In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib widget
import numpy as np
import pandas as pd
import torch
import torchvision
from torch import optim
import matplotlib.pyplot as plt
import torch.nn.functional as F
from sklearn.metrics import pairwise_distances
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tqdm.notebook import tqdm
import os
import sys
import time
import torch.nn as nn
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)
from stream_generators.market_basket_dataset import MarketBasketDataset, BinaryEncodingTransform, RemoveItemsTransform
from utils.MatplotlibUtils import reduce_dims_and_plot
from network.auto_encoder import AutoEncoder
from losses.knn_loss import KNNLoss
from sklearn.cluster import DBSCAN
from sklearn.mixture import GaussianMixture
from soft_decision_tree.sdt_model import SDT
from sklearn.metrics import davies_bouldin_score

In [2]:
k = 128
tree_depth = 12
device = 'cuda'
dataset_path = r"/mnt/qnap/ekosman/Groceries_dataset.csv"

## Create the market basket dataset and use one-hot encoding for items

In [3]:
dataset = MarketBasketDataset(dataset_path=dataset_path)

In [4]:
model = AutoEncoder(dataset.n_items, 50, 4).train().to(device)
epochs = 500
lr = 5e-3
batch_size = 512
log_every = 5

In [5]:
dataset.transform = torchvision.transforms.Compose([
    RemoveItemsTransform(p=0.5),
    BinaryEncodingTransform(mapping=dataset.items_to_idx),
]
)
dataset.target_transform = torchvision.transforms.Compose([
    BinaryEncodingTransform(mapping=dataset.items_to_idx),
]
)

In [6]:
model.train()
data_iter = torch.utils.data.DataLoader(dataset,
                                     batch_size=batch_size,
                                     shuffle=True,
                                     num_workers=1,
                                     pin_memory=True)
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)
scheduler = ReduceLROnPlateau(optimizer, 'min', verbose=True, factor=0.7, threshold=1e-4)
knn_crt = KNNLoss(k=k).to(device)
losses = []
alpha = 10/170
gamma = 2
for epoch in range(epochs):
    total_loss = 0
    for iteration, (batch, target) in enumerate(data_iter):
        batch = batch.to(device)
        target = target.to(device)
        outputs, iterm = model(batch, return_intermidiate=True)
        mse_loss = F.binary_cross_entropy_with_logits(outputs, target, reduction='none')
        mask = torch.ones_like(mse_loss)
        mask[target == 0] = alpha ** gamma
        mask[target == 1] = (1 - alpha) ** gamma
        mse_loss = (mse_loss * mask).sum(dim=-1).mean()
        try:
            knn_loss = knn_crt(iterm)
            if torch.isinf(knn_loss):
                knn_loss = 0
        except ValueError:
            knn_loss = torch.tensor(0)
        loss = mse_loss + knn_loss
        total_loss += loss.item()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if iteration % log_every == 0:
            print(f"Epoch {epoch} / {epochs} | iteration {iteration} / {len(data_iter)} | Total Loss: {loss.item()} | KNN Loss: {knn_loss.item()} | BCE Loss: {mse_loss.item()}")
    
    scheduler.step(total_loss / (iteration + 1))
    losses.append(total_loss / (iteration + 1))

Epoch 0 / 500 | iteration 0 / 30 | Total Loss: 8.222156524658203 | KNN Loss: 6.2308759689331055 | BCE Loss: 1.9912800788879395
Epoch 0 / 500 | iteration 5 / 30 | Total Loss: 8.250527381896973 | KNN Loss: 6.230742454528809 | BCE Loss: 2.019784927368164
Epoch 0 / 500 | iteration 10 / 30 | Total Loss: 8.177160263061523 | KNN Loss: 6.230837821960449 | BCE Loss: 1.9463227987289429
Epoch 0 / 500 | iteration 15 / 30 | Total Loss: 8.158607482910156 | KNN Loss: 6.2306413650512695 | BCE Loss: 1.9279663562774658
Epoch 0 / 500 | iteration 20 / 30 | Total Loss: 8.189087867736816 | KNN Loss: 6.230562686920166 | BCE Loss: 1.9585249423980713
Epoch 0 / 500 | iteration 25 / 30 | Total Loss: 8.206348419189453 | KNN Loss: 6.230489253997803 | BCE Loss: 1.97585928440094
Epoch 1 / 500 | iteration 0 / 30 | Total Loss: 8.168549537658691 | KNN Loss: 6.230339050292969 | BCE Loss: 1.9382104873657227
Epoch 1 / 500 | iteration 5 / 30 | Total Loss: 8.182073593139648 | KNN Loss: 6.230133056640625 | BCE Loss: 1.951940

Epoch 10 / 500 | iteration 25 / 30 | Total Loss: 7.244460582733154 | KNN Loss: 6.120423793792725 | BCE Loss: 1.1240366697311401
Epoch 11 / 500 | iteration 0 / 30 | Total Loss: 7.239505767822266 | KNN Loss: 6.110269069671631 | BCE Loss: 1.1292364597320557
Epoch 11 / 500 | iteration 5 / 30 | Total Loss: 7.203968048095703 | KNN Loss: 6.089426040649414 | BCE Loss: 1.11454176902771
Epoch 11 / 500 | iteration 10 / 30 | Total Loss: 7.219274044036865 | KNN Loss: 6.091695785522461 | BCE Loss: 1.1275783777236938
Epoch 11 / 500 | iteration 15 / 30 | Total Loss: 7.180758476257324 | KNN Loss: 6.0650410652160645 | BCE Loss: 1.1157176494598389
Epoch 11 / 500 | iteration 20 / 30 | Total Loss: 7.187288284301758 | KNN Loss: 6.0549516677856445 | BCE Loss: 1.1323363780975342
Epoch 11 / 500 | iteration 25 / 30 | Total Loss: 7.148846626281738 | KNN Loss: 6.03350305557251 | BCE Loss: 1.115343689918518
Epoch 12 / 500 | iteration 0 / 30 | Total Loss: 7.079246997833252 | KNN Loss: 5.996521472930908 | BCE Loss: 

Epoch 21 / 500 | iteration 20 / 30 | Total Loss: 6.100081443786621 | KNN Loss: 5.0664381980896 | BCE Loss: 1.033643126487732
Epoch 21 / 500 | iteration 25 / 30 | Total Loss: 6.133056640625 | KNN Loss: 5.060835361480713 | BCE Loss: 1.0722215175628662
Epoch 22 / 500 | iteration 0 / 30 | Total Loss: 6.144747734069824 | KNN Loss: 5.062102794647217 | BCE Loss: 1.0826447010040283
Epoch 22 / 500 | iteration 5 / 30 | Total Loss: 6.1050519943237305 | KNN Loss: 5.070614814758301 | BCE Loss: 1.0344369411468506
Epoch 22 / 500 | iteration 10 / 30 | Total Loss: 6.1240739822387695 | KNN Loss: 5.068467617034912 | BCE Loss: 1.0556062459945679
Epoch 22 / 500 | iteration 15 / 30 | Total Loss: 6.120849609375 | KNN Loss: 5.079685688018799 | BCE Loss: 1.0411639213562012
Epoch 22 / 500 | iteration 20 / 30 | Total Loss: 6.108098983764648 | KNN Loss: 5.069944858551025 | BCE Loss: 1.038154125213623
Epoch 22 / 500 | iteration 25 / 30 | Total Loss: 6.146769046783447 | KNN Loss: 5.057798862457275 | BCE Loss: 1.088

Epoch 32 / 500 | iteration 15 / 30 | Total Loss: 6.132634162902832 | KNN Loss: 5.058141708374023 | BCE Loss: 1.0744925737380981
Epoch 32 / 500 | iteration 20 / 30 | Total Loss: 6.13221549987793 | KNN Loss: 5.069360256195068 | BCE Loss: 1.0628553628921509
Epoch 32 / 500 | iteration 25 / 30 | Total Loss: 6.0787672996521 | KNN Loss: 5.034670352935791 | BCE Loss: 1.0440969467163086
Epoch 33 / 500 | iteration 0 / 30 | Total Loss: 6.0907511711120605 | KNN Loss: 5.040624618530273 | BCE Loss: 1.050126552581787
Epoch 33 / 500 | iteration 5 / 30 | Total Loss: 6.109333515167236 | KNN Loss: 5.041493892669678 | BCE Loss: 1.067839503288269
Epoch 33 / 500 | iteration 10 / 30 | Total Loss: 6.110171318054199 | KNN Loss: 5.057861804962158 | BCE Loss: 1.052309274673462
Epoch 33 / 500 | iteration 15 / 30 | Total Loss: 6.108434677124023 | KNN Loss: 5.054862976074219 | BCE Loss: 1.0535714626312256
Epoch 33 / 500 | iteration 20 / 30 | Total Loss: 6.096025466918945 | KNN Loss: 5.044345855712891 | BCE Loss: 1.

Epoch 43 / 500 | iteration 10 / 30 | Total Loss: 6.101199626922607 | KNN Loss: 5.0482258796691895 | BCE Loss: 1.0529736280441284
Epoch 43 / 500 | iteration 15 / 30 | Total Loss: 6.084390163421631 | KNN Loss: 5.035100936889648 | BCE Loss: 1.0492892265319824
Epoch 43 / 500 | iteration 20 / 30 | Total Loss: 6.100293159484863 | KNN Loss: 5.055597305297852 | BCE Loss: 1.0446959733963013
Epoch 43 / 500 | iteration 25 / 30 | Total Loss: 6.0953521728515625 | KNN Loss: 5.036495208740234 | BCE Loss: 1.0588572025299072
Epoch 44 / 500 | iteration 0 / 30 | Total Loss: 6.095643997192383 | KNN Loss: 5.040556907653809 | BCE Loss: 1.0550869703292847
Epoch 44 / 500 | iteration 5 / 30 | Total Loss: 6.09891414642334 | KNN Loss: 5.032977104187012 | BCE Loss: 1.0659370422363281
Epoch 44 / 500 | iteration 10 / 30 | Total Loss: 6.089578628540039 | KNN Loss: 5.052615642547607 | BCE Loss: 1.036962866783142
Epoch 44 / 500 | iteration 15 / 30 | Total Loss: 6.122159004211426 | KNN Loss: 5.065701961517334 | BCE Los

Epoch 54 / 500 | iteration 5 / 30 | Total Loss: 6.062124252319336 | KNN Loss: 5.033090114593506 | BCE Loss: 1.029033899307251
Epoch 54 / 500 | iteration 10 / 30 | Total Loss: 6.083619117736816 | KNN Loss: 5.032845973968506 | BCE Loss: 1.0507731437683105
Epoch 54 / 500 | iteration 15 / 30 | Total Loss: 6.0724873542785645 | KNN Loss: 5.0229926109313965 | BCE Loss: 1.0494946241378784
Epoch 54 / 500 | iteration 20 / 30 | Total Loss: 6.0718302726745605 | KNN Loss: 5.022273063659668 | BCE Loss: 1.049557089805603
Epoch 54 / 500 | iteration 25 / 30 | Total Loss: 6.0956268310546875 | KNN Loss: 5.032552242279053 | BCE Loss: 1.0630748271942139
Epoch 55 / 500 | iteration 0 / 30 | Total Loss: 6.079458236694336 | KNN Loss: 5.040228843688965 | BCE Loss: 1.039229154586792
Epoch 55 / 500 | iteration 5 / 30 | Total Loss: 6.079995632171631 | KNN Loss: 5.024451732635498 | BCE Loss: 1.0555437803268433
Epoch 55 / 500 | iteration 10 / 30 | Total Loss: 6.05195951461792 | KNN Loss: 5.022393226623535 | BCE Loss

Epoch 65 / 500 | iteration 0 / 30 | Total Loss: 6.075306415557861 | KNN Loss: 5.015254020690918 | BCE Loss: 1.0600523948669434
Epoch 65 / 500 | iteration 5 / 30 | Total Loss: 6.068122863769531 | KNN Loss: 5.027667045593262 | BCE Loss: 1.040455937385559
Epoch 65 / 500 | iteration 10 / 30 | Total Loss: 6.065857410430908 | KNN Loss: 5.02804708480835 | BCE Loss: 1.037810206413269
Epoch 65 / 500 | iteration 15 / 30 | Total Loss: 6.082126617431641 | KNN Loss: 5.04140043258667 | BCE Loss: 1.0407260656356812
Epoch 65 / 500 | iteration 20 / 30 | Total Loss: 6.1417083740234375 | KNN Loss: 5.075972557067871 | BCE Loss: 1.0657358169555664
Epoch 65 / 500 | iteration 25 / 30 | Total Loss: 6.149733543395996 | KNN Loss: 5.080692768096924 | BCE Loss: 1.0690405368804932
Epoch 66 / 500 | iteration 0 / 30 | Total Loss: 6.125180721282959 | KNN Loss: 5.075532913208008 | BCE Loss: 1.0496478080749512
Epoch 66 / 500 | iteration 5 / 30 | Total Loss: 6.063169479370117 | KNN Loss: 5.040223121643066 | BCE Loss: 1.

Epoch 75 / 500 | iteration 20 / 30 | Total Loss: 6.082497596740723 | KNN Loss: 5.045161724090576 | BCE Loss: 1.0373361110687256
Epoch 75 / 500 | iteration 25 / 30 | Total Loss: 6.061762809753418 | KNN Loss: 5.039983749389648 | BCE Loss: 1.0217788219451904
Epoch 76 / 500 | iteration 0 / 30 | Total Loss: 6.069907188415527 | KNN Loss: 5.0101399421691895 | BCE Loss: 1.0597673654556274
Epoch 76 / 500 | iteration 5 / 30 | Total Loss: 6.053818702697754 | KNN Loss: 5.036397933959961 | BCE Loss: 1.0174205303192139
Epoch 76 / 500 | iteration 10 / 30 | Total Loss: 6.0723395347595215 | KNN Loss: 5.036162853240967 | BCE Loss: 1.0361768007278442
Epoch 76 / 500 | iteration 15 / 30 | Total Loss: 6.066215991973877 | KNN Loss: 5.015169620513916 | BCE Loss: 1.0510462522506714
Epoch 76 / 500 | iteration 20 / 30 | Total Loss: 6.097414016723633 | KNN Loss: 5.041859149932861 | BCE Loss: 1.0555546283721924
Epoch 76 / 500 | iteration 25 / 30 | Total Loss: 6.119195461273193 | KNN Loss: 5.044164180755615 | BCE L

Epoch 86 / 500 | iteration 15 / 30 | Total Loss: 6.0418195724487305 | KNN Loss: 5.011046409606934 | BCE Loss: 1.0307729244232178
Epoch 86 / 500 | iteration 20 / 30 | Total Loss: 6.027312278747559 | KNN Loss: 5.017266750335693 | BCE Loss: 1.0100456476211548
Epoch 86 / 500 | iteration 25 / 30 | Total Loss: 6.09340763092041 | KNN Loss: 5.028271198272705 | BCE Loss: 1.065136432647705
Epoch 87 / 500 | iteration 0 / 30 | Total Loss: 6.071282863616943 | KNN Loss: 5.018850803375244 | BCE Loss: 1.0524321794509888
Epoch 87 / 500 | iteration 5 / 30 | Total Loss: 6.082579612731934 | KNN Loss: 5.024210453033447 | BCE Loss: 1.0583691596984863
Epoch 87 / 500 | iteration 10 / 30 | Total Loss: 6.052914619445801 | KNN Loss: 5.029180526733398 | BCE Loss: 1.0237343311309814
Epoch 87 / 500 | iteration 15 / 30 | Total Loss: 6.050444602966309 | KNN Loss: 5.01887321472168 | BCE Loss: 1.031571388244629
Epoch 87 / 500 | iteration 20 / 30 | Total Loss: 6.111946105957031 | KNN Loss: 5.057621479034424 | BCE Loss: 

Epoch 97 / 500 | iteration 10 / 30 | Total Loss: 6.06715202331543 | KNN Loss: 5.02771520614624 | BCE Loss: 1.0394368171691895
Epoch 97 / 500 | iteration 15 / 30 | Total Loss: 6.055013179779053 | KNN Loss: 5.007699012756348 | BCE Loss: 1.0473142862319946
Epoch 97 / 500 | iteration 20 / 30 | Total Loss: 6.033638000488281 | KNN Loss: 5.0281291007995605 | BCE Loss: 1.0055087804794312
Epoch 97 / 500 | iteration 25 / 30 | Total Loss: 6.043074607849121 | KNN Loss: 5.022216320037842 | BCE Loss: 1.0208581686019897
Epoch 98 / 500 | iteration 0 / 30 | Total Loss: 6.109987735748291 | KNN Loss: 5.0679521560668945 | BCE Loss: 1.0420355796813965
Epoch 98 / 500 | iteration 5 / 30 | Total Loss: 6.096314907073975 | KNN Loss: 5.042057037353516 | BCE Loss: 1.0542577505111694
Epoch 98 / 500 | iteration 10 / 30 | Total Loss: 6.046238899230957 | KNN Loss: 5.041769504547119 | BCE Loss: 1.0044695138931274
Epoch 98 / 500 | iteration 15 / 30 | Total Loss: 6.0885114669799805 | KNN Loss: 5.041902542114258 | BCE Lo

Epoch 108 / 500 | iteration 0 / 30 | Total Loss: 6.069351673126221 | KNN Loss: 5.03022575378418 | BCE Loss: 1.0391258001327515
Epoch 108 / 500 | iteration 5 / 30 | Total Loss: 6.065885543823242 | KNN Loss: 5.019021034240723 | BCE Loss: 1.0468647480010986
Epoch 108 / 500 | iteration 10 / 30 | Total Loss: 6.047036170959473 | KNN Loss: 5.007876873016357 | BCE Loss: 1.0391590595245361
Epoch 108 / 500 | iteration 15 / 30 | Total Loss: 6.0364556312561035 | KNN Loss: 5.002388000488281 | BCE Loss: 1.0340677499771118
Epoch 108 / 500 | iteration 20 / 30 | Total Loss: 6.079538345336914 | KNN Loss: 5.028631210327148 | BCE Loss: 1.0509068965911865
Epoch 108 / 500 | iteration 25 / 30 | Total Loss: 6.06098747253418 | KNN Loss: 5.022965908050537 | BCE Loss: 1.0380216836929321
Epoch 109 / 500 | iteration 0 / 30 | Total Loss: 6.075107097625732 | KNN Loss: 5.0250749588012695 | BCE Loss: 1.0500322580337524
Epoch 109 / 500 | iteration 5 / 30 | Total Loss: 6.062103748321533 | KNN Loss: 5.023305892944336 | B

Epoch 118 / 500 | iteration 20 / 30 | Total Loss: 6.092315673828125 | KNN Loss: 5.0320563316345215 | BCE Loss: 1.060259461402893
Epoch 118 / 500 | iteration 25 / 30 | Total Loss: 6.067868232727051 | KNN Loss: 5.0366034507751465 | BCE Loss: 1.0312647819519043
Epoch 119 / 500 | iteration 0 / 30 | Total Loss: 6.031923770904541 | KNN Loss: 5.015937328338623 | BCE Loss: 1.0159865617752075
Epoch 119 / 500 | iteration 5 / 30 | Total Loss: 6.051070213317871 | KNN Loss: 5.031985759735107 | BCE Loss: 1.0190842151641846
Epoch 119 / 500 | iteration 10 / 30 | Total Loss: 6.126895904541016 | KNN Loss: 5.071123123168945 | BCE Loss: 1.0557730197906494
Epoch 119 / 500 | iteration 15 / 30 | Total Loss: 6.031171798706055 | KNN Loss: 5.012972354888916 | BCE Loss: 1.0181993246078491
Epoch 119 / 500 | iteration 20 / 30 | Total Loss: 6.123788356781006 | KNN Loss: 5.079809188842773 | BCE Loss: 1.0439791679382324
Epoch 119 / 500 | iteration 25 / 30 | Total Loss: 6.052117347717285 | KNN Loss: 5.0202717781066895

Epoch 129 / 500 | iteration 10 / 30 | Total Loss: 6.035946846008301 | KNN Loss: 5.021030902862549 | BCE Loss: 1.014915943145752
Epoch 129 / 500 | iteration 15 / 30 | Total Loss: 6.060025215148926 | KNN Loss: 5.039061546325684 | BCE Loss: 1.020963430404663
Epoch 129 / 500 | iteration 20 / 30 | Total Loss: 6.035989761352539 | KNN Loss: 5.024127006530762 | BCE Loss: 1.0118627548217773
Epoch 129 / 500 | iteration 25 / 30 | Total Loss: 6.082831859588623 | KNN Loss: 5.0340657234191895 | BCE Loss: 1.0487661361694336
Epoch 130 / 500 | iteration 0 / 30 | Total Loss: 6.04049015045166 | KNN Loss: 5.009990692138672 | BCE Loss: 1.0304994583129883
Epoch 130 / 500 | iteration 5 / 30 | Total Loss: 6.04350471496582 | KNN Loss: 5.023612022399902 | BCE Loss: 1.0198924541473389
Epoch 130 / 500 | iteration 10 / 30 | Total Loss: 6.034991264343262 | KNN Loss: 5.013096332550049 | BCE Loss: 1.021894931793213
Epoch 130 / 500 | iteration 15 / 30 | Total Loss: 6.052977085113525 | KNN Loss: 5.031552314758301 | BCE

Epoch 140 / 500 | iteration 0 / 30 | Total Loss: 6.071063995361328 | KNN Loss: 5.017164707183838 | BCE Loss: 1.0538992881774902
Epoch 140 / 500 | iteration 5 / 30 | Total Loss: 6.044007301330566 | KNN Loss: 5.017685890197754 | BCE Loss: 1.026321291923523
Epoch 140 / 500 | iteration 10 / 30 | Total Loss: 6.026648044586182 | KNN Loss: 5.009760856628418 | BCE Loss: 1.0168873071670532
Epoch 140 / 500 | iteration 15 / 30 | Total Loss: 6.05915641784668 | KNN Loss: 5.034256935119629 | BCE Loss: 1.0248994827270508
Epoch 140 / 500 | iteration 20 / 30 | Total Loss: 6.070659637451172 | KNN Loss: 5.028651237487793 | BCE Loss: 1.0420082807540894
Epoch 140 / 500 | iteration 25 / 30 | Total Loss: 6.077164649963379 | KNN Loss: 5.040660381317139 | BCE Loss: 1.0365042686462402
Epoch 141 / 500 | iteration 0 / 30 | Total Loss: 6.098871231079102 | KNN Loss: 5.047292709350586 | BCE Loss: 1.0515785217285156
Epoch 141 / 500 | iteration 5 / 30 | Total Loss: 6.071250915527344 | KNN Loss: 5.045877456665039 | BCE

Epoch 150 / 500 | iteration 20 / 30 | Total Loss: 6.098261833190918 | KNN Loss: 5.060215473175049 | BCE Loss: 1.0380464792251587
Epoch 150 / 500 | iteration 25 / 30 | Total Loss: 6.112943172454834 | KNN Loss: 5.0787672996521 | BCE Loss: 1.0341758728027344
Epoch 151 / 500 | iteration 0 / 30 | Total Loss: 6.051935195922852 | KNN Loss: 5.019477367401123 | BCE Loss: 1.0324578285217285
Epoch 151 / 500 | iteration 5 / 30 | Total Loss: 6.022839546203613 | KNN Loss: 5.011096477508545 | BCE Loss: 1.0117428302764893
Epoch 151 / 500 | iteration 10 / 30 | Total Loss: 6.067074775695801 | KNN Loss: 5.031554698944092 | BCE Loss: 1.0355199575424194
Epoch 151 / 500 | iteration 15 / 30 | Total Loss: 6.080998420715332 | KNN Loss: 5.021121501922607 | BCE Loss: 1.0598770380020142
Epoch 151 / 500 | iteration 20 / 30 | Total Loss: 6.050303936004639 | KNN Loss: 5.024550914764404 | BCE Loss: 1.025753140449524
Epoch 151 / 500 | iteration 25 / 30 | Total Loss: 6.054780006408691 | KNN Loss: 5.0077080726623535 | B

Epoch 161 / 500 | iteration 10 / 30 | Total Loss: 6.029268264770508 | KNN Loss: 5.0188069343566895 | BCE Loss: 1.0104610919952393
Epoch 161 / 500 | iteration 15 / 30 | Total Loss: 6.009766578674316 | KNN Loss: 5.0188307762146 | BCE Loss: 0.9909359216690063
Epoch 161 / 500 | iteration 20 / 30 | Total Loss: 6.075444221496582 | KNN Loss: 5.044018268585205 | BCE Loss: 1.031425952911377
Epoch 161 / 500 | iteration 25 / 30 | Total Loss: 6.0405120849609375 | KNN Loss: 5.004584312438965 | BCE Loss: 1.0359277725219727
Epoch 162 / 500 | iteration 0 / 30 | Total Loss: 6.047030448913574 | KNN Loss: 5.023823261260986 | BCE Loss: 1.0232073068618774
Epoch 162 / 500 | iteration 5 / 30 | Total Loss: 6.016084671020508 | KNN Loss: 5.027547359466553 | BCE Loss: 0.9885373115539551
Epoch 162 / 500 | iteration 10 / 30 | Total Loss: 6.045001983642578 | KNN Loss: 5.013946533203125 | BCE Loss: 1.0310556888580322
Epoch 162 / 500 | iteration 15 / 30 | Total Loss: 6.081828594207764 | KNN Loss: 5.041665077209473 | 

Epoch 172 / 500 | iteration 0 / 30 | Total Loss: 6.048678398132324 | KNN Loss: 5.02221155166626 | BCE Loss: 1.0264666080474854
Epoch 172 / 500 | iteration 5 / 30 | Total Loss: 6.078570365905762 | KNN Loss: 5.0205559730529785 | BCE Loss: 1.0580143928527832
Epoch 172 / 500 | iteration 10 / 30 | Total Loss: 6.0734429359436035 | KNN Loss: 5.018912315368652 | BCE Loss: 1.0545306205749512
Epoch 172 / 500 | iteration 15 / 30 | Total Loss: 6.033078670501709 | KNN Loss: 5.019141674041748 | BCE Loss: 1.0139371156692505
Epoch 172 / 500 | iteration 20 / 30 | Total Loss: 6.043820381164551 | KNN Loss: 4.998090744018555 | BCE Loss: 1.0457298755645752
Epoch 172 / 500 | iteration 25 / 30 | Total Loss: 6.021022319793701 | KNN Loss: 5.012973308563232 | BCE Loss: 1.0080488920211792
Epoch 173 / 500 | iteration 0 / 30 | Total Loss: 6.018740653991699 | KNN Loss: 5.00325870513916 | BCE Loss: 1.015481948852539
Epoch 173 / 500 | iteration 5 / 30 | Total Loss: 6.057538986206055 | KNN Loss: 5.027087211608887 | BC

Epoch 182 / 500 | iteration 20 / 30 | Total Loss: 6.021491050720215 | KNN Loss: 5.0070929527282715 | BCE Loss: 1.0143979787826538
Epoch 182 / 500 | iteration 25 / 30 | Total Loss: 6.053244113922119 | KNN Loss: 5.023173809051514 | BCE Loss: 1.030070424079895
Epoch 183 / 500 | iteration 0 / 30 | Total Loss: 6.028314113616943 | KNN Loss: 5.002669811248779 | BCE Loss: 1.0256441831588745
Epoch 183 / 500 | iteration 5 / 30 | Total Loss: 6.0571746826171875 | KNN Loss: 5.023507118225098 | BCE Loss: 1.0336676836013794
Epoch 183 / 500 | iteration 10 / 30 | Total Loss: 6.043197154998779 | KNN Loss: 5.002577781677246 | BCE Loss: 1.0406192541122437
Epoch 183 / 500 | iteration 15 / 30 | Total Loss: 6.066444396972656 | KNN Loss: 5.035834312438965 | BCE Loss: 1.0306100845336914
Epoch 183 / 500 | iteration 20 / 30 | Total Loss: 6.048348426818848 | KNN Loss: 5.005783557891846 | BCE Loss: 1.0425646305084229
Epoch 183 / 500 | iteration 25 / 30 | Total Loss: 6.046239852905273 | KNN Loss: 5.009491920471191 

Epoch 193 / 500 | iteration 10 / 30 | Total Loss: 6.106555938720703 | KNN Loss: 5.02431583404541 | BCE Loss: 1.0822398662567139
Epoch 193 / 500 | iteration 15 / 30 | Total Loss: 6.028464317321777 | KNN Loss: 4.996397495269775 | BCE Loss: 1.032067060470581
Epoch 193 / 500 | iteration 20 / 30 | Total Loss: 6.021909713745117 | KNN Loss: 5.008660793304443 | BCE Loss: 1.0132488012313843
Epoch 193 / 500 | iteration 25 / 30 | Total Loss: 6.032464027404785 | KNN Loss: 5.013736724853516 | BCE Loss: 1.018727421760559
Epoch 194 / 500 | iteration 0 / 30 | Total Loss: 6.052937030792236 | KNN Loss: 5.023896217346191 | BCE Loss: 1.0290409326553345
Epoch 194 / 500 | iteration 5 / 30 | Total Loss: 6.097586631774902 | KNN Loss: 5.05354642868042 | BCE Loss: 1.0440400838851929
Epoch 194 / 500 | iteration 10 / 30 | Total Loss: 6.056812286376953 | KNN Loss: 5.034637451171875 | BCE Loss: 1.0221748352050781
Epoch 194 / 500 | iteration 15 / 30 | Total Loss: 6.045169830322266 | KNN Loss: 5.015415668487549 | BCE

Epoch 204 / 500 | iteration 0 / 30 | Total Loss: 6.083514213562012 | KNN Loss: 5.053910255432129 | BCE Loss: 1.0296037197113037
Epoch 204 / 500 | iteration 5 / 30 | Total Loss: 6.088666915893555 | KNN Loss: 5.049191951751709 | BCE Loss: 1.0394752025604248
Epoch 204 / 500 | iteration 10 / 30 | Total Loss: 6.0593414306640625 | KNN Loss: 5.023741245269775 | BCE Loss: 1.0356003046035767
Epoch 204 / 500 | iteration 15 / 30 | Total Loss: 6.06182861328125 | KNN Loss: 5.030941009521484 | BCE Loss: 1.030887484550476
Epoch 204 / 500 | iteration 20 / 30 | Total Loss: 6.031280517578125 | KNN Loss: 4.998959541320801 | BCE Loss: 1.0323209762573242
Epoch 204 / 500 | iteration 25 / 30 | Total Loss: 6.077281951904297 | KNN Loss: 5.042796611785889 | BCE Loss: 1.0344853401184082
Epoch 205 / 500 | iteration 0 / 30 | Total Loss: 6.034902572631836 | KNN Loss: 4.999104976654053 | BCE Loss: 1.0357975959777832
Epoch 205 / 500 | iteration 5 / 30 | Total Loss: 6.070522308349609 | KNN Loss: 5.002176284790039 | BC

Epoch 214 / 500 | iteration 20 / 30 | Total Loss: 6.049067497253418 | KNN Loss: 5.015918731689453 | BCE Loss: 1.0331487655639648
Epoch 214 / 500 | iteration 25 / 30 | Total Loss: 6.085432052612305 | KNN Loss: 5.0448527336120605 | BCE Loss: 1.040579080581665
Epoch 215 / 500 | iteration 0 / 30 | Total Loss: 6.0416951179504395 | KNN Loss: 5.0226287841796875 | BCE Loss: 1.0190664529800415
Epoch 215 / 500 | iteration 5 / 30 | Total Loss: 6.1185407638549805 | KNN Loss: 5.061473846435547 | BCE Loss: 1.0570666790008545
Epoch 215 / 500 | iteration 10 / 30 | Total Loss: 6.062761306762695 | KNN Loss: 5.032864093780518 | BCE Loss: 1.0298972129821777
Epoch 215 / 500 | iteration 15 / 30 | Total Loss: 6.04835319519043 | KNN Loss: 5.0264973640441895 | BCE Loss: 1.0218558311462402
Epoch 215 / 500 | iteration 20 / 30 | Total Loss: 6.0505146980285645 | KNN Loss: 5.015153408050537 | BCE Loss: 1.0353612899780273
Epoch 215 / 500 | iteration 25 / 30 | Total Loss: 6.068530082702637 | KNN Loss: 5.0397434234619

Epoch 225 / 500 | iteration 10 / 30 | Total Loss: 6.049723148345947 | KNN Loss: 5.023898601531982 | BCE Loss: 1.0258245468139648
Epoch 225 / 500 | iteration 15 / 30 | Total Loss: 6.085808753967285 | KNN Loss: 5.028022766113281 | BCE Loss: 1.057785987854004
Epoch 225 / 500 | iteration 20 / 30 | Total Loss: 6.072114944458008 | KNN Loss: 5.026945114135742 | BCE Loss: 1.0451695919036865
Epoch 225 / 500 | iteration 25 / 30 | Total Loss: 6.075882911682129 | KNN Loss: 5.018072605133057 | BCE Loss: 1.0578104257583618
Epoch 226 / 500 | iteration 0 / 30 | Total Loss: 6.062229156494141 | KNN Loss: 5.0176897048950195 | BCE Loss: 1.044539451599121
Epoch 226 / 500 | iteration 5 / 30 | Total Loss: 6.046594619750977 | KNN Loss: 5.011147499084473 | BCE Loss: 1.035447120666504
Epoch 226 / 500 | iteration 10 / 30 | Total Loss: 6.030889511108398 | KNN Loss: 5.00692892074585 | BCE Loss: 1.0239605903625488
Epoch 226 / 500 | iteration 15 / 30 | Total Loss: 6.036469459533691 | KNN Loss: 5.012301445007324 | BC

Epoch 236 / 500 | iteration 0 / 30 | Total Loss: 6.0289106369018555 | KNN Loss: 5.006335258483887 | BCE Loss: 1.0225754976272583
Epoch 236 / 500 | iteration 5 / 30 | Total Loss: 6.004672527313232 | KNN Loss: 4.996206760406494 | BCE Loss: 1.0084657669067383
Epoch 236 / 500 | iteration 10 / 30 | Total Loss: 6.0486836433410645 | KNN Loss: 5.033788204193115 | BCE Loss: 1.0148953199386597
Epoch 236 / 500 | iteration 15 / 30 | Total Loss: 6.032046794891357 | KNN Loss: 5.002834796905518 | BCE Loss: 1.0292118787765503
Epoch 236 / 500 | iteration 20 / 30 | Total Loss: 6.051270961761475 | KNN Loss: 5.013434410095215 | BCE Loss: 1.0378365516662598
Epoch 236 / 500 | iteration 25 / 30 | Total Loss: 6.066123008728027 | KNN Loss: 5.046825885772705 | BCE Loss: 1.0192968845367432
Epoch 237 / 500 | iteration 0 / 30 | Total Loss: 6.078777313232422 | KNN Loss: 5.017185688018799 | BCE Loss: 1.061591625213623
Epoch 237 / 500 | iteration 5 / 30 | Total Loss: 6.027895927429199 | KNN Loss: 5.010066032409668 | 

Epoch 246 / 500 | iteration 20 / 30 | Total Loss: 6.067867279052734 | KNN Loss: 5.038124084472656 | BCE Loss: 1.0297430753707886
Epoch 246 / 500 | iteration 25 / 30 | Total Loss: 6.068826675415039 | KNN Loss: 5.039900302886963 | BCE Loss: 1.0289264917373657
Epoch 247 / 500 | iteration 0 / 30 | Total Loss: 6.0842437744140625 | KNN Loss: 5.02943229675293 | BCE Loss: 1.054811716079712
Epoch 247 / 500 | iteration 5 / 30 | Total Loss: 6.1016740798950195 | KNN Loss: 5.055055618286133 | BCE Loss: 1.0466185808181763
Epoch 247 / 500 | iteration 10 / 30 | Total Loss: 6.04471492767334 | KNN Loss: 5.0164361000061035 | BCE Loss: 1.0282788276672363
Epoch 247 / 500 | iteration 15 / 30 | Total Loss: 6.029384613037109 | KNN Loss: 5.002326965332031 | BCE Loss: 1.0270575284957886
Epoch 247 / 500 | iteration 20 / 30 | Total Loss: 6.017473220825195 | KNN Loss: 5.021862506866455 | BCE Loss: 0.9956104755401611
Epoch 247 / 500 | iteration 25 / 30 | Total Loss: 6.046443462371826 | KNN Loss: 5.009698867797852 |

Epoch 257 / 500 | iteration 10 / 30 | Total Loss: 6.052661895751953 | KNN Loss: 5.01886510848999 | BCE Loss: 1.0337966680526733
Epoch 257 / 500 | iteration 15 / 30 | Total Loss: 6.035187721252441 | KNN Loss: 5.028391361236572 | BCE Loss: 1.0067964792251587
Epoch 257 / 500 | iteration 20 / 30 | Total Loss: 6.048271179199219 | KNN Loss: 5.015400409698486 | BCE Loss: 1.032870888710022
Epoch 257 / 500 | iteration 25 / 30 | Total Loss: 6.053013801574707 | KNN Loss: 5.009747505187988 | BCE Loss: 1.0432660579681396
Epoch 258 / 500 | iteration 0 / 30 | Total Loss: 6.090051174163818 | KNN Loss: 5.046652317047119 | BCE Loss: 1.0433989763259888
Epoch 258 / 500 | iteration 5 / 30 | Total Loss: 6.05143404006958 | KNN Loss: 5.032796382904053 | BCE Loss: 1.0186375379562378
Epoch 258 / 500 | iteration 10 / 30 | Total Loss: 6.054821968078613 | KNN Loss: 5.006478786468506 | BCE Loss: 1.0483429431915283
Epoch 258 / 500 | iteration 15 / 30 | Total Loss: 6.065473556518555 | KNN Loss: 5.016833305358887 | BC

Epoch   268: reducing learning rate of group 0 to 2.3738e-05.
Epoch 268 / 500 | iteration 0 / 30 | Total Loss: 6.013206481933594 | KNN Loss: 5.007236957550049 | BCE Loss: 1.0059692859649658
Epoch 268 / 500 | iteration 5 / 30 | Total Loss: 6.022522926330566 | KNN Loss: 5.009215831756592 | BCE Loss: 1.0133068561553955
Epoch 268 / 500 | iteration 10 / 30 | Total Loss: 6.049413681030273 | KNN Loss: 5.02718448638916 | BCE Loss: 1.0222289562225342
Epoch 268 / 500 | iteration 15 / 30 | Total Loss: 6.103693962097168 | KNN Loss: 5.0423359870910645 | BCE Loss: 1.0613579750061035
Epoch 268 / 500 | iteration 20 / 30 | Total Loss: 6.030168533325195 | KNN Loss: 5.006434440612793 | BCE Loss: 1.0237340927124023
Epoch 268 / 500 | iteration 25 / 30 | Total Loss: 6.044207572937012 | KNN Loss: 5.010908603668213 | BCE Loss: 1.033299207687378
Epoch 269 / 500 | iteration 0 / 30 | Total Loss: 5.994176864624023 | KNN Loss: 5.00463342666626 | BCE Loss: 0.9895434975624084
Epoch 269 / 500 | iteration 5 / 30 | Tot

Epoch 278 / 500 | iteration 20 / 30 | Total Loss: 6.059895992279053 | KNN Loss: 5.015424728393555 | BCE Loss: 1.044471263885498
Epoch 278 / 500 | iteration 25 / 30 | Total Loss: 6.043234348297119 | KNN Loss: 5.009065628051758 | BCE Loss: 1.0341687202453613
Epoch   279: reducing learning rate of group 0 to 1.6616e-05.
Epoch 279 / 500 | iteration 0 / 30 | Total Loss: 6.080356597900391 | KNN Loss: 5.077766418457031 | BCE Loss: 1.0025901794433594
Epoch 279 / 500 | iteration 5 / 30 | Total Loss: 6.06008768081665 | KNN Loss: 5.022638320922852 | BCE Loss: 1.0374494791030884
Epoch 279 / 500 | iteration 10 / 30 | Total Loss: 6.027736186981201 | KNN Loss: 5.014758586883545 | BCE Loss: 1.0129776000976562
Epoch 279 / 500 | iteration 15 / 30 | Total Loss: 6.010735034942627 | KNN Loss: 5.001363754272461 | BCE Loss: 1.0093711614608765
Epoch 279 / 500 | iteration 20 / 30 | Total Loss: 6.073187351226807 | KNN Loss: 5.014804840087891 | BCE Loss: 1.058382511138916
Epoch 279 / 500 | iteration 25 / 30 | To

Epoch 289 / 500 | iteration 10 / 30 | Total Loss: 6.102818965911865 | KNN Loss: 5.045796871185303 | BCE Loss: 1.0570220947265625
Epoch 289 / 500 | iteration 15 / 30 | Total Loss: 6.023585319519043 | KNN Loss: 5.017012119293213 | BCE Loss: 1.0065734386444092
Epoch 289 / 500 | iteration 20 / 30 | Total Loss: 6.042482376098633 | KNN Loss: 5.0027384757995605 | BCE Loss: 1.0397436618804932
Epoch 289 / 500 | iteration 25 / 30 | Total Loss: 6.041851043701172 | KNN Loss: 5.0196146965026855 | BCE Loss: 1.0222365856170654
Epoch   290: reducing learning rate of group 0 to 1.1632e-05.
Epoch 290 / 500 | iteration 0 / 30 | Total Loss: 6.042211055755615 | KNN Loss: 5.008718490600586 | BCE Loss: 1.0334926843643188
Epoch 290 / 500 | iteration 5 / 30 | Total Loss: 6.054404258728027 | KNN Loss: 5.028841018676758 | BCE Loss: 1.0255634784698486
Epoch 290 / 500 | iteration 10 / 30 | Total Loss: 6.045506477355957 | KNN Loss: 5.015398025512695 | BCE Loss: 1.0301082134246826
Epoch 290 / 500 | iteration 15 / 30

Epoch 300 / 500 | iteration 0 / 30 | Total Loss: 6.036981105804443 | KNN Loss: 5.023275852203369 | BCE Loss: 1.0137051343917847
Epoch 300 / 500 | iteration 5 / 30 | Total Loss: 6.062219142913818 | KNN Loss: 5.021985054016113 | BCE Loss: 1.040234088897705
Epoch 300 / 500 | iteration 10 / 30 | Total Loss: 6.0905866622924805 | KNN Loss: 5.0359578132629395 | BCE Loss: 1.054628610610962
Epoch 300 / 500 | iteration 15 / 30 | Total Loss: 6.052672386169434 | KNN Loss: 5.053261756896973 | BCE Loss: 0.9994104504585266
Epoch 300 / 500 | iteration 20 / 30 | Total Loss: 6.05092716217041 | KNN Loss: 5.018309593200684 | BCE Loss: 1.032617449760437
Epoch 300 / 500 | iteration 25 / 30 | Total Loss: 6.086071968078613 | KNN Loss: 5.0401153564453125 | BCE Loss: 1.0459568500518799
Epoch 301 / 500 | iteration 0 / 30 | Total Loss: 6.045450687408447 | KNN Loss: 5.013944625854492 | BCE Loss: 1.031506061553955
Epoch 301 / 500 | iteration 5 / 30 | Total Loss: 6.048126220703125 | KNN Loss: 5.019384860992432 | BCE

Epoch 310 / 500 | iteration 20 / 30 | Total Loss: 6.025239944458008 | KNN Loss: 5.018309116363525 | BCE Loss: 1.0069305896759033
Epoch 310 / 500 | iteration 25 / 30 | Total Loss: 6.069162845611572 | KNN Loss: 5.013923645019531 | BCE Loss: 1.055239200592041
Epoch 311 / 500 | iteration 0 / 30 | Total Loss: 6.052623271942139 | KNN Loss: 5.007540225982666 | BCE Loss: 1.0450830459594727
Epoch 311 / 500 | iteration 5 / 30 | Total Loss: 6.046937942504883 | KNN Loss: 5.0088791847229 | BCE Loss: 1.0380585193634033
Epoch 311 / 500 | iteration 10 / 30 | Total Loss: 6.105905055999756 | KNN Loss: 5.031917095184326 | BCE Loss: 1.0739879608154297
Epoch 311 / 500 | iteration 15 / 30 | Total Loss: 6.012267589569092 | KNN Loss: 5.016502857208252 | BCE Loss: 0.9957647919654846
Epoch 311 / 500 | iteration 20 / 30 | Total Loss: 6.043435573577881 | KNN Loss: 5.0229902267456055 | BCE Loss: 1.020445466041565
Epoch 311 / 500 | iteration 25 / 30 | Total Loss: 6.070878028869629 | KNN Loss: 5.058707237243652 | BC

Epoch 321 / 500 | iteration 10 / 30 | Total Loss: 6.049729824066162 | KNN Loss: 5.013769149780273 | BCE Loss: 1.0359607934951782
Epoch 321 / 500 | iteration 15 / 30 | Total Loss: 6.026420593261719 | KNN Loss: 5.003382205963135 | BCE Loss: 1.0230382680892944
Epoch 321 / 500 | iteration 20 / 30 | Total Loss: 6.021750450134277 | KNN Loss: 5.003814220428467 | BCE Loss: 1.0179364681243896
Epoch 321 / 500 | iteration 25 / 30 | Total Loss: 6.07747745513916 | KNN Loss: 5.028421401977539 | BCE Loss: 1.049055814743042
Epoch 322 / 500 | iteration 0 / 30 | Total Loss: 6.00756311416626 | KNN Loss: 5.012650489807129 | BCE Loss: 0.9949126243591309
Epoch 322 / 500 | iteration 5 / 30 | Total Loss: 6.00161600112915 | KNN Loss: 5.0099663734436035 | BCE Loss: 0.9916495084762573
Epoch 322 / 500 | iteration 10 / 30 | Total Loss: 6.0433759689331055 | KNN Loss: 5.0170512199401855 | BCE Loss: 1.02632474899292
Epoch 322 / 500 | iteration 15 / 30 | Total Loss: 6.0418853759765625 | KNN Loss: 5.008951187133789 | B

Epoch 332 / 500 | iteration 0 / 30 | Total Loss: 6.071387767791748 | KNN Loss: 5.019378185272217 | BCE Loss: 1.0520094633102417
Epoch 332 / 500 | iteration 5 / 30 | Total Loss: 6.028693675994873 | KNN Loss: 4.997572898864746 | BCE Loss: 1.031120777130127
Epoch 332 / 500 | iteration 10 / 30 | Total Loss: 6.041281700134277 | KNN Loss: 5.037398338317871 | BCE Loss: 1.0038833618164062
Epoch 332 / 500 | iteration 15 / 30 | Total Loss: 6.046102046966553 | KNN Loss: 5.016395568847656 | BCE Loss: 1.0297064781188965
Epoch 332 / 500 | iteration 20 / 30 | Total Loss: 6.025871753692627 | KNN Loss: 5.012014865875244 | BCE Loss: 1.0138567686080933
Epoch 332 / 500 | iteration 25 / 30 | Total Loss: 6.04915714263916 | KNN Loss: 5.024508476257324 | BCE Loss: 1.0246484279632568
Epoch 333 / 500 | iteration 0 / 30 | Total Loss: 6.047740936279297 | KNN Loss: 5.001276016235352 | BCE Loss: 1.0464651584625244
Epoch 333 / 500 | iteration 5 / 30 | Total Loss: 6.0127716064453125 | KNN Loss: 4.994142055511475 | BC

Epoch 342 / 500 | iteration 20 / 30 | Total Loss: 6.048254013061523 | KNN Loss: 5.027513027191162 | BCE Loss: 1.0207409858703613
Epoch 342 / 500 | iteration 25 / 30 | Total Loss: 6.025027751922607 | KNN Loss: 5.022772789001465 | BCE Loss: 1.0022549629211426
Epoch 343 / 500 | iteration 0 / 30 | Total Loss: 6.032893180847168 | KNN Loss: 5.025766372680664 | BCE Loss: 1.007127046585083
Epoch 343 / 500 | iteration 5 / 30 | Total Loss: 6.0258469581604 | KNN Loss: 5.015528678894043 | BCE Loss: 1.010318398475647
Epoch 343 / 500 | iteration 10 / 30 | Total Loss: 6.054093360900879 | KNN Loss: 5.032325267791748 | BCE Loss: 1.0217678546905518
Epoch 343 / 500 | iteration 15 / 30 | Total Loss: 6.032695770263672 | KNN Loss: 5.011592864990234 | BCE Loss: 1.0211031436920166
Epoch 343 / 500 | iteration 20 / 30 | Total Loss: 6.042458534240723 | KNN Loss: 5.025774002075195 | BCE Loss: 1.0166842937469482
Epoch 343 / 500 | iteration 25 / 30 | Total Loss: 6.083264350891113 | KNN Loss: 5.021119117736816 | BCE

Epoch 353 / 500 | iteration 10 / 30 | Total Loss: 6.060959815979004 | KNN Loss: 5.019456386566162 | BCE Loss: 1.041503667831421
Epoch 353 / 500 | iteration 15 / 30 | Total Loss: 6.05232048034668 | KNN Loss: 5.015036582946777 | BCE Loss: 1.0372841358184814
Epoch 353 / 500 | iteration 20 / 30 | Total Loss: 6.028977394104004 | KNN Loss: 5.006956100463867 | BCE Loss: 1.0220212936401367
Epoch 353 / 500 | iteration 25 / 30 | Total Loss: 6.021738052368164 | KNN Loss: 5.001918792724609 | BCE Loss: 1.0198190212249756
Epoch 354 / 500 | iteration 0 / 30 | Total Loss: 6.05396842956543 | KNN Loss: 5.034595012664795 | BCE Loss: 1.0193736553192139
Epoch 354 / 500 | iteration 5 / 30 | Total Loss: 6.030961036682129 | KNN Loss: 5.007833957672119 | BCE Loss: 1.0231270790100098
Epoch 354 / 500 | iteration 10 / 30 | Total Loss: 6.057252883911133 | KNN Loss: 5.042801380157471 | BCE Loss: 1.0144516229629517
Epoch 354 / 500 | iteration 15 / 30 | Total Loss: 6.0871782302856445 | KNN Loss: 5.049203395843506 | B

Epoch 364 / 500 | iteration 0 / 30 | Total Loss: 6.016395568847656 | KNN Loss: 5.00892972946167 | BCE Loss: 1.0074658393859863
Epoch 364 / 500 | iteration 5 / 30 | Total Loss: 6.087184429168701 | KNN Loss: 5.036960124969482 | BCE Loss: 1.0502244234085083
Epoch 364 / 500 | iteration 10 / 30 | Total Loss: 6.091559886932373 | KNN Loss: 5.040876388549805 | BCE Loss: 1.0506833791732788
Epoch 364 / 500 | iteration 15 / 30 | Total Loss: 6.043370246887207 | KNN Loss: 5.012822151184082 | BCE Loss: 1.030548095703125
Epoch 364 / 500 | iteration 20 / 30 | Total Loss: 6.068790912628174 | KNN Loss: 5.006211280822754 | BCE Loss: 1.0625797510147095
Epoch 364 / 500 | iteration 25 / 30 | Total Loss: 6.036464214324951 | KNN Loss: 5.012161731719971 | BCE Loss: 1.0243024826049805
Epoch 365 / 500 | iteration 0 / 30 | Total Loss: 6.040223121643066 | KNN Loss: 5.032845973968506 | BCE Loss: 1.0073773860931396
Epoch 365 / 500 | iteration 5 / 30 | Total Loss: 6.0594682693481445 | KNN Loss: 5.018750190734863 | BC

Epoch 374 / 500 | iteration 20 / 30 | Total Loss: 6.045515060424805 | KNN Loss: 5.020077228546143 | BCE Loss: 1.0254377126693726
Epoch 374 / 500 | iteration 25 / 30 | Total Loss: 6.053659439086914 | KNN Loss: 5.0265398025512695 | BCE Loss: 1.0271198749542236
Epoch 375 / 500 | iteration 0 / 30 | Total Loss: 6.048740386962891 | KNN Loss: 5.017756938934326 | BCE Loss: 1.0309836864471436
Epoch 375 / 500 | iteration 5 / 30 | Total Loss: 5.997547149658203 | KNN Loss: 4.988187789916992 | BCE Loss: 1.00935959815979
Epoch 375 / 500 | iteration 10 / 30 | Total Loss: 6.039884567260742 | KNN Loss: 5.011584281921387 | BCE Loss: 1.0283002853393555
Epoch 375 / 500 | iteration 15 / 30 | Total Loss: 6.013166904449463 | KNN Loss: 5.000722408294678 | BCE Loss: 1.0124443769454956
Epoch 375 / 500 | iteration 20 / 30 | Total Loss: 6.109011650085449 | KNN Loss: 5.045729160308838 | BCE Loss: 1.0632823705673218
Epoch 375 / 500 | iteration 25 / 30 | Total Loss: 6.010015487670898 | KNN Loss: 5.004150867462158 | 

Epoch 385 / 500 | iteration 10 / 30 | Total Loss: 6.061944961547852 | KNN Loss: 5.025825023651123 | BCE Loss: 1.0361201763153076
Epoch 385 / 500 | iteration 15 / 30 | Total Loss: 6.045200347900391 | KNN Loss: 5.004926681518555 | BCE Loss: 1.0402734279632568
Epoch 385 / 500 | iteration 20 / 30 | Total Loss: 6.037946701049805 | KNN Loss: 5.001579761505127 | BCE Loss: 1.0363671779632568
Epoch 385 / 500 | iteration 25 / 30 | Total Loss: 6.0641961097717285 | KNN Loss: 5.021745204925537 | BCE Loss: 1.0424509048461914
Epoch 386 / 500 | iteration 0 / 30 | Total Loss: 6.062648773193359 | KNN Loss: 5.033629894256592 | BCE Loss: 1.0290188789367676
Epoch 386 / 500 | iteration 5 / 30 | Total Loss: 6.101574897766113 | KNN Loss: 5.0872697830200195 | BCE Loss: 1.0143052339553833
Epoch 386 / 500 | iteration 10 / 30 | Total Loss: 6.044249057769775 | KNN Loss: 5.012156963348389 | BCE Loss: 1.0320922136306763
Epoch 386 / 500 | iteration 15 / 30 | Total Loss: 6.064578056335449 | KNN Loss: 5.041496276855469

Epoch 396 / 500 | iteration 0 / 30 | Total Loss: 6.055603504180908 | KNN Loss: 5.023997783660889 | BCE Loss: 1.031605839729309
Epoch 396 / 500 | iteration 5 / 30 | Total Loss: 6.07103157043457 | KNN Loss: 5.035933017730713 | BCE Loss: 1.035098671913147
Epoch 396 / 500 | iteration 10 / 30 | Total Loss: 6.050051212310791 | KNN Loss: 5.040304183959961 | BCE Loss: 1.0097471475601196
Epoch 396 / 500 | iteration 15 / 30 | Total Loss: 6.0619001388549805 | KNN Loss: 5.027299404144287 | BCE Loss: 1.0346004962921143
Epoch 396 / 500 | iteration 20 / 30 | Total Loss: 6.040177345275879 | KNN Loss: 5.004186630249023 | BCE Loss: 1.0359904766082764
Epoch 396 / 500 | iteration 25 / 30 | Total Loss: 6.036218166351318 | KNN Loss: 5.019582271575928 | BCE Loss: 1.0166358947753906
Epoch 397 / 500 | iteration 0 / 30 | Total Loss: 6.113120079040527 | KNN Loss: 5.0549774169921875 | BCE Loss: 1.058142900466919
Epoch 397 / 500 | iteration 5 / 30 | Total Loss: 6.069126605987549 | KNN Loss: 5.014901161193848 | BCE

Epoch 406 / 500 | iteration 20 / 30 | Total Loss: 6.100532054901123 | KNN Loss: 5.059808254241943 | BCE Loss: 1.0407239198684692
Epoch 406 / 500 | iteration 25 / 30 | Total Loss: 6.067277431488037 | KNN Loss: 5.031207084655762 | BCE Loss: 1.0360702276229858
Epoch 407 / 500 | iteration 0 / 30 | Total Loss: 6.052545070648193 | KNN Loss: 5.044075012207031 | BCE Loss: 1.0084699392318726
Epoch 407 / 500 | iteration 5 / 30 | Total Loss: 6.014760971069336 | KNN Loss: 5.004143714904785 | BCE Loss: 1.0106174945831299
Epoch 407 / 500 | iteration 10 / 30 | Total Loss: 6.0492448806762695 | KNN Loss: 5.033742904663086 | BCE Loss: 1.0155020952224731
Epoch 407 / 500 | iteration 15 / 30 | Total Loss: 6.045732021331787 | KNN Loss: 5.014837741851807 | BCE Loss: 1.03089439868927
Epoch 407 / 500 | iteration 20 / 30 | Total Loss: 6.054770469665527 | KNN Loss: 5.023788928985596 | BCE Loss: 1.0309815406799316
Epoch 407 / 500 | iteration 25 / 30 | Total Loss: 6.05964469909668 | KNN Loss: 5.023557662963867 | B

Epoch 417 / 500 | iteration 10 / 30 | Total Loss: 6.073976516723633 | KNN Loss: 5.0424933433532715 | BCE Loss: 1.0314830541610718
Epoch 417 / 500 | iteration 15 / 30 | Total Loss: 6.073605060577393 | KNN Loss: 5.048844337463379 | BCE Loss: 1.0247606039047241
Epoch 417 / 500 | iteration 20 / 30 | Total Loss: 6.070494651794434 | KNN Loss: 5.025375843048096 | BCE Loss: 1.045119047164917
Epoch 417 / 500 | iteration 25 / 30 | Total Loss: 6.064312934875488 | KNN Loss: 5.021456241607666 | BCE Loss: 1.0428569316864014
Epoch 418 / 500 | iteration 0 / 30 | Total Loss: 6.0180864334106445 | KNN Loss: 5.001406192779541 | BCE Loss: 1.0166804790496826
Epoch 418 / 500 | iteration 5 / 30 | Total Loss: 6.026587963104248 | KNN Loss: 5.023806571960449 | BCE Loss: 1.0027815103530884
Epoch 418 / 500 | iteration 10 / 30 | Total Loss: 6.039917469024658 | KNN Loss: 5.013638973236084 | BCE Loss: 1.0262783765792847
Epoch 418 / 500 | iteration 15 / 30 | Total Loss: 6.0996174812316895 | KNN Loss: 5.03252649307251 

Epoch 428 / 500 | iteration 0 / 30 | Total Loss: 6.076841354370117 | KNN Loss: 5.057008266448975 | BCE Loss: 1.0198332071304321
Epoch 428 / 500 | iteration 5 / 30 | Total Loss: 6.023669242858887 | KNN Loss: 5.001410961151123 | BCE Loss: 1.0222581624984741
Epoch 428 / 500 | iteration 10 / 30 | Total Loss: 6.071459770202637 | KNN Loss: 5.016894817352295 | BCE Loss: 1.0545647144317627
Epoch 428 / 500 | iteration 15 / 30 | Total Loss: 6.073071479797363 | KNN Loss: 5.026872158050537 | BCE Loss: 1.0461993217468262
Epoch 428 / 500 | iteration 20 / 30 | Total Loss: 6.101875305175781 | KNN Loss: 5.075786590576172 | BCE Loss: 1.0260885953903198
Epoch 428 / 500 | iteration 25 / 30 | Total Loss: 6.150151252746582 | KNN Loss: 5.045929431915283 | BCE Loss: 1.1042218208312988
Epoch 429 / 500 | iteration 0 / 30 | Total Loss: 6.076177597045898 | KNN Loss: 5.016018390655518 | BCE Loss: 1.0601592063903809
Epoch 429 / 500 | iteration 5 / 30 | Total Loss: 6.046045303344727 | KNN Loss: 5.01049280166626 | BC

Epoch 438 / 500 | iteration 20 / 30 | Total Loss: 6.008277893066406 | KNN Loss: 4.991856098175049 | BCE Loss: 1.0164217948913574
Epoch 438 / 500 | iteration 25 / 30 | Total Loss: 6.041222095489502 | KNN Loss: 5.026523113250732 | BCE Loss: 1.0146989822387695
Epoch 439 / 500 | iteration 0 / 30 | Total Loss: 6.041961669921875 | KNN Loss: 5.00853967666626 | BCE Loss: 1.0334219932556152
Epoch 439 / 500 | iteration 5 / 30 | Total Loss: 6.048149585723877 | KNN Loss: 5.025080680847168 | BCE Loss: 1.0230690240859985
Epoch 439 / 500 | iteration 10 / 30 | Total Loss: 6.0717058181762695 | KNN Loss: 5.030248641967773 | BCE Loss: 1.041456937789917
Epoch 439 / 500 | iteration 15 / 30 | Total Loss: 6.019771099090576 | KNN Loss: 5.012335300445557 | BCE Loss: 1.0074357986450195
Epoch 439 / 500 | iteration 20 / 30 | Total Loss: 6.0255889892578125 | KNN Loss: 5.029683589935303 | BCE Loss: 0.9959052801132202
Epoch 439 / 500 | iteration 25 / 30 | Total Loss: 6.047815322875977 | KNN Loss: 5.0191755294799805 

Epoch 449 / 500 | iteration 10 / 30 | Total Loss: 6.052653789520264 | KNN Loss: 5.023716449737549 | BCE Loss: 1.0289373397827148
Epoch 449 / 500 | iteration 15 / 30 | Total Loss: 6.122894287109375 | KNN Loss: 5.069304943084717 | BCE Loss: 1.0535892248153687
Epoch 449 / 500 | iteration 20 / 30 | Total Loss: 6.049287796020508 | KNN Loss: 5.017678737640381 | BCE Loss: 1.031609296798706
Epoch 449 / 500 | iteration 25 / 30 | Total Loss: 6.062922477722168 | KNN Loss: 5.031908988952637 | BCE Loss: 1.0310134887695312
Epoch 450 / 500 | iteration 0 / 30 | Total Loss: 6.0268168449401855 | KNN Loss: 5.01405668258667 | BCE Loss: 1.0127602815628052
Epoch 450 / 500 | iteration 5 / 30 | Total Loss: 6.042359352111816 | KNN Loss: 5.026082992553711 | BCE Loss: 1.0162763595581055
Epoch 450 / 500 | iteration 10 / 30 | Total Loss: 6.0299601554870605 | KNN Loss: 5.010076522827148 | BCE Loss: 1.0198837518692017
Epoch 450 / 500 | iteration 15 / 30 | Total Loss: 6.058023929595947 | KNN Loss: 5.025526523590088 |

Epoch 460 / 500 | iteration 0 / 30 | Total Loss: 6.033576965332031 | KNN Loss: 5.009700775146484 | BCE Loss: 1.0238761901855469
Epoch 460 / 500 | iteration 5 / 30 | Total Loss: 6.110397815704346 | KNN Loss: 5.0677337646484375 | BCE Loss: 1.0426640510559082
Epoch 460 / 500 | iteration 10 / 30 | Total Loss: 6.038633346557617 | KNN Loss: 5.005651950836182 | BCE Loss: 1.0329811573028564
Epoch 460 / 500 | iteration 15 / 30 | Total Loss: 6.037411689758301 | KNN Loss: 5.013095855712891 | BCE Loss: 1.024315595626831
Epoch 460 / 500 | iteration 20 / 30 | Total Loss: 6.04649019241333 | KNN Loss: 5.004507541656494 | BCE Loss: 1.041982650756836
Epoch 460 / 500 | iteration 25 / 30 | Total Loss: 6.074123859405518 | KNN Loss: 5.028656005859375 | BCE Loss: 1.045467734336853
Epoch 461 / 500 | iteration 0 / 30 | Total Loss: 6.081554889678955 | KNN Loss: 5.041596412658691 | BCE Loss: 1.0399584770202637
Epoch 461 / 500 | iteration 5 / 30 | Total Loss: 6.045753479003906 | KNN Loss: 5.027709007263184 | BCE 

Epoch 470 / 500 | iteration 20 / 30 | Total Loss: 6.034987926483154 | KNN Loss: 5.00231409072876 | BCE Loss: 1.0326738357543945
Epoch 470 / 500 | iteration 25 / 30 | Total Loss: 6.0556440353393555 | KNN Loss: 5.0288166999816895 | BCE Loss: 1.0268275737762451
Epoch 471 / 500 | iteration 0 / 30 | Total Loss: 6.021641731262207 | KNN Loss: 5.002894401550293 | BCE Loss: 1.018747329711914
Epoch 471 / 500 | iteration 5 / 30 | Total Loss: 6.066617012023926 | KNN Loss: 5.026043891906738 | BCE Loss: 1.040573239326477
Epoch 471 / 500 | iteration 10 / 30 | Total Loss: 6.042706489562988 | KNN Loss: 5.02046012878418 | BCE Loss: 1.0222465991973877
Epoch 471 / 500 | iteration 15 / 30 | Total Loss: 6.03903865814209 | KNN Loss: 5.010788917541504 | BCE Loss: 1.0282495021820068
Epoch 471 / 500 | iteration 20 / 30 | Total Loss: 6.0509467124938965 | KNN Loss: 5.029558181762695 | BCE Loss: 1.0213885307312012
Epoch 471 / 500 | iteration 25 / 30 | Total Loss: 6.046949863433838 | KNN Loss: 5.011531352996826 | B

Epoch 481 / 500 | iteration 10 / 30 | Total Loss: 6.034120082855225 | KNN Loss: 5.0099592208862305 | BCE Loss: 1.0241607427597046
Epoch 481 / 500 | iteration 15 / 30 | Total Loss: 6.052731513977051 | KNN Loss: 5.0125813484191895 | BCE Loss: 1.0401499271392822
Epoch 481 / 500 | iteration 20 / 30 | Total Loss: 6.090795040130615 | KNN Loss: 5.0254292488098145 | BCE Loss: 1.0653659105300903
Epoch 481 / 500 | iteration 25 / 30 | Total Loss: 6.014037132263184 | KNN Loss: 5.001946926116943 | BCE Loss: 1.0120902061462402
Epoch 482 / 500 | iteration 0 / 30 | Total Loss: 6.0226545333862305 | KNN Loss: 5.007966995239258 | BCE Loss: 1.0146877765655518
Epoch 482 / 500 | iteration 5 / 30 | Total Loss: 6.066821098327637 | KNN Loss: 5.047447204589844 | BCE Loss: 1.019373893737793
Epoch 482 / 500 | iteration 10 / 30 | Total Loss: 6.072332859039307 | KNN Loss: 5.006659507751465 | BCE Loss: 1.0656733512878418
Epoch 482 / 500 | iteration 15 / 30 | Total Loss: 6.0557942390441895 | KNN Loss: 5.0212788581848

Epoch 492 / 500 | iteration 0 / 30 | Total Loss: 6.100092887878418 | KNN Loss: 5.060546875 | BCE Loss: 1.039546012878418
Epoch 492 / 500 | iteration 5 / 30 | Total Loss: 6.032808303833008 | KNN Loss: 5.01011848449707 | BCE Loss: 1.022689938545227
Epoch 492 / 500 | iteration 10 / 30 | Total Loss: 6.03571891784668 | KNN Loss: 5.011983871459961 | BCE Loss: 1.0237348079681396
Epoch 492 / 500 | iteration 15 / 30 | Total Loss: 6.067623138427734 | KNN Loss: 5.043232440948486 | BCE Loss: 1.024390697479248
Epoch 492 / 500 | iteration 20 / 30 | Total Loss: 6.050644874572754 | KNN Loss: 5.029106616973877 | BCE Loss: 1.0215381383895874
Epoch 492 / 500 | iteration 25 / 30 | Total Loss: 5.991954803466797 | KNN Loss: 5.000690460205078 | BCE Loss: 0.9912642240524292
Epoch 493 / 500 | iteration 0 / 30 | Total Loss: 6.029138565063477 | KNN Loss: 5.0149712562561035 | BCE Loss: 1.014167070388794
Epoch 493 / 500 | iteration 5 / 30 | Total Loss: 6.039982795715332 | KNN Loss: 4.9985671043396 | BCE Loss: 1.04

In [7]:
outputs, iterm = model(dataset[67][0].unsqueeze(0).to(device), return_intermidiate=True)
print(outputs)
print(outputs.min())
print(outputs.max())

tensor([[ 2.0212,  4.1637,  2.2793,  2.4290,  2.3282,  0.7648,  2.8325,  1.9331,
          1.9076,  2.1568,  2.3861,  2.4074,  0.6492,  1.8556,  1.3833,  1.1246,
          2.4802,  3.3786,  2.0662,  2.5517,  1.3852,  3.1902,  2.0022,  2.8558,
          2.2824,  1.8604,  2.3070,  1.5351,  1.6248,  0.4055, -0.2656,  0.5114,
          0.3085,  1.0195,  1.6012,  1.5376,  1.1399,  3.5268,  0.9174,  1.3869,
          1.0135, -0.7747, -0.2560,  2.5865,  2.2172,  0.8223, -0.1455,  0.1383,
          1.5418,  2.3296,  1.9156,  0.1940,  1.0721,  0.4884, -0.6085,  1.2113,
          1.6052,  1.4865,  0.9008,  1.9952,  0.6163,  0.9264,  0.1962,  1.7886,
          1.4019,  1.8067, -1.8601,  0.4402,  2.3858,  1.7615,  2.2435,  0.4592,
          1.4922,  2.6230,  2.1548,  1.4235,  0.3650,  0.8210,  0.3345,  1.1884,
          0.1844,  0.4736,  1.4085, -0.2858, -0.0079, -1.0756, -2.3881, -0.1343,
          0.5750, -1.8682,  0.5242, -0.0994, -0.5039, -0.8989,  0.6663,  1.3529,
         -0.6817, -0.5918,  

In [8]:
plt.figure()
plt.plot(losses)
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [9]:
dataset.transform = torchvision.transforms.Compose([
    BinaryEncodingTransform(mapping=dataset.items_to_idx),
]
)
dataset.target_transform = torchvision.transforms.Compose([
    BinaryEncodingTransform(mapping=dataset.items_to_idx),
]
)

In [10]:
dataset_ = [d[0].cpu() for d in dataset]

In [11]:
model = model.eval().cpu()
projections = model.calculate_intermidiate(dataset_)

100%|██████████| 15/15 [00:00<00:00, 79.29it/s]


In [12]:
distances = pairwise_distances(projections)
# distances = np.triu(distances)
distances_f = distances.flatten()

plt.matshow(distances)
plt.colorbar()
plt.figure()
plt.hist(distances_f[distances_f > 0], bins=1000)
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

# Fit DBSCAN and calculate indices


In [13]:
clusters = DBSCAN(eps=0.2, min_samples=80).fit_predict(projections)
# scores = []
# best_score = float('inf')
# clusters = None
# range_ = list(range(5, 20))
# for k in tqdm(range_):
#     y = GaussianMixture(n_components=k).fit_predict(projections)
#     cur_score = davies_bouldin_score(projections, y)
#     scores.append(cur_score)
    
#     if cur_score < best_score:
#         best_score = cur_score
#         clusters = y

In [14]:
perplexity = 100

p = reduce_dims_and_plot(projections[clusters != -1],
                         y=clusters[clusters != -1],
                         title=f'perplexity: {perplexity}',
                         file_name=None,
                         perplexity=perplexity,
                         library='Multicore-TSNE',
                         perform_PCA=False,
                         projected=None,
                         figure_type='2d',
                         show_figure=True,
                         close_figure=False,
                         text=None)    

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [15]:
# from sklearn.tree import DecisionTreeClassifier
# from sklearn import tree
# from sklearn.tree import _tree

In [16]:
tensor_dataset = torch.stack(dataset_)

In [17]:
# clf = DecisionTreeClassifier(max_depth=200, min_samples_leaf=5)
# clf = clf.fit(tensor_dataset[clusters!=-1], clusters[clusters != -1])
# print(clf.score(tensor_dataset[clusters!=-1], clusters[clusters != -1]))
# print(clf.get_depth())

In [18]:
# scores = []
# for min_samples in range(1,50, 1):
#     clf = DecisionTreeClassifier(max_depth=200, min_samples_leaf=min_samples)
#     clf = clf.fit(tensor_dataset[clusters!=-1], clusters[clusters != -1])
#     scores.append(clf.score(tensor_dataset[clusters!=-1], clusters[clusters != -1]))
    
# plt.figure()
# plt.plot(list(range(1,50, 1)), scores)
# plt.show()

In [19]:
def get_rules(tree, feature_names, class_names):
    tree_ = tree.tree_
    feature_name = [
        feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!"
        for i in tree_.feature
    ]

    paths = []
    path = []
    
    def recurse(node, path, paths):
        if tree_.feature[node] != _tree.TREE_UNDEFINED:
            name = feature_name[node]
            threshold = tree_.threshold[node]
            p1, p2 = list(path), list(path)
#             p1 += [f"({name} <= {np.round(threshold, 3)})"]
            p1 += [(name, '<=', np.round(threshold, 3))]
            recurse(tree_.children_left[node], p1, paths)
            p2 += [(name, '>', np.round(threshold, 3))]
            recurse(tree_.children_right[node], p2, paths)
        else:
            path += [(tree_.value[node], tree_.n_node_samples[node])]
            paths += [path]
            
    recurse(0, path, paths)

    # sort by samples count
    samples_count = [p[-1][1] for p in paths]
    ii = list(np.argsort(samples_count))
    paths = [paths[i] for i in reversed(ii)]
    
    rules = []
    for path in paths:
        rule = []
        
        for p in path[:-1]:
            rule += [p]
        target = " then "
        if class_names is None:
            target += "response: "+str(np.round(path[-1][0][0][0],3))
        else:
            classes = path[-1][0][0]
            l = np.argmax(classes)
            target += f"class: {class_names[l]} (proba: {np.round(100.0*classes[l]/np.sum(classes),2)}%)"
           
        proba = np.round(100.0*classes[l]/np.sum(classes),2)
        target += f" | based on {path[-1][1]:,} samples"
        rule_wrapper = {'target': target, 'rule': rule, 'proba': proba}
        rules += [rule_wrapper]
        
    return rules

In [20]:
# rules = get_rules(clf, dataset.items, clusters[clusters != -1])

# for rule in rules:
#     n_pos = 0
#     for c,p,v in rule['rule']:
#         if p == '>':
#             n_pos += 1
#     rule['pos'] = n_pos

In [21]:
# plt.figure()
# probs = [r['proba'] for r in rules]
# plt.hist(probs, bins = 100)
# plt.show()

In [22]:
# rules = sorted(rules, key=lambda x:x['pos'])
# rules = [r for r in rules if r['proba'] > 50]
# print(len(rules))

In [23]:
# for i in range(17):
#     r_i = rules[i]
#     print(f"------------- rule {i} length {len(r_i)} -------------")
#     print(r_i)

# Train a Soft-Decision-Tree given the self-labels

## Prepare the dataset

In [24]:
tree_dataset = list(zip(tensor_dataset[clusters!=-1], clusters[clusters != -1]))
batch_size = 512
tree_loader = torch.utils.data.DataLoader(tree_dataset, batch_size=batch_size, shuffle=True)

# Define how we prune the weights of a node

In [25]:
def prune_node(node_weights, factor=1):
    w = node_weights.cpu().detach().numpy()
    mean_ = np.mean(w)
    std_ = np.std(w)
    node_weights[((mean_ - std_ * factor) < node_weights) & (node_weights < (mean_ + std_ * factor))] = 0
    return node_weights

def prune_node_keep(node_weights, keep=4):
    w = node_weights.cpu().detach().numpy()
    throw_idx = np.argsort(abs(w))[:-keep]
    node_weights[throw_idx] = 0
    return node_weights

def prune_tree(tree_, factor):
    new_weights = tree_.inner_nodes.weight.clone()
    for i in range(new_weights.shape[0]):
        res = prune_node_keep(new_weights[i, :], factor)
        new_weights[i, :] = res

    with torch.no_grad():
        tree_.inner_nodes.weight.copy_(new_weights)
        
def sparseness(x):
    s = []
    for i in range(x.shape[0]):
        x_ = x[i, :]
        sp = (len(x_) - torch.norm(x_, 0).item()) / len(x_)
        s.append(sp)
    return np.mean(s)

def compute_regularization_by_level(tree):
    total_reg = 0
    for i in range(tree.inner_nodes.weight.shape[0]):
        cur_level = np.floor(np.log2(i+1))
        node_reg = torch.norm(tree.inner_nodes.weight[i].view(-1), 2)
        total_reg += 2**(-cur_level) * node_reg
    return total_reg

def show_sparseness(tree):
    avg_sp = sparseness(tree.inner_nodes.weight)
    print(f"Average sparseness: {avg_sp}")
    layer = 0
    sps = []
    for i in range(tree.inner_nodes.weight.shape[0]):
        cur_layer = int(np.floor(np.log2(i+1)))
        if cur_layer != layer:
            print(f"layer {layer}: {np.mean(sps)}")
            sps = []
            layer = cur_layer

        x_ = tree.inner_nodes.weight[i, :]
        sp = (len(x_) - torch.norm(x_, 0).item()) / len(x_)
        sps.append(sp)
        
    return avg_sp

## Training configurations

In [26]:
def do_epoch(model, loader, device, log_interval, losses, accs, epoch, iteration):
    model = model.train()
    for batch_idx, (data, target) in enumerate(loader):
        iteration += 1
        data, target = data.to(device), target.to(device)

        output, penalty = tree.forward(data)

        # Loss
        loss_tree = criterion(output, target.view(-1))

        # Penalty
        loss_tree += penalty

        # Sparse regularization
#         fc_params = torch.cat([x.view(-1) for x in tree.inner_nodes.parameters()])
#         regularization = sparsity_lamda * torch.norm(fc_params, 2)
        regularization = sparsity_lamda * compute_regularization_by_level(tree)
        loss = loss_tree

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        losses.append(loss.item())

        pred = output.data.max(1)[1]
        correct = pred.eq(target.view(-1).data).sum()
        accs.append(correct.item() / data.size()[0])

        # Print training status
        if batch_idx % log_interval == 0:
            print(f"Epoch: {epoch:02d} | Batch: {batch_idx:03d} / {len(loader):03d} | Total loss: {loss.item():.3f} | Reg loss: {regularization.item():.3f} | Tree loss: {loss_tree.item():.3f} | Accuracy: {correct.item() / data.size()[0]:03f} | {round((time.time() - start_time) / iteration, 3)} sec/iter")
            
    return iteration


In [27]:
lr = 5e-3
weight_decay = 5e-4
sparsity_lamda = 2e-3
epochs = 100
output_dim = len(set(clusters))
log_interval = 1
use_cuda = device != 'cpu'

In [28]:
tree = SDT(input_dim=tensor_dataset.shape[1], output_dim=len(clusters - 1), depth=tree_depth, lamda=1e-3, use_cuda=use_cuda)
optimizer = torch.optim.Adam(tree.parameters(),
                                 lr=lr,
                                 weight_decay=weight_decay)
criterion = nn.CrossEntropyLoss()
tree = tree.to(device)

In [29]:
losses = []
accs = []
sparsity = []

In [30]:
start_time = time.time()
iteration = 0
for epoch in range(epochs):
    # Training
    avg_sp = show_sparseness(tree)
    sparsity.append(avg_sp)
    iteration = do_epoch(tree, tree_loader, device, log_interval, losses, accs, epoch, iteration)
    
    if epoch % 1 == 0:
        prune_tree(tree, factor=3)
        

Average sparseness: 0.0
layer 0: 0.0
layer 1: 0.0
layer 2: 0.0
layer 3: 0.0
layer 4: 0.0
layer 5: 0.0
layer 6: 0.0
layer 7: 0.0
layer 8: 0.0
layer 9: 0.0
layer 10: 0.0
Epoch: 00 | Batch: 000 / 029 | Total loss: 9.631 | Reg loss: 0.014 | Tree loss: 9.631 | Accuracy: 0.000000 | 6.98 sec/iter
Epoch: 00 | Batch: 001 / 029 | Total loss: 9.625 | Reg loss: 0.013 | Tree loss: 9.625 | Accuracy: 0.000000 | 6.279 sec/iter
Epoch: 00 | Batch: 002 / 029 | Total loss: 9.619 | Reg loss: 0.012 | Tree loss: 9.619 | Accuracy: 0.000000 | 6.027 sec/iter
Epoch: 00 | Batch: 003 / 029 | Total loss: 9.614 | Reg loss: 0.011 | Tree loss: 9.614 | Accuracy: 0.000000 | 5.884 sec/iter
Epoch: 00 | Batch: 004 / 029 | Total loss: 9.608 | Reg loss: 0.010 | Tree loss: 9.608 | Accuracy: 0.000000 | 6.131 sec/iter
Epoch: 00 | Batch: 005 / 029 | Total loss: 9.602 | Reg loss: 0.009 | Tree loss: 9.602 | Accuracy: 0.000000 | 6.273 sec/iter
Epoch: 00 | Batch: 006 / 029 | Total loss: 9.596 | Reg loss: 0.009 | Tree loss: 9.596 | A

Epoch: 02 | Batch: 002 / 029 | Total loss: 9.463 | Reg loss: 0.009 | Tree loss: 9.463 | Accuracy: 0.529297 | 6.701 sec/iter
Epoch: 02 | Batch: 003 / 029 | Total loss: 9.463 | Reg loss: 0.009 | Tree loss: 9.463 | Accuracy: 0.492188 | 6.708 sec/iter
Epoch: 02 | Batch: 004 / 029 | Total loss: 9.457 | Reg loss: 0.009 | Tree loss: 9.457 | Accuracy: 0.474609 | 6.715 sec/iter
Epoch: 02 | Batch: 005 / 029 | Total loss: 9.450 | Reg loss: 0.009 | Tree loss: 9.450 | Accuracy: 0.505859 | 6.721 sec/iter
Epoch: 02 | Batch: 006 / 029 | Total loss: 9.443 | Reg loss: 0.010 | Tree loss: 9.443 | Accuracy: 0.498047 | 6.728 sec/iter
Epoch: 02 | Batch: 007 / 029 | Total loss: 9.443 | Reg loss: 0.010 | Tree loss: 9.443 | Accuracy: 0.458984 | 6.733 sec/iter
Epoch: 02 | Batch: 008 / 029 | Total loss: 9.431 | Reg loss: 0.010 | Tree loss: 9.431 | Accuracy: 0.517578 | 6.74 sec/iter
Epoch: 02 | Batch: 009 / 029 | Total loss: 9.427 | Reg loss: 0.011 | Tree loss: 9.427 | Accuracy: 0.484375 | 6.746 sec/iter
Epoch: 02

Epoch: 04 | Batch: 005 / 029 | Total loss: 9.220 | Reg loss: 0.015 | Tree loss: 9.220 | Accuracy: 0.505859 | 6.687 sec/iter
Epoch: 04 | Batch: 006 / 029 | Total loss: 9.204 | Reg loss: 0.015 | Tree loss: 9.204 | Accuracy: 0.521484 | 6.69 sec/iter
Epoch: 04 | Batch: 007 / 029 | Total loss: 9.183 | Reg loss: 0.015 | Tree loss: 9.183 | Accuracy: 0.527344 | 6.693 sec/iter
Epoch: 04 | Batch: 008 / 029 | Total loss: 9.179 | Reg loss: 0.016 | Tree loss: 9.179 | Accuracy: 0.476562 | 6.695 sec/iter
Epoch: 04 | Batch: 009 / 029 | Total loss: 9.157 | Reg loss: 0.016 | Tree loss: 9.157 | Accuracy: 0.498047 | 6.695 sec/iter
Epoch: 04 | Batch: 010 / 029 | Total loss: 9.139 | Reg loss: 0.017 | Tree loss: 9.139 | Accuracy: 0.515625 | 6.696 sec/iter
Epoch: 04 | Batch: 011 / 029 | Total loss: 9.125 | Reg loss: 0.017 | Tree loss: 9.125 | Accuracy: 0.482422 | 6.697 sec/iter
Epoch: 04 | Batch: 012 / 029 | Total loss: 9.099 | Reg loss: 0.017 | Tree loss: 9.099 | Accuracy: 0.515625 | 6.697 sec/iter
Epoch: 04

Epoch: 06 | Batch: 008 / 029 | Total loss: 8.658 | Reg loss: 0.020 | Tree loss: 8.658 | Accuracy: 0.505859 | 6.978 sec/iter
Epoch: 06 | Batch: 009 / 029 | Total loss: 8.648 | Reg loss: 0.020 | Tree loss: 8.648 | Accuracy: 0.494141 | 6.979 sec/iter
Epoch: 06 | Batch: 010 / 029 | Total loss: 8.626 | Reg loss: 0.021 | Tree loss: 8.626 | Accuracy: 0.474609 | 6.979 sec/iter
Epoch: 06 | Batch: 011 / 029 | Total loss: 8.596 | Reg loss: 0.021 | Tree loss: 8.596 | Accuracy: 0.503906 | 6.98 sec/iter
Epoch: 06 | Batch: 012 / 029 | Total loss: 8.548 | Reg loss: 0.021 | Tree loss: 8.548 | Accuracy: 0.535156 | 6.979 sec/iter
Epoch: 06 | Batch: 013 / 029 | Total loss: 8.515 | Reg loss: 0.022 | Tree loss: 8.515 | Accuracy: 0.509766 | 6.978 sec/iter
Epoch: 06 | Batch: 014 / 029 | Total loss: 8.493 | Reg loss: 0.022 | Tree loss: 8.493 | Accuracy: 0.472656 | 6.978 sec/iter
Epoch: 06 | Batch: 015 / 029 | Total loss: 8.478 | Reg loss: 0.022 | Tree loss: 8.478 | Accuracy: 0.466797 | 6.976 sec/iter
Epoch: 06

Epoch: 08 | Batch: 011 / 029 | Total loss: 7.915 | Reg loss: 0.024 | Tree loss: 7.915 | Accuracy: 0.509766 | 7.1 sec/iter
Epoch: 08 | Batch: 012 / 029 | Total loss: 7.900 | Reg loss: 0.025 | Tree loss: 7.900 | Accuracy: 0.503906 | 7.092 sec/iter
Epoch: 08 | Batch: 013 / 029 | Total loss: 7.852 | Reg loss: 0.025 | Tree loss: 7.852 | Accuracy: 0.484375 | 7.084 sec/iter
Epoch: 08 | Batch: 014 / 029 | Total loss: 7.857 | Reg loss: 0.025 | Tree loss: 7.857 | Accuracy: 0.474609 | 7.084 sec/iter
Epoch: 08 | Batch: 015 / 029 | Total loss: 7.794 | Reg loss: 0.026 | Tree loss: 7.794 | Accuracy: 0.496094 | 7.083 sec/iter
Epoch: 08 | Batch: 016 / 029 | Total loss: 7.736 | Reg loss: 0.026 | Tree loss: 7.736 | Accuracy: 0.513672 | 7.081 sec/iter
Epoch: 08 | Batch: 017 / 029 | Total loss: 7.737 | Reg loss: 0.026 | Tree loss: 7.737 | Accuracy: 0.474609 | 7.08 sec/iter
Epoch: 08 | Batch: 018 / 029 | Total loss: 7.729 | Reg loss: 0.027 | Tree loss: 7.729 | Accuracy: 0.480469 | 7.078 sec/iter
Epoch: 08 |

Epoch: 10 | Batch: 014 / 029 | Total loss: 7.176 | Reg loss: 0.029 | Tree loss: 7.176 | Accuracy: 0.480469 | 7.037 sec/iter
Epoch: 10 | Batch: 015 / 029 | Total loss: 7.123 | Reg loss: 0.029 | Tree loss: 7.123 | Accuracy: 0.472656 | 7.036 sec/iter
Epoch: 10 | Batch: 016 / 029 | Total loss: 7.111 | Reg loss: 0.029 | Tree loss: 7.111 | Accuracy: 0.523438 | 7.03 sec/iter
Epoch: 10 | Batch: 017 / 029 | Total loss: 7.099 | Reg loss: 0.030 | Tree loss: 7.099 | Accuracy: 0.464844 | 7.024 sec/iter
Epoch: 10 | Batch: 018 / 029 | Total loss: 7.043 | Reg loss: 0.030 | Tree loss: 7.043 | Accuracy: 0.521484 | 7.018 sec/iter
Epoch: 10 | Batch: 019 / 029 | Total loss: 7.013 | Reg loss: 0.030 | Tree loss: 7.013 | Accuracy: 0.505859 | 7.018 sec/iter
Epoch: 10 | Batch: 020 / 029 | Total loss: 7.020 | Reg loss: 0.030 | Tree loss: 7.020 | Accuracy: 0.476562 | 7.018 sec/iter
Epoch: 10 | Batch: 021 / 029 | Total loss: 7.014 | Reg loss: 0.031 | Tree loss: 7.014 | Accuracy: 0.490234 | 7.019 sec/iter
Epoch: 10

Epoch: 12 | Batch: 017 / 029 | Total loss: 6.456 | Reg loss: 0.031 | Tree loss: 6.456 | Accuracy: 0.539062 | 7.064 sec/iter
Epoch: 12 | Batch: 018 / 029 | Total loss: 6.432 | Reg loss: 0.032 | Tree loss: 6.432 | Accuracy: 0.519531 | 7.063 sec/iter
Epoch: 12 | Batch: 019 / 029 | Total loss: 6.425 | Reg loss: 0.032 | Tree loss: 6.425 | Accuracy: 0.490234 | 7.063 sec/iter
Epoch: 12 | Batch: 020 / 029 | Total loss: 6.404 | Reg loss: 0.032 | Tree loss: 6.404 | Accuracy: 0.474609 | 7.063 sec/iter
Epoch: 12 | Batch: 021 / 029 | Total loss: 6.411 | Reg loss: 0.032 | Tree loss: 6.411 | Accuracy: 0.462891 | 7.063 sec/iter
Epoch: 12 | Batch: 022 / 029 | Total loss: 6.394 | Reg loss: 0.033 | Tree loss: 6.394 | Accuracy: 0.474609 | 7.064 sec/iter
Epoch: 12 | Batch: 023 / 029 | Total loss: 6.369 | Reg loss: 0.033 | Tree loss: 6.369 | Accuracy: 0.460938 | 7.064 sec/iter
Epoch: 12 | Batch: 024 / 029 | Total loss: 6.289 | Reg loss: 0.033 | Tree loss: 6.289 | Accuracy: 0.541016 | 7.064 sec/iter
Epoch: 1

Epoch: 14 | Batch: 020 / 029 | Total loss: 5.871 | Reg loss: 0.033 | Tree loss: 5.871 | Accuracy: 0.500000 | 7.121 sec/iter
Epoch: 14 | Batch: 021 / 029 | Total loss: 5.831 | Reg loss: 0.033 | Tree loss: 5.831 | Accuracy: 0.513672 | 7.122 sec/iter
Epoch: 14 | Batch: 022 / 029 | Total loss: 5.839 | Reg loss: 0.033 | Tree loss: 5.839 | Accuracy: 0.455078 | 7.122 sec/iter
Epoch: 14 | Batch: 023 / 029 | Total loss: 5.776 | Reg loss: 0.033 | Tree loss: 5.776 | Accuracy: 0.498047 | 7.123 sec/iter
Epoch: 14 | Batch: 024 / 029 | Total loss: 5.814 | Reg loss: 0.034 | Tree loss: 5.814 | Accuracy: 0.480469 | 7.123 sec/iter
Epoch: 14 | Batch: 025 / 029 | Total loss: 5.748 | Reg loss: 0.034 | Tree loss: 5.748 | Accuracy: 0.486328 | 7.123 sec/iter
Epoch: 14 | Batch: 026 / 029 | Total loss: 5.721 | Reg loss: 0.034 | Tree loss: 5.721 | Accuracy: 0.468750 | 7.124 sec/iter
Epoch: 14 | Batch: 027 / 029 | Total loss: 5.681 | Reg loss: 0.034 | Tree loss: 5.681 | Accuracy: 0.527344 | 7.124 sec/iter
Epoch: 1

Epoch: 16 | Batch: 023 / 029 | Total loss: 5.292 | Reg loss: 0.034 | Tree loss: 5.292 | Accuracy: 0.476562 | 7.085 sec/iter
Epoch: 16 | Batch: 024 / 029 | Total loss: 5.268 | Reg loss: 0.034 | Tree loss: 5.268 | Accuracy: 0.474609 | 7.081 sec/iter
Epoch: 16 | Batch: 025 / 029 | Total loss: 5.270 | Reg loss: 0.034 | Tree loss: 5.270 | Accuracy: 0.503906 | 7.076 sec/iter
Epoch: 16 | Batch: 026 / 029 | Total loss: 5.240 | Reg loss: 0.034 | Tree loss: 5.240 | Accuracy: 0.490234 | 7.073 sec/iter
Epoch: 16 | Batch: 027 / 029 | Total loss: 5.221 | Reg loss: 0.034 | Tree loss: 5.221 | Accuracy: 0.486328 | 7.069 sec/iter
Epoch: 16 | Batch: 028 / 029 | Total loss: 5.204 | Reg loss: 0.034 | Tree loss: 5.204 | Accuracy: 0.499014 | 7.066 sec/iter
Average sparseness: 0.9821428571428572
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
layer 7: 0.98214285714

Epoch: 18 | Batch: 026 / 029 | Total loss: 4.796 | Reg loss: 0.034 | Tree loss: 4.796 | Accuracy: 0.486328 | 7.199 sec/iter
Epoch: 18 | Batch: 027 / 029 | Total loss: 4.769 | Reg loss: 0.034 | Tree loss: 4.769 | Accuracy: 0.457031 | 7.198 sec/iter
Epoch: 18 | Batch: 028 / 029 | Total loss: 4.751 | Reg loss: 0.034 | Tree loss: 4.751 | Accuracy: 0.491124 | 7.198 sec/iter
Average sparseness: 0.9821428571428572
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
layer 7: 0.9821428571428573
layer 8: 0.9821428571428573
layer 9: 0.9821428571428573
layer 10: 0.9821428571428573
Epoch: 19 | Batch: 000 / 029 | Total loss: 5.069 | Reg loss: 0.033 | Tree loss: 5.069 | Accuracy: 0.423828 | 7.255 sec/iter
Epoch: 19 | Batch: 001 / 029 | Total loss: 5.050 | Reg loss: 0.033 | Tree loss: 5.050 | Accuracy: 0.464844 | 7.255 sec/iter
Epoch: 19 | Batch: 002 / 029 | To

Average sparseness: 0.9821428571428572
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
layer 7: 0.9821428571428573
layer 8: 0.9821428571428573
layer 9: 0.9821428571428573
layer 10: 0.9821428571428573
Epoch: 21 | Batch: 000 / 029 | Total loss: 4.562 | Reg loss: 0.033 | Tree loss: 4.562 | Accuracy: 0.478516 | 7.257 sec/iter
Epoch: 21 | Batch: 001 / 029 | Total loss: 4.575 | Reg loss: 0.033 | Tree loss: 4.575 | Accuracy: 0.494141 | 7.256 sec/iter
Epoch: 21 | Batch: 002 / 029 | Total loss: 4.557 | Reg loss: 0.033 | Tree loss: 4.557 | Accuracy: 0.470703 | 7.256 sec/iter
Epoch: 21 | Batch: 003 / 029 | Total loss: 4.524 | Reg loss: 0.033 | Tree loss: 4.524 | Accuracy: 0.496094 | 7.254 sec/iter
Epoch: 21 | Batch: 004 / 029 | Total loss: 4.488 | Reg loss: 0.033 | Tree loss: 4.488 | Accuracy: 0.492188 | 7.252 sec/iter
Epoch: 21 | Batch: 005 / 029 | To

layer 10: 0.9821428571428573
Epoch: 23 | Batch: 000 / 029 | Total loss: 4.205 | Reg loss: 0.033 | Tree loss: 4.205 | Accuracy: 0.458984 | 7.198 sec/iter
Epoch: 23 | Batch: 001 / 029 | Total loss: 4.163 | Reg loss: 0.033 | Tree loss: 4.163 | Accuracy: 0.466797 | 7.197 sec/iter
Epoch: 23 | Batch: 002 / 029 | Total loss: 4.122 | Reg loss: 0.033 | Tree loss: 4.122 | Accuracy: 0.500000 | 7.197 sec/iter
Epoch: 23 | Batch: 003 / 029 | Total loss: 4.156 | Reg loss: 0.033 | Tree loss: 4.156 | Accuracy: 0.449219 | 7.197 sec/iter
Epoch: 23 | Batch: 004 / 029 | Total loss: 4.087 | Reg loss: 0.033 | Tree loss: 4.087 | Accuracy: 0.480469 | 7.197 sec/iter
Epoch: 23 | Batch: 005 / 029 | Total loss: 4.086 | Reg loss: 0.033 | Tree loss: 4.086 | Accuracy: 0.478516 | 7.197 sec/iter
Epoch: 23 | Batch: 006 / 029 | Total loss: 4.053 | Reg loss: 0.033 | Tree loss: 4.053 | Accuracy: 0.503906 | 7.197 sec/iter
Epoch: 23 | Batch: 007 / 029 | Total loss: 4.045 | Reg loss: 0.033 | Tree loss: 4.045 | Accuracy: 0.474

Epoch: 25 | Batch: 003 / 029 | Total loss: 3.726 | Reg loss: 0.033 | Tree loss: 3.726 | Accuracy: 0.480469 | 7.244 sec/iter
Epoch: 25 | Batch: 004 / 029 | Total loss: 3.731 | Reg loss: 0.033 | Tree loss: 3.731 | Accuracy: 0.498047 | 7.244 sec/iter
Epoch: 25 | Batch: 005 / 029 | Total loss: 3.719 | Reg loss: 0.033 | Tree loss: 3.719 | Accuracy: 0.501953 | 7.244 sec/iter
Epoch: 25 | Batch: 006 / 029 | Total loss: 3.697 | Reg loss: 0.033 | Tree loss: 3.697 | Accuracy: 0.464844 | 7.244 sec/iter
Epoch: 25 | Batch: 007 / 029 | Total loss: 3.673 | Reg loss: 0.033 | Tree loss: 3.673 | Accuracy: 0.464844 | 7.244 sec/iter
Epoch: 25 | Batch: 008 / 029 | Total loss: 3.655 | Reg loss: 0.033 | Tree loss: 3.655 | Accuracy: 0.507812 | 7.244 sec/iter
Epoch: 25 | Batch: 009 / 029 | Total loss: 3.626 | Reg loss: 0.033 | Tree loss: 3.626 | Accuracy: 0.507812 | 7.244 sec/iter
Epoch: 25 | Batch: 010 / 029 | Total loss: 3.593 | Reg loss: 0.033 | Tree loss: 3.593 | Accuracy: 0.501953 | 7.244 sec/iter
Epoch: 2

Epoch: 27 | Batch: 006 / 029 | Total loss: 3.364 | Reg loss: 0.033 | Tree loss: 3.364 | Accuracy: 0.460938 | 7.269 sec/iter
Epoch: 27 | Batch: 007 / 029 | Total loss: 3.279 | Reg loss: 0.033 | Tree loss: 3.279 | Accuracy: 0.517578 | 7.266 sec/iter
Epoch: 27 | Batch: 008 / 029 | Total loss: 3.290 | Reg loss: 0.033 | Tree loss: 3.290 | Accuracy: 0.498047 | 7.266 sec/iter
Epoch: 27 | Batch: 009 / 029 | Total loss: 3.281 | Reg loss: 0.033 | Tree loss: 3.281 | Accuracy: 0.492188 | 7.265 sec/iter
Epoch: 27 | Batch: 010 / 029 | Total loss: 3.300 | Reg loss: 0.033 | Tree loss: 3.300 | Accuracy: 0.462891 | 7.264 sec/iter
Epoch: 27 | Batch: 011 / 029 | Total loss: 3.257 | Reg loss: 0.033 | Tree loss: 3.257 | Accuracy: 0.488281 | 7.263 sec/iter
Epoch: 27 | Batch: 012 / 029 | Total loss: 3.265 | Reg loss: 0.033 | Tree loss: 3.265 | Accuracy: 0.462891 | 7.263 sec/iter
Epoch: 27 | Batch: 013 / 029 | Total loss: 3.238 | Reg loss: 0.033 | Tree loss: 3.238 | Accuracy: 0.462891 | 7.262 sec/iter
Epoch: 2

Epoch: 29 | Batch: 009 / 029 | Total loss: 2.975 | Reg loss: 0.033 | Tree loss: 2.975 | Accuracy: 0.478516 | 7.261 sec/iter
Epoch: 29 | Batch: 010 / 029 | Total loss: 2.923 | Reg loss: 0.033 | Tree loss: 2.923 | Accuracy: 0.480469 | 7.258 sec/iter
Epoch: 29 | Batch: 011 / 029 | Total loss: 2.941 | Reg loss: 0.033 | Tree loss: 2.941 | Accuracy: 0.486328 | 7.258 sec/iter
Epoch: 29 | Batch: 012 / 029 | Total loss: 2.934 | Reg loss: 0.033 | Tree loss: 2.934 | Accuracy: 0.488281 | 7.257 sec/iter
Epoch: 29 | Batch: 013 / 029 | Total loss: 2.901 | Reg loss: 0.033 | Tree loss: 2.901 | Accuracy: 0.494141 | 7.256 sec/iter
Epoch: 29 | Batch: 014 / 029 | Total loss: 2.920 | Reg loss: 0.033 | Tree loss: 2.920 | Accuracy: 0.458984 | 7.256 sec/iter
Epoch: 29 | Batch: 015 / 029 | Total loss: 2.895 | Reg loss: 0.033 | Tree loss: 2.895 | Accuracy: 0.496094 | 7.256 sec/iter
Epoch: 29 | Batch: 016 / 029 | Total loss: 2.892 | Reg loss: 0.033 | Tree loss: 2.892 | Accuracy: 0.476562 | 7.256 sec/iter
Epoch: 2

Epoch: 31 | Batch: 012 / 029 | Total loss: 2.630 | Reg loss: 0.033 | Tree loss: 2.630 | Accuracy: 0.492188 | 7.244 sec/iter
Epoch: 31 | Batch: 013 / 029 | Total loss: 2.649 | Reg loss: 0.033 | Tree loss: 2.649 | Accuracy: 0.484375 | 7.241 sec/iter
Epoch: 31 | Batch: 014 / 029 | Total loss: 2.619 | Reg loss: 0.033 | Tree loss: 2.619 | Accuracy: 0.494141 | 7.239 sec/iter
Epoch: 31 | Batch: 015 / 029 | Total loss: 2.576 | Reg loss: 0.033 | Tree loss: 2.576 | Accuracy: 0.521484 | 7.236 sec/iter
Epoch: 31 | Batch: 016 / 029 | Total loss: 2.602 | Reg loss: 0.033 | Tree loss: 2.602 | Accuracy: 0.470703 | 7.233 sec/iter
Epoch: 31 | Batch: 017 / 029 | Total loss: 2.591 | Reg loss: 0.033 | Tree loss: 2.591 | Accuracy: 0.474609 | 7.23 sec/iter
Epoch: 31 | Batch: 018 / 029 | Total loss: 2.559 | Reg loss: 0.033 | Tree loss: 2.559 | Accuracy: 0.501953 | 7.228 sec/iter
Epoch: 31 | Batch: 019 / 029 | Total loss: 2.527 | Reg loss: 0.033 | Tree loss: 2.527 | Accuracy: 0.521484 | 7.225 sec/iter
Epoch: 31

Epoch: 33 | Batch: 015 / 029 | Total loss: 2.383 | Reg loss: 0.033 | Tree loss: 2.383 | Accuracy: 0.496094 | 7.214 sec/iter
Epoch: 33 | Batch: 016 / 029 | Total loss: 2.373 | Reg loss: 0.033 | Tree loss: 2.373 | Accuracy: 0.496094 | 7.214 sec/iter
Epoch: 33 | Batch: 017 / 029 | Total loss: 2.327 | Reg loss: 0.033 | Tree loss: 2.327 | Accuracy: 0.509766 | 7.214 sec/iter
Epoch: 33 | Batch: 018 / 029 | Total loss: 2.359 | Reg loss: 0.033 | Tree loss: 2.359 | Accuracy: 0.464844 | 7.214 sec/iter
Epoch: 33 | Batch: 019 / 029 | Total loss: 2.319 | Reg loss: 0.033 | Tree loss: 2.319 | Accuracy: 0.496094 | 7.213 sec/iter
Epoch: 33 | Batch: 020 / 029 | Total loss: 2.351 | Reg loss: 0.033 | Tree loss: 2.351 | Accuracy: 0.447266 | 7.213 sec/iter
Epoch: 33 | Batch: 021 / 029 | Total loss: 2.312 | Reg loss: 0.033 | Tree loss: 2.312 | Accuracy: 0.484375 | 7.213 sec/iter
Epoch: 33 | Batch: 022 / 029 | Total loss: 2.287 | Reg loss: 0.033 | Tree loss: 2.287 | Accuracy: 0.505859 | 7.213 sec/iter
Epoch: 3

Epoch: 35 | Batch: 018 / 029 | Total loss: 2.140 | Reg loss: 0.033 | Tree loss: 2.140 | Accuracy: 0.529297 | 7.277 sec/iter
Epoch: 35 | Batch: 019 / 029 | Total loss: 2.138 | Reg loss: 0.033 | Tree loss: 2.138 | Accuracy: 0.511719 | 7.276 sec/iter
Epoch: 35 | Batch: 020 / 029 | Total loss: 2.119 | Reg loss: 0.033 | Tree loss: 2.119 | Accuracy: 0.496094 | 7.276 sec/iter
Epoch: 35 | Batch: 021 / 029 | Total loss: 2.119 | Reg loss: 0.033 | Tree loss: 2.119 | Accuracy: 0.505859 | 7.276 sec/iter
Epoch: 35 | Batch: 022 / 029 | Total loss: 2.136 | Reg loss: 0.033 | Tree loss: 2.136 | Accuracy: 0.460938 | 7.275 sec/iter
Epoch: 35 | Batch: 023 / 029 | Total loss: 2.111 | Reg loss: 0.033 | Tree loss: 2.111 | Accuracy: 0.488281 | 7.275 sec/iter
Epoch: 35 | Batch: 024 / 029 | Total loss: 2.116 | Reg loss: 0.033 | Tree loss: 2.116 | Accuracy: 0.470703 | 7.275 sec/iter
Epoch: 35 | Batch: 025 / 029 | Total loss: 2.078 | Reg loss: 0.033 | Tree loss: 2.078 | Accuracy: 0.529297 | 7.275 sec/iter
Epoch: 3

Epoch: 37 | Batch: 021 / 029 | Total loss: 1.983 | Reg loss: 0.033 | Tree loss: 1.983 | Accuracy: 0.500000 | 7.273 sec/iter
Epoch: 37 | Batch: 022 / 029 | Total loss: 1.973 | Reg loss: 0.033 | Tree loss: 1.973 | Accuracy: 0.488281 | 7.273 sec/iter
Epoch: 37 | Batch: 023 / 029 | Total loss: 1.905 | Reg loss: 0.033 | Tree loss: 1.905 | Accuracy: 0.544922 | 7.273 sec/iter
Epoch: 37 | Batch: 024 / 029 | Total loss: 1.933 | Reg loss: 0.033 | Tree loss: 1.933 | Accuracy: 0.521484 | 7.273 sec/iter
Epoch: 37 | Batch: 025 / 029 | Total loss: 1.949 | Reg loss: 0.033 | Tree loss: 1.949 | Accuracy: 0.484375 | 7.273 sec/iter
Epoch: 37 | Batch: 026 / 029 | Total loss: 1.971 | Reg loss: 0.033 | Tree loss: 1.971 | Accuracy: 0.460938 | 7.272 sec/iter
Epoch: 37 | Batch: 027 / 029 | Total loss: 1.955 | Reg loss: 0.033 | Tree loss: 1.955 | Accuracy: 0.474609 | 7.272 sec/iter
Epoch: 37 | Batch: 028 / 029 | Total loss: 1.931 | Reg loss: 0.033 | Tree loss: 1.931 | Accuracy: 0.504931 | 7.272 sec/iter
Average 

Epoch: 39 | Batch: 024 / 029 | Total loss: 1.821 | Reg loss: 0.033 | Tree loss: 1.821 | Accuracy: 0.517578 | 7.242 sec/iter
Epoch: 39 | Batch: 025 / 029 | Total loss: 1.846 | Reg loss: 0.033 | Tree loss: 1.846 | Accuracy: 0.472656 | 7.242 sec/iter
Epoch: 39 | Batch: 026 / 029 | Total loss: 1.835 | Reg loss: 0.033 | Tree loss: 1.835 | Accuracy: 0.482422 | 7.242 sec/iter
Epoch: 39 | Batch: 027 / 029 | Total loss: 1.822 | Reg loss: 0.033 | Tree loss: 1.822 | Accuracy: 0.496094 | 7.241 sec/iter
Epoch: 39 | Batch: 028 / 029 | Total loss: 1.814 | Reg loss: 0.033 | Tree loss: 1.814 | Accuracy: 0.518738 | 7.24 sec/iter
Average sparseness: 0.9821428571428572
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
layer 7: 0.9821428571428573
layer 8: 0.9821428571428573
layer 9: 0.9821428571428573
layer 10: 0.9821428571428573
Epoch: 40 | Batch: 000 / 029 | Tot

Epoch: 41 | Batch: 027 / 029 | Total loss: 1.712 | Reg loss: 0.033 | Tree loss: 1.712 | Accuracy: 0.509766 | 7.229 sec/iter
Epoch: 41 | Batch: 028 / 029 | Total loss: 1.731 | Reg loss: 0.033 | Tree loss: 1.731 | Accuracy: 0.495069 | 7.229 sec/iter
Average sparseness: 0.9821428571428572
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
layer 7: 0.9821428571428573
layer 8: 0.9821428571428573
layer 9: 0.9821428571428573
layer 10: 0.9821428571428573
Epoch: 42 | Batch: 000 / 029 | Total loss: 1.904 | Reg loss: 0.032 | Tree loss: 1.904 | Accuracy: 0.482422 | 7.235 sec/iter
Epoch: 42 | Batch: 001 / 029 | Total loss: 1.875 | Reg loss: 0.032 | Tree loss: 1.875 | Accuracy: 0.513672 | 7.235 sec/iter
Epoch: 42 | Batch: 002 / 029 | Total loss: 1.846 | Reg loss: 0.032 | Tree loss: 1.846 | Accuracy: 0.535156 | 7.235 sec/iter
Epoch: 42 | Batch: 003 / 029 | To

layer 5: 0.9821428571428571
layer 6: 0.982142857142857
layer 7: 0.9821428571428573
layer 8: 0.9821428571428573
layer 9: 0.9821428571428573
layer 10: 0.9821428571428573
Epoch: 44 | Batch: 000 / 029 | Total loss: 1.809 | Reg loss: 0.032 | Tree loss: 1.809 | Accuracy: 0.496094 | 7.235 sec/iter
Epoch: 44 | Batch: 001 / 029 | Total loss: 1.811 | Reg loss: 0.032 | Tree loss: 1.811 | Accuracy: 0.494141 | 7.235 sec/iter
Epoch: 44 | Batch: 002 / 029 | Total loss: 1.804 | Reg loss: 0.032 | Tree loss: 1.804 | Accuracy: 0.484375 | 7.235 sec/iter
Epoch: 44 | Batch: 003 / 029 | Total loss: 1.793 | Reg loss: 0.032 | Tree loss: 1.793 | Accuracy: 0.496094 | 7.235 sec/iter
Epoch: 44 | Batch: 004 / 029 | Total loss: 1.799 | Reg loss: 0.032 | Tree loss: 1.799 | Accuracy: 0.503906 | 7.235 sec/iter
Epoch: 44 | Batch: 005 / 029 | Total loss: 1.768 | Reg loss: 0.032 | Tree loss: 1.768 | Accuracy: 0.513672 | 7.235 sec/iter
Epoch: 44 | Batch: 006 / 029 | Total loss: 1.762 | Reg loss: 0.032 | Tree loss: 1.762 | 

Epoch: 46 | Batch: 002 / 029 | Total loss: 1.733 | Reg loss: 0.031 | Tree loss: 1.733 | Accuracy: 0.498047 | 7.223 sec/iter
Epoch: 46 | Batch: 003 / 029 | Total loss: 1.718 | Reg loss: 0.031 | Tree loss: 1.718 | Accuracy: 0.507812 | 7.223 sec/iter
Epoch: 46 | Batch: 004 / 029 | Total loss: 1.700 | Reg loss: 0.031 | Tree loss: 1.700 | Accuracy: 0.496094 | 7.223 sec/iter
Epoch: 46 | Batch: 005 / 029 | Total loss: 1.721 | Reg loss: 0.031 | Tree loss: 1.721 | Accuracy: 0.476562 | 7.223 sec/iter
Epoch: 46 | Batch: 006 / 029 | Total loss: 1.709 | Reg loss: 0.032 | Tree loss: 1.709 | Accuracy: 0.498047 | 7.223 sec/iter
Epoch: 46 | Batch: 007 / 029 | Total loss: 1.693 | Reg loss: 0.032 | Tree loss: 1.693 | Accuracy: 0.478516 | 7.223 sec/iter
Epoch: 46 | Batch: 008 / 029 | Total loss: 1.690 | Reg loss: 0.032 | Tree loss: 1.690 | Accuracy: 0.478516 | 7.223 sec/iter
Epoch: 46 | Batch: 009 / 029 | Total loss: 1.648 | Reg loss: 0.032 | Tree loss: 1.648 | Accuracy: 0.533203 | 7.223 sec/iter
Epoch: 4

Epoch: 48 | Batch: 005 / 029 | Total loss: 1.626 | Reg loss: 0.031 | Tree loss: 1.626 | Accuracy: 0.519531 | 7.216 sec/iter
Epoch: 48 | Batch: 006 / 029 | Total loss: 1.654 | Reg loss: 0.031 | Tree loss: 1.654 | Accuracy: 0.486328 | 7.216 sec/iter
Epoch: 48 | Batch: 007 / 029 | Total loss: 1.614 | Reg loss: 0.031 | Tree loss: 1.614 | Accuracy: 0.521484 | 7.216 sec/iter
Epoch: 48 | Batch: 008 / 029 | Total loss: 1.649 | Reg loss: 0.031 | Tree loss: 1.649 | Accuracy: 0.470703 | 7.216 sec/iter
Epoch: 48 | Batch: 009 / 029 | Total loss: 1.608 | Reg loss: 0.031 | Tree loss: 1.608 | Accuracy: 0.503906 | 7.216 sec/iter
Epoch: 48 | Batch: 010 / 029 | Total loss: 1.607 | Reg loss: 0.031 | Tree loss: 1.607 | Accuracy: 0.505859 | 7.216 sec/iter
Epoch: 48 | Batch: 011 / 029 | Total loss: 1.587 | Reg loss: 0.031 | Tree loss: 1.587 | Accuracy: 0.511719 | 7.216 sec/iter
Epoch: 48 | Batch: 012 / 029 | Total loss: 1.580 | Reg loss: 0.031 | Tree loss: 1.580 | Accuracy: 0.500000 | 7.215 sec/iter
Epoch: 4

Epoch: 50 | Batch: 008 / 029 | Total loss: 1.577 | Reg loss: 0.031 | Tree loss: 1.577 | Accuracy: 0.484375 | 7.205 sec/iter
Epoch: 50 | Batch: 009 / 029 | Total loss: 1.571 | Reg loss: 0.031 | Tree loss: 1.571 | Accuracy: 0.490234 | 7.204 sec/iter
Epoch: 50 | Batch: 010 / 029 | Total loss: 1.573 | Reg loss: 0.031 | Tree loss: 1.573 | Accuracy: 0.464844 | 7.204 sec/iter
Epoch: 50 | Batch: 011 / 029 | Total loss: 1.537 | Reg loss: 0.031 | Tree loss: 1.537 | Accuracy: 0.519531 | 7.204 sec/iter
Epoch: 50 | Batch: 012 / 029 | Total loss: 1.553 | Reg loss: 0.031 | Tree loss: 1.553 | Accuracy: 0.488281 | 7.203 sec/iter
Epoch: 50 | Batch: 013 / 029 | Total loss: 1.534 | Reg loss: 0.031 | Tree loss: 1.534 | Accuracy: 0.500000 | 7.203 sec/iter
Epoch: 50 | Batch: 014 / 029 | Total loss: 1.495 | Reg loss: 0.031 | Tree loss: 1.495 | Accuracy: 0.542969 | 7.203 sec/iter
Epoch: 50 | Batch: 015 / 029 | Total loss: 1.521 | Reg loss: 0.031 | Tree loss: 1.521 | Accuracy: 0.496094 | 7.202 sec/iter
Epoch: 5

Epoch: 52 | Batch: 011 / 029 | Total loss: 1.511 | Reg loss: 0.031 | Tree loss: 1.511 | Accuracy: 0.511719 | 7.189 sec/iter
Epoch: 52 | Batch: 012 / 029 | Total loss: 1.489 | Reg loss: 0.031 | Tree loss: 1.489 | Accuracy: 0.521484 | 7.189 sec/iter
Epoch: 52 | Batch: 013 / 029 | Total loss: 1.527 | Reg loss: 0.031 | Tree loss: 1.527 | Accuracy: 0.460938 | 7.189 sec/iter
Epoch: 52 | Batch: 014 / 029 | Total loss: 1.511 | Reg loss: 0.031 | Tree loss: 1.511 | Accuracy: 0.472656 | 7.189 sec/iter
Epoch: 52 | Batch: 015 / 029 | Total loss: 1.517 | Reg loss: 0.031 | Tree loss: 1.517 | Accuracy: 0.468750 | 7.189 sec/iter
Epoch: 52 | Batch: 016 / 029 | Total loss: 1.469 | Reg loss: 0.031 | Tree loss: 1.469 | Accuracy: 0.519531 | 7.188 sec/iter
Epoch: 52 | Batch: 017 / 029 | Total loss: 1.482 | Reg loss: 0.031 | Tree loss: 1.482 | Accuracy: 0.484375 | 7.188 sec/iter
Epoch: 52 | Batch: 018 / 029 | Total loss: 1.481 | Reg loss: 0.031 | Tree loss: 1.481 | Accuracy: 0.480469 | 7.188 sec/iter
Epoch: 5

Epoch: 54 | Batch: 014 / 029 | Total loss: 1.454 | Reg loss: 0.031 | Tree loss: 1.454 | Accuracy: 0.503906 | 7.187 sec/iter
Epoch: 54 | Batch: 015 / 029 | Total loss: 1.452 | Reg loss: 0.031 | Tree loss: 1.452 | Accuracy: 0.500000 | 7.187 sec/iter
Epoch: 54 | Batch: 016 / 029 | Total loss: 1.430 | Reg loss: 0.031 | Tree loss: 1.430 | Accuracy: 0.521484 | 7.187 sec/iter
Epoch: 54 | Batch: 017 / 029 | Total loss: 1.454 | Reg loss: 0.031 | Tree loss: 1.454 | Accuracy: 0.488281 | 7.187 sec/iter
Epoch: 54 | Batch: 018 / 029 | Total loss: 1.443 | Reg loss: 0.031 | Tree loss: 1.443 | Accuracy: 0.488281 | 7.187 sec/iter
Epoch: 54 | Batch: 019 / 029 | Total loss: 1.456 | Reg loss: 0.031 | Tree loss: 1.456 | Accuracy: 0.472656 | 7.187 sec/iter
Epoch: 54 | Batch: 020 / 029 | Total loss: 1.463 | Reg loss: 0.031 | Tree loss: 1.463 | Accuracy: 0.451172 | 7.187 sec/iter
Epoch: 54 | Batch: 021 / 029 | Total loss: 1.423 | Reg loss: 0.031 | Tree loss: 1.423 | Accuracy: 0.498047 | 7.187 sec/iter
Epoch: 5

Epoch: 56 | Batch: 017 / 029 | Total loss: 1.415 | Reg loss: 0.031 | Tree loss: 1.415 | Accuracy: 0.484375 | 7.181 sec/iter
Epoch: 56 | Batch: 018 / 029 | Total loss: 1.399 | Reg loss: 0.031 | Tree loss: 1.399 | Accuracy: 0.503906 | 7.181 sec/iter
Epoch: 56 | Batch: 019 / 029 | Total loss: 1.394 | Reg loss: 0.031 | Tree loss: 1.394 | Accuracy: 0.503906 | 7.181 sec/iter
Epoch: 56 | Batch: 020 / 029 | Total loss: 1.399 | Reg loss: 0.031 | Tree loss: 1.399 | Accuracy: 0.496094 | 7.181 sec/iter
Epoch: 56 | Batch: 021 / 029 | Total loss: 1.400 | Reg loss: 0.031 | Tree loss: 1.400 | Accuracy: 0.498047 | 7.181 sec/iter
Epoch: 56 | Batch: 022 / 029 | Total loss: 1.369 | Reg loss: 0.031 | Tree loss: 1.369 | Accuracy: 0.523438 | 7.181 sec/iter
Epoch: 56 | Batch: 023 / 029 | Total loss: 1.405 | Reg loss: 0.031 | Tree loss: 1.405 | Accuracy: 0.460938 | 7.181 sec/iter
Epoch: 56 | Batch: 024 / 029 | Total loss: 1.390 | Reg loss: 0.031 | Tree loss: 1.390 | Accuracy: 0.474609 | 7.181 sec/iter
Epoch: 5

Epoch: 58 | Batch: 020 / 029 | Total loss: 1.366 | Reg loss: 0.031 | Tree loss: 1.366 | Accuracy: 0.511719 | 7.17 sec/iter
Epoch: 58 | Batch: 021 / 029 | Total loss: 1.366 | Reg loss: 0.031 | Tree loss: 1.366 | Accuracy: 0.500000 | 7.17 sec/iter
Epoch: 58 | Batch: 022 / 029 | Total loss: 1.364 | Reg loss: 0.031 | Tree loss: 1.364 | Accuracy: 0.500000 | 7.17 sec/iter
Epoch: 58 | Batch: 023 / 029 | Total loss: 1.368 | Reg loss: 0.031 | Tree loss: 1.368 | Accuracy: 0.490234 | 7.17 sec/iter
Epoch: 58 | Batch: 024 / 029 | Total loss: 1.376 | Reg loss: 0.031 | Tree loss: 1.376 | Accuracy: 0.486328 | 7.17 sec/iter
Epoch: 58 | Batch: 025 / 029 | Total loss: 1.369 | Reg loss: 0.031 | Tree loss: 1.369 | Accuracy: 0.494141 | 7.169 sec/iter
Epoch: 58 | Batch: 026 / 029 | Total loss: 1.371 | Reg loss: 0.031 | Tree loss: 1.371 | Accuracy: 0.476562 | 7.168 sec/iter
Epoch: 58 | Batch: 027 / 029 | Total loss: 1.317 | Reg loss: 0.031 | Tree loss: 1.317 | Accuracy: 0.533203 | 7.167 sec/iter
Epoch: 58 | B

Epoch: 60 | Batch: 023 / 029 | Total loss: 1.347 | Reg loss: 0.031 | Tree loss: 1.347 | Accuracy: 0.488281 | 7.156 sec/iter
Epoch: 60 | Batch: 024 / 029 | Total loss: 1.349 | Reg loss: 0.031 | Tree loss: 1.349 | Accuracy: 0.482422 | 7.156 sec/iter
Epoch: 60 | Batch: 025 / 029 | Total loss: 1.354 | Reg loss: 0.031 | Tree loss: 1.354 | Accuracy: 0.466797 | 7.156 sec/iter
Epoch: 60 | Batch: 026 / 029 | Total loss: 1.349 | Reg loss: 0.031 | Tree loss: 1.349 | Accuracy: 0.474609 | 7.154 sec/iter
Epoch: 60 | Batch: 027 / 029 | Total loss: 1.356 | Reg loss: 0.031 | Tree loss: 1.356 | Accuracy: 0.451172 | 7.153 sec/iter
Epoch: 60 | Batch: 028 / 029 | Total loss: 1.336 | Reg loss: 0.031 | Tree loss: 1.336 | Accuracy: 0.487179 | 7.153 sec/iter
Average sparseness: 0.9821428571428572
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
layer 7: 0.98214285714

Epoch: 62 | Batch: 026 / 029 | Total loss: 1.294 | Reg loss: 0.031 | Tree loss: 1.294 | Accuracy: 0.515625 | 7.141 sec/iter
Epoch: 62 | Batch: 027 / 029 | Total loss: 1.311 | Reg loss: 0.031 | Tree loss: 1.311 | Accuracy: 0.494141 | 7.141 sec/iter
Epoch: 62 | Batch: 028 / 029 | Total loss: 1.306 | Reg loss: 0.031 | Tree loss: 1.306 | Accuracy: 0.512821 | 7.14 sec/iter
Average sparseness: 0.9821428571428572
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
layer 7: 0.9821428571428573
layer 8: 0.9821428571428573
layer 9: 0.9821428571428573
layer 10: 0.9821428571428573
Epoch: 63 | Batch: 000 / 029 | Total loss: 1.454 | Reg loss: 0.030 | Tree loss: 1.454 | Accuracy: 0.486328 | 7.144 sec/iter
Epoch: 63 | Batch: 001 / 029 | Total loss: 1.415 | Reg loss: 0.030 | Tree loss: 1.415 | Accuracy: 0.533203 | 7.145 sec/iter
Epoch: 63 | Batch: 002 / 029 | Tot

Average sparseness: 0.9821428571428572
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
layer 7: 0.9821428571428573
layer 8: 0.9821428571428573
layer 9: 0.9821428571428573
layer 10: 0.9821428571428573
Epoch: 65 | Batch: 000 / 029 | Total loss: 1.443 | Reg loss: 0.030 | Tree loss: 1.443 | Accuracy: 0.488281 | 7.143 sec/iter
Epoch: 65 | Batch: 001 / 029 | Total loss: 1.445 | Reg loss: 0.030 | Tree loss: 1.445 | Accuracy: 0.457031 | 7.143 sec/iter
Epoch: 65 | Batch: 002 / 029 | Total loss: 1.436 | Reg loss: 0.030 | Tree loss: 1.436 | Accuracy: 0.492188 | 7.143 sec/iter
Epoch: 65 | Batch: 003 / 029 | Total loss: 1.441 | Reg loss: 0.030 | Tree loss: 1.441 | Accuracy: 0.451172 | 7.144 sec/iter
Epoch: 65 | Batch: 004 / 029 | Total loss: 1.412 | Reg loss: 0.030 | Tree loss: 1.412 | Accuracy: 0.486328 | 7.144 sec/iter
Epoch: 65 | Batch: 005 / 029 | To

layer 10: 0.9821428571428573
Epoch: 67 | Batch: 000 / 029 | Total loss: 1.435 | Reg loss: 0.030 | Tree loss: 1.435 | Accuracy: 0.474609 | 7.142 sec/iter
Epoch: 67 | Batch: 001 / 029 | Total loss: 1.386 | Reg loss: 0.030 | Tree loss: 1.386 | Accuracy: 0.525391 | 7.142 sec/iter
Epoch: 67 | Batch: 002 / 029 | Total loss: 1.413 | Reg loss: 0.030 | Tree loss: 1.413 | Accuracy: 0.501953 | 7.142 sec/iter
Epoch: 67 | Batch: 003 / 029 | Total loss: 1.413 | Reg loss: 0.030 | Tree loss: 1.413 | Accuracy: 0.466797 | 7.142 sec/iter
Epoch: 67 | Batch: 004 / 029 | Total loss: 1.393 | Reg loss: 0.030 | Tree loss: 1.393 | Accuracy: 0.511719 | 7.142 sec/iter
Epoch: 67 | Batch: 005 / 029 | Total loss: 1.368 | Reg loss: 0.030 | Tree loss: 1.368 | Accuracy: 0.509766 | 7.142 sec/iter
Epoch: 67 | Batch: 006 / 029 | Total loss: 1.338 | Reg loss: 0.030 | Tree loss: 1.338 | Accuracy: 0.537109 | 7.142 sec/iter
Epoch: 67 | Batch: 007 / 029 | Total loss: 1.377 | Reg loss: 0.030 | Tree loss: 1.377 | Accuracy: 0.482

Epoch: 69 | Batch: 003 / 029 | Total loss: 1.390 | Reg loss: 0.030 | Tree loss: 1.390 | Accuracy: 0.490234 | 7.135 sec/iter
Epoch: 69 | Batch: 004 / 029 | Total loss: 1.394 | Reg loss: 0.030 | Tree loss: 1.394 | Accuracy: 0.464844 | 7.135 sec/iter
Epoch: 69 | Batch: 005 / 029 | Total loss: 1.400 | Reg loss: 0.030 | Tree loss: 1.400 | Accuracy: 0.451172 | 7.135 sec/iter
Epoch: 69 | Batch: 006 / 029 | Total loss: 1.364 | Reg loss: 0.030 | Tree loss: 1.364 | Accuracy: 0.501953 | 7.135 sec/iter
Epoch: 69 | Batch: 007 / 029 | Total loss: 1.345 | Reg loss: 0.030 | Tree loss: 1.345 | Accuracy: 0.509766 | 7.135 sec/iter
Epoch: 69 | Batch: 008 / 029 | Total loss: 1.328 | Reg loss: 0.030 | Tree loss: 1.328 | Accuracy: 0.535156 | 7.135 sec/iter
Epoch: 69 | Batch: 009 / 029 | Total loss: 1.314 | Reg loss: 0.030 | Tree loss: 1.314 | Accuracy: 0.515625 | 7.135 sec/iter
Epoch: 69 | Batch: 010 / 029 | Total loss: 1.334 | Reg loss: 0.030 | Tree loss: 1.334 | Accuracy: 0.505859 | 7.135 sec/iter
Epoch: 6

Epoch: 71 | Batch: 006 / 029 | Total loss: 1.375 | Reg loss: 0.030 | Tree loss: 1.375 | Accuracy: 0.464844 | 7.128 sec/iter
Epoch: 71 | Batch: 007 / 029 | Total loss: 1.367 | Reg loss: 0.030 | Tree loss: 1.367 | Accuracy: 0.480469 | 7.128 sec/iter
Epoch: 71 | Batch: 008 / 029 | Total loss: 1.297 | Reg loss: 0.030 | Tree loss: 1.297 | Accuracy: 0.527344 | 7.128 sec/iter
Epoch: 71 | Batch: 009 / 029 | Total loss: 1.343 | Reg loss: 0.030 | Tree loss: 1.343 | Accuracy: 0.482422 | 7.128 sec/iter
Epoch: 71 | Batch: 010 / 029 | Total loss: 1.310 | Reg loss: 0.030 | Tree loss: 1.310 | Accuracy: 0.521484 | 7.128 sec/iter
Epoch: 71 | Batch: 011 / 029 | Total loss: 1.341 | Reg loss: 0.030 | Tree loss: 1.341 | Accuracy: 0.462891 | 7.128 sec/iter
Epoch: 71 | Batch: 012 / 029 | Total loss: 1.328 | Reg loss: 0.030 | Tree loss: 1.328 | Accuracy: 0.470703 | 7.128 sec/iter
Epoch: 71 | Batch: 013 / 029 | Total loss: 1.302 | Reg loss: 0.030 | Tree loss: 1.302 | Accuracy: 0.492188 | 7.128 sec/iter
Epoch: 7

Epoch: 73 | Batch: 009 / 029 | Total loss: 1.343 | Reg loss: 0.030 | Tree loss: 1.343 | Accuracy: 0.466797 | 7.124 sec/iter
Epoch: 73 | Batch: 010 / 029 | Total loss: 1.330 | Reg loss: 0.030 | Tree loss: 1.330 | Accuracy: 0.458984 | 7.124 sec/iter
Epoch: 73 | Batch: 011 / 029 | Total loss: 1.305 | Reg loss: 0.030 | Tree loss: 1.305 | Accuracy: 0.492188 | 7.123 sec/iter
Epoch: 73 | Batch: 012 / 029 | Total loss: 1.273 | Reg loss: 0.030 | Tree loss: 1.273 | Accuracy: 0.511719 | 7.123 sec/iter
Epoch: 73 | Batch: 013 / 029 | Total loss: 1.323 | Reg loss: 0.030 | Tree loss: 1.323 | Accuracy: 0.458984 | 7.123 sec/iter
Epoch: 73 | Batch: 014 / 029 | Total loss: 1.296 | Reg loss: 0.030 | Tree loss: 1.296 | Accuracy: 0.494141 | 7.123 sec/iter
Epoch: 73 | Batch: 015 / 029 | Total loss: 1.277 | Reg loss: 0.030 | Tree loss: 1.277 | Accuracy: 0.507812 | 7.123 sec/iter
Epoch: 73 | Batch: 016 / 029 | Total loss: 1.260 | Reg loss: 0.030 | Tree loss: 1.260 | Accuracy: 0.519531 | 7.123 sec/iter
Epoch: 7

Epoch: 75 | Batch: 012 / 029 | Total loss: 1.311 | Reg loss: 0.030 | Tree loss: 1.311 | Accuracy: 0.480469 | 7.133 sec/iter
Epoch: 75 | Batch: 013 / 029 | Total loss: 1.268 | Reg loss: 0.030 | Tree loss: 1.268 | Accuracy: 0.523438 | 7.133 sec/iter
Epoch: 75 | Batch: 014 / 029 | Total loss: 1.301 | Reg loss: 0.030 | Tree loss: 1.301 | Accuracy: 0.482422 | 7.133 sec/iter
Epoch: 75 | Batch: 015 / 029 | Total loss: 1.290 | Reg loss: 0.030 | Tree loss: 1.290 | Accuracy: 0.476562 | 7.133 sec/iter
Epoch: 75 | Batch: 016 / 029 | Total loss: 1.264 | Reg loss: 0.030 | Tree loss: 1.264 | Accuracy: 0.507812 | 7.133 sec/iter
Epoch: 75 | Batch: 017 / 029 | Total loss: 1.265 | Reg loss: 0.030 | Tree loss: 1.265 | Accuracy: 0.496094 | 7.133 sec/iter
Epoch: 75 | Batch: 018 / 029 | Total loss: 1.280 | Reg loss: 0.030 | Tree loss: 1.280 | Accuracy: 0.486328 | 7.133 sec/iter
Epoch: 75 | Batch: 019 / 029 | Total loss: 1.277 | Reg loss: 0.030 | Tree loss: 1.277 | Accuracy: 0.480469 | 7.133 sec/iter
Epoch: 7

Epoch: 77 | Batch: 015 / 029 | Total loss: 1.253 | Reg loss: 0.030 | Tree loss: 1.253 | Accuracy: 0.519531 | 7.132 sec/iter
Epoch: 77 | Batch: 016 / 029 | Total loss: 1.249 | Reg loss: 0.030 | Tree loss: 1.249 | Accuracy: 0.529297 | 7.132 sec/iter
Epoch: 77 | Batch: 017 / 029 | Total loss: 1.262 | Reg loss: 0.030 | Tree loss: 1.262 | Accuracy: 0.496094 | 7.132 sec/iter
Epoch: 77 | Batch: 018 / 029 | Total loss: 1.261 | Reg loss: 0.030 | Tree loss: 1.261 | Accuracy: 0.490234 | 7.132 sec/iter
Epoch: 77 | Batch: 019 / 029 | Total loss: 1.249 | Reg loss: 0.030 | Tree loss: 1.249 | Accuracy: 0.486328 | 7.132 sec/iter
Epoch: 77 | Batch: 020 / 029 | Total loss: 1.281 | Reg loss: 0.030 | Tree loss: 1.281 | Accuracy: 0.458984 | 7.132 sec/iter
Epoch: 77 | Batch: 021 / 029 | Total loss: 1.260 | Reg loss: 0.030 | Tree loss: 1.260 | Accuracy: 0.484375 | 7.131 sec/iter
Epoch: 77 | Batch: 022 / 029 | Total loss: 1.242 | Reg loss: 0.030 | Tree loss: 1.242 | Accuracy: 0.498047 | 7.13 sec/iter
Epoch: 77

Epoch: 79 | Batch: 018 / 029 | Total loss: 1.277 | Reg loss: 0.029 | Tree loss: 1.277 | Accuracy: 0.457031 | 7.124 sec/iter
Epoch: 79 | Batch: 019 / 029 | Total loss: 1.252 | Reg loss: 0.030 | Tree loss: 1.252 | Accuracy: 0.478516 | 7.124 sec/iter
Epoch: 79 | Batch: 020 / 029 | Total loss: 1.234 | Reg loss: 0.030 | Tree loss: 1.234 | Accuracy: 0.505859 | 7.124 sec/iter
Epoch: 79 | Batch: 021 / 029 | Total loss: 1.241 | Reg loss: 0.030 | Tree loss: 1.241 | Accuracy: 0.490234 | 7.124 sec/iter
Epoch: 79 | Batch: 022 / 029 | Total loss: 1.228 | Reg loss: 0.030 | Tree loss: 1.228 | Accuracy: 0.515625 | 7.123 sec/iter
Epoch: 79 | Batch: 023 / 029 | Total loss: 1.217 | Reg loss: 0.030 | Tree loss: 1.217 | Accuracy: 0.505859 | 7.122 sec/iter
Epoch: 79 | Batch: 024 / 029 | Total loss: 1.214 | Reg loss: 0.030 | Tree loss: 1.214 | Accuracy: 0.519531 | 7.121 sec/iter
Epoch: 79 | Batch: 025 / 029 | Total loss: 1.239 | Reg loss: 0.030 | Tree loss: 1.239 | Accuracy: 0.480469 | 7.12 sec/iter
Epoch: 79

Epoch: 81 | Batch: 021 / 029 | Total loss: 1.238 | Reg loss: 0.029 | Tree loss: 1.238 | Accuracy: 0.482422 | 7.105 sec/iter
Epoch: 81 | Batch: 022 / 029 | Total loss: 1.252 | Reg loss: 0.029 | Tree loss: 1.252 | Accuracy: 0.460938 | 7.104 sec/iter
Epoch: 81 | Batch: 023 / 029 | Total loss: 1.205 | Reg loss: 0.029 | Tree loss: 1.205 | Accuracy: 0.525391 | 7.103 sec/iter
Epoch: 81 | Batch: 024 / 029 | Total loss: 1.200 | Reg loss: 0.030 | Tree loss: 1.200 | Accuracy: 0.527344 | 7.103 sec/iter
Epoch: 81 | Batch: 025 / 029 | Total loss: 1.224 | Reg loss: 0.030 | Tree loss: 1.224 | Accuracy: 0.490234 | 7.103 sec/iter
Epoch: 81 | Batch: 026 / 029 | Total loss: 1.210 | Reg loss: 0.030 | Tree loss: 1.210 | Accuracy: 0.503906 | 7.103 sec/iter
Epoch: 81 | Batch: 027 / 029 | Total loss: 1.214 | Reg loss: 0.030 | Tree loss: 1.214 | Accuracy: 0.496094 | 7.102 sec/iter
Epoch: 81 | Batch: 028 / 029 | Total loss: 1.214 | Reg loss: 0.030 | Tree loss: 1.214 | Accuracy: 0.495069 | 7.101 sec/iter
Average 

Epoch: 83 | Batch: 024 / 029 | Total loss: 1.200 | Reg loss: 0.029 | Tree loss: 1.200 | Accuracy: 0.505859 | 7.101 sec/iter
Epoch: 83 | Batch: 025 / 029 | Total loss: 1.229 | Reg loss: 0.029 | Tree loss: 1.229 | Accuracy: 0.470703 | 7.1 sec/iter
Epoch: 83 | Batch: 026 / 029 | Total loss: 1.202 | Reg loss: 0.029 | Tree loss: 1.202 | Accuracy: 0.511719 | 7.099 sec/iter
Epoch: 83 | Batch: 027 / 029 | Total loss: 1.175 | Reg loss: 0.029 | Tree loss: 1.175 | Accuracy: 0.548828 | 7.099 sec/iter
Epoch: 83 | Batch: 028 / 029 | Total loss: 1.221 | Reg loss: 0.029 | Tree loss: 1.221 | Accuracy: 0.471400 | 7.099 sec/iter
Average sparseness: 0.9821428571428572
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
layer 7: 0.9821428571428573
layer 8: 0.9821428571428573
layer 9: 0.9821428571428573
layer 10: 0.9821428571428573
Epoch: 84 | Batch: 000 / 029 | Tota

Epoch: 85 | Batch: 027 / 029 | Total loss: 1.179 | Reg loss: 0.029 | Tree loss: 1.179 | Accuracy: 0.527344 | 7.104 sec/iter
Epoch: 85 | Batch: 028 / 029 | Total loss: 1.215 | Reg loss: 0.029 | Tree loss: 1.215 | Accuracy: 0.473373 | 7.104 sec/iter
Average sparseness: 0.9821428571428572
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
layer 7: 0.9821428571428573
layer 8: 0.9821428571428573
layer 9: 0.9821428571428573
layer 10: 0.9821428571428573
Epoch: 86 | Batch: 000 / 029 | Total loss: 1.367 | Reg loss: 0.029 | Tree loss: 1.367 | Accuracy: 0.464844 | 7.11 sec/iter
Epoch: 86 | Batch: 001 / 029 | Total loss: 1.322 | Reg loss: 0.029 | Tree loss: 1.322 | Accuracy: 0.492188 | 7.11 sec/iter
Epoch: 86 | Batch: 002 / 029 | Total loss: 1.306 | Reg loss: 0.029 | Tree loss: 1.306 | Accuracy: 0.511719 | 7.11 sec/iter
Epoch: 86 | Batch: 003 / 029 | Total

layer 9: 0.9821428571428573
layer 10: 0.9821428571428573
Epoch: 88 | Batch: 000 / 029 | Total loss: 1.312 | Reg loss: 0.029 | Tree loss: 1.312 | Accuracy: 0.521484 | 7.106 sec/iter
Epoch: 88 | Batch: 001 / 029 | Total loss: 1.371 | Reg loss: 0.029 | Tree loss: 1.371 | Accuracy: 0.464844 | 7.106 sec/iter
Epoch: 88 | Batch: 002 / 029 | Total loss: 1.311 | Reg loss: 0.029 | Tree loss: 1.311 | Accuracy: 0.500000 | 7.106 sec/iter
Epoch: 88 | Batch: 003 / 029 | Total loss: 1.283 | Reg loss: 0.029 | Tree loss: 1.283 | Accuracy: 0.525391 | 7.106 sec/iter
Epoch: 88 | Batch: 004 / 029 | Total loss: 1.311 | Reg loss: 0.029 | Tree loss: 1.311 | Accuracy: 0.472656 | 7.106 sec/iter
Epoch: 88 | Batch: 005 / 029 | Total loss: 1.326 | Reg loss: 0.029 | Tree loss: 1.326 | Accuracy: 0.470703 | 7.106 sec/iter
Epoch: 88 | Batch: 006 / 029 | Total loss: 1.298 | Reg loss: 0.029 | Tree loss: 1.298 | Accuracy: 0.480469 | 7.106 sec/iter
Epoch: 88 | Batch: 007 / 029 | Total loss: 1.262 | Reg loss: 0.029 | Tree l

Epoch: 90 | Batch: 003 / 029 | Total loss: 1.292 | Reg loss: 0.029 | Tree loss: 1.292 | Accuracy: 0.529297 | 7.103 sec/iter
Epoch: 90 | Batch: 004 / 029 | Total loss: 1.313 | Reg loss: 0.029 | Tree loss: 1.313 | Accuracy: 0.482422 | 7.103 sec/iter
Epoch: 90 | Batch: 005 / 029 | Total loss: 1.292 | Reg loss: 0.029 | Tree loss: 1.292 | Accuracy: 0.482422 | 7.104 sec/iter
Epoch: 90 | Batch: 006 / 029 | Total loss: 1.252 | Reg loss: 0.029 | Tree loss: 1.252 | Accuracy: 0.525391 | 7.104 sec/iter
Epoch: 90 | Batch: 007 / 029 | Total loss: 1.239 | Reg loss: 0.029 | Tree loss: 1.239 | Accuracy: 0.537109 | 7.104 sec/iter
Epoch: 90 | Batch: 008 / 029 | Total loss: 1.269 | Reg loss: 0.029 | Tree loss: 1.269 | Accuracy: 0.498047 | 7.103 sec/iter
Epoch: 90 | Batch: 009 / 029 | Total loss: 1.291 | Reg loss: 0.029 | Tree loss: 1.291 | Accuracy: 0.476562 | 7.103 sec/iter
Epoch: 90 | Batch: 010 / 029 | Total loss: 1.279 | Reg loss: 0.029 | Tree loss: 1.279 | Accuracy: 0.470703 | 7.103 sec/iter
Epoch: 9

Epoch: 92 | Batch: 006 / 029 | Total loss: 1.302 | Reg loss: 0.028 | Tree loss: 1.302 | Accuracy: 0.480469 | 7.099 sec/iter
Epoch: 92 | Batch: 007 / 029 | Total loss: 1.271 | Reg loss: 0.028 | Tree loss: 1.271 | Accuracy: 0.496094 | 7.099 sec/iter
Epoch: 92 | Batch: 008 / 029 | Total loss: 1.293 | Reg loss: 0.029 | Tree loss: 1.293 | Accuracy: 0.468750 | 7.099 sec/iter
Epoch: 92 | Batch: 009 / 029 | Total loss: 1.271 | Reg loss: 0.029 | Tree loss: 1.271 | Accuracy: 0.490234 | 7.099 sec/iter
Epoch: 92 | Batch: 010 / 029 | Total loss: 1.250 | Reg loss: 0.029 | Tree loss: 1.250 | Accuracy: 0.501953 | 7.099 sec/iter
Epoch: 92 | Batch: 011 / 029 | Total loss: 1.232 | Reg loss: 0.029 | Tree loss: 1.232 | Accuracy: 0.513672 | 7.099 sec/iter
Epoch: 92 | Batch: 012 / 029 | Total loss: 1.221 | Reg loss: 0.029 | Tree loss: 1.221 | Accuracy: 0.515625 | 7.099 sec/iter
Epoch: 92 | Batch: 013 / 029 | Total loss: 1.250 | Reg loss: 0.029 | Tree loss: 1.250 | Accuracy: 0.470703 | 7.099 sec/iter
Epoch: 9

Epoch: 94 | Batch: 009 / 029 | Total loss: 1.249 | Reg loss: 0.028 | Tree loss: 1.249 | Accuracy: 0.509766 | 7.092 sec/iter
Epoch: 94 | Batch: 010 / 029 | Total loss: 1.262 | Reg loss: 0.028 | Tree loss: 1.262 | Accuracy: 0.474609 | 7.092 sec/iter
Epoch: 94 | Batch: 011 / 029 | Total loss: 1.238 | Reg loss: 0.028 | Tree loss: 1.238 | Accuracy: 0.498047 | 7.092 sec/iter
Epoch: 94 | Batch: 012 / 029 | Total loss: 1.255 | Reg loss: 0.028 | Tree loss: 1.255 | Accuracy: 0.484375 | 7.092 sec/iter
Epoch: 94 | Batch: 013 / 029 | Total loss: 1.245 | Reg loss: 0.029 | Tree loss: 1.245 | Accuracy: 0.476562 | 7.092 sec/iter
Epoch: 94 | Batch: 014 / 029 | Total loss: 1.226 | Reg loss: 0.029 | Tree loss: 1.226 | Accuracy: 0.511719 | 7.092 sec/iter
Epoch: 94 | Batch: 015 / 029 | Total loss: 1.259 | Reg loss: 0.029 | Tree loss: 1.259 | Accuracy: 0.451172 | 7.092 sec/iter
Epoch: 94 | Batch: 016 / 029 | Total loss: 1.195 | Reg loss: 0.029 | Tree loss: 1.195 | Accuracy: 0.525391 | 7.092 sec/iter
Epoch: 9

Epoch: 96 | Batch: 012 / 029 | Total loss: 1.220 | Reg loss: 0.028 | Tree loss: 1.220 | Accuracy: 0.519531 | 7.08 sec/iter
Epoch: 96 | Batch: 013 / 029 | Total loss: 1.212 | Reg loss: 0.028 | Tree loss: 1.212 | Accuracy: 0.498047 | 7.079 sec/iter
Epoch: 96 | Batch: 014 / 029 | Total loss: 1.234 | Reg loss: 0.028 | Tree loss: 1.234 | Accuracy: 0.484375 | 7.079 sec/iter
Epoch: 96 | Batch: 015 / 029 | Total loss: 1.212 | Reg loss: 0.028 | Tree loss: 1.212 | Accuracy: 0.509766 | 7.078 sec/iter
Epoch: 96 | Batch: 016 / 029 | Total loss: 1.249 | Reg loss: 0.028 | Tree loss: 1.249 | Accuracy: 0.457031 | 7.077 sec/iter
Epoch: 96 | Batch: 017 / 029 | Total loss: 1.196 | Reg loss: 0.029 | Tree loss: 1.196 | Accuracy: 0.503906 | 7.076 sec/iter
Epoch: 96 | Batch: 018 / 029 | Total loss: 1.210 | Reg loss: 0.029 | Tree loss: 1.210 | Accuracy: 0.501953 | 7.075 sec/iter
Epoch: 96 | Batch: 019 / 029 | Total loss: 1.205 | Reg loss: 0.029 | Tree loss: 1.205 | Accuracy: 0.490234 | 7.074 sec/iter
Epoch: 96

Epoch: 98 | Batch: 015 / 029 | Total loss: 1.184 | Reg loss: 0.028 | Tree loss: 1.184 | Accuracy: 0.533203 | 7.035 sec/iter
Epoch: 98 | Batch: 016 / 029 | Total loss: 1.194 | Reg loss: 0.028 | Tree loss: 1.194 | Accuracy: 0.519531 | 7.034 sec/iter
Epoch: 98 | Batch: 017 / 029 | Total loss: 1.207 | Reg loss: 0.028 | Tree loss: 1.207 | Accuracy: 0.486328 | 7.033 sec/iter
Epoch: 98 | Batch: 018 / 029 | Total loss: 1.233 | Reg loss: 0.028 | Tree loss: 1.233 | Accuracy: 0.468750 | 7.032 sec/iter
Epoch: 98 | Batch: 019 / 029 | Total loss: 1.229 | Reg loss: 0.028 | Tree loss: 1.229 | Accuracy: 0.462891 | 7.031 sec/iter
Epoch: 98 | Batch: 020 / 029 | Total loss: 1.203 | Reg loss: 0.028 | Tree loss: 1.203 | Accuracy: 0.490234 | 7.031 sec/iter
Epoch: 98 | Batch: 021 / 029 | Total loss: 1.180 | Reg loss: 0.028 | Tree loss: 1.180 | Accuracy: 0.503906 | 7.03 sec/iter
Epoch: 98 | Batch: 022 / 029 | Total loss: 1.208 | Reg loss: 0.029 | Tree loss: 1.208 | Accuracy: 0.458984 | 7.029 sec/iter
Epoch: 98

In [31]:
plt.figure(figsize=(10, 5))
plt.ylabel("Accuracy")
plt.xlabel('Iteration')
plt.plot(accs, label='Accuracy vs iteration')
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [32]:
plt.figure()
plt.ylabel("Loss")
plt.xlabel('Iteration')
plt.plot(losses, label='Loss vs iteration')
plt.yscale("log")
plt.show()

plt.figure()
weights = tree.inner_nodes.weight.cpu().detach().numpy().flatten()
plt.hist(weights, bins=500)
weights_std = np.std(weights)
weights_mean = np.mean(weights)
plt.axvline(weights_mean + weights_std, color='r')
plt.axvline(weights_mean - weights_std, color='r')
plt.title(f"Mean: {weights_mean}   |   STD: {weights_std}")
plt.yscale("log")
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

# Tree Visualization

In [33]:
plt.figure(figsize=(15, 10), dpi=80)
avg_height, root = tree.visualize()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Average height: 11.886291486291487


# Extract Rules

# Accumulate samples in the leaves

In [34]:
print(f"Number of patterns: {len(root.get_leaves())}")

Number of patterns: 3465


In [35]:
method = 'greedy'

In [36]:
root.clear_leaves_samples()

with torch.no_grad():
    for batch_idx, (data, target) in enumerate(tree_loader):
        root.accumulate_samples(data, method)



# Tighten boundaries

In [37]:
attr_names = dataset.items

# print(attr_names)
leaves = root.get_leaves()
sum_comprehensibility = 0
comprehensibilities = []
for pattern_counter, leaf in enumerate(leaves):
    leaf.reset_path()
    leaf.tighten_with_accumulated_samples()
    conds = leaf.get_path_conditions(attr_names)
    print(f"============== Pattern {pattern_counter + 1} ==============")
    comprehensibilities.append(sum([cond.comprehensibility for cond in conds]))
    
print(f"Average comprehensibility: {np.mean(comprehensibilities)}")
print(f"std comprehensibility: {np.std(comprehensibilities)}")
print(f"var comprehensibility: {np.var(comprehensibilities)}")
print(f"minimum comprehensibility: {np.min(comprehensibilities)}")
print(f"maximum comprehensibility: {np.max(comprehensibilities)}")

  return np.log(1 / (1 - x))


14843








Average comprehensibility: 60.38787878787879
std comprehensibility: 4.175330936554858


var comprehensibility: 17.433388429752068
minimum comprehensibility: 30
maximum comprehensibility: 68
