In [35]:
%load_ext autoreload
%autoreload 2
%matplotlib widget
import numpy as np
import pandas as pd
import torch
import torchvision
from torch import optim
import matplotlib.pyplot as plt
import torch.nn.functional as F
from sklearn.metrics import pairwise_distances
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tqdm.notebook import tqdm
import os
import sys
import time
import torch.nn as nn
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)
from stream_generators.market_basket_dataset import MarketBasketDataset, BinaryEncodingTransform, RemoveItemsTransform
from utils.MatplotlibUtils import reduce_dims_and_plot
from network.auto_encoder import AutoEncoder
from losses.knn_loss import KNNLoss
from sklearn.cluster import DBSCAN
from sklearn.mixture import GaussianMixture
from soft_decision_tree.sdt_model import SDT
from sklearn.metrics import davies_bouldin_score

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
k = 256
tree_depth = 6
device = 'cuda'
dataset_path = r"/mnt/qnap/ekosman/Groceries_dataset.csv"

## Create the market basket dataset and use one-hot encoding for items

In [3]:
dataset = MarketBasketDataset(dataset_path=dataset_path)

In [4]:
model = AutoEncoder(dataset.n_items, 50, 4).train().to(device)
epochs = 500
lr = 5e-3
batch_size = 512
log_every = 5

In [5]:
dataset.transform = torchvision.transforms.Compose([
    RemoveItemsTransform(p=0.5),
    BinaryEncodingTransform(mapping=dataset.items_to_idx),
]
)
dataset.target_transform = torchvision.transforms.Compose([
    BinaryEncodingTransform(mapping=dataset.items_to_idx),
]
)

In [6]:
model.train()
data_iter = torch.utils.data.DataLoader(dataset,
                                     batch_size=batch_size,
                                     shuffle=True,
                                     num_workers=1,
                                     pin_memory=True)
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)
scheduler = ReduceLROnPlateau(optimizer, 'min', verbose=True, factor=0.7, threshold=1e-4)
knn_crt = KNNLoss(k=k).to(device)
losses = []
alpha = 10/170
gamma = 2
for epoch in range(epochs):
    total_loss = 0
    for iteration, (batch, target) in enumerate(data_iter):
        batch = batch.to(device)
        target = target.to(device)
        outputs, iterm = model(batch, return_intermidiate=True)
        mse_loss = F.binary_cross_entropy_with_logits(outputs, target, reduction='none')
        mask = torch.ones_like(mse_loss)
        mask[target == 0] = alpha ** gamma
        mask[target == 1] = (1 - alpha) ** gamma
        mse_loss = (mse_loss * mask).sum(dim=-1).mean()
        try:
            knn_loss = knn_crt(iterm)
            if torch.isinf(knn_loss):
                knn_loss = 0
        except ValueError:
            knn_loss = torch.tensor(0)
        loss = mse_loss + knn_loss
        total_loss += loss.item()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if iteration % log_every == 0:
            print(f"Epoch {epoch} / {epochs} | iteration {iteration} / {len(data_iter)} | Total Loss: {loss.item()} | KNN Loss: {knn_loss.item()} | BCE Loss: {mse_loss.item()}")
    
    scheduler.step(total_loss / (iteration + 1))
    losses.append(total_loss / (iteration + 1))

Epoch 0 / 500 | iteration 0 / 30 | Total Loss: 8.214082717895508 | KNN Loss: 6.232590198516846 | BCE Loss: 1.9814924001693726
Epoch 0 / 500 | iteration 5 / 30 | Total Loss: 8.18209171295166 | KNN Loss: 6.232646942138672 | BCE Loss: 1.9494444131851196
Epoch 0 / 500 | iteration 10 / 30 | Total Loss: 8.180830955505371 | KNN Loss: 6.232651233673096 | BCE Loss: 1.948180079460144
Epoch 0 / 500 | iteration 15 / 30 | Total Loss: 8.165058135986328 | KNN Loss: 6.232604503631592 | BCE Loss: 1.9324537515640259
Epoch 0 / 500 | iteration 20 / 30 | Total Loss: 8.141905784606934 | KNN Loss: 6.232545852661133 | BCE Loss: 1.9093595743179321
Epoch 0 / 500 | iteration 25 / 30 | Total Loss: 8.150469779968262 | KNN Loss: 6.232403755187988 | BCE Loss: 1.9180659055709839
Epoch 1 / 500 | iteration 0 / 30 | Total Loss: 8.167644500732422 | KNN Loss: 6.232351779937744 | BCE Loss: 1.9352922439575195
Epoch 1 / 500 | iteration 5 / 30 | Total Loss: 8.160579681396484 | KNN Loss: 6.232356071472168 | BCE Loss: 1.9282238

Epoch 10 / 500 | iteration 25 / 30 | Total Loss: 7.382832050323486 | KNN Loss: 6.2128143310546875 | BCE Loss: 1.1700177192687988
Epoch 11 / 500 | iteration 0 / 30 | Total Loss: 7.365210056304932 | KNN Loss: 6.208114147186279 | BCE Loss: 1.1570959091186523
Epoch 11 / 500 | iteration 5 / 30 | Total Loss: 7.319311141967773 | KNN Loss: 6.20708703994751 | BCE Loss: 1.1122238636016846
Epoch 11 / 500 | iteration 10 / 30 | Total Loss: 7.334786415100098 | KNN Loss: 6.203920841217041 | BCE Loss: 1.1308658123016357
Epoch 11 / 500 | iteration 15 / 30 | Total Loss: 7.328578948974609 | KNN Loss: 6.200413227081299 | BCE Loss: 1.1281657218933105
Epoch 11 / 500 | iteration 20 / 30 | Total Loss: 7.3075175285339355 | KNN Loss: 6.196969509124756 | BCE Loss: 1.1105480194091797
Epoch 11 / 500 | iteration 25 / 30 | Total Loss: 7.304122447967529 | KNN Loss: 6.196120738983154 | BCE Loss: 1.1080015897750854
Epoch 12 / 500 | iteration 0 / 30 | Total Loss: 7.274090766906738 | KNN Loss: 6.193741321563721 | BCE Los

Epoch 21 / 500 | iteration 20 / 30 | Total Loss: 6.817303657531738 | KNN Loss: 5.757551193237305 | BCE Loss: 1.0597525835037231
Epoch 21 / 500 | iteration 25 / 30 | Total Loss: 6.8298659324646 | KNN Loss: 5.766357898712158 | BCE Loss: 1.0635080337524414
Epoch 22 / 500 | iteration 0 / 30 | Total Loss: 6.808271408081055 | KNN Loss: 5.757097244262695 | BCE Loss: 1.0511741638183594
Epoch 22 / 500 | iteration 5 / 30 | Total Loss: 6.804651737213135 | KNN Loss: 5.737254619598389 | BCE Loss: 1.067397117614746
Epoch 22 / 500 | iteration 10 / 30 | Total Loss: 6.8376383781433105 | KNN Loss: 5.74813985824585 | BCE Loss: 1.0894984006881714
Epoch 22 / 500 | iteration 15 / 30 | Total Loss: 6.7630462646484375 | KNN Loss: 5.725612640380859 | BCE Loss: 1.037433385848999
Epoch 22 / 500 | iteration 20 / 30 | Total Loss: 6.782855987548828 | KNN Loss: 5.726426124572754 | BCE Loss: 1.0564301013946533
Epoch 22 / 500 | iteration 25 / 30 | Total Loss: 6.758255958557129 | KNN Loss: 5.711081504821777 | BCE Loss: 

Epoch 32 / 500 | iteration 15 / 30 | Total Loss: 6.665799617767334 | KNN Loss: 5.621352672576904 | BCE Loss: 1.0444470643997192
Epoch 32 / 500 | iteration 20 / 30 | Total Loss: 6.676266670227051 | KNN Loss: 5.643748760223389 | BCE Loss: 1.0325177907943726
Epoch 32 / 500 | iteration 25 / 30 | Total Loss: 6.670083045959473 | KNN Loss: 5.616166591644287 | BCE Loss: 1.0539166927337646
Epoch 33 / 500 | iteration 0 / 30 | Total Loss: 6.766299247741699 | KNN Loss: 5.700475215911865 | BCE Loss: 1.0658239126205444
Epoch 33 / 500 | iteration 5 / 30 | Total Loss: 6.715966701507568 | KNN Loss: 5.669598579406738 | BCE Loss: 1.04636812210083
Epoch 33 / 500 | iteration 10 / 30 | Total Loss: 6.698056221008301 | KNN Loss: 5.666419982910156 | BCE Loss: 1.0316359996795654
Epoch 33 / 500 | iteration 15 / 30 | Total Loss: 6.733022212982178 | KNN Loss: 5.679352760314941 | BCE Loss: 1.0536694526672363
Epoch 33 / 500 | iteration 20 / 30 | Total Loss: 6.645362377166748 | KNN Loss: 5.630128383636475 | BCE Loss:

Epoch 43 / 500 | iteration 10 / 30 | Total Loss: 6.698171615600586 | KNN Loss: 5.619533538818359 | BCE Loss: 1.0786380767822266
Epoch 43 / 500 | iteration 15 / 30 | Total Loss: 6.653359413146973 | KNN Loss: 5.622973442077637 | BCE Loss: 1.0303857326507568
Epoch 43 / 500 | iteration 20 / 30 | Total Loss: 6.652497291564941 | KNN Loss: 5.626805305480957 | BCE Loss: 1.025692105293274
Epoch 43 / 500 | iteration 25 / 30 | Total Loss: 6.656472206115723 | KNN Loss: 5.606506824493408 | BCE Loss: 1.0499656200408936
Epoch 44 / 500 | iteration 0 / 30 | Total Loss: 6.697007179260254 | KNN Loss: 5.646609783172607 | BCE Loss: 1.0503973960876465
Epoch 44 / 500 | iteration 5 / 30 | Total Loss: 6.674751281738281 | KNN Loss: 5.63383150100708 | BCE Loss: 1.0409200191497803
Epoch 44 / 500 | iteration 10 / 30 | Total Loss: 6.68321418762207 | KNN Loss: 5.612985134124756 | BCE Loss: 1.0702292919158936
Epoch 44 / 500 | iteration 15 / 30 | Total Loss: 6.666189193725586 | KNN Loss: 5.630638122558594 | BCE Loss: 

Epoch 54 / 500 | iteration 5 / 30 | Total Loss: 6.629484176635742 | KNN Loss: 5.595648765563965 | BCE Loss: 1.0338356494903564
Epoch 54 / 500 | iteration 10 / 30 | Total Loss: 6.626345634460449 | KNN Loss: 5.591723918914795 | BCE Loss: 1.0346214771270752
Epoch 54 / 500 | iteration 15 / 30 | Total Loss: 6.645418167114258 | KNN Loss: 5.607794761657715 | BCE Loss: 1.037623405456543
Epoch 54 / 500 | iteration 20 / 30 | Total Loss: 6.714426517486572 | KNN Loss: 5.657421588897705 | BCE Loss: 1.0570048093795776
Epoch 54 / 500 | iteration 25 / 30 | Total Loss: 6.633336067199707 | KNN Loss: 5.59903621673584 | BCE Loss: 1.0342998504638672
Epoch 55 / 500 | iteration 0 / 30 | Total Loss: 6.661944389343262 | KNN Loss: 5.621927738189697 | BCE Loss: 1.0400164127349854
Epoch 55 / 500 | iteration 5 / 30 | Total Loss: 6.698716163635254 | KNN Loss: 5.634062767028809 | BCE Loss: 1.0646536350250244
Epoch 55 / 500 | iteration 10 / 30 | Total Loss: 6.7162041664123535 | KNN Loss: 5.650175094604492 | BCE Loss:

Epoch 64 / 500 | iteration 25 / 30 | Total Loss: 6.693574905395508 | KNN Loss: 5.631575107574463 | BCE Loss: 1.062000036239624
Epoch 65 / 500 | iteration 0 / 30 | Total Loss: 6.6827006340026855 | KNN Loss: 5.635884761810303 | BCE Loss: 1.0468159914016724
Epoch 65 / 500 | iteration 5 / 30 | Total Loss: 6.69195556640625 | KNN Loss: 5.6459879875183105 | BCE Loss: 1.0459678173065186
Epoch 65 / 500 | iteration 10 / 30 | Total Loss: 6.64630126953125 | KNN Loss: 5.613107204437256 | BCE Loss: 1.0331941843032837
Epoch 65 / 500 | iteration 15 / 30 | Total Loss: 6.692854881286621 | KNN Loss: 5.65071964263916 | BCE Loss: 1.0421350002288818
Epoch 65 / 500 | iteration 20 / 30 | Total Loss: 6.668036460876465 | KNN Loss: 5.6315155029296875 | BCE Loss: 1.0365209579467773
Epoch 65 / 500 | iteration 25 / 30 | Total Loss: 6.661787033081055 | KNN Loss: 5.6033549308776855 | BCE Loss: 1.0584322214126587
Epoch 66 / 500 | iteration 0 / 30 | Total Loss: 6.6590962409973145 | KNN Loss: 5.595683574676514 | BCE Los

Epoch 75 / 500 | iteration 15 / 30 | Total Loss: 6.672493934631348 | KNN Loss: 5.631849765777588 | BCE Loss: 1.0406444072723389
Epoch 75 / 500 | iteration 20 / 30 | Total Loss: 6.68533182144165 | KNN Loss: 5.639110565185547 | BCE Loss: 1.0462212562561035
Epoch 75 / 500 | iteration 25 / 30 | Total Loss: 6.659724712371826 | KNN Loss: 5.6058831214904785 | BCE Loss: 1.0538417100906372
Epoch 76 / 500 | iteration 0 / 30 | Total Loss: 6.640063285827637 | KNN Loss: 5.595051288604736 | BCE Loss: 1.0450122356414795
Epoch 76 / 500 | iteration 5 / 30 | Total Loss: 6.692999839782715 | KNN Loss: 5.6311163902282715 | BCE Loss: 1.0618832111358643
Epoch 76 / 500 | iteration 10 / 30 | Total Loss: 6.657026290893555 | KNN Loss: 5.604086875915527 | BCE Loss: 1.0529394149780273
Epoch 76 / 500 | iteration 15 / 30 | Total Loss: 6.680678844451904 | KNN Loss: 5.628357887268066 | BCE Loss: 1.052320957183838
Epoch 76 / 500 | iteration 20 / 30 | Total Loss: 6.696191787719727 | KNN Loss: 5.639223098754883 | BCE Los

Epoch 86 / 500 | iteration 5 / 30 | Total Loss: 6.664100170135498 | KNN Loss: 5.623289108276367 | BCE Loss: 1.0408110618591309
Epoch 86 / 500 | iteration 10 / 30 | Total Loss: 6.647641181945801 | KNN Loss: 5.602823734283447 | BCE Loss: 1.044817328453064
Epoch 86 / 500 | iteration 15 / 30 | Total Loss: 6.627521514892578 | KNN Loss: 5.59689474105835 | BCE Loss: 1.030626893043518
Epoch 86 / 500 | iteration 20 / 30 | Total Loss: 6.645143985748291 | KNN Loss: 5.594326972961426 | BCE Loss: 1.0508170127868652
Epoch 86 / 500 | iteration 25 / 30 | Total Loss: 6.694102764129639 | KNN Loss: 5.666804790496826 | BCE Loss: 1.027298092842102
Epoch 87 / 500 | iteration 0 / 30 | Total Loss: 6.763680458068848 | KNN Loss: 5.734431266784668 | BCE Loss: 1.0292491912841797
Epoch 87 / 500 | iteration 5 / 30 | Total Loss: 6.7032976150512695 | KNN Loss: 5.658468723297119 | BCE Loss: 1.0448286533355713
Epoch 87 / 500 | iteration 10 / 30 | Total Loss: 6.6810078620910645 | KNN Loss: 5.637966632843018 | BCE Loss: 

Epoch 96 / 500 | iteration 25 / 30 | Total Loss: 6.683658123016357 | KNN Loss: 5.6288743019104 | BCE Loss: 1.0547837018966675
Epoch 97 / 500 | iteration 0 / 30 | Total Loss: 6.722046852111816 | KNN Loss: 5.663301467895508 | BCE Loss: 1.0587456226348877
Epoch 97 / 500 | iteration 5 / 30 | Total Loss: 6.70414924621582 | KNN Loss: 5.654600143432617 | BCE Loss: 1.0495493412017822
Epoch 97 / 500 | iteration 10 / 30 | Total Loss: 6.641899108886719 | KNN Loss: 5.594277381896973 | BCE Loss: 1.047621726989746
Epoch 97 / 500 | iteration 15 / 30 | Total Loss: 6.705392837524414 | KNN Loss: 5.687831401824951 | BCE Loss: 1.017561435699463
Epoch 97 / 500 | iteration 20 / 30 | Total Loss: 6.659313201904297 | KNN Loss: 5.62041711807251 | BCE Loss: 1.038896083831787
Epoch 97 / 500 | iteration 25 / 30 | Total Loss: 6.730095863342285 | KNN Loss: 5.668910980224609 | BCE Loss: 1.0611846446990967
Epoch 98 / 500 | iteration 0 / 30 | Total Loss: 6.609463691711426 | KNN Loss: 5.599545955657959 | BCE Loss: 1.009

Epoch 107 / 500 | iteration 15 / 30 | Total Loss: 6.642608165740967 | KNN Loss: 5.6060051918029785 | BCE Loss: 1.0366028547286987
Epoch 107 / 500 | iteration 20 / 30 | Total Loss: 6.741386413574219 | KNN Loss: 5.677004337310791 | BCE Loss: 1.0643818378448486
Epoch 107 / 500 | iteration 25 / 30 | Total Loss: 6.639183044433594 | KNN Loss: 5.590975761413574 | BCE Loss: 1.0482075214385986
Epoch 108 / 500 | iteration 0 / 30 | Total Loss: 6.624815940856934 | KNN Loss: 5.597254753112793 | BCE Loss: 1.0275614261627197
Epoch 108 / 500 | iteration 5 / 30 | Total Loss: 6.695888519287109 | KNN Loss: 5.635781764984131 | BCE Loss: 1.0601067543029785
Epoch 108 / 500 | iteration 10 / 30 | Total Loss: 6.660679340362549 | KNN Loss: 5.596388816833496 | BCE Loss: 1.0642904043197632
Epoch 108 / 500 | iteration 15 / 30 | Total Loss: 6.769515514373779 | KNN Loss: 5.711103916168213 | BCE Loss: 1.0584114789962769
Epoch 108 / 500 | iteration 20 / 30 | Total Loss: 6.655019760131836 | KNN Loss: 5.611400604248047 

Epoch 118 / 500 | iteration 5 / 30 | Total Loss: 6.678328514099121 | KNN Loss: 5.629777431488037 | BCE Loss: 1.048551082611084
Epoch 118 / 500 | iteration 10 / 30 | Total Loss: 6.655971527099609 | KNN Loss: 5.613584995269775 | BCE Loss: 1.0423866510391235
Epoch 118 / 500 | iteration 15 / 30 | Total Loss: 6.6489949226379395 | KNN Loss: 5.593647480010986 | BCE Loss: 1.0553474426269531
Epoch 118 / 500 | iteration 20 / 30 | Total Loss: 6.629653453826904 | KNN Loss: 5.592667579650879 | BCE Loss: 1.0369858741760254
Epoch 118 / 500 | iteration 25 / 30 | Total Loss: 6.696819305419922 | KNN Loss: 5.645941734313965 | BCE Loss: 1.050877332687378
Epoch 119 / 500 | iteration 0 / 30 | Total Loss: 6.667258262634277 | KNN Loss: 5.598446846008301 | BCE Loss: 1.0688115358352661
Epoch 119 / 500 | iteration 5 / 30 | Total Loss: 6.6602325439453125 | KNN Loss: 5.6095051765441895 | BCE Loss: 1.0507274866104126
Epoch 119 / 500 | iteration 10 / 30 | Total Loss: 6.633196830749512 | KNN Loss: 5.603232383728027 |

Epoch 128 / 500 | iteration 25 / 30 | Total Loss: 6.6838274002075195 | KNN Loss: 5.640278339385986 | BCE Loss: 1.0435492992401123
Epoch   129: reducing learning rate of group 0 to 5.8824e-04.
Epoch 129 / 500 | iteration 0 / 30 | Total Loss: 6.678582191467285 | KNN Loss: 5.610419273376465 | BCE Loss: 1.0681626796722412
Epoch 129 / 500 | iteration 5 / 30 | Total Loss: 6.720404148101807 | KNN Loss: 5.661571025848389 | BCE Loss: 1.0588332414627075
Epoch 129 / 500 | iteration 10 / 30 | Total Loss: 6.707280158996582 | KNN Loss: 5.63295841217041 | BCE Loss: 1.0743218660354614
Epoch 129 / 500 | iteration 15 / 30 | Total Loss: 6.618492603302002 | KNN Loss: 5.598173141479492 | BCE Loss: 1.0203193426132202
Epoch 129 / 500 | iteration 20 / 30 | Total Loss: 6.632275581359863 | KNN Loss: 5.598455905914307 | BCE Loss: 1.0338199138641357
Epoch 129 / 500 | iteration 25 / 30 | Total Loss: 6.628574848175049 | KNN Loss: 5.598291873931885 | BCE Loss: 1.030282974243164
Epoch 130 / 500 | iteration 0 / 30 | T

Epoch 139 / 500 | iteration 15 / 30 | Total Loss: 6.7515716552734375 | KNN Loss: 5.6953020095825195 | BCE Loss: 1.0562697649002075
Epoch 139 / 500 | iteration 20 / 30 | Total Loss: 6.707491874694824 | KNN Loss: 5.649453163146973 | BCE Loss: 1.0580384731292725
Epoch 139 / 500 | iteration 25 / 30 | Total Loss: 6.681277275085449 | KNN Loss: 5.642602443695068 | BCE Loss: 1.03867506980896
Epoch   140: reducing learning rate of group 0 to 4.1177e-04.
Epoch 140 / 500 | iteration 0 / 30 | Total Loss: 6.649196624755859 | KNN Loss: 5.592665672302246 | BCE Loss: 1.0565309524536133
Epoch 140 / 500 | iteration 5 / 30 | Total Loss: 6.717127323150635 | KNN Loss: 5.6642937660217285 | BCE Loss: 1.0528335571289062
Epoch 140 / 500 | iteration 10 / 30 | Total Loss: 6.63722562789917 | KNN Loss: 5.596168518066406 | BCE Loss: 1.0410571098327637
Epoch 140 / 500 | iteration 15 / 30 | Total Loss: 6.703810214996338 | KNN Loss: 5.6550798416137695 | BCE Loss: 1.0487302541732788
Epoch 140 / 500 | iteration 20 / 30 

Epoch 150 / 500 | iteration 5 / 30 | Total Loss: 6.66447114944458 | KNN Loss: 5.623578071594238 | BCE Loss: 1.0408930778503418
Epoch 150 / 500 | iteration 10 / 30 | Total Loss: 6.648674011230469 | KNN Loss: 5.615788459777832 | BCE Loss: 1.0328853130340576
Epoch 150 / 500 | iteration 15 / 30 | Total Loss: 6.769718647003174 | KNN Loss: 5.743386745452881 | BCE Loss: 1.026331901550293
Epoch 150 / 500 | iteration 20 / 30 | Total Loss: 6.6746368408203125 | KNN Loss: 5.597639083862305 | BCE Loss: 1.0769977569580078
Epoch 150 / 500 | iteration 25 / 30 | Total Loss: 6.636116981506348 | KNN Loss: 5.593885898590088 | BCE Loss: 1.0422309637069702
Epoch   151: reducing learning rate of group 0 to 2.8824e-04.
Epoch 151 / 500 | iteration 0 / 30 | Total Loss: 6.675058841705322 | KNN Loss: 5.6185808181762695 | BCE Loss: 1.0564779043197632
Epoch 151 / 500 | iteration 5 / 30 | Total Loss: 6.666398048400879 | KNN Loss: 5.64592170715332 | BCE Loss: 1.0204764604568481
Epoch 151 / 500 | iteration 10 / 30 | T

Epoch 160 / 500 | iteration 25 / 30 | Total Loss: 6.736371994018555 | KNN Loss: 5.700174331665039 | BCE Loss: 1.0361976623535156
Epoch 161 / 500 | iteration 0 / 30 | Total Loss: 6.644770622253418 | KNN Loss: 5.600738048553467 | BCE Loss: 1.044032335281372
Epoch 161 / 500 | iteration 5 / 30 | Total Loss: 6.67933988571167 | KNN Loss: 5.614498615264893 | BCE Loss: 1.0648412704467773
Epoch 161 / 500 | iteration 10 / 30 | Total Loss: 6.73223876953125 | KNN Loss: 5.656249523162842 | BCE Loss: 1.075989007949829
Epoch 161 / 500 | iteration 15 / 30 | Total Loss: 6.681676864624023 | KNN Loss: 5.630151748657227 | BCE Loss: 1.051525354385376
Epoch 161 / 500 | iteration 20 / 30 | Total Loss: 6.716960906982422 | KNN Loss: 5.653599262237549 | BCE Loss: 1.0633618831634521
Epoch 161 / 500 | iteration 25 / 30 | Total Loss: 6.643570899963379 | KNN Loss: 5.605375289916992 | BCE Loss: 1.0381954908370972
Epoch   162: reducing learning rate of group 0 to 2.0177e-04.
Epoch 162 / 500 | iteration 0 / 30 | Total

Epoch 171 / 500 | iteration 15 / 30 | Total Loss: 6.672609329223633 | KNN Loss: 5.614200592041016 | BCE Loss: 1.0584087371826172
Epoch 171 / 500 | iteration 20 / 30 | Total Loss: 6.671189308166504 | KNN Loss: 5.604657173156738 | BCE Loss: 1.0665323734283447
Epoch 171 / 500 | iteration 25 / 30 | Total Loss: 6.685840606689453 | KNN Loss: 5.62651252746582 | BCE Loss: 1.0593281984329224
Epoch 172 / 500 | iteration 0 / 30 | Total Loss: 6.709750175476074 | KNN Loss: 5.644826889038086 | BCE Loss: 1.0649232864379883
Epoch 172 / 500 | iteration 5 / 30 | Total Loss: 6.632444381713867 | KNN Loss: 5.5940728187561035 | BCE Loss: 1.0383718013763428
Epoch 172 / 500 | iteration 10 / 30 | Total Loss: 6.663352012634277 | KNN Loss: 5.617891788482666 | BCE Loss: 1.0454599857330322
Epoch 172 / 500 | iteration 15 / 30 | Total Loss: 6.643396377563477 | KNN Loss: 5.596458435058594 | BCE Loss: 1.0469379425048828
Epoch 172 / 500 | iteration 20 / 30 | Total Loss: 6.7014055252075195 | KNN Loss: 5.6603922843933105

Epoch 182 / 500 | iteration 5 / 30 | Total Loss: 6.6579790115356445 | KNN Loss: 5.60611629486084 | BCE Loss: 1.0518629550933838
Epoch 182 / 500 | iteration 10 / 30 | Total Loss: 6.65657901763916 | KNN Loss: 5.617825031280518 | BCE Loss: 1.0387539863586426
Epoch 182 / 500 | iteration 15 / 30 | Total Loss: 6.627588748931885 | KNN Loss: 5.598017692565918 | BCE Loss: 1.0295711755752563
Epoch 182 / 500 | iteration 20 / 30 | Total Loss: 6.6400251388549805 | KNN Loss: 5.601463794708252 | BCE Loss: 1.038561224937439
Epoch 182 / 500 | iteration 25 / 30 | Total Loss: 6.749052047729492 | KNN Loss: 5.688633441925049 | BCE Loss: 1.0604186058044434
Epoch 183 / 500 | iteration 0 / 30 | Total Loss: 6.694167613983154 | KNN Loss: 5.645124912261963 | BCE Loss: 1.0490427017211914
Epoch 183 / 500 | iteration 5 / 30 | Total Loss: 6.670742034912109 | KNN Loss: 5.617212772369385 | BCE Loss: 1.0535290241241455
Epoch 183 / 500 | iteration 10 / 30 | Total Loss: 6.670945167541504 | KNN Loss: 5.626399993896484 | B

Epoch 192 / 500 | iteration 25 / 30 | Total Loss: 6.779690742492676 | KNN Loss: 5.72421932220459 | BCE Loss: 1.055471658706665
Epoch 193 / 500 | iteration 0 / 30 | Total Loss: 6.629465103149414 | KNN Loss: 5.598260402679443 | BCE Loss: 1.0312045812606812
Epoch 193 / 500 | iteration 5 / 30 | Total Loss: 6.6223859786987305 | KNN Loss: 5.592004299163818 | BCE Loss: 1.0303819179534912
Epoch 193 / 500 | iteration 10 / 30 | Total Loss: 6.634976863861084 | KNN Loss: 5.60638952255249 | BCE Loss: 1.0285872220993042
Epoch 193 / 500 | iteration 15 / 30 | Total Loss: 6.666520595550537 | KNN Loss: 5.605529308319092 | BCE Loss: 1.0609912872314453
Epoch 193 / 500 | iteration 20 / 30 | Total Loss: 6.65282678604126 | KNN Loss: 5.595981121063232 | BCE Loss: 1.0568456649780273
Epoch 193 / 500 | iteration 25 / 30 | Total Loss: 6.672272205352783 | KNN Loss: 5.628424167633057 | BCE Loss: 1.0438481569290161
Epoch 194 / 500 | iteration 0 / 30 | Total Loss: 6.707898139953613 | KNN Loss: 5.636188507080078 | BCE

Epoch 203 / 500 | iteration 15 / 30 | Total Loss: 6.7042083740234375 | KNN Loss: 5.667392253875732 | BCE Loss: 1.0368163585662842
Epoch 203 / 500 | iteration 20 / 30 | Total Loss: 6.643509387969971 | KNN Loss: 5.610869884490967 | BCE Loss: 1.0326393842697144
Epoch 203 / 500 | iteration 25 / 30 | Total Loss: 6.670121192932129 | KNN Loss: 5.634101390838623 | BCE Loss: 1.036020040512085
Epoch 204 / 500 | iteration 0 / 30 | Total Loss: 6.665125846862793 | KNN Loss: 5.618555068969727 | BCE Loss: 1.0465705394744873
Epoch 204 / 500 | iteration 5 / 30 | Total Loss: 6.727392196655273 | KNN Loss: 5.683957099914551 | BCE Loss: 1.0434348583221436
Epoch 204 / 500 | iteration 10 / 30 | Total Loss: 6.670970916748047 | KNN Loss: 5.602113246917725 | BCE Loss: 1.0688579082489014
Epoch 204 / 500 | iteration 15 / 30 | Total Loss: 6.6792216300964355 | KNN Loss: 5.635622024536133 | BCE Loss: 1.0435994863510132
Epoch 204 / 500 | iteration 20 / 30 | Total Loss: 6.770442485809326 | KNN Loss: 5.7156147956848145

Epoch 214 / 500 | iteration 5 / 30 | Total Loss: 6.657923698425293 | KNN Loss: 5.608719348907471 | BCE Loss: 1.0492043495178223
Epoch 214 / 500 | iteration 10 / 30 | Total Loss: 6.640993118286133 | KNN Loss: 5.596783638000488 | BCE Loss: 1.0442092418670654
Epoch 214 / 500 | iteration 15 / 30 | Total Loss: 6.7549238204956055 | KNN Loss: 5.7053141593933105 | BCE Loss: 1.049609661102295
Epoch 214 / 500 | iteration 20 / 30 | Total Loss: 6.639946937561035 | KNN Loss: 5.597443580627441 | BCE Loss: 1.0425031185150146
Epoch 214 / 500 | iteration 25 / 30 | Total Loss: 6.697096824645996 | KNN Loss: 5.647896766662598 | BCE Loss: 1.0491999387741089
Epoch 215 / 500 | iteration 0 / 30 | Total Loss: 6.641568660736084 | KNN Loss: 5.60659122467041 | BCE Loss: 1.0349774360656738
Epoch 215 / 500 | iteration 5 / 30 | Total Loss: 6.70533561706543 | KNN Loss: 5.639509201049805 | BCE Loss: 1.065826416015625
Epoch 215 / 500 | iteration 10 / 30 | Total Loss: 6.696913719177246 | KNN Loss: 5.615418910980225 | BC

Epoch 224 / 500 | iteration 25 / 30 | Total Loss: 6.689937591552734 | KNN Loss: 5.631084442138672 | BCE Loss: 1.058853268623352
Epoch 225 / 500 | iteration 0 / 30 | Total Loss: 6.650698661804199 | KNN Loss: 5.60921049118042 | BCE Loss: 1.0414884090423584
Epoch 225 / 500 | iteration 5 / 30 | Total Loss: 6.738119125366211 | KNN Loss: 5.667929172515869 | BCE Loss: 1.070190191268921
Epoch 225 / 500 | iteration 10 / 30 | Total Loss: 6.695267677307129 | KNN Loss: 5.671316146850586 | BCE Loss: 1.0239512920379639
Epoch 225 / 500 | iteration 15 / 30 | Total Loss: 6.704167366027832 | KNN Loss: 5.607880115509033 | BCE Loss: 1.0962872505187988
Epoch 225 / 500 | iteration 20 / 30 | Total Loss: 6.656418800354004 | KNN Loss: 5.602319240570068 | BCE Loss: 1.0540997982025146
Epoch 225 / 500 | iteration 25 / 30 | Total Loss: 6.713015556335449 | KNN Loss: 5.676568508148193 | BCE Loss: 1.0364471673965454
Epoch 226 / 500 | iteration 0 / 30 | Total Loss: 6.6995391845703125 | KNN Loss: 5.6663994789123535 | B

Epoch 235 / 500 | iteration 15 / 30 | Total Loss: 6.648608684539795 | KNN Loss: 5.6197428703308105 | BCE Loss: 1.0288658142089844
Epoch 235 / 500 | iteration 20 / 30 | Total Loss: 6.810937881469727 | KNN Loss: 5.752829074859619 | BCE Loss: 1.058108925819397
Epoch 235 / 500 | iteration 25 / 30 | Total Loss: 6.680927276611328 | KNN Loss: 5.615092754364014 | BCE Loss: 1.0658347606658936
Epoch 236 / 500 | iteration 0 / 30 | Total Loss: 6.7010650634765625 | KNN Loss: 5.643383026123047 | BCE Loss: 1.0576822757720947
Epoch 236 / 500 | iteration 5 / 30 | Total Loss: 6.706835746765137 | KNN Loss: 5.63114070892334 | BCE Loss: 1.075695276260376
Epoch 236 / 500 | iteration 10 / 30 | Total Loss: 6.728301048278809 | KNN Loss: 5.6612653732299805 | BCE Loss: 1.067035436630249
Epoch 236 / 500 | iteration 15 / 30 | Total Loss: 6.652581214904785 | KNN Loss: 5.594162940979004 | BCE Loss: 1.0584180355072021
Epoch 236 / 500 | iteration 20 / 30 | Total Loss: 6.668451309204102 | KNN Loss: 5.625827789306641 | 

Epoch 246 / 500 | iteration 5 / 30 | Total Loss: 6.6854400634765625 | KNN Loss: 5.608436584472656 | BCE Loss: 1.0770034790039062
Epoch 246 / 500 | iteration 10 / 30 | Total Loss: 6.6690473556518555 | KNN Loss: 5.621182918548584 | BCE Loss: 1.0478641986846924
Epoch 246 / 500 | iteration 15 / 30 | Total Loss: 6.645016193389893 | KNN Loss: 5.594892978668213 | BCE Loss: 1.0501232147216797
Epoch 246 / 500 | iteration 20 / 30 | Total Loss: 6.6126909255981445 | KNN Loss: 5.5986328125 | BCE Loss: 1.0140578746795654
Epoch 246 / 500 | iteration 25 / 30 | Total Loss: 6.687328815460205 | KNN Loss: 5.64163875579834 | BCE Loss: 1.0456899404525757
Epoch 247 / 500 | iteration 0 / 30 | Total Loss: 6.720671653747559 | KNN Loss: 5.638343334197998 | BCE Loss: 1.0823283195495605
Epoch 247 / 500 | iteration 5 / 30 | Total Loss: 6.695827484130859 | KNN Loss: 5.638443946838379 | BCE Loss: 1.0573837757110596
Epoch 247 / 500 | iteration 10 / 30 | Total Loss: 6.623012065887451 | KNN Loss: 5.617652416229248 | BCE

Epoch 256 / 500 | iteration 25 / 30 | Total Loss: 6.657567977905273 | KNN Loss: 5.606690406799316 | BCE Loss: 1.0508778095245361
Epoch 257 / 500 | iteration 0 / 30 | Total Loss: 6.669547080993652 | KNN Loss: 5.639656066894531 | BCE Loss: 1.029890775680542
Epoch 257 / 500 | iteration 5 / 30 | Total Loss: 6.631656646728516 | KNN Loss: 5.597186088562012 | BCE Loss: 1.0344703197479248
Epoch 257 / 500 | iteration 10 / 30 | Total Loss: 6.69064474105835 | KNN Loss: 5.653416633605957 | BCE Loss: 1.0372281074523926
Epoch 257 / 500 | iteration 15 / 30 | Total Loss: 6.7557268142700195 | KNN Loss: 5.697240829467773 | BCE Loss: 1.058485746383667
Epoch 257 / 500 | iteration 20 / 30 | Total Loss: 6.673306941986084 | KNN Loss: 5.616492748260498 | BCE Loss: 1.056814193725586
Epoch 257 / 500 | iteration 25 / 30 | Total Loss: 6.661456108093262 | KNN Loss: 5.616448402404785 | BCE Loss: 1.0450079441070557
Epoch 258 / 500 | iteration 0 / 30 | Total Loss: 6.6306352615356445 | KNN Loss: 5.603365421295166 | BC

Epoch 267 / 500 | iteration 15 / 30 | Total Loss: 6.6496477127075195 | KNN Loss: 5.595839977264404 | BCE Loss: 1.0538079738616943
Epoch 267 / 500 | iteration 20 / 30 | Total Loss: 6.645152568817139 | KNN Loss: 5.618465423583984 | BCE Loss: 1.0266871452331543
Epoch 267 / 500 | iteration 25 / 30 | Total Loss: 6.624072074890137 | KNN Loss: 5.59575080871582 | BCE Loss: 1.0283215045928955
Epoch 268 / 500 | iteration 0 / 30 | Total Loss: 6.6714959144592285 | KNN Loss: 5.6270751953125 | BCE Loss: 1.044420838356018
Epoch 268 / 500 | iteration 5 / 30 | Total Loss: 6.670822620391846 | KNN Loss: 5.594536304473877 | BCE Loss: 1.0762863159179688
Epoch 268 / 500 | iteration 10 / 30 | Total Loss: 6.654727935791016 | KNN Loss: 5.601110935211182 | BCE Loss: 1.053617000579834
Epoch 268 / 500 | iteration 15 / 30 | Total Loss: 6.651573181152344 | KNN Loss: 5.612483978271484 | BCE Loss: 1.0390890836715698
Epoch 268 / 500 | iteration 20 / 30 | Total Loss: 6.658444404602051 | KNN Loss: 5.6038665771484375 | B

Epoch 278 / 500 | iteration 5 / 30 | Total Loss: 6.64456033706665 | KNN Loss: 5.612809181213379 | BCE Loss: 1.031751036643982
Epoch 278 / 500 | iteration 10 / 30 | Total Loss: 6.633317947387695 | KNN Loss: 5.6027679443359375 | BCE Loss: 1.030550241470337
Epoch 278 / 500 | iteration 15 / 30 | Total Loss: 6.683040618896484 | KNN Loss: 5.617974758148193 | BCE Loss: 1.0650660991668701
Epoch 278 / 500 | iteration 20 / 30 | Total Loss: 6.65429162979126 | KNN Loss: 5.619722843170166 | BCE Loss: 1.0345687866210938
Epoch 278 / 500 | iteration 25 / 30 | Total Loss: 6.691657543182373 | KNN Loss: 5.634257793426514 | BCE Loss: 1.0573997497558594
Epoch 279 / 500 | iteration 0 / 30 | Total Loss: 6.649815559387207 | KNN Loss: 5.618764400482178 | BCE Loss: 1.0310509204864502
Epoch 279 / 500 | iteration 5 / 30 | Total Loss: 6.642840385437012 | KNN Loss: 5.615389347076416 | BCE Loss: 1.0274512767791748
Epoch 279 / 500 | iteration 10 / 30 | Total Loss: 6.696717262268066 | KNN Loss: 5.637483596801758 | BCE

Epoch 288 / 500 | iteration 25 / 30 | Total Loss: 6.771689414978027 | KNN Loss: 5.730332374572754 | BCE Loss: 1.0413570404052734
Epoch 289 / 500 | iteration 0 / 30 | Total Loss: 6.709698677062988 | KNN Loss: 5.6244940757751465 | BCE Loss: 1.0852043628692627
Epoch 289 / 500 | iteration 5 / 30 | Total Loss: 6.7591233253479 | KNN Loss: 5.699336528778076 | BCE Loss: 1.0597867965698242
Epoch 289 / 500 | iteration 10 / 30 | Total Loss: 6.687190055847168 | KNN Loss: 5.620724678039551 | BCE Loss: 1.0664656162261963
Epoch 289 / 500 | iteration 15 / 30 | Total Loss: 6.637509346008301 | KNN Loss: 5.600642204284668 | BCE Loss: 1.036867380142212
Epoch 289 / 500 | iteration 20 / 30 | Total Loss: 6.731346607208252 | KNN Loss: 5.693033218383789 | BCE Loss: 1.038313388824463
Epoch 289 / 500 | iteration 25 / 30 | Total Loss: 6.664642810821533 | KNN Loss: 5.597926139831543 | BCE Loss: 1.0667165517807007
Epoch 290 / 500 | iteration 0 / 30 | Total Loss: 6.682305335998535 | KNN Loss: 5.6266913414001465 | BC

Epoch 299 / 500 | iteration 15 / 30 | Total Loss: 6.6573591232299805 | KNN Loss: 5.621152400970459 | BCE Loss: 1.036206603050232
Epoch 299 / 500 | iteration 20 / 30 | Total Loss: 6.644556999206543 | KNN Loss: 5.608725547790527 | BCE Loss: 1.0358316898345947
Epoch 299 / 500 | iteration 25 / 30 | Total Loss: 6.765723705291748 | KNN Loss: 5.718708038330078 | BCE Loss: 1.04701566696167
Epoch 300 / 500 | iteration 0 / 30 | Total Loss: 6.657115936279297 | KNN Loss: 5.594414234161377 | BCE Loss: 1.062701940536499
Epoch 300 / 500 | iteration 5 / 30 | Total Loss: 6.647456169128418 | KNN Loss: 5.620770454406738 | BCE Loss: 1.0266857147216797
Epoch 300 / 500 | iteration 10 / 30 | Total Loss: 6.661035060882568 | KNN Loss: 5.615723133087158 | BCE Loss: 1.0453120470046997
Epoch 300 / 500 | iteration 15 / 30 | Total Loss: 6.67130184173584 | KNN Loss: 5.6194987297058105 | BCE Loss: 1.0518031120300293
Epoch 300 / 500 | iteration 20 / 30 | Total Loss: 6.653002738952637 | KNN Loss: 5.616968631744385 | BC

Epoch 310 / 500 | iteration 5 / 30 | Total Loss: 6.633763313293457 | KNN Loss: 5.593297958374023 | BCE Loss: 1.0404654741287231
Epoch 310 / 500 | iteration 10 / 30 | Total Loss: 6.754829406738281 | KNN Loss: 5.69495964050293 | BCE Loss: 1.059869647026062
Epoch 310 / 500 | iteration 15 / 30 | Total Loss: 6.694135665893555 | KNN Loss: 5.628715991973877 | BCE Loss: 1.0654199123382568
Epoch 310 / 500 | iteration 20 / 30 | Total Loss: 6.6176676750183105 | KNN Loss: 5.596700191497803 | BCE Loss: 1.0209676027297974
Epoch 310 / 500 | iteration 25 / 30 | Total Loss: 6.694865703582764 | KNN Loss: 5.660048961639404 | BCE Loss: 1.0348167419433594
Epoch 311 / 500 | iteration 0 / 30 | Total Loss: 6.619037628173828 | KNN Loss: 5.601258277893066 | BCE Loss: 1.0177795886993408
Epoch 311 / 500 | iteration 5 / 30 | Total Loss: 6.774803638458252 | KNN Loss: 5.738369464874268 | BCE Loss: 1.0364341735839844
Epoch 311 / 500 | iteration 10 / 30 | Total Loss: 6.612616539001465 | KNN Loss: 5.597124099731445 | B

Epoch 320 / 500 | iteration 25 / 30 | Total Loss: 6.708527088165283 | KNN Loss: 5.649270534515381 | BCE Loss: 1.059256672859192
Epoch 321 / 500 | iteration 0 / 30 | Total Loss: 6.645556449890137 | KNN Loss: 5.595376014709473 | BCE Loss: 1.050180196762085
Epoch 321 / 500 | iteration 5 / 30 | Total Loss: 6.6319379806518555 | KNN Loss: 5.604128837585449 | BCE Loss: 1.0278091430664062
Epoch 321 / 500 | iteration 10 / 30 | Total Loss: 6.673250675201416 | KNN Loss: 5.598572731018066 | BCE Loss: 1.0746779441833496
Epoch 321 / 500 | iteration 15 / 30 | Total Loss: 6.7458906173706055 | KNN Loss: 5.669541835784912 | BCE Loss: 1.0763490200042725
Epoch 321 / 500 | iteration 20 / 30 | Total Loss: 6.619791030883789 | KNN Loss: 5.611082553863525 | BCE Loss: 1.0087083578109741
Epoch 321 / 500 | iteration 25 / 30 | Total Loss: 6.68602180480957 | KNN Loss: 5.656915664672852 | BCE Loss: 1.0291062593460083
Epoch 322 / 500 | iteration 0 / 30 | Total Loss: 6.647005558013916 | KNN Loss: 5.624006271362305 | B

Epoch 331 / 500 | iteration 15 / 30 | Total Loss: 6.7315473556518555 | KNN Loss: 5.689108848571777 | BCE Loss: 1.0424386262893677
Epoch 331 / 500 | iteration 20 / 30 | Total Loss: 6.647971153259277 | KNN Loss: 5.602514743804932 | BCE Loss: 1.0454561710357666
Epoch 331 / 500 | iteration 25 / 30 | Total Loss: 6.70372200012207 | KNN Loss: 5.6313066482543945 | BCE Loss: 1.0724151134490967
Epoch 332 / 500 | iteration 0 / 30 | Total Loss: 6.655806541442871 | KNN Loss: 5.602598667144775 | BCE Loss: 1.0532081127166748
Epoch 332 / 500 | iteration 5 / 30 | Total Loss: 6.665623664855957 | KNN Loss: 5.626487731933594 | BCE Loss: 1.0391356945037842
Epoch 332 / 500 | iteration 10 / 30 | Total Loss: 6.710417747497559 | KNN Loss: 5.678835391998291 | BCE Loss: 1.0315825939178467
Epoch 332 / 500 | iteration 15 / 30 | Total Loss: 6.741229057312012 | KNN Loss: 5.651371479034424 | BCE Loss: 1.089857578277588
Epoch 332 / 500 | iteration 20 / 30 | Total Loss: 6.70876407623291 | KNN Loss: 5.661553382873535 | 

Epoch 342 / 500 | iteration 5 / 30 | Total Loss: 6.6407976150512695 | KNN Loss: 5.592962741851807 | BCE Loss: 1.047835111618042
Epoch 342 / 500 | iteration 10 / 30 | Total Loss: 6.6979241371154785 | KNN Loss: 5.63923978805542 | BCE Loss: 1.0586844682693481
Epoch 342 / 500 | iteration 15 / 30 | Total Loss: 6.700593948364258 | KNN Loss: 5.651535987854004 | BCE Loss: 1.0490577220916748
Epoch 342 / 500 | iteration 20 / 30 | Total Loss: 6.670642852783203 | KNN Loss: 5.606019973754883 | BCE Loss: 1.0646231174468994
Epoch 342 / 500 | iteration 25 / 30 | Total Loss: 6.668867588043213 | KNN Loss: 5.600662708282471 | BCE Loss: 1.0682048797607422
Epoch 343 / 500 | iteration 0 / 30 | Total Loss: 6.66922664642334 | KNN Loss: 5.607205867767334 | BCE Loss: 1.062021017074585
Epoch 343 / 500 | iteration 5 / 30 | Total Loss: 6.658748149871826 | KNN Loss: 5.604900360107422 | BCE Loss: 1.0538479089736938
Epoch 343 / 500 | iteration 10 / 30 | Total Loss: 6.762999534606934 | KNN Loss: 5.697632312774658 | BC

Epoch 352 / 500 | iteration 25 / 30 | Total Loss: 6.6544647216796875 | KNN Loss: 5.615653991699219 | BCE Loss: 1.0388107299804688
Epoch 353 / 500 | iteration 0 / 30 | Total Loss: 6.666289806365967 | KNN Loss: 5.6069254875183105 | BCE Loss: 1.0593643188476562
Epoch 353 / 500 | iteration 5 / 30 | Total Loss: 6.682476997375488 | KNN Loss: 5.646657943725586 | BCE Loss: 1.035819172859192
Epoch 353 / 500 | iteration 10 / 30 | Total Loss: 6.756270408630371 | KNN Loss: 5.703603267669678 | BCE Loss: 1.0526673793792725
Epoch 353 / 500 | iteration 15 / 30 | Total Loss: 6.6673479080200195 | KNN Loss: 5.609593868255615 | BCE Loss: 1.0577540397644043
Epoch 353 / 500 | iteration 20 / 30 | Total Loss: 6.625819206237793 | KNN Loss: 5.597403526306152 | BCE Loss: 1.0284159183502197
Epoch 353 / 500 | iteration 25 / 30 | Total Loss: 6.714117050170898 | KNN Loss: 5.66993522644043 | BCE Loss: 1.0441818237304688
Epoch 354 / 500 | iteration 0 / 30 | Total Loss: 6.725008010864258 | KNN Loss: 5.658242702484131 |

Epoch 363 / 500 | iteration 15 / 30 | Total Loss: 6.748856544494629 | KNN Loss: 5.713974952697754 | BCE Loss: 1.0348814725875854
Epoch 363 / 500 | iteration 20 / 30 | Total Loss: 6.731478691101074 | KNN Loss: 5.663523197174072 | BCE Loss: 1.0679552555084229
Epoch 363 / 500 | iteration 25 / 30 | Total Loss: 6.686615943908691 | KNN Loss: 5.676231384277344 | BCE Loss: 1.010384440422058
Epoch 364 / 500 | iteration 0 / 30 | Total Loss: 6.656398773193359 | KNN Loss: 5.6079254150390625 | BCE Loss: 1.0484731197357178
Epoch 364 / 500 | iteration 5 / 30 | Total Loss: 6.695870876312256 | KNN Loss: 5.652743339538574 | BCE Loss: 1.0431275367736816
Epoch 364 / 500 | iteration 10 / 30 | Total Loss: 6.622135639190674 | KNN Loss: 5.591168403625488 | BCE Loss: 1.030967354774475
Epoch 364 / 500 | iteration 15 / 30 | Total Loss: 6.6399736404418945 | KNN Loss: 5.591217517852783 | BCE Loss: 1.0487563610076904
Epoch 364 / 500 | iteration 20 / 30 | Total Loss: 6.634133338928223 | KNN Loss: 5.611938953399658 |

Epoch 374 / 500 | iteration 5 / 30 | Total Loss: 6.652017116546631 | KNN Loss: 5.598711967468262 | BCE Loss: 1.0533052682876587
Epoch 374 / 500 | iteration 10 / 30 | Total Loss: 6.752880096435547 | KNN Loss: 5.684060573577881 | BCE Loss: 1.068819284439087
Epoch 374 / 500 | iteration 15 / 30 | Total Loss: 6.690841197967529 | KNN Loss: 5.640128135681152 | BCE Loss: 1.050713062286377
Epoch 374 / 500 | iteration 20 / 30 | Total Loss: 6.649293422698975 | KNN Loss: 5.597191333770752 | BCE Loss: 1.0521020889282227
Epoch 374 / 500 | iteration 25 / 30 | Total Loss: 6.703794956207275 | KNN Loss: 5.662990093231201 | BCE Loss: 1.0408047437667847
Epoch 375 / 500 | iteration 0 / 30 | Total Loss: 6.639779090881348 | KNN Loss: 5.619815349578857 | BCE Loss: 1.0199635028839111
Epoch 375 / 500 | iteration 5 / 30 | Total Loss: 6.6576056480407715 | KNN Loss: 5.597208023071289 | BCE Loss: 1.060397744178772
Epoch 375 / 500 | iteration 10 / 30 | Total Loss: 6.750003814697266 | KNN Loss: 5.682858943939209 | BC

Epoch 384 / 500 | iteration 25 / 30 | Total Loss: 6.6073126792907715 | KNN Loss: 5.604601860046387 | BCE Loss: 1.0027108192443848
Epoch 385 / 500 | iteration 0 / 30 | Total Loss: 6.708436965942383 | KNN Loss: 5.645650863647461 | BCE Loss: 1.062786340713501
Epoch 385 / 500 | iteration 5 / 30 | Total Loss: 6.681609153747559 | KNN Loss: 5.634986877441406 | BCE Loss: 1.0466221570968628
Epoch 385 / 500 | iteration 10 / 30 | Total Loss: 6.784908294677734 | KNN Loss: 5.7243218421936035 | BCE Loss: 1.06058669090271
Epoch 385 / 500 | iteration 15 / 30 | Total Loss: 6.7209062576293945 | KNN Loss: 5.648448944091797 | BCE Loss: 1.0724575519561768
Epoch 385 / 500 | iteration 20 / 30 | Total Loss: 6.649965286254883 | KNN Loss: 5.621917247772217 | BCE Loss: 1.028047800064087
Epoch 385 / 500 | iteration 25 / 30 | Total Loss: 6.666098594665527 | KNN Loss: 5.623793601989746 | BCE Loss: 1.0423049926757812
Epoch 386 / 500 | iteration 0 / 30 | Total Loss: 6.6520185470581055 | KNN Loss: 5.636840343475342 | 

Epoch 395 / 500 | iteration 15 / 30 | Total Loss: 6.664648056030273 | KNN Loss: 5.601632118225098 | BCE Loss: 1.0630161762237549
Epoch 395 / 500 | iteration 20 / 30 | Total Loss: 6.735088348388672 | KNN Loss: 5.659753322601318 | BCE Loss: 1.0753347873687744
Epoch 395 / 500 | iteration 25 / 30 | Total Loss: 6.685731410980225 | KNN Loss: 5.635324954986572 | BCE Loss: 1.0504064559936523
Epoch 396 / 500 | iteration 0 / 30 | Total Loss: 6.645852565765381 | KNN Loss: 5.64146614074707 | BCE Loss: 1.0043865442276
Epoch 396 / 500 | iteration 5 / 30 | Total Loss: 6.66829252243042 | KNN Loss: 5.60400915145874 | BCE Loss: 1.0642833709716797
Epoch 396 / 500 | iteration 10 / 30 | Total Loss: 6.7177276611328125 | KNN Loss: 5.633070468902588 | BCE Loss: 1.0846569538116455
Epoch 396 / 500 | iteration 15 / 30 | Total Loss: 6.699888706207275 | KNN Loss: 5.634925842285156 | BCE Loss: 1.0649627447128296
Epoch 396 / 500 | iteration 20 / 30 | Total Loss: 6.615666389465332 | KNN Loss: 5.592231273651123 | BCE 

Epoch 406 / 500 | iteration 5 / 30 | Total Loss: 6.66158390045166 | KNN Loss: 5.600961685180664 | BCE Loss: 1.060621976852417
Epoch 406 / 500 | iteration 10 / 30 | Total Loss: 6.678431510925293 | KNN Loss: 5.6140546798706055 | BCE Loss: 1.064376950263977
Epoch 406 / 500 | iteration 15 / 30 | Total Loss: 6.646390914916992 | KNN Loss: 5.597099304199219 | BCE Loss: 1.0492918491363525
Epoch 406 / 500 | iteration 20 / 30 | Total Loss: 6.648588180541992 | KNN Loss: 5.602494239807129 | BCE Loss: 1.0460937023162842
Epoch 406 / 500 | iteration 25 / 30 | Total Loss: 6.659420490264893 | KNN Loss: 5.5979838371276855 | BCE Loss: 1.061436653137207
Epoch 407 / 500 | iteration 0 / 30 | Total Loss: 6.691220283508301 | KNN Loss: 5.651219367980957 | BCE Loss: 1.0400011539459229
Epoch 407 / 500 | iteration 5 / 30 | Total Loss: 6.658292293548584 | KNN Loss: 5.615538597106934 | BCE Loss: 1.0427536964416504
Epoch 407 / 500 | iteration 10 / 30 | Total Loss: 6.682666778564453 | KNN Loss: 5.622068405151367 | BC

Epoch 416 / 500 | iteration 25 / 30 | Total Loss: 6.676491737365723 | KNN Loss: 5.619062423706055 | BCE Loss: 1.0574290752410889
Epoch 417 / 500 | iteration 0 / 30 | Total Loss: 6.673384666442871 | KNN Loss: 5.642085552215576 | BCE Loss: 1.031299352645874
Epoch 417 / 500 | iteration 5 / 30 | Total Loss: 6.652766227722168 | KNN Loss: 5.604642868041992 | BCE Loss: 1.0481233596801758
Epoch 417 / 500 | iteration 10 / 30 | Total Loss: 6.658535003662109 | KNN Loss: 5.612541198730469 | BCE Loss: 1.0459939241409302
Epoch 417 / 500 | iteration 15 / 30 | Total Loss: 6.6328864097595215 | KNN Loss: 5.593512535095215 | BCE Loss: 1.0393739938735962
Epoch 417 / 500 | iteration 20 / 30 | Total Loss: 6.700678825378418 | KNN Loss: 5.637027263641357 | BCE Loss: 1.063651442527771
Epoch 417 / 500 | iteration 25 / 30 | Total Loss: 6.645586013793945 | KNN Loss: 5.613584041595459 | BCE Loss: 1.0320019721984863
Epoch 418 / 500 | iteration 0 / 30 | Total Loss: 6.699481010437012 | KNN Loss: 5.65203857421875 | BC

Epoch 427 / 500 | iteration 15 / 30 | Total Loss: 6.667026519775391 | KNN Loss: 5.599555015563965 | BCE Loss: 1.0674715042114258
Epoch 427 / 500 | iteration 20 / 30 | Total Loss: 6.639980316162109 | KNN Loss: 5.5939531326293945 | BCE Loss: 1.0460271835327148
Epoch 427 / 500 | iteration 25 / 30 | Total Loss: 6.638461112976074 | KNN Loss: 5.618919372558594 | BCE Loss: 1.01954185962677
Epoch 428 / 500 | iteration 0 / 30 | Total Loss: 6.645467281341553 | KNN Loss: 5.609829425811768 | BCE Loss: 1.0356378555297852
Epoch 428 / 500 | iteration 5 / 30 | Total Loss: 6.8503851890563965 | KNN Loss: 5.8300909996032715 | BCE Loss: 1.020294189453125
Epoch 428 / 500 | iteration 10 / 30 | Total Loss: 6.652978897094727 | KNN Loss: 5.600870609283447 | BCE Loss: 1.0521084070205688
Epoch 428 / 500 | iteration 15 / 30 | Total Loss: 6.624970436096191 | KNN Loss: 5.5981831550598145 | BCE Loss: 1.026787519454956
Epoch 428 / 500 | iteration 20 / 30 | Total Loss: 6.645907402038574 | KNN Loss: 5.614967346191406 |

Epoch 438 / 500 | iteration 5 / 30 | Total Loss: 6.674930095672607 | KNN Loss: 5.633243560791016 | BCE Loss: 1.0416865348815918
Epoch 438 / 500 | iteration 10 / 30 | Total Loss: 6.750466823577881 | KNN Loss: 5.698458194732666 | BCE Loss: 1.0520086288452148
Epoch 438 / 500 | iteration 15 / 30 | Total Loss: 6.657831192016602 | KNN Loss: 5.606848239898682 | BCE Loss: 1.0509827136993408
Epoch 438 / 500 | iteration 20 / 30 | Total Loss: 6.655025959014893 | KNN Loss: 5.593140602111816 | BCE Loss: 1.0618853569030762
Epoch 438 / 500 | iteration 25 / 30 | Total Loss: 6.741613388061523 | KNN Loss: 5.717184066772461 | BCE Loss: 1.0244293212890625
Epoch 439 / 500 | iteration 0 / 30 | Total Loss: 6.710047245025635 | KNN Loss: 5.659897327423096 | BCE Loss: 1.0501497983932495
Epoch 439 / 500 | iteration 5 / 30 | Total Loss: 6.647607803344727 | KNN Loss: 5.5958075523376465 | BCE Loss: 1.05180025100708
Epoch 439 / 500 | iteration 10 / 30 | Total Loss: 6.679595947265625 | KNN Loss: 5.613131523132324 | B

Epoch 448 / 500 | iteration 25 / 30 | Total Loss: 6.635524749755859 | KNN Loss: 5.599367618560791 | BCE Loss: 1.0361573696136475
Epoch   449: reducing learning rate of group 0 to 2.7058e-08.
Epoch 449 / 500 | iteration 0 / 30 | Total Loss: 6.635662078857422 | KNN Loss: 5.605652809143066 | BCE Loss: 1.0300092697143555
Epoch 449 / 500 | iteration 5 / 30 | Total Loss: 6.693984508514404 | KNN Loss: 5.647451400756836 | BCE Loss: 1.0465331077575684
Epoch 449 / 500 | iteration 10 / 30 | Total Loss: 6.695448398590088 | KNN Loss: 5.6347479820251465 | BCE Loss: 1.0607002973556519
Epoch 449 / 500 | iteration 15 / 30 | Total Loss: 6.683749198913574 | KNN Loss: 5.626462459564209 | BCE Loss: 1.0572868585586548
Epoch 449 / 500 | iteration 20 / 30 | Total Loss: 6.656353950500488 | KNN Loss: 5.63012170791626 | BCE Loss: 1.0262320041656494
Epoch 449 / 500 | iteration 25 / 30 | Total Loss: 6.671597480773926 | KNN Loss: 5.615428924560547 | BCE Loss: 1.056168794631958
Epoch 450 / 500 | iteration 0 / 30 | T

Epoch 459 / 500 | iteration 15 / 30 | Total Loss: 6.671774387359619 | KNN Loss: 5.628738880157471 | BCE Loss: 1.0430353879928589
Epoch 459 / 500 | iteration 20 / 30 | Total Loss: 6.710239410400391 | KNN Loss: 5.652767658233643 | BCE Loss: 1.0574719905853271
Epoch 459 / 500 | iteration 25 / 30 | Total Loss: 6.665363788604736 | KNN Loss: 5.596570014953613 | BCE Loss: 1.0687936544418335
Epoch 460 / 500 | iteration 0 / 30 | Total Loss: 6.672429084777832 | KNN Loss: 5.633646488189697 | BCE Loss: 1.0387825965881348
Epoch 460 / 500 | iteration 5 / 30 | Total Loss: 6.689088821411133 | KNN Loss: 5.6199631690979 | BCE Loss: 1.0691254138946533
Epoch 460 / 500 | iteration 10 / 30 | Total Loss: 6.671342849731445 | KNN Loss: 5.636999607086182 | BCE Loss: 1.0343433618545532
Epoch 460 / 500 | iteration 15 / 30 | Total Loss: 6.694055557250977 | KNN Loss: 5.662044048309326 | BCE Loss: 1.0320112705230713
Epoch 460 / 500 | iteration 20 / 30 | Total Loss: 6.6909894943237305 | KNN Loss: 5.671223163604736 | 

Epoch 470 / 500 | iteration 5 / 30 | Total Loss: 6.73414945602417 | KNN Loss: 5.693599700927734 | BCE Loss: 1.0405497550964355
Epoch 470 / 500 | iteration 10 / 30 | Total Loss: 6.708990573883057 | KNN Loss: 5.671413421630859 | BCE Loss: 1.0375771522521973
Epoch 470 / 500 | iteration 15 / 30 | Total Loss: 6.662715911865234 | KNN Loss: 5.614140510559082 | BCE Loss: 1.048575520515442
Epoch 470 / 500 | iteration 20 / 30 | Total Loss: 6.632256031036377 | KNN Loss: 5.601239204406738 | BCE Loss: 1.0310168266296387
Epoch 470 / 500 | iteration 25 / 30 | Total Loss: 6.696747303009033 | KNN Loss: 5.6356282234191895 | BCE Loss: 1.0611190795898438
Epoch 471 / 500 | iteration 0 / 30 | Total Loss: 6.6831254959106445 | KNN Loss: 5.640750885009766 | BCE Loss: 1.042374849319458
Epoch 471 / 500 | iteration 5 / 30 | Total Loss: 6.6779913902282715 | KNN Loss: 5.625983238220215 | BCE Loss: 1.052008032798767
Epoch 471 / 500 | iteration 10 / 30 | Total Loss: 6.68925666809082 | KNN Loss: 5.647762775421143 | BC

Epoch 480 / 500 | iteration 25 / 30 | Total Loss: 6.637829303741455 | KNN Loss: 5.596517562866211 | BCE Loss: 1.0413117408752441
Epoch 481 / 500 | iteration 0 / 30 | Total Loss: 6.676870822906494 | KNN Loss: 5.602390289306641 | BCE Loss: 1.0744805335998535
Epoch 481 / 500 | iteration 5 / 30 | Total Loss: 6.654966831207275 | KNN Loss: 5.598052024841309 | BCE Loss: 1.0569148063659668
Epoch 481 / 500 | iteration 10 / 30 | Total Loss: 6.650898456573486 | KNN Loss: 5.617271900177002 | BCE Loss: 1.033626675605774
Epoch 481 / 500 | iteration 15 / 30 | Total Loss: 6.660549163818359 | KNN Loss: 5.61375617980957 | BCE Loss: 1.046792984008789
Epoch 481 / 500 | iteration 20 / 30 | Total Loss: 6.6108293533325195 | KNN Loss: 5.593349933624268 | BCE Loss: 1.017479658126831
Epoch 481 / 500 | iteration 25 / 30 | Total Loss: 6.678884506225586 | KNN Loss: 5.632691383361816 | BCE Loss: 1.0461933612823486
Epoch 482 / 500 | iteration 0 / 30 | Total Loss: 6.627462863922119 | KNN Loss: 5.601944923400879 | BCE

Epoch 491 / 500 | iteration 15 / 30 | Total Loss: 6.679714202880859 | KNN Loss: 5.628486156463623 | BCE Loss: 1.0512281656265259
Epoch 491 / 500 | iteration 20 / 30 | Total Loss: 6.605490684509277 | KNN Loss: 5.593974590301514 | BCE Loss: 1.0115163326263428
Epoch 491 / 500 | iteration 25 / 30 | Total Loss: 6.646239757537842 | KNN Loss: 5.609392166137695 | BCE Loss: 1.0368475914001465
Epoch 492 / 500 | iteration 0 / 30 | Total Loss: 6.792389869689941 | KNN Loss: 5.7402024269104 | BCE Loss: 1.052187204360962
Epoch 492 / 500 | iteration 5 / 30 | Total Loss: 6.681451797485352 | KNN Loss: 5.624104022979736 | BCE Loss: 1.0573480129241943
Epoch 492 / 500 | iteration 10 / 30 | Total Loss: 6.635441780090332 | KNN Loss: 5.600375175476074 | BCE Loss: 1.0350664854049683
Epoch 492 / 500 | iteration 15 / 30 | Total Loss: 6.64028263092041 | KNN Loss: 5.606632232666016 | BCE Loss: 1.0336506366729736
Epoch 492 / 500 | iteration 20 / 30 | Total Loss: 6.675051689147949 | KNN Loss: 5.650125026702881 | BCE

In [7]:
outputs, iterm = model(dataset[67][0].unsqueeze(0).to(device), return_intermidiate=True)
print(outputs)
print(outputs.min())
print(outputs.max())

tensor([[ 2.8459,  3.5631,  2.5001,  3.1537,  3.3568,  0.6708,  2.3711,  2.0827,
          2.1993,  1.9800,  2.0103,  1.9847,  0.8091,  1.7728,  1.2656,  1.4792,
          2.4599,  2.8532,  2.6886,  2.2220,  1.6471,  2.9088,  2.0390,  2.3761,
          2.4719,  1.5146,  1.8861,  1.3733,  1.5239,  0.2685, -0.1907,  0.9555,
          0.2066,  0.9097,  1.4664,  1.3153,  0.9876,  3.2347,  0.8022,  1.3234,
          0.9608, -0.6582, -0.2546,  2.2673,  1.9412,  0.6751, -0.1418,  0.0990,
          1.4441,  2.3865,  1.7711,  0.0801,  1.4148,  0.4678, -0.5583,  1.0986,
          1.4249,  1.3179,  1.2942,  1.7721,  0.5641,  0.8216,  0.1511,  1.6695,
          1.2416,  1.6097, -1.7406,  0.2939,  2.2443,  2.0747,  2.4529,  0.4525,
          1.3193,  2.3935,  1.7749,  1.2830,  0.2490,  0.7379,  0.2319,  1.5311,
          0.0470,  0.3363,  1.7889, -0.3315,  0.2074, -1.0918, -2.2267, -0.2803,
          0.5183, -1.7710,  0.4681, -0.1521, -0.5530, -0.8952,  0.5180,  1.2434,
         -0.6065, -0.6744,  

In [8]:
plt.figure()
plt.plot(losses)
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [9]:
dataset.transform = torchvision.transforms.Compose([
    BinaryEncodingTransform(mapping=dataset.items_to_idx),
]
)
dataset.target_transform = torchvision.transforms.Compose([
    BinaryEncodingTransform(mapping=dataset.items_to_idx),
]
)

In [10]:
dataset_ = [d[0].to('cpu') for d in dataset]

In [11]:
model = model.eval().to('cpu')
projections = model.calculate_intermidiate(dataset_)

100%|██████████| 15/15 [00:02<00:00,  5.63it/s]


In [41]:
distances = pairwise_distances(projections)
# distances = np.triu(distances)
distances_f = distances.flatten()

plt.matshow(distances)
plt.colorbar()
plt.figure()
plt.hist(distances_f[distances_f > 0], bins=1000)
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

# Fit DBSCAN and calculate indices


In [48]:
clusters = DBSCAN(eps=0.02, min_samples=80).fit_predict(projections)
# scores = []
# best_score = float('inf')
# clusters = None
# range_ = list(range(5, 20))
# for k in tqdm(range_):
#     y = GaussianMixture(n_components=k).fit_predict(projections)
#     cur_score = davies_bouldin_score(projections, y)
#     scores.append(cur_score)
    
#     if cur_score < best_score:
#         best_score = cur_score
#         clusters = y

In [49]:
perplexity = 100

p = reduce_dims_and_plot(projections[clusters != -1],
                         y=clusters[clusters != -1],
                         title=f'perplexity: {perplexity}',
                         file_name=None,
                         perplexity=perplexity,
                         library='Multicore-TSNE',
                         perform_PCA=False,
                         projected=None,
                         figure_type='2d',
                         show_figure=True,
                         close_figure=False,
                         text=None)    

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [50]:
# from sklearn.tree import DecisionTreeClassifier
# from sklearn import tree
# from sklearn.tree import _tree

In [51]:
tensor_dataset = torch.stack(dataset_)

In [52]:
# clf = DecisionTreeClassifier(max_depth=200, min_samples_leaf=5)
# clf = clf.fit(tensor_dataset[clusters!=-1], clusters[clusters != -1])
# print(clf.score(tensor_dataset[clusters!=-1], clusters[clusters != -1]))
# print(clf.get_depth())

In [53]:
# scores = []
# for min_samples in range(1,50, 1):
#     clf = DecisionTreeClassifier(max_depth=200, min_samples_leaf=min_samples)
#     clf = clf.fit(tensor_dataset[clusters!=-1], clusters[clusters != -1])
#     scores.append(clf.score(tensor_dataset[clusters!=-1], clusters[clusters != -1]))
    
# plt.figure()
# plt.plot(list(range(1,50, 1)), scores)
# plt.show()

In [54]:
def get_rules(tree, feature_names, class_names):
    tree_ = tree.tree_
    feature_name = [
        feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!"
        for i in tree_.feature
    ]

    paths = []
    path = []
    
    def recurse(node, path, paths):
        if tree_.feature[node] != _tree.TREE_UNDEFINED:
            name = feature_name[node]
            threshold = tree_.threshold[node]
            p1, p2 = list(path), list(path)
#             p1 += [f"({name} <= {np.round(threshold, 3)})"]
            p1 += [(name, '<=', np.round(threshold, 3))]
            recurse(tree_.children_left[node], p1, paths)
            p2 += [(name, '>', np.round(threshold, 3))]
            recurse(tree_.children_right[node], p2, paths)
        else:
            path += [(tree_.value[node], tree_.n_node_samples[node])]
            paths += [path]
            
    recurse(0, path, paths)

    # sort by samples count
    samples_count = [p[-1][1] for p in paths]
    ii = list(np.argsort(samples_count))
    paths = [paths[i] for i in reversed(ii)]
    
    rules = []
    for path in paths:
        rule = []
        
        for p in path[:-1]:
            rule += [p]
        target = " then "
        if class_names is None:
            target += "response: "+str(np.round(path[-1][0][0][0],3))
        else:
            classes = path[-1][0][0]
            l = np.argmax(classes)
            target += f"class: {class_names[l]} (proba: {np.round(100.0*classes[l]/np.sum(classes),2)}%)"
           
        proba = np.round(100.0*classes[l]/np.sum(classes),2)
        target += f" | based on {path[-1][1]:,} samples"
        rule_wrapper = {'target': target, 'rule': rule, 'proba': proba}
        rules += [rule_wrapper]
        
    return rules

In [55]:
# rules = get_rules(clf, dataset.items, clusters[clusters != -1])

# for rule in rules:
#     n_pos = 0
#     for c,p,v in rule['rule']:
#         if p == '>':
#             n_pos += 1
#     rule['pos'] = n_pos

In [56]:
# plt.figure()
# probs = [r['proba'] for r in rules]
# plt.hist(probs, bins = 100)
# plt.show()

In [57]:
# rules = sorted(rules, key=lambda x:x['pos'])
# rules = [r for r in rules if r['proba'] > 50]
# print(len(rules))

In [58]:
# for i in range(17):
#     r_i = rules[i]
#     print(f"------------- rule {i} length {len(r_i)} -------------")
#     print(r_i)

# Train a Soft-Decision-Tree given the self-labels

## Prepare the dataset

In [59]:
tree_dataset = list(zip(tensor_dataset[clusters!=-1], clusters[clusters != -1]))
batch_size = 512
tree_loader = torch.utils.data.DataLoader(tree_dataset, batch_size=batch_size, shuffle=True)

# Define how we prune the weights of a node

In [60]:
def prune_node(node_weights, factor=1):
    w = node_weights.cpu().detach().numpy()
    mean_ = np.mean(w)
    std_ = np.std(w)
    node_weights[((mean_ - std_ * factor) < node_weights) & (node_weights < (mean_ + std_ * factor))] = 0
    return node_weights

def prune_node_keep(node_weights, keep=4):
    w = node_weights.cpu().detach().numpy()
    throw_idx = np.argsort(abs(w))[:-keep]
    node_weights[throw_idx] = 0
    return node_weights

def prune_tree(tree_, factor):
    new_weights = tree_.inner_nodes.weight.clone()
    for i in range(new_weights.shape[0]):
        res = prune_node_keep(new_weights[i, :], factor)
        new_weights[i, :] = res

    with torch.no_grad():
        tree_.inner_nodes.weight.copy_(new_weights)
        
def sparseness(x):
    s = []
    for i in range(x.shape[0]):
        x_ = x[i, :]
        sp = (len(x_) - torch.norm(x_, 0).item()) / len(x_)
        s.append(sp)
    return np.mean(s)

def compute_regularization_by_level(tree):
    total_reg = 0
    for i in range(tree.inner_nodes.weight.shape[0]):
        cur_level = np.floor(np.log2(i+1))
        node_reg = torch.norm(tree.inner_nodes.weight[i].view(-1), 2)
        total_reg += 2**(-cur_level) * node_reg
    return total_reg

def show_sparseness(tree):
    avg_sp = sparseness(tree.inner_nodes.weight)
    print(f"Average sparseness: {avg_sp}")
    layer = 0
    sps = []
    for i in range(tree.inner_nodes.weight.shape[0]):
        cur_layer = int(np.floor(np.log2(i+1)))
        if cur_layer != layer:
            print(f"layer {layer}: {np.mean(sps)}")
            sps = []
            layer = cur_layer

        x_ = tree.inner_nodes.weight[i, :]
        sp = (len(x_) - torch.norm(x_, 0).item()) / len(x_)
        sps.append(sp)
        
    return avg_sp

## Training configurations

In [61]:
def do_epoch(model, loader, device, log_interval, losses, accs, epoch, iteration):
    model = model.train()
    for batch_idx, (data, target) in enumerate(loader):
        iteration += 1
        data, target = data.to(device), target.to(device)

        output, penalty = tree.forward(data)

        # Loss
        loss_tree = criterion(output, target.view(-1))

        # Penalty
        loss_tree += penalty

        # Sparse regularization
#         fc_params = torch.cat([x.view(-1) for x in tree.inner_nodes.parameters()])
#         regularization = sparsity_lamda * torch.norm(fc_params, 2)
        regularization = sparsity_lamda * compute_regularization_by_level(tree)
        loss = loss_tree

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        losses.append(loss.item())

        pred = output.data.max(1)[1]
        correct = pred.eq(target.view(-1).data).sum()
        accs.append(correct.item() / data.size()[0])

        # Print training status
        if batch_idx % log_interval == 0:
            print(f"Epoch: {epoch:02d} | Batch: {batch_idx:03d} / {len(loader):03d} | Total loss: {loss.item():.3f} | Reg loss: {regularization.item():.3f} | Tree loss: {loss_tree.item():.3f} | Accuracy: {correct.item() / data.size()[0]:03f} | {round((time.time() - start_time) / iteration, 3)} sec/iter")
            
    return iteration


In [62]:
lr = 5e-3
weight_decay = 5e-4
sparsity_lamda = 2e-3
epochs = 100
output_dim = len(set(clusters))
log_interval = 1
use_cuda = device != 'cpu'

In [63]:
tree = SDT(input_dim=tensor_dataset.shape[1], output_dim=len(clusters - 1), depth=tree_depth, lamda=1e-3, use_cuda=use_cuda)
optimizer = torch.optim.Adam(tree.parameters(),
                                 lr=lr,
                                 weight_decay=weight_decay)
criterion = nn.CrossEntropyLoss()
tree = tree.to(device)

In [64]:
losses = []
accs = []
sparsity = []

In [65]:
start_time = time.time()
iteration = 0
for epoch in range(epochs):
    # Training
    avg_sp = show_sparseness(tree)
    sparsity.append(avg_sp)
    iteration = do_epoch(tree, tree_loader, device, log_interval, losses, accs, epoch, iteration)
    
    if epoch % 1 == 0:
        prune_tree(tree, factor=3)
        

Average sparseness: 0.0
layer 0: 0.0
layer 1: 0.0
layer 2: 0.0
layer 3: 0.0
layer 4: 0.0
Epoch: 00 | Batch: 000 / 021 | Total loss: 9.760 | Reg loss: 0.007 | Tree loss: 9.760 | Accuracy: 0.000000 | 0.634 sec/iter
Epoch: 00 | Batch: 001 / 021 | Total loss: 9.752 | Reg loss: 0.007 | Tree loss: 9.752 | Accuracy: 0.000000 | 0.369 sec/iter
Epoch: 00 | Batch: 002 / 021 | Total loss: 9.742 | Reg loss: 0.007 | Tree loss: 9.742 | Accuracy: 0.000000 | 0.279 sec/iter
Epoch: 00 | Batch: 003 / 021 | Total loss: 9.728 | Reg loss: 0.006 | Tree loss: 9.728 | Accuracy: 0.000000 | 0.234 sec/iter
Epoch: 00 | Batch: 004 / 021 | Total loss: 9.717 | Reg loss: 0.006 | Tree loss: 9.717 | Accuracy: 0.000000 | 0.208 sec/iter
Epoch: 00 | Batch: 005 / 021 | Total loss: 9.703 | Reg loss: 0.006 | Tree loss: 9.703 | Accuracy: 0.000000 | 0.191 sec/iter
Epoch: 00 | Batch: 006 / 021 | Total loss: 9.693 | Reg loss: 0.006 | Tree loss: 9.693 | Accuracy: 0.000000 | 0.178 sec/iter
Epoch: 00 | Batch: 007 / 021 | Total loss: 

Epoch: 03 | Batch: 000 / 021 | Total loss: 9.402 | Reg loss: 0.006 | Tree loss: 9.402 | Accuracy: 0.000000 | 0.114 sec/iter
Epoch: 03 | Batch: 001 / 021 | Total loss: 9.391 | Reg loss: 0.006 | Tree loss: 9.391 | Accuracy: 0.000000 | 0.114 sec/iter
Epoch: 03 | Batch: 002 / 021 | Total loss: 9.380 | Reg loss: 0.007 | Tree loss: 9.380 | Accuracy: 0.001953 | 0.114 sec/iter
Epoch: 03 | Batch: 003 / 021 | Total loss: 9.366 | Reg loss: 0.007 | Tree loss: 9.366 | Accuracy: 0.000000 | 0.113 sec/iter
Epoch: 03 | Batch: 004 / 021 | Total loss: 9.349 | Reg loss: 0.007 | Tree loss: 9.349 | Accuracy: 0.003906 | 0.113 sec/iter
Epoch: 03 | Batch: 005 / 021 | Total loss: 9.341 | Reg loss: 0.007 | Tree loss: 9.341 | Accuracy: 0.011719 | 0.113 sec/iter
Epoch: 03 | Batch: 006 / 021 | Total loss: 9.331 | Reg loss: 0.007 | Tree loss: 9.331 | Accuracy: 0.011719 | 0.113 sec/iter
Epoch: 03 | Batch: 007 / 021 | Total loss: 9.315 | Reg loss: 0.008 | Tree loss: 9.315 | Accuracy: 0.023438 | 0.112 sec/iter
Epoch: 0

Epoch: 06 | Batch: 000 / 021 | Total loss: 9.043 | Reg loss: 0.011 | Tree loss: 9.043 | Accuracy: 0.376953 | 0.104 sec/iter
Epoch: 06 | Batch: 001 / 021 | Total loss: 9.036 | Reg loss: 0.011 | Tree loss: 9.036 | Accuracy: 0.345703 | 0.104 sec/iter
Epoch: 06 | Batch: 002 / 021 | Total loss: 9.022 | Reg loss: 0.011 | Tree loss: 9.022 | Accuracy: 0.349609 | 0.104 sec/iter
Epoch: 06 | Batch: 003 / 021 | Total loss: 8.997 | Reg loss: 0.011 | Tree loss: 8.997 | Accuracy: 0.414062 | 0.104 sec/iter
Epoch: 06 | Batch: 004 / 021 | Total loss: 8.988 | Reg loss: 0.011 | Tree loss: 8.988 | Accuracy: 0.371094 | 0.104 sec/iter
Epoch: 06 | Batch: 005 / 021 | Total loss: 8.987 | Reg loss: 0.011 | Tree loss: 8.987 | Accuracy: 0.330078 | 0.104 sec/iter
Epoch: 06 | Batch: 006 / 021 | Total loss: 8.973 | Reg loss: 0.011 | Tree loss: 8.973 | Accuracy: 0.335938 | 0.104 sec/iter
Epoch: 06 | Batch: 007 / 021 | Total loss: 8.956 | Reg loss: 0.012 | Tree loss: 8.956 | Accuracy: 0.353516 | 0.104 sec/iter
Epoch: 0

Epoch: 09 | Batch: 001 / 021 | Total loss: 8.653 | Reg loss: 0.015 | Tree loss: 8.653 | Accuracy: 0.386719 | 0.103 sec/iter
Epoch: 09 | Batch: 002 / 021 | Total loss: 8.640 | Reg loss: 0.015 | Tree loss: 8.640 | Accuracy: 0.367188 | 0.103 sec/iter
Epoch: 09 | Batch: 003 / 021 | Total loss: 8.628 | Reg loss: 0.015 | Tree loss: 8.628 | Accuracy: 0.341797 | 0.103 sec/iter
Epoch: 09 | Batch: 004 / 021 | Total loss: 8.610 | Reg loss: 0.015 | Tree loss: 8.610 | Accuracy: 0.392578 | 0.103 sec/iter
Epoch: 09 | Batch: 005 / 021 | Total loss: 8.588 | Reg loss: 0.015 | Tree loss: 8.588 | Accuracy: 0.380859 | 0.103 sec/iter
Epoch: 09 | Batch: 006 / 021 | Total loss: 8.590 | Reg loss: 0.015 | Tree loss: 8.590 | Accuracy: 0.355469 | 0.103 sec/iter
Epoch: 09 | Batch: 007 / 021 | Total loss: 8.575 | Reg loss: 0.015 | Tree loss: 8.575 | Accuracy: 0.341797 | 0.103 sec/iter
Epoch: 09 | Batch: 008 / 021 | Total loss: 8.558 | Reg loss: 0.016 | Tree loss: 8.558 | Accuracy: 0.365234 | 0.103 sec/iter
Epoch: 0

Epoch: 12 | Batch: 000 / 021 | Total loss: 8.283 | Reg loss: 0.018 | Tree loss: 8.283 | Accuracy: 0.396484 | 0.103 sec/iter
Epoch: 12 | Batch: 001 / 021 | Total loss: 8.252 | Reg loss: 0.018 | Tree loss: 8.252 | Accuracy: 0.398438 | 0.103 sec/iter
Epoch: 12 | Batch: 002 / 021 | Total loss: 8.250 | Reg loss: 0.018 | Tree loss: 8.250 | Accuracy: 0.375000 | 0.103 sec/iter
Epoch: 12 | Batch: 003 / 021 | Total loss: 8.225 | Reg loss: 0.018 | Tree loss: 8.225 | Accuracy: 0.359375 | 0.102 sec/iter
Epoch: 12 | Batch: 004 / 021 | Total loss: 8.202 | Reg loss: 0.018 | Tree loss: 8.202 | Accuracy: 0.373047 | 0.102 sec/iter
Epoch: 12 | Batch: 005 / 021 | Total loss: 8.204 | Reg loss: 0.018 | Tree loss: 8.204 | Accuracy: 0.353516 | 0.102 sec/iter
Epoch: 12 | Batch: 006 / 021 | Total loss: 8.177 | Reg loss: 0.018 | Tree loss: 8.177 | Accuracy: 0.369141 | 0.102 sec/iter
Epoch: 12 | Batch: 007 / 021 | Total loss: 8.175 | Reg loss: 0.018 | Tree loss: 8.175 | Accuracy: 0.333984 | 0.102 sec/iter
Epoch: 1

Epoch: 15 | Batch: 000 / 021 | Total loss: 7.873 | Reg loss: 0.021 | Tree loss: 7.873 | Accuracy: 0.335938 | 0.103 sec/iter
Epoch: 15 | Batch: 001 / 021 | Total loss: 7.853 | Reg loss: 0.021 | Tree loss: 7.853 | Accuracy: 0.363281 | 0.103 sec/iter
Epoch: 15 | Batch: 002 / 021 | Total loss: 7.821 | Reg loss: 0.021 | Tree loss: 7.821 | Accuracy: 0.382812 | 0.103 sec/iter
Epoch: 15 | Batch: 003 / 021 | Total loss: 7.824 | Reg loss: 0.021 | Tree loss: 7.824 | Accuracy: 0.367188 | 0.103 sec/iter
Epoch: 15 | Batch: 004 / 021 | Total loss: 7.790 | Reg loss: 0.022 | Tree loss: 7.790 | Accuracy: 0.375000 | 0.103 sec/iter
Epoch: 15 | Batch: 005 / 021 | Total loss: 7.776 | Reg loss: 0.022 | Tree loss: 7.776 | Accuracy: 0.376953 | 0.103 sec/iter
Epoch: 15 | Batch: 006 / 021 | Total loss: 7.751 | Reg loss: 0.022 | Tree loss: 7.751 | Accuracy: 0.349609 | 0.103 sec/iter
Epoch: 15 | Batch: 007 / 021 | Total loss: 7.759 | Reg loss: 0.022 | Tree loss: 7.759 | Accuracy: 0.349609 | 0.103 sec/iter
Epoch: 1

Epoch: 18 | Batch: 001 / 021 | Total loss: 7.468 | Reg loss: 0.025 | Tree loss: 7.468 | Accuracy: 0.375000 | 0.104 sec/iter
Epoch: 18 | Batch: 002 / 021 | Total loss: 7.427 | Reg loss: 0.025 | Tree loss: 7.427 | Accuracy: 0.355469 | 0.104 sec/iter
Epoch: 18 | Batch: 003 / 021 | Total loss: 7.436 | Reg loss: 0.025 | Tree loss: 7.436 | Accuracy: 0.349609 | 0.104 sec/iter
Epoch: 18 | Batch: 004 / 021 | Total loss: 7.377 | Reg loss: 0.025 | Tree loss: 7.377 | Accuracy: 0.378906 | 0.104 sec/iter
Epoch: 18 | Batch: 005 / 021 | Total loss: 7.369 | Reg loss: 0.025 | Tree loss: 7.369 | Accuracy: 0.355469 | 0.104 sec/iter
Epoch: 18 | Batch: 006 / 021 | Total loss: 7.331 | Reg loss: 0.025 | Tree loss: 7.331 | Accuracy: 0.371094 | 0.104 sec/iter
Epoch: 18 | Batch: 007 / 021 | Total loss: 7.321 | Reg loss: 0.026 | Tree loss: 7.321 | Accuracy: 0.390625 | 0.104 sec/iter
Epoch: 18 | Batch: 008 / 021 | Total loss: 7.291 | Reg loss: 0.026 | Tree loss: 7.291 | Accuracy: 0.335938 | 0.104 sec/iter
Epoch: 1

Epoch: 21 | Batch: 001 / 021 | Total loss: 7.033 | Reg loss: 0.028 | Tree loss: 7.033 | Accuracy: 0.419922 | 0.105 sec/iter
Epoch: 21 | Batch: 002 / 021 | Total loss: 7.028 | Reg loss: 0.028 | Tree loss: 7.028 | Accuracy: 0.365234 | 0.105 sec/iter
Epoch: 21 | Batch: 003 / 021 | Total loss: 7.005 | Reg loss: 0.028 | Tree loss: 7.005 | Accuracy: 0.367188 | 0.105 sec/iter
Epoch: 21 | Batch: 004 / 021 | Total loss: 7.035 | Reg loss: 0.028 | Tree loss: 7.035 | Accuracy: 0.324219 | 0.105 sec/iter
Epoch: 21 | Batch: 005 / 021 | Total loss: 6.990 | Reg loss: 0.028 | Tree loss: 6.990 | Accuracy: 0.343750 | 0.105 sec/iter
Epoch: 21 | Batch: 006 / 021 | Total loss: 6.963 | Reg loss: 0.028 | Tree loss: 6.963 | Accuracy: 0.363281 | 0.105 sec/iter
Epoch: 21 | Batch: 007 / 021 | Total loss: 6.919 | Reg loss: 0.028 | Tree loss: 6.919 | Accuracy: 0.337891 | 0.105 sec/iter
Epoch: 21 | Batch: 008 / 021 | Total loss: 6.922 | Reg loss: 0.029 | Tree loss: 6.922 | Accuracy: 0.376953 | 0.105 sec/iter
Epoch: 2

Epoch: 24 | Batch: 000 / 021 | Total loss: 6.694 | Reg loss: 0.030 | Tree loss: 6.694 | Accuracy: 0.353516 | 0.105 sec/iter
Epoch: 24 | Batch: 001 / 021 | Total loss: 6.699 | Reg loss: 0.030 | Tree loss: 6.699 | Accuracy: 0.388672 | 0.105 sec/iter
Epoch: 24 | Batch: 002 / 021 | Total loss: 6.617 | Reg loss: 0.030 | Tree loss: 6.617 | Accuracy: 0.425781 | 0.105 sec/iter
Epoch: 24 | Batch: 003 / 021 | Total loss: 6.634 | Reg loss: 0.030 | Tree loss: 6.634 | Accuracy: 0.361328 | 0.105 sec/iter
Epoch: 24 | Batch: 004 / 021 | Total loss: 6.627 | Reg loss: 0.030 | Tree loss: 6.627 | Accuracy: 0.392578 | 0.105 sec/iter
Epoch: 24 | Batch: 005 / 021 | Total loss: 6.614 | Reg loss: 0.031 | Tree loss: 6.614 | Accuracy: 0.347656 | 0.105 sec/iter
Epoch: 24 | Batch: 006 / 021 | Total loss: 6.567 | Reg loss: 0.031 | Tree loss: 6.567 | Accuracy: 0.394531 | 0.105 sec/iter
Epoch: 24 | Batch: 007 / 021 | Total loss: 6.545 | Reg loss: 0.031 | Tree loss: 6.545 | Accuracy: 0.339844 | 0.105 sec/iter
Epoch: 2

Epoch: 27 | Batch: 000 / 021 | Total loss: 6.337 | Reg loss: 0.032 | Tree loss: 6.337 | Accuracy: 0.376953 | 0.106 sec/iter
Epoch: 27 | Batch: 001 / 021 | Total loss: 6.324 | Reg loss: 0.032 | Tree loss: 6.324 | Accuracy: 0.365234 | 0.106 sec/iter
Epoch: 27 | Batch: 002 / 021 | Total loss: 6.308 | Reg loss: 0.032 | Tree loss: 6.308 | Accuracy: 0.369141 | 0.106 sec/iter
Epoch: 27 | Batch: 003 / 021 | Total loss: 6.274 | Reg loss: 0.032 | Tree loss: 6.274 | Accuracy: 0.353516 | 0.106 sec/iter
Epoch: 27 | Batch: 004 / 021 | Total loss: 6.232 | Reg loss: 0.032 | Tree loss: 6.232 | Accuracy: 0.394531 | 0.106 sec/iter
Epoch: 27 | Batch: 005 / 021 | Total loss: 6.254 | Reg loss: 0.032 | Tree loss: 6.254 | Accuracy: 0.351562 | 0.106 sec/iter
Epoch: 27 | Batch: 006 / 021 | Total loss: 6.199 | Reg loss: 0.032 | Tree loss: 6.199 | Accuracy: 0.363281 | 0.106 sec/iter
Epoch: 27 | Batch: 007 / 021 | Total loss: 6.178 | Reg loss: 0.033 | Tree loss: 6.178 | Accuracy: 0.373047 | 0.106 sec/iter
Epoch: 2

Epoch: 30 | Batch: 002 / 021 | Total loss: 5.925 | Reg loss: 0.034 | Tree loss: 5.925 | Accuracy: 0.371094 | 0.106 sec/iter
Epoch: 30 | Batch: 003 / 021 | Total loss: 5.896 | Reg loss: 0.034 | Tree loss: 5.896 | Accuracy: 0.375000 | 0.106 sec/iter
Epoch: 30 | Batch: 004 / 021 | Total loss: 5.903 | Reg loss: 0.034 | Tree loss: 5.903 | Accuracy: 0.361328 | 0.106 sec/iter
Epoch: 30 | Batch: 005 / 021 | Total loss: 5.876 | Reg loss: 0.034 | Tree loss: 5.876 | Accuracy: 0.369141 | 0.106 sec/iter
Epoch: 30 | Batch: 006 / 021 | Total loss: 5.846 | Reg loss: 0.034 | Tree loss: 5.846 | Accuracy: 0.359375 | 0.105 sec/iter
Epoch: 30 | Batch: 007 / 021 | Total loss: 5.815 | Reg loss: 0.034 | Tree loss: 5.815 | Accuracy: 0.396484 | 0.105 sec/iter
Epoch: 30 | Batch: 008 / 021 | Total loss: 5.791 | Reg loss: 0.034 | Tree loss: 5.791 | Accuracy: 0.355469 | 0.105 sec/iter
Epoch: 30 | Batch: 009 / 021 | Total loss: 5.790 | Reg loss: 0.034 | Tree loss: 5.790 | Accuracy: 0.382812 | 0.105 sec/iter
Epoch: 3

Epoch: 33 | Batch: 001 / 021 | Total loss: 5.587 | Reg loss: 0.035 | Tree loss: 5.587 | Accuracy: 0.369141 | 0.105 sec/iter
Epoch: 33 | Batch: 002 / 021 | Total loss: 5.526 | Reg loss: 0.035 | Tree loss: 5.526 | Accuracy: 0.363281 | 0.105 sec/iter
Epoch: 33 | Batch: 003 / 021 | Total loss: 5.519 | Reg loss: 0.035 | Tree loss: 5.519 | Accuracy: 0.365234 | 0.105 sec/iter
Epoch: 33 | Batch: 004 / 021 | Total loss: 5.508 | Reg loss: 0.035 | Tree loss: 5.508 | Accuracy: 0.378906 | 0.105 sec/iter
Epoch: 33 | Batch: 005 / 021 | Total loss: 5.477 | Reg loss: 0.035 | Tree loss: 5.477 | Accuracy: 0.384766 | 0.105 sec/iter
Epoch: 33 | Batch: 006 / 021 | Total loss: 5.491 | Reg loss: 0.035 | Tree loss: 5.491 | Accuracy: 0.378906 | 0.105 sec/iter
Epoch: 33 | Batch: 007 / 021 | Total loss: 5.425 | Reg loss: 0.036 | Tree loss: 5.425 | Accuracy: 0.367188 | 0.105 sec/iter
Epoch: 33 | Batch: 008 / 021 | Total loss: 5.454 | Reg loss: 0.036 | Tree loss: 5.454 | Accuracy: 0.363281 | 0.105 sec/iter
Epoch: 3

Epoch: 36 | Batch: 001 / 021 | Total loss: 5.200 | Reg loss: 0.036 | Tree loss: 5.200 | Accuracy: 0.400391 | 0.106 sec/iter
Epoch: 36 | Batch: 002 / 021 | Total loss: 5.135 | Reg loss: 0.036 | Tree loss: 5.135 | Accuracy: 0.367188 | 0.106 sec/iter
Epoch: 36 | Batch: 003 / 021 | Total loss: 5.153 | Reg loss: 0.036 | Tree loss: 5.153 | Accuracy: 0.388672 | 0.106 sec/iter
Epoch: 36 | Batch: 004 / 021 | Total loss: 5.148 | Reg loss: 0.036 | Tree loss: 5.148 | Accuracy: 0.349609 | 0.106 sec/iter
Epoch: 36 | Batch: 005 / 021 | Total loss: 5.078 | Reg loss: 0.036 | Tree loss: 5.078 | Accuracy: 0.394531 | 0.106 sec/iter
Epoch: 36 | Batch: 006 / 021 | Total loss: 5.080 | Reg loss: 0.036 | Tree loss: 5.080 | Accuracy: 0.378906 | 0.106 sec/iter
Epoch: 36 | Batch: 007 / 021 | Total loss: 5.009 | Reg loss: 0.036 | Tree loss: 5.009 | Accuracy: 0.371094 | 0.106 sec/iter
Epoch: 36 | Batch: 008 / 021 | Total loss: 5.021 | Reg loss: 0.037 | Tree loss: 5.021 | Accuracy: 0.386719 | 0.106 sec/iter
Epoch: 3

Epoch: 39 | Batch: 000 / 021 | Total loss: 4.881 | Reg loss: 0.036 | Tree loss: 4.881 | Accuracy: 0.380859 | 0.106 sec/iter
Epoch: 39 | Batch: 001 / 021 | Total loss: 4.811 | Reg loss: 0.036 | Tree loss: 4.811 | Accuracy: 0.376953 | 0.106 sec/iter
Epoch: 39 | Batch: 002 / 021 | Total loss: 4.783 | Reg loss: 0.036 | Tree loss: 4.783 | Accuracy: 0.373047 | 0.106 sec/iter
Epoch: 39 | Batch: 003 / 021 | Total loss: 4.752 | Reg loss: 0.037 | Tree loss: 4.752 | Accuracy: 0.429688 | 0.106 sec/iter
Epoch: 39 | Batch: 004 / 021 | Total loss: 4.770 | Reg loss: 0.037 | Tree loss: 4.770 | Accuracy: 0.355469 | 0.106 sec/iter
Epoch: 39 | Batch: 005 / 021 | Total loss: 4.752 | Reg loss: 0.037 | Tree loss: 4.752 | Accuracy: 0.375000 | 0.106 sec/iter
Epoch: 39 | Batch: 006 / 021 | Total loss: 4.706 | Reg loss: 0.037 | Tree loss: 4.706 | Accuracy: 0.365234 | 0.106 sec/iter
Epoch: 39 | Batch: 007 / 021 | Total loss: 4.712 | Reg loss: 0.037 | Tree loss: 4.712 | Accuracy: 0.332031 | 0.106 sec/iter
Epoch: 3

Epoch: 42 | Batch: 001 / 021 | Total loss: 4.396 | Reg loss: 0.036 | Tree loss: 4.396 | Accuracy: 0.396484 | 0.106 sec/iter
Epoch: 42 | Batch: 002 / 021 | Total loss: 4.422 | Reg loss: 0.036 | Tree loss: 4.422 | Accuracy: 0.384766 | 0.106 sec/iter
Epoch: 42 | Batch: 003 / 021 | Total loss: 4.419 | Reg loss: 0.036 | Tree loss: 4.419 | Accuracy: 0.339844 | 0.106 sec/iter
Epoch: 42 | Batch: 004 / 021 | Total loss: 4.374 | Reg loss: 0.036 | Tree loss: 4.374 | Accuracy: 0.373047 | 0.106 sec/iter
Epoch: 42 | Batch: 005 / 021 | Total loss: 4.347 | Reg loss: 0.036 | Tree loss: 4.347 | Accuracy: 0.355469 | 0.106 sec/iter
Epoch: 42 | Batch: 006 / 021 | Total loss: 4.304 | Reg loss: 0.036 | Tree loss: 4.304 | Accuracy: 0.351562 | 0.106 sec/iter
Epoch: 42 | Batch: 007 / 021 | Total loss: 4.252 | Reg loss: 0.036 | Tree loss: 4.252 | Accuracy: 0.398438 | 0.107 sec/iter
Epoch: 42 | Batch: 008 / 021 | Total loss: 4.281 | Reg loss: 0.036 | Tree loss: 4.281 | Accuracy: 0.355469 | 0.107 sec/iter
Epoch: 4

Epoch: 45 | Batch: 000 / 021 | Total loss: 4.126 | Reg loss: 0.035 | Tree loss: 4.126 | Accuracy: 0.310547 | 0.107 sec/iter
Epoch: 45 | Batch: 001 / 021 | Total loss: 4.098 | Reg loss: 0.035 | Tree loss: 4.098 | Accuracy: 0.341797 | 0.107 sec/iter
Epoch: 45 | Batch: 002 / 021 | Total loss: 4.008 | Reg loss: 0.035 | Tree loss: 4.008 | Accuracy: 0.404297 | 0.107 sec/iter
Epoch: 45 | Batch: 003 / 021 | Total loss: 4.026 | Reg loss: 0.035 | Tree loss: 4.026 | Accuracy: 0.380859 | 0.107 sec/iter
Epoch: 45 | Batch: 004 / 021 | Total loss: 4.028 | Reg loss: 0.035 | Tree loss: 4.028 | Accuracy: 0.355469 | 0.107 sec/iter
Epoch: 45 | Batch: 005 / 021 | Total loss: 3.953 | Reg loss: 0.035 | Tree loss: 3.953 | Accuracy: 0.351562 | 0.107 sec/iter
Epoch: 45 | Batch: 006 / 021 | Total loss: 3.876 | Reg loss: 0.035 | Tree loss: 3.876 | Accuracy: 0.380859 | 0.107 sec/iter
Epoch: 45 | Batch: 007 / 021 | Total loss: 3.871 | Reg loss: 0.035 | Tree loss: 3.871 | Accuracy: 0.371094 | 0.107 sec/iter
Epoch: 4

Epoch: 48 | Batch: 001 / 021 | Total loss: 3.698 | Reg loss: 0.033 | Tree loss: 3.698 | Accuracy: 0.400391 | 0.107 sec/iter
Epoch: 48 | Batch: 002 / 021 | Total loss: 3.659 | Reg loss: 0.033 | Tree loss: 3.659 | Accuracy: 0.380859 | 0.107 sec/iter
Epoch: 48 | Batch: 003 / 021 | Total loss: 3.671 | Reg loss: 0.033 | Tree loss: 3.671 | Accuracy: 0.341797 | 0.107 sec/iter
Epoch: 48 | Batch: 004 / 021 | Total loss: 3.621 | Reg loss: 0.033 | Tree loss: 3.621 | Accuracy: 0.369141 | 0.107 sec/iter
Epoch: 48 | Batch: 005 / 021 | Total loss: 3.618 | Reg loss: 0.033 | Tree loss: 3.618 | Accuracy: 0.369141 | 0.107 sec/iter
Epoch: 48 | Batch: 006 / 021 | Total loss: 3.614 | Reg loss: 0.033 | Tree loss: 3.614 | Accuracy: 0.386719 | 0.107 sec/iter
Epoch: 48 | Batch: 007 / 021 | Total loss: 3.544 | Reg loss: 0.033 | Tree loss: 3.544 | Accuracy: 0.388672 | 0.107 sec/iter
Epoch: 48 | Batch: 008 / 021 | Total loss: 3.472 | Reg loss: 0.034 | Tree loss: 3.472 | Accuracy: 0.363281 | 0.107 sec/iter
Epoch: 4

Epoch: 51 | Batch: 001 / 021 | Total loss: 3.336 | Reg loss: 0.032 | Tree loss: 3.336 | Accuracy: 0.425781 | 0.108 sec/iter
Epoch: 51 | Batch: 002 / 021 | Total loss: 3.372 | Reg loss: 0.032 | Tree loss: 3.372 | Accuracy: 0.367188 | 0.108 sec/iter
Epoch: 51 | Batch: 003 / 021 | Total loss: 3.347 | Reg loss: 0.032 | Tree loss: 3.347 | Accuracy: 0.404297 | 0.108 sec/iter
Epoch: 51 | Batch: 004 / 021 | Total loss: 3.283 | Reg loss: 0.032 | Tree loss: 3.283 | Accuracy: 0.384766 | 0.108 sec/iter
Epoch: 51 | Batch: 005 / 021 | Total loss: 3.271 | Reg loss: 0.032 | Tree loss: 3.271 | Accuracy: 0.390625 | 0.108 sec/iter
Epoch: 51 | Batch: 006 / 021 | Total loss: 3.267 | Reg loss: 0.032 | Tree loss: 3.267 | Accuracy: 0.365234 | 0.108 sec/iter
Epoch: 51 | Batch: 007 / 021 | Total loss: 3.219 | Reg loss: 0.032 | Tree loss: 3.219 | Accuracy: 0.382812 | 0.108 sec/iter
Epoch: 51 | Batch: 008 / 021 | Total loss: 3.231 | Reg loss: 0.032 | Tree loss: 3.231 | Accuracy: 0.369141 | 0.108 sec/iter
Epoch: 5

Epoch: 54 | Batch: 000 / 021 | Total loss: 3.161 | Reg loss: 0.033 | Tree loss: 3.161 | Accuracy: 0.359375 | 0.108 sec/iter
Epoch: 54 | Batch: 001 / 021 | Total loss: 3.085 | Reg loss: 0.033 | Tree loss: 3.085 | Accuracy: 0.402344 | 0.108 sec/iter
Epoch: 54 | Batch: 002 / 021 | Total loss: 3.112 | Reg loss: 0.033 | Tree loss: 3.112 | Accuracy: 0.378906 | 0.108 sec/iter
Epoch: 54 | Batch: 003 / 021 | Total loss: 3.015 | Reg loss: 0.033 | Tree loss: 3.015 | Accuracy: 0.429688 | 0.108 sec/iter
Epoch: 54 | Batch: 004 / 021 | Total loss: 3.000 | Reg loss: 0.033 | Tree loss: 3.000 | Accuracy: 0.359375 | 0.108 sec/iter
Epoch: 54 | Batch: 005 / 021 | Total loss: 3.015 | Reg loss: 0.033 | Tree loss: 3.015 | Accuracy: 0.355469 | 0.108 sec/iter
Epoch: 54 | Batch: 006 / 021 | Total loss: 3.006 | Reg loss: 0.033 | Tree loss: 3.006 | Accuracy: 0.371094 | 0.108 sec/iter
Epoch: 54 | Batch: 007 / 021 | Total loss: 2.876 | Reg loss: 0.033 | Tree loss: 2.876 | Accuracy: 0.398438 | 0.108 sec/iter
Epoch: 5

Epoch: 57 | Batch: 001 / 021 | Total loss: 2.926 | Reg loss: 0.035 | Tree loss: 2.926 | Accuracy: 0.371094 | 0.109 sec/iter
Epoch: 57 | Batch: 002 / 021 | Total loss: 2.841 | Reg loss: 0.035 | Tree loss: 2.841 | Accuracy: 0.380859 | 0.109 sec/iter
Epoch: 57 | Batch: 003 / 021 | Total loss: 2.789 | Reg loss: 0.035 | Tree loss: 2.789 | Accuracy: 0.369141 | 0.109 sec/iter
Epoch: 57 | Batch: 004 / 021 | Total loss: 2.802 | Reg loss: 0.035 | Tree loss: 2.802 | Accuracy: 0.355469 | 0.109 sec/iter
Epoch: 57 | Batch: 005 / 021 | Total loss: 2.723 | Reg loss: 0.035 | Tree loss: 2.723 | Accuracy: 0.388672 | 0.109 sec/iter
Epoch: 57 | Batch: 006 / 021 | Total loss: 2.773 | Reg loss: 0.035 | Tree loss: 2.773 | Accuracy: 0.390625 | 0.109 sec/iter
Epoch: 57 | Batch: 007 / 021 | Total loss: 2.730 | Reg loss: 0.035 | Tree loss: 2.730 | Accuracy: 0.404297 | 0.109 sec/iter
Epoch: 57 | Batch: 008 / 021 | Total loss: 2.670 | Reg loss: 0.035 | Tree loss: 2.670 | Accuracy: 0.365234 | 0.109 sec/iter
Epoch: 5

Epoch: 60 | Batch: 000 / 021 | Total loss: 2.705 | Reg loss: 0.036 | Tree loss: 2.705 | Accuracy: 0.367188 | 0.109 sec/iter
Epoch: 60 | Batch: 001 / 021 | Total loss: 2.691 | Reg loss: 0.036 | Tree loss: 2.691 | Accuracy: 0.384766 | 0.109 sec/iter
Epoch: 60 | Batch: 002 / 021 | Total loss: 2.663 | Reg loss: 0.036 | Tree loss: 2.663 | Accuracy: 0.351562 | 0.109 sec/iter
Epoch: 60 | Batch: 003 / 021 | Total loss: 2.610 | Reg loss: 0.036 | Tree loss: 2.610 | Accuracy: 0.363281 | 0.109 sec/iter
Epoch: 60 | Batch: 004 / 021 | Total loss: 2.620 | Reg loss: 0.036 | Tree loss: 2.620 | Accuracy: 0.390625 | 0.109 sec/iter
Epoch: 60 | Batch: 005 / 021 | Total loss: 2.585 | Reg loss: 0.036 | Tree loss: 2.585 | Accuracy: 0.388672 | 0.109 sec/iter
Epoch: 60 | Batch: 006 / 021 | Total loss: 2.526 | Reg loss: 0.036 | Tree loss: 2.526 | Accuracy: 0.375000 | 0.109 sec/iter
Epoch: 60 | Batch: 007 / 021 | Total loss: 2.533 | Reg loss: 0.036 | Tree loss: 2.533 | Accuracy: 0.341797 | 0.109 sec/iter
Epoch: 6

Epoch: 63 | Batch: 000 / 021 | Total loss: 2.516 | Reg loss: 0.037 | Tree loss: 2.516 | Accuracy: 0.367188 | 0.109 sec/iter
Epoch: 63 | Batch: 001 / 021 | Total loss: 2.506 | Reg loss: 0.037 | Tree loss: 2.506 | Accuracy: 0.408203 | 0.109 sec/iter
Epoch: 63 | Batch: 002 / 021 | Total loss: 2.494 | Reg loss: 0.037 | Tree loss: 2.494 | Accuracy: 0.369141 | 0.109 sec/iter
Epoch: 63 | Batch: 003 / 021 | Total loss: 2.520 | Reg loss: 0.037 | Tree loss: 2.520 | Accuracy: 0.337891 | 0.109 sec/iter
Epoch: 63 | Batch: 004 / 021 | Total loss: 2.446 | Reg loss: 0.037 | Tree loss: 2.446 | Accuracy: 0.361328 | 0.109 sec/iter
Epoch: 63 | Batch: 005 / 021 | Total loss: 2.431 | Reg loss: 0.037 | Tree loss: 2.431 | Accuracy: 0.363281 | 0.109 sec/iter
Epoch: 63 | Batch: 006 / 021 | Total loss: 2.477 | Reg loss: 0.037 | Tree loss: 2.477 | Accuracy: 0.363281 | 0.109 sec/iter
Epoch: 63 | Batch: 007 / 021 | Total loss: 2.382 | Reg loss: 0.037 | Tree loss: 2.382 | Accuracy: 0.378906 | 0.109 sec/iter
Epoch: 6

Epoch: 66 | Batch: 000 / 021 | Total loss: 2.394 | Reg loss: 0.037 | Tree loss: 2.394 | Accuracy: 0.388672 | 0.109 sec/iter
Epoch: 66 | Batch: 001 / 021 | Total loss: 2.390 | Reg loss: 0.037 | Tree loss: 2.390 | Accuracy: 0.373047 | 0.109 sec/iter
Epoch: 66 | Batch: 002 / 021 | Total loss: 2.383 | Reg loss: 0.037 | Tree loss: 2.383 | Accuracy: 0.380859 | 0.109 sec/iter
Epoch: 66 | Batch: 003 / 021 | Total loss: 2.329 | Reg loss: 0.037 | Tree loss: 2.329 | Accuracy: 0.363281 | 0.109 sec/iter
Epoch: 66 | Batch: 004 / 021 | Total loss: 2.297 | Reg loss: 0.037 | Tree loss: 2.297 | Accuracy: 0.349609 | 0.109 sec/iter
Epoch: 66 | Batch: 005 / 021 | Total loss: 2.301 | Reg loss: 0.037 | Tree loss: 2.301 | Accuracy: 0.351562 | 0.109 sec/iter
Epoch: 66 | Batch: 006 / 021 | Total loss: 2.220 | Reg loss: 0.037 | Tree loss: 2.220 | Accuracy: 0.400391 | 0.109 sec/iter
Epoch: 66 | Batch: 007 / 021 | Total loss: 2.207 | Reg loss: 0.038 | Tree loss: 2.207 | Accuracy: 0.376953 | 0.109 sec/iter
Epoch: 6

Epoch: 69 | Batch: 000 / 021 | Total loss: 2.338 | Reg loss: 0.038 | Tree loss: 2.338 | Accuracy: 0.376953 | 0.109 sec/iter
Epoch: 69 | Batch: 001 / 021 | Total loss: 2.227 | Reg loss: 0.038 | Tree loss: 2.227 | Accuracy: 0.384766 | 0.109 sec/iter
Epoch: 69 | Batch: 002 / 021 | Total loss: 2.236 | Reg loss: 0.038 | Tree loss: 2.236 | Accuracy: 0.369141 | 0.109 sec/iter
Epoch: 69 | Batch: 003 / 021 | Total loss: 2.207 | Reg loss: 0.038 | Tree loss: 2.207 | Accuracy: 0.390625 | 0.109 sec/iter
Epoch: 69 | Batch: 004 / 021 | Total loss: 2.180 | Reg loss: 0.038 | Tree loss: 2.180 | Accuracy: 0.386719 | 0.109 sec/iter
Epoch: 69 | Batch: 005 / 021 | Total loss: 2.149 | Reg loss: 0.038 | Tree loss: 2.149 | Accuracy: 0.373047 | 0.109 sec/iter
Epoch: 69 | Batch: 006 / 021 | Total loss: 2.122 | Reg loss: 0.038 | Tree loss: 2.122 | Accuracy: 0.373047 | 0.109 sec/iter
Epoch: 69 | Batch: 007 / 021 | Total loss: 2.119 | Reg loss: 0.038 | Tree loss: 2.119 | Accuracy: 0.363281 | 0.109 sec/iter
Epoch: 6

Epoch: 72 | Batch: 000 / 021 | Total loss: 2.179 | Reg loss: 0.038 | Tree loss: 2.179 | Accuracy: 0.390625 | 0.109 sec/iter
Epoch: 72 | Batch: 001 / 021 | Total loss: 2.211 | Reg loss: 0.038 | Tree loss: 2.211 | Accuracy: 0.332031 | 0.109 sec/iter
Epoch: 72 | Batch: 002 / 021 | Total loss: 2.143 | Reg loss: 0.038 | Tree loss: 2.143 | Accuracy: 0.412109 | 0.109 sec/iter
Epoch: 72 | Batch: 003 / 021 | Total loss: 2.087 | Reg loss: 0.038 | Tree loss: 2.087 | Accuracy: 0.371094 | 0.109 sec/iter
Epoch: 72 | Batch: 004 / 021 | Total loss: 2.067 | Reg loss: 0.038 | Tree loss: 2.067 | Accuracy: 0.386719 | 0.109 sec/iter
Epoch: 72 | Batch: 005 / 021 | Total loss: 2.122 | Reg loss: 0.038 | Tree loss: 2.122 | Accuracy: 0.324219 | 0.109 sec/iter
Epoch: 72 | Batch: 006 / 021 | Total loss: 2.103 | Reg loss: 0.038 | Tree loss: 2.103 | Accuracy: 0.339844 | 0.109 sec/iter
Epoch: 72 | Batch: 007 / 021 | Total loss: 2.051 | Reg loss: 0.038 | Tree loss: 2.051 | Accuracy: 0.363281 | 0.109 sec/iter
Epoch: 7

Epoch: 75 | Batch: 000 / 021 | Total loss: 2.099 | Reg loss: 0.038 | Tree loss: 2.099 | Accuracy: 0.365234 | 0.11 sec/iter
Epoch: 75 | Batch: 001 / 021 | Total loss: 2.107 | Reg loss: 0.038 | Tree loss: 2.107 | Accuracy: 0.386719 | 0.11 sec/iter
Epoch: 75 | Batch: 002 / 021 | Total loss: 2.069 | Reg loss: 0.038 | Tree loss: 2.069 | Accuracy: 0.375000 | 0.11 sec/iter
Epoch: 75 | Batch: 003 / 021 | Total loss: 1.983 | Reg loss: 0.038 | Tree loss: 1.983 | Accuracy: 0.388672 | 0.11 sec/iter
Epoch: 75 | Batch: 004 / 021 | Total loss: 1.987 | Reg loss: 0.038 | Tree loss: 1.987 | Accuracy: 0.392578 | 0.11 sec/iter
Epoch: 75 | Batch: 005 / 021 | Total loss: 2.006 | Reg loss: 0.038 | Tree loss: 2.006 | Accuracy: 0.404297 | 0.11 sec/iter
Epoch: 75 | Batch: 006 / 021 | Total loss: 2.022 | Reg loss: 0.038 | Tree loss: 2.022 | Accuracy: 0.333984 | 0.11 sec/iter
Epoch: 75 | Batch: 007 / 021 | Total loss: 1.953 | Reg loss: 0.038 | Tree loss: 1.953 | Accuracy: 0.417969 | 0.11 sec/iter
Epoch: 75 | Batc

Epoch: 78 | Batch: 001 / 021 | Total loss: 1.985 | Reg loss: 0.038 | Tree loss: 1.985 | Accuracy: 0.388672 | 0.11 sec/iter
Epoch: 78 | Batch: 002 / 021 | Total loss: 1.977 | Reg loss: 0.038 | Tree loss: 1.977 | Accuracy: 0.398438 | 0.11 sec/iter
Epoch: 78 | Batch: 003 / 021 | Total loss: 1.951 | Reg loss: 0.038 | Tree loss: 1.951 | Accuracy: 0.386719 | 0.11 sec/iter
Epoch: 78 | Batch: 004 / 021 | Total loss: 1.957 | Reg loss: 0.038 | Tree loss: 1.957 | Accuracy: 0.433594 | 0.11 sec/iter
Epoch: 78 | Batch: 005 / 021 | Total loss: 1.915 | Reg loss: 0.038 | Tree loss: 1.915 | Accuracy: 0.359375 | 0.11 sec/iter
Epoch: 78 | Batch: 006 / 021 | Total loss: 1.941 | Reg loss: 0.038 | Tree loss: 1.941 | Accuracy: 0.365234 | 0.11 sec/iter
Epoch: 78 | Batch: 007 / 021 | Total loss: 1.956 | Reg loss: 0.038 | Tree loss: 1.956 | Accuracy: 0.347656 | 0.11 sec/iter
Epoch: 78 | Batch: 008 / 021 | Total loss: 1.840 | Reg loss: 0.038 | Tree loss: 1.840 | Accuracy: 0.380859 | 0.11 sec/iter
Epoch: 78 | Batc

Epoch: 81 | Batch: 001 / 021 | Total loss: 1.893 | Reg loss: 0.037 | Tree loss: 1.893 | Accuracy: 0.445312 | 0.109 sec/iter
Epoch: 81 | Batch: 002 / 021 | Total loss: 1.842 | Reg loss: 0.037 | Tree loss: 1.842 | Accuracy: 0.437500 | 0.109 sec/iter
Epoch: 81 | Batch: 003 / 021 | Total loss: 2.043 | Reg loss: 0.037 | Tree loss: 2.043 | Accuracy: 0.380859 | 0.109 sec/iter
Epoch: 81 | Batch: 004 / 021 | Total loss: 1.894 | Reg loss: 0.037 | Tree loss: 1.894 | Accuracy: 0.375000 | 0.109 sec/iter
Epoch: 81 | Batch: 005 / 021 | Total loss: 1.811 | Reg loss: 0.037 | Tree loss: 1.811 | Accuracy: 0.410156 | 0.109 sec/iter
Epoch: 81 | Batch: 006 / 021 | Total loss: 1.827 | Reg loss: 0.038 | Tree loss: 1.827 | Accuracy: 0.419922 | 0.109 sec/iter
Epoch: 81 | Batch: 007 / 021 | Total loss: 1.870 | Reg loss: 0.038 | Tree loss: 1.870 | Accuracy: 0.402344 | 0.109 sec/iter
Epoch: 81 | Batch: 008 / 021 | Total loss: 1.824 | Reg loss: 0.038 | Tree loss: 1.824 | Accuracy: 0.410156 | 0.109 sec/iter
Epoch: 8

Epoch: 84 | Batch: 001 / 021 | Total loss: 1.880 | Reg loss: 0.037 | Tree loss: 1.880 | Accuracy: 0.431641 | 0.109 sec/iter
Epoch: 84 | Batch: 002 / 021 | Total loss: 1.868 | Reg loss: 0.037 | Tree loss: 1.868 | Accuracy: 0.376953 | 0.109 sec/iter
Epoch: 84 | Batch: 003 / 021 | Total loss: 1.832 | Reg loss: 0.037 | Tree loss: 1.832 | Accuracy: 0.427734 | 0.109 sec/iter
Epoch: 84 | Batch: 004 / 021 | Total loss: 1.815 | Reg loss: 0.037 | Tree loss: 1.815 | Accuracy: 0.439453 | 0.109 sec/iter
Epoch: 84 | Batch: 005 / 021 | Total loss: 1.787 | Reg loss: 0.037 | Tree loss: 1.787 | Accuracy: 0.423828 | 0.109 sec/iter
Epoch: 84 | Batch: 006 / 021 | Total loss: 1.716 | Reg loss: 0.037 | Tree loss: 1.716 | Accuracy: 0.410156 | 0.109 sec/iter
Epoch: 84 | Batch: 007 / 021 | Total loss: 1.778 | Reg loss: 0.037 | Tree loss: 1.778 | Accuracy: 0.375000 | 0.109 sec/iter
Epoch: 84 | Batch: 008 / 021 | Total loss: 1.779 | Reg loss: 0.037 | Tree loss: 1.779 | Accuracy: 0.410156 | 0.109 sec/iter
Epoch: 8

Epoch: 87 | Batch: 000 / 021 | Total loss: 1.850 | Reg loss: 0.037 | Tree loss: 1.850 | Accuracy: 0.427734 | 0.109 sec/iter
Epoch: 87 | Batch: 001 / 021 | Total loss: 1.905 | Reg loss: 0.037 | Tree loss: 1.905 | Accuracy: 0.437500 | 0.109 sec/iter
Epoch: 87 | Batch: 002 / 021 | Total loss: 1.794 | Reg loss: 0.037 | Tree loss: 1.794 | Accuracy: 0.437500 | 0.109 sec/iter
Epoch: 87 | Batch: 003 / 021 | Total loss: 1.759 | Reg loss: 0.037 | Tree loss: 1.759 | Accuracy: 0.408203 | 0.109 sec/iter
Epoch: 87 | Batch: 004 / 021 | Total loss: 1.750 | Reg loss: 0.037 | Tree loss: 1.750 | Accuracy: 0.392578 | 0.109 sec/iter
Epoch: 87 | Batch: 005 / 021 | Total loss: 1.764 | Reg loss: 0.037 | Tree loss: 1.764 | Accuracy: 0.431641 | 0.109 sec/iter
Epoch: 87 | Batch: 006 / 021 | Total loss: 1.771 | Reg loss: 0.037 | Tree loss: 1.771 | Accuracy: 0.447266 | 0.109 sec/iter
Epoch: 87 | Batch: 007 / 021 | Total loss: 1.713 | Reg loss: 0.037 | Tree loss: 1.713 | Accuracy: 0.437500 | 0.109 sec/iter
Epoch: 8

Epoch: 90 | Batch: 001 / 021 | Total loss: 1.767 | Reg loss: 0.037 | Tree loss: 1.767 | Accuracy: 0.414062 | 0.11 sec/iter
Epoch: 90 | Batch: 002 / 021 | Total loss: 1.775 | Reg loss: 0.037 | Tree loss: 1.775 | Accuracy: 0.449219 | 0.11 sec/iter
Epoch: 90 | Batch: 003 / 021 | Total loss: 1.748 | Reg loss: 0.037 | Tree loss: 1.748 | Accuracy: 0.412109 | 0.11 sec/iter
Epoch: 90 | Batch: 004 / 021 | Total loss: 1.751 | Reg loss: 0.037 | Tree loss: 1.751 | Accuracy: 0.388672 | 0.11 sec/iter
Epoch: 90 | Batch: 005 / 021 | Total loss: 1.683 | Reg loss: 0.037 | Tree loss: 1.683 | Accuracy: 0.429688 | 0.11 sec/iter
Epoch: 90 | Batch: 006 / 021 | Total loss: 1.685 | Reg loss: 0.037 | Tree loss: 1.685 | Accuracy: 0.410156 | 0.11 sec/iter
Epoch: 90 | Batch: 007 / 021 | Total loss: 1.696 | Reg loss: 0.037 | Tree loss: 1.696 | Accuracy: 0.412109 | 0.11 sec/iter
Epoch: 90 | Batch: 008 / 021 | Total loss: 1.671 | Reg loss: 0.037 | Tree loss: 1.671 | Accuracy: 0.402344 | 0.11 sec/iter
Epoch: 90 | Batc

Epoch: 93 | Batch: 002 / 021 | Total loss: 1.714 | Reg loss: 0.037 | Tree loss: 1.714 | Accuracy: 0.419922 | 0.11 sec/iter
Epoch: 93 | Batch: 003 / 021 | Total loss: 1.696 | Reg loss: 0.037 | Tree loss: 1.696 | Accuracy: 0.455078 | 0.11 sec/iter
Epoch: 93 | Batch: 004 / 021 | Total loss: 1.755 | Reg loss: 0.037 | Tree loss: 1.755 | Accuracy: 0.441406 | 0.11 sec/iter
Epoch: 93 | Batch: 005 / 021 | Total loss: 1.711 | Reg loss: 0.037 | Tree loss: 1.711 | Accuracy: 0.406250 | 0.11 sec/iter
Epoch: 93 | Batch: 006 / 021 | Total loss: 1.679 | Reg loss: 0.037 | Tree loss: 1.679 | Accuracy: 0.439453 | 0.11 sec/iter
Epoch: 93 | Batch: 007 / 021 | Total loss: 1.700 | Reg loss: 0.037 | Tree loss: 1.700 | Accuracy: 0.402344 | 0.11 sec/iter
Epoch: 93 | Batch: 008 / 021 | Total loss: 1.642 | Reg loss: 0.037 | Tree loss: 1.642 | Accuracy: 0.423828 | 0.11 sec/iter
Epoch: 93 | Batch: 009 / 021 | Total loss: 1.687 | Reg loss: 0.037 | Tree loss: 1.687 | Accuracy: 0.384766 | 0.11 sec/iter
Epoch: 93 | Batc

Epoch: 96 | Batch: 003 / 021 | Total loss: 1.697 | Reg loss: 0.036 | Tree loss: 1.697 | Accuracy: 0.443359 | 0.11 sec/iter
Epoch: 96 | Batch: 004 / 021 | Total loss: 1.661 | Reg loss: 0.036 | Tree loss: 1.661 | Accuracy: 0.425781 | 0.11 sec/iter
Epoch: 96 | Batch: 005 / 021 | Total loss: 1.644 | Reg loss: 0.036 | Tree loss: 1.644 | Accuracy: 0.417969 | 0.11 sec/iter
Epoch: 96 | Batch: 006 / 021 | Total loss: 1.612 | Reg loss: 0.036 | Tree loss: 1.612 | Accuracy: 0.417969 | 0.11 sec/iter
Epoch: 96 | Batch: 007 / 021 | Total loss: 1.633 | Reg loss: 0.036 | Tree loss: 1.633 | Accuracy: 0.439453 | 0.11 sec/iter
Epoch: 96 | Batch: 008 / 021 | Total loss: 1.629 | Reg loss: 0.036 | Tree loss: 1.629 | Accuracy: 0.453125 | 0.11 sec/iter
Epoch: 96 | Batch: 009 / 021 | Total loss: 1.643 | Reg loss: 0.036 | Tree loss: 1.643 | Accuracy: 0.435547 | 0.11 sec/iter
Epoch: 96 | Batch: 010 / 021 | Total loss: 1.596 | Reg loss: 0.036 | Tree loss: 1.596 | Accuracy: 0.386719 | 0.11 sec/iter
Epoch: 96 | Batc

Epoch: 99 | Batch: 003 / 021 | Total loss: 1.633 | Reg loss: 0.036 | Tree loss: 1.633 | Accuracy: 0.445312 | 0.11 sec/iter
Epoch: 99 | Batch: 004 / 021 | Total loss: 1.672 | Reg loss: 0.036 | Tree loss: 1.672 | Accuracy: 0.439453 | 0.11 sec/iter
Epoch: 99 | Batch: 005 / 021 | Total loss: 1.665 | Reg loss: 0.036 | Tree loss: 1.665 | Accuracy: 0.423828 | 0.11 sec/iter
Epoch: 99 | Batch: 006 / 021 | Total loss: 1.574 | Reg loss: 0.036 | Tree loss: 1.574 | Accuracy: 0.441406 | 0.11 sec/iter
Epoch: 99 | Batch: 007 / 021 | Total loss: 1.611 | Reg loss: 0.036 | Tree loss: 1.611 | Accuracy: 0.394531 | 0.11 sec/iter
Epoch: 99 | Batch: 008 / 021 | Total loss: 1.561 | Reg loss: 0.036 | Tree loss: 1.561 | Accuracy: 0.453125 | 0.11 sec/iter
Epoch: 99 | Batch: 009 / 021 | Total loss: 1.655 | Reg loss: 0.036 | Tree loss: 1.655 | Accuracy: 0.394531 | 0.11 sec/iter
Epoch: 99 | Batch: 010 / 021 | Total loss: 1.536 | Reg loss: 0.036 | Tree loss: 1.536 | Accuracy: 0.451172 | 0.11 sec/iter
Epoch: 99 | Batc

In [66]:
plt.figure(figsize=(10, 5))
plt.ylabel("Accuracy")
plt.xlabel('Iteration')
plt.plot(accs, label='Accuracy vs iteration')
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [67]:
plt.figure()
plt.ylabel("Loss")
plt.xlabel('Iteration')
plt.plot(losses, label='Loss vs iteration')
plt.yscale("log")
plt.show()

plt.figure()
weights = tree.inner_nodes.weight.cpu().detach().numpy().flatten()
plt.hist(weights, bins=500)
weights_std = np.std(weights)
weights_mean = np.mean(weights)
plt.axvline(weights_mean + weights_std, color='r')
plt.axvline(weights_mean - weights_std, color='r')
plt.title(f"Mean: {weights_mean}   |   STD: {weights_std}")
plt.yscale("log")
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

# Tree Visualization

In [68]:
plt.figure(figsize=(15, 10), dpi=80)
avg_height, root = tree.visualize()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Average height: 5.393939393939394


# Extract Rules

# Accumulate samples in the leaves

In [69]:
print(f"Number of patterns: {len(root.get_leaves())}")

Number of patterns: 33


In [70]:
method = 'greedy'

In [71]:
root.clear_leaves_samples()

with torch.no_grad():
    for batch_idx, (data, target) in enumerate(tree_loader):
        root.accumulate_samples(data, method)



# Tighten boundaries

In [72]:
attr_names = dataset.items

# print(attr_names)
leaves = root.get_leaves()
sum_comprehensibility = 0
comprehensibilities = []
for pattern_counter, leaf in enumerate(leaves):
    leaf.reset_path()
    leaf.tighten_with_accumulated_samples()
    conds = leaf.get_path_conditions(attr_names)
    print(f"============== Pattern {pattern_counter + 1} ==============")
    comprehensibilities.append(sum([cond.comprehensibility for cond in conds]))
    
print(f"Average comprehensibility: {np.mean(comprehensibilities)}")
print(f"std comprehensibility: {np.std(comprehensibilities)}")
print(f"var comprehensibility: {np.var(comprehensibilities)}")
print(f"minimum comprehensibility: {np.min(comprehensibilities)}")
print(f"maximum comprehensibility: {np.max(comprehensibilities)}")

  return np.log(1 / (1 - x))


8869
542
211
749
Average comprehensibility: 29.09090909090909
std comprehensibility: 4.647342912165046
var comprehensibility: 21.59779614325069
minimum comprehensibility: 16
maximum comprehensibility: 34
