In [30]:
%load_ext autoreload
%autoreload 2
%matplotlib widget
import numpy as np
import pandas as pd
import torch
import torchvision
from torch import optim
import matplotlib.pyplot as plt
import torch.nn.functional as F
from sklearn.metrics import pairwise_distances
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tqdm.notebook import tqdm
import os
import sys
import time
import torch.nn as nn
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)
from stream_generators.mit_bih import MITBIH
from utils.MatplotlibUtils import reduce_dims_and_plot
from network.auto_encoder import AutoEncoder
from losses.knn_loss import KNNLoss, ClassificationKNNLoss
from sklearn.cluster import DBSCAN
from sklearn.mixture import GaussianMixture
from soft_decision_tree.sdt_model import SDT
from sklearn.metrics import davies_bouldin_score

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [31]:
k = 64
tree_depth = 6
batch_size = 512
device = 'cuda'
train_data_path = r'<>/mitbih_train.csv'  # replace <> with the correct path of the dataset
test_data_path = r'<>/mitbih_test.csv'  # replace <> with the correct path of the dataset

In [32]:
train_data_iter = torch.utils.data.DataLoader(MITBIH(train_data_path),
                                             batch_size=batch_size,
                                             shuffle=True,
                                             num_workers=1,
                                             pin_memory=True,
                                             drop_last=True)

test_data_iter = torch.utils.data.DataLoader(MITBIH(test_data_path),
                                             batch_size=batch_size,
                                             shuffle=True,
                                             num_workers=1,
                                             pin_memory=True)

In [33]:
class ConvBlock(nn.Module):
    def __init__(self):
        super(ConvBlock, self).__init__()
        self.conv1 = nn.Conv1d(32, 32, kernel_size=5, stride=1, padding=2)
        self.conv2 = nn.Conv1d(32, 32, kernel_size=5, stride=1, padding=2)
        self.relu1 = nn.ReLU()
        self.relu2 = nn.ReLU()
        self.pool = nn.MaxPool1d(kernel_size=5, stride=2)
        
    def forward(self, x):
        y = x
        y = self.conv1(y)
        y = self.relu1(y)
        y = self.conv2(y)
        y = y + x
        y = self.relu2(y)
        y = self.pool(y)
        return y


class ECGModel(nn.Module):
    def __init__(self):
        super(ECGModel, self).__init__()
        self.conv = nn.Conv1d(1, 32, kernel_size=5, stride=1, padding=1)
        self.block1 = ConvBlock()
        self.block2 = ConvBlock()
        self.block3 = ConvBlock()
        self.block4 = ConvBlock()
        self.block5 = ConvBlock()
        self.fc1 = nn.Linear(64, 32)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(32, 5)

    def forward(self, x, return_interm=False):
        x = self.conv(x)
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        x = self.block4(x)
        x = self.block5(x)
        interm = x.flatten(1)
        x = self.fc1(interm)
        x = self.relu(x)
        x = self.fc2(x)
        
        if return_interm:
            return x, interm
        
        return x

In [34]:
knn_crt = ClassificationKNNLoss(k=k).to(device)

def train(model, loader, optimizer, device):
    model.train()

    total_loss = 0
    for iteration, (batch, target) in enumerate(loader):
        batch = batch.to(device)
        target = target.to(device)
        outputs, interm = model(batch, return_interm=True)
        mse_loss = F.cross_entropy(outputs, target)
        mse_loss = mse_loss.sum(dim=-1).mean()
        try:
            knn_loss = knn_crt(interm, target)
            if torch.isinf(knn_loss):
                knn_loss = torch.tensor(0).to(device)
        except ValueError:
            knn_loss = torch.tensor(0).to(device)
        loss = mse_loss + knn_loss
        total_loss += loss.item()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if iteration % log_every == 0:
            print(f"Epoch {epoch} / {epochs} | iteration {iteration} / {len(loader)} | Total Loss: {loss.item()} | KNN Loss: {knn_loss.item()} | CLS Loss: {mse_loss.item()}")

    return total_loss / len(loader)

@torch.no_grad()
def test(model, loader, device):
    model.eval()
    
    correct = 0
    for iteration, (batch, target) in enumerate(loader):
        batch = batch.to(device)
        target = target.to(device)
        y_pred = model(batch).argmax(dim=-1)
        correct += y_pred.eq(target.view(-1).data).sum()
    
    return correct / len(loader.dataset)

In [35]:
epochs = 200
lr = 1e-3
log_every = 10

model = ECGModel().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
num_params = sum([p.numel() for p in model.parameters()])
print(f'#Params: {num_params}')

#Params: 53957


In [None]:
best_valid_acc = 0
losses = []
train_accs = []
val_accs = []
for epoch in range(1, epochs + 1):
    loss = train(model, train_data_iter, optimizer, device)
#     print(f"Loss: {loss} =============================")
    losses.append(loss)
    train_acc = test(model, train_data_iter, device)
    train_accs.append(train_acc)
    valid_acc = test(model, test_data_iter, device)
    val_accs.append(valid_acc)
    if valid_acc > best_valid_acc:
        best_valid_acc = valid_acc
    if epoch % 1 == 0:
        print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, '
              f'Train: {train_acc:.4f}, Valid: {valid_acc:.4f}, '
              f'Best: {best_valid_acc:.4f}')

Epoch 1 / 200 | iteration 0 / 171 | Total Loss: 7.303014755249023 | KNN Loss: 5.731903553009033 | CLS Loss: 1.5711109638214111
Epoch 1 / 200 | iteration 10 / 171 | Total Loss: 5.449155807495117 | KNN Loss: 4.73738956451416 | CLS Loss: 0.7117664813995361
Epoch 1 / 200 | iteration 20 / 171 | Total Loss: 5.2877278327941895 | KNN Loss: 4.609186172485352 | CLS Loss: 0.6785417795181274
Epoch 1 / 200 | iteration 30 / 171 | Total Loss: 5.199596405029297 | KNN Loss: 4.553031921386719 | CLS Loss: 0.6465644836425781
Epoch 1 / 200 | iteration 40 / 171 | Total Loss: 5.04219388961792 | KNN Loss: 4.514331340789795 | CLS Loss: 0.5278626084327698
Epoch 1 / 200 | iteration 50 / 171 | Total Loss: 5.134302616119385 | KNN Loss: 4.567552089691162 | CLS Loss: 0.5667505860328674
Epoch 1 / 200 | iteration 60 / 171 | Total Loss: 5.0261921882629395 | KNN Loss: 4.5067338943481445 | CLS Loss: 0.5194582939147949
Epoch 1 / 200 | iteration 70 / 171 | Total Loss: 4.983083724975586 | KNN Loss: 4.485644817352295 | CLS L

Epoch 4 / 200 | iteration 80 / 171 | Total Loss: 4.453580856323242 | KNN Loss: 4.280031204223633 | CLS Loss: 0.17354948818683624
Epoch 4 / 200 | iteration 90 / 171 | Total Loss: 4.456829071044922 | KNN Loss: 4.278853893280029 | CLS Loss: 0.1779753863811493
Epoch 4 / 200 | iteration 100 / 171 | Total Loss: 4.422021389007568 | KNN Loss: 4.252720832824707 | CLS Loss: 0.169300377368927
Epoch 4 / 200 | iteration 110 / 171 | Total Loss: 4.439184665679932 | KNN Loss: 4.2957000732421875 | CLS Loss: 0.14348460733890533
Epoch 4 / 200 | iteration 120 / 171 | Total Loss: 4.44588041305542 | KNN Loss: 4.277695178985596 | CLS Loss: 0.16818521916866302
Epoch 4 / 200 | iteration 130 / 171 | Total Loss: 4.444919109344482 | KNN Loss: 4.2993974685668945 | CLS Loss: 0.1455218642950058
Epoch 4 / 200 | iteration 140 / 171 | Total Loss: 4.41594123840332 | KNN Loss: 4.2986297607421875 | CLS Loss: 0.11731164157390594
Epoch 4 / 200 | iteration 150 / 171 | Total Loss: 4.403146743774414 | KNN Loss: 4.2885956764221

Epoch 7 / 200 | iteration 160 / 171 | Total Loss: 4.271496772766113 | KNN Loss: 4.237403869628906 | CLS Loss: 0.03409278392791748
Epoch 7 / 200 | iteration 170 / 171 | Total Loss: 4.4215545654296875 | KNN Loss: 4.311953544616699 | CLS Loss: 0.10960079729557037
Epoch: 007, Loss: 4.3496, Train: 0.9781, Valid: 0.9762, Best: 0.9762
Epoch 8 / 200 | iteration 0 / 171 | Total Loss: 4.328602313995361 | KNN Loss: 4.25540018081665 | CLS Loss: 0.07320191711187363
Epoch 8 / 200 | iteration 10 / 171 | Total Loss: 4.354703426361084 | KNN Loss: 4.2576446533203125 | CLS Loss: 0.09705885499715805
Epoch 8 / 200 | iteration 20 / 171 | Total Loss: 4.363182544708252 | KNN Loss: 4.277888774871826 | CLS Loss: 0.08529369533061981
Epoch 8 / 200 | iteration 30 / 171 | Total Loss: 4.294840335845947 | KNN Loss: 4.2391510009765625 | CLS Loss: 0.05568942427635193
Epoch 8 / 200 | iteration 40 / 171 | Total Loss: 4.3226799964904785 | KNN Loss: 4.238986492156982 | CLS Loss: 0.08369331061840057
Epoch 8 / 200 | iteratio

Epoch 11 / 200 | iteration 50 / 171 | Total Loss: 4.319055080413818 | KNN Loss: 4.250992298126221 | CLS Loss: 0.06806281954050064
Epoch 11 / 200 | iteration 60 / 171 | Total Loss: 4.287293910980225 | KNN Loss: 4.216714859008789 | CLS Loss: 0.0705791637301445
Epoch 11 / 200 | iteration 70 / 171 | Total Loss: 4.3196024894714355 | KNN Loss: 4.24448823928833 | CLS Loss: 0.07511436194181442
Epoch 11 / 200 | iteration 80 / 171 | Total Loss: 4.336727142333984 | KNN Loss: 4.268614768981934 | CLS Loss: 0.06811220943927765
Epoch 11 / 200 | iteration 90 / 171 | Total Loss: 4.2627034187316895 | KNN Loss: 4.196775913238525 | CLS Loss: 0.06592747569084167
Epoch 11 / 200 | iteration 100 / 171 | Total Loss: 4.307506561279297 | KNN Loss: 4.2461066246032715 | CLS Loss: 0.061399850994348526
Epoch 11 / 200 | iteration 110 / 171 | Total Loss: 4.333591938018799 | KNN Loss: 4.2286481857299805 | CLS Loss: 0.10494387149810791
Epoch 11 / 200 | iteration 120 / 171 | Total Loss: 4.277791500091553 | KNN Loss: 4.23

Epoch 14 / 200 | iteration 130 / 171 | Total Loss: 4.287924289703369 | KNN Loss: 4.241753578186035 | CLS Loss: 0.04617079719901085
Epoch 14 / 200 | iteration 140 / 171 | Total Loss: 4.2644243240356445 | KNN Loss: 4.207264423370361 | CLS Loss: 0.057159729301929474
Epoch 14 / 200 | iteration 150 / 171 | Total Loss: 4.306877136230469 | KNN Loss: 4.237769603729248 | CLS Loss: 0.06910748034715652
Epoch 14 / 200 | iteration 160 / 171 | Total Loss: 4.267545700073242 | KNN Loss: 4.225411415100098 | CLS Loss: 0.04213442653417587
Epoch 14 / 200 | iteration 170 / 171 | Total Loss: 4.248343467712402 | KNN Loss: 4.228784561157227 | CLS Loss: 0.019559090957045555
Epoch: 014, Loss: 4.2794, Train: 0.9854, Valid: 0.9820, Best: 0.9820
Epoch 15 / 200 | iteration 0 / 171 | Total Loss: 4.325794219970703 | KNN Loss: 4.262497425079346 | CLS Loss: 0.06329665333032608
Epoch 15 / 200 | iteration 10 / 171 | Total Loss: 4.276278972625732 | KNN Loss: 4.239833354949951 | CLS Loss: 0.0364457443356514
Epoch 15 / 200 

Epoch 18 / 200 | iteration 20 / 171 | Total Loss: 4.303857803344727 | KNN Loss: 4.226426124572754 | CLS Loss: 0.07743176817893982
Epoch 18 / 200 | iteration 30 / 171 | Total Loss: 4.239066123962402 | KNN Loss: 4.17639684677124 | CLS Loss: 0.06266903877258301
Epoch 18 / 200 | iteration 40 / 171 | Total Loss: 4.26939058303833 | KNN Loss: 4.226423263549805 | CLS Loss: 0.04296719282865524
Epoch 18 / 200 | iteration 50 / 171 | Total Loss: 4.297374725341797 | KNN Loss: 4.213367462158203 | CLS Loss: 0.08400703966617584
Epoch 18 / 200 | iteration 60 / 171 | Total Loss: 4.271676540374756 | KNN Loss: 4.241681098937988 | CLS Loss: 0.02999560534954071
Epoch 18 / 200 | iteration 70 / 171 | Total Loss: 4.251960754394531 | KNN Loss: 4.231335639953613 | CLS Loss: 0.020625127479434013
Epoch 18 / 200 | iteration 80 / 171 | Total Loss: 4.242746353149414 | KNN Loss: 4.223612308502197 | CLS Loss: 0.019134126603603363
Epoch 18 / 200 | iteration 90 / 171 | Total Loss: 4.2283406257629395 | KNN Loss: 4.1855177

Epoch 21 / 200 | iteration 100 / 171 | Total Loss: 4.259561538696289 | KNN Loss: 4.207320690155029 | CLS Loss: 0.052240628749132156
Epoch 21 / 200 | iteration 110 / 171 | Total Loss: 4.250393390655518 | KNN Loss: 4.197073936462402 | CLS Loss: 0.053319595754146576
Epoch 21 / 200 | iteration 120 / 171 | Total Loss: 4.259017467498779 | KNN Loss: 4.217649936676025 | CLS Loss: 0.041367728263139725
Epoch 21 / 200 | iteration 130 / 171 | Total Loss: 4.250395774841309 | KNN Loss: 4.1975555419921875 | CLS Loss: 0.05284019187092781
Epoch 21 / 200 | iteration 140 / 171 | Total Loss: 4.186774730682373 | KNN Loss: 4.170956611633301 | CLS Loss: 0.015817929059267044
Epoch 21 / 200 | iteration 150 / 171 | Total Loss: 4.233867168426514 | KNN Loss: 4.186148166656494 | CLS Loss: 0.04771913215517998
Epoch 21 / 200 | iteration 160 / 171 | Total Loss: 4.241882801055908 | KNN Loss: 4.191761493682861 | CLS Loss: 0.0501212440431118
Epoch 21 / 200 | iteration 170 / 171 | Total Loss: 4.277629852294922 | KNN Loss

Epoch: 024, Loss: 4.2390, Train: 0.9909, Valid: 0.9854, Best: 0.9854
Epoch 25 / 200 | iteration 0 / 171 | Total Loss: 4.265495300292969 | KNN Loss: 4.238539218902588 | CLS Loss: 0.02695620246231556
Epoch 25 / 200 | iteration 10 / 171 | Total Loss: 4.219695091247559 | KNN Loss: 4.180814743041992 | CLS Loss: 0.03888051211833954
Epoch 25 / 200 | iteration 20 / 171 | Total Loss: 4.232290744781494 | KNN Loss: 4.200233459472656 | CLS Loss: 0.032057516276836395
Epoch 25 / 200 | iteration 30 / 171 | Total Loss: 4.259670734405518 | KNN Loss: 4.225083827972412 | CLS Loss: 0.034587111324071884
Epoch 25 / 200 | iteration 40 / 171 | Total Loss: 4.251523017883301 | KNN Loss: 4.226673126220703 | CLS Loss: 0.024849819019436836
Epoch 25 / 200 | iteration 50 / 171 | Total Loss: 4.218588352203369 | KNN Loss: 4.196338653564453 | CLS Loss: 0.022249750792980194
Epoch 25 / 200 | iteration 60 / 171 | Total Loss: 4.233736515045166 | KNN Loss: 4.207286834716797 | CLS Loss: 0.02644953317940235
Epoch 25 / 200 | i

Epoch 28 / 200 | iteration 70 / 171 | Total Loss: 4.275606155395508 | KNN Loss: 4.237156867980957 | CLS Loss: 0.03844938054680824
Epoch 28 / 200 | iteration 80 / 171 | Total Loss: 4.209798812866211 | KNN Loss: 4.193807601928711 | CLS Loss: 0.015991419553756714
Epoch 28 / 200 | iteration 90 / 171 | Total Loss: 4.20304536819458 | KNN Loss: 4.172657012939453 | CLS Loss: 0.03038826398551464
Epoch 28 / 200 | iteration 100 / 171 | Total Loss: 4.267799377441406 | KNN Loss: 4.218329906463623 | CLS Loss: 0.04946959391236305
Epoch 28 / 200 | iteration 110 / 171 | Total Loss: 4.239386558532715 | KNN Loss: 4.214444160461426 | CLS Loss: 0.024942200630903244
Epoch 28 / 200 | iteration 120 / 171 | Total Loss: 4.184256076812744 | KNN Loss: 4.156450271606445 | CLS Loss: 0.027805667370557785
Epoch 28 / 200 | iteration 130 / 171 | Total Loss: 4.254861831665039 | KNN Loss: 4.222545623779297 | CLS Loss: 0.032315973192453384
Epoch 28 / 200 | iteration 140 / 171 | Total Loss: 4.228485107421875 | KNN Loss: 4.

Epoch 31 / 200 | iteration 150 / 171 | Total Loss: 4.229220390319824 | KNN Loss: 4.211800575256348 | CLS Loss: 0.017419731244444847
Epoch 31 / 200 | iteration 160 / 171 | Total Loss: 4.192266464233398 | KNN Loss: 4.172729969024658 | CLS Loss: 0.0195364598184824
Epoch 31 / 200 | iteration 170 / 171 | Total Loss: 4.243007659912109 | KNN Loss: 4.205879211425781 | CLS Loss: 0.03712835535407066
Epoch: 031, Loss: 4.2213, Train: 0.9923, Valid: 0.9853, Best: 0.9860
Epoch 32 / 200 | iteration 0 / 171 | Total Loss: 4.189051151275635 | KNN Loss: 4.1735639572143555 | CLS Loss: 0.015487393364310265
Epoch 32 / 200 | iteration 10 / 171 | Total Loss: 4.225273609161377 | KNN Loss: 4.191262722015381 | CLS Loss: 0.034010663628578186
Epoch 32 / 200 | iteration 20 / 171 | Total Loss: 4.220917701721191 | KNN Loss: 4.197173118591309 | CLS Loss: 0.02374441921710968
Epoch 32 / 200 | iteration 30 / 171 | Total Loss: 4.222548961639404 | KNN Loss: 4.196850299835205 | CLS Loss: 0.025698574259877205
Epoch 32 / 200 

Epoch 35 / 200 | iteration 40 / 171 | Total Loss: 4.229077339172363 | KNN Loss: 4.195725917816162 | CLS Loss: 0.03335139527916908
Epoch 35 / 200 | iteration 50 / 171 | Total Loss: 4.223647594451904 | KNN Loss: 4.196990966796875 | CLS Loss: 0.02665676362812519
Epoch 35 / 200 | iteration 60 / 171 | Total Loss: 4.197102069854736 | KNN Loss: 4.175550937652588 | CLS Loss: 0.02155124582350254
Epoch 35 / 200 | iteration 70 / 171 | Total Loss: 4.214033126831055 | KNN Loss: 4.189996242523193 | CLS Loss: 0.024036716669797897
Epoch 35 / 200 | iteration 80 / 171 | Total Loss: 4.235169887542725 | KNN Loss: 4.202627182006836 | CLS Loss: 0.032542772591114044
Epoch 35 / 200 | iteration 90 / 171 | Total Loss: 4.245207786560059 | KNN Loss: 4.1876983642578125 | CLS Loss: 0.05750952288508415
Epoch 35 / 200 | iteration 100 / 171 | Total Loss: 4.187944412231445 | KNN Loss: 4.153622150421143 | CLS Loss: 0.0343221016228199
Epoch 35 / 200 | iteration 110 / 171 | Total Loss: 4.234807014465332 | KNN Loss: 4.1863

Epoch 38 / 200 | iteration 120 / 171 | Total Loss: 4.225594520568848 | KNN Loss: 4.195630073547363 | CLS Loss: 0.029964551329612732
Epoch 38 / 200 | iteration 130 / 171 | Total Loss: 4.187765121459961 | KNN Loss: 4.166758060455322 | CLS Loss: 0.021006902679800987
Epoch 38 / 200 | iteration 140 / 171 | Total Loss: 4.200759410858154 | KNN Loss: 4.188342571258545 | CLS Loss: 0.012416871264576912
Epoch 38 / 200 | iteration 150 / 171 | Total Loss: 4.227649211883545 | KNN Loss: 4.194164276123047 | CLS Loss: 0.03348511457443237
Epoch 38 / 200 | iteration 160 / 171 | Total Loss: 4.270568370819092 | KNN Loss: 4.220340251922607 | CLS Loss: 0.05022818222641945
Epoch 38 / 200 | iteration 170 / 171 | Total Loss: 4.2346367835998535 | KNN Loss: 4.2113447189331055 | CLS Loss: 0.02329188585281372
Epoch: 038, Loss: 4.2077, Train: 0.9912, Valid: 0.9844, Best: 0.9860
Epoch 39 / 200 | iteration 0 / 171 | Total Loss: 4.251513481140137 | KNN Loss: 4.220104217529297 | CLS Loss: 0.03140903264284134
Epoch 39 / 

Epoch 42 / 200 | iteration 10 / 171 | Total Loss: 4.178593158721924 | KNN Loss: 4.17196798324585 | CLS Loss: 0.006625153124332428
Epoch 42 / 200 | iteration 20 / 171 | Total Loss: 4.2046895027160645 | KNN Loss: 4.1780524253845215 | CLS Loss: 0.02663695067167282
Epoch 42 / 200 | iteration 30 / 171 | Total Loss: 4.174792766571045 | KNN Loss: 4.166518211364746 | CLS Loss: 0.008274378255009651
Epoch 42 / 200 | iteration 40 / 171 | Total Loss: 4.287055492401123 | KNN Loss: 4.255039691925049 | CLS Loss: 0.03201565518975258
Epoch 42 / 200 | iteration 50 / 171 | Total Loss: 4.218233585357666 | KNN Loss: 4.1948676109313965 | CLS Loss: 0.023365966975688934
Epoch 42 / 200 | iteration 60 / 171 | Total Loss: 4.211311340332031 | KNN Loss: 4.18679141998291 | CLS Loss: 0.02451995015144348
Epoch 42 / 200 | iteration 70 / 171 | Total Loss: 4.191420555114746 | KNN Loss: 4.151583671569824 | CLS Loss: 0.03983696922659874
Epoch 42 / 200 | iteration 80 / 171 | Total Loss: 4.273916721343994 | KNN Loss: 4.2302

Epoch 45 / 200 | iteration 80 / 171 | Total Loss: 4.2150163650512695 | KNN Loss: 4.199432373046875 | CLS Loss: 0.015583793632686138
Epoch 45 / 200 | iteration 90 / 171 | Total Loss: 4.2043633460998535 | KNN Loss: 4.1850738525390625 | CLS Loss: 0.01928948424756527
Epoch 45 / 200 | iteration 100 / 171 | Total Loss: 4.213202476501465 | KNN Loss: 4.193416595458984 | CLS Loss: 0.019785992801189423
Epoch 45 / 200 | iteration 110 / 171 | Total Loss: 4.189053058624268 | KNN Loss: 4.167591571807861 | CLS Loss: 0.021461494266986847
Epoch 45 / 200 | iteration 120 / 171 | Total Loss: 4.203876495361328 | KNN Loss: 4.184410095214844 | CLS Loss: 0.019466590136289597
Epoch 45 / 200 | iteration 130 / 171 | Total Loss: 4.207544803619385 | KNN Loss: 4.178778648376465 | CLS Loss: 0.02876628190279007
Epoch 45 / 200 | iteration 140 / 171 | Total Loss: 4.223671913146973 | KNN Loss: 4.192754745483398 | CLS Loss: 0.030917160212993622
Epoch 45 / 200 | iteration 150 / 171 | Total Loss: 4.210974216461182 | KNN Lo

Epoch 48 / 200 | iteration 150 / 171 | Total Loss: 4.202439785003662 | KNN Loss: 4.172946929931641 | CLS Loss: 0.02949298359453678
Epoch 48 / 200 | iteration 160 / 171 | Total Loss: 4.217105865478516 | KNN Loss: 4.204037189483643 | CLS Loss: 0.013068810105323792
Epoch 48 / 200 | iteration 170 / 171 | Total Loss: 4.218495845794678 | KNN Loss: 4.19435453414917 | CLS Loss: 0.02414124459028244
Epoch: 048, Loss: 4.1964, Train: 0.9943, Valid: 0.9858, Best: 0.9877
Epoch 49 / 200 | iteration 0 / 171 | Total Loss: 4.174630165100098 | KNN Loss: 4.161332130432129 | CLS Loss: 0.01329818181693554
Epoch 49 / 200 | iteration 10 / 171 | Total Loss: 4.209981441497803 | KNN Loss: 4.201780796051025 | CLS Loss: 0.008200523443520069
Epoch 49 / 200 | iteration 20 / 171 | Total Loss: 4.175197124481201 | KNN Loss: 4.1686320304870605 | CLS Loss: 0.006565262097865343
Epoch 49 / 200 | iteration 30 / 171 | Total Loss: 4.208946228027344 | KNN Loss: 4.193747520446777 | CLS Loss: 0.015198523178696632
Epoch 49 / 200 

Epoch 52 / 200 | iteration 40 / 171 | Total Loss: 4.1946821212768555 | KNN Loss: 4.176973819732666 | CLS Loss: 0.017708133906126022
Epoch 52 / 200 | iteration 50 / 171 | Total Loss: 4.202315807342529 | KNN Loss: 4.189834117889404 | CLS Loss: 0.012481702491641045
Epoch 52 / 200 | iteration 60 / 171 | Total Loss: 4.20235013961792 | KNN Loss: 4.175153732299805 | CLS Loss: 0.02719629369676113
Epoch 52 / 200 | iteration 70 / 171 | Total Loss: 4.306292533874512 | KNN Loss: 4.272714614868164 | CLS Loss: 0.03357803821563721
Epoch 52 / 200 | iteration 80 / 171 | Total Loss: 4.217885971069336 | KNN Loss: 4.194332122802734 | CLS Loss: 0.023553909733891487
Epoch 52 / 200 | iteration 90 / 171 | Total Loss: 4.210465908050537 | KNN Loss: 4.188170433044434 | CLS Loss: 0.022295527160167694
Epoch 52 / 200 | iteration 100 / 171 | Total Loss: 4.2572150230407715 | KNN Loss: 4.216191291809082 | CLS Loss: 0.041023626923561096
Epoch 52 / 200 | iteration 110 / 171 | Total Loss: 4.2148590087890625 | KNN Loss: 4

Epoch 55 / 200 | iteration 110 / 171 | Total Loss: 4.283218860626221 | KNN Loss: 4.244460582733154 | CLS Loss: 0.038758207112550735
Epoch 55 / 200 | iteration 120 / 171 | Total Loss: 4.160972595214844 | KNN Loss: 4.155418872833252 | CLS Loss: 0.005553707480430603
Epoch 55 / 200 | iteration 130 / 171 | Total Loss: 4.201815605163574 | KNN Loss: 4.174366474151611 | CLS Loss: 0.02744893543422222
Epoch 55 / 200 | iteration 140 / 171 | Total Loss: 4.202620506286621 | KNN Loss: 4.1774773597717285 | CLS Loss: 0.025143008679151535
Epoch 55 / 200 | iteration 150 / 171 | Total Loss: 4.160615921020508 | KNN Loss: 4.144134521484375 | CLS Loss: 0.016481172293424606
Epoch 55 / 200 | iteration 160 / 171 | Total Loss: 4.17690372467041 | KNN Loss: 4.168121337890625 | CLS Loss: 0.00878248829394579
Epoch 55 / 200 | iteration 170 / 171 | Total Loss: 4.2019243240356445 | KNN Loss: 4.184511661529541 | CLS Loss: 0.017412450164556503
Epoch: 055, Loss: 4.1892, Train: 0.9947, Valid: 0.9854, Best: 0.9877
Epoch 56

Epoch: 058, Loss: 4.1864, Train: 0.9949, Valid: 0.9857, Best: 0.9877
Epoch 59 / 200 | iteration 0 / 171 | Total Loss: 4.1580610275268555 | KNN Loss: 4.152759552001953 | CLS Loss: 0.005301548633724451
Epoch 59 / 200 | iteration 10 / 171 | Total Loss: 4.188549995422363 | KNN Loss: 4.148524284362793 | CLS Loss: 0.040025655180215836
Epoch 59 / 200 | iteration 20 / 171 | Total Loss: 4.178872108459473 | KNN Loss: 4.157114505767822 | CLS Loss: 0.021757641807198524
Epoch 59 / 200 | iteration 30 / 171 | Total Loss: 4.163790225982666 | KNN Loss: 4.147476673126221 | CLS Loss: 0.016313612461090088
Epoch 59 / 200 | iteration 40 / 171 | Total Loss: 4.23149299621582 | KNN Loss: 4.222084045410156 | CLS Loss: 0.009408792480826378
Epoch 59 / 200 | iteration 50 / 171 | Total Loss: 4.192058563232422 | KNN Loss: 4.184995174407959 | CLS Loss: 0.007063187193125486
Epoch 59 / 200 | iteration 60 / 171 | Total Loss: 4.179533004760742 | KNN Loss: 4.176252365112305 | CLS Loss: 0.003280709031969309
Epoch 59 / 200 

Epoch 62 / 200 | iteration 70 / 171 | Total Loss: 4.202852249145508 | KNN Loss: 4.180104732513428 | CLS Loss: 0.02274765633046627
Epoch 62 / 200 | iteration 80 / 171 | Total Loss: 4.186528205871582 | KNN Loss: 4.161948204040527 | CLS Loss: 0.024579916149377823
Epoch 62 / 200 | iteration 90 / 171 | Total Loss: 4.224067687988281 | KNN Loss: 4.204464435577393 | CLS Loss: 0.019603334367275238
Epoch 62 / 200 | iteration 100 / 171 | Total Loss: 4.140321254730225 | KNN Loss: 4.1305155754089355 | CLS Loss: 0.009805455803871155
Epoch 62 / 200 | iteration 110 / 171 | Total Loss: 4.218142032623291 | KNN Loss: 4.1864471435546875 | CLS Loss: 0.031695082783699036
Epoch 62 / 200 | iteration 120 / 171 | Total Loss: 4.177969932556152 | KNN Loss: 4.159965991973877 | CLS Loss: 0.018004121258854866
Epoch 62 / 200 | iteration 130 / 171 | Total Loss: 4.169165134429932 | KNN Loss: 4.156026363372803 | CLS Loss: 0.0131387235596776
Epoch 62 / 200 | iteration 140 / 171 | Total Loss: 4.2020158767700195 | KNN Loss

Epoch 65 / 200 | iteration 140 / 171 | Total Loss: 4.149967193603516 | KNN Loss: 4.126682758331299 | CLS Loss: 0.023284204304218292
Epoch 65 / 200 | iteration 150 / 171 | Total Loss: 4.1950812339782715 | KNN Loss: 4.183898448944092 | CLS Loss: 0.01118288841098547
Epoch 65 / 200 | iteration 160 / 171 | Total Loss: 4.1710052490234375 | KNN Loss: 4.163559436798096 | CLS Loss: 0.007445796392858028
Epoch 65 / 200 | iteration 170 / 171 | Total Loss: 4.160733699798584 | KNN Loss: 4.144943714141846 | CLS Loss: 0.015789829194545746
Epoch: 065, Loss: 4.1793, Train: 0.9953, Valid: 0.9852, Best: 0.9877
Epoch 66 / 200 | iteration 0 / 171 | Total Loss: 4.219118118286133 | KNN Loss: 4.183114528656006 | CLS Loss: 0.03600335121154785
Epoch 66 / 200 | iteration 10 / 171 | Total Loss: 4.160523891448975 | KNN Loss: 4.146389961242676 | CLS Loss: 0.01413408201187849
Epoch 66 / 200 | iteration 20 / 171 | Total Loss: 4.176066875457764 | KNN Loss: 4.167725563049316 | CLS Loss: 0.008341513574123383
Epoch 66 / 2

Epoch 69 / 200 | iteration 30 / 171 | Total Loss: 4.161281585693359 | KNN Loss: 4.155218601226807 | CLS Loss: 0.006062796339392662
Epoch 69 / 200 | iteration 40 / 171 | Total Loss: 4.206069469451904 | KNN Loss: 4.188847541809082 | CLS Loss: 0.017221912741661072
Epoch 69 / 200 | iteration 50 / 171 | Total Loss: 4.157599925994873 | KNN Loss: 4.15139102935791 | CLS Loss: 0.00620880164206028
Epoch 69 / 200 | iteration 60 / 171 | Total Loss: 4.141510963439941 | KNN Loss: 4.135335445404053 | CLS Loss: 0.006175358314067125
Epoch 69 / 200 | iteration 70 / 171 | Total Loss: 4.170724868774414 | KNN Loss: 4.14848518371582 | CLS Loss: 0.02223968505859375
Epoch 69 / 200 | iteration 80 / 171 | Total Loss: 4.239158630371094 | KNN Loss: 4.203597545623779 | CLS Loss: 0.03556087985634804
Epoch 69 / 200 | iteration 90 / 171 | Total Loss: 4.225435733795166 | KNN Loss: 4.179186820983887 | CLS Loss: 0.04624883085489273
Epoch 69 / 200 | iteration 100 / 171 | Total Loss: 4.150387763977051 | KNN Loss: 4.144852

Epoch 72 / 200 | iteration 100 / 171 | Total Loss: 4.170676231384277 | KNN Loss: 4.148681640625 | CLS Loss: 0.02199462056159973
Epoch 72 / 200 | iteration 110 / 171 | Total Loss: 4.2044477462768555 | KNN Loss: 4.183150291442871 | CLS Loss: 0.021297693252563477
Epoch 72 / 200 | iteration 120 / 171 | Total Loss: 4.215217113494873 | KNN Loss: 4.193902969360352 | CLS Loss: 0.021314125508069992
Epoch 72 / 200 | iteration 130 / 171 | Total Loss: 4.149684906005859 | KNN Loss: 4.13804817199707 | CLS Loss: 0.011636967770755291
Epoch 72 / 200 | iteration 140 / 171 | Total Loss: 4.20574426651001 | KNN Loss: 4.187355041503906 | CLS Loss: 0.01838923618197441
Epoch 72 / 200 | iteration 150 / 171 | Total Loss: 4.197363376617432 | KNN Loss: 4.1769843101501465 | CLS Loss: 0.0203792005777359
Epoch 72 / 200 | iteration 160 / 171 | Total Loss: 4.219241619110107 | KNN Loss: 4.185904502868652 | CLS Loss: 0.03333720564842224
Epoch 72 / 200 | iteration 170 / 171 | Total Loss: 4.235894203186035 | KNN Loss: 4.2

Epoch 75 / 200 | iteration 170 / 171 | Total Loss: 4.165290355682373 | KNN Loss: 4.154426574707031 | CLS Loss: 0.010863877832889557
Epoch: 075, Loss: 4.1779, Train: 0.9956, Valid: 0.9850, Best: 0.9877
Epoch 76 / 200 | iteration 0 / 171 | Total Loss: 4.209721088409424 | KNN Loss: 4.195563793182373 | CLS Loss: 0.014157470315694809
Epoch 76 / 200 | iteration 10 / 171 | Total Loss: 4.16067361831665 | KNN Loss: 4.154938697814941 | CLS Loss: 0.005734800361096859
Epoch 76 / 200 | iteration 20 / 171 | Total Loss: 4.171138763427734 | KNN Loss: 4.147661209106445 | CLS Loss: 0.023477718234062195
Epoch 76 / 200 | iteration 30 / 171 | Total Loss: 4.242914199829102 | KNN Loss: 4.216883659362793 | CLS Loss: 0.026030760258436203
Epoch 76 / 200 | iteration 40 / 171 | Total Loss: 4.16958475112915 | KNN Loss: 4.156491279602051 | CLS Loss: 0.013093606568872929
Epoch 76 / 200 | iteration 50 / 171 | Total Loss: 4.16823148727417 | KNN Loss: 4.158453464508057 | CLS Loss: 0.009778019040822983
Epoch 76 / 200 | 

Epoch 79 / 200 | iteration 60 / 171 | Total Loss: 4.166930198669434 | KNN Loss: 4.16139030456543 | CLS Loss: 0.005539972335100174
Epoch 79 / 200 | iteration 70 / 171 | Total Loss: 4.169836044311523 | KNN Loss: 4.141866207122803 | CLS Loss: 0.027969712391495705
Epoch 79 / 200 | iteration 80 / 171 | Total Loss: 4.2296037673950195 | KNN Loss: 4.2214484214782715 | CLS Loss: 0.008155218325555325
Epoch 79 / 200 | iteration 90 / 171 | Total Loss: 4.1908769607543945 | KNN Loss: 4.160884857177734 | CLS Loss: 0.029992029070854187
Epoch 79 / 200 | iteration 100 / 171 | Total Loss: 4.188178062438965 | KNN Loss: 4.1837286949157715 | CLS Loss: 0.004449509549885988
Epoch 79 / 200 | iteration 110 / 171 | Total Loss: 4.142493724822998 | KNN Loss: 4.134003162384033 | CLS Loss: 0.008490631356835365
Epoch 79 / 200 | iteration 120 / 171 | Total Loss: 4.185122489929199 | KNN Loss: 4.175537109375 | CLS Loss: 0.009585217572748661
Epoch 79 / 200 | iteration 130 / 171 | Total Loss: 4.192245006561279 | KNN Loss:

Epoch 82 / 200 | iteration 130 / 171 | Total Loss: 4.155994415283203 | KNN Loss: 4.132330417633057 | CLS Loss: 0.023664116859436035
Epoch 82 / 200 | iteration 140 / 171 | Total Loss: 4.1669535636901855 | KNN Loss: 4.163586616516113 | CLS Loss: 0.0033669774420559406
Epoch 82 / 200 | iteration 150 / 171 | Total Loss: 4.171258926391602 | KNN Loss: 4.163424015045166 | CLS Loss: 0.00783504731953144
Epoch 82 / 200 | iteration 160 / 171 | Total Loss: 4.204437732696533 | KNN Loss: 4.1814751625061035 | CLS Loss: 0.02296256273984909
Epoch 82 / 200 | iteration 170 / 171 | Total Loss: 4.179472923278809 | KNN Loss: 4.148589611053467 | CLS Loss: 0.03088338114321232
Epoch: 082, Loss: 4.1793, Train: 0.9959, Valid: 0.9866, Best: 0.9878
Epoch 83 / 200 | iteration 0 / 171 | Total Loss: 4.178383827209473 | KNN Loss: 4.163675785064697 | CLS Loss: 0.014707834459841251
Epoch 83 / 200 | iteration 10 / 171 | Total Loss: 4.192801475524902 | KNN Loss: 4.166308879852295 | CLS Loss: 0.02649264968931675
Epoch 83 / 

Epoch 86 / 200 | iteration 20 / 171 | Total Loss: 4.153111457824707 | KNN Loss: 4.150784969329834 | CLS Loss: 0.0023263453040271997
Epoch 86 / 200 | iteration 30 / 171 | Total Loss: 4.211200714111328 | KNN Loss: 4.2052083015441895 | CLS Loss: 0.005992430727928877
Epoch 86 / 200 | iteration 40 / 171 | Total Loss: 4.188557147979736 | KNN Loss: 4.170945644378662 | CLS Loss: 0.017611712217330933
Epoch 86 / 200 | iteration 50 / 171 | Total Loss: 4.149726867675781 | KNN Loss: 4.135758399963379 | CLS Loss: 0.013968263752758503
Epoch 86 / 200 | iteration 60 / 171 | Total Loss: 4.198293685913086 | KNN Loss: 4.182538032531738 | CLS Loss: 0.015755776315927505
Epoch 86 / 200 | iteration 70 / 171 | Total Loss: 4.164061069488525 | KNN Loss: 4.139739036560059 | CLS Loss: 0.02432223968207836
Epoch 86 / 200 | iteration 80 / 171 | Total Loss: 4.149442672729492 | KNN Loss: 4.140484809875488 | CLS Loss: 0.008958091959357262
Epoch 86 / 200 | iteration 90 / 171 | Total Loss: 4.155911445617676 | KNN Loss: 4.

Epoch 89 / 200 | iteration 90 / 171 | Total Loss: 4.1525468826293945 | KNN Loss: 4.1374945640563965 | CLS Loss: 0.015052217990159988
Epoch 89 / 200 | iteration 100 / 171 | Total Loss: 4.18501091003418 | KNN Loss: 4.164855003356934 | CLS Loss: 0.020155927166342735
Epoch 89 / 200 | iteration 110 / 171 | Total Loss: 4.166687488555908 | KNN Loss: 4.149980545043945 | CLS Loss: 0.01670677773654461
Epoch 89 / 200 | iteration 120 / 171 | Total Loss: 4.208606719970703 | KNN Loss: 4.180209159851074 | CLS Loss: 0.028397560119628906
Epoch 89 / 200 | iteration 130 / 171 | Total Loss: 4.2185540199279785 | KNN Loss: 4.210280418395996 | CLS Loss: 0.008273464627563953
Epoch 89 / 200 | iteration 140 / 171 | Total Loss: 4.214430809020996 | KNN Loss: 4.202151298522949 | CLS Loss: 0.012279417365789413
Epoch 89 / 200 | iteration 150 / 171 | Total Loss: 4.182023048400879 | KNN Loss: 4.167685031890869 | CLS Loss: 0.014338172972202301
Epoch 89 / 200 | iteration 160 / 171 | Total Loss: 4.17145299911499 | KNN Lo

Epoch 92 / 200 | iteration 160 / 171 | Total Loss: 4.180145740509033 | KNN Loss: 4.1617536544799805 | CLS Loss: 0.01839185319840908
Epoch 92 / 200 | iteration 170 / 171 | Total Loss: 4.194956302642822 | KNN Loss: 4.1919379234313965 | CLS Loss: 0.0030185317154973745
Epoch: 092, Loss: 4.1729, Train: 0.9963, Valid: 0.9865, Best: 0.9878
Epoch 93 / 200 | iteration 0 / 171 | Total Loss: 4.200282573699951 | KNN Loss: 4.175217628479004 | CLS Loss: 0.025065094232559204
Epoch 93 / 200 | iteration 10 / 171 | Total Loss: 4.13863468170166 | KNN Loss: 4.134592533111572 | CLS Loss: 0.0040422710590064526
Epoch 93 / 200 | iteration 20 / 171 | Total Loss: 4.167498588562012 | KNN Loss: 4.149827003479004 | CLS Loss: 0.017671680077910423
Epoch 93 / 200 | iteration 30 / 171 | Total Loss: 4.143689155578613 | KNN Loss: 4.1400861740112305 | CLS Loss: 0.0036030588671565056
Epoch 93 / 200 | iteration 40 / 171 | Total Loss: 4.132369041442871 | KNN Loss: 4.126833915710449 | CLS Loss: 0.005535160191357136
Epoch 93 

Epoch 96 / 200 | iteration 50 / 171 | Total Loss: 4.188086986541748 | KNN Loss: 4.171458721160889 | CLS Loss: 0.01662849821150303
Epoch 96 / 200 | iteration 60 / 171 | Total Loss: 4.154277324676514 | KNN Loss: 4.1363911628723145 | CLS Loss: 0.017885973677039146
Epoch 96 / 200 | iteration 70 / 171 | Total Loss: 4.154818534851074 | KNN Loss: 4.14848518371582 | CLS Loss: 0.006333326920866966
Epoch 96 / 200 | iteration 80 / 171 | Total Loss: 4.169034481048584 | KNN Loss: 4.16160249710083 | CLS Loss: 0.007432049140334129
Epoch 96 / 200 | iteration 90 / 171 | Total Loss: 4.147444248199463 | KNN Loss: 4.139397144317627 | CLS Loss: 0.008047022856771946
Epoch 96 / 200 | iteration 100 / 171 | Total Loss: 4.214555740356445 | KNN Loss: 4.159028053283691 | CLS Loss: 0.05552755668759346
Epoch 96 / 200 | iteration 110 / 171 | Total Loss: 4.18003511428833 | KNN Loss: 4.1660075187683105 | CLS Loss: 0.014027445577085018
Epoch 96 / 200 | iteration 120 / 171 | Total Loss: 4.181146144866943 | KNN Loss: 4.1

Epoch 99 / 200 | iteration 120 / 171 | Total Loss: 4.182658672332764 | KNN Loss: 4.172940254211426 | CLS Loss: 0.009718348272144794
Epoch 99 / 200 | iteration 130 / 171 | Total Loss: 4.150428295135498 | KNN Loss: 4.142909049987793 | CLS Loss: 0.007519348058849573
Epoch 99 / 200 | iteration 140 / 171 | Total Loss: 4.1450347900390625 | KNN Loss: 4.140814781188965 | CLS Loss: 0.004220094531774521
Epoch 99 / 200 | iteration 150 / 171 | Total Loss: 4.186178207397461 | KNN Loss: 4.180509090423584 | CLS Loss: 0.00566928181797266
Epoch 99 / 200 | iteration 160 / 171 | Total Loss: 4.161584854125977 | KNN Loss: 4.156813621520996 | CLS Loss: 0.0047710672952234745
Epoch 99 / 200 | iteration 170 / 171 | Total Loss: 4.174562931060791 | KNN Loss: 4.171230316162109 | CLS Loss: 0.0033327247947454453
Epoch: 099, Loss: 4.1704, Train: 0.9962, Valid: 0.9871, Best: 0.9878
Epoch 100 / 200 | iteration 0 / 171 | Total Loss: 4.171031475067139 | KNN Loss: 4.1543288230896 | CLS Loss: 0.01670273393392563
Epoch 100

Epoch 103 / 200 | iteration 0 / 171 | Total Loss: 4.172691345214844 | KNN Loss: 4.165294170379639 | CLS Loss: 0.007397236302495003
Epoch 103 / 200 | iteration 10 / 171 | Total Loss: 4.156515121459961 | KNN Loss: 4.151280403137207 | CLS Loss: 0.005234846379607916
Epoch 103 / 200 | iteration 20 / 171 | Total Loss: 4.1643967628479 | KNN Loss: 4.157748222351074 | CLS Loss: 0.006648611277341843
Epoch 103 / 200 | iteration 30 / 171 | Total Loss: 4.1371684074401855 | KNN Loss: 4.1332855224609375 | CLS Loss: 0.0038827117532491684
Epoch 103 / 200 | iteration 40 / 171 | Total Loss: 4.189187049865723 | KNN Loss: 4.182706832885742 | CLS Loss: 0.006480189971625805
Epoch 103 / 200 | iteration 50 / 171 | Total Loss: 4.200575351715088 | KNN Loss: 4.194869518280029 | CLS Loss: 0.005705840419977903
Epoch 103 / 200 | iteration 60 / 171 | Total Loss: 4.1659345626831055 | KNN Loss: 4.159405708312988 | CLS Loss: 0.006528736092150211
Epoch 103 / 200 | iteration 70 / 171 | Total Loss: 4.186144828796387 | KNN 

In [37]:
test(model, test_data_iter, device)

tensor(0.9879, device='cuda:0')

In [39]:
plt.figure()
plt.plot(losses, label='train loss')
plt.legend()
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [40]:
plt.figure()
plt.plot(train_accs, label='train accuracy')
plt.plot(val_accs, label='validation accuracy')
plt.legend()
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [41]:
test_samples = torch.tensor([])
projections = torch.tensor([])
labels = torch.tensor([])

with torch.no_grad():
    for x, y in tqdm(test_data_iter):
        test_samples = torch.cat([test_samples, x])
        labels = torch.cat([labels, y])
        x = x.to(device)
        _, interm = model(x, True)
        projections = torch.cat([projections, interm.detach().cpu().flatten(1)])

  0%|          | 0/43 [00:00<?, ?it/s]

In [42]:
distances = pairwise_distances(projections)
# distances = np.triu(distances)
distances_f = distances.flatten()

plt.matshow(distances)
plt.colorbar()
plt.figure()
plt.hist(distances_f[distances_f > 0], bins=1000)
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [43]:
clusters = DBSCAN(eps=2, min_samples=10).fit_predict(projections)

In [44]:
print(f"Number of inliers: {sum(clusters != -1) / len(clusters)}")

Number of inliers: 0.9344022657713216


In [45]:
perplexity = 100
p = reduce_dims_and_plot(projections[clusters != -1],
                         y=clusters[clusters != -1],
                         title=f'perplexity: {perplexity}',
                         file_name=None,
                         perplexity=perplexity,
                         library='Multicore-TSNE',
                         perform_PCA=False,
                         projected=None,
                         figure_type='2d',
                         show_figure=True,
                         close_figure=False,
                         text=None)    

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

# Train a Soft-Decision-Tree given the self-labels

## Prepare the dataset

In [46]:
tree_dataset = list(zip(test_samples.flatten(1)[clusters!=-1], clusters[clusters != -1]))
batch_size = 512
tree_loader = torch.utils.data.DataLoader(tree_dataset, batch_size=batch_size, shuffle=True)

# Define how we prune the weights of a node

In [47]:
def prune_node(node_weights, factor=1):
    w = node_weights.cpu().detach().numpy()
    mean_ = np.mean(w)
    std_ = np.std(w)
    node_weights[((mean_ - std_ * factor) < node_weights) & (node_weights < (mean_ + std_ * factor))] = 0
    return node_weights

def prune_node_keep(node_weights, keep=4):
    w = node_weights.cpu().detach().numpy()
    throw_idx = np.argsort(abs(w))[:-keep]
    node_weights[throw_idx] = 0
    return node_weights

def prune_tree(tree_, factor):
    new_weights = tree_.inner_nodes.weight.clone()
    for i in range(new_weights.shape[0]):
        res = prune_node_keep(new_weights[i, :], factor)
        new_weights[i, :] = res

    with torch.no_grad():
        tree_.inner_nodes.weight.copy_(new_weights)
        
def sparseness(x):
    s = []
    for i in range(x.shape[0]):
        x_ = x[i, :]
        sp = (len(x_) - torch.norm(x_, 0).item()) / len(x_)
        s.append(sp)
    return np.mean(s)

def compute_regularization_by_level(tree):
    total_reg = 0
    for i in range(tree.inner_nodes.weight.shape[0]):
        cur_level = np.floor(np.log2(i+1))
        node_reg = torch.norm(tree.inner_nodes.weight[i].view(-1), 2)
        total_reg += 2**(-cur_level) * node_reg
    return total_reg

def show_sparseness(tree):
    avg_sp = sparseness(tree.inner_nodes.weight)
    print(f"Average sparseness: {avg_sp}")
    layer = 0
    sps = []
    for i in range(tree.inner_nodes.weight.shape[0]):
        cur_layer = int(np.floor(np.log2(i+1)))
        if cur_layer != layer:
            print(f"layer {layer}: {np.mean(sps)}")
            sps = []
            layer = cur_layer

        x_ = tree.inner_nodes.weight[i, :]
        sp = (len(x_) - torch.norm(x_, 0).item()) / len(x_)
        sps.append(sp)
        
    return avg_sp

## Training configurations

In [48]:
def do_epoch(model, loader, device, log_interval, losses, accs, epoch, iteration):
    model = model.train()
    for batch_idx, (data, target) in enumerate(loader):
        iteration += 1
        data, target = data.to(device), target.to(device)

        output, penalty = tree.forward(data)

        # Loss
        loss_tree = criterion(output, target.view(-1))

        # Penalty
        loss_tree += penalty

        # Sparse regularization
#         fc_params = torch.cat([x.view(-1) for x in tree.inner_nodes.parameters()])
#         regularization = sparsity_lamda * torch.norm(fc_params, 2)
        regularization = sparsity_lamda * compute_regularization_by_level(tree)
        loss = loss_tree

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        losses.append(loss.item())

        pred = output.data.max(1)[1]
        correct = pred.eq(target.view(-1).data).sum()
        accs.append(correct.item() / data.size()[0])

        # Print training status
        if batch_idx % log_interval == 0:
            print(f"Epoch: {epoch:02d} | Batch: {batch_idx:03d} / {len(loader):03d} | Total loss: {loss.item():.3f} | Reg loss: {regularization.item():.3f} | Tree loss: {loss_tree.item():.3f} | Accuracy: {correct.item() / data.size()[0]:03f} | {round((time.time() - start_time) / iteration, 3)} sec/iter")
            
    return iteration


In [49]:
lr = 5e-3
weight_decay = 5e-4
sparsity_lamda = 2e-3
epochs = 400
log_interval = 100
use_cuda = device != 'cpu'

In [50]:
tree = SDT(input_dim=test_samples.shape[2], output_dim=len(set(clusters)) - 1, depth=tree_depth, lamda=1e-3, use_cuda=use_cuda)
optimizer = torch.optim.Adam(tree.parameters(),
                                 lr=lr,
                                 weight_decay=weight_decay)
criterion = nn.CrossEntropyLoss()
tree = tree.to(device)

In [51]:
losses = []
accs = []
sparsity = []

In [52]:
start_time = time.time()
iteration = 0
for epoch in range(epochs):
    # Training
    avg_sp = show_sparseness(tree)
    sparsity.append(avg_sp)
    iteration = do_epoch(tree, tree_loader, device, log_interval, losses, accs, epoch, iteration)
    
    if epoch % 1 == 0:
        prune_tree(tree, factor=3)
        

Average sparseness: 0.0
layer 0: 0.0
layer 1: 0.0
layer 2: 0.0
layer 3: 0.0
layer 4: 0.0
Epoch: 00 | Batch: 000 / 040 | Total loss: 1.575 | Reg loss: 0.007 | Tree loss: 1.575 | Accuracy: 0.066406 | 0.07 sec/iter
Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
Epoch: 01 | Batch: 000 / 040 | Total loss: 1.395 | Reg loss: 0.004 | Tree loss: 1.395 | Accuracy: 0.689453 | 0.046 sec/iter
Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
Epoch: 02 | Batch: 000 / 040 | Total loss: 1.311 | Reg loss: 0.006 | Tree loss: 1.311 | Accuracy: 0.636719 | 0.046 sec/iter
Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
Epoch:

Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
Epoch: 28 | Batch: 000 / 040 | Total loss: 0.815 | Reg loss: 0.021 | Tree loss: 0.815 | Accuracy: 0.660156 | 0.046 sec/iter
Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
Epoch: 29 | Batch: 000 / 040 | Total loss: 0.802 | Reg loss: 0.021 | Tree loss: 0.802 | Accuracy: 0.669922 | 0.046 sec/iter
Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
Epoch: 30 | Batch: 000 / 040 | Total loss: 0.781 | Reg loss: 0.021 | Tree loss: 0.781 | Accuracy: 0.675781 | 0.046 sec/iter
Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894


Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
Epoch: 56 | Batch: 000 / 040 | Total loss: 0.692 | Reg loss: 0.020 | Tree loss: 0.692 | Accuracy: 0.679688 | 0.046 sec/iter
Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
Epoch: 57 | Batch: 000 / 040 | Total loss: 0.701 | Reg loss: 0.020 | Tree loss: 0.701 | Accuracy: 0.697266 | 0.046 sec/iter
Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
Epoch: 58 | Batch: 000 / 040 | Total loss: 0.640 | Reg loss: 0.020 | Tree loss: 0.640 | Accuracy: 0.718750 | 0.046 sec/iter
Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894


Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
Epoch: 84 | Batch: 000 / 040 | Total loss: 0.692 | Reg loss: 0.020 | Tree loss: 0.692 | Accuracy: 0.681641 | 0.046 sec/iter
Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
Epoch: 85 | Batch: 000 / 040 | Total loss: 0.701 | Reg loss: 0.020 | Tree loss: 0.701 | Accuracy: 0.685547 | 0.046 sec/iter
Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
Epoch: 86 | Batch: 000 / 040 | Total loss: 0.623 | Reg loss: 0.020 | Tree loss: 0.623 | Accuracy: 0.724609 | 0.046 sec/iter
Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894


Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
Epoch: 112 | Batch: 000 / 040 | Total loss: 0.749 | Reg loss: 0.020 | Tree loss: 0.749 | Accuracy: 0.666016 | 0.046 sec/iter
Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
Epoch: 113 | Batch: 000 / 040 | Total loss: 0.666 | Reg loss: 0.020 | Tree loss: 0.666 | Accuracy: 0.691406 | 0.046 sec/iter
Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
Epoch: 114 | Batch: 000 / 040 | Total loss: 0.646 | Reg loss: 0.020 | Tree loss: 0.646 | Accuracy: 0.707031 | 0.046 sec/iter
Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.98404255319148

Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
Epoch: 140 | Batch: 000 / 040 | Total loss: 0.671 | Reg loss: 0.020 | Tree loss: 0.671 | Accuracy: 0.697266 | 0.046 sec/iter
Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
Epoch: 141 | Batch: 000 / 040 | Total loss: 0.702 | Reg loss: 0.020 | Tree loss: 0.702 | Accuracy: 0.712891 | 0.046 sec/iter
Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
Epoch: 142 | Batch: 000 / 040 | Total loss: 0.726 | Reg loss: 0.020 | Tree loss: 0.726 | Accuracy: 0.679688 | 0.046 sec/iter
Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.98404255319148

Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
Epoch: 168 | Batch: 000 / 040 | Total loss: 0.593 | Reg loss: 0.020 | Tree loss: 0.593 | Accuracy: 0.742188 | 0.046 sec/iter
Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
Epoch: 169 | Batch: 000 / 040 | Total loss: 0.672 | Reg loss: 0.020 | Tree loss: 0.672 | Accuracy: 0.716797 | 0.046 sec/iter
Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
Epoch: 170 | Batch: 000 / 040 | Total loss: 0.652 | Reg loss: 0.020 | Tree loss: 0.652 | Accuracy: 0.718750 | 0.046 sec/iter
Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.98404255319148

Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
Epoch: 196 | Batch: 000 / 040 | Total loss: 0.645 | Reg loss: 0.020 | Tree loss: 0.645 | Accuracy: 0.720703 | 0.046 sec/iter
Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
Epoch: 197 | Batch: 000 / 040 | Total loss: 0.674 | Reg loss: 0.020 | Tree loss: 0.674 | Accuracy: 0.755859 | 0.046 sec/iter
Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
Epoch: 198 | Batch: 000 / 040 | Total loss: 0.641 | Reg loss: 0.020 | Tree loss: 0.641 | Accuracy: 0.742188 | 0.046 sec/iter
Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.98404255319148

Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
Epoch: 224 | Batch: 000 / 040 | Total loss: 0.614 | Reg loss: 0.020 | Tree loss: 0.614 | Accuracy: 0.769531 | 0.046 sec/iter
Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
Epoch: 225 | Batch: 000 / 040 | Total loss: 0.582 | Reg loss: 0.020 | Tree loss: 0.582 | Accuracy: 0.775391 | 0.046 sec/iter
Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
Epoch: 226 | Batch: 000 / 040 | Total loss: 0.684 | Reg loss: 0.020 | Tree loss: 0.684 | Accuracy: 0.691406 | 0.046 sec/iter
Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.98404255319148

Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
Epoch: 252 | Batch: 000 / 040 | Total loss: 0.660 | Reg loss: 0.020 | Tree loss: 0.660 | Accuracy: 0.724609 | 0.046 sec/iter
Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
Epoch: 253 | Batch: 000 / 040 | Total loss: 0.657 | Reg loss: 0.020 | Tree loss: 0.657 | Accuracy: 0.718750 | 0.046 sec/iter
Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
Epoch: 254 | Batch: 000 / 040 | Total loss: 0.674 | Reg loss: 0.020 | Tree loss: 0.674 | Accuracy: 0.708984 | 0.046 sec/iter
Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.98404255319148

Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
Epoch: 280 | Batch: 000 / 040 | Total loss: 0.635 | Reg loss: 0.020 | Tree loss: 0.635 | Accuracy: 0.755859 | 0.046 sec/iter
Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
Epoch: 281 | Batch: 000 / 040 | Total loss: 0.616 | Reg loss: 0.020 | Tree loss: 0.616 | Accuracy: 0.748047 | 0.046 sec/iter
Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
Epoch: 282 | Batch: 000 / 040 | Total loss: 0.619 | Reg loss: 0.020 | Tree loss: 0.619 | Accuracy: 0.746094 | 0.046 sec/iter
Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.98404255319148

Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
Epoch: 308 | Batch: 000 / 040 | Total loss: 0.633 | Reg loss: 0.020 | Tree loss: 0.633 | Accuracy: 0.753906 | 0.047 sec/iter
Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
Epoch: 309 | Batch: 000 / 040 | Total loss: 0.639 | Reg loss: 0.020 | Tree loss: 0.639 | Accuracy: 0.751953 | 0.047 sec/iter
Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
Epoch: 310 | Batch: 000 / 040 | Total loss: 0.637 | Reg loss: 0.020 | Tree loss: 0.637 | Accuracy: 0.750000 | 0.047 sec/iter
Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.98404255319148

Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
Epoch: 336 | Batch: 000 / 040 | Total loss: 0.558 | Reg loss: 0.020 | Tree loss: 0.558 | Accuracy: 0.765625 | 0.048 sec/iter
Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
Epoch: 337 | Batch: 000 / 040 | Total loss: 0.576 | Reg loss: 0.020 | Tree loss: 0.576 | Accuracy: 0.767578 | 0.048 sec/iter
Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
Epoch: 338 | Batch: 000 / 040 | Total loss: 0.640 | Reg loss: 0.020 | Tree loss: 0.640 | Accuracy: 0.753906 | 0.048 sec/iter
Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.98404255319148

Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
Epoch: 364 | Batch: 000 / 040 | Total loss: 0.681 | Reg loss: 0.020 | Tree loss: 0.681 | Accuracy: 0.720703 | 0.049 sec/iter
Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
Epoch: 365 | Batch: 000 / 040 | Total loss: 0.655 | Reg loss: 0.020 | Tree loss: 0.655 | Accuracy: 0.742188 | 0.049 sec/iter
Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
Epoch: 366 | Batch: 000 / 040 | Total loss: 0.617 | Reg loss: 0.020 | Tree loss: 0.617 | Accuracy: 0.738281 | 0.049 sec/iter
Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.98404255319148

Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
Epoch: 392 | Batch: 000 / 040 | Total loss: 0.668 | Reg loss: 0.020 | Tree loss: 0.668 | Accuracy: 0.734375 | 0.05 sec/iter
Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
Epoch: 393 | Batch: 000 / 040 | Total loss: 0.648 | Reg loss: 0.020 | Tree loss: 0.648 | Accuracy: 0.708984 | 0.05 sec/iter
Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
Epoch: 394 | Batch: 000 / 040 | Total loss: 0.651 | Reg loss: 0.020 | Tree loss: 0.651 | Accuracy: 0.718750 | 0.05 sec/iter
Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894


In [53]:
plt.figure(figsize=(10, 5))
plt.ylabel("Accuracy")
plt.xlabel('Iteration')
plt.plot(accs, label='Accuracy vs iteration')
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [54]:
plt.figure()
plt.ylabel("Loss")
plt.xlabel('Iteration')
plt.plot(losses, label='Loss vs iteration')
# plt.yscale("log")
plt.show()

plt.figure()
weights = tree.inner_nodes.weight.cpu().detach().numpy().flatten()
plt.hist(weights, bins=500)
weights_std = np.std(weights)
weights_mean = np.mean(weights)
plt.axvline(weights_mean + weights_std, color='r')
plt.axvline(weights_mean - weights_std, color='r')
plt.title(f"Mean: {weights_mean}   |   STD: {weights_std}")
plt.yscale("log")
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

# Tree Visualization

In [55]:
plt.figure(figsize=(15, 10), dpi=80)
avg_height, root = tree.visualize()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Average height: 4.928571428571429


# Extract Rules

# Accumulate samples in the leaves

In [56]:
print(f"Number of patterns: {len(root.get_leaves())}")

Number of patterns: 14


In [57]:
method = 'greedy'

In [58]:
root.clear_leaves_samples()

with torch.no_grad():
    for batch_idx, (data, target) in enumerate(tree_loader):
        root.accumulate_samples(data, method)



# Tighten boundaries

In [59]:
attr_names = [f"T_{i}" for i in range(test_samples.shape[2])]
leaves = root.get_leaves()
sum_comprehensibility = 0
comprehensibilities = []
for pattern_counter, leaf in enumerate(leaves):
    leaf.reset_path()
    leaf.tighten_with_accumulated_samples()
    conds = leaf.get_path_conditions(attr_names)
    print(f"============== Pattern {pattern_counter + 1} ==============")
    comprehensibilities.append(sum([cond.comprehensibility for cond in conds]))
    
print(f"Average comprehensibility: {np.mean(comprehensibilities)}")
print(f"std comprehensibility: {np.std(comprehensibilities)}")
print(f"var comprehensibility: {np.var(comprehensibilities)}")
print(f"minimum comprehensibility: {np.min(comprehensibilities)}")
print(f"maximum comprehensibility: {np.max(comprehensibilities)}")

5002
3331
11581
138
403
Average comprehensibility: 25.428571428571427
std comprehensibility: 8.666143364630344
var comprehensibility: 75.10204081632654
minimum comprehensibility: 4
maximum comprehensibility: 32


  return np.log(1 / (1 - x))
