In [12]:
# Importing all necessary libraries
%load_ext autoreload
%autoreload 2

# internal packages
import os
from collections import Counter, OrderedDict

# external packages
import torch
import torchvision
import numpy as np
import sklearn
from sklearn.cluster import KMeans
from sklearn.metrics import normalized_mutual_info_score, confusion_matrix
from sklearn.decomposition import PCA
import matplotlib
from matplotlib import pyplot as plt
import seaborn as sns
import pandas as pd

# util functions
from util.util import *

# dataset functions
from dataset import load_util

# autoencoder
from models.autoencoder.conv_ae import ConvAE
from models.simclr.simclr import *
from models.simclr.transforms import *
from models.rotnet.rotnet import *
from models.rotnet.IDEC import *
from models.rotnet.custom_stl10 import *
from cluster_accuracy import cluster_accuracy

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [13]:
%%javascript
IPython.OutputArea.prototype._should_scroll = function(lines) {
    return false;
}

<IPython.core.display.Javascript object>

In [14]:
print("Versions:")
print(f"torch: {torch.__version__}")
print(f"torchvision: {torchvision.__version__}")
print(f"numpy: {np.__version__}",)
print(f"scikit-learn: {sklearn.__version__}")

device = detect_device()
print("Using device: ", device)

Versions:
torch: 1.8.1+cu111
torchvision: 0.9.1+cu111
numpy: 1.19.5
scikit-learn: 0.24.1
Using device:  cuda


## Preparation

In [15]:
# specify learning params
batch_size = 128
learning_rate = 0.1
epochs = 100

# training

train = True

In [16]:
test_data = load_util.load_custom_cifar('./data', download=False, train=False, data_percent=1.0, for_model='SimCLR', transforms=False)
testloader = torch.utils.data.DataLoader(test_data,
                                          batch_size=128,
                                          shuffle=True,
                                          drop_last=True)

In [17]:
colors_classes = {i: color_class for i, color_class in zip(range(len(test_data.classes)), test_data.classes)}

In [18]:
def get_nmis_and_cas_10_runs(model_name, colors_classes, device):
    nmis = {}
    cas = {}

    for i in range(10):
        name = f'{model_name}_{i}.pth'
        print(name)
        
        model = load_model(name, device, torch.rand(size=(4, 12288)))
        nmis[i] = []
        cas[i] = []

        for k in range(10):
            labels, kmeans, nmi, _, _ = compute_nmi_and_pca(model, name, colors_classes, device, testloader)
            nmis[i].append(nmi)

            ca = cluster_accuracy(labels, kmeans.labels_)
            cas[i].append(ca)
            print(f'Run: {k}\n{name}:\nNMI:{nmi}\nCA:{ca}\n')
            del labels
            del kmeans
    return nmis, cas

In [19]:
def get_nmis_and_cas_10_runs_1_model(model_name, colors_classes, device):
    print(model_name)

    model = load_model(model_name, device, torch.rand(size=(4, 12288)))
    nmis = []
    cas = []

    for k in range(5):
        labels, kmeans, nmi, _, _ = compute_nmi_and_pca(model, model_name, colors_classes, device, testloader)
        nmis.append(nmi)

        ca = cluster_accuracy(labels, kmeans.labels_)
        cas.append(ca)
        print(f'Run: {k}\n{model_name}:\nNMI:{nmi}\nCA:{ca}\n')
        del labels
        del kmeans
    return nmis, cas

In [20]:
def print_total_mean_and_std(metrics_name, df):
    models_means = df.mean()

    total_mean = models_means.mean()
    total_std = models_means.std()
    
    print(f'{metrics_name} mean: {total_mean:.4f}')
    print(f'{metrics_name} std: {total_std:.4f}')
    return total_mean, total_std

## Overall comparison

In [80]:
## df_final_nmi = pd.DataFrame(columns=['Model', 'CIFAR NMI', 'CIFAR NMI STD', 'STL10 NMI', 'STL10 NMI STD'])
## df_final_ca = pd.DataFrame(columns=['Model', 'CIFAR CA', 'CIFAR CA STD', 'STL10 CA', 'STL10 CA STD'])

In [94]:
## df_final_nmi.to_csv('trained_models/nmis_all.csv')
## df_final_ca.to_csv('trained_models/cas_all.csv')

In [59]:
df_final_nmi = pd.read_csv('trained_models/nmis_all.csv', index_col=0, header=0)
df_final_ca = pd.read_csv('trained_models/cas_all.csv', index_col=0, header=0)

In [172]:
df_final_ca

Unnamed: 0,Model,CIFAR CA,CIFAR CA STD,STL10 CA,STL10 CA STD
0,RotNet,0.43565,0.018361,0.364128,0.014185
1,RotNet + DEC,0.206545,0.002399,0.196739,0.029117
2,SimCLR,0.590364,0.023882,0.411395,0.022278
3,SimCLR + IDEC,0.657623,0.047969,0.445916,0.015449
4,SimCLR + 100,0.5815,0.0017,0.3937,0.006011


In [173]:
df_final_nmi

Unnamed: 0,Model,CIFAR NMI,CIFAR NMI STD,STL10 NMI,STL10 NMI STD
0,RotNet,0.350597,0.006679,0.335731,0.006869
1,RotNet + DEC,0.104224,0.003862,0.129233,0.047574
2,SimCLR,0.472371,0.009836,0.35022,0.005799
3,SimCLR + IDEC,0.556923,0.020652,0.388456,0.007139
4,SimCLR + 100,0.4924,0.0033,0.351471,0.003135


## RotNet

In [29]:
rotnet_list_nmi = ['RotNet', 0, 0, 0, 0]
rotnet_dec_list_nmi = ['RotNet + DEC', 0, 0, 0, 0]
rotnet_list_ca = ['RotNet', 0, 0, 0, 0]
rotnet_dec_list_ca = ['RotNet + DEC', 0, 0, 0, 0]

### CIFAR : Pretraining

In [4]:
path_nmis = 'trained_models/RotNet/CIFAR/nmis.csv'
path_cas = 'trained_models/RotNet/CIFAR/cas.csv'

Uncomment this part to generate new NMI and cluster accuracies

In [34]:
# rotnet_nmis_cifar, rotnet_cas_cifar = get_nmis_and_cas_10_runs('RotNet\CIFAR\pretrained_RotNet', colors_classes, device)

In [35]:
# df_rotnet_nmis_cifar = pd.DataFrame(rotnet_nmis_cifar)
# df_rotnet_cas_cifar = pd.DataFrame(rotnet_cas_cifar)

In [36]:
# df_rotnet_nmis_cifar.to_csv(path_nmis)
# df_rotnet_cas_cifar.to_csv(path_cas)

This code loads existing dataframes from specified paths

In [10]:
df_rotnet_nmis_cifar = pd.read_csv(path_nmis, index_col=0)
df_rotnet_cas_cifar = pd.read_csv(path_cas, index_col=0)

In [21]:
rotnet_list_nmi[1], rotnet_list_nmi[2] = print_total_mean_and_std('NMI', df_rotnet_nmis_cifar)

NMI mean: 0.3506
NMI std: 0.0067


In [22]:
rotnet_list_ca[1], rotnet_list_ca[2] = print_total_mean_and_std('Cluster accuracy', df_rotnet_cas_cifar)

Cluster accuracy mean: 0.4357
Cluster accuracy std: 0.0184


### CIFAR: DEC

In [23]:
path_nmis = 'trained_models/RotNet/CIFAR/nmis_dec.csv'
path_cas = 'trained_models/RotNet/CIFAR/cas_dec.csv'

Uncomment this part to generate new NMI and cluster accuracies

In [None]:
# rotnet_nmis_cifar_dec, rotnet_cas_cifar_dec = get_nmis_and_cas_10_runs('RotNet\CIFAR\DEC_RotNet', colors_classes, device)

In [None]:
# df_rotnet_nmis_cifar_dec = pd.DataFrame(rotnet_nmis_cifar_dec)
# df_rotnet_cas_cifar_dec = pd.DataFrame(rotnet_cas_cifar_dec)

In [105]:
# df_rotnet_nmis_cifar_dec.to_csv(path_nmis)
# df_rotnet_cas_cifar_dec.to_csv(path_cas)

This code loads existing dataframes from specified paths

In [24]:
df_rotnet_nmis_cifar_dec = pd.read_csv(path_nmis, index_col=0)
df_rotnet_cas_cifar_dec = pd.read_csv(path_cas, index_col=0)

In [25]:
rotnet_dec_list_nmi[1], rotnet_dec_list_nmi[2] = print_total_mean_and_std('NMI', df_rotnet_nmis_cifar_dec)

NMI mean: 0.1042
NMI std: 0.0039


In [26]:
rotnet_dec_list_ca[1], rotnet_dec_list_ca[2] = print_total_mean_and_std('Cluster accuracy', df_rotnet_cas_cifar_dec)

Cluster accuracy mean: 0.2065
Cluster accuracy std: 0.0024


### STL10: Pretraining

In [120]:
path_nmis = 'trained_models/RotNet/STL10/nmis.csv'
path_cas = 'trained_models/RotNet/STL10/cas.csv'

Uncomment this part to generate new NMI and cluster accuracies

In [None]:
# rotnet_nmis_stl10, rotnet_cas_stl10 = get_nmis_and_cas_10_runs('RotNet\STL10\pretrained_RotNet_STL10', colors_classes, device)

In [None]:
# df_rotnet_nmis_stl10 = pd.DataFrame(rotnet_nmis_stl10)
# df_rotnet_cas_stl10 = pd.DataFrame(rotnet_cas_stl10)

In [149]:
# df_rotnet_nmis_stl10.to_csv(path_nmis)
# df_rotnet_cas_stl10.to_csv(path_cas)

In [121]:
df_rotnet_nmis_stl10 = pd.read_csv(path_nmis, index_col=0)
df_rotnet_cas_stl10 = pd.read_csv(path_cas, index_col=0)

In [122]:
rotnet_list_nmi[3], rotnet_list_nmi[4] = print_total_mean_and_std('NMI', df_rotnet_nmis_stl10)

NMI mean: 0.3357
NMI std: 0.0069


In [123]:
rotnet_list_ca[3], rotnet_list_ca[4] = print_total_mean_and_std('Cluster accuracy', df_rotnet_cas_stl10)

Cluster accuracy mean: 0.3641
Cluster accuracy std: 0.0142


### STL10: DEC

In [124]:
path_nmis = 'trained_models/RotNet/STL10/nmis_dec.csv'
path_cas = 'trained_models/RotNet/STL10/cas_dec.csv'

Uncomment this part to generate new NMI and cluster accuracies

In [None]:
# rotnet_nmis_stl10_dec, rotnet_cas_stl10_dec = get_nmis_and_cas_10_runs('RotNet\STL10\DEC_RotNet_STL10', colors_classes, device)

In [17]:
# df_rotnet_nmis_stl10_dec = pd.DataFrame(rotnet_nmis_stl10_dec)
# df_rotnet_cas_stl10_dec = pd.DataFrame(rotnet_cas_stl10_dec)

In [18]:
# df_rotnet_nmis_stl10_dec.to_csv(path_nmis)
# df_rotnet_cas_stl10_dec.to_csv(path_cas)

In [125]:
df_rotnet_nmis_stl10_dec = pd.read_csv(path_nmis, index_col=0)
df_rotnet_cas_stl10_dec = pd.read_csv(path_cas, index_col=0)

In [126]:
rotnet_dec_list_nmi[3], rotnet_dec_list_nmi[4] = print_total_mean_and_std('NMI', df_rotnet_nmis_stl10_dec)

NMI mean: 0.1292
NMI std: 0.0476


In [127]:
rotnet_dec_list_ca[3], rotnet_dec_list_ca[4] = print_total_mean_and_std('Cluster accuracy', df_rotnet_cas_stl10_dec)

Cluster accuracy mean: 0.1967
Cluster accuracy std: 0.0291


In [156]:
df_final_nmi.loc[0] = rotnet_list_nmi
df_final_nmi.loc[1] = rotnet_dec_list_nmi
df_final_ca.loc[0] = rotnet_list_ca
df_final_ca.loc[1] = rotnet_dec_list_ca

## SimCLR

In [33]:
simclr_list_nmi = ['SimCLR', 0, 0, 0, 0]
simclr_idec_list_nmi = ['SimCLR + IDEC', 0, 0, 0, 0]
simclr_list_ca = ['SimCLR', 0, 0, 0, 0]
simclr_idec_list_ca = ['SimCLR + IDEC', 0, 0, 0, 0]
simclr_list_100_nmi = ['SimCLR + 100', 0, 0, 0, 0]
simclr_list_100_ca = ['SimCLR + 100', 0, 0, 0, 0]

### CIFAR: Pretraining

In [34]:
path_nmis = 'trained_models/SimCLR/CIFAR/nmis.csv'
path_cas = 'trained_models/SimCLR/CIFAR/cas.csv'

Uncomment this part to generate new NMI and cluster accuracies

In [None]:
# simclr_nmis_cifar, simclr_cas_cifar = get_nmis_and_cas_10_runs('SimCLR\CIFAR\pretrained_SimCLR', colors_classes, device)

In [24]:
# df_simclr_nmis_cifar = pd.DataFrame(simclr_nmis_cifar)
# df_simclr_cas_cifar = pd.DataFrame(simclr_cas_cifar)

In [134]:
# df_simclr_nmis_cifar.to_csv(path_nmis)
# df_simclr_cas_cifar.to_csv(path_cas)

In [35]:
df_simclr_nmis_cifar = pd.read_csv(path_nmis, index_col=0)
df_simclr_cas_cifar = pd.read_csv(path_cas, index_col=0)

In [51]:
df_simclr_cas_cifar

Unnamed: 0,Run 0,Run 1,Run 2,Run 3,Run 4,Run 5,Run 6,Run 7,Run 8,Run 9
0,0.610377,0.610978,0.588642,0.563502,0.599459,0.597356,0.625,0.566907,0.622796,0.5622
1,0.609776,0.634415,0.58734,0.55609,0.599259,0.595653,0.624099,0.56871,0.576623,0.558994
2,0.608774,0.634615,0.590845,0.558193,0.603866,0.596955,0.626502,0.557893,0.554487,0.563001
3,0.610777,0.634615,0.591947,0.561398,0.596855,0.593249,0.623297,0.566506,0.578325,0.555489
4,0.611078,0.639022,0.589643,0.558894,0.599359,0.591647,0.620693,0.561799,0.573818,0.557091
5,0.611378,0.633514,0.589042,0.559195,0.59385,0.597556,0.619191,0.5626,0.580629,0.567808
6,0.606871,0.606971,0.492688,0.584034,0.592849,0.592548,0.616286,0.561098,0.585036,0.59385
7,0.60647,0.636218,0.583333,0.555389,0.586939,0.595353,0.6251,0.554788,0.577123,0.565905
8,0.607672,0.604167,0.588642,0.553285,0.59365,0.588041,0.624399,0.588942,0.578425,0.568309
9,0.610276,0.638522,0.58143,0.5624,0.592949,0.589744,0.623898,0.557792,0.621294,0.560096


In [54]:
df_simclr_cas_cifar.mean().std()

0.5903635817307691

In [36]:
simclr_list_nmi[1], simclr_list_nmi[2] = print_total_mean_and_std('NMI', df_simclr_nmis_cifar)

NMI mean: 0.4724
NMI std: 0.0098


In [37]:
simclr_list_ca[1], simclr_list_ca[2] = print_total_mean_and_std('Cluser accuracy', df_simclr_cas_cifar)

Cluser accuracy mean: 0.5904
Cluser accuracy std: 0.0239


### CIFAR: Pretraining +100

In [85]:
path_nmis = 'trained_models/SimCLR/CIFAR/nmis_100.csv'
path_cas = 'trained_models/SimCLR/CIFAR/cas_100.csv'

In [86]:
df_simclr_nmis_cifar_dec = pd.read_csv(path_nmis, index_col=0)
df_simclr_cas_cifar_dec = pd.read_csv(path_cas, index_col=0)

In [89]:
simclr_list_100_nmi[1], simclr_list_100_nmi[2] = print_total_mean_and_std('NMI', df_simclr_nmis_cifar_dec)

NMI mean: 0.4905
NMI std: 0.0105


In [90]:
simclr_list_100_ca[1], simclr_list_100_ca[2] = print_total_mean_and_std('Cluster accuracy', df_simclr_cas_cifar_dec)

Cluster accuracy mean: 0.6106
Cluster accuracy std: 0.0266


### CIFAR: IDEC

In [42]:
path_nmis = 'trained_models/SimCLR/CIFAR/nmis_idec.csv'
path_cas = 'trained_models/SimCLR/CIFAR/cas_idec.csv'

Uncomment this part to generate new NMI and cluster accuracies

In [None]:
# simclr_nmis_cifar_idec, simclr_cas_cifar_idec = get_nmis_and_cas_10_runs('SimCLR\CIFAR\IDEC_SimCLR', colors_classes, device)

In [27]:
# df_simclr_nmis_cifar_idec = pd.DataFrame(simclr_nmis_cifar_idec)
# df_simclr_cas_cifar_idec = pd.DataFrame(simclr_cas_cifar_idec)

In [141]:
# df_simclr_nmis_cifar_idec.to_csv(path_nmis)
# df_simclr_cas_cifar_idec.to_csv(path_cas)

In [43]:
df_simclr_nmis_cifar_idec = pd.read_csv(path_nmis, index_col=0)
df_simclr_cas_cifar_idec = pd.read_csv(path_cas, index_col=0)

In [44]:
df_simclr_cas_cifar_idec.mean().std()

0.04796861744598033

In [45]:
simclr_idec_list_nmi[1], simclr_idec_list_nmi[2] = print_total_mean_and_std('NMI', df_simclr_nmis_cifar_idec)

NMI mean: 0.5569
NMI std: 0.0207


In [46]:
simclr_idec_list_ca[1], simclr_idec_list_ca[2] = print_total_mean_and_std('Cluster accuracy', df_simclr_cas_cifar_idec)

Cluster accuracy mean: 0.6576
Cluster accuracy std: 0.0480


### STL10: Pretraining

In [64]:
path_nmis = 'trained_models/SimCLR/STL10/nmis.csv'
path_cas = 'trained_models/SimCLR/STL10/cas.csv'

Uncomment this part to generate new NMI and cluster accuracies

In [65]:
# simclr_nmis_stl10, simclr_cas_stl10 = get_nmis_and_cas_10_runs('SimCLR\STL10\pretrained_SimCLR_STL10', colors_classes, device)

In [66]:
# df_simclr_nmis_stl10 = pd.DataFrame(simclr_nmis_stl10)
# df_simclr_cas_stl0 = pd.DataFrame(simclr_cas_stl10)

In [67]:
# df_simclr_nmis_stl10.to_csv(path_nmis)
# df_simclr_cas_stl0.to_csv(path_cas)

In [68]:
df_simclr_nmis_stl10 = pd.read_csv(path_nmis, index_col=0)
df_simclr_cas_stl0 = pd.read_csv(path_cas, index_col=0)

In [69]:
simclr_list_nmi[3], simclr_list_nmi[4] = print_total_mean_and_std('NMI', df_simclr_nmis_stl10)

NMI mean: 0.3502
NMI std: 0.0058


In [70]:
simclr_list_ca[3], simclr_list_ca[4] = print_total_mean_and_std('Cluster accuracy', df_simclr_cas_stl0)

Cluster accuracy mean: 0.4114
Cluster accuracy std: 0.0223


### STL10: Pretraining + 100

In [47]:
path_nmis = 'trained_models/SimCLR/STL10/nmis_100.csv'
path_cas = 'trained_models/SimCLR/STL10/cas_100.csv'

In [48]:
df_simclr_nmis_stl10 = pd.read_csv(path_nmis, index_col=0)
df_simclr_cas_stl0 = pd.read_csv(path_cas, index_col=0)

In [49]:
simclr_list_100_nmi[3], simclr_list_100_nmi[4] = print_total_mean_and_std('NMI', df_simclr_nmis_stl10)

NMI mean: 0.3515
NMI std: 0.0051


In [50]:
simclr_list_100_ca[3], simclr_list_100_ca[4] = print_total_mean_and_std('Cluster accuracy', df_simclr_cas_stl0)

Cluster accuracy mean: 0.4106
Cluster accuracy std: 0.0163


### STL10: IDEC

In [51]:
path_nmis = 'trained_models/SimCLR/STL10/nmis_new.csv'
path_cas = 'trained_models/SimCLR/STL10/cas_new.csv'

Uncomment this part to generate new NMI and cluster accuracies

In [52]:
# simclr_nmis_stl10_idec, simclr_cas_stl10_idec = get_nmis_and_cas_10_runs('SimCLR\STL10\IDEC_SimCLR_STL10', colors_classes, device)

In [53]:
# df_simclr_nmis_stl10_idec = pd.DataFrame(simclr_nmis_stl10_idec)
# df_simclr_cas_stl10_idec = pd.DataFrame(simclr_cas_stl10_idec)

In [54]:
# df_simclr_nmis_stl10_idec.to_csv(path_nmis)
# df_simclr_cas_stl10_idec.to_csv(path_cas)

In [55]:
df_simclr_nmis_stl10_idec = pd.read_csv(path_nmis, index_col=0)
df_simclr_cas_stl10_idec = pd.read_csv(path_cas, index_col=0)

In [56]:
simclr_idec_list_nmi[3], simclr_idec_list_nmi[4] = print_total_mean_and_std('NMI', df_simclr_nmis_stl10_idec)

NMI mean: 0.3595
NMI std: 0.0173


In [57]:
simclr_idec_list_ca[3], simclr_idec_list_ca[4] = print_total_mean_and_std('Cluster accuracy', df_simclr_cas_stl10_idec)

Cluster accuracy mean: 0.4026
Cluster accuracy std: 0.0220


In [91]:
df_final_nmi.loc[2] = simclr_list_nmi
df_final_nmi.loc[3] = simclr_idec_list_nmi
df_final_ca.loc[2] = simclr_list_ca
df_final_ca.loc[3] = simclr_idec_list_ca
df_final_nmi.loc[4] = simclr_list_100_nmi
df_final_ca.loc[4] = simclr_list_100_ca

In [96]:
df_final_ca

Unnamed: 0,Model,CIFAR CA,CIFAR CA STD,STL10 CA,STL10 CA STD
0,RotNet,0.43565,0.018361,0.364128,0.014185
1,RotNet + DEC,0.206545,0.002399,0.196739,0.029117
2,SimCLR,0.590364,0.023882,0.411395,0.022278
3,SimCLR + IDEC,0.657623,0.047969,0.40261,0.02198
4,SimCLR + 100,0.610551,0.026599,0.41062,0.016269


TODO:
- recompute SimCLR STL10 IDEC
- recompute SimCLR STL10 +100