In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib widget
import numpy as np
import pandas as pd
import torch
import torchvision
from torch import optim
import matplotlib.pyplot as plt
import torch.nn.functional as F
from sklearn.metrics import pairwise_distances
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tqdm.notebook import tqdm
import os
import sys
import time
import torch.nn as nn
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)
from stream_generators.market_basket_dataset import MarketBasketDataset, BinaryEncodingTransform, RemoveItemsTransform
from utils.MatplotlibUtils import reduce_dims_and_plot
from network.auto_encoder import AutoEncoder
from losses.knn_loss import KNNLoss
from sklearn.cluster import DBSCAN
from sklearn.mixture import GaussianMixture
from soft_decision_tree.sdt_model import SDT
from sklearn.metrics import davies_bouldin_score

In [2]:
k = 128
tree_depth = 10
device = 'cuda'
dataset_path = r"/mnt/qnap/ekosman/Groceries_dataset.csv"

## Create the market basket dataset and use one-hot encoding for items

In [3]:
dataset = MarketBasketDataset(dataset_path=dataset_path)

In [4]:
model = AutoEncoder(dataset.n_items, 50, 4).train().to(device)
epochs = 500
lr = 5e-3
batch_size = 512
log_every = 5

In [5]:
dataset.transform = torchvision.transforms.Compose([
    RemoveItemsTransform(p=0.5),
    BinaryEncodingTransform(mapping=dataset.items_to_idx),
]
)
dataset.target_transform = torchvision.transforms.Compose([
    BinaryEncodingTransform(mapping=dataset.items_to_idx),
]
)

In [6]:
model.train()
data_iter = torch.utils.data.DataLoader(dataset,
                                     batch_size=batch_size,
                                     shuffle=True,
                                     num_workers=1,
                                     pin_memory=True)
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)
scheduler = ReduceLROnPlateau(optimizer, 'min', verbose=True, factor=0.7, threshold=1e-4)
knn_crt = KNNLoss(k=k).to(device)
losses = []
alpha = 10/170
gamma = 2
for epoch in range(epochs):
    total_loss = 0
    for iteration, (batch, target) in enumerate(data_iter):
        batch = batch.to(device)
        target = target.to(device)
        outputs, iterm = model(batch, return_intermidiate=True)
        mse_loss = F.binary_cross_entropy_with_logits(outputs, target, reduction='none')
        mask = torch.ones_like(mse_loss)
        mask[target == 0] = alpha ** gamma
        mask[target == 1] = (1 - alpha) ** gamma
        mse_loss = (mse_loss * mask).sum(dim=-1).mean()
        try:
            knn_loss = knn_crt(iterm)
            if torch.isinf(knn_loss):
                knn_loss = 0
        except ValueError:
            knn_loss = torch.tensor(0)
        loss = mse_loss + knn_loss
        total_loss += loss.item()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if iteration % log_every == 0:
            print(f"Epoch {epoch} / {epochs} | iteration {iteration} / {len(data_iter)} | Total Loss: {loss.item()} | KNN Loss: {knn_loss.item()} | BCE Loss: {mse_loss.item()}")
    
    scheduler.step(total_loss / (iteration + 1))
    losses.append(total_loss / (iteration + 1))

Epoch 0 / 500 | iteration 0 / 30 | Total Loss: 8.14717960357666 | KNN Loss: 6.230565071105957 | BCE Loss: 1.916614294052124
Epoch 0 / 500 | iteration 5 / 30 | Total Loss: 8.178574562072754 | KNN Loss: 6.230663299560547 | BCE Loss: 1.947911262512207
Epoch 0 / 500 | iteration 10 / 30 | Total Loss: 8.182377815246582 | KNN Loss: 6.230414390563965 | BCE Loss: 1.9519634246826172
Epoch 0 / 500 | iteration 15 / 30 | Total Loss: 8.153038024902344 | KNN Loss: 6.230363368988037 | BCE Loss: 1.9226750135421753
Epoch 0 / 500 | iteration 20 / 30 | Total Loss: 8.175795555114746 | KNN Loss: 6.23012638092041 | BCE Loss: 1.945669174194336
Epoch 0 / 500 | iteration 25 / 30 | Total Loss: 8.15682315826416 | KNN Loss: 6.229822635650635 | BCE Loss: 1.9270007610321045
Epoch 1 / 500 | iteration 0 / 30 | Total Loss: 8.178810119628906 | KNN Loss: 6.229875564575195 | BCE Loss: 1.9489349126815796
Epoch 1 / 500 | iteration 5 / 30 | Total Loss: 8.180387496948242 | KNN Loss: 6.229690074920654 | BCE Loss: 1.95069742202

Epoch 10 / 500 | iteration 25 / 30 | Total Loss: 7.28761625289917 | KNN Loss: 6.1548004150390625 | BCE Loss: 1.1328157186508179
Epoch 11 / 500 | iteration 0 / 30 | Total Loss: 7.281442642211914 | KNN Loss: 6.142515659332275 | BCE Loss: 1.1389267444610596
Epoch 11 / 500 | iteration 5 / 30 | Total Loss: 7.249922275543213 | KNN Loss: 6.137760639190674 | BCE Loss: 1.1121617555618286
Epoch 11 / 500 | iteration 10 / 30 | Total Loss: 7.2522125244140625 | KNN Loss: 6.126208305358887 | BCE Loss: 1.1260042190551758
Epoch 11 / 500 | iteration 15 / 30 | Total Loss: 7.251084327697754 | KNN Loss: 6.1104416847229 | BCE Loss: 1.1406424045562744
Epoch 11 / 500 | iteration 20 / 30 | Total Loss: 7.217717170715332 | KNN Loss: 6.0983428955078125 | BCE Loss: 1.11937415599823
Epoch 11 / 500 | iteration 25 / 30 | Total Loss: 7.160482406616211 | KNN Loss: 6.068140983581543 | BCE Loss: 1.092341423034668
Epoch 12 / 500 | iteration 0 / 30 | Total Loss: 7.165571212768555 | KNN Loss: 6.057934284210205 | BCE Loss: 1

Epoch 21 / 500 | iteration 20 / 30 | Total Loss: 6.108476638793945 | KNN Loss: 5.061213493347168 | BCE Loss: 1.0472630262374878
Epoch 21 / 500 | iteration 25 / 30 | Total Loss: 6.104894161224365 | KNN Loss: 5.063118934631348 | BCE Loss: 1.0417752265930176
Epoch 22 / 500 | iteration 0 / 30 | Total Loss: 6.156625747680664 | KNN Loss: 5.079104900360107 | BCE Loss: 1.0775208473205566
Epoch 22 / 500 | iteration 5 / 30 | Total Loss: 6.1539106369018555 | KNN Loss: 5.085153102874756 | BCE Loss: 1.0687572956085205
Epoch 22 / 500 | iteration 10 / 30 | Total Loss: 6.137800216674805 | KNN Loss: 5.079092502593994 | BCE Loss: 1.058707594871521
Epoch 22 / 500 | iteration 15 / 30 | Total Loss: 6.13166618347168 | KNN Loss: 5.064792633056641 | BCE Loss: 1.066873550415039
Epoch 22 / 500 | iteration 20 / 30 | Total Loss: 6.091497421264648 | KNN Loss: 5.0570597648620605 | BCE Loss: 1.034437656402588
Epoch 22 / 500 | iteration 25 / 30 | Total Loss: 6.151091575622559 | KNN Loss: 5.051975250244141 | BCE Loss:

Epoch 32 / 500 | iteration 15 / 30 | Total Loss: 6.107427597045898 | KNN Loss: 5.051238536834717 | BCE Loss: 1.0561891794204712
Epoch 32 / 500 | iteration 20 / 30 | Total Loss: 6.115278720855713 | KNN Loss: 5.04813814163208 | BCE Loss: 1.0671404600143433
Epoch 32 / 500 | iteration 25 / 30 | Total Loss: 6.083350658416748 | KNN Loss: 5.041615009307861 | BCE Loss: 1.0417355298995972
Epoch 33 / 500 | iteration 0 / 30 | Total Loss: 6.092284202575684 | KNN Loss: 5.055273532867432 | BCE Loss: 1.037010669708252
Epoch 33 / 500 | iteration 5 / 30 | Total Loss: 6.086962699890137 | KNN Loss: 5.048601150512695 | BCE Loss: 1.0383614301681519
Epoch 33 / 500 | iteration 10 / 30 | Total Loss: 6.096273422241211 | KNN Loss: 5.04991340637207 | BCE Loss: 1.0463597774505615
Epoch 33 / 500 | iteration 15 / 30 | Total Loss: 6.074934959411621 | KNN Loss: 5.046614646911621 | BCE Loss: 1.0283203125
Epoch 33 / 500 | iteration 20 / 30 | Total Loss: 6.096368789672852 | KNN Loss: 5.045067310333252 | BCE Loss: 1.0513

Epoch 43 / 500 | iteration 5 / 30 | Total Loss: 6.062371730804443 | KNN Loss: 5.03947639465332 | BCE Loss: 1.0228954553604126
Epoch 43 / 500 | iteration 10 / 30 | Total Loss: 6.0904436111450195 | KNN Loss: 5.043878555297852 | BCE Loss: 1.0465651750564575
Epoch 43 / 500 | iteration 15 / 30 | Total Loss: 6.092868804931641 | KNN Loss: 5.044665336608887 | BCE Loss: 1.048203468322754
Epoch 43 / 500 | iteration 20 / 30 | Total Loss: 6.085016250610352 | KNN Loss: 5.041781425476074 | BCE Loss: 1.0432350635528564
Epoch 43 / 500 | iteration 25 / 30 | Total Loss: 6.1180877685546875 | KNN Loss: 5.065965175628662 | BCE Loss: 1.0521225929260254
Epoch 44 / 500 | iteration 0 / 30 | Total Loss: 6.081908226013184 | KNN Loss: 5.034112930297852 | BCE Loss: 1.0477955341339111
Epoch 44 / 500 | iteration 5 / 30 | Total Loss: 6.106303691864014 | KNN Loss: 5.046926498413086 | BCE Loss: 1.0593771934509277
Epoch 44 / 500 | iteration 10 / 30 | Total Loss: 6.1107258796691895 | KNN Loss: 5.0440263748168945 | BCE Lo

Epoch 54 / 500 | iteration 0 / 30 | Total Loss: 6.096177577972412 | KNN Loss: 5.045533180236816 | BCE Loss: 1.0506443977355957
Epoch 54 / 500 | iteration 5 / 30 | Total Loss: 6.091556549072266 | KNN Loss: 5.043697834014893 | BCE Loss: 1.0478585958480835
Epoch 54 / 500 | iteration 10 / 30 | Total Loss: 6.100545883178711 | KNN Loss: 5.044763565063477 | BCE Loss: 1.0557821989059448
Epoch 54 / 500 | iteration 15 / 30 | Total Loss: 6.096068382263184 | KNN Loss: 5.0420660972595215 | BCE Loss: 1.0540025234222412
Epoch 54 / 500 | iteration 20 / 30 | Total Loss: 6.108269691467285 | KNN Loss: 5.038491249084473 | BCE Loss: 1.0697782039642334
Epoch 54 / 500 | iteration 25 / 30 | Total Loss: 6.089432716369629 | KNN Loss: 5.039303779602051 | BCE Loss: 1.050128698348999
Epoch 55 / 500 | iteration 0 / 30 | Total Loss: 6.113760471343994 | KNN Loss: 5.051718235015869 | BCE Loss: 1.062042236328125
Epoch 55 / 500 | iteration 5 / 30 | Total Loss: 6.082190990447998 | KNN Loss: 5.049178123474121 | BCE Loss: 

Epoch 64 / 500 | iteration 20 / 30 | Total Loss: 6.087122917175293 | KNN Loss: 5.052177906036377 | BCE Loss: 1.0349452495574951
Epoch 64 / 500 | iteration 25 / 30 | Total Loss: 6.084155082702637 | KNN Loss: 5.036402702331543 | BCE Loss: 1.0477523803710938
Epoch 65 / 500 | iteration 0 / 30 | Total Loss: 6.0978593826293945 | KNN Loss: 5.036871910095215 | BCE Loss: 1.0609874725341797
Epoch 65 / 500 | iteration 5 / 30 | Total Loss: 6.061302661895752 | KNN Loss: 5.0426812171936035 | BCE Loss: 1.0186213254928589
Epoch 65 / 500 | iteration 10 / 30 | Total Loss: 6.076380729675293 | KNN Loss: 5.037388801574707 | BCE Loss: 1.038992166519165
Epoch 65 / 500 | iteration 15 / 30 | Total Loss: 6.095370292663574 | KNN Loss: 5.056882381439209 | BCE Loss: 1.0384880304336548
Epoch 65 / 500 | iteration 20 / 30 | Total Loss: 6.104027271270752 | KNN Loss: 5.064815998077393 | BCE Loss: 1.0392112731933594
Epoch 65 / 500 | iteration 25 / 30 | Total Loss: 6.106443405151367 | KNN Loss: 5.036929607391357 | BCE Lo

Epoch 75 / 500 | iteration 10 / 30 | Total Loss: 6.0768232345581055 | KNN Loss: 5.051542282104492 | BCE Loss: 1.0252811908721924
Epoch 75 / 500 | iteration 15 / 30 | Total Loss: 6.061604976654053 | KNN Loss: 5.036231517791748 | BCE Loss: 1.0253734588623047
Epoch 75 / 500 | iteration 20 / 30 | Total Loss: 6.075700759887695 | KNN Loss: 5.031344413757324 | BCE Loss: 1.0443565845489502
Epoch 75 / 500 | iteration 25 / 30 | Total Loss: 6.0973687171936035 | KNN Loss: 5.049161434173584 | BCE Loss: 1.048207402229309
Epoch 76 / 500 | iteration 0 / 30 | Total Loss: 6.086277961730957 | KNN Loss: 5.052222728729248 | BCE Loss: 1.0340549945831299
Epoch 76 / 500 | iteration 5 / 30 | Total Loss: 6.085200309753418 | KNN Loss: 5.037827014923096 | BCE Loss: 1.0473735332489014
Epoch 76 / 500 | iteration 10 / 30 | Total Loss: 6.088960647583008 | KNN Loss: 5.034401893615723 | BCE Loss: 1.0545589923858643
Epoch 76 / 500 | iteration 15 / 30 | Total Loss: 6.091104507446289 | KNN Loss: 5.034475326538086 | BCE Lo

Epoch 86 / 500 | iteration 5 / 30 | Total Loss: 6.069293022155762 | KNN Loss: 5.03639554977417 | BCE Loss: 1.0328973531723022
Epoch 86 / 500 | iteration 10 / 30 | Total Loss: 6.09158182144165 | KNN Loss: 5.032209873199463 | BCE Loss: 1.059372067451477
Epoch 86 / 500 | iteration 15 / 30 | Total Loss: 6.06585693359375 | KNN Loss: 5.03915548324585 | BCE Loss: 1.0267012119293213
Epoch 86 / 500 | iteration 20 / 30 | Total Loss: 6.083051681518555 | KNN Loss: 5.036562919616699 | BCE Loss: 1.0464890003204346
Epoch 86 / 500 | iteration 25 / 30 | Total Loss: 6.079998016357422 | KNN Loss: 5.03474235534668 | BCE Loss: 1.0452557802200317
Epoch 87 / 500 | iteration 0 / 30 | Total Loss: 6.086360931396484 | KNN Loss: 5.0373311042785645 | BCE Loss: 1.0490297079086304
Epoch 87 / 500 | iteration 5 / 30 | Total Loss: 6.09236478805542 | KNN Loss: 5.035271644592285 | BCE Loss: 1.0570932626724243
Epoch 87 / 500 | iteration 10 / 30 | Total Loss: 6.050549507141113 | KNN Loss: 5.02980899810791 | BCE Loss: 1.020

Epoch 96 / 500 | iteration 25 / 30 | Total Loss: 6.084042549133301 | KNN Loss: 5.027827262878418 | BCE Loss: 1.056215524673462
Epoch 97 / 500 | iteration 0 / 30 | Total Loss: 6.056000709533691 | KNN Loss: 5.038578033447266 | BCE Loss: 1.0174229145050049
Epoch 97 / 500 | iteration 5 / 30 | Total Loss: 6.107810020446777 | KNN Loss: 5.057973861694336 | BCE Loss: 1.0498363971710205
Epoch 97 / 500 | iteration 10 / 30 | Total Loss: 6.077352523803711 | KNN Loss: 5.030940055847168 | BCE Loss: 1.046412467956543
Epoch 97 / 500 | iteration 15 / 30 | Total Loss: 6.089074611663818 | KNN Loss: 5.038368225097656 | BCE Loss: 1.050706386566162
Epoch 97 / 500 | iteration 20 / 30 | Total Loss: 6.097002983093262 | KNN Loss: 5.050445556640625 | BCE Loss: 1.0465574264526367
Epoch 97 / 500 | iteration 25 / 30 | Total Loss: 6.034694671630859 | KNN Loss: 5.024302005767822 | BCE Loss: 1.0103929042816162
Epoch 98 / 500 | iteration 0 / 30 | Total Loss: 6.089261054992676 | KNN Loss: 5.045051097869873 | BCE Loss: 1

Epoch 107 / 500 | iteration 15 / 30 | Total Loss: 6.052009582519531 | KNN Loss: 5.028120517730713 | BCE Loss: 1.0238889455795288
Epoch 107 / 500 | iteration 20 / 30 | Total Loss: 6.063867568969727 | KNN Loss: 5.033749103546143 | BCE Loss: 1.030118465423584
Epoch 107 / 500 | iteration 25 / 30 | Total Loss: 6.080366134643555 | KNN Loss: 5.0443501472473145 | BCE Loss: 1.0360157489776611
Epoch 108 / 500 | iteration 0 / 30 | Total Loss: 6.073193550109863 | KNN Loss: 5.034810543060303 | BCE Loss: 1.0383827686309814
Epoch 108 / 500 | iteration 5 / 30 | Total Loss: 6.060562610626221 | KNN Loss: 5.0455241203308105 | BCE Loss: 1.0150386095046997
Epoch 108 / 500 | iteration 10 / 30 | Total Loss: 6.066206455230713 | KNN Loss: 5.034526348114014 | BCE Loss: 1.0316799879074097
Epoch 108 / 500 | iteration 15 / 30 | Total Loss: 6.068295001983643 | KNN Loss: 5.027552127838135 | BCE Loss: 1.0407429933547974
Epoch 108 / 500 | iteration 20 / 30 | Total Loss: 6.078431606292725 | KNN Loss: 5.03693962097168 |

Epoch 118 / 500 | iteration 5 / 30 | Total Loss: 6.082198619842529 | KNN Loss: 5.041184425354004 | BCE Loss: 1.0410141944885254
Epoch 118 / 500 | iteration 10 / 30 | Total Loss: 6.091071605682373 | KNN Loss: 5.0430731773376465 | BCE Loss: 1.0479984283447266
Epoch 118 / 500 | iteration 15 / 30 | Total Loss: 6.050337791442871 | KNN Loss: 5.035341739654541 | BCE Loss: 1.0149962902069092
Epoch 118 / 500 | iteration 20 / 30 | Total Loss: 6.0671281814575195 | KNN Loss: 5.032009124755859 | BCE Loss: 1.035118818283081
Epoch 118 / 500 | iteration 25 / 30 | Total Loss: 6.057733535766602 | KNN Loss: 5.030279159545898 | BCE Loss: 1.0274542570114136
Epoch 119 / 500 | iteration 0 / 30 | Total Loss: 6.140458106994629 | KNN Loss: 5.109394550323486 | BCE Loss: 1.0310635566711426
Epoch 119 / 500 | iteration 5 / 30 | Total Loss: 6.093816757202148 | KNN Loss: 5.037773132324219 | BCE Loss: 1.0560436248779297
Epoch 119 / 500 | iteration 10 / 30 | Total Loss: 6.059527397155762 | KNN Loss: 5.027854919433594 |

Epoch 128 / 500 | iteration 25 / 30 | Total Loss: 6.070181846618652 | KNN Loss: 5.041353225708008 | BCE Loss: 1.0288288593292236
Epoch 129 / 500 | iteration 0 / 30 | Total Loss: 6.087186813354492 | KNN Loss: 5.0392279624938965 | BCE Loss: 1.0479586124420166
Epoch 129 / 500 | iteration 5 / 30 | Total Loss: 6.069699764251709 | KNN Loss: 5.028414249420166 | BCE Loss: 1.0412856340408325
Epoch 129 / 500 | iteration 10 / 30 | Total Loss: 6.072876930236816 | KNN Loss: 5.037140846252441 | BCE Loss: 1.0357359647750854
Epoch 129 / 500 | iteration 15 / 30 | Total Loss: 6.043679714202881 | KNN Loss: 5.039313793182373 | BCE Loss: 1.0043660402297974
Epoch 129 / 500 | iteration 20 / 30 | Total Loss: 6.062193870544434 | KNN Loss: 5.0416789054870605 | BCE Loss: 1.0205152034759521
Epoch 129 / 500 | iteration 25 / 30 | Total Loss: 6.106677532196045 | KNN Loss: 5.035653591156006 | BCE Loss: 1.0710238218307495
Epoch 130 / 500 | iteration 0 / 30 | Total Loss: 6.092171669006348 | KNN Loss: 5.035059928894043 

Epoch 139 / 500 | iteration 15 / 30 | Total Loss: 6.100985050201416 | KNN Loss: 5.057323455810547 | BCE Loss: 1.0436615943908691
Epoch 139 / 500 | iteration 20 / 30 | Total Loss: 6.089192867279053 | KNN Loss: 5.036113262176514 | BCE Loss: 1.0530794858932495
Epoch 139 / 500 | iteration 25 / 30 | Total Loss: 6.054187774658203 | KNN Loss: 5.035865783691406 | BCE Loss: 1.018322229385376
Epoch 140 / 500 | iteration 0 / 30 | Total Loss: 6.083724498748779 | KNN Loss: 5.05403995513916 | BCE Loss: 1.0296844244003296
Epoch 140 / 500 | iteration 5 / 30 | Total Loss: 6.070745468139648 | KNN Loss: 5.034542560577393 | BCE Loss: 1.0362030267715454
Epoch 140 / 500 | iteration 10 / 30 | Total Loss: 6.087112903594971 | KNN Loss: 5.041832447052002 | BCE Loss: 1.0452804565429688
Epoch 140 / 500 | iteration 15 / 30 | Total Loss: 6.118324279785156 | KNN Loss: 5.039999485015869 | BCE Loss: 1.078324794769287
Epoch 140 / 500 | iteration 20 / 30 | Total Loss: 6.079291820526123 | KNN Loss: 5.036701679229736 | BC

Epoch 150 / 500 | iteration 5 / 30 | Total Loss: 6.071098327636719 | KNN Loss: 5.041639804840088 | BCE Loss: 1.0294585227966309
Epoch 150 / 500 | iteration 10 / 30 | Total Loss: 6.099087715148926 | KNN Loss: 5.037156105041504 | BCE Loss: 1.061931848526001
Epoch 150 / 500 | iteration 15 / 30 | Total Loss: 6.077216148376465 | KNN Loss: 5.041511058807373 | BCE Loss: 1.0357049703598022
Epoch 150 / 500 | iteration 20 / 30 | Total Loss: 6.088779449462891 | KNN Loss: 5.032336235046387 | BCE Loss: 1.056443214416504
Epoch 150 / 500 | iteration 25 / 30 | Total Loss: 6.077906608581543 | KNN Loss: 5.03281831741333 | BCE Loss: 1.0450884103775024
Epoch 151 / 500 | iteration 0 / 30 | Total Loss: 6.090187072753906 | KNN Loss: 5.032061576843262 | BCE Loss: 1.058125376701355
Epoch 151 / 500 | iteration 5 / 30 | Total Loss: 6.081716537475586 | KNN Loss: 5.032903671264648 | BCE Loss: 1.0488128662109375
Epoch 151 / 500 | iteration 10 / 30 | Total Loss: 6.104953765869141 | KNN Loss: 5.046867370605469 | BCE 

Epoch 160 / 500 | iteration 25 / 30 | Total Loss: 6.068413734436035 | KNN Loss: 5.038077354431152 | BCE Loss: 1.0303362607955933
Epoch 161 / 500 | iteration 0 / 30 | Total Loss: 6.089548587799072 | KNN Loss: 5.028110504150391 | BCE Loss: 1.0614380836486816
Epoch 161 / 500 | iteration 5 / 30 | Total Loss: 6.0749616622924805 | KNN Loss: 5.030587673187256 | BCE Loss: 1.0443741083145142
Epoch 161 / 500 | iteration 10 / 30 | Total Loss: 6.071917533874512 | KNN Loss: 5.03441047668457 | BCE Loss: 1.0375070571899414
Epoch 161 / 500 | iteration 15 / 30 | Total Loss: 6.086417198181152 | KNN Loss: 5.027633190155029 | BCE Loss: 1.058783769607544
Epoch 161 / 500 | iteration 20 / 30 | Total Loss: 6.042969703674316 | KNN Loss: 5.027047634124756 | BCE Loss: 1.0159218311309814
Epoch 161 / 500 | iteration 25 / 30 | Total Loss: 6.066227912902832 | KNN Loss: 5.030499458312988 | BCE Loss: 1.0357283353805542
Epoch 162 / 500 | iteration 0 / 30 | Total Loss: 6.105352401733398 | KNN Loss: 5.026792049407959 | B

Epoch 171 / 500 | iteration 15 / 30 | Total Loss: 6.078695297241211 | KNN Loss: 5.036757946014404 | BCE Loss: 1.0419371128082275
Epoch 171 / 500 | iteration 20 / 30 | Total Loss: 6.061735153198242 | KNN Loss: 5.024326801300049 | BCE Loss: 1.0374085903167725
Epoch 171 / 500 | iteration 25 / 30 | Total Loss: 6.0650248527526855 | KNN Loss: 5.032693386077881 | BCE Loss: 1.0323314666748047
Epoch 172 / 500 | iteration 0 / 30 | Total Loss: 6.087860584259033 | KNN Loss: 5.029001712799072 | BCE Loss: 1.0588587522506714
Epoch 172 / 500 | iteration 5 / 30 | Total Loss: 6.058897018432617 | KNN Loss: 5.031718730926514 | BCE Loss: 1.0271785259246826
Epoch 172 / 500 | iteration 10 / 30 | Total Loss: 6.109053134918213 | KNN Loss: 5.049563407897949 | BCE Loss: 1.0594897270202637
Epoch 172 / 500 | iteration 15 / 30 | Total Loss: 6.071697235107422 | KNN Loss: 5.035006999969482 | BCE Loss: 1.036690354347229
Epoch 172 / 500 | iteration 20 / 30 | Total Loss: 6.093623161315918 | KNN Loss: 5.033542633056641 |

Epoch 182 / 500 | iteration 5 / 30 | Total Loss: 6.048887252807617 | KNN Loss: 5.03870964050293 | BCE Loss: 1.0101778507232666
Epoch 182 / 500 | iteration 10 / 30 | Total Loss: 6.09234619140625 | KNN Loss: 5.06146764755249 | BCE Loss: 1.0308787822723389
Epoch 182 / 500 | iteration 15 / 30 | Total Loss: 6.08812952041626 | KNN Loss: 5.026316165924072 | BCE Loss: 1.061813235282898
Epoch 182 / 500 | iteration 20 / 30 | Total Loss: 6.041009426116943 | KNN Loss: 5.030131816864014 | BCE Loss: 1.0108776092529297
Epoch 182 / 500 | iteration 25 / 30 | Total Loss: 6.067170143127441 | KNN Loss: 5.027286529541016 | BCE Loss: 1.0398838520050049
Epoch 183 / 500 | iteration 0 / 30 | Total Loss: 6.075752258300781 | KNN Loss: 5.033078670501709 | BCE Loss: 1.0426733493804932
Epoch 183 / 500 | iteration 5 / 30 | Total Loss: 6.061255931854248 | KNN Loss: 5.03233003616333 | BCE Loss: 1.028925895690918
Epoch 183 / 500 | iteration 10 / 30 | Total Loss: 6.058967590332031 | KNN Loss: 5.030661582946777 | BCE Los

Epoch 192 / 500 | iteration 25 / 30 | Total Loss: 6.0494842529296875 | KNN Loss: 5.026808261871338 | BCE Loss: 1.0226757526397705
Epoch 193 / 500 | iteration 0 / 30 | Total Loss: 6.111610412597656 | KNN Loss: 5.039653778076172 | BCE Loss: 1.0719566345214844
Epoch 193 / 500 | iteration 5 / 30 | Total Loss: 6.069733142852783 | KNN Loss: 5.057985305786133 | BCE Loss: 1.0117478370666504
Epoch 193 / 500 | iteration 10 / 30 | Total Loss: 6.081287384033203 | KNN Loss: 5.043806552886963 | BCE Loss: 1.0374810695648193
Epoch 193 / 500 | iteration 15 / 30 | Total Loss: 6.087515830993652 | KNN Loss: 5.041334629058838 | BCE Loss: 1.046181082725525
Epoch 193 / 500 | iteration 20 / 30 | Total Loss: 6.075685501098633 | KNN Loss: 5.03734016418457 | BCE Loss: 1.0383450984954834
Epoch 193 / 500 | iteration 25 / 30 | Total Loss: 6.0766096115112305 | KNN Loss: 5.027439117431641 | BCE Loss: 1.0491702556610107
Epoch 194 / 500 | iteration 0 / 30 | Total Loss: 6.062281608581543 | KNN Loss: 5.030760288238525 | 

Epoch 203 / 500 | iteration 15 / 30 | Total Loss: 6.103603839874268 | KNN Loss: 5.028852462768555 | BCE Loss: 1.0747514963150024
Epoch 203 / 500 | iteration 20 / 30 | Total Loss: 6.071135520935059 | KNN Loss: 5.041805744171143 | BCE Loss: 1.0293298959732056
Epoch 203 / 500 | iteration 25 / 30 | Total Loss: 6.050136566162109 | KNN Loss: 5.03346061706543 | BCE Loss: 1.0166759490966797
Epoch 204 / 500 | iteration 0 / 30 | Total Loss: 6.067919731140137 | KNN Loss: 5.04030704498291 | BCE Loss: 1.0276126861572266
Epoch 204 / 500 | iteration 5 / 30 | Total Loss: 6.064104080200195 | KNN Loss: 5.040430545806885 | BCE Loss: 1.0236736536026
Epoch 204 / 500 | iteration 10 / 30 | Total Loss: 6.06163215637207 | KNN Loss: 5.033855438232422 | BCE Loss: 1.0277769565582275
Epoch 204 / 500 | iteration 15 / 30 | Total Loss: 6.060523509979248 | KNN Loss: 5.0311737060546875 | BCE Loss: 1.02934992313385
Epoch 204 / 500 | iteration 20 / 30 | Total Loss: 6.084150314331055 | KNN Loss: 5.037420272827148 | BCE Lo

Epoch 214 / 500 | iteration 5 / 30 | Total Loss: 6.08603572845459 | KNN Loss: 5.027030944824219 | BCE Loss: 1.059004545211792
Epoch 214 / 500 | iteration 10 / 30 | Total Loss: 6.065065860748291 | KNN Loss: 5.033153533935547 | BCE Loss: 1.0319123268127441
Epoch 214 / 500 | iteration 15 / 30 | Total Loss: 6.076951026916504 | KNN Loss: 5.036865234375 | BCE Loss: 1.040086030960083
Epoch 214 / 500 | iteration 20 / 30 | Total Loss: 6.08395528793335 | KNN Loss: 5.040464878082275 | BCE Loss: 1.0434904098510742
Epoch 214 / 500 | iteration 25 / 30 | Total Loss: 6.080463886260986 | KNN Loss: 5.031987190246582 | BCE Loss: 1.0484766960144043
Epoch 215 / 500 | iteration 0 / 30 | Total Loss: 6.073736190795898 | KNN Loss: 5.039417743682861 | BCE Loss: 1.034318208694458
Epoch 215 / 500 | iteration 5 / 30 | Total Loss: 6.061346530914307 | KNN Loss: 5.032276630401611 | BCE Loss: 1.0290699005126953
Epoch 215 / 500 | iteration 10 / 30 | Total Loss: 6.0648908615112305 | KNN Loss: 5.046885967254639 | BCE Los

Epoch 224 / 500 | iteration 25 / 30 | Total Loss: 6.041339874267578 | KNN Loss: 5.031439304351807 | BCE Loss: 1.0099003314971924
Epoch 225 / 500 | iteration 0 / 30 | Total Loss: 6.0415143966674805 | KNN Loss: 5.028441429138184 | BCE Loss: 1.0130727291107178
Epoch 225 / 500 | iteration 5 / 30 | Total Loss: 6.065425872802734 | KNN Loss: 5.032253265380859 | BCE Loss: 1.033172369003296
Epoch 225 / 500 | iteration 10 / 30 | Total Loss: 6.065649032592773 | KNN Loss: 5.0324835777282715 | BCE Loss: 1.0331655740737915
Epoch 225 / 500 | iteration 15 / 30 | Total Loss: 6.055556297302246 | KNN Loss: 5.026913166046143 | BCE Loss: 1.0286431312561035
Epoch 225 / 500 | iteration 20 / 30 | Total Loss: 6.106433868408203 | KNN Loss: 5.0446553230285645 | BCE Loss: 1.0617787837982178
Epoch 225 / 500 | iteration 25 / 30 | Total Loss: 6.083904266357422 | KNN Loss: 5.026700019836426 | BCE Loss: 1.057204246520996
Epoch 226 / 500 | iteration 0 / 30 | Total Loss: 6.079084873199463 | KNN Loss: 5.033573627471924 |

Epoch 235 / 500 | iteration 15 / 30 | Total Loss: 6.0731401443481445 | KNN Loss: 5.026750087738037 | BCE Loss: 1.046390175819397
Epoch 235 / 500 | iteration 20 / 30 | Total Loss: 6.082367897033691 | KNN Loss: 5.027311325073242 | BCE Loss: 1.0550568103790283
Epoch 235 / 500 | iteration 25 / 30 | Total Loss: 6.061336517333984 | KNN Loss: 5.034867286682129 | BCE Loss: 1.0264692306518555
Epoch 236 / 500 | iteration 0 / 30 | Total Loss: 6.073512554168701 | KNN Loss: 5.031449794769287 | BCE Loss: 1.0420626401901245
Epoch 236 / 500 | iteration 5 / 30 | Total Loss: 6.075008392333984 | KNN Loss: 5.038653373718262 | BCE Loss: 1.0363550186157227
Epoch 236 / 500 | iteration 10 / 30 | Total Loss: 6.067986488342285 | KNN Loss: 5.029291152954102 | BCE Loss: 1.0386953353881836
Epoch 236 / 500 | iteration 15 / 30 | Total Loss: 6.0874152183532715 | KNN Loss: 5.039962291717529 | BCE Loss: 1.0474530458450317
Epoch 236 / 500 | iteration 20 / 30 | Total Loss: 6.040919303894043 | KNN Loss: 5.037524700164795 

Epoch 246 / 500 | iteration 5 / 30 | Total Loss: 6.055229187011719 | KNN Loss: 5.036261558532715 | BCE Loss: 1.0189673900604248
Epoch 246 / 500 | iteration 10 / 30 | Total Loss: 6.089101791381836 | KNN Loss: 5.045258522033691 | BCE Loss: 1.043843388557434
Epoch 246 / 500 | iteration 15 / 30 | Total Loss: 6.055931091308594 | KNN Loss: 5.025661945343018 | BCE Loss: 1.0302691459655762
Epoch 246 / 500 | iteration 20 / 30 | Total Loss: 6.070645809173584 | KNN Loss: 5.028929233551025 | BCE Loss: 1.0417165756225586
Epoch 246 / 500 | iteration 25 / 30 | Total Loss: 6.062268257141113 | KNN Loss: 5.0315093994140625 | BCE Loss: 1.0307588577270508
Epoch 247 / 500 | iteration 0 / 30 | Total Loss: 6.105381965637207 | KNN Loss: 5.043605327606201 | BCE Loss: 1.0617766380310059
Epoch 247 / 500 | iteration 5 / 30 | Total Loss: 6.079785346984863 | KNN Loss: 5.0329155921936035 | BCE Loss: 1.0468697547912598
Epoch 247 / 500 | iteration 10 / 30 | Total Loss: 6.078884124755859 | KNN Loss: 5.036197662353516 |

Epoch 256 / 500 | iteration 25 / 30 | Total Loss: 6.041276931762695 | KNN Loss: 5.02662992477417 | BCE Loss: 1.0146468877792358
Epoch 257 / 500 | iteration 0 / 30 | Total Loss: 6.083417892456055 | KNN Loss: 5.036128997802734 | BCE Loss: 1.0472886562347412
Epoch 257 / 500 | iteration 5 / 30 | Total Loss: 6.05924129486084 | KNN Loss: 5.030652046203613 | BCE Loss: 1.0285892486572266
Epoch 257 / 500 | iteration 10 / 30 | Total Loss: 6.083641052246094 | KNN Loss: 5.033901214599609 | BCE Loss: 1.0497397184371948
Epoch 257 / 500 | iteration 15 / 30 | Total Loss: 6.095284461975098 | KNN Loss: 5.032439708709717 | BCE Loss: 1.0628447532653809
Epoch 257 / 500 | iteration 20 / 30 | Total Loss: 6.07939338684082 | KNN Loss: 5.0299177169799805 | BCE Loss: 1.0494756698608398
Epoch 257 / 500 | iteration 25 / 30 | Total Loss: 6.054522514343262 | KNN Loss: 5.0301947593688965 | BCE Loss: 1.0243279933929443
Epoch 258 / 500 | iteration 0 / 30 | Total Loss: 6.066711902618408 | KNN Loss: 5.036285877227783 | B

Epoch 267 / 500 | iteration 15 / 30 | Total Loss: 6.086468696594238 | KNN Loss: 5.0313401222229 | BCE Loss: 1.055128574371338
Epoch 267 / 500 | iteration 20 / 30 | Total Loss: 6.087555885314941 | KNN Loss: 5.046238422393799 | BCE Loss: 1.0413177013397217
Epoch 267 / 500 | iteration 25 / 30 | Total Loss: 6.056197643280029 | KNN Loss: 5.035429000854492 | BCE Loss: 1.020768642425537
Epoch 268 / 500 | iteration 0 / 30 | Total Loss: 6.072237014770508 | KNN Loss: 5.034692287445068 | BCE Loss: 1.0375444889068604
Epoch 268 / 500 | iteration 5 / 30 | Total Loss: 6.060156345367432 | KNN Loss: 5.029608249664307 | BCE Loss: 1.030548095703125
Epoch 268 / 500 | iteration 10 / 30 | Total Loss: 6.06443977355957 | KNN Loss: 5.037171840667725 | BCE Loss: 1.0272676944732666
Epoch 268 / 500 | iteration 15 / 30 | Total Loss: 6.098935127258301 | KNN Loss: 5.03472375869751 | BCE Loss: 1.064211368560791
Epoch 268 / 500 | iteration 20 / 30 | Total Loss: 6.097809791564941 | KNN Loss: 5.047125816345215 | BCE Los

Epoch 278 / 500 | iteration 5 / 30 | Total Loss: 6.0572052001953125 | KNN Loss: 5.023245334625244 | BCE Loss: 1.0339596271514893
Epoch 278 / 500 | iteration 10 / 30 | Total Loss: 6.079711437225342 | KNN Loss: 5.042956829071045 | BCE Loss: 1.0367544889450073
Epoch 278 / 500 | iteration 15 / 30 | Total Loss: 6.069212913513184 | KNN Loss: 5.040118217468262 | BCE Loss: 1.029094934463501
Epoch 278 / 500 | iteration 20 / 30 | Total Loss: 6.061029434204102 | KNN Loss: 5.0335307121276855 | BCE Loss: 1.0274986028671265
Epoch 278 / 500 | iteration 25 / 30 | Total Loss: 6.098328113555908 | KNN Loss: 5.034042835235596 | BCE Loss: 1.0642852783203125
Epoch 279 / 500 | iteration 0 / 30 | Total Loss: 6.089414119720459 | KNN Loss: 5.03481388092041 | BCE Loss: 1.0546001195907593
Epoch 279 / 500 | iteration 5 / 30 | Total Loss: 6.0807952880859375 | KNN Loss: 5.032888412475586 | BCE Loss: 1.0479066371917725
Epoch 279 / 500 | iteration 10 / 30 | Total Loss: 6.0657148361206055 | KNN Loss: 5.03255033493042 |

Epoch 288 / 500 | iteration 25 / 30 | Total Loss: 6.0515594482421875 | KNN Loss: 5.032769680023193 | BCE Loss: 1.0187900066375732
Epoch 289 / 500 | iteration 0 / 30 | Total Loss: 6.083308219909668 | KNN Loss: 5.037139415740967 | BCE Loss: 1.0461688041687012
Epoch 289 / 500 | iteration 5 / 30 | Total Loss: 6.092341423034668 | KNN Loss: 5.0307297706604 | BCE Loss: 1.0616118907928467
Epoch 289 / 500 | iteration 10 / 30 | Total Loss: 6.058474540710449 | KNN Loss: 5.038815498352051 | BCE Loss: 1.0196588039398193
Epoch 289 / 500 | iteration 15 / 30 | Total Loss: 6.053526401519775 | KNN Loss: 5.034684658050537 | BCE Loss: 1.0188418626785278
Epoch 289 / 500 | iteration 20 / 30 | Total Loss: 6.061361312866211 | KNN Loss: 5.044064998626709 | BCE Loss: 1.017296552658081
Epoch 289 / 500 | iteration 25 / 30 | Total Loss: 6.065059185028076 | KNN Loss: 5.024938583374023 | BCE Loss: 1.0401207208633423
Epoch 290 / 500 | iteration 0 / 30 | Total Loss: 6.1092424392700195 | KNN Loss: 5.036752700805664 | B

Epoch 299 / 500 | iteration 15 / 30 | Total Loss: 6.07545280456543 | KNN Loss: 5.040984153747559 | BCE Loss: 1.0344688892364502
Epoch 299 / 500 | iteration 20 / 30 | Total Loss: 6.058280944824219 | KNN Loss: 5.037469387054443 | BCE Loss: 1.0208115577697754
Epoch 299 / 500 | iteration 25 / 30 | Total Loss: 6.069382190704346 | KNN Loss: 5.045237064361572 | BCE Loss: 1.024145245552063
Epoch 300 / 500 | iteration 0 / 30 | Total Loss: 6.079411029815674 | KNN Loss: 5.037721157073975 | BCE Loss: 1.0416897535324097
Epoch 300 / 500 | iteration 5 / 30 | Total Loss: 6.0621209144592285 | KNN Loss: 5.037947654724121 | BCE Loss: 1.0241732597351074
Epoch 300 / 500 | iteration 10 / 30 | Total Loss: 6.0689239501953125 | KNN Loss: 5.033682823181152 | BCE Loss: 1.035240888595581
Epoch 300 / 500 | iteration 15 / 30 | Total Loss: 6.074100017547607 | KNN Loss: 5.029005527496338 | BCE Loss: 1.045094609260559
Epoch 300 / 500 | iteration 20 / 30 | Total Loss: 6.099788665771484 | KNN Loss: 5.033951282501221 | B

Epoch 310 / 500 | iteration 5 / 30 | Total Loss: 6.081766128540039 | KNN Loss: 5.027576446533203 | BCE Loss: 1.054189682006836
Epoch 310 / 500 | iteration 10 / 30 | Total Loss: 6.089794158935547 | KNN Loss: 5.039479732513428 | BCE Loss: 1.0503146648406982
Epoch 310 / 500 | iteration 15 / 30 | Total Loss: 6.1127166748046875 | KNN Loss: 5.041300296783447 | BCE Loss: 1.0714163780212402
Epoch 310 / 500 | iteration 20 / 30 | Total Loss: 6.088671684265137 | KNN Loss: 5.05385160446167 | BCE Loss: 1.0348198413848877
Epoch 310 / 500 | iteration 25 / 30 | Total Loss: 6.057802200317383 | KNN Loss: 5.023895740509033 | BCE Loss: 1.0339064598083496
Epoch 311 / 500 | iteration 0 / 30 | Total Loss: 6.06911039352417 | KNN Loss: 5.025988578796387 | BCE Loss: 1.0431219339370728
Epoch 311 / 500 | iteration 5 / 30 | Total Loss: 6.065236568450928 | KNN Loss: 5.025812149047852 | BCE Loss: 1.0394243001937866
Epoch 311 / 500 | iteration 10 / 30 | Total Loss: 6.076449394226074 | KNN Loss: 5.031102180480957 | BC

Epoch 320 / 500 | iteration 25 / 30 | Total Loss: 6.07716178894043 | KNN Loss: 5.039038181304932 | BCE Loss: 1.0381238460540771
Epoch 321 / 500 | iteration 0 / 30 | Total Loss: 6.0949788093566895 | KNN Loss: 5.047385215759277 | BCE Loss: 1.047593593597412
Epoch 321 / 500 | iteration 5 / 30 | Total Loss: 6.08770227432251 | KNN Loss: 5.035257816314697 | BCE Loss: 1.0524444580078125
Epoch 321 / 500 | iteration 10 / 30 | Total Loss: 6.0838823318481445 | KNN Loss: 5.033627033233643 | BCE Loss: 1.050255537033081
Epoch 321 / 500 | iteration 15 / 30 | Total Loss: 6.031824111938477 | KNN Loss: 5.031899929046631 | BCE Loss: 0.99992436170578
Epoch 321 / 500 | iteration 20 / 30 | Total Loss: 6.099966049194336 | KNN Loss: 5.032649993896484 | BCE Loss: 1.0673160552978516
Epoch 321 / 500 | iteration 25 / 30 | Total Loss: 6.053305149078369 | KNN Loss: 5.02946138381958 | BCE Loss: 1.0238438844680786
Epoch 322 / 500 | iteration 0 / 30 | Total Loss: 6.07880973815918 | KNN Loss: 5.035634994506836 | BCE Lo

Epoch 331 / 500 | iteration 15 / 30 | Total Loss: 6.097304344177246 | KNN Loss: 5.057595252990723 | BCE Loss: 1.0397093296051025
Epoch 331 / 500 | iteration 20 / 30 | Total Loss: 6.050758361816406 | KNN Loss: 5.027897834777832 | BCE Loss: 1.0228602886199951
Epoch 331 / 500 | iteration 25 / 30 | Total Loss: 6.097956657409668 | KNN Loss: 5.041385173797607 | BCE Loss: 1.0565712451934814
Epoch 332 / 500 | iteration 0 / 30 | Total Loss: 6.098381042480469 | KNN Loss: 5.044937610626221 | BCE Loss: 1.053443431854248
Epoch 332 / 500 | iteration 5 / 30 | Total Loss: 6.057406425476074 | KNN Loss: 5.037785053253174 | BCE Loss: 1.01962149143219
Epoch 332 / 500 | iteration 10 / 30 | Total Loss: 6.073476314544678 | KNN Loss: 5.027244567871094 | BCE Loss: 1.0462318658828735
Epoch 332 / 500 | iteration 15 / 30 | Total Loss: 6.069616317749023 | KNN Loss: 5.037257671356201 | BCE Loss: 1.0323587656021118
Epoch 332 / 500 | iteration 20 / 30 | Total Loss: 6.0621466636657715 | KNN Loss: 5.027753829956055 | B

Epoch 342 / 500 | iteration 5 / 30 | Total Loss: 6.0912041664123535 | KNN Loss: 5.035335540771484 | BCE Loss: 1.0558687448501587
Epoch 342 / 500 | iteration 10 / 30 | Total Loss: 6.085033416748047 | KNN Loss: 5.047191619873047 | BCE Loss: 1.037842035293579
Epoch 342 / 500 | iteration 15 / 30 | Total Loss: 6.060906410217285 | KNN Loss: 5.029366970062256 | BCE Loss: 1.0315396785736084
Epoch 342 / 500 | iteration 20 / 30 | Total Loss: 6.061972618103027 | KNN Loss: 5.030698299407959 | BCE Loss: 1.031274437904358
Epoch 342 / 500 | iteration 25 / 30 | Total Loss: 6.038148403167725 | KNN Loss: 5.029440879821777 | BCE Loss: 1.0087075233459473
Epoch 343 / 500 | iteration 0 / 30 | Total Loss: 6.070773124694824 | KNN Loss: 5.030824661254883 | BCE Loss: 1.0399484634399414
Epoch 343 / 500 | iteration 5 / 30 | Total Loss: 6.042928695678711 | KNN Loss: 5.034734725952148 | BCE Loss: 1.0081937313079834
Epoch 343 / 500 | iteration 10 / 30 | Total Loss: 6.05419921875 | KNN Loss: 5.033200740814209 | BCE L

Epoch 352 / 500 | iteration 25 / 30 | Total Loss: 6.080989837646484 | KNN Loss: 5.040487766265869 | BCE Loss: 1.0405018329620361
Epoch 353 / 500 | iteration 0 / 30 | Total Loss: 6.086658000946045 | KNN Loss: 5.028843879699707 | BCE Loss: 1.057814121246338
Epoch 353 / 500 | iteration 5 / 30 | Total Loss: 6.049939155578613 | KNN Loss: 5.0324625968933105 | BCE Loss: 1.0174767971038818
Epoch 353 / 500 | iteration 10 / 30 | Total Loss: 6.059459686279297 | KNN Loss: 5.030266761779785 | BCE Loss: 1.0291929244995117
Epoch 353 / 500 | iteration 15 / 30 | Total Loss: 6.061821937561035 | KNN Loss: 5.0346903800964355 | BCE Loss: 1.0271315574645996
Epoch 353 / 500 | iteration 20 / 30 | Total Loss: 6.041083812713623 | KNN Loss: 5.032082557678223 | BCE Loss: 1.0090012550354004
Epoch 353 / 500 | iteration 25 / 30 | Total Loss: 6.100796699523926 | KNN Loss: 5.039881229400635 | BCE Loss: 1.0609153509140015
Epoch 354 / 500 | iteration 0 / 30 | Total Loss: 6.119894504547119 | KNN Loss: 5.032805442810059 |

Epoch 363 / 500 | iteration 15 / 30 | Total Loss: 6.063923358917236 | KNN Loss: 5.037242889404297 | BCE Loss: 1.026680588722229
Epoch 363 / 500 | iteration 20 / 30 | Total Loss: 6.074005126953125 | KNN Loss: 5.035107612609863 | BCE Loss: 1.0388977527618408
Epoch 363 / 500 | iteration 25 / 30 | Total Loss: 6.076988220214844 | KNN Loss: 5.043315887451172 | BCE Loss: 1.033672571182251
Epoch 364 / 500 | iteration 0 / 30 | Total Loss: 6.070315361022949 | KNN Loss: 5.030518054962158 | BCE Loss: 1.039797067642212
Epoch 364 / 500 | iteration 5 / 30 | Total Loss: 6.090076446533203 | KNN Loss: 5.050982475280762 | BCE Loss: 1.0390942096710205
Epoch 364 / 500 | iteration 10 / 30 | Total Loss: 6.121639251708984 | KNN Loss: 5.052523136138916 | BCE Loss: 1.0691161155700684
Epoch 364 / 500 | iteration 15 / 30 | Total Loss: 6.106818675994873 | KNN Loss: 5.05309534072876 | BCE Loss: 1.0537232160568237
Epoch 364 / 500 | iteration 20 / 30 | Total Loss: 6.059082508087158 | KNN Loss: 5.030580043792725 | BCE

Epoch 374 / 500 | iteration 5 / 30 | Total Loss: 6.083366870880127 | KNN Loss: 5.036783218383789 | BCE Loss: 1.0465835332870483
Epoch 374 / 500 | iteration 10 / 30 | Total Loss: 6.058257102966309 | KNN Loss: 5.030787944793701 | BCE Loss: 1.0274691581726074
Epoch 374 / 500 | iteration 15 / 30 | Total Loss: 6.058038234710693 | KNN Loss: 5.031131267547607 | BCE Loss: 1.026906967163086
Epoch 374 / 500 | iteration 20 / 30 | Total Loss: 6.079001426696777 | KNN Loss: 5.033856391906738 | BCE Loss: 1.04514479637146
Epoch 374 / 500 | iteration 25 / 30 | Total Loss: 6.08532190322876 | KNN Loss: 5.046485424041748 | BCE Loss: 1.0388364791870117
Epoch 375 / 500 | iteration 0 / 30 | Total Loss: 6.063356399536133 | KNN Loss: 5.039042949676514 | BCE Loss: 1.0243136882781982
Epoch 375 / 500 | iteration 5 / 30 | Total Loss: 6.044239044189453 | KNN Loss: 5.027237892150879 | BCE Loss: 1.0170009136199951
Epoch 375 / 500 | iteration 10 / 30 | Total Loss: 6.0674028396606445 | KNN Loss: 5.0397162437438965 | BC

Epoch 384 / 500 | iteration 25 / 30 | Total Loss: 6.043972492218018 | KNN Loss: 5.026390552520752 | BCE Loss: 1.0175819396972656
Epoch 385 / 500 | iteration 0 / 30 | Total Loss: 6.071282386779785 | KNN Loss: 5.037160396575928 | BCE Loss: 1.0341217517852783
Epoch 385 / 500 | iteration 5 / 30 | Total Loss: 6.046072006225586 | KNN Loss: 5.0238776206970215 | BCE Loss: 1.0221941471099854
Epoch 385 / 500 | iteration 10 / 30 | Total Loss: 6.080229759216309 | KNN Loss: 5.039120197296143 | BCE Loss: 1.0411098003387451
Epoch 385 / 500 | iteration 15 / 30 | Total Loss: 6.098684310913086 | KNN Loss: 5.046571731567383 | BCE Loss: 1.052112340927124
Epoch 385 / 500 | iteration 20 / 30 | Total Loss: 6.057171821594238 | KNN Loss: 5.028180122375488 | BCE Loss: 1.028991460800171
Epoch 385 / 500 | iteration 25 / 30 | Total Loss: 6.069547653198242 | KNN Loss: 5.033443927764893 | BCE Loss: 1.0361034870147705
Epoch 386 / 500 | iteration 0 / 30 | Total Loss: 6.059021949768066 | KNN Loss: 5.029446125030518 | B

Epoch 395 / 500 | iteration 15 / 30 | Total Loss: 6.125980854034424 | KNN Loss: 5.054740905761719 | BCE Loss: 1.0712398290634155
Epoch 395 / 500 | iteration 20 / 30 | Total Loss: 6.045041561126709 | KNN Loss: 5.038346290588379 | BCE Loss: 1.0066953897476196
Epoch 395 / 500 | iteration 25 / 30 | Total Loss: 6.05496883392334 | KNN Loss: 5.027224063873291 | BCE Loss: 1.0277445316314697
Epoch 396 / 500 | iteration 0 / 30 | Total Loss: 6.0308027267456055 | KNN Loss: 5.042102336883545 | BCE Loss: 0.9887003302574158
Epoch 396 / 500 | iteration 5 / 30 | Total Loss: 6.054042339324951 | KNN Loss: 5.035949230194092 | BCE Loss: 1.0180931091308594
Epoch 396 / 500 | iteration 10 / 30 | Total Loss: 6.104779243469238 | KNN Loss: 5.047010898590088 | BCE Loss: 1.0577685832977295
Epoch 396 / 500 | iteration 15 / 30 | Total Loss: 6.071751594543457 | KNN Loss: 5.036447525024414 | BCE Loss: 1.0353041887283325
Epoch 396 / 500 | iteration 20 / 30 | Total Loss: 6.050614356994629 | KNN Loss: 5.037130355834961 |

Epoch 406 / 500 | iteration 5 / 30 | Total Loss: 6.095922946929932 | KNN Loss: 5.046838283538818 | BCE Loss: 1.0490846633911133
Epoch 406 / 500 | iteration 10 / 30 | Total Loss: 6.056415557861328 | KNN Loss: 5.02637243270874 | BCE Loss: 1.030043125152588
Epoch 406 / 500 | iteration 15 / 30 | Total Loss: 6.069551467895508 | KNN Loss: 5.032211780548096 | BCE Loss: 1.037339448928833
Epoch 406 / 500 | iteration 20 / 30 | Total Loss: 6.066575050354004 | KNN Loss: 5.027983665466309 | BCE Loss: 1.0385916233062744
Epoch 406 / 500 | iteration 25 / 30 | Total Loss: 6.072520732879639 | KNN Loss: 5.033507347106934 | BCE Loss: 1.039013385772705
Epoch 407 / 500 | iteration 0 / 30 | Total Loss: 6.067142486572266 | KNN Loss: 5.030191421508789 | BCE Loss: 1.0369511842727661
Epoch 407 / 500 | iteration 5 / 30 | Total Loss: 6.062765598297119 | KNN Loss: 5.027805805206299 | BCE Loss: 1.0349596738815308
Epoch 407 / 500 | iteration 10 / 30 | Total Loss: 6.085871696472168 | KNN Loss: 5.03318452835083 | BCE L

Epoch 416 / 500 | iteration 25 / 30 | Total Loss: 6.071086883544922 | KNN Loss: 5.036660671234131 | BCE Loss: 1.034425973892212
Epoch   417: reducing learning rate of group 0 to 7.8888e-08.
Epoch 417 / 500 | iteration 0 / 30 | Total Loss: 6.077520370483398 | KNN Loss: 5.03652811050415 | BCE Loss: 1.0409921407699585
Epoch 417 / 500 | iteration 5 / 30 | Total Loss: 6.122861862182617 | KNN Loss: 5.073422908782959 | BCE Loss: 1.0494391918182373
Epoch 417 / 500 | iteration 10 / 30 | Total Loss: 6.067098617553711 | KNN Loss: 5.0358405113220215 | BCE Loss: 1.0312583446502686
Epoch 417 / 500 | iteration 15 / 30 | Total Loss: 6.039282321929932 | KNN Loss: 5.034909248352051 | BCE Loss: 1.0043730735778809
Epoch 417 / 500 | iteration 20 / 30 | Total Loss: 6.063196182250977 | KNN Loss: 5.035141468048096 | BCE Loss: 1.0280547142028809
Epoch 417 / 500 | iteration 25 / 30 | Total Loss: 6.072604179382324 | KNN Loss: 5.029934883117676 | BCE Loss: 1.0426692962646484
Epoch 418 / 500 | iteration 0 / 30 | T

Epoch 427 / 500 | iteration 15 / 30 | Total Loss: 6.062180042266846 | KNN Loss: 5.026371002197266 | BCE Loss: 1.0358091592788696
Epoch 427 / 500 | iteration 20 / 30 | Total Loss: 6.050676345825195 | KNN Loss: 5.028989315032959 | BCE Loss: 1.0216871500015259
Epoch 427 / 500 | iteration 25 / 30 | Total Loss: 6.114026069641113 | KNN Loss: 5.033078670501709 | BCE Loss: 1.0809476375579834
Epoch   428: reducing learning rate of group 0 to 5.5221e-08.
Epoch 428 / 500 | iteration 0 / 30 | Total Loss: 6.051994323730469 | KNN Loss: 5.031574249267578 | BCE Loss: 1.0204203128814697
Epoch 428 / 500 | iteration 5 / 30 | Total Loss: 6.065311908721924 | KNN Loss: 5.033810615539551 | BCE Loss: 1.031501293182373
Epoch 428 / 500 | iteration 10 / 30 | Total Loss: 6.040180683135986 | KNN Loss: 5.02991247177124 | BCE Loss: 1.010268211364746
Epoch 428 / 500 | iteration 15 / 30 | Total Loss: 6.069477558135986 | KNN Loss: 5.028721332550049 | BCE Loss: 1.0407562255859375
Epoch 428 / 500 | iteration 20 / 30 | To

Epoch 438 / 500 | iteration 5 / 30 | Total Loss: 6.085519790649414 | KNN Loss: 5.031088352203369 | BCE Loss: 1.0544315576553345
Epoch 438 / 500 | iteration 10 / 30 | Total Loss: 6.05369234085083 | KNN Loss: 5.031317710876465 | BCE Loss: 1.0223746299743652
Epoch 438 / 500 | iteration 15 / 30 | Total Loss: 6.089005470275879 | KNN Loss: 5.030367374420166 | BCE Loss: 1.058638334274292
Epoch 438 / 500 | iteration 20 / 30 | Total Loss: 6.090102672576904 | KNN Loss: 5.054186820983887 | BCE Loss: 1.035915732383728
Epoch 438 / 500 | iteration 25 / 30 | Total Loss: 6.048975944519043 | KNN Loss: 5.028647422790527 | BCE Loss: 1.020328402519226
Epoch   439: reducing learning rate of group 0 to 3.8655e-08.
Epoch 439 / 500 | iteration 0 / 30 | Total Loss: 6.054462432861328 | KNN Loss: 5.035567283630371 | BCE Loss: 1.0188950300216675
Epoch 439 / 500 | iteration 5 / 30 | Total Loss: 6.09264612197876 | KNN Loss: 5.032491207122803 | BCE Loss: 1.0601547956466675
Epoch 439 / 500 | iteration 10 / 30 | Total

Epoch 448 / 500 | iteration 25 / 30 | Total Loss: 6.085725784301758 | KNN Loss: 5.032927989959717 | BCE Loss: 1.052797555923462
Epoch 449 / 500 | iteration 0 / 30 | Total Loss: 6.07209587097168 | KNN Loss: 5.041560173034668 | BCE Loss: 1.0305359363555908
Epoch 449 / 500 | iteration 5 / 30 | Total Loss: 6.062740802764893 | KNN Loss: 5.03169059753418 | BCE Loss: 1.031050205230713
Epoch 449 / 500 | iteration 10 / 30 | Total Loss: 6.059811592102051 | KNN Loss: 5.02908182144165 | BCE Loss: 1.0307297706604004
Epoch 449 / 500 | iteration 15 / 30 | Total Loss: 6.049190521240234 | KNN Loss: 5.028652191162109 | BCE Loss: 1.020538091659546
Epoch 449 / 500 | iteration 20 / 30 | Total Loss: 6.059130668640137 | KNN Loss: 5.042571067810059 | BCE Loss: 1.016559362411499
Epoch 449 / 500 | iteration 25 / 30 | Total Loss: 6.058109283447266 | KNN Loss: 5.032116413116455 | BCE Loss: 1.0259931087493896
Epoch   450: reducing learning rate of group 0 to 2.7058e-08.
Epoch 450 / 500 | iteration 0 / 30 | Total L

Epoch 459 / 500 | iteration 15 / 30 | Total Loss: 6.046849727630615 | KNN Loss: 5.028506278991699 | BCE Loss: 1.0183435678482056
Epoch 459 / 500 | iteration 20 / 30 | Total Loss: 6.051676273345947 | KNN Loss: 5.032701015472412 | BCE Loss: 1.0189752578735352
Epoch 459 / 500 | iteration 25 / 30 | Total Loss: 6.064789295196533 | KNN Loss: 5.031075954437256 | BCE Loss: 1.0337133407592773
Epoch 460 / 500 | iteration 0 / 30 | Total Loss: 6.062212944030762 | KNN Loss: 5.031175136566162 | BCE Loss: 1.0310375690460205
Epoch 460 / 500 | iteration 5 / 30 | Total Loss: 6.075439453125 | KNN Loss: 5.028006553649902 | BCE Loss: 1.0474330186843872
Epoch 460 / 500 | iteration 10 / 30 | Total Loss: 6.065629959106445 | KNN Loss: 5.045800685882568 | BCE Loss: 1.0198293924331665
Epoch 460 / 500 | iteration 15 / 30 | Total Loss: 6.087003707885742 | KNN Loss: 5.031047344207764 | BCE Loss: 1.0559563636779785
Epoch 460 / 500 | iteration 20 / 30 | Total Loss: 6.056163787841797 | KNN Loss: 5.03107213973999 | BCE

Epoch 470 / 500 | iteration 5 / 30 | Total Loss: 6.060054779052734 | KNN Loss: 5.034311771392822 | BCE Loss: 1.0257431268692017
Epoch 470 / 500 | iteration 10 / 30 | Total Loss: 6.076012134552002 | KNN Loss: 5.031058311462402 | BCE Loss: 1.0449538230895996
Epoch 470 / 500 | iteration 15 / 30 | Total Loss: 6.061872482299805 | KNN Loss: 5.039194107055664 | BCE Loss: 1.0226786136627197
Epoch 470 / 500 | iteration 20 / 30 | Total Loss: 6.129312992095947 | KNN Loss: 5.029121398925781 | BCE Loss: 1.1001917123794556
Epoch 470 / 500 | iteration 25 / 30 | Total Loss: 6.044210433959961 | KNN Loss: 5.0272393226623535 | BCE Loss: 1.0169711112976074
Epoch 471 / 500 | iteration 0 / 30 | Total Loss: 6.068950176239014 | KNN Loss: 5.03914213180542 | BCE Loss: 1.0298079252243042
Epoch 471 / 500 | iteration 5 / 30 | Total Loss: 6.090576171875 | KNN Loss: 5.035552501678467 | BCE Loss: 1.0550237894058228
Epoch 471 / 500 | iteration 10 / 30 | Total Loss: 6.102415084838867 | KNN Loss: 5.051633358001709 | BCE

Epoch 480 / 500 | iteration 25 / 30 | Total Loss: 6.085815906524658 | KNN Loss: 5.036717414855957 | BCE Loss: 1.0490984916687012
Epoch 481 / 500 | iteration 0 / 30 | Total Loss: 6.078423500061035 | KNN Loss: 5.040165901184082 | BCE Loss: 1.0382577180862427
Epoch 481 / 500 | iteration 5 / 30 | Total Loss: 6.091352462768555 | KNN Loss: 5.043668746948242 | BCE Loss: 1.0476837158203125
Epoch 481 / 500 | iteration 10 / 30 | Total Loss: 6.054566860198975 | KNN Loss: 5.036487579345703 | BCE Loss: 1.0180792808532715
Epoch 481 / 500 | iteration 15 / 30 | Total Loss: 6.045516490936279 | KNN Loss: 5.042888641357422 | BCE Loss: 1.0026277303695679
Epoch 481 / 500 | iteration 20 / 30 | Total Loss: 6.068411827087402 | KNN Loss: 5.026695251464844 | BCE Loss: 1.041716456413269
Epoch 481 / 500 | iteration 25 / 30 | Total Loss: 6.059846878051758 | KNN Loss: 5.0303053855896 | BCE Loss: 1.029541254043579
Epoch 482 / 500 | iteration 0 / 30 | Total Loss: 6.07157564163208 | KNN Loss: 5.039799690246582 | BCE L

Epoch 491 / 500 | iteration 20 / 30 | Total Loss: 6.06508731842041 | KNN Loss: 5.025021076202393 | BCE Loss: 1.0400664806365967
Epoch 491 / 500 | iteration 25 / 30 | Total Loss: 6.080235004425049 | KNN Loss: 5.030549049377441 | BCE Loss: 1.0496859550476074
Epoch 492 / 500 | iteration 0 / 30 | Total Loss: 6.075401782989502 | KNN Loss: 5.027512550354004 | BCE Loss: 1.0478891134262085
Epoch 492 / 500 | iteration 5 / 30 | Total Loss: 6.110662460327148 | KNN Loss: 5.062050819396973 | BCE Loss: 1.0486117601394653
Epoch 492 / 500 | iteration 10 / 30 | Total Loss: 6.048522472381592 | KNN Loss: 5.0362067222595215 | BCE Loss: 1.0123156309127808
Epoch 492 / 500 | iteration 15 / 30 | Total Loss: 6.075016021728516 | KNN Loss: 5.038321495056152 | BCE Loss: 1.0366945266723633
Epoch 492 / 500 | iteration 20 / 30 | Total Loss: 6.092489719390869 | KNN Loss: 5.0410990715026855 | BCE Loss: 1.051390528678894
Epoch 492 / 500 | iteration 25 / 30 | Total Loss: 6.068001747131348 | KNN Loss: 5.043178558349609 |

In [7]:
outputs, iterm = model(dataset[67][0].unsqueeze(0).to(device), return_intermidiate=True)
print(outputs)
print(outputs.min())
print(outputs.max())

tensor([[ 3.1004,  2.6817,  2.7539,  2.4920,  3.6779,  0.7150,  2.7538,  2.4548,
          2.5670,  1.9031,  2.4237,  2.1422,  0.9989,  1.9468,  1.3392,  1.4412,
          1.9644,  2.2851,  2.9922,  2.3174,  1.7563,  3.0557,  2.2364,  1.6399,
          2.4427,  1.5676,  2.3108,  1.6814,  1.3940,  0.0944, -0.1598,  1.1897,
          0.2207,  1.0250,  1.5466,  1.4039,  0.7653,  3.5941,  0.6789,  1.5340,
          0.9580, -0.7711, -0.3380,  2.2785,  2.4960,  0.7685, -0.0435, -0.1210,
          1.4144,  2.7551,  1.6707,  0.1121,  1.4718,  0.4221, -0.7428,  1.1006,
          1.6403,  1.3968,  1.3351,  1.9359,  0.2931,  0.7008,  0.2174,  1.9189,
          1.3878,  1.7041, -1.8440,  0.3478,  2.1132,  1.7647,  2.5614,  0.4807,
          1.2871,  2.6378,  1.7086,  1.1656,  0.2173,  0.7676,  0.3412,  1.7553,
          0.0473,  0.2614,  2.0176, -0.3224,  0.4090, -1.1499, -2.7509, -0.2275,
          0.5427, -1.7010,  0.4813, -0.1728, -0.4958, -0.8848,  0.5925,  1.2779,
         -0.5448, -0.9115,  

In [8]:
plt.figure()
plt.plot(losses)
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [9]:
dataset.transform = torchvision.transforms.Compose([
    BinaryEncodingTransform(mapping=dataset.items_to_idx),
]
)
dataset.target_transform = torchvision.transforms.Compose([
    BinaryEncodingTransform(mapping=dataset.items_to_idx),
]
)

In [10]:
dataset_ = [d[0].cpu() for d in dataset]

In [11]:
model = model.eval().cpu()
projections = model.calculate_intermidiate(dataset_)

100%|██████████| 15/15 [00:00<00:00, 84.87it/s]


In [12]:
distances = pairwise_distances(projections)
# distances = np.triu(distances)
distances_f = distances.flatten()

plt.matshow(distances)
plt.colorbar()
plt.figure()
plt.hist(distances_f[distances_f > 0], bins=1000)
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

# Fit DBSCAN and calculate indices


In [13]:
clusters = DBSCAN(eps=0.2, min_samples=80).fit_predict(projections)
# scores = []
# best_score = float('inf')
# clusters = None
# range_ = list(range(5, 20))
# for k in tqdm(range_):
#     y = GaussianMixture(n_components=k).fit_predict(projections)
#     cur_score = davies_bouldin_score(projections, y)
#     scores.append(cur_score)
    
#     if cur_score < best_score:
#         best_score = cur_score
#         clusters = y

In [14]:
perplexity = 100

p = reduce_dims_and_plot(projections[clusters != -1],
                         y=clusters[clusters != -1],
                         title=f'perplexity: {perplexity}',
                         file_name=None,
                         perplexity=perplexity,
                         library='Multicore-TSNE',
                         perform_PCA=False,
                         projected=None,
                         figure_type='2d',
                         show_figure=True,
                         close_figure=False,
                         text=None)    

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [15]:
# from sklearn.tree import DecisionTreeClassifier
# from sklearn import tree
# from sklearn.tree import _tree

In [16]:
tensor_dataset = torch.stack(dataset_)

In [17]:
# clf = DecisionTreeClassifier(max_depth=200, min_samples_leaf=5)
# clf = clf.fit(tensor_dataset[clusters!=-1], clusters[clusters != -1])
# print(clf.score(tensor_dataset[clusters!=-1], clusters[clusters != -1]))
# print(clf.get_depth())

In [18]:
# scores = []
# for min_samples in range(1,50, 1):
#     clf = DecisionTreeClassifier(max_depth=200, min_samples_leaf=min_samples)
#     clf = clf.fit(tensor_dataset[clusters!=-1], clusters[clusters != -1])
#     scores.append(clf.score(tensor_dataset[clusters!=-1], clusters[clusters != -1]))
    
# plt.figure()
# plt.plot(list(range(1,50, 1)), scores)
# plt.show()

In [19]:
def get_rules(tree, feature_names, class_names):
    tree_ = tree.tree_
    feature_name = [
        feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!"
        for i in tree_.feature
    ]

    paths = []
    path = []
    
    def recurse(node, path, paths):
        if tree_.feature[node] != _tree.TREE_UNDEFINED:
            name = feature_name[node]
            threshold = tree_.threshold[node]
            p1, p2 = list(path), list(path)
#             p1 += [f"({name} <= {np.round(threshold, 3)})"]
            p1 += [(name, '<=', np.round(threshold, 3))]
            recurse(tree_.children_left[node], p1, paths)
            p2 += [(name, '>', np.round(threshold, 3))]
            recurse(tree_.children_right[node], p2, paths)
        else:
            path += [(tree_.value[node], tree_.n_node_samples[node])]
            paths += [path]
            
    recurse(0, path, paths)

    # sort by samples count
    samples_count = [p[-1][1] for p in paths]
    ii = list(np.argsort(samples_count))
    paths = [paths[i] for i in reversed(ii)]
    
    rules = []
    for path in paths:
        rule = []
        
        for p in path[:-1]:
            rule += [p]
        target = " then "
        if class_names is None:
            target += "response: "+str(np.round(path[-1][0][0][0],3))
        else:
            classes = path[-1][0][0]
            l = np.argmax(classes)
            target += f"class: {class_names[l]} (proba: {np.round(100.0*classes[l]/np.sum(classes),2)}%)"
           
        proba = np.round(100.0*classes[l]/np.sum(classes),2)
        target += f" | based on {path[-1][1]:,} samples"
        rule_wrapper = {'target': target, 'rule': rule, 'proba': proba}
        rules += [rule_wrapper]
        
    return rules

In [20]:
# rules = get_rules(clf, dataset.items, clusters[clusters != -1])

# for rule in rules:
#     n_pos = 0
#     for c,p,v in rule['rule']:
#         if p == '>':
#             n_pos += 1
#     rule['pos'] = n_pos

In [21]:
# plt.figure()
# probs = [r['proba'] for r in rules]
# plt.hist(probs, bins = 100)
# plt.show()

In [22]:
# rules = sorted(rules, key=lambda x:x['pos'])
# rules = [r for r in rules if r['proba'] > 50]
# print(len(rules))

In [23]:
# for i in range(17):
#     r_i = rules[i]
#     print(f"------------- rule {i} length {len(r_i)} -------------")
#     print(r_i)

# Train a Soft-Decision-Tree given the self-labels

## Prepare the dataset

In [24]:
tree_dataset = list(zip(tensor_dataset[clusters!=-1], clusters[clusters != -1]))
batch_size = 512
tree_loader = torch.utils.data.DataLoader(tree_dataset, batch_size=batch_size, shuffle=True)

# Define how we prune the weights of a node

In [25]:
def prune_node(node_weights, factor=1):
    w = node_weights.cpu().detach().numpy()
    mean_ = np.mean(w)
    std_ = np.std(w)
    node_weights[((mean_ - std_ * factor) < node_weights) & (node_weights < (mean_ + std_ * factor))] = 0
    return node_weights

def prune_node_keep(node_weights, keep=4):
    w = node_weights.cpu().detach().numpy()
    throw_idx = np.argsort(abs(w))[:-keep]
    node_weights[throw_idx] = 0
    return node_weights

def prune_tree(tree_, factor):
    new_weights = tree_.inner_nodes.weight.clone()
    for i in range(new_weights.shape[0]):
        res = prune_node_keep(new_weights[i, :], factor)
        new_weights[i, :] = res

    with torch.no_grad():
        tree_.inner_nodes.weight.copy_(new_weights)
        
def sparseness(x):
    s = []
    for i in range(x.shape[0]):
        x_ = x[i, :]
        sp = (len(x_) - torch.norm(x_, 0).item()) / len(x_)
        s.append(sp)
    return np.mean(s)

def compute_regularization_by_level(tree):
    total_reg = 0
    for i in range(tree.inner_nodes.weight.shape[0]):
        cur_level = np.floor(np.log2(i+1))
        node_reg = torch.norm(tree.inner_nodes.weight[i].view(-1), 2)
        total_reg += 2**(-cur_level) * node_reg
    return total_reg

def show_sparseness(tree):
    avg_sp = sparseness(tree.inner_nodes.weight)
    print(f"Average sparseness: {avg_sp}")
    layer = 0
    sps = []
    for i in range(tree.inner_nodes.weight.shape[0]):
        cur_layer = int(np.floor(np.log2(i+1)))
        if cur_layer != layer:
            print(f"layer {layer}: {np.mean(sps)}")
            sps = []
            layer = cur_layer

        x_ = tree.inner_nodes.weight[i, :]
        sp = (len(x_) - torch.norm(x_, 0).item()) / len(x_)
        sps.append(sp)
        
    return avg_sp

## Training configurations

In [26]:
def do_epoch(model, loader, device, log_interval, losses, accs, epoch, iteration):
    model = model.train()
    for batch_idx, (data, target) in enumerate(loader):
        iteration += 1
        data, target = data.to(device), target.to(device)

        output, penalty = tree.forward(data)

        # Loss
        loss_tree = criterion(output, target.view(-1))

        # Penalty
        loss_tree += penalty

        # Sparse regularization
#         fc_params = torch.cat([x.view(-1) for x in tree.inner_nodes.parameters()])
#         regularization = sparsity_lamda * torch.norm(fc_params, 2)
        regularization = sparsity_lamda * compute_regularization_by_level(tree)
        loss = loss_tree

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        losses.append(loss.item())

        pred = output.data.max(1)[1]
        correct = pred.eq(target.view(-1).data).sum()
        accs.append(correct.item() / data.size()[0])

        # Print training status
        if batch_idx % log_interval == 0:
            print(f"Epoch: {epoch:02d} | Batch: {batch_idx:03d} / {len(loader):03d} | Total loss: {loss.item():.3f} | Reg loss: {regularization.item():.3f} | Tree loss: {loss_tree.item():.3f} | Accuracy: {correct.item() / data.size()[0]:03f} | {round((time.time() - start_time) / iteration, 3)} sec/iter")
            
    return iteration


In [27]:
lr = 5e-3
weight_decay = 5e-4
sparsity_lamda = 2e-3
epochs = 100
output_dim = len(set(clusters))
log_interval = 1
use_cuda = device != 'cpu'

In [28]:
tree = SDT(input_dim=tensor_dataset.shape[1], output_dim=len(clusters - 1), depth=tree_depth, lamda=1e-3, use_cuda=use_cuda)
optimizer = torch.optim.Adam(tree.parameters(),
                                 lr=lr,
                                 weight_decay=weight_decay)
criterion = nn.CrossEntropyLoss()
tree = tree.to(device)

In [29]:
losses = []
accs = []
sparsity = []

In [30]:
start_time = time.time()
iteration = 0
for epoch in range(epochs):
    # Training
    avg_sp = show_sparseness(tree)
    sparsity.append(avg_sp)
    iteration = do_epoch(tree, tree_loader, device, log_interval, losses, accs, epoch, iteration)
    
    if epoch % 1 == 0:
        prune_tree(tree, factor=3)
        

Average sparseness: 0.0
layer 0: 0.0
layer 1: 0.0
layer 2: 0.0
layer 3: 0.0
layer 4: 0.0
layer 5: 0.0
layer 6: 0.0
layer 7: 0.0
layer 8: 0.0
Epoch: 00 | Batch: 000 / 029 | Total loss: 9.606 | Reg loss: 0.012 | Tree loss: 9.606 | Accuracy: 0.000000 | 1.466 sec/iter
Epoch: 00 | Batch: 001 / 029 | Total loss: 9.597 | Reg loss: 0.011 | Tree loss: 9.597 | Accuracy: 0.000000 | 1.161 sec/iter
Epoch: 00 | Batch: 002 / 029 | Total loss: 9.588 | Reg loss: 0.010 | Tree loss: 9.588 | Accuracy: 0.000000 | 1.091 sec/iter
Epoch: 00 | Batch: 003 / 029 | Total loss: 9.579 | Reg loss: 0.010 | Tree loss: 9.579 | Accuracy: 0.000000 | 1.04 sec/iter
Epoch: 00 | Batch: 004 / 029 | Total loss: 9.572 | Reg loss: 0.009 | Tree loss: 9.572 | Accuracy: 0.000000 | 1.011 sec/iter
Epoch: 00 | Batch: 005 / 029 | Total loss: 9.562 | Reg loss: 0.009 | Tree loss: 9.562 | Accuracy: 0.003906 | 0.989 sec/iter
Epoch: 00 | Batch: 006 / 029 | Total loss: 9.552 | Reg loss: 0.009 | Tree loss: 9.552 | Accuracy: 0.005859 | 0.975 s

Epoch: 02 | Batch: 003 / 029 | Total loss: 9.348 | Reg loss: 0.009 | Tree loss: 9.348 | Accuracy: 0.390625 | 0.918 sec/iter
Epoch: 02 | Batch: 004 / 029 | Total loss: 9.345 | Reg loss: 0.009 | Tree loss: 9.345 | Accuracy: 0.332031 | 0.917 sec/iter
Epoch: 02 | Batch: 005 / 029 | Total loss: 9.330 | Reg loss: 0.009 | Tree loss: 9.330 | Accuracy: 0.388672 | 0.917 sec/iter
Epoch: 02 | Batch: 006 / 029 | Total loss: 9.338 | Reg loss: 0.010 | Tree loss: 9.338 | Accuracy: 0.285156 | 0.916 sec/iter
Epoch: 02 | Batch: 007 / 029 | Total loss: 9.321 | Reg loss: 0.010 | Tree loss: 9.321 | Accuracy: 0.376953 | 0.916 sec/iter
Epoch: 02 | Batch: 008 / 029 | Total loss: 9.313 | Reg loss: 0.010 | Tree loss: 9.313 | Accuracy: 0.363281 | 0.916 sec/iter
Epoch: 02 | Batch: 009 / 029 | Total loss: 9.297 | Reg loss: 0.011 | Tree loss: 9.297 | Accuracy: 0.367188 | 0.916 sec/iter
Epoch: 02 | Batch: 010 / 029 | Total loss: 9.291 | Reg loss: 0.011 | Tree loss: 9.291 | Accuracy: 0.312500 | 0.916 sec/iter
Epoch: 0

Epoch: 04 | Batch: 007 / 029 | Total loss: 9.029 | Reg loss: 0.016 | Tree loss: 9.029 | Accuracy: 0.376953 | 0.901 sec/iter
Epoch: 04 | Batch: 008 / 029 | Total loss: 9.031 | Reg loss: 0.017 | Tree loss: 9.031 | Accuracy: 0.363281 | 0.901 sec/iter
Epoch: 04 | Batch: 009 / 029 | Total loss: 9.035 | Reg loss: 0.017 | Tree loss: 9.035 | Accuracy: 0.314453 | 0.901 sec/iter
Epoch: 04 | Batch: 010 / 029 | Total loss: 9.004 | Reg loss: 0.017 | Tree loss: 9.004 | Accuracy: 0.330078 | 0.9 sec/iter
Epoch: 04 | Batch: 011 / 029 | Total loss: 8.990 | Reg loss: 0.018 | Tree loss: 8.990 | Accuracy: 0.355469 | 0.9 sec/iter
Epoch: 04 | Batch: 012 / 029 | Total loss: 8.961 | Reg loss: 0.018 | Tree loss: 8.961 | Accuracy: 0.355469 | 0.9 sec/iter
Epoch: 04 | Batch: 013 / 029 | Total loss: 8.950 | Reg loss: 0.019 | Tree loss: 8.950 | Accuracy: 0.355469 | 0.9 sec/iter
Epoch: 04 | Batch: 014 / 029 | Total loss: 8.946 | Reg loss: 0.019 | Tree loss: 8.946 | Accuracy: 0.367188 | 0.9 sec/iter
Epoch: 04 | Batch:

Epoch: 06 | Batch: 011 / 029 | Total loss: 8.561 | Reg loss: 0.023 | Tree loss: 8.561 | Accuracy: 0.380859 | 0.898 sec/iter
Epoch: 06 | Batch: 012 / 029 | Total loss: 8.541 | Reg loss: 0.024 | Tree loss: 8.541 | Accuracy: 0.347656 | 0.898 sec/iter
Epoch: 06 | Batch: 013 / 029 | Total loss: 8.520 | Reg loss: 0.024 | Tree loss: 8.520 | Accuracy: 0.339844 | 0.898 sec/iter
Epoch: 06 | Batch: 014 / 029 | Total loss: 8.498 | Reg loss: 0.024 | Tree loss: 8.498 | Accuracy: 0.333984 | 0.898 sec/iter
Epoch: 06 | Batch: 015 / 029 | Total loss: 8.469 | Reg loss: 0.025 | Tree loss: 8.469 | Accuracy: 0.343750 | 0.898 sec/iter
Epoch: 06 | Batch: 016 / 029 | Total loss: 8.454 | Reg loss: 0.025 | Tree loss: 8.454 | Accuracy: 0.353516 | 0.898 sec/iter
Epoch: 06 | Batch: 017 / 029 | Total loss: 8.435 | Reg loss: 0.026 | Tree loss: 8.435 | Accuracy: 0.345703 | 0.898 sec/iter
Epoch: 06 | Batch: 018 / 029 | Total loss: 8.432 | Reg loss: 0.026 | Tree loss: 8.432 | Accuracy: 0.365234 | 0.898 sec/iter
Epoch: 0

Epoch: 08 | Batch: 015 / 029 | Total loss: 7.950 | Reg loss: 0.028 | Tree loss: 7.950 | Accuracy: 0.363281 | 0.897 sec/iter
Epoch: 08 | Batch: 016 / 029 | Total loss: 7.949 | Reg loss: 0.028 | Tree loss: 7.949 | Accuracy: 0.371094 | 0.897 sec/iter
Epoch: 08 | Batch: 017 / 029 | Total loss: 7.916 | Reg loss: 0.028 | Tree loss: 7.916 | Accuracy: 0.365234 | 0.897 sec/iter
Epoch: 08 | Batch: 018 / 029 | Total loss: 7.935 | Reg loss: 0.029 | Tree loss: 7.935 | Accuracy: 0.332031 | 0.897 sec/iter
Epoch: 08 | Batch: 019 / 029 | Total loss: 7.903 | Reg loss: 0.029 | Tree loss: 7.903 | Accuracy: 0.361328 | 0.897 sec/iter
Epoch: 08 | Batch: 020 / 029 | Total loss: 7.874 | Reg loss: 0.029 | Tree loss: 7.874 | Accuracy: 0.320312 | 0.897 sec/iter
Epoch: 08 | Batch: 021 / 029 | Total loss: 7.850 | Reg loss: 0.029 | Tree loss: 7.850 | Accuracy: 0.376953 | 0.897 sec/iter
Epoch: 08 | Batch: 022 / 029 | Total loss: 7.799 | Reg loss: 0.030 | Tree loss: 7.799 | Accuracy: 0.341797 | 0.896 sec/iter
Epoch: 0

Epoch: 10 | Batch: 019 / 029 | Total loss: 7.326 | Reg loss: 0.030 | Tree loss: 7.326 | Accuracy: 0.373047 | 0.896 sec/iter
Epoch: 10 | Batch: 020 / 029 | Total loss: 7.305 | Reg loss: 0.030 | Tree loss: 7.305 | Accuracy: 0.349609 | 0.896 sec/iter
Epoch: 10 | Batch: 021 / 029 | Total loss: 7.291 | Reg loss: 0.030 | Tree loss: 7.291 | Accuracy: 0.330078 | 0.896 sec/iter
Epoch: 10 | Batch: 022 / 029 | Total loss: 7.248 | Reg loss: 0.031 | Tree loss: 7.248 | Accuracy: 0.382812 | 0.896 sec/iter
Epoch: 10 | Batch: 023 / 029 | Total loss: 7.215 | Reg loss: 0.031 | Tree loss: 7.215 | Accuracy: 0.392578 | 0.896 sec/iter
Epoch: 10 | Batch: 024 / 029 | Total loss: 7.190 | Reg loss: 0.031 | Tree loss: 7.190 | Accuracy: 0.386719 | 0.896 sec/iter
Epoch: 10 | Batch: 025 / 029 | Total loss: 7.164 | Reg loss: 0.032 | Tree loss: 7.164 | Accuracy: 0.353516 | 0.896 sec/iter
Epoch: 10 | Batch: 026 / 029 | Total loss: 7.156 | Reg loss: 0.032 | Tree loss: 7.156 | Accuracy: 0.363281 | 0.896 sec/iter
Epoch: 1

Epoch: 12 | Batch: 023 / 029 | Total loss: 6.615 | Reg loss: 0.034 | Tree loss: 6.615 | Accuracy: 0.367188 | 0.895 sec/iter
Epoch: 12 | Batch: 024 / 029 | Total loss: 6.572 | Reg loss: 0.035 | Tree loss: 6.572 | Accuracy: 0.353516 | 0.895 sec/iter
Epoch: 12 | Batch: 025 / 029 | Total loss: 6.526 | Reg loss: 0.035 | Tree loss: 6.526 | Accuracy: 0.357422 | 0.895 sec/iter
Epoch: 12 | Batch: 026 / 029 | Total loss: 6.500 | Reg loss: 0.036 | Tree loss: 6.500 | Accuracy: 0.406250 | 0.895 sec/iter
Epoch: 12 | Batch: 027 / 029 | Total loss: 6.493 | Reg loss: 0.036 | Tree loss: 6.493 | Accuracy: 0.373047 | 0.895 sec/iter
Epoch: 12 | Batch: 028 / 029 | Total loss: 6.488 | Reg loss: 0.037 | Tree loss: 6.488 | Accuracy: 0.356902 | 0.895 sec/iter
Average sparseness: 0.9821428571428573
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
layer 7: 0.98214285714

Epoch: 14 | Batch: 027 / 029 | Total loss: 5.861 | Reg loss: 0.039 | Tree loss: 5.861 | Accuracy: 0.392578 | 0.895 sec/iter
Epoch: 14 | Batch: 028 / 029 | Total loss: 5.867 | Reg loss: 0.039 | Tree loss: 5.867 | Accuracy: 0.313131 | 0.895 sec/iter
Average sparseness: 0.9821428571428573
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
layer 7: 0.9821428571428573
layer 8: 0.9821428571428573
Epoch: 15 | Batch: 000 / 029 | Total loss: 6.434 | Reg loss: 0.031 | Tree loss: 6.434 | Accuracy: 0.371094 | 0.896 sec/iter
Epoch: 15 | Batch: 001 / 029 | Total loss: 6.449 | Reg loss: 0.031 | Tree loss: 6.449 | Accuracy: 0.349609 | 0.896 sec/iter
Epoch: 15 | Batch: 002 / 029 | Total loss: 6.377 | Reg loss: 0.031 | Tree loss: 6.377 | Accuracy: 0.371094 | 0.896 sec/iter
Epoch: 15 | Batch: 003 / 029 | Total loss: 6.362 | Reg loss: 0.032 | Tree loss: 6.362 | Ac

Epoch: 17 | Batch: 000 / 029 | Total loss: 5.917 | Reg loss: 0.034 | Tree loss: 5.917 | Accuracy: 0.353516 | 0.896 sec/iter
Epoch: 17 | Batch: 001 / 029 | Total loss: 5.839 | Reg loss: 0.034 | Tree loss: 5.839 | Accuracy: 0.367188 | 0.896 sec/iter
Epoch: 17 | Batch: 002 / 029 | Total loss: 5.832 | Reg loss: 0.034 | Tree loss: 5.832 | Accuracy: 0.378906 | 0.896 sec/iter
Epoch: 17 | Batch: 003 / 029 | Total loss: 5.784 | Reg loss: 0.034 | Tree loss: 5.784 | Accuracy: 0.353516 | 0.896 sec/iter
Epoch: 17 | Batch: 004 / 029 | Total loss: 5.727 | Reg loss: 0.034 | Tree loss: 5.727 | Accuracy: 0.376953 | 0.896 sec/iter
Epoch: 17 | Batch: 005 / 029 | Total loss: 5.694 | Reg loss: 0.035 | Tree loss: 5.694 | Accuracy: 0.351562 | 0.896 sec/iter
Epoch: 17 | Batch: 006 / 029 | Total loss: 5.655 | Reg loss: 0.035 | Tree loss: 5.655 | Accuracy: 0.365234 | 0.896 sec/iter
Epoch: 17 | Batch: 007 / 029 | Total loss: 5.622 | Reg loss: 0.035 | Tree loss: 5.622 | Accuracy: 0.341797 | 0.895 sec/iter
Epoch: 1

Epoch: 19 | Batch: 004 / 029 | Total loss: 5.124 | Reg loss: 0.037 | Tree loss: 5.124 | Accuracy: 0.396484 | 0.895 sec/iter
Epoch: 19 | Batch: 005 / 029 | Total loss: 5.120 | Reg loss: 0.037 | Tree loss: 5.120 | Accuracy: 0.376953 | 0.895 sec/iter
Epoch: 19 | Batch: 006 / 029 | Total loss: 5.124 | Reg loss: 0.037 | Tree loss: 5.124 | Accuracy: 0.359375 | 0.895 sec/iter
Epoch: 19 | Batch: 007 / 029 | Total loss: 5.085 | Reg loss: 0.037 | Tree loss: 5.085 | Accuracy: 0.369141 | 0.895 sec/iter
Epoch: 19 | Batch: 008 / 029 | Total loss: 5.061 | Reg loss: 0.037 | Tree loss: 5.061 | Accuracy: 0.308594 | 0.895 sec/iter
Epoch: 19 | Batch: 009 / 029 | Total loss: 4.962 | Reg loss: 0.037 | Tree loss: 4.962 | Accuracy: 0.382812 | 0.895 sec/iter
Epoch: 19 | Batch: 010 / 029 | Total loss: 4.938 | Reg loss: 0.038 | Tree loss: 4.938 | Accuracy: 0.361328 | 0.895 sec/iter
Epoch: 19 | Batch: 011 / 029 | Total loss: 4.862 | Reg loss: 0.038 | Tree loss: 4.862 | Accuracy: 0.375000 | 0.895 sec/iter
Epoch: 1

Epoch: 21 | Batch: 008 / 029 | Total loss: 4.503 | Reg loss: 0.039 | Tree loss: 4.503 | Accuracy: 0.349609 | 0.895 sec/iter
Epoch: 21 | Batch: 009 / 029 | Total loss: 4.464 | Reg loss: 0.039 | Tree loss: 4.464 | Accuracy: 0.347656 | 0.895 sec/iter
Epoch: 21 | Batch: 010 / 029 | Total loss: 4.451 | Reg loss: 0.040 | Tree loss: 4.451 | Accuracy: 0.402344 | 0.895 sec/iter
Epoch: 21 | Batch: 011 / 029 | Total loss: 4.367 | Reg loss: 0.040 | Tree loss: 4.367 | Accuracy: 0.355469 | 0.895 sec/iter
Epoch: 21 | Batch: 012 / 029 | Total loss: 4.405 | Reg loss: 0.040 | Tree loss: 4.405 | Accuracy: 0.367188 | 0.895 sec/iter
Epoch: 21 | Batch: 013 / 029 | Total loss: 4.409 | Reg loss: 0.040 | Tree loss: 4.409 | Accuracy: 0.320312 | 0.895 sec/iter
Epoch: 21 | Batch: 014 / 029 | Total loss: 4.284 | Reg loss: 0.041 | Tree loss: 4.284 | Accuracy: 0.349609 | 0.895 sec/iter
Epoch: 21 | Batch: 015 / 029 | Total loss: 4.257 | Reg loss: 0.041 | Tree loss: 4.257 | Accuracy: 0.365234 | 0.895 sec/iter
Epoch: 2

Epoch: 23 | Batch: 012 / 029 | Total loss: 3.864 | Reg loss: 0.042 | Tree loss: 3.864 | Accuracy: 0.369141 | 0.895 sec/iter
Epoch: 23 | Batch: 013 / 029 | Total loss: 3.773 | Reg loss: 0.042 | Tree loss: 3.773 | Accuracy: 0.414062 | 0.895 sec/iter
Epoch: 23 | Batch: 014 / 029 | Total loss: 3.821 | Reg loss: 0.042 | Tree loss: 3.821 | Accuracy: 0.406250 | 0.895 sec/iter
Epoch: 23 | Batch: 015 / 029 | Total loss: 3.834 | Reg loss: 0.042 | Tree loss: 3.834 | Accuracy: 0.380859 | 0.895 sec/iter
Epoch: 23 | Batch: 016 / 029 | Total loss: 3.739 | Reg loss: 0.043 | Tree loss: 3.739 | Accuracy: 0.369141 | 0.895 sec/iter
Epoch: 23 | Batch: 017 / 029 | Total loss: 3.738 | Reg loss: 0.043 | Tree loss: 3.738 | Accuracy: 0.369141 | 0.895 sec/iter
Epoch: 23 | Batch: 018 / 029 | Total loss: 3.712 | Reg loss: 0.043 | Tree loss: 3.712 | Accuracy: 0.335938 | 0.895 sec/iter
Epoch: 23 | Batch: 019 / 029 | Total loss: 3.700 | Reg loss: 0.043 | Tree loss: 3.700 | Accuracy: 0.367188 | 0.895 sec/iter
Epoch: 2

Epoch: 25 | Batch: 016 / 029 | Total loss: 3.415 | Reg loss: 0.044 | Tree loss: 3.415 | Accuracy: 0.373047 | 0.895 sec/iter
Epoch: 25 | Batch: 017 / 029 | Total loss: 3.373 | Reg loss: 0.044 | Tree loss: 3.373 | Accuracy: 0.378906 | 0.895 sec/iter
Epoch: 25 | Batch: 018 / 029 | Total loss: 3.329 | Reg loss: 0.044 | Tree loss: 3.329 | Accuracy: 0.392578 | 0.895 sec/iter
Epoch: 25 | Batch: 019 / 029 | Total loss: 3.307 | Reg loss: 0.044 | Tree loss: 3.307 | Accuracy: 0.367188 | 0.895 sec/iter
Epoch: 25 | Batch: 020 / 029 | Total loss: 3.261 | Reg loss: 0.045 | Tree loss: 3.261 | Accuracy: 0.361328 | 0.895 sec/iter
Epoch: 25 | Batch: 021 / 029 | Total loss: 3.281 | Reg loss: 0.045 | Tree loss: 3.281 | Accuracy: 0.396484 | 0.895 sec/iter
Epoch: 25 | Batch: 022 / 029 | Total loss: 3.269 | Reg loss: 0.045 | Tree loss: 3.269 | Accuracy: 0.375000 | 0.895 sec/iter
Epoch: 25 | Batch: 023 / 029 | Total loss: 3.213 | Reg loss: 0.045 | Tree loss: 3.213 | Accuracy: 0.349609 | 0.895 sec/iter
Epoch: 2

Epoch: 27 | Batch: 020 / 029 | Total loss: 3.023 | Reg loss: 0.045 | Tree loss: 3.023 | Accuracy: 0.373047 | 0.895 sec/iter
Epoch: 27 | Batch: 021 / 029 | Total loss: 2.957 | Reg loss: 0.045 | Tree loss: 2.957 | Accuracy: 0.388672 | 0.895 sec/iter
Epoch: 27 | Batch: 022 / 029 | Total loss: 3.022 | Reg loss: 0.046 | Tree loss: 3.022 | Accuracy: 0.324219 | 0.895 sec/iter
Epoch: 27 | Batch: 023 / 029 | Total loss: 2.982 | Reg loss: 0.046 | Tree loss: 2.982 | Accuracy: 0.339844 | 0.895 sec/iter
Epoch: 27 | Batch: 024 / 029 | Total loss: 2.949 | Reg loss: 0.046 | Tree loss: 2.949 | Accuracy: 0.376953 | 0.895 sec/iter
Epoch: 27 | Batch: 025 / 029 | Total loss: 2.879 | Reg loss: 0.046 | Tree loss: 2.879 | Accuracy: 0.375000 | 0.895 sec/iter
Epoch: 27 | Batch: 026 / 029 | Total loss: 2.818 | Reg loss: 0.046 | Tree loss: 2.818 | Accuracy: 0.396484 | 0.895 sec/iter
Epoch: 27 | Batch: 027 / 029 | Total loss: 2.791 | Reg loss: 0.046 | Tree loss: 2.791 | Accuracy: 0.378906 | 0.895 sec/iter
Epoch: 2

Epoch: 29 | Batch: 024 / 029 | Total loss: 2.647 | Reg loss: 0.046 | Tree loss: 2.647 | Accuracy: 0.386719 | 0.895 sec/iter
Epoch: 29 | Batch: 025 / 029 | Total loss: 2.654 | Reg loss: 0.046 | Tree loss: 2.654 | Accuracy: 0.400391 | 0.895 sec/iter
Epoch: 29 | Batch: 026 / 029 | Total loss: 2.633 | Reg loss: 0.047 | Tree loss: 2.633 | Accuracy: 0.365234 | 0.895 sec/iter
Epoch: 29 | Batch: 027 / 029 | Total loss: 2.635 | Reg loss: 0.047 | Tree loss: 2.635 | Accuracy: 0.339844 | 0.895 sec/iter
Epoch: 29 | Batch: 028 / 029 | Total loss: 2.583 | Reg loss: 0.047 | Tree loss: 2.583 | Accuracy: 0.353535 | 0.895 sec/iter
Average sparseness: 0.9821428571428573
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
layer 7: 0.9821428571428573
layer 8: 0.9821428571428573
Epoch: 30 | Batch: 000 / 029 | Total loss: 3.174 | Reg loss: 0.044 | Tree loss: 3.174 | Ac

Epoch: 31 | Batch: 028 / 029 | Total loss: 2.416 | Reg loss: 0.047 | Tree loss: 2.416 | Accuracy: 0.346801 | 0.895 sec/iter
Average sparseness: 0.9821428571428573
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
layer 7: 0.9821428571428573
layer 8: 0.9821428571428573
Epoch: 32 | Batch: 000 / 029 | Total loss: 2.917 | Reg loss: 0.044 | Tree loss: 2.917 | Accuracy: 0.361328 | 0.896 sec/iter
Epoch: 32 | Batch: 001 / 029 | Total loss: 2.925 | Reg loss: 0.044 | Tree loss: 2.925 | Accuracy: 0.382812 | 0.896 sec/iter
Epoch: 32 | Batch: 002 / 029 | Total loss: 2.889 | Reg loss: 0.044 | Tree loss: 2.889 | Accuracy: 0.378906 | 0.896 sec/iter
Epoch: 32 | Batch: 003 / 029 | Total loss: 2.863 | Reg loss: 0.044 | Tree loss: 2.863 | Accuracy: 0.382812 | 0.896 sec/iter
Epoch: 32 | Batch: 004 / 029 | Total loss: 2.855 | Reg loss: 0.044 | Tree loss: 2.855 | Ac

Epoch: 34 | Batch: 001 / 029 | Total loss: 2.697 | Reg loss: 0.045 | Tree loss: 2.697 | Accuracy: 0.412109 | 0.896 sec/iter
Epoch: 34 | Batch: 002 / 029 | Total loss: 2.714 | Reg loss: 0.045 | Tree loss: 2.714 | Accuracy: 0.378906 | 0.896 sec/iter
Epoch: 34 | Batch: 003 / 029 | Total loss: 2.652 | Reg loss: 0.045 | Tree loss: 2.652 | Accuracy: 0.410156 | 0.896 sec/iter
Epoch: 34 | Batch: 004 / 029 | Total loss: 2.645 | Reg loss: 0.045 | Tree loss: 2.645 | Accuracy: 0.373047 | 0.896 sec/iter
Epoch: 34 | Batch: 005 / 029 | Total loss: 2.566 | Reg loss: 0.045 | Tree loss: 2.566 | Accuracy: 0.367188 | 0.895 sec/iter
Epoch: 34 | Batch: 006 / 029 | Total loss: 2.647 | Reg loss: 0.045 | Tree loss: 2.647 | Accuracy: 0.349609 | 0.895 sec/iter
Epoch: 34 | Batch: 007 / 029 | Total loss: 2.598 | Reg loss: 0.045 | Tree loss: 2.598 | Accuracy: 0.384766 | 0.895 sec/iter
Epoch: 34 | Batch: 008 / 029 | Total loss: 2.551 | Reg loss: 0.045 | Tree loss: 2.551 | Accuracy: 0.373047 | 0.895 sec/iter
Epoch: 3

Epoch: 36 | Batch: 005 / 029 | Total loss: 2.495 | Reg loss: 0.045 | Tree loss: 2.495 | Accuracy: 0.357422 | 0.896 sec/iter
Epoch: 36 | Batch: 006 / 029 | Total loss: 2.489 | Reg loss: 0.045 | Tree loss: 2.489 | Accuracy: 0.351562 | 0.895 sec/iter
Epoch: 36 | Batch: 007 / 029 | Total loss: 2.449 | Reg loss: 0.045 | Tree loss: 2.449 | Accuracy: 0.353516 | 0.895 sec/iter
Epoch: 36 | Batch: 008 / 029 | Total loss: 2.458 | Reg loss: 0.045 | Tree loss: 2.458 | Accuracy: 0.353516 | 0.895 sec/iter
Epoch: 36 | Batch: 009 / 029 | Total loss: 2.331 | Reg loss: 0.045 | Tree loss: 2.331 | Accuracy: 0.378906 | 0.895 sec/iter
Epoch: 36 | Batch: 010 / 029 | Total loss: 2.315 | Reg loss: 0.045 | Tree loss: 2.315 | Accuracy: 0.382812 | 0.895 sec/iter
Epoch: 36 | Batch: 011 / 029 | Total loss: 2.320 | Reg loss: 0.045 | Tree loss: 2.320 | Accuracy: 0.400391 | 0.895 sec/iter
Epoch: 36 | Batch: 012 / 029 | Total loss: 2.328 | Reg loss: 0.045 | Tree loss: 2.328 | Accuracy: 0.373047 | 0.895 sec/iter
Epoch: 3

Epoch: 38 | Batch: 009 / 029 | Total loss: 2.355 | Reg loss: 0.045 | Tree loss: 2.355 | Accuracy: 0.343750 | 0.895 sec/iter
Epoch: 38 | Batch: 010 / 029 | Total loss: 2.281 | Reg loss: 0.045 | Tree loss: 2.281 | Accuracy: 0.343750 | 0.895 sec/iter
Epoch: 38 | Batch: 011 / 029 | Total loss: 2.303 | Reg loss: 0.045 | Tree loss: 2.303 | Accuracy: 0.353516 | 0.895 sec/iter
Epoch: 38 | Batch: 012 / 029 | Total loss: 2.235 | Reg loss: 0.045 | Tree loss: 2.235 | Accuracy: 0.386719 | 0.895 sec/iter
Epoch: 38 | Batch: 013 / 029 | Total loss: 2.179 | Reg loss: 0.045 | Tree loss: 2.179 | Accuracy: 0.380859 | 0.895 sec/iter
Epoch: 38 | Batch: 014 / 029 | Total loss: 2.247 | Reg loss: 0.045 | Tree loss: 2.247 | Accuracy: 0.365234 | 0.895 sec/iter
Epoch: 38 | Batch: 015 / 029 | Total loss: 2.167 | Reg loss: 0.045 | Tree loss: 2.167 | Accuracy: 0.371094 | 0.895 sec/iter
Epoch: 38 | Batch: 016 / 029 | Total loss: 2.199 | Reg loss: 0.045 | Tree loss: 2.199 | Accuracy: 0.365234 | 0.895 sec/iter
Epoch: 3

Epoch: 40 | Batch: 013 / 029 | Total loss: 2.138 | Reg loss: 0.045 | Tree loss: 2.138 | Accuracy: 0.384766 | 0.895 sec/iter
Epoch: 40 | Batch: 014 / 029 | Total loss: 2.105 | Reg loss: 0.045 | Tree loss: 2.105 | Accuracy: 0.375000 | 0.895 sec/iter
Epoch: 40 | Batch: 015 / 029 | Total loss: 2.078 | Reg loss: 0.045 | Tree loss: 2.078 | Accuracy: 0.384766 | 0.895 sec/iter
Epoch: 40 | Batch: 016 / 029 | Total loss: 2.043 | Reg loss: 0.045 | Tree loss: 2.043 | Accuracy: 0.392578 | 0.895 sec/iter
Epoch: 40 | Batch: 017 / 029 | Total loss: 2.100 | Reg loss: 0.045 | Tree loss: 2.100 | Accuracy: 0.349609 | 0.895 sec/iter
Epoch: 40 | Batch: 018 / 029 | Total loss: 2.037 | Reg loss: 0.045 | Tree loss: 2.037 | Accuracy: 0.333984 | 0.895 sec/iter
Epoch: 40 | Batch: 019 / 029 | Total loss: 2.031 | Reg loss: 0.045 | Tree loss: 2.031 | Accuracy: 0.378906 | 0.895 sec/iter
Epoch: 40 | Batch: 020 / 029 | Total loss: 2.031 | Reg loss: 0.045 | Tree loss: 2.031 | Accuracy: 0.339844 | 0.895 sec/iter
Epoch: 4

Epoch: 42 | Batch: 017 / 029 | Total loss: 1.962 | Reg loss: 0.045 | Tree loss: 1.962 | Accuracy: 0.378906 | 0.895 sec/iter
Epoch: 42 | Batch: 018 / 029 | Total loss: 1.985 | Reg loss: 0.045 | Tree loss: 1.985 | Accuracy: 0.376953 | 0.895 sec/iter
Epoch: 42 | Batch: 019 / 029 | Total loss: 1.969 | Reg loss: 0.045 | Tree loss: 1.969 | Accuracy: 0.384766 | 0.895 sec/iter
Epoch: 42 | Batch: 020 / 029 | Total loss: 1.953 | Reg loss: 0.045 | Tree loss: 1.953 | Accuracy: 0.367188 | 0.895 sec/iter
Epoch: 42 | Batch: 021 / 029 | Total loss: 1.938 | Reg loss: 0.045 | Tree loss: 1.938 | Accuracy: 0.361328 | 0.895 sec/iter
Epoch: 42 | Batch: 022 / 029 | Total loss: 1.978 | Reg loss: 0.045 | Tree loss: 1.978 | Accuracy: 0.378906 | 0.895 sec/iter
Epoch: 42 | Batch: 023 / 029 | Total loss: 1.913 | Reg loss: 0.045 | Tree loss: 1.913 | Accuracy: 0.365234 | 0.895 sec/iter
Epoch: 42 | Batch: 024 / 029 | Total loss: 1.948 | Reg loss: 0.045 | Tree loss: 1.948 | Accuracy: 0.322266 | 0.895 sec/iter
Epoch: 4

Epoch: 44 | Batch: 021 / 029 | Total loss: 1.835 | Reg loss: 0.044 | Tree loss: 1.835 | Accuracy: 0.414062 | 0.895 sec/iter
Epoch: 44 | Batch: 022 / 029 | Total loss: 1.868 | Reg loss: 0.045 | Tree loss: 1.868 | Accuracy: 0.388672 | 0.895 sec/iter
Epoch: 44 | Batch: 023 / 029 | Total loss: 1.834 | Reg loss: 0.045 | Tree loss: 1.834 | Accuracy: 0.400391 | 0.895 sec/iter
Epoch: 44 | Batch: 024 / 029 | Total loss: 1.900 | Reg loss: 0.045 | Tree loss: 1.900 | Accuracy: 0.351562 | 0.895 sec/iter
Epoch: 44 | Batch: 025 / 029 | Total loss: 1.855 | Reg loss: 0.045 | Tree loss: 1.855 | Accuracy: 0.373047 | 0.895 sec/iter
Epoch: 44 | Batch: 026 / 029 | Total loss: 1.814 | Reg loss: 0.045 | Tree loss: 1.814 | Accuracy: 0.365234 | 0.895 sec/iter
Epoch: 44 | Batch: 027 / 029 | Total loss: 1.831 | Reg loss: 0.045 | Tree loss: 1.831 | Accuracy: 0.367188 | 0.895 sec/iter
Epoch: 44 | Batch: 028 / 029 | Total loss: 1.801 | Reg loss: 0.045 | Tree loss: 1.801 | Accuracy: 0.390572 | 0.895 sec/iter
Average 

Epoch: 46 | Batch: 025 / 029 | Total loss: 1.800 | Reg loss: 0.044 | Tree loss: 1.800 | Accuracy: 0.333984 | 0.895 sec/iter
Epoch: 46 | Batch: 026 / 029 | Total loss: 1.774 | Reg loss: 0.044 | Tree loss: 1.774 | Accuracy: 0.371094 | 0.895 sec/iter
Epoch: 46 | Batch: 027 / 029 | Total loss: 1.786 | Reg loss: 0.044 | Tree loss: 1.786 | Accuracy: 0.367188 | 0.895 sec/iter
Epoch: 46 | Batch: 028 / 029 | Total loss: 1.685 | Reg loss: 0.044 | Tree loss: 1.685 | Accuracy: 0.410774 | 0.895 sec/iter
Average sparseness: 0.9821428571428573
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
layer 7: 0.9821428571428573
layer 8: 0.9821428571428573
Epoch: 47 | Batch: 000 / 029 | Total loss: 2.161 | Reg loss: 0.043 | Tree loss: 2.161 | Accuracy: 0.369141 | 0.895 sec/iter
Epoch: 47 | Batch: 001 / 029 | Total loss: 2.058 | Reg loss: 0.043 | Tree loss: 2.058 | Ac

Average sparseness: 0.9821428571428573
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
layer 7: 0.9821428571428573
layer 8: 0.9821428571428573
Epoch: 49 | Batch: 000 / 029 | Total loss: 2.086 | Reg loss: 0.042 | Tree loss: 2.086 | Accuracy: 0.343750 | 0.895 sec/iter
Epoch: 49 | Batch: 001 / 029 | Total loss: 2.055 | Reg loss: 0.042 | Tree loss: 2.055 | Accuracy: 0.394531 | 0.895 sec/iter
Epoch: 49 | Batch: 002 / 029 | Total loss: 2.040 | Reg loss: 0.042 | Tree loss: 2.040 | Accuracy: 0.363281 | 0.895 sec/iter
Epoch: 49 | Batch: 003 / 029 | Total loss: 2.035 | Reg loss: 0.042 | Tree loss: 2.035 | Accuracy: 0.375000 | 0.895 sec/iter
Epoch: 49 | Batch: 004 / 029 | Total loss: 2.018 | Reg loss: 0.042 | Tree loss: 2.018 | Accuracy: 0.380859 | 0.895 sec/iter
Epoch: 49 | Batch: 005 / 029 | Total loss: 1.989 | Reg loss: 0.042 | Tree loss: 1.989 | Ac

Epoch: 51 | Batch: 002 / 029 | Total loss: 1.969 | Reg loss: 0.042 | Tree loss: 1.969 | Accuracy: 0.355469 | 0.895 sec/iter
Epoch: 51 | Batch: 003 / 029 | Total loss: 1.968 | Reg loss: 0.042 | Tree loss: 1.968 | Accuracy: 0.345703 | 0.895 sec/iter
Epoch: 51 | Batch: 004 / 029 | Total loss: 1.967 | Reg loss: 0.042 | Tree loss: 1.967 | Accuracy: 0.410156 | 0.895 sec/iter
Epoch: 51 | Batch: 005 / 029 | Total loss: 1.965 | Reg loss: 0.042 | Tree loss: 1.965 | Accuracy: 0.400391 | 0.895 sec/iter
Epoch: 51 | Batch: 006 / 029 | Total loss: 1.911 | Reg loss: 0.042 | Tree loss: 1.911 | Accuracy: 0.369141 | 0.895 sec/iter
Epoch: 51 | Batch: 007 / 029 | Total loss: 1.899 | Reg loss: 0.042 | Tree loss: 1.899 | Accuracy: 0.351562 | 0.895 sec/iter
Epoch: 51 | Batch: 008 / 029 | Total loss: 1.909 | Reg loss: 0.042 | Tree loss: 1.909 | Accuracy: 0.357422 | 0.895 sec/iter
Epoch: 51 | Batch: 009 / 029 | Total loss: 1.877 | Reg loss: 0.042 | Tree loss: 1.877 | Accuracy: 0.373047 | 0.895 sec/iter
Epoch: 5

Epoch: 53 | Batch: 006 / 029 | Total loss: 1.892 | Reg loss: 0.041 | Tree loss: 1.892 | Accuracy: 0.371094 | 0.895 sec/iter
Epoch: 53 | Batch: 007 / 029 | Total loss: 1.871 | Reg loss: 0.041 | Tree loss: 1.871 | Accuracy: 0.375000 | 0.895 sec/iter
Epoch: 53 | Batch: 008 / 029 | Total loss: 1.866 | Reg loss: 0.041 | Tree loss: 1.866 | Accuracy: 0.351562 | 0.895 sec/iter
Epoch: 53 | Batch: 009 / 029 | Total loss: 1.884 | Reg loss: 0.041 | Tree loss: 1.884 | Accuracy: 0.388672 | 0.895 sec/iter
Epoch: 53 | Batch: 010 / 029 | Total loss: 1.826 | Reg loss: 0.041 | Tree loss: 1.826 | Accuracy: 0.357422 | 0.895 sec/iter
Epoch: 53 | Batch: 011 / 029 | Total loss: 1.848 | Reg loss: 0.042 | Tree loss: 1.848 | Accuracy: 0.347656 | 0.895 sec/iter
Epoch: 53 | Batch: 012 / 029 | Total loss: 1.767 | Reg loss: 0.042 | Tree loss: 1.767 | Accuracy: 0.378906 | 0.895 sec/iter
Epoch: 53 | Batch: 013 / 029 | Total loss: 1.779 | Reg loss: 0.042 | Tree loss: 1.779 | Accuracy: 0.367188 | 0.895 sec/iter
Epoch: 5

Epoch: 55 | Batch: 010 / 029 | Total loss: 1.843 | Reg loss: 0.041 | Tree loss: 1.843 | Accuracy: 0.396484 | 0.895 sec/iter
Epoch: 55 | Batch: 011 / 029 | Total loss: 1.756 | Reg loss: 0.041 | Tree loss: 1.756 | Accuracy: 0.382812 | 0.895 sec/iter
Epoch: 55 | Batch: 012 / 029 | Total loss: 1.813 | Reg loss: 0.041 | Tree loss: 1.813 | Accuracy: 0.343750 | 0.895 sec/iter
Epoch: 55 | Batch: 013 / 029 | Total loss: 1.784 | Reg loss: 0.041 | Tree loss: 1.784 | Accuracy: 0.365234 | 0.895 sec/iter
Epoch: 55 | Batch: 014 / 029 | Total loss: 1.803 | Reg loss: 0.041 | Tree loss: 1.803 | Accuracy: 0.343750 | 0.895 sec/iter
Epoch: 55 | Batch: 015 / 029 | Total loss: 1.748 | Reg loss: 0.041 | Tree loss: 1.748 | Accuracy: 0.394531 | 0.895 sec/iter
Epoch: 55 | Batch: 016 / 029 | Total loss: 1.727 | Reg loss: 0.041 | Tree loss: 1.727 | Accuracy: 0.404297 | 0.895 sec/iter
Epoch: 55 | Batch: 017 / 029 | Total loss: 1.772 | Reg loss: 0.041 | Tree loss: 1.772 | Accuracy: 0.353516 | 0.895 sec/iter
Epoch: 5

Epoch: 57 | Batch: 014 / 029 | Total loss: 1.764 | Reg loss: 0.041 | Tree loss: 1.764 | Accuracy: 0.343750 | 0.895 sec/iter
Epoch: 57 | Batch: 015 / 029 | Total loss: 1.693 | Reg loss: 0.041 | Tree loss: 1.693 | Accuracy: 0.388672 | 0.895 sec/iter
Epoch: 57 | Batch: 016 / 029 | Total loss: 1.690 | Reg loss: 0.041 | Tree loss: 1.690 | Accuracy: 0.353516 | 0.895 sec/iter
Epoch: 57 | Batch: 017 / 029 | Total loss: 1.737 | Reg loss: 0.041 | Tree loss: 1.737 | Accuracy: 0.363281 | 0.895 sec/iter
Epoch: 57 | Batch: 018 / 029 | Total loss: 1.651 | Reg loss: 0.041 | Tree loss: 1.651 | Accuracy: 0.421875 | 0.895 sec/iter
Epoch: 57 | Batch: 019 / 029 | Total loss: 1.690 | Reg loss: 0.041 | Tree loss: 1.690 | Accuracy: 0.373047 | 0.895 sec/iter
Epoch: 57 | Batch: 020 / 029 | Total loss: 1.693 | Reg loss: 0.041 | Tree loss: 1.693 | Accuracy: 0.361328 | 0.895 sec/iter
Epoch: 57 | Batch: 021 / 029 | Total loss: 1.757 | Reg loss: 0.041 | Tree loss: 1.757 | Accuracy: 0.337891 | 0.895 sec/iter
Epoch: 5

Epoch: 59 | Batch: 018 / 029 | Total loss: 1.651 | Reg loss: 0.041 | Tree loss: 1.651 | Accuracy: 0.339844 | 0.895 sec/iter
Epoch: 59 | Batch: 019 / 029 | Total loss: 1.650 | Reg loss: 0.041 | Tree loss: 1.650 | Accuracy: 0.380859 | 0.895 sec/iter
Epoch: 59 | Batch: 020 / 029 | Total loss: 1.641 | Reg loss: 0.041 | Tree loss: 1.641 | Accuracy: 0.398438 | 0.895 sec/iter
Epoch: 59 | Batch: 021 / 029 | Total loss: 1.612 | Reg loss: 0.041 | Tree loss: 1.612 | Accuracy: 0.351562 | 0.895 sec/iter
Epoch: 59 | Batch: 022 / 029 | Total loss: 1.608 | Reg loss: 0.041 | Tree loss: 1.608 | Accuracy: 0.380859 | 0.895 sec/iter
Epoch: 59 | Batch: 023 / 029 | Total loss: 1.663 | Reg loss: 0.041 | Tree loss: 1.663 | Accuracy: 0.371094 | 0.895 sec/iter
Epoch: 59 | Batch: 024 / 029 | Total loss: 1.612 | Reg loss: 0.041 | Tree loss: 1.612 | Accuracy: 0.351562 | 0.895 sec/iter
Epoch: 59 | Batch: 025 / 029 | Total loss: 1.550 | Reg loss: 0.041 | Tree loss: 1.550 | Accuracy: 0.386719 | 0.895 sec/iter
Epoch: 5

Epoch: 61 | Batch: 022 / 029 | Total loss: 1.631 | Reg loss: 0.041 | Tree loss: 1.631 | Accuracy: 0.341797 | 0.895 sec/iter
Epoch: 61 | Batch: 023 / 029 | Total loss: 1.560 | Reg loss: 0.041 | Tree loss: 1.560 | Accuracy: 0.396484 | 0.895 sec/iter
Epoch: 61 | Batch: 024 / 029 | Total loss: 1.571 | Reg loss: 0.041 | Tree loss: 1.571 | Accuracy: 0.373047 | 0.895 sec/iter
Epoch: 61 | Batch: 025 / 029 | Total loss: 1.580 | Reg loss: 0.041 | Tree loss: 1.580 | Accuracy: 0.355469 | 0.895 sec/iter
Epoch: 61 | Batch: 026 / 029 | Total loss: 1.580 | Reg loss: 0.041 | Tree loss: 1.580 | Accuracy: 0.363281 | 0.895 sec/iter
Epoch: 61 | Batch: 027 / 029 | Total loss: 1.589 | Reg loss: 0.041 | Tree loss: 1.589 | Accuracy: 0.363281 | 0.895 sec/iter
Epoch: 61 | Batch: 028 / 029 | Total loss: 1.606 | Reg loss: 0.041 | Tree loss: 1.606 | Accuracy: 0.400673 | 0.895 sec/iter
Average sparseness: 0.9821428571428573
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 

Epoch: 63 | Batch: 026 / 029 | Total loss: 1.558 | Reg loss: 0.041 | Tree loss: 1.558 | Accuracy: 0.416016 | 0.895 sec/iter
Epoch: 63 | Batch: 027 / 029 | Total loss: 1.584 | Reg loss: 0.041 | Tree loss: 1.584 | Accuracy: 0.345703 | 0.895 sec/iter
Epoch: 63 | Batch: 028 / 029 | Total loss: 1.540 | Reg loss: 0.041 | Tree loss: 1.540 | Accuracy: 0.343434 | 0.895 sec/iter
Average sparseness: 0.9821428571428573
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
layer 7: 0.9821428571428573
layer 8: 0.9821428571428573
Epoch: 64 | Batch: 000 / 029 | Total loss: 1.871 | Reg loss: 0.040 | Tree loss: 1.871 | Accuracy: 0.355469 | 0.895 sec/iter
Epoch: 64 | Batch: 001 / 029 | Total loss: 1.868 | Reg loss: 0.040 | Tree loss: 1.868 | Accuracy: 0.339844 | 0.895 sec/iter
Epoch: 64 | Batch: 002 / 029 | Total loss: 1.806 | Reg loss: 0.040 | Tree loss: 1.806 | Ac

Epoch: 66 | Batch: 000 / 029 | Total loss: 1.846 | Reg loss: 0.040 | Tree loss: 1.846 | Accuracy: 0.369141 | 0.895 sec/iter
Epoch: 66 | Batch: 001 / 029 | Total loss: 1.848 | Reg loss: 0.040 | Tree loss: 1.848 | Accuracy: 0.400391 | 0.895 sec/iter
Epoch: 66 | Batch: 002 / 029 | Total loss: 1.821 | Reg loss: 0.040 | Tree loss: 1.821 | Accuracy: 0.376953 | 0.895 sec/iter
Epoch: 66 | Batch: 003 / 029 | Total loss: 1.845 | Reg loss: 0.040 | Tree loss: 1.845 | Accuracy: 0.349609 | 0.895 sec/iter
Epoch: 66 | Batch: 004 / 029 | Total loss: 1.766 | Reg loss: 0.040 | Tree loss: 1.766 | Accuracy: 0.341797 | 0.895 sec/iter
Epoch: 66 | Batch: 005 / 029 | Total loss: 1.754 | Reg loss: 0.040 | Tree loss: 1.754 | Accuracy: 0.349609 | 0.895 sec/iter
Epoch: 66 | Batch: 006 / 029 | Total loss: 1.736 | Reg loss: 0.040 | Tree loss: 1.736 | Accuracy: 0.361328 | 0.895 sec/iter
Epoch: 66 | Batch: 007 / 029 | Total loss: 1.845 | Reg loss: 0.040 | Tree loss: 1.845 | Accuracy: 0.339844 | 0.895 sec/iter
Epoch: 6

Epoch: 68 | Batch: 004 / 029 | Total loss: 1.817 | Reg loss: 0.040 | Tree loss: 1.817 | Accuracy: 0.390625 | 0.895 sec/iter
Epoch: 68 | Batch: 005 / 029 | Total loss: 1.806 | Reg loss: 0.040 | Tree loss: 1.806 | Accuracy: 0.388672 | 0.895 sec/iter
Epoch: 68 | Batch: 006 / 029 | Total loss: 1.779 | Reg loss: 0.040 | Tree loss: 1.779 | Accuracy: 0.376953 | 0.895 sec/iter
Epoch: 68 | Batch: 007 / 029 | Total loss: 1.765 | Reg loss: 0.040 | Tree loss: 1.765 | Accuracy: 0.330078 | 0.895 sec/iter
Epoch: 68 | Batch: 008 / 029 | Total loss: 1.706 | Reg loss: 0.040 | Tree loss: 1.706 | Accuracy: 0.400391 | 0.895 sec/iter
Epoch: 68 | Batch: 009 / 029 | Total loss: 1.628 | Reg loss: 0.040 | Tree loss: 1.628 | Accuracy: 0.384766 | 0.895 sec/iter
Epoch: 68 | Batch: 010 / 029 | Total loss: 1.696 | Reg loss: 0.040 | Tree loss: 1.696 | Accuracy: 0.365234 | 0.895 sec/iter
Epoch: 68 | Batch: 011 / 029 | Total loss: 1.667 | Reg loss: 0.040 | Tree loss: 1.667 | Accuracy: 0.398438 | 0.895 sec/iter
Epoch: 6

Epoch: 70 | Batch: 008 / 029 | Total loss: 1.746 | Reg loss: 0.040 | Tree loss: 1.746 | Accuracy: 0.382812 | 0.895 sec/iter
Epoch: 70 | Batch: 009 / 029 | Total loss: 1.723 | Reg loss: 0.040 | Tree loss: 1.723 | Accuracy: 0.339844 | 0.895 sec/iter
Epoch: 70 | Batch: 010 / 029 | Total loss: 1.635 | Reg loss: 0.040 | Tree loss: 1.635 | Accuracy: 0.408203 | 0.895 sec/iter
Epoch: 70 | Batch: 011 / 029 | Total loss: 1.666 | Reg loss: 0.040 | Tree loss: 1.666 | Accuracy: 0.324219 | 0.895 sec/iter
Epoch: 70 | Batch: 012 / 029 | Total loss: 1.667 | Reg loss: 0.040 | Tree loss: 1.667 | Accuracy: 0.337891 | 0.895 sec/iter
Epoch: 70 | Batch: 013 / 029 | Total loss: 1.669 | Reg loss: 0.040 | Tree loss: 1.669 | Accuracy: 0.369141 | 0.895 sec/iter
Epoch: 70 | Batch: 014 / 029 | Total loss: 1.674 | Reg loss: 0.040 | Tree loss: 1.674 | Accuracy: 0.369141 | 0.895 sec/iter
Epoch: 70 | Batch: 015 / 029 | Total loss: 1.621 | Reg loss: 0.040 | Tree loss: 1.621 | Accuracy: 0.390625 | 0.895 sec/iter
Epoch: 7

Epoch: 72 | Batch: 012 / 029 | Total loss: 1.640 | Reg loss: 0.040 | Tree loss: 1.640 | Accuracy: 0.357422 | 0.895 sec/iter
Epoch: 72 | Batch: 013 / 029 | Total loss: 1.609 | Reg loss: 0.040 | Tree loss: 1.609 | Accuracy: 0.402344 | 0.895 sec/iter
Epoch: 72 | Batch: 014 / 029 | Total loss: 1.586 | Reg loss: 0.040 | Tree loss: 1.586 | Accuracy: 0.373047 | 0.895 sec/iter
Epoch: 72 | Batch: 015 / 029 | Total loss: 1.624 | Reg loss: 0.040 | Tree loss: 1.624 | Accuracy: 0.355469 | 0.895 sec/iter
Epoch: 72 | Batch: 016 / 029 | Total loss: 1.619 | Reg loss: 0.040 | Tree loss: 1.619 | Accuracy: 0.324219 | 0.894 sec/iter
Epoch: 72 | Batch: 017 / 029 | Total loss: 1.538 | Reg loss: 0.040 | Tree loss: 1.538 | Accuracy: 0.386719 | 0.894 sec/iter
Epoch: 72 | Batch: 018 / 029 | Total loss: 1.558 | Reg loss: 0.040 | Tree loss: 1.558 | Accuracy: 0.388672 | 0.894 sec/iter
Epoch: 72 | Batch: 019 / 029 | Total loss: 1.583 | Reg loss: 0.040 | Tree loss: 1.583 | Accuracy: 0.349609 | 0.894 sec/iter
Epoch: 7

Epoch: 74 | Batch: 016 / 029 | Total loss: 1.562 | Reg loss: 0.040 | Tree loss: 1.562 | Accuracy: 0.345703 | 0.894 sec/iter
Epoch: 74 | Batch: 017 / 029 | Total loss: 1.607 | Reg loss: 0.040 | Tree loss: 1.607 | Accuracy: 0.330078 | 0.894 sec/iter
Epoch: 74 | Batch: 018 / 029 | Total loss: 1.641 | Reg loss: 0.040 | Tree loss: 1.641 | Accuracy: 0.367188 | 0.894 sec/iter
Epoch: 74 | Batch: 019 / 029 | Total loss: 1.593 | Reg loss: 0.040 | Tree loss: 1.593 | Accuracy: 0.367188 | 0.894 sec/iter
Epoch: 74 | Batch: 020 / 029 | Total loss: 1.539 | Reg loss: 0.040 | Tree loss: 1.539 | Accuracy: 0.369141 | 0.894 sec/iter
Epoch: 74 | Batch: 021 / 029 | Total loss: 1.537 | Reg loss: 0.040 | Tree loss: 1.537 | Accuracy: 0.351562 | 0.894 sec/iter
Epoch: 74 | Batch: 022 / 029 | Total loss: 1.540 | Reg loss: 0.040 | Tree loss: 1.540 | Accuracy: 0.378906 | 0.894 sec/iter
Epoch: 74 | Batch: 023 / 029 | Total loss: 1.525 | Reg loss: 0.040 | Tree loss: 1.525 | Accuracy: 0.388672 | 0.894 sec/iter
Epoch: 7

Epoch: 76 | Batch: 020 / 029 | Total loss: 1.506 | Reg loss: 0.040 | Tree loss: 1.506 | Accuracy: 0.367188 | 0.894 sec/iter
Epoch: 76 | Batch: 021 / 029 | Total loss: 1.617 | Reg loss: 0.040 | Tree loss: 1.617 | Accuracy: 0.375000 | 0.894 sec/iter
Epoch: 76 | Batch: 022 / 029 | Total loss: 1.583 | Reg loss: 0.040 | Tree loss: 1.583 | Accuracy: 0.337891 | 0.894 sec/iter
Epoch: 76 | Batch: 023 / 029 | Total loss: 1.475 | Reg loss: 0.040 | Tree loss: 1.475 | Accuracy: 0.404297 | 0.894 sec/iter
Epoch: 76 | Batch: 024 / 029 | Total loss: 1.533 | Reg loss: 0.040 | Tree loss: 1.533 | Accuracy: 0.357422 | 0.894 sec/iter
Epoch: 76 | Batch: 025 / 029 | Total loss: 1.499 | Reg loss: 0.040 | Tree loss: 1.499 | Accuracy: 0.351562 | 0.894 sec/iter
Epoch: 76 | Batch: 026 / 029 | Total loss: 1.497 | Reg loss: 0.040 | Tree loss: 1.497 | Accuracy: 0.376953 | 0.894 sec/iter
Epoch: 76 | Batch: 027 / 029 | Total loss: 1.502 | Reg loss: 0.040 | Tree loss: 1.502 | Accuracy: 0.404297 | 0.894 sec/iter
Epoch: 7

Epoch: 78 | Batch: 024 / 029 | Total loss: 1.510 | Reg loss: 0.040 | Tree loss: 1.510 | Accuracy: 0.384766 | 0.894 sec/iter
Epoch: 78 | Batch: 025 / 029 | Total loss: 1.513 | Reg loss: 0.040 | Tree loss: 1.513 | Accuracy: 0.355469 | 0.894 sec/iter
Epoch: 78 | Batch: 026 / 029 | Total loss: 1.461 | Reg loss: 0.040 | Tree loss: 1.461 | Accuracy: 0.386719 | 0.894 sec/iter
Epoch: 78 | Batch: 027 / 029 | Total loss: 1.511 | Reg loss: 0.040 | Tree loss: 1.511 | Accuracy: 0.347656 | 0.894 sec/iter
Epoch: 78 | Batch: 028 / 029 | Total loss: 1.501 | Reg loss: 0.040 | Tree loss: 1.501 | Accuracy: 0.387205 | 0.894 sec/iter
Average sparseness: 0.9821428571428573
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
layer 7: 0.9821428571428573
layer 8: 0.9821428571428573
Epoch: 79 | Batch: 000 / 029 | Total loss: 1.773 | Reg loss: 0.039 | Tree loss: 1.773 | Ac

Epoch: 80 | Batch: 028 / 029 | Total loss: 1.498 | Reg loss: 0.040 | Tree loss: 1.498 | Accuracy: 0.377104 | 0.894 sec/iter
Average sparseness: 0.9821428571428573
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
layer 7: 0.9821428571428573
layer 8: 0.9821428571428573
Epoch: 81 | Batch: 000 / 029 | Total loss: 1.850 | Reg loss: 0.039 | Tree loss: 1.850 | Accuracy: 0.376953 | 0.894 sec/iter
Epoch: 81 | Batch: 001 / 029 | Total loss: 1.803 | Reg loss: 0.039 | Tree loss: 1.803 | Accuracy: 0.332031 | 0.894 sec/iter
Epoch: 81 | Batch: 002 / 029 | Total loss: 1.787 | Reg loss: 0.039 | Tree loss: 1.787 | Accuracy: 0.341797 | 0.894 sec/iter
Epoch: 81 | Batch: 003 / 029 | Total loss: 1.754 | Reg loss: 0.039 | Tree loss: 1.754 | Accuracy: 0.384766 | 0.894 sec/iter
Epoch: 81 | Batch: 004 / 029 | Total loss: 1.823 | Reg loss: 0.039 | Tree loss: 1.823 | Ac

Epoch: 83 | Batch: 001 / 029 | Total loss: 1.819 | Reg loss: 0.038 | Tree loss: 1.819 | Accuracy: 0.335938 | 0.894 sec/iter
Epoch: 83 | Batch: 002 / 029 | Total loss: 1.850 | Reg loss: 0.038 | Tree loss: 1.850 | Accuracy: 0.351562 | 0.894 sec/iter
Epoch: 83 | Batch: 003 / 029 | Total loss: 1.805 | Reg loss: 0.038 | Tree loss: 1.805 | Accuracy: 0.351562 | 0.894 sec/iter
Epoch: 83 | Batch: 004 / 029 | Total loss: 1.756 | Reg loss: 0.038 | Tree loss: 1.756 | Accuracy: 0.375000 | 0.894 sec/iter
Epoch: 83 | Batch: 005 / 029 | Total loss: 1.678 | Reg loss: 0.038 | Tree loss: 1.678 | Accuracy: 0.382812 | 0.894 sec/iter
Epoch: 83 | Batch: 006 / 029 | Total loss: 1.702 | Reg loss: 0.038 | Tree loss: 1.702 | Accuracy: 0.371094 | 0.894 sec/iter
Epoch: 83 | Batch: 007 / 029 | Total loss: 1.717 | Reg loss: 0.038 | Tree loss: 1.717 | Accuracy: 0.378906 | 0.894 sec/iter
Epoch: 83 | Batch: 008 / 029 | Total loss: 1.667 | Reg loss: 0.039 | Tree loss: 1.667 | Accuracy: 0.357422 | 0.894 sec/iter
Epoch: 8

Epoch: 85 | Batch: 005 / 029 | Total loss: 1.701 | Reg loss: 0.038 | Tree loss: 1.701 | Accuracy: 0.378906 | 0.894 sec/iter
Epoch: 85 | Batch: 006 / 029 | Total loss: 1.628 | Reg loss: 0.038 | Tree loss: 1.628 | Accuracy: 0.406250 | 0.894 sec/iter
Epoch: 85 | Batch: 007 / 029 | Total loss: 1.722 | Reg loss: 0.038 | Tree loss: 1.722 | Accuracy: 0.371094 | 0.894 sec/iter
Epoch: 85 | Batch: 008 / 029 | Total loss: 1.629 | Reg loss: 0.038 | Tree loss: 1.629 | Accuracy: 0.416016 | 0.894 sec/iter
Epoch: 85 | Batch: 009 / 029 | Total loss: 1.716 | Reg loss: 0.038 | Tree loss: 1.716 | Accuracy: 0.332031 | 0.894 sec/iter
Epoch: 85 | Batch: 010 / 029 | Total loss: 1.633 | Reg loss: 0.038 | Tree loss: 1.633 | Accuracy: 0.390625 | 0.894 sec/iter
Epoch: 85 | Batch: 011 / 029 | Total loss: 1.644 | Reg loss: 0.038 | Tree loss: 1.644 | Accuracy: 0.363281 | 0.894 sec/iter
Epoch: 85 | Batch: 012 / 029 | Total loss: 1.589 | Reg loss: 0.039 | Tree loss: 1.589 | Accuracy: 0.394531 | 0.894 sec/iter
Epoch: 8

Epoch: 87 | Batch: 009 / 029 | Total loss: 1.677 | Reg loss: 0.038 | Tree loss: 1.677 | Accuracy: 0.367188 | 0.894 sec/iter
Epoch: 87 | Batch: 010 / 029 | Total loss: 1.595 | Reg loss: 0.038 | Tree loss: 1.595 | Accuracy: 0.398438 | 0.894 sec/iter
Epoch: 87 | Batch: 011 / 029 | Total loss: 1.627 | Reg loss: 0.038 | Tree loss: 1.627 | Accuracy: 0.351562 | 0.894 sec/iter
Epoch: 87 | Batch: 012 / 029 | Total loss: 1.576 | Reg loss: 0.038 | Tree loss: 1.576 | Accuracy: 0.365234 | 0.894 sec/iter
Epoch: 87 | Batch: 013 / 029 | Total loss: 1.602 | Reg loss: 0.038 | Tree loss: 1.602 | Accuracy: 0.388672 | 0.894 sec/iter
Epoch: 87 | Batch: 014 / 029 | Total loss: 1.585 | Reg loss: 0.039 | Tree loss: 1.585 | Accuracy: 0.367188 | 0.894 sec/iter
Epoch: 87 | Batch: 015 / 029 | Total loss: 1.524 | Reg loss: 0.039 | Tree loss: 1.524 | Accuracy: 0.392578 | 0.894 sec/iter
Epoch: 87 | Batch: 016 / 029 | Total loss: 1.553 | Reg loss: 0.039 | Tree loss: 1.553 | Accuracy: 0.375000 | 0.894 sec/iter
Epoch: 8

Epoch: 89 | Batch: 013 / 029 | Total loss: 1.607 | Reg loss: 0.038 | Tree loss: 1.607 | Accuracy: 0.396484 | 0.894 sec/iter
Epoch: 89 | Batch: 014 / 029 | Total loss: 1.585 | Reg loss: 0.038 | Tree loss: 1.585 | Accuracy: 0.357422 | 0.894 sec/iter
Epoch: 89 | Batch: 015 / 029 | Total loss: 1.535 | Reg loss: 0.038 | Tree loss: 1.535 | Accuracy: 0.382812 | 0.894 sec/iter
Epoch: 89 | Batch: 016 / 029 | Total loss: 1.552 | Reg loss: 0.038 | Tree loss: 1.552 | Accuracy: 0.347656 | 0.894 sec/iter
Epoch: 89 | Batch: 017 / 029 | Total loss: 1.480 | Reg loss: 0.039 | Tree loss: 1.480 | Accuracy: 0.378906 | 0.894 sec/iter
Epoch: 89 | Batch: 018 / 029 | Total loss: 1.470 | Reg loss: 0.039 | Tree loss: 1.470 | Accuracy: 0.380859 | 0.894 sec/iter
Epoch: 89 | Batch: 019 / 029 | Total loss: 1.507 | Reg loss: 0.039 | Tree loss: 1.507 | Accuracy: 0.365234 | 0.894 sec/iter
Epoch: 89 | Batch: 020 / 029 | Total loss: 1.495 | Reg loss: 0.039 | Tree loss: 1.495 | Accuracy: 0.365234 | 0.894 sec/iter
Epoch: 8

Epoch: 91 | Batch: 017 / 029 | Total loss: 1.525 | Reg loss: 0.038 | Tree loss: 1.525 | Accuracy: 0.382812 | 0.894 sec/iter
Epoch: 91 | Batch: 018 / 029 | Total loss: 1.547 | Reg loss: 0.038 | Tree loss: 1.547 | Accuracy: 0.373047 | 0.894 sec/iter
Epoch: 91 | Batch: 019 / 029 | Total loss: 1.509 | Reg loss: 0.039 | Tree loss: 1.509 | Accuracy: 0.349609 | 0.894 sec/iter
Epoch: 91 | Batch: 020 / 029 | Total loss: 1.510 | Reg loss: 0.039 | Tree loss: 1.510 | Accuracy: 0.396484 | 0.894 sec/iter
Epoch: 91 | Batch: 021 / 029 | Total loss: 1.545 | Reg loss: 0.039 | Tree loss: 1.545 | Accuracy: 0.335938 | 0.894 sec/iter
Epoch: 91 | Batch: 022 / 029 | Total loss: 1.469 | Reg loss: 0.039 | Tree loss: 1.469 | Accuracy: 0.378906 | 0.894 sec/iter
Epoch: 91 | Batch: 023 / 029 | Total loss: 1.493 | Reg loss: 0.039 | Tree loss: 1.493 | Accuracy: 0.365234 | 0.894 sec/iter
Epoch: 91 | Batch: 024 / 029 | Total loss: 1.492 | Reg loss: 0.039 | Tree loss: 1.492 | Accuracy: 0.353516 | 0.894 sec/iter
Epoch: 9

Epoch: 93 | Batch: 021 / 029 | Total loss: 1.472 | Reg loss: 0.039 | Tree loss: 1.472 | Accuracy: 0.373047 | 0.894 sec/iter
Epoch: 93 | Batch: 022 / 029 | Total loss: 1.484 | Reg loss: 0.039 | Tree loss: 1.484 | Accuracy: 0.359375 | 0.894 sec/iter
Epoch: 93 | Batch: 023 / 029 | Total loss: 1.472 | Reg loss: 0.039 | Tree loss: 1.472 | Accuracy: 0.355469 | 0.894 sec/iter
Epoch: 93 | Batch: 024 / 029 | Total loss: 1.501 | Reg loss: 0.039 | Tree loss: 1.501 | Accuracy: 0.371094 | 0.894 sec/iter
Epoch: 93 | Batch: 025 / 029 | Total loss: 1.474 | Reg loss: 0.039 | Tree loss: 1.474 | Accuracy: 0.341797 | 0.894 sec/iter
Epoch: 93 | Batch: 026 / 029 | Total loss: 1.494 | Reg loss: 0.039 | Tree loss: 1.494 | Accuracy: 0.388672 | 0.894 sec/iter
Epoch: 93 | Batch: 027 / 029 | Total loss: 1.477 | Reg loss: 0.039 | Tree loss: 1.477 | Accuracy: 0.375000 | 0.894 sec/iter
Epoch: 93 | Batch: 028 / 029 | Total loss: 1.467 | Reg loss: 0.039 | Tree loss: 1.467 | Accuracy: 0.370370 | 0.894 sec/iter
Average 

Epoch: 95 | Batch: 025 / 029 | Total loss: 1.457 | Reg loss: 0.039 | Tree loss: 1.457 | Accuracy: 0.365234 | 0.894 sec/iter
Epoch: 95 | Batch: 026 / 029 | Total loss: 1.423 | Reg loss: 0.039 | Tree loss: 1.423 | Accuracy: 0.357422 | 0.894 sec/iter
Epoch: 95 | Batch: 027 / 029 | Total loss: 1.479 | Reg loss: 0.039 | Tree loss: 1.479 | Accuracy: 0.349609 | 0.894 sec/iter
Epoch: 95 | Batch: 028 / 029 | Total loss: 1.424 | Reg loss: 0.039 | Tree loss: 1.424 | Accuracy: 0.387205 | 0.894 sec/iter
Average sparseness: 0.9821428571428573
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
layer 7: 0.9821428571428573
layer 8: 0.9821428571428573
Epoch: 96 | Batch: 000 / 029 | Total loss: 1.843 | Reg loss: 0.037 | Tree loss: 1.843 | Accuracy: 0.363281 | 0.894 sec/iter
Epoch: 96 | Batch: 001 / 029 | Total loss: 1.832 | Reg loss: 0.037 | Tree loss: 1.832 | Ac

Average sparseness: 0.9821428571428573
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
layer 7: 0.9821428571428573
layer 8: 0.9821428571428573
Epoch: 98 | Batch: 000 / 029 | Total loss: 1.793 | Reg loss: 0.037 | Tree loss: 1.793 | Accuracy: 0.355469 | 0.894 sec/iter
Epoch: 98 | Batch: 001 / 029 | Total loss: 1.799 | Reg loss: 0.037 | Tree loss: 1.799 | Accuracy: 0.363281 | 0.894 sec/iter
Epoch: 98 | Batch: 002 / 029 | Total loss: 1.765 | Reg loss: 0.037 | Tree loss: 1.765 | Accuracy: 0.363281 | 0.894 sec/iter
Epoch: 98 | Batch: 003 / 029 | Total loss: 1.777 | Reg loss: 0.037 | Tree loss: 1.777 | Accuracy: 0.365234 | 0.894 sec/iter
Epoch: 98 | Batch: 004 / 029 | Total loss: 1.752 | Reg loss: 0.037 | Tree loss: 1.752 | Accuracy: 0.375000 | 0.894 sec/iter
Epoch: 98 | Batch: 005 / 029 | Total loss: 1.695 | Reg loss: 0.037 | Tree loss: 1.695 | Ac

In [31]:
plt.figure(figsize=(10, 5))
plt.ylabel("Accuracy")
plt.xlabel('Iteration')
plt.plot(accs, label='Accuracy vs iteration')
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [32]:
plt.figure()
plt.ylabel("Loss")
plt.xlabel('Iteration')
plt.plot(losses, label='Loss vs iteration')
plt.yscale("log")
plt.show()

plt.figure()
weights = tree.inner_nodes.weight.cpu().detach().numpy().flatten()
plt.hist(weights, bins=500)
weights_std = np.std(weights)
weights_mean = np.mean(weights)
plt.axvline(weights_mean + weights_std, color='r')
plt.axvline(weights_mean - weights_std, color='r')
plt.title(f"Mean: {weights_mean}   |   STD: {weights_std}")
plt.yscale("log")
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

# Tree Visualization

In [33]:
plt.figure(figsize=(15, 10), dpi=80)
avg_height, root = tree.visualize()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Average height: 9.652777777777779


# Extract Rules

# Accumulate samples in the leaves

In [34]:
print(f"Number of patterns: {len(root.get_leaves())}")

Number of patterns: 720


In [35]:
method = 'greedy'

In [36]:
root.clear_leaves_samples()

with torch.no_grad():
    for batch_idx, (data, target) in enumerate(tree_loader):
        root.accumulate_samples(data, method)



# Tighten boundaries

In [37]:
attr_names = dataset.items

# print(attr_names)
leaves = root.get_leaves()
sum_comprehensibility = 0
comprehensibilities = []
for pattern_counter, leaf in enumerate(leaves):
    leaf.reset_path()
    leaf.tighten_with_accumulated_samples()
    conds = leaf.get_path_conditions(attr_names)
    print(f"============== Pattern {pattern_counter + 1} ==============")
    comprehensibilities.append(sum([cond.comprehensibility for cond in conds]))
    
print(f"Average comprehensibility: {np.mean(comprehensibilities)}")
print(f"std comprehensibility: {np.std(comprehensibilities)}")
print(f"var comprehensibility: {np.var(comprehensibilities)}")
print(f"minimum comprehensibility: {np.min(comprehensibilities)}")
print(f"maximum comprehensibility: {np.max(comprehensibilities)}")



  return np.log(1 / (1 - x))


14633


Average comprehensibility: 48.25555555555555
std comprehensibility: 3.8313360336140105
var comprehensibility: 14.679135802469137
minimum comprehensibility: 34
maximum comprehensibility: 56
