In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib widget
import numpy as np
import pandas as pd
import torch
import torchvision
from torch import optim
import matplotlib.pyplot as plt
import torch.nn.functional as F
from sklearn.metrics import pairwise_distances
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tqdm.notebook import tqdm
import os
import sys
import time
import torch.nn as nn
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)
from stream_generators.market_basket_dataset import MarketBasketDataset, BinaryEncodingTransform, RemoveItemsTransform
from utils.MatplotlibUtils import reduce_dims_and_plot
from network.auto_encoder import AutoEncoder
from losses.knn_loss import KNNLoss
from sklearn.cluster import DBSCAN
from sklearn.mixture import GaussianMixture
from soft_decision_tree.sdt_model import SDT
from sklearn.metrics import davies_bouldin_score

In [2]:
k = 256
tree_depth = 12
device = 'cuda'
dataset_path = r"/mnt/qnap/ekosman/Groceries_dataset.csv"

## Create the market basket dataset and use one-hot encoding for items

In [3]:
dataset = MarketBasketDataset(dataset_path=dataset_path)

In [4]:
model = AutoEncoder(dataset.n_items, 50, 4).train().to(device)
epochs = 500
lr = 5e-3
batch_size = 512
log_every = 5

In [5]:
dataset.transform = torchvision.transforms.Compose([
    RemoveItemsTransform(p=0.5),
    BinaryEncodingTransform(mapping=dataset.items_to_idx),
]
)
dataset.target_transform = torchvision.transforms.Compose([
    BinaryEncodingTransform(mapping=dataset.items_to_idx),
]
)

In [6]:
model.train()
data_iter = torch.utils.data.DataLoader(dataset,
                                     batch_size=batch_size,
                                     shuffle=True,
                                     num_workers=1,
                                     pin_memory=True)
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)
scheduler = ReduceLROnPlateau(optimizer, 'min', verbose=True, factor=0.7, threshold=1e-4)
knn_crt = KNNLoss(k=k).to(device)
losses = []
alpha = 10/170
gamma = 2
for epoch in range(epochs):
    total_loss = 0
    for iteration, (batch, target) in enumerate(data_iter):
        batch = batch.to(device)
        target = target.to(device)
        outputs, iterm = model(batch, return_intermidiate=True)
        mse_loss = F.binary_cross_entropy_with_logits(outputs, target, reduction='none')
        mask = torch.ones_like(mse_loss)
        mask[target == 0] = alpha ** gamma
        mask[target == 1] = (1 - alpha) ** gamma
        mse_loss = (mse_loss * mask).sum(dim=-1).mean()
        try:
            knn_loss = knn_crt(iterm)
            if torch.isinf(knn_loss):
                knn_loss = 0
        except ValueError:
            knn_loss = torch.tensor(0)
        loss = mse_loss + knn_loss
        total_loss += loss.item()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if iteration % log_every == 0:
            print(f"Epoch {epoch} / {epochs} | iteration {iteration} / {len(data_iter)} | Total Loss: {loss.item()} | KNN Loss: {knn_loss.item()} | BCE Loss: {mse_loss.item()}")
    
    scheduler.step(total_loss / (iteration + 1))
    losses.append(total_loss / (iteration + 1))

Epoch 0 / 500 | iteration 0 / 30 | Total Loss: 8.182477951049805 | KNN Loss: 6.232606410980225 | BCE Loss: 1.94987154006958
Epoch 0 / 500 | iteration 5 / 30 | Total Loss: 8.211150169372559 | KNN Loss: 6.232600688934326 | BCE Loss: 1.9785493612289429
Epoch 0 / 500 | iteration 10 / 30 | Total Loss: 8.215576171875 | KNN Loss: 6.232440948486328 | BCE Loss: 1.98313570022583
Epoch 0 / 500 | iteration 15 / 30 | Total Loss: 8.188455581665039 | KNN Loss: 6.232435703277588 | BCE Loss: 1.9560197591781616
Epoch 0 / 500 | iteration 20 / 30 | Total Loss: 8.175670623779297 | KNN Loss: 6.232313632965088 | BCE Loss: 1.9433574676513672
Epoch 0 / 500 | iteration 25 / 30 | Total Loss: 8.12234878540039 | KNN Loss: 6.232375621795654 | BCE Loss: 1.8899729251861572
Epoch 1 / 500 | iteration 0 / 30 | Total Loss: 8.198297500610352 | KNN Loss: 6.232083797454834 | BCE Loss: 1.9662132263183594
Epoch 1 / 500 | iteration 5 / 30 | Total Loss: 8.153447151184082 | KNN Loss: 6.232396602630615 | BCE Loss: 1.9210506677627

Epoch 10 / 500 | iteration 25 / 30 | Total Loss: 7.318770408630371 | KNN Loss: 6.189958572387695 | BCE Loss: 1.1288115978240967
Epoch 11 / 500 | iteration 0 / 30 | Total Loss: 7.308002471923828 | KNN Loss: 6.188998222351074 | BCE Loss: 1.119004487991333
Epoch 11 / 500 | iteration 5 / 30 | Total Loss: 7.333943843841553 | KNN Loss: 6.183051586151123 | BCE Loss: 1.1508922576904297
Epoch 11 / 500 | iteration 10 / 30 | Total Loss: 7.269381523132324 | KNN Loss: 6.177389144897461 | BCE Loss: 1.0919926166534424
Epoch 11 / 500 | iteration 15 / 30 | Total Loss: 7.314318656921387 | KNN Loss: 6.177396774291992 | BCE Loss: 1.136922001838684
Epoch 11 / 500 | iteration 20 / 30 | Total Loss: 7.283210754394531 | KNN Loss: 6.16886043548584 | BCE Loss: 1.1143503189086914
Epoch 11 / 500 | iteration 25 / 30 | Total Loss: 7.2974090576171875 | KNN Loss: 6.164889812469482 | BCE Loss: 1.1325193643569946
Epoch 12 / 500 | iteration 0 / 30 | Total Loss: 7.250821590423584 | KNN Loss: 6.15731143951416 | BCE Loss: 1

Epoch 21 / 500 | iteration 20 / 30 | Total Loss: 6.800716876983643 | KNN Loss: 5.751204490661621 | BCE Loss: 1.0495123863220215
Epoch 21 / 500 | iteration 25 / 30 | Total Loss: 6.809625625610352 | KNN Loss: 5.739058017730713 | BCE Loss: 1.0705674886703491
Epoch 22 / 500 | iteration 0 / 30 | Total Loss: 6.784581184387207 | KNN Loss: 5.72567081451416 | BCE Loss: 1.0589104890823364
Epoch 22 / 500 | iteration 5 / 30 | Total Loss: 6.7950334548950195 | KNN Loss: 5.731474876403809 | BCE Loss: 1.0635583400726318
Epoch 22 / 500 | iteration 10 / 30 | Total Loss: 6.8036208152771 | KNN Loss: 5.738285064697266 | BCE Loss: 1.0653358697891235
Epoch 22 / 500 | iteration 15 / 30 | Total Loss: 6.782561779022217 | KNN Loss: 5.733485221862793 | BCE Loss: 1.0490766763687134
Epoch 22 / 500 | iteration 20 / 30 | Total Loss: 6.798057556152344 | KNN Loss: 5.735341548919678 | BCE Loss: 1.062716007232666
Epoch 22 / 500 | iteration 25 / 30 | Total Loss: 6.756974697113037 | KNN Loss: 5.7236127853393555 | BCE Loss:

Epoch 32 / 500 | iteration 15 / 30 | Total Loss: 6.715782642364502 | KNN Loss: 5.660604476928711 | BCE Loss: 1.0551780462265015
Epoch 32 / 500 | iteration 20 / 30 | Total Loss: 6.747676849365234 | KNN Loss: 5.711723327636719 | BCE Loss: 1.0359532833099365
Epoch 32 / 500 | iteration 25 / 30 | Total Loss: 6.675665378570557 | KNN Loss: 5.6262688636779785 | BCE Loss: 1.0493966341018677
Epoch 33 / 500 | iteration 0 / 30 | Total Loss: 6.724010467529297 | KNN Loss: 5.652440547943115 | BCE Loss: 1.0715696811676025
Epoch 33 / 500 | iteration 5 / 30 | Total Loss: 6.698491096496582 | KNN Loss: 5.65304708480835 | BCE Loss: 1.0454440116882324
Epoch 33 / 500 | iteration 10 / 30 | Total Loss: 6.679124355316162 | KNN Loss: 5.627102851867676 | BCE Loss: 1.0520216226577759
Epoch 33 / 500 | iteration 15 / 30 | Total Loss: 6.677981853485107 | KNN Loss: 5.628269672393799 | BCE Loss: 1.0497121810913086
Epoch 33 / 500 | iteration 20 / 30 | Total Loss: 6.7041120529174805 | KNN Loss: 5.648830890655518 | BCE Lo

Epoch 43 / 500 | iteration 10 / 30 | Total Loss: 6.689226150512695 | KNN Loss: 5.645828723907471 | BCE Loss: 1.0433974266052246
Epoch 43 / 500 | iteration 15 / 30 | Total Loss: 6.715989589691162 | KNN Loss: 5.688199520111084 | BCE Loss: 1.0277901887893677
Epoch 43 / 500 | iteration 20 / 30 | Total Loss: 6.751047134399414 | KNN Loss: 5.682516098022461 | BCE Loss: 1.0685312747955322
Epoch 43 / 500 | iteration 25 / 30 | Total Loss: 6.694040775299072 | KNN Loss: 5.638171195983887 | BCE Loss: 1.0558695793151855
Epoch 44 / 500 | iteration 0 / 30 | Total Loss: 6.681375503540039 | KNN Loss: 5.631553649902344 | BCE Loss: 1.0498217344284058
Epoch 44 / 500 | iteration 5 / 30 | Total Loss: 6.676664352416992 | KNN Loss: 5.613317966461182 | BCE Loss: 1.0633466243743896
Epoch 44 / 500 | iteration 10 / 30 | Total Loss: 6.730112075805664 | KNN Loss: 5.698916435241699 | BCE Loss: 1.0311956405639648
Epoch 44 / 500 | iteration 15 / 30 | Total Loss: 6.651851654052734 | KNN Loss: 5.611672878265381 | BCE Los

Epoch 54 / 500 | iteration 5 / 30 | Total Loss: 6.7017822265625 | KNN Loss: 5.63620662689209 | BCE Loss: 1.0655755996704102
Epoch 54 / 500 | iteration 10 / 30 | Total Loss: 6.650291442871094 | KNN Loss: 5.609005451202393 | BCE Loss: 1.0412859916687012
Epoch 54 / 500 | iteration 15 / 30 | Total Loss: 6.724653720855713 | KNN Loss: 5.651760101318359 | BCE Loss: 1.072893738746643
Epoch 54 / 500 | iteration 20 / 30 | Total Loss: 6.634978294372559 | KNN Loss: 5.607069969177246 | BCE Loss: 1.0279083251953125
Epoch 54 / 500 | iteration 25 / 30 | Total Loss: 6.656762599945068 | KNN Loss: 5.621818542480469 | BCE Loss: 1.0349440574645996
Epoch 55 / 500 | iteration 0 / 30 | Total Loss: 6.708695411682129 | KNN Loss: 5.6456146240234375 | BCE Loss: 1.0630805492401123
Epoch 55 / 500 | iteration 5 / 30 | Total Loss: 6.669973373413086 | KNN Loss: 5.634293556213379 | BCE Loss: 1.0356799364089966
Epoch 55 / 500 | iteration 10 / 30 | Total Loss: 6.646852493286133 | KNN Loss: 5.611717224121094 | BCE Loss: 1

Epoch 65 / 500 | iteration 0 / 30 | Total Loss: 6.736957550048828 | KNN Loss: 5.690012454986572 | BCE Loss: 1.046945333480835
Epoch 65 / 500 | iteration 5 / 30 | Total Loss: 6.6618971824646 | KNN Loss: 5.604918003082275 | BCE Loss: 1.0569792985916138
Epoch 65 / 500 | iteration 10 / 30 | Total Loss: 6.634238243103027 | KNN Loss: 5.608509540557861 | BCE Loss: 1.0257289409637451
Epoch 65 / 500 | iteration 15 / 30 | Total Loss: 6.633997440338135 | KNN Loss: 5.595698356628418 | BCE Loss: 1.0382990837097168
Epoch 65 / 500 | iteration 20 / 30 | Total Loss: 6.66232967376709 | KNN Loss: 5.597160816192627 | BCE Loss: 1.0651686191558838
Epoch 65 / 500 | iteration 25 / 30 | Total Loss: 6.634922504425049 | KNN Loss: 5.600239276885986 | BCE Loss: 1.0346832275390625
Epoch 66 / 500 | iteration 0 / 30 | Total Loss: 6.635773658752441 | KNN Loss: 5.593874931335449 | BCE Loss: 1.0418987274169922
Epoch 66 / 500 | iteration 5 / 30 | Total Loss: 6.655123710632324 | KNN Loss: 5.616445541381836 | BCE Loss: 1.0

Epoch 75 / 500 | iteration 20 / 30 | Total Loss: 6.667660713195801 | KNN Loss: 5.6298017501831055 | BCE Loss: 1.0378589630126953
Epoch 75 / 500 | iteration 25 / 30 | Total Loss: 6.658428192138672 | KNN Loss: 5.603091716766357 | BCE Loss: 1.0553362369537354
Epoch 76 / 500 | iteration 0 / 30 | Total Loss: 6.694732666015625 | KNN Loss: 5.659072399139404 | BCE Loss: 1.0356600284576416
Epoch 76 / 500 | iteration 5 / 30 | Total Loss: 6.6935133934021 | KNN Loss: 5.621359825134277 | BCE Loss: 1.0721535682678223
Epoch 76 / 500 | iteration 10 / 30 | Total Loss: 6.727113723754883 | KNN Loss: 5.693207740783691 | BCE Loss: 1.0339057445526123
Epoch 76 / 500 | iteration 15 / 30 | Total Loss: 6.6981401443481445 | KNN Loss: 5.662759780883789 | BCE Loss: 1.0353806018829346
Epoch 76 / 500 | iteration 20 / 30 | Total Loss: 6.658168315887451 | KNN Loss: 5.62257194519043 | BCE Loss: 1.035596489906311
Epoch 76 / 500 | iteration 25 / 30 | Total Loss: 6.688642978668213 | KNN Loss: 5.671849727630615 | BCE Loss:

Epoch 86 / 500 | iteration 10 / 30 | Total Loss: 6.74531364440918 | KNN Loss: 5.696653366088867 | BCE Loss: 1.0486605167388916
Epoch 86 / 500 | iteration 15 / 30 | Total Loss: 6.667490482330322 | KNN Loss: 5.6081719398498535 | BCE Loss: 1.0593184232711792
Epoch 86 / 500 | iteration 20 / 30 | Total Loss: 6.663773536682129 | KNN Loss: 5.611473083496094 | BCE Loss: 1.052300214767456
Epoch 86 / 500 | iteration 25 / 30 | Total Loss: 6.668682098388672 | KNN Loss: 5.627532958984375 | BCE Loss: 1.041149377822876
Epoch 87 / 500 | iteration 0 / 30 | Total Loss: 6.7231645584106445 | KNN Loss: 5.671037673950195 | BCE Loss: 1.0521271228790283
Epoch 87 / 500 | iteration 5 / 30 | Total Loss: 6.650536060333252 | KNN Loss: 5.609241008758545 | BCE Loss: 1.0412949323654175
Epoch 87 / 500 | iteration 10 / 30 | Total Loss: 6.635078430175781 | KNN Loss: 5.594808101654053 | BCE Loss: 1.0402705669403076
Epoch 87 / 500 | iteration 15 / 30 | Total Loss: 6.7622246742248535 | KNN Loss: 5.687489032745361 | BCE Los

Epoch 97 / 500 | iteration 0 / 30 | Total Loss: 6.735601902008057 | KNN Loss: 5.686628818511963 | BCE Loss: 1.0489730834960938
Epoch 97 / 500 | iteration 5 / 30 | Total Loss: 6.700606346130371 | KNN Loss: 5.67059326171875 | BCE Loss: 1.0300133228302002
Epoch 97 / 500 | iteration 10 / 30 | Total Loss: 6.711550235748291 | KNN Loss: 5.680934906005859 | BCE Loss: 1.0306153297424316
Epoch 97 / 500 | iteration 15 / 30 | Total Loss: 6.705007076263428 | KNN Loss: 5.681201457977295 | BCE Loss: 1.0238054990768433
Epoch 97 / 500 | iteration 20 / 30 | Total Loss: 6.7050275802612305 | KNN Loss: 5.662910461425781 | BCE Loss: 1.0421173572540283
Epoch 97 / 500 | iteration 25 / 30 | Total Loss: 6.66266393661499 | KNN Loss: 5.610222816467285 | BCE Loss: 1.052441120147705
Epoch 98 / 500 | iteration 0 / 30 | Total Loss: 6.671745300292969 | KNN Loss: 5.601093292236328 | BCE Loss: 1.0706522464752197
Epoch 98 / 500 | iteration 5 / 30 | Total Loss: 6.631798267364502 | KNN Loss: 5.594435691833496 | BCE Loss: 1

Epoch 107 / 500 | iteration 20 / 30 | Total Loss: 6.642368316650391 | KNN Loss: 5.614650249481201 | BCE Loss: 1.0277180671691895
Epoch 107 / 500 | iteration 25 / 30 | Total Loss: 6.661529541015625 | KNN Loss: 5.622446060180664 | BCE Loss: 1.03908371925354
Epoch 108 / 500 | iteration 0 / 30 | Total Loss: 6.723221778869629 | KNN Loss: 5.6963300704956055 | BCE Loss: 1.0268919467926025
Epoch 108 / 500 | iteration 5 / 30 | Total Loss: 6.674414157867432 | KNN Loss: 5.621864318847656 | BCE Loss: 1.0525498390197754
Epoch 108 / 500 | iteration 10 / 30 | Total Loss: 6.699146747589111 | KNN Loss: 5.639377593994141 | BCE Loss: 1.0597691535949707
Epoch 108 / 500 | iteration 15 / 30 | Total Loss: 6.672486305236816 | KNN Loss: 5.6222944259643555 | BCE Loss: 1.05019211769104
Epoch 108 / 500 | iteration 20 / 30 | Total Loss: 6.6783671379089355 | KNN Loss: 5.62606143951416 | BCE Loss: 1.0523056983947754
Epoch 108 / 500 | iteration 25 / 30 | Total Loss: 6.655538082122803 | KNN Loss: 5.595443248748779 | B

Epoch 118 / 500 | iteration 10 / 30 | Total Loss: 6.654528617858887 | KNN Loss: 5.597446441650391 | BCE Loss: 1.057082176208496
Epoch 118 / 500 | iteration 15 / 30 | Total Loss: 6.6501688957214355 | KNN Loss: 5.618995189666748 | BCE Loss: 1.0311737060546875
Epoch 118 / 500 | iteration 20 / 30 | Total Loss: 6.6280622482299805 | KNN Loss: 5.590504169464111 | BCE Loss: 1.0375579595565796
Epoch 118 / 500 | iteration 25 / 30 | Total Loss: 6.690496921539307 | KNN Loss: 5.634491443634033 | BCE Loss: 1.0560054779052734
Epoch 119 / 500 | iteration 0 / 30 | Total Loss: 6.752155303955078 | KNN Loss: 5.720042705535889 | BCE Loss: 1.0321123600006104
Epoch 119 / 500 | iteration 5 / 30 | Total Loss: 6.664290428161621 | KNN Loss: 5.6135969161987305 | BCE Loss: 1.0506936311721802
Epoch 119 / 500 | iteration 10 / 30 | Total Loss: 6.708821773529053 | KNN Loss: 5.636144161224365 | BCE Loss: 1.072677493095398
Epoch 119 / 500 | iteration 15 / 30 | Total Loss: 6.670483589172363 | KNN Loss: 5.6188225746154785

Epoch 129 / 500 | iteration 0 / 30 | Total Loss: 6.710764408111572 | KNN Loss: 5.639047622680664 | BCE Loss: 1.0717167854309082
Epoch 129 / 500 | iteration 5 / 30 | Total Loss: 6.645365238189697 | KNN Loss: 5.604659080505371 | BCE Loss: 1.0407062768936157
Epoch 129 / 500 | iteration 10 / 30 | Total Loss: 6.652410507202148 | KNN Loss: 5.593637466430664 | BCE Loss: 1.0587730407714844
Epoch 129 / 500 | iteration 15 / 30 | Total Loss: 6.699736595153809 | KNN Loss: 5.66482400894165 | BCE Loss: 1.0349125862121582
Epoch 129 / 500 | iteration 20 / 30 | Total Loss: 6.66864013671875 | KNN Loss: 5.603036880493164 | BCE Loss: 1.065603256225586
Epoch 129 / 500 | iteration 25 / 30 | Total Loss: 6.642287731170654 | KNN Loss: 5.59523868560791 | BCE Loss: 1.0470490455627441
Epoch 130 / 500 | iteration 0 / 30 | Total Loss: 6.662049770355225 | KNN Loss: 5.635184288024902 | BCE Loss: 1.0268653631210327
Epoch 130 / 500 | iteration 5 / 30 | Total Loss: 6.7150373458862305 | KNN Loss: 5.647828102111816 | BCE 

Epoch 139 / 500 | iteration 20 / 30 | Total Loss: 6.655292510986328 | KNN Loss: 5.612277984619141 | BCE Loss: 1.0430145263671875
Epoch 139 / 500 | iteration 25 / 30 | Total Loss: 6.682567596435547 | KNN Loss: 5.64408016204834 | BCE Loss: 1.038487434387207
Epoch 140 / 500 | iteration 0 / 30 | Total Loss: 6.677860260009766 | KNN Loss: 5.640701770782471 | BCE Loss: 1.0371583700180054
Epoch 140 / 500 | iteration 5 / 30 | Total Loss: 6.692222595214844 | KNN Loss: 5.628222942352295 | BCE Loss: 1.0639994144439697
Epoch 140 / 500 | iteration 10 / 30 | Total Loss: 6.6839470863342285 | KNN Loss: 5.628842830657959 | BCE Loss: 1.0551042556762695
Epoch 140 / 500 | iteration 15 / 30 | Total Loss: 6.660314083099365 | KNN Loss: 5.618414878845215 | BCE Loss: 1.0418992042541504
Epoch 140 / 500 | iteration 20 / 30 | Total Loss: 6.633325576782227 | KNN Loss: 5.598674774169922 | BCE Loss: 1.0346505641937256
Epoch 140 / 500 | iteration 25 / 30 | Total Loss: 6.625510215759277 | KNN Loss: 5.595191955566406 | 

Epoch 150 / 500 | iteration 10 / 30 | Total Loss: 6.6834821701049805 | KNN Loss: 5.620075702667236 | BCE Loss: 1.0634067058563232
Epoch 150 / 500 | iteration 15 / 30 | Total Loss: 6.673224925994873 | KNN Loss: 5.6281538009643555 | BCE Loss: 1.0450712442398071
Epoch 150 / 500 | iteration 20 / 30 | Total Loss: 6.62900447845459 | KNN Loss: 5.6066203117370605 | BCE Loss: 1.0223839282989502
Epoch 150 / 500 | iteration 25 / 30 | Total Loss: 6.703327655792236 | KNN Loss: 5.621386528015137 | BCE Loss: 1.08194100856781
Epoch 151 / 500 | iteration 0 / 30 | Total Loss: 6.762264728546143 | KNN Loss: 5.725063800811768 | BCE Loss: 1.0372010469436646
Epoch 151 / 500 | iteration 5 / 30 | Total Loss: 6.696443557739258 | KNN Loss: 5.658609390258789 | BCE Loss: 1.0378340482711792
Epoch 151 / 500 | iteration 10 / 30 | Total Loss: 6.6345367431640625 | KNN Loss: 5.602720260620117 | BCE Loss: 1.0318162441253662
Epoch 151 / 500 | iteration 15 / 30 | Total Loss: 6.686225891113281 | KNN Loss: 5.617949962615967 

Epoch 161 / 500 | iteration 0 / 30 | Total Loss: 6.643012046813965 | KNN Loss: 5.593273162841797 | BCE Loss: 1.049739122390747
Epoch 161 / 500 | iteration 5 / 30 | Total Loss: 6.7042155265808105 | KNN Loss: 5.6444549560546875 | BCE Loss: 1.059760570526123
Epoch 161 / 500 | iteration 10 / 30 | Total Loss: 6.652129173278809 | KNN Loss: 5.60851526260376 | BCE Loss: 1.043614149093628
Epoch 161 / 500 | iteration 15 / 30 | Total Loss: 6.691746234893799 | KNN Loss: 5.631418704986572 | BCE Loss: 1.0603275299072266
Epoch 161 / 500 | iteration 20 / 30 | Total Loss: 6.6456475257873535 | KNN Loss: 5.615178108215332 | BCE Loss: 1.0304694175720215
Epoch 161 / 500 | iteration 25 / 30 | Total Loss: 6.675307273864746 | KNN Loss: 5.622501850128174 | BCE Loss: 1.0528051853179932
Epoch 162 / 500 | iteration 0 / 30 | Total Loss: 6.678328514099121 | KNN Loss: 5.60626220703125 | BCE Loss: 1.072066068649292
Epoch 162 / 500 | iteration 5 / 30 | Total Loss: 6.654244422912598 | KNN Loss: 5.621358394622803 | BCE 

Epoch 171 / 500 | iteration 20 / 30 | Total Loss: 6.627845287322998 | KNN Loss: 5.598957538604736 | BCE Loss: 1.0288878679275513
Epoch 171 / 500 | iteration 25 / 30 | Total Loss: 6.625782012939453 | KNN Loss: 5.590404510498047 | BCE Loss: 1.0353777408599854
Epoch 172 / 500 | iteration 0 / 30 | Total Loss: 6.681704998016357 | KNN Loss: 5.623116493225098 | BCE Loss: 1.0585883855819702
Epoch 172 / 500 | iteration 5 / 30 | Total Loss: 6.660345554351807 | KNN Loss: 5.609485626220703 | BCE Loss: 1.0508599281311035
Epoch 172 / 500 | iteration 10 / 30 | Total Loss: 6.68867826461792 | KNN Loss: 5.630383491516113 | BCE Loss: 1.0582947731018066
Epoch 172 / 500 | iteration 15 / 30 | Total Loss: 6.661768913269043 | KNN Loss: 5.612947463989258 | BCE Loss: 1.0488216876983643
Epoch 172 / 500 | iteration 20 / 30 | Total Loss: 6.63248872756958 | KNN Loss: 5.606969833374023 | BCE Loss: 1.0255188941955566
Epoch 172 / 500 | iteration 25 / 30 | Total Loss: 6.677586078643799 | KNN Loss: 5.62206506729126 | BC

Epoch 182 / 500 | iteration 10 / 30 | Total Loss: 6.681608200073242 | KNN Loss: 5.62089204788208 | BCE Loss: 1.0607163906097412
Epoch 182 / 500 | iteration 15 / 30 | Total Loss: 6.642932891845703 | KNN Loss: 5.600748538970947 | BCE Loss: 1.0421844720840454
Epoch 182 / 500 | iteration 20 / 30 | Total Loss: 6.660069465637207 | KNN Loss: 5.605281829833984 | BCE Loss: 1.0547873973846436
Epoch 182 / 500 | iteration 25 / 30 | Total Loss: 6.716601848602295 | KNN Loss: 5.672891616821289 | BCE Loss: 1.0437101125717163
Epoch   183: reducing learning rate of group 0 to 2.0177e-04.
Epoch 183 / 500 | iteration 0 / 30 | Total Loss: 6.65536642074585 | KNN Loss: 5.608720779418945 | BCE Loss: 1.0466456413269043
Epoch 183 / 500 | iteration 5 / 30 | Total Loss: 6.672023773193359 | KNN Loss: 5.629927635192871 | BCE Loss: 1.0420962572097778
Epoch 183 / 500 | iteration 10 / 30 | Total Loss: 6.65839147567749 | KNN Loss: 5.632142066955566 | BCE Loss: 1.0262494087219238
Epoch 183 / 500 | iteration 15 / 30 | To

Epoch 193 / 500 | iteration 0 / 30 | Total Loss: 6.706310749053955 | KNN Loss: 5.662971496582031 | BCE Loss: 1.0433393716812134
Epoch 193 / 500 | iteration 5 / 30 | Total Loss: 6.690129280090332 | KNN Loss: 5.653651714324951 | BCE Loss: 1.0364773273468018
Epoch 193 / 500 | iteration 10 / 30 | Total Loss: 6.638817310333252 | KNN Loss: 5.598745822906494 | BCE Loss: 1.0400714874267578
Epoch 193 / 500 | iteration 15 / 30 | Total Loss: 6.653109073638916 | KNN Loss: 5.6159491539001465 | BCE Loss: 1.03715980052948
Epoch 193 / 500 | iteration 20 / 30 | Total Loss: 6.654963493347168 | KNN Loss: 5.621153831481934 | BCE Loss: 1.0338096618652344
Epoch 193 / 500 | iteration 25 / 30 | Total Loss: 6.737525939941406 | KNN Loss: 5.679196357727051 | BCE Loss: 1.0583298206329346
Epoch   194: reducing learning rate of group 0 to 1.4124e-04.
Epoch 194 / 500 | iteration 0 / 30 | Total Loss: 6.645843505859375 | KNN Loss: 5.597964286804199 | BCE Loss: 1.0478792190551758
Epoch 194 / 500 | iteration 5 / 30 | To

Epoch 203 / 500 | iteration 20 / 30 | Total Loss: 6.63314151763916 | KNN Loss: 5.5923285484313965 | BCE Loss: 1.0408128499984741
Epoch 203 / 500 | iteration 25 / 30 | Total Loss: 6.744845390319824 | KNN Loss: 5.714748382568359 | BCE Loss: 1.030097246170044
Epoch 204 / 500 | iteration 0 / 30 | Total Loss: 6.656296253204346 | KNN Loss: 5.6369709968566895 | BCE Loss: 1.0193252563476562
Epoch 204 / 500 | iteration 5 / 30 | Total Loss: 6.642267227172852 | KNN Loss: 5.602746486663818 | BCE Loss: 1.0395206212997437
Epoch 204 / 500 | iteration 10 / 30 | Total Loss: 6.686194896697998 | KNN Loss: 5.618226528167725 | BCE Loss: 1.0679683685302734
Epoch 204 / 500 | iteration 15 / 30 | Total Loss: 6.702664375305176 | KNN Loss: 5.674350738525391 | BCE Loss: 1.028313398361206
Epoch 204 / 500 | iteration 20 / 30 | Total Loss: 6.676589012145996 | KNN Loss: 5.624922275543213 | BCE Loss: 1.0516666173934937
Epoch 204 / 500 | iteration 25 / 30 | Total Loss: 6.745023727416992 | KNN Loss: 5.671672344207764 | 

Epoch 214 / 500 | iteration 10 / 30 | Total Loss: 6.65761661529541 | KNN Loss: 5.591122627258301 | BCE Loss: 1.0664937496185303
Epoch 214 / 500 | iteration 15 / 30 | Total Loss: 6.626102447509766 | KNN Loss: 5.5947699546813965 | BCE Loss: 1.0313327312469482
Epoch 214 / 500 | iteration 20 / 30 | Total Loss: 6.708847522735596 | KNN Loss: 5.641964912414551 | BCE Loss: 1.0668824911117554
Epoch 214 / 500 | iteration 25 / 30 | Total Loss: 6.6160125732421875 | KNN Loss: 5.5957465171813965 | BCE Loss: 1.020265817642212
Epoch 215 / 500 | iteration 0 / 30 | Total Loss: 6.690831661224365 | KNN Loss: 5.642663478851318 | BCE Loss: 1.0481681823730469
Epoch 215 / 500 | iteration 5 / 30 | Total Loss: 6.671355247497559 | KNN Loss: 5.611227512359619 | BCE Loss: 1.0601274967193604
Epoch 215 / 500 | iteration 10 / 30 | Total Loss: 6.667659282684326 | KNN Loss: 5.60070276260376 | BCE Loss: 1.0669564008712769
Epoch 215 / 500 | iteration 15 / 30 | Total Loss: 6.693853378295898 | KNN Loss: 5.625725269317627 |

Epoch 225 / 500 | iteration 0 / 30 | Total Loss: 6.7488203048706055 | KNN Loss: 5.6866888999938965 | BCE Loss: 1.0621312856674194
Epoch 225 / 500 | iteration 5 / 30 | Total Loss: 6.63664436340332 | KNN Loss: 5.601637363433838 | BCE Loss: 1.0350072383880615
Epoch 225 / 500 | iteration 10 / 30 | Total Loss: 6.680408477783203 | KNN Loss: 5.6106696128845215 | BCE Loss: 1.069738745689392
Epoch 225 / 500 | iteration 15 / 30 | Total Loss: 6.670209884643555 | KNN Loss: 5.620833396911621 | BCE Loss: 1.049376368522644
Epoch 225 / 500 | iteration 20 / 30 | Total Loss: 6.681789875030518 | KNN Loss: 5.611025333404541 | BCE Loss: 1.0707645416259766
Epoch 225 / 500 | iteration 25 / 30 | Total Loss: 6.673892974853516 | KNN Loss: 5.638190269470215 | BCE Loss: 1.0357025861740112
Epoch 226 / 500 | iteration 0 / 30 | Total Loss: 6.705818176269531 | KNN Loss: 5.663251876831055 | BCE Loss: 1.0425662994384766
Epoch 226 / 500 | iteration 5 / 30 | Total Loss: 6.725537300109863 | KNN Loss: 5.6648664474487305 | 

Epoch 235 / 500 | iteration 20 / 30 | Total Loss: 6.676481246948242 | KNN Loss: 5.637205123901367 | BCE Loss: 1.039275884628296
Epoch 235 / 500 | iteration 25 / 30 | Total Loss: 6.6517229080200195 | KNN Loss: 5.590928554534912 | BCE Loss: 1.0607943534851074
Epoch 236 / 500 | iteration 0 / 30 | Total Loss: 6.660060405731201 | KNN Loss: 5.597643852233887 | BCE Loss: 1.0624165534973145
Epoch 236 / 500 | iteration 5 / 30 | Total Loss: 6.682954788208008 | KNN Loss: 5.612449645996094 | BCE Loss: 1.0705053806304932
Epoch 236 / 500 | iteration 10 / 30 | Total Loss: 6.718941688537598 | KNN Loss: 5.681159019470215 | BCE Loss: 1.037782907485962
Epoch 236 / 500 | iteration 15 / 30 | Total Loss: 6.616495132446289 | KNN Loss: 5.595297336578369 | BCE Loss: 1.0211975574493408
Epoch 236 / 500 | iteration 20 / 30 | Total Loss: 6.6557536125183105 | KNN Loss: 5.62200403213501 | BCE Loss: 1.0337495803833008
Epoch 236 / 500 | iteration 25 / 30 | Total Loss: 6.655881881713867 | KNN Loss: 5.6051788330078125 |

Epoch 246 / 500 | iteration 10 / 30 | Total Loss: 6.679064750671387 | KNN Loss: 5.6255292892456055 | BCE Loss: 1.0535354614257812
Epoch 246 / 500 | iteration 15 / 30 | Total Loss: 6.640125751495361 | KNN Loss: 5.611093521118164 | BCE Loss: 1.0290322303771973
Epoch 246 / 500 | iteration 20 / 30 | Total Loss: 6.654121398925781 | KNN Loss: 5.597908973693848 | BCE Loss: 1.056212306022644
Epoch 246 / 500 | iteration 25 / 30 | Total Loss: 6.663610935211182 | KNN Loss: 5.62062931060791 | BCE Loss: 1.042981505393982
Epoch 247 / 500 | iteration 0 / 30 | Total Loss: 6.671507358551025 | KNN Loss: 5.643348693847656 | BCE Loss: 1.0281585454940796
Epoch 247 / 500 | iteration 5 / 30 | Total Loss: 6.617877006530762 | KNN Loss: 5.598084926605225 | BCE Loss: 1.019791841506958
Epoch 247 / 500 | iteration 10 / 30 | Total Loss: 6.672032356262207 | KNN Loss: 5.607692241668701 | BCE Loss: 1.0643399953842163
Epoch 247 / 500 | iteration 15 / 30 | Total Loss: 6.6544013023376465 | KNN Loss: 5.597583770751953 | B

Epoch 257 / 500 | iteration 0 / 30 | Total Loss: 6.647035598754883 | KNN Loss: 5.5970282554626465 | BCE Loss: 1.0500075817108154
Epoch 257 / 500 | iteration 5 / 30 | Total Loss: 6.68233585357666 | KNN Loss: 5.617194175720215 | BCE Loss: 1.0651416778564453
Epoch 257 / 500 | iteration 10 / 30 | Total Loss: 6.690984725952148 | KNN Loss: 5.653266906738281 | BCE Loss: 1.0377180576324463
Epoch 257 / 500 | iteration 15 / 30 | Total Loss: 6.6489577293396 | KNN Loss: 5.59602689743042 | BCE Loss: 1.0529308319091797
Epoch 257 / 500 | iteration 20 / 30 | Total Loss: 6.697073936462402 | KNN Loss: 5.661931037902832 | BCE Loss: 1.0351431369781494
Epoch 257 / 500 | iteration 25 / 30 | Total Loss: 6.632165908813477 | KNN Loss: 5.593078136444092 | BCE Loss: 1.0390878915786743
Epoch 258 / 500 | iteration 0 / 30 | Total Loss: 6.637179851531982 | KNN Loss: 5.594545364379883 | BCE Loss: 1.0426344871520996
Epoch 258 / 500 | iteration 5 / 30 | Total Loss: 6.656878471374512 | KNN Loss: 5.594600677490234 | BCE 

Epoch 267 / 500 | iteration 20 / 30 | Total Loss: 6.689516544342041 | KNN Loss: 5.6289472579956055 | BCE Loss: 1.0605692863464355
Epoch 267 / 500 | iteration 25 / 30 | Total Loss: 6.746791362762451 | KNN Loss: 5.7048773765563965 | BCE Loss: 1.0419139862060547
Epoch 268 / 500 | iteration 0 / 30 | Total Loss: 6.665771007537842 | KNN Loss: 5.611189842224121 | BCE Loss: 1.0545812845230103
Epoch 268 / 500 | iteration 5 / 30 | Total Loss: 6.712515830993652 | KNN Loss: 5.6839375495910645 | BCE Loss: 1.0285780429840088
Epoch 268 / 500 | iteration 10 / 30 | Total Loss: 6.671202182769775 | KNN Loss: 5.594386100769043 | BCE Loss: 1.0768160820007324
Epoch 268 / 500 | iteration 15 / 30 | Total Loss: 6.713964462280273 | KNN Loss: 5.637327194213867 | BCE Loss: 1.0766371488571167
Epoch 268 / 500 | iteration 20 / 30 | Total Loss: 6.661250114440918 | KNN Loss: 5.639242172241211 | BCE Loss: 1.022007703781128
Epoch 268 / 500 | iteration 25 / 30 | Total Loss: 6.665184020996094 | KNN Loss: 5.611840724945068

Epoch 278 / 500 | iteration 10 / 30 | Total Loss: 6.702492713928223 | KNN Loss: 5.662189960479736 | BCE Loss: 1.0403025150299072
Epoch 278 / 500 | iteration 15 / 30 | Total Loss: 6.716117858886719 | KNN Loss: 5.6305389404296875 | BCE Loss: 1.0855786800384521
Epoch 278 / 500 | iteration 20 / 30 | Total Loss: 6.6608076095581055 | KNN Loss: 5.601461887359619 | BCE Loss: 1.0593459606170654
Epoch 278 / 500 | iteration 25 / 30 | Total Loss: 6.646828651428223 | KNN Loss: 5.595900058746338 | BCE Loss: 1.0509283542633057
Epoch 279 / 500 | iteration 0 / 30 | Total Loss: 6.735372543334961 | KNN Loss: 5.673603534698486 | BCE Loss: 1.0617690086364746
Epoch 279 / 500 | iteration 5 / 30 | Total Loss: 6.704068660736084 | KNN Loss: 5.630214214324951 | BCE Loss: 1.0738544464111328
Epoch 279 / 500 | iteration 10 / 30 | Total Loss: 6.738284111022949 | KNN Loss: 5.676184177398682 | BCE Loss: 1.062099814414978
Epoch 279 / 500 | iteration 15 / 30 | Total Loss: 6.7389068603515625 | KNN Loss: 5.672243595123291

Epoch 289 / 500 | iteration 0 / 30 | Total Loss: 6.693600177764893 | KNN Loss: 5.630528450012207 | BCE Loss: 1.063071608543396
Epoch 289 / 500 | iteration 5 / 30 | Total Loss: 6.691220283508301 | KNN Loss: 5.634036064147949 | BCE Loss: 1.0571839809417725
Epoch 289 / 500 | iteration 10 / 30 | Total Loss: 6.711946964263916 | KNN Loss: 5.664770126342773 | BCE Loss: 1.047176718711853
Epoch 289 / 500 | iteration 15 / 30 | Total Loss: 6.671706199645996 | KNN Loss: 5.615844249725342 | BCE Loss: 1.0558621883392334
Epoch 289 / 500 | iteration 20 / 30 | Total Loss: 6.6622633934021 | KNN Loss: 5.617457389831543 | BCE Loss: 1.044805884361267
Epoch 289 / 500 | iteration 25 / 30 | Total Loss: 6.689056396484375 | KNN Loss: 5.633315563201904 | BCE Loss: 1.0557408332824707
Epoch 290 / 500 | iteration 0 / 30 | Total Loss: 6.659138202667236 | KNN Loss: 5.609602928161621 | BCE Loss: 1.0495352745056152
Epoch 290 / 500 | iteration 5 / 30 | Total Loss: 6.652096748352051 | KNN Loss: 5.59266471862793 | BCE Los

Epoch 299 / 500 | iteration 20 / 30 | Total Loss: 6.613533020019531 | KNN Loss: 5.592226028442383 | BCE Loss: 1.0213069915771484
Epoch 299 / 500 | iteration 25 / 30 | Total Loss: 6.691777229309082 | KNN Loss: 5.679767608642578 | BCE Loss: 1.012009859085083
Epoch 300 / 500 | iteration 0 / 30 | Total Loss: 6.648251533508301 | KNN Loss: 5.597432613372803 | BCE Loss: 1.0508191585540771
Epoch 300 / 500 | iteration 5 / 30 | Total Loss: 6.65736198425293 | KNN Loss: 5.602949142456055 | BCE Loss: 1.054412841796875
Epoch 300 / 500 | iteration 10 / 30 | Total Loss: 6.666622638702393 | KNN Loss: 5.631566524505615 | BCE Loss: 1.035056233406067
Epoch 300 / 500 | iteration 15 / 30 | Total Loss: 6.679257869720459 | KNN Loss: 5.633421421051025 | BCE Loss: 1.0458364486694336
Epoch 300 / 500 | iteration 20 / 30 | Total Loss: 6.6732048988342285 | KNN Loss: 5.608818531036377 | BCE Loss: 1.0643863677978516
Epoch 300 / 500 | iteration 25 / 30 | Total Loss: 6.686565399169922 | KNN Loss: 5.632523059844971 | BC

Epoch 310 / 500 | iteration 10 / 30 | Total Loss: 6.648674964904785 | KNN Loss: 5.6248626708984375 | BCE Loss: 1.023812174797058
Epoch 310 / 500 | iteration 15 / 30 | Total Loss: 6.650419235229492 | KNN Loss: 5.613167762756348 | BCE Loss: 1.0372512340545654
Epoch 310 / 500 | iteration 20 / 30 | Total Loss: 6.65997314453125 | KNN Loss: 5.597779273986816 | BCE Loss: 1.0621936321258545
Epoch 310 / 500 | iteration 25 / 30 | Total Loss: 6.744662761688232 | KNN Loss: 5.690520763397217 | BCE Loss: 1.0541419982910156
Epoch 311 / 500 | iteration 0 / 30 | Total Loss: 6.638860702514648 | KNN Loss: 5.593068599700928 | BCE Loss: 1.0457919836044312
Epoch 311 / 500 | iteration 5 / 30 | Total Loss: 6.660667419433594 | KNN Loss: 5.603341102600098 | BCE Loss: 1.057326316833496
Epoch 311 / 500 | iteration 10 / 30 | Total Loss: 6.643486022949219 | KNN Loss: 5.606847763061523 | BCE Loss: 1.0366381406784058
Epoch 311 / 500 | iteration 15 / 30 | Total Loss: 6.693835258483887 | KNN Loss: 5.638180255889893 | B

Epoch 321 / 500 | iteration 0 / 30 | Total Loss: 6.6820268630981445 | KNN Loss: 5.620161056518555 | BCE Loss: 1.0618658065795898
Epoch 321 / 500 | iteration 5 / 30 | Total Loss: 6.637747764587402 | KNN Loss: 5.607336521148682 | BCE Loss: 1.0304112434387207
Epoch 321 / 500 | iteration 10 / 30 | Total Loss: 6.674618721008301 | KNN Loss: 5.632928848266602 | BCE Loss: 1.0416898727416992
Epoch 321 / 500 | iteration 15 / 30 | Total Loss: 6.654331684112549 | KNN Loss: 5.59250020980835 | BCE Loss: 1.0618314743041992
Epoch 321 / 500 | iteration 20 / 30 | Total Loss: 6.641572952270508 | KNN Loss: 5.595075607299805 | BCE Loss: 1.0464974641799927
Epoch 321 / 500 | iteration 25 / 30 | Total Loss: 6.684658050537109 | KNN Loss: 5.6262006759643555 | BCE Loss: 1.058457374572754
Epoch 322 / 500 | iteration 0 / 30 | Total Loss: 6.677923679351807 | KNN Loss: 5.6006622314453125 | BCE Loss: 1.0772615671157837
Epoch 322 / 500 | iteration 5 / 30 | Total Loss: 6.7279953956604 | KNN Loss: 5.667718410491943 | BC

Epoch 331 / 500 | iteration 20 / 30 | Total Loss: 6.7066330909729 | KNN Loss: 5.652093887329102 | BCE Loss: 1.0545392036437988
Epoch 331 / 500 | iteration 25 / 30 | Total Loss: 6.701410293579102 | KNN Loss: 5.661324501037598 | BCE Loss: 1.0400855541229248
Epoch 332 / 500 | iteration 0 / 30 | Total Loss: 6.634898662567139 | KNN Loss: 5.609177589416504 | BCE Loss: 1.0257210731506348
Epoch 332 / 500 | iteration 5 / 30 | Total Loss: 6.735024929046631 | KNN Loss: 5.679218292236328 | BCE Loss: 1.0558066368103027
Epoch 332 / 500 | iteration 10 / 30 | Total Loss: 6.6408891677856445 | KNN Loss: 5.616164207458496 | BCE Loss: 1.0247249603271484
Epoch 332 / 500 | iteration 15 / 30 | Total Loss: 6.727060794830322 | KNN Loss: 5.638340950012207 | BCE Loss: 1.0887197256088257
Epoch 332 / 500 | iteration 20 / 30 | Total Loss: 6.675785064697266 | KNN Loss: 5.621368885040283 | BCE Loss: 1.054416298866272
Epoch 332 / 500 | iteration 25 / 30 | Total Loss: 6.689411640167236 | KNN Loss: 5.639071941375732 | B

Epoch 342 / 500 | iteration 10 / 30 | Total Loss: 6.669587135314941 | KNN Loss: 5.617282867431641 | BCE Loss: 1.0523042678833008
Epoch 342 / 500 | iteration 15 / 30 | Total Loss: 6.65382194519043 | KNN Loss: 5.600876331329346 | BCE Loss: 1.052945613861084
Epoch 342 / 500 | iteration 20 / 30 | Total Loss: 6.653424263000488 | KNN Loss: 5.621708869934082 | BCE Loss: 1.0317153930664062
Epoch 342 / 500 | iteration 25 / 30 | Total Loss: 6.665713310241699 | KNN Loss: 5.6117658615112305 | BCE Loss: 1.0539476871490479
Epoch 343 / 500 | iteration 0 / 30 | Total Loss: 6.65416955947876 | KNN Loss: 5.605318546295166 | BCE Loss: 1.0488508939743042
Epoch 343 / 500 | iteration 5 / 30 | Total Loss: 6.677487373352051 | KNN Loss: 5.637402057647705 | BCE Loss: 1.0400854349136353
Epoch 343 / 500 | iteration 10 / 30 | Total Loss: 6.69460916519165 | KNN Loss: 5.632408618927002 | BCE Loss: 1.0622005462646484
Epoch 343 / 500 | iteration 15 / 30 | Total Loss: 6.640981674194336 | KNN Loss: 5.60848331451416 | BCE

Epoch 353 / 500 | iteration 0 / 30 | Total Loss: 6.652334213256836 | KNN Loss: 5.625709056854248 | BCE Loss: 1.0266252756118774
Epoch 353 / 500 | iteration 5 / 30 | Total Loss: 6.793625831604004 | KNN Loss: 5.735267162322998 | BCE Loss: 1.0583585500717163
Epoch 353 / 500 | iteration 10 / 30 | Total Loss: 6.6278252601623535 | KNN Loss: 5.595922946929932 | BCE Loss: 1.0319023132324219
Epoch 353 / 500 | iteration 15 / 30 | Total Loss: 6.75868034362793 | KNN Loss: 5.689879417419434 | BCE Loss: 1.0688011646270752
Epoch 353 / 500 | iteration 20 / 30 | Total Loss: 6.699810981750488 | KNN Loss: 5.596123695373535 | BCE Loss: 1.1036872863769531
Epoch 353 / 500 | iteration 25 / 30 | Total Loss: 6.683258056640625 | KNN Loss: 5.651554584503174 | BCE Loss: 1.0317037105560303
Epoch 354 / 500 | iteration 0 / 30 | Total Loss: 6.620163917541504 | KNN Loss: 5.601367473602295 | BCE Loss: 1.018796682357788
Epoch 354 / 500 | iteration 5 / 30 | Total Loss: 6.659585952758789 | KNN Loss: 5.61619758605957 | BCE

Epoch 363 / 500 | iteration 20 / 30 | Total Loss: 6.652896881103516 | KNN Loss: 5.610687732696533 | BCE Loss: 1.0422090291976929
Epoch 363 / 500 | iteration 25 / 30 | Total Loss: 6.658353805541992 | KNN Loss: 5.626745223999023 | BCE Loss: 1.0316088199615479
Epoch 364 / 500 | iteration 0 / 30 | Total Loss: 6.647310733795166 | KNN Loss: 5.6042399406433105 | BCE Loss: 1.043070912361145
Epoch 364 / 500 | iteration 5 / 30 | Total Loss: 6.6379876136779785 | KNN Loss: 5.599636077880859 | BCE Loss: 1.0383514165878296
Epoch 364 / 500 | iteration 10 / 30 | Total Loss: 6.7761430740356445 | KNN Loss: 5.715860366821289 | BCE Loss: 1.060282588005066
Epoch 364 / 500 | iteration 15 / 30 | Total Loss: 6.662109375 | KNN Loss: 5.622045040130615 | BCE Loss: 1.0400643348693848
Epoch 364 / 500 | iteration 20 / 30 | Total Loss: 6.733790874481201 | KNN Loss: 5.675692558288574 | BCE Loss: 1.0580984354019165
Epoch 364 / 500 | iteration 25 / 30 | Total Loss: 6.651797294616699 | KNN Loss: 5.597598075866699 | BCE 

Epoch 374 / 500 | iteration 10 / 30 | Total Loss: 6.644322395324707 | KNN Loss: 5.604955673217773 | BCE Loss: 1.0393668413162231
Epoch 374 / 500 | iteration 15 / 30 | Total Loss: 6.648681163787842 | KNN Loss: 5.599183559417725 | BCE Loss: 1.0494976043701172
Epoch 374 / 500 | iteration 20 / 30 | Total Loss: 6.654484272003174 | KNN Loss: 5.609583377838135 | BCE Loss: 1.0449007749557495
Epoch 374 / 500 | iteration 25 / 30 | Total Loss: 6.679830551147461 | KNN Loss: 5.626730918884277 | BCE Loss: 1.0530993938446045
Epoch 375 / 500 | iteration 0 / 30 | Total Loss: 6.65690803527832 | KNN Loss: 5.647232532501221 | BCE Loss: 1.0096756219863892
Epoch 375 / 500 | iteration 5 / 30 | Total Loss: 6.722081661224365 | KNN Loss: 5.663464069366455 | BCE Loss: 1.0586174726486206
Epoch 375 / 500 | iteration 10 / 30 | Total Loss: 6.637582778930664 | KNN Loss: 5.607883930206299 | BCE Loss: 1.0296987295150757
Epoch 375 / 500 | iteration 15 / 30 | Total Loss: 6.772745132446289 | KNN Loss: 5.732587814331055 | 

Epoch 385 / 500 | iteration 0 / 30 | Total Loss: 6.785067558288574 | KNN Loss: 5.706675052642822 | BCE Loss: 1.0783923864364624
Epoch 385 / 500 | iteration 5 / 30 | Total Loss: 6.8033294677734375 | KNN Loss: 5.736708164215088 | BCE Loss: 1.06662118434906
Epoch 385 / 500 | iteration 10 / 30 | Total Loss: 6.664425849914551 | KNN Loss: 5.626609802246094 | BCE Loss: 1.0378159284591675
Epoch 385 / 500 | iteration 15 / 30 | Total Loss: 6.6142425537109375 | KNN Loss: 5.601129055023193 | BCE Loss: 1.0131137371063232
Epoch 385 / 500 | iteration 20 / 30 | Total Loss: 6.698152542114258 | KNN Loss: 5.66974401473999 | BCE Loss: 1.0284085273742676
Epoch 385 / 500 | iteration 25 / 30 | Total Loss: 6.649499416351318 | KNN Loss: 5.591833591461182 | BCE Loss: 1.0576657056808472
Epoch 386 / 500 | iteration 0 / 30 | Total Loss: 6.650694847106934 | KNN Loss: 5.595646381378174 | BCE Loss: 1.0550482273101807
Epoch 386 / 500 | iteration 5 / 30 | Total Loss: 6.6485915184021 | KNN Loss: 5.59659481048584 | BCE L

Epoch 395 / 500 | iteration 20 / 30 | Total Loss: 6.674430847167969 | KNN Loss: 5.632808208465576 | BCE Loss: 1.0416227579116821
Epoch 395 / 500 | iteration 25 / 30 | Total Loss: 6.726423263549805 | KNN Loss: 5.688623905181885 | BCE Loss: 1.037799596786499
Epoch 396 / 500 | iteration 0 / 30 | Total Loss: 6.682547569274902 | KNN Loss: 5.6457085609436035 | BCE Loss: 1.0368390083312988
Epoch 396 / 500 | iteration 5 / 30 | Total Loss: 6.681992530822754 | KNN Loss: 5.64169454574585 | BCE Loss: 1.0402981042861938
Epoch 396 / 500 | iteration 10 / 30 | Total Loss: 6.680999755859375 | KNN Loss: 5.64704704284668 | BCE Loss: 1.0339527130126953
Epoch 396 / 500 | iteration 15 / 30 | Total Loss: 6.698088645935059 | KNN Loss: 5.653799533843994 | BCE Loss: 1.0442888736724854
Epoch 396 / 500 | iteration 20 / 30 | Total Loss: 6.665309429168701 | KNN Loss: 5.602972030639648 | BCE Loss: 1.0623373985290527
Epoch 396 / 500 | iteration 25 / 30 | Total Loss: 6.674775123596191 | KNN Loss: 5.61165714263916 | BC

Epoch 406 / 500 | iteration 10 / 30 | Total Loss: 6.667544364929199 | KNN Loss: 5.611945629119873 | BCE Loss: 1.055598497390747
Epoch 406 / 500 | iteration 15 / 30 | Total Loss: 6.681788444519043 | KNN Loss: 5.6136794090271 | BCE Loss: 1.0681092739105225
Epoch 406 / 500 | iteration 20 / 30 | Total Loss: 6.682497978210449 | KNN Loss: 5.637839317321777 | BCE Loss: 1.0446586608886719
Epoch 406 / 500 | iteration 25 / 30 | Total Loss: 6.687449932098389 | KNN Loss: 5.612906455993652 | BCE Loss: 1.0745434761047363
Epoch 407 / 500 | iteration 0 / 30 | Total Loss: 6.674480438232422 | KNN Loss: 5.640607833862305 | BCE Loss: 1.0338727235794067
Epoch 407 / 500 | iteration 5 / 30 | Total Loss: 6.705484390258789 | KNN Loss: 5.678563594818115 | BCE Loss: 1.026921033859253
Epoch 407 / 500 | iteration 10 / 30 | Total Loss: 6.649173736572266 | KNN Loss: 5.591495513916016 | BCE Loss: 1.057677984237671
Epoch 407 / 500 | iteration 15 / 30 | Total Loss: 6.700611114501953 | KNN Loss: 5.653326034545898 | BCE 

Epoch 417 / 500 | iteration 0 / 30 | Total Loss: 6.725886344909668 | KNN Loss: 5.654155731201172 | BCE Loss: 1.071730375289917
Epoch 417 / 500 | iteration 5 / 30 | Total Loss: 6.694731712341309 | KNN Loss: 5.669042110443115 | BCE Loss: 1.0256898403167725
Epoch 417 / 500 | iteration 10 / 30 | Total Loss: 6.68310546875 | KNN Loss: 5.6468682289123535 | BCE Loss: 1.0362374782562256
Epoch 417 / 500 | iteration 15 / 30 | Total Loss: 6.662860870361328 | KNN Loss: 5.625968933105469 | BCE Loss: 1.036892056465149
Epoch 417 / 500 | iteration 20 / 30 | Total Loss: 6.661364555358887 | KNN Loss: 5.608191013336182 | BCE Loss: 1.0531737804412842
Epoch 417 / 500 | iteration 25 / 30 | Total Loss: 6.633981227874756 | KNN Loss: 5.59464168548584 | BCE Loss: 1.0393396615982056
Epoch 418 / 500 | iteration 0 / 30 | Total Loss: 6.658901214599609 | KNN Loss: 5.598267078399658 | BCE Loss: 1.060633897781372
Epoch 418 / 500 | iteration 5 / 30 | Total Loss: 6.655983924865723 | KNN Loss: 5.6329345703125 | BCE Loss: 

Epoch 427 / 500 | iteration 20 / 30 | Total Loss: 6.676930904388428 | KNN Loss: 5.6487202644348145 | BCE Loss: 1.0282105207443237
Epoch 427 / 500 | iteration 25 / 30 | Total Loss: 6.688145637512207 | KNN Loss: 5.639692306518555 | BCE Loss: 1.0484533309936523
Epoch 428 / 500 | iteration 0 / 30 | Total Loss: 6.681848526000977 | KNN Loss: 5.622967720031738 | BCE Loss: 1.0588806867599487
Epoch 428 / 500 | iteration 5 / 30 | Total Loss: 6.642606735229492 | KNN Loss: 5.616644382476807 | BCE Loss: 1.0259621143341064
Epoch 428 / 500 | iteration 10 / 30 | Total Loss: 6.659778594970703 | KNN Loss: 5.6185407638549805 | BCE Loss: 1.0412378311157227
Epoch 428 / 500 | iteration 15 / 30 | Total Loss: 6.622081756591797 | KNN Loss: 5.602856636047363 | BCE Loss: 1.0192253589630127
Epoch 428 / 500 | iteration 20 / 30 | Total Loss: 6.61286735534668 | KNN Loss: 5.591809272766113 | BCE Loss: 1.0210579633712769
Epoch 428 / 500 | iteration 25 / 30 | Total Loss: 6.680312633514404 | KNN Loss: 5.598023414611816 

Epoch 438 / 500 | iteration 10 / 30 | Total Loss: 6.641079902648926 | KNN Loss: 5.6287102699279785 | BCE Loss: 1.0123693943023682
Epoch 438 / 500 | iteration 15 / 30 | Total Loss: 6.644201278686523 | KNN Loss: 5.606711387634277 | BCE Loss: 1.037489652633667
Epoch 438 / 500 | iteration 20 / 30 | Total Loss: 6.706498622894287 | KNN Loss: 5.660494804382324 | BCE Loss: 1.046003818511963
Epoch 438 / 500 | iteration 25 / 30 | Total Loss: 6.643157958984375 | KNN Loss: 5.606011390686035 | BCE Loss: 1.0371466875076294
Epoch 439 / 500 | iteration 0 / 30 | Total Loss: 6.699958801269531 | KNN Loss: 5.663671970367432 | BCE Loss: 1.0362865924835205
Epoch 439 / 500 | iteration 5 / 30 | Total Loss: 6.73884391784668 | KNN Loss: 5.681560039520264 | BCE Loss: 1.0572839975357056
Epoch 439 / 500 | iteration 10 / 30 | Total Loss: 6.711134910583496 | KNN Loss: 5.643615245819092 | BCE Loss: 1.0675196647644043
Epoch 439 / 500 | iteration 15 / 30 | Total Loss: 6.686648845672607 | KNN Loss: 5.633330345153809 | B

Epoch 449 / 500 | iteration 0 / 30 | Total Loss: 6.671750068664551 | KNN Loss: 5.6568074226379395 | BCE Loss: 1.0149428844451904
Epoch 449 / 500 | iteration 5 / 30 | Total Loss: 6.816885471343994 | KNN Loss: 5.7464599609375 | BCE Loss: 1.0704253911972046
Epoch 449 / 500 | iteration 10 / 30 | Total Loss: 6.6697187423706055 | KNN Loss: 5.591679573059082 | BCE Loss: 1.0780394077301025
Epoch 449 / 500 | iteration 15 / 30 | Total Loss: 6.688450813293457 | KNN Loss: 5.608777046203613 | BCE Loss: 1.0796736478805542
Epoch 449 / 500 | iteration 20 / 30 | Total Loss: 6.66083288192749 | KNN Loss: 5.606150150299072 | BCE Loss: 1.054682731628418
Epoch 449 / 500 | iteration 25 / 30 | Total Loss: 6.629759311676025 | KNN Loss: 5.604778289794922 | BCE Loss: 1.0249810218811035
Epoch 450 / 500 | iteration 0 / 30 | Total Loss: 6.653360366821289 | KNN Loss: 5.59627628326416 | BCE Loss: 1.057084083557129
Epoch 450 / 500 | iteration 5 / 30 | Total Loss: 6.643348693847656 | KNN Loss: 5.606680870056152 | BCE L

Epoch 459 / 500 | iteration 20 / 30 | Total Loss: 6.697253227233887 | KNN Loss: 5.628466606140137 | BCE Loss: 1.068786859512329
Epoch 459 / 500 | iteration 25 / 30 | Total Loss: 6.63754940032959 | KNN Loss: 5.617853164672852 | BCE Loss: 1.0196962356567383
Epoch 460 / 500 | iteration 0 / 30 | Total Loss: 6.67661190032959 | KNN Loss: 5.646628379821777 | BCE Loss: 1.0299835205078125
Epoch 460 / 500 | iteration 5 / 30 | Total Loss: 6.713959693908691 | KNN Loss: 5.652726650238037 | BCE Loss: 1.0612328052520752
Epoch 460 / 500 | iteration 10 / 30 | Total Loss: 6.686912536621094 | KNN Loss: 5.632821559906006 | BCE Loss: 1.054091215133667
Epoch 460 / 500 | iteration 15 / 30 | Total Loss: 6.7007598876953125 | KNN Loss: 5.641837120056152 | BCE Loss: 1.0589227676391602
Epoch 460 / 500 | iteration 20 / 30 | Total Loss: 6.736947059631348 | KNN Loss: 5.665217399597168 | BCE Loss: 1.0717295408248901
Epoch 460 / 500 | iteration 25 / 30 | Total Loss: 6.636678695678711 | KNN Loss: 5.590761661529541 | BC

Epoch 470 / 500 | iteration 10 / 30 | Total Loss: 6.671650409698486 | KNN Loss: 5.602357387542725 | BCE Loss: 1.0692931413650513
Epoch 470 / 500 | iteration 15 / 30 | Total Loss: 6.641423225402832 | KNN Loss: 5.590216159820557 | BCE Loss: 1.0512068271636963
Epoch 470 / 500 | iteration 20 / 30 | Total Loss: 6.626206874847412 | KNN Loss: 5.591987609863281 | BCE Loss: 1.0342191457748413
Epoch 470 / 500 | iteration 25 / 30 | Total Loss: 6.725613594055176 | KNN Loss: 5.67105770111084 | BCE Loss: 1.0545557737350464
Epoch 471 / 500 | iteration 0 / 30 | Total Loss: 6.614880561828613 | KNN Loss: 5.602985858917236 | BCE Loss: 1.0118944644927979
Epoch 471 / 500 | iteration 5 / 30 | Total Loss: 6.7131028175354 | KNN Loss: 5.665530681610107 | BCE Loss: 1.0475720167160034
Epoch 471 / 500 | iteration 10 / 30 | Total Loss: 6.661980628967285 | KNN Loss: 5.599809646606445 | BCE Loss: 1.0621707439422607
Epoch 471 / 500 | iteration 15 / 30 | Total Loss: 6.632842063903809 | KNN Loss: 5.598044395446777 | BC

Epoch 481 / 500 | iteration 0 / 30 | Total Loss: 6.684351921081543 | KNN Loss: 5.649083137512207 | BCE Loss: 1.0352685451507568
Epoch 481 / 500 | iteration 5 / 30 | Total Loss: 6.659478664398193 | KNN Loss: 5.603353977203369 | BCE Loss: 1.0561246871948242
Epoch 481 / 500 | iteration 10 / 30 | Total Loss: 6.66899299621582 | KNN Loss: 5.6213250160217285 | BCE Loss: 1.0476679801940918
Epoch 481 / 500 | iteration 15 / 30 | Total Loss: 6.660287857055664 | KNN Loss: 5.603072643280029 | BCE Loss: 1.0572152137756348
Epoch 481 / 500 | iteration 20 / 30 | Total Loss: 6.61958646774292 | KNN Loss: 5.595616340637207 | BCE Loss: 1.0239700078964233
Epoch 481 / 500 | iteration 25 / 30 | Total Loss: 6.686114311218262 | KNN Loss: 5.6205549240112305 | BCE Loss: 1.0655591487884521
Epoch 482 / 500 | iteration 0 / 30 | Total Loss: 6.647724628448486 | KNN Loss: 5.606989860534668 | BCE Loss: 1.0407347679138184
Epoch 482 / 500 | iteration 5 / 30 | Total Loss: 6.663987159729004 | KNN Loss: 5.633882522583008 | B

Epoch 491 / 500 | iteration 20 / 30 | Total Loss: 6.70100212097168 | KNN Loss: 5.629997253417969 | BCE Loss: 1.0710049867630005
Epoch 491 / 500 | iteration 25 / 30 | Total Loss: 6.667936325073242 | KNN Loss: 5.608964920043945 | BCE Loss: 1.0589715242385864
Epoch 492 / 500 | iteration 0 / 30 | Total Loss: 6.705996513366699 | KNN Loss: 5.622093200683594 | BCE Loss: 1.0839033126831055
Epoch 492 / 500 | iteration 5 / 30 | Total Loss: 6.620012283325195 | KNN Loss: 5.596725940704346 | BCE Loss: 1.0232864618301392
Epoch 492 / 500 | iteration 10 / 30 | Total Loss: 6.694492816925049 | KNN Loss: 5.644543170928955 | BCE Loss: 1.0499496459960938
Epoch 492 / 500 | iteration 15 / 30 | Total Loss: 6.641535758972168 | KNN Loss: 5.601391792297363 | BCE Loss: 1.0401442050933838
Epoch 492 / 500 | iteration 20 / 30 | Total Loss: 6.67543363571167 | KNN Loss: 5.627509117126465 | BCE Loss: 1.0479243993759155
Epoch 492 / 500 | iteration 25 / 30 | Total Loss: 6.679417133331299 | KNN Loss: 5.596860408782959 | B

In [7]:
outputs, iterm = model(dataset[67][0].unsqueeze(0).to(device), return_intermidiate=True)
print(outputs)
print(outputs.min())
print(outputs.max())

tensor([[ 2.8217,  3.6125,  2.5152,  3.3619,  3.3702,  0.6665,  2.4555,  1.9713,
          2.2067,  1.9203,  2.0551,  2.0604,  0.7355,  1.7611,  1.2391,  1.4528,
          2.7113,  3.0813,  2.5865,  2.2356,  1.6618,  2.8783,  2.2024,  2.5269,
          2.3406,  1.6601,  2.0477,  1.3680,  1.4390,  0.3112, -0.2259,  0.9774,
          0.2020,  0.8970,  1.4721,  1.3503,  0.9529,  3.0087,  0.7396,  1.2688,
          0.9522, -0.7081, -0.2568,  2.2572,  2.0168,  0.6907, -0.2034,  0.0830,
          1.2786,  2.3213,  1.7738,  0.1577,  1.3278,  0.5104, -0.6213,  1.0817,
          1.4182,  1.2694,  1.2766,  1.7575,  0.5297,  0.8090,  0.1186,  1.6593,
          1.2432,  1.6040, -1.7948,  0.2945,  2.2321,  2.0777,  2.4627,  0.3653,
          1.2580,  2.2580,  1.9304,  1.2560,  0.2079,  0.7257,  0.2168,  1.5396,
          0.0038,  0.3593,  1.7438, -0.3736,  0.2129, -1.0330, -2.3198, -0.2471,
          0.5458, -1.7872,  0.4219, -0.1311, -0.5523, -0.8504,  0.5359,  1.2067,
         -0.6807, -0.6634,  

In [8]:
plt.figure()
plt.plot(losses)
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [9]:
dataset.transform = torchvision.transforms.Compose([
    BinaryEncodingTransform(mapping=dataset.items_to_idx),
]
)
dataset.target_transform = torchvision.transforms.Compose([
    BinaryEncodingTransform(mapping=dataset.items_to_idx),
]
)

In [10]:
dataset_ = [d[0].cpu() for d in dataset]

In [11]:
model = model.eval().cpu()
projections = model.calculate_intermidiate(dataset_)

100%|██████████| 15/15 [00:00<00:00, 79.20it/s]


In [12]:
distances = pairwise_distances(projections)
# distances = np.triu(distances)
distances_f = distances.flatten()

plt.matshow(distances)
plt.colorbar()
plt.figure()
plt.hist(distances_f[distances_f > 0], bins=1000)
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

# Fit DBSCAN and calculate indices


In [13]:
clusters = DBSCAN(eps=0.2, min_samples=80).fit_predict(projections)
# scores = []
# best_score = float('inf')
# clusters = None
# range_ = list(range(5, 20))
# for k in tqdm(range_):
#     y = GaussianMixture(n_components=k).fit_predict(projections)
#     cur_score = davies_bouldin_score(projections, y)
#     scores.append(cur_score)
    
#     if cur_score < best_score:
#         best_score = cur_score
#         clusters = y

In [14]:
perplexity = 100

p = reduce_dims_and_plot(projections[clusters != -1],
                         y=clusters[clusters != -1],
                         title=f'perplexity: {perplexity}',
                         file_name=None,
                         perplexity=perplexity,
                         library='Multicore-TSNE',
                         perform_PCA=False,
                         projected=None,
                         figure_type='2d',
                         show_figure=True,
                         close_figure=False,
                         text=None)    

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [15]:
# from sklearn.tree import DecisionTreeClassifier
# from sklearn import tree
# from sklearn.tree import _tree

In [16]:
tensor_dataset = torch.stack(dataset_)

In [17]:
# clf = DecisionTreeClassifier(max_depth=200, min_samples_leaf=5)
# clf = clf.fit(tensor_dataset[clusters!=-1], clusters[clusters != -1])
# print(clf.score(tensor_dataset[clusters!=-1], clusters[clusters != -1]))
# print(clf.get_depth())

In [18]:
# scores = []
# for min_samples in range(1,50, 1):
#     clf = DecisionTreeClassifier(max_depth=200, min_samples_leaf=min_samples)
#     clf = clf.fit(tensor_dataset[clusters!=-1], clusters[clusters != -1])
#     scores.append(clf.score(tensor_dataset[clusters!=-1], clusters[clusters != -1]))
    
# plt.figure()
# plt.plot(list(range(1,50, 1)), scores)
# plt.show()

In [19]:
def get_rules(tree, feature_names, class_names):
    tree_ = tree.tree_
    feature_name = [
        feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!"
        for i in tree_.feature
    ]

    paths = []
    path = []
    
    def recurse(node, path, paths):
        if tree_.feature[node] != _tree.TREE_UNDEFINED:
            name = feature_name[node]
            threshold = tree_.threshold[node]
            p1, p2 = list(path), list(path)
#             p1 += [f"({name} <= {np.round(threshold, 3)})"]
            p1 += [(name, '<=', np.round(threshold, 3))]
            recurse(tree_.children_left[node], p1, paths)
            p2 += [(name, '>', np.round(threshold, 3))]
            recurse(tree_.children_right[node], p2, paths)
        else:
            path += [(tree_.value[node], tree_.n_node_samples[node])]
            paths += [path]
            
    recurse(0, path, paths)

    # sort by samples count
    samples_count = [p[-1][1] for p in paths]
    ii = list(np.argsort(samples_count))
    paths = [paths[i] for i in reversed(ii)]
    
    rules = []
    for path in paths:
        rule = []
        
        for p in path[:-1]:
            rule += [p]
        target = " then "
        if class_names is None:
            target += "response: "+str(np.round(path[-1][0][0][0],3))
        else:
            classes = path[-1][0][0]
            l = np.argmax(classes)
            target += f"class: {class_names[l]} (proba: {np.round(100.0*classes[l]/np.sum(classes),2)}%)"
           
        proba = np.round(100.0*classes[l]/np.sum(classes),2)
        target += f" | based on {path[-1][1]:,} samples"
        rule_wrapper = {'target': target, 'rule': rule, 'proba': proba}
        rules += [rule_wrapper]
        
    return rules

In [20]:
# rules = get_rules(clf, dataset.items, clusters[clusters != -1])

# for rule in rules:
#     n_pos = 0
#     for c,p,v in rule['rule']:
#         if p == '>':
#             n_pos += 1
#     rule['pos'] = n_pos

In [21]:
# plt.figure()
# probs = [r['proba'] for r in rules]
# plt.hist(probs, bins = 100)
# plt.show()

In [22]:
# rules = sorted(rules, key=lambda x:x['pos'])
# rules = [r for r in rules if r['proba'] > 50]
# print(len(rules))

In [23]:
# for i in range(17):
#     r_i = rules[i]
#     print(f"------------- rule {i} length {len(r_i)} -------------")
#     print(r_i)

# Train a Soft-Decision-Tree given the self-labels

## Prepare the dataset

In [24]:
tree_dataset = list(zip(tensor_dataset[clusters!=-1], clusters[clusters != -1]))
batch_size = 512
tree_loader = torch.utils.data.DataLoader(tree_dataset, batch_size=batch_size, shuffle=True)

# Define how we prune the weights of a node

In [25]:
def prune_node(node_weights, factor=1):
    w = node_weights.cpu().detach().numpy()
    mean_ = np.mean(w)
    std_ = np.std(w)
    node_weights[((mean_ - std_ * factor) < node_weights) & (node_weights < (mean_ + std_ * factor))] = 0
    return node_weights

def prune_node_keep(node_weights, keep=4):
    w = node_weights.cpu().detach().numpy()
    throw_idx = np.argsort(abs(w))[:-keep]
    node_weights[throw_idx] = 0
    return node_weights

def prune_tree(tree_, factor):
    new_weights = tree_.inner_nodes.weight.clone()
    for i in range(new_weights.shape[0]):
        res = prune_node_keep(new_weights[i, :], factor)
        new_weights[i, :] = res

    with torch.no_grad():
        tree_.inner_nodes.weight.copy_(new_weights)
        
def sparseness(x):
    s = []
    for i in range(x.shape[0]):
        x_ = x[i, :]
        sp = (len(x_) - torch.norm(x_, 0).item()) / len(x_)
        s.append(sp)
    return np.mean(s)

def compute_regularization_by_level(tree):
    total_reg = 0
    for i in range(tree.inner_nodes.weight.shape[0]):
        cur_level = np.floor(np.log2(i+1))
        node_reg = torch.norm(tree.inner_nodes.weight[i].view(-1), 2)
        total_reg += 2**(-cur_level) * node_reg
    return total_reg

def show_sparseness(tree):
    avg_sp = sparseness(tree.inner_nodes.weight)
    print(f"Average sparseness: {avg_sp}")
    layer = 0
    sps = []
    for i in range(tree.inner_nodes.weight.shape[0]):
        cur_layer = int(np.floor(np.log2(i+1)))
        if cur_layer != layer:
            print(f"layer {layer}: {np.mean(sps)}")
            sps = []
            layer = cur_layer

        x_ = tree.inner_nodes.weight[i, :]
        sp = (len(x_) - torch.norm(x_, 0).item()) / len(x_)
        sps.append(sp)
        
    return avg_sp

## Training configurations

In [26]:
def do_epoch(model, loader, device, log_interval, losses, accs, epoch, iteration):
    model = model.train()
    for batch_idx, (data, target) in enumerate(loader):
        iteration += 1
        data, target = data.to(device), target.to(device)

        output, penalty = tree.forward(data)

        # Loss
        loss_tree = criterion(output, target.view(-1))

        # Penalty
        loss_tree += penalty

        # Sparse regularization
#         fc_params = torch.cat([x.view(-1) for x in tree.inner_nodes.parameters()])
#         regularization = sparsity_lamda * torch.norm(fc_params, 2)
        regularization = sparsity_lamda * compute_regularization_by_level(tree)
        loss = loss_tree

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        losses.append(loss.item())

        pred = output.data.max(1)[1]
        correct = pred.eq(target.view(-1).data).sum()
        accs.append(correct.item() / data.size()[0])

        # Print training status
        if batch_idx % log_interval == 0:
            print(f"Epoch: {epoch:02d} | Batch: {batch_idx:03d} / {len(loader):03d} | Total loss: {loss.item():.3f} | Reg loss: {regularization.item():.3f} | Tree loss: {loss_tree.item():.3f} | Accuracy: {correct.item() / data.size()[0]:03f} | {round((time.time() - start_time) / iteration, 3)} sec/iter")
            
    return iteration


In [27]:
lr = 5e-3
weight_decay = 5e-4
sparsity_lamda = 2e-3
epochs = 100
output_dim = len(set(clusters))
log_interval = 1
use_cuda = device != 'cpu'

In [28]:
tree = SDT(input_dim=tensor_dataset.shape[1], output_dim=len(clusters - 1), depth=tree_depth, lamda=1e-3, use_cuda=use_cuda)
optimizer = torch.optim.Adam(tree.parameters(),
                                 lr=lr,
                                 weight_decay=weight_decay)
criterion = nn.CrossEntropyLoss()
tree = tree.to(device)

In [29]:
losses = []
accs = []
sparsity = []

In [30]:
start_time = time.time()
iteration = 0
for epoch in range(epochs):
    # Training
    avg_sp = show_sparseness(tree)
    sparsity.append(avg_sp)
    iteration = do_epoch(tree, tree_loader, device, log_interval, losses, accs, epoch, iteration)
    
    if epoch % 1 == 0:
        prune_tree(tree, factor=3)
        

Average sparseness: 0.0
layer 0: 0.0
layer 1: 0.0
layer 2: 0.0
layer 3: 0.0
layer 4: 0.0
layer 5: 0.0
layer 6: 0.0
layer 7: 0.0
layer 8: 0.0
layer 9: 0.0
layer 10: 0.0
Epoch: 00 | Batch: 000 / 030 | Total loss: 9.630 | Reg loss: 0.014 | Tree loss: 9.630 | Accuracy: 0.000000 | 12.32 sec/iter
Epoch: 00 | Batch: 001 / 030 | Total loss: 9.622 | Reg loss: 0.013 | Tree loss: 9.622 | Accuracy: 0.000000 | 9.791 sec/iter
Epoch: 00 | Batch: 002 / 030 | Total loss: 9.615 | Reg loss: 0.012 | Tree loss: 9.615 | Accuracy: 0.000000 | 8.947 sec/iter
Epoch: 00 | Batch: 003 / 030 | Total loss: 9.607 | Reg loss: 0.011 | Tree loss: 9.607 | Accuracy: 0.000000 | 8.54 sec/iter
Epoch: 00 | Batch: 004 / 030 | Total loss: 9.600 | Reg loss: 0.010 | Tree loss: 9.600 | Accuracy: 0.005859 | 8.291 sec/iter
Epoch: 00 | Batch: 005 / 030 | Total loss: 9.592 | Reg loss: 0.010 | Tree loss: 9.592 | Accuracy: 0.019531 | 8.133 sec/iter
Epoch: 00 | Batch: 006 / 030 | Total loss: 9.584 | Reg loss: 0.009 | Tree loss: 9.584 | A

layer 8: 0.9821428571428573
layer 9: 0.9821428571428573
layer 10: 0.9821428571428573
Epoch: 02 | Batch: 000 / 030 | Total loss: 9.430 | Reg loss: 0.009 | Tree loss: 9.430 | Accuracy: 0.578125 | 8.098 sec/iter
Epoch: 02 | Batch: 001 / 030 | Total loss: 9.427 | Reg loss: 0.009 | Tree loss: 9.427 | Accuracy: 0.576172 | 8.081 sec/iter
Epoch: 02 | Batch: 002 / 030 | Total loss: 9.422 | Reg loss: 0.009 | Tree loss: 9.422 | Accuracy: 0.552734 | 8.064 sec/iter
Epoch: 02 | Batch: 003 / 030 | Total loss: 9.416 | Reg loss: 0.010 | Tree loss: 9.416 | Accuracy: 0.593750 | 8.047 sec/iter
Epoch: 02 | Batch: 004 / 030 | Total loss: 9.408 | Reg loss: 0.010 | Tree loss: 9.408 | Accuracy: 0.601562 | 8.031 sec/iter
Epoch: 02 | Batch: 005 / 030 | Total loss: 9.401 | Reg loss: 0.010 | Tree loss: 9.401 | Accuracy: 0.576172 | 7.986 sec/iter
Epoch: 02 | Batch: 006 / 030 | Total loss: 9.396 | Reg loss: 0.010 | Tree loss: 9.396 | Accuracy: 0.585938 | 7.94 sec/iter
Epoch: 02 | Batch: 007 / 030 | Total loss: 9.392

Epoch: 04 | Batch: 000 / 030 | Total loss: 9.251 | Reg loss: 0.015 | Tree loss: 9.251 | Accuracy: 0.560547 | 7.648 sec/iter
Epoch: 04 | Batch: 001 / 030 | Total loss: 9.238 | Reg loss: 0.015 | Tree loss: 9.238 | Accuracy: 0.585938 | 7.626 sec/iter
Epoch: 04 | Batch: 002 / 030 | Total loss: 9.225 | Reg loss: 0.015 | Tree loss: 9.225 | Accuracy: 0.605469 | 7.621 sec/iter
Epoch: 04 | Batch: 003 / 030 | Total loss: 9.219 | Reg loss: 0.015 | Tree loss: 9.219 | Accuracy: 0.570312 | 7.614 sec/iter
Epoch: 04 | Batch: 004 / 030 | Total loss: 9.201 | Reg loss: 0.015 | Tree loss: 9.201 | Accuracy: 0.589844 | 7.595 sec/iter
Epoch: 04 | Batch: 005 / 030 | Total loss: 9.186 | Reg loss: 0.016 | Tree loss: 9.186 | Accuracy: 0.595703 | 7.574 sec/iter
Epoch: 04 | Batch: 006 / 030 | Total loss: 9.179 | Reg loss: 0.016 | Tree loss: 9.179 | Accuracy: 0.564453 | 7.552 sec/iter
Epoch: 04 | Batch: 007 / 030 | Total loss: 9.150 | Reg loss: 0.016 | Tree loss: 9.150 | Accuracy: 0.603516 | 7.53 sec/iter
Epoch: 04

Epoch: 06 | Batch: 001 / 030 | Total loss: 8.816 | Reg loss: 0.021 | Tree loss: 8.816 | Accuracy: 0.578125 | 7.315 sec/iter
Epoch: 06 | Batch: 002 / 030 | Total loss: 8.819 | Reg loss: 0.021 | Tree loss: 8.819 | Accuracy: 0.568359 | 7.314 sec/iter
Epoch: 06 | Batch: 003 / 030 | Total loss: 8.783 | Reg loss: 0.021 | Tree loss: 8.783 | Accuracy: 0.574219 | 7.313 sec/iter
Epoch: 06 | Batch: 004 / 030 | Total loss: 8.783 | Reg loss: 0.021 | Tree loss: 8.783 | Accuracy: 0.572266 | 7.312 sec/iter
Epoch: 06 | Batch: 005 / 030 | Total loss: 8.733 | Reg loss: 0.021 | Tree loss: 8.733 | Accuracy: 0.593750 | 7.312 sec/iter
Epoch: 06 | Batch: 006 / 030 | Total loss: 8.719 | Reg loss: 0.022 | Tree loss: 8.719 | Accuracy: 0.583984 | 7.311 sec/iter
Epoch: 06 | Batch: 007 / 030 | Total loss: 8.683 | Reg loss: 0.022 | Tree loss: 8.683 | Accuracy: 0.580078 | 7.31 sec/iter
Epoch: 06 | Batch: 008 / 030 | Total loss: 8.636 | Reg loss: 0.023 | Tree loss: 8.636 | Accuracy: 0.601562 | 7.308 sec/iter
Epoch: 06

Epoch: 08 | Batch: 002 / 030 | Total loss: 8.262 | Reg loss: 0.025 | Tree loss: 8.262 | Accuracy: 0.570312 | 7.184 sec/iter
Epoch: 08 | Batch: 003 / 030 | Total loss: 8.246 | Reg loss: 0.026 | Tree loss: 8.246 | Accuracy: 0.574219 | 7.183 sec/iter
Epoch: 08 | Batch: 004 / 030 | Total loss: 8.200 | Reg loss: 0.026 | Tree loss: 8.200 | Accuracy: 0.566406 | 7.183 sec/iter
Epoch: 08 | Batch: 005 / 030 | Total loss: 8.150 | Reg loss: 0.026 | Tree loss: 8.150 | Accuracy: 0.593750 | 7.182 sec/iter
Epoch: 08 | Batch: 006 / 030 | Total loss: 8.137 | Reg loss: 0.026 | Tree loss: 8.137 | Accuracy: 0.593750 | 7.182 sec/iter
Epoch: 08 | Batch: 007 / 030 | Total loss: 8.100 | Reg loss: 0.026 | Tree loss: 8.100 | Accuracy: 0.595703 | 7.181 sec/iter
Epoch: 08 | Batch: 008 / 030 | Total loss: 8.053 | Reg loss: 0.027 | Tree loss: 8.053 | Accuracy: 0.560547 | 7.181 sec/iter
Epoch: 08 | Batch: 009 / 030 | Total loss: 8.004 | Reg loss: 0.027 | Tree loss: 8.004 | Accuracy: 0.597656 | 7.18 sec/iter
Epoch: 08

Epoch: 10 | Batch: 003 / 030 | Total loss: 7.610 | Reg loss: 0.029 | Tree loss: 7.610 | Accuracy: 0.589844 | 7.154 sec/iter
Epoch: 10 | Batch: 004 / 030 | Total loss: 7.615 | Reg loss: 0.030 | Tree loss: 7.615 | Accuracy: 0.558594 | 7.152 sec/iter
Epoch: 10 | Batch: 005 / 030 | Total loss: 7.568 | Reg loss: 0.030 | Tree loss: 7.568 | Accuracy: 0.595703 | 7.15 sec/iter
Epoch: 10 | Batch: 006 / 030 | Total loss: 7.514 | Reg loss: 0.030 | Tree loss: 7.514 | Accuracy: 0.597656 | 7.148 sec/iter
Epoch: 10 | Batch: 007 / 030 | Total loss: 7.484 | Reg loss: 0.030 | Tree loss: 7.484 | Accuracy: 0.583984 | 7.147 sec/iter
Epoch: 10 | Batch: 008 / 030 | Total loss: 7.433 | Reg loss: 0.030 | Tree loss: 7.433 | Accuracy: 0.568359 | 7.146 sec/iter
Epoch: 10 | Batch: 009 / 030 | Total loss: 7.397 | Reg loss: 0.031 | Tree loss: 7.397 | Accuracy: 0.578125 | 7.145 sec/iter
Epoch: 10 | Batch: 010 / 030 | Total loss: 7.343 | Reg loss: 0.031 | Tree loss: 7.343 | Accuracy: 0.630859 | 7.144 sec/iter
Epoch: 10

Epoch: 12 | Batch: 004 / 030 | Total loss: 6.945 | Reg loss: 0.033 | Tree loss: 6.945 | Accuracy: 0.599609 | 7.119 sec/iter
Epoch: 12 | Batch: 005 / 030 | Total loss: 6.885 | Reg loss: 0.033 | Tree loss: 6.885 | Accuracy: 0.585938 | 7.117 sec/iter
Epoch: 12 | Batch: 006 / 030 | Total loss: 6.905 | Reg loss: 0.033 | Tree loss: 6.905 | Accuracy: 0.585938 | 7.116 sec/iter
Epoch: 12 | Batch: 007 / 030 | Total loss: 6.851 | Reg loss: 0.034 | Tree loss: 6.851 | Accuracy: 0.576172 | 7.115 sec/iter
Epoch: 12 | Batch: 008 / 030 | Total loss: 6.796 | Reg loss: 0.034 | Tree loss: 6.796 | Accuracy: 0.580078 | 7.114 sec/iter
Epoch: 12 | Batch: 009 / 030 | Total loss: 6.774 | Reg loss: 0.034 | Tree loss: 6.774 | Accuracy: 0.615234 | 7.114 sec/iter
Epoch: 12 | Batch: 010 / 030 | Total loss: 6.765 | Reg loss: 0.034 | Tree loss: 6.765 | Accuracy: 0.554688 | 7.112 sec/iter
Epoch: 12 | Batch: 011 / 030 | Total loss: 6.708 | Reg loss: 0.034 | Tree loss: 6.708 | Accuracy: 0.595703 | 7.111 sec/iter
Epoch: 1

Epoch: 14 | Batch: 005 / 030 | Total loss: 6.333 | Reg loss: 0.036 | Tree loss: 6.333 | Accuracy: 0.576172 | 7.081 sec/iter
Epoch: 14 | Batch: 006 / 030 | Total loss: 6.268 | Reg loss: 0.036 | Tree loss: 6.268 | Accuracy: 0.628906 | 7.08 sec/iter
Epoch: 14 | Batch: 007 / 030 | Total loss: 6.264 | Reg loss: 0.036 | Tree loss: 6.264 | Accuracy: 0.550781 | 7.079 sec/iter
Epoch: 14 | Batch: 008 / 030 | Total loss: 6.236 | Reg loss: 0.036 | Tree loss: 6.236 | Accuracy: 0.580078 | 7.079 sec/iter
Epoch: 14 | Batch: 009 / 030 | Total loss: 6.190 | Reg loss: 0.036 | Tree loss: 6.190 | Accuracy: 0.628906 | 7.079 sec/iter
Epoch: 14 | Batch: 010 / 030 | Total loss: 6.211 | Reg loss: 0.036 | Tree loss: 6.211 | Accuracy: 0.582031 | 7.079 sec/iter
Epoch: 14 | Batch: 011 / 030 | Total loss: 6.171 | Reg loss: 0.037 | Tree loss: 6.171 | Accuracy: 0.603516 | 7.08 sec/iter
Epoch: 14 | Batch: 012 / 030 | Total loss: 6.097 | Reg loss: 0.037 | Tree loss: 6.097 | Accuracy: 0.578125 | 7.08 sec/iter
Epoch: 14 |

Epoch: 16 | Batch: 006 / 030 | Total loss: 5.722 | Reg loss: 0.037 | Tree loss: 5.722 | Accuracy: 0.609375 | 7.154 sec/iter
Epoch: 16 | Batch: 007 / 030 | Total loss: 5.691 | Reg loss: 0.037 | Tree loss: 5.691 | Accuracy: 0.578125 | 7.149 sec/iter
Epoch: 16 | Batch: 008 / 030 | Total loss: 5.689 | Reg loss: 0.038 | Tree loss: 5.689 | Accuracy: 0.560547 | 7.145 sec/iter
Epoch: 16 | Batch: 009 / 030 | Total loss: 5.631 | Reg loss: 0.038 | Tree loss: 5.631 | Accuracy: 0.615234 | 7.14 sec/iter
Epoch: 16 | Batch: 010 / 030 | Total loss: 5.648 | Reg loss: 0.038 | Tree loss: 5.648 | Accuracy: 0.554688 | 7.135 sec/iter
Epoch: 16 | Batch: 011 / 030 | Total loss: 5.585 | Reg loss: 0.038 | Tree loss: 5.585 | Accuracy: 0.603516 | 7.135 sec/iter
Epoch: 16 | Batch: 012 / 030 | Total loss: 5.566 | Reg loss: 0.038 | Tree loss: 5.566 | Accuracy: 0.576172 | 7.135 sec/iter
Epoch: 16 | Batch: 013 / 030 | Total loss: 5.549 | Reg loss: 0.038 | Tree loss: 5.549 | Accuracy: 0.583984 | 7.135 sec/iter
Epoch: 16

Epoch: 18 | Batch: 007 / 030 | Total loss: 5.157 | Reg loss: 0.039 | Tree loss: 5.157 | Accuracy: 0.593750 | 7.096 sec/iter
Epoch: 18 | Batch: 008 / 030 | Total loss: 5.148 | Reg loss: 0.039 | Tree loss: 5.148 | Accuracy: 0.572266 | 7.096 sec/iter
Epoch: 18 | Batch: 009 / 030 | Total loss: 5.091 | Reg loss: 0.039 | Tree loss: 5.091 | Accuracy: 0.597656 | 7.096 sec/iter
Epoch: 18 | Batch: 010 / 030 | Total loss: 5.097 | Reg loss: 0.039 | Tree loss: 5.097 | Accuracy: 0.564453 | 7.096 sec/iter
Epoch: 18 | Batch: 011 / 030 | Total loss: 5.032 | Reg loss: 0.039 | Tree loss: 5.032 | Accuracy: 0.582031 | 7.096 sec/iter
Epoch: 18 | Batch: 012 / 030 | Total loss: 5.001 | Reg loss: 0.039 | Tree loss: 5.001 | Accuracy: 0.587891 | 7.096 sec/iter
Epoch: 18 | Batch: 013 / 030 | Total loss: 5.034 | Reg loss: 0.039 | Tree loss: 5.034 | Accuracy: 0.535156 | 7.096 sec/iter
Epoch: 18 | Batch: 014 / 030 | Total loss: 4.948 | Reg loss: 0.040 | Tree loss: 4.948 | Accuracy: 0.617188 | 7.095 sec/iter
Epoch: 1

Epoch: 20 | Batch: 008 / 030 | Total loss: 4.626 | Reg loss: 0.040 | Tree loss: 4.626 | Accuracy: 0.587891 | 7.053 sec/iter
Epoch: 20 | Batch: 009 / 030 | Total loss: 4.627 | Reg loss: 0.040 | Tree loss: 4.627 | Accuracy: 0.564453 | 7.053 sec/iter
Epoch: 20 | Batch: 010 / 030 | Total loss: 4.542 | Reg loss: 0.040 | Tree loss: 4.542 | Accuracy: 0.621094 | 7.052 sec/iter
Epoch: 20 | Batch: 011 / 030 | Total loss: 4.518 | Reg loss: 0.040 | Tree loss: 4.518 | Accuracy: 0.609375 | 7.052 sec/iter
Epoch: 20 | Batch: 012 / 030 | Total loss: 4.541 | Reg loss: 0.040 | Tree loss: 4.541 | Accuracy: 0.556641 | 7.051 sec/iter
Epoch: 20 | Batch: 013 / 030 | Total loss: 4.448 | Reg loss: 0.040 | Tree loss: 4.448 | Accuracy: 0.634766 | 7.051 sec/iter
Epoch: 20 | Batch: 014 / 030 | Total loss: 4.466 | Reg loss: 0.041 | Tree loss: 4.466 | Accuracy: 0.583984 | 7.051 sec/iter
Epoch: 20 | Batch: 015 / 030 | Total loss: 4.440 | Reg loss: 0.041 | Tree loss: 4.440 | Accuracy: 0.583984 | 7.05 sec/iter
Epoch: 20

Epoch: 22 | Batch: 009 / 030 | Total loss: 4.097 | Reg loss: 0.041 | Tree loss: 4.097 | Accuracy: 0.585938 | 7.137 sec/iter
Epoch: 22 | Batch: 010 / 030 | Total loss: 4.058 | Reg loss: 0.041 | Tree loss: 4.058 | Accuracy: 0.576172 | 7.137 sec/iter
Epoch: 22 | Batch: 011 / 030 | Total loss: 4.059 | Reg loss: 0.041 | Tree loss: 4.059 | Accuracy: 0.582031 | 7.137 sec/iter
Epoch: 22 | Batch: 012 / 030 | Total loss: 4.048 | Reg loss: 0.041 | Tree loss: 4.048 | Accuracy: 0.544922 | 7.137 sec/iter
Epoch: 22 | Batch: 013 / 030 | Total loss: 3.995 | Reg loss: 0.041 | Tree loss: 3.995 | Accuracy: 0.587891 | 7.138 sec/iter
Epoch: 22 | Batch: 014 / 030 | Total loss: 3.959 | Reg loss: 0.041 | Tree loss: 3.959 | Accuracy: 0.626953 | 7.138 sec/iter
Epoch: 22 | Batch: 015 / 030 | Total loss: 3.938 | Reg loss: 0.042 | Tree loss: 3.938 | Accuracy: 0.599609 | 7.138 sec/iter
Epoch: 22 | Batch: 016 / 030 | Total loss: 3.924 | Reg loss: 0.042 | Tree loss: 3.924 | Accuracy: 0.603516 | 7.138 sec/iter
Epoch: 2

Epoch: 24 | Batch: 010 / 030 | Total loss: 3.671 | Reg loss: 0.042 | Tree loss: 3.671 | Accuracy: 0.578125 | 7.125 sec/iter
Epoch: 24 | Batch: 011 / 030 | Total loss: 3.594 | Reg loss: 0.042 | Tree loss: 3.594 | Accuracy: 0.597656 | 7.121 sec/iter
Epoch: 24 | Batch: 012 / 030 | Total loss: 3.597 | Reg loss: 0.042 | Tree loss: 3.597 | Accuracy: 0.583984 | 7.118 sec/iter
Epoch: 24 | Batch: 013 / 030 | Total loss: 3.526 | Reg loss: 0.042 | Tree loss: 3.526 | Accuracy: 0.605469 | 7.115 sec/iter
Epoch: 24 | Batch: 014 / 030 | Total loss: 3.523 | Reg loss: 0.043 | Tree loss: 3.523 | Accuracy: 0.603516 | 7.112 sec/iter
Epoch: 24 | Batch: 015 / 030 | Total loss: 3.504 | Reg loss: 0.043 | Tree loss: 3.504 | Accuracy: 0.574219 | 7.109 sec/iter
Epoch: 24 | Batch: 016 / 030 | Total loss: 3.514 | Reg loss: 0.043 | Tree loss: 3.514 | Accuracy: 0.552734 | 7.109 sec/iter
Epoch: 24 | Batch: 017 / 030 | Total loss: 3.459 | Reg loss: 0.043 | Tree loss: 3.459 | Accuracy: 0.566406 | 7.109 sec/iter
Epoch: 2

Epoch: 26 | Batch: 011 / 030 | Total loss: 3.247 | Reg loss: 0.043 | Tree loss: 3.247 | Accuracy: 0.574219 | 7.107 sec/iter
Epoch: 26 | Batch: 012 / 030 | Total loss: 3.225 | Reg loss: 0.043 | Tree loss: 3.225 | Accuracy: 0.568359 | 7.107 sec/iter
Epoch: 26 | Batch: 013 / 030 | Total loss: 3.182 | Reg loss: 0.043 | Tree loss: 3.182 | Accuracy: 0.576172 | 7.106 sec/iter
Epoch: 26 | Batch: 014 / 030 | Total loss: 3.160 | Reg loss: 0.044 | Tree loss: 3.160 | Accuracy: 0.574219 | 7.105 sec/iter
Epoch: 26 | Batch: 015 / 030 | Total loss: 3.165 | Reg loss: 0.044 | Tree loss: 3.165 | Accuracy: 0.564453 | 7.102 sec/iter
Epoch: 26 | Batch: 016 / 030 | Total loss: 3.111 | Reg loss: 0.044 | Tree loss: 3.111 | Accuracy: 0.574219 | 7.099 sec/iter
Epoch: 26 | Batch: 017 / 030 | Total loss: 3.085 | Reg loss: 0.044 | Tree loss: 3.085 | Accuracy: 0.562500 | 7.096 sec/iter
Epoch: 26 | Batch: 018 / 030 | Total loss: 3.063 | Reg loss: 0.044 | Tree loss: 3.063 | Accuracy: 0.566406 | 7.093 sec/iter
Epoch: 2

Epoch: 28 | Batch: 012 / 030 | Total loss: 2.888 | Reg loss: 0.044 | Tree loss: 2.888 | Accuracy: 0.617188 | 7.056 sec/iter
Epoch: 28 | Batch: 013 / 030 | Total loss: 2.885 | Reg loss: 0.044 | Tree loss: 2.885 | Accuracy: 0.585938 | 7.056 sec/iter
Epoch: 28 | Batch: 014 / 030 | Total loss: 2.882 | Reg loss: 0.044 | Tree loss: 2.882 | Accuracy: 0.542969 | 7.056 sec/iter
Epoch: 28 | Batch: 015 / 030 | Total loss: 2.826 | Reg loss: 0.044 | Tree loss: 2.826 | Accuracy: 0.583984 | 7.056 sec/iter
Epoch: 28 | Batch: 016 / 030 | Total loss: 2.812 | Reg loss: 0.045 | Tree loss: 2.812 | Accuracy: 0.560547 | 7.056 sec/iter
Epoch: 28 | Batch: 017 / 030 | Total loss: 2.748 | Reg loss: 0.045 | Tree loss: 2.748 | Accuracy: 0.568359 | 7.055 sec/iter
Epoch: 28 | Batch: 018 / 030 | Total loss: 2.731 | Reg loss: 0.045 | Tree loss: 2.731 | Accuracy: 0.582031 | 7.055 sec/iter
Epoch: 28 | Batch: 019 / 030 | Total loss: 2.685 | Reg loss: 0.045 | Tree loss: 2.685 | Accuracy: 0.625000 | 7.055 sec/iter
Epoch: 2

Epoch: 30 | Batch: 013 / 030 | Total loss: 2.593 | Reg loss: 0.045 | Tree loss: 2.593 | Accuracy: 0.611328 | 7.09 sec/iter
Epoch: 30 | Batch: 014 / 030 | Total loss: 2.590 | Reg loss: 0.045 | Tree loss: 2.590 | Accuracy: 0.515625 | 7.089 sec/iter
Epoch: 30 | Batch: 015 / 030 | Total loss: 2.514 | Reg loss: 0.045 | Tree loss: 2.514 | Accuracy: 0.619141 | 7.088 sec/iter
Epoch: 30 | Batch: 016 / 030 | Total loss: 2.503 | Reg loss: 0.045 | Tree loss: 2.503 | Accuracy: 0.576172 | 7.085 sec/iter
Epoch: 30 | Batch: 017 / 030 | Total loss: 2.488 | Reg loss: 0.045 | Tree loss: 2.488 | Accuracy: 0.597656 | 7.085 sec/iter
Epoch: 30 | Batch: 018 / 030 | Total loss: 2.451 | Reg loss: 0.045 | Tree loss: 2.451 | Accuracy: 0.607422 | 7.085 sec/iter
Epoch: 30 | Batch: 019 / 030 | Total loss: 2.489 | Reg loss: 0.045 | Tree loss: 2.489 | Accuracy: 0.544922 | 7.085 sec/iter
Epoch: 30 | Batch: 020 / 030 | Total loss: 2.427 | Reg loss: 0.045 | Tree loss: 2.427 | Accuracy: 0.578125 | 7.084 sec/iter
Epoch: 30

Epoch: 32 | Batch: 014 / 030 | Total loss: 2.334 | Reg loss: 0.045 | Tree loss: 2.334 | Accuracy: 0.593750 | 7.131 sec/iter
Epoch: 32 | Batch: 015 / 030 | Total loss: 2.317 | Reg loss: 0.045 | Tree loss: 2.317 | Accuracy: 0.593750 | 7.131 sec/iter
Epoch: 32 | Batch: 016 / 030 | Total loss: 2.322 | Reg loss: 0.045 | Tree loss: 2.322 | Accuracy: 0.574219 | 7.131 sec/iter
Epoch: 32 | Batch: 017 / 030 | Total loss: 2.276 | Reg loss: 0.045 | Tree loss: 2.276 | Accuracy: 0.558594 | 7.131 sec/iter
Epoch: 32 | Batch: 018 / 030 | Total loss: 2.226 | Reg loss: 0.045 | Tree loss: 2.226 | Accuracy: 0.621094 | 7.131 sec/iter
Epoch: 32 | Batch: 019 / 030 | Total loss: 2.227 | Reg loss: 0.045 | Tree loss: 2.227 | Accuracy: 0.621094 | 7.131 sec/iter
Epoch: 32 | Batch: 020 / 030 | Total loss: 2.198 | Reg loss: 0.046 | Tree loss: 2.198 | Accuracy: 0.574219 | 7.131 sec/iter
Epoch: 32 | Batch: 021 / 030 | Total loss: 2.226 | Reg loss: 0.046 | Tree loss: 2.226 | Accuracy: 0.566406 | 7.13 sec/iter
Epoch: 32

Epoch: 34 | Batch: 015 / 030 | Total loss: 2.111 | Reg loss: 0.045 | Tree loss: 2.111 | Accuracy: 0.585938 | 7.095 sec/iter
Epoch: 34 | Batch: 016 / 030 | Total loss: 2.105 | Reg loss: 0.045 | Tree loss: 2.105 | Accuracy: 0.558594 | 7.093 sec/iter
Epoch: 34 | Batch: 017 / 030 | Total loss: 2.067 | Reg loss: 0.045 | Tree loss: 2.067 | Accuracy: 0.597656 | 7.091 sec/iter
Epoch: 34 | Batch: 018 / 030 | Total loss: 2.081 | Reg loss: 0.045 | Tree loss: 2.081 | Accuracy: 0.546875 | 7.088 sec/iter
Epoch: 34 | Batch: 019 / 030 | Total loss: 2.031 | Reg loss: 0.045 | Tree loss: 2.031 | Accuracy: 0.574219 | 7.086 sec/iter
Epoch: 34 | Batch: 020 / 030 | Total loss: 2.042 | Reg loss: 0.046 | Tree loss: 2.042 | Accuracy: 0.566406 | 7.084 sec/iter
Epoch: 34 | Batch: 021 / 030 | Total loss: 1.996 | Reg loss: 0.046 | Tree loss: 1.996 | Accuracy: 0.611328 | 7.082 sec/iter
Epoch: 34 | Batch: 022 / 030 | Total loss: 1.993 | Reg loss: 0.046 | Tree loss: 1.993 | Accuracy: 0.580078 | 7.082 sec/iter
Epoch: 3

Epoch: 36 | Batch: 016 / 030 | Total loss: 1.930 | Reg loss: 0.045 | Tree loss: 1.930 | Accuracy: 0.583984 | 7.092 sec/iter
Epoch: 36 | Batch: 017 / 030 | Total loss: 1.905 | Reg loss: 0.045 | Tree loss: 1.905 | Accuracy: 0.568359 | 7.091 sec/iter
Epoch: 36 | Batch: 018 / 030 | Total loss: 1.900 | Reg loss: 0.045 | Tree loss: 1.900 | Accuracy: 0.576172 | 7.089 sec/iter
Epoch: 36 | Batch: 019 / 030 | Total loss: 1.884 | Reg loss: 0.045 | Tree loss: 1.884 | Accuracy: 0.599609 | 7.086 sec/iter
Epoch: 36 | Batch: 020 / 030 | Total loss: 1.842 | Reg loss: 0.045 | Tree loss: 1.842 | Accuracy: 0.585938 | 7.085 sec/iter
Epoch: 36 | Batch: 021 / 030 | Total loss: 1.875 | Reg loss: 0.045 | Tree loss: 1.875 | Accuracy: 0.580078 | 7.083 sec/iter
Epoch: 36 | Batch: 022 / 030 | Total loss: 1.843 | Reg loss: 0.046 | Tree loss: 1.843 | Accuracy: 0.587891 | 7.083 sec/iter
Epoch: 36 | Batch: 023 / 030 | Total loss: 1.814 | Reg loss: 0.046 | Tree loss: 1.814 | Accuracy: 0.572266 | 7.082 sec/iter
Epoch: 3

Epoch: 38 | Batch: 017 / 030 | Total loss: 1.776 | Reg loss: 0.045 | Tree loss: 1.776 | Accuracy: 0.552734 | 7.087 sec/iter
Epoch: 38 | Batch: 018 / 030 | Total loss: 1.756 | Reg loss: 0.045 | Tree loss: 1.756 | Accuracy: 0.601562 | 7.087 sec/iter
Epoch: 38 | Batch: 019 / 030 | Total loss: 1.762 | Reg loss: 0.045 | Tree loss: 1.762 | Accuracy: 0.578125 | 7.087 sec/iter
Epoch: 38 | Batch: 020 / 030 | Total loss: 1.717 | Reg loss: 0.045 | Tree loss: 1.717 | Accuracy: 0.595703 | 7.087 sec/iter
Epoch: 38 | Batch: 021 / 030 | Total loss: 1.688 | Reg loss: 0.045 | Tree loss: 1.688 | Accuracy: 0.599609 | 7.088 sec/iter
Epoch: 38 | Batch: 022 / 030 | Total loss: 1.716 | Reg loss: 0.045 | Tree loss: 1.716 | Accuracy: 0.580078 | 7.088 sec/iter
Epoch: 38 | Batch: 023 / 030 | Total loss: 1.667 | Reg loss: 0.045 | Tree loss: 1.667 | Accuracy: 0.591797 | 7.088 sec/iter
Epoch: 38 | Batch: 024 / 030 | Total loss: 1.667 | Reg loss: 0.046 | Tree loss: 1.667 | Accuracy: 0.595703 | 7.088 sec/iter
Epoch: 3

Epoch: 40 | Batch: 018 / 030 | Total loss: 1.614 | Reg loss: 0.045 | Tree loss: 1.614 | Accuracy: 0.603516 | 7.122 sec/iter
Epoch: 40 | Batch: 019 / 030 | Total loss: 1.587 | Reg loss: 0.045 | Tree loss: 1.587 | Accuracy: 0.595703 | 7.122 sec/iter
Epoch: 40 | Batch: 020 / 030 | Total loss: 1.621 | Reg loss: 0.045 | Tree loss: 1.621 | Accuracy: 0.541016 | 7.122 sec/iter
Epoch: 40 | Batch: 021 / 030 | Total loss: 1.608 | Reg loss: 0.045 | Tree loss: 1.608 | Accuracy: 0.552734 | 7.121 sec/iter
Epoch: 40 | Batch: 022 / 030 | Total loss: 1.590 | Reg loss: 0.045 | Tree loss: 1.590 | Accuracy: 0.568359 | 7.121 sec/iter
Epoch: 40 | Batch: 023 / 030 | Total loss: 1.555 | Reg loss: 0.045 | Tree loss: 1.555 | Accuracy: 0.583984 | 7.121 sec/iter
Epoch: 40 | Batch: 024 / 030 | Total loss: 1.580 | Reg loss: 0.045 | Tree loss: 1.580 | Accuracy: 0.574219 | 7.121 sec/iter
Epoch: 40 | Batch: 025 / 030 | Total loss: 1.522 | Reg loss: 0.045 | Tree loss: 1.522 | Accuracy: 0.621094 | 7.121 sec/iter
Epoch: 4

Epoch: 42 | Batch: 019 / 030 | Total loss: 1.502 | Reg loss: 0.045 | Tree loss: 1.502 | Accuracy: 0.568359 | 7.118 sec/iter
Epoch: 42 | Batch: 020 / 030 | Total loss: 1.475 | Reg loss: 0.045 | Tree loss: 1.475 | Accuracy: 0.623047 | 7.119 sec/iter
Epoch: 42 | Batch: 021 / 030 | Total loss: 1.465 | Reg loss: 0.045 | Tree loss: 1.465 | Accuracy: 0.585938 | 7.119 sec/iter
Epoch: 42 | Batch: 022 / 030 | Total loss: 1.485 | Reg loss: 0.045 | Tree loss: 1.485 | Accuracy: 0.566406 | 7.119 sec/iter
Epoch: 42 | Batch: 023 / 030 | Total loss: 1.462 | Reg loss: 0.045 | Tree loss: 1.462 | Accuracy: 0.585938 | 7.119 sec/iter
Epoch: 42 | Batch: 024 / 030 | Total loss: 1.433 | Reg loss: 0.045 | Tree loss: 1.433 | Accuracy: 0.560547 | 7.119 sec/iter
Epoch: 42 | Batch: 025 / 030 | Total loss: 1.461 | Reg loss: 0.046 | Tree loss: 1.461 | Accuracy: 0.525391 | 7.119 sec/iter
Epoch: 42 | Batch: 026 / 030 | Total loss: 1.425 | Reg loss: 0.046 | Tree loss: 1.425 | Accuracy: 0.585938 | 7.119 sec/iter
Epoch: 4

Epoch: 44 | Batch: 020 / 030 | Total loss: 1.406 | Reg loss: 0.045 | Tree loss: 1.406 | Accuracy: 0.562500 | 7.114 sec/iter
Epoch: 44 | Batch: 021 / 030 | Total loss: 1.378 | Reg loss: 0.045 | Tree loss: 1.378 | Accuracy: 0.582031 | 7.114 sec/iter
Epoch: 44 | Batch: 022 / 030 | Total loss: 1.365 | Reg loss: 0.045 | Tree loss: 1.365 | Accuracy: 0.568359 | 7.114 sec/iter
Epoch: 44 | Batch: 023 / 030 | Total loss: 1.332 | Reg loss: 0.045 | Tree loss: 1.332 | Accuracy: 0.630859 | 7.114 sec/iter
Epoch: 44 | Batch: 024 / 030 | Total loss: 1.347 | Reg loss: 0.045 | Tree loss: 1.347 | Accuracy: 0.566406 | 7.113 sec/iter
Epoch: 44 | Batch: 025 / 030 | Total loss: 1.333 | Reg loss: 0.045 | Tree loss: 1.333 | Accuracy: 0.582031 | 7.113 sec/iter
Epoch: 44 | Batch: 026 / 030 | Total loss: 1.340 | Reg loss: 0.045 | Tree loss: 1.340 | Accuracy: 0.566406 | 7.112 sec/iter
Epoch: 44 | Batch: 027 / 030 | Total loss: 1.328 | Reg loss: 0.046 | Tree loss: 1.328 | Accuracy: 0.537109 | 7.11 sec/iter
Epoch: 44

Epoch: 46 | Batch: 021 / 030 | Total loss: 1.285 | Reg loss: 0.045 | Tree loss: 1.285 | Accuracy: 0.589844 | 7.11 sec/iter
Epoch: 46 | Batch: 022 / 030 | Total loss: 1.287 | Reg loss: 0.045 | Tree loss: 1.287 | Accuracy: 0.568359 | 7.11 sec/iter
Epoch: 46 | Batch: 023 / 030 | Total loss: 1.288 | Reg loss: 0.045 | Tree loss: 1.288 | Accuracy: 0.583984 | 7.109 sec/iter
Epoch: 46 | Batch: 024 / 030 | Total loss: 1.266 | Reg loss: 0.045 | Tree loss: 1.266 | Accuracy: 0.564453 | 7.107 sec/iter
Epoch: 46 | Batch: 025 / 030 | Total loss: 1.262 | Reg loss: 0.045 | Tree loss: 1.262 | Accuracy: 0.580078 | 7.106 sec/iter
Epoch: 46 | Batch: 026 / 030 | Total loss: 1.251 | Reg loss: 0.045 | Tree loss: 1.251 | Accuracy: 0.585938 | 7.104 sec/iter
Epoch: 46 | Batch: 027 / 030 | Total loss: 1.246 | Reg loss: 0.045 | Tree loss: 1.246 | Accuracy: 0.572266 | 7.102 sec/iter
Epoch: 46 | Batch: 028 / 030 | Total loss: 1.248 | Reg loss: 0.045 | Tree loss: 1.248 | Accuracy: 0.578125 | 7.101 sec/iter
Epoch: 46 

Epoch: 48 | Batch: 022 / 030 | Total loss: 1.250 | Reg loss: 0.045 | Tree loss: 1.250 | Accuracy: 0.562500 | 7.094 sec/iter
Epoch: 48 | Batch: 023 / 030 | Total loss: 1.191 | Reg loss: 0.045 | Tree loss: 1.191 | Accuracy: 0.632812 | 7.094 sec/iter
Epoch: 48 | Batch: 024 / 030 | Total loss: 1.201 | Reg loss: 0.045 | Tree loss: 1.201 | Accuracy: 0.605469 | 7.094 sec/iter
Epoch: 48 | Batch: 025 / 030 | Total loss: 1.225 | Reg loss: 0.045 | Tree loss: 1.225 | Accuracy: 0.574219 | 7.094 sec/iter
Epoch: 48 | Batch: 026 / 030 | Total loss: 1.205 | Reg loss: 0.045 | Tree loss: 1.205 | Accuracy: 0.570312 | 7.093 sec/iter
Epoch: 48 | Batch: 027 / 030 | Total loss: 1.175 | Reg loss: 0.045 | Tree loss: 1.175 | Accuracy: 0.591797 | 7.092 sec/iter
Epoch: 48 | Batch: 028 / 030 | Total loss: 1.169 | Reg loss: 0.045 | Tree loss: 1.169 | Accuracy: 0.587891 | 7.09 sec/iter
Epoch: 48 | Batch: 029 / 030 | Total loss: 1.119 | Reg loss: 0.045 | Tree loss: 1.119 | Accuracy: 0.631068 | 7.087 sec/iter
Average s

Epoch: 50 | Batch: 023 / 030 | Total loss: 1.192 | Reg loss: 0.044 | Tree loss: 1.192 | Accuracy: 0.554688 | 7.09 sec/iter
Epoch: 50 | Batch: 024 / 030 | Total loss: 1.188 | Reg loss: 0.044 | Tree loss: 1.188 | Accuracy: 0.539062 | 7.088 sec/iter
Epoch: 50 | Batch: 025 / 030 | Total loss: 1.176 | Reg loss: 0.044 | Tree loss: 1.176 | Accuracy: 0.548828 | 7.087 sec/iter
Epoch: 50 | Batch: 026 / 030 | Total loss: 1.143 | Reg loss: 0.044 | Tree loss: 1.143 | Accuracy: 0.593750 | 7.085 sec/iter
Epoch: 50 | Batch: 027 / 030 | Total loss: 1.130 | Reg loss: 0.044 | Tree loss: 1.130 | Accuracy: 0.597656 | 7.084 sec/iter
Epoch: 50 | Batch: 028 / 030 | Total loss: 1.133 | Reg loss: 0.044 | Tree loss: 1.133 | Accuracy: 0.570312 | 7.082 sec/iter
Epoch: 50 | Batch: 029 / 030 | Total loss: 1.095 | Reg loss: 0.045 | Tree loss: 1.095 | Accuracy: 0.592233 | 7.079 sec/iter
Average sparseness: 0.9821428571428572
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0

Epoch: 52 | Batch: 024 / 030 | Total loss: 1.121 | Reg loss: 0.044 | Tree loss: 1.121 | Accuracy: 0.568359 | 7.084 sec/iter
Epoch: 52 | Batch: 025 / 030 | Total loss: 1.118 | Reg loss: 0.044 | Tree loss: 1.118 | Accuracy: 0.589844 | 7.083 sec/iter
Epoch: 52 | Batch: 026 / 030 | Total loss: 1.091 | Reg loss: 0.044 | Tree loss: 1.091 | Accuracy: 0.597656 | 7.082 sec/iter
Epoch: 52 | Batch: 027 / 030 | Total loss: 1.098 | Reg loss: 0.044 | Tree loss: 1.098 | Accuracy: 0.560547 | 7.082 sec/iter
Epoch: 52 | Batch: 028 / 030 | Total loss: 1.074 | Reg loss: 0.044 | Tree loss: 1.074 | Accuracy: 0.603516 | 7.082 sec/iter
Epoch: 52 | Batch: 029 / 030 | Total loss: 1.050 | Reg loss: 0.044 | Tree loss: 1.050 | Accuracy: 0.631068 | 7.08 sec/iter
Average sparseness: 0.9821428571428572
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
layer 7: 0.982142857142

Epoch: 54 | Batch: 025 / 030 | Total loss: 1.073 | Reg loss: 0.044 | Tree loss: 1.073 | Accuracy: 0.583984 | 7.074 sec/iter
Epoch: 54 | Batch: 026 / 030 | Total loss: 1.042 | Reg loss: 0.044 | Tree loss: 1.042 | Accuracy: 0.613281 | 7.073 sec/iter
Epoch: 54 | Batch: 027 / 030 | Total loss: 1.072 | Reg loss: 0.044 | Tree loss: 1.072 | Accuracy: 0.544922 | 7.073 sec/iter
Epoch: 54 | Batch: 028 / 030 | Total loss: 1.053 | Reg loss: 0.044 | Tree loss: 1.053 | Accuracy: 0.566406 | 7.073 sec/iter
Epoch: 54 | Batch: 029 / 030 | Total loss: 1.053 | Reg loss: 0.044 | Tree loss: 1.053 | Accuracy: 0.611650 | 7.071 sec/iter
Average sparseness: 0.9821428571428572
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
layer 7: 0.9821428571428573
layer 8: 0.9821428571428573
layer 9: 0.9821428571428573
layer 10: 0.9821428571428573
Epoch: 55 | Batch: 000 / 030 | To

Epoch: 56 | Batch: 026 / 030 | Total loss: 1.030 | Reg loss: 0.044 | Tree loss: 1.030 | Accuracy: 0.626953 | 7.072 sec/iter
Epoch: 56 | Batch: 027 / 030 | Total loss: 1.043 | Reg loss: 0.044 | Tree loss: 1.043 | Accuracy: 0.558594 | 7.072 sec/iter
Epoch: 56 | Batch: 028 / 030 | Total loss: 1.038 | Reg loss: 0.044 | Tree loss: 1.038 | Accuracy: 0.550781 | 7.072 sec/iter
Epoch: 56 | Batch: 029 / 030 | Total loss: 1.037 | Reg loss: 0.044 | Tree loss: 1.037 | Accuracy: 0.543689 | 7.07 sec/iter
Average sparseness: 0.9821428571428572
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
layer 7: 0.9821428571428573
layer 8: 0.9821428571428573
layer 9: 0.9821428571428573
layer 10: 0.9821428571428573
Epoch: 57 | Batch: 000 / 030 | Total loss: 1.377 | Reg loss: 0.042 | Tree loss: 1.377 | Accuracy: 0.597656 | 7.087 sec/iter
Epoch: 57 | Batch: 001 / 030 | Tot

Epoch: 58 | Batch: 027 / 030 | Total loss: 1.015 | Reg loss: 0.043 | Tree loss: 1.015 | Accuracy: 0.601562 | 7.078 sec/iter
Epoch: 58 | Batch: 028 / 030 | Total loss: 0.988 | Reg loss: 0.044 | Tree loss: 0.988 | Accuracy: 0.621094 | 7.078 sec/iter
Epoch: 58 | Batch: 029 / 030 | Total loss: 1.039 | Reg loss: 0.044 | Tree loss: 1.039 | Accuracy: 0.553398 | 7.076 sec/iter
Average sparseness: 0.9821428571428572
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
layer 7: 0.9821428571428573
layer 8: 0.9821428571428573
layer 9: 0.9821428571428573
layer 10: 0.9821428571428573
Epoch: 59 | Batch: 000 / 030 | Total loss: 1.342 | Reg loss: 0.042 | Tree loss: 1.342 | Accuracy: 0.560547 | 7.101 sec/iter
Epoch: 59 | Batch: 001 / 030 | Total loss: 1.339 | Reg loss: 0.042 | Tree loss: 1.339 | Accuracy: 0.578125 | 7.101 sec/iter
Epoch: 59 | Batch: 002 / 030 | To

Epoch: 60 | Batch: 028 / 030 | Total loss: 0.989 | Reg loss: 0.043 | Tree loss: 0.989 | Accuracy: 0.548828 | 7.096 sec/iter
Epoch: 60 | Batch: 029 / 030 | Total loss: 0.973 | Reg loss: 0.043 | Tree loss: 0.973 | Accuracy: 0.592233 | 7.094 sec/iter
Average sparseness: 0.9821428571428572
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
layer 7: 0.9821428571428573
layer 8: 0.9821428571428573
layer 9: 0.9821428571428573
layer 10: 0.9821428571428573
Epoch: 61 | Batch: 000 / 030 | Total loss: 1.333 | Reg loss: 0.042 | Tree loss: 1.333 | Accuracy: 0.570312 | 7.11 sec/iter
Epoch: 61 | Batch: 001 / 030 | Total loss: 1.350 | Reg loss: 0.042 | Tree loss: 1.350 | Accuracy: 0.550781 | 7.109 sec/iter
Epoch: 61 | Batch: 002 / 030 | Total loss: 1.343 | Reg loss: 0.042 | Tree loss: 1.343 | Accuracy: 0.558594 | 7.107 sec/iter
Epoch: 61 | Batch: 003 / 030 | Tot

Epoch: 62 | Batch: 029 / 030 | Total loss: 0.971 | Reg loss: 0.043 | Tree loss: 0.971 | Accuracy: 0.563107 | 7.102 sec/iter
Average sparseness: 0.9821428571428572
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
layer 7: 0.9821428571428573
layer 8: 0.9821428571428573
layer 9: 0.9821428571428573
layer 10: 0.9821428571428573
Epoch: 63 | Batch: 000 / 030 | Total loss: 1.366 | Reg loss: 0.041 | Tree loss: 1.366 | Accuracy: 0.566406 | 7.107 sec/iter
Epoch: 63 | Batch: 001 / 030 | Total loss: 1.279 | Reg loss: 0.041 | Tree loss: 1.279 | Accuracy: 0.601562 | 7.106 sec/iter
Epoch: 63 | Batch: 002 / 030 | Total loss: 1.273 | Reg loss: 0.041 | Tree loss: 1.273 | Accuracy: 0.580078 | 7.106 sec/iter
Epoch: 63 | Batch: 003 / 030 | Total loss: 1.292 | Reg loss: 0.041 | Tree loss: 1.292 | Accuracy: 0.578125 | 7.106 sec/iter
Epoch: 63 | Batch: 004 / 030 | To

Average sparseness: 0.9821428571428572
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
layer 7: 0.9821428571428573
layer 8: 0.9821428571428573
layer 9: 0.9821428571428573
layer 10: 0.9821428571428573
Epoch: 65 | Batch: 000 / 030 | Total loss: 1.309 | Reg loss: 0.041 | Tree loss: 1.309 | Accuracy: 0.585938 | 7.102 sec/iter
Epoch: 65 | Batch: 001 / 030 | Total loss: 1.315 | Reg loss: 0.041 | Tree loss: 1.315 | Accuracy: 0.595703 | 7.102 sec/iter
Epoch: 65 | Batch: 002 / 030 | Total loss: 1.295 | Reg loss: 0.041 | Tree loss: 1.295 | Accuracy: 0.613281 | 7.103 sec/iter
Epoch: 65 | Batch: 003 / 030 | Total loss: 1.263 | Reg loss: 0.041 | Tree loss: 1.263 | Accuracy: 0.601562 | 7.103 sec/iter
Epoch: 65 | Batch: 004 / 030 | Total loss: 1.264 | Reg loss: 0.041 | Tree loss: 1.264 | Accuracy: 0.585938 | 7.103 sec/iter
Epoch: 65 | Batch: 005 / 030 | To

layer 10: 0.9821428571428573
Epoch: 67 | Batch: 000 / 030 | Total loss: 1.325 | Reg loss: 0.041 | Tree loss: 1.325 | Accuracy: 0.548828 | 7.097 sec/iter
Epoch: 67 | Batch: 001 / 030 | Total loss: 1.304 | Reg loss: 0.041 | Tree loss: 1.304 | Accuracy: 0.574219 | 7.097 sec/iter
Epoch: 67 | Batch: 002 / 030 | Total loss: 1.279 | Reg loss: 0.041 | Tree loss: 1.279 | Accuracy: 0.556641 | 7.097 sec/iter
Epoch: 67 | Batch: 003 / 030 | Total loss: 1.284 | Reg loss: 0.041 | Tree loss: 1.284 | Accuracy: 0.599609 | 7.097 sec/iter
Epoch: 67 | Batch: 004 / 030 | Total loss: 1.249 | Reg loss: 0.041 | Tree loss: 1.249 | Accuracy: 0.589844 | 7.097 sec/iter
Epoch: 67 | Batch: 005 / 030 | Total loss: 1.227 | Reg loss: 0.041 | Tree loss: 1.227 | Accuracy: 0.583984 | 7.097 sec/iter
Epoch: 67 | Batch: 006 / 030 | Total loss: 1.220 | Reg loss: 0.041 | Tree loss: 1.220 | Accuracy: 0.564453 | 7.097 sec/iter
Epoch: 67 | Batch: 007 / 030 | Total loss: 1.194 | Reg loss: 0.041 | Tree loss: 1.194 | Accuracy: 0.566

Epoch: 69 | Batch: 001 / 030 | Total loss: 1.275 | Reg loss: 0.041 | Tree loss: 1.275 | Accuracy: 0.591797 | 7.094 sec/iter
Epoch: 69 | Batch: 002 / 030 | Total loss: 1.262 | Reg loss: 0.041 | Tree loss: 1.262 | Accuracy: 0.597656 | 7.094 sec/iter
Epoch: 69 | Batch: 003 / 030 | Total loss: 1.245 | Reg loss: 0.041 | Tree loss: 1.245 | Accuracy: 0.611328 | 7.094 sec/iter
Epoch: 69 | Batch: 004 / 030 | Total loss: 1.262 | Reg loss: 0.041 | Tree loss: 1.262 | Accuracy: 0.531250 | 7.094 sec/iter
Epoch: 69 | Batch: 005 / 030 | Total loss: 1.204 | Reg loss: 0.041 | Tree loss: 1.204 | Accuracy: 0.580078 | 7.094 sec/iter
Epoch: 69 | Batch: 006 / 030 | Total loss: 1.248 | Reg loss: 0.041 | Tree loss: 1.248 | Accuracy: 0.525391 | 7.094 sec/iter
Epoch: 69 | Batch: 007 / 030 | Total loss: 1.155 | Reg loss: 0.041 | Tree loss: 1.155 | Accuracy: 0.591797 | 7.094 sec/iter
Epoch: 69 | Batch: 008 / 030 | Total loss: 1.170 | Reg loss: 0.041 | Tree loss: 1.170 | Accuracy: 0.562500 | 7.094 sec/iter
Epoch: 6

Epoch: 71 | Batch: 002 / 030 | Total loss: 1.268 | Reg loss: 0.041 | Tree loss: 1.268 | Accuracy: 0.591797 | 7.09 sec/iter
Epoch: 71 | Batch: 003 / 030 | Total loss: 1.242 | Reg loss: 0.041 | Tree loss: 1.242 | Accuracy: 0.582031 | 7.09 sec/iter
Epoch: 71 | Batch: 004 / 030 | Total loss: 1.230 | Reg loss: 0.041 | Tree loss: 1.230 | Accuracy: 0.572266 | 7.09 sec/iter
Epoch: 71 | Batch: 005 / 030 | Total loss: 1.205 | Reg loss: 0.041 | Tree loss: 1.205 | Accuracy: 0.585938 | 7.09 sec/iter
Epoch: 71 | Batch: 006 / 030 | Total loss: 1.189 | Reg loss: 0.041 | Tree loss: 1.189 | Accuracy: 0.580078 | 7.09 sec/iter
Epoch: 71 | Batch: 007 / 030 | Total loss: 1.185 | Reg loss: 0.041 | Tree loss: 1.185 | Accuracy: 0.576172 | 7.09 sec/iter
Epoch: 71 | Batch: 008 / 030 | Total loss: 1.154 | Reg loss: 0.041 | Tree loss: 1.154 | Accuracy: 0.574219 | 7.09 sec/iter
Epoch: 71 | Batch: 009 / 030 | Total loss: 1.134 | Reg loss: 0.041 | Tree loss: 1.134 | Accuracy: 0.572266 | 7.089 sec/iter
Epoch: 71 | Bat

Epoch: 73 | Batch: 003 / 030 | Total loss: 1.271 | Reg loss: 0.040 | Tree loss: 1.271 | Accuracy: 0.570312 | 7.084 sec/iter
Epoch: 73 | Batch: 004 / 030 | Total loss: 1.266 | Reg loss: 0.040 | Tree loss: 1.266 | Accuracy: 0.568359 | 7.083 sec/iter
Epoch: 73 | Batch: 005 / 030 | Total loss: 1.212 | Reg loss: 0.040 | Tree loss: 1.212 | Accuracy: 0.582031 | 7.083 sec/iter
Epoch: 73 | Batch: 006 / 030 | Total loss: 1.201 | Reg loss: 0.040 | Tree loss: 1.201 | Accuracy: 0.566406 | 7.083 sec/iter
Epoch: 73 | Batch: 007 / 030 | Total loss: 1.154 | Reg loss: 0.040 | Tree loss: 1.154 | Accuracy: 0.582031 | 7.083 sec/iter
Epoch: 73 | Batch: 008 / 030 | Total loss: 1.127 | Reg loss: 0.040 | Tree loss: 1.127 | Accuracy: 0.597656 | 7.082 sec/iter
Epoch: 73 | Batch: 009 / 030 | Total loss: 1.144 | Reg loss: 0.041 | Tree loss: 1.144 | Accuracy: 0.539062 | 7.082 sec/iter
Epoch: 73 | Batch: 010 / 030 | Total loss: 1.106 | Reg loss: 0.041 | Tree loss: 1.106 | Accuracy: 0.593750 | 7.082 sec/iter
Epoch: 7

Epoch: 75 | Batch: 004 / 030 | Total loss: 1.191 | Reg loss: 0.040 | Tree loss: 1.191 | Accuracy: 0.589844 | 7.081 sec/iter
Epoch: 75 | Batch: 005 / 030 | Total loss: 1.176 | Reg loss: 0.040 | Tree loss: 1.176 | Accuracy: 0.601562 | 7.081 sec/iter
Epoch: 75 | Batch: 006 / 030 | Total loss: 1.184 | Reg loss: 0.040 | Tree loss: 1.184 | Accuracy: 0.568359 | 7.081 sec/iter
Epoch: 75 | Batch: 007 / 030 | Total loss: 1.163 | Reg loss: 0.040 | Tree loss: 1.163 | Accuracy: 0.572266 | 7.081 sec/iter
Epoch: 75 | Batch: 008 / 030 | Total loss: 1.114 | Reg loss: 0.040 | Tree loss: 1.114 | Accuracy: 0.603516 | 7.081 sec/iter
Epoch: 75 | Batch: 009 / 030 | Total loss: 1.117 | Reg loss: 0.040 | Tree loss: 1.117 | Accuracy: 0.582031 | 7.08 sec/iter
Epoch: 75 | Batch: 010 / 030 | Total loss: 1.128 | Reg loss: 0.040 | Tree loss: 1.128 | Accuracy: 0.560547 | 7.08 sec/iter
Epoch: 75 | Batch: 011 / 030 | Total loss: 1.072 | Reg loss: 0.040 | Tree loss: 1.072 | Accuracy: 0.564453 | 7.08 sec/iter
Epoch: 75 |

Epoch: 77 | Batch: 005 / 030 | Total loss: 1.196 | Reg loss: 0.040 | Tree loss: 1.196 | Accuracy: 0.582031 | 7.084 sec/iter
Epoch: 77 | Batch: 006 / 030 | Total loss: 1.148 | Reg loss: 0.040 | Tree loss: 1.148 | Accuracy: 0.582031 | 7.084 sec/iter
Epoch: 77 | Batch: 007 / 030 | Total loss: 1.146 | Reg loss: 0.040 | Tree loss: 1.146 | Accuracy: 0.607422 | 7.084 sec/iter
Epoch: 77 | Batch: 008 / 030 | Total loss: 1.125 | Reg loss: 0.040 | Tree loss: 1.125 | Accuracy: 0.580078 | 7.084 sec/iter
Epoch: 77 | Batch: 009 / 030 | Total loss: 1.085 | Reg loss: 0.040 | Tree loss: 1.085 | Accuracy: 0.605469 | 7.084 sec/iter
Epoch: 77 | Batch: 010 / 030 | Total loss: 1.096 | Reg loss: 0.040 | Tree loss: 1.096 | Accuracy: 0.572266 | 7.084 sec/iter
Epoch: 77 | Batch: 011 / 030 | Total loss: 1.092 | Reg loss: 0.040 | Tree loss: 1.092 | Accuracy: 0.513672 | 7.084 sec/iter
Epoch: 77 | Batch: 012 / 030 | Total loss: 1.058 | Reg loss: 0.040 | Tree loss: 1.058 | Accuracy: 0.589844 | 7.084 sec/iter
Epoch: 7

Epoch: 79 | Batch: 006 / 030 | Total loss: 1.155 | Reg loss: 0.040 | Tree loss: 1.155 | Accuracy: 0.585938 | 7.099 sec/iter
Epoch: 79 | Batch: 007 / 030 | Total loss: 1.168 | Reg loss: 0.040 | Tree loss: 1.168 | Accuracy: 0.570312 | 7.099 sec/iter
Epoch: 79 | Batch: 008 / 030 | Total loss: 1.131 | Reg loss: 0.040 | Tree loss: 1.131 | Accuracy: 0.578125 | 7.1 sec/iter
Epoch: 79 | Batch: 009 / 030 | Total loss: 1.100 | Reg loss: 0.040 | Tree loss: 1.100 | Accuracy: 0.582031 | 7.1 sec/iter
Epoch: 79 | Batch: 010 / 030 | Total loss: 1.067 | Reg loss: 0.040 | Tree loss: 1.067 | Accuracy: 0.595703 | 7.1 sec/iter
Epoch: 79 | Batch: 011 / 030 | Total loss: 1.061 | Reg loss: 0.040 | Tree loss: 1.061 | Accuracy: 0.583984 | 7.1 sec/iter
Epoch: 79 | Batch: 012 / 030 | Total loss: 1.052 | Reg loss: 0.040 | Tree loss: 1.052 | Accuracy: 0.585938 | 7.1 sec/iter
Epoch: 79 | Batch: 013 / 030 | Total loss: 1.010 | Reg loss: 0.040 | Tree loss: 1.010 | Accuracy: 0.591797 | 7.1 sec/iter
Epoch: 79 | Batch: 0

Epoch: 81 | Batch: 007 / 030 | Total loss: 1.133 | Reg loss: 0.040 | Tree loss: 1.133 | Accuracy: 0.574219 | 7.097 sec/iter
Epoch: 81 | Batch: 008 / 030 | Total loss: 1.125 | Reg loss: 0.040 | Tree loss: 1.125 | Accuracy: 0.572266 | 7.097 sec/iter
Epoch: 81 | Batch: 009 / 030 | Total loss: 1.065 | Reg loss: 0.040 | Tree loss: 1.065 | Accuracy: 0.611328 | 7.097 sec/iter
Epoch: 81 | Batch: 010 / 030 | Total loss: 1.061 | Reg loss: 0.040 | Tree loss: 1.061 | Accuracy: 0.570312 | 7.097 sec/iter
Epoch: 81 | Batch: 011 / 030 | Total loss: 1.068 | Reg loss: 0.040 | Tree loss: 1.068 | Accuracy: 0.578125 | 7.097 sec/iter
Epoch: 81 | Batch: 012 / 030 | Total loss: 1.060 | Reg loss: 0.040 | Tree loss: 1.060 | Accuracy: 0.597656 | 7.097 sec/iter
Epoch: 81 | Batch: 013 / 030 | Total loss: 1.059 | Reg loss: 0.040 | Tree loss: 1.059 | Accuracy: 0.556641 | 7.097 sec/iter
Epoch: 81 | Batch: 014 / 030 | Total loss: 1.020 | Reg loss: 0.040 | Tree loss: 1.020 | Accuracy: 0.558594 | 7.097 sec/iter
Epoch: 8

Epoch: 83 | Batch: 008 / 030 | Total loss: 1.072 | Reg loss: 0.040 | Tree loss: 1.072 | Accuracy: 0.605469 | 7.097 sec/iter
Epoch: 83 | Batch: 009 / 030 | Total loss: 1.093 | Reg loss: 0.040 | Tree loss: 1.093 | Accuracy: 0.541016 | 7.097 sec/iter
Epoch: 83 | Batch: 010 / 030 | Total loss: 1.059 | Reg loss: 0.040 | Tree loss: 1.059 | Accuracy: 0.582031 | 7.097 sec/iter
Epoch: 83 | Batch: 011 / 030 | Total loss: 1.046 | Reg loss: 0.040 | Tree loss: 1.046 | Accuracy: 0.601562 | 7.097 sec/iter
Epoch: 83 | Batch: 012 / 030 | Total loss: 1.011 | Reg loss: 0.040 | Tree loss: 1.011 | Accuracy: 0.585938 | 7.097 sec/iter
Epoch: 83 | Batch: 013 / 030 | Total loss: 1.020 | Reg loss: 0.040 | Tree loss: 1.020 | Accuracy: 0.595703 | 7.097 sec/iter
Epoch: 83 | Batch: 014 / 030 | Total loss: 1.012 | Reg loss: 0.040 | Tree loss: 1.012 | Accuracy: 0.597656 | 7.097 sec/iter
Epoch: 83 | Batch: 015 / 030 | Total loss: 0.995 | Reg loss: 0.040 | Tree loss: 0.995 | Accuracy: 0.578125 | 7.097 sec/iter
Epoch: 8

Epoch: 85 | Batch: 009 / 030 | Total loss: 1.098 | Reg loss: 0.039 | Tree loss: 1.098 | Accuracy: 0.591797 | 7.092 sec/iter
Epoch: 85 | Batch: 010 / 030 | Total loss: 1.085 | Reg loss: 0.040 | Tree loss: 1.085 | Accuracy: 0.546875 | 7.092 sec/iter
Epoch: 85 | Batch: 011 / 030 | Total loss: 1.050 | Reg loss: 0.040 | Tree loss: 1.050 | Accuracy: 0.583984 | 7.092 sec/iter
Epoch: 85 | Batch: 012 / 030 | Total loss: 1.037 | Reg loss: 0.040 | Tree loss: 1.037 | Accuracy: 0.552734 | 7.092 sec/iter
Epoch: 85 | Batch: 013 / 030 | Total loss: 1.010 | Reg loss: 0.040 | Tree loss: 1.010 | Accuracy: 0.566406 | 7.092 sec/iter
Epoch: 85 | Batch: 014 / 030 | Total loss: 1.005 | Reg loss: 0.040 | Tree loss: 1.005 | Accuracy: 0.595703 | 7.092 sec/iter
Epoch: 85 | Batch: 015 / 030 | Total loss: 1.014 | Reg loss: 0.040 | Tree loss: 1.014 | Accuracy: 0.558594 | 7.092 sec/iter
Epoch: 85 | Batch: 016 / 030 | Total loss: 0.944 | Reg loss: 0.040 | Tree loss: 0.944 | Accuracy: 0.636719 | 7.092 sec/iter
Epoch: 8

Epoch: 87 | Batch: 010 / 030 | Total loss: 1.064 | Reg loss: 0.039 | Tree loss: 1.064 | Accuracy: 0.568359 | 7.091 sec/iter
Epoch: 87 | Batch: 011 / 030 | Total loss: 1.050 | Reg loss: 0.039 | Tree loss: 1.050 | Accuracy: 0.556641 | 7.091 sec/iter
Epoch: 87 | Batch: 012 / 030 | Total loss: 1.047 | Reg loss: 0.039 | Tree loss: 1.047 | Accuracy: 0.566406 | 7.091 sec/iter
Epoch: 87 | Batch: 013 / 030 | Total loss: 1.005 | Reg loss: 0.040 | Tree loss: 1.005 | Accuracy: 0.625000 | 7.091 sec/iter
Epoch: 87 | Batch: 014 / 030 | Total loss: 0.993 | Reg loss: 0.040 | Tree loss: 0.993 | Accuracy: 0.582031 | 7.091 sec/iter
Epoch: 87 | Batch: 015 / 030 | Total loss: 0.976 | Reg loss: 0.040 | Tree loss: 0.976 | Accuracy: 0.593750 | 7.091 sec/iter
Epoch: 87 | Batch: 016 / 030 | Total loss: 0.970 | Reg loss: 0.040 | Tree loss: 0.970 | Accuracy: 0.585938 | 7.091 sec/iter
Epoch: 87 | Batch: 017 / 030 | Total loss: 0.954 | Reg loss: 0.040 | Tree loss: 0.954 | Accuracy: 0.621094 | 7.091 sec/iter
Epoch: 8

Epoch: 89 | Batch: 011 / 030 | Total loss: 1.053 | Reg loss: 0.039 | Tree loss: 1.053 | Accuracy: 0.583984 | 7.084 sec/iter
Epoch: 89 | Batch: 012 / 030 | Total loss: 1.040 | Reg loss: 0.039 | Tree loss: 1.040 | Accuracy: 0.546875 | 7.084 sec/iter
Epoch: 89 | Batch: 013 / 030 | Total loss: 0.992 | Reg loss: 0.039 | Tree loss: 0.992 | Accuracy: 0.591797 | 7.084 sec/iter
Epoch: 89 | Batch: 014 / 030 | Total loss: 0.991 | Reg loss: 0.040 | Tree loss: 0.991 | Accuracy: 0.583984 | 7.084 sec/iter
Epoch: 89 | Batch: 015 / 030 | Total loss: 0.982 | Reg loss: 0.040 | Tree loss: 0.982 | Accuracy: 0.550781 | 7.084 sec/iter
Epoch: 89 | Batch: 016 / 030 | Total loss: 0.954 | Reg loss: 0.040 | Tree loss: 0.954 | Accuracy: 0.587891 | 7.084 sec/iter
Epoch: 89 | Batch: 017 / 030 | Total loss: 0.948 | Reg loss: 0.040 | Tree loss: 0.948 | Accuracy: 0.589844 | 7.084 sec/iter
Epoch: 89 | Batch: 018 / 030 | Total loss: 0.962 | Reg loss: 0.040 | Tree loss: 0.962 | Accuracy: 0.562500 | 7.084 sec/iter
Epoch: 8

Epoch: 91 | Batch: 012 / 030 | Total loss: 1.026 | Reg loss: 0.039 | Tree loss: 1.026 | Accuracy: 0.580078 | 7.08 sec/iter
Epoch: 91 | Batch: 013 / 030 | Total loss: 1.000 | Reg loss: 0.039 | Tree loss: 1.000 | Accuracy: 0.591797 | 7.08 sec/iter
Epoch: 91 | Batch: 014 / 030 | Total loss: 0.981 | Reg loss: 0.039 | Tree loss: 0.981 | Accuracy: 0.593750 | 7.08 sec/iter
Epoch: 91 | Batch: 015 / 030 | Total loss: 0.983 | Reg loss: 0.039 | Tree loss: 0.983 | Accuracy: 0.576172 | 7.08 sec/iter
Epoch: 91 | Batch: 016 / 030 | Total loss: 0.959 | Reg loss: 0.040 | Tree loss: 0.959 | Accuracy: 0.591797 | 7.08 sec/iter
Epoch: 91 | Batch: 017 / 030 | Total loss: 0.983 | Reg loss: 0.040 | Tree loss: 0.983 | Accuracy: 0.541016 | 7.08 sec/iter
Epoch: 91 | Batch: 018 / 030 | Total loss: 0.920 | Reg loss: 0.040 | Tree loss: 0.920 | Accuracy: 0.630859 | 7.08 sec/iter
Epoch: 91 | Batch: 019 / 030 | Total loss: 0.937 | Reg loss: 0.040 | Tree loss: 0.937 | Accuracy: 0.560547 | 7.08 sec/iter
Epoch: 91 | Batc

Epoch: 93 | Batch: 013 / 030 | Total loss: 0.997 | Reg loss: 0.039 | Tree loss: 0.997 | Accuracy: 0.583984 | 7.063 sec/iter
Epoch: 93 | Batch: 014 / 030 | Total loss: 0.971 | Reg loss: 0.039 | Tree loss: 0.971 | Accuracy: 0.580078 | 7.062 sec/iter
Epoch: 93 | Batch: 015 / 030 | Total loss: 0.986 | Reg loss: 0.039 | Tree loss: 0.986 | Accuracy: 0.546875 | 7.061 sec/iter
Epoch: 93 | Batch: 016 / 030 | Total loss: 0.964 | Reg loss: 0.039 | Tree loss: 0.964 | Accuracy: 0.580078 | 7.061 sec/iter
Epoch: 93 | Batch: 017 / 030 | Total loss: 0.928 | Reg loss: 0.039 | Tree loss: 0.928 | Accuracy: 0.591797 | 7.06 sec/iter
Epoch: 93 | Batch: 018 / 030 | Total loss: 0.927 | Reg loss: 0.040 | Tree loss: 0.927 | Accuracy: 0.593750 | 7.059 sec/iter
Epoch: 93 | Batch: 019 / 030 | Total loss: 0.925 | Reg loss: 0.040 | Tree loss: 0.925 | Accuracy: 0.582031 | 7.058 sec/iter
Epoch: 93 | Batch: 020 / 030 | Total loss: 0.911 | Reg loss: 0.040 | Tree loss: 0.911 | Accuracy: 0.589844 | 7.057 sec/iter
Epoch: 93

Epoch: 95 | Batch: 014 / 030 | Total loss: 0.965 | Reg loss: 0.039 | Tree loss: 0.965 | Accuracy: 0.630859 | 7.018 sec/iter
Epoch: 95 | Batch: 015 / 030 | Total loss: 0.962 | Reg loss: 0.039 | Tree loss: 0.962 | Accuracy: 0.574219 | 7.017 sec/iter
Epoch: 95 | Batch: 016 / 030 | Total loss: 0.932 | Reg loss: 0.039 | Tree loss: 0.932 | Accuracy: 0.601562 | 7.016 sec/iter
Epoch: 95 | Batch: 017 / 030 | Total loss: 0.933 | Reg loss: 0.039 | Tree loss: 0.933 | Accuracy: 0.583984 | 7.015 sec/iter
Epoch: 95 | Batch: 018 / 030 | Total loss: 0.935 | Reg loss: 0.039 | Tree loss: 0.935 | Accuracy: 0.599609 | 7.015 sec/iter
Epoch: 95 | Batch: 019 / 030 | Total loss: 0.950 | Reg loss: 0.039 | Tree loss: 0.950 | Accuracy: 0.509766 | 7.014 sec/iter
Epoch: 95 | Batch: 020 / 030 | Total loss: 0.908 | Reg loss: 0.040 | Tree loss: 0.908 | Accuracy: 0.585938 | 7.013 sec/iter
Epoch: 95 | Batch: 021 / 030 | Total loss: 0.885 | Reg loss: 0.040 | Tree loss: 0.885 | Accuracy: 0.607422 | 7.012 sec/iter
Epoch: 9

Epoch: 97 | Batch: 015 / 030 | Total loss: 0.979 | Reg loss: 0.039 | Tree loss: 0.979 | Accuracy: 0.582031 | 6.964 sec/iter
Epoch: 97 | Batch: 016 / 030 | Total loss: 0.941 | Reg loss: 0.039 | Tree loss: 0.941 | Accuracy: 0.587891 | 6.963 sec/iter
Epoch: 97 | Batch: 017 / 030 | Total loss: 0.961 | Reg loss: 0.039 | Tree loss: 0.961 | Accuracy: 0.560547 | 6.962 sec/iter
Epoch: 97 | Batch: 018 / 030 | Total loss: 0.907 | Reg loss: 0.039 | Tree loss: 0.907 | Accuracy: 0.628906 | 6.961 sec/iter
Epoch: 97 | Batch: 019 / 030 | Total loss: 0.923 | Reg loss: 0.039 | Tree loss: 0.923 | Accuracy: 0.580078 | 6.96 sec/iter
Epoch: 97 | Batch: 020 / 030 | Total loss: 0.918 | Reg loss: 0.039 | Tree loss: 0.918 | Accuracy: 0.546875 | 6.959 sec/iter
Epoch: 97 | Batch: 021 / 030 | Total loss: 0.893 | Reg loss: 0.040 | Tree loss: 0.893 | Accuracy: 0.615234 | 6.958 sec/iter
Epoch: 97 | Batch: 022 / 030 | Total loss: 0.882 | Reg loss: 0.040 | Tree loss: 0.882 | Accuracy: 0.591797 | 6.956 sec/iter
Epoch: 97

Epoch: 99 | Batch: 016 / 030 | Total loss: 0.961 | Reg loss: 0.039 | Tree loss: 0.961 | Accuracy: 0.548828 | 6.894 sec/iter
Epoch: 99 | Batch: 017 / 030 | Total loss: 0.945 | Reg loss: 0.039 | Tree loss: 0.945 | Accuracy: 0.578125 | 6.893 sec/iter
Epoch: 99 | Batch: 018 / 030 | Total loss: 0.923 | Reg loss: 0.039 | Tree loss: 0.923 | Accuracy: 0.601562 | 6.892 sec/iter
Epoch: 99 | Batch: 019 / 030 | Total loss: 0.914 | Reg loss: 0.039 | Tree loss: 0.914 | Accuracy: 0.580078 | 6.891 sec/iter
Epoch: 99 | Batch: 020 / 030 | Total loss: 0.918 | Reg loss: 0.039 | Tree loss: 0.918 | Accuracy: 0.570312 | 6.89 sec/iter
Epoch: 99 | Batch: 021 / 030 | Total loss: 0.878 | Reg loss: 0.039 | Tree loss: 0.878 | Accuracy: 0.597656 | 6.889 sec/iter
Epoch: 99 | Batch: 022 / 030 | Total loss: 0.889 | Reg loss: 0.039 | Tree loss: 0.889 | Accuracy: 0.585938 | 6.887 sec/iter
Epoch: 99 | Batch: 023 / 030 | Total loss: 0.874 | Reg loss: 0.040 | Tree loss: 0.874 | Accuracy: 0.595703 | 6.886 sec/iter
Epoch: 99

In [31]:
plt.figure(figsize=(10, 5))
plt.ylabel("Accuracy")
plt.xlabel('Iteration')
plt.plot(accs, label='Accuracy vs iteration')
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [32]:
plt.figure()
plt.ylabel("Loss")
plt.xlabel('Iteration')
plt.plot(losses, label='Loss vs iteration')
plt.yscale("log")
plt.show()

plt.figure()
weights = tree.inner_nodes.weight.cpu().detach().numpy().flatten()
plt.hist(weights, bins=500)
weights_std = np.std(weights)
weights_mean = np.mean(weights)
plt.axvline(weights_mean + weights_std, color='r')
plt.axvline(weights_mean - weights_std, color='r')
plt.title(f"Mean: {weights_mean}   |   STD: {weights_std}")
plt.yscale("log")
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

# Tree Visualization

In [33]:
plt.figure(figsize=(15, 10), dpi=80)
avg_height, root = tree.visualize()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Average height: 11.736542962219993


# Extract Rules

# Accumulate samples in the leaves

In [34]:
print(f"Number of patterns: {len(root.get_leaves())}")

Number of patterns: 2991


In [35]:
method = 'greedy'

In [36]:
root.clear_leaves_samples()

with torch.no_grad():
    for batch_idx, (data, target) in enumerate(tree_loader):
        root.accumulate_samples(data, method)



# Tighten boundaries

In [37]:
attr_names = dataset.items

# print(attr_names)
leaves = root.get_leaves()
sum_comprehensibility = 0
comprehensibilities = []
for pattern_counter, leaf in enumerate(leaves):
    leaf.reset_path()
    leaf.tighten_with_accumulated_samples()
    conds = leaf.get_path_conditions(attr_names)
    print(f"============== Pattern {pattern_counter + 1} ==============")
    comprehensibilities.append(sum([cond.comprehensibility for cond in conds]))
    
print(f"Average comprehensibility: {np.mean(comprehensibilities)}")
print(f"std comprehensibility: {np.std(comprehensibilities)}")
print(f"var comprehensibility: {np.var(comprehensibilities)}")
print(f"minimum comprehensibility: {np.min(comprehensibilities)}")
print(f"maximum comprehensibility: {np.max(comprehensibilities)}")

  return np.log(1 / (1 - x))


14951






Average comprehensibility: 60.04948177866934
std comprehensibility: 4.502708301128793
var comprehensibility: 20.274382045054143
minimum comprehensibility: 38
maximum comprehensibility: 70
