In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib widget
import numpy as np
import pandas as pd
import torch
import torchvision
from torch import optim
import matplotlib.pyplot as plt
import torch.nn.functional as F
from sklearn.metrics import pairwise_distances
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tqdm.notebook import tqdm
import os
import sys
import time
import torch.nn as nn
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)
from stream_generators.market_basket_dataset import MarketBasketDataset, BinaryEncodingTransform, RemoveItemsTransform
from utils.MatplotlibUtils import reduce_dims_and_plot
from network.auto_encoder import AutoEncoder
from losses.knn_loss import KNNLoss
from sklearn.cluster import DBSCAN
from sklearn.mixture import GaussianMixture
from soft_decision_tree.sdt_model import SDT
from sklearn.metrics import davies_bouldin_score

In [2]:
k = 32
tree_depth = 6
device = 'cuda'
dataset_path = r"/mnt/qnap/ekosman/Groceries_dataset.csv"

## Create the market basket dataset and use one-hot encoding for items

In [3]:
dataset = MarketBasketDataset(dataset_path=dataset_path)

In [4]:
model = AutoEncoder(dataset.n_items, 50, 4).train().to(device)
epochs = 500
lr = 5e-3
batch_size = 512
log_every = 5

In [5]:
dataset.transform = torchvision.transforms.Compose([
    RemoveItemsTransform(p=0.5),
    BinaryEncodingTransform(mapping=dataset.items_to_idx),
]
)
dataset.target_transform = torchvision.transforms.Compose([
    BinaryEncodingTransform(mapping=dataset.items_to_idx),
]
)

In [6]:
model.train()
data_iter = torch.utils.data.DataLoader(dataset,
                                     batch_size=batch_size,
                                     shuffle=True,
                                     num_workers=1,
                                     pin_memory=True)
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)
scheduler = ReduceLROnPlateau(optimizer, 'min', verbose=True, factor=0.7, threshold=1e-4)
knn_crt = KNNLoss(k=k).to(device)
losses = []
alpha = 10/170
gamma = 2
for epoch in range(epochs):
    total_loss = 0
    for iteration, (batch, target) in enumerate(data_iter):
        batch = batch.to(device)
        target = target.to(device)
        outputs, iterm = model(batch, return_intermidiate=True)
        mse_loss = F.binary_cross_entropy_with_logits(outputs, target, reduction='none')
        mask = torch.ones_like(mse_loss)
        mask[target == 0] = alpha ** gamma
        mask[target == 1] = (1 - alpha) ** gamma
        mse_loss = (mse_loss * mask).sum(dim=-1).mean()
        try:
            knn_loss = knn_crt(iterm)
            if torch.isinf(knn_loss):
                knn_loss = 0
        except ValueError:
            knn_loss = torch.tensor(0)
        loss = mse_loss + knn_loss
        total_loss += loss.item()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if iteration % log_every == 0:
            print(f"Epoch {epoch} / {epochs} | iteration {iteration} / {len(data_iter)} | Total Loss: {loss.item()} | KNN Loss: {knn_loss.item()} | BCE Loss: {mse_loss.item()}")
    
    scheduler.step(total_loss / (iteration + 1))
    losses.append(total_loss / (iteration + 1))

Epoch 0 / 500 | iteration 0 / 30 | Total Loss: 8.164664268493652 | KNN Loss: 6.227475166320801 | BCE Loss: 1.937188982963562
Epoch 0 / 500 | iteration 5 / 30 | Total Loss: 8.172765731811523 | KNN Loss: 6.227191925048828 | BCE Loss: 1.9455742835998535
Epoch 0 / 500 | iteration 10 / 30 | Total Loss: 8.190939903259277 | KNN Loss: 6.227177143096924 | BCE Loss: 1.9637631177902222
Epoch 0 / 500 | iteration 15 / 30 | Total Loss: 8.146702766418457 | KNN Loss: 6.226847171783447 | BCE Loss: 1.9198553562164307
Epoch 0 / 500 | iteration 20 / 30 | Total Loss: 8.167062759399414 | KNN Loss: 6.2267584800720215 | BCE Loss: 1.9403047561645508
Epoch 0 / 500 | iteration 25 / 30 | Total Loss: 8.134511947631836 | KNN Loss: 6.226141452789307 | BCE Loss: 1.908370018005371
Epoch 1 / 500 | iteration 0 / 30 | Total Loss: 8.144550323486328 | KNN Loss: 6.225564002990723 | BCE Loss: 1.9189860820770264
Epoch 1 / 500 | iteration 5 / 30 | Total Loss: 8.135579109191895 | KNN Loss: 6.225414752960205 | BCE Loss: 1.910163

Epoch 10 / 500 | iteration 25 / 30 | Total Loss: 5.511267185211182 | KNN Loss: 4.427506923675537 | BCE Loss: 1.0837602615356445
Epoch 11 / 500 | iteration 0 / 30 | Total Loss: 5.47360897064209 | KNN Loss: 4.334427833557129 | BCE Loss: 1.139181137084961
Epoch 11 / 500 | iteration 5 / 30 | Total Loss: 5.351345062255859 | KNN Loss: 4.25172233581543 | BCE Loss: 1.0996226072311401
Epoch 11 / 500 | iteration 10 / 30 | Total Loss: 5.377434730529785 | KNN Loss: 4.280810356140137 | BCE Loss: 1.0966246128082275
Epoch 11 / 500 | iteration 15 / 30 | Total Loss: 5.298580169677734 | KNN Loss: 4.186285972595215 | BCE Loss: 1.1122944355010986
Epoch 11 / 500 | iteration 20 / 30 | Total Loss: 5.240612030029297 | KNN Loss: 4.130659580230713 | BCE Loss: 1.1099525690078735
Epoch 11 / 500 | iteration 25 / 30 | Total Loss: 5.193686485290527 | KNN Loss: 4.103450775146484 | BCE Loss: 1.090235710144043
Epoch 12 / 500 | iteration 0 / 30 | Total Loss: 5.167356967926025 | KNN Loss: 4.063889980316162 | BCE Loss: 1.

Epoch 21 / 500 | iteration 20 / 30 | Total Loss: 4.867629528045654 | KNN Loss: 3.8126978874206543 | BCE Loss: 1.0549317598342896
Epoch 21 / 500 | iteration 25 / 30 | Total Loss: 4.894550800323486 | KNN Loss: 3.8383522033691406 | BCE Loss: 1.0561985969543457
Epoch 22 / 500 | iteration 0 / 30 | Total Loss: 4.868419647216797 | KNN Loss: 3.809732437133789 | BCE Loss: 1.058687448501587
Epoch 22 / 500 | iteration 5 / 30 | Total Loss: 4.86475944519043 | KNN Loss: 3.838819742202759 | BCE Loss: 1.0259398221969604
Epoch 22 / 500 | iteration 10 / 30 | Total Loss: 4.947162628173828 | KNN Loss: 3.889233350753784 | BCE Loss: 1.057929515838623
Epoch 22 / 500 | iteration 15 / 30 | Total Loss: 4.845092296600342 | KNN Loss: 3.7960262298583984 | BCE Loss: 1.0490660667419434
Epoch 22 / 500 | iteration 20 / 30 | Total Loss: 4.894075393676758 | KNN Loss: 3.827622175216675 | BCE Loss: 1.066453456878662
Epoch 22 / 500 | iteration 25 / 30 | Total Loss: 4.861667633056641 | KNN Loss: 3.7902114391326904 | BCE Los

Epoch 32 / 500 | iteration 10 / 30 | Total Loss: 4.831170558929443 | KNN Loss: 3.7563765048980713 | BCE Loss: 1.074794054031372
Epoch 32 / 500 | iteration 15 / 30 | Total Loss: 4.797304153442383 | KNN Loss: 3.7578086853027344 | BCE Loss: 1.0394952297210693
Epoch 32 / 500 | iteration 20 / 30 | Total Loss: 4.778392791748047 | KNN Loss: 3.75398850440979 | BCE Loss: 1.0244042873382568
Epoch 32 / 500 | iteration 25 / 30 | Total Loss: 4.798393249511719 | KNN Loss: 3.75352144241333 | BCE Loss: 1.0448715686798096
Epoch 33 / 500 | iteration 0 / 30 | Total Loss: 4.7605390548706055 | KNN Loss: 3.7299633026123047 | BCE Loss: 1.0305759906768799
Epoch 33 / 500 | iteration 5 / 30 | Total Loss: 4.825061321258545 | KNN Loss: 3.779970407485962 | BCE Loss: 1.0450907945632935
Epoch 33 / 500 | iteration 10 / 30 | Total Loss: 4.787725448608398 | KNN Loss: 3.7464547157287598 | BCE Loss: 1.0412707328796387
Epoch 33 / 500 | iteration 15 / 30 | Total Loss: 4.800629615783691 | KNN Loss: 3.7373907566070557 | BCE 

Epoch 43 / 500 | iteration 5 / 30 | Total Loss: 4.838930130004883 | KNN Loss: 3.808166027069092 | BCE Loss: 1.0307643413543701
Epoch 43 / 500 | iteration 10 / 30 | Total Loss: 4.802123069763184 | KNN Loss: 3.764037847518921 | BCE Loss: 1.0380849838256836
Epoch 43 / 500 | iteration 15 / 30 | Total Loss: 4.778770923614502 | KNN Loss: 3.7332231998443604 | BCE Loss: 1.0455477237701416
Epoch 43 / 500 | iteration 20 / 30 | Total Loss: 4.752194881439209 | KNN Loss: 3.7203030586242676 | BCE Loss: 1.0318917036056519
Epoch 43 / 500 | iteration 25 / 30 | Total Loss: 4.779771327972412 | KNN Loss: 3.7534127235412598 | BCE Loss: 1.0263586044311523
Epoch 44 / 500 | iteration 0 / 30 | Total Loss: 4.772917747497559 | KNN Loss: 3.726386070251465 | BCE Loss: 1.0465319156646729
Epoch 44 / 500 | iteration 5 / 30 | Total Loss: 4.778808116912842 | KNN Loss: 3.7498276233673096 | BCE Loss: 1.0289804935455322
Epoch 44 / 500 | iteration 10 / 30 | Total Loss: 4.833897590637207 | KNN Loss: 3.7689502239227295 | BCE

Epoch 54 / 500 | iteration 0 / 30 | Total Loss: 4.769696235656738 | KNN Loss: 3.7385637760162354 | BCE Loss: 1.031132698059082
Epoch 54 / 500 | iteration 5 / 30 | Total Loss: 4.792182922363281 | KNN Loss: 3.75553560256958 | BCE Loss: 1.0366473197937012
Epoch 54 / 500 | iteration 10 / 30 | Total Loss: 4.7573442459106445 | KNN Loss: 3.721357822418213 | BCE Loss: 1.0359864234924316
Epoch 54 / 500 | iteration 15 / 30 | Total Loss: 4.750117778778076 | KNN Loss: 3.7282280921936035 | BCE Loss: 1.0218896865844727
Epoch 54 / 500 | iteration 20 / 30 | Total Loss: 4.801402568817139 | KNN Loss: 3.74560546875 | BCE Loss: 1.0557971000671387
Epoch 54 / 500 | iteration 25 / 30 | Total Loss: 4.743971824645996 | KNN Loss: 3.7212440967559814 | BCE Loss: 1.0227277278900146
Epoch 55 / 500 | iteration 0 / 30 | Total Loss: 4.8330583572387695 | KNN Loss: 3.7855491638183594 | BCE Loss: 1.047508955001831
Epoch 55 / 500 | iteration 5 / 30 | Total Loss: 4.7697343826293945 | KNN Loss: 3.7485311031341553 | BCE Loss

Epoch 64 / 500 | iteration 20 / 30 | Total Loss: 4.750327110290527 | KNN Loss: 3.7059431076049805 | BCE Loss: 1.044384241104126
Epoch 64 / 500 | iteration 25 / 30 | Total Loss: 4.752415657043457 | KNN Loss: 3.7225630283355713 | BCE Loss: 1.0298526287078857
Epoch 65 / 500 | iteration 0 / 30 | Total Loss: 4.778557300567627 | KNN Loss: 3.7284209728240967 | BCE Loss: 1.0501363277435303
Epoch 65 / 500 | iteration 5 / 30 | Total Loss: 4.766260147094727 | KNN Loss: 3.7340734004974365 | BCE Loss: 1.0321866273880005
Epoch 65 / 500 | iteration 10 / 30 | Total Loss: 4.7468438148498535 | KNN Loss: 3.716355800628662 | BCE Loss: 1.0304880142211914
Epoch 65 / 500 | iteration 15 / 30 | Total Loss: 4.746288299560547 | KNN Loss: 3.7063331604003906 | BCE Loss: 1.0399550199508667
Epoch 65 / 500 | iteration 20 / 30 | Total Loss: 4.728255271911621 | KNN Loss: 3.7093734741210938 | BCE Loss: 1.0188815593719482
Epoch 65 / 500 | iteration 25 / 30 | Total Loss: 4.7857232093811035 | KNN Loss: 3.748563528060913 | 

Epoch 75 / 500 | iteration 15 / 30 | Total Loss: 4.71764612197876 | KNN Loss: 3.691009759902954 | BCE Loss: 1.0266363620758057
Epoch 75 / 500 | iteration 20 / 30 | Total Loss: 4.750609874725342 | KNN Loss: 3.715437650680542 | BCE Loss: 1.0351722240447998
Epoch 75 / 500 | iteration 25 / 30 | Total Loss: 4.739856719970703 | KNN Loss: 3.7106285095214844 | BCE Loss: 1.0292282104492188
Epoch 76 / 500 | iteration 0 / 30 | Total Loss: 4.816424369812012 | KNN Loss: 3.7428250312805176 | BCE Loss: 1.0735994577407837
Epoch 76 / 500 | iteration 5 / 30 | Total Loss: 4.742466926574707 | KNN Loss: 3.710428476333618 | BCE Loss: 1.0320383310317993
Epoch 76 / 500 | iteration 10 / 30 | Total Loss: 4.770359039306641 | KNN Loss: 3.7363409996032715 | BCE Loss: 1.0340182781219482
Epoch 76 / 500 | iteration 15 / 30 | Total Loss: 4.751928329467773 | KNN Loss: 3.716123580932617 | BCE Loss: 1.0358047485351562
Epoch 76 / 500 | iteration 20 / 30 | Total Loss: 4.769192695617676 | KNN Loss: 3.7200663089752197 | BCE 

Epoch 86 / 500 | iteration 10 / 30 | Total Loss: 4.734511375427246 | KNN Loss: 3.718853712081909 | BCE Loss: 1.015657901763916
Epoch 86 / 500 | iteration 15 / 30 | Total Loss: 4.741369247436523 | KNN Loss: 3.7101144790649414 | BCE Loss: 1.031254768371582
Epoch 86 / 500 | iteration 20 / 30 | Total Loss: 4.758678436279297 | KNN Loss: 3.7331881523132324 | BCE Loss: 1.0254902839660645
Epoch 86 / 500 | iteration 25 / 30 | Total Loss: 4.735609531402588 | KNN Loss: 3.7117559909820557 | BCE Loss: 1.0238536596298218
Epoch 87 / 500 | iteration 0 / 30 | Total Loss: 4.721056938171387 | KNN Loss: 3.7231664657592773 | BCE Loss: 0.9978905916213989
Epoch 87 / 500 | iteration 5 / 30 | Total Loss: 4.763482093811035 | KNN Loss: 3.733769655227661 | BCE Loss: 1.0297125577926636
Epoch 87 / 500 | iteration 10 / 30 | Total Loss: 4.735794544219971 | KNN Loss: 3.7286901473999023 | BCE Loss: 1.0071043968200684
Epoch 87 / 500 | iteration 15 / 30 | Total Loss: 4.729981422424316 | KNN Loss: 3.724855661392212 | BCE 

Epoch 97 / 500 | iteration 5 / 30 | Total Loss: 4.745556831359863 | KNN Loss: 3.7185747623443604 | BCE Loss: 1.026982307434082
Epoch 97 / 500 | iteration 10 / 30 | Total Loss: 4.793820381164551 | KNN Loss: 3.752561569213867 | BCE Loss: 1.0412588119506836
Epoch 97 / 500 | iteration 15 / 30 | Total Loss: 4.72690486907959 | KNN Loss: 3.6875112056732178 | BCE Loss: 1.039393424987793
Epoch 97 / 500 | iteration 20 / 30 | Total Loss: 4.762628555297852 | KNN Loss: 3.7398593425750732 | BCE Loss: 1.0227689743041992
Epoch 97 / 500 | iteration 25 / 30 | Total Loss: 4.753249645233154 | KNN Loss: 3.7064459323883057 | BCE Loss: 1.046803593635559
Epoch 98 / 500 | iteration 0 / 30 | Total Loss: 4.74613618850708 | KNN Loss: 3.692047119140625 | BCE Loss: 1.054089069366455
Epoch 98 / 500 | iteration 5 / 30 | Total Loss: 4.746635437011719 | KNN Loss: 3.700798273086548 | BCE Loss: 1.0458370447158813
Epoch 98 / 500 | iteration 10 / 30 | Total Loss: 4.769066333770752 | KNN Loss: 3.726808786392212 | BCE Loss: 

Epoch 107 / 500 | iteration 25 / 30 | Total Loss: 4.743544578552246 | KNN Loss: 3.681439161300659 | BCE Loss: 1.062105655670166
Epoch 108 / 500 | iteration 0 / 30 | Total Loss: 4.766185760498047 | KNN Loss: 3.7034788131713867 | BCE Loss: 1.0627071857452393
Epoch 108 / 500 | iteration 5 / 30 | Total Loss: 4.791109085083008 | KNN Loss: 3.7352957725524902 | BCE Loss: 1.0558130741119385
Epoch 108 / 500 | iteration 10 / 30 | Total Loss: 4.735445022583008 | KNN Loss: 3.7206239700317383 | BCE Loss: 1.0148212909698486
Epoch 108 / 500 | iteration 15 / 30 | Total Loss: 4.695124626159668 | KNN Loss: 3.6845762729644775 | BCE Loss: 1.0105481147766113
Epoch 108 / 500 | iteration 20 / 30 | Total Loss: 4.725499153137207 | KNN Loss: 3.697798728942871 | BCE Loss: 1.0277005434036255
Epoch 108 / 500 | iteration 25 / 30 | Total Loss: 4.75405740737915 | KNN Loss: 3.7240090370178223 | BCE Loss: 1.0300483703613281
Epoch 109 / 500 | iteration 0 / 30 | Total Loss: 4.700139999389648 | KNN Loss: 3.700690746307373

Epoch 118 / 500 | iteration 15 / 30 | Total Loss: 4.783628463745117 | KNN Loss: 3.7377724647521973 | BCE Loss: 1.0458557605743408
Epoch 118 / 500 | iteration 20 / 30 | Total Loss: 4.743024826049805 | KNN Loss: 3.6964240074157715 | BCE Loss: 1.0466006994247437
Epoch 118 / 500 | iteration 25 / 30 | Total Loss: 4.739471435546875 | KNN Loss: 3.7177317142486572 | BCE Loss: 1.0217397212982178
Epoch 119 / 500 | iteration 0 / 30 | Total Loss: 4.763558387756348 | KNN Loss: 3.7154903411865234 | BCE Loss: 1.0480680465698242
Epoch 119 / 500 | iteration 5 / 30 | Total Loss: 4.747273921966553 | KNN Loss: 3.7147529125213623 | BCE Loss: 1.0325210094451904
Epoch 119 / 500 | iteration 10 / 30 | Total Loss: 4.701626777648926 | KNN Loss: 3.6998910903930664 | BCE Loss: 1.0017355680465698
Epoch 119 / 500 | iteration 15 / 30 | Total Loss: 4.729096412658691 | KNN Loss: 3.7122371196746826 | BCE Loss: 1.0168592929840088
Epoch 119 / 500 | iteration 20 / 30 | Total Loss: 4.738917350769043 | KNN Loss: 3.6905372142

Epoch 129 / 500 | iteration 0 / 30 | Total Loss: 4.731118202209473 | KNN Loss: 3.707352876663208 | BCE Loss: 1.0237655639648438
Epoch 129 / 500 | iteration 5 / 30 | Total Loss: 4.7231974601745605 | KNN Loss: 3.689502716064453 | BCE Loss: 1.0336946249008179
Epoch 129 / 500 | iteration 10 / 30 | Total Loss: 4.7765984535217285 | KNN Loss: 3.695077896118164 | BCE Loss: 1.0815205574035645
Epoch 129 / 500 | iteration 15 / 30 | Total Loss: 4.744012832641602 | KNN Loss: 3.713778257369995 | BCE Loss: 1.0302345752716064
Epoch 129 / 500 | iteration 20 / 30 | Total Loss: 4.722775936126709 | KNN Loss: 3.6860833168029785 | BCE Loss: 1.0366926193237305
Epoch 129 / 500 | iteration 25 / 30 | Total Loss: 4.740422248840332 | KNN Loss: 3.7133355140686035 | BCE Loss: 1.0270867347717285
Epoch 130 / 500 | iteration 0 / 30 | Total Loss: 4.711063861846924 | KNN Loss: 3.6855080127716064 | BCE Loss: 1.0255558490753174
Epoch 130 / 500 | iteration 5 / 30 | Total Loss: 4.722482204437256 | KNN Loss: 3.70292925834655

Epoch 139 / 500 | iteration 20 / 30 | Total Loss: 4.691895484924316 | KNN Loss: 3.6823508739471436 | BCE Loss: 1.0095444917678833
Epoch 139 / 500 | iteration 25 / 30 | Total Loss: 4.69080924987793 | KNN Loss: 3.6834664344787598 | BCE Loss: 1.007343053817749
Epoch 140 / 500 | iteration 0 / 30 | Total Loss: 4.739829063415527 | KNN Loss: 3.700241804122925 | BCE Loss: 1.0395870208740234
Epoch 140 / 500 | iteration 5 / 30 | Total Loss: 4.762032508850098 | KNN Loss: 3.726933479309082 | BCE Loss: 1.0350992679595947
Epoch 140 / 500 | iteration 10 / 30 | Total Loss: 4.701656341552734 | KNN Loss: 3.6810414791107178 | BCE Loss: 1.020614743232727
Epoch 140 / 500 | iteration 15 / 30 | Total Loss: 4.736347675323486 | KNN Loss: 3.7102222442626953 | BCE Loss: 1.0261255502700806
Epoch 140 / 500 | iteration 20 / 30 | Total Loss: 4.748898506164551 | KNN Loss: 3.7047295570373535 | BCE Loss: 1.0441687107086182
Epoch 140 / 500 | iteration 25 / 30 | Total Loss: 4.734158992767334 | KNN Loss: 3.688397884368896

Epoch 150 / 500 | iteration 10 / 30 | Total Loss: 4.726978302001953 | KNN Loss: 3.6990952491760254 | BCE Loss: 1.0278832912445068
Epoch 150 / 500 | iteration 15 / 30 | Total Loss: 4.730175018310547 | KNN Loss: 3.7058253288269043 | BCE Loss: 1.0243499279022217
Epoch 150 / 500 | iteration 20 / 30 | Total Loss: 4.750066757202148 | KNN Loss: 3.7111928462982178 | BCE Loss: 1.0388739109039307
Epoch 150 / 500 | iteration 25 / 30 | Total Loss: 4.778169631958008 | KNN Loss: 3.717513084411621 | BCE Loss: 1.0606567859649658
Epoch 151 / 500 | iteration 0 / 30 | Total Loss: 4.667409420013428 | KNN Loss: 3.6733477115631104 | BCE Loss: 0.9940618276596069
Epoch 151 / 500 | iteration 5 / 30 | Total Loss: 4.783348083496094 | KNN Loss: 3.7502939701080322 | BCE Loss: 1.0330541133880615
Epoch 151 / 500 | iteration 10 / 30 | Total Loss: 4.704705238342285 | KNN Loss: 3.6838345527648926 | BCE Loss: 1.020870566368103
Epoch 151 / 500 | iteration 15 / 30 | Total Loss: 4.74923038482666 | KNN Loss: 3.7266347408294

Epoch 161 / 500 | iteration 0 / 30 | Total Loss: 4.720484733581543 | KNN Loss: 3.7381415367126465 | BCE Loss: 0.9823431968688965
Epoch 161 / 500 | iteration 5 / 30 | Total Loss: 4.719350814819336 | KNN Loss: 3.724574327468872 | BCE Loss: 0.9947762489318848
Epoch 161 / 500 | iteration 10 / 30 | Total Loss: 4.696784973144531 | KNN Loss: 3.6837544441223145 | BCE Loss: 1.0130305290222168
Epoch 161 / 500 | iteration 15 / 30 | Total Loss: 4.747257709503174 | KNN Loss: 3.7201690673828125 | BCE Loss: 1.0270886421203613
Epoch 161 / 500 | iteration 20 / 30 | Total Loss: 4.718799591064453 | KNN Loss: 3.68786358833313 | BCE Loss: 1.0309357643127441
Epoch 161 / 500 | iteration 25 / 30 | Total Loss: 4.737159252166748 | KNN Loss: 3.6986985206604004 | BCE Loss: 1.0384607315063477
Epoch 162 / 500 | iteration 0 / 30 | Total Loss: 4.769438743591309 | KNN Loss: 3.712846279144287 | BCE Loss: 1.0565927028656006
Epoch 162 / 500 | iteration 5 / 30 | Total Loss: 4.71798038482666 | KNN Loss: 3.6960742473602295 

Epoch 171 / 500 | iteration 20 / 30 | Total Loss: 4.738338947296143 | KNN Loss: 3.7039976119995117 | BCE Loss: 1.0343414545059204
Epoch 171 / 500 | iteration 25 / 30 | Total Loss: 4.732970714569092 | KNN Loss: 3.707430839538574 | BCE Loss: 1.0255398750305176
Epoch 172 / 500 | iteration 0 / 30 | Total Loss: 4.76054573059082 | KNN Loss: 3.707520008087158 | BCE Loss: 1.0530259609222412
Epoch 172 / 500 | iteration 5 / 30 | Total Loss: 4.693306922912598 | KNN Loss: 3.6857004165649414 | BCE Loss: 1.0076066255569458
Epoch 172 / 500 | iteration 10 / 30 | Total Loss: 4.778915882110596 | KNN Loss: 3.742236614227295 | BCE Loss: 1.0366792678833008
Epoch 172 / 500 | iteration 15 / 30 | Total Loss: 4.733243942260742 | KNN Loss: 3.7092552185058594 | BCE Loss: 1.0239884853363037
Epoch 172 / 500 | iteration 20 / 30 | Total Loss: 4.717403888702393 | KNN Loss: 3.691443681716919 | BCE Loss: 1.0259602069854736
Epoch 172 / 500 | iteration 25 / 30 | Total Loss: 4.763835906982422 | KNN Loss: 3.724456071853637

Epoch 182 / 500 | iteration 10 / 30 | Total Loss: 4.777622699737549 | KNN Loss: 3.740391492843628 | BCE Loss: 1.037231206893921
Epoch 182 / 500 | iteration 15 / 30 | Total Loss: 4.732438087463379 | KNN Loss: 3.68967342376709 | BCE Loss: 1.0427649021148682
Epoch 182 / 500 | iteration 20 / 30 | Total Loss: 4.704916000366211 | KNN Loss: 3.6908280849456787 | BCE Loss: 1.0140880346298218
Epoch 182 / 500 | iteration 25 / 30 | Total Loss: 4.771174430847168 | KNN Loss: 3.725107431411743 | BCE Loss: 1.046067237854004
Epoch 183 / 500 | iteration 0 / 30 | Total Loss: 4.725032329559326 | KNN Loss: 3.6630303859710693 | BCE Loss: 1.0620019435882568
Epoch 183 / 500 | iteration 5 / 30 | Total Loss: 4.743923664093018 | KNN Loss: 3.7194745540618896 | BCE Loss: 1.024449110031128
Epoch 183 / 500 | iteration 10 / 30 | Total Loss: 4.750846862792969 | KNN Loss: 3.72432541847229 | BCE Loss: 1.0265214443206787
Epoch 183 / 500 | iteration 15 / 30 | Total Loss: 4.690490245819092 | KNN Loss: 3.7020835876464844 | 

Epoch 193 / 500 | iteration 0 / 30 | Total Loss: 4.692324161529541 | KNN Loss: 3.6948978900909424 | BCE Loss: 0.9974261522293091
Epoch 193 / 500 | iteration 5 / 30 | Total Loss: 4.749891757965088 | KNN Loss: 3.7270169258117676 | BCE Loss: 1.0228747129440308
Epoch 193 / 500 | iteration 10 / 30 | Total Loss: 4.7060651779174805 | KNN Loss: 3.7034506797790527 | BCE Loss: 1.0026146173477173
Epoch 193 / 500 | iteration 15 / 30 | Total Loss: 4.730594635009766 | KNN Loss: 3.688636541366577 | BCE Loss: 1.0419583320617676
Epoch 193 / 500 | iteration 20 / 30 | Total Loss: 4.700528621673584 | KNN Loss: 3.6944522857666016 | BCE Loss: 1.0060762166976929
Epoch 193 / 500 | iteration 25 / 30 | Total Loss: 4.676318168640137 | KNN Loss: 3.670138359069824 | BCE Loss: 1.006179690361023
Epoch 194 / 500 | iteration 0 / 30 | Total Loss: 4.693743705749512 | KNN Loss: 3.68746280670166 | BCE Loss: 1.0062806606292725
Epoch 194 / 500 | iteration 5 / 30 | Total Loss: 4.776888847351074 | KNN Loss: 3.7225799560546875

Epoch 203 / 500 | iteration 20 / 30 | Total Loss: 4.700628280639648 | KNN Loss: 3.691561222076416 | BCE Loss: 1.0090668201446533
Epoch 203 / 500 | iteration 25 / 30 | Total Loss: 4.697819232940674 | KNN Loss: 3.67606258392334 | BCE Loss: 1.021756649017334
Epoch 204 / 500 | iteration 0 / 30 | Total Loss: 4.726080417633057 | KNN Loss: 3.70048451423645 | BCE Loss: 1.025596022605896
Epoch 204 / 500 | iteration 5 / 30 | Total Loss: 4.714714050292969 | KNN Loss: 3.6922948360443115 | BCE Loss: 1.0224194526672363
Epoch 204 / 500 | iteration 10 / 30 | Total Loss: 4.7030768394470215 | KNN Loss: 3.7001852989196777 | BCE Loss: 1.0028916597366333
Epoch 204 / 500 | iteration 15 / 30 | Total Loss: 4.710176467895508 | KNN Loss: 3.698385715484619 | BCE Loss: 1.0117907524108887
Epoch 204 / 500 | iteration 20 / 30 | Total Loss: 4.782822608947754 | KNN Loss: 3.708951950073242 | BCE Loss: 1.0738704204559326
Epoch 204 / 500 | iteration 25 / 30 | Total Loss: 4.726590156555176 | KNN Loss: 3.6882076263427734 |

Epoch 214 / 500 | iteration 10 / 30 | Total Loss: 4.731794834136963 | KNN Loss: 3.6806206703186035 | BCE Loss: 1.051174283027649
Epoch 214 / 500 | iteration 15 / 30 | Total Loss: 4.737914085388184 | KNN Loss: 3.7077267169952393 | BCE Loss: 1.0301876068115234
Epoch 214 / 500 | iteration 20 / 30 | Total Loss: 4.719261646270752 | KNN Loss: 3.711268186569214 | BCE Loss: 1.0079935789108276
Epoch 214 / 500 | iteration 25 / 30 | Total Loss: 4.783514976501465 | KNN Loss: 3.721160411834717 | BCE Loss: 1.062354564666748
Epoch 215 / 500 | iteration 0 / 30 | Total Loss: 4.723406791687012 | KNN Loss: 3.685626745223999 | BCE Loss: 1.0377800464630127
Epoch 215 / 500 | iteration 5 / 30 | Total Loss: 4.725470066070557 | KNN Loss: 3.697281837463379 | BCE Loss: 1.0281881093978882
Epoch 215 / 500 | iteration 10 / 30 | Total Loss: 4.675640106201172 | KNN Loss: 3.6669130325317383 | BCE Loss: 1.0087270736694336
Epoch 215 / 500 | iteration 15 / 30 | Total Loss: 4.746035575866699 | KNN Loss: 3.704097032546997 

Epoch 225 / 500 | iteration 0 / 30 | Total Loss: 4.695624828338623 | KNN Loss: 3.698115348815918 | BCE Loss: 0.9975094795227051
Epoch 225 / 500 | iteration 5 / 30 | Total Loss: 4.6477813720703125 | KNN Loss: 3.665881633758545 | BCE Loss: 0.9818999767303467
Epoch 225 / 500 | iteration 10 / 30 | Total Loss: 4.764874458312988 | KNN Loss: 3.7062900066375732 | BCE Loss: 1.0585846900939941
Epoch 225 / 500 | iteration 15 / 30 | Total Loss: 4.7406721115112305 | KNN Loss: 3.6761012077331543 | BCE Loss: 1.064570665359497
Epoch 225 / 500 | iteration 20 / 30 | Total Loss: 4.725223064422607 | KNN Loss: 3.6855783462524414 | BCE Loss: 1.039644718170166
Epoch 225 / 500 | iteration 25 / 30 | Total Loss: 4.771454811096191 | KNN Loss: 3.7212390899658203 | BCE Loss: 1.0502156019210815
Epoch 226 / 500 | iteration 0 / 30 | Total Loss: 4.742733001708984 | KNN Loss: 3.7014875411987305 | BCE Loss: 1.041245460510254
Epoch 226 / 500 | iteration 5 / 30 | Total Loss: 4.691166877746582 | KNN Loss: 3.695117712020874

Epoch 235 / 500 | iteration 20 / 30 | Total Loss: 4.7099289894104 | KNN Loss: 3.693317413330078 | BCE Loss: 1.0166115760803223
Epoch 235 / 500 | iteration 25 / 30 | Total Loss: 4.704799175262451 | KNN Loss: 3.6972432136535645 | BCE Loss: 1.0075558423995972
Epoch 236 / 500 | iteration 0 / 30 | Total Loss: 4.71965217590332 | KNN Loss: 3.6739585399627686 | BCE Loss: 1.0456938743591309
Epoch 236 / 500 | iteration 5 / 30 | Total Loss: 4.752005577087402 | KNN Loss: 3.718529224395752 | BCE Loss: 1.0334763526916504
Epoch 236 / 500 | iteration 10 / 30 | Total Loss: 4.718829154968262 | KNN Loss: 3.667057752609253 | BCE Loss: 1.0517714023590088
Epoch 236 / 500 | iteration 15 / 30 | Total Loss: 4.7447710037231445 | KNN Loss: 3.6928117275238037 | BCE Loss: 1.0519593954086304
Epoch 236 / 500 | iteration 20 / 30 | Total Loss: 4.730156898498535 | KNN Loss: 3.726097583770752 | BCE Loss: 1.0040595531463623
Epoch 236 / 500 | iteration 25 / 30 | Total Loss: 4.701536178588867 | KNN Loss: 3.697327136993408 

Epoch 246 / 500 | iteration 10 / 30 | Total Loss: 4.663172721862793 | KNN Loss: 3.65230393409729 | BCE Loss: 1.010869026184082
Epoch 246 / 500 | iteration 15 / 30 | Total Loss: 4.708013534545898 | KNN Loss: 3.7041800022125244 | BCE Loss: 1.003833293914795
Epoch 246 / 500 | iteration 20 / 30 | Total Loss: 4.7195916175842285 | KNN Loss: 3.702199935913086 | BCE Loss: 1.0173918008804321
Epoch 246 / 500 | iteration 25 / 30 | Total Loss: 4.7378740310668945 | KNN Loss: 3.680939197540283 | BCE Loss: 1.0569348335266113
Epoch 247 / 500 | iteration 0 / 30 | Total Loss: 4.725154399871826 | KNN Loss: 3.6858909130096436 | BCE Loss: 1.0392636060714722
Epoch 247 / 500 | iteration 5 / 30 | Total Loss: 4.730515003204346 | KNN Loss: 3.7082533836364746 | BCE Loss: 1.0222615003585815
Epoch 247 / 500 | iteration 10 / 30 | Total Loss: 4.69728946685791 | KNN Loss: 3.6866676807403564 | BCE Loss: 1.0106215476989746
Epoch 247 / 500 | iteration 15 / 30 | Total Loss: 4.72092342376709 | KNN Loss: 3.7108099460601807

Epoch 257 / 500 | iteration 0 / 30 | Total Loss: 4.7534074783325195 | KNN Loss: 3.703187942504883 | BCE Loss: 1.0502197742462158
Epoch 257 / 500 | iteration 5 / 30 | Total Loss: 4.696640968322754 | KNN Loss: 3.6728460788726807 | BCE Loss: 1.0237948894500732
Epoch 257 / 500 | iteration 10 / 30 | Total Loss: 4.668384552001953 | KNN Loss: 3.6642165184020996 | BCE Loss: 1.004167914390564
Epoch 257 / 500 | iteration 15 / 30 | Total Loss: 4.747287750244141 | KNN Loss: 3.699702501296997 | BCE Loss: 1.0475850105285645
Epoch 257 / 500 | iteration 20 / 30 | Total Loss: 4.724896430969238 | KNN Loss: 3.69954252243042 | BCE Loss: 1.0253537893295288
Epoch 257 / 500 | iteration 25 / 30 | Total Loss: 4.736482620239258 | KNN Loss: 3.7071335315704346 | BCE Loss: 1.0293490886688232
Epoch 258 / 500 | iteration 0 / 30 | Total Loss: 4.725542068481445 | KNN Loss: 3.7000577449798584 | BCE Loss: 1.0254840850830078
Epoch 258 / 500 | iteration 5 / 30 | Total Loss: 4.71171760559082 | KNN Loss: 3.692807912826538 |

Epoch 267 / 500 | iteration 20 / 30 | Total Loss: 4.760209083557129 | KNN Loss: 3.708247423171997 | BCE Loss: 1.0519616603851318
Epoch 267 / 500 | iteration 25 / 30 | Total Loss: 4.737545967102051 | KNN Loss: 3.7071492671966553 | BCE Loss: 1.0303969383239746
Epoch 268 / 500 | iteration 0 / 30 | Total Loss: 4.677646160125732 | KNN Loss: 3.6564791202545166 | BCE Loss: 1.0211671590805054
Epoch 268 / 500 | iteration 5 / 30 | Total Loss: 4.684292793273926 | KNN Loss: 3.689962863922119 | BCE Loss: 0.9943298101425171
Epoch 268 / 500 | iteration 10 / 30 | Total Loss: 4.703194618225098 | KNN Loss: 3.670804500579834 | BCE Loss: 1.0323903560638428
Epoch 268 / 500 | iteration 15 / 30 | Total Loss: 4.666982173919678 | KNN Loss: 3.6500494480133057 | BCE Loss: 1.0169328451156616
Epoch 268 / 500 | iteration 20 / 30 | Total Loss: 4.69497013092041 | KNN Loss: 3.67573618888855 | BCE Loss: 1.0192337036132812
Epoch 268 / 500 | iteration 25 / 30 | Total Loss: 4.715019702911377 | KNN Loss: 3.6817996501922607

Epoch 278 / 500 | iteration 10 / 30 | Total Loss: 4.751784324645996 | KNN Loss: 3.703399658203125 | BCE Loss: 1.048384666442871
Epoch 278 / 500 | iteration 15 / 30 | Total Loss: 4.735965728759766 | KNN Loss: 3.7318644523620605 | BCE Loss: 1.0041015148162842
Epoch 278 / 500 | iteration 20 / 30 | Total Loss: 4.726239204406738 | KNN Loss: 3.6770248413085938 | BCE Loss: 1.0492146015167236
Epoch 278 / 500 | iteration 25 / 30 | Total Loss: 4.666477680206299 | KNN Loss: 3.6742289066314697 | BCE Loss: 0.9922487735748291
Epoch 279 / 500 | iteration 0 / 30 | Total Loss: 4.684869289398193 | KNN Loss: 3.681619882583618 | BCE Loss: 1.0032495260238647
Epoch 279 / 500 | iteration 5 / 30 | Total Loss: 4.773609638214111 | KNN Loss: 3.7388665676116943 | BCE Loss: 1.0347431898117065
Epoch 279 / 500 | iteration 10 / 30 | Total Loss: 4.712279319763184 | KNN Loss: 3.675114631652832 | BCE Loss: 1.037164568901062
Epoch 279 / 500 | iteration 15 / 30 | Total Loss: 4.724983215332031 | KNN Loss: 3.689967393875122

Epoch 289 / 500 | iteration 0 / 30 | Total Loss: 4.73099946975708 | KNN Loss: 3.683333158493042 | BCE Loss: 1.047666311264038
Epoch 289 / 500 | iteration 5 / 30 | Total Loss: 4.742441177368164 | KNN Loss: 3.732931137084961 | BCE Loss: 1.0095100402832031
Epoch 289 / 500 | iteration 10 / 30 | Total Loss: 4.689058303833008 | KNN Loss: 3.6563942432403564 | BCE Loss: 1.032664179801941
Epoch 289 / 500 | iteration 15 / 30 | Total Loss: 4.700741767883301 | KNN Loss: 3.6658177375793457 | BCE Loss: 1.0349242687225342
Epoch 289 / 500 | iteration 20 / 30 | Total Loss: 4.672262191772461 | KNN Loss: 3.650599241256714 | BCE Loss: 1.021662950515747
Epoch 289 / 500 | iteration 25 / 30 | Total Loss: 4.6722517013549805 | KNN Loss: 3.667447328567505 | BCE Loss: 1.0048041343688965
Epoch 290 / 500 | iteration 0 / 30 | Total Loss: 4.7086992263793945 | KNN Loss: 3.689448118209839 | BCE Loss: 1.0192508697509766
Epoch 290 / 500 | iteration 5 / 30 | Total Loss: 4.797212600708008 | KNN Loss: 3.7146687507629395 | 

Epoch 299 / 500 | iteration 20 / 30 | Total Loss: 4.708841323852539 | KNN Loss: 3.688969612121582 | BCE Loss: 1.0198719501495361
Epoch 299 / 500 | iteration 25 / 30 | Total Loss: 4.710621356964111 | KNN Loss: 3.674962282180786 | BCE Loss: 1.0356590747833252
Epoch 300 / 500 | iteration 0 / 30 | Total Loss: 4.729092597961426 | KNN Loss: 3.7069194316864014 | BCE Loss: 1.0221730470657349
Epoch 300 / 500 | iteration 5 / 30 | Total Loss: 4.697409629821777 | KNN Loss: 3.6514501571655273 | BCE Loss: 1.0459595918655396
Epoch 300 / 500 | iteration 10 / 30 | Total Loss: 4.7498369216918945 | KNN Loss: 3.714813232421875 | BCE Loss: 1.0350239276885986
Epoch 300 / 500 | iteration 15 / 30 | Total Loss: 4.6861982345581055 | KNN Loss: 3.6597402095794678 | BCE Loss: 1.0264581441879272
Epoch 300 / 500 | iteration 20 / 30 | Total Loss: 4.717418670654297 | KNN Loss: 3.700214147567749 | BCE Loss: 1.0172045230865479
Epoch 300 / 500 | iteration 25 / 30 | Total Loss: 4.70927619934082 | KNN Loss: 3.6772389411926

Epoch 310 / 500 | iteration 10 / 30 | Total Loss: 4.6960248947143555 | KNN Loss: 3.6922380924224854 | BCE Loss: 1.0037866830825806
Epoch 310 / 500 | iteration 15 / 30 | Total Loss: 4.705151081085205 | KNN Loss: 3.688892364501953 | BCE Loss: 1.016258716583252
Epoch 310 / 500 | iteration 20 / 30 | Total Loss: 4.71847677230835 | KNN Loss: 3.6952755451202393 | BCE Loss: 1.0232012271881104
Epoch 310 / 500 | iteration 25 / 30 | Total Loss: 4.75611686706543 | KNN Loss: 3.7108566761016846 | BCE Loss: 1.0452601909637451
Epoch 311 / 500 | iteration 0 / 30 | Total Loss: 4.664460182189941 | KNN Loss: 3.658505439758301 | BCE Loss: 1.0059548616409302
Epoch 311 / 500 | iteration 5 / 30 | Total Loss: 4.742511749267578 | KNN Loss: 3.7012107372283936 | BCE Loss: 1.0413012504577637
Epoch 311 / 500 | iteration 10 / 30 | Total Loss: 4.71280574798584 | KNN Loss: 3.6783576011657715 | BCE Loss: 1.0344479084014893
Epoch 311 / 500 | iteration 15 / 30 | Total Loss: 4.719955921173096 | KNN Loss: 3.687764406204223

Epoch 321 / 500 | iteration 0 / 30 | Total Loss: 4.71147346496582 | KNN Loss: 3.678046226501465 | BCE Loss: 1.033427357673645
Epoch 321 / 500 | iteration 5 / 30 | Total Loss: 4.734026908874512 | KNN Loss: 3.730025053024292 | BCE Loss: 1.0040016174316406
Epoch 321 / 500 | iteration 10 / 30 | Total Loss: 4.761735439300537 | KNN Loss: 3.736593723297119 | BCE Loss: 1.0251415967941284
Epoch 321 / 500 | iteration 15 / 30 | Total Loss: 4.717321395874023 | KNN Loss: 3.6874523162841797 | BCE Loss: 1.0298693180084229
Epoch 321 / 500 | iteration 20 / 30 | Total Loss: 4.723474502563477 | KNN Loss: 3.6831252574920654 | BCE Loss: 1.040349006652832
Epoch 321 / 500 | iteration 25 / 30 | Total Loss: 4.7253546714782715 | KNN Loss: 3.692889928817749 | BCE Loss: 1.032464623451233
Epoch 322 / 500 | iteration 0 / 30 | Total Loss: 4.695505619049072 | KNN Loss: 3.686382532119751 | BCE Loss: 1.0091230869293213
Epoch 322 / 500 | iteration 5 / 30 | Total Loss: 4.720823764801025 | KNN Loss: 3.691710948944092 | BC

Epoch 331 / 500 | iteration 15 / 30 | Total Loss: 4.737398624420166 | KNN Loss: 3.738171339035034 | BCE Loss: 0.9992274045944214
Epoch 331 / 500 | iteration 20 / 30 | Total Loss: 4.729674816131592 | KNN Loss: 3.7005040645599365 | BCE Loss: 1.0291706323623657
Epoch 331 / 500 | iteration 25 / 30 | Total Loss: 4.722132205963135 | KNN Loss: 3.684361696243286 | BCE Loss: 1.037770390510559
Epoch 332 / 500 | iteration 0 / 30 | Total Loss: 4.681082725524902 | KNN Loss: 3.6716907024383545 | BCE Loss: 1.009392261505127
Epoch 332 / 500 | iteration 5 / 30 | Total Loss: 4.749914169311523 | KNN Loss: 3.7205870151519775 | BCE Loss: 1.029327392578125
Epoch 332 / 500 | iteration 10 / 30 | Total Loss: 4.7605085372924805 | KNN Loss: 3.7273824214935303 | BCE Loss: 1.0331259965896606
Epoch 332 / 500 | iteration 15 / 30 | Total Loss: 4.719786167144775 | KNN Loss: 3.692368268966675 | BCE Loss: 1.0274180173873901
Epoch 332 / 500 | iteration 20 / 30 | Total Loss: 4.730352401733398 | KNN Loss: 3.678578376770019

Epoch 342 / 500 | iteration 5 / 30 | Total Loss: 4.7331647872924805 | KNN Loss: 3.7018520832061768 | BCE Loss: 1.0313127040863037
Epoch 342 / 500 | iteration 10 / 30 | Total Loss: 4.770681381225586 | KNN Loss: 3.7402164936065674 | BCE Loss: 1.0304646492004395
Epoch 342 / 500 | iteration 15 / 30 | Total Loss: 4.733707427978516 | KNN Loss: 3.7056727409362793 | BCE Loss: 1.0280349254608154
Epoch 342 / 500 | iteration 20 / 30 | Total Loss: 4.7370805740356445 | KNN Loss: 3.7476553916931152 | BCE Loss: 0.9894253015518188
Epoch 342 / 500 | iteration 25 / 30 | Total Loss: 4.730641841888428 | KNN Loss: 3.6750762462615967 | BCE Loss: 1.055565595626831
Epoch 343 / 500 | iteration 0 / 30 | Total Loss: 4.714796543121338 | KNN Loss: 3.686882257461548 | BCE Loss: 1.02791428565979
Epoch 343 / 500 | iteration 5 / 30 | Total Loss: 4.666537761688232 | KNN Loss: 3.6694324016571045 | BCE Loss: 0.9971053600311279
Epoch 343 / 500 | iteration 10 / 30 | Total Loss: 4.77365779876709 | KNN Loss: 3.70765590667724

Epoch 352 / 500 | iteration 20 / 30 | Total Loss: 4.721430778503418 | KNN Loss: 3.6900882720947266 | BCE Loss: 1.0313427448272705
Epoch 352 / 500 | iteration 25 / 30 | Total Loss: 4.706968307495117 | KNN Loss: 3.7064082622528076 | BCE Loss: 1.0005601644515991
Epoch 353 / 500 | iteration 0 / 30 | Total Loss: 4.712953567504883 | KNN Loss: 3.698476791381836 | BCE Loss: 1.014477014541626
Epoch 353 / 500 | iteration 5 / 30 | Total Loss: 4.768014430999756 | KNN Loss: 3.7510414123535156 | BCE Loss: 1.0169730186462402
Epoch 353 / 500 | iteration 10 / 30 | Total Loss: 4.710958480834961 | KNN Loss: 3.686586856842041 | BCE Loss: 1.0243713855743408
Epoch 353 / 500 | iteration 15 / 30 | Total Loss: 4.769651889801025 | KNN Loss: 3.700533628463745 | BCE Loss: 1.0691182613372803
Epoch 353 / 500 | iteration 20 / 30 | Total Loss: 4.735799789428711 | KNN Loss: 3.725614547729492 | BCE Loss: 1.0101852416992188
Epoch 353 / 500 | iteration 25 / 30 | Total Loss: 4.719125747680664 | KNN Loss: 3.704738378524780

Epoch 363 / 500 | iteration 10 / 30 | Total Loss: 4.7037858963012695 | KNN Loss: 3.6657140254974365 | BCE Loss: 1.038071870803833
Epoch 363 / 500 | iteration 15 / 30 | Total Loss: 4.705646514892578 | KNN Loss: 3.6757941246032715 | BCE Loss: 1.0298523902893066
Epoch 363 / 500 | iteration 20 / 30 | Total Loss: 4.696738243103027 | KNN Loss: 3.6903233528137207 | BCE Loss: 1.006414771080017
Epoch 363 / 500 | iteration 25 / 30 | Total Loss: 4.672262191772461 | KNN Loss: 3.6548614501953125 | BCE Loss: 1.0174009799957275
Epoch 364 / 500 | iteration 0 / 30 | Total Loss: 4.757284641265869 | KNN Loss: 3.705980062484741 | BCE Loss: 1.0513044595718384
Epoch 364 / 500 | iteration 5 / 30 | Total Loss: 4.697528839111328 | KNN Loss: 3.6949398517608643 | BCE Loss: 1.0025891065597534
Epoch 364 / 500 | iteration 10 / 30 | Total Loss: 4.741352081298828 | KNN Loss: 3.702030897140503 | BCE Loss: 1.0393214225769043
Epoch 364 / 500 | iteration 15 / 30 | Total Loss: 4.698066711425781 | KNN Loss: 3.6980419158935

Epoch 373 / 500 | iteration 25 / 30 | Total Loss: 4.699648380279541 | KNN Loss: 3.6955785751342773 | BCE Loss: 1.0040699243545532
Epoch 374 / 500 | iteration 0 / 30 | Total Loss: 4.715725898742676 | KNN Loss: 3.6800684928894043 | BCE Loss: 1.035657525062561
Epoch 374 / 500 | iteration 5 / 30 | Total Loss: 4.69816780090332 | KNN Loss: 3.676811933517456 | BCE Loss: 1.0213558673858643
Epoch 374 / 500 | iteration 10 / 30 | Total Loss: 4.74888801574707 | KNN Loss: 3.6960158348083496 | BCE Loss: 1.0528719425201416
Epoch 374 / 500 | iteration 15 / 30 | Total Loss: 4.722936630249023 | KNN Loss: 3.6922590732574463 | BCE Loss: 1.0306777954101562
Epoch 374 / 500 | iteration 20 / 30 | Total Loss: 4.677726745605469 | KNN Loss: 3.685103178024292 | BCE Loss: 0.9926234483718872
Epoch 374 / 500 | iteration 25 / 30 | Total Loss: 4.679096221923828 | KNN Loss: 3.6714377403259277 | BCE Loss: 1.0076582431793213
Epoch 375 / 500 | iteration 0 / 30 | Total Loss: 4.660531044006348 | KNN Loss: 3.667128324508667 

Epoch 384 / 500 | iteration 15 / 30 | Total Loss: 4.702356338500977 | KNN Loss: 3.6861188411712646 | BCE Loss: 1.016237735748291
Epoch 384 / 500 | iteration 20 / 30 | Total Loss: 4.75984001159668 | KNN Loss: 3.7262327671051025 | BCE Loss: 1.0336074829101562
Epoch 384 / 500 | iteration 25 / 30 | Total Loss: 4.707622528076172 | KNN Loss: 3.6855697631835938 | BCE Loss: 1.0220528841018677
Epoch 385 / 500 | iteration 0 / 30 | Total Loss: 4.716506004333496 | KNN Loss: 3.704629421234131 | BCE Loss: 1.0118765830993652
Epoch 385 / 500 | iteration 5 / 30 | Total Loss: 4.690215110778809 | KNN Loss: 3.6658213138580322 | BCE Loss: 1.0243937969207764
Epoch 385 / 500 | iteration 10 / 30 | Total Loss: 4.731057643890381 | KNN Loss: 3.714479923248291 | BCE Loss: 1.0165777206420898
Epoch 385 / 500 | iteration 15 / 30 | Total Loss: 4.766258239746094 | KNN Loss: 3.6951682567596436 | BCE Loss: 1.071089744567871
Epoch 385 / 500 | iteration 20 / 30 | Total Loss: 4.706774711608887 | KNN Loss: 3.671059131622314

Epoch 395 / 500 | iteration 5 / 30 | Total Loss: 4.721700668334961 | KNN Loss: 3.716920852661133 | BCE Loss: 1.0047798156738281
Epoch 395 / 500 | iteration 10 / 30 | Total Loss: 4.70701789855957 | KNN Loss: 3.6546289920806885 | BCE Loss: 1.0523889064788818
Epoch 395 / 500 | iteration 15 / 30 | Total Loss: 4.69720458984375 | KNN Loss: 3.6814184188842773 | BCE Loss: 1.015786051750183
Epoch 395 / 500 | iteration 20 / 30 | Total Loss: 4.724734783172607 | KNN Loss: 3.6906940937042236 | BCE Loss: 1.0340406894683838
Epoch 395 / 500 | iteration 25 / 30 | Total Loss: 4.734705924987793 | KNN Loss: 3.730337381362915 | BCE Loss: 1.0043686628341675
Epoch 396 / 500 | iteration 0 / 30 | Total Loss: 4.702478408813477 | KNN Loss: 3.705914258956909 | BCE Loss: 0.9965639114379883
Epoch 396 / 500 | iteration 5 / 30 | Total Loss: 4.707912921905518 | KNN Loss: 3.689957857131958 | BCE Loss: 1.0179550647735596
Epoch 396 / 500 | iteration 10 / 30 | Total Loss: 4.721449851989746 | KNN Loss: 3.7163314819335938 |

Epoch 405 / 500 | iteration 25 / 30 | Total Loss: 4.686516284942627 | KNN Loss: 3.6792151927948 | BCE Loss: 1.0073010921478271
Epoch 406 / 500 | iteration 0 / 30 | Total Loss: 4.683380126953125 | KNN Loss: 3.692497491836548 | BCE Loss: 0.9908825755119324
Epoch 406 / 500 | iteration 5 / 30 | Total Loss: 4.734013557434082 | KNN Loss: 3.7130496501922607 | BCE Loss: 1.0209639072418213
Epoch 406 / 500 | iteration 10 / 30 | Total Loss: 4.785103797912598 | KNN Loss: 3.7416744232177734 | BCE Loss: 1.0434296131134033
Epoch 406 / 500 | iteration 15 / 30 | Total Loss: 4.704380989074707 | KNN Loss: 3.689220666885376 | BCE Loss: 1.0151602029800415
Epoch 406 / 500 | iteration 20 / 30 | Total Loss: 4.729823589324951 | KNN Loss: 3.6996164321899414 | BCE Loss: 1.0302071571350098
Epoch 406 / 500 | iteration 25 / 30 | Total Loss: 4.686915397644043 | KNN Loss: 3.6740875244140625 | BCE Loss: 1.0128276348114014
Epoch 407 / 500 | iteration 0 / 30 | Total Loss: 4.714256286621094 | KNN Loss: 3.6863956451416016

Epoch 416 / 500 | iteration 15 / 30 | Total Loss: 4.71793794631958 | KNN Loss: 3.7189621925354004 | BCE Loss: 0.9989758729934692
Epoch 416 / 500 | iteration 20 / 30 | Total Loss: 4.751008987426758 | KNN Loss: 3.655109167098999 | BCE Loss: 1.0958995819091797
Epoch 416 / 500 | iteration 25 / 30 | Total Loss: 4.746922492980957 | KNN Loss: 3.7293834686279297 | BCE Loss: 1.0175390243530273
Epoch 417 / 500 | iteration 0 / 30 | Total Loss: 4.695389270782471 | KNN Loss: 3.6814396381378174 | BCE Loss: 1.0139496326446533
Epoch 417 / 500 | iteration 5 / 30 | Total Loss: 4.708314895629883 | KNN Loss: 3.6958742141723633 | BCE Loss: 1.0124404430389404
Epoch 417 / 500 | iteration 10 / 30 | Total Loss: 4.690959453582764 | KNN Loss: 3.699187755584717 | BCE Loss: 0.9917716383934021
Epoch 417 / 500 | iteration 15 / 30 | Total Loss: 4.705693244934082 | KNN Loss: 3.6715152263641357 | BCE Loss: 1.0341777801513672
Epoch 417 / 500 | iteration 20 / 30 | Total Loss: 4.6919355392456055 | KNN Loss: 3.643036842346

Epoch 427 / 500 | iteration 0 / 30 | Total Loss: 4.716227054595947 | KNN Loss: 3.7163543701171875 | BCE Loss: 0.9998728632926941
Epoch 427 / 500 | iteration 5 / 30 | Total Loss: 4.697392463684082 | KNN Loss: 3.6813416481018066 | BCE Loss: 1.0160508155822754
Epoch 427 / 500 | iteration 10 / 30 | Total Loss: 4.676826000213623 | KNN Loss: 3.6649653911590576 | BCE Loss: 1.0118606090545654
Epoch 427 / 500 | iteration 15 / 30 | Total Loss: 4.725605010986328 | KNN Loss: 3.6983401775360107 | BCE Loss: 1.0272650718688965
Epoch 427 / 500 | iteration 20 / 30 | Total Loss: 4.711603164672852 | KNN Loss: 3.688124895095825 | BCE Loss: 1.0234782695770264
Epoch 427 / 500 | iteration 25 / 30 | Total Loss: 4.744886875152588 | KNN Loss: 3.700871706008911 | BCE Loss: 1.0440152883529663
Epoch 428 / 500 | iteration 0 / 30 | Total Loss: 4.73754358291626 | KNN Loss: 3.698155403137207 | BCE Loss: 1.0393882989883423
Epoch 428 / 500 | iteration 5 / 30 | Total Loss: 4.7168989181518555 | KNN Loss: 3.693422317504883

Epoch 437 / 500 | iteration 20 / 30 | Total Loss: 4.724127769470215 | KNN Loss: 3.6652274131774902 | BCE Loss: 1.0589005947113037
Epoch 437 / 500 | iteration 25 / 30 | Total Loss: 4.693363189697266 | KNN Loss: 3.6762561798095703 | BCE Loss: 1.0171067714691162
Epoch   438: reducing learning rate of group 0 to 9.5791e-07.
Epoch 438 / 500 | iteration 0 / 30 | Total Loss: 4.667969226837158 | KNN Loss: 3.668151378631592 | BCE Loss: 0.9998180270195007
Epoch 438 / 500 | iteration 5 / 30 | Total Loss: 4.674234390258789 | KNN Loss: 3.6698861122131348 | BCE Loss: 1.0043482780456543
Epoch 438 / 500 | iteration 10 / 30 | Total Loss: 4.716817378997803 | KNN Loss: 3.725106716156006 | BCE Loss: 0.9917107224464417
Epoch 438 / 500 | iteration 15 / 30 | Total Loss: 4.7270402908325195 | KNN Loss: 3.684619665145874 | BCE Loss: 1.0424203872680664
Epoch 438 / 500 | iteration 20 / 30 | Total Loss: 4.73661994934082 | KNN Loss: 3.7208058834075928 | BCE Loss: 1.0158143043518066
Epoch 438 / 500 | iteration 25 / 

Epoch 448 / 500 | iteration 10 / 30 | Total Loss: 4.748598098754883 | KNN Loss: 3.703897476196289 | BCE Loss: 1.0447006225585938
Epoch 448 / 500 | iteration 15 / 30 | Total Loss: 4.709958553314209 | KNN Loss: 3.684103488922119 | BCE Loss: 1.0258550643920898
Epoch 448 / 500 | iteration 20 / 30 | Total Loss: 4.682947158813477 | KNN Loss: 3.665107488632202 | BCE Loss: 1.0178394317626953
Epoch 448 / 500 | iteration 25 / 30 | Total Loss: 4.696419715881348 | KNN Loss: 3.6882333755493164 | BCE Loss: 1.0081863403320312
Epoch   449: reducing learning rate of group 0 to 6.7053e-07.
Epoch 449 / 500 | iteration 0 / 30 | Total Loss: 4.688918113708496 | KNN Loss: 3.6989476680755615 | BCE Loss: 0.9899702668190002
Epoch 449 / 500 | iteration 5 / 30 | Total Loss: 4.684876918792725 | KNN Loss: 3.6805100440979004 | BCE Loss: 1.0043667554855347
Epoch 449 / 500 | iteration 10 / 30 | Total Loss: 4.676281929016113 | KNN Loss: 3.67991304397583 | BCE Loss: 0.9963686466217041
Epoch 449 / 500 | iteration 15 / 30

Epoch 459 / 500 | iteration 0 / 30 | Total Loss: 4.739774703979492 | KNN Loss: 3.69266939163208 | BCE Loss: 1.047105073928833
Epoch 459 / 500 | iteration 5 / 30 | Total Loss: 4.714485168457031 | KNN Loss: 3.6632633209228516 | BCE Loss: 1.0512220859527588
Epoch 459 / 500 | iteration 10 / 30 | Total Loss: 4.720749855041504 | KNN Loss: 3.6967527866363525 | BCE Loss: 1.0239968299865723
Epoch 459 / 500 | iteration 15 / 30 | Total Loss: 4.785182476043701 | KNN Loss: 3.7344589233398438 | BCE Loss: 1.0507235527038574
Epoch 459 / 500 | iteration 20 / 30 | Total Loss: 4.728562355041504 | KNN Loss: 3.7009694576263428 | BCE Loss: 1.027592658996582
Epoch 459 / 500 | iteration 25 / 30 | Total Loss: 4.679009437561035 | KNN Loss: 3.6754894256591797 | BCE Loss: 1.0035202503204346
Epoch   460: reducing learning rate of group 0 to 4.6937e-07.
Epoch 460 / 500 | iteration 0 / 30 | Total Loss: 4.720353603363037 | KNN Loss: 3.7050390243530273 | BCE Loss: 1.0153146982192993
Epoch 460 / 500 | iteration 5 / 30 

Epoch 469 / 500 | iteration 20 / 30 | Total Loss: 4.723755836486816 | KNN Loss: 3.6922805309295654 | BCE Loss: 1.031475305557251
Epoch 469 / 500 | iteration 25 / 30 | Total Loss: 4.755258083343506 | KNN Loss: 3.71797251701355 | BCE Loss: 1.037285566329956
Epoch 470 / 500 | iteration 0 / 30 | Total Loss: 4.722499370574951 | KNN Loss: 3.7153282165527344 | BCE Loss: 1.0071711540222168
Epoch 470 / 500 | iteration 5 / 30 | Total Loss: 4.697483062744141 | KNN Loss: 3.651270627975464 | BCE Loss: 1.0462125539779663
Epoch 470 / 500 | iteration 10 / 30 | Total Loss: 4.736935615539551 | KNN Loss: 3.716346263885498 | BCE Loss: 1.0205895900726318
Epoch 470 / 500 | iteration 15 / 30 | Total Loss: 4.670272350311279 | KNN Loss: 3.653196096420288 | BCE Loss: 1.0170761346817017
Epoch 470 / 500 | iteration 20 / 30 | Total Loss: 4.682635307312012 | KNN Loss: 3.6709365844726562 | BCE Loss: 1.0116984844207764
Epoch 470 / 500 | iteration 25 / 30 | Total Loss: 4.701333045959473 | KNN Loss: 3.6814544200897217 

Epoch 480 / 500 | iteration 10 / 30 | Total Loss: 4.723355770111084 | KNN Loss: 3.701094388961792 | BCE Loss: 1.0222612619400024
Epoch 480 / 500 | iteration 15 / 30 | Total Loss: 4.721074104309082 | KNN Loss: 3.7213540077209473 | BCE Loss: 0.9997198581695557
Epoch 480 / 500 | iteration 20 / 30 | Total Loss: 4.700201988220215 | KNN Loss: 3.6684844493865967 | BCE Loss: 1.0317176580429077
Epoch 480 / 500 | iteration 25 / 30 | Total Loss: 4.729432106018066 | KNN Loss: 3.6830976009368896 | BCE Loss: 1.0463342666625977
Epoch 481 / 500 | iteration 0 / 30 | Total Loss: 4.713353157043457 | KNN Loss: 3.698345899581909 | BCE Loss: 1.0150070190429688
Epoch 481 / 500 | iteration 5 / 30 | Total Loss: 4.67999792098999 | KNN Loss: 3.65663743019104 | BCE Loss: 1.0233604907989502
Epoch 481 / 500 | iteration 10 / 30 | Total Loss: 4.685439109802246 | KNN Loss: 3.682257890701294 | BCE Loss: 1.0031814575195312
Epoch 481 / 500 | iteration 15 / 30 | Total Loss: 4.680056095123291 | KNN Loss: 3.6669938564300537

Epoch 491 / 500 | iteration 0 / 30 | Total Loss: 4.713675498962402 | KNN Loss: 3.686420202255249 | BCE Loss: 1.0272551774978638
Epoch 491 / 500 | iteration 5 / 30 | Total Loss: 4.739717483520508 | KNN Loss: 3.6963613033294678 | BCE Loss: 1.0433560609817505
Epoch 491 / 500 | iteration 10 / 30 | Total Loss: 4.723142147064209 | KNN Loss: 3.700502872467041 | BCE Loss: 1.022639274597168
Epoch 491 / 500 | iteration 15 / 30 | Total Loss: 4.720836639404297 | KNN Loss: 3.7010433673858643 | BCE Loss: 1.0197930335998535
Epoch 491 / 500 | iteration 20 / 30 | Total Loss: 4.700216770172119 | KNN Loss: 3.6801035404205322 | BCE Loss: 1.0201133489608765
Epoch 491 / 500 | iteration 25 / 30 | Total Loss: 4.777024269104004 | KNN Loss: 3.728052854537964 | BCE Loss: 1.0489716529846191
Epoch 492 / 500 | iteration 0 / 30 | Total Loss: 4.656067848205566 | KNN Loss: 3.6576991081237793 | BCE Loss: 0.9983688592910767
Epoch 492 / 500 | iteration 5 / 30 | Total Loss: 4.736042022705078 | KNN Loss: 3.716486930847168 

In [7]:
outputs, iterm = model(dataset[67][0].unsqueeze(0).to(device), return_intermidiate=True)
print(outputs)
print(outputs.min())
print(outputs.max())

tensor([[ 1.9380,  2.3906,  3.0006,  2.2567,  3.8740,  0.6361,  3.0656,  2.5721,
          1.4370,  2.3509,  2.0332,  1.6066,  0.9440,  2.1109,  1.3601,  1.5618,
          2.6201,  2.0559,  1.9476,  2.5626,  1.7362,  2.5003,  2.7152,  3.0096,
          2.2980,  1.4191,  2.2821,  1.4492,  1.4496,  0.1748,  0.0623,  1.1730,
          0.3292,  0.9314,  1.1736,  1.4804,  1.3876,  3.7322,  0.4617,  1.4594,
          1.1210, -0.7043, -0.0831,  1.9769,  2.5086,  1.0815, -0.2059,  0.0990,
          1.7805,  1.4562,  2.1947,  0.0654,  1.6704,  0.6176, -0.3426,  1.1495,
          1.7938,  1.6236,  1.4176,  1.8316,  0.8881,  1.0403,  0.0438,  1.8676,
          1.6128,  2.0330, -1.9911,  0.4503,  2.6109,  2.5046,  2.0279,  0.5732,
          1.5381,  2.8802,  2.3922,  1.5608,  0.3403,  0.6326,  0.3242,  1.9697,
          0.1613,  0.6509,  2.1566, -0.2382,  0.5027, -0.9883, -2.2782, -0.1516,
          0.5367, -1.8180,  0.3130, -0.1467, -0.6027, -0.9736,  0.5923,  1.4121,
         -0.6954, -0.6870,  

In [8]:
plt.figure()
plt.plot(losses)
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [9]:
dataset.transform = torchvision.transforms.Compose([
    BinaryEncodingTransform(mapping=dataset.items_to_idx),
]
)
dataset.target_transform = torchvision.transforms.Compose([
    BinaryEncodingTransform(mapping=dataset.items_to_idx),
]
)

In [12]:
dataset_ = [d[0].cpu() for d in dataset]

In [13]:
model = model.eval().cpu()
projections = model.calculate_intermidiate(dataset_)

100%|██████████| 15/15 [00:00<00:00, 82.46it/s]


In [14]:
distances = pairwise_distances(projections)
# distances = np.triu(distances)
distances_f = distances.flatten()

plt.matshow(distances)
plt.colorbar()
plt.figure()
plt.hist(distances_f[distances_f > 0], bins=1000)
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

# Fit DBSCAN and calculate indices


In [15]:
clusters = DBSCAN(eps=0.2, min_samples=80).fit_predict(projections)
# scores = []
# best_score = float('inf')
# clusters = None
# range_ = list(range(5, 20))
# for k in tqdm(range_):
#     y = GaussianMixture(n_components=k).fit_predict(projections)
#     cur_score = davies_bouldin_score(projections, y)
#     scores.append(cur_score)
    
#     if cur_score < best_score:
#         best_score = cur_score
#         clusters = y

In [16]:
perplexity = 100

p = reduce_dims_and_plot(projections[clusters != -1],
                         y=clusters[clusters != -1],
                         title=f'perplexity: {perplexity}',
                         file_name=None,
                         perplexity=perplexity,
                         library='Multicore-TSNE',
                         perform_PCA=False,
                         projected=None,
                         figure_type='2d',
                         show_figure=True,
                         close_figure=False,
                         text=None)    

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [17]:
# from sklearn.tree import DecisionTreeClassifier
# from sklearn import tree
# from sklearn.tree import _tree

In [18]:
tensor_dataset = torch.stack(dataset_)

In [19]:
# clf = DecisionTreeClassifier(max_depth=200, min_samples_leaf=5)
# clf = clf.fit(tensor_dataset[clusters!=-1], clusters[clusters != -1])
# print(clf.score(tensor_dataset[clusters!=-1], clusters[clusters != -1]))
# print(clf.get_depth())

In [20]:
# scores = []
# for min_samples in range(1,50, 1):
#     clf = DecisionTreeClassifier(max_depth=200, min_samples_leaf=min_samples)
#     clf = clf.fit(tensor_dataset[clusters!=-1], clusters[clusters != -1])
#     scores.append(clf.score(tensor_dataset[clusters!=-1], clusters[clusters != -1]))
    
# plt.figure()
# plt.plot(list(range(1,50, 1)), scores)
# plt.show()

In [21]:
def get_rules(tree, feature_names, class_names):
    tree_ = tree.tree_
    feature_name = [
        feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!"
        for i in tree_.feature
    ]

    paths = []
    path = []
    
    def recurse(node, path, paths):
        if tree_.feature[node] != _tree.TREE_UNDEFINED:
            name = feature_name[node]
            threshold = tree_.threshold[node]
            p1, p2 = list(path), list(path)
#             p1 += [f"({name} <= {np.round(threshold, 3)})"]
            p1 += [(name, '<=', np.round(threshold, 3))]
            recurse(tree_.children_left[node], p1, paths)
            p2 += [(name, '>', np.round(threshold, 3))]
            recurse(tree_.children_right[node], p2, paths)
        else:
            path += [(tree_.value[node], tree_.n_node_samples[node])]
            paths += [path]
            
    recurse(0, path, paths)

    # sort by samples count
    samples_count = [p[-1][1] for p in paths]
    ii = list(np.argsort(samples_count))
    paths = [paths[i] for i in reversed(ii)]
    
    rules = []
    for path in paths:
        rule = []
        
        for p in path[:-1]:
            rule += [p]
        target = " then "
        if class_names is None:
            target += "response: "+str(np.round(path[-1][0][0][0],3))
        else:
            classes = path[-1][0][0]
            l = np.argmax(classes)
            target += f"class: {class_names[l]} (proba: {np.round(100.0*classes[l]/np.sum(classes),2)}%)"
           
        proba = np.round(100.0*classes[l]/np.sum(classes),2)
        target += f" | based on {path[-1][1]:,} samples"
        rule_wrapper = {'target': target, 'rule': rule, 'proba': proba}
        rules += [rule_wrapper]
        
    return rules

In [22]:
# rules = get_rules(clf, dataset.items, clusters[clusters != -1])

# for rule in rules:
#     n_pos = 0
#     for c,p,v in rule['rule']:
#         if p == '>':
#             n_pos += 1
#     rule['pos'] = n_pos

In [23]:
# plt.figure()
# probs = [r['proba'] for r in rules]
# plt.hist(probs, bins = 100)
# plt.show()

In [24]:
# rules = sorted(rules, key=lambda x:x['pos'])
# rules = [r for r in rules if r['proba'] > 50]
# print(len(rules))

In [25]:
# for i in range(17):
#     r_i = rules[i]
#     print(f"------------- rule {i} length {len(r_i)} -------------")
#     print(r_i)

# Train a Soft-Decision-Tree given the self-labels

## Prepare the dataset

In [26]:
tree_dataset = list(zip(tensor_dataset[clusters!=-1], clusters[clusters != -1]))
batch_size = 512
tree_loader = torch.utils.data.DataLoader(tree_dataset, batch_size=batch_size, shuffle=True)

# Define how we prune the weights of a node

In [27]:
def prune_node(node_weights, factor=1):
    w = node_weights.cpu().detach().numpy()
    mean_ = np.mean(w)
    std_ = np.std(w)
    node_weights[((mean_ - std_ * factor) < node_weights) & (node_weights < (mean_ + std_ * factor))] = 0
    return node_weights

def prune_node_keep(node_weights, keep=4):
    w = node_weights.cpu().detach().numpy()
    throw_idx = np.argsort(abs(w))[:-keep]
    node_weights[throw_idx] = 0
    return node_weights

def prune_tree(tree_, factor):
    new_weights = tree_.inner_nodes.weight.clone()
    for i in range(new_weights.shape[0]):
        res = prune_node_keep(new_weights[i, :], factor)
        new_weights[i, :] = res

    with torch.no_grad():
        tree_.inner_nodes.weight.copy_(new_weights)
        
def sparseness(x):
    s = []
    for i in range(x.shape[0]):
        x_ = x[i, :]
        sp = (len(x_) - torch.norm(x_, 0).item()) / len(x_)
        s.append(sp)
    return np.mean(s)

def compute_regularization_by_level(tree):
    total_reg = 0
    for i in range(tree.inner_nodes.weight.shape[0]):
        cur_level = np.floor(np.log2(i+1))
        node_reg = torch.norm(tree.inner_nodes.weight[i].view(-1), 2)
        total_reg += 2**(-cur_level) * node_reg
    return total_reg

def show_sparseness(tree):
    avg_sp = sparseness(tree.inner_nodes.weight)
    print(f"Average sparseness: {avg_sp}")
    layer = 0
    sps = []
    for i in range(tree.inner_nodes.weight.shape[0]):
        cur_layer = int(np.floor(np.log2(i+1)))
        if cur_layer != layer:
            print(f"layer {layer}: {np.mean(sps)}")
            sps = []
            layer = cur_layer

        x_ = tree.inner_nodes.weight[i, :]
        sp = (len(x_) - torch.norm(x_, 0).item()) / len(x_)
        sps.append(sp)
        
    return avg_sp

## Training configurations

In [28]:
def do_epoch(model, loader, device, log_interval, losses, accs, epoch, iteration):
    model = model.train()
    for batch_idx, (data, target) in enumerate(loader):
        iteration += 1
        data, target = data.to(device), target.to(device)

        output, penalty = tree.forward(data)

        # Loss
        loss_tree = criterion(output, target.view(-1))

        # Penalty
        loss_tree += penalty

        # Sparse regularization
#         fc_params = torch.cat([x.view(-1) for x in tree.inner_nodes.parameters()])
#         regularization = sparsity_lamda * torch.norm(fc_params, 2)
        regularization = sparsity_lamda * compute_regularization_by_level(tree)
        loss = loss_tree

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        losses.append(loss.item())

        pred = output.data.max(1)[1]
        correct = pred.eq(target.view(-1).data).sum()
        accs.append(correct.item() / data.size()[0])

        # Print training status
        if batch_idx % log_interval == 0:
            print(f"Epoch: {epoch:02d} | Batch: {batch_idx:03d} / {len(loader):03d} | Total loss: {loss.item():.3f} | Reg loss: {regularization.item():.3f} | Tree loss: {loss_tree.item():.3f} | Accuracy: {correct.item() / data.size()[0]:03f} | {round((time.time() - start_time) / iteration, 3)} sec/iter")
            
    return iteration


In [29]:
lr = 5e-3
weight_decay = 5e-4
sparsity_lamda = 2e-3
epochs = 100
output_dim = len(set(clusters))
log_interval = 1
use_cuda = device != 'cpu'

In [30]:
tree = SDT(input_dim=tensor_dataset.shape[1], output_dim=len(clusters - 1), depth=tree_depth, lamda=1e-3, use_cuda=use_cuda)
optimizer = torch.optim.Adam(tree.parameters(),
                                 lr=lr,
                                 weight_decay=weight_decay)
criterion = nn.CrossEntropyLoss()
tree = tree.to(device)

In [31]:
losses = []
accs = []
sparsity = []

In [32]:
start_time = time.time()
iteration = 0
for epoch in range(epochs):
    # Training
    avg_sp = show_sparseness(tree)
    sparsity.append(avg_sp)
    iteration = do_epoch(tree, tree_loader, device, log_interval, losses, accs, epoch, iteration)
    
    if epoch % 1 == 0:
        prune_tree(tree, factor=3)
        

Average sparseness: 0.0
layer 0: 0.0
layer 1: 0.0
layer 2: 0.0
layer 3: 0.0
layer 4: 0.0
Epoch: 00 | Batch: 000 / 026 | Total loss: 9.665 | Reg loss: 0.007 | Tree loss: 9.665 | Accuracy: 0.000000 | 0.09 sec/iter
Epoch: 00 | Batch: 001 / 026 | Total loss: 9.657 | Reg loss: 0.007 | Tree loss: 9.657 | Accuracy: 0.000000 | 0.084 sec/iter
Epoch: 00 | Batch: 002 / 026 | Total loss: 9.645 | Reg loss: 0.007 | Tree loss: 9.645 | Accuracy: 0.000000 | 0.078 sec/iter
Epoch: 00 | Batch: 003 / 026 | Total loss: 9.643 | Reg loss: 0.006 | Tree loss: 9.643 | Accuracy: 0.000000 | 0.076 sec/iter
Epoch: 00 | Batch: 004 / 026 | Total loss: 9.624 | Reg loss: 0.006 | Tree loss: 9.624 | Accuracy: 0.000000 | 0.078 sec/iter
Epoch: 00 | Batch: 005 / 026 | Total loss: 9.616 | Reg loss: 0.006 | Tree loss: 9.616 | Accuracy: 0.000000 | 0.078 sec/iter
Epoch: 00 | Batch: 006 / 026 | Total loss: 9.618 | Reg loss: 0.006 | Tree loss: 9.618 | Accuracy: 0.000000 | 0.077 sec/iter
Epoch: 00 | Batch: 007 / 026 | Total loss: 9

Epoch: 02 | Batch: 012 / 026 | Total loss: 9.292 | Reg loss: 0.007 | Tree loss: 9.292 | Accuracy: 0.080078 | 0.071 sec/iter
Epoch: 02 | Batch: 013 / 026 | Total loss: 9.272 | Reg loss: 0.007 | Tree loss: 9.272 | Accuracy: 0.099609 | 0.071 sec/iter
Epoch: 02 | Batch: 014 / 026 | Total loss: 9.268 | Reg loss: 0.007 | Tree loss: 9.268 | Accuracy: 0.095703 | 0.071 sec/iter
Epoch: 02 | Batch: 015 / 026 | Total loss: 9.272 | Reg loss: 0.008 | Tree loss: 9.272 | Accuracy: 0.076172 | 0.072 sec/iter
Epoch: 02 | Batch: 016 / 026 | Total loss: 9.258 | Reg loss: 0.008 | Tree loss: 9.258 | Accuracy: 0.085938 | 0.071 sec/iter
Epoch: 02 | Batch: 017 / 026 | Total loss: 9.250 | Reg loss: 0.008 | Tree loss: 9.250 | Accuracy: 0.078125 | 0.071 sec/iter
Epoch: 02 | Batch: 018 / 026 | Total loss: 9.244 | Reg loss: 0.008 | Tree loss: 9.244 | Accuracy: 0.095703 | 0.071 sec/iter
Epoch: 02 | Batch: 019 / 026 | Total loss: 9.235 | Reg loss: 0.009 | Tree loss: 9.235 | Accuracy: 0.074219 | 0.071 sec/iter
Epoch: 0

Epoch: 05 | Batch: 000 / 026 | Total loss: 8.989 | Reg loss: 0.010 | Tree loss: 8.989 | Accuracy: 0.089844 | 0.072 sec/iter
Epoch: 05 | Batch: 001 / 026 | Total loss: 8.981 | Reg loss: 0.010 | Tree loss: 8.981 | Accuracy: 0.066406 | 0.072 sec/iter
Epoch: 05 | Batch: 002 / 026 | Total loss: 8.971 | Reg loss: 0.010 | Tree loss: 8.971 | Accuracy: 0.076172 | 0.072 sec/iter
Epoch: 05 | Batch: 003 / 026 | Total loss: 8.956 | Reg loss: 0.010 | Tree loss: 8.956 | Accuracy: 0.082031 | 0.072 sec/iter
Epoch: 05 | Batch: 004 / 026 | Total loss: 8.950 | Reg loss: 0.010 | Tree loss: 8.950 | Accuracy: 0.072266 | 0.072 sec/iter
Epoch: 05 | Batch: 005 / 026 | Total loss: 8.939 | Reg loss: 0.011 | Tree loss: 8.939 | Accuracy: 0.076172 | 0.072 sec/iter
Epoch: 05 | Batch: 006 / 026 | Total loss: 8.941 | Reg loss: 0.011 | Tree loss: 8.941 | Accuracy: 0.054688 | 0.072 sec/iter
Epoch: 05 | Batch: 007 / 026 | Total loss: 8.918 | Reg loss: 0.011 | Tree loss: 8.918 | Accuracy: 0.097656 | 0.072 sec/iter
Epoch: 0

Epoch: 07 | Batch: 015 / 026 | Total loss: 8.531 | Reg loss: 0.016 | Tree loss: 8.531 | Accuracy: 0.072266 | 0.072 sec/iter
Epoch: 07 | Batch: 016 / 026 | Total loss: 8.516 | Reg loss: 0.016 | Tree loss: 8.516 | Accuracy: 0.080078 | 0.072 sec/iter
Epoch: 07 | Batch: 017 / 026 | Total loss: 8.517 | Reg loss: 0.016 | Tree loss: 8.517 | Accuracy: 0.074219 | 0.072 sec/iter
Epoch: 07 | Batch: 018 / 026 | Total loss: 8.493 | Reg loss: 0.017 | Tree loss: 8.493 | Accuracy: 0.109375 | 0.072 sec/iter
Epoch: 07 | Batch: 019 / 026 | Total loss: 8.482 | Reg loss: 0.017 | Tree loss: 8.482 | Accuracy: 0.095703 | 0.072 sec/iter
Epoch: 07 | Batch: 020 / 026 | Total loss: 8.477 | Reg loss: 0.017 | Tree loss: 8.477 | Accuracy: 0.087891 | 0.072 sec/iter
Epoch: 07 | Batch: 021 / 026 | Total loss: 8.451 | Reg loss: 0.017 | Tree loss: 8.451 | Accuracy: 0.103516 | 0.072 sec/iter
Epoch: 07 | Batch: 022 / 026 | Total loss: 8.436 | Reg loss: 0.018 | Tree loss: 8.436 | Accuracy: 0.128906 | 0.072 sec/iter
Epoch: 0

Epoch: 10 | Batch: 000 / 026 | Total loss: 8.239 | Reg loss: 0.018 | Tree loss: 8.239 | Accuracy: 0.056641 | 0.072 sec/iter
Epoch: 10 | Batch: 001 / 026 | Total loss: 8.212 | Reg loss: 0.018 | Tree loss: 8.212 | Accuracy: 0.089844 | 0.072 sec/iter
Epoch: 10 | Batch: 002 / 026 | Total loss: 8.197 | Reg loss: 0.018 | Tree loss: 8.197 | Accuracy: 0.078125 | 0.072 sec/iter
Epoch: 10 | Batch: 003 / 026 | Total loss: 8.194 | Reg loss: 0.018 | Tree loss: 8.194 | Accuracy: 0.080078 | 0.072 sec/iter
Epoch: 10 | Batch: 004 / 026 | Total loss: 8.151 | Reg loss: 0.018 | Tree loss: 8.151 | Accuracy: 0.085938 | 0.072 sec/iter
Epoch: 10 | Batch: 005 / 026 | Total loss: 8.161 | Reg loss: 0.018 | Tree loss: 8.161 | Accuracy: 0.068359 | 0.072 sec/iter
Epoch: 10 | Batch: 006 / 026 | Total loss: 8.147 | Reg loss: 0.018 | Tree loss: 8.147 | Accuracy: 0.070312 | 0.073 sec/iter
Epoch: 10 | Batch: 007 / 026 | Total loss: 8.125 | Reg loss: 0.018 | Tree loss: 8.125 | Accuracy: 0.078125 | 0.073 sec/iter
Epoch: 1

Epoch: 12 | Batch: 012 / 026 | Total loss: 7.713 | Reg loss: 0.021 | Tree loss: 7.713 | Accuracy: 0.083984 | 0.073 sec/iter
Epoch: 12 | Batch: 013 / 026 | Total loss: 7.706 | Reg loss: 0.021 | Tree loss: 7.706 | Accuracy: 0.076172 | 0.073 sec/iter
Epoch: 12 | Batch: 014 / 026 | Total loss: 7.680 | Reg loss: 0.022 | Tree loss: 7.680 | Accuracy: 0.101562 | 0.073 sec/iter
Epoch: 12 | Batch: 015 / 026 | Total loss: 7.672 | Reg loss: 0.022 | Tree loss: 7.672 | Accuracy: 0.087891 | 0.073 sec/iter
Epoch: 12 | Batch: 016 / 026 | Total loss: 7.645 | Reg loss: 0.022 | Tree loss: 7.645 | Accuracy: 0.072266 | 0.073 sec/iter
Epoch: 12 | Batch: 017 / 026 | Total loss: 7.620 | Reg loss: 0.022 | Tree loss: 7.620 | Accuracy: 0.117188 | 0.073 sec/iter
Epoch: 12 | Batch: 018 / 026 | Total loss: 7.635 | Reg loss: 0.022 | Tree loss: 7.635 | Accuracy: 0.074219 | 0.073 sec/iter
Epoch: 12 | Batch: 019 / 026 | Total loss: 7.592 | Reg loss: 0.023 | Tree loss: 7.592 | Accuracy: 0.099609 | 0.073 sec/iter
Epoch: 1

Average sparseness: 0.9821428571428567
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
Epoch: 15 | Batch: 000 / 026 | Total loss: 7.326 | Reg loss: 0.023 | Tree loss: 7.326 | Accuracy: 0.076172 | 0.073 sec/iter
Epoch: 15 | Batch: 001 / 026 | Total loss: 7.329 | Reg loss: 0.023 | Tree loss: 7.329 | Accuracy: 0.056641 | 0.073 sec/iter
Epoch: 15 | Batch: 002 / 026 | Total loss: 7.294 | Reg loss: 0.023 | Tree loss: 7.294 | Accuracy: 0.068359 | 0.073 sec/iter
Epoch: 15 | Batch: 003 / 026 | Total loss: 7.266 | Reg loss: 0.023 | Tree loss: 7.266 | Accuracy: 0.072266 | 0.073 sec/iter
Epoch: 15 | Batch: 004 / 026 | Total loss: 7.249 | Reg loss: 0.023 | Tree loss: 7.249 | Accuracy: 0.087891 | 0.073 sec/iter
Epoch: 15 | Batch: 005 / 026 | Total loss: 7.241 | Reg loss: 0.023 | Tree loss: 7.241 | Accuracy: 0.089844 | 0.073 sec/iter
Epoch: 15 | Batch: 006 / 026 | Total loss: 7.205 | Reg loss: 0.023 | Tree los

Epoch: 17 | Batch: 012 / 026 | Total loss: 6.647 | Reg loss: 0.027 | Tree loss: 6.647 | Accuracy: 0.083984 | 0.073 sec/iter
Epoch: 17 | Batch: 013 / 026 | Total loss: 6.591 | Reg loss: 0.027 | Tree loss: 6.591 | Accuracy: 0.093750 | 0.073 sec/iter
Epoch: 17 | Batch: 014 / 026 | Total loss: 6.589 | Reg loss: 0.027 | Tree loss: 6.589 | Accuracy: 0.080078 | 0.073 sec/iter
Epoch: 17 | Batch: 015 / 026 | Total loss: 6.578 | Reg loss: 0.028 | Tree loss: 6.578 | Accuracy: 0.058594 | 0.073 sec/iter
Epoch: 17 | Batch: 016 / 026 | Total loss: 6.562 | Reg loss: 0.028 | Tree loss: 6.562 | Accuracy: 0.054688 | 0.073 sec/iter
Epoch: 17 | Batch: 017 / 026 | Total loss: 6.528 | Reg loss: 0.028 | Tree loss: 6.528 | Accuracy: 0.083984 | 0.073 sec/iter
Epoch: 17 | Batch: 018 / 026 | Total loss: 6.482 | Reg loss: 0.028 | Tree loss: 6.482 | Accuracy: 0.093750 | 0.073 sec/iter
Epoch: 17 | Batch: 019 / 026 | Total loss: 6.488 | Reg loss: 0.028 | Tree loss: 6.488 | Accuracy: 0.076172 | 0.073 sec/iter
Epoch: 1

Epoch: 19 | Batch: 024 / 026 | Total loss: 5.974 | Reg loss: 0.031 | Tree loss: 5.974 | Accuracy: 0.068359 | 0.072 sec/iter
Epoch: 19 | Batch: 025 / 026 | Total loss: 5.950 | Reg loss: 0.031 | Tree loss: 5.950 | Accuracy: 0.086351 | 0.072 sec/iter
Average sparseness: 0.9821428571428567
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
Epoch: 20 | Batch: 000 / 026 | Total loss: 6.260 | Reg loss: 0.029 | Tree loss: 6.260 | Accuracy: 0.062500 | 0.072 sec/iter
Epoch: 20 | Batch: 001 / 026 | Total loss: 6.261 | Reg loss: 0.029 | Tree loss: 6.261 | Accuracy: 0.082031 | 0.072 sec/iter
Epoch: 20 | Batch: 002 / 026 | Total loss: 6.248 | Reg loss: 0.029 | Tree loss: 6.248 | Accuracy: 0.070312 | 0.072 sec/iter
Epoch: 20 | Batch: 003 / 026 | Total loss: 6.186 | Reg loss: 0.029 | Tree loss: 6.186 | Accuracy: 0.082031 | 0.072 sec/iter
Epoch: 20 | Batch: 004 / 026 | Total loss: 6.195 | Reg loss: 0.029 | Tree los

Epoch: 22 | Batch: 011 / 026 | Total loss: 5.660 | Reg loss: 0.031 | Tree loss: 5.660 | Accuracy: 0.082031 | 0.072 sec/iter
Epoch: 22 | Batch: 012 / 026 | Total loss: 5.639 | Reg loss: 0.031 | Tree loss: 5.639 | Accuracy: 0.080078 | 0.072 sec/iter
Epoch: 22 | Batch: 013 / 026 | Total loss: 5.622 | Reg loss: 0.031 | Tree loss: 5.622 | Accuracy: 0.060547 | 0.072 sec/iter
Epoch: 22 | Batch: 014 / 026 | Total loss: 5.614 | Reg loss: 0.031 | Tree loss: 5.614 | Accuracy: 0.076172 | 0.072 sec/iter
Epoch: 22 | Batch: 015 / 026 | Total loss: 5.625 | Reg loss: 0.031 | Tree loss: 5.625 | Accuracy: 0.068359 | 0.072 sec/iter
Epoch: 22 | Batch: 016 / 026 | Total loss: 5.561 | Reg loss: 0.031 | Tree loss: 5.561 | Accuracy: 0.087891 | 0.072 sec/iter
Epoch: 22 | Batch: 017 / 026 | Total loss: 5.572 | Reg loss: 0.032 | Tree loss: 5.572 | Accuracy: 0.070312 | 0.072 sec/iter
Epoch: 22 | Batch: 018 / 026 | Total loss: 5.538 | Reg loss: 0.032 | Tree loss: 5.538 | Accuracy: 0.072266 | 0.072 sec/iter
Epoch: 2

Epoch: 24 | Batch: 023 / 026 | Total loss: 5.150 | Reg loss: 0.033 | Tree loss: 5.150 | Accuracy: 0.097656 | 0.071 sec/iter
Epoch: 24 | Batch: 024 / 026 | Total loss: 5.123 | Reg loss: 0.033 | Tree loss: 5.123 | Accuracy: 0.083984 | 0.071 sec/iter
Epoch: 24 | Batch: 025 / 026 | Total loss: 5.079 | Reg loss: 0.033 | Tree loss: 5.079 | Accuracy: 0.080780 | 0.071 sec/iter
Average sparseness: 0.9821428571428567
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
Epoch: 25 | Batch: 000 / 026 | Total loss: 5.444 | Reg loss: 0.032 | Tree loss: 5.444 | Accuracy: 0.062500 | 0.071 sec/iter
Epoch: 25 | Batch: 001 / 026 | Total loss: 5.394 | Reg loss: 0.032 | Tree loss: 5.394 | Accuracy: 0.093750 | 0.071 sec/iter
Epoch: 25 | Batch: 002 / 026 | Total loss: 5.393 | Reg loss: 0.032 | Tree loss: 5.393 | Accuracy: 0.070312 | 0.071 sec/iter
Epoch: 25 | Batch: 003 / 026 | Total loss: 5.365 | Reg loss: 0.032 | Tree los

Epoch: 27 | Batch: 010 / 026 | Total loss: 4.934 | Reg loss: 0.033 | Tree loss: 4.934 | Accuracy: 0.087891 | 0.07 sec/iter
Epoch: 27 | Batch: 011 / 026 | Total loss: 4.937 | Reg loss: 0.033 | Tree loss: 4.937 | Accuracy: 0.072266 | 0.07 sec/iter
Epoch: 27 | Batch: 012 / 026 | Total loss: 4.882 | Reg loss: 0.033 | Tree loss: 4.882 | Accuracy: 0.074219 | 0.07 sec/iter
Epoch: 27 | Batch: 013 / 026 | Total loss: 4.878 | Reg loss: 0.033 | Tree loss: 4.878 | Accuracy: 0.072266 | 0.07 sec/iter
Epoch: 27 | Batch: 014 / 026 | Total loss: 4.904 | Reg loss: 0.033 | Tree loss: 4.904 | Accuracy: 0.064453 | 0.07 sec/iter
Epoch: 27 | Batch: 015 / 026 | Total loss: 4.811 | Reg loss: 0.033 | Tree loss: 4.811 | Accuracy: 0.132812 | 0.07 sec/iter
Epoch: 27 | Batch: 016 / 026 | Total loss: 4.825 | Reg loss: 0.033 | Tree loss: 4.825 | Accuracy: 0.082031 | 0.07 sec/iter
Epoch: 27 | Batch: 017 / 026 | Total loss: 4.807 | Reg loss: 0.033 | Tree loss: 4.807 | Accuracy: 0.089844 | 0.07 sec/iter
Epoch: 27 | Batc

Epoch: 29 | Batch: 025 / 026 | Total loss: 4.372 | Reg loss: 0.035 | Tree loss: 4.372 | Accuracy: 0.086351 | 0.069 sec/iter
Average sparseness: 0.9821428571428567
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
Epoch: 30 | Batch: 000 / 026 | Total loss: 4.628 | Reg loss: 0.033 | Tree loss: 4.628 | Accuracy: 0.074219 | 0.069 sec/iter
Epoch: 30 | Batch: 001 / 026 | Total loss: 4.673 | Reg loss: 0.033 | Tree loss: 4.673 | Accuracy: 0.085938 | 0.069 sec/iter
Epoch: 30 | Batch: 002 / 026 | Total loss: 4.667 | Reg loss: 0.033 | Tree loss: 4.667 | Accuracy: 0.068359 | 0.069 sec/iter
Epoch: 30 | Batch: 003 / 026 | Total loss: 4.627 | Reg loss: 0.033 | Tree loss: 4.627 | Accuracy: 0.056641 | 0.069 sec/iter
Epoch: 30 | Batch: 004 / 026 | Total loss: 4.624 | Reg loss: 0.033 | Tree loss: 4.624 | Accuracy: 0.054688 | 0.069 sec/iter
Epoch: 30 | Batch: 005 / 026 | Total loss: 4.599 | Reg loss: 0.033 | Tree los

Epoch: 32 | Batch: 010 / 026 | Total loss: 4.199 | Reg loss: 0.034 | Tree loss: 4.199 | Accuracy: 0.062500 | 0.069 sec/iter
Epoch: 32 | Batch: 011 / 026 | Total loss: 4.182 | Reg loss: 0.034 | Tree loss: 4.182 | Accuracy: 0.062500 | 0.069 sec/iter
Epoch: 32 | Batch: 012 / 026 | Total loss: 4.167 | Reg loss: 0.034 | Tree loss: 4.167 | Accuracy: 0.072266 | 0.069 sec/iter
Epoch: 32 | Batch: 013 / 026 | Total loss: 4.159 | Reg loss: 0.034 | Tree loss: 4.159 | Accuracy: 0.072266 | 0.069 sec/iter
Epoch: 32 | Batch: 014 / 026 | Total loss: 4.188 | Reg loss: 0.034 | Tree loss: 4.188 | Accuracy: 0.072266 | 0.069 sec/iter
Epoch: 32 | Batch: 015 / 026 | Total loss: 4.169 | Reg loss: 0.034 | Tree loss: 4.169 | Accuracy: 0.080078 | 0.069 sec/iter
Epoch: 32 | Batch: 016 / 026 | Total loss: 4.132 | Reg loss: 0.034 | Tree loss: 4.132 | Accuracy: 0.066406 | 0.068 sec/iter
Epoch: 32 | Batch: 017 / 026 | Total loss: 4.138 | Reg loss: 0.034 | Tree loss: 4.138 | Accuracy: 0.072266 | 0.068 sec/iter
Epoch: 3

Epoch: 34 | Batch: 022 / 026 | Total loss: 3.744 | Reg loss: 0.035 | Tree loss: 3.744 | Accuracy: 0.064453 | 0.068 sec/iter
Epoch: 34 | Batch: 023 / 026 | Total loss: 3.794 | Reg loss: 0.035 | Tree loss: 3.794 | Accuracy: 0.078125 | 0.068 sec/iter
Epoch: 34 | Batch: 024 / 026 | Total loss: 3.813 | Reg loss: 0.035 | Tree loss: 3.813 | Accuracy: 0.058594 | 0.068 sec/iter
Epoch: 34 | Batch: 025 / 026 | Total loss: 3.816 | Reg loss: 0.035 | Tree loss: 3.816 | Accuracy: 0.052925 | 0.068 sec/iter
Average sparseness: 0.9821428571428567
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
Epoch: 35 | Batch: 000 / 026 | Total loss: 3.980 | Reg loss: 0.034 | Tree loss: 3.980 | Accuracy: 0.064453 | 0.068 sec/iter
Epoch: 35 | Batch: 001 / 026 | Total loss: 3.990 | Reg loss: 0.034 | Tree loss: 3.990 | Accuracy: 0.074219 | 0.068 sec/iter
Epoch: 35 | Batch: 002 / 026 | Total loss: 3.967 | Reg loss: 0.034 | Tree los

Epoch: 37 | Batch: 006 / 026 | Total loss: 3.777 | Reg loss: 0.034 | Tree loss: 3.777 | Accuracy: 0.082031 | 0.067 sec/iter
Epoch: 37 | Batch: 007 / 026 | Total loss: 3.722 | Reg loss: 0.034 | Tree loss: 3.722 | Accuracy: 0.080078 | 0.067 sec/iter
Epoch: 37 | Batch: 008 / 026 | Total loss: 3.681 | Reg loss: 0.034 | Tree loss: 3.681 | Accuracy: 0.103516 | 0.067 sec/iter
Epoch: 37 | Batch: 009 / 026 | Total loss: 3.683 | Reg loss: 0.034 | Tree loss: 3.683 | Accuracy: 0.111328 | 0.067 sec/iter
Epoch: 37 | Batch: 010 / 026 | Total loss: 3.644 | Reg loss: 0.034 | Tree loss: 3.644 | Accuracy: 0.099609 | 0.067 sec/iter
Epoch: 37 | Batch: 011 / 026 | Total loss: 3.611 | Reg loss: 0.034 | Tree loss: 3.611 | Accuracy: 0.095703 | 0.067 sec/iter
Epoch: 37 | Batch: 012 / 026 | Total loss: 3.635 | Reg loss: 0.034 | Tree loss: 3.635 | Accuracy: 0.095703 | 0.067 sec/iter
Epoch: 37 | Batch: 013 / 026 | Total loss: 3.683 | Reg loss: 0.034 | Tree loss: 3.683 | Accuracy: 0.072266 | 0.067 sec/iter
Epoch: 3

Epoch: 39 | Batch: 018 / 026 | Total loss: 3.417 | Reg loss: 0.034 | Tree loss: 3.417 | Accuracy: 0.132812 | 0.067 sec/iter
Epoch: 39 | Batch: 019 / 026 | Total loss: 3.375 | Reg loss: 0.034 | Tree loss: 3.375 | Accuracy: 0.111328 | 0.067 sec/iter
Epoch: 39 | Batch: 020 / 026 | Total loss: 3.451 | Reg loss: 0.034 | Tree loss: 3.451 | Accuracy: 0.083984 | 0.067 sec/iter
Epoch: 39 | Batch: 021 / 026 | Total loss: 3.399 | Reg loss: 0.034 | Tree loss: 3.399 | Accuracy: 0.117188 | 0.067 sec/iter
Epoch: 39 | Batch: 022 / 026 | Total loss: 3.411 | Reg loss: 0.034 | Tree loss: 3.411 | Accuracy: 0.111328 | 0.067 sec/iter
Epoch: 39 | Batch: 023 / 026 | Total loss: 3.408 | Reg loss: 0.034 | Tree loss: 3.408 | Accuracy: 0.091797 | 0.067 sec/iter
Epoch: 39 | Batch: 024 / 026 | Total loss: 3.417 | Reg loss: 0.034 | Tree loss: 3.417 | Accuracy: 0.072266 | 0.067 sec/iter
Epoch: 39 | Batch: 025 / 026 | Total loss: 3.350 | Reg loss: 0.034 | Tree loss: 3.350 | Accuracy: 0.116992 | 0.067 sec/iter
Average 

Epoch: 42 | Batch: 003 / 026 | Total loss: 3.362 | Reg loss: 0.033 | Tree loss: 3.362 | Accuracy: 0.111328 | 0.066 sec/iter
Epoch: 42 | Batch: 004 / 026 | Total loss: 3.412 | Reg loss: 0.033 | Tree loss: 3.412 | Accuracy: 0.087891 | 0.066 sec/iter
Epoch: 42 | Batch: 005 / 026 | Total loss: 3.340 | Reg loss: 0.033 | Tree loss: 3.340 | Accuracy: 0.095703 | 0.066 sec/iter
Epoch: 42 | Batch: 006 / 026 | Total loss: 3.375 | Reg loss: 0.033 | Tree loss: 3.375 | Accuracy: 0.093750 | 0.066 sec/iter
Epoch: 42 | Batch: 007 / 026 | Total loss: 3.360 | Reg loss: 0.033 | Tree loss: 3.360 | Accuracy: 0.109375 | 0.066 sec/iter
Epoch: 42 | Batch: 008 / 026 | Total loss: 3.369 | Reg loss: 0.033 | Tree loss: 3.369 | Accuracy: 0.093750 | 0.066 sec/iter
Epoch: 42 | Batch: 009 / 026 | Total loss: 3.355 | Reg loss: 0.033 | Tree loss: 3.355 | Accuracy: 0.099609 | 0.066 sec/iter
Epoch: 42 | Batch: 010 / 026 | Total loss: 3.306 | Reg loss: 0.033 | Tree loss: 3.306 | Accuracy: 0.136719 | 0.066 sec/iter
Epoch: 4

Epoch: 44 | Batch: 018 / 026 | Total loss: 3.180 | Reg loss: 0.033 | Tree loss: 3.180 | Accuracy: 0.097656 | 0.066 sec/iter
Epoch: 44 | Batch: 019 / 026 | Total loss: 3.139 | Reg loss: 0.033 | Tree loss: 3.139 | Accuracy: 0.117188 | 0.066 sec/iter
Epoch: 44 | Batch: 020 / 026 | Total loss: 3.167 | Reg loss: 0.033 | Tree loss: 3.167 | Accuracy: 0.101562 | 0.066 sec/iter
Epoch: 44 | Batch: 021 / 026 | Total loss: 3.138 | Reg loss: 0.033 | Tree loss: 3.138 | Accuracy: 0.101562 | 0.066 sec/iter
Epoch: 44 | Batch: 022 / 026 | Total loss: 3.144 | Reg loss: 0.033 | Tree loss: 3.144 | Accuracy: 0.085938 | 0.066 sec/iter
Epoch: 44 | Batch: 023 / 026 | Total loss: 3.146 | Reg loss: 0.033 | Tree loss: 3.146 | Accuracy: 0.103516 | 0.066 sec/iter
Epoch: 44 | Batch: 024 / 026 | Total loss: 3.151 | Reg loss: 0.033 | Tree loss: 3.151 | Accuracy: 0.119141 | 0.066 sec/iter
Epoch: 44 | Batch: 025 / 026 | Total loss: 3.128 | Reg loss: 0.033 | Tree loss: 3.128 | Accuracy: 0.097493 | 0.066 sec/iter
Average 

Epoch: 47 | Batch: 004 / 026 | Total loss: 3.171 | Reg loss: 0.033 | Tree loss: 3.171 | Accuracy: 0.107422 | 0.066 sec/iter
Epoch: 47 | Batch: 005 / 026 | Total loss: 3.183 | Reg loss: 0.033 | Tree loss: 3.183 | Accuracy: 0.113281 | 0.066 sec/iter
Epoch: 47 | Batch: 006 / 026 | Total loss: 3.163 | Reg loss: 0.033 | Tree loss: 3.163 | Accuracy: 0.111328 | 0.066 sec/iter
Epoch: 47 | Batch: 007 / 026 | Total loss: 3.153 | Reg loss: 0.033 | Tree loss: 3.153 | Accuracy: 0.109375 | 0.066 sec/iter
Epoch: 47 | Batch: 008 / 026 | Total loss: 3.158 | Reg loss: 0.033 | Tree loss: 3.158 | Accuracy: 0.099609 | 0.066 sec/iter
Epoch: 47 | Batch: 009 / 026 | Total loss: 3.143 | Reg loss: 0.033 | Tree loss: 3.143 | Accuracy: 0.113281 | 0.066 sec/iter
Epoch: 47 | Batch: 010 / 026 | Total loss: 3.128 | Reg loss: 0.033 | Tree loss: 3.128 | Accuracy: 0.121094 | 0.066 sec/iter
Epoch: 47 | Batch: 011 / 026 | Total loss: 3.112 | Reg loss: 0.033 | Tree loss: 3.112 | Accuracy: 0.099609 | 0.066 sec/iter
Epoch: 4

Epoch: 49 | Batch: 017 / 026 | Total loss: 3.053 | Reg loss: 0.032 | Tree loss: 3.053 | Accuracy: 0.105469 | 0.065 sec/iter
Epoch: 49 | Batch: 018 / 026 | Total loss: 3.019 | Reg loss: 0.032 | Tree loss: 3.019 | Accuracy: 0.107422 | 0.065 sec/iter
Epoch: 49 | Batch: 019 / 026 | Total loss: 3.037 | Reg loss: 0.032 | Tree loss: 3.037 | Accuracy: 0.121094 | 0.065 sec/iter
Epoch: 49 | Batch: 020 / 026 | Total loss: 3.022 | Reg loss: 0.033 | Tree loss: 3.022 | Accuracy: 0.130859 | 0.065 sec/iter
Epoch: 49 | Batch: 021 / 026 | Total loss: 3.007 | Reg loss: 0.033 | Tree loss: 3.007 | Accuracy: 0.111328 | 0.065 sec/iter
Epoch: 49 | Batch: 022 / 026 | Total loss: 3.004 | Reg loss: 0.033 | Tree loss: 3.004 | Accuracy: 0.101562 | 0.065 sec/iter
Epoch: 49 | Batch: 023 / 026 | Total loss: 3.010 | Reg loss: 0.033 | Tree loss: 3.010 | Accuracy: 0.121094 | 0.065 sec/iter
Epoch: 49 | Batch: 024 / 026 | Total loss: 3.017 | Reg loss: 0.033 | Tree loss: 3.017 | Accuracy: 0.117188 | 0.065 sec/iter
Epoch: 4

Epoch: 52 | Batch: 002 / 026 | Total loss: 3.064 | Reg loss: 0.032 | Tree loss: 3.064 | Accuracy: 0.101562 | 0.065 sec/iter
Epoch: 52 | Batch: 003 / 026 | Total loss: 3.116 | Reg loss: 0.032 | Tree loss: 3.116 | Accuracy: 0.111328 | 0.065 sec/iter
Epoch: 52 | Batch: 004 / 026 | Total loss: 3.066 | Reg loss: 0.032 | Tree loss: 3.066 | Accuracy: 0.105469 | 0.065 sec/iter
Epoch: 52 | Batch: 005 / 026 | Total loss: 3.064 | Reg loss: 0.032 | Tree loss: 3.064 | Accuracy: 0.101562 | 0.065 sec/iter
Epoch: 52 | Batch: 006 / 026 | Total loss: 3.050 | Reg loss: 0.032 | Tree loss: 3.050 | Accuracy: 0.117188 | 0.065 sec/iter
Epoch: 52 | Batch: 007 / 026 | Total loss: 3.067 | Reg loss: 0.032 | Tree loss: 3.067 | Accuracy: 0.132812 | 0.065 sec/iter
Epoch: 52 | Batch: 008 / 026 | Total loss: 3.042 | Reg loss: 0.032 | Tree loss: 3.042 | Accuracy: 0.105469 | 0.065 sec/iter
Epoch: 52 | Batch: 009 / 026 | Total loss: 3.027 | Reg loss: 0.032 | Tree loss: 3.027 | Accuracy: 0.125000 | 0.065 sec/iter
Epoch: 5

Epoch: 54 | Batch: 014 / 026 | Total loss: 2.972 | Reg loss: 0.032 | Tree loss: 2.972 | Accuracy: 0.113281 | 0.065 sec/iter
Epoch: 54 | Batch: 015 / 026 | Total loss: 2.972 | Reg loss: 0.032 | Tree loss: 2.972 | Accuracy: 0.126953 | 0.065 sec/iter
Epoch: 54 | Batch: 016 / 026 | Total loss: 2.947 | Reg loss: 0.032 | Tree loss: 2.947 | Accuracy: 0.117188 | 0.065 sec/iter
Epoch: 54 | Batch: 017 / 026 | Total loss: 2.967 | Reg loss: 0.032 | Tree loss: 2.967 | Accuracy: 0.125000 | 0.065 sec/iter
Epoch: 54 | Batch: 018 / 026 | Total loss: 2.959 | Reg loss: 0.032 | Tree loss: 2.959 | Accuracy: 0.121094 | 0.065 sec/iter
Epoch: 54 | Batch: 019 / 026 | Total loss: 2.949 | Reg loss: 0.032 | Tree loss: 2.949 | Accuracy: 0.103516 | 0.065 sec/iter
Epoch: 54 | Batch: 020 / 026 | Total loss: 2.924 | Reg loss: 0.032 | Tree loss: 2.924 | Accuracy: 0.126953 | 0.065 sec/iter
Epoch: 54 | Batch: 021 / 026 | Total loss: 2.935 | Reg loss: 0.032 | Tree loss: 2.935 | Accuracy: 0.164062 | 0.065 sec/iter
Epoch: 5

Epoch: 57 | Batch: 003 / 026 | Total loss: 2.989 | Reg loss: 0.031 | Tree loss: 2.989 | Accuracy: 0.126953 | 0.065 sec/iter
Epoch: 57 | Batch: 004 / 026 | Total loss: 2.997 | Reg loss: 0.031 | Tree loss: 2.997 | Accuracy: 0.113281 | 0.065 sec/iter
Epoch: 57 | Batch: 005 / 026 | Total loss: 3.005 | Reg loss: 0.031 | Tree loss: 3.005 | Accuracy: 0.113281 | 0.065 sec/iter
Epoch: 57 | Batch: 006 / 026 | Total loss: 2.970 | Reg loss: 0.031 | Tree loss: 2.970 | Accuracy: 0.113281 | 0.065 sec/iter
Epoch: 57 | Batch: 007 / 026 | Total loss: 2.997 | Reg loss: 0.031 | Tree loss: 2.997 | Accuracy: 0.105469 | 0.065 sec/iter
Epoch: 57 | Batch: 008 / 026 | Total loss: 2.984 | Reg loss: 0.031 | Tree loss: 2.984 | Accuracy: 0.121094 | 0.065 sec/iter
Epoch: 57 | Batch: 009 / 026 | Total loss: 2.983 | Reg loss: 0.031 | Tree loss: 2.983 | Accuracy: 0.119141 | 0.065 sec/iter
Epoch: 57 | Batch: 010 / 026 | Total loss: 2.943 | Reg loss: 0.032 | Tree loss: 2.943 | Accuracy: 0.117188 | 0.065 sec/iter
Epoch: 5

Epoch: 59 | Batch: 015 / 026 | Total loss: 2.919 | Reg loss: 0.031 | Tree loss: 2.919 | Accuracy: 0.095703 | 0.064 sec/iter
Epoch: 59 | Batch: 016 / 026 | Total loss: 2.890 | Reg loss: 0.031 | Tree loss: 2.890 | Accuracy: 0.132812 | 0.064 sec/iter
Epoch: 59 | Batch: 017 / 026 | Total loss: 2.922 | Reg loss: 0.031 | Tree loss: 2.922 | Accuracy: 0.099609 | 0.064 sec/iter
Epoch: 59 | Batch: 018 / 026 | Total loss: 2.937 | Reg loss: 0.031 | Tree loss: 2.937 | Accuracy: 0.105469 | 0.064 sec/iter
Epoch: 59 | Batch: 019 / 026 | Total loss: 2.887 | Reg loss: 0.032 | Tree loss: 2.887 | Accuracy: 0.111328 | 0.064 sec/iter
Epoch: 59 | Batch: 020 / 026 | Total loss: 2.881 | Reg loss: 0.032 | Tree loss: 2.881 | Accuracy: 0.107422 | 0.064 sec/iter
Epoch: 59 | Batch: 021 / 026 | Total loss: 2.874 | Reg loss: 0.032 | Tree loss: 2.874 | Accuracy: 0.121094 | 0.064 sec/iter
Epoch: 59 | Batch: 022 / 026 | Total loss: 2.890 | Reg loss: 0.032 | Tree loss: 2.890 | Accuracy: 0.128906 | 0.064 sec/iter
Epoch: 5

Epoch: 62 | Batch: 001 / 026 | Total loss: 2.981 | Reg loss: 0.031 | Tree loss: 2.981 | Accuracy: 0.128906 | 0.064 sec/iter
Epoch: 62 | Batch: 002 / 026 | Total loss: 2.969 | Reg loss: 0.031 | Tree loss: 2.969 | Accuracy: 0.107422 | 0.064 sec/iter
Epoch: 62 | Batch: 003 / 026 | Total loss: 2.958 | Reg loss: 0.031 | Tree loss: 2.958 | Accuracy: 0.113281 | 0.064 sec/iter
Epoch: 62 | Batch: 004 / 026 | Total loss: 2.945 | Reg loss: 0.031 | Tree loss: 2.945 | Accuracy: 0.136719 | 0.064 sec/iter
Epoch: 62 | Batch: 005 / 026 | Total loss: 2.939 | Reg loss: 0.031 | Tree loss: 2.939 | Accuracy: 0.132812 | 0.064 sec/iter
Epoch: 62 | Batch: 006 / 026 | Total loss: 2.960 | Reg loss: 0.031 | Tree loss: 2.960 | Accuracy: 0.087891 | 0.064 sec/iter
Epoch: 62 | Batch: 007 / 026 | Total loss: 2.920 | Reg loss: 0.031 | Tree loss: 2.920 | Accuracy: 0.136719 | 0.064 sec/iter
Epoch: 62 | Batch: 008 / 026 | Total loss: 2.929 | Reg loss: 0.031 | Tree loss: 2.929 | Accuracy: 0.121094 | 0.064 sec/iter
Epoch: 6

Epoch: 64 | Batch: 015 / 026 | Total loss: 2.872 | Reg loss: 0.031 | Tree loss: 2.872 | Accuracy: 0.126953 | 0.064 sec/iter
Epoch: 64 | Batch: 016 / 026 | Total loss: 2.874 | Reg loss: 0.031 | Tree loss: 2.874 | Accuracy: 0.126953 | 0.064 sec/iter
Epoch: 64 | Batch: 017 / 026 | Total loss: 2.909 | Reg loss: 0.031 | Tree loss: 2.909 | Accuracy: 0.128906 | 0.064 sec/iter
Epoch: 64 | Batch: 018 / 026 | Total loss: 2.872 | Reg loss: 0.031 | Tree loss: 2.872 | Accuracy: 0.111328 | 0.064 sec/iter
Epoch: 64 | Batch: 019 / 026 | Total loss: 2.864 | Reg loss: 0.031 | Tree loss: 2.864 | Accuracy: 0.103516 | 0.064 sec/iter
Epoch: 64 | Batch: 020 / 026 | Total loss: 2.836 | Reg loss: 0.031 | Tree loss: 2.836 | Accuracy: 0.121094 | 0.064 sec/iter
Epoch: 64 | Batch: 021 / 026 | Total loss: 2.870 | Reg loss: 0.031 | Tree loss: 2.870 | Accuracy: 0.115234 | 0.064 sec/iter
Epoch: 64 | Batch: 022 / 026 | Total loss: 2.859 | Reg loss: 0.031 | Tree loss: 2.859 | Accuracy: 0.093750 | 0.064 sec/iter
Epoch: 6

Epoch: 67 | Batch: 001 / 026 | Total loss: 2.940 | Reg loss: 0.031 | Tree loss: 2.940 | Accuracy: 0.128906 | 0.064 sec/iter
Epoch: 67 | Batch: 002 / 026 | Total loss: 2.959 | Reg loss: 0.031 | Tree loss: 2.959 | Accuracy: 0.087891 | 0.064 sec/iter
Epoch: 67 | Batch: 003 / 026 | Total loss: 2.877 | Reg loss: 0.031 | Tree loss: 2.877 | Accuracy: 0.126953 | 0.064 sec/iter
Epoch: 67 | Batch: 004 / 026 | Total loss: 2.949 | Reg loss: 0.031 | Tree loss: 2.949 | Accuracy: 0.119141 | 0.064 sec/iter
Epoch: 67 | Batch: 005 / 026 | Total loss: 2.889 | Reg loss: 0.031 | Tree loss: 2.889 | Accuracy: 0.119141 | 0.064 sec/iter
Epoch: 67 | Batch: 006 / 026 | Total loss: 2.897 | Reg loss: 0.031 | Tree loss: 2.897 | Accuracy: 0.126953 | 0.064 sec/iter
Epoch: 67 | Batch: 007 / 026 | Total loss: 2.926 | Reg loss: 0.031 | Tree loss: 2.926 | Accuracy: 0.113281 | 0.064 sec/iter
Epoch: 67 | Batch: 008 / 026 | Total loss: 2.890 | Reg loss: 0.031 | Tree loss: 2.890 | Accuracy: 0.140625 | 0.064 sec/iter
Epoch: 6

Epoch: 69 | Batch: 013 / 026 | Total loss: 2.868 | Reg loss: 0.031 | Tree loss: 2.868 | Accuracy: 0.146484 | 0.064 sec/iter
Epoch: 69 | Batch: 014 / 026 | Total loss: 2.867 | Reg loss: 0.031 | Tree loss: 2.867 | Accuracy: 0.113281 | 0.064 sec/iter
Epoch: 69 | Batch: 015 / 026 | Total loss: 2.842 | Reg loss: 0.031 | Tree loss: 2.842 | Accuracy: 0.126953 | 0.064 sec/iter
Epoch: 69 | Batch: 016 / 026 | Total loss: 2.826 | Reg loss: 0.031 | Tree loss: 2.826 | Accuracy: 0.132812 | 0.064 sec/iter
Epoch: 69 | Batch: 017 / 026 | Total loss: 2.836 | Reg loss: 0.031 | Tree loss: 2.836 | Accuracy: 0.101562 | 0.064 sec/iter
Epoch: 69 | Batch: 018 / 026 | Total loss: 2.856 | Reg loss: 0.031 | Tree loss: 2.856 | Accuracy: 0.097656 | 0.064 sec/iter
Epoch: 69 | Batch: 019 / 026 | Total loss: 2.877 | Reg loss: 0.031 | Tree loss: 2.877 | Accuracy: 0.093750 | 0.064 sec/iter
Epoch: 69 | Batch: 020 / 026 | Total loss: 2.819 | Reg loss: 0.031 | Tree loss: 2.819 | Accuracy: 0.107422 | 0.064 sec/iter
Epoch: 6

Epoch: 72 | Batch: 001 / 026 | Total loss: 2.893 | Reg loss: 0.031 | Tree loss: 2.893 | Accuracy: 0.132812 | 0.063 sec/iter
Epoch: 72 | Batch: 002 / 026 | Total loss: 2.887 | Reg loss: 0.031 | Tree loss: 2.887 | Accuracy: 0.123047 | 0.063 sec/iter
Epoch: 72 | Batch: 003 / 026 | Total loss: 2.899 | Reg loss: 0.031 | Tree loss: 2.899 | Accuracy: 0.105469 | 0.063 sec/iter
Epoch: 72 | Batch: 004 / 026 | Total loss: 2.882 | Reg loss: 0.031 | Tree loss: 2.882 | Accuracy: 0.134766 | 0.063 sec/iter
Epoch: 72 | Batch: 005 / 026 | Total loss: 2.878 | Reg loss: 0.031 | Tree loss: 2.878 | Accuracy: 0.119141 | 0.063 sec/iter
Epoch: 72 | Batch: 006 / 026 | Total loss: 2.894 | Reg loss: 0.031 | Tree loss: 2.894 | Accuracy: 0.103516 | 0.063 sec/iter
Epoch: 72 | Batch: 007 / 026 | Total loss: 2.874 | Reg loss: 0.031 | Tree loss: 2.874 | Accuracy: 0.119141 | 0.063 sec/iter
Epoch: 72 | Batch: 008 / 026 | Total loss: 2.886 | Reg loss: 0.031 | Tree loss: 2.886 | Accuracy: 0.130859 | 0.063 sec/iter
Epoch: 7

Epoch: 74 | Batch: 013 / 026 | Total loss: 2.817 | Reg loss: 0.031 | Tree loss: 2.817 | Accuracy: 0.134766 | 0.063 sec/iter
Epoch: 74 | Batch: 014 / 026 | Total loss: 2.827 | Reg loss: 0.031 | Tree loss: 2.827 | Accuracy: 0.119141 | 0.063 sec/iter
Epoch: 74 | Batch: 015 / 026 | Total loss: 2.792 | Reg loss: 0.031 | Tree loss: 2.792 | Accuracy: 0.152344 | 0.063 sec/iter
Epoch: 74 | Batch: 016 / 026 | Total loss: 2.832 | Reg loss: 0.031 | Tree loss: 2.832 | Accuracy: 0.125000 | 0.063 sec/iter
Epoch: 74 | Batch: 017 / 026 | Total loss: 2.835 | Reg loss: 0.031 | Tree loss: 2.835 | Accuracy: 0.101562 | 0.063 sec/iter
Epoch: 74 | Batch: 018 / 026 | Total loss: 2.823 | Reg loss: 0.031 | Tree loss: 2.823 | Accuracy: 0.099609 | 0.063 sec/iter
Epoch: 74 | Batch: 019 / 026 | Total loss: 2.824 | Reg loss: 0.031 | Tree loss: 2.824 | Accuracy: 0.093750 | 0.063 sec/iter
Epoch: 74 | Batch: 020 / 026 | Total loss: 2.817 | Reg loss: 0.031 | Tree loss: 2.817 | Accuracy: 0.101562 | 0.063 sec/iter
Epoch: 7

Epoch: 77 | Batch: 001 / 026 | Total loss: 2.892 | Reg loss: 0.031 | Tree loss: 2.892 | Accuracy: 0.111328 | 0.063 sec/iter
Epoch: 77 | Batch: 002 / 026 | Total loss: 2.867 | Reg loss: 0.031 | Tree loss: 2.867 | Accuracy: 0.121094 | 0.063 sec/iter
Epoch: 77 | Batch: 003 / 026 | Total loss: 2.938 | Reg loss: 0.031 | Tree loss: 2.938 | Accuracy: 0.111328 | 0.063 sec/iter
Epoch: 77 | Batch: 004 / 026 | Total loss: 2.877 | Reg loss: 0.031 | Tree loss: 2.877 | Accuracy: 0.109375 | 0.063 sec/iter
Epoch: 77 | Batch: 005 / 026 | Total loss: 2.894 | Reg loss: 0.031 | Tree loss: 2.894 | Accuracy: 0.123047 | 0.063 sec/iter
Epoch: 77 | Batch: 006 / 026 | Total loss: 2.859 | Reg loss: 0.031 | Tree loss: 2.859 | Accuracy: 0.136719 | 0.063 sec/iter
Epoch: 77 | Batch: 007 / 026 | Total loss: 2.838 | Reg loss: 0.031 | Tree loss: 2.838 | Accuracy: 0.128906 | 0.063 sec/iter
Epoch: 77 | Batch: 008 / 026 | Total loss: 2.887 | Reg loss: 0.031 | Tree loss: 2.887 | Accuracy: 0.089844 | 0.063 sec/iter
Epoch: 7

Epoch: 79 | Batch: 014 / 026 | Total loss: 2.826 | Reg loss: 0.031 | Tree loss: 2.826 | Accuracy: 0.111328 | 0.063 sec/iter
Epoch: 79 | Batch: 015 / 026 | Total loss: 2.802 | Reg loss: 0.031 | Tree loss: 2.802 | Accuracy: 0.121094 | 0.063 sec/iter
Epoch: 79 | Batch: 016 / 026 | Total loss: 2.798 | Reg loss: 0.031 | Tree loss: 2.798 | Accuracy: 0.123047 | 0.063 sec/iter
Epoch: 79 | Batch: 017 / 026 | Total loss: 2.785 | Reg loss: 0.031 | Tree loss: 2.785 | Accuracy: 0.113281 | 0.063 sec/iter
Epoch: 79 | Batch: 018 / 026 | Total loss: 2.826 | Reg loss: 0.031 | Tree loss: 2.826 | Accuracy: 0.121094 | 0.063 sec/iter
Epoch: 79 | Batch: 019 / 026 | Total loss: 2.794 | Reg loss: 0.031 | Tree loss: 2.794 | Accuracy: 0.125000 | 0.063 sec/iter
Epoch: 79 | Batch: 020 / 026 | Total loss: 2.815 | Reg loss: 0.031 | Tree loss: 2.815 | Accuracy: 0.113281 | 0.063 sec/iter
Epoch: 79 | Batch: 021 / 026 | Total loss: 2.791 | Reg loss: 0.031 | Tree loss: 2.791 | Accuracy: 0.087891 | 0.063 sec/iter
Epoch: 7

Epoch: 82 | Batch: 003 / 026 | Total loss: 2.863 | Reg loss: 0.030 | Tree loss: 2.863 | Accuracy: 0.125000 | 0.063 sec/iter
Epoch: 82 | Batch: 004 / 026 | Total loss: 2.898 | Reg loss: 0.030 | Tree loss: 2.898 | Accuracy: 0.097656 | 0.063 sec/iter
Epoch: 82 | Batch: 005 / 026 | Total loss: 2.872 | Reg loss: 0.030 | Tree loss: 2.872 | Accuracy: 0.111328 | 0.063 sec/iter
Epoch: 82 | Batch: 006 / 026 | Total loss: 2.826 | Reg loss: 0.030 | Tree loss: 2.826 | Accuracy: 0.134766 | 0.063 sec/iter
Epoch: 82 | Batch: 007 / 026 | Total loss: 2.875 | Reg loss: 0.030 | Tree loss: 2.875 | Accuracy: 0.099609 | 0.063 sec/iter
Epoch: 82 | Batch: 008 / 026 | Total loss: 2.851 | Reg loss: 0.030 | Tree loss: 2.851 | Accuracy: 0.087891 | 0.063 sec/iter
Epoch: 82 | Batch: 009 / 026 | Total loss: 2.831 | Reg loss: 0.030 | Tree loss: 2.831 | Accuracy: 0.121094 | 0.063 sec/iter
Epoch: 82 | Batch: 010 / 026 | Total loss: 2.819 | Reg loss: 0.030 | Tree loss: 2.819 | Accuracy: 0.105469 | 0.063 sec/iter
Epoch: 8

Epoch: 84 | Batch: 015 / 026 | Total loss: 2.849 | Reg loss: 0.030 | Tree loss: 2.849 | Accuracy: 0.097656 | 0.063 sec/iter
Epoch: 84 | Batch: 016 / 026 | Total loss: 2.784 | Reg loss: 0.030 | Tree loss: 2.784 | Accuracy: 0.111328 | 0.063 sec/iter
Epoch: 84 | Batch: 017 / 026 | Total loss: 2.800 | Reg loss: 0.030 | Tree loss: 2.800 | Accuracy: 0.113281 | 0.063 sec/iter
Epoch: 84 | Batch: 018 / 026 | Total loss: 2.775 | Reg loss: 0.031 | Tree loss: 2.775 | Accuracy: 0.140625 | 0.063 sec/iter
Epoch: 84 | Batch: 019 / 026 | Total loss: 2.770 | Reg loss: 0.031 | Tree loss: 2.770 | Accuracy: 0.117188 | 0.063 sec/iter
Epoch: 84 | Batch: 020 / 026 | Total loss: 2.802 | Reg loss: 0.031 | Tree loss: 2.802 | Accuracy: 0.101562 | 0.063 sec/iter
Epoch: 84 | Batch: 021 / 026 | Total loss: 2.802 | Reg loss: 0.031 | Tree loss: 2.802 | Accuracy: 0.099609 | 0.063 sec/iter
Epoch: 84 | Batch: 022 / 026 | Total loss: 2.802 | Reg loss: 0.031 | Tree loss: 2.802 | Accuracy: 0.117188 | 0.063 sec/iter
Epoch: 8

Epoch: 87 | Batch: 000 / 026 | Total loss: 2.887 | Reg loss: 0.030 | Tree loss: 2.887 | Accuracy: 0.093750 | 0.062 sec/iter
Epoch: 87 | Batch: 001 / 026 | Total loss: 2.896 | Reg loss: 0.030 | Tree loss: 2.896 | Accuracy: 0.101562 | 0.062 sec/iter
Epoch: 87 | Batch: 002 / 026 | Total loss: 2.867 | Reg loss: 0.030 | Tree loss: 2.867 | Accuracy: 0.115234 | 0.062 sec/iter
Epoch: 87 | Batch: 003 / 026 | Total loss: 2.898 | Reg loss: 0.030 | Tree loss: 2.898 | Accuracy: 0.119141 | 0.062 sec/iter
Epoch: 87 | Batch: 004 / 026 | Total loss: 2.855 | Reg loss: 0.030 | Tree loss: 2.855 | Accuracy: 0.128906 | 0.062 sec/iter
Epoch: 87 | Batch: 005 / 026 | Total loss: 2.880 | Reg loss: 0.030 | Tree loss: 2.880 | Accuracy: 0.121094 | 0.062 sec/iter
Epoch: 87 | Batch: 006 / 026 | Total loss: 2.870 | Reg loss: 0.030 | Tree loss: 2.870 | Accuracy: 0.101562 | 0.062 sec/iter
Epoch: 87 | Batch: 007 / 026 | Total loss: 2.820 | Reg loss: 0.030 | Tree loss: 2.820 | Accuracy: 0.125000 | 0.062 sec/iter
Epoch: 8

Epoch: 89 | Batch: 012 / 026 | Total loss: 2.798 | Reg loss: 0.030 | Tree loss: 2.798 | Accuracy: 0.115234 | 0.062 sec/iter
Epoch: 89 | Batch: 013 / 026 | Total loss: 2.781 | Reg loss: 0.030 | Tree loss: 2.781 | Accuracy: 0.125000 | 0.062 sec/iter
Epoch: 89 | Batch: 014 / 026 | Total loss: 2.815 | Reg loss: 0.030 | Tree loss: 2.815 | Accuracy: 0.087891 | 0.062 sec/iter
Epoch: 89 | Batch: 015 / 026 | Total loss: 2.806 | Reg loss: 0.030 | Tree loss: 2.806 | Accuracy: 0.128906 | 0.062 sec/iter
Epoch: 89 | Batch: 016 / 026 | Total loss: 2.781 | Reg loss: 0.030 | Tree loss: 2.781 | Accuracy: 0.130859 | 0.062 sec/iter
Epoch: 89 | Batch: 017 / 026 | Total loss: 2.773 | Reg loss: 0.030 | Tree loss: 2.773 | Accuracy: 0.119141 | 0.062 sec/iter
Epoch: 89 | Batch: 018 / 026 | Total loss: 2.776 | Reg loss: 0.030 | Tree loss: 2.776 | Accuracy: 0.111328 | 0.062 sec/iter
Epoch: 89 | Batch: 019 / 026 | Total loss: 2.768 | Reg loss: 0.030 | Tree loss: 2.768 | Accuracy: 0.119141 | 0.062 sec/iter
Epoch: 8

Epoch: 91 | Batch: 024 / 026 | Total loss: 2.739 | Reg loss: 0.030 | Tree loss: 2.739 | Accuracy: 0.105469 | 0.062 sec/iter
Epoch: 91 | Batch: 025 / 026 | Total loss: 2.759 | Reg loss: 0.030 | Tree loss: 2.759 | Accuracy: 0.114206 | 0.062 sec/iter
Average sparseness: 0.9821428571428567
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
Epoch: 92 | Batch: 000 / 026 | Total loss: 2.860 | Reg loss: 0.030 | Tree loss: 2.860 | Accuracy: 0.113281 | 0.062 sec/iter
Epoch: 92 | Batch: 001 / 026 | Total loss: 2.887 | Reg loss: 0.030 | Tree loss: 2.887 | Accuracy: 0.109375 | 0.062 sec/iter
Epoch: 92 | Batch: 002 / 026 | Total loss: 2.866 | Reg loss: 0.030 | Tree loss: 2.866 | Accuracy: 0.117188 | 0.062 sec/iter
Epoch: 92 | Batch: 003 / 026 | Total loss: 2.864 | Reg loss: 0.030 | Tree loss: 2.864 | Accuracy: 0.115234 | 0.062 sec/iter
Epoch: 92 | Batch: 004 / 026 | Total loss: 2.848 | Reg loss: 0.030 | Tree los

Epoch: 94 | Batch: 008 / 026 | Total loss: 2.822 | Reg loss: 0.030 | Tree loss: 2.822 | Accuracy: 0.113281 | 0.062 sec/iter
Epoch: 94 | Batch: 009 / 026 | Total loss: 2.834 | Reg loss: 0.030 | Tree loss: 2.834 | Accuracy: 0.113281 | 0.062 sec/iter
Epoch: 94 | Batch: 010 / 026 | Total loss: 2.846 | Reg loss: 0.030 | Tree loss: 2.846 | Accuracy: 0.091797 | 0.062 sec/iter
Epoch: 94 | Batch: 011 / 026 | Total loss: 2.802 | Reg loss: 0.030 | Tree loss: 2.802 | Accuracy: 0.105469 | 0.062 sec/iter
Epoch: 94 | Batch: 012 / 026 | Total loss: 2.798 | Reg loss: 0.030 | Tree loss: 2.798 | Accuracy: 0.121094 | 0.062 sec/iter
Epoch: 94 | Batch: 013 / 026 | Total loss: 2.815 | Reg loss: 0.030 | Tree loss: 2.815 | Accuracy: 0.134766 | 0.062 sec/iter
Epoch: 94 | Batch: 014 / 026 | Total loss: 2.782 | Reg loss: 0.030 | Tree loss: 2.782 | Accuracy: 0.119141 | 0.062 sec/iter
Epoch: 94 | Batch: 015 / 026 | Total loss: 2.802 | Reg loss: 0.030 | Tree loss: 2.802 | Accuracy: 0.103516 | 0.062 sec/iter
Epoch: 9

Epoch: 96 | Batch: 020 / 026 | Total loss: 2.780 | Reg loss: 0.030 | Tree loss: 2.780 | Accuracy: 0.080078 | 0.062 sec/iter
Epoch: 96 | Batch: 021 / 026 | Total loss: 2.761 | Reg loss: 0.030 | Tree loss: 2.761 | Accuracy: 0.111328 | 0.062 sec/iter
Epoch: 96 | Batch: 022 / 026 | Total loss: 2.787 | Reg loss: 0.030 | Tree loss: 2.787 | Accuracy: 0.115234 | 0.062 sec/iter
Epoch: 96 | Batch: 023 / 026 | Total loss: 2.765 | Reg loss: 0.030 | Tree loss: 2.765 | Accuracy: 0.109375 | 0.062 sec/iter
Epoch: 96 | Batch: 024 / 026 | Total loss: 2.738 | Reg loss: 0.030 | Tree loss: 2.738 | Accuracy: 0.132812 | 0.062 sec/iter
Epoch: 96 | Batch: 025 / 026 | Total loss: 2.745 | Reg loss: 0.030 | Tree loss: 2.745 | Accuracy: 0.144847 | 0.062 sec/iter
Average sparseness: 0.9821428571428567
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
Epoch: 97 | Batch: 000 / 026 | Total loss: 2.866 | Reg loss: 0.030 | Tree los

Epoch: 99 | Batch: 005 / 026 | Total loss: 2.808 | Reg loss: 0.030 | Tree loss: 2.808 | Accuracy: 0.142578 | 0.062 sec/iter
Epoch: 99 | Batch: 006 / 026 | Total loss: 2.850 | Reg loss: 0.030 | Tree loss: 2.850 | Accuracy: 0.103516 | 0.062 sec/iter
Epoch: 99 | Batch: 007 / 026 | Total loss: 2.834 | Reg loss: 0.030 | Tree loss: 2.834 | Accuracy: 0.097656 | 0.062 sec/iter
Epoch: 99 | Batch: 008 / 026 | Total loss: 2.827 | Reg loss: 0.030 | Tree loss: 2.827 | Accuracy: 0.103516 | 0.062 sec/iter
Epoch: 99 | Batch: 009 / 026 | Total loss: 2.814 | Reg loss: 0.030 | Tree loss: 2.814 | Accuracy: 0.113281 | 0.062 sec/iter
Epoch: 99 | Batch: 010 / 026 | Total loss: 2.837 | Reg loss: 0.030 | Tree loss: 2.837 | Accuracy: 0.091797 | 0.062 sec/iter
Epoch: 99 | Batch: 011 / 026 | Total loss: 2.818 | Reg loss: 0.030 | Tree loss: 2.818 | Accuracy: 0.089844 | 0.062 sec/iter
Epoch: 99 | Batch: 012 / 026 | Total loss: 2.781 | Reg loss: 0.030 | Tree loss: 2.781 | Accuracy: 0.099609 | 0.062 sec/iter
Epoch: 9

In [33]:
plt.figure(figsize=(10, 5))
plt.ylabel("Accuracy")
plt.xlabel('Iteration')
plt.plot(accs, label='Accuracy vs iteration')
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [34]:
plt.figure()
plt.ylabel("Loss")
plt.xlabel('Iteration')
plt.plot(losses, label='Loss vs iteration')
plt.yscale("log")
plt.show()

plt.figure()
weights = tree.inner_nodes.weight.cpu().detach().numpy().flatten()
plt.hist(weights, bins=500)
weights_std = np.std(weights)
weights_mean = np.mean(weights)
plt.axvline(weights_mean + weights_std, color='r')
plt.axvline(weights_mean - weights_std, color='r')
plt.title(f"Mean: {weights_mean}   |   STD: {weights_std}")
plt.yscale("log")
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

# Tree Visualization

In [35]:
plt.figure(figsize=(15, 10), dpi=80)
avg_height, root = tree.visualize()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Average height: 5.595238095238095


# Extract Rules

# Accumulate samples in the leaves

In [40]:
print(f"Number of patterns: {len(root.get_leaves())}")

Number of patterns: 42


In [41]:
method = 'greedy'

In [42]:
root.clear_leaves_samples()

with torch.no_grad():
    for batch_idx, (data, target) in enumerate(tree_loader):
        root.accumulate_samples(data, method)

# Tighten boundaries

In [44]:
attr_names = dataset.items

# print(attr_names)
leaves = root.get_leaves()
sum_comprehensibility = 0
comprehensibilities = []
for pattern_counter, leaf in enumerate(leaves):
    leaf.reset_path()
    leaf.tighten_with_accumulated_samples()
    conds = leaf.get_path_conditions(attr_names)
    print(f"============== Pattern {pattern_counter + 1} ==============")
    comprehensibilities.append(sum([cond.comprehensibility for cond in conds]))
    
print(f"Average comprehensibility: {np.mean(comprehensibilities)}")
print(f"std comprehensibility: {np.std(comprehensibilities)}")
print(f"var comprehensibility: {np.var(comprehensibilities)}")
print(f"minimum comprehensibility: {np.min(comprehensibilities)}")
print(f"maximum comprehensibility: {np.max(comprehensibilities)}")

9744
3415
Average comprehensibility: 26.047619047619047
std comprehensibility: 3.2215966239205422
var comprehensibility: 10.378684807256237
minimum comprehensibility: 18
maximum comprehensibility: 30


  return np.log(1 / (1 - x))
