In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib widget
import numpy as np
import pandas as pd
import torch
import torchvision
from torch import optim
import matplotlib.pyplot as plt
import torch.nn.functional as F
from sklearn.metrics import pairwise_distances
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tqdm.notebook import tqdm
import os
import sys
import time
import torch.nn as nn
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)
from stream_generators.mit_bih import MITBIH
from utils.MatplotlibUtils import reduce_dims_and_plot
from network.auto_encoder import AutoEncoder
from losses.knn_loss import KNNLoss, ClassificationKNNLoss
from sklearn.cluster import DBSCAN
from sklearn.mixture import GaussianMixture
from soft_decision_tree.sdt_model import SDT
from sklearn.metrics import davies_bouldin_score

In [2]:
k = 32
tree_depth = 8
batch_size = 512
device = 'cuda'
train_data_path = r'<>/mitbih_train.csv'  # replace <> with the correct path of the dataset
test_data_path = r'<>/mitbih_test.csv'  # replace <> with the correct path of the dataset

In [3]:
train_data_iter = torch.utils.data.DataLoader(MITBIH(train_data_path),
                                             batch_size=batch_size,
                                             shuffle=True,
                                             num_workers=1,
                                             pin_memory=True,
                                             drop_last=True)

test_data_iter = torch.utils.data.DataLoader(MITBIH(test_data_path),
                                             batch_size=batch_size,
                                             shuffle=True,
                                             num_workers=1,
                                             pin_memory=True)

In [4]:
class ConvBlock(nn.Module):
    def __init__(self):
        super(ConvBlock, self).__init__()
        self.conv1 = nn.Conv1d(32, 32, kernel_size=5, stride=1, padding=2)
        self.conv2 = nn.Conv1d(32, 32, kernel_size=5, stride=1, padding=2)
        self.relu1 = nn.ReLU()
        self.relu2 = nn.ReLU()
        self.pool = nn.MaxPool1d(kernel_size=5, stride=2)
        
    def forward(self, x):
        y = x
        y = self.conv1(y)
        y = self.relu1(y)
        y = self.conv2(y)
        y = y + x
        y = self.relu2(y)
        y = self.pool(y)
        return y


class ECGModel(nn.Module):
    def __init__(self):
        super(ECGModel, self).__init__()
        self.conv = nn.Conv1d(1, 32, kernel_size=5, stride=1, padding=1)
        self.block1 = ConvBlock()
        self.block2 = ConvBlock()
        self.block3 = ConvBlock()
        self.block4 = ConvBlock()
        self.block5 = ConvBlock()
        self.fc1 = nn.Linear(64, 32)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(32, 5)

    def forward(self, x, return_interm=False):
        x = self.conv(x)
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        x = self.block4(x)
        x = self.block5(x)
        interm = x.flatten(1)
        x = self.fc1(interm)
        x = self.relu(x)
        x = self.fc2(x)
        
        if return_interm:
            return x, interm
        
        return x

In [5]:
knn_crt = ClassificationKNNLoss(k=k).to(device)

def train(model, loader, optimizer, device):
    model.train()

    total_loss = 0
    for iteration, (batch, target) in enumerate(loader):
        batch = batch.to(device)
        target = target.to(device)
        outputs, interm = model(batch, return_interm=True)
        mse_loss = F.cross_entropy(outputs, target)
        mse_loss = mse_loss.sum(dim=-1).mean()
        try:
            knn_loss = knn_crt(interm, target)
            if torch.isinf(knn_loss):
                knn_loss = torch.tensor(0).to(device)
        except ValueError:
            knn_loss = torch.tensor(0).to(device)
        loss = mse_loss + knn_loss
        total_loss += loss.item()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if iteration % log_every == 0:
            print(f"Epoch {epoch} / {epochs} | iteration {iteration} / {len(loader)} | Total Loss: {loss.item()} | KNN Loss: {knn_loss.item()} | CLS Loss: {mse_loss.item()}")

    return total_loss / len(loader)

@torch.no_grad()
def test(model, loader, device):
    model.eval()
    
    correct = 0
    for iteration, (batch, target) in enumerate(loader):
        batch = batch.to(device)
        target = target.to(device)
        y_pred = model(batch).argmax(dim=-1)
        correct += y_pred.eq(target.view(-1).data).sum()
    
    return correct / len(loader.dataset)

In [6]:
epochs = 200
lr = 1e-3
log_every = 10

model = ECGModel().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
num_params = sum([p.numel() for p in model.parameters()])
print(f'#Params: {num_params}')

#Params: 53957


In [7]:
best_valid_acc = 0
losses = []
train_accs = []
val_accs = []
for epoch in range(1, epochs + 1):
    loss = train(model, train_data_iter, optimizer, device)
#     print(f"Loss: {loss} =============================")
    losses.append(loss)
    train_acc = test(model, train_data_iter, device)
    train_accs.append(train_acc)
    valid_acc = test(model, test_data_iter, device)
    val_accs.append(valid_acc)
    if valid_acc > best_valid_acc:
        best_valid_acc = valid_acc
    if epoch % 1 == 0:
        print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, '
              f'Train: {train_acc:.4f}, Valid: {valid_acc:.4f}, '
              f'Best: {best_valid_acc:.4f}')

Epoch 1 / 200 | iteration 0 / 171 | Total Loss: 7.338754177093506 | KNN Loss: 5.834799289703369 | CLS Loss: 1.5039548873901367
Epoch 1 / 200 | iteration 10 / 171 | Total Loss: 5.342403888702393 | KNN Loss: 4.256844997406006 | CLS Loss: 1.0855587720870972
Epoch 1 / 200 | iteration 20 / 171 | Total Loss: 4.604902744293213 | KNN Loss: 3.9738428592681885 | CLS Loss: 0.6310598850250244
Epoch 1 / 200 | iteration 30 / 171 | Total Loss: 4.551351070404053 | KNN Loss: 3.9501264095306396 | CLS Loss: 0.6012245416641235
Epoch 1 / 200 | iteration 40 / 171 | Total Loss: 4.384238243103027 | KNN Loss: 3.87967848777771 | CLS Loss: 0.5045595765113831
Epoch 1 / 200 | iteration 50 / 171 | Total Loss: 4.387600898742676 | KNN Loss: 3.8757076263427734 | CLS Loss: 0.5118935108184814
Epoch 1 / 200 | iteration 60 / 171 | Total Loss: 4.261226177215576 | KNN Loss: 3.863406181335449 | CLS Loss: 0.3978199064731598
Epoch 1 / 200 | iteration 70 / 171 | Total Loss: 4.3324408531188965 | KNN Loss: 3.8777098655700684 | CL

Epoch 4 / 200 | iteration 80 / 171 | Total Loss: 3.843834638595581 | KNN Loss: 3.724855661392212 | CLS Loss: 0.1189790666103363
Epoch 4 / 200 | iteration 90 / 171 | Total Loss: 3.889632225036621 | KNN Loss: 3.6959853172302246 | CLS Loss: 0.19364680349826813
Epoch 4 / 200 | iteration 100 / 171 | Total Loss: 3.847315549850464 | KNN Loss: 3.7167088985443115 | CLS Loss: 0.13060657680034637
Epoch 4 / 200 | iteration 110 / 171 | Total Loss: 3.943624496459961 | KNN Loss: 3.7300667762756348 | CLS Loss: 0.21355783939361572
Epoch 4 / 200 | iteration 120 / 171 | Total Loss: 3.8840489387512207 | KNN Loss: 3.7766077518463135 | CLS Loss: 0.1074412390589714
Epoch 4 / 200 | iteration 130 / 171 | Total Loss: 3.8714261054992676 | KNN Loss: 3.7358646392822266 | CLS Loss: 0.1355615109205246
Epoch 4 / 200 | iteration 140 / 171 | Total Loss: 3.9323315620422363 | KNN Loss: 3.7160022258758545 | CLS Loss: 0.21632935106754303
Epoch 4 / 200 | iteration 150 / 171 | Total Loss: 3.868086338043213 | KNN Loss: 3.7704

Epoch 7 / 200 | iteration 160 / 171 | Total Loss: 3.749708414077759 | KNN Loss: 3.7052414417266846 | CLS Loss: 0.04446694999933243
Epoch 7 / 200 | iteration 170 / 171 | Total Loss: 3.840590476989746 | KNN Loss: 3.7238759994506836 | CLS Loss: 0.11671452969312668
Epoch: 007, Loss: 3.8055, Train: 0.9751, Valid: 0.9724, Best: 0.9737
Epoch 8 / 200 | iteration 0 / 171 | Total Loss: 3.7702958583831787 | KNN Loss: 3.7018682956695557 | CLS Loss: 0.0684276595711708
Epoch 8 / 200 | iteration 10 / 171 | Total Loss: 3.813961982727051 | KNN Loss: 3.7236132621765137 | CLS Loss: 0.09034878015518188
Epoch 8 / 200 | iteration 20 / 171 | Total Loss: 3.771840810775757 | KNN Loss: 3.6930980682373047 | CLS Loss: 0.07874263823032379
Epoch 8 / 200 | iteration 30 / 171 | Total Loss: 3.795667886734009 | KNN Loss: 3.6814167499542236 | CLS Loss: 0.11425109207630157
Epoch 8 / 200 | iteration 40 / 171 | Total Loss: 3.7971737384796143 | KNN Loss: 3.695945978164673 | CLS Loss: 0.10122781991958618
Epoch 8 / 200 | iter

Epoch 11 / 200 | iteration 50 / 171 | Total Loss: 3.7585644721984863 | KNN Loss: 3.6319265365600586 | CLS Loss: 0.1266380250453949
Epoch 11 / 200 | iteration 60 / 171 | Total Loss: 3.759752035140991 | KNN Loss: 3.68269419670105 | CLS Loss: 0.077057845890522
Epoch 11 / 200 | iteration 70 / 171 | Total Loss: 3.824002265930176 | KNN Loss: 3.7147364616394043 | CLS Loss: 0.1092657595872879
Epoch 11 / 200 | iteration 80 / 171 | Total Loss: 3.772289276123047 | KNN Loss: 3.680137872695923 | CLS Loss: 0.09215141087770462
Epoch 11 / 200 | iteration 90 / 171 | Total Loss: 3.7551050186157227 | KNN Loss: 3.672175168991089 | CLS Loss: 0.08292978256940842
Epoch 11 / 200 | iteration 100 / 171 | Total Loss: 3.7618613243103027 | KNN Loss: 3.681504249572754 | CLS Loss: 0.08035717159509659
Epoch 11 / 200 | iteration 110 / 171 | Total Loss: 3.759122848510742 | KNN Loss: 3.623027801513672 | CLS Loss: 0.13609500229358673
Epoch 11 / 200 | iteration 120 / 171 | Total Loss: 3.7889795303344727 | KNN Loss: 3.6959

Epoch 14 / 200 | iteration 120 / 171 | Total Loss: 3.753809690475464 | KNN Loss: 3.6689274311065674 | CLS Loss: 0.0848822072148323
Epoch 14 / 200 | iteration 130 / 171 | Total Loss: 3.7419090270996094 | KNN Loss: 3.6707661151885986 | CLS Loss: 0.07114293426275253
Epoch 14 / 200 | iteration 140 / 171 | Total Loss: 3.7657008171081543 | KNN Loss: 3.708181142807007 | CLS Loss: 0.05751967802643776
Epoch 14 / 200 | iteration 150 / 171 | Total Loss: 3.7609260082244873 | KNN Loss: 3.6992526054382324 | CLS Loss: 0.061673350632190704
Epoch 14 / 200 | iteration 160 / 171 | Total Loss: 3.7297489643096924 | KNN Loss: 3.6697888374328613 | CLS Loss: 0.059960171580314636
Epoch 14 / 200 | iteration 170 / 171 | Total Loss: 3.776679277420044 | KNN Loss: 3.7186031341552734 | CLS Loss: 0.058076124638319016
Epoch: 014, Loss: 3.7413, Train: 0.9839, Valid: 0.9805, Best: 0.9806
Epoch 15 / 200 | iteration 0 / 171 | Total Loss: 3.7490642070770264 | KNN Loss: 3.6808879375457764 | CLS Loss: 0.06817632168531418
Epo

Epoch 18 / 200 | iteration 10 / 171 | Total Loss: 3.6982250213623047 | KNN Loss: 3.6511051654815674 | CLS Loss: 0.04711996018886566
Epoch 18 / 200 | iteration 20 / 171 | Total Loss: 3.7589612007141113 | KNN Loss: 3.6595091819763184 | CLS Loss: 0.09945190697908401
Epoch 18 / 200 | iteration 30 / 171 | Total Loss: 3.7329001426696777 | KNN Loss: 3.6983208656311035 | CLS Loss: 0.03457936644554138
Epoch 18 / 200 | iteration 40 / 171 | Total Loss: 3.7600438594818115 | KNN Loss: 3.6894490718841553 | CLS Loss: 0.07059486955404282
Epoch 18 / 200 | iteration 50 / 171 | Total Loss: 3.7165403366088867 | KNN Loss: 3.6628689765930176 | CLS Loss: 0.05367138236761093
Epoch 18 / 200 | iteration 60 / 171 | Total Loss: 3.7211008071899414 | KNN Loss: 3.6618824005126953 | CLS Loss: 0.05921830236911774
Epoch 18 / 200 | iteration 70 / 171 | Total Loss: 3.688033103942871 | KNN Loss: 3.6543681621551514 | CLS Loss: 0.03366489335894585
Epoch 18 / 200 | iteration 80 / 171 | Total Loss: 3.6989071369171143 | KNN Lo

Epoch 21 / 200 | iteration 80 / 171 | Total Loss: 3.7312204837799072 | KNN Loss: 3.6779186725616455 | CLS Loss: 0.05330183357000351
Epoch 21 / 200 | iteration 90 / 171 | Total Loss: 3.702188730239868 | KNN Loss: 3.6583821773529053 | CLS Loss: 0.04380656033754349
Epoch 21 / 200 | iteration 100 / 171 | Total Loss: 3.6925582885742188 | KNN Loss: 3.6545724868774414 | CLS Loss: 0.03798582777380943
Epoch 21 / 200 | iteration 110 / 171 | Total Loss: 3.6783676147460938 | KNN Loss: 3.662086248397827 | CLS Loss: 0.016281459480524063
Epoch 21 / 200 | iteration 120 / 171 | Total Loss: 3.6803810596466064 | KNN Loss: 3.65169620513916 | CLS Loss: 0.028684748336672783
Epoch 21 / 200 | iteration 130 / 171 | Total Loss: 3.6873257160186768 | KNN Loss: 3.6547253131866455 | CLS Loss: 0.03260049968957901
Epoch 21 / 200 | iteration 140 / 171 | Total Loss: 3.706716775894165 | KNN Loss: 3.6632566452026367 | CLS Loss: 0.04346016049385071
Epoch 21 / 200 | iteration 150 / 171 | Total Loss: 3.7160792350769043 | KN

Epoch 24 / 200 | iteration 150 / 171 | Total Loss: 3.7082808017730713 | KNN Loss: 3.677102565765381 | CLS Loss: 0.031178202480077744
Epoch 24 / 200 | iteration 160 / 171 | Total Loss: 3.707207679748535 | KNN Loss: 3.6545321941375732 | CLS Loss: 0.052675433456897736
Epoch 24 / 200 | iteration 170 / 171 | Total Loss: 3.671010971069336 | KNN Loss: 3.603168249130249 | CLS Loss: 0.06784262508153915
Epoch: 024, Loss: 3.6938, Train: 0.9883, Valid: 0.9824, Best: 0.9825
Epoch 25 / 200 | iteration 0 / 171 | Total Loss: 3.677830696105957 | KNN Loss: 3.6054742336273193 | CLS Loss: 0.07235658168792725
Epoch 25 / 200 | iteration 10 / 171 | Total Loss: 3.7444167137145996 | KNN Loss: 3.69122576713562 | CLS Loss: 0.053190965205430984
Epoch 25 / 200 | iteration 20 / 171 | Total Loss: 3.7102222442626953 | KNN Loss: 3.6525917053222656 | CLS Loss: 0.05763062834739685
Epoch 25 / 200 | iteration 30 / 171 | Total Loss: 3.710568428039551 | KNN Loss: 3.6350183486938477 | CLS Loss: 0.07554997503757477
Epoch 25 /

Epoch 28 / 200 | iteration 40 / 171 | Total Loss: 3.6742031574249268 | KNN Loss: 3.65909743309021 | CLS Loss: 0.015105719678103924
Epoch 28 / 200 | iteration 50 / 171 | Total Loss: 3.69950008392334 | KNN Loss: 3.6642954349517822 | CLS Loss: 0.03520473465323448
Epoch 28 / 200 | iteration 60 / 171 | Total Loss: 3.681307077407837 | KNN Loss: 3.657381534576416 | CLS Loss: 0.023925427347421646
Epoch 28 / 200 | iteration 70 / 171 | Total Loss: 3.6771240234375 | KNN Loss: 3.6539933681488037 | CLS Loss: 0.023130672052502632
Epoch 28 / 200 | iteration 80 / 171 | Total Loss: 3.686328411102295 | KNN Loss: 3.6419615745544434 | CLS Loss: 0.04436672478914261
Epoch 28 / 200 | iteration 90 / 171 | Total Loss: 3.6734156608581543 | KNN Loss: 3.6384098529815674 | CLS Loss: 0.03500586748123169
Epoch 28 / 200 | iteration 100 / 171 | Total Loss: 3.6955769062042236 | KNN Loss: 3.6351115703582764 | CLS Loss: 0.06046542152762413
Epoch 28 / 200 | iteration 110 / 171 | Total Loss: 3.6680774688720703 | KNN Loss: 

Epoch 31 / 200 | iteration 110 / 171 | Total Loss: 3.6898162364959717 | KNN Loss: 3.653327226638794 | CLS Loss: 0.036488939076662064
Epoch 31 / 200 | iteration 120 / 171 | Total Loss: 3.675814390182495 | KNN Loss: 3.630084753036499 | CLS Loss: 0.04572966694831848
Epoch 31 / 200 | iteration 130 / 171 | Total Loss: 3.6825597286224365 | KNN Loss: 3.6712446212768555 | CLS Loss: 0.011315049603581429
Epoch 31 / 200 | iteration 140 / 171 | Total Loss: 3.6810450553894043 | KNN Loss: 3.6249330043792725 | CLS Loss: 0.05611199885606766
Epoch 31 / 200 | iteration 150 / 171 | Total Loss: 3.6789464950561523 | KNN Loss: 3.654362201690674 | CLS Loss: 0.024584298953413963
Epoch 31 / 200 | iteration 160 / 171 | Total Loss: 3.716630697250366 | KNN Loss: 3.671271324157715 | CLS Loss: 0.04535936191678047
Epoch 31 / 200 | iteration 170 / 171 | Total Loss: 3.671799421310425 | KNN Loss: 3.6574580669403076 | CLS Loss: 0.014341408386826515
Epoch: 031, Loss: 3.6868, Train: 0.9889, Valid: 0.9824, Best: 0.9848
Epo

Epoch: 034, Loss: 3.6727, Train: 0.9908, Valid: 0.9837, Best: 0.9852
Epoch 35 / 200 | iteration 0 / 171 | Total Loss: 3.66903018951416 | KNN Loss: 3.6490793228149414 | CLS Loss: 0.019950777292251587
Epoch 35 / 200 | iteration 10 / 171 | Total Loss: 3.7062532901763916 | KNN Loss: 3.665992259979248 | CLS Loss: 0.040260955691337585
Epoch 35 / 200 | iteration 20 / 171 | Total Loss: 3.638197183609009 | KNN Loss: 3.615208625793457 | CLS Loss: 0.02298852987587452
Epoch 35 / 200 | iteration 30 / 171 | Total Loss: 3.691382646560669 | KNN Loss: 3.6681060791015625 | CLS Loss: 0.023276664316654205
Epoch 35 / 200 | iteration 40 / 171 | Total Loss: 3.6589534282684326 | KNN Loss: 3.6350064277648926 | CLS Loss: 0.023947114124894142
Epoch 35 / 200 | iteration 50 / 171 | Total Loss: 3.6544852256774902 | KNN Loss: 3.5937671661376953 | CLS Loss: 0.06071798503398895
Epoch 35 / 200 | iteration 60 / 171 | Total Loss: 3.687897205352783 | KNN Loss: 3.6643226146698 | CLS Loss: 0.023574480786919594
Epoch 35 / 20

Epoch 38 / 200 | iteration 70 / 171 | Total Loss: 3.6436147689819336 | KNN Loss: 3.6138155460357666 | CLS Loss: 0.02979922853410244
Epoch 38 / 200 | iteration 80 / 171 | Total Loss: 3.683253288269043 | KNN Loss: 3.6731791496276855 | CLS Loss: 0.010074188932776451
Epoch 38 / 200 | iteration 90 / 171 | Total Loss: 3.6697933673858643 | KNN Loss: 3.6583645343780518 | CLS Loss: 0.011428726837038994
Epoch 38 / 200 | iteration 100 / 171 | Total Loss: 3.6550073623657227 | KNN Loss: 3.6068708896636963 | CLS Loss: 0.04813650622963905
Epoch 38 / 200 | iteration 110 / 171 | Total Loss: 3.681947708129883 | KNN Loss: 3.653435468673706 | CLS Loss: 0.028512325137853622
Epoch 38 / 200 | iteration 120 / 171 | Total Loss: 3.7096524238586426 | KNN Loss: 3.6866910457611084 | CLS Loss: 0.022961309179663658
Epoch 38 / 200 | iteration 130 / 171 | Total Loss: 3.6898062229156494 | KNN Loss: 3.6535158157348633 | CLS Loss: 0.03629042208194733
Epoch 38 / 200 | iteration 140 / 171 | Total Loss: 3.657222032546997 | 

Epoch 41 / 200 | iteration 140 / 171 | Total Loss: 3.675118923187256 | KNN Loss: 3.6515920162200928 | CLS Loss: 0.02352684922516346
Epoch 41 / 200 | iteration 150 / 171 | Total Loss: 3.710148572921753 | KNN Loss: 3.6831130981445312 | CLS Loss: 0.02703559212386608
Epoch 41 / 200 | iteration 160 / 171 | Total Loss: 3.6759989261627197 | KNN Loss: 3.629619598388672 | CLS Loss: 0.04637935757637024
Epoch 41 / 200 | iteration 170 / 171 | Total Loss: 3.6849188804626465 | KNN Loss: 3.6351089477539062 | CLS Loss: 0.04980981722474098
Epoch: 041, Loss: 3.6709, Train: 0.9924, Valid: 0.9857, Best: 0.9857
Epoch 42 / 200 | iteration 0 / 171 | Total Loss: 3.695998191833496 | KNN Loss: 3.6695501804351807 | CLS Loss: 0.026448095217347145
Epoch 42 / 200 | iteration 10 / 171 | Total Loss: 3.6669323444366455 | KNN Loss: 3.6379847526550293 | CLS Loss: 0.02894764393568039
Epoch 42 / 200 | iteration 20 / 171 | Total Loss: 3.6435353755950928 | KNN Loss: 3.616520404815674 | CLS Loss: 0.027014879509806633
Epoch 4

Epoch 45 / 200 | iteration 30 / 171 | Total Loss: 3.674020767211914 | KNN Loss: 3.6628055572509766 | CLS Loss: 0.011215245351195335
Epoch 45 / 200 | iteration 40 / 171 | Total Loss: 3.638474941253662 | KNN Loss: 3.6134133338928223 | CLS Loss: 0.025061555206775665
Epoch 45 / 200 | iteration 50 / 171 | Total Loss: 3.6613900661468506 | KNN Loss: 3.6379590034484863 | CLS Loss: 0.02343105338513851
Epoch 45 / 200 | iteration 60 / 171 | Total Loss: 3.6373536586761475 | KNN Loss: 3.607299327850342 | CLS Loss: 0.0300543662160635
Epoch 45 / 200 | iteration 70 / 171 | Total Loss: 3.6596665382385254 | KNN Loss: 3.6432206630706787 | CLS Loss: 0.016445988789200783
Epoch 45 / 200 | iteration 80 / 171 | Total Loss: 3.6829710006713867 | KNN Loss: 3.6282901763916016 | CLS Loss: 0.054680801928043365
Epoch 45 / 200 | iteration 90 / 171 | Total Loss: 3.650747060775757 | KNN Loss: 3.6377227306365967 | CLS Loss: 0.01302428636699915
Epoch 45 / 200 | iteration 100 / 171 | Total Loss: 3.6886324882507324 | KNN L

Epoch 48 / 200 | iteration 100 / 171 | Total Loss: 3.6621346473693848 | KNN Loss: 3.635104179382324 | CLS Loss: 0.027030566707253456
Epoch 48 / 200 | iteration 110 / 171 | Total Loss: 3.6518406867980957 | KNN Loss: 3.6111538410186768 | CLS Loss: 0.040686801075935364
Epoch 48 / 200 | iteration 120 / 171 | Total Loss: 3.6415610313415527 | KNN Loss: 3.622943878173828 | CLS Loss: 0.018617264926433563
Epoch 48 / 200 | iteration 130 / 171 | Total Loss: 3.664109945297241 | KNN Loss: 3.6110951900482178 | CLS Loss: 0.053014837205410004
Epoch 48 / 200 | iteration 140 / 171 | Total Loss: 3.6380250453948975 | KNN Loss: 3.6337947845458984 | CLS Loss: 0.00423037214204669
Epoch 48 / 200 | iteration 150 / 171 | Total Loss: 3.6721742153167725 | KNN Loss: 3.6584389209747314 | CLS Loss: 0.01373538002371788
Epoch 48 / 200 | iteration 160 / 171 | Total Loss: 3.6348464488983154 | KNN Loss: 3.6189370155334473 | CLS Loss: 0.015909496694803238
Epoch 48 / 200 | iteration 170 / 171 | Total Loss: 3.70760989189147

Epoch 51 / 200 | iteration 170 / 171 | Total Loss: 3.6501994132995605 | KNN Loss: 3.6316144466400146 | CLS Loss: 0.018585067242383957
Epoch: 051, Loss: 3.6710, Train: 0.9931, Valid: 0.9864, Best: 0.9866
Epoch 52 / 200 | iteration 0 / 171 | Total Loss: 3.643723249435425 | KNN Loss: 3.6131808757781982 | CLS Loss: 0.030542319640517235
Epoch 52 / 200 | iteration 10 / 171 | Total Loss: 3.6688201427459717 | KNN Loss: 3.632675886154175 | CLS Loss: 0.03614432364702225
Epoch 52 / 200 | iteration 20 / 171 | Total Loss: 3.660720109939575 | KNN Loss: 3.6442432403564453 | CLS Loss: 0.016476891934871674
Epoch 52 / 200 | iteration 30 / 171 | Total Loss: 3.6683483123779297 | KNN Loss: 3.6474688053131104 | CLS Loss: 0.020879516378045082
Epoch 52 / 200 | iteration 40 / 171 | Total Loss: 3.726088523864746 | KNN Loss: 3.683061122894287 | CLS Loss: 0.043027304112911224
Epoch 52 / 200 | iteration 50 / 171 | Total Loss: 3.6515748500823975 | KNN Loss: 3.6245107650756836 | CLS Loss: 0.02706417813897133
Epoch 5

Epoch 55 / 200 | iteration 50 / 171 | Total Loss: 3.6232404708862305 | KNN Loss: 3.6008400917053223 | CLS Loss: 0.022400496527552605
Epoch 55 / 200 | iteration 60 / 171 | Total Loss: 3.67880916595459 | KNN Loss: 3.6511306762695312 | CLS Loss: 0.027678564190864563
Epoch 55 / 200 | iteration 70 / 171 | Total Loss: 3.6708579063415527 | KNN Loss: 3.6291024684906006 | CLS Loss: 0.041755400598049164
Epoch 55 / 200 | iteration 80 / 171 | Total Loss: 3.662759780883789 | KNN Loss: 3.643327474594116 | CLS Loss: 0.019432326778769493
Epoch 55 / 200 | iteration 90 / 171 | Total Loss: 3.719548225402832 | KNN Loss: 3.6568658351898193 | CLS Loss: 0.06268235296010971
Epoch 55 / 200 | iteration 100 / 171 | Total Loss: 3.6348793506622314 | KNN Loss: 3.616398334503174 | CLS Loss: 0.01848096214234829
Epoch 55 / 200 | iteration 110 / 171 | Total Loss: 3.664475202560425 | KNN Loss: 3.6424765586853027 | CLS Loss: 0.021998753771185875
Epoch 55 / 200 | iteration 120 / 171 | Total Loss: 3.6665384769439697 | KNN 

Epoch 58 / 200 | iteration 120 / 171 | Total Loss: 3.6733193397521973 | KNN Loss: 3.6568078994750977 | CLS Loss: 0.016511349007487297
Epoch 58 / 200 | iteration 130 / 171 | Total Loss: 3.6703085899353027 | KNN Loss: 3.6463091373443604 | CLS Loss: 0.023999441415071487
Epoch 58 / 200 | iteration 140 / 171 | Total Loss: 3.7099149227142334 | KNN Loss: 3.674295663833618 | CLS Loss: 0.03561931848526001
Epoch 58 / 200 | iteration 150 / 171 | Total Loss: 3.6543710231781006 | KNN Loss: 3.639430522918701 | CLS Loss: 0.014940431341528893
Epoch 58 / 200 | iteration 160 / 171 | Total Loss: 3.637864351272583 | KNN Loss: 3.611377000808716 | CLS Loss: 0.02648734487593174
Epoch 58 / 200 | iteration 170 / 171 | Total Loss: 3.6436285972595215 | KNN Loss: 3.619504690170288 | CLS Loss: 0.02412392385303974
Epoch: 058, Loss: 3.6595, Train: 0.9936, Valid: 0.9860, Best: 0.9866
Epoch 59 / 200 | iteration 0 / 171 | Total Loss: 3.6654913425445557 | KNN Loss: 3.6352896690368652 | CLS Loss: 0.03020167350769043
Epoc

Epoch 62 / 200 | iteration 0 / 171 | Total Loss: 3.635624647140503 | KNN Loss: 3.6134488582611084 | CLS Loss: 0.02217579260468483
Epoch 62 / 200 | iteration 10 / 171 | Total Loss: 3.6664607524871826 | KNN Loss: 3.6211302280426025 | CLS Loss: 0.04533042758703232
Epoch 62 / 200 | iteration 20 / 171 | Total Loss: 3.6346020698547363 | KNN Loss: 3.6162123680114746 | CLS Loss: 0.018389767035841942
Epoch 62 / 200 | iteration 30 / 171 | Total Loss: 3.657862663269043 | KNN Loss: 3.6349878311157227 | CLS Loss: 0.022874891757965088
Epoch 62 / 200 | iteration 40 / 171 | Total Loss: 3.674889087677002 | KNN Loss: 3.64983868598938 | CLS Loss: 0.025050325319170952
Epoch 62 / 200 | iteration 50 / 171 | Total Loss: 3.653926134109497 | KNN Loss: 3.614013671875 | CLS Loss: 0.03991256281733513
Epoch 62 / 200 | iteration 60 / 171 | Total Loss: 3.651498556137085 | KNN Loss: 3.5997302532196045 | CLS Loss: 0.05176827311515808
Epoch 62 / 200 | iteration 70 / 171 | Total Loss: 3.6652157306671143 | KNN Loss: 3.64

Epoch 65 / 200 | iteration 70 / 171 | Total Loss: 3.6487972736358643 | KNN Loss: 3.6241044998168945 | CLS Loss: 0.024692706763744354
Epoch 65 / 200 | iteration 80 / 171 | Total Loss: 3.6401638984680176 | KNN Loss: 3.635779857635498 | CLS Loss: 0.004384030122309923
Epoch 65 / 200 | iteration 90 / 171 | Total Loss: 3.662536144256592 | KNN Loss: 3.646332025527954 | CLS Loss: 0.016204185783863068
Epoch 65 / 200 | iteration 100 / 171 | Total Loss: 3.675102710723877 | KNN Loss: 3.655867099761963 | CLS Loss: 0.019235599786043167
Epoch 65 / 200 | iteration 110 / 171 | Total Loss: 3.6989712715148926 | KNN Loss: 3.684954881668091 | CLS Loss: 0.014016365632414818
Epoch 65 / 200 | iteration 120 / 171 | Total Loss: 3.6574454307556152 | KNN Loss: 3.6245100498199463 | CLS Loss: 0.032935336232185364
Epoch 65 / 200 | iteration 130 / 171 | Total Loss: 3.692349910736084 | KNN Loss: 3.6638948917388916 | CLS Loss: 0.028454937040805817
Epoch 65 / 200 | iteration 140 / 171 | Total Loss: 3.6678965091705322 | 

Epoch 68 / 200 | iteration 140 / 171 | Total Loss: 3.6745409965515137 | KNN Loss: 3.646989583969116 | CLS Loss: 0.02755132131278515
Epoch 68 / 200 | iteration 150 / 171 | Total Loss: 3.740340232849121 | KNN Loss: 3.6903610229492188 | CLS Loss: 0.049979183822870255
Epoch 68 / 200 | iteration 160 / 171 | Total Loss: 3.6219089031219482 | KNN Loss: 3.6049227714538574 | CLS Loss: 0.016986172646284103
Epoch 68 / 200 | iteration 170 / 171 | Total Loss: 3.6177713871002197 | KNN Loss: 3.5958950519561768 | CLS Loss: 0.02187643013894558
Epoch: 068, Loss: 3.6499, Train: 0.9946, Valid: 0.9869, Best: 0.9871
Epoch 69 / 200 | iteration 0 / 171 | Total Loss: 3.644516706466675 | KNN Loss: 3.6328940391540527 | CLS Loss: 0.011622578836977482
Epoch 69 / 200 | iteration 10 / 171 | Total Loss: 3.6498193740844727 | KNN Loss: 3.6263158321380615 | CLS Loss: 0.023503422737121582
Epoch 69 / 200 | iteration 20 / 171 | Total Loss: 3.661085844039917 | KNN Loss: 3.6435465812683105 | CLS Loss: 0.01753927581012249
Epoc

Epoch 72 / 200 | iteration 30 / 171 | Total Loss: 3.6593809127807617 | KNN Loss: 3.633129835128784 | CLS Loss: 0.02625107578933239
Epoch 72 / 200 | iteration 40 / 171 | Total Loss: 3.6406376361846924 | KNN Loss: 3.6021127700805664 | CLS Loss: 0.038524918258190155
Epoch 72 / 200 | iteration 50 / 171 | Total Loss: 3.6645243167877197 | KNN Loss: 3.6356539726257324 | CLS Loss: 0.028870441019535065
Epoch 72 / 200 | iteration 60 / 171 | Total Loss: 3.6215262413024902 | KNN Loss: 3.6087350845336914 | CLS Loss: 0.012791207991540432
Epoch 72 / 200 | iteration 70 / 171 | Total Loss: 3.698610305786133 | KNN Loss: 3.669224500656128 | CLS Loss: 0.02938583306968212
Epoch 72 / 200 | iteration 80 / 171 | Total Loss: 3.6752636432647705 | KNN Loss: 3.6679985523223877 | CLS Loss: 0.007265185937285423
Epoch 72 / 200 | iteration 90 / 171 | Total Loss: 3.608842611312866 | KNN Loss: 3.585810422897339 | CLS Loss: 0.023032132536172867
Epoch 72 / 200 | iteration 100 / 171 | Total Loss: 3.6437437534332275 | KNN 

Epoch 75 / 200 | iteration 100 / 171 | Total Loss: 3.629589557647705 | KNN Loss: 3.620208740234375 | CLS Loss: 0.009380880743265152
Epoch 75 / 200 | iteration 110 / 171 | Total Loss: 3.6835131645202637 | KNN Loss: 3.6481447219848633 | CLS Loss: 0.03536849468946457
Epoch 75 / 200 | iteration 120 / 171 | Total Loss: 3.6549484729766846 | KNN Loss: 3.6391243934631348 | CLS Loss: 0.01582399196922779
Epoch 75 / 200 | iteration 130 / 171 | Total Loss: 3.670414924621582 | KNN Loss: 3.6283726692199707 | CLS Loss: 0.042042337357997894
Epoch 75 / 200 | iteration 140 / 171 | Total Loss: 3.6243300437927246 | KNN Loss: 3.6124939918518066 | CLS Loss: 0.011835956946015358
Epoch 75 / 200 | iteration 150 / 171 | Total Loss: 3.66377592086792 | KNN Loss: 3.6246390342712402 | CLS Loss: 0.03913681209087372
Epoch 75 / 200 | iteration 160 / 171 | Total Loss: 3.6479074954986572 | KNN Loss: 3.6244730949401855 | CLS Loss: 0.023434314876794815
Epoch 75 / 200 | iteration 170 / 171 | Total Loss: 3.7073044776916504 

Epoch 78 / 200 | iteration 170 / 171 | Total Loss: 3.6806752681732178 | KNN Loss: 3.656628131866455 | CLS Loss: 0.02404705435037613
Epoch: 078, Loss: 3.6440, Train: 0.9943, Valid: 0.9852, Best: 0.9871
Epoch 79 / 200 | iteration 0 / 171 | Total Loss: 3.6794559955596924 | KNN Loss: 3.6505236625671387 | CLS Loss: 0.028932280838489532
Epoch 79 / 200 | iteration 10 / 171 | Total Loss: 3.6548264026641846 | KNN Loss: 3.630750894546509 | CLS Loss: 0.02407558634877205
Epoch 79 / 200 | iteration 20 / 171 | Total Loss: 3.6762869358062744 | KNN Loss: 3.6091275215148926 | CLS Loss: 0.06715936958789825
Epoch 79 / 200 | iteration 30 / 171 | Total Loss: 3.6240711212158203 | KNN Loss: 3.6055538654327393 | CLS Loss: 0.01851736381649971
Epoch 79 / 200 | iteration 40 / 171 | Total Loss: 3.6293253898620605 | KNN Loss: 3.615821361541748 | CLS Loss: 0.01350412517786026
Epoch 79 / 200 | iteration 50 / 171 | Total Loss: 3.6440300941467285 | KNN Loss: 3.640756368637085 | CLS Loss: 0.0032736710272729397
Epoch 79

Epoch 82 / 200 | iteration 50 / 171 | Total Loss: 3.6693270206451416 | KNN Loss: 3.64971661567688 | CLS Loss: 0.019610371440649033
Epoch 82 / 200 | iteration 60 / 171 | Total Loss: 3.680440902709961 | KNN Loss: 3.6684892177581787 | CLS Loss: 0.011951753869652748
Epoch 82 / 200 | iteration 70 / 171 | Total Loss: 3.6179721355438232 | KNN Loss: 3.6015796661376953 | CLS Loss: 0.016392365097999573
Epoch 82 / 200 | iteration 80 / 171 | Total Loss: 3.64730167388916 | KNN Loss: 3.6203465461730957 | CLS Loss: 0.02695503830909729
Epoch 82 / 200 | iteration 90 / 171 | Total Loss: 3.6373231410980225 | KNN Loss: 3.616776466369629 | CLS Loss: 0.020546643063426018
Epoch 82 / 200 | iteration 100 / 171 | Total Loss: 3.655268430709839 | KNN Loss: 3.648651361465454 | CLS Loss: 0.006617000326514244
Epoch 82 / 200 | iteration 110 / 171 | Total Loss: 3.67612361907959 | KNN Loss: 3.6556308269500732 | CLS Loss: 0.020492851734161377
Epoch 82 / 200 | iteration 120 / 171 | Total Loss: 3.622727394104004 | KNN Los

Epoch 85 / 200 | iteration 120 / 171 | Total Loss: 3.710954427719116 | KNN Loss: 3.6796069145202637 | CLS Loss: 0.031347621232271194
Epoch 85 / 200 | iteration 130 / 171 | Total Loss: 3.6440553665161133 | KNN Loss: 3.63805890083313 | CLS Loss: 0.005996455438435078
Epoch 85 / 200 | iteration 140 / 171 | Total Loss: 3.621293544769287 | KNN Loss: 3.6140012741088867 | CLS Loss: 0.007292275782674551
Epoch 85 / 200 | iteration 150 / 171 | Total Loss: 3.638615369796753 | KNN Loss: 3.624866247177124 | CLS Loss: 0.013749146834015846
Epoch 85 / 200 | iteration 160 / 171 | Total Loss: 3.6559500694274902 | KNN Loss: 3.633148193359375 | CLS Loss: 0.02280183508992195
Epoch 85 / 200 | iteration 170 / 171 | Total Loss: 3.66935133934021 | KNN Loss: 3.6571555137634277 | CLS Loss: 0.012195764109492302
Epoch: 085, Loss: 3.6441, Train: 0.9932, Valid: 0.9849, Best: 0.9873
Epoch 86 / 200 | iteration 0 / 171 | Total Loss: 3.624724864959717 | KNN Loss: 3.612138271331787 | CLS Loss: 0.012586616910994053
Epoch 8

Epoch 89 / 200 | iteration 10 / 171 | Total Loss: 3.6185052394866943 | KNN Loss: 3.606656551361084 | CLS Loss: 0.011848701164126396
Epoch 89 / 200 | iteration 20 / 171 | Total Loss: 3.6562089920043945 | KNN Loss: 3.6352710723876953 | CLS Loss: 0.02093781717121601
Epoch 89 / 200 | iteration 30 / 171 | Total Loss: 3.7009100914001465 | KNN Loss: 3.6621413230895996 | CLS Loss: 0.038768887519836426
Epoch 89 / 200 | iteration 40 / 171 | Total Loss: 3.687605857849121 | KNN Loss: 3.667494773864746 | CLS Loss: 0.02011103183031082
Epoch 89 / 200 | iteration 50 / 171 | Total Loss: 3.594698905944824 | KNN Loss: 3.573939561843872 | CLS Loss: 0.02075928822159767
Epoch 89 / 200 | iteration 60 / 171 | Total Loss: 3.657963275909424 | KNN Loss: 3.635880947113037 | CLS Loss: 0.022082427516579628
Epoch 89 / 200 | iteration 70 / 171 | Total Loss: 3.621464252471924 | KNN Loss: 3.6160178184509277 | CLS Loss: 0.0054463171400129795
Epoch 89 / 200 | iteration 80 / 171 | Total Loss: 3.648895502090454 | KNN Loss:

Epoch 92 / 200 | iteration 80 / 171 | Total Loss: 3.6609861850738525 | KNN Loss: 3.633877754211426 | CLS Loss: 0.02710840106010437
Epoch 92 / 200 | iteration 90 / 171 | Total Loss: 3.6494057178497314 | KNN Loss: 3.6150572299957275 | CLS Loss: 0.0343485027551651
Epoch 92 / 200 | iteration 100 / 171 | Total Loss: 3.6334824562072754 | KNN Loss: 3.6219983100891113 | CLS Loss: 0.011484050191938877
Epoch 92 / 200 | iteration 110 / 171 | Total Loss: 3.658475160598755 | KNN Loss: 3.627619504928589 | CLS Loss: 0.03085567243397236
Epoch 92 / 200 | iteration 120 / 171 | Total Loss: 3.6442413330078125 | KNN Loss: 3.625507116317749 | CLS Loss: 0.01873416267335415
Epoch 92 / 200 | iteration 130 / 171 | Total Loss: 3.648421049118042 | KNN Loss: 3.6352591514587402 | CLS Loss: 0.013161790557205677
Epoch 92 / 200 | iteration 140 / 171 | Total Loss: 3.6441667079925537 | KNN Loss: 3.634119749069214 | CLS Loss: 0.010047038085758686
Epoch 92 / 200 | iteration 150 / 171 | Total Loss: 3.66007137298584 | KNN L

Epoch 95 / 200 | iteration 150 / 171 | Total Loss: 3.6398766040802 | KNN Loss: 3.6279990673065186 | CLS Loss: 0.011877420358359814
Epoch 95 / 200 | iteration 160 / 171 | Total Loss: 3.6895132064819336 | KNN Loss: 3.673330068588257 | CLS Loss: 0.01618315279483795
Epoch 95 / 200 | iteration 170 / 171 | Total Loss: 3.6631927490234375 | KNN Loss: 3.6479098796844482 | CLS Loss: 0.015282964333891869
Epoch: 095, Loss: 3.6446, Train: 0.9958, Valid: 0.9865, Best: 0.9875
Epoch 96 / 200 | iteration 0 / 171 | Total Loss: 3.6321237087249756 | KNN Loss: 3.6211249828338623 | CLS Loss: 0.010998680256307125
Epoch 96 / 200 | iteration 10 / 171 | Total Loss: 3.65777587890625 | KNN Loss: 3.6440787315368652 | CLS Loss: 0.013697213493287563
Epoch 96 / 200 | iteration 20 / 171 | Total Loss: 3.624239444732666 | KNN Loss: 3.6049747467041016 | CLS Loss: 0.019264714792370796
Epoch 96 / 200 | iteration 30 / 171 | Total Loss: 3.6720454692840576 | KNN Loss: 3.6416194438934326 | CLS Loss: 0.03042605333030224
Epoch 9

Epoch 99 / 200 | iteration 30 / 171 | Total Loss: 3.590543746948242 | KNN Loss: 3.574340581893921 | CLS Loss: 0.016203269362449646
Epoch 99 / 200 | iteration 40 / 171 | Total Loss: 3.664567708969116 | KNN Loss: 3.6438872814178467 | CLS Loss: 0.02068045176565647
Epoch 99 / 200 | iteration 50 / 171 | Total Loss: 3.652207851409912 | KNN Loss: 3.617781162261963 | CLS Loss: 0.0344267301261425
Epoch 99 / 200 | iteration 60 / 171 | Total Loss: 3.632126808166504 | KNN Loss: 3.613577127456665 | CLS Loss: 0.018549608066678047
Epoch 99 / 200 | iteration 70 / 171 | Total Loss: 3.626056671142578 | KNN Loss: 3.6199591159820557 | CLS Loss: 0.006097497418522835
Epoch 99 / 200 | iteration 80 / 171 | Total Loss: 3.6516902446746826 | KNN Loss: 3.6312856674194336 | CLS Loss: 0.020404508337378502
Epoch 99 / 200 | iteration 90 / 171 | Total Loss: 3.620685338973999 | KNN Loss: 3.6141581535339355 | CLS Loss: 0.006527135148644447
Epoch 99 / 200 | iteration 100 / 171 | Total Loss: 3.6380841732025146 | KNN Loss:

Epoch 102 / 200 | iteration 100 / 171 | Total Loss: 3.598004102706909 | KNN Loss: 3.5810935497283936 | CLS Loss: 0.016910653561353683
Epoch 102 / 200 | iteration 110 / 171 | Total Loss: 3.6251959800720215 | KNN Loss: 3.615276336669922 | CLS Loss: 0.009919758886098862
Epoch 102 / 200 | iteration 120 / 171 | Total Loss: 3.6302168369293213 | KNN Loss: 3.6213455200195312 | CLS Loss: 0.008871311321854591
Epoch 102 / 200 | iteration 130 / 171 | Total Loss: 3.6304023265838623 | KNN Loss: 3.615344524383545 | CLS Loss: 0.015057907439768314
Epoch 102 / 200 | iteration 140 / 171 | Total Loss: 3.61515736579895 | KNN Loss: 3.6048765182495117 | CLS Loss: 0.01028084009885788
Epoch 102 / 200 | iteration 150 / 171 | Total Loss: 3.663729667663574 | KNN Loss: 3.6584222316741943 | CLS Loss: 0.005307443905621767
Epoch 102 / 200 | iteration 160 / 171 | Total Loss: 3.6596312522888184 | KNN Loss: 3.657477855682373 | CLS Loss: 0.0021534189581871033
Epoch 102 / 200 | iteration 170 / 171 | Total Loss: 3.59442663

Epoch 105 / 200 | iteration 160 / 171 | Total Loss: 3.687199354171753 | KNN Loss: 3.673489570617676 | CLS Loss: 0.013709687627851963
Epoch 105 / 200 | iteration 170 / 171 | Total Loss: 3.6546335220336914 | KNN Loss: 3.6253554821014404 | CLS Loss: 0.029278073459863663
Epoch: 105, Loss: 3.6456, Train: 0.9939, Valid: 0.9849, Best: 0.9876
Epoch 106 / 200 | iteration 0 / 171 | Total Loss: 3.595285177230835 | KNN Loss: 3.587592601776123 | CLS Loss: 0.007692515384405851
Epoch 106 / 200 | iteration 10 / 171 | Total Loss: 3.605112075805664 | KNN Loss: 3.594559907913208 | CLS Loss: 0.010552093386650085
Epoch 106 / 200 | iteration 20 / 171 | Total Loss: 3.6308014392852783 | KNN Loss: 3.6164045333862305 | CLS Loss: 0.014397014863789082
Epoch 106 / 200 | iteration 30 / 171 | Total Loss: 3.6385042667388916 | KNN Loss: 3.622450828552246 | CLS Loss: 0.016053510829806328
Epoch 106 / 200 | iteration 40 / 171 | Total Loss: 3.634216070175171 | KNN Loss: 3.624659776687622 | CLS Loss: 0.009556321427226067
E

Epoch 109 / 200 | iteration 40 / 171 | Total Loss: 3.603267192840576 | KNN Loss: 3.59265398979187 | CLS Loss: 0.01061319001019001
Epoch 109 / 200 | iteration 50 / 171 | Total Loss: 3.6370372772216797 | KNN Loss: 3.6281139850616455 | CLS Loss: 0.008923223242163658
Epoch 109 / 200 | iteration 60 / 171 | Total Loss: 3.6509816646575928 | KNN Loss: 3.6451058387756348 | CLS Loss: 0.005875799804925919
Epoch 109 / 200 | iteration 70 / 171 | Total Loss: 3.6125121116638184 | KNN Loss: 3.605048894882202 | CLS Loss: 0.007463289424777031
Epoch 109 / 200 | iteration 80 / 171 | Total Loss: 3.6549787521362305 | KNN Loss: 3.6410744190216064 | CLS Loss: 0.013904359191656113
Epoch 109 / 200 | iteration 90 / 171 | Total Loss: 3.6227829456329346 | KNN Loss: 3.6138699054718018 | CLS Loss: 0.008913114666938782
Epoch 109 / 200 | iteration 100 / 171 | Total Loss: 3.6348259449005127 | KNN Loss: 3.624406337738037 | CLS Loss: 0.010419672355055809
Epoch 109 / 200 | iteration 110 / 171 | Total Loss: 3.7000041007995

Epoch 112 / 200 | iteration 100 / 171 | Total Loss: 3.639369249343872 | KNN Loss: 3.6178033351898193 | CLS Loss: 0.0215659961104393
Epoch 112 / 200 | iteration 110 / 171 | Total Loss: 3.6339588165283203 | KNN Loss: 3.6301848888397217 | CLS Loss: 0.003773984033614397
Epoch 112 / 200 | iteration 120 / 171 | Total Loss: 3.6596360206604004 | KNN Loss: 3.6250369548797607 | CLS Loss: 0.03459900617599487
Epoch 112 / 200 | iteration 130 / 171 | Total Loss: 3.6661224365234375 | KNN Loss: 3.64313006401062 | CLS Loss: 0.022992491722106934
Epoch 112 / 200 | iteration 140 / 171 | Total Loss: 3.6625897884368896 | KNN Loss: 3.631444215774536 | CLS Loss: 0.03114561177790165
Epoch 112 / 200 | iteration 150 / 171 | Total Loss: 3.6347053050994873 | KNN Loss: 3.6217122077941895 | CLS Loss: 0.012993120588362217
Epoch 112 / 200 | iteration 160 / 171 | Total Loss: 3.617220640182495 | KNN Loss: 3.611640691757202 | CLS Loss: 0.005579921416938305
Epoch 112 / 200 | iteration 170 / 171 | Total Loss: 3.64120268821

Epoch 115 / 200 | iteration 160 / 171 | Total Loss: 3.6638641357421875 | KNN Loss: 3.634939432144165 | CLS Loss: 0.028924688696861267
Epoch 115 / 200 | iteration 170 / 171 | Total Loss: 3.625331401824951 | KNN Loss: 3.614860773086548 | CLS Loss: 0.010470605455338955
Epoch: 115, Loss: 3.6360, Train: 0.9949, Valid: 0.9855, Best: 0.9876
Epoch 116 / 200 | iteration 0 / 171 | Total Loss: 3.6388189792633057 | KNN Loss: 3.627091884613037 | CLS Loss: 0.01172720268368721
Epoch 116 / 200 | iteration 10 / 171 | Total Loss: 3.6111295223236084 | KNN Loss: 3.5981829166412354 | CLS Loss: 0.012946641072630882
Epoch 116 / 200 | iteration 20 / 171 | Total Loss: 3.620400905609131 | KNN Loss: 3.607592821121216 | CLS Loss: 0.012808116152882576
Epoch 116 / 200 | iteration 30 / 171 | Total Loss: 3.593369722366333 | KNN Loss: 3.575413942337036 | CLS Loss: 0.01795574650168419
Epoch 116 / 200 | iteration 40 / 171 | Total Loss: 3.632918119430542 | KNN Loss: 3.618520975112915 | CLS Loss: 0.014397196471691132
Epoc

Epoch 119 / 200 | iteration 40 / 171 | Total Loss: 3.6227574348449707 | KNN Loss: 3.618394136428833 | CLS Loss: 0.00436341343447566
Epoch 119 / 200 | iteration 50 / 171 | Total Loss: 3.623582363128662 | KNN Loss: 3.6167571544647217 | CLS Loss: 0.006825141608715057
Epoch 119 / 200 | iteration 60 / 171 | Total Loss: 3.6019954681396484 | KNN Loss: 3.5815632343292236 | CLS Loss: 0.02043216861784458
Epoch 119 / 200 | iteration 70 / 171 | Total Loss: 3.642270088195801 | KNN Loss: 3.628070592880249 | CLS Loss: 0.014199549332261086
Epoch 119 / 200 | iteration 80 / 171 | Total Loss: 3.601231098175049 | KNN Loss: 3.5870089530944824 | CLS Loss: 0.014222201891243458
Epoch 119 / 200 | iteration 90 / 171 | Total Loss: 3.6364803314208984 | KNN Loss: 3.626359701156616 | CLS Loss: 0.010120580904185772
Epoch 119 / 200 | iteration 100 / 171 | Total Loss: 3.6075737476348877 | KNN Loss: 3.594346284866333 | CLS Loss: 0.013227449730038643
Epoch 119 / 200 | iteration 110 / 171 | Total Loss: 3.6275248527526855

Epoch 122 / 200 | iteration 110 / 171 | Total Loss: 3.64738130569458 | KNN Loss: 3.6198973655700684 | CLS Loss: 0.02748394012451172
Epoch 122 / 200 | iteration 120 / 171 | Total Loss: 3.5880374908447266 | KNN Loss: 3.5774731636047363 | CLS Loss: 0.01056424155831337
Epoch 122 / 200 | iteration 130 / 171 | Total Loss: 3.676305055618286 | KNN Loss: 3.641549825668335 | CLS Loss: 0.03475524112582207
Epoch 122 / 200 | iteration 140 / 171 | Total Loss: 3.6424403190612793 | KNN Loss: 3.6222097873687744 | CLS Loss: 0.02023046277463436
Epoch 122 / 200 | iteration 150 / 171 | Total Loss: 3.6325597763061523 | KNN Loss: 3.621894598007202 | CLS Loss: 0.010665112175047398
Epoch 122 / 200 | iteration 160 / 171 | Total Loss: 3.602773666381836 | KNN Loss: 3.5993573665618896 | CLS Loss: 0.003416236490011215
Epoch 122 / 200 | iteration 170 / 171 | Total Loss: 3.6067333221435547 | KNN Loss: 3.582409620285034 | CLS Loss: 0.02432367391884327
Epoch: 122, Loss: 3.6387, Train: 0.9972, Valid: 0.9868, Best: 0.987

Epoch 125 / 200 | iteration 170 / 171 | Total Loss: 3.6330528259277344 | KNN Loss: 3.611884593963623 | CLS Loss: 0.02116812951862812
Epoch: 125, Loss: 3.6303, Train: 0.9971, Valid: 0.9872, Best: 0.9876
Epoch 126 / 200 | iteration 0 / 171 | Total Loss: 3.6070337295532227 | KNN Loss: 3.6008448600769043 | CLS Loss: 0.006188869010657072
Epoch 126 / 200 | iteration 10 / 171 | Total Loss: 3.612562417984009 | KNN Loss: 3.6066770553588867 | CLS Loss: 0.005885340739041567
Epoch 126 / 200 | iteration 20 / 171 | Total Loss: 3.6351852416992188 | KNN Loss: 3.61542010307312 | CLS Loss: 0.019765105098485947
Epoch 126 / 200 | iteration 30 / 171 | Total Loss: 3.6569817066192627 | KNN Loss: 3.641549825668335 | CLS Loss: 0.015431815758347511
Epoch 126 / 200 | iteration 40 / 171 | Total Loss: 3.6285853385925293 | KNN Loss: 3.6270718574523926 | CLS Loss: 0.0015135523863136768
Epoch 126 / 200 | iteration 50 / 171 | Total Loss: 3.6108267307281494 | KNN Loss: 3.6027166843414307 | CLS Loss: 0.00810998864471912

Epoch 129 / 200 | iteration 50 / 171 | Total Loss: 3.5954298973083496 | KNN Loss: 3.5792620182037354 | CLS Loss: 0.01616794802248478
Epoch 129 / 200 | iteration 60 / 171 | Total Loss: 3.6653692722320557 | KNN Loss: 3.660672187805176 | CLS Loss: 0.004697154741734266
Epoch 129 / 200 | iteration 70 / 171 | Total Loss: 3.63822603225708 | KNN Loss: 3.63690185546875 | CLS Loss: 0.001324193668551743
Epoch 129 / 200 | iteration 80 / 171 | Total Loss: 3.617915391921997 | KNN Loss: 3.600184679031372 | CLS Loss: 0.017730824649333954
Epoch 129 / 200 | iteration 90 / 171 | Total Loss: 3.6822149753570557 | KNN Loss: 3.65476655960083 | CLS Loss: 0.02744852751493454
Epoch 129 / 200 | iteration 100 / 171 | Total Loss: 3.638619899749756 | KNN Loss: 3.633920907974243 | CLS Loss: 0.0046988981775939465
Epoch 129 / 200 | iteration 110 / 171 | Total Loss: 3.681699275970459 | KNN Loss: 3.6619575023651123 | CLS Loss: 0.019741810858249664
Epoch 129 / 200 | iteration 120 / 171 | Total Loss: 3.6591529846191406 | 

Epoch 132 / 200 | iteration 120 / 171 | Total Loss: 3.652517795562744 | KNN Loss: 3.620737314224243 | CLS Loss: 0.0317804254591465
Epoch 132 / 200 | iteration 130 / 171 | Total Loss: 3.640562057495117 | KNN Loss: 3.62310791015625 | CLS Loss: 0.017454082146286964
Epoch 132 / 200 | iteration 140 / 171 | Total Loss: 3.6282551288604736 | KNN Loss: 3.6213486194610596 | CLS Loss: 0.006906555965542793
Epoch 132 / 200 | iteration 150 / 171 | Total Loss: 3.6288001537323 | KNN Loss: 3.6045637130737305 | CLS Loss: 0.024236558005213737
Epoch 132 / 200 | iteration 160 / 171 | Total Loss: 3.6037237644195557 | KNN Loss: 3.599393129348755 | CLS Loss: 0.0043307337909936905
Epoch 132 / 200 | iteration 170 / 171 | Total Loss: 3.626636266708374 | KNN Loss: 3.608339786529541 | CLS Loss: 0.018296506255865097
Epoch: 132, Loss: 3.6245, Train: 0.9964, Valid: 0.9861, Best: 0.9876
Epoch 133 / 200 | iteration 0 / 171 | Total Loss: 3.626883029937744 | KNN Loss: 3.624532461166382 | CLS Loss: 0.0023506709840148687
E

Epoch: 135, Loss: 3.6318, Train: 0.9968, Valid: 0.9855, Best: 0.9876
Epoch 136 / 200 | iteration 0 / 171 | Total Loss: 3.651228427886963 | KNN Loss: 3.6492791175842285 | CLS Loss: 0.001949426019564271
Epoch 136 / 200 | iteration 10 / 171 | Total Loss: 3.621546745300293 | KNN Loss: 3.604398250579834 | CLS Loss: 0.017148494720458984
Epoch 136 / 200 | iteration 20 / 171 | Total Loss: 3.6652674674987793 | KNN Loss: 3.6407558917999268 | CLS Loss: 0.02451159618794918
Epoch 136 / 200 | iteration 30 / 171 | Total Loss: 3.6104776859283447 | KNN Loss: 3.5993235111236572 | CLS Loss: 0.011154129169881344
Epoch 136 / 200 | iteration 40 / 171 | Total Loss: 3.620788335800171 | KNN Loss: 3.6153926849365234 | CLS Loss: 0.005395571701228619
Epoch 136 / 200 | iteration 50 / 171 | Total Loss: 3.6073827743530273 | KNN Loss: 3.584610939025879 | CLS Loss: 0.022771866992115974
Epoch 136 / 200 | iteration 60 / 171 | Total Loss: 3.6058828830718994 | KNN Loss: 3.5982143878936768 | CLS Loss: 0.007668403908610344


Epoch 139 / 200 | iteration 60 / 171 | Total Loss: 3.6040139198303223 | KNN Loss: 3.5955166816711426 | CLS Loss: 0.008497237227857113
Epoch 139 / 200 | iteration 70 / 171 | Total Loss: 3.6183176040649414 | KNN Loss: 3.609783411026001 | CLS Loss: 0.008534112013876438
Epoch 139 / 200 | iteration 80 / 171 | Total Loss: 3.6277382373809814 | KNN Loss: 3.622814893722534 | CLS Loss: 0.00492327893152833
Epoch 139 / 200 | iteration 90 / 171 | Total Loss: 3.679542303085327 | KNN Loss: 3.6577162742614746 | CLS Loss: 0.02182612754404545
Epoch 139 / 200 | iteration 100 / 171 | Total Loss: 3.6138951778411865 | KNN Loss: 3.5937552452087402 | CLS Loss: 0.020140044391155243
Epoch 139 / 200 | iteration 110 / 171 | Total Loss: 3.6832289695739746 | KNN Loss: 3.6762146949768066 | CLS Loss: 0.007014315575361252
Epoch 139 / 200 | iteration 120 / 171 | Total Loss: 3.612980365753174 | KNN Loss: 3.600708484649658 | CLS Loss: 0.012271893210709095
Epoch 139 / 200 | iteration 130 / 171 | Total Loss: 3.651127338409

Epoch 142 / 200 | iteration 120 / 171 | Total Loss: 3.629305839538574 | KNN Loss: 3.6237807273864746 | CLS Loss: 0.005525009706616402
Epoch 142 / 200 | iteration 130 / 171 | Total Loss: 3.602107524871826 | KNN Loss: 3.5902013778686523 | CLS Loss: 0.011906190775334835
Epoch 142 / 200 | iteration 140 / 171 | Total Loss: 3.6525707244873047 | KNN Loss: 3.6429953575134277 | CLS Loss: 0.009575475007295609
Epoch 142 / 200 | iteration 150 / 171 | Total Loss: 3.6267483234405518 | KNN Loss: 3.6000771522521973 | CLS Loss: 0.02667112462222576
Epoch 142 / 200 | iteration 160 / 171 | Total Loss: 3.609285354614258 | KNN Loss: 3.596386432647705 | CLS Loss: 0.012899026274681091
Epoch 142 / 200 | iteration 170 / 171 | Total Loss: 3.6330983638763428 | KNN Loss: 3.6079976558685303 | CLS Loss: 0.025100799277424812
Epoch: 142, Loss: 3.6361, Train: 0.9966, Valid: 0.9856, Best: 0.9876
Epoch 143 / 200 | iteration 0 / 171 | Total Loss: 3.5815141201019287 | KNN Loss: 3.5699427127838135 | CLS Loss: 0.011571464128

Epoch: 145, Loss: 3.6263, Train: 0.9961, Valid: 0.9857, Best: 0.9876
Epoch 146 / 200 | iteration 0 / 171 | Total Loss: 3.630467176437378 | KNN Loss: 3.6008429527282715 | CLS Loss: 0.029624175280332565
Epoch 146 / 200 | iteration 10 / 171 | Total Loss: 3.669983386993408 | KNN Loss: 3.662234306335449 | CLS Loss: 0.0077491640113294125
Epoch 146 / 200 | iteration 20 / 171 | Total Loss: 3.641341209411621 | KNN Loss: 3.6309127807617188 | CLS Loss: 0.010428446345031261
Epoch 146 / 200 | iteration 30 / 171 | Total Loss: 3.6484375 | KNN Loss: 3.643630027770996 | CLS Loss: 0.004807477351278067
Epoch 146 / 200 | iteration 40 / 171 | Total Loss: 3.574648857116699 | KNN Loss: 3.569812059402466 | CLS Loss: 0.00483689084649086
Epoch 146 / 200 | iteration 50 / 171 | Total Loss: 3.6487419605255127 | KNN Loss: 3.6372599601745605 | CLS Loss: 0.011482003144919872
Epoch 146 / 200 | iteration 60 / 171 | Total Loss: 3.6397705078125 | KNN Loss: 3.634249210357666 | CLS Loss: 0.005521188955754042
Epoch 146 / 20

Epoch 149 / 200 | iteration 60 / 171 | Total Loss: 3.631011724472046 | KNN Loss: 3.6190078258514404 | CLS Loss: 0.01200378593057394
Epoch 149 / 200 | iteration 70 / 171 | Total Loss: 3.628061294555664 | KNN Loss: 3.625187397003174 | CLS Loss: 0.002873779507353902
Epoch 149 / 200 | iteration 80 / 171 | Total Loss: 3.6225202083587646 | KNN Loss: 3.5924363136291504 | CLS Loss: 0.03008384257555008
Epoch 149 / 200 | iteration 90 / 171 | Total Loss: 3.6766152381896973 | KNN Loss: 3.6687607765197754 | CLS Loss: 0.00785440020263195
Epoch 149 / 200 | iteration 100 / 171 | Total Loss: 3.6344192028045654 | KNN Loss: 3.621608018875122 | CLS Loss: 0.012811115942895412
Epoch 149 / 200 | iteration 110 / 171 | Total Loss: 3.65803861618042 | KNN Loss: 3.647348642349243 | CLS Loss: 0.010689914226531982
Epoch 149 / 200 | iteration 120 / 171 | Total Loss: 3.686155319213867 | KNN Loss: 3.6373753547668457 | CLS Loss: 0.048779960721731186
Epoch 149 / 200 | iteration 130 / 171 | Total Loss: 3.640346050262451 

Epoch 152 / 200 | iteration 120 / 171 | Total Loss: 3.648953676223755 | KNN Loss: 3.6344218254089355 | CLS Loss: 0.014531935565173626
Epoch 152 / 200 | iteration 130 / 171 | Total Loss: 3.6194186210632324 | KNN Loss: 3.598566770553589 | CLS Loss: 0.02085183933377266
Epoch 152 / 200 | iteration 140 / 171 | Total Loss: 3.588383197784424 | KNN Loss: 3.577138662338257 | CLS Loss: 0.01124450284987688
Epoch 152 / 200 | iteration 150 / 171 | Total Loss: 3.650620460510254 | KNN Loss: 3.644127368927002 | CLS Loss: 0.006493148393929005
Epoch 152 / 200 | iteration 160 / 171 | Total Loss: 3.6075587272644043 | KNN Loss: 3.5925955772399902 | CLS Loss: 0.014963115565478802
Epoch 152 / 200 | iteration 170 / 171 | Total Loss: 3.636598587036133 | KNN Loss: 3.6303563117980957 | CLS Loss: 0.00624237023293972
Epoch: 152, Loss: 3.6270, Train: 0.9962, Valid: 0.9873, Best: 0.9876
Epoch 153 / 200 | iteration 0 / 171 | Total Loss: 3.6427061557769775 | KNN Loss: 3.6322102546691895 | CLS Loss: 0.01049584057182073

Epoch: 155, Loss: 3.6246, Train: 0.9973, Valid: 0.9871, Best: 0.9876
Epoch 156 / 200 | iteration 0 / 171 | Total Loss: 3.58880352973938 | KNN Loss: 3.5806405544281006 | CLS Loss: 0.00816306658089161
Epoch 156 / 200 | iteration 10 / 171 | Total Loss: 3.617133855819702 | KNN Loss: 3.6070237159729004 | CLS Loss: 0.010110078379511833
Epoch 156 / 200 | iteration 20 / 171 | Total Loss: 3.6441988945007324 | KNN Loss: 3.634873867034912 | CLS Loss: 0.009325118735432625
Epoch 156 / 200 | iteration 30 / 171 | Total Loss: 3.6147701740264893 | KNN Loss: 3.612868547439575 | CLS Loss: 0.0019017342710867524
Epoch 156 / 200 | iteration 40 / 171 | Total Loss: 3.5938832759857178 | KNN Loss: 3.590275526046753 | CLS Loss: 0.003607644699513912
Epoch 156 / 200 | iteration 50 / 171 | Total Loss: 3.5858376026153564 | KNN Loss: 3.579669713973999 | CLS Loss: 0.0061679743230342865
Epoch 156 / 200 | iteration 60 / 171 | Total Loss: 3.6042098999023438 | KNN Loss: 3.5838799476623535 | CLS Loss: 0.020329907536506653


Epoch 159 / 200 | iteration 60 / 171 | Total Loss: 3.646608829498291 | KNN Loss: 3.640458106994629 | CLS Loss: 0.006150651257485151
Epoch 159 / 200 | iteration 70 / 171 | Total Loss: 3.6408417224884033 | KNN Loss: 3.6133334636688232 | CLS Loss: 0.027508312836289406
Epoch 159 / 200 | iteration 80 / 171 | Total Loss: 3.6237363815307617 | KNN Loss: 3.5896825790405273 | CLS Loss: 0.034053727984428406
Epoch 159 / 200 | iteration 90 / 171 | Total Loss: 3.622403144836426 | KNN Loss: 3.6150670051574707 | CLS Loss: 0.007336210925132036
Epoch 159 / 200 | iteration 100 / 171 | Total Loss: 3.598031997680664 | KNN Loss: 3.5888750553131104 | CLS Loss: 0.00915694609284401
Epoch 159 / 200 | iteration 110 / 171 | Total Loss: 3.6189346313476562 | KNN Loss: 3.6119136810302734 | CLS Loss: 0.007020971737802029
Epoch 159 / 200 | iteration 120 / 171 | Total Loss: 3.583162546157837 | KNN Loss: 3.576295852661133 | CLS Loss: 0.006866663694381714
Epoch 159 / 200 | iteration 130 / 171 | Total Loss: 3.606523752212

Epoch 162 / 200 | iteration 120 / 171 | Total Loss: 3.654322385787964 | KNN Loss: 3.631718158721924 | CLS Loss: 0.02260415069758892
Epoch 162 / 200 | iteration 130 / 171 | Total Loss: 3.6650290489196777 | KNN Loss: 3.651296615600586 | CLS Loss: 0.013732338324189186
Epoch 162 / 200 | iteration 140 / 171 | Total Loss: 3.612222671508789 | KNN Loss: 3.600280284881592 | CLS Loss: 0.011942435055971146
Epoch 162 / 200 | iteration 150 / 171 | Total Loss: 3.6138627529144287 | KNN Loss: 3.6079585552215576 | CLS Loss: 0.005904106423258781
Epoch 162 / 200 | iteration 160 / 171 | Total Loss: 3.6468663215637207 | KNN Loss: 3.644155263900757 | CLS Loss: 0.002710946137085557
Epoch 162 / 200 | iteration 170 / 171 | Total Loss: 3.7044856548309326 | KNN Loss: 3.6938228607177734 | CLS Loss: 0.010662887245416641
Epoch: 162, Loss: 3.6274, Train: 0.9971, Valid: 0.9871, Best: 0.9876
Epoch 163 / 200 | iteration 0 / 171 | Total Loss: 3.582772731781006 | KNN Loss: 3.575026750564575 | CLS Loss: 0.0077460668981075

Epoch: 165, Loss: 3.6256, Train: 0.9963, Valid: 0.9856, Best: 0.9876
Epoch 166 / 200 | iteration 0 / 171 | Total Loss: 3.64422607421875 | KNN Loss: 3.634563684463501 | CLS Loss: 0.009662310592830181
Epoch 166 / 200 | iteration 10 / 171 | Total Loss: 3.6277756690979004 | KNN Loss: 3.624656915664673 | CLS Loss: 0.003118707099929452
Epoch 166 / 200 | iteration 20 / 171 | Total Loss: 3.5975778102874756 | KNN Loss: 3.5925686359405518 | CLS Loss: 0.0050091990269720554
Epoch 166 / 200 | iteration 30 / 171 | Total Loss: 3.624056816101074 | KNN Loss: 3.6207542419433594 | CLS Loss: 0.0033026766031980515
Epoch 166 / 200 | iteration 40 / 171 | Total Loss: 3.600933074951172 | KNN Loss: 3.5862720012664795 | CLS Loss: 0.014661001041531563
Epoch 166 / 200 | iteration 50 / 171 | Total Loss: 3.6075472831726074 | KNN Loss: 3.604901075363159 | CLS Loss: 0.0026461449451744556
Epoch 166 / 200 | iteration 60 / 171 | Total Loss: 3.6233067512512207 | KNN Loss: 3.619687795639038 | CLS Loss: 0.003618944669142365

Epoch 169 / 200 | iteration 60 / 171 | Total Loss: 3.59635591506958 | KNN Loss: 3.58709979057312 | CLS Loss: 0.009256059303879738
Epoch 169 / 200 | iteration 70 / 171 | Total Loss: 3.6413075923919678 | KNN Loss: 3.630943775177002 | CLS Loss: 0.010363870300352573
Epoch 169 / 200 | iteration 80 / 171 | Total Loss: 3.664003849029541 | KNN Loss: 3.6574158668518066 | CLS Loss: 0.006588030606508255
Epoch 169 / 200 | iteration 90 / 171 | Total Loss: 3.619335651397705 | KNN Loss: 3.6113383769989014 | CLS Loss: 0.007997189648449421
Epoch 169 / 200 | iteration 100 / 171 | Total Loss: 3.6831157207489014 | KNN Loss: 3.6601810455322266 | CLS Loss: 0.02293458580970764
Epoch 169 / 200 | iteration 110 / 171 | Total Loss: 3.6903533935546875 | KNN Loss: 3.6694326400756836 | CLS Loss: 0.020920682698488235
Epoch 169 / 200 | iteration 120 / 171 | Total Loss: 3.6932919025421143 | KNN Loss: 3.687347173690796 | CLS Loss: 0.005944840610027313
Epoch 169 / 200 | iteration 130 / 171 | Total Loss: 3.61713337898254

Epoch 172 / 200 | iteration 120 / 171 | Total Loss: 3.580548048019409 | KNN Loss: 3.57124662399292 | CLS Loss: 0.009301463142037392
Epoch 172 / 200 | iteration 130 / 171 | Total Loss: 3.6521294116973877 | KNN Loss: 3.649071455001831 | CLS Loss: 0.0030578547157347202
Epoch 172 / 200 | iteration 140 / 171 | Total Loss: 3.655874252319336 | KNN Loss: 3.645141363143921 | CLS Loss: 0.010732807219028473
Epoch 172 / 200 | iteration 150 / 171 | Total Loss: 3.622750997543335 | KNN Loss: 3.597238540649414 | CLS Loss: 0.025512436404824257
Epoch 172 / 200 | iteration 160 / 171 | Total Loss: 3.5924313068389893 | KNN Loss: 3.5897388458251953 | CLS Loss: 0.0026924856938421726
Epoch 172 / 200 | iteration 170 / 171 | Total Loss: 3.5869181156158447 | KNN Loss: 3.577371597290039 | CLS Loss: 0.009546566754579544
Epoch: 172, Loss: 3.6239, Train: 0.9977, Valid: 0.9870, Best: 0.9876
Epoch 173 / 200 | iteration 0 / 171 | Total Loss: 3.5847556591033936 | KNN Loss: 3.5813450813293457 | CLS Loss: 0.00341056287288

Epoch: 175, Loss: 3.6220, Train: 0.9971, Valid: 0.9859, Best: 0.9876
Epoch 176 / 200 | iteration 0 / 171 | Total Loss: 3.6143691539764404 | KNN Loss: 3.6049280166625977 | CLS Loss: 0.009441026486456394
Epoch 176 / 200 | iteration 10 / 171 | Total Loss: 3.605889320373535 | KNN Loss: 3.597479820251465 | CLS Loss: 0.008409553207457066
Epoch 176 / 200 | iteration 20 / 171 | Total Loss: 3.615475654602051 | KNN Loss: 3.6118106842041016 | CLS Loss: 0.003664921037852764
Epoch 176 / 200 | iteration 30 / 171 | Total Loss: 3.632847309112549 | KNN Loss: 3.6202259063720703 | CLS Loss: 0.012621358968317509
Epoch 176 / 200 | iteration 40 / 171 | Total Loss: 3.5956640243530273 | KNN Loss: 3.585963487625122 | CLS Loss: 0.009700545109808445
Epoch 176 / 200 | iteration 50 / 171 | Total Loss: 3.6137568950653076 | KNN Loss: 3.599789619445801 | CLS Loss: 0.013967392034828663
Epoch 176 / 200 | iteration 60 / 171 | Total Loss: 3.607226848602295 | KNN Loss: 3.60377836227417 | CLS Loss: 0.0034485303331166506
Ep

Epoch 179 / 200 | iteration 60 / 171 | Total Loss: 3.652689218521118 | KNN Loss: 3.627095937728882 | CLS Loss: 0.025593332946300507
Epoch 179 / 200 | iteration 70 / 171 | Total Loss: 3.6557202339172363 | KNN Loss: 3.6421141624450684 | CLS Loss: 0.01360598485916853
Epoch 179 / 200 | iteration 80 / 171 | Total Loss: 3.61965274810791 | KNN Loss: 3.614889144897461 | CLS Loss: 0.0047636511735618114
Epoch 179 / 200 | iteration 90 / 171 | Total Loss: 3.6102490425109863 | KNN Loss: 3.606499433517456 | CLS Loss: 0.0037496844306588173
Epoch 179 / 200 | iteration 100 / 171 | Total Loss: 3.626084327697754 | KNN Loss: 3.6149988174438477 | CLS Loss: 0.011085602454841137
Epoch 179 / 200 | iteration 110 / 171 | Total Loss: 3.698547601699829 | KNN Loss: 3.6567821502685547 | CLS Loss: 0.04176536574959755
Epoch 179 / 200 | iteration 120 / 171 | Total Loss: 3.6583595275878906 | KNN Loss: 3.6436009407043457 | CLS Loss: 0.014758583158254623
Epoch 179 / 200 | iteration 130 / 171 | Total Loss: 3.6692814826965

Epoch 182 / 200 | iteration 120 / 171 | Total Loss: 3.6429431438446045 | KNN Loss: 3.6204166412353516 | CLS Loss: 0.02252655103802681
Epoch 182 / 200 | iteration 130 / 171 | Total Loss: 3.639878749847412 | KNN Loss: 3.6379127502441406 | CLS Loss: 0.001965959556400776
Epoch 182 / 200 | iteration 140 / 171 | Total Loss: 3.5996601581573486 | KNN Loss: 3.5913119316101074 | CLS Loss: 0.00834817998111248
Epoch 182 / 200 | iteration 150 / 171 | Total Loss: 3.6642725467681885 | KNN Loss: 3.6548068523406982 | CLS Loss: 0.00946575403213501
Epoch 182 / 200 | iteration 160 / 171 | Total Loss: 3.6733999252319336 | KNN Loss: 3.6675338745117188 | CLS Loss: 0.005866130348294973
Epoch 182 / 200 | iteration 170 / 171 | Total Loss: 3.634690523147583 | KNN Loss: 3.6240346431732178 | CLS Loss: 0.010655797086656094
Epoch: 182, Loss: 3.6295, Train: 0.9956, Valid: 0.9859, Best: 0.9876
Epoch 183 / 200 | iteration 0 / 171 | Total Loss: 3.628014087677002 | KNN Loss: 3.602497100830078 | CLS Loss: 0.02551709488034

Epoch: 185, Loss: 3.6264, Train: 0.9976, Valid: 0.9869, Best: 0.9876
Epoch 186 / 200 | iteration 0 / 171 | Total Loss: 3.5996034145355225 | KNN Loss: 3.5942909717559814 | CLS Loss: 0.0053123896941542625
Epoch 186 / 200 | iteration 10 / 171 | Total Loss: 3.618114471435547 | KNN Loss: 3.615853786468506 | CLS Loss: 0.0022606896236538887
Epoch 186 / 200 | iteration 20 / 171 | Total Loss: 3.621063470840454 | KNN Loss: 3.613410234451294 | CLS Loss: 0.007653196342289448
Epoch 186 / 200 | iteration 30 / 171 | Total Loss: 3.5906076431274414 | KNN Loss: 3.5887577533721924 | CLS Loss: 0.0018499090801924467
Epoch 186 / 200 | iteration 40 / 171 | Total Loss: 3.614187002182007 | KNN Loss: 3.6103265285491943 | CLS Loss: 0.0038604652509093285
Epoch 186 / 200 | iteration 50 / 171 | Total Loss: 3.6237847805023193 | KNN Loss: 3.610759973526001 | CLS Loss: 0.013024895451962948
Epoch 186 / 200 | iteration 60 / 171 | Total Loss: 3.643181324005127 | KNN Loss: 3.6329989433288574 | CLS Loss: 0.0101824980229139

Epoch 189 / 200 | iteration 60 / 171 | Total Loss: 3.592528820037842 | KNN Loss: 3.585294485092163 | CLS Loss: 0.007234451826661825
Epoch 189 / 200 | iteration 70 / 171 | Total Loss: 3.591280937194824 | KNN Loss: 3.5898499488830566 | CLS Loss: 0.0014308879617601633
Epoch 189 / 200 | iteration 80 / 171 | Total Loss: 3.579073429107666 | KNN Loss: 3.5757646560668945 | CLS Loss: 0.0033087963238358498
Epoch 189 / 200 | iteration 90 / 171 | Total Loss: 3.6223948001861572 | KNN Loss: 3.6163384914398193 | CLS Loss: 0.0060562132857739925
Epoch 189 / 200 | iteration 100 / 171 | Total Loss: 3.6112325191497803 | KNN Loss: 3.6046500205993652 | CLS Loss: 0.006582415662705898
Epoch 189 / 200 | iteration 110 / 171 | Total Loss: 3.6086058616638184 | KNN Loss: 3.594255208969116 | CLS Loss: 0.014350628480315208
Epoch 189 / 200 | iteration 120 / 171 | Total Loss: 3.5999295711517334 | KNN Loss: 3.596799612045288 | CLS Loss: 0.003129939315840602
Epoch 189 / 200 | iteration 130 / 171 | Total Loss: 3.59919261

Epoch 192 / 200 | iteration 130 / 171 | Total Loss: 3.637855052947998 | KNN Loss: 3.6268444061279297 | CLS Loss: 0.01101070735603571
Epoch 192 / 200 | iteration 140 / 171 | Total Loss: 3.6233179569244385 | KNN Loss: 3.614175319671631 | CLS Loss: 0.009142590686678886
Epoch 192 / 200 | iteration 150 / 171 | Total Loss: 3.5889978408813477 | KNN Loss: 3.583848714828491 | CLS Loss: 0.005149089731276035
Epoch 192 / 200 | iteration 160 / 171 | Total Loss: 3.6393251419067383 | KNN Loss: 3.6283762454986572 | CLS Loss: 0.010948811657726765
Epoch 192 / 200 | iteration 170 / 171 | Total Loss: 3.6060292720794678 | KNN Loss: 3.5987157821655273 | CLS Loss: 0.007313478272408247
Epoch: 192, Loss: 3.6251, Train: 0.9975, Valid: 0.9870, Best: 0.9876
Epoch 193 / 200 | iteration 0 / 171 | Total Loss: 3.5758273601531982 | KNN Loss: 3.570014476776123 | CLS Loss: 0.005812915973365307
Epoch 193 / 200 | iteration 10 / 171 | Total Loss: 3.5749545097351074 | KNN Loss: 3.5733964443206787 | CLS Loss: 0.0015581520274

Epoch 196 / 200 | iteration 10 / 171 | Total Loss: 3.6029610633850098 | KNN Loss: 3.580819606781006 | CLS Loss: 0.022141383960843086
Epoch 196 / 200 | iteration 20 / 171 | Total Loss: 3.607645273208618 | KNN Loss: 3.6037845611572266 | CLS Loss: 0.003860811935737729
Epoch 196 / 200 | iteration 30 / 171 | Total Loss: 3.6153924465179443 | KNN Loss: 3.6134893894195557 | CLS Loss: 0.0019030848052352667
Epoch 196 / 200 | iteration 40 / 171 | Total Loss: 3.616936445236206 | KNN Loss: 3.612663507461548 | CLS Loss: 0.004272931255400181
Epoch 196 / 200 | iteration 50 / 171 | Total Loss: 3.619908094406128 | KNN Loss: 3.6017515659332275 | CLS Loss: 0.01815658062696457
Epoch 196 / 200 | iteration 60 / 171 | Total Loss: 3.6128487586975098 | KNN Loss: 3.608612537384033 | CLS Loss: 0.004236267879605293
Epoch 196 / 200 | iteration 70 / 171 | Total Loss: 3.5964102745056152 | KNN Loss: 3.593170166015625 | CLS Loss: 0.0032401911448687315
Epoch 196 / 200 | iteration 80 / 171 | Total Loss: 3.668031930923462

Epoch 199 / 200 | iteration 70 / 171 | Total Loss: 3.5943009853363037 | KNN Loss: 3.5925846099853516 | CLS Loss: 0.0017163126030936837
Epoch 199 / 200 | iteration 80 / 171 | Total Loss: 3.597454786300659 | KNN Loss: 3.5926666259765625 | CLS Loss: 0.004788219463080168
Epoch 199 / 200 | iteration 90 / 171 | Total Loss: 3.6179463863372803 | KNN Loss: 3.605574131011963 | CLS Loss: 0.012372189201414585
Epoch 199 / 200 | iteration 100 / 171 | Total Loss: 3.617128372192383 | KNN Loss: 3.605790376663208 | CLS Loss: 0.011337882839143276
Epoch 199 / 200 | iteration 110 / 171 | Total Loss: 3.601569175720215 | KNN Loss: 3.578174114227295 | CLS Loss: 0.023395152762532234
Epoch 199 / 200 | iteration 120 / 171 | Total Loss: 3.665844440460205 | KNN Loss: 3.6365015506744385 | CLS Loss: 0.029342852532863617
Epoch 199 / 200 | iteration 130 / 171 | Total Loss: 3.6328747272491455 | KNN Loss: 3.616969108581543 | CLS Loss: 0.01590556465089321
Epoch 199 / 200 | iteration 140 / 171 | Total Loss: 3.626754760742

In [8]:
test(model, test_data_iter, device)

tensor(0.9878, device='cuda:0')

In [9]:
plt.figure()
plt.plot(losses, label='train loss')
plt.legend()
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [10]:
plt.figure()
plt.plot(train_accs, label='train accuracy')
plt.plot(val_accs, label='validation accuracy')
plt.legend()
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [11]:
test_samples = torch.tensor([])
projections = torch.tensor([])
labels = torch.tensor([])

with torch.no_grad():
    for x, y in tqdm(test_data_iter):
        test_samples = torch.cat([test_samples, x])
        labels = torch.cat([labels, y])
        x = x.to(device)
        _, interm = model(x, True)
        projections = torch.cat([projections, interm.detach().cpu().flatten(1)])

  0%|          | 0/43 [00:00<?, ?it/s]

In [12]:
distances = pairwise_distances(projections)
# distances = np.triu(distances)
distances_f = distances.flatten()

plt.matshow(distances)
plt.colorbar()
plt.figure()
plt.hist(distances_f[distances_f > 0], bins=1000)
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [13]:
clusters = DBSCAN(eps=2, min_samples=10).fit_predict(projections)

In [14]:
print(f"Number of inliers: {sum(clusters != -1) / len(clusters)}")

Number of inliers: 0.9780731807592161


In [15]:
perplexity = 100
p = reduce_dims_and_plot(projections[clusters != -1],
                         y=clusters[clusters != -1],
                         title=f'perplexity: {perplexity}',
                         file_name=None,
                         perplexity=perplexity,
                         library='Multicore-TSNE',
                         perform_PCA=False,
                         projected=None,
                         figure_type='2d',
                         show_figure=True,
                         close_figure=False,
                         text=None)    

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

# Train a Soft-Decision-Tree given the self-labels

## Prepare the dataset

In [16]:
tree_dataset = list(zip(test_samples.flatten(1)[clusters!=-1], clusters[clusters != -1]))
batch_size = 512
tree_loader = torch.utils.data.DataLoader(tree_dataset, batch_size=batch_size, shuffle=True)

# Define how we prune the weights of a node

In [17]:
def prune_node(node_weights, factor=1):
    w = node_weights.cpu().detach().numpy()
    mean_ = np.mean(w)
    std_ = np.std(w)
    node_weights[((mean_ - std_ * factor) < node_weights) & (node_weights < (mean_ + std_ * factor))] = 0
    return node_weights

def prune_node_keep(node_weights, keep=4):
    w = node_weights.cpu().detach().numpy()
    throw_idx = np.argsort(abs(w))[:-keep]
    node_weights[throw_idx] = 0
    return node_weights

def prune_tree(tree_, factor):
    new_weights = tree_.inner_nodes.weight.clone()
    for i in range(new_weights.shape[0]):
        res = prune_node_keep(new_weights[i, :], factor)
        new_weights[i, :] = res

    with torch.no_grad():
        tree_.inner_nodes.weight.copy_(new_weights)
        
def sparseness(x):
    s = []
    for i in range(x.shape[0]):
        x_ = x[i, :]
        sp = (len(x_) - torch.norm(x_, 0).item()) / len(x_)
        s.append(sp)
    return np.mean(s)

def compute_regularization_by_level(tree):
    total_reg = 0
    for i in range(tree.inner_nodes.weight.shape[0]):
        cur_level = np.floor(np.log2(i+1))
        node_reg = torch.norm(tree.inner_nodes.weight[i].view(-1), 2)
        total_reg += 2**(-cur_level) * node_reg
    return total_reg

def show_sparseness(tree):
    avg_sp = sparseness(tree.inner_nodes.weight)
    print(f"Average sparseness: {avg_sp}")
    layer = 0
    sps = []
    for i in range(tree.inner_nodes.weight.shape[0]):
        cur_layer = int(np.floor(np.log2(i+1)))
        if cur_layer != layer:
            print(f"layer {layer}: {np.mean(sps)}")
            sps = []
            layer = cur_layer

        x_ = tree.inner_nodes.weight[i, :]
        sp = (len(x_) - torch.norm(x_, 0).item()) / len(x_)
        sps.append(sp)
        
    return avg_sp

## Training configurations

In [18]:
def do_epoch(model, loader, device, log_interval, losses, accs, epoch, iteration):
    model = model.train()
    for batch_idx, (data, target) in enumerate(loader):
        iteration += 1
        data, target = data.to(device), target.to(device)

        output, penalty = tree.forward(data)

        # Loss
        loss_tree = criterion(output, target.view(-1))

        # Penalty
        loss_tree += penalty

        # Sparse regularization
#         fc_params = torch.cat([x.view(-1) for x in tree.inner_nodes.parameters()])
#         regularization = sparsity_lamda * torch.norm(fc_params, 2)
        regularization = sparsity_lamda * compute_regularization_by_level(tree)
        loss = loss_tree

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        losses.append(loss.item())

        pred = output.data.max(1)[1]
        correct = pred.eq(target.view(-1).data).sum()
        accs.append(correct.item() / data.size()[0])

        # Print training status
        if batch_idx % log_interval == 0:
            print(f"Epoch: {epoch:02d} | Batch: {batch_idx:03d} / {len(loader):03d} | Total loss: {loss.item():.3f} | Reg loss: {regularization.item():.3f} | Tree loss: {loss_tree.item():.3f} | Accuracy: {correct.item() / data.size()[0]:03f} | {round((time.time() - start_time) / iteration, 3)} sec/iter")
            
    return iteration


In [19]:
lr = 5e-3
weight_decay = 5e-4
sparsity_lamda = 2e-3
epochs = 400
log_interval = 100
use_cuda = device != 'cpu'

In [20]:
tree = SDT(input_dim=test_samples.shape[2], output_dim=len(set(clusters)) - 1, depth=tree_depth, lamda=1e-3, use_cuda=use_cuda)
optimizer = torch.optim.Adam(tree.parameters(),
                                 lr=lr,
                                 weight_decay=weight_decay)
criterion = nn.CrossEntropyLoss()
tree = tree.to(device)

In [21]:
losses = []
accs = []
sparsity = []

In [22]:
start_time = time.time()
iteration = 0
for epoch in range(epochs):
    # Training
    avg_sp = show_sparseness(tree)
    sparsity.append(avg_sp)
    iteration = do_epoch(tree, tree_loader, device, log_interval, losses, accs, epoch, iteration)
    
    if epoch % 1 == 0:
        prune_tree(tree, factor=3)
        

Average sparseness: 0.0
layer 0: 0.0
layer 1: 0.0
layer 2: 0.0
layer 3: 0.0
layer 4: 0.0
layer 5: 0.0
layer 6: 0.0
Epoch: 00 | Batch: 000 / 042 | Total loss: 1.413 | Reg loss: 0.009 | Tree loss: 1.413 | Accuracy: 0.064453 | 0.308 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
Epoch: 01 | Batch: 000 / 042 | Total loss: 1.185 | Reg loss: 0.004 | Tree loss: 1.185 | Accuracy: 0.794922 | 0.22 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
Epoch: 02 | Batch: 000 / 042 | Total loss: 1.109 | Reg loss: 0.007 | Tree loss: 1.109 | Accuracy: 0.744141 | 0.218 sec/iter
Average sparseness: 0.9840425531914894
layer

Epoch: 23 | Batch: 000 / 042 | Total loss: 0.677 | Reg loss: 0.021 | Tree loss: 0.677 | Accuracy: 0.763672 | 0.227 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
Epoch: 24 | Batch: 000 / 042 | Total loss: 0.688 | Reg loss: 0.021 | Tree loss: 0.688 | Accuracy: 0.759766 | 0.227 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
Epoch: 25 | Batch: 000 / 042 | Total loss: 0.615 | Reg loss: 0.022 | Tree loss: 0.615 | Accuracy: 0.808594 | 0.228 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4

Epoch: 46 | Batch: 000 / 042 | Total loss: 0.650 | Reg loss: 0.024 | Tree loss: 0.650 | Accuracy: 0.783203 | 0.229 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
Epoch: 47 | Batch: 000 / 042 | Total loss: 0.607 | Reg loss: 0.024 | Tree loss: 0.607 | Accuracy: 0.796875 | 0.229 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
Epoch: 48 | Batch: 000 / 042 | Total loss: 0.627 | Reg loss: 0.024 | Tree loss: 0.627 | Accuracy: 0.802734 | 0.229 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4

Epoch: 69 | Batch: 000 / 042 | Total loss: 0.615 | Reg loss: 0.025 | Tree loss: 0.615 | Accuracy: 0.808594 | 0.229 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
Epoch: 70 | Batch: 000 / 042 | Total loss: 0.624 | Reg loss: 0.025 | Tree loss: 0.624 | Accuracy: 0.804688 | 0.229 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
Epoch: 71 | Batch: 000 / 042 | Total loss: 0.652 | Reg loss: 0.025 | Tree loss: 0.652 | Accuracy: 0.792969 | 0.229 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4

layer 6: 0.9840425531914895
Epoch: 92 | Batch: 000 / 042 | Total loss: 0.685 | Reg loss: 0.025 | Tree loss: 0.685 | Accuracy: 0.765625 | 0.228 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
Epoch: 93 | Batch: 000 / 042 | Total loss: 0.626 | Reg loss: 0.025 | Tree loss: 0.626 | Accuracy: 0.787109 | 0.228 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
Epoch: 94 | Batch: 000 / 042 | Total loss: 0.612 | Reg loss: 0.025 | Tree loss: 0.612 | Accuracy: 0.798828 | 0.228 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3

Epoch: 115 | Batch: 000 / 042 | Total loss: 0.604 | Reg loss: 0.025 | Tree loss: 0.604 | Accuracy: 0.800781 | 0.229 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
Epoch: 116 | Batch: 000 / 042 | Total loss: 0.623 | Reg loss: 0.025 | Tree loss: 0.623 | Accuracy: 0.792969 | 0.229 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
Epoch: 117 | Batch: 000 / 042 | Total loss: 0.604 | Reg loss: 0.025 | Tree loss: 0.604 | Accuracy: 0.794922 | 0.229 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
laye

Epoch: 138 | Batch: 000 / 042 | Total loss: 0.582 | Reg loss: 0.025 | Tree loss: 0.582 | Accuracy: 0.820312 | 0.228 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
Epoch: 139 | Batch: 000 / 042 | Total loss: 0.599 | Reg loss: 0.025 | Tree loss: 0.599 | Accuracy: 0.810547 | 0.228 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
Epoch: 140 | Batch: 000 / 042 | Total loss: 0.619 | Reg loss: 0.025 | Tree loss: 0.619 | Accuracy: 0.792969 | 0.228 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
laye

Epoch: 161 | Batch: 000 / 042 | Total loss: 0.587 | Reg loss: 0.025 | Tree loss: 0.587 | Accuracy: 0.816406 | 0.228 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
Epoch: 162 | Batch: 000 / 042 | Total loss: 0.619 | Reg loss: 0.025 | Tree loss: 0.619 | Accuracy: 0.806641 | 0.228 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
Epoch: 163 | Batch: 000 / 042 | Total loss: 0.668 | Reg loss: 0.025 | Tree loss: 0.668 | Accuracy: 0.773438 | 0.228 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
laye

Epoch: 184 | Batch: 000 / 042 | Total loss: 0.578 | Reg loss: 0.025 | Tree loss: 0.578 | Accuracy: 0.802734 | 0.229 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
Epoch: 185 | Batch: 000 / 042 | Total loss: 0.579 | Reg loss: 0.025 | Tree loss: 0.579 | Accuracy: 0.808594 | 0.229 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
Epoch: 186 | Batch: 000 / 042 | Total loss: 0.687 | Reg loss: 0.025 | Tree loss: 0.687 | Accuracy: 0.755859 | 0.229 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
laye

Epoch: 207 | Batch: 000 / 042 | Total loss: 0.608 | Reg loss: 0.025 | Tree loss: 0.608 | Accuracy: 0.812500 | 0.229 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
Epoch: 208 | Batch: 000 / 042 | Total loss: 0.635 | Reg loss: 0.025 | Tree loss: 0.635 | Accuracy: 0.792969 | 0.229 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
Epoch: 209 | Batch: 000 / 042 | Total loss: 0.621 | Reg loss: 0.025 | Tree loss: 0.621 | Accuracy: 0.785156 | 0.229 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
laye

Epoch: 230 | Batch: 000 / 042 | Total loss: 0.640 | Reg loss: 0.024 | Tree loss: 0.640 | Accuracy: 0.785156 | 0.229 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
Epoch: 231 | Batch: 000 / 042 | Total loss: 0.633 | Reg loss: 0.024 | Tree loss: 0.633 | Accuracy: 0.789062 | 0.229 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
Epoch: 232 | Batch: 000 / 042 | Total loss: 0.600 | Reg loss: 0.024 | Tree loss: 0.600 | Accuracy: 0.814453 | 0.229 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
laye

Epoch: 253 | Batch: 000 / 042 | Total loss: 0.625 | Reg loss: 0.025 | Tree loss: 0.625 | Accuracy: 0.794922 | 0.23 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
Epoch: 254 | Batch: 000 / 042 | Total loss: 0.619 | Reg loss: 0.024 | Tree loss: 0.619 | Accuracy: 0.792969 | 0.23 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
Epoch: 255 | Batch: 000 / 042 | Total loss: 0.596 | Reg loss: 0.025 | Tree loss: 0.596 | Accuracy: 0.800781 | 0.23 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4

Epoch: 276 | Batch: 000 / 042 | Total loss: 0.641 | Reg loss: 0.025 | Tree loss: 0.641 | Accuracy: 0.789062 | 0.231 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
Epoch: 277 | Batch: 000 / 042 | Total loss: 0.631 | Reg loss: 0.025 | Tree loss: 0.631 | Accuracy: 0.787109 | 0.231 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
Epoch: 278 | Batch: 000 / 042 | Total loss: 0.609 | Reg loss: 0.025 | Tree loss: 0.609 | Accuracy: 0.787109 | 0.231 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
laye

Epoch: 299 | Batch: 000 / 042 | Total loss: 0.524 | Reg loss: 0.025 | Tree loss: 0.524 | Accuracy: 0.841797 | 0.232 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
Epoch: 300 | Batch: 000 / 042 | Total loss: 0.588 | Reg loss: 0.025 | Tree loss: 0.588 | Accuracy: 0.818359 | 0.232 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
Epoch: 301 | Batch: 000 / 042 | Total loss: 0.579 | Reg loss: 0.025 | Tree loss: 0.579 | Accuracy: 0.833984 | 0.232 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
laye

Epoch: 322 | Batch: 000 / 042 | Total loss: 0.621 | Reg loss: 0.025 | Tree loss: 0.621 | Accuracy: 0.785156 | 0.232 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
Epoch: 323 | Batch: 000 / 042 | Total loss: 0.678 | Reg loss: 0.025 | Tree loss: 0.678 | Accuracy: 0.769531 | 0.232 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
Epoch: 324 | Batch: 000 / 042 | Total loss: 0.648 | Reg loss: 0.025 | Tree loss: 0.648 | Accuracy: 0.789062 | 0.232 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
laye

layer 6: 0.9840425531914895
Epoch: 345 | Batch: 000 / 042 | Total loss: 0.617 | Reg loss: 0.025 | Tree loss: 0.617 | Accuracy: 0.796875 | 0.233 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
Epoch: 346 | Batch: 000 / 042 | Total loss: 0.571 | Reg loss: 0.025 | Tree loss: 0.571 | Accuracy: 0.826172 | 0.233 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
Epoch: 347 | Batch: 000 / 042 | Total loss: 0.587 | Reg loss: 0.025 | Tree loss: 0.587 | Accuracy: 0.820312 | 0.233 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
laye

Epoch: 368 | Batch: 000 / 042 | Total loss: 0.615 | Reg loss: 0.025 | Tree loss: 0.615 | Accuracy: 0.791016 | 0.233 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
Epoch: 369 | Batch: 000 / 042 | Total loss: 0.591 | Reg loss: 0.025 | Tree loss: 0.591 | Accuracy: 0.810547 | 0.233 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
Epoch: 370 | Batch: 000 / 042 | Total loss: 0.608 | Reg loss: 0.025 | Tree loss: 0.608 | Accuracy: 0.808594 | 0.233 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
laye

Epoch: 391 | Batch: 000 / 042 | Total loss: 0.606 | Reg loss: 0.025 | Tree loss: 0.606 | Accuracy: 0.804688 | 0.233 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
Epoch: 392 | Batch: 000 / 042 | Total loss: 0.640 | Reg loss: 0.025 | Tree loss: 0.640 | Accuracy: 0.791016 | 0.233 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
Epoch: 393 | Batch: 000 / 042 | Total loss: 0.600 | Reg loss: 0.025 | Tree loss: 0.600 | Accuracy: 0.820312 | 0.233 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
laye

In [23]:
plt.figure(figsize=(10, 5))
plt.ylabel("Accuracy")
plt.xlabel('Iteration')
plt.plot(accs, label='Accuracy vs iteration')
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [24]:
plt.figure()
plt.ylabel("Loss")
plt.xlabel('Iteration')
plt.plot(losses, label='Loss vs iteration')
# plt.yscale("log")
plt.show()

plt.figure()
weights = tree.inner_nodes.weight.cpu().detach().numpy().flatten()
plt.hist(weights, bins=500)
weights_std = np.std(weights)
weights_mean = np.mean(weights)
plt.axvline(weights_mean + weights_std, color='r')
plt.axvline(weights_mean - weights_std, color='r')
plt.title(f"Mean: {weights_mean}   |   STD: {weights_std}")
plt.yscale("log")
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

# Tree Visualization

In [25]:
plt.figure(figsize=(15, 10), dpi=80)
avg_height, root = tree.visualize()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Average height: 5.666666666666667


# Extract Rules

# Accumulate samples in the leaves

In [26]:
print(f"Number of patterns: {len(root.get_leaves())}")

Number of patterns: 24


In [27]:
method = 'greedy'

In [28]:
root.clear_leaves_samples()

with torch.no_grad():
    for batch_idx, (data, target) in enumerate(tree_loader):
        root.accumulate_samples(data, method)



# Tighten boundaries

In [29]:
attr_names = [f"T_{i}" for i in range(test_samples.shape[2])]
leaves = root.get_leaves()
sum_comprehensibility = 0
comprehensibilities = []
for pattern_counter, leaf in enumerate(leaves):
    leaf.reset_path()
    leaf.tighten_with_accumulated_samples()
    conds = leaf.get_path_conditions(attr_names)
    print(f"============== Pattern {pattern_counter + 1} ==============")
    comprehensibilities.append(sum([cond.comprehensibility for cond in conds]))
    
print(f"Average comprehensibility: {np.mean(comprehensibilities)}")
print(f"std comprehensibility: {np.std(comprehensibilities)}")
print(f"var comprehensibility: {np.var(comprehensibilities)}")
print(f"minimum comprehensibility: {np.min(comprehensibilities)}")
print(f"maximum comprehensibility: {np.max(comprehensibilities)}")

138
507
485
10797
7571
10
1903
Average comprehensibility: 32.666666666666664
std comprehensibility: 10.514540197058336
var comprehensibility: 110.55555555555556
minimum comprehensibility: 12
maximum comprehensibility: 48


  return np.log(1 / (1 - x))
