In [None]:
import sys
sys.path.append('../src')

import torch
import torchvision
import torchvision.datasets as datasets
import numpy as np

import matplotlib.pyplot as plt
from modules import IDC
from utils import get_synthetic_dataset, plot_synthetic_dataset, clustering_accuracy
from trainer import idc_trainer, device

import warnings
warnings.filterwarnings('ignore')

In [None]:
SEED = 42
np.random.seed(SEED)
torch.manual_seed(SEED)

In [None]:
mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=None)
mnist_X = mnist_trainset.data.reshape(-1, 784) / 255.
mnist_y = mnist_trainset.targets

subset_size_per_class = 1000
classes = torch.unique(mnist_y)
subset_indices = []
for class_label in classes:
    indices = (mnist_y == class_label).nonzero().squeeze()
    subset_indices.extend(indices[:subset_size_per_class])

import random
random.seed(SEED)
random.shuffle(subset_indices)

mnist_X_subset = mnist_X[subset_indices]
mnist_y_subset = mnist_y[subset_indices]

In [None]:
data_input_dim = 784
ae_layer_dims = [784, 512, 512, 2048, 10]
gnn_hidden_dim = 784
cluster_hidden_dim = 2048
nb_classes = 10

idc = IDC(data_input_dim, ae_layer_dims, gnn_hidden_dim, cluster_hidden_dim, nb_classes)

In [None]:
ae_gnn_config_train = {
    "batch_size":256, 
    "lr":1e-3,
    "local_gates_lmbd": 100,
    "reg_lmbd": 10,
    "epochs":150,
    "end_pretrain_epoch": 100
}

clust_config_train = {
    "batch_size":256, 
    "lr_cluster_head": 1e-2,
    "lr_aux": 1e-1,
    "lr_zg": 1e-1,
    "gamma": 4,
    "global_gates_lmbd": 10,
    "epochs": 500,
    "end_pretrain_epoch": 100
}

training_result = idc_trainer(idc, mnist_X_subset, ae_gnn_config_train, clust_config_train)

In [None]:
plt.plot(training_result["stage_one"]["ae_sparse_losses"])
plt.title("AE Sparse loss")
plt.show()

In [None]:
plt.plot(training_result["stage_one"]["ae_gnn_sparse_losses"])
plt.title("AE + GNN Sparse loss")
plt.show()

In [None]:
idx = 5
with torch.no_grad():
    X = mnist_X_subset.to(device)
    X_Z ,z , _ = idc.gnn(X)
    X_z_hat = idc.ae(X_Z)
    X_hat = idc.ae(X)

    print("------------------------------")
    print("AE Fine-tuned , display for X")
    print("------------------------------")
    plt.imshow(X_hat[idx].cpu().view(28, 28))
    plt.show()

    print("------------------------------")
    print("AE Fine-tuned , display for X_Z")
    print("------------------------------")
    plt.imshow(X_Z[idx].cpu().view(28, 28))
    plt.show()

In [None]:
plt.plot(training_result["stage_two"]["clust_head_pretrain_losses"])
plt.title("Cluster Head pretrain loss")
plt.show()

In [None]:
plt.plot(training_result["stage_two"]["clust_head_finetune_losses"])
plt.title("Cluster Head finetuned loss")
plt.show()

In [None]:
plt.plot(training_result["stage_two"]["aux_losses"])
plt.title("Aux loss")
plt.show()

In [None]:
with torch.no_grad():
    X = mnist_X_subset.to(device)
    X_Z ,z , _ = idc.gnn(X)
    H = idc.ae.encoder(X_Z)

    clust_logits, aux_logits, u_zg = idc.clusterNN(X_Z, H)
    yhat = clust_logits.argmax(dim=1).cpu()

    print(clustering_accuracy(yhat, mnist_y_subset))
    print(np.unique(yhat, return_counts=True))

In [None]:
from sklearn.cluster import KMeans

km = KMeans(10)

km.fit(mnist_X_subset)
yhat = km.predict(mnist_X_subset)

clustering_accuracy(yhat, mnist_y_subset)