In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib widget
import numpy as np
import pandas as pd
import torch
import torchvision
from torch import optim
import matplotlib.pyplot as plt
import torch.nn.functional as F
from sklearn.metrics import pairwise_distances
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tqdm.notebook import tqdm
import os
import sys
import time
import torch.nn as nn
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)
from stream_generators.market_basket_dataset import MarketBasketDataset, BinaryEncodingTransform, RemoveItemsTransform
from utils.MatplotlibUtils import reduce_dims_and_plot
from network.auto_encoder import AutoEncoder
from losses.knn_loss import KNNLoss
from sklearn.cluster import DBSCAN
from sklearn.mixture import GaussianMixture
from soft_decision_tree.sdt_model import SDT
from sklearn.metrics import davies_bouldin_score

In [2]:
k = 64
tree_depth = 10
device = 'cuda'
dataset_path = r"/mnt/qnap/ekosman/Groceries_dataset.csv"

## Create the market basket dataset and use one-hot encoding for items

In [3]:
dataset = MarketBasketDataset(dataset_path=dataset_path)

In [4]:
model = AutoEncoder(dataset.n_items, 50, 4).train().to(device)
epochs = 500
lr = 5e-3
batch_size = 512
log_every = 5

In [5]:
dataset.transform = torchvision.transforms.Compose([
    RemoveItemsTransform(p=0.5),
    BinaryEncodingTransform(mapping=dataset.items_to_idx),
]
)
dataset.target_transform = torchvision.transforms.Compose([
    BinaryEncodingTransform(mapping=dataset.items_to_idx),
]
)

In [6]:
model.train()
data_iter = torch.utils.data.DataLoader(dataset,
                                     batch_size=batch_size,
                                     shuffle=True,
                                     num_workers=1,
                                     pin_memory=True)
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)
scheduler = ReduceLROnPlateau(optimizer, 'min', verbose=True, factor=0.7, threshold=1e-4)
knn_crt = KNNLoss(k=k).to(device)
losses = []
alpha = 10/170
gamma = 2
for epoch in range(epochs):
    total_loss = 0
    for iteration, (batch, target) in enumerate(data_iter):
        batch = batch.to(device)
        target = target.to(device)
        outputs, iterm = model(batch, return_intermidiate=True)
        mse_loss = F.binary_cross_entropy_with_logits(outputs, target, reduction='none')
        mask = torch.ones_like(mse_loss)
        mask[target == 0] = alpha ** gamma
        mask[target == 1] = (1 - alpha) ** gamma
        mse_loss = (mse_loss * mask).sum(dim=-1).mean()
        try:
            knn_loss = knn_crt(iterm)
            if torch.isinf(knn_loss):
                knn_loss = 0
        except ValueError:
            knn_loss = torch.tensor(0)
        loss = mse_loss + knn_loss
        total_loss += loss.item()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if iteration % log_every == 0:
            print(f"Epoch {epoch} / {epochs} | iteration {iteration} / {len(data_iter)} | Total Loss: {loss.item()} | KNN Loss: {knn_loss.item()} | BCE Loss: {mse_loss.item()}")
    
    scheduler.step(total_loss / (iteration + 1))
    losses.append(total_loss / (iteration + 1))

Epoch 0 / 500 | iteration 0 / 30 | Total Loss: 8.198735237121582 | KNN Loss: 6.228330612182617 | BCE Loss: 1.9704046249389648
Epoch 0 / 500 | iteration 5 / 30 | Total Loss: 8.193399429321289 | KNN Loss: 6.228418827056885 | BCE Loss: 1.9649808406829834
Epoch 0 / 500 | iteration 10 / 30 | Total Loss: 8.223655700683594 | KNN Loss: 6.228407859802246 | BCE Loss: 1.9952478408813477
Epoch 0 / 500 | iteration 15 / 30 | Total Loss: 8.12940502166748 | KNN Loss: 6.228500843048096 | BCE Loss: 1.9009044170379639
Epoch 0 / 500 | iteration 20 / 30 | Total Loss: 8.163383483886719 | KNN Loss: 6.227762699127197 | BCE Loss: 1.9356210231781006
Epoch 0 / 500 | iteration 25 / 30 | Total Loss: 8.170571327209473 | KNN Loss: 6.227832317352295 | BCE Loss: 1.9427390098571777
Epoch 1 / 500 | iteration 0 / 30 | Total Loss: 8.164538383483887 | KNN Loss: 6.22742223739624 | BCE Loss: 1.9371163845062256
Epoch 1 / 500 | iteration 5 / 30 | Total Loss: 8.091979026794434 | KNN Loss: 6.22742223739624 | BCE Loss: 1.86455714

Epoch 10 / 500 | iteration 25 / 30 | Total Loss: 6.738680839538574 | KNN Loss: 5.639242649078369 | BCE Loss: 1.099438190460205
Epoch 11 / 500 | iteration 0 / 30 | Total Loss: 6.765195846557617 | KNN Loss: 5.623289108276367 | BCE Loss: 1.14190673828125
Epoch 11 / 500 | iteration 5 / 30 | Total Loss: 6.651707649230957 | KNN Loss: 5.5264081954956055 | BCE Loss: 1.1252994537353516
Epoch 11 / 500 | iteration 10 / 30 | Total Loss: 6.5760297775268555 | KNN Loss: 5.452567100524902 | BCE Loss: 1.1234629154205322
Epoch 11 / 500 | iteration 15 / 30 | Total Loss: 6.507997035980225 | KNN Loss: 5.396655082702637 | BCE Loss: 1.1113418340682983
Epoch 11 / 500 | iteration 20 / 30 | Total Loss: 6.445337295532227 | KNN Loss: 5.33567476272583 | BCE Loss: 1.1096622943878174
Epoch 11 / 500 | iteration 25 / 30 | Total Loss: 6.339080810546875 | KNN Loss: 5.23173189163208 | BCE Loss: 1.107349157333374
Epoch 12 / 500 | iteration 0 / 30 | Total Loss: 6.220019817352295 | KNN Loss: 5.13062858581543 | BCE Loss: 1.0

Epoch 21 / 500 | iteration 20 / 30 | Total Loss: 5.498337745666504 | KNN Loss: 4.450432777404785 | BCE Loss: 1.0479047298431396
Epoch 21 / 500 | iteration 25 / 30 | Total Loss: 5.499588966369629 | KNN Loss: 4.450000286102295 | BCE Loss: 1.049588918685913
Epoch 22 / 500 | iteration 0 / 30 | Total Loss: 5.5169782638549805 | KNN Loss: 4.461717128753662 | BCE Loss: 1.0552613735198975
Epoch 22 / 500 | iteration 5 / 30 | Total Loss: 5.492061614990234 | KNN Loss: 4.438493251800537 | BCE Loss: 1.0535683631896973
Epoch 22 / 500 | iteration 10 / 30 | Total Loss: 5.5384840965271 | KNN Loss: 4.464260101318359 | BCE Loss: 1.0742239952087402
Epoch 22 / 500 | iteration 15 / 30 | Total Loss: 5.467388153076172 | KNN Loss: 4.433053493499756 | BCE Loss: 1.0343348979949951
Epoch 22 / 500 | iteration 20 / 30 | Total Loss: 5.527736663818359 | KNN Loss: 4.471822261810303 | BCE Loss: 1.0559141635894775
Epoch 22 / 500 | iteration 25 / 30 | Total Loss: 5.500675201416016 | KNN Loss: 4.458330154418945 | BCE Loss:

Epoch 32 / 500 | iteration 15 / 30 | Total Loss: 5.4654645919799805 | KNN Loss: 4.422767162322998 | BCE Loss: 1.0426971912384033
Epoch 32 / 500 | iteration 20 / 30 | Total Loss: 5.493514060974121 | KNN Loss: 4.415186405181885 | BCE Loss: 1.0783274173736572
Epoch 32 / 500 | iteration 25 / 30 | Total Loss: 5.479903221130371 | KNN Loss: 4.41825532913208 | BCE Loss: 1.0616480112075806
Epoch 33 / 500 | iteration 0 / 30 | Total Loss: 5.450873374938965 | KNN Loss: 4.432248592376709 | BCE Loss: 1.018625020980835
Epoch 33 / 500 | iteration 5 / 30 | Total Loss: 5.460605621337891 | KNN Loss: 4.403741359710693 | BCE Loss: 1.0568640232086182
Epoch 33 / 500 | iteration 10 / 30 | Total Loss: 5.501046180725098 | KNN Loss: 4.443245887756348 | BCE Loss: 1.05780029296875
Epoch 33 / 500 | iteration 15 / 30 | Total Loss: 5.489946365356445 | KNN Loss: 4.433122158050537 | BCE Loss: 1.0568244457244873
Epoch 33 / 500 | iteration 20 / 30 | Total Loss: 5.447597980499268 | KNN Loss: 4.405772686004639 | BCE Loss: 

Epoch 43 / 500 | iteration 10 / 30 | Total Loss: 5.494498252868652 | KNN Loss: 4.442195892333984 | BCE Loss: 1.0523024797439575
Epoch 43 / 500 | iteration 15 / 30 | Total Loss: 5.465056896209717 | KNN Loss: 4.438093185424805 | BCE Loss: 1.0269638299942017
Epoch 43 / 500 | iteration 20 / 30 | Total Loss: 5.502208709716797 | KNN Loss: 4.44742488861084 | BCE Loss: 1.054783582687378
Epoch 43 / 500 | iteration 25 / 30 | Total Loss: 5.425683498382568 | KNN Loss: 4.4065117835998535 | BCE Loss: 1.0191717147827148
Epoch 44 / 500 | iteration 0 / 30 | Total Loss: 5.431161880493164 | KNN Loss: 4.394833564758301 | BCE Loss: 1.0363280773162842
Epoch 44 / 500 | iteration 5 / 30 | Total Loss: 5.4571638107299805 | KNN Loss: 4.421267509460449 | BCE Loss: 1.0358960628509521
Epoch 44 / 500 | iteration 10 / 30 | Total Loss: 5.449094295501709 | KNN Loss: 4.409332752227783 | BCE Loss: 1.0397614240646362
Epoch 44 / 500 | iteration 15 / 30 | Total Loss: 5.469348907470703 | KNN Loss: 4.410404205322266 | BCE Los

Epoch 54 / 500 | iteration 5 / 30 | Total Loss: 5.455676555633545 | KNN Loss: 4.403257369995117 | BCE Loss: 1.0524191856384277
Epoch 54 / 500 | iteration 10 / 30 | Total Loss: 5.373770713806152 | KNN Loss: 4.362937927246094 | BCE Loss: 1.0108327865600586
Epoch 54 / 500 | iteration 15 / 30 | Total Loss: 5.454716205596924 | KNN Loss: 4.371763229370117 | BCE Loss: 1.0829529762268066
Epoch 54 / 500 | iteration 20 / 30 | Total Loss: 5.394976615905762 | KNN Loss: 4.365742206573486 | BCE Loss: 1.0292344093322754
Epoch 54 / 500 | iteration 25 / 30 | Total Loss: 5.4025068283081055 | KNN Loss: 4.363577842712402 | BCE Loss: 1.038928747177124
Epoch 55 / 500 | iteration 0 / 30 | Total Loss: 5.432860374450684 | KNN Loss: 4.372990608215332 | BCE Loss: 1.059869647026062
Epoch 55 / 500 | iteration 5 / 30 | Total Loss: 5.394655704498291 | KNN Loss: 4.3650360107421875 | BCE Loss: 1.0296196937561035
Epoch 55 / 500 | iteration 10 / 30 | Total Loss: 5.402189254760742 | KNN Loss: 4.371366024017334 | BCE Loss

Epoch    65: reducing learning rate of group 0 to 3.5000e-03.
Epoch 65 / 500 | iteration 0 / 30 | Total Loss: 5.438379287719727 | KNN Loss: 4.365175247192383 | BCE Loss: 1.0732038021087646
Epoch 65 / 500 | iteration 5 / 30 | Total Loss: 5.397183418273926 | KNN Loss: 4.358837604522705 | BCE Loss: 1.0383455753326416
Epoch 65 / 500 | iteration 10 / 30 | Total Loss: 5.391772270202637 | KNN Loss: 4.379667282104492 | BCE Loss: 1.0121049880981445
Epoch 65 / 500 | iteration 15 / 30 | Total Loss: 5.393117904663086 | KNN Loss: 4.383541584014893 | BCE Loss: 1.0095760822296143
Epoch 65 / 500 | iteration 20 / 30 | Total Loss: 5.403880596160889 | KNN Loss: 4.349991798400879 | BCE Loss: 1.0538889169692993
Epoch 65 / 500 | iteration 25 / 30 | Total Loss: 5.401690483093262 | KNN Loss: 4.366482734680176 | BCE Loss: 1.0352078676223755
Epoch 66 / 500 | iteration 0 / 30 | Total Loss: 5.359391689300537 | KNN Loss: 4.335411548614502 | BCE Loss: 1.0239801406860352
Epoch 66 / 500 | iteration 5 / 30 | Total Los

Epoch 75 / 500 | iteration 20 / 30 | Total Loss: 5.373533248901367 | KNN Loss: 4.358471393585205 | BCE Loss: 1.0150617361068726
Epoch 75 / 500 | iteration 25 / 30 | Total Loss: 5.399715423583984 | KNN Loss: 4.359380722045898 | BCE Loss: 1.0403344631195068
Epoch 76 / 500 | iteration 0 / 30 | Total Loss: 5.357043266296387 | KNN Loss: 4.344985008239746 | BCE Loss: 1.0120584964752197
Epoch 76 / 500 | iteration 5 / 30 | Total Loss: 5.422059059143066 | KNN Loss: 4.387399196624756 | BCE Loss: 1.0346601009368896
Epoch 76 / 500 | iteration 10 / 30 | Total Loss: 5.3639702796936035 | KNN Loss: 4.351160526275635 | BCE Loss: 1.0128097534179688
Epoch 76 / 500 | iteration 15 / 30 | Total Loss: 5.408161163330078 | KNN Loss: 4.357083797454834 | BCE Loss: 1.0510776042938232
Epoch 76 / 500 | iteration 20 / 30 | Total Loss: 5.380096912384033 | KNN Loss: 4.3423051834106445 | BCE Loss: 1.0377917289733887
Epoch 76 / 500 | iteration 25 / 30 | Total Loss: 5.393805027008057 | KNN Loss: 4.356424808502197 | BCE L

Epoch 86 / 500 | iteration 15 / 30 | Total Loss: 5.426111221313477 | KNN Loss: 4.402726650238037 | BCE Loss: 1.0233845710754395
Epoch 86 / 500 | iteration 20 / 30 | Total Loss: 5.384760856628418 | KNN Loss: 4.345700740814209 | BCE Loss: 1.039060354232788
Epoch 86 / 500 | iteration 25 / 30 | Total Loss: 5.3843889236450195 | KNN Loss: 4.349915504455566 | BCE Loss: 1.0344735383987427
Epoch 87 / 500 | iteration 0 / 30 | Total Loss: 5.376343250274658 | KNN Loss: 4.386961936950684 | BCE Loss: 0.9893814325332642
Epoch 87 / 500 | iteration 5 / 30 | Total Loss: 5.396801471710205 | KNN Loss: 4.354222297668457 | BCE Loss: 1.0425790548324585
Epoch 87 / 500 | iteration 10 / 30 | Total Loss: 5.42839241027832 | KNN Loss: 4.372525215148926 | BCE Loss: 1.0558671951293945
Epoch 87 / 500 | iteration 15 / 30 | Total Loss: 5.386249542236328 | KNN Loss: 4.344707489013672 | BCE Loss: 1.0415420532226562
Epoch 87 / 500 | iteration 20 / 30 | Total Loss: 5.3686418533325195 | KNN Loss: 4.362735271453857 | BCE Los

Epoch 97 / 500 | iteration 5 / 30 | Total Loss: 5.43046236038208 | KNN Loss: 4.382347106933594 | BCE Loss: 1.0481152534484863
Epoch 97 / 500 | iteration 10 / 30 | Total Loss: 5.405642509460449 | KNN Loss: 4.362426280975342 | BCE Loss: 1.0432161092758179
Epoch 97 / 500 | iteration 15 / 30 | Total Loss: 5.369318008422852 | KNN Loss: 4.351657390594482 | BCE Loss: 1.01766037940979
Epoch 97 / 500 | iteration 20 / 30 | Total Loss: 5.374583721160889 | KNN Loss: 4.366170883178711 | BCE Loss: 1.0084127187728882
Epoch 97 / 500 | iteration 25 / 30 | Total Loss: 5.343667507171631 | KNN Loss: 4.340145111083984 | BCE Loss: 1.003522276878357
Epoch 98 / 500 | iteration 0 / 30 | Total Loss: 5.357397556304932 | KNN Loss: 4.344089984893799 | BCE Loss: 1.0133075714111328
Epoch 98 / 500 | iteration 5 / 30 | Total Loss: 5.34920597076416 | KNN Loss: 4.357241153717041 | BCE Loss: 0.99196457862854
Epoch 98 / 500 | iteration 10 / 30 | Total Loss: 5.355490684509277 | KNN Loss: 4.330639839172363 | BCE Loss: 1.024

Epoch 107 / 500 | iteration 25 / 30 | Total Loss: 5.368741989135742 | KNN Loss: 4.339701175689697 | BCE Loss: 1.0290406942367554
Epoch 108 / 500 | iteration 0 / 30 | Total Loss: 5.372265338897705 | KNN Loss: 4.341917991638184 | BCE Loss: 1.0303473472595215
Epoch 108 / 500 | iteration 5 / 30 | Total Loss: 5.401910781860352 | KNN Loss: 4.36207914352417 | BCE Loss: 1.0398316383361816
Epoch 108 / 500 | iteration 10 / 30 | Total Loss: 5.427125930786133 | KNN Loss: 4.393468856811523 | BCE Loss: 1.033657193183899
Epoch 108 / 500 | iteration 15 / 30 | Total Loss: 5.353525638580322 | KNN Loss: 4.347275733947754 | BCE Loss: 1.0062499046325684
Epoch 108 / 500 | iteration 20 / 30 | Total Loss: 5.37416410446167 | KNN Loss: 4.356557369232178 | BCE Loss: 1.0176067352294922
Epoch 108 / 500 | iteration 25 / 30 | Total Loss: 5.3692145347595215 | KNN Loss: 4.331296443939209 | BCE Loss: 1.037918210029602
Epoch 109 / 500 | iteration 0 / 30 | Total Loss: 5.38596773147583 | KNN Loss: 4.350383281707764 | BCE 

Epoch 118 / 500 | iteration 15 / 30 | Total Loss: 5.371885299682617 | KNN Loss: 4.3273024559021 | BCE Loss: 1.0445830821990967
Epoch 118 / 500 | iteration 20 / 30 | Total Loss: 5.398066997528076 | KNN Loss: 4.3732452392578125 | BCE Loss: 1.0248217582702637
Epoch 118 / 500 | iteration 25 / 30 | Total Loss: 5.381849765777588 | KNN Loss: 4.352941036224365 | BCE Loss: 1.028908610343933
Epoch 119 / 500 | iteration 0 / 30 | Total Loss: 5.355139255523682 | KNN Loss: 4.3359785079956055 | BCE Loss: 1.0191607475280762
Epoch 119 / 500 | iteration 5 / 30 | Total Loss: 5.3712944984436035 | KNN Loss: 4.349690914154053 | BCE Loss: 1.0216037034988403
Epoch 119 / 500 | iteration 10 / 30 | Total Loss: 5.373980522155762 | KNN Loss: 4.353891849517822 | BCE Loss: 1.0200884342193604
Epoch 119 / 500 | iteration 15 / 30 | Total Loss: 5.375360488891602 | KNN Loss: 4.337036609649658 | BCE Loss: 1.0383238792419434
Epoch 119 / 500 | iteration 20 / 30 | Total Loss: 5.370473861694336 | KNN Loss: 4.354710102081299 |

Epoch 129 / 500 | iteration 5 / 30 | Total Loss: 5.396441459655762 | KNN Loss: 4.349708080291748 | BCE Loss: 1.0467332601547241
Epoch 129 / 500 | iteration 10 / 30 | Total Loss: 5.39109992980957 | KNN Loss: 4.377358436584473 | BCE Loss: 1.0137414932250977
Epoch 129 / 500 | iteration 15 / 30 | Total Loss: 5.4089884757995605 | KNN Loss: 4.383548736572266 | BCE Loss: 1.025439739227295
Epoch 129 / 500 | iteration 20 / 30 | Total Loss: 5.416195392608643 | KNN Loss: 4.370363235473633 | BCE Loss: 1.0458320379257202
Epoch 129 / 500 | iteration 25 / 30 | Total Loss: 5.400256633758545 | KNN Loss: 4.376749038696289 | BCE Loss: 1.0235077142715454
Epoch 130 / 500 | iteration 0 / 30 | Total Loss: 5.387774467468262 | KNN Loss: 4.352993488311768 | BCE Loss: 1.0347812175750732
Epoch 130 / 500 | iteration 5 / 30 | Total Loss: 5.346580982208252 | KNN Loss: 4.322575569152832 | BCE Loss: 1.02400541305542
Epoch 130 / 500 | iteration 10 / 30 | Total Loss: 5.333121299743652 | KNN Loss: 4.330036163330078 | BCE

Epoch 139 / 500 | iteration 25 / 30 | Total Loss: 5.369327545166016 | KNN Loss: 4.333723068237305 | BCE Loss: 1.0356043577194214
Epoch 140 / 500 | iteration 0 / 30 | Total Loss: 5.357727527618408 | KNN Loss: 4.3232741355896 | BCE Loss: 1.0344535112380981
Epoch 140 / 500 | iteration 5 / 30 | Total Loss: 5.402801513671875 | KNN Loss: 4.353353977203369 | BCE Loss: 1.0494476556777954
Epoch 140 / 500 | iteration 10 / 30 | Total Loss: 5.34882926940918 | KNN Loss: 4.340889930725098 | BCE Loss: 1.007939338684082
Epoch 140 / 500 | iteration 15 / 30 | Total Loss: 5.380702018737793 | KNN Loss: 4.3462138175964355 | BCE Loss: 1.0344880819320679
Epoch 140 / 500 | iteration 20 / 30 | Total Loss: 5.354951858520508 | KNN Loss: 4.340846061706543 | BCE Loss: 1.0141057968139648
Epoch 140 / 500 | iteration 25 / 30 | Total Loss: 5.375967502593994 | KNN Loss: 4.342637538909912 | BCE Loss: 1.0333298444747925
Epoch 141 / 500 | iteration 0 / 30 | Total Loss: 5.348485946655273 | KNN Loss: 4.347087860107422 | BCE

Epoch 150 / 500 | iteration 15 / 30 | Total Loss: 5.391356468200684 | KNN Loss: 4.337191581726074 | BCE Loss: 1.0541646480560303
Epoch 150 / 500 | iteration 20 / 30 | Total Loss: 5.380334854125977 | KNN Loss: 4.352839469909668 | BCE Loss: 1.0274951457977295
Epoch 150 / 500 | iteration 25 / 30 | Total Loss: 5.37720251083374 | KNN Loss: 4.338650226593018 | BCE Loss: 1.038552165031433
Epoch 151 / 500 | iteration 0 / 30 | Total Loss: 5.391068458557129 | KNN Loss: 4.367131233215332 | BCE Loss: 1.023937463760376
Epoch 151 / 500 | iteration 5 / 30 | Total Loss: 5.3671040534973145 | KNN Loss: 4.345143795013428 | BCE Loss: 1.0219602584838867
Epoch 151 / 500 | iteration 10 / 30 | Total Loss: 5.376875877380371 | KNN Loss: 4.345682144165039 | BCE Loss: 1.0311939716339111
Epoch 151 / 500 | iteration 15 / 30 | Total Loss: 5.3774261474609375 | KNN Loss: 4.337563991546631 | BCE Loss: 1.0398619174957275
Epoch 151 / 500 | iteration 20 / 30 | Total Loss: 5.35535192489624 | KNN Loss: 4.353546619415283 | B

Epoch 161 / 500 | iteration 5 / 30 | Total Loss: 5.384667873382568 | KNN Loss: 4.344208240509033 | BCE Loss: 1.0404596328735352
Epoch 161 / 500 | iteration 10 / 30 | Total Loss: 5.383322715759277 | KNN Loss: 4.340900421142578 | BCE Loss: 1.0424225330352783
Epoch 161 / 500 | iteration 15 / 30 | Total Loss: 5.4023237228393555 | KNN Loss: 4.3553314208984375 | BCE Loss: 1.0469920635223389
Epoch 161 / 500 | iteration 20 / 30 | Total Loss: 5.424866676330566 | KNN Loss: 4.415889263153076 | BCE Loss: 1.0089771747589111
Epoch 161 / 500 | iteration 25 / 30 | Total Loss: 5.378792762756348 | KNN Loss: 4.353711128234863 | BCE Loss: 1.0250813961029053
Epoch   162: reducing learning rate of group 0 to 4.1177e-04.
Epoch 162 / 500 | iteration 0 / 30 | Total Loss: 5.3864426612854 | KNN Loss: 4.3561787605285645 | BCE Loss: 1.030263900756836
Epoch 162 / 500 | iteration 5 / 30 | Total Loss: 5.338135242462158 | KNN Loss: 4.3354363441467285 | BCE Loss: 1.0026987791061401
Epoch 162 / 500 | iteration 10 / 30 |

Epoch 171 / 500 | iteration 25 / 30 | Total Loss: 5.327816963195801 | KNN Loss: 4.338626861572266 | BCE Loss: 0.9891899824142456
Epoch 172 / 500 | iteration 0 / 30 | Total Loss: 5.362992286682129 | KNN Loss: 4.333486080169678 | BCE Loss: 1.0295063257217407
Epoch 172 / 500 | iteration 5 / 30 | Total Loss: 5.363020896911621 | KNN Loss: 4.3705058097839355 | BCE Loss: 0.9925153255462646
Epoch 172 / 500 | iteration 10 / 30 | Total Loss: 5.33917236328125 | KNN Loss: 4.328516960144043 | BCE Loss: 1.0106556415557861
Epoch 172 / 500 | iteration 15 / 30 | Total Loss: 5.359407424926758 | KNN Loss: 4.333460330963135 | BCE Loss: 1.025947093963623
Epoch 172 / 500 | iteration 20 / 30 | Total Loss: 5.374569892883301 | KNN Loss: 4.34463357925415 | BCE Loss: 1.0299365520477295
Epoch 172 / 500 | iteration 25 / 30 | Total Loss: 5.3525800704956055 | KNN Loss: 4.324286937713623 | BCE Loss: 1.028293251991272
Epoch   173: reducing learning rate of group 0 to 2.8824e-04.
Epoch 173 / 500 | iteration 0 / 30 | To

Epoch 182 / 500 | iteration 15 / 30 | Total Loss: 5.417591094970703 | KNN Loss: 4.371377468109131 | BCE Loss: 1.0462136268615723
Epoch 182 / 500 | iteration 20 / 30 | Total Loss: 5.36545991897583 | KNN Loss: 4.356813907623291 | BCE Loss: 1.008646011352539
Epoch 182 / 500 | iteration 25 / 30 | Total Loss: 5.343201637268066 | KNN Loss: 4.347002983093262 | BCE Loss: 0.9961984157562256
Epoch 183 / 500 | iteration 0 / 30 | Total Loss: 5.409252643585205 | KNN Loss: 4.386538028717041 | BCE Loss: 1.0227144956588745
Epoch 183 / 500 | iteration 5 / 30 | Total Loss: 5.384638786315918 | KNN Loss: 4.3416571617126465 | BCE Loss: 1.0429816246032715
Epoch 183 / 500 | iteration 10 / 30 | Total Loss: 5.389078140258789 | KNN Loss: 4.338427543640137 | BCE Loss: 1.050650715827942
Epoch 183 / 500 | iteration 15 / 30 | Total Loss: 5.3370256423950195 | KNN Loss: 4.345485687255859 | BCE Loss: 0.9915398955345154
Epoch 183 / 500 | iteration 20 / 30 | Total Loss: 5.350704193115234 | KNN Loss: 4.346925735473633 | 

Epoch 193 / 500 | iteration 5 / 30 | Total Loss: 5.410998344421387 | KNN Loss: 4.408833980560303 | BCE Loss: 1.0021641254425049
Epoch 193 / 500 | iteration 10 / 30 | Total Loss: 5.330832004547119 | KNN Loss: 4.340184688568115 | BCE Loss: 0.9906473159790039
Epoch 193 / 500 | iteration 15 / 30 | Total Loss: 5.348545074462891 | KNN Loss: 4.33862829208374 | BCE Loss: 1.0099165439605713
Epoch 193 / 500 | iteration 20 / 30 | Total Loss: 5.408653736114502 | KNN Loss: 4.331836223602295 | BCE Loss: 1.0768176317214966
Epoch 193 / 500 | iteration 25 / 30 | Total Loss: 5.385564804077148 | KNN Loss: 4.343106746673584 | BCE Loss: 1.042458176612854
Epoch 194 / 500 | iteration 0 / 30 | Total Loss: 5.412370681762695 | KNN Loss: 4.3549346923828125 | BCE Loss: 1.0574359893798828
Epoch 194 / 500 | iteration 5 / 30 | Total Loss: 5.393054008483887 | KNN Loss: 4.359588146209717 | BCE Loss: 1.0334656238555908
Epoch 194 / 500 | iteration 10 / 30 | Total Loss: 5.380929946899414 | KNN Loss: 4.35763692855835 | BC

Epoch 203 / 500 | iteration 25 / 30 | Total Loss: 5.389151573181152 | KNN Loss: 4.366763591766357 | BCE Loss: 1.022388219833374
Epoch 204 / 500 | iteration 0 / 30 | Total Loss: 5.376889228820801 | KNN Loss: 4.335150718688965 | BCE Loss: 1.041738510131836
Epoch 204 / 500 | iteration 5 / 30 | Total Loss: 5.356049537658691 | KNN Loss: 4.334020614624023 | BCE Loss: 1.0220286846160889
Epoch 204 / 500 | iteration 10 / 30 | Total Loss: 5.369775772094727 | KNN Loss: 4.345573902130127 | BCE Loss: 1.0242018699645996
Epoch 204 / 500 | iteration 15 / 30 | Total Loss: 5.367750644683838 | KNN Loss: 4.340217113494873 | BCE Loss: 1.0275336503982544
Epoch 204 / 500 | iteration 20 / 30 | Total Loss: 5.376115322113037 | KNN Loss: 4.358524322509766 | BCE Loss: 1.0175909996032715
Epoch 204 / 500 | iteration 25 / 30 | Total Loss: 5.387426376342773 | KNN Loss: 4.331012725830078 | BCE Loss: 1.0564136505126953
Epoch 205 / 500 | iteration 0 / 30 | Total Loss: 5.341602325439453 | KNN Loss: 4.36120080947876 | BCE

Epoch 214 / 500 | iteration 15 / 30 | Total Loss: 5.399343967437744 | KNN Loss: 4.348114490509033 | BCE Loss: 1.0512295961380005
Epoch 214 / 500 | iteration 20 / 30 | Total Loss: 5.373453140258789 | KNN Loss: 4.3209147453308105 | BCE Loss: 1.0525381565093994
Epoch 214 / 500 | iteration 25 / 30 | Total Loss: 5.390584468841553 | KNN Loss: 4.36714506149292 | BCE Loss: 1.0234395265579224
Epoch 215 / 500 | iteration 0 / 30 | Total Loss: 5.399969577789307 | KNN Loss: 4.360569477081299 | BCE Loss: 1.0394001007080078
Epoch 215 / 500 | iteration 5 / 30 | Total Loss: 5.3771891593933105 | KNN Loss: 4.330700874328613 | BCE Loss: 1.0464882850646973
Epoch 215 / 500 | iteration 10 / 30 | Total Loss: 5.351072311401367 | KNN Loss: 4.34326171875 | BCE Loss: 1.0078104734420776
Epoch 215 / 500 | iteration 15 / 30 | Total Loss: 5.37966775894165 | KNN Loss: 4.346429347991943 | BCE Loss: 1.0332385301589966
Epoch 215 / 500 | iteration 20 / 30 | Total Loss: 5.352939128875732 | KNN Loss: 4.347033500671387 | BCE

Epoch 225 / 500 | iteration 5 / 30 | Total Loss: 5.384949207305908 | KNN Loss: 4.346677303314209 | BCE Loss: 1.0382719039916992
Epoch 225 / 500 | iteration 10 / 30 | Total Loss: 5.3904829025268555 | KNN Loss: 4.347760200500488 | BCE Loss: 1.0427225828170776
Epoch 225 / 500 | iteration 15 / 30 | Total Loss: 5.367539405822754 | KNN Loss: 4.363221645355225 | BCE Loss: 1.0043175220489502
Epoch 225 / 500 | iteration 20 / 30 | Total Loss: 5.375859260559082 | KNN Loss: 4.355532646179199 | BCE Loss: 1.020326852798462
Epoch 225 / 500 | iteration 25 / 30 | Total Loss: 5.396213054656982 | KNN Loss: 4.337543964385986 | BCE Loss: 1.0586692094802856
Epoch 226 / 500 | iteration 0 / 30 | Total Loss: 5.390443801879883 | KNN Loss: 4.35444450378418 | BCE Loss: 1.0359992980957031
Epoch 226 / 500 | iteration 5 / 30 | Total Loss: 5.39876651763916 | KNN Loss: 4.360323905944824 | BCE Loss: 1.0384423732757568
Epoch 226 / 500 | iteration 10 / 30 | Total Loss: 5.390728950500488 | KNN Loss: 4.369758605957031 | BC

Epoch 235 / 500 | iteration 25 / 30 | Total Loss: 5.412865161895752 | KNN Loss: 4.3706841468811035 | BCE Loss: 1.0421810150146484
Epoch 236 / 500 | iteration 0 / 30 | Total Loss: 5.358555793762207 | KNN Loss: 4.346323490142822 | BCE Loss: 1.0122325420379639
Epoch 236 / 500 | iteration 5 / 30 | Total Loss: 5.370770454406738 | KNN Loss: 4.334403038024902 | BCE Loss: 1.0363671779632568
Epoch 236 / 500 | iteration 10 / 30 | Total Loss: 5.366395473480225 | KNN Loss: 4.34962797164917 | BCE Loss: 1.0167675018310547
Epoch 236 / 500 | iteration 15 / 30 | Total Loss: 5.336621284484863 | KNN Loss: 4.347736835479736 | BCE Loss: 0.9888846278190613
Epoch 236 / 500 | iteration 20 / 30 | Total Loss: 5.32765007019043 | KNN Loss: 4.332512378692627 | BCE Loss: 0.995137631893158
Epoch 236 / 500 | iteration 25 / 30 | Total Loss: 5.452523231506348 | KNN Loss: 4.397891998291016 | BCE Loss: 1.054631233215332
Epoch 237 / 500 | iteration 0 / 30 | Total Loss: 5.383674621582031 | KNN Loss: 4.338926315307617 | BCE

Epoch 246 / 500 | iteration 15 / 30 | Total Loss: 5.354574680328369 | KNN Loss: 4.335957050323486 | BCE Loss: 1.0186176300048828
Epoch 246 / 500 | iteration 20 / 30 | Total Loss: 5.380064964294434 | KNN Loss: 4.35535192489624 | BCE Loss: 1.0247132778167725
Epoch 246 / 500 | iteration 25 / 30 | Total Loss: 5.384084701538086 | KNN Loss: 4.3588972091674805 | BCE Loss: 1.0251874923706055
Epoch 247 / 500 | iteration 0 / 30 | Total Loss: 5.325982093811035 | KNN Loss: 4.323070049285889 | BCE Loss: 1.0029120445251465
Epoch 247 / 500 | iteration 5 / 30 | Total Loss: 5.384626388549805 | KNN Loss: 4.367661476135254 | BCE Loss: 1.0169649124145508
Epoch 247 / 500 | iteration 10 / 30 | Total Loss: 5.344942569732666 | KNN Loss: 4.336441516876221 | BCE Loss: 1.0085011720657349
Epoch 247 / 500 | iteration 15 / 30 | Total Loss: 5.351956367492676 | KNN Loss: 4.338168621063232 | BCE Loss: 1.0137879848480225
Epoch 247 / 500 | iteration 20 / 30 | Total Loss: 5.340036392211914 | KNN Loss: 4.320937633514404 |

Epoch 257 / 500 | iteration 5 / 30 | Total Loss: 5.336450099945068 | KNN Loss: 4.33142614364624 | BCE Loss: 1.0050240755081177
Epoch 257 / 500 | iteration 10 / 30 | Total Loss: 5.3653950691223145 | KNN Loss: 4.321305274963379 | BCE Loss: 1.0440897941589355
Epoch 257 / 500 | iteration 15 / 30 | Total Loss: 5.353693008422852 | KNN Loss: 4.351749897003174 | BCE Loss: 1.0019428730010986
Epoch 257 / 500 | iteration 20 / 30 | Total Loss: 5.33466100692749 | KNN Loss: 4.326825141906738 | BCE Loss: 1.007835865020752
Epoch 257 / 500 | iteration 25 / 30 | Total Loss: 5.346556663513184 | KNN Loss: 4.342149257659912 | BCE Loss: 1.0044071674346924
Epoch 258 / 500 | iteration 0 / 30 | Total Loss: 5.420558452606201 | KNN Loss: 4.380152702331543 | BCE Loss: 1.0404057502746582
Epoch 258 / 500 | iteration 5 / 30 | Total Loss: 5.421347618103027 | KNN Loss: 4.365427494049072 | BCE Loss: 1.055919885635376
Epoch 258 / 500 | iteration 10 / 30 | Total Loss: 5.353852272033691 | KNN Loss: 4.347198009490967 | BCE

Epoch 267 / 500 | iteration 25 / 30 | Total Loss: 5.376496315002441 | KNN Loss: 4.369040489196777 | BCE Loss: 1.007455825805664
Epoch 268 / 500 | iteration 0 / 30 | Total Loss: 5.377857685089111 | KNN Loss: 4.3503804206848145 | BCE Loss: 1.0274772644042969
Epoch 268 / 500 | iteration 5 / 30 | Total Loss: 5.363445281982422 | KNN Loss: 4.351357936859131 | BCE Loss: 1.0120872259140015
Epoch 268 / 500 | iteration 10 / 30 | Total Loss: 5.406569957733154 | KNN Loss: 4.3738226890563965 | BCE Loss: 1.0327472686767578
Epoch 268 / 500 | iteration 15 / 30 | Total Loss: 5.416513442993164 | KNN Loss: 4.411596775054932 | BCE Loss: 1.0049166679382324
Epoch 268 / 500 | iteration 20 / 30 | Total Loss: 5.370360851287842 | KNN Loss: 4.336607933044434 | BCE Loss: 1.0337527990341187
Epoch 268 / 500 | iteration 25 / 30 | Total Loss: 5.353951454162598 | KNN Loss: 4.338425159454346 | BCE Loss: 1.015526294708252
Epoch 269 / 500 | iteration 0 / 30 | Total Loss: 5.432212829589844 | KNN Loss: 4.389011383056641 | 

Epoch 278 / 500 | iteration 15 / 30 | Total Loss: 5.352625370025635 | KNN Loss: 4.32868766784668 | BCE Loss: 1.023937702178955
Epoch 278 / 500 | iteration 20 / 30 | Total Loss: 5.3901896476745605 | KNN Loss: 4.339118480682373 | BCE Loss: 1.0510711669921875
Epoch 278 / 500 | iteration 25 / 30 | Total Loss: 5.378846168518066 | KNN Loss: 4.340297698974609 | BCE Loss: 1.0385485887527466
Epoch 279 / 500 | iteration 0 / 30 | Total Loss: 5.337159633636475 | KNN Loss: 4.340985298156738 | BCE Loss: 0.9961743354797363
Epoch 279 / 500 | iteration 5 / 30 | Total Loss: 5.3386077880859375 | KNN Loss: 4.332905292510986 | BCE Loss: 1.0057027339935303
Epoch 279 / 500 | iteration 10 / 30 | Total Loss: 5.363412380218506 | KNN Loss: 4.346210479736328 | BCE Loss: 1.0172020196914673
Epoch 279 / 500 | iteration 15 / 30 | Total Loss: 5.414684295654297 | KNN Loss: 4.385612964630127 | BCE Loss: 1.0290714502334595
Epoch 279 / 500 | iteration 20 / 30 | Total Loss: 5.3822712898254395 | KNN Loss: 4.362052917480469 

Epoch 289 / 500 | iteration 5 / 30 | Total Loss: 5.3787384033203125 | KNN Loss: 4.349741458892822 | BCE Loss: 1.0289967060089111
Epoch 289 / 500 | iteration 10 / 30 | Total Loss: 5.391236305236816 | KNN Loss: 4.364835739135742 | BCE Loss: 1.0264006853103638
Epoch 289 / 500 | iteration 15 / 30 | Total Loss: 5.386870861053467 | KNN Loss: 4.3490309715271 | BCE Loss: 1.0378398895263672
Epoch 289 / 500 | iteration 20 / 30 | Total Loss: 5.380796432495117 | KNN Loss: 4.343257904052734 | BCE Loss: 1.0375384092330933
Epoch 289 / 500 | iteration 25 / 30 | Total Loss: 5.366450786590576 | KNN Loss: 4.345076084136963 | BCE Loss: 1.0213747024536133
Epoch 290 / 500 | iteration 0 / 30 | Total Loss: 5.3672027587890625 | KNN Loss: 4.343896865844727 | BCE Loss: 1.023305892944336
Epoch 290 / 500 | iteration 5 / 30 | Total Loss: 5.381892681121826 | KNN Loss: 4.345914363861084 | BCE Loss: 1.0359783172607422
Epoch 290 / 500 | iteration 10 / 30 | Total Loss: 5.389949798583984 | KNN Loss: 4.359908103942871 | B

Epoch 299 / 500 | iteration 25 / 30 | Total Loss: 5.387683868408203 | KNN Loss: 4.339895248413086 | BCE Loss: 1.0477887392044067
Epoch 300 / 500 | iteration 0 / 30 | Total Loss: 5.371133327484131 | KNN Loss: 4.343096733093262 | BCE Loss: 1.0280365943908691
Epoch 300 / 500 | iteration 5 / 30 | Total Loss: 5.410622596740723 | KNN Loss: 4.3871684074401855 | BCE Loss: 1.023453950881958
Epoch 300 / 500 | iteration 10 / 30 | Total Loss: 5.35328483581543 | KNN Loss: 4.3300275802612305 | BCE Loss: 1.0232571363449097
Epoch 300 / 500 | iteration 15 / 30 | Total Loss: 5.420668601989746 | KNN Loss: 4.387984275817871 | BCE Loss: 1.032684564590454
Epoch 300 / 500 | iteration 20 / 30 | Total Loss: 5.409950256347656 | KNN Loss: 4.366508483886719 | BCE Loss: 1.0434417724609375
Epoch 300 / 500 | iteration 25 / 30 | Total Loss: 5.461366176605225 | KNN Loss: 4.411252975463867 | BCE Loss: 1.050113320350647
Epoch 301 / 500 | iteration 0 / 30 | Total Loss: 5.426380157470703 | KNN Loss: 4.364851951599121 | BC

Epoch 310 / 500 | iteration 15 / 30 | Total Loss: 5.36876106262207 | KNN Loss: 4.344345569610596 | BCE Loss: 1.0244154930114746
Epoch 310 / 500 | iteration 20 / 30 | Total Loss: 5.328984260559082 | KNN Loss: 4.330155372619629 | BCE Loss: 0.9988290071487427
Epoch 310 / 500 | iteration 25 / 30 | Total Loss: 5.389225959777832 | KNN Loss: 4.3387274742126465 | BCE Loss: 1.0504987239837646
Epoch 311 / 500 | iteration 0 / 30 | Total Loss: 5.375657081604004 | KNN Loss: 4.353789806365967 | BCE Loss: 1.0218675136566162
Epoch 311 / 500 | iteration 5 / 30 | Total Loss: 5.385025978088379 | KNN Loss: 4.366623401641846 | BCE Loss: 1.0184025764465332
Epoch 311 / 500 | iteration 10 / 30 | Total Loss: 5.340858459472656 | KNN Loss: 4.337552547454834 | BCE Loss: 1.0033060312271118
Epoch 311 / 500 | iteration 15 / 30 | Total Loss: 5.411264896392822 | KNN Loss: 4.367744445800781 | BCE Loss: 1.0435205698013306
Epoch 311 / 500 | iteration 20 / 30 | Total Loss: 5.374178409576416 | KNN Loss: 4.335440158843994 |

Epoch 321 / 500 | iteration 5 / 30 | Total Loss: 5.345055103302002 | KNN Loss: 4.333880424499512 | BCE Loss: 1.0111747980117798
Epoch 321 / 500 | iteration 10 / 30 | Total Loss: 5.370384216308594 | KNN Loss: 4.35182523727417 | BCE Loss: 1.018559217453003
Epoch 321 / 500 | iteration 15 / 30 | Total Loss: 5.418431282043457 | KNN Loss: 4.357544898986816 | BCE Loss: 1.0608866214752197
Epoch 321 / 500 | iteration 20 / 30 | Total Loss: 5.364306449890137 | KNN Loss: 4.331836700439453 | BCE Loss: 1.0324697494506836
Epoch 321 / 500 | iteration 25 / 30 | Total Loss: 5.391447067260742 | KNN Loss: 4.365975379943848 | BCE Loss: 1.0254719257354736
Epoch 322 / 500 | iteration 0 / 30 | Total Loss: 5.417884826660156 | KNN Loss: 4.382458686828613 | BCE Loss: 1.035426139831543
Epoch 322 / 500 | iteration 5 / 30 | Total Loss: 5.395211219787598 | KNN Loss: 4.339386463165283 | BCE Loss: 1.0558247566223145
Epoch 322 / 500 | iteration 10 / 30 | Total Loss: 5.442824363708496 | KNN Loss: 4.388823509216309 | BCE

Epoch 331 / 500 | iteration 25 / 30 | Total Loss: 5.360718727111816 | KNN Loss: 4.329967498779297 | BCE Loss: 1.0307509899139404
Epoch 332 / 500 | iteration 0 / 30 | Total Loss: 5.404683589935303 | KNN Loss: 4.355605602264404 | BCE Loss: 1.049078106880188
Epoch 332 / 500 | iteration 5 / 30 | Total Loss: 5.344674110412598 | KNN Loss: 4.338256359100342 | BCE Loss: 1.0064175128936768
Epoch 332 / 500 | iteration 10 / 30 | Total Loss: 5.335905075073242 | KNN Loss: 4.3275837898254395 | BCE Loss: 1.0083215236663818
Epoch 332 / 500 | iteration 15 / 30 | Total Loss: 5.357287406921387 | KNN Loss: 4.341869354248047 | BCE Loss: 1.0154178142547607
Epoch 332 / 500 | iteration 20 / 30 | Total Loss: 5.359344482421875 | KNN Loss: 4.31995964050293 | BCE Loss: 1.0393846035003662
Epoch 332 / 500 | iteration 25 / 30 | Total Loss: 5.398453712463379 | KNN Loss: 4.354249954223633 | BCE Loss: 1.044203519821167
Epoch 333 / 500 | iteration 0 / 30 | Total Loss: 5.413182735443115 | KNN Loss: 4.349185466766357 | BC

Epoch 342 / 500 | iteration 15 / 30 | Total Loss: 5.46096658706665 | KNN Loss: 4.421188831329346 | BCE Loss: 1.0397777557373047
Epoch 342 / 500 | iteration 20 / 30 | Total Loss: 5.417422771453857 | KNN Loss: 4.410400867462158 | BCE Loss: 1.0070219039916992
Epoch 342 / 500 | iteration 25 / 30 | Total Loss: 5.393093109130859 | KNN Loss: 4.359714508056641 | BCE Loss: 1.0333783626556396
Epoch 343 / 500 | iteration 0 / 30 | Total Loss: 5.368679046630859 | KNN Loss: 4.329094886779785 | BCE Loss: 1.0395841598510742
Epoch 343 / 500 | iteration 5 / 30 | Total Loss: 5.367444038391113 | KNN Loss: 4.35665225982666 | BCE Loss: 1.0107920169830322
Epoch 343 / 500 | iteration 10 / 30 | Total Loss: 5.357845783233643 | KNN Loss: 4.334264278411865 | BCE Loss: 1.0235815048217773
Epoch 343 / 500 | iteration 15 / 30 | Total Loss: 5.416504859924316 | KNN Loss: 4.39518404006958 | BCE Loss: 1.0213205814361572
Epoch 343 / 500 | iteration 20 / 30 | Total Loss: 5.343430995941162 | KNN Loss: 4.336134910583496 | BC

Epoch 353 / 500 | iteration 5 / 30 | Total Loss: 5.343161582946777 | KNN Loss: 4.318607330322266 | BCE Loss: 1.0245540142059326
Epoch 353 / 500 | iteration 10 / 30 | Total Loss: 5.372434616088867 | KNN Loss: 4.356947422027588 | BCE Loss: 1.0154869556427002
Epoch 353 / 500 | iteration 15 / 30 | Total Loss: 5.413322448730469 | KNN Loss: 4.338711738586426 | BCE Loss: 1.074610710144043
Epoch 353 / 500 | iteration 20 / 30 | Total Loss: 5.3520588874816895 | KNN Loss: 4.337334632873535 | BCE Loss: 1.0147241353988647
Epoch 353 / 500 | iteration 25 / 30 | Total Loss: 5.358988285064697 | KNN Loss: 4.341034412384033 | BCE Loss: 1.017953872680664
Epoch 354 / 500 | iteration 0 / 30 | Total Loss: 5.364492416381836 | KNN Loss: 4.342912197113037 | BCE Loss: 1.0215799808502197
Epoch 354 / 500 | iteration 5 / 30 | Total Loss: 5.353063106536865 | KNN Loss: 4.3678131103515625 | BCE Loss: 0.9852499961853027
Epoch 354 / 500 | iteration 10 / 30 | Total Loss: 5.353179454803467 | KNN Loss: 4.3340911865234375 |

Epoch 363 / 500 | iteration 25 / 30 | Total Loss: 5.363353729248047 | KNN Loss: 4.346924304962158 | BCE Loss: 1.0164294242858887
Epoch 364 / 500 | iteration 0 / 30 | Total Loss: 5.385161876678467 | KNN Loss: 4.328492164611816 | BCE Loss: 1.05666983127594
Epoch 364 / 500 | iteration 5 / 30 | Total Loss: 5.342675685882568 | KNN Loss: 4.341615676879883 | BCE Loss: 1.0010600090026855
Epoch 364 / 500 | iteration 10 / 30 | Total Loss: 5.385725498199463 | KNN Loss: 4.344820976257324 | BCE Loss: 1.0409046411514282
Epoch 364 / 500 | iteration 15 / 30 | Total Loss: 5.357307434082031 | KNN Loss: 4.336922645568848 | BCE Loss: 1.0203845500946045
Epoch 364 / 500 | iteration 20 / 30 | Total Loss: 5.41559362411499 | KNN Loss: 4.374716281890869 | BCE Loss: 1.0408774614334106
Epoch 364 / 500 | iteration 25 / 30 | Total Loss: 5.398985862731934 | KNN Loss: 4.345499515533447 | BCE Loss: 1.0534861087799072
Epoch 365 / 500 | iteration 0 / 30 | Total Loss: 5.34314489364624 | KNN Loss: 4.338852405548096 | BCE 

Epoch 374 / 500 | iteration 15 / 30 | Total Loss: 5.39036226272583 | KNN Loss: 4.346982955932617 | BCE Loss: 1.0433794260025024
Epoch 374 / 500 | iteration 20 / 30 | Total Loss: 5.384981632232666 | KNN Loss: 4.382969379425049 | BCE Loss: 1.0020123720169067
Epoch 374 / 500 | iteration 25 / 30 | Total Loss: 5.378666877746582 | KNN Loss: 4.36087703704834 | BCE Loss: 1.0177898406982422
Epoch 375 / 500 | iteration 0 / 30 | Total Loss: 5.414875030517578 | KNN Loss: 4.375783920288086 | BCE Loss: 1.0390911102294922
Epoch 375 / 500 | iteration 5 / 30 | Total Loss: 5.398048400878906 | KNN Loss: 4.36453104019165 | BCE Loss: 1.033517599105835
Epoch 375 / 500 | iteration 10 / 30 | Total Loss: 5.4039387702941895 | KNN Loss: 4.356095790863037 | BCE Loss: 1.0478428602218628
Epoch 375 / 500 | iteration 15 / 30 | Total Loss: 5.358743190765381 | KNN Loss: 4.34894323348999 | BCE Loss: 1.0097999572753906
Epoch 375 / 500 | iteration 20 / 30 | Total Loss: 5.36942195892334 | KNN Loss: 4.33769416809082 | BCE L

Epoch 385 / 500 | iteration 5 / 30 | Total Loss: 5.3524041175842285 | KNN Loss: 4.342040061950684 | BCE Loss: 1.0103641748428345
Epoch 385 / 500 | iteration 10 / 30 | Total Loss: 5.353541851043701 | KNN Loss: 4.3537397384643555 | BCE Loss: 0.9998019933700562
Epoch 385 / 500 | iteration 15 / 30 | Total Loss: 5.38395357131958 | KNN Loss: 4.351023197174072 | BCE Loss: 1.0329303741455078
Epoch 385 / 500 | iteration 20 / 30 | Total Loss: 5.377869606018066 | KNN Loss: 4.358887672424316 | BCE Loss: 1.018982172012329
Epoch 385 / 500 | iteration 25 / 30 | Total Loss: 5.3930768966674805 | KNN Loss: 4.344080448150635 | BCE Loss: 1.0489966869354248
Epoch 386 / 500 | iteration 0 / 30 | Total Loss: 5.341936111450195 | KNN Loss: 4.316619396209717 | BCE Loss: 1.0253164768218994
Epoch 386 / 500 | iteration 5 / 30 | Total Loss: 5.371123790740967 | KNN Loss: 4.340087413787842 | BCE Loss: 1.0310362577438354
Epoch 386 / 500 | iteration 10 / 30 | Total Loss: 5.378000259399414 | KNN Loss: 4.341484069824219 |

Epoch 395 / 500 | iteration 25 / 30 | Total Loss: 5.393837928771973 | KNN Loss: 4.379480361938477 | BCE Loss: 1.014357566833496
Epoch 396 / 500 | iteration 0 / 30 | Total Loss: 5.371673583984375 | KNN Loss: 4.357054233551025 | BCE Loss: 1.0146191120147705
Epoch 396 / 500 | iteration 5 / 30 | Total Loss: 5.400928497314453 | KNN Loss: 4.375899791717529 | BCE Loss: 1.025028944015503
Epoch 396 / 500 | iteration 10 / 30 | Total Loss: 5.381948471069336 | KNN Loss: 4.352011203765869 | BCE Loss: 1.0299371480941772
Epoch 396 / 500 | iteration 15 / 30 | Total Loss: 5.386092185974121 | KNN Loss: 4.353244304656982 | BCE Loss: 1.0328478813171387
Epoch 396 / 500 | iteration 20 / 30 | Total Loss: 5.378301620483398 | KNN Loss: 4.3560686111450195 | BCE Loss: 1.0222328901290894
Epoch 396 / 500 | iteration 25 / 30 | Total Loss: 5.376584529876709 | KNN Loss: 4.3347487449646 | BCE Loss: 1.0418356657028198
Epoch 397 / 500 | iteration 0 / 30 | Total Loss: 5.3806047439575195 | KNN Loss: 4.353725433349609 | BC

Epoch 406 / 500 | iteration 15 / 30 | Total Loss: 5.3968095779418945 | KNN Loss: 4.369535446166992 | BCE Loss: 1.0272743701934814
Epoch 406 / 500 | iteration 20 / 30 | Total Loss: 5.375361442565918 | KNN Loss: 4.363291263580322 | BCE Loss: 1.0120701789855957
Epoch 406 / 500 | iteration 25 / 30 | Total Loss: 5.391444683074951 | KNN Loss: 4.34165620803833 | BCE Loss: 1.0497885942459106
Epoch 407 / 500 | iteration 0 / 30 | Total Loss: 5.363994598388672 | KNN Loss: 4.335515975952148 | BCE Loss: 1.0284783840179443
Epoch 407 / 500 | iteration 5 / 30 | Total Loss: 5.3449296951293945 | KNN Loss: 4.374381065368652 | BCE Loss: 0.970548689365387
Epoch 407 / 500 | iteration 10 / 30 | Total Loss: 5.374638080596924 | KNN Loss: 4.353221893310547 | BCE Loss: 1.021416187286377
Epoch 407 / 500 | iteration 15 / 30 | Total Loss: 5.3875627517700195 | KNN Loss: 4.3436174392700195 | BCE Loss: 1.043945074081421
Epoch 407 / 500 | iteration 20 / 30 | Total Loss: 5.354793548583984 | KNN Loss: 4.34135103225708 | 

Epoch 417 / 500 | iteration 5 / 30 | Total Loss: 5.3681535720825195 | KNN Loss: 4.3369598388671875 | BCE Loss: 1.0311938524246216
Epoch 417 / 500 | iteration 10 / 30 | Total Loss: 5.4168701171875 | KNN Loss: 4.372065544128418 | BCE Loss: 1.044804334640503
Epoch 417 / 500 | iteration 15 / 30 | Total Loss: 5.358777046203613 | KNN Loss: 4.340908527374268 | BCE Loss: 1.0178682804107666
Epoch 417 / 500 | iteration 20 / 30 | Total Loss: 5.35737419128418 | KNN Loss: 4.336169242858887 | BCE Loss: 1.0212047100067139
Epoch 417 / 500 | iteration 25 / 30 | Total Loss: 5.38026237487793 | KNN Loss: 4.347719669342041 | BCE Loss: 1.0325424671173096
Epoch 418 / 500 | iteration 0 / 30 | Total Loss: 5.374818801879883 | KNN Loss: 4.34874153137207 | BCE Loss: 1.0260775089263916
Epoch 418 / 500 | iteration 5 / 30 | Total Loss: 5.364240646362305 | KNN Loss: 4.340979099273682 | BCE Loss: 1.0232617855072021
Epoch 418 / 500 | iteration 10 / 30 | Total Loss: 5.40300989151001 | KNN Loss: 4.382999420166016 | BCE L

Epoch 427 / 500 | iteration 25 / 30 | Total Loss: 5.339339733123779 | KNN Loss: 4.328603744506836 | BCE Loss: 1.0107359886169434
Epoch 428 / 500 | iteration 0 / 30 | Total Loss: 5.359095573425293 | KNN Loss: 4.332515239715576 | BCE Loss: 1.0265803337097168
Epoch 428 / 500 | iteration 5 / 30 | Total Loss: 5.374670505523682 | KNN Loss: 4.3401665687561035 | BCE Loss: 1.0345039367675781
Epoch 428 / 500 | iteration 10 / 30 | Total Loss: 5.368521690368652 | KNN Loss: 4.332732677459717 | BCE Loss: 1.0357887744903564
Epoch 428 / 500 | iteration 15 / 30 | Total Loss: 5.359218120574951 | KNN Loss: 4.345033645629883 | BCE Loss: 1.0141844749450684
Epoch 428 / 500 | iteration 20 / 30 | Total Loss: 5.356226921081543 | KNN Loss: 4.326531410217285 | BCE Loss: 1.0296955108642578
Epoch 428 / 500 | iteration 25 / 30 | Total Loss: 5.3976149559021 | KNN Loss: 4.345983982086182 | BCE Loss: 1.0516308546066284
Epoch 429 / 500 | iteration 0 / 30 | Total Loss: 5.468920707702637 | KNN Loss: 4.413146495819092 | B

Epoch 438 / 500 | iteration 15 / 30 | Total Loss: 5.383783340454102 | KNN Loss: 4.346730709075928 | BCE Loss: 1.0370525121688843
Epoch 438 / 500 | iteration 20 / 30 | Total Loss: 5.375587463378906 | KNN Loss: 4.350536346435547 | BCE Loss: 1.0250508785247803
Epoch 438 / 500 | iteration 25 / 30 | Total Loss: 5.395153522491455 | KNN Loss: 4.3513336181640625 | BCE Loss: 1.0438199043273926
Epoch 439 / 500 | iteration 0 / 30 | Total Loss: 5.381987571716309 | KNN Loss: 4.383342266082764 | BCE Loss: 0.9986453652381897
Epoch 439 / 500 | iteration 5 / 30 | Total Loss: 5.3755598068237305 | KNN Loss: 4.357601165771484 | BCE Loss: 1.017958641052246
Epoch 439 / 500 | iteration 10 / 30 | Total Loss: 5.335978031158447 | KNN Loss: 4.339078426361084 | BCE Loss: 0.9968996047973633
Epoch 439 / 500 | iteration 15 / 30 | Total Loss: 5.366513252258301 | KNN Loss: 4.342787265777588 | BCE Loss: 1.023726224899292
Epoch 439 / 500 | iteration 20 / 30 | Total Loss: 5.38269567489624 | KNN Loss: 4.350576877593994 | 

Epoch 449 / 500 | iteration 5 / 30 | Total Loss: 5.3634033203125 | KNN Loss: 4.333193778991699 | BCE Loss: 1.0302095413208008
Epoch 449 / 500 | iteration 10 / 30 | Total Loss: 5.404452800750732 | KNN Loss: 4.383240699768066 | BCE Loss: 1.0212122201919556
Epoch 449 / 500 | iteration 15 / 30 | Total Loss: 5.322530269622803 | KNN Loss: 4.326475143432617 | BCE Loss: 0.9960550665855408
Epoch 449 / 500 | iteration 20 / 30 | Total Loss: 5.410157203674316 | KNN Loss: 4.3770952224731445 | BCE Loss: 1.033062219619751
Epoch 449 / 500 | iteration 25 / 30 | Total Loss: 5.386936664581299 | KNN Loss: 4.351184844970703 | BCE Loss: 1.0357518196105957
Epoch 450 / 500 | iteration 0 / 30 | Total Loss: 5.3716936111450195 | KNN Loss: 4.333949089050293 | BCE Loss: 1.0377447605133057
Epoch 450 / 500 | iteration 5 / 30 | Total Loss: 5.333678722381592 | KNN Loss: 4.344855308532715 | BCE Loss: 0.9888232946395874
Epoch 450 / 500 | iteration 10 / 30 | Total Loss: 5.397200107574463 | KNN Loss: 4.3757147789001465 | 

Epoch 459 / 500 | iteration 25 / 30 | Total Loss: 5.342750549316406 | KNN Loss: 4.327213287353516 | BCE Loss: 1.0155372619628906
Epoch 460 / 500 | iteration 0 / 30 | Total Loss: 5.390583038330078 | KNN Loss: 4.345832824707031 | BCE Loss: 1.0447500944137573
Epoch 460 / 500 | iteration 5 / 30 | Total Loss: 5.360966205596924 | KNN Loss: 4.330828666687012 | BCE Loss: 1.030137538909912
Epoch 460 / 500 | iteration 10 / 30 | Total Loss: 5.360551834106445 | KNN Loss: 4.328232288360596 | BCE Loss: 1.0323193073272705
Epoch 460 / 500 | iteration 15 / 30 | Total Loss: 5.377540111541748 | KNN Loss: 4.339588165283203 | BCE Loss: 1.0379518270492554
Epoch 460 / 500 | iteration 20 / 30 | Total Loss: 5.392430782318115 | KNN Loss: 4.370214462280273 | BCE Loss: 1.0222163200378418
Epoch 460 / 500 | iteration 25 / 30 | Total Loss: 5.371382713317871 | KNN Loss: 4.358486652374268 | BCE Loss: 1.0128958225250244
Epoch 461 / 500 | iteration 0 / 30 | Total Loss: 5.382922172546387 | KNN Loss: 4.359556674957275 | B

Epoch 470 / 500 | iteration 15 / 30 | Total Loss: 5.358429431915283 | KNN Loss: 4.32494592666626 | BCE Loss: 1.0334833860397339
Epoch 470 / 500 | iteration 20 / 30 | Total Loss: 5.368271350860596 | KNN Loss: 4.356976509094238 | BCE Loss: 1.011294960975647
Epoch 470 / 500 | iteration 25 / 30 | Total Loss: 5.3690080642700195 | KNN Loss: 4.339217662811279 | BCE Loss: 1.0297901630401611
Epoch 471 / 500 | iteration 0 / 30 | Total Loss: 5.34496545791626 | KNN Loss: 4.337339878082275 | BCE Loss: 1.0076255798339844
Epoch 471 / 500 | iteration 5 / 30 | Total Loss: 5.389900207519531 | KNN Loss: 4.372283458709717 | BCE Loss: 1.0176165103912354
Epoch 471 / 500 | iteration 10 / 30 | Total Loss: 5.404962539672852 | KNN Loss: 4.372114658355713 | BCE Loss: 1.0328476428985596
Epoch 471 / 500 | iteration 15 / 30 | Total Loss: 5.380067825317383 | KNN Loss: 4.3546648025512695 | BCE Loss: 1.0254027843475342
Epoch 471 / 500 | iteration 20 / 30 | Total Loss: 5.427299499511719 | KNN Loss: 4.386161804199219 | 

Epoch 481 / 500 | iteration 5 / 30 | Total Loss: 5.369363784790039 | KNN Loss: 4.339537620544434 | BCE Loss: 1.0298261642456055
Epoch 481 / 500 | iteration 10 / 30 | Total Loss: 5.389568328857422 | KNN Loss: 4.364975929260254 | BCE Loss: 1.024592399597168
Epoch 481 / 500 | iteration 15 / 30 | Total Loss: 5.369410514831543 | KNN Loss: 4.359279155731201 | BCE Loss: 1.0101313591003418
Epoch 481 / 500 | iteration 20 / 30 | Total Loss: 5.3662590980529785 | KNN Loss: 4.34240198135376 | BCE Loss: 1.0238569974899292
Epoch 481 / 500 | iteration 25 / 30 | Total Loss: 5.434374809265137 | KNN Loss: 4.3694376945495605 | BCE Loss: 1.0649373531341553
Epoch 482 / 500 | iteration 0 / 30 | Total Loss: 5.386667728424072 | KNN Loss: 4.363371849060059 | BCE Loss: 1.0232957601547241
Epoch 482 / 500 | iteration 5 / 30 | Total Loss: 5.3647918701171875 | KNN Loss: 4.347118854522705 | BCE Loss: 1.0176732540130615
Epoch 482 / 500 | iteration 10 / 30 | Total Loss: 5.4127044677734375 | KNN Loss: 4.358959197998047 

Epoch 491 / 500 | iteration 25 / 30 | Total Loss: 5.353229522705078 | KNN Loss: 4.346161365509033 | BCE Loss: 1.007068395614624
Epoch 492 / 500 | iteration 0 / 30 | Total Loss: 5.421252250671387 | KNN Loss: 4.371200084686279 | BCE Loss: 1.0500521659851074
Epoch 492 / 500 | iteration 5 / 30 | Total Loss: 5.3846845626831055 | KNN Loss: 4.335709571838379 | BCE Loss: 1.0489747524261475
Epoch 492 / 500 | iteration 10 / 30 | Total Loss: 5.403069496154785 | KNN Loss: 4.333566665649414 | BCE Loss: 1.0695027112960815
Epoch 492 / 500 | iteration 15 / 30 | Total Loss: 5.36777400970459 | KNN Loss: 4.340906143188477 | BCE Loss: 1.0268676280975342
Epoch 492 / 500 | iteration 20 / 30 | Total Loss: 5.383893966674805 | KNN Loss: 4.3631415367126465 | BCE Loss: 1.020752191543579
Epoch 492 / 500 | iteration 25 / 30 | Total Loss: 5.348280429840088 | KNN Loss: 4.353436470031738 | BCE Loss: 0.9948439598083496
Epoch 493 / 500 | iteration 0 / 30 | Total Loss: 5.375085353851318 | KNN Loss: 4.338935852050781 | B

In [7]:
outputs, iterm = model(dataset[67][0].unsqueeze(0).to(device), return_intermidiate=True)
print(outputs)
print(outputs.min())
print(outputs.max())

tensor([[ 2.3926,  4.0860,  2.8788,  3.9504,  2.8986,  0.4837,  2.7811,  1.9030,
          2.1603,  2.3338,  2.5971,  2.5487,  1.0050,  1.4871,  1.5329,  1.5996,
          2.2310,  3.5485,  2.7080,  1.6907,  2.0699,  2.6769,  2.0625,  2.9899,
          1.9688,  2.0603,  2.1489,  1.4235,  1.8046,  0.2838, -0.4469,  0.9981,
         -0.2694,  0.8887,  1.6710,  1.5721,  0.9988,  3.7110,  1.0052,  1.3846,
          0.9644, -0.8202, -0.4330,  2.1562,  2.2686,  0.7811, -0.1593,  0.0440,
          1.2229,  2.8391,  2.1248,  0.1024,  1.5054,  0.5918, -0.5685,  1.0243,
          1.4711,  1.4002,  1.3616,  2.1332,  0.5040,  0.9411,  0.2120,  1.7772,
          1.1621,  1.6978, -2.1811,  0.0616,  2.3492,  2.2480,  1.9991,  0.3823,
          1.1766,  2.1853,  2.2963,  1.0241,  0.3581,  0.8016,  0.1199,  1.6242,
          0.0295,  0.4071,  1.8388, -0.4244,  0.2974, -1.0537, -2.5488, -0.3111,
          0.5336, -2.0137,  0.4038, -0.1622, -0.6053, -0.9296,  0.4635,  1.2992,
         -0.6913, -0.6401,  

In [8]:
plt.figure()
plt.plot(losses)
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [9]:
dataset.transform = torchvision.transforms.Compose([
    BinaryEncodingTransform(mapping=dataset.items_to_idx),
]
)
dataset.target_transform = torchvision.transforms.Compose([
    BinaryEncodingTransform(mapping=dataset.items_to_idx),
]
)

In [10]:
dataset_ = [d[0].cpu() for d in dataset]

In [11]:
model = model.eval().cpu()
projections = model.calculate_intermidiate(dataset_)

100%|██████████| 15/15 [00:00<00:00, 88.88it/s] 


In [12]:
distances = pairwise_distances(projections)
# distances = np.triu(distances)
distances_f = distances.flatten()

plt.matshow(distances)
plt.colorbar()
plt.figure()
plt.hist(distances_f[distances_f > 0], bins=1000)
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

# Fit DBSCAN and calculate indices


In [13]:
clusters = DBSCAN(eps=0.2, min_samples=80).fit_predict(projections)
# scores = []
# best_score = float('inf')
# clusters = None
# range_ = list(range(5, 20))
# for k in tqdm(range_):
#     y = GaussianMixture(n_components=k).fit_predict(projections)
#     cur_score = davies_bouldin_score(projections, y)
#     scores.append(cur_score)
    
#     if cur_score < best_score:
#         best_score = cur_score
#         clusters = y

In [14]:
perplexity = 100

p = reduce_dims_and_plot(projections[clusters != -1],
                         y=clusters[clusters != -1],
                         title=f'perplexity: {perplexity}',
                         file_name=None,
                         perplexity=perplexity,
                         library='Multicore-TSNE',
                         perform_PCA=False,
                         projected=None,
                         figure_type='2d',
                         show_figure=True,
                         close_figure=False,
                         text=None)    

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [15]:
# from sklearn.tree import DecisionTreeClassifier
# from sklearn import tree
# from sklearn.tree import _tree

In [16]:
tensor_dataset = torch.stack(dataset_)

In [17]:
# clf = DecisionTreeClassifier(max_depth=200, min_samples_leaf=5)
# clf = clf.fit(tensor_dataset[clusters!=-1], clusters[clusters != -1])
# print(clf.score(tensor_dataset[clusters!=-1], clusters[clusters != -1]))
# print(clf.get_depth())

In [18]:
# scores = []
# for min_samples in range(1,50, 1):
#     clf = DecisionTreeClassifier(max_depth=200, min_samples_leaf=min_samples)
#     clf = clf.fit(tensor_dataset[clusters!=-1], clusters[clusters != -1])
#     scores.append(clf.score(tensor_dataset[clusters!=-1], clusters[clusters != -1]))
    
# plt.figure()
# plt.plot(list(range(1,50, 1)), scores)
# plt.show()

In [19]:
def get_rules(tree, feature_names, class_names):
    tree_ = tree.tree_
    feature_name = [
        feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!"
        for i in tree_.feature
    ]

    paths = []
    path = []
    
    def recurse(node, path, paths):
        if tree_.feature[node] != _tree.TREE_UNDEFINED:
            name = feature_name[node]
            threshold = tree_.threshold[node]
            p1, p2 = list(path), list(path)
#             p1 += [f"({name} <= {np.round(threshold, 3)})"]
            p1 += [(name, '<=', np.round(threshold, 3))]
            recurse(tree_.children_left[node], p1, paths)
            p2 += [(name, '>', np.round(threshold, 3))]
            recurse(tree_.children_right[node], p2, paths)
        else:
            path += [(tree_.value[node], tree_.n_node_samples[node])]
            paths += [path]
            
    recurse(0, path, paths)

    # sort by samples count
    samples_count = [p[-1][1] for p in paths]
    ii = list(np.argsort(samples_count))
    paths = [paths[i] for i in reversed(ii)]
    
    rules = []
    for path in paths:
        rule = []
        
        for p in path[:-1]:
            rule += [p]
        target = " then "
        if class_names is None:
            target += "response: "+str(np.round(path[-1][0][0][0],3))
        else:
            classes = path[-1][0][0]
            l = np.argmax(classes)
            target += f"class: {class_names[l]} (proba: {np.round(100.0*classes[l]/np.sum(classes),2)}%)"
           
        proba = np.round(100.0*classes[l]/np.sum(classes),2)
        target += f" | based on {path[-1][1]:,} samples"
        rule_wrapper = {'target': target, 'rule': rule, 'proba': proba}
        rules += [rule_wrapper]
        
    return rules

In [20]:
# rules = get_rules(clf, dataset.items, clusters[clusters != -1])

# for rule in rules:
#     n_pos = 0
#     for c,p,v in rule['rule']:
#         if p == '>':
#             n_pos += 1
#     rule['pos'] = n_pos

In [21]:
# plt.figure()
# probs = [r['proba'] for r in rules]
# plt.hist(probs, bins = 100)
# plt.show()

In [22]:
# rules = sorted(rules, key=lambda x:x['pos'])
# rules = [r for r in rules if r['proba'] > 50]
# print(len(rules))

In [23]:
# for i in range(17):
#     r_i = rules[i]
#     print(f"------------- rule {i} length {len(r_i)} -------------")
#     print(r_i)

# Train a Soft-Decision-Tree given the self-labels

## Prepare the dataset

In [24]:
tree_dataset = list(zip(tensor_dataset[clusters!=-1], clusters[clusters != -1]))
batch_size = 512
tree_loader = torch.utils.data.DataLoader(tree_dataset, batch_size=batch_size, shuffle=True)

# Define how we prune the weights of a node

In [25]:
def prune_node(node_weights, factor=1):
    w = node_weights.cpu().detach().numpy()
    mean_ = np.mean(w)
    std_ = np.std(w)
    node_weights[((mean_ - std_ * factor) < node_weights) & (node_weights < (mean_ + std_ * factor))] = 0
    return node_weights

def prune_node_keep(node_weights, keep=4):
    w = node_weights.cpu().detach().numpy()
    throw_idx = np.argsort(abs(w))[:-keep]
    node_weights[throw_idx] = 0
    return node_weights

def prune_tree(tree_, factor):
    new_weights = tree_.inner_nodes.weight.clone()
    for i in range(new_weights.shape[0]):
        res = prune_node_keep(new_weights[i, :], factor)
        new_weights[i, :] = res

    with torch.no_grad():
        tree_.inner_nodes.weight.copy_(new_weights)
        
def sparseness(x):
    s = []
    for i in range(x.shape[0]):
        x_ = x[i, :]
        sp = (len(x_) - torch.norm(x_, 0).item()) / len(x_)
        s.append(sp)
    return np.mean(s)

def compute_regularization_by_level(tree):
    total_reg = 0
    for i in range(tree.inner_nodes.weight.shape[0]):
        cur_level = np.floor(np.log2(i+1))
        node_reg = torch.norm(tree.inner_nodes.weight[i].view(-1), 2)
        total_reg += 2**(-cur_level) * node_reg
    return total_reg

def show_sparseness(tree):
    avg_sp = sparseness(tree.inner_nodes.weight)
    print(f"Average sparseness: {avg_sp}")
    layer = 0
    sps = []
    for i in range(tree.inner_nodes.weight.shape[0]):
        cur_layer = int(np.floor(np.log2(i+1)))
        if cur_layer != layer:
            print(f"layer {layer}: {np.mean(sps)}")
            sps = []
            layer = cur_layer

        x_ = tree.inner_nodes.weight[i, :]
        sp = (len(x_) - torch.norm(x_, 0).item()) / len(x_)
        sps.append(sp)
        
    return avg_sp

## Training configurations

In [26]:
def do_epoch(model, loader, device, log_interval, losses, accs, epoch, iteration):
    model = model.train()
    for batch_idx, (data, target) in enumerate(loader):
        iteration += 1
        data, target = data.to(device), target.to(device)

        output, penalty = tree.forward(data)

        # Loss
        loss_tree = criterion(output, target.view(-1))

        # Penalty
        loss_tree += penalty

        # Sparse regularization
#         fc_params = torch.cat([x.view(-1) for x in tree.inner_nodes.parameters()])
#         regularization = sparsity_lamda * torch.norm(fc_params, 2)
        regularization = sparsity_lamda * compute_regularization_by_level(tree)
        loss = loss_tree

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        losses.append(loss.item())

        pred = output.data.max(1)[1]
        correct = pred.eq(target.view(-1).data).sum()
        accs.append(correct.item() / data.size()[0])

        # Print training status
        if batch_idx % log_interval == 0:
            print(f"Epoch: {epoch:02d} | Batch: {batch_idx:03d} / {len(loader):03d} | Total loss: {loss.item():.3f} | Reg loss: {regularization.item():.3f} | Tree loss: {loss_tree.item():.3f} | Accuracy: {correct.item() / data.size()[0]:03f} | {round((time.time() - start_time) / iteration, 3)} sec/iter")
            
    return iteration


In [27]:
lr = 5e-3
weight_decay = 5e-4
sparsity_lamda = 2e-3
epochs = 100
output_dim = len(set(clusters))
log_interval = 1
use_cuda = device != 'cpu'

In [28]:
tree = SDT(input_dim=tensor_dataset.shape[1], output_dim=len(clusters - 1), depth=tree_depth, lamda=1e-3, use_cuda=use_cuda)
optimizer = torch.optim.Adam(tree.parameters(),
                                 lr=lr,
                                 weight_decay=weight_decay)
criterion = nn.CrossEntropyLoss()
tree = tree.to(device)

In [29]:
losses = []
accs = []
sparsity = []

In [30]:
start_time = time.time()
iteration = 0
for epoch in range(epochs):
    # Training
    avg_sp = show_sparseness(tree)
    sparsity.append(avg_sp)
    iteration = do_epoch(tree, tree_loader, device, log_interval, losses, accs, epoch, iteration)
    
    if epoch % 1 == 0:
        prune_tree(tree, factor=3)
        

Average sparseness: 0.0
layer 0: 0.0
layer 1: 0.0
layer 2: 0.0
layer 3: 0.0
layer 4: 0.0
layer 5: 0.0
layer 6: 0.0
layer 7: 0.0
layer 8: 0.0
Epoch: 00 | Batch: 000 / 029 | Total loss: 9.607 | Reg loss: 0.011 | Tree loss: 9.607 | Accuracy: 0.000000 | 1.088 sec/iter
Epoch: 00 | Batch: 001 / 029 | Total loss: 9.603 | Reg loss: 0.011 | Tree loss: 9.603 | Accuracy: 0.000000 | 0.95 sec/iter
Epoch: 00 | Batch: 002 / 029 | Total loss: 9.598 | Reg loss: 0.010 | Tree loss: 9.598 | Accuracy: 0.000000 | 0.909 sec/iter
Epoch: 00 | Batch: 003 / 029 | Total loss: 9.594 | Reg loss: 0.009 | Tree loss: 9.594 | Accuracy: 0.000000 | 0.892 sec/iter
Epoch: 00 | Batch: 004 / 029 | Total loss: 9.588 | Reg loss: 0.009 | Tree loss: 9.588 | Accuracy: 0.000000 | 0.88 sec/iter
Epoch: 00 | Batch: 005 / 029 | Total loss: 9.582 | Reg loss: 0.008 | Tree loss: 9.582 | Accuracy: 0.000000 | 0.874 sec/iter
Epoch: 00 | Batch: 006 / 029 | Total loss: 9.577 | Reg loss: 0.008 | Tree loss: 9.577 | Accuracy: 0.000000 | 0.868 se

Epoch: 02 | Batch: 003 / 029 | Total loss: 9.457 | Reg loss: 0.008 | Tree loss: 9.457 | Accuracy: 0.218750 | 0.863 sec/iter
Epoch: 02 | Batch: 004 / 029 | Total loss: 9.452 | Reg loss: 0.008 | Tree loss: 9.452 | Accuracy: 0.189453 | 0.863 sec/iter
Epoch: 02 | Batch: 005 / 029 | Total loss: 9.444 | Reg loss: 0.008 | Tree loss: 9.444 | Accuracy: 0.220703 | 0.862 sec/iter
Epoch: 02 | Batch: 006 / 029 | Total loss: 9.439 | Reg loss: 0.008 | Tree loss: 9.439 | Accuracy: 0.193359 | 0.861 sec/iter
Epoch: 02 | Batch: 007 / 029 | Total loss: 9.436 | Reg loss: 0.009 | Tree loss: 9.436 | Accuracy: 0.212891 | 0.861 sec/iter
Epoch: 02 | Batch: 008 / 029 | Total loss: 9.427 | Reg loss: 0.009 | Tree loss: 9.427 | Accuracy: 0.222656 | 0.86 sec/iter
Epoch: 02 | Batch: 009 / 029 | Total loss: 9.424 | Reg loss: 0.009 | Tree loss: 9.424 | Accuracy: 0.212891 | 0.86 sec/iter
Epoch: 02 | Batch: 010 / 029 | Total loss: 9.416 | Reg loss: 0.010 | Tree loss: 9.416 | Accuracy: 0.224609 | 0.86 sec/iter
Epoch: 02 |

Epoch: 04 | Batch: 007 / 029 | Total loss: 9.230 | Reg loss: 0.015 | Tree loss: 9.230 | Accuracy: 0.246094 | 0.86 sec/iter
Epoch: 04 | Batch: 008 / 029 | Total loss: 9.230 | Reg loss: 0.015 | Tree loss: 9.230 | Accuracy: 0.199219 | 0.86 sec/iter
Epoch: 04 | Batch: 009 / 029 | Total loss: 9.211 | Reg loss: 0.015 | Tree loss: 9.211 | Accuracy: 0.222656 | 0.86 sec/iter
Epoch: 04 | Batch: 010 / 029 | Total loss: 9.202 | Reg loss: 0.016 | Tree loss: 9.202 | Accuracy: 0.205078 | 0.86 sec/iter
Epoch: 04 | Batch: 011 / 029 | Total loss: 9.198 | Reg loss: 0.016 | Tree loss: 9.198 | Accuracy: 0.191406 | 0.86 sec/iter
Epoch: 04 | Batch: 012 / 029 | Total loss: 9.175 | Reg loss: 0.017 | Tree loss: 9.175 | Accuracy: 0.199219 | 0.86 sec/iter
Epoch: 04 | Batch: 013 / 029 | Total loss: 9.169 | Reg loss: 0.017 | Tree loss: 9.169 | Accuracy: 0.195312 | 0.859 sec/iter
Epoch: 04 | Batch: 014 / 029 | Total loss: 9.154 | Reg loss: 0.018 | Tree loss: 9.154 | Accuracy: 0.199219 | 0.859 sec/iter
Epoch: 04 | Ba

Epoch: 06 | Batch: 011 / 029 | Total loss: 8.789 | Reg loss: 0.022 | Tree loss: 8.789 | Accuracy: 0.220703 | 0.861 sec/iter
Epoch: 06 | Batch: 012 / 029 | Total loss: 8.775 | Reg loss: 0.022 | Tree loss: 8.775 | Accuracy: 0.212891 | 0.861 sec/iter
Epoch: 06 | Batch: 013 / 029 | Total loss: 8.758 | Reg loss: 0.022 | Tree loss: 8.758 | Accuracy: 0.193359 | 0.861 sec/iter
Epoch: 06 | Batch: 014 / 029 | Total loss: 8.708 | Reg loss: 0.023 | Tree loss: 8.708 | Accuracy: 0.216797 | 0.861 sec/iter
Epoch: 06 | Batch: 015 / 029 | Total loss: 8.712 | Reg loss: 0.023 | Tree loss: 8.712 | Accuracy: 0.187500 | 0.86 sec/iter
Epoch: 06 | Batch: 016 / 029 | Total loss: 8.705 | Reg loss: 0.024 | Tree loss: 8.705 | Accuracy: 0.201172 | 0.86 sec/iter
Epoch: 06 | Batch: 017 / 029 | Total loss: 8.679 | Reg loss: 0.024 | Tree loss: 8.679 | Accuracy: 0.191406 | 0.86 sec/iter
Epoch: 06 | Batch: 018 / 029 | Total loss: 8.624 | Reg loss: 0.024 | Tree loss: 8.624 | Accuracy: 0.201172 | 0.86 sec/iter
Epoch: 06 | 

Epoch: 08 | Batch: 015 / 029 | Total loss: 8.196 | Reg loss: 0.027 | Tree loss: 8.196 | Accuracy: 0.181641 | 0.864 sec/iter
Epoch: 08 | Batch: 016 / 029 | Total loss: 8.162 | Reg loss: 0.027 | Tree loss: 8.162 | Accuracy: 0.189453 | 0.864 sec/iter
Epoch: 08 | Batch: 017 / 029 | Total loss: 8.122 | Reg loss: 0.027 | Tree loss: 8.122 | Accuracy: 0.230469 | 0.864 sec/iter
Epoch: 08 | Batch: 018 / 029 | Total loss: 8.113 | Reg loss: 0.028 | Tree loss: 8.113 | Accuracy: 0.167969 | 0.863 sec/iter
Epoch: 08 | Batch: 019 / 029 | Total loss: 8.111 | Reg loss: 0.028 | Tree loss: 8.111 | Accuracy: 0.212891 | 0.863 sec/iter
Epoch: 08 | Batch: 020 / 029 | Total loss: 8.069 | Reg loss: 0.028 | Tree loss: 8.069 | Accuracy: 0.207031 | 0.863 sec/iter
Epoch: 08 | Batch: 021 / 029 | Total loss: 8.053 | Reg loss: 0.028 | Tree loss: 8.053 | Accuracy: 0.214844 | 0.863 sec/iter
Epoch: 08 | Batch: 022 / 029 | Total loss: 8.055 | Reg loss: 0.029 | Tree loss: 8.055 | Accuracy: 0.201172 | 0.863 sec/iter
Epoch: 0

Epoch: 10 | Batch: 019 / 029 | Total loss: 7.578 | Reg loss: 0.030 | Tree loss: 7.578 | Accuracy: 0.179688 | 0.864 sec/iter
Epoch: 10 | Batch: 020 / 029 | Total loss: 7.536 | Reg loss: 0.030 | Tree loss: 7.536 | Accuracy: 0.193359 | 0.864 sec/iter
Epoch: 10 | Batch: 021 / 029 | Total loss: 7.527 | Reg loss: 0.030 | Tree loss: 7.527 | Accuracy: 0.173828 | 0.864 sec/iter
Epoch: 10 | Batch: 022 / 029 | Total loss: 7.512 | Reg loss: 0.030 | Tree loss: 7.512 | Accuracy: 0.205078 | 0.864 sec/iter
Epoch: 10 | Batch: 023 / 029 | Total loss: 7.516 | Reg loss: 0.030 | Tree loss: 7.516 | Accuracy: 0.177734 | 0.864 sec/iter
Epoch: 10 | Batch: 024 / 029 | Total loss: 7.449 | Reg loss: 0.031 | Tree loss: 7.449 | Accuracy: 0.210938 | 0.864 sec/iter
Epoch: 10 | Batch: 025 / 029 | Total loss: 7.457 | Reg loss: 0.031 | Tree loss: 7.457 | Accuracy: 0.212891 | 0.864 sec/iter
Epoch: 10 | Batch: 026 / 029 | Total loss: 7.424 | Reg loss: 0.031 | Tree loss: 7.424 | Accuracy: 0.205078 | 0.864 sec/iter
Epoch: 1

Epoch: 12 | Batch: 023 / 029 | Total loss: 6.939 | Reg loss: 0.030 | Tree loss: 6.939 | Accuracy: 0.208984 | 0.864 sec/iter
Epoch: 12 | Batch: 024 / 029 | Total loss: 6.942 | Reg loss: 0.031 | Tree loss: 6.942 | Accuracy: 0.185547 | 0.864 sec/iter
Epoch: 12 | Batch: 025 / 029 | Total loss: 6.909 | Reg loss: 0.031 | Tree loss: 6.909 | Accuracy: 0.208984 | 0.864 sec/iter
Epoch: 12 | Batch: 026 / 029 | Total loss: 6.905 | Reg loss: 0.031 | Tree loss: 6.905 | Accuracy: 0.216797 | 0.864 sec/iter
Epoch: 12 | Batch: 027 / 029 | Total loss: 6.864 | Reg loss: 0.031 | Tree loss: 6.864 | Accuracy: 0.199219 | 0.863 sec/iter
Epoch: 12 | Batch: 028 / 029 | Total loss: 6.906 | Reg loss: 0.031 | Tree loss: 6.906 | Accuracy: 0.258621 | 0.863 sec/iter
Average sparseness: 0.9821428571428573
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
layer 7: 0.98214285714

Epoch: 14 | Batch: 027 / 029 | Total loss: 6.294 | Reg loss: 0.030 | Tree loss: 6.294 | Accuracy: 0.220703 | 0.864 sec/iter
Epoch: 14 | Batch: 028 / 029 | Total loss: 6.260 | Reg loss: 0.030 | Tree loss: 6.260 | Accuracy: 0.206897 | 0.863 sec/iter
Average sparseness: 0.9821428571428573
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
layer 7: 0.9821428571428573
layer 8: 0.9821428571428573
Epoch: 15 | Batch: 000 / 029 | Total loss: 6.631 | Reg loss: 0.028 | Tree loss: 6.631 | Accuracy: 0.195312 | 0.865 sec/iter
Epoch: 15 | Batch: 001 / 029 | Total loss: 6.573 | Reg loss: 0.028 | Tree loss: 6.573 | Accuracy: 0.208984 | 0.865 sec/iter
Epoch: 15 | Batch: 002 / 029 | Total loss: 6.555 | Reg loss: 0.028 | Tree loss: 6.555 | Accuracy: 0.169922 | 0.864 sec/iter
Epoch: 15 | Batch: 003 / 029 | Total loss: 6.540 | Reg loss: 0.028 | Tree loss: 6.540 | Ac

layer 8: 0.9821428571428573
Epoch: 17 | Batch: 000 / 029 | Total loss: 6.003 | Reg loss: 0.026 | Tree loss: 6.003 | Accuracy: 0.187500 | 0.862 sec/iter
Epoch: 17 | Batch: 001 / 029 | Total loss: 5.938 | Reg loss: 0.026 | Tree loss: 5.938 | Accuracy: 0.224609 | 0.862 sec/iter
Epoch: 17 | Batch: 002 / 029 | Total loss: 5.944 | Reg loss: 0.026 | Tree loss: 5.944 | Accuracy: 0.183594 | 0.862 sec/iter
Epoch: 17 | Batch: 003 / 029 | Total loss: 5.876 | Reg loss: 0.026 | Tree loss: 5.876 | Accuracy: 0.197266 | 0.862 sec/iter
Epoch: 17 | Batch: 004 / 029 | Total loss: 5.831 | Reg loss: 0.026 | Tree loss: 5.831 | Accuracy: 0.195312 | 0.862 sec/iter
Epoch: 17 | Batch: 005 / 029 | Total loss: 5.814 | Reg loss: 0.026 | Tree loss: 5.814 | Accuracy: 0.214844 | 0.862 sec/iter
Epoch: 17 | Batch: 006 / 029 | Total loss: 5.807 | Reg loss: 0.026 | Tree loss: 5.807 | Accuracy: 0.212891 | 0.862 sec/iter
Epoch: 17 | Batch: 007 / 029 | Total loss: 5.828 | Reg loss: 0.026 | Tree loss: 5.828 | Accuracy: 0.2011

Epoch: 19 | Batch: 004 / 029 | Total loss: 5.202 | Reg loss: 0.025 | Tree loss: 5.202 | Accuracy: 0.220703 | 0.862 sec/iter
Epoch: 19 | Batch: 005 / 029 | Total loss: 5.164 | Reg loss: 0.026 | Tree loss: 5.164 | Accuracy: 0.236328 | 0.862 sec/iter
Epoch: 19 | Batch: 006 / 029 | Total loss: 5.208 | Reg loss: 0.026 | Tree loss: 5.208 | Accuracy: 0.199219 | 0.862 sec/iter
Epoch: 19 | Batch: 007 / 029 | Total loss: 5.131 | Reg loss: 0.026 | Tree loss: 5.131 | Accuracy: 0.197266 | 0.862 sec/iter
Epoch: 19 | Batch: 008 / 029 | Total loss: 5.112 | Reg loss: 0.026 | Tree loss: 5.112 | Accuracy: 0.232422 | 0.862 sec/iter
Epoch: 19 | Batch: 009 / 029 | Total loss: 5.114 | Reg loss: 0.026 | Tree loss: 5.114 | Accuracy: 0.216797 | 0.862 sec/iter
Epoch: 19 | Batch: 010 / 029 | Total loss: 5.069 | Reg loss: 0.026 | Tree loss: 5.069 | Accuracy: 0.232422 | 0.862 sec/iter
Epoch: 19 | Batch: 011 / 029 | Total loss: 5.018 | Reg loss: 0.026 | Tree loss: 5.018 | Accuracy: 0.214844 | 0.862 sec/iter
Epoch: 1

Epoch: 21 | Batch: 008 / 029 | Total loss: 4.577 | Reg loss: 0.028 | Tree loss: 4.577 | Accuracy: 0.201172 | 0.862 sec/iter
Epoch: 21 | Batch: 009 / 029 | Total loss: 4.563 | Reg loss: 0.028 | Tree loss: 4.563 | Accuracy: 0.224609 | 0.861 sec/iter
Epoch: 21 | Batch: 010 / 029 | Total loss: 4.527 | Reg loss: 0.028 | Tree loss: 4.527 | Accuracy: 0.189453 | 0.861 sec/iter
Epoch: 21 | Batch: 011 / 029 | Total loss: 4.532 | Reg loss: 0.028 | Tree loss: 4.532 | Accuracy: 0.199219 | 0.861 sec/iter
Epoch: 21 | Batch: 012 / 029 | Total loss: 4.521 | Reg loss: 0.028 | Tree loss: 4.521 | Accuracy: 0.226562 | 0.861 sec/iter
Epoch: 21 | Batch: 013 / 029 | Total loss: 4.496 | Reg loss: 0.028 | Tree loss: 4.496 | Accuracy: 0.191406 | 0.861 sec/iter
Epoch: 21 | Batch: 014 / 029 | Total loss: 4.429 | Reg loss: 0.028 | Tree loss: 4.429 | Accuracy: 0.218750 | 0.861 sec/iter
Epoch: 21 | Batch: 015 / 029 | Total loss: 4.436 | Reg loss: 0.028 | Tree loss: 4.436 | Accuracy: 0.199219 | 0.861 sec/iter
Epoch: 2

Epoch: 23 | Batch: 012 / 029 | Total loss: 4.103 | Reg loss: 0.029 | Tree loss: 4.103 | Accuracy: 0.195312 | 0.861 sec/iter
Epoch: 23 | Batch: 013 / 029 | Total loss: 4.054 | Reg loss: 0.029 | Tree loss: 4.054 | Accuracy: 0.203125 | 0.861 sec/iter
Epoch: 23 | Batch: 014 / 029 | Total loss: 4.078 | Reg loss: 0.029 | Tree loss: 4.078 | Accuracy: 0.191406 | 0.861 sec/iter
Epoch: 23 | Batch: 015 / 029 | Total loss: 4.047 | Reg loss: 0.030 | Tree loss: 4.047 | Accuracy: 0.179688 | 0.861 sec/iter
Epoch: 23 | Batch: 016 / 029 | Total loss: 4.000 | Reg loss: 0.030 | Tree loss: 4.000 | Accuracy: 0.220703 | 0.861 sec/iter
Epoch: 23 | Batch: 017 / 029 | Total loss: 3.991 | Reg loss: 0.030 | Tree loss: 3.991 | Accuracy: 0.216797 | 0.861 sec/iter
Epoch: 23 | Batch: 018 / 029 | Total loss: 3.982 | Reg loss: 0.030 | Tree loss: 3.982 | Accuracy: 0.199219 | 0.861 sec/iter
Epoch: 23 | Batch: 019 / 029 | Total loss: 3.867 | Reg loss: 0.030 | Tree loss: 3.867 | Accuracy: 0.257812 | 0.861 sec/iter
Epoch: 2

Epoch: 25 | Batch: 016 / 029 | Total loss: 3.678 | Reg loss: 0.030 | Tree loss: 3.678 | Accuracy: 0.193359 | 0.861 sec/iter
Epoch: 25 | Batch: 017 / 029 | Total loss: 3.661 | Reg loss: 0.031 | Tree loss: 3.661 | Accuracy: 0.205078 | 0.861 sec/iter
Epoch: 25 | Batch: 018 / 029 | Total loss: 3.649 | Reg loss: 0.031 | Tree loss: 3.649 | Accuracy: 0.199219 | 0.861 sec/iter
Epoch: 25 | Batch: 019 / 029 | Total loss: 3.628 | Reg loss: 0.031 | Tree loss: 3.628 | Accuracy: 0.199219 | 0.861 sec/iter
Epoch: 25 | Batch: 020 / 029 | Total loss: 3.630 | Reg loss: 0.031 | Tree loss: 3.630 | Accuracy: 0.193359 | 0.861 sec/iter
Epoch: 25 | Batch: 021 / 029 | Total loss: 3.572 | Reg loss: 0.031 | Tree loss: 3.572 | Accuracy: 0.220703 | 0.861 sec/iter
Epoch: 25 | Batch: 022 / 029 | Total loss: 3.498 | Reg loss: 0.031 | Tree loss: 3.498 | Accuracy: 0.226562 | 0.861 sec/iter
Epoch: 25 | Batch: 023 / 029 | Total loss: 3.596 | Reg loss: 0.031 | Tree loss: 3.596 | Accuracy: 0.205078 | 0.861 sec/iter
Epoch: 2

Epoch: 27 | Batch: 020 / 029 | Total loss: 3.354 | Reg loss: 0.031 | Tree loss: 3.354 | Accuracy: 0.208984 | 0.861 sec/iter
Epoch: 27 | Batch: 021 / 029 | Total loss: 3.278 | Reg loss: 0.031 | Tree loss: 3.278 | Accuracy: 0.181641 | 0.86 sec/iter
Epoch: 27 | Batch: 022 / 029 | Total loss: 3.347 | Reg loss: 0.031 | Tree loss: 3.347 | Accuracy: 0.212891 | 0.86 sec/iter
Epoch: 27 | Batch: 023 / 029 | Total loss: 3.321 | Reg loss: 0.032 | Tree loss: 3.321 | Accuracy: 0.197266 | 0.86 sec/iter
Epoch: 27 | Batch: 024 / 029 | Total loss: 3.347 | Reg loss: 0.032 | Tree loss: 3.347 | Accuracy: 0.169922 | 0.86 sec/iter
Epoch: 27 | Batch: 025 / 029 | Total loss: 3.255 | Reg loss: 0.032 | Tree loss: 3.255 | Accuracy: 0.222656 | 0.86 sec/iter
Epoch: 27 | Batch: 026 / 029 | Total loss: 3.281 | Reg loss: 0.032 | Tree loss: 3.281 | Accuracy: 0.216797 | 0.86 sec/iter
Epoch: 27 | Batch: 027 / 029 | Total loss: 3.208 | Reg loss: 0.032 | Tree loss: 3.208 | Accuracy: 0.240234 | 0.86 sec/iter
Epoch: 27 | Bat

Epoch: 29 | Batch: 024 / 029 | Total loss: 3.058 | Reg loss: 0.032 | Tree loss: 3.058 | Accuracy: 0.216797 | 0.86 sec/iter
Epoch: 29 | Batch: 025 / 029 | Total loss: 3.042 | Reg loss: 0.032 | Tree loss: 3.042 | Accuracy: 0.205078 | 0.86 sec/iter
Epoch: 29 | Batch: 026 / 029 | Total loss: 3.141 | Reg loss: 0.032 | Tree loss: 3.141 | Accuracy: 0.173828 | 0.86 sec/iter
Epoch: 29 | Batch: 027 / 029 | Total loss: 3.008 | Reg loss: 0.032 | Tree loss: 3.008 | Accuracy: 0.203125 | 0.86 sec/iter
Epoch: 29 | Batch: 028 / 029 | Total loss: 3.081 | Reg loss: 0.032 | Tree loss: 3.081 | Accuracy: 0.155172 | 0.86 sec/iter
Average sparseness: 0.9821428571428573
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
layer 7: 0.9821428571428573
layer 8: 0.9821428571428573
Epoch: 30 | Batch: 000 / 029 | Total loss: 3.196 | Reg loss: 0.031 | Tree loss: 3.196 | Accurac

Epoch: 31 | Batch: 028 / 029 | Total loss: 2.838 | Reg loss: 0.032 | Tree loss: 2.838 | Accuracy: 0.224138 | 0.859 sec/iter
Average sparseness: 0.9821428571428573
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
layer 7: 0.9821428571428573
layer 8: 0.9821428571428573
Epoch: 32 | Batch: 000 / 029 | Total loss: 3.072 | Reg loss: 0.032 | Tree loss: 3.072 | Accuracy: 0.185547 | 0.86 sec/iter
Epoch: 32 | Batch: 001 / 029 | Total loss: 2.984 | Reg loss: 0.032 | Tree loss: 2.984 | Accuracy: 0.220703 | 0.86 sec/iter
Epoch: 32 | Batch: 002 / 029 | Total loss: 2.979 | Reg loss: 0.032 | Tree loss: 2.979 | Accuracy: 0.203125 | 0.86 sec/iter
Epoch: 32 | Batch: 003 / 029 | Total loss: 3.013 | Reg loss: 0.032 | Tree loss: 3.013 | Accuracy: 0.210938 | 0.86 sec/iter
Epoch: 32 | Batch: 004 / 029 | Total loss: 3.004 | Reg loss: 0.032 | Tree loss: 3.004 | Accura

Epoch: 34 | Batch: 001 / 029 | Total loss: 2.893 | Reg loss: 0.032 | Tree loss: 2.893 | Accuracy: 0.195312 | 0.861 sec/iter
Epoch: 34 | Batch: 002 / 029 | Total loss: 2.856 | Reg loss: 0.032 | Tree loss: 2.856 | Accuracy: 0.218750 | 0.86 sec/iter
Epoch: 34 | Batch: 003 / 029 | Total loss: 2.822 | Reg loss: 0.032 | Tree loss: 2.822 | Accuracy: 0.238281 | 0.86 sec/iter
Epoch: 34 | Batch: 004 / 029 | Total loss: 2.850 | Reg loss: 0.032 | Tree loss: 2.850 | Accuracy: 0.218750 | 0.86 sec/iter
Epoch: 34 | Batch: 005 / 029 | Total loss: 2.824 | Reg loss: 0.032 | Tree loss: 2.824 | Accuracy: 0.197266 | 0.86 sec/iter
Epoch: 34 | Batch: 006 / 029 | Total loss: 2.803 | Reg loss: 0.032 | Tree loss: 2.803 | Accuracy: 0.210938 | 0.86 sec/iter
Epoch: 34 | Batch: 007 / 029 | Total loss: 2.887 | Reg loss: 0.032 | Tree loss: 2.887 | Accuracy: 0.185547 | 0.86 sec/iter
Epoch: 34 | Batch: 008 / 029 | Total loss: 2.827 | Reg loss: 0.032 | Tree loss: 2.827 | Accuracy: 0.205078 | 0.86 sec/iter
Epoch: 34 | Bat

Epoch: 36 | Batch: 005 / 029 | Total loss: 2.693 | Reg loss: 0.032 | Tree loss: 2.693 | Accuracy: 0.197266 | 0.86 sec/iter
Epoch: 36 | Batch: 006 / 029 | Total loss: 2.711 | Reg loss: 0.032 | Tree loss: 2.711 | Accuracy: 0.187500 | 0.86 sec/iter
Epoch: 36 | Batch: 007 / 029 | Total loss: 2.713 | Reg loss: 0.032 | Tree loss: 2.713 | Accuracy: 0.167969 | 0.86 sec/iter
Epoch: 36 | Batch: 008 / 029 | Total loss: 2.681 | Reg loss: 0.032 | Tree loss: 2.681 | Accuracy: 0.214844 | 0.86 sec/iter
Epoch: 36 | Batch: 009 / 029 | Total loss: 2.697 | Reg loss: 0.032 | Tree loss: 2.697 | Accuracy: 0.218750 | 0.86 sec/iter
Epoch: 36 | Batch: 010 / 029 | Total loss: 2.702 | Reg loss: 0.032 | Tree loss: 2.702 | Accuracy: 0.177734 | 0.86 sec/iter
Epoch: 36 | Batch: 011 / 029 | Total loss: 2.685 | Reg loss: 0.032 | Tree loss: 2.685 | Accuracy: 0.218750 | 0.86 sec/iter
Epoch: 36 | Batch: 012 / 029 | Total loss: 2.681 | Reg loss: 0.032 | Tree loss: 2.681 | Accuracy: 0.189453 | 0.86 sec/iter
Epoch: 36 | Batc

Epoch: 38 | Batch: 009 / 029 | Total loss: 2.618 | Reg loss: 0.032 | Tree loss: 2.618 | Accuracy: 0.199219 | 0.86 sec/iter
Epoch: 38 | Batch: 010 / 029 | Total loss: 2.620 | Reg loss: 0.032 | Tree loss: 2.620 | Accuracy: 0.216797 | 0.86 sec/iter
Epoch: 38 | Batch: 011 / 029 | Total loss: 2.568 | Reg loss: 0.032 | Tree loss: 2.568 | Accuracy: 0.208984 | 0.86 sec/iter
Epoch: 38 | Batch: 012 / 029 | Total loss: 2.634 | Reg loss: 0.032 | Tree loss: 2.634 | Accuracy: 0.177734 | 0.86 sec/iter
Epoch: 38 | Batch: 013 / 029 | Total loss: 2.563 | Reg loss: 0.032 | Tree loss: 2.563 | Accuracy: 0.218750 | 0.86 sec/iter
Epoch: 38 | Batch: 014 / 029 | Total loss: 2.577 | Reg loss: 0.032 | Tree loss: 2.577 | Accuracy: 0.214844 | 0.86 sec/iter
Epoch: 38 | Batch: 015 / 029 | Total loss: 2.587 | Reg loss: 0.032 | Tree loss: 2.587 | Accuracy: 0.191406 | 0.859 sec/iter
Epoch: 38 | Batch: 016 / 029 | Total loss: 2.639 | Reg loss: 0.032 | Tree loss: 2.639 | Accuracy: 0.183594 | 0.859 sec/iter
Epoch: 38 | Ba

Epoch: 40 | Batch: 013 / 029 | Total loss: 2.539 | Reg loss: 0.032 | Tree loss: 2.539 | Accuracy: 0.199219 | 0.86 sec/iter
Epoch: 40 | Batch: 014 / 029 | Total loss: 2.546 | Reg loss: 0.032 | Tree loss: 2.546 | Accuracy: 0.175781 | 0.86 sec/iter
Epoch: 40 | Batch: 015 / 029 | Total loss: 2.506 | Reg loss: 0.032 | Tree loss: 2.506 | Accuracy: 0.189453 | 0.86 sec/iter
Epoch: 40 | Batch: 016 / 029 | Total loss: 2.485 | Reg loss: 0.032 | Tree loss: 2.485 | Accuracy: 0.193359 | 0.86 sec/iter
Epoch: 40 | Batch: 017 / 029 | Total loss: 2.568 | Reg loss: 0.032 | Tree loss: 2.568 | Accuracy: 0.166016 | 0.86 sec/iter
Epoch: 40 | Batch: 018 / 029 | Total loss: 2.528 | Reg loss: 0.032 | Tree loss: 2.528 | Accuracy: 0.193359 | 0.86 sec/iter
Epoch: 40 | Batch: 019 / 029 | Total loss: 2.489 | Reg loss: 0.032 | Tree loss: 2.489 | Accuracy: 0.201172 | 0.86 sec/iter
Epoch: 40 | Batch: 020 / 029 | Total loss: 2.513 | Reg loss: 0.032 | Tree loss: 2.513 | Accuracy: 0.207031 | 0.86 sec/iter
Epoch: 40 | Batc

Epoch: 42 | Batch: 017 / 029 | Total loss: 2.403 | Reg loss: 0.032 | Tree loss: 2.403 | Accuracy: 0.218750 | 0.859 sec/iter
Epoch: 42 | Batch: 018 / 029 | Total loss: 2.475 | Reg loss: 0.032 | Tree loss: 2.475 | Accuracy: 0.191406 | 0.859 sec/iter
Epoch: 42 | Batch: 019 / 029 | Total loss: 2.446 | Reg loss: 0.032 | Tree loss: 2.446 | Accuracy: 0.216797 | 0.859 sec/iter
Epoch: 42 | Batch: 020 / 029 | Total loss: 2.442 | Reg loss: 0.032 | Tree loss: 2.442 | Accuracy: 0.214844 | 0.859 sec/iter
Epoch: 42 | Batch: 021 / 029 | Total loss: 2.504 | Reg loss: 0.032 | Tree loss: 2.504 | Accuracy: 0.166016 | 0.859 sec/iter
Epoch: 42 | Batch: 022 / 029 | Total loss: 2.463 | Reg loss: 0.032 | Tree loss: 2.463 | Accuracy: 0.191406 | 0.859 sec/iter
Epoch: 42 | Batch: 023 / 029 | Total loss: 2.438 | Reg loss: 0.032 | Tree loss: 2.438 | Accuracy: 0.226562 | 0.859 sec/iter
Epoch: 42 | Batch: 024 / 029 | Total loss: 2.426 | Reg loss: 0.032 | Tree loss: 2.426 | Accuracy: 0.189453 | 0.859 sec/iter
Epoch: 4

Epoch: 44 | Batch: 021 / 029 | Total loss: 2.381 | Reg loss: 0.032 | Tree loss: 2.381 | Accuracy: 0.199219 | 0.859 sec/iter
Epoch: 44 | Batch: 022 / 029 | Total loss: 2.425 | Reg loss: 0.032 | Tree loss: 2.425 | Accuracy: 0.197266 | 0.859 sec/iter
Epoch: 44 | Batch: 023 / 029 | Total loss: 2.383 | Reg loss: 0.032 | Tree loss: 2.383 | Accuracy: 0.191406 | 0.859 sec/iter
Epoch: 44 | Batch: 024 / 029 | Total loss: 2.374 | Reg loss: 0.032 | Tree loss: 2.374 | Accuracy: 0.216797 | 0.859 sec/iter
Epoch: 44 | Batch: 025 / 029 | Total loss: 2.390 | Reg loss: 0.032 | Tree loss: 2.390 | Accuracy: 0.207031 | 0.859 sec/iter
Epoch: 44 | Batch: 026 / 029 | Total loss: 2.428 | Reg loss: 0.032 | Tree loss: 2.428 | Accuracy: 0.183594 | 0.859 sec/iter
Epoch: 44 | Batch: 027 / 029 | Total loss: 2.398 | Reg loss: 0.032 | Tree loss: 2.398 | Accuracy: 0.228516 | 0.859 sec/iter
Epoch: 44 | Batch: 028 / 029 | Total loss: 2.303 | Reg loss: 0.032 | Tree loss: 2.303 | Accuracy: 0.224138 | 0.859 sec/iter
Average 

Epoch: 46 | Batch: 025 / 029 | Total loss: 2.395 | Reg loss: 0.032 | Tree loss: 2.395 | Accuracy: 0.224609 | 0.859 sec/iter
Epoch: 46 | Batch: 026 / 029 | Total loss: 2.325 | Reg loss: 0.032 | Tree loss: 2.325 | Accuracy: 0.216797 | 0.859 sec/iter
Epoch: 46 | Batch: 027 / 029 | Total loss: 2.367 | Reg loss: 0.032 | Tree loss: 2.367 | Accuracy: 0.185547 | 0.859 sec/iter
Epoch: 46 | Batch: 028 / 029 | Total loss: 2.426 | Reg loss: 0.032 | Tree loss: 2.426 | Accuracy: 0.172414 | 0.859 sec/iter
Average sparseness: 0.9821428571428573
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
layer 7: 0.9821428571428573
layer 8: 0.9821428571428573
Epoch: 47 | Batch: 000 / 029 | Total loss: 2.408 | Reg loss: 0.032 | Tree loss: 2.408 | Accuracy: 0.181641 | 0.86 sec/iter
Epoch: 47 | Batch: 001 / 029 | Total loss: 2.442 | Reg loss: 0.032 | Tree loss: 2.442 | Acc

Average sparseness: 0.9821428571428573
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
layer 7: 0.9821428571428573
layer 8: 0.9821428571428573
Epoch: 49 | Batch: 000 / 029 | Total loss: 2.415 | Reg loss: 0.032 | Tree loss: 2.415 | Accuracy: 0.171875 | 0.86 sec/iter
Epoch: 49 | Batch: 001 / 029 | Total loss: 2.382 | Reg loss: 0.032 | Tree loss: 2.382 | Accuracy: 0.185547 | 0.86 sec/iter
Epoch: 49 | Batch: 002 / 029 | Total loss: 2.392 | Reg loss: 0.032 | Tree loss: 2.392 | Accuracy: 0.205078 | 0.86 sec/iter
Epoch: 49 | Batch: 003 / 029 | Total loss: 2.367 | Reg loss: 0.032 | Tree loss: 2.367 | Accuracy: 0.230469 | 0.86 sec/iter
Epoch: 49 | Batch: 004 / 029 | Total loss: 2.350 | Reg loss: 0.032 | Tree loss: 2.350 | Accuracy: 0.187500 | 0.86 sec/iter
Epoch: 49 | Batch: 005 / 029 | Total loss: 2.367 | Reg loss: 0.032 | Tree loss: 2.367 | Accurac

Epoch: 51 | Batch: 002 / 029 | Total loss: 2.347 | Reg loss: 0.031 | Tree loss: 2.347 | Accuracy: 0.218750 | 0.86 sec/iter
Epoch: 51 | Batch: 003 / 029 | Total loss: 2.385 | Reg loss: 0.031 | Tree loss: 2.385 | Accuracy: 0.162109 | 0.86 sec/iter
Epoch: 51 | Batch: 004 / 029 | Total loss: 2.391 | Reg loss: 0.031 | Tree loss: 2.391 | Accuracy: 0.199219 | 0.859 sec/iter
Epoch: 51 | Batch: 005 / 029 | Total loss: 2.354 | Reg loss: 0.031 | Tree loss: 2.354 | Accuracy: 0.193359 | 0.859 sec/iter
Epoch: 51 | Batch: 006 / 029 | Total loss: 2.369 | Reg loss: 0.031 | Tree loss: 2.369 | Accuracy: 0.189453 | 0.859 sec/iter
Epoch: 51 | Batch: 007 / 029 | Total loss: 2.350 | Reg loss: 0.031 | Tree loss: 2.350 | Accuracy: 0.187500 | 0.859 sec/iter
Epoch: 51 | Batch: 008 / 029 | Total loss: 2.377 | Reg loss: 0.031 | Tree loss: 2.377 | Accuracy: 0.207031 | 0.859 sec/iter
Epoch: 51 | Batch: 009 / 029 | Total loss: 2.326 | Reg loss: 0.031 | Tree loss: 2.326 | Accuracy: 0.224609 | 0.859 sec/iter
Epoch: 51 

Epoch: 53 | Batch: 006 / 029 | Total loss: 2.350 | Reg loss: 0.031 | Tree loss: 2.350 | Accuracy: 0.212891 | 0.859 sec/iter
Epoch: 53 | Batch: 007 / 029 | Total loss: 2.328 | Reg loss: 0.031 | Tree loss: 2.328 | Accuracy: 0.224609 | 0.859 sec/iter
Epoch: 53 | Batch: 008 / 029 | Total loss: 2.322 | Reg loss: 0.031 | Tree loss: 2.322 | Accuracy: 0.205078 | 0.859 sec/iter
Epoch: 53 | Batch: 009 / 029 | Total loss: 2.277 | Reg loss: 0.031 | Tree loss: 2.277 | Accuracy: 0.232422 | 0.859 sec/iter
Epoch: 53 | Batch: 010 / 029 | Total loss: 2.330 | Reg loss: 0.031 | Tree loss: 2.330 | Accuracy: 0.191406 | 0.859 sec/iter
Epoch: 53 | Batch: 011 / 029 | Total loss: 2.339 | Reg loss: 0.031 | Tree loss: 2.339 | Accuracy: 0.207031 | 0.859 sec/iter
Epoch: 53 | Batch: 012 / 029 | Total loss: 2.320 | Reg loss: 0.031 | Tree loss: 2.320 | Accuracy: 0.222656 | 0.859 sec/iter
Epoch: 53 | Batch: 013 / 029 | Total loss: 2.303 | Reg loss: 0.031 | Tree loss: 2.303 | Accuracy: 0.189453 | 0.859 sec/iter
Epoch: 5

Epoch: 55 | Batch: 010 / 029 | Total loss: 2.282 | Reg loss: 0.031 | Tree loss: 2.282 | Accuracy: 0.181641 | 0.859 sec/iter
Epoch: 55 | Batch: 011 / 029 | Total loss: 2.267 | Reg loss: 0.031 | Tree loss: 2.267 | Accuracy: 0.189453 | 0.859 sec/iter
Epoch: 55 | Batch: 012 / 029 | Total loss: 2.289 | Reg loss: 0.031 | Tree loss: 2.289 | Accuracy: 0.191406 | 0.859 sec/iter
Epoch: 55 | Batch: 013 / 029 | Total loss: 2.303 | Reg loss: 0.031 | Tree loss: 2.303 | Accuracy: 0.208984 | 0.859 sec/iter
Epoch: 55 | Batch: 014 / 029 | Total loss: 2.271 | Reg loss: 0.031 | Tree loss: 2.271 | Accuracy: 0.218750 | 0.859 sec/iter
Epoch: 55 | Batch: 015 / 029 | Total loss: 2.278 | Reg loss: 0.031 | Tree loss: 2.278 | Accuracy: 0.199219 | 0.859 sec/iter
Epoch: 55 | Batch: 016 / 029 | Total loss: 2.317 | Reg loss: 0.031 | Tree loss: 2.317 | Accuracy: 0.191406 | 0.859 sec/iter
Epoch: 55 | Batch: 017 / 029 | Total loss: 2.282 | Reg loss: 0.031 | Tree loss: 2.282 | Accuracy: 0.207031 | 0.859 sec/iter
Epoch: 5

Epoch: 57 | Batch: 014 / 029 | Total loss: 2.243 | Reg loss: 0.031 | Tree loss: 2.243 | Accuracy: 0.232422 | 0.859 sec/iter
Epoch: 57 | Batch: 015 / 029 | Total loss: 2.235 | Reg loss: 0.031 | Tree loss: 2.235 | Accuracy: 0.216797 | 0.859 sec/iter
Epoch: 57 | Batch: 016 / 029 | Total loss: 2.282 | Reg loss: 0.031 | Tree loss: 2.282 | Accuracy: 0.248047 | 0.859 sec/iter
Epoch: 57 | Batch: 017 / 029 | Total loss: 2.244 | Reg loss: 0.031 | Tree loss: 2.244 | Accuracy: 0.226562 | 0.859 sec/iter
Epoch: 57 | Batch: 018 / 029 | Total loss: 2.243 | Reg loss: 0.031 | Tree loss: 2.243 | Accuracy: 0.193359 | 0.859 sec/iter
Epoch: 57 | Batch: 019 / 029 | Total loss: 2.269 | Reg loss: 0.031 | Tree loss: 2.269 | Accuracy: 0.173828 | 0.859 sec/iter
Epoch: 57 | Batch: 020 / 029 | Total loss: 2.258 | Reg loss: 0.031 | Tree loss: 2.258 | Accuracy: 0.208984 | 0.859 sec/iter
Epoch: 57 | Batch: 021 / 029 | Total loss: 2.267 | Reg loss: 0.031 | Tree loss: 2.267 | Accuracy: 0.214844 | 0.859 sec/iter
Epoch: 5

Epoch: 59 | Batch: 018 / 029 | Total loss: 2.216 | Reg loss: 0.031 | Tree loss: 2.216 | Accuracy: 0.189453 | 0.859 sec/iter
Epoch: 59 | Batch: 019 / 029 | Total loss: 2.312 | Reg loss: 0.031 | Tree loss: 2.312 | Accuracy: 0.191406 | 0.859 sec/iter
Epoch: 59 | Batch: 020 / 029 | Total loss: 2.262 | Reg loss: 0.031 | Tree loss: 2.262 | Accuracy: 0.183594 | 0.859 sec/iter
Epoch: 59 | Batch: 021 / 029 | Total loss: 2.319 | Reg loss: 0.031 | Tree loss: 2.319 | Accuracy: 0.162109 | 0.859 sec/iter
Epoch: 59 | Batch: 022 / 029 | Total loss: 2.218 | Reg loss: 0.031 | Tree loss: 2.218 | Accuracy: 0.210938 | 0.859 sec/iter
Epoch: 59 | Batch: 023 / 029 | Total loss: 2.230 | Reg loss: 0.031 | Tree loss: 2.230 | Accuracy: 0.197266 | 0.859 sec/iter
Epoch: 59 | Batch: 024 / 029 | Total loss: 2.259 | Reg loss: 0.031 | Tree loss: 2.259 | Accuracy: 0.181641 | 0.859 sec/iter
Epoch: 59 | Batch: 025 / 029 | Total loss: 2.198 | Reg loss: 0.031 | Tree loss: 2.198 | Accuracy: 0.201172 | 0.859 sec/iter
Epoch: 5

Epoch: 61 | Batch: 022 / 029 | Total loss: 2.230 | Reg loss: 0.031 | Tree loss: 2.230 | Accuracy: 0.199219 | 0.857 sec/iter
Epoch: 61 | Batch: 023 / 029 | Total loss: 2.214 | Reg loss: 0.031 | Tree loss: 2.214 | Accuracy: 0.195312 | 0.857 sec/iter
Epoch: 61 | Batch: 024 / 029 | Total loss: 2.259 | Reg loss: 0.031 | Tree loss: 2.259 | Accuracy: 0.193359 | 0.857 sec/iter
Epoch: 61 | Batch: 025 / 029 | Total loss: 2.238 | Reg loss: 0.031 | Tree loss: 2.238 | Accuracy: 0.193359 | 0.857 sec/iter
Epoch: 61 | Batch: 026 / 029 | Total loss: 2.234 | Reg loss: 0.031 | Tree loss: 2.234 | Accuracy: 0.201172 | 0.857 sec/iter
Epoch: 61 | Batch: 027 / 029 | Total loss: 2.259 | Reg loss: 0.031 | Tree loss: 2.259 | Accuracy: 0.222656 | 0.857 sec/iter
Epoch: 61 | Batch: 028 / 029 | Total loss: 2.152 | Reg loss: 0.031 | Tree loss: 2.152 | Accuracy: 0.189655 | 0.857 sec/iter
Average sparseness: 0.9821428571428573
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 

Epoch: 63 | Batch: 026 / 029 | Total loss: 2.211 | Reg loss: 0.031 | Tree loss: 2.211 | Accuracy: 0.201172 | 0.856 sec/iter
Epoch: 63 | Batch: 027 / 029 | Total loss: 2.178 | Reg loss: 0.031 | Tree loss: 2.178 | Accuracy: 0.203125 | 0.856 sec/iter
Epoch: 63 | Batch: 028 / 029 | Total loss: 2.156 | Reg loss: 0.031 | Tree loss: 2.156 | Accuracy: 0.241379 | 0.856 sec/iter
Average sparseness: 0.9821428571428573
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
layer 7: 0.9821428571428573
layer 8: 0.9821428571428573
Epoch: 64 | Batch: 000 / 029 | Total loss: 2.291 | Reg loss: 0.031 | Tree loss: 2.291 | Accuracy: 0.208984 | 0.856 sec/iter
Epoch: 64 | Batch: 001 / 029 | Total loss: 2.300 | Reg loss: 0.031 | Tree loss: 2.300 | Accuracy: 0.244141 | 0.856 sec/iter
Epoch: 64 | Batch: 002 / 029 | Total loss: 2.255 | Reg loss: 0.031 | Tree loss: 2.255 | Ac

Epoch: 66 | Batch: 000 / 029 | Total loss: 2.328 | Reg loss: 0.031 | Tree loss: 2.328 | Accuracy: 0.210938 | 0.854 sec/iter
Epoch: 66 | Batch: 001 / 029 | Total loss: 2.264 | Reg loss: 0.031 | Tree loss: 2.264 | Accuracy: 0.216797 | 0.854 sec/iter
Epoch: 66 | Batch: 002 / 029 | Total loss: 2.319 | Reg loss: 0.031 | Tree loss: 2.319 | Accuracy: 0.185547 | 0.854 sec/iter
Epoch: 66 | Batch: 003 / 029 | Total loss: 2.302 | Reg loss: 0.031 | Tree loss: 2.302 | Accuracy: 0.189453 | 0.853 sec/iter
Epoch: 66 | Batch: 004 / 029 | Total loss: 2.259 | Reg loss: 0.031 | Tree loss: 2.259 | Accuracy: 0.218750 | 0.853 sec/iter
Epoch: 66 | Batch: 005 / 029 | Total loss: 2.246 | Reg loss: 0.031 | Tree loss: 2.246 | Accuracy: 0.207031 | 0.853 sec/iter
Epoch: 66 | Batch: 006 / 029 | Total loss: 2.265 | Reg loss: 0.031 | Tree loss: 2.265 | Accuracy: 0.205078 | 0.853 sec/iter
Epoch: 66 | Batch: 007 / 029 | Total loss: 2.274 | Reg loss: 0.031 | Tree loss: 2.274 | Accuracy: 0.205078 | 0.853 sec/iter
Epoch: 6

Epoch: 68 | Batch: 004 / 029 | Total loss: 2.268 | Reg loss: 0.030 | Tree loss: 2.268 | Accuracy: 0.205078 | 0.851 sec/iter
Epoch: 68 | Batch: 005 / 029 | Total loss: 2.215 | Reg loss: 0.030 | Tree loss: 2.215 | Accuracy: 0.232422 | 0.851 sec/iter
Epoch: 68 | Batch: 006 / 029 | Total loss: 2.242 | Reg loss: 0.030 | Tree loss: 2.242 | Accuracy: 0.187500 | 0.851 sec/iter
Epoch: 68 | Batch: 007 / 029 | Total loss: 2.267 | Reg loss: 0.030 | Tree loss: 2.267 | Accuracy: 0.185547 | 0.851 sec/iter
Epoch: 68 | Batch: 008 / 029 | Total loss: 2.227 | Reg loss: 0.030 | Tree loss: 2.227 | Accuracy: 0.203125 | 0.851 sec/iter
Epoch: 68 | Batch: 009 / 029 | Total loss: 2.214 | Reg loss: 0.030 | Tree loss: 2.214 | Accuracy: 0.224609 | 0.851 sec/iter
Epoch: 68 | Batch: 010 / 029 | Total loss: 2.222 | Reg loss: 0.031 | Tree loss: 2.222 | Accuracy: 0.208984 | 0.851 sec/iter
Epoch: 68 | Batch: 011 / 029 | Total loss: 2.241 | Reg loss: 0.031 | Tree loss: 2.241 | Accuracy: 0.214844 | 0.851 sec/iter
Epoch: 6

Epoch: 70 | Batch: 008 / 029 | Total loss: 2.192 | Reg loss: 0.030 | Tree loss: 2.192 | Accuracy: 0.214844 | 0.85 sec/iter
Epoch: 70 | Batch: 009 / 029 | Total loss: 2.213 | Reg loss: 0.030 | Tree loss: 2.213 | Accuracy: 0.181641 | 0.85 sec/iter
Epoch: 70 | Batch: 010 / 029 | Total loss: 2.227 | Reg loss: 0.030 | Tree loss: 2.227 | Accuracy: 0.191406 | 0.85 sec/iter
Epoch: 70 | Batch: 011 / 029 | Total loss: 2.217 | Reg loss: 0.030 | Tree loss: 2.217 | Accuracy: 0.226562 | 0.85 sec/iter
Epoch: 70 | Batch: 012 / 029 | Total loss: 2.200 | Reg loss: 0.030 | Tree loss: 2.200 | Accuracy: 0.236328 | 0.85 sec/iter
Epoch: 70 | Batch: 013 / 029 | Total loss: 2.230 | Reg loss: 0.031 | Tree loss: 2.230 | Accuracy: 0.195312 | 0.85 sec/iter
Epoch: 70 | Batch: 014 / 029 | Total loss: 2.232 | Reg loss: 0.031 | Tree loss: 2.232 | Accuracy: 0.207031 | 0.85 sec/iter
Epoch: 70 | Batch: 015 / 029 | Total loss: 2.252 | Reg loss: 0.031 | Tree loss: 2.252 | Accuracy: 0.166016 | 0.85 sec/iter
Epoch: 70 | Batc

Epoch: 72 | Batch: 012 / 029 | Total loss: 2.210 | Reg loss: 0.030 | Tree loss: 2.210 | Accuracy: 0.230469 | 0.848 sec/iter
Epoch: 72 | Batch: 013 / 029 | Total loss: 2.227 | Reg loss: 0.030 | Tree loss: 2.227 | Accuracy: 0.234375 | 0.848 sec/iter
Epoch: 72 | Batch: 014 / 029 | Total loss: 2.201 | Reg loss: 0.031 | Tree loss: 2.201 | Accuracy: 0.205078 | 0.848 sec/iter
Epoch: 72 | Batch: 015 / 029 | Total loss: 2.190 | Reg loss: 0.031 | Tree loss: 2.190 | Accuracy: 0.183594 | 0.848 sec/iter
Epoch: 72 | Batch: 016 / 029 | Total loss: 2.159 | Reg loss: 0.031 | Tree loss: 2.159 | Accuracy: 0.220703 | 0.848 sec/iter
Epoch: 72 | Batch: 017 / 029 | Total loss: 2.214 | Reg loss: 0.031 | Tree loss: 2.214 | Accuracy: 0.197266 | 0.848 sec/iter
Epoch: 72 | Batch: 018 / 029 | Total loss: 2.164 | Reg loss: 0.031 | Tree loss: 2.164 | Accuracy: 0.191406 | 0.848 sec/iter
Epoch: 72 | Batch: 019 / 029 | Total loss: 2.184 | Reg loss: 0.031 | Tree loss: 2.184 | Accuracy: 0.205078 | 0.848 sec/iter
Epoch: 7

Epoch: 74 | Batch: 016 / 029 | Total loss: 2.223 | Reg loss: 0.030 | Tree loss: 2.223 | Accuracy: 0.207031 | 0.846 sec/iter
Epoch: 74 | Batch: 017 / 029 | Total loss: 2.203 | Reg loss: 0.031 | Tree loss: 2.203 | Accuracy: 0.218750 | 0.846 sec/iter
Epoch: 74 | Batch: 018 / 029 | Total loss: 2.179 | Reg loss: 0.031 | Tree loss: 2.179 | Accuracy: 0.193359 | 0.846 sec/iter
Epoch: 74 | Batch: 019 / 029 | Total loss: 2.181 | Reg loss: 0.031 | Tree loss: 2.181 | Accuracy: 0.214844 | 0.846 sec/iter
Epoch: 74 | Batch: 020 / 029 | Total loss: 2.184 | Reg loss: 0.031 | Tree loss: 2.184 | Accuracy: 0.201172 | 0.846 sec/iter
Epoch: 74 | Batch: 021 / 029 | Total loss: 2.178 | Reg loss: 0.031 | Tree loss: 2.178 | Accuracy: 0.179688 | 0.846 sec/iter
Epoch: 74 | Batch: 022 / 029 | Total loss: 2.201 | Reg loss: 0.031 | Tree loss: 2.201 | Accuracy: 0.214844 | 0.846 sec/iter
Epoch: 74 | Batch: 023 / 029 | Total loss: 2.223 | Reg loss: 0.031 | Tree loss: 2.223 | Accuracy: 0.164062 | 0.846 sec/iter
Epoch: 7

Epoch: 76 | Batch: 020 / 029 | Total loss: 2.182 | Reg loss: 0.031 | Tree loss: 2.182 | Accuracy: 0.212891 | 0.844 sec/iter
Epoch: 76 | Batch: 021 / 029 | Total loss: 2.165 | Reg loss: 0.031 | Tree loss: 2.165 | Accuracy: 0.203125 | 0.844 sec/iter
Epoch: 76 | Batch: 022 / 029 | Total loss: 2.172 | Reg loss: 0.031 | Tree loss: 2.172 | Accuracy: 0.228516 | 0.844 sec/iter
Epoch: 76 | Batch: 023 / 029 | Total loss: 2.154 | Reg loss: 0.031 | Tree loss: 2.154 | Accuracy: 0.216797 | 0.844 sec/iter
Epoch: 76 | Batch: 024 / 029 | Total loss: 2.166 | Reg loss: 0.031 | Tree loss: 2.166 | Accuracy: 0.228516 | 0.844 sec/iter
Epoch: 76 | Batch: 025 / 029 | Total loss: 2.170 | Reg loss: 0.031 | Tree loss: 2.170 | Accuracy: 0.195312 | 0.844 sec/iter
Epoch: 76 | Batch: 026 / 029 | Total loss: 2.133 | Reg loss: 0.031 | Tree loss: 2.133 | Accuracy: 0.195312 | 0.844 sec/iter
Epoch: 76 | Batch: 027 / 029 | Total loss: 2.145 | Reg loss: 0.031 | Tree loss: 2.145 | Accuracy: 0.167969 | 0.844 sec/iter
Epoch: 7

Epoch: 78 | Batch: 024 / 029 | Total loss: 2.163 | Reg loss: 0.031 | Tree loss: 2.163 | Accuracy: 0.191406 | 0.843 sec/iter
Epoch: 78 | Batch: 025 / 029 | Total loss: 2.105 | Reg loss: 0.031 | Tree loss: 2.105 | Accuracy: 0.236328 | 0.843 sec/iter
Epoch: 78 | Batch: 026 / 029 | Total loss: 2.161 | Reg loss: 0.031 | Tree loss: 2.161 | Accuracy: 0.207031 | 0.843 sec/iter
Epoch: 78 | Batch: 027 / 029 | Total loss: 2.146 | Reg loss: 0.031 | Tree loss: 2.146 | Accuracy: 0.214844 | 0.843 sec/iter
Epoch: 78 | Batch: 028 / 029 | Total loss: 2.047 | Reg loss: 0.031 | Tree loss: 2.047 | Accuracy: 0.310345 | 0.843 sec/iter
Average sparseness: 0.9821428571428573
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
layer 7: 0.9821428571428573
layer 8: 0.9821428571428573
Epoch: 79 | Batch: 000 / 029 | Total loss: 2.299 | Reg loss: 0.030 | Tree loss: 2.299 | Ac

Epoch: 80 | Batch: 028 / 029 | Total loss: 2.228 | Reg loss: 0.031 | Tree loss: 2.228 | Accuracy: 0.172414 | 0.841 sec/iter
Average sparseness: 0.9821428571428573
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
layer 7: 0.9821428571428573
layer 8: 0.9821428571428573
Epoch: 81 | Batch: 000 / 029 | Total loss: 2.249 | Reg loss: 0.030 | Tree loss: 2.249 | Accuracy: 0.216797 | 0.841 sec/iter
Epoch: 81 | Batch: 001 / 029 | Total loss: 2.253 | Reg loss: 0.030 | Tree loss: 2.253 | Accuracy: 0.199219 | 0.841 sec/iter
Epoch: 81 | Batch: 002 / 029 | Total loss: 2.289 | Reg loss: 0.030 | Tree loss: 2.289 | Accuracy: 0.203125 | 0.841 sec/iter
Epoch: 81 | Batch: 003 / 029 | Total loss: 2.232 | Reg loss: 0.030 | Tree loss: 2.232 | Accuracy: 0.195312 | 0.841 sec/iter
Epoch: 81 | Batch: 004 / 029 | Total loss: 2.229 | Reg loss: 0.030 | Tree loss: 2.229 | Ac

Epoch: 83 | Batch: 001 / 029 | Total loss: 2.259 | Reg loss: 0.030 | Tree loss: 2.259 | Accuracy: 0.208984 | 0.84 sec/iter
Epoch: 83 | Batch: 002 / 029 | Total loss: 2.222 | Reg loss: 0.030 | Tree loss: 2.222 | Accuracy: 0.218750 | 0.84 sec/iter
Epoch: 83 | Batch: 003 / 029 | Total loss: 2.231 | Reg loss: 0.030 | Tree loss: 2.231 | Accuracy: 0.203125 | 0.84 sec/iter
Epoch: 83 | Batch: 004 / 029 | Total loss: 2.222 | Reg loss: 0.030 | Tree loss: 2.222 | Accuracy: 0.210938 | 0.84 sec/iter
Epoch: 83 | Batch: 005 / 029 | Total loss: 2.196 | Reg loss: 0.030 | Tree loss: 2.196 | Accuracy: 0.208984 | 0.84 sec/iter
Epoch: 83 | Batch: 006 / 029 | Total loss: 2.215 | Reg loss: 0.030 | Tree loss: 2.215 | Accuracy: 0.220703 | 0.84 sec/iter
Epoch: 83 | Batch: 007 / 029 | Total loss: 2.234 | Reg loss: 0.030 | Tree loss: 2.234 | Accuracy: 0.218750 | 0.84 sec/iter
Epoch: 83 | Batch: 008 / 029 | Total loss: 2.190 | Reg loss: 0.030 | Tree loss: 2.190 | Accuracy: 0.197266 | 0.84 sec/iter
Epoch: 83 | Batc

Epoch: 85 | Batch: 005 / 029 | Total loss: 2.212 | Reg loss: 0.030 | Tree loss: 2.212 | Accuracy: 0.201172 | 0.839 sec/iter
Epoch: 85 | Batch: 006 / 029 | Total loss: 2.246 | Reg loss: 0.030 | Tree loss: 2.246 | Accuracy: 0.203125 | 0.839 sec/iter
Epoch: 85 | Batch: 007 / 029 | Total loss: 2.315 | Reg loss: 0.030 | Tree loss: 2.315 | Accuracy: 0.183594 | 0.839 sec/iter
Epoch: 85 | Batch: 008 / 029 | Total loss: 2.218 | Reg loss: 0.030 | Tree loss: 2.218 | Accuracy: 0.226562 | 0.839 sec/iter
Epoch: 85 | Batch: 009 / 029 | Total loss: 2.225 | Reg loss: 0.030 | Tree loss: 2.225 | Accuracy: 0.207031 | 0.838 sec/iter
Epoch: 85 | Batch: 010 / 029 | Total loss: 2.243 | Reg loss: 0.030 | Tree loss: 2.243 | Accuracy: 0.212891 | 0.838 sec/iter
Epoch: 85 | Batch: 011 / 029 | Total loss: 2.176 | Reg loss: 0.030 | Tree loss: 2.176 | Accuracy: 0.214844 | 0.838 sec/iter
Epoch: 85 | Batch: 012 / 029 | Total loss: 2.232 | Reg loss: 0.030 | Tree loss: 2.232 | Accuracy: 0.208984 | 0.838 sec/iter
Epoch: 8

Epoch: 87 | Batch: 009 / 029 | Total loss: 2.207 | Reg loss: 0.030 | Tree loss: 2.207 | Accuracy: 0.187500 | 0.837 sec/iter
Epoch: 87 | Batch: 010 / 029 | Total loss: 2.236 | Reg loss: 0.030 | Tree loss: 2.236 | Accuracy: 0.189453 | 0.837 sec/iter
Epoch: 87 | Batch: 011 / 029 | Total loss: 2.177 | Reg loss: 0.030 | Tree loss: 2.177 | Accuracy: 0.195312 | 0.837 sec/iter
Epoch: 87 | Batch: 012 / 029 | Total loss: 2.212 | Reg loss: 0.030 | Tree loss: 2.212 | Accuracy: 0.208984 | 0.837 sec/iter
Epoch: 87 | Batch: 013 / 029 | Total loss: 2.211 | Reg loss: 0.030 | Tree loss: 2.211 | Accuracy: 0.171875 | 0.837 sec/iter
Epoch: 87 | Batch: 014 / 029 | Total loss: 2.188 | Reg loss: 0.030 | Tree loss: 2.188 | Accuracy: 0.175781 | 0.837 sec/iter
Epoch: 87 | Batch: 015 / 029 | Total loss: 2.191 | Reg loss: 0.030 | Tree loss: 2.191 | Accuracy: 0.205078 | 0.837 sec/iter
Epoch: 87 | Batch: 016 / 029 | Total loss: 2.186 | Reg loss: 0.030 | Tree loss: 2.186 | Accuracy: 0.220703 | 0.837 sec/iter
Epoch: 8

Epoch: 89 | Batch: 013 / 029 | Total loss: 2.228 | Reg loss: 0.030 | Tree loss: 2.228 | Accuracy: 0.205078 | 0.836 sec/iter
Epoch: 89 | Batch: 014 / 029 | Total loss: 2.143 | Reg loss: 0.030 | Tree loss: 2.143 | Accuracy: 0.199219 | 0.836 sec/iter
Epoch: 89 | Batch: 015 / 029 | Total loss: 2.180 | Reg loss: 0.030 | Tree loss: 2.180 | Accuracy: 0.222656 | 0.836 sec/iter
Epoch: 89 | Batch: 016 / 029 | Total loss: 2.165 | Reg loss: 0.030 | Tree loss: 2.165 | Accuracy: 0.197266 | 0.836 sec/iter
Epoch: 89 | Batch: 017 / 029 | Total loss: 2.177 | Reg loss: 0.030 | Tree loss: 2.177 | Accuracy: 0.189453 | 0.836 sec/iter
Epoch: 89 | Batch: 018 / 029 | Total loss: 2.228 | Reg loss: 0.030 | Tree loss: 2.228 | Accuracy: 0.177734 | 0.836 sec/iter
Epoch: 89 | Batch: 019 / 029 | Total loss: 2.139 | Reg loss: 0.030 | Tree loss: 2.139 | Accuracy: 0.191406 | 0.836 sec/iter
Epoch: 89 | Batch: 020 / 029 | Total loss: 2.177 | Reg loss: 0.030 | Tree loss: 2.177 | Accuracy: 0.210938 | 0.836 sec/iter
Epoch: 8

Epoch: 91 | Batch: 017 / 029 | Total loss: 2.161 | Reg loss: 0.030 | Tree loss: 2.161 | Accuracy: 0.214844 | 0.834 sec/iter
Epoch: 91 | Batch: 018 / 029 | Total loss: 2.129 | Reg loss: 0.030 | Tree loss: 2.129 | Accuracy: 0.208984 | 0.834 sec/iter
Epoch: 91 | Batch: 019 / 029 | Total loss: 2.124 | Reg loss: 0.030 | Tree loss: 2.124 | Accuracy: 0.244141 | 0.834 sec/iter
Epoch: 91 | Batch: 020 / 029 | Total loss: 2.139 | Reg loss: 0.030 | Tree loss: 2.139 | Accuracy: 0.207031 | 0.834 sec/iter
Epoch: 91 | Batch: 021 / 029 | Total loss: 2.152 | Reg loss: 0.030 | Tree loss: 2.152 | Accuracy: 0.193359 | 0.834 sec/iter
Epoch: 91 | Batch: 022 / 029 | Total loss: 2.175 | Reg loss: 0.030 | Tree loss: 2.175 | Accuracy: 0.189453 | 0.834 sec/iter
Epoch: 91 | Batch: 023 / 029 | Total loss: 2.136 | Reg loss: 0.030 | Tree loss: 2.136 | Accuracy: 0.208984 | 0.834 sec/iter
Epoch: 91 | Batch: 024 / 029 | Total loss: 2.181 | Reg loss: 0.030 | Tree loss: 2.181 | Accuracy: 0.201172 | 0.834 sec/iter
Epoch: 9

Epoch: 93 | Batch: 021 / 029 | Total loss: 2.144 | Reg loss: 0.030 | Tree loss: 2.144 | Accuracy: 0.212891 | 0.833 sec/iter
Epoch: 93 | Batch: 022 / 029 | Total loss: 2.130 | Reg loss: 0.030 | Tree loss: 2.130 | Accuracy: 0.244141 | 0.833 sec/iter
Epoch: 93 | Batch: 023 / 029 | Total loss: 2.118 | Reg loss: 0.030 | Tree loss: 2.118 | Accuracy: 0.224609 | 0.833 sec/iter
Epoch: 93 | Batch: 024 / 029 | Total loss: 2.154 | Reg loss: 0.030 | Tree loss: 2.154 | Accuracy: 0.179688 | 0.833 sec/iter
Epoch: 93 | Batch: 025 / 029 | Total loss: 2.112 | Reg loss: 0.030 | Tree loss: 2.112 | Accuracy: 0.214844 | 0.833 sec/iter
Epoch: 93 | Batch: 026 / 029 | Total loss: 2.175 | Reg loss: 0.030 | Tree loss: 2.175 | Accuracy: 0.199219 | 0.832 sec/iter
Epoch: 93 | Batch: 027 / 029 | Total loss: 2.179 | Reg loss: 0.030 | Tree loss: 2.179 | Accuracy: 0.185547 | 0.832 sec/iter
Epoch: 93 | Batch: 028 / 029 | Total loss: 2.145 | Reg loss: 0.030 | Tree loss: 2.145 | Accuracy: 0.206897 | 0.832 sec/iter
Average 

Epoch: 95 | Batch: 025 / 029 | Total loss: 2.148 | Reg loss: 0.030 | Tree loss: 2.148 | Accuracy: 0.179688 | 0.831 sec/iter
Epoch: 95 | Batch: 026 / 029 | Total loss: 2.130 | Reg loss: 0.030 | Tree loss: 2.130 | Accuracy: 0.228516 | 0.831 sec/iter
Epoch: 95 | Batch: 027 / 029 | Total loss: 2.137 | Reg loss: 0.030 | Tree loss: 2.137 | Accuracy: 0.189453 | 0.831 sec/iter
Epoch: 95 | Batch: 028 / 029 | Total loss: 2.106 | Reg loss: 0.030 | Tree loss: 2.106 | Accuracy: 0.206897 | 0.831 sec/iter
Average sparseness: 0.9821428571428573
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
layer 7: 0.9821428571428573
layer 8: 0.9821428571428573
Epoch: 96 | Batch: 000 / 029 | Total loss: 2.241 | Reg loss: 0.029 | Tree loss: 2.241 | Accuracy: 0.205078 | 0.831 sec/iter
Epoch: 96 | Batch: 001 / 029 | Total loss: 2.257 | Reg loss: 0.029 | Tree loss: 2.257 | Ac

Epoch: 98 | Batch: 000 / 029 | Total loss: 2.238 | Reg loss: 0.029 | Tree loss: 2.238 | Accuracy: 0.210938 | 0.829 sec/iter
Epoch: 98 | Batch: 001 / 029 | Total loss: 2.290 | Reg loss: 0.029 | Tree loss: 2.290 | Accuracy: 0.177734 | 0.829 sec/iter
Epoch: 98 | Batch: 002 / 029 | Total loss: 2.235 | Reg loss: 0.029 | Tree loss: 2.235 | Accuracy: 0.199219 | 0.829 sec/iter
Epoch: 98 | Batch: 003 / 029 | Total loss: 2.235 | Reg loss: 0.029 | Tree loss: 2.235 | Accuracy: 0.226562 | 0.829 sec/iter
Epoch: 98 | Batch: 004 / 029 | Total loss: 2.217 | Reg loss: 0.029 | Tree loss: 2.217 | Accuracy: 0.199219 | 0.829 sec/iter
Epoch: 98 | Batch: 005 / 029 | Total loss: 2.212 | Reg loss: 0.029 | Tree loss: 2.212 | Accuracy: 0.214844 | 0.829 sec/iter
Epoch: 98 | Batch: 006 / 029 | Total loss: 2.232 | Reg loss: 0.029 | Tree loss: 2.232 | Accuracy: 0.195312 | 0.829 sec/iter
Epoch: 98 | Batch: 007 / 029 | Total loss: 2.207 | Reg loss: 0.029 | Tree loss: 2.207 | Accuracy: 0.208984 | 0.828 sec/iter
Epoch: 9

In [31]:
plt.figure(figsize=(10, 5))
plt.ylabel("Accuracy")
plt.xlabel('Iteration')
plt.plot(accs, label='Accuracy vs iteration')
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [32]:
plt.figure()
plt.ylabel("Loss")
plt.xlabel('Iteration')
plt.plot(losses, label='Loss vs iteration')
plt.yscale("log")
plt.show()

plt.figure()
weights = tree.inner_nodes.weight.cpu().detach().numpy().flatten()
plt.hist(weights, bins=500)
weights_std = np.std(weights)
weights_mean = np.mean(weights)
plt.axvline(weights_mean + weights_std, color='r')
plt.axvline(weights_mean - weights_std, color='r')
plt.title(f"Mean: {weights_mean}   |   STD: {weights_std}")
plt.yscale("log")
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

# Tree Visualization

In [33]:
plt.figure(figsize=(15, 10), dpi=80)
avg_height, root = tree.visualize()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Average height: 9.895196506550219


# Extract Rules

# Accumulate samples in the leaves

In [34]:
print(f"Number of patterns: {len(root.get_leaves())}")

Number of patterns: 916


In [35]:
method = 'greedy'

In [36]:
root.clear_leaves_samples()

with torch.no_grad():
    for batch_idx, (data, target) in enumerate(tree_loader):
        root.accumulate_samples(data, method)



# Tighten boundaries

In [37]:
attr_names = dataset.items

# print(attr_names)
leaves = root.get_leaves()
sum_comprehensibility = 0
comprehensibilities = []
for pattern_counter, leaf in enumerate(leaves):
    leaf.reset_path()
    leaf.tighten_with_accumulated_samples()
    conds = leaf.get_path_conditions(attr_names)
    print(f"============== Pattern {pattern_counter + 1} ==============")
    comprehensibilities.append(sum([cond.comprehensibility for cond in conds]))
    
print(f"Average comprehensibility: {np.mean(comprehensibilities)}")
print(f"std comprehensibility: {np.std(comprehensibilities)}")
print(f"var comprehensibility: {np.var(comprehensibilities)}")
print(f"minimum comprehensibility: {np.min(comprehensibilities)}")
print(f"maximum comprehensibility: {np.max(comprehensibilities)}")

  return np.log(1 / (1 - x))


13420
974


Average comprehensibility: 51.6768558951965
std comprehensibility: 4.179429421497326
var comprehensibility: 17.467630289277473
minimum comprehensibility: 36
maximum comprehensibility: 58
