In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib widget
import numpy as np
import pandas as pd
import torch
import torchvision
from torch import optim
import matplotlib.pyplot as plt
import torch.nn.functional as F
from sklearn.metrics import pairwise_distances
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tqdm.notebook import tqdm
import os
import sys
import time
import torch.nn as nn
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)
from stream_generators.market_basket_dataset import MarketBasketDataset, BinaryEncodingTransform, RemoveItemsTransform
from utils.MatplotlibUtils import reduce_dims_and_plot
from network.auto_encoder import AutoEncoder
from losses.knn_loss import KNNLoss
from sklearn.cluster import DBSCAN
from sklearn.mixture import GaussianMixture
from soft_decision_tree.sdt_model import SDT
from sklearn.metrics import davies_bouldin_score

In [2]:
k = 128
tree_depth = 6
device = 'cuda'
dataset_path = r"/mnt/qnap/ekosman/Groceries_dataset.csv"

## Create the market basket dataset and use one-hot encoding for items

In [3]:
dataset = MarketBasketDataset(dataset_path=dataset_path)

In [4]:
model = AutoEncoder(dataset.n_items, 50, 4).train().to(device)
epochs = 500
lr = 5e-3
batch_size = 512
log_every = 5

In [5]:
dataset.transform = torchvision.transforms.Compose([
    RemoveItemsTransform(p=0.5),
    BinaryEncodingTransform(mapping=dataset.items_to_idx),
]
)
dataset.target_transform = torchvision.transforms.Compose([
    BinaryEncodingTransform(mapping=dataset.items_to_idx),
]
)

In [6]:
model.train()
data_iter = torch.utils.data.DataLoader(dataset,
                                     batch_size=batch_size,
                                     shuffle=True,
                                     num_workers=1,
                                     pin_memory=True)
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)
scheduler = ReduceLROnPlateau(optimizer, 'min', verbose=True, factor=0.7, threshold=1e-4)
knn_crt = KNNLoss(k=k).to(device)
losses = []
alpha = 10/170
gamma = 2
for epoch in range(epochs):
    total_loss = 0
    for iteration, (batch, target) in enumerate(data_iter):
        batch = batch.to(device)
        target = target.to(device)
        outputs, iterm = model(batch, return_intermidiate=True)
        mse_loss = F.binary_cross_entropy_with_logits(outputs, target, reduction='none')
        mask = torch.ones_like(mse_loss)
        mask[target == 0] = alpha ** gamma
        mask[target == 1] = (1 - alpha) ** gamma
        mse_loss = (mse_loss * mask).sum(dim=-1).mean()
        try:
            knn_loss = knn_crt(iterm)
            if torch.isinf(knn_loss):
                knn_loss = 0
        except ValueError:
            knn_loss = torch.tensor(0)
        loss = mse_loss + knn_loss
        total_loss += loss.item()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if iteration % log_every == 0:
            print(f"Epoch {epoch} / {epochs} | iteration {iteration} / {len(data_iter)} | Total Loss: {loss.item()} | KNN Loss: {knn_loss.item()} | BCE Loss: {mse_loss.item()}")
    
    scheduler.step(total_loss / (iteration + 1))
    losses.append(total_loss / (iteration + 1))

Epoch 0 / 500 | iteration 0 / 30 | Total Loss: 8.180225372314453 | KNN Loss: 6.229950904846191 | BCE Loss: 1.9502747058868408
Epoch 0 / 500 | iteration 5 / 30 | Total Loss: 8.183635711669922 | KNN Loss: 6.229683876037598 | BCE Loss: 1.9539520740509033
Epoch 0 / 500 | iteration 10 / 30 | Total Loss: 8.158685684204102 | KNN Loss: 6.2299699783325195 | BCE Loss: 1.928715467453003
Epoch 0 / 500 | iteration 15 / 30 | Total Loss: 8.198995590209961 | KNN Loss: 6.229693412780762 | BCE Loss: 1.9693019390106201
Epoch 0 / 500 | iteration 20 / 30 | Total Loss: 8.157793998718262 | KNN Loss: 6.229772567749023 | BCE Loss: 1.9280214309692383
Epoch 0 / 500 | iteration 25 / 30 | Total Loss: 8.0819673538208 | KNN Loss: 6.229701042175293 | BCE Loss: 1.8522663116455078
Epoch 1 / 500 | iteration 0 / 30 | Total Loss: 8.16765022277832 | KNN Loss: 6.229369640350342 | BCE Loss: 1.9382810592651367
Epoch 1 / 500 | iteration 5 / 30 | Total Loss: 8.128780364990234 | KNN Loss: 6.229136943817139 | BCE Loss: 1.89964354

Epoch 10 / 500 | iteration 25 / 30 | Total Loss: 7.07936954498291 | KNN Loss: 5.899815082550049 | BCE Loss: 1.1795545816421509
Epoch 11 / 500 | iteration 0 / 30 | Total Loss: 7.034160614013672 | KNN Loss: 5.883945465087891 | BCE Loss: 1.1502151489257812
Epoch 11 / 500 | iteration 5 / 30 | Total Loss: 6.996936798095703 | KNN Loss: 5.832756519317627 | BCE Loss: 1.1641805171966553
Epoch 11 / 500 | iteration 10 / 30 | Total Loss: 6.9601149559021 | KNN Loss: 5.7869720458984375 | BCE Loss: 1.1731430292129517
Epoch 11 / 500 | iteration 15 / 30 | Total Loss: 6.9008636474609375 | KNN Loss: 5.745500564575195 | BCE Loss: 1.155362844467163
Epoch 11 / 500 | iteration 20 / 30 | Total Loss: 6.829466819763184 | KNN Loss: 5.708059787750244 | BCE Loss: 1.1214070320129395
Epoch 11 / 500 | iteration 25 / 30 | Total Loss: 6.806700706481934 | KNN Loss: 5.70042610168457 | BCE Loss: 1.1062748432159424
Epoch 12 / 500 | iteration 0 / 30 | Total Loss: 6.7193603515625 | KNN Loss: 5.615592002868652 | BCE Loss: 1.1

Epoch 21 / 500 | iteration 20 / 30 | Total Loss: 6.136683464050293 | KNN Loss: 5.08859395980835 | BCE Loss: 1.0480892658233643
Epoch 21 / 500 | iteration 25 / 30 | Total Loss: 6.154345512390137 | KNN Loss: 5.095767498016357 | BCE Loss: 1.0585782527923584
Epoch 22 / 500 | iteration 0 / 30 | Total Loss: 6.227887153625488 | KNN Loss: 5.146585464477539 | BCE Loss: 1.0813019275665283
Epoch 22 / 500 | iteration 5 / 30 | Total Loss: 6.109425067901611 | KNN Loss: 5.08043098449707 | BCE Loss: 1.028994083404541
Epoch 22 / 500 | iteration 10 / 30 | Total Loss: 6.174731731414795 | KNN Loss: 5.090341091156006 | BCE Loss: 1.084390640258789
Epoch 22 / 500 | iteration 15 / 30 | Total Loss: 6.132198333740234 | KNN Loss: 5.077462673187256 | BCE Loss: 1.054735779762268
Epoch 22 / 500 | iteration 20 / 30 | Total Loss: 6.128279209136963 | KNN Loss: 5.067727565765381 | BCE Loss: 1.060551643371582
Epoch 22 / 500 | iteration 25 / 30 | Total Loss: 6.169456958770752 | KNN Loss: 5.112816333770752 | BCE Loss: 1.0

Epoch 32 / 500 | iteration 15 / 30 | Total Loss: 6.130435943603516 | KNN Loss: 5.046717166900635 | BCE Loss: 1.0837187767028809
Epoch 32 / 500 | iteration 20 / 30 | Total Loss: 6.0779924392700195 | KNN Loss: 5.037525653839111 | BCE Loss: 1.0404667854309082
Epoch 32 / 500 | iteration 25 / 30 | Total Loss: 6.108156204223633 | KNN Loss: 5.0492377281188965 | BCE Loss: 1.0589187145233154
Epoch 33 / 500 | iteration 0 / 30 | Total Loss: 6.203396797180176 | KNN Loss: 5.093068599700928 | BCE Loss: 1.1103284358978271
Epoch 33 / 500 | iteration 5 / 30 | Total Loss: 6.112232208251953 | KNN Loss: 5.073187351226807 | BCE Loss: 1.0390450954437256
Epoch 33 / 500 | iteration 10 / 30 | Total Loss: 6.130401611328125 | KNN Loss: 5.056979179382324 | BCE Loss: 1.0734226703643799
Epoch 33 / 500 | iteration 15 / 30 | Total Loss: 6.106178283691406 | KNN Loss: 5.061697006225586 | BCE Loss: 1.0444811582565308
Epoch 33 / 500 | iteration 20 / 30 | Total Loss: 6.100916862487793 | KNN Loss: 5.063270568847656 | BCE L

Epoch 43 / 500 | iteration 10 / 30 | Total Loss: 6.0896077156066895 | KNN Loss: 5.037966728210449 | BCE Loss: 1.0516409873962402
Epoch 43 / 500 | iteration 15 / 30 | Total Loss: 6.083768844604492 | KNN Loss: 5.039661884307861 | BCE Loss: 1.04410719871521
Epoch 43 / 500 | iteration 20 / 30 | Total Loss: 6.115143775939941 | KNN Loss: 5.054600238800049 | BCE Loss: 1.0605437755584717
Epoch 43 / 500 | iteration 25 / 30 | Total Loss: 6.11208438873291 | KNN Loss: 5.058661937713623 | BCE Loss: 1.0534226894378662
Epoch 44 / 500 | iteration 0 / 30 | Total Loss: 6.074542999267578 | KNN Loss: 5.037545680999756 | BCE Loss: 1.0369970798492432
Epoch 44 / 500 | iteration 5 / 30 | Total Loss: 6.1285080909729 | KNN Loss: 5.04819917678833 | BCE Loss: 1.0803089141845703
Epoch 44 / 500 | iteration 10 / 30 | Total Loss: 6.118646621704102 | KNN Loss: 5.072786331176758 | BCE Loss: 1.0458600521087646
Epoch 44 / 500 | iteration 15 / 30 | Total Loss: 6.0933332443237305 | KNN Loss: 5.0483551025390625 | BCE Loss: 

Epoch 54 / 500 | iteration 5 / 30 | Total Loss: 6.0696024894714355 | KNN Loss: 5.0296831130981445 | BCE Loss: 1.039919376373291
Epoch 54 / 500 | iteration 10 / 30 | Total Loss: 6.080321788787842 | KNN Loss: 5.031465530395508 | BCE Loss: 1.048856258392334
Epoch 54 / 500 | iteration 15 / 30 | Total Loss: 6.100520610809326 | KNN Loss: 5.021414756774902 | BCE Loss: 1.0791059732437134
Epoch 54 / 500 | iteration 20 / 30 | Total Loss: 6.1001458168029785 | KNN Loss: 5.057011127471924 | BCE Loss: 1.0431345701217651
Epoch 54 / 500 | iteration 25 / 30 | Total Loss: 6.1169891357421875 | KNN Loss: 5.054841041564941 | BCE Loss: 1.062147855758667
Epoch 55 / 500 | iteration 0 / 30 | Total Loss: 6.090317726135254 | KNN Loss: 5.046672821044922 | BCE Loss: 1.0436450242996216
Epoch 55 / 500 | iteration 5 / 30 | Total Loss: 6.137221813201904 | KNN Loss: 5.048509120941162 | BCE Loss: 1.0887126922607422
Epoch 55 / 500 | iteration 10 / 30 | Total Loss: 6.166006088256836 | KNN Loss: 5.112795352935791 | BCE Los

Epoch 65 / 500 | iteration 0 / 30 | Total Loss: 6.041706085205078 | KNN Loss: 5.017374515533447 | BCE Loss: 1.0243315696716309
Epoch 65 / 500 | iteration 5 / 30 | Total Loss: 6.0404157638549805 | KNN Loss: 5.0175700187683105 | BCE Loss: 1.02284574508667
Epoch 65 / 500 | iteration 10 / 30 | Total Loss: 6.049448013305664 | KNN Loss: 5.01605224609375 | BCE Loss: 1.033395528793335
Epoch 65 / 500 | iteration 15 / 30 | Total Loss: 6.062939643859863 | KNN Loss: 5.035002708435059 | BCE Loss: 1.0279369354248047
Epoch 65 / 500 | iteration 20 / 30 | Total Loss: 6.055649757385254 | KNN Loss: 5.013498306274414 | BCE Loss: 1.0421514511108398
Epoch 65 / 500 | iteration 25 / 30 | Total Loss: 6.091758728027344 | KNN Loss: 5.031163692474365 | BCE Loss: 1.0605950355529785
Epoch 66 / 500 | iteration 0 / 30 | Total Loss: 6.097441673278809 | KNN Loss: 5.0313920974731445 | BCE Loss: 1.066049337387085
Epoch 66 / 500 | iteration 5 / 30 | Total Loss: 6.106651306152344 | KNN Loss: 5.041915416717529 | BCE Loss: 1

Epoch 75 / 500 | iteration 25 / 30 | Total Loss: 6.06343936920166 | KNN Loss: 5.036379337310791 | BCE Loss: 1.02705979347229
Epoch 76 / 500 | iteration 0 / 30 | Total Loss: 6.0555315017700195 | KNN Loss: 5.015176773071289 | BCE Loss: 1.0403549671173096
Epoch 76 / 500 | iteration 5 / 30 | Total Loss: 6.082566261291504 | KNN Loss: 5.032643795013428 | BCE Loss: 1.049922227859497
Epoch 76 / 500 | iteration 10 / 30 | Total Loss: 6.067405700683594 | KNN Loss: 5.012475490570068 | BCE Loss: 1.0549304485321045
Epoch 76 / 500 | iteration 15 / 30 | Total Loss: 6.104670524597168 | KNN Loss: 5.027459621429443 | BCE Loss: 1.0772110223770142
Epoch 76 / 500 | iteration 20 / 30 | Total Loss: 6.140169143676758 | KNN Loss: 5.109612464904785 | BCE Loss: 1.0305566787719727
Epoch 76 / 500 | iteration 25 / 30 | Total Loss: 6.049584865570068 | KNN Loss: 5.01343297958374 | BCE Loss: 1.0361520051956177
Epoch 77 / 500 | iteration 0 / 30 | Total Loss: 6.075631141662598 | KNN Loss: 5.02353572845459 | BCE Loss: 1.0

Epoch 86 / 500 | iteration 20 / 30 | Total Loss: 6.059650421142578 | KNN Loss: 5.025633811950684 | BCE Loss: 1.0340168476104736
Epoch 86 / 500 | iteration 25 / 30 | Total Loss: 6.03624153137207 | KNN Loss: 5.004266738891602 | BCE Loss: 1.0319749116897583
Epoch 87 / 500 | iteration 0 / 30 | Total Loss: 6.02328634262085 | KNN Loss: 5.006771087646484 | BCE Loss: 1.0165152549743652
Epoch 87 / 500 | iteration 5 / 30 | Total Loss: 6.0826520919799805 | KNN Loss: 5.031464099884033 | BCE Loss: 1.0511879920959473
Epoch 87 / 500 | iteration 10 / 30 | Total Loss: 6.055439472198486 | KNN Loss: 5.00681734085083 | BCE Loss: 1.0486220121383667
Epoch 87 / 500 | iteration 15 / 30 | Total Loss: 6.092889308929443 | KNN Loss: 5.05412483215332 | BCE Loss: 1.0387643575668335
Epoch 87 / 500 | iteration 20 / 30 | Total Loss: 6.040262222290039 | KNN Loss: 5.015324592590332 | BCE Loss: 1.024937629699707
Epoch 87 / 500 | iteration 25 / 30 | Total Loss: 6.037534713745117 | KNN Loss: 4.999297618865967 | BCE Loss: 1

Epoch 97 / 500 | iteration 15 / 30 | Total Loss: 6.028935432434082 | KNN Loss: 5.014415264129639 | BCE Loss: 1.0145201683044434
Epoch 97 / 500 | iteration 20 / 30 | Total Loss: 6.0476765632629395 | KNN Loss: 5.018375396728516 | BCE Loss: 1.0293011665344238
Epoch 97 / 500 | iteration 25 / 30 | Total Loss: 6.0189313888549805 | KNN Loss: 4.999255180358887 | BCE Loss: 1.0196764469146729
Epoch 98 / 500 | iteration 0 / 30 | Total Loss: 6.04156494140625 | KNN Loss: 5.020716667175293 | BCE Loss: 1.0208485126495361
Epoch 98 / 500 | iteration 5 / 30 | Total Loss: 6.042229652404785 | KNN Loss: 5.020939350128174 | BCE Loss: 1.0212905406951904
Epoch 98 / 500 | iteration 10 / 30 | Total Loss: 6.038056373596191 | KNN Loss: 5.0220537185668945 | BCE Loss: 1.016002893447876
Epoch 98 / 500 | iteration 15 / 30 | Total Loss: 6.033816814422607 | KNN Loss: 5.031329154968262 | BCE Loss: 1.0024877786636353
Epoch 98 / 500 | iteration 20 / 30 | Total Loss: 6.054437637329102 | KNN Loss: 5.027155876159668 | BCE Lo

Epoch 108 / 500 | iteration 5 / 30 | Total Loss: 6.0429534912109375 | KNN Loss: 5.013554573059082 | BCE Loss: 1.0293986797332764
Epoch 108 / 500 | iteration 10 / 30 | Total Loss: 6.009341239929199 | KNN Loss: 4.989260673522949 | BCE Loss: 1.020080804824829
Epoch 108 / 500 | iteration 15 / 30 | Total Loss: 6.020467281341553 | KNN Loss: 5.001977920532227 | BCE Loss: 1.0184894800186157
Epoch 108 / 500 | iteration 20 / 30 | Total Loss: 6.085610866546631 | KNN Loss: 5.029135704040527 | BCE Loss: 1.056475281715393
Epoch 108 / 500 | iteration 25 / 30 | Total Loss: 6.028777599334717 | KNN Loss: 4.985438346862793 | BCE Loss: 1.0433392524719238
Epoch 109 / 500 | iteration 0 / 30 | Total Loss: 6.029018878936768 | KNN Loss: 5.047170162200928 | BCE Loss: 0.9818486571311951
Epoch 109 / 500 | iteration 5 / 30 | Total Loss: 6.033347129821777 | KNN Loss: 5.006369590759277 | BCE Loss: 1.0269775390625
Epoch 109 / 500 | iteration 10 / 30 | Total Loss: 6.041639804840088 | KNN Loss: 4.99087381362915 | BCE L

Epoch 118 / 500 | iteration 25 / 30 | Total Loss: 6.039344787597656 | KNN Loss: 5.014034748077393 | BCE Loss: 1.0253098011016846
Epoch 119 / 500 | iteration 0 / 30 | Total Loss: 6.056617259979248 | KNN Loss: 4.993488788604736 | BCE Loss: 1.0631283521652222
Epoch 119 / 500 | iteration 5 / 30 | Total Loss: 6.022832870483398 | KNN Loss: 4.99938440322876 | BCE Loss: 1.0234485864639282
Epoch 119 / 500 | iteration 10 / 30 | Total Loss: 6.041475296020508 | KNN Loss: 5.001172065734863 | BCE Loss: 1.040303349494934
Epoch 119 / 500 | iteration 15 / 30 | Total Loss: 6.025002479553223 | KNN Loss: 4.995077133178711 | BCE Loss: 1.0299252271652222
Epoch 119 / 500 | iteration 20 / 30 | Total Loss: 6.082763671875 | KNN Loss: 5.066988468170166 | BCE Loss: 1.0157749652862549
Epoch 119 / 500 | iteration 25 / 30 | Total Loss: 6.021142482757568 | KNN Loss: 5.0023579597473145 | BCE Loss: 1.0187846422195435
Epoch 120 / 500 | iteration 0 / 30 | Total Loss: 6.016135215759277 | KNN Loss: 4.994772911071777 | BCE 

Epoch 129 / 500 | iteration 15 / 30 | Total Loss: 6.042799472808838 | KNN Loss: 4.97893762588501 | BCE Loss: 1.0638618469238281
Epoch 129 / 500 | iteration 20 / 30 | Total Loss: 6.04741096496582 | KNN Loss: 5.011135578155518 | BCE Loss: 1.0362751483917236
Epoch 129 / 500 | iteration 25 / 30 | Total Loss: 5.997302055358887 | KNN Loss: 4.984949588775635 | BCE Loss: 1.0123522281646729
Epoch 130 / 500 | iteration 0 / 30 | Total Loss: 6.0244340896606445 | KNN Loss: 4.98615026473999 | BCE Loss: 1.0382840633392334
Epoch 130 / 500 | iteration 5 / 30 | Total Loss: 6.0276079177856445 | KNN Loss: 4.988617897033691 | BCE Loss: 1.0389901399612427
Epoch 130 / 500 | iteration 10 / 30 | Total Loss: 6.055008888244629 | KNN Loss: 5.041882514953613 | BCE Loss: 1.0131261348724365
Epoch 130 / 500 | iteration 15 / 30 | Total Loss: 6.05958366394043 | KNN Loss: 5.048610687255859 | BCE Loss: 1.0109727382659912
Epoch 130 / 500 | iteration 20 / 30 | Total Loss: 6.0307393074035645 | KNN Loss: 5.018115520477295 | 

Epoch 140 / 500 | iteration 5 / 30 | Total Loss: 6.061967849731445 | KNN Loss: 5.016207218170166 | BCE Loss: 1.0457608699798584
Epoch 140 / 500 | iteration 10 / 30 | Total Loss: 6.068056106567383 | KNN Loss: 5.016519069671631 | BCE Loss: 1.051537275314331
Epoch 140 / 500 | iteration 15 / 30 | Total Loss: 6.016406059265137 | KNN Loss: 4.98006534576416 | BCE Loss: 1.0363404750823975
Epoch 140 / 500 | iteration 20 / 30 | Total Loss: 6.089438438415527 | KNN Loss: 5.06364631652832 | BCE Loss: 1.0257923603057861
Epoch 140 / 500 | iteration 25 / 30 | Total Loss: 6.032982349395752 | KNN Loss: 5.0310492515563965 | BCE Loss: 1.001933217048645
Epoch 141 / 500 | iteration 0 / 30 | Total Loss: 6.099852561950684 | KNN Loss: 5.047641277313232 | BCE Loss: 1.0522111654281616
Epoch 141 / 500 | iteration 5 / 30 | Total Loss: 6.1062188148498535 | KNN Loss: 5.032506942749023 | BCE Loss: 1.07371187210083
Epoch 141 / 500 | iteration 10 / 30 | Total Loss: 6.0432586669921875 | KNN Loss: 5.015885353088379 | BCE

Epoch 150 / 500 | iteration 25 / 30 | Total Loss: 6.00905179977417 | KNN Loss: 4.994162559509277 | BCE Loss: 1.0148893594741821
Epoch 151 / 500 | iteration 0 / 30 | Total Loss: 6.037932395935059 | KNN Loss: 5.019835948944092 | BCE Loss: 1.0180965662002563
Epoch 151 / 500 | iteration 5 / 30 | Total Loss: 6.0169267654418945 | KNN Loss: 5.010152339935303 | BCE Loss: 1.0067743062973022
Epoch 151 / 500 | iteration 10 / 30 | Total Loss: 6.08650016784668 | KNN Loss: 5.043859004974365 | BCE Loss: 1.0426409244537354
Epoch 151 / 500 | iteration 15 / 30 | Total Loss: 6.03225564956665 | KNN Loss: 5.002765655517578 | BCE Loss: 1.0294899940490723
Epoch 151 / 500 | iteration 20 / 30 | Total Loss: 6.038332939147949 | KNN Loss: 5.00814151763916 | BCE Loss: 1.0301915407180786
Epoch 151 / 500 | iteration 25 / 30 | Total Loss: 6.0367889404296875 | KNN Loss: 5.026343822479248 | BCE Loss: 1.0104448795318604
Epoch 152 / 500 | iteration 0 / 30 | Total Loss: 6.032766819000244 | KNN Loss: 5.008983135223389 | BC

Epoch 161 / 500 | iteration 15 / 30 | Total Loss: 6.049095630645752 | KNN Loss: 5.015454292297363 | BCE Loss: 1.0336413383483887
Epoch 161 / 500 | iteration 20 / 30 | Total Loss: 6.022940635681152 | KNN Loss: 4.9958696365356445 | BCE Loss: 1.0270711183547974
Epoch 161 / 500 | iteration 25 / 30 | Total Loss: 6.0734710693359375 | KNN Loss: 5.02890682220459 | BCE Loss: 1.0445644855499268
Epoch 162 / 500 | iteration 0 / 30 | Total Loss: 6.045058250427246 | KNN Loss: 5.035076141357422 | BCE Loss: 1.0099823474884033
Epoch 162 / 500 | iteration 5 / 30 | Total Loss: 6.112523555755615 | KNN Loss: 5.0693359375 | BCE Loss: 1.0431876182556152
Epoch 162 / 500 | iteration 10 / 30 | Total Loss: 6.032015323638916 | KNN Loss: 4.998532772064209 | BCE Loss: 1.033482551574707
Epoch 162 / 500 | iteration 15 / 30 | Total Loss: 6.052684783935547 | KNN Loss: 5.023055553436279 | BCE Loss: 1.0296292304992676
Epoch 162 / 500 | iteration 20 / 30 | Total Loss: 6.057159423828125 | KNN Loss: 5.026776313781738 | BCE 

Epoch 172 / 500 | iteration 5 / 30 | Total Loss: 6.031352996826172 | KNN Loss: 5.0040106773376465 | BCE Loss: 1.0273425579071045
Epoch 172 / 500 | iteration 10 / 30 | Total Loss: 6.039177894592285 | KNN Loss: 5.012136459350586 | BCE Loss: 1.0270415544509888
Epoch 172 / 500 | iteration 15 / 30 | Total Loss: 6.042013168334961 | KNN Loss: 5.027540683746338 | BCE Loss: 1.014472484588623
Epoch 172 / 500 | iteration 20 / 30 | Total Loss: 6.011489391326904 | KNN Loss: 4.990205764770508 | BCE Loss: 1.021283745765686
Epoch 172 / 500 | iteration 25 / 30 | Total Loss: 6.043184280395508 | KNN Loss: 5.000129699707031 | BCE Loss: 1.0430548191070557
Epoch 173 / 500 | iteration 0 / 30 | Total Loss: 6.073986053466797 | KNN Loss: 5.0400872230529785 | BCE Loss: 1.0338985919952393
Epoch 173 / 500 | iteration 5 / 30 | Total Loss: 6.0549421310424805 | KNN Loss: 5.002638339996338 | BCE Loss: 1.0523035526275635
Epoch 173 / 500 | iteration 10 / 30 | Total Loss: 6.084356784820557 | KNN Loss: 5.032279968261719 |

Epoch 182 / 500 | iteration 25 / 30 | Total Loss: 6.055254936218262 | KNN Loss: 5.032163619995117 | BCE Loss: 1.0230915546417236
Epoch 183 / 500 | iteration 0 / 30 | Total Loss: 6.007072925567627 | KNN Loss: 4.979960918426514 | BCE Loss: 1.0271121263504028
Epoch 183 / 500 | iteration 5 / 30 | Total Loss: 6.043704509735107 | KNN Loss: 5.002006530761719 | BCE Loss: 1.0416978597640991
Epoch 183 / 500 | iteration 10 / 30 | Total Loss: 6.087887763977051 | KNN Loss: 5.046228408813477 | BCE Loss: 1.0416591167449951
Epoch 183 / 500 | iteration 15 / 30 | Total Loss: 6.014974117279053 | KNN Loss: 4.987212657928467 | BCE Loss: 1.027761459350586
Epoch 183 / 500 | iteration 20 / 30 | Total Loss: 6.049522399902344 | KNN Loss: 5.01785945892334 | BCE Loss: 1.0316628217697144
Epoch 183 / 500 | iteration 25 / 30 | Total Loss: 6.043483734130859 | KNN Loss: 5.009125709533691 | BCE Loss: 1.0343579053878784
Epoch 184 / 500 | iteration 0 / 30 | Total Loss: 6.023841857910156 | KNN Loss: 4.994759559631348 | BC

Epoch 193 / 500 | iteration 15 / 30 | Total Loss: 6.008994102478027 | KNN Loss: 4.99589204788208 | BCE Loss: 1.0131018161773682
Epoch 193 / 500 | iteration 20 / 30 | Total Loss: 6.047043323516846 | KNN Loss: 4.9933037757873535 | BCE Loss: 1.0537395477294922
Epoch 193 / 500 | iteration 25 / 30 | Total Loss: 6.0696306228637695 | KNN Loss: 5.016899108886719 | BCE Loss: 1.0527315139770508
Epoch 194 / 500 | iteration 0 / 30 | Total Loss: 6.021917343139648 | KNN Loss: 5.024557590484619 | BCE Loss: 0.9973598122596741
Epoch 194 / 500 | iteration 5 / 30 | Total Loss: 6.0365800857543945 | KNN Loss: 5.05091667175293 | BCE Loss: 0.9856632947921753
Epoch 194 / 500 | iteration 10 / 30 | Total Loss: 6.030801296234131 | KNN Loss: 5.0106706619262695 | BCE Loss: 1.0201306343078613
Epoch 194 / 500 | iteration 15 / 30 | Total Loss: 6.0315937995910645 | KNN Loss: 5.011815547943115 | BCE Loss: 1.0197783708572388
Epoch 194 / 500 | iteration 20 / 30 | Total Loss: 6.042860984802246 | KNN Loss: 4.99204730987548

Epoch 204 / 500 | iteration 5 / 30 | Total Loss: 6.036676406860352 | KNN Loss: 5.009092807769775 | BCE Loss: 1.027583360671997
Epoch 204 / 500 | iteration 10 / 30 | Total Loss: 6.128673553466797 | KNN Loss: 5.068080902099609 | BCE Loss: 1.060592770576477
Epoch 204 / 500 | iteration 15 / 30 | Total Loss: 6.012908935546875 | KNN Loss: 4.996050834655762 | BCE Loss: 1.0168578624725342
Epoch 204 / 500 | iteration 20 / 30 | Total Loss: 5.999847888946533 | KNN Loss: 4.9896135330200195 | BCE Loss: 1.0102342367172241
Epoch 204 / 500 | iteration 25 / 30 | Total Loss: 6.040980339050293 | KNN Loss: 5.001227378845215 | BCE Loss: 1.0397531986236572
Epoch 205 / 500 | iteration 0 / 30 | Total Loss: 6.069392204284668 | KNN Loss: 5.024964809417725 | BCE Loss: 1.0444276332855225
Epoch 205 / 500 | iteration 5 / 30 | Total Loss: 6.0454912185668945 | KNN Loss: 5.018773555755615 | BCE Loss: 1.0267175436019897
Epoch 205 / 500 | iteration 10 / 30 | Total Loss: 6.00527286529541 | KNN Loss: 5.00943660736084 | BC

Epoch 214 / 500 | iteration 25 / 30 | Total Loss: 6.027600288391113 | KNN Loss: 5.015091896057129 | BCE Loss: 1.0125083923339844
Epoch 215 / 500 | iteration 0 / 30 | Total Loss: 6.012761116027832 | KNN Loss: 4.984920978546143 | BCE Loss: 1.0278403759002686
Epoch 215 / 500 | iteration 5 / 30 | Total Loss: 6.075695991516113 | KNN Loss: 5.048672676086426 | BCE Loss: 1.027023434638977
Epoch 215 / 500 | iteration 10 / 30 | Total Loss: 6.060834884643555 | KNN Loss: 5.039547920227051 | BCE Loss: 1.021287202835083
Epoch 215 / 500 | iteration 15 / 30 | Total Loss: 6.066418170928955 | KNN Loss: 5.037097454071045 | BCE Loss: 1.0293205976486206
Epoch 215 / 500 | iteration 20 / 30 | Total Loss: 6.041723251342773 | KNN Loss: 5.031862735748291 | BCE Loss: 1.0098605155944824
Epoch 215 / 500 | iteration 25 / 30 | Total Loss: 6.034419536590576 | KNN Loss: 5.002009391784668 | BCE Loss: 1.0324100255966187
Epoch 216 / 500 | iteration 0 / 30 | Total Loss: 6.119030952453613 | KNN Loss: 5.082966327667236 | BC

Epoch 225 / 500 | iteration 15 / 30 | Total Loss: 6.042793273925781 | KNN Loss: 5.020571231842041 | BCE Loss: 1.0222220420837402
Epoch 225 / 500 | iteration 20 / 30 | Total Loss: 6.069212913513184 | KNN Loss: 5.0487470626831055 | BCE Loss: 1.020465612411499
Epoch 225 / 500 | iteration 25 / 30 | Total Loss: 6.078712463378906 | KNN Loss: 5.047940254211426 | BCE Loss: 1.0307719707489014
Epoch 226 / 500 | iteration 0 / 30 | Total Loss: 6.048078536987305 | KNN Loss: 4.995428562164307 | BCE Loss: 1.0526502132415771
Epoch 226 / 500 | iteration 5 / 30 | Total Loss: 5.992036819458008 | KNN Loss: 4.984274864196777 | BCE Loss: 1.0077621936798096
Epoch 226 / 500 | iteration 10 / 30 | Total Loss: 5.9871320724487305 | KNN Loss: 4.971126079559326 | BCE Loss: 1.0160062313079834
Epoch 226 / 500 | iteration 15 / 30 | Total Loss: 6.040619373321533 | KNN Loss: 5.026799201965332 | BCE Loss: 1.0138201713562012
Epoch 226 / 500 | iteration 20 / 30 | Total Loss: 6.0364556312561035 | KNN Loss: 5.006229877471924

Epoch 236 / 500 | iteration 5 / 30 | Total Loss: 6.05882453918457 | KNN Loss: 5.016063690185547 | BCE Loss: 1.0427608489990234
Epoch 236 / 500 | iteration 10 / 30 | Total Loss: 6.0089006423950195 | KNN Loss: 4.986953258514404 | BCE Loss: 1.0219471454620361
Epoch 236 / 500 | iteration 15 / 30 | Total Loss: 6.0318474769592285 | KNN Loss: 5.017027854919434 | BCE Loss: 1.014819622039795
Epoch 236 / 500 | iteration 20 / 30 | Total Loss: 6.040718078613281 | KNN Loss: 5.023330211639404 | BCE Loss: 1.017388105392456
Epoch 236 / 500 | iteration 25 / 30 | Total Loss: 5.992644786834717 | KNN Loss: 4.9713521003723145 | BCE Loss: 1.021292805671692
Epoch 237 / 500 | iteration 0 / 30 | Total Loss: 6.056281089782715 | KNN Loss: 5.027825355529785 | BCE Loss: 1.0284559726715088
Epoch 237 / 500 | iteration 5 / 30 | Total Loss: 5.999690532684326 | KNN Loss: 4.989825248718262 | BCE Loss: 1.009865164756775
Epoch 237 / 500 | iteration 10 / 30 | Total Loss: 6.0180463790893555 | KNN Loss: 4.979657173156738 | B

Epoch 246 / 500 | iteration 25 / 30 | Total Loss: 6.013554096221924 | KNN Loss: 5.013326168060303 | BCE Loss: 1.000227928161621
Epoch 247 / 500 | iteration 0 / 30 | Total Loss: 6.02556037902832 | KNN Loss: 4.976141452789307 | BCE Loss: 1.0494186878204346
Epoch 247 / 500 | iteration 5 / 30 | Total Loss: 6.031200408935547 | KNN Loss: 5.0092549324035645 | BCE Loss: 1.0219454765319824
Epoch 247 / 500 | iteration 10 / 30 | Total Loss: 5.988860130310059 | KNN Loss: 4.977952003479004 | BCE Loss: 1.0109078884124756
Epoch 247 / 500 | iteration 15 / 30 | Total Loss: 6.054637908935547 | KNN Loss: 5.026487350463867 | BCE Loss: 1.0281503200531006
Epoch 247 / 500 | iteration 20 / 30 | Total Loss: 6.046933174133301 | KNN Loss: 5.024504661560059 | BCE Loss: 1.0224287509918213
Epoch 247 / 500 | iteration 25 / 30 | Total Loss: 6.0405097007751465 | KNN Loss: 4.99690055847168 | BCE Loss: 1.0436090230941772
Epoch 248 / 500 | iteration 0 / 30 | Total Loss: 6.052636623382568 | KNN Loss: 5.0313286781311035 | 

Epoch 257 / 500 | iteration 15 / 30 | Total Loss: 6.024074554443359 | KNN Loss: 5.001955986022949 | BCE Loss: 1.0221185684204102
Epoch 257 / 500 | iteration 20 / 30 | Total Loss: 6.058073043823242 | KNN Loss: 5.005248069763184 | BCE Loss: 1.0528247356414795
Epoch 257 / 500 | iteration 25 / 30 | Total Loss: 6.004766464233398 | KNN Loss: 4.99049711227417 | BCE Loss: 1.014269471168518
Epoch 258 / 500 | iteration 0 / 30 | Total Loss: 6.052371501922607 | KNN Loss: 4.997718334197998 | BCE Loss: 1.0546531677246094
Epoch 258 / 500 | iteration 5 / 30 | Total Loss: 6.034512519836426 | KNN Loss: 5.0196533203125 | BCE Loss: 1.0148591995239258
Epoch 258 / 500 | iteration 10 / 30 | Total Loss: 6.039517879486084 | KNN Loss: 5.013401985168457 | BCE Loss: 1.026115894317627
Epoch 258 / 500 | iteration 15 / 30 | Total Loss: 6.064043998718262 | KNN Loss: 5.0011491775512695 | BCE Loss: 1.0628949403762817
Epoch 258 / 500 | iteration 20 / 30 | Total Loss: 6.037271022796631 | KNN Loss: 5.024835586547852 | BCE

Epoch 268 / 500 | iteration 5 / 30 | Total Loss: 6.013460636138916 | KNN Loss: 4.999653339385986 | BCE Loss: 1.0138071775436401
Epoch 268 / 500 | iteration 10 / 30 | Total Loss: 6.032060623168945 | KNN Loss: 5.018365383148193 | BCE Loss: 1.013695478439331
Epoch 268 / 500 | iteration 15 / 30 | Total Loss: 6.06089973449707 | KNN Loss: 5.012787342071533 | BCE Loss: 1.0481125116348267
Epoch 268 / 500 | iteration 20 / 30 | Total Loss: 6.036672592163086 | KNN Loss: 5.012868881225586 | BCE Loss: 1.0238037109375
Epoch 268 / 500 | iteration 25 / 30 | Total Loss: 6.066684722900391 | KNN Loss: 5.035377502441406 | BCE Loss: 1.0313074588775635
Epoch 269 / 500 | iteration 0 / 30 | Total Loss: 6.0298357009887695 | KNN Loss: 5.021240711212158 | BCE Loss: 1.0085949897766113
Epoch 269 / 500 | iteration 5 / 30 | Total Loss: 6.070206642150879 | KNN Loss: 5.028445720672607 | BCE Loss: 1.0417611598968506
Epoch 269 / 500 | iteration 10 / 30 | Total Loss: 5.967853546142578 | KNN Loss: 4.977558135986328 | BCE 

Epoch 278 / 500 | iteration 25 / 30 | Total Loss: 6.003467559814453 | KNN Loss: 4.980100154876709 | BCE Loss: 1.0233672857284546
Epoch 279 / 500 | iteration 0 / 30 | Total Loss: 6.0261688232421875 | KNN Loss: 5.010121822357178 | BCE Loss: 1.0160467624664307
Epoch 279 / 500 | iteration 5 / 30 | Total Loss: 6.041759490966797 | KNN Loss: 5.025346279144287 | BCE Loss: 1.0164132118225098
Epoch 279 / 500 | iteration 10 / 30 | Total Loss: 6.015995979309082 | KNN Loss: 5.003365993499756 | BCE Loss: 1.0126302242279053
Epoch 279 / 500 | iteration 15 / 30 | Total Loss: 6.063911437988281 | KNN Loss: 5.007852077484131 | BCE Loss: 1.0560591220855713
Epoch 279 / 500 | iteration 20 / 30 | Total Loss: 6.055912971496582 | KNN Loss: 5.036261558532715 | BCE Loss: 1.019651174545288
Epoch 279 / 500 | iteration 25 / 30 | Total Loss: 6.029735088348389 | KNN Loss: 5.008279323577881 | BCE Loss: 1.0214556455612183
Epoch 280 / 500 | iteration 0 / 30 | Total Loss: 6.022757053375244 | KNN Loss: 4.9821696281433105 |

Epoch 289 / 500 | iteration 15 / 30 | Total Loss: 6.03170108795166 | KNN Loss: 5.0049967765808105 | BCE Loss: 1.0267040729522705
Epoch 289 / 500 | iteration 20 / 30 | Total Loss: 6.0596513748168945 | KNN Loss: 5.0432939529418945 | BCE Loss: 1.0163575410842896
Epoch 289 / 500 | iteration 25 / 30 | Total Loss: 5.988691806793213 | KNN Loss: 4.971010208129883 | BCE Loss: 1.01768159866333
Epoch 290 / 500 | iteration 0 / 30 | Total Loss: 6.040393352508545 | KNN Loss: 5.01356315612793 | BCE Loss: 1.0268300771713257
Epoch 290 / 500 | iteration 5 / 30 | Total Loss: 6.044535160064697 | KNN Loss: 5.017277717590332 | BCE Loss: 1.0272573232650757
Epoch 290 / 500 | iteration 10 / 30 | Total Loss: 6.014586925506592 | KNN Loss: 5.008232593536377 | BCE Loss: 1.0063544511795044
Epoch 290 / 500 | iteration 15 / 30 | Total Loss: 6.0018463134765625 | KNN Loss: 4.9893293380737305 | BCE Loss: 1.012516975402832
Epoch 290 / 500 | iteration 20 / 30 | Total Loss: 6.024381637573242 | KNN Loss: 4.995099067687988 |

Epoch 300 / 500 | iteration 5 / 30 | Total Loss: 6.046483993530273 | KNN Loss: 5.026795387268066 | BCE Loss: 1.019688367843628
Epoch 300 / 500 | iteration 10 / 30 | Total Loss: 6.0946784019470215 | KNN Loss: 5.087503433227539 | BCE Loss: 1.0071749687194824
Epoch 300 / 500 | iteration 15 / 30 | Total Loss: 6.039053916931152 | KNN Loss: 4.999898433685303 | BCE Loss: 1.0391556024551392
Epoch 300 / 500 | iteration 20 / 30 | Total Loss: 6.017491817474365 | KNN Loss: 5.005316257476807 | BCE Loss: 1.0121755599975586
Epoch 300 / 500 | iteration 25 / 30 | Total Loss: 6.046480655670166 | KNN Loss: 5.0136260986328125 | BCE Loss: 1.0328545570373535
Epoch 301 / 500 | iteration 0 / 30 | Total Loss: 6.039165019989014 | KNN Loss: 5.008707046508789 | BCE Loss: 1.0304579734802246
Epoch 301 / 500 | iteration 5 / 30 | Total Loss: 6.046754837036133 | KNN Loss: 5.010242938995361 | BCE Loss: 1.0365121364593506
Epoch 301 / 500 | iteration 10 / 30 | Total Loss: 6.0189924240112305 | KNN Loss: 4.981760025024414 

Epoch 310 / 500 | iteration 25 / 30 | Total Loss: 6.019783020019531 | KNN Loss: 5.008553504943848 | BCE Loss: 1.0112295150756836
Epoch 311 / 500 | iteration 0 / 30 | Total Loss: 6.032553672790527 | KNN Loss: 4.996654510498047 | BCE Loss: 1.0358991622924805
Epoch 311 / 500 | iteration 5 / 30 | Total Loss: 6.059474468231201 | KNN Loss: 5.0266218185424805 | BCE Loss: 1.0328525304794312
Epoch 311 / 500 | iteration 10 / 30 | Total Loss: 6.071877479553223 | KNN Loss: 5.011009216308594 | BCE Loss: 1.060868501663208
Epoch 311 / 500 | iteration 15 / 30 | Total Loss: 6.030112266540527 | KNN Loss: 5.020129680633545 | BCE Loss: 1.0099824666976929
Epoch 311 / 500 | iteration 20 / 30 | Total Loss: 6.060830116271973 | KNN Loss: 5.029746055603027 | BCE Loss: 1.0310842990875244
Epoch 311 / 500 | iteration 25 / 30 | Total Loss: 6.013570308685303 | KNN Loss: 4.995748996734619 | BCE Loss: 1.0178213119506836
Epoch 312 / 500 | iteration 0 / 30 | Total Loss: 6.033949851989746 | KNN Loss: 5.020389080047607 | 

Epoch 321 / 500 | iteration 15 / 30 | Total Loss: 6.021280288696289 | KNN Loss: 5.013329029083252 | BCE Loss: 1.0079514980316162
Epoch 321 / 500 | iteration 20 / 30 | Total Loss: 6.030664443969727 | KNN Loss: 5.000585556030273 | BCE Loss: 1.0300791263580322
Epoch 321 / 500 | iteration 25 / 30 | Total Loss: 5.997589111328125 | KNN Loss: 4.997008323669434 | BCE Loss: 1.0005810260772705
Epoch 322 / 500 | iteration 0 / 30 | Total Loss: 6.016063213348389 | KNN Loss: 4.985025405883789 | BCE Loss: 1.03103768825531
Epoch 322 / 500 | iteration 5 / 30 | Total Loss: 6.033472061157227 | KNN Loss: 4.96937370300293 | BCE Loss: 1.0640981197357178
Epoch 322 / 500 | iteration 10 / 30 | Total Loss: 6.0244293212890625 | KNN Loss: 4.991325378417969 | BCE Loss: 1.0331039428710938
Epoch 322 / 500 | iteration 15 / 30 | Total Loss: 6.00136661529541 | KNN Loss: 4.997824192047119 | BCE Loss: 1.0035425424575806
Epoch 322 / 500 | iteration 20 / 30 | Total Loss: 5.994441509246826 | KNN Loss: 4.995545864105225 | BC

Epoch 332 / 500 | iteration 5 / 30 | Total Loss: 6.077615737915039 | KNN Loss: 5.045828819274902 | BCE Loss: 1.0317871570587158
Epoch 332 / 500 | iteration 10 / 30 | Total Loss: 6.036493301391602 | KNN Loss: 5.000519275665283 | BCE Loss: 1.0359737873077393
Epoch 332 / 500 | iteration 15 / 30 | Total Loss: 6.121911525726318 | KNN Loss: 5.084228038787842 | BCE Loss: 1.037683367729187
Epoch 332 / 500 | iteration 20 / 30 | Total Loss: 6.02011251449585 | KNN Loss: 5.001737594604492 | BCE Loss: 1.018375039100647
Epoch 332 / 500 | iteration 25 / 30 | Total Loss: 6.065311431884766 | KNN Loss: 5.0178117752075195 | BCE Loss: 1.0474998950958252
Epoch 333 / 500 | iteration 0 / 30 | Total Loss: 6.021064758300781 | KNN Loss: 4.995584011077881 | BCE Loss: 1.0254809856414795
Epoch 333 / 500 | iteration 5 / 30 | Total Loss: 6.030372619628906 | KNN Loss: 5.007797718048096 | BCE Loss: 1.0225751399993896
Epoch 333 / 500 | iteration 10 / 30 | Total Loss: 6.0334367752075195 | KNN Loss: 5.005491256713867 | B

Epoch 342 / 500 | iteration 25 / 30 | Total Loss: 6.004743576049805 | KNN Loss: 4.9858174324035645 | BCE Loss: 1.0189261436462402
Epoch   343: reducing learning rate of group 0 to 3.9896e-06.
Epoch 343 / 500 | iteration 0 / 30 | Total Loss: 6.0359086990356445 | KNN Loss: 5.0205888748168945 | BCE Loss: 1.015319585800171
Epoch 343 / 500 | iteration 5 / 30 | Total Loss: 6.067837238311768 | KNN Loss: 5.06294059753418 | BCE Loss: 1.004896640777588
Epoch 343 / 500 | iteration 10 / 30 | Total Loss: 6.021771430969238 | KNN Loss: 4.9852142333984375 | BCE Loss: 1.0365574359893799
Epoch 343 / 500 | iteration 15 / 30 | Total Loss: 6.072725296020508 | KNN Loss: 5.056077480316162 | BCE Loss: 1.0166479349136353
Epoch 343 / 500 | iteration 20 / 30 | Total Loss: 6.002506732940674 | KNN Loss: 4.998241424560547 | BCE Loss: 1.0042654275894165
Epoch 343 / 500 | iteration 25 / 30 | Total Loss: 6.047629356384277 | KNN Loss: 4.996552467346191 | BCE Loss: 1.0510767698287964
Epoch 344 / 500 | iteration 0 / 30 |

Epoch 353 / 500 | iteration 15 / 30 | Total Loss: 6.092792510986328 | KNN Loss: 5.040063381195068 | BCE Loss: 1.0527291297912598
Epoch 353 / 500 | iteration 20 / 30 | Total Loss: 6.007585048675537 | KNN Loss: 5.013785362243652 | BCE Loss: 0.9937997460365295
Epoch 353 / 500 | iteration 25 / 30 | Total Loss: 6.024809837341309 | KNN Loss: 5.006701469421387 | BCE Loss: 1.0181083679199219
Epoch   354: reducing learning rate of group 0 to 2.7927e-06.
Epoch 354 / 500 | iteration 0 / 30 | Total Loss: 6.022641181945801 | KNN Loss: 4.999680995941162 | BCE Loss: 1.0229599475860596
Epoch 354 / 500 | iteration 5 / 30 | Total Loss: 6.03481912612915 | KNN Loss: 4.995436191558838 | BCE Loss: 1.039382815361023
Epoch 354 / 500 | iteration 10 / 30 | Total Loss: 6.034597873687744 | KNN Loss: 4.994095325469971 | BCE Loss: 1.040502667427063
Epoch 354 / 500 | iteration 15 / 30 | Total Loss: 6.003665924072266 | KNN Loss: 4.987577438354492 | BCE Loss: 1.0160882472991943
Epoch 354 / 500 | iteration 20 / 30 | To

Epoch 364 / 500 | iteration 5 / 30 | Total Loss: 6.010943412780762 | KNN Loss: 4.977572917938232 | BCE Loss: 1.0333702564239502
Epoch 364 / 500 | iteration 10 / 30 | Total Loss: 6.026951789855957 | KNN Loss: 4.998095989227295 | BCE Loss: 1.028855562210083
Epoch 364 / 500 | iteration 15 / 30 | Total Loss: 6.016975402832031 | KNN Loss: 4.996631622314453 | BCE Loss: 1.0203437805175781
Epoch 364 / 500 | iteration 20 / 30 | Total Loss: 6.097758769989014 | KNN Loss: 5.06287956237793 | BCE Loss: 1.0348793268203735
Epoch 364 / 500 | iteration 25 / 30 | Total Loss: 6.00969123840332 | KNN Loss: 4.987896919250488 | BCE Loss: 1.0217944383621216
Epoch   365: reducing learning rate of group 0 to 1.9549e-06.
Epoch 365 / 500 | iteration 0 / 30 | Total Loss: 6.023990631103516 | KNN Loss: 5.000502586364746 | BCE Loss: 1.0234878063201904
Epoch 365 / 500 | iteration 5 / 30 | Total Loss: 6.050355911254883 | KNN Loss: 4.991152286529541 | BCE Loss: 1.0592033863067627
Epoch 365 / 500 | iteration 10 / 30 | Tot

Epoch 374 / 500 | iteration 25 / 30 | Total Loss: 6.023799419403076 | KNN Loss: 4.982768535614014 | BCE Loss: 1.0410308837890625
Epoch 375 / 500 | iteration 0 / 30 | Total Loss: 6.081629753112793 | KNN Loss: 5.032063961029053 | BCE Loss: 1.0495657920837402
Epoch 375 / 500 | iteration 5 / 30 | Total Loss: 6.037136077880859 | KNN Loss: 5.00532865524292 | BCE Loss: 1.0318071842193604
Epoch 375 / 500 | iteration 10 / 30 | Total Loss: 6.009263038635254 | KNN Loss: 5.007815837860107 | BCE Loss: 1.0014472007751465
Epoch 375 / 500 | iteration 15 / 30 | Total Loss: 6.007588863372803 | KNN Loss: 4.988133907318115 | BCE Loss: 1.0194549560546875
Epoch 375 / 500 | iteration 20 / 30 | Total Loss: 6.011289596557617 | KNN Loss: 5.004359245300293 | BCE Loss: 1.0069303512573242
Epoch 375 / 500 | iteration 25 / 30 | Total Loss: 6.03176212310791 | KNN Loss: 4.974371433258057 | BCE Loss: 1.0573906898498535
Epoch   376: reducing learning rate of group 0 to 1.3684e-06.
Epoch 376 / 500 | iteration 0 / 30 | To

Epoch 385 / 500 | iteration 15 / 30 | Total Loss: 6.037484169006348 | KNN Loss: 5.0340471267700195 | BCE Loss: 1.0034370422363281
Epoch 385 / 500 | iteration 20 / 30 | Total Loss: 6.028408050537109 | KNN Loss: 5.0018391609191895 | BCE Loss: 1.026569128036499
Epoch 385 / 500 | iteration 25 / 30 | Total Loss: 6.029114723205566 | KNN Loss: 5.038174152374268 | BCE Loss: 0.9909406900405884
Epoch 386 / 500 | iteration 0 / 30 | Total Loss: 6.024418830871582 | KNN Loss: 4.993710994720459 | BCE Loss: 1.0307080745697021
Epoch 386 / 500 | iteration 5 / 30 | Total Loss: 5.996101379394531 | KNN Loss: 4.991487979888916 | BCE Loss: 1.0046131610870361
Epoch 386 / 500 | iteration 10 / 30 | Total Loss: 6.047181129455566 | KNN Loss: 5.006996154785156 | BCE Loss: 1.0401849746704102
Epoch 386 / 500 | iteration 15 / 30 | Total Loss: 6.061966419219971 | KNN Loss: 5.043574333190918 | BCE Loss: 1.0183922052383423
Epoch 386 / 500 | iteration 20 / 30 | Total Loss: 6.0193986892700195 | KNN Loss: 4.989124774932861

Epoch 396 / 500 | iteration 5 / 30 | Total Loss: 6.026158332824707 | KNN Loss: 4.99375581741333 | BCE Loss: 1.0324022769927979
Epoch 396 / 500 | iteration 10 / 30 | Total Loss: 6.023304462432861 | KNN Loss: 5.019600868225098 | BCE Loss: 1.0037037134170532
Epoch 396 / 500 | iteration 15 / 30 | Total Loss: 6.030067443847656 | KNN Loss: 5.000393390655518 | BCE Loss: 1.0296742916107178
Epoch 396 / 500 | iteration 20 / 30 | Total Loss: 6.047136306762695 | KNN Loss: 5.0218186378479 | BCE Loss: 1.0253174304962158
Epoch 396 / 500 | iteration 25 / 30 | Total Loss: 5.991802215576172 | KNN Loss: 4.976236343383789 | BCE Loss: 1.0155659914016724
Epoch 397 / 500 | iteration 0 / 30 | Total Loss: 6.072108268737793 | KNN Loss: 5.007017612457275 | BCE Loss: 1.065090537071228
Epoch 397 / 500 | iteration 5 / 30 | Total Loss: 6.001628875732422 | KNN Loss: 4.997427463531494 | BCE Loss: 1.0042016506195068
Epoch 397 / 500 | iteration 10 / 30 | Total Loss: 6.043091297149658 | KNN Loss: 5.0255022048950195 | BCE

Epoch 406 / 500 | iteration 25 / 30 | Total Loss: 6.00821590423584 | KNN Loss: 4.98282527923584 | BCE Loss: 1.0253905057907104
Epoch 407 / 500 | iteration 0 / 30 | Total Loss: 6.069305419921875 | KNN Loss: 5.0304341316223145 | BCE Loss: 1.038871169090271
Epoch 407 / 500 | iteration 5 / 30 | Total Loss: 6.032289505004883 | KNN Loss: 5.043704986572266 | BCE Loss: 0.9885842800140381
Epoch 407 / 500 | iteration 10 / 30 | Total Loss: 6.060688018798828 | KNN Loss: 5.041126728057861 | BCE Loss: 1.0195614099502563
Epoch 407 / 500 | iteration 15 / 30 | Total Loss: 6.018387794494629 | KNN Loss: 4.992721080780029 | BCE Loss: 1.02566659450531
Epoch 407 / 500 | iteration 20 / 30 | Total Loss: 6.046230316162109 | KNN Loss: 5.019248008728027 | BCE Loss: 1.0269824266433716
Epoch 407 / 500 | iteration 25 / 30 | Total Loss: 6.075963973999023 | KNN Loss: 5.055668354034424 | BCE Loss: 1.0202953815460205
Epoch 408 / 500 | iteration 0 / 30 | Total Loss: 6.033123016357422 | KNN Loss: 5.0040483474731445 | BCE

Epoch 417 / 500 | iteration 15 / 30 | Total Loss: 6.099130630493164 | KNN Loss: 5.0404582023620605 | BCE Loss: 1.0586724281311035
Epoch 417 / 500 | iteration 20 / 30 | Total Loss: 6.067590713500977 | KNN Loss: 5.033631801605225 | BCE Loss: 1.033959150314331
Epoch 417 / 500 | iteration 25 / 30 | Total Loss: 6.040345191955566 | KNN Loss: 5.0184431076049805 | BCE Loss: 1.0219018459320068
Epoch 418 / 500 | iteration 0 / 30 | Total Loss: 6.008139610290527 | KNN Loss: 4.979063034057617 | BCE Loss: 1.029076337814331
Epoch 418 / 500 | iteration 5 / 30 | Total Loss: 6.053720951080322 | KNN Loss: 5.012526988983154 | BCE Loss: 1.041193962097168
Epoch 418 / 500 | iteration 10 / 30 | Total Loss: 6.072053909301758 | KNN Loss: 5.031579971313477 | BCE Loss: 1.0404741764068604
Epoch 418 / 500 | iteration 15 / 30 | Total Loss: 6.044543743133545 | KNN Loss: 5.027968883514404 | BCE Loss: 1.0165748596191406
Epoch 418 / 500 | iteration 20 / 30 | Total Loss: 6.020480155944824 | KNN Loss: 5.001270771026611 | 

Epoch 428 / 500 | iteration 5 / 30 | Total Loss: 6.044663429260254 | KNN Loss: 5.004051208496094 | BCE Loss: 1.040611982345581
Epoch 428 / 500 | iteration 10 / 30 | Total Loss: 6.009640216827393 | KNN Loss: 4.9978179931640625 | BCE Loss: 1.0118221044540405
Epoch 428 / 500 | iteration 15 / 30 | Total Loss: 6.015079498291016 | KNN Loss: 5.002893447875977 | BCE Loss: 1.0121862888336182
Epoch 428 / 500 | iteration 20 / 30 | Total Loss: 6.072293281555176 | KNN Loss: 5.046955585479736 | BCE Loss: 1.0253374576568604
Epoch 428 / 500 | iteration 25 / 30 | Total Loss: 6.008981704711914 | KNN Loss: 4.986256122589111 | BCE Loss: 1.0227253437042236
Epoch 429 / 500 | iteration 0 / 30 | Total Loss: 6.040383815765381 | KNN Loss: 5.014544486999512 | BCE Loss: 1.0258394479751587
Epoch 429 / 500 | iteration 5 / 30 | Total Loss: 6.039428234100342 | KNN Loss: 5.008208274841309 | BCE Loss: 1.0312198400497437
Epoch 429 / 500 | iteration 10 / 30 | Total Loss: 6.046314716339111 | KNN Loss: 5.020675182342529 | 

Epoch 438 / 500 | iteration 25 / 30 | Total Loss: 6.029807090759277 | KNN Loss: 5.002346515655518 | BCE Loss: 1.0274605751037598
Epoch 439 / 500 | iteration 0 / 30 | Total Loss: 6.013274192810059 | KNN Loss: 5.011310577392578 | BCE Loss: 1.0019633769989014
Epoch 439 / 500 | iteration 5 / 30 | Total Loss: 6.071589469909668 | KNN Loss: 5.042746067047119 | BCE Loss: 1.0288431644439697
Epoch 439 / 500 | iteration 10 / 30 | Total Loss: 6.000012397766113 | KNN Loss: 4.99124813079834 | BCE Loss: 1.0087642669677734
Epoch 439 / 500 | iteration 15 / 30 | Total Loss: 6.06846809387207 | KNN Loss: 5.021121501922607 | BCE Loss: 1.047346591949463
Epoch 439 / 500 | iteration 20 / 30 | Total Loss: 6.0055084228515625 | KNN Loss: 4.982028007507324 | BCE Loss: 1.0234801769256592
Epoch 439 / 500 | iteration 25 / 30 | Total Loss: 6.038318634033203 | KNN Loss: 4.984135627746582 | BCE Loss: 1.0541828870773315
Epoch 440 / 500 | iteration 0 / 30 | Total Loss: 6.065736770629883 | KNN Loss: 5.024531841278076 | BC

Epoch 449 / 500 | iteration 15 / 30 | Total Loss: 5.998149394989014 | KNN Loss: 4.985795974731445 | BCE Loss: 1.0123534202575684
Epoch 449 / 500 | iteration 20 / 30 | Total Loss: 5.988480567932129 | KNN Loss: 4.9815239906311035 | BCE Loss: 1.0069568157196045
Epoch 449 / 500 | iteration 25 / 30 | Total Loss: 5.9784064292907715 | KNN Loss: 4.990254878997803 | BCE Loss: 0.9881516098976135
Epoch 450 / 500 | iteration 0 / 30 | Total Loss: 6.009145259857178 | KNN Loss: 4.9896345138549805 | BCE Loss: 1.0195107460021973
Epoch 450 / 500 | iteration 5 / 30 | Total Loss: 6.014015197753906 | KNN Loss: 5.000187873840332 | BCE Loss: 1.0138274431228638
Epoch 450 / 500 | iteration 10 / 30 | Total Loss: 6.029138565063477 | KNN Loss: 4.995615482330322 | BCE Loss: 1.0335230827331543
Epoch 450 / 500 | iteration 15 / 30 | Total Loss: 6.070472717285156 | KNN Loss: 5.054712772369385 | BCE Loss: 1.0157597064971924
Epoch 450 / 500 | iteration 20 / 30 | Total Loss: 6.025610446929932 | KNN Loss: 5.00410985946655

Epoch 460 / 500 | iteration 5 / 30 | Total Loss: 6.076190948486328 | KNN Loss: 5.036505699157715 | BCE Loss: 1.0396854877471924
Epoch 460 / 500 | iteration 10 / 30 | Total Loss: 5.98680305480957 | KNN Loss: 4.9798102378845215 | BCE Loss: 1.0069929361343384
Epoch 460 / 500 | iteration 15 / 30 | Total Loss: 6.058855056762695 | KNN Loss: 5.041995048522949 | BCE Loss: 1.016859769821167
Epoch 460 / 500 | iteration 20 / 30 | Total Loss: 6.000916481018066 | KNN Loss: 4.994308948516846 | BCE Loss: 1.0066072940826416
Epoch 460 / 500 | iteration 25 / 30 | Total Loss: 6.026651859283447 | KNN Loss: 5.011651992797852 | BCE Loss: 1.0149997472763062
Epoch 461 / 500 | iteration 0 / 30 | Total Loss: 6.0110392570495605 | KNN Loss: 4.977228164672852 | BCE Loss: 1.033811092376709
Epoch 461 / 500 | iteration 5 / 30 | Total Loss: 6.058467864990234 | KNN Loss: 5.020451545715332 | BCE Loss: 1.0380160808563232
Epoch 461 / 500 | iteration 10 / 30 | Total Loss: 6.073695182800293 | KNN Loss: 5.017346382141113 | B

Epoch 470 / 500 | iteration 25 / 30 | Total Loss: 6.022159099578857 | KNN Loss: 4.992762088775635 | BCE Loss: 1.0293970108032227
Epoch 471 / 500 | iteration 0 / 30 | Total Loss: 6.073005676269531 | KNN Loss: 5.009349822998047 | BCE Loss: 1.0636560916900635
Epoch 471 / 500 | iteration 5 / 30 | Total Loss: 6.052948951721191 | KNN Loss: 5.012661933898926 | BCE Loss: 1.0402872562408447
Epoch 471 / 500 | iteration 10 / 30 | Total Loss: 6.033650875091553 | KNN Loss: 5.002420425415039 | BCE Loss: 1.0312303304672241
Epoch 471 / 500 | iteration 15 / 30 | Total Loss: 6.046316623687744 | KNN Loss: 5.022641181945801 | BCE Loss: 1.0236754417419434
Epoch 471 / 500 | iteration 20 / 30 | Total Loss: 6.02301025390625 | KNN Loss: 5.001128196716309 | BCE Loss: 1.0218818187713623
Epoch 471 / 500 | iteration 25 / 30 | Total Loss: 6.0301618576049805 | KNN Loss: 5.008731365203857 | BCE Loss: 1.021430253982544
Epoch 472 / 500 | iteration 0 / 30 | Total Loss: 6.019319534301758 | KNN Loss: 5.002521991729736 | B

Epoch 481 / 500 | iteration 15 / 30 | Total Loss: 6.055817127227783 | KNN Loss: 5.0468010902404785 | BCE Loss: 1.0090160369873047
Epoch 481 / 500 | iteration 20 / 30 | Total Loss: 6.0202507972717285 | KNN Loss: 4.999390125274658 | BCE Loss: 1.0208606719970703
Epoch 481 / 500 | iteration 25 / 30 | Total Loss: 6.018528461456299 | KNN Loss: 5.000139236450195 | BCE Loss: 1.018389344215393
Epoch 482 / 500 | iteration 0 / 30 | Total Loss: 6.071181774139404 | KNN Loss: 5.0631022453308105 | BCE Loss: 1.0080796480178833
Epoch 482 / 500 | iteration 5 / 30 | Total Loss: 6.069654941558838 | KNN Loss: 5.0339765548706055 | BCE Loss: 1.0356782674789429
Epoch 482 / 500 | iteration 10 / 30 | Total Loss: 6.013039588928223 | KNN Loss: 5.013042449951172 | BCE Loss: 0.9999969005584717
Epoch 482 / 500 | iteration 15 / 30 | Total Loss: 6.068719863891602 | KNN Loss: 5.020889759063721 | BCE Loss: 1.0478301048278809
Epoch 482 / 500 | iteration 20 / 30 | Total Loss: 6.03743839263916 | KNN Loss: 4.987823009490967

Epoch 492 / 500 | iteration 5 / 30 | Total Loss: 6.004672527313232 | KNN Loss: 5.000941276550293 | BCE Loss: 1.003731369972229
Epoch 492 / 500 | iteration 10 / 30 | Total Loss: 6.00396728515625 | KNN Loss: 4.997598648071289 | BCE Loss: 1.00636887550354
Epoch 492 / 500 | iteration 15 / 30 | Total Loss: 6.017963886260986 | KNN Loss: 5.013737201690674 | BCE Loss: 1.004226803779602
Epoch 492 / 500 | iteration 20 / 30 | Total Loss: 6.02263069152832 | KNN Loss: 5.001904487609863 | BCE Loss: 1.020725965499878
Epoch 492 / 500 | iteration 25 / 30 | Total Loss: 6.031940460205078 | KNN Loss: 5.015169143676758 | BCE Loss: 1.0167710781097412
Epoch 493 / 500 | iteration 0 / 30 | Total Loss: 6.021288871765137 | KNN Loss: 4.998803615570068 | BCE Loss: 1.0224854946136475
Epoch 493 / 500 | iteration 5 / 30 | Total Loss: 6.0339555740356445 | KNN Loss: 4.987903118133545 | BCE Loss: 1.0460525751113892
Epoch 493 / 500 | iteration 10 / 30 | Total Loss: 6.079273223876953 | KNN Loss: 5.01887845993042 | BCE Los

In [7]:
outputs, iterm = model(dataset[67][0].unsqueeze(0).to(device), return_intermidiate=True)
print(outputs)
print(outputs.min())
print(outputs.max())

tensor([[ 2.2265,  3.8615,  2.5736,  2.9090,  2.8241,  0.6991,  2.6035,  1.6462,
          2.3869,  2.0602,  2.1964,  2.0600,  0.6111,  1.8932,  1.3220,  1.4583,
          2.9208,  2.6505,  2.8708,  2.3284,  1.7294,  3.1018,  2.3613,  2.6003,
          1.9957,  1.8294,  2.0026,  1.2849,  1.4417,  0.4059, -0.1320,  1.0266,
          0.2938,  0.9174,  1.6701,  1.4917,  1.0858,  3.2999,  0.7243,  1.4092,
          0.8491, -0.9957, -0.2098,  2.2923,  2.1144,  0.7912, -0.1759,  0.0579,
          1.4154,  1.9824,  1.9156,  0.2031,  1.2618,  0.4512, -0.5825,  1.0971,
          1.3924,  1.4737,  1.2272,  1.9045,  0.5839,  0.9296,  0.1225,  1.7981,
          1.3353,  1.7348, -1.8844,  0.4145,  2.3761,  2.0298,  2.5776,  0.4516,
          1.4844,  2.5514,  1.8596,  1.3033,  0.3155,  0.7763,  0.1432,  1.6417,
          0.0722,  0.5226,  1.9932, -0.3415,  0.2397, -1.0558, -2.4648, -0.2253,
          0.5830, -1.8698,  0.5544, -0.1350, -0.5117, -0.9409,  0.6519,  1.3184,
         -0.6455, -0.7293,  

In [8]:
plt.figure()
plt.plot(losses)
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [9]:
dataset.transform = torchvision.transforms.Compose([
    BinaryEncodingTransform(mapping=dataset.items_to_idx),
]
)
dataset.target_transform = torchvision.transforms.Compose([
    BinaryEncodingTransform(mapping=dataset.items_to_idx),
]
)

In [10]:
dataset_ = [d[0].to('cpu') for d in dataset]

In [11]:
model = model.eval().to('cpu')
projections = model.calculate_intermidiate(dataset_)

100%|██████████| 15/15 [00:00<00:00, 88.94it/s]


In [12]:
distances = pairwise_distances(projections)
# distances = np.triu(distances)
distances_f = distances.flatten()

plt.matshow(distances)
plt.colorbar()
plt.figure()
plt.hist(distances_f[distances_f > 0], bins=1000)
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

# Fit DBSCAN and calculate indices


In [13]:
clusters = DBSCAN(eps=0.1, min_samples=80).fit_predict(projections)
# scores = []
# best_score = float('inf')
# clusters = None
# range_ = list(range(5, 20))
# for k in tqdm(range_):
#     y = GaussianMixture(n_components=k).fit_predict(projections)
#     cur_score = davies_bouldin_score(projections, y)
#     scores.append(cur_score)
    
#     if cur_score < best_score:
#         best_score = cur_score
#         clusters = y

In [14]:
perplexity = 100

p = reduce_dims_and_plot(projections[clusters != -1],
                         y=clusters[clusters != -1],
                         title=f'perplexity: {perplexity}',
                         file_name=None,
                         perplexity=perplexity,
                         library='Multicore-TSNE',
                         perform_PCA=False,
                         projected=None,
                         figure_type='2d',
                         show_figure=True,
                         close_figure=False,
                         text=None)    

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [15]:
# from sklearn.tree import DecisionTreeClassifier
# from sklearn import tree
# from sklearn.tree import _tree

In [16]:
# tensor_dataset = torch.stack(dataset_)

In [17]:
# clf = DecisionTreeClassifier(max_depth=200, min_samples_leaf=5)
# clf = clf.fit(tensor_dataset[clusters!=-1], clusters[clusters != -1])
# print(clf.score(tensor_dataset[clusters!=-1], clusters[clusters != -1]))
# print(clf.get_depth())

In [18]:
# scores = []
# for min_samples in range(1,50, 1):
#     clf = DecisionTreeClassifier(max_depth=200, min_samples_leaf=min_samples)
#     clf = clf.fit(tensor_dataset[clusters!=-1], clusters[clusters != -1])
#     scores.append(clf.score(tensor_dataset[clusters!=-1], clusters[clusters != -1]))
    
# plt.figure()
# plt.plot(list(range(1,50, 1)), scores)
# plt.show()

In [19]:
def get_rules(tree, feature_names, class_names):
    tree_ = tree.tree_
    feature_name = [
        feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!"
        for i in tree_.feature
    ]

    paths = []
    path = []
    
    def recurse(node, path, paths):
        if tree_.feature[node] != _tree.TREE_UNDEFINED:
            name = feature_name[node]
            threshold = tree_.threshold[node]
            p1, p2 = list(path), list(path)
#             p1 += [f"({name} <= {np.round(threshold, 3)})"]
            p1 += [(name, '<=', np.round(threshold, 3))]
            recurse(tree_.children_left[node], p1, paths)
            p2 += [(name, '>', np.round(threshold, 3))]
            recurse(tree_.children_right[node], p2, paths)
        else:
            path += [(tree_.value[node], tree_.n_node_samples[node])]
            paths += [path]
            
    recurse(0, path, paths)

    # sort by samples count
    samples_count = [p[-1][1] for p in paths]
    ii = list(np.argsort(samples_count))
    paths = [paths[i] for i in reversed(ii)]
    
    rules = []
    for path in paths:
        rule = []
        
        for p in path[:-1]:
            rule += [p]
        target = " then "
        if class_names is None:
            target += "response: "+str(np.round(path[-1][0][0][0],3))
        else:
            classes = path[-1][0][0]
            l = np.argmax(classes)
            target += f"class: {class_names[l]} (proba: {np.round(100.0*classes[l]/np.sum(classes),2)}%)"
           
        proba = np.round(100.0*classes[l]/np.sum(classes),2)
        target += f" | based on {path[-1][1]:,} samples"
        rule_wrapper = {'target': target, 'rule': rule, 'proba': proba}
        rules += [rule_wrapper]
        
    return rules

In [20]:
# rules = get_rules(clf, dataset.items, clusters[clusters != -1])

# for rule in rules:
#     n_pos = 0
#     for c,p,v in rule['rule']:
#         if p == '>':
#             n_pos += 1
#     rule['pos'] = n_pos

In [21]:
# plt.figure()
# probs = [r['proba'] for r in rules]
# plt.hist(probs, bins = 100)
# plt.show()

In [22]:
# rules = sorted(rules, key=lambda x:x['pos'])
# rules = [r for r in rules if r['proba'] > 50]
# print(len(rules))

In [23]:
# for i in range(17):
#     r_i = rules[i]
#     print(f"------------- rule {i} length {len(r_i)} -------------")
#     print(r_i)

# Train a Soft-Decision-Tree given the self-labels

## Prepare the dataset

In [24]:
tensor_dataset = torch.stack(dataset_)

In [25]:
tree_dataset = list(zip(tensor_dataset[clusters!=-1], clusters[clusters != -1]))
batch_size = 512
tree_loader = torch.utils.data.DataLoader(tree_dataset, batch_size=batch_size, shuffle=True)

# Define how we prune the weights of a node

In [26]:
def prune_node(node_weights, factor=1):
    w = node_weights.cpu().detach().numpy()
    mean_ = np.mean(w)
    std_ = np.std(w)
    node_weights[((mean_ - std_ * factor) < node_weights) & (node_weights < (mean_ + std_ * factor))] = 0
    return node_weights

def prune_node_keep(node_weights, keep=4):
    w = node_weights.cpu().detach().numpy()
    throw_idx = np.argsort(abs(w))[:-keep]
    node_weights[throw_idx] = 0
    return node_weights

def prune_tree(tree_, factor):
    new_weights = tree_.inner_nodes.weight.clone()
    for i in range(new_weights.shape[0]):
        res = prune_node_keep(new_weights[i, :], factor)
        new_weights[i, :] = res

    with torch.no_grad():
        tree_.inner_nodes.weight.copy_(new_weights)
        
def sparseness(x):
    s = []
    for i in range(x.shape[0]):
        x_ = x[i, :]
        sp = (len(x_) - torch.norm(x_, 0).item()) / len(x_)
        s.append(sp)
    return np.mean(s)

def compute_regularization_by_level(tree):
    total_reg = 0
    for i in range(tree.inner_nodes.weight.shape[0]):
        cur_level = np.floor(np.log2(i+1))
        node_reg = torch.norm(tree.inner_nodes.weight[i].view(-1), 2)
        total_reg += 2**(-cur_level) * node_reg
    return total_reg

def show_sparseness(tree):
    avg_sp = sparseness(tree.inner_nodes.weight)
    print(f"Average sparseness: {avg_sp}")
    layer = 0
    sps = []
    for i in range(tree.inner_nodes.weight.shape[0]):
        cur_layer = int(np.floor(np.log2(i+1)))
        if cur_layer != layer:
            print(f"layer {layer}: {np.mean(sps)}")
            sps = []
            layer = cur_layer

        x_ = tree.inner_nodes.weight[i, :]
        sp = (len(x_) - torch.norm(x_, 0).item()) / len(x_)
        sps.append(sp)
        
    return avg_sp

## Training configurations

In [27]:
def do_epoch(model, loader, device, log_interval, losses, accs, epoch, iteration):
    model = model.train()
    for batch_idx, (data, target) in enumerate(loader):
        iteration += 1
        data, target = data.to(device), target.to(device)

        output, penalty = tree.forward(data)

        # Loss
        loss_tree = criterion(output, target.view(-1))

        # Penalty
        loss_tree += penalty

        # Sparse regularization
#         fc_params = torch.cat([x.view(-1) for x in tree.inner_nodes.parameters()])
#         regularization = sparsity_lamda * torch.norm(fc_params, 2)
        regularization = sparsity_lamda * compute_regularization_by_level(tree)
        loss = loss_tree

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        losses.append(loss.item())

        pred = output.data.max(1)[1]
        correct = pred.eq(target.view(-1).data).sum()
        accs.append(correct.item() / data.size()[0])

        # Print training status
        if batch_idx % log_interval == 0:
            print(f"Epoch: {epoch:02d} | Batch: {batch_idx:03d} / {len(loader):03d} | Total loss: {loss.item():.3f} | Reg loss: {regularization.item():.3f} | Tree loss: {loss_tree.item():.3f} | Accuracy: {correct.item() / data.size()[0]:03f} | {round((time.time() - start_time) / iteration, 3)} sec/iter")
            
    return iteration


In [28]:
lr = 5e-3
weight_decay = 5e-4
sparsity_lamda = 2e-3
epochs = 100
output_dim = len(set(clusters))
log_interval = 1
use_cuda = device != 'cpu'

In [29]:
tree = SDT(input_dim=tensor_dataset.shape[1], output_dim=len(clusters - 1), depth=tree_depth, lamda=1e-3, use_cuda=use_cuda)
optimizer = torch.optim.Adam(tree.parameters(),
                                 lr=lr,
                                 weight_decay=weight_decay)
criterion = nn.CrossEntropyLoss()
tree = tree.to(device)

In [30]:
losses = []
accs = []
sparsity = []

In [31]:
start_time = time.time()
iteration = 0
for epoch in range(epochs):
    # Training
    avg_sp = show_sparseness(tree)
    sparsity.append(avg_sp)
    iteration = do_epoch(tree, tree_loader, device, log_interval, losses, accs, epoch, iteration)
    
    if epoch % 1 == 0:
        prune_tree(tree, factor=3)
        

Average sparseness: 0.0
layer 0: 0.0
layer 1: 0.0
layer 2: 0.0
layer 3: 0.0
layer 4: 0.0
Epoch: 00 | Batch: 000 / 029 | Total loss: 9.575 | Reg loss: 0.007 | Tree loss: 9.575 | Accuracy: 0.000000 | 0.098 sec/iter
Epoch: 00 | Batch: 001 / 029 | Total loss: 9.570 | Reg loss: 0.007 | Tree loss: 9.570 | Accuracy: 0.000000 | 0.086 sec/iter
Epoch: 00 | Batch: 002 / 029 | Total loss: 9.548 | Reg loss: 0.007 | Tree loss: 9.548 | Accuracy: 0.000000 | 0.079 sec/iter
Epoch: 00 | Batch: 003 / 029 | Total loss: 9.531 | Reg loss: 0.007 | Tree loss: 9.531 | Accuracy: 0.000000 | 0.076 sec/iter
Epoch: 00 | Batch: 004 / 029 | Total loss: 9.536 | Reg loss: 0.007 | Tree loss: 9.536 | Accuracy: 0.000000 | 0.074 sec/iter
Epoch: 00 | Batch: 005 / 029 | Total loss: 9.522 | Reg loss: 0.007 | Tree loss: 9.522 | Accuracy: 0.000000 | 0.073 sec/iter
Epoch: 00 | Batch: 006 / 029 | Total loss: 9.507 | Reg loss: 0.007 | Tree loss: 9.507 | Accuracy: 0.000000 | 0.072 sec/iter
Epoch: 00 | Batch: 007 / 029 | Total loss: 

Epoch: 02 | Batch: 007 / 029 | Total loss: 9.179 | Reg loss: 0.007 | Tree loss: 9.179 | Accuracy: 0.238281 | 0.065 sec/iter
Epoch: 02 | Batch: 008 / 029 | Total loss: 9.143 | Reg loss: 0.007 | Tree loss: 9.143 | Accuracy: 0.312500 | 0.065 sec/iter
Epoch: 02 | Batch: 009 / 029 | Total loss: 9.134 | Reg loss: 0.007 | Tree loss: 9.134 | Accuracy: 0.271484 | 0.065 sec/iter
Epoch: 02 | Batch: 010 / 029 | Total loss: 9.140 | Reg loss: 0.008 | Tree loss: 9.140 | Accuracy: 0.259766 | 0.065 sec/iter
Epoch: 02 | Batch: 011 / 029 | Total loss: 9.093 | Reg loss: 0.008 | Tree loss: 9.093 | Accuracy: 0.355469 | 0.065 sec/iter
Epoch: 02 | Batch: 012 / 029 | Total loss: 9.109 | Reg loss: 0.008 | Tree loss: 9.109 | Accuracy: 0.279297 | 0.065 sec/iter
Epoch: 02 | Batch: 013 / 029 | Total loss: 9.074 | Reg loss: 0.008 | Tree loss: 9.074 | Accuracy: 0.304688 | 0.065 sec/iter
Epoch: 02 | Batch: 014 / 029 | Total loss: 9.088 | Reg loss: 0.009 | Tree loss: 9.088 | Accuracy: 0.285156 | 0.065 sec/iter
Epoch: 0

Epoch: 04 | Batch: 013 / 029 | Total loss: 8.747 | Reg loss: 0.012 | Tree loss: 8.747 | Accuracy: 0.314453 | 0.065 sec/iter
Epoch: 04 | Batch: 014 / 029 | Total loss: 8.735 | Reg loss: 0.012 | Tree loss: 8.735 | Accuracy: 0.304688 | 0.065 sec/iter
Epoch: 04 | Batch: 015 / 029 | Total loss: 8.733 | Reg loss: 0.013 | Tree loss: 8.733 | Accuracy: 0.318359 | 0.065 sec/iter
Epoch: 04 | Batch: 016 / 029 | Total loss: 8.729 | Reg loss: 0.013 | Tree loss: 8.729 | Accuracy: 0.294922 | 0.065 sec/iter
Epoch: 04 | Batch: 017 / 029 | Total loss: 8.744 | Reg loss: 0.013 | Tree loss: 8.744 | Accuracy: 0.250000 | 0.065 sec/iter
Epoch: 04 | Batch: 018 / 029 | Total loss: 8.720 | Reg loss: 0.013 | Tree loss: 8.720 | Accuracy: 0.273438 | 0.065 sec/iter
Epoch: 04 | Batch: 019 / 029 | Total loss: 8.696 | Reg loss: 0.014 | Tree loss: 8.696 | Accuracy: 0.285156 | 0.065 sec/iter
Epoch: 04 | Batch: 020 / 029 | Total loss: 8.695 | Reg loss: 0.014 | Tree loss: 8.695 | Accuracy: 0.263672 | 0.065 sec/iter
Epoch: 0

Epoch: 06 | Batch: 020 / 029 | Total loss: 8.309 | Reg loss: 0.018 | Tree loss: 8.309 | Accuracy: 0.283203 | 0.065 sec/iter
Epoch: 06 | Batch: 021 / 029 | Total loss: 8.285 | Reg loss: 0.019 | Tree loss: 8.285 | Accuracy: 0.289062 | 0.065 sec/iter
Epoch: 06 | Batch: 022 / 029 | Total loss: 8.269 | Reg loss: 0.019 | Tree loss: 8.269 | Accuracy: 0.304688 | 0.065 sec/iter
Epoch: 06 | Batch: 023 / 029 | Total loss: 8.266 | Reg loss: 0.019 | Tree loss: 8.266 | Accuracy: 0.281250 | 0.065 sec/iter
Epoch: 06 | Batch: 024 / 029 | Total loss: 8.240 | Reg loss: 0.020 | Tree loss: 8.240 | Accuracy: 0.296875 | 0.065 sec/iter
Epoch: 06 | Batch: 025 / 029 | Total loss: 8.235 | Reg loss: 0.020 | Tree loss: 8.235 | Accuracy: 0.289062 | 0.065 sec/iter
Epoch: 06 | Batch: 026 / 029 | Total loss: 8.226 | Reg loss: 0.020 | Tree loss: 8.226 | Accuracy: 0.289062 | 0.065 sec/iter
Epoch: 06 | Batch: 027 / 029 | Total loss: 8.193 | Reg loss: 0.021 | Tree loss: 8.193 | Accuracy: 0.296875 | 0.065 sec/iter
Epoch: 0

Average sparseness: 0.9821428571428567
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
Epoch: 09 | Batch: 000 / 029 | Total loss: 8.103 | Reg loss: 0.019 | Tree loss: 8.103 | Accuracy: 0.267578 | 0.065 sec/iter
Epoch: 09 | Batch: 001 / 029 | Total loss: 8.030 | Reg loss: 0.019 | Tree loss: 8.030 | Accuracy: 0.310547 | 0.065 sec/iter
Epoch: 09 | Batch: 002 / 029 | Total loss: 8.023 | Reg loss: 0.019 | Tree loss: 8.023 | Accuracy: 0.283203 | 0.065 sec/iter
Epoch: 09 | Batch: 003 / 029 | Total loss: 8.029 | Reg loss: 0.020 | Tree loss: 8.029 | Accuracy: 0.287109 | 0.065 sec/iter
Epoch: 09 | Batch: 004 / 029 | Total loss: 7.997 | Reg loss: 0.020 | Tree loss: 7.997 | Accuracy: 0.281250 | 0.065 sec/iter
Epoch: 09 | Batch: 005 / 029 | Total loss: 7.975 | Reg loss: 0.020 | Tree loss: 7.975 | Accuracy: 0.300781 | 0.065 sec/iter
Epoch: 09 | Batch: 006 / 029 | Total loss: 7.964 | Reg loss: 0.020 | Tree los

Epoch: 11 | Batch: 004 / 029 | Total loss: 7.644 | Reg loss: 0.023 | Tree loss: 7.644 | Accuracy: 0.291016 | 0.065 sec/iter
Epoch: 11 | Batch: 005 / 029 | Total loss: 7.589 | Reg loss: 0.023 | Tree loss: 7.589 | Accuracy: 0.304688 | 0.065 sec/iter
Epoch: 11 | Batch: 006 / 029 | Total loss: 7.598 | Reg loss: 0.023 | Tree loss: 7.598 | Accuracy: 0.273438 | 0.065 sec/iter
Epoch: 11 | Batch: 007 / 029 | Total loss: 7.530 | Reg loss: 0.023 | Tree loss: 7.530 | Accuracy: 0.310547 | 0.065 sec/iter
Epoch: 11 | Batch: 008 / 029 | Total loss: 7.551 | Reg loss: 0.023 | Tree loss: 7.551 | Accuracy: 0.283203 | 0.065 sec/iter
Epoch: 11 | Batch: 009 / 029 | Total loss: 7.458 | Reg loss: 0.024 | Tree loss: 7.458 | Accuracy: 0.335938 | 0.065 sec/iter
Epoch: 11 | Batch: 010 / 029 | Total loss: 7.536 | Reg loss: 0.024 | Tree loss: 7.536 | Accuracy: 0.253906 | 0.065 sec/iter
Epoch: 11 | Batch: 011 / 029 | Total loss: 7.469 | Reg loss: 0.024 | Tree loss: 7.469 | Accuracy: 0.322266 | 0.065 sec/iter
Epoch: 1

Epoch: 13 | Batch: 010 / 029 | Total loss: 7.113 | Reg loss: 0.027 | Tree loss: 7.113 | Accuracy: 0.300781 | 0.065 sec/iter
Epoch: 13 | Batch: 011 / 029 | Total loss: 7.048 | Reg loss: 0.027 | Tree loss: 7.048 | Accuracy: 0.316406 | 0.065 sec/iter
Epoch: 13 | Batch: 012 / 029 | Total loss: 7.005 | Reg loss: 0.027 | Tree loss: 7.005 | Accuracy: 0.357422 | 0.065 sec/iter
Epoch: 13 | Batch: 013 / 029 | Total loss: 7.051 | Reg loss: 0.027 | Tree loss: 7.051 | Accuracy: 0.316406 | 0.065 sec/iter
Epoch: 13 | Batch: 014 / 029 | Total loss: 7.042 | Reg loss: 0.028 | Tree loss: 7.042 | Accuracy: 0.304688 | 0.065 sec/iter
Epoch: 13 | Batch: 015 / 029 | Total loss: 7.039 | Reg loss: 0.028 | Tree loss: 7.039 | Accuracy: 0.259766 | 0.065 sec/iter
Epoch: 13 | Batch: 016 / 029 | Total loss: 6.965 | Reg loss: 0.028 | Tree loss: 6.965 | Accuracy: 0.298828 | 0.065 sec/iter
Epoch: 13 | Batch: 017 / 029 | Total loss: 7.004 | Reg loss: 0.028 | Tree loss: 7.004 | Accuracy: 0.257812 | 0.065 sec/iter
Epoch: 1

Epoch: 15 | Batch: 017 / 029 | Total loss: 6.616 | Reg loss: 0.031 | Tree loss: 6.616 | Accuracy: 0.273438 | 0.065 sec/iter
Epoch: 15 | Batch: 018 / 029 | Total loss: 6.590 | Reg loss: 0.031 | Tree loss: 6.590 | Accuracy: 0.263672 | 0.065 sec/iter
Epoch: 15 | Batch: 019 / 029 | Total loss: 6.518 | Reg loss: 0.031 | Tree loss: 6.518 | Accuracy: 0.314453 | 0.065 sec/iter
Epoch: 15 | Batch: 020 / 029 | Total loss: 6.546 | Reg loss: 0.032 | Tree loss: 6.546 | Accuracy: 0.285156 | 0.065 sec/iter
Epoch: 15 | Batch: 021 / 029 | Total loss: 6.489 | Reg loss: 0.032 | Tree loss: 6.489 | Accuracy: 0.279297 | 0.065 sec/iter
Epoch: 15 | Batch: 022 / 029 | Total loss: 6.439 | Reg loss: 0.032 | Tree loss: 6.439 | Accuracy: 0.312500 | 0.065 sec/iter
Epoch: 15 | Batch: 023 / 029 | Total loss: 6.458 | Reg loss: 0.032 | Tree loss: 6.458 | Accuracy: 0.281250 | 0.065 sec/iter
Epoch: 15 | Batch: 024 / 029 | Total loss: 6.446 | Reg loss: 0.033 | Tree loss: 6.446 | Accuracy: 0.273438 | 0.065 sec/iter
Epoch: 1

Epoch: 17 | Batch: 023 / 029 | Total loss: 6.101 | Reg loss: 0.034 | Tree loss: 6.101 | Accuracy: 0.292969 | 0.066 sec/iter
Epoch: 17 | Batch: 024 / 029 | Total loss: 6.082 | Reg loss: 0.035 | Tree loss: 6.082 | Accuracy: 0.277344 | 0.066 sec/iter
Epoch: 17 | Batch: 025 / 029 | Total loss: 6.010 | Reg loss: 0.035 | Tree loss: 6.010 | Accuracy: 0.312500 | 0.066 sec/iter
Epoch: 17 | Batch: 026 / 029 | Total loss: 6.047 | Reg loss: 0.035 | Tree loss: 6.047 | Accuracy: 0.281250 | 0.066 sec/iter
Epoch: 17 | Batch: 027 / 029 | Total loss: 6.012 | Reg loss: 0.035 | Tree loss: 6.012 | Accuracy: 0.285156 | 0.066 sec/iter
Epoch: 17 | Batch: 028 / 029 | Total loss: 6.070 | Reg loss: 0.035 | Tree loss: 6.070 | Accuracy: 0.246753 | 0.066 sec/iter
Average sparseness: 0.9821428571428567
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
Epoch: 18 | Batch: 000 / 029 | Total loss: 6.433 | Reg loss: 0.031 | Tree los

Epoch: 20 | Batch: 000 / 029 | Total loss: 6.158 | Reg loss: 0.033 | Tree loss: 6.158 | Accuracy: 0.289062 | 0.067 sec/iter
Epoch: 20 | Batch: 001 / 029 | Total loss: 6.072 | Reg loss: 0.033 | Tree loss: 6.072 | Accuracy: 0.343750 | 0.067 sec/iter
Epoch: 20 | Batch: 002 / 029 | Total loss: 6.082 | Reg loss: 0.033 | Tree loss: 6.082 | Accuracy: 0.316406 | 0.067 sec/iter
Epoch: 20 | Batch: 003 / 029 | Total loss: 6.076 | Reg loss: 0.033 | Tree loss: 6.076 | Accuracy: 0.281250 | 0.067 sec/iter
Epoch: 20 | Batch: 004 / 029 | Total loss: 6.013 | Reg loss: 0.033 | Tree loss: 6.013 | Accuracy: 0.294922 | 0.067 sec/iter
Epoch: 20 | Batch: 005 / 029 | Total loss: 6.000 | Reg loss: 0.033 | Tree loss: 6.000 | Accuracy: 0.318359 | 0.067 sec/iter
Epoch: 20 | Batch: 006 / 029 | Total loss: 5.987 | Reg loss: 0.033 | Tree loss: 5.987 | Accuracy: 0.273438 | 0.067 sec/iter
Epoch: 20 | Batch: 007 / 029 | Total loss: 5.995 | Reg loss: 0.034 | Tree loss: 5.995 | Accuracy: 0.277344 | 0.067 sec/iter
Epoch: 2

Epoch: 22 | Batch: 006 / 029 | Total loss: 5.620 | Reg loss: 0.035 | Tree loss: 5.620 | Accuracy: 0.304688 | 0.067 sec/iter
Epoch: 22 | Batch: 007 / 029 | Total loss: 5.629 | Reg loss: 0.035 | Tree loss: 5.629 | Accuracy: 0.283203 | 0.067 sec/iter
Epoch: 22 | Batch: 008 / 029 | Total loss: 5.597 | Reg loss: 0.035 | Tree loss: 5.597 | Accuracy: 0.289062 | 0.067 sec/iter
Epoch: 22 | Batch: 009 / 029 | Total loss: 5.551 | Reg loss: 0.035 | Tree loss: 5.551 | Accuracy: 0.298828 | 0.067 sec/iter
Epoch: 22 | Batch: 010 / 029 | Total loss: 5.572 | Reg loss: 0.035 | Tree loss: 5.572 | Accuracy: 0.294922 | 0.067 sec/iter
Epoch: 22 | Batch: 011 / 029 | Total loss: 5.559 | Reg loss: 0.036 | Tree loss: 5.559 | Accuracy: 0.277344 | 0.067 sec/iter
Epoch: 22 | Batch: 012 / 029 | Total loss: 5.534 | Reg loss: 0.036 | Tree loss: 5.534 | Accuracy: 0.279297 | 0.067 sec/iter
Epoch: 22 | Batch: 013 / 029 | Total loss: 5.516 | Reg loss: 0.036 | Tree loss: 5.516 | Accuracy: 0.275391 | 0.067 sec/iter
Epoch: 2

Epoch: 24 | Batch: 013 / 029 | Total loss: 5.122 | Reg loss: 0.037 | Tree loss: 5.122 | Accuracy: 0.314453 | 0.068 sec/iter
Epoch: 24 | Batch: 014 / 029 | Total loss: 5.132 | Reg loss: 0.037 | Tree loss: 5.132 | Accuracy: 0.289062 | 0.068 sec/iter
Epoch: 24 | Batch: 015 / 029 | Total loss: 5.124 | Reg loss: 0.037 | Tree loss: 5.124 | Accuracy: 0.281250 | 0.068 sec/iter
Epoch: 24 | Batch: 016 / 029 | Total loss: 5.077 | Reg loss: 0.037 | Tree loss: 5.077 | Accuracy: 0.314453 | 0.068 sec/iter
Epoch: 24 | Batch: 017 / 029 | Total loss: 5.069 | Reg loss: 0.038 | Tree loss: 5.069 | Accuracy: 0.283203 | 0.068 sec/iter
Epoch: 24 | Batch: 018 / 029 | Total loss: 5.046 | Reg loss: 0.038 | Tree loss: 5.046 | Accuracy: 0.285156 | 0.068 sec/iter
Epoch: 24 | Batch: 019 / 029 | Total loss: 5.008 | Reg loss: 0.038 | Tree loss: 5.008 | Accuracy: 0.285156 | 0.068 sec/iter
Epoch: 24 | Batch: 020 / 029 | Total loss: 4.983 | Reg loss: 0.038 | Tree loss: 4.983 | Accuracy: 0.296875 | 0.068 sec/iter
Epoch: 2

Epoch: 26 | Batch: 021 / 029 | Total loss: 4.601 | Reg loss: 0.039 | Tree loss: 4.601 | Accuracy: 0.275391 | 0.068 sec/iter
Epoch: 26 | Batch: 022 / 029 | Total loss: 4.621 | Reg loss: 0.039 | Tree loss: 4.621 | Accuracy: 0.287109 | 0.068 sec/iter
Epoch: 26 | Batch: 023 / 029 | Total loss: 4.629 | Reg loss: 0.039 | Tree loss: 4.629 | Accuracy: 0.259766 | 0.068 sec/iter
Epoch: 26 | Batch: 024 / 029 | Total loss: 4.548 | Reg loss: 0.040 | Tree loss: 4.548 | Accuracy: 0.287109 | 0.068 sec/iter
Epoch: 26 | Batch: 025 / 029 | Total loss: 4.512 | Reg loss: 0.040 | Tree loss: 4.512 | Accuracy: 0.314453 | 0.068 sec/iter
Epoch: 26 | Batch: 026 / 029 | Total loss: 4.484 | Reg loss: 0.040 | Tree loss: 4.484 | Accuracy: 0.296875 | 0.068 sec/iter
Epoch: 26 | Batch: 027 / 029 | Total loss: 4.564 | Reg loss: 0.040 | Tree loss: 4.564 | Accuracy: 0.261719 | 0.068 sec/iter
Epoch: 26 | Batch: 028 / 029 | Total loss: 4.455 | Reg loss: 0.040 | Tree loss: 4.455 | Accuracy: 0.285714 | 0.068 sec/iter
Average 

Epoch: 28 | Batch: 028 / 029 | Total loss: 4.104 | Reg loss: 0.040 | Tree loss: 4.104 | Accuracy: 0.311688 | 0.069 sec/iter
Average sparseness: 0.9821428571428567
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
Epoch: 29 | Batch: 000 / 029 | Total loss: 4.678 | Reg loss: 0.038 | Tree loss: 4.678 | Accuracy: 0.248047 | 0.069 sec/iter
Epoch: 29 | Batch: 001 / 029 | Total loss: 4.560 | Reg loss: 0.038 | Tree loss: 4.560 | Accuracy: 0.316406 | 0.069 sec/iter
Epoch: 29 | Batch: 002 / 029 | Total loss: 4.582 | Reg loss: 0.038 | Tree loss: 4.582 | Accuracy: 0.306641 | 0.069 sec/iter
Epoch: 29 | Batch: 003 / 029 | Total loss: 4.553 | Reg loss: 0.038 | Tree loss: 4.553 | Accuracy: 0.300781 | 0.069 sec/iter
Epoch: 29 | Batch: 004 / 029 | Total loss: 4.582 | Reg loss: 0.038 | Tree loss: 4.582 | Accuracy: 0.314453 | 0.069 sec/iter
Epoch: 29 | Batch: 005 / 029 | Total loss: 4.559 | Reg loss: 0.038 | Tree los

Epoch: 31 | Batch: 005 / 029 | Total loss: 4.205 | Reg loss: 0.038 | Tree loss: 4.205 | Accuracy: 0.312500 | 0.069 sec/iter
Epoch: 31 | Batch: 006 / 029 | Total loss: 4.278 | Reg loss: 0.038 | Tree loss: 4.278 | Accuracy: 0.257812 | 0.069 sec/iter
Epoch: 31 | Batch: 007 / 029 | Total loss: 4.204 | Reg loss: 0.038 | Tree loss: 4.204 | Accuracy: 0.253906 | 0.069 sec/iter
Epoch: 31 | Batch: 008 / 029 | Total loss: 4.137 | Reg loss: 0.038 | Tree loss: 4.137 | Accuracy: 0.296875 | 0.069 sec/iter
Epoch: 31 | Batch: 009 / 029 | Total loss: 4.170 | Reg loss: 0.038 | Tree loss: 4.170 | Accuracy: 0.267578 | 0.069 sec/iter
Epoch: 31 | Batch: 010 / 029 | Total loss: 4.076 | Reg loss: 0.038 | Tree loss: 4.076 | Accuracy: 0.304688 | 0.069 sec/iter
Epoch: 31 | Batch: 011 / 029 | Total loss: 4.068 | Reg loss: 0.038 | Tree loss: 4.068 | Accuracy: 0.287109 | 0.069 sec/iter
Epoch: 31 | Batch: 012 / 029 | Total loss: 4.065 | Reg loss: 0.038 | Tree loss: 4.065 | Accuracy: 0.283203 | 0.069 sec/iter
Epoch: 3

Epoch: 33 | Batch: 013 / 029 | Total loss: 3.791 | Reg loss: 0.038 | Tree loss: 3.791 | Accuracy: 0.302734 | 0.068 sec/iter
Epoch: 33 | Batch: 014 / 029 | Total loss: 3.710 | Reg loss: 0.038 | Tree loss: 3.710 | Accuracy: 0.285156 | 0.068 sec/iter
Epoch: 33 | Batch: 015 / 029 | Total loss: 3.770 | Reg loss: 0.038 | Tree loss: 3.770 | Accuracy: 0.273438 | 0.068 sec/iter
Epoch: 33 | Batch: 016 / 029 | Total loss: 3.694 | Reg loss: 0.038 | Tree loss: 3.694 | Accuracy: 0.283203 | 0.068 sec/iter
Epoch: 33 | Batch: 017 / 029 | Total loss: 3.683 | Reg loss: 0.039 | Tree loss: 3.683 | Accuracy: 0.283203 | 0.068 sec/iter
Epoch: 33 | Batch: 018 / 029 | Total loss: 3.681 | Reg loss: 0.039 | Tree loss: 3.681 | Accuracy: 0.265625 | 0.069 sec/iter
Epoch: 33 | Batch: 019 / 029 | Total loss: 3.552 | Reg loss: 0.039 | Tree loss: 3.552 | Accuracy: 0.326172 | 0.069 sec/iter
Epoch: 33 | Batch: 020 / 029 | Total loss: 3.598 | Reg loss: 0.039 | Tree loss: 3.598 | Accuracy: 0.294922 | 0.069 sec/iter
Epoch: 3

Epoch: 35 | Batch: 019 / 029 | Total loss: 3.397 | Reg loss: 0.038 | Tree loss: 3.397 | Accuracy: 0.279297 | 0.068 sec/iter
Epoch: 35 | Batch: 020 / 029 | Total loss: 3.335 | Reg loss: 0.038 | Tree loss: 3.335 | Accuracy: 0.265625 | 0.068 sec/iter
Epoch: 35 | Batch: 021 / 029 | Total loss: 3.340 | Reg loss: 0.038 | Tree loss: 3.340 | Accuracy: 0.271484 | 0.068 sec/iter
Epoch: 35 | Batch: 022 / 029 | Total loss: 3.278 | Reg loss: 0.038 | Tree loss: 3.278 | Accuracy: 0.296875 | 0.068 sec/iter
Epoch: 35 | Batch: 023 / 029 | Total loss: 3.320 | Reg loss: 0.039 | Tree loss: 3.320 | Accuracy: 0.232422 | 0.068 sec/iter
Epoch: 35 | Batch: 024 / 029 | Total loss: 3.238 | Reg loss: 0.039 | Tree loss: 3.238 | Accuracy: 0.273438 | 0.068 sec/iter
Epoch: 35 | Batch: 025 / 029 | Total loss: 3.271 | Reg loss: 0.039 | Tree loss: 3.271 | Accuracy: 0.253906 | 0.068 sec/iter
Epoch: 35 | Batch: 026 / 029 | Total loss: 3.250 | Reg loss: 0.039 | Tree loss: 3.250 | Accuracy: 0.253906 | 0.068 sec/iter
Epoch: 3

Epoch: 37 | Batch: 025 / 029 | Total loss: 2.992 | Reg loss: 0.039 | Tree loss: 2.992 | Accuracy: 0.279297 | 0.068 sec/iter
Epoch: 37 | Batch: 026 / 029 | Total loss: 2.992 | Reg loss: 0.039 | Tree loss: 2.992 | Accuracy: 0.251953 | 0.068 sec/iter
Epoch: 37 | Batch: 027 / 029 | Total loss: 2.997 | Reg loss: 0.040 | Tree loss: 2.997 | Accuracy: 0.251953 | 0.068 sec/iter
Epoch: 37 | Batch: 028 / 029 | Total loss: 2.969 | Reg loss: 0.040 | Tree loss: 2.969 | Accuracy: 0.233766 | 0.068 sec/iter
Average sparseness: 0.9821428571428567
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
Epoch: 38 | Batch: 000 / 029 | Total loss: 3.491 | Reg loss: 0.037 | Tree loss: 3.491 | Accuracy: 0.271484 | 0.068 sec/iter
Epoch: 38 | Batch: 001 / 029 | Total loss: 3.488 | Reg loss: 0.037 | Tree loss: 3.488 | Accuracy: 0.265625 | 0.068 sec/iter
Epoch: 38 | Batch: 002 / 029 | Total loss: 3.460 | Reg loss: 0.037 | Tree los

Epoch: 40 | Batch: 001 / 029 | Total loss: 3.224 | Reg loss: 0.038 | Tree loss: 3.224 | Accuracy: 0.320312 | 0.068 sec/iter
Epoch: 40 | Batch: 002 / 029 | Total loss: 3.248 | Reg loss: 0.038 | Tree loss: 3.248 | Accuracy: 0.285156 | 0.068 sec/iter
Epoch: 40 | Batch: 003 / 029 | Total loss: 3.189 | Reg loss: 0.038 | Tree loss: 3.189 | Accuracy: 0.281250 | 0.068 sec/iter
Epoch: 40 | Batch: 004 / 029 | Total loss: 3.177 | Reg loss: 0.038 | Tree loss: 3.177 | Accuracy: 0.255859 | 0.068 sec/iter
Epoch: 40 | Batch: 005 / 029 | Total loss: 3.226 | Reg loss: 0.038 | Tree loss: 3.226 | Accuracy: 0.281250 | 0.068 sec/iter
Epoch: 40 | Batch: 006 / 029 | Total loss: 3.112 | Reg loss: 0.038 | Tree loss: 3.112 | Accuracy: 0.285156 | 0.068 sec/iter
Epoch: 40 | Batch: 007 / 029 | Total loss: 3.049 | Reg loss: 0.038 | Tree loss: 3.049 | Accuracy: 0.283203 | 0.068 sec/iter
Epoch: 40 | Batch: 008 / 029 | Total loss: 3.080 | Reg loss: 0.038 | Tree loss: 3.080 | Accuracy: 0.292969 | 0.068 sec/iter
Epoch: 4

Epoch: 42 | Batch: 010 / 029 | Total loss: 2.832 | Reg loss: 0.039 | Tree loss: 2.832 | Accuracy: 0.291016 | 0.068 sec/iter
Epoch: 42 | Batch: 011 / 029 | Total loss: 2.822 | Reg loss: 0.039 | Tree loss: 2.822 | Accuracy: 0.259766 | 0.068 sec/iter
Epoch: 42 | Batch: 012 / 029 | Total loss: 2.782 | Reg loss: 0.039 | Tree loss: 2.782 | Accuracy: 0.281250 | 0.068 sec/iter
Epoch: 42 | Batch: 013 / 029 | Total loss: 2.811 | Reg loss: 0.039 | Tree loss: 2.811 | Accuracy: 0.294922 | 0.068 sec/iter
Epoch: 42 | Batch: 014 / 029 | Total loss: 2.759 | Reg loss: 0.039 | Tree loss: 2.759 | Accuracy: 0.304688 | 0.068 sec/iter
Epoch: 42 | Batch: 015 / 029 | Total loss: 2.781 | Reg loss: 0.039 | Tree loss: 2.781 | Accuracy: 0.287109 | 0.068 sec/iter
Epoch: 42 | Batch: 016 / 029 | Total loss: 2.756 | Reg loss: 0.039 | Tree loss: 2.756 | Accuracy: 0.291016 | 0.068 sec/iter
Epoch: 42 | Batch: 017 / 029 | Total loss: 2.701 | Reg loss: 0.039 | Tree loss: 2.701 | Accuracy: 0.267578 | 0.068 sec/iter
Epoch: 4

Epoch: 44 | Batch: 019 / 029 | Total loss: 2.516 | Reg loss: 0.040 | Tree loss: 2.516 | Accuracy: 0.287109 | 0.068 sec/iter
Epoch: 44 | Batch: 020 / 029 | Total loss: 2.475 | Reg loss: 0.040 | Tree loss: 2.475 | Accuracy: 0.306641 | 0.068 sec/iter
Epoch: 44 | Batch: 021 / 029 | Total loss: 2.570 | Reg loss: 0.040 | Tree loss: 2.570 | Accuracy: 0.261719 | 0.068 sec/iter
Epoch: 44 | Batch: 022 / 029 | Total loss: 2.485 | Reg loss: 0.040 | Tree loss: 2.485 | Accuracy: 0.287109 | 0.068 sec/iter
Epoch: 44 | Batch: 023 / 029 | Total loss: 2.523 | Reg loss: 0.040 | Tree loss: 2.523 | Accuracy: 0.269531 | 0.068 sec/iter
Epoch: 44 | Batch: 024 / 029 | Total loss: 2.490 | Reg loss: 0.040 | Tree loss: 2.490 | Accuracy: 0.273438 | 0.068 sec/iter
Epoch: 44 | Batch: 025 / 029 | Total loss: 2.512 | Reg loss: 0.040 | Tree loss: 2.512 | Accuracy: 0.271484 | 0.068 sec/iter
Epoch: 44 | Batch: 026 / 029 | Total loss: 2.457 | Reg loss: 0.040 | Tree loss: 2.457 | Accuracy: 0.291016 | 0.068 sec/iter
Epoch: 4

Epoch: 46 | Batch: 028 / 029 | Total loss: 2.317 | Reg loss: 0.041 | Tree loss: 2.317 | Accuracy: 0.298701 | 0.068 sec/iter
Average sparseness: 0.9821428571428567
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
Epoch: 47 | Batch: 000 / 029 | Total loss: 2.826 | Reg loss: 0.039 | Tree loss: 2.826 | Accuracy: 0.269531 | 0.068 sec/iter
Epoch: 47 | Batch: 001 / 029 | Total loss: 2.728 | Reg loss: 0.039 | Tree loss: 2.728 | Accuracy: 0.318359 | 0.068 sec/iter
Epoch: 47 | Batch: 002 / 029 | Total loss: 2.739 | Reg loss: 0.039 | Tree loss: 2.739 | Accuracy: 0.271484 | 0.068 sec/iter
Epoch: 47 | Batch: 003 / 029 | Total loss: 2.699 | Reg loss: 0.039 | Tree loss: 2.699 | Accuracy: 0.291016 | 0.068 sec/iter
Epoch: 47 | Batch: 004 / 029 | Total loss: 2.698 | Reg loss: 0.039 | Tree loss: 2.698 | Accuracy: 0.310547 | 0.068 sec/iter
Epoch: 47 | Batch: 005 / 029 | Total loss: 2.680 | Reg loss: 0.039 | Tree los

Epoch: 49 | Batch: 004 / 029 | Total loss: 2.581 | Reg loss: 0.039 | Tree loss: 2.581 | Accuracy: 0.324219 | 0.069 sec/iter
Epoch: 49 | Batch: 005 / 029 | Total loss: 2.575 | Reg loss: 0.039 | Tree loss: 2.575 | Accuracy: 0.302734 | 0.069 sec/iter
Epoch: 49 | Batch: 006 / 029 | Total loss: 2.596 | Reg loss: 0.039 | Tree loss: 2.596 | Accuracy: 0.273438 | 0.069 sec/iter
Epoch: 49 | Batch: 007 / 029 | Total loss: 2.556 | Reg loss: 0.039 | Tree loss: 2.556 | Accuracy: 0.281250 | 0.069 sec/iter
Epoch: 49 | Batch: 008 / 029 | Total loss: 2.536 | Reg loss: 0.039 | Tree loss: 2.536 | Accuracy: 0.267578 | 0.069 sec/iter
Epoch: 49 | Batch: 009 / 029 | Total loss: 2.548 | Reg loss: 0.039 | Tree loss: 2.548 | Accuracy: 0.298828 | 0.069 sec/iter
Epoch: 49 | Batch: 010 / 029 | Total loss: 2.530 | Reg loss: 0.039 | Tree loss: 2.530 | Accuracy: 0.246094 | 0.069 sec/iter
Epoch: 49 | Batch: 011 / 029 | Total loss: 2.514 | Reg loss: 0.039 | Tree loss: 2.514 | Accuracy: 0.263672 | 0.069 sec/iter
Epoch: 4

Epoch: 51 | Batch: 010 / 029 | Total loss: 2.366 | Reg loss: 0.039 | Tree loss: 2.366 | Accuracy: 0.291016 | 0.069 sec/iter
Epoch: 51 | Batch: 011 / 029 | Total loss: 2.424 | Reg loss: 0.039 | Tree loss: 2.424 | Accuracy: 0.273438 | 0.069 sec/iter
Epoch: 51 | Batch: 012 / 029 | Total loss: 2.386 | Reg loss: 0.039 | Tree loss: 2.386 | Accuracy: 0.292969 | 0.069 sec/iter
Epoch: 51 | Batch: 013 / 029 | Total loss: 2.444 | Reg loss: 0.039 | Tree loss: 2.444 | Accuracy: 0.285156 | 0.069 sec/iter
Epoch: 51 | Batch: 014 / 029 | Total loss: 2.329 | Reg loss: 0.039 | Tree loss: 2.329 | Accuracy: 0.314453 | 0.069 sec/iter
Epoch: 51 | Batch: 015 / 029 | Total loss: 2.340 | Reg loss: 0.039 | Tree loss: 2.340 | Accuracy: 0.248047 | 0.069 sec/iter
Epoch: 51 | Batch: 016 / 029 | Total loss: 2.330 | Reg loss: 0.039 | Tree loss: 2.330 | Accuracy: 0.253906 | 0.069 sec/iter
Epoch: 51 | Batch: 017 / 029 | Total loss: 2.329 | Reg loss: 0.039 | Tree loss: 2.329 | Accuracy: 0.261719 | 0.069 sec/iter
Epoch: 5

Epoch: 53 | Batch: 016 / 029 | Total loss: 2.278 | Reg loss: 0.039 | Tree loss: 2.278 | Accuracy: 0.267578 | 0.069 sec/iter
Epoch: 53 | Batch: 017 / 029 | Total loss: 2.253 | Reg loss: 0.039 | Tree loss: 2.253 | Accuracy: 0.308594 | 0.069 sec/iter
Epoch: 53 | Batch: 018 / 029 | Total loss: 2.193 | Reg loss: 0.039 | Tree loss: 2.193 | Accuracy: 0.265625 | 0.069 sec/iter
Epoch: 53 | Batch: 019 / 029 | Total loss: 2.191 | Reg loss: 0.039 | Tree loss: 2.191 | Accuracy: 0.287109 | 0.069 sec/iter
Epoch: 53 | Batch: 020 / 029 | Total loss: 2.241 | Reg loss: 0.039 | Tree loss: 2.241 | Accuracy: 0.314453 | 0.069 sec/iter
Epoch: 53 | Batch: 021 / 029 | Total loss: 2.190 | Reg loss: 0.040 | Tree loss: 2.190 | Accuracy: 0.279297 | 0.069 sec/iter
Epoch: 53 | Batch: 022 / 029 | Total loss: 2.219 | Reg loss: 0.040 | Tree loss: 2.219 | Accuracy: 0.259766 | 0.069 sec/iter
Epoch: 53 | Batch: 023 / 029 | Total loss: 2.168 | Reg loss: 0.040 | Tree loss: 2.168 | Accuracy: 0.273438 | 0.069 sec/iter
Epoch: 5

Epoch: 55 | Batch: 024 / 029 | Total loss: 2.049 | Reg loss: 0.040 | Tree loss: 2.049 | Accuracy: 0.279297 | 0.07 sec/iter
Epoch: 55 | Batch: 025 / 029 | Total loss: 2.128 | Reg loss: 0.040 | Tree loss: 2.128 | Accuracy: 0.296875 | 0.07 sec/iter
Epoch: 55 | Batch: 026 / 029 | Total loss: 2.114 | Reg loss: 0.040 | Tree loss: 2.114 | Accuracy: 0.318359 | 0.07 sec/iter
Epoch: 55 | Batch: 027 / 029 | Total loss: 2.118 | Reg loss: 0.040 | Tree loss: 2.118 | Accuracy: 0.279297 | 0.07 sec/iter
Epoch: 55 | Batch: 028 / 029 | Total loss: 2.039 | Reg loss: 0.040 | Tree loss: 2.039 | Accuracy: 0.279221 | 0.07 sec/iter
Average sparseness: 0.9821428571428567
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
Epoch: 56 | Batch: 000 / 029 | Total loss: 2.436 | Reg loss: 0.038 | Tree loss: 2.436 | Accuracy: 0.302734 | 0.07 sec/iter
Epoch: 56 | Batch: 001 / 029 | Total loss: 2.459 | Reg loss: 0.038 | Tree loss: 2.4

Epoch: 58 | Batch: 003 / 029 | Total loss: 2.462 | Reg loss: 0.038 | Tree loss: 2.462 | Accuracy: 0.296875 | 0.07 sec/iter
Epoch: 58 | Batch: 004 / 029 | Total loss: 2.421 | Reg loss: 0.038 | Tree loss: 2.421 | Accuracy: 0.253906 | 0.07 sec/iter
Epoch: 58 | Batch: 005 / 029 | Total loss: 2.349 | Reg loss: 0.038 | Tree loss: 2.349 | Accuracy: 0.296875 | 0.07 sec/iter
Epoch: 58 | Batch: 006 / 029 | Total loss: 2.381 | Reg loss: 0.038 | Tree loss: 2.381 | Accuracy: 0.292969 | 0.07 sec/iter
Epoch: 58 | Batch: 007 / 029 | Total loss: 2.245 | Reg loss: 0.038 | Tree loss: 2.245 | Accuracy: 0.318359 | 0.07 sec/iter
Epoch: 58 | Batch: 008 / 029 | Total loss: 2.234 | Reg loss: 0.038 | Tree loss: 2.234 | Accuracy: 0.316406 | 0.07 sec/iter
Epoch: 58 | Batch: 009 / 029 | Total loss: 2.215 | Reg loss: 0.038 | Tree loss: 2.215 | Accuracy: 0.304688 | 0.07 sec/iter
Epoch: 58 | Batch: 010 / 029 | Total loss: 2.209 | Reg loss: 0.038 | Tree loss: 2.209 | Accuracy: 0.294922 | 0.07 sec/iter
Epoch: 58 | Batc

Epoch: 60 | Batch: 009 / 029 | Total loss: 2.235 | Reg loss: 0.038 | Tree loss: 2.235 | Accuracy: 0.298828 | 0.069 sec/iter
Epoch: 60 | Batch: 010 / 029 | Total loss: 2.153 | Reg loss: 0.038 | Tree loss: 2.153 | Accuracy: 0.292969 | 0.069 sec/iter
Epoch: 60 | Batch: 011 / 029 | Total loss: 2.250 | Reg loss: 0.038 | Tree loss: 2.250 | Accuracy: 0.257812 | 0.069 sec/iter
Epoch: 60 | Batch: 012 / 029 | Total loss: 2.259 | Reg loss: 0.038 | Tree loss: 2.259 | Accuracy: 0.259766 | 0.069 sec/iter
Epoch: 60 | Batch: 013 / 029 | Total loss: 2.180 | Reg loss: 0.038 | Tree loss: 2.180 | Accuracy: 0.296875 | 0.069 sec/iter
Epoch: 60 | Batch: 014 / 029 | Total loss: 2.112 | Reg loss: 0.038 | Tree loss: 2.112 | Accuracy: 0.265625 | 0.069 sec/iter
Epoch: 60 | Batch: 015 / 029 | Total loss: 2.193 | Reg loss: 0.038 | Tree loss: 2.193 | Accuracy: 0.251953 | 0.069 sec/iter
Epoch: 60 | Batch: 016 / 029 | Total loss: 2.164 | Reg loss: 0.038 | Tree loss: 2.164 | Accuracy: 0.294922 | 0.069 sec/iter
Epoch: 6

Epoch: 62 | Batch: 016 / 029 | Total loss: 2.075 | Reg loss: 0.038 | Tree loss: 2.075 | Accuracy: 0.246094 | 0.069 sec/iter
Epoch: 62 | Batch: 017 / 029 | Total loss: 2.143 | Reg loss: 0.038 | Tree loss: 2.143 | Accuracy: 0.273438 | 0.069 sec/iter
Epoch: 62 | Batch: 018 / 029 | Total loss: 2.061 | Reg loss: 0.038 | Tree loss: 2.061 | Accuracy: 0.292969 | 0.069 sec/iter
Epoch: 62 | Batch: 019 / 029 | Total loss: 2.037 | Reg loss: 0.038 | Tree loss: 2.037 | Accuracy: 0.308594 | 0.069 sec/iter
Epoch: 62 | Batch: 020 / 029 | Total loss: 2.058 | Reg loss: 0.038 | Tree loss: 2.058 | Accuracy: 0.304688 | 0.069 sec/iter
Epoch: 62 | Batch: 021 / 029 | Total loss: 2.098 | Reg loss: 0.039 | Tree loss: 2.098 | Accuracy: 0.277344 | 0.069 sec/iter
Epoch: 62 | Batch: 022 / 029 | Total loss: 2.120 | Reg loss: 0.039 | Tree loss: 2.120 | Accuracy: 0.285156 | 0.069 sec/iter
Epoch: 62 | Batch: 023 / 029 | Total loss: 2.020 | Reg loss: 0.039 | Tree loss: 2.020 | Accuracy: 0.292969 | 0.069 sec/iter
Epoch: 6

Epoch: 64 | Batch: 024 / 029 | Total loss: 2.003 | Reg loss: 0.039 | Tree loss: 2.003 | Accuracy: 0.291016 | 0.069 sec/iter
Epoch: 64 | Batch: 025 / 029 | Total loss: 1.973 | Reg loss: 0.039 | Tree loss: 1.973 | Accuracy: 0.298828 | 0.069 sec/iter
Epoch: 64 | Batch: 026 / 029 | Total loss: 1.975 | Reg loss: 0.039 | Tree loss: 1.975 | Accuracy: 0.310547 | 0.069 sec/iter
Epoch: 64 | Batch: 027 / 029 | Total loss: 1.965 | Reg loss: 0.039 | Tree loss: 1.965 | Accuracy: 0.275391 | 0.069 sec/iter
Epoch: 64 | Batch: 028 / 029 | Total loss: 1.872 | Reg loss: 0.039 | Tree loss: 1.872 | Accuracy: 0.370130 | 0.069 sec/iter
Average sparseness: 0.9821428571428567
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
Epoch: 65 | Batch: 000 / 029 | Total loss: 2.308 | Reg loss: 0.037 | Tree loss: 2.308 | Accuracy: 0.332031 | 0.069 sec/iter
Epoch: 65 | Batch: 001 / 029 | Total loss: 2.289 | Reg loss: 0.037 | Tree los

Epoch: 67 | Batch: 001 / 029 | Total loss: 2.277 | Reg loss: 0.037 | Tree loss: 2.277 | Accuracy: 0.310547 | 0.07 sec/iter
Epoch: 67 | Batch: 002 / 029 | Total loss: 2.273 | Reg loss: 0.037 | Tree loss: 2.273 | Accuracy: 0.269531 | 0.07 sec/iter
Epoch: 67 | Batch: 003 / 029 | Total loss: 2.274 | Reg loss: 0.037 | Tree loss: 2.274 | Accuracy: 0.277344 | 0.07 sec/iter
Epoch: 67 | Batch: 004 / 029 | Total loss: 2.231 | Reg loss: 0.037 | Tree loss: 2.231 | Accuracy: 0.314453 | 0.07 sec/iter
Epoch: 67 | Batch: 005 / 029 | Total loss: 2.248 | Reg loss: 0.037 | Tree loss: 2.248 | Accuracy: 0.265625 | 0.07 sec/iter
Epoch: 67 | Batch: 006 / 029 | Total loss: 2.153 | Reg loss: 0.037 | Tree loss: 2.153 | Accuracy: 0.306641 | 0.07 sec/iter
Epoch: 67 | Batch: 007 / 029 | Total loss: 2.191 | Reg loss: 0.037 | Tree loss: 2.191 | Accuracy: 0.294922 | 0.07 sec/iter
Epoch: 67 | Batch: 008 / 029 | Total loss: 2.145 | Reg loss: 0.037 | Tree loss: 2.145 | Accuracy: 0.328125 | 0.07 sec/iter
Epoch: 67 | Batc

Epoch: 69 | Batch: 007 / 029 | Total loss: 2.204 | Reg loss: 0.037 | Tree loss: 2.204 | Accuracy: 0.330078 | 0.07 sec/iter
Epoch: 69 | Batch: 008 / 029 | Total loss: 2.173 | Reg loss: 0.037 | Tree loss: 2.173 | Accuracy: 0.269531 | 0.07 sec/iter
Epoch: 69 | Batch: 009 / 029 | Total loss: 2.186 | Reg loss: 0.037 | Tree loss: 2.186 | Accuracy: 0.277344 | 0.07 sec/iter
Epoch: 69 | Batch: 010 / 029 | Total loss: 2.156 | Reg loss: 0.037 | Tree loss: 2.156 | Accuracy: 0.269531 | 0.07 sec/iter
Epoch: 69 | Batch: 011 / 029 | Total loss: 2.107 | Reg loss: 0.037 | Tree loss: 2.107 | Accuracy: 0.300781 | 0.07 sec/iter
Epoch: 69 | Batch: 012 / 029 | Total loss: 2.091 | Reg loss: 0.037 | Tree loss: 2.091 | Accuracy: 0.287109 | 0.07 sec/iter
Epoch: 69 | Batch: 013 / 029 | Total loss: 2.087 | Reg loss: 0.037 | Tree loss: 2.087 | Accuracy: 0.291016 | 0.07 sec/iter
Epoch: 69 | Batch: 014 / 029 | Total loss: 2.024 | Reg loss: 0.037 | Tree loss: 2.024 | Accuracy: 0.294922 | 0.07 sec/iter
Epoch: 69 | Batc

Epoch: 71 | Batch: 014 / 029 | Total loss: 2.043 | Reg loss: 0.037 | Tree loss: 2.043 | Accuracy: 0.308594 | 0.07 sec/iter
Epoch: 71 | Batch: 015 / 029 | Total loss: 2.041 | Reg loss: 0.037 | Tree loss: 2.041 | Accuracy: 0.296875 | 0.07 sec/iter
Epoch: 71 | Batch: 016 / 029 | Total loss: 2.050 | Reg loss: 0.037 | Tree loss: 2.050 | Accuracy: 0.291016 | 0.07 sec/iter
Epoch: 71 | Batch: 017 / 029 | Total loss: 2.050 | Reg loss: 0.037 | Tree loss: 2.050 | Accuracy: 0.314453 | 0.07 sec/iter
Epoch: 71 | Batch: 018 / 029 | Total loss: 1.983 | Reg loss: 0.037 | Tree loss: 1.983 | Accuracy: 0.306641 | 0.07 sec/iter
Epoch: 71 | Batch: 019 / 029 | Total loss: 1.971 | Reg loss: 0.038 | Tree loss: 1.971 | Accuracy: 0.275391 | 0.07 sec/iter
Epoch: 71 | Batch: 020 / 029 | Total loss: 2.008 | Reg loss: 0.038 | Tree loss: 2.008 | Accuracy: 0.277344 | 0.07 sec/iter
Epoch: 71 | Batch: 021 / 029 | Total loss: 1.967 | Reg loss: 0.038 | Tree loss: 1.967 | Accuracy: 0.289062 | 0.07 sec/iter
Epoch: 71 | Batc

Epoch: 73 | Batch: 021 / 029 | Total loss: 1.954 | Reg loss: 0.038 | Tree loss: 1.954 | Accuracy: 0.292969 | 0.07 sec/iter
Epoch: 73 | Batch: 022 / 029 | Total loss: 1.944 | Reg loss: 0.038 | Tree loss: 1.944 | Accuracy: 0.298828 | 0.07 sec/iter
Epoch: 73 | Batch: 023 / 029 | Total loss: 1.920 | Reg loss: 0.038 | Tree loss: 1.920 | Accuracy: 0.263672 | 0.07 sec/iter
Epoch: 73 | Batch: 024 / 029 | Total loss: 1.949 | Reg loss: 0.038 | Tree loss: 1.949 | Accuracy: 0.310547 | 0.07 sec/iter
Epoch: 73 | Batch: 025 / 029 | Total loss: 1.906 | Reg loss: 0.038 | Tree loss: 1.906 | Accuracy: 0.273438 | 0.07 sec/iter
Epoch: 73 | Batch: 026 / 029 | Total loss: 1.951 | Reg loss: 0.038 | Tree loss: 1.951 | Accuracy: 0.269531 | 0.07 sec/iter
Epoch: 73 | Batch: 027 / 029 | Total loss: 1.941 | Reg loss: 0.038 | Tree loss: 1.941 | Accuracy: 0.308594 | 0.07 sec/iter
Epoch: 73 | Batch: 028 / 029 | Total loss: 1.835 | Reg loss: 0.038 | Tree loss: 1.835 | Accuracy: 0.363636 | 0.07 sec/iter
Average sparsene

Epoch: 76 | Batch: 000 / 029 | Total loss: 2.234 | Reg loss: 0.036 | Tree loss: 2.234 | Accuracy: 0.271484 | 0.07 sec/iter
Epoch: 76 | Batch: 001 / 029 | Total loss: 2.229 | Reg loss: 0.036 | Tree loss: 2.229 | Accuracy: 0.277344 | 0.07 sec/iter
Epoch: 76 | Batch: 002 / 029 | Total loss: 2.255 | Reg loss: 0.036 | Tree loss: 2.255 | Accuracy: 0.292969 | 0.07 sec/iter
Epoch: 76 | Batch: 003 / 029 | Total loss: 2.186 | Reg loss: 0.036 | Tree loss: 2.186 | Accuracy: 0.291016 | 0.07 sec/iter
Epoch: 76 | Batch: 004 / 029 | Total loss: 2.161 | Reg loss: 0.036 | Tree loss: 2.161 | Accuracy: 0.277344 | 0.07 sec/iter
Epoch: 76 | Batch: 005 / 029 | Total loss: 2.160 | Reg loss: 0.036 | Tree loss: 2.160 | Accuracy: 0.265625 | 0.07 sec/iter
Epoch: 76 | Batch: 006 / 029 | Total loss: 2.195 | Reg loss: 0.036 | Tree loss: 2.195 | Accuracy: 0.263672 | 0.07 sec/iter
Epoch: 76 | Batch: 007 / 029 | Total loss: 2.065 | Reg loss: 0.036 | Tree loss: 2.065 | Accuracy: 0.298828 | 0.07 sec/iter
Epoch: 76 | Batc

Epoch: 78 | Batch: 007 / 029 | Total loss: 2.089 | Reg loss: 0.036 | Tree loss: 2.089 | Accuracy: 0.296875 | 0.07 sec/iter
Epoch: 78 | Batch: 008 / 029 | Total loss: 2.128 | Reg loss: 0.036 | Tree loss: 2.128 | Accuracy: 0.316406 | 0.07 sec/iter
Epoch: 78 | Batch: 009 / 029 | Total loss: 2.088 | Reg loss: 0.036 | Tree loss: 2.088 | Accuracy: 0.294922 | 0.07 sec/iter
Epoch: 78 | Batch: 010 / 029 | Total loss: 2.108 | Reg loss: 0.036 | Tree loss: 2.108 | Accuracy: 0.259766 | 0.07 sec/iter
Epoch: 78 | Batch: 011 / 029 | Total loss: 2.018 | Reg loss: 0.036 | Tree loss: 2.018 | Accuracy: 0.289062 | 0.07 sec/iter
Epoch: 78 | Batch: 012 / 029 | Total loss: 2.024 | Reg loss: 0.036 | Tree loss: 2.024 | Accuracy: 0.289062 | 0.07 sec/iter
Epoch: 78 | Batch: 013 / 029 | Total loss: 2.080 | Reg loss: 0.037 | Tree loss: 2.080 | Accuracy: 0.294922 | 0.07 sec/iter
Epoch: 78 | Batch: 014 / 029 | Total loss: 1.927 | Reg loss: 0.037 | Tree loss: 1.927 | Accuracy: 0.281250 | 0.07 sec/iter
Epoch: 78 | Batc

Epoch: 80 | Batch: 014 / 029 | Total loss: 2.007 | Reg loss: 0.037 | Tree loss: 2.007 | Accuracy: 0.292969 | 0.069 sec/iter
Epoch: 80 | Batch: 015 / 029 | Total loss: 1.980 | Reg loss: 0.037 | Tree loss: 1.980 | Accuracy: 0.302734 | 0.069 sec/iter
Epoch: 80 | Batch: 016 / 029 | Total loss: 1.985 | Reg loss: 0.037 | Tree loss: 1.985 | Accuracy: 0.281250 | 0.069 sec/iter
Epoch: 80 | Batch: 017 / 029 | Total loss: 2.043 | Reg loss: 0.037 | Tree loss: 2.043 | Accuracy: 0.291016 | 0.069 sec/iter
Epoch: 80 | Batch: 018 / 029 | Total loss: 1.883 | Reg loss: 0.037 | Tree loss: 1.883 | Accuracy: 0.298828 | 0.069 sec/iter
Epoch: 80 | Batch: 019 / 029 | Total loss: 1.971 | Reg loss: 0.037 | Tree loss: 1.971 | Accuracy: 0.310547 | 0.069 sec/iter
Epoch: 80 | Batch: 020 / 029 | Total loss: 1.914 | Reg loss: 0.037 | Tree loss: 1.914 | Accuracy: 0.289062 | 0.069 sec/iter
Epoch: 80 | Batch: 021 / 029 | Total loss: 1.958 | Reg loss: 0.037 | Tree loss: 1.958 | Accuracy: 0.322266 | 0.069 sec/iter
Epoch: 8

Epoch: 82 | Batch: 021 / 029 | Total loss: 1.901 | Reg loss: 0.037 | Tree loss: 1.901 | Accuracy: 0.320312 | 0.069 sec/iter
Epoch: 82 | Batch: 022 / 029 | Total loss: 1.861 | Reg loss: 0.037 | Tree loss: 1.861 | Accuracy: 0.271484 | 0.069 sec/iter
Epoch: 82 | Batch: 023 / 029 | Total loss: 1.932 | Reg loss: 0.037 | Tree loss: 1.932 | Accuracy: 0.257812 | 0.069 sec/iter
Epoch: 82 | Batch: 024 / 029 | Total loss: 1.813 | Reg loss: 0.037 | Tree loss: 1.813 | Accuracy: 0.304688 | 0.069 sec/iter
Epoch: 82 | Batch: 025 / 029 | Total loss: 1.832 | Reg loss: 0.037 | Tree loss: 1.832 | Accuracy: 0.277344 | 0.069 sec/iter
Epoch: 82 | Batch: 026 / 029 | Total loss: 1.826 | Reg loss: 0.037 | Tree loss: 1.826 | Accuracy: 0.306641 | 0.069 sec/iter
Epoch: 82 | Batch: 027 / 029 | Total loss: 1.892 | Reg loss: 0.037 | Tree loss: 1.892 | Accuracy: 0.296875 | 0.069 sec/iter
Epoch: 82 | Batch: 028 / 029 | Total loss: 1.757 | Reg loss: 0.037 | Tree loss: 1.757 | Accuracy: 0.240260 | 0.069 sec/iter
Average 

Epoch: 84 | Batch: 027 / 029 | Total loss: 1.881 | Reg loss: 0.037 | Tree loss: 1.881 | Accuracy: 0.265625 | 0.069 sec/iter
Epoch: 84 | Batch: 028 / 029 | Total loss: 1.907 | Reg loss: 0.037 | Tree loss: 1.907 | Accuracy: 0.337662 | 0.069 sec/iter
Average sparseness: 0.9821428571428567
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
Epoch: 85 | Batch: 000 / 029 | Total loss: 2.211 | Reg loss: 0.036 | Tree loss: 2.211 | Accuracy: 0.300781 | 0.069 sec/iter
Epoch: 85 | Batch: 001 / 029 | Total loss: 2.144 | Reg loss: 0.036 | Tree loss: 2.144 | Accuracy: 0.298828 | 0.069 sec/iter
Epoch: 85 | Batch: 002 / 029 | Total loss: 2.231 | Reg loss: 0.036 | Tree loss: 2.231 | Accuracy: 0.279297 | 0.069 sec/iter
Epoch: 85 | Batch: 003 / 029 | Total loss: 2.141 | Reg loss: 0.036 | Tree loss: 2.141 | Accuracy: 0.267578 | 0.069 sec/iter
Epoch: 85 | Batch: 004 / 029 | Total loss: 2.095 | Reg loss: 0.036 | Tree los

Epoch: 87 | Batch: 003 / 029 | Total loss: 2.112 | Reg loss: 0.035 | Tree loss: 2.112 | Accuracy: 0.287109 | 0.069 sec/iter
Epoch: 87 | Batch: 004 / 029 | Total loss: 2.191 | Reg loss: 0.035 | Tree loss: 2.191 | Accuracy: 0.257812 | 0.069 sec/iter
Epoch: 87 | Batch: 005 / 029 | Total loss: 2.093 | Reg loss: 0.036 | Tree loss: 2.093 | Accuracy: 0.306641 | 0.069 sec/iter
Epoch: 87 | Batch: 006 / 029 | Total loss: 2.161 | Reg loss: 0.036 | Tree loss: 2.161 | Accuracy: 0.263672 | 0.069 sec/iter
Epoch: 87 | Batch: 007 / 029 | Total loss: 2.027 | Reg loss: 0.036 | Tree loss: 2.027 | Accuracy: 0.285156 | 0.069 sec/iter
Epoch: 87 | Batch: 008 / 029 | Total loss: 1.975 | Reg loss: 0.036 | Tree loss: 1.975 | Accuracy: 0.287109 | 0.069 sec/iter
Epoch: 87 | Batch: 009 / 029 | Total loss: 2.075 | Reg loss: 0.036 | Tree loss: 2.075 | Accuracy: 0.283203 | 0.069 sec/iter
Epoch: 87 | Batch: 010 / 029 | Total loss: 2.020 | Reg loss: 0.036 | Tree loss: 2.020 | Accuracy: 0.265625 | 0.069 sec/iter
Epoch: 8

Epoch: 89 | Batch: 011 / 029 | Total loss: 1.959 | Reg loss: 0.036 | Tree loss: 1.959 | Accuracy: 0.322266 | 0.069 sec/iter
Epoch: 89 | Batch: 012 / 029 | Total loss: 1.968 | Reg loss: 0.036 | Tree loss: 1.968 | Accuracy: 0.322266 | 0.069 sec/iter
Epoch: 89 | Batch: 013 / 029 | Total loss: 1.996 | Reg loss: 0.036 | Tree loss: 1.996 | Accuracy: 0.259766 | 0.069 sec/iter
Epoch: 89 | Batch: 014 / 029 | Total loss: 1.982 | Reg loss: 0.036 | Tree loss: 1.982 | Accuracy: 0.255859 | 0.069 sec/iter
Epoch: 89 | Batch: 015 / 029 | Total loss: 1.949 | Reg loss: 0.036 | Tree loss: 1.949 | Accuracy: 0.279297 | 0.069 sec/iter
Epoch: 89 | Batch: 016 / 029 | Total loss: 1.917 | Reg loss: 0.036 | Tree loss: 1.917 | Accuracy: 0.296875 | 0.069 sec/iter
Epoch: 89 | Batch: 017 / 029 | Total loss: 1.956 | Reg loss: 0.036 | Tree loss: 1.956 | Accuracy: 0.283203 | 0.069 sec/iter
Epoch: 89 | Batch: 018 / 029 | Total loss: 1.957 | Reg loss: 0.036 | Tree loss: 1.957 | Accuracy: 0.283203 | 0.069 sec/iter
Epoch: 8

Epoch: 91 | Batch: 019 / 029 | Total loss: 1.930 | Reg loss: 0.036 | Tree loss: 1.930 | Accuracy: 0.306641 | 0.069 sec/iter
Epoch: 91 | Batch: 020 / 029 | Total loss: 1.907 | Reg loss: 0.036 | Tree loss: 1.907 | Accuracy: 0.292969 | 0.069 sec/iter
Epoch: 91 | Batch: 021 / 029 | Total loss: 1.961 | Reg loss: 0.036 | Tree loss: 1.961 | Accuracy: 0.289062 | 0.069 sec/iter
Epoch: 91 | Batch: 022 / 029 | Total loss: 1.885 | Reg loss: 0.036 | Tree loss: 1.885 | Accuracy: 0.281250 | 0.069 sec/iter
Epoch: 91 | Batch: 023 / 029 | Total loss: 1.854 | Reg loss: 0.036 | Tree loss: 1.854 | Accuracy: 0.292969 | 0.069 sec/iter
Epoch: 91 | Batch: 024 / 029 | Total loss: 1.905 | Reg loss: 0.037 | Tree loss: 1.905 | Accuracy: 0.291016 | 0.069 sec/iter
Epoch: 91 | Batch: 025 / 029 | Total loss: 1.935 | Reg loss: 0.037 | Tree loss: 1.935 | Accuracy: 0.277344 | 0.069 sec/iter
Epoch: 91 | Batch: 026 / 029 | Total loss: 1.787 | Reg loss: 0.037 | Tree loss: 1.787 | Accuracy: 0.277344 | 0.069 sec/iter
Epoch: 9

Epoch: 93 | Batch: 026 / 029 | Total loss: 1.811 | Reg loss: 0.037 | Tree loss: 1.811 | Accuracy: 0.300781 | 0.069 sec/iter
Epoch: 93 | Batch: 027 / 029 | Total loss: 1.893 | Reg loss: 0.037 | Tree loss: 1.893 | Accuracy: 0.287109 | 0.069 sec/iter
Epoch: 93 | Batch: 028 / 029 | Total loss: 1.895 | Reg loss: 0.037 | Tree loss: 1.895 | Accuracy: 0.253247 | 0.069 sec/iter
Average sparseness: 0.9821428571428567
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
Epoch: 94 | Batch: 000 / 029 | Total loss: 2.163 | Reg loss: 0.035 | Tree loss: 2.163 | Accuracy: 0.271484 | 0.069 sec/iter
Epoch: 94 | Batch: 001 / 029 | Total loss: 2.150 | Reg loss: 0.035 | Tree loss: 2.150 | Accuracy: 0.308594 | 0.069 sec/iter
Epoch: 94 | Batch: 002 / 029 | Total loss: 2.127 | Reg loss: 0.035 | Tree loss: 2.127 | Accuracy: 0.294922 | 0.069 sec/iter
Epoch: 94 | Batch: 003 / 029 | Total loss: 2.119 | Reg loss: 0.035 | Tree los

Epoch: 96 | Batch: 003 / 029 | Total loss: 2.133 | Reg loss: 0.035 | Tree loss: 2.133 | Accuracy: 0.281250 | 0.069 sec/iter
Epoch: 96 | Batch: 004 / 029 | Total loss: 2.096 | Reg loss: 0.035 | Tree loss: 2.096 | Accuracy: 0.267578 | 0.069 sec/iter
Epoch: 96 | Batch: 005 / 029 | Total loss: 2.115 | Reg loss: 0.035 | Tree loss: 2.115 | Accuracy: 0.257812 | 0.069 sec/iter
Epoch: 96 | Batch: 006 / 029 | Total loss: 2.028 | Reg loss: 0.035 | Tree loss: 2.028 | Accuracy: 0.302734 | 0.069 sec/iter
Epoch: 96 | Batch: 007 / 029 | Total loss: 2.031 | Reg loss: 0.035 | Tree loss: 2.031 | Accuracy: 0.261719 | 0.069 sec/iter
Epoch: 96 | Batch: 008 / 029 | Total loss: 2.025 | Reg loss: 0.035 | Tree loss: 2.025 | Accuracy: 0.320312 | 0.069 sec/iter
Epoch: 96 | Batch: 009 / 029 | Total loss: 2.035 | Reg loss: 0.035 | Tree loss: 2.035 | Accuracy: 0.281250 | 0.069 sec/iter
Epoch: 96 | Batch: 010 / 029 | Total loss: 2.009 | Reg loss: 0.035 | Tree loss: 2.009 | Accuracy: 0.294922 | 0.069 sec/iter
Epoch: 9

Epoch: 98 | Batch: 010 / 029 | Total loss: 2.010 | Reg loss: 0.035 | Tree loss: 2.010 | Accuracy: 0.308594 | 0.069 sec/iter
Epoch: 98 | Batch: 011 / 029 | Total loss: 2.027 | Reg loss: 0.035 | Tree loss: 2.027 | Accuracy: 0.306641 | 0.069 sec/iter
Epoch: 98 | Batch: 012 / 029 | Total loss: 2.059 | Reg loss: 0.035 | Tree loss: 2.059 | Accuracy: 0.289062 | 0.069 sec/iter
Epoch: 98 | Batch: 013 / 029 | Total loss: 1.990 | Reg loss: 0.035 | Tree loss: 1.990 | Accuracy: 0.289062 | 0.069 sec/iter
Epoch: 98 | Batch: 014 / 029 | Total loss: 1.944 | Reg loss: 0.035 | Tree loss: 1.944 | Accuracy: 0.281250 | 0.069 sec/iter
Epoch: 98 | Batch: 015 / 029 | Total loss: 1.937 | Reg loss: 0.036 | Tree loss: 1.937 | Accuracy: 0.285156 | 0.069 sec/iter
Epoch: 98 | Batch: 016 / 029 | Total loss: 1.861 | Reg loss: 0.036 | Tree loss: 1.861 | Accuracy: 0.296875 | 0.069 sec/iter
Epoch: 98 | Batch: 017 / 029 | Total loss: 1.926 | Reg loss: 0.036 | Tree loss: 1.926 | Accuracy: 0.265625 | 0.069 sec/iter
Epoch: 9

In [32]:
plt.figure(figsize=(10, 5))
plt.ylabel("Accuracy")
plt.xlabel('Iteration')
plt.plot(accs, label='Accuracy vs iteration')
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [33]:
plt.figure()
plt.ylabel("Loss")
plt.xlabel('Iteration')
plt.plot(losses, label='Loss vs iteration')
plt.yscale("log")
plt.show()

plt.figure()
weights = tree.inner_nodes.weight.cpu().detach().numpy().flatten()
plt.hist(weights, bins=500)
weights_std = np.std(weights)
weights_mean = np.mean(weights)
plt.axvline(weights_mean + weights_std, color='r')
plt.axvline(weights_mean - weights_std, color='r')
plt.title(f"Mean: {weights_mean}   |   STD: {weights_std}")
plt.yscale("log")
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

# Tree Visualization

In [34]:
plt.figure(figsize=(15, 10), dpi=80)
avg_height, root = tree.visualize()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Average height: 5.185185185185185


# Extract Rules

# Accumulate samples in the leaves

In [35]:
print(f"Number of patterns: {len(root.get_leaves())}")

Number of patterns: 27


In [36]:
method = 'greedy'

In [37]:
root.clear_leaves_samples()

with torch.no_grad():
    for batch_idx, (data, target) in enumerate(tree_loader):
        root.accumulate_samples(data, method)



# Tighten boundaries

In [38]:
attr_names = dataset.items

# print(attr_names)
leaves = root.get_leaves()
sum_comprehensibility = 0
comprehensibilities = []
for pattern_counter, leaf in enumerate(leaves):
    leaf.reset_path()
    leaf.tighten_with_accumulated_samples()
    conds = leaf.get_path_conditions(attr_names)
    print(f"============== Pattern {pattern_counter + 1} ==============")
    comprehensibilities.append(sum([cond.comprehensibility for cond in conds]))
    
print(f"Average comprehensibility: {np.mean(comprehensibilities)}")
print(f"std comprehensibility: {np.std(comprehensibilities)}")
print(f"var comprehensibility: {np.var(comprehensibilities)}")
print(f"minimum comprehensibility: {np.min(comprehensibilities)}")
print(f"maximum comprehensibility: {np.max(comprehensibilities)}")

7804
1659
2261
345
2098
56
267
Average comprehensibility: 26.444444444444443
std comprehensibility: 4.755763235340599
var comprehensibility: 22.61728395061729
minimum comprehensibility: 16
maximum comprehensibility: 32


  return np.log(1 / (1 - x))
