In [None]:
# Eval code

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import tqdm

from utils.CMAES_utils import load_model
from utils.monitor_utils import *
from model.CMAES import CMAES
from utils.test_utils import generate_testset, visualize_tsne, visualize_umap
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import confusion_matrix
from scipy.optimize import linear_sum_assignment
from scipy import linalg
from sklearn.metrics import accuracy_score, precision_score, recall_score

import matplotlib.pyplot as plt
import numpy as np
import torch.nn.functional as F
import torch
import time

random.seed(42)
torch.manual_seed(42)
np.random.seed(42)

In [None]:
# Those dataset are bigger than 5GB, due to the limit of storage we cant offer them :(
train_labels = None
test_labels = None
train_dataset = None
test_dataset = None

train_dataset = torch.tensor(train_dataset, dtype=torch.float)
train_dataset = train_dataset.permute(0, 2, 1)

test_dataset = torch.tensor(test_dataset, dtype=torch.float)
test_dataset = test_dataset.permute(0, 2, 1)


print(test_dataset.shape)
print(train_dataset.shape)
print(np.unique(test_labels).shape)
print(np.unique(test_labels))
# test_labels = np.concatenate([test_labels, test_labels2 + 1])
print(test_labels)

In [None]:
# load model
huidu = CMAES(T=0.2,
              mask_ratio=0.0,
              use_embedding=True,
              n_heads=4,
              m=0.99,
              use_avg_pool=True,
              K=4096,
              embedding_dim=64,
              ff_dim=128,
              num_layers=4,
              dropout=0.1,
              moco_v3=True).to('cuda')

huidu = load_model(huidu, './resources/checkpoint/HuiduRep.pt').to('cuda')

In [None]:
# print total activity params
def count_named_parameters(model, name_filter):
    total = 0
    for name, param in model.named_parameters():
        print(name)
        if 'conv_embedding' in name or 'encoder_q' in name or 'reduce_q' in name:
            total += param.numel()
    return total

print("Encoder params:", count_named_parameters(huidu, "encoder"))

In [None]:
# select 10 units randomly
test_data, test_units, labels = generate_testset(test_dataset, test_labels, num_units=10)

In [None]:
test_units = torch.tensor(test_units, dtype=torch.float)
test_units = test_units.to('cpu')
test_data = test_data.to('cuda')
start = time.perf_counter()
test_data_denoise = huidu.denoise(test_data).cpu().detach()
# calculate ari
res, gmm_test, test_spikes = gmm_monitor(huidu,
                                         None,
                                         test_units,
                                         test_data_denoise,
                                         test_units,
                                         labels,
                                         device='cuda',
                                         epochs=20,
                                         use_pca=False,
                                         use_scaler=False,
                                         covariance_type='full',
                                         use_iso=False,
                                         score=True,
                                         test_data_origin=test_data)

end = time.perf_counter()
print(end - start)
print(np.mean(res))
print(np.std(res))
print(np.max(res))
print(np.min(res))

In [None]:
# calculate silhouette score
from sklearn.metrics import silhouette_score

scaler = StandardScaler()
scores = []
for labels in gmm_test:
    sil_score = silhouette_score(test_spikes, labels)
    scores.append(sil_score)

print(np.mean(scores))
print(np.std(scores))

In [None]:
def match_labels(y_true, y_pred):
    # 构建混淆矩阵
    cm = confusion_matrix(y_true, y_pred)

    # 匈牙利算法求最大匹配（负号是因为 linear_sum_assignment 是最小化）
    row_ind, col_ind = linear_sum_assignment(-cm)

    # 创建一个新的标签映射
    label_mapping = {col: row for row, col in zip(row_ind, col_ind)}
    # 重新映射 y_pred
    y_pred_aligned = np.array([label_mapping[label] for label in y_pred])

    return y_pred_aligned

total_acc = []
total_precision = []
total_recall = []
for data_point in gmm_test:

    true_encoder = LabelEncoder()
    pred_encoder = LabelEncoder()

    y_true_encoded = true_encoder.fit_transform(test_labels)
    y_pred_encoded = pred_encoder.fit_transform(data_point)
    y_pred_aligned = match_labels(y_pred=y_pred_encoded, y_true=y_true_encoded)
    # 分类评估
    acc = accuracy_score(y_true_encoded, y_pred_aligned)
    precision = precision_score(y_true_encoded, y_pred_aligned, average='macro')
    recall = recall_score(y_true_encoded, y_pred_aligned, average='macro')

    total_acc.append(acc)
    total_precision.append(precision)
    total_recall.append(recall)
    print(acc)
print(f"Accuracy: {np.mean(total_acc):.4f}")
print(f"Precision (macro): {np.mean(total_precision):.4f}")
print(f"Recall (macro): {np.mean(total_recall):.4f}")

In [None]:
# run test for 100 times with different seed
score = []
times = []
for i in tqdm(range(100)):
    start = time.perf_counter()
    random.seed(i)
    test_data, test_units, labels = generate_testset(test_dataset, test_labels, num_units=10)
    res, gmm_test, test_spikes = gmm_monitor(huidu,
                                         None,
                                         None,
                                         test_data,
                                         test_units,
                                         labels,
                                         verbose=False,
                                         use_iso=False,
                                         score=True,
                                         max_iter=100,
                                         covariance_type='full',
                                         device='gpu', epochs=50,)
    score.append(np.mean(res))
    print(np.mean(score))
    print(labels)
    end = time.perf_counter()
    times.append(end - start)

print(np.mean(times))
print(np.std(times))
print(np.mean(score))
print(np.std(score))
print(np.max(score))
print(np.min(score))

In [None]:
print(np.mean(score))
print(np.std(score, ddof=1) / np.sqrt(len(res)))
print(np.mean(times))
print(np.std(times, ddof=1) / np.sqrt(len(res)))
print(np.max(score))
print(np.min(score))