In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib widget
import numpy as np
import pandas as pd
import torch
import torchvision
from torch import optim
import matplotlib.pyplot as plt
import torch.nn.functional as F
from sklearn.metrics import pairwise_distances
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tqdm.notebook import tqdm
import os
import sys
import time
import torch.nn as nn
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)
from stream_generators.mit_bih import MITBIH
from utils.MatplotlibUtils import reduce_dims_and_plot
from network.auto_encoder import AutoEncoder
from losses.knn_loss import KNNLoss, ClassificationKNNLoss
from sklearn.cluster import DBSCAN
from sklearn.mixture import GaussianMixture
from soft_decision_tree.sdt_model import SDT
from sklearn.metrics import davies_bouldin_score

In [2]:
k = 32
tree_depth = 12
batch_size = 512
device = 'cuda'
train_data_path = r'<>/mitbih_train.csv'  # replace <> with the correct path of the dataset
test_data_path = r'<>/mitbih_test.csv'  # replace <> with the correct path of the dataset

In [3]:
train_data_iter = torch.utils.data.DataLoader(MITBIH(train_data_path),
                                             batch_size=batch_size,
                                             shuffle=True,
                                             num_workers=1,
                                             pin_memory=True,
                                             drop_last=True)

test_data_iter = torch.utils.data.DataLoader(MITBIH(test_data_path),
                                             batch_size=batch_size,
                                             shuffle=True,
                                             num_workers=1,
                                             pin_memory=True)

In [4]:
class ConvBlock(nn.Module):
    def __init__(self):
        super(ConvBlock, self).__init__()
        self.conv1 = nn.Conv1d(32, 32, kernel_size=5, stride=1, padding=2)
        self.conv2 = nn.Conv1d(32, 32, kernel_size=5, stride=1, padding=2)
        self.relu1 = nn.ReLU()
        self.relu2 = nn.ReLU()
        self.pool = nn.MaxPool1d(kernel_size=5, stride=2)
        
    def forward(self, x):
        y = x
        y = self.conv1(y)
        y = self.relu1(y)
        y = self.conv2(y)
        y = y + x
        y = self.relu2(y)
        y = self.pool(y)
        return y


class ECGModel(nn.Module):
    def __init__(self):
        super(ECGModel, self).__init__()
        self.conv = nn.Conv1d(1, 32, kernel_size=5, stride=1, padding=1)
        self.block1 = ConvBlock()
        self.block2 = ConvBlock()
        self.block3 = ConvBlock()
        self.block4 = ConvBlock()
        self.block5 = ConvBlock()
        self.fc1 = nn.Linear(64, 32)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(32, 5)

    def forward(self, x, return_interm=False):
        x = self.conv(x)
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        x = self.block4(x)
        x = self.block5(x)
        interm = x.flatten(1)
        x = self.fc1(interm)
        x = self.relu(x)
        x = self.fc2(x)
        
        if return_interm:
            return x, interm
        
        return x

In [5]:
knn_crt = ClassificationKNNLoss(k=k).to(device)

def train(model, loader, optimizer, device):
    model.train()

    total_loss = 0
    for iteration, (batch, target) in enumerate(loader):
        batch = batch.to(device)
        target = target.to(device)
        outputs, interm = model(batch, return_interm=True)
        mse_loss = F.cross_entropy(outputs, target)
        mse_loss = mse_loss.sum(dim=-1).mean()
        try:
            knn_loss = knn_crt(interm, target)
            if torch.isinf(knn_loss):
                knn_loss = torch.tensor(0).to(device)
        except ValueError:
            knn_loss = torch.tensor(0).to(device)
        loss = mse_loss + knn_loss
        total_loss += loss.item()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if iteration % log_every == 0:
            print(f"Epoch {epoch} / {epochs} | iteration {iteration} / {len(loader)} | Total Loss: {loss.item()} | KNN Loss: {knn_loss.item()} | CLS Loss: {mse_loss.item()}")

    return total_loss / len(loader)

@torch.no_grad()
def test(model, loader, device):
    model.eval()
    
    correct = 0
    for iteration, (batch, target) in enumerate(loader):
        batch = batch.to(device)
        target = target.to(device)
        y_pred = model(batch).argmax(dim=-1)
        correct += y_pred.eq(target.view(-1).data).sum()
    
    return correct / len(loader.dataset)

In [6]:
epochs = 200
lr = 1e-3
log_every = 10

model = ECGModel().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
num_params = sum([p.numel() for p in model.parameters()])
print(f'#Params: {num_params}')

#Params: 53957


In [7]:
best_valid_acc = 0
losses = []
train_accs = []
val_accs = []
for epoch in range(1, epochs + 1):
    loss = train(model, train_data_iter, optimizer, device)
#     print(f"Loss: {loss} =============================")
    losses.append(loss)
    train_acc = test(model, train_data_iter, device)
    train_accs.append(train_acc)
    valid_acc = test(model, test_data_iter, device)
    val_accs.append(valid_acc)
    if valid_acc > best_valid_acc:
        best_valid_acc = valid_acc
    if epoch % 1 == 0:
        print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, '
              f'Train: {train_acc:.4f}, Valid: {valid_acc:.4f}, '
              f'Best: {best_valid_acc:.4f}')

Epoch 1 / 200 | iteration 0 / 171 | Total Loss: 7.318262577056885 | KNN Loss: 5.820803642272949 | CLS Loss: 1.497458815574646
Epoch 1 / 200 | iteration 10 / 171 | Total Loss: 5.080780982971191 | KNN Loss: 4.2764387130737305 | CLS Loss: 0.8043422698974609
Epoch 1 / 200 | iteration 20 / 171 | Total Loss: 4.642612934112549 | KNN Loss: 3.95671010017395 | CLS Loss: 0.6859029531478882
Epoch 1 / 200 | iteration 30 / 171 | Total Loss: 4.5135345458984375 | KNN Loss: 3.8869571685791016 | CLS Loss: 0.6265771985054016
Epoch 1 / 200 | iteration 40 / 171 | Total Loss: 4.441784858703613 | KNN Loss: 3.8211803436279297 | CLS Loss: 0.6206047534942627
Epoch 1 / 200 | iteration 50 / 171 | Total Loss: 4.345569133758545 | KNN Loss: 3.8533806800842285 | CLS Loss: 0.49218854308128357
Epoch 1 / 200 | iteration 60 / 171 | Total Loss: 4.419981002807617 | KNN Loss: 3.8789937496185303 | CLS Loss: 0.5409873127937317
Epoch 1 / 200 | iteration 70 / 171 | Total Loss: 4.247969150543213 | KNN Loss: 3.846813678741455 | C

Epoch 4 / 200 | iteration 80 / 171 | Total Loss: 3.79848575592041 | KNN Loss: 3.689203977584839 | CLS Loss: 0.10928184539079666
Epoch 4 / 200 | iteration 90 / 171 | Total Loss: 3.8146355152130127 | KNN Loss: 3.6692540645599365 | CLS Loss: 0.14538134634494781
Epoch 4 / 200 | iteration 100 / 171 | Total Loss: 3.789215564727783 | KNN Loss: 3.7060489654541016 | CLS Loss: 0.08316660672426224
Epoch 4 / 200 | iteration 110 / 171 | Total Loss: 3.8082547187805176 | KNN Loss: 3.6728105545043945 | CLS Loss: 0.13544407486915588
Epoch 4 / 200 | iteration 120 / 171 | Total Loss: 3.847170114517212 | KNN Loss: 3.7044966220855713 | CLS Loss: 0.1426735818386078
Epoch 4 / 200 | iteration 130 / 171 | Total Loss: 3.7556545734405518 | KNN Loss: 3.6470677852630615 | CLS Loss: 0.10858669877052307
Epoch 4 / 200 | iteration 140 / 171 | Total Loss: 3.7774887084960938 | KNN Loss: 3.6704304218292236 | CLS Loss: 0.10705823451280594
Epoch 4 / 200 | iteration 150 / 171 | Total Loss: 3.7957208156585693 | KNN Loss: 3.7

Epoch 7 / 200 | iteration 160 / 171 | Total Loss: 3.7533836364746094 | KNN Loss: 3.643889904022217 | CLS Loss: 0.1094936728477478
Epoch 7 / 200 | iteration 170 / 171 | Total Loss: 3.759463310241699 | KNN Loss: 3.6787896156311035 | CLS Loss: 0.08067367225885391
Epoch: 007, Loss: 3.7559, Train: 0.9793, Valid: 0.9760, Best: 0.9760
Epoch 8 / 200 | iteration 0 / 171 | Total Loss: 3.738450288772583 | KNN Loss: 3.6555938720703125 | CLS Loss: 0.08285630494356155
Epoch 8 / 200 | iteration 10 / 171 | Total Loss: 3.70348858833313 | KNN Loss: 3.6520161628723145 | CLS Loss: 0.05147242546081543
Epoch 8 / 200 | iteration 20 / 171 | Total Loss: 3.7704567909240723 | KNN Loss: 3.6611456871032715 | CLS Loss: 0.10931115597486496
Epoch 8 / 200 | iteration 30 / 171 | Total Loss: 3.7452759742736816 | KNN Loss: 3.672139883041382 | CLS Loss: 0.07313603907823563
Epoch 8 / 200 | iteration 40 / 171 | Total Loss: 3.708233594894409 | KNN Loss: 3.630563497543335 | CLS Loss: 0.07767011225223541
Epoch 8 / 200 | iterat

Epoch 11 / 200 | iteration 50 / 171 | Total Loss: 3.7627627849578857 | KNN Loss: 3.675680160522461 | CLS Loss: 0.08708266913890839
Epoch 11 / 200 | iteration 60 / 171 | Total Loss: 3.7342638969421387 | KNN Loss: 3.6771769523620605 | CLS Loss: 0.05708690732717514
Epoch 11 / 200 | iteration 70 / 171 | Total Loss: 3.7306833267211914 | KNN Loss: 3.6280019283294678 | CLS Loss: 0.1026814728975296
Epoch 11 / 200 | iteration 80 / 171 | Total Loss: 3.719848871231079 | KNN Loss: 3.638535737991333 | CLS Loss: 0.0813131183385849
Epoch 11 / 200 | iteration 90 / 171 | Total Loss: 3.70465350151062 | KNN Loss: 3.6479318141937256 | CLS Loss: 0.05672163888812065
Epoch 11 / 200 | iteration 100 / 171 | Total Loss: 3.710153579711914 | KNN Loss: 3.6282947063446045 | CLS Loss: 0.08185885101556778
Epoch 11 / 200 | iteration 110 / 171 | Total Loss: 3.698760509490967 | KNN Loss: 3.644787073135376 | CLS Loss: 0.05397352576255798
Epoch 11 / 200 | iteration 120 / 171 | Total Loss: 3.728572368621826 | KNN Loss: 3.6

Epoch 14 / 200 | iteration 120 / 171 | Total Loss: 3.6565840244293213 | KNN Loss: 3.6200175285339355 | CLS Loss: 0.03656653314828873
Epoch 14 / 200 | iteration 130 / 171 | Total Loss: 3.672551155090332 | KNN Loss: 3.6199028491973877 | CLS Loss: 0.052648186683654785
Epoch 14 / 200 | iteration 140 / 171 | Total Loss: 3.731665849685669 | KNN Loss: 3.6519761085510254 | CLS Loss: 0.07968978583812714
Epoch 14 / 200 | iteration 150 / 171 | Total Loss: 3.680635452270508 | KNN Loss: 3.608729600906372 | CLS Loss: 0.07190591841936111
Epoch 14 / 200 | iteration 160 / 171 | Total Loss: 3.754255771636963 | KNN Loss: 3.6525609493255615 | CLS Loss: 0.10169482231140137
Epoch 14 / 200 | iteration 170 / 171 | Total Loss: 3.693587064743042 | KNN Loss: 3.6085550785064697 | CLS Loss: 0.0850318968296051
Epoch: 014, Loss: 3.6998, Train: 0.9820, Valid: 0.9783, Best: 0.9814
Epoch 15 / 200 | iteration 0 / 171 | Total Loss: 3.724168300628662 | KNN Loss: 3.647566556930542 | CLS Loss: 0.07660174369812012
Epoch 15 /

Epoch 18 / 200 | iteration 10 / 171 | Total Loss: 3.638951063156128 | KNN Loss: 3.6093945503234863 | CLS Loss: 0.02955649234354496
Epoch 18 / 200 | iteration 20 / 171 | Total Loss: 3.6590747833251953 | KNN Loss: 3.617400646209717 | CLS Loss: 0.04167410731315613
Epoch 18 / 200 | iteration 30 / 171 | Total Loss: 3.6594011783599854 | KNN Loss: 3.617701768875122 | CLS Loss: 0.041699472814798355
Epoch 18 / 200 | iteration 40 / 171 | Total Loss: 3.6963508129119873 | KNN Loss: 3.63618540763855 | CLS Loss: 0.06016538292169571
Epoch 18 / 200 | iteration 50 / 171 | Total Loss: 3.6822354793548584 | KNN Loss: 3.6098101139068604 | CLS Loss: 0.07242526113986969
Epoch 18 / 200 | iteration 60 / 171 | Total Loss: 3.6879372596740723 | KNN Loss: 3.6214959621429443 | CLS Loss: 0.06644134223461151
Epoch 18 / 200 | iteration 70 / 171 | Total Loss: 3.671886444091797 | KNN Loss: 3.611564874649048 | CLS Loss: 0.060321662575006485
Epoch 18 / 200 | iteration 80 / 171 | Total Loss: 3.6674792766571045 | KNN Loss: 

Epoch 21 / 200 | iteration 80 / 171 | Total Loss: 3.6557419300079346 | KNN Loss: 3.6284549236297607 | CLS Loss: 0.02728702314198017
Epoch 21 / 200 | iteration 90 / 171 | Total Loss: 3.6604034900665283 | KNN Loss: 3.6110756397247314 | CLS Loss: 0.04932793974876404
Epoch 21 / 200 | iteration 100 / 171 | Total Loss: 3.6701767444610596 | KNN Loss: 3.587149143218994 | CLS Loss: 0.08302749693393707
Epoch 21 / 200 | iteration 110 / 171 | Total Loss: 3.6688661575317383 | KNN Loss: 3.635118007659912 | CLS Loss: 0.0337480790913105
Epoch 21 / 200 | iteration 120 / 171 | Total Loss: 3.712132215499878 | KNN Loss: 3.6369788646698 | CLS Loss: 0.07515344023704529
Epoch 21 / 200 | iteration 130 / 171 | Total Loss: 3.647012948989868 | KNN Loss: 3.5717649459838867 | CLS Loss: 0.07524795085191727
Epoch 21 / 200 | iteration 140 / 171 | Total Loss: 3.6938745975494385 | KNN Loss: 3.630488395690918 | CLS Loss: 0.06338613480329514
Epoch 21 / 200 | iteration 150 / 171 | Total Loss: 3.6343488693237305 | KNN Loss

Epoch 24 / 200 | iteration 150 / 171 | Total Loss: 3.6566600799560547 | KNN Loss: 3.6270554065704346 | CLS Loss: 0.02960479073226452
Epoch 24 / 200 | iteration 160 / 171 | Total Loss: 3.639728307723999 | KNN Loss: 3.614372968673706 | CLS Loss: 0.025355232879519463
Epoch 24 / 200 | iteration 170 / 171 | Total Loss: 3.651150941848755 | KNN Loss: 3.602566719055176 | CLS Loss: 0.048584215342998505
Epoch: 024, Loss: 3.6587, Train: 0.9890, Valid: 0.9834, Best: 0.9845
Epoch 25 / 200 | iteration 0 / 171 | Total Loss: 3.635643243789673 | KNN Loss: 3.615349292755127 | CLS Loss: 0.020293917506933212
Epoch 25 / 200 | iteration 10 / 171 | Total Loss: 3.6388304233551025 | KNN Loss: 3.6025454998016357 | CLS Loss: 0.0362849161028862
Epoch 25 / 200 | iteration 20 / 171 | Total Loss: 3.6316399574279785 | KNN Loss: 3.587451219558716 | CLS Loss: 0.04418875649571419
Epoch 25 / 200 | iteration 30 / 171 | Total Loss: 3.6823229789733887 | KNN Loss: 3.646953582763672 | CLS Loss: 0.03536949306726456
Epoch 25 / 

Epoch 28 / 200 | iteration 40 / 171 | Total Loss: 3.6517903804779053 | KNN Loss: 3.6172561645507812 | CLS Loss: 0.03453418239951134
Epoch 28 / 200 | iteration 50 / 171 | Total Loss: 3.6446778774261475 | KNN Loss: 3.618469715118408 | CLS Loss: 0.026208214461803436
Epoch 28 / 200 | iteration 60 / 171 | Total Loss: 3.6588780879974365 | KNN Loss: 3.629643678665161 | CLS Loss: 0.0292343832552433
Epoch 28 / 200 | iteration 70 / 171 | Total Loss: 3.671288251876831 | KNN Loss: 3.6366539001464844 | CLS Loss: 0.03463425859808922
Epoch 28 / 200 | iteration 80 / 171 | Total Loss: 3.6374566555023193 | KNN Loss: 3.608553886413574 | CLS Loss: 0.02890276536345482
Epoch 28 / 200 | iteration 90 / 171 | Total Loss: 3.6206307411193848 | KNN Loss: 3.610518217086792 | CLS Loss: 0.010112565942108631
Epoch 28 / 200 | iteration 100 / 171 | Total Loss: 3.6882412433624268 | KNN Loss: 3.6393935680389404 | CLS Loss: 0.04884764179587364
Epoch 28 / 200 | iteration 110 / 171 | Total Loss: 3.6753103733062744 | KNN Los

Epoch 31 / 200 | iteration 110 / 171 | Total Loss: 3.686310291290283 | KNN Loss: 3.6592633724212646 | CLS Loss: 0.027047021314501762
Epoch 31 / 200 | iteration 120 / 171 | Total Loss: 3.6658778190612793 | KNN Loss: 3.618687391281128 | CLS Loss: 0.04719039425253868
Epoch 31 / 200 | iteration 130 / 171 | Total Loss: 3.7259602546691895 | KNN Loss: 3.6905179023742676 | CLS Loss: 0.0354422889649868
Epoch 31 / 200 | iteration 140 / 171 | Total Loss: 3.6128954887390137 | KNN Loss: 3.5914816856384277 | CLS Loss: 0.021413851529359818
Epoch 31 / 200 | iteration 150 / 171 | Total Loss: 3.659183979034424 | KNN Loss: 3.6283867359161377 | CLS Loss: 0.030797215178608894
Epoch 31 / 200 | iteration 160 / 171 | Total Loss: 3.671973943710327 | KNN Loss: 3.6387782096862793 | CLS Loss: 0.033195625990629196
Epoch 31 / 200 | iteration 170 / 171 | Total Loss: 3.6727097034454346 | KNN Loss: 3.6313233375549316 | CLS Loss: 0.04138626158237457
Epoch: 031, Loss: 3.6592, Train: 0.9915, Valid: 0.9844, Best: 0.9849
E

Epoch: 034, Loss: 3.6557, Train: 0.9927, Valid: 0.9858, Best: 0.9858
Epoch 35 / 200 | iteration 0 / 171 | Total Loss: 3.6351327896118164 | KNN Loss: 3.611428737640381 | CLS Loss: 0.023704133927822113
Epoch 35 / 200 | iteration 10 / 171 | Total Loss: 3.6505351066589355 | KNN Loss: 3.6283681392669678 | CLS Loss: 0.02216694876551628
Epoch 35 / 200 | iteration 20 / 171 | Total Loss: 3.664210081100464 | KNN Loss: 3.624635934829712 | CLS Loss: 0.03957405686378479
Epoch 35 / 200 | iteration 30 / 171 | Total Loss: 3.6433396339416504 | KNN Loss: 3.614870071411133 | CLS Loss: 0.02846948616206646
Epoch 35 / 200 | iteration 40 / 171 | Total Loss: 3.6094255447387695 | KNN Loss: 3.589462995529175 | CLS Loss: 0.01996249333024025
Epoch 35 / 200 | iteration 50 / 171 | Total Loss: 3.6713778972625732 | KNN Loss: 3.6424942016601562 | CLS Loss: 0.02888377010822296
Epoch 35 / 200 | iteration 60 / 171 | Total Loss: 3.6588034629821777 | KNN Loss: 3.6262643337249756 | CLS Loss: 0.03253912553191185
Epoch 35 / 2

Epoch 38 / 200 | iteration 70 / 171 | Total Loss: 3.6182613372802734 | KNN Loss: 3.594242572784424 | CLS Loss: 0.024018676951527596
Epoch 38 / 200 | iteration 80 / 171 | Total Loss: 3.6807491779327393 | KNN Loss: 3.634270191192627 | CLS Loss: 0.04647887125611305
Epoch 38 / 200 | iteration 90 / 171 | Total Loss: 3.674145460128784 | KNN Loss: 3.62239408493042 | CLS Loss: 0.05175129324197769
Epoch 38 / 200 | iteration 100 / 171 | Total Loss: 3.6495323181152344 | KNN Loss: 3.6389074325561523 | CLS Loss: 0.010624919086694717
Epoch 38 / 200 | iteration 110 / 171 | Total Loss: 3.6816635131835938 | KNN Loss: 3.652076005935669 | CLS Loss: 0.02958747372031212
Epoch 38 / 200 | iteration 120 / 171 | Total Loss: 3.6377005577087402 | KNN Loss: 3.6228861808776855 | CLS Loss: 0.014814487658441067
Epoch 38 / 200 | iteration 130 / 171 | Total Loss: 3.7147233486175537 | KNN Loss: 3.680055856704712 | CLS Loss: 0.03466740995645523
Epoch 38 / 200 | iteration 140 / 171 | Total Loss: 3.7258317470550537 | KNN 

Epoch 41 / 200 | iteration 140 / 171 | Total Loss: 3.7075746059417725 | KNN Loss: 3.659881353378296 | CLS Loss: 0.0476931557059288
Epoch 41 / 200 | iteration 150 / 171 | Total Loss: 3.6322319507598877 | KNN Loss: 3.607978582382202 | CLS Loss: 0.02425335720181465
Epoch 41 / 200 | iteration 160 / 171 | Total Loss: 3.649883508682251 | KNN Loss: 3.6193301677703857 | CLS Loss: 0.030553355813026428
Epoch 41 / 200 | iteration 170 / 171 | Total Loss: 3.6559553146362305 | KNN Loss: 3.6211116313934326 | CLS Loss: 0.034843720495700836
Epoch: 041, Loss: 3.6586, Train: 0.9932, Valid: 0.9850, Best: 0.9863
Epoch 42 / 200 | iteration 0 / 171 | Total Loss: 3.6240220069885254 | KNN Loss: 3.605971574783325 | CLS Loss: 0.018050508573651314
Epoch 42 / 200 | iteration 10 / 171 | Total Loss: 3.702122688293457 | KNN Loss: 3.6732213497161865 | CLS Loss: 0.02890133112668991
Epoch 42 / 200 | iteration 20 / 171 | Total Loss: 3.659257411956787 | KNN Loss: 3.6155362129211426 | CLS Loss: 0.04372125118970871
Epoch 42

Epoch 45 / 200 | iteration 20 / 171 | Total Loss: 3.606844186782837 | KNN Loss: 3.5996198654174805 | CLS Loss: 0.00722424266859889
Epoch 45 / 200 | iteration 30 / 171 | Total Loss: 3.646623373031616 | KNN Loss: 3.611558198928833 | CLS Loss: 0.035065069794654846
Epoch 45 / 200 | iteration 40 / 171 | Total Loss: 3.631519079208374 | KNN Loss: 3.6012954711914062 | CLS Loss: 0.030223648995161057
Epoch 45 / 200 | iteration 50 / 171 | Total Loss: 3.649301767349243 | KNN Loss: 3.6111438274383545 | CLS Loss: 0.038157861679792404
Epoch 45 / 200 | iteration 60 / 171 | Total Loss: 3.6337218284606934 | KNN Loss: 3.6242470741271973 | CLS Loss: 0.009474842809140682
Epoch 45 / 200 | iteration 70 / 171 | Total Loss: 3.592681646347046 | KNN Loss: 3.579632520675659 | CLS Loss: 0.013049175031483173
Epoch 45 / 200 | iteration 80 / 171 | Total Loss: 3.647430181503296 | KNN Loss: 3.582977533340454 | CLS Loss: 0.0644526258111
Epoch 45 / 200 | iteration 90 / 171 | Total Loss: 3.6700057983398438 | KNN Loss: 3.6

Epoch 48 / 200 | iteration 90 / 171 | Total Loss: 3.63954496383667 | KNN Loss: 3.621581792831421 | CLS Loss: 0.017963232472538948
Epoch 48 / 200 | iteration 100 / 171 | Total Loss: 3.6340675354003906 | KNN Loss: 3.605264663696289 | CLS Loss: 0.028802908957004547
Epoch 48 / 200 | iteration 110 / 171 | Total Loss: 3.63861346244812 | KNN Loss: 3.6154401302337646 | CLS Loss: 0.023173406720161438
Epoch 48 / 200 | iteration 120 / 171 | Total Loss: 3.679769277572632 | KNN Loss: 3.641517400741577 | CLS Loss: 0.038251809775829315
Epoch 48 / 200 | iteration 130 / 171 | Total Loss: 3.6917829513549805 | KNN Loss: 3.675417423248291 | CLS Loss: 0.016365446150302887
Epoch 48 / 200 | iteration 140 / 171 | Total Loss: 3.6352436542510986 | KNN Loss: 3.6241824626922607 | CLS Loss: 0.01106125395745039
Epoch 48 / 200 | iteration 150 / 171 | Total Loss: 3.627027750015259 | KNN Loss: 3.5985138416290283 | CLS Loss: 0.028513982892036438
Epoch 48 / 200 | iteration 160 / 171 | Total Loss: 3.662531614303589 | KNN

Epoch 51 / 200 | iteration 160 / 171 | Total Loss: 3.706047534942627 | KNN Loss: 3.6429874897003174 | CLS Loss: 0.06306011229753494
Epoch 51 / 200 | iteration 170 / 171 | Total Loss: 3.6699507236480713 | KNN Loss: 3.644773244857788 | CLS Loss: 0.025177521631121635
Epoch: 051, Loss: 3.6483, Train: 0.9939, Valid: 0.9865, Best: 0.9873
Epoch 52 / 200 | iteration 0 / 171 | Total Loss: 3.6316661834716797 | KNN Loss: 3.607577323913574 | CLS Loss: 0.024088917300105095
Epoch 52 / 200 | iteration 10 / 171 | Total Loss: 3.6518542766571045 | KNN Loss: 3.618896007537842 | CLS Loss: 0.03295822814106941
Epoch 52 / 200 | iteration 20 / 171 | Total Loss: 3.6706149578094482 | KNN Loss: 3.6519784927368164 | CLS Loss: 0.018636571243405342
Epoch 52 / 200 | iteration 30 / 171 | Total Loss: 3.651196241378784 | KNN Loss: 3.639460325241089 | CLS Loss: 0.011735948733985424
Epoch 52 / 200 | iteration 40 / 171 | Total Loss: 3.6370701789855957 | KNN Loss: 3.6106810569763184 | CLS Loss: 0.02638903819024563
Epoch 52

Epoch 55 / 200 | iteration 50 / 171 | Total Loss: 3.6613314151763916 | KNN Loss: 3.6292049884796143 | CLS Loss: 0.03212636336684227
Epoch 55 / 200 | iteration 60 / 171 | Total Loss: 3.643455743789673 | KNN Loss: 3.606968641281128 | CLS Loss: 0.03648701310157776
Epoch 55 / 200 | iteration 70 / 171 | Total Loss: 3.6610770225524902 | KNN Loss: 3.632298231124878 | CLS Loss: 0.028778746724128723
Epoch 55 / 200 | iteration 80 / 171 | Total Loss: 3.6350581645965576 | KNN Loss: 3.5970938205718994 | CLS Loss: 0.03796426206827164
Epoch 55 / 200 | iteration 90 / 171 | Total Loss: 3.617839813232422 | KNN Loss: 3.6127371788024902 | CLS Loss: 0.005102622788399458
Epoch 55 / 200 | iteration 100 / 171 | Total Loss: 3.666402578353882 | KNN Loss: 3.642460584640503 | CLS Loss: 0.023941950872540474
Epoch 55 / 200 | iteration 110 / 171 | Total Loss: 3.65158748626709 | KNN Loss: 3.626426935195923 | CLS Loss: 0.025160498917102814
Epoch 55 / 200 | iteration 120 / 171 | Total Loss: 3.663989543914795 | KNN Loss

Epoch 58 / 200 | iteration 120 / 171 | Total Loss: 3.7090811729431152 | KNN Loss: 3.650923252105713 | CLS Loss: 0.05815788730978966
Epoch 58 / 200 | iteration 130 / 171 | Total Loss: 3.6430752277374268 | KNN Loss: 3.5913162231445312 | CLS Loss: 0.05175893008708954
Epoch 58 / 200 | iteration 140 / 171 | Total Loss: 3.683300733566284 | KNN Loss: 3.6716830730438232 | CLS Loss: 0.011617613025009632
Epoch 58 / 200 | iteration 150 / 171 | Total Loss: 3.622406244277954 | KNN Loss: 3.589998960494995 | CLS Loss: 0.0324072502553463
Epoch 58 / 200 | iteration 160 / 171 | Total Loss: 3.6579744815826416 | KNN Loss: 3.6300365924835205 | CLS Loss: 0.02793784812092781
Epoch 58 / 200 | iteration 170 / 171 | Total Loss: 3.6593587398529053 | KNN Loss: 3.62252140045166 | CLS Loss: 0.03683742508292198
Epoch: 058, Loss: 3.6498, Train: 0.9943, Valid: 0.9853, Best: 0.9873
Epoch 59 / 200 | iteration 0 / 171 | Total Loss: 3.6590826511383057 | KNN Loss: 3.6548049449920654 | CLS Loss: 0.004277745261788368
Epoch 5

Epoch 62 / 200 | iteration 0 / 171 | Total Loss: 3.6650583744049072 | KNN Loss: 3.6297364234924316 | CLS Loss: 0.03532198444008827
Epoch 62 / 200 | iteration 10 / 171 | Total Loss: 3.6379528045654297 | KNN Loss: 3.6248228549957275 | CLS Loss: 0.01312994584441185
Epoch 62 / 200 | iteration 20 / 171 | Total Loss: 3.6866202354431152 | KNN Loss: 3.6764466762542725 | CLS Loss: 0.010173634625971317
Epoch 62 / 200 | iteration 30 / 171 | Total Loss: 3.6666083335876465 | KNN Loss: 3.6461474895477295 | CLS Loss: 0.020460914820432663
Epoch 62 / 200 | iteration 40 / 171 | Total Loss: 3.6463522911071777 | KNN Loss: 3.62431263923645 | CLS Loss: 0.022039690986275673
Epoch 62 / 200 | iteration 50 / 171 | Total Loss: 3.669015884399414 | KNN Loss: 3.642732620239258 | CLS Loss: 0.026283256709575653
Epoch 62 / 200 | iteration 60 / 171 | Total Loss: 3.653456926345825 | KNN Loss: 3.638542413711548 | CLS Loss: 0.014914433471858501
Epoch 62 / 200 | iteration 70 / 171 | Total Loss: 3.63602876663208 | KNN Loss:

Epoch 65 / 200 | iteration 70 / 171 | Total Loss: 3.659623384475708 | KNN Loss: 3.622908353805542 | CLS Loss: 0.036715004593133926
Epoch 65 / 200 | iteration 80 / 171 | Total Loss: 3.6287477016448975 | KNN Loss: 3.62618350982666 | CLS Loss: 0.0025640816893428564
Epoch 65 / 200 | iteration 90 / 171 | Total Loss: 3.6369149684906006 | KNN Loss: 3.6267123222351074 | CLS Loss: 0.010202603414654732
Epoch 65 / 200 | iteration 100 / 171 | Total Loss: 3.682544469833374 | KNN Loss: 3.6554970741271973 | CLS Loss: 0.027047472074627876
Epoch 65 / 200 | iteration 110 / 171 | Total Loss: 3.6456539630889893 | KNN Loss: 3.6307246685028076 | CLS Loss: 0.014929354190826416
Epoch 65 / 200 | iteration 120 / 171 | Total Loss: 3.612827777862549 | KNN Loss: 3.5889060497283936 | CLS Loss: 0.023921679705381393
Epoch 65 / 200 | iteration 130 / 171 | Total Loss: 3.6711368560791016 | KNN Loss: 3.6494579315185547 | CLS Loss: 0.02167890965938568
Epoch 65 / 200 | iteration 140 / 171 | Total Loss: 3.622426748275757 | 

Epoch 68 / 200 | iteration 140 / 171 | Total Loss: 3.6116585731506348 | KNN Loss: 3.601923942565918 | CLS Loss: 0.009734644554555416
Epoch 68 / 200 | iteration 150 / 171 | Total Loss: 3.6132140159606934 | KNN Loss: 3.599552631378174 | CLS Loss: 0.013661324977874756
Epoch 68 / 200 | iteration 160 / 171 | Total Loss: 3.650015115737915 | KNN Loss: 3.6417548656463623 | CLS Loss: 0.00826026126742363
Epoch 68 / 200 | iteration 170 / 171 | Total Loss: 3.627016067504883 | KNN Loss: 3.5881052017211914 | CLS Loss: 0.038910817354917526
Epoch: 068, Loss: 3.6334, Train: 0.9952, Valid: 0.9863, Best: 0.9873
Epoch 69 / 200 | iteration 0 / 171 | Total Loss: 3.6023855209350586 | KNN Loss: 3.5903031826019287 | CLS Loss: 0.01208241656422615
Epoch 69 / 200 | iteration 10 / 171 | Total Loss: 3.6404638290405273 | KNN Loss: 3.6167640686035156 | CLS Loss: 0.023699862882494926
Epoch 69 / 200 | iteration 20 / 171 | Total Loss: 3.6384425163269043 | KNN Loss: 3.619255781173706 | CLS Loss: 0.01918662153184414
Epoch

Epoch 72 / 200 | iteration 20 / 171 | Total Loss: 3.6078531742095947 | KNN Loss: 3.6015708446502686 | CLS Loss: 0.006282305344939232
Epoch 72 / 200 | iteration 30 / 171 | Total Loss: 3.6265857219696045 | KNN Loss: 3.605905294418335 | CLS Loss: 0.020680485293269157
Epoch 72 / 200 | iteration 40 / 171 | Total Loss: 3.612851142883301 | KNN Loss: 3.6102089881896973 | CLS Loss: 0.002642161911353469
Epoch 72 / 200 | iteration 50 / 171 | Total Loss: 3.6367785930633545 | KNN Loss: 3.6213433742523193 | CLS Loss: 0.015435220673680305
Epoch 72 / 200 | iteration 60 / 171 | Total Loss: 3.6311161518096924 | KNN Loss: 3.6260986328125 | CLS Loss: 0.005017537157982588
Epoch 72 / 200 | iteration 70 / 171 | Total Loss: 3.6116833686828613 | KNN Loss: 3.606553792953491 | CLS Loss: 0.0051295701414346695
Epoch 72 / 200 | iteration 80 / 171 | Total Loss: 3.637434959411621 | KNN Loss: 3.6163313388824463 | CLS Loss: 0.02110353298485279
Epoch 72 / 200 | iteration 90 / 171 | Total Loss: 3.705763816833496 | KNN Lo

Epoch 75 / 200 | iteration 90 / 171 | Total Loss: 3.6432509422302246 | KNN Loss: 3.623993158340454 | CLS Loss: 0.019257858395576477
Epoch 75 / 200 | iteration 100 / 171 | Total Loss: 3.6497132778167725 | KNN Loss: 3.6173081398010254 | CLS Loss: 0.03240522742271423
Epoch 75 / 200 | iteration 110 / 171 | Total Loss: 3.657618284225464 | KNN Loss: 3.643691301345825 | CLS Loss: 0.01392704900354147
Epoch 75 / 200 | iteration 120 / 171 | Total Loss: 3.587270498275757 | KNN Loss: 3.5662460327148438 | CLS Loss: 0.021024387329816818
Epoch 75 / 200 | iteration 130 / 171 | Total Loss: 3.6258833408355713 | KNN Loss: 3.609583616256714 | CLS Loss: 0.016299722716212273
Epoch 75 / 200 | iteration 140 / 171 | Total Loss: 3.6030611991882324 | KNN Loss: 3.591930389404297 | CLS Loss: 0.01113081257790327
Epoch 75 / 200 | iteration 150 / 171 | Total Loss: 3.6527023315429688 | KNN Loss: 3.6281728744506836 | CLS Loss: 0.024529486894607544
Epoch 75 / 200 | iteration 160 / 171 | Total Loss: 3.6275405883789062 | 

Epoch 78 / 200 | iteration 160 / 171 | Total Loss: 3.608360767364502 | KNN Loss: 3.592115640640259 | CLS Loss: 0.016245054081082344
Epoch 78 / 200 | iteration 170 / 171 | Total Loss: 3.591012954711914 | KNN Loss: 3.579185962677002 | CLS Loss: 0.011826964095234871
Epoch: 078, Loss: 3.6246, Train: 0.9960, Valid: 0.9869, Best: 0.9881
Epoch 79 / 200 | iteration 0 / 171 | Total Loss: 3.603792190551758 | KNN Loss: 3.5797924995422363 | CLS Loss: 0.023999638855457306
Epoch 79 / 200 | iteration 10 / 171 | Total Loss: 3.6212425231933594 | KNN Loss: 3.6008007526397705 | CLS Loss: 0.020441750064492226
Epoch 79 / 200 | iteration 20 / 171 | Total Loss: 3.6015663146972656 | KNN Loss: 3.5890729427337646 | CLS Loss: 0.012493393383920193
Epoch 79 / 200 | iteration 30 / 171 | Total Loss: 3.623159885406494 | KNN Loss: 3.614776611328125 | CLS Loss: 0.008383376523852348
Epoch 79 / 200 | iteration 40 / 171 | Total Loss: 3.6364541053771973 | KNN Loss: 3.616581678390503 | CLS Loss: 0.019872449338436127
Epoch 7

Epoch 82 / 200 | iteration 40 / 171 | Total Loss: 3.590385675430298 | KNN Loss: 3.5843987464904785 | CLS Loss: 0.005986956879496574
Epoch 82 / 200 | iteration 50 / 171 | Total Loss: 3.60817813873291 | KNN Loss: 3.6028826236724854 | CLS Loss: 0.005295445676892996
Epoch 82 / 200 | iteration 60 / 171 | Total Loss: 3.6131114959716797 | KNN Loss: 3.598600149154663 | CLS Loss: 0.014511308632791042
Epoch 82 / 200 | iteration 70 / 171 | Total Loss: 3.6079277992248535 | KNN Loss: 3.59498929977417 | CLS Loss: 0.012938479892909527
Epoch 82 / 200 | iteration 80 / 171 | Total Loss: 3.6848158836364746 | KNN Loss: 3.6809210777282715 | CLS Loss: 0.0038948534056544304
Epoch 82 / 200 | iteration 90 / 171 | Total Loss: 3.598494529724121 | KNN Loss: 3.5889899730682373 | CLS Loss: 0.00950460322201252
Epoch 82 / 200 | iteration 100 / 171 | Total Loss: 3.628730058670044 | KNN Loss: 3.612852096557617 | CLS Loss: 0.01587800495326519
Epoch 82 / 200 | iteration 110 / 171 | Total Loss: 3.6119167804718018 | KNN Lo

Epoch 85 / 200 | iteration 110 / 171 | Total Loss: 3.6526379585266113 | KNN Loss: 3.6186540126800537 | CLS Loss: 0.03398396819829941
Epoch 85 / 200 | iteration 120 / 171 | Total Loss: 3.6158230304718018 | KNN Loss: 3.605956554412842 | CLS Loss: 0.009866478852927685
Epoch 85 / 200 | iteration 130 / 171 | Total Loss: 3.6284825801849365 | KNN Loss: 3.599581003189087 | CLS Loss: 0.028901543468236923
Epoch 85 / 200 | iteration 140 / 171 | Total Loss: 3.6334500312805176 | KNN Loss: 3.626453161239624 | CLS Loss: 0.006996871437877417
Epoch 85 / 200 | iteration 150 / 171 | Total Loss: 3.5954108238220215 | KNN Loss: 3.577103614807129 | CLS Loss: 0.018307199701666832
Epoch 85 / 200 | iteration 160 / 171 | Total Loss: 3.5969326496124268 | KNN Loss: 3.5808956623077393 | CLS Loss: 0.01603698544204235
Epoch 85 / 200 | iteration 170 / 171 | Total Loss: 3.618598461151123 | KNN Loss: 3.611574411392212 | CLS Loss: 0.007024068851023912
Epoch: 085, Loss: 3.6276, Train: 0.9963, Valid: 0.9871, Best: 0.9881
E

Epoch: 088, Loss: 3.6282, Train: 0.9950, Valid: 0.9861, Best: 0.9881
Epoch 89 / 200 | iteration 0 / 171 | Total Loss: 3.6097090244293213 | KNN Loss: 3.5991828441619873 | CLS Loss: 0.01052629854530096
Epoch 89 / 200 | iteration 10 / 171 | Total Loss: 3.61694073677063 | KNN Loss: 3.6133177280426025 | CLS Loss: 0.0036229512188583612
Epoch 89 / 200 | iteration 20 / 171 | Total Loss: 3.6056883335113525 | KNN Loss: 3.5985605716705322 | CLS Loss: 0.00712782796472311
Epoch 89 / 200 | iteration 30 / 171 | Total Loss: 3.646388530731201 | KNN Loss: 3.6406948566436768 | CLS Loss: 0.0056937141343951225
Epoch 89 / 200 | iteration 40 / 171 | Total Loss: 3.5958356857299805 | KNN Loss: 3.5877678394317627 | CLS Loss: 0.008067801594734192
Epoch 89 / 200 | iteration 50 / 171 | Total Loss: 3.6543004512786865 | KNN Loss: 3.641294240951538 | CLS Loss: 0.013006219640374184
Epoch 89 / 200 | iteration 60 / 171 | Total Loss: 3.637173652648926 | KNN Loss: 3.615574359893799 | CLS Loss: 0.02159930393099785
Epoch 89

Epoch 92 / 200 | iteration 60 / 171 | Total Loss: 3.6299216747283936 | KNN Loss: 3.6202521324157715 | CLS Loss: 0.009669617749750614
Epoch 92 / 200 | iteration 70 / 171 | Total Loss: 3.604034900665283 | KNN Loss: 3.5838005542755127 | CLS Loss: 0.020234448835253716
Epoch 92 / 200 | iteration 80 / 171 | Total Loss: 3.6052756309509277 | KNN Loss: 3.6022756099700928 | CLS Loss: 0.002999914577230811
Epoch 92 / 200 | iteration 90 / 171 | Total Loss: 3.6048531532287598 | KNN Loss: 3.6022427082061768 | CLS Loss: 0.0026104506105184555
Epoch 92 / 200 | iteration 100 / 171 | Total Loss: 3.6874210834503174 | KNN Loss: 3.6572487354278564 | CLS Loss: 0.03017231449484825
Epoch 92 / 200 | iteration 110 / 171 | Total Loss: 3.616624355316162 | KNN Loss: 3.601573944091797 | CLS Loss: 0.015050292015075684
Epoch 92 / 200 | iteration 120 / 171 | Total Loss: 3.612304449081421 | KNN Loss: 3.5936269760131836 | CLS Loss: 0.01867748610675335
Epoch 92 / 200 | iteration 130 / 171 | Total Loss: 3.6032867431640625 |

Epoch 95 / 200 | iteration 130 / 171 | Total Loss: 3.6031577587127686 | KNN Loss: 3.5908613204956055 | CLS Loss: 0.012296427972614765
Epoch 95 / 200 | iteration 140 / 171 | Total Loss: 3.6065027713775635 | KNN Loss: 3.599504232406616 | CLS Loss: 0.006998507771641016
Epoch 95 / 200 | iteration 150 / 171 | Total Loss: 3.6330788135528564 | KNN Loss: 3.6258912086486816 | CLS Loss: 0.007187516428530216
Epoch 95 / 200 | iteration 160 / 171 | Total Loss: 3.60300874710083 | KNN Loss: 3.5895142555236816 | CLS Loss: 0.013494417071342468
Epoch 95 / 200 | iteration 170 / 171 | Total Loss: 3.6084210872650146 | KNN Loss: 3.607045888900757 | CLS Loss: 0.0013751289807260036
Epoch: 095, Loss: 3.6249, Train: 0.9965, Valid: 0.9871, Best: 0.9881
Epoch 96 / 200 | iteration 0 / 171 | Total Loss: 3.626405715942383 | KNN Loss: 3.6208367347717285 | CLS Loss: 0.005568910855799913
Epoch 96 / 200 | iteration 10 / 171 | Total Loss: 3.7036592960357666 | KNN Loss: 3.6880064010620117 | CLS Loss: 0.015652917325496674


Epoch 99 / 200 | iteration 10 / 171 | Total Loss: 3.607104778289795 | KNN Loss: 3.604343891143799 | CLS Loss: 0.0027609674725681543
Epoch 99 / 200 | iteration 20 / 171 | Total Loss: 3.593759536743164 | KNN Loss: 3.586265802383423 | CLS Loss: 0.00749370688572526
Epoch 99 / 200 | iteration 30 / 171 | Total Loss: 3.6244349479675293 | KNN Loss: 3.6207594871520996 | CLS Loss: 0.0036754265893250704
Epoch 99 / 200 | iteration 40 / 171 | Total Loss: 3.600209951400757 | KNN Loss: 3.5927391052246094 | CLS Loss: 0.007470823358744383
Epoch 99 / 200 | iteration 50 / 171 | Total Loss: 3.6026833057403564 | KNN Loss: 3.5829503536224365 | CLS Loss: 0.01973291113972664
Epoch 99 / 200 | iteration 60 / 171 | Total Loss: 3.58573842048645 | KNN Loss: 3.5762736797332764 | CLS Loss: 0.00946483388543129
Epoch 99 / 200 | iteration 70 / 171 | Total Loss: 3.565946340560913 | KNN Loss: 3.5513689517974854 | CLS Loss: 0.014577271416783333
Epoch 99 / 200 | iteration 80 / 171 | Total Loss: 3.622305393218994 | KNN Loss

Epoch 102 / 200 | iteration 80 / 171 | Total Loss: 3.6054251194000244 | KNN Loss: 3.601679801940918 | CLS Loss: 0.0037453973200172186
Epoch 102 / 200 | iteration 90 / 171 | Total Loss: 3.6136953830718994 | KNN Loss: 3.6123852729797363 | CLS Loss: 0.0013101596850901842
Epoch 102 / 200 | iteration 100 / 171 | Total Loss: 3.6126914024353027 | KNN Loss: 3.5850138664245605 | CLS Loss: 0.027677452191710472
Epoch 102 / 200 | iteration 110 / 171 | Total Loss: 3.6270532608032227 | KNN Loss: 3.6112849712371826 | CLS Loss: 0.015768174082040787
Epoch 102 / 200 | iteration 120 / 171 | Total Loss: 3.595400094985962 | KNN Loss: 3.5868515968322754 | CLS Loss: 0.008548563346266747
Epoch 102 / 200 | iteration 130 / 171 | Total Loss: 3.588284730911255 | KNN Loss: 3.5782132148742676 | CLS Loss: 0.010071457363665104
Epoch 102 / 200 | iteration 140 / 171 | Total Loss: 3.601775646209717 | KNN Loss: 3.5766208171844482 | CLS Loss: 0.025154942646622658
Epoch 102 / 200 | iteration 150 / 171 | Total Loss: 3.62042

Epoch 105 / 200 | iteration 140 / 171 | Total Loss: 3.621258497238159 | KNN Loss: 3.595087766647339 | CLS Loss: 0.0261706430464983
Epoch 105 / 200 | iteration 150 / 171 | Total Loss: 3.57934308052063 | KNN Loss: 3.578195810317993 | CLS Loss: 0.0011473491322249174
Epoch 105 / 200 | iteration 160 / 171 | Total Loss: 3.6048479080200195 | KNN Loss: 3.5728423595428467 | CLS Loss: 0.03200559318065643
Epoch 105 / 200 | iteration 170 / 171 | Total Loss: 3.5945873260498047 | KNN Loss: 3.5807077884674072 | CLS Loss: 0.01387965027242899
Epoch: 105, Loss: 3.6159, Train: 0.9963, Valid: 0.9868, Best: 0.9881
Epoch 106 / 200 | iteration 0 / 171 | Total Loss: 3.6360819339752197 | KNN Loss: 3.6221022605895996 | CLS Loss: 0.013979783281683922
Epoch 106 / 200 | iteration 10 / 171 | Total Loss: 3.5761168003082275 | KNN Loss: 3.5731942653656006 | CLS Loss: 0.0029225219041109085
Epoch 106 / 200 | iteration 20 / 171 | Total Loss: 3.6563286781311035 | KNN Loss: 3.642249345779419 | CLS Loss: 0.01407929137349128

Epoch 109 / 200 | iteration 20 / 171 | Total Loss: 3.6155145168304443 | KNN Loss: 3.591719627380371 | CLS Loss: 0.02379487454891205
Epoch 109 / 200 | iteration 30 / 171 | Total Loss: 3.654296875 | KNN Loss: 3.631864309310913 | CLS Loss: 0.022432496771216393
Epoch 109 / 200 | iteration 40 / 171 | Total Loss: 3.644212007522583 | KNN Loss: 3.6005733013153076 | CLS Loss: 0.04363860934972763
Epoch 109 / 200 | iteration 50 / 171 | Total Loss: 3.605515956878662 | KNN Loss: 3.590158462524414 | CLS Loss: 0.015357498079538345
Epoch 109 / 200 | iteration 60 / 171 | Total Loss: 3.6847920417785645 | KNN Loss: 3.6636672019958496 | CLS Loss: 0.021124867722392082
Epoch 109 / 200 | iteration 70 / 171 | Total Loss: 3.7126827239990234 | KNN Loss: 3.6932153701782227 | CLS Loss: 0.01946740597486496
Epoch 109 / 200 | iteration 80 / 171 | Total Loss: 3.6285064220428467 | KNN Loss: 3.6124963760375977 | CLS Loss: 0.016010086983442307
Epoch 109 / 200 | iteration 90 / 171 | Total Loss: 3.601691484451294 | KNN Lo

Epoch 112 / 200 | iteration 80 / 171 | Total Loss: 3.6035053730010986 | KNN Loss: 3.5965304374694824 | CLS Loss: 0.006974935065954924
Epoch 112 / 200 | iteration 90 / 171 | Total Loss: 3.6135776042938232 | KNN Loss: 3.587245225906372 | CLS Loss: 0.02633245661854744
Epoch 112 / 200 | iteration 100 / 171 | Total Loss: 3.6555404663085938 | KNN Loss: 3.630105972290039 | CLS Loss: 0.02543458715081215
Epoch 112 / 200 | iteration 110 / 171 | Total Loss: 3.645601272583008 | KNN Loss: 3.632251262664795 | CLS Loss: 0.013349977321922779
Epoch 112 / 200 | iteration 120 / 171 | Total Loss: 3.5855722427368164 | KNN Loss: 3.571620464324951 | CLS Loss: 0.013951738364994526
Epoch 112 / 200 | iteration 130 / 171 | Total Loss: 3.6021888256073 | KNN Loss: 3.5948171615600586 | CLS Loss: 0.0073716104961931705
Epoch 112 / 200 | iteration 140 / 171 | Total Loss: 3.590501546859741 | KNN Loss: 3.584798812866211 | CLS Loss: 0.005702849477529526
Epoch 112 / 200 | iteration 150 / 171 | Total Loss: 3.60858607292175

Epoch 115 / 200 | iteration 140 / 171 | Total Loss: 3.6236655712127686 | KNN Loss: 3.6154463291168213 | CLS Loss: 0.008219242095947266
Epoch 115 / 200 | iteration 150 / 171 | Total Loss: 3.6411449909210205 | KNN Loss: 3.6250200271606445 | CLS Loss: 0.016124870628118515
Epoch 115 / 200 | iteration 160 / 171 | Total Loss: 3.60744309425354 | KNN Loss: 3.5959525108337402 | CLS Loss: 0.01149052381515503
Epoch 115 / 200 | iteration 170 / 171 | Total Loss: 3.6377522945404053 | KNN Loss: 3.63360333442688 | CLS Loss: 0.0041489670984447
Epoch: 115, Loss: 3.6178, Train: 0.9975, Valid: 0.9877, Best: 0.9881
Epoch 116 / 200 | iteration 0 / 171 | Total Loss: 3.6135458946228027 | KNN Loss: 3.6086859703063965 | CLS Loss: 0.00485984468832612
Epoch 116 / 200 | iteration 10 / 171 | Total Loss: 3.6337218284606934 | KNN Loss: 3.6279385089874268 | CLS Loss: 0.005783253349363804
Epoch 116 / 200 | iteration 20 / 171 | Total Loss: 3.583249568939209 | KNN Loss: 3.580674409866333 | CLS Loss: 0.0025752221699804068

Epoch 119 / 200 | iteration 20 / 171 | Total Loss: 3.6071670055389404 | KNN Loss: 3.5991220474243164 | CLS Loss: 0.008044973015785217
Epoch 119 / 200 | iteration 30 / 171 | Total Loss: 3.612978219985962 | KNN Loss: 3.6033852100372314 | CLS Loss: 0.00959306862205267
Epoch 119 / 200 | iteration 40 / 171 | Total Loss: 3.606078624725342 | KNN Loss: 3.5992918014526367 | CLS Loss: 0.006786815822124481
Epoch 119 / 200 | iteration 50 / 171 | Total Loss: 3.6129939556121826 | KNN Loss: 3.6075706481933594 | CLS Loss: 0.005423199385404587
Epoch 119 / 200 | iteration 60 / 171 | Total Loss: 3.6527276039123535 | KNN Loss: 3.638612747192383 | CLS Loss: 0.01411491073668003
Epoch 119 / 200 | iteration 70 / 171 | Total Loss: 3.604142665863037 | KNN Loss: 3.5883240699768066 | CLS Loss: 0.015818608924746513
Epoch 119 / 200 | iteration 80 / 171 | Total Loss: 3.5924365520477295 | KNN Loss: 3.5850725173950195 | CLS Loss: 0.007364145014435053
Epoch 119 / 200 | iteration 90 / 171 | Total Loss: 3.679576635360717

Epoch 122 / 200 | iteration 80 / 171 | Total Loss: 3.652555465698242 | KNN Loss: 3.632920742034912 | CLS Loss: 0.019634706899523735
Epoch 122 / 200 | iteration 90 / 171 | Total Loss: 3.608720302581787 | KNN Loss: 3.6025447845458984 | CLS Loss: 0.006175558548420668
Epoch 122 / 200 | iteration 100 / 171 | Total Loss: 3.617546796798706 | KNN Loss: 3.6056737899780273 | CLS Loss: 0.011872935108840466
Epoch 122 / 200 | iteration 110 / 171 | Total Loss: 3.594090700149536 | KNN Loss: 3.5758755207061768 | CLS Loss: 0.018215246498584747
Epoch 122 / 200 | iteration 120 / 171 | Total Loss: 3.575608253479004 | KNN Loss: 3.561797618865967 | CLS Loss: 0.013810553587973118
Epoch 122 / 200 | iteration 130 / 171 | Total Loss: 3.6691012382507324 | KNN Loss: 3.6638784408569336 | CLS Loss: 0.005222736392170191
Epoch 122 / 200 | iteration 140 / 171 | Total Loss: 3.6868655681610107 | KNN Loss: 3.6741299629211426 | CLS Loss: 0.012735721655189991
Epoch 122 / 200 | iteration 150 / 171 | Total Loss: 3.6283428668

Epoch 125 / 200 | iteration 140 / 171 | Total Loss: 3.625856637954712 | KNN Loss: 3.608428478240967 | CLS Loss: 0.017428120598196983
Epoch 125 / 200 | iteration 150 / 171 | Total Loss: 3.572056293487549 | KNN Loss: 3.5525996685028076 | CLS Loss: 0.019456731155514717
Epoch 125 / 200 | iteration 160 / 171 | Total Loss: 3.640126943588257 | KNN Loss: 3.6269543170928955 | CLS Loss: 0.013172728940844536
Epoch 125 / 200 | iteration 170 / 171 | Total Loss: 3.595499277114868 | KNN Loss: 3.5779781341552734 | CLS Loss: 0.017521236091852188
Epoch: 125, Loss: 3.6183, Train: 0.9955, Valid: 0.9864, Best: 0.9881
Epoch 126 / 200 | iteration 0 / 171 | Total Loss: 3.592500686645508 | KNN Loss: 3.570540189743042 | CLS Loss: 0.021960459649562836
Epoch 126 / 200 | iteration 10 / 171 | Total Loss: 3.6385421752929688 | KNN Loss: 3.6291301250457764 | CLS Loss: 0.009412012994289398
Epoch 126 / 200 | iteration 20 / 171 | Total Loss: 3.619993209838867 | KNN Loss: 3.617649555206299 | CLS Loss: 0.002343720290809869

Epoch 129 / 200 | iteration 20 / 171 | Total Loss: 3.591430902481079 | KNN Loss: 3.588898181915283 | CLS Loss: 0.002532710786908865
Epoch 129 / 200 | iteration 30 / 171 | Total Loss: 3.6317508220672607 | KNN Loss: 3.6221466064453125 | CLS Loss: 0.009604228660464287
Epoch 129 / 200 | iteration 40 / 171 | Total Loss: 3.6359524726867676 | KNN Loss: 3.6015257835388184 | CLS Loss: 0.034426577389240265
Epoch 129 / 200 | iteration 50 / 171 | Total Loss: 3.6138741970062256 | KNN Loss: 3.589832305908203 | CLS Loss: 0.0240419190376997
Epoch 129 / 200 | iteration 60 / 171 | Total Loss: 3.5843639373779297 | KNN Loss: 3.5764667987823486 | CLS Loss: 0.007897049188613892
Epoch 129 / 200 | iteration 70 / 171 | Total Loss: 3.5820722579956055 | KNN Loss: 3.5799121856689453 | CLS Loss: 0.0021600862964987755
Epoch 129 / 200 | iteration 80 / 171 | Total Loss: 3.603118419647217 | KNN Loss: 3.5968575477600098 | CLS Loss: 0.006260890047997236
Epoch 129 / 200 | iteration 90 / 171 | Total Loss: 3.62306785583496

Epoch 132 / 200 | iteration 80 / 171 | Total Loss: 3.59708571434021 | KNN Loss: 3.5937249660491943 | CLS Loss: 0.0033608004450798035
Epoch 132 / 200 | iteration 90 / 171 | Total Loss: 3.5914409160614014 | KNN Loss: 3.5786988735198975 | CLS Loss: 0.012742137536406517
Epoch 132 / 200 | iteration 100 / 171 | Total Loss: 3.60382080078125 | KNN Loss: 3.573345899581909 | CLS Loss: 0.030475009232759476
Epoch 132 / 200 | iteration 110 / 171 | Total Loss: 3.574259042739868 | KNN Loss: 3.5666537284851074 | CLS Loss: 0.007605295162647963
Epoch 132 / 200 | iteration 120 / 171 | Total Loss: 3.590351104736328 | KNN Loss: 3.582319498062134 | CLS Loss: 0.008031499572098255
Epoch 132 / 200 | iteration 130 / 171 | Total Loss: 3.5852653980255127 | KNN Loss: 3.574251890182495 | CLS Loss: 0.01101356465369463
Epoch 132 / 200 | iteration 140 / 171 | Total Loss: 3.619638204574585 | KNN Loss: 3.5997350215911865 | CLS Loss: 0.019903244450688362
Epoch 132 / 200 | iteration 150 / 171 | Total Loss: 3.7050676345825

Epoch 135 / 200 | iteration 140 / 171 | Total Loss: 3.638998508453369 | KNN Loss: 3.6181862354278564 | CLS Loss: 0.02081228606402874
Epoch 135 / 200 | iteration 150 / 171 | Total Loss: 3.6016907691955566 | KNN Loss: 3.5941572189331055 | CLS Loss: 0.007533504161983728
Epoch 135 / 200 | iteration 160 / 171 | Total Loss: 3.614584445953369 | KNN Loss: 3.601925849914551 | CLS Loss: 0.012658597901463509
Epoch 135 / 200 | iteration 170 / 171 | Total Loss: 3.588029623031616 | KNN Loss: 3.5850725173950195 | CLS Loss: 0.002957066288217902
Epoch: 135, Loss: 3.6185, Train: 0.9970, Valid: 0.9866, Best: 0.9881
Epoch 136 / 200 | iteration 0 / 171 | Total Loss: 3.635249137878418 | KNN Loss: 3.632143497467041 | CLS Loss: 0.003105557058006525
Epoch 136 / 200 | iteration 10 / 171 | Total Loss: 3.5765738487243652 | KNN Loss: 3.5744457244873047 | CLS Loss: 0.0021280148066580296
Epoch 136 / 200 | iteration 20 / 171 | Total Loss: 3.6222288608551025 | KNN Loss: 3.6037158966064453 | CLS Loss: 0.018512906506657

Epoch 139 / 200 | iteration 20 / 171 | Total Loss: 3.620741605758667 | KNN Loss: 3.610426664352417 | CLS Loss: 0.010314843617379665
Epoch 139 / 200 | iteration 30 / 171 | Total Loss: 3.565764904022217 | KNN Loss: 3.558926582336426 | CLS Loss: 0.006838324014097452
Epoch 139 / 200 | iteration 40 / 171 | Total Loss: 3.5933353900909424 | KNN Loss: 3.5923497676849365 | CLS Loss: 0.0009857024997472763
Epoch 139 / 200 | iteration 50 / 171 | Total Loss: 3.6001458168029785 | KNN Loss: 3.5923349857330322 | CLS Loss: 0.0078109088353812695
Epoch 139 / 200 | iteration 60 / 171 | Total Loss: 3.5672271251678467 | KNN Loss: 3.5647635459899902 | CLS Loss: 0.002463472541421652
Epoch 139 / 200 | iteration 70 / 171 | Total Loss: 3.567338705062866 | KNN Loss: 3.559861660003662 | CLS Loss: 0.007477061823010445
Epoch 139 / 200 | iteration 80 / 171 | Total Loss: 3.6141796112060547 | KNN Loss: 3.6078758239746094 | CLS Loss: 0.006303807720541954
Epoch 139 / 200 | iteration 90 / 171 | Total Loss: 3.6710164546966

Epoch 142 / 200 | iteration 80 / 171 | Total Loss: 3.63802170753479 | KNN Loss: 3.6356401443481445 | CLS Loss: 0.002381605328992009
Epoch 142 / 200 | iteration 90 / 171 | Total Loss: 3.6373796463012695 | KNN Loss: 3.6069679260253906 | CLS Loss: 0.030411651358008385
Epoch 142 / 200 | iteration 100 / 171 | Total Loss: 3.617928981781006 | KNN Loss: 3.6127731800079346 | CLS Loss: 0.005155717954039574
Epoch 142 / 200 | iteration 110 / 171 | Total Loss: 3.6379926204681396 | KNN Loss: 3.6239194869995117 | CLS Loss: 0.01407324057072401
Epoch 142 / 200 | iteration 120 / 171 | Total Loss: 3.615962266921997 | KNN Loss: 3.6140403747558594 | CLS Loss: 0.001922010094858706
Epoch 142 / 200 | iteration 130 / 171 | Total Loss: 3.6017777919769287 | KNN Loss: 3.5948679447174072 | CLS Loss: 0.006909769494086504
Epoch 142 / 200 | iteration 140 / 171 | Total Loss: 3.6013970375061035 | KNN Loss: 3.599187135696411 | CLS Loss: 0.002209796104580164
Epoch 142 / 200 | iteration 150 / 171 | Total Loss: 3.612819194

Epoch 145 / 200 | iteration 140 / 171 | Total Loss: 3.619194746017456 | KNN Loss: 3.6101105213165283 | CLS Loss: 0.009084267541766167
Epoch 145 / 200 | iteration 150 / 171 | Total Loss: 3.6200993061065674 | KNN Loss: 3.6127164363861084 | CLS Loss: 0.007382981013506651
Epoch 145 / 200 | iteration 160 / 171 | Total Loss: 3.622480630874634 | KNN Loss: 3.5836446285247803 | CLS Loss: 0.03883596137166023
Epoch 145 / 200 | iteration 170 / 171 | Total Loss: 3.6269521713256836 | KNN Loss: 3.595228672027588 | CLS Loss: 0.03172346577048302
Epoch: 145, Loss: 3.6076, Train: 0.9959, Valid: 0.9855, Best: 0.9881
Epoch 146 / 200 | iteration 0 / 171 | Total Loss: 3.5707640647888184 | KNN Loss: 3.543552875518799 | CLS Loss: 0.027211103588342667
Epoch 146 / 200 | iteration 10 / 171 | Total Loss: 3.635529041290283 | KNN Loss: 3.628830909729004 | CLS Loss: 0.0066981082782149315
Epoch 146 / 200 | iteration 20 / 171 | Total Loss: 3.6629738807678223 | KNN Loss: 3.6550395488739014 | CLS Loss: 0.0079342862591147

Epoch 149 / 200 | iteration 20 / 171 | Total Loss: 3.61661434173584 | KNN Loss: 3.595770835876465 | CLS Loss: 0.020843585953116417
Epoch 149 / 200 | iteration 30 / 171 | Total Loss: 3.615825891494751 | KNN Loss: 3.6063807010650635 | CLS Loss: 0.009445198811590672
Epoch 149 / 200 | iteration 40 / 171 | Total Loss: 3.578794479370117 | KNN Loss: 3.571270704269409 | CLS Loss: 0.007523700129240751
Epoch 149 / 200 | iteration 50 / 171 | Total Loss: 3.6378655433654785 | KNN Loss: 3.597137451171875 | CLS Loss: 0.04072817787528038
Epoch 149 / 200 | iteration 60 / 171 | Total Loss: 3.600748062133789 | KNN Loss: 3.5907275676727295 | CLS Loss: 0.010020378045737743
Epoch 149 / 200 | iteration 70 / 171 | Total Loss: 3.5854859352111816 | KNN Loss: 3.579847574234009 | CLS Loss: 0.005638458766043186
Epoch 149 / 200 | iteration 80 / 171 | Total Loss: 3.598257303237915 | KNN Loss: 3.590209722518921 | CLS Loss: 0.008047573268413544
Epoch 149 / 200 | iteration 90 / 171 | Total Loss: 3.611137866973877 | KNN

Epoch 152 / 200 | iteration 80 / 171 | Total Loss: 3.656128406524658 | KNN Loss: 3.6495559215545654 | CLS Loss: 0.00657245796173811
Epoch 152 / 200 | iteration 90 / 171 | Total Loss: 3.649089813232422 | KNN Loss: 3.6288185119628906 | CLS Loss: 0.02027127519249916
Epoch 152 / 200 | iteration 100 / 171 | Total Loss: 3.6156325340270996 | KNN Loss: 3.598353624343872 | CLS Loss: 0.017278896644711494
Epoch 152 / 200 | iteration 110 / 171 | Total Loss: 3.6222665309906006 | KNN Loss: 3.6105713844299316 | CLS Loss: 0.011695112101733685
Epoch 152 / 200 | iteration 120 / 171 | Total Loss: 3.6158249378204346 | KNN Loss: 3.6053824424743652 | CLS Loss: 0.010442475788295269
Epoch 152 / 200 | iteration 130 / 171 | Total Loss: 3.587689161300659 | KNN Loss: 3.5831356048583984 | CLS Loss: 0.004553659353405237
Epoch 152 / 200 | iteration 140 / 171 | Total Loss: 3.602689266204834 | KNN Loss: 3.5993494987487793 | CLS Loss: 0.003339653369039297
Epoch 152 / 200 | iteration 150 / 171 | Total Loss: 3.5994122028

Epoch 155 / 200 | iteration 140 / 171 | Total Loss: 3.6331231594085693 | KNN Loss: 3.623807430267334 | CLS Loss: 0.009315677918493748
Epoch 155 / 200 | iteration 150 / 171 | Total Loss: 3.582397937774658 | KNN Loss: 3.5760319232940674 | CLS Loss: 0.0063661313615739346
Epoch 155 / 200 | iteration 160 / 171 | Total Loss: 3.5692152976989746 | KNN Loss: 3.5606164932250977 | CLS Loss: 0.00859872717410326
Epoch 155 / 200 | iteration 170 / 171 | Total Loss: 3.577214479446411 | KNN Loss: 3.575634241104126 | CLS Loss: 0.0015801527770236135
Epoch: 155, Loss: 3.6080, Train: 0.9975, Valid: 0.9873, Best: 0.9881
Epoch 156 / 200 | iteration 0 / 171 | Total Loss: 3.6580278873443604 | KNN Loss: 3.637699604034424 | CLS Loss: 0.02032817155122757
Epoch 156 / 200 | iteration 10 / 171 | Total Loss: 3.634885549545288 | KNN Loss: 3.631197929382324 | CLS Loss: 0.0036876776721328497
Epoch 156 / 200 | iteration 20 / 171 | Total Loss: 3.597041368484497 | KNN Loss: 3.59324312210083 | CLS Loss: 0.003798344172537327

Epoch 159 / 200 | iteration 20 / 171 | Total Loss: 3.6185574531555176 | KNN Loss: 3.5999081134796143 | CLS Loss: 0.018649283796548843
Epoch 159 / 200 | iteration 30 / 171 | Total Loss: 3.602532148361206 | KNN Loss: 3.600550413131714 | CLS Loss: 0.001981717301532626
Epoch 159 / 200 | iteration 40 / 171 | Total Loss: 3.598616123199463 | KNN Loss: 3.5965113639831543 | CLS Loss: 0.0021047741174697876
Epoch 159 / 200 | iteration 50 / 171 | Total Loss: 3.6107890605926514 | KNN Loss: 3.5970678329467773 | CLS Loss: 0.013721192255616188
Epoch 159 / 200 | iteration 60 / 171 | Total Loss: 3.632171630859375 | KNN Loss: 3.6282150745391846 | CLS Loss: 0.00395645946264267
Epoch 159 / 200 | iteration 70 / 171 | Total Loss: 3.6164462566375732 | KNN Loss: 3.6070451736450195 | CLS Loss: 0.009401059709489346
Epoch 159 / 200 | iteration 80 / 171 | Total Loss: 3.573749303817749 | KNN Loss: 3.5688915252685547 | CLS Loss: 0.004857802297919989
Epoch 159 / 200 | iteration 90 / 171 | Total Loss: 3.56506490707397

Epoch 162 / 200 | iteration 80 / 171 | Total Loss: 3.59816837310791 | KNN Loss: 3.584925889968872 | CLS Loss: 0.01324236486107111
Epoch 162 / 200 | iteration 90 / 171 | Total Loss: 3.6189887523651123 | KNN Loss: 3.6036689281463623 | CLS Loss: 0.015319869853556156
Epoch 162 / 200 | iteration 100 / 171 | Total Loss: 3.5916478633880615 | KNN Loss: 3.587113857269287 | CLS Loss: 0.004533973056823015
Epoch 162 / 200 | iteration 110 / 171 | Total Loss: 3.647597074508667 | KNN Loss: 3.598750591278076 | CLS Loss: 0.048846494406461716
Epoch 162 / 200 | iteration 120 / 171 | Total Loss: 3.620100498199463 | KNN Loss: 3.6078479290008545 | CLS Loss: 0.012252485379576683
Epoch 162 / 200 | iteration 130 / 171 | Total Loss: 3.6257102489471436 | KNN Loss: 3.6220171451568604 | CLS Loss: 0.0036931296344846487
Epoch 162 / 200 | iteration 140 / 171 | Total Loss: 3.60410475730896 | KNN Loss: 3.596083402633667 | CLS Loss: 0.008021265268325806
Epoch 162 / 200 | iteration 150 / 171 | Total Loss: 3.6099836826324

Epoch 165 / 200 | iteration 140 / 171 | Total Loss: 3.566638946533203 | KNN Loss: 3.55916428565979 | CLS Loss: 0.007474665064364672
Epoch 165 / 200 | iteration 150 / 171 | Total Loss: 3.650378465652466 | KNN Loss: 3.644834041595459 | CLS Loss: 0.0055445232428610325
Epoch 165 / 200 | iteration 160 / 171 | Total Loss: 3.650665044784546 | KNN Loss: 3.6418678760528564 | CLS Loss: 0.00879719015210867
Epoch 165 / 200 | iteration 170 / 171 | Total Loss: 3.638160228729248 | KNN Loss: 3.626802444458008 | CLS Loss: 0.01135767437517643
Epoch: 165, Loss: 3.6076, Train: 0.9957, Valid: 0.9862, Best: 0.9881
Epoch 166 / 200 | iteration 0 / 171 | Total Loss: 3.5909039974212646 | KNN Loss: 3.568639039993286 | CLS Loss: 0.022265035659074783
Epoch 166 / 200 | iteration 10 / 171 | Total Loss: 3.6396291255950928 | KNN Loss: 3.631071090698242 | CLS Loss: 0.008558094501495361
Epoch 166 / 200 | iteration 20 / 171 | Total Loss: 3.623832941055298 | KNN Loss: 3.615786552429199 | CLS Loss: 0.008046336472034454
Epo

Epoch 169 / 200 | iteration 20 / 171 | Total Loss: 3.5800859928131104 | KNN Loss: 3.572683572769165 | CLS Loss: 0.007402312941849232
Epoch 169 / 200 | iteration 30 / 171 | Total Loss: 3.6201202869415283 | KNN Loss: 3.6089909076690674 | CLS Loss: 0.011129427701234818
Epoch 169 / 200 | iteration 40 / 171 | Total Loss: 3.6165590286254883 | KNN Loss: 3.613389492034912 | CLS Loss: 0.0031694984063506126
Epoch 169 / 200 | iteration 50 / 171 | Total Loss: 3.6751906871795654 | KNN Loss: 3.64603590965271 | CLS Loss: 0.02915477566421032
Epoch 169 / 200 | iteration 60 / 171 | Total Loss: 3.593156576156616 | KNN Loss: 3.5811197757720947 | CLS Loss: 0.012036804109811783
Epoch 169 / 200 | iteration 70 / 171 | Total Loss: 3.574723482131958 | KNN Loss: 3.5601401329040527 | CLS Loss: 0.014583430252969265
Epoch 169 / 200 | iteration 80 / 171 | Total Loss: 3.6079561710357666 | KNN Loss: 3.597006320953369 | CLS Loss: 0.010949870571494102
Epoch 169 / 200 | iteration 90 / 171 | Total Loss: 3.6253714561462402

Epoch 172 / 200 | iteration 80 / 171 | Total Loss: 3.6310698986053467 | KNN Loss: 3.6195900440216064 | CLS Loss: 0.011479861102998257
Epoch 172 / 200 | iteration 90 / 171 | Total Loss: 3.6158447265625 | KNN Loss: 3.5708298683166504 | CLS Loss: 0.04501485452055931
Epoch 172 / 200 | iteration 100 / 171 | Total Loss: 3.632399797439575 | KNN Loss: 3.629894495010376 | CLS Loss: 0.0025053992867469788
Epoch 172 / 200 | iteration 110 / 171 | Total Loss: 3.60424542427063 | KNN Loss: 3.586005687713623 | CLS Loss: 0.018239634111523628
Epoch 172 / 200 | iteration 120 / 171 | Total Loss: 3.5701873302459717 | KNN Loss: 3.562110185623169 | CLS Loss: 0.008077147416770458
Epoch 172 / 200 | iteration 130 / 171 | Total Loss: 3.5933268070220947 | KNN Loss: 3.5916860103607178 | CLS Loss: 0.0016408308874815702
Epoch 172 / 200 | iteration 140 / 171 | Total Loss: 3.588416814804077 | KNN Loss: 3.5773844718933105 | CLS Loss: 0.01103235874325037
Epoch 172 / 200 | iteration 150 / 171 | Total Loss: 3.5487191677093

Epoch 175 / 200 | iteration 140 / 171 | Total Loss: 3.677370071411133 | KNN Loss: 3.6745712757110596 | CLS Loss: 0.002798736561089754
Epoch 175 / 200 | iteration 150 / 171 | Total Loss: 3.6616055965423584 | KNN Loss: 3.658848285675049 | CLS Loss: 0.002757239155471325
Epoch 175 / 200 | iteration 160 / 171 | Total Loss: 3.6153523921966553 | KNN Loss: 3.5960094928741455 | CLS Loss: 0.019342787563800812
Epoch 175 / 200 | iteration 170 / 171 | Total Loss: 3.623452663421631 | KNN Loss: 3.6061508655548096 | CLS Loss: 0.01730176992714405
Epoch: 175, Loss: 3.6059, Train: 0.9951, Valid: 0.9851, Best: 0.9881
Epoch 176 / 200 | iteration 0 / 171 | Total Loss: 3.6335887908935547 | KNN Loss: 3.613145112991333 | CLS Loss: 0.02044370397925377
Epoch 176 / 200 | iteration 10 / 171 | Total Loss: 3.6115128993988037 | KNN Loss: 3.5786967277526855 | CLS Loss: 0.03281623497605324
Epoch 176 / 200 | iteration 20 / 171 | Total Loss: 3.627920150756836 | KNN Loss: 3.6090309619903564 | CLS Loss: 0.01888924092054367

Epoch 179 / 200 | iteration 20 / 171 | Total Loss: 3.615893840789795 | KNN Loss: 3.6083104610443115 | CLS Loss: 0.007583432365208864
Epoch 179 / 200 | iteration 30 / 171 | Total Loss: 3.598613739013672 | KNN Loss: 3.588073253631592 | CLS Loss: 0.010540553368628025
Epoch 179 / 200 | iteration 40 / 171 | Total Loss: 3.5743699073791504 | KNN Loss: 3.5663187503814697 | CLS Loss: 0.008051042445003986
Epoch 179 / 200 | iteration 50 / 171 | Total Loss: 3.6808083057403564 | KNN Loss: 3.6779603958129883 | CLS Loss: 0.002847852883860469
Epoch 179 / 200 | iteration 60 / 171 | Total Loss: 3.6325833797454834 | KNN Loss: 3.605116844177246 | CLS Loss: 0.027466539293527603
Epoch 179 / 200 | iteration 70 / 171 | Total Loss: 3.6561825275421143 | KNN Loss: 3.6430938243865967 | CLS Loss: 0.013088789768517017
Epoch 179 / 200 | iteration 80 / 171 | Total Loss: 3.620182752609253 | KNN Loss: 3.612419366836548 | CLS Loss: 0.00776332151144743
Epoch 179 / 200 | iteration 90 / 171 | Total Loss: 3.617846727371216 

Epoch 182 / 200 | iteration 80 / 171 | Total Loss: 3.622291088104248 | KNN Loss: 3.59264874458313 | CLS Loss: 0.029642246663570404
Epoch 182 / 200 | iteration 90 / 171 | Total Loss: 3.6464555263519287 | KNN Loss: 3.636695384979248 | CLS Loss: 0.009760159999132156
Epoch 182 / 200 | iteration 100 / 171 | Total Loss: 3.6399755477905273 | KNN Loss: 3.6218209266662598 | CLS Loss: 0.018154628574848175
Epoch 182 / 200 | iteration 110 / 171 | Total Loss: 3.610546350479126 | KNN Loss: 3.5860025882720947 | CLS Loss: 0.024543670937418938
Epoch 182 / 200 | iteration 120 / 171 | Total Loss: 3.594552516937256 | KNN Loss: 3.5824785232543945 | CLS Loss: 0.012074087746441364
Epoch 182 / 200 | iteration 130 / 171 | Total Loss: 3.5958774089813232 | KNN Loss: 3.592085123062134 | CLS Loss: 0.0037921948824077845
Epoch 182 / 200 | iteration 140 / 171 | Total Loss: 3.604705572128296 | KNN Loss: 3.601583242416382 | CLS Loss: 0.0031223229598253965
Epoch 182 / 200 | iteration 150 / 171 | Total Loss: 3.5843250751

Epoch 185 / 200 | iteration 140 / 171 | Total Loss: 3.5805492401123047 | KNN Loss: 3.5769805908203125 | CLS Loss: 0.003568677231669426
Epoch 185 / 200 | iteration 150 / 171 | Total Loss: 3.5757715702056885 | KNN Loss: 3.57362961769104 | CLS Loss: 0.0021419499535113573
Epoch 185 / 200 | iteration 160 / 171 | Total Loss: 3.6233317852020264 | KNN Loss: 3.606220006942749 | CLS Loss: 0.017111727967858315
Epoch 185 / 200 | iteration 170 / 171 | Total Loss: 3.6639420986175537 | KNN Loss: 3.6395184993743896 | CLS Loss: 0.024423567578196526
Epoch: 185, Loss: 3.6137, Train: 0.9960, Valid: 0.9851, Best: 0.9881
Epoch 186 / 200 | iteration 0 / 171 | Total Loss: 3.584944725036621 | KNN Loss: 3.572535514831543 | CLS Loss: 0.012409098446369171
Epoch 186 / 200 | iteration 10 / 171 | Total Loss: 3.581165313720703 | KNN Loss: 3.576125383377075 | CLS Loss: 0.005039928946644068
Epoch 186 / 200 | iteration 20 / 171 | Total Loss: 3.6369593143463135 | KNN Loss: 3.624846935272217 | CLS Loss: 0.0121124750003218

Epoch 189 / 200 | iteration 20 / 171 | Total Loss: 3.6054458618164062 | KNN Loss: 3.597970724105835 | CLS Loss: 0.007475217338651419
Epoch 189 / 200 | iteration 30 / 171 | Total Loss: 3.614314079284668 | KNN Loss: 3.59910249710083 | CLS Loss: 0.015211476944386959
Epoch 189 / 200 | iteration 40 / 171 | Total Loss: 3.598947525024414 | KNN Loss: 3.5901756286621094 | CLS Loss: 0.008771778084337711
Epoch 189 / 200 | iteration 50 / 171 | Total Loss: 3.5944154262542725 | KNN Loss: 3.579150438308716 | CLS Loss: 0.015265066176652908
Epoch 189 / 200 | iteration 60 / 171 | Total Loss: 3.6110453605651855 | KNN Loss: 3.5961081981658936 | CLS Loss: 0.01493705902248621
Epoch 189 / 200 | iteration 70 / 171 | Total Loss: 3.5712316036224365 | KNN Loss: 3.556103467941284 | CLS Loss: 0.01512810867279768
Epoch 189 / 200 | iteration 80 / 171 | Total Loss: 3.6456170082092285 | KNN Loss: 3.6373629570007324 | CLS Loss: 0.008254148066043854
Epoch 189 / 200 | iteration 90 / 171 | Total Loss: 3.6576714515686035 |

Epoch 192 / 200 | iteration 80 / 171 | Total Loss: 3.5929462909698486 | KNN Loss: 3.5877981185913086 | CLS Loss: 0.005148110445588827
Epoch 192 / 200 | iteration 90 / 171 | Total Loss: 3.596129894256592 | KNN Loss: 3.588235855102539 | CLS Loss: 0.007894113659858704
Epoch 192 / 200 | iteration 100 / 171 | Total Loss: 3.62921142578125 | KNN Loss: 3.6154186725616455 | CLS Loss: 0.013792778365314007
Epoch 192 / 200 | iteration 110 / 171 | Total Loss: 3.6082143783569336 | KNN Loss: 3.606281280517578 | CLS Loss: 0.0019331170478835702
Epoch 192 / 200 | iteration 120 / 171 | Total Loss: 3.6608004570007324 | KNN Loss: 3.642062187194824 | CLS Loss: 0.01873824931681156
Epoch 192 / 200 | iteration 130 / 171 | Total Loss: 3.7065107822418213 | KNN Loss: 3.680169105529785 | CLS Loss: 0.02634158544242382
Epoch 192 / 200 | iteration 140 / 171 | Total Loss: 3.6399686336517334 | KNN Loss: 3.6156833171844482 | CLS Loss: 0.024285387247800827
Epoch 192 / 200 | iteration 150 / 171 | Total Loss: 3.60504770278

Epoch 195 / 200 | iteration 150 / 171 | Total Loss: 3.6351470947265625 | KNN Loss: 3.6327919960021973 | CLS Loss: 0.0023551343474537134
Epoch 195 / 200 | iteration 160 / 171 | Total Loss: 3.559770345687866 | KNN Loss: 3.5528485774993896 | CLS Loss: 0.0069217742420732975
Epoch 195 / 200 | iteration 170 / 171 | Total Loss: 3.5854344367980957 | KNN Loss: 3.576892375946045 | CLS Loss: 0.008541978895664215
Epoch: 195, Loss: 3.6105, Train: 0.9979, Valid: 0.9871, Best: 0.9881
Epoch 196 / 200 | iteration 0 / 171 | Total Loss: 3.591273784637451 | KNN Loss: 3.5829689502716064 | CLS Loss: 0.008304794318974018
Epoch 196 / 200 | iteration 10 / 171 | Total Loss: 3.5628163814544678 | KNN Loss: 3.5591280460357666 | CLS Loss: 0.0036883633583784103
Epoch 196 / 200 | iteration 20 / 171 | Total Loss: 3.576659917831421 | KNN Loss: 3.5755321979522705 | CLS Loss: 0.0011278289603069425
Epoch 196 / 200 | iteration 30 / 171 | Total Loss: 3.5898005962371826 | KNN Loss: 3.5870003700256348 | CLS Loss: 0.0028003354

Epoch 199 / 200 | iteration 30 / 171 | Total Loss: 3.6455960273742676 | KNN Loss: 3.640336275100708 | CLS Loss: 0.005259797442704439
Epoch 199 / 200 | iteration 40 / 171 | Total Loss: 3.6235580444335938 | KNN Loss: 3.5912466049194336 | CLS Loss: 0.032311420887708664
Epoch 199 / 200 | iteration 50 / 171 | Total Loss: 3.664645195007324 | KNN Loss: 3.6504340171813965 | CLS Loss: 0.014211189933121204
Epoch 199 / 200 | iteration 60 / 171 | Total Loss: 3.5963103771209717 | KNN Loss: 3.581380605697632 | CLS Loss: 0.014929704368114471
Epoch 199 / 200 | iteration 70 / 171 | Total Loss: 3.5769057273864746 | KNN Loss: 3.571948766708374 | CLS Loss: 0.004956903867423534
Epoch 199 / 200 | iteration 80 / 171 | Total Loss: 3.5752196311950684 | KNN Loss: 3.5630714893341064 | CLS Loss: 0.012148236855864525
Epoch 199 / 200 | iteration 90 / 171 | Total Loss: 3.600748062133789 | KNN Loss: 3.584953546524048 | CLS Loss: 0.015794595703482628
Epoch 199 / 200 | iteration 100 / 171 | Total Loss: 3.56090664863586

In [8]:
test(model, test_data_iter, device)

tensor(0.9870, device='cuda:0')

In [9]:
plt.figure()
plt.plot(losses, label='train loss')
plt.legend()
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [10]:
plt.figure()
plt.plot(train_accs, label='train accuracy')
plt.plot(val_accs, label='validation accuracy')
plt.legend()
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [11]:
test_samples = torch.tensor([])
projections = torch.tensor([])
labels = torch.tensor([])

with torch.no_grad():
    for x, y in tqdm(test_data_iter):
        test_samples = torch.cat([test_samples, x])
        labels = torch.cat([labels, y])
        x = x.to(device)
        _, interm = model(x, True)
        projections = torch.cat([projections, interm.detach().cpu().flatten(1)])

  0%|          | 0/43 [00:00<?, ?it/s]

In [12]:
distances = pairwise_distances(projections)
# distances = np.triu(distances)
distances_f = distances.flatten()

plt.matshow(distances)
plt.colorbar()
plt.figure()
plt.hist(distances_f[distances_f > 0], bins=1000)
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [13]:
clusters = DBSCAN(eps=2, min_samples=10).fit_predict(projections)

In [14]:
print(f"Number of inliers: {sum(clusters != -1) / len(clusters)}")

Number of inliers: 0.9703074322781051


In [15]:
perplexity = 100
p = reduce_dims_and_plot(projections[clusters != -1],
                         y=clusters[clusters != -1],
                         title=f'perplexity: {perplexity}',
                         file_name=None,
                         perplexity=perplexity,
                         library='Multicore-TSNE',
                         perform_PCA=False,
                         projected=None,
                         figure_type='2d',
                         show_figure=True,
                         close_figure=False,
                         text=None)    

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

# Train a Soft-Decision-Tree given the self-labels

## Prepare the dataset

In [16]:
tree_dataset = list(zip(test_samples.flatten(1)[clusters!=-1], clusters[clusters != -1]))
batch_size = 512
tree_loader = torch.utils.data.DataLoader(tree_dataset, batch_size=batch_size, shuffle=True)

# Define how we prune the weights of a node

In [17]:
def prune_node(node_weights, factor=1):
    w = node_weights.cpu().detach().numpy()
    mean_ = np.mean(w)
    std_ = np.std(w)
    node_weights[((mean_ - std_ * factor) < node_weights) & (node_weights < (mean_ + std_ * factor))] = 0
    return node_weights

def prune_node_keep(node_weights, keep=4):
    w = node_weights.cpu().detach().numpy()
    throw_idx = np.argsort(abs(w))[:-keep]
    node_weights[throw_idx] = 0
    return node_weights

def prune_tree(tree_, factor):
    new_weights = tree_.inner_nodes.weight.clone()
    for i in range(new_weights.shape[0]):
        res = prune_node_keep(new_weights[i, :], factor)
        new_weights[i, :] = res

    with torch.no_grad():
        tree_.inner_nodes.weight.copy_(new_weights)
        
def sparseness(x):
    s = []
    for i in range(x.shape[0]):
        x_ = x[i, :]
        sp = (len(x_) - torch.norm(x_, 0).item()) / len(x_)
        s.append(sp)
    return np.mean(s)

def compute_regularization_by_level(tree):
    total_reg = 0
    for i in range(tree.inner_nodes.weight.shape[0]):
        cur_level = np.floor(np.log2(i+1))
        node_reg = torch.norm(tree.inner_nodes.weight[i].view(-1), 2)
        total_reg += 2**(-cur_level) * node_reg
    return total_reg

def show_sparseness(tree):
    avg_sp = sparseness(tree.inner_nodes.weight)
    print(f"Average sparseness: {avg_sp}")
    layer = 0
    sps = []
    for i in range(tree.inner_nodes.weight.shape[0]):
        cur_layer = int(np.floor(np.log2(i+1)))
        if cur_layer != layer:
            print(f"layer {layer}: {np.mean(sps)}")
            sps = []
            layer = cur_layer

        x_ = tree.inner_nodes.weight[i, :]
        sp = (len(x_) - torch.norm(x_, 0).item()) / len(x_)
        sps.append(sp)
        
    return avg_sp

## Training configurations

In [18]:
def do_epoch(model, loader, device, log_interval, losses, accs, epoch, iteration):
    model = model.train()
    for batch_idx, (data, target) in enumerate(loader):
        iteration += 1
        data, target = data.to(device), target.to(device)

        output, penalty = tree.forward(data)

        # Loss
        loss_tree = criterion(output, target.view(-1))

        # Penalty
        loss_tree += penalty

        # Sparse regularization
#         fc_params = torch.cat([x.view(-1) for x in tree.inner_nodes.parameters()])
#         regularization = sparsity_lamda * torch.norm(fc_params, 2)
        regularization = sparsity_lamda * compute_regularization_by_level(tree)
        loss = loss_tree

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        losses.append(loss.item())

        pred = output.data.max(1)[1]
        correct = pred.eq(target.view(-1).data).sum()
        accs.append(correct.item() / data.size()[0])

        # Print training status
        if batch_idx % log_interval == 0:
            print(f"Epoch: {epoch:02d} | Batch: {batch_idx:03d} / {len(loader):03d} | Total loss: {loss.item():.3f} | Reg loss: {regularization.item():.3f} | Tree loss: {loss_tree.item():.3f} | Accuracy: {correct.item() / data.size()[0]:03f} | {round((time.time() - start_time) / iteration, 3)} sec/iter")
            
    return iteration


In [31]:
lr = 5e-3
weight_decay = 5e-4
sparsity_lamda = 2e-3
epochs = 400
log_interval = 100
use_cuda = device != 'cpu'

In [32]:
tree = SDT(input_dim=test_samples.shape[2], output_dim=len(set(clusters)) - 1, depth=tree_depth, lamda=1e-3, use_cuda=use_cuda)
optimizer = torch.optim.Adam(tree.parameters(),
                                 lr=lr,
                                 weight_decay=weight_decay)
criterion = nn.CrossEntropyLoss()
tree = tree.to(device)

In [33]:
losses = []
accs = []
sparsity = []

In [None]:
start_time = time.time()
iteration = 0
for epoch in range(epochs):
    # Training
    avg_sp = show_sparseness(tree)
    sparsity.append(avg_sp)
    iteration = do_epoch(tree, tree_loader, device, log_interval, losses, accs, epoch, iteration)
    
    if epoch % 1 == 0:
        prune_tree(tree, factor=3)
        

Average sparseness: 0.0
layer 0: 0.0
layer 1: 0.0
layer 2: 0.0
layer 3: 0.0
layer 4: 0.0
layer 5: 0.0
layer 6: 0.0
layer 7: 0.0
layer 8: 0.0
layer 9: 0.0
layer 10: 0.0
Epoch: 00 | Batch: 000 / 042 | Total loss: 1.627 | Reg loss: 0.014 | Tree loss: 1.627 | Accuracy: 0.072266 | 3.746 sec/iter
Average sparseness: 0.9840425531914895
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
layer 7: 0.9840425531914895
layer 8: 0.9840425531914895
layer 9: 0.9840425531914895
layer 10: 0.9840425531914895
Epoch: 01 | Batch: 000 / 042 | Total loss: 1.587 | Reg loss: 0.006 | Tree loss: 1.587 | Accuracy: 0.429688 | 3.663 sec/iter
Average sparseness: 0.9840425531914895
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
l

Average sparseness: 0.9840425531914895
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
layer 7: 0.9840425531914895
layer 8: 0.9840425531914895
layer 9: 0.9840425531914895
layer 10: 0.9840425531914895
Epoch: 18 | Batch: 000 / 042 | Total loss: 1.263 | Reg loss: 0.014 | Tree loss: 1.263 | Accuracy: 0.560547 | 3.886 sec/iter
Average sparseness: 0.9840425531914895
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
layer 7: 0.9840425531914895
layer 8: 0.9840425531914895
layer 9: 0.9840425531914895
layer 10: 0.9840425531914895
Epoch: 19 | Batch: 000 / 042 | Total loss: 1.232 | Reg loss: 0.014 | Tree loss: 1.232 | Accuracy: 0.587891 | 3.888 sec/iter
Average sparseness: 0.9840425531914895
layer 0: 0.984042

Epoch: 35 | Batch: 000 / 042 | Total loss: 1.214 | Reg loss: 0.015 | Tree loss: 1.214 | Accuracy: 0.607422 | 3.915 sec/iter
Average sparseness: 0.9840425531914895
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
layer 7: 0.9840425531914895
layer 8: 0.9840425531914895
layer 9: 0.9840425531914895
layer 10: 0.9840425531914895
Epoch: 36 | Batch: 000 / 042 | Total loss: 1.184 | Reg loss: 0.015 | Tree loss: 1.184 | Accuracy: 0.615234 | 3.915 sec/iter
Average sparseness: 0.9840425531914895
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
layer 7: 0.9840425531914895
layer 8: 0.9840425531914895
layer 9: 0.9840425531914895
layer 10: 0.9840425531914895
Epoch: 37 | Batch: 000 / 042 | Total loss: 1.249 | Reg l

Epoch: 53 | Batch: 000 / 042 | Total loss: 1.147 | Reg loss: 0.014 | Tree loss: 1.147 | Accuracy: 0.640625 | 3.924 sec/iter
Average sparseness: 0.9840425531914895
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
layer 7: 0.9840425531914895
layer 8: 0.9840425531914895
layer 9: 0.9840425531914895
layer 10: 0.9840425531914895
Epoch: 54 | Batch: 000 / 042 | Total loss: 1.218 | Reg loss: 0.014 | Tree loss: 1.218 | Accuracy: 0.599609 | 3.924 sec/iter
Average sparseness: 0.9840425531914895
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
layer 7: 0.9840425531914895
layer 8: 0.9840425531914895
layer 9: 0.9840425531914895
layer 10: 0.9840425531914895
Epoch: 55 | Batch: 000 / 042 | Total loss: 1.238 | Reg l

In [None]:
plt.figure(figsize=(10, 5))
plt.ylabel("Accuracy")
plt.xlabel('Iteration')
plt.plot(accs, label='Accuracy vs iteration')
plt.show()

In [None]:
plt.figure()
plt.ylabel("Loss")
plt.xlabel('Iteration')
plt.plot(losses, label='Loss vs iteration')
# plt.yscale("log")
plt.show()

plt.figure()
weights = tree.inner_nodes.weight.cpu().detach().numpy().flatten()
plt.hist(weights, bins=500)
weights_std = np.std(weights)
weights_mean = np.mean(weights)
plt.axvline(weights_mean + weights_std, color='r')
plt.axvline(weights_mean - weights_std, color='r')
plt.title(f"Mean: {weights_mean}   |   STD: {weights_std}")
plt.yscale("log")
plt.show()

# Tree Visualization

In [None]:
plt.figure(figsize=(15, 10), dpi=80)
avg_height, root = tree.visualize()

# Extract Rules

# Accumulate samples in the leaves

In [None]:
print(f"Number of patterns: {len(root.get_leaves())}")

In [None]:
method = 'greedy'

In [None]:
root.clear_leaves_samples()

with torch.no_grad():
    for batch_idx, (data, target) in enumerate(tree_loader):
        root.accumulate_samples(data, method)

# Tighten boundaries

In [None]:
attr_names = [f"T_{i}" for i in range(test_samples.shape[2])]
leaves = root.get_leaves()
sum_comprehensibility = 0
comprehensibilities = []
for pattern_counter, leaf in enumerate(leaves):
    leaf.reset_path()
    leaf.tighten_with_accumulated_samples()
    conds = leaf.get_path_conditions(attr_names)
    print(f"============== Pattern {pattern_counter + 1} ==============")
    comprehensibilities.append(sum([cond.comprehensibility for cond in conds]))
    
print(f"Average comprehensibility: {np.mean(comprehensibilities)}")
print(f"std comprehensibility: {np.std(comprehensibilities)}")
print(f"var comprehensibility: {np.var(comprehensibilities)}")
print(f"minimum comprehensibility: {np.min(comprehensibilities)}")
print(f"maximum comprehensibility: {np.max(comprehensibilities)}")