In [None]:
import os
import torch
import numpy as np
import yaml
from sklearn.metrics import accuracy_score, normalized_mutual_info_score
from sklearn.linear_model import LogisticRegression

from models.model import TreeVAE
from utils.model_utils import construct_tree_fromnpy
from utils.data_utils import get_data, get_gen
from utils.training_utils import predict
from utils.utils import cluster_acc

checkpoint_path = 'models/experiments/mnist/20231025-175819_d6be9'  # ← 请替换为你的路径
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

with open(os.path.join(checkpoint_path, "config.yaml"), 'r') as f:
    configs = yaml.load(f, Loader=yaml.Loader)

In [None]:
model = TreeVAE(**configs['training'])
model.load_state_dict(torch.load(os.path.join(checkpoint_path, "model_weights.pt"), map_location=device), strict=True)

tree_structure = np.load(os.path.join(checkpoint_path, "data_tree.npy"), allow_pickle=True)
model = construct_tree_fromnpy(model, tree_structure, configs)

model.to(device)
model.eval()
print("model loaded")

In [None]:
trainset, trainset_eval, testset = get_data(configs)
gen_train_eval = get_gen(trainset_eval, configs, validation=True, shuffle=False)
gen_test = get_gen(testset, configs, validation=True, shuffle=False)

y_train = trainset_eval.dataset.targets[trainset_eval.indices].numpy()
y_test = testset.dataset.targets[testset.indices].numpy()

print(f"data loaded | Train: {len(y_train)}, Test: {len(y_test)}")

train_acc = cluster_acc(y_train, trainset_eval.dataset.targets[trainset_eval.indices].numpy())
test_acc = cluster_acc(y_test, testset.dataset.targets[testset.indices].numpy())

print(f"cluster accuracy | Train: {train_acc:.3f}, Test: {test_acc:.3f}")

In [None]:
_, y_pred = predict(gen_test, model, device)

acc = cluster_acc(y_test, y_pred.numpy(), return_index=False)
nmi = normalized_mutual_info_score(y_test, y_pred.numpy())

print(f"Clustering ACC: {acc:.4f}")
print(f"NMI: {nmi:.4f}")

In [None]:
z_train = predict(gen_train_eval, model, device, 'bottom_up')[-1].cpu().numpy()
z_test = predict(gen_test, model, device, 'bottom_up')[-1].cpu().numpy()

clf = LogisticRegression(max_iter=1000)
clf.fit(z_train, y_train)
y_lp_pred = clf.predict(z_test)

lp_acc = accuracy_score(y_test, y_lp_pred)
print(f"Linear Probe Accuracy: {lp_acc:.4f}")

In [None]:
prob_train = predict(gen_train_eval, model, device, 'prob_leaves')
prob_test = predict(gen_test, model, device, 'prob_leaves')

leaf_test = prob_test.argmax(axis=1)

def compute_dp_score(labels, leaves):
    count = 0
    total = 0
    for i in range(len(labels)):
        for j in range(i+1, len(labels)):
            if labels[i] == labels[j]:
                total += 1
                if leaves[i] == leaves[j]:
                    count += 1
    return count / total if total > 0 else 0.0

dp_score = compute_dp_score(y_test, leaf_test)
print(f"Decision Path Agreement: {dp_score:.4f}")

In [None]:
print("======== TreeVAE Final Evaluation ========")
print(f"Clustering ACC      : {acc:.4f}")
print(f"NMI                 : {nmi:.4f}")
print(f"Linear Probe ACC    : {lp_acc:.4f}")
print(f"Decision Path Agree : {dp_score:.4f}")