In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib widget
import numpy as np
import pandas as pd
import torch
import torchvision
from torch import optim
import matplotlib.pyplot as plt
import torch.nn.functional as F
from sklearn.metrics import pairwise_distances
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tqdm.notebook import tqdm
import os
import sys
import time
import torch.nn as nn
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)
from stream_generators.market_basket_dataset import MarketBasketDataset, BinaryEncodingTransform, RemoveItemsTransform
from utils.MatplotlibUtils import reduce_dims_and_plot
from network.auto_encoder import AutoEncoder
from losses.knn_loss import KNNLoss
from sklearn.cluster import DBSCAN
from sklearn.mixture import GaussianMixture
from soft_decision_tree.sdt_model import SDT
from sklearn.metrics import davies_bouldin_score

In [2]:
k = 16
tree_depth = 8
device = 'cuda'
dataset_path = r"/mnt/qnap/ekosman/Groceries_dataset.csv"

## Create the market basket dataset and use one-hot encoding for items

In [3]:
dataset = MarketBasketDataset(dataset_path=dataset_path)

In [4]:
model = AutoEncoder(dataset.n_items, 50, 4).train().to(device)
epochs = 500
lr = 5e-3
batch_size = 512
log_every = 5

In [5]:
dataset.transform = torchvision.transforms.Compose([
    RemoveItemsTransform(p=0.5),
    BinaryEncodingTransform(mapping=dataset.items_to_idx),
]
)
dataset.target_transform = torchvision.transforms.Compose([
    BinaryEncodingTransform(mapping=dataset.items_to_idx),
]
)

In [6]:
model.train()
data_iter = torch.utils.data.DataLoader(dataset,
                                     batch_size=batch_size,
                                     shuffle=True,
                                     num_workers=1,
                                     pin_memory=True)
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)
scheduler = ReduceLROnPlateau(optimizer, 'min', verbose=True, factor=0.7, threshold=1e-4)
knn_crt = KNNLoss(k=k).to(device)
losses = []
alpha = 10/170
gamma = 2
for epoch in range(epochs):
    total_loss = 0
    for iteration, (batch, target) in enumerate(data_iter):
        batch = batch.to(device)
        target = target.to(device)
        outputs, iterm = model(batch, return_intermidiate=True)
        mse_loss = F.binary_cross_entropy_with_logits(outputs, target, reduction='none')
        mask = torch.ones_like(mse_loss)
        mask[target == 0] = alpha ** gamma
        mask[target == 1] = (1 - alpha) ** gamma
        mse_loss = (mse_loss * mask).sum(dim=-1).mean()
        try:
            knn_loss = knn_crt(iterm)
            if torch.isinf(knn_loss):
                knn_loss = 0
        except ValueError:
            knn_loss = torch.tensor(0)
        loss = mse_loss + knn_loss
        total_loss += loss.item()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if iteration % log_every == 0:
            print(f"Epoch {epoch} / {epochs} | iteration {iteration} / {len(data_iter)} | Total Loss: {loss.item()} | KNN Loss: {knn_loss.item()} | BCE Loss: {mse_loss.item()}")
    
    scheduler.step(total_loss / (iteration + 1))
    losses.append(total_loss / (iteration + 1))

Epoch 0 / 500 | iteration 0 / 30 | Total Loss: 8.131396293640137 | KNN Loss: 6.227344036102295 | BCE Loss: 1.9040518999099731
Epoch 0 / 500 | iteration 5 / 30 | Total Loss: 8.219286918640137 | KNN Loss: 6.227286338806152 | BCE Loss: 1.9920002222061157
Epoch 0 / 500 | iteration 10 / 30 | Total Loss: 8.189247131347656 | KNN Loss: 6.226543426513672 | BCE Loss: 1.9627039432525635
Epoch 0 / 500 | iteration 15 / 30 | Total Loss: 8.169183731079102 | KNN Loss: 6.226537704467773 | BCE Loss: 1.9426462650299072
Epoch 0 / 500 | iteration 20 / 30 | Total Loss: 8.15272331237793 | KNN Loss: 6.2260355949401855 | BCE Loss: 1.926687479019165
Epoch 0 / 500 | iteration 25 / 30 | Total Loss: 8.198753356933594 | KNN Loss: 6.225461959838867 | BCE Loss: 1.973291039466858
Epoch 1 / 500 | iteration 0 / 30 | Total Loss: 8.116584777832031 | KNN Loss: 6.225305557250977 | BCE Loss: 1.8912787437438965
Epoch 1 / 500 | iteration 5 / 30 | Total Loss: 8.173934936523438 | KNN Loss: 6.224647521972656 | BCE Loss: 1.9492878

Epoch 10 / 500 | iteration 25 / 30 | Total Loss: 5.735421180725098 | KNN Loss: 4.626675605773926 | BCE Loss: 1.1087453365325928
Epoch 11 / 500 | iteration 0 / 30 | Total Loss: 5.6172637939453125 | KNN Loss: 4.494812965393066 | BCE Loss: 1.122450590133667
Epoch 11 / 500 | iteration 5 / 30 | Total Loss: 5.558211326599121 | KNN Loss: 4.4204607009887695 | BCE Loss: 1.1377503871917725
Epoch 11 / 500 | iteration 10 / 30 | Total Loss: 5.433544158935547 | KNN Loss: 4.292762279510498 | BCE Loss: 1.140782117843628
Epoch 11 / 500 | iteration 15 / 30 | Total Loss: 5.3397369384765625 | KNN Loss: 4.232474327087402 | BCE Loss: 1.1072626113891602
Epoch 11 / 500 | iteration 20 / 30 | Total Loss: 5.2252984046936035 | KNN Loss: 4.119917392730713 | BCE Loss: 1.1053811311721802
Epoch 11 / 500 | iteration 25 / 30 | Total Loss: 5.177340507507324 | KNN Loss: 4.050661563873291 | BCE Loss: 1.126678705215454
Epoch 12 / 500 | iteration 0 / 30 | Total Loss: 5.075552940368652 | KNN Loss: 3.9711804389953613 | BCE Lo

Epoch 21 / 500 | iteration 20 / 30 | Total Loss: 4.31995964050293 | KNN Loss: 3.2258055210113525 | BCE Loss: 1.094153881072998
Epoch 21 / 500 | iteration 25 / 30 | Total Loss: 4.285663604736328 | KNN Loss: 3.239567518234253 | BCE Loss: 1.0460963249206543
Epoch 22 / 500 | iteration 0 / 30 | Total Loss: 4.283781051635742 | KNN Loss: 3.248910903930664 | BCE Loss: 1.034869909286499
Epoch 22 / 500 | iteration 5 / 30 | Total Loss: 4.315661907196045 | KNN Loss: 3.242034912109375 | BCE Loss: 1.0736271142959595
Epoch 22 / 500 | iteration 10 / 30 | Total Loss: 4.284772872924805 | KNN Loss: 3.226449728012085 | BCE Loss: 1.0583233833312988
Epoch 22 / 500 | iteration 15 / 30 | Total Loss: 4.303432464599609 | KNN Loss: 3.245673418045044 | BCE Loss: 1.0577588081359863
Epoch 22 / 500 | iteration 20 / 30 | Total Loss: 4.270933628082275 | KNN Loss: 3.180241346359253 | BCE Loss: 1.0906922817230225
Epoch 22 / 500 | iteration 25 / 30 | Total Loss: 4.238199234008789 | KNN Loss: 3.1943888664245605 | BCE Loss

Epoch 32 / 500 | iteration 15 / 30 | Total Loss: 4.13721227645874 | KNN Loss: 3.1251296997070312 | BCE Loss: 1.012082576751709
Epoch 32 / 500 | iteration 20 / 30 | Total Loss: 4.149568557739258 | KNN Loss: 3.1208372116088867 | BCE Loss: 1.0287314653396606
Epoch 32 / 500 | iteration 25 / 30 | Total Loss: 4.170516014099121 | KNN Loss: 3.142583131790161 | BCE Loss: 1.0279330015182495
Epoch 33 / 500 | iteration 0 / 30 | Total Loss: 4.163435935974121 | KNN Loss: 3.113342761993408 | BCE Loss: 1.0500929355621338
Epoch 33 / 500 | iteration 5 / 30 | Total Loss: 4.1312336921691895 | KNN Loss: 3.119642496109009 | BCE Loss: 1.0115911960601807
Epoch 33 / 500 | iteration 10 / 30 | Total Loss: 4.195626258850098 | KNN Loss: 3.148836612701416 | BCE Loss: 1.0467896461486816
Epoch 33 / 500 | iteration 15 / 30 | Total Loss: 4.19155216217041 | KNN Loss: 3.149794816970825 | BCE Loss: 1.041757583618164
Epoch 33 / 500 | iteration 20 / 30 | Total Loss: 4.164190292358398 | KNN Loss: 3.1522083282470703 | BCE Los

Epoch 43 / 500 | iteration 10 / 30 | Total Loss: 4.186305999755859 | KNN Loss: 3.139052391052246 | BCE Loss: 1.0472533702850342
Epoch 43 / 500 | iteration 15 / 30 | Total Loss: 4.175289154052734 | KNN Loss: 3.1487531661987305 | BCE Loss: 1.0265361070632935
Epoch 43 / 500 | iteration 20 / 30 | Total Loss: 4.121947288513184 | KNN Loss: 3.1192805767059326 | BCE Loss: 1.00266695022583
Epoch 43 / 500 | iteration 25 / 30 | Total Loss: 4.216286659240723 | KNN Loss: 3.1491591930389404 | BCE Loss: 1.0671277046203613
Epoch 44 / 500 | iteration 0 / 30 | Total Loss: 4.175373077392578 | KNN Loss: 3.1491129398345947 | BCE Loss: 1.0262600183486938
Epoch 44 / 500 | iteration 5 / 30 | Total Loss: 4.200582981109619 | KNN Loss: 3.1648108959198 | BCE Loss: 1.0357720851898193
Epoch 44 / 500 | iteration 10 / 30 | Total Loss: 4.13418436050415 | KNN Loss: 3.1326236724853516 | BCE Loss: 1.0015605688095093
Epoch 44 / 500 | iteration 15 / 30 | Total Loss: 4.181057929992676 | KNN Loss: 3.1439993381500244 | BCE Lo

Epoch 54 / 500 | iteration 0 / 30 | Total Loss: 4.112297058105469 | KNN Loss: 3.101496934890747 | BCE Loss: 1.0108003616333008
Epoch 54 / 500 | iteration 5 / 30 | Total Loss: 4.187385082244873 | KNN Loss: 3.1383895874023438 | BCE Loss: 1.0489956140518188
Epoch 54 / 500 | iteration 10 / 30 | Total Loss: 4.11583948135376 | KNN Loss: 3.110107421875 | BCE Loss: 1.0057320594787598
Epoch 54 / 500 | iteration 15 / 30 | Total Loss: 4.141858100891113 | KNN Loss: 3.113483190536499 | BCE Loss: 1.0283749103546143
Epoch 54 / 500 | iteration 20 / 30 | Total Loss: 4.1421918869018555 | KNN Loss: 3.1321442127227783 | BCE Loss: 1.0100479125976562
Epoch 54 / 500 | iteration 25 / 30 | Total Loss: 4.166543006896973 | KNN Loss: 3.104938507080078 | BCE Loss: 1.0616044998168945
Epoch 55 / 500 | iteration 0 / 30 | Total Loss: 4.148497581481934 | KNN Loss: 3.126340627670288 | BCE Loss: 1.0221569538116455
Epoch 55 / 500 | iteration 5 / 30 | Total Loss: 4.132956027984619 | KNN Loss: 3.110224962234497 | BCE Loss: 

Epoch 64 / 500 | iteration 25 / 30 | Total Loss: 4.137509822845459 | KNN Loss: 3.108593225479126 | BCE Loss: 1.0289167165756226
Epoch 65 / 500 | iteration 0 / 30 | Total Loss: 4.164376258850098 | KNN Loss: 3.126413345336914 | BCE Loss: 1.0379631519317627
Epoch 65 / 500 | iteration 5 / 30 | Total Loss: 4.120737075805664 | KNN Loss: 3.098545789718628 | BCE Loss: 1.022191047668457
Epoch 65 / 500 | iteration 10 / 30 | Total Loss: 4.113468170166016 | KNN Loss: 3.087542772293091 | BCE Loss: 1.0259253978729248
Epoch 65 / 500 | iteration 15 / 30 | Total Loss: 4.144238471984863 | KNN Loss: 3.13447642326355 | BCE Loss: 1.0097622871398926
Epoch 65 / 500 | iteration 20 / 30 | Total Loss: 4.091134548187256 | KNN Loss: 3.071777105331421 | BCE Loss: 1.0193575620651245
Epoch 65 / 500 | iteration 25 / 30 | Total Loss: 4.081925392150879 | KNN Loss: 3.073878049850464 | BCE Loss: 1.0080475807189941
Epoch 66 / 500 | iteration 0 / 30 | Total Loss: 4.124319076538086 | KNN Loss: 3.097728967666626 | BCE Loss: 

Epoch 75 / 500 | iteration 15 / 30 | Total Loss: 4.102123260498047 | KNN Loss: 3.0964179039001465 | BCE Loss: 1.0057055950164795
Epoch 75 / 500 | iteration 20 / 30 | Total Loss: 4.149783134460449 | KNN Loss: 3.1033875942230225 | BCE Loss: 1.0463953018188477
Epoch 75 / 500 | iteration 25 / 30 | Total Loss: 4.123112201690674 | KNN Loss: 3.090872287750244 | BCE Loss: 1.0322397947311401
Epoch 76 / 500 | iteration 0 / 30 | Total Loss: 4.051721096038818 | KNN Loss: 3.0492944717407227 | BCE Loss: 1.0024265050888062
Epoch 76 / 500 | iteration 5 / 30 | Total Loss: 4.124121189117432 | KNN Loss: 3.087127208709717 | BCE Loss: 1.0369940996170044
Epoch 76 / 500 | iteration 10 / 30 | Total Loss: 4.096490859985352 | KNN Loss: 3.0937042236328125 | BCE Loss: 1.00278639793396
Epoch 76 / 500 | iteration 15 / 30 | Total Loss: 4.122089385986328 | KNN Loss: 3.1060736179351807 | BCE Loss: 1.0160157680511475
Epoch 76 / 500 | iteration 20 / 30 | Total Loss: 4.123082160949707 | KNN Loss: 3.0880813598632812 | BCE

Epoch 86 / 500 | iteration 10 / 30 | Total Loss: 4.119910717010498 | KNN Loss: 3.100619316101074 | BCE Loss: 1.0192914009094238
Epoch 86 / 500 | iteration 15 / 30 | Total Loss: 4.1505584716796875 | KNN Loss: 3.093820095062256 | BCE Loss: 1.0567383766174316
Epoch 86 / 500 | iteration 20 / 30 | Total Loss: 4.106929302215576 | KNN Loss: 3.075411081314087 | BCE Loss: 1.0315182209014893
Epoch 86 / 500 | iteration 25 / 30 | Total Loss: 4.141281604766846 | KNN Loss: 3.1092159748077393 | BCE Loss: 1.0320656299591064
Epoch 87 / 500 | iteration 0 / 30 | Total Loss: 4.099991798400879 | KNN Loss: 3.063936233520508 | BCE Loss: 1.0360554456710815
Epoch 87 / 500 | iteration 5 / 30 | Total Loss: 4.088787078857422 | KNN Loss: 3.0723607540130615 | BCE Loss: 1.0164262056350708
Epoch 87 / 500 | iteration 10 / 30 | Total Loss: 4.074583530426025 | KNN Loss: 3.0688979625701904 | BCE Loss: 1.005685567855835
Epoch 87 / 500 | iteration 15 / 30 | Total Loss: 4.162245750427246 | KNN Loss: 3.1315207481384277 | BCE

Epoch 97 / 500 | iteration 0 / 30 | Total Loss: 4.122441291809082 | KNN Loss: 3.077540397644043 | BCE Loss: 1.0449007749557495
Epoch 97 / 500 | iteration 5 / 30 | Total Loss: 4.1388678550720215 | KNN Loss: 3.0992136001586914 | BCE Loss: 1.0396541357040405
Epoch 97 / 500 | iteration 10 / 30 | Total Loss: 4.140501976013184 | KNN Loss: 3.0726561546325684 | BCE Loss: 1.0678460597991943
Epoch 97 / 500 | iteration 15 / 30 | Total Loss: 4.123679161071777 | KNN Loss: 3.110231637954712 | BCE Loss: 1.0134477615356445
Epoch 97 / 500 | iteration 20 / 30 | Total Loss: 4.1192121505737305 | KNN Loss: 3.0950584411621094 | BCE Loss: 1.0241539478302002
Epoch 97 / 500 | iteration 25 / 30 | Total Loss: 4.102678298950195 | KNN Loss: 3.0854270458221436 | BCE Loss: 1.0172513723373413
Epoch 98 / 500 | iteration 0 / 30 | Total Loss: 4.08939266204834 | KNN Loss: 3.084026575088501 | BCE Loss: 1.0053658485412598
Epoch 98 / 500 | iteration 5 / 30 | Total Loss: 4.1371002197265625 | KNN Loss: 3.089794635772705 | BCE

Epoch 107 / 500 | iteration 20 / 30 | Total Loss: 4.1084113121032715 | KNN Loss: 3.070949077606201 | BCE Loss: 1.0374621152877808
Epoch 107 / 500 | iteration 25 / 30 | Total Loss: 4.055578708648682 | KNN Loss: 3.0447654724121094 | BCE Loss: 1.0108132362365723
Epoch   108: reducing learning rate of group 0 to 1.2005e-03.
Epoch 108 / 500 | iteration 0 / 30 | Total Loss: 4.096551895141602 | KNN Loss: 3.067798614501953 | BCE Loss: 1.0287530422210693
Epoch 108 / 500 | iteration 5 / 30 | Total Loss: 4.086984157562256 | KNN Loss: 3.0695960521698 | BCE Loss: 1.017388105392456
Epoch 108 / 500 | iteration 10 / 30 | Total Loss: 4.115662574768066 | KNN Loss: 3.079214572906494 | BCE Loss: 1.0364482402801514
Epoch 108 / 500 | iteration 15 / 30 | Total Loss: 4.149948596954346 | KNN Loss: 3.109612226486206 | BCE Loss: 1.0403363704681396
Epoch 108 / 500 | iteration 20 / 30 | Total Loss: 4.130278587341309 | KNN Loss: 3.078092098236084 | BCE Loss: 1.052186369895935
Epoch 108 / 500 | iteration 25 / 30 | T

Epoch 118 / 500 | iteration 10 / 30 | Total Loss: 4.115325927734375 | KNN Loss: 3.0848610401153564 | BCE Loss: 1.0304651260375977
Epoch 118 / 500 | iteration 15 / 30 | Total Loss: 4.088039398193359 | KNN Loss: 3.0564053058624268 | BCE Loss: 1.0316338539123535
Epoch 118 / 500 | iteration 20 / 30 | Total Loss: 4.080898284912109 | KNN Loss: 3.0773017406463623 | BCE Loss: 1.003596544265747
Epoch 118 / 500 | iteration 25 / 30 | Total Loss: 4.082291603088379 | KNN Loss: 3.0610828399658203 | BCE Loss: 1.0212087631225586
Epoch 119 / 500 | iteration 0 / 30 | Total Loss: 4.174953460693359 | KNN Loss: 3.1253714561462402 | BCE Loss: 1.04958176612854
Epoch 119 / 500 | iteration 5 / 30 | Total Loss: 4.093533515930176 | KNN Loss: 3.0718834400177 | BCE Loss: 1.0216498374938965
Epoch 119 / 500 | iteration 10 / 30 | Total Loss: 4.093579292297363 | KNN Loss: 3.0596117973327637 | BCE Loss: 1.0339672565460205
Epoch 119 / 500 | iteration 15 / 30 | Total Loss: 4.072288990020752 | KNN Loss: 3.068572998046875 

Epoch 129 / 500 | iteration 0 / 30 | Total Loss: 4.124573707580566 | KNN Loss: 3.1080117225646973 | BCE Loss: 1.0165618658065796
Epoch 129 / 500 | iteration 5 / 30 | Total Loss: 4.068890571594238 | KNN Loss: 3.0575110912323 | BCE Loss: 1.0113792419433594
Epoch 129 / 500 | iteration 10 / 30 | Total Loss: 4.150649547576904 | KNN Loss: 3.103487253189087 | BCE Loss: 1.0471622943878174
Epoch 129 / 500 | iteration 15 / 30 | Total Loss: 4.114579677581787 | KNN Loss: 3.076709508895874 | BCE Loss: 1.037870168685913
Epoch 129 / 500 | iteration 20 / 30 | Total Loss: 4.137100696563721 | KNN Loss: 3.0883638858795166 | BCE Loss: 1.048736810684204
Epoch 129 / 500 | iteration 25 / 30 | Total Loss: 4.103177547454834 | KNN Loss: 3.1135854721069336 | BCE Loss: 0.9895921945571899
Epoch 130 / 500 | iteration 0 / 30 | Total Loss: 4.092764854431152 | KNN Loss: 3.063499689102173 | BCE Loss: 1.0292654037475586
Epoch 130 / 500 | iteration 5 / 30 | Total Loss: 4.114772796630859 | KNN Loss: 3.0711395740509033 | B

Epoch 139 / 500 | iteration 20 / 30 | Total Loss: 4.089846134185791 | KNN Loss: 3.065338373184204 | BCE Loss: 1.0245076417922974
Epoch 139 / 500 | iteration 25 / 30 | Total Loss: 4.053929328918457 | KNN Loss: 3.0428435802459717 | BCE Loss: 1.0110857486724854
Epoch 140 / 500 | iteration 0 / 30 | Total Loss: 4.047557830810547 | KNN Loss: 3.0383076667785645 | BCE Loss: 1.0092504024505615
Epoch 140 / 500 | iteration 5 / 30 | Total Loss: 4.083378791809082 | KNN Loss: 3.074779510498047 | BCE Loss: 1.0085991621017456
Epoch 140 / 500 | iteration 10 / 30 | Total Loss: 4.130561351776123 | KNN Loss: 3.094024419784546 | BCE Loss: 1.0365370512008667
Epoch 140 / 500 | iteration 15 / 30 | Total Loss: 4.155498027801514 | KNN Loss: 3.086491346359253 | BCE Loss: 1.0690068006515503
Epoch 140 / 500 | iteration 20 / 30 | Total Loss: 4.060441970825195 | KNN Loss: 3.0615506172180176 | BCE Loss: 0.9988915920257568
Epoch 140 / 500 | iteration 25 / 30 | Total Loss: 4.061905384063721 | KNN Loss: 3.07179450988769

Epoch 150 / 500 | iteration 10 / 30 | Total Loss: 4.084072113037109 | KNN Loss: 3.0746548175811768 | BCE Loss: 1.0094174146652222
Epoch 150 / 500 | iteration 15 / 30 | Total Loss: 4.096889972686768 | KNN Loss: 3.0281858444213867 | BCE Loss: 1.0687041282653809
Epoch 150 / 500 | iteration 20 / 30 | Total Loss: 4.092436790466309 | KNN Loss: 3.0867021083831787 | BCE Loss: 1.0057344436645508
Epoch 150 / 500 | iteration 25 / 30 | Total Loss: 4.099284648895264 | KNN Loss: 3.0897903442382812 | BCE Loss: 1.0094943046569824
Epoch 151 / 500 | iteration 0 / 30 | Total Loss: 4.085549354553223 | KNN Loss: 3.0559909343719482 | BCE Loss: 1.0295581817626953
Epoch 151 / 500 | iteration 5 / 30 | Total Loss: 4.068517684936523 | KNN Loss: 3.057922840118408 | BCE Loss: 1.0105946063995361
Epoch 151 / 500 | iteration 10 / 30 | Total Loss: 4.104161739349365 | KNN Loss: 3.0666351318359375 | BCE Loss: 1.0375266075134277
Epoch 151 / 500 | iteration 15 / 30 | Total Loss: 4.098696708679199 | KNN Loss: 3.10074353218

Epoch 160 / 500 | iteration 25 / 30 | Total Loss: 4.118847846984863 | KNN Loss: 3.0953078269958496 | BCE Loss: 1.0235401391983032
Epoch 161 / 500 | iteration 0 / 30 | Total Loss: 4.104015350341797 | KNN Loss: 3.065640926361084 | BCE Loss: 1.0383741855621338
Epoch 161 / 500 | iteration 5 / 30 | Total Loss: 4.121308326721191 | KNN Loss: 3.08447265625 | BCE Loss: 1.036835789680481
Epoch 161 / 500 | iteration 10 / 30 | Total Loss: 4.085082054138184 | KNN Loss: 3.0700488090515137 | BCE Loss: 1.01503324508667
Epoch 161 / 500 | iteration 15 / 30 | Total Loss: 4.08622932434082 | KNN Loss: 3.05708646774292 | BCE Loss: 1.02914297580719
Epoch 161 / 500 | iteration 20 / 30 | Total Loss: 4.074995994567871 | KNN Loss: 3.0457217693328857 | BCE Loss: 1.0292744636535645
Epoch 161 / 500 | iteration 25 / 30 | Total Loss: 4.192248344421387 | KNN Loss: 3.115696430206299 | BCE Loss: 1.0765516757965088
Epoch 162 / 500 | iteration 0 / 30 | Total Loss: 4.0836076736450195 | KNN Loss: 3.072368621826172 | BCE Los

Epoch 171 / 500 | iteration 15 / 30 | Total Loss: 4.074243545532227 | KNN Loss: 3.054414749145508 | BCE Loss: 1.0198285579681396
Epoch 171 / 500 | iteration 20 / 30 | Total Loss: 4.090884208679199 | KNN Loss: 3.070280075073242 | BCE Loss: 1.020604133605957
Epoch 171 / 500 | iteration 25 / 30 | Total Loss: 4.080811023712158 | KNN Loss: 3.0609686374664307 | BCE Loss: 1.0198423862457275
Epoch 172 / 500 | iteration 0 / 30 | Total Loss: 4.056197643280029 | KNN Loss: 3.0462186336517334 | BCE Loss: 1.0099788904190063
Epoch 172 / 500 | iteration 5 / 30 | Total Loss: 4.092959403991699 | KNN Loss: 3.0743303298950195 | BCE Loss: 1.0186288356781006
Epoch 172 / 500 | iteration 10 / 30 | Total Loss: 4.103797435760498 | KNN Loss: 3.069753885269165 | BCE Loss: 1.034043550491333
Epoch 172 / 500 | iteration 15 / 30 | Total Loss: 4.095306873321533 | KNN Loss: 3.078124523162842 | BCE Loss: 1.0171823501586914
Epoch 172 / 500 | iteration 20 / 30 | Total Loss: 4.088261604309082 | KNN Loss: 3.0722692012786865

Epoch 182 / 500 | iteration 5 / 30 | Total Loss: 4.134346961975098 | KNN Loss: 3.089235544204712 | BCE Loss: 1.0451116561889648
Epoch 182 / 500 | iteration 10 / 30 | Total Loss: 4.103343963623047 | KNN Loss: 3.10176157951355 | BCE Loss: 1.001582384109497
Epoch 182 / 500 | iteration 15 / 30 | Total Loss: 4.095044136047363 | KNN Loss: 3.078857421875 | BCE Loss: 1.0161864757537842
Epoch 182 / 500 | iteration 20 / 30 | Total Loss: 4.05525016784668 | KNN Loss: 3.0586464405059814 | BCE Loss: 0.9966038465499878
Epoch 182 / 500 | iteration 25 / 30 | Total Loss: 4.088606834411621 | KNN Loss: 3.0439577102661133 | BCE Loss: 1.0446490049362183
Epoch 183 / 500 | iteration 0 / 30 | Total Loss: 4.073754787445068 | KNN Loss: 3.063962697982788 | BCE Loss: 1.0097920894622803
Epoch 183 / 500 | iteration 5 / 30 | Total Loss: 4.089110374450684 | KNN Loss: 3.0878055095672607 | BCE Loss: 1.001305103302002
Epoch 183 / 500 | iteration 10 / 30 | Total Loss: 4.104755401611328 | KNN Loss: 3.103487730026245 | BCE 

Epoch 192 / 500 | iteration 25 / 30 | Total Loss: 4.064971923828125 | KNN Loss: 3.038792133331299 | BCE Loss: 1.026179552078247
Epoch 193 / 500 | iteration 0 / 30 | Total Loss: 4.052854061126709 | KNN Loss: 3.080481767654419 | BCE Loss: 0.9723724722862244
Epoch 193 / 500 | iteration 5 / 30 | Total Loss: 4.069075584411621 | KNN Loss: 3.0547313690185547 | BCE Loss: 1.0143440961837769
Epoch 193 / 500 | iteration 10 / 30 | Total Loss: 4.0598249435424805 | KNN Loss: 3.02384614944458 | BCE Loss: 1.0359790325164795
Epoch 193 / 500 | iteration 15 / 30 | Total Loss: 4.0571770668029785 | KNN Loss: 3.055776596069336 | BCE Loss: 1.0014004707336426
Epoch 193 / 500 | iteration 20 / 30 | Total Loss: 4.059244632720947 | KNN Loss: 3.03291654586792 | BCE Loss: 1.0263279676437378
Epoch 193 / 500 | iteration 25 / 30 | Total Loss: 4.12431001663208 | KNN Loss: 3.0812485218048096 | BCE Loss: 1.043061375617981
Epoch 194 / 500 | iteration 0 / 30 | Total Loss: 4.089166641235352 | KNN Loss: 3.0641841888427734 | 

Epoch 203 / 500 | iteration 15 / 30 | Total Loss: 4.107048511505127 | KNN Loss: 3.0775537490844727 | BCE Loss: 1.0294947624206543
Epoch 203 / 500 | iteration 20 / 30 | Total Loss: 4.081412315368652 | KNN Loss: 3.066655397415161 | BCE Loss: 1.014756679534912
Epoch 203 / 500 | iteration 25 / 30 | Total Loss: 4.125693321228027 | KNN Loss: 3.0944645404815674 | BCE Loss: 1.03122878074646
Epoch 204 / 500 | iteration 0 / 30 | Total Loss: 4.076618194580078 | KNN Loss: 3.051234483718872 | BCE Loss: 1.025383710861206
Epoch 204 / 500 | iteration 5 / 30 | Total Loss: 4.080120086669922 | KNN Loss: 3.063258647918701 | BCE Loss: 1.0168616771697998
Epoch 204 / 500 | iteration 10 / 30 | Total Loss: 4.060517311096191 | KNN Loss: 3.029059648513794 | BCE Loss: 1.0314574241638184
Epoch 204 / 500 | iteration 15 / 30 | Total Loss: 4.064132213592529 | KNN Loss: 3.0640344619750977 | BCE Loss: 1.0000977516174316
Epoch 204 / 500 | iteration 20 / 30 | Total Loss: 4.111478805541992 | KNN Loss: 3.086312770843506 | 

Epoch 214 / 500 | iteration 5 / 30 | Total Loss: 4.120598316192627 | KNN Loss: 3.0905704498291016 | BCE Loss: 1.0300278663635254
Epoch 214 / 500 | iteration 10 / 30 | Total Loss: 4.145369529724121 | KNN Loss: 3.096073627471924 | BCE Loss: 1.0492961406707764
Epoch 214 / 500 | iteration 15 / 30 | Total Loss: 4.148062705993652 | KNN Loss: 3.098355770111084 | BCE Loss: 1.0497069358825684
Epoch 214 / 500 | iteration 20 / 30 | Total Loss: 4.086316108703613 | KNN Loss: 3.0515215396881104 | BCE Loss: 1.0347943305969238
Epoch 214 / 500 | iteration 25 / 30 | Total Loss: 4.063759803771973 | KNN Loss: 3.067037582397461 | BCE Loss: 0.9967223405838013
Epoch 215 / 500 | iteration 0 / 30 | Total Loss: 4.120593547821045 | KNN Loss: 3.0807435512542725 | BCE Loss: 1.039849877357483
Epoch 215 / 500 | iteration 5 / 30 | Total Loss: 4.12528133392334 | KNN Loss: 3.061448574066162 | BCE Loss: 1.0638326406478882
Epoch 215 / 500 | iteration 10 / 30 | Total Loss: 4.0742506980896 | KNN Loss: 3.0651090145111084 | 

Epoch 224 / 500 | iteration 25 / 30 | Total Loss: 4.076581001281738 | KNN Loss: 3.052670478820801 | BCE Loss: 1.0239107608795166
Epoch 225 / 500 | iteration 0 / 30 | Total Loss: 4.104897499084473 | KNN Loss: 3.0452442169189453 | BCE Loss: 1.059653401374817
Epoch 225 / 500 | iteration 5 / 30 | Total Loss: 4.108001708984375 | KNN Loss: 3.0808987617492676 | BCE Loss: 1.0271027088165283
Epoch 225 / 500 | iteration 10 / 30 | Total Loss: 4.116347312927246 | KNN Loss: 3.088130235671997 | BCE Loss: 1.028217077255249
Epoch 225 / 500 | iteration 15 / 30 | Total Loss: 4.117647647857666 | KNN Loss: 3.0619256496429443 | BCE Loss: 1.0557218790054321
Epoch 225 / 500 | iteration 20 / 30 | Total Loss: 4.098866939544678 | KNN Loss: 3.079468250274658 | BCE Loss: 1.019398808479309
Epoch 225 / 500 | iteration 25 / 30 | Total Loss: 4.090507507324219 | KNN Loss: 3.0597829818725586 | BCE Loss: 1.0307246446609497
Epoch 226 / 500 | iteration 0 / 30 | Total Loss: 4.041130542755127 | KNN Loss: 3.0301756858825684 

Epoch 235 / 500 | iteration 15 / 30 | Total Loss: 4.11304235458374 | KNN Loss: 3.079608678817749 | BCE Loss: 1.0334336757659912
Epoch 235 / 500 | iteration 20 / 30 | Total Loss: 4.084815979003906 | KNN Loss: 3.0672075748443604 | BCE Loss: 1.0176081657409668
Epoch 235 / 500 | iteration 25 / 30 | Total Loss: 4.072519779205322 | KNN Loss: 3.0584123134613037 | BCE Loss: 1.0141074657440186
Epoch 236 / 500 | iteration 0 / 30 | Total Loss: 4.111647605895996 | KNN Loss: 3.061455488204956 | BCE Loss: 1.0501922369003296
Epoch 236 / 500 | iteration 5 / 30 | Total Loss: 4.1021270751953125 | KNN Loss: 3.0939254760742188 | BCE Loss: 1.0082013607025146
Epoch 236 / 500 | iteration 10 / 30 | Total Loss: 4.0620527267456055 | KNN Loss: 3.042569398880005 | BCE Loss: 1.0194833278656006
Epoch 236 / 500 | iteration 15 / 30 | Total Loss: 4.043419361114502 | KNN Loss: 3.017348051071167 | BCE Loss: 1.026071310043335
Epoch 236 / 500 | iteration 20 / 30 | Total Loss: 4.126520156860352 | KNN Loss: 3.08962655067443

Epoch 246 / 500 | iteration 5 / 30 | Total Loss: 4.092326641082764 | KNN Loss: 3.0443007946014404 | BCE Loss: 1.0480258464813232
Epoch 246 / 500 | iteration 10 / 30 | Total Loss: 4.145635604858398 | KNN Loss: 3.0890908241271973 | BCE Loss: 1.0565450191497803
Epoch 246 / 500 | iteration 15 / 30 | Total Loss: 4.061256408691406 | KNN Loss: 3.0601353645324707 | BCE Loss: 1.0011212825775146
Epoch 246 / 500 | iteration 20 / 30 | Total Loss: 4.0957112312316895 | KNN Loss: 3.0801279544830322 | BCE Loss: 1.0155832767486572
Epoch 246 / 500 | iteration 25 / 30 | Total Loss: 4.0715179443359375 | KNN Loss: 3.0436267852783203 | BCE Loss: 1.027890920639038
Epoch 247 / 500 | iteration 0 / 30 | Total Loss: 4.108040809631348 | KNN Loss: 3.070302963256836 | BCE Loss: 1.0377378463745117
Epoch 247 / 500 | iteration 5 / 30 | Total Loss: 4.092835903167725 | KNN Loss: 3.0975546836853027 | BCE Loss: 0.9952813386917114
Epoch 247 / 500 | iteration 10 / 30 | Total Loss: 4.098410129547119 | KNN Loss: 3.09741425514

Epoch 256 / 500 | iteration 20 / 30 | Total Loss: 4.044073581695557 | KNN Loss: 3.020670175552368 | BCE Loss: 1.0234034061431885
Epoch 256 / 500 | iteration 25 / 30 | Total Loss: 4.086979866027832 | KNN Loss: 3.079332113265991 | BCE Loss: 1.0076477527618408
Epoch 257 / 500 | iteration 0 / 30 | Total Loss: 4.136874198913574 | KNN Loss: 3.0958328247070312 | BCE Loss: 1.0410414934158325
Epoch 257 / 500 | iteration 5 / 30 | Total Loss: 4.081936836242676 | KNN Loss: 3.065920114517212 | BCE Loss: 1.0160164833068848
Epoch 257 / 500 | iteration 10 / 30 | Total Loss: 4.086392402648926 | KNN Loss: 3.0447659492492676 | BCE Loss: 1.0416263341903687
Epoch 257 / 500 | iteration 15 / 30 | Total Loss: 4.103999614715576 | KNN Loss: 3.085829019546509 | BCE Loss: 1.0181704759597778
Epoch 257 / 500 | iteration 20 / 30 | Total Loss: 4.105611324310303 | KNN Loss: 3.0609638690948486 | BCE Loss: 1.044647455215454
Epoch 257 / 500 | iteration 25 / 30 | Total Loss: 4.081439971923828 | KNN Loss: 3.070415735244751

Epoch 267 / 500 | iteration 10 / 30 | Total Loss: 4.09248685836792 | KNN Loss: 3.067960739135742 | BCE Loss: 1.0245261192321777
Epoch 267 / 500 | iteration 15 / 30 | Total Loss: 4.074141025543213 | KNN Loss: 3.0553879737854004 | BCE Loss: 1.0187530517578125
Epoch 267 / 500 | iteration 20 / 30 | Total Loss: 4.124065399169922 | KNN Loss: 3.1013076305389404 | BCE Loss: 1.0227575302124023
Epoch 267 / 500 | iteration 25 / 30 | Total Loss: 4.072769641876221 | KNN Loss: 3.0630996227264404 | BCE Loss: 1.0096700191497803
Epoch 268 / 500 | iteration 0 / 30 | Total Loss: 4.090710639953613 | KNN Loss: 3.0787253379821777 | BCE Loss: 1.0119853019714355
Epoch 268 / 500 | iteration 5 / 30 | Total Loss: 4.07977819442749 | KNN Loss: 3.0427346229553223 | BCE Loss: 1.037043571472168
Epoch 268 / 500 | iteration 10 / 30 | Total Loss: 4.120046138763428 | KNN Loss: 3.095799207687378 | BCE Loss: 1.0242469310760498
Epoch 268 / 500 | iteration 15 / 30 | Total Loss: 4.0994720458984375 | KNN Loss: 3.07459115982055

Epoch 278 / 500 | iteration 0 / 30 | Total Loss: 4.0761518478393555 | KNN Loss: 3.080254554748535 | BCE Loss: 0.9958970546722412
Epoch 278 / 500 | iteration 5 / 30 | Total Loss: 4.099803924560547 | KNN Loss: 3.07232928276062 | BCE Loss: 1.0274745225906372
Epoch 278 / 500 | iteration 10 / 30 | Total Loss: 4.094264030456543 | KNN Loss: 3.070964813232422 | BCE Loss: 1.0232994556427002
Epoch 278 / 500 | iteration 15 / 30 | Total Loss: 4.142199516296387 | KNN Loss: 3.0860135555267334 | BCE Loss: 1.0561859607696533
Epoch 278 / 500 | iteration 20 / 30 | Total Loss: 4.100869178771973 | KNN Loss: 3.0912442207336426 | BCE Loss: 1.0096250772476196
Epoch 278 / 500 | iteration 25 / 30 | Total Loss: 4.159628868103027 | KNN Loss: 3.0965328216552734 | BCE Loss: 1.0630958080291748
Epoch 279 / 500 | iteration 0 / 30 | Total Loss: 4.076394081115723 | KNN Loss: 3.064857006072998 | BCE Loss: 1.0115370750427246
Epoch 279 / 500 | iteration 5 / 30 | Total Loss: 4.083259582519531 | KNN Loss: 3.068960428237915 

Epoch 288 / 500 | iteration 20 / 30 | Total Loss: 4.085001468658447 | KNN Loss: 3.0719473361968994 | BCE Loss: 1.0130540132522583
Epoch 288 / 500 | iteration 25 / 30 | Total Loss: 4.057371139526367 | KNN Loss: 3.04414701461792 | BCE Loss: 1.0132241249084473
Epoch 289 / 500 | iteration 0 / 30 | Total Loss: 4.0879926681518555 | KNN Loss: 3.0583269596099854 | BCE Loss: 1.0296659469604492
Epoch 289 / 500 | iteration 5 / 30 | Total Loss: 4.10365629196167 | KNN Loss: 3.0687003135681152 | BCE Loss: 1.0349558591842651
Epoch 289 / 500 | iteration 10 / 30 | Total Loss: 4.172995567321777 | KNN Loss: 3.1112582683563232 | BCE Loss: 1.061737060546875
Epoch 289 / 500 | iteration 15 / 30 | Total Loss: 4.088571071624756 | KNN Loss: 3.081333875656128 | BCE Loss: 1.007237195968628
Epoch 289 / 500 | iteration 20 / 30 | Total Loss: 4.082385063171387 | KNN Loss: 3.071139335632324 | BCE Loss: 1.0112459659576416
Epoch 289 / 500 | iteration 25 / 30 | Total Loss: 4.087089538574219 | KNN Loss: 3.076913595199585 

Epoch 299 / 500 | iteration 10 / 30 | Total Loss: 4.113282680511475 | KNN Loss: 3.076467275619507 | BCE Loss: 1.0368152856826782
Epoch 299 / 500 | iteration 15 / 30 | Total Loss: 4.028761386871338 | KNN Loss: 3.033860445022583 | BCE Loss: 0.9949008226394653
Epoch 299 / 500 | iteration 20 / 30 | Total Loss: 4.061757564544678 | KNN Loss: 3.053492784500122 | BCE Loss: 1.0082647800445557
Epoch 299 / 500 | iteration 25 / 30 | Total Loss: 4.099350929260254 | KNN Loss: 3.045837879180908 | BCE Loss: 1.0535131692886353
Epoch 300 / 500 | iteration 0 / 30 | Total Loss: 4.06881046295166 | KNN Loss: 3.0591232776641846 | BCE Loss: 1.009687066078186
Epoch 300 / 500 | iteration 5 / 30 | Total Loss: 4.076324462890625 | KNN Loss: 3.0659828186035156 | BCE Loss: 1.0103414058685303
Epoch 300 / 500 | iteration 10 / 30 | Total Loss: 4.109277248382568 | KNN Loss: 3.0929689407348633 | BCE Loss: 1.016308307647705
Epoch 300 / 500 | iteration 15 / 30 | Total Loss: 4.06688928604126 | KNN Loss: 3.054436206817627 | 

Epoch 310 / 500 | iteration 0 / 30 | Total Loss: 4.052274703979492 | KNN Loss: 3.04594087600708 | BCE Loss: 1.006333589553833
Epoch 310 / 500 | iteration 5 / 30 | Total Loss: 4.08652925491333 | KNN Loss: 3.0514590740203857 | BCE Loss: 1.0350701808929443
Epoch 310 / 500 | iteration 10 / 30 | Total Loss: 4.111854553222656 | KNN Loss: 3.0568714141845703 | BCE Loss: 1.054983377456665
Epoch 310 / 500 | iteration 15 / 30 | Total Loss: 4.084563732147217 | KNN Loss: 3.060687780380249 | BCE Loss: 1.0238760709762573
Epoch 310 / 500 | iteration 20 / 30 | Total Loss: 4.144804954528809 | KNN Loss: 3.0996556282043457 | BCE Loss: 1.045149564743042
Epoch 310 / 500 | iteration 25 / 30 | Total Loss: 4.053451061248779 | KNN Loss: 3.081650495529175 | BCE Loss: 0.9718003869056702
Epoch 311 / 500 | iteration 0 / 30 | Total Loss: 4.079009056091309 | KNN Loss: 3.051497220993042 | BCE Loss: 1.0275115966796875
Epoch 311 / 500 | iteration 5 / 30 | Total Loss: 4.118320941925049 | KNN Loss: 3.078925132751465 | BCE

Epoch 320 / 500 | iteration 20 / 30 | Total Loss: 4.105191230773926 | KNN Loss: 3.1009697914123535 | BCE Loss: 1.0042213201522827
Epoch 320 / 500 | iteration 25 / 30 | Total Loss: 4.101876735687256 | KNN Loss: 3.0505573749542236 | BCE Loss: 1.0513193607330322
Epoch 321 / 500 | iteration 0 / 30 | Total Loss: 4.114068984985352 | KNN Loss: 3.10219407081604 | BCE Loss: 1.0118751525878906
Epoch 321 / 500 | iteration 5 / 30 | Total Loss: 4.086054801940918 | KNN Loss: 3.0530011653900146 | BCE Loss: 1.0330535173416138
Epoch 321 / 500 | iteration 10 / 30 | Total Loss: 4.115144729614258 | KNN Loss: 3.082744836807251 | BCE Loss: 1.032400131225586
Epoch 321 / 500 | iteration 15 / 30 | Total Loss: 4.095075607299805 | KNN Loss: 3.044609308242798 | BCE Loss: 1.0504661798477173
Epoch 321 / 500 | iteration 20 / 30 | Total Loss: 4.0922651290893555 | KNN Loss: 3.0420875549316406 | BCE Loss: 1.0501773357391357
Epoch 321 / 500 | iteration 25 / 30 | Total Loss: 4.108036994934082 | KNN Loss: 3.08954477310180

Epoch 331 / 500 | iteration 10 / 30 | Total Loss: 4.063437461853027 | KNN Loss: 3.0654261112213135 | BCE Loss: 0.9980112314224243
Epoch 331 / 500 | iteration 15 / 30 | Total Loss: 4.056706428527832 | KNN Loss: 3.037548065185547 | BCE Loss: 1.0191584825515747
Epoch 331 / 500 | iteration 20 / 30 | Total Loss: 4.1034955978393555 | KNN Loss: 3.0682175159454346 | BCE Loss: 1.035278081893921
Epoch 331 / 500 | iteration 25 / 30 | Total Loss: 4.107128143310547 | KNN Loss: 3.076261281967163 | BCE Loss: 1.0308668613433838
Epoch 332 / 500 | iteration 0 / 30 | Total Loss: 4.061897277832031 | KNN Loss: 3.035414934158325 | BCE Loss: 1.0264822244644165
Epoch 332 / 500 | iteration 5 / 30 | Total Loss: 4.095222473144531 | KNN Loss: 3.063845634460449 | BCE Loss: 1.031376838684082
Epoch 332 / 500 | iteration 10 / 30 | Total Loss: 4.0807976722717285 | KNN Loss: 3.0590994358062744 | BCE Loss: 1.0216981172561646
Epoch 332 / 500 | iteration 15 / 30 | Total Loss: 4.1414971351623535 | KNN Loss: 3.1070537567138

Epoch 342 / 500 | iteration 0 / 30 | Total Loss: 4.050844669342041 | KNN Loss: 3.044635057449341 | BCE Loss: 1.0062096118927002
Epoch 342 / 500 | iteration 5 / 30 | Total Loss: 4.074559688568115 | KNN Loss: 3.0337564945220947 | BCE Loss: 1.04080331325531
Epoch 342 / 500 | iteration 10 / 30 | Total Loss: 4.130638599395752 | KNN Loss: 3.078523635864258 | BCE Loss: 1.0521150827407837
Epoch 342 / 500 | iteration 15 / 30 | Total Loss: 4.038408279418945 | KNN Loss: 3.057567596435547 | BCE Loss: 0.9808404445648193
Epoch 342 / 500 | iteration 20 / 30 | Total Loss: 4.0691142082214355 | KNN Loss: 3.0593366622924805 | BCE Loss: 1.0097774267196655
Epoch 342 / 500 | iteration 25 / 30 | Total Loss: 4.109340667724609 | KNN Loss: 3.075244188308716 | BCE Loss: 1.0340962409973145
Epoch 343 / 500 | iteration 0 / 30 | Total Loss: 4.058300018310547 | KNN Loss: 3.0556116104125977 | BCE Loss: 1.0026881694793701
Epoch 343 / 500 | iteration 5 / 30 | Total Loss: 4.112673759460449 | KNN Loss: 3.0901453495025635 

Epoch 352 / 500 | iteration 20 / 30 | Total Loss: 4.095426082611084 | KNN Loss: 3.068218231201172 | BCE Loss: 1.0272077322006226
Epoch 352 / 500 | iteration 25 / 30 | Total Loss: 4.061656951904297 | KNN Loss: 3.052975654602051 | BCE Loss: 1.008681297302246
Epoch 353 / 500 | iteration 0 / 30 | Total Loss: 4.066148281097412 | KNN Loss: 3.057011127471924 | BCE Loss: 1.0091372728347778
Epoch 353 / 500 | iteration 5 / 30 | Total Loss: 4.118610382080078 | KNN Loss: 3.0762710571289062 | BCE Loss: 1.0423392057418823
Epoch 353 / 500 | iteration 10 / 30 | Total Loss: 4.117094039916992 | KNN Loss: 3.0718977451324463 | BCE Loss: 1.045196294784546
Epoch 353 / 500 | iteration 15 / 30 | Total Loss: 4.107043266296387 | KNN Loss: 3.105374813079834 | BCE Loss: 1.0016683340072632
Epoch 353 / 500 | iteration 20 / 30 | Total Loss: 4.0703887939453125 | KNN Loss: 3.075707197189331 | BCE Loss: 0.9946817755699158
Epoch 353 / 500 | iteration 25 / 30 | Total Loss: 4.073463439941406 | KNN Loss: 3.07967209815979 |

Epoch 363 / 500 | iteration 10 / 30 | Total Loss: 4.097957611083984 | KNN Loss: 3.0697274208068848 | BCE Loss: 1.0282303094863892
Epoch 363 / 500 | iteration 15 / 30 | Total Loss: 4.036758899688721 | KNN Loss: 3.024502992630005 | BCE Loss: 1.0122559070587158
Epoch 363 / 500 | iteration 20 / 30 | Total Loss: 4.057812690734863 | KNN Loss: 3.05973744392395 | BCE Loss: 0.9980754852294922
Epoch 363 / 500 | iteration 25 / 30 | Total Loss: 4.093024253845215 | KNN Loss: 3.085183620452881 | BCE Loss: 1.007840871810913
Epoch 364 / 500 | iteration 0 / 30 | Total Loss: 4.045061111450195 | KNN Loss: 3.0348076820373535 | BCE Loss: 1.0102531909942627
Epoch 364 / 500 | iteration 5 / 30 | Total Loss: 4.059871196746826 | KNN Loss: 3.0367536544799805 | BCE Loss: 1.0231174230575562
Epoch 364 / 500 | iteration 10 / 30 | Total Loss: 4.057836532592773 | KNN Loss: 3.048776388168335 | BCE Loss: 1.0090601444244385
Epoch 364 / 500 | iteration 15 / 30 | Total Loss: 4.067967414855957 | KNN Loss: 3.0481982231140137

Epoch 374 / 500 | iteration 0 / 30 | Total Loss: 4.09586238861084 | KNN Loss: 3.044299840927124 | BCE Loss: 1.0515625476837158
Epoch 374 / 500 | iteration 5 / 30 | Total Loss: 4.114771842956543 | KNN Loss: 3.099344491958618 | BCE Loss: 1.0154271125793457
Epoch 374 / 500 | iteration 10 / 30 | Total Loss: 4.0606184005737305 | KNN Loss: 3.047515392303467 | BCE Loss: 1.0131032466888428
Epoch 374 / 500 | iteration 15 / 30 | Total Loss: 4.095880508422852 | KNN Loss: 3.042097568511963 | BCE Loss: 1.0537828207015991
Epoch 374 / 500 | iteration 20 / 30 | Total Loss: 4.054931163787842 | KNN Loss: 3.0412046909332275 | BCE Loss: 1.0137264728546143
Epoch 374 / 500 | iteration 25 / 30 | Total Loss: 4.100155353546143 | KNN Loss: 3.0954654216766357 | BCE Loss: 1.0046899318695068
Epoch 375 / 500 | iteration 0 / 30 | Total Loss: 4.071311950683594 | KNN Loss: 3.040191411972046 | BCE Loss: 1.0311205387115479
Epoch 375 / 500 | iteration 5 / 30 | Total Loss: 4.074844837188721 | KNN Loss: 3.06019926071167 | 

Epoch 384 / 500 | iteration 20 / 30 | Total Loss: 4.104311943054199 | KNN Loss: 3.065690040588379 | BCE Loss: 1.0386216640472412
Epoch 384 / 500 | iteration 25 / 30 | Total Loss: 4.118304252624512 | KNN Loss: 3.0907137393951416 | BCE Loss: 1.0275905132293701
Epoch 385 / 500 | iteration 0 / 30 | Total Loss: 4.126136779785156 | KNN Loss: 3.0589377880096436 | BCE Loss: 1.0671991109848022
Epoch 385 / 500 | iteration 5 / 30 | Total Loss: 4.075124740600586 | KNN Loss: 3.0616018772125244 | BCE Loss: 1.0135226249694824
Epoch 385 / 500 | iteration 10 / 30 | Total Loss: 4.078861236572266 | KNN Loss: 3.062480926513672 | BCE Loss: 1.0163803100585938
Epoch 385 / 500 | iteration 15 / 30 | Total Loss: 4.083714962005615 | KNN Loss: 3.0681660175323486 | BCE Loss: 1.015548825263977
Epoch 385 / 500 | iteration 20 / 30 | Total Loss: 4.084773540496826 | KNN Loss: 3.065049648284912 | BCE Loss: 1.0197240114212036
Epoch 385 / 500 | iteration 25 / 30 | Total Loss: 4.093136310577393 | KNN Loss: 3.07685375213623

Epoch 395 / 500 | iteration 5 / 30 | Total Loss: 4.08995246887207 | KNN Loss: 3.0600340366363525 | BCE Loss: 1.0299181938171387
Epoch 395 / 500 | iteration 10 / 30 | Total Loss: 4.074589729309082 | KNN Loss: 3.059720754623413 | BCE Loss: 1.0148688554763794
Epoch 395 / 500 | iteration 15 / 30 | Total Loss: 4.126965522766113 | KNN Loss: 3.071751832962036 | BCE Loss: 1.055213451385498
Epoch 395 / 500 | iteration 20 / 30 | Total Loss: 4.087882041931152 | KNN Loss: 3.066810369491577 | BCE Loss: 1.0210715532302856
Epoch 395 / 500 | iteration 25 / 30 | Total Loss: 4.05764102935791 | KNN Loss: 3.049177646636963 | BCE Loss: 1.0084635019302368
Epoch 396 / 500 | iteration 0 / 30 | Total Loss: 4.09837532043457 | KNN Loss: 3.0724523067474365 | BCE Loss: 1.025923252105713
Epoch 396 / 500 | iteration 5 / 30 | Total Loss: 4.100473880767822 | KNN Loss: 3.0920450687408447 | BCE Loss: 1.0084288120269775
Epoch 396 / 500 | iteration 10 / 30 | Total Loss: 4.043381690979004 | KNN Loss: 3.041574716567993 | BC

Epoch 405 / 500 | iteration 25 / 30 | Total Loss: 4.067869186401367 | KNN Loss: 3.048295497894287 | BCE Loss: 1.019573450088501
Epoch 406 / 500 | iteration 0 / 30 | Total Loss: 4.11198091506958 | KNN Loss: 3.087237596511841 | BCE Loss: 1.0247431993484497
Epoch 406 / 500 | iteration 5 / 30 | Total Loss: 4.202536106109619 | KNN Loss: 3.15704607963562 | BCE Loss: 1.0454899072647095
Epoch 406 / 500 | iteration 10 / 30 | Total Loss: 4.065980434417725 | KNN Loss: 3.0564651489257812 | BCE Loss: 1.0095152854919434
Epoch 406 / 500 | iteration 15 / 30 | Total Loss: 4.078166961669922 | KNN Loss: 3.0687882900238037 | BCE Loss: 1.0093787908554077
Epoch 406 / 500 | iteration 20 / 30 | Total Loss: 4.106908798217773 | KNN Loss: 3.0565907955169678 | BCE Loss: 1.0503180027008057
Epoch 406 / 500 | iteration 25 / 30 | Total Loss: 4.096429824829102 | KNN Loss: 3.082143783569336 | BCE Loss: 1.0142862796783447
Epoch 407 / 500 | iteration 0 / 30 | Total Loss: 4.168573379516602 | KNN Loss: 3.0847578048706055 |

Epoch 416 / 500 | iteration 15 / 30 | Total Loss: 4.120395660400391 | KNN Loss: 3.1064107418060303 | BCE Loss: 1.0139851570129395
Epoch 416 / 500 | iteration 20 / 30 | Total Loss: 4.084008693695068 | KNN Loss: 3.049210548400879 | BCE Loss: 1.0347980260849
Epoch 416 / 500 | iteration 25 / 30 | Total Loss: 4.0950822830200195 | KNN Loss: 3.098548650741577 | BCE Loss: 0.9965334534645081
Epoch 417 / 500 | iteration 0 / 30 | Total Loss: 4.077378749847412 | KNN Loss: 3.0431835651397705 | BCE Loss: 1.0341951847076416
Epoch 417 / 500 | iteration 5 / 30 | Total Loss: 4.0467729568481445 | KNN Loss: 3.040727138519287 | BCE Loss: 1.0060458183288574
Epoch 417 / 500 | iteration 10 / 30 | Total Loss: 4.101114273071289 | KNN Loss: 3.074624538421631 | BCE Loss: 1.0264899730682373
Epoch 417 / 500 | iteration 15 / 30 | Total Loss: 4.156754016876221 | KNN Loss: 3.1156563758850098 | BCE Loss: 1.0410975217819214
Epoch 417 / 500 | iteration 20 / 30 | Total Loss: 4.077102184295654 | KNN Loss: 3.054123878479004

Epoch 427 / 500 | iteration 5 / 30 | Total Loss: 4.063459396362305 | KNN Loss: 3.0340020656585693 | BCE Loss: 1.0294570922851562
Epoch 427 / 500 | iteration 10 / 30 | Total Loss: 4.106435775756836 | KNN Loss: 3.0735971927642822 | BCE Loss: 1.0328385829925537
Epoch 427 / 500 | iteration 15 / 30 | Total Loss: 4.092832565307617 | KNN Loss: 3.0779454708099365 | BCE Loss: 1.0148870944976807
Epoch 427 / 500 | iteration 20 / 30 | Total Loss: 4.1314377784729 | KNN Loss: 3.1037068367004395 | BCE Loss: 1.027730941772461
Epoch 427 / 500 | iteration 25 / 30 | Total Loss: 4.132693290710449 | KNN Loss: 3.099058151245117 | BCE Loss: 1.033634901046753
Epoch 428 / 500 | iteration 0 / 30 | Total Loss: 4.119777679443359 | KNN Loss: 3.095893621444702 | BCE Loss: 1.0238842964172363
Epoch 428 / 500 | iteration 5 / 30 | Total Loss: 4.090920925140381 | KNN Loss: 3.0997676849365234 | BCE Loss: 0.9911531805992126
Epoch 428 / 500 | iteration 10 / 30 | Total Loss: 4.113251209259033 | KNN Loss: 3.081963300704956 |

Epoch 437 / 500 | iteration 25 / 30 | Total Loss: 4.0708794593811035 | KNN Loss: 3.041016101837158 | BCE Loss: 1.0298633575439453
Epoch 438 / 500 | iteration 0 / 30 | Total Loss: 4.063405990600586 | KNN Loss: 3.059786558151245 | BCE Loss: 1.0036191940307617
Epoch 438 / 500 | iteration 5 / 30 | Total Loss: 4.121001243591309 | KNN Loss: 3.0869781970977783 | BCE Loss: 1.0340228080749512
Epoch 438 / 500 | iteration 10 / 30 | Total Loss: 4.089197635650635 | KNN Loss: 3.061354875564575 | BCE Loss: 1.02784264087677
Epoch 438 / 500 | iteration 15 / 30 | Total Loss: 4.094061374664307 | KNN Loss: 3.0630154609680176 | BCE Loss: 1.031045913696289
Epoch 438 / 500 | iteration 20 / 30 | Total Loss: 4.052301406860352 | KNN Loss: 3.0573949813842773 | BCE Loss: 0.9949064254760742
Epoch 438 / 500 | iteration 25 / 30 | Total Loss: 4.104227066040039 | KNN Loss: 3.0930612087249756 | BCE Loss: 1.0111660957336426
Epoch 439 / 500 | iteration 0 / 30 | Total Loss: 4.105855941772461 | KNN Loss: 3.0757195949554443

Epoch 448 / 500 | iteration 15 / 30 | Total Loss: 4.089088439941406 | KNN Loss: 3.071681499481201 | BCE Loss: 1.0174068212509155
Epoch 448 / 500 | iteration 20 / 30 | Total Loss: 4.125425338745117 | KNN Loss: 3.0975501537323 | BCE Loss: 1.027875304222107
Epoch 448 / 500 | iteration 25 / 30 | Total Loss: 4.044632434844971 | KNN Loss: 3.066856861114502 | BCE Loss: 0.9777754545211792
Epoch 449 / 500 | iteration 0 / 30 | Total Loss: 4.0968170166015625 | KNN Loss: 3.0623044967651367 | BCE Loss: 1.0345127582550049
Epoch 449 / 500 | iteration 5 / 30 | Total Loss: 4.073984146118164 | KNN Loss: 3.0623459815979004 | BCE Loss: 1.0116382837295532
Epoch 449 / 500 | iteration 10 / 30 | Total Loss: 4.081967830657959 | KNN Loss: 3.0752947330474854 | BCE Loss: 1.0066732168197632
Epoch 449 / 500 | iteration 15 / 30 | Total Loss: 4.1162848472595215 | KNN Loss: 3.1017158031463623 | BCE Loss: 1.0145691633224487
Epoch 449 / 500 | iteration 20 / 30 | Total Loss: 4.095612049102783 | KNN Loss: 3.08211660385131

Epoch 459 / 500 | iteration 5 / 30 | Total Loss: 4.108442783355713 | KNN Loss: 3.052743434906006 | BCE Loss: 1.055699348449707
Epoch 459 / 500 | iteration 10 / 30 | Total Loss: 4.067770004272461 | KNN Loss: 3.056199073791504 | BCE Loss: 1.011570930480957
Epoch 459 / 500 | iteration 15 / 30 | Total Loss: 4.112699508666992 | KNN Loss: 3.0618655681610107 | BCE Loss: 1.0508339405059814
Epoch 459 / 500 | iteration 20 / 30 | Total Loss: 4.136714458465576 | KNN Loss: 3.108010768890381 | BCE Loss: 1.0287036895751953
Epoch 459 / 500 | iteration 25 / 30 | Total Loss: 4.08895206451416 | KNN Loss: 3.0643649101257324 | BCE Loss: 1.0245869159698486
Epoch 460 / 500 | iteration 0 / 30 | Total Loss: 4.0924224853515625 | KNN Loss: 3.0653116703033447 | BCE Loss: 1.0271108150482178
Epoch 460 / 500 | iteration 5 / 30 | Total Loss: 4.111757755279541 | KNN Loss: 3.0552256107330322 | BCE Loss: 1.0565322637557983
Epoch 460 / 500 | iteration 10 / 30 | Total Loss: 4.106442451477051 | KNN Loss: 3.0806546211242676

Epoch 469 / 500 | iteration 25 / 30 | Total Loss: 4.0749053955078125 | KNN Loss: 3.053128242492676 | BCE Loss: 1.0217769145965576
Epoch   470: reducing learning rate of group 0 to 2.7058e-08.
Epoch 470 / 500 | iteration 0 / 30 | Total Loss: 4.154278755187988 | KNN Loss: 3.1213719844818115 | BCE Loss: 1.0329068899154663
Epoch 470 / 500 | iteration 5 / 30 | Total Loss: 4.072127819061279 | KNN Loss: 3.065067768096924 | BCE Loss: 1.0070600509643555
Epoch 470 / 500 | iteration 10 / 30 | Total Loss: 4.122323036193848 | KNN Loss: 3.084078550338745 | BCE Loss: 1.038244605064392
Epoch 470 / 500 | iteration 15 / 30 | Total Loss: 4.089859962463379 | KNN Loss: 3.0588338375091553 | BCE Loss: 1.0310263633728027
Epoch 470 / 500 | iteration 20 / 30 | Total Loss: 4.087821006774902 | KNN Loss: 3.0655465126037598 | BCE Loss: 1.0222746133804321
Epoch 470 / 500 | iteration 25 / 30 | Total Loss: 4.069508075714111 | KNN Loss: 3.0296149253845215 | BCE Loss: 1.0398931503295898
Epoch 471 / 500 | iteration 0 / 3

Epoch 480 / 500 | iteration 15 / 30 | Total Loss: 4.045282363891602 | KNN Loss: 3.0295751094818115 | BCE Loss: 1.0157071352005005
Epoch 480 / 500 | iteration 20 / 30 | Total Loss: 4.100578308105469 | KNN Loss: 3.0946602821350098 | BCE Loss: 1.0059177875518799
Epoch 480 / 500 | iteration 25 / 30 | Total Loss: 4.094522953033447 | KNN Loss: 3.0792617797851562 | BCE Loss: 1.015261173248291
Epoch 481 / 500 | iteration 0 / 30 | Total Loss: 4.060585021972656 | KNN Loss: 3.048348903656006 | BCE Loss: 1.0122363567352295
Epoch 481 / 500 | iteration 5 / 30 | Total Loss: 4.15877103805542 | KNN Loss: 3.103651523590088 | BCE Loss: 1.055119514465332
Epoch 481 / 500 | iteration 10 / 30 | Total Loss: 4.100108623504639 | KNN Loss: 3.0888283252716064 | BCE Loss: 1.0112801790237427
Epoch 481 / 500 | iteration 15 / 30 | Total Loss: 4.133885383605957 | KNN Loss: 3.100940227508545 | BCE Loss: 1.032944917678833
Epoch 481 / 500 | iteration 20 / 30 | Total Loss: 4.09637975692749 | KNN Loss: 3.0578696727752686 |

Epoch 491 / 500 | iteration 5 / 30 | Total Loss: 4.086661338806152 | KNN Loss: 3.0787205696105957 | BCE Loss: 1.0079410076141357
Epoch 491 / 500 | iteration 10 / 30 | Total Loss: 4.0501017570495605 | KNN Loss: 3.0588767528533936 | BCE Loss: 0.9912250638008118
Epoch 491 / 500 | iteration 15 / 30 | Total Loss: 4.078587055206299 | KNN Loss: 3.0444648265838623 | BCE Loss: 1.0341222286224365
Epoch 491 / 500 | iteration 20 / 30 | Total Loss: 4.038388252258301 | KNN Loss: 3.0538175106048584 | BCE Loss: 0.9845705032348633
Epoch 491 / 500 | iteration 25 / 30 | Total Loss: 4.095339775085449 | KNN Loss: 3.0685949325561523 | BCE Loss: 1.0267446041107178
Epoch 492 / 500 | iteration 0 / 30 | Total Loss: 4.064320087432861 | KNN Loss: 3.0709352493286133 | BCE Loss: 0.993384838104248
Epoch 492 / 500 | iteration 5 / 30 | Total Loss: 4.153305530548096 | KNN Loss: 3.1162731647491455 | BCE Loss: 1.0370323657989502
Epoch 492 / 500 | iteration 10 / 30 | Total Loss: 4.1040754318237305 | KNN Loss: 3.0542881488

In [7]:
outputs, iterm = model(dataset[67][0].unsqueeze(0).to(device), return_intermidiate=True)
print(outputs)
print(outputs.min())
print(outputs.max())

tensor([[ 3.2370,  2.7679,  2.7122,  3.9908,  3.8847,  0.8725,  3.0327,  2.5618,
          2.2363,  1.6938,  2.2864,  2.5082,  0.6424,  2.1298,  1.3242,  1.0866,
          2.1310,  2.3357,  2.0970,  2.6638,  1.7034,  2.1554,  1.3457,  1.5651,
          2.0842,  1.5516,  1.5837,  1.6459,  1.2925,  0.4660, -0.3103,  0.5602,
          0.1078,  1.1122,  1.7155,  1.2236,  0.4354,  2.3027,  0.8398,  1.3784,
          0.6218, -0.6223, -0.1409,  2.6776,  2.5059,  0.5122, -0.3030,  0.2509,
          1.6245,  1.8960,  2.0517, -0.2963,  1.5548,  0.5244, -0.7659,  1.3300,
          1.7193,  1.1864,  1.4329,  1.3418,  0.5420,  0.9076, -0.0140,  1.9244,
          1.2947,  1.9178, -2.1642,  0.5483,  2.1487,  2.4974,  2.8048,  0.6341,
          1.2793,  2.8174,  1.6929,  0.9510, -0.0296,  0.4604,  0.0906,  1.9074,
         -0.1839,  0.4141,  0.8577, -0.2542,  0.4213, -1.1251, -2.7708, -0.2355,
          0.6990, -2.0052,  0.0347, -0.0140, -0.6624, -1.0185,  0.7804,  0.9752,
         -0.7505, -0.7863,  

In [8]:
plt.figure()
plt.plot(losses)
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [9]:
dataset.transform = torchvision.transforms.Compose([
    BinaryEncodingTransform(mapping=dataset.items_to_idx),
]
)
dataset.target_transform = torchvision.transforms.Compose([
    BinaryEncodingTransform(mapping=dataset.items_to_idx),
]
)

In [12]:
dataset_ = [d[0].cpu() for d in dataset]

In [15]:
model = model.eval().cpu()
projections = model.calculate_intermidiate(dataset_)

100%|██████████| 15/15 [00:00<00:00, 81.46it/s]


In [16]:
distances = pairwise_distances(projections)
# distances = np.triu(distances)
distances_f = distances.flatten()

plt.matshow(distances)
plt.colorbar()
plt.figure()
plt.hist(distances_f[distances_f > 0], bins=1000)
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

# Fit DBSCAN and calculate indices


In [17]:
clusters = DBSCAN(eps=0.2, min_samples=80).fit_predict(projections)
# scores = []
# best_score = float('inf')
# clusters = None
# range_ = list(range(5, 20))
# for k in tqdm(range_):
#     y = GaussianMixture(n_components=k).fit_predict(projections)
#     cur_score = davies_bouldin_score(projections, y)
#     scores.append(cur_score)
    
#     if cur_score < best_score:
#         best_score = cur_score
#         clusters = y

In [18]:
perplexity = 100

p = reduce_dims_and_plot(projections[clusters != -1],
                         y=clusters[clusters != -1],
                         title=f'perplexity: {perplexity}',
                         file_name=None,
                         perplexity=perplexity,
                         library='Multicore-TSNE',
                         perform_PCA=False,
                         projected=None,
                         figure_type='2d',
                         show_figure=True,
                         close_figure=False,
                         text=None)    

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [19]:
# from sklearn.tree import DecisionTreeClassifier
# from sklearn import tree
# from sklearn.tree import _tree

In [20]:
tensor_dataset = torch.stack(dataset_)

In [21]:
# clf = DecisionTreeClassifier(max_depth=200, min_samples_leaf=5)
# clf = clf.fit(tensor_dataset[clusters!=-1], clusters[clusters != -1])
# print(clf.score(tensor_dataset[clusters!=-1], clusters[clusters != -1]))
# print(clf.get_depth())

In [22]:
# scores = []
# for min_samples in range(1,50, 1):
#     clf = DecisionTreeClassifier(max_depth=200, min_samples_leaf=min_samples)
#     clf = clf.fit(tensor_dataset[clusters!=-1], clusters[clusters != -1])
#     scores.append(clf.score(tensor_dataset[clusters!=-1], clusters[clusters != -1]))
    
# plt.figure()
# plt.plot(list(range(1,50, 1)), scores)
# plt.show()

In [23]:
def get_rules(tree, feature_names, class_names):
    tree_ = tree.tree_
    feature_name = [
        feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!"
        for i in tree_.feature
    ]

    paths = []
    path = []
    
    def recurse(node, path, paths):
        if tree_.feature[node] != _tree.TREE_UNDEFINED:
            name = feature_name[node]
            threshold = tree_.threshold[node]
            p1, p2 = list(path), list(path)
#             p1 += [f"({name} <= {np.round(threshold, 3)})"]
            p1 += [(name, '<=', np.round(threshold, 3))]
            recurse(tree_.children_left[node], p1, paths)
            p2 += [(name, '>', np.round(threshold, 3))]
            recurse(tree_.children_right[node], p2, paths)
        else:
            path += [(tree_.value[node], tree_.n_node_samples[node])]
            paths += [path]
            
    recurse(0, path, paths)

    # sort by samples count
    samples_count = [p[-1][1] for p in paths]
    ii = list(np.argsort(samples_count))
    paths = [paths[i] for i in reversed(ii)]
    
    rules = []
    for path in paths:
        rule = []
        
        for p in path[:-1]:
            rule += [p]
        target = " then "
        if class_names is None:
            target += "response: "+str(np.round(path[-1][0][0][0],3))
        else:
            classes = path[-1][0][0]
            l = np.argmax(classes)
            target += f"class: {class_names[l]} (proba: {np.round(100.0*classes[l]/np.sum(classes),2)}%)"
           
        proba = np.round(100.0*classes[l]/np.sum(classes),2)
        target += f" | based on {path[-1][1]:,} samples"
        rule_wrapper = {'target': target, 'rule': rule, 'proba': proba}
        rules += [rule_wrapper]
        
    return rules

In [24]:
# rules = get_rules(clf, dataset.items, clusters[clusters != -1])

# for rule in rules:
#     n_pos = 0
#     for c,p,v in rule['rule']:
#         if p == '>':
#             n_pos += 1
#     rule['pos'] = n_pos

In [25]:
# plt.figure()
# probs = [r['proba'] for r in rules]
# plt.hist(probs, bins = 100)
# plt.show()

In [26]:
# rules = sorted(rules, key=lambda x:x['pos'])
# rules = [r for r in rules if r['proba'] > 50]
# print(len(rules))

In [27]:
# for i in range(17):
#     r_i = rules[i]
#     print(f"------------- rule {i} length {len(r_i)} -------------")
#     print(r_i)

# Train a Soft-Decision-Tree given the self-labels

## Prepare the dataset

In [28]:
tree_dataset = list(zip(tensor_dataset[clusters!=-1], clusters[clusters != -1]))
batch_size = 512
tree_loader = torch.utils.data.DataLoader(tree_dataset, batch_size=batch_size, shuffle=True)

# Define how we prune the weights of a node

In [29]:
def prune_node(node_weights, factor=1):
    w = node_weights.cpu().detach().numpy()
    mean_ = np.mean(w)
    std_ = np.std(w)
    node_weights[((mean_ - std_ * factor) < node_weights) & (node_weights < (mean_ + std_ * factor))] = 0
    return node_weights

def prune_node_keep(node_weights, keep=4):
    w = node_weights.cpu().detach().numpy()
    throw_idx = np.argsort(abs(w))[:-keep]
    node_weights[throw_idx] = 0
    return node_weights

def prune_tree(tree_, factor):
    new_weights = tree_.inner_nodes.weight.clone()
    for i in range(new_weights.shape[0]):
        res = prune_node_keep(new_weights[i, :], factor)
        new_weights[i, :] = res

    with torch.no_grad():
        tree_.inner_nodes.weight.copy_(new_weights)
        
def sparseness(x):
    s = []
    for i in range(x.shape[0]):
        x_ = x[i, :]
        sp = (len(x_) - torch.norm(x_, 0).item()) / len(x_)
        s.append(sp)
    return np.mean(s)

def compute_regularization_by_level(tree):
    total_reg = 0
    for i in range(tree.inner_nodes.weight.shape[0]):
        cur_level = np.floor(np.log2(i+1))
        node_reg = torch.norm(tree.inner_nodes.weight[i].view(-1), 2)
        total_reg += 2**(-cur_level) * node_reg
    return total_reg

def show_sparseness(tree):
    avg_sp = sparseness(tree.inner_nodes.weight)
    print(f"Average sparseness: {avg_sp}")
    layer = 0
    sps = []
    for i in range(tree.inner_nodes.weight.shape[0]):
        cur_layer = int(np.floor(np.log2(i+1)))
        if cur_layer != layer:
            print(f"layer {layer}: {np.mean(sps)}")
            sps = []
            layer = cur_layer

        x_ = tree.inner_nodes.weight[i, :]
        sp = (len(x_) - torch.norm(x_, 0).item()) / len(x_)
        sps.append(sp)
        
    return avg_sp

## Training configurations

In [30]:
def do_epoch(model, loader, device, log_interval, losses, accs, epoch, iteration):
    model = model.train()
    for batch_idx, (data, target) in enumerate(loader):
        iteration += 1
        data, target = data.to(device), target.to(device)

        output, penalty = tree.forward(data)

        # Loss
        loss_tree = criterion(output, target.view(-1))

        # Penalty
        loss_tree += penalty

        # Sparse regularization
#         fc_params = torch.cat([x.view(-1) for x in tree.inner_nodes.parameters()])
#         regularization = sparsity_lamda * torch.norm(fc_params, 2)
        regularization = sparsity_lamda * compute_regularization_by_level(tree)
        loss = loss_tree

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        losses.append(loss.item())

        pred = output.data.max(1)[1]
        correct = pred.eq(target.view(-1).data).sum()
        accs.append(correct.item() / data.size()[0])

        # Print training status
        if batch_idx % log_interval == 0:
            print(f"Epoch: {epoch:02d} | Batch: {batch_idx:03d} / {len(loader):03d} | Total loss: {loss.item():.3f} | Reg loss: {regularization.item():.3f} | Tree loss: {loss_tree.item():.3f} | Accuracy: {correct.item() / data.size()[0]:03f} | {round((time.time() - start_time) / iteration, 3)} sec/iter")
            
    return iteration


In [31]:
lr = 5e-3
weight_decay = 5e-4
sparsity_lamda = 2e-3
epochs = 100
output_dim = len(set(clusters))
log_interval = 1
use_cuda = device != 'cpu'

In [32]:
tree = SDT(input_dim=tensor_dataset.shape[1], output_dim=len(clusters - 1), depth=tree_depth, lamda=1e-3, use_cuda=use_cuda)
optimizer = torch.optim.Adam(tree.parameters(),
                                 lr=lr,
                                 weight_decay=weight_decay)
criterion = nn.CrossEntropyLoss()
tree = tree.to(device)

In [33]:
losses = []
accs = []
sparsity = []

In [34]:
start_time = time.time()
iteration = 0
for epoch in range(epochs):
    # Training
    avg_sp = show_sparseness(tree)
    sparsity.append(avg_sp)
    iteration = do_epoch(tree, tree_loader, device, log_interval, losses, accs, epoch, iteration)
    
    if epoch % 1 == 0:
        prune_tree(tree, factor=3)
        

Average sparseness: 0.0
layer 0: 0.0
layer 1: 0.0
layer 2: 0.0
layer 3: 0.0
layer 4: 0.0
layer 5: 0.0
layer 6: 0.0
Epoch: 00 | Batch: 000 / 018 | Total loss: 9.623 | Reg loss: 0.009 | Tree loss: 9.623 | Accuracy: 0.000000 | 0.241 sec/iter
Epoch: 00 | Batch: 001 / 018 | Total loss: 9.620 | Reg loss: 0.009 | Tree loss: 9.620 | Accuracy: 0.000000 | 0.229 sec/iter
Epoch: 00 | Batch: 002 / 018 | Total loss: 9.619 | Reg loss: 0.008 | Tree loss: 9.619 | Accuracy: 0.000000 | 0.227 sec/iter
Epoch: 00 | Batch: 003 / 018 | Total loss: 9.612 | Reg loss: 0.008 | Tree loss: 9.612 | Accuracy: 0.000000 | 0.225 sec/iter
Epoch: 00 | Batch: 004 / 018 | Total loss: 9.609 | Reg loss: 0.008 | Tree loss: 9.609 | Accuracy: 0.000000 | 0.227 sec/iter
Epoch: 00 | Batch: 005 / 018 | Total loss: 9.609 | Reg loss: 0.007 | Tree loss: 9.609 | Accuracy: 0.000000 | 0.226 sec/iter
Epoch: 00 | Batch: 006 / 018 | Total loss: 9.600 | Reg loss: 0.007 | Tree loss: 9.600 | Accuracy: 0.000000 | 0.226 sec/iter
Epoch: 00 | Batch

Epoch: 03 | Batch: 006 / 018 | Total loss: 9.526 | Reg loss: 0.006 | Tree loss: 9.526 | Accuracy: 0.058594 | 0.23 sec/iter
Epoch: 03 | Batch: 007 / 018 | Total loss: 9.529 | Reg loss: 0.006 | Tree loss: 9.529 | Accuracy: 0.062500 | 0.23 sec/iter
Epoch: 03 | Batch: 008 / 018 | Total loss: 9.515 | Reg loss: 0.006 | Tree loss: 9.515 | Accuracy: 0.076172 | 0.23 sec/iter
Epoch: 03 | Batch: 009 / 018 | Total loss: 9.517 | Reg loss: 0.006 | Tree loss: 9.517 | Accuracy: 0.080078 | 0.23 sec/iter
Epoch: 03 | Batch: 010 / 018 | Total loss: 9.521 | Reg loss: 0.006 | Tree loss: 9.521 | Accuracy: 0.050781 | 0.23 sec/iter
Epoch: 03 | Batch: 011 / 018 | Total loss: 9.517 | Reg loss: 0.007 | Tree loss: 9.517 | Accuracy: 0.060547 | 0.23 sec/iter
Epoch: 03 | Batch: 012 / 018 | Total loss: 9.517 | Reg loss: 0.007 | Tree loss: 9.517 | Accuracy: 0.068359 | 0.23 sec/iter
Epoch: 03 | Batch: 013 / 018 | Total loss: 9.510 | Reg loss: 0.007 | Tree loss: 9.510 | Accuracy: 0.074219 | 0.23 sec/iter
Epoch: 03 | Batc

Epoch: 06 | Batch: 013 / 018 | Total loss: 9.405 | Reg loss: 0.010 | Tree loss: 9.405 | Accuracy: 0.074219 | 0.228 sec/iter
Epoch: 06 | Batch: 014 / 018 | Total loss: 9.405 | Reg loss: 0.011 | Tree loss: 9.405 | Accuracy: 0.076172 | 0.228 sec/iter
Epoch: 06 | Batch: 015 / 018 | Total loss: 9.404 | Reg loss: 0.011 | Tree loss: 9.404 | Accuracy: 0.068359 | 0.228 sec/iter
Epoch: 06 | Batch: 016 / 018 | Total loss: 9.385 | Reg loss: 0.011 | Tree loss: 9.385 | Accuracy: 0.072266 | 0.228 sec/iter
Epoch: 06 | Batch: 017 / 018 | Total loss: 9.388 | Reg loss: 0.012 | Tree loss: 9.388 | Accuracy: 0.066253 | 0.228 sec/iter
Average sparseness: 0.9821428571428573
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
Epoch: 07 | Batch: 000 / 018 | Total loss: 9.429 | Reg loss: 0.009 | Tree loss: 9.429 | Accuracy: 0.078125 | 0.228 sec/iter
Epoch: 07 | Batch: 001

Epoch: 10 | Batch: 000 / 018 | Total loss: 9.246 | Reg loss: 0.014 | Tree loss: 9.246 | Accuracy: 0.046875 | 0.229 sec/iter
Epoch: 10 | Batch: 001 / 018 | Total loss: 9.230 | Reg loss: 0.014 | Tree loss: 9.230 | Accuracy: 0.064453 | 0.229 sec/iter
Epoch: 10 | Batch: 002 / 018 | Total loss: 9.204 | Reg loss: 0.014 | Tree loss: 9.204 | Accuracy: 0.072266 | 0.229 sec/iter
Epoch: 10 | Batch: 003 / 018 | Total loss: 9.193 | Reg loss: 0.014 | Tree loss: 9.193 | Accuracy: 0.060547 | 0.229 sec/iter
Epoch: 10 | Batch: 004 / 018 | Total loss: 9.177 | Reg loss: 0.014 | Tree loss: 9.177 | Accuracy: 0.078125 | 0.229 sec/iter
Epoch: 10 | Batch: 005 / 018 | Total loss: 9.167 | Reg loss: 0.015 | Tree loss: 9.167 | Accuracy: 0.066406 | 0.229 sec/iter
Epoch: 10 | Batch: 006 / 018 | Total loss: 9.158 | Reg loss: 0.015 | Tree loss: 9.158 | Accuracy: 0.068359 | 0.229 sec/iter
Epoch: 10 | Batch: 007 / 018 | Total loss: 9.152 | Reg loss: 0.015 | Tree loss: 9.152 | Accuracy: 0.060547 | 0.229 sec/iter
Epoch: 1

Epoch: 13 | Batch: 007 / 018 | Total loss: 8.751 | Reg loss: 0.020 | Tree loss: 8.751 | Accuracy: 0.060547 | 0.229 sec/iter
Epoch: 13 | Batch: 008 / 018 | Total loss: 8.734 | Reg loss: 0.021 | Tree loss: 8.734 | Accuracy: 0.072266 | 0.229 sec/iter
Epoch: 13 | Batch: 009 / 018 | Total loss: 8.686 | Reg loss: 0.021 | Tree loss: 8.686 | Accuracy: 0.087891 | 0.229 sec/iter
Epoch: 13 | Batch: 010 / 018 | Total loss: 8.690 | Reg loss: 0.021 | Tree loss: 8.690 | Accuracy: 0.068359 | 0.229 sec/iter
Epoch: 13 | Batch: 011 / 018 | Total loss: 8.651 | Reg loss: 0.022 | Tree loss: 8.651 | Accuracy: 0.080078 | 0.229 sec/iter
Epoch: 13 | Batch: 012 / 018 | Total loss: 8.628 | Reg loss: 0.022 | Tree loss: 8.628 | Accuracy: 0.093750 | 0.229 sec/iter
Epoch: 13 | Batch: 013 / 018 | Total loss: 8.623 | Reg loss: 0.023 | Tree loss: 8.623 | Accuracy: 0.060547 | 0.229 sec/iter
Epoch: 13 | Batch: 014 / 018 | Total loss: 8.597 | Reg loss: 0.023 | Tree loss: 8.597 | Accuracy: 0.060547 | 0.229 sec/iter
Epoch: 1

Epoch: 16 | Batch: 014 / 018 | Total loss: 8.108 | Reg loss: 0.026 | Tree loss: 8.108 | Accuracy: 0.068359 | 0.229 sec/iter
Epoch: 16 | Batch: 015 / 018 | Total loss: 8.128 | Reg loss: 0.027 | Tree loss: 8.128 | Accuracy: 0.052734 | 0.229 sec/iter
Epoch: 16 | Batch: 016 / 018 | Total loss: 8.071 | Reg loss: 0.027 | Tree loss: 8.071 | Accuracy: 0.070312 | 0.229 sec/iter
Epoch: 16 | Batch: 017 / 018 | Total loss: 8.070 | Reg loss: 0.027 | Tree loss: 8.070 | Accuracy: 0.049689 | 0.229 sec/iter
Average sparseness: 0.9821428571428573
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
Epoch: 17 | Batch: 000 / 018 | Total loss: 8.305 | Reg loss: 0.025 | Tree loss: 8.305 | Accuracy: 0.058594 | 0.229 sec/iter
Epoch: 17 | Batch: 001 / 018 | Total loss: 8.274 | Reg loss: 0.025 | Tree loss: 8.274 | Accuracy: 0.085938 | 0.229 sec/iter
Epoch: 17 | Batch: 002

Epoch: 20 | Batch: 001 / 018 | Total loss: 7.786 | Reg loss: 0.027 | Tree loss: 7.786 | Accuracy: 0.083984 | 0.229 sec/iter
Epoch: 20 | Batch: 002 / 018 | Total loss: 7.817 | Reg loss: 0.028 | Tree loss: 7.817 | Accuracy: 0.062500 | 0.229 sec/iter
Epoch: 20 | Batch: 003 / 018 | Total loss: 7.795 | Reg loss: 0.028 | Tree loss: 7.795 | Accuracy: 0.060547 | 0.229 sec/iter
Epoch: 20 | Batch: 004 / 018 | Total loss: 7.715 | Reg loss: 0.028 | Tree loss: 7.715 | Accuracy: 0.076172 | 0.229 sec/iter
Epoch: 20 | Batch: 005 / 018 | Total loss: 7.742 | Reg loss: 0.028 | Tree loss: 7.742 | Accuracy: 0.060547 | 0.229 sec/iter
Epoch: 20 | Batch: 006 / 018 | Total loss: 7.679 | Reg loss: 0.028 | Tree loss: 7.679 | Accuracy: 0.070312 | 0.229 sec/iter
Epoch: 20 | Batch: 007 / 018 | Total loss: 7.675 | Reg loss: 0.028 | Tree loss: 7.675 | Accuracy: 0.074219 | 0.229 sec/iter
Epoch: 20 | Batch: 008 / 018 | Total loss: 7.649 | Reg loss: 0.028 | Tree loss: 7.649 | Accuracy: 0.056641 | 0.229 sec/iter
Epoch: 2

Epoch: 23 | Batch: 008 / 018 | Total loss: 7.245 | Reg loss: 0.029 | Tree loss: 7.245 | Accuracy: 0.060547 | 0.229 sec/iter
Epoch: 23 | Batch: 009 / 018 | Total loss: 7.199 | Reg loss: 0.030 | Tree loss: 7.199 | Accuracy: 0.085938 | 0.229 sec/iter
Epoch: 23 | Batch: 010 / 018 | Total loss: 7.229 | Reg loss: 0.030 | Tree loss: 7.229 | Accuracy: 0.052734 | 0.229 sec/iter
Epoch: 23 | Batch: 011 / 018 | Total loss: 7.122 | Reg loss: 0.030 | Tree loss: 7.122 | Accuracy: 0.093750 | 0.229 sec/iter
Epoch: 23 | Batch: 012 / 018 | Total loss: 7.137 | Reg loss: 0.030 | Tree loss: 7.137 | Accuracy: 0.074219 | 0.229 sec/iter
Epoch: 23 | Batch: 013 / 018 | Total loss: 7.106 | Reg loss: 0.030 | Tree loss: 7.106 | Accuracy: 0.070312 | 0.229 sec/iter
Epoch: 23 | Batch: 014 / 018 | Total loss: 7.120 | Reg loss: 0.030 | Tree loss: 7.120 | Accuracy: 0.050781 | 0.229 sec/iter
Epoch: 23 | Batch: 015 / 018 | Total loss: 7.085 | Reg loss: 0.030 | Tree loss: 7.085 | Accuracy: 0.089844 | 0.229 sec/iter
Epoch: 2

Epoch: 26 | Batch: 015 / 018 | Total loss: 6.740 | Reg loss: 0.031 | Tree loss: 6.740 | Accuracy: 0.072266 | 0.229 sec/iter
Epoch: 26 | Batch: 016 / 018 | Total loss: 6.697 | Reg loss: 0.031 | Tree loss: 6.697 | Accuracy: 0.066406 | 0.229 sec/iter
Epoch: 26 | Batch: 017 / 018 | Total loss: 6.721 | Reg loss: 0.031 | Tree loss: 6.721 | Accuracy: 0.070393 | 0.229 sec/iter
Average sparseness: 0.9821428571428573
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
Epoch: 27 | Batch: 000 / 018 | Total loss: 6.855 | Reg loss: 0.030 | Tree loss: 6.855 | Accuracy: 0.078125 | 0.229 sec/iter
Epoch: 27 | Batch: 001 / 018 | Total loss: 6.908 | Reg loss: 0.030 | Tree loss: 6.908 | Accuracy: 0.056641 | 0.229 sec/iter
Epoch: 27 | Batch: 002 / 018 | Total loss: 6.876 | Reg loss: 0.030 | Tree loss: 6.876 | Accuracy: 0.054688 | 0.229 sec/iter
Epoch: 27 | Batch: 003

Epoch: 30 | Batch: 002 / 018 | Total loss: 6.502 | Reg loss: 0.031 | Tree loss: 6.502 | Accuracy: 0.062500 | 0.229 sec/iter
Epoch: 30 | Batch: 003 / 018 | Total loss: 6.446 | Reg loss: 0.031 | Tree loss: 6.446 | Accuracy: 0.058594 | 0.229 sec/iter
Epoch: 30 | Batch: 004 / 018 | Total loss: 6.426 | Reg loss: 0.031 | Tree loss: 6.426 | Accuracy: 0.091797 | 0.229 sec/iter
Epoch: 30 | Batch: 005 / 018 | Total loss: 6.396 | Reg loss: 0.031 | Tree loss: 6.396 | Accuracy: 0.068359 | 0.229 sec/iter
Epoch: 30 | Batch: 006 / 018 | Total loss: 6.397 | Reg loss: 0.031 | Tree loss: 6.397 | Accuracy: 0.080078 | 0.229 sec/iter
Epoch: 30 | Batch: 007 / 018 | Total loss: 6.339 | Reg loss: 0.031 | Tree loss: 6.339 | Accuracy: 0.078125 | 0.229 sec/iter
Epoch: 30 | Batch: 008 / 018 | Total loss: 6.326 | Reg loss: 0.031 | Tree loss: 6.326 | Accuracy: 0.082031 | 0.229 sec/iter
Epoch: 30 | Batch: 009 / 018 | Total loss: 6.380 | Reg loss: 0.031 | Tree loss: 6.380 | Accuracy: 0.072266 | 0.229 sec/iter
Epoch: 3

Epoch: 33 | Batch: 009 / 018 | Total loss: 5.961 | Reg loss: 0.032 | Tree loss: 5.961 | Accuracy: 0.080078 | 0.229 sec/iter
Epoch: 33 | Batch: 010 / 018 | Total loss: 6.006 | Reg loss: 0.032 | Tree loss: 6.006 | Accuracy: 0.078125 | 0.229 sec/iter
Epoch: 33 | Batch: 011 / 018 | Total loss: 5.968 | Reg loss: 0.032 | Tree loss: 5.968 | Accuracy: 0.076172 | 0.229 sec/iter
Epoch: 33 | Batch: 012 / 018 | Total loss: 5.933 | Reg loss: 0.032 | Tree loss: 5.933 | Accuracy: 0.068359 | 0.229 sec/iter
Epoch: 33 | Batch: 013 / 018 | Total loss: 5.944 | Reg loss: 0.032 | Tree loss: 5.944 | Accuracy: 0.058594 | 0.229 sec/iter
Epoch: 33 | Batch: 014 / 018 | Total loss: 5.895 | Reg loss: 0.032 | Tree loss: 5.895 | Accuracy: 0.074219 | 0.229 sec/iter
Epoch: 33 | Batch: 015 / 018 | Total loss: 5.915 | Reg loss: 0.032 | Tree loss: 5.915 | Accuracy: 0.056641 | 0.229 sec/iter
Epoch: 33 | Batch: 016 / 018 | Total loss: 5.960 | Reg loss: 0.032 | Tree loss: 5.960 | Accuracy: 0.052734 | 0.229 sec/iter
Epoch: 3

Epoch: 36 | Batch: 016 / 018 | Total loss: 5.650 | Reg loss: 0.032 | Tree loss: 5.650 | Accuracy: 0.054688 | 0.229 sec/iter
Epoch: 36 | Batch: 017 / 018 | Total loss: 5.606 | Reg loss: 0.032 | Tree loss: 5.606 | Accuracy: 0.074534 | 0.229 sec/iter
Average sparseness: 0.9821428571428573
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
Epoch: 37 | Batch: 000 / 018 | Total loss: 5.701 | Reg loss: 0.032 | Tree loss: 5.701 | Accuracy: 0.082031 | 0.229 sec/iter
Epoch: 37 | Batch: 001 / 018 | Total loss: 5.701 | Reg loss: 0.032 | Tree loss: 5.701 | Accuracy: 0.068359 | 0.229 sec/iter
Epoch: 37 | Batch: 002 / 018 | Total loss: 5.752 | Reg loss: 0.032 | Tree loss: 5.752 | Accuracy: 0.060547 | 0.229 sec/iter
Epoch: 37 | Batch: 003 / 018 | Total loss: 5.694 | Reg loss: 0.032 | Tree loss: 5.694 | Accuracy: 0.072266 | 0.229 sec/iter
Epoch: 37 | Batch: 004

Epoch: 40 | Batch: 003 / 018 | Total loss: 5.452 | Reg loss: 0.032 | Tree loss: 5.452 | Accuracy: 0.064453 | 0.229 sec/iter
Epoch: 40 | Batch: 004 / 018 | Total loss: 5.391 | Reg loss: 0.032 | Tree loss: 5.391 | Accuracy: 0.072266 | 0.229 sec/iter
Epoch: 40 | Batch: 005 / 018 | Total loss: 5.408 | Reg loss: 0.032 | Tree loss: 5.408 | Accuracy: 0.068359 | 0.229 sec/iter
Epoch: 40 | Batch: 006 / 018 | Total loss: 5.387 | Reg loss: 0.032 | Tree loss: 5.387 | Accuracy: 0.083984 | 0.229 sec/iter
Epoch: 40 | Batch: 007 / 018 | Total loss: 5.357 | Reg loss: 0.032 | Tree loss: 5.357 | Accuracy: 0.089844 | 0.229 sec/iter
Epoch: 40 | Batch: 008 / 018 | Total loss: 5.380 | Reg loss: 0.032 | Tree loss: 5.380 | Accuracy: 0.064453 | 0.229 sec/iter
Epoch: 40 | Batch: 009 / 018 | Total loss: 5.364 | Reg loss: 0.032 | Tree loss: 5.364 | Accuracy: 0.068359 | 0.229 sec/iter
Epoch: 40 | Batch: 010 / 018 | Total loss: 5.326 | Reg loss: 0.032 | Tree loss: 5.326 | Accuracy: 0.074219 | 0.229 sec/iter
Epoch: 4

Epoch: 43 | Batch: 010 / 018 | Total loss: 5.146 | Reg loss: 0.031 | Tree loss: 5.146 | Accuracy: 0.052734 | 0.23 sec/iter
Epoch: 43 | Batch: 011 / 018 | Total loss: 5.084 | Reg loss: 0.031 | Tree loss: 5.084 | Accuracy: 0.089844 | 0.23 sec/iter
Epoch: 43 | Batch: 012 / 018 | Total loss: 5.079 | Reg loss: 0.031 | Tree loss: 5.079 | Accuracy: 0.070312 | 0.23 sec/iter
Epoch: 43 | Batch: 013 / 018 | Total loss: 5.125 | Reg loss: 0.031 | Tree loss: 5.125 | Accuracy: 0.080078 | 0.23 sec/iter
Epoch: 43 | Batch: 014 / 018 | Total loss: 5.045 | Reg loss: 0.031 | Tree loss: 5.045 | Accuracy: 0.080078 | 0.23 sec/iter
Epoch: 43 | Batch: 015 / 018 | Total loss: 5.109 | Reg loss: 0.031 | Tree loss: 5.109 | Accuracy: 0.048828 | 0.23 sec/iter
Epoch: 43 | Batch: 016 / 018 | Total loss: 5.049 | Reg loss: 0.031 | Tree loss: 5.049 | Accuracy: 0.085938 | 0.23 sec/iter
Epoch: 43 | Batch: 017 / 018 | Total loss: 5.117 | Reg loss: 0.032 | Tree loss: 5.117 | Accuracy: 0.043478 | 0.23 sec/iter
Average sparsene

Epoch: 46 | Batch: 017 / 018 | Total loss: 4.878 | Reg loss: 0.031 | Tree loss: 4.878 | Accuracy: 0.064182 | 0.23 sec/iter
Average sparseness: 0.9821428571428573
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
Epoch: 47 | Batch: 000 / 018 | Total loss: 5.053 | Reg loss: 0.030 | Tree loss: 5.053 | Accuracy: 0.062500 | 0.23 sec/iter
Epoch: 47 | Batch: 001 / 018 | Total loss: 5.007 | Reg loss: 0.030 | Tree loss: 5.007 | Accuracy: 0.054688 | 0.23 sec/iter
Epoch: 47 | Batch: 002 / 018 | Total loss: 5.020 | Reg loss: 0.030 | Tree loss: 5.020 | Accuracy: 0.066406 | 0.23 sec/iter
Epoch: 47 | Batch: 003 / 018 | Total loss: 4.982 | Reg loss: 0.030 | Tree loss: 4.982 | Accuracy: 0.085938 | 0.23 sec/iter
Epoch: 47 | Batch: 004 / 018 | Total loss: 4.966 | Reg loss: 0.030 | Tree loss: 4.966 | Accuracy: 0.078125 | 0.23 sec/iter
Epoch: 47 | Batch: 005 / 018

Epoch: 50 | Batch: 004 / 018 | Total loss: 4.798 | Reg loss: 0.029 | Tree loss: 4.798 | Accuracy: 0.074219 | 0.23 sec/iter
Epoch: 50 | Batch: 005 / 018 | Total loss: 4.830 | Reg loss: 0.029 | Tree loss: 4.830 | Accuracy: 0.068359 | 0.23 sec/iter
Epoch: 50 | Batch: 006 / 018 | Total loss: 4.833 | Reg loss: 0.029 | Tree loss: 4.833 | Accuracy: 0.068359 | 0.23 sec/iter
Epoch: 50 | Batch: 007 / 018 | Total loss: 4.812 | Reg loss: 0.029 | Tree loss: 4.812 | Accuracy: 0.060547 | 0.23 sec/iter
Epoch: 50 | Batch: 008 / 018 | Total loss: 4.839 | Reg loss: 0.029 | Tree loss: 4.839 | Accuracy: 0.068359 | 0.23 sec/iter
Epoch: 50 | Batch: 009 / 018 | Total loss: 4.793 | Reg loss: 0.029 | Tree loss: 4.793 | Accuracy: 0.082031 | 0.23 sec/iter
Epoch: 50 | Batch: 010 / 018 | Total loss: 4.783 | Reg loss: 0.029 | Tree loss: 4.783 | Accuracy: 0.070312 | 0.23 sec/iter
Epoch: 50 | Batch: 011 / 018 | Total loss: 4.739 | Reg loss: 0.029 | Tree loss: 4.739 | Accuracy: 0.087891 | 0.23 sec/iter
Epoch: 50 | Batc

Epoch: 53 | Batch: 011 / 018 | Total loss: 4.576 | Reg loss: 0.028 | Tree loss: 4.576 | Accuracy: 0.080078 | 0.23 sec/iter
Epoch: 53 | Batch: 012 / 018 | Total loss: 4.622 | Reg loss: 0.028 | Tree loss: 4.622 | Accuracy: 0.082031 | 0.23 sec/iter
Epoch: 53 | Batch: 013 / 018 | Total loss: 4.629 | Reg loss: 0.028 | Tree loss: 4.629 | Accuracy: 0.062500 | 0.23 sec/iter
Epoch: 53 | Batch: 014 / 018 | Total loss: 4.590 | Reg loss: 0.028 | Tree loss: 4.590 | Accuracy: 0.062500 | 0.23 sec/iter
Epoch: 53 | Batch: 015 / 018 | Total loss: 4.629 | Reg loss: 0.028 | Tree loss: 4.629 | Accuracy: 0.052734 | 0.23 sec/iter
Epoch: 53 | Batch: 016 / 018 | Total loss: 4.643 | Reg loss: 0.028 | Tree loss: 4.643 | Accuracy: 0.060547 | 0.23 sec/iter
Epoch: 53 | Batch: 017 / 018 | Total loss: 4.600 | Reg loss: 0.028 | Tree loss: 4.600 | Accuracy: 0.072464 | 0.23 sec/iter
Average sparseness: 0.9821428571428573
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.98214

Epoch: 57 | Batch: 000 / 018 | Total loss: 4.606 | Reg loss: 0.029 | Tree loss: 4.606 | Accuracy: 0.078125 | 0.23 sec/iter
Epoch: 57 | Batch: 001 / 018 | Total loss: 4.660 | Reg loss: 0.029 | Tree loss: 4.660 | Accuracy: 0.070312 | 0.23 sec/iter
Epoch: 57 | Batch: 002 / 018 | Total loss: 4.599 | Reg loss: 0.029 | Tree loss: 4.599 | Accuracy: 0.072266 | 0.23 sec/iter
Epoch: 57 | Batch: 003 / 018 | Total loss: 4.588 | Reg loss: 0.029 | Tree loss: 4.588 | Accuracy: 0.054688 | 0.23 sec/iter
Epoch: 57 | Batch: 004 / 018 | Total loss: 4.567 | Reg loss: 0.029 | Tree loss: 4.567 | Accuracy: 0.089844 | 0.23 sec/iter
Epoch: 57 | Batch: 005 / 018 | Total loss: 4.551 | Reg loss: 0.029 | Tree loss: 4.551 | Accuracy: 0.078125 | 0.23 sec/iter
Epoch: 57 | Batch: 006 / 018 | Total loss: 4.485 | Reg loss: 0.029 | Tree loss: 4.485 | Accuracy: 0.056641 | 0.23 sec/iter
Epoch: 57 | Batch: 007 / 018 | Total loss: 4.558 | Reg loss: 0.029 | Tree loss: 4.558 | Accuracy: 0.068359 | 0.23 sec/iter
Epoch: 57 | Batc

Epoch: 60 | Batch: 007 / 018 | Total loss: 4.509 | Reg loss: 0.031 | Tree loss: 4.509 | Accuracy: 0.066406 | 0.23 sec/iter
Epoch: 60 | Batch: 008 / 018 | Total loss: 4.459 | Reg loss: 0.031 | Tree loss: 4.459 | Accuracy: 0.068359 | 0.23 sec/iter
Epoch: 60 | Batch: 009 / 018 | Total loss: 4.466 | Reg loss: 0.031 | Tree loss: 4.466 | Accuracy: 0.062500 | 0.23 sec/iter
Epoch: 60 | Batch: 010 / 018 | Total loss: 4.438 | Reg loss: 0.031 | Tree loss: 4.438 | Accuracy: 0.072266 | 0.23 sec/iter
Epoch: 60 | Batch: 011 / 018 | Total loss: 4.364 | Reg loss: 0.031 | Tree loss: 4.364 | Accuracy: 0.074219 | 0.23 sec/iter
Epoch: 60 | Batch: 012 / 018 | Total loss: 4.441 | Reg loss: 0.031 | Tree loss: 4.441 | Accuracy: 0.062500 | 0.23 sec/iter
Epoch: 60 | Batch: 013 / 018 | Total loss: 4.367 | Reg loss: 0.031 | Tree loss: 4.367 | Accuracy: 0.062500 | 0.23 sec/iter
Epoch: 60 | Batch: 014 / 018 | Total loss: 4.334 | Reg loss: 0.031 | Tree loss: 4.334 | Accuracy: 0.103516 | 0.23 sec/iter
Epoch: 60 | Batc

Epoch: 63 | Batch: 014 / 018 | Total loss: 4.256 | Reg loss: 0.033 | Tree loss: 4.256 | Accuracy: 0.076172 | 0.23 sec/iter
Epoch: 63 | Batch: 015 / 018 | Total loss: 4.346 | Reg loss: 0.033 | Tree loss: 4.346 | Accuracy: 0.058594 | 0.23 sec/iter
Epoch: 63 | Batch: 016 / 018 | Total loss: 4.305 | Reg loss: 0.033 | Tree loss: 4.305 | Accuracy: 0.064453 | 0.23 sec/iter
Epoch: 63 | Batch: 017 / 018 | Total loss: 4.300 | Reg loss: 0.033 | Tree loss: 4.300 | Accuracy: 0.047619 | 0.23 sec/iter
Average sparseness: 0.9821428571428573
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
Epoch: 64 | Batch: 000 / 018 | Total loss: 4.462 | Reg loss: 0.032 | Tree loss: 4.462 | Accuracy: 0.068359 | 0.23 sec/iter
Epoch: 64 | Batch: 001 / 018 | Total loss: 4.465 | Reg loss: 0.032 | Tree loss: 4.465 | Accuracy: 0.062500 | 0.23 sec/iter
Epoch: 64 | Batch: 002 / 018

Epoch: 67 | Batch: 001 / 018 | Total loss: 4.393 | Reg loss: 0.033 | Tree loss: 4.393 | Accuracy: 0.082031 | 0.231 sec/iter
Epoch: 67 | Batch: 002 / 018 | Total loss: 4.397 | Reg loss: 0.033 | Tree loss: 4.397 | Accuracy: 0.078125 | 0.231 sec/iter
Epoch: 67 | Batch: 003 / 018 | Total loss: 4.386 | Reg loss: 0.033 | Tree loss: 4.386 | Accuracy: 0.054688 | 0.231 sec/iter
Epoch: 67 | Batch: 004 / 018 | Total loss: 4.310 | Reg loss: 0.033 | Tree loss: 4.310 | Accuracy: 0.091797 | 0.231 sec/iter
Epoch: 67 | Batch: 005 / 018 | Total loss: 4.340 | Reg loss: 0.033 | Tree loss: 4.340 | Accuracy: 0.064453 | 0.231 sec/iter
Epoch: 67 | Batch: 006 / 018 | Total loss: 4.403 | Reg loss: 0.033 | Tree loss: 4.403 | Accuracy: 0.062500 | 0.231 sec/iter
Epoch: 67 | Batch: 007 / 018 | Total loss: 4.304 | Reg loss: 0.033 | Tree loss: 4.304 | Accuracy: 0.064453 | 0.231 sec/iter
Epoch: 67 | Batch: 008 / 018 | Total loss: 4.347 | Reg loss: 0.033 | Tree loss: 4.347 | Accuracy: 0.070312 | 0.231 sec/iter
Epoch: 6

Epoch: 70 | Batch: 008 / 018 | Total loss: 4.294 | Reg loss: 0.034 | Tree loss: 4.294 | Accuracy: 0.068359 | 0.231 sec/iter
Epoch: 70 | Batch: 009 / 018 | Total loss: 4.285 | Reg loss: 0.034 | Tree loss: 4.285 | Accuracy: 0.078125 | 0.231 sec/iter
Epoch: 70 | Batch: 010 / 018 | Total loss: 4.255 | Reg loss: 0.034 | Tree loss: 4.255 | Accuracy: 0.070312 | 0.231 sec/iter
Epoch: 70 | Batch: 011 / 018 | Total loss: 4.230 | Reg loss: 0.034 | Tree loss: 4.230 | Accuracy: 0.058594 | 0.231 sec/iter
Epoch: 70 | Batch: 012 / 018 | Total loss: 4.274 | Reg loss: 0.034 | Tree loss: 4.274 | Accuracy: 0.070312 | 0.231 sec/iter
Epoch: 70 | Batch: 013 / 018 | Total loss: 4.195 | Reg loss: 0.034 | Tree loss: 4.195 | Accuracy: 0.064453 | 0.231 sec/iter
Epoch: 70 | Batch: 014 / 018 | Total loss: 4.154 | Reg loss: 0.035 | Tree loss: 4.154 | Accuracy: 0.095703 | 0.231 sec/iter
Epoch: 70 | Batch: 015 / 018 | Total loss: 4.232 | Reg loss: 0.035 | Tree loss: 4.232 | Accuracy: 0.062500 | 0.231 sec/iter
Epoch: 7

Epoch: 73 | Batch: 015 / 018 | Total loss: 4.134 | Reg loss: 0.035 | Tree loss: 4.134 | Accuracy: 0.078125 | 0.231 sec/iter
Epoch: 73 | Batch: 016 / 018 | Total loss: 4.217 | Reg loss: 0.035 | Tree loss: 4.217 | Accuracy: 0.072266 | 0.231 sec/iter
Epoch: 73 | Batch: 017 / 018 | Total loss: 4.230 | Reg loss: 0.035 | Tree loss: 4.230 | Accuracy: 0.051760 | 0.231 sec/iter
Average sparseness: 0.9821428571428573
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
Epoch: 74 | Batch: 000 / 018 | Total loss: 4.330 | Reg loss: 0.035 | Tree loss: 4.330 | Accuracy: 0.068359 | 0.231 sec/iter
Epoch: 74 | Batch: 001 / 018 | Total loss: 4.298 | Reg loss: 0.035 | Tree loss: 4.298 | Accuracy: 0.078125 | 0.231 sec/iter
Epoch: 74 | Batch: 002 / 018 | Total loss: 4.323 | Reg loss: 0.035 | Tree loss: 4.323 | Accuracy: 0.070312 | 0.231 sec/iter
Epoch: 74 | Batch: 003

Epoch: 77 | Batch: 002 / 018 | Total loss: 4.349 | Reg loss: 0.035 | Tree loss: 4.349 | Accuracy: 0.058594 | 0.231 sec/iter
Epoch: 77 | Batch: 003 / 018 | Total loss: 4.279 | Reg loss: 0.035 | Tree loss: 4.279 | Accuracy: 0.083984 | 0.231 sec/iter
Epoch: 77 | Batch: 004 / 018 | Total loss: 4.300 | Reg loss: 0.035 | Tree loss: 4.300 | Accuracy: 0.068359 | 0.231 sec/iter
Epoch: 77 | Batch: 005 / 018 | Total loss: 4.195 | Reg loss: 0.035 | Tree loss: 4.195 | Accuracy: 0.058594 | 0.231 sec/iter
Epoch: 77 | Batch: 006 / 018 | Total loss: 4.224 | Reg loss: 0.035 | Tree loss: 4.224 | Accuracy: 0.083984 | 0.231 sec/iter
Epoch: 77 | Batch: 007 / 018 | Total loss: 4.218 | Reg loss: 0.035 | Tree loss: 4.218 | Accuracy: 0.070312 | 0.232 sec/iter
Epoch: 77 | Batch: 008 / 018 | Total loss: 4.215 | Reg loss: 0.035 | Tree loss: 4.215 | Accuracy: 0.064453 | 0.232 sec/iter
Epoch: 77 | Batch: 009 / 018 | Total loss: 4.187 | Reg loss: 0.035 | Tree loss: 4.187 | Accuracy: 0.070312 | 0.232 sec/iter
Epoch: 7

Epoch: 80 | Batch: 009 / 018 | Total loss: 4.169 | Reg loss: 0.036 | Tree loss: 4.169 | Accuracy: 0.080078 | 0.232 sec/iter
Epoch: 80 | Batch: 010 / 018 | Total loss: 4.230 | Reg loss: 0.036 | Tree loss: 4.230 | Accuracy: 0.068359 | 0.232 sec/iter
Epoch: 80 | Batch: 011 / 018 | Total loss: 4.235 | Reg loss: 0.036 | Tree loss: 4.235 | Accuracy: 0.068359 | 0.232 sec/iter
Epoch: 80 | Batch: 012 / 018 | Total loss: 4.169 | Reg loss: 0.036 | Tree loss: 4.169 | Accuracy: 0.070312 | 0.232 sec/iter
Epoch: 80 | Batch: 013 / 018 | Total loss: 4.130 | Reg loss: 0.036 | Tree loss: 4.130 | Accuracy: 0.072266 | 0.231 sec/iter
Epoch: 80 | Batch: 014 / 018 | Total loss: 4.118 | Reg loss: 0.036 | Tree loss: 4.118 | Accuracy: 0.087891 | 0.231 sec/iter
Epoch: 80 | Batch: 015 / 018 | Total loss: 4.116 | Reg loss: 0.036 | Tree loss: 4.116 | Accuracy: 0.078125 | 0.231 sec/iter
Epoch: 80 | Batch: 016 / 018 | Total loss: 4.131 | Reg loss: 0.036 | Tree loss: 4.131 | Accuracy: 0.078125 | 0.231 sec/iter
Epoch: 8

Epoch: 83 | Batch: 016 / 018 | Total loss: 4.060 | Reg loss: 0.036 | Tree loss: 4.060 | Accuracy: 0.072266 | 0.231 sec/iter
Epoch: 83 | Batch: 017 / 018 | Total loss: 4.107 | Reg loss: 0.037 | Tree loss: 4.107 | Accuracy: 0.053830 | 0.231 sec/iter
Average sparseness: 0.9821428571428573
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
Epoch: 84 | Batch: 000 / 018 | Total loss: 4.288 | Reg loss: 0.036 | Tree loss: 4.288 | Accuracy: 0.070312 | 0.231 sec/iter
Epoch: 84 | Batch: 001 / 018 | Total loss: 4.231 | Reg loss: 0.036 | Tree loss: 4.231 | Accuracy: 0.068359 | 0.231 sec/iter
Epoch: 84 | Batch: 002 / 018 | Total loss: 4.289 | Reg loss: 0.036 | Tree loss: 4.289 | Accuracy: 0.070312 | 0.231 sec/iter
Epoch: 84 | Batch: 003 / 018 | Total loss: 4.318 | Reg loss: 0.036 | Tree loss: 4.318 | Accuracy: 0.062500 | 0.231 sec/iter
Epoch: 84 | Batch: 004

Epoch: 87 | Batch: 003 / 018 | Total loss: 4.294 | Reg loss: 0.036 | Tree loss: 4.294 | Accuracy: 0.064453 | 0.231 sec/iter
Epoch: 87 | Batch: 004 / 018 | Total loss: 4.205 | Reg loss: 0.036 | Tree loss: 4.205 | Accuracy: 0.070312 | 0.231 sec/iter
Epoch: 87 | Batch: 005 / 018 | Total loss: 4.177 | Reg loss: 0.036 | Tree loss: 4.177 | Accuracy: 0.072266 | 0.231 sec/iter
Epoch: 87 | Batch: 006 / 018 | Total loss: 4.235 | Reg loss: 0.036 | Tree loss: 4.235 | Accuracy: 0.054688 | 0.231 sec/iter
Epoch: 87 | Batch: 007 / 018 | Total loss: 4.256 | Reg loss: 0.036 | Tree loss: 4.256 | Accuracy: 0.082031 | 0.231 sec/iter
Epoch: 87 | Batch: 008 / 018 | Total loss: 4.224 | Reg loss: 0.036 | Tree loss: 4.224 | Accuracy: 0.068359 | 0.231 sec/iter
Epoch: 87 | Batch: 009 / 018 | Total loss: 4.160 | Reg loss: 0.036 | Tree loss: 4.160 | Accuracy: 0.082031 | 0.231 sec/iter
Epoch: 87 | Batch: 010 / 018 | Total loss: 4.127 | Reg loss: 0.036 | Tree loss: 4.127 | Accuracy: 0.062500 | 0.231 sec/iter
Epoch: 8

Epoch: 90 | Batch: 010 / 018 | Total loss: 4.171 | Reg loss: 0.037 | Tree loss: 4.171 | Accuracy: 0.062500 | 0.231 sec/iter
Epoch: 90 | Batch: 011 / 018 | Total loss: 4.160 | Reg loss: 0.037 | Tree loss: 4.160 | Accuracy: 0.066406 | 0.231 sec/iter
Epoch: 90 | Batch: 012 / 018 | Total loss: 4.093 | Reg loss: 0.037 | Tree loss: 4.093 | Accuracy: 0.074219 | 0.231 sec/iter
Epoch: 90 | Batch: 013 / 018 | Total loss: 4.119 | Reg loss: 0.037 | Tree loss: 4.119 | Accuracy: 0.064453 | 0.231 sec/iter
Epoch: 90 | Batch: 014 / 018 | Total loss: 4.118 | Reg loss: 0.037 | Tree loss: 4.118 | Accuracy: 0.068359 | 0.231 sec/iter
Epoch: 90 | Batch: 015 / 018 | Total loss: 4.140 | Reg loss: 0.037 | Tree loss: 4.140 | Accuracy: 0.068359 | 0.231 sec/iter
Epoch: 90 | Batch: 016 / 018 | Total loss: 4.071 | Reg loss: 0.037 | Tree loss: 4.071 | Accuracy: 0.074219 | 0.231 sec/iter
Epoch: 90 | Batch: 017 / 018 | Total loss: 4.052 | Reg loss: 0.037 | Tree loss: 4.052 | Accuracy: 0.078675 | 0.231 sec/iter
Average 

Epoch: 93 | Batch: 017 / 018 | Total loss: 4.018 | Reg loss: 0.037 | Tree loss: 4.018 | Accuracy: 0.086957 | 0.231 sec/iter
Average sparseness: 0.9821428571428573
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
Epoch: 94 | Batch: 000 / 018 | Total loss: 4.304 | Reg loss: 0.037 | Tree loss: 4.304 | Accuracy: 0.058594 | 0.231 sec/iter
Epoch: 94 | Batch: 001 / 018 | Total loss: 4.269 | Reg loss: 0.037 | Tree loss: 4.269 | Accuracy: 0.060547 | 0.231 sec/iter
Epoch: 94 | Batch: 002 / 018 | Total loss: 4.254 | Reg loss: 0.037 | Tree loss: 4.254 | Accuracy: 0.072266 | 0.231 sec/iter
Epoch: 94 | Batch: 003 / 018 | Total loss: 4.203 | Reg loss: 0.037 | Tree loss: 4.203 | Accuracy: 0.066406 | 0.231 sec/iter
Epoch: 94 | Batch: 004 / 018 | Total loss: 4.224 | Reg loss: 0.037 | Tree loss: 4.224 | Accuracy: 0.066406 | 0.231 sec/iter
Epoch: 94 | Batch: 005

Epoch: 97 | Batch: 004 / 018 | Total loss: 4.211 | Reg loss: 0.037 | Tree loss: 4.211 | Accuracy: 0.068359 | 0.231 sec/iter
Epoch: 97 | Batch: 005 / 018 | Total loss: 4.241 | Reg loss: 0.037 | Tree loss: 4.241 | Accuracy: 0.078125 | 0.231 sec/iter
Epoch: 97 | Batch: 006 / 018 | Total loss: 4.253 | Reg loss: 0.037 | Tree loss: 4.253 | Accuracy: 0.064453 | 0.231 sec/iter
Epoch: 97 | Batch: 007 / 018 | Total loss: 4.167 | Reg loss: 0.037 | Tree loss: 4.167 | Accuracy: 0.076172 | 0.231 sec/iter
Epoch: 97 | Batch: 008 / 018 | Total loss: 4.138 | Reg loss: 0.037 | Tree loss: 4.138 | Accuracy: 0.070312 | 0.231 sec/iter
Epoch: 97 | Batch: 009 / 018 | Total loss: 4.107 | Reg loss: 0.037 | Tree loss: 4.107 | Accuracy: 0.087891 | 0.231 sec/iter
Epoch: 97 | Batch: 010 / 018 | Total loss: 4.153 | Reg loss: 0.037 | Tree loss: 4.153 | Accuracy: 0.068359 | 0.231 sec/iter
Epoch: 97 | Batch: 011 / 018 | Total loss: 4.116 | Reg loss: 0.037 | Tree loss: 4.116 | Accuracy: 0.078125 | 0.231 sec/iter
Epoch: 9

In [35]:
plt.figure(figsize=(10, 5))
plt.ylabel("Accuracy")
plt.xlabel('Iteration')
plt.plot(accs, label='Accuracy vs iteration')
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [36]:
plt.figure()
plt.ylabel("Loss")
plt.xlabel('Iteration')
plt.plot(losses, label='Loss vs iteration')
plt.yscale("log")
plt.show()

plt.figure()
weights = tree.inner_nodes.weight.cpu().detach().numpy().flatten()
plt.hist(weights, bins=500)
weights_std = np.std(weights)
weights_mean = np.mean(weights)
plt.axvline(weights_mean + weights_std, color='r')
plt.axvline(weights_mean - weights_std, color='r')
plt.title(f"Mean: {weights_mean}   |   STD: {weights_std}")
plt.yscale("log")
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

# Tree Visualization

In [37]:
plt.figure(figsize=(15, 10), dpi=80)
avg_height, root = tree.visualize()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Average height: 7.996078431372549


# Extract Rules

# Accumulate samples in the leaves

In [38]:
print(f"Number of patterns: {len(root.get_leaves())}")

Number of patterns: 255


In [39]:
method = 'greedy'

In [40]:
root.clear_leaves_samples()

with torch.no_grad():
    for batch_idx, (data, target) in enumerate(tree_loader):
        root.accumulate_samples(data, method)



# Tighten boundaries

In [44]:
attr_names = dataset.items

# print(attr_names)
leaves = root.get_leaves()
sum_comprehensibility = 0
comprehensibilities = []
for pattern_counter, leaf in enumerate(leaves):
    leaf.reset_path()
    leaf.tighten_with_accumulated_samples()
    conds = leaf.get_path_conditions(attr_names)
    print(f"============== Pattern {pattern_counter + 1} ==============")
#     print(conds)
    comprehensibilities.append(sum([cond.comprehensibility for cond in conds]))
    
print(f"Average comprehensibility: {np.mean(comprehensibilities)}")
print(f"std comprehensibility: {np.std(comprehensibilities)}")
print(f"var comprehensibility: {np.var(comprehensibilities)}")
print(f"minimum comprehensibility: {np.min(comprehensibilities)}")
print(f"maximum comprehensibility: {np.max(comprehensibilities)}")

6307
1696
1184


Average comprehensibility: 36.58039215686274
std comprehensibility: 1.6355151023865735
var comprehensibility: 2.674909650134564
minimum comprehensibility: 32
maximum comprehensibility: 42
