In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import random
import time
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
import PIL 
from tqdm import tqdm
import matplotlib.pyplot as plt
# tsne and pca
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
from DeepTaxonNet import DeepTaxonNet
import argparse
import utils

from sklearn.mixture import GaussianMixture
import os
import sys

In [3]:
train_loader, test_loader, train_set, test_set = utils.get_data_loader('cifar-10-eval', 128, False)



In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# device = 'cpu'

In [5]:
n_layers=10
# model = DeepTaxonNet(
#     n_layers=n_layers,
#     enc_hidden_dim=128*1*1,
#     dec_hidden_dim=(128,1,1),
#     input_dim=1*28*28,
#     latent_dim=10,
#     encoder_name='omniglot',
#     decoder_name='omniglot',
#     kl1_weight=1
# ).to(device)

## For CIFAR-10
model = DeepTaxonNet(
    n_layers=n_layers,
    enc_hidden_dim=512*1*1,
    dec_hidden_dim=(512,1,1),
    input_dim=3*32*32,
    latent_dim=64,
    encoder_name='resnet18',
    decoder_name='resnet18',
    kl1_weight=1
).to(device)

path = './models/'
path = '/nethome/zwang910/file_storage/nips-2025/deep-taxon/project-checkin/'
model_name = 'dtn-10-cifar10/deep_taxon_480.pt'
model.load_state_dict(torch.load(f'{path}{model_name}'), strict=False)

<All keys matched successfully>

# Accuracy

In [6]:
annotation = utils.label_annotation(model, train_loader, 10, device)
acc = utils.basic_node_evaluation(model, annotation, test_loader, device)
print('acc:', acc)



acc: 0.7421


# NMI

In [7]:
nmi = utils.compute_nmi(model, annotation, test_loader, device)
print(f"MNI: {nmi}")

MNI: 0.6298395308270804


# DP

In [10]:
dendrogram_purity = utils.soft_dendrogram_purity(model, test_loader, device)
print('dendrogram_purity:', dendrogram_purity)

Processing test data to get probability distributions (pcx)...


Evaluating Test Set: 100%|██████████| 40/40 [00:02<00:00, 13.92it/s]


Processed 10000 test samples. Found 10 classes and 2047 nodes.
Calculating node purities based on test set expected counts...
Node purities calculated.
Calculating Soft Dendrogram Purity (iterating over pairs)...


Processing Classes: 100%|██████████| 10/10 [07:41<00:00, 46.13s/it]

Calculation complete.
dendrogram_purity: 0.6083609662417684





# LP

In [9]:
overall_leaf_purity, per_leaf_purities = utils.leaf_purity(model, test_loader, device)
print('overall_leaf_purity:', overall_leaf_purity)
print('per_leaf_purities:', per_leaf_purities)



overall_leaf_purity: 0.7082329805834606
per_leaf_purities: {0: (0.3900905145452013, 2.6341689711425865e-05), 1: (0.3883206867556593, 2.0811201398438424e-05), 2: (0.34909211759623643, 6.050795960118756e-05), 3: (0.36274272690227066, 0.00012190487900403048), 4: (0.5084044821284158, 1.9739190855037122e-05), 5: (0.6039799358745291, 2.5170417988681546e-06), 6: (0.7496543362034338, 6.631616104530866e-05), 7: (0.4524243699302665, 4.932483288511019e-05), 8: (0.6583253022449825, 7.319118809277796e-06), 9: (0.48645169168325414, 1.7737364027776506e-05), 10: (0.6030373609290551, 1.638922947934914e-06), 11: (0.628584730293214, 7.02956247089188e-05), 12: (0.5347940844247707, 0.0002813685495513814), 13: (0.7698444509193378, 0.0009695223202456966), 14: (0.36772823333060906, 0.0006524857981222037), 15: (0.8730729670179036, 0.000995547705024841), 16: (0.546135155150578, 9.239336387396198e-05), 17: (0.3208015361519769, 0.0001131206943259317), 18: (0.38826958561480995, 6.553380169715615e-05), 19: (0.75751