In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib widget
import numpy as np
import pandas as pd
import torch
import torchvision
from torch import optim
import matplotlib.pyplot as plt
import torch.nn.functional as F
from sklearn.metrics import pairwise_distances
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tqdm.notebook import tqdm
import os
import sys
import time
import torch.nn as nn
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)
from stream_generators.mit_bih import MITBIH
from utils.MatplotlibUtils import reduce_dims_and_plot
from network.auto_encoder import AutoEncoder
from losses.knn_loss import KNNLoss, ClassificationKNNLoss
from sklearn.cluster import DBSCAN
from sklearn.mixture import GaussianMixture
from soft_decision_tree.sdt_model import SDT
from sklearn.metrics import davies_bouldin_score

In [2]:
k = 32
tree_depth = 10
batch_size = 512
device = 'cpu'
train_data_path = r'F:\Downloads\archive\mitbih_train.csv'
test_data_path = r'F:\Downloads\archive\mitbih_test.csv'

In [3]:
train_data_iter = torch.utils.data.DataLoader(MITBIH(train_data_path),
                                             batch_size=batch_size,
                                             shuffle=True,
                                             num_workers=1,
                                             pin_memory=True,
                                             drop_last=True)

test_data_iter = torch.utils.data.DataLoader(MITBIH(test_data_path),
                                             batch_size=batch_size,
                                             shuffle=True,
                                             num_workers=1,
                                             pin_memory=True)

In [4]:
class ConvBlock(nn.Module):
    def __init__(self):
        super(ConvBlock, self).__init__()
        self.conv1 = nn.Conv1d(32, 32, kernel_size=5, stride=1, padding=2)
        self.conv2 = nn.Conv1d(32, 32, kernel_size=5, stride=1, padding=2)
        self.relu1 = nn.ReLU()
        self.relu2 = nn.ReLU()
        self.pool = nn.MaxPool1d(kernel_size=5, stride=2)
        
    def forward(self, x):
        y = x
        y = self.conv1(y)
        y = self.relu1(y)
        y = self.conv2(y)
        y = y + x
        y = self.relu2(y)
        y = self.pool(y)
        return y


class ECGModel(nn.Module):
    def __init__(self):
        super(ECGModel, self).__init__()
        self.conv = nn.Conv1d(1, 32, kernel_size=5, stride=1, padding=1)
        self.block1 = ConvBlock()
        self.block2 = ConvBlock()
        self.block3 = ConvBlock()
        self.block4 = ConvBlock()
        self.block5 = ConvBlock()
        self.fc1 = nn.Linear(64, 32)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(32, 5)

    def forward(self, x, return_interm=False):
        x = self.conv(x)
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        x = self.block4(x)
        x = self.block5(x)
        interm = x.flatten(1)
        x = self.fc1(interm)
        x = self.relu(x)
        x = self.fc2(x)
        
        if return_interm:
            return x, interm
        
        return x

In [5]:
knn_crt = ClassificationKNNLoss(k=k).to(device)

def train(model, loader, optimizer, device):
    model.train()

    total_loss = 0
    for iteration, (batch, target) in enumerate(loader):
        batch = batch.to(device)
        target = target.to(device)
        outputs, interm = model(batch, return_interm=True)
        mse_loss = F.cross_entropy(outputs, target)
        mse_loss = mse_loss.sum(dim=-1).mean()
        try:
            knn_loss = knn_crt(interm, target)
            if torch.isinf(knn_loss):
                knn_loss = torch.tensor(0)
        except ValueError:
            knn_loss = torch.tensor(0)
        loss = mse_loss + knn_loss
        total_loss += loss.item()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if iteration % log_every == 0:
            print(f"Epoch {epoch} / {epochs} | iteration {iteration} / {len(loader)} | Total Loss: {loss.item()} | KNN Loss: {knn_loss.item()} | CLS Loss: {mse_loss.item()}")

    return total_loss / len(loader)

@torch.no_grad()
def test(model, loader, device):
    model.eval()
    
    correct = 0
    for iteration, (batch, target) in enumerate(loader):
        batch = batch.to(device)
        target = target.to(device)
        y_pred = model(batch).argmax(dim=-1)
        correct += y_pred.eq(target.view(-1).data).sum()
    
    return correct / len(loader.dataset)

In [6]:
epochs = 200
lr = 1e-3
log_every = 10

model = ECGModel()
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
num_params = sum([p.numel() for p in model.parameters()])
print(f'#Params: {num_params}')

#Params: 53957


In [7]:
best_valid_acc = 0
losses = []
train_accs = []
val_accs = []
for epoch in range(1, epochs + 1):
    loss = train(model, train_data_iter, optimizer, device)
#     print(f"Loss: {loss} =============================")
    losses.append(loss)
    train_acc = test(model, train_data_iter, device)
    train_accs.append(train_acc)
    valid_acc = test(model, test_data_iter, device)
    val_accs.append(valid_acc)
    if valid_acc > best_valid_acc:
        best_valid_acc = valid_acc
    if epoch % 1 == 0:
        print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, '
              f'Train: {train_acc:.4f}, Valid: {valid_acc:.4f}, '
              f'Best: {best_valid_acc:.4f}')

Epoch 1 / 200 | iteration 0 / 171 | Total Loss: 7.587115287780762 | KNN Loss: 5.718296527862549 | CLS Loss: 1.8688185214996338
Epoch 1 / 200 | iteration 10 / 171 | Total Loss: 5.539576053619385 | KNN Loss: 4.427477836608887 | CLS Loss: 1.1120980978012085
Epoch 1 / 200 | iteration 20 / 171 | Total Loss: 4.7229814529418945 | KNN Loss: 3.9933700561523438 | CLS Loss: 0.729611337184906
Epoch 1 / 200 | iteration 30 / 171 | Total Loss: 4.43195915222168 | KNN Loss: 3.8444793224334717 | CLS Loss: 0.5874797105789185
Epoch 1 / 200 | iteration 40 / 171 | Total Loss: 4.483575820922852 | KNN Loss: 3.8387410640716553 | CLS Loss: 0.6448348164558411
Epoch 1 / 200 | iteration 50 / 171 | Total Loss: 4.415455341339111 | KNN Loss: 3.8455843925476074 | CLS Loss: 0.5698708295822144
Epoch 1 / 200 | iteration 60 / 171 | Total Loss: 4.365846157073975 | KNN Loss: 3.7997422218322754 | CLS Loss: 0.5661040544509888
Epoch 1 / 200 | iteration 70 / 171 | Total Loss: 4.351258277893066 | KNN Loss: 3.831059455871582 | CL

Epoch 4 / 200 | iteration 80 / 171 | Total Loss: 3.8422582149505615 | KNN Loss: 3.6666877269744873 | CLS Loss: 0.17557038366794586
Epoch 4 / 200 | iteration 90 / 171 | Total Loss: 3.843769073486328 | KNN Loss: 3.654789686203003 | CLS Loss: 0.1889793574810028
Epoch 4 / 200 | iteration 100 / 171 | Total Loss: 3.826467990875244 | KNN Loss: 3.6591689586639404 | CLS Loss: 0.16729910671710968
Epoch 4 / 200 | iteration 110 / 171 | Total Loss: 3.8243889808654785 | KNN Loss: 3.6761441230773926 | CLS Loss: 0.1482449471950531
Epoch 4 / 200 | iteration 120 / 171 | Total Loss: 3.7941391468048096 | KNN Loss: 3.6541688442230225 | CLS Loss: 0.1399703025817871
Epoch 4 / 200 | iteration 130 / 171 | Total Loss: 3.7871880531311035 | KNN Loss: 3.6518688201904297 | CLS Loss: 0.13531926274299622
Epoch 4 / 200 | iteration 140 / 171 | Total Loss: 3.8234703540802 | KNN Loss: 3.676626443862915 | CLS Loss: 0.14684398472309113
Epoch 4 / 200 | iteration 150 / 171 | Total Loss: 3.854278326034546 | KNN Loss: 3.675780

Epoch 7 / 200 | iteration 160 / 171 | Total Loss: 3.7192060947418213 | KNN Loss: 3.6243038177490234 | CLS Loss: 0.09490218758583069
Epoch 7 / 200 | iteration 170 / 171 | Total Loss: 3.69795560836792 | KNN Loss: 3.6322481632232666 | CLS Loss: 0.06570734828710556
Epoch: 007, Loss: 3.7457, Train: 0.9769, Valid: 0.9742, Best: 0.9742
Epoch 8 / 200 | iteration 0 / 171 | Total Loss: 3.70446515083313 | KNN Loss: 3.6155078411102295 | CLS Loss: 0.08895742148160934
Epoch 8 / 200 | iteration 10 / 171 | Total Loss: 3.734149217605591 | KNN Loss: 3.6574413776397705 | CLS Loss: 0.07670783251523972
Epoch 8 / 200 | iteration 20 / 171 | Total Loss: 3.722667694091797 | KNN Loss: 3.6310293674468994 | CLS Loss: 0.09163843840360641
Epoch 8 / 200 | iteration 30 / 171 | Total Loss: 3.700176477432251 | KNN Loss: 3.6133313179016113 | CLS Loss: 0.08684512227773666
Epoch 8 / 200 | iteration 40 / 171 | Total Loss: 3.714923620223999 | KNN Loss: 3.64667010307312 | CLS Loss: 0.06825359165668488
Epoch 8 / 200 | iterati

Epoch 11 / 200 | iteration 50 / 171 | Total Loss: 3.685824394226074 | KNN Loss: 3.596081018447876 | CLS Loss: 0.08974327892065048
Epoch 11 / 200 | iteration 60 / 171 | Total Loss: 3.7582485675811768 | KNN Loss: 3.6727967262268066 | CLS Loss: 0.08545186370611191
Epoch 11 / 200 | iteration 70 / 171 | Total Loss: 3.6923916339874268 | KNN Loss: 3.6167309284210205 | CLS Loss: 0.07566066086292267
Epoch 11 / 200 | iteration 80 / 171 | Total Loss: 3.669750213623047 | KNN Loss: 3.6241214275360107 | CLS Loss: 0.04562876373529434
Epoch 11 / 200 | iteration 90 / 171 | Total Loss: 3.686194658279419 | KNN Loss: 3.5548713207244873 | CLS Loss: 0.13132326304912567
Epoch 11 / 200 | iteration 100 / 171 | Total Loss: 3.7228708267211914 | KNN Loss: 3.6078102588653564 | CLS Loss: 0.11506067216396332
Epoch 11 / 200 | iteration 110 / 171 | Total Loss: 3.689363718032837 | KNN Loss: 3.6066689491271973 | CLS Loss: 0.08269474655389786
Epoch 11 / 200 | iteration 120 / 171 | Total Loss: 3.7255859375 | KNN Loss: 3.6

Epoch 14 / 200 | iteration 120 / 171 | Total Loss: 3.656766891479492 | KNN Loss: 3.604492425918579 | CLS Loss: 0.05227439105510712
Epoch 14 / 200 | iteration 130 / 171 | Total Loss: 3.6770365238189697 | KNN Loss: 3.6193835735321045 | CLS Loss: 0.05765296518802643
Epoch 14 / 200 | iteration 140 / 171 | Total Loss: 3.7052364349365234 | KNN Loss: 3.636918306350708 | CLS Loss: 0.06831815838813782
Epoch 14 / 200 | iteration 150 / 171 | Total Loss: 3.679098129272461 | KNN Loss: 3.6211349964141846 | CLS Loss: 0.057963162660598755
Epoch 14 / 200 | iteration 160 / 171 | Total Loss: 3.650550603866577 | KNN Loss: 3.6043601036071777 | CLS Loss: 0.04619060084223747
Epoch 14 / 200 | iteration 170 / 171 | Total Loss: 3.611435890197754 | KNN Loss: 3.586193084716797 | CLS Loss: 0.025242773815989494
Epoch: 014, Loss: 3.6855, Train: 0.9826, Valid: 0.9795, Best: 0.9797
Epoch 15 / 200 | iteration 0 / 171 | Total Loss: 3.6574466228485107 | KNN Loss: 3.5909981727600098 | CLS Loss: 0.06644848734140396
Epoch 1

Epoch 18 / 200 | iteration 10 / 171 | Total Loss: 3.6598823070526123 | KNN Loss: 3.596204996109009 | CLS Loss: 0.06367722898721695
Epoch 18 / 200 | iteration 20 / 171 | Total Loss: 3.6729860305786133 | KNN Loss: 3.6223526000976562 | CLS Loss: 0.05063336342573166
Epoch 18 / 200 | iteration 30 / 171 | Total Loss: 3.730236053466797 | KNN Loss: 3.65004825592041 | CLS Loss: 0.08018787950277328
Epoch 18 / 200 | iteration 40 / 171 | Total Loss: 3.7101175785064697 | KNN Loss: 3.668700695037842 | CLS Loss: 0.04141680896282196
Epoch 18 / 200 | iteration 50 / 171 | Total Loss: 3.678621768951416 | KNN Loss: 3.6211235523223877 | CLS Loss: 0.057498227804899216
Epoch 18 / 200 | iteration 60 / 171 | Total Loss: 3.6612026691436768 | KNN Loss: 3.556119918823242 | CLS Loss: 0.10508272051811218
Epoch 18 / 200 | iteration 70 / 171 | Total Loss: 3.61236572265625 | KNN Loss: 3.5856475830078125 | CLS Loss: 0.026718031615018845
Epoch 18 / 200 | iteration 80 / 171 | Total Loss: 3.6716558933258057 | KNN Loss: 3.

Epoch 21 / 200 | iteration 80 / 171 | Total Loss: 3.6628754138946533 | KNN Loss: 3.5964133739471436 | CLS Loss: 0.0664619579911232
Epoch 21 / 200 | iteration 90 / 171 | Total Loss: 3.6751885414123535 | KNN Loss: 3.6318159103393555 | CLS Loss: 0.04337269812822342
Epoch 21 / 200 | iteration 100 / 171 | Total Loss: 3.630153179168701 | KNN Loss: 3.597013235092163 | CLS Loss: 0.0331399068236351
Epoch 21 / 200 | iteration 110 / 171 | Total Loss: 3.6671855449676514 | KNN Loss: 3.6241397857666016 | CLS Loss: 0.04304575175046921
Epoch 21 / 200 | iteration 120 / 171 | Total Loss: 3.6547491550445557 | KNN Loss: 3.599541187286377 | CLS Loss: 0.05520789325237274
Epoch 21 / 200 | iteration 130 / 171 | Total Loss: 3.6733055114746094 | KNN Loss: 3.642901659011841 | CLS Loss: 0.030403781682252884
Epoch 21 / 200 | iteration 140 / 171 | Total Loss: 3.671067953109741 | KNN Loss: 3.6139323711395264 | CLS Loss: 0.0571356900036335
Epoch 21 / 200 | iteration 150 / 171 | Total Loss: 3.6647772789001465 | KNN Lo

Epoch 24 / 200 | iteration 150 / 171 | Total Loss: 3.6529245376586914 | KNN Loss: 3.587357997894287 | CLS Loss: 0.06556656211614609
Epoch 24 / 200 | iteration 160 / 171 | Total Loss: 3.659420967102051 | KNN Loss: 3.624328136444092 | CLS Loss: 0.03509288281202316
Epoch 24 / 200 | iteration 170 / 171 | Total Loss: 3.6037163734436035 | KNN Loss: 3.5643045902252197 | CLS Loss: 0.03941170871257782
Epoch: 024, Loss: 3.6496, Train: 0.9889, Valid: 0.9841, Best: 0.9841
Epoch 25 / 200 | iteration 0 / 171 | Total Loss: 3.5947272777557373 | KNN Loss: 3.570183515548706 | CLS Loss: 0.02454385533928871
Epoch 25 / 200 | iteration 10 / 171 | Total Loss: 3.5984323024749756 | KNN Loss: 3.5689890384674072 | CLS Loss: 0.02944316156208515
Epoch 25 / 200 | iteration 20 / 171 | Total Loss: 3.6254680156707764 | KNN Loss: 3.590101480484009 | CLS Loss: 0.035366423428058624
Epoch 25 / 200 | iteration 30 / 171 | Total Loss: 3.6293327808380127 | KNN Loss: 3.5626914501190186 | CLS Loss: 0.06664127856492996
Epoch 25 

Epoch 28 / 200 | iteration 40 / 171 | Total Loss: 3.664663076400757 | KNN Loss: 3.621389389038086 | CLS Loss: 0.04327370598912239
Epoch 28 / 200 | iteration 50 / 171 | Total Loss: 3.6018247604370117 | KNN Loss: 3.5840976238250732 | CLS Loss: 0.017727019265294075
Epoch 28 / 200 | iteration 60 / 171 | Total Loss: 3.6275362968444824 | KNN Loss: 3.5626912117004395 | CLS Loss: 0.06484515219926834
Epoch 28 / 200 | iteration 70 / 171 | Total Loss: 3.6240897178649902 | KNN Loss: 3.5790674686431885 | CLS Loss: 0.04502223804593086
Epoch 28 / 200 | iteration 80 / 171 | Total Loss: 3.6949009895324707 | KNN Loss: 3.6034622192382812 | CLS Loss: 0.09143884479999542
Epoch 28 / 200 | iteration 90 / 171 | Total Loss: 3.6456689834594727 | KNN Loss: 3.5888876914978027 | CLS Loss: 0.05678130313754082
Epoch 28 / 200 | iteration 100 / 171 | Total Loss: 3.6619744300842285 | KNN Loss: 3.5994913578033447 | CLS Loss: 0.06248318403959274
Epoch 28 / 200 | iteration 110 / 171 | Total Loss: 3.6092026233673096 | KNN 

Epoch 31 / 200 | iteration 110 / 171 | Total Loss: 3.5875353813171387 | KNN Loss: 3.5693368911743164 | CLS Loss: 0.01819850690662861
Epoch 31 / 200 | iteration 120 / 171 | Total Loss: 3.6558549404144287 | KNN Loss: 3.6060428619384766 | CLS Loss: 0.04981211572885513
Epoch 31 / 200 | iteration 130 / 171 | Total Loss: 3.6491622924804688 | KNN Loss: 3.6076276302337646 | CLS Loss: 0.04153454676270485
Epoch 31 / 200 | iteration 140 / 171 | Total Loss: 3.6279120445251465 | KNN Loss: 3.5890257358551025 | CLS Loss: 0.03888633847236633
Epoch 31 / 200 | iteration 150 / 171 | Total Loss: 3.671757936477661 | KNN Loss: 3.6437039375305176 | CLS Loss: 0.028054043650627136
Epoch 31 / 200 | iteration 160 / 171 | Total Loss: 3.6417324542999268 | KNN Loss: 3.607290267944336 | CLS Loss: 0.03444208577275276
Epoch 31 / 200 | iteration 170 / 171 | Total Loss: 3.623818874359131 | KNN Loss: 3.59360933303833 | CLS Loss: 0.030209552496671677
Epoch: 031, Loss: 3.6399, Train: 0.9895, Valid: 0.9843, Best: 0.9843
Epo

Epoch: 034, Loss: 3.6298, Train: 0.9903, Valid: 0.9842, Best: 0.9850
Epoch 35 / 200 | iteration 0 / 171 | Total Loss: 3.6459591388702393 | KNN Loss: 3.6063640117645264 | CLS Loss: 0.039595216512680054
Epoch 35 / 200 | iteration 10 / 171 | Total Loss: 3.648469924926758 | KNN Loss: 3.613757371902466 | CLS Loss: 0.034712616354227066
Epoch 35 / 200 | iteration 20 / 171 | Total Loss: 3.6075522899627686 | KNN Loss: 3.5999791622161865 | CLS Loss: 0.007573017850518227
Epoch 35 / 200 | iteration 30 / 171 | Total Loss: 3.614513635635376 | KNN Loss: 3.5567119121551514 | CLS Loss: 0.05780177190899849
Epoch 35 / 200 | iteration 40 / 171 | Total Loss: 3.63446044921875 | KNN Loss: 3.5775022506713867 | CLS Loss: 0.0569581463932991
Epoch 35 / 200 | iteration 50 / 171 | Total Loss: 3.6043877601623535 | KNN Loss: 3.5687026977539062 | CLS Loss: 0.03568503260612488
Epoch 35 / 200 | iteration 60 / 171 | Total Loss: 3.646314859390259 | KNN Loss: 3.589869260787964 | CLS Loss: 0.056445490568876266
Epoch 35 / 2

Epoch 38 / 200 | iteration 70 / 171 | Total Loss: 3.609412670135498 | KNN Loss: 3.5557339191436768 | CLS Loss: 0.05367881804704666
Epoch 38 / 200 | iteration 80 / 171 | Total Loss: 3.5950465202331543 | KNN Loss: 3.5797388553619385 | CLS Loss: 0.01530769094824791
Epoch 38 / 200 | iteration 90 / 171 | Total Loss: 3.6174142360687256 | KNN Loss: 3.601377010345459 | CLS Loss: 0.016037175431847572
Epoch 38 / 200 | iteration 100 / 171 | Total Loss: 3.609300374984741 | KNN Loss: 3.5839457511901855 | CLS Loss: 0.025354551151394844
Epoch 38 / 200 | iteration 110 / 171 | Total Loss: 3.618511915206909 | KNN Loss: 3.5956578254699707 | CLS Loss: 0.0228541512042284
Epoch 38 / 200 | iteration 120 / 171 | Total Loss: 3.6295716762542725 | KNN Loss: 3.608710765838623 | CLS Loss: 0.02086099237203598
Epoch 38 / 200 | iteration 130 / 171 | Total Loss: 3.6345434188842773 | KNN Loss: 3.6109437942504883 | CLS Loss: 0.023599542677402496
Epoch 38 / 200 | iteration 140 / 171 | Total Loss: 3.6151649951934814 | KNN

Epoch 41 / 200 | iteration 140 / 171 | Total Loss: 3.7036964893341064 | KNN Loss: 3.643734931945801 | CLS Loss: 0.059961605817079544
Epoch 41 / 200 | iteration 150 / 171 | Total Loss: 3.6348960399627686 | KNN Loss: 3.6005005836486816 | CLS Loss: 0.03439541906118393
Epoch 41 / 200 | iteration 160 / 171 | Total Loss: 3.6818277835845947 | KNN Loss: 3.632594108581543 | CLS Loss: 0.049233578145504
Epoch 41 / 200 | iteration 170 / 171 | Total Loss: 3.6594271659851074 | KNN Loss: 3.6432623863220215 | CLS Loss: 0.01616472564637661
Epoch: 041, Loss: 3.6296, Train: 0.9896, Valid: 0.9846, Best: 0.9850
Epoch 42 / 200 | iteration 0 / 171 | Total Loss: 3.6760714054107666 | KNN Loss: 3.644674301147461 | CLS Loss: 0.0313970111310482
Epoch 42 / 200 | iteration 10 / 171 | Total Loss: 3.6073074340820312 | KNN Loss: 3.5931053161621094 | CLS Loss: 0.014202069491147995
Epoch 42 / 200 | iteration 20 / 171 | Total Loss: 3.6203484535217285 | KNN Loss: 3.5867741107940674 | CLS Loss: 0.033574458211660385
Epoch 4

Epoch 45 / 200 | iteration 30 / 171 | Total Loss: 3.672511577606201 | KNN Loss: 3.6472835540771484 | CLS Loss: 0.02522796392440796
Epoch 45 / 200 | iteration 40 / 171 | Total Loss: 3.6237564086914062 | KNN Loss: 3.585114002227783 | CLS Loss: 0.03864250332117081
Epoch 45 / 200 | iteration 50 / 171 | Total Loss: 3.6537187099456787 | KNN Loss: 3.617685079574585 | CLS Loss: 0.03603354096412659
Epoch 45 / 200 | iteration 60 / 171 | Total Loss: 3.6815381050109863 | KNN Loss: 3.640855312347412 | CLS Loss: 0.04068271815776825
Epoch 45 / 200 | iteration 70 / 171 | Total Loss: 3.6332342624664307 | KNN Loss: 3.6071860790252686 | CLS Loss: 0.02604827843606472
Epoch 45 / 200 | iteration 80 / 171 | Total Loss: 3.597831964492798 | KNN Loss: 3.568232536315918 | CLS Loss: 0.029599417001008987
Epoch 45 / 200 | iteration 90 / 171 | Total Loss: 3.58534836769104 | KNN Loss: 3.564540147781372 | CLS Loss: 0.020808208733797073
Epoch 45 / 200 | iteration 100 / 171 | Total Loss: 3.6304054260253906 | KNN Loss: 3

Epoch 48 / 200 | iteration 100 / 171 | Total Loss: 3.6072285175323486 | KNN Loss: 3.5588467121124268 | CLS Loss: 0.04838176444172859
Epoch 48 / 200 | iteration 110 / 171 | Total Loss: 3.626706600189209 | KNN Loss: 3.5956313610076904 | CLS Loss: 0.031075267121195793
Epoch 48 / 200 | iteration 120 / 171 | Total Loss: 3.610544443130493 | KNN Loss: 3.5900862216949463 | CLS Loss: 0.020458189770579338
Epoch 48 / 200 | iteration 130 / 171 | Total Loss: 3.641087055206299 | KNN Loss: 3.5824134349823 | CLS Loss: 0.05867372080683708
Epoch 48 / 200 | iteration 140 / 171 | Total Loss: 3.593996286392212 | KNN Loss: 3.5709331035614014 | CLS Loss: 0.023063167929649353
Epoch 48 / 200 | iteration 150 / 171 | Total Loss: 3.5676443576812744 | KNN Loss: 3.5578088760375977 | CLS Loss: 0.009835569187998772
Epoch 48 / 200 | iteration 160 / 171 | Total Loss: 3.625774383544922 | KNN Loss: 3.592315196990967 | CLS Loss: 0.033459294587373734
Epoch 48 / 200 | iteration 170 / 171 | Total Loss: 3.6116888523101807 | K

Epoch 51 / 200 | iteration 170 / 171 | Total Loss: 3.6032204627990723 | KNN Loss: 3.5726003646850586 | CLS Loss: 0.030620137229561806
Epoch: 051, Loss: 3.6266, Train: 0.9920, Valid: 0.9852, Best: 0.9860
Epoch 52 / 200 | iteration 0 / 171 | Total Loss: 3.613009214401245 | KNN Loss: 3.583034038543701 | CLS Loss: 0.02997509576380253
Epoch 52 / 200 | iteration 10 / 171 | Total Loss: 3.6173224449157715 | KNN Loss: 3.57765531539917 | CLS Loss: 0.03966708108782768
Epoch 52 / 200 | iteration 20 / 171 | Total Loss: 3.6434085369110107 | KNN Loss: 3.6217548847198486 | CLS Loss: 0.02165362238883972
Epoch 52 / 200 | iteration 30 / 171 | Total Loss: 3.6049838066101074 | KNN Loss: 3.5747838020324707 | CLS Loss: 0.03020007535815239
Epoch 52 / 200 | iteration 40 / 171 | Total Loss: 3.6127684116363525 | KNN Loss: 3.59572434425354 | CLS Loss: 0.017044078558683395
Epoch 52 / 200 | iteration 50 / 171 | Total Loss: 3.580366849899292 | KNN Loss: 3.5657854080200195 | CLS Loss: 0.014581328257918358
Epoch 52 / 

Epoch 55 / 200 | iteration 60 / 171 | Total Loss: 3.6295175552368164 | KNN Loss: 3.593012809753418 | CLS Loss: 0.036504726856946945
Epoch 55 / 200 | iteration 70 / 171 | Total Loss: 3.632831573486328 | KNN Loss: 3.596482992172241 | CLS Loss: 0.03634863346815109
Epoch 55 / 200 | iteration 80 / 171 | Total Loss: 3.5897128582000732 | KNN Loss: 3.57761812210083 | CLS Loss: 0.012094796635210514
Epoch 55 / 200 | iteration 90 / 171 | Total Loss: 3.6397511959075928 | KNN Loss: 3.6111011505126953 | CLS Loss: 0.028650090098381042
Epoch 55 / 200 | iteration 100 / 171 | Total Loss: 3.6083717346191406 | KNN Loss: 3.5765533447265625 | CLS Loss: 0.031818389892578125
Epoch 55 / 200 | iteration 110 / 171 | Total Loss: 3.624729633331299 | KNN Loss: 3.599191904067993 | CLS Loss: 0.02553783357143402
Epoch 55 / 200 | iteration 120 / 171 | Total Loss: 3.656813144683838 | KNN Loss: 3.6471805572509766 | CLS Loss: 0.009632624685764313
Epoch 55 / 200 | iteration 130 / 171 | Total Loss: 3.5841877460479736 | KNN 

Epoch 58 / 200 | iteration 130 / 171 | Total Loss: 3.630267381668091 | KNN Loss: 3.6204400062561035 | CLS Loss: 0.009827465750277042
Epoch 58 / 200 | iteration 140 / 171 | Total Loss: 3.5875704288482666 | KNN Loss: 3.5773775577545166 | CLS Loss: 0.010192761197686195
Epoch 58 / 200 | iteration 150 / 171 | Total Loss: 3.6248230934143066 | KNN Loss: 3.606240749359131 | CLS Loss: 0.018582366406917572
Epoch 58 / 200 | iteration 160 / 171 | Total Loss: 3.5821385383605957 | KNN Loss: 3.5691964626312256 | CLS Loss: 0.012941958382725716
Epoch 58 / 200 | iteration 170 / 171 | Total Loss: 3.61069917678833 | KNN Loss: 3.595754623413086 | CLS Loss: 0.014944599941372871
Epoch: 058, Loss: 3.6247, Train: 0.9931, Valid: 0.9860, Best: 0.9865
Epoch 59 / 200 | iteration 0 / 171 | Total Loss: 3.642448902130127 | KNN Loss: 3.619199275970459 | CLS Loss: 0.02324952557682991
Epoch 59 / 200 | iteration 10 / 171 | Total Loss: 3.6042840480804443 | KNN Loss: 3.559816598892212 | CLS Loss: 0.04446745663881302
Epoch 

Epoch 62 / 200 | iteration 20 / 171 | Total Loss: 3.637253761291504 | KNN Loss: 3.631664752960205 | CLS Loss: 0.00558896828442812
Epoch 62 / 200 | iteration 30 / 171 | Total Loss: 3.5965261459350586 | KNN Loss: 3.5670018196105957 | CLS Loss: 0.029524315148591995
Epoch 62 / 200 | iteration 40 / 171 | Total Loss: 3.673548460006714 | KNN Loss: 3.6577529907226562 | CLS Loss: 0.01579555869102478
Epoch 62 / 200 | iteration 50 / 171 | Total Loss: 3.6438748836517334 | KNN Loss: 3.636187791824341 | CLS Loss: 0.00768704991787672
Epoch 62 / 200 | iteration 60 / 171 | Total Loss: 3.607801675796509 | KNN Loss: 3.567981004714966 | CLS Loss: 0.0398206003010273
Epoch 62 / 200 | iteration 70 / 171 | Total Loss: 3.581993579864502 | KNN Loss: 3.574312686920166 | CLS Loss: 0.007680797018110752
Epoch 62 / 200 | iteration 80 / 171 | Total Loss: 3.646906852722168 | KNN Loss: 3.637376308441162 | CLS Loss: 0.009530573152005672
Epoch 62 / 200 | iteration 90 / 171 | Total Loss: 3.5848042964935303 | KNN Loss: 3.5

Epoch 65 / 200 | iteration 90 / 171 | Total Loss: 3.6567087173461914 | KNN Loss: 3.6498827934265137 | CLS Loss: 0.006825863849371672
Epoch 65 / 200 | iteration 100 / 171 | Total Loss: 3.609076738357544 | KNN Loss: 3.5919992923736572 | CLS Loss: 0.017077390104532242
Epoch 65 / 200 | iteration 110 / 171 | Total Loss: 3.594413995742798 | KNN Loss: 3.565845489501953 | CLS Loss: 0.028568435460329056
Epoch 65 / 200 | iteration 120 / 171 | Total Loss: 3.665515661239624 | KNN Loss: 3.625492572784424 | CLS Loss: 0.04002310335636139
Epoch 65 / 200 | iteration 130 / 171 | Total Loss: 3.574702024459839 | KNN Loss: 3.562844753265381 | CLS Loss: 0.011857302859425545
Epoch 65 / 200 | iteration 140 / 171 | Total Loss: 3.5914602279663086 | KNN Loss: 3.5712389945983887 | CLS Loss: 0.020221326500177383
Epoch 65 / 200 | iteration 150 / 171 | Total Loss: 3.6661250591278076 | KNN Loss: 3.630819320678711 | CLS Loss: 0.035305704921483994
Epoch 65 / 200 | iteration 160 / 171 | Total Loss: 3.5855720043182373 | 

Epoch 68 / 200 | iteration 160 / 171 | Total Loss: 3.603968620300293 | KNN Loss: 3.5870862007141113 | CLS Loss: 0.01688249409198761
Epoch 68 / 200 | iteration 170 / 171 | Total Loss: 3.6074318885803223 | KNN Loss: 3.600886106491089 | CLS Loss: 0.006545674055814743
Epoch: 068, Loss: 3.6204, Train: 0.9939, Valid: 0.9855, Best: 0.9866
Epoch 69 / 200 | iteration 0 / 171 | Total Loss: 3.542783498764038 | KNN Loss: 3.5305745601654053 | CLS Loss: 0.012208975851535797
Epoch 69 / 200 | iteration 10 / 171 | Total Loss: 3.607537031173706 | KNN Loss: 3.6018497943878174 | CLS Loss: 0.005687124561518431
Epoch 69 / 200 | iteration 20 / 171 | Total Loss: 3.631160020828247 | KNN Loss: 3.60326886177063 | CLS Loss: 0.027891084551811218
Epoch 69 / 200 | iteration 30 / 171 | Total Loss: 3.617069721221924 | KNN Loss: 3.5854134559631348 | CLS Loss: 0.031656187027692795
Epoch 69 / 200 | iteration 40 / 171 | Total Loss: 3.569350242614746 | KNN Loss: 3.5433928966522217 | CLS Loss: 0.025957411155104637
Epoch 69 

Epoch 72 / 200 | iteration 50 / 171 | Total Loss: 3.595602035522461 | KNN Loss: 3.565916061401367 | CLS Loss: 0.02968592569231987
Epoch 72 / 200 | iteration 60 / 171 | Total Loss: 3.6294562816619873 | KNN Loss: 3.608227491378784 | CLS Loss: 0.021228674799203873
Epoch 72 / 200 | iteration 70 / 171 | Total Loss: 3.567329168319702 | KNN Loss: 3.559237480163574 | CLS Loss: 0.008091665804386139
Epoch 72 / 200 | iteration 80 / 171 | Total Loss: 3.6053998470306396 | KNN Loss: 3.5969083309173584 | CLS Loss: 0.008491543121635914
Epoch 72 / 200 | iteration 90 / 171 | Total Loss: 3.624803066253662 | KNN Loss: 3.6117005348205566 | CLS Loss: 0.013102540746331215
Epoch 72 / 200 | iteration 100 / 171 | Total Loss: 3.631629228591919 | KNN Loss: 3.6225013732910156 | CLS Loss: 0.00912792794406414
Epoch 72 / 200 | iteration 110 / 171 | Total Loss: 3.6003642082214355 | KNN Loss: 3.5710885524749756 | CLS Loss: 0.029275698587298393
Epoch 72 / 200 | iteration 120 / 171 | Total Loss: 3.6126856803894043 | KNN 

Epoch 75 / 200 | iteration 120 / 171 | Total Loss: 3.588238000869751 | KNN Loss: 3.5727734565734863 | CLS Loss: 0.01546463929116726
Epoch 75 / 200 | iteration 130 / 171 | Total Loss: 3.618821620941162 | KNN Loss: 3.5987560749053955 | CLS Loss: 0.02006547339260578
Epoch 75 / 200 | iteration 140 / 171 | Total Loss: 3.615615129470825 | KNN Loss: 3.606846570968628 | CLS Loss: 0.008768527768552303
Epoch 75 / 200 | iteration 150 / 171 | Total Loss: 3.6212635040283203 | KNN Loss: 3.603041410446167 | CLS Loss: 0.01822216622531414
Epoch 75 / 200 | iteration 160 / 171 | Total Loss: 3.596771001815796 | KNN Loss: 3.5731301307678223 | CLS Loss: 0.023640954867005348
Epoch 75 / 200 | iteration 170 / 171 | Total Loss: 3.5844931602478027 | KNN Loss: 3.5729644298553467 | CLS Loss: 0.01152863260358572
Epoch: 075, Loss: 3.6129, Train: 0.9935, Valid: 0.9857, Best: 0.9873
Epoch 76 / 200 | iteration 0 / 171 | Total Loss: 3.5972414016723633 | KNN Loss: 3.5787317752838135 | CLS Loss: 0.018509624525904655
Epoch

Epoch 79 / 200 | iteration 10 / 171 | Total Loss: 3.5795435905456543 | KNN Loss: 3.574016809463501 | CLS Loss: 0.005526815541088581
Epoch 79 / 200 | iteration 20 / 171 | Total Loss: 3.611707925796509 | KNN Loss: 3.597028970718384 | CLS Loss: 0.014678975567221642
Epoch 79 / 200 | iteration 30 / 171 | Total Loss: 3.65199613571167 | KNN Loss: 3.626490592956543 | CLS Loss: 0.025505470111966133
Epoch 79 / 200 | iteration 40 / 171 | Total Loss: 3.601844310760498 | KNN Loss: 3.587657928466797 | CLS Loss: 0.014186344109475613
Epoch 79 / 200 | iteration 50 / 171 | Total Loss: 3.585623025894165 | KNN Loss: 3.5665955543518066 | CLS Loss: 0.01902754418551922
Epoch 79 / 200 | iteration 60 / 171 | Total Loss: 3.6452815532684326 | KNN Loss: 3.634308099746704 | CLS Loss: 0.01097337156534195
Epoch 79 / 200 | iteration 70 / 171 | Total Loss: 3.5722925662994385 | KNN Loss: 3.569079875946045 | CLS Loss: 0.0032126158475875854
Epoch 79 / 200 | iteration 80 / 171 | Total Loss: 3.6039657592773438 | KNN Loss: 

Epoch 82 / 200 | iteration 80 / 171 | Total Loss: 3.6057262420654297 | KNN Loss: 3.595538377761841 | CLS Loss: 0.010187924839556217
Epoch 82 / 200 | iteration 90 / 171 | Total Loss: 3.643526315689087 | KNN Loss: 3.612896203994751 | CLS Loss: 0.03063000552356243
Epoch 82 / 200 | iteration 100 / 171 | Total Loss: 3.6049020290374756 | KNN Loss: 3.59982967376709 | CLS Loss: 0.005072413478046656
Epoch 82 / 200 | iteration 110 / 171 | Total Loss: 3.595329523086548 | KNN Loss: 3.5910658836364746 | CLS Loss: 0.004263700917363167
Epoch 82 / 200 | iteration 120 / 171 | Total Loss: 3.5860626697540283 | KNN Loss: 3.5682106018066406 | CLS Loss: 0.017852067947387695
Epoch 82 / 200 | iteration 130 / 171 | Total Loss: 3.6055381298065186 | KNN Loss: 3.590984582901001 | CLS Loss: 0.014553564600646496
Epoch 82 / 200 | iteration 140 / 171 | Total Loss: 3.5788862705230713 | KNN Loss: 3.566373109817505 | CLS Loss: 0.012513046152889729
Epoch 82 / 200 | iteration 150 / 171 | Total Loss: 3.6350977420806885 | K

Epoch 85 / 200 | iteration 150 / 171 | Total Loss: 3.612131357192993 | KNN Loss: 3.581892728805542 | CLS Loss: 0.030238624662160873
Epoch 85 / 200 | iteration 160 / 171 | Total Loss: 3.578245162963867 | KNN Loss: 3.5351645946502686 | CLS Loss: 0.043080586940050125
Epoch 85 / 200 | iteration 170 / 171 | Total Loss: 3.5763943195343018 | KNN Loss: 3.556379556655884 | CLS Loss: 0.020014654844999313
Epoch: 085, Loss: 3.6088, Train: 0.9956, Valid: 0.9875, Best: 0.9875
Epoch 86 / 200 | iteration 0 / 171 | Total Loss: 3.5819389820098877 | KNN Loss: 3.5481066703796387 | CLS Loss: 0.03383241221308708
Epoch 86 / 200 | iteration 10 / 171 | Total Loss: 3.6101012229919434 | KNN Loss: 3.5839273929595947 | CLS Loss: 0.026173725724220276
Epoch 86 / 200 | iteration 20 / 171 | Total Loss: 3.592787027359009 | KNN Loss: 3.57601261138916 | CLS Loss: 0.016774320974946022
Epoch 86 / 200 | iteration 30 / 171 | Total Loss: 3.6169490814208984 | KNN Loss: 3.5953681468963623 | CLS Loss: 0.02158096246421337
Epoch 8

Epoch 89 / 200 | iteration 30 / 171 | Total Loss: 3.6409811973571777 | KNN Loss: 3.6046719551086426 | CLS Loss: 0.03630915656685829
Epoch 89 / 200 | iteration 40 / 171 | Total Loss: 3.6155614852905273 | KNN Loss: 3.608588218688965 | CLS Loss: 0.006973337382078171
Epoch 89 / 200 | iteration 50 / 171 | Total Loss: 3.6152987480163574 | KNN Loss: 3.580678701400757 | CLS Loss: 0.03462015837430954
Epoch 89 / 200 | iteration 60 / 171 | Total Loss: 3.607699394226074 | KNN Loss: 3.5945327281951904 | CLS Loss: 0.013166782446205616
Epoch 89 / 200 | iteration 70 / 171 | Total Loss: 3.649501323699951 | KNN Loss: 3.6113343238830566 | CLS Loss: 0.0381670817732811
Epoch 89 / 200 | iteration 80 / 171 | Total Loss: 3.6177093982696533 | KNN Loss: 3.5968685150146484 | CLS Loss: 0.0208408385515213
Epoch 89 / 200 | iteration 90 / 171 | Total Loss: 3.6179189682006836 | KNN Loss: 3.601796865463257 | CLS Loss: 0.016122013330459595
Epoch 89 / 200 | iteration 100 / 171 | Total Loss: 3.610335111618042 | KNN Loss:

Epoch 92 / 200 | iteration 100 / 171 | Total Loss: 3.6637511253356934 | KNN Loss: 3.635371208190918 | CLS Loss: 0.0283798985183239
Epoch 92 / 200 | iteration 110 / 171 | Total Loss: 3.6254169940948486 | KNN Loss: 3.604728937149048 | CLS Loss: 0.02068793959915638
Epoch 92 / 200 | iteration 120 / 171 | Total Loss: 3.658700942993164 | KNN Loss: 3.6387224197387695 | CLS Loss: 0.01997850462794304
Epoch 92 / 200 | iteration 130 / 171 | Total Loss: 3.6086857318878174 | KNN Loss: 3.58693528175354 | CLS Loss: 0.0217505544424057
Epoch 92 / 200 | iteration 140 / 171 | Total Loss: 3.6589272022247314 | KNN Loss: 3.6508240699768066 | CLS Loss: 0.00810316763818264
Epoch 92 / 200 | iteration 150 / 171 | Total Loss: 3.6004223823547363 | KNN Loss: 3.5737216472625732 | CLS Loss: 0.026700809597969055
Epoch 92 / 200 | iteration 160 / 171 | Total Loss: 3.649937629699707 | KNN Loss: 3.6400482654571533 | CLS Loss: 0.0098893903195858
Epoch 92 / 200 | iteration 170 / 171 | Total Loss: 3.64971923828125 | KNN Los

Epoch 95 / 200 | iteration 170 / 171 | Total Loss: 3.6290297508239746 | KNN Loss: 3.622307777404785 | CLS Loss: 0.006721979007124901
Epoch: 095, Loss: 3.6244, Train: 0.9954, Valid: 0.9863, Best: 0.9875
Epoch 96 / 200 | iteration 0 / 171 | Total Loss: 3.6155080795288086 | KNN Loss: 3.6072757244110107 | CLS Loss: 0.008232351392507553
Epoch 96 / 200 | iteration 10 / 171 | Total Loss: 3.61220121383667 | KNN Loss: 3.6019012928009033 | CLS Loss: 0.010299992747604847
Epoch 96 / 200 | iteration 20 / 171 | Total Loss: 3.6252317428588867 | KNN Loss: 3.614065647125244 | CLS Loss: 0.01116605568677187
Epoch 96 / 200 | iteration 30 / 171 | Total Loss: 3.630497455596924 | KNN Loss: 3.6249401569366455 | CLS Loss: 0.00555720878764987
Epoch 96 / 200 | iteration 40 / 171 | Total Loss: 3.6307482719421387 | KNN Loss: 3.6112895011901855 | CLS Loss: 0.019458767026662827
Epoch 96 / 200 | iteration 50 / 171 | Total Loss: 3.6126651763916016 | KNN Loss: 3.5999324321746826 | CLS Loss: 0.012732641771435738
Epoch 9

Epoch 99 / 200 | iteration 50 / 171 | Total Loss: 3.670285940170288 | KNN Loss: 3.656085252761841 | CLS Loss: 0.014200570061802864
Epoch 99 / 200 | iteration 60 / 171 | Total Loss: 3.6393704414367676 | KNN Loss: 3.6265666484832764 | CLS Loss: 0.012803874909877777
Epoch 99 / 200 | iteration 70 / 171 | Total Loss: 3.616502523422241 | KNN Loss: 3.606290817260742 | CLS Loss: 0.010211811400949955
Epoch 99 / 200 | iteration 80 / 171 | Total Loss: 3.609179735183716 | KNN Loss: 3.600865125656128 | CLS Loss: 0.008314545266330242
Epoch 99 / 200 | iteration 90 / 171 | Total Loss: 3.6458022594451904 | KNN Loss: 3.634718894958496 | CLS Loss: 0.011083351448178291
Epoch 99 / 200 | iteration 100 / 171 | Total Loss: 3.6333556175231934 | KNN Loss: 3.6135127544403076 | CLS Loss: 0.019842753186821938
Epoch 99 / 200 | iteration 110 / 171 | Total Loss: 3.5879669189453125 | KNN Loss: 3.5675997734069824 | CLS Loss: 0.0203670933842659
Epoch 99 / 200 | iteration 120 / 171 | Total Loss: 3.6800918579101562 | KNN 

Epoch 102 / 200 | iteration 120 / 171 | Total Loss: 3.6326003074645996 | KNN Loss: 3.626147985458374 | CLS Loss: 0.006452325265854597
Epoch 102 / 200 | iteration 130 / 171 | Total Loss: 3.5986056327819824 | KNN Loss: 3.575068712234497 | CLS Loss: 0.023536834865808487
Epoch 102 / 200 | iteration 140 / 171 | Total Loss: 3.6016316413879395 | KNN Loss: 3.5974018573760986 | CLS Loss: 0.004229837097227573
Epoch 102 / 200 | iteration 150 / 171 | Total Loss: 3.6811420917510986 | KNN Loss: 3.674394130706787 | CLS Loss: 0.00674797035753727
Epoch 102 / 200 | iteration 160 / 171 | Total Loss: 3.6441943645477295 | KNN Loss: 3.6273746490478516 | CLS Loss: 0.016819771379232407
Epoch 102 / 200 | iteration 170 / 171 | Total Loss: 3.6239867210388184 | KNN Loss: 3.6141674518585205 | CLS Loss: 0.009819337166845798
Epoch: 102, Loss: 3.6124, Train: 0.9968, Valid: 0.9868, Best: 0.9875
Epoch 103 / 200 | iteration 0 / 171 | Total Loss: 3.5726523399353027 | KNN Loss: 3.5694966316223145 | CLS Loss: 0.00315571599

Epoch: 105, Loss: 3.6099, Train: 0.9966, Valid: 0.9867, Best: 0.9875
Epoch 106 / 200 | iteration 0 / 171 | Total Loss: 3.588693857192993 | KNN Loss: 3.579767942428589 | CLS Loss: 0.008925908245146275
Epoch 106 / 200 | iteration 10 / 171 | Total Loss: 3.6000139713287354 | KNN Loss: 3.5821282863616943 | CLS Loss: 0.017885761335492134
Epoch 106 / 200 | iteration 20 / 171 | Total Loss: 3.5936825275421143 | KNN Loss: 3.5896706581115723 | CLS Loss: 0.004011890850961208
Epoch 106 / 200 | iteration 30 / 171 | Total Loss: 3.5753836631774902 | KNN Loss: 3.5697450637817383 | CLS Loss: 0.0056386590003967285
Epoch 106 / 200 | iteration 40 / 171 | Total Loss: 3.6065213680267334 | KNN Loss: 3.5842292308807373 | CLS Loss: 0.02229221910238266
Epoch 106 / 200 | iteration 50 / 171 | Total Loss: 3.6030256748199463 | KNN Loss: 3.593996524810791 | CLS Loss: 0.00902919378131628
Epoch 106 / 200 | iteration 60 / 171 | Total Loss: 3.59574031829834 | KNN Loss: 3.5622966289520264 | CLS Loss: 0.033443599939346313


Epoch 109 / 200 | iteration 60 / 171 | Total Loss: 3.6601572036743164 | KNN Loss: 3.632415294647217 | CLS Loss: 0.027741989120841026
Epoch 109 / 200 | iteration 70 / 171 | Total Loss: 3.6308915615081787 | KNN Loss: 3.6178839206695557 | CLS Loss: 0.013007537461817265
Epoch 109 / 200 | iteration 80 / 171 | Total Loss: 3.621035575866699 | KNN Loss: 3.601595640182495 | CLS Loss: 0.01943988725543022
Epoch 109 / 200 | iteration 90 / 171 | Total Loss: 3.573359251022339 | KNN Loss: 3.560431480407715 | CLS Loss: 0.012927884235978127
Epoch 109 / 200 | iteration 100 / 171 | Total Loss: 3.6564619541168213 | KNN Loss: 3.6409859657287598 | CLS Loss: 0.015475952066481113
Epoch 109 / 200 | iteration 110 / 171 | Total Loss: 3.6041371822357178 | KNN Loss: 3.5755436420440674 | CLS Loss: 0.02859349548816681
Epoch 109 / 200 | iteration 120 / 171 | Total Loss: 3.685354471206665 | KNN Loss: 3.671964168548584 | CLS Loss: 0.013390250504016876
Epoch 109 / 200 | iteration 130 / 171 | Total Loss: 3.60727238655090

Epoch 112 / 200 | iteration 120 / 171 | Total Loss: 3.622364044189453 | KNN Loss: 3.6030898094177246 | CLS Loss: 0.019274236634373665
Epoch 112 / 200 | iteration 130 / 171 | Total Loss: 3.6354804039001465 | KNN Loss: 3.628413677215576 | CLS Loss: 0.007066616788506508
Epoch 112 / 200 | iteration 140 / 171 | Total Loss: 3.6217353343963623 | KNN Loss: 3.6079416275024414 | CLS Loss: 0.013793774880468845
Epoch 112 / 200 | iteration 150 / 171 | Total Loss: 3.657914638519287 | KNN Loss: 3.6376898288726807 | CLS Loss: 0.020224811509251595
Epoch 112 / 200 | iteration 160 / 171 | Total Loss: 3.6453847885131836 | KNN Loss: 3.6207504272460938 | CLS Loss: 0.024634359404444695
Epoch 112 / 200 | iteration 170 / 171 | Total Loss: 3.582343101501465 | KNN Loss: 3.575068950653076 | CLS Loss: 0.007274136412888765
Epoch: 112, Loss: 3.6163, Train: 0.9938, Valid: 0.9852, Best: 0.9875
Epoch 113 / 200 | iteration 0 / 171 | Total Loss: 3.6486217975616455 | KNN Loss: 3.5922060012817383 | CLS Loss: 0.056415870785

Epoch: 115, Loss: 3.6127, Train: 0.9965, Valid: 0.9873, Best: 0.9877
Epoch 116 / 200 | iteration 0 / 171 | Total Loss: 3.6561129093170166 | KNN Loss: 3.6365363597869873 | CLS Loss: 0.019576620310544968
Epoch 116 / 200 | iteration 10 / 171 | Total Loss: 3.638606071472168 | KNN Loss: 3.6319005489349365 | CLS Loss: 0.006705405190587044
Epoch 116 / 200 | iteration 20 / 171 | Total Loss: 3.5630295276641846 | KNN Loss: 3.542226552963257 | CLS Loss: 0.02080288715660572
Epoch 116 / 200 | iteration 30 / 171 | Total Loss: 3.705322027206421 | KNN Loss: 3.668548345565796 | CLS Loss: 0.03677358105778694
Epoch 116 / 200 | iteration 40 / 171 | Total Loss: 3.5994927883148193 | KNN Loss: 3.591642379760742 | CLS Loss: 0.00785032194107771
Epoch 116 / 200 | iteration 50 / 171 | Total Loss: 3.600027322769165 | KNN Loss: 3.594968795776367 | CLS Loss: 0.005058483220636845
Epoch 116 / 200 | iteration 60 / 171 | Total Loss: 3.6387765407562256 | KNN Loss: 3.626321792602539 | CLS Loss: 0.012454641051590443
Epoch

Epoch 119 / 200 | iteration 60 / 171 | Total Loss: 3.6527762413024902 | KNN Loss: 3.616206645965576 | CLS Loss: 0.036569613963365555
Epoch 119 / 200 | iteration 70 / 171 | Total Loss: 3.6106479167938232 | KNN Loss: 3.591291904449463 | CLS Loss: 0.01935601606965065
Epoch 119 / 200 | iteration 80 / 171 | Total Loss: 3.583786725997925 | KNN Loss: 3.5602989196777344 | CLS Loss: 0.023487897589802742
Epoch 119 / 200 | iteration 90 / 171 | Total Loss: 3.5860371589660645 | KNN Loss: 3.57378888130188 | CLS Loss: 0.01224822923541069
Epoch 119 / 200 | iteration 100 / 171 | Total Loss: 3.641547203063965 | KNN Loss: 3.6205272674560547 | CLS Loss: 0.02101999707520008
Epoch 119 / 200 | iteration 110 / 171 | Total Loss: 3.5877671241760254 | KNN Loss: 3.569577932357788 | CLS Loss: 0.018189137801527977
Epoch 119 / 200 | iteration 120 / 171 | Total Loss: 3.5822908878326416 | KNN Loss: 3.559713125228882 | CLS Loss: 0.022577764466404915
Epoch 119 / 200 | iteration 130 / 171 | Total Loss: 3.685929775238037 

Epoch 122 / 200 | iteration 120 / 171 | Total Loss: 3.6100127696990967 | KNN Loss: 3.5973424911499023 | CLS Loss: 0.012670342810451984
Epoch 122 / 200 | iteration 130 / 171 | Total Loss: 3.573925733566284 | KNN Loss: 3.5707552433013916 | CLS Loss: 0.0031705701258033514
Epoch 122 / 200 | iteration 140 / 171 | Total Loss: 3.644120931625366 | KNN Loss: 3.6242971420288086 | CLS Loss: 0.019823750481009483
Epoch 122 / 200 | iteration 150 / 171 | Total Loss: 3.5831189155578613 | KNN Loss: 3.5766780376434326 | CLS Loss: 0.006440938916057348
Epoch 122 / 200 | iteration 160 / 171 | Total Loss: 3.5909249782562256 | KNN Loss: 3.5724759101867676 | CLS Loss: 0.018449164927005768
Epoch 122 / 200 | iteration 170 / 171 | Total Loss: 3.6484649181365967 | KNN Loss: 3.6397483348846436 | CLS Loss: 0.008716619573533535
Epoch: 122, Loss: 3.6143, Train: 0.9949, Valid: 0.9844, Best: 0.9877
Epoch 123 / 200 | iteration 0 / 171 | Total Loss: 3.5803630352020264 | KNN Loss: 3.5702266693115234 | CLS Loss: 0.01013643

Epoch: 125, Loss: 3.6164, Train: 0.9963, Valid: 0.9866, Best: 0.9877
Epoch 126 / 200 | iteration 0 / 171 | Total Loss: 3.5924618244171143 | KNN Loss: 3.575559139251709 | CLS Loss: 0.016902752220630646
Epoch 126 / 200 | iteration 10 / 171 | Total Loss: 3.598458766937256 | KNN Loss: 3.582456350326538 | CLS Loss: 0.01600247249007225
Epoch 126 / 200 | iteration 20 / 171 | Total Loss: 3.594820737838745 | KNN Loss: 3.584144115447998 | CLS Loss: 0.010676519013941288
Epoch 126 / 200 | iteration 30 / 171 | Total Loss: 3.5910470485687256 | KNN Loss: 3.5745248794555664 | CLS Loss: 0.01652228645980358
Epoch 126 / 200 | iteration 40 / 171 | Total Loss: 3.579653739929199 | KNN Loss: 3.5648913383483887 | CLS Loss: 0.014762460254132748
Epoch 126 / 200 | iteration 50 / 171 | Total Loss: 3.574450731277466 | KNN Loss: 3.5623207092285156 | CLS Loss: 0.012129954993724823
Epoch 126 / 200 | iteration 60 / 171 | Total Loss: 3.6342904567718506 | KNN Loss: 3.6186904907226562 | CLS Loss: 0.015599938109517097
Epo

Epoch 129 / 200 | iteration 60 / 171 | Total Loss: 3.5673868656158447 | KNN Loss: 3.561445713043213 | CLS Loss: 0.005941166076809168
Epoch 129 / 200 | iteration 70 / 171 | Total Loss: 3.6052794456481934 | KNN Loss: 3.601505756378174 | CLS Loss: 0.0037736310623586178
Epoch 129 / 200 | iteration 80 / 171 | Total Loss: 3.6172239780426025 | KNN Loss: 3.6154584884643555 | CLS Loss: 0.0017655410338193178
Epoch 129 / 200 | iteration 90 / 171 | Total Loss: 3.5998589992523193 | KNN Loss: 3.5767412185668945 | CLS Loss: 0.023117845878005028
Epoch 129 / 200 | iteration 100 / 171 | Total Loss: 3.630061626434326 | KNN Loss: 3.6064648628234863 | CLS Loss: 0.02359664998948574
Epoch 129 / 200 | iteration 110 / 171 | Total Loss: 3.624075412750244 | KNN Loss: 3.611659526824951 | CLS Loss: 0.01241578720510006
Epoch 129 / 200 | iteration 120 / 171 | Total Loss: 3.594989538192749 | KNN Loss: 3.583191156387329 | CLS Loss: 0.011798477731645107
Epoch 129 / 200 | iteration 130 / 171 | Total Loss: 3.611742734909

Epoch 132 / 200 | iteration 130 / 171 | Total Loss: 3.6130335330963135 | KNN Loss: 3.6087419986724854 | CLS Loss: 0.004291596356779337
Epoch 132 / 200 | iteration 140 / 171 | Total Loss: 3.589982748031616 | KNN Loss: 3.584198474884033 | CLS Loss: 0.005784334149211645
Epoch 132 / 200 | iteration 150 / 171 | Total Loss: 3.584033727645874 | KNN Loss: 3.5796096324920654 | CLS Loss: 0.004424158949404955
Epoch 132 / 200 | iteration 160 / 171 | Total Loss: 3.59462308883667 | KNN Loss: 3.567808151245117 | CLS Loss: 0.0268148984760046
Epoch 132 / 200 | iteration 170 / 171 | Total Loss: 3.6687536239624023 | KNN Loss: 3.6445584297180176 | CLS Loss: 0.024195225909352303
Epoch: 132, Loss: 3.6076, Train: 0.9958, Valid: 0.9860, Best: 0.9877
Epoch 133 / 200 | iteration 0 / 171 | Total Loss: 3.6357345581054688 | KNN Loss: 3.6161179542541504 | CLS Loss: 0.019616615027189255
Epoch 133 / 200 | iteration 10 / 171 | Total Loss: 3.5950515270233154 | KNN Loss: 3.586660385131836 | CLS Loss: 0.00839118380099535

Epoch 136 / 200 | iteration 10 / 171 | Total Loss: 3.5610289573669434 | KNN Loss: 3.5594146251678467 | CLS Loss: 0.0016144179971888661
Epoch 136 / 200 | iteration 20 / 171 | Total Loss: 3.6158742904663086 | KNN Loss: 3.6068971157073975 | CLS Loss: 0.00897705927491188
Epoch 136 / 200 | iteration 30 / 171 | Total Loss: 3.6829161643981934 | KNN Loss: 3.66926908493042 | CLS Loss: 0.01364714652299881
Epoch 136 / 200 | iteration 40 / 171 | Total Loss: 3.6661744117736816 | KNN Loss: 3.6549806594848633 | CLS Loss: 0.01119375042617321
Epoch 136 / 200 | iteration 50 / 171 | Total Loss: 3.5809338092803955 | KNN Loss: 3.5691380500793457 | CLS Loss: 0.011795686557888985
Epoch 136 / 200 | iteration 60 / 171 | Total Loss: 3.574237823486328 | KNN Loss: 3.566317319869995 | CLS Loss: 0.00792052410542965
Epoch 136 / 200 | iteration 70 / 171 | Total Loss: 3.6133646965026855 | KNN Loss: 3.600616693496704 | CLS Loss: 0.012748093344271183
Epoch 136 / 200 | iteration 80 / 171 | Total Loss: 3.6210227012634277 

Epoch 139 / 200 | iteration 70 / 171 | Total Loss: 3.58958101272583 | KNN Loss: 3.5708680152893066 | CLS Loss: 0.018713073804974556
Epoch 139 / 200 | iteration 80 / 171 | Total Loss: 3.5985968112945557 | KNN Loss: 3.5779452323913574 | CLS Loss: 0.020651696249842644
Epoch 139 / 200 | iteration 90 / 171 | Total Loss: 3.6102895736694336 | KNN Loss: 3.5998618602752686 | CLS Loss: 0.010427772998809814
Epoch 139 / 200 | iteration 100 / 171 | Total Loss: 3.550642967224121 | KNN Loss: 3.536332368850708 | CLS Loss: 0.014310515485703945
Epoch 139 / 200 | iteration 110 / 171 | Total Loss: 3.5837700366973877 | KNN Loss: 3.580122709274292 | CLS Loss: 0.0036474394146353006
Epoch 139 / 200 | iteration 120 / 171 | Total Loss: 3.6114580631256104 | KNN Loss: 3.6008615493774414 | CLS Loss: 0.01059652492403984
Epoch 139 / 200 | iteration 130 / 171 | Total Loss: 3.6001102924346924 | KNN Loss: 3.5706570148468018 | CLS Loss: 0.029453160241246223
Epoch 139 / 200 | iteration 140 / 171 | Total Loss: 3.596314907

Epoch 142 / 200 | iteration 130 / 171 | Total Loss: 3.5890674591064453 | KNN Loss: 3.565585136413574 | CLS Loss: 0.02348233386874199
Epoch 142 / 200 | iteration 140 / 171 | Total Loss: 3.61161470413208 | KNN Loss: 3.581075429916382 | CLS Loss: 0.030539188534021378
Epoch 142 / 200 | iteration 150 / 171 | Total Loss: 3.6062328815460205 | KNN Loss: 3.5816354751586914 | CLS Loss: 0.02459731511771679
Epoch 142 / 200 | iteration 160 / 171 | Total Loss: 3.6792173385620117 | KNN Loss: 3.6246109008789062 | CLS Loss: 0.05460640788078308
Epoch 142 / 200 | iteration 170 / 171 | Total Loss: 3.651214122772217 | KNN Loss: 3.605290174484253 | CLS Loss: 0.04592403396964073
Epoch: 142, Loss: 3.6130, Train: 0.9965, Valid: 0.9863, Best: 0.9877
Epoch 143 / 200 | iteration 0 / 171 | Total Loss: 3.573319673538208 | KNN Loss: 3.5469679832458496 | CLS Loss: 0.026351701468229294
Epoch 143 / 200 | iteration 10 / 171 | Total Loss: 3.559103488922119 | KNN Loss: 3.5551562309265137 | CLS Loss: 0.003947185818105936
E

Epoch 146 / 200 | iteration 10 / 171 | Total Loss: 3.5896432399749756 | KNN Loss: 3.5823323726654053 | CLS Loss: 0.0073108673095703125
Epoch 146 / 200 | iteration 20 / 171 | Total Loss: 3.581458330154419 | KNN Loss: 3.5765645503997803 | CLS Loss: 0.004893822595477104
Epoch 146 / 200 | iteration 30 / 171 | Total Loss: 3.6006951332092285 | KNN Loss: 3.5832879543304443 | CLS Loss: 0.017407283186912537
Epoch 146 / 200 | iteration 40 / 171 | Total Loss: 3.5913913249969482 | KNN Loss: 3.573756217956543 | CLS Loss: 0.017635134980082512
Epoch 146 / 200 | iteration 50 / 171 | Total Loss: 3.6220993995666504 | KNN Loss: 3.6197056770324707 | CLS Loss: 0.0023938361555337906
Epoch 146 / 200 | iteration 60 / 171 | Total Loss: 3.6426050662994385 | KNN Loss: 3.613386392593384 | CLS Loss: 0.02921857312321663
Epoch 146 / 200 | iteration 70 / 171 | Total Loss: 3.6541919708251953 | KNN Loss: 3.6353280544281006 | CLS Loss: 0.018863823264837265
Epoch 146 / 200 | iteration 80 / 171 | Total Loss: 3.67793631553

Epoch 149 / 200 | iteration 70 / 171 | Total Loss: 3.667750597000122 | KNN Loss: 3.657876491546631 | CLS Loss: 0.009874027222394943
Epoch 149 / 200 | iteration 80 / 171 | Total Loss: 3.5609798431396484 | KNN Loss: 3.5565030574798584 | CLS Loss: 0.004476810339838266
Epoch 149 / 200 | iteration 90 / 171 | Total Loss: 3.589830160140991 | KNN Loss: 3.5805208683013916 | CLS Loss: 0.009309304878115654
Epoch 149 / 200 | iteration 100 / 171 | Total Loss: 3.6372735500335693 | KNN Loss: 3.6219236850738525 | CLS Loss: 0.01534995436668396
Epoch 149 / 200 | iteration 110 / 171 | Total Loss: 3.590461254119873 | KNN Loss: 3.56977915763855 | CLS Loss: 0.020682021975517273
Epoch 149 / 200 | iteration 120 / 171 | Total Loss: 3.6229746341705322 | KNN Loss: 3.6078553199768066 | CLS Loss: 0.015119382180273533
Epoch 149 / 200 | iteration 130 / 171 | Total Loss: 3.554780960083008 | KNN Loss: 3.5528335571289062 | CLS Loss: 0.0019473781576380134
Epoch 149 / 200 | iteration 140 / 171 | Total Loss: 3.57972383499

Epoch 152 / 200 | iteration 130 / 171 | Total Loss: 3.6103904247283936 | KNN Loss: 3.600930690765381 | CLS Loss: 0.009459781460464
Epoch 152 / 200 | iteration 140 / 171 | Total Loss: 3.619649648666382 | KNN Loss: 3.592257499694824 | CLS Loss: 0.027392050251364708
Epoch 152 / 200 | iteration 150 / 171 | Total Loss: 3.6176135540008545 | KNN Loss: 3.5943808555603027 | CLS Loss: 0.02323276549577713
Epoch 152 / 200 | iteration 160 / 171 | Total Loss: 3.5575950145721436 | KNN Loss: 3.548957347869873 | CLS Loss: 0.008637731894850731
Epoch 152 / 200 | iteration 170 / 171 | Total Loss: 3.6263792514801025 | KNN Loss: 3.601743221282959 | CLS Loss: 0.024636147543787956
Epoch: 152, Loss: 3.6114, Train: 0.9971, Valid: 0.9868, Best: 0.9877
Epoch 153 / 200 | iteration 0 / 171 | Total Loss: 3.6297600269317627 | KNN Loss: 3.609898567199707 | CLS Loss: 0.019861385226249695
Epoch 153 / 200 | iteration 10 / 171 | Total Loss: 3.625674247741699 | KNN Loss: 3.60467791557312 | CLS Loss: 0.020996246486902237
Ep

Epoch 156 / 200 | iteration 10 / 171 | Total Loss: 3.592717170715332 | KNN Loss: 3.590681552886963 | CLS Loss: 0.0020356278400868177
Epoch 156 / 200 | iteration 20 / 171 | Total Loss: 3.571916103363037 | KNN Loss: 3.563425302505493 | CLS Loss: 0.008490857668220997
Epoch 156 / 200 | iteration 30 / 171 | Total Loss: 3.5947558879852295 | KNN Loss: 3.5844175815582275 | CLS Loss: 0.010338193736970425
Epoch 156 / 200 | iteration 40 / 171 | Total Loss: 3.6217103004455566 | KNN Loss: 3.6191670894622803 | CLS Loss: 0.0025432920083403587
Epoch 156 / 200 | iteration 50 / 171 | Total Loss: 3.6377947330474854 | KNN Loss: 3.612751007080078 | CLS Loss: 0.02504374459385872
Epoch 156 / 200 | iteration 60 / 171 | Total Loss: 3.571620464324951 | KNN Loss: 3.560326099395752 | CLS Loss: 0.01129437331110239
Epoch 156 / 200 | iteration 70 / 171 | Total Loss: 3.586012125015259 | KNN Loss: 3.5716958045959473 | CLS Loss: 0.014316221699118614
Epoch 156 / 200 | iteration 80 / 171 | Total Loss: 3.653892755508423 |

Epoch 159 / 200 | iteration 70 / 171 | Total Loss: 3.5892202854156494 | KNN Loss: 3.5673866271972656 | CLS Loss: 0.021833760663866997
Epoch 159 / 200 | iteration 80 / 171 | Total Loss: 3.623215675354004 | KNN Loss: 3.6153366565704346 | CLS Loss: 0.007879078388214111
Epoch 159 / 200 | iteration 90 / 171 | Total Loss: 3.6845314502716064 | KNN Loss: 3.6749908924102783 | CLS Loss: 0.009540514089167118
Epoch 159 / 200 | iteration 100 / 171 | Total Loss: 3.65633225440979 | KNN Loss: 3.6339476108551025 | CLS Loss: 0.022384658455848694
Epoch 159 / 200 | iteration 110 / 171 | Total Loss: 3.6059606075286865 | KNN Loss: 3.5956501960754395 | CLS Loss: 0.010310517624020576
Epoch 159 / 200 | iteration 120 / 171 | Total Loss: 3.6664841175079346 | KNN Loss: 3.6341187953948975 | CLS Loss: 0.032365210354328156
Epoch 159 / 200 | iteration 130 / 171 | Total Loss: 3.5950911045074463 | KNN Loss: 3.572720766067505 | CLS Loss: 0.022370368242263794
Epoch 159 / 200 | iteration 140 / 171 | Total Loss: 3.59227681

Epoch 162 / 200 | iteration 130 / 171 | Total Loss: 3.561610460281372 | KNN Loss: 3.5555505752563477 | CLS Loss: 0.006059957202523947
Epoch 162 / 200 | iteration 140 / 171 | Total Loss: 3.6036105155944824 | KNN Loss: 3.5765738487243652 | CLS Loss: 0.027036620303988457
Epoch 162 / 200 | iteration 150 / 171 | Total Loss: 3.5745978355407715 | KNN Loss: 3.571420192718506 | CLS Loss: 0.00317764631472528
Epoch 162 / 200 | iteration 160 / 171 | Total Loss: 3.631453275680542 | KNN Loss: 3.6268575191497803 | CLS Loss: 0.0045957970432937145
Epoch 162 / 200 | iteration 170 / 171 | Total Loss: 3.6564853191375732 | KNN Loss: 3.650585651397705 | CLS Loss: 0.005899720825254917
Epoch: 162, Loss: 3.6073, Train: 0.9972, Valid: 0.9870, Best: 0.9877
Epoch 163 / 200 | iteration 0 / 171 | Total Loss: 3.58699631690979 | KNN Loss: 3.5810065269470215 | CLS Loss: 0.005989685654640198
Epoch 163 / 200 | iteration 10 / 171 | Total Loss: 3.6058897972106934 | KNN Loss: 3.595282793045044 | CLS Loss: 0.010607009753584

Epoch 166 / 200 | iteration 10 / 171 | Total Loss: 3.610295295715332 | KNN Loss: 3.599414587020874 | CLS Loss: 0.010880750603973866
Epoch 166 / 200 | iteration 20 / 171 | Total Loss: 3.600490093231201 | KNN Loss: 3.5959813594818115 | CLS Loss: 0.004508644342422485
Epoch 166 / 200 | iteration 30 / 171 | Total Loss: 3.6041746139526367 | KNN Loss: 3.590022087097168 | CLS Loss: 0.01415251474827528
Epoch 166 / 200 | iteration 40 / 171 | Total Loss: 3.6109721660614014 | KNN Loss: 3.5998878479003906 | CLS Loss: 0.011084415018558502
Epoch 166 / 200 | iteration 50 / 171 | Total Loss: 3.6076087951660156 | KNN Loss: 3.5857436656951904 | CLS Loss: 0.0218652430921793
Epoch 166 / 200 | iteration 60 / 171 | Total Loss: 3.612952709197998 | KNN Loss: 3.600865125656128 | CLS Loss: 0.012087572365999222
Epoch 166 / 200 | iteration 70 / 171 | Total Loss: 3.6187756061553955 | KNN Loss: 3.599456548690796 | CLS Loss: 0.01931903511285782
Epoch 166 / 200 | iteration 80 / 171 | Total Loss: 3.6160459518432617 | K

Epoch 169 / 200 | iteration 70 / 171 | Total Loss: 3.619835376739502 | KNN Loss: 3.614124298095703 | CLS Loss: 0.0057111261412501335
Epoch 169 / 200 | iteration 80 / 171 | Total Loss: 3.659003257751465 | KNN Loss: 3.6524784564971924 | CLS Loss: 0.006524842232465744
Epoch 169 / 200 | iteration 90 / 171 | Total Loss: 3.674769163131714 | KNN Loss: 3.6512980461120605 | CLS Loss: 0.023471098393201828
Epoch 169 / 200 | iteration 100 / 171 | Total Loss: 3.6432082653045654 | KNN Loss: 3.6216230392456055 | CLS Loss: 0.021585343405604362
Epoch 169 / 200 | iteration 110 / 171 | Total Loss: 3.587221145629883 | KNN Loss: 3.576486110687256 | CLS Loss: 0.010734958574175835
Epoch 169 / 200 | iteration 120 / 171 | Total Loss: 3.6051242351531982 | KNN Loss: 3.6006057262420654 | CLS Loss: 0.004518571309745312
Epoch 169 / 200 | iteration 130 / 171 | Total Loss: 3.5958802700042725 | KNN Loss: 3.586392879486084 | CLS Loss: 0.00948740728199482
Epoch 169 / 200 | iteration 140 / 171 | Total Loss: 3.61589217185

Epoch 172 / 200 | iteration 130 / 171 | Total Loss: 3.6051454544067383 | KNN Loss: 3.5983338356018066 | CLS Loss: 0.006811695639044046
Epoch 172 / 200 | iteration 140 / 171 | Total Loss: 3.595937967300415 | KNN Loss: 3.583299160003662 | CLS Loss: 0.012638726271688938
Epoch 172 / 200 | iteration 150 / 171 | Total Loss: 3.6070070266723633 | KNN Loss: 3.5919599533081055 | CLS Loss: 0.015047064982354641
Epoch 172 / 200 | iteration 160 / 171 | Total Loss: 3.668304920196533 | KNN Loss: 3.660273551940918 | CLS Loss: 0.008031333796679974
Epoch 172 / 200 | iteration 170 / 171 | Total Loss: 3.6334006786346436 | KNN Loss: 3.622349977493286 | CLS Loss: 0.01105080172419548
Epoch: 172, Loss: 3.6123, Train: 0.9956, Valid: 0.9863, Best: 0.9885
Epoch 173 / 200 | iteration 0 / 171 | Total Loss: 3.6683459281921387 | KNN Loss: 3.640483856201172 | CLS Loss: 0.02786218374967575
Epoch 173 / 200 | iteration 10 / 171 | Total Loss: 3.585043430328369 | KNN Loss: 3.5653820037841797 | CLS Loss: 0.01966133154928684

Epoch 176 / 200 | iteration 10 / 171 | Total Loss: 3.5862302780151367 | KNN Loss: 3.569478988647461 | CLS Loss: 0.016751373186707497
Epoch 176 / 200 | iteration 20 / 171 | Total Loss: 3.5696170330047607 | KNN Loss: 3.563492774963379 | CLS Loss: 0.00612422963604331
Epoch 176 / 200 | iteration 30 / 171 | Total Loss: 3.5802314281463623 | KNN Loss: 3.572284460067749 | CLS Loss: 0.007946949452161789
Epoch 176 / 200 | iteration 40 / 171 | Total Loss: 3.575937271118164 | KNN Loss: 3.5668232440948486 | CLS Loss: 0.009114140644669533
Epoch 176 / 200 | iteration 50 / 171 | Total Loss: 3.6394729614257812 | KNN Loss: 3.6343910694122314 | CLS Loss: 0.005081782117486
Epoch 176 / 200 | iteration 60 / 171 | Total Loss: 3.6693246364593506 | KNN Loss: 3.650700569152832 | CLS Loss: 0.01862398535013199
Epoch 176 / 200 | iteration 70 / 171 | Total Loss: 3.5968399047851562 | KNN Loss: 3.5809834003448486 | CLS Loss: 0.01585659757256508
Epoch 176 / 200 | iteration 80 / 171 | Total Loss: 3.5746941566467285 | K

Epoch 179 / 200 | iteration 70 / 171 | Total Loss: 3.608278512954712 | KNN Loss: 3.591878890991211 | CLS Loss: 0.016399675980210304
Epoch 179 / 200 | iteration 80 / 171 | Total Loss: 3.629469156265259 | KNN Loss: 3.6031579971313477 | CLS Loss: 0.02631116472184658
Epoch 179 / 200 | iteration 90 / 171 | Total Loss: 3.5785350799560547 | KNN Loss: 3.573194980621338 | CLS Loss: 0.005340203642845154
Epoch 179 / 200 | iteration 100 / 171 | Total Loss: 3.6118016242980957 | KNN Loss: 3.601712226867676 | CLS Loss: 0.010089440271258354
Epoch 179 / 200 | iteration 110 / 171 | Total Loss: 3.6131718158721924 | KNN Loss: 3.5885632038116455 | CLS Loss: 0.024608612060546875
Epoch 179 / 200 | iteration 120 / 171 | Total Loss: 3.6418910026550293 | KNN Loss: 3.6226749420166016 | CLS Loss: 0.019216017797589302
Epoch 179 / 200 | iteration 130 / 171 | Total Loss: 3.5918469429016113 | KNN Loss: 3.5655362606048584 | CLS Loss: 0.026310762390494347
Epoch 179 / 200 | iteration 140 / 171 | Total Loss: 3.6133613586

Epoch 182 / 200 | iteration 130 / 171 | Total Loss: 3.652622938156128 | KNN Loss: 3.6363136768341064 | CLS Loss: 0.01630914770066738
Epoch 182 / 200 | iteration 140 / 171 | Total Loss: 3.5854251384735107 | KNN Loss: 3.582097291946411 | CLS Loss: 0.0033279391936957836
Epoch 182 / 200 | iteration 150 / 171 | Total Loss: 3.6310925483703613 | KNN Loss: 3.627176284790039 | CLS Loss: 0.0039161862805485725
Epoch 182 / 200 | iteration 160 / 171 | Total Loss: 3.628720760345459 | KNN Loss: 3.601499557495117 | CLS Loss: 0.0272213164716959
Epoch 182 / 200 | iteration 170 / 171 | Total Loss: 3.613668203353882 | KNN Loss: 3.6050148010253906 | CLS Loss: 0.008653485216200352
Epoch: 182, Loss: 3.6052, Train: 0.9967, Valid: 0.9869, Best: 0.9885
Epoch 183 / 200 | iteration 0 / 171 | Total Loss: 3.586686134338379 | KNN Loss: 3.5837438106536865 | CLS Loss: 0.002942229388281703
Epoch 183 / 200 | iteration 10 / 171 | Total Loss: 3.640517234802246 | KNN Loss: 3.6372783184051514 | CLS Loss: 0.00323893688619136

Epoch 186 / 200 | iteration 10 / 171 | Total Loss: 3.5789926052093506 | KNN Loss: 3.5747244358062744 | CLS Loss: 0.004268075339496136
Epoch 186 / 200 | iteration 20 / 171 | Total Loss: 3.6802499294281006 | KNN Loss: 3.6704108715057373 | CLS Loss: 0.009839098900556564
Epoch 186 / 200 | iteration 30 / 171 | Total Loss: 3.640944242477417 | KNN Loss: 3.633958339691162 | CLS Loss: 0.006985971704125404
Epoch 186 / 200 | iteration 40 / 171 | Total Loss: 3.641639471054077 | KNN Loss: 3.6195478439331055 | CLS Loss: 0.02209164947271347
Epoch 186 / 200 | iteration 50 / 171 | Total Loss: 3.5816233158111572 | KNN Loss: 3.578193426132202 | CLS Loss: 0.0034298275131732225
Epoch 186 / 200 | iteration 60 / 171 | Total Loss: 3.5820138454437256 | KNN Loss: 3.576108932495117 | CLS Loss: 0.005904989317059517
Epoch 186 / 200 | iteration 70 / 171 | Total Loss: 3.6609060764312744 | KNN Loss: 3.656869888305664 | CLS Loss: 0.004036185797303915
Epoch 186 / 200 | iteration 80 / 171 | Total Loss: 3.561441898345947

Epoch 189 / 200 | iteration 70 / 171 | Total Loss: 3.600261926651001 | KNN Loss: 3.5846450328826904 | CLS Loss: 0.015616881661117077
Epoch 189 / 200 | iteration 80 / 171 | Total Loss: 3.6257970333099365 | KNN Loss: 3.6098709106445312 | CLS Loss: 0.015926232561469078
Epoch 189 / 200 | iteration 90 / 171 | Total Loss: 3.6277945041656494 | KNN Loss: 3.61649489402771 | CLS Loss: 0.01129960548132658
Epoch 189 / 200 | iteration 100 / 171 | Total Loss: 3.6359384059906006 | KNN Loss: 3.6051104068756104 | CLS Loss: 0.03082795813679695
Epoch 189 / 200 | iteration 110 / 171 | Total Loss: 3.627716302871704 | KNN Loss: 3.6101317405700684 | CLS Loss: 0.017584552988409996
Epoch 189 / 200 | iteration 120 / 171 | Total Loss: 3.651902914047241 | KNN Loss: 3.629162549972534 | CLS Loss: 0.022740429267287254
Epoch 189 / 200 | iteration 130 / 171 | Total Loss: 3.6467716693878174 | KNN Loss: 3.6378331184387207 | CLS Loss: 0.008938651531934738
Epoch 189 / 200 | iteration 140 / 171 | Total Loss: 3.625596046447

Epoch 192 / 200 | iteration 130 / 171 | Total Loss: 3.605968475341797 | KNN Loss: 3.589134454727173 | CLS Loss: 0.01683402806520462
Epoch 192 / 200 | iteration 140 / 171 | Total Loss: 3.5807862281799316 | KNN Loss: 3.576380729675293 | CLS Loss: 0.0044055818580091
Epoch 192 / 200 | iteration 150 / 171 | Total Loss: 3.5892229080200195 | KNN Loss: 3.5634138584136963 | CLS Loss: 0.02580908313393593
Epoch 192 / 200 | iteration 160 / 171 | Total Loss: 3.624157667160034 | KNN Loss: 3.6008410453796387 | CLS Loss: 0.023316603153944016
Epoch 192 / 200 | iteration 170 / 171 | Total Loss: 3.6173601150512695 | KNN Loss: 3.596792697906494 | CLS Loss: 0.020567480474710464
Epoch: 192, Loss: 3.6088, Train: 0.9963, Valid: 0.9871, Best: 0.9885
Epoch 193 / 200 | iteration 0 / 171 | Total Loss: 3.617873191833496 | KNN Loss: 3.600017547607422 | CLS Loss: 0.017855700105428696
Epoch 193 / 200 | iteration 10 / 171 | Total Loss: 3.620344877243042 | KNN Loss: 3.585192918777466 | CLS Loss: 0.03515198081731796
Epo

Epoch 196 / 200 | iteration 10 / 171 | Total Loss: 3.56596040725708 | KNN Loss: 3.5565896034240723 | CLS Loss: 0.009370706044137478
Epoch 196 / 200 | iteration 20 / 171 | Total Loss: 3.569096565246582 | KNN Loss: 3.5610709190368652 | CLS Loss: 0.00802560430020094
Epoch 196 / 200 | iteration 30 / 171 | Total Loss: 3.7674403190612793 | KNN Loss: 3.7329766750335693 | CLS Loss: 0.034463535994291306
Epoch 196 / 200 | iteration 40 / 171 | Total Loss: 3.6068832874298096 | KNN Loss: 3.5955653190612793 | CLS Loss: 0.011317950673401356
Epoch 196 / 200 | iteration 50 / 171 | Total Loss: 3.605165481567383 | KNN Loss: 3.593226432800293 | CLS Loss: 0.011938974261283875
Epoch 196 / 200 | iteration 60 / 171 | Total Loss: 3.58068585395813 | KNN Loss: 3.57743501663208 | CLS Loss: 0.003250734182074666
Epoch 196 / 200 | iteration 70 / 171 | Total Loss: 3.6278390884399414 | KNN Loss: 3.586243152618408 | CLS Loss: 0.041595980525016785
Epoch 196 / 200 | iteration 80 / 171 | Total Loss: 3.6637308597564697 | K

Epoch 199 / 200 | iteration 80 / 171 | Total Loss: 3.5880959033966064 | KNN Loss: 3.5725131034851074 | CLS Loss: 0.015582882799208164
Epoch 199 / 200 | iteration 90 / 171 | Total Loss: 3.620037078857422 | KNN Loss: 3.5903987884521484 | CLS Loss: 0.029638340696692467
Epoch 199 / 200 | iteration 100 / 171 | Total Loss: 3.633448362350464 | KNN Loss: 3.6003639698028564 | CLS Loss: 0.03308449313044548
Epoch 199 / 200 | iteration 110 / 171 | Total Loss: 3.605926513671875 | KNN Loss: 3.601919174194336 | CLS Loss: 0.004007227253168821
Epoch 199 / 200 | iteration 120 / 171 | Total Loss: 3.648345708847046 | KNN Loss: 3.6371028423309326 | CLS Loss: 0.011242974549531937
Epoch 199 / 200 | iteration 130 / 171 | Total Loss: 3.5967555046081543 | KNN Loss: 3.5879805088043213 | CLS Loss: 0.00877506285905838
Epoch 199 / 200 | iteration 140 / 171 | Total Loss: 3.622724771499634 | KNN Loss: 3.608492374420166 | CLS Loss: 0.014232449233531952
Epoch 199 / 200 | iteration 150 / 171 | Total Loss: 3.611444473266

In [8]:
test(model, test_data_iter, device)

tensor(0.9865)

In [9]:
plt.figure()
plt.plot(losses, label='train loss')
plt.legend()
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [10]:
plt.figure()
plt.plot(train_accs, label='train accuracy')
plt.plot(val_accs, label='validation accuracy')
plt.legend()
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [11]:
test_samples = torch.tensor([])
projections = torch.tensor([])
labels = torch.tensor([])

with torch.no_grad():
    for x, y in tqdm(test_data_iter):
        test_samples = torch.cat([test_samples, x])
        labels = torch.cat([labels, y])
        x = x.to(device)
        _, interm = model(x, True)
        projections = torch.cat([projections, interm.detach().cpu().flatten(1)])

HBox(children=(FloatProgress(value=0.0, max=43.0), HTML(value='')))




In [12]:
distances = pairwise_distances(projections)
# distances = np.triu(distances)
distances_f = distances.flatten()

plt.matshow(distances)
plt.colorbar()
plt.figure()
plt.hist(distances_f[distances_f > 0], bins=1000)
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [13]:
res = []
for i in np.arange(0.5, 4, 0.1):
    clusters = DBSCAN(eps=i, min_samples=10).fit_predict(projections)
    print(f"Number of inliers: {sum(clusters != -1) / len(clusters)}")
    res.append(sum(clusters != -1) / len(clusters))

Number of inliers: 0.618564706957197
Number of inliers: 0.6916084235530583
Number of inliers: 0.7413548947055868
Number of inliers: 0.7792243387693573
Number of inliers: 0.8090996299849254
Number of inliers: 0.8321227901877484
Number of inliers: 0.8487506281120095
Number of inliers: 0.865743913023617
Number of inliers: 0.881138367365584
Number of inliers: 0.8950710337581654
Number of inliers: 0.9076332739481979
Number of inliers: 0.9195103010369559
Number of inliers: 0.9304737106573477
Number of inliers: 0.939792608834681
Number of inliers: 0.9477867616828833
Number of inliers: 0.9554154675437394
Number of inliers: 0.9634096203919419
Number of inliers: 0.9691197295692293
Number of inliers: 0.9738705404047325
Number of inliers: 0.9789411173541638
Number of inliers: 0.9821387784934448
Number of inliers: 0.9841944177972682
Number of inliers: 0.9866155040884381
Number of inliers: 0.9886711433922617
Number of inliers: 0.990407016582157
Number of inliers: 0.9923712941391439
Number of inliers

In [26]:
plt.figure()
plt.plot(np.arange(0.5, 4, 0.1), res)
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [36]:
clusters = DBSCAN(eps=1.5, min_samples=25).fit_predict(projections)

In [37]:
print(f"Number of inliers: {sum(clusters != -1) / len(clusters)}")

Number of inliers: 0.8923758622264858


In [38]:
perplexity = 100
p = reduce_dims_and_plot(projections[clusters != -1],
                         y=clusters[clusters != -1],
                         title=f'perplexity: {perplexity}',
                         file_name=None,
                         perplexity=perplexity,
                         library='Multicore-TSNE',
                         perform_PCA=False,
                         projected=None,
                         figure_type='2d',
                         show_figure=True,
                         close_figure=False,
                         text=None)    

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

# Train a Soft-Decision-Tree given the self-labels

## Prepare the dataset

In [39]:
tree_dataset = list(zip(test_samples.flatten(1)[clusters!=-1], clusters[clusters != -1]))
batch_size = 512
tree_loader = torch.utils.data.DataLoader(tree_dataset, batch_size=batch_size, shuffle=True)

# Define how we prune the weights of a node

In [40]:
def prune_node(node_weights, factor=1):
    w = node_weights.cpu().detach().numpy()
    mean_ = np.mean(w)
    std_ = np.std(w)
    node_weights[((mean_ - std_ * factor) < node_weights) & (node_weights < (mean_ + std_ * factor))] = 0
    return node_weights

def prune_node_keep(node_weights, keep=4):
    w = node_weights.cpu().detach().numpy()
    throw_idx = np.argsort(abs(w))[:-keep]
    node_weights[throw_idx] = 0
    return node_weights

def prune_tree(tree_, factor):
    new_weights = tree_.inner_nodes.weight.clone()
    for i in range(new_weights.shape[0]):
        res = prune_node_keep(new_weights[i, :], factor)
        new_weights[i, :] = res

    with torch.no_grad():
        tree_.inner_nodes.weight.copy_(new_weights)
        
def sparseness(x):
    s = []
    for i in range(x.shape[0]):
        x_ = x[i, :]
        sp = (len(x_) - torch.norm(x_, 0).item()) / len(x_)
        s.append(sp)
    return np.mean(s)

def compute_regularization_by_level(tree):
    total_reg = 0
    for i in range(tree.inner_nodes.weight.shape[0]):
        cur_level = np.floor(np.log2(i+1))
        node_reg = torch.norm(tree.inner_nodes.weight[i].view(-1), 2)
        total_reg += 2**(-cur_level) * node_reg
    return total_reg

def show_sparseness(tree):
    avg_sp = sparseness(tree.inner_nodes.weight)
    print(f"Average sparseness: {avg_sp}")
    layer = 0
    sps = []
    for i in range(tree.inner_nodes.weight.shape[0]):
        cur_layer = int(np.floor(np.log2(i+1)))
        if cur_layer != layer:
            print(f"layer {layer}: {np.mean(sps)}")
            sps = []
            layer = cur_layer

        x_ = tree.inner_nodes.weight[i, :]
        sp = (len(x_) - torch.norm(x_, 0).item()) / len(x_)
        sps.append(sp)
        
    return avg_sp

## Training configurations

In [41]:
def do_epoch(model, loader, device, log_interval, losses, accs, epoch, iteration):
    model = model.train()
    for batch_idx, (data, target) in enumerate(loader):
        iteration += 1
        data, target = data.to(device), target.to(device)

        output, penalty = tree.forward(data)

        # Loss
        loss_tree = criterion(output, target.view(-1))

        # Penalty
        loss_tree += penalty

        # Sparse regularization
#         fc_params = torch.cat([x.view(-1) for x in tree.inner_nodes.parameters()])
#         regularization = sparsity_lamda * torch.norm(fc_params, 2)
        regularization = sparsity_lamda * compute_regularization_by_level(tree)
        loss = loss_tree

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        losses.append(loss.item())

        pred = output.data.max(1)[1]
        correct = pred.eq(target.view(-1).data).sum()
        accs.append(correct.item() / data.size()[0])

        # Print training status
        if batch_idx % log_interval == 0:
            print(f"Epoch: {epoch:02d} | Batch: {batch_idx:03d} / {len(loader):03d} | Total loss: {loss.item():.3f} | Reg loss: {regularization.item():.3f} | Tree loss: {loss_tree.item():.3f} | Accuracy: {correct.item() / data.size()[0]:03f} | {round((time.time() - start_time) / iteration, 3)} sec/iter")
            
    return iteration


In [46]:
lr = 5e-3
weight_decay = 5e-4
sparsity_lamda = 2e-3
epochs = 400
log_interval = 10
use_cuda = device != 'cpu'

In [47]:
tree = SDT(input_dim=test_samples.shape[2], output_dim=len(set(clusters)) - 1, depth=tree_depth, lamda=1e-3, use_cuda=use_cuda)
optimizer = torch.optim.Adam(tree.parameters(),
                                 lr=lr,
                                 weight_decay=weight_decay)
criterion = nn.CrossEntropyLoss()
tree = tree.to(device)

In [48]:
losses = []
accs = []
sparsity = []

In [None]:
start_time = time.time()
iteration = 0
for epoch in range(epochs):
    # Training
    avg_sp = show_sparseness(tree)
    sparsity.append(avg_sp)
    iteration = do_epoch(tree, tree_loader, device, log_interval, losses, accs, epoch, iteration)
    
    if epoch % 1 == 0:
        prune_tree(tree, factor=1)
        

Average sparseness: 0.0
layer 0: 0.0
layer 1: 0.0
layer 2: 0.0
layer 3: 0.0
layer 4: 0.0
layer 5: 0.0
layer 6: 0.0
layer 7: 0.0
layer 8: 0.0
Epoch: 00 | Batch: 000 / 039 | Total loss: 2.504 | Reg loss: 0.011 | Tree loss: 2.504 | Accuracy: 0.027344 | 7.191 sec/iter
Epoch: 00 | Batch: 010 / 039 | Total loss: 2.358 | Reg loss: 0.012 | Tree loss: 2.358 | Accuracy: 0.179688 | 6.568 sec/iter
Epoch: 00 | Batch: 020 / 039 | Total loss: 2.228 | Reg loss: 0.014 | Tree loss: 2.228 | Accuracy: 0.298828 | 6.546 sec/iter
Epoch: 00 | Batch: 030 / 039 | Total loss: 2.087 | Reg loss: 0.018 | Tree loss: 2.087 | Accuracy: 0.404297 | 6.471 sec/iter
Average sparseness: 0.9946808510638298
layer 0: 0.9946808510638298
layer 1: 0.9946808510638298
layer 2: 0.9946808510638298
layer 3: 0.9946808510638298
layer 4: 0.9946808510638298
layer 5: 0.9946808510638298
layer 6: 0.9946808510638298
layer 7: 0.9946808510638298
layer 8: 0.9946808510638298
Epoch: 01 | Batch: 000 / 039 | Total loss: 2.493 | Reg loss: 0.004 | Tre

Epoch: 10 | Batch: 020 / 039 | Total loss: 1.993 | Reg loss: 0.014 | Tree loss: 1.993 | Accuracy: 0.398438 | 6.256 sec/iter
Epoch: 10 | Batch: 030 / 039 | Total loss: 1.799 | Reg loss: 0.017 | Tree loss: 1.799 | Accuracy: 0.400391 | 6.258 sec/iter
Average sparseness: 0.9946808510638298
layer 0: 0.9946808510638298
layer 1: 0.9946808510638298
layer 2: 0.9946808510638298
layer 3: 0.9946808510638298
layer 4: 0.9946808510638298
layer 5: 0.9946808510638298
layer 6: 0.9946808510638298
layer 7: 0.9946808510638298
layer 8: 0.9946808510638298
Epoch: 11 | Batch: 000 / 039 | Total loss: 2.474 | Reg loss: 0.010 | Tree loss: 2.474 | Accuracy: 0.146484 | 6.245 sec/iter
Epoch: 11 | Batch: 010 / 039 | Total loss: 2.243 | Reg loss: 0.012 | Tree loss: 2.243 | Accuracy: 0.199219 | 6.242 sec/iter
Epoch: 11 | Batch: 020 / 039 | Total loss: 1.965 | Reg loss: 0.015 | Tree loss: 1.965 | Accuracy: 0.402344 | 6.248 sec/iter
Epoch: 11 | Batch: 030 / 039 | Total loss: 1.733 | Reg loss: 0.017 | Tree loss: 1.733 | A

Epoch: 21 | Batch: 000 / 039 | Total loss: 2.450 | Reg loss: 0.013 | Tree loss: 2.450 | Accuracy: 0.294922 | 6.206 sec/iter
Epoch: 21 | Batch: 010 / 039 | Total loss: 2.071 | Reg loss: 0.015 | Tree loss: 2.071 | Accuracy: 0.357422 | 6.21 sec/iter
Epoch: 21 | Batch: 020 / 039 | Total loss: 1.765 | Reg loss: 0.017 | Tree loss: 1.765 | Accuracy: 0.476562 | 6.208 sec/iter
Epoch: 21 | Batch: 030 / 039 | Total loss: 1.646 | Reg loss: 0.019 | Tree loss: 1.646 | Accuracy: 0.458984 | 6.207 sec/iter
Average sparseness: 0.9946808510638298
layer 0: 0.9946808510638298
layer 1: 0.9946808510638298
layer 2: 0.9946808510638298
layer 3: 0.9946808510638298
layer 4: 0.9946808510638298
layer 5: 0.9946808510638298
layer 6: 0.9946808510638298
layer 7: 0.9946808510638298
layer 8: 0.9946808510638298
Epoch: 22 | Batch: 000 / 039 | Total loss: 2.451 | Reg loss: 0.014 | Tree loss: 2.451 | Accuracy: 0.250000 | 6.199 sec/iter
Epoch: 22 | Batch: 010 / 039 | Total loss: 2.093 | Reg loss: 0.015 | Tree loss: 2.093 | Ac

Epoch: 31 | Batch: 030 / 039 | Total loss: 1.556 | Reg loss: 0.021 | Tree loss: 1.556 | Accuracy: 0.435547 | 6.129 sec/iter
Average sparseness: 0.9946808510638298
layer 0: 0.9946808510638298
layer 1: 0.9946808510638298
layer 2: 0.9946808510638298
layer 3: 0.9946808510638298
layer 4: 0.9946808510638298
layer 5: 0.9946808510638298
layer 6: 0.9946808510638298
layer 7: 0.9946808510638298
layer 8: 0.9946808510638298
Epoch: 32 | Batch: 000 / 039 | Total loss: 2.427 | Reg loss: 0.016 | Tree loss: 2.427 | Accuracy: 0.312500 | 6.123 sec/iter
Epoch: 32 | Batch: 010 / 039 | Total loss: 2.059 | Reg loss: 0.017 | Tree loss: 2.059 | Accuracy: 0.433594 | 6.12 sec/iter
Epoch: 32 | Batch: 020 / 039 | Total loss: 1.698 | Reg loss: 0.019 | Tree loss: 1.698 | Accuracy: 0.474609 | 6.118 sec/iter
Epoch: 32 | Batch: 030 / 039 | Total loss: 1.488 | Reg loss: 0.021 | Tree loss: 1.488 | Accuracy: 0.496094 | 6.115 sec/iter
Average sparseness: 0.9946808510638298
layer 0: 0.9946808510638298
layer 1: 0.994680851063

Epoch: 42 | Batch: 000 / 039 | Total loss: 2.421 | Reg loss: 0.018 | Tree loss: 2.421 | Accuracy: 0.300781 | 6.048 sec/iter
Epoch: 42 | Batch: 010 / 039 | Total loss: 1.988 | Reg loss: 0.019 | Tree loss: 1.988 | Accuracy: 0.484375 | 6.05 sec/iter
Epoch: 42 | Batch: 020 / 039 | Total loss: 1.661 | Reg loss: 0.021 | Tree loss: 1.661 | Accuracy: 0.468750 | 6.051 sec/iter
Epoch: 42 | Batch: 030 / 039 | Total loss: 1.453 | Reg loss: 0.023 | Tree loss: 1.453 | Accuracy: 0.501953 | 6.052 sec/iter
Average sparseness: 0.9946808510638298
layer 0: 0.9946808510638298
layer 1: 0.9946808510638298
layer 2: 0.9946808510638298
layer 3: 0.9946808510638298
layer 4: 0.9946808510638298
layer 5: 0.9946808510638298
layer 6: 0.9946808510638298
layer 7: 0.9946808510638298
layer 8: 0.9946808510638298
Epoch: 43 | Batch: 000 / 039 | Total loss: 2.424 | Reg loss: 0.019 | Tree loss: 2.424 | Accuracy: 0.257812 | 6.05 sec/iter
Epoch: 43 | Batch: 010 / 039 | Total loss: 2.026 | Reg loss: 0.019 | Tree loss: 2.026 | Acc

Epoch: 52 | Batch: 030 / 039 | Total loss: 1.470 | Reg loss: 0.024 | Tree loss: 1.470 | Accuracy: 0.539062 | 6.045 sec/iter
Average sparseness: 0.9946808510638298
layer 0: 0.9946808510638298
layer 1: 0.9946808510638298
layer 2: 0.9946808510638298
layer 3: 0.9946808510638298
layer 4: 0.9946808510638298
layer 5: 0.9946808510638298
layer 6: 0.9946808510638298
layer 7: 0.9946808510638298
layer 8: 0.9946808510638298
Epoch: 53 | Batch: 000 / 039 | Total loss: 2.399 | Reg loss: 0.020 | Tree loss: 2.399 | Accuracy: 0.281250 | 6.048 sec/iter
Epoch: 53 | Batch: 010 / 039 | Total loss: 1.923 | Reg loss: 0.021 | Tree loss: 1.923 | Accuracy: 0.466797 | 6.049 sec/iter
Epoch: 53 | Batch: 020 / 039 | Total loss: 1.624 | Reg loss: 0.022 | Tree loss: 1.624 | Accuracy: 0.494141 | 6.052 sec/iter
Epoch: 53 | Batch: 030 / 039 | Total loss: 1.413 | Reg loss: 0.024 | Tree loss: 1.413 | Accuracy: 0.533203 | 6.053 sec/iter
Average sparseness: 0.9946808510638298
layer 0: 0.9946808510638298
layer 1: 0.99468085106

Epoch: 63 | Batch: 000 / 039 | Total loss: 2.388 | Reg loss: 0.021 | Tree loss: 2.388 | Accuracy: 0.273438 | 6.07 sec/iter
Epoch: 63 | Batch: 010 / 039 | Total loss: 1.956 | Reg loss: 0.021 | Tree loss: 1.956 | Accuracy: 0.464844 | 6.072 sec/iter
Epoch: 63 | Batch: 020 / 039 | Total loss: 1.588 | Reg loss: 0.023 | Tree loss: 1.588 | Accuracy: 0.458984 | 6.075 sec/iter
Epoch: 63 | Batch: 030 / 039 | Total loss: 1.414 | Reg loss: 0.024 | Tree loss: 1.414 | Accuracy: 0.531250 | 6.075 sec/iter
Average sparseness: 0.9946808510638298
layer 0: 0.9946808510638298
layer 1: 0.9946808510638298
layer 2: 0.9946808510638298
layer 3: 0.9946808510638298
layer 4: 0.9946808510638298
layer 5: 0.9946808510638298
layer 6: 0.9946808510638298
layer 7: 0.9946808510638298
layer 8: 0.9946808510638298
Epoch: 64 | Batch: 000 / 039 | Total loss: 2.402 | Reg loss: 0.021 | Tree loss: 2.402 | Accuracy: 0.238281 | 6.073 sec/iter
Epoch: 64 | Batch: 010 / 039 | Total loss: 1.958 | Reg loss: 0.022 | Tree loss: 1.958 | Ac

In [None]:
plt.figure(figsize=(10, 5))
plt.ylabel("Accuracy")
plt.xlabel('Iteration')
plt.plot(accs, label='Accuracy vs iteration')
plt.show()

In [None]:
plt.figure()
plt.ylabel("Loss")
plt.xlabel('Iteration')
plt.plot(losses, label='Loss vs iteration')
# plt.yscale("log")
plt.show()

plt.figure()
weights = tree.inner_nodes.weight.cpu().detach().numpy().flatten()
plt.hist(weights, bins=500)
weights_std = np.std(weights)
weights_mean = np.mean(weights)
plt.axvline(weights_mean + weights_std, color='r')
plt.axvline(weights_mean - weights_std, color='r')
plt.title(f"Mean: {weights_mean}   |   STD: {weights_std}")
plt.yscale("log")
plt.show()

# Tree Visualization

In [None]:
plt.figure(figsize=(15, 10), dpi=80)
avg_height, root = tree.visualize()

# Extract Rules

# Accumulate samples in the leaves

In [None]:
print(f"Number of patterns: {len(root.get_leaves())}")

In [None]:
method = 'greedy'

In [None]:
root.clear_leaves_samples()

with torch.no_grad():
    for batch_idx, (data, target) in enumerate(tree_loader):
        root.accumulate_samples(data, method)

# Tighten boundaries

In [None]:
attr_names = dataset.items

# print(attr_names)
leaves = root.get_leaves()
sum_comprehensibility = 0
comprehensibilities = []
for pattern_counter, leaf in enumerate(leaves):
    leaf.reset_path()
    leaf.tighten_with_accumulated_samples()
    conds = leaf.get_path_conditions(attr_names)
    print(f"============== Pattern {pattern_counter + 1} ==============")
    comprehensibilities.append(sum([cond.comprehensibility for cond in conds]))
    
print(f"Average comprehensibility: {np.mean(comprehensibilities)}")
print(f"std comprehensibility: {np.std(comprehensibilities)}")
print(f"var comprehensibility: {np.var(comprehensibilities)}")
print(f"minimum comprehensibility: {np.min(comprehensibilities)}")
print(f"maximum comprehensibility: {np.max(comprehensibilities)}")