In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib widget
import numpy as np
import pandas as pd
import torch
import torchvision
from torch import optim
import matplotlib.pyplot as plt
import torch.nn.functional as F
from sklearn.metrics import pairwise_distances
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tqdm.notebook import tqdm
import os
import sys
import time
import torch.nn as nn
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)
from stream_generators.market_basket_dataset import MarketBasketDataset, BinaryEncodingTransform, RemoveItemsTransform
from utils.MatplotlibUtils import reduce_dims_and_plot
from network.auto_encoder import AutoEncoder
from losses.knn_loss import KNNLoss
from sklearn.cluster import DBSCAN
from sklearn.mixture import GaussianMixture
from soft_decision_tree.sdt_model import SDT
from sklearn.metrics import davies_bouldin_score

In [2]:
k = 64
tree_depth = 12
device = 'cuda'
dataset_path = r"/mnt/qnap/ekosman/Groceries_dataset.csv"

## Create the market basket dataset and use one-hot encoding for items

In [3]:
dataset = MarketBasketDataset(dataset_path=dataset_path)

In [4]:
model = AutoEncoder(dataset.n_items, 50, 4).train().to(device)
epochs = 500
lr = 5e-3
batch_size = 512
log_every = 5

In [5]:
dataset.transform = torchvision.transforms.Compose([
    RemoveItemsTransform(p=0.5),
    BinaryEncodingTransform(mapping=dataset.items_to_idx),
]
)
dataset.target_transform = torchvision.transforms.Compose([
    BinaryEncodingTransform(mapping=dataset.items_to_idx),
]
)

In [6]:
model.train()
data_iter = torch.utils.data.DataLoader(dataset,
                                     batch_size=batch_size,
                                     shuffle=True,
                                     num_workers=1,
                                     pin_memory=True)
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)
scheduler = ReduceLROnPlateau(optimizer, 'min', verbose=True, factor=0.7, threshold=1e-4)
knn_crt = KNNLoss(k=k).to(device)
losses = []
alpha = 10/170
gamma = 2
for epoch in range(epochs):
    total_loss = 0
    for iteration, (batch, target) in enumerate(data_iter):
        batch = batch.to(device)
        target = target.to(device)
        outputs, iterm = model(batch, return_intermidiate=True)
        mse_loss = F.binary_cross_entropy_with_logits(outputs, target, reduction='none')
        mask = torch.ones_like(mse_loss)
        mask[target == 0] = alpha ** gamma
        mask[target == 1] = (1 - alpha) ** gamma
        mse_loss = (mse_loss * mask).sum(dim=-1).mean()
        try:
            knn_loss = knn_crt(iterm)
            if torch.isinf(knn_loss):
                knn_loss = 0
        except ValueError:
            knn_loss = torch.tensor(0)
        loss = mse_loss + knn_loss
        total_loss += loss.item()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if iteration % log_every == 0:
            print(f"Epoch {epoch} / {epochs} | iteration {iteration} / {len(data_iter)} | Total Loss: {loss.item()} | KNN Loss: {knn_loss.item()} | BCE Loss: {mse_loss.item()}")
    
    scheduler.step(total_loss / (iteration + 1))
    losses.append(total_loss / (iteration + 1))

Epoch 0 / 500 | iteration 0 / 30 | Total Loss: 8.212848663330078 | KNN Loss: 6.229985237121582 | BCE Loss: 1.9828636646270752
Epoch 0 / 500 | iteration 5 / 30 | Total Loss: 8.168519973754883 | KNN Loss: 6.229908466339111 | BCE Loss: 1.9386110305786133
Epoch 0 / 500 | iteration 10 / 30 | Total Loss: 8.14419937133789 | KNN Loss: 6.229614734649658 | BCE Loss: 1.9145846366882324
Epoch 0 / 500 | iteration 15 / 30 | Total Loss: 8.14875316619873 | KNN Loss: 6.229379653930664 | BCE Loss: 1.919373631477356
Epoch 0 / 500 | iteration 20 / 30 | Total Loss: 8.138741493225098 | KNN Loss: 6.229327201843262 | BCE Loss: 1.9094146490097046
Epoch 0 / 500 | iteration 25 / 30 | Total Loss: 8.163969039916992 | KNN Loss: 6.229116439819336 | BCE Loss: 1.9348526000976562
Epoch 1 / 500 | iteration 0 / 30 | Total Loss: 8.142416000366211 | KNN Loss: 6.228728771209717 | BCE Loss: 1.9136877059936523
Epoch 1 / 500 | iteration 5 / 30 | Total Loss: 8.133009910583496 | KNN Loss: 6.228733062744141 | BCE Loss: 1.90427720

Epoch 10 / 500 | iteration 25 / 30 | Total Loss: 7.1136908531188965 | KNN Loss: 6.006924152374268 | BCE Loss: 1.106766700744629
Epoch 11 / 500 | iteration 0 / 30 | Total Loss: 7.062594890594482 | KNN Loss: 5.958533763885498 | BCE Loss: 1.1040611267089844
Epoch 11 / 500 | iteration 5 / 30 | Total Loss: 7.014702796936035 | KNN Loss: 5.912962436676025 | BCE Loss: 1.1017404794692993
Epoch 11 / 500 | iteration 10 / 30 | Total Loss: 6.985175132751465 | KNN Loss: 5.881042003631592 | BCE Loss: 1.104133129119873
Epoch 11 / 500 | iteration 15 / 30 | Total Loss: 6.982003211975098 | KNN Loss: 5.843730926513672 | BCE Loss: 1.1382720470428467
Epoch 11 / 500 | iteration 20 / 30 | Total Loss: 6.8553266525268555 | KNN Loss: 5.777772426605225 | BCE Loss: 1.0775541067123413
Epoch 11 / 500 | iteration 25 / 30 | Total Loss: 6.823657989501953 | KNN Loss: 5.716406345367432 | BCE Loss: 1.1072516441345215
Epoch 12 / 500 | iteration 0 / 30 | Total Loss: 6.77890157699585 | KNN Loss: 5.681288242340088 | BCE Loss:

Epoch 21 / 500 | iteration 20 / 30 | Total Loss: 5.507378101348877 | KNN Loss: 4.423064231872559 | BCE Loss: 1.084313988685608
Epoch 21 / 500 | iteration 25 / 30 | Total Loss: 5.532404899597168 | KNN Loss: 4.467677116394043 | BCE Loss: 1.064727544784546
Epoch 22 / 500 | iteration 0 / 30 | Total Loss: 5.539498329162598 | KNN Loss: 4.463871002197266 | BCE Loss: 1.0756272077560425
Epoch 22 / 500 | iteration 5 / 30 | Total Loss: 5.464534759521484 | KNN Loss: 4.4394354820251465 | BCE Loss: 1.0250990390777588
Epoch 22 / 500 | iteration 10 / 30 | Total Loss: 5.4879913330078125 | KNN Loss: 4.418454170227051 | BCE Loss: 1.0695372819900513
Epoch 22 / 500 | iteration 15 / 30 | Total Loss: 5.512447357177734 | KNN Loss: 4.455063343048096 | BCE Loss: 1.0573842525482178
Epoch 22 / 500 | iteration 20 / 30 | Total Loss: 5.4826130867004395 | KNN Loss: 4.447882175445557 | BCE Loss: 1.0347310304641724
Epoch 22 / 500 | iteration 25 / 30 | Total Loss: 5.572408199310303 | KNN Loss: 4.486761569976807 | BCE Lo

Epoch 32 / 500 | iteration 15 / 30 | Total Loss: 5.467789649963379 | KNN Loss: 4.427908420562744 | BCE Loss: 1.0398812294006348
Epoch 32 / 500 | iteration 20 / 30 | Total Loss: 5.515353679656982 | KNN Loss: 4.440191745758057 | BCE Loss: 1.0751619338989258
Epoch 32 / 500 | iteration 25 / 30 | Total Loss: 5.499302864074707 | KNN Loss: 4.43162727355957 | BCE Loss: 1.0676755905151367
Epoch 33 / 500 | iteration 0 / 30 | Total Loss: 5.4820556640625 | KNN Loss: 4.421136379241943 | BCE Loss: 1.060919165611267
Epoch 33 / 500 | iteration 5 / 30 | Total Loss: 5.483309268951416 | KNN Loss: 4.434974193572998 | BCE Loss: 1.048335075378418
Epoch 33 / 500 | iteration 10 / 30 | Total Loss: 5.489395618438721 | KNN Loss: 4.430936813354492 | BCE Loss: 1.058458924293518
Epoch 33 / 500 | iteration 15 / 30 | Total Loss: 5.49208402633667 | KNN Loss: 4.427916526794434 | BCE Loss: 1.0641676187515259
Epoch 33 / 500 | iteration 20 / 30 | Total Loss: 5.442554473876953 | KNN Loss: 4.402794361114502 | BCE Loss: 1.03

Epoch 43 / 500 | iteration 10 / 30 | Total Loss: 5.453668594360352 | KNN Loss: 4.428226947784424 | BCE Loss: 1.0254415273666382
Epoch 43 / 500 | iteration 15 / 30 | Total Loss: 5.423710823059082 | KNN Loss: 4.404366493225098 | BCE Loss: 1.0193440914154053
Epoch 43 / 500 | iteration 20 / 30 | Total Loss: 5.475948333740234 | KNN Loss: 4.4180426597595215 | BCE Loss: 1.057905673980713
Epoch 43 / 500 | iteration 25 / 30 | Total Loss: 5.454903602600098 | KNN Loss: 4.396880626678467 | BCE Loss: 1.0580229759216309
Epoch 44 / 500 | iteration 0 / 30 | Total Loss: 5.440617561340332 | KNN Loss: 4.386813640594482 | BCE Loss: 1.05380380153656
Epoch 44 / 500 | iteration 5 / 30 | Total Loss: 5.4464497566223145 | KNN Loss: 4.408674716949463 | BCE Loss: 1.037774920463562
Epoch 44 / 500 | iteration 10 / 30 | Total Loss: 5.449211120605469 | KNN Loss: 4.405559062957764 | BCE Loss: 1.043652057647705
Epoch 44 / 500 | iteration 15 / 30 | Total Loss: 5.44399356842041 | KNN Loss: 4.415781497955322 | BCE Loss: 1

Epoch 54 / 500 | iteration 5 / 30 | Total Loss: 5.434545516967773 | KNN Loss: 4.3978753089904785 | BCE Loss: 1.036670207977295
Epoch 54 / 500 | iteration 10 / 30 | Total Loss: 5.446662902832031 | KNN Loss: 4.389941692352295 | BCE Loss: 1.0567212104797363
Epoch 54 / 500 | iteration 15 / 30 | Total Loss: 5.4409284591674805 | KNN Loss: 4.405837059020996 | BCE Loss: 1.0350916385650635
Epoch 54 / 500 | iteration 20 / 30 | Total Loss: 5.483055591583252 | KNN Loss: 4.4302496910095215 | BCE Loss: 1.05280601978302
Epoch 54 / 500 | iteration 25 / 30 | Total Loss: 5.443774223327637 | KNN Loss: 4.407904624938965 | BCE Loss: 1.0358693599700928
Epoch 55 / 500 | iteration 0 / 30 | Total Loss: 5.478604793548584 | KNN Loss: 4.426198959350586 | BCE Loss: 1.0524059534072876
Epoch 55 / 500 | iteration 5 / 30 | Total Loss: 5.390327453613281 | KNN Loss: 4.3837103843688965 | BCE Loss: 1.0066173076629639
Epoch 55 / 500 | iteration 10 / 30 | Total Loss: 5.435505390167236 | KNN Loss: 4.40519905090332 | BCE Loss

Epoch 65 / 500 | iteration 0 / 30 | Total Loss: 5.482130527496338 | KNN Loss: 4.4495720863342285 | BCE Loss: 1.0325583219528198
Epoch 65 / 500 | iteration 5 / 30 | Total Loss: 5.416349411010742 | KNN Loss: 4.391154766082764 | BCE Loss: 1.0251948833465576
Epoch 65 / 500 | iteration 10 / 30 | Total Loss: 5.40132474899292 | KNN Loss: 4.3675537109375 | BCE Loss: 1.0337709188461304
Epoch 65 / 500 | iteration 15 / 30 | Total Loss: 5.4426188468933105 | KNN Loss: 4.402673244476318 | BCE Loss: 1.0399457216262817
Epoch 65 / 500 | iteration 20 / 30 | Total Loss: 5.425076961517334 | KNN Loss: 4.366964340209961 | BCE Loss: 1.0581125020980835
Epoch 65 / 500 | iteration 25 / 30 | Total Loss: 5.445033073425293 | KNN Loss: 4.4020562171936035 | BCE Loss: 1.0429770946502686
Epoch 66 / 500 | iteration 0 / 30 | Total Loss: 5.413837432861328 | KNN Loss: 4.368912696838379 | BCE Loss: 1.0449244976043701
Epoch 66 / 500 | iteration 5 / 30 | Total Loss: 5.414743423461914 | KNN Loss: 4.385481834411621 | BCE Loss:

Epoch 75 / 500 | iteration 25 / 30 | Total Loss: 5.430509567260742 | KNN Loss: 4.407435894012451 | BCE Loss: 1.023073673248291
Epoch 76 / 500 | iteration 0 / 30 | Total Loss: 5.410862445831299 | KNN Loss: 4.379970550537109 | BCE Loss: 1.0308917760849
Epoch 76 / 500 | iteration 5 / 30 | Total Loss: 5.372937202453613 | KNN Loss: 4.369564056396484 | BCE Loss: 1.0033729076385498
Epoch 76 / 500 | iteration 10 / 30 | Total Loss: 5.407941818237305 | KNN Loss: 4.3928070068359375 | BCE Loss: 1.015134572982788
Epoch 76 / 500 | iteration 15 / 30 | Total Loss: 5.412952899932861 | KNN Loss: 4.3678388595581055 | BCE Loss: 1.0451140403747559
Epoch 76 / 500 | iteration 20 / 30 | Total Loss: 5.430639266967773 | KNN Loss: 4.390250205993652 | BCE Loss: 1.0403889417648315
Epoch 76 / 500 | iteration 25 / 30 | Total Loss: 5.407138824462891 | KNN Loss: 4.386468887329102 | BCE Loss: 1.0206701755523682
Epoch 77 / 500 | iteration 0 / 30 | Total Loss: 5.411380767822266 | KNN Loss: 4.355388641357422 | BCE Loss: 1

Epoch 86 / 500 | iteration 20 / 30 | Total Loss: 5.400032997131348 | KNN Loss: 4.377598285675049 | BCE Loss: 1.0224347114562988
Epoch 86 / 500 | iteration 25 / 30 | Total Loss: 5.382449626922607 | KNN Loss: 4.363124847412109 | BCE Loss: 1.019324779510498
Epoch 87 / 500 | iteration 0 / 30 | Total Loss: 5.414988040924072 | KNN Loss: 4.399794101715088 | BCE Loss: 1.015194058418274
Epoch 87 / 500 | iteration 5 / 30 | Total Loss: 5.438265800476074 | KNN Loss: 4.3769989013671875 | BCE Loss: 1.0612666606903076
Epoch 87 / 500 | iteration 10 / 30 | Total Loss: 5.432356834411621 | KNN Loss: 4.434698104858398 | BCE Loss: 0.9976584911346436
Epoch 87 / 500 | iteration 15 / 30 | Total Loss: 5.386996269226074 | KNN Loss: 4.361823558807373 | BCE Loss: 1.0251728296279907
Epoch 87 / 500 | iteration 20 / 30 | Total Loss: 5.396116256713867 | KNN Loss: 4.374709129333496 | BCE Loss: 1.0214072465896606
Epoch 87 / 500 | iteration 25 / 30 | Total Loss: 5.388461589813232 | KNN Loss: 4.357341766357422 | BCE Loss

Epoch 97 / 500 | iteration 10 / 30 | Total Loss: 5.3568806648254395 | KNN Loss: 4.362505912780762 | BCE Loss: 0.9943749308586121
Epoch 97 / 500 | iteration 15 / 30 | Total Loss: 5.423254013061523 | KNN Loss: 4.384307384490967 | BCE Loss: 1.0389467477798462
Epoch 97 / 500 | iteration 20 / 30 | Total Loss: 5.373359203338623 | KNN Loss: 4.3711724281311035 | BCE Loss: 1.002186894416809
Epoch 97 / 500 | iteration 25 / 30 | Total Loss: 5.384055137634277 | KNN Loss: 4.355935096740723 | BCE Loss: 1.0281202793121338
Epoch 98 / 500 | iteration 0 / 30 | Total Loss: 5.454859733581543 | KNN Loss: 4.401942253112793 | BCE Loss: 1.05291748046875
Epoch 98 / 500 | iteration 5 / 30 | Total Loss: 5.398942947387695 | KNN Loss: 4.381374359130859 | BCE Loss: 1.017568826675415
Epoch 98 / 500 | iteration 10 / 30 | Total Loss: 5.4126410484313965 | KNN Loss: 4.365908145904541 | BCE Loss: 1.046732783317566
Epoch 98 / 500 | iteration 15 / 30 | Total Loss: 5.38692569732666 | KNN Loss: 4.355402946472168 | BCE Loss: 

Epoch 108 / 500 | iteration 0 / 30 | Total Loss: 5.371675491333008 | KNN Loss: 4.346179485321045 | BCE Loss: 1.0254961252212524
Epoch 108 / 500 | iteration 5 / 30 | Total Loss: 5.365689754486084 | KNN Loss: 4.352165699005127 | BCE Loss: 1.0135241746902466
Epoch 108 / 500 | iteration 10 / 30 | Total Loss: 5.394217491149902 | KNN Loss: 4.344084739685059 | BCE Loss: 1.0501327514648438
Epoch 108 / 500 | iteration 15 / 30 | Total Loss: 5.407317161560059 | KNN Loss: 4.36619234085083 | BCE Loss: 1.041124939918518
Epoch 108 / 500 | iteration 20 / 30 | Total Loss: 5.41560173034668 | KNN Loss: 4.379758834838867 | BCE Loss: 1.0358431339263916
Epoch 108 / 500 | iteration 25 / 30 | Total Loss: 5.345952987670898 | KNN Loss: 4.336705684661865 | BCE Loss: 1.0092475414276123
Epoch 109 / 500 | iteration 0 / 30 | Total Loss: 5.355257034301758 | KNN Loss: 4.335052967071533 | BCE Loss: 1.0202040672302246
Epoch 109 / 500 | iteration 5 / 30 | Total Loss: 5.379927635192871 | KNN Loss: 4.373350620269775 | BCE 

Epoch 118 / 500 | iteration 20 / 30 | Total Loss: 5.3619585037231445 | KNN Loss: 4.334133625030518 | BCE Loss: 1.027824878692627
Epoch 118 / 500 | iteration 25 / 30 | Total Loss: 5.364171028137207 | KNN Loss: 4.351348876953125 | BCE Loss: 1.012822151184082
Epoch 119 / 500 | iteration 0 / 30 | Total Loss: 5.396387577056885 | KNN Loss: 4.346325397491455 | BCE Loss: 1.0500621795654297
Epoch 119 / 500 | iteration 5 / 30 | Total Loss: 5.381161212921143 | KNN Loss: 4.363773822784424 | BCE Loss: 1.0173872709274292
Epoch 119 / 500 | iteration 10 / 30 | Total Loss: 5.381629467010498 | KNN Loss: 4.378503322601318 | BCE Loss: 1.0031261444091797
Epoch 119 / 500 | iteration 15 / 30 | Total Loss: 5.409842014312744 | KNN Loss: 4.380771636962891 | BCE Loss: 1.029070496559143
Epoch 119 / 500 | iteration 20 / 30 | Total Loss: 5.420607566833496 | KNN Loss: 4.399882793426514 | BCE Loss: 1.0207245349884033
Epoch 119 / 500 | iteration 25 / 30 | Total Loss: 5.368844032287598 | KNN Loss: 4.324952125549316 | B

Epoch 129 / 500 | iteration 10 / 30 | Total Loss: 5.387449264526367 | KNN Loss: 4.358529567718506 | BCE Loss: 1.0289195775985718
Epoch 129 / 500 | iteration 15 / 30 | Total Loss: 5.385261535644531 | KNN Loss: 4.359062671661377 | BCE Loss: 1.0261991024017334
Epoch 129 / 500 | iteration 20 / 30 | Total Loss: 5.379024028778076 | KNN Loss: 4.3583598136901855 | BCE Loss: 1.0206643342971802
Epoch 129 / 500 | iteration 25 / 30 | Total Loss: 5.367504119873047 | KNN Loss: 4.338450908660889 | BCE Loss: 1.0290533304214478
Epoch 130 / 500 | iteration 0 / 30 | Total Loss: 5.407863140106201 | KNN Loss: 4.360833644866943 | BCE Loss: 1.0470296144485474
Epoch 130 / 500 | iteration 5 / 30 | Total Loss: 5.4156365394592285 | KNN Loss: 4.346874713897705 | BCE Loss: 1.0687617063522339
Epoch 130 / 500 | iteration 10 / 30 | Total Loss: 5.3751139640808105 | KNN Loss: 4.358725547790527 | BCE Loss: 1.0163884162902832
Epoch 130 / 500 | iteration 15 / 30 | Total Loss: 5.425351142883301 | KNN Loss: 4.36621332168579

Epoch   140: reducing learning rate of group 0 to 1.2005e-03.
Epoch 140 / 500 | iteration 0 / 30 | Total Loss: 5.3478217124938965 | KNN Loss: 4.334627628326416 | BCE Loss: 1.0131940841674805
Epoch 140 / 500 | iteration 5 / 30 | Total Loss: 5.357454299926758 | KNN Loss: 4.319551944732666 | BCE Loss: 1.0379023551940918
Epoch 140 / 500 | iteration 10 / 30 | Total Loss: 5.425537109375 | KNN Loss: 4.387564182281494 | BCE Loss: 1.0379729270935059
Epoch 140 / 500 | iteration 15 / 30 | Total Loss: 5.379631996154785 | KNN Loss: 4.346367835998535 | BCE Loss: 1.0332640409469604
Epoch 140 / 500 | iteration 20 / 30 | Total Loss: 5.392107009887695 | KNN Loss: 4.335211277008057 | BCE Loss: 1.0568957328796387
Epoch 140 / 500 | iteration 25 / 30 | Total Loss: 5.354142189025879 | KNN Loss: 4.360055446624756 | BCE Loss: 0.9940869212150574
Epoch 141 / 500 | iteration 0 / 30 | Total Loss: 5.344489097595215 | KNN Loss: 4.339416980743408 | BCE Loss: 1.0050718784332275
Epoch 141 / 500 | iteration 5 / 30 | Tot

Epoch 150 / 500 | iteration 20 / 30 | Total Loss: 5.3416829109191895 | KNN Loss: 4.344029426574707 | BCE Loss: 0.9976534843444824
Epoch 150 / 500 | iteration 25 / 30 | Total Loss: 5.351125240325928 | KNN Loss: 4.360593795776367 | BCE Loss: 0.9905316233634949
Epoch 151 / 500 | iteration 0 / 30 | Total Loss: 5.385076522827148 | KNN Loss: 4.356515407562256 | BCE Loss: 1.028560996055603
Epoch 151 / 500 | iteration 5 / 30 | Total Loss: 5.3765106201171875 | KNN Loss: 4.341403484344482 | BCE Loss: 1.035107135772705
Epoch 151 / 500 | iteration 10 / 30 | Total Loss: 5.383762359619141 | KNN Loss: 4.349235534667969 | BCE Loss: 1.0345265865325928
Epoch 151 / 500 | iteration 15 / 30 | Total Loss: 5.393272399902344 | KNN Loss: 4.374727725982666 | BCE Loss: 1.0185449123382568
Epoch 151 / 500 | iteration 20 / 30 | Total Loss: 5.359528541564941 | KNN Loss: 4.335381507873535 | BCE Loss: 1.0241470336914062
Epoch 151 / 500 | iteration 25 / 30 | Total Loss: 5.388387680053711 | KNN Loss: 4.39791202545166 | 

Epoch 161 / 500 | iteration 10 / 30 | Total Loss: 5.399094104766846 | KNN Loss: 4.341405868530273 | BCE Loss: 1.0576882362365723
Epoch 161 / 500 | iteration 15 / 30 | Total Loss: 5.378274917602539 | KNN Loss: 4.356893062591553 | BCE Loss: 1.0213820934295654
Epoch 161 / 500 | iteration 20 / 30 | Total Loss: 5.378725051879883 | KNN Loss: 4.3676438331604 | BCE Loss: 1.0110814571380615
Epoch 161 / 500 | iteration 25 / 30 | Total Loss: 5.351779937744141 | KNN Loss: 4.35300874710083 | BCE Loss: 0.9987711906433105
Epoch 162 / 500 | iteration 0 / 30 | Total Loss: 5.337748050689697 | KNN Loss: 4.316977024078369 | BCE Loss: 1.0207711458206177
Epoch 162 / 500 | iteration 5 / 30 | Total Loss: 5.366402626037598 | KNN Loss: 4.3359150886535645 | BCE Loss: 1.0304875373840332
Epoch 162 / 500 | iteration 10 / 30 | Total Loss: 5.371758937835693 | KNN Loss: 4.321863651275635 | BCE Loss: 1.0498952865600586
Epoch 162 / 500 | iteration 15 / 30 | Total Loss: 5.368221759796143 | KNN Loss: 4.346579074859619 | B

Epoch 172 / 500 | iteration 0 / 30 | Total Loss: 5.388667583465576 | KNN Loss: 4.359400272369385 | BCE Loss: 1.0292671918869019
Epoch 172 / 500 | iteration 5 / 30 | Total Loss: 5.357452392578125 | KNN Loss: 4.33066463470459 | BCE Loss: 1.026787519454956
Epoch 172 / 500 | iteration 10 / 30 | Total Loss: 5.384127616882324 | KNN Loss: 4.349640369415283 | BCE Loss: 1.0344874858856201
Epoch 172 / 500 | iteration 15 / 30 | Total Loss: 5.387388706207275 | KNN Loss: 4.34055757522583 | BCE Loss: 1.0468310117721558
Epoch 172 / 500 | iteration 20 / 30 | Total Loss: 5.390714645385742 | KNN Loss: 4.356607437133789 | BCE Loss: 1.0341072082519531
Epoch 172 / 500 | iteration 25 / 30 | Total Loss: 5.344223976135254 | KNN Loss: 4.324267864227295 | BCE Loss: 1.0199558734893799
Epoch 173 / 500 | iteration 0 / 30 | Total Loss: 5.379513263702393 | KNN Loss: 4.343180179595947 | BCE Loss: 1.0363332033157349
Epoch 173 / 500 | iteration 5 / 30 | Total Loss: 5.391201972961426 | KNN Loss: 4.361751556396484 | BCE 

Epoch 182 / 500 | iteration 20 / 30 | Total Loss: 5.375922203063965 | KNN Loss: 4.3643927574157715 | BCE Loss: 1.0115294456481934
Epoch 182 / 500 | iteration 25 / 30 | Total Loss: 5.369171619415283 | KNN Loss: 4.338595867156982 | BCE Loss: 1.0305757522583008
Epoch 183 / 500 | iteration 0 / 30 | Total Loss: 5.336714267730713 | KNN Loss: 4.325250625610352 | BCE Loss: 1.0114637613296509
Epoch 183 / 500 | iteration 5 / 30 | Total Loss: 5.358863353729248 | KNN Loss: 4.368281364440918 | BCE Loss: 0.9905818104743958
Epoch 183 / 500 | iteration 10 / 30 | Total Loss: 5.396547317504883 | KNN Loss: 4.35877799987793 | BCE Loss: 1.0377695560455322
Epoch 183 / 500 | iteration 15 / 30 | Total Loss: 5.370776176452637 | KNN Loss: 4.341398239135742 | BCE Loss: 1.0293781757354736
Epoch 183 / 500 | iteration 20 / 30 | Total Loss: 5.366744518280029 | KNN Loss: 4.36600399017334 | BCE Loss: 1.000740647315979
Epoch 183 / 500 | iteration 25 / 30 | Total Loss: 5.329010486602783 | KNN Loss: 4.330703258514404 | B

Epoch 193 / 500 | iteration 10 / 30 | Total Loss: 5.331056594848633 | KNN Loss: 4.323740005493164 | BCE Loss: 1.0073168277740479
Epoch 193 / 500 | iteration 15 / 30 | Total Loss: 5.373711585998535 | KNN Loss: 4.360569953918457 | BCE Loss: 1.0131416320800781
Epoch 193 / 500 | iteration 20 / 30 | Total Loss: 5.360617637634277 | KNN Loss: 4.320070743560791 | BCE Loss: 1.0405471324920654
Epoch 193 / 500 | iteration 25 / 30 | Total Loss: 5.397516250610352 | KNN Loss: 4.378911018371582 | BCE Loss: 1.0186049938201904
Epoch 194 / 500 | iteration 0 / 30 | Total Loss: 5.364614963531494 | KNN Loss: 4.372374057769775 | BCE Loss: 0.9922410249710083
Epoch 194 / 500 | iteration 5 / 30 | Total Loss: 5.375185012817383 | KNN Loss: 4.350447177886963 | BCE Loss: 1.024738073348999
Epoch 194 / 500 | iteration 10 / 30 | Total Loss: 5.368709564208984 | KNN Loss: 4.334160327911377 | BCE Loss: 1.034549355506897
Epoch 194 / 500 | iteration 15 / 30 | Total Loss: 5.361690998077393 | KNN Loss: 4.338366508483887 | B

Epoch 204 / 500 | iteration 0 / 30 | Total Loss: 5.345978736877441 | KNN Loss: 4.31674861907959 | BCE Loss: 1.0292302370071411
Epoch 204 / 500 | iteration 5 / 30 | Total Loss: 5.331292629241943 | KNN Loss: 4.316649913787842 | BCE Loss: 1.0146427154541016
Epoch 204 / 500 | iteration 10 / 30 | Total Loss: 5.340664863586426 | KNN Loss: 4.326846599578857 | BCE Loss: 1.0138185024261475
Epoch 204 / 500 | iteration 15 / 30 | Total Loss: 5.320669174194336 | KNN Loss: 4.309599876403809 | BCE Loss: 1.0110692977905273
Epoch 204 / 500 | iteration 20 / 30 | Total Loss: 5.397899627685547 | KNN Loss: 4.362926959991455 | BCE Loss: 1.0349727869033813
Epoch 204 / 500 | iteration 25 / 30 | Total Loss: 5.366705894470215 | KNN Loss: 4.357923984527588 | BCE Loss: 1.0087817907333374
Epoch 205 / 500 | iteration 0 / 30 | Total Loss: 5.354964733123779 | KNN Loss: 4.330198287963867 | BCE Loss: 1.0247663259506226
Epoch 205 / 500 | iteration 5 / 30 | Total Loss: 5.392163276672363 | KNN Loss: 4.342433452606201 | BC

Epoch 214 / 500 | iteration 20 / 30 | Total Loss: 5.368538856506348 | KNN Loss: 4.339588642120361 | BCE Loss: 1.0289503335952759
Epoch 214 / 500 | iteration 25 / 30 | Total Loss: 5.362987995147705 | KNN Loss: 4.3609418869018555 | BCE Loss: 1.0020461082458496
Epoch 215 / 500 | iteration 0 / 30 | Total Loss: 5.39769983291626 | KNN Loss: 4.364389896392822 | BCE Loss: 1.0333099365234375
Epoch 215 / 500 | iteration 5 / 30 | Total Loss: 5.345359802246094 | KNN Loss: 4.323641777038574 | BCE Loss: 1.0217177867889404
Epoch 215 / 500 | iteration 10 / 30 | Total Loss: 5.384208679199219 | KNN Loss: 4.337435245513916 | BCE Loss: 1.0467736721038818
Epoch 215 / 500 | iteration 15 / 30 | Total Loss: 5.350454330444336 | KNN Loss: 4.317714691162109 | BCE Loss: 1.0327394008636475
Epoch 215 / 500 | iteration 20 / 30 | Total Loss: 5.352890968322754 | KNN Loss: 4.341034412384033 | BCE Loss: 1.0118566751480103
Epoch 215 / 500 | iteration 25 / 30 | Total Loss: 5.415961265563965 | KNN Loss: 4.399179458618164 |

Epoch 225 / 500 | iteration 10 / 30 | Total Loss: 5.344145774841309 | KNN Loss: 4.307365894317627 | BCE Loss: 1.0367796421051025
Epoch 225 / 500 | iteration 15 / 30 | Total Loss: 5.384427070617676 | KNN Loss: 4.348245143890381 | BCE Loss: 1.036182165145874
Epoch 225 / 500 | iteration 20 / 30 | Total Loss: 5.352842807769775 | KNN Loss: 4.366208076477051 | BCE Loss: 0.9866347312927246
Epoch 225 / 500 | iteration 25 / 30 | Total Loss: 5.316404342651367 | KNN Loss: 4.313209533691406 | BCE Loss: 1.0031949281692505
Epoch 226 / 500 | iteration 0 / 30 | Total Loss: 5.3978729248046875 | KNN Loss: 4.384249687194824 | BCE Loss: 1.0136229991912842
Epoch 226 / 500 | iteration 5 / 30 | Total Loss: 5.375531196594238 | KNN Loss: 4.314569473266602 | BCE Loss: 1.0609617233276367
Epoch 226 / 500 | iteration 10 / 30 | Total Loss: 5.361615180969238 | KNN Loss: 4.3304948806762695 | BCE Loss: 1.0311203002929688
Epoch 226 / 500 | iteration 15 / 30 | Total Loss: 5.334537982940674 | KNN Loss: 4.32710075378418 |

Epoch 236 / 500 | iteration 0 / 30 | Total Loss: 5.340029716491699 | KNN Loss: 4.324741840362549 | BCE Loss: 1.0152881145477295
Epoch 236 / 500 | iteration 5 / 30 | Total Loss: 5.36462926864624 | KNN Loss: 4.33961820602417 | BCE Loss: 1.0250111818313599
Epoch 236 / 500 | iteration 10 / 30 | Total Loss: 5.339353561401367 | KNN Loss: 4.335816860198975 | BCE Loss: 1.0035369396209717
Epoch 236 / 500 | iteration 15 / 30 | Total Loss: 5.364731311798096 | KNN Loss: 4.314316749572754 | BCE Loss: 1.0504145622253418
Epoch 236 / 500 | iteration 20 / 30 | Total Loss: 5.332450866699219 | KNN Loss: 4.319447040557861 | BCE Loss: 1.0130040645599365
Epoch 236 / 500 | iteration 25 / 30 | Total Loss: 5.339555740356445 | KNN Loss: 4.333128452301025 | BCE Loss: 1.0064270496368408
Epoch 237 / 500 | iteration 0 / 30 | Total Loss: 5.356820583343506 | KNN Loss: 4.352006435394287 | BCE Loss: 1.0048142671585083
Epoch 237 / 500 | iteration 5 / 30 | Total Loss: 5.403218746185303 | KNN Loss: 4.413683891296387 | BCE

Epoch 246 / 500 | iteration 20 / 30 | Total Loss: 5.411556243896484 | KNN Loss: 4.341714859008789 | BCE Loss: 1.0698415040969849
Epoch 246 / 500 | iteration 25 / 30 | Total Loss: 5.374431610107422 | KNN Loss: 4.358969211578369 | BCE Loss: 1.0154621601104736
Epoch 247 / 500 | iteration 0 / 30 | Total Loss: 5.373733997344971 | KNN Loss: 4.3409576416015625 | BCE Loss: 1.0327763557434082
Epoch 247 / 500 | iteration 5 / 30 | Total Loss: 5.409998416900635 | KNN Loss: 4.397111415863037 | BCE Loss: 1.0128871202468872
Epoch 247 / 500 | iteration 10 / 30 | Total Loss: 5.320089340209961 | KNN Loss: 4.322471618652344 | BCE Loss: 0.9976178407669067
Epoch 247 / 500 | iteration 15 / 30 | Total Loss: 5.359219551086426 | KNN Loss: 4.320870399475098 | BCE Loss: 1.038348913192749
Epoch 247 / 500 | iteration 20 / 30 | Total Loss: 5.379520416259766 | KNN Loss: 4.336143970489502 | BCE Loss: 1.0433766841888428
Epoch 247 / 500 | iteration 25 / 30 | Total Loss: 5.332324981689453 | KNN Loss: 4.322441577911377 |

Epoch 257 / 500 | iteration 10 / 30 | Total Loss: 5.407764911651611 | KNN Loss: 4.372744083404541 | BCE Loss: 1.0350208282470703
Epoch 257 / 500 | iteration 15 / 30 | Total Loss: 5.310039520263672 | KNN Loss: 4.310190200805664 | BCE Loss: 0.9998493194580078
Epoch 257 / 500 | iteration 20 / 30 | Total Loss: 5.365618705749512 | KNN Loss: 4.333090782165527 | BCE Loss: 1.0325276851654053
Epoch 257 / 500 | iteration 25 / 30 | Total Loss: 5.347814083099365 | KNN Loss: 4.32740592956543 | BCE Loss: 1.0204081535339355
Epoch 258 / 500 | iteration 0 / 30 | Total Loss: 5.323830604553223 | KNN Loss: 4.3236541748046875 | BCE Loss: 1.0001763105392456
Epoch 258 / 500 | iteration 5 / 30 | Total Loss: 5.332890033721924 | KNN Loss: 4.319753646850586 | BCE Loss: 1.013136386871338
Epoch 258 / 500 | iteration 10 / 30 | Total Loss: 5.363752365112305 | KNN Loss: 4.340979099273682 | BCE Loss: 1.022773265838623
Epoch 258 / 500 | iteration 15 / 30 | Total Loss: 5.34808349609375 | KNN Loss: 4.326876163482666 | BC

Epoch 268 / 500 | iteration 0 / 30 | Total Loss: 5.329870700836182 | KNN Loss: 4.333670616149902 | BCE Loss: 0.996199905872345
Epoch 268 / 500 | iteration 5 / 30 | Total Loss: 5.408700466156006 | KNN Loss: 4.367873668670654 | BCE Loss: 1.0408267974853516
Epoch 268 / 500 | iteration 10 / 30 | Total Loss: 5.318916320800781 | KNN Loss: 4.297286510467529 | BCE Loss: 1.0216299295425415
Epoch 268 / 500 | iteration 15 / 30 | Total Loss: 5.383553504943848 | KNN Loss: 4.333485126495361 | BCE Loss: 1.0500683784484863
Epoch 268 / 500 | iteration 20 / 30 | Total Loss: 5.342248916625977 | KNN Loss: 4.355373382568359 | BCE Loss: 0.9868755340576172
Epoch 268 / 500 | iteration 25 / 30 | Total Loss: 5.36327600479126 | KNN Loss: 4.367438793182373 | BCE Loss: 0.9958371520042419
Epoch 269 / 500 | iteration 0 / 30 | Total Loss: 5.384873390197754 | KNN Loss: 4.375438213348389 | BCE Loss: 1.0094350576400757
Epoch 269 / 500 | iteration 5 / 30 | Total Loss: 5.381702899932861 | KNN Loss: 4.366947174072266 | BCE

Epoch 278 / 500 | iteration 20 / 30 | Total Loss: 5.398136615753174 | KNN Loss: 4.381958484649658 | BCE Loss: 1.0161781311035156
Epoch 278 / 500 | iteration 25 / 30 | Total Loss: 5.3371429443359375 | KNN Loss: 4.339204788208008 | BCE Loss: 0.9979379773139954
Epoch 279 / 500 | iteration 0 / 30 | Total Loss: 5.379116535186768 | KNN Loss: 4.3719635009765625 | BCE Loss: 1.0071529150009155
Epoch 279 / 500 | iteration 5 / 30 | Total Loss: 5.39136266708374 | KNN Loss: 4.3366498947143555 | BCE Loss: 1.0547127723693848
Epoch 279 / 500 | iteration 10 / 30 | Total Loss: 5.360705852508545 | KNN Loss: 4.357628345489502 | BCE Loss: 1.003077507019043
Epoch 279 / 500 | iteration 15 / 30 | Total Loss: 5.312936782836914 | KNN Loss: 4.312992572784424 | BCE Loss: 0.9999443292617798
Epoch 279 / 500 | iteration 20 / 30 | Total Loss: 5.370644569396973 | KNN Loss: 4.340776443481445 | BCE Loss: 1.0298678874969482
Epoch 279 / 500 | iteration 25 / 30 | Total Loss: 5.346396446228027 | KNN Loss: 4.316761493682861 

Epoch 289 / 500 | iteration 10 / 30 | Total Loss: 5.425412654876709 | KNN Loss: 4.3963847160339355 | BCE Loss: 1.0290278196334839
Epoch 289 / 500 | iteration 15 / 30 | Total Loss: 5.341582298278809 | KNN Loss: 4.326824188232422 | BCE Loss: 1.0147583484649658
Epoch 289 / 500 | iteration 20 / 30 | Total Loss: 5.381550312042236 | KNN Loss: 4.370083808898926 | BCE Loss: 1.0114665031433105
Epoch 289 / 500 | iteration 25 / 30 | Total Loss: 5.351917266845703 | KNN Loss: 4.332967758178711 | BCE Loss: 1.0189496278762817
Epoch 290 / 500 | iteration 0 / 30 | Total Loss: 5.354959487915039 | KNN Loss: 4.3364033699035645 | BCE Loss: 1.0185561180114746
Epoch 290 / 500 | iteration 5 / 30 | Total Loss: 5.401604652404785 | KNN Loss: 4.352441310882568 | BCE Loss: 1.0491632223129272
Epoch 290 / 500 | iteration 10 / 30 | Total Loss: 5.330008029937744 | KNN Loss: 4.3344573974609375 | BCE Loss: 0.995550811290741
Epoch 290 / 500 | iteration 15 / 30 | Total Loss: 5.339901924133301 | KNN Loss: 4.32356595993042 

Epoch 300 / 500 | iteration 0 / 30 | Total Loss: 5.400449752807617 | KNN Loss: 4.355912685394287 | BCE Loss: 1.0445373058319092
Epoch 300 / 500 | iteration 5 / 30 | Total Loss: 5.391111373901367 | KNN Loss: 4.359960079193115 | BCE Loss: 1.031151533126831
Epoch 300 / 500 | iteration 10 / 30 | Total Loss: 5.391209125518799 | KNN Loss: 4.353403091430664 | BCE Loss: 1.0378060340881348
Epoch 300 / 500 | iteration 15 / 30 | Total Loss: 5.347984313964844 | KNN Loss: 4.3153886795043945 | BCE Loss: 1.0325953960418701
Epoch 300 / 500 | iteration 20 / 30 | Total Loss: 5.336020469665527 | KNN Loss: 4.323631286621094 | BCE Loss: 1.0123893022537231
Epoch 300 / 500 | iteration 25 / 30 | Total Loss: 5.398921489715576 | KNN Loss: 4.345242977142334 | BCE Loss: 1.0536786317825317
Epoch 301 / 500 | iteration 0 / 30 | Total Loss: 5.340672492980957 | KNN Loss: 4.352476119995117 | BCE Loss: 0.9881964921951294
Epoch 301 / 500 | iteration 5 / 30 | Total Loss: 5.395585060119629 | KNN Loss: 4.346435070037842 | B

Epoch 310 / 500 | iteration 20 / 30 | Total Loss: 5.38198184967041 | KNN Loss: 4.354580402374268 | BCE Loss: 1.0274014472961426
Epoch 310 / 500 | iteration 25 / 30 | Total Loss: 5.357471466064453 | KNN Loss: 4.357285499572754 | BCE Loss: 1.0001859664916992
Epoch 311 / 500 | iteration 0 / 30 | Total Loss: 5.411555290222168 | KNN Loss: 4.376363277435303 | BCE Loss: 1.0351921319961548
Epoch 311 / 500 | iteration 5 / 30 | Total Loss: 5.3743133544921875 | KNN Loss: 4.352755546569824 | BCE Loss: 1.0215578079223633
Epoch 311 / 500 | iteration 10 / 30 | Total Loss: 5.37565279006958 | KNN Loss: 4.344982147216797 | BCE Loss: 1.0306705236434937
Epoch 311 / 500 | iteration 15 / 30 | Total Loss: 5.418898105621338 | KNN Loss: 4.391064167022705 | BCE Loss: 1.0278340578079224
Epoch 311 / 500 | iteration 20 / 30 | Total Loss: 5.36530876159668 | KNN Loss: 4.346837043762207 | BCE Loss: 1.018471598625183
Epoch 311 / 500 | iteration 25 / 30 | Total Loss: 5.347202777862549 | KNN Loss: 4.320523262023926 | BC

Epoch 321 / 500 | iteration 10 / 30 | Total Loss: 5.354340553283691 | KNN Loss: 4.343430519104004 | BCE Loss: 1.0109102725982666
Epoch 321 / 500 | iteration 15 / 30 | Total Loss: 5.340439796447754 | KNN Loss: 4.328736305236816 | BCE Loss: 1.0117034912109375
Epoch 321 / 500 | iteration 20 / 30 | Total Loss: 5.387660026550293 | KNN Loss: 4.368063449859619 | BCE Loss: 1.019596815109253
Epoch 321 / 500 | iteration 25 / 30 | Total Loss: 5.352064609527588 | KNN Loss: 4.363457679748535 | BCE Loss: 0.9886068105697632
Epoch 322 / 500 | iteration 0 / 30 | Total Loss: 5.3659138679504395 | KNN Loss: 4.350650787353516 | BCE Loss: 1.0152629613876343
Epoch 322 / 500 | iteration 5 / 30 | Total Loss: 5.382969379425049 | KNN Loss: 4.358417987823486 | BCE Loss: 1.0245513916015625
Epoch 322 / 500 | iteration 10 / 30 | Total Loss: 5.37991189956665 | KNN Loss: 4.347390651702881 | BCE Loss: 1.0325212478637695
Epoch 322 / 500 | iteration 15 / 30 | Total Loss: 5.352431774139404 | KNN Loss: 4.341634750366211 | 

Epoch 332 / 500 | iteration 0 / 30 | Total Loss: 5.367119789123535 | KNN Loss: 4.361668586730957 | BCE Loss: 1.0054514408111572
Epoch 332 / 500 | iteration 5 / 30 | Total Loss: 5.31496000289917 | KNN Loss: 4.325533390045166 | BCE Loss: 0.9894266128540039
Epoch 332 / 500 | iteration 10 / 30 | Total Loss: 5.4323625564575195 | KNN Loss: 4.3792853355407715 | BCE Loss: 1.0530774593353271
Epoch 332 / 500 | iteration 15 / 30 | Total Loss: 5.357402324676514 | KNN Loss: 4.332403659820557 | BCE Loss: 1.0249987840652466
Epoch 332 / 500 | iteration 20 / 30 | Total Loss: 5.369745254516602 | KNN Loss: 4.320011138916016 | BCE Loss: 1.049734115600586
Epoch 332 / 500 | iteration 25 / 30 | Total Loss: 5.366730213165283 | KNN Loss: 4.37130880355835 | BCE Loss: 0.9954215288162231
Epoch 333 / 500 | iteration 0 / 30 | Total Loss: 5.349822521209717 | KNN Loss: 4.35531759262085 | BCE Loss: 0.9945051074028015
Epoch 333 / 500 | iteration 5 / 30 | Total Loss: 5.394263744354248 | KNN Loss: 4.346890926361084 | BCE

Epoch 342 / 500 | iteration 20 / 30 | Total Loss: 5.32889461517334 | KNN Loss: 4.3204498291015625 | BCE Loss: 1.0084447860717773
Epoch 342 / 500 | iteration 25 / 30 | Total Loss: 5.385364532470703 | KNN Loss: 4.346867561340332 | BCE Loss: 1.0384972095489502
Epoch 343 / 500 | iteration 0 / 30 | Total Loss: 5.412097930908203 | KNN Loss: 4.381918430328369 | BCE Loss: 1.0301792621612549
Epoch 343 / 500 | iteration 5 / 30 | Total Loss: 5.305268287658691 | KNN Loss: 4.30812931060791 | BCE Loss: 0.997139036655426
Epoch 343 / 500 | iteration 10 / 30 | Total Loss: 5.381422519683838 | KNN Loss: 4.316338539123535 | BCE Loss: 1.0650840997695923
Epoch 343 / 500 | iteration 15 / 30 | Total Loss: 5.331233024597168 | KNN Loss: 4.320980072021484 | BCE Loss: 1.0102529525756836
Epoch 343 / 500 | iteration 20 / 30 | Total Loss: 5.374724388122559 | KNN Loss: 4.335403919219971 | BCE Loss: 1.039320707321167
Epoch 343 / 500 | iteration 25 / 30 | Total Loss: 5.380877494812012 | KNN Loss: 4.374922752380371 | BC

Epoch 353 / 500 | iteration 10 / 30 | Total Loss: 5.3176727294921875 | KNN Loss: 4.314280986785889 | BCE Loss: 1.0033915042877197
Epoch 353 / 500 | iteration 15 / 30 | Total Loss: 5.365478992462158 | KNN Loss: 4.341991901397705 | BCE Loss: 1.0234870910644531
Epoch 353 / 500 | iteration 20 / 30 | Total Loss: 5.3853349685668945 | KNN Loss: 4.367983341217041 | BCE Loss: 1.0173516273498535
Epoch 353 / 500 | iteration 25 / 30 | Total Loss: 5.372584342956543 | KNN Loss: 4.357744216918945 | BCE Loss: 1.0148403644561768
Epoch 354 / 500 | iteration 0 / 30 | Total Loss: 5.365780353546143 | KNN Loss: 4.33652400970459 | BCE Loss: 1.0292563438415527
Epoch 354 / 500 | iteration 5 / 30 | Total Loss: 5.332111358642578 | KNN Loss: 4.326883792877197 | BCE Loss: 1.0052275657653809
Epoch 354 / 500 | iteration 10 / 30 | Total Loss: 5.378783226013184 | KNN Loss: 4.367127418518066 | BCE Loss: 1.0116560459136963
Epoch 354 / 500 | iteration 15 / 30 | Total Loss: 5.355651378631592 | KNN Loss: 4.324416637420654 

Epoch 364 / 500 | iteration 0 / 30 | Total Loss: 5.350191593170166 | KNN Loss: 4.338006019592285 | BCE Loss: 1.0121856927871704
Epoch 364 / 500 | iteration 5 / 30 | Total Loss: 5.357417106628418 | KNN Loss: 4.339258193969727 | BCE Loss: 1.0181591510772705
Epoch 364 / 500 | iteration 10 / 30 | Total Loss: 5.360299110412598 | KNN Loss: 4.342634677886963 | BCE Loss: 1.0176646709442139
Epoch 364 / 500 | iteration 15 / 30 | Total Loss: 5.33585262298584 | KNN Loss: 4.3140692710876465 | BCE Loss: 1.0217833518981934
Epoch 364 / 500 | iteration 20 / 30 | Total Loss: 5.338958740234375 | KNN Loss: 4.320520877838135 | BCE Loss: 1.0184376239776611
Epoch 364 / 500 | iteration 25 / 30 | Total Loss: 5.383331775665283 | KNN Loss: 4.370683193206787 | BCE Loss: 1.0126484632492065
Epoch 365 / 500 | iteration 0 / 30 | Total Loss: 5.3489227294921875 | KNN Loss: 4.354379177093506 | BCE Loss: 0.9945437908172607
Epoch 365 / 500 | iteration 5 / 30 | Total Loss: 5.3695969581604 | KNN Loss: 4.355698108673096 | BC

Epoch 374 / 500 | iteration 20 / 30 | Total Loss: 5.342226028442383 | KNN Loss: 4.327207088470459 | BCE Loss: 1.015019178390503
Epoch 374 / 500 | iteration 25 / 30 | Total Loss: 5.369772434234619 | KNN Loss: 4.334789752960205 | BCE Loss: 1.0349825620651245
Epoch 375 / 500 | iteration 0 / 30 | Total Loss: 5.391184329986572 | KNN Loss: 4.3752923011779785 | BCE Loss: 1.0158920288085938
Epoch 375 / 500 | iteration 5 / 30 | Total Loss: 5.383306503295898 | KNN Loss: 4.355643272399902 | BCE Loss: 1.0276634693145752
Epoch 375 / 500 | iteration 10 / 30 | Total Loss: 5.365352630615234 | KNN Loss: 4.337962627410889 | BCE Loss: 1.0273897647857666
Epoch 375 / 500 | iteration 15 / 30 | Total Loss: 5.375711441040039 | KNN Loss: 4.356298923492432 | BCE Loss: 1.0194122791290283
Epoch 375 / 500 | iteration 20 / 30 | Total Loss: 5.374667167663574 | KNN Loss: 4.351133823394775 | BCE Loss: 1.0235333442687988
Epoch 375 / 500 | iteration 25 / 30 | Total Loss: 5.355513095855713 | KNN Loss: 4.34341287612915 | 

Epoch 385 / 500 | iteration 10 / 30 | Total Loss: 5.318363189697266 | KNN Loss: 4.314396381378174 | BCE Loss: 1.0039669275283813
Epoch 385 / 500 | iteration 15 / 30 | Total Loss: 5.428670883178711 | KNN Loss: 4.358550548553467 | BCE Loss: 1.0701205730438232
Epoch 385 / 500 | iteration 20 / 30 | Total Loss: 5.3431172370910645 | KNN Loss: 4.32890510559082 | BCE Loss: 1.0142121315002441
Epoch 385 / 500 | iteration 25 / 30 | Total Loss: 5.321158409118652 | KNN Loss: 4.315371990203857 | BCE Loss: 1.0057862997055054
Epoch 386 / 500 | iteration 0 / 30 | Total Loss: 5.419076919555664 | KNN Loss: 4.350430011749268 | BCE Loss: 1.0686466693878174
Epoch 386 / 500 | iteration 5 / 30 | Total Loss: 5.373912334442139 | KNN Loss: 4.353179454803467 | BCE Loss: 1.0207327604293823
Epoch 386 / 500 | iteration 10 / 30 | Total Loss: 5.3345627784729 | KNN Loss: 4.318312168121338 | BCE Loss: 1.016250729560852
Epoch 386 / 500 | iteration 15 / 30 | Total Loss: 5.378401756286621 | KNN Loss: 4.370051860809326 | BC

Epoch 396 / 500 | iteration 0 / 30 | Total Loss: 5.384727478027344 | KNN Loss: 4.367968559265137 | BCE Loss: 1.016758680343628
Epoch 396 / 500 | iteration 5 / 30 | Total Loss: 5.4140143394470215 | KNN Loss: 4.384173393249512 | BCE Loss: 1.0298408269882202
Epoch 396 / 500 | iteration 10 / 30 | Total Loss: 5.387234687805176 | KNN Loss: 4.365289688110352 | BCE Loss: 1.0219447612762451
Epoch 396 / 500 | iteration 15 / 30 | Total Loss: 5.367539405822754 | KNN Loss: 4.3419060707092285 | BCE Loss: 1.0256335735321045
Epoch 396 / 500 | iteration 20 / 30 | Total Loss: 5.342684745788574 | KNN Loss: 4.307793140411377 | BCE Loss: 1.0348918437957764
Epoch 396 / 500 | iteration 25 / 30 | Total Loss: 5.325706481933594 | KNN Loss: 4.3463897705078125 | BCE Loss: 0.9793168306350708
Epoch 397 / 500 | iteration 0 / 30 | Total Loss: 5.373893737792969 | KNN Loss: 4.355412006378174 | BCE Loss: 1.0184818506240845
Epoch 397 / 500 | iteration 5 / 30 | Total Loss: 5.363255023956299 | KNN Loss: 4.342340469360352 |

Epoch 406 / 500 | iteration 20 / 30 | Total Loss: 5.393985748291016 | KNN Loss: 4.3762078285217285 | BCE Loss: 1.0177780389785767
Epoch 406 / 500 | iteration 25 / 30 | Total Loss: 5.421331882476807 | KNN Loss: 4.413797855377197 | BCE Loss: 1.007534146308899
Epoch 407 / 500 | iteration 0 / 30 | Total Loss: 5.325910568237305 | KNN Loss: 4.335816383361816 | BCE Loss: 0.9900941848754883
Epoch 407 / 500 | iteration 5 / 30 | Total Loss: 5.333857536315918 | KNN Loss: 4.328846454620361 | BCE Loss: 1.0050112009048462
Epoch 407 / 500 | iteration 10 / 30 | Total Loss: 5.360915184020996 | KNN Loss: 4.356260776519775 | BCE Loss: 1.0046545267105103
Epoch 407 / 500 | iteration 15 / 30 | Total Loss: 5.367401599884033 | KNN Loss: 4.3526997566223145 | BCE Loss: 1.0147018432617188
Epoch 407 / 500 | iteration 20 / 30 | Total Loss: 5.391834259033203 | KNN Loss: 4.334609031677246 | BCE Loss: 1.0572254657745361
Epoch 407 / 500 | iteration 25 / 30 | Total Loss: 5.389742374420166 | KNN Loss: 4.356434345245361 

Epoch 417 / 500 | iteration 10 / 30 | Total Loss: 5.391670227050781 | KNN Loss: 4.370935916900635 | BCE Loss: 1.0207343101501465
Epoch 417 / 500 | iteration 15 / 30 | Total Loss: 5.349358081817627 | KNN Loss: 4.352179527282715 | BCE Loss: 0.9971783757209778
Epoch 417 / 500 | iteration 20 / 30 | Total Loss: 5.352061748504639 | KNN Loss: 4.324198246002197 | BCE Loss: 1.027863621711731
Epoch 417 / 500 | iteration 25 / 30 | Total Loss: 5.3263959884643555 | KNN Loss: 4.31308650970459 | BCE Loss: 1.0133097171783447
Epoch 418 / 500 | iteration 0 / 30 | Total Loss: 5.412875652313232 | KNN Loss: 4.392711162567139 | BCE Loss: 1.0201644897460938
Epoch 418 / 500 | iteration 5 / 30 | Total Loss: 5.37197732925415 | KNN Loss: 4.3461384773254395 | BCE Loss: 1.025838851928711
Epoch 418 / 500 | iteration 10 / 30 | Total Loss: 5.365981578826904 | KNN Loss: 4.325695037841797 | BCE Loss: 1.040286660194397
Epoch 418 / 500 | iteration 15 / 30 | Total Loss: 5.353357315063477 | KNN Loss: 4.332359790802002 | BC

Epoch 428 / 500 | iteration 0 / 30 | Total Loss: 5.38817834854126 | KNN Loss: 4.370124340057373 | BCE Loss: 1.0180541276931763
Epoch 428 / 500 | iteration 5 / 30 | Total Loss: 5.366745471954346 | KNN Loss: 4.344224452972412 | BCE Loss: 1.0225211381912231
Epoch 428 / 500 | iteration 10 / 30 | Total Loss: 5.362945079803467 | KNN Loss: 4.344245910644531 | BCE Loss: 1.0186991691589355
Epoch 428 / 500 | iteration 15 / 30 | Total Loss: 5.349668502807617 | KNN Loss: 4.318215847015381 | BCE Loss: 1.0314526557922363
Epoch 428 / 500 | iteration 20 / 30 | Total Loss: 5.344742298126221 | KNN Loss: 4.332143306732178 | BCE Loss: 1.0125988721847534
Epoch 428 / 500 | iteration 25 / 30 | Total Loss: 5.41508674621582 | KNN Loss: 4.385757923126221 | BCE Loss: 1.0293288230895996
Epoch 429 / 500 | iteration 0 / 30 | Total Loss: 5.345803260803223 | KNN Loss: 4.332432270050049 | BCE Loss: 1.0133711099624634
Epoch 429 / 500 | iteration 5 / 30 | Total Loss: 5.35964298248291 | KNN Loss: 4.3458733558654785 | BCE

Epoch 438 / 500 | iteration 20 / 30 | Total Loss: 5.323318958282471 | KNN Loss: 4.311581611633301 | BCE Loss: 1.0117374658584595
Epoch 438 / 500 | iteration 25 / 30 | Total Loss: 5.362170696258545 | KNN Loss: 4.335082054138184 | BCE Loss: 1.0270886421203613
Epoch 439 / 500 | iteration 0 / 30 | Total Loss: 5.435814380645752 | KNN Loss: 4.394586563110352 | BCE Loss: 1.0412278175354004
Epoch 439 / 500 | iteration 5 / 30 | Total Loss: 5.361844539642334 | KNN Loss: 4.361156940460205 | BCE Loss: 1.0006874799728394
Epoch 439 / 500 | iteration 10 / 30 | Total Loss: 5.325392246246338 | KNN Loss: 4.312139987945557 | BCE Loss: 1.0132522583007812
Epoch 439 / 500 | iteration 15 / 30 | Total Loss: 5.370399475097656 | KNN Loss: 4.345261573791504 | BCE Loss: 1.0251379013061523
Epoch 439 / 500 | iteration 20 / 30 | Total Loss: 5.35184383392334 | KNN Loss: 4.333664894104004 | BCE Loss: 1.0181788206100464
Epoch 439 / 500 | iteration 25 / 30 | Total Loss: 5.35351037979126 | KNN Loss: 4.330833911895752 | B

Epoch 449 / 500 | iteration 10 / 30 | Total Loss: 5.347927093505859 | KNN Loss: 4.343445301055908 | BCE Loss: 1.004481554031372
Epoch 449 / 500 | iteration 15 / 30 | Total Loss: 5.371289253234863 | KNN Loss: 4.334761142730713 | BCE Loss: 1.03652822971344
Epoch 449 / 500 | iteration 20 / 30 | Total Loss: 5.332276344299316 | KNN Loss: 4.34283447265625 | BCE Loss: 0.9894418120384216
Epoch 449 / 500 | iteration 25 / 30 | Total Loss: 5.342615127563477 | KNN Loss: 4.345034599304199 | BCE Loss: 0.9975806474685669
Epoch 450 / 500 | iteration 0 / 30 | Total Loss: 5.359007835388184 | KNN Loss: 4.367350101470947 | BCE Loss: 0.9916574954986572
Epoch 450 / 500 | iteration 5 / 30 | Total Loss: 5.344439506530762 | KNN Loss: 4.330578327178955 | BCE Loss: 1.0138612985610962
Epoch 450 / 500 | iteration 10 / 30 | Total Loss: 5.396506309509277 | KNN Loss: 4.3855133056640625 | BCE Loss: 1.0109928846359253
Epoch 450 / 500 | iteration 15 / 30 | Total Loss: 5.382510662078857 | KNN Loss: 4.359109878540039 | BC

Epoch 460 / 500 | iteration 0 / 30 | Total Loss: 5.353022575378418 | KNN Loss: 4.3539886474609375 | BCE Loss: 0.9990341663360596
Epoch 460 / 500 | iteration 5 / 30 | Total Loss: 5.396484375 | KNN Loss: 4.3535637855529785 | BCE Loss: 1.0429205894470215
Epoch 460 / 500 | iteration 10 / 30 | Total Loss: 5.351693630218506 | KNN Loss: 4.327581405639648 | BCE Loss: 1.0241122245788574
Epoch 460 / 500 | iteration 15 / 30 | Total Loss: 5.358039379119873 | KNN Loss: 4.322794437408447 | BCE Loss: 1.0352450609207153
Epoch 460 / 500 | iteration 20 / 30 | Total Loss: 5.373109817504883 | KNN Loss: 4.332299709320068 | BCE Loss: 1.0408098697662354
Epoch 460 / 500 | iteration 25 / 30 | Total Loss: 5.381913185119629 | KNN Loss: 4.349060535430908 | BCE Loss: 1.0328526496887207
Epoch 461 / 500 | iteration 0 / 30 | Total Loss: 5.328824043273926 | KNN Loss: 4.314085960388184 | BCE Loss: 1.0147383213043213
Epoch 461 / 500 | iteration 5 / 30 | Total Loss: 5.407964706420898 | KNN Loss: 4.3662428855896 | BCE Los

Epoch 470 / 500 | iteration 20 / 30 | Total Loss: 5.346129894256592 | KNN Loss: 4.314916133880615 | BCE Loss: 1.0312138795852661
Epoch 470 / 500 | iteration 25 / 30 | Total Loss: 5.4247517585754395 | KNN Loss: 4.376527309417725 | BCE Loss: 1.0482245683670044
Epoch 471 / 500 | iteration 0 / 30 | Total Loss: 5.441575050354004 | KNN Loss: 4.41087007522583 | BCE Loss: 1.0307047367095947
Epoch 471 / 500 | iteration 5 / 30 | Total Loss: 5.354607582092285 | KNN Loss: 4.321402072906494 | BCE Loss: 1.0332057476043701
Epoch 471 / 500 | iteration 10 / 30 | Total Loss: 5.36297607421875 | KNN Loss: 4.337094783782959 | BCE Loss: 1.025881052017212
Epoch 471 / 500 | iteration 15 / 30 | Total Loss: 5.333868980407715 | KNN Loss: 4.327200889587402 | BCE Loss: 1.0066680908203125
Epoch 471 / 500 | iteration 20 / 30 | Total Loss: 5.383738040924072 | KNN Loss: 4.366393566131592 | BCE Loss: 1.017344355583191
Epoch 471 / 500 | iteration 25 / 30 | Total Loss: 5.364250659942627 | KNN Loss: 4.344139575958252 | BC

Epoch 481 / 500 | iteration 10 / 30 | Total Loss: 5.380999565124512 | KNN Loss: 4.382349491119385 | BCE Loss: 0.9986500144004822
Epoch 481 / 500 | iteration 15 / 30 | Total Loss: 5.323890209197998 | KNN Loss: 4.32285737991333 | BCE Loss: 1.001032829284668
Epoch 481 / 500 | iteration 20 / 30 | Total Loss: 5.402249336242676 | KNN Loss: 4.356246471405029 | BCE Loss: 1.0460028648376465
Epoch 481 / 500 | iteration 25 / 30 | Total Loss: 5.36060905456543 | KNN Loss: 4.328298568725586 | BCE Loss: 1.0323104858398438
Epoch 482 / 500 | iteration 0 / 30 | Total Loss: 5.40244722366333 | KNN Loss: 4.394598484039307 | BCE Loss: 1.0078487396240234
Epoch 482 / 500 | iteration 5 / 30 | Total Loss: 5.389484405517578 | KNN Loss: 4.368040561676025 | BCE Loss: 1.0214439630508423
Epoch 482 / 500 | iteration 10 / 30 | Total Loss: 5.365406036376953 | KNN Loss: 4.366781711578369 | BCE Loss: 0.9986240863800049
Epoch 482 / 500 | iteration 15 / 30 | Total Loss: 5.40535306930542 | KNN Loss: 4.355584144592285 | BCE 

Epoch   492: reducing learning rate of group 0 to 5.5221e-08.
Epoch 492 / 500 | iteration 0 / 30 | Total Loss: 5.350766658782959 | KNN Loss: 4.337794780731201 | BCE Loss: 1.0129719972610474
Epoch 492 / 500 | iteration 5 / 30 | Total Loss: 5.3958539962768555 | KNN Loss: 4.357967376708984 | BCE Loss: 1.0378867387771606
Epoch 492 / 500 | iteration 10 / 30 | Total Loss: 5.367641448974609 | KNN Loss: 4.3413214683532715 | BCE Loss: 1.026320219039917
Epoch 492 / 500 | iteration 15 / 30 | Total Loss: 5.3972930908203125 | KNN Loss: 4.384162902832031 | BCE Loss: 1.0131301879882812
Epoch 492 / 500 | iteration 20 / 30 | Total Loss: 5.36600399017334 | KNN Loss: 4.376267910003662 | BCE Loss: 0.9897359609603882
Epoch 492 / 500 | iteration 25 / 30 | Total Loss: 5.402640342712402 | KNN Loss: 4.371490955352783 | BCE Loss: 1.03114914894104
Epoch 493 / 500 | iteration 0 / 30 | Total Loss: 5.321385383605957 | KNN Loss: 4.327150344848633 | BCE Loss: 0.9942350387573242
Epoch 493 / 500 | iteration 5 / 30 | To

In [7]:
outputs, iterm = model(dataset[67][0].unsqueeze(0).to(device), return_intermidiate=True)
print(outputs)
print(outputs.min())
print(outputs.max())

tensor([[ 1.8325,  4.6160,  3.0677,  2.2880,  4.1206,  0.6368,  2.1909,  1.6315,
          1.8769,  2.4564,  2.1214,  1.3905,  1.1680,  2.0682,  1.4719,  0.8914,
          2.3867,  1.8554,  3.2410,  2.5316,  2.0258,  1.7814,  2.7396,  2.8656,
          1.6242,  1.8446,  1.2237,  0.8096,  1.0281,  0.5324,  0.1516,  0.8719,
         -0.3351,  0.9580,  2.0002,  0.9887,  0.4817,  3.2517,  0.8580,  0.8546,
          0.7083, -1.2026, -0.3733,  2.2330,  2.6924,  0.8294,  0.1402, -0.0804,
          1.7870,  1.6762,  2.2953,  0.1201,  0.7758,  0.4457, -0.4955,  1.0042,
          1.3151,  1.6255,  1.2598,  1.7012,  0.4507,  0.7257,  0.1015,  1.6082,
          1.2640,  1.6263, -1.6217,  0.1696,  2.7663,  1.7402,  2.9492,  0.2148,
          1.1837,  2.8738,  1.5330,  1.7702,  0.6192,  0.5534, -0.3117,  1.0783,
         -0.0685,  0.2990,  2.1036, -0.3987,  0.5665, -1.0399, -2.3851,  0.2871,
          0.5592, -1.7341,  0.2284, -0.0333, -0.1515, -1.0996,  0.6627,  0.8139,
         -0.4803, -0.3365,  

In [8]:
plt.figure()
plt.plot(losses)
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [9]:
dataset.transform = torchvision.transforms.Compose([
    BinaryEncodingTransform(mapping=dataset.items_to_idx),
]
)
dataset.target_transform = torchvision.transforms.Compose([
    BinaryEncodingTransform(mapping=dataset.items_to_idx),
]
)

In [10]:
dataset_ = [d[0].cpu() for d in dataset]

In [11]:
model = model.eval().cpu()
projections = model.calculate_intermidiate(dataset_)

100%|██████████| 15/15 [00:00<00:00, 83.33it/s]


In [12]:
distances = pairwise_distances(projections)
# distances = np.triu(distances)
distances_f = distances.flatten()

plt.matshow(distances)
plt.colorbar()
plt.figure()
plt.hist(distances_f[distances_f > 0], bins=1000)
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

# Fit DBSCAN and calculate indices


In [13]:
clusters = DBSCAN(eps=0.2, min_samples=80).fit_predict(projections)
# scores = []
# best_score = float('inf')
# clusters = None
# range_ = list(range(5, 20))
# for k in tqdm(range_):
#     y = GaussianMixture(n_components=k).fit_predict(projections)
#     cur_score = davies_bouldin_score(projections, y)
#     scores.append(cur_score)
    
#     if cur_score < best_score:
#         best_score = cur_score
#         clusters = y

In [14]:
perplexity = 100

p = reduce_dims_and_plot(projections[clusters != -1],
                         y=clusters[clusters != -1],
                         title=f'perplexity: {perplexity}',
                         file_name=None,
                         perplexity=perplexity,
                         library='Multicore-TSNE',
                         perform_PCA=False,
                         projected=None,
                         figure_type='2d',
                         show_figure=True,
                         close_figure=False,
                         text=None)    

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [15]:
# from sklearn.tree import DecisionTreeClassifier
# from sklearn import tree
# from sklearn.tree import _tree

In [16]:
tensor_dataset = torch.stack(dataset_)

In [17]:
# clf = DecisionTreeClassifier(max_depth=200, min_samples_leaf=5)
# clf = clf.fit(tensor_dataset[clusters!=-1], clusters[clusters != -1])
# print(clf.score(tensor_dataset[clusters!=-1], clusters[clusters != -1]))
# print(clf.get_depth())

In [18]:
# scores = []
# for min_samples in range(1,50, 1):
#     clf = DecisionTreeClassifier(max_depth=200, min_samples_leaf=min_samples)
#     clf = clf.fit(tensor_dataset[clusters!=-1], clusters[clusters != -1])
#     scores.append(clf.score(tensor_dataset[clusters!=-1], clusters[clusters != -1]))
    
# plt.figure()
# plt.plot(list(range(1,50, 1)), scores)
# plt.show()

In [19]:
def get_rules(tree, feature_names, class_names):
    tree_ = tree.tree_
    feature_name = [
        feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!"
        for i in tree_.feature
    ]

    paths = []
    path = []
    
    def recurse(node, path, paths):
        if tree_.feature[node] != _tree.TREE_UNDEFINED:
            name = feature_name[node]
            threshold = tree_.threshold[node]
            p1, p2 = list(path), list(path)
#             p1 += [f"({name} <= {np.round(threshold, 3)})"]
            p1 += [(name, '<=', np.round(threshold, 3))]
            recurse(tree_.children_left[node], p1, paths)
            p2 += [(name, '>', np.round(threshold, 3))]
            recurse(tree_.children_right[node], p2, paths)
        else:
            path += [(tree_.value[node], tree_.n_node_samples[node])]
            paths += [path]
            
    recurse(0, path, paths)

    # sort by samples count
    samples_count = [p[-1][1] for p in paths]
    ii = list(np.argsort(samples_count))
    paths = [paths[i] for i in reversed(ii)]
    
    rules = []
    for path in paths:
        rule = []
        
        for p in path[:-1]:
            rule += [p]
        target = " then "
        if class_names is None:
            target += "response: "+str(np.round(path[-1][0][0][0],3))
        else:
            classes = path[-1][0][0]
            l = np.argmax(classes)
            target += f"class: {class_names[l]} (proba: {np.round(100.0*classes[l]/np.sum(classes),2)}%)"
           
        proba = np.round(100.0*classes[l]/np.sum(classes),2)
        target += f" | based on {path[-1][1]:,} samples"
        rule_wrapper = {'target': target, 'rule': rule, 'proba': proba}
        rules += [rule_wrapper]
        
    return rules

In [20]:
# rules = get_rules(clf, dataset.items, clusters[clusters != -1])

# for rule in rules:
#     n_pos = 0
#     for c,p,v in rule['rule']:
#         if p == '>':
#             n_pos += 1
#     rule['pos'] = n_pos

In [21]:
# plt.figure()
# probs = [r['proba'] for r in rules]
# plt.hist(probs, bins = 100)
# plt.show()

In [22]:
# rules = sorted(rules, key=lambda x:x['pos'])
# rules = [r for r in rules if r['proba'] > 50]
# print(len(rules))

In [23]:
# for i in range(17):
#     r_i = rules[i]
#     print(f"------------- rule {i} length {len(r_i)} -------------")
#     print(r_i)

# Train a Soft-Decision-Tree given the self-labels

## Prepare the dataset

In [24]:
tree_dataset = list(zip(tensor_dataset[clusters!=-1], clusters[clusters != -1]))
batch_size = 512
tree_loader = torch.utils.data.DataLoader(tree_dataset, batch_size=batch_size, shuffle=True)

# Define how we prune the weights of a node

In [25]:
def prune_node(node_weights, factor=1):
    w = node_weights.cpu().detach().numpy()
    mean_ = np.mean(w)
    std_ = np.std(w)
    node_weights[((mean_ - std_ * factor) < node_weights) & (node_weights < (mean_ + std_ * factor))] = 0
    return node_weights

def prune_node_keep(node_weights, keep=4):
    w = node_weights.cpu().detach().numpy()
    throw_idx = np.argsort(abs(w))[:-keep]
    node_weights[throw_idx] = 0
    return node_weights

def prune_tree(tree_, factor):
    new_weights = tree_.inner_nodes.weight.clone()
    for i in range(new_weights.shape[0]):
        res = prune_node_keep(new_weights[i, :], factor)
        new_weights[i, :] = res

    with torch.no_grad():
        tree_.inner_nodes.weight.copy_(new_weights)
        
def sparseness(x):
    s = []
    for i in range(x.shape[0]):
        x_ = x[i, :]
        sp = (len(x_) - torch.norm(x_, 0).item()) / len(x_)
        s.append(sp)
    return np.mean(s)

def compute_regularization_by_level(tree):
    total_reg = 0
    for i in range(tree.inner_nodes.weight.shape[0]):
        cur_level = np.floor(np.log2(i+1))
        node_reg = torch.norm(tree.inner_nodes.weight[i].view(-1), 2)
        total_reg += 2**(-cur_level) * node_reg
    return total_reg

def show_sparseness(tree):
    avg_sp = sparseness(tree.inner_nodes.weight)
    print(f"Average sparseness: {avg_sp}")
    layer = 0
    sps = []
    for i in range(tree.inner_nodes.weight.shape[0]):
        cur_layer = int(np.floor(np.log2(i+1)))
        if cur_layer != layer:
            print(f"layer {layer}: {np.mean(sps)}")
            sps = []
            layer = cur_layer

        x_ = tree.inner_nodes.weight[i, :]
        sp = (len(x_) - torch.norm(x_, 0).item()) / len(x_)
        sps.append(sp)
        
    return avg_sp

## Training configurations

In [26]:
def do_epoch(model, loader, device, log_interval, losses, accs, epoch, iteration):
    model = model.train()
    for batch_idx, (data, target) in enumerate(loader):
        iteration += 1
        data, target = data.to(device), target.to(device)

        output, penalty = tree.forward(data)

        # Loss
        loss_tree = criterion(output, target.view(-1))

        # Penalty
        loss_tree += penalty

        # Sparse regularization
#         fc_params = torch.cat([x.view(-1) for x in tree.inner_nodes.parameters()])
#         regularization = sparsity_lamda * torch.norm(fc_params, 2)
        regularization = sparsity_lamda * compute_regularization_by_level(tree)
        loss = loss_tree

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        losses.append(loss.item())

        pred = output.data.max(1)[1]
        correct = pred.eq(target.view(-1).data).sum()
        accs.append(correct.item() / data.size()[0])

        # Print training status
        if batch_idx % log_interval == 0:
            print(f"Epoch: {epoch:02d} | Batch: {batch_idx:03d} / {len(loader):03d} | Total loss: {loss.item():.3f} | Reg loss: {regularization.item():.3f} | Tree loss: {loss_tree.item():.3f} | Accuracy: {correct.item() / data.size()[0]:03f} | {round((time.time() - start_time) / iteration, 3)} sec/iter")
            
    return iteration


In [27]:
lr = 5e-3
weight_decay = 5e-4
sparsity_lamda = 2e-3
epochs = 100
output_dim = len(set(clusters))
log_interval = 1
use_cuda = device != 'cpu'

In [28]:
tree = SDT(input_dim=tensor_dataset.shape[1], output_dim=len(clusters - 1), depth=tree_depth, lamda=1e-3, use_cuda=use_cuda)
optimizer = torch.optim.Adam(tree.parameters(),
                                 lr=lr,
                                 weight_decay=weight_decay)
criterion = nn.CrossEntropyLoss()
tree = tree.to(device)

In [29]:
losses = []
accs = []
sparsity = []

In [30]:
start_time = time.time()
iteration = 0
for epoch in range(epochs):
    # Training
    avg_sp = show_sparseness(tree)
    sparsity.append(avg_sp)
    iteration = do_epoch(tree, tree_loader, device, log_interval, losses, accs, epoch, iteration)
    
    if epoch % 1 == 0:
        prune_tree(tree, factor=3)
        

Average sparseness: 0.0
layer 0: 0.0
layer 1: 0.0
layer 2: 0.0
layer 3: 0.0
layer 4: 0.0
layer 5: 0.0
layer 6: 0.0
layer 7: 0.0
layer 8: 0.0
layer 9: 0.0
layer 10: 0.0
Epoch: 00 | Batch: 000 / 028 | Total loss: 9.634 | Reg loss: 0.014 | Tree loss: 9.634 | Accuracy: 0.000000 | 3.626 sec/iter
Epoch: 00 | Batch: 001 / 028 | Total loss: 9.632 | Reg loss: 0.013 | Tree loss: 9.632 | Accuracy: 0.000000 | 3.556 sec/iter
Epoch: 00 | Batch: 002 / 028 | Total loss: 9.629 | Reg loss: 0.012 | Tree loss: 9.629 | Accuracy: 0.000000 | 3.564 sec/iter
Epoch: 00 | Batch: 003 / 028 | Total loss: 9.625 | Reg loss: 0.011 | Tree loss: 9.625 | Accuracy: 0.000000 | 4.062 sec/iter
Epoch: 00 | Batch: 004 / 028 | Total loss: 9.624 | Reg loss: 0.010 | Tree loss: 9.624 | Accuracy: 0.000000 | 4.377 sec/iter
Epoch: 00 | Batch: 005 / 028 | Total loss: 9.621 | Reg loss: 0.009 | Tree loss: 9.621 | Accuracy: 0.000000 | 4.601 sec/iter
Epoch: 00 | Batch: 006 / 028 | Total loss: 9.618 | Reg loss: 0.008 | Tree loss: 9.618 | 

Epoch: 02 | Batch: 004 / 028 | Total loss: 9.558 | Reg loss: 0.007 | Tree loss: 9.558 | Accuracy: 0.171875 | 6.834 sec/iter
Epoch: 02 | Batch: 005 / 028 | Total loss: 9.555 | Reg loss: 0.007 | Tree loss: 9.555 | Accuracy: 0.199219 | 6.833 sec/iter
Epoch: 02 | Batch: 006 / 028 | Total loss: 9.556 | Reg loss: 0.007 | Tree loss: 9.556 | Accuracy: 0.169922 | 6.83 sec/iter
Epoch: 02 | Batch: 007 / 028 | Total loss: 9.552 | Reg loss: 0.007 | Tree loss: 9.552 | Accuracy: 0.179688 | 6.828 sec/iter
Epoch: 02 | Batch: 008 / 028 | Total loss: 9.551 | Reg loss: 0.007 | Tree loss: 9.551 | Accuracy: 0.175781 | 6.828 sec/iter
Epoch: 02 | Batch: 009 / 028 | Total loss: 9.546 | Reg loss: 0.008 | Tree loss: 9.546 | Accuracy: 0.228516 | 6.826 sec/iter
Epoch: 02 | Batch: 010 / 028 | Total loss: 9.544 | Reg loss: 0.008 | Tree loss: 9.544 | Accuracy: 0.220703 | 6.827 sec/iter
Epoch: 02 | Batch: 011 / 028 | Total loss: 9.540 | Reg loss: 0.008 | Tree loss: 9.540 | Accuracy: 0.232422 | 6.825 sec/iter
Epoch: 02

Epoch: 04 | Batch: 009 / 028 | Total loss: 9.444 | Reg loss: 0.013 | Tree loss: 9.444 | Accuracy: 0.222656 | 6.774 sec/iter
Epoch: 04 | Batch: 010 / 028 | Total loss: 9.435 | Reg loss: 0.013 | Tree loss: 9.435 | Accuracy: 0.205078 | 6.777 sec/iter
Epoch: 04 | Batch: 011 / 028 | Total loss: 9.427 | Reg loss: 0.014 | Tree loss: 9.427 | Accuracy: 0.177734 | 6.78 sec/iter
Epoch: 04 | Batch: 012 / 028 | Total loss: 9.421 | Reg loss: 0.014 | Tree loss: 9.421 | Accuracy: 0.203125 | 6.783 sec/iter
Epoch: 04 | Batch: 013 / 028 | Total loss: 9.416 | Reg loss: 0.015 | Tree loss: 9.416 | Accuracy: 0.177734 | 6.786 sec/iter
Epoch: 04 | Batch: 014 / 028 | Total loss: 9.403 | Reg loss: 0.015 | Tree loss: 9.403 | Accuracy: 0.205078 | 6.789 sec/iter
Epoch: 04 | Batch: 015 / 028 | Total loss: 9.400 | Reg loss: 0.015 | Tree loss: 9.400 | Accuracy: 0.183594 | 6.791 sec/iter
Epoch: 04 | Batch: 016 / 028 | Total loss: 9.377 | Reg loss: 0.016 | Tree loss: 9.377 | Accuracy: 0.169922 | 6.795 sec/iter
Epoch: 04

Epoch: 06 | Batch: 014 / 028 | Total loss: 9.040 | Reg loss: 0.021 | Tree loss: 9.040 | Accuracy: 0.156250 | 6.823 sec/iter
Epoch: 06 | Batch: 015 / 028 | Total loss: 9.006 | Reg loss: 0.021 | Tree loss: 9.006 | Accuracy: 0.162109 | 6.823 sec/iter
Epoch: 06 | Batch: 016 / 028 | Total loss: 8.982 | Reg loss: 0.022 | Tree loss: 8.982 | Accuracy: 0.210938 | 6.823 sec/iter
Epoch: 06 | Batch: 017 / 028 | Total loss: 8.968 | Reg loss: 0.022 | Tree loss: 8.968 | Accuracy: 0.171875 | 6.823 sec/iter
Epoch: 06 | Batch: 018 / 028 | Total loss: 8.960 | Reg loss: 0.023 | Tree loss: 8.960 | Accuracy: 0.169922 | 6.823 sec/iter
Epoch: 06 | Batch: 019 / 028 | Total loss: 8.906 | Reg loss: 0.023 | Tree loss: 8.906 | Accuracy: 0.167969 | 6.823 sec/iter
Epoch: 06 | Batch: 020 / 028 | Total loss: 8.885 | Reg loss: 0.024 | Tree loss: 8.885 | Accuracy: 0.195312 | 6.823 sec/iter
Epoch: 06 | Batch: 021 / 028 | Total loss: 8.881 | Reg loss: 0.024 | Tree loss: 8.881 | Accuracy: 0.208984 | 6.823 sec/iter
Epoch: 0

Epoch: 08 | Batch: 019 / 028 | Total loss: 8.379 | Reg loss: 0.026 | Tree loss: 8.379 | Accuracy: 0.195312 | 6.832 sec/iter
Epoch: 08 | Batch: 020 / 028 | Total loss: 8.336 | Reg loss: 0.027 | Tree loss: 8.336 | Accuracy: 0.201172 | 6.831 sec/iter
Epoch: 08 | Batch: 021 / 028 | Total loss: 8.329 | Reg loss: 0.027 | Tree loss: 8.329 | Accuracy: 0.173828 | 6.83 sec/iter
Epoch: 08 | Batch: 022 / 028 | Total loss: 8.321 | Reg loss: 0.027 | Tree loss: 8.321 | Accuracy: 0.179688 | 6.83 sec/iter
Epoch: 08 | Batch: 023 / 028 | Total loss: 8.288 | Reg loss: 0.028 | Tree loss: 8.288 | Accuracy: 0.181641 | 6.829 sec/iter
Epoch: 08 | Batch: 024 / 028 | Total loss: 8.234 | Reg loss: 0.028 | Tree loss: 8.234 | Accuracy: 0.179688 | 6.829 sec/iter
Epoch: 08 | Batch: 025 / 028 | Total loss: 8.216 | Reg loss: 0.028 | Tree loss: 8.216 | Accuracy: 0.175781 | 6.83 sec/iter
Epoch: 08 | Batch: 026 / 028 | Total loss: 8.206 | Reg loss: 0.029 | Tree loss: 8.206 | Accuracy: 0.197266 | 6.831 sec/iter
Epoch: 08 |

Epoch: 10 | Batch: 024 / 028 | Total loss: 7.722 | Reg loss: 0.030 | Tree loss: 7.722 | Accuracy: 0.156250 | 6.799 sec/iter
Epoch: 10 | Batch: 025 / 028 | Total loss: 7.685 | Reg loss: 0.030 | Tree loss: 7.685 | Accuracy: 0.173828 | 6.801 sec/iter
Epoch: 10 | Batch: 026 / 028 | Total loss: 7.655 | Reg loss: 0.030 | Tree loss: 7.655 | Accuracy: 0.177734 | 6.802 sec/iter
Epoch: 10 | Batch: 027 / 028 | Total loss: 7.528 | Reg loss: 0.030 | Tree loss: 7.528 | Accuracy: 0.375000 | 6.791 sec/iter
Average sparseness: 0.9821428571428572
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
layer 7: 0.9821428571428573
layer 8: 0.9821428571428573
layer 9: 0.9821428571428573
layer 10: 0.9821428571428573
Epoch: 11 | Batch: 000 / 028 | Total loss: 8.040 | Reg loss: 0.026 | Tree loss: 8.040 | Accuracy: 0.210938 | 6.95 sec/iter
Epoch: 11 | Batch: 001 / 028 | Tot

Epoch: 13 | Batch: 000 / 028 | Total loss: 7.512 | Reg loss: 0.029 | Tree loss: 7.512 | Accuracy: 0.173828 | 6.964 sec/iter
Epoch: 13 | Batch: 001 / 028 | Total loss: 7.483 | Reg loss: 0.029 | Tree loss: 7.483 | Accuracy: 0.228516 | 6.964 sec/iter
Epoch: 13 | Batch: 002 / 028 | Total loss: 7.438 | Reg loss: 0.029 | Tree loss: 7.438 | Accuracy: 0.193359 | 6.964 sec/iter
Epoch: 13 | Batch: 003 / 028 | Total loss: 7.447 | Reg loss: 0.029 | Tree loss: 7.447 | Accuracy: 0.181641 | 6.962 sec/iter
Epoch: 13 | Batch: 004 / 028 | Total loss: 7.426 | Reg loss: 0.029 | Tree loss: 7.426 | Accuracy: 0.212891 | 6.96 sec/iter
Epoch: 13 | Batch: 005 / 028 | Total loss: 7.378 | Reg loss: 0.029 | Tree loss: 7.378 | Accuracy: 0.169922 | 6.954 sec/iter
Epoch: 13 | Batch: 006 / 028 | Total loss: 7.353 | Reg loss: 0.029 | Tree loss: 7.353 | Accuracy: 0.175781 | 6.949 sec/iter
Epoch: 13 | Batch: 007 / 028 | Total loss: 7.342 | Reg loss: 0.029 | Tree loss: 7.342 | Accuracy: 0.167969 | 6.943 sec/iter
Epoch: 13

Epoch: 15 | Batch: 005 / 028 | Total loss: 6.848 | Reg loss: 0.031 | Tree loss: 6.848 | Accuracy: 0.173828 | 6.943 sec/iter
Epoch: 15 | Batch: 006 / 028 | Total loss: 6.783 | Reg loss: 0.032 | Tree loss: 6.783 | Accuracy: 0.154297 | 6.944 sec/iter
Epoch: 15 | Batch: 007 / 028 | Total loss: 6.775 | Reg loss: 0.032 | Tree loss: 6.775 | Accuracy: 0.181641 | 6.946 sec/iter
Epoch: 15 | Batch: 008 / 028 | Total loss: 6.701 | Reg loss: 0.032 | Tree loss: 6.701 | Accuracy: 0.171875 | 6.946 sec/iter
Epoch: 15 | Batch: 009 / 028 | Total loss: 6.687 | Reg loss: 0.032 | Tree loss: 6.687 | Accuracy: 0.181641 | 6.947 sec/iter
Epoch: 15 | Batch: 010 / 028 | Total loss: 6.737 | Reg loss: 0.032 | Tree loss: 6.737 | Accuracy: 0.207031 | 6.948 sec/iter
Epoch: 15 | Batch: 011 / 028 | Total loss: 6.632 | Reg loss: 0.032 | Tree loss: 6.632 | Accuracy: 0.189453 | 6.944 sec/iter
Epoch: 15 | Batch: 012 / 028 | Total loss: 6.594 | Reg loss: 0.033 | Tree loss: 6.594 | Accuracy: 0.169922 | 6.94 sec/iter
Epoch: 15

Epoch: 17 | Batch: 010 / 028 | Total loss: 6.129 | Reg loss: 0.034 | Tree loss: 6.129 | Accuracy: 0.187500 | 6.913 sec/iter
Epoch: 17 | Batch: 011 / 028 | Total loss: 6.058 | Reg loss: 0.034 | Tree loss: 6.058 | Accuracy: 0.160156 | 6.913 sec/iter
Epoch: 17 | Batch: 012 / 028 | Total loss: 6.042 | Reg loss: 0.035 | Tree loss: 6.042 | Accuracy: 0.187500 | 6.913 sec/iter
Epoch: 17 | Batch: 013 / 028 | Total loss: 6.000 | Reg loss: 0.035 | Tree loss: 6.000 | Accuracy: 0.193359 | 6.914 sec/iter
Epoch: 17 | Batch: 014 / 028 | Total loss: 5.970 | Reg loss: 0.035 | Tree loss: 5.970 | Accuracy: 0.207031 | 6.914 sec/iter
Epoch: 17 | Batch: 015 / 028 | Total loss: 5.990 | Reg loss: 0.035 | Tree loss: 5.990 | Accuracy: 0.183594 | 6.915 sec/iter
Epoch: 17 | Batch: 016 / 028 | Total loss: 5.939 | Reg loss: 0.035 | Tree loss: 5.939 | Accuracy: 0.201172 | 6.916 sec/iter
Epoch: 17 | Batch: 017 / 028 | Total loss: 5.921 | Reg loss: 0.035 | Tree loss: 5.921 | Accuracy: 0.164062 | 6.917 sec/iter
Epoch: 1

Epoch: 19 | Batch: 015 / 028 | Total loss: 5.409 | Reg loss: 0.037 | Tree loss: 5.409 | Accuracy: 0.191406 | 6.88 sec/iter
Epoch: 19 | Batch: 016 / 028 | Total loss: 5.458 | Reg loss: 0.037 | Tree loss: 5.458 | Accuracy: 0.195312 | 6.878 sec/iter
Epoch: 19 | Batch: 017 / 028 | Total loss: 5.419 | Reg loss: 0.037 | Tree loss: 5.419 | Accuracy: 0.185547 | 6.878 sec/iter
Epoch: 19 | Batch: 018 / 028 | Total loss: 5.426 | Reg loss: 0.037 | Tree loss: 5.426 | Accuracy: 0.195312 | 6.877 sec/iter
Epoch: 19 | Batch: 019 / 028 | Total loss: 5.383 | Reg loss: 0.037 | Tree loss: 5.383 | Accuracy: 0.144531 | 6.877 sec/iter
Epoch: 19 | Batch: 020 / 028 | Total loss: 5.331 | Reg loss: 0.038 | Tree loss: 5.331 | Accuracy: 0.193359 | 6.877 sec/iter
Epoch: 19 | Batch: 021 / 028 | Total loss: 5.311 | Reg loss: 0.038 | Tree loss: 5.311 | Accuracy: 0.167969 | 6.877 sec/iter
Epoch: 19 | Batch: 022 / 028 | Total loss: 5.294 | Reg loss: 0.038 | Tree loss: 5.294 | Accuracy: 0.185547 | 6.877 sec/iter
Epoch: 19

Epoch: 21 | Batch: 020 / 028 | Total loss: 4.886 | Reg loss: 0.039 | Tree loss: 4.886 | Accuracy: 0.183594 | 6.879 sec/iter
Epoch: 21 | Batch: 021 / 028 | Total loss: 4.897 | Reg loss: 0.039 | Tree loss: 4.897 | Accuracy: 0.164062 | 6.879 sec/iter
Epoch: 21 | Batch: 022 / 028 | Total loss: 4.863 | Reg loss: 0.039 | Tree loss: 4.863 | Accuracy: 0.189453 | 6.879 sec/iter
Epoch: 21 | Batch: 023 / 028 | Total loss: 4.886 | Reg loss: 0.039 | Tree loss: 4.886 | Accuracy: 0.166016 | 6.876 sec/iter
Epoch: 21 | Batch: 024 / 028 | Total loss: 4.826 | Reg loss: 0.039 | Tree loss: 4.826 | Accuracy: 0.152344 | 6.872 sec/iter
Epoch: 21 | Batch: 025 / 028 | Total loss: 4.865 | Reg loss: 0.040 | Tree loss: 4.865 | Accuracy: 0.156250 | 6.869 sec/iter
Epoch: 21 | Batch: 026 / 028 | Total loss: 4.847 | Reg loss: 0.040 | Tree loss: 4.847 | Accuracy: 0.181641 | 6.866 sec/iter
Epoch: 21 | Batch: 027 / 028 | Total loss: 4.943 | Reg loss: 0.040 | Tree loss: 4.943 | Accuracy: 0.187500 | 6.86 sec/iter
Average s

Epoch: 23 | Batch: 025 / 028 | Total loss: 4.365 | Reg loss: 0.041 | Tree loss: 4.365 | Accuracy: 0.201172 | 6.865 sec/iter
Epoch: 23 | Batch: 026 / 028 | Total loss: 4.406 | Reg loss: 0.041 | Tree loss: 4.406 | Accuracy: 0.181641 | 6.866 sec/iter
Epoch: 23 | Batch: 027 / 028 | Total loss: 4.233 | Reg loss: 0.041 | Tree loss: 4.233 | Accuracy: 0.312500 | 6.86 sec/iter
Average sparseness: 0.9821428571428572
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
layer 7: 0.9821428571428573
layer 8: 0.9821428571428573
layer 9: 0.9821428571428573
layer 10: 0.9821428571428573
Epoch: 24 | Batch: 000 / 028 | Total loss: 4.728 | Reg loss: 0.039 | Tree loss: 4.728 | Accuracy: 0.205078 | 6.871 sec/iter
Epoch: 24 | Batch: 001 / 028 | Total loss: 4.726 | Reg loss: 0.039 | Tree loss: 4.726 | Accuracy: 0.177734 | 6.871 sec/iter
Epoch: 24 | Batch: 002 / 028 | Tot

layer 10: 0.9821428571428573
Epoch: 26 | Batch: 000 / 028 | Total loss: 4.299 | Reg loss: 0.040 | Tree loss: 4.299 | Accuracy: 0.175781 | 6.853 sec/iter
Epoch: 26 | Batch: 001 / 028 | Total loss: 4.287 | Reg loss: 0.040 | Tree loss: 4.287 | Accuracy: 0.197266 | 6.85 sec/iter
Epoch: 26 | Batch: 002 / 028 | Total loss: 4.307 | Reg loss: 0.040 | Tree loss: 4.307 | Accuracy: 0.164062 | 6.848 sec/iter
Epoch: 26 | Batch: 003 / 028 | Total loss: 4.279 | Reg loss: 0.040 | Tree loss: 4.279 | Accuracy: 0.169922 | 6.845 sec/iter
Epoch: 26 | Batch: 004 / 028 | Total loss: 4.218 | Reg loss: 0.041 | Tree loss: 4.218 | Accuracy: 0.191406 | 6.842 sec/iter
Epoch: 26 | Batch: 005 / 028 | Total loss: 4.229 | Reg loss: 0.041 | Tree loss: 4.229 | Accuracy: 0.181641 | 6.839 sec/iter
Epoch: 26 | Batch: 006 / 028 | Total loss: 4.152 | Reg loss: 0.041 | Tree loss: 4.152 | Accuracy: 0.203125 | 6.837 sec/iter
Epoch: 26 | Batch: 007 / 028 | Total loss: 4.085 | Reg loss: 0.041 | Tree loss: 4.085 | Accuracy: 0.2187

Epoch: 28 | Batch: 005 / 028 | Total loss: 3.791 | Reg loss: 0.041 | Tree loss: 3.791 | Accuracy: 0.179688 | 6.831 sec/iter
Epoch: 28 | Batch: 006 / 028 | Total loss: 3.799 | Reg loss: 0.041 | Tree loss: 3.799 | Accuracy: 0.167969 | 6.831 sec/iter
Epoch: 28 | Batch: 007 / 028 | Total loss: 3.789 | Reg loss: 0.042 | Tree loss: 3.789 | Accuracy: 0.167969 | 6.831 sec/iter
Epoch: 28 | Batch: 008 / 028 | Total loss: 3.798 | Reg loss: 0.042 | Tree loss: 3.798 | Accuracy: 0.148438 | 6.831 sec/iter
Epoch: 28 | Batch: 009 / 028 | Total loss: 3.786 | Reg loss: 0.042 | Tree loss: 3.786 | Accuracy: 0.185547 | 6.832 sec/iter
Epoch: 28 | Batch: 010 / 028 | Total loss: 3.792 | Reg loss: 0.042 | Tree loss: 3.792 | Accuracy: 0.193359 | 6.833 sec/iter
Epoch: 28 | Batch: 011 / 028 | Total loss: 3.694 | Reg loss: 0.042 | Tree loss: 3.694 | Accuracy: 0.197266 | 6.833 sec/iter
Epoch: 28 | Batch: 012 / 028 | Total loss: 3.706 | Reg loss: 0.042 | Tree loss: 3.706 | Accuracy: 0.185547 | 6.834 sec/iter
Epoch: 2

Epoch: 30 | Batch: 010 / 028 | Total loss: 3.426 | Reg loss: 0.042 | Tree loss: 3.426 | Accuracy: 0.167969 | 6.843 sec/iter
Epoch: 30 | Batch: 011 / 028 | Total loss: 3.438 | Reg loss: 0.042 | Tree loss: 3.438 | Accuracy: 0.199219 | 6.844 sec/iter
Epoch: 30 | Batch: 012 / 028 | Total loss: 3.406 | Reg loss: 0.042 | Tree loss: 3.406 | Accuracy: 0.187500 | 6.844 sec/iter
Epoch: 30 | Batch: 013 / 028 | Total loss: 3.463 | Reg loss: 0.042 | Tree loss: 3.463 | Accuracy: 0.160156 | 6.845 sec/iter
Epoch: 30 | Batch: 014 / 028 | Total loss: 3.379 | Reg loss: 0.043 | Tree loss: 3.379 | Accuracy: 0.187500 | 6.845 sec/iter
Epoch: 30 | Batch: 015 / 028 | Total loss: 3.380 | Reg loss: 0.043 | Tree loss: 3.380 | Accuracy: 0.191406 | 6.845 sec/iter
Epoch: 30 | Batch: 016 / 028 | Total loss: 3.353 | Reg loss: 0.043 | Tree loss: 3.353 | Accuracy: 0.201172 | 6.846 sec/iter
Epoch: 30 | Batch: 017 / 028 | Total loss: 3.340 | Reg loss: 0.043 | Tree loss: 3.340 | Accuracy: 0.197266 | 6.846 sec/iter
Epoch: 3

Epoch: 32 | Batch: 015 / 028 | Total loss: 3.160 | Reg loss: 0.043 | Tree loss: 3.160 | Accuracy: 0.162109 | 6.844 sec/iter
Epoch: 32 | Batch: 016 / 028 | Total loss: 3.178 | Reg loss: 0.043 | Tree loss: 3.178 | Accuracy: 0.183594 | 6.844 sec/iter
Epoch: 32 | Batch: 017 / 028 | Total loss: 3.142 | Reg loss: 0.043 | Tree loss: 3.142 | Accuracy: 0.183594 | 6.845 sec/iter
Epoch: 32 | Batch: 018 / 028 | Total loss: 3.110 | Reg loss: 0.043 | Tree loss: 3.110 | Accuracy: 0.199219 | 6.845 sec/iter
Epoch: 32 | Batch: 019 / 028 | Total loss: 3.121 | Reg loss: 0.043 | Tree loss: 3.121 | Accuracy: 0.199219 | 6.845 sec/iter
Epoch: 32 | Batch: 020 / 028 | Total loss: 3.056 | Reg loss: 0.043 | Tree loss: 3.056 | Accuracy: 0.156250 | 6.845 sec/iter
Epoch: 32 | Batch: 021 / 028 | Total loss: 3.062 | Reg loss: 0.043 | Tree loss: 3.062 | Accuracy: 0.201172 | 6.845 sec/iter
Epoch: 32 | Batch: 022 / 028 | Total loss: 3.087 | Reg loss: 0.043 | Tree loss: 3.087 | Accuracy: 0.152344 | 6.845 sec/iter
Epoch: 3

Epoch: 34 | Batch: 020 / 028 | Total loss: 2.940 | Reg loss: 0.043 | Tree loss: 2.940 | Accuracy: 0.181641 | 6.838 sec/iter
Epoch: 34 | Batch: 021 / 028 | Total loss: 2.913 | Reg loss: 0.043 | Tree loss: 2.913 | Accuracy: 0.169922 | 6.838 sec/iter
Epoch: 34 | Batch: 022 / 028 | Total loss: 2.907 | Reg loss: 0.043 | Tree loss: 2.907 | Accuracy: 0.181641 | 6.839 sec/iter
Epoch: 34 | Batch: 023 / 028 | Total loss: 2.857 | Reg loss: 0.044 | Tree loss: 2.857 | Accuracy: 0.193359 | 6.839 sec/iter
Epoch: 34 | Batch: 024 / 028 | Total loss: 2.887 | Reg loss: 0.044 | Tree loss: 2.887 | Accuracy: 0.175781 | 6.839 sec/iter
Epoch: 34 | Batch: 025 / 028 | Total loss: 2.879 | Reg loss: 0.044 | Tree loss: 2.879 | Accuracy: 0.164062 | 6.839 sec/iter
Epoch: 34 | Batch: 026 / 028 | Total loss: 2.914 | Reg loss: 0.044 | Tree loss: 2.914 | Accuracy: 0.177734 | 6.839 sec/iter
Epoch: 34 | Batch: 027 / 028 | Total loss: 2.591 | Reg loss: 0.044 | Tree loss: 2.591 | Accuracy: 0.250000 | 6.836 sec/iter
Average 

Epoch: 36 | Batch: 025 / 028 | Total loss: 2.781 | Reg loss: 0.044 | Tree loss: 2.781 | Accuracy: 0.156250 | 6.843 sec/iter
Epoch: 36 | Batch: 026 / 028 | Total loss: 2.775 | Reg loss: 0.044 | Tree loss: 2.775 | Accuracy: 0.162109 | 6.843 sec/iter
Epoch: 36 | Batch: 027 / 028 | Total loss: 2.795 | Reg loss: 0.044 | Tree loss: 2.795 | Accuracy: 0.187500 | 6.839 sec/iter
Average sparseness: 0.9821428571428572
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
layer 7: 0.9821428571428573
layer 8: 0.9821428571428573
layer 9: 0.9821428571428573
layer 10: 0.9821428571428573
Epoch: 37 | Batch: 000 / 028 | Total loss: 2.948 | Reg loss: 0.043 | Tree loss: 2.948 | Accuracy: 0.171875 | 6.848 sec/iter
Epoch: 37 | Batch: 001 / 028 | Total loss: 2.981 | Reg loss: 0.043 | Tree loss: 2.981 | Accuracy: 0.193359 | 6.848 sec/iter
Epoch: 37 | Batch: 002 / 028 | To

layer 10: 0.9821428571428573
Epoch: 39 | Batch: 000 / 028 | Total loss: 2.863 | Reg loss: 0.043 | Tree loss: 2.863 | Accuracy: 0.189453 | 6.822 sec/iter
Epoch: 39 | Batch: 001 / 028 | Total loss: 2.888 | Reg loss: 0.043 | Tree loss: 2.888 | Accuracy: 0.175781 | 6.821 sec/iter
Epoch: 39 | Batch: 002 / 028 | Total loss: 2.872 | Reg loss: 0.043 | Tree loss: 2.872 | Accuracy: 0.148438 | 6.82 sec/iter
Epoch: 39 | Batch: 003 / 028 | Total loss: 2.903 | Reg loss: 0.043 | Tree loss: 2.903 | Accuracy: 0.173828 | 6.818 sec/iter
Epoch: 39 | Batch: 004 / 028 | Total loss: 2.822 | Reg loss: 0.043 | Tree loss: 2.822 | Accuracy: 0.187500 | 6.819 sec/iter
Epoch: 39 | Batch: 005 / 028 | Total loss: 2.831 | Reg loss: 0.043 | Tree loss: 2.831 | Accuracy: 0.166016 | 6.819 sec/iter
Epoch: 39 | Batch: 006 / 028 | Total loss: 2.800 | Reg loss: 0.043 | Tree loss: 2.800 | Accuracy: 0.185547 | 6.82 sec/iter
Epoch: 39 | Batch: 007 / 028 | Total loss: 2.843 | Reg loss: 0.043 | Tree loss: 2.843 | Accuracy: 0.19335

Epoch: 41 | Batch: 005 / 028 | Total loss: 2.722 | Reg loss: 0.043 | Tree loss: 2.722 | Accuracy: 0.193359 | 6.828 sec/iter
Epoch: 41 | Batch: 006 / 028 | Total loss: 2.713 | Reg loss: 0.043 | Tree loss: 2.713 | Accuracy: 0.193359 | 6.828 sec/iter
Epoch: 41 | Batch: 007 / 028 | Total loss: 2.742 | Reg loss: 0.043 | Tree loss: 2.742 | Accuracy: 0.181641 | 6.826 sec/iter
Epoch: 41 | Batch: 008 / 028 | Total loss: 2.731 | Reg loss: 0.043 | Tree loss: 2.731 | Accuracy: 0.166016 | 6.827 sec/iter
Epoch: 41 | Batch: 009 / 028 | Total loss: 2.721 | Reg loss: 0.043 | Tree loss: 2.721 | Accuracy: 0.142578 | 6.828 sec/iter
Epoch: 41 | Batch: 010 / 028 | Total loss: 2.665 | Reg loss: 0.043 | Tree loss: 2.665 | Accuracy: 0.193359 | 6.828 sec/iter
Epoch: 41 | Batch: 011 / 028 | Total loss: 2.674 | Reg loss: 0.043 | Tree loss: 2.674 | Accuracy: 0.164062 | 6.828 sec/iter
Epoch: 41 | Batch: 012 / 028 | Total loss: 2.636 | Reg loss: 0.043 | Tree loss: 2.636 | Accuracy: 0.193359 | 6.829 sec/iter
Epoch: 4

Epoch: 43 | Batch: 010 / 028 | Total loss: 2.631 | Reg loss: 0.042 | Tree loss: 2.631 | Accuracy: 0.152344 | 6.837 sec/iter
Epoch: 43 | Batch: 011 / 028 | Total loss: 2.653 | Reg loss: 0.042 | Tree loss: 2.653 | Accuracy: 0.175781 | 6.836 sec/iter
Epoch: 43 | Batch: 012 / 028 | Total loss: 2.587 | Reg loss: 0.042 | Tree loss: 2.587 | Accuracy: 0.164062 | 6.836 sec/iter
Epoch: 43 | Batch: 013 / 028 | Total loss: 2.610 | Reg loss: 0.042 | Tree loss: 2.610 | Accuracy: 0.177734 | 6.834 sec/iter
Epoch: 43 | Batch: 014 / 028 | Total loss: 2.555 | Reg loss: 0.042 | Tree loss: 2.555 | Accuracy: 0.193359 | 6.832 sec/iter
Epoch: 43 | Batch: 015 / 028 | Total loss: 2.646 | Reg loss: 0.042 | Tree loss: 2.646 | Accuracy: 0.162109 | 6.831 sec/iter
Epoch: 43 | Batch: 016 / 028 | Total loss: 2.558 | Reg loss: 0.042 | Tree loss: 2.558 | Accuracy: 0.179688 | 6.829 sec/iter
Epoch: 43 | Batch: 017 / 028 | Total loss: 2.552 | Reg loss: 0.042 | Tree loss: 2.552 | Accuracy: 0.181641 | 6.827 sec/iter
Epoch: 4

Epoch: 45 | Batch: 015 / 028 | Total loss: 2.479 | Reg loss: 0.042 | Tree loss: 2.479 | Accuracy: 0.199219 | 6.829 sec/iter
Epoch: 45 | Batch: 016 / 028 | Total loss: 2.525 | Reg loss: 0.042 | Tree loss: 2.525 | Accuracy: 0.193359 | 6.83 sec/iter
Epoch: 45 | Batch: 017 / 028 | Total loss: 2.515 | Reg loss: 0.042 | Tree loss: 2.515 | Accuracy: 0.201172 | 6.83 sec/iter
Epoch: 45 | Batch: 018 / 028 | Total loss: 2.475 | Reg loss: 0.042 | Tree loss: 2.475 | Accuracy: 0.169922 | 6.83 sec/iter
Epoch: 45 | Batch: 019 / 028 | Total loss: 2.529 | Reg loss: 0.042 | Tree loss: 2.529 | Accuracy: 0.183594 | 6.83 sec/iter
Epoch: 45 | Batch: 020 / 028 | Total loss: 2.470 | Reg loss: 0.042 | Tree loss: 2.470 | Accuracy: 0.205078 | 6.831 sec/iter
Epoch: 45 | Batch: 021 / 028 | Total loss: 2.497 | Reg loss: 0.042 | Tree loss: 2.497 | Accuracy: 0.148438 | 6.831 sec/iter
Epoch: 45 | Batch: 022 / 028 | Total loss: 2.483 | Reg loss: 0.042 | Tree loss: 2.483 | Accuracy: 0.181641 | 6.831 sec/iter
Epoch: 45 | 

Epoch: 47 | Batch: 020 / 028 | Total loss: 2.432 | Reg loss: 0.042 | Tree loss: 2.432 | Accuracy: 0.175781 | 6.852 sec/iter
Epoch: 47 | Batch: 021 / 028 | Total loss: 2.497 | Reg loss: 0.042 | Tree loss: 2.497 | Accuracy: 0.187500 | 6.852 sec/iter
Epoch: 47 | Batch: 022 / 028 | Total loss: 2.431 | Reg loss: 0.042 | Tree loss: 2.431 | Accuracy: 0.238281 | 6.851 sec/iter
Epoch: 47 | Batch: 023 / 028 | Total loss: 2.402 | Reg loss: 0.042 | Tree loss: 2.402 | Accuracy: 0.166016 | 6.851 sec/iter
Epoch: 47 | Batch: 024 / 028 | Total loss: 2.405 | Reg loss: 0.042 | Tree loss: 2.405 | Accuracy: 0.208984 | 6.851 sec/iter
Epoch: 47 | Batch: 025 / 028 | Total loss: 2.420 | Reg loss: 0.042 | Tree loss: 2.420 | Accuracy: 0.160156 | 6.851 sec/iter
Epoch: 47 | Batch: 026 / 028 | Total loss: 2.414 | Reg loss: 0.042 | Tree loss: 2.414 | Accuracy: 0.173828 | 6.851 sec/iter
Epoch: 47 | Batch: 027 / 028 | Total loss: 2.690 | Reg loss: 0.042 | Tree loss: 2.690 | Accuracy: 0.187500 | 6.848 sec/iter
Average 

Epoch: 49 | Batch: 025 / 028 | Total loss: 2.397 | Reg loss: 0.042 | Tree loss: 2.397 | Accuracy: 0.169922 | 6.862 sec/iter
Epoch: 49 | Batch: 026 / 028 | Total loss: 2.425 | Reg loss: 0.042 | Tree loss: 2.425 | Accuracy: 0.181641 | 6.863 sec/iter
Epoch: 49 | Batch: 027 / 028 | Total loss: 2.393 | Reg loss: 0.042 | Tree loss: 2.393 | Accuracy: 0.062500 | 6.86 sec/iter
Average sparseness: 0.9821428571428572
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
layer 7: 0.9821428571428573
layer 8: 0.9821428571428573
layer 9: 0.9821428571428573
layer 10: 0.9821428571428573
Epoch: 50 | Batch: 000 / 028 | Total loss: 2.530 | Reg loss: 0.041 | Tree loss: 2.530 | Accuracy: 0.195312 | 6.87 sec/iter
Epoch: 50 | Batch: 001 / 028 | Total loss: 2.501 | Reg loss: 0.041 | Tree loss: 2.501 | Accuracy: 0.181641 | 6.869 sec/iter
Epoch: 50 | Batch: 002 / 028 | Tota

layer 10: 0.9821428571428573
Epoch: 52 | Batch: 000 / 028 | Total loss: 2.467 | Reg loss: 0.041 | Tree loss: 2.467 | Accuracy: 0.207031 | 6.892 sec/iter
Epoch: 52 | Batch: 001 / 028 | Total loss: 2.553 | Reg loss: 0.041 | Tree loss: 2.553 | Accuracy: 0.185547 | 6.891 sec/iter
Epoch: 52 | Batch: 002 / 028 | Total loss: 2.491 | Reg loss: 0.041 | Tree loss: 2.491 | Accuracy: 0.175781 | 6.89 sec/iter
Epoch: 52 | Batch: 003 / 028 | Total loss: 2.477 | Reg loss: 0.041 | Tree loss: 2.477 | Accuracy: 0.187500 | 6.89 sec/iter
Epoch: 52 | Batch: 004 / 028 | Total loss: 2.513 | Reg loss: 0.041 | Tree loss: 2.513 | Accuracy: 0.173828 | 6.891 sec/iter
Epoch: 52 | Batch: 005 / 028 | Total loss: 2.492 | Reg loss: 0.041 | Tree loss: 2.492 | Accuracy: 0.148438 | 6.891 sec/iter
Epoch: 52 | Batch: 006 / 028 | Total loss: 2.522 | Reg loss: 0.041 | Tree loss: 2.522 | Accuracy: 0.154297 | 6.891 sec/iter
Epoch: 52 | Batch: 007 / 028 | Total loss: 2.433 | Reg loss: 0.041 | Tree loss: 2.433 | Accuracy: 0.15820

Epoch: 54 | Batch: 005 / 028 | Total loss: 2.473 | Reg loss: 0.040 | Tree loss: 2.473 | Accuracy: 0.171875 | 6.907 sec/iter
Epoch: 54 | Batch: 006 / 028 | Total loss: 2.384 | Reg loss: 0.040 | Tree loss: 2.384 | Accuracy: 0.164062 | 6.908 sec/iter
Epoch: 54 | Batch: 007 / 028 | Total loss: 2.462 | Reg loss: 0.040 | Tree loss: 2.462 | Accuracy: 0.193359 | 6.908 sec/iter
Epoch: 54 | Batch: 008 / 028 | Total loss: 2.426 | Reg loss: 0.040 | Tree loss: 2.426 | Accuracy: 0.185547 | 6.908 sec/iter
Epoch: 54 | Batch: 009 / 028 | Total loss: 2.353 | Reg loss: 0.040 | Tree loss: 2.353 | Accuracy: 0.216797 | 6.908 sec/iter
Epoch: 54 | Batch: 010 / 028 | Total loss: 2.382 | Reg loss: 0.040 | Tree loss: 2.382 | Accuracy: 0.181641 | 6.909 sec/iter
Epoch: 54 | Batch: 011 / 028 | Total loss: 2.381 | Reg loss: 0.040 | Tree loss: 2.381 | Accuracy: 0.179688 | 6.909 sec/iter
Epoch: 54 | Batch: 012 / 028 | Total loss: 2.390 | Reg loss: 0.040 | Tree loss: 2.390 | Accuracy: 0.197266 | 6.909 sec/iter
Epoch: 5

Epoch: 56 | Batch: 010 / 028 | Total loss: 2.376 | Reg loss: 0.040 | Tree loss: 2.376 | Accuracy: 0.183594 | 6.92 sec/iter
Epoch: 56 | Batch: 011 / 028 | Total loss: 2.396 | Reg loss: 0.040 | Tree loss: 2.396 | Accuracy: 0.152344 | 6.92 sec/iter
Epoch: 56 | Batch: 012 / 028 | Total loss: 2.410 | Reg loss: 0.040 | Tree loss: 2.410 | Accuracy: 0.179688 | 6.92 sec/iter
Epoch: 56 | Batch: 013 / 028 | Total loss: 2.358 | Reg loss: 0.040 | Tree loss: 2.358 | Accuracy: 0.189453 | 6.92 sec/iter
Epoch: 56 | Batch: 014 / 028 | Total loss: 2.358 | Reg loss: 0.040 | Tree loss: 2.358 | Accuracy: 0.177734 | 6.92 sec/iter
Epoch: 56 | Batch: 015 / 028 | Total loss: 2.365 | Reg loss: 0.040 | Tree loss: 2.365 | Accuracy: 0.152344 | 6.92 sec/iter
Epoch: 56 | Batch: 016 / 028 | Total loss: 2.347 | Reg loss: 0.040 | Tree loss: 2.347 | Accuracy: 0.214844 | 6.92 sec/iter
Epoch: 56 | Batch: 017 / 028 | Total loss: 2.365 | Reg loss: 0.040 | Tree loss: 2.365 | Accuracy: 0.181641 | 6.92 sec/iter
Epoch: 56 | Batc

Epoch: 58 | Batch: 015 / 028 | Total loss: 2.352 | Reg loss: 0.040 | Tree loss: 2.352 | Accuracy: 0.160156 | 6.933 sec/iter
Epoch: 58 | Batch: 016 / 028 | Total loss: 2.362 | Reg loss: 0.040 | Tree loss: 2.362 | Accuracy: 0.183594 | 6.933 sec/iter
Epoch: 58 | Batch: 017 / 028 | Total loss: 2.330 | Reg loss: 0.040 | Tree loss: 2.330 | Accuracy: 0.183594 | 6.933 sec/iter
Epoch: 58 | Batch: 018 / 028 | Total loss: 2.334 | Reg loss: 0.040 | Tree loss: 2.334 | Accuracy: 0.187500 | 6.933 sec/iter
Epoch: 58 | Batch: 019 / 028 | Total loss: 2.273 | Reg loss: 0.040 | Tree loss: 2.273 | Accuracy: 0.179688 | 6.933 sec/iter
Epoch: 58 | Batch: 020 / 028 | Total loss: 2.294 | Reg loss: 0.040 | Tree loss: 2.294 | Accuracy: 0.191406 | 6.932 sec/iter
Epoch: 58 | Batch: 021 / 028 | Total loss: 2.277 | Reg loss: 0.040 | Tree loss: 2.277 | Accuracy: 0.210938 | 6.932 sec/iter
Epoch: 58 | Batch: 022 / 028 | Total loss: 2.290 | Reg loss: 0.040 | Tree loss: 2.290 | Accuracy: 0.193359 | 6.932 sec/iter
Epoch: 5

Epoch: 60 | Batch: 020 / 028 | Total loss: 2.299 | Reg loss: 0.039 | Tree loss: 2.299 | Accuracy: 0.193359 | 6.949 sec/iter
Epoch: 60 | Batch: 021 / 028 | Total loss: 2.315 | Reg loss: 0.039 | Tree loss: 2.315 | Accuracy: 0.167969 | 6.949 sec/iter
Epoch: 60 | Batch: 022 / 028 | Total loss: 2.239 | Reg loss: 0.039 | Tree loss: 2.239 | Accuracy: 0.201172 | 6.949 sec/iter
Epoch: 60 | Batch: 023 / 028 | Total loss: 2.288 | Reg loss: 0.039 | Tree loss: 2.288 | Accuracy: 0.183594 | 6.949 sec/iter
Epoch: 60 | Batch: 024 / 028 | Total loss: 2.274 | Reg loss: 0.039 | Tree loss: 2.274 | Accuracy: 0.222656 | 6.949 sec/iter
Epoch: 60 | Batch: 025 / 028 | Total loss: 2.310 | Reg loss: 0.039 | Tree loss: 2.310 | Accuracy: 0.185547 | 6.949 sec/iter
Epoch: 60 | Batch: 026 / 028 | Total loss: 2.274 | Reg loss: 0.039 | Tree loss: 2.274 | Accuracy: 0.207031 | 6.949 sec/iter
Epoch: 60 | Batch: 027 / 028 | Total loss: 2.320 | Reg loss: 0.039 | Tree loss: 2.320 | Accuracy: 0.125000 | 6.947 sec/iter
Average 

Epoch: 62 | Batch: 025 / 028 | Total loss: 2.273 | Reg loss: 0.039 | Tree loss: 2.273 | Accuracy: 0.171875 | 6.949 sec/iter
Epoch: 62 | Batch: 026 / 028 | Total loss: 2.274 | Reg loss: 0.039 | Tree loss: 2.274 | Accuracy: 0.189453 | 6.949 sec/iter
Epoch: 62 | Batch: 027 / 028 | Total loss: 2.320 | Reg loss: 0.039 | Tree loss: 2.320 | Accuracy: 0.312500 | 6.947 sec/iter
Average sparseness: 0.9821428571428572
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
layer 7: 0.9821428571428573
layer 8: 0.9821428571428573
layer 9: 0.9821428571428573
layer 10: 0.9821428571428573
Epoch: 63 | Batch: 000 / 028 | Total loss: 2.381 | Reg loss: 0.039 | Tree loss: 2.381 | Accuracy: 0.218750 | 6.956 sec/iter
Epoch: 63 | Batch: 001 / 028 | Total loss: 2.381 | Reg loss: 0.039 | Tree loss: 2.381 | Accuracy: 0.220703 | 6.957 sec/iter
Epoch: 63 | Batch: 002 / 028 | To

layer 10: 0.9821428571428573
Epoch: 65 | Batch: 000 / 028 | Total loss: 2.362 | Reg loss: 0.038 | Tree loss: 2.362 | Accuracy: 0.203125 | 6.955 sec/iter
Epoch: 65 | Batch: 001 / 028 | Total loss: 2.363 | Reg loss: 0.038 | Tree loss: 2.363 | Accuracy: 0.187500 | 6.954 sec/iter
Epoch: 65 | Batch: 002 / 028 | Total loss: 2.399 | Reg loss: 0.038 | Tree loss: 2.399 | Accuracy: 0.171875 | 6.954 sec/iter
Epoch: 65 | Batch: 003 / 028 | Total loss: 2.356 | Reg loss: 0.038 | Tree loss: 2.356 | Accuracy: 0.181641 | 6.954 sec/iter
Epoch: 65 | Batch: 004 / 028 | Total loss: 2.351 | Reg loss: 0.038 | Tree loss: 2.351 | Accuracy: 0.207031 | 6.954 sec/iter
Epoch: 65 | Batch: 005 / 028 | Total loss: 2.334 | Reg loss: 0.038 | Tree loss: 2.334 | Accuracy: 0.195312 | 6.953 sec/iter
Epoch: 65 | Batch: 006 / 028 | Total loss: 2.310 | Reg loss: 0.038 | Tree loss: 2.310 | Accuracy: 0.191406 | 6.953 sec/iter
Epoch: 65 | Batch: 007 / 028 | Total loss: 2.357 | Reg loss: 0.038 | Tree loss: 2.357 | Accuracy: 0.208

Epoch: 67 | Batch: 005 / 028 | Total loss: 2.345 | Reg loss: 0.038 | Tree loss: 2.345 | Accuracy: 0.197266 | 6.957 sec/iter
Epoch: 67 | Batch: 006 / 028 | Total loss: 2.281 | Reg loss: 0.038 | Tree loss: 2.281 | Accuracy: 0.185547 | 6.957 sec/iter
Epoch: 67 | Batch: 007 / 028 | Total loss: 2.318 | Reg loss: 0.038 | Tree loss: 2.318 | Accuracy: 0.240234 | 6.957 sec/iter
Epoch: 67 | Batch: 008 / 028 | Total loss: 2.311 | Reg loss: 0.038 | Tree loss: 2.311 | Accuracy: 0.210938 | 6.957 sec/iter
Epoch: 67 | Batch: 009 / 028 | Total loss: 2.262 | Reg loss: 0.038 | Tree loss: 2.262 | Accuracy: 0.210938 | 6.957 sec/iter
Epoch: 67 | Batch: 010 / 028 | Total loss: 2.289 | Reg loss: 0.038 | Tree loss: 2.289 | Accuracy: 0.197266 | 6.957 sec/iter
Epoch: 67 | Batch: 011 / 028 | Total loss: 2.317 | Reg loss: 0.038 | Tree loss: 2.317 | Accuracy: 0.181641 | 6.958 sec/iter
Epoch: 67 | Batch: 012 / 028 | Total loss: 2.337 | Reg loss: 0.038 | Tree loss: 2.337 | Accuracy: 0.183594 | 6.958 sec/iter
Epoch: 6

Epoch: 69 | Batch: 010 / 028 | Total loss: 2.288 | Reg loss: 0.038 | Tree loss: 2.288 | Accuracy: 0.183594 | 6.965 sec/iter
Epoch: 69 | Batch: 011 / 028 | Total loss: 2.293 | Reg loss: 0.038 | Tree loss: 2.293 | Accuracy: 0.222656 | 6.965 sec/iter
Epoch: 69 | Batch: 012 / 028 | Total loss: 2.328 | Reg loss: 0.038 | Tree loss: 2.328 | Accuracy: 0.195312 | 6.966 sec/iter
Epoch: 69 | Batch: 013 / 028 | Total loss: 2.286 | Reg loss: 0.038 | Tree loss: 2.286 | Accuracy: 0.175781 | 6.966 sec/iter
Epoch: 69 | Batch: 014 / 028 | Total loss: 2.226 | Reg loss: 0.038 | Tree loss: 2.226 | Accuracy: 0.222656 | 6.966 sec/iter
Epoch: 69 | Batch: 015 / 028 | Total loss: 2.293 | Reg loss: 0.038 | Tree loss: 2.293 | Accuracy: 0.195312 | 6.966 sec/iter
Epoch: 69 | Batch: 016 / 028 | Total loss: 2.283 | Reg loss: 0.038 | Tree loss: 2.283 | Accuracy: 0.199219 | 6.966 sec/iter
Epoch: 69 | Batch: 017 / 028 | Total loss: 2.262 | Reg loss: 0.038 | Tree loss: 2.262 | Accuracy: 0.193359 | 6.966 sec/iter
Epoch: 6

Epoch: 71 | Batch: 015 / 028 | Total loss: 2.247 | Reg loss: 0.037 | Tree loss: 2.247 | Accuracy: 0.193359 | 6.969 sec/iter
Epoch: 71 | Batch: 016 / 028 | Total loss: 2.249 | Reg loss: 0.037 | Tree loss: 2.249 | Accuracy: 0.220703 | 6.97 sec/iter
Epoch: 71 | Batch: 017 / 028 | Total loss: 2.232 | Reg loss: 0.037 | Tree loss: 2.232 | Accuracy: 0.218750 | 6.97 sec/iter
Epoch: 71 | Batch: 018 / 028 | Total loss: 2.268 | Reg loss: 0.037 | Tree loss: 2.268 | Accuracy: 0.185547 | 6.97 sec/iter
Epoch: 71 | Batch: 019 / 028 | Total loss: 2.292 | Reg loss: 0.037 | Tree loss: 2.292 | Accuracy: 0.169922 | 6.97 sec/iter
Epoch: 71 | Batch: 020 / 028 | Total loss: 2.245 | Reg loss: 0.038 | Tree loss: 2.245 | Accuracy: 0.175781 | 6.97 sec/iter
Epoch: 71 | Batch: 021 / 028 | Total loss: 2.217 | Reg loss: 0.038 | Tree loss: 2.217 | Accuracy: 0.193359 | 6.97 sec/iter
Epoch: 71 | Batch: 022 / 028 | Total loss: 2.292 | Reg loss: 0.038 | Tree loss: 2.292 | Accuracy: 0.150391 | 6.97 sec/iter
Epoch: 71 | Bat

Epoch: 73 | Batch: 020 / 028 | Total loss: 2.216 | Reg loss: 0.037 | Tree loss: 2.216 | Accuracy: 0.230469 | 6.982 sec/iter
Epoch: 73 | Batch: 021 / 028 | Total loss: 2.269 | Reg loss: 0.037 | Tree loss: 2.269 | Accuracy: 0.205078 | 6.982 sec/iter
Epoch: 73 | Batch: 022 / 028 | Total loss: 2.212 | Reg loss: 0.037 | Tree loss: 2.212 | Accuracy: 0.183594 | 6.982 sec/iter
Epoch: 73 | Batch: 023 / 028 | Total loss: 2.305 | Reg loss: 0.037 | Tree loss: 2.305 | Accuracy: 0.158203 | 6.982 sec/iter
Epoch: 73 | Batch: 024 / 028 | Total loss: 2.221 | Reg loss: 0.037 | Tree loss: 2.221 | Accuracy: 0.197266 | 6.982 sec/iter
Epoch: 73 | Batch: 025 / 028 | Total loss: 2.198 | Reg loss: 0.037 | Tree loss: 2.198 | Accuracy: 0.191406 | 6.982 sec/iter
Epoch: 73 | Batch: 026 / 028 | Total loss: 2.184 | Reg loss: 0.037 | Tree loss: 2.184 | Accuracy: 0.218750 | 6.982 sec/iter
Epoch: 73 | Batch: 027 / 028 | Total loss: 2.051 | Reg loss: 0.037 | Tree loss: 2.051 | Accuracy: 0.375000 | 6.98 sec/iter
Average s

Epoch: 75 | Batch: 025 / 028 | Total loss: 2.253 | Reg loss: 0.037 | Tree loss: 2.253 | Accuracy: 0.189453 | 6.989 sec/iter
Epoch: 75 | Batch: 026 / 028 | Total loss: 2.207 | Reg loss: 0.037 | Tree loss: 2.207 | Accuracy: 0.210938 | 6.989 sec/iter
Epoch: 75 | Batch: 027 / 028 | Total loss: 2.166 | Reg loss: 0.037 | Tree loss: 2.166 | Accuracy: 0.187500 | 6.987 sec/iter
Average sparseness: 0.9821428571428572
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
layer 7: 0.9821428571428573
layer 8: 0.9821428571428573
layer 9: 0.9821428571428573
layer 10: 0.9821428571428573
Epoch: 76 | Batch: 000 / 028 | Total loss: 2.312 | Reg loss: 0.037 | Tree loss: 2.312 | Accuracy: 0.205078 | 6.995 sec/iter
Epoch: 76 | Batch: 001 / 028 | Total loss: 2.302 | Reg loss: 0.037 | Tree loss: 2.302 | Accuracy: 0.205078 | 6.995 sec/iter
Epoch: 76 | Batch: 002 / 028 | To

layer 10: 0.9821428571428573
Epoch: 78 | Batch: 000 / 028 | Total loss: 2.299 | Reg loss: 0.036 | Tree loss: 2.299 | Accuracy: 0.191406 | 7.002 sec/iter
Epoch: 78 | Batch: 001 / 028 | Total loss: 2.308 | Reg loss: 0.036 | Tree loss: 2.308 | Accuracy: 0.214844 | 7.002 sec/iter
Epoch: 78 | Batch: 002 / 028 | Total loss: 2.371 | Reg loss: 0.036 | Tree loss: 2.371 | Accuracy: 0.191406 | 7.002 sec/iter
Epoch: 78 | Batch: 003 / 028 | Total loss: 2.297 | Reg loss: 0.036 | Tree loss: 2.297 | Accuracy: 0.185547 | 7.002 sec/iter
Epoch: 78 | Batch: 004 / 028 | Total loss: 2.334 | Reg loss: 0.036 | Tree loss: 2.334 | Accuracy: 0.175781 | 7.002 sec/iter
Epoch: 78 | Batch: 005 / 028 | Total loss: 2.284 | Reg loss: 0.036 | Tree loss: 2.284 | Accuracy: 0.187500 | 7.002 sec/iter
Epoch: 78 | Batch: 006 / 028 | Total loss: 2.283 | Reg loss: 0.036 | Tree loss: 2.283 | Accuracy: 0.236328 | 7.002 sec/iter
Epoch: 78 | Batch: 007 / 028 | Total loss: 2.240 | Reg loss: 0.036 | Tree loss: 2.240 | Accuracy: 0.201

Epoch: 80 | Batch: 005 / 028 | Total loss: 2.261 | Reg loss: 0.036 | Tree loss: 2.261 | Accuracy: 0.195312 | 7.0 sec/iter
Epoch: 80 | Batch: 006 / 028 | Total loss: 2.287 | Reg loss: 0.036 | Tree loss: 2.287 | Accuracy: 0.189453 | 7.0 sec/iter
Epoch: 80 | Batch: 007 / 028 | Total loss: 2.225 | Reg loss: 0.036 | Tree loss: 2.225 | Accuracy: 0.208984 | 6.999 sec/iter
Epoch: 80 | Batch: 008 / 028 | Total loss: 2.272 | Reg loss: 0.036 | Tree loss: 2.272 | Accuracy: 0.177734 | 6.999 sec/iter
Epoch: 80 | Batch: 009 / 028 | Total loss: 2.259 | Reg loss: 0.036 | Tree loss: 2.259 | Accuracy: 0.207031 | 6.999 sec/iter
Epoch: 80 | Batch: 010 / 028 | Total loss: 2.255 | Reg loss: 0.036 | Tree loss: 2.255 | Accuracy: 0.199219 | 6.998 sec/iter
Epoch: 80 | Batch: 011 / 028 | Total loss: 2.257 | Reg loss: 0.036 | Tree loss: 2.257 | Accuracy: 0.175781 | 6.997 sec/iter
Epoch: 80 | Batch: 012 / 028 | Total loss: 2.283 | Reg loss: 0.036 | Tree loss: 2.283 | Accuracy: 0.191406 | 6.997 sec/iter
Epoch: 80 | 

Epoch: 82 | Batch: 010 / 028 | Total loss: 2.225 | Reg loss: 0.036 | Tree loss: 2.225 | Accuracy: 0.212891 | 7.007 sec/iter
Epoch: 82 | Batch: 011 / 028 | Total loss: 2.243 | Reg loss: 0.036 | Tree loss: 2.243 | Accuracy: 0.197266 | 7.007 sec/iter
Epoch: 82 | Batch: 012 / 028 | Total loss: 2.231 | Reg loss: 0.036 | Tree loss: 2.231 | Accuracy: 0.197266 | 7.007 sec/iter
Epoch: 82 | Batch: 013 / 028 | Total loss: 2.328 | Reg loss: 0.036 | Tree loss: 2.328 | Accuracy: 0.146484 | 7.007 sec/iter
Epoch: 82 | Batch: 014 / 028 | Total loss: 2.229 | Reg loss: 0.036 | Tree loss: 2.229 | Accuracy: 0.201172 | 7.007 sec/iter
Epoch: 82 | Batch: 015 / 028 | Total loss: 2.259 | Reg loss: 0.036 | Tree loss: 2.259 | Accuracy: 0.205078 | 7.007 sec/iter
Epoch: 82 | Batch: 016 / 028 | Total loss: 2.196 | Reg loss: 0.036 | Tree loss: 2.196 | Accuracy: 0.187500 | 7.007 sec/iter
Epoch: 82 | Batch: 017 / 028 | Total loss: 2.168 | Reg loss: 0.036 | Tree loss: 2.168 | Accuracy: 0.228516 | 7.007 sec/iter
Epoch: 8

Epoch: 84 | Batch: 015 / 028 | Total loss: 2.218 | Reg loss: 0.036 | Tree loss: 2.218 | Accuracy: 0.212891 | 7.019 sec/iter
Epoch: 84 | Batch: 016 / 028 | Total loss: 2.192 | Reg loss: 0.036 | Tree loss: 2.192 | Accuracy: 0.208984 | 7.018 sec/iter
Epoch: 84 | Batch: 017 / 028 | Total loss: 2.200 | Reg loss: 0.036 | Tree loss: 2.200 | Accuracy: 0.220703 | 7.017 sec/iter
Epoch: 84 | Batch: 018 / 028 | Total loss: 2.217 | Reg loss: 0.036 | Tree loss: 2.217 | Accuracy: 0.199219 | 7.017 sec/iter
Epoch: 84 | Batch: 019 / 028 | Total loss: 2.218 | Reg loss: 0.036 | Tree loss: 2.218 | Accuracy: 0.212891 | 7.017 sec/iter
Epoch: 84 | Batch: 020 / 028 | Total loss: 2.197 | Reg loss: 0.036 | Tree loss: 2.197 | Accuracy: 0.181641 | 7.017 sec/iter
Epoch: 84 | Batch: 021 / 028 | Total loss: 2.229 | Reg loss: 0.036 | Tree loss: 2.229 | Accuracy: 0.222656 | 7.017 sec/iter
Epoch: 84 | Batch: 022 / 028 | Total loss: 2.230 | Reg loss: 0.036 | Tree loss: 2.230 | Accuracy: 0.195312 | 7.017 sec/iter
Epoch: 8

Epoch: 86 | Batch: 020 / 028 | Total loss: 2.193 | Reg loss: 0.035 | Tree loss: 2.193 | Accuracy: 0.222656 | 7.016 sec/iter
Epoch: 86 | Batch: 021 / 028 | Total loss: 2.140 | Reg loss: 0.035 | Tree loss: 2.140 | Accuracy: 0.218750 | 7.016 sec/iter
Epoch: 86 | Batch: 022 / 028 | Total loss: 2.185 | Reg loss: 0.036 | Tree loss: 2.185 | Accuracy: 0.222656 | 7.016 sec/iter
Epoch: 86 | Batch: 023 / 028 | Total loss: 2.170 | Reg loss: 0.036 | Tree loss: 2.170 | Accuracy: 0.203125 | 7.016 sec/iter
Epoch: 86 | Batch: 024 / 028 | Total loss: 2.232 | Reg loss: 0.036 | Tree loss: 2.232 | Accuracy: 0.207031 | 7.016 sec/iter
Epoch: 86 | Batch: 025 / 028 | Total loss: 2.173 | Reg loss: 0.036 | Tree loss: 2.173 | Accuracy: 0.195312 | 7.015 sec/iter
Epoch: 86 | Batch: 026 / 028 | Total loss: 2.188 | Reg loss: 0.036 | Tree loss: 2.188 | Accuracy: 0.197266 | 7.014 sec/iter
Epoch: 86 | Batch: 027 / 028 | Total loss: 2.076 | Reg loss: 0.036 | Tree loss: 2.076 | Accuracy: 0.187500 | 7.012 sec/iter
Average 

Epoch: 88 | Batch: 025 / 028 | Total loss: 2.202 | Reg loss: 0.035 | Tree loss: 2.202 | Accuracy: 0.191406 | 7.016 sec/iter
Epoch: 88 | Batch: 026 / 028 | Total loss: 2.153 | Reg loss: 0.035 | Tree loss: 2.153 | Accuracy: 0.193359 | 7.016 sec/iter
Epoch: 88 | Batch: 027 / 028 | Total loss: 2.107 | Reg loss: 0.035 | Tree loss: 2.107 | Accuracy: 0.125000 | 7.014 sec/iter
Average sparseness: 0.9821428571428572
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
layer 7: 0.9821428571428573
layer 8: 0.9821428571428573
layer 9: 0.9821428571428573
layer 10: 0.9821428571428573
Epoch: 89 | Batch: 000 / 028 | Total loss: 2.274 | Reg loss: 0.035 | Tree loss: 2.274 | Accuracy: 0.214844 | 7.015 sec/iter
Epoch: 89 | Batch: 001 / 028 | Total loss: 2.294 | Reg loss: 0.035 | Tree loss: 2.294 | Accuracy: 0.166016 | 7.016 sec/iter
Epoch: 89 | Batch: 002 / 028 | To

layer 6: 0.982142857142857
layer 7: 0.9821428571428573
layer 8: 0.9821428571428573
layer 9: 0.9821428571428573
layer 10: 0.9821428571428573
Epoch: 91 | Batch: 000 / 028 | Total loss: 2.279 | Reg loss: 0.035 | Tree loss: 2.279 | Accuracy: 0.199219 | 7.021 sec/iter
Epoch: 91 | Batch: 001 / 028 | Total loss: 2.288 | Reg loss: 0.035 | Tree loss: 2.288 | Accuracy: 0.197266 | 7.02 sec/iter
Epoch: 91 | Batch: 002 / 028 | Total loss: 2.263 | Reg loss: 0.035 | Tree loss: 2.263 | Accuracy: 0.197266 | 7.02 sec/iter
Epoch: 91 | Batch: 003 / 028 | Total loss: 2.358 | Reg loss: 0.035 | Tree loss: 2.358 | Accuracy: 0.197266 | 7.019 sec/iter
Epoch: 91 | Batch: 004 / 028 | Total loss: 2.224 | Reg loss: 0.035 | Tree loss: 2.224 | Accuracy: 0.230469 | 7.019 sec/iter
Epoch: 91 | Batch: 005 / 028 | Total loss: 2.245 | Reg loss: 0.035 | Tree loss: 2.245 | Accuracy: 0.201172 | 7.019 sec/iter
Epoch: 91 | Batch: 006 / 028 | Total loss: 2.258 | Reg loss: 0.035 | Tree loss: 2.258 | Accuracy: 0.203125 | 7.019 sec

Epoch: 93 | Batch: 004 / 028 | Total loss: 2.242 | Reg loss: 0.034 | Tree loss: 2.242 | Accuracy: 0.193359 | 7.03 sec/iter
Epoch: 93 | Batch: 005 / 028 | Total loss: 2.227 | Reg loss: 0.034 | Tree loss: 2.227 | Accuracy: 0.187500 | 7.03 sec/iter
Epoch: 93 | Batch: 006 / 028 | Total loss: 2.293 | Reg loss: 0.034 | Tree loss: 2.293 | Accuracy: 0.226562 | 7.03 sec/iter
Epoch: 93 | Batch: 007 / 028 | Total loss: 2.326 | Reg loss: 0.034 | Tree loss: 2.326 | Accuracy: 0.177734 | 7.03 sec/iter
Epoch: 93 | Batch: 008 / 028 | Total loss: 2.276 | Reg loss: 0.034 | Tree loss: 2.276 | Accuracy: 0.187500 | 7.03 sec/iter
Epoch: 93 | Batch: 009 / 028 | Total loss: 2.234 | Reg loss: 0.034 | Tree loss: 2.234 | Accuracy: 0.207031 | 7.03 sec/iter
Epoch: 93 | Batch: 010 / 028 | Total loss: 2.216 | Reg loss: 0.034 | Tree loss: 2.216 | Accuracy: 0.195312 | 7.03 sec/iter
Epoch: 93 | Batch: 011 / 028 | Total loss: 2.255 | Reg loss: 0.034 | Tree loss: 2.255 | Accuracy: 0.158203 | 7.031 sec/iter
Epoch: 93 | Bat

Epoch: 95 | Batch: 009 / 028 | Total loss: 2.289 | Reg loss: 0.034 | Tree loss: 2.289 | Accuracy: 0.167969 | 7.039 sec/iter
Epoch: 95 | Batch: 010 / 028 | Total loss: 2.211 | Reg loss: 0.034 | Tree loss: 2.211 | Accuracy: 0.210938 | 7.038 sec/iter
Epoch: 95 | Batch: 011 / 028 | Total loss: 2.209 | Reg loss: 0.034 | Tree loss: 2.209 | Accuracy: 0.199219 | 7.038 sec/iter
Epoch: 95 | Batch: 012 / 028 | Total loss: 2.208 | Reg loss: 0.034 | Tree loss: 2.208 | Accuracy: 0.179688 | 7.038 sec/iter
Epoch: 95 | Batch: 013 / 028 | Total loss: 2.185 | Reg loss: 0.034 | Tree loss: 2.185 | Accuracy: 0.191406 | 7.038 sec/iter
Epoch: 95 | Batch: 014 / 028 | Total loss: 2.174 | Reg loss: 0.034 | Tree loss: 2.174 | Accuracy: 0.214844 | 7.038 sec/iter
Epoch: 95 | Batch: 015 / 028 | Total loss: 2.199 | Reg loss: 0.034 | Tree loss: 2.199 | Accuracy: 0.183594 | 7.038 sec/iter
Epoch: 95 | Batch: 016 / 028 | Total loss: 2.225 | Reg loss: 0.034 | Tree loss: 2.225 | Accuracy: 0.183594 | 7.038 sec/iter
Epoch: 9

Epoch: 97 | Batch: 014 / 028 | Total loss: 2.157 | Reg loss: 0.034 | Tree loss: 2.157 | Accuracy: 0.197266 | 7.047 sec/iter
Epoch: 97 | Batch: 015 / 028 | Total loss: 2.205 | Reg loss: 0.034 | Tree loss: 2.205 | Accuracy: 0.199219 | 7.047 sec/iter
Epoch: 97 | Batch: 016 / 028 | Total loss: 2.207 | Reg loss: 0.034 | Tree loss: 2.207 | Accuracy: 0.201172 | 7.047 sec/iter
Epoch: 97 | Batch: 017 / 028 | Total loss: 2.198 | Reg loss: 0.034 | Tree loss: 2.198 | Accuracy: 0.181641 | 7.047 sec/iter
Epoch: 97 | Batch: 018 / 028 | Total loss: 2.214 | Reg loss: 0.034 | Tree loss: 2.214 | Accuracy: 0.212891 | 7.047 sec/iter
Epoch: 97 | Batch: 019 / 028 | Total loss: 2.157 | Reg loss: 0.034 | Tree loss: 2.157 | Accuracy: 0.207031 | 7.047 sec/iter
Epoch: 97 | Batch: 020 / 028 | Total loss: 2.169 | Reg loss: 0.034 | Tree loss: 2.169 | Accuracy: 0.199219 | 7.047 sec/iter
Epoch: 97 | Batch: 021 / 028 | Total loss: 2.174 | Reg loss: 0.034 | Tree loss: 2.174 | Accuracy: 0.203125 | 7.047 sec/iter
Epoch: 9

Epoch: 99 | Batch: 019 / 028 | Total loss: 2.194 | Reg loss: 0.034 | Tree loss: 2.194 | Accuracy: 0.214844 | 7.047 sec/iter
Epoch: 99 | Batch: 020 / 028 | Total loss: 2.189 | Reg loss: 0.034 | Tree loss: 2.189 | Accuracy: 0.193359 | 7.047 sec/iter
Epoch: 99 | Batch: 021 / 028 | Total loss: 2.190 | Reg loss: 0.034 | Tree loss: 2.190 | Accuracy: 0.197266 | 7.047 sec/iter
Epoch: 99 | Batch: 022 / 028 | Total loss: 2.144 | Reg loss: 0.034 | Tree loss: 2.144 | Accuracy: 0.203125 | 7.047 sec/iter
Epoch: 99 | Batch: 023 / 028 | Total loss: 2.177 | Reg loss: 0.034 | Tree loss: 2.177 | Accuracy: 0.203125 | 7.047 sec/iter
Epoch: 99 | Batch: 024 / 028 | Total loss: 2.152 | Reg loss: 0.034 | Tree loss: 2.152 | Accuracy: 0.169922 | 7.048 sec/iter
Epoch: 99 | Batch: 025 / 028 | Total loss: 2.144 | Reg loss: 0.034 | Tree loss: 2.144 | Accuracy: 0.187500 | 7.048 sec/iter
Epoch: 99 | Batch: 026 / 028 | Total loss: 2.131 | Reg loss: 0.034 | Tree loss: 2.131 | Accuracy: 0.210938 | 7.048 sec/iter
Epoch: 9

In [31]:
plt.figure(figsize=(10, 5))
plt.ylabel("Accuracy")
plt.xlabel('Iteration')
plt.plot(accs, label='Accuracy vs iteration')
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [32]:
plt.figure()
plt.ylabel("Loss")
plt.xlabel('Iteration')
plt.plot(losses, label='Loss vs iteration')
plt.yscale("log")
plt.show()

plt.figure()
weights = tree.inner_nodes.weight.cpu().detach().numpy().flatten()
plt.hist(weights, bins=500)
weights_std = np.std(weights)
weights_mean = np.mean(weights)
plt.axvline(weights_mean + weights_std, color='r')
plt.axvline(weights_mean - weights_std, color='r')
plt.title(f"Mean: {weights_mean}   |   STD: {weights_std}")
plt.yscale("log")
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

# Tree Visualization

In [33]:
plt.figure(figsize=(15, 10), dpi=80)
avg_height, root = tree.visualize()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Average height: 11.93182416425375


# Extract Rules

# Accumulate samples in the leaves

In [34]:
print(f"Number of patterns: {len(root.get_leaves())}")

Number of patterns: 3799


In [35]:
method = 'greedy'

In [36]:
root.clear_leaves_samples()

with torch.no_grad():
    for batch_idx, (data, target) in enumerate(tree_loader):
        root.accumulate_samples(data, method)



# Tighten boundaries

In [37]:
attr_names = dataset.items

# print(attr_names)
leaves = root.get_leaves()
sum_comprehensibility = 0
comprehensibilities = []
for pattern_counter, leaf in enumerate(leaves):
    leaf.reset_path()
    leaf.tighten_with_accumulated_samples()
    conds = leaf.get_path_conditions(attr_names)
    print(f"============== Pattern {pattern_counter + 1} ==============")
    comprehensibilities.append(sum([cond.comprehensibility for cond in conds]))
    
print(f"Average comprehensibility: {np.mean(comprehensibilities)}")
print(f"std comprehensibility: {np.std(comprehensibilities)}")
print(f"var comprehensibility: {np.var(comprehensibilities)}")
print(f"minimum comprehensibility: {np.min(comprehensibilities)}")
print(f"maximum comprehensibility: {np.max(comprehensibilities)}")

  return np.log(1 / (1 - x))


13840
















Average comprehensibility: 64.17425638325875
std comprehensibility: 4.402990269709591
var comprehensibility: 19.38632331515733
minimum comprehensibility: 40
maximum comprehensibility: 72
