In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib widget
import numpy as np
import pandas as pd
import torch
import torchvision
from torch import optim
import matplotlib.pyplot as plt
import torch.nn.functional as F
from sklearn.metrics import pairwise_distances
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tqdm.notebook import tqdm
import os
import sys
import time
import torch.nn as nn
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)
from stream_generators.mit_bih import MITBIH
from utils.MatplotlibUtils import reduce_dims_and_plot
from network.auto_encoder import AutoEncoder
from losses.knn_loss import KNNLoss, ClassificationKNNLoss
from sklearn.cluster import DBSCAN
from sklearn.mixture import GaussianMixture
from soft_decision_tree.sdt_model import SDT
from sklearn.metrics import davies_bouldin_score

In [2]:
k = 32
tree_depth = 10
batch_size = 512
device = 'cuda'
train_data_path = r'/mnt/qnap/ekosman/mitbih_train.csv'
test_data_path = r'/mnt/qnap/ekosman/mitbih_test.csv'

In [3]:
train_data_iter = torch.utils.data.DataLoader(MITBIH(train_data_path),
                                             batch_size=batch_size,
                                             shuffle=True,
                                             num_workers=1,
                                             pin_memory=True,
                                             drop_last=True)

test_data_iter = torch.utils.data.DataLoader(MITBIH(test_data_path),
                                             batch_size=batch_size,
                                             shuffle=True,
                                             num_workers=1,
                                             pin_memory=True)

In [4]:
class ConvBlock(nn.Module):
    def __init__(self):
        super(ConvBlock, self).__init__()
        self.conv1 = nn.Conv1d(32, 32, kernel_size=5, stride=1, padding=2)
        self.conv2 = nn.Conv1d(32, 32, kernel_size=5, stride=1, padding=2)
        self.relu1 = nn.ReLU()
        self.relu2 = nn.ReLU()
        self.pool = nn.MaxPool1d(kernel_size=5, stride=2)
        
    def forward(self, x):
        y = x
        y = self.conv1(y)
        y = self.relu1(y)
        y = self.conv2(y)
        y = y + x
        y = self.relu2(y)
        y = self.pool(y)
        return y


class ECGModel(nn.Module):
    def __init__(self):
        super(ECGModel, self).__init__()
        self.conv = nn.Conv1d(1, 32, kernel_size=5, stride=1, padding=1)
        self.block1 = ConvBlock()
        self.block2 = ConvBlock()
        self.block3 = ConvBlock()
        self.block4 = ConvBlock()
        self.block5 = ConvBlock()
        self.fc1 = nn.Linear(64, 32)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(32, 5)

    def forward(self, x, return_interm=False):
        x = self.conv(x)
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        x = self.block4(x)
        x = self.block5(x)
        interm = x.flatten(1)
        x = self.fc1(interm)
        x = self.relu(x)
        x = self.fc2(x)
        
        if return_interm:
            return x, interm
        
        return x

In [5]:
knn_crt = ClassificationKNNLoss(k=k).to(device)

def train(model, loader, optimizer, device):
    model.train()

    total_loss = 0
    for iteration, (batch, target) in enumerate(loader):
        batch = batch.to(device)
        target = target.to(device)
        outputs, interm = model(batch, return_interm=True)
        mse_loss = F.cross_entropy(outputs, target)
        mse_loss = mse_loss.sum(dim=-1).mean()
        try:
            knn_loss = knn_crt(interm, target)
            if torch.isinf(knn_loss):
                knn_loss = torch.tensor(0).to(device)
        except ValueError:
            knn_loss = torch.tensor(0).to(device)
        loss = mse_loss + knn_loss
        total_loss += loss.item()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if iteration % log_every == 0:
            print(f"Epoch {epoch} / {epochs} | iteration {iteration} / {len(loader)} | Total Loss: {loss.item()} | KNN Loss: {knn_loss.item()} | CLS Loss: {mse_loss.item()}")

    return total_loss / len(loader)

@torch.no_grad()
def test(model, loader, device):
    model.eval()
    
    correct = 0
    for iteration, (batch, target) in enumerate(loader):
        batch = batch.to(device)
        target = target.to(device)
        y_pred = model(batch).argmax(dim=-1)
        correct += y_pred.eq(target.view(-1).data).sum()
    
    return correct / len(loader.dataset)

In [6]:
epochs = 200
lr = 1e-3
log_every = 10

model = ECGModel().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
num_params = sum([p.numel() for p in model.parameters()])
print(f'#Params: {num_params}')

#Params: 53957


In [7]:
best_valid_acc = 0
losses = []
train_accs = []
val_accs = []
for epoch in range(1, epochs + 1):
    loss = train(model, train_data_iter, optimizer, device)
#     print(f"Loss: {loss} =============================")
    losses.append(loss)
    train_acc = test(model, train_data_iter, device)
    train_accs.append(train_acc)
    valid_acc = test(model, test_data_iter, device)
    val_accs.append(valid_acc)
    if valid_acc > best_valid_acc:
        best_valid_acc = valid_acc
    if epoch % 1 == 0:
        print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, '
              f'Train: {train_acc:.4f}, Valid: {valid_acc:.4f}, '
              f'Best: {best_valid_acc:.4f}')

Epoch 1 / 200 | iteration 0 / 171 | Total Loss: 7.358572006225586 | KNN Loss: 5.69699239730835 | CLS Loss: 1.6615793704986572
Epoch 1 / 200 | iteration 10 / 171 | Total Loss: 4.858458518981934 | KNN Loss: 4.118841648101807 | CLS Loss: 0.739617109298706
Epoch 1 / 200 | iteration 20 / 171 | Total Loss: 4.7575297355651855 | KNN Loss: 4.001819610595703 | CLS Loss: 0.7557100653648376
Epoch 1 / 200 | iteration 30 / 171 | Total Loss: 4.550957679748535 | KNN Loss: 3.9674794673919678 | CLS Loss: 0.5834780335426331
Epoch 1 / 200 | iteration 40 / 171 | Total Loss: 4.4556169509887695 | KNN Loss: 3.8908891677856445 | CLS Loss: 0.5647277235984802
Epoch 1 / 200 | iteration 50 / 171 | Total Loss: 4.35487699508667 | KNN Loss: 3.8714163303375244 | CLS Loss: 0.48346057534217834
Epoch 1 / 200 | iteration 60 / 171 | Total Loss: 4.348923206329346 | KNN Loss: 3.8277058601379395 | CLS Loss: 0.5212175250053406
Epoch 1 / 200 | iteration 70 / 171 | Total Loss: 4.310129642486572 | KNN Loss: 3.895495653152466 | CL

Epoch 4 / 200 | iteration 80 / 171 | Total Loss: 3.8621749877929688 | KNN Loss: 3.7199788093566895 | CLS Loss: 0.14219610393047333
Epoch 4 / 200 | iteration 90 / 171 | Total Loss: 3.846350908279419 | KNN Loss: 3.7377986907958984 | CLS Loss: 0.10855214297771454
Epoch 4 / 200 | iteration 100 / 171 | Total Loss: 3.8782801628112793 | KNN Loss: 3.679821252822876 | CLS Loss: 0.19845889508724213
Epoch 4 / 200 | iteration 110 / 171 | Total Loss: 3.8147568702697754 | KNN Loss: 3.693053722381592 | CLS Loss: 0.121703140437603
Epoch 4 / 200 | iteration 120 / 171 | Total Loss: 3.862760305404663 | KNN Loss: 3.7014079093933105 | CLS Loss: 0.16135230660438538
Epoch 4 / 200 | iteration 130 / 171 | Total Loss: 3.8173587322235107 | KNN Loss: 3.670098304748535 | CLS Loss: 0.14726035296916962
Epoch 4 / 200 | iteration 140 / 171 | Total Loss: 3.7751920223236084 | KNN Loss: 3.7059693336486816 | CLS Loss: 0.06922276318073273
Epoch 4 / 200 | iteration 150 / 171 | Total Loss: 3.875886917114258 | KNN Loss: 3.701

Epoch 7 / 200 | iteration 160 / 171 | Total Loss: 3.7856101989746094 | KNN Loss: 3.687330722808838 | CLS Loss: 0.09827938675880432
Epoch 7 / 200 | iteration 170 / 171 | Total Loss: 3.7485692501068115 | KNN Loss: 3.6494884490966797 | CLS Loss: 0.09908083826303482
Epoch: 007, Loss: 3.7624, Train: 0.9775, Valid: 0.9732, Best: 0.9732
Epoch 8 / 200 | iteration 0 / 171 | Total Loss: 3.782048463821411 | KNN Loss: 3.664184093475342 | CLS Loss: 0.11786431819200516
Epoch 8 / 200 | iteration 10 / 171 | Total Loss: 3.771681070327759 | KNN Loss: 3.6252777576446533 | CLS Loss: 0.1464032530784607
Epoch 8 / 200 | iteration 20 / 171 | Total Loss: 3.708716869354248 | KNN Loss: 3.6196749210357666 | CLS Loss: 0.0890420526266098
Epoch 8 / 200 | iteration 30 / 171 | Total Loss: 3.746598958969116 | KNN Loss: 3.6318962574005127 | CLS Loss: 0.11470276117324829
Epoch 8 / 200 | iteration 40 / 171 | Total Loss: 3.7737317085266113 | KNN Loss: 3.6878867149353027 | CLS Loss: 0.08584505319595337
Epoch 8 / 200 | itera

Epoch 11 / 200 | iteration 50 / 171 | Total Loss: 3.7452592849731445 | KNN Loss: 3.6759467124938965 | CLS Loss: 0.06931264698505402
Epoch 11 / 200 | iteration 60 / 171 | Total Loss: 3.735203742980957 | KNN Loss: 3.6836297512054443 | CLS Loss: 0.05157403647899628
Epoch 11 / 200 | iteration 70 / 171 | Total Loss: 3.7291078567504883 | KNN Loss: 3.664486885070801 | CLS Loss: 0.06462091207504272
Epoch 11 / 200 | iteration 80 / 171 | Total Loss: 3.723984479904175 | KNN Loss: 3.665686845779419 | CLS Loss: 0.05829761177301407
Epoch 11 / 200 | iteration 90 / 171 | Total Loss: 3.703261613845825 | KNN Loss: 3.6206865310668945 | CLS Loss: 0.0825750008225441
Epoch 11 / 200 | iteration 100 / 171 | Total Loss: 3.743879795074463 | KNN Loss: 3.671294927597046 | CLS Loss: 0.07258490473031998
Epoch 11 / 200 | iteration 110 / 171 | Total Loss: 3.729762315750122 | KNN Loss: 3.6473186016082764 | CLS Loss: 0.08244366198778152
Epoch 11 / 200 | iteration 120 / 171 | Total Loss: 3.7208974361419678 | KNN Loss: 3

Epoch 14 / 200 | iteration 120 / 171 | Total Loss: 3.726778507232666 | KNN Loss: 3.677480459213257 | CLS Loss: 0.049298059195280075
Epoch 14 / 200 | iteration 130 / 171 | Total Loss: 3.713148832321167 | KNN Loss: 3.667767286300659 | CLS Loss: 0.04538147523999214
Epoch 14 / 200 | iteration 140 / 171 | Total Loss: 3.6706933975219727 | KNN Loss: 3.615631341934204 | CLS Loss: 0.055062130093574524
Epoch 14 / 200 | iteration 150 / 171 | Total Loss: 3.7304608821868896 | KNN Loss: 3.6823830604553223 | CLS Loss: 0.04807782918214798
Epoch 14 / 200 | iteration 160 / 171 | Total Loss: 3.6848113536834717 | KNN Loss: 3.6287119388580322 | CLS Loss: 0.05609948933124542
Epoch 14 / 200 | iteration 170 / 171 | Total Loss: 3.6925601959228516 | KNN Loss: 3.6284523010253906 | CLS Loss: 0.0641079694032669
Epoch: 014, Loss: 3.7151, Train: 0.9834, Valid: 0.9782, Best: 0.9801
Epoch 15 / 200 | iteration 0 / 171 | Total Loss: 3.7169532775878906 | KNN Loss: 3.634455442428589 | CLS Loss: 0.08249791711568832
Epoch 1

Epoch 18 / 200 | iteration 10 / 171 | Total Loss: 3.6905293464660645 | KNN Loss: 3.6615214347839355 | CLS Loss: 0.029007909819483757
Epoch 18 / 200 | iteration 20 / 171 | Total Loss: 3.706444025039673 | KNN Loss: 3.633948802947998 | CLS Loss: 0.07249516248703003
Epoch 18 / 200 | iteration 30 / 171 | Total Loss: 3.68507981300354 | KNN Loss: 3.631970167160034 | CLS Loss: 0.05310957878828049
Epoch 18 / 200 | iteration 40 / 171 | Total Loss: 3.680185317993164 | KNN Loss: 3.606022357940674 | CLS Loss: 0.0741630345582962
Epoch 18 / 200 | iteration 50 / 171 | Total Loss: 3.683511972427368 | KNN Loss: 3.6246085166931152 | CLS Loss: 0.05890338122844696
Epoch 18 / 200 | iteration 60 / 171 | Total Loss: 3.679985761642456 | KNN Loss: 3.6364259719848633 | CLS Loss: 0.043559860438108444
Epoch 18 / 200 | iteration 70 / 171 | Total Loss: 3.67526912689209 | KNN Loss: 3.6158130168914795 | CLS Loss: 0.059456031769514084
Epoch 18 / 200 | iteration 80 / 171 | Total Loss: 3.703122138977051 | KNN Loss: 3.665

Epoch 21 / 200 | iteration 80 / 171 | Total Loss: 3.6741855144500732 | KNN Loss: 3.644930601119995 | CLS Loss: 0.029254989698529243
Epoch 21 / 200 | iteration 90 / 171 | Total Loss: 3.691176414489746 | KNN Loss: 3.6542088985443115 | CLS Loss: 0.03696751967072487
Epoch 21 / 200 | iteration 100 / 171 | Total Loss: 3.6645641326904297 | KNN Loss: 3.6213107109069824 | CLS Loss: 0.04325348883867264
Epoch 21 / 200 | iteration 110 / 171 | Total Loss: 3.6763253211975098 | KNN Loss: 3.633737087249756 | CLS Loss: 0.042588163167238235
Epoch 21 / 200 | iteration 120 / 171 | Total Loss: 3.675701141357422 | KNN Loss: 3.6434972286224365 | CLS Loss: 0.032203931361436844
Epoch 21 / 200 | iteration 130 / 171 | Total Loss: 3.6884822845458984 | KNN Loss: 3.631119728088379 | CLS Loss: 0.05736253783106804
Epoch 21 / 200 | iteration 140 / 171 | Total Loss: 3.7111411094665527 | KNN Loss: 3.6606998443603516 | CLS Loss: 0.05044128745794296
Epoch 21 / 200 | iteration 150 / 171 | Total Loss: 3.6807398796081543 | K

Epoch 24 / 200 | iteration 150 / 171 | Total Loss: 3.6903743743896484 | KNN Loss: 3.6219820976257324 | CLS Loss: 0.0683922991156578
Epoch 24 / 200 | iteration 160 / 171 | Total Loss: 3.701960325241089 | KNN Loss: 3.6588029861450195 | CLS Loss: 0.043157245963811874
Epoch 24 / 200 | iteration 170 / 171 | Total Loss: 3.6730947494506836 | KNN Loss: 3.6150455474853516 | CLS Loss: 0.05804911628365517
Epoch: 024, Loss: 3.6882, Train: 0.9878, Valid: 0.9827, Best: 0.9833
Epoch 25 / 200 | iteration 0 / 171 | Total Loss: 3.6799540519714355 | KNN Loss: 3.621143341064453 | CLS Loss: 0.05881066620349884
Epoch 25 / 200 | iteration 10 / 171 | Total Loss: 3.6978490352630615 | KNN Loss: 3.6547534465789795 | CLS Loss: 0.043095577508211136
Epoch 25 / 200 | iteration 20 / 171 | Total Loss: 3.680058717727661 | KNN Loss: 3.6266684532165527 | CLS Loss: 0.053390275686979294
Epoch 25 / 200 | iteration 30 / 171 | Total Loss: 3.657360792160034 | KNN Loss: 3.6209774017333984 | CLS Loss: 0.03638339042663574
Epoch 2

Epoch 28 / 200 | iteration 40 / 171 | Total Loss: 3.683732271194458 | KNN Loss: 3.63784122467041 | CLS Loss: 0.045891132205724716
Epoch 28 / 200 | iteration 50 / 171 | Total Loss: 3.7241547107696533 | KNN Loss: 3.6810832023620605 | CLS Loss: 0.04307141527533531
Epoch 28 / 200 | iteration 60 / 171 | Total Loss: 3.651355504989624 | KNN Loss: 3.6278281211853027 | CLS Loss: 0.02352738007903099
Epoch 28 / 200 | iteration 70 / 171 | Total Loss: 3.700629711151123 | KNN Loss: 3.6274824142456055 | CLS Loss: 0.0731472596526146
Epoch 28 / 200 | iteration 80 / 171 | Total Loss: 3.6668405532836914 | KNN Loss: 3.621349811553955 | CLS Loss: 0.04549062252044678
Epoch 28 / 200 | iteration 90 / 171 | Total Loss: 3.6885180473327637 | KNN Loss: 3.663614511489868 | CLS Loss: 0.024903450161218643
Epoch 28 / 200 | iteration 100 / 171 | Total Loss: 3.652031660079956 | KNN Loss: 3.612787961959839 | CLS Loss: 0.03924375772476196
Epoch 28 / 200 | iteration 110 / 171 | Total Loss: 3.6337103843688965 | KNN Loss: 3

Epoch 31 / 200 | iteration 110 / 171 | Total Loss: 3.688228130340576 | KNN Loss: 3.6507999897003174 | CLS Loss: 0.03742813691496849
Epoch 31 / 200 | iteration 120 / 171 | Total Loss: 3.6421101093292236 | KNN Loss: 3.586568593978882 | CLS Loss: 0.05554141104221344
Epoch 31 / 200 | iteration 130 / 171 | Total Loss: 3.7111170291900635 | KNN Loss: 3.667830467224121 | CLS Loss: 0.0432865284383297
Epoch 31 / 200 | iteration 140 / 171 | Total Loss: 3.691829204559326 | KNN Loss: 3.659367084503174 | CLS Loss: 0.032462190836668015
Epoch 31 / 200 | iteration 150 / 171 | Total Loss: 3.6361770629882812 | KNN Loss: 3.6278343200683594 | CLS Loss: 0.008342758752405643
Epoch 31 / 200 | iteration 160 / 171 | Total Loss: 3.6458358764648438 | KNN Loss: 3.619546890258789 | CLS Loss: 0.02628910169005394
Epoch 31 / 200 | iteration 170 / 171 | Total Loss: 3.6320698261260986 | KNN Loss: 3.6098809242248535 | CLS Loss: 0.022188959643244743
Epoch: 031, Loss: 3.6718, Train: 0.9904, Valid: 0.9854, Best: 0.9854
Epoc

Epoch: 034, Loss: 3.6655, Train: 0.9904, Valid: 0.9851, Best: 0.9854
Epoch 35 / 200 | iteration 0 / 171 | Total Loss: 3.659313678741455 | KNN Loss: 3.646801471710205 | CLS Loss: 0.012512186542153358
Epoch 35 / 200 | iteration 10 / 171 | Total Loss: 3.658104658126831 | KNN Loss: 3.633042573928833 | CLS Loss: 0.02506212517619133
Epoch 35 / 200 | iteration 20 / 171 | Total Loss: 3.6805267333984375 | KNN Loss: 3.6537559032440186 | CLS Loss: 0.02677087113261223
Epoch 35 / 200 | iteration 30 / 171 | Total Loss: 3.659950017929077 | KNN Loss: 3.6438732147216797 | CLS Loss: 0.016076767817139626
Epoch 35 / 200 | iteration 40 / 171 | Total Loss: 3.638293504714966 | KNN Loss: 3.612647771835327 | CLS Loss: 0.025645634159445763
Epoch 35 / 200 | iteration 50 / 171 | Total Loss: 3.749516725540161 | KNN Loss: 3.6963722705841064 | CLS Loss: 0.053144536912441254
Epoch 35 / 200 | iteration 60 / 171 | Total Loss: 3.63759446144104 | KNN Loss: 3.623734474182129 | CLS Loss: 0.013860026374459267
Epoch 35 / 200

Epoch 38 / 200 | iteration 70 / 171 | Total Loss: 3.7646243572235107 | KNN Loss: 3.6833763122558594 | CLS Loss: 0.08124809712171555
Epoch 38 / 200 | iteration 80 / 171 | Total Loss: 3.6657559871673584 | KNN Loss: 3.6157639026641846 | CLS Loss: 0.04999213665723801
Epoch 38 / 200 | iteration 90 / 171 | Total Loss: 3.6492063999176025 | KNN Loss: 3.6175856590270996 | CLS Loss: 0.03162074089050293
Epoch 38 / 200 | iteration 100 / 171 | Total Loss: 3.6513671875 | KNN Loss: 3.631314754486084 | CLS Loss: 0.020052360370755196
Epoch 38 / 200 | iteration 110 / 171 | Total Loss: 3.6682796478271484 | KNN Loss: 3.629516839981079 | CLS Loss: 0.03876269981265068
Epoch 38 / 200 | iteration 120 / 171 | Total Loss: 3.6607375144958496 | KNN Loss: 3.62886643409729 | CLS Loss: 0.03187102451920509
Epoch 38 / 200 | iteration 130 / 171 | Total Loss: 3.722790002822876 | KNN Loss: 3.6868550777435303 | CLS Loss: 0.03593502938747406
Epoch 38 / 200 | iteration 140 / 171 | Total Loss: 3.6698596477508545 | KNN Loss: 

Epoch 41 / 200 | iteration 140 / 171 | Total Loss: 3.652107000350952 | KNN Loss: 3.636575222015381 | CLS Loss: 0.015531751327216625
Epoch 41 / 200 | iteration 150 / 171 | Total Loss: 3.7022857666015625 | KNN Loss: 3.6791675090789795 | CLS Loss: 0.02311818115413189
Epoch 41 / 200 | iteration 160 / 171 | Total Loss: 3.66413950920105 | KNN Loss: 3.6246023178100586 | CLS Loss: 0.039537109434604645
Epoch 41 / 200 | iteration 170 / 171 | Total Loss: 3.6606552600860596 | KNN Loss: 3.6410980224609375 | CLS Loss: 0.019557276740670204
Epoch: 041, Loss: 3.6582, Train: 0.9932, Valid: 0.9861, Best: 0.9861
Epoch 42 / 200 | iteration 0 / 171 | Total Loss: 3.6573989391326904 | KNN Loss: 3.621546983718872 | CLS Loss: 0.035851992666721344
Epoch 42 / 200 | iteration 10 / 171 | Total Loss: 3.63030743598938 | KNN Loss: 3.61606764793396 | CLS Loss: 0.014239874668419361
Epoch 42 / 200 | iteration 20 / 171 | Total Loss: 3.651611328125 | KNN Loss: 3.636916160583496 | CLS Loss: 0.014695207588374615
Epoch 42 / 2

Epoch 45 / 200 | iteration 30 / 171 | Total Loss: 3.6657557487487793 | KNN Loss: 3.6391284465789795 | CLS Loss: 0.026627320796251297
Epoch 45 / 200 | iteration 40 / 171 | Total Loss: 3.6343002319335938 | KNN Loss: 3.6166653633117676 | CLS Loss: 0.0176348015666008
Epoch 45 / 200 | iteration 50 / 171 | Total Loss: 3.6442413330078125 | KNN Loss: 3.594874382019043 | CLS Loss: 0.0493670292198658
Epoch 45 / 200 | iteration 60 / 171 | Total Loss: 3.6451804637908936 | KNN Loss: 3.6213901042938232 | CLS Loss: 0.023790359497070312
Epoch 45 / 200 | iteration 70 / 171 | Total Loss: 3.627195358276367 | KNN Loss: 3.606367349624634 | CLS Loss: 0.02082797884941101
Epoch 45 / 200 | iteration 80 / 171 | Total Loss: 3.6710758209228516 | KNN Loss: 3.619884490966797 | CLS Loss: 0.05119134858250618
Epoch 45 / 200 | iteration 90 / 171 | Total Loss: 3.6580140590667725 | KNN Loss: 3.6266937255859375 | CLS Loss: 0.03132032975554466
Epoch 45 / 200 | iteration 100 / 171 | Total Loss: 3.6403822898864746 | KNN Loss

Epoch 48 / 200 | iteration 100 / 171 | Total Loss: 3.6719095706939697 | KNN Loss: 3.645430326461792 | CLS Loss: 0.026479190215468407
Epoch 48 / 200 | iteration 110 / 171 | Total Loss: 3.633803367614746 | KNN Loss: 3.6126937866210938 | CLS Loss: 0.021109672263264656
Epoch 48 / 200 | iteration 120 / 171 | Total Loss: 3.692962408065796 | KNN Loss: 3.637575149536133 | CLS Loss: 0.055387210100889206
Epoch 48 / 200 | iteration 130 / 171 | Total Loss: 3.662323236465454 | KNN Loss: 3.6105175018310547 | CLS Loss: 0.05180574953556061
Epoch 48 / 200 | iteration 140 / 171 | Total Loss: 3.6380743980407715 | KNN Loss: 3.612069606781006 | CLS Loss: 0.0260047297924757
Epoch 48 / 200 | iteration 150 / 171 | Total Loss: 3.662916898727417 | KNN Loss: 3.6141738891601562 | CLS Loss: 0.048742905259132385
Epoch 48 / 200 | iteration 160 / 171 | Total Loss: 3.6160552501678467 | KNN Loss: 3.600604295730591 | CLS Loss: 0.015450932085514069
Epoch 48 / 200 | iteration 170 / 171 | Total Loss: 3.647099733352661 | KN

Epoch 51 / 200 | iteration 170 / 171 | Total Loss: 3.66347599029541 | KNN Loss: 3.6424946784973145 | CLS Loss: 0.020981252193450928
Epoch: 051, Loss: 3.6474, Train: 0.9923, Valid: 0.9843, Best: 0.9863
Epoch 52 / 200 | iteration 0 / 171 | Total Loss: 3.658198595046997 | KNN Loss: 3.6435351371765137 | CLS Loss: 0.014663368463516235
Epoch 52 / 200 | iteration 10 / 171 | Total Loss: 3.694763422012329 | KNN Loss: 3.6649184226989746 | CLS Loss: 0.02984493039548397
Epoch 52 / 200 | iteration 20 / 171 | Total Loss: 3.665208101272583 | KNN Loss: 3.6309077739715576 | CLS Loss: 0.034300416707992554
Epoch 52 / 200 | iteration 30 / 171 | Total Loss: 3.6332333087921143 | KNN Loss: 3.611560344696045 | CLS Loss: 0.02167295292019844
Epoch 52 / 200 | iteration 40 / 171 | Total Loss: 3.6391637325286865 | KNN Loss: 3.6167502403259277 | CLS Loss: 0.022413594648241997
Epoch 52 / 200 | iteration 50 / 171 | Total Loss: 3.6723878383636475 | KNN Loss: 3.6497390270233154 | CLS Loss: 0.0226488895714283
Epoch 52 /

Epoch 55 / 200 | iteration 60 / 171 | Total Loss: 3.611485719680786 | KNN Loss: 3.6060333251953125 | CLS Loss: 0.005452336277812719
Epoch 55 / 200 | iteration 70 / 171 | Total Loss: 3.6485650539398193 | KNN Loss: 3.624685764312744 | CLS Loss: 0.023879220709204674
Epoch 55 / 200 | iteration 80 / 171 | Total Loss: 3.6490020751953125 | KNN Loss: 3.63244891166687 | CLS Loss: 0.016553224995732307
Epoch 55 / 200 | iteration 90 / 171 | Total Loss: 3.6282870769500732 | KNN Loss: 3.611921548843384 | CLS Loss: 0.016365550458431244
Epoch 55 / 200 | iteration 100 / 171 | Total Loss: 3.619569778442383 | KNN Loss: 3.6002280712127686 | CLS Loss: 0.01934160105884075
Epoch 55 / 200 | iteration 110 / 171 | Total Loss: 3.643597364425659 | KNN Loss: 3.6222341060638428 | CLS Loss: 0.021363308653235435
Epoch 55 / 200 | iteration 120 / 171 | Total Loss: 3.649646043777466 | KNN Loss: 3.6280531883239746 | CLS Loss: 0.021592896431684494
Epoch 55 / 200 | iteration 130 / 171 | Total Loss: 3.669174909591675 | KNN 

Epoch 58 / 200 | iteration 130 / 171 | Total Loss: 3.707387685775757 | KNN Loss: 3.6706461906433105 | CLS Loss: 0.036741551011800766
Epoch 58 / 200 | iteration 140 / 171 | Total Loss: 3.6527771949768066 | KNN Loss: 3.588421106338501 | CLS Loss: 0.06435597687959671
Epoch 58 / 200 | iteration 150 / 171 | Total Loss: 3.6639392375946045 | KNN Loss: 3.637831211090088 | CLS Loss: 0.02610807865858078
Epoch 58 / 200 | iteration 160 / 171 | Total Loss: 3.6500024795532227 | KNN Loss: 3.6244359016418457 | CLS Loss: 0.025566695258021355
Epoch 58 / 200 | iteration 170 / 171 | Total Loss: 3.672175168991089 | KNN Loss: 3.649306535720825 | CLS Loss: 0.02286870777606964
Epoch: 058, Loss: 3.6435, Train: 0.9943, Valid: 0.9857, Best: 0.9865
Epoch 59 / 200 | iteration 0 / 171 | Total Loss: 3.6151020526885986 | KNN Loss: 3.5994064807891846 | CLS Loss: 0.015695666894316673
Epoch 59 / 200 | iteration 10 / 171 | Total Loss: 3.668445587158203 | KNN Loss: 3.6377317905426025 | CLS Loss: 0.030713798478245735
Epoch

Epoch 62 / 200 | iteration 10 / 171 | Total Loss: 3.624119281768799 | KNN Loss: 3.6121408939361572 | CLS Loss: 0.011978409253060818
Epoch 62 / 200 | iteration 20 / 171 | Total Loss: 3.6425552368164062 | KNN Loss: 3.611420154571533 | CLS Loss: 0.031134996563196182
Epoch 62 / 200 | iteration 30 / 171 | Total Loss: 3.6632542610168457 | KNN Loss: 3.644176721572876 | CLS Loss: 0.019077586010098457
Epoch 62 / 200 | iteration 40 / 171 | Total Loss: 3.673609972000122 | KNN Loss: 3.6403732299804688 | CLS Loss: 0.03323666751384735
Epoch 62 / 200 | iteration 50 / 171 | Total Loss: 3.6287682056427 | KNN Loss: 3.606776475906372 | CLS Loss: 0.021991783753037453
Epoch 62 / 200 | iteration 60 / 171 | Total Loss: 3.658700704574585 | KNN Loss: 3.631220817565918 | CLS Loss: 0.02747991681098938
Epoch 62 / 200 | iteration 70 / 171 | Total Loss: 3.635669231414795 | KNN Loss: 3.6194655895233154 | CLS Loss: 0.016203636303544044
Epoch 62 / 200 | iteration 80 / 171 | Total Loss: 3.6306393146514893 | KNN Loss: 3

Epoch 65 / 200 | iteration 80 / 171 | Total Loss: 3.6313467025756836 | KNN Loss: 3.5977885723114014 | CLS Loss: 0.03355815261602402
Epoch 65 / 200 | iteration 90 / 171 | Total Loss: 3.6197214126586914 | KNN Loss: 3.611952781677246 | CLS Loss: 0.007768668700009584
Epoch 65 / 200 | iteration 100 / 171 | Total Loss: 3.6323742866516113 | KNN Loss: 3.6215708255767822 | CLS Loss: 0.010803543031215668
Epoch 65 / 200 | iteration 110 / 171 | Total Loss: 3.6546242237091064 | KNN Loss: 3.6324174404144287 | CLS Loss: 0.02220667526125908
Epoch 65 / 200 | iteration 120 / 171 | Total Loss: 3.61924147605896 | KNN Loss: 3.607754945755005 | CLS Loss: 0.011486422270536423
Epoch 65 / 200 | iteration 130 / 171 | Total Loss: 3.620983362197876 | KNN Loss: 3.6117069721221924 | CLS Loss: 0.009276455268263817
Epoch 65 / 200 | iteration 140 / 171 | Total Loss: 3.6439037322998047 | KNN Loss: 3.6124794483184814 | CLS Loss: 0.031424324959516525
Epoch 65 / 200 | iteration 150 / 171 | Total Loss: 3.6806492805480957 |

Epoch 68 / 200 | iteration 150 / 171 | Total Loss: 3.660557508468628 | KNN Loss: 3.632819652557373 | CLS Loss: 0.027737827971577644
Epoch 68 / 200 | iteration 160 / 171 | Total Loss: 3.619223117828369 | KNN Loss: 3.608908176422119 | CLS Loss: 0.01031486876308918
Epoch 68 / 200 | iteration 170 / 171 | Total Loss: 3.6122071743011475 | KNN Loss: 3.5982677936553955 | CLS Loss: 0.013939343392848969
Epoch: 068, Loss: 3.6413, Train: 0.9947, Valid: 0.9851, Best: 0.9867
Epoch 69 / 200 | iteration 0 / 171 | Total Loss: 3.6390368938446045 | KNN Loss: 3.6132309436798096 | CLS Loss: 0.025805959478020668
Epoch 69 / 200 | iteration 10 / 171 | Total Loss: 3.6611318588256836 | KNN Loss: 3.648470163345337 | CLS Loss: 0.012661672197282314
Epoch 69 / 200 | iteration 20 / 171 | Total Loss: 3.6345252990722656 | KNN Loss: 3.6111583709716797 | CLS Loss: 0.023366957902908325
Epoch 69 / 200 | iteration 30 / 171 | Total Loss: 3.6910717487335205 | KNN Loss: 3.6603481769561768 | CLS Loss: 0.030723586678504944
Epoc

Epoch 72 / 200 | iteration 30 / 171 | Total Loss: 3.6367435455322266 | KNN Loss: 3.616748332977295 | CLS Loss: 0.019995326176285744
Epoch 72 / 200 | iteration 40 / 171 | Total Loss: 3.6451759338378906 | KNN Loss: 3.610538959503174 | CLS Loss: 0.03463687375187874
Epoch 72 / 200 | iteration 50 / 171 | Total Loss: 3.6310174465179443 | KNN Loss: 3.6221866607666016 | CLS Loss: 0.008830725215375423
Epoch 72 / 200 | iteration 60 / 171 | Total Loss: 3.6177468299865723 | KNN Loss: 3.6047182083129883 | CLS Loss: 0.013028637506067753
Epoch 72 / 200 | iteration 70 / 171 | Total Loss: 3.609560251235962 | KNN Loss: 3.5998153686523438 | CLS Loss: 0.009744923561811447
Epoch 72 / 200 | iteration 80 / 171 | Total Loss: 3.6321043968200684 | KNN Loss: 3.618283748626709 | CLS Loss: 0.013820596970617771
Epoch 72 / 200 | iteration 90 / 171 | Total Loss: 3.6085708141326904 | KNN Loss: 3.5882840156555176 | CLS Loss: 0.020286763086915016
Epoch 72 / 200 | iteration 100 / 171 | Total Loss: 3.6050188541412354 | KN

Epoch 75 / 200 | iteration 100 / 171 | Total Loss: 3.625706911087036 | KNN Loss: 3.6131656169891357 | CLS Loss: 0.012541298754513264
Epoch 75 / 200 | iteration 110 / 171 | Total Loss: 3.6068897247314453 | KNN Loss: 3.5812323093414307 | CLS Loss: 0.02565743215382099
Epoch 75 / 200 | iteration 120 / 171 | Total Loss: 3.6437222957611084 | KNN Loss: 3.6286258697509766 | CLS Loss: 0.015096385963261127
Epoch 75 / 200 | iteration 130 / 171 | Total Loss: 3.607807159423828 | KNN Loss: 3.6026077270507812 | CLS Loss: 0.005199415609240532
Epoch 75 / 200 | iteration 140 / 171 | Total Loss: 3.632398843765259 | KNN Loss: 3.621151924133301 | CLS Loss: 0.011246957816183567
Epoch 75 / 200 | iteration 150 / 171 | Total Loss: 3.621037006378174 | KNN Loss: 3.6117959022521973 | CLS Loss: 0.009241144172847271
Epoch 75 / 200 | iteration 160 / 171 | Total Loss: 3.638749837875366 | KNN Loss: 3.630941390991211 | CLS Loss: 0.007808331400156021
Epoch 75 / 200 | iteration 170 / 171 | Total Loss: 3.645456552505493 |

Epoch 78 / 200 | iteration 170 / 171 | Total Loss: 3.6490485668182373 | KNN Loss: 3.622202157974243 | CLS Loss: 0.02684645727276802
Epoch: 078, Loss: 3.6330, Train: 0.9942, Valid: 0.9841, Best: 0.9872
Epoch 79 / 200 | iteration 0 / 171 | Total Loss: 3.646577835083008 | KNN Loss: 3.6311686038970947 | CLS Loss: 0.015409338288009167
Epoch 79 / 200 | iteration 10 / 171 | Total Loss: 3.5857865810394287 | KNN Loss: 3.582533597946167 | CLS Loss: 0.003253028029575944
Epoch 79 / 200 | iteration 20 / 171 | Total Loss: 3.603862762451172 | KNN Loss: 3.592116594314575 | CLS Loss: 0.011746094562113285
Epoch 79 / 200 | iteration 30 / 171 | Total Loss: 3.6003713607788086 | KNN Loss: 3.5903584957122803 | CLS Loss: 0.010012773796916008
Epoch 79 / 200 | iteration 40 / 171 | Total Loss: 3.6088056564331055 | KNN Loss: 3.5953662395477295 | CLS Loss: 0.013439413160085678
Epoch 79 / 200 | iteration 50 / 171 | Total Loss: 3.665938377380371 | KNN Loss: 3.655069589614868 | CLS Loss: 0.010868889279663563
Epoch 79

Epoch 82 / 200 | iteration 50 / 171 | Total Loss: 3.611201524734497 | KNN Loss: 3.5927698612213135 | CLS Loss: 0.018431762233376503
Epoch 82 / 200 | iteration 60 / 171 | Total Loss: 3.686183452606201 | KNN Loss: 3.645203113555908 | CLS Loss: 0.04098033159971237
Epoch 82 / 200 | iteration 70 / 171 | Total Loss: 3.620623826980591 | KNN Loss: 3.6062984466552734 | CLS Loss: 0.01432543434202671
Epoch 82 / 200 | iteration 80 / 171 | Total Loss: 3.655012607574463 | KNN Loss: 3.6447653770446777 | CLS Loss: 0.010247151367366314
Epoch 82 / 200 | iteration 90 / 171 | Total Loss: 3.62788724899292 | KNN Loss: 3.594909191131592 | CLS Loss: 0.032978083938360214
Epoch 82 / 200 | iteration 100 / 171 | Total Loss: 3.670849323272705 | KNN Loss: 3.6434481143951416 | CLS Loss: 0.027401186525821686
Epoch 82 / 200 | iteration 110 / 171 | Total Loss: 3.6312432289123535 | KNN Loss: 3.620168685913086 | CLS Loss: 0.011074580252170563
Epoch 82 / 200 | iteration 120 / 171 | Total Loss: 3.645087957382202 | KNN Loss

Epoch 85 / 200 | iteration 120 / 171 | Total Loss: 3.6125588417053223 | KNN Loss: 3.594191312789917 | CLS Loss: 0.018367473036050797
Epoch 85 / 200 | iteration 130 / 171 | Total Loss: 3.6592376232147217 | KNN Loss: 3.6227962970733643 | CLS Loss: 0.0364413745701313
Epoch 85 / 200 | iteration 140 / 171 | Total Loss: 3.6025285720825195 | KNN Loss: 3.599350929260254 | CLS Loss: 0.00317756412550807
Epoch 85 / 200 | iteration 150 / 171 | Total Loss: 3.6251296997070312 | KNN Loss: 3.6142663955688477 | CLS Loss: 0.010863285511732101
Epoch 85 / 200 | iteration 160 / 171 | Total Loss: 3.64123797416687 | KNN Loss: 3.6189916133880615 | CLS Loss: 0.022246433421969414
Epoch 85 / 200 | iteration 170 / 171 | Total Loss: 3.6929051876068115 | KNN Loss: 3.6594653129577637 | CLS Loss: 0.033439815044403076
Epoch: 085, Loss: 3.6288, Train: 0.9959, Valid: 0.9863, Best: 0.9872
Epoch 86 / 200 | iteration 0 / 171 | Total Loss: 3.6058900356292725 | KNN Loss: 3.5889649391174316 | CLS Loss: 0.01692505180835724
Epo

Epoch 89 / 200 | iteration 0 / 171 | Total Loss: 3.6149990558624268 | KNN Loss: 3.60939884185791 | CLS Loss: 0.005600270349532366
Epoch 89 / 200 | iteration 10 / 171 | Total Loss: 3.6503407955169678 | KNN Loss: 3.645698070526123 | CLS Loss: 0.004642823711037636
Epoch 89 / 200 | iteration 20 / 171 | Total Loss: 3.654783248901367 | KNN Loss: 3.623379945755005 | CLS Loss: 0.031403351575136185
Epoch 89 / 200 | iteration 30 / 171 | Total Loss: 3.6293904781341553 | KNN Loss: 3.6184256076812744 | CLS Loss: 0.010964852757751942
Epoch 89 / 200 | iteration 40 / 171 | Total Loss: 3.6030709743499756 | KNN Loss: 3.5788307189941406 | CLS Loss: 0.024240346625447273
Epoch 89 / 200 | iteration 50 / 171 | Total Loss: 3.6082193851470947 | KNN Loss: 3.603856325149536 | CLS Loss: 0.004363050684332848
Epoch 89 / 200 | iteration 60 / 171 | Total Loss: 3.6563720703125 | KNN Loss: 3.6345913410186768 | CLS Loss: 0.021780764684081078
Epoch 89 / 200 | iteration 70 / 171 | Total Loss: 3.5934576988220215 | KNN Loss

Epoch 92 / 200 | iteration 70 / 171 | Total Loss: 3.6233327388763428 | KNN Loss: 3.604680061340332 | CLS Loss: 0.01865277625620365
Epoch 92 / 200 | iteration 80 / 171 | Total Loss: 3.641547203063965 | KNN Loss: 3.6300694942474365 | CLS Loss: 0.011477712541818619
Epoch 92 / 200 | iteration 90 / 171 | Total Loss: 3.5979886054992676 | KNN Loss: 3.580007791519165 | CLS Loss: 0.017980894073843956
Epoch 92 / 200 | iteration 100 / 171 | Total Loss: 3.6289801597595215 | KNN Loss: 3.603811740875244 | CLS Loss: 0.025168459862470627
Epoch 92 / 200 | iteration 110 / 171 | Total Loss: 3.587360382080078 | KNN Loss: 3.5826799869537354 | CLS Loss: 0.004680283833295107
Epoch 92 / 200 | iteration 120 / 171 | Total Loss: 3.624922037124634 | KNN Loss: 3.586778402328491 | CLS Loss: 0.03814353793859482
Epoch 92 / 200 | iteration 130 / 171 | Total Loss: 3.5966484546661377 | KNN Loss: 3.5827393531799316 | CLS Loss: 0.013909182511270046
Epoch 92 / 200 | iteration 140 / 171 | Total Loss: 3.6142418384552 | KNN L

Epoch 95 / 200 | iteration 140 / 171 | Total Loss: 3.6254518032073975 | KNN Loss: 3.6028213500976562 | CLS Loss: 0.022630352526903152
Epoch 95 / 200 | iteration 150 / 171 | Total Loss: 3.6673552989959717 | KNN Loss: 3.6468217372894287 | CLS Loss: 0.02053362876176834
Epoch 95 / 200 | iteration 160 / 171 | Total Loss: 3.6056509017944336 | KNN Loss: 3.5973331928253174 | CLS Loss: 0.008317629806697369
Epoch 95 / 200 | iteration 170 / 171 | Total Loss: 3.621460437774658 | KNN Loss: 3.615311622619629 | CLS Loss: 0.006148811895400286
Epoch: 095, Loss: 3.6287, Train: 0.9967, Valid: 0.9868, Best: 0.9872
Epoch 96 / 200 | iteration 0 / 171 | Total Loss: 3.626277446746826 | KNN Loss: 3.6060903072357178 | CLS Loss: 0.02018715813755989
Epoch 96 / 200 | iteration 10 / 171 | Total Loss: 3.5971829891204834 | KNN Loss: 3.5822949409484863 | CLS Loss: 0.014887976460158825
Epoch 96 / 200 | iteration 20 / 171 | Total Loss: 3.6301753520965576 | KNN Loss: 3.6139516830444336 | CLS Loss: 0.016223646700382233
Ep

Epoch 99 / 200 | iteration 20 / 171 | Total Loss: 3.643935203552246 | KNN Loss: 3.6282942295074463 | CLS Loss: 0.01564090885221958
Epoch 99 / 200 | iteration 30 / 171 | Total Loss: 3.628748655319214 | KNN Loss: 3.6232962608337402 | CLS Loss: 0.005452287849038839
Epoch 99 / 200 | iteration 40 / 171 | Total Loss: 3.6101179122924805 | KNN Loss: 3.5979905128479004 | CLS Loss: 0.012127382680773735
Epoch 99 / 200 | iteration 50 / 171 | Total Loss: 3.614577531814575 | KNN Loss: 3.6094369888305664 | CLS Loss: 0.005140623543411493
Epoch 99 / 200 | iteration 60 / 171 | Total Loss: 3.6312336921691895 | KNN Loss: 3.613185167312622 | CLS Loss: 0.01804843731224537
Epoch 99 / 200 | iteration 70 / 171 | Total Loss: 3.6148841381073 | KNN Loss: 3.603874921798706 | CLS Loss: 0.011009251698851585
Epoch 99 / 200 | iteration 80 / 171 | Total Loss: 3.6315157413482666 | KNN Loss: 3.624821186065674 | CLS Loss: 0.006694588344544172
Epoch 99 / 200 | iteration 90 / 171 | Total Loss: 3.6649937629699707 | KNN Loss:

Epoch 102 / 200 | iteration 90 / 171 | Total Loss: 3.616023063659668 | KNN Loss: 3.5905110836029053 | CLS Loss: 0.025511950254440308
Epoch 102 / 200 | iteration 100 / 171 | Total Loss: 3.5911567211151123 | KNN Loss: 3.586636781692505 | CLS Loss: 0.004519968293607235
Epoch 102 / 200 | iteration 110 / 171 | Total Loss: 3.6030962467193604 | KNN Loss: 3.592863082885742 | CLS Loss: 0.010233080945909023
Epoch 102 / 200 | iteration 120 / 171 | Total Loss: 3.6229147911071777 | KNN Loss: 3.6107606887817383 | CLS Loss: 0.012154157273471355
Epoch 102 / 200 | iteration 130 / 171 | Total Loss: 3.604405403137207 | KNN Loss: 3.594231367111206 | CLS Loss: 0.010174010880291462
Epoch 102 / 200 | iteration 140 / 171 | Total Loss: 3.5934271812438965 | KNN Loss: 3.5919947624206543 | CLS Loss: 0.0014323517680168152
Epoch 102 / 200 | iteration 150 / 171 | Total Loss: 3.640777587890625 | KNN Loss: 3.6150975227355957 | CLS Loss: 0.025680137798190117
Epoch 102 / 200 | iteration 160 / 171 | Total Loss: 3.6158695

Epoch 105 / 200 | iteration 150 / 171 | Total Loss: 3.637131929397583 | KNN Loss: 3.627732276916504 | CLS Loss: 0.009399675764143467
Epoch 105 / 200 | iteration 160 / 171 | Total Loss: 3.693878412246704 | KNN Loss: 3.6868221759796143 | CLS Loss: 0.0070561314933001995
Epoch 105 / 200 | iteration 170 / 171 | Total Loss: 3.666853904724121 | KNN Loss: 3.638063907623291 | CLS Loss: 0.028789926320314407
Epoch: 105, Loss: 3.6214, Train: 0.9960, Valid: 0.9868, Best: 0.9872
Epoch 106 / 200 | iteration 0 / 171 | Total Loss: 3.6201179027557373 | KNN Loss: 3.6119039058685303 | CLS Loss: 0.008214105851948261
Epoch 106 / 200 | iteration 10 / 171 | Total Loss: 3.6124489307403564 | KNN Loss: 3.59419846534729 | CLS Loss: 0.018250498920679092
Epoch 106 / 200 | iteration 20 / 171 | Total Loss: 3.623868465423584 | KNN Loss: 3.6168158054351807 | CLS Loss: 0.007052757311612368
Epoch 106 / 200 | iteration 30 / 171 | Total Loss: 3.6058387756347656 | KNN Loss: 3.601623058319092 | CLS Loss: 0.004215700086206198

Epoch 109 / 200 | iteration 30 / 171 | Total Loss: 3.6160988807678223 | KNN Loss: 3.6098179817199707 | CLS Loss: 0.006280964706093073
Epoch 109 / 200 | iteration 40 / 171 | Total Loss: 3.5975704193115234 | KNN Loss: 3.5949535369873047 | CLS Loss: 0.002616843208670616
Epoch 109 / 200 | iteration 50 / 171 | Total Loss: 3.6131083965301514 | KNN Loss: 3.61004376411438 | CLS Loss: 0.0030646156519651413
Epoch 109 / 200 | iteration 60 / 171 | Total Loss: 3.576046943664551 | KNN Loss: 3.5703213214874268 | CLS Loss: 0.005725712515413761
Epoch 109 / 200 | iteration 70 / 171 | Total Loss: 3.6134490966796875 | KNN Loss: 3.6017003059387207 | CLS Loss: 0.011748863384127617
Epoch 109 / 200 | iteration 80 / 171 | Total Loss: 3.640180826187134 | KNN Loss: 3.6360015869140625 | CLS Loss: 0.004179274197667837
Epoch 109 / 200 | iteration 90 / 171 | Total Loss: 3.6261119842529297 | KNN Loss: 3.6187703609466553 | CLS Loss: 0.007341707590967417
Epoch 109 / 200 | iteration 100 / 171 | Total Loss: 3.69681167602

Epoch 112 / 200 | iteration 90 / 171 | Total Loss: 3.6001713275909424 | KNN Loss: 3.5960514545440674 | CLS Loss: 0.004119829274713993
Epoch 112 / 200 | iteration 100 / 171 | Total Loss: 3.690674066543579 | KNN Loss: 3.6427431106567383 | CLS Loss: 0.04793107137084007
Epoch 112 / 200 | iteration 110 / 171 | Total Loss: 3.6179661750793457 | KNN Loss: 3.613675832748413 | CLS Loss: 0.004290330223739147
Epoch 112 / 200 | iteration 120 / 171 | Total Loss: 3.6199963092803955 | KNN Loss: 3.5862138271331787 | CLS Loss: 0.03378257900476456
Epoch 112 / 200 | iteration 130 / 171 | Total Loss: 3.6557350158691406 | KNN Loss: 3.653904914855957 | CLS Loss: 0.0018300311639904976
Epoch 112 / 200 | iteration 140 / 171 | Total Loss: 3.594385862350464 | KNN Loss: 3.5854032039642334 | CLS Loss: 0.008982697501778603
Epoch 112 / 200 | iteration 150 / 171 | Total Loss: 3.666994094848633 | KNN Loss: 3.6585443019866943 | CLS Loss: 0.008449867367744446
Epoch 112 / 200 | iteration 160 / 171 | Total Loss: 3.64307498

Epoch 115 / 200 | iteration 150 / 171 | Total Loss: 3.667391300201416 | KNN Loss: 3.659186363220215 | CLS Loss: 0.008204974234104156
Epoch 115 / 200 | iteration 160 / 171 | Total Loss: 3.6423676013946533 | KNN Loss: 3.6119027137756348 | CLS Loss: 0.030464958399534225
Epoch 115 / 200 | iteration 170 / 171 | Total Loss: 3.674896478652954 | KNN Loss: 3.6354258060455322 | CLS Loss: 0.039470698684453964
Epoch: 115, Loss: 3.6166, Train: 0.9967, Valid: 0.9859, Best: 0.9873
Epoch 116 / 200 | iteration 0 / 171 | Total Loss: 3.5990591049194336 | KNN Loss: 3.5933303833007812 | CLS Loss: 0.00572860985994339
Epoch 116 / 200 | iteration 10 / 171 | Total Loss: 3.6124744415283203 | KNN Loss: 3.585911512374878 | CLS Loss: 0.026562850922346115
Epoch 116 / 200 | iteration 20 / 171 | Total Loss: 3.6149215698242188 | KNN Loss: 3.6072356700897217 | CLS Loss: 0.007685820106416941
Epoch 116 / 200 | iteration 30 / 171 | Total Loss: 3.606109142303467 | KNN Loss: 3.587613344192505 | CLS Loss: 0.01849576272070408

Epoch 119 / 200 | iteration 30 / 171 | Total Loss: 3.5880138874053955 | KNN Loss: 3.586416482925415 | CLS Loss: 0.0015974962152540684
Epoch 119 / 200 | iteration 40 / 171 | Total Loss: 3.669332265853882 | KNN Loss: 3.6591172218322754 | CLS Loss: 0.01021499466150999
Epoch 119 / 200 | iteration 50 / 171 | Total Loss: 3.6062562465667725 | KNN Loss: 3.5904457569122314 | CLS Loss: 0.01581037975847721
Epoch 119 / 200 | iteration 60 / 171 | Total Loss: 3.6290996074676514 | KNN Loss: 3.606717348098755 | CLS Loss: 0.02238227240741253
Epoch 119 / 200 | iteration 70 / 171 | Total Loss: 3.6070291996002197 | KNN Loss: 3.5953526496887207 | CLS Loss: 0.011676636524498463
Epoch 119 / 200 | iteration 80 / 171 | Total Loss: 3.617767572402954 | KNN Loss: 3.5971200466156006 | CLS Loss: 0.020647641271352768
Epoch 119 / 200 | iteration 90 / 171 | Total Loss: 3.596949577331543 | KNN Loss: 3.5942912101745605 | CLS Loss: 0.0026583196595311165
Epoch 119 / 200 | iteration 100 / 171 | Total Loss: 3.62596678733825

Epoch 122 / 200 | iteration 90 / 171 | Total Loss: 3.6284444332122803 | KNN Loss: 3.6070897579193115 | CLS Loss: 0.021354753524065018
Epoch 122 / 200 | iteration 100 / 171 | Total Loss: 3.6197142601013184 | KNN Loss: 3.590127468109131 | CLS Loss: 0.02958681434392929
Epoch 122 / 200 | iteration 110 / 171 | Total Loss: 3.6331405639648438 | KNN Loss: 3.60144305229187 | CLS Loss: 0.03169739618897438
Epoch 122 / 200 | iteration 120 / 171 | Total Loss: 3.637389659881592 | KNN Loss: 3.6289987564086914 | CLS Loss: 0.008390932343900204
Epoch 122 / 200 | iteration 130 / 171 | Total Loss: 3.6075961589813232 | KNN Loss: 3.6012182235717773 | CLS Loss: 0.0063778311014175415
Epoch 122 / 200 | iteration 140 / 171 | Total Loss: 3.682307004928589 | KNN Loss: 3.6609835624694824 | CLS Loss: 0.02132345549762249
Epoch 122 / 200 | iteration 150 / 171 | Total Loss: 3.6095097064971924 | KNN Loss: 3.5997941493988037 | CLS Loss: 0.009715652093291283
Epoch 122 / 200 | iteration 160 / 171 | Total Loss: 3.629939556

Epoch 125 / 200 | iteration 150 / 171 | Total Loss: 3.6021640300750732 | KNN Loss: 3.589186191558838 | CLS Loss: 0.012977758422493935
Epoch 125 / 200 | iteration 160 / 171 | Total Loss: 3.6385340690612793 | KNN Loss: 3.633214235305786 | CLS Loss: 0.005319810938090086
Epoch 125 / 200 | iteration 170 / 171 | Total Loss: 3.6353070735931396 | KNN Loss: 3.5923147201538086 | CLS Loss: 0.042992401868104935
Epoch: 125, Loss: 3.6209, Train: 0.9962, Valid: 0.9865, Best: 0.9873
Epoch 126 / 200 | iteration 0 / 171 | Total Loss: 3.593187093734741 | KNN Loss: 3.5822174549102783 | CLS Loss: 0.010969581082463264
Epoch 126 / 200 | iteration 10 / 171 | Total Loss: 3.5845084190368652 | KNN Loss: 3.580240488052368 | CLS Loss: 0.004268005024641752
Epoch 126 / 200 | iteration 20 / 171 | Total Loss: 3.609532117843628 | KNN Loss: 3.6068239212036133 | CLS Loss: 0.0027080923318862915
Epoch 126 / 200 | iteration 30 / 171 | Total Loss: 3.6287171840667725 | KNN Loss: 3.6144731044769287 | CLS Loss: 0.01424402929842

Epoch 129 / 200 | iteration 30 / 171 | Total Loss: 3.601958990097046 | KNN Loss: 3.6009514331817627 | CLS Loss: 0.0010076501639559865
Epoch 129 / 200 | iteration 40 / 171 | Total Loss: 3.5804295539855957 | KNN Loss: 3.5635340213775635 | CLS Loss: 0.01689557544887066
Epoch 129 / 200 | iteration 50 / 171 | Total Loss: 3.595039129257202 | KNN Loss: 3.5880329608917236 | CLS Loss: 0.007006144616752863
Epoch 129 / 200 | iteration 60 / 171 | Total Loss: 3.587733507156372 | KNN Loss: 3.5782101154327393 | CLS Loss: 0.009523285552859306
Epoch 129 / 200 | iteration 70 / 171 | Total Loss: 3.6597046852111816 | KNN Loss: 3.6367604732513428 | CLS Loss: 0.022944152355194092
Epoch 129 / 200 | iteration 80 / 171 | Total Loss: 3.6338281631469727 | KNN Loss: 3.6102514266967773 | CLS Loss: 0.023576835170388222
Epoch 129 / 200 | iteration 90 / 171 | Total Loss: 3.6011688709259033 | KNN Loss: 3.593583345413208 | CLS Loss: 0.0075854165479540825
Epoch 129 / 200 | iteration 100 / 171 | Total Loss: 3.64389085769

Epoch 132 / 200 | iteration 90 / 171 | Total Loss: 3.620718002319336 | KNN Loss: 3.5988247394561768 | CLS Loss: 0.021893350407481194
Epoch 132 / 200 | iteration 100 / 171 | Total Loss: 3.601369857788086 | KNN Loss: 3.593755006790161 | CLS Loss: 0.007614814210683107
Epoch 132 / 200 | iteration 110 / 171 | Total Loss: 3.589838981628418 | KNN Loss: 3.5793421268463135 | CLS Loss: 0.010496850125491619
Epoch 132 / 200 | iteration 120 / 171 | Total Loss: 3.6110453605651855 | KNN Loss: 3.602983236312866 | CLS Loss: 0.008062202483415604
Epoch 132 / 200 | iteration 130 / 171 | Total Loss: 3.6219592094421387 | KNN Loss: 3.61103892326355 | CLS Loss: 0.010920331813395023
Epoch 132 / 200 | iteration 140 / 171 | Total Loss: 3.6495277881622314 | KNN Loss: 3.63167405128479 | CLS Loss: 0.017853671684861183
Epoch 132 / 200 | iteration 150 / 171 | Total Loss: 3.6902642250061035 | KNN Loss: 3.668630838394165 | CLS Loss: 0.02163342945277691
Epoch 132 / 200 | iteration 160 / 171 | Total Loss: 3.6240770816802

Epoch 135 / 200 | iteration 150 / 171 | Total Loss: 3.607530117034912 | KNN Loss: 3.600569009780884 | CLS Loss: 0.006961209233850241
Epoch 135 / 200 | iteration 160 / 171 | Total Loss: 3.5983974933624268 | KNN Loss: 3.58443546295166 | CLS Loss: 0.01396195963025093
Epoch 135 / 200 | iteration 170 / 171 | Total Loss: 3.6272025108337402 | KNN Loss: 3.624633312225342 | CLS Loss: 0.0025690984912216663
Epoch: 135, Loss: 3.6146, Train: 0.9981, Valid: 0.9863, Best: 0.9873
Epoch 136 / 200 | iteration 0 / 171 | Total Loss: 3.580533504486084 | KNN Loss: 3.5670857429504395 | CLS Loss: 0.013447693549096584
Epoch 136 / 200 | iteration 10 / 171 | Total Loss: 3.5956010818481445 | KNN Loss: 3.582900047302246 | CLS Loss: 0.01270113606005907
Epoch 136 / 200 | iteration 20 / 171 | Total Loss: 3.6120312213897705 | KNN Loss: 3.607238292694092 | CLS Loss: 0.0047928267158567905
Epoch 136 / 200 | iteration 30 / 171 | Total Loss: 3.6065456867218018 | KNN Loss: 3.6039133071899414 | CLS Loss: 0.002632281510159373

Epoch 139 / 200 | iteration 30 / 171 | Total Loss: 3.587672233581543 | KNN Loss: 3.5781185626983643 | CLS Loss: 0.009553562849760056
Epoch 139 / 200 | iteration 40 / 171 | Total Loss: 3.638401508331299 | KNN Loss: 3.616711139678955 | CLS Loss: 0.02169027552008629
Epoch 139 / 200 | iteration 50 / 171 | Total Loss: 3.6253812313079834 | KNN Loss: 3.594162940979004 | CLS Loss: 0.03121821954846382
Epoch 139 / 200 | iteration 60 / 171 | Total Loss: 3.5962934494018555 | KNN Loss: 3.592224597930908 | CLS Loss: 0.004068780690431595
Epoch 139 / 200 | iteration 70 / 171 | Total Loss: 3.6255927085876465 | KNN Loss: 3.6206037998199463 | CLS Loss: 0.004988865461200476
Epoch 139 / 200 | iteration 80 / 171 | Total Loss: 3.5999655723571777 | KNN Loss: 3.59908127784729 | CLS Loss: 0.0008843513205647469
Epoch 139 / 200 | iteration 90 / 171 | Total Loss: 3.6218366622924805 | KNN Loss: 3.59755802154541 | CLS Loss: 0.02427872270345688
Epoch 139 / 200 | iteration 100 / 171 | Total Loss: 3.6101150512695312 | 

Epoch 142 / 200 | iteration 90 / 171 | Total Loss: 3.6214141845703125 | KNN Loss: 3.6164357662200928 | CLS Loss: 0.004978329408913851
Epoch 142 / 200 | iteration 100 / 171 | Total Loss: 3.594512939453125 | KNN Loss: 3.5803816318511963 | CLS Loss: 0.014131244271993637
Epoch 142 / 200 | iteration 110 / 171 | Total Loss: 3.601670265197754 | KNN Loss: 3.598320245742798 | CLS Loss: 0.0033500606659799814
Epoch 142 / 200 | iteration 120 / 171 | Total Loss: 3.6443235874176025 | KNN Loss: 3.6393051147460938 | CLS Loss: 0.005018517374992371
Epoch 142 / 200 | iteration 130 / 171 | Total Loss: 3.60664439201355 | KNN Loss: 3.585989475250244 | CLS Loss: 0.020654991269111633
Epoch 142 / 200 | iteration 140 / 171 | Total Loss: 3.6561429500579834 | KNN Loss: 3.637166738510132 | CLS Loss: 0.018976159393787384
Epoch 142 / 200 | iteration 150 / 171 | Total Loss: 3.6517677307128906 | KNN Loss: 3.6375603675842285 | CLS Loss: 0.014207348227500916
Epoch 142 / 200 | iteration 160 / 171 | Total Loss: 3.62634158

Epoch 145 / 200 | iteration 150 / 171 | Total Loss: 3.612424850463867 | KNN Loss: 3.601109027862549 | CLS Loss: 0.01131591945886612
Epoch 145 / 200 | iteration 160 / 171 | Total Loss: 3.630927801132202 | KNN Loss: 3.610706090927124 | CLS Loss: 0.02022174745798111
Epoch 145 / 200 | iteration 170 / 171 | Total Loss: 3.618440628051758 | KNN Loss: 3.589381217956543 | CLS Loss: 0.029059482738375664
Epoch: 145, Loss: 3.6166, Train: 0.9975, Valid: 0.9863, Best: 0.9873
Epoch 146 / 200 | iteration 0 / 171 | Total Loss: 3.617413282394409 | KNN Loss: 3.6088168621063232 | CLS Loss: 0.008596319705247879
Epoch 146 / 200 | iteration 10 / 171 | Total Loss: 3.585026502609253 | KNN Loss: 3.571716070175171 | CLS Loss: 0.01331049483269453
Epoch 146 / 200 | iteration 20 / 171 | Total Loss: 3.5845677852630615 | KNN Loss: 3.582979202270508 | CLS Loss: 0.0015885697212070227
Epoch 146 / 200 | iteration 30 / 171 | Total Loss: 3.655428171157837 | KNN Loss: 3.640380859375 | CLS Loss: 0.015047217719256878
Epoch 14

Epoch 149 / 200 | iteration 30 / 171 | Total Loss: 3.6163346767425537 | KNN Loss: 3.613731622695923 | CLS Loss: 0.002603048225864768
Epoch 149 / 200 | iteration 40 / 171 | Total Loss: 3.600409984588623 | KNN Loss: 3.5897397994995117 | CLS Loss: 0.010670186951756477
Epoch 149 / 200 | iteration 50 / 171 | Total Loss: 3.607001543045044 | KNN Loss: 3.5936014652252197 | CLS Loss: 0.013400118798017502
Epoch 149 / 200 | iteration 60 / 171 | Total Loss: 3.615276336669922 | KNN Loss: 3.6073906421661377 | CLS Loss: 0.007885790430009365
Epoch 149 / 200 | iteration 70 / 171 | Total Loss: 3.599039077758789 | KNN Loss: 3.577286720275879 | CLS Loss: 0.02175244502723217
Epoch 149 / 200 | iteration 80 / 171 | Total Loss: 3.5807228088378906 | KNN Loss: 3.5710887908935547 | CLS Loss: 0.009634094312787056
Epoch 149 / 200 | iteration 90 / 171 | Total Loss: 3.603090286254883 | KNN Loss: 3.5983340740203857 | CLS Loss: 0.004756133072078228
Epoch 149 / 200 | iteration 100 / 171 | Total Loss: 3.59619402885437 |

Epoch 152 / 200 | iteration 90 / 171 | Total Loss: 3.574205160140991 | KNN Loss: 3.5696210861206055 | CLS Loss: 0.004583998117595911
Epoch 152 / 200 | iteration 100 / 171 | Total Loss: 3.602250337600708 | KNN Loss: 3.5985629558563232 | CLS Loss: 0.0036873258650302887
Epoch 152 / 200 | iteration 110 / 171 | Total Loss: 3.6003403663635254 | KNN Loss: 3.598818778991699 | CLS Loss: 0.001521618920378387
Epoch 152 / 200 | iteration 120 / 171 | Total Loss: 3.5991268157958984 | KNN Loss: 3.5798685550689697 | CLS Loss: 0.01925818808376789
Epoch 152 / 200 | iteration 130 / 171 | Total Loss: 3.6267588138580322 | KNN Loss: 3.6045923233032227 | CLS Loss: 0.02216646634042263
Epoch 152 / 200 | iteration 140 / 171 | Total Loss: 3.6806697845458984 | KNN Loss: 3.671369791030884 | CLS Loss: 0.009299997240304947
Epoch 152 / 200 | iteration 150 / 171 | Total Loss: 3.5981342792510986 | KNN Loss: 3.596355676651001 | CLS Loss: 0.0017787069082260132
Epoch 152 / 200 | iteration 160 / 171 | Total Loss: 3.6248607

Epoch 155 / 200 | iteration 150 / 171 | Total Loss: 3.652214765548706 | KNN Loss: 3.6428871154785156 | CLS Loss: 0.009327743202447891
Epoch 155 / 200 | iteration 160 / 171 | Total Loss: 3.6340417861938477 | KNN Loss: 3.600449800491333 | CLS Loss: 0.03359202668070793
Epoch 155 / 200 | iteration 170 / 171 | Total Loss: 3.643828868865967 | KNN Loss: 3.6334309577941895 | CLS Loss: 0.010397886857390404
Epoch: 155, Loss: 3.6230, Train: 0.9960, Valid: 0.9863, Best: 0.9873
Epoch 156 / 200 | iteration 0 / 171 | Total Loss: 3.603025436401367 | KNN Loss: 3.596945285797119 | CLS Loss: 0.006080143619328737
Epoch 156 / 200 | iteration 10 / 171 | Total Loss: 3.6273210048675537 | KNN Loss: 3.616830825805664 | CLS Loss: 0.010490186512470245
Epoch 156 / 200 | iteration 20 / 171 | Total Loss: 3.591986894607544 | KNN Loss: 3.5867273807525635 | CLS Loss: 0.005259543191641569
Epoch 156 / 200 | iteration 30 / 171 | Total Loss: 3.731299638748169 | KNN Loss: 3.728640079498291 | CLS Loss: 0.002659557620063424
E

Epoch 159 / 200 | iteration 30 / 171 | Total Loss: 3.5973193645477295 | KNN Loss: 3.591245412826538 | CLS Loss: 0.006073928903788328
Epoch 159 / 200 | iteration 40 / 171 | Total Loss: 3.6194677352905273 | KNN Loss: 3.610872983932495 | CLS Loss: 0.00859485100954771
Epoch 159 / 200 | iteration 50 / 171 | Total Loss: 3.6374118328094482 | KNN Loss: 3.6199445724487305 | CLS Loss: 0.01746717467904091
Epoch 159 / 200 | iteration 60 / 171 | Total Loss: 3.600090265274048 | KNN Loss: 3.5947842597961426 | CLS Loss: 0.0053061190992593765
Epoch 159 / 200 | iteration 70 / 171 | Total Loss: 3.6180572509765625 | KNN Loss: 3.6023590564727783 | CLS Loss: 0.01569819450378418
Epoch 159 / 200 | iteration 80 / 171 | Total Loss: 3.6313908100128174 | KNN Loss: 3.6193957328796387 | CLS Loss: 0.011995102278888226
Epoch 159 / 200 | iteration 90 / 171 | Total Loss: 3.5974953174591064 | KNN Loss: 3.5915188789367676 | CLS Loss: 0.0059763905592262745
Epoch 159 / 200 | iteration 100 / 171 | Total Loss: 3.642891883850

Epoch 162 / 200 | iteration 90 / 171 | Total Loss: 3.6324679851531982 | KNN Loss: 3.6206247806549072 | CLS Loss: 0.011843198910355568
Epoch 162 / 200 | iteration 100 / 171 | Total Loss: 3.652033567428589 | KNN Loss: 3.6354427337646484 | CLS Loss: 0.016590747982263565
Epoch 162 / 200 | iteration 110 / 171 | Total Loss: 3.591543436050415 | KNN Loss: 3.589052200317383 | CLS Loss: 0.0024913023225963116
Epoch 162 / 200 | iteration 120 / 171 | Total Loss: 3.6497058868408203 | KNN Loss: 3.6349241733551025 | CLS Loss: 0.01478164829313755
Epoch 162 / 200 | iteration 130 / 171 | Total Loss: 3.6558566093444824 | KNN Loss: 3.6342878341674805 | CLS Loss: 0.02156883478164673
Epoch 162 / 200 | iteration 140 / 171 | Total Loss: 3.6092941761016846 | KNN Loss: 3.5913054943084717 | CLS Loss: 0.017988568171858788
Epoch 162 / 200 | iteration 150 / 171 | Total Loss: 3.5948421955108643 | KNN Loss: 3.586367130279541 | CLS Loss: 0.008474988862872124
Epoch 162 / 200 | iteration 160 / 171 | Total Loss: 3.6540074

Epoch 165 / 200 | iteration 150 / 171 | Total Loss: 3.625072956085205 | KNN Loss: 3.586512804031372 | CLS Loss: 0.038560230284929276
Epoch 165 / 200 | iteration 160 / 171 | Total Loss: 3.6518633365631104 | KNN Loss: 3.634127616882324 | CLS Loss: 0.017735827714204788
Epoch 165 / 200 | iteration 170 / 171 | Total Loss: 3.6123790740966797 | KNN Loss: 3.607651472091675 | CLS Loss: 0.004727485589683056
Epoch: 165, Loss: 3.6103, Train: 0.9953, Valid: 0.9835, Best: 0.9873
Epoch 166 / 200 | iteration 0 / 171 | Total Loss: 3.658466339111328 | KNN Loss: 3.6255409717559814 | CLS Loss: 0.03292524814605713
Epoch 166 / 200 | iteration 10 / 171 | Total Loss: 3.614391326904297 | KNN Loss: 3.6020262241363525 | CLS Loss: 0.012365058064460754
Epoch 166 / 200 | iteration 20 / 171 | Total Loss: 3.6320271492004395 | KNN Loss: 3.6207103729248047 | CLS Loss: 0.011316860094666481
Epoch 166 / 200 | iteration 30 / 171 | Total Loss: 3.623945474624634 | KNN Loss: 3.601285934448242 | CLS Loss: 0.022659551352262497


Epoch 169 / 200 | iteration 30 / 171 | Total Loss: 3.64953351020813 | KNN Loss: 3.6454033851623535 | CLS Loss: 0.004130121320486069
Epoch 169 / 200 | iteration 40 / 171 | Total Loss: 3.6368253231048584 | KNN Loss: 3.6022422313690186 | CLS Loss: 0.03458316996693611
Epoch 169 / 200 | iteration 50 / 171 | Total Loss: 3.624819040298462 | KNN Loss: 3.6238741874694824 | CLS Loss: 0.0009449228527955711
Epoch 169 / 200 | iteration 60 / 171 | Total Loss: 3.5972673892974854 | KNN Loss: 3.5870914459228516 | CLS Loss: 0.010175846517086029
Epoch 169 / 200 | iteration 70 / 171 | Total Loss: 3.574266195297241 | KNN Loss: 3.566946268081665 | CLS Loss: 0.007320007774978876
Epoch 169 / 200 | iteration 80 / 171 | Total Loss: 3.6083102226257324 | KNN Loss: 3.6007401943206787 | CLS Loss: 0.007570106070488691
Epoch 169 / 200 | iteration 90 / 171 | Total Loss: 3.616387367248535 | KNN Loss: 3.6151654720306396 | CLS Loss: 0.0012219223426654935
Epoch 169 / 200 | iteration 100 / 171 | Total Loss: 3.5915424823760

Epoch 172 / 200 | iteration 90 / 171 | Total Loss: 3.643503427505493 | KNN Loss: 3.628448963165283 | CLS Loss: 0.015054397284984589
Epoch 172 / 200 | iteration 100 / 171 | Total Loss: 3.6177072525024414 | KNN Loss: 3.6021597385406494 | CLS Loss: 0.015547611750662327
Epoch 172 / 200 | iteration 110 / 171 | Total Loss: 3.5977001190185547 | KNN Loss: 3.5921237468719482 | CLS Loss: 0.005576361436396837
Epoch 172 / 200 | iteration 120 / 171 | Total Loss: 3.6594042778015137 | KNN Loss: 3.652432680130005 | CLS Loss: 0.006971603259444237
Epoch 172 / 200 | iteration 130 / 171 | Total Loss: 3.6650748252868652 | KNN Loss: 3.6298720836639404 | CLS Loss: 0.035202641040086746
Epoch 172 / 200 | iteration 140 / 171 | Total Loss: 3.6301209926605225 | KNN Loss: 3.6064422130584717 | CLS Loss: 0.023678747937083244
Epoch 172 / 200 | iteration 150 / 171 | Total Loss: 3.610697031021118 | KNN Loss: 3.609217882156372 | CLS Loss: 0.0014790728455409408
Epoch 172 / 200 | iteration 160 / 171 | Total Loss: 3.576448

Epoch 175 / 200 | iteration 150 / 171 | Total Loss: 3.592108964920044 | KNN Loss: 3.579643964767456 | CLS Loss: 0.012464931234717369
Epoch 175 / 200 | iteration 160 / 171 | Total Loss: 3.67590594291687 | KNN Loss: 3.662576198577881 | CLS Loss: 0.013329686596989632
Epoch 175 / 200 | iteration 170 / 171 | Total Loss: 3.6042051315307617 | KNN Loss: 3.6005187034606934 | CLS Loss: 0.0036863936111330986
Epoch: 175, Loss: 3.6178, Train: 0.9960, Valid: 0.9852, Best: 0.9873
Epoch 176 / 200 | iteration 0 / 171 | Total Loss: 3.6242218017578125 | KNN Loss: 3.6086461544036865 | CLS Loss: 0.015575655736029148
Epoch 176 / 200 | iteration 10 / 171 | Total Loss: 3.610130548477173 | KNN Loss: 3.6018142700195312 | CLS Loss: 0.00831635296344757
Epoch 176 / 200 | iteration 20 / 171 | Total Loss: 3.662177085876465 | KNN Loss: 3.6481142044067383 | CLS Loss: 0.014062825590372086
Epoch 176 / 200 | iteration 30 / 171 | Total Loss: 3.614053726196289 | KNN Loss: 3.5913026332855225 | CLS Loss: 0.022751133888959885

Epoch 179 / 200 | iteration 30 / 171 | Total Loss: 3.628148078918457 | KNN Loss: 3.5964274406433105 | CLS Loss: 0.031720541417598724
Epoch 179 / 200 | iteration 40 / 171 | Total Loss: 3.644183397293091 | KNN Loss: 3.635913133621216 | CLS Loss: 0.008270302787423134
Epoch 179 / 200 | iteration 50 / 171 | Total Loss: 3.6128337383270264 | KNN Loss: 3.5976450443267822 | CLS Loss: 0.015188588760793209
Epoch 179 / 200 | iteration 60 / 171 | Total Loss: 3.6255431175231934 | KNN Loss: 3.6192498207092285 | CLS Loss: 0.006293192505836487
Epoch 179 / 200 | iteration 70 / 171 | Total Loss: 3.595693588256836 | KNN Loss: 3.5863776206970215 | CLS Loss: 0.009315891191363335
Epoch 179 / 200 | iteration 80 / 171 | Total Loss: 3.657918691635132 | KNN Loss: 3.6547350883483887 | CLS Loss: 0.003183704800903797
Epoch 179 / 200 | iteration 90 / 171 | Total Loss: 3.6026127338409424 | KNN Loss: 3.5928945541381836 | CLS Loss: 0.009718182496726513
Epoch 179 / 200 | iteration 100 / 171 | Total Loss: 3.6397223472595

Epoch 182 / 200 | iteration 90 / 171 | Total Loss: 3.6130211353302 | KNN Loss: 3.600275993347168 | CLS Loss: 0.012745173647999763
Epoch 182 / 200 | iteration 100 / 171 | Total Loss: 3.6137053966522217 | KNN Loss: 3.6059470176696777 | CLS Loss: 0.007758395280689001
Epoch 182 / 200 | iteration 110 / 171 | Total Loss: 3.617621421813965 | KNN Loss: 3.5978212356567383 | CLS Loss: 0.019800256937742233
Epoch 182 / 200 | iteration 120 / 171 | Total Loss: 3.5986597537994385 | KNN Loss: 3.594216823577881 | CLS Loss: 0.004443038254976273
Epoch 182 / 200 | iteration 130 / 171 | Total Loss: 3.6056764125823975 | KNN Loss: 3.5979807376861572 | CLS Loss: 0.007695665583014488
Epoch 182 / 200 | iteration 140 / 171 | Total Loss: 3.592532157897949 | KNN Loss: 3.5888452529907227 | CLS Loss: 0.00368681363761425
Epoch 182 / 200 | iteration 150 / 171 | Total Loss: 3.566488265991211 | KNN Loss: 3.5656583309173584 | CLS Loss: 0.0008300240151584148
Epoch 182 / 200 | iteration 160 / 171 | Total Loss: 3.6308686733

Epoch 185 / 200 | iteration 150 / 171 | Total Loss: 3.5888113975524902 | KNN Loss: 3.582388401031494 | CLS Loss: 0.006422976031899452
Epoch 185 / 200 | iteration 160 / 171 | Total Loss: 3.605147361755371 | KNN Loss: 3.5967044830322266 | CLS Loss: 0.008442973718047142
Epoch 185 / 200 | iteration 170 / 171 | Total Loss: 3.585801124572754 | KNN Loss: 3.5842857360839844 | CLS Loss: 0.0015154159627854824
Epoch: 185, Loss: 3.6157, Train: 0.9977, Valid: 0.9864, Best: 0.9873
Epoch 186 / 200 | iteration 0 / 171 | Total Loss: 3.6137428283691406 | KNN Loss: 3.60479736328125 | CLS Loss: 0.008945409208536148
Epoch 186 / 200 | iteration 10 / 171 | Total Loss: 3.654035806655884 | KNN Loss: 3.6368348598480225 | CLS Loss: 0.017201025038957596
Epoch 186 / 200 | iteration 20 / 171 | Total Loss: 3.613801956176758 | KNN Loss: 3.6103832721710205 | CLS Loss: 0.003418781328946352
Epoch 186 / 200 | iteration 30 / 171 | Total Loss: 3.6219284534454346 | KNN Loss: 3.596198797225952 | CLS Loss: 0.02572975307703018

Epoch 189 / 200 | iteration 30 / 171 | Total Loss: 3.645723342895508 | KNN Loss: 3.6447136402130127 | CLS Loss: 0.001009657047688961
Epoch 189 / 200 | iteration 40 / 171 | Total Loss: 3.6180782318115234 | KNN Loss: 3.6140189170837402 | CLS Loss: 0.004059290513396263
Epoch 189 / 200 | iteration 50 / 171 | Total Loss: 3.618532180786133 | KNN Loss: 3.59781813621521 | CLS Loss: 0.020713994279503822
Epoch 189 / 200 | iteration 60 / 171 | Total Loss: 3.591970682144165 | KNN Loss: 3.588420867919922 | CLS Loss: 0.0035498125944286585
Epoch 189 / 200 | iteration 70 / 171 | Total Loss: 3.639238119125366 | KNN Loss: 3.6237573623657227 | CLS Loss: 0.015480652451515198
Epoch 189 / 200 | iteration 80 / 171 | Total Loss: 3.6415834426879883 | KNN Loss: 3.6260440349578857 | CLS Loss: 0.015539480373263359
Epoch 189 / 200 | iteration 90 / 171 | Total Loss: 3.607095718383789 | KNN Loss: 3.5986931324005127 | CLS Loss: 0.008402650244534016
Epoch 189 / 200 | iteration 100 / 171 | Total Loss: 3.620838642120361

Epoch 192 / 200 | iteration 90 / 171 | Total Loss: 3.6029415130615234 | KNN Loss: 3.5985329151153564 | CLS Loss: 0.004408539272844791
Epoch 192 / 200 | iteration 100 / 171 | Total Loss: 3.6114110946655273 | KNN Loss: 3.6111257076263428 | CLS Loss: 0.00028540604398585856
Epoch 192 / 200 | iteration 110 / 171 | Total Loss: 3.574134111404419 | KNN Loss: 3.5719404220581055 | CLS Loss: 0.0021937794517725706
Epoch 192 / 200 | iteration 120 / 171 | Total Loss: 3.6275594234466553 | KNN Loss: 3.5978901386260986 | CLS Loss: 0.02966918982565403
Epoch 192 / 200 | iteration 130 / 171 | Total Loss: 3.613825559616089 | KNN Loss: 3.602100372314453 | CLS Loss: 0.011725218035280704
Epoch 192 / 200 | iteration 140 / 171 | Total Loss: 3.6067306995391846 | KNN Loss: 3.604588270187378 | CLS Loss: 0.002142466139048338
Epoch 192 / 200 | iteration 150 / 171 | Total Loss: 3.6595616340637207 | KNN Loss: 3.6351382732391357 | CLS Loss: 0.02442331612110138
Epoch 192 / 200 | iteration 160 / 171 | Total Loss: 3.65298

Epoch 195 / 200 | iteration 150 / 171 | Total Loss: 3.6119351387023926 | KNN Loss: 3.6010773181915283 | CLS Loss: 0.010857892222702503
Epoch 195 / 200 | iteration 160 / 171 | Total Loss: 3.5734660625457764 | KNN Loss: 3.569589614868164 | CLS Loss: 0.0038764362689107656
Epoch 195 / 200 | iteration 170 / 171 | Total Loss: 3.6822028160095215 | KNN Loss: 3.662424325942993 | CLS Loss: 0.01977851800620556
Epoch: 195, Loss: 3.6062, Train: 0.9974, Valid: 0.9858, Best: 0.9873
Epoch 196 / 200 | iteration 0 / 171 | Total Loss: 3.5784912109375 | KNN Loss: 3.5717251300811768 | CLS Loss: 0.006766071543097496
Epoch 196 / 200 | iteration 10 / 171 | Total Loss: 3.6340038776397705 | KNN Loss: 3.611436367034912 | CLS Loss: 0.02256753109395504
Epoch 196 / 200 | iteration 20 / 171 | Total Loss: 3.608452558517456 | KNN Loss: 3.593545436859131 | CLS Loss: 0.01490722969174385
Epoch 196 / 200 | iteration 30 / 171 | Total Loss: 3.632634162902832 | KNN Loss: 3.615565299987793 | CLS Loss: 0.01706889271736145
Epoc

Epoch 199 / 200 | iteration 30 / 171 | Total Loss: 3.609245538711548 | KNN Loss: 3.6047208309173584 | CLS Loss: 0.0045247189700603485
Epoch 199 / 200 | iteration 40 / 171 | Total Loss: 3.595888614654541 | KNN Loss: 3.5880517959594727 | CLS Loss: 0.007836858741939068
Epoch 199 / 200 | iteration 50 / 171 | Total Loss: 3.6148111820220947 | KNN Loss: 3.6065025329589844 | CLS Loss: 0.008308730088174343
Epoch 199 / 200 | iteration 60 / 171 | Total Loss: 3.579241991043091 | KNN Loss: 3.5780508518218994 | CLS Loss: 0.0011911524925380945
Epoch 199 / 200 | iteration 70 / 171 | Total Loss: 3.600745439529419 | KNN Loss: 3.575310468673706 | CLS Loss: 0.02543490007519722
Epoch 199 / 200 | iteration 80 / 171 | Total Loss: 3.6279661655426025 | KNN Loss: 3.617826223373413 | CLS Loss: 0.010140030644834042
Epoch 199 / 200 | iteration 90 / 171 | Total Loss: 3.5906684398651123 | KNN Loss: 3.584307909011841 | CLS Loss: 0.006360555998980999
Epoch 199 / 200 | iteration 100 / 171 | Total Loss: 3.62377357482910

In [8]:
test(model, test_data_iter, device)

tensor(0.9854, device='cuda:0')

In [9]:
plt.figure()
plt.plot(losses, label='train loss')
plt.legend()
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [10]:
plt.figure()
plt.plot(train_accs, label='train accuracy')
plt.plot(val_accs, label='validation accuracy')
plt.legend()
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [11]:
test_samples = torch.tensor([])
projections = torch.tensor([])
labels = torch.tensor([])

with torch.no_grad():
    for x, y in tqdm(test_data_iter):
        test_samples = torch.cat([test_samples, x])
        labels = torch.cat([labels, y])
        x = x.to(device)
        _, interm = model(x, True)
        projections = torch.cat([projections, interm.detach().cpu().flatten(1)])

  0%|          | 0/43 [00:00<?, ?it/s]

In [12]:
distances = pairwise_distances(projections)
# distances = np.triu(distances)
distances_f = distances.flatten()

plt.matshow(distances)
plt.colorbar()
plt.figure()
plt.hist(distances_f[distances_f > 0], bins=1000)
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [13]:
clusters = DBSCAN(eps=2, min_samples=10).fit_predict(projections)

In [14]:
print(f"Number of inliers: {sum(clusters != -1) / len(clusters)}")

Number of inliers: 0.8574756749349047


In [15]:
perplexity = 100
p = reduce_dims_and_plot(projections[clusters != -1],
                         y=clusters[clusters != -1],
                         title=f'perplexity: {perplexity}',
                         file_name=None,
                         perplexity=perplexity,
                         library='Multicore-TSNE',
                         perform_PCA=False,
                         projected=None,
                         figure_type='2d',
                         show_figure=True,
                         close_figure=False,
                         text=None)    

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

# Train a Soft-Decision-Tree given the self-labels

## Prepare the dataset

In [16]:
tree_dataset = list(zip(test_samples.flatten(1)[clusters!=-1], clusters[clusters != -1]))
batch_size = 512
tree_loader = torch.utils.data.DataLoader(tree_dataset, batch_size=batch_size, shuffle=True)

# Define how we prune the weights of a node

In [17]:
def prune_node(node_weights, factor=1):
    w = node_weights.cpu().detach().numpy()
    mean_ = np.mean(w)
    std_ = np.std(w)
    node_weights[((mean_ - std_ * factor) < node_weights) & (node_weights < (mean_ + std_ * factor))] = 0
    return node_weights

def prune_node_keep(node_weights, keep=4):
    w = node_weights.cpu().detach().numpy()
    throw_idx = np.argsort(abs(w))[:-keep]
    node_weights[throw_idx] = 0
    return node_weights

def prune_tree(tree_, factor):
    new_weights = tree_.inner_nodes.weight.clone()
    for i in range(new_weights.shape[0]):
        res = prune_node_keep(new_weights[i, :], factor)
        new_weights[i, :] = res

    with torch.no_grad():
        tree_.inner_nodes.weight.copy_(new_weights)
        
def sparseness(x):
    s = []
    for i in range(x.shape[0]):
        x_ = x[i, :]
        sp = (len(x_) - torch.norm(x_, 0).item()) / len(x_)
        s.append(sp)
    return np.mean(s)

def compute_regularization_by_level(tree):
    total_reg = 0
    for i in range(tree.inner_nodes.weight.shape[0]):
        cur_level = np.floor(np.log2(i+1))
        node_reg = torch.norm(tree.inner_nodes.weight[i].view(-1), 2)
        total_reg += 2**(-cur_level) * node_reg
    return total_reg

def show_sparseness(tree):
    avg_sp = sparseness(tree.inner_nodes.weight)
    print(f"Average sparseness: {avg_sp}")
    layer = 0
    sps = []
    for i in range(tree.inner_nodes.weight.shape[0]):
        cur_layer = int(np.floor(np.log2(i+1)))
        if cur_layer != layer:
            print(f"layer {layer}: {np.mean(sps)}")
            sps = []
            layer = cur_layer

        x_ = tree.inner_nodes.weight[i, :]
        sp = (len(x_) - torch.norm(x_, 0).item()) / len(x_)
        sps.append(sp)
        
    return avg_sp

## Training configurations

In [18]:
def do_epoch(model, loader, device, log_interval, losses, accs, epoch, iteration):
    model = model.train()
    for batch_idx, (data, target) in enumerate(loader):
        iteration += 1
        data, target = data.to(device), target.to(device)

        output, penalty = tree.forward(data)

        # Loss
        loss_tree = criterion(output, target.view(-1))

        # Penalty
        loss_tree += penalty

        # Sparse regularization
#         fc_params = torch.cat([x.view(-1) for x in tree.inner_nodes.parameters()])
#         regularization = sparsity_lamda * torch.norm(fc_params, 2)
        regularization = sparsity_lamda * compute_regularization_by_level(tree)
        loss = loss_tree

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        losses.append(loss.item())

        pred = output.data.max(1)[1]
        correct = pred.eq(target.view(-1).data).sum()
        accs.append(correct.item() / data.size()[0])

        # Print training status
        if batch_idx % log_interval == 0:
            print(f"Epoch: {epoch:02d} | Batch: {batch_idx:03d} / {len(loader):03d} | Total loss: {loss.item():.3f} | Reg loss: {regularization.item():.3f} | Tree loss: {loss_tree.item():.3f} | Accuracy: {correct.item() / data.size()[0]:03f} | {round((time.time() - start_time) / iteration, 3)} sec/iter")
            
    return iteration


In [19]:
lr = 5e-3
weight_decay = 5e-4
sparsity_lamda = 2e-3
epochs = 400
log_interval = 100
use_cuda = device != 'cpu'

In [20]:
tree = SDT(input_dim=test_samples.shape[2], output_dim=len(set(clusters)) - 1, depth=tree_depth, lamda=1e-3, use_cuda=use_cuda)
optimizer = torch.optim.Adam(tree.parameters(),
                                 lr=lr,
                                 weight_decay=weight_decay)
criterion = nn.CrossEntropyLoss()
tree = tree.to(device)

In [21]:
losses = []
accs = []
sparsity = []

In [22]:
start_time = time.time()
iteration = 0
for epoch in range(epochs):
    # Training
    avg_sp = show_sparseness(tree)
    sparsity.append(avg_sp)
    iteration = do_epoch(tree, tree_loader, device, log_interval, losses, accs, epoch, iteration)
    
    if epoch % 1 == 0:
        prune_tree(tree, factor=3)
        

Average sparseness: 0.0
layer 0: 0.0
layer 1: 0.0
layer 2: 0.0
layer 3: 0.0
layer 4: 0.0
layer 5: 0.0
layer 6: 0.0
layer 7: 0.0
layer 8: 0.0
Epoch: 00 | Batch: 000 / 037 | Total loss: 2.232 | Reg loss: 0.012 | Tree loss: 2.232 | Accuracy: 0.064453 | 1.322 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
layer 7: 0.9840425531914895
layer 8: 0.9840425531914895
Epoch: 01 | Batch: 000 / 037 | Total loss: 2.209 | Reg loss: 0.005 | Tree loss: 2.209 | Accuracy: 0.230469 | 0.829 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
layer 7: 0.9840425531914895
layer 8: 0.9840425531914895
Epoch: 02 | Batch: 000 / 037 

Epoch: 20 | Batch: 000 / 037 | Total loss: 1.992 | Reg loss: 0.024 | Tree loss: 1.992 | Accuracy: 0.458984 | 0.819 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
layer 7: 0.9840425531914895
layer 8: 0.9840425531914895
Epoch: 21 | Batch: 000 / 037 | Total loss: 1.990 | Reg loss: 0.025 | Tree loss: 1.990 | Accuracy: 0.449219 | 0.819 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
layer 7: 0.9840425531914895
layer 8: 0.9840425531914895
Epoch: 22 | Batch: 000 / 037 | Total loss: 1.980 | Reg loss: 0.025 | Tree loss: 1.980 | Accuracy: 0.478516 | 0.818 sec/iter
Average sparseness: 0.9840425531914894
layer 0

Epoch: 40 | Batch: 000 / 037 | Total loss: 1.805 | Reg loss: 0.031 | Tree loss: 1.805 | Accuracy: 0.488281 | 0.816 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
layer 7: 0.9840425531914895
layer 8: 0.9840425531914895
Epoch: 41 | Batch: 000 / 037 | Total loss: 1.828 | Reg loss: 0.032 | Tree loss: 1.828 | Accuracy: 0.429688 | 0.816 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
layer 7: 0.9840425531914895
layer 8: 0.9840425531914895
Epoch: 42 | Batch: 000 / 037 | Total loss: 1.782 | Reg loss: 0.032 | Tree loss: 1.782 | Accuracy: 0.500000 | 0.815 sec/iter
Average sparseness: 0.9840425531914894
layer 0

Epoch: 60 | Batch: 000 / 037 | Total loss: 1.706 | Reg loss: 0.035 | Tree loss: 1.706 | Accuracy: 0.486328 | 0.804 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
layer 7: 0.9840425531914895
layer 8: 0.9840425531914895
Epoch: 61 | Batch: 000 / 037 | Total loss: 1.706 | Reg loss: 0.035 | Tree loss: 1.706 | Accuracy: 0.498047 | 0.804 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
layer 7: 0.9840425531914895
layer 8: 0.9840425531914895
Epoch: 62 | Batch: 000 / 037 | Total loss: 1.685 | Reg loss: 0.035 | Tree loss: 1.685 | Accuracy: 0.492188 | 0.803 sec/iter
Average sparseness: 0.9840425531914894
layer 0

Epoch: 80 | Batch: 000 / 037 | Total loss: 1.665 | Reg loss: 0.037 | Tree loss: 1.665 | Accuracy: 0.496094 | 0.795 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
layer 7: 0.9840425531914895
layer 8: 0.9840425531914895
Epoch: 81 | Batch: 000 / 037 | Total loss: 1.657 | Reg loss: 0.037 | Tree loss: 1.657 | Accuracy: 0.478516 | 0.794 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
layer 7: 0.9840425531914895
layer 8: 0.9840425531914895
Epoch: 82 | Batch: 000 / 037 | Total loss: 1.613 | Reg loss: 0.037 | Tree loss: 1.613 | Accuracy: 0.525391 | 0.794 sec/iter
Average sparseness: 0.9840425531914894
layer 0

Epoch: 100 | Batch: 000 / 037 | Total loss: 1.683 | Reg loss: 0.038 | Tree loss: 1.683 | Accuracy: 0.451172 | 0.789 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
layer 7: 0.9840425531914895
layer 8: 0.9840425531914895
Epoch: 101 | Batch: 000 / 037 | Total loss: 1.630 | Reg loss: 0.038 | Tree loss: 1.630 | Accuracy: 0.513672 | 0.788 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
layer 7: 0.9840425531914895
layer 8: 0.9840425531914895
Epoch: 102 | Batch: 000 / 037 | Total loss: 1.671 | Reg loss: 0.038 | Tree loss: 1.671 | Accuracy: 0.484375 | 0.788 sec/iter
Average sparseness: 0.9840425531914894
laye

Epoch: 120 | Batch: 000 / 037 | Total loss: 1.666 | Reg loss: 0.038 | Tree loss: 1.666 | Accuracy: 0.478516 | 0.785 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
layer 7: 0.9840425531914895
layer 8: 0.9840425531914895
Epoch: 121 | Batch: 000 / 037 | Total loss: 1.632 | Reg loss: 0.038 | Tree loss: 1.632 | Accuracy: 0.464844 | 0.785 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
layer 7: 0.9840425531914895
layer 8: 0.9840425531914895
Epoch: 122 | Batch: 000 / 037 | Total loss: 1.678 | Reg loss: 0.038 | Tree loss: 1.678 | Accuracy: 0.470703 | 0.784 sec/iter
Average sparseness: 0.9840425531914894
laye

Epoch: 140 | Batch: 000 / 037 | Total loss: 1.670 | Reg loss: 0.038 | Tree loss: 1.670 | Accuracy: 0.460938 | 0.782 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
layer 7: 0.9840425531914895
layer 8: 0.9840425531914895
Epoch: 141 | Batch: 000 / 037 | Total loss: 1.643 | Reg loss: 0.038 | Tree loss: 1.643 | Accuracy: 0.476562 | 0.782 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
layer 7: 0.9840425531914895
layer 8: 0.9840425531914895
Epoch: 142 | Batch: 000 / 037 | Total loss: 1.674 | Reg loss: 0.038 | Tree loss: 1.674 | Accuracy: 0.480469 | 0.781 sec/iter
Average sparseness: 0.9840425531914894
laye

Epoch: 160 | Batch: 000 / 037 | Total loss: 1.618 | Reg loss: 0.038 | Tree loss: 1.618 | Accuracy: 0.492188 | 0.782 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
layer 7: 0.9840425531914895
layer 8: 0.9840425531914895
Epoch: 161 | Batch: 000 / 037 | Total loss: 1.671 | Reg loss: 0.038 | Tree loss: 1.671 | Accuracy: 0.453125 | 0.782 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
layer 7: 0.9840425531914895
layer 8: 0.9840425531914895
Epoch: 162 | Batch: 000 / 037 | Total loss: 1.654 | Reg loss: 0.038 | Tree loss: 1.654 | Accuracy: 0.449219 | 0.782 sec/iter
Average sparseness: 0.9840425531914894
laye

Epoch: 180 | Batch: 000 / 037 | Total loss: 1.645 | Reg loss: 0.038 | Tree loss: 1.645 | Accuracy: 0.458984 | 0.783 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
layer 7: 0.9840425531914895
layer 8: 0.9840425531914895
Epoch: 181 | Batch: 000 / 037 | Total loss: 1.619 | Reg loss: 0.038 | Tree loss: 1.619 | Accuracy: 0.464844 | 0.783 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
layer 7: 0.9840425531914895
layer 8: 0.9840425531914895
Epoch: 182 | Batch: 000 / 037 | Total loss: 1.632 | Reg loss: 0.038 | Tree loss: 1.632 | Accuracy: 0.462891 | 0.783 sec/iter
Average sparseness: 0.9840425531914894
laye

layer 6: 0.9840425531914895
layer 7: 0.9840425531914895
layer 8: 0.9840425531914895
Epoch: 200 | Batch: 000 / 037 | Total loss: 1.618 | Reg loss: 0.039 | Tree loss: 1.618 | Accuracy: 0.431641 | 0.784 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
layer 7: 0.9840425531914895
layer 8: 0.9840425531914895
Epoch: 201 | Batch: 000 / 037 | Total loss: 1.591 | Reg loss: 0.039 | Tree loss: 1.591 | Accuracy: 0.496094 | 0.784 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
layer 7: 0.9840425531914895
layer 8: 0.9840425531914895
Epoch: 202 | Batch: 000 / 037 | Total loss: 1.603 | Reg loss: 0.039 | Tree loss: 1.6

Epoch: 220 | Batch: 000 / 037 | Total loss: 1.616 | Reg loss: 0.039 | Tree loss: 1.616 | Accuracy: 0.457031 | 0.784 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
layer 7: 0.9840425531914895
layer 8: 0.9840425531914895
Epoch: 221 | Batch: 000 / 037 | Total loss: 1.657 | Reg loss: 0.039 | Tree loss: 1.657 | Accuracy: 0.439453 | 0.785 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
layer 7: 0.9840425531914895
layer 8: 0.9840425531914895
Epoch: 222 | Batch: 000 / 037 | Total loss: 1.617 | Reg loss: 0.039 | Tree loss: 1.617 | Accuracy: 0.462891 | 0.785 sec/iter
Average sparseness: 0.9840425531914894
laye

Epoch: 240 | Batch: 000 / 037 | Total loss: 1.650 | Reg loss: 0.039 | Tree loss: 1.650 | Accuracy: 0.421875 | 0.785 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
layer 7: 0.9840425531914895
layer 8: 0.9840425531914895
Epoch: 241 | Batch: 000 / 037 | Total loss: 1.625 | Reg loss: 0.039 | Tree loss: 1.625 | Accuracy: 0.460938 | 0.785 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
layer 7: 0.9840425531914895
layer 8: 0.9840425531914895
Epoch: 242 | Batch: 000 / 037 | Total loss: 1.625 | Reg loss: 0.039 | Tree loss: 1.625 | Accuracy: 0.458984 | 0.785 sec/iter
Average sparseness: 0.9840425531914894
laye

Epoch: 260 | Batch: 000 / 037 | Total loss: 1.622 | Reg loss: 0.039 | Tree loss: 1.622 | Accuracy: 0.501953 | 0.786 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
layer 7: 0.9840425531914895
layer 8: 0.9840425531914895
Epoch: 261 | Batch: 000 / 037 | Total loss: 1.663 | Reg loss: 0.039 | Tree loss: 1.663 | Accuracy: 0.472656 | 0.786 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
layer 7: 0.9840425531914895
layer 8: 0.9840425531914895
Epoch: 262 | Batch: 000 / 037 | Total loss: 1.630 | Reg loss: 0.039 | Tree loss: 1.630 | Accuracy: 0.472656 | 0.786 sec/iter
Average sparseness: 0.9840425531914894
laye

Epoch: 280 | Batch: 000 / 037 | Total loss: 1.628 | Reg loss: 0.039 | Tree loss: 1.628 | Accuracy: 0.533203 | 0.786 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
layer 7: 0.9840425531914895
layer 8: 0.9840425531914895
Epoch: 281 | Batch: 000 / 037 | Total loss: 1.657 | Reg loss: 0.039 | Tree loss: 1.657 | Accuracy: 0.507812 | 0.786 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
layer 7: 0.9840425531914895
layer 8: 0.9840425531914895
Epoch: 282 | Batch: 000 / 037 | Total loss: 1.570 | Reg loss: 0.039 | Tree loss: 1.570 | Accuracy: 0.548828 | 0.786 sec/iter
Average sparseness: 0.9840425531914894
laye

Epoch: 300 | Batch: 000 / 037 | Total loss: 1.602 | Reg loss: 0.039 | Tree loss: 1.602 | Accuracy: 0.556641 | 0.786 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
layer 7: 0.9840425531914895
layer 8: 0.9840425531914895
Epoch: 301 | Batch: 000 / 037 | Total loss: 1.646 | Reg loss: 0.039 | Tree loss: 1.646 | Accuracy: 0.519531 | 0.786 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
layer 7: 0.9840425531914895
layer 8: 0.9840425531914895
Epoch: 302 | Batch: 000 / 037 | Total loss: 1.633 | Reg loss: 0.039 | Tree loss: 1.633 | Accuracy: 0.507812 | 0.786 sec/iter
Average sparseness: 0.9840425531914894
laye

Epoch: 320 | Batch: 000 / 037 | Total loss: 1.646 | Reg loss: 0.039 | Tree loss: 1.646 | Accuracy: 0.527344 | 0.786 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
layer 7: 0.9840425531914895
layer 8: 0.9840425531914895
Epoch: 321 | Batch: 000 / 037 | Total loss: 1.636 | Reg loss: 0.039 | Tree loss: 1.636 | Accuracy: 0.496094 | 0.786 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
layer 7: 0.9840425531914895
layer 8: 0.9840425531914895
Epoch: 322 | Batch: 000 / 037 | Total loss: 1.605 | Reg loss: 0.039 | Tree loss: 1.605 | Accuracy: 0.521484 | 0.786 sec/iter
Average sparseness: 0.9840425531914894
laye

Epoch: 340 | Batch: 000 / 037 | Total loss: 1.626 | Reg loss: 0.039 | Tree loss: 1.626 | Accuracy: 0.484375 | 0.786 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
layer 7: 0.9840425531914895
layer 8: 0.9840425531914895
Epoch: 341 | Batch: 000 / 037 | Total loss: 1.634 | Reg loss: 0.039 | Tree loss: 1.634 | Accuracy: 0.492188 | 0.786 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
layer 7: 0.9840425531914895
layer 8: 0.9840425531914895
Epoch: 342 | Batch: 000 / 037 | Total loss: 1.658 | Reg loss: 0.039 | Tree loss: 1.658 | Accuracy: 0.472656 | 0.786 sec/iter
Average sparseness: 0.9840425531914894
laye

Epoch: 360 | Batch: 000 / 037 | Total loss: 1.608 | Reg loss: 0.039 | Tree loss: 1.608 | Accuracy: 0.511719 | 0.787 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
layer 7: 0.9840425531914895
layer 8: 0.9840425531914895
Epoch: 361 | Batch: 000 / 037 | Total loss: 1.594 | Reg loss: 0.039 | Tree loss: 1.594 | Accuracy: 0.537109 | 0.787 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
layer 7: 0.9840425531914895
layer 8: 0.9840425531914895
Epoch: 362 | Batch: 000 / 037 | Total loss: 1.629 | Reg loss: 0.039 | Tree loss: 1.629 | Accuracy: 0.498047 | 0.787 sec/iter
Average sparseness: 0.9840425531914894
laye

Epoch: 380 | Batch: 000 / 037 | Total loss: 1.632 | Reg loss: 0.039 | Tree loss: 1.632 | Accuracy: 0.490234 | 0.787 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
layer 7: 0.9840425531914895
layer 8: 0.9840425531914895
Epoch: 381 | Batch: 000 / 037 | Total loss: 1.637 | Reg loss: 0.039 | Tree loss: 1.637 | Accuracy: 0.472656 | 0.787 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
layer 7: 0.9840425531914895
layer 8: 0.9840425531914895
Epoch: 382 | Batch: 000 / 037 | Total loss: 1.626 | Reg loss: 0.039 | Tree loss: 1.626 | Accuracy: 0.503906 | 0.787 sec/iter
Average sparseness: 0.9840425531914894
laye

In [23]:
plt.figure(figsize=(10, 5))
plt.ylabel("Accuracy")
plt.xlabel('Iteration')
plt.plot(accs, label='Accuracy vs iteration')
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [24]:
plt.figure()
plt.ylabel("Loss")
plt.xlabel('Iteration')
plt.plot(losses, label='Loss vs iteration')
# plt.yscale("log")
plt.show()

plt.figure()
weights = tree.inner_nodes.weight.cpu().detach().numpy().flatten()
plt.hist(weights, bins=500)
weights_std = np.std(weights)
weights_mean = np.mean(weights)
plt.axvline(weights_mean + weights_std, color='r')
plt.axvline(weights_mean - weights_std, color='r')
plt.title(f"Mean: {weights_mean}   |   STD: {weights_std}")
plt.yscale("log")
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

# Tree Visualization

In [25]:
plt.figure(figsize=(15, 10), dpi=80)
avg_height, root = tree.visualize()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Average height: 7.51063829787234


# Extract Rules

# Accumulate samples in the leaves

In [26]:
print(f"Number of patterns: {len(root.get_leaves())}")

Number of patterns: 47


In [27]:
method = 'greedy'

In [28]:
root.clear_leaves_samples()

with torch.no_grad():
    for batch_idx, (data, target) in enumerate(tree_loader):
        root.accumulate_samples(data, method)



# Tighten boundaries

In [29]:
attr_names = [f"T_{i}" for i in range(test_samples.shape[2])]
leaves = root.get_leaves()
sum_comprehensibility = 0
comprehensibilities = []
for pattern_counter, leaf in enumerate(leaves):
    leaf.reset_path()
    leaf.tighten_with_accumulated_samples()
    conds = leaf.get_path_conditions(attr_names)
    print(f"============== Pattern {pattern_counter + 1} ==============")
    comprehensibilities.append(sum([cond.comprehensibility for cond in conds]))
    
print(f"Average comprehensibility: {np.mean(comprehensibilities)}")
print(f"std comprehensibility: {np.std(comprehensibilities)}")
print(f"var comprehensibility: {np.var(comprehensibilities)}")
print(f"minimum comprehensibility: {np.min(comprehensibilities)}")
print(f"maximum comprehensibility: {np.max(comprehensibilities)}")

8
57
39
5190
2
12
7091
2952
1
2760
659
Average comprehensibility: 42.765957446808514
std comprehensibility: 12.565086322666533
var comprehensibility: 157.8813942960616
minimum comprehensibility: 12
maximum comprehensibility: 60


  return np.log(1 / (1 - x))
