In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib widget
import numpy as np
import pandas as pd
import torch
import torchvision
from torch import optim
import matplotlib.pyplot as plt
import torch.nn.functional as F
from sklearn.metrics import pairwise_distances
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tqdm.notebook import tqdm
import os
import sys
import time
import torch.nn as nn
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)
from stream_generators.mit_bih import MITBIH
from utils.MatplotlibUtils import reduce_dims_and_plot
from network.auto_encoder import AutoEncoder
from losses.knn_loss import KNNLoss, ClassificationKNNLoss
from sklearn.cluster import DBSCAN
from sklearn.mixture import GaussianMixture
from soft_decision_tree.sdt_model import SDT
from sklearn.metrics import davies_bouldin_score

In [3]:
k = 8
tree_depth = 6
batch_size = 512
device = 'cuda'
train_data_path = r'<>/mitbih_train.csv'  # replace <> with the correct path of the dataset
test_data_path = r'<>/mitbih_test.csv'  # replace <> with the correct path of the dataset

In [4]:
train_data_iter = torch.utils.data.DataLoader(MITBIH(train_data_path),
                                             batch_size=batch_size,
                                             shuffle=True,
                                             num_workers=1,
                                             pin_memory=True,
                                             drop_last=True)

test_data_iter = torch.utils.data.DataLoader(MITBIH(test_data_path),
                                             batch_size=batch_size,
                                             shuffle=True,
                                             num_workers=1,
                                             pin_memory=True)

In [5]:
class ConvBlock(nn.Module):
    def __init__(self):
        super(ConvBlock, self).__init__()
        self.conv1 = nn.Conv1d(32, 32, kernel_size=5, stride=1, padding=2)
        self.conv2 = nn.Conv1d(32, 32, kernel_size=5, stride=1, padding=2)
        self.relu1 = nn.ReLU()
        self.relu2 = nn.ReLU()
        self.pool = nn.MaxPool1d(kernel_size=5, stride=2)
        
    def forward(self, x):
        y = x
        y = self.conv1(y)
        y = self.relu1(y)
        y = self.conv2(y)
        y = y + x
        y = self.relu2(y)
        y = self.pool(y)
        return y


class ECGModel(nn.Module):
    def __init__(self):
        super(ECGModel, self).__init__()
        self.conv = nn.Conv1d(1, 32, kernel_size=5, stride=1, padding=1)
        self.block1 = ConvBlock()
        self.block2 = ConvBlock()
        self.block3 = ConvBlock()
        self.block4 = ConvBlock()
        self.block5 = ConvBlock()
        self.fc1 = nn.Linear(64, 32)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(32, 5)

    def forward(self, x, return_interm=False):
        x = self.conv(x)
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        x = self.block4(x)
        x = self.block5(x)
        interm = x.flatten(1)
        x = self.fc1(interm)
        x = self.relu(x)
        x = self.fc2(x)
        
        if return_interm:
            return x, interm
        
        return x

In [6]:
knn_crt = ClassificationKNNLoss(k=k).to(device)

def train(model, loader, optimizer, device):
    model.train()

    total_loss = 0
    for iteration, (batch, target) in enumerate(loader):
        batch = batch.to(device)
        target = target.to(device)
        outputs, interm = model(batch, return_interm=True)
        mse_loss = F.cross_entropy(outputs, target)
        mse_loss = mse_loss.sum(dim=-1).mean()
        try:
            knn_loss = knn_crt(interm, target)
            if torch.isinf(knn_loss):
                knn_loss = torch.tensor(0).to(device)
        except ValueError:
            knn_loss = torch.tensor(0).to(device)
        loss = mse_loss + knn_loss
        total_loss += loss.item()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if iteration % log_every == 0:
            print(f"Epoch {epoch} / {epochs} | iteration {iteration} / {len(loader)} | Total Loss: {loss.item()} | KNN Loss: {knn_loss.item()} | CLS Loss: {mse_loss.item()}")

    return total_loss / len(loader)

@torch.no_grad()
def test(model, loader, device):
    model.eval()
    
    correct = 0
    for iteration, (batch, target) in enumerate(loader):
        batch = batch.to(device)
        target = target.to(device)
        y_pred = model(batch).argmax(dim=-1)
        correct += y_pred.eq(target.view(-1).data).sum()
    
    return correct / len(loader.dataset)

In [9]:
epochs = 200
lr = 1e-3
log_every = 10

model = ECGModel().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
num_params = sum([p.numel() for p in model.parameters()])
print(f'#Params: {num_params}')

#Params: 53957


In [11]:
best_valid_acc = 0
losses = []
train_accs = []
val_accs = []
for epoch in range(1, epochs + 1):
    loss = train(model, train_data_iter, optimizer, device)
#     print(f"Loss: {loss} =============================")
    losses.append(loss)
    train_acc = test(model, train_data_iter, device)
    train_accs.append(train_acc)
    valid_acc = test(model, test_data_iter, device)
    val_accs.append(valid_acc)
    if valid_acc > best_valid_acc:
        best_valid_acc = valid_acc
    if epoch % 1 == 0:
        print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, '
              f'Train: {train_acc:.4f}, Valid: {valid_acc:.4f}, '
              f'Best: {best_valid_acc:.4f}')

Epoch 1 / 200 | iteration 0 / 171 | Total Loss: 7.52977180480957 | KNN Loss: 5.478849411010742 | CLS Loss: 2.050922155380249
Epoch 1 / 200 | iteration 10 / 171 | Total Loss: 3.683405876159668 | KNN Loss: 2.786168098449707 | CLS Loss: 0.8972376585006714
Epoch 1 / 200 | iteration 20 / 171 | Total Loss: 3.446229934692383 | KNN Loss: 2.6365323066711426 | CLS Loss: 0.8096975684165955
Epoch 1 / 200 | iteration 30 / 171 | Total Loss: 3.278867721557617 | KNN Loss: 2.509679079055786 | CLS Loss: 0.769188642501831
Epoch 1 / 200 | iteration 40 / 171 | Total Loss: 3.1989879608154297 | KNN Loss: 2.5724802017211914 | CLS Loss: 0.6265076994895935
Epoch 1 / 200 | iteration 50 / 171 | Total Loss: 3.1702568531036377 | KNN Loss: 2.635077953338623 | CLS Loss: 0.5351788997650146
Epoch 1 / 200 | iteration 60 / 171 | Total Loss: 3.150221586227417 | KNN Loss: 2.5995898246765137 | CLS Loss: 0.5506318211555481
Epoch 1 / 200 | iteration 70 / 171 | Total Loss: 3.1188483238220215 | KNN Loss: 2.5542352199554443 | CL

Epoch 4 / 200 | iteration 80 / 171 | Total Loss: 2.6416022777557373 | KNN Loss: 2.507258892059326 | CLS Loss: 0.1343434751033783
Epoch 4 / 200 | iteration 90 / 171 | Total Loss: 2.6252903938293457 | KNN Loss: 2.466026544570923 | CLS Loss: 0.15926380455493927
Epoch 4 / 200 | iteration 100 / 171 | Total Loss: 2.632112503051758 | KNN Loss: 2.4795005321502686 | CLS Loss: 0.15261198580265045
Epoch 4 / 200 | iteration 110 / 171 | Total Loss: 2.6179697513580322 | KNN Loss: 2.491077184677124 | CLS Loss: 0.12689253687858582
Epoch 4 / 200 | iteration 120 / 171 | Total Loss: 2.622281074523926 | KNN Loss: 2.503100633621216 | CLS Loss: 0.11918036639690399
Epoch 4 / 200 | iteration 130 / 171 | Total Loss: 2.625605583190918 | KNN Loss: 2.4916012287139893 | CLS Loss: 0.13400425016880035
Epoch 4 / 200 | iteration 140 / 171 | Total Loss: 2.697336435317993 | KNN Loss: 2.5071699619293213 | CLS Loss: 0.1901664286851883
Epoch 4 / 200 | iteration 150 / 171 | Total Loss: 2.6291849613189697 | KNN Loss: 2.52084

Epoch 7 / 200 | iteration 160 / 171 | Total Loss: 2.5521764755249023 | KNN Loss: 2.437957763671875 | CLS Loss: 0.11421866714954376
Epoch 7 / 200 | iteration 170 / 171 | Total Loss: 2.560342788696289 | KNN Loss: 2.450408935546875 | CLS Loss: 0.10993392020463943
Epoch: 007, Loss: 2.5501, Train: 0.9765, Valid: 0.9739, Best: 0.9739
Epoch 8 / 200 | iteration 0 / 171 | Total Loss: 2.5339887142181396 | KNN Loss: 2.462736129760742 | CLS Loss: 0.07125262171030045
Epoch 8 / 200 | iteration 10 / 171 | Total Loss: 2.5076987743377686 | KNN Loss: 2.4443304538726807 | CLS Loss: 0.06336834281682968
Epoch 8 / 200 | iteration 20 / 171 | Total Loss: 2.5391862392425537 | KNN Loss: 2.462339162826538 | CLS Loss: 0.07684702426195145
Epoch 8 / 200 | iteration 30 / 171 | Total Loss: 2.537283420562744 | KNN Loss: 2.4383511543273926 | CLS Loss: 0.0989321619272232
Epoch 8 / 200 | iteration 40 / 171 | Total Loss: 2.5249199867248535 | KNN Loss: 2.42889666557312 | CLS Loss: 0.096023328602314
Epoch 8 / 200 | iteratio

Epoch 11 / 200 | iteration 50 / 171 | Total Loss: 2.533390760421753 | KNN Loss: 2.4464941024780273 | CLS Loss: 0.08689658343791962
Epoch 11 / 200 | iteration 60 / 171 | Total Loss: 2.504801034927368 | KNN Loss: 2.42911434173584 | CLS Loss: 0.07568667083978653
Epoch 11 / 200 | iteration 70 / 171 | Total Loss: 2.5146284103393555 | KNN Loss: 2.4148519039154053 | CLS Loss: 0.09977652877569199
Epoch 11 / 200 | iteration 80 / 171 | Total Loss: 2.5232040882110596 | KNN Loss: 2.4557511806488037 | CLS Loss: 0.06745290011167526
Epoch 11 / 200 | iteration 90 / 171 | Total Loss: 2.4928107261657715 | KNN Loss: 2.41213321685791 | CLS Loss: 0.08067762106657028
Epoch 11 / 200 | iteration 100 / 171 | Total Loss: 2.5432424545288086 | KNN Loss: 2.459280252456665 | CLS Loss: 0.08396217226982117
Epoch 11 / 200 | iteration 110 / 171 | Total Loss: 2.536130666732788 | KNN Loss: 2.4660532474517822 | CLS Loss: 0.07007752358913422
Epoch 11 / 200 | iteration 120 / 171 | Total Loss: 2.5211100578308105 | KNN Loss: 

Epoch 14 / 200 | iteration 120 / 171 | Total Loss: 2.4870569705963135 | KNN Loss: 2.4060440063476562 | CLS Loss: 0.0810130164027214
Epoch 14 / 200 | iteration 130 / 171 | Total Loss: 2.472944974899292 | KNN Loss: 2.4321982860565186 | CLS Loss: 0.04074675589799881
Epoch 14 / 200 | iteration 140 / 171 | Total Loss: 2.482830047607422 | KNN Loss: 2.402920961380005 | CLS Loss: 0.07990897446870804
Epoch 14 / 200 | iteration 150 / 171 | Total Loss: 2.436797618865967 | KNN Loss: 2.4038889408111572 | CLS Loss: 0.032908715307712555
Epoch 14 / 200 | iteration 160 / 171 | Total Loss: 2.4863297939300537 | KNN Loss: 2.436267375946045 | CLS Loss: 0.050062377005815506
Epoch 14 / 200 | iteration 170 / 171 | Total Loss: 2.4573404788970947 | KNN Loss: 2.401641607284546 | CLS Loss: 0.05569878965616226
Epoch: 014, Loss: 2.4987, Train: 0.9845, Valid: 0.9803, Best: 0.9803
Epoch 15 / 200 | iteration 0 / 171 | Total Loss: 2.478961706161499 | KNN Loss: 2.4204139709472656 | CLS Loss: 0.05854771286249161
Epoch 15

Epoch 18 / 200 | iteration 10 / 171 | Total Loss: 2.4539005756378174 | KNN Loss: 2.4090795516967773 | CLS Loss: 0.04482097178697586
Epoch 18 / 200 | iteration 20 / 171 | Total Loss: 2.49318790435791 | KNN Loss: 2.4290058612823486 | CLS Loss: 0.06418205797672272
Epoch 18 / 200 | iteration 30 / 171 | Total Loss: 2.469351291656494 | KNN Loss: 2.4031410217285156 | CLS Loss: 0.06621021777391434
Epoch 18 / 200 | iteration 40 / 171 | Total Loss: 2.4979114532470703 | KNN Loss: 2.417506217956543 | CLS Loss: 0.0804053321480751
Epoch 18 / 200 | iteration 50 / 171 | Total Loss: 2.484232187271118 | KNN Loss: 2.4077649116516113 | CLS Loss: 0.07646722346544266
Epoch 18 / 200 | iteration 60 / 171 | Total Loss: 2.4698147773742676 | KNN Loss: 2.4296364784240723 | CLS Loss: 0.040178414434194565
Epoch 18 / 200 | iteration 70 / 171 | Total Loss: 2.477363348007202 | KNN Loss: 2.4143216609954834 | CLS Loss: 0.06304175406694412
Epoch 18 / 200 | iteration 80 / 171 | Total Loss: 2.446329116821289 | KNN Loss: 2.

Epoch 21 / 200 | iteration 80 / 171 | Total Loss: 2.460493564605713 | KNN Loss: 2.4292216300964355 | CLS Loss: 0.0312720388174057
Epoch 21 / 200 | iteration 90 / 171 | Total Loss: 2.4652044773101807 | KNN Loss: 2.40796160697937 | CLS Loss: 0.05724276602268219
Epoch 21 / 200 | iteration 100 / 171 | Total Loss: 2.481473445892334 | KNN Loss: 2.4282922744750977 | CLS Loss: 0.053181279450654984
Epoch 21 / 200 | iteration 110 / 171 | Total Loss: 2.482329845428467 | KNN Loss: 2.4284374713897705 | CLS Loss: 0.05389241501688957
Epoch 21 / 200 | iteration 120 / 171 | Total Loss: 2.483583927154541 | KNN Loss: 2.4493372440338135 | CLS Loss: 0.03424671292304993
Epoch 21 / 200 | iteration 130 / 171 | Total Loss: 2.5140562057495117 | KNN Loss: 2.4761171340942383 | CLS Loss: 0.03793904557824135
Epoch 21 / 200 | iteration 140 / 171 | Total Loss: 2.4397835731506348 | KNN Loss: 2.4003894329071045 | CLS Loss: 0.0393940731883049
Epoch 21 / 200 | iteration 150 / 171 | Total Loss: 2.4839394092559814 | KNN Lo

Epoch 24 / 200 | iteration 150 / 171 | Total Loss: 2.4883062839508057 | KNN Loss: 2.4558558464050293 | CLS Loss: 0.03245038911700249
Epoch 24 / 200 | iteration 160 / 171 | Total Loss: 2.4675092697143555 | KNN Loss: 2.445359230041504 | CLS Loss: 0.022149983793497086
Epoch 24 / 200 | iteration 170 / 171 | Total Loss: 2.449309825897217 | KNN Loss: 2.4265058040618896 | CLS Loss: 0.02280394360423088
Epoch: 024, Loss: 2.4654, Train: 0.9890, Valid: 0.9832, Best: 0.9832
Epoch 25 / 200 | iteration 0 / 171 | Total Loss: 2.475696086883545 | KNN Loss: 2.415036916732788 | CLS Loss: 0.06065918877720833
Epoch 25 / 200 | iteration 10 / 171 | Total Loss: 2.46584415435791 | KNN Loss: 2.416666269302368 | CLS Loss: 0.049177780747413635
Epoch 25 / 200 | iteration 20 / 171 | Total Loss: 2.4355766773223877 | KNN Loss: 2.3894009590148926 | CLS Loss: 0.04617573693394661
Epoch 25 / 200 | iteration 30 / 171 | Total Loss: 2.4771034717559814 | KNN Loss: 2.4271039962768555 | CLS Loss: 0.04999959096312523
Epoch 25 /

Epoch 28 / 200 | iteration 40 / 171 | Total Loss: 2.4825997352600098 | KNN Loss: 2.4116973876953125 | CLS Loss: 0.07090230286121368
Epoch 28 / 200 | iteration 50 / 171 | Total Loss: 2.4778966903686523 | KNN Loss: 2.411468029022217 | CLS Loss: 0.06642865389585495
Epoch 28 / 200 | iteration 60 / 171 | Total Loss: 2.41678786277771 | KNN Loss: 2.3814404010772705 | CLS Loss: 0.03534757345914841
Epoch 28 / 200 | iteration 70 / 171 | Total Loss: 2.4440457820892334 | KNN Loss: 2.3942651748657227 | CLS Loss: 0.04978051781654358
Epoch 28 / 200 | iteration 80 / 171 | Total Loss: 2.4999334812164307 | KNN Loss: 2.444127082824707 | CLS Loss: 0.0558064840734005
Epoch 28 / 200 | iteration 90 / 171 | Total Loss: 2.4510903358459473 | KNN Loss: 2.3916826248168945 | CLS Loss: 0.05940760299563408
Epoch 28 / 200 | iteration 100 / 171 | Total Loss: 2.4633383750915527 | KNN Loss: 2.41359543800354 | CLS Loss: 0.04974288493394852
Epoch 28 / 200 | iteration 110 / 171 | Total Loss: 2.42134165763855 | KNN Loss: 2.

Epoch 31 / 200 | iteration 110 / 171 | Total Loss: 2.4818503856658936 | KNN Loss: 2.434185266494751 | CLS Loss: 0.04766512289643288
Epoch 31 / 200 | iteration 120 / 171 | Total Loss: 2.4273600578308105 | KNN Loss: 2.3936688899993896 | CLS Loss: 0.033691149204969406
Epoch 31 / 200 | iteration 130 / 171 | Total Loss: 2.489271402359009 | KNN Loss: 2.447596549987793 | CLS Loss: 0.04167487099766731
Epoch 31 / 200 | iteration 140 / 171 | Total Loss: 2.460070848464966 | KNN Loss: 2.4281094074249268 | CLS Loss: 0.03196144476532936
Epoch 31 / 200 | iteration 150 / 171 | Total Loss: 2.418210506439209 | KNN Loss: 2.3808417320251465 | CLS Loss: 0.037368666380643845
Epoch 31 / 200 | iteration 160 / 171 | Total Loss: 2.445751190185547 | KNN Loss: 2.411686658859253 | CLS Loss: 0.03406447917222977
Epoch 31 / 200 | iteration 170 / 171 | Total Loss: 2.4296584129333496 | KNN Loss: 2.4023871421813965 | CLS Loss: 0.027271229773759842
Epoch: 031, Loss: 2.4526, Train: 0.9907, Valid: 0.9839, Best: 0.9845
Epoc

Epoch: 034, Loss: 2.4532, Train: 0.9897, Valid: 0.9841, Best: 0.9845
Epoch 35 / 200 | iteration 0 / 171 | Total Loss: 2.517206907272339 | KNN Loss: 2.4651033878326416 | CLS Loss: 0.05210347846150398
Epoch 35 / 200 | iteration 10 / 171 | Total Loss: 2.483157157897949 | KNN Loss: 2.4534711837768555 | CLS Loss: 0.029685938730835915
Epoch 35 / 200 | iteration 20 / 171 | Total Loss: 2.4959094524383545 | KNN Loss: 2.4635841846466064 | CLS Loss: 0.03232528641819954
Epoch 35 / 200 | iteration 30 / 171 | Total Loss: 2.441018581390381 | KNN Loss: 2.3870441913604736 | CLS Loss: 0.05397440120577812
Epoch 35 / 200 | iteration 40 / 171 | Total Loss: 2.457491636276245 | KNN Loss: 2.4232401847839355 | CLS Loss: 0.03425145149230957
Epoch 35 / 200 | iteration 50 / 171 | Total Loss: 2.4416353702545166 | KNN Loss: 2.381139039993286 | CLS Loss: 0.06049637123942375
Epoch 35 / 200 | iteration 60 / 171 | Total Loss: 2.4496896266937256 | KNN Loss: 2.4001247882843018 | CLS Loss: 0.04956480860710144
Epoch 35 / 2

Epoch 38 / 200 | iteration 70 / 171 | Total Loss: 2.454136848449707 | KNN Loss: 2.418821096420288 | CLS Loss: 0.03531566262245178
Epoch 38 / 200 | iteration 80 / 171 | Total Loss: 2.4294815063476562 | KNN Loss: 2.3769097328186035 | CLS Loss: 0.05257188901305199
Epoch 38 / 200 | iteration 90 / 171 | Total Loss: 2.469087600708008 | KNN Loss: 2.441887378692627 | CLS Loss: 0.02720019780099392
Epoch 38 / 200 | iteration 100 / 171 | Total Loss: 2.466557502746582 | KNN Loss: 2.4054830074310303 | CLS Loss: 0.06107444316148758
Epoch 38 / 200 | iteration 110 / 171 | Total Loss: 2.4904024600982666 | KNN Loss: 2.454918622970581 | CLS Loss: 0.035483818501234055
Epoch 38 / 200 | iteration 120 / 171 | Total Loss: 2.4492619037628174 | KNN Loss: 2.4023635387420654 | CLS Loss: 0.04689840227365494
Epoch 38 / 200 | iteration 130 / 171 | Total Loss: 2.4208014011383057 | KNN Loss: 2.4107556343078613 | CLS Loss: 0.010045796632766724
Epoch 38 / 200 | iteration 140 / 171 | Total Loss: 2.4701249599456787 | KNN 

Epoch 41 / 200 | iteration 140 / 171 | Total Loss: 2.44218111038208 | KNN Loss: 2.4063808917999268 | CLS Loss: 0.03580012544989586
Epoch 41 / 200 | iteration 150 / 171 | Total Loss: 2.4464876651763916 | KNN Loss: 2.4155213832855225 | CLS Loss: 0.030966177582740784
Epoch 41 / 200 | iteration 160 / 171 | Total Loss: 2.446061372756958 | KNN Loss: 2.4275972843170166 | CLS Loss: 0.01846412755548954
Epoch 41 / 200 | iteration 170 / 171 | Total Loss: 2.4749808311462402 | KNN Loss: 2.4332032203674316 | CLS Loss: 0.041777580976486206
Epoch: 041, Loss: 2.4464, Train: 0.9891, Valid: 0.9827, Best: 0.9846
Epoch 42 / 200 | iteration 0 / 171 | Total Loss: 2.4944796562194824 | KNN Loss: 2.446552038192749 | CLS Loss: 0.047927673906087875
Epoch 42 / 200 | iteration 10 / 171 | Total Loss: 2.452305793762207 | KNN Loss: 2.4282050132751465 | CLS Loss: 0.0241008959710598
Epoch 42 / 200 | iteration 20 / 171 | Total Loss: 2.4339780807495117 | KNN Loss: 2.4152019023895264 | CLS Loss: 0.01877608895301819
Epoch 4

Epoch 45 / 200 | iteration 20 / 171 | Total Loss: 2.442269802093506 | KNN Loss: 2.4124035835266113 | CLS Loss: 0.029866263270378113
Epoch 45 / 200 | iteration 30 / 171 | Total Loss: 2.4533894062042236 | KNN Loss: 2.4183874130249023 | CLS Loss: 0.0350019671022892
Epoch 45 / 200 | iteration 40 / 171 | Total Loss: 2.435678482055664 | KNN Loss: 2.4068074226379395 | CLS Loss: 0.028871044516563416
Epoch 45 / 200 | iteration 50 / 171 | Total Loss: 2.398043155670166 | KNN Loss: 2.3812410831451416 | CLS Loss: 0.016802074387669563
Epoch 45 / 200 | iteration 60 / 171 | Total Loss: 2.450256824493408 | KNN Loss: 2.4234704971313477 | CLS Loss: 0.026786241680383682
Epoch 45 / 200 | iteration 70 / 171 | Total Loss: 2.481976270675659 | KNN Loss: 2.451282262802124 | CLS Loss: 0.030694114044308662
Epoch 45 / 200 | iteration 80 / 171 | Total Loss: 2.421339273452759 | KNN Loss: 2.396886110305786 | CLS Loss: 0.024453258141875267
Epoch 45 / 200 | iteration 90 / 171 | Total Loss: 2.4625027179718018 | KNN Loss

Epoch 48 / 200 | iteration 90 / 171 | Total Loss: 2.4567718505859375 | KNN Loss: 2.4055402278900146 | CLS Loss: 0.051231712102890015
Epoch 48 / 200 | iteration 100 / 171 | Total Loss: 2.4219202995300293 | KNN Loss: 2.3998847007751465 | CLS Loss: 0.022035673260688782
Epoch 48 / 200 | iteration 110 / 171 | Total Loss: 2.4557180404663086 | KNN Loss: 2.4037559032440186 | CLS Loss: 0.05196215584874153
Epoch 48 / 200 | iteration 120 / 171 | Total Loss: 2.4206645488739014 | KNN Loss: 2.386237382888794 | CLS Loss: 0.03442712128162384
Epoch 48 / 200 | iteration 130 / 171 | Total Loss: 2.4688076972961426 | KNN Loss: 2.4180307388305664 | CLS Loss: 0.05077705159783363
Epoch 48 / 200 | iteration 140 / 171 | Total Loss: 2.424656391143799 | KNN Loss: 2.4016687870025635 | CLS Loss: 0.02298755757510662
Epoch 48 / 200 | iteration 150 / 171 | Total Loss: 2.411862850189209 | KNN Loss: 2.390775442123413 | CLS Loss: 0.021087318658828735
Epoch 48 / 200 | iteration 160 / 171 | Total Loss: 2.4263827800750732 |

Epoch 51 / 200 | iteration 160 / 171 | Total Loss: 2.409090757369995 | KNN Loss: 2.3777310848236084 | CLS Loss: 0.03135956451296806
Epoch 51 / 200 | iteration 170 / 171 | Total Loss: 2.402597665786743 | KNN Loss: 2.367567539215088 | CLS Loss: 0.03503013774752617
Epoch: 051, Loss: 2.4357, Train: 0.9929, Valid: 0.9860, Best: 0.9860
Epoch 52 / 200 | iteration 0 / 171 | Total Loss: 2.415120840072632 | KNN Loss: 2.4008758068084717 | CLS Loss: 0.014245065860450268
Epoch 52 / 200 | iteration 10 / 171 | Total Loss: 2.444334030151367 | KNN Loss: 2.418369770050049 | CLS Loss: 0.025964368134737015
Epoch 52 / 200 | iteration 20 / 171 | Total Loss: 2.470932722091675 | KNN Loss: 2.4322025775909424 | CLS Loss: 0.03873008117079735
Epoch 52 / 200 | iteration 30 / 171 | Total Loss: 2.439960241317749 | KNN Loss: 2.421548843383789 | CLS Loss: 0.018411358818411827
Epoch 52 / 200 | iteration 40 / 171 | Total Loss: 2.435547351837158 | KNN Loss: 2.4010045528411865 | CLS Loss: 0.0345427580177784
Epoch 52 / 200

Epoch 55 / 200 | iteration 50 / 171 | Total Loss: 2.47633695602417 | KNN Loss: 2.455465793609619 | CLS Loss: 0.020871154963970184
Epoch 55 / 200 | iteration 60 / 171 | Total Loss: 2.4389569759368896 | KNN Loss: 2.403843879699707 | CLS Loss: 0.03511308133602142
Epoch 55 / 200 | iteration 70 / 171 | Total Loss: 2.3977174758911133 | KNN Loss: 2.3854451179504395 | CLS Loss: 0.0122722452506423
Epoch 55 / 200 | iteration 80 / 171 | Total Loss: 2.436746597290039 | KNN Loss: 2.413393020629883 | CLS Loss: 0.023353606462478638
Epoch 55 / 200 | iteration 90 / 171 | Total Loss: 2.4991507530212402 | KNN Loss: 2.4536969661712646 | CLS Loss: 0.04545385017991066
Epoch 55 / 200 | iteration 100 / 171 | Total Loss: 2.4004952907562256 | KNN Loss: 2.372343063354492 | CLS Loss: 0.02815225161612034
Epoch 55 / 200 | iteration 110 / 171 | Total Loss: 2.4772653579711914 | KNN Loss: 2.4401230812072754 | CLS Loss: 0.03714223578572273
Epoch 55 / 200 | iteration 120 / 171 | Total Loss: 2.4260425567626953 | KNN Loss

Epoch 58 / 200 | iteration 120 / 171 | Total Loss: 2.4304449558258057 | KNN Loss: 2.4017574787139893 | CLS Loss: 0.028687389567494392
Epoch 58 / 200 | iteration 130 / 171 | Total Loss: 2.394601821899414 | KNN Loss: 2.381950855255127 | CLS Loss: 0.012651074677705765
Epoch 58 / 200 | iteration 140 / 171 | Total Loss: 2.461202621459961 | KNN Loss: 2.4432148933410645 | CLS Loss: 0.01798783428966999
Epoch 58 / 200 | iteration 150 / 171 | Total Loss: 2.4144227504730225 | KNN Loss: 2.406914234161377 | CLS Loss: 0.007508426439017057
Epoch 58 / 200 | iteration 160 / 171 | Total Loss: 2.4096717834472656 | KNN Loss: 2.388731002807617 | CLS Loss: 0.020940696820616722
Epoch 58 / 200 | iteration 170 / 171 | Total Loss: 2.426217555999756 | KNN Loss: 2.4011027812957764 | CLS Loss: 0.025114860385656357
Epoch: 058, Loss: 2.4342, Train: 0.9936, Valid: 0.9856, Best: 0.9860
Epoch 59 / 200 | iteration 0 / 171 | Total Loss: 2.460653305053711 | KNN Loss: 2.427990674972534 | CLS Loss: 0.032662615180015564
Epoc

Epoch 62 / 200 | iteration 10 / 171 | Total Loss: 2.4191205501556396 | KNN Loss: 2.394473075866699 | CLS Loss: 0.02464747615158558
Epoch 62 / 200 | iteration 20 / 171 | Total Loss: 2.399761438369751 | KNN Loss: 2.379631280899048 | CLS Loss: 0.020130233839154243
Epoch 62 / 200 | iteration 30 / 171 | Total Loss: 2.4271254539489746 | KNN Loss: 2.3896608352661133 | CLS Loss: 0.037464678287506104
Epoch 62 / 200 | iteration 40 / 171 | Total Loss: 2.4202849864959717 | KNN Loss: 2.3849093914031982 | CLS Loss: 0.03537552058696747
Epoch 62 / 200 | iteration 50 / 171 | Total Loss: 2.4359982013702393 | KNN Loss: 2.382061719894409 | CLS Loss: 0.053936492651700974
Epoch 62 / 200 | iteration 60 / 171 | Total Loss: 2.4417028427124023 | KNN Loss: 2.4204752445220947 | CLS Loss: 0.02122764103114605
Epoch 62 / 200 | iteration 70 / 171 | Total Loss: 2.459484577178955 | KNN Loss: 2.443368434906006 | CLS Loss: 0.016116205602884293
Epoch 62 / 200 | iteration 80 / 171 | Total Loss: 2.479203939437866 | KNN Loss

Epoch 65 / 200 | iteration 80 / 171 | Total Loss: 2.437222719192505 | KNN Loss: 2.396211624145508 | CLS Loss: 0.041011132299900055
Epoch 65 / 200 | iteration 90 / 171 | Total Loss: 2.4104537963867188 | KNN Loss: 2.3935484886169434 | CLS Loss: 0.016905302181839943
Epoch 65 / 200 | iteration 100 / 171 | Total Loss: 2.4210989475250244 | KNN Loss: 2.4019148349761963 | CLS Loss: 0.0191841721534729
Epoch 65 / 200 | iteration 110 / 171 | Total Loss: 2.386496067047119 | KNN Loss: 2.378312826156616 | CLS Loss: 0.008183357305824757
Epoch 65 / 200 | iteration 120 / 171 | Total Loss: 2.4694080352783203 | KNN Loss: 2.446728467941284 | CLS Loss: 0.022679666057229042
Epoch 65 / 200 | iteration 130 / 171 | Total Loss: 2.4052248001098633 | KNN Loss: 2.3882646560668945 | CLS Loss: 0.016960179433226585
Epoch 65 / 200 | iteration 140 / 171 | Total Loss: 2.4872426986694336 | KNN Loss: 2.445955276489258 | CLS Loss: 0.04128730297088623
Epoch 65 / 200 | iteration 150 / 171 | Total Loss: 2.456246852874756 | KN

Epoch 68 / 200 | iteration 150 / 171 | Total Loss: 2.4303667545318604 | KNN Loss: 2.3876781463623047 | CLS Loss: 0.04268871620297432
Epoch 68 / 200 | iteration 160 / 171 | Total Loss: 2.4061567783355713 | KNN Loss: 2.383423089981079 | CLS Loss: 0.022733604535460472
Epoch 68 / 200 | iteration 170 / 171 | Total Loss: 2.436418056488037 | KNN Loss: 2.423677921295166 | CLS Loss: 0.012740061618387699
Epoch: 068, Loss: 2.4277, Train: 0.9944, Valid: 0.9857, Best: 0.9866
Epoch 69 / 200 | iteration 0 / 171 | Total Loss: 2.436387062072754 | KNN Loss: 2.422041177749634 | CLS Loss: 0.0143458666279912
Epoch 69 / 200 | iteration 10 / 171 | Total Loss: 2.457785129547119 | KNN Loss: 2.423933744430542 | CLS Loss: 0.03385132551193237
Epoch 69 / 200 | iteration 20 / 171 | Total Loss: 2.3917553424835205 | KNN Loss: 2.3669304847717285 | CLS Loss: 0.02482491545379162
Epoch 69 / 200 | iteration 30 / 171 | Total Loss: 2.4182498455047607 | KNN Loss: 2.4029653072357178 | CLS Loss: 0.015284519642591476
Epoch 69 /

Epoch 72 / 200 | iteration 30 / 171 | Total Loss: 2.430877447128296 | KNN Loss: 2.417602062225342 | CLS Loss: 0.013275275938212872
Epoch 72 / 200 | iteration 40 / 171 | Total Loss: 2.464191198348999 | KNN Loss: 2.429706573486328 | CLS Loss: 0.034484680742025375
Epoch 72 / 200 | iteration 50 / 171 | Total Loss: 2.448460340499878 | KNN Loss: 2.4159939289093018 | CLS Loss: 0.032466430217027664
Epoch 72 / 200 | iteration 60 / 171 | Total Loss: 2.4413986206054688 | KNN Loss: 2.4285240173339844 | CLS Loss: 0.012874648906290531
Epoch 72 / 200 | iteration 70 / 171 | Total Loss: 2.411278247833252 | KNN Loss: 2.4019155502319336 | CLS Loss: 0.009362760931253433
Epoch 72 / 200 | iteration 80 / 171 | Total Loss: 2.4497344493865967 | KNN Loss: 2.412018060684204 | CLS Loss: 0.03771643340587616
Epoch 72 / 200 | iteration 90 / 171 | Total Loss: 2.401552200317383 | KNN Loss: 2.382859706878662 | CLS Loss: 0.018692491576075554
Epoch 72 / 200 | iteration 100 / 171 | Total Loss: 2.404322862625122 | KNN Loss

Epoch 75 / 200 | iteration 100 / 171 | Total Loss: 2.4048728942871094 | KNN Loss: 2.392854928970337 | CLS Loss: 0.012018021196126938
Epoch 75 / 200 | iteration 110 / 171 | Total Loss: 2.466926336288452 | KNN Loss: 2.4456522464752197 | CLS Loss: 0.02127411775290966
Epoch 75 / 200 | iteration 120 / 171 | Total Loss: 2.405311107635498 | KNN Loss: 2.393173933029175 | CLS Loss: 0.012137136422097683
Epoch 75 / 200 | iteration 130 / 171 | Total Loss: 2.438831329345703 | KNN Loss: 2.384939432144165 | CLS Loss: 0.05389198288321495
Epoch 75 / 200 | iteration 140 / 171 | Total Loss: 2.4550788402557373 | KNN Loss: 2.433105230331421 | CLS Loss: 0.021973541006445885
Epoch 75 / 200 | iteration 150 / 171 | Total Loss: 2.463841676712036 | KNN Loss: 2.4456517696380615 | CLS Loss: 0.01818997785449028
Epoch 75 / 200 | iteration 160 / 171 | Total Loss: 2.4163732528686523 | KNN Loss: 2.3851430416107178 | CLS Loss: 0.03123011253774166
Epoch 75 / 200 | iteration 170 / 171 | Total Loss: 2.428574323654175 | KNN

Epoch 78 / 200 | iteration 170 / 171 | Total Loss: 2.4345672130584717 | KNN Loss: 2.4214274883270264 | CLS Loss: 0.013139664195477962
Epoch: 078, Loss: 2.4296, Train: 0.9949, Valid: 0.9866, Best: 0.9867
Epoch 79 / 200 | iteration 0 / 171 | Total Loss: 2.4269144535064697 | KNN Loss: 2.4038190841674805 | CLS Loss: 0.023095306009054184
Epoch 79 / 200 | iteration 10 / 171 | Total Loss: 2.443507432937622 | KNN Loss: 2.431255340576172 | CLS Loss: 0.012252112850546837
Epoch 79 / 200 | iteration 20 / 171 | Total Loss: 2.4081921577453613 | KNN Loss: 2.3945302963256836 | CLS Loss: 0.013661973178386688
Epoch 79 / 200 | iteration 30 / 171 | Total Loss: 2.4331908226013184 | KNN Loss: 2.412888526916504 | CLS Loss: 0.020302342250943184
Epoch 79 / 200 | iteration 40 / 171 | Total Loss: 2.438342332839966 | KNN Loss: 2.4174304008483887 | CLS Loss: 0.02091195620596409
Epoch 79 / 200 | iteration 50 / 171 | Total Loss: 2.4545767307281494 | KNN Loss: 2.4238579273223877 | CLS Loss: 0.03071880340576172
Epoch 

Epoch 82 / 200 | iteration 50 / 171 | Total Loss: 2.4492440223693848 | KNN Loss: 2.4388415813446045 | CLS Loss: 0.010402447544038296
Epoch 82 / 200 | iteration 60 / 171 | Total Loss: 2.4688706398010254 | KNN Loss: 2.4586684703826904 | CLS Loss: 0.010202163830399513
Epoch 82 / 200 | iteration 70 / 171 | Total Loss: 2.463597297668457 | KNN Loss: 2.4525861740112305 | CLS Loss: 0.011011175811290741
Epoch 82 / 200 | iteration 80 / 171 | Total Loss: 2.415379285812378 | KNN Loss: 2.375894069671631 | CLS Loss: 0.03948516771197319
Epoch 82 / 200 | iteration 90 / 171 | Total Loss: 2.419111728668213 | KNN Loss: 2.385976791381836 | CLS Loss: 0.033134832978248596
Epoch 82 / 200 | iteration 100 / 171 | Total Loss: 2.4012298583984375 | KNN Loss: 2.3946969509124756 | CLS Loss: 0.006532819475978613
Epoch 82 / 200 | iteration 110 / 171 | Total Loss: 2.432478904724121 | KNN Loss: 2.3923470973968506 | CLS Loss: 0.04013191908597946
Epoch 82 / 200 | iteration 120 / 171 | Total Loss: 2.382209300994873 | KNN 

Epoch 85 / 200 | iteration 120 / 171 | Total Loss: 2.4342520236968994 | KNN Loss: 2.411630868911743 | CLS Loss: 0.02262107841670513
Epoch 85 / 200 | iteration 130 / 171 | Total Loss: 2.422337055206299 | KNN Loss: 2.3854331970214844 | CLS Loss: 0.036903850734233856
Epoch 85 / 200 | iteration 140 / 171 | Total Loss: 2.418377161026001 | KNN Loss: 2.4061927795410156 | CLS Loss: 0.01218440756201744
Epoch 85 / 200 | iteration 150 / 171 | Total Loss: 2.418635845184326 | KNN Loss: 2.3924341201782227 | CLS Loss: 0.026201780885457993
Epoch 85 / 200 | iteration 160 / 171 | Total Loss: 2.4250240325927734 | KNN Loss: 2.4066641330718994 | CLS Loss: 0.01835988275706768
Epoch 85 / 200 | iteration 170 / 171 | Total Loss: 2.446455717086792 | KNN Loss: 2.4265034198760986 | CLS Loss: 0.01995234563946724
Epoch: 085, Loss: 2.4278, Train: 0.9949, Valid: 0.9858, Best: 0.9868
Epoch 86 / 200 | iteration 0 / 171 | Total Loss: 2.418454647064209 | KNN Loss: 2.409790277481079 | CLS Loss: 0.008664438501000404
Epoch 

Epoch 89 / 200 | iteration 0 / 171 | Total Loss: 2.4605512619018555 | KNN Loss: 2.4267759323120117 | CLS Loss: 0.03377522528171539
Epoch 89 / 200 | iteration 10 / 171 | Total Loss: 2.4157261848449707 | KNN Loss: 2.3930156230926514 | CLS Loss: 0.022710641846060753
Epoch 89 / 200 | iteration 20 / 171 | Total Loss: 2.409142255783081 | KNN Loss: 2.394427537918091 | CLS Loss: 0.014714610762894154
Epoch 89 / 200 | iteration 30 / 171 | Total Loss: 2.4521121978759766 | KNN Loss: 2.440814733505249 | CLS Loss: 0.011297388933598995
Epoch 89 / 200 | iteration 40 / 171 | Total Loss: 2.4288036823272705 | KNN Loss: 2.4195895195007324 | CLS Loss: 0.009214275516569614
Epoch 89 / 200 | iteration 50 / 171 | Total Loss: 2.4774787425994873 | KNN Loss: 2.464404582977295 | CLS Loss: 0.013074223883450031
Epoch 89 / 200 | iteration 60 / 171 | Total Loss: 2.4256155490875244 | KNN Loss: 2.4093194007873535 | CLS Loss: 0.01629617251455784
Epoch 89 / 200 | iteration 70 / 171 | Total Loss: 2.41373872756958 | KNN Los

Epoch 92 / 200 | iteration 70 / 171 | Total Loss: 2.427361249923706 | KNN Loss: 2.412095308303833 | CLS Loss: 0.015265850350260735
Epoch 92 / 200 | iteration 80 / 171 | Total Loss: 2.4231443405151367 | KNN Loss: 2.3930857181549072 | CLS Loss: 0.030058644711971283
Epoch 92 / 200 | iteration 90 / 171 | Total Loss: 2.4359352588653564 | KNN Loss: 2.4269769191741943 | CLS Loss: 0.008958343416452408
Epoch 92 / 200 | iteration 100 / 171 | Total Loss: 2.4010519981384277 | KNN Loss: 2.3887932300567627 | CLS Loss: 0.012258786708116531
Epoch 92 / 200 | iteration 110 / 171 | Total Loss: 2.415271282196045 | KNN Loss: 2.399458408355713 | CLS Loss: 0.01581294648349285
Epoch 92 / 200 | iteration 120 / 171 | Total Loss: 2.4276633262634277 | KNN Loss: 2.4040791988372803 | CLS Loss: 0.02358406037092209
Epoch 92 / 200 | iteration 130 / 171 | Total Loss: 2.4201266765594482 | KNN Loss: 2.3990988731384277 | CLS Loss: 0.021027911454439163
Epoch 92 / 200 | iteration 140 / 171 | Total Loss: 2.4096932411193848 |

Epoch 95 / 200 | iteration 140 / 171 | Total Loss: 2.4244463443756104 | KNN Loss: 2.4045023918151855 | CLS Loss: 0.019944000989198685
Epoch 95 / 200 | iteration 150 / 171 | Total Loss: 2.438995838165283 | KNN Loss: 2.427135944366455 | CLS Loss: 0.011859981343150139
Epoch 95 / 200 | iteration 160 / 171 | Total Loss: 2.460590362548828 | KNN Loss: 2.4392266273498535 | CLS Loss: 0.02136370725929737
Epoch 95 / 200 | iteration 170 / 171 | Total Loss: 2.4060490131378174 | KNN Loss: 2.393434762954712 | CLS Loss: 0.01261430699378252
Epoch: 095, Loss: 2.4278, Train: 0.9958, Valid: 0.9873, Best: 0.9873
Epoch 96 / 200 | iteration 0 / 171 | Total Loss: 2.404930830001831 | KNN Loss: 2.3953051567077637 | CLS Loss: 0.009625597856938839
Epoch 96 / 200 | iteration 10 / 171 | Total Loss: 2.406830310821533 | KNN Loss: 2.404090642929077 | CLS Loss: 0.002739636693149805
Epoch 96 / 200 | iteration 20 / 171 | Total Loss: 2.4484074115753174 | KNN Loss: 2.441992998123169 | CLS Loss: 0.0064143287017941475
Epoch 

Epoch 99 / 200 | iteration 30 / 171 | Total Loss: 2.427884340286255 | KNN Loss: 2.402944326400757 | CLS Loss: 0.02493991330265999
Epoch 99 / 200 | iteration 40 / 171 | Total Loss: 2.4274613857269287 | KNN Loss: 2.395341634750366 | CLS Loss: 0.032119765877723694
Epoch 99 / 200 | iteration 50 / 171 | Total Loss: 2.3902907371520996 | KNN Loss: 2.382582187652588 | CLS Loss: 0.007708489894866943
Epoch 99 / 200 | iteration 60 / 171 | Total Loss: 2.4152441024780273 | KNN Loss: 2.3993537425994873 | CLS Loss: 0.015890317037701607
Epoch 99 / 200 | iteration 70 / 171 | Total Loss: 2.410616397857666 | KNN Loss: 2.4008688926696777 | CLS Loss: 0.00974738597869873
Epoch 99 / 200 | iteration 80 / 171 | Total Loss: 2.4133572578430176 | KNN Loss: 2.4091951847076416 | CLS Loss: 0.004162167198956013
Epoch 99 / 200 | iteration 90 / 171 | Total Loss: 2.4021308422088623 | KNN Loss: 2.3790841102600098 | CLS Loss: 0.023046741262078285
Epoch 99 / 200 | iteration 100 / 171 | Total Loss: 2.450831651687622 | KNN L

Epoch 102 / 200 | iteration 100 / 171 | Total Loss: 2.411527395248413 | KNN Loss: 2.4024436473846436 | CLS Loss: 0.009083799086511135
Epoch 102 / 200 | iteration 110 / 171 | Total Loss: 2.408468246459961 | KNN Loss: 2.392312526702881 | CLS Loss: 0.016155634075403214
Epoch 102 / 200 | iteration 120 / 171 | Total Loss: 2.4035086631774902 | KNN Loss: 2.396930456161499 | CLS Loss: 0.006578207481652498
Epoch 102 / 200 | iteration 130 / 171 | Total Loss: 2.4272241592407227 | KNN Loss: 2.400563955307007 | CLS Loss: 0.02666030265390873
Epoch 102 / 200 | iteration 140 / 171 | Total Loss: 2.428779125213623 | KNN Loss: 2.3815722465515137 | CLS Loss: 0.0472068265080452
Epoch 102 / 200 | iteration 150 / 171 | Total Loss: 2.4267849922180176 | KNN Loss: 2.4046990871429443 | CLS Loss: 0.022085899487137794
Epoch 102 / 200 | iteration 160 / 171 | Total Loss: 2.4326443672180176 | KNN Loss: 2.407682418823242 | CLS Loss: 0.02496200054883957
Epoch 102 / 200 | iteration 170 / 171 | Total Loss: 2.431112527847

Epoch 105 / 200 | iteration 160 / 171 | Total Loss: 2.4155220985412598 | KNN Loss: 2.3995425701141357 | CLS Loss: 0.015979526564478874
Epoch 105 / 200 | iteration 170 / 171 | Total Loss: 2.4406630992889404 | KNN Loss: 2.4181275367736816 | CLS Loss: 0.022535528987646103
Epoch: 105, Loss: 2.4310, Train: 0.9942, Valid: 0.9847, Best: 0.9873
Epoch 106 / 200 | iteration 0 / 171 | Total Loss: 2.4072766304016113 | KNN Loss: 2.4002904891967773 | CLS Loss: 0.006986184511333704
Epoch 106 / 200 | iteration 10 / 171 | Total Loss: 2.4370627403259277 | KNN Loss: 2.4285566806793213 | CLS Loss: 0.008506151847541332
Epoch 106 / 200 | iteration 20 / 171 | Total Loss: 2.4430904388427734 | KNN Loss: 2.417677640914917 | CLS Loss: 0.025412827730178833
Epoch 106 / 200 | iteration 30 / 171 | Total Loss: 2.482295036315918 | KNN Loss: 2.452075719833374 | CLS Loss: 0.03021942265331745
Epoch 106 / 200 | iteration 40 / 171 | Total Loss: 2.4652256965637207 | KNN Loss: 2.454318046569824 | CLS Loss: 0.0109077692031860

Epoch 109 / 200 | iteration 40 / 171 | Total Loss: 2.4329068660736084 | KNN Loss: 2.4146578311920166 | CLS Loss: 0.018248990178108215
Epoch 109 / 200 | iteration 50 / 171 | Total Loss: 2.4329724311828613 | KNN Loss: 2.426724672317505 | CLS Loss: 0.006247641518712044
Epoch 109 / 200 | iteration 60 / 171 | Total Loss: 2.4577348232269287 | KNN Loss: 2.4234988689422607 | CLS Loss: 0.034236036241054535
Epoch 109 / 200 | iteration 70 / 171 | Total Loss: 2.435185670852661 | KNN Loss: 2.3980472087860107 | CLS Loss: 0.03713855519890785
Epoch 109 / 200 | iteration 80 / 171 | Total Loss: 2.403531312942505 | KNN Loss: 2.3827927112579346 | CLS Loss: 0.02073850855231285
Epoch 109 / 200 | iteration 90 / 171 | Total Loss: 2.4227259159088135 | KNN Loss: 2.412520408630371 | CLS Loss: 0.010205565020442009
Epoch 109 / 200 | iteration 100 / 171 | Total Loss: 2.383897304534912 | KNN Loss: 2.370675802230835 | CLS Loss: 0.013221485540270805
Epoch 109 / 200 | iteration 110 / 171 | Total Loss: 2.420781373977661

Epoch 112 / 200 | iteration 100 / 171 | Total Loss: 2.4082491397857666 | KNN Loss: 2.3917176723480225 | CLS Loss: 0.01653146930038929
Epoch 112 / 200 | iteration 110 / 171 | Total Loss: 2.4349284172058105 | KNN Loss: 2.414348840713501 | CLS Loss: 0.020579524338245392
Epoch 112 / 200 | iteration 120 / 171 | Total Loss: 2.38258957862854 | KNN Loss: 2.3809475898742676 | CLS Loss: 0.0016420057509094477
Epoch 112 / 200 | iteration 130 / 171 | Total Loss: 2.412882089614868 | KNN Loss: 2.3887999057769775 | CLS Loss: 0.024082284420728683
Epoch 112 / 200 | iteration 140 / 171 | Total Loss: 2.453146457672119 | KNN Loss: 2.419658899307251 | CLS Loss: 0.03348749503493309
Epoch 112 / 200 | iteration 150 / 171 | Total Loss: 2.4231884479522705 | KNN Loss: 2.4028847217559814 | CLS Loss: 0.0203036367893219
Epoch 112 / 200 | iteration 160 / 171 | Total Loss: 2.4183382987976074 | KNN Loss: 2.4025518894195557 | CLS Loss: 0.01578647643327713
Epoch 112 / 200 | iteration 170 / 171 | Total Loss: 2.42771196365

Epoch 115 / 200 | iteration 160 / 171 | Total Loss: 2.409503936767578 | KNN Loss: 2.3997392654418945 | CLS Loss: 0.00976472720503807
Epoch 115 / 200 | iteration 170 / 171 | Total Loss: 2.4153614044189453 | KNN Loss: 2.403148889541626 | CLS Loss: 0.012212545610964298
Epoch: 115, Loss: 2.4271, Train: 0.9960, Valid: 0.9868, Best: 0.9873
Epoch 116 / 200 | iteration 0 / 171 | Total Loss: 2.399867534637451 | KNN Loss: 2.394695997238159 | CLS Loss: 0.005171430762857199
Epoch 116 / 200 | iteration 10 / 171 | Total Loss: 2.420572519302368 | KNN Loss: 2.4115099906921387 | CLS Loss: 0.009062465280294418
Epoch 116 / 200 | iteration 20 / 171 | Total Loss: 2.3876309394836426 | KNN Loss: 2.3739027976989746 | CLS Loss: 0.013728126883506775
Epoch 116 / 200 | iteration 30 / 171 | Total Loss: 2.449387788772583 | KNN Loss: 2.441389322280884 | CLS Loss: 0.007998385466635227
Epoch 116 / 200 | iteration 40 / 171 | Total Loss: 2.408507823944092 | KNN Loss: 2.404376983642578 | CLS Loss: 0.004130837973207235
Ep

Epoch 119 / 200 | iteration 40 / 171 | Total Loss: 2.4057178497314453 | KNN Loss: 2.37969708442688 | CLS Loss: 0.026020776480436325
Epoch 119 / 200 | iteration 50 / 171 | Total Loss: 2.421367883682251 | KNN Loss: 2.4002532958984375 | CLS Loss: 0.02111465483903885
Epoch 119 / 200 | iteration 60 / 171 | Total Loss: 2.3934879302978516 | KNN Loss: 2.3834939002990723 | CLS Loss: 0.009994005784392357
Epoch 119 / 200 | iteration 70 / 171 | Total Loss: 2.433379888534546 | KNN Loss: 2.4190707206726074 | CLS Loss: 0.014309154823422432
Epoch 119 / 200 | iteration 80 / 171 | Total Loss: 2.419113874435425 | KNN Loss: 2.4131522178649902 | CLS Loss: 0.005961698479950428
Epoch 119 / 200 | iteration 90 / 171 | Total Loss: 2.434375286102295 | KNN Loss: 2.4237418174743652 | CLS Loss: 0.010633576661348343
Epoch 119 / 200 | iteration 100 / 171 | Total Loss: 2.459456205368042 | KNN Loss: 2.452127456665039 | CLS Loss: 0.007328852079808712
Epoch 119 / 200 | iteration 110 / 171 | Total Loss: 2.425854206085205 

Epoch 122 / 200 | iteration 100 / 171 | Total Loss: 2.404557943344116 | KNN Loss: 2.397101640701294 | CLS Loss: 0.00745633477345109
Epoch 122 / 200 | iteration 110 / 171 | Total Loss: 2.434159278869629 | KNN Loss: 2.408547878265381 | CLS Loss: 0.02561141364276409
Epoch 122 / 200 | iteration 120 / 171 | Total Loss: 2.4244706630706787 | KNN Loss: 2.407862424850464 | CLS Loss: 0.016608137637376785
Epoch 122 / 200 | iteration 130 / 171 | Total Loss: 2.382575750350952 | KNN Loss: 2.364762783050537 | CLS Loss: 0.01781296357512474
Epoch 122 / 200 | iteration 140 / 171 | Total Loss: 2.4460153579711914 | KNN Loss: 2.4384233951568604 | CLS Loss: 0.0075920443050563335
Epoch 122 / 200 | iteration 150 / 171 | Total Loss: 2.448647975921631 | KNN Loss: 2.421083688735962 | CLS Loss: 0.02756418101489544
Epoch 122 / 200 | iteration 160 / 171 | Total Loss: 2.433551549911499 | KNN Loss: 2.4132750034332275 | CLS Loss: 0.020276609808206558
Epoch 122 / 200 | iteration 170 / 171 | Total Loss: 2.43859910964965

Epoch 125 / 200 | iteration 160 / 171 | Total Loss: 2.4472408294677734 | KNN Loss: 2.4342398643493652 | CLS Loss: 0.01300099678337574
Epoch 125 / 200 | iteration 170 / 171 | Total Loss: 2.43477201461792 | KNN Loss: 2.42405104637146 | CLS Loss: 0.010720849968492985
Epoch: 125, Loss: 2.4259, Train: 0.9959, Valid: 0.9864, Best: 0.9873
Epoch 126 / 200 | iteration 0 / 171 | Total Loss: 2.41845965385437 | KNN Loss: 2.414199113845825 | CLS Loss: 0.004260484594851732
Epoch 126 / 200 | iteration 10 / 171 | Total Loss: 2.4278173446655273 | KNN Loss: 2.4222986698150635 | CLS Loss: 0.005518776830285788
Epoch 126 / 200 | iteration 20 / 171 | Total Loss: 2.4058234691619873 | KNN Loss: 2.3878586292266846 | CLS Loss: 0.01796477474272251
Epoch 126 / 200 | iteration 30 / 171 | Total Loss: 2.426086664199829 | KNN Loss: 2.400967836380005 | CLS Loss: 0.02511891908943653
Epoch 126 / 200 | iteration 40 / 171 | Total Loss: 2.4320223331451416 | KNN Loss: 2.410773515701294 | CLS Loss: 0.021248819306492805
Epoch

Epoch 129 / 200 | iteration 40 / 171 | Total Loss: 2.4978885650634766 | KNN Loss: 2.473923444747925 | CLS Loss: 0.02396516315639019
Epoch 129 / 200 | iteration 50 / 171 | Total Loss: 2.4431567192077637 | KNN Loss: 2.4270777702331543 | CLS Loss: 0.016078844666481018
Epoch 129 / 200 | iteration 60 / 171 | Total Loss: 2.4162588119506836 | KNN Loss: 2.410750389099121 | CLS Loss: 0.005508344154804945
Epoch 129 / 200 | iteration 70 / 171 | Total Loss: 2.449242115020752 | KNN Loss: 2.4126551151275635 | CLS Loss: 0.036586880683898926
Epoch 129 / 200 | iteration 80 / 171 | Total Loss: 2.426483154296875 | KNN Loss: 2.409296989440918 | CLS Loss: 0.01718609407544136
Epoch 129 / 200 | iteration 90 / 171 | Total Loss: 2.5149950981140137 | KNN Loss: 2.4875359535217285 | CLS Loss: 0.02745908498764038
Epoch 129 / 200 | iteration 100 / 171 | Total Loss: 2.4442625045776367 | KNN Loss: 2.419865131378174 | CLS Loss: 0.024397436529397964
Epoch 129 / 200 | iteration 110 / 171 | Total Loss: 2.4199516773223877

Epoch 132 / 200 | iteration 110 / 171 | Total Loss: 2.4730632305145264 | KNN Loss: 2.450782060623169 | CLS Loss: 0.022281277924776077
Epoch 132 / 200 | iteration 120 / 171 | Total Loss: 2.4296774864196777 | KNN Loss: 2.4229705333709717 | CLS Loss: 0.006706861779093742
Epoch 132 / 200 | iteration 130 / 171 | Total Loss: 2.4011168479919434 | KNN Loss: 2.395430564880371 | CLS Loss: 0.005686297547072172
Epoch 132 / 200 | iteration 140 / 171 | Total Loss: 2.4600839614868164 | KNN Loss: 2.438626766204834 | CLS Loss: 0.021457141265273094
Epoch 132 / 200 | iteration 150 / 171 | Total Loss: 2.4551262855529785 | KNN Loss: 2.444669485092163 | CLS Loss: 0.010456729680299759
Epoch 132 / 200 | iteration 160 / 171 | Total Loss: 2.4455366134643555 | KNN Loss: 2.4281678199768066 | CLS Loss: 0.017368797212839127
Epoch 132 / 200 | iteration 170 / 171 | Total Loss: 2.4402213096618652 | KNN Loss: 2.394592761993408 | CLS Loss: 0.045628540217876434
Epoch: 132, Loss: 2.4229, Train: 0.9952, Valid: 0.9852, Best

Epoch 135 / 200 | iteration 170 / 171 | Total Loss: 2.425281286239624 | KNN Loss: 2.4159607887268066 | CLS Loss: 0.00932050496339798
Epoch: 135, Loss: 2.4253, Train: 0.9958, Valid: 0.9858, Best: 0.9873
Epoch 136 / 200 | iteration 0 / 171 | Total Loss: 2.4530434608459473 | KNN Loss: 2.4429354667663574 | CLS Loss: 0.010107909329235554
Epoch 136 / 200 | iteration 10 / 171 | Total Loss: 2.4287827014923096 | KNN Loss: 2.413175582885742 | CLS Loss: 0.015607066452503204
Epoch 136 / 200 | iteration 20 / 171 | Total Loss: 2.391200542449951 | KNN Loss: 2.388995885848999 | CLS Loss: 0.0022047085221856833
Epoch 136 / 200 | iteration 30 / 171 | Total Loss: 2.443066358566284 | KNN Loss: 2.4274377822875977 | CLS Loss: 0.015628468245267868
Epoch 136 / 200 | iteration 40 / 171 | Total Loss: 2.4120326042175293 | KNN Loss: 2.398960828781128 | CLS Loss: 0.013071807101368904
Epoch 136 / 200 | iteration 50 / 171 | Total Loss: 2.4319143295288086 | KNN Loss: 2.416499614715576 | CLS Loss: 0.015414833091199398


Epoch 139 / 200 | iteration 50 / 171 | Total Loss: 2.3725643157958984 | KNN Loss: 2.356304168701172 | CLS Loss: 0.016260169446468353
Epoch 139 / 200 | iteration 60 / 171 | Total Loss: 2.3915722370147705 | KNN Loss: 2.382469415664673 | CLS Loss: 0.009102879092097282
Epoch 139 / 200 | iteration 70 / 171 | Total Loss: 2.4216549396514893 | KNN Loss: 2.403608560562134 | CLS Loss: 0.018046371638774872
Epoch 139 / 200 | iteration 80 / 171 | Total Loss: 2.45607590675354 | KNN Loss: 2.434978485107422 | CLS Loss: 0.02109731175005436
Epoch 139 / 200 | iteration 90 / 171 | Total Loss: 2.379976749420166 | KNN Loss: 2.370692253112793 | CLS Loss: 0.00928457546979189
Epoch 139 / 200 | iteration 100 / 171 | Total Loss: 2.4368975162506104 | KNN Loss: 2.419914484024048 | CLS Loss: 0.016983142122626305
Epoch 139 / 200 | iteration 110 / 171 | Total Loss: 2.399829387664795 | KNN Loss: 2.3822567462921143 | CLS Loss: 0.01757262833416462
Epoch 139 / 200 | iteration 120 / 171 | Total Loss: 2.405946969985962 | K

Epoch 142 / 200 | iteration 120 / 171 | Total Loss: 2.4305477142333984 | KNN Loss: 2.423388957977295 | CLS Loss: 0.007158856838941574
Epoch 142 / 200 | iteration 130 / 171 | Total Loss: 2.4116644859313965 | KNN Loss: 2.401197910308838 | CLS Loss: 0.010466641746461391
Epoch 142 / 200 | iteration 140 / 171 | Total Loss: 2.467960834503174 | KNN Loss: 2.4271717071533203 | CLS Loss: 0.04078923165798187
Epoch 142 / 200 | iteration 150 / 171 | Total Loss: 2.437964677810669 | KNN Loss: 2.43249249458313 | CLS Loss: 0.005472262855619192
Epoch 142 / 200 | iteration 160 / 171 | Total Loss: 2.385284662246704 | KNN Loss: 2.3762407302856445 | CLS Loss: 0.009043947793543339
Epoch 142 / 200 | iteration 170 / 171 | Total Loss: 2.410123109817505 | KNN Loss: 2.4049603939056396 | CLS Loss: 0.005162746645510197
Epoch: 142, Loss: 2.4237, Train: 0.9958, Valid: 0.9866, Best: 0.9873
Epoch 143 / 200 | iteration 0 / 171 | Total Loss: 2.384040117263794 | KNN Loss: 2.3664634227752686 | CLS Loss: 0.01757677644491195

Epoch: 145, Loss: 2.4251, Train: 0.9952, Valid: 0.9862, Best: 0.9873
Epoch 146 / 200 | iteration 0 / 171 | Total Loss: 2.3937766551971436 | KNN Loss: 2.381634473800659 | CLS Loss: 0.012142088264226913
Epoch 146 / 200 | iteration 10 / 171 | Total Loss: 2.401862859725952 | KNN Loss: 2.382481575012207 | CLS Loss: 0.01938120275735855
Epoch 146 / 200 | iteration 20 / 171 | Total Loss: 2.3987979888916016 | KNN Loss: 2.3898544311523438 | CLS Loss: 0.00894360151141882
Epoch 146 / 200 | iteration 30 / 171 | Total Loss: 2.3843867778778076 | KNN Loss: 2.3635029792785645 | CLS Loss: 0.02088373899459839
Epoch 146 / 200 | iteration 40 / 171 | Total Loss: 2.4498534202575684 | KNN Loss: 2.4409682750701904 | CLS Loss: 0.008885029703378677
Epoch 146 / 200 | iteration 50 / 171 | Total Loss: 2.379932165145874 | KNN Loss: 2.368635416030884 | CLS Loss: 0.011296837590634823
Epoch 146 / 200 | iteration 60 / 171 | Total Loss: 2.3864927291870117 | KNN Loss: 2.370435953140259 | CLS Loss: 0.016056746244430542
Epo

Epoch 149 / 200 | iteration 60 / 171 | Total Loss: 2.4352221488952637 | KNN Loss: 2.4212939739227295 | CLS Loss: 0.013928103260695934
Epoch 149 / 200 | iteration 70 / 171 | Total Loss: 2.4316940307617188 | KNN Loss: 2.4191813468933105 | CLS Loss: 0.01251278631389141
Epoch 149 / 200 | iteration 80 / 171 | Total Loss: 2.417574882507324 | KNN Loss: 2.41314435005188 | CLS Loss: 0.004430593457072973
Epoch 149 / 200 | iteration 90 / 171 | Total Loss: 2.3880486488342285 | KNN Loss: 2.369636058807373 | CLS Loss: 0.018412502482533455
Epoch 149 / 200 | iteration 100 / 171 | Total Loss: 2.401202917098999 | KNN Loss: 2.390273094177246 | CLS Loss: 0.010929832234978676
Epoch 149 / 200 | iteration 110 / 171 | Total Loss: 2.4214134216308594 | KNN Loss: 2.4011754989624023 | CLS Loss: 0.0202378761023283
Epoch 149 / 200 | iteration 120 / 171 | Total Loss: 2.406450033187866 | KNN Loss: 2.40173077583313 | CLS Loss: 0.004719339311122894
Epoch 149 / 200 | iteration 130 / 171 | Total Loss: 2.4433255195617676 

Epoch 152 / 200 | iteration 120 / 171 | Total Loss: 2.3931679725646973 | KNN Loss: 2.3809800148010254 | CLS Loss: 0.012188032269477844
Epoch 152 / 200 | iteration 130 / 171 | Total Loss: 2.4273054599761963 | KNN Loss: 2.4180350303649902 | CLS Loss: 0.0092704389244318
Epoch 152 / 200 | iteration 140 / 171 | Total Loss: 2.4420692920684814 | KNN Loss: 2.415151596069336 | CLS Loss: 0.026917627081274986
Epoch 152 / 200 | iteration 150 / 171 | Total Loss: 2.3985612392425537 | KNN Loss: 2.3767404556274414 | CLS Loss: 0.02182081714272499
Epoch 152 / 200 | iteration 160 / 171 | Total Loss: 2.4378674030303955 | KNN Loss: 2.3955562114715576 | CLS Loss: 0.04231124371290207
Epoch 152 / 200 | iteration 170 / 171 | Total Loss: 2.407341718673706 | KNN Loss: 2.400265693664551 | CLS Loss: 0.007075936067849398
Epoch: 152, Loss: 2.4186, Train: 0.9960, Valid: 0.9860, Best: 0.9877
Epoch 153 / 200 | iteration 0 / 171 | Total Loss: 2.383733034133911 | KNN Loss: 2.3741135597229004 | CLS Loss: 0.009619360789656

Epoch: 155, Loss: 2.4190, Train: 0.9962, Valid: 0.9857, Best: 0.9877
Epoch 156 / 200 | iteration 0 / 171 | Total Loss: 2.426699161529541 | KNN Loss: 2.403536796569824 | CLS Loss: 0.02316235564649105
Epoch 156 / 200 | iteration 10 / 171 | Total Loss: 2.4227330684661865 | KNN Loss: 2.418480157852173 | CLS Loss: 0.004252924583852291
Epoch 156 / 200 | iteration 20 / 171 | Total Loss: 2.418699264526367 | KNN Loss: 2.4063868522644043 | CLS Loss: 0.012312395498156548
Epoch 156 / 200 | iteration 30 / 171 | Total Loss: 2.38853120803833 | KNN Loss: 2.374603748321533 | CLS Loss: 0.013927376829087734
Epoch 156 / 200 | iteration 40 / 171 | Total Loss: 2.3793036937713623 | KNN Loss: 2.3694777488708496 | CLS Loss: 0.009826061315834522
Epoch 156 / 200 | iteration 50 / 171 | Total Loss: 2.4142889976501465 | KNN Loss: 2.4078969955444336 | CLS Loss: 0.006392107345163822
Epoch 156 / 200 | iteration 60 / 171 | Total Loss: 2.4464635848999023 | KNN Loss: 2.4304423332214355 | CLS Loss: 0.016021285206079483
Ep

Epoch 159 / 200 | iteration 60 / 171 | Total Loss: 2.414212226867676 | KNN Loss: 2.407844066619873 | CLS Loss: 0.0063681709580123425
Epoch 159 / 200 | iteration 70 / 171 | Total Loss: 2.42338228225708 | KNN Loss: 2.4188199043273926 | CLS Loss: 0.004562459886074066
Epoch 159 / 200 | iteration 80 / 171 | Total Loss: 2.378603219985962 | KNN Loss: 2.372594118118286 | CLS Loss: 0.006009100470691919
Epoch 159 / 200 | iteration 90 / 171 | Total Loss: 2.3845295906066895 | KNN Loss: 2.3780343532562256 | CLS Loss: 0.006495276000350714
Epoch 159 / 200 | iteration 100 / 171 | Total Loss: 2.418332815170288 | KNN Loss: 2.392852306365967 | CLS Loss: 0.025480465963482857
Epoch 159 / 200 | iteration 110 / 171 | Total Loss: 2.4543795585632324 | KNN Loss: 2.432668924331665 | CLS Loss: 0.02171061933040619
Epoch 159 / 200 | iteration 120 / 171 | Total Loss: 2.422098159790039 | KNN Loss: 2.4148752689361572 | CLS Loss: 0.0072228508070111275
Epoch 159 / 200 | iteration 130 / 171 | Total Loss: 2.42977213859558

Epoch 162 / 200 | iteration 120 / 171 | Total Loss: 2.438447952270508 | KNN Loss: 2.4200491905212402 | CLS Loss: 0.018398789688944817
Epoch 162 / 200 | iteration 130 / 171 | Total Loss: 2.419670581817627 | KNN Loss: 2.3884103298187256 | CLS Loss: 0.031260184943675995
Epoch 162 / 200 | iteration 140 / 171 | Total Loss: 2.4277660846710205 | KNN Loss: 2.4169833660125732 | CLS Loss: 0.010782704688608646
Epoch 162 / 200 | iteration 150 / 171 | Total Loss: 2.3850669860839844 | KNN Loss: 2.374346971511841 | CLS Loss: 0.010719978250563145
Epoch 162 / 200 | iteration 160 / 171 | Total Loss: 2.4704678058624268 | KNN Loss: 2.457977294921875 | CLS Loss: 0.012490605004131794
Epoch 162 / 200 | iteration 170 / 171 | Total Loss: 2.3771064281463623 | KNN Loss: 2.3634033203125 | CLS Loss: 0.013702993281185627
Epoch: 162, Loss: 2.4231, Train: 0.9961, Valid: 0.9871, Best: 0.9877
Epoch 163 / 200 | iteration 0 / 171 | Total Loss: 2.3837435245513916 | KNN Loss: 2.3719606399536133 | CLS Loss: 0.01178288739174

Epoch: 165, Loss: 2.4173, Train: 0.9952, Valid: 0.9872, Best: 0.9877
Epoch 166 / 200 | iteration 0 / 171 | Total Loss: 2.384030342102051 | KNN Loss: 2.3750882148742676 | CLS Loss: 0.008942064829170704
Epoch 166 / 200 | iteration 10 / 171 | Total Loss: 2.412233352661133 | KNN Loss: 2.3902266025543213 | CLS Loss: 0.02200683020055294
Epoch 166 / 200 | iteration 20 / 171 | Total Loss: 2.402087688446045 | KNN Loss: 2.399810552597046 | CLS Loss: 0.0022772469092160463
Epoch 166 / 200 | iteration 30 / 171 | Total Loss: 2.4090025424957275 | KNN Loss: 2.399564504623413 | CLS Loss: 0.009438032284379005
Epoch 166 / 200 | iteration 40 / 171 | Total Loss: 2.397714614868164 | KNN Loss: 2.393087148666382 | CLS Loss: 0.004627539776265621
Epoch 166 / 200 | iteration 50 / 171 | Total Loss: 2.3911352157592773 | KNN Loss: 2.376361846923828 | CLS Loss: 0.01477347407490015
Epoch 166 / 200 | iteration 60 / 171 | Total Loss: 2.506476640701294 | KNN Loss: 2.477640390396118 | CLS Loss: 0.028836244717240334
Epoch

Epoch 169 / 200 | iteration 60 / 171 | Total Loss: 2.4103851318359375 | KNN Loss: 2.3778553009033203 | CLS Loss: 0.03252982348203659
Epoch 169 / 200 | iteration 70 / 171 | Total Loss: 2.4260482788085938 | KNN Loss: 2.4181010723114014 | CLS Loss: 0.007947313599288464
Epoch 169 / 200 | iteration 80 / 171 | Total Loss: 2.4116547107696533 | KNN Loss: 2.4069337844848633 | CLS Loss: 0.0047210450284183025
Epoch 169 / 200 | iteration 90 / 171 | Total Loss: 2.4077038764953613 | KNN Loss: 2.402024030685425 | CLS Loss: 0.005679942201822996
Epoch 169 / 200 | iteration 100 / 171 | Total Loss: 2.414956569671631 | KNN Loss: 2.4110352993011475 | CLS Loss: 0.003921192605048418
Epoch 169 / 200 | iteration 110 / 171 | Total Loss: 2.3735311031341553 | KNN Loss: 2.3605504035949707 | CLS Loss: 0.012980758212506771
Epoch 169 / 200 | iteration 120 / 171 | Total Loss: 2.442890167236328 | KNN Loss: 2.41556453704834 | CLS Loss: 0.027325674891471863
Epoch 169 / 200 | iteration 130 / 171 | Total Loss: 2.4164586067

Epoch 172 / 200 | iteration 120 / 171 | Total Loss: 2.4437978267669678 | KNN Loss: 2.432168483734131 | CLS Loss: 0.011629343964159489
Epoch 172 / 200 | iteration 130 / 171 | Total Loss: 2.4632632732391357 | KNN Loss: 2.4509127140045166 | CLS Loss: 0.012350494042038918
Epoch 172 / 200 | iteration 140 / 171 | Total Loss: 2.4422378540039062 | KNN Loss: 2.4264841079711914 | CLS Loss: 0.015753639861941338
Epoch 172 / 200 | iteration 150 / 171 | Total Loss: 2.392085313796997 | KNN Loss: 2.367666482925415 | CLS Loss: 0.024418821558356285
Epoch 172 / 200 | iteration 160 / 171 | Total Loss: 2.3966619968414307 | KNN Loss: 2.374375343322754 | CLS Loss: 0.022286580875515938
Epoch 172 / 200 | iteration 170 / 171 | Total Loss: 2.4043161869049072 | KNN Loss: 2.399260997772217 | CLS Loss: 0.005055295769125223
Epoch: 172, Loss: 2.4163, Train: 0.9966, Valid: 0.9866, Best: 0.9877
Epoch 173 / 200 | iteration 0 / 171 | Total Loss: 2.4368276596069336 | KNN Loss: 2.4200804233551025 | CLS Loss: 0.016747314482

Epoch: 175, Loss: 2.4176, Train: 0.9962, Valid: 0.9868, Best: 0.9877
Epoch 176 / 200 | iteration 0 / 171 | Total Loss: 2.3878118991851807 | KNN Loss: 2.385547161102295 | CLS Loss: 0.0022646316792815924
Epoch 176 / 200 | iteration 10 / 171 | Total Loss: 2.389198064804077 | KNN Loss: 2.3673198223114014 | CLS Loss: 0.021878231316804886
Epoch 176 / 200 | iteration 20 / 171 | Total Loss: 2.408796787261963 | KNN Loss: 2.3899219036102295 | CLS Loss: 0.01887483336031437
Epoch 176 / 200 | iteration 30 / 171 | Total Loss: 2.4296250343322754 | KNN Loss: 2.417938709259033 | CLS Loss: 0.011686262674629688
Epoch 176 / 200 | iteration 40 / 171 | Total Loss: 2.3954977989196777 | KNN Loss: 2.3905324935913086 | CLS Loss: 0.0049652280285954475
Epoch 176 / 200 | iteration 50 / 171 | Total Loss: 2.396230697631836 | KNN Loss: 2.3734073638916016 | CLS Loss: 0.02282334864139557
Epoch 176 / 200 | iteration 60 / 171 | Total Loss: 2.411369562149048 | KNN Loss: 2.401437520980835 | CLS Loss: 0.009932056069374084
E

Epoch 179 / 200 | iteration 60 / 171 | Total Loss: 2.429917812347412 | KNN Loss: 2.4188618659973145 | CLS Loss: 0.0110558420419693
Epoch 179 / 200 | iteration 70 / 171 | Total Loss: 2.400336742401123 | KNN Loss: 2.3888142108917236 | CLS Loss: 0.011522614397108555
Epoch 179 / 200 | iteration 80 / 171 | Total Loss: 2.4450461864471436 | KNN Loss: 2.439887762069702 | CLS Loss: 0.005158418323844671
Epoch 179 / 200 | iteration 90 / 171 | Total Loss: 2.4267995357513428 | KNN Loss: 2.398348808288574 | CLS Loss: 0.028450623154640198
Epoch 179 / 200 | iteration 100 / 171 | Total Loss: 2.433804750442505 | KNN Loss: 2.4274840354919434 | CLS Loss: 0.006320657674223185
Epoch 179 / 200 | iteration 110 / 171 | Total Loss: 2.377309560775757 | KNN Loss: 2.3728644847869873 | CLS Loss: 0.004445097874850035
Epoch 179 / 200 | iteration 120 / 171 | Total Loss: 2.389059066772461 | KNN Loss: 2.381943464279175 | CLS Loss: 0.007115675136446953
Epoch 179 / 200 | iteration 130 / 171 | Total Loss: 2.432914495468139

Epoch 182 / 200 | iteration 130 / 171 | Total Loss: 2.4116954803466797 | KNN Loss: 2.389427900314331 | CLS Loss: 0.022267693653702736
Epoch 182 / 200 | iteration 140 / 171 | Total Loss: 2.3778882026672363 | KNN Loss: 2.3729352951049805 | CLS Loss: 0.004952860996127129
Epoch 182 / 200 | iteration 150 / 171 | Total Loss: 2.405331611633301 | KNN Loss: 2.4006457328796387 | CLS Loss: 0.00468584056943655
Epoch 182 / 200 | iteration 160 / 171 | Total Loss: 2.41795015335083 | KNN Loss: 2.3957574367523193 | CLS Loss: 0.02219267003238201
Epoch 182 / 200 | iteration 170 / 171 | Total Loss: 2.408967971801758 | KNN Loss: 2.3834261894226074 | CLS Loss: 0.025541871786117554
Epoch: 182, Loss: 2.4184, Train: 0.9966, Valid: 0.9853, Best: 0.9877
Epoch 183 / 200 | iteration 0 / 171 | Total Loss: 2.4108808040618896 | KNN Loss: 2.4074294567108154 | CLS Loss: 0.003451260272413492
Epoch 183 / 200 | iteration 10 / 171 | Total Loss: 2.414402484893799 | KNN Loss: 2.406243324279785 | CLS Loss: 0.00815905164927244

Epoch 186 / 200 | iteration 10 / 171 | Total Loss: 2.423741102218628 | KNN Loss: 2.4106104373931885 | CLS Loss: 0.01313056517392397
Epoch 186 / 200 | iteration 20 / 171 | Total Loss: 2.4519309997558594 | KNN Loss: 2.440330982208252 | CLS Loss: 0.011600104160606861
Epoch 186 / 200 | iteration 30 / 171 | Total Loss: 2.422361373901367 | KNN Loss: 2.4131765365600586 | CLS Loss: 0.009184829890727997
Epoch 186 / 200 | iteration 40 / 171 | Total Loss: 2.4058241844177246 | KNN Loss: 2.3885395526885986 | CLS Loss: 0.017284715548157692
Epoch 186 / 200 | iteration 50 / 171 | Total Loss: 2.423069953918457 | KNN Loss: 2.4066174030303955 | CLS Loss: 0.016452541574835777
Epoch 186 / 200 | iteration 60 / 171 | Total Loss: 2.419846534729004 | KNN Loss: 2.39336895942688 | CLS Loss: 0.02647750824689865
Epoch 186 / 200 | iteration 70 / 171 | Total Loss: 2.4381065368652344 | KNN Loss: 2.4240856170654297 | CLS Loss: 0.014020973816514015
Epoch 186 / 200 | iteration 80 / 171 | Total Loss: 2.3900530338287354 |

Epoch 189 / 200 | iteration 70 / 171 | Total Loss: 2.4003379344940186 | KNN Loss: 2.3877363204956055 | CLS Loss: 0.012601522728800774
Epoch 189 / 200 | iteration 80 / 171 | Total Loss: 2.4133224487304688 | KNN Loss: 2.3927459716796875 | CLS Loss: 0.020576534792780876
Epoch 189 / 200 | iteration 90 / 171 | Total Loss: 2.3987808227539062 | KNN Loss: 2.3797974586486816 | CLS Loss: 0.01898333989083767
Epoch 189 / 200 | iteration 100 / 171 | Total Loss: 2.445039987564087 | KNN Loss: 2.4145350456237793 | CLS Loss: 0.03050500713288784
Epoch 189 / 200 | iteration 110 / 171 | Total Loss: 2.433316230773926 | KNN Loss: 2.419560670852661 | CLS Loss: 0.013755575753748417
Epoch 189 / 200 | iteration 120 / 171 | Total Loss: 2.501502275466919 | KNN Loss: 2.4793009757995605 | CLS Loss: 0.02220122143626213
Epoch 189 / 200 | iteration 130 / 171 | Total Loss: 2.442021608352661 | KNN Loss: 2.430443525314331 | CLS Loss: 0.011578013189136982
Epoch 189 / 200 | iteration 140 / 171 | Total Loss: 2.4157562255859

Epoch 192 / 200 | iteration 130 / 171 | Total Loss: 2.4154891967773438 | KNN Loss: 2.40907883644104 | CLS Loss: 0.0064102946780622005
Epoch 192 / 200 | iteration 140 / 171 | Total Loss: 2.41257643699646 | KNN Loss: 2.407850503921509 | CLS Loss: 0.004725826904177666
Epoch 192 / 200 | iteration 150 / 171 | Total Loss: 2.377790689468384 | KNN Loss: 2.354823589324951 | CLS Loss: 0.02296704612672329
Epoch 192 / 200 | iteration 160 / 171 | Total Loss: 2.392366647720337 | KNN Loss: 2.384329080581665 | CLS Loss: 0.008037499152123928
Epoch 192 / 200 | iteration 170 / 171 | Total Loss: 2.438843250274658 | KNN Loss: 2.4248478412628174 | CLS Loss: 0.01399534847587347
Epoch: 192, Loss: 2.4226, Train: 0.9970, Valid: 0.9871, Best: 0.9877
Epoch 193 / 200 | iteration 0 / 171 | Total Loss: 2.4052398204803467 | KNN Loss: 2.3996529579162598 | CLS Loss: 0.005586802028119564
Epoch 193 / 200 | iteration 10 / 171 | Total Loss: 2.400081157684326 | KNN Loss: 2.3848702907562256 | CLS Loss: 0.015210882760584354
E

Epoch 196 / 200 | iteration 10 / 171 | Total Loss: 2.3891751766204834 | KNN Loss: 2.378584623336792 | CLS Loss: 0.010590544901788235
Epoch 196 / 200 | iteration 20 / 171 | Total Loss: 2.4515790939331055 | KNN Loss: 2.446146011352539 | CLS Loss: 0.005433198995888233
Epoch 196 / 200 | iteration 30 / 171 | Total Loss: 2.3746628761291504 | KNN Loss: 2.368917942047119 | CLS Loss: 0.005744998808950186
Epoch 196 / 200 | iteration 40 / 171 | Total Loss: 2.405022621154785 | KNN Loss: 2.3930423259735107 | CLS Loss: 0.011980370618402958
Epoch 196 / 200 | iteration 50 / 171 | Total Loss: 2.385594129562378 | KNN Loss: 2.36598801612854 | CLS Loss: 0.01960601843893528
Epoch 196 / 200 | iteration 60 / 171 | Total Loss: 2.4038076400756836 | KNN Loss: 2.398301362991333 | CLS Loss: 0.005506326910108328
Epoch 196 / 200 | iteration 70 / 171 | Total Loss: 2.469566822052002 | KNN Loss: 2.4619991779327393 | CLS Loss: 0.00756765715777874
Epoch 196 / 200 | iteration 80 / 171 | Total Loss: 2.375807523727417 | KN

Epoch 199 / 200 | iteration 70 / 171 | Total Loss: 2.4101948738098145 | KNN Loss: 2.392791748046875 | CLS Loss: 0.017403146252036095
Epoch 199 / 200 | iteration 80 / 171 | Total Loss: 2.3714137077331543 | KNN Loss: 2.3656351566314697 | CLS Loss: 0.0057786256074905396
Epoch 199 / 200 | iteration 90 / 171 | Total Loss: 2.4124298095703125 | KNN Loss: 2.402292251586914 | CLS Loss: 0.010137462988495827
Epoch 199 / 200 | iteration 100 / 171 | Total Loss: 2.4406585693359375 | KNN Loss: 2.4180727005004883 | CLS Loss: 0.022585898637771606
Epoch 199 / 200 | iteration 110 / 171 | Total Loss: 2.4331226348876953 | KNN Loss: 2.419377326965332 | CLS Loss: 0.013745265081524849
Epoch 199 / 200 | iteration 120 / 171 | Total Loss: 2.4068851470947266 | KNN Loss: 2.3972902297973633 | CLS Loss: 0.009594847448170185
Epoch 199 / 200 | iteration 130 / 171 | Total Loss: 2.4206607341766357 | KNN Loss: 2.404902935028076 | CLS Loss: 0.015757769346237183
Epoch 199 / 200 | iteration 140 / 171 | Total Loss: 2.4278264

In [12]:
test(model, test_data_iter, device)

tensor(0.9856, device='cuda:0')

In [13]:
plt.figure()
plt.plot(losses, label='train loss')
plt.legend()
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [14]:
plt.figure()
plt.plot(train_accs, label='train accuracy')
plt.plot(val_accs, label='validation accuracy')
plt.legend()
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [15]:
test_samples = torch.tensor([])
projections = torch.tensor([])
labels = torch.tensor([])

with torch.no_grad():
    for x, y in tqdm(test_data_iter):
        test_samples = torch.cat([test_samples, x])
        labels = torch.cat([labels, y])
        x = x.to(device)
        _, interm = model(x, True)
        projections = torch.cat([projections, interm.detach().cpu().flatten(1)])

  0%|          | 0/43 [00:00<?, ?it/s]

In [16]:
distances = pairwise_distances(projections)
# distances = np.triu(distances)
distances_f = distances.flatten()

plt.matshow(distances)
plt.colorbar()
plt.figure()
plt.hist(distances_f[distances_f > 0], bins=1000)
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [17]:
clusters = DBSCAN(eps=2, min_samples=10).fit_predict(projections)

In [18]:
print(f"Number of inliers: {sum(clusters != -1) / len(clusters)}")

Number of inliers: 0.9099629984925311


In [19]:
perplexity = 100
p = reduce_dims_and_plot(projections[clusters != -1],
                         y=clusters[clusters != -1],
                         title=f'perplexity: {perplexity}',
                         file_name=None,
                         perplexity=perplexity,
                         library='Multicore-TSNE',
                         perform_PCA=False,
                         projected=None,
                         figure_type='2d',
                         show_figure=True,
                         close_figure=False,
                         text=None)    

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

# Train a Soft-Decision-Tree given the self-labels

## Prepare the dataset

In [20]:
tree_dataset = list(zip(test_samples.flatten(1)[clusters!=-1], clusters[clusters != -1]))
batch_size = 512
tree_loader = torch.utils.data.DataLoader(tree_dataset, batch_size=batch_size, shuffle=True)

# Define how we prune the weights of a node

In [21]:
def prune_node(node_weights, factor=1):
    w = node_weights.cpu().detach().numpy()
    mean_ = np.mean(w)
    std_ = np.std(w)
    node_weights[((mean_ - std_ * factor) < node_weights) & (node_weights < (mean_ + std_ * factor))] = 0
    return node_weights

def prune_node_keep(node_weights, keep=4):
    w = node_weights.cpu().detach().numpy()
    throw_idx = np.argsort(abs(w))[:-keep]
    node_weights[throw_idx] = 0
    return node_weights

def prune_tree(tree_, factor):
    new_weights = tree_.inner_nodes.weight.clone()
    for i in range(new_weights.shape[0]):
        res = prune_node_keep(new_weights[i, :], factor)
        new_weights[i, :] = res

    with torch.no_grad():
        tree_.inner_nodes.weight.copy_(new_weights)
        
def sparseness(x):
    s = []
    for i in range(x.shape[0]):
        x_ = x[i, :]
        sp = (len(x_) - torch.norm(x_, 0).item()) / len(x_)
        s.append(sp)
    return np.mean(s)

def compute_regularization_by_level(tree):
    total_reg = 0
    for i in range(tree.inner_nodes.weight.shape[0]):
        cur_level = np.floor(np.log2(i+1))
        node_reg = torch.norm(tree.inner_nodes.weight[i].view(-1), 2)
        total_reg += 2**(-cur_level) * node_reg
    return total_reg

def show_sparseness(tree):
    avg_sp = sparseness(tree.inner_nodes.weight)
    print(f"Average sparseness: {avg_sp}")
    layer = 0
    sps = []
    for i in range(tree.inner_nodes.weight.shape[0]):
        cur_layer = int(np.floor(np.log2(i+1)))
        if cur_layer != layer:
            print(f"layer {layer}: {np.mean(sps)}")
            sps = []
            layer = cur_layer

        x_ = tree.inner_nodes.weight[i, :]
        sp = (len(x_) - torch.norm(x_, 0).item()) / len(x_)
        sps.append(sp)
        
    return avg_sp

## Training configurations

In [22]:
def do_epoch(model, loader, device, log_interval, losses, accs, epoch, iteration):
    model = model.train()
    for batch_idx, (data, target) in enumerate(loader):
        iteration += 1
        data, target = data.to(device), target.to(device)

        output, penalty = tree.forward(data)

        # Loss
        loss_tree = criterion(output, target.view(-1))

        # Penalty
        loss_tree += penalty

        # Sparse regularization
#         fc_params = torch.cat([x.view(-1) for x in tree.inner_nodes.parameters()])
#         regularization = sparsity_lamda * torch.norm(fc_params, 2)
        regularization = sparsity_lamda * compute_regularization_by_level(tree)
        loss = loss_tree

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        losses.append(loss.item())

        pred = output.data.max(1)[1]
        correct = pred.eq(target.view(-1).data).sum()
        accs.append(correct.item() / data.size()[0])

        # Print training status
        if batch_idx % log_interval == 0:
            print(f"Epoch: {epoch:02d} | Batch: {batch_idx:03d} / {len(loader):03d} | Total loss: {loss.item():.3f} | Reg loss: {regularization.item():.3f} | Tree loss: {loss_tree.item():.3f} | Accuracy: {correct.item() / data.size()[0]:03f} | {round((time.time() - start_time) / iteration, 3)} sec/iter")
            
    return iteration


In [23]:
lr = 5e-3
weight_decay = 5e-4
sparsity_lamda = 2e-3
epochs = 400
log_interval = 100
use_cuda = device != 'cpu'

In [24]:
tree = SDT(input_dim=test_samples.shape[2], output_dim=len(set(clusters)) - 1, depth=tree_depth, lamda=1e-3, use_cuda=use_cuda)
optimizer = torch.optim.Adam(tree.parameters(),
                                 lr=lr,
                                 weight_decay=weight_decay)
criterion = nn.CrossEntropyLoss()
tree = tree.to(device)

In [25]:
losses = []
accs = []
sparsity = []

In [26]:
start_time = time.time()
iteration = 0
for epoch in range(epochs):
    # Training
    avg_sp = show_sparseness(tree)
    sparsity.append(avg_sp)
    iteration = do_epoch(tree, tree_loader, device, log_interval, losses, accs, epoch, iteration)
    
    if epoch % 1 == 0:
        prune_tree(tree, factor=3)
        

Average sparseness: 0.0
layer 0: 0.0
layer 1: 0.0
layer 2: 0.0
layer 3: 0.0
layer 4: 0.0
Epoch: 00 | Batch: 000 / 039 | Total loss: 3.270 | Reg loss: 0.007 | Tree loss: 3.270 | Accuracy: 0.046875 | 0.091 sec/iter
Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
Epoch: 01 | Batch: 000 / 039 | Total loss: 3.193 | Reg loss: 0.005 | Tree loss: 3.193 | Accuracy: 0.136719 | 0.063 sec/iter
Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
Epoch: 02 | Batch: 000 / 039 | Total loss: 3.135 | Reg loss: 0.008 | Tree loss: 3.135 | Accuracy: 0.146484 | 0.062 sec/iter
Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
Epoch

Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
Epoch: 28 | Batch: 000 / 039 | Total loss: 2.743 | Reg loss: 0.023 | Tree loss: 2.743 | Accuracy: 0.216797 | 0.062 sec/iter
Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
Epoch: 29 | Batch: 000 / 039 | Total loss: 2.686 | Reg loss: 0.023 | Tree loss: 2.686 | Accuracy: 0.267578 | 0.062 sec/iter
Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
Epoch: 30 | Batch: 000 / 039 | Total loss: 2.647 | Reg loss: 0.023 | Tree loss: 2.647 | Accuracy: 0.304688 | 0.062 sec/iter
Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894


Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
Epoch: 56 | Batch: 000 / 039 | Total loss: 2.454 | Reg loss: 0.024 | Tree loss: 2.454 | Accuracy: 0.283203 | 0.058 sec/iter
Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
Epoch: 57 | Batch: 000 / 039 | Total loss: 2.436 | Reg loss: 0.024 | Tree loss: 2.436 | Accuracy: 0.296875 | 0.057 sec/iter
Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
Epoch: 58 | Batch: 000 / 039 | Total loss: 2.474 | Reg loss: 0.024 | Tree loss: 2.474 | Accuracy: 0.285156 | 0.057 sec/iter
Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894


Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
Epoch: 84 | Batch: 000 / 039 | Total loss: 2.435 | Reg loss: 0.025 | Tree loss: 2.435 | Accuracy: 0.294922 | 0.054 sec/iter
Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
Epoch: 85 | Batch: 000 / 039 | Total loss: 2.439 | Reg loss: 0.025 | Tree loss: 2.439 | Accuracy: 0.300781 | 0.054 sec/iter
Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
Epoch: 86 | Batch: 000 / 039 | Total loss: 2.442 | Reg loss: 0.025 | Tree loss: 2.442 | Accuracy: 0.302734 | 0.054 sec/iter
Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894


Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
Epoch: 112 | Batch: 000 / 039 | Total loss: 2.385 | Reg loss: 0.025 | Tree loss: 2.385 | Accuracy: 0.294922 | 0.052 sec/iter
Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
Epoch: 113 | Batch: 000 / 039 | Total loss: 2.428 | Reg loss: 0.025 | Tree loss: 2.428 | Accuracy: 0.300781 | 0.052 sec/iter
Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
Epoch: 114 | Batch: 000 / 039 | Total loss: 2.449 | Reg loss: 0.025 | Tree loss: 2.449 | Accuracy: 0.273438 | 0.052 sec/iter
Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.98404255319148

Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
Epoch: 140 | Batch: 000 / 039 | Total loss: 2.385 | Reg loss: 0.025 | Tree loss: 2.385 | Accuracy: 0.328125 | 0.053 sec/iter
Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
Epoch: 141 | Batch: 000 / 039 | Total loss: 2.248 | Reg loss: 0.025 | Tree loss: 2.248 | Accuracy: 0.353516 | 0.054 sec/iter
Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
Epoch: 142 | Batch: 000 / 039 | Total loss: 2.353 | Reg loss: 0.025 | Tree loss: 2.353 | Accuracy: 0.310547 | 0.054 sec/iter
Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.98404255319148

Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
Epoch: 168 | Batch: 000 / 039 | Total loss: 2.284 | Reg loss: 0.026 | Tree loss: 2.284 | Accuracy: 0.322266 | 0.055 sec/iter
Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
Epoch: 169 | Batch: 000 / 039 | Total loss: 2.292 | Reg loss: 0.026 | Tree loss: 2.292 | Accuracy: 0.337891 | 0.055 sec/iter
Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
Epoch: 170 | Batch: 000 / 039 | Total loss: 2.302 | Reg loss: 0.026 | Tree loss: 2.302 | Accuracy: 0.310547 | 0.055 sec/iter
Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.98404255319148

Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
Epoch: 196 | Batch: 000 / 039 | Total loss: 2.238 | Reg loss: 0.027 | Tree loss: 2.238 | Accuracy: 0.320312 | 0.057 sec/iter
Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
Epoch: 197 | Batch: 000 / 039 | Total loss: 2.220 | Reg loss: 0.027 | Tree loss: 2.220 | Accuracy: 0.310547 | 0.057 sec/iter
Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
Epoch: 198 | Batch: 000 / 039 | Total loss: 2.280 | Reg loss: 0.027 | Tree loss: 2.280 | Accuracy: 0.318359 | 0.057 sec/iter
Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.98404255319148

Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
Epoch: 224 | Batch: 000 / 039 | Total loss: 2.154 | Reg loss: 0.028 | Tree loss: 2.154 | Accuracy: 0.400391 | 0.059 sec/iter
Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
Epoch: 225 | Batch: 000 / 039 | Total loss: 2.203 | Reg loss: 0.028 | Tree loss: 2.203 | Accuracy: 0.345703 | 0.059 sec/iter
Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
Epoch: 226 | Batch: 000 / 039 | Total loss: 2.206 | Reg loss: 0.028 | Tree loss: 2.206 | Accuracy: 0.357422 | 0.059 sec/iter
Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.98404255319148

Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
Epoch: 252 | Batch: 000 / 039 | Total loss: 2.203 | Reg loss: 0.028 | Tree loss: 2.203 | Accuracy: 0.363281 | 0.06 sec/iter
Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
Epoch: 253 | Batch: 000 / 039 | Total loss: 2.111 | Reg loss: 0.028 | Tree loss: 2.111 | Accuracy: 0.390625 | 0.06 sec/iter
Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
Epoch: 254 | Batch: 000 / 039 | Total loss: 2.272 | Reg loss: 0.028 | Tree loss: 2.272 | Accuracy: 0.322266 | 0.06 sec/iter
Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894


Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
Epoch: 280 | Batch: 000 / 039 | Total loss: 2.139 | Reg loss: 0.028 | Tree loss: 2.139 | Accuracy: 0.396484 | 0.06 sec/iter
Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
Epoch: 281 | Batch: 000 / 039 | Total loss: 2.096 | Reg loss: 0.028 | Tree loss: 2.096 | Accuracy: 0.400391 | 0.06 sec/iter
Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
Epoch: 282 | Batch: 000 / 039 | Total loss: 2.188 | Reg loss: 0.028 | Tree loss: 2.188 | Accuracy: 0.373047 | 0.06 sec/iter
Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894


Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
Epoch: 308 | Batch: 000 / 039 | Total loss: 2.106 | Reg loss: 0.027 | Tree loss: 2.106 | Accuracy: 0.400391 | 0.06 sec/iter
Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
Epoch: 309 | Batch: 000 / 039 | Total loss: 2.086 | Reg loss: 0.027 | Tree loss: 2.086 | Accuracy: 0.431641 | 0.06 sec/iter
Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
Epoch: 310 | Batch: 000 / 039 | Total loss: 2.049 | Reg loss: 0.027 | Tree loss: 2.049 | Accuracy: 0.417969 | 0.06 sec/iter
Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894


Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
Epoch: 336 | Batch: 000 / 039 | Total loss: 2.083 | Reg loss: 0.027 | Tree loss: 2.083 | Accuracy: 0.423828 | 0.06 sec/iter
Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
Epoch: 337 | Batch: 000 / 039 | Total loss: 2.132 | Reg loss: 0.027 | Tree loss: 2.132 | Accuracy: 0.414062 | 0.06 sec/iter
Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
Epoch: 338 | Batch: 000 / 039 | Total loss: 2.030 | Reg loss: 0.027 | Tree loss: 2.030 | Accuracy: 0.472656 | 0.06 sec/iter
Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894


Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
Epoch: 364 | Batch: 000 / 039 | Total loss: 2.151 | Reg loss: 0.027 | Tree loss: 2.151 | Accuracy: 0.394531 | 0.06 sec/iter
Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
Epoch: 365 | Batch: 000 / 039 | Total loss: 2.134 | Reg loss: 0.027 | Tree loss: 2.134 | Accuracy: 0.392578 | 0.06 sec/iter
Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
Epoch: 366 | Batch: 000 / 039 | Total loss: 2.060 | Reg loss: 0.027 | Tree loss: 2.060 | Accuracy: 0.431641 | 0.06 sec/iter
Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894


Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
Epoch: 392 | Batch: 000 / 039 | Total loss: 2.094 | Reg loss: 0.027 | Tree loss: 2.094 | Accuracy: 0.423828 | 0.059 sec/iter
Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
Epoch: 393 | Batch: 000 / 039 | Total loss: 2.139 | Reg loss: 0.027 | Tree loss: 2.139 | Accuracy: 0.408203 | 0.059 sec/iter
Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
Epoch: 394 | Batch: 000 / 039 | Total loss: 2.114 | Reg loss: 0.027 | Tree loss: 2.114 | Accuracy: 0.417969 | 0.059 sec/iter
Average sparseness: 0.984042553191489
layer 0: 0.9840425531914894
layer 1: 0.98404255319148

In [27]:
plt.figure(figsize=(10, 5))
plt.ylabel("Accuracy")
plt.xlabel('Iteration')
plt.plot(accs, label='Accuracy vs iteration')
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [28]:
plt.figure()
plt.ylabel("Loss")
plt.xlabel('Iteration')
plt.plot(losses, label='Loss vs iteration')
# plt.yscale("log")
plt.show()

plt.figure()
weights = tree.inner_nodes.weight.cpu().detach().numpy().flatten()
plt.hist(weights, bins=500)
weights_std = np.std(weights)
weights_mean = np.mean(weights)
plt.axvline(weights_mean + weights_std, color='r')
plt.axvline(weights_mean - weights_std, color='r')
plt.title(f"Mean: {weights_mean}   |   STD: {weights_std}")
plt.yscale("log")
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

# Tree Visualization

In [29]:
plt.figure(figsize=(15, 10), dpi=80)
avg_height, root = tree.visualize()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Average height: 5.222222222222222


# Extract Rules

# Accumulate samples in the leaves

In [30]:
print(f"Number of patterns: {len(root.get_leaves())}")

Number of patterns: 18


In [31]:
method = 'greedy'

In [32]:
root.clear_leaves_samples()

with torch.no_grad():
    for batch_idx, (data, target) in enumerate(tree_loader):
        root.accumulate_samples(data, method)



# Tighten boundaries

In [35]:
attr_names = [f"T_{i}" for i in range(test_samples.shape[2])]
leaves = root.get_leaves()
sum_comprehensibility = 0
comprehensibilities = []
for pattern_counter, leaf in enumerate(leaves):
    leaf.reset_path()
    leaf.tighten_with_accumulated_samples()
    conds = leaf.get_path_conditions(attr_names)
    print(f"============== Pattern {pattern_counter + 1} ==============")
    comprehensibilities.append(sum([cond.comprehensibility for cond in conds]))
    
print(f"Average comprehensibility: {np.mean(comprehensibilities)}")
print(f"std comprehensibility: {np.std(comprehensibilities)}")
print(f"var comprehensibility: {np.var(comprehensibilities)}")
print(f"minimum comprehensibility: {np.min(comprehensibilities)}")
print(f"maximum comprehensibility: {np.max(comprehensibilities)}")

254
335
1302
1420
626
2427
5248
5396
526
2386
Average comprehensibility: 28.333333333333332
std comprehensibility: 8.034647195462632
var comprehensibility: 64.55555555555556
minimum comprehensibility: 4
maximum comprehensibility: 34
