In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib widget
import numpy as np
import pandas as pd
import torch
import torchvision
from torch import optim
import matplotlib.pyplot as plt
import torch.nn.functional as F
from sklearn.metrics import pairwise_distances
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tqdm.notebook import tqdm
import os
import sys
import time
import torch.nn as nn
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)
from stream_generators.market_basket_dataset import MarketBasketDataset, BinaryEncodingTransform, RemoveItemsTransform
from utils.MatplotlibUtils import reduce_dims_and_plot
from network.auto_encoder import AutoEncoder
from losses.knn_loss import KNNLoss
from sklearn.cluster import DBSCAN
from sklearn.mixture import GaussianMixture
from soft_decision_tree.sdt_model import SDT
from sklearn.metrics import davies_bouldin_score

In [2]:
k = 32
tree_depth = 8
device = 'cuda'
dataset_path = r"/mnt/qnap/ekosman/Groceries_dataset.csv"

## Create the market basket dataset and use one-hot encoding for items

In [3]:
dataset = MarketBasketDataset(dataset_path=dataset_path)

In [4]:
model = AutoEncoder(dataset.n_items, 50, 4).train().to(device)
epochs = 500
lr = 5e-3
batch_size = 512
log_every = 5

In [5]:
dataset.transform = torchvision.transforms.Compose([
    RemoveItemsTransform(p=0.5),
    BinaryEncodingTransform(mapping=dataset.items_to_idx),
]
)
dataset.target_transform = torchvision.transforms.Compose([
    BinaryEncodingTransform(mapping=dataset.items_to_idx),
]
)

In [6]:
model.train()
data_iter = torch.utils.data.DataLoader(dataset,
                                     batch_size=batch_size,
                                     shuffle=True,
                                     num_workers=1,
                                     pin_memory=True)
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)
scheduler = ReduceLROnPlateau(optimizer, 'min', verbose=True, factor=0.7, threshold=1e-4)
knn_crt = KNNLoss(k=k).to(device)
losses = []
alpha = 10/170
gamma = 2
for epoch in range(epochs):
    total_loss = 0
    for iteration, (batch, target) in enumerate(data_iter):
        batch = batch.to(device)
        target = target.to(device)
        outputs, iterm = model(batch, return_intermidiate=True)
        mse_loss = F.binary_cross_entropy_with_logits(outputs, target, reduction='none')
        mask = torch.ones_like(mse_loss)
        mask[target == 0] = alpha ** gamma
        mask[target == 1] = (1 - alpha) ** gamma
        mse_loss = (mse_loss * mask).sum(dim=-1).mean()
        try:
            knn_loss = knn_crt(iterm)
            if torch.isinf(knn_loss):
                knn_loss = 0
        except ValueError:
            knn_loss = torch.tensor(0)
        loss = mse_loss + knn_loss
        total_loss += loss.item()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if iteration % log_every == 0:
            print(f"Epoch {epoch} / {epochs} | iteration {iteration} / {len(data_iter)} | Total Loss: {loss.item()} | KNN Loss: {knn_loss.item()} | BCE Loss: {mse_loss.item()}")
    
    scheduler.step(total_loss / (iteration + 1))
    losses.append(total_loss / (iteration + 1))

Epoch 0 / 500 | iteration 0 / 30 | Total Loss: 8.180998802185059 | KNN Loss: 6.227673530578613 | BCE Loss: 1.9533252716064453
Epoch 0 / 500 | iteration 5 / 30 | Total Loss: 8.19716739654541 | KNN Loss: 6.2274298667907715 | BCE Loss: 1.9697372913360596
Epoch 0 / 500 | iteration 10 / 30 | Total Loss: 8.151121139526367 | KNN Loss: 6.227524757385254 | BCE Loss: 1.9235968589782715
Epoch 0 / 500 | iteration 15 / 30 | Total Loss: 8.188850402832031 | KNN Loss: 6.2274041175842285 | BCE Loss: 1.9614464044570923
Epoch 0 / 500 | iteration 20 / 30 | Total Loss: 8.152887344360352 | KNN Loss: 6.227034568786621 | BCE Loss: 1.9258532524108887
Epoch 0 / 500 | iteration 25 / 30 | Total Loss: 8.164999008178711 | KNN Loss: 6.226687431335449 | BCE Loss: 1.9383113384246826
Epoch 1 / 500 | iteration 0 / 30 | Total Loss: 8.155362129211426 | KNN Loss: 6.226534366607666 | BCE Loss: 1.9288280010223389
Epoch 1 / 500 | iteration 5 / 30 | Total Loss: 8.13614273071289 | KNN Loss: 6.226588726043701 | BCE Loss: 1.90955

Epoch 10 / 500 | iteration 25 / 30 | Total Loss: 5.934737205505371 | KNN Loss: 4.850061893463135 | BCE Loss: 1.0846753120422363
Epoch 11 / 500 | iteration 0 / 30 | Total Loss: 5.811098098754883 | KNN Loss: 4.706234455108643 | BCE Loss: 1.1048638820648193
Epoch 11 / 500 | iteration 5 / 30 | Total Loss: 5.732649326324463 | KNN Loss: 4.618710994720459 | BCE Loss: 1.113938331604004
Epoch 11 / 500 | iteration 10 / 30 | Total Loss: 5.6056976318359375 | KNN Loss: 4.51455545425415 | BCE Loss: 1.0911420583724976
Epoch 11 / 500 | iteration 15 / 30 | Total Loss: 5.611255645751953 | KNN Loss: 4.4998064041137695 | BCE Loss: 1.1114490032196045
Epoch 11 / 500 | iteration 20 / 30 | Total Loss: 5.560354709625244 | KNN Loss: 4.41553258895874 | BCE Loss: 1.1448222398757935
Epoch 11 / 500 | iteration 25 / 30 | Total Loss: 5.4366607666015625 | KNN Loss: 4.3242902755737305 | BCE Loss: 1.112370491027832
Epoch 12 / 500 | iteration 0 / 30 | Total Loss: 5.374316692352295 | KNN Loss: 4.267876148223877 | BCE Loss

Epoch 21 / 500 | iteration 20 / 30 | Total Loss: 4.933305263519287 | KNN Loss: 3.8630523681640625 | BCE Loss: 1.070252776145935
Epoch 21 / 500 | iteration 25 / 30 | Total Loss: 4.93026876449585 | KNN Loss: 3.8572144508361816 | BCE Loss: 1.073054313659668
Epoch 22 / 500 | iteration 0 / 30 | Total Loss: 4.902609825134277 | KNN Loss: 3.868712902069092 | BCE Loss: 1.0338969230651855
Epoch 22 / 500 | iteration 5 / 30 | Total Loss: 4.924625396728516 | KNN Loss: 3.8884031772613525 | BCE Loss: 1.036222219467163
Epoch 22 / 500 | iteration 10 / 30 | Total Loss: 4.907341003417969 | KNN Loss: 3.8555479049682617 | BCE Loss: 1.0517933368682861
Epoch 22 / 500 | iteration 15 / 30 | Total Loss: 4.910784721374512 | KNN Loss: 3.857971429824829 | BCE Loss: 1.0528134107589722
Epoch 22 / 500 | iteration 20 / 30 | Total Loss: 4.968305587768555 | KNN Loss: 3.8832643032073975 | BCE Loss: 1.0850410461425781
Epoch 22 / 500 | iteration 25 / 30 | Total Loss: 5.030717372894287 | KNN Loss: 3.936378002166748 | BCE Lo

Epoch 32 / 500 | iteration 15 / 30 | Total Loss: 4.861330032348633 | KNN Loss: 3.83554744720459 | BCE Loss: 1.0257823467254639
Epoch 32 / 500 | iteration 20 / 30 | Total Loss: 4.830349445343018 | KNN Loss: 3.784989833831787 | BCE Loss: 1.04535973072052
Epoch 32 / 500 | iteration 25 / 30 | Total Loss: 4.912267208099365 | KNN Loss: 3.822322368621826 | BCE Loss: 1.089944839477539
Epoch 33 / 500 | iteration 0 / 30 | Total Loss: 4.833408355712891 | KNN Loss: 3.7764010429382324 | BCE Loss: 1.057007074356079
Epoch 33 / 500 | iteration 5 / 30 | Total Loss: 4.839536666870117 | KNN Loss: 3.7879533767700195 | BCE Loss: 1.0515832901000977
Epoch 33 / 500 | iteration 10 / 30 | Total Loss: 4.895272254943848 | KNN Loss: 3.8180091381073 | BCE Loss: 1.0772629976272583
Epoch 33 / 500 | iteration 15 / 30 | Total Loss: 4.921525001525879 | KNN Loss: 3.8660085201263428 | BCE Loss: 1.055516242980957
Epoch 33 / 500 | iteration 20 / 30 | Total Loss: 4.872917175292969 | KNN Loss: 3.8399155139923096 | BCE Loss: 1

Epoch 43 / 500 | iteration 10 / 30 | Total Loss: 4.803697109222412 | KNN Loss: 3.750314950942993 | BCE Loss: 1.0533822774887085
Epoch 43 / 500 | iteration 15 / 30 | Total Loss: 4.808242321014404 | KNN Loss: 3.760650634765625 | BCE Loss: 1.0475915670394897
Epoch 43 / 500 | iteration 20 / 30 | Total Loss: 4.770204544067383 | KNN Loss: 3.7297310829162598 | BCE Loss: 1.040473461151123
Epoch 43 / 500 | iteration 25 / 30 | Total Loss: 4.796553611755371 | KNN Loss: 3.752020835876465 | BCE Loss: 1.0445325374603271
Epoch 44 / 500 | iteration 0 / 30 | Total Loss: 4.824166774749756 | KNN Loss: 3.776170492172241 | BCE Loss: 1.0479962825775146
Epoch 44 / 500 | iteration 5 / 30 | Total Loss: 4.785524845123291 | KNN Loss: 3.7503693103790283 | BCE Loss: 1.0351555347442627
Epoch 44 / 500 | iteration 10 / 30 | Total Loss: 4.7650604248046875 | KNN Loss: 3.7463302612304688 | BCE Loss: 1.0187301635742188
Epoch 44 / 500 | iteration 15 / 30 | Total Loss: 4.760532379150391 | KNN Loss: 3.742922306060791 | BCE 

Epoch 54 / 500 | iteration 0 / 30 | Total Loss: 4.719110012054443 | KNN Loss: 3.7152626514434814 | BCE Loss: 1.003847360610962
Epoch 54 / 500 | iteration 5 / 30 | Total Loss: 4.844790458679199 | KNN Loss: 3.7730729579925537 | BCE Loss: 1.0717175006866455
Epoch 54 / 500 | iteration 10 / 30 | Total Loss: 4.7858662605285645 | KNN Loss: 3.762510061264038 | BCE Loss: 1.0233561992645264
Epoch 54 / 500 | iteration 15 / 30 | Total Loss: 4.77788782119751 | KNN Loss: 3.7321507930755615 | BCE Loss: 1.0457369089126587
Epoch 54 / 500 | iteration 20 / 30 | Total Loss: 4.776598930358887 | KNN Loss: 3.746474266052246 | BCE Loss: 1.030124545097351
Epoch 54 / 500 | iteration 25 / 30 | Total Loss: 4.750310897827148 | KNN Loss: 3.7381091117858887 | BCE Loss: 1.0122017860412598
Epoch 55 / 500 | iteration 0 / 30 | Total Loss: 4.761354446411133 | KNN Loss: 3.742938995361328 | BCE Loss: 1.0184154510498047
Epoch 55 / 500 | iteration 5 / 30 | Total Loss: 4.790983200073242 | KNN Loss: 3.7687289714813232 | BCE Lo

Epoch 64 / 500 | iteration 25 / 30 | Total Loss: 4.760372161865234 | KNN Loss: 3.718369722366333 | BCE Loss: 1.0420024394989014
Epoch 65 / 500 | iteration 0 / 30 | Total Loss: 4.731297492980957 | KNN Loss: 3.6951308250427246 | BCE Loss: 1.036166787147522
Epoch 65 / 500 | iteration 5 / 30 | Total Loss: 4.7566752433776855 | KNN Loss: 3.7244739532470703 | BCE Loss: 1.0322014093399048
Epoch 65 / 500 | iteration 10 / 30 | Total Loss: 4.7477216720581055 | KNN Loss: 3.7371108531951904 | BCE Loss: 1.0106109380722046
Epoch 65 / 500 | iteration 15 / 30 | Total Loss: 4.768034934997559 | KNN Loss: 3.7284927368164062 | BCE Loss: 1.039542317390442
Epoch 65 / 500 | iteration 20 / 30 | Total Loss: 4.728218078613281 | KNN Loss: 3.709228754043579 | BCE Loss: 1.0189893245697021
Epoch 65 / 500 | iteration 25 / 30 | Total Loss: 4.718601226806641 | KNN Loss: 3.6989688873291016 | BCE Loss: 1.019632339477539
Epoch 66 / 500 | iteration 0 / 30 | Total Loss: 4.711813449859619 | KNN Loss: 3.712442636489868 | BCE 

Epoch 75 / 500 | iteration 15 / 30 | Total Loss: 4.7230544090271 | KNN Loss: 3.6981046199798584 | BCE Loss: 1.0249499082565308
Epoch 75 / 500 | iteration 20 / 30 | Total Loss: 4.7285261154174805 | KNN Loss: 3.6937403678894043 | BCE Loss: 1.0347857475280762
Epoch 75 / 500 | iteration 25 / 30 | Total Loss: 4.7381086349487305 | KNN Loss: 3.6927576065063477 | BCE Loss: 1.0453510284423828
Epoch 76 / 500 | iteration 0 / 30 | Total Loss: 4.750196933746338 | KNN Loss: 3.7426750659942627 | BCE Loss: 1.0075218677520752
Epoch 76 / 500 | iteration 5 / 30 | Total Loss: 4.773016452789307 | KNN Loss: 3.740229368209839 | BCE Loss: 1.0327869653701782
Epoch 76 / 500 | iteration 10 / 30 | Total Loss: 4.744694709777832 | KNN Loss: 3.7445616722106934 | BCE Loss: 1.0001327991485596
Epoch 76 / 500 | iteration 15 / 30 | Total Loss: 4.742612361907959 | KNN Loss: 3.7221782207489014 | BCE Loss: 1.0204341411590576
Epoch 76 / 500 | iteration 20 / 30 | Total Loss: 4.666397571563721 | KNN Loss: 3.6742031574249268 | 

Epoch 86 / 500 | iteration 10 / 30 | Total Loss: 4.7436933517456055 | KNN Loss: 3.7249066829681396 | BCE Loss: 1.0187864303588867
Epoch 86 / 500 | iteration 15 / 30 | Total Loss: 4.7403717041015625 | KNN Loss: 3.7147579193115234 | BCE Loss: 1.025613784790039
Epoch 86 / 500 | iteration 20 / 30 | Total Loss: 4.728116989135742 | KNN Loss: 3.6967148780822754 | BCE Loss: 1.0314019918441772
Epoch 86 / 500 | iteration 25 / 30 | Total Loss: 4.7343878746032715 | KNN Loss: 3.707601308822632 | BCE Loss: 1.02678644657135
Epoch    87: reducing learning rate of group 0 to 3.5000e-03.
Epoch 87 / 500 | iteration 0 / 30 | Total Loss: 4.703076362609863 | KNN Loss: 3.6937506198883057 | BCE Loss: 1.009325623512268
Epoch 87 / 500 | iteration 5 / 30 | Total Loss: 4.7834577560424805 | KNN Loss: 3.723421812057495 | BCE Loss: 1.0600359439849854
Epoch 87 / 500 | iteration 10 / 30 | Total Loss: 4.7734904289245605 | KNN Loss: 3.7249653339385986 | BCE Loss: 1.048525094985962
Epoch 87 / 500 | iteration 15 / 30 | To

Epoch 97 / 500 | iteration 0 / 30 | Total Loss: 4.7416276931762695 | KNN Loss: 3.710026502609253 | BCE Loss: 1.0316014289855957
Epoch 97 / 500 | iteration 5 / 30 | Total Loss: 4.699658393859863 | KNN Loss: 3.70951247215271 | BCE Loss: 0.9901460409164429
Epoch 97 / 500 | iteration 10 / 30 | Total Loss: 4.779877662658691 | KNN Loss: 3.74356746673584 | BCE Loss: 1.0363104343414307
Epoch 97 / 500 | iteration 15 / 30 | Total Loss: 4.720315933227539 | KNN Loss: 3.6791651248931885 | BCE Loss: 1.041150689125061
Epoch 97 / 500 | iteration 20 / 30 | Total Loss: 4.735020160675049 | KNN Loss: 3.7110471725463867 | BCE Loss: 1.0239731073379517
Epoch 97 / 500 | iteration 25 / 30 | Total Loss: 4.723799228668213 | KNN Loss: 3.68276047706604 | BCE Loss: 1.0410388708114624
Epoch 98 / 500 | iteration 0 / 30 | Total Loss: 4.699944972991943 | KNN Loss: 3.713561773300171 | BCE Loss: 0.986383318901062
Epoch 98 / 500 | iteration 5 / 30 | Total Loss: 4.738221168518066 | KNN Loss: 3.7093424797058105 | BCE Loss: 

Epoch 107 / 500 | iteration 20 / 30 | Total Loss: 4.726795673370361 | KNN Loss: 3.7179653644561768 | BCE Loss: 1.0088303089141846
Epoch 107 / 500 | iteration 25 / 30 | Total Loss: 4.707529544830322 | KNN Loss: 3.6718239784240723 | BCE Loss: 1.03570556640625
Epoch   108: reducing learning rate of group 0 to 2.4500e-03.
Epoch 108 / 500 | iteration 0 / 30 | Total Loss: 4.67328405380249 | KNN Loss: 3.662954807281494 | BCE Loss: 1.010329246520996
Epoch 108 / 500 | iteration 5 / 30 | Total Loss: 4.7479166984558105 | KNN Loss: 3.7154738903045654 | BCE Loss: 1.0324428081512451
Epoch 108 / 500 | iteration 10 / 30 | Total Loss: 4.743028163909912 | KNN Loss: 3.768920421600342 | BCE Loss: 0.9741077423095703
Epoch 108 / 500 | iteration 15 / 30 | Total Loss: 4.756906986236572 | KNN Loss: 3.720902442932129 | BCE Loss: 1.0360045433044434
Epoch 108 / 500 | iteration 20 / 30 | Total Loss: 4.758960723876953 | KNN Loss: 3.736894130706787 | BCE Loss: 1.022066354751587
Epoch 108 / 500 | iteration 25 / 30 | 

Epoch 118 / 500 | iteration 10 / 30 | Total Loss: 4.735187530517578 | KNN Loss: 3.6942498683929443 | BCE Loss: 1.040937900543213
Epoch 118 / 500 | iteration 15 / 30 | Total Loss: 4.779535293579102 | KNN Loss: 3.7382898330688477 | BCE Loss: 1.041245698928833
Epoch 118 / 500 | iteration 20 / 30 | Total Loss: 4.699033737182617 | KNN Loss: 3.687662363052368 | BCE Loss: 1.0113712549209595
Epoch 118 / 500 | iteration 25 / 30 | Total Loss: 4.693668365478516 | KNN Loss: 3.6779465675354004 | BCE Loss: 1.0157215595245361
Epoch 119 / 500 | iteration 0 / 30 | Total Loss: 4.739509582519531 | KNN Loss: 3.710425615310669 | BCE Loss: 1.0290839672088623
Epoch 119 / 500 | iteration 5 / 30 | Total Loss: 4.727892875671387 | KNN Loss: 3.722181797027588 | BCE Loss: 1.0057108402252197
Epoch 119 / 500 | iteration 10 / 30 | Total Loss: 4.707941055297852 | KNN Loss: 3.680720567703247 | BCE Loss: 1.0272204875946045
Epoch 119 / 500 | iteration 15 / 30 | Total Loss: 4.722529888153076 | KNN Loss: 3.7007174491882324

Epoch 129 / 500 | iteration 0 / 30 | Total Loss: 4.675452709197998 | KNN Loss: 3.662527084350586 | BCE Loss: 1.012925624847412
Epoch 129 / 500 | iteration 5 / 30 | Total Loss: 4.748711109161377 | KNN Loss: 3.710463285446167 | BCE Loss: 1.03824782371521
Epoch 129 / 500 | iteration 10 / 30 | Total Loss: 4.781057357788086 | KNN Loss: 3.7316596508026123 | BCE Loss: 1.0493974685668945
Epoch 129 / 500 | iteration 15 / 30 | Total Loss: 4.707805633544922 | KNN Loss: 3.674666404724121 | BCE Loss: 1.0331394672393799
Epoch 129 / 500 | iteration 20 / 30 | Total Loss: 4.733466148376465 | KNN Loss: 3.7168054580688477 | BCE Loss: 1.0166606903076172
Epoch 129 / 500 | iteration 25 / 30 | Total Loss: 4.77614688873291 | KNN Loss: 3.7193779945373535 | BCE Loss: 1.0567686557769775
Epoch 130 / 500 | iteration 0 / 30 | Total Loss: 4.753278732299805 | KNN Loss: 3.6971356868743896 | BCE Loss: 1.056143045425415
Epoch 130 / 500 | iteration 5 / 30 | Total Loss: 4.717169284820557 | KNN Loss: 3.6845569610595703 | B

Epoch 139 / 500 | iteration 20 / 30 | Total Loss: 4.736207485198975 | KNN Loss: 3.7151901721954346 | BCE Loss: 1.02101731300354
Epoch 139 / 500 | iteration 25 / 30 | Total Loss: 4.728896141052246 | KNN Loss: 3.708296060562134 | BCE Loss: 1.0206003189086914
Epoch 140 / 500 | iteration 0 / 30 | Total Loss: 4.690422058105469 | KNN Loss: 3.67264723777771 | BCE Loss: 1.0177748203277588
Epoch 140 / 500 | iteration 5 / 30 | Total Loss: 4.715219497680664 | KNN Loss: 3.70550537109375 | BCE Loss: 1.009713888168335
Epoch 140 / 500 | iteration 10 / 30 | Total Loss: 4.674323081970215 | KNN Loss: 3.682277202606201 | BCE Loss: 0.9920459389686584
Epoch 140 / 500 | iteration 15 / 30 | Total Loss: 4.7396674156188965 | KNN Loss: 3.706331491470337 | BCE Loss: 1.0333359241485596
Epoch 140 / 500 | iteration 20 / 30 | Total Loss: 4.705522060394287 | KNN Loss: 3.671377182006836 | BCE Loss: 1.0341448783874512
Epoch 140 / 500 | iteration 25 / 30 | Total Loss: 4.704563140869141 | KNN Loss: 3.6625020503997803 | B

Epoch 150 / 500 | iteration 10 / 30 | Total Loss: 4.725688457489014 | KNN Loss: 3.680338144302368 | BCE Loss: 1.0453503131866455
Epoch 150 / 500 | iteration 15 / 30 | Total Loss: 4.719351768493652 | KNN Loss: 3.665105104446411 | BCE Loss: 1.0542466640472412
Epoch 150 / 500 | iteration 20 / 30 | Total Loss: 4.748157978057861 | KNN Loss: 3.710408926010132 | BCE Loss: 1.0377490520477295
Epoch 150 / 500 | iteration 25 / 30 | Total Loss: 4.717916488647461 | KNN Loss: 3.6924781799316406 | BCE Loss: 1.0254383087158203
Epoch 151 / 500 | iteration 0 / 30 | Total Loss: 4.733489513397217 | KNN Loss: 3.7075681686401367 | BCE Loss: 1.0259212255477905
Epoch 151 / 500 | iteration 5 / 30 | Total Loss: 4.739500045776367 | KNN Loss: 3.6978673934936523 | BCE Loss: 1.0416326522827148
Epoch 151 / 500 | iteration 10 / 30 | Total Loss: 4.742783546447754 | KNN Loss: 3.6902246475219727 | BCE Loss: 1.0525590181350708
Epoch 151 / 500 | iteration 15 / 30 | Total Loss: 4.714107036590576 | KNN Loss: 3.6824774742126

Epoch 161 / 500 | iteration 0 / 30 | Total Loss: 4.769018173217773 | KNN Loss: 3.7405524253845215 | BCE Loss: 1.028465986251831
Epoch 161 / 500 | iteration 5 / 30 | Total Loss: 4.781330585479736 | KNN Loss: 3.7522664070129395 | BCE Loss: 1.0290642976760864
Epoch 161 / 500 | iteration 10 / 30 | Total Loss: 4.72381591796875 | KNN Loss: 3.6829776763916016 | BCE Loss: 1.040838360786438
Epoch 161 / 500 | iteration 15 / 30 | Total Loss: 4.712642192840576 | KNN Loss: 3.7055134773254395 | BCE Loss: 1.0071285963058472
Epoch 161 / 500 | iteration 20 / 30 | Total Loss: 4.76686429977417 | KNN Loss: 3.756225824356079 | BCE Loss: 1.0106384754180908
Epoch 161 / 500 | iteration 25 / 30 | Total Loss: 4.7425150871276855 | KNN Loss: 3.6990272998809814 | BCE Loss: 1.0434879064559937
Epoch 162 / 500 | iteration 0 / 30 | Total Loss: 4.705746650695801 | KNN Loss: 3.657306671142578 | BCE Loss: 1.0484399795532227
Epoch 162 / 500 | iteration 5 / 30 | Total Loss: 4.700399875640869 | KNN Loss: 3.6840932369232178 

Epoch 171 / 500 | iteration 15 / 30 | Total Loss: 4.714395999908447 | KNN Loss: 3.689777135848999 | BCE Loss: 1.0246188640594482
Epoch 171 / 500 | iteration 20 / 30 | Total Loss: 4.806906223297119 | KNN Loss: 3.7575912475585938 | BCE Loss: 1.0493149757385254
Epoch 171 / 500 | iteration 25 / 30 | Total Loss: 4.789076328277588 | KNN Loss: 3.7383737564086914 | BCE Loss: 1.0507025718688965
Epoch 172 / 500 | iteration 0 / 30 | Total Loss: 4.670455455780029 | KNN Loss: 3.677607297897339 | BCE Loss: 0.9928482174873352
Epoch 172 / 500 | iteration 5 / 30 | Total Loss: 4.700645446777344 | KNN Loss: 3.672260284423828 | BCE Loss: 1.0283854007720947
Epoch 172 / 500 | iteration 10 / 30 | Total Loss: 4.76609992980957 | KNN Loss: 3.7262001037597656 | BCE Loss: 1.0398995876312256
Epoch 172 / 500 | iteration 15 / 30 | Total Loss: 4.716073036193848 | KNN Loss: 3.688462495803833 | BCE Loss: 1.0276105403900146
Epoch 172 / 500 | iteration 20 / 30 | Total Loss: 4.71012020111084 | KNN Loss: 3.688098907470703 

Epoch 182 / 500 | iteration 0 / 30 | Total Loss: 4.739169120788574 | KNN Loss: 3.6828365325927734 | BCE Loss: 1.0563325881958008
Epoch 182 / 500 | iteration 5 / 30 | Total Loss: 4.734935760498047 | KNN Loss: 3.699075698852539 | BCE Loss: 1.0358598232269287
Epoch 182 / 500 | iteration 10 / 30 | Total Loss: 4.703526020050049 | KNN Loss: 3.672065019607544 | BCE Loss: 1.0314610004425049
Epoch 182 / 500 | iteration 15 / 30 | Total Loss: 4.665281295776367 | KNN Loss: 3.646109104156494 | BCE Loss: 1.019171953201294
Epoch 182 / 500 | iteration 20 / 30 | Total Loss: 4.724977970123291 | KNN Loss: 3.6787548065185547 | BCE Loss: 1.0462230443954468
Epoch 182 / 500 | iteration 25 / 30 | Total Loss: 4.695399284362793 | KNN Loss: 3.685948371887207 | BCE Loss: 1.0094510316848755
Epoch 183 / 500 | iteration 0 / 30 | Total Loss: 4.718674659729004 | KNN Loss: 3.6991236209869385 | BCE Loss: 1.0195512771606445
Epoch 183 / 500 | iteration 5 / 30 | Total Loss: 4.743906021118164 | KNN Loss: 3.715883493423462 |

Epoch 192 / 500 | iteration 20 / 30 | Total Loss: 4.7211456298828125 | KNN Loss: 3.68992018699646 | BCE Loss: 1.031225562095642
Epoch 192 / 500 | iteration 25 / 30 | Total Loss: 4.700775146484375 | KNN Loss: 3.655952215194702 | BCE Loss: 1.0448229312896729
Epoch 193 / 500 | iteration 0 / 30 | Total Loss: 4.671133041381836 | KNN Loss: 3.652195453643799 | BCE Loss: 1.018937587738037
Epoch 193 / 500 | iteration 5 / 30 | Total Loss: 4.70466947555542 | KNN Loss: 3.697723150253296 | BCE Loss: 1.006946325302124
Epoch 193 / 500 | iteration 10 / 30 | Total Loss: 4.7451171875 | KNN Loss: 3.6882615089416504 | BCE Loss: 1.0568557977676392
Epoch 193 / 500 | iteration 15 / 30 | Total Loss: 4.671378135681152 | KNN Loss: 3.6710267066955566 | BCE Loss: 1.0003514289855957
Epoch 193 / 500 | iteration 20 / 30 | Total Loss: 4.720921993255615 | KNN Loss: 3.6986963748931885 | BCE Loss: 1.0222254991531372
Epoch 193 / 500 | iteration 25 / 30 | Total Loss: 4.698876857757568 | KNN Loss: 3.6827828884124756 | BCE 

Epoch 203 / 500 | iteration 10 / 30 | Total Loss: 4.716550827026367 | KNN Loss: 3.6833183765411377 | BCE Loss: 1.033232569694519
Epoch 203 / 500 | iteration 15 / 30 | Total Loss: 4.729662895202637 | KNN Loss: 3.730487108230591 | BCE Loss: 0.9991757869720459
Epoch 203 / 500 | iteration 20 / 30 | Total Loss: 4.779178619384766 | KNN Loss: 3.722198486328125 | BCE Loss: 1.0569798946380615
Epoch 203 / 500 | iteration 25 / 30 | Total Loss: 4.687356948852539 | KNN Loss: 3.6964638233184814 | BCE Loss: 0.9908931255340576
Epoch 204 / 500 | iteration 0 / 30 | Total Loss: 4.743438720703125 | KNN Loss: 3.713141679763794 | BCE Loss: 1.030297040939331
Epoch 204 / 500 | iteration 5 / 30 | Total Loss: 4.700294494628906 | KNN Loss: 3.6910719871520996 | BCE Loss: 1.0092225074768066
Epoch 204 / 500 | iteration 10 / 30 | Total Loss: 4.732605457305908 | KNN Loss: 3.710179567337036 | BCE Loss: 1.022425889968872
Epoch 204 / 500 | iteration 15 / 30 | Total Loss: 4.719783782958984 | KNN Loss: 3.7004141807556152 

Epoch 214 / 500 | iteration 0 / 30 | Total Loss: 4.733996391296387 | KNN Loss: 3.722764730453491 | BCE Loss: 1.0112314224243164
Epoch 214 / 500 | iteration 5 / 30 | Total Loss: 4.68924617767334 | KNN Loss: 3.662137508392334 | BCE Loss: 1.0271084308624268
Epoch 214 / 500 | iteration 10 / 30 | Total Loss: 4.674508094787598 | KNN Loss: 3.669020891189575 | BCE Loss: 1.005487084388733
Epoch 214 / 500 | iteration 15 / 30 | Total Loss: 4.706817626953125 | KNN Loss: 3.6736128330230713 | BCE Loss: 1.0332046747207642
Epoch 214 / 500 | iteration 20 / 30 | Total Loss: 4.731132507324219 | KNN Loss: 3.7301411628723145 | BCE Loss: 1.0009913444519043
Epoch 214 / 500 | iteration 25 / 30 | Total Loss: 4.696763515472412 | KNN Loss: 3.6721482276916504 | BCE Loss: 1.0246152877807617
Epoch   215: reducing learning rate of group 0 to 4.1177e-04.
Epoch 215 / 500 | iteration 0 / 30 | Total Loss: 4.731341361999512 | KNN Loss: 3.7105159759521484 | BCE Loss: 1.0208253860473633
Epoch 215 / 500 | iteration 5 / 30 |

Epoch 224 / 500 | iteration 15 / 30 | Total Loss: 4.707334518432617 | KNN Loss: 3.6952459812164307 | BCE Loss: 1.0120882987976074
Epoch 224 / 500 | iteration 20 / 30 | Total Loss: 4.7220025062561035 | KNN Loss: 3.6695873737335205 | BCE Loss: 1.0524152517318726
Epoch 224 / 500 | iteration 25 / 30 | Total Loss: 4.783000469207764 | KNN Loss: 3.742971658706665 | BCE Loss: 1.040028691291809
Epoch 225 / 500 | iteration 0 / 30 | Total Loss: 4.696141719818115 | KNN Loss: 3.708183765411377 | BCE Loss: 0.9879578351974487
Epoch 225 / 500 | iteration 5 / 30 | Total Loss: 4.677399158477783 | KNN Loss: 3.652381420135498 | BCE Loss: 1.0250177383422852
Epoch 225 / 500 | iteration 10 / 30 | Total Loss: 4.723726749420166 | KNN Loss: 3.706437110900879 | BCE Loss: 1.017289638519287
Epoch 225 / 500 | iteration 15 / 30 | Total Loss: 4.670502662658691 | KNN Loss: 3.6692519187927246 | BCE Loss: 1.001250982284546
Epoch 225 / 500 | iteration 20 / 30 | Total Loss: 4.725288391113281 | KNN Loss: 3.6887595653533936

Epoch 235 / 500 | iteration 5 / 30 | Total Loss: 4.6887922286987305 | KNN Loss: 3.6575610637664795 | BCE Loss: 1.03123140335083
Epoch 235 / 500 | iteration 10 / 30 | Total Loss: 4.759778022766113 | KNN Loss: 3.72322154045105 | BCE Loss: 1.0365564823150635
Epoch 235 / 500 | iteration 15 / 30 | Total Loss: 4.737240791320801 | KNN Loss: 3.7101662158966064 | BCE Loss: 1.0270744562149048
Epoch 235 / 500 | iteration 20 / 30 | Total Loss: 4.725982189178467 | KNN Loss: 3.7253103256225586 | BCE Loss: 1.0006718635559082
Epoch 235 / 500 | iteration 25 / 30 | Total Loss: 4.702294826507568 | KNN Loss: 3.6671929359436035 | BCE Loss: 1.0351020097732544
Epoch 236 / 500 | iteration 0 / 30 | Total Loss: 4.773587226867676 | KNN Loss: 3.7160604000091553 | BCE Loss: 1.0575265884399414
Epoch 236 / 500 | iteration 5 / 30 | Total Loss: 4.66157341003418 | KNN Loss: 3.6657235622406006 | BCE Loss: 0.9958497285842896
Epoch 236 / 500 | iteration 10 / 30 | Total Loss: 4.692432403564453 | KNN Loss: 3.671902179718017

Epoch 245 / 500 | iteration 25 / 30 | Total Loss: 4.711059093475342 | KNN Loss: 3.6831564903259277 | BCE Loss: 1.0279024839401245
Epoch 246 / 500 | iteration 0 / 30 | Total Loss: 4.713857173919678 | KNN Loss: 3.69326114654541 | BCE Loss: 1.0205961465835571
Epoch 246 / 500 | iteration 5 / 30 | Total Loss: 4.673521995544434 | KNN Loss: 3.6751444339752197 | BCE Loss: 0.9983777403831482
Epoch 246 / 500 | iteration 10 / 30 | Total Loss: 4.6681318283081055 | KNN Loss: 3.6483070850372314 | BCE Loss: 1.0198249816894531
Epoch 246 / 500 | iteration 15 / 30 | Total Loss: 4.672610282897949 | KNN Loss: 3.6553807258605957 | BCE Loss: 1.0172297954559326
Epoch 246 / 500 | iteration 20 / 30 | Total Loss: 4.700257301330566 | KNN Loss: 3.6917738914489746 | BCE Loss: 1.0084832906723022
Epoch 246 / 500 | iteration 25 / 30 | Total Loss: 4.701848983764648 | KNN Loss: 3.659717082977295 | BCE Loss: 1.0421316623687744
Epoch 247 / 500 | iteration 0 / 30 | Total Loss: 4.737268924713135 | KNN Loss: 3.7271587848663

Epoch 256 / 500 | iteration 15 / 30 | Total Loss: 4.703512191772461 | KNN Loss: 3.698789119720459 | BCE Loss: 1.0047228336334229
Epoch 256 / 500 | iteration 20 / 30 | Total Loss: 4.7712507247924805 | KNN Loss: 3.701223373413086 | BCE Loss: 1.0700273513793945
Epoch 256 / 500 | iteration 25 / 30 | Total Loss: 4.738153457641602 | KNN Loss: 3.6909780502319336 | BCE Loss: 1.047175407409668
Epoch 257 / 500 | iteration 0 / 30 | Total Loss: 4.703385353088379 | KNN Loss: 3.657529830932617 | BCE Loss: 1.0458554029464722
Epoch 257 / 500 | iteration 5 / 30 | Total Loss: 4.6957902908325195 | KNN Loss: 3.6990087032318115 | BCE Loss: 0.9967814683914185
Epoch 257 / 500 | iteration 10 / 30 | Total Loss: 4.7409515380859375 | KNN Loss: 3.6940677165985107 | BCE Loss: 1.0468838214874268
Epoch 257 / 500 | iteration 15 / 30 | Total Loss: 4.695403575897217 | KNN Loss: 3.6973190307617188 | BCE Loss: 0.998084545135498
Epoch 257 / 500 | iteration 20 / 30 | Total Loss: 4.68182373046875 | KNN Loss: 3.6619474887847

Epoch 267 / 500 | iteration 0 / 30 | Total Loss: 4.71190881729126 | KNN Loss: 3.6798558235168457 | BCE Loss: 1.032052993774414
Epoch 267 / 500 | iteration 5 / 30 | Total Loss: 4.698554992675781 | KNN Loss: 3.6765825748443604 | BCE Loss: 1.02197265625
Epoch 267 / 500 | iteration 10 / 30 | Total Loss: 4.718596458435059 | KNN Loss: 3.676518201828003 | BCE Loss: 1.0420780181884766
Epoch 267 / 500 | iteration 15 / 30 | Total Loss: 4.708215236663818 | KNN Loss: 3.6877501010894775 | BCE Loss: 1.0204652547836304
Epoch 267 / 500 | iteration 20 / 30 | Total Loss: 4.637581825256348 | KNN Loss: 3.64620304107666 | BCE Loss: 0.9913787245750427
Epoch 267 / 500 | iteration 25 / 30 | Total Loss: 4.704524517059326 | KNN Loss: 3.694326162338257 | BCE Loss: 1.0101983547210693
Epoch 268 / 500 | iteration 0 / 30 | Total Loss: 4.708281993865967 | KNN Loss: 3.7090396881103516 | BCE Loss: 0.9992421269416809
Epoch 268 / 500 | iteration 5 / 30 | Total Loss: 4.66422700881958 | KNN Loss: 3.6736865043640137 | BCE L

Epoch 277 / 500 | iteration 20 / 30 | Total Loss: 4.755423545837402 | KNN Loss: 3.701122760772705 | BCE Loss: 1.0543010234832764
Epoch 277 / 500 | iteration 25 / 30 | Total Loss: 4.687032699584961 | KNN Loss: 3.692953586578369 | BCE Loss: 0.9940793514251709
Epoch 278 / 500 | iteration 0 / 30 | Total Loss: 4.7337141036987305 | KNN Loss: 3.730053186416626 | BCE Loss: 1.0036606788635254
Epoch 278 / 500 | iteration 5 / 30 | Total Loss: 4.704862117767334 | KNN Loss: 3.6801252365112305 | BCE Loss: 1.0247368812561035
Epoch 278 / 500 | iteration 10 / 30 | Total Loss: 4.7160868644714355 | KNN Loss: 3.6975960731506348 | BCE Loss: 1.0184907913208008
Epoch 278 / 500 | iteration 15 / 30 | Total Loss: 4.678962707519531 | KNN Loss: 3.666198968887329 | BCE Loss: 1.012763500213623
Epoch 278 / 500 | iteration 20 / 30 | Total Loss: 4.760176658630371 | KNN Loss: 3.705671787261963 | BCE Loss: 1.054504632949829
Epoch 278 / 500 | iteration 25 / 30 | Total Loss: 4.719836235046387 | KNN Loss: 3.676417112350464

Epoch 288 / 500 | iteration 10 / 30 | Total Loss: 4.728771209716797 | KNN Loss: 3.706209421157837 | BCE Loss: 1.0225619077682495
Epoch 288 / 500 | iteration 15 / 30 | Total Loss: 4.711973190307617 | KNN Loss: 3.70750093460083 | BCE Loss: 1.0044723749160767
Epoch 288 / 500 | iteration 20 / 30 | Total Loss: 4.727872848510742 | KNN Loss: 3.6813645362854004 | BCE Loss: 1.046508550643921
Epoch 288 / 500 | iteration 25 / 30 | Total Loss: 4.736748695373535 | KNN Loss: 3.698896884918213 | BCE Loss: 1.0378520488739014
Epoch 289 / 500 | iteration 0 / 30 | Total Loss: 4.6785383224487305 | KNN Loss: 3.659581422805786 | BCE Loss: 1.0189566612243652
Epoch 289 / 500 | iteration 5 / 30 | Total Loss: 4.711480140686035 | KNN Loss: 3.673100709915161 | BCE Loss: 1.038379192352295
Epoch 289 / 500 | iteration 10 / 30 | Total Loss: 4.695392608642578 | KNN Loss: 3.6869466304779053 | BCE Loss: 1.008446216583252
Epoch 289 / 500 | iteration 15 / 30 | Total Loss: 4.718842506408691 | KNN Loss: 3.691472053527832 | 

Epoch 299 / 500 | iteration 0 / 30 | Total Loss: 4.708618640899658 | KNN Loss: 3.7068488597869873 | BCE Loss: 1.0017696619033813
Epoch 299 / 500 | iteration 5 / 30 | Total Loss: 4.709845066070557 | KNN Loss: 3.7064685821533203 | BCE Loss: 1.0033764839172363
Epoch 299 / 500 | iteration 10 / 30 | Total Loss: 4.74594783782959 | KNN Loss: 3.708810567855835 | BCE Loss: 1.0371373891830444
Epoch 299 / 500 | iteration 15 / 30 | Total Loss: 4.681051254272461 | KNN Loss: 3.6460604667663574 | BCE Loss: 1.0349905490875244
Epoch 299 / 500 | iteration 20 / 30 | Total Loss: 4.736692428588867 | KNN Loss: 3.7289984226226807 | BCE Loss: 1.0076942443847656
Epoch 299 / 500 | iteration 25 / 30 | Total Loss: 4.689509868621826 | KNN Loss: 3.6938366889953613 | BCE Loss: 0.9956731200218201
Epoch 300 / 500 | iteration 0 / 30 | Total Loss: 4.740659236907959 | KNN Loss: 3.6989645957946777 | BCE Loss: 1.0416946411132812
Epoch 300 / 500 | iteration 5 / 30 | Total Loss: 4.747907638549805 | KNN Loss: 3.70043015480041

Epoch 309 / 500 | iteration 15 / 30 | Total Loss: 4.73975133895874 | KNN Loss: 3.715670108795166 | BCE Loss: 1.0240812301635742
Epoch 309 / 500 | iteration 20 / 30 | Total Loss: 4.695099830627441 | KNN Loss: 3.692857027053833 | BCE Loss: 1.0022426843643188
Epoch 309 / 500 | iteration 25 / 30 | Total Loss: 4.7131476402282715 | KNN Loss: 3.6822092533111572 | BCE Loss: 1.0309383869171143
Epoch 310 / 500 | iteration 0 / 30 | Total Loss: 4.69252872467041 | KNN Loss: 3.669090747833252 | BCE Loss: 1.0234382152557373
Epoch 310 / 500 | iteration 5 / 30 | Total Loss: 4.71948766708374 | KNN Loss: 3.657866954803467 | BCE Loss: 1.061620831489563
Epoch 310 / 500 | iteration 10 / 30 | Total Loss: 4.772586822509766 | KNN Loss: 3.7172834873199463 | BCE Loss: 1.0553035736083984
Epoch 310 / 500 | iteration 15 / 30 | Total Loss: 4.681685447692871 | KNN Loss: 3.6680397987365723 | BCE Loss: 1.013645887374878
Epoch 310 / 500 | iteration 20 / 30 | Total Loss: 4.675919055938721 | KNN Loss: 3.6670525074005127 |

Epoch 320 / 500 | iteration 5 / 30 | Total Loss: 4.730607986450195 | KNN Loss: 3.6972365379333496 | BCE Loss: 1.0333714485168457
Epoch 320 / 500 | iteration 10 / 30 | Total Loss: 4.660962104797363 | KNN Loss: 3.6573567390441895 | BCE Loss: 1.0036052465438843
Epoch 320 / 500 | iteration 15 / 30 | Total Loss: 4.693453788757324 | KNN Loss: 3.6762590408325195 | BCE Loss: 1.0171947479248047
Epoch 320 / 500 | iteration 20 / 30 | Total Loss: 4.739112854003906 | KNN Loss: 3.7203009128570557 | BCE Loss: 1.0188121795654297
Epoch 320 / 500 | iteration 25 / 30 | Total Loss: 4.717229843139648 | KNN Loss: 3.6764776706695557 | BCE Loss: 1.0407520532608032
Epoch 321 / 500 | iteration 0 / 30 | Total Loss: 4.728982925415039 | KNN Loss: 3.686455249786377 | BCE Loss: 1.042527437210083
Epoch 321 / 500 | iteration 5 / 30 | Total Loss: 4.708450794219971 | KNN Loss: 3.6810147762298584 | BCE Loss: 1.0274360179901123
Epoch 321 / 500 | iteration 10 / 30 | Total Loss: 4.683122634887695 | KNN Loss: 3.6621155738830

Epoch 330 / 500 | iteration 20 / 30 | Total Loss: 4.647955417633057 | KNN Loss: 3.661982536315918 | BCE Loss: 0.985973060131073
Epoch 330 / 500 | iteration 25 / 30 | Total Loss: 4.687281608581543 | KNN Loss: 3.673563241958618 | BCE Loss: 1.0137184858322144
Epoch 331 / 500 | iteration 0 / 30 | Total Loss: 4.708765983581543 | KNN Loss: 3.6791622638702393 | BCE Loss: 1.0296039581298828
Epoch 331 / 500 | iteration 5 / 30 | Total Loss: 4.716582775115967 | KNN Loss: 3.6822543144226074 | BCE Loss: 1.034328579902649
Epoch 331 / 500 | iteration 10 / 30 | Total Loss: 4.731959342956543 | KNN Loss: 3.7215702533721924 | BCE Loss: 1.0103893280029297
Epoch 331 / 500 | iteration 15 / 30 | Total Loss: 4.6975274085998535 | KNN Loss: 3.707078218460083 | BCE Loss: 0.9904493093490601
Epoch 331 / 500 | iteration 20 / 30 | Total Loss: 4.734625339508057 | KNN Loss: 3.6975345611572266 | BCE Loss: 1.03709077835083
Epoch 331 / 500 | iteration 25 / 30 | Total Loss: 4.687020301818848 | KNN Loss: 3.676769495010376 

Epoch 341 / 500 | iteration 10 / 30 | Total Loss: 4.761072158813477 | KNN Loss: 3.7229440212249756 | BCE Loss: 1.0381282567977905
Epoch 341 / 500 | iteration 15 / 30 | Total Loss: 4.707817077636719 | KNN Loss: 3.6950271129608154 | BCE Loss: 1.0127899646759033
Epoch 341 / 500 | iteration 20 / 30 | Total Loss: 4.721379280090332 | KNN Loss: 3.6951565742492676 | BCE Loss: 1.0262227058410645
Epoch 341 / 500 | iteration 25 / 30 | Total Loss: 4.764959335327148 | KNN Loss: 3.731513261795044 | BCE Loss: 1.033445954322815
Epoch 342 / 500 | iteration 0 / 30 | Total Loss: 4.691410064697266 | KNN Loss: 3.687889575958252 | BCE Loss: 1.0035207271575928
Epoch 342 / 500 | iteration 5 / 30 | Total Loss: 4.759422302246094 | KNN Loss: 3.7196271419525146 | BCE Loss: 1.039794921875
Epoch 342 / 500 | iteration 10 / 30 | Total Loss: 4.680342674255371 | KNN Loss: 3.6862599849700928 | BCE Loss: 0.9940825700759888
Epoch 342 / 500 | iteration 15 / 30 | Total Loss: 4.707197189331055 | KNN Loss: 3.6826043128967285 

Epoch 352 / 500 | iteration 0 / 30 | Total Loss: 4.727409362792969 | KNN Loss: 3.696769952774048 | BCE Loss: 1.030639410018921
Epoch 352 / 500 | iteration 5 / 30 | Total Loss: 4.745514869689941 | KNN Loss: 3.717909812927246 | BCE Loss: 1.0276048183441162
Epoch 352 / 500 | iteration 10 / 30 | Total Loss: 4.725104808807373 | KNN Loss: 3.6868364810943604 | BCE Loss: 1.0382683277130127
Epoch 352 / 500 | iteration 15 / 30 | Total Loss: 4.6664652824401855 | KNN Loss: 3.6726791858673096 | BCE Loss: 0.9937859773635864
Epoch 352 / 500 | iteration 20 / 30 | Total Loss: 4.759120464324951 | KNN Loss: 3.714693069458008 | BCE Loss: 1.0444272756576538
Epoch 352 / 500 | iteration 25 / 30 | Total Loss: 4.744544982910156 | KNN Loss: 3.7120320796966553 | BCE Loss: 1.03251314163208
Epoch 353 / 500 | iteration 0 / 30 | Total Loss: 4.697330474853516 | KNN Loss: 3.6638405323028564 | BCE Loss: 1.0334899425506592
Epoch 353 / 500 | iteration 5 / 30 | Total Loss: 4.731172561645508 | KNN Loss: 3.694321393966675 |

Epoch 362 / 500 | iteration 20 / 30 | Total Loss: 4.7355546951293945 | KNN Loss: 3.690927743911743 | BCE Loss: 1.0446271896362305
Epoch 362 / 500 | iteration 25 / 30 | Total Loss: 4.736426830291748 | KNN Loss: 3.6933820247650146 | BCE Loss: 1.0430446863174438
Epoch 363 / 500 | iteration 0 / 30 | Total Loss: 4.707645893096924 | KNN Loss: 3.70054292678833 | BCE Loss: 1.0071029663085938
Epoch 363 / 500 | iteration 5 / 30 | Total Loss: 4.683151721954346 | KNN Loss: 3.6527087688446045 | BCE Loss: 1.0304430723190308
Epoch 363 / 500 | iteration 10 / 30 | Total Loss: 4.737252235412598 | KNN Loss: 3.7109155654907227 | BCE Loss: 1.026336431503296
Epoch 363 / 500 | iteration 15 / 30 | Total Loss: 4.796880722045898 | KNN Loss: 3.7512059211730957 | BCE Loss: 1.0456750392913818
Epoch 363 / 500 | iteration 20 / 30 | Total Loss: 4.676353454589844 | KNN Loss: 3.6539981365203857 | BCE Loss: 1.022355318069458
Epoch 363 / 500 | iteration 25 / 30 | Total Loss: 4.700132369995117 | KNN Loss: 3.68155789375305

Epoch 373 / 500 | iteration 10 / 30 | Total Loss: 4.703348159790039 | KNN Loss: 3.6668667793273926 | BCE Loss: 1.0364813804626465
Epoch 373 / 500 | iteration 15 / 30 | Total Loss: 4.722272872924805 | KNN Loss: 3.7025625705718994 | BCE Loss: 1.0197105407714844
Epoch 373 / 500 | iteration 20 / 30 | Total Loss: 4.696122646331787 | KNN Loss: 3.6575300693511963 | BCE Loss: 1.0385925769805908
Epoch 373 / 500 | iteration 25 / 30 | Total Loss: 4.72933292388916 | KNN Loss: 3.7062461376190186 | BCE Loss: 1.0230870246887207
Epoch 374 / 500 | iteration 0 / 30 | Total Loss: 4.753696918487549 | KNN Loss: 3.732390880584717 | BCE Loss: 1.021306037902832
Epoch 374 / 500 | iteration 5 / 30 | Total Loss: 4.737729072570801 | KNN Loss: 3.699394941329956 | BCE Loss: 1.0383341312408447
Epoch 374 / 500 | iteration 10 / 30 | Total Loss: 4.7061872482299805 | KNN Loss: 3.680590867996216 | BCE Loss: 1.0255961418151855
Epoch 374 / 500 | iteration 15 / 30 | Total Loss: 4.765904903411865 | KNN Loss: 3.74319982528686

Epoch 384 / 500 | iteration 0 / 30 | Total Loss: 4.689777851104736 | KNN Loss: 3.66196608543396 | BCE Loss: 1.0278117656707764
Epoch 384 / 500 | iteration 5 / 30 | Total Loss: 4.719875335693359 | KNN Loss: 3.691113233566284 | BCE Loss: 1.0287621021270752
Epoch 384 / 500 | iteration 10 / 30 | Total Loss: 4.70038366317749 | KNN Loss: 3.6917409896850586 | BCE Loss: 1.0086426734924316
Epoch 384 / 500 | iteration 15 / 30 | Total Loss: 4.702990531921387 | KNN Loss: 3.702233076095581 | BCE Loss: 1.0007572174072266
Epoch 384 / 500 | iteration 20 / 30 | Total Loss: 4.702019691467285 | KNN Loss: 3.698847532272339 | BCE Loss: 1.0031723976135254
Epoch 384 / 500 | iteration 25 / 30 | Total Loss: 4.731573104858398 | KNN Loss: 3.7131006717681885 | BCE Loss: 1.01847243309021
Epoch 385 / 500 | iteration 0 / 30 | Total Loss: 4.720511436462402 | KNN Loss: 3.703808307647705 | BCE Loss: 1.0167031288146973
Epoch 385 / 500 | iteration 5 / 30 | Total Loss: 4.687342643737793 | KNN Loss: 3.674984931945801 | BCE

Epoch 394 / 500 | iteration 20 / 30 | Total Loss: 4.687875747680664 | KNN Loss: 3.679211378097534 | BCE Loss: 1.008664608001709
Epoch 394 / 500 | iteration 25 / 30 | Total Loss: 4.673733234405518 | KNN Loss: 3.669077157974243 | BCE Loss: 1.004656195640564
Epoch 395 / 500 | iteration 0 / 30 | Total Loss: 4.79136323928833 | KNN Loss: 3.7708029747009277 | BCE Loss: 1.0205602645874023
Epoch 395 / 500 | iteration 5 / 30 | Total Loss: 4.729035377502441 | KNN Loss: 3.6943230628967285 | BCE Loss: 1.034712553024292
Epoch 395 / 500 | iteration 10 / 30 | Total Loss: 4.778444290161133 | KNN Loss: 3.7483811378479004 | BCE Loss: 1.0300633907318115
Epoch 395 / 500 | iteration 15 / 30 | Total Loss: 4.731167793273926 | KNN Loss: 3.691953659057617 | BCE Loss: 1.039214015007019
Epoch 395 / 500 | iteration 20 / 30 | Total Loss: 4.701098918914795 | KNN Loss: 3.6822140216827393 | BCE Loss: 1.0188848972320557
Epoch 395 / 500 | iteration 25 / 30 | Total Loss: 4.76290225982666 | KNN Loss: 3.714416027069092 | B

Epoch 405 / 500 | iteration 10 / 30 | Total Loss: 4.810971260070801 | KNN Loss: 3.74129056930542 | BCE Loss: 1.06968092918396
Epoch 405 / 500 | iteration 15 / 30 | Total Loss: 4.755423545837402 | KNN Loss: 3.7108614444732666 | BCE Loss: 1.0445623397827148
Epoch 405 / 500 | iteration 20 / 30 | Total Loss: 4.628955841064453 | KNN Loss: 3.644928455352783 | BCE Loss: 0.984027624130249
Epoch 405 / 500 | iteration 25 / 30 | Total Loss: 4.719236373901367 | KNN Loss: 3.690999746322632 | BCE Loss: 1.0282363891601562
Epoch 406 / 500 | iteration 0 / 30 | Total Loss: 4.7324724197387695 | KNN Loss: 3.6964871883392334 | BCE Loss: 1.0359851121902466
Epoch 406 / 500 | iteration 5 / 30 | Total Loss: 4.7882490158081055 | KNN Loss: 3.7428901195526123 | BCE Loss: 1.0453590154647827
Epoch 406 / 500 | iteration 10 / 30 | Total Loss: 4.724094867706299 | KNN Loss: 3.725226402282715 | BCE Loss: 0.9988684058189392
Epoch 406 / 500 | iteration 15 / 30 | Total Loss: 4.711143493652344 | KNN Loss: 3.681917667388916 

Epoch 416 / 500 | iteration 0 / 30 | Total Loss: 4.684908866882324 | KNN Loss: 3.662559747695923 | BCE Loss: 1.0223493576049805
Epoch 416 / 500 | iteration 5 / 30 | Total Loss: 4.7171783447265625 | KNN Loss: 3.684363842010498 | BCE Loss: 1.032814621925354
Epoch 416 / 500 | iteration 10 / 30 | Total Loss: 4.682908535003662 | KNN Loss: 3.674025297164917 | BCE Loss: 1.0088833570480347
Epoch 416 / 500 | iteration 15 / 30 | Total Loss: 4.692817211151123 | KNN Loss: 3.6803700923919678 | BCE Loss: 1.0124471187591553
Epoch 416 / 500 | iteration 20 / 30 | Total Loss: 4.732419013977051 | KNN Loss: 3.6941699981689453 | BCE Loss: 1.038248896598816
Epoch 416 / 500 | iteration 25 / 30 | Total Loss: 4.702141284942627 | KNN Loss: 3.6764721870422363 | BCE Loss: 1.0256690979003906
Epoch 417 / 500 | iteration 0 / 30 | Total Loss: 4.671271324157715 | KNN Loss: 3.6699178218841553 | BCE Loss: 1.0013535022735596
Epoch 417 / 500 | iteration 5 / 30 | Total Loss: 4.748252868652344 | KNN Loss: 3.741379499435425 

Epoch 426 / 500 | iteration 20 / 30 | Total Loss: 4.705377101898193 | KNN Loss: 3.6603331565856934 | BCE Loss: 1.0450438261032104
Epoch 426 / 500 | iteration 25 / 30 | Total Loss: 4.7135844230651855 | KNN Loss: 3.716151475906372 | BCE Loss: 0.9974328875541687
Epoch 427 / 500 | iteration 0 / 30 | Total Loss: 4.737201690673828 | KNN Loss: 3.733917713165283 | BCE Loss: 1.0032837390899658
Epoch 427 / 500 | iteration 5 / 30 | Total Loss: 4.707576751708984 | KNN Loss: 3.685514211654663 | BCE Loss: 1.0220624208450317
Epoch 427 / 500 | iteration 10 / 30 | Total Loss: 4.718943119049072 | KNN Loss: 3.7135863304138184 | BCE Loss: 1.0053566694259644
Epoch 427 / 500 | iteration 15 / 30 | Total Loss: 4.658566474914551 | KNN Loss: 3.649756669998169 | BCE Loss: 1.0088098049163818
Epoch 427 / 500 | iteration 20 / 30 | Total Loss: 4.688265800476074 | KNN Loss: 3.68343186378479 | BCE Loss: 1.0048339366912842
Epoch 427 / 500 | iteration 25 / 30 | Total Loss: 4.697124481201172 | KNN Loss: 3.706594467163086

Epoch 437 / 500 | iteration 10 / 30 | Total Loss: 4.740231513977051 | KNN Loss: 3.7092111110687256 | BCE Loss: 1.0310204029083252
Epoch 437 / 500 | iteration 15 / 30 | Total Loss: 4.740328788757324 | KNN Loss: 3.7277040481567383 | BCE Loss: 1.0126245021820068
Epoch 437 / 500 | iteration 20 / 30 | Total Loss: 4.723419189453125 | KNN Loss: 3.7056901454925537 | BCE Loss: 1.0177291631698608
Epoch 437 / 500 | iteration 25 / 30 | Total Loss: 4.709522247314453 | KNN Loss: 3.6893835067749023 | BCE Loss: 1.0201387405395508
Epoch 438 / 500 | iteration 0 / 30 | Total Loss: 4.759857177734375 | KNN Loss: 3.7159714698791504 | BCE Loss: 1.0438858270645142
Epoch 438 / 500 | iteration 5 / 30 | Total Loss: 4.73109245300293 | KNN Loss: 3.7003979682922363 | BCE Loss: 1.0306947231292725
Epoch 438 / 500 | iteration 10 / 30 | Total Loss: 4.711421966552734 | KNN Loss: 3.6733806133270264 | BCE Loss: 1.038041591644287
Epoch 438 / 500 | iteration 15 / 30 | Total Loss: 4.681744575500488 | KNN Loss: 3.687158107757

Epoch 448 / 500 | iteration 0 / 30 | Total Loss: 4.736355781555176 | KNN Loss: 3.720778465270996 | BCE Loss: 1.0155770778656006
Epoch 448 / 500 | iteration 5 / 30 | Total Loss: 4.739109039306641 | KNN Loss: 3.7167129516601562 | BCE Loss: 1.0223963260650635
Epoch 448 / 500 | iteration 10 / 30 | Total Loss: 4.711991786956787 | KNN Loss: 3.691059112548828 | BCE Loss: 1.0209325551986694
Epoch 448 / 500 | iteration 15 / 30 | Total Loss: 4.701459884643555 | KNN Loss: 3.704157590866089 | BCE Loss: 0.9973020553588867
Epoch 448 / 500 | iteration 20 / 30 | Total Loss: 4.701648712158203 | KNN Loss: 3.671044111251831 | BCE Loss: 1.030604600906372
Epoch 448 / 500 | iteration 25 / 30 | Total Loss: 4.761867523193359 | KNN Loss: 3.7083561420440674 | BCE Loss: 1.0535115003585815
Epoch 449 / 500 | iteration 0 / 30 | Total Loss: 4.688438415527344 | KNN Loss: 3.6693243980407715 | BCE Loss: 1.0191140174865723
Epoch 449 / 500 | iteration 5 / 30 | Total Loss: 4.73390007019043 | KNN Loss: 3.678891181945801 | 

Epoch 458 / 500 | iteration 20 / 30 | Total Loss: 4.711792945861816 | KNN Loss: 3.7024762630462646 | BCE Loss: 1.0093164443969727
Epoch 458 / 500 | iteration 25 / 30 | Total Loss: 4.749874114990234 | KNN Loss: 3.71945858001709 | BCE Loss: 1.030415415763855
Epoch 459 / 500 | iteration 0 / 30 | Total Loss: 4.678637504577637 | KNN Loss: 3.655155658721924 | BCE Loss: 1.0234819650650024
Epoch 459 / 500 | iteration 5 / 30 | Total Loss: 4.727014541625977 | KNN Loss: 3.693755865097046 | BCE Loss: 1.0332589149475098
Epoch 459 / 500 | iteration 10 / 30 | Total Loss: 4.726995468139648 | KNN Loss: 3.689645290374756 | BCE Loss: 1.0373499393463135
Epoch 459 / 500 | iteration 15 / 30 | Total Loss: 4.734009742736816 | KNN Loss: 3.701765537261963 | BCE Loss: 1.0322442054748535
Epoch 459 / 500 | iteration 20 / 30 | Total Loss: 4.679590702056885 | KNN Loss: 3.6703453063964844 | BCE Loss: 1.0092453956604004
Epoch 459 / 500 | iteration 25 / 30 | Total Loss: 4.7523698806762695 | KNN Loss: 3.700995683670044 

Epoch 469 / 500 | iteration 10 / 30 | Total Loss: 4.683858871459961 | KNN Loss: 3.6704912185668945 | BCE Loss: 1.0133675336837769
Epoch 469 / 500 | iteration 15 / 30 | Total Loss: 4.695420265197754 | KNN Loss: 3.6804099082946777 | BCE Loss: 1.0150105953216553
Epoch 469 / 500 | iteration 20 / 30 | Total Loss: 4.741342544555664 | KNN Loss: 3.7153501510620117 | BCE Loss: 1.0259926319122314
Epoch 469 / 500 | iteration 25 / 30 | Total Loss: 4.7017364501953125 | KNN Loss: 3.69100022315979 | BCE Loss: 1.0107359886169434
Epoch 470 / 500 | iteration 0 / 30 | Total Loss: 4.734338283538818 | KNN Loss: 3.6815943717956543 | BCE Loss: 1.052743911743164
Epoch 470 / 500 | iteration 5 / 30 | Total Loss: 4.714425563812256 | KNN Loss: 3.6779732704162598 | BCE Loss: 1.0364524126052856
Epoch 470 / 500 | iteration 10 / 30 | Total Loss: 4.703989505767822 | KNN Loss: 3.6965994834899902 | BCE Loss: 1.0073901414871216
Epoch 470 / 500 | iteration 15 / 30 | Total Loss: 4.701154708862305 | KNN Loss: 3.666209697723

Epoch 480 / 500 | iteration 0 / 30 | Total Loss: 4.7608184814453125 | KNN Loss: 3.7304775714874268 | BCE Loss: 1.0303411483764648
Epoch 480 / 500 | iteration 5 / 30 | Total Loss: 4.711490631103516 | KNN Loss: 3.676882743835449 | BCE Loss: 1.0346076488494873
Epoch 480 / 500 | iteration 10 / 30 | Total Loss: 4.69797420501709 | KNN Loss: 3.6910603046417236 | BCE Loss: 1.006913661956787
Epoch 480 / 500 | iteration 15 / 30 | Total Loss: 4.758475303649902 | KNN Loss: 3.7306647300720215 | BCE Loss: 1.0278103351593018
Epoch 480 / 500 | iteration 20 / 30 | Total Loss: 4.7572455406188965 | KNN Loss: 3.742795705795288 | BCE Loss: 1.014449954032898
Epoch 480 / 500 | iteration 25 / 30 | Total Loss: 4.720937728881836 | KNN Loss: 3.6801161766052246 | BCE Loss: 1.0408214330673218
Epoch 481 / 500 | iteration 0 / 30 | Total Loss: 4.767848491668701 | KNN Loss: 3.7440948486328125 | BCE Loss: 1.0237535238265991
Epoch 481 / 500 | iteration 5 / 30 | Total Loss: 4.680918216705322 | KNN Loss: 3.698016405105591

Epoch 490 / 500 | iteration 20 / 30 | Total Loss: 4.692294120788574 | KNN Loss: 3.6814565658569336 | BCE Loss: 1.0108373165130615
Epoch 490 / 500 | iteration 25 / 30 | Total Loss: 4.7550048828125 | KNN Loss: 3.698802947998047 | BCE Loss: 1.0562021732330322
Epoch 491 / 500 | iteration 0 / 30 | Total Loss: 4.733845233917236 | KNN Loss: 3.694500684738159 | BCE Loss: 1.0393446683883667
Epoch 491 / 500 | iteration 5 / 30 | Total Loss: 4.711842060089111 | KNN Loss: 3.6803207397460938 | BCE Loss: 1.031521201133728
Epoch 491 / 500 | iteration 10 / 30 | Total Loss: 4.733510494232178 | KNN Loss: 3.692293643951416 | BCE Loss: 1.0412169694900513
Epoch 491 / 500 | iteration 15 / 30 | Total Loss: 4.728346824645996 | KNN Loss: 3.7061851024627686 | BCE Loss: 1.022161602973938
Epoch 491 / 500 | iteration 20 / 30 | Total Loss: 4.738253116607666 | KNN Loss: 3.701615810394287 | BCE Loss: 1.036637306213379
Epoch 491 / 500 | iteration 25 / 30 | Total Loss: 4.7347564697265625 | KNN Loss: 3.7237050533294678 |

In [7]:
outputs, iterm = model(dataset[67][0].unsqueeze(0).to(device), return_intermidiate=True)
print(outputs)
print(outputs.min())
print(outputs.max())

tensor([[ 1.4957e+00,  2.2320e+00,  3.2307e+00,  1.8412e+00,  4.1501e+00,
          1.0916e+00,  1.3062e+00,  2.3809e+00,  1.4414e+00,  2.5832e+00,
          2.7500e+00,  2.0658e+00,  8.5737e-01,  2.2502e+00,  1.5931e+00,
          2.1076e+00,  1.5929e+00,  3.8431e+00,  1.5590e+00,  2.7104e+00,
          1.8890e+00,  2.1590e+00,  1.8157e+00,  1.7746e+00,  2.1227e+00,
          2.1804e+00,  2.2866e+00,  1.7074e+00,  1.2637e+00,  7.7037e-01,
         -7.2234e-03,  1.2120e+00,  4.1476e-01,  1.2188e+00,  1.2754e+00,
          1.6367e+00,  1.2974e+00,  3.5390e+00,  8.6547e-01,  1.8028e+00,
          1.3451e+00, -7.5467e-01,  2.4334e-01,  2.4294e+00,  2.7278e+00,
          1.0045e+00, -5.0151e-01,  9.3344e-02,  1.2230e+00,  1.9194e+00,
          2.4503e+00, -1.1153e-01,  9.0962e-01,  3.8952e-01, -4.6239e-01,
          9.9071e-01,  1.9920e+00,  1.6375e+00,  1.2462e+00,  1.6857e+00,
          1.9570e-01,  6.7696e-01,  5.1394e-01,  1.1245e+00,  1.8090e+00,
          1.5869e+00, -1.8152e+00,  1.

In [8]:
plt.figure()
plt.plot(losses)
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [9]:
dataset.transform = torchvision.transforms.Compose([
    BinaryEncodingTransform(mapping=dataset.items_to_idx),
]
)
dataset.target_transform = torchvision.transforms.Compose([
    BinaryEncodingTransform(mapping=dataset.items_to_idx),
]
)

In [12]:
dataset_ = [d[0].cpu() for d in dataset]

In [13]:
model = model.eval().cpu()
projections = model.calculate_intermidiate(dataset_)

100%|██████████| 15/15 [00:00<00:00, 18.22it/s]


In [14]:
distances = pairwise_distances(projections)
# distances = np.triu(distances)
distances_f = distances.flatten()

plt.matshow(distances)
plt.colorbar()
plt.figure()
plt.hist(distances_f[distances_f > 0], bins=1000)
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

# Fit DBSCAN and calculate indices


In [15]:
clusters = DBSCAN(eps=0.2, min_samples=80).fit_predict(projections)
# scores = []
# best_score = float('inf')
# clusters = None
# range_ = list(range(5, 20))
# for k in tqdm(range_):
#     y = GaussianMixture(n_components=k).fit_predict(projections)
#     cur_score = davies_bouldin_score(projections, y)
#     scores.append(cur_score)
    
#     if cur_score < best_score:
#         best_score = cur_score
#         clusters = y

In [16]:
perplexity = 100

p = reduce_dims_and_plot(projections[clusters != -1],
                         y=clusters[clusters != -1],
                         title=f'perplexity: {perplexity}',
                         file_name=None,
                         perplexity=perplexity,
                         library='Multicore-TSNE',
                         perform_PCA=False,
                         projected=None,
                         figure_type='2d',
                         show_figure=True,
                         close_figure=False,
                         text=None)    

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [17]:
# from sklearn.tree import DecisionTreeClassifier
# from sklearn import tree
# from sklearn.tree import _tree

In [18]:
tensor_dataset = torch.stack(dataset_)

In [19]:
# clf = DecisionTreeClassifier(max_depth=200, min_samples_leaf=5)
# clf = clf.fit(tensor_dataset[clusters!=-1], clusters[clusters != -1])
# print(clf.score(tensor_dataset[clusters!=-1], clusters[clusters != -1]))
# print(clf.get_depth())

In [20]:
# scores = []
# for min_samples in range(1,50, 1):
#     clf = DecisionTreeClassifier(max_depth=200, min_samples_leaf=min_samples)
#     clf = clf.fit(tensor_dataset[clusters!=-1], clusters[clusters != -1])
#     scores.append(clf.score(tensor_dataset[clusters!=-1], clusters[clusters != -1]))
    
# plt.figure()
# plt.plot(list(range(1,50, 1)), scores)
# plt.show()

In [21]:
def get_rules(tree, feature_names, class_names):
    tree_ = tree.tree_
    feature_name = [
        feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!"
        for i in tree_.feature
    ]

    paths = []
    path = []
    
    def recurse(node, path, paths):
        if tree_.feature[node] != _tree.TREE_UNDEFINED:
            name = feature_name[node]
            threshold = tree_.threshold[node]
            p1, p2 = list(path), list(path)
#             p1 += [f"({name} <= {np.round(threshold, 3)})"]
            p1 += [(name, '<=', np.round(threshold, 3))]
            recurse(tree_.children_left[node], p1, paths)
            p2 += [(name, '>', np.round(threshold, 3))]
            recurse(tree_.children_right[node], p2, paths)
        else:
            path += [(tree_.value[node], tree_.n_node_samples[node])]
            paths += [path]
            
    recurse(0, path, paths)

    # sort by samples count
    samples_count = [p[-1][1] for p in paths]
    ii = list(np.argsort(samples_count))
    paths = [paths[i] for i in reversed(ii)]
    
    rules = []
    for path in paths:
        rule = []
        
        for p in path[:-1]:
            rule += [p]
        target = " then "
        if class_names is None:
            target += "response: "+str(np.round(path[-1][0][0][0],3))
        else:
            classes = path[-1][0][0]
            l = np.argmax(classes)
            target += f"class: {class_names[l]} (proba: {np.round(100.0*classes[l]/np.sum(classes),2)}%)"
           
        proba = np.round(100.0*classes[l]/np.sum(classes),2)
        target += f" | based on {path[-1][1]:,} samples"
        rule_wrapper = {'target': target, 'rule': rule, 'proba': proba}
        rules += [rule_wrapper]
        
    return rules

In [22]:
# rules = get_rules(clf, dataset.items, clusters[clusters != -1])

# for rule in rules:
#     n_pos = 0
#     for c,p,v in rule['rule']:
#         if p == '>':
#             n_pos += 1
#     rule['pos'] = n_pos

In [23]:
# plt.figure()
# probs = [r['proba'] for r in rules]
# plt.hist(probs, bins = 100)
# plt.show()

In [24]:
# rules = sorted(rules, key=lambda x:x['pos'])
# rules = [r for r in rules if r['proba'] > 50]
# print(len(rules))

In [25]:
# for i in range(17):
#     r_i = rules[i]
#     print(f"------------- rule {i} length {len(r_i)} -------------")
#     print(r_i)

# Train a Soft-Decision-Tree given the self-labels

## Prepare the dataset

In [26]:
tree_dataset = list(zip(tensor_dataset[clusters!=-1], clusters[clusters != -1]))
batch_size = 512
tree_loader = torch.utils.data.DataLoader(tree_dataset, batch_size=batch_size, shuffle=True)

# Define how we prune the weights of a node

In [27]:
def prune_node(node_weights, factor=1):
    w = node_weights.cpu().detach().numpy()
    mean_ = np.mean(w)
    std_ = np.std(w)
    node_weights[((mean_ - std_ * factor) < node_weights) & (node_weights < (mean_ + std_ * factor))] = 0
    return node_weights

def prune_node_keep(node_weights, keep=4):
    w = node_weights.cpu().detach().numpy()
    throw_idx = np.argsort(abs(w))[:-keep]
    node_weights[throw_idx] = 0
    return node_weights

def prune_tree(tree_, factor):
    new_weights = tree_.inner_nodes.weight.clone()
    for i in range(new_weights.shape[0]):
        res = prune_node_keep(new_weights[i, :], factor)
        new_weights[i, :] = res

    with torch.no_grad():
        tree_.inner_nodes.weight.copy_(new_weights)
        
def sparseness(x):
    s = []
    for i in range(x.shape[0]):
        x_ = x[i, :]
        sp = (len(x_) - torch.norm(x_, 0).item()) / len(x_)
        s.append(sp)
    return np.mean(s)

def compute_regularization_by_level(tree):
    total_reg = 0
    for i in range(tree.inner_nodes.weight.shape[0]):
        cur_level = np.floor(np.log2(i+1))
        node_reg = torch.norm(tree.inner_nodes.weight[i].view(-1), 2)
        total_reg += 2**(-cur_level) * node_reg
    return total_reg

def show_sparseness(tree):
    avg_sp = sparseness(tree.inner_nodes.weight)
    print(f"Average sparseness: {avg_sp}")
    layer = 0
    sps = []
    for i in range(tree.inner_nodes.weight.shape[0]):
        cur_layer = int(np.floor(np.log2(i+1)))
        if cur_layer != layer:
            print(f"layer {layer}: {np.mean(sps)}")
            sps = []
            layer = cur_layer

        x_ = tree.inner_nodes.weight[i, :]
        sp = (len(x_) - torch.norm(x_, 0).item()) / len(x_)
        sps.append(sp)
        
    return avg_sp

## Training configurations

In [28]:
def do_epoch(model, loader, device, log_interval, losses, accs, epoch, iteration):
    model = model.train()
    for batch_idx, (data, target) in enumerate(loader):
        iteration += 1
        data, target = data.to(device), target.to(device)

        output, penalty = tree.forward(data)

        # Loss
        loss_tree = criterion(output, target.view(-1))

        # Penalty
        loss_tree += penalty

        # Sparse regularization
#         fc_params = torch.cat([x.view(-1) for x in tree.inner_nodes.parameters()])
#         regularization = sparsity_lamda * torch.norm(fc_params, 2)
        regularization = sparsity_lamda * compute_regularization_by_level(tree)
        loss = loss_tree

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        losses.append(loss.item())

        pred = output.data.max(1)[1]
        correct = pred.eq(target.view(-1).data).sum()
        accs.append(correct.item() / data.size()[0])

        # Print training status
        if batch_idx % log_interval == 0:
            print(f"Epoch: {epoch:02d} | Batch: {batch_idx:03d} / {len(loader):03d} | Total loss: {loss.item():.3f} | Reg loss: {regularization.item():.3f} | Tree loss: {loss_tree.item():.3f} | Accuracy: {correct.item() / data.size()[0]:03f} | {round((time.time() - start_time) / iteration, 3)} sec/iter")
            
    return iteration


In [29]:
lr = 5e-3
weight_decay = 5e-4
sparsity_lamda = 2e-3
epochs = 100
output_dim = len(set(clusters))
log_interval = 1
use_cuda = device != 'cpu'

In [30]:
tree = SDT(input_dim=tensor_dataset.shape[1], output_dim=len(clusters - 1), depth=tree_depth, lamda=1e-3, use_cuda=use_cuda)
optimizer = torch.optim.Adam(tree.parameters(),
                                 lr=lr,
                                 weight_decay=weight_decay)
criterion = nn.CrossEntropyLoss()
tree = tree.to(device)

In [31]:
losses = []
accs = []
sparsity = []

In [32]:
start_time = time.time()
iteration = 0
for epoch in range(epochs):
    # Training
    avg_sp = show_sparseness(tree)
    sparsity.append(avg_sp)
    iteration = do_epoch(tree, tree_loader, device, log_interval, losses, accs, epoch, iteration)
    
    if epoch % 1 == 0:
        prune_tree(tree, factor=3)
        

Average sparseness: 0.0
layer 0: 0.0
layer 1: 0.0
layer 2: 0.0
layer 3: 0.0
layer 4: 0.0
layer 5: 0.0
layer 6: 0.0
Epoch: 00 | Batch: 000 / 027 | Total loss: 9.630 | Reg loss: 0.009 | Tree loss: 9.630 | Accuracy: 0.000000 | 0.58 sec/iter
Epoch: 00 | Batch: 001 / 027 | Total loss: 9.626 | Reg loss: 0.009 | Tree loss: 9.626 | Accuracy: 0.000000 | 0.434 sec/iter
Epoch: 00 | Batch: 002 / 027 | Total loss: 9.620 | Reg loss: 0.008 | Tree loss: 9.620 | Accuracy: 0.000000 | 0.387 sec/iter
Epoch: 00 | Batch: 003 / 027 | Total loss: 9.616 | Reg loss: 0.008 | Tree loss: 9.616 | Accuracy: 0.000000 | 0.365 sec/iter
Epoch: 00 | Batch: 004 / 027 | Total loss: 9.606 | Reg loss: 0.008 | Tree loss: 9.606 | Accuracy: 0.000000 | 0.351 sec/iter
Epoch: 00 | Batch: 005 / 027 | Total loss: 9.603 | Reg loss: 0.008 | Tree loss: 9.603 | Accuracy: 0.000000 | 0.343 sec/iter
Epoch: 00 | Batch: 006 / 027 | Total loss: 9.594 | Reg loss: 0.007 | Tree loss: 9.594 | Accuracy: 0.000000 | 0.346 sec/iter
Epoch: 00 | Batch:

Epoch: 02 | Batch: 008 / 027 | Total loss: 9.426 | Reg loss: 0.008 | Tree loss: 9.426 | Accuracy: 0.105469 | 0.328 sec/iter
Epoch: 02 | Batch: 009 / 027 | Total loss: 9.433 | Reg loss: 0.008 | Tree loss: 9.433 | Accuracy: 0.087891 | 0.328 sec/iter
Epoch: 02 | Batch: 010 / 027 | Total loss: 9.418 | Reg loss: 0.008 | Tree loss: 9.418 | Accuracy: 0.097656 | 0.329 sec/iter
Epoch: 02 | Batch: 011 / 027 | Total loss: 9.411 | Reg loss: 0.008 | Tree loss: 9.411 | Accuracy: 0.105469 | 0.329 sec/iter
Epoch: 02 | Batch: 012 / 027 | Total loss: 9.403 | Reg loss: 0.009 | Tree loss: 9.403 | Accuracy: 0.119141 | 0.329 sec/iter
Epoch: 02 | Batch: 013 / 027 | Total loss: 9.394 | Reg loss: 0.009 | Tree loss: 9.394 | Accuracy: 0.134766 | 0.329 sec/iter
Epoch: 02 | Batch: 014 / 027 | Total loss: 9.398 | Reg loss: 0.009 | Tree loss: 9.398 | Accuracy: 0.101562 | 0.329 sec/iter
Epoch: 02 | Batch: 015 / 027 | Total loss: 9.391 | Reg loss: 0.010 | Tree loss: 9.391 | Accuracy: 0.103516 | 0.329 sec/iter
Epoch: 0

Epoch: 04 | Batch: 017 / 027 | Total loss: 9.175 | Reg loss: 0.014 | Tree loss: 9.175 | Accuracy: 0.109375 | 0.331 sec/iter
Epoch: 04 | Batch: 018 / 027 | Total loss: 9.161 | Reg loss: 0.014 | Tree loss: 9.161 | Accuracy: 0.113281 | 0.331 sec/iter
Epoch: 04 | Batch: 019 / 027 | Total loss: 9.165 | Reg loss: 0.014 | Tree loss: 9.165 | Accuracy: 0.097656 | 0.331 sec/iter
Epoch: 04 | Batch: 020 / 027 | Total loss: 9.144 | Reg loss: 0.015 | Tree loss: 9.144 | Accuracy: 0.099609 | 0.331 sec/iter
Epoch: 04 | Batch: 021 / 027 | Total loss: 9.134 | Reg loss: 0.015 | Tree loss: 9.134 | Accuracy: 0.134766 | 0.331 sec/iter
Epoch: 04 | Batch: 022 / 027 | Total loss: 9.128 | Reg loss: 0.015 | Tree loss: 9.128 | Accuracy: 0.113281 | 0.332 sec/iter
Epoch: 04 | Batch: 023 / 027 | Total loss: 9.119 | Reg loss: 0.016 | Tree loss: 9.119 | Accuracy: 0.105469 | 0.332 sec/iter
Epoch: 04 | Batch: 024 / 027 | Total loss: 9.115 | Reg loss: 0.016 | Tree loss: 9.115 | Accuracy: 0.093750 | 0.331 sec/iter
Epoch: 0

Epoch: 06 | Batch: 026 / 027 | Total loss: 8.677 | Reg loss: 0.023 | Tree loss: 8.677 | Accuracy: 0.083333 | 0.328 sec/iter
Average sparseness: 0.9821428571428573
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
Epoch: 07 | Batch: 000 / 027 | Total loss: 8.944 | Reg loss: 0.016 | Tree loss: 8.944 | Accuracy: 0.111328 | 0.33 sec/iter
Epoch: 07 | Batch: 001 / 027 | Total loss: 8.928 | Reg loss: 0.016 | Tree loss: 8.928 | Accuracy: 0.099609 | 0.329 sec/iter
Epoch: 07 | Batch: 002 / 027 | Total loss: 8.905 | Reg loss: 0.016 | Tree loss: 8.905 | Accuracy: 0.125000 | 0.329 sec/iter
Epoch: 07 | Batch: 003 / 027 | Total loss: 8.897 | Reg loss: 0.017 | Tree loss: 8.897 | Accuracy: 0.103516 | 0.329 sec/iter
Epoch: 07 | Batch: 004 / 027 | Total loss: 8.897 | Reg loss: 0.017 | Tree loss: 8.897 | Accuracy: 0.091797 | 0.329 sec/iter
Epoch: 07 | Batch: 005 

Epoch: 09 | Batch: 006 / 027 | Total loss: 8.447 | Reg loss: 0.022 | Tree loss: 8.447 | Accuracy: 0.105469 | 0.333 sec/iter
Epoch: 09 | Batch: 007 / 027 | Total loss: 8.458 | Reg loss: 0.022 | Tree loss: 8.458 | Accuracy: 0.082031 | 0.333 sec/iter
Epoch: 09 | Batch: 008 / 027 | Total loss: 8.404 | Reg loss: 0.023 | Tree loss: 8.404 | Accuracy: 0.089844 | 0.333 sec/iter
Epoch: 09 | Batch: 009 / 027 | Total loss: 8.377 | Reg loss: 0.023 | Tree loss: 8.377 | Accuracy: 0.111328 | 0.333 sec/iter
Epoch: 09 | Batch: 010 / 027 | Total loss: 8.342 | Reg loss: 0.023 | Tree loss: 8.342 | Accuracy: 0.099609 | 0.333 sec/iter
Epoch: 09 | Batch: 011 / 027 | Total loss: 8.327 | Reg loss: 0.024 | Tree loss: 8.327 | Accuracy: 0.121094 | 0.333 sec/iter
Epoch: 09 | Batch: 012 / 027 | Total loss: 8.309 | Reg loss: 0.024 | Tree loss: 8.309 | Accuracy: 0.125000 | 0.333 sec/iter
Epoch: 09 | Batch: 013 / 027 | Total loss: 8.277 | Reg loss: 0.024 | Tree loss: 8.277 | Accuracy: 0.105469 | 0.333 sec/iter
Epoch: 0

Epoch: 11 | Batch: 015 / 027 | Total loss: 7.717 | Reg loss: 0.028 | Tree loss: 7.717 | Accuracy: 0.109375 | 0.338 sec/iter
Epoch: 11 | Batch: 016 / 027 | Total loss: 7.720 | Reg loss: 0.028 | Tree loss: 7.720 | Accuracy: 0.089844 | 0.338 sec/iter
Epoch: 11 | Batch: 017 / 027 | Total loss: 7.693 | Reg loss: 0.028 | Tree loss: 7.693 | Accuracy: 0.097656 | 0.339 sec/iter
Epoch: 11 | Batch: 018 / 027 | Total loss: 7.647 | Reg loss: 0.028 | Tree loss: 7.647 | Accuracy: 0.126953 | 0.339 sec/iter
Epoch: 11 | Batch: 019 / 027 | Total loss: 7.672 | Reg loss: 0.029 | Tree loss: 7.672 | Accuracy: 0.082031 | 0.339 sec/iter
Epoch: 11 | Batch: 020 / 027 | Total loss: 7.640 | Reg loss: 0.029 | Tree loss: 7.640 | Accuracy: 0.103516 | 0.339 sec/iter
Epoch: 11 | Batch: 021 / 027 | Total loss: 7.608 | Reg loss: 0.029 | Tree loss: 7.608 | Accuracy: 0.109375 | 0.339 sec/iter
Epoch: 11 | Batch: 022 / 027 | Total loss: 7.583 | Reg loss: 0.029 | Tree loss: 7.583 | Accuracy: 0.107422 | 0.339 sec/iter
Epoch: 1

Epoch: 13 | Batch: 024 / 027 | Total loss: 7.099 | Reg loss: 0.031 | Tree loss: 7.099 | Accuracy: 0.099609 | 0.343 sec/iter
Epoch: 13 | Batch: 025 / 027 | Total loss: 7.060 | Reg loss: 0.031 | Tree loss: 7.060 | Accuracy: 0.095703 | 0.343 sec/iter
Epoch: 13 | Batch: 026 / 027 | Total loss: 7.194 | Reg loss: 0.031 | Tree loss: 7.194 | Accuracy: 0.250000 | 0.344 sec/iter
Average sparseness: 0.9821428571428573
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
Epoch: 14 | Batch: 000 / 027 | Total loss: 7.301 | Reg loss: 0.028 | Tree loss: 7.301 | Accuracy: 0.103516 | 0.344 sec/iter
Epoch: 14 | Batch: 001 / 027 | Total loss: 7.322 | Reg loss: 0.029 | Tree loss: 7.322 | Accuracy: 0.091797 | 0.344 sec/iter
Epoch: 14 | Batch: 002 / 027 | Total loss: 7.276 | Reg loss: 0.029 | Tree loss: 7.276 | Accuracy: 0.109375 | 0.344 sec/iter
Epoch: 14 | Batch: 003

Epoch: 16 | Batch: 004 / 027 | Total loss: 6.799 | Reg loss: 0.030 | Tree loss: 6.799 | Accuracy: 0.103516 | 0.346 sec/iter
Epoch: 16 | Batch: 005 / 027 | Total loss: 6.778 | Reg loss: 0.030 | Tree loss: 6.778 | Accuracy: 0.103516 | 0.346 sec/iter
Epoch: 16 | Batch: 006 / 027 | Total loss: 6.754 | Reg loss: 0.030 | Tree loss: 6.754 | Accuracy: 0.089844 | 0.346 sec/iter
Epoch: 16 | Batch: 007 / 027 | Total loss: 6.691 | Reg loss: 0.030 | Tree loss: 6.691 | Accuracy: 0.109375 | 0.346 sec/iter
Epoch: 16 | Batch: 008 / 027 | Total loss: 6.724 | Reg loss: 0.030 | Tree loss: 6.724 | Accuracy: 0.087891 | 0.346 sec/iter
Epoch: 16 | Batch: 009 / 027 | Total loss: 6.668 | Reg loss: 0.030 | Tree loss: 6.668 | Accuracy: 0.115234 | 0.346 sec/iter
Epoch: 16 | Batch: 010 / 027 | Total loss: 6.665 | Reg loss: 0.030 | Tree loss: 6.665 | Accuracy: 0.097656 | 0.346 sec/iter
Epoch: 16 | Batch: 011 / 027 | Total loss: 6.657 | Reg loss: 0.030 | Tree loss: 6.657 | Accuracy: 0.109375 | 0.346 sec/iter
Epoch: 1

Epoch: 18 | Batch: 013 / 027 | Total loss: 6.198 | Reg loss: 0.031 | Tree loss: 6.198 | Accuracy: 0.095703 | 0.346 sec/iter
Epoch: 18 | Batch: 014 / 027 | Total loss: 6.157 | Reg loss: 0.031 | Tree loss: 6.157 | Accuracy: 0.117188 | 0.346 sec/iter
Epoch: 18 | Batch: 015 / 027 | Total loss: 6.145 | Reg loss: 0.031 | Tree loss: 6.145 | Accuracy: 0.087891 | 0.346 sec/iter
Epoch: 18 | Batch: 016 / 027 | Total loss: 6.121 | Reg loss: 0.031 | Tree loss: 6.121 | Accuracy: 0.113281 | 0.346 sec/iter
Epoch: 18 | Batch: 017 / 027 | Total loss: 6.177 | Reg loss: 0.031 | Tree loss: 6.177 | Accuracy: 0.095703 | 0.346 sec/iter
Epoch: 18 | Batch: 018 / 027 | Total loss: 6.137 | Reg loss: 0.031 | Tree loss: 6.137 | Accuracy: 0.082031 | 0.346 sec/iter
Epoch: 18 | Batch: 019 / 027 | Total loss: 6.110 | Reg loss: 0.031 | Tree loss: 6.110 | Accuracy: 0.105469 | 0.346 sec/iter
Epoch: 18 | Batch: 020 / 027 | Total loss: 6.087 | Reg loss: 0.031 | Tree loss: 6.087 | Accuracy: 0.109375 | 0.346 sec/iter
Epoch: 1

Epoch: 20 | Batch: 022 / 027 | Total loss: 5.639 | Reg loss: 0.032 | Tree loss: 5.639 | Accuracy: 0.089844 | 0.347 sec/iter
Epoch: 20 | Batch: 023 / 027 | Total loss: 5.640 | Reg loss: 0.032 | Tree loss: 5.640 | Accuracy: 0.126953 | 0.347 sec/iter
Epoch: 20 | Batch: 024 / 027 | Total loss: 5.633 | Reg loss: 0.032 | Tree loss: 5.633 | Accuracy: 0.095703 | 0.347 sec/iter
Epoch: 20 | Batch: 025 / 027 | Total loss: 5.619 | Reg loss: 0.032 | Tree loss: 5.619 | Accuracy: 0.105469 | 0.347 sec/iter
Epoch: 20 | Batch: 026 / 027 | Total loss: 5.578 | Reg loss: 0.032 | Tree loss: 5.578 | Accuracy: 0.166667 | 0.347 sec/iter
Average sparseness: 0.9821428571428573
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
Epoch: 21 | Batch: 000 / 027 | Total loss: 5.854 | Reg loss: 0.030 | Tree loss: 5.854 | Accuracy: 0.101562 | 0.348 sec/iter
Epoch: 21 | Batch: 001

Epoch: 23 | Batch: 002 / 027 | Total loss: 5.488 | Reg loss: 0.030 | Tree loss: 5.488 | Accuracy: 0.103516 | 0.348 sec/iter
Epoch: 23 | Batch: 003 / 027 | Total loss: 5.483 | Reg loss: 0.030 | Tree loss: 5.483 | Accuracy: 0.109375 | 0.348 sec/iter
Epoch: 23 | Batch: 004 / 027 | Total loss: 5.467 | Reg loss: 0.030 | Tree loss: 5.467 | Accuracy: 0.103516 | 0.348 sec/iter
Epoch: 23 | Batch: 005 / 027 | Total loss: 5.426 | Reg loss: 0.030 | Tree loss: 5.426 | Accuracy: 0.103516 | 0.348 sec/iter
Epoch: 23 | Batch: 006 / 027 | Total loss: 5.411 | Reg loss: 0.030 | Tree loss: 5.411 | Accuracy: 0.101562 | 0.348 sec/iter
Epoch: 23 | Batch: 007 / 027 | Total loss: 5.388 | Reg loss: 0.030 | Tree loss: 5.388 | Accuracy: 0.101562 | 0.348 sec/iter
Epoch: 23 | Batch: 008 / 027 | Total loss: 5.367 | Reg loss: 0.030 | Tree loss: 5.367 | Accuracy: 0.111328 | 0.348 sec/iter
Epoch: 23 | Batch: 009 / 027 | Total loss: 5.358 | Reg loss: 0.030 | Tree loss: 5.358 | Accuracy: 0.093750 | 0.348 sec/iter
Epoch: 2

Epoch: 25 | Batch: 011 / 027 | Total loss: 4.922 | Reg loss: 0.031 | Tree loss: 4.922 | Accuracy: 0.095703 | 0.35 sec/iter
Epoch: 25 | Batch: 012 / 027 | Total loss: 4.862 | Reg loss: 0.031 | Tree loss: 4.862 | Accuracy: 0.109375 | 0.35 sec/iter
Epoch: 25 | Batch: 013 / 027 | Total loss: 4.843 | Reg loss: 0.032 | Tree loss: 4.843 | Accuracy: 0.091797 | 0.35 sec/iter
Epoch: 25 | Batch: 014 / 027 | Total loss: 4.802 | Reg loss: 0.032 | Tree loss: 4.802 | Accuracy: 0.093750 | 0.35 sec/iter
Epoch: 25 | Batch: 015 / 027 | Total loss: 4.818 | Reg loss: 0.032 | Tree loss: 4.818 | Accuracy: 0.093750 | 0.35 sec/iter
Epoch: 25 | Batch: 016 / 027 | Total loss: 4.870 | Reg loss: 0.032 | Tree loss: 4.870 | Accuracy: 0.091797 | 0.35 sec/iter
Epoch: 25 | Batch: 017 / 027 | Total loss: 4.748 | Reg loss: 0.032 | Tree loss: 4.748 | Accuracy: 0.107422 | 0.35 sec/iter
Epoch: 25 | Batch: 018 / 027 | Total loss: 4.763 | Reg loss: 0.032 | Tree loss: 4.763 | Accuracy: 0.099609 | 0.35 sec/iter
Epoch: 25 | Batc

Epoch: 27 | Batch: 020 / 027 | Total loss: 4.400 | Reg loss: 0.034 | Tree loss: 4.400 | Accuracy: 0.101562 | 0.35 sec/iter
Epoch: 27 | Batch: 021 / 027 | Total loss: 4.406 | Reg loss: 0.034 | Tree loss: 4.406 | Accuracy: 0.076172 | 0.35 sec/iter
Epoch: 27 | Batch: 022 / 027 | Total loss: 4.398 | Reg loss: 0.034 | Tree loss: 4.398 | Accuracy: 0.089844 | 0.35 sec/iter
Epoch: 27 | Batch: 023 / 027 | Total loss: 4.362 | Reg loss: 0.034 | Tree loss: 4.362 | Accuracy: 0.087891 | 0.35 sec/iter
Epoch: 27 | Batch: 024 / 027 | Total loss: 4.298 | Reg loss: 0.034 | Tree loss: 4.298 | Accuracy: 0.099609 | 0.35 sec/iter
Epoch: 27 | Batch: 025 / 027 | Total loss: 4.246 | Reg loss: 0.034 | Tree loss: 4.246 | Accuracy: 0.117188 | 0.35 sec/iter
Epoch: 27 | Batch: 026 / 027 | Total loss: 4.230 | Reg loss: 0.034 | Tree loss: 4.230 | Accuracy: 0.166667 | 0.35 sec/iter
Average sparseness: 0.9821428571428573
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.98214

Epoch: 30 | Batch: 000 / 027 | Total loss: 4.253 | Reg loss: 0.033 | Tree loss: 4.253 | Accuracy: 0.080078 | 0.352 sec/iter
Epoch: 30 | Batch: 001 / 027 | Total loss: 4.216 | Reg loss: 0.033 | Tree loss: 4.216 | Accuracy: 0.109375 | 0.352 sec/iter
Epoch: 30 | Batch: 002 / 027 | Total loss: 4.243 | Reg loss: 0.033 | Tree loss: 4.243 | Accuracy: 0.107422 | 0.352 sec/iter
Epoch: 30 | Batch: 003 / 027 | Total loss: 4.199 | Reg loss: 0.033 | Tree loss: 4.199 | Accuracy: 0.091797 | 0.352 sec/iter
Epoch: 30 | Batch: 004 / 027 | Total loss: 4.192 | Reg loss: 0.033 | Tree loss: 4.192 | Accuracy: 0.113281 | 0.352 sec/iter
Epoch: 30 | Batch: 005 / 027 | Total loss: 4.170 | Reg loss: 0.033 | Tree loss: 4.170 | Accuracy: 0.111328 | 0.352 sec/iter
Epoch: 30 | Batch: 006 / 027 | Total loss: 4.160 | Reg loss: 0.033 | Tree loss: 4.160 | Accuracy: 0.091797 | 0.352 sec/iter
Epoch: 30 | Batch: 007 / 027 | Total loss: 4.148 | Reg loss: 0.034 | Tree loss: 4.148 | Accuracy: 0.103516 | 0.352 sec/iter
Epoch: 3

Epoch: 32 | Batch: 009 / 027 | Total loss: 3.971 | Reg loss: 0.034 | Tree loss: 3.971 | Accuracy: 0.076172 | 0.352 sec/iter
Epoch: 32 | Batch: 010 / 027 | Total loss: 3.888 | Reg loss: 0.034 | Tree loss: 3.888 | Accuracy: 0.083984 | 0.352 sec/iter
Epoch: 32 | Batch: 011 / 027 | Total loss: 3.934 | Reg loss: 0.034 | Tree loss: 3.934 | Accuracy: 0.097656 | 0.352 sec/iter
Epoch: 32 | Batch: 012 / 027 | Total loss: 3.822 | Reg loss: 0.034 | Tree loss: 3.822 | Accuracy: 0.107422 | 0.352 sec/iter
Epoch: 32 | Batch: 013 / 027 | Total loss: 3.955 | Reg loss: 0.034 | Tree loss: 3.955 | Accuracy: 0.082031 | 0.352 sec/iter
Epoch: 32 | Batch: 014 / 027 | Total loss: 3.883 | Reg loss: 0.035 | Tree loss: 3.883 | Accuracy: 0.085938 | 0.352 sec/iter
Epoch: 32 | Batch: 015 / 027 | Total loss: 3.833 | Reg loss: 0.035 | Tree loss: 3.833 | Accuracy: 0.093750 | 0.352 sec/iter
Epoch: 32 | Batch: 016 / 027 | Total loss: 3.859 | Reg loss: 0.035 | Tree loss: 3.859 | Accuracy: 0.097656 | 0.352 sec/iter
Epoch: 3

Epoch: 34 | Batch: 018 / 027 | Total loss: 3.662 | Reg loss: 0.035 | Tree loss: 3.662 | Accuracy: 0.078125 | 0.352 sec/iter
Epoch: 34 | Batch: 019 / 027 | Total loss: 3.672 | Reg loss: 0.035 | Tree loss: 3.672 | Accuracy: 0.105469 | 0.352 sec/iter
Epoch: 34 | Batch: 020 / 027 | Total loss: 3.695 | Reg loss: 0.035 | Tree loss: 3.695 | Accuracy: 0.085938 | 0.352 sec/iter
Epoch: 34 | Batch: 021 / 027 | Total loss: 3.624 | Reg loss: 0.035 | Tree loss: 3.624 | Accuracy: 0.095703 | 0.351 sec/iter
Epoch: 34 | Batch: 022 / 027 | Total loss: 3.623 | Reg loss: 0.035 | Tree loss: 3.623 | Accuracy: 0.089844 | 0.351 sec/iter
Epoch: 34 | Batch: 023 / 027 | Total loss: 3.638 | Reg loss: 0.035 | Tree loss: 3.638 | Accuracy: 0.095703 | 0.351 sec/iter
Epoch: 34 | Batch: 024 / 027 | Total loss: 3.601 | Reg loss: 0.035 | Tree loss: 3.601 | Accuracy: 0.097656 | 0.351 sec/iter
Epoch: 34 | Batch: 025 / 027 | Total loss: 3.563 | Reg loss: 0.036 | Tree loss: 3.563 | Accuracy: 0.111328 | 0.351 sec/iter
Epoch: 3

Average sparseness: 0.9821428571428573
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
Epoch: 37 | Batch: 000 / 027 | Total loss: 3.693 | Reg loss: 0.035 | Tree loss: 3.693 | Accuracy: 0.091797 | 0.352 sec/iter
Epoch: 37 | Batch: 001 / 027 | Total loss: 3.662 | Reg loss: 0.035 | Tree loss: 3.662 | Accuracy: 0.119141 | 0.351 sec/iter
Epoch: 37 | Batch: 002 / 027 | Total loss: 3.635 | Reg loss: 0.035 | Tree loss: 3.635 | Accuracy: 0.097656 | 0.352 sec/iter
Epoch: 37 | Batch: 003 / 027 | Total loss: 3.647 | Reg loss: 0.035 | Tree loss: 3.647 | Accuracy: 0.103516 | 0.352 sec/iter
Epoch: 37 | Batch: 004 / 027 | Total loss: 3.660 | Reg loss: 0.035 | Tree loss: 3.660 | Accuracy: 0.101562 | 0.352 sec/iter
Epoch: 37 | Batch: 005 / 027 | Total loss: 3.637 | Reg loss: 0.035 | Tree loss: 3.637 | Accuracy: 0.082031 | 0.352 sec/iter
Epoch: 37 | Batch: 006

Epoch: 39 | Batch: 007 / 027 | Total loss: 3.521 | Reg loss: 0.035 | Tree loss: 3.521 | Accuracy: 0.101562 | 0.352 sec/iter
Epoch: 39 | Batch: 008 / 027 | Total loss: 3.522 | Reg loss: 0.035 | Tree loss: 3.522 | Accuracy: 0.076172 | 0.352 sec/iter
Epoch: 39 | Batch: 009 / 027 | Total loss: 3.435 | Reg loss: 0.035 | Tree loss: 3.435 | Accuracy: 0.107422 | 0.352 sec/iter
Epoch: 39 | Batch: 010 / 027 | Total loss: 3.461 | Reg loss: 0.035 | Tree loss: 3.461 | Accuracy: 0.125000 | 0.352 sec/iter
Epoch: 39 | Batch: 011 / 027 | Total loss: 3.471 | Reg loss: 0.035 | Tree loss: 3.471 | Accuracy: 0.082031 | 0.352 sec/iter
Epoch: 39 | Batch: 012 / 027 | Total loss: 3.472 | Reg loss: 0.035 | Tree loss: 3.472 | Accuracy: 0.087891 | 0.352 sec/iter
Epoch: 39 | Batch: 013 / 027 | Total loss: 3.434 | Reg loss: 0.035 | Tree loss: 3.434 | Accuracy: 0.091797 | 0.352 sec/iter
Epoch: 39 | Batch: 014 / 027 | Total loss: 3.467 | Reg loss: 0.035 | Tree loss: 3.467 | Accuracy: 0.089844 | 0.352 sec/iter
Epoch: 3

Epoch: 41 | Batch: 016 / 027 | Total loss: 3.368 | Reg loss: 0.035 | Tree loss: 3.368 | Accuracy: 0.082031 | 0.351 sec/iter
Epoch: 41 | Batch: 017 / 027 | Total loss: 3.294 | Reg loss: 0.035 | Tree loss: 3.294 | Accuracy: 0.111328 | 0.351 sec/iter
Epoch: 41 | Batch: 018 / 027 | Total loss: 3.314 | Reg loss: 0.035 | Tree loss: 3.314 | Accuracy: 0.107422 | 0.351 sec/iter
Epoch: 41 | Batch: 019 / 027 | Total loss: 3.280 | Reg loss: 0.035 | Tree loss: 3.280 | Accuracy: 0.093750 | 0.351 sec/iter
Epoch: 41 | Batch: 020 / 027 | Total loss: 3.290 | Reg loss: 0.035 | Tree loss: 3.290 | Accuracy: 0.103516 | 0.351 sec/iter
Epoch: 41 | Batch: 021 / 027 | Total loss: 3.317 | Reg loss: 0.035 | Tree loss: 3.317 | Accuracy: 0.115234 | 0.35 sec/iter
Epoch: 41 | Batch: 022 / 027 | Total loss: 3.348 | Reg loss: 0.035 | Tree loss: 3.348 | Accuracy: 0.091797 | 0.35 sec/iter
Epoch: 41 | Batch: 023 / 027 | Total loss: 3.297 | Reg loss: 0.035 | Tree loss: 3.297 | Accuracy: 0.085938 | 0.35 sec/iter
Epoch: 41 |

Epoch: 43 | Batch: 025 / 027 | Total loss: 3.200 | Reg loss: 0.035 | Tree loss: 3.200 | Accuracy: 0.103516 | 0.349 sec/iter
Epoch: 43 | Batch: 026 / 027 | Total loss: 3.005 | Reg loss: 0.035 | Tree loss: 3.005 | Accuracy: 0.000000 | 0.349 sec/iter
Average sparseness: 0.9821428571428573
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
Epoch: 44 | Batch: 000 / 027 | Total loss: 3.424 | Reg loss: 0.035 | Tree loss: 3.424 | Accuracy: 0.083984 | 0.349 sec/iter
Epoch: 44 | Batch: 001 / 027 | Total loss: 3.379 | Reg loss: 0.035 | Tree loss: 3.379 | Accuracy: 0.119141 | 0.349 sec/iter
Epoch: 44 | Batch: 002 / 027 | Total loss: 3.406 | Reg loss: 0.035 | Tree loss: 3.406 | Accuracy: 0.095703 | 0.349 sec/iter
Epoch: 44 | Batch: 003 / 027 | Total loss: 3.368 | Reg loss: 0.035 | Tree loss: 3.368 | Accuracy: 0.085938 | 0.349 sec/iter
Epoch: 44 | Batch: 004

Epoch: 46 | Batch: 005 / 027 | Total loss: 3.324 | Reg loss: 0.035 | Tree loss: 3.324 | Accuracy: 0.095703 | 0.349 sec/iter
Epoch: 46 | Batch: 006 / 027 | Total loss: 3.330 | Reg loss: 0.035 | Tree loss: 3.330 | Accuracy: 0.099609 | 0.349 sec/iter
Epoch: 46 | Batch: 007 / 027 | Total loss: 3.264 | Reg loss: 0.035 | Tree loss: 3.264 | Accuracy: 0.093750 | 0.349 sec/iter
Epoch: 46 | Batch: 008 / 027 | Total loss: 3.272 | Reg loss: 0.035 | Tree loss: 3.272 | Accuracy: 0.083984 | 0.349 sec/iter
Epoch: 46 | Batch: 009 / 027 | Total loss: 3.275 | Reg loss: 0.035 | Tree loss: 3.275 | Accuracy: 0.109375 | 0.349 sec/iter
Epoch: 46 | Batch: 010 / 027 | Total loss: 3.293 | Reg loss: 0.035 | Tree loss: 3.293 | Accuracy: 0.111328 | 0.35 sec/iter
Epoch: 46 | Batch: 011 / 027 | Total loss: 3.254 | Reg loss: 0.035 | Tree loss: 3.254 | Accuracy: 0.101562 | 0.35 sec/iter
Epoch: 46 | Batch: 012 / 027 | Total loss: 3.237 | Reg loss: 0.035 | Tree loss: 3.237 | Accuracy: 0.095703 | 0.349 sec/iter
Epoch: 46 

Epoch: 48 | Batch: 014 / 027 | Total loss: 3.276 | Reg loss: 0.035 | Tree loss: 3.276 | Accuracy: 0.083984 | 0.35 sec/iter
Epoch: 48 | Batch: 015 / 027 | Total loss: 3.148 | Reg loss: 0.035 | Tree loss: 3.148 | Accuracy: 0.115234 | 0.35 sec/iter
Epoch: 48 | Batch: 016 / 027 | Total loss: 3.158 | Reg loss: 0.035 | Tree loss: 3.158 | Accuracy: 0.085938 | 0.35 sec/iter
Epoch: 48 | Batch: 017 / 027 | Total loss: 3.155 | Reg loss: 0.035 | Tree loss: 3.155 | Accuracy: 0.093750 | 0.35 sec/iter
Epoch: 48 | Batch: 018 / 027 | Total loss: 3.157 | Reg loss: 0.035 | Tree loss: 3.157 | Accuracy: 0.105469 | 0.35 sec/iter
Epoch: 48 | Batch: 019 / 027 | Total loss: 3.105 | Reg loss: 0.035 | Tree loss: 3.105 | Accuracy: 0.123047 | 0.35 sec/iter
Epoch: 48 | Batch: 020 / 027 | Total loss: 3.066 | Reg loss: 0.035 | Tree loss: 3.066 | Accuracy: 0.105469 | 0.35 sec/iter
Epoch: 48 | Batch: 021 / 027 | Total loss: 3.128 | Reg loss: 0.035 | Tree loss: 3.128 | Accuracy: 0.095703 | 0.35 sec/iter
Epoch: 48 | Batc

Epoch: 50 | Batch: 023 / 027 | Total loss: 3.069 | Reg loss: 0.035 | Tree loss: 3.069 | Accuracy: 0.089844 | 0.35 sec/iter
Epoch: 50 | Batch: 024 / 027 | Total loss: 3.100 | Reg loss: 0.035 | Tree loss: 3.100 | Accuracy: 0.080078 | 0.35 sec/iter
Epoch: 50 | Batch: 025 / 027 | Total loss: 3.065 | Reg loss: 0.035 | Tree loss: 3.065 | Accuracy: 0.105469 | 0.35 sec/iter
Epoch: 50 | Batch: 026 / 027 | Total loss: 3.202 | Reg loss: 0.035 | Tree loss: 3.202 | Accuracy: 0.083333 | 0.35 sec/iter
Average sparseness: 0.9821428571428573
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
Epoch: 51 | Batch: 000 / 027 | Total loss: 3.227 | Reg loss: 0.034 | Tree loss: 3.227 | Accuracy: 0.101562 | 0.351 sec/iter
Epoch: 51 | Batch: 001 / 027 | Total loss: 3.305 | Reg loss: 0.034 | Tree loss: 3.305 | Accuracy: 0.097656 | 0.351 sec/iter
Epoch: 51 | Batch: 002 / 0

Epoch: 53 | Batch: 003 / 027 | Total loss: 3.217 | Reg loss: 0.034 | Tree loss: 3.217 | Accuracy: 0.101562 | 0.351 sec/iter
Epoch: 53 | Batch: 004 / 027 | Total loss: 3.206 | Reg loss: 0.034 | Tree loss: 3.206 | Accuracy: 0.089844 | 0.351 sec/iter
Epoch: 53 | Batch: 005 / 027 | Total loss: 3.187 | Reg loss: 0.034 | Tree loss: 3.187 | Accuracy: 0.095703 | 0.351 sec/iter
Epoch: 53 | Batch: 006 / 027 | Total loss: 3.155 | Reg loss: 0.034 | Tree loss: 3.155 | Accuracy: 0.125000 | 0.351 sec/iter
Epoch: 53 | Batch: 007 / 027 | Total loss: 3.194 | Reg loss: 0.034 | Tree loss: 3.194 | Accuracy: 0.091797 | 0.351 sec/iter
Epoch: 53 | Batch: 008 / 027 | Total loss: 3.228 | Reg loss: 0.034 | Tree loss: 3.228 | Accuracy: 0.083984 | 0.351 sec/iter
Epoch: 53 | Batch: 009 / 027 | Total loss: 3.157 | Reg loss: 0.034 | Tree loss: 3.157 | Accuracy: 0.091797 | 0.351 sec/iter
Epoch: 53 | Batch: 010 / 027 | Total loss: 3.142 | Reg loss: 0.034 | Tree loss: 3.142 | Accuracy: 0.091797 | 0.351 sec/iter
Epoch: 5

Epoch: 55 | Batch: 012 / 027 | Total loss: 3.141 | Reg loss: 0.034 | Tree loss: 3.141 | Accuracy: 0.091797 | 0.351 sec/iter
Epoch: 55 | Batch: 013 / 027 | Total loss: 3.112 | Reg loss: 0.034 | Tree loss: 3.112 | Accuracy: 0.097656 | 0.351 sec/iter
Epoch: 55 | Batch: 014 / 027 | Total loss: 3.092 | Reg loss: 0.034 | Tree loss: 3.092 | Accuracy: 0.095703 | 0.351 sec/iter
Epoch: 55 | Batch: 015 / 027 | Total loss: 3.066 | Reg loss: 0.034 | Tree loss: 3.066 | Accuracy: 0.089844 | 0.351 sec/iter
Epoch: 55 | Batch: 016 / 027 | Total loss: 3.084 | Reg loss: 0.035 | Tree loss: 3.084 | Accuracy: 0.115234 | 0.351 sec/iter
Epoch: 55 | Batch: 017 / 027 | Total loss: 3.036 | Reg loss: 0.035 | Tree loss: 3.036 | Accuracy: 0.126953 | 0.351 sec/iter
Epoch: 55 | Batch: 018 / 027 | Total loss: 3.045 | Reg loss: 0.035 | Tree loss: 3.045 | Accuracy: 0.095703 | 0.351 sec/iter
Epoch: 55 | Batch: 019 / 027 | Total loss: 3.036 | Reg loss: 0.035 | Tree loss: 3.036 | Accuracy: 0.089844 | 0.351 sec/iter
Epoch: 5

Epoch: 57 | Batch: 021 / 027 | Total loss: 3.039 | Reg loss: 0.035 | Tree loss: 3.039 | Accuracy: 0.113281 | 0.351 sec/iter
Epoch: 57 | Batch: 022 / 027 | Total loss: 3.022 | Reg loss: 0.035 | Tree loss: 3.022 | Accuracy: 0.101562 | 0.351 sec/iter
Epoch: 57 | Batch: 023 / 027 | Total loss: 2.986 | Reg loss: 0.035 | Tree loss: 2.986 | Accuracy: 0.103516 | 0.351 sec/iter
Epoch: 57 | Batch: 024 / 027 | Total loss: 3.015 | Reg loss: 0.035 | Tree loss: 3.015 | Accuracy: 0.099609 | 0.351 sec/iter
Epoch: 57 | Batch: 025 / 027 | Total loss: 2.981 | Reg loss: 0.035 | Tree loss: 2.981 | Accuracy: 0.091797 | 0.351 sec/iter
Epoch: 57 | Batch: 026 / 027 | Total loss: 2.969 | Reg loss: 0.035 | Tree loss: 2.969 | Accuracy: 0.000000 | 0.351 sec/iter
Average sparseness: 0.9821428571428573
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
Epoch: 58 | Batch: 000

Epoch: 60 | Batch: 001 / 027 | Total loss: 3.205 | Reg loss: 0.034 | Tree loss: 3.205 | Accuracy: 0.099609 | 0.352 sec/iter
Epoch: 60 | Batch: 002 / 027 | Total loss: 3.179 | Reg loss: 0.034 | Tree loss: 3.179 | Accuracy: 0.095703 | 0.352 sec/iter
Epoch: 60 | Batch: 003 / 027 | Total loss: 3.153 | Reg loss: 0.034 | Tree loss: 3.153 | Accuracy: 0.099609 | 0.352 sec/iter
Epoch: 60 | Batch: 004 / 027 | Total loss: 3.120 | Reg loss: 0.034 | Tree loss: 3.120 | Accuracy: 0.111328 | 0.352 sec/iter
Epoch: 60 | Batch: 005 / 027 | Total loss: 3.112 | Reg loss: 0.034 | Tree loss: 3.112 | Accuracy: 0.115234 | 0.352 sec/iter
Epoch: 60 | Batch: 006 / 027 | Total loss: 3.115 | Reg loss: 0.034 | Tree loss: 3.115 | Accuracy: 0.093750 | 0.352 sec/iter
Epoch: 60 | Batch: 007 / 027 | Total loss: 3.125 | Reg loss: 0.034 | Tree loss: 3.125 | Accuracy: 0.080078 | 0.352 sec/iter
Epoch: 60 | Batch: 008 / 027 | Total loss: 3.129 | Reg loss: 0.034 | Tree loss: 3.129 | Accuracy: 0.125000 | 0.352 sec/iter
Epoch: 6

Epoch: 62 | Batch: 010 / 027 | Total loss: 3.099 | Reg loss: 0.034 | Tree loss: 3.099 | Accuracy: 0.095703 | 0.352 sec/iter
Epoch: 62 | Batch: 011 / 027 | Total loss: 3.107 | Reg loss: 0.034 | Tree loss: 3.107 | Accuracy: 0.103516 | 0.352 sec/iter
Epoch: 62 | Batch: 012 / 027 | Total loss: 3.028 | Reg loss: 0.034 | Tree loss: 3.028 | Accuracy: 0.103516 | 0.352 sec/iter
Epoch: 62 | Batch: 013 / 027 | Total loss: 3.075 | Reg loss: 0.034 | Tree loss: 3.075 | Accuracy: 0.119141 | 0.352 sec/iter
Epoch: 62 | Batch: 014 / 027 | Total loss: 3.064 | Reg loss: 0.034 | Tree loss: 3.064 | Accuracy: 0.070312 | 0.352 sec/iter
Epoch: 62 | Batch: 015 / 027 | Total loss: 2.985 | Reg loss: 0.034 | Tree loss: 2.985 | Accuracy: 0.091797 | 0.352 sec/iter
Epoch: 62 | Batch: 016 / 027 | Total loss: 3.000 | Reg loss: 0.034 | Tree loss: 3.000 | Accuracy: 0.091797 | 0.352 sec/iter
Epoch: 62 | Batch: 017 / 027 | Total loss: 3.030 | Reg loss: 0.034 | Tree loss: 3.030 | Accuracy: 0.085938 | 0.352 sec/iter
Epoch: 6

Epoch: 64 | Batch: 019 / 027 | Total loss: 3.014 | Reg loss: 0.034 | Tree loss: 3.014 | Accuracy: 0.076172 | 0.352 sec/iter
Epoch: 64 | Batch: 020 / 027 | Total loss: 2.964 | Reg loss: 0.034 | Tree loss: 2.964 | Accuracy: 0.128906 | 0.352 sec/iter
Epoch: 64 | Batch: 021 / 027 | Total loss: 2.981 | Reg loss: 0.034 | Tree loss: 2.981 | Accuracy: 0.117188 | 0.352 sec/iter
Epoch: 64 | Batch: 022 / 027 | Total loss: 2.970 | Reg loss: 0.034 | Tree loss: 2.970 | Accuracy: 0.105469 | 0.352 sec/iter
Epoch: 64 | Batch: 023 / 027 | Total loss: 2.965 | Reg loss: 0.034 | Tree loss: 2.965 | Accuracy: 0.113281 | 0.352 sec/iter
Epoch: 64 | Batch: 024 / 027 | Total loss: 3.018 | Reg loss: 0.034 | Tree loss: 3.018 | Accuracy: 0.101562 | 0.352 sec/iter
Epoch: 64 | Batch: 025 / 027 | Total loss: 2.940 | Reg loss: 0.034 | Tree loss: 2.940 | Accuracy: 0.078125 | 0.352 sec/iter
Epoch: 64 | Batch: 026 / 027 | Total loss: 2.823 | Reg loss: 0.035 | Tree loss: 2.823 | Accuracy: 0.083333 | 0.352 sec/iter
Average 

Epoch: 67 | Batch: 000 / 027 | Total loss: 3.115 | Reg loss: 0.034 | Tree loss: 3.115 | Accuracy: 0.083984 | 0.353 sec/iter
Epoch: 67 | Batch: 001 / 027 | Total loss: 3.117 | Reg loss: 0.034 | Tree loss: 3.117 | Accuracy: 0.087891 | 0.353 sec/iter
Epoch: 67 | Batch: 002 / 027 | Total loss: 3.112 | Reg loss: 0.034 | Tree loss: 3.112 | Accuracy: 0.113281 | 0.353 sec/iter
Epoch: 67 | Batch: 003 / 027 | Total loss: 3.070 | Reg loss: 0.034 | Tree loss: 3.070 | Accuracy: 0.117188 | 0.353 sec/iter
Epoch: 67 | Batch: 004 / 027 | Total loss: 3.077 | Reg loss: 0.034 | Tree loss: 3.077 | Accuracy: 0.105469 | 0.353 sec/iter
Epoch: 67 | Batch: 005 / 027 | Total loss: 3.085 | Reg loss: 0.034 | Tree loss: 3.085 | Accuracy: 0.101562 | 0.353 sec/iter
Epoch: 67 | Batch: 006 / 027 | Total loss: 3.017 | Reg loss: 0.034 | Tree loss: 3.017 | Accuracy: 0.089844 | 0.353 sec/iter
Epoch: 67 | Batch: 007 / 027 | Total loss: 3.132 | Reg loss: 0.034 | Tree loss: 3.132 | Accuracy: 0.095703 | 0.353 sec/iter
Epoch: 6

Epoch: 69 | Batch: 009 / 027 | Total loss: 3.054 | Reg loss: 0.034 | Tree loss: 3.054 | Accuracy: 0.091797 | 0.353 sec/iter
Epoch: 69 | Batch: 010 / 027 | Total loss: 3.008 | Reg loss: 0.034 | Tree loss: 3.008 | Accuracy: 0.132812 | 0.353 sec/iter
Epoch: 69 | Batch: 011 / 027 | Total loss: 3.042 | Reg loss: 0.034 | Tree loss: 3.042 | Accuracy: 0.103516 | 0.353 sec/iter
Epoch: 69 | Batch: 012 / 027 | Total loss: 3.008 | Reg loss: 0.034 | Tree loss: 3.008 | Accuracy: 0.107422 | 0.353 sec/iter
Epoch: 69 | Batch: 013 / 027 | Total loss: 2.997 | Reg loss: 0.034 | Tree loss: 2.997 | Accuracy: 0.103516 | 0.353 sec/iter
Epoch: 69 | Batch: 014 / 027 | Total loss: 2.987 | Reg loss: 0.034 | Tree loss: 2.987 | Accuracy: 0.132812 | 0.353 sec/iter
Epoch: 69 | Batch: 015 / 027 | Total loss: 2.979 | Reg loss: 0.034 | Tree loss: 2.979 | Accuracy: 0.089844 | 0.353 sec/iter
Epoch: 69 | Batch: 016 / 027 | Total loss: 3.004 | Reg loss: 0.034 | Tree loss: 3.004 | Accuracy: 0.097656 | 0.353 sec/iter
Epoch: 6

Epoch: 71 | Batch: 018 / 027 | Total loss: 2.922 | Reg loss: 0.034 | Tree loss: 2.922 | Accuracy: 0.091797 | 0.353 sec/iter
Epoch: 71 | Batch: 019 / 027 | Total loss: 2.940 | Reg loss: 0.034 | Tree loss: 2.940 | Accuracy: 0.121094 | 0.353 sec/iter
Epoch: 71 | Batch: 020 / 027 | Total loss: 2.981 | Reg loss: 0.034 | Tree loss: 2.981 | Accuracy: 0.095703 | 0.353 sec/iter
Epoch: 71 | Batch: 021 / 027 | Total loss: 2.936 | Reg loss: 0.034 | Tree loss: 2.936 | Accuracy: 0.091797 | 0.353 sec/iter
Epoch: 71 | Batch: 022 / 027 | Total loss: 2.921 | Reg loss: 0.034 | Tree loss: 2.921 | Accuracy: 0.074219 | 0.353 sec/iter
Epoch: 71 | Batch: 023 / 027 | Total loss: 2.980 | Reg loss: 0.034 | Tree loss: 2.980 | Accuracy: 0.085938 | 0.353 sec/iter
Epoch: 71 | Batch: 024 / 027 | Total loss: 2.942 | Reg loss: 0.034 | Tree loss: 2.942 | Accuracy: 0.093750 | 0.353 sec/iter
Epoch: 71 | Batch: 025 / 027 | Total loss: 2.932 | Reg loss: 0.034 | Tree loss: 2.932 | Accuracy: 0.097656 | 0.353 sec/iter
Epoch: 7

Average sparseness: 0.9821428571428573
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
Epoch: 74 | Batch: 000 / 027 | Total loss: 3.080 | Reg loss: 0.033 | Tree loss: 3.080 | Accuracy: 0.107422 | 0.353 sec/iter
Epoch: 74 | Batch: 001 / 027 | Total loss: 3.115 | Reg loss: 0.033 | Tree loss: 3.115 | Accuracy: 0.107422 | 0.353 sec/iter
Epoch: 74 | Batch: 002 / 027 | Total loss: 3.112 | Reg loss: 0.033 | Tree loss: 3.112 | Accuracy: 0.095703 | 0.353 sec/iter
Epoch: 74 | Batch: 003 / 027 | Total loss: 3.152 | Reg loss: 0.033 | Tree loss: 3.152 | Accuracy: 0.080078 | 0.353 sec/iter
Epoch: 74 | Batch: 004 / 027 | Total loss: 3.041 | Reg loss: 0.033 | Tree loss: 3.041 | Accuracy: 0.097656 | 0.353 sec/iter
Epoch: 74 | Batch: 005 / 027 | Total loss: 3.023 | Reg loss: 0.033 | Tree loss: 3.023 | Accuracy: 0.101562 | 0.353 sec/iter
Epoch: 74 | Batch: 006

Epoch: 76 | Batch: 007 / 027 | Total loss: 3.051 | Reg loss: 0.033 | Tree loss: 3.051 | Accuracy: 0.146484 | 0.353 sec/iter
Epoch: 76 | Batch: 008 / 027 | Total loss: 3.050 | Reg loss: 0.033 | Tree loss: 3.050 | Accuracy: 0.109375 | 0.353 sec/iter
Epoch: 76 | Batch: 009 / 027 | Total loss: 3.029 | Reg loss: 0.033 | Tree loss: 3.029 | Accuracy: 0.107422 | 0.353 sec/iter
Epoch: 76 | Batch: 010 / 027 | Total loss: 3.095 | Reg loss: 0.033 | Tree loss: 3.095 | Accuracy: 0.070312 | 0.353 sec/iter
Epoch: 76 | Batch: 011 / 027 | Total loss: 3.027 | Reg loss: 0.033 | Tree loss: 3.027 | Accuracy: 0.109375 | 0.353 sec/iter
Epoch: 76 | Batch: 012 / 027 | Total loss: 2.987 | Reg loss: 0.033 | Tree loss: 2.987 | Accuracy: 0.107422 | 0.353 sec/iter
Epoch: 76 | Batch: 013 / 027 | Total loss: 2.998 | Reg loss: 0.033 | Tree loss: 2.998 | Accuracy: 0.117188 | 0.353 sec/iter
Epoch: 76 | Batch: 014 / 027 | Total loss: 3.025 | Reg loss: 0.034 | Tree loss: 3.025 | Accuracy: 0.111328 | 0.354 sec/iter
Epoch: 7

Epoch: 78 | Batch: 016 / 027 | Total loss: 2.992 | Reg loss: 0.033 | Tree loss: 2.992 | Accuracy: 0.105469 | 0.353 sec/iter
Epoch: 78 | Batch: 017 / 027 | Total loss: 2.955 | Reg loss: 0.034 | Tree loss: 2.955 | Accuracy: 0.126953 | 0.353 sec/iter
Epoch: 78 | Batch: 018 / 027 | Total loss: 2.971 | Reg loss: 0.034 | Tree loss: 2.971 | Accuracy: 0.101562 | 0.353 sec/iter
Epoch: 78 | Batch: 019 / 027 | Total loss: 2.943 | Reg loss: 0.034 | Tree loss: 2.943 | Accuracy: 0.080078 | 0.353 sec/iter
Epoch: 78 | Batch: 020 / 027 | Total loss: 2.945 | Reg loss: 0.034 | Tree loss: 2.945 | Accuracy: 0.115234 | 0.353 sec/iter
Epoch: 78 | Batch: 021 / 027 | Total loss: 2.907 | Reg loss: 0.034 | Tree loss: 2.907 | Accuracy: 0.097656 | 0.353 sec/iter
Epoch: 78 | Batch: 022 / 027 | Total loss: 2.879 | Reg loss: 0.034 | Tree loss: 2.879 | Accuracy: 0.109375 | 0.353 sec/iter
Epoch: 78 | Batch: 023 / 027 | Total loss: 2.899 | Reg loss: 0.034 | Tree loss: 2.899 | Accuracy: 0.107422 | 0.353 sec/iter
Epoch: 7

Epoch: 80 | Batch: 025 / 027 | Total loss: 2.906 | Reg loss: 0.034 | Tree loss: 2.906 | Accuracy: 0.099609 | 0.353 sec/iter
Epoch: 80 | Batch: 026 / 027 | Total loss: 3.005 | Reg loss: 0.034 | Tree loss: 3.005 | Accuracy: 0.083333 | 0.353 sec/iter
Average sparseness: 0.9821428571428573
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
Epoch: 81 | Batch: 000 / 027 | Total loss: 3.091 | Reg loss: 0.033 | Tree loss: 3.091 | Accuracy: 0.089844 | 0.354 sec/iter
Epoch: 81 | Batch: 001 / 027 | Total loss: 3.077 | Reg loss: 0.033 | Tree loss: 3.077 | Accuracy: 0.125000 | 0.354 sec/iter
Epoch: 81 | Batch: 002 / 027 | Total loss: 3.073 | Reg loss: 0.033 | Tree loss: 3.073 | Accuracy: 0.095703 | 0.354 sec/iter
Epoch: 81 | Batch: 003 / 027 | Total loss: 3.132 | Reg loss: 0.033 | Tree loss: 3.132 | Accuracy: 0.121094 | 0.354 sec/iter
Epoch: 81 | Batch: 004

Epoch: 83 | Batch: 005 / 027 | Total loss: 3.091 | Reg loss: 0.033 | Tree loss: 3.091 | Accuracy: 0.085938 | 0.354 sec/iter
Epoch: 83 | Batch: 006 / 027 | Total loss: 2.982 | Reg loss: 0.033 | Tree loss: 2.982 | Accuracy: 0.109375 | 0.354 sec/iter
Epoch: 83 | Batch: 007 / 027 | Total loss: 3.001 | Reg loss: 0.033 | Tree loss: 3.001 | Accuracy: 0.109375 | 0.354 sec/iter
Epoch: 83 | Batch: 008 / 027 | Total loss: 3.057 | Reg loss: 0.033 | Tree loss: 3.057 | Accuracy: 0.113281 | 0.354 sec/iter
Epoch: 83 | Batch: 009 / 027 | Total loss: 3.044 | Reg loss: 0.033 | Tree loss: 3.044 | Accuracy: 0.113281 | 0.354 sec/iter
Epoch: 83 | Batch: 010 / 027 | Total loss: 2.976 | Reg loss: 0.033 | Tree loss: 2.976 | Accuracy: 0.091797 | 0.354 sec/iter
Epoch: 83 | Batch: 011 / 027 | Total loss: 2.969 | Reg loss: 0.033 | Tree loss: 2.969 | Accuracy: 0.091797 | 0.354 sec/iter
Epoch: 83 | Batch: 012 / 027 | Total loss: 2.967 | Reg loss: 0.033 | Tree loss: 2.967 | Accuracy: 0.095703 | 0.354 sec/iter
Epoch: 8

Epoch: 85 | Batch: 014 / 027 | Total loss: 2.962 | Reg loss: 0.033 | Tree loss: 2.962 | Accuracy: 0.125000 | 0.354 sec/iter
Epoch: 85 | Batch: 015 / 027 | Total loss: 2.919 | Reg loss: 0.033 | Tree loss: 2.919 | Accuracy: 0.109375 | 0.354 sec/iter
Epoch: 85 | Batch: 016 / 027 | Total loss: 2.945 | Reg loss: 0.033 | Tree loss: 2.945 | Accuracy: 0.099609 | 0.354 sec/iter
Epoch: 85 | Batch: 017 / 027 | Total loss: 2.908 | Reg loss: 0.033 | Tree loss: 2.908 | Accuracy: 0.103516 | 0.354 sec/iter
Epoch: 85 | Batch: 018 / 027 | Total loss: 2.888 | Reg loss: 0.033 | Tree loss: 2.888 | Accuracy: 0.097656 | 0.354 sec/iter
Epoch: 85 | Batch: 019 / 027 | Total loss: 2.986 | Reg loss: 0.033 | Tree loss: 2.986 | Accuracy: 0.089844 | 0.354 sec/iter
Epoch: 85 | Batch: 020 / 027 | Total loss: 2.926 | Reg loss: 0.033 | Tree loss: 2.926 | Accuracy: 0.115234 | 0.354 sec/iter
Epoch: 85 | Batch: 021 / 027 | Total loss: 2.928 | Reg loss: 0.033 | Tree loss: 2.928 | Accuracy: 0.095703 | 0.354 sec/iter
Epoch: 8

Epoch: 87 | Batch: 023 / 027 | Total loss: 2.888 | Reg loss: 0.033 | Tree loss: 2.888 | Accuracy: 0.109375 | 0.354 sec/iter
Epoch: 87 | Batch: 024 / 027 | Total loss: 2.879 | Reg loss: 0.033 | Tree loss: 2.879 | Accuracy: 0.138672 | 0.354 sec/iter
Epoch: 87 | Batch: 025 / 027 | Total loss: 2.912 | Reg loss: 0.033 | Tree loss: 2.912 | Accuracy: 0.109375 | 0.354 sec/iter
Epoch: 87 | Batch: 026 / 027 | Total loss: 2.768 | Reg loss: 0.033 | Tree loss: 2.768 | Accuracy: 0.000000 | 0.354 sec/iter
Average sparseness: 0.9821428571428573
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
Epoch: 88 | Batch: 000 / 027 | Total loss: 3.080 | Reg loss: 0.033 | Tree loss: 3.080 | Accuracy: 0.103516 | 0.354 sec/iter
Epoch: 88 | Batch: 001 / 027 | Total loss: 3.061 | Reg loss: 0.033 | Tree loss: 3.061 | Accuracy: 0.074219 | 0.354 sec/iter
Epoch: 88 | Batch: 002

Epoch: 90 | Batch: 003 / 027 | Total loss: 3.029 | Reg loss: 0.033 | Tree loss: 3.029 | Accuracy: 0.105469 | 0.354 sec/iter
Epoch: 90 | Batch: 004 / 027 | Total loss: 3.026 | Reg loss: 0.033 | Tree loss: 3.026 | Accuracy: 0.082031 | 0.354 sec/iter
Epoch: 90 | Batch: 005 / 027 | Total loss: 3.037 | Reg loss: 0.033 | Tree loss: 3.037 | Accuracy: 0.126953 | 0.354 sec/iter
Epoch: 90 | Batch: 006 / 027 | Total loss: 2.983 | Reg loss: 0.033 | Tree loss: 2.983 | Accuracy: 0.101562 | 0.354 sec/iter
Epoch: 90 | Batch: 007 / 027 | Total loss: 2.987 | Reg loss: 0.033 | Tree loss: 2.987 | Accuracy: 0.105469 | 0.354 sec/iter
Epoch: 90 | Batch: 008 / 027 | Total loss: 3.000 | Reg loss: 0.033 | Tree loss: 3.000 | Accuracy: 0.119141 | 0.354 sec/iter
Epoch: 90 | Batch: 009 / 027 | Total loss: 3.001 | Reg loss: 0.033 | Tree loss: 3.001 | Accuracy: 0.105469 | 0.354 sec/iter
Epoch: 90 | Batch: 010 / 027 | Total loss: 2.973 | Reg loss: 0.033 | Tree loss: 2.973 | Accuracy: 0.123047 | 0.354 sec/iter
Epoch: 9

Epoch: 92 | Batch: 012 / 027 | Total loss: 2.961 | Reg loss: 0.033 | Tree loss: 2.961 | Accuracy: 0.113281 | 0.354 sec/iter
Epoch: 92 | Batch: 013 / 027 | Total loss: 2.939 | Reg loss: 0.033 | Tree loss: 2.939 | Accuracy: 0.113281 | 0.354 sec/iter
Epoch: 92 | Batch: 014 / 027 | Total loss: 2.882 | Reg loss: 0.033 | Tree loss: 2.882 | Accuracy: 0.123047 | 0.354 sec/iter
Epoch: 92 | Batch: 015 / 027 | Total loss: 2.927 | Reg loss: 0.033 | Tree loss: 2.927 | Accuracy: 0.087891 | 0.354 sec/iter
Epoch: 92 | Batch: 016 / 027 | Total loss: 2.908 | Reg loss: 0.033 | Tree loss: 2.908 | Accuracy: 0.097656 | 0.354 sec/iter
Epoch: 92 | Batch: 017 / 027 | Total loss: 2.956 | Reg loss: 0.033 | Tree loss: 2.956 | Accuracy: 0.123047 | 0.354 sec/iter
Epoch: 92 | Batch: 018 / 027 | Total loss: 2.900 | Reg loss: 0.033 | Tree loss: 2.900 | Accuracy: 0.101562 | 0.354 sec/iter
Epoch: 92 | Batch: 019 / 027 | Total loss: 2.902 | Reg loss: 0.033 | Tree loss: 2.902 | Accuracy: 0.082031 | 0.354 sec/iter
Epoch: 9

Epoch: 94 | Batch: 021 / 027 | Total loss: 2.906 | Reg loss: 0.033 | Tree loss: 2.906 | Accuracy: 0.099609 | 0.354 sec/iter
Epoch: 94 | Batch: 022 / 027 | Total loss: 2.881 | Reg loss: 0.033 | Tree loss: 2.881 | Accuracy: 0.113281 | 0.354 sec/iter
Epoch: 94 | Batch: 023 / 027 | Total loss: 2.867 | Reg loss: 0.033 | Tree loss: 2.867 | Accuracy: 0.097656 | 0.354 sec/iter
Epoch: 94 | Batch: 024 / 027 | Total loss: 2.889 | Reg loss: 0.033 | Tree loss: 2.889 | Accuracy: 0.117188 | 0.354 sec/iter
Epoch: 94 | Batch: 025 / 027 | Total loss: 2.862 | Reg loss: 0.033 | Tree loss: 2.862 | Accuracy: 0.109375 | 0.354 sec/iter
Epoch: 94 | Batch: 026 / 027 | Total loss: 2.712 | Reg loss: 0.033 | Tree loss: 2.712 | Accuracy: 0.166667 | 0.354 sec/iter
Average sparseness: 0.9821428571428573
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
Epoch: 95 | Batch: 000

Epoch: 97 | Batch: 001 / 027 | Total loss: 3.044 | Reg loss: 0.032 | Tree loss: 3.044 | Accuracy: 0.109375 | 0.354 sec/iter
Epoch: 97 | Batch: 002 / 027 | Total loss: 3.063 | Reg loss: 0.032 | Tree loss: 3.063 | Accuracy: 0.101562 | 0.354 sec/iter
Epoch: 97 | Batch: 003 / 027 | Total loss: 3.054 | Reg loss: 0.032 | Tree loss: 3.054 | Accuracy: 0.093750 | 0.354 sec/iter
Epoch: 97 | Batch: 004 / 027 | Total loss: 3.054 | Reg loss: 0.032 | Tree loss: 3.054 | Accuracy: 0.089844 | 0.354 sec/iter
Epoch: 97 | Batch: 005 / 027 | Total loss: 3.000 | Reg loss: 0.032 | Tree loss: 3.000 | Accuracy: 0.107422 | 0.354 sec/iter
Epoch: 97 | Batch: 006 / 027 | Total loss: 3.048 | Reg loss: 0.032 | Tree loss: 3.048 | Accuracy: 0.095703 | 0.354 sec/iter
Epoch: 97 | Batch: 007 / 027 | Total loss: 2.966 | Reg loss: 0.032 | Tree loss: 2.966 | Accuracy: 0.119141 | 0.354 sec/iter
Epoch: 97 | Batch: 008 / 027 | Total loss: 2.990 | Reg loss: 0.032 | Tree loss: 2.990 | Accuracy: 0.123047 | 0.354 sec/iter
Epoch: 9

Epoch: 99 | Batch: 010 / 027 | Total loss: 3.009 | Reg loss: 0.032 | Tree loss: 3.009 | Accuracy: 0.126953 | 0.354 sec/iter
Epoch: 99 | Batch: 011 / 027 | Total loss: 2.961 | Reg loss: 0.032 | Tree loss: 2.961 | Accuracy: 0.103516 | 0.354 sec/iter
Epoch: 99 | Batch: 012 / 027 | Total loss: 2.970 | Reg loss: 0.032 | Tree loss: 2.970 | Accuracy: 0.095703 | 0.354 sec/iter
Epoch: 99 | Batch: 013 / 027 | Total loss: 2.915 | Reg loss: 0.032 | Tree loss: 2.915 | Accuracy: 0.101562 | 0.354 sec/iter
Epoch: 99 | Batch: 014 / 027 | Total loss: 2.969 | Reg loss: 0.032 | Tree loss: 2.969 | Accuracy: 0.105469 | 0.354 sec/iter
Epoch: 99 | Batch: 015 / 027 | Total loss: 2.939 | Reg loss: 0.032 | Tree loss: 2.939 | Accuracy: 0.115234 | 0.354 sec/iter
Epoch: 99 | Batch: 016 / 027 | Total loss: 2.939 | Reg loss: 0.032 | Tree loss: 2.939 | Accuracy: 0.091797 | 0.354 sec/iter
Epoch: 99 | Batch: 017 / 027 | Total loss: 2.899 | Reg loss: 0.032 | Tree loss: 2.899 | Accuracy: 0.109375 | 0.354 sec/iter
Epoch: 9

In [33]:
plt.figure(figsize=(10, 5))
plt.ylabel("Accuracy")
plt.xlabel('Iteration')
plt.plot(accs, label='Accuracy vs iteration')
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [34]:
plt.figure()
plt.ylabel("Loss")
plt.xlabel('Iteration')
plt.plot(losses, label='Loss vs iteration')
plt.yscale("log")
plt.show()

plt.figure()
weights = tree.inner_nodes.weight.cpu().detach().numpy().flatten()
plt.hist(weights, bins=500)
weights_std = np.std(weights)
weights_mean = np.mean(weights)
plt.axvline(weights_mean + weights_std, color='r')
plt.axvline(weights_mean - weights_std, color='r')
plt.title(f"Mean: {weights_mean}   |   STD: {weights_std}")
plt.yscale("log")
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

# Tree Visualization

In [35]:
plt.figure(figsize=(15, 10), dpi=80)
avg_height, root = tree.visualize()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Average height: 7.625


# Extract Rules

# Accumulate samples in the leaves

In [36]:
print(f"Number of patterns: {len(root.get_leaves())}")

Number of patterns: 160


In [37]:
method = 'greedy'

In [38]:
root.clear_leaves_samples()

with torch.no_grad():
    for batch_idx, (data, target) in enumerate(tree_loader):
        root.accumulate_samples(data, method)



# Tighten boundaries

In [39]:
attr_names = dataset.items

# print(attr_names)
leaves = root.get_leaves()
sum_comprehensibility = 0
comprehensibilities = []
for pattern_counter, leaf in enumerate(leaves):
    leaf.reset_path()
    leaf.tighten_with_accumulated_samples()
    conds = leaf.get_path_conditions(attr_names)
    print(f"============== Pattern {pattern_counter + 1} ==============")
    comprehensibilities.append(sum([cond.comprehensibility for cond in conds]))
    
print(f"Average comprehensibility: {np.mean(comprehensibilities)}")
print(f"std comprehensibility: {np.std(comprehensibilities)}")
print(f"var comprehensibility: {np.var(comprehensibilities)}")
print(f"minimum comprehensibility: {np.min(comprehensibilities)}")
print(f"maximum comprehensibility: {np.max(comprehensibilities)}")

  return np.log(1 / (1 - x))


10132
3192
Average comprehensibility: 36.4
std comprehensibility: 4.7791212581394085
var comprehensibility: 22.84
minimum comprehensibility: 18
maximum comprehensibility: 44
