In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib widget
import numpy as np
import pandas as pd
import torch
import torchvision
from torch import optim
import matplotlib.pyplot as plt
import torch.nn.functional as F
from sklearn.metrics import pairwise_distances
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tqdm.notebook import tqdm
import os
import sys
import time
import torch.nn as nn
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)
from stream_generators.mit_bih import MITBIH
from utils.MatplotlibUtils import reduce_dims_and_plot
from network.auto_encoder import AutoEncoder
from losses.knn_loss import KNNLoss, ClassificationKNNLoss
from sklearn.cluster import DBSCAN
from sklearn.mixture import GaussianMixture
from soft_decision_tree.sdt_model import SDT
from sklearn.metrics import davies_bouldin_score

In [2]:
k = 8
tree_depth = 12
batch_size = 512
device = 'cuda'
train_data_path = r'/mnt/qnap/ekosman/mitbih_train.csv'
test_data_path = r'/mnt/qnap/ekosman/mitbih_test.csv'

In [3]:
train_data_iter = torch.utils.data.DataLoader(MITBIH(train_data_path),
                                             batch_size=batch_size,
                                             shuffle=True,
                                             num_workers=1,
                                             pin_memory=True,
                                             drop_last=True)

test_data_iter = torch.utils.data.DataLoader(MITBIH(test_data_path),
                                             batch_size=batch_size,
                                             shuffle=True,
                                             num_workers=1,
                                             pin_memory=True)

In [4]:
class ConvBlock(nn.Module):
    def __init__(self):
        super(ConvBlock, self).__init__()
        self.conv1 = nn.Conv1d(32, 32, kernel_size=5, stride=1, padding=2)
        self.conv2 = nn.Conv1d(32, 32, kernel_size=5, stride=1, padding=2)
        self.relu1 = nn.ReLU()
        self.relu2 = nn.ReLU()
        self.pool = nn.MaxPool1d(kernel_size=5, stride=2)
        
    def forward(self, x):
        y = x
        y = self.conv1(y)
        y = self.relu1(y)
        y = self.conv2(y)
        y = y + x
        y = self.relu2(y)
        y = self.pool(y)
        return y


class ECGModel(nn.Module):
    def __init__(self):
        super(ECGModel, self).__init__()
        self.conv = nn.Conv1d(1, 32, kernel_size=5, stride=1, padding=1)
        self.block1 = ConvBlock()
        self.block2 = ConvBlock()
        self.block3 = ConvBlock()
        self.block4 = ConvBlock()
        self.block5 = ConvBlock()
        self.fc1 = nn.Linear(64, 32)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(32, 5)

    def forward(self, x, return_interm=False):
        x = self.conv(x)
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        x = self.block4(x)
        x = self.block5(x)
        interm = x.flatten(1)
        x = self.fc1(interm)
        x = self.relu(x)
        x = self.fc2(x)
        
        if return_interm:
            return x, interm
        
        return x

In [5]:
knn_crt = ClassificationKNNLoss(k=k).to(device)

def train(model, loader, optimizer, device):
    model.train()

    total_loss = 0
    for iteration, (batch, target) in enumerate(loader):
        batch = batch.to(device)
        target = target.to(device)
        outputs, interm = model(batch, return_interm=True)
        mse_loss = F.cross_entropy(outputs, target)
        mse_loss = mse_loss.sum(dim=-1).mean()
        try:
            knn_loss = knn_crt(interm, target)
            if torch.isinf(knn_loss):
                knn_loss = torch.tensor(0).to(device)
        except ValueError:
            knn_loss = torch.tensor(0).to(device)
        loss = mse_loss + knn_loss
        total_loss += loss.item()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if iteration % log_every == 0:
            print(f"Epoch {epoch} / {epochs} | iteration {iteration} / {len(loader)} | Total Loss: {loss.item()} | KNN Loss: {knn_loss.item()} | CLS Loss: {mse_loss.item()}")

    return total_loss / len(loader)

@torch.no_grad()
def test(model, loader, device):
    model.eval()
    
    correct = 0
    for iteration, (batch, target) in enumerate(loader):
        batch = batch.to(device)
        target = target.to(device)
        y_pred = model(batch).argmax(dim=-1)
        correct += y_pred.eq(target.view(-1).data).sum()
    
    return correct / len(loader.dataset)

In [6]:
epochs = 200
lr = 1e-3
log_every = 10

model = ECGModel().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
num_params = sum([p.numel() for p in model.parameters()])
print(f'#Params: {num_params}')

#Params: 53957


In [7]:
best_valid_acc = 0
losses = []
train_accs = []
val_accs = []
for epoch in range(1, epochs + 1):
    loss = train(model, train_data_iter, optimizer, device)
#     print(f"Loss: {loss} =============================")
    losses.append(loss)
    train_acc = test(model, train_data_iter, device)
    train_accs.append(train_acc)
    valid_acc = test(model, test_data_iter, device)
    val_accs.append(valid_acc)
    if valid_acc > best_valid_acc:
        best_valid_acc = valid_acc
    if epoch % 1 == 0:
        print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, '
              f'Train: {train_acc:.4f}, Valid: {valid_acc:.4f}, '
              f'Best: {best_valid_acc:.4f}')

Epoch 1 / 200 | iteration 0 / 171 | Total Loss: 7.476043224334717 | KNN Loss: 5.769813060760498 | CLS Loss: 1.7062300443649292
Epoch 1 / 200 | iteration 10 / 171 | Total Loss: 4.464776992797852 | KNN Loss: 3.2373194694519043 | CLS Loss: 1.2274576425552368
Epoch 1 / 200 | iteration 20 / 171 | Total Loss: 3.391493320465088 | KNN Loss: 2.6543593406677246 | CLS Loss: 0.7371340990066528
Epoch 1 / 200 | iteration 30 / 171 | Total Loss: 3.1886754035949707 | KNN Loss: 2.561807632446289 | CLS Loss: 0.6268676519393921
Epoch 1 / 200 | iteration 40 / 171 | Total Loss: 3.188472032546997 | KNN Loss: 2.5338451862335205 | CLS Loss: 0.6546267867088318
Epoch 1 / 200 | iteration 50 / 171 | Total Loss: 3.1445138454437256 | KNN Loss: 2.6530253887176514 | CLS Loss: 0.49148836731910706
Epoch 1 / 200 | iteration 60 / 171 | Total Loss: 3.0928361415863037 | KNN Loss: 2.5160574913024902 | CLS Loss: 0.5767786502838135
Epoch 1 / 200 | iteration 70 / 171 | Total Loss: 2.983839511871338 | KNN Loss: 2.521601200103759

Epoch 4 / 200 | iteration 80 / 171 | Total Loss: 2.598839044570923 | KNN Loss: 2.4779484272003174 | CLS Loss: 0.12089067697525024
Epoch 4 / 200 | iteration 90 / 171 | Total Loss: 2.6051814556121826 | KNN Loss: 2.512763023376465 | CLS Loss: 0.09241849184036255
Epoch 4 / 200 | iteration 100 / 171 | Total Loss: 2.6455912590026855 | KNN Loss: 2.53206729888916 | CLS Loss: 0.11352407932281494
Epoch 4 / 200 | iteration 110 / 171 | Total Loss: 2.606755018234253 | KNN Loss: 2.515778064727783 | CLS Loss: 0.09097685664892197
Epoch 4 / 200 | iteration 120 / 171 | Total Loss: 2.613258123397827 | KNN Loss: 2.504054546356201 | CLS Loss: 0.10920365154743195
Epoch 4 / 200 | iteration 130 / 171 | Total Loss: 2.590320587158203 | KNN Loss: 2.4629762172698975 | CLS Loss: 0.1273442655801773
Epoch 4 / 200 | iteration 140 / 171 | Total Loss: 2.6003620624542236 | KNN Loss: 2.474505662918091 | CLS Loss: 0.1258564591407776
Epoch 4 / 200 | iteration 150 / 171 | Total Loss: 2.6393747329711914 | KNN Loss: 2.5004622

Epoch 7 / 200 | iteration 160 / 171 | Total Loss: 2.5393362045288086 | KNN Loss: 2.4390337467193604 | CLS Loss: 0.10030245035886765
Epoch 7 / 200 | iteration 170 / 171 | Total Loss: 2.56878662109375 | KNN Loss: 2.441195011138916 | CLS Loss: 0.12759165465831757
Epoch: 007, Loss: 2.5576, Train: 0.9771, Valid: 0.9739, Best: 0.9739
Epoch 8 / 200 | iteration 0 / 171 | Total Loss: 2.5666160583496094 | KNN Loss: 2.468266725540161 | CLS Loss: 0.09834936261177063
Epoch 8 / 200 | iteration 10 / 171 | Total Loss: 2.557650327682495 | KNN Loss: 2.459782123565674 | CLS Loss: 0.09786819666624069
Epoch 8 / 200 | iteration 20 / 171 | Total Loss: 2.5491745471954346 | KNN Loss: 2.4246954917907715 | CLS Loss: 0.12447910010814667
Epoch 8 / 200 | iteration 30 / 171 | Total Loss: 2.5385468006134033 | KNN Loss: 2.462998867034912 | CLS Loss: 0.07554784417152405
Epoch 8 / 200 | iteration 40 / 171 | Total Loss: 2.5647027492523193 | KNN Loss: 2.4557669162750244 | CLS Loss: 0.10893575102090836
Epoch 8 / 200 | iter

Epoch 11 / 200 | iteration 50 / 171 | Total Loss: 2.4872567653656006 | KNN Loss: 2.4295566082000732 | CLS Loss: 0.05770007520914078
Epoch 11 / 200 | iteration 60 / 171 | Total Loss: 2.546365261077881 | KNN Loss: 2.4590630531311035 | CLS Loss: 0.0873023197054863
Epoch 11 / 200 | iteration 70 / 171 | Total Loss: 2.5020530223846436 | KNN Loss: 2.4328184127807617 | CLS Loss: 0.0692346841096878
Epoch 11 / 200 | iteration 80 / 171 | Total Loss: 2.5381650924682617 | KNN Loss: 2.461623191833496 | CLS Loss: 0.07654178887605667
Epoch 11 / 200 | iteration 90 / 171 | Total Loss: 2.5211596488952637 | KNN Loss: 2.43990159034729 | CLS Loss: 0.08125804364681244
Epoch 11 / 200 | iteration 100 / 171 | Total Loss: 2.533362627029419 | KNN Loss: 2.455315589904785 | CLS Loss: 0.07804711163043976
Epoch 11 / 200 | iteration 110 / 171 | Total Loss: 2.4923484325408936 | KNN Loss: 2.420955181121826 | CLS Loss: 0.07139333337545395
Epoch 11 / 200 | iteration 120 / 171 | Total Loss: 2.4960505962371826 | KNN Loss: 2

Epoch 14 / 200 | iteration 120 / 171 | Total Loss: 2.499396562576294 | KNN Loss: 2.453481435775757 | CLS Loss: 0.04591507464647293
Epoch 14 / 200 | iteration 130 / 171 | Total Loss: 2.462001323699951 | KNN Loss: 2.413766384124756 | CLS Loss: 0.04823485016822815
Epoch 14 / 200 | iteration 140 / 171 | Total Loss: 2.498473882675171 | KNN Loss: 2.407113552093506 | CLS Loss: 0.09136022627353668
Epoch 14 / 200 | iteration 150 / 171 | Total Loss: 2.4779276847839355 | KNN Loss: 2.412907123565674 | CLS Loss: 0.06502047181129456
Epoch 14 / 200 | iteration 160 / 171 | Total Loss: 2.4980461597442627 | KNN Loss: 2.4089159965515137 | CLS Loss: 0.08913024514913559
Epoch 14 / 200 | iteration 170 / 171 | Total Loss: 2.499948263168335 | KNN Loss: 2.43463397026062 | CLS Loss: 0.06531422585248947
Epoch: 014, Loss: 2.4911, Train: 0.9813, Valid: 0.9778, Best: 0.9798
Epoch 15 / 200 | iteration 0 / 171 | Total Loss: 2.5111746788024902 | KNN Loss: 2.3962152004241943 | CLS Loss: 0.11495952308177948
Epoch 15 / 2

Epoch 18 / 200 | iteration 10 / 171 | Total Loss: 2.476365327835083 | KNN Loss: 2.4304699897766113 | CLS Loss: 0.04589524492621422
Epoch 18 / 200 | iteration 20 / 171 | Total Loss: 2.512005567550659 | KNN Loss: 2.453900098800659 | CLS Loss: 0.058105386793613434
Epoch 18 / 200 | iteration 30 / 171 | Total Loss: 2.462614059448242 | KNN Loss: 2.4157018661499023 | CLS Loss: 0.04691208153963089
Epoch 18 / 200 | iteration 40 / 171 | Total Loss: 2.44916033744812 | KNN Loss: 2.4168167114257812 | CLS Loss: 0.03234367445111275
Epoch 18 / 200 | iteration 50 / 171 | Total Loss: 2.4674506187438965 | KNN Loss: 2.431638479232788 | CLS Loss: 0.035812027752399445
Epoch 18 / 200 | iteration 60 / 171 | Total Loss: 2.498500108718872 | KNN Loss: 2.440647840499878 | CLS Loss: 0.05785217508673668
Epoch 18 / 200 | iteration 70 / 171 | Total Loss: 2.4817447662353516 | KNN Loss: 2.411778688430786 | CLS Loss: 0.06996599584817886
Epoch 18 / 200 | iteration 80 / 171 | Total Loss: 2.5137434005737305 | KNN Loss: 2.4

Epoch 21 / 200 | iteration 80 / 171 | Total Loss: 2.473090648651123 | KNN Loss: 2.429137706756592 | CLS Loss: 0.043952930718660355
Epoch 21 / 200 | iteration 90 / 171 | Total Loss: 2.4419784545898438 | KNN Loss: 2.399592399597168 | CLS Loss: 0.04238595440983772
Epoch 21 / 200 | iteration 100 / 171 | Total Loss: 2.447232961654663 | KNN Loss: 2.397538900375366 | CLS Loss: 0.0496939979493618
Epoch 21 / 200 | iteration 110 / 171 | Total Loss: 2.482562303543091 | KNN Loss: 2.449773073196411 | CLS Loss: 0.032789330929517746
Epoch 21 / 200 | iteration 120 / 171 | Total Loss: 2.5182583332061768 | KNN Loss: 2.4657480716705322 | CLS Loss: 0.05251025781035423
Epoch 21 / 200 | iteration 130 / 171 | Total Loss: 2.493154287338257 | KNN Loss: 2.429997205734253 | CLS Loss: 0.06315703690052032
Epoch 21 / 200 | iteration 140 / 171 | Total Loss: 2.4538466930389404 | KNN Loss: 2.415354013442993 | CLS Loss: 0.0384925976395607
Epoch 21 / 200 | iteration 150 / 171 | Total Loss: 2.510446310043335 | KNN Loss: 

Epoch 24 / 200 | iteration 150 / 171 | Total Loss: 2.478961706161499 | KNN Loss: 2.426452398300171 | CLS Loss: 0.052509281784296036
Epoch 24 / 200 | iteration 160 / 171 | Total Loss: 2.461373805999756 | KNN Loss: 2.3844196796417236 | CLS Loss: 0.07695411145687103
Epoch 24 / 200 | iteration 170 / 171 | Total Loss: 2.4501399993896484 | KNN Loss: 2.420170783996582 | CLS Loss: 0.02996929921209812
Epoch: 024, Loss: 2.4611, Train: 0.9874, Valid: 0.9823, Best: 0.9823
Epoch 25 / 200 | iteration 0 / 171 | Total Loss: 2.465010166168213 | KNN Loss: 2.4102790355682373 | CLS Loss: 0.05473123490810394
Epoch 25 / 200 | iteration 10 / 171 | Total Loss: 2.429643392562866 | KNN Loss: 2.407382011413574 | CLS Loss: 0.022261325269937515
Epoch 25 / 200 | iteration 20 / 171 | Total Loss: 2.4425854682922363 | KNN Loss: 2.4156391620635986 | CLS Loss: 0.026946209371089935
Epoch 25 / 200 | iteration 30 / 171 | Total Loss: 2.461549997329712 | KNN Loss: 2.4253170490264893 | CLS Loss: 0.036232877522706985
Epoch 25 

Epoch 28 / 200 | iteration 40 / 171 | Total Loss: 2.4560916423797607 | KNN Loss: 2.4172110557556152 | CLS Loss: 0.038880616426467896
Epoch 28 / 200 | iteration 50 / 171 | Total Loss: 2.4351308345794678 | KNN Loss: 2.397585391998291 | CLS Loss: 0.037545353174209595
Epoch 28 / 200 | iteration 60 / 171 | Total Loss: 2.459334135055542 | KNN Loss: 2.4448766708374023 | CLS Loss: 0.014457393437623978
Epoch 28 / 200 | iteration 70 / 171 | Total Loss: 2.422696590423584 | KNN Loss: 2.3990533351898193 | CLS Loss: 0.023643314838409424
Epoch 28 / 200 | iteration 80 / 171 | Total Loss: 2.4593446254730225 | KNN Loss: 2.4292666912078857 | CLS Loss: 0.03007788211107254
Epoch 28 / 200 | iteration 90 / 171 | Total Loss: 2.4258744716644287 | KNN Loss: 2.400040626525879 | CLS Loss: 0.025833861902356148
Epoch 28 / 200 | iteration 100 / 171 | Total Loss: 2.42549729347229 | KNN Loss: 2.3973891735076904 | CLS Loss: 0.02810804173350334
Epoch 28 / 200 | iteration 110 / 171 | Total Loss: 2.434732437133789 | KNN L

Epoch 31 / 200 | iteration 110 / 171 | Total Loss: 2.423253059387207 | KNN Loss: 2.394390344619751 | CLS Loss: 0.028862737119197845
Epoch 31 / 200 | iteration 120 / 171 | Total Loss: 2.497260093688965 | KNN Loss: 2.406461000442505 | CLS Loss: 0.09079910814762115
Epoch 31 / 200 | iteration 130 / 171 | Total Loss: 2.4309191703796387 | KNN Loss: 2.37016224861145 | CLS Loss: 0.060756832361221313
Epoch 31 / 200 | iteration 140 / 171 | Total Loss: 2.460653305053711 | KNN Loss: 2.4172117710113525 | CLS Loss: 0.043441418558359146
Epoch 31 / 200 | iteration 150 / 171 | Total Loss: 2.429929256439209 | KNN Loss: 2.407388925552368 | CLS Loss: 0.02254021354019642
Epoch 31 / 200 | iteration 160 / 171 | Total Loss: 2.445472240447998 | KNN Loss: 2.373246669769287 | CLS Loss: 0.07222560793161392
Epoch 31 / 200 | iteration 170 / 171 | Total Loss: 2.4646317958831787 | KNN Loss: 2.4043023586273193 | CLS Loss: 0.06032932177186012
Epoch: 031, Loss: 2.4467, Train: 0.9901, Valid: 0.9848, Best: 0.9848
Epoch 32

Epoch: 034, Loss: 2.4372, Train: 0.9896, Valid: 0.9838, Best: 0.9848
Epoch 35 / 200 | iteration 0 / 171 | Total Loss: 2.4702939987182617 | KNN Loss: 2.4292545318603516 | CLS Loss: 0.041039370000362396
Epoch 35 / 200 | iteration 10 / 171 | Total Loss: 2.401801824569702 | KNN Loss: 2.362006187438965 | CLS Loss: 0.03979559242725372
Epoch 35 / 200 | iteration 20 / 171 | Total Loss: 2.4323740005493164 | KNN Loss: 2.4127323627471924 | CLS Loss: 0.019641710445284843
Epoch 35 / 200 | iteration 30 / 171 | Total Loss: 2.4204773902893066 | KNN Loss: 2.3849031925201416 | CLS Loss: 0.03557419404387474
Epoch 35 / 200 | iteration 40 / 171 | Total Loss: 2.4500319957733154 | KNN Loss: 2.421945333480835 | CLS Loss: 0.028086746111512184
Epoch 35 / 200 | iteration 50 / 171 | Total Loss: 2.4397902488708496 | KNN Loss: 2.413207769393921 | CLS Loss: 0.02658259868621826
Epoch 35 / 200 | iteration 60 / 171 | Total Loss: 2.465892791748047 | KNN Loss: 2.4149134159088135 | CLS Loss: 0.0509793795645237
Epoch 35 / 

Epoch 38 / 200 | iteration 60 / 171 | Total Loss: 2.4443163871765137 | KNN Loss: 2.4109508991241455 | CLS Loss: 0.033365555107593536
Epoch 38 / 200 | iteration 70 / 171 | Total Loss: 2.515061616897583 | KNN Loss: 2.447596311569214 | CLS Loss: 0.0674654170870781
Epoch 38 / 200 | iteration 80 / 171 | Total Loss: 2.4252536296844482 | KNN Loss: 2.406633138656616 | CLS Loss: 0.018620382994413376
Epoch 38 / 200 | iteration 90 / 171 | Total Loss: 2.437495708465576 | KNN Loss: 2.399496555328369 | CLS Loss: 0.03799916431307793
Epoch 38 / 200 | iteration 100 / 171 | Total Loss: 2.448355197906494 | KNN Loss: 2.4140427112579346 | CLS Loss: 0.0343124084174633
Epoch 38 / 200 | iteration 110 / 171 | Total Loss: 2.428218126296997 | KNN Loss: 2.4171054363250732 | CLS Loss: 0.011112766340374947
Epoch 38 / 200 | iteration 120 / 171 | Total Loss: 2.460858106613159 | KNN Loss: 2.4175121784210205 | CLS Loss: 0.043345846235752106
Epoch 38 / 200 | iteration 130 / 171 | Total Loss: 2.4029033184051514 | KNN Los

Epoch 41 / 200 | iteration 130 / 171 | Total Loss: 2.45642352104187 | KNN Loss: 2.4265215396881104 | CLS Loss: 0.029901890084147453
Epoch 41 / 200 | iteration 140 / 171 | Total Loss: 2.4397263526916504 | KNN Loss: 2.381004571914673 | CLS Loss: 0.05872169882059097
Epoch 41 / 200 | iteration 150 / 171 | Total Loss: 2.4666736125946045 | KNN Loss: 2.4100866317749023 | CLS Loss: 0.056587085127830505
Epoch 41 / 200 | iteration 160 / 171 | Total Loss: 2.4598941802978516 | KNN Loss: 2.4122819900512695 | CLS Loss: 0.04761224985122681
Epoch 41 / 200 | iteration 170 / 171 | Total Loss: 2.4564359188079834 | KNN Loss: 2.379107713699341 | CLS Loss: 0.07732813060283661
Epoch: 041, Loss: 2.4353, Train: 0.9909, Valid: 0.9844, Best: 0.9855
Epoch 42 / 200 | iteration 0 / 171 | Total Loss: 2.4529037475585938 | KNN Loss: 2.4050447940826416 | CLS Loss: 0.04785895720124245
Epoch 42 / 200 | iteration 10 / 171 | Total Loss: 2.444174289703369 | KNN Loss: 2.4141197204589844 | CLS Loss: 0.030054498463869095
Epoch

Epoch 45 / 200 | iteration 10 / 171 | Total Loss: 2.4254403114318848 | KNN Loss: 2.409135341644287 | CLS Loss: 0.01630507968366146
Epoch 45 / 200 | iteration 20 / 171 | Total Loss: 2.4383227825164795 | KNN Loss: 2.3875625133514404 | CLS Loss: 0.050760384649038315
Epoch 45 / 200 | iteration 30 / 171 | Total Loss: 2.406935214996338 | KNN Loss: 2.3941409587860107 | CLS Loss: 0.01279417984187603
Epoch 45 / 200 | iteration 40 / 171 | Total Loss: 2.4219818115234375 | KNN Loss: 2.4021973609924316 | CLS Loss: 0.01978437229990959
Epoch 45 / 200 | iteration 50 / 171 | Total Loss: 2.4444479942321777 | KNN Loss: 2.414515733718872 | CLS Loss: 0.029932230710983276
Epoch 45 / 200 | iteration 60 / 171 | Total Loss: 2.427824020385742 | KNN Loss: 2.397669553756714 | CLS Loss: 0.030154500156641006
Epoch 45 / 200 | iteration 70 / 171 | Total Loss: 2.4205048084259033 | KNN Loss: 2.3979949951171875 | CLS Loss: 0.02250991202890873
Epoch 45 / 200 | iteration 80 / 171 | Total Loss: 2.462672710418701 | KNN Loss

Epoch 48 / 200 | iteration 80 / 171 | Total Loss: 2.403923988342285 | KNN Loss: 2.3744709491729736 | CLS Loss: 0.029453100636601448
Epoch 48 / 200 | iteration 90 / 171 | Total Loss: 2.3851051330566406 | KNN Loss: 2.36529803276062 | CLS Loss: 0.01980709843337536
Epoch 48 / 200 | iteration 100 / 171 | Total Loss: 2.438324451446533 | KNN Loss: 2.4056169986724854 | CLS Loss: 0.03270747512578964
Epoch 48 / 200 | iteration 110 / 171 | Total Loss: 2.403348445892334 | KNN Loss: 2.386944532394409 | CLS Loss: 0.01640380546450615
Epoch 48 / 200 | iteration 120 / 171 | Total Loss: 2.488034963607788 | KNN Loss: 2.4501867294311523 | CLS Loss: 0.03784829378128052
Epoch 48 / 200 | iteration 130 / 171 | Total Loss: 2.4452977180480957 | KNN Loss: 2.4301037788391113 | CLS Loss: 0.015193942002952099
Epoch 48 / 200 | iteration 140 / 171 | Total Loss: 2.430123805999756 | KNN Loss: 2.391632080078125 | CLS Loss: 0.03849177062511444
Epoch 48 / 200 | iteration 150 / 171 | Total Loss: 2.435600757598877 | KNN Los

Epoch 51 / 200 | iteration 150 / 171 | Total Loss: 2.385340452194214 | KNN Loss: 2.359915018081665 | CLS Loss: 0.025425374507904053
Epoch 51 / 200 | iteration 160 / 171 | Total Loss: 2.4093568325042725 | KNN Loss: 2.39463472366333 | CLS Loss: 0.014722065068781376
Epoch 51 / 200 | iteration 170 / 171 | Total Loss: 2.4576432704925537 | KNN Loss: 2.4137659072875977 | CLS Loss: 0.04387732595205307
Epoch: 051, Loss: 2.4225, Train: 0.9937, Valid: 0.9859, Best: 0.9860
Epoch 52 / 200 | iteration 0 / 171 | Total Loss: 2.416010618209839 | KNN Loss: 2.3905551433563232 | CLS Loss: 0.025455553084611893
Epoch 52 / 200 | iteration 10 / 171 | Total Loss: 2.4474287033081055 | KNN Loss: 2.3864963054656982 | CLS Loss: 0.06093234196305275
Epoch 52 / 200 | iteration 20 / 171 | Total Loss: 2.411317825317383 | KNN Loss: 2.393038034439087 | CLS Loss: 0.018279794603586197
Epoch 52 / 200 | iteration 30 / 171 | Total Loss: 2.4189419746398926 | KNN Loss: 2.4055140018463135 | CLS Loss: 0.01342797465622425
Epoch 52

Epoch 55 / 200 | iteration 30 / 171 | Total Loss: 2.424492359161377 | KNN Loss: 2.3797924518585205 | CLS Loss: 0.044699911028146744
Epoch 55 / 200 | iteration 40 / 171 | Total Loss: 2.436307907104492 | KNN Loss: 2.3976376056671143 | CLS Loss: 0.03867029771208763
Epoch 55 / 200 | iteration 50 / 171 | Total Loss: 2.4385111331939697 | KNN Loss: 2.4099462032318115 | CLS Loss: 0.028564922511577606
Epoch 55 / 200 | iteration 60 / 171 | Total Loss: 2.4468417167663574 | KNN Loss: 2.4197678565979004 | CLS Loss: 0.027073752135038376
Epoch 55 / 200 | iteration 70 / 171 | Total Loss: 2.4102749824523926 | KNN Loss: 2.388718366622925 | CLS Loss: 0.02155664749443531
Epoch 55 / 200 | iteration 80 / 171 | Total Loss: 2.425039052963257 | KNN Loss: 2.4102883338928223 | CLS Loss: 0.014750654809176922
Epoch 55 / 200 | iteration 90 / 171 | Total Loss: 2.4437363147735596 | KNN Loss: 2.4171199798583984 | CLS Loss: 0.026616238057613373
Epoch 55 / 200 | iteration 100 / 171 | Total Loss: 2.4503867626190186 | KNN

Epoch 58 / 200 | iteration 100 / 171 | Total Loss: 2.443161964416504 | KNN Loss: 2.4257256984710693 | CLS Loss: 0.017436208203434944
Epoch 58 / 200 | iteration 110 / 171 | Total Loss: 2.4094314575195312 | KNN Loss: 2.37402606010437 | CLS Loss: 0.035405367612838745
Epoch 58 / 200 | iteration 120 / 171 | Total Loss: 2.4204819202423096 | KNN Loss: 2.3981552124023438 | CLS Loss: 0.02232675813138485
Epoch 58 / 200 | iteration 130 / 171 | Total Loss: 2.433419942855835 | KNN Loss: 2.3967061042785645 | CLS Loss: 0.03671391308307648
Epoch 58 / 200 | iteration 140 / 171 | Total Loss: 2.393270969390869 | KNN Loss: 2.379490852355957 | CLS Loss: 0.013780014589428902
Epoch 58 / 200 | iteration 150 / 171 | Total Loss: 2.4027881622314453 | KNN Loss: 2.393852710723877 | CLS Loss: 0.008935380727052689
Epoch 58 / 200 | iteration 160 / 171 | Total Loss: 2.462623357772827 | KNN Loss: 2.447075366973877 | CLS Loss: 0.015547892078757286
Epoch 58 / 200 | iteration 170 / 171 | Total Loss: 2.43915057182312 | KNN

Epoch 61 / 200 | iteration 170 / 171 | Total Loss: 2.4484915733337402 | KNN Loss: 2.433476448059082 | CLS Loss: 0.015015038661658764
Epoch: 061, Loss: 2.4163, Train: 0.9938, Valid: 0.9863, Best: 0.9866
Epoch 62 / 200 | iteration 0 / 171 | Total Loss: 2.3923816680908203 | KNN Loss: 2.380629062652588 | CLS Loss: 0.011752615682780743
Epoch 62 / 200 | iteration 10 / 171 | Total Loss: 2.400252342224121 | KNN Loss: 2.3560538291931152 | CLS Loss: 0.04419856145977974
Epoch 62 / 200 | iteration 20 / 171 | Total Loss: 2.4265213012695312 | KNN Loss: 2.4142675399780273 | CLS Loss: 0.012253833934664726
Epoch 62 / 200 | iteration 30 / 171 | Total Loss: 2.3598058223724365 | KNN Loss: 2.348642110824585 | CLS Loss: 0.011163600720465183
Epoch 62 / 200 | iteration 40 / 171 | Total Loss: 2.3771982192993164 | KNN Loss: 2.3567731380462646 | CLS Loss: 0.020425181835889816
Epoch 62 / 200 | iteration 50 / 171 | Total Loss: 2.4069948196411133 | KNN Loss: 2.3843259811401367 | CLS Loss: 0.022668899968266487
Epoch

Epoch 65 / 200 | iteration 50 / 171 | Total Loss: 2.408505916595459 | KNN Loss: 2.362030506134033 | CLS Loss: 0.04647543653845787
Epoch 65 / 200 | iteration 60 / 171 | Total Loss: 2.4002785682678223 | KNN Loss: 2.382356643676758 | CLS Loss: 0.01792190782725811
Epoch 65 / 200 | iteration 70 / 171 | Total Loss: 2.4258835315704346 | KNN Loss: 2.3888189792633057 | CLS Loss: 0.037064582109451294
Epoch 65 / 200 | iteration 80 / 171 | Total Loss: 2.410062074661255 | KNN Loss: 2.382854700088501 | CLS Loss: 0.027207354083657265
Epoch 65 / 200 | iteration 90 / 171 | Total Loss: 2.401031970977783 | KNN Loss: 2.3783507347106934 | CLS Loss: 0.022681163623929024
Epoch 65 / 200 | iteration 100 / 171 | Total Loss: 2.3490424156188965 | KNN Loss: 2.331092119216919 | CLS Loss: 0.017950186505913734
Epoch 65 / 200 | iteration 110 / 171 | Total Loss: 2.4045069217681885 | KNN Loss: 2.3863487243652344 | CLS Loss: 0.018158189952373505
Epoch 65 / 200 | iteration 120 / 171 | Total Loss: 2.3989577293395996 | KNN 

Epoch 68 / 200 | iteration 120 / 171 | Total Loss: 2.4230194091796875 | KNN Loss: 2.394366502761841 | CLS Loss: 0.028652790933847427
Epoch 68 / 200 | iteration 130 / 171 | Total Loss: 2.404181957244873 | KNN Loss: 2.3799145221710205 | CLS Loss: 0.024267427623271942
Epoch 68 / 200 | iteration 140 / 171 | Total Loss: 2.4208555221557617 | KNN Loss: 2.409362316131592 | CLS Loss: 0.011493168771266937
Epoch 68 / 200 | iteration 150 / 171 | Total Loss: 2.4084174633026123 | KNN Loss: 2.3685362339019775 | CLS Loss: 0.0398811511695385
Epoch 68 / 200 | iteration 160 / 171 | Total Loss: 2.4145121574401855 | KNN Loss: 2.4092984199523926 | CLS Loss: 0.005213804543018341
Epoch 68 / 200 | iteration 170 / 171 | Total Loss: 2.346195697784424 | KNN Loss: 2.3315963745117188 | CLS Loss: 0.014599375426769257
Epoch: 068, Loss: 2.4145, Train: 0.9946, Valid: 0.9860, Best: 0.9870
Epoch 69 / 200 | iteration 0 / 171 | Total Loss: 2.4166018962860107 | KNN Loss: 2.390693187713623 | CLS Loss: 0.025908783078193665
Ep

Epoch 72 / 200 | iteration 0 / 171 | Total Loss: 2.368018865585327 | KNN Loss: 2.355069875717163 | CLS Loss: 0.012949004769325256
Epoch 72 / 200 | iteration 10 / 171 | Total Loss: 2.3890514373779297 | KNN Loss: 2.3738434314727783 | CLS Loss: 0.015208037570118904
Epoch 72 / 200 | iteration 20 / 171 | Total Loss: 2.387073278427124 | KNN Loss: 2.3728458881378174 | CLS Loss: 0.014227384701371193
Epoch 72 / 200 | iteration 30 / 171 | Total Loss: 2.3678743839263916 | KNN Loss: 2.3587167263031006 | CLS Loss: 0.009157653898000717
Epoch 72 / 200 | iteration 40 / 171 | Total Loss: 2.4119327068328857 | KNN Loss: 2.4028570652008057 | CLS Loss: 0.009075754322111607
Epoch 72 / 200 | iteration 50 / 171 | Total Loss: 2.39367938041687 | KNN Loss: 2.3858113288879395 | CLS Loss: 0.007867933250963688
Epoch 72 / 200 | iteration 60 / 171 | Total Loss: 2.423067569732666 | KNN Loss: 2.40354061126709 | CLS Loss: 0.019527047872543335
Epoch 72 / 200 | iteration 70 / 171 | Total Loss: 2.423534393310547 | KNN Loss

Epoch 75 / 200 | iteration 70 / 171 | Total Loss: 2.4620580673217773 | KNN Loss: 2.449739456176758 | CLS Loss: 0.012318691238760948
Epoch 75 / 200 | iteration 80 / 171 | Total Loss: 2.387554168701172 | KNN Loss: 2.379441738128662 | CLS Loss: 0.008112340234220028
Epoch 75 / 200 | iteration 90 / 171 | Total Loss: 2.411417245864868 | KNN Loss: 2.3929836750030518 | CLS Loss: 0.018433624878525734
Epoch 75 / 200 | iteration 100 / 171 | Total Loss: 2.4166247844696045 | KNN Loss: 2.403425693511963 | CLS Loss: 0.013199173845350742
Epoch 75 / 200 | iteration 110 / 171 | Total Loss: 2.420776128768921 | KNN Loss: 2.395629644393921 | CLS Loss: 0.02514658309519291
Epoch 75 / 200 | iteration 120 / 171 | Total Loss: 2.4043240547180176 | KNN Loss: 2.3859939575195312 | CLS Loss: 0.01833019219338894
Epoch 75 / 200 | iteration 130 / 171 | Total Loss: 2.447852373123169 | KNN Loss: 2.4210171699523926 | CLS Loss: 0.026835277676582336
Epoch 75 / 200 | iteration 140 / 171 | Total Loss: 2.4298691749572754 | KNN

Epoch 78 / 200 | iteration 140 / 171 | Total Loss: 2.3851075172424316 | KNN Loss: 2.3683154582977295 | CLS Loss: 0.016792094334959984
Epoch 78 / 200 | iteration 150 / 171 | Total Loss: 2.3794193267822266 | KNN Loss: 2.3760485649108887 | CLS Loss: 0.003370713908225298
Epoch 78 / 200 | iteration 160 / 171 | Total Loss: 2.4055888652801514 | KNN Loss: 2.4018564224243164 | CLS Loss: 0.0037324409931898117
Epoch 78 / 200 | iteration 170 / 171 | Total Loss: 2.376258134841919 | KNN Loss: 2.35781192779541 | CLS Loss: 0.018446089699864388
Epoch: 078, Loss: 2.4069, Train: 0.9954, Valid: 0.9861, Best: 0.9870
Epoch 79 / 200 | iteration 0 / 171 | Total Loss: 2.3918685913085938 | KNN Loss: 2.376077890396118 | CLS Loss: 0.015790794044733047
Epoch 79 / 200 | iteration 10 / 171 | Total Loss: 2.381571054458618 | KNN Loss: 2.3624930381774902 | CLS Loss: 0.019078057259321213
Epoch 79 / 200 | iteration 20 / 171 | Total Loss: 2.37876296043396 | KNN Loss: 2.374074935913086 | CLS Loss: 0.004688005894422531
Epoc

Epoch 82 / 200 | iteration 20 / 171 | Total Loss: 2.4049623012542725 | KNN Loss: 2.392625093460083 | CLS Loss: 0.012337195686995983
Epoch 82 / 200 | iteration 30 / 171 | Total Loss: 2.4060940742492676 | KNN Loss: 2.38356876373291 | CLS Loss: 0.02252526767551899
Epoch 82 / 200 | iteration 40 / 171 | Total Loss: 2.3783302307128906 | KNN Loss: 2.3739070892333984 | CLS Loss: 0.004423215985298157
Epoch 82 / 200 | iteration 50 / 171 | Total Loss: 2.4224019050598145 | KNN Loss: 2.4096133708953857 | CLS Loss: 0.012788445688784122
Epoch 82 / 200 | iteration 60 / 171 | Total Loss: 2.44667387008667 | KNN Loss: 2.419440507888794 | CLS Loss: 0.02723325788974762
Epoch 82 / 200 | iteration 70 / 171 | Total Loss: 2.431288242340088 | KNN Loss: 2.420243501663208 | CLS Loss: 0.011044629849493504
Epoch 82 / 200 | iteration 80 / 171 | Total Loss: 2.3953049182891846 | KNN Loss: 2.387882947921753 | CLS Loss: 0.007421907968819141
Epoch 82 / 200 | iteration 90 / 171 | Total Loss: 2.4151647090911865 | KNN Loss:

Epoch 85 / 200 | iteration 90 / 171 | Total Loss: 2.436995029449463 | KNN Loss: 2.4124646186828613 | CLS Loss: 0.024530477821826935
Epoch 85 / 200 | iteration 100 / 171 | Total Loss: 2.4446861743927 | KNN Loss: 2.437509775161743 | CLS Loss: 0.007176365703344345
Epoch 85 / 200 | iteration 110 / 171 | Total Loss: 2.398958444595337 | KNN Loss: 2.388225555419922 | CLS Loss: 0.010732815600931644
Epoch 85 / 200 | iteration 120 / 171 | Total Loss: 2.4128384590148926 | KNN Loss: 2.390338897705078 | CLS Loss: 0.022499555721879005
Epoch 85 / 200 | iteration 130 / 171 | Total Loss: 2.438443183898926 | KNN Loss: 2.421572685241699 | CLS Loss: 0.016870535910129547
Epoch 85 / 200 | iteration 140 / 171 | Total Loss: 2.4197299480438232 | KNN Loss: 2.406599283218384 | CLS Loss: 0.013130630366504192
Epoch 85 / 200 | iteration 150 / 171 | Total Loss: 2.376617193222046 | KNN Loss: 2.3699934482574463 | CLS Loss: 0.006623793859034777
Epoch 85 / 200 | iteration 160 / 171 | Total Loss: 2.4485151767730713 | KNN

Epoch 88 / 200 | iteration 160 / 171 | Total Loss: 2.421468496322632 | KNN Loss: 2.4140453338623047 | CLS Loss: 0.0074231442995369434
Epoch 88 / 200 | iteration 170 / 171 | Total Loss: 2.4335970878601074 | KNN Loss: 2.4257922172546387 | CLS Loss: 0.007804783992469311
Epoch: 088, Loss: 2.4173, Train: 0.9963, Valid: 0.9864, Best: 0.9870
Epoch 89 / 200 | iteration 0 / 171 | Total Loss: 2.402679920196533 | KNN Loss: 2.3951048851013184 | CLS Loss: 0.007575063034892082
Epoch 89 / 200 | iteration 10 / 171 | Total Loss: 2.406231641769409 | KNN Loss: 2.399745464324951 | CLS Loss: 0.006486090831458569
Epoch 89 / 200 | iteration 20 / 171 | Total Loss: 2.4038681983947754 | KNN Loss: 2.387392044067383 | CLS Loss: 0.016476167365908623
Epoch 89 / 200 | iteration 30 / 171 | Total Loss: 2.42533802986145 | KNN Loss: 2.4097063541412354 | CLS Loss: 0.015631593763828278
Epoch 89 / 200 | iteration 40 / 171 | Total Loss: 2.4160501956939697 | KNN Loss: 2.401535987854004 | CLS Loss: 0.014514271169900894
Epoch 

Epoch 92 / 200 | iteration 40 / 171 | Total Loss: 2.4651124477386475 | KNN Loss: 2.4454314708709717 | CLS Loss: 0.019681042060256004
Epoch 92 / 200 | iteration 50 / 171 | Total Loss: 2.406785488128662 | KNN Loss: 2.3961617946624756 | CLS Loss: 0.01062364038079977
Epoch 92 / 200 | iteration 60 / 171 | Total Loss: 2.4199676513671875 | KNN Loss: 2.4152958393096924 | CLS Loss: 0.004671864211559296
Epoch 92 / 200 | iteration 70 / 171 | Total Loss: 2.4186787605285645 | KNN Loss: 2.4108617305755615 | CLS Loss: 0.00781709048897028
Epoch 92 / 200 | iteration 80 / 171 | Total Loss: 2.460679054260254 | KNN Loss: 2.4170899391174316 | CLS Loss: 0.043589036911726
Epoch 92 / 200 | iteration 90 / 171 | Total Loss: 2.406858444213867 | KNN Loss: 2.3903110027313232 | CLS Loss: 0.01654748059809208
Epoch 92 / 200 | iteration 100 / 171 | Total Loss: 2.3963940143585205 | KNN Loss: 2.3835110664367676 | CLS Loss: 0.012882894836366177
Epoch 92 / 200 | iteration 110 / 171 | Total Loss: 2.4486448764801025 | KNN L

Epoch 95 / 200 | iteration 110 / 171 | Total Loss: 2.4156079292297363 | KNN Loss: 2.4104323387145996 | CLS Loss: 0.005175556987524033
Epoch 95 / 200 | iteration 120 / 171 | Total Loss: 2.4309065341949463 | KNN Loss: 2.4187920093536377 | CLS Loss: 0.012114496901631355
Epoch 95 / 200 | iteration 130 / 171 | Total Loss: 2.443948745727539 | KNN Loss: 2.440953016281128 | CLS Loss: 0.00299580255523324
Epoch 95 / 200 | iteration 140 / 171 | Total Loss: 2.420226573944092 | KNN Loss: 2.41505765914917 | CLS Loss: 0.005168868228793144
Epoch 95 / 200 | iteration 150 / 171 | Total Loss: 2.4071362018585205 | KNN Loss: 2.383702039718628 | CLS Loss: 0.023434164002537727
Epoch 95 / 200 | iteration 160 / 171 | Total Loss: 2.406251907348633 | KNN Loss: 2.397050619125366 | CLS Loss: 0.009201178327202797
Epoch 95 / 200 | iteration 170 / 171 | Total Loss: 2.431401491165161 | KNN Loss: 2.4119269847869873 | CLS Loss: 0.01947452500462532
Epoch: 095, Loss: 2.4185, Train: 0.9960, Valid: 0.9863, Best: 0.9870
Epoc

Epoch: 098, Loss: 2.4173, Train: 0.9951, Valid: 0.9849, Best: 0.9870
Epoch 99 / 200 | iteration 0 / 171 | Total Loss: 2.4523048400878906 | KNN Loss: 2.4183411598205566 | CLS Loss: 0.03396376967430115
Epoch 99 / 200 | iteration 10 / 171 | Total Loss: 2.4421989917755127 | KNN Loss: 2.4221854209899902 | CLS Loss: 0.02001352794468403
Epoch 99 / 200 | iteration 20 / 171 | Total Loss: 2.442675828933716 | KNN Loss: 2.4247429370880127 | CLS Loss: 0.017932849004864693
Epoch 99 / 200 | iteration 30 / 171 | Total Loss: 2.3987739086151123 | KNN Loss: 2.3848423957824707 | CLS Loss: 0.01393153890967369
Epoch 99 / 200 | iteration 40 / 171 | Total Loss: 2.4056754112243652 | KNN Loss: 2.372952938079834 | CLS Loss: 0.0327225923538208
Epoch 99 / 200 | iteration 50 / 171 | Total Loss: 2.457200050354004 | KNN Loss: 2.4324018955230713 | CLS Loss: 0.02479827031493187
Epoch 99 / 200 | iteration 60 / 171 | Total Loss: 2.397240161895752 | KNN Loss: 2.374983549118042 | CLS Loss: 0.0222565196454525
Epoch 99 / 200

Epoch 102 / 200 | iteration 60 / 171 | Total Loss: 2.3852458000183105 | KNN Loss: 2.376171112060547 | CLS Loss: 0.009074619971215725
Epoch 102 / 200 | iteration 70 / 171 | Total Loss: 2.4600932598114014 | KNN Loss: 2.4394562244415283 | CLS Loss: 0.020637130364775658
Epoch 102 / 200 | iteration 80 / 171 | Total Loss: 2.41862154006958 | KNN Loss: 2.404925584793091 | CLS Loss: 0.013695919886231422
Epoch 102 / 200 | iteration 90 / 171 | Total Loss: 2.3987865447998047 | KNN Loss: 2.3953492641448975 | CLS Loss: 0.0034373211674392223
Epoch 102 / 200 | iteration 100 / 171 | Total Loss: 2.42932391166687 | KNN Loss: 2.4042611122131348 | CLS Loss: 0.02506287582218647
Epoch 102 / 200 | iteration 110 / 171 | Total Loss: 2.4069204330444336 | KNN Loss: 2.3977105617523193 | CLS Loss: 0.009209905751049519
Epoch 102 / 200 | iteration 120 / 171 | Total Loss: 2.434507369995117 | KNN Loss: 2.4197723865509033 | CLS Loss: 0.014734945259988308
Epoch 102 / 200 | iteration 130 / 171 | Total Loss: 2.409370899200

Epoch 105 / 200 | iteration 120 / 171 | Total Loss: 2.4168472290039062 | KNN Loss: 2.413574695587158 | CLS Loss: 0.0032726323697715998
Epoch 105 / 200 | iteration 130 / 171 | Total Loss: 2.425452709197998 | KNN Loss: 2.415271282196045 | CLS Loss: 0.010181366465985775
Epoch 105 / 200 | iteration 140 / 171 | Total Loss: 2.412182092666626 | KNN Loss: 2.383709192276001 | CLS Loss: 0.02847299911081791
Epoch 105 / 200 | iteration 150 / 171 | Total Loss: 2.4597644805908203 | KNN Loss: 2.446932077407837 | CLS Loss: 0.012832383625209332
Epoch 105 / 200 | iteration 160 / 171 | Total Loss: 2.4042937755584717 | KNN Loss: 2.395169258117676 | CLS Loss: 0.009124553762376308
Epoch 105 / 200 | iteration 170 / 171 | Total Loss: 2.4052982330322266 | KNN Loss: 2.4007411003112793 | CLS Loss: 0.004557175096124411
Epoch: 105, Loss: 2.4173, Train: 0.9962, Valid: 0.9863, Best: 0.9870
Epoch 106 / 200 | iteration 0 / 171 | Total Loss: 2.393795967102051 | KNN Loss: 2.3769400119781494 | CLS Loss: 0.016855992376804

Epoch: 108, Loss: 2.4169, Train: 0.9967, Valid: 0.9866, Best: 0.9870
Epoch 109 / 200 | iteration 0 / 171 | Total Loss: 2.417480945587158 | KNN Loss: 2.3990001678466797 | CLS Loss: 0.018480826169252396
Epoch 109 / 200 | iteration 10 / 171 | Total Loss: 2.3961849212646484 | KNN Loss: 2.3906407356262207 | CLS Loss: 0.005544242914766073
Epoch 109 / 200 | iteration 20 / 171 | Total Loss: 2.3604087829589844 | KNN Loss: 2.3551676273345947 | CLS Loss: 0.005241159815341234
Epoch 109 / 200 | iteration 30 / 171 | Total Loss: 2.4372761249542236 | KNN Loss: 2.4272913932800293 | CLS Loss: 0.00998464785516262
Epoch 109 / 200 | iteration 40 / 171 | Total Loss: 2.4340336322784424 | KNN Loss: 2.4157543182373047 | CLS Loss: 0.01827928237617016
Epoch 109 / 200 | iteration 50 / 171 | Total Loss: 2.4266507625579834 | KNN Loss: 2.415034055709839 | CLS Loss: 0.011616635136306286
Epoch 109 / 200 | iteration 60 / 171 | Total Loss: 2.4249379634857178 | KNN Loss: 2.421326160430908 | CLS Loss: 0.003611853811889887

Epoch 112 / 200 | iteration 60 / 171 | Total Loss: 2.4094011783599854 | KNN Loss: 2.406369686126709 | CLS Loss: 0.0030314791947603226
Epoch 112 / 200 | iteration 70 / 171 | Total Loss: 2.398012399673462 | KNN Loss: 2.3939149379730225 | CLS Loss: 0.004097490571439266
Epoch 112 / 200 | iteration 80 / 171 | Total Loss: 2.429311990737915 | KNN Loss: 2.414421796798706 | CLS Loss: 0.01489012036472559
Epoch 112 / 200 | iteration 90 / 171 | Total Loss: 2.4133269786834717 | KNN Loss: 2.407501459121704 | CLS Loss: 0.0058255321346223354
Epoch 112 / 200 | iteration 100 / 171 | Total Loss: 2.3620738983154297 | KNN Loss: 2.3573949337005615 | CLS Loss: 0.004678880330175161
Epoch 112 / 200 | iteration 110 / 171 | Total Loss: 2.3816614151000977 | KNN Loss: 2.3783910274505615 | CLS Loss: 0.0032704814802855253
Epoch 112 / 200 | iteration 120 / 171 | Total Loss: 2.4178731441497803 | KNN Loss: 2.4011666774749756 | CLS Loss: 0.016706442460417747
Epoch 112 / 200 | iteration 130 / 171 | Total Loss: 2.38792443

Epoch 115 / 200 | iteration 120 / 171 | Total Loss: 2.4208109378814697 | KNN Loss: 2.4130618572235107 | CLS Loss: 0.007749016396701336
Epoch 115 / 200 | iteration 130 / 171 | Total Loss: 2.4110419750213623 | KNN Loss: 2.405611991882324 | CLS Loss: 0.0054300385527312756
Epoch 115 / 200 | iteration 140 / 171 | Total Loss: 2.4018967151641846 | KNN Loss: 2.3929219245910645 | CLS Loss: 0.008974701166152954
Epoch 115 / 200 | iteration 150 / 171 | Total Loss: 2.371898651123047 | KNN Loss: 2.368736505508423 | CLS Loss: 0.0031620431691408157
Epoch 115 / 200 | iteration 160 / 171 | Total Loss: 2.4303157329559326 | KNN Loss: 2.4179165363311768 | CLS Loss: 0.012399178929626942
Epoch 115 / 200 | iteration 170 / 171 | Total Loss: 2.430546522140503 | KNN Loss: 2.425171375274658 | CLS Loss: 0.005375260021537542
Epoch: 115, Loss: 2.4175, Train: 0.9955, Valid: 0.9848, Best: 0.9870
Epoch 116 / 200 | iteration 0 / 171 | Total Loss: 2.410747528076172 | KNN Loss: 2.391237735748291 | CLS Loss: 0.019509743899

Epoch: 118, Loss: 2.4116, Train: 0.9953, Valid: 0.9844, Best: 0.9870
Epoch 119 / 200 | iteration 0 / 171 | Total Loss: 2.4362974166870117 | KNN Loss: 2.4298620223999023 | CLS Loss: 0.006435394752770662
Epoch 119 / 200 | iteration 10 / 171 | Total Loss: 2.4023447036743164 | KNN Loss: 2.390296459197998 | CLS Loss: 0.01204831525683403
Epoch 119 / 200 | iteration 20 / 171 | Total Loss: 2.404844284057617 | KNN Loss: 2.3731110095977783 | CLS Loss: 0.03173327073454857
Epoch 119 / 200 | iteration 30 / 171 | Total Loss: 2.4085395336151123 | KNN Loss: 2.3965132236480713 | CLS Loss: 0.012026423588395119
Epoch 119 / 200 | iteration 40 / 171 | Total Loss: 2.4360713958740234 | KNN Loss: 2.4265594482421875 | CLS Loss: 0.009512060321867466
Epoch 119 / 200 | iteration 50 / 171 | Total Loss: 2.4272701740264893 | KNN Loss: 2.4199352264404297 | CLS Loss: 0.0073348358273506165
Epoch 119 / 200 | iteration 60 / 171 | Total Loss: 2.427964210510254 | KNN Loss: 2.4036123752593994 | CLS Loss: 0.02435180172324180

Epoch 122 / 200 | iteration 60 / 171 | Total Loss: 2.3907456398010254 | KNN Loss: 2.3761260509490967 | CLS Loss: 0.014619573950767517
Epoch 122 / 200 | iteration 70 / 171 | Total Loss: 2.437025785446167 | KNN Loss: 2.430744171142578 | CLS Loss: 0.006281713955104351
Epoch 122 / 200 | iteration 80 / 171 | Total Loss: 2.413574695587158 | KNN Loss: 2.402433395385742 | CLS Loss: 0.011141324415802956
Epoch 122 / 200 | iteration 90 / 171 | Total Loss: 2.4266183376312256 | KNN Loss: 2.4154694080352783 | CLS Loss: 0.011148910038173199
Epoch 122 / 200 | iteration 100 / 171 | Total Loss: 2.4263272285461426 | KNN Loss: 2.403916835784912 | CLS Loss: 0.022410370409488678
Epoch 122 / 200 | iteration 110 / 171 | Total Loss: 2.412644624710083 | KNN Loss: 2.4009196758270264 | CLS Loss: 0.011724921874701977
Epoch 122 / 200 | iteration 120 / 171 | Total Loss: 2.437598705291748 | KNN Loss: 2.4150497913360596 | CLS Loss: 0.022548945620656013
Epoch 122 / 200 | iteration 130 / 171 | Total Loss: 2.385656833648

Epoch 125 / 200 | iteration 120 / 171 | Total Loss: 2.439101457595825 | KNN Loss: 2.420600652694702 | CLS Loss: 0.01850087381899357
Epoch 125 / 200 | iteration 130 / 171 | Total Loss: 2.418489694595337 | KNN Loss: 2.4145569801330566 | CLS Loss: 0.003932792693376541
Epoch 125 / 200 | iteration 140 / 171 | Total Loss: 2.4175102710723877 | KNN Loss: 2.405975580215454 | CLS Loss: 0.011534666642546654
Epoch 125 / 200 | iteration 150 / 171 | Total Loss: 2.4131479263305664 | KNN Loss: 2.3925201892852783 | CLS Loss: 0.020627837628126144
Epoch 125 / 200 | iteration 160 / 171 | Total Loss: 2.379848003387451 | KNN Loss: 2.3683745861053467 | CLS Loss: 0.011473424732685089
Epoch 125 / 200 | iteration 170 / 171 | Total Loss: 2.4358131885528564 | KNN Loss: 2.4272522926330566 | CLS Loss: 0.008560797199606895
Epoch: 125, Loss: 2.4151, Train: 0.9967, Valid: 0.9867, Best: 0.9870
Epoch 126 / 200 | iteration 0 / 171 | Total Loss: 2.408902406692505 | KNN Loss: 2.3954036235809326 | CLS Loss: 0.01349875144660

Epoch: 128, Loss: 2.4163, Train: 0.9957, Valid: 0.9852, Best: 0.9870
Epoch 129 / 200 | iteration 0 / 171 | Total Loss: 2.390692949295044 | KNN Loss: 2.3826498985290527 | CLS Loss: 0.008043126203119755
Epoch 129 / 200 | iteration 10 / 171 | Total Loss: 2.411470890045166 | KNN Loss: 2.3970823287963867 | CLS Loss: 0.014388534240424633
Epoch 129 / 200 | iteration 20 / 171 | Total Loss: 2.442150354385376 | KNN Loss: 2.4227421283721924 | CLS Loss: 0.019408326596021652
Epoch 129 / 200 | iteration 30 / 171 | Total Loss: 2.4101297855377197 | KNN Loss: 2.396545886993408 | CLS Loss: 0.013583862222731113
Epoch 129 / 200 | iteration 40 / 171 | Total Loss: 2.435899019241333 | KNN Loss: 2.4035658836364746 | CLS Loss: 0.03233322873711586
Epoch 129 / 200 | iteration 50 / 171 | Total Loss: 2.466184377670288 | KNN Loss: 2.4581122398376465 | CLS Loss: 0.008072148077189922
Epoch 129 / 200 | iteration 60 / 171 | Total Loss: 2.3869035243988037 | KNN Loss: 2.3843331336975098 | CLS Loss: 0.002570290584117174
E

Epoch 132 / 200 | iteration 60 / 171 | Total Loss: 2.423673629760742 | KNN Loss: 2.4195034503936768 | CLS Loss: 0.004170069005340338
Epoch 132 / 200 | iteration 70 / 171 | Total Loss: 2.458491802215576 | KNN Loss: 2.447695255279541 | CLS Loss: 0.010796554386615753
Epoch 132 / 200 | iteration 80 / 171 | Total Loss: 2.4014525413513184 | KNN Loss: 2.387829542160034 | CLS Loss: 0.013622942380607128
Epoch 132 / 200 | iteration 90 / 171 | Total Loss: 2.434735059738159 | KNN Loss: 2.4321162700653076 | CLS Loss: 0.0026187968906015158
Epoch 132 / 200 | iteration 100 / 171 | Total Loss: 2.407010316848755 | KNN Loss: 2.403341054916382 | CLS Loss: 0.0036691795103251934
Epoch 132 / 200 | iteration 110 / 171 | Total Loss: 2.394099712371826 | KNN Loss: 2.3781158924102783 | CLS Loss: 0.01598382368683815
Epoch 132 / 200 | iteration 120 / 171 | Total Loss: 2.395442247390747 | KNN Loss: 2.386655569076538 | CLS Loss: 0.00878671370446682
Epoch 132 / 200 | iteration 130 / 171 | Total Loss: 2.411722660064697

Epoch 135 / 200 | iteration 120 / 171 | Total Loss: 2.3926191329956055 | KNN Loss: 2.377835273742676 | CLS Loss: 0.014783822931349277
Epoch 135 / 200 | iteration 130 / 171 | Total Loss: 2.403862953186035 | KNN Loss: 2.379260301589966 | CLS Loss: 0.024602683261036873
Epoch 135 / 200 | iteration 140 / 171 | Total Loss: 2.440702199935913 | KNN Loss: 2.419074296951294 | CLS Loss: 0.021627986803650856
Epoch 135 / 200 | iteration 150 / 171 | Total Loss: 2.458510637283325 | KNN Loss: 2.432723045349121 | CLS Loss: 0.02578757517039776
Epoch 135 / 200 | iteration 160 / 171 | Total Loss: 2.4666924476623535 | KNN Loss: 2.462677240371704 | CLS Loss: 0.0040151989087462425
Epoch 135 / 200 | iteration 170 / 171 | Total Loss: 2.419105052947998 | KNN Loss: 2.4024643898010254 | CLS Loss: 0.016640733927488327
Epoch: 135, Loss: 2.4131, Train: 0.9967, Valid: 0.9858, Best: 0.9872
Epoch 136 / 200 | iteration 0 / 171 | Total Loss: 2.4257922172546387 | KNN Loss: 2.421905279159546 | CLS Loss: 0.00388699583709239

Epoch: 138, Loss: 2.4070, Train: 0.9966, Valid: 0.9872, Best: 0.9872
Epoch 139 / 200 | iteration 0 / 171 | Total Loss: 2.4293875694274902 | KNN Loss: 2.419623613357544 | CLS Loss: 0.00976401474326849
Epoch 139 / 200 | iteration 10 / 171 | Total Loss: 2.404351234436035 | KNN Loss: 2.399686574935913 | CLS Loss: 0.004664599895477295
Epoch 139 / 200 | iteration 20 / 171 | Total Loss: 2.4004955291748047 | KNN Loss: 2.381991386413574 | CLS Loss: 0.01850416511297226
Epoch 139 / 200 | iteration 30 / 171 | Total Loss: 2.419252872467041 | KNN Loss: 2.404296636581421 | CLS Loss: 0.014956243336200714
Epoch 139 / 200 | iteration 40 / 171 | Total Loss: 2.4266433715820312 | KNN Loss: 2.4201161861419678 | CLS Loss: 0.00652730418369174
Epoch 139 / 200 | iteration 50 / 171 | Total Loss: 2.417839527130127 | KNN Loss: 2.406299114227295 | CLS Loss: 0.01154048927128315
Epoch 139 / 200 | iteration 60 / 171 | Total Loss: 2.4182417392730713 | KNN Loss: 2.398493766784668 | CLS Loss: 0.01974790170788765
Epoch 13

Epoch 142 / 200 | iteration 60 / 171 | Total Loss: 2.396801710128784 | KNN Loss: 2.3930184841156006 | CLS Loss: 0.003783267457038164
Epoch 142 / 200 | iteration 70 / 171 | Total Loss: 2.417140483856201 | KNN Loss: 2.3986928462982178 | CLS Loss: 0.018447542563080788
Epoch 142 / 200 | iteration 80 / 171 | Total Loss: 2.4030094146728516 | KNN Loss: 2.3969268798828125 | CLS Loss: 0.006082584615796804
Epoch 142 / 200 | iteration 90 / 171 | Total Loss: 2.426377296447754 | KNN Loss: 2.4216182231903076 | CLS Loss: 0.00475896243005991
Epoch 142 / 200 | iteration 100 / 171 | Total Loss: 2.39205002784729 | KNN Loss: 2.385650634765625 | CLS Loss: 0.006399284582585096
Epoch 142 / 200 | iteration 110 / 171 | Total Loss: 2.3765580654144287 | KNN Loss: 2.3732762336730957 | CLS Loss: 0.0032818240579217672
Epoch 142 / 200 | iteration 120 / 171 | Total Loss: 2.426480770111084 | KNN Loss: 2.422492027282715 | CLS Loss: 0.0039887810125947
Epoch 142 / 200 | iteration 130 / 171 | Total Loss: 2.402863979339599

Epoch 145 / 200 | iteration 120 / 171 | Total Loss: 2.4133758544921875 | KNN Loss: 2.394998550415039 | CLS Loss: 0.018377244472503662
Epoch 145 / 200 | iteration 130 / 171 | Total Loss: 2.3793256282806396 | KNN Loss: 2.3762621879577637 | CLS Loss: 0.0030634193681180477
Epoch 145 / 200 | iteration 140 / 171 | Total Loss: 2.4180376529693604 | KNN Loss: 2.3935165405273438 | CLS Loss: 0.024521199986338615
Epoch 145 / 200 | iteration 150 / 171 | Total Loss: 2.4487240314483643 | KNN Loss: 2.4326465129852295 | CLS Loss: 0.016077591106295586
Epoch 145 / 200 | iteration 160 / 171 | Total Loss: 2.4032468795776367 | KNN Loss: 2.3990578651428223 | CLS Loss: 0.004189037252217531
Epoch 145 / 200 | iteration 170 / 171 | Total Loss: 2.367698907852173 | KNN Loss: 2.361074686050415 | CLS Loss: 0.006624209228903055
Epoch: 145, Loss: 2.4067, Train: 0.9978, Valid: 0.9869, Best: 0.9872
Epoch 146 / 200 | iteration 0 / 171 | Total Loss: 2.421119213104248 | KNN Loss: 2.4184162616729736 | CLS Loss: 0.0027029984

Epoch: 148, Loss: 2.4070, Train: 0.9959, Valid: 0.9845, Best: 0.9872
Epoch 149 / 200 | iteration 0 / 171 | Total Loss: 2.444162130355835 | KNN Loss: 2.409745693206787 | CLS Loss: 0.03441649675369263
Epoch 149 / 200 | iteration 10 / 171 | Total Loss: 2.417348861694336 | KNN Loss: 2.412590503692627 | CLS Loss: 0.004758280701935291
Epoch 149 / 200 | iteration 20 / 171 | Total Loss: 2.4143965244293213 | KNN Loss: 2.393092393875122 | CLS Loss: 0.021304160356521606
Epoch 149 / 200 | iteration 30 / 171 | Total Loss: 2.4054372310638428 | KNN Loss: 2.3918607234954834 | CLS Loss: 0.01357639953494072
Epoch 149 / 200 | iteration 40 / 171 | Total Loss: 2.4007883071899414 | KNN Loss: 2.3769075870513916 | CLS Loss: 0.023880617693066597
Epoch 149 / 200 | iteration 50 / 171 | Total Loss: 2.392603874206543 | KNN Loss: 2.390293836593628 | CLS Loss: 0.0023101160768419504
Epoch 149 / 200 | iteration 60 / 171 | Total Loss: 2.440016746520996 | KNN Loss: 2.433034658432007 | CLS Loss: 0.006982048507779837
Epoc

Epoch 152 / 200 | iteration 60 / 171 | Total Loss: 2.4184517860412598 | KNN Loss: 2.413170337677002 | CLS Loss: 0.005281556863337755
Epoch 152 / 200 | iteration 70 / 171 | Total Loss: 2.412416696548462 | KNN Loss: 2.408963203430176 | CLS Loss: 0.0034534367732703686
Epoch 152 / 200 | iteration 80 / 171 | Total Loss: 2.3978819847106934 | KNN Loss: 2.396674394607544 | CLS Loss: 0.0012075315462425351
Epoch 152 / 200 | iteration 90 / 171 | Total Loss: 2.370090961456299 | KNN Loss: 2.363598585128784 | CLS Loss: 0.006492315791547298
Epoch 152 / 200 | iteration 100 / 171 | Total Loss: 2.4271326065063477 | KNN Loss: 2.4233052730560303 | CLS Loss: 0.003827235661447048
Epoch 152 / 200 | iteration 110 / 171 | Total Loss: 2.467238426208496 | KNN Loss: 2.4562411308288574 | CLS Loss: 0.010997183620929718
Epoch 152 / 200 | iteration 120 / 171 | Total Loss: 2.3843889236450195 | KNN Loss: 2.3700692653656006 | CLS Loss: 0.01431970950216055
Epoch 152 / 200 | iteration 130 / 171 | Total Loss: 2.38552618026

Epoch 155 / 200 | iteration 120 / 171 | Total Loss: 2.434494733810425 | KNN Loss: 2.424710273742676 | CLS Loss: 0.009784398600459099
Epoch 155 / 200 | iteration 130 / 171 | Total Loss: 2.4036617279052734 | KNN Loss: 2.3947813510894775 | CLS Loss: 0.008880479261279106
Epoch 155 / 200 | iteration 140 / 171 | Total Loss: 2.4097115993499756 | KNN Loss: 2.4057745933532715 | CLS Loss: 0.003936906810849905
Epoch 155 / 200 | iteration 150 / 171 | Total Loss: 2.410839557647705 | KNN Loss: 2.4064414501190186 | CLS Loss: 0.00439799576997757
Epoch 155 / 200 | iteration 160 / 171 | Total Loss: 2.404160499572754 | KNN Loss: 2.3962314128875732 | CLS Loss: 0.00792916864156723
Epoch 155 / 200 | iteration 170 / 171 | Total Loss: 2.3837244510650635 | KNN Loss: 2.3737735748291016 | CLS Loss: 0.009950980544090271
Epoch: 155, Loss: 2.4137, Train: 0.9970, Valid: 0.9863, Best: 0.9872
Epoch 156 / 200 | iteration 0 / 171 | Total Loss: 2.4218616485595703 | KNN Loss: 2.4171524047851562 | CLS Loss: 0.0047091497108

Epoch: 158, Loss: 2.4192, Train: 0.9964, Valid: 0.9853, Best: 0.9872
Epoch 159 / 200 | iteration 0 / 171 | Total Loss: 2.443228006362915 | KNN Loss: 2.433905839920044 | CLS Loss: 0.00932206679135561
Epoch 159 / 200 | iteration 10 / 171 | Total Loss: 2.43428897857666 | KNN Loss: 2.4239776134490967 | CLS Loss: 0.010311344638466835
Epoch 159 / 200 | iteration 20 / 171 | Total Loss: 2.4657115936279297 | KNN Loss: 2.456658363342285 | CLS Loss: 0.009053127840161324
Epoch 159 / 200 | iteration 30 / 171 | Total Loss: 2.4090301990509033 | KNN Loss: 2.4027915000915527 | CLS Loss: 0.006238613277673721
Epoch 159 / 200 | iteration 40 / 171 | Total Loss: 2.4046480655670166 | KNN Loss: 2.396949052810669 | CLS Loss: 0.007698944304138422
Epoch 159 / 200 | iteration 50 / 171 | Total Loss: 2.384688377380371 | KNN Loss: 2.3804023265838623 | CLS Loss: 0.004286156967282295
Epoch 159 / 200 | iteration 60 / 171 | Total Loss: 2.404247522354126 | KNN Loss: 2.395864963531494 | CLS Loss: 0.008382528088986874
Epoc

Epoch 162 / 200 | iteration 60 / 171 | Total Loss: 2.4197309017181396 | KNN Loss: 2.4131736755371094 | CLS Loss: 0.006557187996804714
Epoch 162 / 200 | iteration 70 / 171 | Total Loss: 2.409358501434326 | KNN Loss: 2.40689754486084 | CLS Loss: 0.0024608904495835304
Epoch 162 / 200 | iteration 80 / 171 | Total Loss: 2.3979856967926025 | KNN Loss: 2.391758680343628 | CLS Loss: 0.006227103527635336
Epoch 162 / 200 | iteration 90 / 171 | Total Loss: 2.4295308589935303 | KNN Loss: 2.4244730472564697 | CLS Loss: 0.005057752598077059
Epoch 162 / 200 | iteration 100 / 171 | Total Loss: 2.3954851627349854 | KNN Loss: 2.3830649852752686 | CLS Loss: 0.012420210056006908
Epoch 162 / 200 | iteration 110 / 171 | Total Loss: 2.4105212688446045 | KNN Loss: 2.400390863418579 | CLS Loss: 0.010130386799573898
Epoch 162 / 200 | iteration 120 / 171 | Total Loss: 2.4687814712524414 | KNN Loss: 2.457796335220337 | CLS Loss: 0.01098511554300785
Epoch 162 / 200 | iteration 130 / 171 | Total Loss: 2.45480847358

Epoch 165 / 200 | iteration 120 / 171 | Total Loss: 2.418165445327759 | KNN Loss: 2.40348744392395 | CLS Loss: 0.01467798464000225
Epoch 165 / 200 | iteration 130 / 171 | Total Loss: 2.435713529586792 | KNN Loss: 2.4294190406799316 | CLS Loss: 0.006294600665569305
Epoch 165 / 200 | iteration 140 / 171 | Total Loss: 2.4189960956573486 | KNN Loss: 2.414642095565796 | CLS Loss: 0.004354076460003853
Epoch 165 / 200 | iteration 150 / 171 | Total Loss: 2.4105513095855713 | KNN Loss: 2.4028000831604004 | CLS Loss: 0.007751145865768194
Epoch 165 / 200 | iteration 160 / 171 | Total Loss: 2.416099786758423 | KNN Loss: 2.411648988723755 | CLS Loss: 0.004450783133506775
Epoch 165 / 200 | iteration 170 / 171 | Total Loss: 2.437040090560913 | KNN Loss: 2.4231956005096436 | CLS Loss: 0.013844595290720463
Epoch: 165, Loss: 2.4162, Train: 0.9978, Valid: 0.9871, Best: 0.9872
Epoch 166 / 200 | iteration 0 / 171 | Total Loss: 2.389057159423828 | KNN Loss: 2.383793592453003 | CLS Loss: 0.005263453815132379

Epoch: 168, Loss: 2.4199, Train: 0.9970, Valid: 0.9852, Best: 0.9872
Epoch 169 / 200 | iteration 0 / 171 | Total Loss: 2.3933002948760986 | KNN Loss: 2.3888840675354004 | CLS Loss: 0.004416283685714006
Epoch 169 / 200 | iteration 10 / 171 | Total Loss: 2.4267635345458984 | KNN Loss: 2.4117419719696045 | CLS Loss: 0.015021542087197304
Epoch 169 / 200 | iteration 20 / 171 | Total Loss: 2.4258947372436523 | KNN Loss: 2.4238414764404297 | CLS Loss: 0.0020533695351332426
Epoch 169 / 200 | iteration 30 / 171 | Total Loss: 2.395958662033081 | KNN Loss: 2.3925158977508545 | CLS Loss: 0.0034427656792104244
Epoch 169 / 200 | iteration 40 / 171 | Total Loss: 2.4156854152679443 | KNN Loss: 2.413133382797241 | CLS Loss: 0.0025520557537674904
Epoch 169 / 200 | iteration 50 / 171 | Total Loss: 2.3704793453216553 | KNN Loss: 2.366595506668091 | CLS Loss: 0.003883955767378211
Epoch 169 / 200 | iteration 60 / 171 | Total Loss: 2.4039924144744873 | KNN Loss: 2.402238368988037 | CLS Loss: 0.00175395293626

Epoch 172 / 200 | iteration 60 / 171 | Total Loss: 2.414000988006592 | KNN Loss: 2.395493268966675 | CLS Loss: 0.018507622182369232
Epoch 172 / 200 | iteration 70 / 171 | Total Loss: 2.392164945602417 | KNN Loss: 2.387985944747925 | CLS Loss: 0.004179018083959818
Epoch 172 / 200 | iteration 80 / 171 | Total Loss: 2.4149200916290283 | KNN Loss: 2.4134469032287598 | CLS Loss: 0.0014730974799022079
Epoch 172 / 200 | iteration 90 / 171 | Total Loss: 2.4165005683898926 | KNN Loss: 2.4132676124572754 | CLS Loss: 0.0032328381203114986
Epoch 172 / 200 | iteration 100 / 171 | Total Loss: 2.407639741897583 | KNN Loss: 2.4062561988830566 | CLS Loss: 0.0013834249693900347
Epoch 172 / 200 | iteration 110 / 171 | Total Loss: 2.4261586666107178 | KNN Loss: 2.415469169616699 | CLS Loss: 0.010689567774534225
Epoch 172 / 200 | iteration 120 / 171 | Total Loss: 2.447883129119873 | KNN Loss: 2.4248173236846924 | CLS Loss: 0.02306571789085865
Epoch 172 / 200 | iteration 130 / 171 | Total Loss: 2.4002780914

Epoch 175 / 200 | iteration 120 / 171 | Total Loss: 2.4374735355377197 | KNN Loss: 2.4320428371429443 | CLS Loss: 0.005430754739791155
Epoch 175 / 200 | iteration 130 / 171 | Total Loss: 2.4023661613464355 | KNN Loss: 2.395905017852783 | CLS Loss: 0.006461227312684059
Epoch 175 / 200 | iteration 140 / 171 | Total Loss: 2.450218915939331 | KNN Loss: 2.400432586669922 | CLS Loss: 0.04978625848889351
Epoch 175 / 200 | iteration 150 / 171 | Total Loss: 2.3913938999176025 | KNN Loss: 2.3854787349700928 | CLS Loss: 0.005915118847042322
Epoch 175 / 200 | iteration 160 / 171 | Total Loss: 2.4014580249786377 | KNN Loss: 2.3997857570648193 | CLS Loss: 0.0016723406733945012
Epoch 175 / 200 | iteration 170 / 171 | Total Loss: 2.463972806930542 | KNN Loss: 2.435152292251587 | CLS Loss: 0.028820551931858063
Epoch: 175, Loss: 2.4187, Train: 0.9957, Valid: 0.9840, Best: 0.9874
Epoch 176 / 200 | iteration 0 / 171 | Total Loss: 2.429286003112793 | KNN Loss: 2.4078259468078613 | CLS Loss: 0.0214599706232

Epoch: 178, Loss: 2.4145, Train: 0.9978, Valid: 0.9858, Best: 0.9874
Epoch 179 / 200 | iteration 0 / 171 | Total Loss: 2.4150428771972656 | KNN Loss: 2.409811496734619 | CLS Loss: 0.005231310613453388
Epoch 179 / 200 | iteration 10 / 171 | Total Loss: 2.4563872814178467 | KNN Loss: 2.4371211528778076 | CLS Loss: 0.019266119226813316
Epoch 179 / 200 | iteration 20 / 171 | Total Loss: 2.3951241970062256 | KNN Loss: 2.3914458751678467 | CLS Loss: 0.0036782645620405674
Epoch 179 / 200 | iteration 30 / 171 | Total Loss: 2.3775322437286377 | KNN Loss: 2.3705897331237793 | CLS Loss: 0.0069425893016159534
Epoch 179 / 200 | iteration 40 / 171 | Total Loss: 2.4068145751953125 | KNN Loss: 2.391777276992798 | CLS Loss: 0.015037310309708118
Epoch 179 / 200 | iteration 50 / 171 | Total Loss: 2.3904595375061035 | KNN Loss: 2.372316598892212 | CLS Loss: 0.018142856657505035
Epoch 179 / 200 | iteration 60 / 171 | Total Loss: 2.4148144721984863 | KNN Loss: 2.408608913421631 | CLS Loss: 0.006205475423485

Epoch 182 / 200 | iteration 60 / 171 | Total Loss: 2.4124467372894287 | KNN Loss: 2.411271333694458 | CLS Loss: 0.0011753460858017206
Epoch 182 / 200 | iteration 70 / 171 | Total Loss: 2.434814691543579 | KNN Loss: 2.4138009548187256 | CLS Loss: 0.02101362869143486
Epoch 182 / 200 | iteration 80 / 171 | Total Loss: 2.4243712425231934 | KNN Loss: 2.4127674102783203 | CLS Loss: 0.011603936553001404
Epoch 182 / 200 | iteration 90 / 171 | Total Loss: 2.4099419116973877 | KNN Loss: 2.404778480529785 | CLS Loss: 0.005163388792425394
Epoch 182 / 200 | iteration 100 / 171 | Total Loss: 2.404109477996826 | KNN Loss: 2.3904213905334473 | CLS Loss: 0.013687968254089355
Epoch 182 / 200 | iteration 110 / 171 | Total Loss: 2.388370990753174 | KNN Loss: 2.38631010055542 | CLS Loss: 0.002060906495898962
Epoch 182 / 200 | iteration 120 / 171 | Total Loss: 2.380505323410034 | KNN Loss: 2.375601053237915 | CLS Loss: 0.004904215224087238
Epoch 182 / 200 | iteration 130 / 171 | Total Loss: 2.42561984062194

Epoch 185 / 200 | iteration 120 / 171 | Total Loss: 2.410003662109375 | KNN Loss: 2.4045491218566895 | CLS Loss: 0.005454644560813904
Epoch 185 / 200 | iteration 130 / 171 | Total Loss: 2.4125468730926514 | KNN Loss: 2.4065325260162354 | CLS Loss: 0.006014366168528795
Epoch 185 / 200 | iteration 140 / 171 | Total Loss: 2.4414896965026855 | KNN Loss: 2.4342782497406006 | CLS Loss: 0.007211505901068449
Epoch 185 / 200 | iteration 150 / 171 | Total Loss: 2.402709484100342 | KNN Loss: 2.4011120796203613 | CLS Loss: 0.0015975134447216988
Epoch 185 / 200 | iteration 160 / 171 | Total Loss: 2.429591417312622 | KNN Loss: 2.4050540924072266 | CLS Loss: 0.024537308141589165
Epoch 185 / 200 | iteration 170 / 171 | Total Loss: 2.407379150390625 | KNN Loss: 2.4010794162750244 | CLS Loss: 0.006299692206084728
Epoch: 185, Loss: 2.4127, Train: 0.9964, Valid: 0.9852, Best: 0.9874
Epoch 186 / 200 | iteration 0 / 171 | Total Loss: 2.3946070671081543 | KNN Loss: 2.3843612670898438 | CLS Loss: 0.0102457758

Epoch: 188, Loss: 2.4163, Train: 0.9968, Valid: 0.9856, Best: 0.9874
Epoch 189 / 200 | iteration 0 / 171 | Total Loss: 2.4022114276885986 | KNN Loss: 2.39015531539917 | CLS Loss: 0.01205622497946024
Epoch 189 / 200 | iteration 10 / 171 | Total Loss: 2.4021103382110596 | KNN Loss: 2.3974359035491943 | CLS Loss: 0.004674519412219524
Epoch 189 / 200 | iteration 20 / 171 | Total Loss: 2.4069371223449707 | KNN Loss: 2.3938426971435547 | CLS Loss: 0.013094501569867134
Epoch 189 / 200 | iteration 30 / 171 | Total Loss: 2.3911221027374268 | KNN Loss: 2.3857240676879883 | CLS Loss: 0.0053981030359864235
Epoch 189 / 200 | iteration 40 / 171 | Total Loss: 2.377262592315674 | KNN Loss: 2.370898723602295 | CLS Loss: 0.006363935302942991
Epoch 189 / 200 | iteration 50 / 171 | Total Loss: 2.475409984588623 | KNN Loss: 2.4523887634277344 | CLS Loss: 0.02302112616598606
Epoch 189 / 200 | iteration 60 / 171 | Total Loss: 2.4302632808685303 | KNN Loss: 2.4257004261016846 | CLS Loss: 0.00456285011023283
E

Epoch 192 / 200 | iteration 60 / 171 | Total Loss: 2.445564031600952 | KNN Loss: 2.4437737464904785 | CLS Loss: 0.0017902180552482605
Epoch 192 / 200 | iteration 70 / 171 | Total Loss: 2.4203107357025146 | KNN Loss: 2.4160983562469482 | CLS Loss: 0.004212354775518179
Epoch 192 / 200 | iteration 80 / 171 | Total Loss: 2.4323344230651855 | KNN Loss: 2.4142065048217773 | CLS Loss: 0.01812788099050522
Epoch 192 / 200 | iteration 90 / 171 | Total Loss: 2.41702938079834 | KNN Loss: 2.4050469398498535 | CLS Loss: 0.01198238879442215
Epoch 192 / 200 | iteration 100 / 171 | Total Loss: 2.3841145038604736 | KNN Loss: 2.379514694213867 | CLS Loss: 0.004599805921316147
Epoch 192 / 200 | iteration 110 / 171 | Total Loss: 2.4017508029937744 | KNN Loss: 2.399730920791626 | CLS Loss: 0.0020198074635118246
Epoch 192 / 200 | iteration 120 / 171 | Total Loss: 2.4248809814453125 | KNN Loss: 2.407672643661499 | CLS Loss: 0.01720832660794258
Epoch 192 / 200 | iteration 130 / 171 | Total Loss: 2.423055887222

Epoch 195 / 200 | iteration 120 / 171 | Total Loss: 2.3954384326934814 | KNN Loss: 2.3866419792175293 | CLS Loss: 0.00879654660820961
Epoch 195 / 200 | iteration 130 / 171 | Total Loss: 2.3906569480895996 | KNN Loss: 2.3863706588745117 | CLS Loss: 0.0042863208800554276
Epoch 195 / 200 | iteration 140 / 171 | Total Loss: 2.378544330596924 | KNN Loss: 2.371440887451172 | CLS Loss: 0.007103500887751579
Epoch 195 / 200 | iteration 150 / 171 | Total Loss: 2.413036346435547 | KNN Loss: 2.37988018989563 | CLS Loss: 0.03315627574920654
Epoch 195 / 200 | iteration 160 / 171 | Total Loss: 2.413043975830078 | KNN Loss: 2.409041166305542 | CLS Loss: 0.004002727568149567
Epoch 195 / 200 | iteration 170 / 171 | Total Loss: 2.436136245727539 | KNN Loss: 2.4286253452301025 | CLS Loss: 0.007510796654969454
Epoch: 195, Loss: 2.4133, Train: 0.9975, Valid: 0.9863, Best: 0.9874
Epoch 196 / 200 | iteration 0 / 171 | Total Loss: 2.405136823654175 | KNN Loss: 2.396449089050293 | CLS Loss: 0.008687762543559074

Epoch: 198, Loss: 2.4146, Train: 0.9965, Valid: 0.9855, Best: 0.9874
Epoch 199 / 200 | iteration 0 / 171 | Total Loss: 2.458559513092041 | KNN Loss: 2.4451005458831787 | CLS Loss: 0.013458899222314358
Epoch 199 / 200 | iteration 10 / 171 | Total Loss: 2.393702745437622 | KNN Loss: 2.3843774795532227 | CLS Loss: 0.009325203485786915
Epoch 199 / 200 | iteration 20 / 171 | Total Loss: 2.411714792251587 | KNN Loss: 2.408041477203369 | CLS Loss: 0.0036733821034431458
Epoch 199 / 200 | iteration 30 / 171 | Total Loss: 2.4443836212158203 | KNN Loss: 2.425159454345703 | CLS Loss: 0.01922428421676159
Epoch 199 / 200 | iteration 40 / 171 | Total Loss: 2.3968935012817383 | KNN Loss: 2.3940272331237793 | CLS Loss: 0.002866328228265047
Epoch 199 / 200 | iteration 50 / 171 | Total Loss: 2.4277451038360596 | KNN Loss: 2.412027359008789 | CLS Loss: 0.01571783795952797
Epoch 199 / 200 | iteration 60 / 171 | Total Loss: 2.4099137783050537 | KNN Loss: 2.4005932807922363 | CLS Loss: 0.009320582263171673
E

In [8]:
test(model, test_data_iter, device)

tensor(0.9856, device='cuda:0')

In [9]:
plt.figure()
plt.plot(losses, label='train loss')
plt.legend()
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [10]:
plt.figure()
plt.plot(train_accs, label='train accuracy')
plt.plot(val_accs, label='validation accuracy')
plt.legend()
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [11]:
test_samples = torch.tensor([])
projections = torch.tensor([])
labels = torch.tensor([])

with torch.no_grad():
    for x, y in tqdm(test_data_iter):
        test_samples = torch.cat([test_samples, x])
        labels = torch.cat([labels, y])
        x = x.to(device)
        _, interm = model(x, True)
        projections = torch.cat([projections, interm.detach().cpu().flatten(1)])

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=43.0), HTML(value='')))




In [23]:
distances = pairwise_distances(projections)
# distances = np.triu(distances)
distances_f = distances.flatten()

plt.matshow(distances)
plt.colorbar()
plt.figure()
plt.hist(distances_f[distances_f > 0], bins=1000)
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

KeyboardInterrupt: 

In [25]:
clusters = DBSCAN(eps=2, min_samples=10).fit_predict(projections)

In [26]:
print(f"Number of inliers: {sum(clusters != -1) / len(clusters)}")

Number of inliers: 0.8591658672513819


In [27]:
perplexity = 100
p = reduce_dims_and_plot(projections[clusters != -1],
                         y=clusters[clusters != -1],
                         title=f'perplexity: {perplexity}',
                         file_name=None,
                         perplexity=perplexity,
                         library='Multicore-TSNE',
                         perform_PCA=False,
                         projected=None,
                         figure_type='2d',
                         show_figure=True,
                         close_figure=False,
                         text=None)    

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

# Train a Soft-Decision-Tree given the self-labels

## Prepare the dataset

In [28]:
tree_dataset = list(zip(test_samples.flatten(1)[clusters!=-1], clusters[clusters != -1]))
batch_size = 512
tree_loader = torch.utils.data.DataLoader(tree_dataset, batch_size=batch_size, shuffle=True)

# Define how we prune the weights of a node

In [29]:
def prune_node(node_weights, factor=1):
    w = node_weights.cpu().detach().numpy()
    mean_ = np.mean(w)
    std_ = np.std(w)
    node_weights[((mean_ - std_ * factor) < node_weights) & (node_weights < (mean_ + std_ * factor))] = 0
    return node_weights

def prune_node_keep(node_weights, keep=4):
    w = node_weights.cpu().detach().numpy()
    throw_idx = np.argsort(abs(w))[:-keep]
    node_weights[throw_idx] = 0
    return node_weights

def prune_tree(tree_, factor):
    new_weights = tree_.inner_nodes.weight.clone()
    for i in range(new_weights.shape[0]):
        res = prune_node_keep(new_weights[i, :], factor)
        new_weights[i, :] = res

    with torch.no_grad():
        tree_.inner_nodes.weight.copy_(new_weights)
        
def sparseness(x):
    s = []
    for i in range(x.shape[0]):
        x_ = x[i, :]
        sp = (len(x_) - torch.norm(x_, 0).item()) / len(x_)
        s.append(sp)
    return np.mean(s)

def compute_regularization_by_level(tree):
    total_reg = 0
    for i in range(tree.inner_nodes.weight.shape[0]):
        cur_level = np.floor(np.log2(i+1))
        node_reg = torch.norm(tree.inner_nodes.weight[i].view(-1), 2)
        total_reg += 2**(-cur_level) * node_reg
    return total_reg

def show_sparseness(tree):
    avg_sp = sparseness(tree.inner_nodes.weight)
    print(f"Average sparseness: {avg_sp}")
    layer = 0
    sps = []
    for i in range(tree.inner_nodes.weight.shape[0]):
        cur_layer = int(np.floor(np.log2(i+1)))
        if cur_layer != layer:
            print(f"layer {layer}: {np.mean(sps)}")
            sps = []
            layer = cur_layer

        x_ = tree.inner_nodes.weight[i, :]
        sp = (len(x_) - torch.norm(x_, 0).item()) / len(x_)
        sps.append(sp)
        
    return avg_sp

## Training configurations

In [30]:
def do_epoch(model, loader, device, log_interval, losses, accs, epoch, iteration):
    model = model.train()
    for batch_idx, (data, target) in enumerate(loader):
        iteration += 1
        data, target = data.to(device), target.to(device)

        output, penalty = tree.forward(data)

        # Loss
        loss_tree = criterion(output, target.view(-1))

        # Penalty
        loss_tree += penalty

        # Sparse regularization
#         fc_params = torch.cat([x.view(-1) for x in tree.inner_nodes.parameters()])
#         regularization = sparsity_lamda * torch.norm(fc_params, 2)
        regularization = sparsity_lamda * compute_regularization_by_level(tree)
        loss = loss_tree

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        losses.append(loss.item())

        pred = output.data.max(1)[1]
        correct = pred.eq(target.view(-1).data).sum()
        accs.append(correct.item() / data.size()[0])

        # Print training status
        if batch_idx % log_interval == 0:
            print(f"Epoch: {epoch:02d} | Batch: {batch_idx:03d} / {len(loader):03d} | Total loss: {loss.item():.3f} | Reg loss: {regularization.item():.3f} | Tree loss: {loss_tree.item():.3f} | Accuracy: {correct.item() / data.size()[0]:03f} | {round((time.time() - start_time) / iteration, 3)} sec/iter")
            
    return iteration


In [31]:
lr = 5e-3
weight_decay = 5e-4
sparsity_lamda = 2e-3
epochs = 400
log_interval = 100
use_cuda = device != 'cpu'

In [32]:
tree = SDT(input_dim=test_samples.shape[2], output_dim=len(set(clusters)) - 1, depth=tree_depth, lamda=1e-3, use_cuda=use_cuda)
optimizer = torch.optim.Adam(tree.parameters(),
                                 lr=lr,
                                 weight_decay=weight_decay)
criterion = nn.CrossEntropyLoss()
tree = tree.to(device)

In [33]:
losses = []
accs = []
sparsity = []

In [None]:
start_time = time.time()
iteration = 0
for epoch in range(epochs):
    # Training
    avg_sp = show_sparseness(tree)
    sparsity.append(avg_sp)
    iteration = do_epoch(tree, tree_loader, device, log_interval, losses, accs, epoch, iteration)
    
    if epoch % 1 == 0:
        prune_tree(tree, factor=3)
        

Average sparseness: 0.0
layer 0: 0.0
layer 1: 0.0
layer 2: 0.0
layer 3: 0.0
layer 4: 0.0
layer 5: 0.0
layer 6: 0.0
layer 7: 0.0
layer 8: 0.0
layer 9: 0.0
layer 10: 0.0
Epoch: 00 | Batch: 000 / 037 | Total loss: 3.313 | Reg loss: 0.014 | Tree loss: 3.313 | Accuracy: 0.033203 | 7.816 sec/iter
Average sparseness: 0.9840425531914895
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
layer 7: 0.9840425531914895
layer 8: 0.9840425531914895
layer 9: 0.9840425531914895
layer 10: 0.9840425531914895
Epoch: 01 | Batch: 000 / 037 | Total loss: 3.305 | Reg loss: 0.005 | Tree loss: 3.305 | Accuracy: 0.060547 | 5.372 sec/iter
Average sparseness: 0.9840425531914895
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
l

Average sparseness: 0.9840425531914895
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
layer 7: 0.9840425531914895
layer 8: 0.9840425531914895
layer 9: 0.9840425531914895
layer 10: 0.9840425531914895
Epoch: 18 | Batch: 000 / 037 | Total loss: 3.137 | Reg loss: 0.015 | Tree loss: 3.137 | Accuracy: 0.289062 | 5.258 sec/iter
Average sparseness: 0.9840425531914895
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
layer 7: 0.9840425531914895
layer 8: 0.9840425531914895
layer 9: 0.9840425531914895
layer 10: 0.9840425531914895
Epoch: 19 | Batch: 000 / 037 | Total loss: 3.138 | Reg loss: 0.015 | Tree loss: 3.138 | Accuracy: 0.269531 | 5.261 sec/iter
Average sparseness: 0.9840425531914895
layer 0: 0.984042

layer 9: 0.9840425531914895
layer 10: 0.9840425531914895
Epoch: 35 | Batch: 000 / 037 | Total loss: 3.127 | Reg loss: 0.018 | Tree loss: 3.127 | Accuracy: 0.255859 | 5.227 sec/iter
Average sparseness: 0.9840425531914895
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
layer 7: 0.9840425531914895
layer 8: 0.9840425531914895
layer 9: 0.9840425531914895
layer 10: 0.9840425531914895
Epoch: 36 | Batch: 000 / 037 | Total loss: 3.079 | Reg loss: 0.018 | Tree loss: 3.079 | Accuracy: 0.296875 | 5.226 sec/iter
Average sparseness: 0.9840425531914895
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
layer 7: 0.9840425531914895
layer 8: 0.9840425531914895
layer 9: 0.9840425531914895
layer 10: 0.9840425531914895

In [None]:
plt.figure(figsize=(10, 5))
plt.ylabel("Accuracy")
plt.xlabel('Iteration')
plt.plot(accs, label='Accuracy vs iteration')
plt.show()

In [None]:
plt.figure()
plt.ylabel("Loss")
plt.xlabel('Iteration')
plt.plot(losses, label='Loss vs iteration')
# plt.yscale("log")
plt.show()

plt.figure()
weights = tree.inner_nodes.weight.cpu().detach().numpy().flatten()
plt.hist(weights, bins=500)
weights_std = np.std(weights)
weights_mean = np.mean(weights)
plt.axvline(weights_mean + weights_std, color='r')
plt.axvline(weights_mean - weights_std, color='r')
plt.title(f"Mean: {weights_mean}   |   STD: {weights_std}")
plt.yscale("log")
plt.show()

# Tree Visualization

In [None]:
plt.figure(figsize=(15, 10), dpi=80)
avg_height, root = tree.visualize()

# Extract Rules

# Accumulate samples in the leaves

In [None]:
print(f"Number of patterns: {len(root.get_leaves())}")

In [None]:
method = 'greedy'

In [None]:
root.clear_leaves_samples()

with torch.no_grad():
    for batch_idx, (data, target) in enumerate(tree_loader):
        root.accumulate_samples(data, method)

# Tighten boundaries

In [None]:
attr_names = [f"T_{i}" for i in range(test_samples.shape[2])]
leaves = root.get_leaves()
sum_comprehensibility = 0
comprehensibilities = []
for pattern_counter, leaf in enumerate(leaves):
    leaf.reset_path()
    leaf.tighten_with_accumulated_samples()
    conds = leaf.get_path_conditions(attr_names)
    print(f"============== Pattern {pattern_counter + 1} ==============")
    comprehensibilities.append(sum([cond.comprehensibility for cond in conds]))
    
print(f"Average comprehensibility: {np.mean(comprehensibilities)}")
print(f"std comprehensibility: {np.std(comprehensibilities)}")
print(f"var comprehensibility: {np.var(comprehensibilities)}")
print(f"minimum comprehensibility: {np.min(comprehensibilities)}")
print(f"maximum comprehensibility: {np.max(comprehensibilities)}")