In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib widget
import numpy as np
import pandas as pd
import torch
import torchvision
from torch import optim
import matplotlib.pyplot as plt
import torch.nn.functional as F
from sklearn.metrics import pairwise_distances
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tqdm.notebook import tqdm
import os
import sys
import time
import torch.nn as nn
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)
from stream_generators.mit_bih import MITBIH
from utils.MatplotlibUtils import reduce_dims_and_plot
from network.auto_encoder import AutoEncoder
from losses.knn_loss import KNNLoss, ClassificationKNNLoss
from sklearn.cluster import DBSCAN
from sklearn.mixture import GaussianMixture
from soft_decision_tree.sdt_model import SDT
from sklearn.metrics import davies_bouldin_score

In [2]:
k = 8
tree_depth = 8
batch_size = 512
device = 'cuda'
train_data_path = r'/mnt/qnap/ekosman/mitbih_train.csv'
test_data_path = r'/mnt/qnap/ekosman/mitbih_test.csv'

In [3]:
train_data_iter = torch.utils.data.DataLoader(MITBIH(train_data_path),
                                             batch_size=batch_size,
                                             shuffle=True,
                                             num_workers=1,
                                             pin_memory=True,
                                             drop_last=True)

test_data_iter = torch.utils.data.DataLoader(MITBIH(test_data_path),
                                             batch_size=batch_size,
                                             shuffle=True,
                                             num_workers=1,
                                             pin_memory=True)

In [4]:
class ConvBlock(nn.Module):
    def __init__(self):
        super(ConvBlock, self).__init__()
        self.conv1 = nn.Conv1d(32, 32, kernel_size=5, stride=1, padding=2)
        self.conv2 = nn.Conv1d(32, 32, kernel_size=5, stride=1, padding=2)
        self.relu1 = nn.ReLU()
        self.relu2 = nn.ReLU()
        self.pool = nn.MaxPool1d(kernel_size=5, stride=2)
        
    def forward(self, x):
        y = x
        y = self.conv1(y)
        y = self.relu1(y)
        y = self.conv2(y)
        y = y + x
        y = self.relu2(y)
        y = self.pool(y)
        return y


class ECGModel(nn.Module):
    def __init__(self):
        super(ECGModel, self).__init__()
        self.conv = nn.Conv1d(1, 32, kernel_size=5, stride=1, padding=1)
        self.block1 = ConvBlock()
        self.block2 = ConvBlock()
        self.block3 = ConvBlock()
        self.block4 = ConvBlock()
        self.block5 = ConvBlock()
        self.fc1 = nn.Linear(64, 32)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(32, 5)

    def forward(self, x, return_interm=False):
        x = self.conv(x)
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        x = self.block4(x)
        x = self.block5(x)
        interm = x.flatten(1)
        x = self.fc1(interm)
        x = self.relu(x)
        x = self.fc2(x)
        
        if return_interm:
            return x, interm
        
        return x

In [5]:
knn_crt = ClassificationKNNLoss(k=k).to(device)

def train(model, loader, optimizer, device):
    model.train()

    total_loss = 0
    for iteration, (batch, target) in enumerate(loader):
        batch = batch.to(device)
        target = target.to(device)
        outputs, interm = model(batch, return_interm=True)
        mse_loss = F.cross_entropy(outputs, target)
        mse_loss = mse_loss.sum(dim=-1).mean()
        try:
            knn_loss = knn_crt(interm, target)
            if torch.isinf(knn_loss):
                knn_loss = torch.tensor(0).to(device)
        except ValueError:
            knn_loss = torch.tensor(0).to(device)
        loss = mse_loss + knn_loss
        total_loss += loss.item()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if iteration % log_every == 0:
            print(f"Epoch {epoch} / {epochs} | iteration {iteration} / {len(loader)} | Total Loss: {loss.item()} | KNN Loss: {knn_loss.item()} | CLS Loss: {mse_loss.item()}")

    return total_loss / len(loader)

@torch.no_grad()
def test(model, loader, device):
    model.eval()
    
    correct = 0
    for iteration, (batch, target) in enumerate(loader):
        batch = batch.to(device)
        target = target.to(device)
        y_pred = model(batch).argmax(dim=-1)
        correct += y_pred.eq(target.view(-1).data).sum()
    
    return correct / len(loader.dataset)

In [6]:
epochs = 200
lr = 1e-3
log_every = 10

model = ECGModel().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
num_params = sum([p.numel() for p in model.parameters()])
print(f'#Params: {num_params}')

#Params: 53957


In [7]:
best_valid_acc = 0
losses = []
train_accs = []
val_accs = []
for epoch in range(1, epochs + 1):
    loss = train(model, train_data_iter, optimizer, device)
#     print(f"Loss: {loss} =============================")
    losses.append(loss)
    train_acc = test(model, train_data_iter, device)
    train_accs.append(train_acc)
    valid_acc = test(model, test_data_iter, device)
    val_accs.append(valid_acc)
    if valid_acc > best_valid_acc:
        best_valid_acc = valid_acc
    if epoch % 1 == 0:
        print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, '
              f'Train: {train_acc:.4f}, Valid: {valid_acc:.4f}, '
              f'Best: {best_valid_acc:.4f}')

Epoch 1 / 200 | iteration 0 / 171 | Total Loss: 7.258726119995117 | KNN Loss: 5.58424711227417 | CLS Loss: 1.6744791269302368
Epoch 1 / 200 | iteration 10 / 171 | Total Loss: 4.858896255493164 | KNN Loss: 4.17528772354126 | CLS Loss: 0.68360835313797
Epoch 1 / 200 | iteration 20 / 171 | Total Loss: 3.323235511779785 | KNN Loss: 2.731367588043213 | CLS Loss: 0.5918678641319275
Epoch 1 / 200 | iteration 30 / 171 | Total Loss: 3.1718411445617676 | KNN Loss: 2.56550669670105 | CLS Loss: 0.606334388256073
Epoch 1 / 200 | iteration 40 / 171 | Total Loss: 3.0967538356781006 | KNN Loss: 2.5365593433380127 | CLS Loss: 0.5601944327354431
Epoch 1 / 200 | iteration 50 / 171 | Total Loss: 2.9649579524993896 | KNN Loss: 2.501180410385132 | CLS Loss: 0.46377745270729065
Epoch 1 / 200 | iteration 60 / 171 | Total Loss: 3.0078890323638916 | KNN Loss: 2.502568006515503 | CLS Loss: 0.5053209662437439
Epoch 1 / 200 | iteration 70 / 171 | Total Loss: 2.9528684616088867 | KNN Loss: 2.4429826736450195 | CLS 

Epoch 4 / 200 | iteration 80 / 171 | Total Loss: 2.566361427307129 | KNN Loss: 2.463594675064087 | CLS Loss: 0.1027667224407196
Epoch 4 / 200 | iteration 90 / 171 | Total Loss: 2.593855142593384 | KNN Loss: 2.463956356048584 | CLS Loss: 0.12989874184131622
Epoch 4 / 200 | iteration 100 / 171 | Total Loss: 2.6004343032836914 | KNN Loss: 2.4911108016967773 | CLS Loss: 0.10932353138923645
Epoch 4 / 200 | iteration 110 / 171 | Total Loss: 2.61438250541687 | KNN Loss: 2.477689266204834 | CLS Loss: 0.13669322431087494
Epoch 4 / 200 | iteration 120 / 171 | Total Loss: 2.604036569595337 | KNN Loss: 2.478548049926758 | CLS Loss: 0.12548844516277313
Epoch 4 / 200 | iteration 130 / 171 | Total Loss: 2.550182819366455 | KNN Loss: 2.457052230834961 | CLS Loss: 0.09313064813613892
Epoch 4 / 200 | iteration 140 / 171 | Total Loss: 2.5695672035217285 | KNN Loss: 2.4919986724853516 | CLS Loss: 0.07756859809160233
Epoch 4 / 200 | iteration 150 / 171 | Total Loss: 2.602182149887085 | KNN Loss: 2.46675968

Epoch 7 / 200 | iteration 160 / 171 | Total Loss: 2.5054526329040527 | KNN Loss: 2.4551796913146973 | CLS Loss: 0.050273049622774124
Epoch 7 / 200 | iteration 170 / 171 | Total Loss: 2.5363872051239014 | KNN Loss: 2.4591224193573 | CLS Loss: 0.0772646889090538
Epoch: 007, Loss: 2.5362, Train: 0.9793, Valid: 0.9762, Best: 0.9762
Epoch 8 / 200 | iteration 0 / 171 | Total Loss: 2.514387607574463 | KNN Loss: 2.455143690109253 | CLS Loss: 0.05924391746520996
Epoch 8 / 200 | iteration 10 / 171 | Total Loss: 2.5086286067962646 | KNN Loss: 2.436633348464966 | CLS Loss: 0.07199514657258987
Epoch 8 / 200 | iteration 20 / 171 | Total Loss: 2.554171562194824 | KNN Loss: 2.47037935256958 | CLS Loss: 0.08379209041595459
Epoch 8 / 200 | iteration 30 / 171 | Total Loss: 2.537275791168213 | KNN Loss: 2.4291815757751465 | CLS Loss: 0.10809414088726044
Epoch 8 / 200 | iteration 40 / 171 | Total Loss: 2.5084729194641113 | KNN Loss: 2.450446367263794 | CLS Loss: 0.05802658945322037
Epoch 8 / 200 | iteratio

Epoch 11 / 200 | iteration 50 / 171 | Total Loss: 2.452906370162964 | KNN Loss: 2.4076926708221436 | CLS Loss: 0.04521365463733673
Epoch 11 / 200 | iteration 60 / 171 | Total Loss: 2.5393195152282715 | KNN Loss: 2.4924399852752686 | CLS Loss: 0.046879444271326065
Epoch 11 / 200 | iteration 70 / 171 | Total Loss: 2.507828712463379 | KNN Loss: 2.4357590675354004 | CLS Loss: 0.07206976413726807
Epoch 11 / 200 | iteration 80 / 171 | Total Loss: 2.4950079917907715 | KNN Loss: 2.415837287902832 | CLS Loss: 0.07917068898677826
Epoch 11 / 200 | iteration 90 / 171 | Total Loss: 2.5553457736968994 | KNN Loss: 2.4509036540985107 | CLS Loss: 0.10444211214780807
Epoch 11 / 200 | iteration 100 / 171 | Total Loss: 2.5299899578094482 | KNN Loss: 2.4251089096069336 | CLS Loss: 0.10488102585077286
Epoch 11 / 200 | iteration 110 / 171 | Total Loss: 2.478713274002075 | KNN Loss: 2.441230297088623 | CLS Loss: 0.03748295456171036
Epoch 11 / 200 | iteration 120 / 171 | Total Loss: 2.5011115074157715 | KNN Lo

Epoch 14 / 200 | iteration 120 / 171 | Total Loss: 2.450160503387451 | KNN Loss: 2.4191582202911377 | CLS Loss: 0.031002182513475418
Epoch 14 / 200 | iteration 130 / 171 | Total Loss: 2.490746021270752 | KNN Loss: 2.4164295196533203 | CLS Loss: 0.0743165984749794
Epoch 14 / 200 | iteration 140 / 171 | Total Loss: 2.510970115661621 | KNN Loss: 2.420605421066284 | CLS Loss: 0.09036476165056229
Epoch 14 / 200 | iteration 150 / 171 | Total Loss: 2.470524311065674 | KNN Loss: 2.4175517559051514 | CLS Loss: 0.05297257751226425
Epoch 14 / 200 | iteration 160 / 171 | Total Loss: 2.4557316303253174 | KNN Loss: 2.3701064586639404 | CLS Loss: 0.08562511205673218
Epoch 14 / 200 | iteration 170 / 171 | Total Loss: 2.411316156387329 | KNN Loss: 2.3702523708343506 | CLS Loss: 0.04106369614601135
Epoch: 014, Loss: 2.4797, Train: 0.9843, Valid: 0.9812, Best: 0.9812
Epoch 15 / 200 | iteration 0 / 171 | Total Loss: 2.4774720668792725 | KNN Loss: 2.4158856868743896 | CLS Loss: 0.06158629059791565
Epoch 15

Epoch 18 / 200 | iteration 10 / 171 | Total Loss: 2.4552762508392334 | KNN Loss: 2.4185264110565186 | CLS Loss: 0.03674991801381111
Epoch 18 / 200 | iteration 20 / 171 | Total Loss: 2.494626998901367 | KNN Loss: 2.4380931854248047 | CLS Loss: 0.05653371661901474
Epoch 18 / 200 | iteration 30 / 171 | Total Loss: 2.4752871990203857 | KNN Loss: 2.4012084007263184 | CLS Loss: 0.07407869398593903
Epoch 18 / 200 | iteration 40 / 171 | Total Loss: 2.440117597579956 | KNN Loss: 2.378377914428711 | CLS Loss: 0.0617397204041481
Epoch 18 / 200 | iteration 50 / 171 | Total Loss: 2.45060133934021 | KNN Loss: 2.3973543643951416 | CLS Loss: 0.05324699729681015
Epoch 18 / 200 | iteration 60 / 171 | Total Loss: 2.4633748531341553 | KNN Loss: 2.388291597366333 | CLS Loss: 0.07508321106433868
Epoch 18 / 200 | iteration 70 / 171 | Total Loss: 2.4639015197753906 | KNN Loss: 2.3991198539733887 | CLS Loss: 0.06478168070316315
Epoch 18 / 200 | iteration 80 / 171 | Total Loss: 2.5040483474731445 | KNN Loss: 2.

Epoch 21 / 200 | iteration 80 / 171 | Total Loss: 2.4143147468566895 | KNN Loss: 2.367396116256714 | CLS Loss: 0.046918682754039764
Epoch 21 / 200 | iteration 90 / 171 | Total Loss: 2.436619281768799 | KNN Loss: 2.3937363624572754 | CLS Loss: 0.04288286715745926
Epoch 21 / 200 | iteration 100 / 171 | Total Loss: 2.466261386871338 | KNN Loss: 2.3826353549957275 | CLS Loss: 0.08362608402967453
Epoch 21 / 200 | iteration 110 / 171 | Total Loss: 2.4562079906463623 | KNN Loss: 2.3994433879852295 | CLS Loss: 0.05676458030939102
Epoch 21 / 200 | iteration 120 / 171 | Total Loss: 2.4521865844726562 | KNN Loss: 2.413733720779419 | CLS Loss: 0.03845291957259178
Epoch 21 / 200 | iteration 130 / 171 | Total Loss: 2.427682638168335 | KNN Loss: 2.398148775100708 | CLS Loss: 0.02953382395207882
Epoch 21 / 200 | iteration 140 / 171 | Total Loss: 2.4180171489715576 | KNN Loss: 2.356450319290161 | CLS Loss: 0.06156686693429947
Epoch 21 / 200 | iteration 150 / 171 | Total Loss: 2.4465363025665283 | KNN L

Epoch 24 / 200 | iteration 150 / 171 | Total Loss: 2.4299185276031494 | KNN Loss: 2.367173671722412 | CLS Loss: 0.06274493783712387
Epoch 24 / 200 | iteration 160 / 171 | Total Loss: 2.4671168327331543 | KNN Loss: 2.4222846031188965 | CLS Loss: 0.044832292944192886
Epoch 24 / 200 | iteration 170 / 171 | Total Loss: 2.41422176361084 | KNN Loss: 2.390052556991577 | CLS Loss: 0.02416921779513359
Epoch: 024, Loss: 2.4477, Train: 0.9889, Valid: 0.9840, Best: 0.9841
Epoch 25 / 200 | iteration 0 / 171 | Total Loss: 2.4431638717651367 | KNN Loss: 2.4172558784484863 | CLS Loss: 0.025907965376973152
Epoch 25 / 200 | iteration 10 / 171 | Total Loss: 2.4306464195251465 | KNN Loss: 2.3835816383361816 | CLS Loss: 0.04706469550728798
Epoch 25 / 200 | iteration 20 / 171 | Total Loss: 2.469083547592163 | KNN Loss: 2.4303953647613525 | CLS Loss: 0.038688212633132935
Epoch 25 / 200 | iteration 30 / 171 | Total Loss: 2.503422737121582 | KNN Loss: 2.439587354660034 | CLS Loss: 0.06383540481328964
Epoch 25 

Epoch 28 / 200 | iteration 40 / 171 | Total Loss: 2.4749159812927246 | KNN Loss: 2.421358585357666 | CLS Loss: 0.05355742946267128
Epoch 28 / 200 | iteration 50 / 171 | Total Loss: 2.46362042427063 | KNN Loss: 2.4107449054718018 | CLS Loss: 0.05287555232644081
Epoch 28 / 200 | iteration 60 / 171 | Total Loss: 2.4160211086273193 | KNN Loss: 2.3896467685699463 | CLS Loss: 0.02637444995343685
Epoch 28 / 200 | iteration 70 / 171 | Total Loss: 2.4543700218200684 | KNN Loss: 2.4069278240203857 | CLS Loss: 0.04744216427206993
Epoch 28 / 200 | iteration 80 / 171 | Total Loss: 2.4510867595672607 | KNN Loss: 2.4124388694763184 | CLS Loss: 0.03864789009094238
Epoch 28 / 200 | iteration 90 / 171 | Total Loss: 2.469883918762207 | KNN Loss: 2.439171075820923 | CLS Loss: 0.030712760984897614
Epoch 28 / 200 | iteration 100 / 171 | Total Loss: 2.4902870655059814 | KNN Loss: 2.364896535873413 | CLS Loss: 0.12539047002792358
Epoch 28 / 200 | iteration 110 / 171 | Total Loss: 2.476637363433838 | KNN Loss:

Epoch 31 / 200 | iteration 110 / 171 | Total Loss: 2.4378864765167236 | KNN Loss: 2.408557415008545 | CLS Loss: 0.0293289665132761
Epoch 31 / 200 | iteration 120 / 171 | Total Loss: 2.42397403717041 | KNN Loss: 2.4074039459228516 | CLS Loss: 0.016570014879107475
Epoch 31 / 200 | iteration 130 / 171 | Total Loss: 2.4283604621887207 | KNN Loss: 2.373887538909912 | CLS Loss: 0.05447299778461456
Epoch 31 / 200 | iteration 140 / 171 | Total Loss: 2.4238345623016357 | KNN Loss: 2.387714385986328 | CLS Loss: 0.03612016886472702
Epoch 31 / 200 | iteration 150 / 171 | Total Loss: 2.3981339931488037 | KNN Loss: 2.3370521068573 | CLS Loss: 0.06108184903860092
Epoch 31 / 200 | iteration 160 / 171 | Total Loss: 2.4838526248931885 | KNN Loss: 2.450967311859131 | CLS Loss: 0.032885342836380005
Epoch 31 / 200 | iteration 170 / 171 | Total Loss: 2.430816411972046 | KNN Loss: 2.388277053833008 | CLS Loss: 0.04253930225968361
Epoch: 031, Loss: 2.4371, Train: 0.9897, Valid: 0.9843, Best: 0.9857
Epoch 32 /

Epoch: 034, Loss: 2.4354, Train: 0.9912, Valid: 0.9857, Best: 0.9859
Epoch 35 / 200 | iteration 0 / 171 | Total Loss: 2.4314017295837402 | KNN Loss: 2.416856288909912 | CLS Loss: 0.014545371755957603
Epoch 35 / 200 | iteration 10 / 171 | Total Loss: 2.416207790374756 | KNN Loss: 2.38043212890625 | CLS Loss: 0.035775549709796906
Epoch 35 / 200 | iteration 20 / 171 | Total Loss: 2.4131674766540527 | KNN Loss: 2.3559951782226562 | CLS Loss: 0.05717229098081589
Epoch 35 / 200 | iteration 30 / 171 | Total Loss: 2.433051586151123 | KNN Loss: 2.4058895111083984 | CLS Loss: 0.0271621011197567
Epoch 35 / 200 | iteration 40 / 171 | Total Loss: 2.4639055728912354 | KNN Loss: 2.4469003677368164 | CLS Loss: 0.01700526475906372
Epoch 35 / 200 | iteration 50 / 171 | Total Loss: 2.4514620304107666 | KNN Loss: 2.4148125648498535 | CLS Loss: 0.03664937987923622
Epoch 35 / 200 | iteration 60 / 171 | Total Loss: 2.4326627254486084 | KNN Loss: 2.394394636154175 | CLS Loss: 0.03826816752552986
Epoch 35 / 20

Epoch 38 / 200 | iteration 70 / 171 | Total Loss: 2.418576717376709 | KNN Loss: 2.3969714641571045 | CLS Loss: 0.021605150774121284
Epoch 38 / 200 | iteration 80 / 171 | Total Loss: 2.3996236324310303 | KNN Loss: 2.383439302444458 | CLS Loss: 0.016184283420443535
Epoch 38 / 200 | iteration 90 / 171 | Total Loss: 2.386899948120117 | KNN Loss: 2.3661489486694336 | CLS Loss: 0.020750973373651505
Epoch 38 / 200 | iteration 100 / 171 | Total Loss: 2.418302297592163 | KNN Loss: 2.3895363807678223 | CLS Loss: 0.028765907511115074
Epoch 38 / 200 | iteration 110 / 171 | Total Loss: 2.4517104625701904 | KNN Loss: 2.4021573066711426 | CLS Loss: 0.04955307021737099
Epoch 38 / 200 | iteration 120 / 171 | Total Loss: 2.427091360092163 | KNN Loss: 2.375933885574341 | CLS Loss: 0.05115745589137077
Epoch 38 / 200 | iteration 130 / 171 | Total Loss: 2.4216666221618652 | KNN Loss: 2.3800528049468994 | CLS Loss: 0.04161386936903
Epoch 38 / 200 | iteration 140 / 171 | Total Loss: 2.3962554931640625 | KNN L

Epoch 41 / 200 | iteration 140 / 171 | Total Loss: 2.415483236312866 | KNN Loss: 2.398831844329834 | CLS Loss: 0.016651304438710213
Epoch 41 / 200 | iteration 150 / 171 | Total Loss: 2.4473180770874023 | KNN Loss: 2.4126017093658447 | CLS Loss: 0.03471630439162254
Epoch 41 / 200 | iteration 160 / 171 | Total Loss: 2.50728702545166 | KNN Loss: 2.4508509635925293 | CLS Loss: 0.05643594264984131
Epoch 41 / 200 | iteration 170 / 171 | Total Loss: 2.430040121078491 | KNN Loss: 2.4011049270629883 | CLS Loss: 0.028935085982084274
Epoch: 041, Loss: 2.4266, Train: 0.9927, Valid: 0.9866, Best: 0.9866
Epoch 42 / 200 | iteration 0 / 171 | Total Loss: 2.4634575843811035 | KNN Loss: 2.417475938796997 | CLS Loss: 0.04598162695765495
Epoch 42 / 200 | iteration 10 / 171 | Total Loss: 2.409532308578491 | KNN Loss: 2.361368179321289 | CLS Loss: 0.04816414788365364
Epoch 42 / 200 | iteration 20 / 171 | Total Loss: 2.414159059524536 | KNN Loss: 2.388326406478882 | CLS Loss: 0.025832761079072952
Epoch 42 / 

Epoch 45 / 200 | iteration 30 / 171 | Total Loss: 2.4263980388641357 | KNN Loss: 2.404324769973755 | CLS Loss: 0.02207331731915474
Epoch 45 / 200 | iteration 40 / 171 | Total Loss: 2.4409589767456055 | KNN Loss: 2.4092812538146973 | CLS Loss: 0.0316777341067791
Epoch 45 / 200 | iteration 50 / 171 | Total Loss: 2.444166898727417 | KNN Loss: 2.422528028488159 | CLS Loss: 0.02163882553577423
Epoch 45 / 200 | iteration 60 / 171 | Total Loss: 2.437716007232666 | KNN Loss: 2.41519832611084 | CLS Loss: 0.022517703473567963
Epoch 45 / 200 | iteration 70 / 171 | Total Loss: 2.4179019927978516 | KNN Loss: 2.3838284015655518 | CLS Loss: 0.034073520451784134
Epoch 45 / 200 | iteration 80 / 171 | Total Loss: 2.4071106910705566 | KNN Loss: 2.3628339767456055 | CLS Loss: 0.04427679628133774
Epoch 45 / 200 | iteration 90 / 171 | Total Loss: 2.4187171459198 | KNN Loss: 2.4087746143341064 | CLS Loss: 0.009942558594048023
Epoch 45 / 200 | iteration 100 / 171 | Total Loss: 2.430152177810669 | KNN Loss: 2.

Epoch 48 / 200 | iteration 100 / 171 | Total Loss: 2.4047234058380127 | KNN Loss: 2.3777356147766113 | CLS Loss: 0.02698773890733719
Epoch 48 / 200 | iteration 110 / 171 | Total Loss: 2.447726249694824 | KNN Loss: 2.37937068939209 | CLS Loss: 0.0683554857969284
Epoch 48 / 200 | iteration 120 / 171 | Total Loss: 2.427299737930298 | KNN Loss: 2.411612033843994 | CLS Loss: 0.015687694773077965
Epoch 48 / 200 | iteration 130 / 171 | Total Loss: 2.4396004676818848 | KNN Loss: 2.41352915763855 | CLS Loss: 0.026071222499012947
Epoch 48 / 200 | iteration 140 / 171 | Total Loss: 2.4562788009643555 | KNN Loss: 2.431913137435913 | CLS Loss: 0.024365771561861038
Epoch 48 / 200 | iteration 150 / 171 | Total Loss: 2.4486563205718994 | KNN Loss: 2.4196228981018066 | CLS Loss: 0.029033519327640533
Epoch 48 / 200 | iteration 160 / 171 | Total Loss: 2.451486825942993 | KNN Loss: 2.4369468688964844 | CLS Loss: 0.01454001571983099
Epoch 48 / 200 | iteration 170 / 171 | Total Loss: 2.422532320022583 | KNN 

Epoch 51 / 200 | iteration 170 / 171 | Total Loss: 2.4168429374694824 | KNN Loss: 2.3775289058685303 | CLS Loss: 0.03931400924921036
Epoch: 051, Loss: 2.4350, Train: 0.9916, Valid: 0.9838, Best: 0.9866
Epoch 52 / 200 | iteration 0 / 171 | Total Loss: 2.4293642044067383 | KNN Loss: 2.399405002593994 | CLS Loss: 0.02995912730693817
Epoch 52 / 200 | iteration 10 / 171 | Total Loss: 2.453402280807495 | KNN Loss: 2.4334073066711426 | CLS Loss: 0.019994977861642838
Epoch 52 / 200 | iteration 20 / 171 | Total Loss: 2.436356782913208 | KNN Loss: 2.418696165084839 | CLS Loss: 0.017660507932305336
Epoch 52 / 200 | iteration 30 / 171 | Total Loss: 2.402538299560547 | KNN Loss: 2.3975844383239746 | CLS Loss: 0.004953878931701183
Epoch 52 / 200 | iteration 40 / 171 | Total Loss: 2.4394915103912354 | KNN Loss: 2.405895709991455 | CLS Loss: 0.03359571099281311
Epoch 52 / 200 | iteration 50 / 171 | Total Loss: 2.446819543838501 | KNN Loss: 2.406564474105835 | CLS Loss: 0.04025515168905258
Epoch 52 / 2

Epoch 55 / 200 | iteration 60 / 171 | Total Loss: 2.4255211353302 | KNN Loss: 2.4016261100769043 | CLS Loss: 0.02389501966536045
Epoch 55 / 200 | iteration 70 / 171 | Total Loss: 2.4023327827453613 | KNN Loss: 2.3707354068756104 | CLS Loss: 0.031597379595041275
Epoch 55 / 200 | iteration 80 / 171 | Total Loss: 2.4490511417388916 | KNN Loss: 2.4195799827575684 | CLS Loss: 0.02947109378874302
Epoch 55 / 200 | iteration 90 / 171 | Total Loss: 2.3745996952056885 | KNN Loss: 2.3657288551330566 | CLS Loss: 0.008870908990502357
Epoch 55 / 200 | iteration 100 / 171 | Total Loss: 2.4290733337402344 | KNN Loss: 2.396984338760376 | CLS Loss: 0.032088905572891235
Epoch 55 / 200 | iteration 110 / 171 | Total Loss: 2.4100141525268555 | KNN Loss: 2.389298439025879 | CLS Loss: 0.020715629681944847
Epoch 55 / 200 | iteration 120 / 171 | Total Loss: 2.4547533988952637 | KNN Loss: 2.4191441535949707 | CLS Loss: 0.03560929372906685
Epoch 55 / 200 | iteration 130 / 171 | Total Loss: 2.51659893989563 | KNN 

Epoch 58 / 200 | iteration 130 / 171 | Total Loss: 2.4260196685791016 | KNN Loss: 2.3995072841644287 | CLS Loss: 0.02651236578822136
Epoch 58 / 200 | iteration 140 / 171 | Total Loss: 2.4658007621765137 | KNN Loss: 2.403367519378662 | CLS Loss: 0.06243321672081947
Epoch 58 / 200 | iteration 150 / 171 | Total Loss: 2.4302268028259277 | KNN Loss: 2.4152133464813232 | CLS Loss: 0.015013477765023708
Epoch 58 / 200 | iteration 160 / 171 | Total Loss: 2.4097259044647217 | KNN Loss: 2.390004873275757 | CLS Loss: 0.019720951095223427
Epoch 58 / 200 | iteration 170 / 171 | Total Loss: 2.4610517024993896 | KNN Loss: 2.426239252090454 | CLS Loss: 0.03481238707900047
Epoch: 058, Loss: 2.4313, Train: 0.9942, Valid: 0.9867, Best: 0.9867
Epoch 59 / 200 | iteration 0 / 171 | Total Loss: 2.4272241592407227 | KNN Loss: 2.4211552143096924 | CLS Loss: 0.006068967282772064
Epoch 59 / 200 | iteration 10 / 171 | Total Loss: 2.4456770420074463 | KNN Loss: 2.412442445755005 | CLS Loss: 0.033234648406505585
Epo

Epoch 62 / 200 | iteration 10 / 171 | Total Loss: 2.4671196937561035 | KNN Loss: 2.420119285583496 | CLS Loss: 0.04700038209557533
Epoch 62 / 200 | iteration 20 / 171 | Total Loss: 2.437140941619873 | KNN Loss: 2.4040844440460205 | CLS Loss: 0.033056557178497314
Epoch 62 / 200 | iteration 30 / 171 | Total Loss: 2.448408842086792 | KNN Loss: 2.420020818710327 | CLS Loss: 0.02838806062936783
Epoch 62 / 200 | iteration 40 / 171 | Total Loss: 2.4199583530426025 | KNN Loss: 2.4121463298797607 | CLS Loss: 0.00781198451295495
Epoch 62 / 200 | iteration 50 / 171 | Total Loss: 2.4052696228027344 | KNN Loss: 2.3479018211364746 | CLS Loss: 0.0573677234351635
Epoch 62 / 200 | iteration 60 / 171 | Total Loss: 2.3793725967407227 | KNN Loss: 2.3718619346618652 | CLS Loss: 0.0075105950236320496
Epoch 62 / 200 | iteration 70 / 171 | Total Loss: 2.468064785003662 | KNN Loss: 2.4284355640411377 | CLS Loss: 0.03962932154536247
Epoch 62 / 200 | iteration 80 / 171 | Total Loss: 2.419922351837158 | KNN Loss:

Epoch 65 / 200 | iteration 80 / 171 | Total Loss: 2.467092275619507 | KNN Loss: 2.4481406211853027 | CLS Loss: 0.01895175129175186
Epoch 65 / 200 | iteration 90 / 171 | Total Loss: 2.4158260822296143 | KNN Loss: 2.3903536796569824 | CLS Loss: 0.02547239512205124
Epoch 65 / 200 | iteration 100 / 171 | Total Loss: 2.458944797515869 | KNN Loss: 2.4195303916931152 | CLS Loss: 0.03941429406404495
Epoch 65 / 200 | iteration 110 / 171 | Total Loss: 2.449298620223999 | KNN Loss: 2.4285452365875244 | CLS Loss: 0.02075330540537834
Epoch 65 / 200 | iteration 120 / 171 | Total Loss: 2.4389209747314453 | KNN Loss: 2.3916969299316406 | CLS Loss: 0.047224123030900955
Epoch 65 / 200 | iteration 130 / 171 | Total Loss: 2.41792368888855 | KNN Loss: 2.4040770530700684 | CLS Loss: 0.013846532441675663
Epoch 65 / 200 | iteration 140 / 171 | Total Loss: 2.452256441116333 | KNN Loss: 2.4081647396087646 | CLS Loss: 0.044091612100601196
Epoch 65 / 200 | iteration 150 / 171 | Total Loss: 2.452345132827759 | KNN

Epoch 68 / 200 | iteration 150 / 171 | Total Loss: 2.4085986614227295 | KNN Loss: 2.40266489982605 | CLS Loss: 0.005933693610131741
Epoch 68 / 200 | iteration 160 / 171 | Total Loss: 2.4361820220947266 | KNN Loss: 2.4153876304626465 | CLS Loss: 0.020794417709112167
Epoch 68 / 200 | iteration 170 / 171 | Total Loss: 2.4543840885162354 | KNN Loss: 2.407576084136963 | CLS Loss: 0.046807993203401566
Epoch: 068, Loss: 2.4283, Train: 0.9928, Valid: 0.9839, Best: 0.9870
Epoch 69 / 200 | iteration 0 / 171 | Total Loss: 2.431929349899292 | KNN Loss: 2.410444974899292 | CLS Loss: 0.02148429863154888
Epoch 69 / 200 | iteration 10 / 171 | Total Loss: 2.4101815223693848 | KNN Loss: 2.3824336528778076 | CLS Loss: 0.027747903019189835
Epoch 69 / 200 | iteration 20 / 171 | Total Loss: 2.4099998474121094 | KNN Loss: 2.38503098487854 | CLS Loss: 0.024968847632408142
Epoch 69 / 200 | iteration 30 / 171 | Total Loss: 2.42826247215271 | KNN Loss: 2.411430597305298 | CLS Loss: 0.01683192141354084
Epoch 69 /

Epoch 72 / 200 | iteration 30 / 171 | Total Loss: 2.4147679805755615 | KNN Loss: 2.368011474609375 | CLS Loss: 0.046756573021411896
Epoch 72 / 200 | iteration 40 / 171 | Total Loss: 2.403822422027588 | KNN Loss: 2.3975296020507812 | CLS Loss: 0.0062927803955972195
Epoch 72 / 200 | iteration 50 / 171 | Total Loss: 2.426968574523926 | KNN Loss: 2.417637348175049 | CLS Loss: 0.00933121982961893
Epoch 72 / 200 | iteration 60 / 171 | Total Loss: 2.448246955871582 | KNN Loss: 2.4361321926116943 | CLS Loss: 0.012114799581468105
Epoch 72 / 200 | iteration 70 / 171 | Total Loss: 2.462702751159668 | KNN Loss: 2.4387154579162598 | CLS Loss: 0.02398722805082798
Epoch 72 / 200 | iteration 80 / 171 | Total Loss: 2.41074275970459 | KNN Loss: 2.394761323928833 | CLS Loss: 0.015981456264853477
Epoch 72 / 200 | iteration 90 / 171 | Total Loss: 2.4185707569122314 | KNN Loss: 2.3942296504974365 | CLS Loss: 0.02434120699763298
Epoch 72 / 200 | iteration 100 / 171 | Total Loss: 2.431295156478882 | KNN Loss:

Epoch 75 / 200 | iteration 100 / 171 | Total Loss: 2.4168057441711426 | KNN Loss: 2.3943166732788086 | CLS Loss: 0.022488955408334732
Epoch 75 / 200 | iteration 110 / 171 | Total Loss: 2.412680149078369 | KNN Loss: 2.4000706672668457 | CLS Loss: 0.012609390541911125
Epoch 75 / 200 | iteration 120 / 171 | Total Loss: 2.4174296855926514 | KNN Loss: 2.380970001220703 | CLS Loss: 0.036459799855947495
Epoch 75 / 200 | iteration 130 / 171 | Total Loss: 2.468350648880005 | KNN Loss: 2.4354817867279053 | CLS Loss: 0.03286897391080856
Epoch 75 / 200 | iteration 140 / 171 | Total Loss: 2.440643787384033 | KNN Loss: 2.392266273498535 | CLS Loss: 0.048377398401498795
Epoch 75 / 200 | iteration 150 / 171 | Total Loss: 2.45760440826416 | KNN Loss: 2.4144065380096436 | CLS Loss: 0.04319782555103302
Epoch 75 / 200 | iteration 160 / 171 | Total Loss: 2.4324522018432617 | KNN Loss: 2.4145588874816895 | CLS Loss: 0.017893211916089058
Epoch 75 / 200 | iteration 170 / 171 | Total Loss: 2.4566662311553955 |

Epoch 78 / 200 | iteration 170 / 171 | Total Loss: 2.4067165851593018 | KNN Loss: 2.3945562839508057 | CLS Loss: 0.012160374782979488
Epoch: 078, Loss: 2.4245, Train: 0.9948, Valid: 0.9874, Best: 0.9874
Epoch 79 / 200 | iteration 0 / 171 | Total Loss: 2.447366714477539 | KNN Loss: 2.4205687046051025 | CLS Loss: 0.026798121631145477
Epoch 79 / 200 | iteration 10 / 171 | Total Loss: 2.417611837387085 | KNN Loss: 2.394016742706299 | CLS Loss: 0.023595081642270088
Epoch 79 / 200 | iteration 20 / 171 | Total Loss: 2.380323886871338 | KNN Loss: 2.3669354915618896 | CLS Loss: 0.013388407416641712
Epoch 79 / 200 | iteration 30 / 171 | Total Loss: 2.40809965133667 | KNN Loss: 2.401171922683716 | CLS Loss: 0.006927793379873037
Epoch 79 / 200 | iteration 40 / 171 | Total Loss: 2.421306848526001 | KNN Loss: 2.4159114360809326 | CLS Loss: 0.0053954096511006355
Epoch 79 / 200 | iteration 50 / 171 | Total Loss: 2.435999870300293 | KNN Loss: 2.3997762203216553 | CLS Loss: 0.036223605275154114
Epoch 79

Epoch 82 / 200 | iteration 60 / 171 | Total Loss: 2.464536190032959 | KNN Loss: 2.4517765045166016 | CLS Loss: 0.012759581208229065
Epoch 82 / 200 | iteration 70 / 171 | Total Loss: 2.3993563652038574 | KNN Loss: 2.390620231628418 | CLS Loss: 0.008736107498407364
Epoch 82 / 200 | iteration 80 / 171 | Total Loss: 2.4322257041931152 | KNN Loss: 2.407219648361206 | CLS Loss: 0.025006145238876343
Epoch 82 / 200 | iteration 90 / 171 | Total Loss: 2.423063278198242 | KNN Loss: 2.3959202766418457 | CLS Loss: 0.02714288979768753
Epoch 82 / 200 | iteration 100 / 171 | Total Loss: 2.4190518856048584 | KNN Loss: 2.3951735496520996 | CLS Loss: 0.023878393694758415
Epoch 82 / 200 | iteration 110 / 171 | Total Loss: 2.4245893955230713 | KNN Loss: 2.4150545597076416 | CLS Loss: 0.009534938260912895
Epoch 82 / 200 | iteration 120 / 171 | Total Loss: 2.4484574794769287 | KNN Loss: 2.4385032653808594 | CLS Loss: 0.009954111650586128
Epoch 82 / 200 | iteration 130 / 171 | Total Loss: 2.3965511322021484 |

Epoch 85 / 200 | iteration 130 / 171 | Total Loss: 2.414275884628296 | KNN Loss: 2.3914177417755127 | CLS Loss: 0.02285817824304104
Epoch 85 / 200 | iteration 140 / 171 | Total Loss: 2.4685890674591064 | KNN Loss: 2.4525399208068848 | CLS Loss: 0.016049187630414963
Epoch 85 / 200 | iteration 150 / 171 | Total Loss: 2.4019265174865723 | KNN Loss: 2.3910160064697266 | CLS Loss: 0.010910623706877232
Epoch 85 / 200 | iteration 160 / 171 | Total Loss: 2.4272024631500244 | KNN Loss: 2.4029476642608643 | CLS Loss: 0.024254851043224335
Epoch 85 / 200 | iteration 170 / 171 | Total Loss: 2.4302279949188232 | KNN Loss: 2.3954596519470215 | CLS Loss: 0.03476835414767265
Epoch: 085, Loss: 2.4203, Train: 0.9946, Valid: 0.9865, Best: 0.9874
Epoch 86 / 200 | iteration 0 / 171 | Total Loss: 2.4484245777130127 | KNN Loss: 2.4126083850860596 | CLS Loss: 0.0358162522315979
Epoch 86 / 200 | iteration 10 / 171 | Total Loss: 2.4568562507629395 | KNN Loss: 2.4447968006134033 | CLS Loss: 0.012059363536536694
E

Epoch 89 / 200 | iteration 10 / 171 | Total Loss: 2.3781630992889404 | KNN Loss: 2.3645365238189697 | CLS Loss: 0.01362649817019701
Epoch 89 / 200 | iteration 20 / 171 | Total Loss: 2.42268443107605 | KNN Loss: 2.4082818031311035 | CLS Loss: 0.01440261397510767
Epoch 89 / 200 | iteration 30 / 171 | Total Loss: 2.4165287017822266 | KNN Loss: 2.4078850746154785 | CLS Loss: 0.008643708191812038
Epoch 89 / 200 | iteration 40 / 171 | Total Loss: 2.432921886444092 | KNN Loss: 2.3960483074188232 | CLS Loss: 0.036873478442430496
Epoch 89 / 200 | iteration 50 / 171 | Total Loss: 2.4538512229919434 | KNN Loss: 2.4281489849090576 | CLS Loss: 0.025702303275465965
Epoch 89 / 200 | iteration 60 / 171 | Total Loss: 2.4155359268188477 | KNN Loss: 2.406480073928833 | CLS Loss: 0.009055943228304386
Epoch 89 / 200 | iteration 70 / 171 | Total Loss: 2.4191620349884033 | KNN Loss: 2.4032602310180664 | CLS Loss: 0.015901880338788033
Epoch 89 / 200 | iteration 80 / 171 | Total Loss: 2.419233798980713 | KNN L

Epoch 92 / 200 | iteration 80 / 171 | Total Loss: 2.4104998111724854 | KNN Loss: 2.3926427364349365 | CLS Loss: 0.017856977880001068
Epoch 92 / 200 | iteration 90 / 171 | Total Loss: 2.4327399730682373 | KNN Loss: 2.4046895503997803 | CLS Loss: 0.028050504624843597
Epoch 92 / 200 | iteration 100 / 171 | Total Loss: 2.405285358428955 | KNN Loss: 2.3821730613708496 | CLS Loss: 0.023112183436751366
Epoch 92 / 200 | iteration 110 / 171 | Total Loss: 2.4403703212738037 | KNN Loss: 2.4008705615997314 | CLS Loss: 0.03949976712465286
Epoch 92 / 200 | iteration 120 / 171 | Total Loss: 2.4379982948303223 | KNN Loss: 2.42527174949646 | CLS Loss: 0.012726476415991783
Epoch 92 / 200 | iteration 130 / 171 | Total Loss: 2.4200901985168457 | KNN Loss: 2.4034669399261475 | CLS Loss: 0.016623152419924736
Epoch 92 / 200 | iteration 140 / 171 | Total Loss: 2.4229533672332764 | KNN Loss: 2.396937131881714 | CLS Loss: 0.02601628005504608
Epoch 92 / 200 | iteration 150 / 171 | Total Loss: 2.4088377952575684 

Epoch 95 / 200 | iteration 150 / 171 | Total Loss: 2.4187397956848145 | KNN Loss: 2.4157168865203857 | CLS Loss: 0.0030228826217353344
Epoch 95 / 200 | iteration 160 / 171 | Total Loss: 2.4157612323760986 | KNN Loss: 2.3981504440307617 | CLS Loss: 0.017610689625144005
Epoch 95 / 200 | iteration 170 / 171 | Total Loss: 2.4149396419525146 | KNN Loss: 2.4077906608581543 | CLS Loss: 0.007148889359086752
Epoch: 095, Loss: 2.4165, Train: 0.9965, Valid: 0.9876, Best: 0.9878
Epoch 96 / 200 | iteration 0 / 171 | Total Loss: 2.4604711532592773 | KNN Loss: 2.4448139667510986 | CLS Loss: 0.01565719209611416
Epoch 96 / 200 | iteration 10 / 171 | Total Loss: 2.384208917617798 | KNN Loss: 2.3568758964538574 | CLS Loss: 0.027333088219165802
Epoch 96 / 200 | iteration 20 / 171 | Total Loss: 2.4307100772857666 | KNN Loss: 2.411741018295288 | CLS Loss: 0.018969163298606873
Epoch 96 / 200 | iteration 30 / 171 | Total Loss: 2.4435362815856934 | KNN Loss: 2.4034769535064697 | CLS Loss: 0.04005931317806244
E

Epoch 99 / 200 | iteration 30 / 171 | Total Loss: 2.4090805053710938 | KNN Loss: 2.3818883895874023 | CLS Loss: 0.027192117646336555
Epoch 99 / 200 | iteration 40 / 171 | Total Loss: 2.4093453884124756 | KNN Loss: 2.4031081199645996 | CLS Loss: 0.006237211171537638
Epoch 99 / 200 | iteration 50 / 171 | Total Loss: 2.411881923675537 | KNN Loss: 2.4073455333709717 | CLS Loss: 0.004536332096904516
Epoch 99 / 200 | iteration 60 / 171 | Total Loss: 2.406984806060791 | KNN Loss: 2.3789916038513184 | CLS Loss: 0.02799312025308609
Epoch 99 / 200 | iteration 70 / 171 | Total Loss: 2.417468786239624 | KNN Loss: 2.405515670776367 | CLS Loss: 0.011953094974160194
Epoch 99 / 200 | iteration 80 / 171 | Total Loss: 2.3962228298187256 | KNN Loss: 2.3719875812530518 | CLS Loss: 0.024235185235738754
Epoch 99 / 200 | iteration 90 / 171 | Total Loss: 2.397979497909546 | KNN Loss: 2.3949782848358154 | CLS Loss: 0.003001247765496373
Epoch 99 / 200 | iteration 100 / 171 | Total Loss: 2.3935914039611816 | KNN

Epoch 102 / 200 | iteration 100 / 171 | Total Loss: 2.4281089305877686 | KNN Loss: 2.401885747909546 | CLS Loss: 0.026223143562674522
Epoch 102 / 200 | iteration 110 / 171 | Total Loss: 2.4158198833465576 | KNN Loss: 2.399688720703125 | CLS Loss: 0.016131170094013214
Epoch 102 / 200 | iteration 120 / 171 | Total Loss: 2.4092230796813965 | KNN Loss: 2.4031050205230713 | CLS Loss: 0.006118148099631071
Epoch 102 / 200 | iteration 130 / 171 | Total Loss: 2.4101831912994385 | KNN Loss: 2.404726505279541 | CLS Loss: 0.0054567172192037106
Epoch 102 / 200 | iteration 140 / 171 | Total Loss: 2.4063880443573 | KNN Loss: 2.4006831645965576 | CLS Loss: 0.005704963114112616
Epoch 102 / 200 | iteration 150 / 171 | Total Loss: 2.4285666942596436 | KNN Loss: 2.38278865814209 | CLS Loss: 0.04577803984284401
Epoch 102 / 200 | iteration 160 / 171 | Total Loss: 2.440495491027832 | KNN Loss: 2.4072861671447754 | CLS Loss: 0.03320937231183052
Epoch 102 / 200 | iteration 170 / 171 | Total Loss: 2.41416239738

Epoch 105 / 200 | iteration 160 / 171 | Total Loss: 2.4261093139648438 | KNN Loss: 2.40960693359375 | CLS Loss: 0.016502458602190018
Epoch 105 / 200 | iteration 170 / 171 | Total Loss: 2.4165689945220947 | KNN Loss: 2.402536392211914 | CLS Loss: 0.014032592065632343
Epoch: 105, Loss: 2.4133, Train: 0.9949, Valid: 0.9855, Best: 0.9878
Epoch 106 / 200 | iteration 0 / 171 | Total Loss: 2.4064979553222656 | KNN Loss: 2.3946051597595215 | CLS Loss: 0.011892887763679028
Epoch 106 / 200 | iteration 10 / 171 | Total Loss: 2.4008901119232178 | KNN Loss: 2.387813091278076 | CLS Loss: 0.013077130541205406
Epoch 106 / 200 | iteration 20 / 171 | Total Loss: 2.3920645713806152 | KNN Loss: 2.3680689334869385 | CLS Loss: 0.02399558201432228
Epoch 106 / 200 | iteration 30 / 171 | Total Loss: 2.384657144546509 | KNN Loss: 2.3551809787750244 | CLS Loss: 0.02947618067264557
Epoch 106 / 200 | iteration 40 / 171 | Total Loss: 2.411860466003418 | KNN Loss: 2.399411916732788 | CLS Loss: 0.012448485940694809
E

Epoch 109 / 200 | iteration 40 / 171 | Total Loss: 2.414196491241455 | KNN Loss: 2.396958589553833 | CLS Loss: 0.01723797619342804
Epoch 109 / 200 | iteration 50 / 171 | Total Loss: 2.3666443824768066 | KNN Loss: 2.3635504245758057 | CLS Loss: 0.0030940205324441195
Epoch 109 / 200 | iteration 60 / 171 | Total Loss: 2.4171595573425293 | KNN Loss: 2.409950017929077 | CLS Loss: 0.007209639996290207
Epoch 109 / 200 | iteration 70 / 171 | Total Loss: 2.4189882278442383 | KNN Loss: 2.4077250957489014 | CLS Loss: 0.011263096705079079
Epoch 109 / 200 | iteration 80 / 171 | Total Loss: 2.394908905029297 | KNN Loss: 2.386657238006592 | CLS Loss: 0.008251691237092018
Epoch 109 / 200 | iteration 90 / 171 | Total Loss: 2.4291090965270996 | KNN Loss: 2.406256914138794 | CLS Loss: 0.02285226620733738
Epoch 109 / 200 | iteration 100 / 171 | Total Loss: 2.4501631259918213 | KNN Loss: 2.4087367057800293 | CLS Loss: 0.0414264015853405
Epoch 109 / 200 | iteration 110 / 171 | Total Loss: 2.443814516067505 

Epoch 112 / 200 | iteration 110 / 171 | Total Loss: 2.4357142448425293 | KNN Loss: 2.4121744632720947 | CLS Loss: 0.02353987656533718
Epoch 112 / 200 | iteration 120 / 171 | Total Loss: 2.403343915939331 | KNN Loss: 2.3875319957733154 | CLS Loss: 0.015811841934919357
Epoch 112 / 200 | iteration 130 / 171 | Total Loss: 2.3748886585235596 | KNN Loss: 2.356793165206909 | CLS Loss: 0.018095415085554123
Epoch 112 / 200 | iteration 140 / 171 | Total Loss: 2.41402530670166 | KNN Loss: 2.3978939056396484 | CLS Loss: 0.016131456941366196
Epoch 112 / 200 | iteration 150 / 171 | Total Loss: 2.4091620445251465 | KNN Loss: 2.3824219703674316 | CLS Loss: 0.026739979162812233
Epoch 112 / 200 | iteration 160 / 171 | Total Loss: 2.3840346336364746 | KNN Loss: 2.365967035293579 | CLS Loss: 0.01806754246354103
Epoch 112 / 200 | iteration 170 / 171 | Total Loss: 2.439099073410034 | KNN Loss: 2.4266445636749268 | CLS Loss: 0.012454397976398468
Epoch: 112, Loss: 2.4169, Train: 0.9960, Valid: 0.9865, Best: 0

Epoch 115 / 200 | iteration 170 / 171 | Total Loss: 2.438997507095337 | KNN Loss: 2.4319984912872314 | CLS Loss: 0.006999124772846699
Epoch: 115, Loss: 2.4173, Train: 0.9968, Valid: 0.9862, Best: 0.9878
Epoch 116 / 200 | iteration 0 / 171 | Total Loss: 2.3950066566467285 | KNN Loss: 2.376633882522583 | CLS Loss: 0.01837271638214588
Epoch 116 / 200 | iteration 10 / 171 | Total Loss: 2.406200408935547 | KNN Loss: 2.4019083976745605 | CLS Loss: 0.004291938152164221
Epoch 116 / 200 | iteration 20 / 171 | Total Loss: 2.3948638439178467 | KNN Loss: 2.391629457473755 | CLS Loss: 0.003234376199543476
Epoch 116 / 200 | iteration 30 / 171 | Total Loss: 2.4062178134918213 | KNN Loss: 2.4028382301330566 | CLS Loss: 0.003379630157724023
Epoch 116 / 200 | iteration 40 / 171 | Total Loss: 2.3929810523986816 | KNN Loss: 2.386455535888672 | CLS Loss: 0.006525494623929262
Epoch 116 / 200 | iteration 50 / 171 | Total Loss: 2.4133455753326416 | KNN Loss: 2.404627799987793 | CLS Loss: 0.008717675693333149


Epoch 119 / 200 | iteration 50 / 171 | Total Loss: 2.4119420051574707 | KNN Loss: 2.399339199066162 | CLS Loss: 0.012602746486663818
Epoch 119 / 200 | iteration 60 / 171 | Total Loss: 2.445822238922119 | KNN Loss: 2.4344425201416016 | CLS Loss: 0.011379708535969257
Epoch 119 / 200 | iteration 70 / 171 | Total Loss: 2.4564578533172607 | KNN Loss: 2.422043800354004 | CLS Loss: 0.03441406041383743
Epoch 119 / 200 | iteration 80 / 171 | Total Loss: 2.399264335632324 | KNN Loss: 2.3899075984954834 | CLS Loss: 0.009356831200420856
Epoch 119 / 200 | iteration 90 / 171 | Total Loss: 2.4069035053253174 | KNN Loss: 2.396620512008667 | CLS Loss: 0.010282972827553749
Epoch 119 / 200 | iteration 100 / 171 | Total Loss: 2.428077220916748 | KNN Loss: 2.419060707092285 | CLS Loss: 0.009016572497785091
Epoch 119 / 200 | iteration 110 / 171 | Total Loss: 2.4236481189727783 | KNN Loss: 2.4158525466918945 | CLS Loss: 0.007795600686222315
Epoch 119 / 200 | iteration 120 / 171 | Total Loss: 2.41658091545105

Epoch 122 / 200 | iteration 110 / 171 | Total Loss: 2.38179087638855 | KNN Loss: 2.376298666000366 | CLS Loss: 0.005492095369845629
Epoch 122 / 200 | iteration 120 / 171 | Total Loss: 2.4388599395751953 | KNN Loss: 2.428873300552368 | CLS Loss: 0.00998655054718256
Epoch 122 / 200 | iteration 130 / 171 | Total Loss: 2.4287333488464355 | KNN Loss: 2.421967029571533 | CLS Loss: 0.006766209844499826
Epoch 122 / 200 | iteration 140 / 171 | Total Loss: 2.4204506874084473 | KNN Loss: 2.393866777420044 | CLS Loss: 0.026583993807435036
Epoch 122 / 200 | iteration 150 / 171 | Total Loss: 2.4248270988464355 | KNN Loss: 2.4154984951019287 | CLS Loss: 0.009328568354249
Epoch 122 / 200 | iteration 160 / 171 | Total Loss: 2.423283100128174 | KNN Loss: 2.4173009395599365 | CLS Loss: 0.005982252303510904
Epoch 122 / 200 | iteration 170 / 171 | Total Loss: 2.3937227725982666 | KNN Loss: 2.3762879371643066 | CLS Loss: 0.017434820532798767
Epoch: 122, Loss: 2.4127, Train: 0.9970, Valid: 0.9874, Best: 0.98

Epoch 125 / 200 | iteration 170 / 171 | Total Loss: 2.4161996841430664 | KNN Loss: 2.414077043533325 | CLS Loss: 0.002122635720297694
Epoch: 125, Loss: 2.4119, Train: 0.9972, Valid: 0.9874, Best: 0.9878
Epoch 126 / 200 | iteration 0 / 171 | Total Loss: 2.3918304443359375 | KNN Loss: 2.3797755241394043 | CLS Loss: 0.012054921127855778
Epoch 126 / 200 | iteration 10 / 171 | Total Loss: 2.408883571624756 | KNN Loss: 2.3956170082092285 | CLS Loss: 0.01326647587120533
Epoch 126 / 200 | iteration 20 / 171 | Total Loss: 2.454699754714966 | KNN Loss: 2.4398210048675537 | CLS Loss: 0.014878763817250729
Epoch 126 / 200 | iteration 30 / 171 | Total Loss: 2.405946731567383 | KNN Loss: 2.4010908603668213 | CLS Loss: 0.0048558032140135765
Epoch 126 / 200 | iteration 40 / 171 | Total Loss: 2.395657777786255 | KNN Loss: 2.3926138877868652 | CLS Loss: 0.0030439484398812056
Epoch 126 / 200 | iteration 50 / 171 | Total Loss: 2.428980588912964 | KNN Loss: 2.418198585510254 | CLS Loss: 0.01078200712800026


Epoch 129 / 200 | iteration 50 / 171 | Total Loss: 2.4026544094085693 | KNN Loss: 2.378816604614258 | CLS Loss: 0.02383785881102085
Epoch 129 / 200 | iteration 60 / 171 | Total Loss: 2.40240478515625 | KNN Loss: 2.398435115814209 | CLS Loss: 0.00396970147266984
Epoch 129 / 200 | iteration 70 / 171 | Total Loss: 2.393235683441162 | KNN Loss: 2.3899266719818115 | CLS Loss: 0.0033089183270931244
Epoch 129 / 200 | iteration 80 / 171 | Total Loss: 2.403447389602661 | KNN Loss: 2.400646686553955 | CLS Loss: 0.0028006487991660833
Epoch 129 / 200 | iteration 90 / 171 | Total Loss: 2.392164945602417 | KNN Loss: 2.3875296115875244 | CLS Loss: 0.004635377787053585
Epoch 129 / 200 | iteration 100 / 171 | Total Loss: 2.400472402572632 | KNN Loss: 2.3831539154052734 | CLS Loss: 0.01731852814555168
Epoch 129 / 200 | iteration 110 / 171 | Total Loss: 2.411667585372925 | KNN Loss: 2.401001214981079 | CLS Loss: 0.01066632755100727
Epoch 129 / 200 | iteration 120 / 171 | Total Loss: 2.4276723861694336 | 

Epoch 132 / 200 | iteration 120 / 171 | Total Loss: 2.390885591506958 | KNN Loss: 2.3725316524505615 | CLS Loss: 0.01835402473807335
Epoch 132 / 200 | iteration 130 / 171 | Total Loss: 2.434852361679077 | KNN Loss: 2.430737257003784 | CLS Loss: 0.004115118645131588
Epoch 132 / 200 | iteration 140 / 171 | Total Loss: 2.4231197834014893 | KNN Loss: 2.4098620414733887 | CLS Loss: 0.013257824815809727
Epoch 132 / 200 | iteration 150 / 171 | Total Loss: 2.4707040786743164 | KNN Loss: 2.4606776237487793 | CLS Loss: 0.010026425123214722
Epoch 132 / 200 | iteration 160 / 171 | Total Loss: 2.430321455001831 | KNN Loss: 2.4128646850585938 | CLS Loss: 0.01745671033859253
Epoch 132 / 200 | iteration 170 / 171 | Total Loss: 2.4235169887542725 | KNN Loss: 2.395570993423462 | CLS Loss: 0.02794599160552025
Epoch: 132, Loss: 2.4201, Train: 0.9955, Valid: 0.9851, Best: 0.9878
Epoch 133 / 200 | iteration 0 / 171 | Total Loss: 2.4560294151306152 | KNN Loss: 2.437328338623047 | CLS Loss: 0.0187010429799556

Epoch: 135, Loss: 2.4134, Train: 0.9968, Valid: 0.9870, Best: 0.9878
Epoch 136 / 200 | iteration 0 / 171 | Total Loss: 2.4188365936279297 | KNN Loss: 2.412299871444702 | CLS Loss: 0.006536669097840786
Epoch 136 / 200 | iteration 10 / 171 | Total Loss: 2.473193407058716 | KNN Loss: 2.445981025695801 | CLS Loss: 0.027212362736463547
Epoch 136 / 200 | iteration 20 / 171 | Total Loss: 2.404040813446045 | KNN Loss: 2.3996236324310303 | CLS Loss: 0.004417239222675562
Epoch 136 / 200 | iteration 30 / 171 | Total Loss: 2.3975908756256104 | KNN Loss: 2.381883382797241 | CLS Loss: 0.015707453712821007
Epoch 136 / 200 | iteration 40 / 171 | Total Loss: 2.4638140201568604 | KNN Loss: 2.456021547317505 | CLS Loss: 0.0077923936769366264
Epoch 136 / 200 | iteration 50 / 171 | Total Loss: 2.3968091011047363 | KNN Loss: 2.3684298992156982 | CLS Loss: 0.02837924100458622
Epoch 136 / 200 | iteration 60 / 171 | Total Loss: 2.415238618850708 | KNN Loss: 2.3944485187530518 | CLS Loss: 0.020790155977010727
E

Epoch 139 / 200 | iteration 60 / 171 | Total Loss: 2.4090769290924072 | KNN Loss: 2.384261131286621 | CLS Loss: 0.02481582760810852
Epoch 139 / 200 | iteration 70 / 171 | Total Loss: 2.4444634914398193 | KNN Loss: 2.430837869644165 | CLS Loss: 0.013625568710267544
Epoch 139 / 200 | iteration 80 / 171 | Total Loss: 2.4055309295654297 | KNN Loss: 2.3918895721435547 | CLS Loss: 0.01364132296293974
Epoch 139 / 200 | iteration 90 / 171 | Total Loss: 2.460820436477661 | KNN Loss: 2.4348654747009277 | CLS Loss: 0.025955067947506905
Epoch 139 / 200 | iteration 100 / 171 | Total Loss: 2.405714511871338 | KNN Loss: 2.402527332305908 | CLS Loss: 0.003187081078067422
Epoch 139 / 200 | iteration 110 / 171 | Total Loss: 2.469928741455078 | KNN Loss: 2.4643123149871826 | CLS Loss: 0.00561648840084672
Epoch 139 / 200 | iteration 120 / 171 | Total Loss: 2.440058946609497 | KNN Loss: 2.41267728805542 | CLS Loss: 0.027381733059883118
Epoch 139 / 200 | iteration 130 / 171 | Total Loss: 2.4092459678649902 

Epoch 142 / 200 | iteration 120 / 171 | Total Loss: 2.3910398483276367 | KNN Loss: 2.3744266033172607 | CLS Loss: 0.016613297164440155
Epoch 142 / 200 | iteration 130 / 171 | Total Loss: 2.3794608116149902 | KNN Loss: 2.372133493423462 | CLS Loss: 0.0073273335583508015
Epoch 142 / 200 | iteration 140 / 171 | Total Loss: 2.412013530731201 | KNN Loss: 2.398181676864624 | CLS Loss: 0.013831904157996178
Epoch 142 / 200 | iteration 150 / 171 | Total Loss: 2.4066977500915527 | KNN Loss: 2.3948781490325928 | CLS Loss: 0.011819511651992798
Epoch 142 / 200 | iteration 160 / 171 | Total Loss: 2.4306023120880127 | KNN Loss: 2.4174585342407227 | CLS Loss: 0.013143709860742092
Epoch 142 / 200 | iteration 170 / 171 | Total Loss: 2.4174110889434814 | KNN Loss: 2.392571210861206 | CLS Loss: 0.024839825928211212
Epoch: 142, Loss: 2.4072, Train: 0.9970, Valid: 0.9860, Best: 0.9878
Epoch 143 / 200 | iteration 0 / 171 | Total Loss: 2.4113452434539795 | KNN Loss: 2.397935390472412 | CLS Loss: 0.01340986415

Epoch: 145, Loss: 2.4103, Train: 0.9968, Valid: 0.9858, Best: 0.9878
Epoch 146 / 200 | iteration 0 / 171 | Total Loss: 2.3801896572113037 | KNN Loss: 2.3722307682037354 | CLS Loss: 0.007958979345858097
Epoch 146 / 200 | iteration 10 / 171 | Total Loss: 2.4068245887756348 | KNN Loss: 2.3983242511749268 | CLS Loss: 0.008500367403030396
Epoch 146 / 200 | iteration 20 / 171 | Total Loss: 2.4224398136138916 | KNN Loss: 2.4152138233184814 | CLS Loss: 0.0072260363958776
Epoch 146 / 200 | iteration 30 / 171 | Total Loss: 2.438340902328491 | KNN Loss: 2.4358150959014893 | CLS Loss: 0.0025257994420826435
Epoch 146 / 200 | iteration 40 / 171 | Total Loss: 2.3999216556549072 | KNN Loss: 2.389310359954834 | CLS Loss: 0.010611379519104958
Epoch 146 / 200 | iteration 50 / 171 | Total Loss: 2.450612783432007 | KNN Loss: 2.4282572269439697 | CLS Loss: 0.022355644032359123
Epoch 146 / 200 | iteration 60 / 171 | Total Loss: 2.3785080909729004 | KNN Loss: 2.375030755996704 | CLS Loss: 0.003477226011455059

Epoch 149 / 200 | iteration 60 / 171 | Total Loss: 2.403106927871704 | KNN Loss: 2.396996259689331 | CLS Loss: 0.006110745016485453
Epoch 149 / 200 | iteration 70 / 171 | Total Loss: 2.418189525604248 | KNN Loss: 2.41595196723938 | CLS Loss: 0.002237651264294982
Epoch 149 / 200 | iteration 80 / 171 | Total Loss: 2.4031460285186768 | KNN Loss: 2.3978588581085205 | CLS Loss: 0.005287060979753733
Epoch 149 / 200 | iteration 90 / 171 | Total Loss: 2.4220457077026367 | KNN Loss: 2.394346237182617 | CLS Loss: 0.027699587866663933
Epoch 149 / 200 | iteration 100 / 171 | Total Loss: 2.3865294456481934 | KNN Loss: 2.363023042678833 | CLS Loss: 0.023506473749876022
Epoch 149 / 200 | iteration 110 / 171 | Total Loss: 2.3946444988250732 | KNN Loss: 2.391211748123169 | CLS Loss: 0.0034327416215091944
Epoch 149 / 200 | iteration 120 / 171 | Total Loss: 2.3653595447540283 | KNN Loss: 2.3491673469543457 | CLS Loss: 0.01619214005768299
Epoch 149 / 200 | iteration 130 / 171 | Total Loss: 2.4121408462524

Epoch 152 / 200 | iteration 120 / 171 | Total Loss: 2.4020724296569824 | KNN Loss: 2.3929545879364014 | CLS Loss: 0.009117750450968742
Epoch 152 / 200 | iteration 130 / 171 | Total Loss: 2.4267237186431885 | KNN Loss: 2.4099233150482178 | CLS Loss: 0.01680043153464794
Epoch 152 / 200 | iteration 140 / 171 | Total Loss: 2.4244234561920166 | KNN Loss: 2.3945634365081787 | CLS Loss: 0.02985996939241886
Epoch 152 / 200 | iteration 150 / 171 | Total Loss: 2.40876841545105 | KNN Loss: 2.4021267890930176 | CLS Loss: 0.006641555577516556
Epoch 152 / 200 | iteration 160 / 171 | Total Loss: 2.4401607513427734 | KNN Loss: 2.437626361846924 | CLS Loss: 0.0025343128945678473
Epoch 152 / 200 | iteration 170 / 171 | Total Loss: 2.4408645629882812 | KNN Loss: 2.4125330448150635 | CLS Loss: 0.028331534937024117
Epoch: 152, Loss: 2.4121, Train: 0.9971, Valid: 0.9864, Best: 0.9878
Epoch 153 / 200 | iteration 0 / 171 | Total Loss: 2.415532112121582 | KNN Loss: 2.4092495441436768 | CLS Loss: 0.006282678339

Epoch: 155, Loss: 2.4068, Train: 0.9959, Valid: 0.9858, Best: 0.9878
Epoch 156 / 200 | iteration 0 / 171 | Total Loss: 2.4301488399505615 | KNN Loss: 2.416956663131714 | CLS Loss: 0.013192208483815193
Epoch 156 / 200 | iteration 10 / 171 | Total Loss: 2.4006025791168213 | KNN Loss: 2.384650230407715 | CLS Loss: 0.01595238409936428
Epoch 156 / 200 | iteration 20 / 171 | Total Loss: 2.3919906616210938 | KNN Loss: 2.382678270339966 | CLS Loss: 0.009312466718256474
Epoch 156 / 200 | iteration 30 / 171 | Total Loss: 2.41636061668396 | KNN Loss: 2.3836965560913086 | CLS Loss: 0.03266414627432823
Epoch 156 / 200 | iteration 40 / 171 | Total Loss: 2.413684844970703 | KNN Loss: 2.409841775894165 | CLS Loss: 0.0038430080749094486
Epoch 156 / 200 | iteration 50 / 171 | Total Loss: 2.449755907058716 | KNN Loss: 2.4446308612823486 | CLS Loss: 0.005125163588672876
Epoch 156 / 200 | iteration 60 / 171 | Total Loss: 2.3922030925750732 | KNN Loss: 2.387766122817993 | CLS Loss: 0.004436911549419165
Epoc

Epoch 159 / 200 | iteration 60 / 171 | Total Loss: 2.3926849365234375 | KNN Loss: 2.371378183364868 | CLS Loss: 0.021306656301021576
Epoch 159 / 200 | iteration 70 / 171 | Total Loss: 2.419893503189087 | KNN Loss: 2.403132915496826 | CLS Loss: 0.01676054857671261
Epoch 159 / 200 | iteration 80 / 171 | Total Loss: 2.416393756866455 | KNN Loss: 2.4046285152435303 | CLS Loss: 0.011765331961214542
Epoch 159 / 200 | iteration 90 / 171 | Total Loss: 2.4224112033843994 | KNN Loss: 2.420743465423584 | CLS Loss: 0.0016677876701578498
Epoch 159 / 200 | iteration 100 / 171 | Total Loss: 2.430586576461792 | KNN Loss: 2.415789842605591 | CLS Loss: 0.014796807430684566
Epoch 159 / 200 | iteration 110 / 171 | Total Loss: 2.4244868755340576 | KNN Loss: 2.408724308013916 | CLS Loss: 0.01576252467930317
Epoch 159 / 200 | iteration 120 / 171 | Total Loss: 2.384993314743042 | KNN Loss: 2.3762307167053223 | CLS Loss: 0.008762693032622337
Epoch 159 / 200 | iteration 130 / 171 | Total Loss: 2.359407424926758

Epoch 162 / 200 | iteration 120 / 171 | Total Loss: 2.4279768466949463 | KNN Loss: 2.416605234146118 | CLS Loss: 0.011371677741408348
Epoch 162 / 200 | iteration 130 / 171 | Total Loss: 2.3861067295074463 | KNN Loss: 2.3776297569274902 | CLS Loss: 0.008476980961859226
Epoch 162 / 200 | iteration 140 / 171 | Total Loss: 2.4356577396392822 | KNN Loss: 2.4158670902252197 | CLS Loss: 0.01979055628180504
Epoch 162 / 200 | iteration 150 / 171 | Total Loss: 2.4314465522766113 | KNN Loss: 2.4183788299560547 | CLS Loss: 0.013067708350718021
Epoch 162 / 200 | iteration 160 / 171 | Total Loss: 2.4127767086029053 | KNN Loss: 2.4097578525543213 | CLS Loss: 0.0030187482479959726
Epoch 162 / 200 | iteration 170 / 171 | Total Loss: 2.4368746280670166 | KNN Loss: 2.4188387393951416 | CLS Loss: 0.018035883083939552
Epoch: 162, Loss: 2.4079, Train: 0.9969, Valid: 0.9867, Best: 0.9878
Epoch 163 / 200 | iteration 0 / 171 | Total Loss: 2.3787312507629395 | KNN Loss: 2.3642079830169678 | CLS Loss: 0.01452328

Epoch: 165, Loss: 2.4091, Train: 0.9971, Valid: 0.9860, Best: 0.9878
Epoch 166 / 200 | iteration 0 / 171 | Total Loss: 2.4372200965881348 | KNN Loss: 2.4259555339813232 | CLS Loss: 0.011264658533036709
Epoch 166 / 200 | iteration 10 / 171 | Total Loss: 2.3875250816345215 | KNN Loss: 2.3857245445251465 | CLS Loss: 0.0018005723832175136
Epoch 166 / 200 | iteration 20 / 171 | Total Loss: 2.3904170989990234 | KNN Loss: 2.385345220565796 | CLS Loss: 0.005071794148534536
Epoch 166 / 200 | iteration 30 / 171 | Total Loss: 2.3754959106445312 | KNN Loss: 2.371304750442505 | CLS Loss: 0.00419115275144577
Epoch 166 / 200 | iteration 40 / 171 | Total Loss: 2.4166958332061768 | KNN Loss: 2.409127712249756 | CLS Loss: 0.007568098604679108
Epoch 166 / 200 | iteration 50 / 171 | Total Loss: 2.421764612197876 | KNN Loss: 2.4106554985046387 | CLS Loss: 0.011109111830592155
Epoch 166 / 200 | iteration 60 / 171 | Total Loss: 2.3961334228515625 | KNN Loss: 2.3910839557647705 | CLS Loss: 0.00504949735477566

Epoch 169 / 200 | iteration 60 / 171 | Total Loss: 2.4314417839050293 | KNN Loss: 2.4178848266601562 | CLS Loss: 0.013556952588260174
Epoch 169 / 200 | iteration 70 / 171 | Total Loss: 2.3991620540618896 | KNN Loss: 2.3952479362487793 | CLS Loss: 0.003914100117981434
Epoch 169 / 200 | iteration 80 / 171 | Total Loss: 2.413848876953125 | KNN Loss: 2.4015579223632812 | CLS Loss: 0.012291026301681995
Epoch 169 / 200 | iteration 90 / 171 | Total Loss: 2.384444236755371 | KNN Loss: 2.3831424713134766 | CLS Loss: 0.001301768934354186
Epoch 169 / 200 | iteration 100 / 171 | Total Loss: 2.3608477115631104 | KNN Loss: 2.343308448791504 | CLS Loss: 0.01753917522728443
Epoch 169 / 200 | iteration 110 / 171 | Total Loss: 2.402322292327881 | KNN Loss: 2.399563789367676 | CLS Loss: 0.0027585593052208424
Epoch 169 / 200 | iteration 120 / 171 | Total Loss: 2.3921091556549072 | KNN Loss: 2.3855252265930176 | CLS Loss: 0.006583996117115021
Epoch 169 / 200 | iteration 130 / 171 | Total Loss: 2.3984954357

Epoch 172 / 200 | iteration 120 / 171 | Total Loss: 2.4095072746276855 | KNN Loss: 2.4025354385375977 | CLS Loss: 0.006971778813749552
Epoch 172 / 200 | iteration 130 / 171 | Total Loss: 2.3734593391418457 | KNN Loss: 2.3647592067718506 | CLS Loss: 0.008700129576027393
Epoch 172 / 200 | iteration 140 / 171 | Total Loss: 2.3498165607452393 | KNN Loss: 2.349445104598999 | CLS Loss: 0.00037140410859137774
Epoch 172 / 200 | iteration 150 / 171 | Total Loss: 2.417059898376465 | KNN Loss: 2.4098212718963623 | CLS Loss: 0.007238606456667185
Epoch 172 / 200 | iteration 160 / 171 | Total Loss: 2.4199373722076416 | KNN Loss: 2.3972558975219727 | CLS Loss: 0.02268156222999096
Epoch 172 / 200 | iteration 170 / 171 | Total Loss: 2.4041714668273926 | KNN Loss: 2.395745038986206 | CLS Loss: 0.008426348678767681
Epoch: 172, Loss: 2.4085, Train: 0.9965, Valid: 0.9865, Best: 0.9878
Epoch 173 / 200 | iteration 0 / 171 | Total Loss: 2.405426025390625 | KNN Loss: 2.402008056640625 | CLS Loss: 0.00341799319

Epoch: 175, Loss: 2.4089, Train: 0.9960, Valid: 0.9861, Best: 0.9878
Epoch 176 / 200 | iteration 0 / 171 | Total Loss: 2.439706563949585 | KNN Loss: 2.423344373703003 | CLS Loss: 0.016362294554710388
Epoch 176 / 200 | iteration 10 / 171 | Total Loss: 2.419208288192749 | KNN Loss: 2.417212963104248 | CLS Loss: 0.001995422877371311
Epoch 176 / 200 | iteration 20 / 171 | Total Loss: 2.444315195083618 | KNN Loss: 2.436753749847412 | CLS Loss: 0.007561384234577417
Epoch 176 / 200 | iteration 30 / 171 | Total Loss: 2.398893356323242 | KNN Loss: 2.3881239891052246 | CLS Loss: 0.010769395157694817
Epoch 176 / 200 | iteration 40 / 171 | Total Loss: 2.390425205230713 | KNN Loss: 2.3878254890441895 | CLS Loss: 0.0025998326018452644
Epoch 176 / 200 | iteration 50 / 171 | Total Loss: 2.384436845779419 | KNN Loss: 2.3731329441070557 | CLS Loss: 0.011303821578621864
Epoch 176 / 200 | iteration 60 / 171 | Total Loss: 2.397850275039673 | KNN Loss: 2.38834285736084 | CLS Loss: 0.009507318027317524
Epoch

Epoch 179 / 200 | iteration 60 / 171 | Total Loss: 2.4312684535980225 | KNN Loss: 2.418348550796509 | CLS Loss: 0.012919902801513672
Epoch 179 / 200 | iteration 70 / 171 | Total Loss: 2.4015331268310547 | KNN Loss: 2.3825480937957764 | CLS Loss: 0.018984949216246605
Epoch 179 / 200 | iteration 80 / 171 | Total Loss: 2.4184584617614746 | KNN Loss: 2.4028100967407227 | CLS Loss: 0.015648363158106804
Epoch 179 / 200 | iteration 90 / 171 | Total Loss: 2.387951374053955 | KNN Loss: 2.377260446548462 | CLS Loss: 0.010690818540751934
Epoch 179 / 200 | iteration 100 / 171 | Total Loss: 2.431504726409912 | KNN Loss: 2.4227287769317627 | CLS Loss: 0.008775957860052586
Epoch 179 / 200 | iteration 110 / 171 | Total Loss: 2.378124237060547 | KNN Loss: 2.37115478515625 | CLS Loss: 0.006969396956264973
Epoch 179 / 200 | iteration 120 / 171 | Total Loss: 2.417039632797241 | KNN Loss: 2.4109106063842773 | CLS Loss: 0.006128915119916201
Epoch 179 / 200 | iteration 130 / 171 | Total Loss: 2.3768625259399

Epoch 182 / 200 | iteration 120 / 171 | Total Loss: 2.3971948623657227 | KNN Loss: 2.3941457271575928 | CLS Loss: 0.0030491214711219072
Epoch 182 / 200 | iteration 130 / 171 | Total Loss: 2.3931772708892822 | KNN Loss: 2.3886098861694336 | CLS Loss: 0.004567311145365238
Epoch 182 / 200 | iteration 140 / 171 | Total Loss: 2.3916101455688477 | KNN Loss: 2.3841545581817627 | CLS Loss: 0.007455611135810614
Epoch 182 / 200 | iteration 150 / 171 | Total Loss: 2.417396068572998 | KNN Loss: 2.413971424102783 | CLS Loss: 0.0034247359726577997
Epoch 182 / 200 | iteration 160 / 171 | Total Loss: 2.421266555786133 | KNN Loss: 2.402632236480713 | CLS Loss: 0.018634293228387833
Epoch 182 / 200 | iteration 170 / 171 | Total Loss: 2.395117998123169 | KNN Loss: 2.388073444366455 | CLS Loss: 0.007044651545584202
Epoch: 182, Loss: 2.4079, Train: 0.9971, Valid: 0.9868, Best: 0.9878
Epoch 183 / 200 | iteration 0 / 171 | Total Loss: 2.4373679161071777 | KNN Loss: 2.4256834983825684 | CLS Loss: 0.01168431434

Epoch: 185, Loss: 2.4103, Train: 0.9980, Valid: 0.9878, Best: 0.9878
Epoch 186 / 200 | iteration 0 / 171 | Total Loss: 2.4008283615112305 | KNN Loss: 2.385937213897705 | CLS Loss: 0.014891088008880615
Epoch 186 / 200 | iteration 10 / 171 | Total Loss: 2.3708126544952393 | KNN Loss: 2.3694918155670166 | CLS Loss: 0.001320941955782473
Epoch 186 / 200 | iteration 20 / 171 | Total Loss: 2.4096171855926514 | KNN Loss: 2.3974850177764893 | CLS Loss: 0.012132060714066029
Epoch 186 / 200 | iteration 30 / 171 | Total Loss: 2.3881924152374268 | KNN Loss: 2.38106107711792 | CLS Loss: 0.007131386548280716
Epoch 186 / 200 | iteration 40 / 171 | Total Loss: 2.362730026245117 | KNN Loss: 2.3547840118408203 | CLS Loss: 0.00794593058526516
Epoch 186 / 200 | iteration 50 / 171 | Total Loss: 2.391960382461548 | KNN Loss: 2.387556314468384 | CLS Loss: 0.004404039587825537
Epoch 186 / 200 | iteration 60 / 171 | Total Loss: 2.3919930458068848 | KNN Loss: 2.3834145069122314 | CLS Loss: 0.008578580804169178
E

Epoch 189 / 200 | iteration 60 / 171 | Total Loss: 2.425607681274414 | KNN Loss: 2.422868013381958 | CLS Loss: 0.00273971538990736
Epoch 189 / 200 | iteration 70 / 171 | Total Loss: 2.403491258621216 | KNN Loss: 2.3969478607177734 | CLS Loss: 0.006543281488120556
Epoch 189 / 200 | iteration 80 / 171 | Total Loss: 2.3887836933135986 | KNN Loss: 2.3818166255950928 | CLS Loss: 0.006966956425458193
Epoch 189 / 200 | iteration 90 / 171 | Total Loss: 2.432523012161255 | KNN Loss: 2.4093360900878906 | CLS Loss: 0.023186897858977318
Epoch 189 / 200 | iteration 100 / 171 | Total Loss: 2.386491537094116 | KNN Loss: 2.3740074634552 | CLS Loss: 0.012484166771173477
Epoch 189 / 200 | iteration 110 / 171 | Total Loss: 2.426100492477417 | KNN Loss: 2.4050347805023193 | CLS Loss: 0.021065719425678253
Epoch 189 / 200 | iteration 120 / 171 | Total Loss: 2.386218309402466 | KNN Loss: 2.384631872177124 | CLS Loss: 0.0015865350142121315
Epoch 189 / 200 | iteration 130 / 171 | Total Loss: 2.4077579975128174

Epoch 192 / 200 | iteration 120 / 171 | Total Loss: 2.3859593868255615 | KNN Loss: 2.378988742828369 | CLS Loss: 0.006970677059143782
Epoch 192 / 200 | iteration 130 / 171 | Total Loss: 2.409832715988159 | KNN Loss: 2.38132905960083 | CLS Loss: 0.028503581881523132
Epoch 192 / 200 | iteration 140 / 171 | Total Loss: 2.4100027084350586 | KNN Loss: 2.4044291973114014 | CLS Loss: 0.005573393311351538
Epoch 192 / 200 | iteration 150 / 171 | Total Loss: 2.433279514312744 | KNN Loss: 2.4250168800354004 | CLS Loss: 0.00826265662908554
Epoch 192 / 200 | iteration 160 / 171 | Total Loss: 2.4009337425231934 | KNN Loss: 2.390872001647949 | CLS Loss: 0.010061747394502163
Epoch 192 / 200 | iteration 170 / 171 | Total Loss: 2.368534564971924 | KNN Loss: 2.3657431602478027 | CLS Loss: 0.0027914096135646105
Epoch: 192, Loss: 2.4036, Train: 0.9964, Valid: 0.9853, Best: 0.9878
Epoch 193 / 200 | iteration 0 / 171 | Total Loss: 2.4220030307769775 | KNN Loss: 2.418938159942627 | CLS Loss: 0.003064960474148

Epoch: 195, Loss: 2.4044, Train: 0.9959, Valid: 0.9861, Best: 0.9878
Epoch 196 / 200 | iteration 0 / 171 | Total Loss: 2.382582664489746 | KNN Loss: 2.3770408630371094 | CLS Loss: 0.005541868507862091
Epoch 196 / 200 | iteration 10 / 171 | Total Loss: 2.387298345565796 | KNN Loss: 2.3784537315368652 | CLS Loss: 0.008844670839607716
Epoch 196 / 200 | iteration 20 / 171 | Total Loss: 2.4020190238952637 | KNN Loss: 2.400002956390381 | CLS Loss: 0.0020159974228590727
Epoch 196 / 200 | iteration 30 / 171 | Total Loss: 2.4295778274536133 | KNN Loss: 2.4134514331817627 | CLS Loss: 0.016126492992043495
Epoch 196 / 200 | iteration 40 / 171 | Total Loss: 2.3838071823120117 | KNN Loss: 2.369483709335327 | CLS Loss: 0.014323582872748375
Epoch 196 / 200 | iteration 50 / 171 | Total Loss: 2.4163501262664795 | KNN Loss: 2.4147579669952393 | CLS Loss: 0.0015922078164294362
Epoch 196 / 200 | iteration 60 / 171 | Total Loss: 2.3970820903778076 | KNN Loss: 2.391136884689331 | CLS Loss: 0.0059452550485730

Epoch 199 / 200 | iteration 60 / 171 | Total Loss: 2.4184951782226562 | KNN Loss: 2.4127755165100098 | CLS Loss: 0.005719643551856279
Epoch 199 / 200 | iteration 70 / 171 | Total Loss: 2.3907227516174316 | KNN Loss: 2.373054027557373 | CLS Loss: 0.017668629065155983
Epoch 199 / 200 | iteration 80 / 171 | Total Loss: 2.435577392578125 | KNN Loss: 2.432734966278076 | CLS Loss: 0.002842331537976861
Epoch 199 / 200 | iteration 90 / 171 | Total Loss: 2.37748384475708 | KNN Loss: 2.372047185897827 | CLS Loss: 0.005436714738607407
Epoch 199 / 200 | iteration 100 / 171 | Total Loss: 2.427713632583618 | KNN Loss: 2.4072234630584717 | CLS Loss: 0.02049008198082447
Epoch 199 / 200 | iteration 110 / 171 | Total Loss: 2.403379440307617 | KNN Loss: 2.3833584785461426 | CLS Loss: 0.020020943135023117
Epoch 199 / 200 | iteration 120 / 171 | Total Loss: 2.404508113861084 | KNN Loss: 2.400773286819458 | CLS Loss: 0.0037347835022956133
Epoch 199 / 200 | iteration 130 / 171 | Total Loss: 2.395252227783203

In [8]:
test(model, test_data_iter, device)

tensor(0.9863, device='cuda:0')

In [9]:
plt.figure()
plt.plot(losses, label='train loss')
plt.legend()
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [10]:
plt.figure()
plt.plot(train_accs, label='train accuracy')
plt.plot(val_accs, label='validation accuracy')
plt.legend()
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [11]:
test_samples = torch.tensor([])
projections = torch.tensor([])
labels = torch.tensor([])

with torch.no_grad():
    for x, y in tqdm(test_data_iter):
        test_samples = torch.cat([test_samples, x])
        labels = torch.cat([labels, y])
        x = x.to(device)
        _, interm = model(x, True)
        projections = torch.cat([projections, interm.detach().cpu().flatten(1)])

  0%|          | 0/43 [00:00<?, ?it/s]

In [12]:
distances = pairwise_distances(projections)
# distances = np.triu(distances)
distances_f = distances.flatten()

plt.matshow(distances)
plt.colorbar()
plt.figure()
plt.hist(distances_f[distances_f > 0], bins=1000)
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [13]:
clusters = DBSCAN(eps=2, min_samples=10).fit_predict(projections)

In [14]:
print(f"Number of inliers: {sum(clusters != -1) / len(clusters)}")

Number of inliers: 0.9101457219862044


In [15]:
perplexity = 100
p = reduce_dims_and_plot(projections[clusters != -1],
                         y=clusters[clusters != -1],
                         title=f'perplexity: {perplexity}',
                         file_name=None,
                         perplexity=perplexity,
                         library='Multicore-TSNE',
                         perform_PCA=False,
                         projected=None,
                         figure_type='2d',
                         show_figure=True,
                         close_figure=False,
                         text=None)    

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

# Train a Soft-Decision-Tree given the self-labels

## Prepare the dataset

In [16]:
tree_dataset = list(zip(test_samples.flatten(1)[clusters!=-1], clusters[clusters != -1]))
batch_size = 512
tree_loader = torch.utils.data.DataLoader(tree_dataset, batch_size=batch_size, shuffle=True)

# Define how we prune the weights of a node

In [17]:
def prune_node(node_weights, factor=1):
    w = node_weights.cpu().detach().numpy()
    mean_ = np.mean(w)
    std_ = np.std(w)
    node_weights[((mean_ - std_ * factor) < node_weights) & (node_weights < (mean_ + std_ * factor))] = 0
    return node_weights

def prune_node_keep(node_weights, keep=4):
    w = node_weights.cpu().detach().numpy()
    throw_idx = np.argsort(abs(w))[:-keep]
    node_weights[throw_idx] = 0
    return node_weights

def prune_tree(tree_, factor):
    new_weights = tree_.inner_nodes.weight.clone()
    for i in range(new_weights.shape[0]):
        res = prune_node_keep(new_weights[i, :], factor)
        new_weights[i, :] = res

    with torch.no_grad():
        tree_.inner_nodes.weight.copy_(new_weights)
        
def sparseness(x):
    s = []
    for i in range(x.shape[0]):
        x_ = x[i, :]
        sp = (len(x_) - torch.norm(x_, 0).item()) / len(x_)
        s.append(sp)
    return np.mean(s)

def compute_regularization_by_level(tree):
    total_reg = 0
    for i in range(tree.inner_nodes.weight.shape[0]):
        cur_level = np.floor(np.log2(i+1))
        node_reg = torch.norm(tree.inner_nodes.weight[i].view(-1), 2)
        total_reg += 2**(-cur_level) * node_reg
    return total_reg

def show_sparseness(tree):
    avg_sp = sparseness(tree.inner_nodes.weight)
    print(f"Average sparseness: {avg_sp}")
    layer = 0
    sps = []
    for i in range(tree.inner_nodes.weight.shape[0]):
        cur_layer = int(np.floor(np.log2(i+1)))
        if cur_layer != layer:
            print(f"layer {layer}: {np.mean(sps)}")
            sps = []
            layer = cur_layer

        x_ = tree.inner_nodes.weight[i, :]
        sp = (len(x_) - torch.norm(x_, 0).item()) / len(x_)
        sps.append(sp)
        
    return avg_sp

## Training configurations

In [18]:
def do_epoch(model, loader, device, log_interval, losses, accs, epoch, iteration):
    model = model.train()
    for batch_idx, (data, target) in enumerate(loader):
        iteration += 1
        data, target = data.to(device), target.to(device)

        output, penalty = tree.forward(data)

        # Loss
        loss_tree = criterion(output, target.view(-1))

        # Penalty
        loss_tree += penalty

        # Sparse regularization
#         fc_params = torch.cat([x.view(-1) for x in tree.inner_nodes.parameters()])
#         regularization = sparsity_lamda * torch.norm(fc_params, 2)
        regularization = sparsity_lamda * compute_regularization_by_level(tree)
        loss = loss_tree

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        losses.append(loss.item())

        pred = output.data.max(1)[1]
        correct = pred.eq(target.view(-1).data).sum()
        accs.append(correct.item() / data.size()[0])

        # Print training status
        if batch_idx % log_interval == 0:
            print(f"Epoch: {epoch:02d} | Batch: {batch_idx:03d} / {len(loader):03d} | Total loss: {loss.item():.3f} | Reg loss: {regularization.item():.3f} | Tree loss: {loss_tree.item():.3f} | Accuracy: {correct.item() / data.size()[0]:03f} | {round((time.time() - start_time) / iteration, 3)} sec/iter")
            
    return iteration


In [19]:
lr = 5e-3
weight_decay = 5e-4
sparsity_lamda = 2e-3
epochs = 400
log_interval = 100
use_cuda = device != 'cpu'

In [20]:
tree = SDT(input_dim=test_samples.shape[2], output_dim=len(set(clusters)) - 1, depth=tree_depth, lamda=1e-3, use_cuda=use_cuda)
optimizer = torch.optim.Adam(tree.parameters(),
                                 lr=lr,
                                 weight_decay=weight_decay)
criterion = nn.CrossEntropyLoss()
tree = tree.to(device)

In [21]:
losses = []
accs = []
sparsity = []

In [22]:
start_time = time.time()
iteration = 0
for epoch in range(epochs):
    # Training
    avg_sp = show_sparseness(tree)
    sparsity.append(avg_sp)
    iteration = do_epoch(tree, tree_loader, device, log_interval, losses, accs, epoch, iteration)
    
    if epoch % 1 == 0:
        prune_tree(tree, factor=3)
        

Average sparseness: 0.0
layer 0: 0.0
layer 1: 0.0
layer 2: 0.0
layer 3: 0.0
layer 4: 0.0
layer 5: 0.0
layer 6: 0.0
Epoch: 00 | Batch: 000 / 039 | Total loss: 3.071 | Reg loss: 0.009 | Tree loss: 3.071 | Accuracy: 0.037109 | 0.245 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
Epoch: 01 | Batch: 000 / 039 | Total loss: 2.984 | Reg loss: 0.005 | Tree loss: 2.984 | Accuracy: 0.326172 | 0.211 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
Epoch: 02 | Batch: 000 / 039 | Total loss: 2.934 | Reg loss: 0.008 | Tree loss: 2.934 | Accuracy: 0.316406 | 0.209 sec/iter
Average sparseness: 0.9840425531914894
laye

Epoch: 23 | Batch: 000 / 039 | Total loss: 2.394 | Reg loss: 0.024 | Tree loss: 2.394 | Accuracy: 0.408203 | 0.205 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
Epoch: 24 | Batch: 000 / 039 | Total loss: 2.455 | Reg loss: 0.025 | Tree loss: 2.455 | Accuracy: 0.341797 | 0.204 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
Epoch: 25 | Batch: 000 / 039 | Total loss: 2.424 | Reg loss: 0.025 | Tree loss: 2.424 | Accuracy: 0.376953 | 0.204 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4

Epoch: 46 | Batch: 000 / 039 | Total loss: 2.372 | Reg loss: 0.028 | Tree loss: 2.372 | Accuracy: 0.335938 | 0.205 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
Epoch: 47 | Batch: 000 / 039 | Total loss: 2.272 | Reg loss: 0.029 | Tree loss: 2.272 | Accuracy: 0.392578 | 0.205 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
Epoch: 48 | Batch: 000 / 039 | Total loss: 2.365 | Reg loss: 0.029 | Tree loss: 2.365 | Accuracy: 0.339844 | 0.205 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4

Epoch: 69 | Batch: 000 / 039 | Total loss: 2.414 | Reg loss: 0.030 | Tree loss: 2.414 | Accuracy: 0.292969 | 0.204 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
Epoch: 70 | Batch: 000 / 039 | Total loss: 2.290 | Reg loss: 0.030 | Tree loss: 2.290 | Accuracy: 0.333984 | 0.204 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
Epoch: 71 | Batch: 000 / 039 | Total loss: 2.336 | Reg loss: 0.030 | Tree loss: 2.336 | Accuracy: 0.320312 | 0.204 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4

Epoch: 92 | Batch: 000 / 039 | Total loss: 2.222 | Reg loss: 0.031 | Tree loss: 2.222 | Accuracy: 0.361328 | 0.203 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
Epoch: 93 | Batch: 000 / 039 | Total loss: 2.330 | Reg loss: 0.031 | Tree loss: 2.330 | Accuracy: 0.312500 | 0.203 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
Epoch: 94 | Batch: 000 / 039 | Total loss: 2.276 | Reg loss: 0.031 | Tree loss: 2.276 | Accuracy: 0.347656 | 0.203 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4

Epoch: 115 | Batch: 000 / 039 | Total loss: 2.301 | Reg loss: 0.031 | Tree loss: 2.301 | Accuracy: 0.349609 | 0.203 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
Epoch: 116 | Batch: 000 / 039 | Total loss: 2.399 | Reg loss: 0.031 | Tree loss: 2.399 | Accuracy: 0.310547 | 0.203 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
Epoch: 117 | Batch: 000 / 039 | Total loss: 2.249 | Reg loss: 0.031 | Tree loss: 2.249 | Accuracy: 0.369141 | 0.203 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
laye

Epoch: 138 | Batch: 000 / 039 | Total loss: 2.265 | Reg loss: 0.032 | Tree loss: 2.265 | Accuracy: 0.347656 | 0.202 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
Epoch: 139 | Batch: 000 / 039 | Total loss: 2.236 | Reg loss: 0.032 | Tree loss: 2.236 | Accuracy: 0.361328 | 0.202 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
Epoch: 140 | Batch: 000 / 039 | Total loss: 2.197 | Reg loss: 0.032 | Tree loss: 2.197 | Accuracy: 0.353516 | 0.202 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
laye

Epoch: 161 | Batch: 000 / 039 | Total loss: 2.192 | Reg loss: 0.031 | Tree loss: 2.192 | Accuracy: 0.388672 | 0.202 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
Epoch: 162 | Batch: 000 / 039 | Total loss: 2.227 | Reg loss: 0.031 | Tree loss: 2.227 | Accuracy: 0.375000 | 0.202 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
Epoch: 163 | Batch: 000 / 039 | Total loss: 2.190 | Reg loss: 0.031 | Tree loss: 2.190 | Accuracy: 0.388672 | 0.202 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
laye

Epoch: 184 | Batch: 000 / 039 | Total loss: 2.123 | Reg loss: 0.031 | Tree loss: 2.123 | Accuracy: 0.429688 | 0.202 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
Epoch: 185 | Batch: 000 / 039 | Total loss: 2.088 | Reg loss: 0.031 | Tree loss: 2.088 | Accuracy: 0.417969 | 0.202 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
Epoch: 186 | Batch: 000 / 039 | Total loss: 2.253 | Reg loss: 0.031 | Tree loss: 2.253 | Accuracy: 0.369141 | 0.202 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
laye

Epoch: 207 | Batch: 000 / 039 | Total loss: 2.099 | Reg loss: 0.032 | Tree loss: 2.099 | Accuracy: 0.400391 | 0.202 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
Epoch: 208 | Batch: 000 / 039 | Total loss: 2.324 | Reg loss: 0.032 | Tree loss: 2.324 | Accuracy: 0.343750 | 0.202 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
Epoch: 209 | Batch: 000 / 039 | Total loss: 2.209 | Reg loss: 0.032 | Tree loss: 2.209 | Accuracy: 0.396484 | 0.202 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
laye

Epoch: 230 | Batch: 000 / 039 | Total loss: 2.181 | Reg loss: 0.033 | Tree loss: 2.181 | Accuracy: 0.406250 | 0.202 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
Epoch: 231 | Batch: 000 / 039 | Total loss: 2.302 | Reg loss: 0.033 | Tree loss: 2.302 | Accuracy: 0.361328 | 0.202 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
Epoch: 232 | Batch: 000 / 039 | Total loss: 2.287 | Reg loss: 0.033 | Tree loss: 2.287 | Accuracy: 0.363281 | 0.202 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
laye

Epoch: 253 | Batch: 000 / 039 | Total loss: 2.229 | Reg loss: 0.033 | Tree loss: 2.229 | Accuracy: 0.369141 | 0.202 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
Epoch: 254 | Batch: 000 / 039 | Total loss: 2.177 | Reg loss: 0.033 | Tree loss: 2.177 | Accuracy: 0.404297 | 0.202 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
Epoch: 255 | Batch: 000 / 039 | Total loss: 2.211 | Reg loss: 0.033 | Tree loss: 2.211 | Accuracy: 0.376953 | 0.202 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
laye

Epoch: 276 | Batch: 000 / 039 | Total loss: 2.161 | Reg loss: 0.032 | Tree loss: 2.161 | Accuracy: 0.376953 | 0.203 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
Epoch: 277 | Batch: 000 / 039 | Total loss: 2.275 | Reg loss: 0.032 | Tree loss: 2.275 | Accuracy: 0.355469 | 0.203 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
Epoch: 278 | Batch: 000 / 039 | Total loss: 2.215 | Reg loss: 0.032 | Tree loss: 2.215 | Accuracy: 0.367188 | 0.203 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
laye

Epoch: 299 | Batch: 000 / 039 | Total loss: 2.148 | Reg loss: 0.032 | Tree loss: 2.148 | Accuracy: 0.392578 | 0.203 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
Epoch: 300 | Batch: 000 / 039 | Total loss: 2.200 | Reg loss: 0.032 | Tree loss: 2.200 | Accuracy: 0.376953 | 0.203 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
Epoch: 301 | Batch: 000 / 039 | Total loss: 2.179 | Reg loss: 0.032 | Tree loss: 2.179 | Accuracy: 0.375000 | 0.203 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
laye

Epoch: 322 | Batch: 000 / 039 | Total loss: 2.130 | Reg loss: 0.032 | Tree loss: 2.130 | Accuracy: 0.382812 | 0.204 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
Epoch: 323 | Batch: 000 / 039 | Total loss: 2.135 | Reg loss: 0.032 | Tree loss: 2.135 | Accuracy: 0.384766 | 0.204 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
Epoch: 324 | Batch: 000 / 039 | Total loss: 2.234 | Reg loss: 0.032 | Tree loss: 2.234 | Accuracy: 0.349609 | 0.204 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
laye

Epoch: 345 | Batch: 000 / 039 | Total loss: 2.196 | Reg loss: 0.032 | Tree loss: 2.196 | Accuracy: 0.359375 | 0.204 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
Epoch: 346 | Batch: 000 / 039 | Total loss: 2.127 | Reg loss: 0.032 | Tree loss: 2.127 | Accuracy: 0.402344 | 0.204 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
Epoch: 347 | Batch: 000 / 039 | Total loss: 2.219 | Reg loss: 0.032 | Tree loss: 2.219 | Accuracy: 0.367188 | 0.204 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
laye

Epoch: 368 | Batch: 000 / 039 | Total loss: 2.139 | Reg loss: 0.032 | Tree loss: 2.139 | Accuracy: 0.398438 | 0.204 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
Epoch: 369 | Batch: 000 / 039 | Total loss: 2.168 | Reg loss: 0.032 | Tree loss: 2.168 | Accuracy: 0.382812 | 0.204 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
Epoch: 370 | Batch: 000 / 039 | Total loss: 2.166 | Reg loss: 0.032 | Tree loss: 2.166 | Accuracy: 0.388672 | 0.204 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
laye

Epoch: 391 | Batch: 000 / 039 | Total loss: 2.142 | Reg loss: 0.032 | Tree loss: 2.142 | Accuracy: 0.394531 | 0.205 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
Epoch: 392 | Batch: 000 / 039 | Total loss: 2.239 | Reg loss: 0.032 | Tree loss: 2.239 | Accuracy: 0.355469 | 0.205 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
Epoch: 393 | Batch: 000 / 039 | Total loss: 2.181 | Reg loss: 0.032 | Tree loss: 2.181 | Accuracy: 0.367188 | 0.205 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
laye

In [23]:
plt.figure(figsize=(10, 5))
plt.ylabel("Accuracy")
plt.xlabel('Iteration')
plt.plot(accs, label='Accuracy vs iteration')
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [24]:
plt.figure()
plt.ylabel("Loss")
plt.xlabel('Iteration')
plt.plot(losses, label='Loss vs iteration')
# plt.yscale("log")
plt.show()

plt.figure()
weights = tree.inner_nodes.weight.cpu().detach().numpy().flatten()
plt.hist(weights, bins=500)
weights_std = np.std(weights)
weights_mean = np.mean(weights)
plt.axvline(weights_mean + weights_std, color='r')
plt.axvline(weights_mean - weights_std, color='r')
plt.title(f"Mean: {weights_mean}   |   STD: {weights_std}")
plt.yscale("log")
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

# Tree Visualization

In [25]:
plt.figure(figsize=(15, 10), dpi=80)
avg_height, root = tree.visualize()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Average height: 6.6521739130434785


# Extract Rules

# Accumulate samples in the leaves

In [26]:
print(f"Number of patterns: {len(root.get_leaves())}")

Number of patterns: 46


In [27]:
method = 'greedy'

In [28]:
root.clear_leaves_samples()

with torch.no_grad():
    for batch_idx, (data, target) in enumerate(tree_loader):
        root.accumulate_samples(data, method)



# Tighten boundaries

In [29]:
attr_names = [f"T_{i}" for i in range(test_samples.shape[2])]
leaves = root.get_leaves()
sum_comprehensibility = 0
comprehensibilities = []
for pattern_counter, leaf in enumerate(leaves):
    leaf.reset_path()
    leaf.tighten_with_accumulated_samples()
    conds = leaf.get_path_conditions(attr_names)
    print(f"============== Pattern {pattern_counter + 1} ==============")
    comprehensibilities.append(sum([cond.comprehensibility for cond in conds]))
    
print(f"Average comprehensibility: {np.mean(comprehensibilities)}")
print(f"std comprehensibility: {np.std(comprehensibilities)}")
print(f"var comprehensibility: {np.var(comprehensibilities)}")
print(f"minimum comprehensibility: {np.min(comprehensibilities)}")
print(f"maximum comprehensibility: {np.max(comprehensibilities)}")

2292
4
15706
105
224
1593
Average comprehensibility: 37.17391304347826
std comprehensibility: 9.785313635697138
var comprehensibility: 95.75236294896033
minimum comprehensibility: 16
maximum comprehensibility: 48


  return np.log(1 / (1 - x))
