In [7]:
%load_ext autoreload
%autoreload 2
%matplotlib widget
import numpy as np
import pandas as pd
import torch
import torchvision
from torch import optim
import matplotlib.pyplot as plt
import torch.nn.functional as F
from sklearn.metrics import pairwise_distances
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tqdm.notebook import tqdm
import os
import sys
import time
import torch.nn as nn
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)
from stream_generators.market_basket_dataset import MarketBasketDataset, BinaryEncodingTransform, RemoveItemsTransform
from utils.MatplotlibUtils import reduce_dims_and_plot
from network.auto_encoder import AutoEncoder
from losses.knn_loss import KNNLoss
from sklearn.cluster import DBSCAN
from sklearn.mixture import GaussianMixture
from soft_decision_tree.sdt_model import SDT
from sklearn.metrics import davies_bouldin_score

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [8]:
k = 64
tree_depth = 8
device = 'cuda'
dataset_path = r"/mnt/qnap/ekosman/Groceries_dataset.csv"

## Create the market basket dataset and use one-hot encoding for items

In [9]:
dataset = MarketBasketDataset(dataset_path=dataset_path)

In [10]:
model = AutoEncoder(dataset.n_items, 50, 4).train().to(device)
epochs = 500
lr = 5e-3
batch_size = 512
log_every = 5

In [11]:
dataset.transform = torchvision.transforms.Compose([
    RemoveItemsTransform(p=0.5),
    BinaryEncodingTransform(mapping=dataset.items_to_idx),
]
)
dataset.target_transform = torchvision.transforms.Compose([
    BinaryEncodingTransform(mapping=dataset.items_to_idx),
]
)

In [12]:
model.train()
data_iter = torch.utils.data.DataLoader(dataset,
                                     batch_size=batch_size,
                                     shuffle=True,
                                     num_workers=1,
                                     pin_memory=True)
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)
scheduler = ReduceLROnPlateau(optimizer, 'min', verbose=True, factor=0.7, threshold=1e-4)
knn_crt = KNNLoss(k=k).to(device)
losses = []
alpha = 10/170
gamma = 2
for epoch in range(epochs):
    total_loss = 0
    for iteration, (batch, target) in enumerate(data_iter):
        batch = batch.to(device)
        target = target.to(device)
        outputs, iterm = model(batch, return_intermidiate=True)
        mse_loss = F.binary_cross_entropy_with_logits(outputs, target, reduction='none')
        mask = torch.ones_like(mse_loss)
        mask[target == 0] = alpha ** gamma
        mask[target == 1] = (1 - alpha) ** gamma
        mse_loss = (mse_loss * mask).sum(dim=-1).mean()
        try:
            knn_loss = knn_crt(iterm)
            if torch.isinf(knn_loss):
                knn_loss = 0
        except ValueError:
            knn_loss = torch.tensor(0)
        loss = mse_loss + knn_loss
        total_loss += loss.item()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if iteration % log_every == 0:
            print(f"Epoch {epoch} / {epochs} | iteration {iteration} / {len(data_iter)} | Total Loss: {loss.item()} | KNN Loss: {knn_loss.item()} | BCE Loss: {mse_loss.item()}")
    
    scheduler.step(total_loss / (iteration + 1))
    losses.append(total_loss / (iteration + 1))

Epoch 0 / 500 | iteration 0 / 30 | Total Loss: 8.174827575683594 | KNN Loss: 6.228297710418701 | BCE Loss: 1.9465301036834717
Epoch 0 / 500 | iteration 5 / 30 | Total Loss: 8.177190780639648 | KNN Loss: 6.22810173034668 | BCE Loss: 1.9490892887115479
Epoch 0 / 500 | iteration 10 / 30 | Total Loss: 8.218077659606934 | KNN Loss: 6.228068828582764 | BCE Loss: 1.99000883102417
Epoch 0 / 500 | iteration 15 / 30 | Total Loss: 8.136442184448242 | KNN Loss: 6.2280378341674805 | BCE Loss: 1.90840482711792
Epoch 0 / 500 | iteration 20 / 30 | Total Loss: 8.170071601867676 | KNN Loss: 6.2273850440979 | BCE Loss: 1.9426862001419067
Epoch 0 / 500 | iteration 25 / 30 | Total Loss: 8.154052734375 | KNN Loss: 6.227167129516602 | BCE Loss: 1.9268858432769775
Epoch 1 / 500 | iteration 0 / 30 | Total Loss: 8.193419456481934 | KNN Loss: 6.226790428161621 | BCE Loss: 1.9666287899017334
Epoch 1 / 500 | iteration 5 / 30 | Total Loss: 8.15809440612793 | KNN Loss: 6.226618766784668 | BCE Loss: 1.93147611618042


Epoch 10 / 500 | iteration 25 / 30 | Total Loss: 5.972324371337891 | KNN Loss: 4.83497428894043 | BCE Loss: 1.1373498439788818
Epoch 11 / 500 | iteration 0 / 30 | Total Loss: 5.987344741821289 | KNN Loss: 4.841643810272217 | BCE Loss: 1.1457008123397827
Epoch 11 / 500 | iteration 5 / 30 | Total Loss: 5.96442174911499 | KNN Loss: 4.795790195465088 | BCE Loss: 1.1686314344406128
Epoch 11 / 500 | iteration 10 / 30 | Total Loss: 5.8169708251953125 | KNN Loss: 4.7224884033203125 | BCE Loss: 1.094482421875
Epoch 11 / 500 | iteration 15 / 30 | Total Loss: 5.812456130981445 | KNN Loss: 4.7007646560668945 | BCE Loss: 1.1116917133331299
Epoch 11 / 500 | iteration 20 / 30 | Total Loss: 5.810124397277832 | KNN Loss: 4.6933417320251465 | BCE Loss: 1.116782546043396
Epoch 11 / 500 | iteration 25 / 30 | Total Loss: 5.793524265289307 | KNN Loss: 4.673505783081055 | BCE Loss: 1.1200186014175415
Epoch 12 / 500 | iteration 0 / 30 | Total Loss: 5.723931789398193 | KNN Loss: 4.6097893714904785 | BCE Loss: 

Epoch 21 / 500 | iteration 20 / 30 | Total Loss: 5.5076117515563965 | KNN Loss: 4.457967281341553 | BCE Loss: 1.0496445894241333
Epoch 21 / 500 | iteration 25 / 30 | Total Loss: 5.538027763366699 | KNN Loss: 4.485864639282227 | BCE Loss: 1.0521628856658936
Epoch 22 / 500 | iteration 0 / 30 | Total Loss: 5.51624870300293 | KNN Loss: 4.441437244415283 | BCE Loss: 1.0748112201690674
Epoch 22 / 500 | iteration 5 / 30 | Total Loss: 5.4546966552734375 | KNN Loss: 4.411798000335693 | BCE Loss: 1.042898416519165
Epoch 22 / 500 | iteration 10 / 30 | Total Loss: 5.493555068969727 | KNN Loss: 4.437297821044922 | BCE Loss: 1.0562574863433838
Epoch 22 / 500 | iteration 15 / 30 | Total Loss: 5.537965774536133 | KNN Loss: 4.481687545776367 | BCE Loss: 1.0562784671783447
Epoch 22 / 500 | iteration 20 / 30 | Total Loss: 5.524158000946045 | KNN Loss: 4.479379177093506 | BCE Loss: 1.044778823852539
Epoch 22 / 500 | iteration 25 / 30 | Total Loss: 5.468611717224121 | KNN Loss: 4.432379722595215 | BCE Loss

Epoch 32 / 500 | iteration 15 / 30 | Total Loss: 5.4424638748168945 | KNN Loss: 4.4166083335876465 | BCE Loss: 1.0258557796478271
Epoch 32 / 500 | iteration 20 / 30 | Total Loss: 5.4416279792785645 | KNN Loss: 4.402205467224121 | BCE Loss: 1.0394225120544434
Epoch 32 / 500 | iteration 25 / 30 | Total Loss: 5.496085166931152 | KNN Loss: 4.437416076660156 | BCE Loss: 1.058669090270996
Epoch 33 / 500 | iteration 0 / 30 | Total Loss: 5.45173454284668 | KNN Loss: 4.408928394317627 | BCE Loss: 1.0428060293197632
Epoch 33 / 500 | iteration 5 / 30 | Total Loss: 5.485454082489014 | KNN Loss: 4.417438983917236 | BCE Loss: 1.068015217781067
Epoch 33 / 500 | iteration 10 / 30 | Total Loss: 5.485142230987549 | KNN Loss: 4.42959451675415 | BCE Loss: 1.055547833442688
Epoch 33 / 500 | iteration 15 / 30 | Total Loss: 5.494961738586426 | KNN Loss: 4.429432392120361 | BCE Loss: 1.0655295848846436
Epoch 33 / 500 | iteration 20 / 30 | Total Loss: 5.482602119445801 | KNN Loss: 4.441196918487549 | BCE Loss:

Epoch 43 / 500 | iteration 10 / 30 | Total Loss: 5.439749717712402 | KNN Loss: 4.4092912673950195 | BCE Loss: 1.0304585695266724
Epoch 43 / 500 | iteration 15 / 30 | Total Loss: 5.3913493156433105 | KNN Loss: 4.368643283843994 | BCE Loss: 1.0227059125900269
Epoch 43 / 500 | iteration 20 / 30 | Total Loss: 5.417226791381836 | KNN Loss: 4.386479377746582 | BCE Loss: 1.0307471752166748
Epoch 43 / 500 | iteration 25 / 30 | Total Loss: 5.456398963928223 | KNN Loss: 4.393693447113037 | BCE Loss: 1.062705397605896
Epoch 44 / 500 | iteration 0 / 30 | Total Loss: 5.4906840324401855 | KNN Loss: 4.430553436279297 | BCE Loss: 1.0601304769515991
Epoch 44 / 500 | iteration 5 / 30 | Total Loss: 5.418916702270508 | KNN Loss: 4.407166957855225 | BCE Loss: 1.0117496252059937
Epoch 44 / 500 | iteration 10 / 30 | Total Loss: 5.456377983093262 | KNN Loss: 4.4151201248168945 | BCE Loss: 1.0412578582763672
Epoch 44 / 500 | iteration 15 / 30 | Total Loss: 5.4453582763671875 | KNN Loss: 4.38686990737915 | BCE 

Epoch 54 / 500 | iteration 5 / 30 | Total Loss: 5.387113571166992 | KNN Loss: 4.374667167663574 | BCE Loss: 1.012446641921997
Epoch 54 / 500 | iteration 10 / 30 | Total Loss: 5.446811676025391 | KNN Loss: 4.403135299682617 | BCE Loss: 1.0436766147613525
Epoch 54 / 500 | iteration 15 / 30 | Total Loss: 5.404544830322266 | KNN Loss: 4.384265899658203 | BCE Loss: 1.0202791690826416
Epoch 54 / 500 | iteration 20 / 30 | Total Loss: 5.426520347595215 | KNN Loss: 4.360771656036377 | BCE Loss: 1.0657484531402588
Epoch 54 / 500 | iteration 25 / 30 | Total Loss: 5.397165298461914 | KNN Loss: 4.364058971405029 | BCE Loss: 1.0331063270568848
Epoch 55 / 500 | iteration 0 / 30 | Total Loss: 5.428272247314453 | KNN Loss: 4.379443645477295 | BCE Loss: 1.0488286018371582
Epoch 55 / 500 | iteration 5 / 30 | Total Loss: 5.4153242111206055 | KNN Loss: 4.3819260597229 | BCE Loss: 1.0333982706069946
Epoch 55 / 500 | iteration 10 / 30 | Total Loss: 5.497959613800049 | KNN Loss: 4.4260993003845215 | BCE Loss:

Epoch 65 / 500 | iteration 0 / 30 | Total Loss: 5.398621082305908 | KNN Loss: 4.346743583679199 | BCE Loss: 1.051877498626709
Epoch 65 / 500 | iteration 5 / 30 | Total Loss: 5.383538246154785 | KNN Loss: 4.367569923400879 | BCE Loss: 1.0159683227539062
Epoch 65 / 500 | iteration 10 / 30 | Total Loss: 5.487311363220215 | KNN Loss: 4.44534158706665 | BCE Loss: 1.0419695377349854
Epoch 65 / 500 | iteration 15 / 30 | Total Loss: 5.4748735427856445 | KNN Loss: 4.421506881713867 | BCE Loss: 1.0533666610717773
Epoch 65 / 500 | iteration 20 / 30 | Total Loss: 5.366055488586426 | KNN Loss: 4.353030204772949 | BCE Loss: 1.0130252838134766
Epoch 65 / 500 | iteration 25 / 30 | Total Loss: 5.407138347625732 | KNN Loss: 4.385196208953857 | BCE Loss: 1.021942138671875
Epoch 66 / 500 | iteration 0 / 30 | Total Loss: 5.4341230392456055 | KNN Loss: 4.402965068817139 | BCE Loss: 1.0311577320098877
Epoch 66 / 500 | iteration 5 / 30 | Total Loss: 5.397789001464844 | KNN Loss: 4.380267143249512 | BCE Loss: 

Epoch 75 / 500 | iteration 25 / 30 | Total Loss: 5.3887505531311035 | KNN Loss: 4.3645405769348145 | BCE Loss: 1.024209976196289
Epoch 76 / 500 | iteration 0 / 30 | Total Loss: 5.381153106689453 | KNN Loss: 4.365494728088379 | BCE Loss: 1.0156583786010742
Epoch 76 / 500 | iteration 5 / 30 | Total Loss: 5.454333305358887 | KNN Loss: 4.397769927978516 | BCE Loss: 1.056563138961792
Epoch 76 / 500 | iteration 10 / 30 | Total Loss: 5.467281818389893 | KNN Loss: 4.417895317077637 | BCE Loss: 1.0493865013122559
Epoch 76 / 500 | iteration 15 / 30 | Total Loss: 5.420456886291504 | KNN Loss: 4.390741348266602 | BCE Loss: 1.0297152996063232
Epoch 76 / 500 | iteration 20 / 30 | Total Loss: 5.444135665893555 | KNN Loss: 4.409120559692383 | BCE Loss: 1.0350151062011719
Epoch 76 / 500 | iteration 25 / 30 | Total Loss: 5.411436557769775 | KNN Loss: 4.382104396820068 | BCE Loss: 1.0293322801589966
Epoch 77 / 500 | iteration 0 / 30 | Total Loss: 5.400620460510254 | KNN Loss: 4.336050510406494 | BCE Loss

Epoch 86 / 500 | iteration 15 / 30 | Total Loss: 5.3990888595581055 | KNN Loss: 4.374813556671143 | BCE Loss: 1.024275541305542
Epoch 86 / 500 | iteration 20 / 30 | Total Loss: 5.4385786056518555 | KNN Loss: 4.387181282043457 | BCE Loss: 1.0513975620269775
Epoch 86 / 500 | iteration 25 / 30 | Total Loss: 5.411357402801514 | KNN Loss: 4.365033149719238 | BCE Loss: 1.046324372291565
Epoch 87 / 500 | iteration 0 / 30 | Total Loss: 5.385058403015137 | KNN Loss: 4.356067657470703 | BCE Loss: 1.0289907455444336
Epoch 87 / 500 | iteration 5 / 30 | Total Loss: 5.423992156982422 | KNN Loss: 4.40303897857666 | BCE Loss: 1.0209534168243408
Epoch 87 / 500 | iteration 10 / 30 | Total Loss: 5.400088787078857 | KNN Loss: 4.374690055847168 | BCE Loss: 1.025398850440979
Epoch 87 / 500 | iteration 15 / 30 | Total Loss: 5.42087459564209 | KNN Loss: 4.387348175048828 | BCE Loss: 1.0335263013839722
Epoch 87 / 500 | iteration 20 / 30 | Total Loss: 5.422621250152588 | KNN Loss: 4.409318923950195 | BCE Loss: 

Epoch 97 / 500 | iteration 5 / 30 | Total Loss: 5.412693023681641 | KNN Loss: 4.365521430969238 | BCE Loss: 1.0471715927124023
Epoch 97 / 500 | iteration 10 / 30 | Total Loss: 5.388228416442871 | KNN Loss: 4.382907390594482 | BCE Loss: 1.0053210258483887
Epoch 97 / 500 | iteration 15 / 30 | Total Loss: 5.412227153778076 | KNN Loss: 4.390500545501709 | BCE Loss: 1.0217267274856567
Epoch 97 / 500 | iteration 20 / 30 | Total Loss: 5.362514972686768 | KNN Loss: 4.355202674865723 | BCE Loss: 1.007312297821045
Epoch 97 / 500 | iteration 25 / 30 | Total Loss: 5.40207052230835 | KNN Loss: 4.37186861038208 | BCE Loss: 1.030202031135559
Epoch 98 / 500 | iteration 0 / 30 | Total Loss: 5.390743732452393 | KNN Loss: 4.364981651306152 | BCE Loss: 1.0257620811462402
Epoch 98 / 500 | iteration 5 / 30 | Total Loss: 5.395610809326172 | KNN Loss: 4.366288661956787 | BCE Loss: 1.0293219089508057
Epoch 98 / 500 | iteration 10 / 30 | Total Loss: 5.432631969451904 | KNN Loss: 4.399496555328369 | BCE Loss: 1.

Epoch 107 / 500 | iteration 25 / 30 | Total Loss: 5.489634990692139 | KNN Loss: 4.466172218322754 | BCE Loss: 1.0234627723693848
Epoch   108: reducing learning rate of group 0 to 1.7150e-03.
Epoch 108 / 500 | iteration 0 / 30 | Total Loss: 5.396200180053711 | KNN Loss: 4.382324695587158 | BCE Loss: 1.0138752460479736
Epoch 108 / 500 | iteration 5 / 30 | Total Loss: 5.420331954956055 | KNN Loss: 4.367659568786621 | BCE Loss: 1.0526723861694336
Epoch 108 / 500 | iteration 10 / 30 | Total Loss: 5.431967735290527 | KNN Loss: 4.3840179443359375 | BCE Loss: 1.047950029373169
Epoch 108 / 500 | iteration 15 / 30 | Total Loss: 5.4168901443481445 | KNN Loss: 4.37197208404541 | BCE Loss: 1.0449178218841553
Epoch 108 / 500 | iteration 20 / 30 | Total Loss: 5.420536994934082 | KNN Loss: 4.394333839416504 | BCE Loss: 1.0262031555175781
Epoch 108 / 500 | iteration 25 / 30 | Total Loss: 5.393564701080322 | KNN Loss: 4.366558074951172 | BCE Loss: 1.0270065069198608
Epoch 109 / 500 | iteration 0 / 30 | 

Epoch 118 / 500 | iteration 15 / 30 | Total Loss: 5.4399824142456055 | KNN Loss: 4.39484167098999 | BCE Loss: 1.0451407432556152
Epoch 118 / 500 | iteration 20 / 30 | Total Loss: 5.3714189529418945 | KNN Loss: 4.349240779876709 | BCE Loss: 1.0221779346466064
Epoch 118 / 500 | iteration 25 / 30 | Total Loss: 5.427596569061279 | KNN Loss: 4.382777214050293 | BCE Loss: 1.0448194742202759
Epoch 119 / 500 | iteration 0 / 30 | Total Loss: 5.366668701171875 | KNN Loss: 4.378049850463867 | BCE Loss: 0.9886189103126526
Epoch 119 / 500 | iteration 5 / 30 | Total Loss: 5.370730400085449 | KNN Loss: 4.3556342124938965 | BCE Loss: 1.0150964260101318
Epoch 119 / 500 | iteration 10 / 30 | Total Loss: 5.422802925109863 | KNN Loss: 4.372765064239502 | BCE Loss: 1.0500377416610718
Epoch 119 / 500 | iteration 15 / 30 | Total Loss: 5.4230241775512695 | KNN Loss: 4.3888983726501465 | BCE Loss: 1.034125804901123
Epoch 119 / 500 | iteration 20 / 30 | Total Loss: 5.433114528656006 | KNN Loss: 4.40408658981323

Epoch 129 / 500 | iteration 5 / 30 | Total Loss: 5.401525020599365 | KNN Loss: 4.372458457946777 | BCE Loss: 1.029066562652588
Epoch 129 / 500 | iteration 10 / 30 | Total Loss: 5.400982856750488 | KNN Loss: 4.3750529289245605 | BCE Loss: 1.0259300470352173
Epoch 129 / 500 | iteration 15 / 30 | Total Loss: 5.39568567276001 | KNN Loss: 4.378140449523926 | BCE Loss: 1.0175453424453735
Epoch 129 / 500 | iteration 20 / 30 | Total Loss: 5.411736488342285 | KNN Loss: 4.376201629638672 | BCE Loss: 1.0355347394943237
Epoch 129 / 500 | iteration 25 / 30 | Total Loss: 5.5009050369262695 | KNN Loss: 4.433978080749512 | BCE Loss: 1.066927194595337
Epoch 130 / 500 | iteration 0 / 30 | Total Loss: 5.404448509216309 | KNN Loss: 4.362146377563477 | BCE Loss: 1.042301893234253
Epoch 130 / 500 | iteration 5 / 30 | Total Loss: 5.410538196563721 | KNN Loss: 4.364859580993652 | BCE Loss: 1.0456786155700684
Epoch 130 / 500 | iteration 10 / 30 | Total Loss: 5.415460109710693 | KNN Loss: 4.373363971710205 | BC

Epoch 139 / 500 | iteration 25 / 30 | Total Loss: 5.46508264541626 | KNN Loss: 4.4123358726501465 | BCE Loss: 1.0527468919754028
Epoch   140: reducing learning rate of group 0 to 8.4035e-04.
Epoch 140 / 500 | iteration 0 / 30 | Total Loss: 5.3770575523376465 | KNN Loss: 4.357994079589844 | BCE Loss: 1.0190633535385132
Epoch 140 / 500 | iteration 5 / 30 | Total Loss: 5.377534866333008 | KNN Loss: 4.362887382507324 | BCE Loss: 1.0146477222442627
Epoch 140 / 500 | iteration 10 / 30 | Total Loss: 5.402493953704834 | KNN Loss: 4.35179328918457 | BCE Loss: 1.0507006645202637
Epoch 140 / 500 | iteration 15 / 30 | Total Loss: 5.417993068695068 | KNN Loss: 4.349936008453369 | BCE Loss: 1.0680571794509888
Epoch 140 / 500 | iteration 20 / 30 | Total Loss: 5.434355735778809 | KNN Loss: 4.37747049331665 | BCE Loss: 1.0568854808807373
Epoch 140 / 500 | iteration 25 / 30 | Total Loss: 5.364978313446045 | KNN Loss: 4.356512546539307 | BCE Loss: 1.0084656476974487
Epoch 141 / 500 | iteration 0 / 30 | T

Epoch 150 / 500 | iteration 15 / 30 | Total Loss: 5.381557464599609 | KNN Loss: 4.361090660095215 | BCE Loss: 1.0204668045043945
Epoch 150 / 500 | iteration 20 / 30 | Total Loss: 5.441189765930176 | KNN Loss: 4.418916702270508 | BCE Loss: 1.0222729444503784
Epoch 150 / 500 | iteration 25 / 30 | Total Loss: 5.433194160461426 | KNN Loss: 4.379369258880615 | BCE Loss: 1.0538249015808105
Epoch 151 / 500 | iteration 0 / 30 | Total Loss: 5.396218299865723 | KNN Loss: 4.360690593719482 | BCE Loss: 1.0355274677276611
Epoch 151 / 500 | iteration 5 / 30 | Total Loss: 5.433990478515625 | KNN Loss: 4.390612602233887 | BCE Loss: 1.0433779954910278
Epoch 151 / 500 | iteration 10 / 30 | Total Loss: 5.35537576675415 | KNN Loss: 4.355792045593262 | BCE Loss: 0.9995837211608887
Epoch 151 / 500 | iteration 15 / 30 | Total Loss: 5.406432151794434 | KNN Loss: 4.374722003936768 | BCE Loss: 1.031709909439087
Epoch 151 / 500 | iteration 20 / 30 | Total Loss: 5.4137773513793945 | KNN Loss: 4.35959529876709 | B

Epoch 161 / 500 | iteration 5 / 30 | Total Loss: 5.374227523803711 | KNN Loss: 4.332513332366943 | BCE Loss: 1.0417141914367676
Epoch 161 / 500 | iteration 10 / 30 | Total Loss: 5.466549873352051 | KNN Loss: 4.419331073760986 | BCE Loss: 1.0472185611724854
Epoch 161 / 500 | iteration 15 / 30 | Total Loss: 5.41642427444458 | KNN Loss: 4.367187023162842 | BCE Loss: 1.0492371320724487
Epoch 161 / 500 | iteration 20 / 30 | Total Loss: 5.372176647186279 | KNN Loss: 4.364610195159912 | BCE Loss: 1.0075664520263672
Epoch 161 / 500 | iteration 25 / 30 | Total Loss: 5.442203521728516 | KNN Loss: 4.395930290222168 | BCE Loss: 1.0462734699249268
Epoch 162 / 500 | iteration 0 / 30 | Total Loss: 5.4208831787109375 | KNN Loss: 4.380406379699707 | BCE Loss: 1.04047691822052
Epoch 162 / 500 | iteration 5 / 30 | Total Loss: 5.370200157165527 | KNN Loss: 4.351572036743164 | BCE Loss: 1.0186280012130737
Epoch 162 / 500 | iteration 10 / 30 | Total Loss: 5.415696144104004 | KNN Loss: 4.398961544036865 | BC

Epoch 171 / 500 | iteration 25 / 30 | Total Loss: 5.36831521987915 | KNN Loss: 4.35055685043335 | BCE Loss: 1.0177583694458008
Epoch 172 / 500 | iteration 0 / 30 | Total Loss: 5.39324426651001 | KNN Loss: 4.377485275268555 | BCE Loss: 1.015758991241455
Epoch 172 / 500 | iteration 5 / 30 | Total Loss: 5.434553146362305 | KNN Loss: 4.412121295928955 | BCE Loss: 1.0224320888519287
Epoch 172 / 500 | iteration 10 / 30 | Total Loss: 5.408450603485107 | KNN Loss: 4.386533737182617 | BCE Loss: 1.0219169855117798
Epoch 172 / 500 | iteration 15 / 30 | Total Loss: 5.4300312995910645 | KNN Loss: 4.383147716522217 | BCE Loss: 1.046883463859558
Epoch 172 / 500 | iteration 20 / 30 | Total Loss: 5.394927978515625 | KNN Loss: 4.348963260650635 | BCE Loss: 1.0459644794464111
Epoch 172 / 500 | iteration 25 / 30 | Total Loss: 5.398402214050293 | KNN Loss: 4.359979152679443 | BCE Loss: 1.0384230613708496
Epoch 173 / 500 | iteration 0 / 30 | Total Loss: 5.427352428436279 | KNN Loss: 4.399730682373047 | BCE 

Epoch 182 / 500 | iteration 15 / 30 | Total Loss: 5.415480613708496 | KNN Loss: 4.400487422943115 | BCE Loss: 1.0149930715560913
Epoch 182 / 500 | iteration 20 / 30 | Total Loss: 5.410183906555176 | KNN Loss: 4.366849422454834 | BCE Loss: 1.0433344841003418
Epoch 182 / 500 | iteration 25 / 30 | Total Loss: 5.380431652069092 | KNN Loss: 4.378454685211182 | BCE Loss: 1.0019769668579102
Epoch 183 / 500 | iteration 0 / 30 | Total Loss: 5.374417304992676 | KNN Loss: 4.356521129608154 | BCE Loss: 1.0178959369659424
Epoch 183 / 500 | iteration 5 / 30 | Total Loss: 5.435103893280029 | KNN Loss: 4.4098992347717285 | BCE Loss: 1.0252045392990112
Epoch 183 / 500 | iteration 10 / 30 | Total Loss: 5.423579216003418 | KNN Loss: 4.38866662979126 | BCE Loss: 1.0349128246307373
Epoch 183 / 500 | iteration 15 / 30 | Total Loss: 5.356777191162109 | KNN Loss: 4.344326496124268 | BCE Loss: 1.0124504566192627
Epoch 183 / 500 | iteration 20 / 30 | Total Loss: 5.420635223388672 | KNN Loss: 4.357544898986816 |

Epoch 193 / 500 | iteration 5 / 30 | Total Loss: 5.371336460113525 | KNN Loss: 4.355814456939697 | BCE Loss: 1.0155220031738281
Epoch 193 / 500 | iteration 10 / 30 | Total Loss: 5.439828395843506 | KNN Loss: 4.3843488693237305 | BCE Loss: 1.0554795265197754
Epoch 193 / 500 | iteration 15 / 30 | Total Loss: 5.3485426902771 | KNN Loss: 4.339293956756592 | BCE Loss: 1.0092486143112183
Epoch 193 / 500 | iteration 20 / 30 | Total Loss: 5.382338523864746 | KNN Loss: 4.373823165893555 | BCE Loss: 1.0085151195526123
Epoch 193 / 500 | iteration 25 / 30 | Total Loss: 5.3646440505981445 | KNN Loss: 4.35214376449585 | BCE Loss: 1.012500286102295
Epoch 194 / 500 | iteration 0 / 30 | Total Loss: 5.423469543457031 | KNN Loss: 4.366580963134766 | BCE Loss: 1.0568888187408447
Epoch 194 / 500 | iteration 5 / 30 | Total Loss: 5.3954315185546875 | KNN Loss: 4.371198654174805 | BCE Loss: 1.0242326259613037
Epoch 194 / 500 | iteration 10 / 30 | Total Loss: 5.379033088684082 | KNN Loss: 4.36260461807251 | BC

Epoch 203 / 500 | iteration 25 / 30 | Total Loss: 5.44553279876709 | KNN Loss: 4.397899150848389 | BCE Loss: 1.0476338863372803
Epoch 204 / 500 | iteration 0 / 30 | Total Loss: 5.368404388427734 | KNN Loss: 4.355591773986816 | BCE Loss: 1.012812852859497
Epoch 204 / 500 | iteration 5 / 30 | Total Loss: 5.454561233520508 | KNN Loss: 4.416686534881592 | BCE Loss: 1.037874698638916
Epoch 204 / 500 | iteration 10 / 30 | Total Loss: 5.42636775970459 | KNN Loss: 4.403219223022461 | BCE Loss: 1.023148536682129
Epoch 204 / 500 | iteration 15 / 30 | Total Loss: 5.3892083168029785 | KNN Loss: 4.387063980102539 | BCE Loss: 1.00214421749115
Epoch 204 / 500 | iteration 20 / 30 | Total Loss: 5.385799884796143 | KNN Loss: 4.385186195373535 | BCE Loss: 1.0006136894226074
Epoch 204 / 500 | iteration 25 / 30 | Total Loss: 5.41087532043457 | KNN Loss: 4.355223655700684 | BCE Loss: 1.0556514263153076
Epoch 205 / 500 | iteration 0 / 30 | Total Loss: 5.432422637939453 | KNN Loss: 4.38545036315918 | BCE Loss

Epoch 214 / 500 | iteration 15 / 30 | Total Loss: 5.376770973205566 | KNN Loss: 4.347792148590088 | BCE Loss: 1.0289790630340576
Epoch 214 / 500 | iteration 20 / 30 | Total Loss: 5.424764633178711 | KNN Loss: 4.390745639801025 | BCE Loss: 1.0340189933776855
Epoch 214 / 500 | iteration 25 / 30 | Total Loss: 5.350327968597412 | KNN Loss: 4.352342128753662 | BCE Loss: 0.9979859590530396
Epoch 215 / 500 | iteration 0 / 30 | Total Loss: 5.405364513397217 | KNN Loss: 4.3691182136535645 | BCE Loss: 1.0362461805343628
Epoch 215 / 500 | iteration 5 / 30 | Total Loss: 5.40249490737915 | KNN Loss: 4.342508792877197 | BCE Loss: 1.0599861145019531
Epoch 215 / 500 | iteration 10 / 30 | Total Loss: 5.360642433166504 | KNN Loss: 4.352747917175293 | BCE Loss: 1.00789475440979
Epoch 215 / 500 | iteration 15 / 30 | Total Loss: 5.38388204574585 | KNN Loss: 4.366787910461426 | BCE Loss: 1.0170940160751343
Epoch 215 / 500 | iteration 20 / 30 | Total Loss: 5.415053367614746 | KNN Loss: 4.366507053375244 | BC

Epoch 225 / 500 | iteration 5 / 30 | Total Loss: 5.358302593231201 | KNN Loss: 4.346256256103516 | BCE Loss: 1.012046217918396
Epoch 225 / 500 | iteration 10 / 30 | Total Loss: 5.388148307800293 | KNN Loss: 4.3635382652282715 | BCE Loss: 1.0246098041534424
Epoch 225 / 500 | iteration 15 / 30 | Total Loss: 5.390926361083984 | KNN Loss: 4.373022079467773 | BCE Loss: 1.0179040431976318
Epoch 225 / 500 | iteration 20 / 30 | Total Loss: 5.3877034187316895 | KNN Loss: 4.3659443855285645 | BCE Loss: 1.021759033203125
Epoch 225 / 500 | iteration 25 / 30 | Total Loss: 5.373426914215088 | KNN Loss: 4.356732368469238 | BCE Loss: 1.0166945457458496
Epoch 226 / 500 | iteration 0 / 30 | Total Loss: 5.404093265533447 | KNN Loss: 4.3650898933410645 | BCE Loss: 1.0390032529830933
Epoch 226 / 500 | iteration 5 / 30 | Total Loss: 5.36570930480957 | KNN Loss: 4.37091588973999 | BCE Loss: 0.9947935342788696
Epoch 226 / 500 | iteration 10 / 30 | Total Loss: 5.3829545974731445 | KNN Loss: 4.365495204925537 |

Epoch 235 / 500 | iteration 25 / 30 | Total Loss: 5.4504804611206055 | KNN Loss: 4.406428813934326 | BCE Loss: 1.0440518856048584
Epoch 236 / 500 | iteration 0 / 30 | Total Loss: 5.372734546661377 | KNN Loss: 4.341066360473633 | BCE Loss: 1.0316680669784546
Epoch 236 / 500 | iteration 5 / 30 | Total Loss: 5.429498672485352 | KNN Loss: 4.371689319610596 | BCE Loss: 1.057809591293335
Epoch 236 / 500 | iteration 10 / 30 | Total Loss: 5.393152713775635 | KNN Loss: 4.396761417388916 | BCE Loss: 0.9963914155960083
Epoch 236 / 500 | iteration 15 / 30 | Total Loss: 5.451716899871826 | KNN Loss: 4.389090538024902 | BCE Loss: 1.0626263618469238
Epoch 236 / 500 | iteration 20 / 30 | Total Loss: 5.348982334136963 | KNN Loss: 4.354413032531738 | BCE Loss: 0.9945691823959351
Epoch 236 / 500 | iteration 25 / 30 | Total Loss: 5.354949474334717 | KNN Loss: 4.348024845123291 | BCE Loss: 1.0069246292114258
Epoch 237 / 500 | iteration 0 / 30 | Total Loss: 5.377147674560547 | KNN Loss: 4.344461917877197 | 

Epoch 246 / 500 | iteration 15 / 30 | Total Loss: 5.405891418457031 | KNN Loss: 4.342324256896973 | BCE Loss: 1.0635673999786377
Epoch 246 / 500 | iteration 20 / 30 | Total Loss: 5.397876739501953 | KNN Loss: 4.367661476135254 | BCE Loss: 1.0302150249481201
Epoch 246 / 500 | iteration 25 / 30 | Total Loss: 5.434875965118408 | KNN Loss: 4.390585422515869 | BCE Loss: 1.0442904233932495
Epoch 247 / 500 | iteration 0 / 30 | Total Loss: 5.432035446166992 | KNN Loss: 4.413301467895508 | BCE Loss: 1.0187337398529053
Epoch 247 / 500 | iteration 5 / 30 | Total Loss: 5.403598785400391 | KNN Loss: 4.38935661315918 | BCE Loss: 1.01424241065979
Epoch 247 / 500 | iteration 10 / 30 | Total Loss: 5.3720831871032715 | KNN Loss: 4.34710168838501 | BCE Loss: 1.0249816179275513
Epoch 247 / 500 | iteration 15 / 30 | Total Loss: 5.411288738250732 | KNN Loss: 4.359678268432617 | BCE Loss: 1.0516104698181152
Epoch 247 / 500 | iteration 20 / 30 | Total Loss: 5.4093523025512695 | KNN Loss: 4.368225574493408 | B

Epoch 257 / 500 | iteration 5 / 30 | Total Loss: 5.399740219116211 | KNN Loss: 4.375702381134033 | BCE Loss: 1.0240377187728882
Epoch 257 / 500 | iteration 10 / 30 | Total Loss: 5.42772102355957 | KNN Loss: 4.405219078063965 | BCE Loss: 1.0225019454956055
Epoch 257 / 500 | iteration 15 / 30 | Total Loss: 5.3969316482543945 | KNN Loss: 4.373394966125488 | BCE Loss: 1.0235364437103271
Epoch 257 / 500 | iteration 20 / 30 | Total Loss: 5.416792392730713 | KNN Loss: 4.379599571228027 | BCE Loss: 1.037192702293396
Epoch 257 / 500 | iteration 25 / 30 | Total Loss: 5.404210090637207 | KNN Loss: 4.389542579650879 | BCE Loss: 1.0146675109863281
Epoch   258: reducing learning rate of group 0 to 3.3911e-05.
Epoch 258 / 500 | iteration 0 / 30 | Total Loss: 5.468334674835205 | KNN Loss: 4.430617332458496 | BCE Loss: 1.0377172231674194
Epoch 258 / 500 | iteration 5 / 30 | Total Loss: 5.462325096130371 | KNN Loss: 4.402691841125488 | BCE Loss: 1.059633493423462
Epoch 258 / 500 | iteration 10 / 30 | To

Epoch 267 / 500 | iteration 25 / 30 | Total Loss: 5.400324821472168 | KNN Loss: 4.360675811767578 | BCE Loss: 1.039649248123169
Epoch 268 / 500 | iteration 0 / 30 | Total Loss: 5.376794815063477 | KNN Loss: 4.35478401184082 | BCE Loss: 1.0220110416412354
Epoch 268 / 500 | iteration 5 / 30 | Total Loss: 5.384860038757324 | KNN Loss: 4.344861030578613 | BCE Loss: 1.0399987697601318
Epoch 268 / 500 | iteration 10 / 30 | Total Loss: 5.397288799285889 | KNN Loss: 4.365914821624756 | BCE Loss: 1.0313740968704224
Epoch 268 / 500 | iteration 15 / 30 | Total Loss: 5.394014835357666 | KNN Loss: 4.353979587554932 | BCE Loss: 1.0400352478027344
Epoch 268 / 500 | iteration 20 / 30 | Total Loss: 5.36430549621582 | KNN Loss: 4.345435619354248 | BCE Loss: 1.0188698768615723
Epoch 268 / 500 | iteration 25 / 30 | Total Loss: 5.384111404418945 | KNN Loss: 4.359631061553955 | BCE Loss: 1.0244803428649902
Epoch   269: reducing learning rate of group 0 to 2.3738e-05.
Epoch 269 / 500 | iteration 0 / 30 | Tot

Epoch 278 / 500 | iteration 15 / 30 | Total Loss: 5.388853549957275 | KNN Loss: 4.373793125152588 | BCE Loss: 1.015060544013977
Epoch 278 / 500 | iteration 20 / 30 | Total Loss: 5.3996429443359375 | KNN Loss: 4.396242141723633 | BCE Loss: 1.0034005641937256
Epoch 278 / 500 | iteration 25 / 30 | Total Loss: 5.392556190490723 | KNN Loss: 4.384540557861328 | BCE Loss: 1.0080156326293945
Epoch 279 / 500 | iteration 0 / 30 | Total Loss: 5.398913383483887 | KNN Loss: 4.3583784103393555 | BCE Loss: 1.0405352115631104
Epoch 279 / 500 | iteration 5 / 30 | Total Loss: 5.405203819274902 | KNN Loss: 4.370213031768799 | BCE Loss: 1.0349905490875244
Epoch 279 / 500 | iteration 10 / 30 | Total Loss: 5.397258758544922 | KNN Loss: 4.376405239105225 | BCE Loss: 1.0208537578582764
Epoch 279 / 500 | iteration 15 / 30 | Total Loss: 5.428540229797363 | KNN Loss: 4.3618316650390625 | BCE Loss: 1.0667083263397217
Epoch 279 / 500 | iteration 20 / 30 | Total Loss: 5.369275093078613 | KNN Loss: 4.341807842254639

Epoch 289 / 500 | iteration 5 / 30 | Total Loss: 5.380566596984863 | KNN Loss: 4.344812870025635 | BCE Loss: 1.0357539653778076
Epoch 289 / 500 | iteration 10 / 30 | Total Loss: 5.433135032653809 | KNN Loss: 4.394646167755127 | BCE Loss: 1.0384891033172607
Epoch 289 / 500 | iteration 15 / 30 | Total Loss: 5.355498790740967 | KNN Loss: 4.341590404510498 | BCE Loss: 1.0139082670211792
Epoch 289 / 500 | iteration 20 / 30 | Total Loss: 5.3794026374816895 | KNN Loss: 4.360113620758057 | BCE Loss: 1.0192891359329224
Epoch 289 / 500 | iteration 25 / 30 | Total Loss: 5.427642822265625 | KNN Loss: 4.418810844421387 | BCE Loss: 1.0088319778442383
Epoch 290 / 500 | iteration 0 / 30 | Total Loss: 5.415714263916016 | KNN Loss: 4.382504463195801 | BCE Loss: 1.0332099199295044
Epoch 290 / 500 | iteration 5 / 30 | Total Loss: 5.489725112915039 | KNN Loss: 4.473855972290039 | BCE Loss: 1.0158692598342896
Epoch 290 / 500 | iteration 10 / 30 | Total Loss: 5.429081916809082 | KNN Loss: 4.40468692779541 | 

Epoch 299 / 500 | iteration 25 / 30 | Total Loss: 5.402722358703613 | KNN Loss: 4.363831043243408 | BCE Loss: 1.038891315460205
Epoch 300 / 500 | iteration 0 / 30 | Total Loss: 5.426999568939209 | KNN Loss: 4.386746406555176 | BCE Loss: 1.0402531623840332
Epoch 300 / 500 | iteration 5 / 30 | Total Loss: 5.384426593780518 | KNN Loss: 4.363600254058838 | BCE Loss: 1.0208264589309692
Epoch 300 / 500 | iteration 10 / 30 | Total Loss: 5.3980207443237305 | KNN Loss: 4.37890625 | BCE Loss: 1.0191142559051514
Epoch 300 / 500 | iteration 15 / 30 | Total Loss: 5.495565414428711 | KNN Loss: 4.469158172607422 | BCE Loss: 1.026407241821289
Epoch 300 / 500 | iteration 20 / 30 | Total Loss: 5.411613941192627 | KNN Loss: 4.373678684234619 | BCE Loss: 1.0379351377487183
Epoch 300 / 500 | iteration 25 / 30 | Total Loss: 5.397150039672852 | KNN Loss: 4.370487689971924 | BCE Loss: 1.0266621112823486
Epoch 301 / 500 | iteration 0 / 30 | Total Loss: 5.426206111907959 | KNN Loss: 4.3838911056518555 | BCE Los

Epoch 310 / 500 | iteration 15 / 30 | Total Loss: 5.4004292488098145 | KNN Loss: 4.34157133102417 | BCE Loss: 1.0588579177856445
Epoch 310 / 500 | iteration 20 / 30 | Total Loss: 5.428223609924316 | KNN Loss: 4.386580944061279 | BCE Loss: 1.0416425466537476
Epoch 310 / 500 | iteration 25 / 30 | Total Loss: 5.385236740112305 | KNN Loss: 4.357962608337402 | BCE Loss: 1.027274250984192
Epoch 311 / 500 | iteration 0 / 30 | Total Loss: 5.369210243225098 | KNN Loss: 4.35513973236084 | BCE Loss: 1.0140706300735474
Epoch 311 / 500 | iteration 5 / 30 | Total Loss: 5.376352310180664 | KNN Loss: 4.364897727966309 | BCE Loss: 1.0114543437957764
Epoch 311 / 500 | iteration 10 / 30 | Total Loss: 5.383723258972168 | KNN Loss: 4.367844581604004 | BCE Loss: 1.015878438949585
Epoch 311 / 500 | iteration 15 / 30 | Total Loss: 5.401627540588379 | KNN Loss: 4.374452114105225 | BCE Loss: 1.0271755456924438
Epoch 311 / 500 | iteration 20 / 30 | Total Loss: 5.420899868011475 | KNN Loss: 4.371690273284912 | BC

Epoch 321 / 500 | iteration 5 / 30 | Total Loss: 5.38386344909668 | KNN Loss: 4.342007637023926 | BCE Loss: 1.041855812072754
Epoch 321 / 500 | iteration 10 / 30 | Total Loss: 5.460561275482178 | KNN Loss: 4.411808490753174 | BCE Loss: 1.0487526655197144
Epoch 321 / 500 | iteration 15 / 30 | Total Loss: 5.411027431488037 | KNN Loss: 4.394240379333496 | BCE Loss: 1.016787052154541
Epoch 321 / 500 | iteration 20 / 30 | Total Loss: 5.392977237701416 | KNN Loss: 4.3748884201049805 | BCE Loss: 1.0180888175964355
Epoch 321 / 500 | iteration 25 / 30 | Total Loss: 5.395218372344971 | KNN Loss: 4.3823347091674805 | BCE Loss: 1.0128835439682007
Epoch 322 / 500 | iteration 0 / 30 | Total Loss: 5.418330192565918 | KNN Loss: 4.40716552734375 | BCE Loss: 1.0111644268035889
Epoch 322 / 500 | iteration 5 / 30 | Total Loss: 5.407203674316406 | KNN Loss: 4.37639045715332 | BCE Loss: 1.030813217163086
Epoch 322 / 500 | iteration 10 / 30 | Total Loss: 5.383775234222412 | KNN Loss: 4.348777770996094 | BCE 

Epoch 331 / 500 | iteration 25 / 30 | Total Loss: 5.446134090423584 | KNN Loss: 4.411021709442139 | BCE Loss: 1.0351123809814453
Epoch 332 / 500 | iteration 0 / 30 | Total Loss: 5.407476902008057 | KNN Loss: 4.366114139556885 | BCE Loss: 1.0413627624511719
Epoch 332 / 500 | iteration 5 / 30 | Total Loss: 5.422477722167969 | KNN Loss: 4.3732194900512695 | BCE Loss: 1.0492579936981201
Epoch 332 / 500 | iteration 10 / 30 | Total Loss: 5.433860778808594 | KNN Loss: 4.405467987060547 | BCE Loss: 1.028393030166626
Epoch 332 / 500 | iteration 15 / 30 | Total Loss: 5.423696517944336 | KNN Loss: 4.408357620239258 | BCE Loss: 1.015338659286499
Epoch 332 / 500 | iteration 20 / 30 | Total Loss: 5.436310291290283 | KNN Loss: 4.393581867218018 | BCE Loss: 1.0427284240722656
Epoch 332 / 500 | iteration 25 / 30 | Total Loss: 5.403205394744873 | KNN Loss: 4.362895488739014 | BCE Loss: 1.0403097867965698
Epoch 333 / 500 | iteration 0 / 30 | Total Loss: 5.359805107116699 | KNN Loss: 4.340354919433594 | B

Epoch 342 / 500 | iteration 15 / 30 | Total Loss: 5.449846267700195 | KNN Loss: 4.4107160568237305 | BCE Loss: 1.0391300916671753
Epoch 342 / 500 | iteration 20 / 30 | Total Loss: 5.355711460113525 | KNN Loss: 4.359804153442383 | BCE Loss: 0.9959074854850769
Epoch 342 / 500 | iteration 25 / 30 | Total Loss: 5.4524431228637695 | KNN Loss: 4.402338027954102 | BCE Loss: 1.050105094909668
Epoch 343 / 500 | iteration 0 / 30 | Total Loss: 5.374826908111572 | KNN Loss: 4.371763229370117 | BCE Loss: 1.003063678741455
Epoch 343 / 500 | iteration 5 / 30 | Total Loss: 5.434188365936279 | KNN Loss: 4.382929801940918 | BCE Loss: 1.0512584447860718
Epoch 343 / 500 | iteration 10 / 30 | Total Loss: 5.376378059387207 | KNN Loss: 4.356935024261475 | BCE Loss: 1.0194430351257324
Epoch 343 / 500 | iteration 15 / 30 | Total Loss: 5.396095275878906 | KNN Loss: 4.36132287979126 | BCE Loss: 1.0347723960876465
Epoch 343 / 500 | iteration 20 / 30 | Total Loss: 5.379488945007324 | KNN Loss: 4.367842197418213 | 

Epoch 353 / 500 | iteration 5 / 30 | Total Loss: 5.343733787536621 | KNN Loss: 4.328978061676025 | BCE Loss: 1.0147559642791748
Epoch 353 / 500 | iteration 10 / 30 | Total Loss: 5.3775739669799805 | KNN Loss: 4.347911834716797 | BCE Loss: 1.0296621322631836
Epoch 353 / 500 | iteration 15 / 30 | Total Loss: 5.395534038543701 | KNN Loss: 4.375138759613037 | BCE Loss: 1.0203951597213745
Epoch 353 / 500 | iteration 20 / 30 | Total Loss: 5.37571382522583 | KNN Loss: 4.348917484283447 | BCE Loss: 1.0267962217330933
Epoch 353 / 500 | iteration 25 / 30 | Total Loss: 5.369720458984375 | KNN Loss: 4.368724346160889 | BCE Loss: 1.0009958744049072
Epoch 354 / 500 | iteration 0 / 30 | Total Loss: 5.381543159484863 | KNN Loss: 4.357430458068848 | BCE Loss: 1.0241127014160156
Epoch 354 / 500 | iteration 5 / 30 | Total Loss: 5.3741254806518555 | KNN Loss: 4.366258144378662 | BCE Loss: 1.0078675746917725
Epoch 354 / 500 | iteration 10 / 30 | Total Loss: 5.385378360748291 | KNN Loss: 4.367142677307129 |

Epoch 363 / 500 | iteration 25 / 30 | Total Loss: 5.407011985778809 | KNN Loss: 4.392210483551025 | BCE Loss: 1.0148017406463623
Epoch 364 / 500 | iteration 0 / 30 | Total Loss: 5.409515380859375 | KNN Loss: 4.377837657928467 | BCE Loss: 1.031677484512329
Epoch 364 / 500 | iteration 5 / 30 | Total Loss: 5.403809070587158 | KNN Loss: 4.39387845993042 | BCE Loss: 1.0099307298660278
Epoch 364 / 500 | iteration 10 / 30 | Total Loss: 5.474316596984863 | KNN Loss: 4.43327522277832 | BCE Loss: 1.0410414934158325
Epoch 364 / 500 | iteration 15 / 30 | Total Loss: 5.4332756996154785 | KNN Loss: 4.3868865966796875 | BCE Loss: 1.046389102935791
Epoch 364 / 500 | iteration 20 / 30 | Total Loss: 5.426518440246582 | KNN Loss: 4.395773410797119 | BCE Loss: 1.030745029449463
Epoch 364 / 500 | iteration 25 / 30 | Total Loss: 5.413646697998047 | KNN Loss: 4.368694305419922 | BCE Loss: 1.044952392578125
Epoch 365 / 500 | iteration 0 / 30 | Total Loss: 5.385837078094482 | KNN Loss: 4.37705135345459 | BCE L

Epoch 374 / 500 | iteration 15 / 30 | Total Loss: 5.411194801330566 | KNN Loss: 4.385404109954834 | BCE Loss: 1.0257906913757324
Epoch 374 / 500 | iteration 20 / 30 | Total Loss: 5.441565990447998 | KNN Loss: 4.388049602508545 | BCE Loss: 1.0535163879394531
Epoch 374 / 500 | iteration 25 / 30 | Total Loss: 5.418768882751465 | KNN Loss: 4.385362148284912 | BCE Loss: 1.0334067344665527
Epoch 375 / 500 | iteration 0 / 30 | Total Loss: 5.374481201171875 | KNN Loss: 4.347359657287598 | BCE Loss: 1.027121663093567
Epoch 375 / 500 | iteration 5 / 30 | Total Loss: 5.363791465759277 | KNN Loss: 4.356779098510742 | BCE Loss: 1.0070123672485352
Epoch 375 / 500 | iteration 10 / 30 | Total Loss: 5.434025287628174 | KNN Loss: 4.372797966003418 | BCE Loss: 1.0612273216247559
Epoch 375 / 500 | iteration 15 / 30 | Total Loss: 5.363168239593506 | KNN Loss: 4.351267337799072 | BCE Loss: 1.0119010210037231
Epoch 375 / 500 | iteration 20 / 30 | Total Loss: 5.3906354904174805 | KNN Loss: 4.375050067901611 |

Epoch 385 / 500 | iteration 5 / 30 | Total Loss: 5.421492099761963 | KNN Loss: 4.417176723480225 | BCE Loss: 1.0043152570724487
Epoch 385 / 500 | iteration 10 / 30 | Total Loss: 5.388345718383789 | KNN Loss: 4.371342182159424 | BCE Loss: 1.0170034170150757
Epoch 385 / 500 | iteration 15 / 30 | Total Loss: 5.433860778808594 | KNN Loss: 4.3962507247924805 | BCE Loss: 1.0376101732254028
Epoch 385 / 500 | iteration 20 / 30 | Total Loss: 5.438559055328369 | KNN Loss: 4.406874179840088 | BCE Loss: 1.0316849946975708
Epoch 385 / 500 | iteration 25 / 30 | Total Loss: 5.422153472900391 | KNN Loss: 4.391071319580078 | BCE Loss: 1.0310821533203125
Epoch 386 / 500 | iteration 0 / 30 | Total Loss: 5.433164596557617 | KNN Loss: 4.383158206939697 | BCE Loss: 1.05000638961792
Epoch 386 / 500 | iteration 5 / 30 | Total Loss: 5.4013752937316895 | KNN Loss: 4.374415397644043 | BCE Loss: 1.0269598960876465
Epoch 386 / 500 | iteration 10 / 30 | Total Loss: 5.420682907104492 | KNN Loss: 4.400352954864502 | 

Epoch 395 / 500 | iteration 25 / 30 | Total Loss: 5.360018253326416 | KNN Loss: 4.376079559326172 | BCE Loss: 0.9839386343955994
Epoch 396 / 500 | iteration 0 / 30 | Total Loss: 5.405110836029053 | KNN Loss: 4.389258861541748 | BCE Loss: 1.0158518552780151
Epoch 396 / 500 | iteration 5 / 30 | Total Loss: 5.391532897949219 | KNN Loss: 4.373805522918701 | BCE Loss: 1.0177271366119385
Epoch 396 / 500 | iteration 10 / 30 | Total Loss: 5.367646217346191 | KNN Loss: 4.3445892333984375 | BCE Loss: 1.0230567455291748
Epoch 396 / 500 | iteration 15 / 30 | Total Loss: 5.360358238220215 | KNN Loss: 4.332310199737549 | BCE Loss: 1.028047800064087
Epoch 396 / 500 | iteration 20 / 30 | Total Loss: 5.413618087768555 | KNN Loss: 4.380745887756348 | BCE Loss: 1.0328723192214966
Epoch 396 / 500 | iteration 25 / 30 | Total Loss: 5.445387840270996 | KNN Loss: 4.381087779998779 | BCE Loss: 1.0643000602722168
Epoch 397 / 500 | iteration 0 / 30 | Total Loss: 5.389957904815674 | KNN Loss: 4.355424404144287 | 

Epoch 406 / 500 | iteration 15 / 30 | Total Loss: 5.4188551902771 | KNN Loss: 4.404438495635986 | BCE Loss: 1.0144168138504028
Epoch 406 / 500 | iteration 20 / 30 | Total Loss: 5.377321243286133 | KNN Loss: 4.359106540679932 | BCE Loss: 1.018214464187622
Epoch 406 / 500 | iteration 25 / 30 | Total Loss: 5.39042854309082 | KNN Loss: 4.364990234375 | BCE Loss: 1.0254380702972412
Epoch 407 / 500 | iteration 0 / 30 | Total Loss: 5.430169105529785 | KNN Loss: 4.408502578735352 | BCE Loss: 1.0216665267944336
Epoch 407 / 500 | iteration 5 / 30 | Total Loss: 5.390746116638184 | KNN Loss: 4.359445571899414 | BCE Loss: 1.0313005447387695
Epoch 407 / 500 | iteration 10 / 30 | Total Loss: 5.367217063903809 | KNN Loss: 4.368931770324707 | BCE Loss: 0.9982852339744568
Epoch 407 / 500 | iteration 15 / 30 | Total Loss: 5.391040325164795 | KNN Loss: 4.3637375831604 | BCE Loss: 1.0273027420043945
Epoch 407 / 500 | iteration 20 / 30 | Total Loss: 5.412813186645508 | KNN Loss: 4.372248649597168 | BCE Loss

Epoch 417 / 500 | iteration 5 / 30 | Total Loss: 5.375885009765625 | KNN Loss: 4.356439113616943 | BCE Loss: 1.0194460153579712
Epoch 417 / 500 | iteration 10 / 30 | Total Loss: 5.4187331199646 | KNN Loss: 4.389726638793945 | BCE Loss: 1.0290064811706543
Epoch 417 / 500 | iteration 15 / 30 | Total Loss: 5.416496276855469 | KNN Loss: 4.386774063110352 | BCE Loss: 1.0297220945358276
Epoch 417 / 500 | iteration 20 / 30 | Total Loss: 5.368431091308594 | KNN Loss: 4.361526012420654 | BCE Loss: 1.0069053173065186
Epoch 417 / 500 | iteration 25 / 30 | Total Loss: 5.48036003112793 | KNN Loss: 4.4274702072143555 | BCE Loss: 1.0528895854949951
Epoch 418 / 500 | iteration 0 / 30 | Total Loss: 5.371793270111084 | KNN Loss: 4.346191883087158 | BCE Loss: 1.0256013870239258
Epoch 418 / 500 | iteration 5 / 30 | Total Loss: 5.397812843322754 | KNN Loss: 4.396267890930176 | BCE Loss: 1.001544713973999
Epoch 418 / 500 | iteration 10 / 30 | Total Loss: 5.39389181137085 | KNN Loss: 4.367108345031738 | BCE 

Epoch 427 / 500 | iteration 25 / 30 | Total Loss: 5.397862434387207 | KNN Loss: 4.37176513671875 | BCE Loss: 1.026097059249878
Epoch 428 / 500 | iteration 0 / 30 | Total Loss: 5.4377665519714355 | KNN Loss: 4.401102542877197 | BCE Loss: 1.0366640090942383
Epoch 428 / 500 | iteration 5 / 30 | Total Loss: 5.402417182922363 | KNN Loss: 4.383434772491455 | BCE Loss: 1.0189826488494873
Epoch 428 / 500 | iteration 10 / 30 | Total Loss: 5.340904235839844 | KNN Loss: 4.329569339752197 | BCE Loss: 1.011334776878357
Epoch 428 / 500 | iteration 15 / 30 | Total Loss: 5.431185245513916 | KNN Loss: 4.398580074310303 | BCE Loss: 1.0326051712036133
Epoch 428 / 500 | iteration 20 / 30 | Total Loss: 5.5016703605651855 | KNN Loss: 4.452282905578613 | BCE Loss: 1.0493875741958618
Epoch 428 / 500 | iteration 25 / 30 | Total Loss: 5.378060340881348 | KNN Loss: 4.3455400466918945 | BCE Loss: 1.0325204133987427
Epoch 429 / 500 | iteration 0 / 30 | Total Loss: 5.3853440284729 | KNN Loss: 4.387049674987793 | BC

Epoch 438 / 500 | iteration 15 / 30 | Total Loss: 5.367556571960449 | KNN Loss: 4.350622177124023 | BCE Loss: 1.0169346332550049
Epoch 438 / 500 | iteration 20 / 30 | Total Loss: 5.426031589508057 | KNN Loss: 4.388668537139893 | BCE Loss: 1.037363052368164
Epoch 438 / 500 | iteration 25 / 30 | Total Loss: 5.385848045349121 | KNN Loss: 4.380543231964111 | BCE Loss: 1.0053046941757202
Epoch 439 / 500 | iteration 0 / 30 | Total Loss: 5.348920822143555 | KNN Loss: 4.359360694885254 | BCE Loss: 0.9895601272583008
Epoch 439 / 500 | iteration 5 / 30 | Total Loss: 5.410243034362793 | KNN Loss: 4.376736640930176 | BCE Loss: 1.0335063934326172
Epoch 439 / 500 | iteration 10 / 30 | Total Loss: 5.3865861892700195 | KNN Loss: 4.375328540802002 | BCE Loss: 1.0112576484680176
Epoch 439 / 500 | iteration 15 / 30 | Total Loss: 5.378180980682373 | KNN Loss: 4.356457233428955 | BCE Loss: 1.021723747253418
Epoch 439 / 500 | iteration 20 / 30 | Total Loss: 5.376687049865723 | KNN Loss: 4.3452959060668945 |

Epoch 449 / 500 | iteration 5 / 30 | Total Loss: 5.415876388549805 | KNN Loss: 4.390794277191162 | BCE Loss: 1.025081992149353
Epoch 449 / 500 | iteration 10 / 30 | Total Loss: 5.352718353271484 | KNN Loss: 4.345103740692139 | BCE Loss: 1.0076146125793457
Epoch 449 / 500 | iteration 15 / 30 | Total Loss: 5.429523944854736 | KNN Loss: 4.387258529663086 | BCE Loss: 1.0422654151916504
Epoch 449 / 500 | iteration 20 / 30 | Total Loss: 5.395442962646484 | KNN Loss: 4.361802577972412 | BCE Loss: 1.0336403846740723
Epoch 449 / 500 | iteration 25 / 30 | Total Loss: 5.387103080749512 | KNN Loss: 4.352215766906738 | BCE Loss: 1.034887433052063
Epoch 450 / 500 | iteration 0 / 30 | Total Loss: 5.407593727111816 | KNN Loss: 4.380495548248291 | BCE Loss: 1.0270979404449463
Epoch 450 / 500 | iteration 5 / 30 | Total Loss: 5.409254550933838 | KNN Loss: 4.3655805587768555 | BCE Loss: 1.043674111366272
Epoch 450 / 500 | iteration 10 / 30 | Total Loss: 5.382393836975098 | KNN Loss: 4.349435806274414 | BC

Epoch 459 / 500 | iteration 25 / 30 | Total Loss: 5.3846659660339355 | KNN Loss: 4.376435279846191 | BCE Loss: 1.0082306861877441
Epoch 460 / 500 | iteration 0 / 30 | Total Loss: 5.380387306213379 | KNN Loss: 4.3619232177734375 | BCE Loss: 1.0184638500213623
Epoch 460 / 500 | iteration 5 / 30 | Total Loss: 5.428438186645508 | KNN Loss: 4.376776218414307 | BCE Loss: 1.0516619682312012
Epoch 460 / 500 | iteration 10 / 30 | Total Loss: 5.414963722229004 | KNN Loss: 4.393837928771973 | BCE Loss: 1.0211255550384521
Epoch 460 / 500 | iteration 15 / 30 | Total Loss: 5.447108268737793 | KNN Loss: 4.378981590270996 | BCE Loss: 1.0681264400482178
Epoch 460 / 500 | iteration 20 / 30 | Total Loss: 5.400277614593506 | KNN Loss: 4.39302396774292 | BCE Loss: 1.0072537660598755
Epoch 460 / 500 | iteration 25 / 30 | Total Loss: 5.373232841491699 | KNN Loss: 4.364997863769531 | BCE Loss: 1.0082347393035889
Epoch 461 / 500 | iteration 0 / 30 | Total Loss: 5.353822708129883 | KNN Loss: 4.343193054199219 |

Epoch 470 / 500 | iteration 15 / 30 | Total Loss: 5.379047393798828 | KNN Loss: 4.343362808227539 | BCE Loss: 1.0356848239898682
Epoch 470 / 500 | iteration 20 / 30 | Total Loss: 5.382643699645996 | KNN Loss: 4.343164920806885 | BCE Loss: 1.0394785404205322
Epoch 470 / 500 | iteration 25 / 30 | Total Loss: 5.374365329742432 | KNN Loss: 4.361779689788818 | BCE Loss: 1.0125857591629028
Epoch 471 / 500 | iteration 0 / 30 | Total Loss: 5.399739742279053 | KNN Loss: 4.361929416656494 | BCE Loss: 1.037810206413269
Epoch 471 / 500 | iteration 5 / 30 | Total Loss: 5.390772819519043 | KNN Loss: 4.3784918785095215 | BCE Loss: 1.012281060218811
Epoch 471 / 500 | iteration 10 / 30 | Total Loss: 5.411612510681152 | KNN Loss: 4.391783714294434 | BCE Loss: 1.0198289155960083
Epoch 471 / 500 | iteration 15 / 30 | Total Loss: 5.435630798339844 | KNN Loss: 4.365993499755859 | BCE Loss: 1.0696371793746948
Epoch 471 / 500 | iteration 20 / 30 | Total Loss: 5.437350273132324 | KNN Loss: 4.367086887359619 | 

Epoch 481 / 500 | iteration 5 / 30 | Total Loss: 5.416067123413086 | KNN Loss: 4.398866176605225 | BCE Loss: 1.0172009468078613
Epoch 481 / 500 | iteration 10 / 30 | Total Loss: 5.411890029907227 | KNN Loss: 4.358218193054199 | BCE Loss: 1.0536715984344482
Epoch 481 / 500 | iteration 15 / 30 | Total Loss: 5.356133460998535 | KNN Loss: 4.349209785461426 | BCE Loss: 1.0069239139556885
Epoch 481 / 500 | iteration 20 / 30 | Total Loss: 5.366866111755371 | KNN Loss: 4.337185859680176 | BCE Loss: 1.0296804904937744
Epoch 481 / 500 | iteration 25 / 30 | Total Loss: 5.434194564819336 | KNN Loss: 4.369448184967041 | BCE Loss: 1.0647461414337158
Epoch 482 / 500 | iteration 0 / 30 | Total Loss: 5.406996250152588 | KNN Loss: 4.394435405731201 | BCE Loss: 1.0125607252120972
Epoch 482 / 500 | iteration 5 / 30 | Total Loss: 5.428384780883789 | KNN Loss: 4.412336826324463 | BCE Loss: 1.0160481929779053
Epoch 482 / 500 | iteration 10 / 30 | Total Loss: 5.390324592590332 | KNN Loss: 4.360025405883789 | 

Epoch 491 / 500 | iteration 25 / 30 | Total Loss: 5.432872772216797 | KNN Loss: 4.369320869445801 | BCE Loss: 1.0635517835617065
Epoch 492 / 500 | iteration 0 / 30 | Total Loss: 5.410745143890381 | KNN Loss: 4.365313529968262 | BCE Loss: 1.0454316139221191
Epoch 492 / 500 | iteration 5 / 30 | Total Loss: 5.4033637046813965 | KNN Loss: 4.357863426208496 | BCE Loss: 1.0455002784729004
Epoch 492 / 500 | iteration 10 / 30 | Total Loss: 5.490099906921387 | KNN Loss: 4.437799453735352 | BCE Loss: 1.052300214767456
Epoch 492 / 500 | iteration 15 / 30 | Total Loss: 5.450606346130371 | KNN Loss: 4.41609525680542 | BCE Loss: 1.0345113277435303
Epoch 492 / 500 | iteration 20 / 30 | Total Loss: 5.399035930633545 | KNN Loss: 4.3543572425842285 | BCE Loss: 1.044678807258606
Epoch 492 / 500 | iteration 25 / 30 | Total Loss: 5.3918023109436035 | KNN Loss: 4.384982585906982 | BCE Loss: 1.006819725036621
Epoch 493 / 500 | iteration 0 / 30 | Total Loss: 5.442098617553711 | KNN Loss: 4.405357360839844 | B

In [13]:
outputs, iterm = model(dataset[67][0].unsqueeze(0).to(device), return_intermidiate=True)
print(outputs)
print(outputs.min())
print(outputs.max())

tensor([[ 3.4010e+00,  2.2324e+00,  1.9259e+00,  4.4305e+00,  4.3083e+00,
          8.2646e-01,  2.8776e+00,  2.2525e+00,  1.6196e+00,  2.5330e+00,
          2.7708e+00,  2.3513e+00,  8.7957e-01,  1.9581e+00,  1.4848e+00,
          1.2184e+00,  3.0277e+00,  1.7433e+00,  2.0688e+00,  1.9962e+00,
          1.8550e+00,  1.6882e+00,  2.8604e+00,  3.1876e+00,  1.7777e+00,
          1.9339e+00,  1.3438e+00,  9.6325e-01,  1.2171e+00,  3.6746e-01,
          7.2058e-02,  8.8119e-01,  2.8268e-01,  1.0370e+00,  1.0608e+00,
          1.9995e+00,  1.0214e+00,  4.1382e+00,  4.7325e-01,  1.4965e+00,
          1.0061e+00, -6.7459e-01, -6.6089e-01,  2.7963e+00,  1.9189e+00,
          3.0279e-01, -2.4250e-02,  2.4552e-02,  1.9926e+00,  1.7274e+00,
          1.0519e+00, -2.2927e-02,  1.2995e+00,  7.2269e-01, -4.8415e-01,
          1.5939e+00,  1.9762e+00,  1.1411e+00,  1.4182e+00,  1.0440e+00,
          4.4690e-01,  9.6211e-01,  2.4799e-01,  1.5313e+00,  1.4831e+00,
          1.7758e+00, -1.8357e+00,  5.

In [14]:
plt.figure()
plt.plot(losses)
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [15]:
dataset.transform = torchvision.transforms.Compose([
    BinaryEncodingTransform(mapping=dataset.items_to_idx),
]
)
dataset.target_transform = torchvision.transforms.Compose([
    BinaryEncodingTransform(mapping=dataset.items_to_idx),
]
)

In [18]:
dataset_ = [d[0].cpu() for d in dataset]

In [19]:
model = model.eval().cpu()
projections = model.calculate_intermidiate(dataset_)

100%|██████████| 15/15 [00:00<00:00, 76.28it/s]


In [20]:
distances = pairwise_distances(projections)
# distances = np.triu(distances)
distances_f = distances.flatten()

plt.matshow(distances)
plt.colorbar()
plt.figure()
plt.hist(distances_f[distances_f > 0], bins=1000)
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

# Fit DBSCAN and calculate indices


In [21]:
clusters = DBSCAN(eps=0.2, min_samples=80).fit_predict(projections)
# scores = []
# best_score = float('inf')
# clusters = None
# range_ = list(range(5, 20))
# for k in tqdm(range_):
#     y = GaussianMixture(n_components=k).fit_predict(projections)
#     cur_score = davies_bouldin_score(projections, y)
#     scores.append(cur_score)
    
#     if cur_score < best_score:
#         best_score = cur_score
#         clusters = y

In [22]:
perplexity = 100

p = reduce_dims_and_plot(projections[clusters != -1],
                         y=clusters[clusters != -1],
                         title=f'perplexity: {perplexity}',
                         file_name=None,
                         perplexity=perplexity,
                         library='Multicore-TSNE',
                         perform_PCA=False,
                         projected=None,
                         figure_type='2d',
                         show_figure=True,
                         close_figure=False,
                         text=None)    

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [23]:
# from sklearn.tree import DecisionTreeClassifier
# from sklearn import tree
# from sklearn.tree import _tree

In [24]:
tensor_dataset = torch.stack(dataset_)

In [25]:
# clf = DecisionTreeClassifier(max_depth=200, min_samples_leaf=5)
# clf = clf.fit(tensor_dataset[clusters!=-1], clusters[clusters != -1])
# print(clf.score(tensor_dataset[clusters!=-1], clusters[clusters != -1]))
# print(clf.get_depth())

In [26]:
# scores = []
# for min_samples in range(1,50, 1):
#     clf = DecisionTreeClassifier(max_depth=200, min_samples_leaf=min_samples)
#     clf = clf.fit(tensor_dataset[clusters!=-1], clusters[clusters != -1])
#     scores.append(clf.score(tensor_dataset[clusters!=-1], clusters[clusters != -1]))
    
# plt.figure()
# plt.plot(list(range(1,50, 1)), scores)
# plt.show()

In [27]:
def get_rules(tree, feature_names, class_names):
    tree_ = tree.tree_
    feature_name = [
        feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!"
        for i in tree_.feature
    ]

    paths = []
    path = []
    
    def recurse(node, path, paths):
        if tree_.feature[node] != _tree.TREE_UNDEFINED:
            name = feature_name[node]
            threshold = tree_.threshold[node]
            p1, p2 = list(path), list(path)
#             p1 += [f"({name} <= {np.round(threshold, 3)})"]
            p1 += [(name, '<=', np.round(threshold, 3))]
            recurse(tree_.children_left[node], p1, paths)
            p2 += [(name, '>', np.round(threshold, 3))]
            recurse(tree_.children_right[node], p2, paths)
        else:
            path += [(tree_.value[node], tree_.n_node_samples[node])]
            paths += [path]
            
    recurse(0, path, paths)

    # sort by samples count
    samples_count = [p[-1][1] for p in paths]
    ii = list(np.argsort(samples_count))
    paths = [paths[i] for i in reversed(ii)]
    
    rules = []
    for path in paths:
        rule = []
        
        for p in path[:-1]:
            rule += [p]
        target = " then "
        if class_names is None:
            target += "response: "+str(np.round(path[-1][0][0][0],3))
        else:
            classes = path[-1][0][0]
            l = np.argmax(classes)
            target += f"class: {class_names[l]} (proba: {np.round(100.0*classes[l]/np.sum(classes),2)}%)"
           
        proba = np.round(100.0*classes[l]/np.sum(classes),2)
        target += f" | based on {path[-1][1]:,} samples"
        rule_wrapper = {'target': target, 'rule': rule, 'proba': proba}
        rules += [rule_wrapper]
        
    return rules

In [28]:
# rules = get_rules(clf, dataset.items, clusters[clusters != -1])

# for rule in rules:
#     n_pos = 0
#     for c,p,v in rule['rule']:
#         if p == '>':
#             n_pos += 1
#     rule['pos'] = n_pos

In [29]:
# plt.figure()
# probs = [r['proba'] for r in rules]
# plt.hist(probs, bins = 100)
# plt.show()

In [30]:
# rules = sorted(rules, key=lambda x:x['pos'])
# rules = [r for r in rules if r['proba'] > 50]
# print(len(rules))

In [31]:
# for i in range(17):
#     r_i = rules[i]
#     print(f"------------- rule {i} length {len(r_i)} -------------")
#     print(r_i)

# Train a Soft-Decision-Tree given the self-labels

## Prepare the dataset

In [32]:
tree_dataset = list(zip(tensor_dataset[clusters!=-1], clusters[clusters != -1]))
batch_size = 512
tree_loader = torch.utils.data.DataLoader(tree_dataset, batch_size=batch_size, shuffle=True)

# Define how we prune the weights of a node

In [33]:
def prune_node(node_weights, factor=1):
    w = node_weights.cpu().detach().numpy()
    mean_ = np.mean(w)
    std_ = np.std(w)
    node_weights[((mean_ - std_ * factor) < node_weights) & (node_weights < (mean_ + std_ * factor))] = 0
    return node_weights

def prune_node_keep(node_weights, keep=4):
    w = node_weights.cpu().detach().numpy()
    throw_idx = np.argsort(abs(w))[:-keep]
    node_weights[throw_idx] = 0
    return node_weights

def prune_tree(tree_, factor):
    new_weights = tree_.inner_nodes.weight.clone()
    for i in range(new_weights.shape[0]):
        res = prune_node_keep(new_weights[i, :], factor)
        new_weights[i, :] = res

    with torch.no_grad():
        tree_.inner_nodes.weight.copy_(new_weights)
        
def sparseness(x):
    s = []
    for i in range(x.shape[0]):
        x_ = x[i, :]
        sp = (len(x_) - torch.norm(x_, 0).item()) / len(x_)
        s.append(sp)
    return np.mean(s)

def compute_regularization_by_level(tree):
    total_reg = 0
    for i in range(tree.inner_nodes.weight.shape[0]):
        cur_level = np.floor(np.log2(i+1))
        node_reg = torch.norm(tree.inner_nodes.weight[i].view(-1), 2)
        total_reg += 2**(-cur_level) * node_reg
    return total_reg

def show_sparseness(tree):
    avg_sp = sparseness(tree.inner_nodes.weight)
    print(f"Average sparseness: {avg_sp}")
    layer = 0
    sps = []
    for i in range(tree.inner_nodes.weight.shape[0]):
        cur_layer = int(np.floor(np.log2(i+1)))
        if cur_layer != layer:
            print(f"layer {layer}: {np.mean(sps)}")
            sps = []
            layer = cur_layer

        x_ = tree.inner_nodes.weight[i, :]
        sp = (len(x_) - torch.norm(x_, 0).item()) / len(x_)
        sps.append(sp)
        
    return avg_sp

## Training configurations

In [34]:
def do_epoch(model, loader, device, log_interval, losses, accs, epoch, iteration):
    model = model.train()
    for batch_idx, (data, target) in enumerate(loader):
        iteration += 1
        data, target = data.to(device), target.to(device)

        output, penalty = tree.forward(data)

        # Loss
        loss_tree = criterion(output, target.view(-1))

        # Penalty
        loss_tree += penalty

        # Sparse regularization
#         fc_params = torch.cat([x.view(-1) for x in tree.inner_nodes.parameters()])
#         regularization = sparsity_lamda * torch.norm(fc_params, 2)
        regularization = sparsity_lamda * compute_regularization_by_level(tree)
        loss = loss_tree

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        losses.append(loss.item())

        pred = output.data.max(1)[1]
        correct = pred.eq(target.view(-1).data).sum()
        accs.append(correct.item() / data.size()[0])

        # Print training status
        if batch_idx % log_interval == 0:
            print(f"Epoch: {epoch:02d} | Batch: {batch_idx:03d} / {len(loader):03d} | Total loss: {loss.item():.3f} | Reg loss: {regularization.item():.3f} | Tree loss: {loss_tree.item():.3f} | Accuracy: {correct.item() / data.size()[0]:03f} | {round((time.time() - start_time) / iteration, 3)} sec/iter")
            
    return iteration


In [35]:
lr = 5e-3
weight_decay = 5e-4
sparsity_lamda = 2e-3
epochs = 100
output_dim = len(set(clusters))
log_interval = 1
use_cuda = device != 'cpu'

In [36]:
tree = SDT(input_dim=tensor_dataset.shape[1], output_dim=len(clusters - 1), depth=tree_depth, lamda=1e-3, use_cuda=use_cuda)
optimizer = torch.optim.Adam(tree.parameters(),
                                 lr=lr,
                                 weight_decay=weight_decay)
criterion = nn.CrossEntropyLoss()
tree = tree.to(device)

In [37]:
losses = []
accs = []
sparsity = []

In [38]:
start_time = time.time()
iteration = 0
for epoch in range(epochs):
    # Training
    avg_sp = show_sparseness(tree)
    sparsity.append(avg_sp)
    iteration = do_epoch(tree, tree_loader, device, log_interval, losses, accs, epoch, iteration)
    
    if epoch % 1 == 0:
        prune_tree(tree, factor=3)
        

Average sparseness: 0.0
layer 0: 0.0
layer 1: 0.0
layer 2: 0.0
layer 3: 0.0
layer 4: 0.0
layer 5: 0.0
layer 6: 0.0
Epoch: 00 | Batch: 000 / 029 | Total loss: 9.632 | Reg loss: 0.009 | Tree loss: 9.632 | Accuracy: 0.000000 | 0.279 sec/iter
Epoch: 00 | Batch: 001 / 029 | Total loss: 9.619 | Reg loss: 0.009 | Tree loss: 9.619 | Accuracy: 0.000000 | 0.254 sec/iter
Epoch: 00 | Batch: 002 / 029 | Total loss: 9.609 | Reg loss: 0.008 | Tree loss: 9.609 | Accuracy: 0.000000 | 0.247 sec/iter
Epoch: 00 | Batch: 003 / 029 | Total loss: 9.603 | Reg loss: 0.008 | Tree loss: 9.603 | Accuracy: 0.000000 | 0.243 sec/iter
Epoch: 00 | Batch: 004 / 029 | Total loss: 9.588 | Reg loss: 0.008 | Tree loss: 9.588 | Accuracy: 0.000000 | 0.243 sec/iter
Epoch: 00 | Batch: 005 / 029 | Total loss: 9.586 | Reg loss: 0.008 | Tree loss: 9.586 | Accuracy: 0.000000 | 0.242 sec/iter
Epoch: 00 | Batch: 006 / 029 | Total loss: 9.570 | Reg loss: 0.008 | Tree loss: 9.570 | Accuracy: 0.003906 | 0.241 sec/iter
Epoch: 00 | Batch

Epoch: 02 | Batch: 004 / 029 | Total loss: 9.314 | Reg loss: 0.007 | Tree loss: 9.314 | Accuracy: 0.275391 | 0.241 sec/iter
Epoch: 02 | Batch: 005 / 029 | Total loss: 9.293 | Reg loss: 0.007 | Tree loss: 9.293 | Accuracy: 0.302734 | 0.241 sec/iter
Epoch: 02 | Batch: 006 / 029 | Total loss: 9.302 | Reg loss: 0.007 | Tree loss: 9.302 | Accuracy: 0.253906 | 0.24 sec/iter
Epoch: 02 | Batch: 007 / 029 | Total loss: 9.295 | Reg loss: 0.008 | Tree loss: 9.295 | Accuracy: 0.257812 | 0.24 sec/iter
Epoch: 02 | Batch: 008 / 029 | Total loss: 9.274 | Reg loss: 0.008 | Tree loss: 9.274 | Accuracy: 0.292969 | 0.24 sec/iter
Epoch: 02 | Batch: 009 / 029 | Total loss: 9.279 | Reg loss: 0.008 | Tree loss: 9.279 | Accuracy: 0.250000 | 0.24 sec/iter
Epoch: 02 | Batch: 010 / 029 | Total loss: 9.252 | Reg loss: 0.009 | Tree loss: 9.252 | Accuracy: 0.269531 | 0.24 sec/iter
Epoch: 02 | Batch: 011 / 029 | Total loss: 9.251 | Reg loss: 0.009 | Tree loss: 9.251 | Accuracy: 0.259766 | 0.24 sec/iter
Epoch: 02 | Ba

Epoch: 04 | Batch: 009 / 029 | Total loss: 8.926 | Reg loss: 0.013 | Tree loss: 8.926 | Accuracy: 0.306641 | 0.241 sec/iter
Epoch: 04 | Batch: 010 / 029 | Total loss: 8.932 | Reg loss: 0.014 | Tree loss: 8.932 | Accuracy: 0.271484 | 0.241 sec/iter
Epoch: 04 | Batch: 011 / 029 | Total loss: 8.923 | Reg loss: 0.014 | Tree loss: 8.923 | Accuracy: 0.273438 | 0.241 sec/iter
Epoch: 04 | Batch: 012 / 029 | Total loss: 8.900 | Reg loss: 0.014 | Tree loss: 8.900 | Accuracy: 0.281250 | 0.24 sec/iter
Epoch: 04 | Batch: 013 / 029 | Total loss: 8.894 | Reg loss: 0.015 | Tree loss: 8.894 | Accuracy: 0.279297 | 0.241 sec/iter
Epoch: 04 | Batch: 014 / 029 | Total loss: 8.882 | Reg loss: 0.015 | Tree loss: 8.882 | Accuracy: 0.269531 | 0.241 sec/iter
Epoch: 04 | Batch: 015 / 029 | Total loss: 8.828 | Reg loss: 0.016 | Tree loss: 8.828 | Accuracy: 0.316406 | 0.241 sec/iter
Epoch: 04 | Batch: 016 / 029 | Total loss: 8.831 | Reg loss: 0.016 | Tree loss: 8.831 | Accuracy: 0.289062 | 0.241 sec/iter
Epoch: 04

Epoch: 06 | Batch: 014 / 029 | Total loss: 8.419 | Reg loss: 0.019 | Tree loss: 8.419 | Accuracy: 0.320312 | 0.242 sec/iter
Epoch: 06 | Batch: 015 / 029 | Total loss: 8.414 | Reg loss: 0.019 | Tree loss: 8.414 | Accuracy: 0.312500 | 0.242 sec/iter
Epoch: 06 | Batch: 016 / 029 | Total loss: 8.423 | Reg loss: 0.020 | Tree loss: 8.423 | Accuracy: 0.287109 | 0.242 sec/iter
Epoch: 06 | Batch: 017 / 029 | Total loss: 8.404 | Reg loss: 0.020 | Tree loss: 8.404 | Accuracy: 0.261719 | 0.242 sec/iter
Epoch: 06 | Batch: 018 / 029 | Total loss: 8.401 | Reg loss: 0.020 | Tree loss: 8.401 | Accuracy: 0.232422 | 0.242 sec/iter
Epoch: 06 | Batch: 019 / 029 | Total loss: 8.357 | Reg loss: 0.021 | Tree loss: 8.357 | Accuracy: 0.267578 | 0.242 sec/iter
Epoch: 06 | Batch: 020 / 029 | Total loss: 8.367 | Reg loss: 0.021 | Tree loss: 8.367 | Accuracy: 0.248047 | 0.242 sec/iter
Epoch: 06 | Batch: 021 / 029 | Total loss: 8.320 | Reg loss: 0.022 | Tree loss: 8.320 | Accuracy: 0.285156 | 0.242 sec/iter
Epoch: 0

Epoch: 08 | Batch: 019 / 029 | Total loss: 7.875 | Reg loss: 0.023 | Tree loss: 7.875 | Accuracy: 0.277344 | 0.243 sec/iter
Epoch: 08 | Batch: 020 / 029 | Total loss: 7.883 | Reg loss: 0.024 | Tree loss: 7.883 | Accuracy: 0.250000 | 0.243 sec/iter
Epoch: 08 | Batch: 021 / 029 | Total loss: 7.838 | Reg loss: 0.024 | Tree loss: 7.838 | Accuracy: 0.251953 | 0.243 sec/iter
Epoch: 08 | Batch: 022 / 029 | Total loss: 7.825 | Reg loss: 0.024 | Tree loss: 7.825 | Accuracy: 0.263672 | 0.243 sec/iter
Epoch: 08 | Batch: 023 / 029 | Total loss: 7.785 | Reg loss: 0.025 | Tree loss: 7.785 | Accuracy: 0.292969 | 0.242 sec/iter
Epoch: 08 | Batch: 024 / 029 | Total loss: 7.783 | Reg loss: 0.025 | Tree loss: 7.783 | Accuracy: 0.257812 | 0.242 sec/iter
Epoch: 08 | Batch: 025 / 029 | Total loss: 7.721 | Reg loss: 0.025 | Tree loss: 7.721 | Accuracy: 0.292969 | 0.242 sec/iter
Epoch: 08 | Batch: 026 / 029 | Total loss: 7.718 | Reg loss: 0.025 | Tree loss: 7.718 | Accuracy: 0.296875 | 0.242 sec/iter
Epoch: 0

Epoch: 10 | Batch: 024 / 029 | Total loss: 7.196 | Reg loss: 0.027 | Tree loss: 7.196 | Accuracy: 0.304688 | 0.25 sec/iter
Epoch: 10 | Batch: 025 / 029 | Total loss: 7.207 | Reg loss: 0.027 | Tree loss: 7.207 | Accuracy: 0.265625 | 0.25 sec/iter
Epoch: 10 | Batch: 026 / 029 | Total loss: 7.158 | Reg loss: 0.027 | Tree loss: 7.158 | Accuracy: 0.287109 | 0.251 sec/iter
Epoch: 10 | Batch: 027 / 029 | Total loss: 7.162 | Reg loss: 0.028 | Tree loss: 7.162 | Accuracy: 0.279297 | 0.251 sec/iter
Epoch: 10 | Batch: 028 / 029 | Total loss: 7.358 | Reg loss: 0.028 | Tree loss: 7.358 | Accuracy: 0.230769 | 0.251 sec/iter
Average sparseness: 0.9821428571428573
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
Epoch: 11 | Batch: 000 / 029 | Total loss: 7.446 | Reg loss: 0.024 | Tree loss: 7.446 | Accuracy: 0.298828 | 0.251 sec/iter
Epoch: 11 | Batch: 001 /

Epoch: 13 | Batch: 000 / 029 | Total loss: 6.866 | Reg loss: 0.025 | Tree loss: 6.866 | Accuracy: 0.269531 | 0.253 sec/iter
Epoch: 13 | Batch: 001 / 029 | Total loss: 6.854 | Reg loss: 0.025 | Tree loss: 6.854 | Accuracy: 0.281250 | 0.253 sec/iter
Epoch: 13 | Batch: 002 / 029 | Total loss: 6.848 | Reg loss: 0.025 | Tree loss: 6.848 | Accuracy: 0.246094 | 0.253 sec/iter
Epoch: 13 | Batch: 003 / 029 | Total loss: 6.777 | Reg loss: 0.025 | Tree loss: 6.777 | Accuracy: 0.289062 | 0.253 sec/iter
Epoch: 13 | Batch: 004 / 029 | Total loss: 6.775 | Reg loss: 0.026 | Tree loss: 6.775 | Accuracy: 0.273438 | 0.253 sec/iter
Epoch: 13 | Batch: 005 / 029 | Total loss: 6.733 | Reg loss: 0.026 | Tree loss: 6.733 | Accuracy: 0.265625 | 0.253 sec/iter
Epoch: 13 | Batch: 006 / 029 | Total loss: 6.725 | Reg loss: 0.026 | Tree loss: 6.725 | Accuracy: 0.269531 | 0.253 sec/iter
Epoch: 13 | Batch: 007 / 029 | Total loss: 6.737 | Reg loss: 0.026 | Tree loss: 6.737 | Accuracy: 0.271484 | 0.253 sec/iter
Epoch: 1

Epoch: 15 | Batch: 005 / 029 | Total loss: 6.191 | Reg loss: 0.026 | Tree loss: 6.191 | Accuracy: 0.291016 | 0.255 sec/iter
Epoch: 15 | Batch: 006 / 029 | Total loss: 6.181 | Reg loss: 0.026 | Tree loss: 6.181 | Accuracy: 0.269531 | 0.255 sec/iter
Epoch: 15 | Batch: 007 / 029 | Total loss: 6.135 | Reg loss: 0.026 | Tree loss: 6.135 | Accuracy: 0.277344 | 0.255 sec/iter
Epoch: 15 | Batch: 008 / 029 | Total loss: 6.109 | Reg loss: 0.026 | Tree loss: 6.109 | Accuracy: 0.310547 | 0.255 sec/iter
Epoch: 15 | Batch: 009 / 029 | Total loss: 6.097 | Reg loss: 0.026 | Tree loss: 6.097 | Accuracy: 0.283203 | 0.255 sec/iter
Epoch: 15 | Batch: 010 / 029 | Total loss: 6.081 | Reg loss: 0.027 | Tree loss: 6.081 | Accuracy: 0.291016 | 0.255 sec/iter
Epoch: 15 | Batch: 011 / 029 | Total loss: 6.075 | Reg loss: 0.027 | Tree loss: 6.075 | Accuracy: 0.287109 | 0.255 sec/iter
Epoch: 15 | Batch: 012 / 029 | Total loss: 6.054 | Reg loss: 0.027 | Tree loss: 6.054 | Accuracy: 0.255859 | 0.255 sec/iter
Epoch: 1

Epoch: 17 | Batch: 010 / 029 | Total loss: 5.596 | Reg loss: 0.026 | Tree loss: 5.596 | Accuracy: 0.244141 | 0.256 sec/iter
Epoch: 17 | Batch: 011 / 029 | Total loss: 5.587 | Reg loss: 0.027 | Tree loss: 5.587 | Accuracy: 0.248047 | 0.256 sec/iter
Epoch: 17 | Batch: 012 / 029 | Total loss: 5.539 | Reg loss: 0.027 | Tree loss: 5.539 | Accuracy: 0.240234 | 0.256 sec/iter
Epoch: 17 | Batch: 013 / 029 | Total loss: 5.539 | Reg loss: 0.027 | Tree loss: 5.539 | Accuracy: 0.220703 | 0.256 sec/iter
Epoch: 17 | Batch: 014 / 029 | Total loss: 5.485 | Reg loss: 0.027 | Tree loss: 5.485 | Accuracy: 0.236328 | 0.256 sec/iter
Epoch: 17 | Batch: 015 / 029 | Total loss: 5.473 | Reg loss: 0.027 | Tree loss: 5.473 | Accuracy: 0.257812 | 0.256 sec/iter
Epoch: 17 | Batch: 016 / 029 | Total loss: 5.449 | Reg loss: 0.027 | Tree loss: 5.449 | Accuracy: 0.220703 | 0.256 sec/iter
Epoch: 17 | Batch: 017 / 029 | Total loss: 5.442 | Reg loss: 0.027 | Tree loss: 5.442 | Accuracy: 0.246094 | 0.256 sec/iter
Epoch: 1

Epoch: 19 | Batch: 015 / 029 | Total loss: 5.012 | Reg loss: 0.026 | Tree loss: 5.012 | Accuracy: 0.181641 | 0.257 sec/iter
Epoch: 19 | Batch: 016 / 029 | Total loss: 4.997 | Reg loss: 0.026 | Tree loss: 4.997 | Accuracy: 0.144531 | 0.257 sec/iter
Epoch: 19 | Batch: 017 / 029 | Total loss: 4.967 | Reg loss: 0.026 | Tree loss: 4.967 | Accuracy: 0.160156 | 0.258 sec/iter
Epoch: 19 | Batch: 018 / 029 | Total loss: 4.945 | Reg loss: 0.026 | Tree loss: 4.945 | Accuracy: 0.160156 | 0.258 sec/iter
Epoch: 19 | Batch: 019 / 029 | Total loss: 4.933 | Reg loss: 0.026 | Tree loss: 4.933 | Accuracy: 0.150391 | 0.258 sec/iter
Epoch: 19 | Batch: 020 / 029 | Total loss: 4.940 | Reg loss: 0.026 | Tree loss: 4.940 | Accuracy: 0.166016 | 0.258 sec/iter
Epoch: 19 | Batch: 021 / 029 | Total loss: 4.904 | Reg loss: 0.027 | Tree loss: 4.904 | Accuracy: 0.160156 | 0.258 sec/iter
Epoch: 19 | Batch: 022 / 029 | Total loss: 4.876 | Reg loss: 0.027 | Tree loss: 4.876 | Accuracy: 0.166016 | 0.258 sec/iter
Epoch: 1

Epoch: 21 | Batch: 020 / 029 | Total loss: 4.416 | Reg loss: 0.026 | Tree loss: 4.416 | Accuracy: 0.130859 | 0.258 sec/iter
Epoch: 21 | Batch: 021 / 029 | Total loss: 4.455 | Reg loss: 0.026 | Tree loss: 4.455 | Accuracy: 0.115234 | 0.258 sec/iter
Epoch: 21 | Batch: 022 / 029 | Total loss: 4.412 | Reg loss: 0.026 | Tree loss: 4.412 | Accuracy: 0.119141 | 0.258 sec/iter
Epoch: 21 | Batch: 023 / 029 | Total loss: 4.389 | Reg loss: 0.026 | Tree loss: 4.389 | Accuracy: 0.128906 | 0.258 sec/iter
Epoch: 21 | Batch: 024 / 029 | Total loss: 4.386 | Reg loss: 0.027 | Tree loss: 4.386 | Accuracy: 0.126953 | 0.258 sec/iter
Epoch: 21 | Batch: 025 / 029 | Total loss: 4.343 | Reg loss: 0.027 | Tree loss: 4.343 | Accuracy: 0.117188 | 0.258 sec/iter
Epoch: 21 | Batch: 026 / 029 | Total loss: 4.320 | Reg loss: 0.027 | Tree loss: 4.320 | Accuracy: 0.166016 | 0.258 sec/iter
Epoch: 21 | Batch: 027 / 029 | Total loss: 4.360 | Reg loss: 0.027 | Tree loss: 4.360 | Accuracy: 0.111328 | 0.258 sec/iter
Epoch: 2

Epoch: 23 | Batch: 025 / 029 | Total loss: 3.854 | Reg loss: 0.027 | Tree loss: 3.854 | Accuracy: 0.099609 | 0.257 sec/iter
Epoch: 23 | Batch: 026 / 029 | Total loss: 3.845 | Reg loss: 0.027 | Tree loss: 3.845 | Accuracy: 0.138672 | 0.257 sec/iter
Epoch: 23 | Batch: 027 / 029 | Total loss: 3.807 | Reg loss: 0.027 | Tree loss: 3.807 | Accuracy: 0.126953 | 0.257 sec/iter
Epoch: 23 | Batch: 028 / 029 | Total loss: 3.850 | Reg loss: 0.027 | Tree loss: 3.850 | Accuracy: 0.076923 | 0.257 sec/iter
Average sparseness: 0.9821428571428573
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
Epoch: 24 | Batch: 000 / 029 | Total loss: 3.935 | Reg loss: 0.025 | Tree loss: 3.935 | Accuracy: 0.128906 | 0.257 sec/iter
Epoch: 24 | Batch: 001 / 029 | Total loss: 3.907 | Reg loss: 0.025 | Tree loss: 3.907 | Accuracy: 0.128906 | 0.257 sec/iter
Epoch: 24 | Batch: 002

Epoch: 26 | Batch: 000 / 029 | Total loss: 3.518 | Reg loss: 0.026 | Tree loss: 3.518 | Accuracy: 0.105469 | 0.257 sec/iter
Epoch: 26 | Batch: 001 / 029 | Total loss: 3.536 | Reg loss: 0.026 | Tree loss: 3.536 | Accuracy: 0.115234 | 0.257 sec/iter
Epoch: 26 | Batch: 002 / 029 | Total loss: 3.473 | Reg loss: 0.026 | Tree loss: 3.473 | Accuracy: 0.117188 | 0.257 sec/iter
Epoch: 26 | Batch: 003 / 029 | Total loss: 3.499 | Reg loss: 0.026 | Tree loss: 3.499 | Accuracy: 0.138672 | 0.257 sec/iter
Epoch: 26 | Batch: 004 / 029 | Total loss: 3.525 | Reg loss: 0.026 | Tree loss: 3.525 | Accuracy: 0.136719 | 0.257 sec/iter
Epoch: 26 | Batch: 005 / 029 | Total loss: 3.478 | Reg loss: 0.026 | Tree loss: 3.478 | Accuracy: 0.148438 | 0.257 sec/iter
Epoch: 26 | Batch: 006 / 029 | Total loss: 3.451 | Reg loss: 0.026 | Tree loss: 3.451 | Accuracy: 0.105469 | 0.257 sec/iter
Epoch: 26 | Batch: 007 / 029 | Total loss: 3.477 | Reg loss: 0.026 | Tree loss: 3.477 | Accuracy: 0.101562 | 0.257 sec/iter
Epoch: 2

Epoch: 28 | Batch: 005 / 029 | Total loss: 3.129 | Reg loss: 0.026 | Tree loss: 3.129 | Accuracy: 0.144531 | 0.256 sec/iter
Epoch: 28 | Batch: 006 / 029 | Total loss: 3.136 | Reg loss: 0.026 | Tree loss: 3.136 | Accuracy: 0.128906 | 0.256 sec/iter
Epoch: 28 | Batch: 007 / 029 | Total loss: 3.067 | Reg loss: 0.026 | Tree loss: 3.067 | Accuracy: 0.142578 | 0.256 sec/iter
Epoch: 28 | Batch: 008 / 029 | Total loss: 3.147 | Reg loss: 0.026 | Tree loss: 3.147 | Accuracy: 0.126953 | 0.256 sec/iter
Epoch: 28 | Batch: 009 / 029 | Total loss: 3.106 | Reg loss: 0.026 | Tree loss: 3.106 | Accuracy: 0.144531 | 0.256 sec/iter
Epoch: 28 | Batch: 010 / 029 | Total loss: 3.133 | Reg loss: 0.026 | Tree loss: 3.133 | Accuracy: 0.113281 | 0.256 sec/iter
Epoch: 28 | Batch: 011 / 029 | Total loss: 3.092 | Reg loss: 0.026 | Tree loss: 3.092 | Accuracy: 0.148438 | 0.256 sec/iter
Epoch: 28 | Batch: 012 / 029 | Total loss: 3.139 | Reg loss: 0.026 | Tree loss: 3.139 | Accuracy: 0.123047 | 0.256 sec/iter
Epoch: 2

Epoch: 30 | Batch: 010 / 029 | Total loss: 2.859 | Reg loss: 0.026 | Tree loss: 2.859 | Accuracy: 0.152344 | 0.256 sec/iter
Epoch: 30 | Batch: 011 / 029 | Total loss: 2.861 | Reg loss: 0.026 | Tree loss: 2.861 | Accuracy: 0.136719 | 0.256 sec/iter
Epoch: 30 | Batch: 012 / 029 | Total loss: 2.847 | Reg loss: 0.026 | Tree loss: 2.847 | Accuracy: 0.154297 | 0.256 sec/iter
Epoch: 30 | Batch: 013 / 029 | Total loss: 2.914 | Reg loss: 0.026 | Tree loss: 2.914 | Accuracy: 0.121094 | 0.256 sec/iter
Epoch: 30 | Batch: 014 / 029 | Total loss: 2.871 | Reg loss: 0.026 | Tree loss: 2.871 | Accuracy: 0.130859 | 0.255 sec/iter
Epoch: 30 | Batch: 015 / 029 | Total loss: 2.867 | Reg loss: 0.026 | Tree loss: 2.867 | Accuracy: 0.125000 | 0.255 sec/iter
Epoch: 30 | Batch: 016 / 029 | Total loss: 2.852 | Reg loss: 0.026 | Tree loss: 2.852 | Accuracy: 0.148438 | 0.255 sec/iter
Epoch: 30 | Batch: 017 / 029 | Total loss: 2.851 | Reg loss: 0.026 | Tree loss: 2.851 | Accuracy: 0.150391 | 0.255 sec/iter
Epoch: 3

Epoch: 32 | Batch: 015 / 029 | Total loss: 2.671 | Reg loss: 0.026 | Tree loss: 2.671 | Accuracy: 0.136719 | 0.255 sec/iter
Epoch: 32 | Batch: 016 / 029 | Total loss: 2.680 | Reg loss: 0.026 | Tree loss: 2.680 | Accuracy: 0.138672 | 0.255 sec/iter
Epoch: 32 | Batch: 017 / 029 | Total loss: 2.681 | Reg loss: 0.026 | Tree loss: 2.681 | Accuracy: 0.125000 | 0.255 sec/iter
Epoch: 32 | Batch: 018 / 029 | Total loss: 2.698 | Reg loss: 0.026 | Tree loss: 2.698 | Accuracy: 0.136719 | 0.255 sec/iter
Epoch: 32 | Batch: 019 / 029 | Total loss: 2.648 | Reg loss: 0.026 | Tree loss: 2.648 | Accuracy: 0.152344 | 0.255 sec/iter
Epoch: 32 | Batch: 020 / 029 | Total loss: 2.665 | Reg loss: 0.026 | Tree loss: 2.665 | Accuracy: 0.160156 | 0.255 sec/iter
Epoch: 32 | Batch: 021 / 029 | Total loss: 2.724 | Reg loss: 0.026 | Tree loss: 2.724 | Accuracy: 0.095703 | 0.255 sec/iter
Epoch: 32 | Batch: 022 / 029 | Total loss: 2.657 | Reg loss: 0.026 | Tree loss: 2.657 | Accuracy: 0.171875 | 0.255 sec/iter
Epoch: 3

Epoch: 34 | Batch: 020 / 029 | Total loss: 2.547 | Reg loss: 0.026 | Tree loss: 2.547 | Accuracy: 0.146484 | 0.255 sec/iter
Epoch: 34 | Batch: 021 / 029 | Total loss: 2.541 | Reg loss: 0.026 | Tree loss: 2.541 | Accuracy: 0.119141 | 0.255 sec/iter
Epoch: 34 | Batch: 022 / 029 | Total loss: 2.527 | Reg loss: 0.026 | Tree loss: 2.527 | Accuracy: 0.111328 | 0.255 sec/iter
Epoch: 34 | Batch: 023 / 029 | Total loss: 2.528 | Reg loss: 0.026 | Tree loss: 2.528 | Accuracy: 0.111328 | 0.255 sec/iter
Epoch: 34 | Batch: 024 / 029 | Total loss: 2.548 | Reg loss: 0.026 | Tree loss: 2.548 | Accuracy: 0.136719 | 0.255 sec/iter
Epoch: 34 | Batch: 025 / 029 | Total loss: 2.556 | Reg loss: 0.026 | Tree loss: 2.556 | Accuracy: 0.119141 | 0.255 sec/iter
Epoch: 34 | Batch: 026 / 029 | Total loss: 2.559 | Reg loss: 0.026 | Tree loss: 2.559 | Accuracy: 0.119141 | 0.255 sec/iter
Epoch: 34 | Batch: 027 / 029 | Total loss: 2.562 | Reg loss: 0.026 | Tree loss: 2.562 | Accuracy: 0.132812 | 0.255 sec/iter
Epoch: 3

Epoch: 36 | Batch: 025 / 029 | Total loss: 2.434 | Reg loss: 0.026 | Tree loss: 2.434 | Accuracy: 0.275391 | 0.255 sec/iter
Epoch: 36 | Batch: 026 / 029 | Total loss: 2.464 | Reg loss: 0.026 | Tree loss: 2.464 | Accuracy: 0.259766 | 0.254 sec/iter
Epoch: 36 | Batch: 027 / 029 | Total loss: 2.443 | Reg loss: 0.026 | Tree loss: 2.443 | Accuracy: 0.263672 | 0.254 sec/iter
Epoch: 36 | Batch: 028 / 029 | Total loss: 2.345 | Reg loss: 0.026 | Tree loss: 2.345 | Accuracy: 0.538462 | 0.254 sec/iter
Average sparseness: 0.9821428571428573
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
Epoch: 37 | Batch: 000 / 029 | Total loss: 2.493 | Reg loss: 0.025 | Tree loss: 2.493 | Accuracy: 0.236328 | 0.255 sec/iter
Epoch: 37 | Batch: 001 / 029 | Total loss: 2.452 | Reg loss: 0.025 | Tree loss: 2.452 | Accuracy: 0.257812 | 0.255 sec/iter
Epoch: 37 | Batch: 002

Epoch: 39 | Batch: 000 / 029 | Total loss: 2.333 | Reg loss: 0.025 | Tree loss: 2.333 | Accuracy: 0.300781 | 0.254 sec/iter
Epoch: 39 | Batch: 001 / 029 | Total loss: 2.347 | Reg loss: 0.025 | Tree loss: 2.347 | Accuracy: 0.300781 | 0.254 sec/iter
Epoch: 39 | Batch: 002 / 029 | Total loss: 2.386 | Reg loss: 0.025 | Tree loss: 2.386 | Accuracy: 0.234375 | 0.254 sec/iter
Epoch: 39 | Batch: 003 / 029 | Total loss: 2.344 | Reg loss: 0.025 | Tree loss: 2.344 | Accuracy: 0.306641 | 0.254 sec/iter
Epoch: 39 | Batch: 004 / 029 | Total loss: 2.362 | Reg loss: 0.025 | Tree loss: 2.362 | Accuracy: 0.261719 | 0.254 sec/iter
Epoch: 39 | Batch: 005 / 029 | Total loss: 2.371 | Reg loss: 0.025 | Tree loss: 2.371 | Accuracy: 0.265625 | 0.254 sec/iter
Epoch: 39 | Batch: 006 / 029 | Total loss: 2.349 | Reg loss: 0.025 | Tree loss: 2.349 | Accuracy: 0.304688 | 0.254 sec/iter
Epoch: 39 | Batch: 007 / 029 | Total loss: 2.362 | Reg loss: 0.025 | Tree loss: 2.362 | Accuracy: 0.271484 | 0.254 sec/iter
Epoch: 3

Epoch: 41 | Batch: 005 / 029 | Total loss: 2.363 | Reg loss: 0.025 | Tree loss: 2.363 | Accuracy: 0.253906 | 0.254 sec/iter
Epoch: 41 | Batch: 006 / 029 | Total loss: 2.291 | Reg loss: 0.025 | Tree loss: 2.291 | Accuracy: 0.291016 | 0.254 sec/iter
Epoch: 41 | Batch: 007 / 029 | Total loss: 2.340 | Reg loss: 0.025 | Tree loss: 2.340 | Accuracy: 0.283203 | 0.254 sec/iter
Epoch: 41 | Batch: 008 / 029 | Total loss: 2.280 | Reg loss: 0.025 | Tree loss: 2.280 | Accuracy: 0.283203 | 0.254 sec/iter
Epoch: 41 | Batch: 009 / 029 | Total loss: 2.281 | Reg loss: 0.025 | Tree loss: 2.281 | Accuracy: 0.263672 | 0.254 sec/iter
Epoch: 41 | Batch: 010 / 029 | Total loss: 2.333 | Reg loss: 0.025 | Tree loss: 2.333 | Accuracy: 0.291016 | 0.254 sec/iter
Epoch: 41 | Batch: 011 / 029 | Total loss: 2.274 | Reg loss: 0.025 | Tree loss: 2.274 | Accuracy: 0.296875 | 0.254 sec/iter
Epoch: 41 | Batch: 012 / 029 | Total loss: 2.291 | Reg loss: 0.025 | Tree loss: 2.291 | Accuracy: 0.269531 | 0.254 sec/iter
Epoch: 4

Epoch: 43 | Batch: 010 / 029 | Total loss: 2.262 | Reg loss: 0.025 | Tree loss: 2.262 | Accuracy: 0.279297 | 0.254 sec/iter
Epoch: 43 | Batch: 011 / 029 | Total loss: 2.240 | Reg loss: 0.025 | Tree loss: 2.240 | Accuracy: 0.267578 | 0.254 sec/iter
Epoch: 43 | Batch: 012 / 029 | Total loss: 2.256 | Reg loss: 0.025 | Tree loss: 2.256 | Accuracy: 0.267578 | 0.254 sec/iter
Epoch: 43 | Batch: 013 / 029 | Total loss: 2.211 | Reg loss: 0.025 | Tree loss: 2.211 | Accuracy: 0.292969 | 0.254 sec/iter
Epoch: 43 | Batch: 014 / 029 | Total loss: 2.241 | Reg loss: 0.025 | Tree loss: 2.241 | Accuracy: 0.279297 | 0.254 sec/iter
Epoch: 43 | Batch: 015 / 029 | Total loss: 2.261 | Reg loss: 0.025 | Tree loss: 2.261 | Accuracy: 0.281250 | 0.254 sec/iter
Epoch: 43 | Batch: 016 / 029 | Total loss: 2.236 | Reg loss: 0.025 | Tree loss: 2.236 | Accuracy: 0.273438 | 0.254 sec/iter
Epoch: 43 | Batch: 017 / 029 | Total loss: 2.279 | Reg loss: 0.025 | Tree loss: 2.279 | Accuracy: 0.287109 | 0.254 sec/iter
Epoch: 4

Epoch: 45 | Batch: 015 / 029 | Total loss: 2.235 | Reg loss: 0.025 | Tree loss: 2.235 | Accuracy: 0.277344 | 0.254 sec/iter
Epoch: 45 | Batch: 016 / 029 | Total loss: 2.224 | Reg loss: 0.025 | Tree loss: 2.224 | Accuracy: 0.253906 | 0.254 sec/iter
Epoch: 45 | Batch: 017 / 029 | Total loss: 2.207 | Reg loss: 0.025 | Tree loss: 2.207 | Accuracy: 0.318359 | 0.254 sec/iter
Epoch: 45 | Batch: 018 / 029 | Total loss: 2.221 | Reg loss: 0.025 | Tree loss: 2.221 | Accuracy: 0.267578 | 0.254 sec/iter
Epoch: 45 | Batch: 019 / 029 | Total loss: 2.216 | Reg loss: 0.025 | Tree loss: 2.216 | Accuracy: 0.300781 | 0.254 sec/iter
Epoch: 45 | Batch: 020 / 029 | Total loss: 2.160 | Reg loss: 0.025 | Tree loss: 2.160 | Accuracy: 0.265625 | 0.254 sec/iter
Epoch: 45 | Batch: 021 / 029 | Total loss: 2.209 | Reg loss: 0.025 | Tree loss: 2.209 | Accuracy: 0.300781 | 0.254 sec/iter
Epoch: 45 | Batch: 022 / 029 | Total loss: 2.230 | Reg loss: 0.025 | Tree loss: 2.230 | Accuracy: 0.269531 | 0.254 sec/iter
Epoch: 4

Epoch: 47 | Batch: 020 / 029 | Total loss: 2.164 | Reg loss: 0.025 | Tree loss: 2.164 | Accuracy: 0.320312 | 0.254 sec/iter
Epoch: 47 | Batch: 021 / 029 | Total loss: 2.226 | Reg loss: 0.025 | Tree loss: 2.226 | Accuracy: 0.277344 | 0.254 sec/iter
Epoch: 47 | Batch: 022 / 029 | Total loss: 2.166 | Reg loss: 0.025 | Tree loss: 2.166 | Accuracy: 0.287109 | 0.254 sec/iter
Epoch: 47 | Batch: 023 / 029 | Total loss: 2.209 | Reg loss: 0.025 | Tree loss: 2.209 | Accuracy: 0.289062 | 0.254 sec/iter
Epoch: 47 | Batch: 024 / 029 | Total loss: 2.205 | Reg loss: 0.025 | Tree loss: 2.205 | Accuracy: 0.240234 | 0.254 sec/iter
Epoch: 47 | Batch: 025 / 029 | Total loss: 2.189 | Reg loss: 0.025 | Tree loss: 2.189 | Accuracy: 0.275391 | 0.254 sec/iter
Epoch: 47 | Batch: 026 / 029 | Total loss: 2.157 | Reg loss: 0.025 | Tree loss: 2.157 | Accuracy: 0.271484 | 0.254 sec/iter
Epoch: 47 | Batch: 027 / 029 | Total loss: 2.134 | Reg loss: 0.025 | Tree loss: 2.134 | Accuracy: 0.310547 | 0.254 sec/iter
Epoch: 4

Epoch: 49 | Batch: 025 / 029 | Total loss: 2.167 | Reg loss: 0.025 | Tree loss: 2.167 | Accuracy: 0.267578 | 0.254 sec/iter
Epoch: 49 | Batch: 026 / 029 | Total loss: 2.133 | Reg loss: 0.025 | Tree loss: 2.133 | Accuracy: 0.281250 | 0.254 sec/iter
Epoch: 49 | Batch: 027 / 029 | Total loss: 2.171 | Reg loss: 0.025 | Tree loss: 2.171 | Accuracy: 0.283203 | 0.254 sec/iter
Epoch: 49 | Batch: 028 / 029 | Total loss: 2.026 | Reg loss: 0.025 | Tree loss: 2.026 | Accuracy: 0.307692 | 0.254 sec/iter
Average sparseness: 0.9821428571428573
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
Epoch: 50 | Batch: 000 / 029 | Total loss: 2.176 | Reg loss: 0.025 | Tree loss: 2.176 | Accuracy: 0.265625 | 0.254 sec/iter
Epoch: 50 | Batch: 001 / 029 | Total loss: 2.174 | Reg loss: 0.025 | Tree loss: 2.174 | Accuracy: 0.271484 | 0.254 sec/iter
Epoch: 50 | Batch: 002

Epoch: 52 | Batch: 000 / 029 | Total loss: 2.146 | Reg loss: 0.025 | Tree loss: 2.146 | Accuracy: 0.312500 | 0.254 sec/iter
Epoch: 52 | Batch: 001 / 029 | Total loss: 2.205 | Reg loss: 0.025 | Tree loss: 2.205 | Accuracy: 0.250000 | 0.254 sec/iter
Epoch: 52 | Batch: 002 / 029 | Total loss: 2.127 | Reg loss: 0.025 | Tree loss: 2.127 | Accuracy: 0.283203 | 0.254 sec/iter
Epoch: 52 | Batch: 003 / 029 | Total loss: 2.154 | Reg loss: 0.025 | Tree loss: 2.154 | Accuracy: 0.263672 | 0.254 sec/iter
Epoch: 52 | Batch: 004 / 029 | Total loss: 2.166 | Reg loss: 0.025 | Tree loss: 2.166 | Accuracy: 0.269531 | 0.254 sec/iter
Epoch: 52 | Batch: 005 / 029 | Total loss: 2.136 | Reg loss: 0.025 | Tree loss: 2.136 | Accuracy: 0.257812 | 0.254 sec/iter
Epoch: 52 | Batch: 006 / 029 | Total loss: 2.150 | Reg loss: 0.025 | Tree loss: 2.150 | Accuracy: 0.285156 | 0.254 sec/iter
Epoch: 52 | Batch: 007 / 029 | Total loss: 2.169 | Reg loss: 0.025 | Tree loss: 2.169 | Accuracy: 0.259766 | 0.254 sec/iter
Epoch: 5

Epoch: 54 | Batch: 005 / 029 | Total loss: 2.142 | Reg loss: 0.025 | Tree loss: 2.142 | Accuracy: 0.263672 | 0.254 sec/iter
Epoch: 54 | Batch: 006 / 029 | Total loss: 2.168 | Reg loss: 0.025 | Tree loss: 2.168 | Accuracy: 0.250000 | 0.254 sec/iter
Epoch: 54 | Batch: 007 / 029 | Total loss: 2.142 | Reg loss: 0.025 | Tree loss: 2.142 | Accuracy: 0.289062 | 0.254 sec/iter
Epoch: 54 | Batch: 008 / 029 | Total loss: 2.114 | Reg loss: 0.025 | Tree loss: 2.114 | Accuracy: 0.291016 | 0.254 sec/iter
Epoch: 54 | Batch: 009 / 029 | Total loss: 2.157 | Reg loss: 0.025 | Tree loss: 2.157 | Accuracy: 0.240234 | 0.254 sec/iter
Epoch: 54 | Batch: 010 / 029 | Total loss: 2.178 | Reg loss: 0.025 | Tree loss: 2.178 | Accuracy: 0.275391 | 0.254 sec/iter
Epoch: 54 | Batch: 011 / 029 | Total loss: 2.135 | Reg loss: 0.025 | Tree loss: 2.135 | Accuracy: 0.269531 | 0.254 sec/iter
Epoch: 54 | Batch: 012 / 029 | Total loss: 2.123 | Reg loss: 0.025 | Tree loss: 2.123 | Accuracy: 0.271484 | 0.254 sec/iter
Epoch: 5

Epoch: 56 | Batch: 010 / 029 | Total loss: 2.174 | Reg loss: 0.025 | Tree loss: 2.174 | Accuracy: 0.238281 | 0.253 sec/iter
Epoch: 56 | Batch: 011 / 029 | Total loss: 2.106 | Reg loss: 0.025 | Tree loss: 2.106 | Accuracy: 0.275391 | 0.253 sec/iter
Epoch: 56 | Batch: 012 / 029 | Total loss: 2.098 | Reg loss: 0.025 | Tree loss: 2.098 | Accuracy: 0.314453 | 0.253 sec/iter
Epoch: 56 | Batch: 013 / 029 | Total loss: 2.114 | Reg loss: 0.025 | Tree loss: 2.114 | Accuracy: 0.273438 | 0.253 sec/iter
Epoch: 56 | Batch: 014 / 029 | Total loss: 2.156 | Reg loss: 0.025 | Tree loss: 2.156 | Accuracy: 0.275391 | 0.253 sec/iter
Epoch: 56 | Batch: 015 / 029 | Total loss: 2.100 | Reg loss: 0.025 | Tree loss: 2.100 | Accuracy: 0.291016 | 0.253 sec/iter
Epoch: 56 | Batch: 016 / 029 | Total loss: 2.158 | Reg loss: 0.025 | Tree loss: 2.158 | Accuracy: 0.269531 | 0.253 sec/iter
Epoch: 56 | Batch: 017 / 029 | Total loss: 2.128 | Reg loss: 0.025 | Tree loss: 2.128 | Accuracy: 0.253906 | 0.253 sec/iter
Epoch: 5

Epoch: 58 | Batch: 015 / 029 | Total loss: 2.106 | Reg loss: 0.025 | Tree loss: 2.106 | Accuracy: 0.285156 | 0.253 sec/iter
Epoch: 58 | Batch: 016 / 029 | Total loss: 2.151 | Reg loss: 0.025 | Tree loss: 2.151 | Accuracy: 0.251953 | 0.253 sec/iter
Epoch: 58 | Batch: 017 / 029 | Total loss: 2.128 | Reg loss: 0.025 | Tree loss: 2.128 | Accuracy: 0.306641 | 0.253 sec/iter
Epoch: 58 | Batch: 018 / 029 | Total loss: 2.126 | Reg loss: 0.025 | Tree loss: 2.126 | Accuracy: 0.244141 | 0.253 sec/iter
Epoch: 58 | Batch: 019 / 029 | Total loss: 2.090 | Reg loss: 0.025 | Tree loss: 2.090 | Accuracy: 0.318359 | 0.253 sec/iter
Epoch: 58 | Batch: 020 / 029 | Total loss: 2.114 | Reg loss: 0.025 | Tree loss: 2.114 | Accuracy: 0.304688 | 0.253 sec/iter
Epoch: 58 | Batch: 021 / 029 | Total loss: 2.121 | Reg loss: 0.025 | Tree loss: 2.121 | Accuracy: 0.273438 | 0.253 sec/iter
Epoch: 58 | Batch: 022 / 029 | Total loss: 2.056 | Reg loss: 0.025 | Tree loss: 2.056 | Accuracy: 0.298828 | 0.253 sec/iter
Epoch: 5

Epoch: 60 | Batch: 020 / 029 | Total loss: 2.115 | Reg loss: 0.025 | Tree loss: 2.115 | Accuracy: 0.265625 | 0.253 sec/iter
Epoch: 60 | Batch: 021 / 029 | Total loss: 2.152 | Reg loss: 0.025 | Tree loss: 2.152 | Accuracy: 0.246094 | 0.253 sec/iter
Epoch: 60 | Batch: 022 / 029 | Total loss: 2.111 | Reg loss: 0.025 | Tree loss: 2.111 | Accuracy: 0.285156 | 0.253 sec/iter
Epoch: 60 | Batch: 023 / 029 | Total loss: 2.077 | Reg loss: 0.025 | Tree loss: 2.077 | Accuracy: 0.312500 | 0.253 sec/iter
Epoch: 60 | Batch: 024 / 029 | Total loss: 2.037 | Reg loss: 0.025 | Tree loss: 2.037 | Accuracy: 0.296875 | 0.253 sec/iter
Epoch: 60 | Batch: 025 / 029 | Total loss: 2.099 | Reg loss: 0.025 | Tree loss: 2.099 | Accuracy: 0.267578 | 0.253 sec/iter
Epoch: 60 | Batch: 026 / 029 | Total loss: 2.104 | Reg loss: 0.025 | Tree loss: 2.104 | Accuracy: 0.261719 | 0.253 sec/iter
Epoch: 60 | Batch: 027 / 029 | Total loss: 2.112 | Reg loss: 0.025 | Tree loss: 2.112 | Accuracy: 0.250000 | 0.253 sec/iter
Epoch: 6

Epoch: 62 | Batch: 025 / 029 | Total loss: 2.087 | Reg loss: 0.025 | Tree loss: 2.087 | Accuracy: 0.277344 | 0.253 sec/iter
Epoch: 62 | Batch: 026 / 029 | Total loss: 2.097 | Reg loss: 0.025 | Tree loss: 2.097 | Accuracy: 0.289062 | 0.253 sec/iter
Epoch: 62 | Batch: 027 / 029 | Total loss: 2.073 | Reg loss: 0.025 | Tree loss: 2.073 | Accuracy: 0.292969 | 0.253 sec/iter
Epoch: 62 | Batch: 028 / 029 | Total loss: 2.018 | Reg loss: 0.025 | Tree loss: 2.018 | Accuracy: 0.230769 | 0.253 sec/iter
Average sparseness: 0.9821428571428573
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
Epoch: 63 | Batch: 000 / 029 | Total loss: 2.149 | Reg loss: 0.025 | Tree loss: 2.149 | Accuracy: 0.251953 | 0.253 sec/iter
Epoch: 63 | Batch: 001 / 029 | Total loss: 2.076 | Reg loss: 0.025 | Tree loss: 2.076 | Accuracy: 0.292969 | 0.253 sec/iter
Epoch: 63 | Batch: 002

Epoch: 65 | Batch: 000 / 029 | Total loss: 2.108 | Reg loss: 0.024 | Tree loss: 2.108 | Accuracy: 0.269531 | 0.253 sec/iter
Epoch: 65 | Batch: 001 / 029 | Total loss: 2.063 | Reg loss: 0.025 | Tree loss: 2.063 | Accuracy: 0.279297 | 0.253 sec/iter
Epoch: 65 | Batch: 002 / 029 | Total loss: 2.121 | Reg loss: 0.025 | Tree loss: 2.121 | Accuracy: 0.248047 | 0.253 sec/iter
Epoch: 65 | Batch: 003 / 029 | Total loss: 2.050 | Reg loss: 0.025 | Tree loss: 2.050 | Accuracy: 0.294922 | 0.253 sec/iter
Epoch: 65 | Batch: 004 / 029 | Total loss: 2.089 | Reg loss: 0.025 | Tree loss: 2.089 | Accuracy: 0.261719 | 0.253 sec/iter
Epoch: 65 | Batch: 005 / 029 | Total loss: 2.185 | Reg loss: 0.025 | Tree loss: 2.185 | Accuracy: 0.224609 | 0.253 sec/iter
Epoch: 65 | Batch: 006 / 029 | Total loss: 2.105 | Reg loss: 0.025 | Tree loss: 2.105 | Accuracy: 0.267578 | 0.253 sec/iter
Epoch: 65 | Batch: 007 / 029 | Total loss: 2.031 | Reg loss: 0.025 | Tree loss: 2.031 | Accuracy: 0.296875 | 0.253 sec/iter
Epoch: 6

Epoch: 67 | Batch: 005 / 029 | Total loss: 2.054 | Reg loss: 0.024 | Tree loss: 2.054 | Accuracy: 0.287109 | 0.253 sec/iter
Epoch: 67 | Batch: 006 / 029 | Total loss: 2.026 | Reg loss: 0.024 | Tree loss: 2.026 | Accuracy: 0.310547 | 0.253 sec/iter
Epoch: 67 | Batch: 007 / 029 | Total loss: 2.059 | Reg loss: 0.024 | Tree loss: 2.059 | Accuracy: 0.296875 | 0.253 sec/iter
Epoch: 67 | Batch: 008 / 029 | Total loss: 2.100 | Reg loss: 0.024 | Tree loss: 2.100 | Accuracy: 0.269531 | 0.253 sec/iter
Epoch: 67 | Batch: 009 / 029 | Total loss: 2.075 | Reg loss: 0.024 | Tree loss: 2.075 | Accuracy: 0.285156 | 0.253 sec/iter
Epoch: 67 | Batch: 010 / 029 | Total loss: 2.085 | Reg loss: 0.024 | Tree loss: 2.085 | Accuracy: 0.263672 | 0.253 sec/iter
Epoch: 67 | Batch: 011 / 029 | Total loss: 2.104 | Reg loss: 0.024 | Tree loss: 2.104 | Accuracy: 0.240234 | 0.253 sec/iter
Epoch: 67 | Batch: 012 / 029 | Total loss: 2.102 | Reg loss: 0.024 | Tree loss: 2.102 | Accuracy: 0.257812 | 0.253 sec/iter
Epoch: 6

Epoch: 69 | Batch: 010 / 029 | Total loss: 2.038 | Reg loss: 0.024 | Tree loss: 2.038 | Accuracy: 0.287109 | 0.253 sec/iter
Epoch: 69 | Batch: 011 / 029 | Total loss: 2.096 | Reg loss: 0.024 | Tree loss: 2.096 | Accuracy: 0.275391 | 0.253 sec/iter
Epoch: 69 | Batch: 012 / 029 | Total loss: 2.001 | Reg loss: 0.024 | Tree loss: 2.001 | Accuracy: 0.300781 | 0.253 sec/iter
Epoch: 69 | Batch: 013 / 029 | Total loss: 2.077 | Reg loss: 0.024 | Tree loss: 2.077 | Accuracy: 0.312500 | 0.253 sec/iter
Epoch: 69 | Batch: 014 / 029 | Total loss: 2.079 | Reg loss: 0.024 | Tree loss: 2.079 | Accuracy: 0.250000 | 0.253 sec/iter
Epoch: 69 | Batch: 015 / 029 | Total loss: 2.048 | Reg loss: 0.024 | Tree loss: 2.048 | Accuracy: 0.279297 | 0.253 sec/iter
Epoch: 69 | Batch: 016 / 029 | Total loss: 2.069 | Reg loss: 0.024 | Tree loss: 2.069 | Accuracy: 0.281250 | 0.253 sec/iter
Epoch: 69 | Batch: 017 / 029 | Total loss: 2.065 | Reg loss: 0.024 | Tree loss: 2.065 | Accuracy: 0.253906 | 0.253 sec/iter
Epoch: 6

Epoch: 71 | Batch: 015 / 029 | Total loss: 2.023 | Reg loss: 0.024 | Tree loss: 2.023 | Accuracy: 0.289062 | 0.253 sec/iter
Epoch: 71 | Batch: 016 / 029 | Total loss: 2.051 | Reg loss: 0.024 | Tree loss: 2.051 | Accuracy: 0.281250 | 0.253 sec/iter
Epoch: 71 | Batch: 017 / 029 | Total loss: 2.053 | Reg loss: 0.024 | Tree loss: 2.053 | Accuracy: 0.304688 | 0.253 sec/iter
Epoch: 71 | Batch: 018 / 029 | Total loss: 2.053 | Reg loss: 0.024 | Tree loss: 2.053 | Accuracy: 0.294922 | 0.253 sec/iter
Epoch: 71 | Batch: 019 / 029 | Total loss: 2.070 | Reg loss: 0.024 | Tree loss: 2.070 | Accuracy: 0.292969 | 0.253 sec/iter
Epoch: 71 | Batch: 020 / 029 | Total loss: 2.048 | Reg loss: 0.024 | Tree loss: 2.048 | Accuracy: 0.277344 | 0.253 sec/iter
Epoch: 71 | Batch: 021 / 029 | Total loss: 2.080 | Reg loss: 0.024 | Tree loss: 2.080 | Accuracy: 0.283203 | 0.253 sec/iter
Epoch: 71 | Batch: 022 / 029 | Total loss: 2.061 | Reg loss: 0.024 | Tree loss: 2.061 | Accuracy: 0.271484 | 0.253 sec/iter
Epoch: 7

Epoch: 73 | Batch: 020 / 029 | Total loss: 2.089 | Reg loss: 0.024 | Tree loss: 2.089 | Accuracy: 0.261719 | 0.252 sec/iter
Epoch: 73 | Batch: 021 / 029 | Total loss: 2.072 | Reg loss: 0.024 | Tree loss: 2.072 | Accuracy: 0.250000 | 0.252 sec/iter
Epoch: 73 | Batch: 022 / 029 | Total loss: 2.098 | Reg loss: 0.024 | Tree loss: 2.098 | Accuracy: 0.250000 | 0.252 sec/iter
Epoch: 73 | Batch: 023 / 029 | Total loss: 2.031 | Reg loss: 0.024 | Tree loss: 2.031 | Accuracy: 0.283203 | 0.252 sec/iter
Epoch: 73 | Batch: 024 / 029 | Total loss: 2.023 | Reg loss: 0.024 | Tree loss: 2.023 | Accuracy: 0.310547 | 0.252 sec/iter
Epoch: 73 | Batch: 025 / 029 | Total loss: 2.033 | Reg loss: 0.024 | Tree loss: 2.033 | Accuracy: 0.302734 | 0.252 sec/iter
Epoch: 73 | Batch: 026 / 029 | Total loss: 2.092 | Reg loss: 0.024 | Tree loss: 2.092 | Accuracy: 0.279297 | 0.252 sec/iter
Epoch: 73 | Batch: 027 / 029 | Total loss: 2.060 | Reg loss: 0.024 | Tree loss: 2.060 | Accuracy: 0.273438 | 0.252 sec/iter
Epoch: 7

Epoch: 75 | Batch: 025 / 029 | Total loss: 2.032 | Reg loss: 0.024 | Tree loss: 2.032 | Accuracy: 0.277344 | 0.252 sec/iter
Epoch: 75 | Batch: 026 / 029 | Total loss: 2.061 | Reg loss: 0.024 | Tree loss: 2.061 | Accuracy: 0.285156 | 0.252 sec/iter
Epoch: 75 | Batch: 027 / 029 | Total loss: 2.085 | Reg loss: 0.024 | Tree loss: 2.085 | Accuracy: 0.281250 | 0.252 sec/iter
Epoch: 75 | Batch: 028 / 029 | Total loss: 1.847 | Reg loss: 0.024 | Tree loss: 1.847 | Accuracy: 0.384615 | 0.252 sec/iter
Average sparseness: 0.9821428571428573
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
Epoch: 76 | Batch: 000 / 029 | Total loss: 2.033 | Reg loss: 0.024 | Tree loss: 2.033 | Accuracy: 0.283203 | 0.253 sec/iter
Epoch: 76 | Batch: 001 / 029 | Total loss: 2.074 | Reg loss: 0.024 | Tree loss: 2.074 | Accuracy: 0.275391 | 0.253 sec/iter
Epoch: 76 | Batch: 002

Epoch: 78 | Batch: 000 / 029 | Total loss: 2.117 | Reg loss: 0.024 | Tree loss: 2.117 | Accuracy: 0.230469 | 0.253 sec/iter
Epoch: 78 | Batch: 001 / 029 | Total loss: 2.047 | Reg loss: 0.024 | Tree loss: 2.047 | Accuracy: 0.279297 | 0.253 sec/iter
Epoch: 78 | Batch: 002 / 029 | Total loss: 2.101 | Reg loss: 0.024 | Tree loss: 2.101 | Accuracy: 0.265625 | 0.253 sec/iter
Epoch: 78 | Batch: 003 / 029 | Total loss: 2.067 | Reg loss: 0.024 | Tree loss: 2.067 | Accuracy: 0.279297 | 0.253 sec/iter
Epoch: 78 | Batch: 004 / 029 | Total loss: 2.055 | Reg loss: 0.024 | Tree loss: 2.055 | Accuracy: 0.257812 | 0.252 sec/iter
Epoch: 78 | Batch: 005 / 029 | Total loss: 2.057 | Reg loss: 0.024 | Tree loss: 2.057 | Accuracy: 0.302734 | 0.252 sec/iter
Epoch: 78 | Batch: 006 / 029 | Total loss: 2.089 | Reg loss: 0.024 | Tree loss: 2.089 | Accuracy: 0.259766 | 0.252 sec/iter
Epoch: 78 | Batch: 007 / 029 | Total loss: 2.029 | Reg loss: 0.024 | Tree loss: 2.029 | Accuracy: 0.283203 | 0.252 sec/iter
Epoch: 7

Epoch: 80 | Batch: 005 / 029 | Total loss: 2.004 | Reg loss: 0.024 | Tree loss: 2.004 | Accuracy: 0.285156 | 0.252 sec/iter
Epoch: 80 | Batch: 006 / 029 | Total loss: 1.995 | Reg loss: 0.024 | Tree loss: 1.995 | Accuracy: 0.302734 | 0.252 sec/iter
Epoch: 80 | Batch: 007 / 029 | Total loss: 2.057 | Reg loss: 0.024 | Tree loss: 2.057 | Accuracy: 0.277344 | 0.252 sec/iter
Epoch: 80 | Batch: 008 / 029 | Total loss: 2.044 | Reg loss: 0.024 | Tree loss: 2.044 | Accuracy: 0.285156 | 0.252 sec/iter
Epoch: 80 | Batch: 009 / 029 | Total loss: 2.074 | Reg loss: 0.024 | Tree loss: 2.074 | Accuracy: 0.275391 | 0.252 sec/iter
Epoch: 80 | Batch: 010 / 029 | Total loss: 2.065 | Reg loss: 0.024 | Tree loss: 2.065 | Accuracy: 0.255859 | 0.252 sec/iter
Epoch: 80 | Batch: 011 / 029 | Total loss: 2.072 | Reg loss: 0.024 | Tree loss: 2.072 | Accuracy: 0.279297 | 0.252 sec/iter
Epoch: 80 | Batch: 012 / 029 | Total loss: 2.071 | Reg loss: 0.024 | Tree loss: 2.071 | Accuracy: 0.265625 | 0.252 sec/iter
Epoch: 8

Epoch: 82 | Batch: 010 / 029 | Total loss: 2.068 | Reg loss: 0.024 | Tree loss: 2.068 | Accuracy: 0.277344 | 0.252 sec/iter
Epoch: 82 | Batch: 011 / 029 | Total loss: 2.006 | Reg loss: 0.024 | Tree loss: 2.006 | Accuracy: 0.292969 | 0.252 sec/iter
Epoch: 82 | Batch: 012 / 029 | Total loss: 2.068 | Reg loss: 0.024 | Tree loss: 2.068 | Accuracy: 0.257812 | 0.252 sec/iter
Epoch: 82 | Batch: 013 / 029 | Total loss: 2.073 | Reg loss: 0.024 | Tree loss: 2.073 | Accuracy: 0.259766 | 0.252 sec/iter
Epoch: 82 | Batch: 014 / 029 | Total loss: 2.057 | Reg loss: 0.024 | Tree loss: 2.057 | Accuracy: 0.287109 | 0.252 sec/iter
Epoch: 82 | Batch: 015 / 029 | Total loss: 2.058 | Reg loss: 0.024 | Tree loss: 2.058 | Accuracy: 0.291016 | 0.252 sec/iter
Epoch: 82 | Batch: 016 / 029 | Total loss: 2.007 | Reg loss: 0.024 | Tree loss: 2.007 | Accuracy: 0.310547 | 0.252 sec/iter
Epoch: 82 | Batch: 017 / 029 | Total loss: 2.039 | Reg loss: 0.024 | Tree loss: 2.039 | Accuracy: 0.271484 | 0.252 sec/iter
Epoch: 8

Epoch: 84 | Batch: 015 / 029 | Total loss: 1.996 | Reg loss: 0.024 | Tree loss: 1.996 | Accuracy: 0.312500 | 0.252 sec/iter
Epoch: 84 | Batch: 016 / 029 | Total loss: 2.068 | Reg loss: 0.024 | Tree loss: 2.068 | Accuracy: 0.251953 | 0.252 sec/iter
Epoch: 84 | Batch: 017 / 029 | Total loss: 2.048 | Reg loss: 0.024 | Tree loss: 2.048 | Accuracy: 0.291016 | 0.252 sec/iter
Epoch: 84 | Batch: 018 / 029 | Total loss: 2.049 | Reg loss: 0.024 | Tree loss: 2.049 | Accuracy: 0.267578 | 0.252 sec/iter
Epoch: 84 | Batch: 019 / 029 | Total loss: 2.041 | Reg loss: 0.024 | Tree loss: 2.041 | Accuracy: 0.287109 | 0.252 sec/iter
Epoch: 84 | Batch: 020 / 029 | Total loss: 2.074 | Reg loss: 0.024 | Tree loss: 2.074 | Accuracy: 0.248047 | 0.252 sec/iter
Epoch: 84 | Batch: 021 / 029 | Total loss: 2.073 | Reg loss: 0.024 | Tree loss: 2.073 | Accuracy: 0.291016 | 0.252 sec/iter
Epoch: 84 | Batch: 022 / 029 | Total loss: 2.071 | Reg loss: 0.024 | Tree loss: 2.071 | Accuracy: 0.259766 | 0.252 sec/iter
Epoch: 8

Epoch: 86 | Batch: 020 / 029 | Total loss: 2.070 | Reg loss: 0.024 | Tree loss: 2.070 | Accuracy: 0.271484 | 0.252 sec/iter
Epoch: 86 | Batch: 021 / 029 | Total loss: 2.056 | Reg loss: 0.024 | Tree loss: 2.056 | Accuracy: 0.273438 | 0.252 sec/iter
Epoch: 86 | Batch: 022 / 029 | Total loss: 2.049 | Reg loss: 0.024 | Tree loss: 2.049 | Accuracy: 0.265625 | 0.252 sec/iter
Epoch: 86 | Batch: 023 / 029 | Total loss: 2.071 | Reg loss: 0.024 | Tree loss: 2.071 | Accuracy: 0.267578 | 0.252 sec/iter
Epoch: 86 | Batch: 024 / 029 | Total loss: 2.031 | Reg loss: 0.024 | Tree loss: 2.031 | Accuracy: 0.291016 | 0.252 sec/iter
Epoch: 86 | Batch: 025 / 029 | Total loss: 2.023 | Reg loss: 0.024 | Tree loss: 2.023 | Accuracy: 0.271484 | 0.252 sec/iter
Epoch: 86 | Batch: 026 / 029 | Total loss: 2.057 | Reg loss: 0.024 | Tree loss: 2.057 | Accuracy: 0.269531 | 0.252 sec/iter
Epoch: 86 | Batch: 027 / 029 | Total loss: 2.073 | Reg loss: 0.024 | Tree loss: 2.073 | Accuracy: 0.251953 | 0.252 sec/iter
Epoch: 8

Epoch: 88 | Batch: 025 / 029 | Total loss: 2.035 | Reg loss: 0.024 | Tree loss: 2.035 | Accuracy: 0.285156 | 0.252 sec/iter
Epoch: 88 | Batch: 026 / 029 | Total loss: 2.052 | Reg loss: 0.024 | Tree loss: 2.052 | Accuracy: 0.267578 | 0.252 sec/iter
Epoch: 88 | Batch: 027 / 029 | Total loss: 2.047 | Reg loss: 0.024 | Tree loss: 2.047 | Accuracy: 0.267578 | 0.252 sec/iter
Epoch: 88 | Batch: 028 / 029 | Total loss: 2.234 | Reg loss: 0.024 | Tree loss: 2.234 | Accuracy: 0.153846 | 0.252 sec/iter
Average sparseness: 0.9821428571428573
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
Epoch: 89 | Batch: 000 / 029 | Total loss: 2.031 | Reg loss: 0.024 | Tree loss: 2.031 | Accuracy: 0.253906 | 0.253 sec/iter
Epoch: 89 | Batch: 001 / 029 | Total loss: 2.066 | Reg loss: 0.024 | Tree loss: 2.066 | Accuracy: 0.257812 | 0.253 sec/iter
Epoch: 89 | Batch: 002

Epoch: 91 | Batch: 000 / 029 | Total loss: 2.023 | Reg loss: 0.024 | Tree loss: 2.023 | Accuracy: 0.281250 | 0.253 sec/iter
Epoch: 91 | Batch: 001 / 029 | Total loss: 2.015 | Reg loss: 0.024 | Tree loss: 2.015 | Accuracy: 0.273438 | 0.253 sec/iter
Epoch: 91 | Batch: 002 / 029 | Total loss: 1.990 | Reg loss: 0.024 | Tree loss: 1.990 | Accuracy: 0.277344 | 0.253 sec/iter
Epoch: 91 | Batch: 003 / 029 | Total loss: 2.008 | Reg loss: 0.024 | Tree loss: 2.008 | Accuracy: 0.296875 | 0.253 sec/iter
Epoch: 91 | Batch: 004 / 029 | Total loss: 2.060 | Reg loss: 0.024 | Tree loss: 2.060 | Accuracy: 0.250000 | 0.253 sec/iter
Epoch: 91 | Batch: 005 / 029 | Total loss: 2.062 | Reg loss: 0.024 | Tree loss: 2.062 | Accuracy: 0.287109 | 0.253 sec/iter
Epoch: 91 | Batch: 006 / 029 | Total loss: 2.055 | Reg loss: 0.024 | Tree loss: 2.055 | Accuracy: 0.291016 | 0.253 sec/iter
Epoch: 91 | Batch: 007 / 029 | Total loss: 2.026 | Reg loss: 0.024 | Tree loss: 2.026 | Accuracy: 0.283203 | 0.253 sec/iter
Epoch: 9

Epoch: 93 | Batch: 005 / 029 | Total loss: 2.065 | Reg loss: 0.024 | Tree loss: 2.065 | Accuracy: 0.275391 | 0.253 sec/iter
Epoch: 93 | Batch: 006 / 029 | Total loss: 2.024 | Reg loss: 0.024 | Tree loss: 2.024 | Accuracy: 0.273438 | 0.253 sec/iter
Epoch: 93 | Batch: 007 / 029 | Total loss: 2.014 | Reg loss: 0.024 | Tree loss: 2.014 | Accuracy: 0.285156 | 0.253 sec/iter
Epoch: 93 | Batch: 008 / 029 | Total loss: 2.005 | Reg loss: 0.024 | Tree loss: 2.005 | Accuracy: 0.271484 | 0.253 sec/iter
Epoch: 93 | Batch: 009 / 029 | Total loss: 1.999 | Reg loss: 0.024 | Tree loss: 1.999 | Accuracy: 0.296875 | 0.253 sec/iter
Epoch: 93 | Batch: 010 / 029 | Total loss: 2.023 | Reg loss: 0.024 | Tree loss: 2.023 | Accuracy: 0.302734 | 0.253 sec/iter
Epoch: 93 | Batch: 011 / 029 | Total loss: 2.045 | Reg loss: 0.024 | Tree loss: 2.045 | Accuracy: 0.285156 | 0.253 sec/iter
Epoch: 93 | Batch: 012 / 029 | Total loss: 2.034 | Reg loss: 0.024 | Tree loss: 2.034 | Accuracy: 0.281250 | 0.253 sec/iter
Epoch: 9

Epoch: 95 | Batch: 010 / 029 | Total loss: 2.074 | Reg loss: 0.023 | Tree loss: 2.074 | Accuracy: 0.242188 | 0.253 sec/iter
Epoch: 95 | Batch: 011 / 029 | Total loss: 2.027 | Reg loss: 0.023 | Tree loss: 2.027 | Accuracy: 0.281250 | 0.253 sec/iter
Epoch: 95 | Batch: 012 / 029 | Total loss: 2.002 | Reg loss: 0.023 | Tree loss: 2.002 | Accuracy: 0.285156 | 0.253 sec/iter
Epoch: 95 | Batch: 013 / 029 | Total loss: 2.040 | Reg loss: 0.023 | Tree loss: 2.040 | Accuracy: 0.279297 | 0.253 sec/iter
Epoch: 95 | Batch: 014 / 029 | Total loss: 2.048 | Reg loss: 0.023 | Tree loss: 2.048 | Accuracy: 0.263672 | 0.253 sec/iter
Epoch: 95 | Batch: 015 / 029 | Total loss: 2.046 | Reg loss: 0.023 | Tree loss: 2.046 | Accuracy: 0.291016 | 0.253 sec/iter
Epoch: 95 | Batch: 016 / 029 | Total loss: 2.046 | Reg loss: 0.023 | Tree loss: 2.046 | Accuracy: 0.261719 | 0.253 sec/iter
Epoch: 95 | Batch: 017 / 029 | Total loss: 2.031 | Reg loss: 0.023 | Tree loss: 2.031 | Accuracy: 0.281250 | 0.253 sec/iter
Epoch: 9

Epoch: 97 | Batch: 015 / 029 | Total loss: 2.021 | Reg loss: 0.023 | Tree loss: 2.021 | Accuracy: 0.265625 | 0.253 sec/iter
Epoch: 97 | Batch: 016 / 029 | Total loss: 2.035 | Reg loss: 0.023 | Tree loss: 2.035 | Accuracy: 0.267578 | 0.253 sec/iter
Epoch: 97 | Batch: 017 / 029 | Total loss: 2.035 | Reg loss: 0.023 | Tree loss: 2.035 | Accuracy: 0.285156 | 0.252 sec/iter
Epoch: 97 | Batch: 018 / 029 | Total loss: 2.013 | Reg loss: 0.023 | Tree loss: 2.013 | Accuracy: 0.277344 | 0.252 sec/iter
Epoch: 97 | Batch: 019 / 029 | Total loss: 2.018 | Reg loss: 0.023 | Tree loss: 2.018 | Accuracy: 0.248047 | 0.252 sec/iter
Epoch: 97 | Batch: 020 / 029 | Total loss: 2.018 | Reg loss: 0.023 | Tree loss: 2.018 | Accuracy: 0.281250 | 0.252 sec/iter
Epoch: 97 | Batch: 021 / 029 | Total loss: 2.021 | Reg loss: 0.023 | Tree loss: 2.021 | Accuracy: 0.271484 | 0.252 sec/iter
Epoch: 97 | Batch: 022 / 029 | Total loss: 2.047 | Reg loss: 0.023 | Tree loss: 2.047 | Accuracy: 0.275391 | 0.252 sec/iter
Epoch: 9

Epoch: 99 | Batch: 020 / 029 | Total loss: 1.987 | Reg loss: 0.023 | Tree loss: 1.987 | Accuracy: 0.312500 | 0.252 sec/iter
Epoch: 99 | Batch: 021 / 029 | Total loss: 2.042 | Reg loss: 0.023 | Tree loss: 2.042 | Accuracy: 0.275391 | 0.252 sec/iter
Epoch: 99 | Batch: 022 / 029 | Total loss: 2.025 | Reg loss: 0.023 | Tree loss: 2.025 | Accuracy: 0.269531 | 0.252 sec/iter
Epoch: 99 | Batch: 023 / 029 | Total loss: 2.006 | Reg loss: 0.023 | Tree loss: 2.006 | Accuracy: 0.292969 | 0.252 sec/iter
Epoch: 99 | Batch: 024 / 029 | Total loss: 2.000 | Reg loss: 0.023 | Tree loss: 2.000 | Accuracy: 0.291016 | 0.252 sec/iter
Epoch: 99 | Batch: 025 / 029 | Total loss: 1.985 | Reg loss: 0.023 | Tree loss: 1.985 | Accuracy: 0.312500 | 0.252 sec/iter
Epoch: 99 | Batch: 026 / 029 | Total loss: 2.071 | Reg loss: 0.023 | Tree loss: 2.071 | Accuracy: 0.267578 | 0.252 sec/iter
Epoch: 99 | Batch: 027 / 029 | Total loss: 2.016 | Reg loss: 0.023 | Tree loss: 2.016 | Accuracy: 0.271484 | 0.252 sec/iter
Epoch: 9

In [39]:
plt.figure(figsize=(10, 5))
plt.ylabel("Accuracy")
plt.xlabel('Iteration')
plt.plot(accs, label='Accuracy vs iteration')
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [40]:
plt.figure()
plt.ylabel("Loss")
plt.xlabel('Iteration')
plt.plot(losses, label='Loss vs iteration')
plt.yscale("log")
plt.show()

plt.figure()
weights = tree.inner_nodes.weight.cpu().detach().numpy().flatten()
plt.hist(weights, bins=500)
weights_std = np.std(weights)
weights_mean = np.mean(weights)
plt.axvline(weights_mean + weights_std, color='r')
plt.axvline(weights_mean - weights_std, color='r')
plt.title(f"Mean: {weights_mean}   |   STD: {weights_std}")
plt.yscale("log")
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

# Tree Visualization

In [41]:
plt.figure(figsize=(15, 10), dpi=80)
avg_height, root = tree.visualize()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Average height: 7.426229508196721


# Extract Rules

# Accumulate samples in the leaves

In [42]:
print(f"Number of patterns: {len(root.get_leaves())}")

Number of patterns: 122


In [43]:
method = 'greedy'

In [44]:
root.clear_leaves_samples()

with torch.no_grad():
    for batch_idx, (data, target) in enumerate(tree_loader):
        root.accumulate_samples(data, method)



# Tighten boundaries

In [45]:
attr_names = dataset.items

# print(attr_names)
leaves = root.get_leaves()
sum_comprehensibility = 0
comprehensibilities = []
for pattern_counter, leaf in enumerate(leaves):
    leaf.reset_path()
    leaf.tighten_with_accumulated_samples()
    conds = leaf.get_path_conditions(attr_names)
    print(f"============== Pattern {pattern_counter + 1} ==============")
    comprehensibilities.append(sum([cond.comprehensibility for cond in conds]))
    
print(f"Average comprehensibility: {np.mean(comprehensibilities)}")
print(f"std comprehensibility: {np.std(comprehensibilities)}")
print(f"var comprehensibility: {np.var(comprehensibilities)}")
print(f"minimum comprehensibility: {np.min(comprehensibilities)}")
print(f"maximum comprehensibility: {np.max(comprehensibilities)}")



  return np.log(1 / (1 - x))


14349
Average comprehensibility: 35.24590163934426
std comprehensibility: 5.664545292199045
var comprehensibility: 32.08707336737436
minimum comprehensibility: 16
maximum comprehensibility: 44
