In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib widget
import numpy as np
import pandas as pd
import torch
import torchvision
from torch import optim
import matplotlib.pyplot as plt
import torch.nn.functional as F
from sklearn.metrics import pairwise_distances
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tqdm.notebook import tqdm
import os
import sys
import time
import torch.nn as nn
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)
from stream_generators.mit_bih import MITBIH
from utils.MatplotlibUtils import reduce_dims_and_plot
from network.auto_encoder import AutoEncoder
from losses.knn_loss import KNNLoss, ClassificationKNNLoss
from sklearn.cluster import DBSCAN
from sklearn.mixture import GaussianMixture
from soft_decision_tree.sdt_model import SDT
from sklearn.metrics import davies_bouldin_score

In [2]:
k = 8
tree_depth = 10
batch_size = 512
device = 'cuda'
train_data_path = r'<>/mitbih_train.csv'  # replace <> with the correct path of the dataset
test_data_path = r'<>/mitbih_test.csv'  # replace <> with the correct path of the dataset

In [3]:
train_data_iter = torch.utils.data.DataLoader(MITBIH(train_data_path),
                                             batch_size=batch_size,
                                             shuffle=True,
                                             num_workers=1,
                                             pin_memory=True,
                                             drop_last=True)

test_data_iter = torch.utils.data.DataLoader(MITBIH(test_data_path),
                                             batch_size=batch_size,
                                             shuffle=True,
                                             num_workers=1,
                                             pin_memory=True)

In [4]:
class ConvBlock(nn.Module):
    def __init__(self):
        super(ConvBlock, self).__init__()
        self.conv1 = nn.Conv1d(32, 32, kernel_size=5, stride=1, padding=2)
        self.conv2 = nn.Conv1d(32, 32, kernel_size=5, stride=1, padding=2)
        self.relu1 = nn.ReLU()
        self.relu2 = nn.ReLU()
        self.pool = nn.MaxPool1d(kernel_size=5, stride=2)
        
    def forward(self, x):
        y = x
        y = self.conv1(y)
        y = self.relu1(y)
        y = self.conv2(y)
        y = y + x
        y = self.relu2(y)
        y = self.pool(y)
        return y


class ECGModel(nn.Module):
    def __init__(self):
        super(ECGModel, self).__init__()
        self.conv = nn.Conv1d(1, 32, kernel_size=5, stride=1, padding=1)
        self.block1 = ConvBlock()
        self.block2 = ConvBlock()
        self.block3 = ConvBlock()
        self.block4 = ConvBlock()
        self.block5 = ConvBlock()
        self.fc1 = nn.Linear(64, 32)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(32, 5)

    def forward(self, x, return_interm=False):
        x = self.conv(x)
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        x = self.block4(x)
        x = self.block5(x)
        interm = x.flatten(1)
        x = self.fc1(interm)
        x = self.relu(x)
        x = self.fc2(x)
        
        if return_interm:
            return x, interm
        
        return x

In [5]:
knn_crt = ClassificationKNNLoss(k=k).to(device)

def train(model, loader, optimizer, device):
    model.train()

    total_loss = 0
    for iteration, (batch, target) in enumerate(loader):
        batch = batch.to(device)
        target = target.to(device)
        outputs, interm = model(batch, return_interm=True)
        mse_loss = F.cross_entropy(outputs, target)
        mse_loss = mse_loss.sum(dim=-1).mean()
        try:
            knn_loss = knn_crt(interm, target)
            if torch.isinf(knn_loss):
                knn_loss = torch.tensor(0).to(device)
        except ValueError:
            knn_loss = torch.tensor(0).to(device)
        loss = mse_loss + knn_loss
        total_loss += loss.item()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if iteration % log_every == 0:
            print(f"Epoch {epoch} / {epochs} | iteration {iteration} / {len(loader)} | Total Loss: {loss.item()} | KNN Loss: {knn_loss.item()} | CLS Loss: {mse_loss.item()}")

    return total_loss / len(loader)

@torch.no_grad()
def test(model, loader, device):
    model.eval()
    
    correct = 0
    for iteration, (batch, target) in enumerate(loader):
        batch = batch.to(device)
        target = target.to(device)
        y_pred = model(batch).argmax(dim=-1)
        correct += y_pred.eq(target.view(-1).data).sum()
    
    return correct / len(loader.dataset)

In [6]:
epochs = 200
lr = 1e-3
log_every = 10

model = ECGModel().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
num_params = sum([p.numel() for p in model.parameters()])
print(f'#Params: {num_params}')

#Params: 53957


In [7]:
best_valid_acc = 0
losses = []
train_accs = []
val_accs = []
for epoch in range(1, epochs + 1):
    loss = train(model, train_data_iter, optimizer, device)
#     print(f"Loss: {loss} =============================")
    losses.append(loss)
    train_acc = test(model, train_data_iter, device)
    train_accs.append(train_acc)
    valid_acc = test(model, test_data_iter, device)
    val_accs.append(valid_acc)
    if valid_acc > best_valid_acc:
        best_valid_acc = valid_acc
    if epoch % 1 == 0:
        print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, '
              f'Train: {train_acc:.4f}, Valid: {valid_acc:.4f}, '
              f'Best: {best_valid_acc:.4f}')

Epoch 1 / 200 | iteration 0 / 171 | Total Loss: 7.200286865234375 | KNN Loss: 5.570528030395508 | CLS Loss: 1.629758596420288
Epoch 1 / 200 | iteration 10 / 171 | Total Loss: 4.334754467010498 | KNN Loss: 2.962836265563965 | CLS Loss: 1.3719182014465332
Epoch 1 / 200 | iteration 20 / 171 | Total Loss: 3.275571346282959 | KNN Loss: 2.6373131275177 | CLS Loss: 0.6382583379745483
Epoch 1 / 200 | iteration 30 / 171 | Total Loss: 3.2182888984680176 | KNN Loss: 2.552614212036133 | CLS Loss: 0.6656746864318848
Epoch 1 / 200 | iteration 40 / 171 | Total Loss: 3.1036977767944336 | KNN Loss: 2.5522866249084473 | CLS Loss: 0.5514110922813416
Epoch 1 / 200 | iteration 50 / 171 | Total Loss: 3.072950839996338 | KNN Loss: 2.5594332218170166 | CLS Loss: 0.5135177373886108
Epoch 1 / 200 | iteration 60 / 171 | Total Loss: 3.1318154335021973 | KNN Loss: 2.584388017654419 | CLS Loss: 0.5474274754524231
Epoch 1 / 200 | iteration 70 / 171 | Total Loss: 3.086531162261963 | KNN Loss: 2.600828170776367 | CLS 

Epoch 4 / 200 | iteration 80 / 171 | Total Loss: 2.577674388885498 | KNN Loss: 2.4594333171844482 | CLS Loss: 0.11824096739292145
Epoch 4 / 200 | iteration 90 / 171 | Total Loss: 2.6667587757110596 | KNN Loss: 2.52266788482666 | CLS Loss: 0.14409081637859344
Epoch 4 / 200 | iteration 100 / 171 | Total Loss: 2.6705310344696045 | KNN Loss: 2.476156711578369 | CLS Loss: 0.19437432289123535
Epoch 4 / 200 | iteration 110 / 171 | Total Loss: 2.605355739593506 | KNN Loss: 2.477708578109741 | CLS Loss: 0.127647265791893
Epoch 4 / 200 | iteration 120 / 171 | Total Loss: 2.6377131938934326 | KNN Loss: 2.461115837097168 | CLS Loss: 0.176597461104393
Epoch 4 / 200 | iteration 130 / 171 | Total Loss: 2.645249605178833 | KNN Loss: 2.4669461250305176 | CLS Loss: 0.17830350995063782
Epoch 4 / 200 | iteration 140 / 171 | Total Loss: 2.6023247241973877 | KNN Loss: 2.5080676078796387 | CLS Loss: 0.09425705671310425
Epoch 4 / 200 | iteration 150 / 171 | Total Loss: 2.5906553268432617 | KNN Loss: 2.4685647

Epoch 7 / 200 | iteration 160 / 171 | Total Loss: 2.548036575317383 | KNN Loss: 2.425438404083252 | CLS Loss: 0.12259820103645325
Epoch 7 / 200 | iteration 170 / 171 | Total Loss: 2.602889060974121 | KNN Loss: 2.4497108459472656 | CLS Loss: 0.1531781554222107
Epoch: 007, Loss: 2.5477, Train: 0.9756, Valid: 0.9728, Best: 0.9728
Epoch 8 / 200 | iteration 0 / 171 | Total Loss: 2.501044273376465 | KNN Loss: 2.4234397411346436 | CLS Loss: 0.07760453969240189
Epoch 8 / 200 | iteration 10 / 171 | Total Loss: 2.5922679901123047 | KNN Loss: 2.467078447341919 | CLS Loss: 0.1251896321773529
Epoch 8 / 200 | iteration 20 / 171 | Total Loss: 2.538774251937866 | KNN Loss: 2.4373648166656494 | CLS Loss: 0.10140953212976456
Epoch 8 / 200 | iteration 30 / 171 | Total Loss: 2.5473005771636963 | KNN Loss: 2.436347007751465 | CLS Loss: 0.11095353960990906
Epoch 8 / 200 | iteration 40 / 171 | Total Loss: 2.5847597122192383 | KNN Loss: 2.475551128387451 | CLS Loss: 0.10920850187540054
Epoch 8 / 200 | iterati

Epoch 11 / 200 | iteration 50 / 171 | Total Loss: 2.505141258239746 | KNN Loss: 2.4568467140197754 | CLS Loss: 0.04829449579119682
Epoch 11 / 200 | iteration 60 / 171 | Total Loss: 2.4942829608917236 | KNN Loss: 2.391362190246582 | CLS Loss: 0.10292084515094757
Epoch 11 / 200 | iteration 70 / 171 | Total Loss: 2.4991743564605713 | KNN Loss: 2.4225046634674072 | CLS Loss: 0.07666979730129242
Epoch 11 / 200 | iteration 80 / 171 | Total Loss: 2.461980104446411 | KNN Loss: 2.391669750213623 | CLS Loss: 0.07031026482582092
Epoch 11 / 200 | iteration 90 / 171 | Total Loss: 2.5431132316589355 | KNN Loss: 2.470979928970337 | CLS Loss: 0.07213333249092102
Epoch 11 / 200 | iteration 100 / 171 | Total Loss: 2.5163300037384033 | KNN Loss: 2.440521001815796 | CLS Loss: 0.07580895721912384
Epoch 11 / 200 | iteration 110 / 171 | Total Loss: 2.4715356826782227 | KNN Loss: 2.4231200218200684 | CLS Loss: 0.04841567203402519
Epoch 11 / 200 | iteration 120 / 171 | Total Loss: 2.4814271926879883 | KNN Loss

Epoch 14 / 200 | iteration 120 / 171 | Total Loss: 2.5127854347229004 | KNN Loss: 2.414885997772217 | CLS Loss: 0.09789931774139404
Epoch 14 / 200 | iteration 130 / 171 | Total Loss: 2.5002973079681396 | KNN Loss: 2.405195951461792 | CLS Loss: 0.09510133415460587
Epoch 14 / 200 | iteration 140 / 171 | Total Loss: 2.460742473602295 | KNN Loss: 2.411587953567505 | CLS Loss: 0.04915445297956467
Epoch 14 / 200 | iteration 150 / 171 | Total Loss: 2.466646671295166 | KNN Loss: 2.4057438373565674 | CLS Loss: 0.06090276688337326
Epoch 14 / 200 | iteration 160 / 171 | Total Loss: 2.462177276611328 | KNN Loss: 2.3895647525787354 | CLS Loss: 0.07261242717504501
Epoch 14 / 200 | iteration 170 / 171 | Total Loss: 2.5027318000793457 | KNN Loss: 2.4199397563934326 | CLS Loss: 0.0827919989824295
Epoch: 014, Loss: 2.4868, Train: 0.9833, Valid: 0.9797, Best: 0.9797
Epoch 15 / 200 | iteration 0 / 171 | Total Loss: 2.4999172687530518 | KNN Loss: 2.444162130355835 | CLS Loss: 0.05575508251786232
Epoch 15 /

Epoch 18 / 200 | iteration 10 / 171 | Total Loss: 2.4529902935028076 | KNN Loss: 2.418497323989868 | CLS Loss: 0.03449290245771408
Epoch 18 / 200 | iteration 20 / 171 | Total Loss: 2.4805970191955566 | KNN Loss: 2.4324324131011963 | CLS Loss: 0.0481647253036499
Epoch 18 / 200 | iteration 30 / 171 | Total Loss: 2.472309112548828 | KNN Loss: 2.42435359954834 | CLS Loss: 0.04795561730861664
Epoch 18 / 200 | iteration 40 / 171 | Total Loss: 2.473250389099121 | KNN Loss: 2.416815996170044 | CLS Loss: 0.05643438547849655
Epoch 18 / 200 | iteration 50 / 171 | Total Loss: 2.468705415725708 | KNN Loss: 2.407219409942627 | CLS Loss: 0.061485905200242996
Epoch 18 / 200 | iteration 60 / 171 | Total Loss: 2.4704768657684326 | KNN Loss: 2.4206905364990234 | CLS Loss: 0.049786217510700226
Epoch 18 / 200 | iteration 70 / 171 | Total Loss: 2.485203742980957 | KNN Loss: 2.426199197769165 | CLS Loss: 0.059004444628953934
Epoch 18 / 200 | iteration 80 / 171 | Total Loss: 2.4409685134887695 | KNN Loss: 2.3

Epoch 21 / 200 | iteration 80 / 171 | Total Loss: 2.4808499813079834 | KNN Loss: 2.4251315593719482 | CLS Loss: 0.05571844428777695
Epoch 21 / 200 | iteration 90 / 171 | Total Loss: 2.433563709259033 | KNN Loss: 2.401470422744751 | CLS Loss: 0.03209329769015312
Epoch 21 / 200 | iteration 100 / 171 | Total Loss: 2.424501657485962 | KNN Loss: 2.3625714778900146 | CLS Loss: 0.06193007156252861
Epoch 21 / 200 | iteration 110 / 171 | Total Loss: 2.490281581878662 | KNN Loss: 2.4171173572540283 | CLS Loss: 0.07316433638334274
Epoch 21 / 200 | iteration 120 / 171 | Total Loss: 2.476388454437256 | KNN Loss: 2.380314588546753 | CLS Loss: 0.09607388079166412
Epoch 21 / 200 | iteration 130 / 171 | Total Loss: 2.4758460521698 | KNN Loss: 2.372680902481079 | CLS Loss: 0.10316509753465652
Epoch 21 / 200 | iteration 140 / 171 | Total Loss: 2.4553472995758057 | KNN Loss: 2.4163966178894043 | CLS Loss: 0.03895071893930435
Epoch 21 / 200 | iteration 150 / 171 | Total Loss: 2.460681915283203 | KNN Loss: 

Epoch 24 / 200 | iteration 150 / 171 | Total Loss: 2.4731428623199463 | KNN Loss: 2.3835816383361816 | CLS Loss: 0.08956122398376465
Epoch 24 / 200 | iteration 160 / 171 | Total Loss: 2.480616331100464 | KNN Loss: 2.431067705154419 | CLS Loss: 0.049548521637916565
Epoch 24 / 200 | iteration 170 / 171 | Total Loss: 2.4604032039642334 | KNN Loss: 2.4391157627105713 | CLS Loss: 0.02128739468753338
Epoch: 024, Loss: 2.4561, Train: 0.9886, Valid: 0.9839, Best: 0.9839
Epoch 25 / 200 | iteration 0 / 171 | Total Loss: 2.441122531890869 | KNN Loss: 2.382991313934326 | CLS Loss: 0.05813124030828476
Epoch 25 / 200 | iteration 10 / 171 | Total Loss: 2.463212013244629 | KNN Loss: 2.4097578525543213 | CLS Loss: 0.0534542053937912
Epoch 25 / 200 | iteration 20 / 171 | Total Loss: 2.4547388553619385 | KNN Loss: 2.3924949169158936 | CLS Loss: 0.06224387139081955
Epoch 25 / 200 | iteration 30 / 171 | Total Loss: 2.4742774963378906 | KNN Loss: 2.4180617332458496 | CLS Loss: 0.05621572211384773
Epoch 25 /

Epoch 28 / 200 | iteration 40 / 171 | Total Loss: 2.4286694526672363 | KNN Loss: 2.3808529376983643 | CLS Loss: 0.04781649261713028
Epoch 28 / 200 | iteration 50 / 171 | Total Loss: 2.4273369312286377 | KNN Loss: 2.4134469032287598 | CLS Loss: 0.013889942318201065
Epoch 28 / 200 | iteration 60 / 171 | Total Loss: 2.4584829807281494 | KNN Loss: 2.4104766845703125 | CLS Loss: 0.048006389290094376
Epoch 28 / 200 | iteration 70 / 171 | Total Loss: 2.427949905395508 | KNN Loss: 2.369306802749634 | CLS Loss: 0.05864310637116432
Epoch 28 / 200 | iteration 80 / 171 | Total Loss: 2.4265265464782715 | KNN Loss: 2.4052140712738037 | CLS Loss: 0.021312404423952103
Epoch 28 / 200 | iteration 90 / 171 | Total Loss: 2.445178508758545 | KNN Loss: 2.393740177154541 | CLS Loss: 0.051438264548778534
Epoch 28 / 200 | iteration 100 / 171 | Total Loss: 2.4476516246795654 | KNN Loss: 2.4162697792053223 | CLS Loss: 0.03138191998004913
Epoch 28 / 200 | iteration 110 / 171 | Total Loss: 2.4530062675476074 | KNN

Epoch 31 / 200 | iteration 110 / 171 | Total Loss: 2.448814868927002 | KNN Loss: 2.395123243331909 | CLS Loss: 0.0536915622651577
Epoch 31 / 200 | iteration 120 / 171 | Total Loss: 2.4839813709259033 | KNN Loss: 2.415451765060425 | CLS Loss: 0.06852968782186508
Epoch 31 / 200 | iteration 130 / 171 | Total Loss: 2.4616382122039795 | KNN Loss: 2.4349606037139893 | CLS Loss: 0.02667768858373165
Epoch 31 / 200 | iteration 140 / 171 | Total Loss: 2.4315826892852783 | KNN Loss: 2.379340410232544 | CLS Loss: 0.05224237218499184
Epoch 31 / 200 | iteration 150 / 171 | Total Loss: 2.464388847351074 | KNN Loss: 2.397658586502075 | CLS Loss: 0.06673017144203186
Epoch 31 / 200 | iteration 160 / 171 | Total Loss: 2.449435234069824 | KNN Loss: 2.4197943210601807 | CLS Loss: 0.029640814289450645
Epoch 31 / 200 | iteration 170 / 171 | Total Loss: 2.4375765323638916 | KNN Loss: 2.400148868560791 | CLS Loss: 0.03742769733071327
Epoch: 031, Loss: 2.4438, Train: 0.9897, Valid: 0.9852, Best: 0.9852
Epoch 32

Epoch: 034, Loss: 2.4382, Train: 0.9904, Valid: 0.9831, Best: 0.9852
Epoch 35 / 200 | iteration 0 / 171 | Total Loss: 2.453667640686035 | KNN Loss: 2.413444757461548 | CLS Loss: 0.04022293910384178
Epoch 35 / 200 | iteration 10 / 171 | Total Loss: 2.482313394546509 | KNN Loss: 2.3888933658599854 | CLS Loss: 0.09341995418071747
Epoch 35 / 200 | iteration 20 / 171 | Total Loss: 2.4356231689453125 | KNN Loss: 2.4200551509857178 | CLS Loss: 0.015568077564239502
Epoch 35 / 200 | iteration 30 / 171 | Total Loss: 2.428253650665283 | KNN Loss: 2.3907439708709717 | CLS Loss: 0.037509772926568985
Epoch 35 / 200 | iteration 40 / 171 | Total Loss: 2.4387760162353516 | KNN Loss: 2.3981502056121826 | CLS Loss: 0.04062582179903984
Epoch 35 / 200 | iteration 50 / 171 | Total Loss: 2.39978289604187 | KNN Loss: 2.375951051712036 | CLS Loss: 0.023831866681575775
Epoch 35 / 200 | iteration 60 / 171 | Total Loss: 2.4629411697387695 | KNN Loss: 2.394022226333618 | CLS Loss: 0.06891892850399017
Epoch 35 / 20

Epoch 38 / 200 | iteration 70 / 171 | Total Loss: 2.4266653060913086 | KNN Loss: 2.390873670578003 | CLS Loss: 0.035791538655757904
Epoch 38 / 200 | iteration 80 / 171 | Total Loss: 2.409208059310913 | KNN Loss: 2.3874573707580566 | CLS Loss: 0.02175065688788891
Epoch 38 / 200 | iteration 90 / 171 | Total Loss: 2.4285478591918945 | KNN Loss: 2.3856167793273926 | CLS Loss: 0.04293113574385643
Epoch 38 / 200 | iteration 100 / 171 | Total Loss: 2.450925827026367 | KNN Loss: 2.38423228263855 | CLS Loss: 0.06669352203607559
Epoch 38 / 200 | iteration 110 / 171 | Total Loss: 2.432851791381836 | KNN Loss: 2.3716628551483154 | CLS Loss: 0.061188954859972
Epoch 38 / 200 | iteration 120 / 171 | Total Loss: 2.4274561405181885 | KNN Loss: 2.4031805992126465 | CLS Loss: 0.024275517091155052
Epoch 38 / 200 | iteration 130 / 171 | Total Loss: 2.439312696456909 | KNN Loss: 2.4160356521606445 | CLS Loss: 0.023277027532458305
Epoch 38 / 200 | iteration 140 / 171 | Total Loss: 2.414743185043335 | KNN Los

Epoch 41 / 200 | iteration 140 / 171 | Total Loss: 2.4318292140960693 | KNN Loss: 2.3899600505828857 | CLS Loss: 0.041869111359119415
Epoch 41 / 200 | iteration 150 / 171 | Total Loss: 2.4123754501342773 | KNN Loss: 2.371616840362549 | CLS Loss: 0.04075856879353523
Epoch 41 / 200 | iteration 160 / 171 | Total Loss: 2.473102569580078 | KNN Loss: 2.4372599124908447 | CLS Loss: 0.03584270179271698
Epoch 41 / 200 | iteration 170 / 171 | Total Loss: 2.409773826599121 | KNN Loss: 2.382098436355591 | CLS Loss: 0.027675427496433258
Epoch: 041, Loss: 2.4294, Train: 0.9923, Valid: 0.9851, Best: 0.9856
Epoch 42 / 200 | iteration 0 / 171 | Total Loss: 2.407243251800537 | KNN Loss: 2.3847317695617676 | CLS Loss: 0.022511545568704605
Epoch 42 / 200 | iteration 10 / 171 | Total Loss: 2.4022576808929443 | KNN Loss: 2.3826730251312256 | CLS Loss: 0.019584763795137405
Epoch 42 / 200 | iteration 20 / 171 | Total Loss: 2.3996481895446777 | KNN Loss: 2.3610355854034424 | CLS Loss: 0.03861256688833237
Epoch

Epoch 45 / 200 | iteration 30 / 171 | Total Loss: 2.4013092517852783 | KNN Loss: 2.364500045776367 | CLS Loss: 0.0368092842400074
Epoch 45 / 200 | iteration 40 / 171 | Total Loss: 2.4418785572052 | KNN Loss: 2.4079415798187256 | CLS Loss: 0.03393692150712013
Epoch 45 / 200 | iteration 50 / 171 | Total Loss: 2.4135656356811523 | KNN Loss: 2.361543655395508 | CLS Loss: 0.052021972835063934
Epoch 45 / 200 | iteration 60 / 171 | Total Loss: 2.453874111175537 | KNN Loss: 2.4235379695892334 | CLS Loss: 0.030336234718561172
Epoch 45 / 200 | iteration 70 / 171 | Total Loss: 2.4761056900024414 | KNN Loss: 2.446397542953491 | CLS Loss: 0.029708154499530792
Epoch 45 / 200 | iteration 80 / 171 | Total Loss: 2.3900556564331055 | KNN Loss: 2.3479058742523193 | CLS Loss: 0.04214989393949509
Epoch 45 / 200 | iteration 90 / 171 | Total Loss: 2.4200856685638428 | KNN Loss: 2.3991928100585938 | CLS Loss: 0.020892826840281487
Epoch 45 / 200 | iteration 100 / 171 | Total Loss: 2.420787811279297 | KNN Loss:

Epoch 48 / 200 | iteration 100 / 171 | Total Loss: 2.4027414321899414 | KNN Loss: 2.3924121856689453 | CLS Loss: 0.010329249314963818
Epoch 48 / 200 | iteration 110 / 171 | Total Loss: 2.42620587348938 | KNN Loss: 2.406595230102539 | CLS Loss: 0.019610699266195297
Epoch 48 / 200 | iteration 120 / 171 | Total Loss: 2.406614303588867 | KNN Loss: 2.3956828117370605 | CLS Loss: 0.010931439697742462
Epoch 48 / 200 | iteration 130 / 171 | Total Loss: 2.403782606124878 | KNN Loss: 2.376964569091797 | CLS Loss: 0.026818130165338516
Epoch 48 / 200 | iteration 140 / 171 | Total Loss: 2.4145877361297607 | KNN Loss: 2.3791756629943848 | CLS Loss: 0.035412050783634186
Epoch 48 / 200 | iteration 150 / 171 | Total Loss: 2.4001307487487793 | KNN Loss: 2.372769832611084 | CLS Loss: 0.027360999956727028
Epoch 48 / 200 | iteration 160 / 171 | Total Loss: 2.4581501483917236 | KNN Loss: 2.438227891921997 | CLS Loss: 0.019922195002436638
Epoch 48 / 200 | iteration 170 / 171 | Total Loss: 2.4068164825439453 

Epoch 51 / 200 | iteration 170 / 171 | Total Loss: 2.383180856704712 | KNN Loss: 2.367116689682007 | CLS Loss: 0.016064073890447617
Epoch: 051, Loss: 2.4296, Train: 0.9918, Valid: 0.9846, Best: 0.9859
Epoch 52 / 200 | iteration 0 / 171 | Total Loss: 2.439235210418701 | KNN Loss: 2.4167168140411377 | CLS Loss: 0.022518321871757507
Epoch 52 / 200 | iteration 10 / 171 | Total Loss: 2.4614853858947754 | KNN Loss: 2.4223248958587646 | CLS Loss: 0.03916052356362343
Epoch 52 / 200 | iteration 20 / 171 | Total Loss: 2.46586275100708 | KNN Loss: 2.429654359817505 | CLS Loss: 0.03620840609073639
Epoch 52 / 200 | iteration 30 / 171 | Total Loss: 2.4359941482543945 | KNN Loss: 2.4014086723327637 | CLS Loss: 0.03458549454808235
Epoch 52 / 200 | iteration 40 / 171 | Total Loss: 2.4478704929351807 | KNN Loss: 2.418701171875 | CLS Loss: 0.029169205576181412
Epoch 52 / 200 | iteration 50 / 171 | Total Loss: 2.4477317333221436 | KNN Loss: 2.4179015159606934 | CLS Loss: 0.02983017824590206
Epoch 52 / 200

Epoch 55 / 200 | iteration 60 / 171 | Total Loss: 2.4424192905426025 | KNN Loss: 2.436392307281494 | CLS Loss: 0.006027053575962782
Epoch 55 / 200 | iteration 70 / 171 | Total Loss: 2.403080701828003 | KNN Loss: 2.3860957622528076 | CLS Loss: 0.01698501594364643
Epoch 55 / 200 | iteration 80 / 171 | Total Loss: 2.408107280731201 | KNN Loss: 2.3774731159210205 | CLS Loss: 0.030634064227342606
Epoch 55 / 200 | iteration 90 / 171 | Total Loss: 2.4201889038085938 | KNN Loss: 2.4083738327026367 | CLS Loss: 0.011814998462796211
Epoch 55 / 200 | iteration 100 / 171 | Total Loss: 2.436630964279175 | KNN Loss: 2.4129209518432617 | CLS Loss: 0.023710057139396667
Epoch 55 / 200 | iteration 110 / 171 | Total Loss: 2.424551010131836 | KNN Loss: 2.398169994354248 | CLS Loss: 0.026381107047200203
Epoch 55 / 200 | iteration 120 / 171 | Total Loss: 2.4592528343200684 | KNN Loss: 2.41036319732666 | CLS Loss: 0.04888974875211716
Epoch 55 / 200 | iteration 130 / 171 | Total Loss: 2.441842555999756 | KNN L

Epoch 58 / 200 | iteration 130 / 171 | Total Loss: 2.388741970062256 | KNN Loss: 2.3770382404327393 | CLS Loss: 0.011703667230904102
Epoch 58 / 200 | iteration 140 / 171 | Total Loss: 2.413048505783081 | KNN Loss: 2.352484941482544 | CLS Loss: 0.060563549399375916
Epoch 58 / 200 | iteration 150 / 171 | Total Loss: 2.443957805633545 | KNN Loss: 2.420886993408203 | CLS Loss: 0.023070892319083214
Epoch 58 / 200 | iteration 160 / 171 | Total Loss: 2.435236692428589 | KNN Loss: 2.4116647243499756 | CLS Loss: 0.023571966215968132
Epoch 58 / 200 | iteration 170 / 171 | Total Loss: 2.400999069213867 | KNN Loss: 2.3807828426361084 | CLS Loss: 0.020216302946209908
Epoch: 058, Loss: 2.4274, Train: 0.9931, Valid: 0.9855, Best: 0.9868
Epoch 59 / 200 | iteration 0 / 171 | Total Loss: 2.4178919792175293 | KNN Loss: 2.3945934772491455 | CLS Loss: 0.023298412561416626
Epoch 59 / 200 | iteration 10 / 171 | Total Loss: 2.413278341293335 | KNN Loss: 2.389416456222534 | CLS Loss: 0.02386198192834854
Epoch 

Epoch 62 / 200 | iteration 10 / 171 | Total Loss: 2.4439456462860107 | KNN Loss: 2.421152353286743 | CLS Loss: 0.02279340662062168
Epoch 62 / 200 | iteration 20 / 171 | Total Loss: 2.438267946243286 | KNN Loss: 2.4255149364471436 | CLS Loss: 0.01275312528014183
Epoch 62 / 200 | iteration 30 / 171 | Total Loss: 2.428596019744873 | KNN Loss: 2.4045965671539307 | CLS Loss: 0.023999415338039398
Epoch 62 / 200 | iteration 40 / 171 | Total Loss: 2.4088127613067627 | KNN Loss: 2.398982048034668 | CLS Loss: 0.00983075425028801
Epoch 62 / 200 | iteration 50 / 171 | Total Loss: 2.432647228240967 | KNN Loss: 2.407931327819824 | CLS Loss: 0.024715907871723175
Epoch 62 / 200 | iteration 60 / 171 | Total Loss: 2.3738207817077637 | KNN Loss: 2.351099967956543 | CLS Loss: 0.022720834240317345
Epoch 62 / 200 | iteration 70 / 171 | Total Loss: 2.395958423614502 | KNN Loss: 2.3737189769744873 | CLS Loss: 0.022239400073885918
Epoch 62 / 200 | iteration 80 / 171 | Total Loss: 2.447103977203369 | KNN Loss: 

Epoch 65 / 200 | iteration 80 / 171 | Total Loss: 2.430645704269409 | KNN Loss: 2.417114496231079 | CLS Loss: 0.01353116799145937
Epoch 65 / 200 | iteration 90 / 171 | Total Loss: 2.441927194595337 | KNN Loss: 2.4079346656799316 | CLS Loss: 0.03399248421192169
Epoch 65 / 200 | iteration 100 / 171 | Total Loss: 2.417349100112915 | KNN Loss: 2.3898544311523438 | CLS Loss: 0.02749469131231308
Epoch 65 / 200 | iteration 110 / 171 | Total Loss: 2.4397170543670654 | KNN Loss: 2.4198453426361084 | CLS Loss: 0.019871629774570465
Epoch 65 / 200 | iteration 120 / 171 | Total Loss: 2.4230477809906006 | KNN Loss: 2.407465696334839 | CLS Loss: 0.015582086518406868
Epoch 65 / 200 | iteration 130 / 171 | Total Loss: 2.4232308864593506 | KNN Loss: 2.412252426147461 | CLS Loss: 0.010978416539728642
Epoch 65 / 200 | iteration 140 / 171 | Total Loss: 2.4082701206207275 | KNN Loss: 2.404573678970337 | CLS Loss: 0.0036965401377528906
Epoch 65 / 200 | iteration 150 / 171 | Total Loss: 2.410902738571167 | KN

Epoch 68 / 200 | iteration 150 / 171 | Total Loss: 2.4427478313446045 | KNN Loss: 2.4077916145324707 | CLS Loss: 0.03495631739497185
Epoch 68 / 200 | iteration 160 / 171 | Total Loss: 2.402865409851074 | KNN Loss: 2.393317222595215 | CLS Loss: 0.009548068046569824
Epoch 68 / 200 | iteration 170 / 171 | Total Loss: 2.3980324268341064 | KNN Loss: 2.392202377319336 | CLS Loss: 0.005830093752592802
Epoch: 068, Loss: 2.4253, Train: 0.9953, Valid: 0.9873, Best: 0.9873
Epoch 69 / 200 | iteration 0 / 171 | Total Loss: 2.4014928340911865 | KNN Loss: 2.3811848163604736 | CLS Loss: 0.02030794508755207
Epoch 69 / 200 | iteration 10 / 171 | Total Loss: 2.463613986968994 | KNN Loss: 2.417757511138916 | CLS Loss: 0.04585643857717514
Epoch 69 / 200 | iteration 20 / 171 | Total Loss: 2.390284538269043 | KNN Loss: 2.371675968170166 | CLS Loss: 0.01860848255455494
Epoch 69 / 200 | iteration 30 / 171 | Total Loss: 2.4387784004211426 | KNN Loss: 2.417240619659424 | CLS Loss: 0.021537698805332184
Epoch 69 /

Epoch 72 / 200 | iteration 40 / 171 | Total Loss: 2.4155778884887695 | KNN Loss: 2.408078193664551 | CLS Loss: 0.007499757222831249
Epoch 72 / 200 | iteration 50 / 171 | Total Loss: 2.3933913707733154 | KNN Loss: 2.377941131591797 | CLS Loss: 0.01545034907758236
Epoch 72 / 200 | iteration 60 / 171 | Total Loss: 2.4248249530792236 | KNN Loss: 2.4084410667419434 | CLS Loss: 0.016383929178118706
Epoch 72 / 200 | iteration 70 / 171 | Total Loss: 2.451303482055664 | KNN Loss: 2.41023588180542 | CLS Loss: 0.04106765240430832
Epoch 72 / 200 | iteration 80 / 171 | Total Loss: 2.375675678253174 | KNN Loss: 2.369980573654175 | CLS Loss: 0.005694990511983633
Epoch 72 / 200 | iteration 90 / 171 | Total Loss: 2.421900510787964 | KNN Loss: 2.4148387908935547 | CLS Loss: 0.007061742711812258
Epoch 72 / 200 | iteration 100 / 171 | Total Loss: 2.409619092941284 | KNN Loss: 2.3713834285736084 | CLS Loss: 0.03823569416999817
Epoch 72 / 200 | iteration 110 / 171 | Total Loss: 2.426363468170166 | KNN Loss:

Epoch 75 / 200 | iteration 110 / 171 | Total Loss: 2.4178225994110107 | KNN Loss: 2.394462823867798 | CLS Loss: 0.02335987240076065
Epoch 75 / 200 | iteration 120 / 171 | Total Loss: 2.4204583168029785 | KNN Loss: 2.3959810733795166 | CLS Loss: 0.02447732537984848
Epoch 75 / 200 | iteration 130 / 171 | Total Loss: 2.386488676071167 | KNN Loss: 2.3843629360198975 | CLS Loss: 0.002125857165083289
Epoch 75 / 200 | iteration 140 / 171 | Total Loss: 2.415059804916382 | KNN Loss: 2.4065418243408203 | CLS Loss: 0.008517922833561897
Epoch 75 / 200 | iteration 150 / 171 | Total Loss: 2.4154393672943115 | KNN Loss: 2.400676727294922 | CLS Loss: 0.01476267073303461
Epoch 75 / 200 | iteration 160 / 171 | Total Loss: 2.4131627082824707 | KNN Loss: 2.4085028171539307 | CLS Loss: 0.004659784026443958
Epoch 75 / 200 | iteration 170 / 171 | Total Loss: 2.455623149871826 | KNN Loss: 2.4258384704589844 | CLS Loss: 0.02978459745645523
Epoch: 075, Loss: 2.4250, Train: 0.9953, Valid: 0.9866, Best: 0.9873
Ep

Epoch: 078, Loss: 2.4270, Train: 0.9955, Valid: 0.9868, Best: 0.9873
Epoch 79 / 200 | iteration 0 / 171 | Total Loss: 2.376992702484131 | KNN Loss: 2.352743625640869 | CLS Loss: 0.024249140173196793
Epoch 79 / 200 | iteration 10 / 171 | Total Loss: 2.3856430053710938 | KNN Loss: 2.361626386642456 | CLS Loss: 0.024016527459025383
Epoch 79 / 200 | iteration 20 / 171 | Total Loss: 2.421844244003296 | KNN Loss: 2.3990275859832764 | CLS Loss: 0.022816767916083336
Epoch 79 / 200 | iteration 30 / 171 | Total Loss: 2.443485736846924 | KNN Loss: 2.3931186199188232 | CLS Loss: 0.050367195159196854
Epoch 79 / 200 | iteration 40 / 171 | Total Loss: 2.4322733879089355 | KNN Loss: 2.420227527618408 | CLS Loss: 0.01204590406268835
Epoch 79 / 200 | iteration 50 / 171 | Total Loss: 2.4135196208953857 | KNN Loss: 2.390939235687256 | CLS Loss: 0.0225803442299366
Epoch 79 / 200 | iteration 60 / 171 | Total Loss: 2.396223306655884 | KNN Loss: 2.3897836208343506 | CLS Loss: 0.006439626216888428
Epoch 79 / 2

Epoch 82 / 200 | iteration 70 / 171 | Total Loss: 2.405061721801758 | KNN Loss: 2.384415864944458 | CLS Loss: 0.0206457506865263
Epoch 82 / 200 | iteration 80 / 171 | Total Loss: 2.4314162731170654 | KNN Loss: 2.4019508361816406 | CLS Loss: 0.02946542389690876
Epoch 82 / 200 | iteration 90 / 171 | Total Loss: 2.4418952465057373 | KNN Loss: 2.423374652862549 | CLS Loss: 0.018520548939704895
Epoch 82 / 200 | iteration 100 / 171 | Total Loss: 2.415668249130249 | KNN Loss: 2.4099442958831787 | CLS Loss: 0.005723912268877029
Epoch 82 / 200 | iteration 110 / 171 | Total Loss: 2.4273722171783447 | KNN Loss: 2.3950557708740234 | CLS Loss: 0.032316550612449646
Epoch 82 / 200 | iteration 120 / 171 | Total Loss: 2.417983055114746 | KNN Loss: 2.3913090229034424 | CLS Loss: 0.026673946529626846
Epoch 82 / 200 | iteration 130 / 171 | Total Loss: 2.4300155639648438 | KNN Loss: 2.410757064819336 | CLS Loss: 0.019258510321378708
Epoch 82 / 200 | iteration 140 / 171 | Total Loss: 2.4115939140319824 | KN

Epoch 85 / 200 | iteration 140 / 171 | Total Loss: 2.3920018672943115 | KNN Loss: 2.3878166675567627 | CLS Loss: 0.004185243975371122
Epoch 85 / 200 | iteration 150 / 171 | Total Loss: 2.433656692504883 | KNN Loss: 2.4298200607299805 | CLS Loss: 0.003836589166894555
Epoch 85 / 200 | iteration 160 / 171 | Total Loss: 2.4512534141540527 | KNN Loss: 2.437114715576172 | CLS Loss: 0.014138591475784779
Epoch 85 / 200 | iteration 170 / 171 | Total Loss: 2.4326424598693848 | KNN Loss: 2.4072489738464355 | CLS Loss: 0.025393567979335785
Epoch: 085, Loss: 2.4234, Train: 0.9950, Valid: 0.9858, Best: 0.9873
Epoch 86 / 200 | iteration 0 / 171 | Total Loss: 2.4308853149414062 | KNN Loss: 2.4150936603546143 | CLS Loss: 0.01579161360859871
Epoch 86 / 200 | iteration 10 / 171 | Total Loss: 2.4373247623443604 | KNN Loss: 2.4138340950012207 | CLS Loss: 0.02349069155752659
Epoch 86 / 200 | iteration 20 / 171 | Total Loss: 2.457421064376831 | KNN Loss: 2.4461724758148193 | CLS Loss: 0.011248493567109108
Ep

Epoch 89 / 200 | iteration 20 / 171 | Total Loss: 2.3883755207061768 | KNN Loss: 2.3773701190948486 | CLS Loss: 0.011005382053554058
Epoch 89 / 200 | iteration 30 / 171 | Total Loss: 2.4080421924591064 | KNN Loss: 2.4008255004882812 | CLS Loss: 0.007216796278953552
Epoch 89 / 200 | iteration 40 / 171 | Total Loss: 2.4349918365478516 | KNN Loss: 2.402470111846924 | CLS Loss: 0.032521724700927734
Epoch 89 / 200 | iteration 50 / 171 | Total Loss: 2.408445358276367 | KNN Loss: 2.3998830318450928 | CLS Loss: 0.008562210947275162
Epoch 89 / 200 | iteration 60 / 171 | Total Loss: 2.40065598487854 | KNN Loss: 2.3872783184051514 | CLS Loss: 0.01337760966271162
Epoch 89 / 200 | iteration 70 / 171 | Total Loss: 2.4268054962158203 | KNN Loss: 2.3912508487701416 | CLS Loss: 0.03555469959974289
Epoch 89 / 200 | iteration 80 / 171 | Total Loss: 2.433711528778076 | KNN Loss: 2.422884702682495 | CLS Loss: 0.010826931335031986
Epoch 89 / 200 | iteration 90 / 171 | Total Loss: 2.421034097671509 | KNN Los

Epoch 92 / 200 | iteration 90 / 171 | Total Loss: 2.4029977321624756 | KNN Loss: 2.3961119651794434 | CLS Loss: 0.006885819602757692
Epoch 92 / 200 | iteration 100 / 171 | Total Loss: 2.4191980361938477 | KNN Loss: 2.392493963241577 | CLS Loss: 0.026704076677560806
Epoch 92 / 200 | iteration 110 / 171 | Total Loss: 2.35256290435791 | KNN Loss: 2.3485300540924072 | CLS Loss: 0.004032933618873358
Epoch 92 / 200 | iteration 120 / 171 | Total Loss: 2.3934671878814697 | KNN Loss: 2.3859195709228516 | CLS Loss: 0.00754750519990921
Epoch 92 / 200 | iteration 130 / 171 | Total Loss: 2.419423818588257 | KNN Loss: 2.401606798171997 | CLS Loss: 0.017816966399550438
Epoch 92 / 200 | iteration 140 / 171 | Total Loss: 2.411888837814331 | KNN Loss: 2.383178949356079 | CLS Loss: 0.028709838166832924
Epoch 92 / 200 | iteration 150 / 171 | Total Loss: 2.4506752490997314 | KNN Loss: 2.4212212562561035 | CLS Loss: 0.02945406548678875
Epoch 92 / 200 | iteration 160 / 171 | Total Loss: 2.401254415512085 | K

Epoch 95 / 200 | iteration 160 / 171 | Total Loss: 2.4127485752105713 | KNN Loss: 2.4051308631896973 | CLS Loss: 0.007617610041052103
Epoch 95 / 200 | iteration 170 / 171 | Total Loss: 2.411968469619751 | KNN Loss: 2.4053170680999756 | CLS Loss: 0.006651314906775951
Epoch: 095, Loss: 2.4203, Train: 0.9955, Valid: 0.9867, Best: 0.9873
Epoch 96 / 200 | iteration 0 / 171 | Total Loss: 2.44950795173645 | KNN Loss: 2.4479095935821533 | CLS Loss: 0.001598259201273322
Epoch 96 / 200 | iteration 10 / 171 | Total Loss: 2.3888707160949707 | KNN Loss: 2.3593785762786865 | CLS Loss: 0.02949216030538082
Epoch 96 / 200 | iteration 20 / 171 | Total Loss: 2.353353261947632 | KNN Loss: 2.3503575325012207 | CLS Loss: 0.0029957792721688747
Epoch 96 / 200 | iteration 30 / 171 | Total Loss: 2.403921604156494 | KNN Loss: 2.387380599975586 | CLS Loss: 0.016541047021746635
Epoch 96 / 200 | iteration 40 / 171 | Total Loss: 2.4101336002349854 | KNN Loss: 2.388340711593628 | CLS Loss: 0.02179277129471302
Epoch 9

Epoch 99 / 200 | iteration 40 / 171 | Total Loss: 2.4299163818359375 | KNN Loss: 2.4222893714904785 | CLS Loss: 0.007626964244991541
Epoch 99 / 200 | iteration 50 / 171 | Total Loss: 2.420658826828003 | KNN Loss: 2.40999436378479 | CLS Loss: 0.010664572939276695
Epoch 99 / 200 | iteration 60 / 171 | Total Loss: 2.3870248794555664 | KNN Loss: 2.376127004623413 | CLS Loss: 0.01089776586741209
Epoch 99 / 200 | iteration 70 / 171 | Total Loss: 2.381593942642212 | KNN Loss: 2.374173164367676 | CLS Loss: 0.0074206627905368805
Epoch 99 / 200 | iteration 80 / 171 | Total Loss: 2.4164435863494873 | KNN Loss: 2.393423080444336 | CLS Loss: 0.023020509630441666
Epoch 99 / 200 | iteration 90 / 171 | Total Loss: 2.4080607891082764 | KNN Loss: 2.3970139026641846 | CLS Loss: 0.011046904139220715
Epoch 99 / 200 | iteration 100 / 171 | Total Loss: 2.4019315242767334 | KNN Loss: 2.392029285430908 | CLS Loss: 0.00990226585417986
Epoch 99 / 200 | iteration 110 / 171 | Total Loss: 2.3934903144836426 | KNN L

Epoch 102 / 200 | iteration 110 / 171 | Total Loss: 2.4470889568328857 | KNN Loss: 2.440800428390503 | CLS Loss: 0.006288525182753801
Epoch 102 / 200 | iteration 120 / 171 | Total Loss: 2.419532060623169 | KNN Loss: 2.416633129119873 | CLS Loss: 0.0028988898266106844
Epoch 102 / 200 | iteration 130 / 171 | Total Loss: 2.4242331981658936 | KNN Loss: 2.400282382965088 | CLS Loss: 0.023950789123773575
Epoch 102 / 200 | iteration 140 / 171 | Total Loss: 2.397582769393921 | KNN Loss: 2.3825252056121826 | CLS Loss: 0.015057520940899849
Epoch 102 / 200 | iteration 150 / 171 | Total Loss: 2.4539170265197754 | KNN Loss: 2.4146547317504883 | CLS Loss: 0.03926220163702965
Epoch 102 / 200 | iteration 160 / 171 | Total Loss: 2.428683042526245 | KNN Loss: 2.423800230026245 | CLS Loss: 0.004882738459855318
Epoch 102 / 200 | iteration 170 / 171 | Total Loss: 2.3784542083740234 | KNN Loss: 2.3721566200256348 | CLS Loss: 0.006297666113823652
Epoch: 102, Loss: 2.4191, Train: 0.9959, Valid: 0.9868, Best: 

Epoch 105 / 200 | iteration 170 / 171 | Total Loss: 2.4207677841186523 | KNN Loss: 2.3963191509246826 | CLS Loss: 0.024448616430163383
Epoch: 105, Loss: 2.4157, Train: 0.9962, Valid: 0.9857, Best: 0.9873
Epoch 106 / 200 | iteration 0 / 171 | Total Loss: 2.4087765216827393 | KNN Loss: 2.388765573501587 | CLS Loss: 0.020010931417346
Epoch 106 / 200 | iteration 10 / 171 | Total Loss: 2.422978639602661 | KNN Loss: 2.4101197719573975 | CLS Loss: 0.012858793139457703
Epoch 106 / 200 | iteration 20 / 171 | Total Loss: 2.41858172416687 | KNN Loss: 2.3976829051971436 | CLS Loss: 0.02089879661798477
Epoch 106 / 200 | iteration 30 / 171 | Total Loss: 2.4162795543670654 | KNN Loss: 2.411379098892212 | CLS Loss: 0.004900495987385511
Epoch 106 / 200 | iteration 40 / 171 | Total Loss: 2.393186569213867 | KNN Loss: 2.380333423614502 | CLS Loss: 0.012853122316300869
Epoch 106 / 200 | iteration 50 / 171 | Total Loss: 2.3979806900024414 | KNN Loss: 2.3935930728912354 | CLS Loss: 0.004387693013995886
Epoc

Epoch 109 / 200 | iteration 50 / 171 | Total Loss: 2.3843092918395996 | KNN Loss: 2.3605544567108154 | CLS Loss: 0.02375483699142933
Epoch 109 / 200 | iteration 60 / 171 | Total Loss: 2.432204484939575 | KNN Loss: 2.42305064201355 | CLS Loss: 0.009153781458735466
Epoch 109 / 200 | iteration 70 / 171 | Total Loss: 2.3904967308044434 | KNN Loss: 2.3699452877044678 | CLS Loss: 0.020551485940814018
Epoch 109 / 200 | iteration 80 / 171 | Total Loss: 2.406862735748291 | KNN Loss: 2.3841168880462646 | CLS Loss: 0.022745957598090172
Epoch 109 / 200 | iteration 90 / 171 | Total Loss: 2.375570058822632 | KNN Loss: 2.3685615062713623 | CLS Loss: 0.007008666172623634
Epoch 109 / 200 | iteration 100 / 171 | Total Loss: 2.4109108448028564 | KNN Loss: 2.3904175758361816 | CLS Loss: 0.020493384450674057
Epoch 109 / 200 | iteration 110 / 171 | Total Loss: 2.3895843029022217 | KNN Loss: 2.375119924545288 | CLS Loss: 0.014464424923062325
Epoch 109 / 200 | iteration 120 / 171 | Total Loss: 2.4272382259368

Epoch 112 / 200 | iteration 110 / 171 | Total Loss: 2.430297613143921 | KNN Loss: 2.4236857891082764 | CLS Loss: 0.0066117895767092705
Epoch 112 / 200 | iteration 120 / 171 | Total Loss: 2.3980236053466797 | KNN Loss: 2.39180326461792 | CLS Loss: 0.006220315583050251
Epoch 112 / 200 | iteration 130 / 171 | Total Loss: 2.383681058883667 | KNN Loss: 2.3762366771698 | CLS Loss: 0.007444469723850489
Epoch 112 / 200 | iteration 140 / 171 | Total Loss: 2.403081178665161 | KNN Loss: 2.3907902240753174 | CLS Loss: 0.012291030958294868
Epoch 112 / 200 | iteration 150 / 171 | Total Loss: 2.39434552192688 | KNN Loss: 2.387427806854248 | CLS Loss: 0.006917779799550772
Epoch 112 / 200 | iteration 160 / 171 | Total Loss: 2.3927085399627686 | KNN Loss: 2.3909246921539307 | CLS Loss: 0.0017838198691606522
Epoch 112 / 200 | iteration 170 / 171 | Total Loss: 2.402440309524536 | KNN Loss: 2.39227557182312 | CLS Loss: 0.010164712555706501
Epoch: 112, Loss: 2.4140, Train: 0.9967, Valid: 0.9858, Best: 0.987

Epoch 115 / 200 | iteration 170 / 171 | Total Loss: 2.423062801361084 | KNN Loss: 2.4049456119537354 | CLS Loss: 0.018117304891347885
Epoch: 115, Loss: 2.4119, Train: 0.9962, Valid: 0.9863, Best: 0.9873
Epoch 116 / 200 | iteration 0 / 171 | Total Loss: 2.4032175540924072 | KNN Loss: 2.3933804035186768 | CLS Loss: 0.009837208315730095
Epoch 116 / 200 | iteration 10 / 171 | Total Loss: 2.440272092819214 | KNN Loss: 2.409156084060669 | CLS Loss: 0.031116044148802757
Epoch 116 / 200 | iteration 20 / 171 | Total Loss: 2.3997416496276855 | KNN Loss: 2.395519495010376 | CLS Loss: 0.0042221397161483765
Epoch 116 / 200 | iteration 30 / 171 | Total Loss: 2.438030242919922 | KNN Loss: 2.4234163761138916 | CLS Loss: 0.014613796956837177
Epoch 116 / 200 | iteration 40 / 171 | Total Loss: 2.362839460372925 | KNN Loss: 2.359600305557251 | CLS Loss: 0.0032391392160207033
Epoch 116 / 200 | iteration 50 / 171 | Total Loss: 2.4195475578308105 | KNN Loss: 2.409435510635376 | CLS Loss: 0.010111944749951363

Epoch 119 / 200 | iteration 50 / 171 | Total Loss: 2.384678602218628 | KNN Loss: 2.3718409538269043 | CLS Loss: 0.012837736867368221
Epoch 119 / 200 | iteration 60 / 171 | Total Loss: 2.421987295150757 | KNN Loss: 2.4131510257720947 | CLS Loss: 0.008836383931338787
Epoch 119 / 200 | iteration 70 / 171 | Total Loss: 2.380098581314087 | KNN Loss: 2.378784418106079 | CLS Loss: 0.0013142063980922103
Epoch 119 / 200 | iteration 80 / 171 | Total Loss: 2.4195339679718018 | KNN Loss: 2.39725661277771 | CLS Loss: 0.02227729558944702
Epoch 119 / 200 | iteration 90 / 171 | Total Loss: 2.4090404510498047 | KNN Loss: 2.4065518379211426 | CLS Loss: 0.0024885262828320265
Epoch 119 / 200 | iteration 100 / 171 | Total Loss: 2.411806344985962 | KNN Loss: 2.4074387550354004 | CLS Loss: 0.004367700777947903
Epoch 119 / 200 | iteration 110 / 171 | Total Loss: 2.4145240783691406 | KNN Loss: 2.407722234725952 | CLS Loss: 0.006801781244575977
Epoch 119 / 200 | iteration 120 / 171 | Total Loss: 2.3874950408935

Epoch 122 / 200 | iteration 110 / 171 | Total Loss: 2.4538934230804443 | KNN Loss: 2.4384639263153076 | CLS Loss: 0.015429539605975151
Epoch 122 / 200 | iteration 120 / 171 | Total Loss: 2.4427754878997803 | KNN Loss: 2.4245285987854004 | CLS Loss: 0.01824679970741272
Epoch 122 / 200 | iteration 130 / 171 | Total Loss: 2.392854690551758 | KNN Loss: 2.3757922649383545 | CLS Loss: 0.01706242561340332
Epoch 122 / 200 | iteration 140 / 171 | Total Loss: 2.4179563522338867 | KNN Loss: 2.4107754230499268 | CLS Loss: 0.007181039545685053
Epoch 122 / 200 | iteration 150 / 171 | Total Loss: 2.4148714542388916 | KNN Loss: 2.3898143768310547 | CLS Loss: 0.02505703642964363
Epoch 122 / 200 | iteration 160 / 171 | Total Loss: 2.4245173931121826 | KNN Loss: 2.4049861431121826 | CLS Loss: 0.01953134499490261
Epoch 122 / 200 | iteration 170 / 171 | Total Loss: 2.4076144695281982 | KNN Loss: 2.3962197303771973 | CLS Loss: 0.011394725181162357
Epoch: 122, Loss: 2.4161, Train: 0.9962, Valid: 0.9856, Best

Epoch 125 / 200 | iteration 170 / 171 | Total Loss: 2.3640105724334717 | KNN Loss: 2.349822998046875 | CLS Loss: 0.014187674038112164
Epoch: 125, Loss: 2.4091, Train: 0.9969, Valid: 0.9870, Best: 0.9873
Epoch 126 / 200 | iteration 0 / 171 | Total Loss: 2.3790769577026367 | KNN Loss: 2.3695647716522217 | CLS Loss: 0.009512077085673809
Epoch 126 / 200 | iteration 10 / 171 | Total Loss: 2.4047458171844482 | KNN Loss: 2.400683641433716 | CLS Loss: 0.004062094260007143
Epoch 126 / 200 | iteration 20 / 171 | Total Loss: 2.3958306312561035 | KNN Loss: 2.38499116897583 | CLS Loss: 0.010839428752660751
Epoch 126 / 200 | iteration 30 / 171 | Total Loss: 2.389822244644165 | KNN Loss: 2.381594657897949 | CLS Loss: 0.008227543905377388
Epoch 126 / 200 | iteration 40 / 171 | Total Loss: 2.4305593967437744 | KNN Loss: 2.4210057258605957 | CLS Loss: 0.009553560987114906
Epoch 126 / 200 | iteration 50 / 171 | Total Loss: 2.431279420852661 | KNN Loss: 2.4188268184661865 | CLS Loss: 0.012452551163733006


Epoch 129 / 200 | iteration 50 / 171 | Total Loss: 2.396239995956421 | KNN Loss: 2.3921101093292236 | CLS Loss: 0.004129784647375345
Epoch 129 / 200 | iteration 60 / 171 | Total Loss: 2.383392572402954 | KNN Loss: 2.3664088249206543 | CLS Loss: 0.01698373444378376
Epoch 129 / 200 | iteration 70 / 171 | Total Loss: 2.4110445976257324 | KNN Loss: 2.404224395751953 | CLS Loss: 0.006820301990956068
Epoch 129 / 200 | iteration 80 / 171 | Total Loss: 2.4406089782714844 | KNN Loss: 2.4294238090515137 | CLS Loss: 0.01118509005755186
Epoch 129 / 200 | iteration 90 / 171 | Total Loss: 2.3722786903381348 | KNN Loss: 2.3647451400756836 | CLS Loss: 0.007533631287515163
Epoch 129 / 200 | iteration 100 / 171 | Total Loss: 2.4363529682159424 | KNN Loss: 2.4068217277526855 | CLS Loss: 0.029531141743063927
Epoch 129 / 200 | iteration 110 / 171 | Total Loss: 2.4161946773529053 | KNN Loss: 2.4145376682281494 | CLS Loss: 0.0016569711733609438
Epoch 129 / 200 | iteration 120 / 171 | Total Loss: 2.3672382831

Epoch 132 / 200 | iteration 110 / 171 | Total Loss: 2.4385828971862793 | KNN Loss: 2.43265962600708 | CLS Loss: 0.0059232865460217
Epoch 132 / 200 | iteration 120 / 171 | Total Loss: 2.4079086780548096 | KNN Loss: 2.4016733169555664 | CLS Loss: 0.006235423032194376
Epoch 132 / 200 | iteration 130 / 171 | Total Loss: 2.449580669403076 | KNN Loss: 2.4338555335998535 | CLS Loss: 0.01572505570948124
Epoch 132 / 200 | iteration 140 / 171 | Total Loss: 2.3679986000061035 | KNN Loss: 2.360365629196167 | CLS Loss: 0.007633047178387642
Epoch 132 / 200 | iteration 150 / 171 | Total Loss: 2.3914036750793457 | KNN Loss: 2.382167100906372 | CLS Loss: 0.009236661717295647
Epoch 132 / 200 | iteration 160 / 171 | Total Loss: 2.386888265609741 | KNN Loss: 2.3850085735321045 | CLS Loss: 0.001879691262729466
Epoch 132 / 200 | iteration 170 / 171 | Total Loss: 2.3856234550476074 | KNN Loss: 2.3746795654296875 | CLS Loss: 0.010943992994725704
Epoch: 132, Loss: 2.4078, Train: 0.9974, Valid: 0.9862, Best: 0.

Epoch 135 / 200 | iteration 170 / 171 | Total Loss: 2.3749632835388184 | KNN Loss: 2.370333433151245 | CLS Loss: 0.004629960283637047
Epoch: 135, Loss: 2.4093, Train: 0.9967, Valid: 0.9861, Best: 0.9873
Epoch 136 / 200 | iteration 0 / 171 | Total Loss: 2.407447576522827 | KNN Loss: 2.4050748348236084 | CLS Loss: 0.0023726823274046183
Epoch 136 / 200 | iteration 10 / 171 | Total Loss: 2.4324264526367188 | KNN Loss: 2.423845052719116 | CLS Loss: 0.008581466972827911
Epoch 136 / 200 | iteration 20 / 171 | Total Loss: 2.4318933486938477 | KNN Loss: 2.4183642864227295 | CLS Loss: 0.01352910976856947
Epoch 136 / 200 | iteration 30 / 171 | Total Loss: 2.426661968231201 | KNN Loss: 2.4244561195373535 | CLS Loss: 0.0022057692985981703
Epoch 136 / 200 | iteration 40 / 171 | Total Loss: 2.367790937423706 | KNN Loss: 2.3606414794921875 | CLS Loss: 0.007149370387196541
Epoch 136 / 200 | iteration 50 / 171 | Total Loss: 2.4401955604553223 | KNN Loss: 2.422445774078369 | CLS Loss: 0.01774968206882476

Epoch 139 / 200 | iteration 50 / 171 | Total Loss: 2.4172120094299316 | KNN Loss: 2.412446975708008 | CLS Loss: 0.0047649601474404335
Epoch 139 / 200 | iteration 60 / 171 | Total Loss: 2.431915521621704 | KNN Loss: 2.3994805812835693 | CLS Loss: 0.0324348509311676
Epoch 139 / 200 | iteration 70 / 171 | Total Loss: 2.405132293701172 | KNN Loss: 2.400583505630493 | CLS Loss: 0.0045488192699849606
Epoch 139 / 200 | iteration 80 / 171 | Total Loss: 2.4302492141723633 | KNN Loss: 2.4075887203216553 | CLS Loss: 0.022660406306385994
Epoch 139 / 200 | iteration 90 / 171 | Total Loss: 2.4307241439819336 | KNN Loss: 2.3914613723754883 | CLS Loss: 0.039262741804122925
Epoch 139 / 200 | iteration 100 / 171 | Total Loss: 2.4382901191711426 | KNN Loss: 2.429764747619629 | CLS Loss: 0.00852537713944912
Epoch 139 / 200 | iteration 110 / 171 | Total Loss: 2.398163318634033 | KNN Loss: 2.3860394954681396 | CLS Loss: 0.012123714201152325
Epoch 139 / 200 | iteration 120 / 171 | Total Loss: 2.3919785022735

Epoch 142 / 200 | iteration 110 / 171 | Total Loss: 2.388078451156616 | KNN Loss: 2.378056049346924 | CLS Loss: 0.010022372007369995
Epoch 142 / 200 | iteration 120 / 171 | Total Loss: 2.386617422103882 | KNN Loss: 2.3847668170928955 | CLS Loss: 0.0018505797488614917
Epoch 142 / 200 | iteration 130 / 171 | Total Loss: 2.3819990158081055 | KNN Loss: 2.3765296936035156 | CLS Loss: 0.005469375289976597
Epoch 142 / 200 | iteration 140 / 171 | Total Loss: 2.4530091285705566 | KNN Loss: 2.432772397994995 | CLS Loss: 0.02023683302104473
Epoch 142 / 200 | iteration 150 / 171 | Total Loss: 2.3915677070617676 | KNN Loss: 2.368978500366211 | CLS Loss: 0.022589117288589478
Epoch 142 / 200 | iteration 160 / 171 | Total Loss: 2.406944513320923 | KNN Loss: 2.402430772781372 | CLS Loss: 0.00451385322958231
Epoch 142 / 200 | iteration 170 / 171 | Total Loss: 2.4577558040618896 | KNN Loss: 2.440119981765747 | CLS Loss: 0.01763584278523922
Epoch: 142, Loss: 2.4103, Train: 0.9970, Valid: 0.9873, Best: 0.9

Epoch 145 / 200 | iteration 170 / 171 | Total Loss: 2.362464189529419 | KNN Loss: 2.3520984649658203 | CLS Loss: 0.010365644469857216
Epoch: 145, Loss: 2.4098, Train: 0.9973, Valid: 0.9868, Best: 0.9873
Epoch 146 / 200 | iteration 0 / 171 | Total Loss: 2.4505908489227295 | KNN Loss: 2.4496190547943115 | CLS Loss: 0.0009716874337755144
Epoch 146 / 200 | iteration 10 / 171 | Total Loss: 2.373491048812866 | KNN Loss: 2.371509313583374 | CLS Loss: 0.001981850014999509
Epoch 146 / 200 | iteration 20 / 171 | Total Loss: 2.424107074737549 | KNN Loss: 2.418142318725586 | CLS Loss: 0.005964841693639755
Epoch 146 / 200 | iteration 30 / 171 | Total Loss: 2.3825972080230713 | KNN Loss: 2.370879650115967 | CLS Loss: 0.011717663146555424
Epoch 146 / 200 | iteration 40 / 171 | Total Loss: 2.4214975833892822 | KNN Loss: 2.411771774291992 | CLS Loss: 0.009725798852741718
Epoch 146 / 200 | iteration 50 / 171 | Total Loss: 2.3817975521087646 | KNN Loss: 2.377549171447754 | CLS Loss: 0.0042483326978981495

Epoch 149 / 200 | iteration 50 / 171 | Total Loss: 2.4287261962890625 | KNN Loss: 2.4239253997802734 | CLS Loss: 0.004800730384886265
Epoch 149 / 200 | iteration 60 / 171 | Total Loss: 2.3656692504882812 | KNN Loss: 2.362079381942749 | CLS Loss: 0.0035899209324270487
Epoch 149 / 200 | iteration 70 / 171 | Total Loss: 2.3981258869171143 | KNN Loss: 2.394566774368286 | CLS Loss: 0.0035590820480138063
Epoch 149 / 200 | iteration 80 / 171 | Total Loss: 2.4138548374176025 | KNN Loss: 2.39920973777771 | CLS Loss: 0.014645061455667019
Epoch 149 / 200 | iteration 90 / 171 | Total Loss: 2.4016120433807373 | KNN Loss: 2.3813247680664062 | CLS Loss: 0.020287370309233665
Epoch 149 / 200 | iteration 100 / 171 | Total Loss: 2.4095265865325928 | KNN Loss: 2.4061920642852783 | CLS Loss: 0.003334427485242486
Epoch 149 / 200 | iteration 110 / 171 | Total Loss: 2.385101556777954 | KNN Loss: 2.381878614425659 | CLS Loss: 0.0032228983473032713
Epoch 149 / 200 | iteration 120 / 171 | Total Loss: 2.410022020

Epoch 152 / 200 | iteration 110 / 171 | Total Loss: 2.420301675796509 | KNN Loss: 2.4019999504089355 | CLS Loss: 0.01830165460705757
Epoch 152 / 200 | iteration 120 / 171 | Total Loss: 2.375483512878418 | KNN Loss: 2.3714241981506348 | CLS Loss: 0.004059280268847942
Epoch 152 / 200 | iteration 130 / 171 | Total Loss: 2.406527042388916 | KNN Loss: 2.398822784423828 | CLS Loss: 0.007704178337007761
Epoch 152 / 200 | iteration 140 / 171 | Total Loss: 2.4117586612701416 | KNN Loss: 2.3906869888305664 | CLS Loss: 0.02107161283493042
Epoch 152 / 200 | iteration 150 / 171 | Total Loss: 2.4128150939941406 | KNN Loss: 2.404513120651245 | CLS Loss: 0.0083018708974123
Epoch 152 / 200 | iteration 160 / 171 | Total Loss: 2.3998820781707764 | KNN Loss: 2.3982737064361572 | CLS Loss: 0.0016084040980786085
Epoch 152 / 200 | iteration 170 / 171 | Total Loss: 2.40511417388916 | KNN Loss: 2.398831367492676 | CLS Loss: 0.00628284877166152
Epoch: 152, Loss: 2.4084, Train: 0.9971, Valid: 0.9852, Best: 0.987

Epoch 155 / 200 | iteration 170 / 171 | Total Loss: 2.4034264087677 | KNN Loss: 2.3831756114959717 | CLS Loss: 0.020250700414180756
Epoch: 155, Loss: 2.4100, Train: 0.9973, Valid: 0.9866, Best: 0.9873
Epoch 156 / 200 | iteration 0 / 171 | Total Loss: 2.4005753993988037 | KNN Loss: 2.3992385864257812 | CLS Loss: 0.0013368421932682395
Epoch 156 / 200 | iteration 10 / 171 | Total Loss: 2.4437615871429443 | KNN Loss: 2.441312789916992 | CLS Loss: 0.0024488456547260284
Epoch 156 / 200 | iteration 20 / 171 | Total Loss: 2.394867181777954 | KNN Loss: 2.3931117057800293 | CLS Loss: 0.0017554920632392168
Epoch 156 / 200 | iteration 30 / 171 | Total Loss: 2.3774962425231934 | KNN Loss: 2.3681440353393555 | CLS Loss: 0.00935214851051569
Epoch 156 / 200 | iteration 40 / 171 | Total Loss: 2.3658230304718018 | KNN Loss: 2.3592936992645264 | CLS Loss: 0.006529365200549364
Epoch 156 / 200 | iteration 50 / 171 | Total Loss: 2.388507604598999 | KNN Loss: 2.3818066120147705 | CLS Loss: 0.0067009474150836

Epoch 159 / 200 | iteration 50 / 171 | Total Loss: 2.4185779094696045 | KNN Loss: 2.4077422618865967 | CLS Loss: 0.0108357397839427
Epoch 159 / 200 | iteration 60 / 171 | Total Loss: 2.423785924911499 | KNN Loss: 2.407517194747925 | CLS Loss: 0.016268765553832054
Epoch 159 / 200 | iteration 70 / 171 | Total Loss: 2.415050506591797 | KNN Loss: 2.4125800132751465 | CLS Loss: 0.002470549661666155
Epoch 159 / 200 | iteration 80 / 171 | Total Loss: 2.3834667205810547 | KNN Loss: 2.3739535808563232 | CLS Loss: 0.00951309222728014
Epoch 159 / 200 | iteration 90 / 171 | Total Loss: 2.381997585296631 | KNN Loss: 2.379599094390869 | CLS Loss: 0.0023984508588910103
Epoch 159 / 200 | iteration 100 / 171 | Total Loss: 2.4311444759368896 | KNN Loss: 2.4003872871398926 | CLS Loss: 0.03075719252228737
Epoch 159 / 200 | iteration 110 / 171 | Total Loss: 2.4444899559020996 | KNN Loss: 2.4251363277435303 | CLS Loss: 0.019353697076439857
Epoch 159 / 200 | iteration 120 / 171 | Total Loss: 2.41695380210876

Epoch 162 / 200 | iteration 110 / 171 | Total Loss: 2.405320405960083 | KNN Loss: 2.3948957920074463 | CLS Loss: 0.010424529202282429
Epoch 162 / 200 | iteration 120 / 171 | Total Loss: 2.3790550231933594 | KNN Loss: 2.3770594596862793 | CLS Loss: 0.0019956075120717287
Epoch 162 / 200 | iteration 130 / 171 | Total Loss: 2.41799259185791 | KNN Loss: 2.4140822887420654 | CLS Loss: 0.003910236991941929
Epoch 162 / 200 | iteration 140 / 171 | Total Loss: 2.421107292175293 | KNN Loss: 2.401874303817749 | CLS Loss: 0.019232885912060738
Epoch 162 / 200 | iteration 150 / 171 | Total Loss: 2.3934381008148193 | KNN Loss: 2.3866894245147705 | CLS Loss: 0.006748669780790806
Epoch 162 / 200 | iteration 160 / 171 | Total Loss: 2.398024797439575 | KNN Loss: 2.3891639709472656 | CLS Loss: 0.008860744535923004
Epoch 162 / 200 | iteration 170 / 171 | Total Loss: 2.4304285049438477 | KNN Loss: 2.408466100692749 | CLS Loss: 0.0219624824821949
Epoch: 162, Loss: 2.4038, Train: 0.9976, Valid: 0.9865, Best: 0

Epoch 165 / 200 | iteration 170 / 171 | Total Loss: 2.4317097663879395 | KNN Loss: 2.424625873565674 | CLS Loss: 0.007083808537572622
Epoch: 165, Loss: 2.4065, Train: 0.9963, Valid: 0.9856, Best: 0.9873
Epoch 166 / 200 | iteration 0 / 171 | Total Loss: 2.4341816902160645 | KNN Loss: 2.418151617050171 | CLS Loss: 0.01602996699512005
Epoch 166 / 200 | iteration 10 / 171 | Total Loss: 2.392561435699463 | KNN Loss: 2.379124641418457 | CLS Loss: 0.01343678031116724
Epoch 166 / 200 | iteration 20 / 171 | Total Loss: 2.390676736831665 | KNN Loss: 2.382258892059326 | CLS Loss: 0.008417951874434948
Epoch 166 / 200 | iteration 30 / 171 | Total Loss: 2.424752712249756 | KNN Loss: 2.4215316772460938 | CLS Loss: 0.0032210731878876686
Epoch 166 / 200 | iteration 40 / 171 | Total Loss: 2.3880441188812256 | KNN Loss: 2.384467363357544 | CLS Loss: 0.0035768328234553337
Epoch 166 / 200 | iteration 50 / 171 | Total Loss: 2.4430792331695557 | KNN Loss: 2.4369232654571533 | CLS Loss: 0.0061558568850159645


Epoch 169 / 200 | iteration 50 / 171 | Total Loss: 2.4255077838897705 | KNN Loss: 2.420090436935425 | CLS Loss: 0.005417464766651392
Epoch 169 / 200 | iteration 60 / 171 | Total Loss: 2.390815496444702 | KNN Loss: 2.3847062587738037 | CLS Loss: 0.006109143141657114
Epoch 169 / 200 | iteration 70 / 171 | Total Loss: 2.4227755069732666 | KNN Loss: 2.4192752838134766 | CLS Loss: 0.0035002168733626604
Epoch 169 / 200 | iteration 80 / 171 | Total Loss: 2.4134175777435303 | KNN Loss: 2.4037513732910156 | CLS Loss: 0.009666308760643005
Epoch 169 / 200 | iteration 90 / 171 | Total Loss: 2.426567554473877 | KNN Loss: 2.424578905105591 | CLS Loss: 0.001988608855754137
Epoch 169 / 200 | iteration 100 / 171 | Total Loss: 2.382051944732666 | KNN Loss: 2.3725602626800537 | CLS Loss: 0.009491756558418274
Epoch 169 / 200 | iteration 110 / 171 | Total Loss: 2.378204345703125 | KNN Loss: 2.362532615661621 | CLS Loss: 0.015671750530600548
Epoch 169 / 200 | iteration 120 / 171 | Total Loss: 2.417549848556

Epoch 172 / 200 | iteration 110 / 171 | Total Loss: 2.4419405460357666 | KNN Loss: 2.4361395835876465 | CLS Loss: 0.005800983402878046
Epoch 172 / 200 | iteration 120 / 171 | Total Loss: 2.3599190711975098 | KNN Loss: 2.349717855453491 | CLS Loss: 0.010201255790889263
Epoch 172 / 200 | iteration 130 / 171 | Total Loss: 2.407052755355835 | KNN Loss: 2.4011101722717285 | CLS Loss: 0.005942523945122957
Epoch 172 / 200 | iteration 140 / 171 | Total Loss: 2.421414613723755 | KNN Loss: 2.418891668319702 | CLS Loss: 0.002522991504520178
Epoch 172 / 200 | iteration 150 / 171 | Total Loss: 2.388122797012329 | KNN Loss: 2.3839452266693115 | CLS Loss: 0.00417746976017952
Epoch 172 / 200 | iteration 160 / 171 | Total Loss: 2.420609474182129 | KNN Loss: 2.402312994003296 | CLS Loss: 0.018296560272574425
Epoch 172 / 200 | iteration 170 / 171 | Total Loss: 2.4343409538269043 | KNN Loss: 2.4047329425811768 | CLS Loss: 0.029607970267534256
Epoch: 172, Loss: 2.4056, Train: 0.9975, Valid: 0.9879, Best: 0

Epoch 175 / 200 | iteration 170 / 171 | Total Loss: 2.447420597076416 | KNN Loss: 2.435206890106201 | CLS Loss: 0.012213627807796001
Epoch: 175, Loss: 2.4076, Train: 0.9976, Valid: 0.9866, Best: 0.9879
Epoch 176 / 200 | iteration 0 / 171 | Total Loss: 2.3773133754730225 | KNN Loss: 2.3642430305480957 | CLS Loss: 0.013070271350443363
Epoch 176 / 200 | iteration 10 / 171 | Total Loss: 2.3882036209106445 | KNN Loss: 2.3842883110046387 | CLS Loss: 0.0039153448306024075
Epoch 176 / 200 | iteration 20 / 171 | Total Loss: 2.4226443767547607 | KNN Loss: 2.4037094116210938 | CLS Loss: 0.018934909254312515
Epoch 176 / 200 | iteration 30 / 171 | Total Loss: 2.4463536739349365 | KNN Loss: 2.431215286254883 | CLS Loss: 0.015138499438762665
Epoch 176 / 200 | iteration 40 / 171 | Total Loss: 2.397960662841797 | KNN Loss: 2.3858954906463623 | CLS Loss: 0.0120651014149189
Epoch 176 / 200 | iteration 50 / 171 | Total Loss: 2.412271738052368 | KNN Loss: 2.4024672508239746 | CLS Loss: 0.009804506786167622

Epoch 179 / 200 | iteration 50 / 171 | Total Loss: 2.3864083290100098 | KNN Loss: 2.374725580215454 | CLS Loss: 0.011682683601975441
Epoch 179 / 200 | iteration 60 / 171 | Total Loss: 2.3844540119171143 | KNN Loss: 2.3769640922546387 | CLS Loss: 0.007490007672458887
Epoch 179 / 200 | iteration 70 / 171 | Total Loss: 2.49143385887146 | KNN Loss: 2.4787652492523193 | CLS Loss: 0.012668697163462639
Epoch 179 / 200 | iteration 80 / 171 | Total Loss: 2.3748507499694824 | KNN Loss: 2.3733279705047607 | CLS Loss: 0.0015228955307975411
Epoch 179 / 200 | iteration 90 / 171 | Total Loss: 2.3721649646759033 | KNN Loss: 2.367286443710327 | CLS Loss: 0.004878529813140631
Epoch 179 / 200 | iteration 100 / 171 | Total Loss: 2.4086239337921143 | KNN Loss: 2.3939995765686035 | CLS Loss: 0.014624453149735928
Epoch 179 / 200 | iteration 110 / 171 | Total Loss: 2.388645648956299 | KNN Loss: 2.3852450847625732 | CLS Loss: 0.0034006594214588404
Epoch 179 / 200 | iteration 120 / 171 | Total Loss: 2.365791797

Epoch 182 / 200 | iteration 110 / 171 | Total Loss: 2.4333391189575195 | KNN Loss: 2.402238368988037 | CLS Loss: 0.031100701540708542
Epoch 182 / 200 | iteration 120 / 171 | Total Loss: 2.443258047103882 | KNN Loss: 2.4226553440093994 | CLS Loss: 0.020602649077773094
Epoch 182 / 200 | iteration 130 / 171 | Total Loss: 2.461087465286255 | KNN Loss: 2.448310136795044 | CLS Loss: 0.012777361087501049
Epoch 182 / 200 | iteration 140 / 171 | Total Loss: 2.4481472969055176 | KNN Loss: 2.4400370121002197 | CLS Loss: 0.008110394701361656
Epoch 182 / 200 | iteration 150 / 171 | Total Loss: 2.4079201221466064 | KNN Loss: 2.3979594707489014 | CLS Loss: 0.009960675612092018
Epoch 182 / 200 | iteration 160 / 171 | Total Loss: 2.4123709201812744 | KNN Loss: 2.4099221229553223 | CLS Loss: 0.002448693383485079
Epoch 182 / 200 | iteration 170 / 171 | Total Loss: 2.4348597526550293 | KNN Loss: 2.4313321113586426 | CLS Loss: 0.003527525346726179
Epoch: 182, Loss: 2.4061, Train: 0.9966, Valid: 0.9868, Bes

Epoch 185 / 200 | iteration 170 / 171 | Total Loss: 2.418367862701416 | KNN Loss: 2.410569667816162 | CLS Loss: 0.007798283360898495
Epoch: 185, Loss: 2.4164, Train: 0.9969, Valid: 0.9858, Best: 0.9879
Epoch 186 / 200 | iteration 0 / 171 | Total Loss: 2.4439210891723633 | KNN Loss: 2.420003890991211 | CLS Loss: 0.023917116224765778
Epoch 186 / 200 | iteration 10 / 171 | Total Loss: 2.424248456954956 | KNN Loss: 2.4124906063079834 | CLS Loss: 0.011757914908230305
Epoch 186 / 200 | iteration 20 / 171 | Total Loss: 2.4127585887908936 | KNN Loss: 2.4028172492980957 | CLS Loss: 0.00994129665195942
Epoch 186 / 200 | iteration 30 / 171 | Total Loss: 2.4063804149627686 | KNN Loss: 2.396390914916992 | CLS Loss: 0.009989582002162933
Epoch 186 / 200 | iteration 40 / 171 | Total Loss: 2.385514259338379 | KNN Loss: 2.378171920776367 | CLS Loss: 0.00734240235760808
Epoch 186 / 200 | iteration 50 / 171 | Total Loss: 2.4034523963928223 | KNN Loss: 2.3900885581970215 | CLS Loss: 0.013363897800445557
Ep

Epoch 189 / 200 | iteration 50 / 171 | Total Loss: 2.4163219928741455 | KNN Loss: 2.404879331588745 | CLS Loss: 0.011442761868238449
Epoch 189 / 200 | iteration 60 / 171 | Total Loss: 2.388144016265869 | KNN Loss: 2.3862338066101074 | CLS Loss: 0.00191009440459311
Epoch 189 / 200 | iteration 70 / 171 | Total Loss: 2.4168198108673096 | KNN Loss: 2.4165663719177246 | CLS Loss: 0.0002535441890358925
Epoch 189 / 200 | iteration 80 / 171 | Total Loss: 2.3894286155700684 | KNN Loss: 2.382960081100464 | CLS Loss: 0.006468483246862888
Epoch 189 / 200 | iteration 90 / 171 | Total Loss: 2.3842203617095947 | KNN Loss: 2.3822245597839355 | CLS Loss: 0.0019959076307713985
Epoch 189 / 200 | iteration 100 / 171 | Total Loss: 2.443408966064453 | KNN Loss: 2.442505121231079 | CLS Loss: 0.0009037788840942085
Epoch 189 / 200 | iteration 110 / 171 | Total Loss: 2.414090156555176 | KNN Loss: 2.3917224407196045 | CLS Loss: 0.022367624565958977
Epoch 189 / 200 | iteration 120 / 171 | Total Loss: 2.4128868579

Epoch 192 / 200 | iteration 110 / 171 | Total Loss: 2.43387508392334 | KNN Loss: 2.4193668365478516 | CLS Loss: 0.014508312568068504
Epoch 192 / 200 | iteration 120 / 171 | Total Loss: 2.4155220985412598 | KNN Loss: 2.3944644927978516 | CLS Loss: 0.021057486534118652
Epoch 192 / 200 | iteration 130 / 171 | Total Loss: 2.4336280822753906 | KNN Loss: 2.4298341274261475 | CLS Loss: 0.0037940246984362602
Epoch 192 / 200 | iteration 140 / 171 | Total Loss: 2.422480344772339 | KNN Loss: 2.409694194793701 | CLS Loss: 0.012786205857992172
Epoch 192 / 200 | iteration 150 / 171 | Total Loss: 2.4170117378234863 | KNN Loss: 2.4160447120666504 | CLS Loss: 0.0009671399020589888
Epoch 192 / 200 | iteration 160 / 171 | Total Loss: 2.416262626647949 | KNN Loss: 2.410883903503418 | CLS Loss: 0.005378612782806158
Epoch 192 / 200 | iteration 170 / 171 | Total Loss: 2.4066991806030273 | KNN Loss: 2.402902364730835 | CLS Loss: 0.003796837292611599
Epoch: 192, Loss: 2.4150, Train: 0.9973, Valid: 0.9873, Best

Epoch 195 / 200 | iteration 170 / 171 | Total Loss: 2.4165523052215576 | KNN Loss: 2.3889713287353516 | CLS Loss: 0.02758098393678665
Epoch: 195, Loss: 2.4115, Train: 0.9972, Valid: 0.9874, Best: 0.9879
Epoch 196 / 200 | iteration 0 / 171 | Total Loss: 2.4283392429351807 | KNN Loss: 2.4202663898468018 | CLS Loss: 0.008072968572378159
Epoch 196 / 200 | iteration 10 / 171 | Total Loss: 2.3838796615600586 | KNN Loss: 2.382077932357788 | CLS Loss: 0.0018016919493675232
Epoch 196 / 200 | iteration 20 / 171 | Total Loss: 2.477060079574585 | KNN Loss: 2.4472100734710693 | CLS Loss: 0.02985006384551525
Epoch 196 / 200 | iteration 30 / 171 | Total Loss: 2.4015281200408936 | KNN Loss: 2.390451669692993 | CLS Loss: 0.011076342314481735
Epoch 196 / 200 | iteration 40 / 171 | Total Loss: 2.4147157669067383 | KNN Loss: 2.410370111465454 | CLS Loss: 0.004345686640590429
Epoch 196 / 200 | iteration 50 / 171 | Total Loss: 2.4007561206817627 | KNN Loss: 2.390831708908081 | CLS Loss: 0.009924476966261864

Epoch 199 / 200 | iteration 50 / 171 | Total Loss: 2.4234280586242676 | KNN Loss: 2.4157543182373047 | CLS Loss: 0.007673816755414009
Epoch 199 / 200 | iteration 60 / 171 | Total Loss: 2.4474942684173584 | KNN Loss: 2.423646926879883 | CLS Loss: 0.023847365751862526
Epoch 199 / 200 | iteration 70 / 171 | Total Loss: 2.4488182067871094 | KNN Loss: 2.435717821121216 | CLS Loss: 0.013100458309054375
Epoch 199 / 200 | iteration 80 / 171 | Total Loss: 2.396406412124634 | KNN Loss: 2.3829433917999268 | CLS Loss: 0.013462989591062069
Epoch 199 / 200 | iteration 90 / 171 | Total Loss: 2.4728357791900635 | KNN Loss: 2.4690585136413574 | CLS Loss: 0.0037773221265524626
Epoch 199 / 200 | iteration 100 / 171 | Total Loss: 2.397653818130493 | KNN Loss: 2.389310598373413 | CLS Loss: 0.008343221619725227
Epoch 199 / 200 | iteration 110 / 171 | Total Loss: 2.431727409362793 | KNN Loss: 2.4261274337768555 | CLS Loss: 0.0055998992174863815
Epoch 199 / 200 | iteration 120 / 171 | Total Loss: 2.3848183155

In [8]:
test(model, test_data_iter, device)

tensor(0.9867, device='cuda:0')

In [9]:
plt.figure()
plt.plot(losses, label='train loss')
plt.legend()
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [10]:
plt.figure()
plt.plot(train_accs, label='train accuracy')
plt.plot(val_accs, label='validation accuracy')
plt.legend()
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [11]:
test_samples = torch.tensor([])
projections = torch.tensor([])
labels = torch.tensor([])

with torch.no_grad():
    for x, y in tqdm(test_data_iter):
        test_samples = torch.cat([test_samples, x])
        labels = torch.cat([labels, y])
        x = x.to(device)
        _, interm = model(x, True)
        projections = torch.cat([projections, interm.detach().cpu().flatten(1)])

  0%|          | 0/43 [00:00<?, ?it/s]

In [12]:
distances = pairwise_distances(projections)
# distances = np.triu(distances)
distances_f = distances.flatten()

plt.matshow(distances)
plt.colorbar()
plt.figure()
plt.hist(distances_f[distances_f > 0], bins=1000)
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [13]:
clusters = DBSCAN(eps=2, min_samples=10).fit_predict(projections)

In [14]:
print(f"Number of inliers: {sum(clusters != -1) / len(clusters)}")

Number of inliers: 0.9194646201635375


In [15]:
perplexity = 100
p = reduce_dims_and_plot(projections[clusters != -1],
                         y=clusters[clusters != -1],
                         title=f'perplexity: {perplexity}',
                         file_name=None,
                         perplexity=perplexity,
                         library='Multicore-TSNE',
                         perform_PCA=False,
                         projected=None,
                         figure_type='2d',
                         show_figure=True,
                         close_figure=False,
                         text=None)    

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

# Train a Soft-Decision-Tree given the self-labels

## Prepare the dataset

In [16]:
tree_dataset = list(zip(test_samples.flatten(1)[clusters!=-1], clusters[clusters != -1]))
batch_size = 512
tree_loader = torch.utils.data.DataLoader(tree_dataset, batch_size=batch_size, shuffle=True)

# Define how we prune the weights of a node

In [17]:
def prune_node(node_weights, factor=1):
    w = node_weights.cpu().detach().numpy()
    mean_ = np.mean(w)
    std_ = np.std(w)
    node_weights[((mean_ - std_ * factor) < node_weights) & (node_weights < (mean_ + std_ * factor))] = 0
    return node_weights

def prune_node_keep(node_weights, keep=4):
    w = node_weights.cpu().detach().numpy()
    throw_idx = np.argsort(abs(w))[:-keep]
    node_weights[throw_idx] = 0
    return node_weights

def prune_tree(tree_, factor):
    new_weights = tree_.inner_nodes.weight.clone()
    for i in range(new_weights.shape[0]):
        res = prune_node_keep(new_weights[i, :], factor)
        new_weights[i, :] = res

    with torch.no_grad():
        tree_.inner_nodes.weight.copy_(new_weights)
        
def sparseness(x):
    s = []
    for i in range(x.shape[0]):
        x_ = x[i, :]
        sp = (len(x_) - torch.norm(x_, 0).item()) / len(x_)
        s.append(sp)
    return np.mean(s)

def compute_regularization_by_level(tree):
    total_reg = 0
    for i in range(tree.inner_nodes.weight.shape[0]):
        cur_level = np.floor(np.log2(i+1))
        node_reg = torch.norm(tree.inner_nodes.weight[i].view(-1), 2)
        total_reg += 2**(-cur_level) * node_reg
    return total_reg

def show_sparseness(tree):
    avg_sp = sparseness(tree.inner_nodes.weight)
    print(f"Average sparseness: {avg_sp}")
    layer = 0
    sps = []
    for i in range(tree.inner_nodes.weight.shape[0]):
        cur_layer = int(np.floor(np.log2(i+1)))
        if cur_layer != layer:
            print(f"layer {layer}: {np.mean(sps)}")
            sps = []
            layer = cur_layer

        x_ = tree.inner_nodes.weight[i, :]
        sp = (len(x_) - torch.norm(x_, 0).item()) / len(x_)
        sps.append(sp)
        
    return avg_sp

## Training configurations

In [18]:
def do_epoch(model, loader, device, log_interval, losses, accs, epoch, iteration):
    model = model.train()
    for batch_idx, (data, target) in enumerate(loader):
        iteration += 1
        data, target = data.to(device), target.to(device)

        output, penalty = tree.forward(data)

        # Loss
        loss_tree = criterion(output, target.view(-1))

        # Penalty
        loss_tree += penalty

        # Sparse regularization
#         fc_params = torch.cat([x.view(-1) for x in tree.inner_nodes.parameters()])
#         regularization = sparsity_lamda * torch.norm(fc_params, 2)
        regularization = sparsity_lamda * compute_regularization_by_level(tree)
        loss = loss_tree

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        losses.append(loss.item())

        pred = output.data.max(1)[1]
        correct = pred.eq(target.view(-1).data).sum()
        accs.append(correct.item() / data.size()[0])

        # Print training status
        if batch_idx % log_interval == 0:
            print(f"Epoch: {epoch:02d} | Batch: {batch_idx:03d} / {len(loader):03d} | Total loss: {loss.item():.3f} | Reg loss: {regularization.item():.3f} | Tree loss: {loss_tree.item():.3f} | Accuracy: {correct.item() / data.size()[0]:03f} | {round((time.time() - start_time) / iteration, 3)} sec/iter")
            
    return iteration


In [19]:
lr = 5e-3
weight_decay = 5e-4
sparsity_lamda = 2e-3
epochs = 400
log_interval = 100
use_cuda = device != 'cpu'

In [20]:
tree = SDT(input_dim=test_samples.shape[2], output_dim=len(set(clusters)) - 1, depth=tree_depth, lamda=1e-3, use_cuda=use_cuda)
optimizer = torch.optim.Adam(tree.parameters(),
                                 lr=lr,
                                 weight_decay=weight_decay)
criterion = nn.CrossEntropyLoss()
tree = tree.to(device)

In [21]:
losses = []
accs = []
sparsity = []

In [22]:
start_time = time.time()
iteration = 0
for epoch in range(epochs):
    # Training
    avg_sp = show_sparseness(tree)
    sparsity.append(avg_sp)
    iteration = do_epoch(tree, tree_loader, device, log_interval, losses, accs, epoch, iteration)
    
    if epoch % 1 == 0:
        prune_tree(tree, factor=3)
        

Average sparseness: 0.0
layer 0: 0.0
layer 1: 0.0
layer 2: 0.0
layer 3: 0.0
layer 4: 0.0
layer 5: 0.0
layer 6: 0.0
layer 7: 0.0
layer 8: 0.0
Epoch: 00 | Batch: 000 / 040 | Total loss: 3.146 | Reg loss: 0.012 | Tree loss: 3.146 | Accuracy: 0.015625 | 0.885 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
layer 7: 0.9840425531914895
layer 8: 0.9840425531914895
Epoch: 01 | Batch: 000 / 040 | Total loss: 3.107 | Reg loss: 0.006 | Tree loss: 3.107 | Accuracy: 0.166016 | 0.768 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
layer 7: 0.9840425531914895
layer 8: 0.9840425531914895
Epoch: 02 | Batch: 000 / 040 

Epoch: 20 | Batch: 000 / 040 | Total loss: 2.533 | Reg loss: 0.024 | Tree loss: 2.533 | Accuracy: 0.310547 | 0.759 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
layer 7: 0.9840425531914895
layer 8: 0.9840425531914895
Epoch: 21 | Batch: 000 / 040 | Total loss: 2.529 | Reg loss: 0.024 | Tree loss: 2.529 | Accuracy: 0.314453 | 0.759 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
layer 7: 0.9840425531914895
layer 8: 0.9840425531914895
Epoch: 22 | Batch: 000 / 040 | Total loss: 2.505 | Reg loss: 0.025 | Tree loss: 2.505 | Accuracy: 0.322266 | 0.759 sec/iter
Average sparseness: 0.9840425531914894
layer 0

Epoch: 40 | Batch: 000 / 040 | Total loss: 2.398 | Reg loss: 0.028 | Tree loss: 2.398 | Accuracy: 0.347656 | 0.759 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
layer 7: 0.9840425531914895
layer 8: 0.9840425531914895
Epoch: 41 | Batch: 000 / 040 | Total loss: 2.389 | Reg loss: 0.028 | Tree loss: 2.389 | Accuracy: 0.384766 | 0.759 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
layer 7: 0.9840425531914895
layer 8: 0.9840425531914895
Epoch: 42 | Batch: 000 / 040 | Total loss: 2.342 | Reg loss: 0.028 | Tree loss: 2.342 | Accuracy: 0.375000 | 0.759 sec/iter
Average sparseness: 0.9840425531914894
layer 0

Epoch: 60 | Batch: 000 / 040 | Total loss: 2.353 | Reg loss: 0.029 | Tree loss: 2.353 | Accuracy: 0.361328 | 0.76 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
layer 7: 0.9840425531914895
layer 8: 0.9840425531914895
Epoch: 61 | Batch: 000 / 040 | Total loss: 2.365 | Reg loss: 0.029 | Tree loss: 2.365 | Accuracy: 0.347656 | 0.76 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
layer 7: 0.9840425531914895
layer 8: 0.9840425531914895
Epoch: 62 | Batch: 000 / 040 | Total loss: 2.348 | Reg loss: 0.029 | Tree loss: 2.348 | Accuracy: 0.367188 | 0.76 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0

Epoch: 80 | Batch: 000 / 040 | Total loss: 2.288 | Reg loss: 0.031 | Tree loss: 2.288 | Accuracy: 0.365234 | 0.76 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
layer 7: 0.9840425531914895
layer 8: 0.9840425531914895
Epoch: 81 | Batch: 000 / 040 | Total loss: 2.402 | Reg loss: 0.031 | Tree loss: 2.402 | Accuracy: 0.343750 | 0.76 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
layer 7: 0.9840425531914895
layer 8: 0.9840425531914895
Epoch: 82 | Batch: 000 / 040 | Total loss: 2.367 | Reg loss: 0.031 | Tree loss: 2.367 | Accuracy: 0.357422 | 0.76 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0

Epoch: 100 | Batch: 000 / 040 | Total loss: 2.335 | Reg loss: 0.032 | Tree loss: 2.335 | Accuracy: 0.355469 | 0.76 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
layer 7: 0.9840425531914895
layer 8: 0.9840425531914895
Epoch: 101 | Batch: 000 / 040 | Total loss: 2.328 | Reg loss: 0.032 | Tree loss: 2.328 | Accuracy: 0.378906 | 0.76 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
layer 7: 0.9840425531914895
layer 8: 0.9840425531914895
Epoch: 102 | Batch: 000 / 040 | Total loss: 2.286 | Reg loss: 0.032 | Tree loss: 2.286 | Accuracy: 0.376953 | 0.76 sec/iter
Average sparseness: 0.9840425531914894
layer 0

Epoch: 120 | Batch: 000 / 040 | Total loss: 2.330 | Reg loss: 0.032 | Tree loss: 2.330 | Accuracy: 0.363281 | 0.757 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
layer 7: 0.9840425531914895
layer 8: 0.9840425531914895
Epoch: 121 | Batch: 000 / 040 | Total loss: 2.333 | Reg loss: 0.032 | Tree loss: 2.333 | Accuracy: 0.382812 | 0.757 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
layer 7: 0.9840425531914895
layer 8: 0.9840425531914895
Epoch: 122 | Batch: 000 / 040 | Total loss: 2.333 | Reg loss: 0.032 | Tree loss: 2.333 | Accuracy: 0.365234 | 0.756 sec/iter
Average sparseness: 0.9840425531914894
laye

Epoch: 140 | Batch: 000 / 040 | Total loss: 2.323 | Reg loss: 0.032 | Tree loss: 2.323 | Accuracy: 0.392578 | 0.752 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
layer 7: 0.9840425531914895
layer 8: 0.9840425531914895
Epoch: 141 | Batch: 000 / 040 | Total loss: 2.377 | Reg loss: 0.032 | Tree loss: 2.377 | Accuracy: 0.328125 | 0.752 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
layer 7: 0.9840425531914895
layer 8: 0.9840425531914895
Epoch: 142 | Batch: 000 / 040 | Total loss: 2.346 | Reg loss: 0.032 | Tree loss: 2.346 | Accuracy: 0.367188 | 0.752 sec/iter
Average sparseness: 0.9840425531914894
laye

Epoch: 160 | Batch: 000 / 040 | Total loss: 2.356 | Reg loss: 0.032 | Tree loss: 2.356 | Accuracy: 0.343750 | 0.749 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
layer 7: 0.9840425531914895
layer 8: 0.9840425531914895
Epoch: 161 | Batch: 000 / 040 | Total loss: 2.341 | Reg loss: 0.032 | Tree loss: 2.341 | Accuracy: 0.365234 | 0.748 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
layer 7: 0.9840425531914895
layer 8: 0.9840425531914895
Epoch: 162 | Batch: 000 / 040 | Total loss: 2.356 | Reg loss: 0.032 | Tree loss: 2.356 | Accuracy: 0.376953 | 0.748 sec/iter
Average sparseness: 0.9840425531914894
laye

Epoch: 180 | Batch: 000 / 040 | Total loss: 2.380 | Reg loss: 0.032 | Tree loss: 2.380 | Accuracy: 0.341797 | 0.745 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
layer 7: 0.9840425531914895
layer 8: 0.9840425531914895
Epoch: 181 | Batch: 000 / 040 | Total loss: 2.382 | Reg loss: 0.032 | Tree loss: 2.382 | Accuracy: 0.330078 | 0.745 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
layer 7: 0.9840425531914895
layer 8: 0.9840425531914895
Epoch: 182 | Batch: 000 / 040 | Total loss: 2.316 | Reg loss: 0.032 | Tree loss: 2.316 | Accuracy: 0.384766 | 0.745 sec/iter
Average sparseness: 0.9840425531914894
laye

Epoch: 200 | Batch: 000 / 040 | Total loss: 2.323 | Reg loss: 0.033 | Tree loss: 2.323 | Accuracy: 0.363281 | 0.743 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
layer 7: 0.9840425531914895
layer 8: 0.9840425531914895
Epoch: 201 | Batch: 000 / 040 | Total loss: 2.341 | Reg loss: 0.033 | Tree loss: 2.341 | Accuracy: 0.378906 | 0.743 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
layer 7: 0.9840425531914895
layer 8: 0.9840425531914895
Epoch: 202 | Batch: 000 / 040 | Total loss: 2.321 | Reg loss: 0.033 | Tree loss: 2.321 | Accuracy: 0.349609 | 0.743 sec/iter
Average sparseness: 0.9840425531914894
laye

Epoch: 220 | Batch: 000 / 040 | Total loss: 2.324 | Reg loss: 0.033 | Tree loss: 2.324 | Accuracy: 0.371094 | 0.742 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
layer 7: 0.9840425531914895
layer 8: 0.9840425531914895
Epoch: 221 | Batch: 000 / 040 | Total loss: 2.354 | Reg loss: 0.033 | Tree loss: 2.354 | Accuracy: 0.365234 | 0.741 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
layer 7: 0.9840425531914895
layer 8: 0.9840425531914895
Epoch: 222 | Batch: 000 / 040 | Total loss: 2.298 | Reg loss: 0.033 | Tree loss: 2.298 | Accuracy: 0.378906 | 0.741 sec/iter
Average sparseness: 0.9840425531914894
laye

Epoch: 240 | Batch: 000 / 040 | Total loss: 2.305 | Reg loss: 0.033 | Tree loss: 2.305 | Accuracy: 0.384766 | 0.74 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
layer 7: 0.9840425531914895
layer 8: 0.9840425531914895
Epoch: 241 | Batch: 000 / 040 | Total loss: 2.277 | Reg loss: 0.033 | Tree loss: 2.277 | Accuracy: 0.414062 | 0.74 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
layer 7: 0.9840425531914895
layer 8: 0.9840425531914895
Epoch: 242 | Batch: 000 / 040 | Total loss: 2.357 | Reg loss: 0.033 | Tree loss: 2.357 | Accuracy: 0.347656 | 0.74 sec/iter
Average sparseness: 0.9840425531914894
layer 0

Epoch: 263 | Batch: 000 / 040 | Total loss: 2.312 | Reg loss: 0.033 | Tree loss: 2.312 | Accuracy: 0.386719 | 0.738 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
layer 7: 0.9840425531914895
layer 8: 0.9840425531914895
Epoch: 264 | Batch: 000 / 040 | Total loss: 2.333 | Reg loss: 0.033 | Tree loss: 2.333 | Accuracy: 0.402344 | 0.738 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
layer 7: 0.9840425531914895
layer 8: 0.9840425531914895
Epoch: 265 | Batch: 000 / 040 | Total loss: 2.344 | Reg loss: 0.033 | Tree loss: 2.344 | Accuracy: 0.357422 | 0.738 sec/iter
Average sparseness: 0.9840425531914894
laye

Epoch: 283 | Batch: 000 / 040 | Total loss: 2.298 | Reg loss: 0.033 | Tree loss: 2.298 | Accuracy: 0.394531 | 0.737 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
layer 7: 0.9840425531914895
layer 8: 0.9840425531914895
Epoch: 284 | Batch: 000 / 040 | Total loss: 2.348 | Reg loss: 0.033 | Tree loss: 2.348 | Accuracy: 0.376953 | 0.737 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
layer 7: 0.9840425531914895
layer 8: 0.9840425531914895
Epoch: 285 | Batch: 000 / 040 | Total loss: 2.372 | Reg loss: 0.033 | Tree loss: 2.372 | Accuracy: 0.343750 | 0.737 sec/iter
Average sparseness: 0.9840425531914894
laye

Epoch: 303 | Batch: 000 / 040 | Total loss: 2.387 | Reg loss: 0.033 | Tree loss: 2.387 | Accuracy: 0.322266 | 0.736 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
layer 7: 0.9840425531914895
layer 8: 0.9840425531914895
Epoch: 304 | Batch: 000 / 040 | Total loss: 2.338 | Reg loss: 0.033 | Tree loss: 2.338 | Accuracy: 0.365234 | 0.736 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
layer 7: 0.9840425531914895
layer 8: 0.9840425531914895
Epoch: 305 | Batch: 000 / 040 | Total loss: 2.330 | Reg loss: 0.033 | Tree loss: 2.330 | Accuracy: 0.373047 | 0.736 sec/iter
Average sparseness: 0.9840425531914894
laye

Epoch: 323 | Batch: 000 / 040 | Total loss: 2.333 | Reg loss: 0.033 | Tree loss: 2.333 | Accuracy: 0.347656 | 0.735 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
layer 7: 0.9840425531914895
layer 8: 0.9840425531914895
Epoch: 324 | Batch: 000 / 040 | Total loss: 2.359 | Reg loss: 0.033 | Tree loss: 2.359 | Accuracy: 0.353516 | 0.735 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
layer 7: 0.9840425531914895
layer 8: 0.9840425531914895
Epoch: 325 | Batch: 000 / 040 | Total loss: 2.302 | Reg loss: 0.033 | Tree loss: 2.302 | Accuracy: 0.402344 | 0.735 sec/iter
Average sparseness: 0.9840425531914894
laye

Epoch: 343 | Batch: 000 / 040 | Total loss: 2.329 | Reg loss: 0.033 | Tree loss: 2.329 | Accuracy: 0.378906 | 0.734 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
layer 7: 0.9840425531914895
layer 8: 0.9840425531914895
Epoch: 344 | Batch: 000 / 040 | Total loss: 2.323 | Reg loss: 0.033 | Tree loss: 2.323 | Accuracy: 0.390625 | 0.734 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
layer 7: 0.9840425531914895
layer 8: 0.9840425531914895
Epoch: 345 | Batch: 000 / 040 | Total loss: 2.336 | Reg loss: 0.033 | Tree loss: 2.336 | Accuracy: 0.367188 | 0.733 sec/iter
Average sparseness: 0.9840425531914894
laye

Epoch: 363 | Batch: 000 / 040 | Total loss: 2.355 | Reg loss: 0.033 | Tree loss: 2.355 | Accuracy: 0.347656 | 0.733 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
layer 7: 0.9840425531914895
layer 8: 0.9840425531914895
Epoch: 364 | Batch: 000 / 040 | Total loss: 2.356 | Reg loss: 0.033 | Tree loss: 2.356 | Accuracy: 0.349609 | 0.733 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
layer 7: 0.9840425531914895
layer 8: 0.9840425531914895
Epoch: 365 | Batch: 000 / 040 | Total loss: 2.362 | Reg loss: 0.033 | Tree loss: 2.362 | Accuracy: 0.337891 | 0.733 sec/iter
Average sparseness: 0.9840425531914894
laye

Epoch: 383 | Batch: 000 / 040 | Total loss: 2.336 | Reg loss: 0.033 | Tree loss: 2.336 | Accuracy: 0.361328 | 0.732 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
layer 7: 0.9840425531914895
layer 8: 0.9840425531914895
Epoch: 384 | Batch: 000 / 040 | Total loss: 2.362 | Reg loss: 0.033 | Tree loss: 2.362 | Accuracy: 0.369141 | 0.732 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
layer 7: 0.9840425531914895
layer 8: 0.9840425531914895
Epoch: 385 | Batch: 000 / 040 | Total loss: 2.390 | Reg loss: 0.033 | Tree loss: 2.390 | Accuracy: 0.341797 | 0.732 sec/iter
Average sparseness: 0.9840425531914894
laye

In [23]:
plt.figure(figsize=(10, 5))
plt.ylabel("Accuracy")
plt.xlabel('Iteration')
plt.plot(accs, label='Accuracy vs iteration')
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [24]:
plt.figure()
plt.ylabel("Loss")
plt.xlabel('Iteration')
plt.plot(losses, label='Loss vs iteration')
# plt.yscale("log")
plt.show()

plt.figure()
weights = tree.inner_nodes.weight.cpu().detach().numpy().flatten()
plt.hist(weights, bins=500)
weights_std = np.std(weights)
weights_mean = np.mean(weights)
plt.axvline(weights_mean + weights_std, color='r')
plt.axvline(weights_mean - weights_std, color='r')
plt.title(f"Mean: {weights_mean}   |   STD: {weights_std}")
plt.yscale("log")
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

# Tree Visualization

In [25]:
plt.figure(figsize=(15, 10), dpi=80)
avg_height, root = tree.visualize()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Average height: 8.075471698113208


# Extract Rules

# Accumulate samples in the leaves

In [26]:
print(f"Number of patterns: {len(root.get_leaves())}")

Number of patterns: 53


In [27]:
method = 'greedy'

In [28]:
root.clear_leaves_samples()

with torch.no_grad():
    for batch_idx, (data, target) in enumerate(tree_loader):
        root.accumulate_samples(data, method)



# Tighten boundaries

In [29]:
attr_names = [f"T_{i}" for i in range(test_samples.shape[2])]
leaves = root.get_leaves()
sum_comprehensibility = 0
comprehensibilities = []
for pattern_counter, leaf in enumerate(leaves):
    leaf.reset_path()
    leaf.tighten_with_accumulated_samples()
    conds = leaf.get_path_conditions(attr_names)
    print(f"============== Pattern {pattern_counter + 1} ==============")
    comprehensibilities.append(sum([cond.comprehensibility for cond in conds]))
    
print(f"Average comprehensibility: {np.mean(comprehensibilities)}")
print(f"std comprehensibility: {np.std(comprehensibilities)}")
print(f"var comprehensibility: {np.var(comprehensibilities)}")
print(f"minimum comprehensibility: {np.min(comprehensibilities)}")
print(f"maximum comprehensibility: {np.max(comprehensibilities)}")

40
30
9615
8687
1756
Average comprehensibility: 45.77358490566038
std comprehensibility: 12.515110126885581
var comprehensibility: 156.62798148807403
minimum comprehensibility: 12
maximum comprehensibility: 60


  return np.log(1 / (1 - x))
