In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib widget
import numpy as np
import pandas as pd
import torch
import torchvision
from torch import optim
import matplotlib.pyplot as plt
import torch.nn.functional as F
from sklearn.metrics import pairwise_distances
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tqdm.notebook import tqdm
import os
import sys
import time
import torch.nn as nn
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)
from stream_generators.market_basket_dataset import MarketBasketDataset, BinaryEncodingTransform, RemoveItemsTransform
from utils.MatplotlibUtils import reduce_dims_and_plot
from network.auto_encoder import AutoEncoder
from losses.knn_loss import KNNLoss
from sklearn.cluster import DBSCAN
from sklearn.mixture import GaussianMixture
from soft_decision_tree.sdt_model import SDT
from sklearn.metrics import davies_bouldin_score

In [2]:
k = 256
tree_depth = 8
device = 'cuda'
dataset_path = r"/mnt/qnap/ekosman/Groceries_dataset.csv"

## Create the market basket dataset and use one-hot encoding for items

In [3]:
dataset = MarketBasketDataset(dataset_path=dataset_path)

In [4]:
model = AutoEncoder(dataset.n_items, 50, 4).train().to(device)
epochs = 500
lr = 5e-3
batch_size = 512
log_every = 5

In [5]:
dataset.transform = torchvision.transforms.Compose([
    RemoveItemsTransform(p=0.5),
    BinaryEncodingTransform(mapping=dataset.items_to_idx),
]
)
dataset.target_transform = torchvision.transforms.Compose([
    BinaryEncodingTransform(mapping=dataset.items_to_idx),
]
)

In [6]:
model.train()
data_iter = torch.utils.data.DataLoader(dataset,
                                     batch_size=batch_size,
                                     shuffle=True,
                                     num_workers=1,
                                     pin_memory=True)
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)
scheduler = ReduceLROnPlateau(optimizer, 'min', verbose=True, factor=0.7, threshold=1e-4)
knn_crt = KNNLoss(k=k).to(device)
losses = []
alpha = 10/170
gamma = 2
for epoch in range(epochs):
    total_loss = 0
    for iteration, (batch, target) in enumerate(data_iter):
        batch = batch.to(device)
        target = target.to(device)
        outputs, iterm = model(batch, return_intermidiate=True)
        mse_loss = F.binary_cross_entropy_with_logits(outputs, target, reduction='none')
        mask = torch.ones_like(mse_loss)
        mask[target == 0] = alpha ** gamma
        mask[target == 1] = (1 - alpha) ** gamma
        mse_loss = (mse_loss * mask).sum(dim=-1).mean()
        try:
            knn_loss = knn_crt(iterm)
            if torch.isinf(knn_loss):
                knn_loss = 0
        except ValueError:
            knn_loss = torch.tensor(0)
        loss = mse_loss + knn_loss
        total_loss += loss.item()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if iteration % log_every == 0:
            print(f"Epoch {epoch} / {epochs} | iteration {iteration} / {len(data_iter)} | Total Loss: {loss.item()} | KNN Loss: {knn_loss.item()} | BCE Loss: {mse_loss.item()}")
    
    scheduler.step(total_loss / (iteration + 1))
    losses.append(total_loss / (iteration + 1))

Epoch 0 / 500 | iteration 0 / 30 | Total Loss: 8.134346008300781 | KNN Loss: 6.232436656951904 | BCE Loss: 1.9019089937210083
Epoch 0 / 500 | iteration 5 / 30 | Total Loss: 8.142910957336426 | KNN Loss: 6.23236083984375 | BCE Loss: 1.9105498790740967
Epoch 0 / 500 | iteration 10 / 30 | Total Loss: 8.195171356201172 | KNN Loss: 6.232297420501709 | BCE Loss: 1.962874412536621
Epoch 0 / 500 | iteration 15 / 30 | Total Loss: 8.213438034057617 | KNN Loss: 6.232086658477783 | BCE Loss: 1.981351375579834
Epoch 0 / 500 | iteration 20 / 30 | Total Loss: 8.177547454833984 | KNN Loss: 6.232169151306152 | BCE Loss: 1.945378303527832
Epoch 0 / 500 | iteration 25 / 30 | Total Loss: 8.185166358947754 | KNN Loss: 6.23225212097168 | BCE Loss: 1.9529139995574951
Epoch 1 / 500 | iteration 0 / 30 | Total Loss: 8.142288208007812 | KNN Loss: 6.232115745544434 | BCE Loss: 1.910172939300537
Epoch 1 / 500 | iteration 5 / 30 | Total Loss: 8.10138988494873 | KNN Loss: 6.232104778289795 | BCE Loss: 1.869284868240

Epoch 10 / 500 | iteration 25 / 30 | Total Loss: 7.31865930557251 | KNN Loss: 6.1899638175964355 | BCE Loss: 1.1286954879760742
Epoch 11 / 500 | iteration 0 / 30 | Total Loss: 7.280426025390625 | KNN Loss: 6.181915283203125 | BCE Loss: 1.098510980606079
Epoch 11 / 500 | iteration 5 / 30 | Total Loss: 7.262902736663818 | KNN Loss: 6.180129051208496 | BCE Loss: 1.0827738046646118
Epoch 11 / 500 | iteration 10 / 30 | Total Loss: 7.279411315917969 | KNN Loss: 6.1774139404296875 | BCE Loss: 1.1019973754882812
Epoch 11 / 500 | iteration 15 / 30 | Total Loss: 7.265653133392334 | KNN Loss: 6.172086715698242 | BCE Loss: 1.0935664176940918
Epoch 11 / 500 | iteration 20 / 30 | Total Loss: 7.274707794189453 | KNN Loss: 6.172283172607422 | BCE Loss: 1.1024248600006104
Epoch 11 / 500 | iteration 25 / 30 | Total Loss: 7.233590126037598 | KNN Loss: 6.16810941696167 | BCE Loss: 1.0654809474945068
Epoch 12 / 500 | iteration 0 / 30 | Total Loss: 7.221923828125 | KNN Loss: 6.165285587310791 | BCE Loss: 1.

Epoch 21 / 500 | iteration 20 / 30 | Total Loss: 6.826141834259033 | KNN Loss: 5.800817489624023 | BCE Loss: 1.0253244638442993
Epoch 21 / 500 | iteration 25 / 30 | Total Loss: 6.8504414558410645 | KNN Loss: 5.806095600128174 | BCE Loss: 1.044345736503601
Epoch 22 / 500 | iteration 0 / 30 | Total Loss: 6.882237911224365 | KNN Loss: 5.80863094329834 | BCE Loss: 1.0736068487167358
Epoch 22 / 500 | iteration 5 / 30 | Total Loss: 6.842680931091309 | KNN Loss: 5.801398277282715 | BCE Loss: 1.0412826538085938
Epoch 22 / 500 | iteration 10 / 30 | Total Loss: 6.855913162231445 | KNN Loss: 5.80914831161499 | BCE Loss: 1.046764850616455
Epoch 22 / 500 | iteration 15 / 30 | Total Loss: 6.846343040466309 | KNN Loss: 5.790359020233154 | BCE Loss: 1.0559839010238647
Epoch 22 / 500 | iteration 20 / 30 | Total Loss: 6.848967552185059 | KNN Loss: 5.792078018188477 | BCE Loss: 1.056889533996582
Epoch 22 / 500 | iteration 25 / 30 | Total Loss: 6.8172688484191895 | KNN Loss: 5.775050163269043 | BCE Loss: 

Epoch 32 / 500 | iteration 15 / 30 | Total Loss: 6.697096824645996 | KNN Loss: 5.649101257324219 | BCE Loss: 1.0479955673217773
Epoch 32 / 500 | iteration 20 / 30 | Total Loss: 6.687907695770264 | KNN Loss: 5.642653942108154 | BCE Loss: 1.0452537536621094
Epoch 32 / 500 | iteration 25 / 30 | Total Loss: 6.684504985809326 | KNN Loss: 5.637027740478516 | BCE Loss: 1.047477126121521
Epoch 33 / 500 | iteration 0 / 30 | Total Loss: 6.726513862609863 | KNN Loss: 5.665369033813477 | BCE Loss: 1.0611450672149658
Epoch 33 / 500 | iteration 5 / 30 | Total Loss: 6.746542930603027 | KNN Loss: 5.689077854156494 | BCE Loss: 1.057464838027954
Epoch 33 / 500 | iteration 10 / 30 | Total Loss: 6.70589017868042 | KNN Loss: 5.639133453369141 | BCE Loss: 1.0667567253112793
Epoch 33 / 500 | iteration 15 / 30 | Total Loss: 6.704431056976318 | KNN Loss: 5.657406806945801 | BCE Loss: 1.0470243692398071
Epoch 33 / 500 | iteration 20 / 30 | Total Loss: 6.731531620025635 | KNN Loss: 5.679787635803223 | BCE Loss: 

Epoch 43 / 500 | iteration 10 / 30 | Total Loss: 6.752091884613037 | KNN Loss: 5.704764366149902 | BCE Loss: 1.0473275184631348
Epoch 43 / 500 | iteration 15 / 30 | Total Loss: 6.636539459228516 | KNN Loss: 5.615149021148682 | BCE Loss: 1.0213905572891235
Epoch 43 / 500 | iteration 20 / 30 | Total Loss: 6.690989971160889 | KNN Loss: 5.654325485229492 | BCE Loss: 1.036664366722107
Epoch 43 / 500 | iteration 25 / 30 | Total Loss: 6.6825852394104 | KNN Loss: 5.641018867492676 | BCE Loss: 1.041566252708435
Epoch 44 / 500 | iteration 0 / 30 | Total Loss: 6.712943077087402 | KNN Loss: 5.646514415740967 | BCE Loss: 1.066428780555725
Epoch 44 / 500 | iteration 5 / 30 | Total Loss: 6.7194294929504395 | KNN Loss: 5.655757427215576 | BCE Loss: 1.0636720657348633
Epoch 44 / 500 | iteration 10 / 30 | Total Loss: 6.678214073181152 | KNN Loss: 5.614256858825684 | BCE Loss: 1.0639570951461792
Epoch 44 / 500 | iteration 15 / 30 | Total Loss: 6.651150703430176 | KNN Loss: 5.6342315673828125 | BCE Loss: 

Epoch 54 / 500 | iteration 5 / 30 | Total Loss: 6.708770275115967 | KNN Loss: 5.668721675872803 | BCE Loss: 1.040048599243164
Epoch 54 / 500 | iteration 10 / 30 | Total Loss: 6.696259498596191 | KNN Loss: 5.674571514129639 | BCE Loss: 1.0216882228851318
Epoch 54 / 500 | iteration 15 / 30 | Total Loss: 6.712340831756592 | KNN Loss: 5.658931255340576 | BCE Loss: 1.053409457206726
Epoch 54 / 500 | iteration 20 / 30 | Total Loss: 6.7107768058776855 | KNN Loss: 5.620892524719238 | BCE Loss: 1.0898842811584473
Epoch 54 / 500 | iteration 25 / 30 | Total Loss: 6.682238578796387 | KNN Loss: 5.6255388259887695 | BCE Loss: 1.0566998720169067
Epoch 55 / 500 | iteration 0 / 30 | Total Loss: 6.720672130584717 | KNN Loss: 5.663945198059082 | BCE Loss: 1.0567269325256348
Epoch 55 / 500 | iteration 5 / 30 | Total Loss: 6.668315887451172 | KNN Loss: 5.620479106903076 | BCE Loss: 1.0478367805480957
Epoch 55 / 500 | iteration 10 / 30 | Total Loss: 6.702477931976318 | KNN Loss: 5.645752906799316 | BCE Loss

Epoch 65 / 500 | iteration 0 / 30 | Total Loss: 6.648127555847168 | KNN Loss: 5.601924896240234 | BCE Loss: 1.0462024211883545
Epoch 65 / 500 | iteration 5 / 30 | Total Loss: 6.660036087036133 | KNN Loss: 5.614432334899902 | BCE Loss: 1.0456039905548096
Epoch 65 / 500 | iteration 10 / 30 | Total Loss: 6.674197673797607 | KNN Loss: 5.626425743103027 | BCE Loss: 1.04777193069458
Epoch 65 / 500 | iteration 15 / 30 | Total Loss: 6.666938304901123 | KNN Loss: 5.613509178161621 | BCE Loss: 1.0534290075302124
Epoch 65 / 500 | iteration 20 / 30 | Total Loss: 6.675251007080078 | KNN Loss: 5.637765407562256 | BCE Loss: 1.0374853610992432
Epoch 65 / 500 | iteration 25 / 30 | Total Loss: 6.763779163360596 | KNN Loss: 5.714779853820801 | BCE Loss: 1.0489994287490845
Epoch 66 / 500 | iteration 0 / 30 | Total Loss: 6.637813091278076 | KNN Loss: 5.600454807281494 | BCE Loss: 1.037358283996582
Epoch 66 / 500 | iteration 5 / 30 | Total Loss: 6.673925399780273 | KNN Loss: 5.616410732269287 | BCE Loss: 1.

Epoch 75 / 500 | iteration 25 / 30 | Total Loss: 6.664184093475342 | KNN Loss: 5.608595371246338 | BCE Loss: 1.0555886030197144
Epoch    76: reducing learning rate of group 0 to 3.5000e-03.
Epoch 76 / 500 | iteration 0 / 30 | Total Loss: 6.741213321685791 | KNN Loss: 5.684236526489258 | BCE Loss: 1.0569767951965332
Epoch 76 / 500 | iteration 5 / 30 | Total Loss: 6.6903486251831055 | KNN Loss: 5.609631538391113 | BCE Loss: 1.0807173252105713
Epoch 76 / 500 | iteration 10 / 30 | Total Loss: 6.673060417175293 | KNN Loss: 5.642778396606445 | BCE Loss: 1.0302822589874268
Epoch 76 / 500 | iteration 15 / 30 | Total Loss: 6.763747692108154 | KNN Loss: 5.694355010986328 | BCE Loss: 1.0693926811218262
Epoch 76 / 500 | iteration 20 / 30 | Total Loss: 6.699502468109131 | KNN Loss: 5.643534183502197 | BCE Loss: 1.055968165397644
Epoch 76 / 500 | iteration 25 / 30 | Total Loss: 6.661111831665039 | KNN Loss: 5.623355388641357 | BCE Loss: 1.0377564430236816
Epoch 77 / 500 | iteration 0 / 30 | Total Lo

Epoch 86 / 500 | iteration 15 / 30 | Total Loss: 6.652132511138916 | KNN Loss: 5.604006767272949 | BCE Loss: 1.0481258630752563
Epoch 86 / 500 | iteration 20 / 30 | Total Loss: 6.720600605010986 | KNN Loss: 5.674674034118652 | BCE Loss: 1.045926570892334
Epoch 86 / 500 | iteration 25 / 30 | Total Loss: 6.62612771987915 | KNN Loss: 5.615839004516602 | BCE Loss: 1.0102887153625488
Epoch    87: reducing learning rate of group 0 to 2.4500e-03.
Epoch 87 / 500 | iteration 0 / 30 | Total Loss: 6.633025646209717 | KNN Loss: 5.601206302642822 | BCE Loss: 1.031819462776184
Epoch 87 / 500 | iteration 5 / 30 | Total Loss: 6.683670997619629 | KNN Loss: 5.6762213706970215 | BCE Loss: 1.0074496269226074
Epoch 87 / 500 | iteration 10 / 30 | Total Loss: 6.679852485656738 | KNN Loss: 5.623786926269531 | BCE Loss: 1.0560657978057861
Epoch 87 / 500 | iteration 15 / 30 | Total Loss: 6.6787495613098145 | KNN Loss: 5.649400234222412 | BCE Loss: 1.0293493270874023
Epoch 87 / 500 | iteration 20 / 30 | Total Lo

Epoch 97 / 500 | iteration 5 / 30 | Total Loss: 6.631680488586426 | KNN Loss: 5.6056647300720215 | BCE Loss: 1.0260157585144043
Epoch 97 / 500 | iteration 10 / 30 | Total Loss: 6.675940990447998 | KNN Loss: 5.615682601928711 | BCE Loss: 1.0602585077285767
Epoch 97 / 500 | iteration 15 / 30 | Total Loss: 6.635870456695557 | KNN Loss: 5.600407123565674 | BCE Loss: 1.0354633331298828
Epoch 97 / 500 | iteration 20 / 30 | Total Loss: 6.744009017944336 | KNN Loss: 5.690992832183838 | BCE Loss: 1.053016185760498
Epoch 97 / 500 | iteration 25 / 30 | Total Loss: 6.644431114196777 | KNN Loss: 5.592933654785156 | BCE Loss: 1.051497459411621
Epoch 98 / 500 | iteration 0 / 30 | Total Loss: 6.630541801452637 | KNN Loss: 5.60539436340332 | BCE Loss: 1.0251474380493164
Epoch 98 / 500 | iteration 5 / 30 | Total Loss: 6.730459213256836 | KNN Loss: 5.694844722747803 | BCE Loss: 1.0356147289276123
Epoch 98 / 500 | iteration 10 / 30 | Total Loss: 6.668277740478516 | KNN Loss: 5.62092399597168 | BCE Loss: 1

Epoch 107 / 500 | iteration 25 / 30 | Total Loss: 6.658070087432861 | KNN Loss: 5.632130146026611 | BCE Loss: 1.0259400606155396
Epoch 108 / 500 | iteration 0 / 30 | Total Loss: 6.688243389129639 | KNN Loss: 5.659222602844238 | BCE Loss: 1.0290207862854004
Epoch 108 / 500 | iteration 5 / 30 | Total Loss: 6.672596454620361 | KNN Loss: 5.618656635284424 | BCE Loss: 1.053939700126648
Epoch 108 / 500 | iteration 10 / 30 | Total Loss: 6.6593017578125 | KNN Loss: 5.598618030548096 | BCE Loss: 1.0606837272644043
Epoch 108 / 500 | iteration 15 / 30 | Total Loss: 6.634176254272461 | KNN Loss: 5.597323417663574 | BCE Loss: 1.0368525981903076
Epoch 108 / 500 | iteration 20 / 30 | Total Loss: 6.699206829071045 | KNN Loss: 5.653608322143555 | BCE Loss: 1.0455985069274902
Epoch 108 / 500 | iteration 25 / 30 | Total Loss: 6.661615371704102 | KNN Loss: 5.602289199829102 | BCE Loss: 1.0593262910842896
Epoch   109: reducing learning rate of group 0 to 1.7150e-03.
Epoch 109 / 500 | iteration 0 / 30 | Tot

Epoch 118 / 500 | iteration 15 / 30 | Total Loss: 6.67404317855835 | KNN Loss: 5.633572101593018 | BCE Loss: 1.040471076965332
Epoch 118 / 500 | iteration 20 / 30 | Total Loss: 6.711287498474121 | KNN Loss: 5.689304828643799 | BCE Loss: 1.0219826698303223
Epoch 118 / 500 | iteration 25 / 30 | Total Loss: 6.646989822387695 | KNN Loss: 5.6071672439575195 | BCE Loss: 1.0398228168487549
Epoch 119 / 500 | iteration 0 / 30 | Total Loss: 6.760867118835449 | KNN Loss: 5.693462371826172 | BCE Loss: 1.0674049854278564
Epoch 119 / 500 | iteration 5 / 30 | Total Loss: 6.62050199508667 | KNN Loss: 5.5963134765625 | BCE Loss: 1.0241886377334595
Epoch 119 / 500 | iteration 10 / 30 | Total Loss: 6.706087112426758 | KNN Loss: 5.681237697601318 | BCE Loss: 1.0248496532440186
Epoch 119 / 500 | iteration 15 / 30 | Total Loss: 6.6842851638793945 | KNN Loss: 5.635470390319824 | BCE Loss: 1.0488150119781494
Epoch 119 / 500 | iteration 20 / 30 | Total Loss: 6.648004531860352 | KNN Loss: 5.608229637145996 | BC

Epoch 129 / 500 | iteration 5 / 30 | Total Loss: 6.704456329345703 | KNN Loss: 5.672240734100342 | BCE Loss: 1.0322155952453613
Epoch 129 / 500 | iteration 10 / 30 | Total Loss: 6.653114318847656 | KNN Loss: 5.595937728881836 | BCE Loss: 1.0571763515472412
Epoch 129 / 500 | iteration 15 / 30 | Total Loss: 6.709204196929932 | KNN Loss: 5.65390682220459 | BCE Loss: 1.0552973747253418
Epoch 129 / 500 | iteration 20 / 30 | Total Loss: 6.687353134155273 | KNN Loss: 5.6281633377075195 | BCE Loss: 1.0591899156570435
Epoch 129 / 500 | iteration 25 / 30 | Total Loss: 6.65064001083374 | KNN Loss: 5.596179008483887 | BCE Loss: 1.054460883140564
Epoch 130 / 500 | iteration 0 / 30 | Total Loss: 6.714920997619629 | KNN Loss: 5.661226749420166 | BCE Loss: 1.053694248199463
Epoch 130 / 500 | iteration 5 / 30 | Total Loss: 6.652503967285156 | KNN Loss: 5.608149528503418 | BCE Loss: 1.0443546772003174
Epoch 130 / 500 | iteration 10 / 30 | Total Loss: 6.6246256828308105 | KNN Loss: 5.599148273468018 | BC

Epoch 139 / 500 | iteration 25 / 30 | Total Loss: 6.698829650878906 | KNN Loss: 5.633424758911133 | BCE Loss: 1.0654048919677734
Epoch 140 / 500 | iteration 0 / 30 | Total Loss: 6.625405311584473 | KNN Loss: 5.594020366668701 | BCE Loss: 1.0313847064971924
Epoch 140 / 500 | iteration 5 / 30 | Total Loss: 6.707432746887207 | KNN Loss: 5.647252082824707 | BCE Loss: 1.060180902481079
Epoch 140 / 500 | iteration 10 / 30 | Total Loss: 6.672046184539795 | KNN Loss: 5.6268157958984375 | BCE Loss: 1.0452302694320679
Epoch 140 / 500 | iteration 15 / 30 | Total Loss: 6.658307075500488 | KNN Loss: 5.604001522064209 | BCE Loss: 1.0543056726455688
Epoch 140 / 500 | iteration 20 / 30 | Total Loss: 6.6671857833862305 | KNN Loss: 5.611615180969238 | BCE Loss: 1.0555706024169922
Epoch 140 / 500 | iteration 25 / 30 | Total Loss: 6.750798225402832 | KNN Loss: 5.708104133605957 | BCE Loss: 1.042694091796875
Epoch 141 / 500 | iteration 0 / 30 | Total Loss: 6.663018226623535 | KNN Loss: 5.617671012878418 | 

Epoch 150 / 500 | iteration 15 / 30 | Total Loss: 6.661893844604492 | KNN Loss: 5.606217861175537 | BCE Loss: 1.055675745010376
Epoch 150 / 500 | iteration 20 / 30 | Total Loss: 6.6903815269470215 | KNN Loss: 5.643754959106445 | BCE Loss: 1.0466265678405762
Epoch 150 / 500 | iteration 25 / 30 | Total Loss: 6.752562046051025 | KNN Loss: 5.692263603210449 | BCE Loss: 1.0602985620498657
Epoch 151 / 500 | iteration 0 / 30 | Total Loss: 6.657244682312012 | KNN Loss: 5.609175682067871 | BCE Loss: 1.0480687618255615
Epoch 151 / 500 | iteration 5 / 30 | Total Loss: 6.66038703918457 | KNN Loss: 5.635445594787598 | BCE Loss: 1.0249414443969727
Epoch 151 / 500 | iteration 10 / 30 | Total Loss: 6.649491310119629 | KNN Loss: 5.5976338386535645 | BCE Loss: 1.0518572330474854
Epoch 151 / 500 | iteration 15 / 30 | Total Loss: 6.681614875793457 | KNN Loss: 5.608395099639893 | BCE Loss: 1.0732195377349854
Epoch 151 / 500 | iteration 20 / 30 | Total Loss: 6.6671528816223145 | KNN Loss: 5.628996849060059 

Epoch 161 / 500 | iteration 5 / 30 | Total Loss: 6.674901962280273 | KNN Loss: 5.630231857299805 | BCE Loss: 1.0446702241897583
Epoch 161 / 500 | iteration 10 / 30 | Total Loss: 6.654296398162842 | KNN Loss: 5.601962566375732 | BCE Loss: 1.052333950996399
Epoch 161 / 500 | iteration 15 / 30 | Total Loss: 6.682249069213867 | KNN Loss: 5.6502532958984375 | BCE Loss: 1.0319960117340088
Epoch 161 / 500 | iteration 20 / 30 | Total Loss: 6.679445266723633 | KNN Loss: 5.625283718109131 | BCE Loss: 1.0541613101959229
Epoch 161 / 500 | iteration 25 / 30 | Total Loss: 6.658733367919922 | KNN Loss: 5.603339195251465 | BCE Loss: 1.055393934249878
Epoch 162 / 500 | iteration 0 / 30 | Total Loss: 6.667007923126221 | KNN Loss: 5.627679347991943 | BCE Loss: 1.0393284559249878
Epoch 162 / 500 | iteration 5 / 30 | Total Loss: 6.674077033996582 | KNN Loss: 5.599569797515869 | BCE Loss: 1.0745073556900024
Epoch 162 / 500 | iteration 10 / 30 | Total Loss: 6.671327590942383 | KNN Loss: 5.607838153839111 | B

Epoch 171 / 500 | iteration 25 / 30 | Total Loss: 6.654973983764648 | KNN Loss: 5.608983516693115 | BCE Loss: 1.0459905862808228
Epoch 172 / 500 | iteration 0 / 30 | Total Loss: 6.742824554443359 | KNN Loss: 5.693508148193359 | BCE Loss: 1.049316644668579
Epoch 172 / 500 | iteration 5 / 30 | Total Loss: 6.654591083526611 | KNN Loss: 5.618764877319336 | BCE Loss: 1.035826325416565
Epoch 172 / 500 | iteration 10 / 30 | Total Loss: 6.7074432373046875 | KNN Loss: 5.6508283615112305 | BCE Loss: 1.0566151142120361
Epoch 172 / 500 | iteration 15 / 30 | Total Loss: 6.663857936859131 | KNN Loss: 5.627274990081787 | BCE Loss: 1.0365828275680542
Epoch 172 / 500 | iteration 20 / 30 | Total Loss: 6.692356586456299 | KNN Loss: 5.63015079498291 | BCE Loss: 1.0622057914733887
Epoch 172 / 500 | iteration 25 / 30 | Total Loss: 6.702936172485352 | KNN Loss: 5.620359897613525 | BCE Loss: 1.0825763940811157
Epoch 173 / 500 | iteration 0 / 30 | Total Loss: 6.643446922302246 | KNN Loss: 5.60283088684082 | BC

Epoch 182 / 500 | iteration 15 / 30 | Total Loss: 6.694148540496826 | KNN Loss: 5.653868198394775 | BCE Loss: 1.0402802228927612
Epoch 182 / 500 | iteration 20 / 30 | Total Loss: 6.639190673828125 | KNN Loss: 5.594297885894775 | BCE Loss: 1.0448927879333496
Epoch 182 / 500 | iteration 25 / 30 | Total Loss: 6.691929817199707 | KNN Loss: 5.653535842895508 | BCE Loss: 1.0383939743041992
Epoch 183 / 500 | iteration 0 / 30 | Total Loss: 6.645488739013672 | KNN Loss: 5.59893274307251 | BCE Loss: 1.046555995941162
Epoch 183 / 500 | iteration 5 / 30 | Total Loss: 6.681106090545654 | KNN Loss: 5.605096340179443 | BCE Loss: 1.0760098695755005
Epoch 183 / 500 | iteration 10 / 30 | Total Loss: 6.676749229431152 | KNN Loss: 5.617606163024902 | BCE Loss: 1.059142827987671
Epoch 183 / 500 | iteration 15 / 30 | Total Loss: 6.642585754394531 | KNN Loss: 5.596022129058838 | BCE Loss: 1.0465636253356934
Epoch 183 / 500 | iteration 20 / 30 | Total Loss: 6.689332008361816 | KNN Loss: 5.602673053741455 | BC

Epoch 193 / 500 | iteration 5 / 30 | Total Loss: 6.627572059631348 | KNN Loss: 5.595192909240723 | BCE Loss: 1.032378911972046
Epoch 193 / 500 | iteration 10 / 30 | Total Loss: 6.663183212280273 | KNN Loss: 5.635003089904785 | BCE Loss: 1.0281801223754883
Epoch 193 / 500 | iteration 15 / 30 | Total Loss: 6.699776649475098 | KNN Loss: 5.630991458892822 | BCE Loss: 1.0687849521636963
Epoch 193 / 500 | iteration 20 / 30 | Total Loss: 6.659555435180664 | KNN Loss: 5.627314567565918 | BCE Loss: 1.032240629196167
Epoch 193 / 500 | iteration 25 / 30 | Total Loss: 6.647723197937012 | KNN Loss: 5.598362922668457 | BCE Loss: 1.0493602752685547
Epoch 194 / 500 | iteration 0 / 30 | Total Loss: 6.6737213134765625 | KNN Loss: 5.616128921508789 | BCE Loss: 1.0575923919677734
Epoch 194 / 500 | iteration 5 / 30 | Total Loss: 6.685408115386963 | KNN Loss: 5.64432430267334 | BCE Loss: 1.041083812713623
Epoch 194 / 500 | iteration 10 / 30 | Total Loss: 6.655383586883545 | KNN Loss: 5.611464500427246 | BCE

Epoch 203 / 500 | iteration 25 / 30 | Total Loss: 6.655445575714111 | KNN Loss: 5.625586986541748 | BCE Loss: 1.0298585891723633
Epoch 204 / 500 | iteration 0 / 30 | Total Loss: 6.628536224365234 | KNN Loss: 5.599133491516113 | BCE Loss: 1.0294029712677002
Epoch 204 / 500 | iteration 5 / 30 | Total Loss: 6.652048587799072 | KNN Loss: 5.597593307495117 | BCE Loss: 1.054455280303955
Epoch 204 / 500 | iteration 10 / 30 | Total Loss: 6.661632537841797 | KNN Loss: 5.63989782333374 | BCE Loss: 1.0217347145080566
Epoch 204 / 500 | iteration 15 / 30 | Total Loss: 6.678160667419434 | KNN Loss: 5.622011661529541 | BCE Loss: 1.0561492443084717
Epoch 204 / 500 | iteration 20 / 30 | Total Loss: 6.617369651794434 | KNN Loss: 5.591113567352295 | BCE Loss: 1.0262560844421387
Epoch 204 / 500 | iteration 25 / 30 | Total Loss: 6.713399887084961 | KNN Loss: 5.659091472625732 | BCE Loss: 1.0543086528778076
Epoch 205 / 500 | iteration 0 / 30 | Total Loss: 6.675411701202393 | KNN Loss: 5.647708415985107 | BC

Epoch 214 / 500 | iteration 15 / 30 | Total Loss: 6.657649993896484 | KNN Loss: 5.607687473297119 | BCE Loss: 1.0499622821807861
Epoch 214 / 500 | iteration 20 / 30 | Total Loss: 6.754766941070557 | KNN Loss: 5.709828853607178 | BCE Loss: 1.044938087463379
Epoch 214 / 500 | iteration 25 / 30 | Total Loss: 6.678420066833496 | KNN Loss: 5.601779937744141 | BCE Loss: 1.0766398906707764
Epoch 215 / 500 | iteration 0 / 30 | Total Loss: 6.679011344909668 | KNN Loss: 5.625884532928467 | BCE Loss: 1.0531266927719116
Epoch 215 / 500 | iteration 5 / 30 | Total Loss: 6.633805274963379 | KNN Loss: 5.612529277801514 | BCE Loss: 1.0212759971618652
Epoch 215 / 500 | iteration 10 / 30 | Total Loss: 6.6993584632873535 | KNN Loss: 5.666172027587891 | BCE Loss: 1.033186435699463
Epoch 215 / 500 | iteration 15 / 30 | Total Loss: 6.652082443237305 | KNN Loss: 5.599139213562012 | BCE Loss: 1.052943468093872
Epoch 215 / 500 | iteration 20 / 30 | Total Loss: 6.6958818435668945 | KNN Loss: 5.645003318786621 | 

Epoch 225 / 500 | iteration 5 / 30 | Total Loss: 6.653926372528076 | KNN Loss: 5.606147766113281 | BCE Loss: 1.047778606414795
Epoch 225 / 500 | iteration 10 / 30 | Total Loss: 6.676407814025879 | KNN Loss: 5.63228178024292 | BCE Loss: 1.044126033782959
Epoch 225 / 500 | iteration 15 / 30 | Total Loss: 6.656899929046631 | KNN Loss: 5.595702648162842 | BCE Loss: 1.0611974000930786
Epoch 225 / 500 | iteration 20 / 30 | Total Loss: 6.660134315490723 | KNN Loss: 5.617269515991211 | BCE Loss: 1.0428646802902222
Epoch 225 / 500 | iteration 25 / 30 | Total Loss: 6.662383079528809 | KNN Loss: 5.634780406951904 | BCE Loss: 1.0276026725769043
Epoch 226 / 500 | iteration 0 / 30 | Total Loss: 6.758236408233643 | KNN Loss: 5.689918041229248 | BCE Loss: 1.0683183670043945
Epoch 226 / 500 | iteration 5 / 30 | Total Loss: 6.651419639587402 | KNN Loss: 5.6304168701171875 | BCE Loss: 1.0210027694702148
Epoch 226 / 500 | iteration 10 / 30 | Total Loss: 6.6547675132751465 | KNN Loss: 5.621790409088135 | B

Epoch 235 / 500 | iteration 25 / 30 | Total Loss: 6.696529388427734 | KNN Loss: 5.639762878417969 | BCE Loss: 1.0567665100097656
Epoch   236: reducing learning rate of group 0 to 4.8445e-05.
Epoch 236 / 500 | iteration 0 / 30 | Total Loss: 6.649460315704346 | KNN Loss: 5.628605842590332 | BCE Loss: 1.0208545923233032
Epoch 236 / 500 | iteration 5 / 30 | Total Loss: 6.685639381408691 | KNN Loss: 5.610720157623291 | BCE Loss: 1.0749194622039795
Epoch 236 / 500 | iteration 10 / 30 | Total Loss: 6.641822338104248 | KNN Loss: 5.592301845550537 | BCE Loss: 1.0495203733444214
Epoch 236 / 500 | iteration 15 / 30 | Total Loss: 6.626049041748047 | KNN Loss: 5.608924865722656 | BCE Loss: 1.017124056816101
Epoch 236 / 500 | iteration 20 / 30 | Total Loss: 6.674127578735352 | KNN Loss: 5.6397809982299805 | BCE Loss: 1.0343466997146606
Epoch 236 / 500 | iteration 25 / 30 | Total Loss: 6.650354862213135 | KNN Loss: 5.605289459228516 | BCE Loss: 1.0450654029846191
Epoch 237 / 500 | iteration 0 / 30 | 

Epoch 246 / 500 | iteration 15 / 30 | Total Loss: 6.6576457023620605 | KNN Loss: 5.5944600105285645 | BCE Loss: 1.0631855726242065
Epoch 246 / 500 | iteration 20 / 30 | Total Loss: 6.663517951965332 | KNN Loss: 5.638101577758789 | BCE Loss: 1.025416612625122
Epoch 246 / 500 | iteration 25 / 30 | Total Loss: 6.626543045043945 | KNN Loss: 5.609468936920166 | BCE Loss: 1.0170741081237793
Epoch   247: reducing learning rate of group 0 to 3.3911e-05.
Epoch 247 / 500 | iteration 0 / 30 | Total Loss: 6.650997638702393 | KNN Loss: 5.62204647064209 | BCE Loss: 1.0289510488510132
Epoch 247 / 500 | iteration 5 / 30 | Total Loss: 6.646271705627441 | KNN Loss: 5.599305152893066 | BCE Loss: 1.046966314315796
Epoch 247 / 500 | iteration 10 / 30 | Total Loss: 6.669759750366211 | KNN Loss: 5.625309467315674 | BCE Loss: 1.0444505214691162
Epoch 247 / 500 | iteration 15 / 30 | Total Loss: 6.7035112380981445 | KNN Loss: 5.647756576538086 | BCE Loss: 1.0557544231414795
Epoch 247 / 500 | iteration 20 / 30 |

Epoch 257 / 500 | iteration 5 / 30 | Total Loss: 6.667695045471191 | KNN Loss: 5.6216912269592285 | BCE Loss: 1.0460035800933838
Epoch 257 / 500 | iteration 10 / 30 | Total Loss: 6.692686080932617 | KNN Loss: 5.645289897918701 | BCE Loss: 1.047395944595337
Epoch 257 / 500 | iteration 15 / 30 | Total Loss: 6.649008274078369 | KNN Loss: 5.617607116699219 | BCE Loss: 1.0314011573791504
Epoch 257 / 500 | iteration 20 / 30 | Total Loss: 6.6841535568237305 | KNN Loss: 5.596723556518555 | BCE Loss: 1.0874302387237549
Epoch 257 / 500 | iteration 25 / 30 | Total Loss: 6.634777069091797 | KNN Loss: 5.604469299316406 | BCE Loss: 1.0303077697753906
Epoch 258 / 500 | iteration 0 / 30 | Total Loss: 6.67627477645874 | KNN Loss: 5.621228218078613 | BCE Loss: 1.0550464391708374
Epoch 258 / 500 | iteration 5 / 30 | Total Loss: 6.69715690612793 | KNN Loss: 5.625401020050049 | BCE Loss: 1.0717560052871704
Epoch 258 / 500 | iteration 10 / 30 | Total Loss: 6.637969017028809 | KNN Loss: 5.601071357727051 | B

Epoch 267 / 500 | iteration 25 / 30 | Total Loss: 6.664780616760254 | KNN Loss: 5.627842426300049 | BCE Loss: 1.0369383096694946
Epoch 268 / 500 | iteration 0 / 30 | Total Loss: 6.676797866821289 | KNN Loss: 5.615804195404053 | BCE Loss: 1.0609939098358154
Epoch 268 / 500 | iteration 5 / 30 | Total Loss: 6.64719295501709 | KNN Loss: 5.5883612632751465 | BCE Loss: 1.0588319301605225
Epoch 268 / 500 | iteration 10 / 30 | Total Loss: 6.666759967803955 | KNN Loss: 5.632730484008789 | BCE Loss: 1.034029483795166
Epoch 268 / 500 | iteration 15 / 30 | Total Loss: 6.625604629516602 | KNN Loss: 5.595890998840332 | BCE Loss: 1.0297133922576904
Epoch 268 / 500 | iteration 20 / 30 | Total Loss: 6.802513122558594 | KNN Loss: 5.718184947967529 | BCE Loss: 1.0843284130096436
Epoch 268 / 500 | iteration 25 / 30 | Total Loss: 6.638161659240723 | KNN Loss: 5.5976948738098145 | BCE Loss: 1.0404669046401978
Epoch 269 / 500 | iteration 0 / 30 | Total Loss: 6.642302989959717 | KNN Loss: 5.595124244689941 | 

Epoch 278 / 500 | iteration 15 / 30 | Total Loss: 6.674072265625 | KNN Loss: 5.660765647888184 | BCE Loss: 1.013306736946106
Epoch 278 / 500 | iteration 20 / 30 | Total Loss: 6.715733528137207 | KNN Loss: 5.669051647186279 | BCE Loss: 1.0466818809509277
Epoch 278 / 500 | iteration 25 / 30 | Total Loss: 6.679945945739746 | KNN Loss: 5.631691932678223 | BCE Loss: 1.0482540130615234
Epoch 279 / 500 | iteration 0 / 30 | Total Loss: 6.69443941116333 | KNN Loss: 5.64022970199585 | BCE Loss: 1.054209589958191
Epoch 279 / 500 | iteration 5 / 30 | Total Loss: 6.630275726318359 | KNN Loss: 5.594376564025879 | BCE Loss: 1.035899043083191
Epoch 279 / 500 | iteration 10 / 30 | Total Loss: 6.631745338439941 | KNN Loss: 5.596572399139404 | BCE Loss: 1.0351728200912476
Epoch 279 / 500 | iteration 15 / 30 | Total Loss: 6.754539966583252 | KNN Loss: 5.698551177978516 | BCE Loss: 1.0559886693954468
Epoch 279 / 500 | iteration 20 / 30 | Total Loss: 6.6699090003967285 | KNN Loss: 5.611634731292725 | BCE Lo

Epoch 289 / 500 | iteration 5 / 30 | Total Loss: 6.671802520751953 | KNN Loss: 5.5978875160217285 | BCE Loss: 1.0739150047302246
Epoch 289 / 500 | iteration 10 / 30 | Total Loss: 6.651455879211426 | KNN Loss: 5.603970527648926 | BCE Loss: 1.0474852323532104
Epoch 289 / 500 | iteration 15 / 30 | Total Loss: 6.652804374694824 | KNN Loss: 5.61592435836792 | BCE Loss: 1.0368802547454834
Epoch 289 / 500 | iteration 20 / 30 | Total Loss: 6.71419095993042 | KNN Loss: 5.644701957702637 | BCE Loss: 1.0694891214370728
Epoch 289 / 500 | iteration 25 / 30 | Total Loss: 6.685623645782471 | KNN Loss: 5.640969276428223 | BCE Loss: 1.044654369354248
Epoch 290 / 500 | iteration 0 / 30 | Total Loss: 6.697491645812988 | KNN Loss: 5.651240348815918 | BCE Loss: 1.0462511777877808
Epoch 290 / 500 | iteration 5 / 30 | Total Loss: 6.675634384155273 | KNN Loss: 5.6447858810424805 | BCE Loss: 1.030848741531372
Epoch 290 / 500 | iteration 10 / 30 | Total Loss: 6.704216957092285 | KNN Loss: 5.651259422302246 | BC

Epoch 299 / 500 | iteration 25 / 30 | Total Loss: 6.656129360198975 | KNN Loss: 5.616641521453857 | BCE Loss: 1.0394878387451172
Epoch 300 / 500 | iteration 0 / 30 | Total Loss: 6.625397205352783 | KNN Loss: 5.59857702255249 | BCE Loss: 1.0268203020095825
Epoch 300 / 500 | iteration 5 / 30 | Total Loss: 6.697272300720215 | KNN Loss: 5.645955562591553 | BCE Loss: 1.0513166189193726
Epoch 300 / 500 | iteration 10 / 30 | Total Loss: 6.641472816467285 | KNN Loss: 5.615651607513428 | BCE Loss: 1.025821328163147
Epoch 300 / 500 | iteration 15 / 30 | Total Loss: 6.6731109619140625 | KNN Loss: 5.607153415679932 | BCE Loss: 1.0659573078155518
Epoch 300 / 500 | iteration 20 / 30 | Total Loss: 6.7098565101623535 | KNN Loss: 5.660873889923096 | BCE Loss: 1.0489826202392578
Epoch 300 / 500 | iteration 25 / 30 | Total Loss: 6.67380952835083 | KNN Loss: 5.613523960113525 | BCE Loss: 1.0602856874465942
Epoch 301 / 500 | iteration 0 / 30 | Total Loss: 6.668851852416992 | KNN Loss: 5.606689929962158 | B

Epoch 310 / 500 | iteration 15 / 30 | Total Loss: 6.672600746154785 | KNN Loss: 5.624458312988281 | BCE Loss: 1.0481421947479248
Epoch 310 / 500 | iteration 20 / 30 | Total Loss: 6.650379657745361 | KNN Loss: 5.622498989105225 | BCE Loss: 1.0278807878494263
Epoch 310 / 500 | iteration 25 / 30 | Total Loss: 6.665004253387451 | KNN Loss: 5.5984039306640625 | BCE Loss: 1.0666003227233887
Epoch 311 / 500 | iteration 0 / 30 | Total Loss: 6.640080451965332 | KNN Loss: 5.607560634613037 | BCE Loss: 1.032520055770874
Epoch 311 / 500 | iteration 5 / 30 | Total Loss: 6.65714168548584 | KNN Loss: 5.631250381469727 | BCE Loss: 1.0258913040161133
Epoch 311 / 500 | iteration 10 / 30 | Total Loss: 6.616027355194092 | KNN Loss: 5.594976902008057 | BCE Loss: 1.0210504531860352
Epoch 311 / 500 | iteration 15 / 30 | Total Loss: 6.681960105895996 | KNN Loss: 5.64342737197876 | BCE Loss: 1.0385327339172363
Epoch 311 / 500 | iteration 20 / 30 | Total Loss: 6.656435489654541 | KNN Loss: 5.596707820892334 | B

Epoch 321 / 500 | iteration 5 / 30 | Total Loss: 6.678255558013916 | KNN Loss: 5.6677680015563965 | BCE Loss: 1.0104875564575195
Epoch 321 / 500 | iteration 10 / 30 | Total Loss: 6.689375877380371 | KNN Loss: 5.652961254119873 | BCE Loss: 1.036414384841919
Epoch 321 / 500 | iteration 15 / 30 | Total Loss: 6.680019855499268 | KNN Loss: 5.643674373626709 | BCE Loss: 1.036345362663269
Epoch 321 / 500 | iteration 20 / 30 | Total Loss: 6.6661505699157715 | KNN Loss: 5.605766296386719 | BCE Loss: 1.0603841543197632
Epoch 321 / 500 | iteration 25 / 30 | Total Loss: 6.694101333618164 | KNN Loss: 5.628669261932373 | BCE Loss: 1.0654323101043701
Epoch 322 / 500 | iteration 0 / 30 | Total Loss: 6.675800323486328 | KNN Loss: 5.61116886138916 | BCE Loss: 1.0646312236785889
Epoch 322 / 500 | iteration 5 / 30 | Total Loss: 6.72736930847168 | KNN Loss: 5.654639720916748 | BCE Loss: 1.0727298259735107
Epoch 322 / 500 | iteration 10 / 30 | Total Loss: 6.694095611572266 | KNN Loss: 5.659055709838867 | BC

Epoch 331 / 500 | iteration 25 / 30 | Total Loss: 6.678648948669434 | KNN Loss: 5.5883588790893555 | BCE Loss: 1.0902900695800781
Epoch 332 / 500 | iteration 0 / 30 | Total Loss: 6.739222526550293 | KNN Loss: 5.660601615905762 | BCE Loss: 1.0786211490631104
Epoch 332 / 500 | iteration 5 / 30 | Total Loss: 6.646421432495117 | KNN Loss: 5.591630458831787 | BCE Loss: 1.05479097366333
Epoch 332 / 500 | iteration 10 / 30 | Total Loss: 6.64228630065918 | KNN Loss: 5.595375061035156 | BCE Loss: 1.0469114780426025
Epoch 332 / 500 | iteration 15 / 30 | Total Loss: 6.666668891906738 | KNN Loss: 5.597116947174072 | BCE Loss: 1.069551944732666
Epoch 332 / 500 | iteration 20 / 30 | Total Loss: 6.706603050231934 | KNN Loss: 5.650901794433594 | BCE Loss: 1.0557011365890503
Epoch 332 / 500 | iteration 25 / 30 | Total Loss: 6.7288289070129395 | KNN Loss: 5.684493064880371 | BCE Loss: 1.0443357229232788
Epoch 333 / 500 | iteration 0 / 30 | Total Loss: 6.691821098327637 | KNN Loss: 5.651113033294678 | BC

Epoch 342 / 500 | iteration 15 / 30 | Total Loss: 6.646759986877441 | KNN Loss: 5.597620010375977 | BCE Loss: 1.0491397380828857
Epoch 342 / 500 | iteration 20 / 30 | Total Loss: 6.764778137207031 | KNN Loss: 5.709144115447998 | BCE Loss: 1.055633783340454
Epoch 342 / 500 | iteration 25 / 30 | Total Loss: 6.627460479736328 | KNN Loss: 5.6091413497924805 | BCE Loss: 1.0183193683624268
Epoch 343 / 500 | iteration 0 / 30 | Total Loss: 6.647134780883789 | KNN Loss: 5.595400810241699 | BCE Loss: 1.051734209060669
Epoch 343 / 500 | iteration 5 / 30 | Total Loss: 6.679005146026611 | KNN Loss: 5.637478828430176 | BCE Loss: 1.0415263175964355
Epoch 343 / 500 | iteration 10 / 30 | Total Loss: 6.653043746948242 | KNN Loss: 5.622384548187256 | BCE Loss: 1.0306590795516968
Epoch 343 / 500 | iteration 15 / 30 | Total Loss: 6.637462615966797 | KNN Loss: 5.598940849304199 | BCE Loss: 1.0385217666625977
Epoch 343 / 500 | iteration 20 / 30 | Total Loss: 6.659110069274902 | KNN Loss: 5.6266913414001465 |

Epoch 353 / 500 | iteration 5 / 30 | Total Loss: 6.65253210067749 | KNN Loss: 5.621780872344971 | BCE Loss: 1.03075110912323
Epoch 353 / 500 | iteration 10 / 30 | Total Loss: 6.707886695861816 | KNN Loss: 5.651646137237549 | BCE Loss: 1.0562403202056885
Epoch 353 / 500 | iteration 15 / 30 | Total Loss: 6.715741157531738 | KNN Loss: 5.6532769203186035 | BCE Loss: 1.0624641180038452
Epoch 353 / 500 | iteration 20 / 30 | Total Loss: 6.67443323135376 | KNN Loss: 5.591287136077881 | BCE Loss: 1.083146095275879
Epoch 353 / 500 | iteration 25 / 30 | Total Loss: 6.659801959991455 | KNN Loss: 5.597300052642822 | BCE Loss: 1.0625019073486328
Epoch 354 / 500 | iteration 0 / 30 | Total Loss: 6.667236328125 | KNN Loss: 5.633293628692627 | BCE Loss: 1.033942461013794
Epoch 354 / 500 | iteration 5 / 30 | Total Loss: 6.644109725952148 | KNN Loss: 5.598114013671875 | BCE Loss: 1.0459957122802734
Epoch 354 / 500 | iteration 10 / 30 | Total Loss: 6.669895648956299 | KNN Loss: 5.616996765136719 | BCE Loss

Epoch 363 / 500 | iteration 25 / 30 | Total Loss: 6.68305778503418 | KNN Loss: 5.632984638214111 | BCE Loss: 1.0500731468200684
Epoch 364 / 500 | iteration 0 / 30 | Total Loss: 6.68613862991333 | KNN Loss: 5.650823593139648 | BCE Loss: 1.0353151559829712
Epoch 364 / 500 | iteration 5 / 30 | Total Loss: 6.624023914337158 | KNN Loss: 5.597300052642822 | BCE Loss: 1.026723861694336
Epoch 364 / 500 | iteration 10 / 30 | Total Loss: 6.684288024902344 | KNN Loss: 5.6822404861450195 | BCE Loss: 1.0020475387573242
Epoch 364 / 500 | iteration 15 / 30 | Total Loss: 6.667510032653809 | KNN Loss: 5.615283012390137 | BCE Loss: 1.0522270202636719
Epoch 364 / 500 | iteration 20 / 30 | Total Loss: 6.641329288482666 | KNN Loss: 5.63900089263916 | BCE Loss: 1.0023282766342163
Epoch 364 / 500 | iteration 25 / 30 | Total Loss: 6.627747535705566 | KNN Loss: 5.5930304527282715 | BCE Loss: 1.0347168445587158
Epoch 365 / 500 | iteration 0 / 30 | Total Loss: 6.639228820800781 | KNN Loss: 5.6053595542907715 | B

Epoch 374 / 500 | iteration 15 / 30 | Total Loss: 6.65596342086792 | KNN Loss: 5.624782085418701 | BCE Loss: 1.0311812162399292
Epoch 374 / 500 | iteration 20 / 30 | Total Loss: 6.692930221557617 | KNN Loss: 5.64038610458374 | BCE Loss: 1.052544116973877
Epoch 374 / 500 | iteration 25 / 30 | Total Loss: 6.738572597503662 | KNN Loss: 5.692989826202393 | BCE Loss: 1.0455827713012695
Epoch 375 / 500 | iteration 0 / 30 | Total Loss: 6.660191059112549 | KNN Loss: 5.621275424957275 | BCE Loss: 1.038915753364563
Epoch 375 / 500 | iteration 5 / 30 | Total Loss: 6.649115085601807 | KNN Loss: 5.595417499542236 | BCE Loss: 1.0536975860595703
Epoch 375 / 500 | iteration 10 / 30 | Total Loss: 6.64787483215332 | KNN Loss: 5.6032562255859375 | BCE Loss: 1.044618844985962
Epoch 375 / 500 | iteration 15 / 30 | Total Loss: 6.65924072265625 | KNN Loss: 5.599930286407471 | BCE Loss: 1.0593105554580688
Epoch 375 / 500 | iteration 20 / 30 | Total Loss: 6.699351787567139 | KNN Loss: 5.653830051422119 | BCE L

Epoch 385 / 500 | iteration 5 / 30 | Total Loss: 6.649996757507324 | KNN Loss: 5.616557598114014 | BCE Loss: 1.0334391593933105
Epoch 385 / 500 | iteration 10 / 30 | Total Loss: 6.693792343139648 | KNN Loss: 5.605706214904785 | BCE Loss: 1.0880863666534424
Epoch 385 / 500 | iteration 15 / 30 | Total Loss: 6.647134304046631 | KNN Loss: 5.61154317855835 | BCE Loss: 1.0355911254882812
Epoch 385 / 500 | iteration 20 / 30 | Total Loss: 6.659626007080078 | KNN Loss: 5.60463809967041 | BCE Loss: 1.054987907409668
Epoch 385 / 500 | iteration 25 / 30 | Total Loss: 6.6413164138793945 | KNN Loss: 5.598004341125488 | BCE Loss: 1.0433118343353271
Epoch 386 / 500 | iteration 0 / 30 | Total Loss: 6.700868129730225 | KNN Loss: 5.666780948638916 | BCE Loss: 1.0340871810913086
Epoch 386 / 500 | iteration 5 / 30 | Total Loss: 6.625696182250977 | KNN Loss: 5.607847213745117 | BCE Loss: 1.0178489685058594
Epoch 386 / 500 | iteration 10 / 30 | Total Loss: 6.661925315856934 | KNN Loss: 5.617593765258789 | BC

Epoch 395 / 500 | iteration 25 / 30 | Total Loss: 6.6920294761657715 | KNN Loss: 5.613185882568359 | BCE Loss: 1.0788434743881226
Epoch   396: reducing learning rate of group 0 to 3.2856e-07.
Epoch 396 / 500 | iteration 0 / 30 | Total Loss: 6.6530351638793945 | KNN Loss: 5.603431701660156 | BCE Loss: 1.0496032238006592
Epoch 396 / 500 | iteration 5 / 30 | Total Loss: 6.655535697937012 | KNN Loss: 5.617489337921143 | BCE Loss: 1.0380465984344482
Epoch 396 / 500 | iteration 10 / 30 | Total Loss: 6.682718276977539 | KNN Loss: 5.658532619476318 | BCE Loss: 1.0241857767105103
Epoch 396 / 500 | iteration 15 / 30 | Total Loss: 6.668723106384277 | KNN Loss: 5.647424221038818 | BCE Loss: 1.021298885345459
Epoch 396 / 500 | iteration 20 / 30 | Total Loss: 6.668506622314453 | KNN Loss: 5.598869800567627 | BCE Loss: 1.0696368217468262
Epoch 396 / 500 | iteration 25 / 30 | Total Loss: 6.645707607269287 | KNN Loss: 5.610860824584961 | BCE Loss: 1.0348467826843262
Epoch 397 / 500 | iteration 0 / 30 |

Epoch 406 / 500 | iteration 15 / 30 | Total Loss: 6.642141819000244 | KNN Loss: 5.604162216186523 | BCE Loss: 1.0379797220230103
Epoch 406 / 500 | iteration 20 / 30 | Total Loss: 6.670628070831299 | KNN Loss: 5.603306770324707 | BCE Loss: 1.0673213005065918
Epoch 406 / 500 | iteration 25 / 30 | Total Loss: 6.640661716461182 | KNN Loss: 5.593796253204346 | BCE Loss: 1.0468655824661255
Epoch   407: reducing learning rate of group 0 to 2.2999e-07.
Epoch 407 / 500 | iteration 0 / 30 | Total Loss: 6.66605806350708 | KNN Loss: 5.602176666259766 | BCE Loss: 1.063881278038025
Epoch 407 / 500 | iteration 5 / 30 | Total Loss: 6.675603866577148 | KNN Loss: 5.60789155960083 | BCE Loss: 1.0677125453948975
Epoch 407 / 500 | iteration 10 / 30 | Total Loss: 6.6283440589904785 | KNN Loss: 5.600160598754883 | BCE Loss: 1.0281834602355957
Epoch 407 / 500 | iteration 15 / 30 | Total Loss: 6.658884048461914 | KNN Loss: 5.6255292892456055 | BCE Loss: 1.0333545207977295
Epoch 407 / 500 | iteration 20 / 30 | 

Epoch 417 / 500 | iteration 5 / 30 | Total Loss: 6.6894378662109375 | KNN Loss: 5.614733695983887 | BCE Loss: 1.0747041702270508
Epoch 417 / 500 | iteration 10 / 30 | Total Loss: 6.639437675476074 | KNN Loss: 5.6038498878479 | BCE Loss: 1.0355877876281738
Epoch 417 / 500 | iteration 15 / 30 | Total Loss: 6.629878520965576 | KNN Loss: 5.6123948097229 | BCE Loss: 1.0174837112426758
Epoch 417 / 500 | iteration 20 / 30 | Total Loss: 6.733138084411621 | KNN Loss: 5.692924976348877 | BCE Loss: 1.040212869644165
Epoch 417 / 500 | iteration 25 / 30 | Total Loss: 6.654616355895996 | KNN Loss: 5.622399806976318 | BCE Loss: 1.0322165489196777
Epoch   418: reducing learning rate of group 0 to 1.6100e-07.
Epoch 418 / 500 | iteration 0 / 30 | Total Loss: 6.652790069580078 | KNN Loss: 5.606924533843994 | BCE Loss: 1.045865535736084
Epoch 418 / 500 | iteration 5 / 30 | Total Loss: 6.72524356842041 | KNN Loss: 5.679763317108154 | BCE Loss: 1.0454803705215454
Epoch 418 / 500 | iteration 10 / 30 | Total 

Epoch 427 / 500 | iteration 25 / 30 | Total Loss: 6.676879405975342 | KNN Loss: 5.637180805206299 | BCE Loss: 1.039698600769043
Epoch 428 / 500 | iteration 0 / 30 | Total Loss: 6.648305892944336 | KNN Loss: 5.593700408935547 | BCE Loss: 1.05460524559021
Epoch 428 / 500 | iteration 5 / 30 | Total Loss: 6.667458534240723 | KNN Loss: 5.640483856201172 | BCE Loss: 1.0269744396209717
Epoch 428 / 500 | iteration 10 / 30 | Total Loss: 6.638981819152832 | KNN Loss: 5.5918378829956055 | BCE Loss: 1.047143816947937
Epoch 428 / 500 | iteration 15 / 30 | Total Loss: 6.662377834320068 | KNN Loss: 5.619941711425781 | BCE Loss: 1.0424362421035767
Epoch 428 / 500 | iteration 20 / 30 | Total Loss: 6.639060020446777 | KNN Loss: 5.599078178405762 | BCE Loss: 1.0399816036224365
Epoch 428 / 500 | iteration 25 / 30 | Total Loss: 6.719330310821533 | KNN Loss: 5.666464328765869 | BCE Loss: 1.0528661012649536
Epoch   429: reducing learning rate of group 0 to 1.1270e-07.
Epoch 429 / 500 | iteration 0 / 30 | Tot

Epoch 438 / 500 | iteration 15 / 30 | Total Loss: 6.673843860626221 | KNN Loss: 5.61041259765625 | BCE Loss: 1.0634312629699707
Epoch 438 / 500 | iteration 20 / 30 | Total Loss: 6.785172462463379 | KNN Loss: 5.7191948890686035 | BCE Loss: 1.0659778118133545
Epoch 438 / 500 | iteration 25 / 30 | Total Loss: 6.679688930511475 | KNN Loss: 5.632978439331055 | BCE Loss: 1.04671049118042
Epoch 439 / 500 | iteration 0 / 30 | Total Loss: 6.754046440124512 | KNN Loss: 5.708021640777588 | BCE Loss: 1.046025037765503
Epoch 439 / 500 | iteration 5 / 30 | Total Loss: 6.663088798522949 | KNN Loss: 5.595804691314697 | BCE Loss: 1.067284107208252
Epoch 439 / 500 | iteration 10 / 30 | Total Loss: 6.690507411956787 | KNN Loss: 5.624465465545654 | BCE Loss: 1.0660420656204224
Epoch 439 / 500 | iteration 15 / 30 | Total Loss: 6.637878894805908 | KNN Loss: 5.6069488525390625 | BCE Loss: 1.0309301614761353
Epoch 439 / 500 | iteration 20 / 30 | Total Loss: 6.65732479095459 | KNN Loss: 5.608053207397461 | BCE

Epoch 449 / 500 | iteration 5 / 30 | Total Loss: 6.639945983886719 | KNN Loss: 5.6075263023376465 | BCE Loss: 1.0324194431304932
Epoch 449 / 500 | iteration 10 / 30 | Total Loss: 6.705848693847656 | KNN Loss: 5.629512786865234 | BCE Loss: 1.0763359069824219
Epoch 449 / 500 | iteration 15 / 30 | Total Loss: 6.694762229919434 | KNN Loss: 5.648298740386963 | BCE Loss: 1.0464632511138916
Epoch 449 / 500 | iteration 20 / 30 | Total Loss: 6.657317161560059 | KNN Loss: 5.600157737731934 | BCE Loss: 1.057159423828125
Epoch 449 / 500 | iteration 25 / 30 | Total Loss: 6.647641181945801 | KNN Loss: 5.602902412414551 | BCE Loss: 1.044739007949829
Epoch 450 / 500 | iteration 0 / 30 | Total Loss: 6.652824878692627 | KNN Loss: 5.614565849304199 | BCE Loss: 1.0382590293884277
Epoch 450 / 500 | iteration 5 / 30 | Total Loss: 6.627872943878174 | KNN Loss: 5.608248710632324 | BCE Loss: 1.0196242332458496
Epoch 450 / 500 | iteration 10 / 30 | Total Loss: 6.683035373687744 | KNN Loss: 5.635503768920898 | B

Epoch 459 / 500 | iteration 25 / 30 | Total Loss: 6.704185962677002 | KNN Loss: 5.65622091293335 | BCE Loss: 1.0479649305343628
Epoch 460 / 500 | iteration 0 / 30 | Total Loss: 6.7202558517456055 | KNN Loss: 5.668644905090332 | BCE Loss: 1.051611065864563
Epoch 460 / 500 | iteration 5 / 30 | Total Loss: 6.665330410003662 | KNN Loss: 5.607845306396484 | BCE Loss: 1.0574849843978882
Epoch 460 / 500 | iteration 10 / 30 | Total Loss: 6.696154594421387 | KNN Loss: 5.6513214111328125 | BCE Loss: 1.0448329448699951
Epoch 460 / 500 | iteration 15 / 30 | Total Loss: 6.623646259307861 | KNN Loss: 5.6089348793029785 | BCE Loss: 1.0147113800048828
Epoch 460 / 500 | iteration 20 / 30 | Total Loss: 6.674218654632568 | KNN Loss: 5.618333339691162 | BCE Loss: 1.0558851957321167
Epoch 460 / 500 | iteration 25 / 30 | Total Loss: 6.702733039855957 | KNN Loss: 5.653134346008301 | BCE Loss: 1.0495986938476562
Epoch 461 / 500 | iteration 0 / 30 | Total Loss: 6.70739221572876 | KNN Loss: 5.6386895179748535 |

Epoch 470 / 500 | iteration 15 / 30 | Total Loss: 6.616541862487793 | KNN Loss: 5.59578800201416 | BCE Loss: 1.020754098892212
Epoch 470 / 500 | iteration 20 / 30 | Total Loss: 6.731933116912842 | KNN Loss: 5.661772727966309 | BCE Loss: 1.0701605081558228
Epoch 470 / 500 | iteration 25 / 30 | Total Loss: 6.6596879959106445 | KNN Loss: 5.622954845428467 | BCE Loss: 1.0367331504821777
Epoch 471 / 500 | iteration 0 / 30 | Total Loss: 6.687159538269043 | KNN Loss: 5.616240978240967 | BCE Loss: 1.070918321609497
Epoch 471 / 500 | iteration 5 / 30 | Total Loss: 6.671160697937012 | KNN Loss: 5.635196208953857 | BCE Loss: 1.0359644889831543
Epoch 471 / 500 | iteration 10 / 30 | Total Loss: 6.667853355407715 | KNN Loss: 5.627680778503418 | BCE Loss: 1.0401724576950073
Epoch 471 / 500 | iteration 15 / 30 | Total Loss: 6.660338401794434 | KNN Loss: 5.59870719909668 | BCE Loss: 1.061631441116333
Epoch 471 / 500 | iteration 20 / 30 | Total Loss: 6.759291648864746 | KNN Loss: 5.706948757171631 | BCE

Epoch 481 / 500 | iteration 5 / 30 | Total Loss: 6.665174961090088 | KNN Loss: 5.610004901885986 | BCE Loss: 1.0551700592041016
Epoch 481 / 500 | iteration 10 / 30 | Total Loss: 6.739755153656006 | KNN Loss: 5.687113285064697 | BCE Loss: 1.0526418685913086
Epoch 481 / 500 | iteration 15 / 30 | Total Loss: 6.690363883972168 | KNN Loss: 5.615386486053467 | BCE Loss: 1.0749772787094116
Epoch 481 / 500 | iteration 20 / 30 | Total Loss: 6.667476177215576 | KNN Loss: 5.633862495422363 | BCE Loss: 1.033613681793213
Epoch 481 / 500 | iteration 25 / 30 | Total Loss: 6.6948466300964355 | KNN Loss: 5.644711971282959 | BCE Loss: 1.0501346588134766
Epoch 482 / 500 | iteration 0 / 30 | Total Loss: 6.717164039611816 | KNN Loss: 5.660238265991211 | BCE Loss: 1.0569257736206055
Epoch 482 / 500 | iteration 5 / 30 | Total Loss: 6.6404266357421875 | KNN Loss: 5.60251522064209 | BCE Loss: 1.0379115343093872
Epoch 482 / 500 | iteration 10 / 30 | Total Loss: 6.659369468688965 | KNN Loss: 5.6508355140686035 |

Epoch 491 / 500 | iteration 25 / 30 | Total Loss: 6.654203414916992 | KNN Loss: 5.602957248687744 | BCE Loss: 1.0512464046478271
Epoch 492 / 500 | iteration 0 / 30 | Total Loss: 6.699536323547363 | KNN Loss: 5.646448135375977 | BCE Loss: 1.0530880689620972
Epoch 492 / 500 | iteration 5 / 30 | Total Loss: 6.636556625366211 | KNN Loss: 5.590782165527344 | BCE Loss: 1.0457744598388672
Epoch 492 / 500 | iteration 10 / 30 | Total Loss: 6.629528522491455 | KNN Loss: 5.594355583190918 | BCE Loss: 1.035172939300537
Epoch 492 / 500 | iteration 15 / 30 | Total Loss: 6.672906875610352 | KNN Loss: 5.615921974182129 | BCE Loss: 1.0569850206375122
Epoch 492 / 500 | iteration 20 / 30 | Total Loss: 6.664638042449951 | KNN Loss: 5.611959457397461 | BCE Loss: 1.0526784658432007
Epoch 492 / 500 | iteration 25 / 30 | Total Loss: 6.716658115386963 | KNN Loss: 5.670037269592285 | BCE Loss: 1.0466207265853882
Epoch 493 / 500 | iteration 0 / 30 | Total Loss: 6.660101413726807 | KNN Loss: 5.63767147064209 | BC

In [7]:
outputs, iterm = model(dataset[67][0].unsqueeze(0).to(device), return_intermidiate=True)
print(outputs)
print(outputs.min())
print(outputs.max())

tensor([[ 2.5823,  3.8489,  2.5081,  3.2581,  3.1805,  0.6070,  2.5824,  2.1524,
          2.2921,  1.8750,  1.9751,  2.1245,  0.7700,  1.7622,  1.2701,  1.5015,
          2.4763,  2.8642,  2.7307,  2.2322,  1.7439,  2.9127,  2.2424,  2.2796,
          2.2368,  1.7092,  2.0704,  1.2598,  1.4428,  0.3424, -0.2011,  0.9850,
          0.2514,  1.0013,  1.5070,  1.4405,  0.9859,  3.0369,  0.7450,  1.2302,
          0.9363, -0.6310, -0.2804,  2.2822,  2.0094,  0.7214, -0.2034,  0.1042,
          1.4171,  2.4876,  1.8115,  0.1094,  1.3559,  0.4724, -0.6058,  1.1043,
          1.4397,  1.3207,  1.1768,  1.8007,  0.5496,  0.8158,  0.1961,  1.6543,
          1.2789,  1.6702, -1.8123,  0.2765,  1.9163,  2.1375,  2.2309,  0.4083,
          1.2641,  2.4029,  1.9281,  1.2849,  0.2437,  0.7234,  0.2041,  1.5392,
          0.0822,  0.4188,  1.5643, -0.4425,  0.2419, -1.0505, -2.2799, -0.3460,
          0.5775, -1.7547,  0.4210, -0.1554, -0.4991, -0.8217,  0.5404,  1.1915,
         -0.6738, -0.5723,  

In [8]:
plt.figure()
plt.plot(losses)
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [9]:
dataset.transform = torchvision.transforms.Compose([
    BinaryEncodingTransform(mapping=dataset.items_to_idx),
]
)
dataset.target_transform = torchvision.transforms.Compose([
    BinaryEncodingTransform(mapping=dataset.items_to_idx),
]
)

In [12]:
dataset_ = [d[0].cpu() for d in dataset]

In [13]:
model = model.eval().cpu()
projections = model.calculate_intermidiate(dataset_)

100%|██████████| 15/15 [00:01<00:00, 12.72it/s]


In [14]:
distances = pairwise_distances(projections)
# distances = np.triu(distances)
distances_f = distances.flatten()

plt.matshow(distances)
plt.colorbar()
plt.figure()
plt.hist(distances_f[distances_f > 0], bins=1000)
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

# Fit DBSCAN and calculate indices


In [15]:
clusters = DBSCAN(eps=0.2, min_samples=80).fit_predict(projections)
# scores = []
# best_score = float('inf')
# clusters = None
# range_ = list(range(5, 20))
# for k in tqdm(range_):
#     y = GaussianMixture(n_components=k).fit_predict(projections)
#     cur_score = davies_bouldin_score(projections, y)
#     scores.append(cur_score)
    
#     if cur_score < best_score:
#         best_score = cur_score
#         clusters = y

In [16]:
perplexity = 100

p = reduce_dims_and_plot(projections[clusters != -1],
                         y=clusters[clusters != -1],
                         title=f'perplexity: {perplexity}',
                         file_name=None,
                         perplexity=perplexity,
                         library='Multicore-TSNE',
                         perform_PCA=False,
                         projected=None,
                         figure_type='2d',
                         show_figure=True,
                         close_figure=False,
                         text=None)    

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [17]:
# from sklearn.tree import DecisionTreeClassifier
# from sklearn import tree
# from sklearn.tree import _tree

In [18]:
tensor_dataset = torch.stack(dataset_)

In [19]:
# clf = DecisionTreeClassifier(max_depth=200, min_samples_leaf=5)
# clf = clf.fit(tensor_dataset[clusters!=-1], clusters[clusters != -1])
# print(clf.score(tensor_dataset[clusters!=-1], clusters[clusters != -1]))
# print(clf.get_depth())

In [20]:
# scores = []
# for min_samples in range(1,50, 1):
#     clf = DecisionTreeClassifier(max_depth=200, min_samples_leaf=min_samples)
#     clf = clf.fit(tensor_dataset[clusters!=-1], clusters[clusters != -1])
#     scores.append(clf.score(tensor_dataset[clusters!=-1], clusters[clusters != -1]))
    
# plt.figure()
# plt.plot(list(range(1,50, 1)), scores)
# plt.show()

In [21]:
def get_rules(tree, feature_names, class_names):
    tree_ = tree.tree_
    feature_name = [
        feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!"
        for i in tree_.feature
    ]

    paths = []
    path = []
    
    def recurse(node, path, paths):
        if tree_.feature[node] != _tree.TREE_UNDEFINED:
            name = feature_name[node]
            threshold = tree_.threshold[node]
            p1, p2 = list(path), list(path)
#             p1 += [f"({name} <= {np.round(threshold, 3)})"]
            p1 += [(name, '<=', np.round(threshold, 3))]
            recurse(tree_.children_left[node], p1, paths)
            p2 += [(name, '>', np.round(threshold, 3))]
            recurse(tree_.children_right[node], p2, paths)
        else:
            path += [(tree_.value[node], tree_.n_node_samples[node])]
            paths += [path]
            
    recurse(0, path, paths)

    # sort by samples count
    samples_count = [p[-1][1] for p in paths]
    ii = list(np.argsort(samples_count))
    paths = [paths[i] for i in reversed(ii)]
    
    rules = []
    for path in paths:
        rule = []
        
        for p in path[:-1]:
            rule += [p]
        target = " then "
        if class_names is None:
            target += "response: "+str(np.round(path[-1][0][0][0],3))
        else:
            classes = path[-1][0][0]
            l = np.argmax(classes)
            target += f"class: {class_names[l]} (proba: {np.round(100.0*classes[l]/np.sum(classes),2)}%)"
           
        proba = np.round(100.0*classes[l]/np.sum(classes),2)
        target += f" | based on {path[-1][1]:,} samples"
        rule_wrapper = {'target': target, 'rule': rule, 'proba': proba}
        rules += [rule_wrapper]
        
    return rules

In [22]:
# rules = get_rules(clf, dataset.items, clusters[clusters != -1])

# for rule in rules:
#     n_pos = 0
#     for c,p,v in rule['rule']:
#         if p == '>':
#             n_pos += 1
#     rule['pos'] = n_pos

In [23]:
# plt.figure()
# probs = [r['proba'] for r in rules]
# plt.hist(probs, bins = 100)
# plt.show()

In [24]:
# rules = sorted(rules, key=lambda x:x['pos'])
# rules = [r for r in rules if r['proba'] > 50]
# print(len(rules))

In [25]:
# for i in range(17):
#     r_i = rules[i]
#     print(f"------------- rule {i} length {len(r_i)} -------------")
#     print(r_i)

# Train a Soft-Decision-Tree given the self-labels

## Prepare the dataset

In [26]:
tree_dataset = list(zip(tensor_dataset[clusters!=-1], clusters[clusters != -1]))
batch_size = 512
tree_loader = torch.utils.data.DataLoader(tree_dataset, batch_size=batch_size, shuffle=True)

# Define how we prune the weights of a node

In [27]:
def prune_node(node_weights, factor=1):
    w = node_weights.cpu().detach().numpy()
    mean_ = np.mean(w)
    std_ = np.std(w)
    node_weights[((mean_ - std_ * factor) < node_weights) & (node_weights < (mean_ + std_ * factor))] = 0
    return node_weights

def prune_node_keep(node_weights, keep=4):
    w = node_weights.cpu().detach().numpy()
    throw_idx = np.argsort(abs(w))[:-keep]
    node_weights[throw_idx] = 0
    return node_weights

def prune_tree(tree_, factor):
    new_weights = tree_.inner_nodes.weight.clone()
    for i in range(new_weights.shape[0]):
        res = prune_node_keep(new_weights[i, :], factor)
        new_weights[i, :] = res

    with torch.no_grad():
        tree_.inner_nodes.weight.copy_(new_weights)
        
def sparseness(x):
    s = []
    for i in range(x.shape[0]):
        x_ = x[i, :]
        sp = (len(x_) - torch.norm(x_, 0).item()) / len(x_)
        s.append(sp)
    return np.mean(s)

def compute_regularization_by_level(tree):
    total_reg = 0
    for i in range(tree.inner_nodes.weight.shape[0]):
        cur_level = np.floor(np.log2(i+1))
        node_reg = torch.norm(tree.inner_nodes.weight[i].view(-1), 2)
        total_reg += 2**(-cur_level) * node_reg
    return total_reg

def show_sparseness(tree):
    avg_sp = sparseness(tree.inner_nodes.weight)
    print(f"Average sparseness: {avg_sp}")
    layer = 0
    sps = []
    for i in range(tree.inner_nodes.weight.shape[0]):
        cur_layer = int(np.floor(np.log2(i+1)))
        if cur_layer != layer:
            print(f"layer {layer}: {np.mean(sps)}")
            sps = []
            layer = cur_layer

        x_ = tree.inner_nodes.weight[i, :]
        sp = (len(x_) - torch.norm(x_, 0).item()) / len(x_)
        sps.append(sp)
        
    return avg_sp

## Training configurations

In [28]:
def do_epoch(model, loader, device, log_interval, losses, accs, epoch, iteration):
    model = model.train()
    for batch_idx, (data, target) in enumerate(loader):
        iteration += 1
        data, target = data.to(device), target.to(device)

        output, penalty = tree.forward(data)

        # Loss
        loss_tree = criterion(output, target.view(-1))

        # Penalty
        loss_tree += penalty

        # Sparse regularization
#         fc_params = torch.cat([x.view(-1) for x in tree.inner_nodes.parameters()])
#         regularization = sparsity_lamda * torch.norm(fc_params, 2)
        regularization = sparsity_lamda * compute_regularization_by_level(tree)
        loss = loss_tree

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        losses.append(loss.item())

        pred = output.data.max(1)[1]
        correct = pred.eq(target.view(-1).data).sum()
        accs.append(correct.item() / data.size()[0])

        # Print training status
        if batch_idx % log_interval == 0:
            print(f"Epoch: {epoch:02d} | Batch: {batch_idx:03d} / {len(loader):03d} | Total loss: {loss.item():.3f} | Reg loss: {regularization.item():.3f} | Tree loss: {loss_tree.item():.3f} | Accuracy: {correct.item() / data.size()[0]:03f} | {round((time.time() - start_time) / iteration, 3)} sec/iter")
            
    return iteration


In [29]:
lr = 5e-3
weight_decay = 5e-4
sparsity_lamda = 2e-3
epochs = 100
output_dim = len(set(clusters))
log_interval = 1
use_cuda = device != 'cpu'

In [30]:
tree = SDT(input_dim=tensor_dataset.shape[1], output_dim=len(clusters - 1), depth=tree_depth, lamda=1e-3, use_cuda=use_cuda)
optimizer = torch.optim.Adam(tree.parameters(),
                                 lr=lr,
                                 weight_decay=weight_decay)
criterion = nn.CrossEntropyLoss()
tree = tree.to(device)

In [31]:
losses = []
accs = []
sparsity = []

In [32]:
start_time = time.time()
iteration = 0
for epoch in range(epochs):
    # Training
    avg_sp = show_sparseness(tree)
    sparsity.append(avg_sp)
    iteration = do_epoch(tree, tree_loader, device, log_interval, losses, accs, epoch, iteration)
    
    if epoch % 1 == 0:
        prune_tree(tree, factor=3)
        

Average sparseness: 0.0
layer 0: 0.0
layer 1: 0.0
layer 2: 0.0
layer 3: 0.0
layer 4: 0.0
layer 5: 0.0
layer 6: 0.0
Epoch: 00 | Batch: 000 / 030 | Total loss: 9.652 | Reg loss: 0.009 | Tree loss: 9.652 | Accuracy: 0.000000 | 1.009 sec/iter
Epoch: 00 | Batch: 001 / 030 | Total loss: 9.636 | Reg loss: 0.009 | Tree loss: 9.636 | Accuracy: 0.000000 | 0.728 sec/iter
Epoch: 00 | Batch: 002 / 030 | Total loss: 9.621 | Reg loss: 0.008 | Tree loss: 9.621 | Accuracy: 0.000000 | 0.6 sec/iter
Epoch: 00 | Batch: 003 / 030 | Total loss: 9.605 | Reg loss: 0.008 | Tree loss: 9.605 | Accuracy: 0.000000 | 0.542 sec/iter
Epoch: 00 | Batch: 004 / 030 | Total loss: 9.591 | Reg loss: 0.008 | Tree loss: 9.591 | Accuracy: 0.000000 | 0.509 sec/iter
Epoch: 00 | Batch: 005 / 030 | Total loss: 9.575 | Reg loss: 0.008 | Tree loss: 9.575 | Accuracy: 0.000000 | 0.486 sec/iter
Epoch: 00 | Batch: 006 / 030 | Total loss: 9.561 | Reg loss: 0.008 | Tree loss: 9.561 | Accuracy: 0.000000 | 0.469 sec/iter
Epoch: 00 | Batch: 

Epoch: 02 | Batch: 002 / 030 | Total loss: 9.236 | Reg loss: 0.008 | Tree loss: 9.236 | Accuracy: 0.541016 | 0.359 sec/iter
Epoch: 02 | Batch: 003 / 030 | Total loss: 9.221 | Reg loss: 0.008 | Tree loss: 9.221 | Accuracy: 0.537109 | 0.358 sec/iter
Epoch: 02 | Batch: 004 / 030 | Total loss: 9.206 | Reg loss: 0.008 | Tree loss: 9.206 | Accuracy: 0.587891 | 0.357 sec/iter
Epoch: 02 | Batch: 005 / 030 | Total loss: 9.195 | Reg loss: 0.009 | Tree loss: 9.195 | Accuracy: 0.541016 | 0.356 sec/iter
Epoch: 02 | Batch: 006 / 030 | Total loss: 9.180 | Reg loss: 0.009 | Tree loss: 9.180 | Accuracy: 0.589844 | 0.356 sec/iter
Epoch: 02 | Batch: 007 / 030 | Total loss: 9.160 | Reg loss: 0.009 | Tree loss: 9.160 | Accuracy: 0.568359 | 0.355 sec/iter
Epoch: 02 | Batch: 008 / 030 | Total loss: 9.148 | Reg loss: 0.010 | Tree loss: 9.148 | Accuracy: 0.539062 | 0.356 sec/iter
Epoch: 02 | Batch: 009 / 030 | Total loss: 9.133 | Reg loss: 0.010 | Tree loss: 9.133 | Accuracy: 0.585938 | 0.357 sec/iter
Epoch: 0

Epoch: 04 | Batch: 005 / 030 | Total loss: 8.836 | Reg loss: 0.013 | Tree loss: 8.836 | Accuracy: 0.580078 | 0.347 sec/iter
Epoch: 04 | Batch: 006 / 030 | Total loss: 8.829 | Reg loss: 0.014 | Tree loss: 8.829 | Accuracy: 0.539062 | 0.347 sec/iter
Epoch: 04 | Batch: 007 / 030 | Total loss: 8.806 | Reg loss: 0.014 | Tree loss: 8.806 | Accuracy: 0.570312 | 0.347 sec/iter
Epoch: 04 | Batch: 008 / 030 | Total loss: 8.791 | Reg loss: 0.014 | Tree loss: 8.791 | Accuracy: 0.585938 | 0.347 sec/iter
Epoch: 04 | Batch: 009 / 030 | Total loss: 8.780 | Reg loss: 0.014 | Tree loss: 8.780 | Accuracy: 0.527344 | 0.348 sec/iter
Epoch: 04 | Batch: 010 / 030 | Total loss: 8.756 | Reg loss: 0.015 | Tree loss: 8.756 | Accuracy: 0.562500 | 0.348 sec/iter
Epoch: 04 | Batch: 011 / 030 | Total loss: 8.745 | Reg loss: 0.015 | Tree loss: 8.745 | Accuracy: 0.556641 | 0.348 sec/iter
Epoch: 04 | Batch: 012 / 030 | Total loss: 8.724 | Reg loss: 0.015 | Tree loss: 8.724 | Accuracy: 0.558594 | 0.348 sec/iter
Epoch: 0

Epoch: 06 | Batch: 008 / 030 | Total loss: 8.421 | Reg loss: 0.018 | Tree loss: 8.421 | Accuracy: 0.585938 | 0.351 sec/iter
Epoch: 06 | Batch: 009 / 030 | Total loss: 8.413 | Reg loss: 0.018 | Tree loss: 8.413 | Accuracy: 0.583984 | 0.351 sec/iter
Epoch: 06 | Batch: 010 / 030 | Total loss: 8.383 | Reg loss: 0.019 | Tree loss: 8.383 | Accuracy: 0.613281 | 0.351 sec/iter
Epoch: 06 | Batch: 011 / 030 | Total loss: 8.377 | Reg loss: 0.019 | Tree loss: 8.377 | Accuracy: 0.548828 | 0.351 sec/iter
Epoch: 06 | Batch: 012 / 030 | Total loss: 8.360 | Reg loss: 0.019 | Tree loss: 8.360 | Accuracy: 0.562500 | 0.351 sec/iter
Epoch: 06 | Batch: 013 / 030 | Total loss: 8.336 | Reg loss: 0.020 | Tree loss: 8.336 | Accuracy: 0.562500 | 0.351 sec/iter
Epoch: 06 | Batch: 014 / 030 | Total loss: 8.312 | Reg loss: 0.020 | Tree loss: 8.312 | Accuracy: 0.625000 | 0.351 sec/iter
Epoch: 06 | Batch: 015 / 030 | Total loss: 8.296 | Reg loss: 0.020 | Tree loss: 8.296 | Accuracy: 0.583984 | 0.351 sec/iter
Epoch: 0

Epoch: 08 | Batch: 011 / 030 | Total loss: 7.980 | Reg loss: 0.023 | Tree loss: 7.980 | Accuracy: 0.591797 | 0.353 sec/iter
Epoch: 08 | Batch: 012 / 030 | Total loss: 7.944 | Reg loss: 0.023 | Tree loss: 7.944 | Accuracy: 0.566406 | 0.354 sec/iter
Epoch: 08 | Batch: 013 / 030 | Total loss: 7.936 | Reg loss: 0.023 | Tree loss: 7.936 | Accuracy: 0.572266 | 0.354 sec/iter
Epoch: 08 | Batch: 014 / 030 | Total loss: 7.921 | Reg loss: 0.024 | Tree loss: 7.921 | Accuracy: 0.576172 | 0.354 sec/iter
Epoch: 08 | Batch: 015 / 030 | Total loss: 7.892 | Reg loss: 0.024 | Tree loss: 7.892 | Accuracy: 0.570312 | 0.353 sec/iter
Epoch: 08 | Batch: 016 / 030 | Total loss: 7.879 | Reg loss: 0.024 | Tree loss: 7.879 | Accuracy: 0.554688 | 0.353 sec/iter
Epoch: 08 | Batch: 017 / 030 | Total loss: 7.854 | Reg loss: 0.025 | Tree loss: 7.854 | Accuracy: 0.564453 | 0.353 sec/iter
Epoch: 08 | Batch: 018 / 030 | Total loss: 7.826 | Reg loss: 0.025 | Tree loss: 7.826 | Accuracy: 0.564453 | 0.353 sec/iter
Epoch: 0

Epoch: 10 | Batch: 014 / 030 | Total loss: 7.449 | Reg loss: 0.027 | Tree loss: 7.449 | Accuracy: 0.601562 | 0.356 sec/iter
Epoch: 10 | Batch: 015 / 030 | Total loss: 7.456 | Reg loss: 0.027 | Tree loss: 7.456 | Accuracy: 0.562500 | 0.356 sec/iter
Epoch: 10 | Batch: 016 / 030 | Total loss: 7.412 | Reg loss: 0.027 | Tree loss: 7.412 | Accuracy: 0.583984 | 0.356 sec/iter
Epoch: 10 | Batch: 017 / 030 | Total loss: 7.404 | Reg loss: 0.028 | Tree loss: 7.404 | Accuracy: 0.542969 | 0.356 sec/iter
Epoch: 10 | Batch: 018 / 030 | Total loss: 7.379 | Reg loss: 0.028 | Tree loss: 7.379 | Accuracy: 0.562500 | 0.356 sec/iter
Epoch: 10 | Batch: 019 / 030 | Total loss: 7.342 | Reg loss: 0.028 | Tree loss: 7.342 | Accuracy: 0.550781 | 0.357 sec/iter
Epoch: 10 | Batch: 020 / 030 | Total loss: 7.313 | Reg loss: 0.029 | Tree loss: 7.313 | Accuracy: 0.582031 | 0.356 sec/iter
Epoch: 10 | Batch: 021 / 030 | Total loss: 7.276 | Reg loss: 0.029 | Tree loss: 7.276 | Accuracy: 0.619141 | 0.357 sec/iter
Epoch: 1

Epoch: 12 | Batch: 017 / 030 | Total loss: 6.881 | Reg loss: 0.030 | Tree loss: 6.881 | Accuracy: 0.607422 | 0.36 sec/iter
Epoch: 12 | Batch: 018 / 030 | Total loss: 6.865 | Reg loss: 0.031 | Tree loss: 6.865 | Accuracy: 0.578125 | 0.36 sec/iter
Epoch: 12 | Batch: 019 / 030 | Total loss: 6.829 | Reg loss: 0.031 | Tree loss: 6.829 | Accuracy: 0.570312 | 0.36 sec/iter
Epoch: 12 | Batch: 020 / 030 | Total loss: 6.806 | Reg loss: 0.031 | Tree loss: 6.806 | Accuracy: 0.582031 | 0.36 sec/iter
Epoch: 12 | Batch: 021 / 030 | Total loss: 6.771 | Reg loss: 0.032 | Tree loss: 6.771 | Accuracy: 0.568359 | 0.36 sec/iter
Epoch: 12 | Batch: 022 / 030 | Total loss: 6.777 | Reg loss: 0.032 | Tree loss: 6.777 | Accuracy: 0.542969 | 0.36 sec/iter
Epoch: 12 | Batch: 023 / 030 | Total loss: 6.732 | Reg loss: 0.032 | Tree loss: 6.732 | Accuracy: 0.572266 | 0.36 sec/iter
Epoch: 12 | Batch: 024 / 030 | Total loss: 6.687 | Reg loss: 0.033 | Tree loss: 6.687 | Accuracy: 0.611328 | 0.36 sec/iter
Epoch: 12 | Batc

Epoch: 14 | Batch: 020 / 030 | Total loss: 6.270 | Reg loss: 0.034 | Tree loss: 6.270 | Accuracy: 0.599609 | 0.361 sec/iter
Epoch: 14 | Batch: 021 / 030 | Total loss: 6.228 | Reg loss: 0.034 | Tree loss: 6.228 | Accuracy: 0.570312 | 0.361 sec/iter
Epoch: 14 | Batch: 022 / 030 | Total loss: 6.225 | Reg loss: 0.035 | Tree loss: 6.225 | Accuracy: 0.544922 | 0.361 sec/iter
Epoch: 14 | Batch: 023 / 030 | Total loss: 6.182 | Reg loss: 0.035 | Tree loss: 6.182 | Accuracy: 0.562500 | 0.361 sec/iter
Epoch: 14 | Batch: 024 / 030 | Total loss: 6.164 | Reg loss: 0.036 | Tree loss: 6.164 | Accuracy: 0.570312 | 0.361 sec/iter
Epoch: 14 | Batch: 025 / 030 | Total loss: 6.115 | Reg loss: 0.036 | Tree loss: 6.115 | Accuracy: 0.589844 | 0.361 sec/iter
Epoch: 14 | Batch: 026 / 030 | Total loss: 6.142 | Reg loss: 0.036 | Tree loss: 6.142 | Accuracy: 0.550781 | 0.361 sec/iter
Epoch: 14 | Batch: 027 / 030 | Total loss: 6.071 | Reg loss: 0.037 | Tree loss: 6.071 | Accuracy: 0.568359 | 0.361 sec/iter
Epoch: 1

Epoch: 16 | Batch: 023 / 030 | Total loss: 5.668 | Reg loss: 0.037 | Tree loss: 5.668 | Accuracy: 0.578125 | 0.361 sec/iter
Epoch: 16 | Batch: 024 / 030 | Total loss: 5.680 | Reg loss: 0.038 | Tree loss: 5.680 | Accuracy: 0.560547 | 0.361 sec/iter
Epoch: 16 | Batch: 025 / 030 | Total loss: 5.627 | Reg loss: 0.038 | Tree loss: 5.627 | Accuracy: 0.562500 | 0.361 sec/iter
Epoch: 16 | Batch: 026 / 030 | Total loss: 5.610 | Reg loss: 0.038 | Tree loss: 5.610 | Accuracy: 0.570312 | 0.36 sec/iter
Epoch: 16 | Batch: 027 / 030 | Total loss: 5.583 | Reg loss: 0.039 | Tree loss: 5.583 | Accuracy: 0.568359 | 0.36 sec/iter
Epoch: 16 | Batch: 028 / 030 | Total loss: 5.580 | Reg loss: 0.039 | Tree loss: 5.580 | Accuracy: 0.544922 | 0.36 sec/iter
Epoch: 16 | Batch: 029 / 030 | Total loss: 5.432 | Reg loss: 0.039 | Tree loss: 5.432 | Accuracy: 0.611111 | 0.36 sec/iter
Average sparseness: 0.9821428571428573
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.98

Epoch: 18 | Batch: 026 / 030 | Total loss: 5.164 | Reg loss: 0.040 | Tree loss: 5.164 | Accuracy: 0.542969 | 0.361 sec/iter
Epoch: 18 | Batch: 027 / 030 | Total loss: 5.097 | Reg loss: 0.040 | Tree loss: 5.097 | Accuracy: 0.582031 | 0.361 sec/iter
Epoch: 18 | Batch: 028 / 030 | Total loss: 5.013 | Reg loss: 0.041 | Tree loss: 5.013 | Accuracy: 0.601562 | 0.361 sec/iter
Epoch: 18 | Batch: 029 / 030 | Total loss: 5.049 | Reg loss: 0.041 | Tree loss: 5.049 | Accuracy: 0.509259 | 0.361 sec/iter
Average sparseness: 0.9821428571428573
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
Epoch: 19 | Batch: 000 / 030 | Total loss: 5.853 | Reg loss: 0.034 | Tree loss: 5.853 | Accuracy: 0.589844 | 0.362 sec/iter
Epoch: 19 | Batch: 001 / 030 | Total loss: 5.809 | Reg loss: 0.034 | Tree loss: 5.809 | Accuracy: 0.591797 | 0.362 sec/iter
Epoch: 19 | Batch: 002

Epoch: 20 | Batch: 029 / 030 | Total loss: 4.683 | Reg loss: 0.044 | Tree loss: 4.683 | Accuracy: 0.518519 | 0.36 sec/iter
Average sparseness: 0.9821428571428573
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
Epoch: 21 | Batch: 000 / 030 | Total loss: 5.437 | Reg loss: 0.036 | Tree loss: 5.437 | Accuracy: 0.587891 | 0.36 sec/iter
Epoch: 21 | Batch: 001 / 030 | Total loss: 5.452 | Reg loss: 0.036 | Tree loss: 5.452 | Accuracy: 0.556641 | 0.36 sec/iter
Epoch: 21 | Batch: 002 / 030 | Total loss: 5.372 | Reg loss: 0.036 | Tree loss: 5.372 | Accuracy: 0.560547 | 0.36 sec/iter
Epoch: 21 | Batch: 003 / 030 | Total loss: 5.335 | Reg loss: 0.036 | Tree loss: 5.335 | Accuracy: 0.570312 | 0.36 sec/iter
Epoch: 21 | Batch: 004 / 030 | Total loss: 5.274 | Reg loss: 0.036 | Tree loss: 5.274 | Accuracy: 0.587891 | 0.36 sec/iter
Epoch: 21 | Batch: 005 / 030

Epoch: 23 | Batch: 000 / 030 | Total loss: 4.918 | Reg loss: 0.039 | Tree loss: 4.918 | Accuracy: 0.580078 | 0.358 sec/iter
Epoch: 23 | Batch: 001 / 030 | Total loss: 4.899 | Reg loss: 0.039 | Tree loss: 4.899 | Accuracy: 0.568359 | 0.358 sec/iter
Epoch: 23 | Batch: 002 / 030 | Total loss: 4.910 | Reg loss: 0.039 | Tree loss: 4.910 | Accuracy: 0.560547 | 0.358 sec/iter
Epoch: 23 | Batch: 003 / 030 | Total loss: 4.789 | Reg loss: 0.039 | Tree loss: 4.789 | Accuracy: 0.587891 | 0.358 sec/iter
Epoch: 23 | Batch: 004 / 030 | Total loss: 4.780 | Reg loss: 0.039 | Tree loss: 4.780 | Accuracy: 0.585938 | 0.358 sec/iter
Epoch: 23 | Batch: 005 / 030 | Total loss: 4.760 | Reg loss: 0.039 | Tree loss: 4.760 | Accuracy: 0.593750 | 0.358 sec/iter
Epoch: 23 | Batch: 006 / 030 | Total loss: 4.729 | Reg loss: 0.039 | Tree loss: 4.729 | Accuracy: 0.542969 | 0.358 sec/iter
Epoch: 23 | Batch: 007 / 030 | Total loss: 4.625 | Reg loss: 0.039 | Tree loss: 4.625 | Accuracy: 0.578125 | 0.358 sec/iter
Epoch: 2

Epoch: 25 | Batch: 003 / 030 | Total loss: 4.357 | Reg loss: 0.040 | Tree loss: 4.357 | Accuracy: 0.558594 | 0.358 sec/iter
Epoch: 25 | Batch: 004 / 030 | Total loss: 4.244 | Reg loss: 0.041 | Tree loss: 4.244 | Accuracy: 0.580078 | 0.358 sec/iter
Epoch: 25 | Batch: 005 / 030 | Total loss: 4.178 | Reg loss: 0.041 | Tree loss: 4.178 | Accuracy: 0.597656 | 0.358 sec/iter
Epoch: 25 | Batch: 006 / 030 | Total loss: 4.171 | Reg loss: 0.041 | Tree loss: 4.171 | Accuracy: 0.574219 | 0.358 sec/iter
Epoch: 25 | Batch: 007 / 030 | Total loss: 4.137 | Reg loss: 0.041 | Tree loss: 4.137 | Accuracy: 0.576172 | 0.358 sec/iter
Epoch: 25 | Batch: 008 / 030 | Total loss: 4.087 | Reg loss: 0.041 | Tree loss: 4.087 | Accuracy: 0.568359 | 0.358 sec/iter
Epoch: 25 | Batch: 009 / 030 | Total loss: 4.063 | Reg loss: 0.041 | Tree loss: 4.063 | Accuracy: 0.554688 | 0.358 sec/iter
Epoch: 25 | Batch: 010 / 030 | Total loss: 3.987 | Reg loss: 0.041 | Tree loss: 3.987 | Accuracy: 0.556641 | 0.358 sec/iter
Epoch: 2

Epoch: 27 | Batch: 006 / 030 | Total loss: 3.629 | Reg loss: 0.041 | Tree loss: 3.629 | Accuracy: 0.601562 | 0.358 sec/iter
Epoch: 27 | Batch: 007 / 030 | Total loss: 3.616 | Reg loss: 0.041 | Tree loss: 3.616 | Accuracy: 0.568359 | 0.358 sec/iter
Epoch: 27 | Batch: 008 / 030 | Total loss: 3.573 | Reg loss: 0.042 | Tree loss: 3.573 | Accuracy: 0.582031 | 0.358 sec/iter
Epoch: 27 | Batch: 009 / 030 | Total loss: 3.518 | Reg loss: 0.042 | Tree loss: 3.518 | Accuracy: 0.574219 | 0.358 sec/iter
Epoch: 27 | Batch: 010 / 030 | Total loss: 3.494 | Reg loss: 0.042 | Tree loss: 3.494 | Accuracy: 0.546875 | 0.358 sec/iter
Epoch: 27 | Batch: 011 / 030 | Total loss: 3.467 | Reg loss: 0.042 | Tree loss: 3.467 | Accuracy: 0.552734 | 0.358 sec/iter
Epoch: 27 | Batch: 012 / 030 | Total loss: 3.419 | Reg loss: 0.042 | Tree loss: 3.419 | Accuracy: 0.570312 | 0.358 sec/iter
Epoch: 27 | Batch: 013 / 030 | Total loss: 3.336 | Reg loss: 0.042 | Tree loss: 3.336 | Accuracy: 0.568359 | 0.358 sec/iter
Epoch: 2

Epoch: 29 | Batch: 009 / 030 | Total loss: 3.068 | Reg loss: 0.042 | Tree loss: 3.068 | Accuracy: 0.562500 | 0.359 sec/iter
Epoch: 29 | Batch: 010 / 030 | Total loss: 3.081 | Reg loss: 0.042 | Tree loss: 3.081 | Accuracy: 0.568359 | 0.359 sec/iter
Epoch: 29 | Batch: 011 / 030 | Total loss: 3.018 | Reg loss: 0.042 | Tree loss: 3.018 | Accuracy: 0.572266 | 0.359 sec/iter
Epoch: 29 | Batch: 012 / 030 | Total loss: 2.999 | Reg loss: 0.042 | Tree loss: 2.999 | Accuracy: 0.582031 | 0.359 sec/iter
Epoch: 29 | Batch: 013 / 030 | Total loss: 2.898 | Reg loss: 0.042 | Tree loss: 2.898 | Accuracy: 0.589844 | 0.359 sec/iter
Epoch: 29 | Batch: 014 / 030 | Total loss: 2.936 | Reg loss: 0.042 | Tree loss: 2.936 | Accuracy: 0.572266 | 0.359 sec/iter
Epoch: 29 | Batch: 015 / 030 | Total loss: 2.918 | Reg loss: 0.043 | Tree loss: 2.918 | Accuracy: 0.566406 | 0.359 sec/iter
Epoch: 29 | Batch: 016 / 030 | Total loss: 2.915 | Reg loss: 0.043 | Tree loss: 2.915 | Accuracy: 0.505859 | 0.359 sec/iter
Epoch: 2

Epoch: 31 | Batch: 012 / 030 | Total loss: 2.653 | Reg loss: 0.042 | Tree loss: 2.653 | Accuracy: 0.554688 | 0.36 sec/iter
Epoch: 31 | Batch: 013 / 030 | Total loss: 2.548 | Reg loss: 0.042 | Tree loss: 2.548 | Accuracy: 0.623047 | 0.36 sec/iter
Epoch: 31 | Batch: 014 / 030 | Total loss: 2.590 | Reg loss: 0.042 | Tree loss: 2.590 | Accuracy: 0.554688 | 0.36 sec/iter
Epoch: 31 | Batch: 015 / 030 | Total loss: 2.540 | Reg loss: 0.042 | Tree loss: 2.540 | Accuracy: 0.558594 | 0.36 sec/iter
Epoch: 31 | Batch: 016 / 030 | Total loss: 2.460 | Reg loss: 0.042 | Tree loss: 2.460 | Accuracy: 0.607422 | 0.36 sec/iter
Epoch: 31 | Batch: 017 / 030 | Total loss: 2.477 | Reg loss: 0.042 | Tree loss: 2.477 | Accuracy: 0.576172 | 0.36 sec/iter
Epoch: 31 | Batch: 018 / 030 | Total loss: 2.444 | Reg loss: 0.043 | Tree loss: 2.444 | Accuracy: 0.585938 | 0.36 sec/iter
Epoch: 31 | Batch: 019 / 030 | Total loss: 2.445 | Reg loss: 0.043 | Tree loss: 2.445 | Accuracy: 0.566406 | 0.36 sec/iter
Epoch: 31 | Batc

Epoch: 33 | Batch: 015 / 030 | Total loss: 2.259 | Reg loss: 0.042 | Tree loss: 2.259 | Accuracy: 0.570312 | 0.36 sec/iter
Epoch: 33 | Batch: 016 / 030 | Total loss: 2.203 | Reg loss: 0.042 | Tree loss: 2.203 | Accuracy: 0.589844 | 0.36 sec/iter
Epoch: 33 | Batch: 017 / 030 | Total loss: 2.221 | Reg loss: 0.042 | Tree loss: 2.221 | Accuracy: 0.568359 | 0.36 sec/iter
Epoch: 33 | Batch: 018 / 030 | Total loss: 2.263 | Reg loss: 0.042 | Tree loss: 2.263 | Accuracy: 0.517578 | 0.36 sec/iter
Epoch: 33 | Batch: 019 / 030 | Total loss: 2.203 | Reg loss: 0.042 | Tree loss: 2.203 | Accuracy: 0.544922 | 0.36 sec/iter
Epoch: 33 | Batch: 020 / 030 | Total loss: 2.181 | Reg loss: 0.042 | Tree loss: 2.181 | Accuracy: 0.550781 | 0.36 sec/iter
Epoch: 33 | Batch: 021 / 030 | Total loss: 2.119 | Reg loss: 0.042 | Tree loss: 2.119 | Accuracy: 0.560547 | 0.36 sec/iter
Epoch: 33 | Batch: 022 / 030 | Total loss: 2.090 | Reg loss: 0.042 | Tree loss: 2.090 | Accuracy: 0.562500 | 0.36 sec/iter
Epoch: 33 | Batc

Epoch: 35 | Batch: 018 / 030 | Total loss: 2.003 | Reg loss: 0.041 | Tree loss: 2.003 | Accuracy: 0.560547 | 0.36 sec/iter
Epoch: 35 | Batch: 019 / 030 | Total loss: 1.979 | Reg loss: 0.041 | Tree loss: 1.979 | Accuracy: 0.556641 | 0.36 sec/iter
Epoch: 35 | Batch: 020 / 030 | Total loss: 1.905 | Reg loss: 0.041 | Tree loss: 1.905 | Accuracy: 0.583984 | 0.36 sec/iter
Epoch: 35 | Batch: 021 / 030 | Total loss: 1.886 | Reg loss: 0.041 | Tree loss: 1.886 | Accuracy: 0.593750 | 0.36 sec/iter
Epoch: 35 | Batch: 022 / 030 | Total loss: 1.943 | Reg loss: 0.041 | Tree loss: 1.943 | Accuracy: 0.527344 | 0.36 sec/iter
Epoch: 35 | Batch: 023 / 030 | Total loss: 1.855 | Reg loss: 0.041 | Tree loss: 1.855 | Accuracy: 0.566406 | 0.36 sec/iter
Epoch: 35 | Batch: 024 / 030 | Total loss: 1.808 | Reg loss: 0.042 | Tree loss: 1.808 | Accuracy: 0.613281 | 0.36 sec/iter
Epoch: 35 | Batch: 025 / 030 | Total loss: 1.823 | Reg loss: 0.042 | Tree loss: 1.823 | Accuracy: 0.589844 | 0.36 sec/iter
Epoch: 35 | Batc

Epoch: 37 | Batch: 021 / 030 | Total loss: 1.726 | Reg loss: 0.040 | Tree loss: 1.726 | Accuracy: 0.589844 | 0.36 sec/iter
Epoch: 37 | Batch: 022 / 030 | Total loss: 1.704 | Reg loss: 0.040 | Tree loss: 1.704 | Accuracy: 0.595703 | 0.36 sec/iter
Epoch: 37 | Batch: 023 / 030 | Total loss: 1.707 | Reg loss: 0.040 | Tree loss: 1.707 | Accuracy: 0.583984 | 0.36 sec/iter
Epoch: 37 | Batch: 024 / 030 | Total loss: 1.673 | Reg loss: 0.040 | Tree loss: 1.673 | Accuracy: 0.582031 | 0.36 sec/iter
Epoch: 37 | Batch: 025 / 030 | Total loss: 1.730 | Reg loss: 0.041 | Tree loss: 1.730 | Accuracy: 0.539062 | 0.36 sec/iter
Epoch: 37 | Batch: 026 / 030 | Total loss: 1.640 | Reg loss: 0.041 | Tree loss: 1.640 | Accuracy: 0.566406 | 0.36 sec/iter
Epoch: 37 | Batch: 027 / 030 | Total loss: 1.616 | Reg loss: 0.041 | Tree loss: 1.616 | Accuracy: 0.599609 | 0.36 sec/iter
Epoch: 37 | Batch: 028 / 030 | Total loss: 1.681 | Reg loss: 0.041 | Tree loss: 1.681 | Accuracy: 0.535156 | 0.36 sec/iter
Epoch: 37 | Batc

Epoch: 39 | Batch: 024 / 030 | Total loss: 1.536 | Reg loss: 0.039 | Tree loss: 1.536 | Accuracy: 0.550781 | 0.361 sec/iter
Epoch: 39 | Batch: 025 / 030 | Total loss: 1.538 | Reg loss: 0.040 | Tree loss: 1.538 | Accuracy: 0.574219 | 0.361 sec/iter
Epoch: 39 | Batch: 026 / 030 | Total loss: 1.511 | Reg loss: 0.040 | Tree loss: 1.511 | Accuracy: 0.591797 | 0.361 sec/iter
Epoch: 39 | Batch: 027 / 030 | Total loss: 1.533 | Reg loss: 0.040 | Tree loss: 1.533 | Accuracy: 0.542969 | 0.361 sec/iter
Epoch: 39 | Batch: 028 / 030 | Total loss: 1.533 | Reg loss: 0.040 | Tree loss: 1.533 | Accuracy: 0.548828 | 0.361 sec/iter
Epoch: 39 | Batch: 029 / 030 | Total loss: 1.422 | Reg loss: 0.040 | Tree loss: 1.422 | Accuracy: 0.629630 | 0.361 sec/iter
Average sparseness: 0.9821428571428573
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
Epoch: 40 | Batch: 000

Epoch: 41 | Batch: 027 / 030 | Total loss: 1.409 | Reg loss: 0.039 | Tree loss: 1.409 | Accuracy: 0.550781 | 0.361 sec/iter
Epoch: 41 | Batch: 028 / 030 | Total loss: 1.409 | Reg loss: 0.039 | Tree loss: 1.409 | Accuracy: 0.558594 | 0.361 sec/iter
Epoch: 41 | Batch: 029 / 030 | Total loss: 1.370 | Reg loss: 0.039 | Tree loss: 1.370 | Accuracy: 0.564815 | 0.361 sec/iter
Average sparseness: 0.9821428571428573
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
Epoch: 42 | Batch: 000 / 030 | Total loss: 1.723 | Reg loss: 0.037 | Tree loss: 1.723 | Accuracy: 0.593750 | 0.361 sec/iter
Epoch: 42 | Batch: 001 / 030 | Total loss: 1.749 | Reg loss: 0.037 | Tree loss: 1.749 | Accuracy: 0.539062 | 0.361 sec/iter
Epoch: 42 | Batch: 002 / 030 | Total loss: 1.702 | Reg loss: 0.037 | Tree loss: 1.702 | Accuracy: 0.597656 | 0.361 sec/iter
Epoch: 42 | Batch: 003

Average sparseness: 0.9821428571428573
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
Epoch: 44 | Batch: 000 / 030 | Total loss: 1.622 | Reg loss: 0.037 | Tree loss: 1.622 | Accuracy: 0.599609 | 0.362 sec/iter
Epoch: 44 | Batch: 001 / 030 | Total loss: 1.616 | Reg loss: 0.037 | Tree loss: 1.616 | Accuracy: 0.583984 | 0.362 sec/iter
Epoch: 44 | Batch: 002 / 030 | Total loss: 1.597 | Reg loss: 0.036 | Tree loss: 1.597 | Accuracy: 0.601562 | 0.362 sec/iter
Epoch: 44 | Batch: 003 / 030 | Total loss: 1.596 | Reg loss: 0.036 | Tree loss: 1.596 | Accuracy: 0.542969 | 0.362 sec/iter
Epoch: 44 | Batch: 004 / 030 | Total loss: 1.578 | Reg loss: 0.036 | Tree loss: 1.578 | Accuracy: 0.533203 | 0.362 sec/iter
Epoch: 44 | Batch: 005 / 030 | Total loss: 1.495 | Reg loss: 0.036 | Tree loss: 1.495 | Accuracy: 0.634766 | 0.362 sec/iter
Epoch: 44 | Batch: 006

Epoch: 46 | Batch: 001 / 030 | Total loss: 1.500 | Reg loss: 0.036 | Tree loss: 1.500 | Accuracy: 0.591797 | 0.362 sec/iter
Epoch: 46 | Batch: 002 / 030 | Total loss: 1.480 | Reg loss: 0.036 | Tree loss: 1.480 | Accuracy: 0.615234 | 0.362 sec/iter
Epoch: 46 | Batch: 003 / 030 | Total loss: 1.492 | Reg loss: 0.036 | Tree loss: 1.492 | Accuracy: 0.542969 | 0.362 sec/iter
Epoch: 46 | Batch: 004 / 030 | Total loss: 1.487 | Reg loss: 0.036 | Tree loss: 1.487 | Accuracy: 0.548828 | 0.362 sec/iter
Epoch: 46 | Batch: 005 / 030 | Total loss: 1.459 | Reg loss: 0.036 | Tree loss: 1.459 | Accuracy: 0.587891 | 0.362 sec/iter
Epoch: 46 | Batch: 006 / 030 | Total loss: 1.443 | Reg loss: 0.036 | Tree loss: 1.443 | Accuracy: 0.566406 | 0.362 sec/iter
Epoch: 46 | Batch: 007 / 030 | Total loss: 1.450 | Reg loss: 0.036 | Tree loss: 1.450 | Accuracy: 0.568359 | 0.362 sec/iter
Epoch: 46 | Batch: 008 / 030 | Total loss: 1.421 | Reg loss: 0.036 | Tree loss: 1.421 | Accuracy: 0.546875 | 0.362 sec/iter
Epoch: 4

Epoch: 48 | Batch: 004 / 030 | Total loss: 1.373 | Reg loss: 0.035 | Tree loss: 1.373 | Accuracy: 0.582031 | 0.363 sec/iter
Epoch: 48 | Batch: 005 / 030 | Total loss: 1.362 | Reg loss: 0.035 | Tree loss: 1.362 | Accuracy: 0.580078 | 0.363 sec/iter
Epoch: 48 | Batch: 006 / 030 | Total loss: 1.392 | Reg loss: 0.035 | Tree loss: 1.392 | Accuracy: 0.546875 | 0.363 sec/iter
Epoch: 48 | Batch: 007 / 030 | Total loss: 1.346 | Reg loss: 0.035 | Tree loss: 1.346 | Accuracy: 0.578125 | 0.363 sec/iter
Epoch: 48 | Batch: 008 / 030 | Total loss: 1.323 | Reg loss: 0.035 | Tree loss: 1.323 | Accuracy: 0.609375 | 0.363 sec/iter
Epoch: 48 | Batch: 009 / 030 | Total loss: 1.319 | Reg loss: 0.035 | Tree loss: 1.319 | Accuracy: 0.603516 | 0.362 sec/iter
Epoch: 48 | Batch: 010 / 030 | Total loss: 1.306 | Reg loss: 0.035 | Tree loss: 1.306 | Accuracy: 0.568359 | 0.362 sec/iter
Epoch: 48 | Batch: 011 / 030 | Total loss: 1.310 | Reg loss: 0.035 | Tree loss: 1.310 | Accuracy: 0.550781 | 0.362 sec/iter
Epoch: 4

Epoch: 50 | Batch: 007 / 030 | Total loss: 1.280 | Reg loss: 0.034 | Tree loss: 1.280 | Accuracy: 0.572266 | 0.361 sec/iter
Epoch: 50 | Batch: 008 / 030 | Total loss: 1.289 | Reg loss: 0.034 | Tree loss: 1.289 | Accuracy: 0.550781 | 0.361 sec/iter
Epoch: 50 | Batch: 009 / 030 | Total loss: 1.242 | Reg loss: 0.034 | Tree loss: 1.242 | Accuracy: 0.595703 | 0.361 sec/iter
Epoch: 50 | Batch: 010 / 030 | Total loss: 1.262 | Reg loss: 0.034 | Tree loss: 1.262 | Accuracy: 0.578125 | 0.361 sec/iter
Epoch: 50 | Batch: 011 / 030 | Total loss: 1.244 | Reg loss: 0.034 | Tree loss: 1.244 | Accuracy: 0.574219 | 0.361 sec/iter
Epoch: 50 | Batch: 012 / 030 | Total loss: 1.222 | Reg loss: 0.034 | Tree loss: 1.222 | Accuracy: 0.582031 | 0.361 sec/iter
Epoch: 50 | Batch: 013 / 030 | Total loss: 1.235 | Reg loss: 0.034 | Tree loss: 1.235 | Accuracy: 0.546875 | 0.361 sec/iter
Epoch: 50 | Batch: 014 / 030 | Total loss: 1.182 | Reg loss: 0.035 | Tree loss: 1.182 | Accuracy: 0.601562 | 0.361 sec/iter
Epoch: 5

Epoch: 52 | Batch: 010 / 030 | Total loss: 1.203 | Reg loss: 0.034 | Tree loss: 1.203 | Accuracy: 0.554688 | 0.361 sec/iter
Epoch: 52 | Batch: 011 / 030 | Total loss: 1.191 | Reg loss: 0.034 | Tree loss: 1.191 | Accuracy: 0.583984 | 0.362 sec/iter
Epoch: 52 | Batch: 012 / 030 | Total loss: 1.193 | Reg loss: 0.034 | Tree loss: 1.193 | Accuracy: 0.576172 | 0.362 sec/iter
Epoch: 52 | Batch: 013 / 030 | Total loss: 1.174 | Reg loss: 0.034 | Tree loss: 1.174 | Accuracy: 0.560547 | 0.362 sec/iter
Epoch: 52 | Batch: 014 / 030 | Total loss: 1.129 | Reg loss: 0.034 | Tree loss: 1.129 | Accuracy: 0.611328 | 0.361 sec/iter
Epoch: 52 | Batch: 015 / 030 | Total loss: 1.162 | Reg loss: 0.034 | Tree loss: 1.162 | Accuracy: 0.550781 | 0.361 sec/iter
Epoch: 52 | Batch: 016 / 030 | Total loss: 1.137 | Reg loss: 0.034 | Tree loss: 1.137 | Accuracy: 0.599609 | 0.361 sec/iter
Epoch: 52 | Batch: 017 / 030 | Total loss: 1.127 | Reg loss: 0.034 | Tree loss: 1.127 | Accuracy: 0.578125 | 0.361 sec/iter
Epoch: 5

Epoch: 54 | Batch: 013 / 030 | Total loss: 1.122 | Reg loss: 0.033 | Tree loss: 1.122 | Accuracy: 0.595703 | 0.36 sec/iter
Epoch: 54 | Batch: 014 / 030 | Total loss: 1.104 | Reg loss: 0.033 | Tree loss: 1.104 | Accuracy: 0.615234 | 0.36 sec/iter
Epoch: 54 | Batch: 015 / 030 | Total loss: 1.109 | Reg loss: 0.033 | Tree loss: 1.109 | Accuracy: 0.583984 | 0.36 sec/iter
Epoch: 54 | Batch: 016 / 030 | Total loss: 1.095 | Reg loss: 0.033 | Tree loss: 1.095 | Accuracy: 0.593750 | 0.36 sec/iter
Epoch: 54 | Batch: 017 / 030 | Total loss: 1.084 | Reg loss: 0.033 | Tree loss: 1.084 | Accuracy: 0.587891 | 0.36 sec/iter
Epoch: 54 | Batch: 018 / 030 | Total loss: 1.104 | Reg loss: 0.033 | Tree loss: 1.104 | Accuracy: 0.537109 | 0.36 sec/iter
Epoch: 54 | Batch: 019 / 030 | Total loss: 1.078 | Reg loss: 0.033 | Tree loss: 1.078 | Accuracy: 0.576172 | 0.36 sec/iter
Epoch: 54 | Batch: 020 / 030 | Total loss: 1.081 | Reg loss: 0.033 | Tree loss: 1.081 | Accuracy: 0.574219 | 0.36 sec/iter
Epoch: 54 | Batc

Epoch: 56 | Batch: 016 / 030 | Total loss: 1.080 | Reg loss: 0.033 | Tree loss: 1.080 | Accuracy: 0.558594 | 0.359 sec/iter
Epoch: 56 | Batch: 017 / 030 | Total loss: 1.080 | Reg loss: 0.033 | Tree loss: 1.080 | Accuracy: 0.546875 | 0.359 sec/iter
Epoch: 56 | Batch: 018 / 030 | Total loss: 1.053 | Reg loss: 0.033 | Tree loss: 1.053 | Accuracy: 0.582031 | 0.359 sec/iter
Epoch: 56 | Batch: 019 / 030 | Total loss: 1.066 | Reg loss: 0.033 | Tree loss: 1.066 | Accuracy: 0.560547 | 0.359 sec/iter
Epoch: 56 | Batch: 020 / 030 | Total loss: 1.039 | Reg loss: 0.033 | Tree loss: 1.039 | Accuracy: 0.595703 | 0.358 sec/iter
Epoch: 56 | Batch: 021 / 030 | Total loss: 1.038 | Reg loss: 0.033 | Tree loss: 1.038 | Accuracy: 0.609375 | 0.358 sec/iter
Epoch: 56 | Batch: 022 / 030 | Total loss: 1.044 | Reg loss: 0.033 | Tree loss: 1.044 | Accuracy: 0.546875 | 0.358 sec/iter
Epoch: 56 | Batch: 023 / 030 | Total loss: 1.032 | Reg loss: 0.033 | Tree loss: 1.032 | Accuracy: 0.566406 | 0.358 sec/iter
Epoch: 5

Epoch: 58 | Batch: 019 / 030 | Total loss: 1.014 | Reg loss: 0.032 | Tree loss: 1.014 | Accuracy: 0.589844 | 0.357 sec/iter
Epoch: 58 | Batch: 020 / 030 | Total loss: 1.054 | Reg loss: 0.032 | Tree loss: 1.054 | Accuracy: 0.541016 | 0.357 sec/iter
Epoch: 58 | Batch: 021 / 030 | Total loss: 1.032 | Reg loss: 0.032 | Tree loss: 1.032 | Accuracy: 0.550781 | 0.357 sec/iter
Epoch: 58 | Batch: 022 / 030 | Total loss: 1.014 | Reg loss: 0.032 | Tree loss: 1.014 | Accuracy: 0.574219 | 0.357 sec/iter
Epoch: 58 | Batch: 023 / 030 | Total loss: 1.012 | Reg loss: 0.032 | Tree loss: 1.012 | Accuracy: 0.568359 | 0.357 sec/iter
Epoch: 58 | Batch: 024 / 030 | Total loss: 0.998 | Reg loss: 0.032 | Tree loss: 0.998 | Accuracy: 0.585938 | 0.357 sec/iter
Epoch: 58 | Batch: 025 / 030 | Total loss: 1.006 | Reg loss: 0.032 | Tree loss: 1.006 | Accuracy: 0.576172 | 0.357 sec/iter
Epoch: 58 | Batch: 026 / 030 | Total loss: 1.004 | Reg loss: 0.033 | Tree loss: 1.004 | Accuracy: 0.554688 | 0.357 sec/iter
Epoch: 5

Epoch: 60 | Batch: 022 / 030 | Total loss: 0.977 | Reg loss: 0.032 | Tree loss: 0.977 | Accuracy: 0.621094 | 0.358 sec/iter
Epoch: 60 | Batch: 023 / 030 | Total loss: 1.001 | Reg loss: 0.032 | Tree loss: 1.001 | Accuracy: 0.562500 | 0.358 sec/iter
Epoch: 60 | Batch: 024 / 030 | Total loss: 1.003 | Reg loss: 0.032 | Tree loss: 1.003 | Accuracy: 0.542969 | 0.358 sec/iter
Epoch: 60 | Batch: 025 / 030 | Total loss: 0.977 | Reg loss: 0.032 | Tree loss: 0.977 | Accuracy: 0.578125 | 0.358 sec/iter
Epoch: 60 | Batch: 026 / 030 | Total loss: 0.944 | Reg loss: 0.032 | Tree loss: 0.944 | Accuracy: 0.630859 | 0.358 sec/iter
Epoch: 60 | Batch: 027 / 030 | Total loss: 0.959 | Reg loss: 0.032 | Tree loss: 0.959 | Accuracy: 0.597656 | 0.358 sec/iter
Epoch: 60 | Batch: 028 / 030 | Total loss: 0.966 | Reg loss: 0.032 | Tree loss: 0.966 | Accuracy: 0.582031 | 0.358 sec/iter
Epoch: 60 | Batch: 029 / 030 | Total loss: 0.938 | Reg loss: 0.032 | Tree loss: 0.938 | Accuracy: 0.638889 | 0.358 sec/iter
Average 

Epoch: 62 | Batch: 025 / 030 | Total loss: 0.964 | Reg loss: 0.032 | Tree loss: 0.964 | Accuracy: 0.572266 | 0.358 sec/iter
Epoch: 62 | Batch: 026 / 030 | Total loss: 0.947 | Reg loss: 0.032 | Tree loss: 0.947 | Accuracy: 0.589844 | 0.358 sec/iter
Epoch: 62 | Batch: 027 / 030 | Total loss: 0.946 | Reg loss: 0.032 | Tree loss: 0.946 | Accuracy: 0.597656 | 0.358 sec/iter
Epoch: 62 | Batch: 028 / 030 | Total loss: 0.967 | Reg loss: 0.032 | Tree loss: 0.967 | Accuracy: 0.548828 | 0.358 sec/iter
Epoch: 62 | Batch: 029 / 030 | Total loss: 0.959 | Reg loss: 0.032 | Tree loss: 0.959 | Accuracy: 0.583333 | 0.358 sec/iter
Average sparseness: 0.9821428571428573
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
Epoch: 63 | Batch: 000 / 030 | Total loss: 1.136 | Reg loss: 0.031 | Tree loss: 1.136 | Accuracy: 0.595703 | 0.358 sec/iter
Epoch: 63 | Batch: 001

Epoch: 64 | Batch: 028 / 030 | Total loss: 0.961 | Reg loss: 0.031 | Tree loss: 0.961 | Accuracy: 0.546875 | 0.359 sec/iter
Epoch: 64 | Batch: 029 / 030 | Total loss: 0.903 | Reg loss: 0.031 | Tree loss: 0.903 | Accuracy: 0.648148 | 0.359 sec/iter
Average sparseness: 0.9821428571428573
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
Epoch: 65 | Batch: 000 / 030 | Total loss: 1.142 | Reg loss: 0.030 | Tree loss: 1.142 | Accuracy: 0.566406 | 0.359 sec/iter
Epoch: 65 | Batch: 001 / 030 | Total loss: 1.120 | Reg loss: 0.030 | Tree loss: 1.120 | Accuracy: 0.597656 | 0.359 sec/iter
Epoch: 65 | Batch: 002 / 030 | Total loss: 1.109 | Reg loss: 0.030 | Tree loss: 1.109 | Accuracy: 0.570312 | 0.359 sec/iter
Epoch: 65 | Batch: 003 / 030 | Total loss: 1.120 | Reg loss: 0.030 | Tree loss: 1.120 | Accuracy: 0.576172 | 0.359 sec/iter
Epoch: 65 | Batch: 004

Epoch: 67 | Batch: 000 / 030 | Total loss: 1.121 | Reg loss: 0.030 | Tree loss: 1.121 | Accuracy: 0.613281 | 0.359 sec/iter
Epoch: 67 | Batch: 001 / 030 | Total loss: 1.113 | Reg loss: 0.030 | Tree loss: 1.113 | Accuracy: 0.568359 | 0.359 sec/iter
Epoch: 67 | Batch: 002 / 030 | Total loss: 1.120 | Reg loss: 0.030 | Tree loss: 1.120 | Accuracy: 0.570312 | 0.359 sec/iter
Epoch: 67 | Batch: 003 / 030 | Total loss: 1.088 | Reg loss: 0.030 | Tree loss: 1.088 | Accuracy: 0.578125 | 0.359 sec/iter
Epoch: 67 | Batch: 004 / 030 | Total loss: 1.080 | Reg loss: 0.030 | Tree loss: 1.080 | Accuracy: 0.566406 | 0.359 sec/iter
Epoch: 67 | Batch: 005 / 030 | Total loss: 1.052 | Reg loss: 0.030 | Tree loss: 1.052 | Accuracy: 0.591797 | 0.359 sec/iter
Epoch: 67 | Batch: 006 / 030 | Total loss: 1.058 | Reg loss: 0.030 | Tree loss: 1.058 | Accuracy: 0.572266 | 0.359 sec/iter
Epoch: 67 | Batch: 007 / 030 | Total loss: 1.053 | Reg loss: 0.030 | Tree loss: 1.053 | Accuracy: 0.576172 | 0.359 sec/iter
Epoch: 6

Epoch: 69 | Batch: 003 / 030 | Total loss: 1.096 | Reg loss: 0.030 | Tree loss: 1.096 | Accuracy: 0.564453 | 0.359 sec/iter
Epoch: 69 | Batch: 004 / 030 | Total loss: 1.100 | Reg loss: 0.030 | Tree loss: 1.100 | Accuracy: 0.529297 | 0.359 sec/iter
Epoch: 69 | Batch: 005 / 030 | Total loss: 1.046 | Reg loss: 0.030 | Tree loss: 1.046 | Accuracy: 0.587891 | 0.359 sec/iter
Epoch: 69 | Batch: 006 / 030 | Total loss: 1.046 | Reg loss: 0.030 | Tree loss: 1.046 | Accuracy: 0.587891 | 0.359 sec/iter
Epoch: 69 | Batch: 007 / 030 | Total loss: 1.047 | Reg loss: 0.030 | Tree loss: 1.047 | Accuracy: 0.560547 | 0.359 sec/iter
Epoch: 69 | Batch: 008 / 030 | Total loss: 1.029 | Reg loss: 0.030 | Tree loss: 1.029 | Accuracy: 0.552734 | 0.359 sec/iter
Epoch: 69 | Batch: 009 / 030 | Total loss: 1.030 | Reg loss: 0.030 | Tree loss: 1.030 | Accuracy: 0.572266 | 0.359 sec/iter
Epoch: 69 | Batch: 010 / 030 | Total loss: 1.029 | Reg loss: 0.030 | Tree loss: 1.029 | Accuracy: 0.574219 | 0.359 sec/iter
Epoch: 6

Epoch: 71 | Batch: 006 / 030 | Total loss: 1.046 | Reg loss: 0.029 | Tree loss: 1.046 | Accuracy: 0.562500 | 0.359 sec/iter
Epoch: 71 | Batch: 007 / 030 | Total loss: 1.024 | Reg loss: 0.029 | Tree loss: 1.024 | Accuracy: 0.566406 | 0.359 sec/iter
Epoch: 71 | Batch: 008 / 030 | Total loss: 1.024 | Reg loss: 0.030 | Tree loss: 1.024 | Accuracy: 0.595703 | 0.359 sec/iter
Epoch: 71 | Batch: 009 / 030 | Total loss: 1.014 | Reg loss: 0.030 | Tree loss: 1.014 | Accuracy: 0.552734 | 0.359 sec/iter
Epoch: 71 | Batch: 010 / 030 | Total loss: 1.026 | Reg loss: 0.030 | Tree loss: 1.026 | Accuracy: 0.541016 | 0.359 sec/iter
Epoch: 71 | Batch: 011 / 030 | Total loss: 1.005 | Reg loss: 0.030 | Tree loss: 1.005 | Accuracy: 0.568359 | 0.359 sec/iter
Epoch: 71 | Batch: 012 / 030 | Total loss: 0.975 | Reg loss: 0.030 | Tree loss: 0.975 | Accuracy: 0.607422 | 0.359 sec/iter
Epoch: 71 | Batch: 013 / 030 | Total loss: 0.988 | Reg loss: 0.030 | Tree loss: 0.988 | Accuracy: 0.548828 | 0.359 sec/iter
Epoch: 7

Epoch: 73 | Batch: 009 / 030 | Total loss: 0.996 | Reg loss: 0.029 | Tree loss: 0.996 | Accuracy: 0.597656 | 0.359 sec/iter
Epoch: 73 | Batch: 010 / 030 | Total loss: 0.999 | Reg loss: 0.029 | Tree loss: 0.999 | Accuracy: 0.554688 | 0.359 sec/iter
Epoch: 73 | Batch: 011 / 030 | Total loss: 0.997 | Reg loss: 0.029 | Tree loss: 0.997 | Accuracy: 0.558594 | 0.359 sec/iter
Epoch: 73 | Batch: 012 / 030 | Total loss: 0.972 | Reg loss: 0.029 | Tree loss: 0.972 | Accuracy: 0.595703 | 0.359 sec/iter
Epoch: 73 | Batch: 013 / 030 | Total loss: 0.975 | Reg loss: 0.029 | Tree loss: 0.975 | Accuracy: 0.558594 | 0.359 sec/iter
Epoch: 73 | Batch: 014 / 030 | Total loss: 0.958 | Reg loss: 0.029 | Tree loss: 0.958 | Accuracy: 0.564453 | 0.359 sec/iter
Epoch: 73 | Batch: 015 / 030 | Total loss: 0.961 | Reg loss: 0.029 | Tree loss: 0.961 | Accuracy: 0.580078 | 0.359 sec/iter
Epoch: 73 | Batch: 016 / 030 | Total loss: 0.934 | Reg loss: 0.029 | Tree loss: 0.934 | Accuracy: 0.585938 | 0.359 sec/iter
Epoch: 7

Epoch: 75 | Batch: 012 / 030 | Total loss: 0.974 | Reg loss: 0.029 | Tree loss: 0.974 | Accuracy: 0.542969 | 0.359 sec/iter
Epoch: 75 | Batch: 013 / 030 | Total loss: 0.977 | Reg loss: 0.029 | Tree loss: 0.977 | Accuracy: 0.560547 | 0.359 sec/iter
Epoch: 75 | Batch: 014 / 030 | Total loss: 0.954 | Reg loss: 0.029 | Tree loss: 0.954 | Accuracy: 0.576172 | 0.359 sec/iter
Epoch: 75 | Batch: 015 / 030 | Total loss: 0.936 | Reg loss: 0.029 | Tree loss: 0.936 | Accuracy: 0.589844 | 0.359 sec/iter
Epoch: 75 | Batch: 016 / 030 | Total loss: 0.947 | Reg loss: 0.029 | Tree loss: 0.947 | Accuracy: 0.548828 | 0.359 sec/iter
Epoch: 75 | Batch: 017 / 030 | Total loss: 0.936 | Reg loss: 0.029 | Tree loss: 0.936 | Accuracy: 0.578125 | 0.359 sec/iter
Epoch: 75 | Batch: 018 / 030 | Total loss: 0.891 | Reg loss: 0.029 | Tree loss: 0.891 | Accuracy: 0.648438 | 0.359 sec/iter
Epoch: 75 | Batch: 019 / 030 | Total loss: 0.925 | Reg loss: 0.029 | Tree loss: 0.925 | Accuracy: 0.578125 | 0.359 sec/iter
Epoch: 7

Epoch: 77 | Batch: 015 / 030 | Total loss: 0.928 | Reg loss: 0.029 | Tree loss: 0.928 | Accuracy: 0.617188 | 0.36 sec/iter
Epoch: 77 | Batch: 016 / 030 | Total loss: 0.933 | Reg loss: 0.029 | Tree loss: 0.933 | Accuracy: 0.578125 | 0.36 sec/iter
Epoch: 77 | Batch: 017 / 030 | Total loss: 0.906 | Reg loss: 0.029 | Tree loss: 0.906 | Accuracy: 0.609375 | 0.36 sec/iter
Epoch: 77 | Batch: 018 / 030 | Total loss: 0.915 | Reg loss: 0.029 | Tree loss: 0.915 | Accuracy: 0.576172 | 0.36 sec/iter
Epoch: 77 | Batch: 019 / 030 | Total loss: 0.929 | Reg loss: 0.029 | Tree loss: 0.929 | Accuracy: 0.544922 | 0.36 sec/iter
Epoch: 77 | Batch: 020 / 030 | Total loss: 0.896 | Reg loss: 0.029 | Tree loss: 0.896 | Accuracy: 0.591797 | 0.36 sec/iter
Epoch: 77 | Batch: 021 / 030 | Total loss: 0.894 | Reg loss: 0.029 | Tree loss: 0.894 | Accuracy: 0.611328 | 0.36 sec/iter
Epoch: 77 | Batch: 022 / 030 | Total loss: 0.912 | Reg loss: 0.029 | Tree loss: 0.912 | Accuracy: 0.548828 | 0.36 sec/iter
Epoch: 77 | Batc

Epoch: 79 | Batch: 018 / 030 | Total loss: 0.913 | Reg loss: 0.029 | Tree loss: 0.913 | Accuracy: 0.589844 | 0.36 sec/iter
Epoch: 79 | Batch: 019 / 030 | Total loss: 0.899 | Reg loss: 0.029 | Tree loss: 0.899 | Accuracy: 0.580078 | 0.36 sec/iter
Epoch: 79 | Batch: 020 / 030 | Total loss: 0.904 | Reg loss: 0.029 | Tree loss: 0.904 | Accuracy: 0.576172 | 0.36 sec/iter
Epoch: 79 | Batch: 021 / 030 | Total loss: 0.916 | Reg loss: 0.029 | Tree loss: 0.916 | Accuracy: 0.527344 | 0.36 sec/iter
Epoch: 79 | Batch: 022 / 030 | Total loss: 0.888 | Reg loss: 0.029 | Tree loss: 0.888 | Accuracy: 0.568359 | 0.36 sec/iter
Epoch: 79 | Batch: 023 / 030 | Total loss: 0.879 | Reg loss: 0.029 | Tree loss: 0.879 | Accuracy: 0.587891 | 0.36 sec/iter
Epoch: 79 | Batch: 024 / 030 | Total loss: 0.880 | Reg loss: 0.029 | Tree loss: 0.880 | Accuracy: 0.580078 | 0.36 sec/iter
Epoch: 79 | Batch: 025 / 030 | Total loss: 0.877 | Reg loss: 0.029 | Tree loss: 0.877 | Accuracy: 0.599609 | 0.36 sec/iter
Epoch: 79 | Batc

Epoch: 81 | Batch: 021 / 030 | Total loss: 0.886 | Reg loss: 0.029 | Tree loss: 0.886 | Accuracy: 0.574219 | 0.36 sec/iter
Epoch: 81 | Batch: 022 / 030 | Total loss: 0.893 | Reg loss: 0.029 | Tree loss: 0.893 | Accuracy: 0.572266 | 0.36 sec/iter
Epoch: 81 | Batch: 023 / 030 | Total loss: 0.878 | Reg loss: 0.029 | Tree loss: 0.878 | Accuracy: 0.580078 | 0.36 sec/iter
Epoch: 81 | Batch: 024 / 030 | Total loss: 0.878 | Reg loss: 0.029 | Tree loss: 0.878 | Accuracy: 0.578125 | 0.36 sec/iter
Epoch: 81 | Batch: 025 / 030 | Total loss: 0.857 | Reg loss: 0.029 | Tree loss: 0.857 | Accuracy: 0.599609 | 0.36 sec/iter
Epoch: 81 | Batch: 026 / 030 | Total loss: 0.867 | Reg loss: 0.029 | Tree loss: 0.867 | Accuracy: 0.582031 | 0.36 sec/iter
Epoch: 81 | Batch: 027 / 030 | Total loss: 0.868 | Reg loss: 0.029 | Tree loss: 0.868 | Accuracy: 0.583984 | 0.36 sec/iter
Epoch: 81 | Batch: 028 / 030 | Total loss: 0.858 | Reg loss: 0.029 | Tree loss: 0.858 | Accuracy: 0.591797 | 0.36 sec/iter
Epoch: 81 | Batc

Epoch: 83 | Batch: 024 / 030 | Total loss: 0.866 | Reg loss: 0.029 | Tree loss: 0.866 | Accuracy: 0.597656 | 0.36 sec/iter
Epoch: 83 | Batch: 025 / 030 | Total loss: 0.872 | Reg loss: 0.029 | Tree loss: 0.872 | Accuracy: 0.562500 | 0.36 sec/iter
Epoch: 83 | Batch: 026 / 030 | Total loss: 0.875 | Reg loss: 0.029 | Tree loss: 0.875 | Accuracy: 0.556641 | 0.36 sec/iter
Epoch: 83 | Batch: 027 / 030 | Total loss: 0.881 | Reg loss: 0.029 | Tree loss: 0.881 | Accuracy: 0.535156 | 0.36 sec/iter
Epoch: 83 | Batch: 028 / 030 | Total loss: 0.847 | Reg loss: 0.029 | Tree loss: 0.847 | Accuracy: 0.591797 | 0.36 sec/iter
Epoch: 83 | Batch: 029 / 030 | Total loss: 0.820 | Reg loss: 0.029 | Tree loss: 0.820 | Accuracy: 0.648148 | 0.36 sec/iter
Average sparseness: 0.9821428571428573
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
Epoch: 84 | Batch: 000 / 030

Epoch: 85 | Batch: 027 / 030 | Total loss: 0.844 | Reg loss: 0.029 | Tree loss: 0.844 | Accuracy: 0.607422 | 0.36 sec/iter
Epoch: 85 | Batch: 028 / 030 | Total loss: 0.853 | Reg loss: 0.029 | Tree loss: 0.853 | Accuracy: 0.582031 | 0.36 sec/iter
Epoch: 85 | Batch: 029 / 030 | Total loss: 0.833 | Reg loss: 0.029 | Tree loss: 0.833 | Accuracy: 0.629630 | 0.36 sec/iter
Average sparseness: 0.9821428571428573
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
Epoch: 86 | Batch: 000 / 030 | Total loss: 1.046 | Reg loss: 0.028 | Tree loss: 1.046 | Accuracy: 0.570312 | 0.36 sec/iter
Epoch: 86 | Batch: 001 / 030 | Total loss: 1.037 | Reg loss: 0.028 | Tree loss: 1.037 | Accuracy: 0.603516 | 0.36 sec/iter
Epoch: 86 | Batch: 002 / 030 | Total loss: 1.030 | Reg loss: 0.028 | Tree loss: 1.030 | Accuracy: 0.576172 | 0.36 sec/iter
Epoch: 86 | Batch: 003 / 030

Average sparseness: 0.9821428571428573
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
Epoch: 88 | Batch: 000 / 030 | Total loss: 1.041 | Reg loss: 0.028 | Tree loss: 1.041 | Accuracy: 0.537109 | 0.36 sec/iter
Epoch: 88 | Batch: 001 / 030 | Total loss: 1.040 | Reg loss: 0.028 | Tree loss: 1.040 | Accuracy: 0.585938 | 0.36 sec/iter
Epoch: 88 | Batch: 002 / 030 | Total loss: 1.023 | Reg loss: 0.028 | Tree loss: 1.023 | Accuracy: 0.585938 | 0.36 sec/iter
Epoch: 88 | Batch: 003 / 030 | Total loss: 1.007 | Reg loss: 0.028 | Tree loss: 1.007 | Accuracy: 0.609375 | 0.36 sec/iter
Epoch: 88 | Batch: 004 / 030 | Total loss: 1.027 | Reg loss: 0.028 | Tree loss: 1.027 | Accuracy: 0.556641 | 0.36 sec/iter
Epoch: 88 | Batch: 005 / 030 | Total loss: 0.997 | Reg loss: 0.028 | Tree loss: 0.997 | Accuracy: 0.591797 | 0.36 sec/iter
Epoch: 88 | Batch: 006 / 030

Epoch: 90 | Batch: 001 / 030 | Total loss: 1.019 | Reg loss: 0.028 | Tree loss: 1.019 | Accuracy: 0.585938 | 0.36 sec/iter
Epoch: 90 | Batch: 002 / 030 | Total loss: 0.999 | Reg loss: 0.028 | Tree loss: 0.999 | Accuracy: 0.609375 | 0.36 sec/iter
Epoch: 90 | Batch: 003 / 030 | Total loss: 1.019 | Reg loss: 0.028 | Tree loss: 1.019 | Accuracy: 0.566406 | 0.36 sec/iter
Epoch: 90 | Batch: 004 / 030 | Total loss: 0.997 | Reg loss: 0.028 | Tree loss: 0.997 | Accuracy: 0.580078 | 0.36 sec/iter
Epoch: 90 | Batch: 005 / 030 | Total loss: 0.997 | Reg loss: 0.028 | Tree loss: 0.997 | Accuracy: 0.564453 | 0.36 sec/iter
Epoch: 90 | Batch: 006 / 030 | Total loss: 0.991 | Reg loss: 0.028 | Tree loss: 0.991 | Accuracy: 0.556641 | 0.36 sec/iter
Epoch: 90 | Batch: 007 / 030 | Total loss: 0.982 | Reg loss: 0.028 | Tree loss: 0.982 | Accuracy: 0.572266 | 0.36 sec/iter
Epoch: 90 | Batch: 008 / 030 | Total loss: 0.947 | Reg loss: 0.028 | Tree loss: 0.947 | Accuracy: 0.601562 | 0.36 sec/iter
Epoch: 90 | Batc

Epoch: 92 | Batch: 004 / 030 | Total loss: 1.004 | Reg loss: 0.027 | Tree loss: 1.004 | Accuracy: 0.582031 | 0.361 sec/iter
Epoch: 92 | Batch: 005 / 030 | Total loss: 0.987 | Reg loss: 0.027 | Tree loss: 0.987 | Accuracy: 0.570312 | 0.361 sec/iter
Epoch: 92 | Batch: 006 / 030 | Total loss: 0.972 | Reg loss: 0.028 | Tree loss: 0.972 | Accuracy: 0.597656 | 0.361 sec/iter
Epoch: 92 | Batch: 007 / 030 | Total loss: 0.968 | Reg loss: 0.028 | Tree loss: 0.968 | Accuracy: 0.583984 | 0.361 sec/iter
Epoch: 92 | Batch: 008 / 030 | Total loss: 0.941 | Reg loss: 0.028 | Tree loss: 0.941 | Accuracy: 0.613281 | 0.361 sec/iter
Epoch: 92 | Batch: 009 / 030 | Total loss: 0.936 | Reg loss: 0.028 | Tree loss: 0.936 | Accuracy: 0.576172 | 0.361 sec/iter
Epoch: 92 | Batch: 010 / 030 | Total loss: 0.922 | Reg loss: 0.028 | Tree loss: 0.922 | Accuracy: 0.593750 | 0.361 sec/iter
Epoch: 92 | Batch: 011 / 030 | Total loss: 0.926 | Reg loss: 0.028 | Tree loss: 0.926 | Accuracy: 0.595703 | 0.361 sec/iter
Epoch: 9

Epoch: 94 | Batch: 007 / 030 | Total loss: 0.966 | Reg loss: 0.027 | Tree loss: 0.966 | Accuracy: 0.564453 | 0.361 sec/iter
Epoch: 94 | Batch: 008 / 030 | Total loss: 0.953 | Reg loss: 0.027 | Tree loss: 0.953 | Accuracy: 0.587891 | 0.361 sec/iter
Epoch: 94 | Batch: 009 / 030 | Total loss: 0.932 | Reg loss: 0.027 | Tree loss: 0.932 | Accuracy: 0.607422 | 0.361 sec/iter
Epoch: 94 | Batch: 010 / 030 | Total loss: 0.930 | Reg loss: 0.028 | Tree loss: 0.930 | Accuracy: 0.580078 | 0.361 sec/iter
Epoch: 94 | Batch: 011 / 030 | Total loss: 0.909 | Reg loss: 0.028 | Tree loss: 0.909 | Accuracy: 0.619141 | 0.361 sec/iter
Epoch: 94 | Batch: 012 / 030 | Total loss: 0.914 | Reg loss: 0.028 | Tree loss: 0.914 | Accuracy: 0.570312 | 0.361 sec/iter
Epoch: 94 | Batch: 013 / 030 | Total loss: 0.904 | Reg loss: 0.028 | Tree loss: 0.904 | Accuracy: 0.570312 | 0.361 sec/iter
Epoch: 94 | Batch: 014 / 030 | Total loss: 0.899 | Reg loss: 0.028 | Tree loss: 0.899 | Accuracy: 0.564453 | 0.361 sec/iter
Epoch: 9

Epoch: 96 | Batch: 010 / 030 | Total loss: 0.919 | Reg loss: 0.027 | Tree loss: 0.919 | Accuracy: 0.595703 | 0.361 sec/iter
Epoch: 96 | Batch: 011 / 030 | Total loss: 0.922 | Reg loss: 0.027 | Tree loss: 0.922 | Accuracy: 0.564453 | 0.361 sec/iter
Epoch: 96 | Batch: 012 / 030 | Total loss: 0.916 | Reg loss: 0.027 | Tree loss: 0.916 | Accuracy: 0.560547 | 0.361 sec/iter
Epoch: 96 | Batch: 013 / 030 | Total loss: 0.900 | Reg loss: 0.028 | Tree loss: 0.900 | Accuracy: 0.582031 | 0.361 sec/iter
Epoch: 96 | Batch: 014 / 030 | Total loss: 0.897 | Reg loss: 0.028 | Tree loss: 0.897 | Accuracy: 0.574219 | 0.361 sec/iter
Epoch: 96 | Batch: 015 / 030 | Total loss: 0.888 | Reg loss: 0.028 | Tree loss: 0.888 | Accuracy: 0.583984 | 0.361 sec/iter
Epoch: 96 | Batch: 016 / 030 | Total loss: 0.910 | Reg loss: 0.028 | Tree loss: 0.910 | Accuracy: 0.517578 | 0.361 sec/iter
Epoch: 96 | Batch: 017 / 030 | Total loss: 0.855 | Reg loss: 0.028 | Tree loss: 0.855 | Accuracy: 0.636719 | 0.361 sec/iter
Epoch: 9

Epoch: 98 | Batch: 013 / 030 | Total loss: 0.917 | Reg loss: 0.027 | Tree loss: 0.917 | Accuracy: 0.541016 | 0.361 sec/iter
Epoch: 98 | Batch: 014 / 030 | Total loss: 0.913 | Reg loss: 0.027 | Tree loss: 0.913 | Accuracy: 0.546875 | 0.361 sec/iter
Epoch: 98 | Batch: 015 / 030 | Total loss: 0.884 | Reg loss: 0.028 | Tree loss: 0.884 | Accuracy: 0.583984 | 0.361 sec/iter
Epoch: 98 | Batch: 016 / 030 | Total loss: 0.874 | Reg loss: 0.028 | Tree loss: 0.874 | Accuracy: 0.587891 | 0.361 sec/iter
Epoch: 98 | Batch: 017 / 030 | Total loss: 0.890 | Reg loss: 0.028 | Tree loss: 0.890 | Accuracy: 0.544922 | 0.361 sec/iter
Epoch: 98 | Batch: 018 / 030 | Total loss: 0.880 | Reg loss: 0.028 | Tree loss: 0.880 | Accuracy: 0.546875 | 0.361 sec/iter
Epoch: 98 | Batch: 019 / 030 | Total loss: 0.849 | Reg loss: 0.028 | Tree loss: 0.849 | Accuracy: 0.585938 | 0.361 sec/iter
Epoch: 98 | Batch: 020 / 030 | Total loss: 0.867 | Reg loss: 0.028 | Tree loss: 0.867 | Accuracy: 0.562500 | 0.361 sec/iter
Epoch: 9

In [33]:
plt.figure(figsize=(10, 5))
plt.ylabel("Accuracy")
plt.xlabel('Iteration')
plt.plot(accs, label='Accuracy vs iteration')
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [34]:
plt.figure()
plt.ylabel("Loss")
plt.xlabel('Iteration')
plt.plot(losses, label='Loss vs iteration')
plt.yscale("log")
plt.show()

plt.figure()
weights = tree.inner_nodes.weight.cpu().detach().numpy().flatten()
plt.hist(weights, bins=500)
weights_std = np.std(weights)
weights_mean = np.mean(weights)
plt.axvline(weights_mean + weights_std, color='r')
plt.axvline(weights_mean - weights_std, color='r')
plt.title(f"Mean: {weights_mean}   |   STD: {weights_std}")
plt.yscale("log")
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

# Tree Visualization

In [35]:
plt.figure(figsize=(15, 10), dpi=80)
avg_height, root = tree.visualize()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Average height: 7.083333333333333


# Extract Rules

# Accumulate samples in the leaves

In [36]:
print(f"Number of patterns: {len(root.get_leaves())}")

Number of patterns: 96


In [37]:
method = 'greedy'

In [38]:
root.clear_leaves_samples()

with torch.no_grad():
    for batch_idx, (data, target) in enumerate(tree_loader):
        root.accumulate_samples(data, method)



# Tighten boundaries

In [40]:
attr_names = dataset.items

# print(attr_names)
leaves = root.get_leaves()
sum_comprehensibility = 0
comprehensibilities = []
for pattern_counter, leaf in enumerate(leaves):
    leaf.reset_path()
    leaf.tighten_with_accumulated_samples()
    conds = leaf.get_path_conditions(attr_names)
    print(f"============== Pattern {pattern_counter + 1} ==============")
    comprehensibilities.append(sum([cond.comprehensibility for cond in conds]))
    
print(f"Average comprehensibility: {np.mean(comprehensibilities)}")
print(f"std comprehensibility: {np.std(comprehensibilities)}")
print(f"var comprehensibility: {np.var(comprehensibilities)}")
print(f"minimum comprehensibility: {np.min(comprehensibilities)}")
print(f"maximum comprehensibility: {np.max(comprehensibilities)}")

11646
1668
1642
Average comprehensibility: 35.4375
std comprehensibility: 4.887936894368966
var comprehensibility: 23.891927083333332
minimum comprehensibility: 20
maximum comprehensibility: 42
