In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib widget
import numpy as np
import pandas as pd
import torch
import torchvision
from torch import optim
import matplotlib.pyplot as plt
import torch.nn.functional as F
from sklearn.metrics import pairwise_distances
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tqdm.notebook import tqdm
import os
import sys
import time
import torch.nn as nn
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)
from stream_generators.market_basket_dataset import MarketBasketDataset, BinaryEncodingTransform, RemoveItemsTransform
from utils.MatplotlibUtils import reduce_dims_and_plot
from network.auto_encoder import AutoEncoder
from losses.knn_loss import KNNLoss
from sklearn.cluster import DBSCAN
from sklearn.mixture import GaussianMixture
from soft_decision_tree.sdt_model import SDT
from sklearn.metrics import davies_bouldin_score

In [2]:
k = 128
tree_depth = 8
device = 'cuda'
dataset_path = r"/mnt/qnap/ekosman/Groceries_dataset.csv"

## Create the market basket dataset and use one-hot encoding for items

In [3]:
dataset = MarketBasketDataset(dataset_path=dataset_path)

In [4]:
model = AutoEncoder(dataset.n_items, 50, 4).train().to(device)
epochs = 500
lr = 5e-3
batch_size = 512
log_every = 5

In [5]:
dataset.transform = torchvision.transforms.Compose([
    RemoveItemsTransform(p=0.5),
    BinaryEncodingTransform(mapping=dataset.items_to_idx),
]
)
dataset.target_transform = torchvision.transforms.Compose([
    BinaryEncodingTransform(mapping=dataset.items_to_idx),
]
)

In [6]:
model.train()
data_iter = torch.utils.data.DataLoader(dataset,
                                     batch_size=batch_size,
                                     shuffle=True,
                                     num_workers=1,
                                     pin_memory=True)
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)
scheduler = ReduceLROnPlateau(optimizer, 'min', verbose=True, factor=0.7, threshold=1e-4)
knn_crt = KNNLoss(k=k).to(device)
losses = []
alpha = 10/170
gamma = 2
for epoch in range(epochs):
    total_loss = 0
    for iteration, (batch, target) in enumerate(data_iter):
        batch = batch.to(device)
        target = target.to(device)
        outputs, iterm = model(batch, return_intermidiate=True)
        mse_loss = F.binary_cross_entropy_with_logits(outputs, target, reduction='none')
        mask = torch.ones_like(mse_loss)
        mask[target == 0] = alpha ** gamma
        mask[target == 1] = (1 - alpha) ** gamma
        mse_loss = (mse_loss * mask).sum(dim=-1).mean()
        try:
            knn_loss = knn_crt(iterm)
            if torch.isinf(knn_loss):
                knn_loss = 0
        except ValueError:
            knn_loss = torch.tensor(0)
        loss = mse_loss + knn_loss
        total_loss += loss.item()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if iteration % log_every == 0:
            print(f"Epoch {epoch} / {epochs} | iteration {iteration} / {len(data_iter)} | Total Loss: {loss.item()} | KNN Loss: {knn_loss.item()} | BCE Loss: {mse_loss.item()}")
    
    scheduler.step(total_loss / (iteration + 1))
    losses.append(total_loss / (iteration + 1))

Epoch 0 / 500 | iteration 0 / 30 | Total Loss: 8.153863906860352 | KNN Loss: 6.230323791503906 | BCE Loss: 1.9235402345657349
Epoch 0 / 500 | iteration 5 / 30 | Total Loss: 8.229302406311035 | KNN Loss: 6.230156898498535 | BCE Loss: 1.999145746231079
Epoch 0 / 500 | iteration 10 / 30 | Total Loss: 8.185433387756348 | KNN Loss: 6.230011940002441 | BCE Loss: 1.9554216861724854
Epoch 0 / 500 | iteration 15 / 30 | Total Loss: 8.193683624267578 | KNN Loss: 6.230076313018799 | BCE Loss: 1.9636069536209106
Epoch 0 / 500 | iteration 20 / 30 | Total Loss: 8.184626579284668 | KNN Loss: 6.2299370765686035 | BCE Loss: 1.9546892642974854
Epoch 0 / 500 | iteration 25 / 30 | Total Loss: 8.163071632385254 | KNN Loss: 6.229653358459473 | BCE Loss: 1.9334180355072021
Epoch 1 / 500 | iteration 0 / 30 | Total Loss: 8.137682914733887 | KNN Loss: 6.229648113250732 | BCE Loss: 1.9080350399017334
Epoch 1 / 500 | iteration 5 / 30 | Total Loss: 8.162541389465332 | KNN Loss: 6.229461669921875 | BCE Loss: 1.93307

Epoch 10 / 500 | iteration 25 / 30 | Total Loss: 7.0865478515625 | KNN Loss: 5.937524795532227 | BCE Loss: 1.1490232944488525
Epoch 11 / 500 | iteration 0 / 30 | Total Loss: 7.04889440536499 | KNN Loss: 5.923275470733643 | BCE Loss: 1.1256190538406372
Epoch 11 / 500 | iteration 5 / 30 | Total Loss: 7.000793933868408 | KNN Loss: 5.882357597351074 | BCE Loss: 1.118436336517334
Epoch 11 / 500 | iteration 10 / 30 | Total Loss: 6.964555740356445 | KNN Loss: 5.85343074798584 | BCE Loss: 1.1111252307891846
Epoch 11 / 500 | iteration 15 / 30 | Total Loss: 6.921308517456055 | KNN Loss: 5.802789211273193 | BCE Loss: 1.1185193061828613
Epoch 11 / 500 | iteration 20 / 30 | Total Loss: 6.896669864654541 | KNN Loss: 5.802331447601318 | BCE Loss: 1.0943384170532227
Epoch 11 / 500 | iteration 25 / 30 | Total Loss: 6.820957183837891 | KNN Loss: 5.726874828338623 | BCE Loss: 1.0940823554992676
Epoch 12 / 500 | iteration 0 / 30 | Total Loss: 6.794429779052734 | KNN Loss: 5.678097248077393 | BCE Loss: 1.1

Epoch 21 / 500 | iteration 20 / 30 | Total Loss: 6.182769298553467 | KNN Loss: 5.102409362792969 | BCE Loss: 1.0803598165512085
Epoch 21 / 500 | iteration 25 / 30 | Total Loss: 6.118555068969727 | KNN Loss: 5.076013088226318 | BCE Loss: 1.042541742324829
Epoch 22 / 500 | iteration 0 / 30 | Total Loss: 6.144984245300293 | KNN Loss: 5.0769243240356445 | BCE Loss: 1.0680601596832275
Epoch 22 / 500 | iteration 5 / 30 | Total Loss: 6.151422500610352 | KNN Loss: 5.071323394775391 | BCE Loss: 1.080099105834961
Epoch 22 / 500 | iteration 10 / 30 | Total Loss: 6.122405052185059 | KNN Loss: 5.060118675231934 | BCE Loss: 1.062286138534546
Epoch 22 / 500 | iteration 15 / 30 | Total Loss: 6.124342918395996 | KNN Loss: 5.078787326812744 | BCE Loss: 1.0455553531646729
Epoch 22 / 500 | iteration 20 / 30 | Total Loss: 6.1400980949401855 | KNN Loss: 5.086660861968994 | BCE Loss: 1.0534371137619019
Epoch 22 / 500 | iteration 25 / 30 | Total Loss: 6.175093173980713 | KNN Loss: 5.101563930511475 | BCE Loss

Epoch 32 / 500 | iteration 15 / 30 | Total Loss: 6.176115989685059 | KNN Loss: 5.059687614440918 | BCE Loss: 1.1164283752441406
Epoch 32 / 500 | iteration 20 / 30 | Total Loss: 6.091367721557617 | KNN Loss: 5.044297218322754 | BCE Loss: 1.0470705032348633
Epoch 32 / 500 | iteration 25 / 30 | Total Loss: 6.087880611419678 | KNN Loss: 5.050999641418457 | BCE Loss: 1.0368809700012207
Epoch 33 / 500 | iteration 0 / 30 | Total Loss: 6.07302188873291 | KNN Loss: 5.0403151512146 | BCE Loss: 1.0327069759368896
Epoch 33 / 500 | iteration 5 / 30 | Total Loss: 6.10439395904541 | KNN Loss: 5.044567108154297 | BCE Loss: 1.0598266124725342
Epoch 33 / 500 | iteration 10 / 30 | Total Loss: 6.067069053649902 | KNN Loss: 5.060004711151123 | BCE Loss: 1.0070645809173584
Epoch 33 / 500 | iteration 15 / 30 | Total Loss: 6.131958484649658 | KNN Loss: 5.058707237243652 | BCE Loss: 1.0732513666152954
Epoch 33 / 500 | iteration 20 / 30 | Total Loss: 6.127476692199707 | KNN Loss: 5.06211519241333 | BCE Loss: 1.

Epoch 43 / 500 | iteration 5 / 30 | Total Loss: 6.08839225769043 | KNN Loss: 5.038623809814453 | BCE Loss: 1.0497686862945557
Epoch 43 / 500 | iteration 10 / 30 | Total Loss: 6.06320333480835 | KNN Loss: 5.032927989959717 | BCE Loss: 1.0302753448486328
Epoch 43 / 500 | iteration 15 / 30 | Total Loss: 6.117987632751465 | KNN Loss: 5.043788909912109 | BCE Loss: 1.0741989612579346
Epoch 43 / 500 | iteration 20 / 30 | Total Loss: 6.118149280548096 | KNN Loss: 5.034641265869141 | BCE Loss: 1.0835081338882446
Epoch 43 / 500 | iteration 25 / 30 | Total Loss: 6.092503547668457 | KNN Loss: 5.027092933654785 | BCE Loss: 1.0654107332229614
Epoch 44 / 500 | iteration 0 / 30 | Total Loss: 6.100611686706543 | KNN Loss: 5.050936222076416 | BCE Loss: 1.0496752262115479
Epoch 44 / 500 | iteration 5 / 30 | Total Loss: 6.132219314575195 | KNN Loss: 5.0503973960876465 | BCE Loss: 1.0818217992782593
Epoch 44 / 500 | iteration 10 / 30 | Total Loss: 6.072502136230469 | KNN Loss: 5.048242568969727 | BCE Loss:

Epoch 54 / 500 | iteration 0 / 30 | Total Loss: 6.051682949066162 | KNN Loss: 5.028656959533691 | BCE Loss: 1.0230261087417603
Epoch 54 / 500 | iteration 5 / 30 | Total Loss: 6.083219528198242 | KNN Loss: 5.0402140617370605 | BCE Loss: 1.0430057048797607
Epoch 54 / 500 | iteration 10 / 30 | Total Loss: 6.18244743347168 | KNN Loss: 5.131138801574707 | BCE Loss: 1.0513088703155518
Epoch 54 / 500 | iteration 15 / 30 | Total Loss: 6.079896450042725 | KNN Loss: 5.0313334465026855 | BCE Loss: 1.0485631227493286
Epoch 54 / 500 | iteration 20 / 30 | Total Loss: 6.06731653213501 | KNN Loss: 5.034237861633301 | BCE Loss: 1.0330787897109985
Epoch 54 / 500 | iteration 25 / 30 | Total Loss: 6.079414367675781 | KNN Loss: 5.027195453643799 | BCE Loss: 1.0522189140319824
Epoch 55 / 500 | iteration 0 / 30 | Total Loss: 6.058646202087402 | KNN Loss: 5.027067184448242 | BCE Loss: 1.031578779220581
Epoch 55 / 500 | iteration 5 / 30 | Total Loss: 6.072388648986816 | KNN Loss: 5.057621955871582 | BCE Loss: 

Epoch 64 / 500 | iteration 25 / 30 | Total Loss: 6.081455230712891 | KNN Loss: 5.0207953453063965 | BCE Loss: 1.0606601238250732
Epoch 65 / 500 | iteration 0 / 30 | Total Loss: 6.071977615356445 | KNN Loss: 5.035012722015381 | BCE Loss: 1.0369646549224854
Epoch 65 / 500 | iteration 5 / 30 | Total Loss: 6.0942583084106445 | KNN Loss: 5.040613651275635 | BCE Loss: 1.0536445379257202
Epoch 65 / 500 | iteration 10 / 30 | Total Loss: 6.0997796058654785 | KNN Loss: 5.057556629180908 | BCE Loss: 1.0422230958938599
Epoch 65 / 500 | iteration 15 / 30 | Total Loss: 6.0909223556518555 | KNN Loss: 5.046536922454834 | BCE Loss: 1.0443851947784424
Epoch 65 / 500 | iteration 20 / 30 | Total Loss: 6.096246242523193 | KNN Loss: 5.0279083251953125 | BCE Loss: 1.0683379173278809
Epoch 65 / 500 | iteration 25 / 30 | Total Loss: 6.063228607177734 | KNN Loss: 5.024147987365723 | BCE Loss: 1.0390805006027222
Epoch 66 / 500 | iteration 0 / 30 | Total Loss: 6.145074367523193 | KNN Loss: 5.084218502044678 | BCE

Epoch 75 / 500 | iteration 20 / 30 | Total Loss: 6.081385135650635 | KNN Loss: 5.027251243591309 | BCE Loss: 1.0541338920593262
Epoch 75 / 500 | iteration 25 / 30 | Total Loss: 6.100461483001709 | KNN Loss: 5.051543712615967 | BCE Loss: 1.0489176511764526
Epoch 76 / 500 | iteration 0 / 30 | Total Loss: 6.106805801391602 | KNN Loss: 5.0235209465026855 | BCE Loss: 1.0832850933074951
Epoch 76 / 500 | iteration 5 / 30 | Total Loss: 6.104742050170898 | KNN Loss: 5.03533935546875 | BCE Loss: 1.0694029331207275
Epoch 76 / 500 | iteration 10 / 30 | Total Loss: 6.057454586029053 | KNN Loss: 5.023622512817383 | BCE Loss: 1.03383207321167
Epoch 76 / 500 | iteration 15 / 30 | Total Loss: 6.130719184875488 | KNN Loss: 5.086127758026123 | BCE Loss: 1.0445916652679443
Epoch 76 / 500 | iteration 20 / 30 | Total Loss: 6.0970258712768555 | KNN Loss: 5.046356678009033 | BCE Loss: 1.0506694316864014
Epoch 76 / 500 | iteration 25 / 30 | Total Loss: 6.068598747253418 | KNN Loss: 5.029611110687256 | BCE Loss

Epoch 86 / 500 | iteration 15 / 30 | Total Loss: 6.077889442443848 | KNN Loss: 5.0281291007995605 | BCE Loss: 1.0497605800628662
Epoch 86 / 500 | iteration 20 / 30 | Total Loss: 6.073517322540283 | KNN Loss: 5.028764247894287 | BCE Loss: 1.0447531938552856
Epoch 86 / 500 | iteration 25 / 30 | Total Loss: 6.07501220703125 | KNN Loss: 5.019989967346191 | BCE Loss: 1.055022120475769
Epoch 87 / 500 | iteration 0 / 30 | Total Loss: 6.068759918212891 | KNN Loss: 5.01613187789917 | BCE Loss: 1.0526278018951416
Epoch 87 / 500 | iteration 5 / 30 | Total Loss: 6.067033767700195 | KNN Loss: 5.030699729919434 | BCE Loss: 1.0363339185714722
Epoch 87 / 500 | iteration 10 / 30 | Total Loss: 6.070094108581543 | KNN Loss: 5.020450115203857 | BCE Loss: 1.0496439933776855
Epoch 87 / 500 | iteration 15 / 30 | Total Loss: 6.068534851074219 | KNN Loss: 5.034926414489746 | BCE Loss: 1.0336081981658936
Epoch 87 / 500 | iteration 20 / 30 | Total Loss: 6.074002265930176 | KNN Loss: 5.04095458984375 | BCE Loss: 

Epoch 97 / 500 | iteration 5 / 30 | Total Loss: 6.062525749206543 | KNN Loss: 5.021864891052246 | BCE Loss: 1.0406606197357178
Epoch 97 / 500 | iteration 10 / 30 | Total Loss: 6.047896385192871 | KNN Loss: 5.01005220413208 | BCE Loss: 1.037843942642212
Epoch 97 / 500 | iteration 15 / 30 | Total Loss: 6.08022403717041 | KNN Loss: 5.022010326385498 | BCE Loss: 1.0582139492034912
Epoch 97 / 500 | iteration 20 / 30 | Total Loss: 6.079639434814453 | KNN Loss: 5.029745101928711 | BCE Loss: 1.0498942136764526
Epoch 97 / 500 | iteration 25 / 30 | Total Loss: 6.05873441696167 | KNN Loss: 5.007699012756348 | BCE Loss: 1.0510354042053223
Epoch 98 / 500 | iteration 0 / 30 | Total Loss: 6.02940559387207 | KNN Loss: 5.0172505378723145 | BCE Loss: 1.0121550559997559
Epoch 98 / 500 | iteration 5 / 30 | Total Loss: 6.0686235427856445 | KNN Loss: 5.0143232345581055 | BCE Loss: 1.05430006980896
Epoch 98 / 500 | iteration 10 / 30 | Total Loss: 6.061220169067383 | KNN Loss: 5.025240898132324 | BCE Loss: 1.

Epoch 107 / 500 | iteration 25 / 30 | Total Loss: 6.063564777374268 | KNN Loss: 5.0297393798828125 | BCE Loss: 1.033825397491455
Epoch 108 / 500 | iteration 0 / 30 | Total Loss: 6.081810474395752 | KNN Loss: 5.050789833068848 | BCE Loss: 1.0310207605361938
Epoch 108 / 500 | iteration 5 / 30 | Total Loss: 6.0637922286987305 | KNN Loss: 5.02841329574585 | BCE Loss: 1.0353788137435913
Epoch 108 / 500 | iteration 10 / 30 | Total Loss: 6.06151008605957 | KNN Loss: 5.005863189697266 | BCE Loss: 1.0556470155715942
Epoch 108 / 500 | iteration 15 / 30 | Total Loss: 6.0906243324279785 | KNN Loss: 5.015073299407959 | BCE Loss: 1.0755510330200195
Epoch 108 / 500 | iteration 20 / 30 | Total Loss: 6.058589458465576 | KNN Loss: 5.010581970214844 | BCE Loss: 1.048007607460022
Epoch 108 / 500 | iteration 25 / 30 | Total Loss: 6.1202311515808105 | KNN Loss: 5.082742691040039 | BCE Loss: 1.0374884605407715
Epoch   109: reducing learning rate of group 0 to 1.7150e-03.
Epoch 109 / 500 | iteration 0 / 30 | 

Epoch 118 / 500 | iteration 15 / 30 | Total Loss: 6.073353290557861 | KNN Loss: 5.014732837677002 | BCE Loss: 1.0586204528808594
Epoch 118 / 500 | iteration 20 / 30 | Total Loss: 6.047101974487305 | KNN Loss: 5.023948669433594 | BCE Loss: 1.0231531858444214
Epoch 118 / 500 | iteration 25 / 30 | Total Loss: 6.05424690246582 | KNN Loss: 5.0101118087768555 | BCE Loss: 1.0441348552703857
Epoch 119 / 500 | iteration 0 / 30 | Total Loss: 6.054615497589111 | KNN Loss: 5.0177106857299805 | BCE Loss: 1.0369046926498413
Epoch 119 / 500 | iteration 5 / 30 | Total Loss: 6.044525623321533 | KNN Loss: 5.0132575035095215 | BCE Loss: 1.0312682390213013
Epoch 119 / 500 | iteration 10 / 30 | Total Loss: 6.043309211730957 | KNN Loss: 4.997982501983643 | BCE Loss: 1.045326828956604
Epoch 119 / 500 | iteration 15 / 30 | Total Loss: 6.057426452636719 | KNN Loss: 5.015480995178223 | BCE Loss: 1.0419456958770752
Epoch 119 / 500 | iteration 20 / 30 | Total Loss: 6.053518295288086 | KNN Loss: 5.036980628967285 

Epoch 129 / 500 | iteration 5 / 30 | Total Loss: 6.066904067993164 | KNN Loss: 5.023560047149658 | BCE Loss: 1.043344259262085
Epoch 129 / 500 | iteration 10 / 30 | Total Loss: 6.065980911254883 | KNN Loss: 5.008481979370117 | BCE Loss: 1.0574991703033447
Epoch 129 / 500 | iteration 15 / 30 | Total Loss: 6.0414018630981445 | KNN Loss: 5.012219429016113 | BCE Loss: 1.0291826725006104
Epoch 129 / 500 | iteration 20 / 30 | Total Loss: 6.047889709472656 | KNN Loss: 5.012051105499268 | BCE Loss: 1.0358388423919678
Epoch 129 / 500 | iteration 25 / 30 | Total Loss: 6.05029296875 | KNN Loss: 5.018232822418213 | BCE Loss: 1.0320602655410767
Epoch 130 / 500 | iteration 0 / 30 | Total Loss: 6.030124664306641 | KNN Loss: 5.004165172576904 | BCE Loss: 1.0259597301483154
Epoch 130 / 500 | iteration 5 / 30 | Total Loss: 6.0670671463012695 | KNN Loss: 5.02721643447876 | BCE Loss: 1.0398504734039307
Epoch 130 / 500 | iteration 10 / 30 | Total Loss: 6.115882873535156 | KNN Loss: 5.082324504852295 | BCE 

Epoch 139 / 500 | iteration 25 / 30 | Total Loss: 6.033400535583496 | KNN Loss: 5.008886337280273 | BCE Loss: 1.0245139598846436
Epoch 140 / 500 | iteration 0 / 30 | Total Loss: 6.037459850311279 | KNN Loss: 5.001664161682129 | BCE Loss: 1.0357956886291504
Epoch 140 / 500 | iteration 5 / 30 | Total Loss: 6.042452812194824 | KNN Loss: 5.025990009307861 | BCE Loss: 1.016462802886963
Epoch 140 / 500 | iteration 10 / 30 | Total Loss: 6.055884838104248 | KNN Loss: 5.005781650543213 | BCE Loss: 1.0501031875610352
Epoch 140 / 500 | iteration 15 / 30 | Total Loss: 6.05653190612793 | KNN Loss: 5.019332408905029 | BCE Loss: 1.0371993780136108
Epoch 140 / 500 | iteration 20 / 30 | Total Loss: 6.030691146850586 | KNN Loss: 5.0150675773620605 | BCE Loss: 1.0156235694885254
Epoch 140 / 500 | iteration 25 / 30 | Total Loss: 6.0553717613220215 | KNN Loss: 5.010310173034668 | BCE Loss: 1.0450615882873535
Epoch 141 / 500 | iteration 0 / 30 | Total Loss: 6.07576322555542 | KNN Loss: 5.068942546844482 | B

Epoch 150 / 500 | iteration 15 / 30 | Total Loss: 6.052748680114746 | KNN Loss: 5.014116287231445 | BCE Loss: 1.0386326313018799
Epoch 150 / 500 | iteration 20 / 30 | Total Loss: 6.062108993530273 | KNN Loss: 5.021981239318848 | BCE Loss: 1.0401275157928467
Epoch 150 / 500 | iteration 25 / 30 | Total Loss: 6.048460960388184 | KNN Loss: 5.0100250244140625 | BCE Loss: 1.0384360551834106
Epoch 151 / 500 | iteration 0 / 30 | Total Loss: 6.045115947723389 | KNN Loss: 5.008676528930664 | BCE Loss: 1.0364394187927246
Epoch 151 / 500 | iteration 5 / 30 | Total Loss: 6.060601234436035 | KNN Loss: 5.034937858581543 | BCE Loss: 1.0256634950637817
Epoch 151 / 500 | iteration 10 / 30 | Total Loss: 6.066564559936523 | KNN Loss: 5.035656452178955 | BCE Loss: 1.0309078693389893
Epoch 151 / 500 | iteration 15 / 30 | Total Loss: 6.053614616394043 | KNN Loss: 5.010173797607422 | BCE Loss: 1.0434409379959106
Epoch 151 / 500 | iteration 20 / 30 | Total Loss: 6.041092395782471 | KNN Loss: 5.017368316650391 

Epoch 161 / 500 | iteration 5 / 30 | Total Loss: 6.048313617706299 | KNN Loss: 5.020328521728516 | BCE Loss: 1.0279849767684937
Epoch 161 / 500 | iteration 10 / 30 | Total Loss: 6.0476393699646 | KNN Loss: 5.01847505569458 | BCE Loss: 1.02916419506073
Epoch 161 / 500 | iteration 15 / 30 | Total Loss: 6.050023078918457 | KNN Loss: 5.015835285186768 | BCE Loss: 1.0341877937316895
Epoch 161 / 500 | iteration 20 / 30 | Total Loss: 6.037081718444824 | KNN Loss: 5.018431186676025 | BCE Loss: 1.018650770187378
Epoch 161 / 500 | iteration 25 / 30 | Total Loss: 6.0707855224609375 | KNN Loss: 5.024952411651611 | BCE Loss: 1.0458333492279053
Epoch 162 / 500 | iteration 0 / 30 | Total Loss: 6.058978080749512 | KNN Loss: 5.018710613250732 | BCE Loss: 1.0402674674987793
Epoch 162 / 500 | iteration 5 / 30 | Total Loss: 6.032491683959961 | KNN Loss: 5.0050177574157715 | BCE Loss: 1.0274739265441895
Epoch 162 / 500 | iteration 10 / 30 | Total Loss: 6.016548156738281 | KNN Loss: 5.016944885253906 | BCE 

Epoch 171 / 500 | iteration 25 / 30 | Total Loss: 6.039434909820557 | KNN Loss: 4.999897480010986 | BCE Loss: 1.0395375490188599
Epoch 172 / 500 | iteration 0 / 30 | Total Loss: 6.071295261383057 | KNN Loss: 5.017430305480957 | BCE Loss: 1.05386483669281
Epoch 172 / 500 | iteration 5 / 30 | Total Loss: 6.0592122077941895 | KNN Loss: 5.0185980796813965 | BCE Loss: 1.040614128112793
Epoch 172 / 500 | iteration 10 / 30 | Total Loss: 6.067696571350098 | KNN Loss: 5.010709285736084 | BCE Loss: 1.0569872856140137
Epoch 172 / 500 | iteration 15 / 30 | Total Loss: 6.070507049560547 | KNN Loss: 5.0191731452941895 | BCE Loss: 1.0513336658477783
Epoch 172 / 500 | iteration 20 / 30 | Total Loss: 6.051589488983154 | KNN Loss: 5.007046699523926 | BCE Loss: 1.044542908668518
Epoch 172 / 500 | iteration 25 / 30 | Total Loss: 6.032484531402588 | KNN Loss: 5.005403995513916 | BCE Loss: 1.0270806550979614
Epoch 173 / 500 | iteration 0 / 30 | Total Loss: 6.059352397918701 | KNN Loss: 5.01471471786499 | BC

Epoch 182 / 500 | iteration 15 / 30 | Total Loss: 6.051867961883545 | KNN Loss: 5.025279998779297 | BCE Loss: 1.0265878438949585
Epoch 182 / 500 | iteration 20 / 30 | Total Loss: 6.048334121704102 | KNN Loss: 5.00639009475708 | BCE Loss: 1.0419442653656006
Epoch 182 / 500 | iteration 25 / 30 | Total Loss: 6.048981189727783 | KNN Loss: 5.022535800933838 | BCE Loss: 1.0264452695846558
Epoch 183 / 500 | iteration 0 / 30 | Total Loss: 6.0886383056640625 | KNN Loss: 5.065757751464844 | BCE Loss: 1.0228805541992188
Epoch 183 / 500 | iteration 5 / 30 | Total Loss: 6.04541015625 | KNN Loss: 5.014416694641113 | BCE Loss: 1.0309933423995972
Epoch 183 / 500 | iteration 10 / 30 | Total Loss: 6.0607171058654785 | KNN Loss: 5.014969348907471 | BCE Loss: 1.0457477569580078
Epoch 183 / 500 | iteration 15 / 30 | Total Loss: 6.062707424163818 | KNN Loss: 5.0274338722229 | BCE Loss: 1.035273551940918
Epoch 183 / 500 | iteration 20 / 30 | Total Loss: 6.028414726257324 | KNN Loss: 5.012734889984131 | BCE L

Epoch 193 / 500 | iteration 5 / 30 | Total Loss: 6.04957389831543 | KNN Loss: 5.013742446899414 | BCE Loss: 1.0358312129974365
Epoch 193 / 500 | iteration 10 / 30 | Total Loss: 6.062358856201172 | KNN Loss: 5.032676696777344 | BCE Loss: 1.0296823978424072
Epoch 193 / 500 | iteration 15 / 30 | Total Loss: 6.057497978210449 | KNN Loss: 4.992732048034668 | BCE Loss: 1.0647661685943604
Epoch 193 / 500 | iteration 20 / 30 | Total Loss: 6.031486988067627 | KNN Loss: 5.010270118713379 | BCE Loss: 1.021216869354248
Epoch 193 / 500 | iteration 25 / 30 | Total Loss: 6.059513092041016 | KNN Loss: 5.015792369842529 | BCE Loss: 1.0437204837799072
Epoch 194 / 500 | iteration 0 / 30 | Total Loss: 6.049061298370361 | KNN Loss: 5.013402938842773 | BCE Loss: 1.035658359527588
Epoch 194 / 500 | iteration 5 / 30 | Total Loss: 6.07982063293457 | KNN Loss: 5.01749324798584 | BCE Loss: 1.0623271465301514
Epoch 194 / 500 | iteration 10 / 30 | Total Loss: 6.075778961181641 | KNN Loss: 5.025097846984863 | BCE L

Epoch 203 / 500 | iteration 25 / 30 | Total Loss: 6.076411247253418 | KNN Loss: 5.041466236114502 | BCE Loss: 1.0349451303482056
Epoch 204 / 500 | iteration 0 / 30 | Total Loss: 6.043013572692871 | KNN Loss: 5.000006198883057 | BCE Loss: 1.0430076122283936
Epoch 204 / 500 | iteration 5 / 30 | Total Loss: 6.077662467956543 | KNN Loss: 5.026898384094238 | BCE Loss: 1.0507640838623047
Epoch 204 / 500 | iteration 10 / 30 | Total Loss: 6.078774452209473 | KNN Loss: 5.0721893310546875 | BCE Loss: 1.006584882736206
Epoch 204 / 500 | iteration 15 / 30 | Total Loss: 6.088666915893555 | KNN Loss: 5.054577827453613 | BCE Loss: 1.0340888500213623
Epoch 204 / 500 | iteration 20 / 30 | Total Loss: 6.059669494628906 | KNN Loss: 5.041635990142822 | BCE Loss: 1.018033504486084
Epoch 204 / 500 | iteration 25 / 30 | Total Loss: 6.085874557495117 | KNN Loss: 5.027984142303467 | BCE Loss: 1.0578904151916504
Epoch 205 / 500 | iteration 0 / 30 | Total Loss: 6.056391716003418 | KNN Loss: 5.010897159576416 | B

Epoch 214 / 500 | iteration 15 / 30 | Total Loss: 6.107327461242676 | KNN Loss: 5.058116912841797 | BCE Loss: 1.049210786819458
Epoch 214 / 500 | iteration 20 / 30 | Total Loss: 6.0383782386779785 | KNN Loss: 5.016050815582275 | BCE Loss: 1.0223274230957031
Epoch 214 / 500 | iteration 25 / 30 | Total Loss: 6.040165901184082 | KNN Loss: 5.0120463371276855 | BCE Loss: 1.0281195640563965
Epoch 215 / 500 | iteration 0 / 30 | Total Loss: 6.037905693054199 | KNN Loss: 5.023917198181152 | BCE Loss: 1.0139882564544678
Epoch 215 / 500 | iteration 5 / 30 | Total Loss: 6.030048370361328 | KNN Loss: 5.010637283325195 | BCE Loss: 1.0194108486175537
Epoch 215 / 500 | iteration 10 / 30 | Total Loss: 6.077371120452881 | KNN Loss: 5.01474142074585 | BCE Loss: 1.0626298189163208
Epoch 215 / 500 | iteration 15 / 30 | Total Loss: 6.027690887451172 | KNN Loss: 5.010293483734131 | BCE Loss: 1.017397165298462
Epoch 215 / 500 | iteration 20 / 30 | Total Loss: 6.046992301940918 | KNN Loss: 4.997705936431885 | 

Epoch 225 / 500 | iteration 5 / 30 | Total Loss: 6.075874328613281 | KNN Loss: 5.043334484100342 | BCE Loss: 1.0325398445129395
Epoch 225 / 500 | iteration 10 / 30 | Total Loss: 6.046077251434326 | KNN Loss: 5.01769495010376 | BCE Loss: 1.0283823013305664
Epoch 225 / 500 | iteration 15 / 30 | Total Loss: 6.115257263183594 | KNN Loss: 5.076364517211914 | BCE Loss: 1.0388925075531006
Epoch 225 / 500 | iteration 20 / 30 | Total Loss: 6.017975330352783 | KNN Loss: 5.009749889373779 | BCE Loss: 1.008225440979004
Epoch 225 / 500 | iteration 25 / 30 | Total Loss: 6.046792984008789 | KNN Loss: 5.017925262451172 | BCE Loss: 1.0288677215576172
Epoch 226 / 500 | iteration 0 / 30 | Total Loss: 6.091151714324951 | KNN Loss: 5.052041053771973 | BCE Loss: 1.039110541343689
Epoch 226 / 500 | iteration 5 / 30 | Total Loss: 6.066736221313477 | KNN Loss: 5.037837982177734 | BCE Loss: 1.0288984775543213
Epoch 226 / 500 | iteration 10 / 30 | Total Loss: 6.074506759643555 | KNN Loss: 5.015486717224121 | BCE

Epoch 235 / 500 | iteration 25 / 30 | Total Loss: 6.100404739379883 | KNN Loss: 5.044919013977051 | BCE Loss: 1.0554859638214111
Epoch 236 / 500 | iteration 0 / 30 | Total Loss: 6.03847599029541 | KNN Loss: 4.996885299682617 | BCE Loss: 1.041590690612793
Epoch 236 / 500 | iteration 5 / 30 | Total Loss: 6.087399005889893 | KNN Loss: 5.049571514129639 | BCE Loss: 1.037827491760254
Epoch 236 / 500 | iteration 10 / 30 | Total Loss: 6.013916969299316 | KNN Loss: 4.998538494110107 | BCE Loss: 1.0153785943984985
Epoch 236 / 500 | iteration 15 / 30 | Total Loss: 6.03904914855957 | KNN Loss: 5.014948844909668 | BCE Loss: 1.0241000652313232
Epoch 236 / 500 | iteration 20 / 30 | Total Loss: 6.052900314331055 | KNN Loss: 5.003842353820801 | BCE Loss: 1.049057960510254
Epoch 236 / 500 | iteration 25 / 30 | Total Loss: 6.0802903175354 | KNN Loss: 5.021372318267822 | BCE Loss: 1.0589181184768677
Epoch 237 / 500 | iteration 0 / 30 | Total Loss: 6.082122802734375 | KNN Loss: 5.017178535461426 | BCE Los

Epoch 246 / 500 | iteration 15 / 30 | Total Loss: 6.076076507568359 | KNN Loss: 5.017348766326904 | BCE Loss: 1.058727502822876
Epoch 246 / 500 | iteration 20 / 30 | Total Loss: 6.054228782653809 | KNN Loss: 5.003281593322754 | BCE Loss: 1.0509474277496338
Epoch 246 / 500 | iteration 25 / 30 | Total Loss: 6.058218479156494 | KNN Loss: 5.0096049308776855 | BCE Loss: 1.0486135482788086
Epoch 247 / 500 | iteration 0 / 30 | Total Loss: 6.064805030822754 | KNN Loss: 5.00779390335083 | BCE Loss: 1.057011365890503
Epoch 247 / 500 | iteration 5 / 30 | Total Loss: 6.037850856781006 | KNN Loss: 5.013367652893066 | BCE Loss: 1.024483323097229
Epoch 247 / 500 | iteration 10 / 30 | Total Loss: 6.062819480895996 | KNN Loss: 5.0047831535339355 | BCE Loss: 1.0580365657806396
Epoch 247 / 500 | iteration 15 / 30 | Total Loss: 6.045943260192871 | KNN Loss: 5.017796039581299 | BCE Loss: 1.0281474590301514
Epoch 247 / 500 | iteration 20 / 30 | Total Loss: 6.049020767211914 | KNN Loss: 5.00939416885376 | BC

Epoch 257 / 500 | iteration 5 / 30 | Total Loss: 6.0404276847839355 | KNN Loss: 5.000545978546143 | BCE Loss: 1.039881706237793
Epoch 257 / 500 | iteration 10 / 30 | Total Loss: 6.00697660446167 | KNN Loss: 5.000911235809326 | BCE Loss: 1.0060654878616333
Epoch 257 / 500 | iteration 15 / 30 | Total Loss: 6.116912841796875 | KNN Loss: 5.053160667419434 | BCE Loss: 1.0637521743774414
Epoch 257 / 500 | iteration 20 / 30 | Total Loss: 6.043320655822754 | KNN Loss: 5.029742240905762 | BCE Loss: 1.0135784149169922
Epoch 257 / 500 | iteration 25 / 30 | Total Loss: 6.0607075691223145 | KNN Loss: 5.013584136962891 | BCE Loss: 1.0471233129501343
Epoch 258 / 500 | iteration 0 / 30 | Total Loss: 6.0650248527526855 | KNN Loss: 5.014961242675781 | BCE Loss: 1.0500637292861938
Epoch 258 / 500 | iteration 5 / 30 | Total Loss: 6.068080902099609 | KNN Loss: 5.01578950881958 | BCE Loss: 1.0522913932800293
Epoch 258 / 500 | iteration 10 / 30 | Total Loss: 6.0836005210876465 | KNN Loss: 5.024886608123779 |

Epoch 267 / 500 | iteration 25 / 30 | Total Loss: 6.04449462890625 | KNN Loss: 5.001646518707275 | BCE Loss: 1.0428483486175537
Epoch 268 / 500 | iteration 0 / 30 | Total Loss: 6.021522045135498 | KNN Loss: 5.005043983459473 | BCE Loss: 1.0164779424667358
Epoch 268 / 500 | iteration 5 / 30 | Total Loss: 6.012213706970215 | KNN Loss: 5.006466865539551 | BCE Loss: 1.0057470798492432
Epoch 268 / 500 | iteration 10 / 30 | Total Loss: 6.079848289489746 | KNN Loss: 5.036489963531494 | BCE Loss: 1.043358564376831
Epoch 268 / 500 | iteration 15 / 30 | Total Loss: 6.043050289154053 | KNN Loss: 5.009645462036133 | BCE Loss: 1.03340482711792
Epoch 268 / 500 | iteration 20 / 30 | Total Loss: 6.0369133949279785 | KNN Loss: 5.024896621704102 | BCE Loss: 1.0120166540145874
Epoch 268 / 500 | iteration 25 / 30 | Total Loss: 6.0317463874816895 | KNN Loss: 5.00429630279541 | BCE Loss: 1.0274499654769897
Epoch 269 / 500 | iteration 0 / 30 | Total Loss: 6.057718276977539 | KNN Loss: 5.016977310180664 | BCE

Epoch 278 / 500 | iteration 15 / 30 | Total Loss: 6.031328201293945 | KNN Loss: 5.010830402374268 | BCE Loss: 1.0204976797103882
Epoch 278 / 500 | iteration 20 / 30 | Total Loss: 6.067998886108398 | KNN Loss: 5.040978908538818 | BCE Loss: 1.0270200967788696
Epoch 278 / 500 | iteration 25 / 30 | Total Loss: 6.113465785980225 | KNN Loss: 5.053899765014648 | BCE Loss: 1.0595660209655762
Epoch 279 / 500 | iteration 0 / 30 | Total Loss: 6.045914173126221 | KNN Loss: 4.998716831207275 | BCE Loss: 1.0471973419189453
Epoch 279 / 500 | iteration 5 / 30 | Total Loss: 6.039690017700195 | KNN Loss: 5.006494045257568 | BCE Loss: 1.033195972442627
Epoch 279 / 500 | iteration 10 / 30 | Total Loss: 6.049155235290527 | KNN Loss: 5.013298988342285 | BCE Loss: 1.0358562469482422
Epoch 279 / 500 | iteration 15 / 30 | Total Loss: 6.065033912658691 | KNN Loss: 5.043952941894531 | BCE Loss: 1.0210812091827393
Epoch 279 / 500 | iteration 20 / 30 | Total Loss: 6.081622123718262 | KNN Loss: 5.042003631591797 | 

Epoch 289 / 500 | iteration 5 / 30 | Total Loss: 6.084571838378906 | KNN Loss: 5.0586628913879395 | BCE Loss: 1.025909185409546
Epoch 289 / 500 | iteration 10 / 30 | Total Loss: 6.065525054931641 | KNN Loss: 5.025134563446045 | BCE Loss: 1.0403902530670166
Epoch 289 / 500 | iteration 15 / 30 | Total Loss: 6.049985408782959 | KNN Loss: 5.028318881988525 | BCE Loss: 1.0216666460037231
Epoch 289 / 500 | iteration 20 / 30 | Total Loss: 6.048116207122803 | KNN Loss: 5.0104241371154785 | BCE Loss: 1.0376920700073242
Epoch 289 / 500 | iteration 25 / 30 | Total Loss: 6.03642463684082 | KNN Loss: 5.006043434143066 | BCE Loss: 1.0303810834884644
Epoch 290 / 500 | iteration 0 / 30 | Total Loss: 6.052399635314941 | KNN Loss: 5.0444865226745605 | BCE Loss: 1.0079128742218018
Epoch 290 / 500 | iteration 5 / 30 | Total Loss: 6.075341701507568 | KNN Loss: 5.013823509216309 | BCE Loss: 1.0615183115005493
Epoch 290 / 500 | iteration 10 / 30 | Total Loss: 6.096243858337402 | KNN Loss: 5.049172401428223 |

Epoch 299 / 500 | iteration 25 / 30 | Total Loss: 6.061354637145996 | KNN Loss: 5.0059614181518555 | BCE Loss: 1.055393099784851
Epoch 300 / 500 | iteration 0 / 30 | Total Loss: 6.077587127685547 | KNN Loss: 5.029234886169434 | BCE Loss: 1.0483524799346924
Epoch 300 / 500 | iteration 5 / 30 | Total Loss: 6.0340070724487305 | KNN Loss: 5.02151346206665 | BCE Loss: 1.01249361038208
Epoch 300 / 500 | iteration 10 / 30 | Total Loss: 6.073529243469238 | KNN Loss: 5.00510835647583 | BCE Loss: 1.068420648574829
Epoch 300 / 500 | iteration 15 / 30 | Total Loss: 6.057488918304443 | KNN Loss: 5.019248008728027 | BCE Loss: 1.0382410287857056
Epoch 300 / 500 | iteration 20 / 30 | Total Loss: 6.062248229980469 | KNN Loss: 5.009426593780518 | BCE Loss: 1.0528218746185303
Epoch 300 / 500 | iteration 25 / 30 | Total Loss: 6.032425880432129 | KNN Loss: 5.002382755279541 | BCE Loss: 1.030043363571167
Epoch 301 / 500 | iteration 0 / 30 | Total Loss: 6.032160758972168 | KNN Loss: 5.00398588180542 | BCE Lo

Epoch 310 / 500 | iteration 15 / 30 | Total Loss: 6.050946235656738 | KNN Loss: 5.010064601898193 | BCE Loss: 1.040881872177124
Epoch 310 / 500 | iteration 20 / 30 | Total Loss: 6.08287239074707 | KNN Loss: 5.050397872924805 | BCE Loss: 1.0324745178222656
Epoch 310 / 500 | iteration 25 / 30 | Total Loss: 6.0667266845703125 | KNN Loss: 5.0120062828063965 | BCE Loss: 1.0547206401824951
Epoch 311 / 500 | iteration 0 / 30 | Total Loss: 6.068122386932373 | KNN Loss: 5.00664758682251 | BCE Loss: 1.0614749193191528
Epoch 311 / 500 | iteration 5 / 30 | Total Loss: 6.060889720916748 | KNN Loss: 5.00822114944458 | BCE Loss: 1.0526686906814575
Epoch 311 / 500 | iteration 10 / 30 | Total Loss: 6.016558647155762 | KNN Loss: 5.019903182983398 | BCE Loss: 0.9966555833816528
Epoch 311 / 500 | iteration 15 / 30 | Total Loss: 6.031740665435791 | KNN Loss: 5.019370079040527 | BCE Loss: 1.0123704671859741
Epoch 311 / 500 | iteration 20 / 30 | Total Loss: 6.070547580718994 | KNN Loss: 5.025796890258789 | B

Epoch 321 / 500 | iteration 5 / 30 | Total Loss: 6.0484490394592285 | KNN Loss: 5.014145851135254 | BCE Loss: 1.034303069114685
Epoch 321 / 500 | iteration 10 / 30 | Total Loss: 6.083335876464844 | KNN Loss: 5.063794136047363 | BCE Loss: 1.01954185962677
Epoch 321 / 500 | iteration 15 / 30 | Total Loss: 6.043814659118652 | KNN Loss: 5.00516939163208 | BCE Loss: 1.0386455059051514
Epoch 321 / 500 | iteration 20 / 30 | Total Loss: 6.097609996795654 | KNN Loss: 5.064860820770264 | BCE Loss: 1.0327491760253906
Epoch 321 / 500 | iteration 25 / 30 | Total Loss: 6.043438911437988 | KNN Loss: 5.002867221832275 | BCE Loss: 1.0405714511871338
Epoch 322 / 500 | iteration 0 / 30 | Total Loss: 6.0720391273498535 | KNN Loss: 5.010136604309082 | BCE Loss: 1.0619025230407715
Epoch 322 / 500 | iteration 5 / 30 | Total Loss: 6.026899337768555 | KNN Loss: 5.004232406616211 | BCE Loss: 1.0226666927337646
Epoch 322 / 500 | iteration 10 / 30 | Total Loss: 6.087615489959717 | KNN Loss: 5.036663055419922 | BC

Epoch 331 / 500 | iteration 25 / 30 | Total Loss: 6.065479755401611 | KNN Loss: 5.022884845733643 | BCE Loss: 1.0425949096679688
Epoch 332 / 500 | iteration 0 / 30 | Total Loss: 6.044946193695068 | KNN Loss: 5.010969638824463 | BCE Loss: 1.033976435661316
Epoch 332 / 500 | iteration 5 / 30 | Total Loss: 6.0918989181518555 | KNN Loss: 5.042323589324951 | BCE Loss: 1.0495752096176147
Epoch 332 / 500 | iteration 10 / 30 | Total Loss: 6.0352606773376465 | KNN Loss: 5.007987976074219 | BCE Loss: 1.0272728204727173
Epoch 332 / 500 | iteration 15 / 30 | Total Loss: 6.070341110229492 | KNN Loss: 5.015366077423096 | BCE Loss: 1.0549752712249756
Epoch 332 / 500 | iteration 20 / 30 | Total Loss: 6.039263725280762 | KNN Loss: 5.014280319213867 | BCE Loss: 1.0249831676483154
Epoch 332 / 500 | iteration 25 / 30 | Total Loss: 6.06982946395874 | KNN Loss: 5.021162033081055 | BCE Loss: 1.0486674308776855
Epoch 333 / 500 | iteration 0 / 30 | Total Loss: 6.045799255371094 | KNN Loss: 5.017697811126709 | 

Epoch 342 / 500 | iteration 15 / 30 | Total Loss: 6.044277191162109 | KNN Loss: 5.002429962158203 | BCE Loss: 1.0418472290039062
Epoch 342 / 500 | iteration 20 / 30 | Total Loss: 6.087509632110596 | KNN Loss: 5.021456718444824 | BCE Loss: 1.066052794456482
Epoch 342 / 500 | iteration 25 / 30 | Total Loss: 6.054943561553955 | KNN Loss: 4.995555877685547 | BCE Loss: 1.0593876838684082
Epoch 343 / 500 | iteration 0 / 30 | Total Loss: 6.101193904876709 | KNN Loss: 5.061398506164551 | BCE Loss: 1.0397953987121582
Epoch 343 / 500 | iteration 5 / 30 | Total Loss: 6.058645248413086 | KNN Loss: 5.0242767333984375 | BCE Loss: 1.034368634223938
Epoch 343 / 500 | iteration 10 / 30 | Total Loss: 6.0269317626953125 | KNN Loss: 5.0104804039001465 | BCE Loss: 1.0164514780044556
Epoch 343 / 500 | iteration 15 / 30 | Total Loss: 6.0362701416015625 | KNN Loss: 5.021463871002197 | BCE Loss: 1.0148060321807861
Epoch 343 / 500 | iteration 20 / 30 | Total Loss: 6.062179088592529 | KNN Loss: 5.004223823547363

Epoch 353 / 500 | iteration 5 / 30 | Total Loss: 6.074705123901367 | KNN Loss: 5.00678014755249 | BCE Loss: 1.0679250955581665
Epoch 353 / 500 | iteration 10 / 30 | Total Loss: 6.069143295288086 | KNN Loss: 5.030079364776611 | BCE Loss: 1.0390639305114746
Epoch 353 / 500 | iteration 15 / 30 | Total Loss: 6.079097270965576 | KNN Loss: 5.022827625274658 | BCE Loss: 1.056269645690918
Epoch 353 / 500 | iteration 20 / 30 | Total Loss: 6.037054538726807 | KNN Loss: 4.996549606323242 | BCE Loss: 1.0405049324035645
Epoch 353 / 500 | iteration 25 / 30 | Total Loss: 6.071198463439941 | KNN Loss: 5.032068252563477 | BCE Loss: 1.0391300916671753
Epoch 354 / 500 | iteration 0 / 30 | Total Loss: 6.01717472076416 | KNN Loss: 5.002819061279297 | BCE Loss: 1.0143558979034424
Epoch 354 / 500 | iteration 5 / 30 | Total Loss: 6.054311752319336 | KNN Loss: 5.0203752517700195 | BCE Loss: 1.0339365005493164
Epoch 354 / 500 | iteration 10 / 30 | Total Loss: 6.072945594787598 | KNN Loss: 5.025470733642578 | BC

Epoch 363 / 500 | iteration 25 / 30 | Total Loss: 6.025862693786621 | KNN Loss: 5.005902290344238 | BCE Loss: 1.0199604034423828
Epoch 364 / 500 | iteration 0 / 30 | Total Loss: 6.053094387054443 | KNN Loss: 5.010839939117432 | BCE Loss: 1.0422545671463013
Epoch 364 / 500 | iteration 5 / 30 | Total Loss: 6.08757209777832 | KNN Loss: 5.0403289794921875 | BCE Loss: 1.0472432374954224
Epoch 364 / 500 | iteration 10 / 30 | Total Loss: 6.008630752563477 | KNN Loss: 4.998585224151611 | BCE Loss: 1.0100452899932861
Epoch 364 / 500 | iteration 15 / 30 | Total Loss: 6.059114933013916 | KNN Loss: 5.005693435668945 | BCE Loss: 1.0534216165542603
Epoch 364 / 500 | iteration 20 / 30 | Total Loss: 6.034046649932861 | KNN Loss: 5.00489616394043 | BCE Loss: 1.0291506052017212
Epoch 364 / 500 | iteration 25 / 30 | Total Loss: 6.055006980895996 | KNN Loss: 4.999687671661377 | BCE Loss: 1.0553191900253296
Epoch 365 / 500 | iteration 0 / 30 | Total Loss: 6.057852745056152 | KNN Loss: 5.0117597579956055 | 

Epoch 374 / 500 | iteration 15 / 30 | Total Loss: 6.0482683181762695 | KNN Loss: 5.018329620361328 | BCE Loss: 1.0299386978149414
Epoch 374 / 500 | iteration 20 / 30 | Total Loss: 6.034018039703369 | KNN Loss: 5.0100483894348145 | BCE Loss: 1.0239696502685547
Epoch 374 / 500 | iteration 25 / 30 | Total Loss: 6.033761978149414 | KNN Loss: 5.003570556640625 | BCE Loss: 1.0301913022994995
Epoch 375 / 500 | iteration 0 / 30 | Total Loss: 6.045807838439941 | KNN Loss: 4.99943733215332 | BCE Loss: 1.046370267868042
Epoch 375 / 500 | iteration 5 / 30 | Total Loss: 6.078805923461914 | KNN Loss: 5.029544353485107 | BCE Loss: 1.0492613315582275
Epoch 375 / 500 | iteration 10 / 30 | Total Loss: 6.02915096282959 | KNN Loss: 5.011593341827393 | BCE Loss: 1.0175575017929077
Epoch 375 / 500 | iteration 15 / 30 | Total Loss: 6.038719654083252 | KNN Loss: 5.006389617919922 | BCE Loss: 1.0323301553726196
Epoch 375 / 500 | iteration 20 / 30 | Total Loss: 6.044210433959961 | KNN Loss: 4.995811462402344 | 

Epoch 385 / 500 | iteration 5 / 30 | Total Loss: 6.038246154785156 | KNN Loss: 5.018630027770996 | BCE Loss: 1.0196163654327393
Epoch 385 / 500 | iteration 10 / 30 | Total Loss: 6.063436508178711 | KNN Loss: 5.01180362701416 | BCE Loss: 1.0516328811645508
Epoch 385 / 500 | iteration 15 / 30 | Total Loss: 6.043227672576904 | KNN Loss: 5.009031295776367 | BCE Loss: 1.034196376800537
Epoch 385 / 500 | iteration 20 / 30 | Total Loss: 6.082441329956055 | KNN Loss: 5.039413928985596 | BCE Loss: 1.0430275201797485
Epoch 385 / 500 | iteration 25 / 30 | Total Loss: 6.0567779541015625 | KNN Loss: 5.029412746429443 | BCE Loss: 1.0273652076721191
Epoch 386 / 500 | iteration 0 / 30 | Total Loss: 6.055022716522217 | KNN Loss: 5.007907390594482 | BCE Loss: 1.0471152067184448
Epoch 386 / 500 | iteration 5 / 30 | Total Loss: 6.0579328536987305 | KNN Loss: 5.01771879196167 | BCE Loss: 1.0402143001556396
Epoch 386 / 500 | iteration 10 / 30 | Total Loss: 6.091553211212158 | KNN Loss: 5.040466785430908 | B

Epoch 395 / 500 | iteration 25 / 30 | Total Loss: 6.059661865234375 | KNN Loss: 5.016458034515381 | BCE Loss: 1.0432038307189941
Epoch 396 / 500 | iteration 0 / 30 | Total Loss: 6.026154041290283 | KNN Loss: 5.010557651519775 | BCE Loss: 1.0155963897705078
Epoch 396 / 500 | iteration 5 / 30 | Total Loss: 6.109240531921387 | KNN Loss: 5.084547996520996 | BCE Loss: 1.0246927738189697
Epoch 396 / 500 | iteration 10 / 30 | Total Loss: 6.033666610717773 | KNN Loss: 5.0179572105407715 | BCE Loss: 1.015709638595581
Epoch 396 / 500 | iteration 15 / 30 | Total Loss: 6.05462646484375 | KNN Loss: 5.015382289886475 | BCE Loss: 1.0392440557479858
Epoch 396 / 500 | iteration 20 / 30 | Total Loss: 6.010775566101074 | KNN Loss: 5.001898288726807 | BCE Loss: 1.0088770389556885
Epoch 396 / 500 | iteration 25 / 30 | Total Loss: 6.0647873878479 | KNN Loss: 5.006306171417236 | BCE Loss: 1.058481216430664
Epoch 397 / 500 | iteration 0 / 30 | Total Loss: 6.055685520172119 | KNN Loss: 5.032402038574219 | BCE 

Epoch 406 / 500 | iteration 15 / 30 | Total Loss: 6.032354354858398 | KNN Loss: 5.010759353637695 | BCE Loss: 1.0215950012207031
Epoch 406 / 500 | iteration 20 / 30 | Total Loss: 6.065715789794922 | KNN Loss: 5.002964973449707 | BCE Loss: 1.0627508163452148
Epoch 406 / 500 | iteration 25 / 30 | Total Loss: 6.03916072845459 | KNN Loss: 5.018344402313232 | BCE Loss: 1.020816445350647
Epoch 407 / 500 | iteration 0 / 30 | Total Loss: 6.052970886230469 | KNN Loss: 5.029860019683838 | BCE Loss: 1.0231107473373413
Epoch 407 / 500 | iteration 5 / 30 | Total Loss: 6.091856956481934 | KNN Loss: 5.0056610107421875 | BCE Loss: 1.086195945739746
Epoch 407 / 500 | iteration 10 / 30 | Total Loss: 6.0555548667907715 | KNN Loss: 5.015420913696289 | BCE Loss: 1.0401338338851929
Epoch 407 / 500 | iteration 15 / 30 | Total Loss: 6.064697265625 | KNN Loss: 4.999995708465576 | BCE Loss: 1.0647015571594238
Epoch 407 / 500 | iteration 20 / 30 | Total Loss: 6.042248725891113 | KNN Loss: 5.036573886871338 | BCE

Epoch 417 / 500 | iteration 5 / 30 | Total Loss: 6.074240207672119 | KNN Loss: 5.030275821685791 | BCE Loss: 1.0439645051956177
Epoch 417 / 500 | iteration 10 / 30 | Total Loss: 6.041131973266602 | KNN Loss: 4.997561931610107 | BCE Loss: 1.043569803237915
Epoch 417 / 500 | iteration 15 / 30 | Total Loss: 6.030750751495361 | KNN Loss: 4.999375820159912 | BCE Loss: 1.0313748121261597
Epoch 417 / 500 | iteration 20 / 30 | Total Loss: 6.072046279907227 | KNN Loss: 5.025304794311523 | BCE Loss: 1.046741247177124
Epoch 417 / 500 | iteration 25 / 30 | Total Loss: 6.048450469970703 | KNN Loss: 5.025455951690674 | BCE Loss: 1.0229946374893188
Epoch 418 / 500 | iteration 0 / 30 | Total Loss: 6.081450462341309 | KNN Loss: 4.997242450714111 | BCE Loss: 1.0842081308364868
Epoch 418 / 500 | iteration 5 / 30 | Total Loss: 6.0456109046936035 | KNN Loss: 5.006505966186523 | BCE Loss: 1.03910493850708
Epoch 418 / 500 | iteration 10 / 30 | Total Loss: 6.02640438079834 | KNN Loss: 5.01438045501709 | BCE L

Epoch 427 / 500 | iteration 25 / 30 | Total Loss: 6.038956642150879 | KNN Loss: 5.009698867797852 | BCE Loss: 1.0292580127716064
Epoch 428 / 500 | iteration 0 / 30 | Total Loss: 6.060854911804199 | KNN Loss: 5.0085530281066895 | BCE Loss: 1.0523018836975098
Epoch 428 / 500 | iteration 5 / 30 | Total Loss: 6.032310485839844 | KNN Loss: 5.006734371185303 | BCE Loss: 1.0255763530731201
Epoch 428 / 500 | iteration 10 / 30 | Total Loss: 6.07066011428833 | KNN Loss: 5.064802646636963 | BCE Loss: 1.0058575868606567
Epoch 428 / 500 | iteration 15 / 30 | Total Loss: 6.044219017028809 | KNN Loss: 5.006101608276367 | BCE Loss: 1.0381174087524414
Epoch 428 / 500 | iteration 20 / 30 | Total Loss: 6.09729528427124 | KNN Loss: 5.043058395385742 | BCE Loss: 1.054236888885498
Epoch 428 / 500 | iteration 25 / 30 | Total Loss: 6.025849342346191 | KNN Loss: 5.008677005767822 | BCE Loss: 1.0171725749969482
Epoch 429 / 500 | iteration 0 / 30 | Total Loss: 6.078924179077148 | KNN Loss: 5.033099174499512 | BC

Epoch 438 / 500 | iteration 15 / 30 | Total Loss: 6.029055118560791 | KNN Loss: 5.0061211585998535 | BCE Loss: 1.0229339599609375
Epoch 438 / 500 | iteration 20 / 30 | Total Loss: 6.094236373901367 | KNN Loss: 5.068244457244873 | BCE Loss: 1.0259921550750732
Epoch 438 / 500 | iteration 25 / 30 | Total Loss: 6.056457996368408 | KNN Loss: 5.011569976806641 | BCE Loss: 1.0448880195617676
Epoch 439 / 500 | iteration 0 / 30 | Total Loss: 6.049798011779785 | KNN Loss: 5.009235382080078 | BCE Loss: 1.0405628681182861
Epoch 439 / 500 | iteration 5 / 30 | Total Loss: 6.040761947631836 | KNN Loss: 5.00249719619751 | BCE Loss: 1.0382648706436157
Epoch 439 / 500 | iteration 10 / 30 | Total Loss: 6.032158851623535 | KNN Loss: 5.015150547027588 | BCE Loss: 1.0170080661773682
Epoch 439 / 500 | iteration 15 / 30 | Total Loss: 6.0204758644104 | KNN Loss: 5.023492336273193 | BCE Loss: 0.996983528137207
Epoch 439 / 500 | iteration 20 / 30 | Total Loss: 6.043042182922363 | KNN Loss: 4.9979705810546875 | B

Epoch 449 / 500 | iteration 5 / 30 | Total Loss: 6.048320770263672 | KNN Loss: 5.015078544616699 | BCE Loss: 1.033242106437683
Epoch 449 / 500 | iteration 10 / 30 | Total Loss: 6.057239532470703 | KNN Loss: 5.008315563201904 | BCE Loss: 1.0489239692687988
Epoch 449 / 500 | iteration 15 / 30 | Total Loss: 6.038585662841797 | KNN Loss: 5.007270336151123 | BCE Loss: 1.0313154458999634
Epoch 449 / 500 | iteration 20 / 30 | Total Loss: 6.0623016357421875 | KNN Loss: 5.033756256103516 | BCE Loss: 1.028545618057251
Epoch 449 / 500 | iteration 25 / 30 | Total Loss: 6.041375160217285 | KNN Loss: 5.020190715789795 | BCE Loss: 1.0211846828460693
Epoch 450 / 500 | iteration 0 / 30 | Total Loss: 6.0682454109191895 | KNN Loss: 5.044124126434326 | BCE Loss: 1.0241211652755737
Epoch 450 / 500 | iteration 5 / 30 | Total Loss: 6.0270466804504395 | KNN Loss: 5.012785911560059 | BCE Loss: 1.0142606496810913
Epoch 450 / 500 | iteration 10 / 30 | Total Loss: 6.047994136810303 | KNN Loss: 5.017503261566162 |

Epoch 459 / 500 | iteration 25 / 30 | Total Loss: 6.041540145874023 | KNN Loss: 5.014594078063965 | BCE Loss: 1.0269463062286377
Epoch   460: reducing learning rate of group 0 to 3.8655e-08.
Epoch 460 / 500 | iteration 0 / 30 | Total Loss: 6.061495304107666 | KNN Loss: 5.018918514251709 | BCE Loss: 1.042576789855957
Epoch 460 / 500 | iteration 5 / 30 | Total Loss: 6.020200252532959 | KNN Loss: 5.017342567443848 | BCE Loss: 1.0028576850891113
Epoch 460 / 500 | iteration 10 / 30 | Total Loss: 6.063215255737305 | KNN Loss: 5.003366470336914 | BCE Loss: 1.0598487854003906
Epoch 460 / 500 | iteration 15 / 30 | Total Loss: 6.055025100708008 | KNN Loss: 5.043058395385742 | BCE Loss: 1.0119664669036865
Epoch 460 / 500 | iteration 20 / 30 | Total Loss: 6.0618367195129395 | KNN Loss: 5.034313201904297 | BCE Loss: 1.0275235176086426
Epoch 460 / 500 | iteration 25 / 30 | Total Loss: 6.017180919647217 | KNN Loss: 4.99492883682251 | BCE Loss: 1.022252082824707
Epoch 461 / 500 | iteration 0 / 30 | To

Epoch 470 / 500 | iteration 15 / 30 | Total Loss: 6.034457206726074 | KNN Loss: 4.995636940002441 | BCE Loss: 1.038820505142212
Epoch 470 / 500 | iteration 20 / 30 | Total Loss: 6.063051700592041 | KNN Loss: 5.032251358032227 | BCE Loss: 1.0308003425598145
Epoch 470 / 500 | iteration 25 / 30 | Total Loss: 6.077188491821289 | KNN Loss: 5.054178237915039 | BCE Loss: 1.023010492324829
Epoch   471: reducing learning rate of group 0 to 2.7058e-08.
Epoch 471 / 500 | iteration 0 / 30 | Total Loss: 6.084349632263184 | KNN Loss: 5.0446295738220215 | BCE Loss: 1.039720058441162
Epoch 471 / 500 | iteration 5 / 30 | Total Loss: 6.040957927703857 | KNN Loss: 5.012941360473633 | BCE Loss: 1.0280165672302246
Epoch 471 / 500 | iteration 10 / 30 | Total Loss: 6.046894550323486 | KNN Loss: 5.032792091369629 | BCE Loss: 1.0141024589538574
Epoch 471 / 500 | iteration 15 / 30 | Total Loss: 6.071711540222168 | KNN Loss: 5.06589412689209 | BCE Loss: 1.0058176517486572
Epoch 471 / 500 | iteration 20 / 30 | To

Epoch 481 / 500 | iteration 5 / 30 | Total Loss: 6.049658298492432 | KNN Loss: 5.0222487449646 | BCE Loss: 1.027409553527832
Epoch 481 / 500 | iteration 10 / 30 | Total Loss: 6.035114288330078 | KNN Loss: 5.024486541748047 | BCE Loss: 1.0106276273727417
Epoch 481 / 500 | iteration 15 / 30 | Total Loss: 6.125916481018066 | KNN Loss: 5.033674716949463 | BCE Loss: 1.0922417640686035
Epoch 481 / 500 | iteration 20 / 30 | Total Loss: 6.008894443511963 | KNN Loss: 4.996998310089111 | BCE Loss: 1.0118961334228516
Epoch 481 / 500 | iteration 25 / 30 | Total Loss: 6.024177551269531 | KNN Loss: 5.020304203033447 | BCE Loss: 1.0038731098175049
Epoch 482 / 500 | iteration 0 / 30 | Total Loss: 6.044114112854004 | KNN Loss: 5.003721237182617 | BCE Loss: 1.0403928756713867
Epoch 482 / 500 | iteration 5 / 30 | Total Loss: 6.095734596252441 | KNN Loss: 5.062690734863281 | BCE Loss: 1.0330440998077393
Epoch 482 / 500 | iteration 10 / 30 | Total Loss: 6.028258323669434 | KNN Loss: 5.004518032073975 | BCE

Epoch 491 / 500 | iteration 25 / 30 | Total Loss: 6.071241855621338 | KNN Loss: 5.025339126586914 | BCE Loss: 1.0459027290344238
Epoch 492 / 500 | iteration 0 / 30 | Total Loss: 6.095396041870117 | KNN Loss: 5.031296253204346 | BCE Loss: 1.0640995502471924
Epoch 492 / 500 | iteration 5 / 30 | Total Loss: 6.084624767303467 | KNN Loss: 5.0562663078308105 | BCE Loss: 1.0283584594726562
Epoch 492 / 500 | iteration 10 / 30 | Total Loss: 6.022395133972168 | KNN Loss: 5.016458511352539 | BCE Loss: 1.005936622619629
Epoch 492 / 500 | iteration 15 / 30 | Total Loss: 6.068243026733398 | KNN Loss: 5.025578498840332 | BCE Loss: 1.0426645278930664
Epoch 492 / 500 | iteration 20 / 30 | Total Loss: 6.061152458190918 | KNN Loss: 5.007332801818848 | BCE Loss: 1.0538197755813599
Epoch 492 / 500 | iteration 25 / 30 | Total Loss: 6.076411247253418 | KNN Loss: 5.056914329528809 | BCE Loss: 1.0194966793060303
Epoch 493 / 500 | iteration 0 / 30 | Total Loss: 6.053962707519531 | KNN Loss: 5.027441024780273 | 

In [7]:
outputs, iterm = model(dataset[67][0].unsqueeze(0).to(device), return_intermidiate=True)
print(outputs)
print(outputs.min())
print(outputs.max())

tensor([[ 3.0169,  3.7089,  2.5876,  3.0771,  3.3849,  0.7724,  2.6627,  2.2804,
          2.2103,  2.0220,  2.1847,  2.2892,  0.7973,  1.7530,  1.2190,  1.4315,
          2.7459,  2.7280,  2.2349,  1.8349,  1.7019,  2.5745,  2.2448,  2.6221,
          2.5252,  1.7932,  2.1448,  1.4807,  1.5369,  0.4111, -0.1998,  1.0286,
          0.2046,  0.9527,  1.4791,  1.4321,  1.0842,  2.8008,  0.6504,  1.2694,
          0.9280, -0.6987, -0.2524,  2.3327,  1.7224,  0.7449, -0.1590,  0.1547,
          1.4286,  2.5432,  1.3626,  0.1273,  1.3694,  0.5756, -0.5356,  1.0705,
          1.4649,  1.3859,  1.3082,  1.8057,  0.5729,  0.8208,  0.1695,  1.7984,
          1.3383,  1.7248, -1.7942,  0.3608,  2.3759,  2.2260,  2.4704,  0.4697,
          1.3686,  2.4387,  1.9007,  1.1780,  0.3153,  0.7410,  0.2198,  1.5819,
          0.0308,  0.4000,  1.8466, -0.3640,  0.2809, -1.0456, -2.3745, -0.1721,
          0.4682, -1.8568,  0.5005, -0.0813, -0.4984, -1.0866,  0.5996,  1.2267,
         -0.7037, -0.6617,  

In [8]:
plt.figure()
plt.plot(losses)
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [9]:
dataset.transform = torchvision.transforms.Compose([
    BinaryEncodingTransform(mapping=dataset.items_to_idx),
]
)
dataset.target_transform = torchvision.transforms.Compose([
    BinaryEncodingTransform(mapping=dataset.items_to_idx),
]
)

In [12]:
dataset_ = [d[0].cpu() for d in dataset]

In [13]:
model = model.eval().cpu()
projections = model.calculate_intermidiate(dataset_)

100%|██████████| 15/15 [00:00<00:00, 86.80it/s]


In [14]:
distances = pairwise_distances(projections)
# distances = np.triu(distances)
distances_f = distances.flatten()

plt.matshow(distances)
plt.colorbar()
plt.figure()
plt.hist(distances_f[distances_f > 0], bins=1000)
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

# Fit DBSCAN and calculate indices


In [15]:
clusters = DBSCAN(eps=0.2, min_samples=80).fit_predict(projections)
# scores = []
# best_score = float('inf')
# clusters = None
# range_ = list(range(5, 20))
# for k in tqdm(range_):
#     y = GaussianMixture(n_components=k).fit_predict(projections)
#     cur_score = davies_bouldin_score(projections, y)
#     scores.append(cur_score)
    
#     if cur_score < best_score:
#         best_score = cur_score
#         clusters = y

In [16]:
perplexity = 100

p = reduce_dims_and_plot(projections[clusters != -1],
                         y=clusters[clusters != -1],
                         title=f'perplexity: {perplexity}',
                         file_name=None,
                         perplexity=perplexity,
                         library='Multicore-TSNE',
                         perform_PCA=False,
                         projected=None,
                         figure_type='2d',
                         show_figure=True,
                         close_figure=False,
                         text=None)    

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [17]:
# from sklearn.tree import DecisionTreeClassifier
# from sklearn import tree
# from sklearn.tree import _tree

In [18]:
tensor_dataset = torch.stack(dataset_)

In [19]:
# clf = DecisionTreeClassifier(max_depth=200, min_samples_leaf=5)
# clf = clf.fit(tensor_dataset[clusters!=-1], clusters[clusters != -1])
# print(clf.score(tensor_dataset[clusters!=-1], clusters[clusters != -1]))
# print(clf.get_depth())

In [20]:
# scores = []
# for min_samples in range(1,50, 1):
#     clf = DecisionTreeClassifier(max_depth=200, min_samples_leaf=min_samples)
#     clf = clf.fit(tensor_dataset[clusters!=-1], clusters[clusters != -1])
#     scores.append(clf.score(tensor_dataset[clusters!=-1], clusters[clusters != -1]))
    
# plt.figure()
# plt.plot(list(range(1,50, 1)), scores)
# plt.show()

In [21]:
def get_rules(tree, feature_names, class_names):
    tree_ = tree.tree_
    feature_name = [
        feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!"
        for i in tree_.feature
    ]

    paths = []
    path = []
    
    def recurse(node, path, paths):
        if tree_.feature[node] != _tree.TREE_UNDEFINED:
            name = feature_name[node]
            threshold = tree_.threshold[node]
            p1, p2 = list(path), list(path)
#             p1 += [f"({name} <= {np.round(threshold, 3)})"]
            p1 += [(name, '<=', np.round(threshold, 3))]
            recurse(tree_.children_left[node], p1, paths)
            p2 += [(name, '>', np.round(threshold, 3))]
            recurse(tree_.children_right[node], p2, paths)
        else:
            path += [(tree_.value[node], tree_.n_node_samples[node])]
            paths += [path]
            
    recurse(0, path, paths)

    # sort by samples count
    samples_count = [p[-1][1] for p in paths]
    ii = list(np.argsort(samples_count))
    paths = [paths[i] for i in reversed(ii)]
    
    rules = []
    for path in paths:
        rule = []
        
        for p in path[:-1]:
            rule += [p]
        target = " then "
        if class_names is None:
            target += "response: "+str(np.round(path[-1][0][0][0],3))
        else:
            classes = path[-1][0][0]
            l = np.argmax(classes)
            target += f"class: {class_names[l]} (proba: {np.round(100.0*classes[l]/np.sum(classes),2)}%)"
           
        proba = np.round(100.0*classes[l]/np.sum(classes),2)
        target += f" | based on {path[-1][1]:,} samples"
        rule_wrapper = {'target': target, 'rule': rule, 'proba': proba}
        rules += [rule_wrapper]
        
    return rules

In [22]:
# rules = get_rules(clf, dataset.items, clusters[clusters != -1])

# for rule in rules:
#     n_pos = 0
#     for c,p,v in rule['rule']:
#         if p == '>':
#             n_pos += 1
#     rule['pos'] = n_pos

In [23]:
# plt.figure()
# probs = [r['proba'] for r in rules]
# plt.hist(probs, bins = 100)
# plt.show()

In [24]:
# rules = sorted(rules, key=lambda x:x['pos'])
# rules = [r for r in rules if r['proba'] > 50]
# print(len(rules))

In [25]:
# for i in range(17):
#     r_i = rules[i]
#     print(f"------------- rule {i} length {len(r_i)} -------------")
#     print(r_i)

# Train a Soft-Decision-Tree given the self-labels

## Prepare the dataset

In [26]:
tree_dataset = list(zip(tensor_dataset[clusters!=-1], clusters[clusters != -1]))
batch_size = 512
tree_loader = torch.utils.data.DataLoader(tree_dataset, batch_size=batch_size, shuffle=True)

# Define how we prune the weights of a node

In [27]:
def prune_node(node_weights, factor=1):
    w = node_weights.cpu().detach().numpy()
    mean_ = np.mean(w)
    std_ = np.std(w)
    node_weights[((mean_ - std_ * factor) < node_weights) & (node_weights < (mean_ + std_ * factor))] = 0
    return node_weights

def prune_node_keep(node_weights, keep=4):
    w = node_weights.cpu().detach().numpy()
    throw_idx = np.argsort(abs(w))[:-keep]
    node_weights[throw_idx] = 0
    return node_weights

def prune_tree(tree_, factor):
    new_weights = tree_.inner_nodes.weight.clone()
    for i in range(new_weights.shape[0]):
        res = prune_node_keep(new_weights[i, :], factor)
        new_weights[i, :] = res

    with torch.no_grad():
        tree_.inner_nodes.weight.copy_(new_weights)
        
def sparseness(x):
    s = []
    for i in range(x.shape[0]):
        x_ = x[i, :]
        sp = (len(x_) - torch.norm(x_, 0).item()) / len(x_)
        s.append(sp)
    return np.mean(s)

def compute_regularization_by_level(tree):
    total_reg = 0
    for i in range(tree.inner_nodes.weight.shape[0]):
        cur_level = np.floor(np.log2(i+1))
        node_reg = torch.norm(tree.inner_nodes.weight[i].view(-1), 2)
        total_reg += 2**(-cur_level) * node_reg
    return total_reg

def show_sparseness(tree):
    avg_sp = sparseness(tree.inner_nodes.weight)
    print(f"Average sparseness: {avg_sp}")
    layer = 0
    sps = []
    for i in range(tree.inner_nodes.weight.shape[0]):
        cur_layer = int(np.floor(np.log2(i+1)))
        if cur_layer != layer:
            print(f"layer {layer}: {np.mean(sps)}")
            sps = []
            layer = cur_layer

        x_ = tree.inner_nodes.weight[i, :]
        sp = (len(x_) - torch.norm(x_, 0).item()) / len(x_)
        sps.append(sp)
        
    return avg_sp

## Training configurations

In [28]:
def do_epoch(model, loader, device, log_interval, losses, accs, epoch, iteration):
    model = model.train()
    for batch_idx, (data, target) in enumerate(loader):
        iteration += 1
        data, target = data.to(device), target.to(device)

        output, penalty = tree.forward(data)

        # Loss
        loss_tree = criterion(output, target.view(-1))

        # Penalty
        loss_tree += penalty

        # Sparse regularization
#         fc_params = torch.cat([x.view(-1) for x in tree.inner_nodes.parameters()])
#         regularization = sparsity_lamda * torch.norm(fc_params, 2)
        regularization = sparsity_lamda * compute_regularization_by_level(tree)
        loss = loss_tree

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        losses.append(loss.item())

        pred = output.data.max(1)[1]
        correct = pred.eq(target.view(-1).data).sum()
        accs.append(correct.item() / data.size()[0])

        # Print training status
        if batch_idx % log_interval == 0:
            print(f"Epoch: {epoch:02d} | Batch: {batch_idx:03d} / {len(loader):03d} | Total loss: {loss.item():.3f} | Reg loss: {regularization.item():.3f} | Tree loss: {loss_tree.item():.3f} | Accuracy: {correct.item() / data.size()[0]:03f} | {round((time.time() - start_time) / iteration, 3)} sec/iter")
            
    return iteration


In [29]:
lr = 5e-3
weight_decay = 5e-4
sparsity_lamda = 2e-3
epochs = 100
output_dim = len(set(clusters))
log_interval = 1
use_cuda = device != 'cpu'

In [30]:
tree = SDT(input_dim=tensor_dataset.shape[1], output_dim=len(clusters - 1), depth=tree_depth, lamda=1e-3, use_cuda=use_cuda)
optimizer = torch.optim.Adam(tree.parameters(),
                                 lr=lr,
                                 weight_decay=weight_decay)
criterion = nn.CrossEntropyLoss()
tree = tree.to(device)

In [31]:
losses = []
accs = []
sparsity = []

In [32]:
start_time = time.time()
iteration = 0
for epoch in range(epochs):
    # Training
    avg_sp = show_sparseness(tree)
    sparsity.append(avg_sp)
    iteration = do_epoch(tree, tree_loader, device, log_interval, losses, accs, epoch, iteration)
    
    if epoch % 1 == 0:
        prune_tree(tree, factor=3)
        

Average sparseness: 0.0
layer 0: 0.0
layer 1: 0.0
layer 2: 0.0
layer 3: 0.0
layer 4: 0.0
layer 5: 0.0
layer 6: 0.0
Epoch: 00 | Batch: 000 / 029 | Total loss: 9.645 | Reg loss: 0.009 | Tree loss: 9.645 | Accuracy: 0.000000 | 0.291 sec/iter
Epoch: 00 | Batch: 001 / 029 | Total loss: 9.629 | Reg loss: 0.009 | Tree loss: 9.629 | Accuracy: 0.000000 | 0.262 sec/iter
Epoch: 00 | Batch: 002 / 029 | Total loss: 9.620 | Reg loss: 0.008 | Tree loss: 9.620 | Accuracy: 0.000000 | 0.248 sec/iter
Epoch: 00 | Batch: 003 / 029 | Total loss: 9.608 | Reg loss: 0.008 | Tree loss: 9.608 | Accuracy: 0.000000 | 0.242 sec/iter
Epoch: 00 | Batch: 004 / 029 | Total loss: 9.592 | Reg loss: 0.008 | Tree loss: 9.592 | Accuracy: 0.000000 | 0.24 sec/iter
Epoch: 00 | Batch: 005 / 029 | Total loss: 9.586 | Reg loss: 0.008 | Tree loss: 9.586 | Accuracy: 0.000000 | 0.239 sec/iter
Epoch: 00 | Batch: 006 / 029 | Total loss: 9.574 | Reg loss: 0.007 | Tree loss: 9.574 | Accuracy: 0.000000 | 0.238 sec/iter
Epoch: 00 | Batch:

Epoch: 02 | Batch: 004 / 029 | Total loss: 9.271 | Reg loss: 0.008 | Tree loss: 9.271 | Accuracy: 0.177734 | 0.238 sec/iter
Epoch: 02 | Batch: 005 / 029 | Total loss: 9.249 | Reg loss: 0.008 | Tree loss: 9.249 | Accuracy: 0.187500 | 0.238 sec/iter
Epoch: 02 | Batch: 006 / 029 | Total loss: 9.252 | Reg loss: 0.008 | Tree loss: 9.252 | Accuracy: 0.195312 | 0.238 sec/iter
Epoch: 02 | Batch: 007 / 029 | Total loss: 9.249 | Reg loss: 0.009 | Tree loss: 9.249 | Accuracy: 0.191406 | 0.237 sec/iter
Epoch: 02 | Batch: 008 / 029 | Total loss: 9.213 | Reg loss: 0.009 | Tree loss: 9.213 | Accuracy: 0.220703 | 0.237 sec/iter
Epoch: 02 | Batch: 009 / 029 | Total loss: 9.202 | Reg loss: 0.009 | Tree loss: 9.202 | Accuracy: 0.205078 | 0.237 sec/iter
Epoch: 02 | Batch: 010 / 029 | Total loss: 9.199 | Reg loss: 0.010 | Tree loss: 9.199 | Accuracy: 0.199219 | 0.237 sec/iter
Epoch: 02 | Batch: 011 / 029 | Total loss: 9.189 | Reg loss: 0.010 | Tree loss: 9.189 | Accuracy: 0.181641 | 0.237 sec/iter
Epoch: 0

Epoch: 04 | Batch: 009 / 029 | Total loss: 8.869 | Reg loss: 0.014 | Tree loss: 8.869 | Accuracy: 0.193359 | 0.237 sec/iter
Epoch: 04 | Batch: 010 / 029 | Total loss: 8.850 | Reg loss: 0.015 | Tree loss: 8.850 | Accuracy: 0.191406 | 0.237 sec/iter
Epoch: 04 | Batch: 011 / 029 | Total loss: 8.814 | Reg loss: 0.015 | Tree loss: 8.814 | Accuracy: 0.214844 | 0.237 sec/iter
Epoch: 04 | Batch: 012 / 029 | Total loss: 8.806 | Reg loss: 0.015 | Tree loss: 8.806 | Accuracy: 0.218750 | 0.237 sec/iter
Epoch: 04 | Batch: 013 / 029 | Total loss: 8.788 | Reg loss: 0.016 | Tree loss: 8.788 | Accuracy: 0.191406 | 0.237 sec/iter
Epoch: 04 | Batch: 014 / 029 | Total loss: 8.774 | Reg loss: 0.016 | Tree loss: 8.774 | Accuracy: 0.236328 | 0.237 sec/iter
Epoch: 04 | Batch: 015 / 029 | Total loss: 8.762 | Reg loss: 0.017 | Tree loss: 8.762 | Accuracy: 0.224609 | 0.236 sec/iter
Epoch: 04 | Batch: 016 / 029 | Total loss: 8.756 | Reg loss: 0.017 | Tree loss: 8.756 | Accuracy: 0.246094 | 0.236 sec/iter
Epoch: 0

Epoch: 06 | Batch: 014 / 029 | Total loss: 8.319 | Reg loss: 0.020 | Tree loss: 8.319 | Accuracy: 0.314453 | 0.237 sec/iter
Epoch: 06 | Batch: 015 / 029 | Total loss: 8.297 | Reg loss: 0.020 | Tree loss: 8.297 | Accuracy: 0.298828 | 0.237 sec/iter
Epoch: 06 | Batch: 016 / 029 | Total loss: 8.303 | Reg loss: 0.021 | Tree loss: 8.303 | Accuracy: 0.283203 | 0.237 sec/iter
Epoch: 06 | Batch: 017 / 029 | Total loss: 8.281 | Reg loss: 0.021 | Tree loss: 8.281 | Accuracy: 0.263672 | 0.237 sec/iter
Epoch: 06 | Batch: 018 / 029 | Total loss: 8.255 | Reg loss: 0.022 | Tree loss: 8.255 | Accuracy: 0.304688 | 0.237 sec/iter
Epoch: 06 | Batch: 019 / 029 | Total loss: 8.248 | Reg loss: 0.022 | Tree loss: 8.248 | Accuracy: 0.302734 | 0.237 sec/iter
Epoch: 06 | Batch: 020 / 029 | Total loss: 8.234 | Reg loss: 0.022 | Tree loss: 8.234 | Accuracy: 0.304688 | 0.237 sec/iter
Epoch: 06 | Batch: 021 / 029 | Total loss: 8.214 | Reg loss: 0.023 | Tree loss: 8.214 | Accuracy: 0.302734 | 0.237 sec/iter
Epoch: 0

Epoch: 08 | Batch: 019 / 029 | Total loss: 7.618 | Reg loss: 0.025 | Tree loss: 7.618 | Accuracy: 0.300781 | 0.238 sec/iter
Epoch: 08 | Batch: 020 / 029 | Total loss: 7.597 | Reg loss: 0.026 | Tree loss: 7.597 | Accuracy: 0.283203 | 0.238 sec/iter
Epoch: 08 | Batch: 021 / 029 | Total loss: 7.554 | Reg loss: 0.026 | Tree loss: 7.554 | Accuracy: 0.328125 | 0.238 sec/iter
Epoch: 08 | Batch: 022 / 029 | Total loss: 7.569 | Reg loss: 0.026 | Tree loss: 7.569 | Accuracy: 0.310547 | 0.238 sec/iter
Epoch: 08 | Batch: 023 / 029 | Total loss: 7.570 | Reg loss: 0.027 | Tree loss: 7.570 | Accuracy: 0.300781 | 0.238 sec/iter
Epoch: 08 | Batch: 024 / 029 | Total loss: 7.517 | Reg loss: 0.027 | Tree loss: 7.517 | Accuracy: 0.320312 | 0.238 sec/iter
Epoch: 08 | Batch: 025 / 029 | Total loss: 7.483 | Reg loss: 0.027 | Tree loss: 7.483 | Accuracy: 0.328125 | 0.238 sec/iter
Epoch: 08 | Batch: 026 / 029 | Total loss: 7.507 | Reg loss: 0.028 | Tree loss: 7.507 | Accuracy: 0.242188 | 0.238 sec/iter
Epoch: 0

Epoch: 10 | Batch: 024 / 029 | Total loss: 6.979 | Reg loss: 0.028 | Tree loss: 6.979 | Accuracy: 0.263672 | 0.238 sec/iter
Epoch: 10 | Batch: 025 / 029 | Total loss: 6.923 | Reg loss: 0.028 | Tree loss: 6.923 | Accuracy: 0.289062 | 0.238 sec/iter
Epoch: 10 | Batch: 026 / 029 | Total loss: 6.890 | Reg loss: 0.029 | Tree loss: 6.890 | Accuracy: 0.298828 | 0.238 sec/iter
Epoch: 10 | Batch: 027 / 029 | Total loss: 6.862 | Reg loss: 0.029 | Tree loss: 6.862 | Accuracy: 0.308594 | 0.238 sec/iter
Epoch: 10 | Batch: 028 / 029 | Total loss: 6.856 | Reg loss: 0.029 | Tree loss: 6.856 | Accuracy: 0.297571 | 0.238 sec/iter
Average sparseness: 0.9821428571428573
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
Epoch: 11 | Batch: 000 / 029 | Total loss: 7.262 | Reg loss: 0.025 | Tree loss: 7.262 | Accuracy: 0.287109 | 0.238 sec/iter
Epoch: 11 | Batch: 001

Epoch: 13 | Batch: 000 / 029 | Total loss: 6.752 | Reg loss: 0.026 | Tree loss: 6.752 | Accuracy: 0.263672 | 0.237 sec/iter
Epoch: 13 | Batch: 001 / 029 | Total loss: 6.678 | Reg loss: 0.026 | Tree loss: 6.678 | Accuracy: 0.308594 | 0.237 sec/iter
Epoch: 13 | Batch: 002 / 029 | Total loss: 6.673 | Reg loss: 0.026 | Tree loss: 6.673 | Accuracy: 0.322266 | 0.237 sec/iter
Epoch: 13 | Batch: 003 / 029 | Total loss: 6.622 | Reg loss: 0.026 | Tree loss: 6.622 | Accuracy: 0.275391 | 0.237 sec/iter
Epoch: 13 | Batch: 004 / 029 | Total loss: 6.611 | Reg loss: 0.026 | Tree loss: 6.611 | Accuracy: 0.302734 | 0.237 sec/iter
Epoch: 13 | Batch: 005 / 029 | Total loss: 6.580 | Reg loss: 0.027 | Tree loss: 6.580 | Accuracy: 0.310547 | 0.237 sec/iter
Epoch: 13 | Batch: 006 / 029 | Total loss: 6.590 | Reg loss: 0.027 | Tree loss: 6.590 | Accuracy: 0.287109 | 0.237 sec/iter
Epoch: 13 | Batch: 007 / 029 | Total loss: 6.549 | Reg loss: 0.027 | Tree loss: 6.549 | Accuracy: 0.298828 | 0.237 sec/iter
Epoch: 1

Epoch: 15 | Batch: 005 / 029 | Total loss: 6.101 | Reg loss: 0.028 | Tree loss: 6.101 | Accuracy: 0.302734 | 0.237 sec/iter
Epoch: 15 | Batch: 006 / 029 | Total loss: 6.094 | Reg loss: 0.028 | Tree loss: 6.094 | Accuracy: 0.287109 | 0.237 sec/iter
Epoch: 15 | Batch: 007 / 029 | Total loss: 6.090 | Reg loss: 0.028 | Tree loss: 6.090 | Accuracy: 0.279297 | 0.237 sec/iter
Epoch: 15 | Batch: 008 / 029 | Total loss: 6.069 | Reg loss: 0.028 | Tree loss: 6.069 | Accuracy: 0.253906 | 0.237 sec/iter
Epoch: 15 | Batch: 009 / 029 | Total loss: 6.042 | Reg loss: 0.028 | Tree loss: 6.042 | Accuracy: 0.287109 | 0.237 sec/iter
Epoch: 15 | Batch: 010 / 029 | Total loss: 5.996 | Reg loss: 0.028 | Tree loss: 5.996 | Accuracy: 0.281250 | 0.237 sec/iter
Epoch: 15 | Batch: 011 / 029 | Total loss: 5.964 | Reg loss: 0.028 | Tree loss: 5.964 | Accuracy: 0.300781 | 0.237 sec/iter
Epoch: 15 | Batch: 012 / 029 | Total loss: 5.957 | Reg loss: 0.028 | Tree loss: 5.957 | Accuracy: 0.324219 | 0.237 sec/iter
Epoch: 1

Epoch: 17 | Batch: 010 / 029 | Total loss: 5.531 | Reg loss: 0.029 | Tree loss: 5.531 | Accuracy: 0.314453 | 0.238 sec/iter
Epoch: 17 | Batch: 011 / 029 | Total loss: 5.525 | Reg loss: 0.029 | Tree loss: 5.525 | Accuracy: 0.287109 | 0.238 sec/iter
Epoch: 17 | Batch: 012 / 029 | Total loss: 5.520 | Reg loss: 0.029 | Tree loss: 5.520 | Accuracy: 0.298828 | 0.238 sec/iter
Epoch: 17 | Batch: 013 / 029 | Total loss: 5.523 | Reg loss: 0.029 | Tree loss: 5.523 | Accuracy: 0.294922 | 0.238 sec/iter
Epoch: 17 | Batch: 014 / 029 | Total loss: 5.497 | Reg loss: 0.029 | Tree loss: 5.497 | Accuracy: 0.271484 | 0.238 sec/iter
Epoch: 17 | Batch: 015 / 029 | Total loss: 5.453 | Reg loss: 0.029 | Tree loss: 5.453 | Accuracy: 0.310547 | 0.238 sec/iter
Epoch: 17 | Batch: 016 / 029 | Total loss: 5.424 | Reg loss: 0.029 | Tree loss: 5.424 | Accuracy: 0.320312 | 0.238 sec/iter
Epoch: 17 | Batch: 017 / 029 | Total loss: 5.443 | Reg loss: 0.029 | Tree loss: 5.443 | Accuracy: 0.289062 | 0.238 sec/iter
Epoch: 1

Epoch: 19 | Batch: 015 / 029 | Total loss: 5.042 | Reg loss: 0.029 | Tree loss: 5.042 | Accuracy: 0.292969 | 0.238 sec/iter
Epoch: 19 | Batch: 016 / 029 | Total loss: 5.057 | Reg loss: 0.029 | Tree loss: 5.057 | Accuracy: 0.300781 | 0.238 sec/iter
Epoch: 19 | Batch: 017 / 029 | Total loss: 5.043 | Reg loss: 0.029 | Tree loss: 5.043 | Accuracy: 0.277344 | 0.238 sec/iter
Epoch: 19 | Batch: 018 / 029 | Total loss: 5.031 | Reg loss: 0.029 | Tree loss: 5.031 | Accuracy: 0.273438 | 0.238 sec/iter
Epoch: 19 | Batch: 019 / 029 | Total loss: 4.999 | Reg loss: 0.030 | Tree loss: 4.999 | Accuracy: 0.273438 | 0.238 sec/iter
Epoch: 19 | Batch: 020 / 029 | Total loss: 4.991 | Reg loss: 0.030 | Tree loss: 4.991 | Accuracy: 0.318359 | 0.238 sec/iter
Epoch: 19 | Batch: 021 / 029 | Total loss: 4.909 | Reg loss: 0.030 | Tree loss: 4.909 | Accuracy: 0.324219 | 0.238 sec/iter
Epoch: 19 | Batch: 022 / 029 | Total loss: 4.969 | Reg loss: 0.030 | Tree loss: 4.969 | Accuracy: 0.287109 | 0.238 sec/iter
Epoch: 1

Epoch: 21 | Batch: 020 / 029 | Total loss: 4.617 | Reg loss: 0.030 | Tree loss: 4.617 | Accuracy: 0.271484 | 0.238 sec/iter
Epoch: 21 | Batch: 021 / 029 | Total loss: 4.575 | Reg loss: 0.030 | Tree loss: 4.575 | Accuracy: 0.259766 | 0.238 sec/iter
Epoch: 21 | Batch: 022 / 029 | Total loss: 4.587 | Reg loss: 0.030 | Tree loss: 4.587 | Accuracy: 0.296875 | 0.238 sec/iter
Epoch: 21 | Batch: 023 / 029 | Total loss: 4.574 | Reg loss: 0.030 | Tree loss: 4.574 | Accuracy: 0.294922 | 0.238 sec/iter
Epoch: 21 | Batch: 024 / 029 | Total loss: 4.554 | Reg loss: 0.030 | Tree loss: 4.554 | Accuracy: 0.294922 | 0.238 sec/iter
Epoch: 21 | Batch: 025 / 029 | Total loss: 4.547 | Reg loss: 0.030 | Tree loss: 4.547 | Accuracy: 0.308594 | 0.238 sec/iter
Epoch: 21 | Batch: 026 / 029 | Total loss: 4.521 | Reg loss: 0.030 | Tree loss: 4.521 | Accuracy: 0.304688 | 0.238 sec/iter
Epoch: 21 | Batch: 027 / 029 | Total loss: 4.542 | Reg loss: 0.030 | Tree loss: 4.542 | Accuracy: 0.283203 | 0.238 sec/iter
Epoch: 2

Epoch: 23 | Batch: 025 / 029 | Total loss: 4.204 | Reg loss: 0.030 | Tree loss: 4.204 | Accuracy: 0.279297 | 0.238 sec/iter
Epoch: 23 | Batch: 026 / 029 | Total loss: 4.165 | Reg loss: 0.030 | Tree loss: 4.165 | Accuracy: 0.300781 | 0.238 sec/iter
Epoch: 23 | Batch: 027 / 029 | Total loss: 4.192 | Reg loss: 0.030 | Tree loss: 4.192 | Accuracy: 0.306641 | 0.238 sec/iter
Epoch: 23 | Batch: 028 / 029 | Total loss: 4.147 | Reg loss: 0.030 | Tree loss: 4.147 | Accuracy: 0.321862 | 0.238 sec/iter
Average sparseness: 0.9821428571428573
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
Epoch: 24 | Batch: 000 / 029 | Total loss: 4.352 | Reg loss: 0.029 | Tree loss: 4.352 | Accuracy: 0.279297 | 0.239 sec/iter
Epoch: 24 | Batch: 001 / 029 | Total loss: 4.328 | Reg loss: 0.029 | Tree loss: 4.328 | Accuracy: 0.287109 | 0.239 sec/iter
Epoch: 24 | Batch: 002

Epoch: 26 | Batch: 000 / 029 | Total loss: 4.039 | Reg loss: 0.029 | Tree loss: 4.039 | Accuracy: 0.273438 | 0.238 sec/iter
Epoch: 26 | Batch: 001 / 029 | Total loss: 3.983 | Reg loss: 0.029 | Tree loss: 3.983 | Accuracy: 0.314453 | 0.238 sec/iter
Epoch: 26 | Batch: 002 / 029 | Total loss: 4.046 | Reg loss: 0.029 | Tree loss: 4.046 | Accuracy: 0.246094 | 0.238 sec/iter
Epoch: 26 | Batch: 003 / 029 | Total loss: 3.975 | Reg loss: 0.029 | Tree loss: 3.975 | Accuracy: 0.273438 | 0.238 sec/iter
Epoch: 26 | Batch: 004 / 029 | Total loss: 4.002 | Reg loss: 0.029 | Tree loss: 4.002 | Accuracy: 0.271484 | 0.238 sec/iter
Epoch: 26 | Batch: 005 / 029 | Total loss: 3.958 | Reg loss: 0.029 | Tree loss: 3.958 | Accuracy: 0.261719 | 0.238 sec/iter
Epoch: 26 | Batch: 006 / 029 | Total loss: 3.951 | Reg loss: 0.029 | Tree loss: 3.951 | Accuracy: 0.292969 | 0.238 sec/iter
Epoch: 26 | Batch: 007 / 029 | Total loss: 3.966 | Reg loss: 0.029 | Tree loss: 3.966 | Accuracy: 0.277344 | 0.238 sec/iter
Epoch: 2

Epoch: 28 | Batch: 005 / 029 | Total loss: 3.682 | Reg loss: 0.029 | Tree loss: 3.682 | Accuracy: 0.310547 | 0.238 sec/iter
Epoch: 28 | Batch: 006 / 029 | Total loss: 3.633 | Reg loss: 0.029 | Tree loss: 3.633 | Accuracy: 0.291016 | 0.238 sec/iter
Epoch: 28 | Batch: 007 / 029 | Total loss: 3.614 | Reg loss: 0.029 | Tree loss: 3.614 | Accuracy: 0.298828 | 0.238 sec/iter
Epoch: 28 | Batch: 008 / 029 | Total loss: 3.638 | Reg loss: 0.029 | Tree loss: 3.638 | Accuracy: 0.296875 | 0.238 sec/iter
Epoch: 28 | Batch: 009 / 029 | Total loss: 3.625 | Reg loss: 0.029 | Tree loss: 3.625 | Accuracy: 0.275391 | 0.238 sec/iter
Epoch: 28 | Batch: 010 / 029 | Total loss: 3.613 | Reg loss: 0.029 | Tree loss: 3.613 | Accuracy: 0.261719 | 0.238 sec/iter
Epoch: 28 | Batch: 011 / 029 | Total loss: 3.612 | Reg loss: 0.029 | Tree loss: 3.612 | Accuracy: 0.275391 | 0.238 sec/iter
Epoch: 28 | Batch: 012 / 029 | Total loss: 3.536 | Reg loss: 0.029 | Tree loss: 3.536 | Accuracy: 0.308594 | 0.238 sec/iter
Epoch: 2

Epoch: 30 | Batch: 010 / 029 | Total loss: 3.349 | Reg loss: 0.029 | Tree loss: 3.349 | Accuracy: 0.251953 | 0.238 sec/iter
Epoch: 30 | Batch: 011 / 029 | Total loss: 3.366 | Reg loss: 0.029 | Tree loss: 3.366 | Accuracy: 0.253906 | 0.238 sec/iter
Epoch: 30 | Batch: 012 / 029 | Total loss: 3.335 | Reg loss: 0.029 | Tree loss: 3.335 | Accuracy: 0.287109 | 0.238 sec/iter
Epoch: 30 | Batch: 013 / 029 | Total loss: 3.315 | Reg loss: 0.029 | Tree loss: 3.315 | Accuracy: 0.300781 | 0.238 sec/iter
Epoch: 30 | Batch: 014 / 029 | Total loss: 3.261 | Reg loss: 0.029 | Tree loss: 3.261 | Accuracy: 0.320312 | 0.238 sec/iter
Epoch: 30 | Batch: 015 / 029 | Total loss: 3.228 | Reg loss: 0.029 | Tree loss: 3.228 | Accuracy: 0.322266 | 0.238 sec/iter
Epoch: 30 | Batch: 016 / 029 | Total loss: 3.233 | Reg loss: 0.029 | Tree loss: 3.233 | Accuracy: 0.316406 | 0.238 sec/iter
Epoch: 30 | Batch: 017 / 029 | Total loss: 3.256 | Reg loss: 0.030 | Tree loss: 3.256 | Accuracy: 0.285156 | 0.238 sec/iter
Epoch: 3

Epoch: 32 | Batch: 015 / 029 | Total loss: 3.010 | Reg loss: 0.029 | Tree loss: 3.010 | Accuracy: 0.312500 | 0.238 sec/iter
Epoch: 32 | Batch: 016 / 029 | Total loss: 3.061 | Reg loss: 0.029 | Tree loss: 3.061 | Accuracy: 0.275391 | 0.238 sec/iter
Epoch: 32 | Batch: 017 / 029 | Total loss: 3.017 | Reg loss: 0.029 | Tree loss: 3.017 | Accuracy: 0.271484 | 0.238 sec/iter
Epoch: 32 | Batch: 018 / 029 | Total loss: 3.012 | Reg loss: 0.029 | Tree loss: 3.012 | Accuracy: 0.289062 | 0.238 sec/iter
Epoch: 32 | Batch: 019 / 029 | Total loss: 2.992 | Reg loss: 0.029 | Tree loss: 2.992 | Accuracy: 0.294922 | 0.238 sec/iter
Epoch: 32 | Batch: 020 / 029 | Total loss: 3.000 | Reg loss: 0.029 | Tree loss: 3.000 | Accuracy: 0.289062 | 0.238 sec/iter
Epoch: 32 | Batch: 021 / 029 | Total loss: 2.999 | Reg loss: 0.029 | Tree loss: 2.999 | Accuracy: 0.291016 | 0.238 sec/iter
Epoch: 32 | Batch: 022 / 029 | Total loss: 2.940 | Reg loss: 0.030 | Tree loss: 2.940 | Accuracy: 0.304688 | 0.238 sec/iter
Epoch: 3

Epoch: 34 | Batch: 020 / 029 | Total loss: 2.745 | Reg loss: 0.029 | Tree loss: 2.745 | Accuracy: 0.328125 | 0.238 sec/iter
Epoch: 34 | Batch: 021 / 029 | Total loss: 2.798 | Reg loss: 0.029 | Tree loss: 2.798 | Accuracy: 0.289062 | 0.238 sec/iter
Epoch: 34 | Batch: 022 / 029 | Total loss: 2.757 | Reg loss: 0.029 | Tree loss: 2.757 | Accuracy: 0.304688 | 0.238 sec/iter
Epoch: 34 | Batch: 023 / 029 | Total loss: 2.777 | Reg loss: 0.029 | Tree loss: 2.777 | Accuracy: 0.285156 | 0.238 sec/iter
Epoch: 34 | Batch: 024 / 029 | Total loss: 2.762 | Reg loss: 0.029 | Tree loss: 2.762 | Accuracy: 0.285156 | 0.238 sec/iter
Epoch: 34 | Batch: 025 / 029 | Total loss: 2.701 | Reg loss: 0.029 | Tree loss: 2.701 | Accuracy: 0.343750 | 0.238 sec/iter
Epoch: 34 | Batch: 026 / 029 | Total loss: 2.777 | Reg loss: 0.029 | Tree loss: 2.777 | Accuracy: 0.269531 | 0.238 sec/iter
Epoch: 34 | Batch: 027 / 029 | Total loss: 2.713 | Reg loss: 0.029 | Tree loss: 2.713 | Accuracy: 0.318359 | 0.238 sec/iter
Epoch: 3

Epoch: 36 | Batch: 025 / 029 | Total loss: 2.608 | Reg loss: 0.029 | Tree loss: 2.608 | Accuracy: 0.287109 | 0.238 sec/iter
Epoch: 36 | Batch: 026 / 029 | Total loss: 2.577 | Reg loss: 0.029 | Tree loss: 2.577 | Accuracy: 0.279297 | 0.238 sec/iter
Epoch: 36 | Batch: 027 / 029 | Total loss: 2.588 | Reg loss: 0.029 | Tree loss: 2.588 | Accuracy: 0.302734 | 0.238 sec/iter
Epoch: 36 | Batch: 028 / 029 | Total loss: 2.571 | Reg loss: 0.029 | Tree loss: 2.571 | Accuracy: 0.295547 | 0.238 sec/iter
Average sparseness: 0.9821428571428573
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
Epoch: 37 | Batch: 000 / 029 | Total loss: 2.732 | Reg loss: 0.029 | Tree loss: 2.732 | Accuracy: 0.322266 | 0.238 sec/iter
Epoch: 37 | Batch: 001 / 029 | Total loss: 2.672 | Reg loss: 0.029 | Tree loss: 2.672 | Accuracy: 0.337891 | 0.238 sec/iter
Epoch: 37 | Batch: 002

Epoch: 39 | Batch: 000 / 029 | Total loss: 2.555 | Reg loss: 0.028 | Tree loss: 2.555 | Accuracy: 0.310547 | 0.238 sec/iter
Epoch: 39 | Batch: 001 / 029 | Total loss: 2.568 | Reg loss: 0.028 | Tree loss: 2.568 | Accuracy: 0.275391 | 0.238 sec/iter
Epoch: 39 | Batch: 002 / 029 | Total loss: 2.511 | Reg loss: 0.028 | Tree loss: 2.511 | Accuracy: 0.300781 | 0.238 sec/iter
Epoch: 39 | Batch: 003 / 029 | Total loss: 2.526 | Reg loss: 0.028 | Tree loss: 2.526 | Accuracy: 0.314453 | 0.238 sec/iter
Epoch: 39 | Batch: 004 / 029 | Total loss: 2.538 | Reg loss: 0.028 | Tree loss: 2.538 | Accuracy: 0.287109 | 0.238 sec/iter
Epoch: 39 | Batch: 005 / 029 | Total loss: 2.530 | Reg loss: 0.028 | Tree loss: 2.530 | Accuracy: 0.304688 | 0.238 sec/iter
Epoch: 39 | Batch: 006 / 029 | Total loss: 2.527 | Reg loss: 0.028 | Tree loss: 2.527 | Accuracy: 0.296875 | 0.238 sec/iter
Epoch: 39 | Batch: 007 / 029 | Total loss: 2.505 | Reg loss: 0.028 | Tree loss: 2.505 | Accuracy: 0.281250 | 0.238 sec/iter
Epoch: 3

Epoch: 41 | Batch: 005 / 029 | Total loss: 2.414 | Reg loss: 0.028 | Tree loss: 2.414 | Accuracy: 0.304688 | 0.238 sec/iter
Epoch: 41 | Batch: 006 / 029 | Total loss: 2.409 | Reg loss: 0.028 | Tree loss: 2.409 | Accuracy: 0.302734 | 0.238 sec/iter
Epoch: 41 | Batch: 007 / 029 | Total loss: 2.414 | Reg loss: 0.028 | Tree loss: 2.414 | Accuracy: 0.279297 | 0.238 sec/iter
Epoch: 41 | Batch: 008 / 029 | Total loss: 2.420 | Reg loss: 0.028 | Tree loss: 2.420 | Accuracy: 0.248047 | 0.238 sec/iter
Epoch: 41 | Batch: 009 / 029 | Total loss: 2.371 | Reg loss: 0.028 | Tree loss: 2.371 | Accuracy: 0.304688 | 0.238 sec/iter
Epoch: 41 | Batch: 010 / 029 | Total loss: 2.345 | Reg loss: 0.028 | Tree loss: 2.345 | Accuracy: 0.326172 | 0.238 sec/iter
Epoch: 41 | Batch: 011 / 029 | Total loss: 2.352 | Reg loss: 0.028 | Tree loss: 2.352 | Accuracy: 0.291016 | 0.238 sec/iter
Epoch: 41 | Batch: 012 / 029 | Total loss: 2.362 | Reg loss: 0.028 | Tree loss: 2.362 | Accuracy: 0.281250 | 0.238 sec/iter
Epoch: 4

Epoch: 43 | Batch: 010 / 029 | Total loss: 2.253 | Reg loss: 0.028 | Tree loss: 2.253 | Accuracy: 0.285156 | 0.238 sec/iter
Epoch: 43 | Batch: 011 / 029 | Total loss: 2.224 | Reg loss: 0.028 | Tree loss: 2.224 | Accuracy: 0.320312 | 0.238 sec/iter
Epoch: 43 | Batch: 012 / 029 | Total loss: 2.221 | Reg loss: 0.028 | Tree loss: 2.221 | Accuracy: 0.324219 | 0.238 sec/iter
Epoch: 43 | Batch: 013 / 029 | Total loss: 2.263 | Reg loss: 0.028 | Tree loss: 2.263 | Accuracy: 0.285156 | 0.238 sec/iter
Epoch: 43 | Batch: 014 / 029 | Total loss: 2.256 | Reg loss: 0.028 | Tree loss: 2.256 | Accuracy: 0.298828 | 0.238 sec/iter
Epoch: 43 | Batch: 015 / 029 | Total loss: 2.246 | Reg loss: 0.028 | Tree loss: 2.246 | Accuracy: 0.279297 | 0.238 sec/iter
Epoch: 43 | Batch: 016 / 029 | Total loss: 2.195 | Reg loss: 0.028 | Tree loss: 2.195 | Accuracy: 0.273438 | 0.238 sec/iter
Epoch: 43 | Batch: 017 / 029 | Total loss: 2.203 | Reg loss: 0.028 | Tree loss: 2.203 | Accuracy: 0.330078 | 0.238 sec/iter
Epoch: 4

Epoch: 45 | Batch: 015 / 029 | Total loss: 2.091 | Reg loss: 0.028 | Tree loss: 2.091 | Accuracy: 0.316406 | 0.238 sec/iter
Epoch: 45 | Batch: 016 / 029 | Total loss: 2.077 | Reg loss: 0.028 | Tree loss: 2.077 | Accuracy: 0.333984 | 0.238 sec/iter
Epoch: 45 | Batch: 017 / 029 | Total loss: 2.105 | Reg loss: 0.028 | Tree loss: 2.105 | Accuracy: 0.300781 | 0.238 sec/iter
Epoch: 45 | Batch: 018 / 029 | Total loss: 2.100 | Reg loss: 0.028 | Tree loss: 2.100 | Accuracy: 0.310547 | 0.238 sec/iter
Epoch: 45 | Batch: 019 / 029 | Total loss: 2.125 | Reg loss: 0.028 | Tree loss: 2.125 | Accuracy: 0.269531 | 0.238 sec/iter
Epoch: 45 | Batch: 020 / 029 | Total loss: 2.111 | Reg loss: 0.028 | Tree loss: 2.111 | Accuracy: 0.287109 | 0.238 sec/iter
Epoch: 45 | Batch: 021 / 029 | Total loss: 2.135 | Reg loss: 0.028 | Tree loss: 2.135 | Accuracy: 0.291016 | 0.238 sec/iter
Epoch: 45 | Batch: 022 / 029 | Total loss: 2.130 | Reg loss: 0.028 | Tree loss: 2.130 | Accuracy: 0.287109 | 0.238 sec/iter
Epoch: 4

Epoch: 47 | Batch: 020 / 029 | Total loss: 2.059 | Reg loss: 0.028 | Tree loss: 2.059 | Accuracy: 0.271484 | 0.238 sec/iter
Epoch: 47 | Batch: 021 / 029 | Total loss: 2.014 | Reg loss: 0.028 | Tree loss: 2.014 | Accuracy: 0.318359 | 0.238 sec/iter
Epoch: 47 | Batch: 022 / 029 | Total loss: 2.016 | Reg loss: 0.028 | Tree loss: 2.016 | Accuracy: 0.292969 | 0.238 sec/iter
Epoch: 47 | Batch: 023 / 029 | Total loss: 2.036 | Reg loss: 0.028 | Tree loss: 2.036 | Accuracy: 0.300781 | 0.238 sec/iter
Epoch: 47 | Batch: 024 / 029 | Total loss: 2.013 | Reg loss: 0.028 | Tree loss: 2.013 | Accuracy: 0.314453 | 0.238 sec/iter
Epoch: 47 | Batch: 025 / 029 | Total loss: 2.012 | Reg loss: 0.028 | Tree loss: 2.012 | Accuracy: 0.298828 | 0.238 sec/iter
Epoch: 47 | Batch: 026 / 029 | Total loss: 2.012 | Reg loss: 0.028 | Tree loss: 2.012 | Accuracy: 0.294922 | 0.238 sec/iter
Epoch: 47 | Batch: 027 / 029 | Total loss: 2.034 | Reg loss: 0.028 | Tree loss: 2.034 | Accuracy: 0.273438 | 0.238 sec/iter
Epoch: 4

Epoch: 49 | Batch: 025 / 029 | Total loss: 1.981 | Reg loss: 0.027 | Tree loss: 1.981 | Accuracy: 0.316406 | 0.238 sec/iter
Epoch: 49 | Batch: 026 / 029 | Total loss: 1.976 | Reg loss: 0.027 | Tree loss: 1.976 | Accuracy: 0.298828 | 0.238 sec/iter
Epoch: 49 | Batch: 027 / 029 | Total loss: 1.935 | Reg loss: 0.027 | Tree loss: 1.935 | Accuracy: 0.271484 | 0.238 sec/iter
Epoch: 49 | Batch: 028 / 029 | Total loss: 1.945 | Reg loss: 0.028 | Tree loss: 1.945 | Accuracy: 0.289474 | 0.238 sec/iter
Average sparseness: 0.9821428571428573
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
Epoch: 50 | Batch: 000 / 029 | Total loss: 2.094 | Reg loss: 0.027 | Tree loss: 2.094 | Accuracy: 0.292969 | 0.238 sec/iter
Epoch: 50 | Batch: 001 / 029 | Total loss: 2.085 | Reg loss: 0.027 | Tree loss: 2.085 | Accuracy: 0.281250 | 0.238 sec/iter
Epoch: 50 | Batch: 002

Epoch: 52 | Batch: 000 / 029 | Total loss: 2.017 | Reg loss: 0.027 | Tree loss: 2.017 | Accuracy: 0.326172 | 0.238 sec/iter
Epoch: 52 | Batch: 001 / 029 | Total loss: 2.038 | Reg loss: 0.027 | Tree loss: 2.038 | Accuracy: 0.296875 | 0.238 sec/iter
Epoch: 52 | Batch: 002 / 029 | Total loss: 2.021 | Reg loss: 0.027 | Tree loss: 2.021 | Accuracy: 0.294922 | 0.238 sec/iter
Epoch: 52 | Batch: 003 / 029 | Total loss: 1.999 | Reg loss: 0.027 | Tree loss: 1.999 | Accuracy: 0.298828 | 0.238 sec/iter
Epoch: 52 | Batch: 004 / 029 | Total loss: 1.996 | Reg loss: 0.027 | Tree loss: 1.996 | Accuracy: 0.302734 | 0.238 sec/iter
Epoch: 52 | Batch: 005 / 029 | Total loss: 1.980 | Reg loss: 0.027 | Tree loss: 1.980 | Accuracy: 0.318359 | 0.238 sec/iter
Epoch: 52 | Batch: 006 / 029 | Total loss: 2.002 | Reg loss: 0.027 | Tree loss: 2.002 | Accuracy: 0.324219 | 0.238 sec/iter
Epoch: 52 | Batch: 007 / 029 | Total loss: 1.972 | Reg loss: 0.027 | Tree loss: 1.972 | Accuracy: 0.300781 | 0.238 sec/iter
Epoch: 5

Epoch: 54 | Batch: 005 / 029 | Total loss: 1.938 | Reg loss: 0.027 | Tree loss: 1.938 | Accuracy: 0.337891 | 0.238 sec/iter
Epoch: 54 | Batch: 006 / 029 | Total loss: 1.923 | Reg loss: 0.027 | Tree loss: 1.923 | Accuracy: 0.320312 | 0.238 sec/iter
Epoch: 54 | Batch: 007 / 029 | Total loss: 1.926 | Reg loss: 0.027 | Tree loss: 1.926 | Accuracy: 0.328125 | 0.238 sec/iter
Epoch: 54 | Batch: 008 / 029 | Total loss: 1.931 | Reg loss: 0.027 | Tree loss: 1.931 | Accuracy: 0.308594 | 0.238 sec/iter
Epoch: 54 | Batch: 009 / 029 | Total loss: 1.954 | Reg loss: 0.027 | Tree loss: 1.954 | Accuracy: 0.332031 | 0.238 sec/iter
Epoch: 54 | Batch: 010 / 029 | Total loss: 1.964 | Reg loss: 0.027 | Tree loss: 1.964 | Accuracy: 0.314453 | 0.238 sec/iter
Epoch: 54 | Batch: 011 / 029 | Total loss: 1.942 | Reg loss: 0.027 | Tree loss: 1.942 | Accuracy: 0.337891 | 0.238 sec/iter
Epoch: 54 | Batch: 012 / 029 | Total loss: 1.937 | Reg loss: 0.027 | Tree loss: 1.937 | Accuracy: 0.304688 | 0.238 sec/iter
Epoch: 5

Epoch: 56 | Batch: 010 / 029 | Total loss: 1.947 | Reg loss: 0.027 | Tree loss: 1.947 | Accuracy: 0.310547 | 0.238 sec/iter
Epoch: 56 | Batch: 011 / 029 | Total loss: 1.891 | Reg loss: 0.027 | Tree loss: 1.891 | Accuracy: 0.291016 | 0.238 sec/iter
Epoch: 56 | Batch: 012 / 029 | Total loss: 1.918 | Reg loss: 0.027 | Tree loss: 1.918 | Accuracy: 0.335938 | 0.238 sec/iter
Epoch: 56 | Batch: 013 / 029 | Total loss: 1.888 | Reg loss: 0.027 | Tree loss: 1.888 | Accuracy: 0.302734 | 0.238 sec/iter
Epoch: 56 | Batch: 014 / 029 | Total loss: 1.894 | Reg loss: 0.027 | Tree loss: 1.894 | Accuracy: 0.335938 | 0.238 sec/iter
Epoch: 56 | Batch: 015 / 029 | Total loss: 1.858 | Reg loss: 0.027 | Tree loss: 1.858 | Accuracy: 0.298828 | 0.238 sec/iter
Epoch: 56 | Batch: 016 / 029 | Total loss: 1.903 | Reg loss: 0.027 | Tree loss: 1.903 | Accuracy: 0.289062 | 0.238 sec/iter
Epoch: 56 | Batch: 017 / 029 | Total loss: 1.906 | Reg loss: 0.027 | Tree loss: 1.906 | Accuracy: 0.283203 | 0.238 sec/iter
Epoch: 5

Epoch: 58 | Batch: 015 / 029 | Total loss: 1.833 | Reg loss: 0.027 | Tree loss: 1.833 | Accuracy: 0.339844 | 0.238 sec/iter
Epoch: 58 | Batch: 016 / 029 | Total loss: 1.841 | Reg loss: 0.027 | Tree loss: 1.841 | Accuracy: 0.330078 | 0.238 sec/iter
Epoch: 58 | Batch: 017 / 029 | Total loss: 1.877 | Reg loss: 0.027 | Tree loss: 1.877 | Accuracy: 0.310547 | 0.238 sec/iter
Epoch: 58 | Batch: 018 / 029 | Total loss: 1.861 | Reg loss: 0.027 | Tree loss: 1.861 | Accuracy: 0.269531 | 0.238 sec/iter
Epoch: 58 | Batch: 019 / 029 | Total loss: 1.797 | Reg loss: 0.027 | Tree loss: 1.797 | Accuracy: 0.294922 | 0.238 sec/iter
Epoch: 58 | Batch: 020 / 029 | Total loss: 1.821 | Reg loss: 0.027 | Tree loss: 1.821 | Accuracy: 0.283203 | 0.239 sec/iter
Epoch: 58 | Batch: 021 / 029 | Total loss: 1.842 | Reg loss: 0.027 | Tree loss: 1.842 | Accuracy: 0.339844 | 0.239 sec/iter
Epoch: 58 | Batch: 022 / 029 | Total loss: 1.843 | Reg loss: 0.027 | Tree loss: 1.843 | Accuracy: 0.292969 | 0.239 sec/iter
Epoch: 5

Epoch: 60 | Batch: 020 / 029 | Total loss: 1.858 | Reg loss: 0.027 | Tree loss: 1.858 | Accuracy: 0.263672 | 0.239 sec/iter
Epoch: 60 | Batch: 021 / 029 | Total loss: 1.806 | Reg loss: 0.027 | Tree loss: 1.806 | Accuracy: 0.322266 | 0.239 sec/iter
Epoch: 60 | Batch: 022 / 029 | Total loss: 1.805 | Reg loss: 0.027 | Tree loss: 1.805 | Accuracy: 0.296875 | 0.239 sec/iter
Epoch: 60 | Batch: 023 / 029 | Total loss: 1.802 | Reg loss: 0.027 | Tree loss: 1.802 | Accuracy: 0.316406 | 0.239 sec/iter
Epoch: 60 | Batch: 024 / 029 | Total loss: 1.790 | Reg loss: 0.027 | Tree loss: 1.790 | Accuracy: 0.322266 | 0.239 sec/iter
Epoch: 60 | Batch: 025 / 029 | Total loss: 1.776 | Reg loss: 0.027 | Tree loss: 1.776 | Accuracy: 0.279297 | 0.239 sec/iter
Epoch: 60 | Batch: 026 / 029 | Total loss: 1.804 | Reg loss: 0.027 | Tree loss: 1.804 | Accuracy: 0.320312 | 0.239 sec/iter
Epoch: 60 | Batch: 027 / 029 | Total loss: 1.818 | Reg loss: 0.027 | Tree loss: 1.818 | Accuracy: 0.287109 | 0.239 sec/iter
Epoch: 6

Epoch: 62 | Batch: 025 / 029 | Total loss: 1.827 | Reg loss: 0.027 | Tree loss: 1.827 | Accuracy: 0.316406 | 0.239 sec/iter
Epoch: 62 | Batch: 026 / 029 | Total loss: 1.802 | Reg loss: 0.027 | Tree loss: 1.802 | Accuracy: 0.283203 | 0.239 sec/iter
Epoch: 62 | Batch: 027 / 029 | Total loss: 1.743 | Reg loss: 0.027 | Tree loss: 1.743 | Accuracy: 0.351562 | 0.239 sec/iter
Epoch: 62 | Batch: 028 / 029 | Total loss: 1.810 | Reg loss: 0.027 | Tree loss: 1.810 | Accuracy: 0.275304 | 0.239 sec/iter
Average sparseness: 0.9821428571428573
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
Epoch: 63 | Batch: 000 / 029 | Total loss: 1.848 | Reg loss: 0.026 | Tree loss: 1.848 | Accuracy: 0.283203 | 0.239 sec/iter
Epoch: 63 | Batch: 001 / 029 | Total loss: 1.844 | Reg loss: 0.026 | Tree loss: 1.844 | Accuracy: 0.335938 | 0.239 sec/iter
Epoch: 63 | Batch: 002

Epoch: 65 | Batch: 000 / 029 | Total loss: 1.840 | Reg loss: 0.026 | Tree loss: 1.840 | Accuracy: 0.318359 | 0.239 sec/iter
Epoch: 65 | Batch: 001 / 029 | Total loss: 1.849 | Reg loss: 0.026 | Tree loss: 1.849 | Accuracy: 0.330078 | 0.239 sec/iter
Epoch: 65 | Batch: 002 / 029 | Total loss: 1.866 | Reg loss: 0.026 | Tree loss: 1.866 | Accuracy: 0.320312 | 0.239 sec/iter
Epoch: 65 | Batch: 003 / 029 | Total loss: 1.892 | Reg loss: 0.026 | Tree loss: 1.892 | Accuracy: 0.302734 | 0.239 sec/iter
Epoch: 65 | Batch: 004 / 029 | Total loss: 1.872 | Reg loss: 0.026 | Tree loss: 1.872 | Accuracy: 0.306641 | 0.239 sec/iter
Epoch: 65 | Batch: 005 / 029 | Total loss: 1.827 | Reg loss: 0.026 | Tree loss: 1.827 | Accuracy: 0.300781 | 0.239 sec/iter
Epoch: 65 | Batch: 006 / 029 | Total loss: 1.786 | Reg loss: 0.026 | Tree loss: 1.786 | Accuracy: 0.351562 | 0.239 sec/iter
Epoch: 65 | Batch: 007 / 029 | Total loss: 1.838 | Reg loss: 0.026 | Tree loss: 1.838 | Accuracy: 0.337891 | 0.239 sec/iter
Epoch: 6

Epoch: 67 | Batch: 005 / 029 | Total loss: 1.853 | Reg loss: 0.026 | Tree loss: 1.853 | Accuracy: 0.279297 | 0.239 sec/iter
Epoch: 67 | Batch: 006 / 029 | Total loss: 1.849 | Reg loss: 0.026 | Tree loss: 1.849 | Accuracy: 0.314453 | 0.239 sec/iter
Epoch: 67 | Batch: 007 / 029 | Total loss: 1.811 | Reg loss: 0.026 | Tree loss: 1.811 | Accuracy: 0.337891 | 0.239 sec/iter
Epoch: 67 | Batch: 008 / 029 | Total loss: 1.795 | Reg loss: 0.026 | Tree loss: 1.795 | Accuracy: 0.332031 | 0.239 sec/iter
Epoch: 67 | Batch: 009 / 029 | Total loss: 1.764 | Reg loss: 0.026 | Tree loss: 1.764 | Accuracy: 0.310547 | 0.239 sec/iter
Epoch: 67 | Batch: 010 / 029 | Total loss: 1.831 | Reg loss: 0.026 | Tree loss: 1.831 | Accuracy: 0.292969 | 0.239 sec/iter
Epoch: 67 | Batch: 011 / 029 | Total loss: 1.812 | Reg loss: 0.026 | Tree loss: 1.812 | Accuracy: 0.306641 | 0.239 sec/iter
Epoch: 67 | Batch: 012 / 029 | Total loss: 1.841 | Reg loss: 0.026 | Tree loss: 1.841 | Accuracy: 0.322266 | 0.239 sec/iter
Epoch: 6

Epoch: 69 | Batch: 010 / 029 | Total loss: 1.819 | Reg loss: 0.026 | Tree loss: 1.819 | Accuracy: 0.279297 | 0.239 sec/iter
Epoch: 69 | Batch: 011 / 029 | Total loss: 1.776 | Reg loss: 0.026 | Tree loss: 1.776 | Accuracy: 0.302734 | 0.239 sec/iter
Epoch: 69 | Batch: 012 / 029 | Total loss: 1.843 | Reg loss: 0.026 | Tree loss: 1.843 | Accuracy: 0.292969 | 0.239 sec/iter
Epoch: 69 | Batch: 013 / 029 | Total loss: 1.818 | Reg loss: 0.026 | Tree loss: 1.818 | Accuracy: 0.306641 | 0.239 sec/iter
Epoch: 69 | Batch: 014 / 029 | Total loss: 1.784 | Reg loss: 0.026 | Tree loss: 1.784 | Accuracy: 0.332031 | 0.239 sec/iter
Epoch: 69 | Batch: 015 / 029 | Total loss: 1.841 | Reg loss: 0.026 | Tree loss: 1.841 | Accuracy: 0.271484 | 0.239 sec/iter
Epoch: 69 | Batch: 016 / 029 | Total loss: 1.786 | Reg loss: 0.026 | Tree loss: 1.786 | Accuracy: 0.328125 | 0.239 sec/iter
Epoch: 69 | Batch: 017 / 029 | Total loss: 1.736 | Reg loss: 0.026 | Tree loss: 1.736 | Accuracy: 0.320312 | 0.239 sec/iter
Epoch: 6

Epoch: 71 | Batch: 015 / 029 | Total loss: 1.757 | Reg loss: 0.026 | Tree loss: 1.757 | Accuracy: 0.316406 | 0.24 sec/iter
Epoch: 71 | Batch: 016 / 029 | Total loss: 1.727 | Reg loss: 0.026 | Tree loss: 1.727 | Accuracy: 0.328125 | 0.24 sec/iter
Epoch: 71 | Batch: 017 / 029 | Total loss: 1.747 | Reg loss: 0.026 | Tree loss: 1.747 | Accuracy: 0.333984 | 0.24 sec/iter
Epoch: 71 | Batch: 018 / 029 | Total loss: 1.781 | Reg loss: 0.026 | Tree loss: 1.781 | Accuracy: 0.283203 | 0.24 sec/iter
Epoch: 71 | Batch: 019 / 029 | Total loss: 1.770 | Reg loss: 0.026 | Tree loss: 1.770 | Accuracy: 0.322266 | 0.24 sec/iter
Epoch: 71 | Batch: 020 / 029 | Total loss: 1.746 | Reg loss: 0.027 | Tree loss: 1.746 | Accuracy: 0.326172 | 0.24 sec/iter
Epoch: 71 | Batch: 021 / 029 | Total loss: 1.735 | Reg loss: 0.027 | Tree loss: 1.735 | Accuracy: 0.316406 | 0.24 sec/iter
Epoch: 71 | Batch: 022 / 029 | Total loss: 1.739 | Reg loss: 0.027 | Tree loss: 1.739 | Accuracy: 0.312500 | 0.24 sec/iter
Epoch: 71 | Batc

Epoch: 73 | Batch: 020 / 029 | Total loss: 1.764 | Reg loss: 0.026 | Tree loss: 1.764 | Accuracy: 0.285156 | 0.24 sec/iter
Epoch: 73 | Batch: 021 / 029 | Total loss: 1.770 | Reg loss: 0.026 | Tree loss: 1.770 | Accuracy: 0.330078 | 0.24 sec/iter
Epoch: 73 | Batch: 022 / 029 | Total loss: 1.730 | Reg loss: 0.027 | Tree loss: 1.730 | Accuracy: 0.332031 | 0.24 sec/iter
Epoch: 73 | Batch: 023 / 029 | Total loss: 1.772 | Reg loss: 0.027 | Tree loss: 1.772 | Accuracy: 0.312500 | 0.24 sec/iter
Epoch: 73 | Batch: 024 / 029 | Total loss: 1.750 | Reg loss: 0.027 | Tree loss: 1.750 | Accuracy: 0.312500 | 0.24 sec/iter
Epoch: 73 | Batch: 025 / 029 | Total loss: 1.774 | Reg loss: 0.027 | Tree loss: 1.774 | Accuracy: 0.269531 | 0.24 sec/iter
Epoch: 73 | Batch: 026 / 029 | Total loss: 1.685 | Reg loss: 0.027 | Tree loss: 1.685 | Accuracy: 0.353516 | 0.24 sec/iter
Epoch: 73 | Batch: 027 / 029 | Total loss: 1.747 | Reg loss: 0.027 | Tree loss: 1.747 | Accuracy: 0.320312 | 0.24 sec/iter
Epoch: 73 | Batc

Epoch: 75 | Batch: 025 / 029 | Total loss: 1.717 | Reg loss: 0.027 | Tree loss: 1.717 | Accuracy: 0.345703 | 0.24 sec/iter
Epoch: 75 | Batch: 026 / 029 | Total loss: 1.718 | Reg loss: 0.027 | Tree loss: 1.718 | Accuracy: 0.341797 | 0.24 sec/iter
Epoch: 75 | Batch: 027 / 029 | Total loss: 1.677 | Reg loss: 0.027 | Tree loss: 1.677 | Accuracy: 0.332031 | 0.24 sec/iter
Epoch: 75 | Batch: 028 / 029 | Total loss: 1.684 | Reg loss: 0.027 | Tree loss: 1.684 | Accuracy: 0.319838 | 0.24 sec/iter
Average sparseness: 0.9821428571428573
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
Epoch: 76 | Batch: 000 / 029 | Total loss: 1.801 | Reg loss: 0.026 | Tree loss: 1.801 | Accuracy: 0.320312 | 0.24 sec/iter
Epoch: 76 | Batch: 001 / 029 | Total loss: 1.849 | Reg loss: 0.026 | Tree loss: 1.849 | Accuracy: 0.310547 | 0.24 sec/iter
Epoch: 76 | Batch: 002 / 029

Epoch: 78 | Batch: 000 / 029 | Total loss: 1.786 | Reg loss: 0.026 | Tree loss: 1.786 | Accuracy: 0.337891 | 0.24 sec/iter
Epoch: 78 | Batch: 001 / 029 | Total loss: 1.798 | Reg loss: 0.026 | Tree loss: 1.798 | Accuracy: 0.308594 | 0.24 sec/iter
Epoch: 78 | Batch: 002 / 029 | Total loss: 1.790 | Reg loss: 0.026 | Tree loss: 1.790 | Accuracy: 0.343750 | 0.24 sec/iter
Epoch: 78 | Batch: 003 / 029 | Total loss: 1.781 | Reg loss: 0.026 | Tree loss: 1.781 | Accuracy: 0.320312 | 0.24 sec/iter
Epoch: 78 | Batch: 004 / 029 | Total loss: 1.772 | Reg loss: 0.026 | Tree loss: 1.772 | Accuracy: 0.361328 | 0.24 sec/iter
Epoch: 78 | Batch: 005 / 029 | Total loss: 1.805 | Reg loss: 0.026 | Tree loss: 1.805 | Accuracy: 0.294922 | 0.24 sec/iter
Epoch: 78 | Batch: 006 / 029 | Total loss: 1.778 | Reg loss: 0.026 | Tree loss: 1.778 | Accuracy: 0.312500 | 0.24 sec/iter
Epoch: 78 | Batch: 007 / 029 | Total loss: 1.791 | Reg loss: 0.026 | Tree loss: 1.791 | Accuracy: 0.302734 | 0.24 sec/iter
Epoch: 78 | Batc

Epoch: 80 | Batch: 005 / 029 | Total loss: 1.776 | Reg loss: 0.026 | Tree loss: 1.776 | Accuracy: 0.322266 | 0.24 sec/iter
Epoch: 80 | Batch: 006 / 029 | Total loss: 1.821 | Reg loss: 0.026 | Tree loss: 1.821 | Accuracy: 0.298828 | 0.24 sec/iter
Epoch: 80 | Batch: 007 / 029 | Total loss: 1.773 | Reg loss: 0.026 | Tree loss: 1.773 | Accuracy: 0.308594 | 0.24 sec/iter
Epoch: 80 | Batch: 008 / 029 | Total loss: 1.770 | Reg loss: 0.026 | Tree loss: 1.770 | Accuracy: 0.281250 | 0.24 sec/iter
Epoch: 80 | Batch: 009 / 029 | Total loss: 1.765 | Reg loss: 0.026 | Tree loss: 1.765 | Accuracy: 0.328125 | 0.24 sec/iter
Epoch: 80 | Batch: 010 / 029 | Total loss: 1.815 | Reg loss: 0.026 | Tree loss: 1.815 | Accuracy: 0.289062 | 0.24 sec/iter
Epoch: 80 | Batch: 011 / 029 | Total loss: 1.735 | Reg loss: 0.026 | Tree loss: 1.735 | Accuracy: 0.347656 | 0.24 sec/iter
Epoch: 80 | Batch: 012 / 029 | Total loss: 1.758 | Reg loss: 0.026 | Tree loss: 1.758 | Accuracy: 0.332031 | 0.24 sec/iter
Epoch: 80 | Batc

Epoch: 82 | Batch: 010 / 029 | Total loss: 1.764 | Reg loss: 0.026 | Tree loss: 1.764 | Accuracy: 0.310547 | 0.24 sec/iter
Epoch: 82 | Batch: 011 / 029 | Total loss: 1.786 | Reg loss: 0.026 | Tree loss: 1.786 | Accuracy: 0.291016 | 0.24 sec/iter
Epoch: 82 | Batch: 012 / 029 | Total loss: 1.737 | Reg loss: 0.026 | Tree loss: 1.737 | Accuracy: 0.312500 | 0.24 sec/iter
Epoch: 82 | Batch: 013 / 029 | Total loss: 1.773 | Reg loss: 0.026 | Tree loss: 1.773 | Accuracy: 0.265625 | 0.24 sec/iter
Epoch: 82 | Batch: 014 / 029 | Total loss: 1.735 | Reg loss: 0.026 | Tree loss: 1.735 | Accuracy: 0.287109 | 0.24 sec/iter
Epoch: 82 | Batch: 015 / 029 | Total loss: 1.714 | Reg loss: 0.026 | Tree loss: 1.714 | Accuracy: 0.330078 | 0.24 sec/iter
Epoch: 82 | Batch: 016 / 029 | Total loss: 1.728 | Reg loss: 0.026 | Tree loss: 1.728 | Accuracy: 0.298828 | 0.24 sec/iter
Epoch: 82 | Batch: 017 / 029 | Total loss: 1.709 | Reg loss: 0.026 | Tree loss: 1.709 | Accuracy: 0.347656 | 0.24 sec/iter
Epoch: 82 | Batc

Epoch: 84 | Batch: 015 / 029 | Total loss: 1.749 | Reg loss: 0.026 | Tree loss: 1.749 | Accuracy: 0.287109 | 0.24 sec/iter
Epoch: 84 | Batch: 016 / 029 | Total loss: 1.718 | Reg loss: 0.026 | Tree loss: 1.718 | Accuracy: 0.292969 | 0.24 sec/iter
Epoch: 84 | Batch: 017 / 029 | Total loss: 1.733 | Reg loss: 0.026 | Tree loss: 1.733 | Accuracy: 0.294922 | 0.24 sec/iter
Epoch: 84 | Batch: 018 / 029 | Total loss: 1.758 | Reg loss: 0.026 | Tree loss: 1.758 | Accuracy: 0.312500 | 0.24 sec/iter
Epoch: 84 | Batch: 019 / 029 | Total loss: 1.728 | Reg loss: 0.026 | Tree loss: 1.728 | Accuracy: 0.326172 | 0.24 sec/iter
Epoch: 84 | Batch: 020 / 029 | Total loss: 1.747 | Reg loss: 0.026 | Tree loss: 1.747 | Accuracy: 0.322266 | 0.24 sec/iter
Epoch: 84 | Batch: 021 / 029 | Total loss: 1.656 | Reg loss: 0.026 | Tree loss: 1.656 | Accuracy: 0.369141 | 0.24 sec/iter
Epoch: 84 | Batch: 022 / 029 | Total loss: 1.723 | Reg loss: 0.026 | Tree loss: 1.723 | Accuracy: 0.292969 | 0.24 sec/iter
Epoch: 84 | Batc

Epoch: 86 | Batch: 020 / 029 | Total loss: 1.743 | Reg loss: 0.026 | Tree loss: 1.743 | Accuracy: 0.255859 | 0.24 sec/iter
Epoch: 86 | Batch: 021 / 029 | Total loss: 1.710 | Reg loss: 0.026 | Tree loss: 1.710 | Accuracy: 0.314453 | 0.24 sec/iter
Epoch: 86 | Batch: 022 / 029 | Total loss: 1.700 | Reg loss: 0.026 | Tree loss: 1.700 | Accuracy: 0.302734 | 0.24 sec/iter
Epoch: 86 | Batch: 023 / 029 | Total loss: 1.689 | Reg loss: 0.026 | Tree loss: 1.689 | Accuracy: 0.330078 | 0.24 sec/iter
Epoch: 86 | Batch: 024 / 029 | Total loss: 1.699 | Reg loss: 0.026 | Tree loss: 1.699 | Accuracy: 0.324219 | 0.24 sec/iter
Epoch: 86 | Batch: 025 / 029 | Total loss: 1.744 | Reg loss: 0.026 | Tree loss: 1.744 | Accuracy: 0.308594 | 0.24 sec/iter
Epoch: 86 | Batch: 026 / 029 | Total loss: 1.711 | Reg loss: 0.026 | Tree loss: 1.711 | Accuracy: 0.318359 | 0.24 sec/iter
Epoch: 86 | Batch: 027 / 029 | Total loss: 1.709 | Reg loss: 0.026 | Tree loss: 1.709 | Accuracy: 0.298828 | 0.24 sec/iter
Epoch: 86 | Batc

Epoch: 88 | Batch: 025 / 029 | Total loss: 1.683 | Reg loss: 0.026 | Tree loss: 1.683 | Accuracy: 0.322266 | 0.24 sec/iter
Epoch: 88 | Batch: 026 / 029 | Total loss: 1.688 | Reg loss: 0.026 | Tree loss: 1.688 | Accuracy: 0.318359 | 0.24 sec/iter
Epoch: 88 | Batch: 027 / 029 | Total loss: 1.701 | Reg loss: 0.026 | Tree loss: 1.701 | Accuracy: 0.332031 | 0.24 sec/iter
Epoch: 88 | Batch: 028 / 029 | Total loss: 1.701 | Reg loss: 0.026 | Tree loss: 1.701 | Accuracy: 0.311741 | 0.24 sec/iter
Average sparseness: 0.9821428571428573
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
Epoch: 89 | Batch: 000 / 029 | Total loss: 1.764 | Reg loss: 0.026 | Tree loss: 1.764 | Accuracy: 0.351562 | 0.24 sec/iter
Epoch: 89 | Batch: 001 / 029 | Total loss: 1.803 | Reg loss: 0.026 | Tree loss: 1.803 | Accuracy: 0.287109 | 0.24 sec/iter
Epoch: 89 | Batch: 002 / 029

Epoch: 91 | Batch: 000 / 029 | Total loss: 1.791 | Reg loss: 0.026 | Tree loss: 1.791 | Accuracy: 0.324219 | 0.24 sec/iter
Epoch: 91 | Batch: 001 / 029 | Total loss: 1.819 | Reg loss: 0.026 | Tree loss: 1.819 | Accuracy: 0.310547 | 0.24 sec/iter
Epoch: 91 | Batch: 002 / 029 | Total loss: 1.809 | Reg loss: 0.026 | Tree loss: 1.809 | Accuracy: 0.324219 | 0.24 sec/iter
Epoch: 91 | Batch: 003 / 029 | Total loss: 1.769 | Reg loss: 0.026 | Tree loss: 1.769 | Accuracy: 0.337891 | 0.24 sec/iter
Epoch: 91 | Batch: 004 / 029 | Total loss: 1.759 | Reg loss: 0.026 | Tree loss: 1.759 | Accuracy: 0.316406 | 0.24 sec/iter
Epoch: 91 | Batch: 005 / 029 | Total loss: 1.814 | Reg loss: 0.026 | Tree loss: 1.814 | Accuracy: 0.298828 | 0.24 sec/iter
Epoch: 91 | Batch: 006 / 029 | Total loss: 1.800 | Reg loss: 0.026 | Tree loss: 1.800 | Accuracy: 0.298828 | 0.24 sec/iter
Epoch: 91 | Batch: 007 / 029 | Total loss: 1.741 | Reg loss: 0.026 | Tree loss: 1.741 | Accuracy: 0.320312 | 0.24 sec/iter
Epoch: 91 | Batc

Epoch: 93 | Batch: 005 / 029 | Total loss: 1.769 | Reg loss: 0.026 | Tree loss: 1.769 | Accuracy: 0.296875 | 0.239 sec/iter
Epoch: 93 | Batch: 006 / 029 | Total loss: 1.795 | Reg loss: 0.026 | Tree loss: 1.795 | Accuracy: 0.289062 | 0.239 sec/iter
Epoch: 93 | Batch: 007 / 029 | Total loss: 1.819 | Reg loss: 0.026 | Tree loss: 1.819 | Accuracy: 0.292969 | 0.239 sec/iter
Epoch: 93 | Batch: 008 / 029 | Total loss: 1.733 | Reg loss: 0.026 | Tree loss: 1.733 | Accuracy: 0.335938 | 0.239 sec/iter
Epoch: 93 | Batch: 009 / 029 | Total loss: 1.749 | Reg loss: 0.026 | Tree loss: 1.749 | Accuracy: 0.324219 | 0.239 sec/iter
Epoch: 93 | Batch: 010 / 029 | Total loss: 1.792 | Reg loss: 0.026 | Tree loss: 1.792 | Accuracy: 0.312500 | 0.239 sec/iter
Epoch: 93 | Batch: 011 / 029 | Total loss: 1.755 | Reg loss: 0.026 | Tree loss: 1.755 | Accuracy: 0.312500 | 0.239 sec/iter
Epoch: 93 | Batch: 012 / 029 | Total loss: 1.740 | Reg loss: 0.026 | Tree loss: 1.740 | Accuracy: 0.283203 | 0.239 sec/iter
Epoch: 9

Epoch: 95 | Batch: 010 / 029 | Total loss: 1.693 | Reg loss: 0.026 | Tree loss: 1.693 | Accuracy: 0.318359 | 0.239 sec/iter
Epoch: 95 | Batch: 011 / 029 | Total loss: 1.722 | Reg loss: 0.026 | Tree loss: 1.722 | Accuracy: 0.308594 | 0.239 sec/iter
Epoch: 95 | Batch: 012 / 029 | Total loss: 1.758 | Reg loss: 0.026 | Tree loss: 1.758 | Accuracy: 0.353516 | 0.239 sec/iter
Epoch: 95 | Batch: 013 / 029 | Total loss: 1.721 | Reg loss: 0.026 | Tree loss: 1.721 | Accuracy: 0.287109 | 0.239 sec/iter
Epoch: 95 | Batch: 014 / 029 | Total loss: 1.742 | Reg loss: 0.026 | Tree loss: 1.742 | Accuracy: 0.318359 | 0.239 sec/iter
Epoch: 95 | Batch: 015 / 029 | Total loss: 1.730 | Reg loss: 0.026 | Tree loss: 1.730 | Accuracy: 0.316406 | 0.239 sec/iter
Epoch: 95 | Batch: 016 / 029 | Total loss: 1.737 | Reg loss: 0.026 | Tree loss: 1.737 | Accuracy: 0.308594 | 0.239 sec/iter
Epoch: 95 | Batch: 017 / 029 | Total loss: 1.705 | Reg loss: 0.026 | Tree loss: 1.705 | Accuracy: 0.294922 | 0.239 sec/iter
Epoch: 9

Epoch: 97 | Batch: 015 / 029 | Total loss: 1.769 | Reg loss: 0.026 | Tree loss: 1.769 | Accuracy: 0.285156 | 0.239 sec/iter
Epoch: 97 | Batch: 016 / 029 | Total loss: 1.693 | Reg loss: 0.026 | Tree loss: 1.693 | Accuracy: 0.328125 | 0.239 sec/iter
Epoch: 97 | Batch: 017 / 029 | Total loss: 1.747 | Reg loss: 0.026 | Tree loss: 1.747 | Accuracy: 0.302734 | 0.239 sec/iter
Epoch: 97 | Batch: 018 / 029 | Total loss: 1.723 | Reg loss: 0.026 | Tree loss: 1.723 | Accuracy: 0.316406 | 0.239 sec/iter
Epoch: 97 | Batch: 019 / 029 | Total loss: 1.697 | Reg loss: 0.026 | Tree loss: 1.697 | Accuracy: 0.347656 | 0.239 sec/iter
Epoch: 97 | Batch: 020 / 029 | Total loss: 1.697 | Reg loss: 0.026 | Tree loss: 1.697 | Accuracy: 0.289062 | 0.239 sec/iter
Epoch: 97 | Batch: 021 / 029 | Total loss: 1.627 | Reg loss: 0.026 | Tree loss: 1.627 | Accuracy: 0.359375 | 0.239 sec/iter
Epoch: 97 | Batch: 022 / 029 | Total loss: 1.644 | Reg loss: 0.026 | Tree loss: 1.644 | Accuracy: 0.337891 | 0.239 sec/iter
Epoch: 9

Epoch: 99 | Batch: 020 / 029 | Total loss: 1.726 | Reg loss: 0.026 | Tree loss: 1.726 | Accuracy: 0.279297 | 0.239 sec/iter
Epoch: 99 | Batch: 021 / 029 | Total loss: 1.719 | Reg loss: 0.026 | Tree loss: 1.719 | Accuracy: 0.330078 | 0.239 sec/iter
Epoch: 99 | Batch: 022 / 029 | Total loss: 1.674 | Reg loss: 0.026 | Tree loss: 1.674 | Accuracy: 0.339844 | 0.239 sec/iter
Epoch: 99 | Batch: 023 / 029 | Total loss: 1.735 | Reg loss: 0.026 | Tree loss: 1.735 | Accuracy: 0.287109 | 0.239 sec/iter
Epoch: 99 | Batch: 024 / 029 | Total loss: 1.682 | Reg loss: 0.026 | Tree loss: 1.682 | Accuracy: 0.330078 | 0.239 sec/iter
Epoch: 99 | Batch: 025 / 029 | Total loss: 1.694 | Reg loss: 0.026 | Tree loss: 1.694 | Accuracy: 0.322266 | 0.239 sec/iter
Epoch: 99 | Batch: 026 / 029 | Total loss: 1.732 | Reg loss: 0.026 | Tree loss: 1.732 | Accuracy: 0.281250 | 0.239 sec/iter
Epoch: 99 | Batch: 027 / 029 | Total loss: 1.698 | Reg loss: 0.026 | Tree loss: 1.698 | Accuracy: 0.314453 | 0.239 sec/iter
Epoch: 9

In [33]:
plt.figure(figsize=(10, 5))
plt.ylabel("Accuracy")
plt.xlabel('Iteration')
plt.plot(accs, label='Accuracy vs iteration')
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [34]:
plt.figure()
plt.ylabel("Loss")
plt.xlabel('Iteration')
plt.plot(losses, label='Loss vs iteration')
plt.yscale("log")
plt.show()

plt.figure()
weights = tree.inner_nodes.weight.cpu().detach().numpy().flatten()
plt.hist(weights, bins=500)
weights_std = np.std(weights)
weights_mean = np.mean(weights)
plt.axvline(weights_mean + weights_std, color='r')
plt.axvline(weights_mean - weights_std, color='r')
plt.title(f"Mean: {weights_mean}   |   STD: {weights_std}")
plt.yscale("log")
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

# Tree Visualization

In [35]:
plt.figure(figsize=(15, 10), dpi=80)
avg_height, root = tree.visualize()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Average height: 7.048192771084337


# Extract Rules

# Accumulate samples in the leaves

In [36]:
print(f"Number of patterns: {len(root.get_leaves())}")

Number of patterns: 83


In [37]:
method = 'greedy'

In [38]:
root.clear_leaves_samples()

with torch.no_grad():
    for batch_idx, (data, target) in enumerate(tree_loader):
        root.accumulate_samples(data, method)



# Tighten boundaries

In [39]:
attr_names = dataset.items

# print(attr_names)
leaves = root.get_leaves()
sum_comprehensibility = 0
comprehensibilities = []
for pattern_counter, leaf in enumerate(leaves):
    leaf.reset_path()
    leaf.tighten_with_accumulated_samples()
    conds = leaf.get_path_conditions(attr_names)
    print(f"============== Pattern {pattern_counter + 1} ==============")
    comprehensibilities.append(sum([cond.comprehensibility for cond in conds]))
    
print(f"Average comprehensibility: {np.mean(comprehensibilities)}")
print(f"std comprehensibility: {np.std(comprehensibilities)}")
print(f"var comprehensibility: {np.var(comprehensibilities)}")
print(f"minimum comprehensibility: {np.min(comprehensibilities)}")
print(f"maximum comprehensibility: {np.max(comprehensibilities)}")

13121
1709


  return np.log(1 / (1 - x))


Average comprehensibility: 30.771084337349397
std comprehensibility: 5.962382682950454
var comprehensibility: 35.550007257947456
minimum comprehensibility: 18
maximum comprehensibility: 40
