In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib widget
import numpy as np
import pandas as pd
import torch
import torchvision
from torch import optim
import matplotlib.pyplot as plt
import torch.nn.functional as F
from sklearn.metrics import pairwise_distances
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tqdm.notebook import tqdm
import os
import sys
import time
import torch.nn as nn
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)
from stream_generators.market_basket_dataset import MarketBasketDataset, BinaryEncodingTransform, RemoveItemsTransform
from utils.MatplotlibUtils import reduce_dims_and_plot
from network.auto_encoder import AutoEncoder
from losses.knn_loss import KNNLoss
from sklearn.cluster import DBSCAN
from sklearn.mixture import GaussianMixture
from soft_decision_tree.sdt_model import SDT
from sklearn.metrics import davies_bouldin_score

In [2]:
k = 16
tree_depth = 10
device = 'cuda'
dataset_path = r"/mnt/qnap/ekosman/Groceries_dataset.csv"

## Create the market basket dataset and use one-hot encoding for items

In [3]:
dataset = MarketBasketDataset(dataset_path=dataset_path)

In [4]:
model = AutoEncoder(dataset.n_items, 50, 4).train().to(device)
epochs = 500
lr = 5e-3
batch_size = 512
log_every = 5

In [5]:
dataset.transform = torchvision.transforms.Compose([
    RemoveItemsTransform(p=0.5),
    BinaryEncodingTransform(mapping=dataset.items_to_idx),
]
)
dataset.target_transform = torchvision.transforms.Compose([
    BinaryEncodingTransform(mapping=dataset.items_to_idx),
]
)

In [6]:
model.train()
data_iter = torch.utils.data.DataLoader(dataset,
                                     batch_size=batch_size,
                                     shuffle=True,
                                     num_workers=1,
                                     pin_memory=True)
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)
scheduler = ReduceLROnPlateau(optimizer, 'min', verbose=True, factor=0.7, threshold=1e-4)
knn_crt = KNNLoss(k=k).to(device)
losses = []
alpha = 10/170
gamma = 2
for epoch in range(epochs):
    total_loss = 0
    for iteration, (batch, target) in enumerate(data_iter):
        batch = batch.to(device)
        target = target.to(device)
        outputs, iterm = model(batch, return_intermidiate=True)
        mse_loss = F.binary_cross_entropy_with_logits(outputs, target, reduction='none')
        mask = torch.ones_like(mse_loss)
        mask[target == 0] = alpha ** gamma
        mask[target == 1] = (1 - alpha) ** gamma
        mse_loss = (mse_loss * mask).sum(dim=-1).mean()
        try:
            knn_loss = knn_crt(iterm)
            if torch.isinf(knn_loss):
                knn_loss = 0
        except ValueError:
            knn_loss = torch.tensor(0)
        loss = mse_loss + knn_loss
        total_loss += loss.item()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if iteration % log_every == 0:
            print(f"Epoch {epoch} / {epochs} | iteration {iteration} / {len(data_iter)} | Total Loss: {loss.item()} | KNN Loss: {knn_loss.item()} | BCE Loss: {mse_loss.item()}")
    
    scheduler.step(total_loss / (iteration + 1))
    losses.append(total_loss / (iteration + 1))

Epoch 0 / 500 | iteration 0 / 30 | Total Loss: 8.220763206481934 | KNN Loss: 6.225732326507568 | BCE Loss: 1.9950306415557861
Epoch 0 / 500 | iteration 5 / 30 | Total Loss: 8.19043254852295 | KNN Loss: 6.225588321685791 | BCE Loss: 1.9648442268371582
Epoch 0 / 500 | iteration 10 / 30 | Total Loss: 8.231579780578613 | KNN Loss: 6.225534439086914 | BCE Loss: 2.006045341491699
Epoch 0 / 500 | iteration 15 / 30 | Total Loss: 8.17219352722168 | KNN Loss: 6.225142955780029 | BCE Loss: 1.947050929069519
Epoch 0 / 500 | iteration 20 / 30 | Total Loss: 8.164810180664062 | KNN Loss: 6.2255425453186035 | BCE Loss: 1.939267635345459
Epoch 0 / 500 | iteration 25 / 30 | Total Loss: 8.082189559936523 | KNN Loss: 6.224427700042725 | BCE Loss: 1.857762098312378
Epoch 1 / 500 | iteration 0 / 30 | Total Loss: 8.12572956085205 | KNN Loss: 6.224025726318359 | BCE Loss: 1.9017040729522705
Epoch 1 / 500 | iteration 5 / 30 | Total Loss: 8.135915756225586 | KNN Loss: 6.2237677574157715 | BCE Loss: 1.9121475219

Epoch 10 / 500 | iteration 25 / 30 | Total Loss: 5.893028736114502 | KNN Loss: 4.714121341705322 | BCE Loss: 1.1789073944091797
Epoch 11 / 500 | iteration 0 / 30 | Total Loss: 5.67885160446167 | KNN Loss: 4.536438465118408 | BCE Loss: 1.1424131393432617
Epoch 11 / 500 | iteration 5 / 30 | Total Loss: 5.617745876312256 | KNN Loss: 4.462981224060059 | BCE Loss: 1.1547646522521973
Epoch 11 / 500 | iteration 10 / 30 | Total Loss: 5.47523307800293 | KNN Loss: 4.33170223236084 | BCE Loss: 1.1435308456420898
Epoch 11 / 500 | iteration 15 / 30 | Total Loss: 5.381700038909912 | KNN Loss: 4.258389949798584 | BCE Loss: 1.1233100891113281
Epoch 11 / 500 | iteration 20 / 30 | Total Loss: 5.270651817321777 | KNN Loss: 4.141578674316406 | BCE Loss: 1.129073143005371
Epoch 11 / 500 | iteration 25 / 30 | Total Loss: 5.185419082641602 | KNN Loss: 4.051194667816162 | BCE Loss: 1.13422429561615
Epoch 12 / 500 | iteration 0 / 30 | Total Loss: 5.1508612632751465 | KNN Loss: 4.007626533508301 | BCE Loss: 1.1

Epoch 21 / 500 | iteration 20 / 30 | Total Loss: 4.242801666259766 | KNN Loss: 3.2004213333129883 | BCE Loss: 1.0423803329467773
Epoch 21 / 500 | iteration 25 / 30 | Total Loss: 4.2983222007751465 | KNN Loss: 3.217681884765625 | BCE Loss: 1.0806403160095215
Epoch 22 / 500 | iteration 0 / 30 | Total Loss: 4.241363048553467 | KNN Loss: 3.191906452178955 | BCE Loss: 1.0494565963745117
Epoch 22 / 500 | iteration 5 / 30 | Total Loss: 4.24920129776001 | KNN Loss: 3.2198221683502197 | BCE Loss: 1.0293792486190796
Epoch 22 / 500 | iteration 10 / 30 | Total Loss: 4.245300769805908 | KNN Loss: 3.216764211654663 | BCE Loss: 1.0285365581512451
Epoch 22 / 500 | iteration 15 / 30 | Total Loss: 4.269352912902832 | KNN Loss: 3.2253353595733643 | BCE Loss: 1.0440176725387573
Epoch 22 / 500 | iteration 20 / 30 | Total Loss: 4.278356552124023 | KNN Loss: 3.2230496406555176 | BCE Loss: 1.0553069114685059
Epoch 22 / 500 | iteration 25 / 30 | Total Loss: 4.317556381225586 | KNN Loss: 3.2745354175567627 | BC

Epoch 32 / 500 | iteration 10 / 30 | Total Loss: 4.19987678527832 | KNN Loss: 3.17140793800354 | BCE Loss: 1.0284690856933594
Epoch 32 / 500 | iteration 15 / 30 | Total Loss: 4.165718078613281 | KNN Loss: 3.1408562660217285 | BCE Loss: 1.0248615741729736
Epoch 32 / 500 | iteration 20 / 30 | Total Loss: 4.24215030670166 | KNN Loss: 3.179236650466919 | BCE Loss: 1.062913417816162
Epoch 32 / 500 | iteration 25 / 30 | Total Loss: 4.17344856262207 | KNN Loss: 3.1509287357330322 | BCE Loss: 1.022519588470459
Epoch 33 / 500 | iteration 0 / 30 | Total Loss: 4.247107982635498 | KNN Loss: 3.177555561065674 | BCE Loss: 1.0695523023605347
Epoch 33 / 500 | iteration 5 / 30 | Total Loss: 4.196765899658203 | KNN Loss: 3.199018716812134 | BCE Loss: 0.9977471232414246
Epoch 33 / 500 | iteration 10 / 30 | Total Loss: 4.234078407287598 | KNN Loss: 3.1877129077911377 | BCE Loss: 1.0463652610778809
Epoch 33 / 500 | iteration 15 / 30 | Total Loss: 4.192921161651611 | KNN Loss: 3.1578729152679443 | BCE Loss:

Epoch 43 / 500 | iteration 5 / 30 | Total Loss: 4.209124565124512 | KNN Loss: 3.176114320755005 | BCE Loss: 1.0330100059509277
Epoch 43 / 500 | iteration 10 / 30 | Total Loss: 4.183102607727051 | KNN Loss: 3.1572511196136475 | BCE Loss: 1.0258512496948242
Epoch 43 / 500 | iteration 15 / 30 | Total Loss: 4.148543834686279 | KNN Loss: 3.1326045989990234 | BCE Loss: 1.0159392356872559
Epoch 43 / 500 | iteration 20 / 30 | Total Loss: 4.208381175994873 | KNN Loss: 3.163536787033081 | BCE Loss: 1.044844388961792
Epoch 43 / 500 | iteration 25 / 30 | Total Loss: 4.18538761138916 | KNN Loss: 3.1472864151000977 | BCE Loss: 1.0381011962890625
Epoch 44 / 500 | iteration 0 / 30 | Total Loss: 4.159329414367676 | KNN Loss: 3.1228044033050537 | BCE Loss: 1.0365251302719116
Epoch 44 / 500 | iteration 5 / 30 | Total Loss: 4.176938056945801 | KNN Loss: 3.125596761703491 | BCE Loss: 1.0513412952423096
Epoch 44 / 500 | iteration 10 / 30 | Total Loss: 4.148064136505127 | KNN Loss: 3.1323349475860596 | BCE L

Epoch 53 / 500 | iteration 25 / 30 | Total Loss: 4.173694610595703 | KNN Loss: 3.1289548873901367 | BCE Loss: 1.0447399616241455
Epoch 54 / 500 | iteration 0 / 30 | Total Loss: 4.117863655090332 | KNN Loss: 3.0925514698028564 | BCE Loss: 1.0253123044967651
Epoch 54 / 500 | iteration 5 / 30 | Total Loss: 4.124027252197266 | KNN Loss: 3.0948495864868164 | BCE Loss: 1.0291776657104492
Epoch 54 / 500 | iteration 10 / 30 | Total Loss: 4.139721870422363 | KNN Loss: 3.1219544410705566 | BCE Loss: 1.0177671909332275
Epoch 54 / 500 | iteration 15 / 30 | Total Loss: 4.1156511306762695 | KNN Loss: 3.0814177989959717 | BCE Loss: 1.0342330932617188
Epoch 54 / 500 | iteration 20 / 30 | Total Loss: 4.114005088806152 | KNN Loss: 3.0988214015960693 | BCE Loss: 1.015183448791504
Epoch 54 / 500 | iteration 25 / 30 | Total Loss: 4.159073829650879 | KNN Loss: 3.1355228424072266 | BCE Loss: 1.0235509872436523
Epoch 55 / 500 | iteration 0 / 30 | Total Loss: 4.180908203125 | KNN Loss: 3.116016387939453 | BCE 

Epoch 64 / 500 | iteration 15 / 30 | Total Loss: 4.086202621459961 | KNN Loss: 3.0866317749023438 | BCE Loss: 0.9995708465576172
Epoch 64 / 500 | iteration 20 / 30 | Total Loss: 4.110311508178711 | KNN Loss: 3.068460464477539 | BCE Loss: 1.0418509244918823
Epoch 64 / 500 | iteration 25 / 30 | Total Loss: 4.068864822387695 | KNN Loss: 3.0672857761383057 | BCE Loss: 1.0015792846679688
Epoch 65 / 500 | iteration 0 / 30 | Total Loss: 4.086534023284912 | KNN Loss: 3.079218864440918 | BCE Loss: 1.0073152780532837
Epoch 65 / 500 | iteration 5 / 30 | Total Loss: 4.120290756225586 | KNN Loss: 3.128277540206909 | BCE Loss: 0.992013156414032
Epoch 65 / 500 | iteration 10 / 30 | Total Loss: 4.151926040649414 | KNN Loss: 3.1280758380889893 | BCE Loss: 1.0238502025604248
Epoch 65 / 500 | iteration 15 / 30 | Total Loss: 4.119924545288086 | KNN Loss: 3.104179859161377 | BCE Loss: 1.015744924545288
Epoch 65 / 500 | iteration 20 / 30 | Total Loss: 4.085965156555176 | KNN Loss: 3.086134672164917 | BCE Lo

Epoch 75 / 500 | iteration 5 / 30 | Total Loss: 4.14432430267334 | KNN Loss: 3.100449800491333 | BCE Loss: 1.0438742637634277
Epoch 75 / 500 | iteration 10 / 30 | Total Loss: 4.107976913452148 | KNN Loss: 3.0922722816467285 | BCE Loss: 1.0157043933868408
Epoch 75 / 500 | iteration 15 / 30 | Total Loss: 4.112934589385986 | KNN Loss: 3.104480504989624 | BCE Loss: 1.0084539651870728
Epoch 75 / 500 | iteration 20 / 30 | Total Loss: 4.130946159362793 | KNN Loss: 3.0945303440093994 | BCE Loss: 1.0364155769348145
Epoch 75 / 500 | iteration 25 / 30 | Total Loss: 4.109489440917969 | KNN Loss: 3.095674753189087 | BCE Loss: 1.013814926147461
Epoch 76 / 500 | iteration 0 / 30 | Total Loss: 4.11566162109375 | KNN Loss: 3.0900044441223145 | BCE Loss: 1.0256569385528564
Epoch 76 / 500 | iteration 5 / 30 | Total Loss: 4.14339017868042 | KNN Loss: 3.1207950115203857 | BCE Loss: 1.0225950479507446
Epoch 76 / 500 | iteration 10 / 30 | Total Loss: 4.103658676147461 | KNN Loss: 3.0873661041259766 | BCE Los

Epoch 85 / 500 | iteration 25 / 30 | Total Loss: 4.154972553253174 | KNN Loss: 3.1094422340393066 | BCE Loss: 1.0455303192138672
Epoch 86 / 500 | iteration 0 / 30 | Total Loss: 4.079039573669434 | KNN Loss: 3.077627420425415 | BCE Loss: 1.0014121532440186
Epoch 86 / 500 | iteration 5 / 30 | Total Loss: 4.123193264007568 | KNN Loss: 3.1338441371917725 | BCE Loss: 0.9893491864204407
Epoch 86 / 500 | iteration 10 / 30 | Total Loss: 4.125030517578125 | KNN Loss: 3.1264941692352295 | BCE Loss: 0.9985362887382507
Epoch 86 / 500 | iteration 15 / 30 | Total Loss: 4.1130547523498535 | KNN Loss: 3.090869903564453 | BCE Loss: 1.0221848487854004
Epoch 86 / 500 | iteration 20 / 30 | Total Loss: 4.150155067443848 | KNN Loss: 3.087007522583008 | BCE Loss: 1.0631476640701294
Epoch 86 / 500 | iteration 25 / 30 | Total Loss: 4.105134963989258 | KNN Loss: 3.070647954940796 | BCE Loss: 1.034487009048462
Epoch    87: reducing learning rate of group 0 to 2.4500e-03.
Epoch 87 / 500 | iteration 0 / 30 | Total

Epoch 96 / 500 | iteration 15 / 30 | Total Loss: 4.105724334716797 | KNN Loss: 3.0808966159820557 | BCE Loss: 1.024827480316162
Epoch 96 / 500 | iteration 20 / 30 | Total Loss: 4.126381874084473 | KNN Loss: 3.100461721420288 | BCE Loss: 1.0259203910827637
Epoch 96 / 500 | iteration 25 / 30 | Total Loss: 4.138910293579102 | KNN Loss: 3.1139800548553467 | BCE Loss: 1.024930477142334
Epoch 97 / 500 | iteration 0 / 30 | Total Loss: 4.097437381744385 | KNN Loss: 3.0835530757904053 | BCE Loss: 1.0138843059539795
Epoch 97 / 500 | iteration 5 / 30 | Total Loss: 4.14650297164917 | KNN Loss: 3.130908489227295 | BCE Loss: 1.015594482421875
Epoch 97 / 500 | iteration 10 / 30 | Total Loss: 4.142204761505127 | KNN Loss: 3.1192705631256104 | BCE Loss: 1.0229341983795166
Epoch 97 / 500 | iteration 15 / 30 | Total Loss: 4.174666404724121 | KNN Loss: 3.1399638652801514 | BCE Loss: 1.0347023010253906
Epoch 97 / 500 | iteration 20 / 30 | Total Loss: 4.112357139587402 | KNN Loss: 3.097900629043579 | BCE Lo

Epoch 107 / 500 | iteration 5 / 30 | Total Loss: 4.101624965667725 | KNN Loss: 3.1017003059387207 | BCE Loss: 0.9999246001243591
Epoch 107 / 500 | iteration 10 / 30 | Total Loss: 4.122209548950195 | KNN Loss: 3.084055185317993 | BCE Loss: 1.0381546020507812
Epoch 107 / 500 | iteration 15 / 30 | Total Loss: 4.142438888549805 | KNN Loss: 3.110187530517578 | BCE Loss: 1.0322511196136475
Epoch 107 / 500 | iteration 20 / 30 | Total Loss: 4.084172248840332 | KNN Loss: 3.0965769290924072 | BCE Loss: 0.98759526014328
Epoch 107 / 500 | iteration 25 / 30 | Total Loss: 4.075799465179443 | KNN Loss: 3.048888683319092 | BCE Loss: 1.0269107818603516
Epoch 108 / 500 | iteration 0 / 30 | Total Loss: 4.079070568084717 | KNN Loss: 3.072615146636963 | BCE Loss: 1.006455421447754
Epoch 108 / 500 | iteration 5 / 30 | Total Loss: 4.090881824493408 | KNN Loss: 3.0759220123291016 | BCE Loss: 1.0149598121643066
Epoch 108 / 500 | iteration 10 / 30 | Total Loss: 4.09749698638916 | KNN Loss: 3.059884548187256 | B

Epoch 117 / 500 | iteration 25 / 30 | Total Loss: 4.088991641998291 | KNN Loss: 3.0749428272247314 | BCE Loss: 1.0140489339828491
Epoch 118 / 500 | iteration 0 / 30 | Total Loss: 4.11130428314209 | KNN Loss: 3.0946810245513916 | BCE Loss: 1.0166232585906982
Epoch 118 / 500 | iteration 5 / 30 | Total Loss: 4.143617630004883 | KNN Loss: 3.099778175354004 | BCE Loss: 1.0438392162322998
Epoch 118 / 500 | iteration 10 / 30 | Total Loss: 4.0449113845825195 | KNN Loss: 3.0432794094085693 | BCE Loss: 1.0016318559646606
Epoch 118 / 500 | iteration 15 / 30 | Total Loss: 4.152543544769287 | KNN Loss: 3.1343414783477783 | BCE Loss: 1.0182021856307983
Epoch 118 / 500 | iteration 20 / 30 | Total Loss: 4.158658027648926 | KNN Loss: 3.128305196762085 | BCE Loss: 1.0303528308868408
Epoch 118 / 500 | iteration 25 / 30 | Total Loss: 4.099148750305176 | KNN Loss: 3.0938189029693604 | BCE Loss: 1.005329966545105
Epoch 119 / 500 | iteration 0 / 30 | Total Loss: 4.047561168670654 | KNN Loss: 3.05747318267822

Epoch 128 / 500 | iteration 15 / 30 | Total Loss: 4.1043243408203125 | KNN Loss: 3.0825438499450684 | BCE Loss: 1.0217804908752441
Epoch 128 / 500 | iteration 20 / 30 | Total Loss: 4.088597297668457 | KNN Loss: 3.075810670852661 | BCE Loss: 1.0127863883972168
Epoch 128 / 500 | iteration 25 / 30 | Total Loss: 4.077401638031006 | KNN Loss: 3.041799306869507 | BCE Loss: 1.0356022119522095
Epoch 129 / 500 | iteration 0 / 30 | Total Loss: 4.063545227050781 | KNN Loss: 3.0408759117126465 | BCE Loss: 1.0226693153381348
Epoch 129 / 500 | iteration 5 / 30 | Total Loss: 4.133244514465332 | KNN Loss: 3.117233991622925 | BCE Loss: 1.0160107612609863
Epoch 129 / 500 | iteration 10 / 30 | Total Loss: 4.120431423187256 | KNN Loss: 3.1015899181365967 | BCE Loss: 1.0188416242599487
Epoch 129 / 500 | iteration 15 / 30 | Total Loss: 4.146996974945068 | KNN Loss: 3.0999042987823486 | BCE Loss: 1.0470927953720093
Epoch 129 / 500 | iteration 20 / 30 | Total Loss: 4.089232444763184 | KNN Loss: 3.067626237869

Epoch 139 / 500 | iteration 5 / 30 | Total Loss: 4.118356704711914 | KNN Loss: 3.0806989669799805 | BCE Loss: 1.0376577377319336
Epoch 139 / 500 | iteration 10 / 30 | Total Loss: 4.10492467880249 | KNN Loss: 3.072922468185425 | BCE Loss: 1.032002329826355
Epoch 139 / 500 | iteration 15 / 30 | Total Loss: 4.024467468261719 | KNN Loss: 3.0274770259857178 | BCE Loss: 0.9969906210899353
Epoch 139 / 500 | iteration 20 / 30 | Total Loss: 4.090290069580078 | KNN Loss: 3.076848268508911 | BCE Loss: 1.0134419202804565
Epoch 139 / 500 | iteration 25 / 30 | Total Loss: 4.08042573928833 | KNN Loss: 3.1041672229766846 | BCE Loss: 0.9762583374977112
Epoch 140 / 500 | iteration 0 / 30 | Total Loss: 4.092552661895752 | KNN Loss: 3.054715871810913 | BCE Loss: 1.0378366708755493
Epoch 140 / 500 | iteration 5 / 30 | Total Loss: 4.127178192138672 | KNN Loss: 3.0826611518859863 | BCE Loss: 1.0445168018341064
Epoch 140 / 500 | iteration 10 / 30 | Total Loss: 4.079256057739258 | KNN Loss: 3.0691330432891846 

Epoch 149 / 500 | iteration 25 / 30 | Total Loss: 4.031757831573486 | KNN Loss: 3.048058032989502 | BCE Loss: 0.9836997985839844
Epoch 150 / 500 | iteration 0 / 30 | Total Loss: 4.110666275024414 | KNN Loss: 3.087629795074463 | BCE Loss: 1.0230364799499512
Epoch 150 / 500 | iteration 5 / 30 | Total Loss: 4.0825514793396 | KNN Loss: 3.107619285583496 | BCE Loss: 0.974932074546814
Epoch 150 / 500 | iteration 10 / 30 | Total Loss: 4.098141670227051 | KNN Loss: 3.08329701423645 | BCE Loss: 1.0148447751998901
Epoch 150 / 500 | iteration 15 / 30 | Total Loss: 4.091085433959961 | KNN Loss: 3.047450065612793 | BCE Loss: 1.0436352491378784
Epoch 150 / 500 | iteration 20 / 30 | Total Loss: 4.094181060791016 | KNN Loss: 3.064678907394409 | BCE Loss: 1.0295021533966064
Epoch 150 / 500 | iteration 25 / 30 | Total Loss: 4.0595197677612305 | KNN Loss: 3.0255987644195557 | BCE Loss: 1.0339207649230957
Epoch 151 / 500 | iteration 0 / 30 | Total Loss: 4.096003532409668 | KNN Loss: 3.0674545764923096 | B

Epoch 160 / 500 | iteration 15 / 30 | Total Loss: 4.068169116973877 | KNN Loss: 3.0701870918273926 | BCE Loss: 0.9979819059371948
Epoch 160 / 500 | iteration 20 / 30 | Total Loss: 4.014050483703613 | KNN Loss: 3.025045394897461 | BCE Loss: 0.9890052080154419
Epoch 160 / 500 | iteration 25 / 30 | Total Loss: 4.101040840148926 | KNN Loss: 3.0744972229003906 | BCE Loss: 1.0265438556671143
Epoch 161 / 500 | iteration 0 / 30 | Total Loss: 4.110960483551025 | KNN Loss: 3.1106042861938477 | BCE Loss: 1.0003561973571777
Epoch 161 / 500 | iteration 5 / 30 | Total Loss: 4.108709335327148 | KNN Loss: 3.078617572784424 | BCE Loss: 1.0300917625427246
Epoch 161 / 500 | iteration 10 / 30 | Total Loss: 4.1109771728515625 | KNN Loss: 3.067495346069336 | BCE Loss: 1.0434819459915161
Epoch 161 / 500 | iteration 15 / 30 | Total Loss: 4.086967945098877 | KNN Loss: 3.080873727798462 | BCE Loss: 1.0060943365097046
Epoch 161 / 500 | iteration 20 / 30 | Total Loss: 4.137491226196289 | KNN Loss: 3.1138575077056

Epoch 171 / 500 | iteration 5 / 30 | Total Loss: 4.07960844039917 | KNN Loss: 3.0653884410858154 | BCE Loss: 1.0142199993133545
Epoch 171 / 500 | iteration 10 / 30 | Total Loss: 4.104319095611572 | KNN Loss: 3.0629022121429443 | BCE Loss: 1.0414170026779175
Epoch 171 / 500 | iteration 15 / 30 | Total Loss: 4.076900482177734 | KNN Loss: 3.06660532951355 | BCE Loss: 1.0102953910827637
Epoch 171 / 500 | iteration 20 / 30 | Total Loss: 4.063222885131836 | KNN Loss: 3.050391435623169 | BCE Loss: 1.012831687927246
Epoch 171 / 500 | iteration 25 / 30 | Total Loss: 4.062729358673096 | KNN Loss: 3.0519890785217285 | BCE Loss: 1.0107402801513672
Epoch 172 / 500 | iteration 0 / 30 | Total Loss: 4.078193187713623 | KNN Loss: 3.041116952896118 | BCE Loss: 1.0370761156082153
Epoch 172 / 500 | iteration 5 / 30 | Total Loss: 4.132658004760742 | KNN Loss: 3.088465929031372 | BCE Loss: 1.044191837310791
Epoch 172 / 500 | iteration 10 / 30 | Total Loss: 4.114150047302246 | KNN Loss: 3.059462070465088 | B

Epoch 181 / 500 | iteration 25 / 30 | Total Loss: 4.110685348510742 | KNN Loss: 3.0619139671325684 | BCE Loss: 1.048771619796753
Epoch 182 / 500 | iteration 0 / 30 | Total Loss: 4.053529739379883 | KNN Loss: 3.069918632507324 | BCE Loss: 0.9836112260818481
Epoch 182 / 500 | iteration 5 / 30 | Total Loss: 4.08386754989624 | KNN Loss: 3.070089817047119 | BCE Loss: 1.013777732849121
Epoch 182 / 500 | iteration 10 / 30 | Total Loss: 4.078964710235596 | KNN Loss: 3.0665690898895264 | BCE Loss: 1.0123955011367798
Epoch 182 / 500 | iteration 15 / 30 | Total Loss: 4.111300468444824 | KNN Loss: 3.065309524536133 | BCE Loss: 1.045991063117981
Epoch 182 / 500 | iteration 20 / 30 | Total Loss: 4.038296699523926 | KNN Loss: 3.0567877292633057 | BCE Loss: 0.9815092086791992
Epoch 182 / 500 | iteration 25 / 30 | Total Loss: 4.077134609222412 | KNN Loss: 3.0388717651367188 | BCE Loss: 1.038262963294983
Epoch 183 / 500 | iteration 0 / 30 | Total Loss: 4.114812850952148 | KNN Loss: 3.1141862869262695 | 

Epoch 192 / 500 | iteration 15 / 30 | Total Loss: 4.1161088943481445 | KNN Loss: 3.069572687149048 | BCE Loss: 1.0465362071990967
Epoch 192 / 500 | iteration 20 / 30 | Total Loss: 4.110224723815918 | KNN Loss: 3.0623161792755127 | BCE Loss: 1.0479086637496948
Epoch 192 / 500 | iteration 25 / 30 | Total Loss: 4.050692081451416 | KNN Loss: 3.0625481605529785 | BCE Loss: 0.988143801689148
Epoch 193 / 500 | iteration 0 / 30 | Total Loss: 4.058401584625244 | KNN Loss: 3.047548294067383 | BCE Loss: 1.0108532905578613
Epoch 193 / 500 | iteration 5 / 30 | Total Loss: 4.12888240814209 | KNN Loss: 3.0950584411621094 | BCE Loss: 1.0338237285614014
Epoch 193 / 500 | iteration 10 / 30 | Total Loss: 4.0584235191345215 | KNN Loss: 3.0505118370056152 | BCE Loss: 1.0079118013381958
Epoch 193 / 500 | iteration 15 / 30 | Total Loss: 4.054477691650391 | KNN Loss: 3.0411696434020996 | BCE Loss: 1.0133081674575806
Epoch 193 / 500 | iteration 20 / 30 | Total Loss: 4.060898780822754 | KNN Loss: 3.057616949081

Epoch 203 / 500 | iteration 5 / 30 | Total Loss: 4.109368324279785 | KNN Loss: 3.098463296890259 | BCE Loss: 1.0109049081802368
Epoch 203 / 500 | iteration 10 / 30 | Total Loss: 4.132368087768555 | KNN Loss: 3.112765073776245 | BCE Loss: 1.01960289478302
Epoch 203 / 500 | iteration 15 / 30 | Total Loss: 4.115475654602051 | KNN Loss: 3.1049609184265137 | BCE Loss: 1.0105146169662476
Epoch 203 / 500 | iteration 20 / 30 | Total Loss: 4.061158180236816 | KNN Loss: 3.0620696544647217 | BCE Loss: 0.9990884065628052
Epoch 203 / 500 | iteration 25 / 30 | Total Loss: 4.075298309326172 | KNN Loss: 3.051128625869751 | BCE Loss: 1.0241698026657104
Epoch 204 / 500 | iteration 0 / 30 | Total Loss: 4.054834365844727 | KNN Loss: 3.032728672027588 | BCE Loss: 1.0221054553985596
Epoch 204 / 500 | iteration 5 / 30 | Total Loss: 4.0728936195373535 | KNN Loss: 3.0860190391540527 | BCE Loss: 0.9868745803833008
Epoch 204 / 500 | iteration 10 / 30 | Total Loss: 4.123208045959473 | KNN Loss: 3.1173596382141113

Epoch 213 / 500 | iteration 25 / 30 | Total Loss: 4.100217819213867 | KNN Loss: 3.093705892562866 | BCE Loss: 1.00651216506958
Epoch 214 / 500 | iteration 0 / 30 | Total Loss: 4.1015825271606445 | KNN Loss: 3.0775129795074463 | BCE Loss: 1.0240695476531982
Epoch 214 / 500 | iteration 5 / 30 | Total Loss: 4.0597686767578125 | KNN Loss: 3.0618488788604736 | BCE Loss: 0.9979199767112732
Epoch 214 / 500 | iteration 10 / 30 | Total Loss: 4.049434185028076 | KNN Loss: 3.071361780166626 | BCE Loss: 0.9780725240707397
Epoch 214 / 500 | iteration 15 / 30 | Total Loss: 4.069096088409424 | KNN Loss: 3.0365078449249268 | BCE Loss: 1.0325881242752075
Epoch 214 / 500 | iteration 20 / 30 | Total Loss: 4.067480564117432 | KNN Loss: 3.0328094959259033 | BCE Loss: 1.0346709489822388
Epoch 214 / 500 | iteration 25 / 30 | Total Loss: 4.077341079711914 | KNN Loss: 3.036071300506592 | BCE Loss: 1.0412697792053223
Epoch 215 / 500 | iteration 0 / 30 | Total Loss: 4.149571895599365 | KNN Loss: 3.08752608299255

Epoch 224 / 500 | iteration 15 / 30 | Total Loss: 4.1001081466674805 | KNN Loss: 3.0563971996307373 | BCE Loss: 1.0437109470367432
Epoch 224 / 500 | iteration 20 / 30 | Total Loss: 4.0709452629089355 | KNN Loss: 3.0471765995025635 | BCE Loss: 1.023768663406372
Epoch 224 / 500 | iteration 25 / 30 | Total Loss: 4.061312198638916 | KNN Loss: 3.0470733642578125 | BCE Loss: 1.014238715171814
Epoch 225 / 500 | iteration 0 / 30 | Total Loss: 4.086259841918945 | KNN Loss: 3.069460391998291 | BCE Loss: 1.0167996883392334
Epoch 225 / 500 | iteration 5 / 30 | Total Loss: 4.090339183807373 | KNN Loss: 3.0428049564361572 | BCE Loss: 1.0475343465805054
Epoch 225 / 500 | iteration 10 / 30 | Total Loss: 4.119556427001953 | KNN Loss: 3.061906337738037 | BCE Loss: 1.057649850845337
Epoch 225 / 500 | iteration 15 / 30 | Total Loss: 4.0661797523498535 | KNN Loss: 3.0710203647613525 | BCE Loss: 0.9951594471931458
Epoch 225 / 500 | iteration 20 / 30 | Total Loss: 4.091028213500977 | KNN Loss: 3.083507537841

Epoch 235 / 500 | iteration 5 / 30 | Total Loss: 4.087567329406738 | KNN Loss: 3.057976007461548 | BCE Loss: 1.0295915603637695
Epoch 235 / 500 | iteration 10 / 30 | Total Loss: 4.069497108459473 | KNN Loss: 3.0662550926208496 | BCE Loss: 1.0032422542572021
Epoch 235 / 500 | iteration 15 / 30 | Total Loss: 4.083083152770996 | KNN Loss: 3.0574429035186768 | BCE Loss: 1.0256404876708984
Epoch 235 / 500 | iteration 20 / 30 | Total Loss: 4.084417343139648 | KNN Loss: 3.0467917919158936 | BCE Loss: 1.037625789642334
Epoch 235 / 500 | iteration 25 / 30 | Total Loss: 4.092174053192139 | KNN Loss: 3.0481700897216797 | BCE Loss: 1.044003963470459
Epoch 236 / 500 | iteration 0 / 30 | Total Loss: 4.079317092895508 | KNN Loss: 3.064122438430786 | BCE Loss: 1.0151946544647217
Epoch 236 / 500 | iteration 5 / 30 | Total Loss: 4.039323806762695 | KNN Loss: 3.0614986419677734 | BCE Loss: 0.9778250455856323
Epoch 236 / 500 | iteration 10 / 30 | Total Loss: 4.103768348693848 | KNN Loss: 3.081108570098877

Epoch 245 / 500 | iteration 25 / 30 | Total Loss: 4.031405448913574 | KNN Loss: 3.0304250717163086 | BCE Loss: 1.0009801387786865
Epoch 246 / 500 | iteration 0 / 30 | Total Loss: 4.058538436889648 | KNN Loss: 3.051798105239868 | BCE Loss: 1.0067400932312012
Epoch 246 / 500 | iteration 5 / 30 | Total Loss: 4.111529350280762 | KNN Loss: 3.0832135677337646 | BCE Loss: 1.028315544128418
Epoch 246 / 500 | iteration 10 / 30 | Total Loss: 4.083789825439453 | KNN Loss: 3.065896511077881 | BCE Loss: 1.0178935527801514
Epoch 246 / 500 | iteration 15 / 30 | Total Loss: 4.097379684448242 | KNN Loss: 3.071638584136963 | BCE Loss: 1.0257412195205688
Epoch 246 / 500 | iteration 20 / 30 | Total Loss: 4.059595584869385 | KNN Loss: 3.0547678470611572 | BCE Loss: 1.004827618598938
Epoch 246 / 500 | iteration 25 / 30 | Total Loss: 4.080341339111328 | KNN Loss: 3.059281826019287 | BCE Loss: 1.021059513092041
Epoch 247 / 500 | iteration 0 / 30 | Total Loss: 4.022783279418945 | KNN Loss: 3.017277479171753 | 

Epoch 256 / 500 | iteration 15 / 30 | Total Loss: 4.071317195892334 | KNN Loss: 3.065279960632324 | BCE Loss: 1.0060372352600098
Epoch 256 / 500 | iteration 20 / 30 | Total Loss: 4.046391487121582 | KNN Loss: 3.0492308139801025 | BCE Loss: 0.9971605539321899
Epoch 256 / 500 | iteration 25 / 30 | Total Loss: 4.084527969360352 | KNN Loss: 3.0566728115081787 | BCE Loss: 1.0278551578521729
Epoch 257 / 500 | iteration 0 / 30 | Total Loss: 4.104671001434326 | KNN Loss: 3.075524091720581 | BCE Loss: 1.0291470289230347
Epoch 257 / 500 | iteration 5 / 30 | Total Loss: 4.0445451736450195 | KNN Loss: 3.046260118484497 | BCE Loss: 0.9982848167419434
Epoch 257 / 500 | iteration 10 / 30 | Total Loss: 4.069267272949219 | KNN Loss: 3.048719882965088 | BCE Loss: 1.02054762840271
Epoch 257 / 500 | iteration 15 / 30 | Total Loss: 4.111875534057617 | KNN Loss: 3.0794360637664795 | BCE Loss: 1.0324397087097168
Epoch 257 / 500 | iteration 20 / 30 | Total Loss: 4.084146499633789 | KNN Loss: 3.058710336685180

Epoch 267 / 500 | iteration 5 / 30 | Total Loss: 4.1019721031188965 | KNN Loss: 3.0672097206115723 | BCE Loss: 1.0347622632980347
Epoch 267 / 500 | iteration 10 / 30 | Total Loss: 4.076491355895996 | KNN Loss: 3.0590267181396484 | BCE Loss: 1.0174648761749268
Epoch 267 / 500 | iteration 15 / 30 | Total Loss: 4.0941314697265625 | KNN Loss: 3.0670557022094727 | BCE Loss: 1.0270755290985107
Epoch 267 / 500 | iteration 20 / 30 | Total Loss: 4.049969673156738 | KNN Loss: 3.03505802154541 | BCE Loss: 1.0149115324020386
Epoch 267 / 500 | iteration 25 / 30 | Total Loss: 4.101031303405762 | KNN Loss: 3.0845513343811035 | BCE Loss: 1.0164802074432373
Epoch 268 / 500 | iteration 0 / 30 | Total Loss: 4.052271842956543 | KNN Loss: 3.050604820251465 | BCE Loss: 1.0016671419143677
Epoch 268 / 500 | iteration 5 / 30 | Total Loss: 4.0339508056640625 | KNN Loss: 3.039543628692627 | BCE Loss: 0.9944073557853699
Epoch 268 / 500 | iteration 10 / 30 | Total Loss: 4.073513984680176 | KNN Loss: 3.062321186065

Epoch 277 / 500 | iteration 25 / 30 | Total Loss: 4.097932815551758 | KNN Loss: 3.070021152496338 | BCE Loss: 1.0279115438461304
Epoch 278 / 500 | iteration 0 / 30 | Total Loss: 4.070549964904785 | KNN Loss: 3.0431389808654785 | BCE Loss: 1.0274109840393066
Epoch 278 / 500 | iteration 5 / 30 | Total Loss: 4.045168876647949 | KNN Loss: 3.018531560897827 | BCE Loss: 1.026637077331543
Epoch 278 / 500 | iteration 10 / 30 | Total Loss: 4.111894130706787 | KNN Loss: 3.0833990573883057 | BCE Loss: 1.028494954109192
Epoch 278 / 500 | iteration 15 / 30 | Total Loss: 4.09842586517334 | KNN Loss: 3.0603489875793457 | BCE Loss: 1.0380768775939941
Epoch 278 / 500 | iteration 20 / 30 | Total Loss: 4.047945022583008 | KNN Loss: 3.02748966217041 | BCE Loss: 1.020455241203308
Epoch 278 / 500 | iteration 25 / 30 | Total Loss: 4.0932135581970215 | KNN Loss: 3.0669822692871094 | BCE Loss: 1.026231288909912
Epoch 279 / 500 | iteration 0 / 30 | Total Loss: 4.081758499145508 | KNN Loss: 3.080761194229126 | B

Epoch 288 / 500 | iteration 15 / 30 | Total Loss: 4.065742492675781 | KNN Loss: 3.05542254447937 | BCE Loss: 1.0103198289871216
Epoch 288 / 500 | iteration 20 / 30 | Total Loss: 4.105730056762695 | KNN Loss: 3.084167003631592 | BCE Loss: 1.0215630531311035
Epoch 288 / 500 | iteration 25 / 30 | Total Loss: 4.023646831512451 | KNN Loss: 3.0563290119171143 | BCE Loss: 0.9673178195953369
Epoch 289 / 500 | iteration 0 / 30 | Total Loss: 4.048339366912842 | KNN Loss: 3.04000186920166 | BCE Loss: 1.0083376169204712
Epoch 289 / 500 | iteration 5 / 30 | Total Loss: 4.029666423797607 | KNN Loss: 3.0634779930114746 | BCE Loss: 0.9661886096000671
Epoch 289 / 500 | iteration 10 / 30 | Total Loss: 4.064867973327637 | KNN Loss: 3.086627960205078 | BCE Loss: 0.978239893913269
Epoch 289 / 500 | iteration 15 / 30 | Total Loss: 4.046764373779297 | KNN Loss: 3.036769390106201 | BCE Loss: 1.0099949836730957
Epoch 289 / 500 | iteration 20 / 30 | Total Loss: 4.099961280822754 | KNN Loss: 3.0812606811523438 |

Epoch 299 / 500 | iteration 5 / 30 | Total Loss: 4.0399885177612305 | KNN Loss: 3.018707036972046 | BCE Loss: 1.0212817192077637
Epoch 299 / 500 | iteration 10 / 30 | Total Loss: 4.08001708984375 | KNN Loss: 3.057105779647827 | BCE Loss: 1.0229110717773438
Epoch 299 / 500 | iteration 15 / 30 | Total Loss: 4.095734596252441 | KNN Loss: 3.0582337379455566 | BCE Loss: 1.0375010967254639
Epoch 299 / 500 | iteration 20 / 30 | Total Loss: 4.0829010009765625 | KNN Loss: 3.037047863006592 | BCE Loss: 1.0458531379699707
Epoch 299 / 500 | iteration 25 / 30 | Total Loss: 4.080162048339844 | KNN Loss: 3.0853521823883057 | BCE Loss: 0.9948100447654724
Epoch 300 / 500 | iteration 0 / 30 | Total Loss: 4.053709983825684 | KNN Loss: 3.058764696121216 | BCE Loss: 0.9949450492858887
Epoch 300 / 500 | iteration 5 / 30 | Total Loss: 4.114897727966309 | KNN Loss: 3.0739548206329346 | BCE Loss: 1.0409431457519531
Epoch 300 / 500 | iteration 10 / 30 | Total Loss: 4.080594062805176 | KNN Loss: 3.07332682609558

Epoch 309 / 500 | iteration 25 / 30 | Total Loss: 4.1005120277404785 | KNN Loss: 3.069261312484741 | BCE Loss: 1.0312507152557373
Epoch 310 / 500 | iteration 0 / 30 | Total Loss: 4.0732316970825195 | KNN Loss: 3.035371780395508 | BCE Loss: 1.0378600358963013
Epoch 310 / 500 | iteration 5 / 30 | Total Loss: 4.088452339172363 | KNN Loss: 3.0697457790374756 | BCE Loss: 1.0187065601348877
Epoch 310 / 500 | iteration 10 / 30 | Total Loss: 4.045154571533203 | KNN Loss: 3.040414571762085 | BCE Loss: 1.0047398805618286
Epoch 310 / 500 | iteration 15 / 30 | Total Loss: 4.10651969909668 | KNN Loss: 3.0534355640411377 | BCE Loss: 1.053084373474121
Epoch 310 / 500 | iteration 20 / 30 | Total Loss: 4.108852386474609 | KNN Loss: 3.078925132751465 | BCE Loss: 1.0299270153045654
Epoch 310 / 500 | iteration 25 / 30 | Total Loss: 4.077104091644287 | KNN Loss: 3.056283712387085 | BCE Loss: 1.0208204984664917
Epoch 311 / 500 | iteration 0 / 30 | Total Loss: 4.14307165145874 | KNN Loss: 3.103580951690674 |

Epoch 320 / 500 | iteration 15 / 30 | Total Loss: 4.084763526916504 | KNN Loss: 3.0778989791870117 | BCE Loss: 1.006864309310913
Epoch 320 / 500 | iteration 20 / 30 | Total Loss: 4.096271991729736 | KNN Loss: 3.0633974075317383 | BCE Loss: 1.032874584197998
Epoch 320 / 500 | iteration 25 / 30 | Total Loss: 4.096059799194336 | KNN Loss: 3.0861282348632812 | BCE Loss: 1.0099313259124756
Epoch 321 / 500 | iteration 0 / 30 | Total Loss: 4.074652194976807 | KNN Loss: 3.0500199794769287 | BCE Loss: 1.024632215499878
Epoch 321 / 500 | iteration 5 / 30 | Total Loss: 4.119438171386719 | KNN Loss: 3.0931994915008545 | BCE Loss: 1.0262386798858643
Epoch 321 / 500 | iteration 10 / 30 | Total Loss: 4.072399616241455 | KNN Loss: 3.0183281898498535 | BCE Loss: 1.054071307182312
Epoch 321 / 500 | iteration 15 / 30 | Total Loss: 4.075682640075684 | KNN Loss: 3.076514959335327 | BCE Loss: 0.999167799949646
Epoch 321 / 500 | iteration 20 / 30 | Total Loss: 4.059852123260498 | KNN Loss: 3.0519766807556152

Epoch 331 / 500 | iteration 5 / 30 | Total Loss: 4.098315238952637 | KNN Loss: 3.064304828643799 | BCE Loss: 1.0340101718902588
Epoch 331 / 500 | iteration 10 / 30 | Total Loss: 4.104279041290283 | KNN Loss: 3.082552194595337 | BCE Loss: 1.0217269659042358
Epoch 331 / 500 | iteration 15 / 30 | Total Loss: 4.1098175048828125 | KNN Loss: 3.065845489501953 | BCE Loss: 1.0439722537994385
Epoch 331 / 500 | iteration 20 / 30 | Total Loss: 4.080868244171143 | KNN Loss: 3.0816409587860107 | BCE Loss: 0.9992272853851318
Epoch 331 / 500 | iteration 25 / 30 | Total Loss: 4.064018726348877 | KNN Loss: 3.082602024078369 | BCE Loss: 0.9814167022705078
Epoch 332 / 500 | iteration 0 / 30 | Total Loss: 4.077089309692383 | KNN Loss: 3.048429489135742 | BCE Loss: 1.0286600589752197
Epoch 332 / 500 | iteration 5 / 30 | Total Loss: 4.069936275482178 | KNN Loss: 3.053239107131958 | BCE Loss: 1.0166971683502197
Epoch 332 / 500 | iteration 10 / 30 | Total Loss: 4.0792155265808105 | KNN Loss: 3.055415153503418

Epoch 341 / 500 | iteration 25 / 30 | Total Loss: 4.07908821105957 | KNN Loss: 3.065523624420166 | BCE Loss: 1.0135648250579834
Epoch 342 / 500 | iteration 0 / 30 | Total Loss: 4.105158805847168 | KNN Loss: 3.085634231567383 | BCE Loss: 1.0195248126983643
Epoch 342 / 500 | iteration 5 / 30 | Total Loss: 4.073190689086914 | KNN Loss: 3.0670089721679688 | BCE Loss: 1.0061814785003662
Epoch 342 / 500 | iteration 10 / 30 | Total Loss: 4.098124980926514 | KNN Loss: 3.0828466415405273 | BCE Loss: 1.0152784585952759
Epoch 342 / 500 | iteration 15 / 30 | Total Loss: 4.0344953536987305 | KNN Loss: 3.04754376411438 | BCE Loss: 0.9869518280029297
Epoch 342 / 500 | iteration 20 / 30 | Total Loss: 4.025619983673096 | KNN Loss: 3.0167596340179443 | BCE Loss: 1.0088603496551514
Epoch 342 / 500 | iteration 25 / 30 | Total Loss: 4.063281059265137 | KNN Loss: 3.0475542545318604 | BCE Loss: 1.015726923942566
Epoch 343 / 500 | iteration 0 / 30 | Total Loss: 4.0590596199035645 | KNN Loss: 3.042356491088867

Epoch 352 / 500 | iteration 15 / 30 | Total Loss: 4.032246112823486 | KNN Loss: 3.0293757915496826 | BCE Loss: 1.0028703212738037
Epoch 352 / 500 | iteration 20 / 30 | Total Loss: 4.069103240966797 | KNN Loss: 3.0453760623931885 | BCE Loss: 1.0237271785736084
Epoch 352 / 500 | iteration 25 / 30 | Total Loss: 4.054629325866699 | KNN Loss: 3.0577046871185303 | BCE Loss: 0.996924638748169
Epoch 353 / 500 | iteration 0 / 30 | Total Loss: 4.050727844238281 | KNN Loss: 3.0572831630706787 | BCE Loss: 0.993444561958313
Epoch 353 / 500 | iteration 5 / 30 | Total Loss: 4.109753608703613 | KNN Loss: 3.0856170654296875 | BCE Loss: 1.0241365432739258
Epoch 353 / 500 | iteration 10 / 30 | Total Loss: 4.0666022300720215 | KNN Loss: 3.0489020347595215 | BCE Loss: 1.0177001953125
Epoch 353 / 500 | iteration 15 / 30 | Total Loss: 4.080479145050049 | KNN Loss: 3.0562362670898438 | BCE Loss: 1.0242429971694946
Epoch 353 / 500 | iteration 20 / 30 | Total Loss: 4.065407752990723 | KNN Loss: 3.03054070472717

Epoch 363 / 500 | iteration 5 / 30 | Total Loss: 4.097436904907227 | KNN Loss: 3.058626651763916 | BCE Loss: 1.0388102531433105
Epoch 363 / 500 | iteration 10 / 30 | Total Loss: 4.1158833503723145 | KNN Loss: 3.0995635986328125 | BCE Loss: 1.0163198709487915
Epoch 363 / 500 | iteration 15 / 30 | Total Loss: 4.100815773010254 | KNN Loss: 3.058811902999878 | BCE Loss: 1.042003870010376
Epoch 363 / 500 | iteration 20 / 30 | Total Loss: 4.0848188400268555 | KNN Loss: 3.0716888904571533 | BCE Loss: 1.013129711151123
Epoch 363 / 500 | iteration 25 / 30 | Total Loss: 4.061610221862793 | KNN Loss: 3.0529844760894775 | BCE Loss: 1.0086255073547363
Epoch 364 / 500 | iteration 0 / 30 | Total Loss: 4.064194679260254 | KNN Loss: 3.0640203952789307 | BCE Loss: 1.0001745223999023
Epoch 364 / 500 | iteration 5 / 30 | Total Loss: 4.0738325119018555 | KNN Loss: 3.064962387084961 | BCE Loss: 1.008870244026184
Epoch 364 / 500 | iteration 10 / 30 | Total Loss: 4.062845230102539 | KNN Loss: 3.05519628524780

Epoch 373 / 500 | iteration 25 / 30 | Total Loss: 4.09687614440918 | KNN Loss: 3.0753509998321533 | BCE Loss: 1.0215251445770264
Epoch 374 / 500 | iteration 0 / 30 | Total Loss: 4.116168022155762 | KNN Loss: 3.085193634033203 | BCE Loss: 1.0309746265411377
Epoch 374 / 500 | iteration 5 / 30 | Total Loss: 4.091175556182861 | KNN Loss: 3.0607194900512695 | BCE Loss: 1.0304559469223022
Epoch 374 / 500 | iteration 10 / 30 | Total Loss: 4.146644115447998 | KNN Loss: 3.1129567623138428 | BCE Loss: 1.0336874723434448
Epoch 374 / 500 | iteration 15 / 30 | Total Loss: 4.118609428405762 | KNN Loss: 3.0652854442596436 | BCE Loss: 1.0533239841461182
Epoch 374 / 500 | iteration 20 / 30 | Total Loss: 4.095663547515869 | KNN Loss: 3.04640531539917 | BCE Loss: 1.0492582321166992
Epoch 374 / 500 | iteration 25 / 30 | Total Loss: 4.099172592163086 | KNN Loss: 3.0700058937072754 | BCE Loss: 1.0291668176651
Epoch 375 / 500 | iteration 0 / 30 | Total Loss: 4.079672336578369 | KNN Loss: 3.048884630203247 | 

Epoch 384 / 500 | iteration 15 / 30 | Total Loss: 4.083115100860596 | KNN Loss: 3.0481956005096436 | BCE Loss: 1.0349193811416626
Epoch 384 / 500 | iteration 20 / 30 | Total Loss: 4.042105674743652 | KNN Loss: 3.040719985961914 | BCE Loss: 1.0013854503631592
Epoch 384 / 500 | iteration 25 / 30 | Total Loss: 4.090600490570068 | KNN Loss: 3.064639091491699 | BCE Loss: 1.0259613990783691
Epoch 385 / 500 | iteration 0 / 30 | Total Loss: 4.065023422241211 | KNN Loss: 3.058389902114868 | BCE Loss: 1.0066334009170532
Epoch 385 / 500 | iteration 5 / 30 | Total Loss: 4.0854010581970215 | KNN Loss: 3.0810868740081787 | BCE Loss: 1.0043141841888428
Epoch 385 / 500 | iteration 10 / 30 | Total Loss: 4.038599491119385 | KNN Loss: 3.022996664047241 | BCE Loss: 1.0156028270721436
Epoch 385 / 500 | iteration 15 / 30 | Total Loss: 4.103513240814209 | KNN Loss: 3.096059560775757 | BCE Loss: 1.0074536800384521
Epoch 385 / 500 | iteration 20 / 30 | Total Loss: 4.053197860717773 | KNN Loss: 3.05938911437988

Epoch 395 / 500 | iteration 5 / 30 | Total Loss: 4.1250224113464355 | KNN Loss: 3.077732801437378 | BCE Loss: 1.0472896099090576
Epoch 395 / 500 | iteration 10 / 30 | Total Loss: 4.070103645324707 | KNN Loss: 3.035731554031372 | BCE Loss: 1.0343722105026245
Epoch 395 / 500 | iteration 15 / 30 | Total Loss: 4.046183109283447 | KNN Loss: 3.0474846363067627 | BCE Loss: 0.9986982941627502
Epoch 395 / 500 | iteration 20 / 30 | Total Loss: 4.07952356338501 | KNN Loss: 3.084867000579834 | BCE Loss: 0.9946564435958862
Epoch 395 / 500 | iteration 25 / 30 | Total Loss: 4.190280914306641 | KNN Loss: 3.127986431121826 | BCE Loss: 1.0622942447662354
Epoch 396 / 500 | iteration 0 / 30 | Total Loss: 4.0727972984313965 | KNN Loss: 3.0358119010925293 | BCE Loss: 1.0369852781295776
Epoch 396 / 500 | iteration 5 / 30 | Total Loss: 4.095860481262207 | KNN Loss: 3.0900018215179443 | BCE Loss: 1.0058584213256836
Epoch 396 / 500 | iteration 10 / 30 | Total Loss: 4.092868804931641 | KNN Loss: 3.07130479812622

Epoch 405 / 500 | iteration 25 / 30 | Total Loss: 4.05124568939209 | KNN Loss: 3.045185089111328 | BCE Loss: 1.0060608386993408
Epoch 406 / 500 | iteration 0 / 30 | Total Loss: 4.094038009643555 | KNN Loss: 3.076324701309204 | BCE Loss: 1.0177130699157715
Epoch 406 / 500 | iteration 5 / 30 | Total Loss: 4.091113090515137 | KNN Loss: 3.074333429336548 | BCE Loss: 1.0167794227600098
Epoch 406 / 500 | iteration 10 / 30 | Total Loss: 4.097942352294922 | KNN Loss: 3.0645244121551514 | BCE Loss: 1.0334177017211914
Epoch 406 / 500 | iteration 15 / 30 | Total Loss: 4.1038737297058105 | KNN Loss: 3.08347487449646 | BCE Loss: 1.0203988552093506
Epoch 406 / 500 | iteration 20 / 30 | Total Loss: 4.077759742736816 | KNN Loss: 3.062635898590088 | BCE Loss: 1.0151240825653076
Epoch 406 / 500 | iteration 25 / 30 | Total Loss: 4.086764335632324 | KNN Loss: 3.087003707885742 | BCE Loss: 0.999760627746582
Epoch 407 / 500 | iteration 0 / 30 | Total Loss: 4.107909202575684 | KNN Loss: 3.086360216140747 | B

Epoch 416 / 500 | iteration 15 / 30 | Total Loss: 4.087550163269043 | KNN Loss: 3.0498249530792236 | BCE Loss: 1.0377252101898193
Epoch 416 / 500 | iteration 20 / 30 | Total Loss: 4.081151008605957 | KNN Loss: 3.032632827758789 | BCE Loss: 1.0485179424285889
Epoch 416 / 500 | iteration 25 / 30 | Total Loss: 4.088641166687012 | KNN Loss: 3.0709195137023926 | BCE Loss: 1.01772141456604
Epoch 417 / 500 | iteration 0 / 30 | Total Loss: 4.086776256561279 | KNN Loss: 3.0423665046691895 | BCE Loss: 1.0444097518920898
Epoch 417 / 500 | iteration 5 / 30 | Total Loss: 4.097029685974121 | KNN Loss: 3.052994728088379 | BCE Loss: 1.044034719467163
Epoch 417 / 500 | iteration 10 / 30 | Total Loss: 4.042773246765137 | KNN Loss: 3.0644874572753906 | BCE Loss: 0.978285551071167
Epoch 417 / 500 | iteration 15 / 30 | Total Loss: 4.114668846130371 | KNN Loss: 3.0660581588745117 | BCE Loss: 1.0486106872558594
Epoch 417 / 500 | iteration 20 / 30 | Total Loss: 4.092311859130859 | KNN Loss: 3.0503692626953125

Epoch 427 / 500 | iteration 5 / 30 | Total Loss: 4.094338417053223 | KNN Loss: 3.074815273284912 | BCE Loss: 1.0195233821868896
Epoch 427 / 500 | iteration 10 / 30 | Total Loss: 4.074184417724609 | KNN Loss: 3.061771869659424 | BCE Loss: 1.0124123096466064
Epoch 427 / 500 | iteration 15 / 30 | Total Loss: 4.060611724853516 | KNN Loss: 3.047820806503296 | BCE Loss: 1.0127910375595093
Epoch 427 / 500 | iteration 20 / 30 | Total Loss: 4.1168694496154785 | KNN Loss: 3.057279348373413 | BCE Loss: 1.0595901012420654
Epoch 427 / 500 | iteration 25 / 30 | Total Loss: 4.040396690368652 | KNN Loss: 3.023773670196533 | BCE Loss: 1.01662278175354
Epoch 428 / 500 | iteration 0 / 30 | Total Loss: 4.0471086502075195 | KNN Loss: 3.0114212036132812 | BCE Loss: 1.0356872081756592
Epoch 428 / 500 | iteration 5 / 30 | Total Loss: 4.086655616760254 | KNN Loss: 3.0692930221557617 | BCE Loss: 1.017362356185913
Epoch 428 / 500 | iteration 10 / 30 | Total Loss: 4.0512285232543945 | KNN Loss: 3.0401103496551514

Epoch 437 / 500 | iteration 25 / 30 | Total Loss: 4.081905364990234 | KNN Loss: 3.057138442993164 | BCE Loss: 1.0247671604156494
Epoch   438: reducing learning rate of group 0 to 2.2999e-07.
Epoch 438 / 500 | iteration 0 / 30 | Total Loss: 4.101387977600098 | KNN Loss: 3.0439934730529785 | BCE Loss: 1.05739426612854
Epoch 438 / 500 | iteration 5 / 30 | Total Loss: 4.114784240722656 | KNN Loss: 3.0964395999908447 | BCE Loss: 1.0183446407318115
Epoch 438 / 500 | iteration 10 / 30 | Total Loss: 4.092372417449951 | KNN Loss: 3.0818772315979004 | BCE Loss: 1.0104953050613403
Epoch 438 / 500 | iteration 15 / 30 | Total Loss: 4.0650739669799805 | KNN Loss: 3.054426908493042 | BCE Loss: 1.0106472969055176
Epoch 438 / 500 | iteration 20 / 30 | Total Loss: 4.1188859939575195 | KNN Loss: 3.0869786739349365 | BCE Loss: 1.031907558441162
Epoch 438 / 500 | iteration 25 / 30 | Total Loss: 4.107195854187012 | KNN Loss: 3.0797200202941895 | BCE Loss: 1.0274758338928223
Epoch 439 / 500 | iteration 0 / 3

Epoch 448 / 500 | iteration 15 / 30 | Total Loss: 4.087451457977295 | KNN Loss: 3.0844593048095703 | BCE Loss: 1.0029921531677246
Epoch 448 / 500 | iteration 20 / 30 | Total Loss: 4.091189861297607 | KNN Loss: 3.087977170944214 | BCE Loss: 1.003212571144104
Epoch 448 / 500 | iteration 25 / 30 | Total Loss: 4.084078311920166 | KNN Loss: 3.0850090980529785 | BCE Loss: 0.9990692138671875
Epoch   449: reducing learning rate of group 0 to 1.6100e-07.
Epoch 449 / 500 | iteration 0 / 30 | Total Loss: 4.108673095703125 | KNN Loss: 3.070420265197754 | BCE Loss: 1.038252592086792
Epoch 449 / 500 | iteration 5 / 30 | Total Loss: 4.087719440460205 | KNN Loss: 3.068009853363037 | BCE Loss: 1.019709587097168
Epoch 449 / 500 | iteration 10 / 30 | Total Loss: 4.104234218597412 | KNN Loss: 3.080981492996216 | BCE Loss: 1.0232527256011963
Epoch 449 / 500 | iteration 15 / 30 | Total Loss: 4.097179889678955 | KNN Loss: 3.0841450691223145 | BCE Loss: 1.0130349397659302
Epoch 449 / 500 | iteration 20 / 30 |

Epoch 459 / 500 | iteration 5 / 30 | Total Loss: 4.105681419372559 | KNN Loss: 3.0701873302459717 | BCE Loss: 1.035494327545166
Epoch 459 / 500 | iteration 10 / 30 | Total Loss: 4.093100547790527 | KNN Loss: 3.058084726333618 | BCE Loss: 1.0350158214569092
Epoch 459 / 500 | iteration 15 / 30 | Total Loss: 4.093588829040527 | KNN Loss: 3.0786702632904053 | BCE Loss: 1.014918565750122
Epoch 459 / 500 | iteration 20 / 30 | Total Loss: 4.063531875610352 | KNN Loss: 3.0682661533355713 | BCE Loss: 0.9952654838562012
Epoch 459 / 500 | iteration 25 / 30 | Total Loss: 4.095851421356201 | KNN Loss: 3.084588050842285 | BCE Loss: 1.011263370513916
Epoch   460: reducing learning rate of group 0 to 1.1270e-07.
Epoch 460 / 500 | iteration 0 / 30 | Total Loss: 4.03083610534668 | KNN Loss: 3.045865774154663 | BCE Loss: 0.9849704504013062
Epoch 460 / 500 | iteration 5 / 30 | Total Loss: 4.088153839111328 | KNN Loss: 3.0413007736206055 | BCE Loss: 1.0468533039093018
Epoch 460 / 500 | iteration 10 / 30 | 

Epoch 469 / 500 | iteration 25 / 30 | Total Loss: 4.071141242980957 | KNN Loss: 3.0663435459136963 | BCE Loss: 1.0047974586486816
Epoch 470 / 500 | iteration 0 / 30 | Total Loss: 4.078019618988037 | KNN Loss: 3.0741524696350098 | BCE Loss: 1.0038670301437378
Epoch 470 / 500 | iteration 5 / 30 | Total Loss: 4.061554908752441 | KNN Loss: 3.0299925804138184 | BCE Loss: 1.0315624475479126
Epoch 470 / 500 | iteration 10 / 30 | Total Loss: 4.113236427307129 | KNN Loss: 3.0986287593841553 | BCE Loss: 1.0146076679229736
Epoch 470 / 500 | iteration 15 / 30 | Total Loss: 4.0313591957092285 | KNN Loss: 3.0465540885925293 | BCE Loss: 0.9848049879074097
Epoch 470 / 500 | iteration 20 / 30 | Total Loss: 4.147325038909912 | KNN Loss: 3.1089487075805664 | BCE Loss: 1.0383764505386353
Epoch 470 / 500 | iteration 25 / 30 | Total Loss: 4.082969665527344 | KNN Loss: 3.075295925140381 | BCE Loss: 1.0076736211776733
Epoch   471: reducing learning rate of group 0 to 7.8888e-08.
Epoch 471 / 500 | iteration 0 

Epoch 480 / 500 | iteration 10 / 30 | Total Loss: 4.089734077453613 | KNN Loss: 3.046983242034912 | BCE Loss: 1.042750597000122
Epoch 480 / 500 | iteration 15 / 30 | Total Loss: 4.052148818969727 | KNN Loss: 3.0352730751037598 | BCE Loss: 1.016875982284546
Epoch 480 / 500 | iteration 20 / 30 | Total Loss: 4.065838813781738 | KNN Loss: 3.047142267227173 | BCE Loss: 1.0186963081359863
Epoch 480 / 500 | iteration 25 / 30 | Total Loss: 4.11040735244751 | KNN Loss: 3.1161696910858154 | BCE Loss: 0.9942377805709839
Epoch 481 / 500 | iteration 0 / 30 | Total Loss: 4.094037055969238 | KNN Loss: 3.091932535171509 | BCE Loss: 1.0021045207977295
Epoch 481 / 500 | iteration 5 / 30 | Total Loss: 4.109869956970215 | KNN Loss: 3.0841615200042725 | BCE Loss: 1.0257084369659424
Epoch 481 / 500 | iteration 10 / 30 | Total Loss: 4.079349517822266 | KNN Loss: 3.076646327972412 | BCE Loss: 1.0027031898498535
Epoch 481 / 500 | iteration 15 / 30 | Total Loss: 4.082265853881836 | KNN Loss: 3.038120746612549 |

Epoch 491 / 500 | iteration 0 / 30 | Total Loss: 4.123971939086914 | KNN Loss: 3.0755860805511475 | BCE Loss: 1.0483858585357666
Epoch 491 / 500 | iteration 5 / 30 | Total Loss: 4.075064182281494 | KNN Loss: 3.0567729473114014 | BCE Loss: 1.0182912349700928
Epoch 491 / 500 | iteration 10 / 30 | Total Loss: 4.0984954833984375 | KNN Loss: 3.062119483947754 | BCE Loss: 1.0363757610321045
Epoch 491 / 500 | iteration 15 / 30 | Total Loss: 4.119644641876221 | KNN Loss: 3.0852420330047607 | BCE Loss: 1.03440260887146
Epoch 491 / 500 | iteration 20 / 30 | Total Loss: 4.0983662605285645 | KNN Loss: 3.0622236728668213 | BCE Loss: 1.0361427068710327
Epoch 491 / 500 | iteration 25 / 30 | Total Loss: 4.060647964477539 | KNN Loss: 3.026803731918335 | BCE Loss: 1.033844232559204
Epoch 492 / 500 | iteration 0 / 30 | Total Loss: 4.0682172775268555 | KNN Loss: 3.0446863174438477 | BCE Loss: 1.0235308408737183
Epoch 492 / 500 | iteration 5 / 30 | Total Loss: 4.107405662536621 | KNN Loss: 3.10675168037414

In [7]:
outputs, iterm = model(dataset[67][0].unsqueeze(0).to(device), return_intermidiate=True)
print(outputs)
print(outputs.min())
print(outputs.max())

tensor([[ 2.6970,  3.1057,  2.7477,  3.4826,  3.7074,  0.7246,  2.7431,  2.0297,
          2.1912,  1.7263,  2.2342,  1.8351,  0.9432,  1.6760,  1.1199,  1.4433,
          2.3246,  3.3025,  2.7007,  2.4342,  1.5816,  3.0688,  2.2299,  2.5344,
          1.9169,  1.8646,  2.2183,  1.4689,  1.4033,  0.3506, -0.2870,  1.0944,
          0.3384,  0.9997,  1.2990,  1.5376,  1.1457,  3.5365,  0.7669,  1.3839,
          1.0160, -0.7720, -0.3782,  2.4335,  1.7112,  0.5022, -0.3516,  0.0256,
          1.3029,  2.5729,  1.9033,  0.2487,  1.4146,  0.5848, -0.5768,  1.2612,
          1.6574,  1.1411,  1.4519,  1.8085,  0.7603,  0.7590,  0.2704,  1.9292,
          1.4704,  1.3470, -1.8991,  0.3703,  2.5268,  2.0931,  2.6446,  0.3866,
          1.0189,  2.6251,  2.0826,  1.2561,  0.2715,  0.7285,  0.2521,  1.5650,
          0.0886,  0.5018,  1.9238, -0.3570,  0.3787, -1.0158, -2.3823, -0.4354,
          0.5506, -1.8535,  0.3873, -0.0160, -0.5293, -1.0027,  0.3250,  1.4222,
         -0.8213, -0.8191,  

In [8]:
plt.figure()
plt.plot(losses)
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [9]:
dataset.transform = torchvision.transforms.Compose([
    BinaryEncodingTransform(mapping=dataset.items_to_idx),
]
)
dataset.target_transform = torchvision.transforms.Compose([
    BinaryEncodingTransform(mapping=dataset.items_to_idx),
]
)

In [10]:
dataset_ = [d[0].cpu() for d in dataset]

In [11]:
model = model.eval().cpu()
projections = model.calculate_intermidiate(dataset_)

100%|██████████| 15/15 [00:00<00:00, 85.35it/s]


In [12]:
distances = pairwise_distances(projections)
# distances = np.triu(distances)
distances_f = distances.flatten()

plt.matshow(distances)
plt.colorbar()
plt.figure()
plt.hist(distances_f[distances_f > 0], bins=1000)
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

# Fit DBSCAN and calculate indices


In [13]:
clusters = DBSCAN(eps=0.2, min_samples=80).fit_predict(projections)
# scores = []
# best_score = float('inf')
# clusters = None
# range_ = list(range(5, 20))
# for k in tqdm(range_):
#     y = GaussianMixture(n_components=k).fit_predict(projections)
#     cur_score = davies_bouldin_score(projections, y)
#     scores.append(cur_score)
    
#     if cur_score < best_score:
#         best_score = cur_score
#         clusters = y

In [14]:
perplexity = 100

p = reduce_dims_and_plot(projections[clusters != -1],
                         y=clusters[clusters != -1],
                         title=f'perplexity: {perplexity}',
                         file_name=None,
                         perplexity=perplexity,
                         library='Multicore-TSNE',
                         perform_PCA=False,
                         projected=None,
                         figure_type='2d',
                         show_figure=True,
                         close_figure=False,
                         text=None)    

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [15]:
# from sklearn.tree import DecisionTreeClassifier
# from sklearn import tree
# from sklearn.tree import _tree

In [16]:
tensor_dataset = torch.stack(dataset_)

In [17]:
# clf = DecisionTreeClassifier(max_depth=200, min_samples_leaf=5)
# clf = clf.fit(tensor_dataset[clusters!=-1], clusters[clusters != -1])
# print(clf.score(tensor_dataset[clusters!=-1], clusters[clusters != -1]))
# print(clf.get_depth())

In [18]:
# scores = []
# for min_samples in range(1,50, 1):
#     clf = DecisionTreeClassifier(max_depth=200, min_samples_leaf=min_samples)
#     clf = clf.fit(tensor_dataset[clusters!=-1], clusters[clusters != -1])
#     scores.append(clf.score(tensor_dataset[clusters!=-1], clusters[clusters != -1]))
    
# plt.figure()
# plt.plot(list(range(1,50, 1)), scores)
# plt.show()

In [19]:
def get_rules(tree, feature_names, class_names):
    tree_ = tree.tree_
    feature_name = [
        feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!"
        for i in tree_.feature
    ]

    paths = []
    path = []
    
    def recurse(node, path, paths):
        if tree_.feature[node] != _tree.TREE_UNDEFINED:
            name = feature_name[node]
            threshold = tree_.threshold[node]
            p1, p2 = list(path), list(path)
#             p1 += [f"({name} <= {np.round(threshold, 3)})"]
            p1 += [(name, '<=', np.round(threshold, 3))]
            recurse(tree_.children_left[node], p1, paths)
            p2 += [(name, '>', np.round(threshold, 3))]
            recurse(tree_.children_right[node], p2, paths)
        else:
            path += [(tree_.value[node], tree_.n_node_samples[node])]
            paths += [path]
            
    recurse(0, path, paths)

    # sort by samples count
    samples_count = [p[-1][1] for p in paths]
    ii = list(np.argsort(samples_count))
    paths = [paths[i] for i in reversed(ii)]
    
    rules = []
    for path in paths:
        rule = []
        
        for p in path[:-1]:
            rule += [p]
        target = " then "
        if class_names is None:
            target += "response: "+str(np.round(path[-1][0][0][0],3))
        else:
            classes = path[-1][0][0]
            l = np.argmax(classes)
            target += f"class: {class_names[l]} (proba: {np.round(100.0*classes[l]/np.sum(classes),2)}%)"
           
        proba = np.round(100.0*classes[l]/np.sum(classes),2)
        target += f" | based on {path[-1][1]:,} samples"
        rule_wrapper = {'target': target, 'rule': rule, 'proba': proba}
        rules += [rule_wrapper]
        
    return rules

In [20]:
# rules = get_rules(clf, dataset.items, clusters[clusters != -1])

# for rule in rules:
#     n_pos = 0
#     for c,p,v in rule['rule']:
#         if p == '>':
#             n_pos += 1
#     rule['pos'] = n_pos

In [21]:
# plt.figure()
# probs = [r['proba'] for r in rules]
# plt.hist(probs, bins = 100)
# plt.show()

In [22]:
# rules = sorted(rules, key=lambda x:x['pos'])
# rules = [r for r in rules if r['proba'] > 50]
# print(len(rules))

In [23]:
# for i in range(17):
#     r_i = rules[i]
#     print(f"------------- rule {i} length {len(r_i)} -------------")
#     print(r_i)

# Train a Soft-Decision-Tree given the self-labels

## Prepare the dataset

In [24]:
tree_dataset = list(zip(tensor_dataset[clusters!=-1], clusters[clusters != -1]))
batch_size = 512
tree_loader = torch.utils.data.DataLoader(tree_dataset, batch_size=batch_size, shuffle=True)

# Define how we prune the weights of a node

In [25]:
def prune_node(node_weights, factor=1):
    w = node_weights.cpu().detach().numpy()
    mean_ = np.mean(w)
    std_ = np.std(w)
    node_weights[((mean_ - std_ * factor) < node_weights) & (node_weights < (mean_ + std_ * factor))] = 0
    return node_weights

def prune_node_keep(node_weights, keep=4):
    w = node_weights.cpu().detach().numpy()
    throw_idx = np.argsort(abs(w))[:-keep]
    node_weights[throw_idx] = 0
    return node_weights

def prune_tree(tree_, factor):
    new_weights = tree_.inner_nodes.weight.clone()
    for i in range(new_weights.shape[0]):
        res = prune_node_keep(new_weights[i, :], factor)
        new_weights[i, :] = res

    with torch.no_grad():
        tree_.inner_nodes.weight.copy_(new_weights)
        
def sparseness(x):
    s = []
    for i in range(x.shape[0]):
        x_ = x[i, :]
        sp = (len(x_) - torch.norm(x_, 0).item()) / len(x_)
        s.append(sp)
    return np.mean(s)

def compute_regularization_by_level(tree):
    total_reg = 0
    for i in range(tree.inner_nodes.weight.shape[0]):
        cur_level = np.floor(np.log2(i+1))
        node_reg = torch.norm(tree.inner_nodes.weight[i].view(-1), 2)
        total_reg += 2**(-cur_level) * node_reg
    return total_reg

def show_sparseness(tree):
    avg_sp = sparseness(tree.inner_nodes.weight)
    print(f"Average sparseness: {avg_sp}")
    layer = 0
    sps = []
    for i in range(tree.inner_nodes.weight.shape[0]):
        cur_layer = int(np.floor(np.log2(i+1)))
        if cur_layer != layer:
            print(f"layer {layer}: {np.mean(sps)}")
            sps = []
            layer = cur_layer

        x_ = tree.inner_nodes.weight[i, :]
        sp = (len(x_) - torch.norm(x_, 0).item()) / len(x_)
        sps.append(sp)
        
    return avg_sp

## Training configurations

In [26]:
def do_epoch(model, loader, device, log_interval, losses, accs, epoch, iteration):
    model = model.train()
    for batch_idx, (data, target) in enumerate(loader):
        iteration += 1
        data, target = data.to(device), target.to(device)

        output, penalty = tree.forward(data)

        # Loss
        loss_tree = criterion(output, target.view(-1))

        # Penalty
        loss_tree += penalty

        # Sparse regularization
#         fc_params = torch.cat([x.view(-1) for x in tree.inner_nodes.parameters()])
#         regularization = sparsity_lamda * torch.norm(fc_params, 2)
        regularization = sparsity_lamda * compute_regularization_by_level(tree)
        loss = loss_tree

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        losses.append(loss.item())

        pred = output.data.max(1)[1]
        correct = pred.eq(target.view(-1).data).sum()
        accs.append(correct.item() / data.size()[0])

        # Print training status
        if batch_idx % log_interval == 0:
            print(f"Epoch: {epoch:02d} | Batch: {batch_idx:03d} / {len(loader):03d} | Total loss: {loss.item():.3f} | Reg loss: {regularization.item():.3f} | Tree loss: {loss_tree.item():.3f} | Accuracy: {correct.item() / data.size()[0]:03f} | {round((time.time() - start_time) / iteration, 3)} sec/iter")
            
    return iteration


In [27]:
lr = 5e-3
weight_decay = 5e-4
sparsity_lamda = 2e-3
epochs = 100
output_dim = len(set(clusters))
log_interval = 1
use_cuda = device != 'cpu'

In [28]:
tree = SDT(input_dim=tensor_dataset.shape[1], output_dim=len(clusters - 1), depth=tree_depth, lamda=1e-3, use_cuda=use_cuda)
optimizer = torch.optim.Adam(tree.parameters(),
                                 lr=lr,
                                 weight_decay=weight_decay)
criterion = nn.CrossEntropyLoss()
tree = tree.to(device)

In [29]:
losses = []
accs = []
sparsity = []

In [30]:
start_time = time.time()
iteration = 0
for epoch in range(epochs):
    # Training
    avg_sp = show_sparseness(tree)
    sparsity.append(avg_sp)
    iteration = do_epoch(tree, tree_loader, device, log_interval, losses, accs, epoch, iteration)
    
    if epoch % 1 == 0:
        prune_tree(tree, factor=3)
        

Average sparseness: 0.0
layer 0: 0.0
layer 1: 0.0
layer 2: 0.0
layer 3: 0.0
layer 4: 0.0
layer 5: 0.0
layer 6: 0.0
layer 7: 0.0
layer 8: 0.0
Epoch: 00 | Batch: 000 / 019 | Total loss: 9.621 | Reg loss: 0.012 | Tree loss: 9.621 | Accuracy: 0.000000 | 0.821 sec/iter
Epoch: 00 | Batch: 001 / 019 | Total loss: 9.618 | Reg loss: 0.011 | Tree loss: 9.618 | Accuracy: 0.000000 | 0.784 sec/iter
Epoch: 00 | Batch: 002 / 019 | Total loss: 9.617 | Reg loss: 0.010 | Tree loss: 9.617 | Accuracy: 0.000000 | 0.775 sec/iter
Epoch: 00 | Batch: 003 / 019 | Total loss: 9.615 | Reg loss: 0.009 | Tree loss: 9.615 | Accuracy: 0.000000 | 0.772 sec/iter
Epoch: 00 | Batch: 004 / 019 | Total loss: 9.613 | Reg loss: 0.009 | Tree loss: 9.613 | Accuracy: 0.000000 | 0.769 sec/iter
Epoch: 00 | Batch: 005 / 019 | Total loss: 9.612 | Reg loss: 0.008 | Tree loss: 9.612 | Accuracy: 0.000000 | 0.768 sec/iter
Epoch: 00 | Batch: 006 / 019 | Total loss: 9.609 | Reg loss: 0.007 | Tree loss: 9.609 | Accuracy: 0.000000 | 0.768 

Epoch: 03 | Batch: 001 / 019 | Total loss: 9.569 | Reg loss: 0.005 | Tree loss: 9.569 | Accuracy: 0.048828 | 0.785 sec/iter
Epoch: 03 | Batch: 002 / 019 | Total loss: 9.567 | Reg loss: 0.005 | Tree loss: 9.567 | Accuracy: 0.056641 | 0.784 sec/iter
Epoch: 03 | Batch: 003 / 019 | Total loss: 9.568 | Reg loss: 0.005 | Tree loss: 9.568 | Accuracy: 0.066406 | 0.784 sec/iter
Epoch: 03 | Batch: 004 / 019 | Total loss: 9.568 | Reg loss: 0.005 | Tree loss: 9.568 | Accuracy: 0.048828 | 0.784 sec/iter
Epoch: 03 | Batch: 005 / 019 | Total loss: 9.567 | Reg loss: 0.005 | Tree loss: 9.567 | Accuracy: 0.048828 | 0.783 sec/iter
Epoch: 03 | Batch: 006 / 019 | Total loss: 9.563 | Reg loss: 0.005 | Tree loss: 9.563 | Accuracy: 0.054688 | 0.783 sec/iter
Epoch: 03 | Batch: 007 / 019 | Total loss: 9.562 | Reg loss: 0.006 | Tree loss: 9.562 | Accuracy: 0.044922 | 0.783 sec/iter
Epoch: 03 | Batch: 008 / 019 | Total loss: 9.560 | Reg loss: 0.006 | Tree loss: 9.560 | Accuracy: 0.058594 | 0.783 sec/iter
Epoch: 0

Epoch: 06 | Batch: 004 / 019 | Total loss: 9.509 | Reg loss: 0.009 | Tree loss: 9.509 | Accuracy: 0.083984 | 0.78 sec/iter
Epoch: 06 | Batch: 005 / 019 | Total loss: 9.500 | Reg loss: 0.009 | Tree loss: 9.500 | Accuracy: 0.091797 | 0.78 sec/iter
Epoch: 06 | Batch: 006 / 019 | Total loss: 9.498 | Reg loss: 0.010 | Tree loss: 9.498 | Accuracy: 0.097656 | 0.78 sec/iter
Epoch: 06 | Batch: 007 / 019 | Total loss: 9.502 | Reg loss: 0.010 | Tree loss: 9.502 | Accuracy: 0.070312 | 0.78 sec/iter
Epoch: 06 | Batch: 008 / 019 | Total loss: 9.494 | Reg loss: 0.010 | Tree loss: 9.494 | Accuracy: 0.074219 | 0.78 sec/iter
Epoch: 06 | Batch: 009 / 019 | Total loss: 9.492 | Reg loss: 0.011 | Tree loss: 9.492 | Accuracy: 0.091797 | 0.78 sec/iter
Epoch: 06 | Batch: 010 / 019 | Total loss: 9.480 | Reg loss: 0.011 | Tree loss: 9.480 | Accuracy: 0.070312 | 0.78 sec/iter
Epoch: 06 | Batch: 011 / 019 | Total loss: 9.480 | Reg loss: 0.011 | Tree loss: 9.480 | Accuracy: 0.076172 | 0.78 sec/iter
Epoch: 06 | Batc

Epoch: 09 | Batch: 007 / 019 | Total loss: 9.191 | Reg loss: 0.015 | Tree loss: 9.191 | Accuracy: 0.076172 | 0.785 sec/iter
Epoch: 09 | Batch: 008 / 019 | Total loss: 9.187 | Reg loss: 0.016 | Tree loss: 9.187 | Accuracy: 0.064453 | 0.785 sec/iter
Epoch: 09 | Batch: 009 / 019 | Total loss: 9.160 | Reg loss: 0.016 | Tree loss: 9.160 | Accuracy: 0.058594 | 0.786 sec/iter
Epoch: 09 | Batch: 010 / 019 | Total loss: 9.128 | Reg loss: 0.016 | Tree loss: 9.128 | Accuracy: 0.050781 | 0.786 sec/iter
Epoch: 09 | Batch: 011 / 019 | Total loss: 9.126 | Reg loss: 0.017 | Tree loss: 9.126 | Accuracy: 0.054688 | 0.786 sec/iter
Epoch: 09 | Batch: 012 / 019 | Total loss: 9.113 | Reg loss: 0.017 | Tree loss: 9.113 | Accuracy: 0.052734 | 0.786 sec/iter
Epoch: 09 | Batch: 013 / 019 | Total loss: 9.087 | Reg loss: 0.017 | Tree loss: 9.087 | Accuracy: 0.068359 | 0.787 sec/iter
Epoch: 09 | Batch: 014 / 019 | Total loss: 9.061 | Reg loss: 0.018 | Tree loss: 9.061 | Accuracy: 0.050781 | 0.787 sec/iter
Epoch: 0

Epoch: 12 | Batch: 010 / 019 | Total loss: 8.543 | Reg loss: 0.020 | Tree loss: 8.543 | Accuracy: 0.068359 | 0.804 sec/iter
Epoch: 12 | Batch: 011 / 019 | Total loss: 8.519 | Reg loss: 0.020 | Tree loss: 8.519 | Accuracy: 0.056641 | 0.804 sec/iter
Epoch: 12 | Batch: 012 / 019 | Total loss: 8.532 | Reg loss: 0.020 | Tree loss: 8.532 | Accuracy: 0.041016 | 0.804 sec/iter
Epoch: 12 | Batch: 013 / 019 | Total loss: 8.509 | Reg loss: 0.021 | Tree loss: 8.509 | Accuracy: 0.056641 | 0.804 sec/iter
Epoch: 12 | Batch: 014 / 019 | Total loss: 8.482 | Reg loss: 0.021 | Tree loss: 8.482 | Accuracy: 0.058594 | 0.805 sec/iter
Epoch: 12 | Batch: 015 / 019 | Total loss: 8.450 | Reg loss: 0.021 | Tree loss: 8.450 | Accuracy: 0.048828 | 0.805 sec/iter
Epoch: 12 | Batch: 016 / 019 | Total loss: 8.454 | Reg loss: 0.022 | Tree loss: 8.454 | Accuracy: 0.058594 | 0.805 sec/iter
Epoch: 12 | Batch: 017 / 019 | Total loss: 8.426 | Reg loss: 0.022 | Tree loss: 8.426 | Accuracy: 0.035156 | 0.804 sec/iter
Epoch: 1

Epoch: 15 | Batch: 013 / 019 | Total loss: 7.900 | Reg loss: 0.023 | Tree loss: 7.900 | Accuracy: 0.046875 | 0.816 sec/iter
Epoch: 15 | Batch: 014 / 019 | Total loss: 7.881 | Reg loss: 0.023 | Tree loss: 7.881 | Accuracy: 0.062500 | 0.816 sec/iter
Epoch: 15 | Batch: 015 / 019 | Total loss: 7.809 | Reg loss: 0.023 | Tree loss: 7.809 | Accuracy: 0.046875 | 0.817 sec/iter
Epoch: 15 | Batch: 016 / 019 | Total loss: 7.822 | Reg loss: 0.023 | Tree loss: 7.822 | Accuracy: 0.037109 | 0.817 sec/iter
Epoch: 15 | Batch: 017 / 019 | Total loss: 7.821 | Reg loss: 0.024 | Tree loss: 7.821 | Accuracy: 0.044922 | 0.817 sec/iter
Epoch: 15 | Batch: 018 / 019 | Total loss: 7.746 | Reg loss: 0.024 | Tree loss: 7.746 | Accuracy: 0.062718 | 0.817 sec/iter
Average sparseness: 0.9821428571428573
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
layer 7: 0.98214285714

Epoch: 18 | Batch: 016 / 019 | Total loss: 7.237 | Reg loss: 0.024 | Tree loss: 7.237 | Accuracy: 0.046875 | 0.827 sec/iter
Epoch: 18 | Batch: 017 / 019 | Total loss: 7.188 | Reg loss: 0.024 | Tree loss: 7.188 | Accuracy: 0.068359 | 0.827 sec/iter
Epoch: 18 | Batch: 018 / 019 | Total loss: 7.228 | Reg loss: 0.025 | Tree loss: 7.228 | Accuracy: 0.048780 | 0.827 sec/iter
Average sparseness: 0.9821428571428573
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
layer 7: 0.9821428571428573
layer 8: 0.9821428571428573
Epoch: 19 | Batch: 000 / 019 | Total loss: 7.352 | Reg loss: 0.023 | Tree loss: 7.352 | Accuracy: 0.041016 | 0.829 sec/iter
Epoch: 19 | Batch: 001 / 019 | Total loss: 7.291 | Reg loss: 0.023 | Tree loss: 7.291 | Accuracy: 0.060547 | 0.829 sec/iter
Epoch: 19 | Batch: 002 / 019 | Total loss: 7.288 | Reg loss: 0.023 | Tree loss: 7.288 | Ac

Average sparseness: 0.9821428571428573
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
layer 7: 0.9821428571428573
layer 8: 0.9821428571428573
Epoch: 22 | Batch: 000 / 019 | Total loss: 6.800 | Reg loss: 0.024 | Tree loss: 6.800 | Accuracy: 0.050781 | 0.834 sec/iter
Epoch: 22 | Batch: 001 / 019 | Total loss: 6.824 | Reg loss: 0.024 | Tree loss: 6.824 | Accuracy: 0.042969 | 0.834 sec/iter
Epoch: 22 | Batch: 002 / 019 | Total loss: 6.770 | Reg loss: 0.024 | Tree loss: 6.770 | Accuracy: 0.052734 | 0.834 sec/iter
Epoch: 22 | Batch: 003 / 019 | Total loss: 6.745 | Reg loss: 0.024 | Tree loss: 6.745 | Accuracy: 0.050781 | 0.834 sec/iter
Epoch: 22 | Batch: 004 / 019 | Total loss: 6.754 | Reg loss: 0.024 | Tree loss: 6.754 | Accuracy: 0.046875 | 0.834 sec/iter
Epoch: 22 | Batch: 005 / 019 | Total loss: 6.743 | Reg loss: 0.024 | Tree loss: 6.743 | Ac

layer 8: 0.9821428571428573
Epoch: 25 | Batch: 000 / 019 | Total loss: 6.276 | Reg loss: 0.024 | Tree loss: 6.276 | Accuracy: 0.068359 | 0.838 sec/iter
Epoch: 25 | Batch: 001 / 019 | Total loss: 6.305 | Reg loss: 0.024 | Tree loss: 6.305 | Accuracy: 0.050781 | 0.838 sec/iter
Epoch: 25 | Batch: 002 / 019 | Total loss: 6.256 | Reg loss: 0.024 | Tree loss: 6.256 | Accuracy: 0.074219 | 0.838 sec/iter
Epoch: 25 | Batch: 003 / 019 | Total loss: 6.223 | Reg loss: 0.024 | Tree loss: 6.223 | Accuracy: 0.046875 | 0.838 sec/iter
Epoch: 25 | Batch: 004 / 019 | Total loss: 6.298 | Reg loss: 0.024 | Tree loss: 6.298 | Accuracy: 0.050781 | 0.838 sec/iter
Epoch: 25 | Batch: 005 / 019 | Total loss: 6.205 | Reg loss: 0.024 | Tree loss: 6.205 | Accuracy: 0.052734 | 0.838 sec/iter
Epoch: 25 | Batch: 006 / 019 | Total loss: 6.239 | Reg loss: 0.024 | Tree loss: 6.239 | Accuracy: 0.039062 | 0.838 sec/iter
Epoch: 25 | Batch: 007 / 019 | Total loss: 6.209 | Reg loss: 0.024 | Tree loss: 6.209 | Accuracy: 0.0644

Epoch: 28 | Batch: 002 / 019 | Total loss: 5.861 | Reg loss: 0.024 | Tree loss: 5.861 | Accuracy: 0.054688 | 0.842 sec/iter
Epoch: 28 | Batch: 003 / 019 | Total loss: 5.851 | Reg loss: 0.024 | Tree loss: 5.851 | Accuracy: 0.039062 | 0.842 sec/iter
Epoch: 28 | Batch: 004 / 019 | Total loss: 5.841 | Reg loss: 0.024 | Tree loss: 5.841 | Accuracy: 0.042969 | 0.842 sec/iter
Epoch: 28 | Batch: 005 / 019 | Total loss: 5.802 | Reg loss: 0.024 | Tree loss: 5.802 | Accuracy: 0.058594 | 0.842 sec/iter
Epoch: 28 | Batch: 006 / 019 | Total loss: 5.768 | Reg loss: 0.024 | Tree loss: 5.768 | Accuracy: 0.052734 | 0.842 sec/iter
Epoch: 28 | Batch: 007 / 019 | Total loss: 5.772 | Reg loss: 0.024 | Tree loss: 5.772 | Accuracy: 0.052734 | 0.842 sec/iter
Epoch: 28 | Batch: 008 / 019 | Total loss: 5.821 | Reg loss: 0.024 | Tree loss: 5.821 | Accuracy: 0.041016 | 0.842 sec/iter
Epoch: 28 | Batch: 009 / 019 | Total loss: 5.779 | Reg loss: 0.024 | Tree loss: 5.779 | Accuracy: 0.058594 | 0.842 sec/iter
Epoch: 2

Epoch: 31 | Batch: 005 / 019 | Total loss: 5.430 | Reg loss: 0.024 | Tree loss: 5.430 | Accuracy: 0.062500 | 0.845 sec/iter
Epoch: 31 | Batch: 006 / 019 | Total loss: 5.389 | Reg loss: 0.024 | Tree loss: 5.389 | Accuracy: 0.050781 | 0.845 sec/iter
Epoch: 31 | Batch: 007 / 019 | Total loss: 5.395 | Reg loss: 0.024 | Tree loss: 5.395 | Accuracy: 0.056641 | 0.845 sec/iter
Epoch: 31 | Batch: 008 / 019 | Total loss: 5.387 | Reg loss: 0.024 | Tree loss: 5.387 | Accuracy: 0.056641 | 0.845 sec/iter
Epoch: 31 | Batch: 009 / 019 | Total loss: 5.334 | Reg loss: 0.024 | Tree loss: 5.334 | Accuracy: 0.056641 | 0.845 sec/iter
Epoch: 31 | Batch: 010 / 019 | Total loss: 5.406 | Reg loss: 0.024 | Tree loss: 5.406 | Accuracy: 0.054688 | 0.845 sec/iter
Epoch: 31 | Batch: 011 / 019 | Total loss: 5.393 | Reg loss: 0.024 | Tree loss: 5.393 | Accuracy: 0.048828 | 0.845 sec/iter
Epoch: 31 | Batch: 012 / 019 | Total loss: 5.432 | Reg loss: 0.024 | Tree loss: 5.432 | Accuracy: 0.037109 | 0.845 sec/iter
Epoch: 3

Epoch: 34 | Batch: 008 / 019 | Total loss: 5.098 | Reg loss: 0.024 | Tree loss: 5.098 | Accuracy: 0.052734 | 0.846 sec/iter
Epoch: 34 | Batch: 009 / 019 | Total loss: 5.038 | Reg loss: 0.024 | Tree loss: 5.038 | Accuracy: 0.062500 | 0.846 sec/iter
Epoch: 34 | Batch: 010 / 019 | Total loss: 5.108 | Reg loss: 0.024 | Tree loss: 5.108 | Accuracy: 0.041016 | 0.846 sec/iter
Epoch: 34 | Batch: 011 / 019 | Total loss: 5.039 | Reg loss: 0.024 | Tree loss: 5.039 | Accuracy: 0.062500 | 0.846 sec/iter
Epoch: 34 | Batch: 012 / 019 | Total loss: 5.059 | Reg loss: 0.024 | Tree loss: 5.059 | Accuracy: 0.039062 | 0.846 sec/iter
Epoch: 34 | Batch: 013 / 019 | Total loss: 5.044 | Reg loss: 0.024 | Tree loss: 5.044 | Accuracy: 0.066406 | 0.846 sec/iter
Epoch: 34 | Batch: 014 / 019 | Total loss: 5.035 | Reg loss: 0.024 | Tree loss: 5.035 | Accuracy: 0.060547 | 0.846 sec/iter
Epoch: 34 | Batch: 015 / 019 | Total loss: 5.036 | Reg loss: 0.024 | Tree loss: 5.036 | Accuracy: 0.052734 | 0.846 sec/iter
Epoch: 3

Epoch: 37 | Batch: 011 / 019 | Total loss: 4.788 | Reg loss: 0.024 | Tree loss: 4.788 | Accuracy: 0.056641 | 0.848 sec/iter
Epoch: 37 | Batch: 012 / 019 | Total loss: 4.777 | Reg loss: 0.024 | Tree loss: 4.777 | Accuracy: 0.054688 | 0.848 sec/iter
Epoch: 37 | Batch: 013 / 019 | Total loss: 4.779 | Reg loss: 0.024 | Tree loss: 4.779 | Accuracy: 0.066406 | 0.848 sec/iter
Epoch: 37 | Batch: 014 / 019 | Total loss: 4.785 | Reg loss: 0.024 | Tree loss: 4.785 | Accuracy: 0.050781 | 0.848 sec/iter
Epoch: 37 | Batch: 015 / 019 | Total loss: 4.772 | Reg loss: 0.024 | Tree loss: 4.772 | Accuracy: 0.054688 | 0.848 sec/iter
Epoch: 37 | Batch: 016 / 019 | Total loss: 4.768 | Reg loss: 0.024 | Tree loss: 4.768 | Accuracy: 0.050781 | 0.848 sec/iter
Epoch: 37 | Batch: 017 / 019 | Total loss: 4.740 | Reg loss: 0.024 | Tree loss: 4.740 | Accuracy: 0.068359 | 0.848 sec/iter
Epoch: 37 | Batch: 018 / 019 | Total loss: 4.748 | Reg loss: 0.024 | Tree loss: 4.748 | Accuracy: 0.055749 | 0.848 sec/iter
Average 

Epoch: 40 | Batch: 014 / 019 | Total loss: 4.572 | Reg loss: 0.024 | Tree loss: 4.572 | Accuracy: 0.052734 | 0.849 sec/iter
Epoch: 40 | Batch: 015 / 019 | Total loss: 4.550 | Reg loss: 0.024 | Tree loss: 4.550 | Accuracy: 0.039062 | 0.849 sec/iter
Epoch: 40 | Batch: 016 / 019 | Total loss: 4.531 | Reg loss: 0.024 | Tree loss: 4.531 | Accuracy: 0.072266 | 0.849 sec/iter
Epoch: 40 | Batch: 017 / 019 | Total loss: 4.564 | Reg loss: 0.024 | Tree loss: 4.564 | Accuracy: 0.054688 | 0.849 sec/iter
Epoch: 40 | Batch: 018 / 019 | Total loss: 4.612 | Reg loss: 0.024 | Tree loss: 4.612 | Accuracy: 0.038328 | 0.849 sec/iter
Average sparseness: 0.9821428571428573
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
layer 7: 0.9821428571428573
layer 8: 0.9821428571428573
Epoch: 41 | Batch: 000 / 019 | Total loss: 4.556 | Reg loss: 0.024 | Tree loss: 4.556 | Ac

Epoch: 43 | Batch: 017 / 019 | Total loss: 4.376 | Reg loss: 0.024 | Tree loss: 4.376 | Accuracy: 0.048828 | 0.851 sec/iter
Epoch: 43 | Batch: 018 / 019 | Total loss: 4.357 | Reg loss: 0.024 | Tree loss: 4.357 | Accuracy: 0.055749 | 0.851 sec/iter
Average sparseness: 0.9821428571428573
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
layer 7: 0.9821428571428573
layer 8: 0.9821428571428573
Epoch: 44 | Batch: 000 / 019 | Total loss: 4.407 | Reg loss: 0.023 | Tree loss: 4.407 | Accuracy: 0.046875 | 0.852 sec/iter
Epoch: 44 | Batch: 001 / 019 | Total loss: 4.354 | Reg loss: 0.023 | Tree loss: 4.354 | Accuracy: 0.064453 | 0.852 sec/iter
Epoch: 44 | Batch: 002 / 019 | Total loss: 4.370 | Reg loss: 0.023 | Tree loss: 4.370 | Accuracy: 0.048828 | 0.852 sec/iter
Epoch: 44 | Batch: 003 / 019 | Total loss: 4.350 | Reg loss: 0.023 | Tree loss: 4.350 | Ac

Epoch: 47 | Batch: 000 / 019 | Total loss: 4.201 | Reg loss: 0.023 | Tree loss: 4.201 | Accuracy: 0.083984 | 0.854 sec/iter
Epoch: 47 | Batch: 001 / 019 | Total loss: 4.219 | Reg loss: 0.023 | Tree loss: 4.219 | Accuracy: 0.095703 | 0.854 sec/iter
Epoch: 47 | Batch: 002 / 019 | Total loss: 4.223 | Reg loss: 0.023 | Tree loss: 4.223 | Accuracy: 0.074219 | 0.853 sec/iter
Epoch: 47 | Batch: 003 / 019 | Total loss: 4.207 | Reg loss: 0.023 | Tree loss: 4.207 | Accuracy: 0.064453 | 0.853 sec/iter
Epoch: 47 | Batch: 004 / 019 | Total loss: 4.183 | Reg loss: 0.023 | Tree loss: 4.183 | Accuracy: 0.093750 | 0.853 sec/iter
Epoch: 47 | Batch: 005 / 019 | Total loss: 4.191 | Reg loss: 0.023 | Tree loss: 4.191 | Accuracy: 0.066406 | 0.853 sec/iter
Epoch: 47 | Batch: 006 / 019 | Total loss: 4.225 | Reg loss: 0.023 | Tree loss: 4.225 | Accuracy: 0.082031 | 0.853 sec/iter
Epoch: 47 | Batch: 007 / 019 | Total loss: 4.148 | Reg loss: 0.023 | Tree loss: 4.148 | Accuracy: 0.087891 | 0.853 sec/iter
Epoch: 4

Epoch: 50 | Batch: 003 / 019 | Total loss: 4.086 | Reg loss: 0.023 | Tree loss: 4.086 | Accuracy: 0.091797 | 0.855 sec/iter
Epoch: 50 | Batch: 004 / 019 | Total loss: 4.071 | Reg loss: 0.023 | Tree loss: 4.071 | Accuracy: 0.074219 | 0.854 sec/iter
Epoch: 50 | Batch: 005 / 019 | Total loss: 4.079 | Reg loss: 0.023 | Tree loss: 4.079 | Accuracy: 0.097656 | 0.854 sec/iter
Epoch: 50 | Batch: 006 / 019 | Total loss: 4.091 | Reg loss: 0.023 | Tree loss: 4.091 | Accuracy: 0.068359 | 0.854 sec/iter
Epoch: 50 | Batch: 007 / 019 | Total loss: 4.090 | Reg loss: 0.023 | Tree loss: 4.090 | Accuracy: 0.072266 | 0.854 sec/iter
Epoch: 50 | Batch: 008 / 019 | Total loss: 4.075 | Reg loss: 0.023 | Tree loss: 4.075 | Accuracy: 0.070312 | 0.854 sec/iter
Epoch: 50 | Batch: 009 / 019 | Total loss: 4.067 | Reg loss: 0.023 | Tree loss: 4.067 | Accuracy: 0.068359 | 0.854 sec/iter
Epoch: 50 | Batch: 010 / 019 | Total loss: 4.094 | Reg loss: 0.023 | Tree loss: 4.094 | Accuracy: 0.058594 | 0.854 sec/iter
Epoch: 5

Epoch: 53 | Batch: 006 / 019 | Total loss: 3.969 | Reg loss: 0.023 | Tree loss: 3.969 | Accuracy: 0.087891 | 0.856 sec/iter
Epoch: 53 | Batch: 007 / 019 | Total loss: 3.985 | Reg loss: 0.023 | Tree loss: 3.985 | Accuracy: 0.080078 | 0.856 sec/iter
Epoch: 53 | Batch: 008 / 019 | Total loss: 3.973 | Reg loss: 0.023 | Tree loss: 3.973 | Accuracy: 0.080078 | 0.856 sec/iter
Epoch: 53 | Batch: 009 / 019 | Total loss: 3.948 | Reg loss: 0.023 | Tree loss: 3.948 | Accuracy: 0.105469 | 0.856 sec/iter
Epoch: 53 | Batch: 010 / 019 | Total loss: 3.988 | Reg loss: 0.023 | Tree loss: 3.988 | Accuracy: 0.070312 | 0.856 sec/iter
Epoch: 53 | Batch: 011 / 019 | Total loss: 3.961 | Reg loss: 0.023 | Tree loss: 3.961 | Accuracy: 0.082031 | 0.856 sec/iter
Epoch: 53 | Batch: 012 / 019 | Total loss: 3.997 | Reg loss: 0.023 | Tree loss: 3.997 | Accuracy: 0.066406 | 0.856 sec/iter
Epoch: 53 | Batch: 013 / 019 | Total loss: 3.957 | Reg loss: 0.023 | Tree loss: 3.957 | Accuracy: 0.072266 | 0.856 sec/iter
Epoch: 5

Epoch: 56 | Batch: 009 / 019 | Total loss: 3.897 | Reg loss: 0.023 | Tree loss: 3.897 | Accuracy: 0.064453 | 0.857 sec/iter
Epoch: 56 | Batch: 010 / 019 | Total loss: 3.886 | Reg loss: 0.023 | Tree loss: 3.886 | Accuracy: 0.070312 | 0.857 sec/iter
Epoch: 56 | Batch: 011 / 019 | Total loss: 3.880 | Reg loss: 0.023 | Tree loss: 3.880 | Accuracy: 0.068359 | 0.857 sec/iter
Epoch: 56 | Batch: 012 / 019 | Total loss: 3.901 | Reg loss: 0.023 | Tree loss: 3.901 | Accuracy: 0.068359 | 0.857 sec/iter
Epoch: 56 | Batch: 013 / 019 | Total loss: 3.895 | Reg loss: 0.023 | Tree loss: 3.895 | Accuracy: 0.076172 | 0.857 sec/iter
Epoch: 56 | Batch: 014 / 019 | Total loss: 3.901 | Reg loss: 0.023 | Tree loss: 3.901 | Accuracy: 0.074219 | 0.856 sec/iter
Epoch: 56 | Batch: 015 / 019 | Total loss: 3.883 | Reg loss: 0.023 | Tree loss: 3.883 | Accuracy: 0.076172 | 0.856 sec/iter
Epoch: 56 | Batch: 016 / 019 | Total loss: 3.860 | Reg loss: 0.023 | Tree loss: 3.860 | Accuracy: 0.103516 | 0.856 sec/iter
Epoch: 5

Epoch: 59 | Batch: 012 / 019 | Total loss: 3.843 | Reg loss: 0.023 | Tree loss: 3.843 | Accuracy: 0.066406 | 0.858 sec/iter
Epoch: 59 | Batch: 013 / 019 | Total loss: 3.825 | Reg loss: 0.023 | Tree loss: 3.825 | Accuracy: 0.089844 | 0.858 sec/iter
Epoch: 59 | Batch: 014 / 019 | Total loss: 3.824 | Reg loss: 0.023 | Tree loss: 3.824 | Accuracy: 0.072266 | 0.858 sec/iter
Epoch: 59 | Batch: 015 / 019 | Total loss: 3.828 | Reg loss: 0.023 | Tree loss: 3.828 | Accuracy: 0.080078 | 0.858 sec/iter
Epoch: 59 | Batch: 016 / 019 | Total loss: 3.838 | Reg loss: 0.023 | Tree loss: 3.838 | Accuracy: 0.060547 | 0.858 sec/iter
Epoch: 59 | Batch: 017 / 019 | Total loss: 3.783 | Reg loss: 0.023 | Tree loss: 3.783 | Accuracy: 0.072266 | 0.857 sec/iter
Epoch: 59 | Batch: 018 / 019 | Total loss: 3.802 | Reg loss: 0.023 | Tree loss: 3.802 | Accuracy: 0.052265 | 0.857 sec/iter
Average sparseness: 0.9821428571428573
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 

Epoch: 62 | Batch: 015 / 019 | Total loss: 3.759 | Reg loss: 0.023 | Tree loss: 3.759 | Accuracy: 0.074219 | 0.858 sec/iter
Epoch: 62 | Batch: 016 / 019 | Total loss: 3.739 | Reg loss: 0.023 | Tree loss: 3.739 | Accuracy: 0.082031 | 0.858 sec/iter
Epoch: 62 | Batch: 017 / 019 | Total loss: 3.770 | Reg loss: 0.023 | Tree loss: 3.770 | Accuracy: 0.074219 | 0.858 sec/iter
Epoch: 62 | Batch: 018 / 019 | Total loss: 3.723 | Reg loss: 0.023 | Tree loss: 3.723 | Accuracy: 0.090592 | 0.858 sec/iter
Average sparseness: 0.9821428571428573
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
layer 7: 0.9821428571428573
layer 8: 0.9821428571428573
Epoch: 63 | Batch: 000 / 019 | Total loss: 3.794 | Reg loss: 0.023 | Tree loss: 3.794 | Accuracy: 0.070312 | 0.859 sec/iter
Epoch: 63 | Batch: 001 / 019 | Total loss: 3.792 | Reg loss: 0.023 | Tree loss: 3.792 | Ac

Epoch: 65 | Batch: 018 / 019 | Total loss: 3.731 | Reg loss: 0.023 | Tree loss: 3.731 | Accuracy: 0.052265 | 0.859 sec/iter
Average sparseness: 0.9821428571428573
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
layer 7: 0.9821428571428573
layer 8: 0.9821428571428573
Epoch: 66 | Batch: 000 / 019 | Total loss: 3.703 | Reg loss: 0.023 | Tree loss: 3.703 | Accuracy: 0.066406 | 0.86 sec/iter
Epoch: 66 | Batch: 001 / 019 | Total loss: 3.741 | Reg loss: 0.023 | Tree loss: 3.741 | Accuracy: 0.076172 | 0.86 sec/iter
Epoch: 66 | Batch: 002 / 019 | Total loss: 3.707 | Reg loss: 0.023 | Tree loss: 3.707 | Accuracy: 0.076172 | 0.86 sec/iter
Epoch: 66 | Batch: 003 / 019 | Total loss: 3.708 | Reg loss: 0.023 | Tree loss: 3.708 | Accuracy: 0.072266 | 0.86 sec/iter
Epoch: 66 | Batch: 004 / 019 | Total loss: 3.721 | Reg loss: 0.023 | Tree loss: 3.721 | Accura

Epoch: 69 | Batch: 000 / 019 | Total loss: 3.686 | Reg loss: 0.023 | Tree loss: 3.686 | Accuracy: 0.070312 | 0.86 sec/iter
Epoch: 69 | Batch: 001 / 019 | Total loss: 3.665 | Reg loss: 0.023 | Tree loss: 3.665 | Accuracy: 0.091797 | 0.86 sec/iter
Epoch: 69 | Batch: 002 / 019 | Total loss: 3.692 | Reg loss: 0.023 | Tree loss: 3.692 | Accuracy: 0.070312 | 0.86 sec/iter
Epoch: 69 | Batch: 003 / 019 | Total loss: 3.638 | Reg loss: 0.023 | Tree loss: 3.638 | Accuracy: 0.085938 | 0.86 sec/iter
Epoch: 69 | Batch: 004 / 019 | Total loss: 3.677 | Reg loss: 0.023 | Tree loss: 3.677 | Accuracy: 0.082031 | 0.86 sec/iter
Epoch: 69 | Batch: 005 / 019 | Total loss: 3.656 | Reg loss: 0.023 | Tree loss: 3.656 | Accuracy: 0.089844 | 0.86 sec/iter
Epoch: 69 | Batch: 006 / 019 | Total loss: 3.711 | Reg loss: 0.023 | Tree loss: 3.711 | Accuracy: 0.052734 | 0.86 sec/iter
Epoch: 69 | Batch: 007 / 019 | Total loss: 3.702 | Reg loss: 0.023 | Tree loss: 3.702 | Accuracy: 0.058594 | 0.86 sec/iter
Epoch: 69 | Batc

Epoch: 72 | Batch: 003 / 019 | Total loss: 3.635 | Reg loss: 0.023 | Tree loss: 3.635 | Accuracy: 0.078125 | 0.86 sec/iter
Epoch: 72 | Batch: 004 / 019 | Total loss: 3.657 | Reg loss: 0.023 | Tree loss: 3.657 | Accuracy: 0.087891 | 0.86 sec/iter
Epoch: 72 | Batch: 005 / 019 | Total loss: 3.617 | Reg loss: 0.023 | Tree loss: 3.617 | Accuracy: 0.087891 | 0.86 sec/iter
Epoch: 72 | Batch: 006 / 019 | Total loss: 3.628 | Reg loss: 0.023 | Tree loss: 3.628 | Accuracy: 0.068359 | 0.86 sec/iter
Epoch: 72 | Batch: 007 / 019 | Total loss: 3.684 | Reg loss: 0.023 | Tree loss: 3.684 | Accuracy: 0.070312 | 0.86 sec/iter
Epoch: 72 | Batch: 008 / 019 | Total loss: 3.636 | Reg loss: 0.023 | Tree loss: 3.636 | Accuracy: 0.085938 | 0.86 sec/iter
Epoch: 72 | Batch: 009 / 019 | Total loss: 3.620 | Reg loss: 0.023 | Tree loss: 3.620 | Accuracy: 0.070312 | 0.86 sec/iter
Epoch: 72 | Batch: 010 / 019 | Total loss: 3.647 | Reg loss: 0.023 | Tree loss: 3.647 | Accuracy: 0.076172 | 0.86 sec/iter
Epoch: 72 | Batc

Epoch: 75 | Batch: 006 / 019 | Total loss: 3.610 | Reg loss: 0.023 | Tree loss: 3.610 | Accuracy: 0.076172 | 0.861 sec/iter
Epoch: 75 | Batch: 007 / 019 | Total loss: 3.604 | Reg loss: 0.023 | Tree loss: 3.604 | Accuracy: 0.093750 | 0.861 sec/iter
Epoch: 75 | Batch: 008 / 019 | Total loss: 3.589 | Reg loss: 0.023 | Tree loss: 3.589 | Accuracy: 0.091797 | 0.86 sec/iter
Epoch: 75 | Batch: 009 / 019 | Total loss: 3.617 | Reg loss: 0.023 | Tree loss: 3.617 | Accuracy: 0.082031 | 0.86 sec/iter
Epoch: 75 | Batch: 010 / 019 | Total loss: 3.604 | Reg loss: 0.023 | Tree loss: 3.604 | Accuracy: 0.082031 | 0.86 sec/iter
Epoch: 75 | Batch: 011 / 019 | Total loss: 3.615 | Reg loss: 0.023 | Tree loss: 3.615 | Accuracy: 0.058594 | 0.86 sec/iter
Epoch: 75 | Batch: 012 / 019 | Total loss: 3.607 | Reg loss: 0.023 | Tree loss: 3.607 | Accuracy: 0.068359 | 0.86 sec/iter
Epoch: 75 | Batch: 013 / 019 | Total loss: 3.610 | Reg loss: 0.023 | Tree loss: 3.610 | Accuracy: 0.091797 | 0.86 sec/iter
Epoch: 75 | Ba

Epoch: 78 | Batch: 009 / 019 | Total loss: 3.600 | Reg loss: 0.023 | Tree loss: 3.600 | Accuracy: 0.087891 | 0.861 sec/iter
Epoch: 78 | Batch: 010 / 019 | Total loss: 3.611 | Reg loss: 0.023 | Tree loss: 3.611 | Accuracy: 0.078125 | 0.861 sec/iter
Epoch: 78 | Batch: 011 / 019 | Total loss: 3.589 | Reg loss: 0.023 | Tree loss: 3.589 | Accuracy: 0.083984 | 0.861 sec/iter
Epoch: 78 | Batch: 012 / 019 | Total loss: 3.582 | Reg loss: 0.023 | Tree loss: 3.582 | Accuracy: 0.083984 | 0.861 sec/iter
Epoch: 78 | Batch: 013 / 019 | Total loss: 3.611 | Reg loss: 0.023 | Tree loss: 3.611 | Accuracy: 0.072266 | 0.861 sec/iter
Epoch: 78 | Batch: 014 / 019 | Total loss: 3.577 | Reg loss: 0.023 | Tree loss: 3.577 | Accuracy: 0.078125 | 0.861 sec/iter
Epoch: 78 | Batch: 015 / 019 | Total loss: 3.618 | Reg loss: 0.023 | Tree loss: 3.618 | Accuracy: 0.072266 | 0.861 sec/iter
Epoch: 78 | Batch: 016 / 019 | Total loss: 3.582 | Reg loss: 0.023 | Tree loss: 3.582 | Accuracy: 0.082031 | 0.861 sec/iter
Epoch: 7

Epoch: 81 | Batch: 012 / 019 | Total loss: 3.545 | Reg loss: 0.023 | Tree loss: 3.545 | Accuracy: 0.080078 | 0.861 sec/iter
Epoch: 81 | Batch: 013 / 019 | Total loss: 3.575 | Reg loss: 0.023 | Tree loss: 3.575 | Accuracy: 0.080078 | 0.861 sec/iter
Epoch: 81 | Batch: 014 / 019 | Total loss: 3.559 | Reg loss: 0.023 | Tree loss: 3.559 | Accuracy: 0.078125 | 0.861 sec/iter
Epoch: 81 | Batch: 015 / 019 | Total loss: 3.602 | Reg loss: 0.023 | Tree loss: 3.602 | Accuracy: 0.074219 | 0.861 sec/iter
Epoch: 81 | Batch: 016 / 019 | Total loss: 3.554 | Reg loss: 0.023 | Tree loss: 3.554 | Accuracy: 0.089844 | 0.861 sec/iter
Epoch: 81 | Batch: 017 / 019 | Total loss: 3.589 | Reg loss: 0.023 | Tree loss: 3.589 | Accuracy: 0.083984 | 0.861 sec/iter
Epoch: 81 | Batch: 018 / 019 | Total loss: 3.555 | Reg loss: 0.023 | Tree loss: 3.555 | Accuracy: 0.062718 | 0.861 sec/iter
Average sparseness: 0.9821428571428573
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 

Epoch: 84 | Batch: 015 / 019 | Total loss: 3.508 | Reg loss: 0.023 | Tree loss: 3.508 | Accuracy: 0.087891 | 0.861 sec/iter
Epoch: 84 | Batch: 016 / 019 | Total loss: 3.541 | Reg loss: 0.023 | Tree loss: 3.541 | Accuracy: 0.078125 | 0.861 sec/iter
Epoch: 84 | Batch: 017 / 019 | Total loss: 3.577 | Reg loss: 0.023 | Tree loss: 3.577 | Accuracy: 0.082031 | 0.861 sec/iter
Epoch: 84 | Batch: 018 / 019 | Total loss: 3.550 | Reg loss: 0.023 | Tree loss: 3.550 | Accuracy: 0.062718 | 0.861 sec/iter
Average sparseness: 0.9821428571428573
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
layer 7: 0.9821428571428573
layer 8: 0.9821428571428573
Epoch: 85 | Batch: 000 / 019 | Total loss: 3.553 | Reg loss: 0.023 | Tree loss: 3.553 | Accuracy: 0.111328 | 0.862 sec/iter
Epoch: 85 | Batch: 001 / 019 | Total loss: 3.538 | Reg loss: 0.023 | Tree loss: 3.538 | Ac

Epoch: 87 | Batch: 018 / 019 | Total loss: 3.557 | Reg loss: 0.023 | Tree loss: 3.557 | Accuracy: 0.090592 | 0.861 sec/iter
Average sparseness: 0.9821428571428573
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
layer 7: 0.9821428571428573
layer 8: 0.9821428571428573
Epoch: 88 | Batch: 000 / 019 | Total loss: 3.543 | Reg loss: 0.023 | Tree loss: 3.543 | Accuracy: 0.083984 | 0.862 sec/iter
Epoch: 88 | Batch: 001 / 019 | Total loss: 3.555 | Reg loss: 0.023 | Tree loss: 3.555 | Accuracy: 0.048828 | 0.862 sec/iter
Epoch: 88 | Batch: 002 / 019 | Total loss: 3.495 | Reg loss: 0.023 | Tree loss: 3.495 | Accuracy: 0.107422 | 0.862 sec/iter
Epoch: 88 | Batch: 003 / 019 | Total loss: 3.537 | Reg loss: 0.023 | Tree loss: 3.537 | Accuracy: 0.097656 | 0.862 sec/iter
Epoch: 88 | Batch: 004 / 019 | Total loss: 3.551 | Reg loss: 0.023 | Tree loss: 3.551 | Ac

Epoch: 91 | Batch: 000 / 019 | Total loss: 3.546 | Reg loss: 0.023 | Tree loss: 3.546 | Accuracy: 0.066406 | 0.862 sec/iter
Epoch: 91 | Batch: 001 / 019 | Total loss: 3.534 | Reg loss: 0.023 | Tree loss: 3.534 | Accuracy: 0.068359 | 0.862 sec/iter
Epoch: 91 | Batch: 002 / 019 | Total loss: 3.515 | Reg loss: 0.023 | Tree loss: 3.515 | Accuracy: 0.080078 | 0.862 sec/iter
Epoch: 91 | Batch: 003 / 019 | Total loss: 3.517 | Reg loss: 0.023 | Tree loss: 3.517 | Accuracy: 0.080078 | 0.862 sec/iter
Epoch: 91 | Batch: 004 / 019 | Total loss: 3.542 | Reg loss: 0.023 | Tree loss: 3.542 | Accuracy: 0.078125 | 0.862 sec/iter
Epoch: 91 | Batch: 005 / 019 | Total loss: 3.552 | Reg loss: 0.023 | Tree loss: 3.552 | Accuracy: 0.074219 | 0.862 sec/iter
Epoch: 91 | Batch: 006 / 019 | Total loss: 3.511 | Reg loss: 0.023 | Tree loss: 3.511 | Accuracy: 0.095703 | 0.862 sec/iter
Epoch: 91 | Batch: 007 / 019 | Total loss: 3.538 | Reg loss: 0.023 | Tree loss: 3.538 | Accuracy: 0.080078 | 0.862 sec/iter
Epoch: 9

Epoch: 94 | Batch: 003 / 019 | Total loss: 3.493 | Reg loss: 0.023 | Tree loss: 3.493 | Accuracy: 0.085938 | 0.862 sec/iter
Epoch: 94 | Batch: 004 / 019 | Total loss: 3.516 | Reg loss: 0.023 | Tree loss: 3.516 | Accuracy: 0.062500 | 0.862 sec/iter
Epoch: 94 | Batch: 005 / 019 | Total loss: 3.518 | Reg loss: 0.023 | Tree loss: 3.518 | Accuracy: 0.072266 | 0.862 sec/iter
Epoch: 94 | Batch: 006 / 019 | Total loss: 3.524 | Reg loss: 0.023 | Tree loss: 3.524 | Accuracy: 0.083984 | 0.862 sec/iter
Epoch: 94 | Batch: 007 / 019 | Total loss: 3.530 | Reg loss: 0.023 | Tree loss: 3.530 | Accuracy: 0.074219 | 0.862 sec/iter
Epoch: 94 | Batch: 008 / 019 | Total loss: 3.536 | Reg loss: 0.023 | Tree loss: 3.536 | Accuracy: 0.064453 | 0.862 sec/iter
Epoch: 94 | Batch: 009 / 019 | Total loss: 3.517 | Reg loss: 0.023 | Tree loss: 3.517 | Accuracy: 0.062500 | 0.862 sec/iter
Epoch: 94 | Batch: 010 / 019 | Total loss: 3.532 | Reg loss: 0.023 | Tree loss: 3.532 | Accuracy: 0.085938 | 0.862 sec/iter
Epoch: 9

Epoch: 97 | Batch: 006 / 019 | Total loss: 3.495 | Reg loss: 0.023 | Tree loss: 3.495 | Accuracy: 0.087891 | 0.863 sec/iter
Epoch: 97 | Batch: 007 / 019 | Total loss: 3.530 | Reg loss: 0.023 | Tree loss: 3.530 | Accuracy: 0.074219 | 0.863 sec/iter
Epoch: 97 | Batch: 008 / 019 | Total loss: 3.489 | Reg loss: 0.023 | Tree loss: 3.489 | Accuracy: 0.074219 | 0.863 sec/iter
Epoch: 97 | Batch: 009 / 019 | Total loss: 3.492 | Reg loss: 0.023 | Tree loss: 3.492 | Accuracy: 0.078125 | 0.863 sec/iter
Epoch: 97 | Batch: 010 / 019 | Total loss: 3.519 | Reg loss: 0.023 | Tree loss: 3.519 | Accuracy: 0.078125 | 0.863 sec/iter
Epoch: 97 | Batch: 011 / 019 | Total loss: 3.505 | Reg loss: 0.023 | Tree loss: 3.505 | Accuracy: 0.087891 | 0.863 sec/iter
Epoch: 97 | Batch: 012 / 019 | Total loss: 3.522 | Reg loss: 0.023 | Tree loss: 3.522 | Accuracy: 0.068359 | 0.863 sec/iter
Epoch: 97 | Batch: 013 / 019 | Total loss: 3.510 | Reg loss: 0.023 | Tree loss: 3.510 | Accuracy: 0.087891 | 0.863 sec/iter
Epoch: 9

In [31]:
plt.figure(figsize=(10, 5))
plt.ylabel("Accuracy")
plt.xlabel('Iteration')
plt.plot(accs, label='Accuracy vs iteration')
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [32]:
plt.figure()
plt.ylabel("Loss")
plt.xlabel('Iteration')
plt.plot(losses, label='Loss vs iteration')
plt.yscale("log")
plt.show()

plt.figure()
weights = tree.inner_nodes.weight.cpu().detach().numpy().flatten()
plt.hist(weights, bins=500)
weights_std = np.std(weights)
weights_mean = np.mean(weights)
plt.axvline(weights_mean + weights_std, color='r')
plt.axvline(weights_mean - weights_std, color='r')
plt.title(f"Mean: {weights_mean}   |   STD: {weights_std}")
plt.yscale("log")
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

# Tree Visualization

In [33]:
plt.figure(figsize=(15, 10), dpi=80)
avg_height, root = tree.visualize()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Average height: 10.0


# Extract Rules

# Accumulate samples in the leaves

In [34]:
print(f"Number of patterns: {len(root.get_leaves())}")

Number of patterns: 1024


In [35]:
method = 'greedy'

In [36]:
root.clear_leaves_samples()

with torch.no_grad():
    for batch_idx, (data, target) in enumerate(tree_loader):
        root.accumulate_samples(data, method)



# Tighten boundaries

In [37]:
attr_names = dataset.items

# print(attr_names)
leaves = root.get_leaves()
sum_comprehensibility = 0
comprehensibilities = []
for pattern_counter, leaf in enumerate(leaves):
    leaf.reset_path()
    leaf.tighten_with_accumulated_samples()
    conds = leaf.get_path_conditions(attr_names)
    print(f"============== Pattern {pattern_counter + 1} ==============")
    comprehensibilities.append(sum([cond.comprehensibility for cond in conds]))
    
print(f"Average comprehensibility: {np.mean(comprehensibilities)}")
print(f"std comprehensibility: {np.std(comprehensibilities)}")
print(f"var comprehensibility: {np.var(comprehensibilities)}")
print(f"minimum comprehensibility: {np.min(comprehensibilities)}")
print(f"maximum comprehensibility: {np.max(comprehensibilities)}")

  return np.log(1 / (1 - x))






9503


Average comprehensibility: 44.109375
std comprehensibility: 2.418142284766345
var comprehensibility: 5.847412109375
minimum comprehensibility: 40
maximum comprehensibility: 52
