In [4]:
# Importing all necessary libraries
%load_ext autoreload
%autoreload 2

# internal packages
import os
from collections import Counter, OrderedDict

# external packages
import torch
import torchvision
import numpy as np
import sklearn
from sklearn.cluster import KMeans
from sklearn.metrics import normalized_mutual_info_score, confusion_matrix
from sklearn.decomposition import PCA
import matplotlib
from matplotlib import pyplot as plt
import seaborn as sns
import pandas as pd

# util functions
from util.util import *

# dataset functions
from dataset import load_util

# autoencoder
from models.autoencoder.conv_ae import ConvAE
from models.simclr.simclr import *
from models.simclr.transforms import *
from models.rotnet.rotnet import *
from models.rotnet.IDEC import *
from models.rotnet.custom_stl10 import *
from cluster_accuracy import cluster_accuracy

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [5]:
%%javascript
IPython.OutputArea.prototype._should_scroll = function(lines) {
    return false;
}

<IPython.core.display.Javascript object>

In [10]:
print("Versions:")
print(f"torch: {torch.__version__}")
print(f"torchvision: {torchvision.__version__}")
print(f"numpy: {np.__version__}",)
print(f"scikit-learn: {sklearn.__version__}")

device = detect_device()
print("Using device: ", device)

Versions:
torch: 1.8.1+cu111
torchvision: 0.9.1+cu111
numpy: 1.19.5
scikit-learn: 0.24.1
Using device:  cuda


## Preparation

In [11]:
# specify learning params
batch_size = 128
learning_rate = 0.1
epochs = 100

# training

train = True

In [12]:
test_data = load_util.load_custom_cifar('./data', download=False, train=False, data_percent=1.0, for_model='SimCLR', transforms=False)
testloader = torch.utils.data.DataLoader(test_data,
                                          batch_size=128,
                                          shuffle=True,
                                          drop_last=True)

In [13]:
colors_classes = {i: color_class for i, color_class in zip(range(len(test_data.classes)), test_data.classes)}

In [14]:
def get_nmis_and_cas_10_runs(model_name, colors_classes, device):
    nmis = {}
    cas = {}

    for i in range(10):
        name = f'{model_name}_{i}.pth'
        print(name)
        
        model = load_model(name, device, torch.rand(size=(4, 12288)))
        nmis[i] = []
        cas[i] = []

        for k in range(10):
            labels, kmeans, nmi, _, _ = compute_nmi_and_pca(model, name, colors_classes, device, testloader)
            nmis[i].append(nmi)

            ca = cluster_accuracy(labels, kmeans.labels_)
            cas[i].append(ca)
            print(f'Run: {k}\n{name}:\nNMI:{nmi}\nCA:{ca}\n')
            del labels
            del kmeans
    return nmis, cas

In [15]:
def print_total_mean_and_var(metrics_name, df):
    models_means = df.mean()
    models_stds = df.std()

    total_mean = models_means.mean()
    total_std = models_stds.std()
    
    print(f'{metrics_name} mean: {total_mean:.4f}')
    print(f'{metrics_name} variance: {total_std:.4f}')

## RotNet

### CIFAR : Pretraining

In [110]:
path_nmis = 'trained_models/RotNet/CIFAR/nmis.csv'
path_cas = 'trained_models/RotNet/CIFAR/cas.csv'

Uncomment this part to generate new NMI and cluster accuracies

In [None]:
# rotnet_nmis_cifar, rotnet_cas_cifar = get_nmis_and_cas_10_runs('RotNet\CIFAR\pretrained_RotNet', colors_classes, device)

In [None]:
# df_rotnet_nmis_cifar = pd.DataFrame(rotnet_nmis_cifar)
# df_rotnet_cas_cifar = pd.DataFrame(rotnet_cas_cifar)

In [74]:
# df_rotnet_nmis_cifar.to_csv(path_nmis)
# df_rotnet_cas_cifar.to_csv(path_cas)

This code loads existing dataframes from specified paths

In [111]:
df_rotnet_nmis_cifar = pd.read_csv(path_nmis, index_col=0)
df_rotnet_cas_cifar = pd.read_csv(path_cas, index_col=0)

In [112]:
print_total_mean_and_var('NMI', df_rotnet_nmis_cifar)

NMI mean: 0.3506
NMI variance: 0.0017


In [113]:
print_total_mean_and_var('Cluser accuracy', df_rotnet_cas_cifar)

Cluser accuracy mean: 0.4357
Cluser accuracy variance: 0.0051


### CIFAR: DEC

In [114]:
path_nmis = 'trained_models/RotNet/CIFAR/nmis_dec.csv'
path_cas = 'trained_models/RotNet/CIFAR/cas_dec.csv'

Uncomment this part to generate new NMI and cluster accuracies

In [None]:
# rotnet_nmis_cifar_dec, rotnet_cas_cifar_dec = get_nmis_and_cas_10_runs('RotNet\CIFAR\DEC_RotNet', colors_classes, device)

In [None]:
# df_rotnet_nmis_cifar_dec = pd.DataFrame(rotnet_nmis_cifar_dec)
# df_rotnet_cas_cifar_dec = pd.DataFrame(rotnet_cas_cifar_dec)

In [105]:
# df_rotnet_nmis_cifar_dec.to_csv(path_nmis)
# df_rotnet_cas_cifar_dec.to_csv(path_cas)

This code loads existing dataframes from specified paths

In [115]:
df_rotnet_nmis_cifar_dec = pd.read_csv(path_nmis, index_col=0)
df_rotnet_cas_cifar_dec = pd.read_csv(path_cas, index_col=0)

In [117]:
print_total_mean_and_var('NMI', df_rotnet_nmis_cifar_dec)

NMI mean: 0.1042
NMI variance: 0.0001


In [118]:
print_total_mean_and_var('Cluser accuracy', df_rotnet_cas_cifar_dec)

Cluser accuracy mean: 0.2065
Cluser accuracy variance: 0.0002


### STL10: Pretraining

In [144]:
path_nmis = 'trained_models/RotNet/STL10/nmis.csv'
path_cas = 'trained_models/RotNet/STL10/cas.csv'

Uncomment this part to generate new NMI and cluster accuracies

In [None]:
# rotnet_nmis_stl10, rotnet_cas_stl10 = get_nmis_and_cas_10_runs('RotNet\STL10\pretrained_RotNet_STL10', colors_classes, device)

In [None]:
# df_rotnet_nmis_stl10 = pd.DataFrame(rotnet_nmis_stl10)
# df_rotnet_cas_stl10 = pd.DataFrame(rotnet_cas_stl10)

In [149]:
# df_rotnet_nmis_stl10.to_csv(path_nmis)
# df_rotnet_cas_stl10.to_csv(path_cas)

In [145]:
df_rotnet_nmis_stl10 = pd.read_csv(path_nmis, index_col=0)
df_rotnet_cas_stl10 = pd.read_csv(path_cas, index_col=0)

In [128]:
print_total_mean_and_var('NMI', df_rotnet_nmis_stl10)

NMI mean: 0.2977
NMI variance: 0.0022


In [129]:
print_total_mean_and_var('Cluser accuracy', df_rotnet_cas_stl10)

Cluser accuracy mean: 0.3826
Cluser accuracy variance: 0.0069


### STL10: DEC

In [8]:
path_nmis = 'trained_models/RotNet/STL10/nmis_dec.csv'
path_cas = 'trained_models/RotNet/STL10/cas_dec.csv'

Uncomment this part to generate new NMI and cluster accuracies

In [16]:
rotnet_nmis_stl10_dec, rotnet_cas_stl10_dec = get_nmis_and_cas_10_runs('RotNet\STL10\DEC_RotNet_STL10', colors_classes, device)

RotNet\STL10\DEC_RotNet_STL10_0.pth


  self.centers = torch.nn.Parameter(torch.tensor(init_np_centers), requires_grad=True)


Run: 0
RotNet\STL10\DEC_RotNet_STL10_0.pth:
NMI:0.12208453487474598
CA:0.22876602564102563

Run: 1
RotNet\STL10\DEC_RotNet_STL10_0.pth:
NMI:0.12155474772359603
CA:0.22846554487179488

Run: 2
RotNet\STL10\DEC_RotNet_STL10_0.pth:
NMI:0.121995399029329
CA:0.22906650641025642

Run: 3
RotNet\STL10\DEC_RotNet_STL10_0.pth:
NMI:0.12096344375228539
CA:0.2272636217948718

Run: 4
RotNet\STL10\DEC_RotNet_STL10_0.pth:
NMI:0.12348306702860942
CA:0.2281650641025641

Run: 5
RotNet\STL10\DEC_RotNet_STL10_0.pth:
NMI:0.12218242504376206
CA:0.22876602564102563

Run: 6
RotNet\STL10\DEC_RotNet_STL10_0.pth:
NMI:0.12175842687384915
CA:0.22896634615384615

Run: 7
RotNet\STL10\DEC_RotNet_STL10_0.pth:
NMI:0.12156853269391757
CA:0.22826522435897437

Run: 8
RotNet\STL10\DEC_RotNet_STL10_0.pth:
NMI:0.12202411442203921
CA:0.22876602564102563

Run: 9
RotNet\STL10\DEC_RotNet_STL10_0.pth:
NMI:0.12157784478325348
CA:0.22846554487179488

RotNet\STL10\DEC_RotNet_STL10_1.pth


  self.centers = torch.nn.Parameter(torch.tensor(init_np_centers), requires_grad=True)


Run: 0
RotNet\STL10\DEC_RotNet_STL10_1.pth:
NMI:0.11714083535160714
CA:0.21374198717948717

Run: 1
RotNet\STL10\DEC_RotNet_STL10_1.pth:
NMI:0.11686183501592559
CA:0.21374198717948717

Run: 2
RotNet\STL10\DEC_RotNet_STL10_1.pth:
NMI:0.11708882572716388
CA:0.21394230769230768

Run: 3
RotNet\STL10\DEC_RotNet_STL10_1.pth:
NMI:0.11722081115418638
CA:0.21564503205128205

Run: 4
RotNet\STL10\DEC_RotNet_STL10_1.pth:
NMI:0.117096532542
CA:0.21354166666666666

Run: 5
RotNet\STL10\DEC_RotNet_STL10_1.pth:
NMI:0.11690777099668934
CA:0.21374198717948717

Run: 6
RotNet\STL10\DEC_RotNet_STL10_1.pth:
NMI:0.11667191082239005
CA:0.2150440705128205

Run: 7
RotNet\STL10\DEC_RotNet_STL10_1.pth:
NMI:0.11683256582667337
CA:0.21494391025641027

Run: 8
RotNet\STL10\DEC_RotNet_STL10_1.pth:
NMI:0.11718121158974824
CA:0.21374198717948717

Run: 9
RotNet\STL10\DEC_RotNet_STL10_1.pth:
NMI:0.11742802851143412
CA:0.2153445512820513

RotNet\STL10\DEC_RotNet_STL10_2.pth


  self.centers = torch.nn.Parameter(torch.tensor(init_np_centers), requires_grad=True)


Run: 0
RotNet\STL10\DEC_RotNet_STL10_2.pth:
NMI:0.11878755577315159
CA:0.2104366987179487

Run: 1
RotNet\STL10\DEC_RotNet_STL10_2.pth:
NMI:0.11915803067001182
CA:0.21233974358974358

Run: 2
RotNet\STL10\DEC_RotNet_STL10_2.pth:
NMI:0.1181032359796965
CA:0.2109375

Run: 3
RotNet\STL10\DEC_RotNet_STL10_2.pth:
NMI:0.11902453249001482
CA:0.2102363782051282

Run: 4
RotNet\STL10\DEC_RotNet_STL10_2.pth:
NMI:0.11861027806390949
CA:0.20943509615384615

Run: 5
RotNet\STL10\DEC_RotNet_STL10_2.pth:
NMI:0.11837208304506236
CA:0.2116386217948718

Run: 6
RotNet\STL10\DEC_RotNet_STL10_2.pth:
NMI:0.12005438883814992
CA:0.21003605769230768

Run: 7
RotNet\STL10\DEC_RotNet_STL10_2.pth:
NMI:0.11892067056552408
CA:0.2102363782051282

Run: 8
RotNet\STL10\DEC_RotNet_STL10_2.pth:
NMI:0.11844823581341114
CA:0.21153846153846154

Run: 9
RotNet\STL10\DEC_RotNet_STL10_2.pth:
NMI:0.11863243315694638
CA:0.21153846153846154

RotNet\STL10\DEC_RotNet_STL10_3.pth


  self.centers = torch.nn.Parameter(torch.tensor(init_np_centers), requires_grad=True)


Run: 0
RotNet\STL10\DEC_RotNet_STL10_3.pth:
NMI:0.11481090836530751
CA:0.2068309294871795

Run: 1
RotNet\STL10\DEC_RotNet_STL10_3.pth:
NMI:0.11490314962938558
CA:0.20833333333333334

Run: 2
RotNet\STL10\DEC_RotNet_STL10_3.pth:
NMI:0.11466961414364969
CA:0.20833333333333334

Run: 3
RotNet\STL10\DEC_RotNet_STL10_3.pth:
NMI:0.1147444918015921
CA:0.2075320512820513

Run: 4
RotNet\STL10\DEC_RotNet_STL10_3.pth:
NMI:0.11462953464023001
CA:0.20813301282051283

Run: 5
RotNet\STL10\DEC_RotNet_STL10_3.pth:
NMI:0.1147921994354148
CA:0.20703125

Run: 6
RotNet\STL10\DEC_RotNet_STL10_3.pth:
NMI:0.11463103235598757
CA:0.20743189102564102

Run: 7
RotNet\STL10\DEC_RotNet_STL10_3.pth:
NMI:0.11485224278044892
CA:0.20793269230769232

Run: 8
RotNet\STL10\DEC_RotNet_STL10_3.pth:
NMI:0.11464456767272883
CA:0.20783253205128205

Run: 9
RotNet\STL10\DEC_RotNet_STL10_3.pth:
NMI:0.11487219051556215
CA:0.20803285256410256

RotNet\STL10\DEC_RotNet_STL10_4.pth


  self.centers = torch.nn.Parameter(torch.tensor(init_np_centers), requires_grad=True)


Run: 0
RotNet\STL10\DEC_RotNet_STL10_4.pth:
NMI:0.10984162224051433
CA:0.21864983974358973

Run: 1
RotNet\STL10\DEC_RotNet_STL10_4.pth:
NMI:0.10955392408189737
CA:0.2180488782051282

Run: 2
RotNet\STL10\DEC_RotNet_STL10_4.pth:
NMI:0.10960210831407613
CA:0.2182491987179487

Run: 3
RotNet\STL10\DEC_RotNet_STL10_4.pth:
NMI:0.10980239731816047
CA:0.21844951923076922

Run: 4
RotNet\STL10\DEC_RotNet_STL10_4.pth:
NMI:0.1065465723200158
CA:0.21684695512820512

Run: 5
RotNet\STL10\DEC_RotNet_STL10_4.pth:
NMI:0.10954781815845377
CA:0.2180488782051282

Run: 6
RotNet\STL10\DEC_RotNet_STL10_4.pth:
NMI:0.1097335309044981
CA:0.21814903846153846

Run: 7
RotNet\STL10\DEC_RotNet_STL10_4.pth:
NMI:0.10623292290041411
CA:0.21654647435897437

Run: 8
RotNet\STL10\DEC_RotNet_STL10_4.pth:
NMI:0.10966196378986925
CA:0.21834935897435898

Run: 9
RotNet\STL10\DEC_RotNet_STL10_4.pth:
NMI:0.10968061311002841
CA:0.21814903846153846

RotNet\STL10\DEC_RotNet_STL10_5.pth


  self.centers = torch.nn.Parameter(torch.tensor(init_np_centers), requires_grad=True)


Run: 0
RotNet\STL10\DEC_RotNet_STL10_5.pth:
NMI:0.011886814589668409
CA:0.10947516025641026

Run: 1
RotNet\STL10\DEC_RotNet_STL10_5.pth:
NMI:0.012991787477077231
CA:0.11017628205128205

Run: 2
RotNet\STL10\DEC_RotNet_STL10_5.pth:
NMI:0.012430547299732643
CA:0.10917467948717949

Run: 3
RotNet\STL10\DEC_RotNet_STL10_5.pth:
NMI:0.011917582538389532
CA:0.10897435897435898

Run: 4
RotNet\STL10\DEC_RotNet_STL10_5.pth:
NMI:0.012346044006012831
CA:0.109375

Run: 5
RotNet\STL10\DEC_RotNet_STL10_5.pth:
NMI:0.012578370581424818
CA:0.109375

Run: 6
RotNet\STL10\DEC_RotNet_STL10_5.pth:
NMI:0.012582682913604253
CA:0.10977564102564102

Run: 7
RotNet\STL10\DEC_RotNet_STL10_5.pth:
NMI:0.011956522224587087
CA:0.10927483974358974

Run: 8
RotNet\STL10\DEC_RotNet_STL10_5.pth:
NMI:0.012268233235650551
CA:0.10947516025641026

Run: 9
RotNet\STL10\DEC_RotNet_STL10_5.pth:
NMI:0.012891492804493367
CA:0.10987580128205128

RotNet\STL10\DEC_RotNet_STL10_6.pth


  self.centers = torch.nn.Parameter(torch.tensor(init_np_centers), requires_grad=True)


Run: 0
RotNet\STL10\DEC_RotNet_STL10_6.pth:
NMI:0.02777956319256333
CA:0.1321113782051282

Run: 1
RotNet\STL10\DEC_RotNet_STL10_6.pth:
NMI:0.02691674992336181
CA:0.13191105769230768

Run: 2
RotNet\STL10\DEC_RotNet_STL10_6.pth:
NMI:0.02556482113749972
CA:0.12890625

Run: 3
RotNet\STL10\DEC_RotNet_STL10_6.pth:
NMI:0.027487564256295942
CA:0.13110977564102563

Run: 4
RotNet\STL10\DEC_RotNet_STL10_6.pth:
NMI:0.025603564859462167
CA:0.1291065705128205

Run: 5
RotNet\STL10\DEC_RotNet_STL10_6.pth:
NMI:0.02551681645902679
CA:0.12890625

Run: 6
RotNet\STL10\DEC_RotNet_STL10_6.pth:
NMI:0.025709456911561908
CA:0.12880608974358973

Run: 7
RotNet\STL10\DEC_RotNet_STL10_6.pth:
NMI:0.02583086695959913
CA:0.12850560897435898

Run: 8
RotNet\STL10\DEC_RotNet_STL10_6.pth:
NMI:0.027173225760340966
CA:0.13151041666666666

Run: 9
RotNet\STL10\DEC_RotNet_STL10_6.pth:
NMI:0.02698372103611639
CA:0.1323116987179487

RotNet\STL10\DEC_RotNet_STL10_7.pth


  self.centers = torch.nn.Parameter(torch.tensor(init_np_centers), requires_grad=True)


Run: 0
RotNet\STL10\DEC_RotNet_STL10_7.pth:
NMI:0.12926007906795292
CA:0.22415865384615385

Run: 1
RotNet\STL10\DEC_RotNet_STL10_7.pth:
NMI:0.12839192692011347
CA:0.22155448717948717

Run: 2
RotNet\STL10\DEC_RotNet_STL10_7.pth:
NMI:0.12948731435229918
CA:0.22395833333333334

Run: 3
RotNet\STL10\DEC_RotNet_STL10_7.pth:
NMI:0.12801140782245685
CA:0.22175480769230768

Run: 4
RotNet\STL10\DEC_RotNet_STL10_7.pth:
NMI:0.12832377368330014
CA:0.22165464743589744

Run: 5
RotNet\STL10\DEC_RotNet_STL10_7.pth:
NMI:0.12825159196915634
CA:0.22155448717948717

Run: 6
RotNet\STL10\DEC_RotNet_STL10_7.pth:
NMI:0.12781375720134547
CA:0.2224559294871795

Run: 7
RotNet\STL10\DEC_RotNet_STL10_7.pth:
NMI:0.12927045954023636
CA:0.22435897435897437

Run: 8
RotNet\STL10\DEC_RotNet_STL10_7.pth:
NMI:0.12770130135894825
CA:0.21834935897435898

Run: 9
RotNet\STL10\DEC_RotNet_STL10_7.pth:
NMI:0.1245613036680977
CA:0.22355769230769232

RotNet\STL10\DEC_RotNet_STL10_8.pth


  self.centers = torch.nn.Parameter(torch.tensor(init_np_centers), requires_grad=True)


Run: 0
RotNet\STL10\DEC_RotNet_STL10_8.pth:
NMI:0.1253114629800414
CA:0.2169471153846154

Run: 1
RotNet\STL10\DEC_RotNet_STL10_8.pth:
NMI:0.12519931653552893
CA:0.21684695512820512

Run: 2
RotNet\STL10\DEC_RotNet_STL10_8.pth:
NMI:0.12497405415195449
CA:0.21674679487179488

Run: 3
RotNet\STL10\DEC_RotNet_STL10_8.pth:
NMI:0.12511579325221464
CA:0.2166466346153846

Run: 4
RotNet\STL10\DEC_RotNet_STL10_8.pth:
NMI:0.12546297231309284
CA:0.21864983974358973

Run: 5
RotNet\STL10\DEC_RotNet_STL10_8.pth:
NMI:0.12540116112335312
CA:0.2192508012820513

Run: 6
RotNet\STL10\DEC_RotNet_STL10_8.pth:
NMI:0.12514338355219828
CA:0.2166466346153846

Run: 7
RotNet\STL10\DEC_RotNet_STL10_8.pth:
NMI:0.12522213510986605
CA:0.21684695512820512

Run: 8
RotNet\STL10\DEC_RotNet_STL10_8.pth:
NMI:0.12534790208237245
CA:0.21875

Run: 9
RotNet\STL10\DEC_RotNet_STL10_8.pth:
NMI:0.12504791934050963
CA:0.21734775641025642

RotNet\STL10\DEC_RotNet_STL10_9.pth


  self.centers = torch.nn.Parameter(torch.tensor(init_np_centers), requires_grad=True)


Run: 0
RotNet\STL10\DEC_RotNet_STL10_9.pth:
NMI:0.1274944931730199
CA:0.20963541666666666

Run: 1
RotNet\STL10\DEC_RotNet_STL10_9.pth:
NMI:0.12731539234372838
CA:0.20973557692307693

Run: 2
RotNet\STL10\DEC_RotNet_STL10_9.pth:
NMI:0.1277800022873754
CA:0.20973557692307693

Run: 3
RotNet\STL10\DEC_RotNet_STL10_9.pth:
NMI:0.12757360905858173
CA:0.20953525641025642

Run: 4
RotNet\STL10\DEC_RotNet_STL10_9.pth:
NMI:0.1273482477345077
CA:0.20963541666666666

Run: 5
RotNet\STL10\DEC_RotNet_STL10_9.pth:
NMI:0.126852790142821
CA:0.20562900641025642

Run: 6
RotNet\STL10\DEC_RotNet_STL10_9.pth:
NMI:0.12753890718129676
CA:0.20993589743589744

Run: 7
RotNet\STL10\DEC_RotNet_STL10_9.pth:
NMI:0.1273642161753056
CA:0.21003605769230768

Run: 8
RotNet\STL10\DEC_RotNet_STL10_9.pth:
NMI:0.12736290116398388
CA:0.20963541666666666

Run: 9
RotNet\STL10\DEC_RotNet_STL10_9.pth:
NMI:0.12756945764023717
CA:0.20963541666666666



In [17]:
df_rotnet_nmis_stl10_dec = pd.DataFrame(rotnet_nmis_stl10_dec)
df_rotnet_cas_stl10_dec = pd.DataFrame(rotnet_cas_stl10_dec)

In [18]:
df_rotnet_nmis_stl10_dec.to_csv(path_nmis)
df_rotnet_cas_stl10_dec.to_csv(path_cas)

In [19]:
df_rotnet_nmis_stl10_dec = pd.read_csv(path_nmis, index_col=0)
df_rotnet_cas_stl10_dec = pd.read_csv(path_cas, index_col=0)

In [20]:
print_total_mean_and_var('NMI', df_rotnet_nmis_stl10_dec)

NMI mean: 0.1001
NMI variance: 0.0005


In [21]:
print_total_mean_and_var('Cluser accuracy', df_rotnet_cas_stl10_dec)

Cluser accuracy mean: 0.1968
Cluser accuracy variance: 0.0005


## SimCLR

### CIFAR: Pretraining

In [130]:
path_nmis = 'trained_models/SimCLR/CIFAR/nmis.csv'
path_cas = 'trained_models/SimCLR/CIFAR/cas.csv'

Uncomment this part to generate new NMI and cluster accuracies

In [None]:
# simclr_nmis_cifar, simclr_cas_cifar = get_nmis_and_cas_10_runs('SimCLR\CIFAR\pretrained_SimCLR', colors_classes, device)

In [24]:
# df_simclr_nmis_cifar = pd.DataFrame(simclr_nmis_cifar)
# df_simclr_cas_cifar = pd.DataFrame(simclr_cas_cifar)

In [134]:
# df_simclr_nmis_cifar.to_csv(path_nmis)
# df_simclr_cas_cifar.to_csv(path_cas)

In [132]:
df_simclr_nmis_cifar = pd.read_csv(path_nmis, index_col=0)
df_simclr_cas_cifar = pd.read_csv(path_cas, index_col=0)

In [135]:
print_total_mean_and_var('NMI', df_simclr_nmis_cifar)

NMI mean: 0.4724
NMI variance: 0.0033


In [136]:
print_total_mean_and_var('Cluser accuracy', df_simclr_cas_cifar)

Cluser accuracy mean: 0.5904
Cluser accuracy variance: 0.0090


In [98]:
columns = [f'Run {i}' for i in range(10)]

In [133]:
df_simclr_nmis_cifar.columns = columns
df_simclr_cas_cifar.columns = columns

### CIFAR: IDEC

In [137]:
path_nmis = 'trained_models/SimCLR/CIFAR/nmis_idec.csv'
path_cas = 'trained_models/SimCLR/CIFAR/cas_idec.csv'

Uncomment this part to generate new NMI and cluster accuracies

In [None]:
# simclr_nmis_cifar_idec, simclr_cas_cifar_idec = get_nmis_and_cas_10_runs('SimCLR\CIFAR\IDEC_SimCLR', colors_classes, device)

In [27]:
# df_simclr_nmis_cifar_idec = pd.DataFrame(simclr_nmis_cifar_idec)
# df_simclr_cas_cifar_idec = pd.DataFrame(simclr_cas_cifar_idec)

In [141]:
# df_simclr_nmis_cifar_idec.to_csv(path_nmis)
# df_simclr_cas_cifar_idec.to_csv(path_cas)

In [138]:
df_simclr_nmis_cifar_idec = pd.read_csv(path_nmis, index_col=0)
df_simclr_cas_cifar_idec = pd.read_csv(path_cas, index_col=0)

In [142]:
print_total_mean_and_var('NMI', df_simclr_nmis_cifar_idec)

NMI mean: 0.5569
NMI variance: 0.0001


In [143]:
print_total_mean_and_var('Cluser accuracy', df_simclr_cas_cifar_idec)

Cluser accuracy mean: 0.6576
Cluser accuracy variance: 0.0000


### STL10: Pretraining

In [22]:
path_nmis = 'trained_models/SimCLR/STL10/nmis_dec.csv'
path_cas = 'trained_models/SimCLR/STL10/cas_dec.csv'

Uncomment this part to generate new NMI and cluster accuracies

In [23]:
simclr_nmis_stl10, simclr_cas_stl10 = get_nmis_and_cas_10_runs('SimCLR\STL10\pretrained_SimCLR_STL10', colors_classes, device)

SimCLR\STL10\pretrained_SimCLR_STL10_0.pth
Run: 0
SimCLR\STL10\pretrained_SimCLR_STL10_0.pth:
NMI:0.31808179438875117
CA:0.422275641025641

Run: 1
SimCLR\STL10\pretrained_SimCLR_STL10_0.pth:
NMI:0.31469791407695513
CA:0.3915264423076923

Run: 2
SimCLR\STL10\pretrained_SimCLR_STL10_0.pth:
NMI:0.31422274999198446
CA:0.41456330128205127

Run: 3
SimCLR\STL10\pretrained_SimCLR_STL10_0.pth:
NMI:0.32019680765389813
CA:0.4305889423076923

Run: 4
SimCLR\STL10\pretrained_SimCLR_STL10_0.pth:
NMI:0.3186096554249888
CA:0.4276842948717949

Run: 5
SimCLR\STL10\pretrained_SimCLR_STL10_0.pth:
NMI:0.3252095068360046
CA:0.433994391025641

Run: 6
SimCLR\STL10\pretrained_SimCLR_STL10_0.pth:
NMI:0.3261675674244283
CA:0.43970352564102566

Run: 7
SimCLR\STL10\pretrained_SimCLR_STL10_0.pth:
NMI:0.31748564988624384
CA:0.42788461538461536

Run: 8
SimCLR\STL10\pretrained_SimCLR_STL10_0.pth:
NMI:0.31879043087646225
CA:0.4282852564102564

Run: 9
SimCLR\STL10\pretrained_SimCLR_STL10_0.pth:
NMI:0.3233407561167077
CA:

Run: 0
SimCLR\STL10\pretrained_SimCLR_STL10_8.pth:
NMI:0.3078016799568831
CA:0.36889022435897434

Run: 1
SimCLR\STL10\pretrained_SimCLR_STL10_8.pth:
NMI:0.30451119012303807
CA:0.3759014423076923

Run: 2
SimCLR\STL10\pretrained_SimCLR_STL10_8.pth:
NMI:0.30645449744366227
CA:0.4130608974358974

Run: 3
SimCLR\STL10\pretrained_SimCLR_STL10_8.pth:
NMI:0.30372132368227234
CA:0.36889022435897434

Run: 4
SimCLR\STL10\pretrained_SimCLR_STL10_8.pth:
NMI:0.3071798990859893
CA:0.36568509615384615

Run: 5
SimCLR\STL10\pretrained_SimCLR_STL10_8.pth:
NMI:0.3056697968536708
CA:0.36047676282051283

Run: 6
SimCLR\STL10\pretrained_SimCLR_STL10_8.pth:
NMI:0.31300577985800976
CA:0.4193709935897436

Run: 7
SimCLR\STL10\pretrained_SimCLR_STL10_8.pth:
NMI:0.30443797839362396
CA:0.358974358974359

Run: 8
SimCLR\STL10\pretrained_SimCLR_STL10_8.pth:
NMI:0.3110231860972469
CA:0.41786858974358976

Run: 9
SimCLR\STL10\pretrained_SimCLR_STL10_8.pth:
NMI:0.3078767920568552
CA:0.39923878205128205

SimCLR\STL10\pretrai

In [24]:
df_simclr_nmis_stl10 = pd.DataFrame(simclr_nmis_stl10)
df_simclr_cas_stl0 = pd.DataFrame(simclr_cas_stl10)

In [25]:
df_simclr_nmis_stl10.to_csv(path_nmis)
df_simclr_cas_stl0.to_csv(path_cas)

In [26]:
df_simclr_nmis_stl10 = pd.read_csv(path_nmis, index_col=0)
df_simclr_cas_stl0 = pd.read_csv(path_cas, index_col=0)

In [27]:
print_total_mean_and_var('NMI', df_simclr_nmis_stl10)

NMI mean: 0.3215
NMI variance: 0.0011


In [28]:
print_total_mean_and_var('Cluser accuracy', df_simclr_cas_stl0)

Cluser accuracy mean: 0.4231
Cluser accuracy variance: 0.0067


### STL10: IDEC

In [29]:
path_nmis_dec = 'trained_models/SimCLR/STL10/nmis_idec.csv'
path_cas_dec = 'trained_models/SimCLR/STL10/cas_idec.csv'

Uncomment this part to generate new NMI and cluster accuracies

In [30]:
simclr_nmis_stl10_idec, simclr_cas_stl10_idec = get_nmis_and_cas_10_runs('SimCLR\STL10\IDEC_SimCLR_STL10', colors_classes, device)

SimCLR\STL10\IDEC_SimCLR_STL10_0.pth


  self.centers = torch.nn.Parameter(torch.tensor(init_np_centers), requires_grad=True)


FileNotFoundError: [Errno 2] No such file or directory: 'trained_models/SimCLR\\STL10\\IDEC_SimCLR_STL10_0.pth'

In [None]:
df_simclr_nmis_stl10_idec = pd.DataFrame(simclr_nmis_stl10_idec)
df_simclr_cas_stl10_idec = pd.DataFrame(simclr_cas_stl10_idec)

In [None]:
df_simclr_nmis_stl10_idec.to_csv(path_nmis)
df_simclr_cas_stl10_idec.to_csv(path_cas)

In [None]:
df_simclr_nmis_stl10_idec = pd.read_csv(path_nmis, index_col=0)
df_simclr_cas_stl10_idec = pd.read_csv(path_cas, index_col=0)

In [None]:
print_total_mean_and_var('NMI', df_simclr_nmis_stl10_idec)

In [None]:
print_total_mean_and_var('Cluser accuracy', df_simclr_cas_stl10_idec)