In [22]:
# Importing all necessary libraries
%load_ext autoreload
%autoreload 2

# internal packages
import os
from collections import Counter, OrderedDict

# external packages
import torch
import torchvision
import numpy as np
import sklearn
from sklearn.cluster import KMeans
from sklearn.metrics import normalized_mutual_info_score, confusion_matrix
from sklearn.decomposition import PCA
import matplotlib
from matplotlib import pyplot as plt
import seaborn as sns
import pandas as pd

# util functions
from util.util import *

# dataset functions
from dataset import load_util

# autoencoder
from models.autoencoder.conv_ae import ConvAE
from models.simclr.simclr import *
from models.simclr.transforms import *
from models.rotnet.rotnet import *
from models.rotnet.IDEC import *
from models.rotnet.custom_stl10 import *
from cluster_accuracy import cluster_accuracy

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [23]:
%%javascript
IPython.OutputArea.prototype._should_scroll = function(lines) {
    return false;
}

<IPython.core.display.Javascript object>

In [24]:
print("Versions:")
print(f"torch: {torch.__version__}")
print(f"torchvision: {torchvision.__version__}")
print(f"numpy: {np.__version__}",)
print(f"scikit-learn: {sklearn.__version__}")

device = detect_device()
print("Using device: ", device)

Versions:
torch: 1.8.1+cu111
torchvision: 0.9.1+cu111
numpy: 1.19.5
scikit-learn: 0.24.1
Using device:  cuda


## Preparation

In [18]:
# specify learning params
batch_size = 128
learning_rate = 0.1
epochs = 100

# training

train = True

In [19]:
test_data = load_util.load_custom_cifar('./data', download=False, train=False, data_percent=1.0, for_model='SimCLR', transforms=False)
testloader = torch.utils.data.DataLoader(test_data,
                                          batch_size=128,
                                          shuffle=True,
                                          drop_last=True)

In [20]:
colors_classes = {i: color_class for i, color_class in zip(range(len(test_data.classes)), test_data.classes)}

In [14]:
def get_nmis_and_cas_10_runs(model_name, colors_classes, device):
    nmis = {}
    cas = {}

    for i in range(10):
        name = f'{model_name}_{i}.pth'
        print(name)
        
        model = load_model(name, device, torch.rand(size=(4, 12288)))
        nmis[i] = []
        cas[i] = []

        for k in range(10):
            labels, kmeans, nmi, _, _ = compute_nmi_and_pca(model, name, colors_classes, device, testloader)
            nmis[i].append(nmi)

            ca = cluster_accuracy(labels, kmeans.labels_)
            cas[i].append(ca)
            print(f'Run: {k}\n{name}:\nNMI:{nmi}\nCA:{ca}\n')
            del labels
            del kmeans
    return nmis, cas

In [15]:
def get_nmis_and_cas_10_runs_1_model(model_name, colors_classes, device):
    print(model_name)

    model = load_model(model_name, device, torch.rand(size=(4, 12288)))
    nmis = []
    cas = []

    for k in range(5):
        labels, kmeans, nmi, _, _ = compute_nmi_and_pca(model, model_name, colors_classes, device, testloader)
        nmis.append(nmi)

        ca = cluster_accuracy(labels, kmeans.labels_)
        cas.append(ca)
        print(f'Run: {k}\n{model_name}:\nNMI:{nmi}\nCA:{ca}\n')
        del labels
        del kmeans
    return nmis, cas

In [15]:
def print_total_mean_and_var(metrics_name, df):
    models_means = df.mean()

    total_mean = models_means.mean()
    total_std = models_means.std()
    
    print(f'{metrics_name} mean: {total_mean:.4f}')
    print(f'{metrics_name} variance: {total_std:.4f}')

## Overall comparison

In [44]:
nmi_all = pd.read_csv('trained_models/nmis_all.csv', index_col=0, header=0)

In [45]:
ca_all = pd.read_csv('trained_models/cas_all.csv', index_col=0, header=0)

In [46]:
nmi_all

Unnamed: 0_level_0,CIFAR NMI,CIFAR NMI VAR,STL10 NMI,STL10 NMI VAR
Model,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
RotNet,0.3506,0.0017,0.3357,0.0032
RotNet + DEC,0.1042,0.0001,0.0,0.0
SimCLR,0.4724,0.0033,0.0,0.0
SimCLR + 100,0.4924,0.0033,0.0,0.0
SimCLR + IDEC,0.5569,0.0001,0.0,0.0


In [47]:
ca_all

Unnamed: 0_level_0,CIFAR CA,CIFAR CA VAR,STL10 CA,STL10 CA VAR
Model,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
RotNet,0.4357,0.0051,0.3641,0.0076
RotNet + DEC,0.2065,0.0002,0.0,0.0
SimCLR,0.5904,0.009,0.0,0.0
SimCLR + 100,0.5815,0.0017,0.0,0.0
SimCLR + IDEC,0.6576,0.0,0.0,0.0


## RotNet

### CIFAR : Pretraining

In [110]:
path_nmis = 'trained_models/RotNet/CIFAR/nmis.csv'
path_cas = 'trained_models/RotNet/CIFAR/cas.csv'

Uncomment this part to generate new NMI and cluster accuracies

In [None]:
# rotnet_nmis_cifar, rotnet_cas_cifar = get_nmis_and_cas_10_runs('RotNet\CIFAR\pretrained_RotNet', colors_classes, device)

In [None]:
# df_rotnet_nmis_cifar = pd.DataFrame(rotnet_nmis_cifar)
# df_rotnet_cas_cifar = pd.DataFrame(rotnet_cas_cifar)

In [74]:
# df_rotnet_nmis_cifar.to_csv(path_nmis)
# df_rotnet_cas_cifar.to_csv(path_cas)

This code loads existing dataframes from specified paths

In [111]:
df_rotnet_nmis_cifar = pd.read_csv(path_nmis, index_col=0)
df_rotnet_cas_cifar = pd.read_csv(path_cas, index_col=0)

In [112]:
print_total_mean_and_var('NMI', df_rotnet_nmis_cifar)

NMI mean: 0.3506
NMI variance: 0.0017


In [113]:
print_total_mean_and_var('Cluser accuracy', df_rotnet_cas_cifar)

Cluser accuracy mean: 0.4357
Cluser accuracy variance: 0.0051


### CIFAR: DEC

In [114]:
path_nmis = 'trained_models/RotNet/CIFAR/nmis_dec.csv'
path_cas = 'trained_models/RotNet/CIFAR/cas_dec.csv'

Uncomment this part to generate new NMI and cluster accuracies

In [None]:
# rotnet_nmis_cifar_dec, rotnet_cas_cifar_dec = get_nmis_and_cas_10_runs('RotNet\CIFAR\DEC_RotNet', colors_classes, device)

In [None]:
# df_rotnet_nmis_cifar_dec = pd.DataFrame(rotnet_nmis_cifar_dec)
# df_rotnet_cas_cifar_dec = pd.DataFrame(rotnet_cas_cifar_dec)

In [105]:
# df_rotnet_nmis_cifar_dec.to_csv(path_nmis)
# df_rotnet_cas_cifar_dec.to_csv(path_cas)

This code loads existing dataframes from specified paths

In [115]:
df_rotnet_nmis_cifar_dec = pd.read_csv(path_nmis, index_col=0)
df_rotnet_cas_cifar_dec = pd.read_csv(path_cas, index_col=0)

In [117]:
print_total_mean_and_var('NMI', df_rotnet_nmis_cifar_dec)

NMI mean: 0.1042
NMI variance: 0.0001


In [118]:
print_total_mean_and_var('Cluser accuracy', df_rotnet_cas_cifar_dec)

Cluser accuracy mean: 0.2065
Cluser accuracy variance: 0.0002


### STL10: Pretraining

In [144]:
path_nmis = 'trained_models/RotNet/STL10/nmis.csv'
path_cas = 'trained_models/RotNet/STL10/cas.csv'

Uncomment this part to generate new NMI and cluster accuracies

In [None]:
# rotnet_nmis_stl10, rotnet_cas_stl10 = get_nmis_and_cas_10_runs('RotNet\STL10\pretrained_RotNet_STL10', colors_classes, device)

In [None]:
# df_rotnet_nmis_stl10 = pd.DataFrame(rotnet_nmis_stl10)
# df_rotnet_cas_stl10 = pd.DataFrame(rotnet_cas_stl10)

In [149]:
# df_rotnet_nmis_stl10.to_csv(path_nmis)
# df_rotnet_cas_stl10.to_csv(path_cas)

In [145]:
df_rotnet_nmis_stl10 = pd.read_csv(path_nmis, index_col=0)
df_rotnet_cas_stl10 = pd.read_csv(path_cas, index_col=0)

In [128]:
print_total_mean_and_var('NMI', df_rotnet_nmis_stl10)

NMI mean: 0.2977
NMI variance: 0.0022


In [129]:
print_total_mean_and_var('Cluser accuracy', df_rotnet_cas_stl10)

Cluser accuracy mean: 0.3826
Cluser accuracy variance: 0.0069


### STL10: DEC

In [8]:
path_nmis = 'trained_models/RotNet/STL10/nmis_dec.csv'
path_cas = 'trained_models/RotNet/STL10/cas_dec.csv'

Uncomment this part to generate new NMI and cluster accuracies

In [None]:
# rotnet_nmis_stl10_dec, rotnet_cas_stl10_dec = get_nmis_and_cas_10_runs('RotNet\STL10\DEC_RotNet_STL10', colors_classes, device)

In [17]:
# df_rotnet_nmis_stl10_dec = pd.DataFrame(rotnet_nmis_stl10_dec)
# df_rotnet_cas_stl10_dec = pd.DataFrame(rotnet_cas_stl10_dec)

In [18]:
# df_rotnet_nmis_stl10_dec.to_csv(path_nmis)
# df_rotnet_cas_stl10_dec.to_csv(path_cas)

In [19]:
df_rotnet_nmis_stl10_dec = pd.read_csv(path_nmis, index_col=0)
df_rotnet_cas_stl10_dec = pd.read_csv(path_cas, index_col=0)

In [20]:
print_total_mean_and_var('NMI', df_rotnet_nmis_stl10_dec)

NMI mean: 0.1001
NMI variance: 0.0005


In [21]:
print_total_mean_and_var('Cluser accuracy', df_rotnet_cas_stl10_dec)

Cluser accuracy mean: 0.1968
Cluser accuracy variance: 0.0005


## SimCLR

### CIFAR: Pretraining

In [48]:
path_nmis = 'trained_models/SimCLR/CIFAR/nmis.csv'
path_cas = 'trained_models/SimCLR/CIFAR/cas.csv'

Uncomment this part to generate new NMI and cluster accuracies

In [None]:
# simclr_nmis_cifar, simclr_cas_cifar = get_nmis_and_cas_10_runs('SimCLR\CIFAR\pretrained_SimCLR', colors_classes, device)

In [24]:
# df_simclr_nmis_cifar = pd.DataFrame(simclr_nmis_cifar)
# df_simclr_cas_cifar = pd.DataFrame(simclr_cas_cifar)

In [134]:
# df_simclr_nmis_cifar.to_csv(path_nmis)
# df_simclr_cas_cifar.to_csv(path_cas)

In [49]:
df_simclr_nmis_cifar = pd.read_csv(path_nmis, index_col=0)
df_simclr_cas_cifar = pd.read_csv(path_cas, index_col=0)

In [51]:
df_simclr_cas_cifar

Unnamed: 0,Run 0,Run 1,Run 2,Run 3,Run 4,Run 5,Run 6,Run 7,Run 8,Run 9
0,0.610377,0.610978,0.588642,0.563502,0.599459,0.597356,0.625,0.566907,0.622796,0.5622
1,0.609776,0.634415,0.58734,0.55609,0.599259,0.595653,0.624099,0.56871,0.576623,0.558994
2,0.608774,0.634615,0.590845,0.558193,0.603866,0.596955,0.626502,0.557893,0.554487,0.563001
3,0.610777,0.634615,0.591947,0.561398,0.596855,0.593249,0.623297,0.566506,0.578325,0.555489
4,0.611078,0.639022,0.589643,0.558894,0.599359,0.591647,0.620693,0.561799,0.573818,0.557091
5,0.611378,0.633514,0.589042,0.559195,0.59385,0.597556,0.619191,0.5626,0.580629,0.567808
6,0.606871,0.606971,0.492688,0.584034,0.592849,0.592548,0.616286,0.561098,0.585036,0.59385
7,0.60647,0.636218,0.583333,0.555389,0.586939,0.595353,0.6251,0.554788,0.577123,0.565905
8,0.607672,0.604167,0.588642,0.553285,0.59365,0.588041,0.624399,0.588942,0.578425,0.568309
9,0.610276,0.638522,0.58143,0.5624,0.592949,0.589744,0.623898,0.557792,0.621294,0.560096


In [54]:
df_simclr_cas_cifar.mean().std()

0.5903635817307691

In [135]:
print_total_mean_and_var('NMI', df_simclr_nmis_cifar)

NMI mean: 0.4724
NMI variance: 0.0033


In [136]:
print_total_mean_and_var('Cluser accuracy', df_simclr_cas_cifar)

Cluser accuracy mean: 0.5904
Cluser accuracy variance: 0.0090


In [98]:
columns = [f'Run {i}' for i in range(10)]

In [133]:
df_simclr_nmis_cifar.columns = columns
df_simclr_cas_cifar.columns = columns

### CIFAR: IDEC

In [55]:
path_nmis = 'trained_models/SimCLR/CIFAR/nmis_idec.csv'
path_cas = 'trained_models/SimCLR/CIFAR/cas_idec.csv'

Uncomment this part to generate new NMI and cluster accuracies

In [None]:
# simclr_nmis_cifar_idec, simclr_cas_cifar_idec = get_nmis_and_cas_10_runs('SimCLR\CIFAR\IDEC_SimCLR', colors_classes, device)

In [27]:
# df_simclr_nmis_cifar_idec = pd.DataFrame(simclr_nmis_cifar_idec)
# df_simclr_cas_cifar_idec = pd.DataFrame(simclr_cas_cifar_idec)

In [141]:
# df_simclr_nmis_cifar_idec.to_csv(path_nmis)
# df_simclr_cas_cifar_idec.to_csv(path_cas)

In [56]:
df_simclr_nmis_cifar_idec = pd.read_csv(path_nmis, index_col=0)
df_simclr_cas_cifar_idec = pd.read_csv(path_cas, index_col=0)

In [57]:
df_simclr_cas_cifar_idec.mean().std()

0.04796861744598033

In [142]:
print_total_mean_and_var('NMI', df_simclr_nmis_cifar_idec)

NMI mean: 0.5569
NMI variance: 0.0001


In [143]:
print_total_mean_and_var('Cluser accuracy', df_simclr_cas_cifar_idec)

Cluser accuracy mean: 0.6576
Cluser accuracy variance: 0.0000


### STL10: Pretraining

In [22]:
path_nmis = 'trained_models/SimCLR/STL10/nmis.csv'
path_cas = 'trained_models/SimCLR/STL10/cas.csv'

Uncomment this part to generate new NMI and cluster accuracies

In [None]:
# simclr_nmis_stl10, simclr_cas_stl10 = get_nmis_and_cas_10_runs('SimCLR\STL10\pretrained_SimCLR_STL10', colors_classes, device)

In [24]:
# df_simclr_nmis_stl10 = pd.DataFrame(simclr_nmis_stl10)
# df_simclr_cas_stl0 = pd.DataFrame(simclr_cas_stl10)

In [25]:
# df_simclr_nmis_stl10.to_csv(path_nmis)
# df_simclr_cas_stl0.to_csv(path_cas)

In [26]:
df_simclr_nmis_stl10 = pd.read_csv(path_nmis, index_col=0)
df_simclr_cas_stl0 = pd.read_csv(path_cas, index_col=0)

In [27]:
print_total_mean_and_var('NMI', df_simclr_nmis_stl10)

NMI mean: 0.3215
NMI variance: 0.0011


In [28]:
print_total_mean_and_var('Cluser accuracy', df_simclr_cas_stl0)

Cluser accuracy mean: 0.4231
Cluser accuracy variance: 0.0067


### STL10: IDEC

In [29]:
path_nmis_dec = 'trained_models/SimCLR/STL10/nmis_idec.csv'
path_cas_dec = 'trained_models/SimCLR/STL10/cas_idec.csv'

Uncomment this part to generate new NMI and cluster accuracies

In [None]:
# simclr_nmis_stl10_idec, simclr_cas_stl10_idec = get_nmis_and_cas_10_runs('SimCLR\STL10\IDEC_SimCLR_STL10', colors_classes, device)

In [None]:
# df_simclr_nmis_stl10_idec = pd.DataFrame(simclr_nmis_stl10_idec)
# df_simclr_cas_stl10_idec = pd.DataFrame(simclr_cas_stl10_idec)

In [None]:
# df_simclr_nmis_stl10_idec.to_csv(path_nmis)
# df_simclr_cas_stl10_idec.to_csv(path_cas)

In [None]:
df_simclr_nmis_stl10_idec = pd.read_csv(path_nmis, index_col=0)
df_simclr_cas_stl10_idec = pd.read_csv(path_cas, index_col=0)

In [None]:
print_total_mean_and_var('NMI', df_simclr_nmis_stl10_idec)

In [None]:
print_total_mean_and_var('Cluser accuracy', df_simclr_cas_stl10_idec)