In [73]:
# Importing all necessary libraries
%load_ext autoreload
%autoreload 2

# internal packages
import os
from collections import Counter, OrderedDict

# external packages
import torch
import torchvision
import numpy as np
import sklearn
from sklearn.cluster import KMeans
from sklearn.metrics import normalized_mutual_info_score, confusion_matrix
from sklearn.decomposition import PCA
import matplotlib
from matplotlib import pyplot as plt
import seaborn as sns
import pandas as pd

# util functions
from util.util import *

# dataset functions
from dataset import load_util

# autoencoder
from models.autoencoder.conv_ae import ConvAE
from models.simclr.simclr import *
from models.simclr.transforms import *
from models.rotnet.rotnet import *
from models.rotnet.IDEC import *
from models.rotnet.custom_stl10 import *
from cluster_accuracy import cluster_accuracy
from models.simclr.custom_stl10 import SimCLRSTL10
from models.simclr.custom_fmnist import SimCLRFMNIST

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [6]:
%%javascript
IPython.OutputArea.prototype._should_scroll = function(lines) {
    return false;
}

<IPython.core.display.Javascript object>

In [7]:
print("Versions:")
print(f"torch: {torch.__version__}")
print(f"torchvision: {torchvision.__version__}")
print(f"numpy: {np.__version__}",)
print(f"scikit-learn: {sklearn.__version__}")

device = detect_device()
print("Using device: ", device)

Versions:
torch: 1.8.1+cu111
torchvision: 0.9.1+cu111
numpy: 1.19.5
scikit-learn: 0.24.1
Using device:  cuda


## Preparation

In [8]:
# specify learning params
batch_size = 128
learning_rate = 0.1
epochs = 100

# training

train = True

In [10]:
cifar_test_data = load_util.load_custom_cifar('./data', download=False, train=False, data_percent=1.0, for_model='SimCLR', transforms=False)
cifar_testloader = torch.utils.data.DataLoader(cifar_test_data,
                                          batch_size=128,
                                          shuffle=True,
                                          drop_last=True)

In [12]:
cifar_colors_classes = {i: color_class for i, color_class in zip(range(len(cifar_test_data.classes)), cifar_test_data.classes)}

In [34]:
stl10_test_data = SimCLRSTL10(data_percent=1.0, train=False)
stl10_testloader = torch.utils.data.DataLoader(stl10_test_data,
                                         batch_size=128,
                                         shuffle=True,
                                         drop_last=True)

In [35]:
stl10_colors_classes = {i: color_class for i, color_class in zip(range(len(stl10_test_data.classes)), stl10_test_data.classes)}

In [125]:
fmnist_test_data = SimCLRFMNIST('./data', data_percent=1.0, train=False, download=True)
fmnist_testloader = torch.utils.data.DataLoader(fmnist_test_data,
                                         batch_size=128,
                                         shuffle=True,
                                         drop_last=True)

In [126]:
fmnist_colors_classes = {i: color_class for i, color_class in zip(range(len(fmnist_test_data.classes)), fmnist_test_data.classes)}

In [138]:
def get_nmis_and_cas_10_runs(model_name, colors_classes, device, dataset='CIFAR'):
    nmis = {}
    cas = {}
    
    tls = {
        'CIFAR': cifar_testloader,
        'FMNIST': fmnist_testloader,
        'STL10':stl10_testloader
    }
    testloader = tls[dataset]
    # testloader = cifar_testloader if 'CIFAR' in dataset else stl10_testloader

    for i in range(10):
        name = f'{model_name}_{i}.pth'
        print(name)

        model = load_model(name, device, torch.rand(size=(10, 9408)))
        nmis[i] = []
        cas[i] = []

        for k in range(10):
            labels, kmeans, nmi, _, _ = compute_nmi_and_pca(model, name, colors_classes, device, testloader)
            nmis[i].append(nmi)

            ca = cluster_accuracy(labels, kmeans.labels_)
            cas[i].append(ca)
            print(f'Run: {k}\n{name}:\nNMI:{nmi}\nCA:{ca}\n')
            del labels
            del kmeans
    return nmis, cas

In [15]:
def print_total_mean_and_var(metrics_name, df):
    models_means = df.mean()

    total_mean = models_means.mean()
    total_std = models_means.std()
    
    print(f'{metrics_name} mean: {total_mean:.4f}')
    print(f'{metrics_name} variance: {total_std:.4f}')

## RotNet

### CIFAR : Pretraining

In [110]:
path_nmis = 'trained_models/RotNet/CIFAR/nmis.csv'
path_cas = 'trained_models/RotNet/CIFAR/cas.csv'

Uncomment this part to generate new NMI and cluster accuracies

In [None]:
# rotnet_nmis_cifar, rotnet_cas_cifar = get_nmis_and_cas_10_runs('RotNet\CIFAR\pretrained_RotNet', colors_classes, device)

In [None]:
# df_rotnet_nmis_cifar = pd.DataFrame(rotnet_nmis_cifar)
# df_rotnet_cas_cifar = pd.DataFrame(rotnet_cas_cifar)

In [74]:
# df_rotnet_nmis_cifar.to_csv(path_nmis)
# df_rotnet_cas_cifar.to_csv(path_cas)

This code loads existing dataframes from specified paths

In [111]:
df_rotnet_nmis_cifar = pd.read_csv(path_nmis, index_col=0)
df_rotnet_cas_cifar = pd.read_csv(path_cas, index_col=0)

In [112]:
print_total_mean_and_var('NMI', df_rotnet_nmis_cifar)

NMI mean: 0.3506
NMI variance: 0.0017


In [113]:
print_total_mean_and_var('Cluser accuracy', df_rotnet_cas_cifar)

Cluser accuracy mean: 0.4357
Cluser accuracy variance: 0.0051


### CIFAR: DEC

In [114]:
path_nmis = 'trained_models/RotNet/CIFAR/nmis_dec.csv'
path_cas = 'trained_models/RotNet/CIFAR/cas_dec.csv'

Uncomment this part to generate new NMI and cluster accuracies

In [None]:
# rotnet_nmis_cifar_dec, rotnet_cas_cifar_dec = get_nmis_and_cas_10_runs('RotNet\CIFAR\DEC_RotNet', colors_classes, device)

In [None]:
# df_rotnet_nmis_cifar_dec = pd.DataFrame(rotnet_nmis_cifar_dec)
# df_rotnet_cas_cifar_dec = pd.DataFrame(rotnet_cas_cifar_dec)

In [105]:
# df_rotnet_nmis_cifar_dec.to_csv(path_nmis)
# df_rotnet_cas_cifar_dec.to_csv(path_cas)

This code loads existing dataframes from specified paths

In [115]:
df_rotnet_nmis_cifar_dec = pd.read_csv(path_nmis, index_col=0)
df_rotnet_cas_cifar_dec = pd.read_csv(path_cas, index_col=0)

In [117]:
print_total_mean_and_var('NMI', df_rotnet_nmis_cifar_dec)

NMI mean: 0.1042
NMI variance: 0.0001


In [118]:
print_total_mean_and_var('Cluser accuracy', df_rotnet_cas_cifar_dec)

Cluser accuracy mean: 0.2065
Cluser accuracy variance: 0.0002


### STL10: Pretraining

In [14]:
path_nmis = 'trained_models/RotNet/STL10/nmis.csv'
path_cas = 'trained_models/RotNet/STL10/cas.csv'

Uncomment this part to generate new NMI and cluster accuracies

In [None]:
rotnet_nmis_stl10, rotnet_cas_stl10 = get_nmis_and_cas_10_runs('RotNet\STL10\pretrained_RotNet_STL10', stl10_colors_classes, device, dataset='STL10')

In [33]:
df_rotnet_nmis_stl10 = pd.DataFrame(rotnet_nmis_stl10)
df_rotnet_cas_stl10 = pd.DataFrame(rotnet_cas_stl10)

In [34]:
df_rotnet_nmis_stl10.to_csv(path_nmis)
df_rotnet_cas_stl10.to_csv(path_cas)

In [35]:
df_rotnet_nmis_stl10 = pd.read_csv(path_nmis, index_col=0)
df_rotnet_cas_stl10 = pd.read_csv(path_cas, index_col=0)

In [36]:
print_total_mean_and_var('NMI', df_rotnet_nmis_stl10)

NMI mean: 0.3357
NMI variance: 0.0032


In [37]:
print_total_mean_and_var('Cluser accuracy', df_rotnet_cas_stl10)

Cluser accuracy mean: 0.3641
Cluser accuracy variance: 0.0076


### STL10: DEC

In [39]:
path_nmis = 'trained_models/RotNet/STL10/nmis_dec.csv'
path_cas = 'trained_models/RotNet/STL10/cas_dec.csv'

Uncomment this part to generate new NMI and cluster accuracies

In [None]:
# rotnet_nmis_stl10_dec, rotnet_cas_stl10_dec = get_nmis_and_cas_10_runs('RotNet\STL10\DEC_RotNet_STL10', stl10_colors_classes, device, dataset='STL10')

In [41]:
# df_rotnet_nmis_stl10_dec = pd.DataFrame(rotnet_nmis_stl10_dec)
# df_rotnet_cas_stl10_dec = pd.DataFrame(rotnet_cas_stl10_dec)

In [42]:
# df_rotnet_nmis_stl10_dec.to_csv(path_nmis)
# df_rotnet_cas_stl10_dec.to_csv(path_cas)

In [43]:
df_rotnet_nmis_stl10_dec = pd.read_csv(path_nmis, index_col=0)
df_rotnet_cas_stl10_dec = pd.read_csv(path_cas, index_col=0)

In [44]:
print_total_mean_and_var('NMI', df_rotnet_nmis_stl10_dec)

NMI mean: 0.1292
NMI variance: 0.0003


In [45]:
print_total_mean_and_var('Cluser accuracy', df_rotnet_cas_stl10_dec)

Cluser accuracy mean: 0.1967
Cluser accuracy variance: 0.0003


### FMNIST: Pretraining

In [79]:
path_nmis = 'trained_models/RotNet/FMNIST/nmis.csv'
path_cas = 'trained_models/RotNet/FMNIST/cas.csv'
model_name = 'RotNet/FMNIST/pretrained_RotNet_FMNIST'

In [None]:
df_nmis, df_cas = compute_everything(path_nmis, path_cas, model_name, fmnist_colors_classes, compute=True, dataset='FMNIST')

In [133]:
path_nmis = 'trained_models/RotNet/FMNIST/nmis_dec.csv'
path_cas = 'trained_models/RotNet/FMNIST/cas_dec.csv'
model_name = 'RotNet/FMNIST/DEC_RotNet_FMINST'

In [139]:
df_nmis, df_cas = compute_everything(path_nmis, path_cas, model_name, fmnist_colors_classes, compute=True, dataset='FMNIST')

RotNet/FMNIST/DEC_RotNet_FMINST_0.pth
Run: 0
RotNet/FMNIST/DEC_RotNet_FMINST_0.pth:
NMI:0.5111750003117222
CA:0.4052483974358974

Run: 1
RotNet/FMNIST/DEC_RotNet_FMINST_0.pth:
NMI:0.5112861978178415
CA:0.40705128205128205

Run: 2
RotNet/FMNIST/DEC_RotNet_FMINST_0.pth:
NMI:0.511487030945079
CA:0.4055488782051282

Run: 3
RotNet/FMNIST/DEC_RotNet_FMINST_0.pth:
NMI:0.5208013783324408
CA:0.44381009615384615

Run: 4
RotNet/FMNIST/DEC_RotNet_FMINST_0.pth:
NMI:0.5109350058289088
CA:0.40635016025641024

Run: 5
RotNet/FMNIST/DEC_RotNet_FMINST_0.pth:
NMI:0.5200279233570336
CA:0.4444110576923077

Run: 6
RotNet/FMNIST/DEC_RotNet_FMINST_0.pth:
NMI:0.5111183252819648
CA:0.4052483974358974

Run: 7
RotNet/FMNIST/DEC_RotNet_FMINST_0.pth:
NMI:0.52011483902179
CA:0.44421073717948717

Run: 8
RotNet/FMNIST/DEC_RotNet_FMINST_0.pth:
NMI:0.5205905650959054
CA:0.4440104166666667

Run: 9
RotNet/FMNIST/DEC_RotNet_FMINST_0.pth:
NMI:0.5109040502420529
CA:0.4049479166666667

RotNet/FMNIST/DEC_RotNet_FMINST_1.pth


  self.centers = torch.nn.Parameter(torch.tensor(init_np_centers), requires_grad=True)


Run: 0
RotNet/FMNIST/DEC_RotNet_FMINST_1.pth:
NMI:0.5124410113755884
CA:0.40705128205128205

Run: 1
RotNet/FMNIST/DEC_RotNet_FMINST_1.pth:
NMI:0.5132735958565171
CA:0.40795272435897434

Run: 2
RotNet/FMNIST/DEC_RotNet_FMINST_1.pth:
NMI:0.5124848027065108
CA:0.4065504807692308

Run: 3
RotNet/FMNIST/DEC_RotNet_FMINST_1.pth:
NMI:0.5213424825719485
CA:0.4441105769230769

Run: 4
RotNet/FMNIST/DEC_RotNet_FMINST_1.pth:
NMI:0.5125709877792544
CA:0.4072516025641026

Run: 5
RotNet/FMNIST/DEC_RotNet_FMINST_1.pth:
NMI:0.5130170298196638
CA:0.4074519230769231

Run: 6
RotNet/FMNIST/DEC_RotNet_FMINST_1.pth:
NMI:0.5214905221363141
CA:0.4443108974358974

Run: 7
RotNet/FMNIST/DEC_RotNet_FMINST_1.pth:
NMI:0.5123637200902981
CA:0.40645032051282054

Run: 8
RotNet/FMNIST/DEC_RotNet_FMINST_1.pth:
NMI:0.5214418420514535
CA:0.4437099358974359

Run: 9
RotNet/FMNIST/DEC_RotNet_FMINST_1.pth:
NMI:0.5212974930574278
CA:0.44421073717948717

RotNet/FMNIST/DEC_RotNet_FMINST_2.pth


  self.centers = torch.nn.Parameter(torch.tensor(init_np_centers), requires_grad=True)


Run: 0
RotNet/FMNIST/DEC_RotNet_FMINST_2.pth:
NMI:0.4992182172784722
CA:0.4026442307692308

Run: 1
RotNet/FMNIST/DEC_RotNet_FMINST_2.pth:
NMI:0.5049394313240794
CA:0.418369391025641

Run: 2
RotNet/FMNIST/DEC_RotNet_FMINST_2.pth:
NMI:0.5049251854125774
CA:0.4182692307692308

Run: 3
RotNet/FMNIST/DEC_RotNet_FMINST_2.pth:
NMI:0.49891850827890677
CA:0.402744391025641

Run: 4
RotNet/FMNIST/DEC_RotNet_FMINST_2.pth:
NMI:0.5047164192560829
CA:0.418369391025641

Run: 5
RotNet/FMNIST/DEC_RotNet_FMINST_2.pth:
NMI:0.5051944266680105
CA:0.398036858974359

Run: 6
RotNet/FMNIST/DEC_RotNet_FMINST_2.pth:
NMI:0.4984451698641219
CA:0.40294471153846156

Run: 7
RotNet/FMNIST/DEC_RotNet_FMINST_2.pth:
NMI:0.5048888944246762
CA:0.41816907051282054

Run: 8
RotNet/FMNIST/DEC_RotNet_FMINST_2.pth:
NMI:0.5046024562702965
CA:0.41816907051282054

Run: 9
RotNet/FMNIST/DEC_RotNet_FMINST_2.pth:
NMI:0.4988915745933595
CA:0.402744391025641

RotNet/FMNIST/DEC_RotNet_FMINST_3.pth


  self.centers = torch.nn.Parameter(torch.tensor(init_np_centers), requires_grad=True)


Run: 0
RotNet/FMNIST/DEC_RotNet_FMINST_3.pth:
NMI:0.4740521758951735
CA:0.3984375

Run: 1
RotNet/FMNIST/DEC_RotNet_FMINST_3.pth:
NMI:0.4739263270253726
CA:0.3977363782051282

Run: 2
RotNet/FMNIST/DEC_RotNet_FMINST_3.pth:
NMI:0.4729529179370591
CA:0.398036858974359

Run: 3
RotNet/FMNIST/DEC_RotNet_FMINST_3.pth:
NMI:0.4741493867930161
CA:0.3987379807692308

Run: 4
RotNet/FMNIST/DEC_RotNet_FMINST_3.pth:
NMI:0.47326390573961463
CA:0.39763621794871795

Run: 5
RotNet/FMNIST/DEC_RotNet_FMINST_3.pth:
NMI:0.47378113873356315
CA:0.3981370192307692

Run: 6
RotNet/FMNIST/DEC_RotNet_FMINST_3.pth:
NMI:0.473104537732523
CA:0.398036858974359

Run: 7
RotNet/FMNIST/DEC_RotNet_FMINST_3.pth:
NMI:0.4729523606534087
CA:0.39793669871794873

Run: 8
RotNet/FMNIST/DEC_RotNet_FMINST_3.pth:
NMI:0.47295789022186624
CA:0.398036858974359

Run: 9
RotNet/FMNIST/DEC_RotNet_FMINST_3.pth:
NMI:0.4730804011085234
CA:0.398036858974359

RotNet/FMNIST/DEC_RotNet_FMINST_4.pth


  self.centers = torch.nn.Parameter(torch.tensor(init_np_centers), requires_grad=True)


Run: 0
RotNet/FMNIST/DEC_RotNet_FMINST_4.pth:
NMI:0.49006586795215407
CA:0.3851161858974359

Run: 1
RotNet/FMNIST/DEC_RotNet_FMINST_4.pth:
NMI:0.4905117104395257
CA:0.3844150641025641

Run: 2
RotNet/FMNIST/DEC_RotNet_FMINST_4.pth:
NMI:0.49065347676371
CA:0.3858173076923077

Run: 3
RotNet/FMNIST/DEC_RotNet_FMINST_4.pth:
NMI:0.48914137601306296
CA:0.3808092948717949

Run: 4
RotNet/FMNIST/DEC_RotNet_FMINST_4.pth:
NMI:0.48967150540328014
CA:0.38501602564102566

Run: 5
RotNet/FMNIST/DEC_RotNet_FMINST_4.pth:
NMI:0.4871928287103092
CA:0.3639823717948718

Run: 6
RotNet/FMNIST/DEC_RotNet_FMINST_4.pth:
NMI:0.4864524820003454
CA:0.3602764423076923

Run: 7
RotNet/FMNIST/DEC_RotNet_FMINST_4.pth:
NMI:0.4901313014148745
CA:0.3844150641025641

Run: 8
RotNet/FMNIST/DEC_RotNet_FMINST_4.pth:
NMI:0.48759157528315167
CA:0.36438301282051283

Run: 9
RotNet/FMNIST/DEC_RotNet_FMINST_4.pth:
NMI:0.48758841276407444
CA:0.36498397435897434

RotNet/FMNIST/DEC_RotNet_FMINST_5.pth


  self.centers = torch.nn.Parameter(torch.tensor(init_np_centers), requires_grad=True)


Run: 0
RotNet/FMNIST/DEC_RotNet_FMINST_5.pth:
NMI:0.5316469334504155
CA:0.40805288461538464

Run: 1
RotNet/FMNIST/DEC_RotNet_FMINST_5.pth:
NMI:0.5318720595150951
CA:0.4085536858974359

Run: 2
RotNet/FMNIST/DEC_RotNet_FMINST_5.pth:
NMI:0.5428561891851108
CA:0.40965544871794873

Run: 3
RotNet/FMNIST/DEC_RotNet_FMINST_5.pth:
NMI:0.5430367355417106
CA:0.40995592948717946

Run: 4
RotNet/FMNIST/DEC_RotNet_FMINST_5.pth:
NMI:0.5320851785885464
CA:0.4088541666666667

Run: 5
RotNet/FMNIST/DEC_RotNet_FMINST_5.pth:
NMI:0.5317291560550744
CA:0.40865384615384615

Run: 6
RotNet/FMNIST/DEC_RotNet_FMINST_5.pth:
NMI:0.5399839415052222
CA:0.39463141025641024

Run: 7
RotNet/FMNIST/DEC_RotNet_FMINST_5.pth:
NMI:0.5395338637360112
CA:0.3952323717948718

Run: 8
RotNet/FMNIST/DEC_RotNet_FMINST_5.pth:
NMI:0.5318137389540714
CA:0.40905448717948717

Run: 9
RotNet/FMNIST/DEC_RotNet_FMINST_5.pth:
NMI:0.5310274270792134
CA:0.4087540064102564

RotNet/FMNIST/DEC_RotNet_FMINST_6.pth


  self.centers = torch.nn.Parameter(torch.tensor(init_np_centers), requires_grad=True)


Run: 0
RotNet/FMNIST/DEC_RotNet_FMINST_6.pth:
NMI:0.4740569975893971
CA:0.38451522435897434

Run: 1
RotNet/FMNIST/DEC_RotNet_FMINST_6.pth:
NMI:0.4746107117015851
CA:0.3860176282051282

Run: 2
RotNet/FMNIST/DEC_RotNet_FMINST_6.pth:
NMI:0.4681656854596609
CA:0.36548477564102566

Run: 3
RotNet/FMNIST/DEC_RotNet_FMINST_6.pth:
NMI:0.47474778625770575
CA:0.37650240384615385

Run: 4
RotNet/FMNIST/DEC_RotNet_FMINST_6.pth:
NMI:0.4743749595541032
CA:0.38561698717948717

Run: 5
RotNet/FMNIST/DEC_RotNet_FMINST_6.pth:
NMI:0.47437701178148595
CA:0.38521634615384615

Run: 6
RotNet/FMNIST/DEC_RotNet_FMINST_6.pth:
NMI:0.474939683174093
CA:0.37650240384615385

Run: 7
RotNet/FMNIST/DEC_RotNet_FMINST_6.pth:
NMI:0.4749894273044162
CA:0.3769030448717949

Run: 8
RotNet/FMNIST/DEC_RotNet_FMINST_6.pth:
NMI:0.4740418025587962
CA:0.3844150641025641

Run: 9
RotNet/FMNIST/DEC_RotNet_FMINST_6.pth:
NMI:0.4751250430910765
CA:0.37670272435897434

RotNet/FMNIST/DEC_RotNet_FMINST_7.pth


  self.centers = torch.nn.Parameter(torch.tensor(init_np_centers), requires_grad=True)


Run: 0
RotNet/FMNIST/DEC_RotNet_FMINST_7.pth:
NMI:0.47960613051607526
CA:0.36548477564102566

Run: 1
RotNet/FMNIST/DEC_RotNet_FMINST_7.pth:
NMI:0.47888011026329014
CA:0.36698717948717946

Run: 2
RotNet/FMNIST/DEC_RotNet_FMINST_7.pth:
NMI:0.4779881786541113
CA:0.38752003205128205

Run: 3
RotNet/FMNIST/DEC_RotNet_FMINST_7.pth:
NMI:0.4789982330879156
CA:0.36658653846153844

Run: 4
RotNet/FMNIST/DEC_RotNet_FMINST_7.pth:
NMI:0.4860081751849253
CA:0.39693509615384615

Run: 5
RotNet/FMNIST/DEC_RotNet_FMINST_7.pth:
NMI:0.4833048239148344
CA:0.38912259615384615

Run: 6
RotNet/FMNIST/DEC_RotNet_FMINST_7.pth:
NMI:0.4859160725844688
CA:0.3971354166666667

Run: 7
RotNet/FMNIST/DEC_RotNet_FMINST_7.pth:
NMI:0.47967852992184745
CA:0.3674879807692308

Run: 8
RotNet/FMNIST/DEC_RotNet_FMINST_7.pth:
NMI:0.47916268802529965
CA:0.36538461538461536

Run: 9
RotNet/FMNIST/DEC_RotNet_FMINST_7.pth:
NMI:0.4791992670643614
CA:0.3671875

RotNet/FMNIST/DEC_RotNet_FMINST_8.pth


  self.centers = torch.nn.Parameter(torch.tensor(init_np_centers), requires_grad=True)


Run: 0
RotNet/FMNIST/DEC_RotNet_FMINST_8.pth:
NMI:0.48866432025860684
CA:0.40314503205128205

Run: 1
RotNet/FMNIST/DEC_RotNet_FMINST_8.pth:
NMI:0.48480876009868623
CA:0.39663461538461536

Run: 2
RotNet/FMNIST/DEC_RotNet_FMINST_8.pth:
NMI:0.48818286294964197
CA:0.38822115384615385

Run: 3
RotNet/FMNIST/DEC_RotNet_FMINST_8.pth:
NMI:0.4820415414632958
CA:0.38721955128205127

Run: 4
RotNet/FMNIST/DEC_RotNet_FMINST_8.pth:
NMI:0.48222473149105854
CA:0.38752003205128205

Run: 5
RotNet/FMNIST/DEC_RotNet_FMINST_8.pth:
NMI:0.4830043031017925
CA:0.38661858974358976

Run: 6
RotNet/FMNIST/DEC_RotNet_FMINST_8.pth:
NMI:0.4849857729050275
CA:0.3968349358974359

Run: 7
RotNet/FMNIST/DEC_RotNet_FMINST_8.pth:
NMI:0.4879010036807052
CA:0.3877203525641026

Run: 8
RotNet/FMNIST/DEC_RotNet_FMINST_8.pth:
NMI:0.4838277238504005
CA:0.3883213141025641

Run: 9
RotNet/FMNIST/DEC_RotNet_FMINST_8.pth:
NMI:0.47959043279094615
CA:0.3835136217948718

RotNet/FMNIST/DEC_RotNet_FMINST_9.pth


  self.centers = torch.nn.Parameter(torch.tensor(init_np_centers), requires_grad=True)


Run: 0
RotNet/FMNIST/DEC_RotNet_FMINST_9.pth:
NMI:0.4916008484599371
CA:0.37710336538461536

Run: 1
RotNet/FMNIST/DEC_RotNet_FMINST_9.pth:
NMI:0.4912402811214174
CA:0.37670272435897434

Run: 2
RotNet/FMNIST/DEC_RotNet_FMINST_9.pth:
NMI:0.49711932453488533
CA:0.3931290064102564

Run: 3
RotNet/FMNIST/DEC_RotNet_FMINST_9.pth:
NMI:0.4919030757139076
CA:0.3773036858974359

Run: 4
RotNet/FMNIST/DEC_RotNet_FMINST_9.pth:
NMI:0.49155225734537916
CA:0.37680288461538464

Run: 5
RotNet/FMNIST/DEC_RotNet_FMINST_9.pth:
NMI:0.4891165928011638
CA:0.3773036858974359

Run: 6
RotNet/FMNIST/DEC_RotNet_FMINST_9.pth:
NMI:0.4917237327428518
CA:0.3775040064102564

Run: 7
RotNet/FMNIST/DEC_RotNet_FMINST_9.pth:
NMI:0.49182758691221423
CA:0.3776041666666667

Run: 8
RotNet/FMNIST/DEC_RotNet_FMINST_9.pth:
NMI:0.49148185416325463
CA:0.3776041666666667

Run: 9
RotNet/FMNIST/DEC_RotNet_FMINST_9.pth:
NMI:0.49139039012489794
CA:0.3777043269230769

NMI mean: 0.4963
NMI variance: 0.0205
Cluser accuracy mean: 0.3959
Cluse

## SimCLR

### CIFAR: Pretraining

In [130]:
path_nmis = 'trained_models/SimCLR/CIFAR/nmis.csv'
path_cas = 'trained_models/SimCLR/CIFAR/cas.csv'

Uncomment this part to generate new NMI and cluster accuracies

In [None]:
# simclr_nmis_cifar, simclr_cas_cifar = get_nmis_and_cas_10_runs('SimCLR\CIFAR\pretrained_SimCLR', colors_classes, device)

In [24]:
# df_simclr_nmis_cifar = pd.DataFrame(simclr_nmis_cifar)
# df_simclr_cas_cifar = pd.DataFrame(simclr_cas_cifar)

In [134]:
# df_simclr_nmis_cifar.to_csv(path_nmis)
# df_simclr_cas_cifar.to_csv(path_cas)

In [132]:
df_simclr_nmis_cifar = pd.read_csv(path_nmis, index_col=0)
df_simclr_cas_cifar = pd.read_csv(path_cas, index_col=0)

In [135]:
print_total_mean_and_var('NMI', df_simclr_nmis_cifar)

NMI mean: 0.4724
NMI variance: 0.0033


In [136]:
print_total_mean_and_var('Cluser accuracy', df_simclr_cas_cifar)

Cluser accuracy mean: 0.5904
Cluser accuracy variance: 0.0090


In [98]:
columns = [f'Run {i}' for i in range(10)]

In [133]:
df_simclr_nmis_cifar.columns = columns
df_simclr_cas_cifar.columns = columns

### CIFAR: +100 pretraining

In [22]:
path_nmis = 'trained_models/SimCLR/CIFAR/nmis_100.csv'
path_cas = 'trained_models/SimCLR/CIFAR/cas_100.csv'
model_name = 'SimCLR/CIFAR/pretrained_SimCLR_100_CIFAR'

In [33]:
df_nmis_100, df_cas_100 = compute_everything(path_nmis, path_cas, model_name, cifar_colors_classes, compute=False)

NMI mean: 0.4905
NMI variance: 0.0105
Cluser accuracy mean: 0.6106
Cluser accuracy variance: 0.0266


In [37]:
def compute_everything(path_nmis, path_cas, model_name, colors_classes, compute=True, dataset='CIFAR'):
    if compute: 
        simclr_nmis_cifar_idec, simclr_cas_cifar_idec = get_nmis_and_cas_10_runs(model_name, colors_classes, device, dataset=dataset)

        df_simclr_nmis_cifar_idec = pd.DataFrame(simclr_nmis_cifar_idec)
        df_simclr_cas_cifar_idec = pd.DataFrame(simclr_cas_cifar_idec)
        
        df_simclr_nmis_cifar_idec.to_csv(path_nmis)
        df_simclr_cas_cifar_idec.to_csv(path_cas)
    
    df_simclr_nmis_cifar_idec = pd.read_csv(path_nmis, index_col=0)
    df_simclr_cas_cifar_idec = pd.read_csv(path_cas, index_col=0)

    print_total_mean_and_var('NMI', df_simclr_nmis_cifar_idec)
    print_total_mean_and_var('Cluser accuracy', df_simclr_cas_cifar_idec)
    
    return df_simclr_nmis_cifar_idec, df_simclr_cas_cifar_idec

### CIFAR: IDEC

In [137]:
path_nmis = 'trained_models/SimCLR/CIFAR/nmis_idec.csv'
path_cas = 'trained_models/SimCLR/CIFAR/cas_idec.csv'

Uncomment this part to generate new NMI and cluster accuracies

In [None]:
# simclr_nmis_cifar_idec, simclr_cas_cifar_idec = get_nmis_and_cas_10_runs('SimCLR\CIFAR\IDEC_SimCLR', colors_classes, device)

In [27]:
# df_simclr_nmis_cifar_idec = pd.DataFrame(simclr_nmis_cifar_idec)
# df_simclr_cas_cifar_idec = pd.DataFrame(simclr_cas_cifar_idec)

In [141]:
# df_simclr_nmis_cifar_idec.to_csv(path_nmis)
# df_simclr_cas_cifar_idec.to_csv(path_cas)

In [138]:
df_simclr_nmis_cifar_idec = pd.read_csv(path_nmis, index_col=0)
df_simclr_cas_cifar_idec = pd.read_csv(path_cas, index_col=0)

In [142]:
print_total_mean_and_var('NMI', df_simclr_nmis_cifar_idec)

NMI mean: 0.5569
NMI variance: 0.0001


In [143]:
print_total_mean_and_var('Cluser accuracy', df_simclr_cas_cifar_idec)

Cluser accuracy mean: 0.6576
Cluser accuracy variance: 0.0000


### STL10: Pretraining

In [46]:
path_nmis = 'trained_models/SimCLR/STL10/nmis.csv'
path_cas = 'trained_models/SimCLR/STL10/cas.csv'

Uncomment this part to generate new NMI and cluster accuracies

In [None]:
simclr_nmis_stl10, simclr_cas_stl10 = get_nmis_and_cas_10_runs('SimCLR\STL10\pretrained_SimCLR_STL10', stl10_colors_classes, device, dataset='STL10')

In [48]:
df_simclr_nmis_stl10 = pd.DataFrame(simclr_nmis_stl10)
df_simclr_cas_stl0 = pd.DataFrame(simclr_cas_stl10)

In [49]:
df_simclr_nmis_stl10.to_csv(path_nmis)
df_simclr_cas_stl0.to_csv(path_cas)

In [50]:
df_simclr_nmis_stl10 = pd.read_csv(path_nmis, index_col=0)
df_simclr_cas_stl0 = pd.read_csv(path_cas, index_col=0)

In [51]:
print_total_mean_and_var('NMI', df_simclr_nmis_stl10)

NMI mean: 0.3502
NMI variance: 0.0015


In [52]:
print_total_mean_and_var('Cluser accuracy', df_simclr_cas_stl0)

Cluser accuracy mean: 0.4114
Cluser accuracy variance: 0.0051


### STL10: +100 Pretraining

In [65]:
path_nmis = 'trained_models/SimCLR/STL10/nmis_100.csv'
path_cas = 'trained_models/SimCLR/STL10/cas_100.csv'
model_name = 'SimCLR/STL10/pretrained_SimCLR_100_STL10'

In [66]:
df_nmis_100, df_cas_100 = compute_everything(path_nmis, path_cas, model_name, stl10_colors_classes, compute=True, dataset='STL10')

SimCLR/STL10/pretrained_SimCLR_100_STL10_0.pth
Run: 0
SimCLR/STL10/pretrained_SimCLR_100_STL10_0.pth:
NMI:0.3569698345805579
CA:0.4264112903225806

Run: 1
SimCLR/STL10/pretrained_SimCLR_100_STL10_0.pth:
NMI:0.3508253159846291
CA:0.4261592741935484

Run: 2
SimCLR/STL10/pretrained_SimCLR_100_STL10_0.pth:
NMI:0.35745604466282127
CA:0.41822076612903225

Run: 3
SimCLR/STL10/pretrained_SimCLR_100_STL10_0.pth:
NMI:0.3537210088397678
CA:0.39742943548387094

Run: 4
SimCLR/STL10/pretrained_SimCLR_100_STL10_0.pth:
NMI:0.3684602545190243
CA:0.4592993951612903

Run: 5
SimCLR/STL10/pretrained_SimCLR_100_STL10_0.pth:
NMI:0.35942969557192095
CA:0.39591733870967744

Run: 6
SimCLR/STL10/pretrained_SimCLR_100_STL10_0.pth:
NMI:0.3533192995519575
CA:0.4112903225806452

Run: 7
SimCLR/STL10/pretrained_SimCLR_100_STL10_0.pth:
NMI:0.35090527296537494
CA:0.3897429435483871

Run: 8
SimCLR/STL10/pretrained_SimCLR_100_STL10_0.pth:
NMI:0.35455645591602863
CA:0.4122983870967742

Run: 9
SimCLR/STL10/pretrained_SimCLR

Run: 8
SimCLR/STL10/pretrained_SimCLR_100_STL10_8.pth:
NMI:0.3618373005741058
CA:0.4371219758064516

Run: 9
SimCLR/STL10/pretrained_SimCLR_100_STL10_8.pth:
NMI:0.35850066801374036
CA:0.4202368951612903

SimCLR/STL10/pretrained_SimCLR_100_STL10_9.pth
Run: 0
SimCLR/STL10/pretrained_SimCLR_100_STL10_9.pth:
NMI:0.345780918406083
CA:0.4308215725806452

Run: 1
SimCLR/STL10/pretrained_SimCLR_100_STL10_9.pth:
NMI:0.34263444537424864
CA:0.4275453629032258

Run: 2
SimCLR/STL10/pretrained_SimCLR_100_STL10_9.pth:
NMI:0.3395235057401971
CA:0.4162046370967742

Run: 3
SimCLR/STL10/pretrained_SimCLR_100_STL10_9.pth:
NMI:0.34818286140455823
CA:0.42477318548387094

Run: 4
SimCLR/STL10/pretrained_SimCLR_100_STL10_9.pth:
NMI:0.34996029237783643
CA:0.43119959677419356

Run: 5
SimCLR/STL10/pretrained_SimCLR_100_STL10_9.pth:
NMI:0.3483539582407274
CA:0.4348538306451613

Run: 6
SimCLR/STL10/pretrained_SimCLR_100_STL10_9.pth:
NMI:0.36046175392237284
CA:0.421875

Run: 7
SimCLR/STL10/pretrained_SimCLR_100_STL10_

### STL10: IDEC

In [129]:
path_nmis = 'trained_models/SimCLR/STL10/nmis_new.csv'
path_cas = 'trained_models/SimCLR/STL10/cas_new.csv'
model_name = 'SimCLR/STL10/pretrained_IDEC_SimCLR'

In [130]:
df_nmis_100, df_cas_100 = compute_everything(path_nmis, path_cas, model_name, stl10_colors_classes, compute=True, dataset='STL10')

SimCLR/STL10/pretrained_IDEC_SimCLR_0.pth


  self.centers = torch.nn.Parameter(torch.tensor(init_np_centers), requires_grad=True)


Run: 0
SimCLR/STL10/pretrained_IDEC_SimCLR_0.pth:
NMI:0.35185063644231607
CA:0.3821824596774194

Run: 1
SimCLR/STL10/pretrained_IDEC_SimCLR_0.pth:
NMI:0.351086659757207
CA:0.3824344758064516

Run: 2
SimCLR/STL10/pretrained_IDEC_SimCLR_0.pth:
NMI:0.35140961741247956
CA:0.3834425403225806

Run: 3
SimCLR/STL10/pretrained_IDEC_SimCLR_0.pth:
NMI:0.3508059195785592
CA:0.3815524193548387

Run: 4
SimCLR/STL10/pretrained_IDEC_SimCLR_0.pth:
NMI:0.3509081167066927
CA:0.3824344758064516

Run: 5
SimCLR/STL10/pretrained_IDEC_SimCLR_0.pth:
NMI:0.35094221413443727
CA:0.3824344758064516

Run: 6
SimCLR/STL10/pretrained_IDEC_SimCLR_0.pth:
NMI:0.3513377567819569
CA:0.38306451612903225

Run: 7
SimCLR/STL10/pretrained_IDEC_SimCLR_0.pth:
NMI:0.3512743600391995
CA:0.3829385080645161

Run: 8
SimCLR/STL10/pretrained_IDEC_SimCLR_0.pth:
NMI:0.3515272813476721
CA:0.3824344758064516

Run: 9
SimCLR/STL10/pretrained_IDEC_SimCLR_0.pth:
NMI:0.3518546498245455
CA:0.3836945564516129

SimCLR/STL10/pretrained_IDEC_SimCLR_1

  self.centers = torch.nn.Parameter(torch.tensor(init_np_centers), requires_grad=True)


Run: 0
SimCLR/STL10/pretrained_IDEC_SimCLR_1.pth:
NMI:0.3629892805092482
CA:0.40032762096774194

Run: 1
SimCLR/STL10/pretrained_IDEC_SimCLR_1.pth:
NMI:0.3634968958654731
CA:0.40095766129032256

Run: 2
SimCLR/STL10/pretrained_IDEC_SimCLR_1.pth:
NMI:0.3633315528386298
CA:0.40083165322580644

Run: 3
SimCLR/STL10/pretrained_IDEC_SimCLR_1.pth:
NMI:0.3632249167819628
CA:0.40133568548387094

Run: 4
SimCLR/STL10/pretrained_IDEC_SimCLR_1.pth:
NMI:0.3633832004348692
CA:0.40045362903225806

Run: 5
SimCLR/STL10/pretrained_IDEC_SimCLR_1.pth:
NMI:0.3627780302463887
CA:0.40083165322580644

Run: 6
SimCLR/STL10/pretrained_IDEC_SimCLR_1.pth:
NMI:0.36247733387834996
CA:0.40083165322580644

Run: 7
SimCLR/STL10/pretrained_IDEC_SimCLR_1.pth:
NMI:0.36344233436282336
CA:0.4019657258064516

Run: 8
SimCLR/STL10/pretrained_IDEC_SimCLR_1.pth:
NMI:0.36354347508011337
CA:0.4017137096774194

Run: 9
SimCLR/STL10/pretrained_IDEC_SimCLR_1.pth:
NMI:0.36299613410836734
CA:0.40095766129032256

SimCLR/STL10/pretrained_IDEC

  self.centers = torch.nn.Parameter(torch.tensor(init_np_centers), requires_grad=True)


Run: 0
SimCLR/STL10/pretrained_IDEC_SimCLR_2.pth:
NMI:0.3647841052034609
CA:0.39705141129032256

Run: 1
SimCLR/STL10/pretrained_IDEC_SimCLR_2.pth:
NMI:0.36502152163647766
CA:0.39818548387096775

Run: 2
SimCLR/STL10/pretrained_IDEC_SimCLR_2.pth:
NMI:0.36450288672299963
CA:0.3979334677419355

Run: 3
SimCLR/STL10/pretrained_IDEC_SimCLR_2.pth:
NMI:0.3652596724528327
CA:0.3973034274193548



KeyboardInterrupt: 

In [55]:
path_nmis_dec = 'trained_models/SimCLR/STL10/nmis_idec.csv'
path_cas_dec = 'trained_models/SimCLR/STL10/cas_idec.csv'

Uncomment this part to generate new NMI and cluster accuracies

In [58]:
simclr_nmis_stl10_idec, simclr_cas_stl10_idec = get_nmis_and_cas_10_runs('SimCLR\STL10\IDEC_SimCLR', stl10_colors_classes, device, dataset='STL10')

SimCLR\STL10\IDEC_SimCLR_0.pth
Run: 0
SimCLR\STL10\IDEC_SimCLR_0.pth:
NMI:0.3869628444149622
CA:0.40940020161290325

Run: 1
SimCLR\STL10\IDEC_SimCLR_0.pth:
NMI:0.3879816584306516
CA:0.4592993951612903

Run: 2
SimCLR\STL10\IDEC_SimCLR_0.pth:
NMI:0.38804349746831
CA:0.45955141129032256

Run: 3
SimCLR\STL10\IDEC_SimCLR_0.pth:
NMI:0.38824743891354596
CA:0.4609375

Run: 4
SimCLR\STL10\IDEC_SimCLR_0.pth:
NMI:0.38797855102524814
CA:0.45904737903225806

Run: 5
SimCLR\STL10\IDEC_SimCLR_0.pth:
NMI:0.389371483921124
CA:0.46018145161290325

Run: 6
SimCLR\STL10\IDEC_SimCLR_0.pth:
NMI:0.38789969893711596
CA:0.46118951612903225

Run: 7
SimCLR\STL10\IDEC_SimCLR_0.pth:
NMI:0.38762481632268164
CA:0.4598034274193548

Run: 8
SimCLR\STL10\IDEC_SimCLR_0.pth:
NMI:0.3885151264576402
CA:0.4604334677419355



KeyboardInterrupt: 

In [None]:
df_simclr_nmis_stl10_idec = pd.DataFrame(simclr_nmis_stl10_idec)
df_simclr_cas_stl10_idec = pd.DataFrame(simclr_cas_stl10_idec)

In [None]:
df_simclr_nmis_stl10_idec.to_csv(path_nmis)
df_simclr_cas_stl10_idec.to_csv(path_cas)

In [None]:
df_simclr_nmis_stl10_idec = pd.read_csv(path_nmis, index_col=0)
df_simclr_cas_stl10_idec = pd.read_csv(path_cas, index_col=0)

In [None]:
print_total_mean_and_var('NMI', df_simclr_nmis_stl10_idec)

In [None]:
print_total_mean_and_var('Cluser accuracy', df_simclr_cas_stl10_idec)