In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib widget
import numpy as np
import pandas as pd
import torch
import torchvision
from torch import optim
import matplotlib.pyplot as plt
import torch.nn.functional as F
from sklearn.metrics import pairwise_distances
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tqdm.notebook import tqdm
import os
import sys
import time
import torch.nn as nn
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)
from stream_generators.market_basket_dataset import MarketBasketDataset, BinaryEncodingTransform, RemoveItemsTransform
from utils.MatplotlibUtils import reduce_dims_and_plot
from network.auto_encoder import AutoEncoder
from losses.knn_loss import KNNLoss
from sklearn.cluster import DBSCAN
from sklearn.mixture import GaussianMixture
from soft_decision_tree.sdt_model import SDT
from sklearn.metrics import davies_bouldin_score

In [2]:
k = 32
tree_depth = 12
device = 'cuda'
dataset_path = r"/mnt/qnap/ekosman/Groceries_dataset.csv"

## Create the market basket dataset and use one-hot encoding for items

In [3]:
dataset = MarketBasketDataset(dataset_path=dataset_path)

In [4]:
model = AutoEncoder(dataset.n_items, 50, 4).train().to(device)
epochs = 500
lr = 5e-3
batch_size = 512
log_every = 5

In [5]:
dataset.transform = torchvision.transforms.Compose([
    RemoveItemsTransform(p=0.5),
    BinaryEncodingTransform(mapping=dataset.items_to_idx),
]
)
dataset.target_transform = torchvision.transforms.Compose([
    BinaryEncodingTransform(mapping=dataset.items_to_idx),
]
)

In [6]:
model.train()
data_iter = torch.utils.data.DataLoader(dataset,
                                     batch_size=batch_size,
                                     shuffle=True,
                                     num_workers=1,
                                     pin_memory=True)
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)
scheduler = ReduceLROnPlateau(optimizer, 'min', verbose=True, factor=0.7, threshold=1e-4)
knn_crt = KNNLoss(k=k).to(device)
losses = []
alpha = 10/170
gamma = 2
for epoch in range(epochs):
    total_loss = 0
    for iteration, (batch, target) in enumerate(data_iter):
        batch = batch.to(device)
        target = target.to(device)
        outputs, iterm = model(batch, return_intermidiate=True)
        mse_loss = F.binary_cross_entropy_with_logits(outputs, target, reduction='none')
        mask = torch.ones_like(mse_loss)
        mask[target == 0] = alpha ** gamma
        mask[target == 1] = (1 - alpha) ** gamma
        mse_loss = (mse_loss * mask).sum(dim=-1).mean()
        try:
            knn_loss = knn_crt(iterm)
            if torch.isinf(knn_loss):
                knn_loss = 0
        except ValueError:
            knn_loss = torch.tensor(0)
        loss = mse_loss + knn_loss
        total_loss += loss.item()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if iteration % log_every == 0:
            print(f"Epoch {epoch} / {epochs} | iteration {iteration} / {len(data_iter)} | Total Loss: {loss.item()} | KNN Loss: {knn_loss.item()} | BCE Loss: {mse_loss.item()}")
    
    scheduler.step(total_loss / (iteration + 1))
    losses.append(total_loss / (iteration + 1))

Epoch 0 / 500 | iteration 0 / 30 | Total Loss: 8.161576271057129 | KNN Loss: 6.226352691650391 | BCE Loss: 1.9352238178253174
Epoch 0 / 500 | iteration 5 / 30 | Total Loss: 8.159000396728516 | KNN Loss: 6.226648807525635 | BCE Loss: 1.932352066040039
Epoch 0 / 500 | iteration 10 / 30 | Total Loss: 8.188031196594238 | KNN Loss: 6.226075172424316 | BCE Loss: 1.9619560241699219
Epoch 0 / 500 | iteration 15 / 30 | Total Loss: 8.213977813720703 | KNN Loss: 6.225383758544922 | BCE Loss: 1.9885936975479126
Epoch 0 / 500 | iteration 20 / 30 | Total Loss: 8.180749893188477 | KNN Loss: 6.225130081176758 | BCE Loss: 1.9556200504302979
Epoch 0 / 500 | iteration 25 / 30 | Total Loss: 8.228500366210938 | KNN Loss: 6.224370956420898 | BCE Loss: 2.004129409790039
Epoch 1 / 500 | iteration 0 / 30 | Total Loss: 8.17105484008789 | KNN Loss: 6.22470760345459 | BCE Loss: 1.9463468790054321
Epoch 1 / 500 | iteration 5 / 30 | Total Loss: 8.153498649597168 | KNN Loss: 6.224190711975098 | BCE Loss: 1.929308056

Epoch 10 / 500 | iteration 25 / 30 | Total Loss: 6.008642196655273 | KNN Loss: 4.881716251373291 | BCE Loss: 1.1269259452819824
Epoch 11 / 500 | iteration 0 / 30 | Total Loss: 5.880500793457031 | KNN Loss: 4.756161212921143 | BCE Loss: 1.1243395805358887
Epoch 11 / 500 | iteration 5 / 30 | Total Loss: 5.784660816192627 | KNN Loss: 4.677248954772949 | BCE Loss: 1.1074117422103882
Epoch 11 / 500 | iteration 10 / 30 | Total Loss: 5.713019847869873 | KNN Loss: 4.59433126449585 | BCE Loss: 1.118688702583313
Epoch 11 / 500 | iteration 15 / 30 | Total Loss: 5.579710006713867 | KNN Loss: 4.452706813812256 | BCE Loss: 1.1270031929016113
Epoch 11 / 500 | iteration 20 / 30 | Total Loss: 5.4775896072387695 | KNN Loss: 4.382556438446045 | BCE Loss: 1.0950329303741455
Epoch 11 / 500 | iteration 25 / 30 | Total Loss: 5.459367752075195 | KNN Loss: 4.336392402648926 | BCE Loss: 1.1229753494262695
Epoch 12 / 500 | iteration 0 / 30 | Total Loss: 5.407453536987305 | KNN Loss: 4.294209957122803 | BCE Loss:

Epoch 21 / 500 | iteration 20 / 30 | Total Loss: 4.889537811279297 | KNN Loss: 3.8321566581726074 | BCE Loss: 1.0573813915252686
Epoch 21 / 500 | iteration 25 / 30 | Total Loss: 4.884041786193848 | KNN Loss: 3.817591905593872 | BCE Loss: 1.0664498805999756
Epoch 22 / 500 | iteration 0 / 30 | Total Loss: 4.904548645019531 | KNN Loss: 3.816315174102783 | BCE Loss: 1.0882335901260376
Epoch 22 / 500 | iteration 5 / 30 | Total Loss: 4.870769023895264 | KNN Loss: 3.8403100967407227 | BCE Loss: 1.030458927154541
Epoch 22 / 500 | iteration 10 / 30 | Total Loss: 4.880009651184082 | KNN Loss: 3.8069984912872314 | BCE Loss: 1.0730112791061401
Epoch 22 / 500 | iteration 15 / 30 | Total Loss: 4.868500709533691 | KNN Loss: 3.8365886211395264 | BCE Loss: 1.0319123268127441
Epoch 22 / 500 | iteration 20 / 30 | Total Loss: 4.897292613983154 | KNN Loss: 3.8595547676086426 | BCE Loss: 1.0377379655838013
Epoch 22 / 500 | iteration 25 / 30 | Total Loss: 4.877297401428223 | KNN Loss: 3.8279037475585938 | BC

Epoch 32 / 500 | iteration 15 / 30 | Total Loss: 4.769382953643799 | KNN Loss: 3.7356760501861572 | BCE Loss: 1.0337069034576416
Epoch 32 / 500 | iteration 20 / 30 | Total Loss: 4.821622848510742 | KNN Loss: 3.7641007900238037 | BCE Loss: 1.0575222969055176
Epoch 32 / 500 | iteration 25 / 30 | Total Loss: 4.8612380027771 | KNN Loss: 3.805204153060913 | BCE Loss: 1.0560338497161865
Epoch 33 / 500 | iteration 0 / 30 | Total Loss: 4.810745716094971 | KNN Loss: 3.7756741046905518 | BCE Loss: 1.0350717306137085
Epoch 33 / 500 | iteration 5 / 30 | Total Loss: 4.864107608795166 | KNN Loss: 3.8003976345062256 | BCE Loss: 1.0637099742889404
Epoch 33 / 500 | iteration 10 / 30 | Total Loss: 4.850316524505615 | KNN Loss: 3.797315835952759 | BCE Loss: 1.0530006885528564
Epoch 33 / 500 | iteration 15 / 30 | Total Loss: 4.829188346862793 | KNN Loss: 3.7873759269714355 | BCE Loss: 1.0418126583099365
Epoch 33 / 500 | iteration 20 / 30 | Total Loss: 4.849262237548828 | KNN Loss: 3.7826976776123047 | BCE

Epoch 43 / 500 | iteration 10 / 30 | Total Loss: 4.815564155578613 | KNN Loss: 3.7677969932556152 | BCE Loss: 1.0477674007415771
Epoch 43 / 500 | iteration 15 / 30 | Total Loss: 4.830121994018555 | KNN Loss: 3.8136463165283203 | BCE Loss: 1.0164759159088135
Epoch 43 / 500 | iteration 20 / 30 | Total Loss: 4.812733173370361 | KNN Loss: 3.7600300312042236 | BCE Loss: 1.0527031421661377
Epoch 43 / 500 | iteration 25 / 30 | Total Loss: 4.761682510375977 | KNN Loss: 3.7450225353240967 | BCE Loss: 1.0166598558425903
Epoch 44 / 500 | iteration 0 / 30 | Total Loss: 4.831888198852539 | KNN Loss: 3.7563159465789795 | BCE Loss: 1.0755720138549805
Epoch 44 / 500 | iteration 5 / 30 | Total Loss: 4.798451900482178 | KNN Loss: 3.7779626846313477 | BCE Loss: 1.02048921585083
Epoch 44 / 500 | iteration 10 / 30 | Total Loss: 4.813273906707764 | KNN Loss: 3.7830395698547363 | BCE Loss: 1.030234456062317
Epoch 44 / 500 | iteration 15 / 30 | Total Loss: 4.818156719207764 | KNN Loss: 3.7547125816345215 | BC

Epoch 54 / 500 | iteration 5 / 30 | Total Loss: 4.8205156326293945 | KNN Loss: 3.7558465003967285 | BCE Loss: 1.0646693706512451
Epoch 54 / 500 | iteration 10 / 30 | Total Loss: 4.785021781921387 | KNN Loss: 3.7652742862701416 | BCE Loss: 1.0197476148605347
Epoch 54 / 500 | iteration 15 / 30 | Total Loss: 4.803799629211426 | KNN Loss: 3.752368211746216 | BCE Loss: 1.05143141746521
Epoch 54 / 500 | iteration 20 / 30 | Total Loss: 4.7835564613342285 | KNN Loss: 3.739664077758789 | BCE Loss: 1.04389226436615
Epoch 54 / 500 | iteration 25 / 30 | Total Loss: 4.757462024688721 | KNN Loss: 3.723784923553467 | BCE Loss: 1.0336772203445435
Epoch 55 / 500 | iteration 0 / 30 | Total Loss: 4.785367488861084 | KNN Loss: 3.7681097984313965 | BCE Loss: 1.0172576904296875
Epoch 55 / 500 | iteration 5 / 30 | Total Loss: 4.766208171844482 | KNN Loss: 3.7340595722198486 | BCE Loss: 1.0321484804153442
Epoch 55 / 500 | iteration 10 / 30 | Total Loss: 4.766202926635742 | KNN Loss: 3.7332630157470703 | BCE L

Epoch 65 / 500 | iteration 0 / 30 | Total Loss: 4.717559814453125 | KNN Loss: 3.6916518211364746 | BCE Loss: 1.0259077548980713
Epoch 65 / 500 | iteration 5 / 30 | Total Loss: 4.754786014556885 | KNN Loss: 3.7242918014526367 | BCE Loss: 1.030494213104248
Epoch 65 / 500 | iteration 10 / 30 | Total Loss: 4.736239433288574 | KNN Loss: 3.7219812870025635 | BCE Loss: 1.0142581462860107
Epoch 65 / 500 | iteration 15 / 30 | Total Loss: 4.726287841796875 | KNN Loss: 3.7138025760650635 | BCE Loss: 1.0124852657318115
Epoch 65 / 500 | iteration 20 / 30 | Total Loss: 4.714808464050293 | KNN Loss: 3.6771745681762695 | BCE Loss: 1.0376336574554443
Epoch 65 / 500 | iteration 25 / 30 | Total Loss: 4.722078800201416 | KNN Loss: 3.713974952697754 | BCE Loss: 1.008103847503662
Epoch 66 / 500 | iteration 0 / 30 | Total Loss: 4.817697525024414 | KNN Loss: 3.742600917816162 | BCE Loss: 1.075096607208252
Epoch 66 / 500 | iteration 5 / 30 | Total Loss: 4.726203918457031 | KNN Loss: 3.697857618331909 | BCE Los

Epoch 75 / 500 | iteration 20 / 30 | Total Loss: 4.70012903213501 | KNN Loss: 3.6739377975463867 | BCE Loss: 1.0261913537979126
Epoch 75 / 500 | iteration 25 / 30 | Total Loss: 4.748657703399658 | KNN Loss: 3.7083072662353516 | BCE Loss: 1.040350317955017
Epoch 76 / 500 | iteration 0 / 30 | Total Loss: 4.758847713470459 | KNN Loss: 3.7303507328033447 | BCE Loss: 1.0284970998764038
Epoch 76 / 500 | iteration 5 / 30 | Total Loss: 4.747636795043945 | KNN Loss: 3.712843894958496 | BCE Loss: 1.0347930192947388
Epoch 76 / 500 | iteration 10 / 30 | Total Loss: 4.830390930175781 | KNN Loss: 3.7853589057922363 | BCE Loss: 1.045032262802124
Epoch 76 / 500 | iteration 15 / 30 | Total Loss: 4.697503089904785 | KNN Loss: 3.700519323348999 | BCE Loss: 0.9969840049743652
Epoch 76 / 500 | iteration 20 / 30 | Total Loss: 4.725960731506348 | KNN Loss: 3.6897521018981934 | BCE Loss: 1.0362083911895752
Epoch 76 / 500 | iteration 25 / 30 | Total Loss: 4.737920761108398 | KNN Loss: 3.7127323150634766 | BCE 

Epoch 86 / 500 | iteration 15 / 30 | Total Loss: 4.782723903656006 | KNN Loss: 3.7330589294433594 | BCE Loss: 1.049664855003357
Epoch 86 / 500 | iteration 20 / 30 | Total Loss: 4.758654594421387 | KNN Loss: 3.7373623847961426 | BCE Loss: 1.0212922096252441
Epoch 86 / 500 | iteration 25 / 30 | Total Loss: 4.709909439086914 | KNN Loss: 3.687044382095337 | BCE Loss: 1.022864818572998
Epoch 87 / 500 | iteration 0 / 30 | Total Loss: 4.7103753089904785 | KNN Loss: 3.667794704437256 | BCE Loss: 1.0425806045532227
Epoch 87 / 500 | iteration 5 / 30 | Total Loss: 4.736180305480957 | KNN Loss: 3.7055439949035645 | BCE Loss: 1.0306364297866821
Epoch 87 / 500 | iteration 10 / 30 | Total Loss: 4.759105682373047 | KNN Loss: 3.7183895111083984 | BCE Loss: 1.0407160520553589
Epoch 87 / 500 | iteration 15 / 30 | Total Loss: 4.696104526519775 | KNN Loss: 3.6824448108673096 | BCE Loss: 1.0136598348617554
Epoch 87 / 500 | iteration 20 / 30 | Total Loss: 4.759469985961914 | KNN Loss: 3.7415833473205566 | BC

Epoch 97 / 500 | iteration 5 / 30 | Total Loss: 4.726352214813232 | KNN Loss: 3.6945323944091797 | BCE Loss: 1.0318198204040527
Epoch 97 / 500 | iteration 10 / 30 | Total Loss: 4.732864856719971 | KNN Loss: 3.7146449089050293 | BCE Loss: 1.0182199478149414
Epoch 97 / 500 | iteration 15 / 30 | Total Loss: 4.753122329711914 | KNN Loss: 3.708754539489746 | BCE Loss: 1.0443676710128784
Epoch 97 / 500 | iteration 20 / 30 | Total Loss: 4.735610485076904 | KNN Loss: 3.705475330352783 | BCE Loss: 1.030135154724121
Epoch 97 / 500 | iteration 25 / 30 | Total Loss: 4.734410285949707 | KNN Loss: 3.7329373359680176 | BCE Loss: 1.001473069190979
Epoch 98 / 500 | iteration 0 / 30 | Total Loss: 4.749842643737793 | KNN Loss: 3.721266269683838 | BCE Loss: 1.0285766124725342
Epoch 98 / 500 | iteration 5 / 30 | Total Loss: 4.754935264587402 | KNN Loss: 3.708028793334961 | BCE Loss: 1.0469067096710205
Epoch 98 / 500 | iteration 10 / 30 | Total Loss: 4.786403179168701 | KNN Loss: 3.734344005584717 | BCE Los

Epoch 107 / 500 | iteration 25 / 30 | Total Loss: 4.768698692321777 | KNN Loss: 3.7170653343200684 | BCE Loss: 1.0516332387924194
Epoch 108 / 500 | iteration 0 / 30 | Total Loss: 4.730510234832764 | KNN Loss: 3.7095818519592285 | BCE Loss: 1.0209283828735352
Epoch 108 / 500 | iteration 5 / 30 | Total Loss: 4.746492385864258 | KNN Loss: 3.703134775161743 | BCE Loss: 1.0433573722839355
Epoch 108 / 500 | iteration 10 / 30 | Total Loss: 4.768903732299805 | KNN Loss: 3.7134742736816406 | BCE Loss: 1.055429458618164
Epoch 108 / 500 | iteration 15 / 30 | Total Loss: 4.6883697509765625 | KNN Loss: 3.681579828262329 | BCE Loss: 1.0067898035049438
Epoch 108 / 500 | iteration 20 / 30 | Total Loss: 4.808526515960693 | KNN Loss: 3.7909724712371826 | BCE Loss: 1.0175540447235107
Epoch 108 / 500 | iteration 25 / 30 | Total Loss: 4.7757768630981445 | KNN Loss: 3.7663114070892334 | BCE Loss: 1.009465217590332
Epoch 109 / 500 | iteration 0 / 30 | Total Loss: 4.735389709472656 | KNN Loss: 3.6886401176452

Epoch 118 / 500 | iteration 15 / 30 | Total Loss: 4.721090793609619 | KNN Loss: 3.696495294570923 | BCE Loss: 1.0245953798294067
Epoch 118 / 500 | iteration 20 / 30 | Total Loss: 4.762164115905762 | KNN Loss: 3.715609550476074 | BCE Loss: 1.046554684638977
Epoch 118 / 500 | iteration 25 / 30 | Total Loss: 4.714425086975098 | KNN Loss: 3.6814746856689453 | BCE Loss: 1.032950520515442
Epoch 119 / 500 | iteration 0 / 30 | Total Loss: 4.681241035461426 | KNN Loss: 3.6533665657043457 | BCE Loss: 1.02787446975708
Epoch 119 / 500 | iteration 5 / 30 | Total Loss: 4.733404159545898 | KNN Loss: 3.727436065673828 | BCE Loss: 1.0059682130813599
Epoch 119 / 500 | iteration 10 / 30 | Total Loss: 4.7258687019348145 | KNN Loss: 3.6961257457733154 | BCE Loss: 1.0297428369522095
Epoch 119 / 500 | iteration 15 / 30 | Total Loss: 4.731088638305664 | KNN Loss: 3.7130115032196045 | BCE Loss: 1.0180772542953491
Epoch 119 / 500 | iteration 20 / 30 | Total Loss: 4.739720344543457 | KNN Loss: 3.7130990028381348

Epoch 129 / 500 | iteration 5 / 30 | Total Loss: 4.71942138671875 | KNN Loss: 3.723304033279419 | BCE Loss: 0.9961174726486206
Epoch 129 / 500 | iteration 10 / 30 | Total Loss: 4.719966888427734 | KNN Loss: 3.6803359985351562 | BCE Loss: 1.039630651473999
Epoch 129 / 500 | iteration 15 / 30 | Total Loss: 4.73471212387085 | KNN Loss: 3.7119104862213135 | BCE Loss: 1.0228015184402466
Epoch 129 / 500 | iteration 20 / 30 | Total Loss: 4.762751579284668 | KNN Loss: 3.7187418937683105 | BCE Loss: 1.0440094470977783
Epoch 129 / 500 | iteration 25 / 30 | Total Loss: 4.7164306640625 | KNN Loss: 3.658285617828369 | BCE Loss: 1.0581451654434204
Epoch 130 / 500 | iteration 0 / 30 | Total Loss: 4.794734954833984 | KNN Loss: 3.733668804168701 | BCE Loss: 1.0610661506652832
Epoch 130 / 500 | iteration 5 / 30 | Total Loss: 4.6592116355896 | KNN Loss: 3.6446800231933594 | BCE Loss: 1.0145316123962402
Epoch 130 / 500 | iteration 10 / 30 | Total Loss: 4.765515327453613 | KNN Loss: 3.7434771060943604 | BC

Epoch 139 / 500 | iteration 25 / 30 | Total Loss: 4.730654716491699 | KNN Loss: 3.6917569637298584 | BCE Loss: 1.0388975143432617
Epoch 140 / 500 | iteration 0 / 30 | Total Loss: 4.719101428985596 | KNN Loss: 3.6691977977752686 | BCE Loss: 1.0499035120010376
Epoch 140 / 500 | iteration 5 / 30 | Total Loss: 4.71204948425293 | KNN Loss: 3.706813097000122 | BCE Loss: 1.005236268043518
Epoch 140 / 500 | iteration 10 / 30 | Total Loss: 4.752737045288086 | KNN Loss: 3.7210421562194824 | BCE Loss: 1.0316946506500244
Epoch 140 / 500 | iteration 15 / 30 | Total Loss: 4.721447467803955 | KNN Loss: 3.682311534881592 | BCE Loss: 1.0391359329223633
Epoch 140 / 500 | iteration 20 / 30 | Total Loss: 4.693603515625 | KNN Loss: 3.670137882232666 | BCE Loss: 1.0234657526016235
Epoch 140 / 500 | iteration 25 / 30 | Total Loss: 4.739391326904297 | KNN Loss: 3.7140326499938965 | BCE Loss: 1.0253584384918213
Epoch 141 / 500 | iteration 0 / 30 | Total Loss: 4.694085121154785 | KNN Loss: 3.662022113800049 | B

Epoch 150 / 500 | iteration 15 / 30 | Total Loss: 4.692320346832275 | KNN Loss: 3.679750680923462 | BCE Loss: 1.0125696659088135
Epoch 150 / 500 | iteration 20 / 30 | Total Loss: 4.74092435836792 | KNN Loss: 3.6927499771118164 | BCE Loss: 1.048174262046814
Epoch 150 / 500 | iteration 25 / 30 | Total Loss: 4.672476768493652 | KNN Loss: 3.6641311645507812 | BCE Loss: 1.0083458423614502
Epoch 151 / 500 | iteration 0 / 30 | Total Loss: 4.692797660827637 | KNN Loss: 3.6731109619140625 | BCE Loss: 1.0196869373321533
Epoch 151 / 500 | iteration 5 / 30 | Total Loss: 4.741195201873779 | KNN Loss: 3.706538438796997 | BCE Loss: 1.0346566438674927
Epoch 151 / 500 | iteration 10 / 30 | Total Loss: 4.726855754852295 | KNN Loss: 3.731053352355957 | BCE Loss: 0.9958022832870483
Epoch 151 / 500 | iteration 15 / 30 | Total Loss: 4.709287643432617 | KNN Loss: 3.697166919708252 | BCE Loss: 1.0121204853057861
Epoch 151 / 500 | iteration 20 / 30 | Total Loss: 4.70134162902832 | KNN Loss: 3.693608522415161 |

Epoch 161 / 500 | iteration 5 / 30 | Total Loss: 4.720521926879883 | KNN Loss: 3.6986353397369385 | BCE Loss: 1.0218863487243652
Epoch 161 / 500 | iteration 10 / 30 | Total Loss: 4.762421607971191 | KNN Loss: 3.699746608734131 | BCE Loss: 1.0626747608184814
Epoch 161 / 500 | iteration 15 / 30 | Total Loss: 4.748936176300049 | KNN Loss: 3.728175640106201 | BCE Loss: 1.0207605361938477
Epoch 161 / 500 | iteration 20 / 30 | Total Loss: 4.747129440307617 | KNN Loss: 3.719137191772461 | BCE Loss: 1.0279922485351562
Epoch 161 / 500 | iteration 25 / 30 | Total Loss: 4.699103355407715 | KNN Loss: 3.689349889755249 | BCE Loss: 1.0097532272338867
Epoch   162: reducing learning rate of group 0 to 8.4035e-04.
Epoch 162 / 500 | iteration 0 / 30 | Total Loss: 4.767461776733398 | KNN Loss: 3.717923402786255 | BCE Loss: 1.0495386123657227
Epoch 162 / 500 | iteration 5 / 30 | Total Loss: 4.740508556365967 | KNN Loss: 3.729621648788452 | BCE Loss: 1.010886788368225
Epoch 162 / 500 | iteration 10 / 30 | 

Epoch 171 / 500 | iteration 25 / 30 | Total Loss: 4.723074913024902 | KNN Loss: 3.6920275688171387 | BCE Loss: 1.0310475826263428
Epoch 172 / 500 | iteration 0 / 30 | Total Loss: 4.76519775390625 | KNN Loss: 3.7023539543151855 | BCE Loss: 1.0628437995910645
Epoch 172 / 500 | iteration 5 / 30 | Total Loss: 4.6731648445129395 | KNN Loss: 3.6617534160614014 | BCE Loss: 1.011411428451538
Epoch 172 / 500 | iteration 10 / 30 | Total Loss: 4.714776515960693 | KNN Loss: 3.705110549926758 | BCE Loss: 1.0096659660339355
Epoch 172 / 500 | iteration 15 / 30 | Total Loss: 4.775636672973633 | KNN Loss: 3.732318878173828 | BCE Loss: 1.0433180332183838
Epoch 172 / 500 | iteration 20 / 30 | Total Loss: 4.742336273193359 | KNN Loss: 3.7158775329589844 | BCE Loss: 1.026458740234375
Epoch 172 / 500 | iteration 25 / 30 | Total Loss: 4.7466840744018555 | KNN Loss: 3.693270683288574 | BCE Loss: 1.0534135103225708
Epoch 173 / 500 | iteration 0 / 30 | Total Loss: 4.74367618560791 | KNN Loss: 3.695646047592163 

Epoch 182 / 500 | iteration 15 / 30 | Total Loss: 4.728199481964111 | KNN Loss: 3.717276096343994 | BCE Loss: 1.0109233856201172
Epoch 182 / 500 | iteration 20 / 30 | Total Loss: 4.724344253540039 | KNN Loss: 3.726048469543457 | BCE Loss: 0.9982956647872925
Epoch 182 / 500 | iteration 25 / 30 | Total Loss: 4.70720100402832 | KNN Loss: 3.6923863887786865 | BCE Loss: 1.0148143768310547
Epoch 183 / 500 | iteration 0 / 30 | Total Loss: 4.755712509155273 | KNN Loss: 3.7218666076660156 | BCE Loss: 1.0338457822799683
Epoch 183 / 500 | iteration 5 / 30 | Total Loss: 4.698261260986328 | KNN Loss: 3.6680076122283936 | BCE Loss: 1.0302538871765137
Epoch 183 / 500 | iteration 10 / 30 | Total Loss: 4.702728271484375 | KNN Loss: 3.6781487464904785 | BCE Loss: 1.0245792865753174
Epoch 183 / 500 | iteration 15 / 30 | Total Loss: 4.738921165466309 | KNN Loss: 3.7188212871551514 | BCE Loss: 1.0200999975204468
Epoch 183 / 500 | iteration 20 / 30 | Total Loss: 4.722631931304932 | KNN Loss: 3.6905257701873

Epoch 193 / 500 | iteration 5 / 30 | Total Loss: 4.7431769371032715 | KNN Loss: 3.6902525424957275 | BCE Loss: 1.052924394607544
Epoch 193 / 500 | iteration 10 / 30 | Total Loss: 4.755003452301025 | KNN Loss: 3.7211110591888428 | BCE Loss: 1.0338923931121826
Epoch 193 / 500 | iteration 15 / 30 | Total Loss: 4.743887901306152 | KNN Loss: 3.7114546298980713 | BCE Loss: 1.032433032989502
Epoch 193 / 500 | iteration 20 / 30 | Total Loss: 4.713747024536133 | KNN Loss: 3.69150447845459 | BCE Loss: 1.0222423076629639
Epoch 193 / 500 | iteration 25 / 30 | Total Loss: 4.718604564666748 | KNN Loss: 3.687448501586914 | BCE Loss: 1.0311559438705444
Epoch 194 / 500 | iteration 0 / 30 | Total Loss: 4.720799446105957 | KNN Loss: 3.6797728538513184 | BCE Loss: 1.0410263538360596
Epoch 194 / 500 | iteration 5 / 30 | Total Loss: 4.698553085327148 | KNN Loss: 3.6851823329925537 | BCE Loss: 1.0133705139160156
Epoch 194 / 500 | iteration 10 / 30 | Total Loss: 4.697116374969482 | KNN Loss: 3.679446697235107

Epoch 203 / 500 | iteration 25 / 30 | Total Loss: 4.670753002166748 | KNN Loss: 3.656561851501465 | BCE Loss: 1.0141910314559937
Epoch 204 / 500 | iteration 0 / 30 | Total Loss: 4.73293399810791 | KNN Loss: 3.679028272628784 | BCE Loss: 1.0539054870605469
Epoch 204 / 500 | iteration 5 / 30 | Total Loss: 4.773617744445801 | KNN Loss: 3.706202745437622 | BCE Loss: 1.0674148797988892
Epoch 204 / 500 | iteration 10 / 30 | Total Loss: 4.658775329589844 | KNN Loss: 3.649751901626587 | BCE Loss: 1.009023666381836
Epoch 204 / 500 | iteration 15 / 30 | Total Loss: 4.724366188049316 | KNN Loss: 3.6884162425994873 | BCE Loss: 1.035949945449829
Epoch 204 / 500 | iteration 20 / 30 | Total Loss: 4.76778507232666 | KNN Loss: 3.7399439811706543 | BCE Loss: 1.027841329574585
Epoch 204 / 500 | iteration 25 / 30 | Total Loss: 4.754150390625 | KNN Loss: 3.6985888481140137 | BCE Loss: 1.0555617809295654
Epoch 205 / 500 | iteration 0 / 30 | Total Loss: 4.8139262199401855 | KNN Loss: 3.739851951599121 | BCE 

Epoch 214 / 500 | iteration 15 / 30 | Total Loss: 4.733274936676025 | KNN Loss: 3.7063345909118652 | BCE Loss: 1.0269403457641602
Epoch 214 / 500 | iteration 20 / 30 | Total Loss: 4.655884265899658 | KNN Loss: 3.644085645675659 | BCE Loss: 1.011798620223999
Epoch 214 / 500 | iteration 25 / 30 | Total Loss: 4.745719909667969 | KNN Loss: 3.721095323562622 | BCE Loss: 1.0246245861053467
Epoch 215 / 500 | iteration 0 / 30 | Total Loss: 4.760196208953857 | KNN Loss: 3.7233102321624756 | BCE Loss: 1.0368858575820923
Epoch 215 / 500 | iteration 5 / 30 | Total Loss: 4.697086334228516 | KNN Loss: 3.690844774246216 | BCE Loss: 1.006241798400879
Epoch 215 / 500 | iteration 10 / 30 | Total Loss: 4.7418599128723145 | KNN Loss: 3.716752767562866 | BCE Loss: 1.0251071453094482
Epoch 215 / 500 | iteration 15 / 30 | Total Loss: 4.703171730041504 | KNN Loss: 3.688396692276001 | BCE Loss: 1.0147749185562134
Epoch 215 / 500 | iteration 20 / 30 | Total Loss: 4.667035102844238 | KNN Loss: 3.6743946075439453

Epoch 225 / 500 | iteration 5 / 30 | Total Loss: 4.690882682800293 | KNN Loss: 3.663172960281372 | BCE Loss: 1.0277094841003418
Epoch 225 / 500 | iteration 10 / 30 | Total Loss: 4.711453914642334 | KNN Loss: 3.674525737762451 | BCE Loss: 1.0369281768798828
Epoch 225 / 500 | iteration 15 / 30 | Total Loss: 4.704532623291016 | KNN Loss: 3.6768031120300293 | BCE Loss: 1.0277293920516968
Epoch 225 / 500 | iteration 20 / 30 | Total Loss: 4.74006986618042 | KNN Loss: 3.720184564590454 | BCE Loss: 1.0198853015899658
Epoch 225 / 500 | iteration 25 / 30 | Total Loss: 4.698287010192871 | KNN Loss: 3.671630859375 | BCE Loss: 1.0266563892364502
Epoch 226 / 500 | iteration 0 / 30 | Total Loss: 4.737943649291992 | KNN Loss: 3.7200491428375244 | BCE Loss: 1.0178947448730469
Epoch 226 / 500 | iteration 5 / 30 | Total Loss: 4.749603748321533 | KNN Loss: 3.713608980178833 | BCE Loss: 1.0359947681427002
Epoch 226 / 500 | iteration 10 / 30 | Total Loss: 4.713703155517578 | KNN Loss: 3.6990466117858887 | B

Epoch 235 / 500 | iteration 25 / 30 | Total Loss: 4.726428985595703 | KNN Loss: 3.70947527885437 | BCE Loss: 1.0169538259506226
Epoch 236 / 500 | iteration 0 / 30 | Total Loss: 4.697216987609863 | KNN Loss: 3.6754963397979736 | BCE Loss: 1.0217208862304688
Epoch 236 / 500 | iteration 5 / 30 | Total Loss: 4.7593536376953125 | KNN Loss: 3.7357773780822754 | BCE Loss: 1.023576021194458
Epoch 236 / 500 | iteration 10 / 30 | Total Loss: 4.7160773277282715 | KNN Loss: 3.661588191986084 | BCE Loss: 1.0544891357421875
Epoch 236 / 500 | iteration 15 / 30 | Total Loss: 4.757218360900879 | KNN Loss: 3.727879762649536 | BCE Loss: 1.0293388366699219
Epoch 236 / 500 | iteration 20 / 30 | Total Loss: 4.703752517700195 | KNN Loss: 3.6653473377227783 | BCE Loss: 1.0384050607681274
Epoch 236 / 500 | iteration 25 / 30 | Total Loss: 4.717349052429199 | KNN Loss: 3.6914241313934326 | BCE Loss: 1.0259249210357666
Epoch 237 / 500 | iteration 0 / 30 | Total Loss: 4.707173824310303 | KNN Loss: 3.68599891662597

Epoch 246 / 500 | iteration 15 / 30 | Total Loss: 4.663125514984131 | KNN Loss: 3.6654231548309326 | BCE Loss: 0.9977021813392639
Epoch 246 / 500 | iteration 20 / 30 | Total Loss: 4.677103042602539 | KNN Loss: 3.6687116622924805 | BCE Loss: 1.0083911418914795
Epoch 246 / 500 | iteration 25 / 30 | Total Loss: 4.751274585723877 | KNN Loss: 3.687795400619507 | BCE Loss: 1.0634791851043701
Epoch 247 / 500 | iteration 0 / 30 | Total Loss: 4.728036403656006 | KNN Loss: 3.6967201232910156 | BCE Loss: 1.0313162803649902
Epoch 247 / 500 | iteration 5 / 30 | Total Loss: 4.705807209014893 | KNN Loss: 3.6932244300842285 | BCE Loss: 1.0125826597213745
Epoch 247 / 500 | iteration 10 / 30 | Total Loss: 4.699333190917969 | KNN Loss: 3.6660399436950684 | BCE Loss: 1.0332930088043213
Epoch 247 / 500 | iteration 15 / 30 | Total Loss: 4.682737827301025 | KNN Loss: 3.663724184036255 | BCE Loss: 1.0190136432647705
Epoch 247 / 500 | iteration 20 / 30 | Total Loss: 4.673327445983887 | KNN Loss: 3.672988176345

Epoch 257 / 500 | iteration 5 / 30 | Total Loss: 4.686240196228027 | KNN Loss: 3.6769490242004395 | BCE Loss: 1.0092909336090088
Epoch 257 / 500 | iteration 10 / 30 | Total Loss: 4.713218688964844 | KNN Loss: 3.6742780208587646 | BCE Loss: 1.0389409065246582
Epoch 257 / 500 | iteration 15 / 30 | Total Loss: 4.69713830947876 | KNN Loss: 3.6589245796203613 | BCE Loss: 1.0382137298583984
Epoch 257 / 500 | iteration 20 / 30 | Total Loss: 4.675715446472168 | KNN Loss: 3.6596076488494873 | BCE Loss: 1.0161077976226807
Epoch 257 / 500 | iteration 25 / 30 | Total Loss: 4.70760440826416 | KNN Loss: 3.6653239727020264 | BCE Loss: 1.042280673980713
Epoch 258 / 500 | iteration 0 / 30 | Total Loss: 4.670154571533203 | KNN Loss: 3.6453938484191895 | BCE Loss: 1.0247609615325928
Epoch 258 / 500 | iteration 5 / 30 | Total Loss: 4.751132965087891 | KNN Loss: 3.7314257621765137 | BCE Loss: 1.019707441329956
Epoch 258 / 500 | iteration 10 / 30 | Total Loss: 4.714669227600098 | KNN Loss: 3.684413909912109

Epoch 267 / 500 | iteration 25 / 30 | Total Loss: 4.690254211425781 | KNN Loss: 3.6842732429504395 | BCE Loss: 1.0059807300567627
Epoch 268 / 500 | iteration 0 / 30 | Total Loss: 4.714166164398193 | KNN Loss: 3.673591375350952 | BCE Loss: 1.0405747890472412
Epoch 268 / 500 | iteration 5 / 30 | Total Loss: 4.695077419281006 | KNN Loss: 3.677628993988037 | BCE Loss: 1.0174484252929688
Epoch 268 / 500 | iteration 10 / 30 | Total Loss: 4.762144088745117 | KNN Loss: 3.7387232780456543 | BCE Loss: 1.0234205722808838
Epoch 268 / 500 | iteration 15 / 30 | Total Loss: 4.707087516784668 | KNN Loss: 3.6951851844787598 | BCE Loss: 1.0119025707244873
Epoch 268 / 500 | iteration 20 / 30 | Total Loss: 4.733908176422119 | KNN Loss: 3.703760862350464 | BCE Loss: 1.0301471948623657
Epoch 268 / 500 | iteration 25 / 30 | Total Loss: 4.743232727050781 | KNN Loss: 3.729060649871826 | BCE Loss: 1.014172077178955
Epoch 269 / 500 | iteration 0 / 30 | Total Loss: 4.718256950378418 | KNN Loss: 3.70161509513855 |

Epoch 278 / 500 | iteration 15 / 30 | Total Loss: 4.682867050170898 | KNN Loss: 3.6782047748565674 | BCE Loss: 1.004662275314331
Epoch 278 / 500 | iteration 20 / 30 | Total Loss: 4.756335258483887 | KNN Loss: 3.707280397415161 | BCE Loss: 1.0490546226501465
Epoch 278 / 500 | iteration 25 / 30 | Total Loss: 4.712508201599121 | KNN Loss: 3.673041582107544 | BCE Loss: 1.0394665002822876
Epoch 279 / 500 | iteration 0 / 30 | Total Loss: 4.72740364074707 | KNN Loss: 3.7171902656555176 | BCE Loss: 1.0102134943008423
Epoch 279 / 500 | iteration 5 / 30 | Total Loss: 4.697874069213867 | KNN Loss: 3.6791491508483887 | BCE Loss: 1.018724799156189
Epoch 279 / 500 | iteration 10 / 30 | Total Loss: 4.705450534820557 | KNN Loss: 3.6719131469726562 | BCE Loss: 1.0335373878479004
Epoch 279 / 500 | iteration 15 / 30 | Total Loss: 4.713605880737305 | KNN Loss: 3.6799371242523193 | BCE Loss: 1.0336689949035645
Epoch 279 / 500 | iteration 20 / 30 | Total Loss: 4.705572128295898 | KNN Loss: 3.667717695236206

Epoch 289 / 500 | iteration 5 / 30 | Total Loss: 4.760621547698975 | KNN Loss: 3.7019033432006836 | BCE Loss: 1.0587183237075806
Epoch 289 / 500 | iteration 10 / 30 | Total Loss: 4.719483375549316 | KNN Loss: 3.6946074962615967 | BCE Loss: 1.0248757600784302
Epoch 289 / 500 | iteration 15 / 30 | Total Loss: 4.71480131149292 | KNN Loss: 3.681589126586914 | BCE Loss: 1.0332121849060059
Epoch 289 / 500 | iteration 20 / 30 | Total Loss: 4.73521614074707 | KNN Loss: 3.6880173683166504 | BCE Loss: 1.0471986532211304
Epoch 289 / 500 | iteration 25 / 30 | Total Loss: 4.7194294929504395 | KNN Loss: 3.6698851585388184 | BCE Loss: 1.0495442152023315
Epoch 290 / 500 | iteration 0 / 30 | Total Loss: 4.683517932891846 | KNN Loss: 3.6646318435668945 | BCE Loss: 1.0188862085342407
Epoch 290 / 500 | iteration 5 / 30 | Total Loss: 4.706836223602295 | KNN Loss: 3.6750364303588867 | BCE Loss: 1.0317997932434082
Epoch 290 / 500 | iteration 10 / 30 | Total Loss: 4.745070934295654 | KNN Loss: 3.7051889896392

Epoch 299 / 500 | iteration 25 / 30 | Total Loss: 4.729122161865234 | KNN Loss: 3.691704273223877 | BCE Loss: 1.0374176502227783
Epoch 300 / 500 | iteration 0 / 30 | Total Loss: 4.758906364440918 | KNN Loss: 3.700498104095459 | BCE Loss: 1.058408498764038
Epoch 300 / 500 | iteration 5 / 30 | Total Loss: 4.719759941101074 | KNN Loss: 3.7201356887817383 | BCE Loss: 0.9996241331100464
Epoch 300 / 500 | iteration 10 / 30 | Total Loss: 4.7085161209106445 | KNN Loss: 3.6874711513519287 | BCE Loss: 1.0210449695587158
Epoch 300 / 500 | iteration 15 / 30 | Total Loss: 4.774001121520996 | KNN Loss: 3.71958327293396 | BCE Loss: 1.0544177293777466
Epoch 300 / 500 | iteration 20 / 30 | Total Loss: 4.717996597290039 | KNN Loss: 3.68058705329895 | BCE Loss: 1.0374093055725098
Epoch 300 / 500 | iteration 25 / 30 | Total Loss: 4.716028213500977 | KNN Loss: 3.7086071968078613 | BCE Loss: 1.0074210166931152
Epoch 301 / 500 | iteration 0 / 30 | Total Loss: 4.70789909362793 | KNN Loss: 3.7055933475494385 |

Epoch 310 / 500 | iteration 15 / 30 | Total Loss: 4.6973490715026855 | KNN Loss: 3.6869826316833496 | BCE Loss: 1.010366439819336
Epoch 310 / 500 | iteration 20 / 30 | Total Loss: 4.762295722961426 | KNN Loss: 3.740694046020508 | BCE Loss: 1.021601915359497
Epoch 310 / 500 | iteration 25 / 30 | Total Loss: 4.765730381011963 | KNN Loss: 3.7262954711914062 | BCE Loss: 1.0394349098205566
Epoch 311 / 500 | iteration 0 / 30 | Total Loss: 4.670647144317627 | KNN Loss: 3.644244432449341 | BCE Loss: 1.0264027118682861
Epoch 311 / 500 | iteration 5 / 30 | Total Loss: 4.692130088806152 | KNN Loss: 3.671787977218628 | BCE Loss: 1.0203418731689453
Epoch 311 / 500 | iteration 10 / 30 | Total Loss: 4.709624767303467 | KNN Loss: 3.69140362739563 | BCE Loss: 1.0182210206985474
Epoch 311 / 500 | iteration 15 / 30 | Total Loss: 4.742819786071777 | KNN Loss: 3.708972930908203 | BCE Loss: 1.0338467359542847
Epoch 311 / 500 | iteration 20 / 30 | Total Loss: 4.745131492614746 | KNN Loss: 3.6836631298065186 

Epoch 321 / 500 | iteration 5 / 30 | Total Loss: 4.703286170959473 | KNN Loss: 3.7073047161102295 | BCE Loss: 0.9959813356399536
Epoch 321 / 500 | iteration 10 / 30 | Total Loss: 4.7307538986206055 | KNN Loss: 3.6719863414764404 | BCE Loss: 1.058767318725586
Epoch 321 / 500 | iteration 15 / 30 | Total Loss: 4.744879245758057 | KNN Loss: 3.7323477268218994 | BCE Loss: 1.0125315189361572
Epoch 321 / 500 | iteration 20 / 30 | Total Loss: 4.709120273590088 | KNN Loss: 3.7197728157043457 | BCE Loss: 0.9893473386764526
Epoch 321 / 500 | iteration 25 / 30 | Total Loss: 4.719443321228027 | KNN Loss: 3.6686720848083496 | BCE Loss: 1.0507714748382568
Epoch 322 / 500 | iteration 0 / 30 | Total Loss: 4.702714920043945 | KNN Loss: 3.6722261905670166 | BCE Loss: 1.0304887294769287
Epoch 322 / 500 | iteration 5 / 30 | Total Loss: 4.727326393127441 | KNN Loss: 3.696885824203491 | BCE Loss: 1.0304408073425293
Epoch 322 / 500 | iteration 10 / 30 | Total Loss: 4.694815158843994 | KNN Loss: 3.673060655593

Epoch 331 / 500 | iteration 25 / 30 | Total Loss: 4.696038246154785 | KNN Loss: 3.683161973953247 | BCE Loss: 1.0128765106201172
Epoch 332 / 500 | iteration 0 / 30 | Total Loss: 4.717602729797363 | KNN Loss: 3.699557304382324 | BCE Loss: 1.01804518699646
Epoch 332 / 500 | iteration 5 / 30 | Total Loss: 4.715530872344971 | KNN Loss: 3.663956880569458 | BCE Loss: 1.0515739917755127
Epoch 332 / 500 | iteration 10 / 30 | Total Loss: 4.771039962768555 | KNN Loss: 3.7389886379241943 | BCE Loss: 1.0320510864257812
Epoch 332 / 500 | iteration 15 / 30 | Total Loss: 4.737974166870117 | KNN Loss: 3.699420213699341 | BCE Loss: 1.0385539531707764
Epoch 332 / 500 | iteration 20 / 30 | Total Loss: 4.757017135620117 | KNN Loss: 3.7121214866638184 | BCE Loss: 1.0448956489562988
Epoch 332 / 500 | iteration 25 / 30 | Total Loss: 4.698191165924072 | KNN Loss: 3.702141523361206 | BCE Loss: 0.9960498213768005
Epoch 333 / 500 | iteration 0 / 30 | Total Loss: 4.745579719543457 | KNN Loss: 3.698697328567505 | 

Epoch 342 / 500 | iteration 15 / 30 | Total Loss: 4.716294288635254 | KNN Loss: 3.6879234313964844 | BCE Loss: 1.0283708572387695
Epoch 342 / 500 | iteration 20 / 30 | Total Loss: 4.706413269042969 | KNN Loss: 3.661684513092041 | BCE Loss: 1.0447287559509277
Epoch 342 / 500 | iteration 25 / 30 | Total Loss: 4.7120442390441895 | KNN Loss: 3.714740514755249 | BCE Loss: 0.9973036646842957
Epoch 343 / 500 | iteration 0 / 30 | Total Loss: 4.732468128204346 | KNN Loss: 3.701120376586914 | BCE Loss: 1.0313477516174316
Epoch 343 / 500 | iteration 5 / 30 | Total Loss: 4.715949058532715 | KNN Loss: 3.708346366882324 | BCE Loss: 1.0076026916503906
Epoch 343 / 500 | iteration 10 / 30 | Total Loss: 4.73274564743042 | KNN Loss: 3.7274937629699707 | BCE Loss: 1.0052520036697388
Epoch 343 / 500 | iteration 15 / 30 | Total Loss: 4.6937761306762695 | KNN Loss: 3.668198347091675 | BCE Loss: 1.0255775451660156
Epoch 343 / 500 | iteration 20 / 30 | Total Loss: 4.6882734298706055 | KNN Loss: 3.6559057235717

Epoch 353 / 500 | iteration 0 / 30 | Total Loss: 4.737476348876953 | KNN Loss: 3.7125961780548096 | BCE Loss: 1.024880051612854
Epoch 353 / 500 | iteration 5 / 30 | Total Loss: 4.73302698135376 | KNN Loss: 3.67484188079834 | BCE Loss: 1.05818510055542
Epoch 353 / 500 | iteration 10 / 30 | Total Loss: 4.700045108795166 | KNN Loss: 3.6693005561828613 | BCE Loss: 1.0307446718215942
Epoch 353 / 500 | iteration 15 / 30 | Total Loss: 4.69725227355957 | KNN Loss: 3.6881394386291504 | BCE Loss: 1.009113073348999
Epoch 353 / 500 | iteration 20 / 30 | Total Loss: 4.690520763397217 | KNN Loss: 3.6620655059814453 | BCE Loss: 1.0284552574157715
Epoch 353 / 500 | iteration 25 / 30 | Total Loss: 4.747366905212402 | KNN Loss: 3.694300651550293 | BCE Loss: 1.0530664920806885
Epoch 354 / 500 | iteration 0 / 30 | Total Loss: 4.665631294250488 | KNN Loss: 3.6557323932647705 | BCE Loss: 1.0098990201950073
Epoch 354 / 500 | iteration 5 / 30 | Total Loss: 4.743204116821289 | KNN Loss: 3.7181787490844727 | BC

Epoch 363 / 500 | iteration 20 / 30 | Total Loss: 4.7802228927612305 | KNN Loss: 3.7588999271392822 | BCE Loss: 1.0213230848312378
Epoch 363 / 500 | iteration 25 / 30 | Total Loss: 4.673035144805908 | KNN Loss: 3.671469211578369 | BCE Loss: 1.0015658140182495
Epoch 364 / 500 | iteration 0 / 30 | Total Loss: 4.734552383422852 | KNN Loss: 3.7308170795440674 | BCE Loss: 1.003735065460205
Epoch 364 / 500 | iteration 5 / 30 | Total Loss: 4.733162879943848 | KNN Loss: 3.708827257156372 | BCE Loss: 1.0243357419967651
Epoch 364 / 500 | iteration 10 / 30 | Total Loss: 4.796446323394775 | KNN Loss: 3.7646076679229736 | BCE Loss: 1.0318386554718018
Epoch 364 / 500 | iteration 15 / 30 | Total Loss: 4.76054573059082 | KNN Loss: 3.7119338512420654 | BCE Loss: 1.0486116409301758
Epoch 364 / 500 | iteration 20 / 30 | Total Loss: 4.728576183319092 | KNN Loss: 3.7051239013671875 | BCE Loss: 1.0234522819519043
Epoch 364 / 500 | iteration 25 / 30 | Total Loss: 4.7119059562683105 | KNN Loss: 3.707667589187

Epoch 374 / 500 | iteration 10 / 30 | Total Loss: 4.717179298400879 | KNN Loss: 3.6856465339660645 | BCE Loss: 1.0315330028533936
Epoch 374 / 500 | iteration 15 / 30 | Total Loss: 4.699735164642334 | KNN Loss: 3.6713080406188965 | BCE Loss: 1.0284271240234375
Epoch 374 / 500 | iteration 20 / 30 | Total Loss: 4.7407636642456055 | KNN Loss: 3.693535804748535 | BCE Loss: 1.0472276210784912
Epoch 374 / 500 | iteration 25 / 30 | Total Loss: 4.7882795333862305 | KNN Loss: 3.7234835624694824 | BCE Loss: 1.0647960901260376
Epoch 375 / 500 | iteration 0 / 30 | Total Loss: 4.648393630981445 | KNN Loss: 3.6593244075775146 | BCE Loss: 0.9890689849853516
Epoch 375 / 500 | iteration 5 / 30 | Total Loss: 4.685923099517822 | KNN Loss: 3.6749303340911865 | BCE Loss: 1.0109926462173462
Epoch 375 / 500 | iteration 10 / 30 | Total Loss: 4.7249345779418945 | KNN Loss: 3.6860744953155518 | BCE Loss: 1.0388603210449219
Epoch 375 / 500 | iteration 15 / 30 | Total Loss: 4.713101387023926 | KNN Loss: 3.67882728

Epoch 384 / 500 | iteration 25 / 30 | Total Loss: 4.716984748840332 | KNN Loss: 3.6645429134368896 | BCE Loss: 1.052441954612732
Epoch 385 / 500 | iteration 0 / 30 | Total Loss: 4.726771354675293 | KNN Loss: 3.6962080001831055 | BCE Loss: 1.0305631160736084
Epoch 385 / 500 | iteration 5 / 30 | Total Loss: 4.729870796203613 | KNN Loss: 3.7021656036376953 | BCE Loss: 1.0277049541473389
Epoch 385 / 500 | iteration 10 / 30 | Total Loss: 4.705059051513672 | KNN Loss: 3.7195026874542236 | BCE Loss: 0.9855563640594482
Epoch 385 / 500 | iteration 15 / 30 | Total Loss: 4.6992058753967285 | KNN Loss: 3.6730051040649414 | BCE Loss: 1.0262008905410767
Epoch 385 / 500 | iteration 20 / 30 | Total Loss: 4.747584342956543 | KNN Loss: 3.686039447784424 | BCE Loss: 1.0615448951721191
Epoch 385 / 500 | iteration 25 / 30 | Total Loss: 4.733644485473633 | KNN Loss: 3.683284044265747 | BCE Loss: 1.0503602027893066
Epoch 386 / 500 | iteration 0 / 30 | Total Loss: 4.743600845336914 | KNN Loss: 3.7328524589538

Epoch 395 / 500 | iteration 15 / 30 | Total Loss: 4.726590156555176 | KNN Loss: 3.656073808670044 | BCE Loss: 1.0705164670944214
Epoch 395 / 500 | iteration 20 / 30 | Total Loss: 4.750820159912109 | KNN Loss: 3.7186710834503174 | BCE Loss: 1.032149314880371
Epoch 395 / 500 | iteration 25 / 30 | Total Loss: 4.727725028991699 | KNN Loss: 3.7031919956207275 | BCE Loss: 1.0245330333709717
Epoch 396 / 500 | iteration 0 / 30 | Total Loss: 4.721490859985352 | KNN Loss: 3.6906983852386475 | BCE Loss: 1.0307927131652832
Epoch 396 / 500 | iteration 5 / 30 | Total Loss: 4.728784561157227 | KNN Loss: 3.67907977104187 | BCE Loss: 1.049704670906067
Epoch 396 / 500 | iteration 10 / 30 | Total Loss: 4.707558631896973 | KNN Loss: 3.70527720451355 | BCE Loss: 1.0022815465927124
Epoch 396 / 500 | iteration 15 / 30 | Total Loss: 4.707256317138672 | KNN Loss: 3.6947617530822754 | BCE Loss: 1.0124948024749756
Epoch 396 / 500 | iteration 20 / 30 | Total Loss: 4.685079574584961 | KNN Loss: 3.668729543685913 |

Epoch 406 / 500 | iteration 5 / 30 | Total Loss: 4.733777046203613 | KNN Loss: 3.692228317260742 | BCE Loss: 1.0415486097335815
Epoch 406 / 500 | iteration 10 / 30 | Total Loss: 4.715044975280762 | KNN Loss: 3.677948236465454 | BCE Loss: 1.0370965003967285
Epoch 406 / 500 | iteration 15 / 30 | Total Loss: 4.741888046264648 | KNN Loss: 3.7062532901763916 | BCE Loss: 1.0356345176696777
Epoch 406 / 500 | iteration 20 / 30 | Total Loss: 4.699357509613037 | KNN Loss: 3.6756443977355957 | BCE Loss: 1.0237131118774414
Epoch 406 / 500 | iteration 25 / 30 | Total Loss: 4.725636005401611 | KNN Loss: 3.6912457942962646 | BCE Loss: 1.0343900918960571
Epoch 407 / 500 | iteration 0 / 30 | Total Loss: 4.746334075927734 | KNN Loss: 3.6990106105804443 | BCE Loss: 1.047323226928711
Epoch 407 / 500 | iteration 5 / 30 | Total Loss: 4.717850685119629 | KNN Loss: 3.678654193878174 | BCE Loss: 1.039196252822876
Epoch 407 / 500 | iteration 10 / 30 | Total Loss: 4.73098087310791 | KNN Loss: 3.6944801807403564 

Epoch 416 / 500 | iteration 25 / 30 | Total Loss: 4.695993423461914 | KNN Loss: 3.6871416568756104 | BCE Loss: 1.0088517665863037
Epoch 417 / 500 | iteration 0 / 30 | Total Loss: 4.7436747550964355 | KNN Loss: 3.7013444900512695 | BCE Loss: 1.042330265045166
Epoch 417 / 500 | iteration 5 / 30 | Total Loss: 4.745013236999512 | KNN Loss: 3.719170570373535 | BCE Loss: 1.0258427858352661
Epoch 417 / 500 | iteration 10 / 30 | Total Loss: 4.722750663757324 | KNN Loss: 3.6956098079681396 | BCE Loss: 1.0271408557891846
Epoch 417 / 500 | iteration 15 / 30 | Total Loss: 4.711954116821289 | KNN Loss: 3.6852829456329346 | BCE Loss: 1.0266709327697754
Epoch 417 / 500 | iteration 20 / 30 | Total Loss: 4.706849098205566 | KNN Loss: 3.688894033432007 | BCE Loss: 1.0179550647735596
Epoch 417 / 500 | iteration 25 / 30 | Total Loss: 4.733284950256348 | KNN Loss: 3.708181381225586 | BCE Loss: 1.0251034498214722
Epoch 418 / 500 | iteration 0 / 30 | Total Loss: 4.721634864807129 | KNN Loss: 3.70802593231201

Epoch 427 / 500 | iteration 15 / 30 | Total Loss: 4.6972222328186035 | KNN Loss: 3.671234130859375 | BCE Loss: 1.0259881019592285
Epoch 427 / 500 | iteration 20 / 30 | Total Loss: 4.695345878601074 | KNN Loss: 3.669992208480835 | BCE Loss: 1.0253534317016602
Epoch 427 / 500 | iteration 25 / 30 | Total Loss: 4.705785751342773 | KNN Loss: 3.69469952583313 | BCE Loss: 1.0110859870910645
Epoch 428 / 500 | iteration 0 / 30 | Total Loss: 4.731802940368652 | KNN Loss: 3.6968953609466553 | BCE Loss: 1.034907579421997
Epoch 428 / 500 | iteration 5 / 30 | Total Loss: 4.780817031860352 | KNN Loss: 3.7310500144958496 | BCE Loss: 1.049767017364502
Epoch 428 / 500 | iteration 10 / 30 | Total Loss: 4.679352283477783 | KNN Loss: 3.636094093322754 | BCE Loss: 1.0432581901550293
Epoch 428 / 500 | iteration 15 / 30 | Total Loss: 4.761282444000244 | KNN Loss: 3.7173852920532227 | BCE Loss: 1.043897032737732
Epoch 428 / 500 | iteration 20 / 30 | Total Loss: 4.656574249267578 | KNN Loss: 3.6647729873657227 

Epoch 438 / 500 | iteration 5 / 30 | Total Loss: 4.6569976806640625 | KNN Loss: 3.6538920402526855 | BCE Loss: 1.003105640411377
Epoch 438 / 500 | iteration 10 / 30 | Total Loss: 4.7059736251831055 | KNN Loss: 3.6790621280670166 | BCE Loss: 1.0269112586975098
Epoch 438 / 500 | iteration 15 / 30 | Total Loss: 4.715023517608643 | KNN Loss: 3.6542091369628906 | BCE Loss: 1.060814380645752
Epoch 438 / 500 | iteration 20 / 30 | Total Loss: 4.751033306121826 | KNN Loss: 3.719531774520874 | BCE Loss: 1.0315016508102417
Epoch 438 / 500 | iteration 25 / 30 | Total Loss: 4.732845783233643 | KNN Loss: 3.6994423866271973 | BCE Loss: 1.0334035158157349
Epoch 439 / 500 | iteration 0 / 30 | Total Loss: 4.72944974899292 | KNN Loss: 3.699167013168335 | BCE Loss: 1.030282735824585
Epoch 439 / 500 | iteration 5 / 30 | Total Loss: 4.774263381958008 | KNN Loss: 3.7340073585510254 | BCE Loss: 1.0402560234069824
Epoch 439 / 500 | iteration 10 / 30 | Total Loss: 4.7733049392700195 | KNN Loss: 3.74291634559631

Epoch 448 / 500 | iteration 25 / 30 | Total Loss: 4.692770481109619 | KNN Loss: 3.685917615890503 | BCE Loss: 1.0068527460098267
Epoch   449: reducing learning rate of group 0 to 1.1270e-07.
Epoch 449 / 500 | iteration 0 / 30 | Total Loss: 4.7409772872924805 | KNN Loss: 3.6936752796173096 | BCE Loss: 1.04730224609375
Epoch 449 / 500 | iteration 5 / 30 | Total Loss: 4.7634100914001465 | KNN Loss: 3.726328134536743 | BCE Loss: 1.0370819568634033
Epoch 449 / 500 | iteration 10 / 30 | Total Loss: 4.75009822845459 | KNN Loss: 3.7112393379211426 | BCE Loss: 1.0388586521148682
Epoch 449 / 500 | iteration 15 / 30 | Total Loss: 4.666780471801758 | KNN Loss: 3.674314498901367 | BCE Loss: 0.9924658536911011
Epoch 449 / 500 | iteration 20 / 30 | Total Loss: 4.720156192779541 | KNN Loss: 3.6891753673553467 | BCE Loss: 1.0309808254241943
Epoch 449 / 500 | iteration 25 / 30 | Total Loss: 4.714337348937988 | KNN Loss: 3.6602962017059326 | BCE Loss: 1.0540409088134766
Epoch 450 / 500 | iteration 0 / 30

Epoch 459 / 500 | iteration 15 / 30 | Total Loss: 4.792212963104248 | KNN Loss: 3.7266476154327393 | BCE Loss: 1.0655654668807983
Epoch 459 / 500 | iteration 20 / 30 | Total Loss: 4.684537410736084 | KNN Loss: 3.6626217365264893 | BCE Loss: 1.0219155550003052
Epoch 459 / 500 | iteration 25 / 30 | Total Loss: 4.727397441864014 | KNN Loss: 3.6909310817718506 | BCE Loss: 1.036466360092163
Epoch   460: reducing learning rate of group 0 to 7.8888e-08.
Epoch 460 / 500 | iteration 0 / 30 | Total Loss: 4.748035430908203 | KNN Loss: 3.686722993850708 | BCE Loss: 1.0613125562667847
Epoch 460 / 500 | iteration 5 / 30 | Total Loss: 4.705294609069824 | KNN Loss: 3.6800954341888428 | BCE Loss: 1.0251991748809814
Epoch 460 / 500 | iteration 10 / 30 | Total Loss: 4.7114105224609375 | KNN Loss: 3.686742067337036 | BCE Loss: 1.0246682167053223
Epoch 460 / 500 | iteration 15 / 30 | Total Loss: 4.734474182128906 | KNN Loss: 3.681114912033081 | BCE Loss: 1.0533595085144043
Epoch 460 / 500 | iteration 20 / 

Epoch 470 / 500 | iteration 5 / 30 | Total Loss: 4.760499000549316 | KNN Loss: 3.741896629333496 | BCE Loss: 1.0186023712158203
Epoch 470 / 500 | iteration 10 / 30 | Total Loss: 4.7106242179870605 | KNN Loss: 3.6862051486968994 | BCE Loss: 1.0244190692901611
Epoch 470 / 500 | iteration 15 / 30 | Total Loss: 4.741048812866211 | KNN Loss: 3.709479808807373 | BCE Loss: 1.031569004058838
Epoch 470 / 500 | iteration 20 / 30 | Total Loss: 4.735234260559082 | KNN Loss: 3.6780471801757812 | BCE Loss: 1.0571868419647217
Epoch 470 / 500 | iteration 25 / 30 | Total Loss: 4.744417190551758 | KNN Loss: 3.7214269638061523 | BCE Loss: 1.0229899883270264
Epoch   471: reducing learning rate of group 0 to 5.5221e-08.
Epoch 471 / 500 | iteration 0 / 30 | Total Loss: 4.764540672302246 | KNN Loss: 3.7196900844573975 | BCE Loss: 1.0448505878448486
Epoch 471 / 500 | iteration 5 / 30 | Total Loss: 4.7561492919921875 | KNN Loss: 3.7200989723205566 | BCE Loss: 1.0360503196716309
Epoch 471 / 500 | iteration 10 /

Epoch 480 / 500 | iteration 25 / 30 | Total Loss: 4.724907875061035 | KNN Loss: 3.7121329307556152 | BCE Loss: 1.0127747058868408
Epoch 481 / 500 | iteration 0 / 30 | Total Loss: 4.679256439208984 | KNN Loss: 3.6544671058654785 | BCE Loss: 1.0247893333435059
Epoch 481 / 500 | iteration 5 / 30 | Total Loss: 4.726641654968262 | KNN Loss: 3.69728946685791 | BCE Loss: 1.0293521881103516
Epoch 481 / 500 | iteration 10 / 30 | Total Loss: 4.739554405212402 | KNN Loss: 3.729313373565674 | BCE Loss: 1.0102410316467285
Epoch 481 / 500 | iteration 15 / 30 | Total Loss: 4.686199188232422 | KNN Loss: 3.6696643829345703 | BCE Loss: 1.0165349245071411
Epoch 481 / 500 | iteration 20 / 30 | Total Loss: 4.730989933013916 | KNN Loss: 3.6781208515167236 | BCE Loss: 1.0528690814971924
Epoch 481 / 500 | iteration 25 / 30 | Total Loss: 4.7508225440979 | KNN Loss: 3.6973724365234375 | BCE Loss: 1.053450107574463
Epoch   482: reducing learning rate of group 0 to 3.8655e-08.
Epoch 482 / 500 | iteration 0 / 30 |

Epoch 491 / 500 | iteration 15 / 30 | Total Loss: 4.682605266571045 | KNN Loss: 3.6712288856506348 | BCE Loss: 1.0113763809204102
Epoch 491 / 500 | iteration 20 / 30 | Total Loss: 4.703763008117676 | KNN Loss: 3.685659408569336 | BCE Loss: 1.0181033611297607
Epoch 491 / 500 | iteration 25 / 30 | Total Loss: 4.735307693481445 | KNN Loss: 3.7150192260742188 | BCE Loss: 1.0202887058258057
Epoch 492 / 500 | iteration 0 / 30 | Total Loss: 4.648627758026123 | KNN Loss: 3.6306040287017822 | BCE Loss: 1.0180238485336304
Epoch 492 / 500 | iteration 5 / 30 | Total Loss: 4.706360816955566 | KNN Loss: 3.6899285316467285 | BCE Loss: 1.016432523727417
Epoch 492 / 500 | iteration 10 / 30 | Total Loss: 4.772653102874756 | KNN Loss: 3.7352843284606934 | BCE Loss: 1.0373687744140625
Epoch 492 / 500 | iteration 15 / 30 | Total Loss: 4.7292914390563965 | KNN Loss: 3.6734671592712402 | BCE Loss: 1.0558241605758667
Epoch 492 / 500 | iteration 20 / 30 | Total Loss: 4.723285675048828 | KNN Loss: 3.70104026794

In [7]:
outputs, iterm = model(dataset[67][0].unsqueeze(0).to(device), return_intermidiate=True)
print(outputs)
print(outputs.min())
print(outputs.max())

tensor([[ 2.8700,  3.2148,  2.5888,  3.2199,  2.9580,  0.7435,  2.5452,  1.7208,
          2.3905,  2.0455,  2.3058,  1.8545,  0.7572,  1.9303,  1.3485,  1.4854,
          2.4915,  2.9553,  2.8168,  1.8230,  1.8092,  2.7377,  2.3335,  2.2920,
          2.6009,  1.8221,  2.0219,  1.3945,  1.5464,  0.3820, -0.2454,  0.9553,
          0.2409,  1.0272,  1.3222,  1.4977,  1.0836,  2.8201,  0.8823,  1.3845,
          0.9177, -0.6213, -0.2419,  2.4166,  2.2063,  0.6436, -0.1784, -0.0311,
          1.5638,  2.1203,  1.9193,  0.2048,  1.4929,  0.5505, -0.6681,  1.1283,
          1.5035,  0.8375,  1.4346,  1.8754,  0.6843,  0.9019,  0.2014,  1.7707,
          1.2772,  1.5920, -2.0957,  0.3177,  2.0157,  1.9995,  2.4371,  0.5219,
          1.3739,  2.4774,  2.0058,  1.2663,  0.2416,  0.7286,  0.2631,  1.6594,
          0.0389,  0.4202,  1.9244, -0.2867,  0.2105, -1.0360, -2.3141, -0.4303,
          0.5728, -1.8165,  0.4920, -0.1380, -0.5948, -0.8402,  0.5739,  1.3905,
         -0.8045, -0.6958,  

In [8]:
plt.figure()
plt.plot(losses)
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [9]:
dataset.transform = torchvision.transforms.Compose([
    BinaryEncodingTransform(mapping=dataset.items_to_idx),
]
)
dataset.target_transform = torchvision.transforms.Compose([
    BinaryEncodingTransform(mapping=dataset.items_to_idx),
]
)

In [10]:
dataset_ = [d[0].cpu() for d in dataset]

In [11]:
model = model.eval().cpu()
projections = model.calculate_intermidiate(dataset_)

100%|██████████| 15/15 [00:00<00:00, 78.14it/s]


In [12]:
distances = pairwise_distances(projections)
# distances = np.triu(distances)
distances_f = distances.flatten()

plt.matshow(distances)
plt.colorbar()
plt.figure()
plt.hist(distances_f[distances_f > 0], bins=1000)
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

# Fit DBSCAN and calculate indices


In [13]:
clusters = DBSCAN(eps=0.2, min_samples=80).fit_predict(projections)
# scores = []
# best_score = float('inf')
# clusters = None
# range_ = list(range(5, 20))
# for k in tqdm(range_):
#     y = GaussianMixture(n_components=k).fit_predict(projections)
#     cur_score = davies_bouldin_score(projections, y)
#     scores.append(cur_score)
    
#     if cur_score < best_score:
#         best_score = cur_score
#         clusters = y

In [14]:
perplexity = 100

p = reduce_dims_and_plot(projections[clusters != -1],
                         y=clusters[clusters != -1],
                         title=f'perplexity: {perplexity}',
                         file_name=None,
                         perplexity=perplexity,
                         library='Multicore-TSNE',
                         perform_PCA=False,
                         projected=None,
                         figure_type='2d',
                         show_figure=True,
                         close_figure=False,
                         text=None)    

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [15]:
# from sklearn.tree import DecisionTreeClassifier
# from sklearn import tree
# from sklearn.tree import _tree

In [16]:
tensor_dataset = torch.stack(dataset_)

In [17]:
# clf = DecisionTreeClassifier(max_depth=200, min_samples_leaf=5)
# clf = clf.fit(tensor_dataset[clusters!=-1], clusters[clusters != -1])
# print(clf.score(tensor_dataset[clusters!=-1], clusters[clusters != -1]))
# print(clf.get_depth())

In [18]:
# scores = []
# for min_samples in range(1,50, 1):
#     clf = DecisionTreeClassifier(max_depth=200, min_samples_leaf=min_samples)
#     clf = clf.fit(tensor_dataset[clusters!=-1], clusters[clusters != -1])
#     scores.append(clf.score(tensor_dataset[clusters!=-1], clusters[clusters != -1]))
    
# plt.figure()
# plt.plot(list(range(1,50, 1)), scores)
# plt.show()

In [19]:
def get_rules(tree, feature_names, class_names):
    tree_ = tree.tree_
    feature_name = [
        feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!"
        for i in tree_.feature
    ]

    paths = []
    path = []
    
    def recurse(node, path, paths):
        if tree_.feature[node] != _tree.TREE_UNDEFINED:
            name = feature_name[node]
            threshold = tree_.threshold[node]
            p1, p2 = list(path), list(path)
#             p1 += [f"({name} <= {np.round(threshold, 3)})"]
            p1 += [(name, '<=', np.round(threshold, 3))]
            recurse(tree_.children_left[node], p1, paths)
            p2 += [(name, '>', np.round(threshold, 3))]
            recurse(tree_.children_right[node], p2, paths)
        else:
            path += [(tree_.value[node], tree_.n_node_samples[node])]
            paths += [path]
            
    recurse(0, path, paths)

    # sort by samples count
    samples_count = [p[-1][1] for p in paths]
    ii = list(np.argsort(samples_count))
    paths = [paths[i] for i in reversed(ii)]
    
    rules = []
    for path in paths:
        rule = []
        
        for p in path[:-1]:
            rule += [p]
        target = " then "
        if class_names is None:
            target += "response: "+str(np.round(path[-1][0][0][0],3))
        else:
            classes = path[-1][0][0]
            l = np.argmax(classes)
            target += f"class: {class_names[l]} (proba: {np.round(100.0*classes[l]/np.sum(classes),2)}%)"
           
        proba = np.round(100.0*classes[l]/np.sum(classes),2)
        target += f" | based on {path[-1][1]:,} samples"
        rule_wrapper = {'target': target, 'rule': rule, 'proba': proba}
        rules += [rule_wrapper]
        
    return rules

In [20]:
# rules = get_rules(clf, dataset.items, clusters[clusters != -1])

# for rule in rules:
#     n_pos = 0
#     for c,p,v in rule['rule']:
#         if p == '>':
#             n_pos += 1
#     rule['pos'] = n_pos

In [21]:
# plt.figure()
# probs = [r['proba'] for r in rules]
# plt.hist(probs, bins = 100)
# plt.show()

In [22]:
# rules = sorted(rules, key=lambda x:x['pos'])
# rules = [r for r in rules if r['proba'] > 50]
# print(len(rules))

In [23]:
# for i in range(17):
#     r_i = rules[i]
#     print(f"------------- rule {i} length {len(r_i)} -------------")
#     print(r_i)

# Train a Soft-Decision-Tree given the self-labels

## Prepare the dataset

In [24]:
tree_dataset = list(zip(tensor_dataset[clusters!=-1], clusters[clusters != -1]))
batch_size = 512
tree_loader = torch.utils.data.DataLoader(tree_dataset, batch_size=batch_size, shuffle=True)

# Define how we prune the weights of a node

In [25]:
def prune_node(node_weights, factor=1):
    w = node_weights.cpu().detach().numpy()
    mean_ = np.mean(w)
    std_ = np.std(w)
    node_weights[((mean_ - std_ * factor) < node_weights) & (node_weights < (mean_ + std_ * factor))] = 0
    return node_weights

def prune_node_keep(node_weights, keep=4):
    w = node_weights.cpu().detach().numpy()
    throw_idx = np.argsort(abs(w))[:-keep]
    node_weights[throw_idx] = 0
    return node_weights

def prune_tree(tree_, factor):
    new_weights = tree_.inner_nodes.weight.clone()
    for i in range(new_weights.shape[0]):
        res = prune_node_keep(new_weights[i, :], factor)
        new_weights[i, :] = res

    with torch.no_grad():
        tree_.inner_nodes.weight.copy_(new_weights)
        
def sparseness(x):
    s = []
    for i in range(x.shape[0]):
        x_ = x[i, :]
        sp = (len(x_) - torch.norm(x_, 0).item()) / len(x_)
        s.append(sp)
    return np.mean(s)

def compute_regularization_by_level(tree):
    total_reg = 0
    for i in range(tree.inner_nodes.weight.shape[0]):
        cur_level = np.floor(np.log2(i+1))
        node_reg = torch.norm(tree.inner_nodes.weight[i].view(-1), 2)
        total_reg += 2**(-cur_level) * node_reg
    return total_reg

def show_sparseness(tree):
    avg_sp = sparseness(tree.inner_nodes.weight)
    print(f"Average sparseness: {avg_sp}")
    layer = 0
    sps = []
    for i in range(tree.inner_nodes.weight.shape[0]):
        cur_layer = int(np.floor(np.log2(i+1)))
        if cur_layer != layer:
            print(f"layer {layer}: {np.mean(sps)}")
            sps = []
            layer = cur_layer

        x_ = tree.inner_nodes.weight[i, :]
        sp = (len(x_) - torch.norm(x_, 0).item()) / len(x_)
        sps.append(sp)
        
    return avg_sp

## Training configurations

In [26]:
def do_epoch(model, loader, device, log_interval, losses, accs, epoch, iteration):
    model = model.train()
    for batch_idx, (data, target) in enumerate(loader):
        iteration += 1
        data, target = data.to(device), target.to(device)

        output, penalty = tree.forward(data)

        # Loss
        loss_tree = criterion(output, target.view(-1))

        # Penalty
        loss_tree += penalty

        # Sparse regularization
#         fc_params = torch.cat([x.view(-1) for x in tree.inner_nodes.parameters()])
#         regularization = sparsity_lamda * torch.norm(fc_params, 2)
        regularization = sparsity_lamda * compute_regularization_by_level(tree)
        loss = loss_tree

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        losses.append(loss.item())

        pred = output.data.max(1)[1]
        correct = pred.eq(target.view(-1).data).sum()
        accs.append(correct.item() / data.size()[0])

        # Print training status
        if batch_idx % log_interval == 0:
            print(f"Epoch: {epoch:02d} | Batch: {batch_idx:03d} / {len(loader):03d} | Total loss: {loss.item():.3f} | Reg loss: {regularization.item():.3f} | Tree loss: {loss_tree.item():.3f} | Accuracy: {correct.item() / data.size()[0]:03f} | {round((time.time() - start_time) / iteration, 3)} sec/iter")
            
    return iteration


In [27]:
lr = 5e-3
weight_decay = 5e-4
sparsity_lamda = 2e-3
epochs = 100
output_dim = len(set(clusters))
log_interval = 1
use_cuda = device != 'cpu'

In [28]:
tree = SDT(input_dim=tensor_dataset.shape[1], output_dim=len(clusters - 1), depth=tree_depth, lamda=1e-3, use_cuda=use_cuda)
optimizer = torch.optim.Adam(tree.parameters(),
                                 lr=lr,
                                 weight_decay=weight_decay)
criterion = nn.CrossEntropyLoss()
tree = tree.to(device)

In [29]:
losses = []
accs = []
sparsity = []

In [30]:
start_time = time.time()
iteration = 0
for epoch in range(epochs):
    # Training
    avg_sp = show_sparseness(tree)
    sparsity.append(avg_sp)
    iteration = do_epoch(tree, tree_loader, device, log_interval, losses, accs, epoch, iteration)
    
    if epoch % 1 == 0:
        prune_tree(tree, factor=3)
        

Average sparseness: 0.0
layer 0: 0.0
layer 1: 0.0
layer 2: 0.0
layer 3: 0.0
layer 4: 0.0
layer 5: 0.0
layer 6: 0.0
layer 7: 0.0
layer 8: 0.0
layer 9: 0.0
layer 10: 0.0
Epoch: 00 | Batch: 000 / 026 | Total loss: 9.626 | Reg loss: 0.014 | Tree loss: 9.626 | Accuracy: 0.000000 | 6.821 sec/iter
Epoch: 00 | Batch: 001 / 026 | Total loss: 9.624 | Reg loss: 0.013 | Tree loss: 9.624 | Accuracy: 0.000000 | 6.117 sec/iter
Epoch: 00 | Batch: 002 / 026 | Total loss: 9.623 | Reg loss: 0.012 | Tree loss: 9.623 | Accuracy: 0.000000 | 5.831 sec/iter
Epoch: 00 | Batch: 003 / 026 | Total loss: 9.621 | Reg loss: 0.011 | Tree loss: 9.621 | Accuracy: 0.000000 | 5.674 sec/iter
Epoch: 00 | Batch: 004 / 026 | Total loss: 9.619 | Reg loss: 0.010 | Tree loss: 9.619 | Accuracy: 0.000000 | 5.561 sec/iter
Epoch: 00 | Batch: 005 / 026 | Total loss: 9.617 | Reg loss: 0.009 | Tree loss: 9.617 | Accuracy: 0.000000 | 5.479 sec/iter
Epoch: 00 | Batch: 006 / 026 | Total loss: 9.616 | Reg loss: 0.008 | Tree loss: 9.616 | 

Epoch: 02 | Batch: 008 / 026 | Total loss: 9.578 | Reg loss: 0.006 | Tree loss: 9.578 | Accuracy: 0.111328 | 5.041 sec/iter
Epoch: 02 | Batch: 009 / 026 | Total loss: 9.579 | Reg loss: 0.006 | Tree loss: 9.579 | Accuracy: 0.115234 | 5.037 sec/iter
Epoch: 02 | Batch: 010 / 026 | Total loss: 9.576 | Reg loss: 0.007 | Tree loss: 9.576 | Accuracy: 0.138672 | 5.03 sec/iter
Epoch: 02 | Batch: 011 / 026 | Total loss: 9.575 | Reg loss: 0.007 | Tree loss: 9.575 | Accuracy: 0.132812 | 5.015 sec/iter
Epoch: 02 | Batch: 012 / 026 | Total loss: 9.573 | Reg loss: 0.007 | Tree loss: 9.573 | Accuracy: 0.132812 | 5.026 sec/iter
Epoch: 02 | Batch: 013 / 026 | Total loss: 9.573 | Reg loss: 0.007 | Tree loss: 9.573 | Accuracy: 0.134766 | 5.038 sec/iter
Epoch: 02 | Batch: 014 / 026 | Total loss: 9.572 | Reg loss: 0.007 | Tree loss: 9.572 | Accuracy: 0.130859 | 5.051 sec/iter
Epoch: 02 | Batch: 015 / 026 | Total loss: 9.570 | Reg loss: 0.008 | Tree loss: 9.570 | Accuracy: 0.138672 | 5.064 sec/iter
Epoch: 02

Epoch: 04 | Batch: 017 / 026 | Total loss: 9.478 | Reg loss: 0.014 | Tree loss: 9.478 | Accuracy: 0.130859 | 5.113 sec/iter
Epoch: 04 | Batch: 018 / 026 | Total loss: 9.483 | Reg loss: 0.015 | Tree loss: 9.483 | Accuracy: 0.111328 | 5.115 sec/iter
Epoch: 04 | Batch: 019 / 026 | Total loss: 9.474 | Reg loss: 0.015 | Tree loss: 9.474 | Accuracy: 0.099609 | 5.117 sec/iter
Epoch: 04 | Batch: 020 / 026 | Total loss: 9.463 | Reg loss: 0.016 | Tree loss: 9.463 | Accuracy: 0.105469 | 5.117 sec/iter
Epoch: 04 | Batch: 021 / 026 | Total loss: 9.454 | Reg loss: 0.016 | Tree loss: 9.454 | Accuracy: 0.105469 | 5.116 sec/iter
Epoch: 04 | Batch: 022 / 026 | Total loss: 9.450 | Reg loss: 0.016 | Tree loss: 9.450 | Accuracy: 0.111328 | 5.114 sec/iter
Epoch: 04 | Batch: 023 / 026 | Total loss: 9.445 | Reg loss: 0.017 | Tree loss: 9.445 | Accuracy: 0.095703 | 5.111 sec/iter
Epoch: 04 | Batch: 024 / 026 | Total loss: 9.430 | Reg loss: 0.017 | Tree loss: 9.430 | Accuracy: 0.058594 | 5.109 sec/iter
Epoch: 0

Average sparseness: 0.9821428571428572
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
layer 7: 0.9821428571428573
layer 8: 0.9821428571428573
layer 9: 0.9821428571428573
layer 10: 0.9821428571428573
Epoch: 07 | Batch: 000 / 026 | Total loss: 9.285 | Reg loss: 0.017 | Tree loss: 9.285 | Accuracy: 0.128906 | 5.137 sec/iter
Epoch: 07 | Batch: 001 / 026 | Total loss: 9.248 | Reg loss: 0.017 | Tree loss: 9.248 | Accuracy: 0.109375 | 5.139 sec/iter
Epoch: 07 | Batch: 002 / 026 | Total loss: 9.236 | Reg loss: 0.017 | Tree loss: 9.236 | Accuracy: 0.107422 | 5.14 sec/iter
Epoch: 07 | Batch: 003 / 026 | Total loss: 9.213 | Reg loss: 0.017 | Tree loss: 9.213 | Accuracy: 0.097656 | 5.14 sec/iter
Epoch: 07 | Batch: 004 / 026 | Total loss: 9.208 | Reg loss: 0.017 | Tree loss: 9.208 | Accuracy: 0.103516 | 5.141 sec/iter
Epoch: 07 | Batch: 005 / 026 | Tota

Epoch: 09 | Batch: 006 / 026 | Total loss: 8.710 | Reg loss: 0.021 | Tree loss: 8.710 | Accuracy: 0.056641 | 5.129 sec/iter
Epoch: 09 | Batch: 007 / 026 | Total loss: 8.681 | Reg loss: 0.021 | Tree loss: 8.681 | Accuracy: 0.058594 | 5.128 sec/iter
Epoch: 09 | Batch: 008 / 026 | Total loss: 8.633 | Reg loss: 0.022 | Tree loss: 8.633 | Accuracy: 0.066406 | 5.124 sec/iter
Epoch: 09 | Batch: 009 / 026 | Total loss: 8.592 | Reg loss: 0.022 | Tree loss: 8.592 | Accuracy: 0.064453 | 5.126 sec/iter
Epoch: 09 | Batch: 010 / 026 | Total loss: 8.551 | Reg loss: 0.022 | Tree loss: 8.551 | Accuracy: 0.062500 | 5.128 sec/iter
Epoch: 09 | Batch: 011 / 026 | Total loss: 8.539 | Reg loss: 0.022 | Tree loss: 8.539 | Accuracy: 0.062500 | 5.13 sec/iter
Epoch: 09 | Batch: 012 / 026 | Total loss: 8.552 | Reg loss: 0.023 | Tree loss: 8.552 | Accuracy: 0.066406 | 5.132 sec/iter
Epoch: 09 | Batch: 013 / 026 | Total loss: 8.505 | Reg loss: 0.023 | Tree loss: 8.505 | Accuracy: 0.070312 | 5.135 sec/iter
Epoch: 09

Epoch: 11 | Batch: 015 / 026 | Total loss: 7.885 | Reg loss: 0.025 | Tree loss: 7.885 | Accuracy: 0.062500 | 5.128 sec/iter
Epoch: 11 | Batch: 016 / 026 | Total loss: 7.870 | Reg loss: 0.026 | Tree loss: 7.870 | Accuracy: 0.064453 | 5.128 sec/iter
Epoch: 11 | Batch: 017 / 026 | Total loss: 7.841 | Reg loss: 0.026 | Tree loss: 7.841 | Accuracy: 0.056641 | 5.128 sec/iter
Epoch: 11 | Batch: 018 / 026 | Total loss: 7.794 | Reg loss: 0.026 | Tree loss: 7.794 | Accuracy: 0.066406 | 5.127 sec/iter
Epoch: 11 | Batch: 019 / 026 | Total loss: 7.786 | Reg loss: 0.026 | Tree loss: 7.786 | Accuracy: 0.054688 | 5.126 sec/iter
Epoch: 11 | Batch: 020 / 026 | Total loss: 7.769 | Reg loss: 0.026 | Tree loss: 7.769 | Accuracy: 0.068359 | 5.125 sec/iter
Epoch: 11 | Batch: 021 / 026 | Total loss: 7.702 | Reg loss: 0.027 | Tree loss: 7.702 | Accuracy: 0.072266 | 5.124 sec/iter
Epoch: 11 | Batch: 022 / 026 | Total loss: 7.736 | Reg loss: 0.027 | Tree loss: 7.736 | Accuracy: 0.056641 | 5.123 sec/iter
Epoch: 1

Epoch: 13 | Batch: 024 / 026 | Total loss: 7.177 | Reg loss: 0.028 | Tree loss: 7.177 | Accuracy: 0.054688 | 5.133 sec/iter
Epoch: 13 | Batch: 025 / 026 | Total loss: 7.139 | Reg loss: 0.028 | Tree loss: 7.139 | Accuracy: 0.083871 | 5.129 sec/iter
Average sparseness: 0.9821428571428572
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
layer 7: 0.9821428571428573
layer 8: 0.9821428571428573
layer 9: 0.9821428571428573
layer 10: 0.9821428571428573
Epoch: 14 | Batch: 000 / 026 | Total loss: 7.455 | Reg loss: 0.026 | Tree loss: 7.455 | Accuracy: 0.060547 | 5.136 sec/iter
Epoch: 14 | Batch: 001 / 026 | Total loss: 7.448 | Reg loss: 0.026 | Tree loss: 7.448 | Accuracy: 0.054688 | 5.137 sec/iter
Epoch: 14 | Batch: 002 / 026 | Total loss: 7.385 | Reg loss: 0.026 | Tree loss: 7.385 | Accuracy: 0.046875 | 5.138 sec/iter
Epoch: 14 | Batch: 003 / 026 | To

Epoch: 16 | Batch: 004 / 026 | Total loss: 6.857 | Reg loss: 0.027 | Tree loss: 6.857 | Accuracy: 0.054688 | 5.138 sec/iter
Epoch: 16 | Batch: 005 / 026 | Total loss: 6.829 | Reg loss: 0.027 | Tree loss: 6.829 | Accuracy: 0.054688 | 5.137 sec/iter
Epoch: 16 | Batch: 006 / 026 | Total loss: 6.799 | Reg loss: 0.027 | Tree loss: 6.799 | Accuracy: 0.076172 | 5.137 sec/iter
Epoch: 16 | Batch: 007 / 026 | Total loss: 6.791 | Reg loss: 0.027 | Tree loss: 6.791 | Accuracy: 0.046875 | 5.135 sec/iter
Epoch: 16 | Batch: 008 / 026 | Total loss: 6.756 | Reg loss: 0.028 | Tree loss: 6.756 | Accuracy: 0.062500 | 5.132 sec/iter
Epoch: 16 | Batch: 009 / 026 | Total loss: 6.788 | Reg loss: 0.028 | Tree loss: 6.788 | Accuracy: 0.066406 | 5.134 sec/iter
Epoch: 16 | Batch: 010 / 026 | Total loss: 6.704 | Reg loss: 0.028 | Tree loss: 6.704 | Accuracy: 0.058594 | 5.135 sec/iter
Epoch: 16 | Batch: 011 / 026 | Total loss: 6.713 | Reg loss: 0.028 | Tree loss: 6.713 | Accuracy: 0.054688 | 5.135 sec/iter
Epoch: 1

Epoch: 18 | Batch: 013 / 026 | Total loss: 6.216 | Reg loss: 0.029 | Tree loss: 6.216 | Accuracy: 0.062500 | 5.121 sec/iter
Epoch: 18 | Batch: 014 / 026 | Total loss: 6.201 | Reg loss: 0.029 | Tree loss: 6.201 | Accuracy: 0.072266 | 5.122 sec/iter
Epoch: 18 | Batch: 015 / 026 | Total loss: 6.231 | Reg loss: 0.029 | Tree loss: 6.231 | Accuracy: 0.068359 | 5.122 sec/iter
Epoch: 18 | Batch: 016 / 026 | Total loss: 6.177 | Reg loss: 0.029 | Tree loss: 6.177 | Accuracy: 0.062500 | 5.122 sec/iter
Epoch: 18 | Batch: 017 / 026 | Total loss: 6.181 | Reg loss: 0.029 | Tree loss: 6.181 | Accuracy: 0.062500 | 5.121 sec/iter
Epoch: 18 | Batch: 018 / 026 | Total loss: 6.167 | Reg loss: 0.029 | Tree loss: 6.167 | Accuracy: 0.048828 | 5.121 sec/iter
Epoch: 18 | Batch: 019 / 026 | Total loss: 6.141 | Reg loss: 0.029 | Tree loss: 6.141 | Accuracy: 0.050781 | 5.12 sec/iter
Epoch: 18 | Batch: 020 / 026 | Total loss: 6.124 | Reg loss: 0.029 | Tree loss: 6.124 | Accuracy: 0.044922 | 5.12 sec/iter
Epoch: 18 

Epoch: 20 | Batch: 022 / 026 | Total loss: 5.704 | Reg loss: 0.030 | Tree loss: 5.704 | Accuracy: 0.050781 | 5.113 sec/iter
Epoch: 20 | Batch: 023 / 026 | Total loss: 5.688 | Reg loss: 0.030 | Tree loss: 5.688 | Accuracy: 0.052734 | 5.113 sec/iter
Epoch: 20 | Batch: 024 / 026 | Total loss: 5.684 | Reg loss: 0.030 | Tree loss: 5.684 | Accuracy: 0.060547 | 5.112 sec/iter
Epoch: 20 | Batch: 025 / 026 | Total loss: 5.653 | Reg loss: 0.030 | Tree loss: 5.653 | Accuracy: 0.083871 | 5.109 sec/iter
Average sparseness: 0.9821428571428572
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
layer 7: 0.9821428571428573
layer 8: 0.9821428571428573
layer 9: 0.9821428571428573
layer 10: 0.9821428571428573
Epoch: 21 | Batch: 000 / 026 | Total loss: 5.825 | Reg loss: 0.029 | Tree loss: 5.825 | Accuracy: 0.068359 | 5.116 sec/iter
Epoch: 21 | Batch: 001 / 026 | To

Epoch: 23 | Batch: 002 / 026 | Total loss: 5.382 | Reg loss: 0.030 | Tree loss: 5.382 | Accuracy: 0.062500 | 5.123 sec/iter
Epoch: 23 | Batch: 003 / 026 | Total loss: 5.397 | Reg loss: 0.030 | Tree loss: 5.397 | Accuracy: 0.046875 | 5.122 sec/iter
Epoch: 23 | Batch: 004 / 026 | Total loss: 5.358 | Reg loss: 0.030 | Tree loss: 5.358 | Accuracy: 0.056641 | 5.122 sec/iter
Epoch: 23 | Batch: 005 / 026 | Total loss: 5.359 | Reg loss: 0.030 | Tree loss: 5.359 | Accuracy: 0.054688 | 5.121 sec/iter
Epoch: 23 | Batch: 006 / 026 | Total loss: 5.381 | Reg loss: 0.030 | Tree loss: 5.381 | Accuracy: 0.056641 | 5.119 sec/iter
Epoch: 23 | Batch: 007 / 026 | Total loss: 5.349 | Reg loss: 0.030 | Tree loss: 5.349 | Accuracy: 0.060547 | 5.12 sec/iter
Epoch: 23 | Batch: 008 / 026 | Total loss: 5.358 | Reg loss: 0.030 | Tree loss: 5.358 | Accuracy: 0.044922 | 5.12 sec/iter
Epoch: 23 | Batch: 009 / 026 | Total loss: 5.294 | Reg loss: 0.030 | Tree loss: 5.294 | Accuracy: 0.058594 | 5.12 sec/iter
Epoch: 23 |

Epoch: 25 | Batch: 011 / 026 | Total loss: 5.033 | Reg loss: 0.030 | Tree loss: 5.033 | Accuracy: 0.050781 | 5.114 sec/iter
Epoch: 25 | Batch: 012 / 026 | Total loss: 4.968 | Reg loss: 0.030 | Tree loss: 4.968 | Accuracy: 0.060547 | 5.115 sec/iter
Epoch: 25 | Batch: 013 / 026 | Total loss: 4.936 | Reg loss: 0.030 | Tree loss: 4.936 | Accuracy: 0.068359 | 5.116 sec/iter
Epoch: 25 | Batch: 014 / 026 | Total loss: 4.951 | Reg loss: 0.030 | Tree loss: 4.951 | Accuracy: 0.052734 | 5.117 sec/iter
Epoch: 25 | Batch: 015 / 026 | Total loss: 4.927 | Reg loss: 0.030 | Tree loss: 4.927 | Accuracy: 0.062500 | 5.117 sec/iter
Epoch: 25 | Batch: 016 / 026 | Total loss: 4.918 | Reg loss: 0.030 | Tree loss: 4.918 | Accuracy: 0.068359 | 5.118 sec/iter
Epoch: 25 | Batch: 017 / 026 | Total loss: 4.909 | Reg loss: 0.030 | Tree loss: 4.909 | Accuracy: 0.066406 | 5.118 sec/iter
Epoch: 25 | Batch: 018 / 026 | Total loss: 4.905 | Reg loss: 0.030 | Tree loss: 4.905 | Accuracy: 0.058594 | 5.119 sec/iter
Epoch: 2

Epoch: 27 | Batch: 020 / 026 | Total loss: 4.583 | Reg loss: 0.031 | Tree loss: 4.583 | Accuracy: 0.052734 | 5.112 sec/iter
Epoch: 27 | Batch: 021 / 026 | Total loss: 4.608 | Reg loss: 0.031 | Tree loss: 4.608 | Accuracy: 0.078125 | 5.111 sec/iter
Epoch: 27 | Batch: 022 / 026 | Total loss: 4.588 | Reg loss: 0.031 | Tree loss: 4.588 | Accuracy: 0.066406 | 5.111 sec/iter
Epoch: 27 | Batch: 023 / 026 | Total loss: 4.578 | Reg loss: 0.031 | Tree loss: 4.578 | Accuracy: 0.078125 | 5.11 sec/iter
Epoch: 27 | Batch: 024 / 026 | Total loss: 4.588 | Reg loss: 0.031 | Tree loss: 4.588 | Accuracy: 0.068359 | 5.11 sec/iter
Epoch: 27 | Batch: 025 / 026 | Total loss: 4.584 | Reg loss: 0.031 | Tree loss: 4.584 | Accuracy: 0.070968 | 5.108 sec/iter
Average sparseness: 0.9821428571428572
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
layer 7: 0.9821428571428

Epoch: 30 | Batch: 000 / 026 | Total loss: 4.403 | Reg loss: 0.031 | Tree loss: 4.403 | Accuracy: 0.087891 | 5.113 sec/iter
Epoch: 30 | Batch: 001 / 026 | Total loss: 4.361 | Reg loss: 0.031 | Tree loss: 4.361 | Accuracy: 0.099609 | 5.113 sec/iter
Epoch: 30 | Batch: 002 / 026 | Total loss: 4.358 | Reg loss: 0.031 | Tree loss: 4.358 | Accuracy: 0.093750 | 5.114 sec/iter
Epoch: 30 | Batch: 003 / 026 | Total loss: 4.356 | Reg loss: 0.031 | Tree loss: 4.356 | Accuracy: 0.099609 | 5.113 sec/iter
Epoch: 30 | Batch: 004 / 026 | Total loss: 4.362 | Reg loss: 0.031 | Tree loss: 4.362 | Accuracy: 0.080078 | 5.113 sec/iter
Epoch: 30 | Batch: 005 / 026 | Total loss: 4.337 | Reg loss: 0.031 | Tree loss: 4.337 | Accuracy: 0.111328 | 5.114 sec/iter
Epoch: 30 | Batch: 006 / 026 | Total loss: 4.346 | Reg loss: 0.031 | Tree loss: 4.346 | Accuracy: 0.105469 | 5.114 sec/iter
Epoch: 30 | Batch: 007 / 026 | Total loss: 4.344 | Reg loss: 0.031 | Tree loss: 4.344 | Accuracy: 0.089844 | 5.115 sec/iter
Epoch: 3

Epoch: 32 | Batch: 009 / 026 | Total loss: 4.114 | Reg loss: 0.031 | Tree loss: 4.114 | Accuracy: 0.091797 | 5.12 sec/iter
Epoch: 32 | Batch: 010 / 026 | Total loss: 4.106 | Reg loss: 0.031 | Tree loss: 4.106 | Accuracy: 0.083984 | 5.12 sec/iter
Epoch: 32 | Batch: 011 / 026 | Total loss: 4.051 | Reg loss: 0.031 | Tree loss: 4.051 | Accuracy: 0.125000 | 5.121 sec/iter
Epoch: 32 | Batch: 012 / 026 | Total loss: 4.076 | Reg loss: 0.031 | Tree loss: 4.076 | Accuracy: 0.125000 | 5.121 sec/iter
Epoch: 32 | Batch: 013 / 026 | Total loss: 4.111 | Reg loss: 0.031 | Tree loss: 4.111 | Accuracy: 0.091797 | 5.121 sec/iter
Epoch: 32 | Batch: 014 / 026 | Total loss: 4.100 | Reg loss: 0.031 | Tree loss: 4.100 | Accuracy: 0.089844 | 5.122 sec/iter
Epoch: 32 | Batch: 015 / 026 | Total loss: 4.064 | Reg loss: 0.031 | Tree loss: 4.064 | Accuracy: 0.126953 | 5.122 sec/iter
Epoch: 32 | Batch: 016 / 026 | Total loss: 4.057 | Reg loss: 0.031 | Tree loss: 4.057 | Accuracy: 0.101562 | 5.122 sec/iter
Epoch: 32 

Epoch: 34 | Batch: 018 / 026 | Total loss: 3.894 | Reg loss: 0.031 | Tree loss: 3.894 | Accuracy: 0.119141 | 5.123 sec/iter
Epoch: 34 | Batch: 019 / 026 | Total loss: 3.869 | Reg loss: 0.031 | Tree loss: 3.869 | Accuracy: 0.125000 | 5.123 sec/iter
Epoch: 34 | Batch: 020 / 026 | Total loss: 3.872 | Reg loss: 0.032 | Tree loss: 3.872 | Accuracy: 0.091797 | 5.123 sec/iter
Epoch: 34 | Batch: 021 / 026 | Total loss: 3.866 | Reg loss: 0.032 | Tree loss: 3.866 | Accuracy: 0.123047 | 5.123 sec/iter
Epoch: 34 | Batch: 022 / 026 | Total loss: 3.865 | Reg loss: 0.032 | Tree loss: 3.865 | Accuracy: 0.097656 | 5.122 sec/iter
Epoch: 34 | Batch: 023 / 026 | Total loss: 3.854 | Reg loss: 0.032 | Tree loss: 3.854 | Accuracy: 0.105469 | 5.122 sec/iter
Epoch: 34 | Batch: 024 / 026 | Total loss: 3.888 | Reg loss: 0.032 | Tree loss: 3.888 | Accuracy: 0.121094 | 5.122 sec/iter
Epoch: 34 | Batch: 025 / 026 | Total loss: 3.804 | Reg loss: 0.032 | Tree loss: 3.804 | Accuracy: 0.122581 | 5.12 sec/iter
Average s

Epoch: 37 | Batch: 000 / 026 | Total loss: 3.758 | Reg loss: 0.031 | Tree loss: 3.758 | Accuracy: 0.103516 | 5.119 sec/iter
Epoch: 37 | Batch: 001 / 026 | Total loss: 3.710 | Reg loss: 0.031 | Tree loss: 3.710 | Accuracy: 0.121094 | 5.118 sec/iter
Epoch: 37 | Batch: 002 / 026 | Total loss: 3.759 | Reg loss: 0.031 | Tree loss: 3.759 | Accuracy: 0.103516 | 5.119 sec/iter
Epoch: 37 | Batch: 003 / 026 | Total loss: 3.755 | Reg loss: 0.031 | Tree loss: 3.755 | Accuracy: 0.093750 | 5.119 sec/iter
Epoch: 37 | Batch: 004 / 026 | Total loss: 3.731 | Reg loss: 0.031 | Tree loss: 3.731 | Accuracy: 0.105469 | 5.119 sec/iter
Epoch: 37 | Batch: 005 / 026 | Total loss: 3.775 | Reg loss: 0.031 | Tree loss: 3.775 | Accuracy: 0.072266 | 5.12 sec/iter
Epoch: 37 | Batch: 006 / 026 | Total loss: 3.734 | Reg loss: 0.031 | Tree loss: 3.734 | Accuracy: 0.101562 | 5.12 sec/iter
Epoch: 37 | Batch: 007 / 026 | Total loss: 3.711 | Reg loss: 0.031 | Tree loss: 3.711 | Accuracy: 0.105469 | 5.12 sec/iter
Epoch: 37 |

Epoch: 39 | Batch: 009 / 026 | Total loss: 3.559 | Reg loss: 0.032 | Tree loss: 3.559 | Accuracy: 0.123047 | 5.122 sec/iter
Epoch: 39 | Batch: 010 / 026 | Total loss: 3.584 | Reg loss: 0.032 | Tree loss: 3.584 | Accuracy: 0.126953 | 5.122 sec/iter
Epoch: 39 | Batch: 011 / 026 | Total loss: 3.592 | Reg loss: 0.032 | Tree loss: 3.592 | Accuracy: 0.119141 | 5.122 sec/iter
Epoch: 39 | Batch: 012 / 026 | Total loss: 3.588 | Reg loss: 0.032 | Tree loss: 3.588 | Accuracy: 0.105469 | 5.123 sec/iter
Epoch: 39 | Batch: 013 / 026 | Total loss: 3.590 | Reg loss: 0.032 | Tree loss: 3.590 | Accuracy: 0.111328 | 5.123 sec/iter
Epoch: 39 | Batch: 014 / 026 | Total loss: 3.582 | Reg loss: 0.032 | Tree loss: 3.582 | Accuracy: 0.130859 | 5.123 sec/iter
Epoch: 39 | Batch: 015 / 026 | Total loss: 3.583 | Reg loss: 0.032 | Tree loss: 3.583 | Accuracy: 0.138672 | 5.123 sec/iter
Epoch: 39 | Batch: 016 / 026 | Total loss: 3.579 | Reg loss: 0.032 | Tree loss: 3.579 | Accuracy: 0.132812 | 5.123 sec/iter
Epoch: 3

Epoch: 41 | Batch: 018 / 026 | Total loss: 3.534 | Reg loss: 0.032 | Tree loss: 3.534 | Accuracy: 0.097656 | 5.125 sec/iter
Epoch: 41 | Batch: 019 / 026 | Total loss: 3.462 | Reg loss: 0.032 | Tree loss: 3.462 | Accuracy: 0.111328 | 5.125 sec/iter
Epoch: 41 | Batch: 020 / 026 | Total loss: 3.422 | Reg loss: 0.032 | Tree loss: 3.422 | Accuracy: 0.158203 | 5.125 sec/iter
Epoch: 41 | Batch: 021 / 026 | Total loss: 3.474 | Reg loss: 0.032 | Tree loss: 3.474 | Accuracy: 0.132812 | 5.125 sec/iter
Epoch: 41 | Batch: 022 / 026 | Total loss: 3.457 | Reg loss: 0.032 | Tree loss: 3.457 | Accuracy: 0.105469 | 5.125 sec/iter
Epoch: 41 | Batch: 023 / 026 | Total loss: 3.463 | Reg loss: 0.032 | Tree loss: 3.463 | Accuracy: 0.105469 | 5.125 sec/iter
Epoch: 41 | Batch: 024 / 026 | Total loss: 3.407 | Reg loss: 0.032 | Tree loss: 3.407 | Accuracy: 0.144531 | 5.124 sec/iter
Epoch: 41 | Batch: 025 / 026 | Total loss: 3.390 | Reg loss: 0.032 | Tree loss: 3.390 | Accuracy: 0.148387 | 5.122 sec/iter
Average 

Epoch: 44 | Batch: 000 / 026 | Total loss: 3.422 | Reg loss: 0.032 | Tree loss: 3.422 | Accuracy: 0.130859 | 5.121 sec/iter
Epoch: 44 | Batch: 001 / 026 | Total loss: 3.407 | Reg loss: 0.032 | Tree loss: 3.407 | Accuracy: 0.126953 | 5.122 sec/iter
Epoch: 44 | Batch: 002 / 026 | Total loss: 3.374 | Reg loss: 0.032 | Tree loss: 3.374 | Accuracy: 0.125000 | 5.122 sec/iter
Epoch: 44 | Batch: 003 / 026 | Total loss: 3.388 | Reg loss: 0.032 | Tree loss: 3.388 | Accuracy: 0.125000 | 5.123 sec/iter
Epoch: 44 | Batch: 004 / 026 | Total loss: 3.367 | Reg loss: 0.032 | Tree loss: 3.367 | Accuracy: 0.128906 | 5.123 sec/iter
Epoch: 44 | Batch: 005 / 026 | Total loss: 3.337 | Reg loss: 0.032 | Tree loss: 3.337 | Accuracy: 0.144531 | 5.124 sec/iter
Epoch: 44 | Batch: 006 / 026 | Total loss: 3.350 | Reg loss: 0.032 | Tree loss: 3.350 | Accuracy: 0.138672 | 5.124 sec/iter
Epoch: 44 | Batch: 007 / 026 | Total loss: 3.438 | Reg loss: 0.032 | Tree loss: 3.438 | Accuracy: 0.107422 | 5.124 sec/iter
Epoch: 4

Epoch: 46 | Batch: 009 / 026 | Total loss: 3.300 | Reg loss: 0.032 | Tree loss: 3.300 | Accuracy: 0.121094 | 5.124 sec/iter
Epoch: 46 | Batch: 010 / 026 | Total loss: 3.282 | Reg loss: 0.032 | Tree loss: 3.282 | Accuracy: 0.152344 | 5.124 sec/iter
Epoch: 46 | Batch: 011 / 026 | Total loss: 3.276 | Reg loss: 0.032 | Tree loss: 3.276 | Accuracy: 0.144531 | 5.124 sec/iter
Epoch: 46 | Batch: 012 / 026 | Total loss: 3.338 | Reg loss: 0.032 | Tree loss: 3.338 | Accuracy: 0.111328 | 5.124 sec/iter
Epoch: 46 | Batch: 013 / 026 | Total loss: 3.289 | Reg loss: 0.032 | Tree loss: 3.289 | Accuracy: 0.101562 | 5.124 sec/iter
Epoch: 46 | Batch: 014 / 026 | Total loss: 3.320 | Reg loss: 0.032 | Tree loss: 3.320 | Accuracy: 0.121094 | 5.123 sec/iter
Epoch: 46 | Batch: 015 / 026 | Total loss: 3.285 | Reg loss: 0.032 | Tree loss: 3.285 | Accuracy: 0.121094 | 5.123 sec/iter
Epoch: 46 | Batch: 016 / 026 | Total loss: 3.265 | Reg loss: 0.032 | Tree loss: 3.265 | Accuracy: 0.152344 | 5.123 sec/iter
Epoch: 4

Epoch: 48 | Batch: 018 / 026 | Total loss: 3.251 | Reg loss: 0.032 | Tree loss: 3.251 | Accuracy: 0.105469 | 5.124 sec/iter
Epoch: 48 | Batch: 019 / 026 | Total loss: 3.198 | Reg loss: 0.032 | Tree loss: 3.198 | Accuracy: 0.125000 | 5.124 sec/iter
Epoch: 48 | Batch: 020 / 026 | Total loss: 3.241 | Reg loss: 0.032 | Tree loss: 3.241 | Accuracy: 0.123047 | 5.123 sec/iter
Epoch: 48 | Batch: 021 / 026 | Total loss: 3.192 | Reg loss: 0.032 | Tree loss: 3.192 | Accuracy: 0.126953 | 5.123 sec/iter
Epoch: 48 | Batch: 022 / 026 | Total loss: 3.200 | Reg loss: 0.032 | Tree loss: 3.200 | Accuracy: 0.132812 | 5.122 sec/iter
Epoch: 48 | Batch: 023 / 026 | Total loss: 3.188 | Reg loss: 0.032 | Tree loss: 3.188 | Accuracy: 0.144531 | 5.123 sec/iter
Epoch: 48 | Batch: 024 / 026 | Total loss: 3.269 | Reg loss: 0.032 | Tree loss: 3.269 | Accuracy: 0.093750 | 5.123 sec/iter
Epoch: 48 | Batch: 025 / 026 | Total loss: 3.188 | Reg loss: 0.032 | Tree loss: 3.188 | Accuracy: 0.103226 | 5.122 sec/iter
Average 

Epoch: 51 | Batch: 000 / 026 | Total loss: 3.169 | Reg loss: 0.032 | Tree loss: 3.169 | Accuracy: 0.148438 | 5.124 sec/iter
Epoch: 51 | Batch: 001 / 026 | Total loss: 3.223 | Reg loss: 0.032 | Tree loss: 3.223 | Accuracy: 0.130859 | 5.124 sec/iter
Epoch: 51 | Batch: 002 / 026 | Total loss: 3.169 | Reg loss: 0.032 | Tree loss: 3.169 | Accuracy: 0.140625 | 5.124 sec/iter
Epoch: 51 | Batch: 003 / 026 | Total loss: 3.226 | Reg loss: 0.032 | Tree loss: 3.226 | Accuracy: 0.111328 | 5.124 sec/iter
Epoch: 51 | Batch: 004 / 026 | Total loss: 3.164 | Reg loss: 0.032 | Tree loss: 3.164 | Accuracy: 0.128906 | 5.124 sec/iter
Epoch: 51 | Batch: 005 / 026 | Total loss: 3.208 | Reg loss: 0.032 | Tree loss: 3.208 | Accuracy: 0.107422 | 5.124 sec/iter
Epoch: 51 | Batch: 006 / 026 | Total loss: 3.187 | Reg loss: 0.032 | Tree loss: 3.187 | Accuracy: 0.115234 | 5.124 sec/iter
Epoch: 51 | Batch: 007 / 026 | Total loss: 3.171 | Reg loss: 0.032 | Tree loss: 3.171 | Accuracy: 0.132812 | 5.124 sec/iter
Epoch: 5

Epoch: 53 | Batch: 009 / 026 | Total loss: 3.204 | Reg loss: 0.032 | Tree loss: 3.204 | Accuracy: 0.125000 | 5.127 sec/iter
Epoch: 53 | Batch: 010 / 026 | Total loss: 3.167 | Reg loss: 0.032 | Tree loss: 3.167 | Accuracy: 0.130859 | 5.127 sec/iter
Epoch: 53 | Batch: 011 / 026 | Total loss: 3.133 | Reg loss: 0.032 | Tree loss: 3.133 | Accuracy: 0.095703 | 5.127 sec/iter
Epoch: 53 | Batch: 012 / 026 | Total loss: 3.142 | Reg loss: 0.032 | Tree loss: 3.142 | Accuracy: 0.138672 | 5.126 sec/iter
Epoch: 53 | Batch: 013 / 026 | Total loss: 3.088 | Reg loss: 0.032 | Tree loss: 3.088 | Accuracy: 0.134766 | 5.126 sec/iter
Epoch: 53 | Batch: 014 / 026 | Total loss: 3.119 | Reg loss: 0.032 | Tree loss: 3.119 | Accuracy: 0.128906 | 5.126 sec/iter
Epoch: 53 | Batch: 015 / 026 | Total loss: 3.150 | Reg loss: 0.032 | Tree loss: 3.150 | Accuracy: 0.128906 | 5.125 sec/iter
Epoch: 53 | Batch: 016 / 026 | Total loss: 3.135 | Reg loss: 0.032 | Tree loss: 3.135 | Accuracy: 0.107422 | 5.125 sec/iter
Epoch: 5

Epoch: 55 | Batch: 018 / 026 | Total loss: 3.114 | Reg loss: 0.032 | Tree loss: 3.114 | Accuracy: 0.128906 | 5.126 sec/iter
Epoch: 55 | Batch: 019 / 026 | Total loss: 3.123 | Reg loss: 0.032 | Tree loss: 3.123 | Accuracy: 0.109375 | 5.125 sec/iter
Epoch: 55 | Batch: 020 / 026 | Total loss: 3.090 | Reg loss: 0.032 | Tree loss: 3.090 | Accuracy: 0.113281 | 5.124 sec/iter
Epoch: 55 | Batch: 021 / 026 | Total loss: 3.089 | Reg loss: 0.032 | Tree loss: 3.089 | Accuracy: 0.115234 | 5.125 sec/iter
Epoch: 55 | Batch: 022 / 026 | Total loss: 3.108 | Reg loss: 0.032 | Tree loss: 3.108 | Accuracy: 0.128906 | 5.125 sec/iter
Epoch: 55 | Batch: 023 / 026 | Total loss: 3.064 | Reg loss: 0.032 | Tree loss: 3.064 | Accuracy: 0.132812 | 5.125 sec/iter
Epoch: 55 | Batch: 024 / 026 | Total loss: 3.078 | Reg loss: 0.032 | Tree loss: 3.078 | Accuracy: 0.160156 | 5.125 sec/iter
Epoch: 55 | Batch: 025 / 026 | Total loss: 3.075 | Reg loss: 0.032 | Tree loss: 3.075 | Accuracy: 0.135484 | 5.124 sec/iter
Average 

Epoch: 58 | Batch: 000 / 026 | Total loss: 3.092 | Reg loss: 0.032 | Tree loss: 3.092 | Accuracy: 0.115234 | 5.132 sec/iter
Epoch: 58 | Batch: 001 / 026 | Total loss: 3.098 | Reg loss: 0.032 | Tree loss: 3.098 | Accuracy: 0.140625 | 5.132 sec/iter
Epoch: 58 | Batch: 002 / 026 | Total loss: 3.083 | Reg loss: 0.032 | Tree loss: 3.083 | Accuracy: 0.126953 | 5.132 sec/iter
Epoch: 58 | Batch: 003 / 026 | Total loss: 3.114 | Reg loss: 0.032 | Tree loss: 3.114 | Accuracy: 0.136719 | 5.132 sec/iter
Epoch: 58 | Batch: 004 / 026 | Total loss: 3.138 | Reg loss: 0.032 | Tree loss: 3.138 | Accuracy: 0.132812 | 5.132 sec/iter
Epoch: 58 | Batch: 005 / 026 | Total loss: 3.103 | Reg loss: 0.032 | Tree loss: 3.103 | Accuracy: 0.142578 | 5.132 sec/iter
Epoch: 58 | Batch: 006 / 026 | Total loss: 3.098 | Reg loss: 0.032 | Tree loss: 3.098 | Accuracy: 0.123047 | 5.131 sec/iter
Epoch: 58 | Batch: 007 / 026 | Total loss: 3.109 | Reg loss: 0.032 | Tree loss: 3.109 | Accuracy: 0.103516 | 5.131 sec/iter
Epoch: 5

Epoch: 60 | Batch: 009 / 026 | Total loss: 3.053 | Reg loss: 0.032 | Tree loss: 3.053 | Accuracy: 0.126953 | 5.132 sec/iter
Epoch: 60 | Batch: 010 / 026 | Total loss: 3.052 | Reg loss: 0.032 | Tree loss: 3.052 | Accuracy: 0.132812 | 5.132 sec/iter
Epoch: 60 | Batch: 011 / 026 | Total loss: 3.075 | Reg loss: 0.032 | Tree loss: 3.075 | Accuracy: 0.109375 | 5.132 sec/iter
Epoch: 60 | Batch: 012 / 026 | Total loss: 3.060 | Reg loss: 0.032 | Tree loss: 3.060 | Accuracy: 0.113281 | 5.132 sec/iter
Epoch: 60 | Batch: 013 / 026 | Total loss: 3.035 | Reg loss: 0.032 | Tree loss: 3.035 | Accuracy: 0.136719 | 5.131 sec/iter
Epoch: 60 | Batch: 014 / 026 | Total loss: 3.038 | Reg loss: 0.032 | Tree loss: 3.038 | Accuracy: 0.111328 | 5.131 sec/iter
Epoch: 60 | Batch: 015 / 026 | Total loss: 3.067 | Reg loss: 0.032 | Tree loss: 3.067 | Accuracy: 0.138672 | 5.131 sec/iter
Epoch: 60 | Batch: 016 / 026 | Total loss: 3.030 | Reg loss: 0.032 | Tree loss: 3.030 | Accuracy: 0.132812 | 5.13 sec/iter
Epoch: 60

Epoch: 62 | Batch: 018 / 026 | Total loss: 3.010 | Reg loss: 0.032 | Tree loss: 3.010 | Accuracy: 0.146484 | 5.129 sec/iter
Epoch: 62 | Batch: 019 / 026 | Total loss: 2.966 | Reg loss: 0.032 | Tree loss: 2.966 | Accuracy: 0.132812 | 5.129 sec/iter
Epoch: 62 | Batch: 020 / 026 | Total loss: 3.007 | Reg loss: 0.032 | Tree loss: 3.007 | Accuracy: 0.119141 | 5.129 sec/iter
Epoch: 62 | Batch: 021 / 026 | Total loss: 3.046 | Reg loss: 0.032 | Tree loss: 3.046 | Accuracy: 0.121094 | 5.13 sec/iter
Epoch: 62 | Batch: 022 / 026 | Total loss: 3.045 | Reg loss: 0.032 | Tree loss: 3.045 | Accuracy: 0.121094 | 5.13 sec/iter
Epoch: 62 | Batch: 023 / 026 | Total loss: 3.039 | Reg loss: 0.032 | Tree loss: 3.039 | Accuracy: 0.105469 | 5.13 sec/iter
Epoch: 62 | Batch: 024 / 026 | Total loss: 3.000 | Reg loss: 0.032 | Tree loss: 3.000 | Accuracy: 0.119141 | 5.13 sec/iter
Epoch: 62 | Batch: 025 / 026 | Total loss: 3.027 | Reg loss: 0.032 | Tree loss: 3.027 | Accuracy: 0.083871 | 5.129 sec/iter
Average spar

Epoch: 65 | Batch: 000 / 026 | Total loss: 3.023 | Reg loss: 0.032 | Tree loss: 3.023 | Accuracy: 0.099609 | 5.133 sec/iter
Epoch: 65 | Batch: 001 / 026 | Total loss: 3.057 | Reg loss: 0.032 | Tree loss: 3.057 | Accuracy: 0.117188 | 5.133 sec/iter
Epoch: 65 | Batch: 002 / 026 | Total loss: 3.066 | Reg loss: 0.032 | Tree loss: 3.066 | Accuracy: 0.113281 | 5.133 sec/iter
Epoch: 65 | Batch: 003 / 026 | Total loss: 3.046 | Reg loss: 0.032 | Tree loss: 3.046 | Accuracy: 0.128906 | 5.133 sec/iter
Epoch: 65 | Batch: 004 / 026 | Total loss: 3.015 | Reg loss: 0.032 | Tree loss: 3.015 | Accuracy: 0.138672 | 5.133 sec/iter
Epoch: 65 | Batch: 005 / 026 | Total loss: 3.049 | Reg loss: 0.032 | Tree loss: 3.049 | Accuracy: 0.101562 | 5.133 sec/iter
Epoch: 65 | Batch: 006 / 026 | Total loss: 3.050 | Reg loss: 0.032 | Tree loss: 3.050 | Accuracy: 0.113281 | 5.133 sec/iter
Epoch: 65 | Batch: 007 / 026 | Total loss: 3.064 | Reg loss: 0.032 | Tree loss: 3.064 | Accuracy: 0.152344 | 5.133 sec/iter
Epoch: 6

Epoch: 67 | Batch: 009 / 026 | Total loss: 3.054 | Reg loss: 0.032 | Tree loss: 3.054 | Accuracy: 0.105469 | 5.138 sec/iter
Epoch: 67 | Batch: 010 / 026 | Total loss: 3.002 | Reg loss: 0.032 | Tree loss: 3.002 | Accuracy: 0.119141 | 5.137 sec/iter
Epoch: 67 | Batch: 011 / 026 | Total loss: 2.998 | Reg loss: 0.032 | Tree loss: 2.998 | Accuracy: 0.105469 | 5.137 sec/iter
Epoch: 67 | Batch: 012 / 026 | Total loss: 3.025 | Reg loss: 0.032 | Tree loss: 3.025 | Accuracy: 0.119141 | 5.137 sec/iter
Epoch: 67 | Batch: 013 / 026 | Total loss: 3.019 | Reg loss: 0.032 | Tree loss: 3.019 | Accuracy: 0.109375 | 5.137 sec/iter
Epoch: 67 | Batch: 014 / 026 | Total loss: 2.994 | Reg loss: 0.032 | Tree loss: 2.994 | Accuracy: 0.109375 | 5.137 sec/iter
Epoch: 67 | Batch: 015 / 026 | Total loss: 2.970 | Reg loss: 0.032 | Tree loss: 2.970 | Accuracy: 0.125000 | 5.136 sec/iter
Epoch: 67 | Batch: 016 / 026 | Total loss: 2.983 | Reg loss: 0.032 | Tree loss: 2.983 | Accuracy: 0.123047 | 5.136 sec/iter
Epoch: 6

Epoch: 69 | Batch: 018 / 026 | Total loss: 3.008 | Reg loss: 0.032 | Tree loss: 3.008 | Accuracy: 0.107422 | 5.127 sec/iter
Epoch: 69 | Batch: 019 / 026 | Total loss: 2.956 | Reg loss: 0.032 | Tree loss: 2.956 | Accuracy: 0.144531 | 5.126 sec/iter
Epoch: 69 | Batch: 020 / 026 | Total loss: 2.979 | Reg loss: 0.032 | Tree loss: 2.979 | Accuracy: 0.107422 | 5.126 sec/iter
Epoch: 69 | Batch: 021 / 026 | Total loss: 2.976 | Reg loss: 0.032 | Tree loss: 2.976 | Accuracy: 0.142578 | 5.125 sec/iter
Epoch: 69 | Batch: 022 / 026 | Total loss: 2.919 | Reg loss: 0.032 | Tree loss: 2.919 | Accuracy: 0.138672 | 5.124 sec/iter
Epoch: 69 | Batch: 023 / 026 | Total loss: 2.987 | Reg loss: 0.032 | Tree loss: 2.987 | Accuracy: 0.136719 | 5.123 sec/iter
Epoch: 69 | Batch: 024 / 026 | Total loss: 2.948 | Reg loss: 0.032 | Tree loss: 2.948 | Accuracy: 0.125000 | 5.123 sec/iter
Epoch: 69 | Batch: 025 / 026 | Total loss: 2.895 | Reg loss: 0.032 | Tree loss: 2.895 | Accuracy: 0.161290 | 5.122 sec/iter
Average 

Epoch: 72 | Batch: 000 / 026 | Total loss: 3.059 | Reg loss: 0.032 | Tree loss: 3.059 | Accuracy: 0.109375 | 5.087 sec/iter
Epoch: 72 | Batch: 001 / 026 | Total loss: 3.001 | Reg loss: 0.032 | Tree loss: 3.001 | Accuracy: 0.123047 | 5.086 sec/iter
Epoch: 72 | Batch: 002 / 026 | Total loss: 3.037 | Reg loss: 0.032 | Tree loss: 3.037 | Accuracy: 0.111328 | 5.085 sec/iter
Epoch: 72 | Batch: 003 / 026 | Total loss: 2.979 | Reg loss: 0.032 | Tree loss: 2.979 | Accuracy: 0.123047 | 5.084 sec/iter
Epoch: 72 | Batch: 004 / 026 | Total loss: 3.001 | Reg loss: 0.032 | Tree loss: 3.001 | Accuracy: 0.083984 | 5.084 sec/iter
Epoch: 72 | Batch: 005 / 026 | Total loss: 3.013 | Reg loss: 0.032 | Tree loss: 3.013 | Accuracy: 0.132812 | 5.083 sec/iter
Epoch: 72 | Batch: 006 / 026 | Total loss: 3.054 | Reg loss: 0.032 | Tree loss: 3.054 | Accuracy: 0.109375 | 5.082 sec/iter
Epoch: 72 | Batch: 007 / 026 | Total loss: 2.994 | Reg loss: 0.032 | Tree loss: 2.994 | Accuracy: 0.132812 | 5.082 sec/iter
Epoch: 7

Epoch: 74 | Batch: 009 / 026 | Total loss: 3.013 | Reg loss: 0.032 | Tree loss: 3.013 | Accuracy: 0.132812 | 5.048 sec/iter
Epoch: 74 | Batch: 010 / 026 | Total loss: 2.980 | Reg loss: 0.032 | Tree loss: 2.980 | Accuracy: 0.152344 | 5.047 sec/iter
Epoch: 74 | Batch: 011 / 026 | Total loss: 2.970 | Reg loss: 0.032 | Tree loss: 2.970 | Accuracy: 0.101562 | 5.047 sec/iter
Epoch: 74 | Batch: 012 / 026 | Total loss: 2.918 | Reg loss: 0.032 | Tree loss: 2.918 | Accuracy: 0.132812 | 5.046 sec/iter
Epoch: 74 | Batch: 013 / 026 | Total loss: 2.997 | Reg loss: 0.032 | Tree loss: 2.997 | Accuracy: 0.123047 | 5.046 sec/iter
Epoch: 74 | Batch: 014 / 026 | Total loss: 2.945 | Reg loss: 0.032 | Tree loss: 2.945 | Accuracy: 0.121094 | 5.045 sec/iter
Epoch: 74 | Batch: 015 / 026 | Total loss: 2.985 | Reg loss: 0.032 | Tree loss: 2.985 | Accuracy: 0.123047 | 5.044 sec/iter
Epoch: 74 | Batch: 016 / 026 | Total loss: 2.949 | Reg loss: 0.032 | Tree loss: 2.949 | Accuracy: 0.136719 | 5.044 sec/iter
Epoch: 7

Epoch: 76 | Batch: 018 / 026 | Total loss: 2.910 | Reg loss: 0.032 | Tree loss: 2.910 | Accuracy: 0.144531 | 5.012 sec/iter
Epoch: 76 | Batch: 019 / 026 | Total loss: 2.941 | Reg loss: 0.032 | Tree loss: 2.941 | Accuracy: 0.148438 | 5.012 sec/iter
Epoch: 76 | Batch: 020 / 026 | Total loss: 2.958 | Reg loss: 0.032 | Tree loss: 2.958 | Accuracy: 0.126953 | 5.011 sec/iter
Epoch: 76 | Batch: 021 / 026 | Total loss: 2.952 | Reg loss: 0.032 | Tree loss: 2.952 | Accuracy: 0.115234 | 5.011 sec/iter
Epoch: 76 | Batch: 022 / 026 | Total loss: 2.920 | Reg loss: 0.032 | Tree loss: 2.920 | Accuracy: 0.103516 | 5.01 sec/iter
Epoch: 76 | Batch: 023 / 026 | Total loss: 2.943 | Reg loss: 0.032 | Tree loss: 2.943 | Accuracy: 0.105469 | 5.01 sec/iter
Epoch: 76 | Batch: 024 / 026 | Total loss: 2.936 | Reg loss: 0.032 | Tree loss: 2.936 | Accuracy: 0.130859 | 5.009 sec/iter
Epoch: 76 | Batch: 025 / 026 | Total loss: 2.905 | Reg loss: 0.032 | Tree loss: 2.905 | Accuracy: 0.122581 | 5.008 sec/iter
Average sp

Epoch: 79 | Batch: 000 / 026 | Total loss: 2.988 | Reg loss: 0.032 | Tree loss: 2.988 | Accuracy: 0.115234 | 4.98 sec/iter
Epoch: 79 | Batch: 001 / 026 | Total loss: 2.993 | Reg loss: 0.032 | Tree loss: 2.993 | Accuracy: 0.126953 | 4.979 sec/iter
Epoch: 79 | Batch: 002 / 026 | Total loss: 3.021 | Reg loss: 0.032 | Tree loss: 3.021 | Accuracy: 0.105469 | 4.978 sec/iter
Epoch: 79 | Batch: 003 / 026 | Total loss: 3.059 | Reg loss: 0.032 | Tree loss: 3.059 | Accuracy: 0.111328 | 4.978 sec/iter
Epoch: 79 | Batch: 004 / 026 | Total loss: 2.954 | Reg loss: 0.032 | Tree loss: 2.954 | Accuracy: 0.117188 | 4.977 sec/iter
Epoch: 79 | Batch: 005 / 026 | Total loss: 3.002 | Reg loss: 0.032 | Tree loss: 3.002 | Accuracy: 0.117188 | 4.977 sec/iter
Epoch: 79 | Batch: 006 / 026 | Total loss: 3.007 | Reg loss: 0.032 | Tree loss: 3.007 | Accuracy: 0.107422 | 4.976 sec/iter
Epoch: 79 | Batch: 007 / 026 | Total loss: 2.986 | Reg loss: 0.032 | Tree loss: 2.986 | Accuracy: 0.117188 | 4.976 sec/iter
Epoch: 79

Epoch: 81 | Batch: 009 / 026 | Total loss: 2.938 | Reg loss: 0.032 | Tree loss: 2.938 | Accuracy: 0.140625 | 4.948 sec/iter
Epoch: 81 | Batch: 010 / 026 | Total loss: 2.925 | Reg loss: 0.032 | Tree loss: 2.925 | Accuracy: 0.144531 | 4.948 sec/iter
Epoch: 81 | Batch: 011 / 026 | Total loss: 2.947 | Reg loss: 0.032 | Tree loss: 2.947 | Accuracy: 0.121094 | 4.947 sec/iter
Epoch: 81 | Batch: 012 / 026 | Total loss: 2.973 | Reg loss: 0.032 | Tree loss: 2.973 | Accuracy: 0.115234 | 4.947 sec/iter
Epoch: 81 | Batch: 013 / 026 | Total loss: 2.962 | Reg loss: 0.032 | Tree loss: 2.962 | Accuracy: 0.119141 | 4.946 sec/iter
Epoch: 81 | Batch: 014 / 026 | Total loss: 2.974 | Reg loss: 0.032 | Tree loss: 2.974 | Accuracy: 0.138672 | 4.946 sec/iter
Epoch: 81 | Batch: 015 / 026 | Total loss: 2.945 | Reg loss: 0.032 | Tree loss: 2.945 | Accuracy: 0.093750 | 4.945 sec/iter
Epoch: 81 | Batch: 016 / 026 | Total loss: 2.927 | Reg loss: 0.032 | Tree loss: 2.927 | Accuracy: 0.142578 | 4.945 sec/iter
Epoch: 8

Epoch: 83 | Batch: 018 / 026 | Total loss: 2.939 | Reg loss: 0.032 | Tree loss: 2.939 | Accuracy: 0.111328 | 4.919 sec/iter
Epoch: 83 | Batch: 019 / 026 | Total loss: 2.909 | Reg loss: 0.032 | Tree loss: 2.909 | Accuracy: 0.140625 | 4.919 sec/iter
Epoch: 83 | Batch: 020 / 026 | Total loss: 2.951 | Reg loss: 0.032 | Tree loss: 2.951 | Accuracy: 0.119141 | 4.918 sec/iter
Epoch: 83 | Batch: 021 / 026 | Total loss: 2.946 | Reg loss: 0.032 | Tree loss: 2.946 | Accuracy: 0.123047 | 4.918 sec/iter
Epoch: 83 | Batch: 022 / 026 | Total loss: 2.863 | Reg loss: 0.032 | Tree loss: 2.863 | Accuracy: 0.126953 | 4.917 sec/iter
Epoch: 83 | Batch: 023 / 026 | Total loss: 2.936 | Reg loss: 0.032 | Tree loss: 2.936 | Accuracy: 0.107422 | 4.917 sec/iter
Epoch: 83 | Batch: 024 / 026 | Total loss: 2.867 | Reg loss: 0.032 | Tree loss: 2.867 | Accuracy: 0.126953 | 4.916 sec/iter
Epoch: 83 | Batch: 025 / 026 | Total loss: 2.892 | Reg loss: 0.032 | Tree loss: 2.892 | Accuracy: 0.077419 | 4.916 sec/iter
Average 

Epoch: 86 | Batch: 000 / 026 | Total loss: 3.032 | Reg loss: 0.032 | Tree loss: 3.032 | Accuracy: 0.119141 | 4.892 sec/iter
Epoch: 86 | Batch: 001 / 026 | Total loss: 2.982 | Reg loss: 0.032 | Tree loss: 2.982 | Accuracy: 0.128906 | 4.891 sec/iter
Epoch: 86 | Batch: 002 / 026 | Total loss: 3.000 | Reg loss: 0.032 | Tree loss: 3.000 | Accuracy: 0.117188 | 4.891 sec/iter
Epoch: 86 | Batch: 003 / 026 | Total loss: 2.972 | Reg loss: 0.032 | Tree loss: 2.972 | Accuracy: 0.130859 | 4.89 sec/iter
Epoch: 86 | Batch: 004 / 026 | Total loss: 2.984 | Reg loss: 0.032 | Tree loss: 2.984 | Accuracy: 0.101562 | 4.89 sec/iter
Epoch: 86 | Batch: 005 / 026 | Total loss: 2.967 | Reg loss: 0.032 | Tree loss: 2.967 | Accuracy: 0.130859 | 4.889 sec/iter
Epoch: 86 | Batch: 006 / 026 | Total loss: 2.949 | Reg loss: 0.032 | Tree loss: 2.949 | Accuracy: 0.125000 | 4.889 sec/iter
Epoch: 86 | Batch: 007 / 026 | Total loss: 2.969 | Reg loss: 0.032 | Tree loss: 2.969 | Accuracy: 0.126953 | 4.888 sec/iter
Epoch: 86 

Epoch: 88 | Batch: 009 / 026 | Total loss: 2.935 | Reg loss: 0.032 | Tree loss: 2.935 | Accuracy: 0.113281 | 4.865 sec/iter
Epoch: 88 | Batch: 010 / 026 | Total loss: 2.959 | Reg loss: 0.032 | Tree loss: 2.959 | Accuracy: 0.105469 | 4.865 sec/iter
Epoch: 88 | Batch: 011 / 026 | Total loss: 2.993 | Reg loss: 0.032 | Tree loss: 2.993 | Accuracy: 0.105469 | 4.864 sec/iter
Epoch: 88 | Batch: 012 / 026 | Total loss: 2.934 | Reg loss: 0.032 | Tree loss: 2.934 | Accuracy: 0.119141 | 4.864 sec/iter
Epoch: 88 | Batch: 013 / 026 | Total loss: 2.936 | Reg loss: 0.032 | Tree loss: 2.936 | Accuracy: 0.123047 | 4.863 sec/iter
Epoch: 88 | Batch: 014 / 026 | Total loss: 2.921 | Reg loss: 0.032 | Tree loss: 2.921 | Accuracy: 0.103516 | 4.863 sec/iter
Epoch: 88 | Batch: 015 / 026 | Total loss: 2.917 | Reg loss: 0.032 | Tree loss: 2.917 | Accuracy: 0.134766 | 4.862 sec/iter
Epoch: 88 | Batch: 016 / 026 | Total loss: 2.951 | Reg loss: 0.032 | Tree loss: 2.951 | Accuracy: 0.101562 | 4.862 sec/iter
Epoch: 8

Epoch: 90 | Batch: 018 / 026 | Total loss: 2.915 | Reg loss: 0.032 | Tree loss: 2.915 | Accuracy: 0.148438 | 4.84 sec/iter
Epoch: 90 | Batch: 019 / 026 | Total loss: 2.913 | Reg loss: 0.032 | Tree loss: 2.913 | Accuracy: 0.111328 | 4.84 sec/iter
Epoch: 90 | Batch: 020 / 026 | Total loss: 2.905 | Reg loss: 0.032 | Tree loss: 2.905 | Accuracy: 0.121094 | 4.839 sec/iter
Epoch: 90 | Batch: 021 / 026 | Total loss: 2.882 | Reg loss: 0.032 | Tree loss: 2.882 | Accuracy: 0.144531 | 4.839 sec/iter
Epoch: 90 | Batch: 022 / 026 | Total loss: 2.862 | Reg loss: 0.032 | Tree loss: 2.862 | Accuracy: 0.152344 | 4.838 sec/iter
Epoch: 90 | Batch: 023 / 026 | Total loss: 2.926 | Reg loss: 0.032 | Tree loss: 2.926 | Accuracy: 0.111328 | 4.838 sec/iter
Epoch: 90 | Batch: 024 / 026 | Total loss: 2.883 | Reg loss: 0.032 | Tree loss: 2.883 | Accuracy: 0.123047 | 4.838 sec/iter
Epoch: 90 | Batch: 025 / 026 | Total loss: 2.868 | Reg loss: 0.032 | Tree loss: 2.868 | Accuracy: 0.154839 | 4.837 sec/iter
Average sp

Epoch: 93 | Batch: 000 / 026 | Total loss: 3.043 | Reg loss: 0.031 | Tree loss: 3.043 | Accuracy: 0.115234 | 4.816 sec/iter
Epoch: 93 | Batch: 001 / 026 | Total loss: 3.011 | Reg loss: 0.031 | Tree loss: 3.011 | Accuracy: 0.101562 | 4.816 sec/iter
Epoch: 93 | Batch: 002 / 026 | Total loss: 2.975 | Reg loss: 0.031 | Tree loss: 2.975 | Accuracy: 0.111328 | 4.816 sec/iter
Epoch: 93 | Batch: 003 / 026 | Total loss: 3.018 | Reg loss: 0.031 | Tree loss: 3.018 | Accuracy: 0.123047 | 4.815 sec/iter
Epoch: 93 | Batch: 004 / 026 | Total loss: 2.980 | Reg loss: 0.031 | Tree loss: 2.980 | Accuracy: 0.140625 | 4.815 sec/iter
Epoch: 93 | Batch: 005 / 026 | Total loss: 2.980 | Reg loss: 0.031 | Tree loss: 2.980 | Accuracy: 0.117188 | 4.814 sec/iter
Epoch: 93 | Batch: 006 / 026 | Total loss: 2.947 | Reg loss: 0.031 | Tree loss: 2.947 | Accuracy: 0.126953 | 4.814 sec/iter
Epoch: 93 | Batch: 007 / 026 | Total loss: 2.935 | Reg loss: 0.031 | Tree loss: 2.935 | Accuracy: 0.128906 | 4.814 sec/iter
Epoch: 9

Epoch: 95 | Batch: 009 / 026 | Total loss: 2.974 | Reg loss: 0.031 | Tree loss: 2.974 | Accuracy: 0.119141 | 4.794 sec/iter
Epoch: 95 | Batch: 010 / 026 | Total loss: 2.923 | Reg loss: 0.031 | Tree loss: 2.923 | Accuracy: 0.128906 | 4.793 sec/iter
Epoch: 95 | Batch: 011 / 026 | Total loss: 2.936 | Reg loss: 0.031 | Tree loss: 2.936 | Accuracy: 0.107422 | 4.793 sec/iter
Epoch: 95 | Batch: 012 / 026 | Total loss: 2.918 | Reg loss: 0.031 | Tree loss: 2.918 | Accuracy: 0.107422 | 4.793 sec/iter
Epoch: 95 | Batch: 013 / 026 | Total loss: 2.911 | Reg loss: 0.031 | Tree loss: 2.911 | Accuracy: 0.101562 | 4.792 sec/iter
Epoch: 95 | Batch: 014 / 026 | Total loss: 2.921 | Reg loss: 0.031 | Tree loss: 2.921 | Accuracy: 0.109375 | 4.792 sec/iter
Epoch: 95 | Batch: 015 / 026 | Total loss: 2.879 | Reg loss: 0.032 | Tree loss: 2.879 | Accuracy: 0.132812 | 4.792 sec/iter
Epoch: 95 | Batch: 016 / 026 | Total loss: 2.891 | Reg loss: 0.032 | Tree loss: 2.891 | Accuracy: 0.119141 | 4.791 sec/iter
Epoch: 9

Epoch: 97 | Batch: 018 / 026 | Total loss: 2.894 | Reg loss: 0.032 | Tree loss: 2.894 | Accuracy: 0.111328 | 4.772 sec/iter
Epoch: 97 | Batch: 019 / 026 | Total loss: 2.879 | Reg loss: 0.032 | Tree loss: 2.879 | Accuracy: 0.144531 | 4.772 sec/iter
Epoch: 97 | Batch: 020 / 026 | Total loss: 2.880 | Reg loss: 0.032 | Tree loss: 2.880 | Accuracy: 0.119141 | 4.772 sec/iter
Epoch: 97 | Batch: 021 / 026 | Total loss: 2.892 | Reg loss: 0.032 | Tree loss: 2.892 | Accuracy: 0.115234 | 4.771 sec/iter
Epoch: 97 | Batch: 022 / 026 | Total loss: 2.853 | Reg loss: 0.032 | Tree loss: 2.853 | Accuracy: 0.128906 | 4.771 sec/iter
Epoch: 97 | Batch: 023 / 026 | Total loss: 2.922 | Reg loss: 0.032 | Tree loss: 2.922 | Accuracy: 0.101562 | 4.771 sec/iter
Epoch: 97 | Batch: 024 / 026 | Total loss: 2.864 | Reg loss: 0.032 | Tree loss: 2.864 | Accuracy: 0.117188 | 4.77 sec/iter
Epoch: 97 | Batch: 025 / 026 | Total loss: 2.883 | Reg loss: 0.032 | Tree loss: 2.883 | Accuracy: 0.096774 | 4.77 sec/iter
Average sp

In [31]:
plt.figure(figsize=(10, 5))
plt.ylabel("Accuracy")
plt.xlabel('Iteration')
plt.plot(accs, label='Accuracy vs iteration')
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [32]:
plt.figure()
plt.ylabel("Loss")
plt.xlabel('Iteration')
plt.plot(losses, label='Loss vs iteration')
plt.yscale("log")
plt.show()

plt.figure()
weights = tree.inner_nodes.weight.cpu().detach().numpy().flatten()
plt.hist(weights, bins=500)
weights_std = np.std(weights)
weights_mean = np.mean(weights)
plt.axvline(weights_mean + weights_std, color='r')
plt.axvline(weights_mean - weights_std, color='r')
plt.title(f"Mean: {weights_mean}   |   STD: {weights_std}")
plt.yscale("log")
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

# Tree Visualization

In [33]:
plt.figure(figsize=(15, 10), dpi=80)
avg_height, root = tree.visualize()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Average height: 11.991871921182266


# Extract Rules

# Accumulate samples in the leaves

In [34]:
print(f"Number of patterns: {len(root.get_leaves())}")

Number of patterns: 4060


In [35]:
method = 'greedy'

In [36]:
root.clear_leaves_samples()

with torch.no_grad():
    for batch_idx, (data, target) in enumerate(tree_loader):
        root.accumulate_samples(data, method)



# Tighten boundaries

In [37]:
attr_names = dataset.items

# print(attr_names)
leaves = root.get_leaves()
sum_comprehensibility = 0
comprehensibilities = []
for pattern_counter, leaf in enumerate(leaves):
    leaf.reset_path()
    leaf.tighten_with_accumulated_samples()
    conds = leaf.get_path_conditions(attr_names)
    print(f"============== Pattern {pattern_counter + 1} ==============")
    comprehensibilities.append(sum([cond.comprehensibility for cond in conds]))
    
print(f"Average comprehensibility: {np.mean(comprehensibilities)}")
print(f"std comprehensibility: {np.std(comprehensibilities)}")
print(f"var comprehensibility: {np.var(comprehensibilities)}")
print(f"minimum comprehensibility: {np.min(comprehensibilities)}")
print(f"maximum comprehensibility: {np.max(comprehensibilities)}")



  return np.log(1 / (1 - x))





12955






Average comprehensibility: 60.83793103448276


std comprehensibility: 4.078338748915884
var comprehensibility: 16.63284695090878
minimum comprehensibility: 42
maximum comprehensibility: 70
