In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib widget
import numpy as np
import pandas as pd
import torch
import torchvision
from torch import optim
import matplotlib.pyplot as plt
import torch.nn.functional as F
from sklearn.metrics import pairwise_distances
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tqdm.notebook import tqdm
import os
import sys
import time
import torch.nn as nn
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)
from stream_generators.market_basket_dataset import MarketBasketDataset, BinaryEncodingTransform, RemoveItemsTransform
from utils.MatplotlibUtils import reduce_dims_and_plot
from network.auto_encoder import AutoEncoder
from losses.knn_loss import KNNLoss
from sklearn.cluster import DBSCAN
from sklearn.mixture import GaussianMixture
from soft_decision_tree.sdt_model import SDT
from sklearn.metrics import davies_bouldin_score

In [2]:
k = 16
tree_depth = 12
device = 'cuda'
dataset_path = r"/mnt/qnap/ekosman/Groceries_dataset.csv"

## Create the market basket dataset and use one-hot encoding for items

In [3]:
dataset = MarketBasketDataset(dataset_path=dataset_path)

In [4]:
model = AutoEncoder(dataset.n_items, 50, 4).train().to(device)
epochs = 500
lr = 5e-3
batch_size = 512
log_every = 5

In [5]:
dataset.transform = torchvision.transforms.Compose([
    RemoveItemsTransform(p=0.5),
    BinaryEncodingTransform(mapping=dataset.items_to_idx),
]
)
dataset.target_transform = torchvision.transforms.Compose([
    BinaryEncodingTransform(mapping=dataset.items_to_idx),
]
)

In [6]:
model.train()
data_iter = torch.utils.data.DataLoader(dataset,
                                     batch_size=batch_size,
                                     shuffle=True,
                                     num_workers=1,
                                     pin_memory=True)
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)
scheduler = ReduceLROnPlateau(optimizer, 'min', verbose=True, factor=0.7, threshold=1e-4)
knn_crt = KNNLoss(k=k).to(device)
losses = []
alpha = 10/170
gamma = 2
for epoch in range(epochs):
    total_loss = 0
    for iteration, (batch, target) in enumerate(data_iter):
        batch = batch.to(device)
        target = target.to(device)
        outputs, iterm = model(batch, return_intermidiate=True)
        mse_loss = F.binary_cross_entropy_with_logits(outputs, target, reduction='none')
        mask = torch.ones_like(mse_loss)
        mask[target == 0] = alpha ** gamma
        mask[target == 1] = (1 - alpha) ** gamma
        mse_loss = (mse_loss * mask).sum(dim=-1).mean()
        try:
            knn_loss = knn_crt(iterm)
            if torch.isinf(knn_loss):
                knn_loss = 0
        except ValueError:
            knn_loss = torch.tensor(0)
        loss = mse_loss + knn_loss
        total_loss += loss.item()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if iteration % log_every == 0:
            print(f"Epoch {epoch} / {epochs} | iteration {iteration} / {len(data_iter)} | Total Loss: {loss.item()} | KNN Loss: {knn_loss.item()} | BCE Loss: {mse_loss.item()}")
    
    scheduler.step(total_loss / (iteration + 1))
    losses.append(total_loss / (iteration + 1))

Epoch 0 / 500 | iteration 0 / 30 | Total Loss: 8.197713851928711 | KNN Loss: 6.226964950561523 | BCE Loss: 1.9707493782043457
Epoch 0 / 500 | iteration 5 / 30 | Total Loss: 8.225666999816895 | KNN Loss: 6.226774215698242 | BCE Loss: 1.9988924264907837
Epoch 0 / 500 | iteration 10 / 30 | Total Loss: 8.174530029296875 | KNN Loss: 6.2269487380981445 | BCE Loss: 1.94758141040802
Epoch 0 / 500 | iteration 15 / 30 | Total Loss: 8.190489768981934 | KNN Loss: 6.226901054382324 | BCE Loss: 1.9635889530181885
Epoch 0 / 500 | iteration 20 / 30 | Total Loss: 8.177996635437012 | KNN Loss: 6.226274013519287 | BCE Loss: 1.9517223834991455
Epoch 0 / 500 | iteration 25 / 30 | Total Loss: 8.151252746582031 | KNN Loss: 6.225853443145752 | BCE Loss: 1.9253997802734375
Epoch 1 / 500 | iteration 0 / 30 | Total Loss: 8.15381145477295 | KNN Loss: 6.225696086883545 | BCE Loss: 1.9281151294708252
Epoch 1 / 500 | iteration 5 / 30 | Total Loss: 8.157215118408203 | KNN Loss: 6.225313186645508 | BCE Loss: 1.9319019

Epoch 10 / 500 | iteration 25 / 30 | Total Loss: 6.8714399337768555 | KNN Loss: 5.739673137664795 | BCE Loss: 1.13176691532135
Epoch 11 / 500 | iteration 0 / 30 | Total Loss: 6.798647880554199 | KNN Loss: 5.666825771331787 | BCE Loss: 1.131821870803833
Epoch 11 / 500 | iteration 5 / 30 | Total Loss: 6.681827545166016 | KNN Loss: 5.56065034866333 | BCE Loss: 1.1211771965026855
Epoch 11 / 500 | iteration 10 / 30 | Total Loss: 6.610039710998535 | KNN Loss: 5.499035835266113 | BCE Loss: 1.1110038757324219
Epoch 11 / 500 | iteration 15 / 30 | Total Loss: 6.4639201164245605 | KNN Loss: 5.373536109924316 | BCE Loss: 1.0903841257095337
Epoch 11 / 500 | iteration 20 / 30 | Total Loss: 6.341156959533691 | KNN Loss: 5.257620811462402 | BCE Loss: 1.0835363864898682
Epoch 11 / 500 | iteration 25 / 30 | Total Loss: 6.213466644287109 | KNN Loss: 5.108219623565674 | BCE Loss: 1.1052470207214355
Epoch 12 / 500 | iteration 0 / 30 | Total Loss: 6.144377708435059 | KNN Loss: 5.05210018157959 | BCE Loss: 1

Epoch 21 / 500 | iteration 20 / 30 | Total Loss: 4.320324897766113 | KNN Loss: 3.2698252201080322 | BCE Loss: 1.050499439239502
Epoch 21 / 500 | iteration 25 / 30 | Total Loss: 4.358704090118408 | KNN Loss: 3.313469409942627 | BCE Loss: 1.0452346801757812
Epoch 22 / 500 | iteration 0 / 30 | Total Loss: 4.304399013519287 | KNN Loss: 3.252042055130005 | BCE Loss: 1.0523569583892822
Epoch 22 / 500 | iteration 5 / 30 | Total Loss: 4.2353925704956055 | KNN Loss: 3.204836130142212 | BCE Loss: 1.030556559562683
Epoch 22 / 500 | iteration 10 / 30 | Total Loss: 4.277061462402344 | KNN Loss: 3.227719306945801 | BCE Loss: 1.0493419170379639
Epoch 22 / 500 | iteration 15 / 30 | Total Loss: 4.300971031188965 | KNN Loss: 3.247703790664673 | BCE Loss: 1.053267240524292
Epoch 22 / 500 | iteration 20 / 30 | Total Loss: 4.378009796142578 | KNN Loss: 3.311530113220215 | BCE Loss: 1.0664796829223633
Epoch 22 / 500 | iteration 25 / 30 | Total Loss: 4.334157943725586 | KNN Loss: 3.281982421875 | BCE Loss: 1

Epoch 32 / 500 | iteration 15 / 30 | Total Loss: 4.237036228179932 | KNN Loss: 3.198315143585205 | BCE Loss: 1.038720965385437
Epoch 32 / 500 | iteration 20 / 30 | Total Loss: 4.232545852661133 | KNN Loss: 3.216519832611084 | BCE Loss: 1.016026258468628
Epoch 32 / 500 | iteration 25 / 30 | Total Loss: 4.190826892852783 | KNN Loss: 3.1581554412841797 | BCE Loss: 1.0326714515686035
Epoch 33 / 500 | iteration 0 / 30 | Total Loss: 4.197216987609863 | KNN Loss: 3.1687138080596924 | BCE Loss: 1.02850341796875
Epoch 33 / 500 | iteration 5 / 30 | Total Loss: 4.229626655578613 | KNN Loss: 3.184157609939575 | BCE Loss: 1.045469045639038
Epoch 33 / 500 | iteration 10 / 30 | Total Loss: 4.245745658874512 | KNN Loss: 3.210498571395874 | BCE Loss: 1.0352468490600586
Epoch 33 / 500 | iteration 15 / 30 | Total Loss: 4.27773380279541 | KNN Loss: 3.2555296421051025 | BCE Loss: 1.0222043991088867
Epoch 33 / 500 | iteration 20 / 30 | Total Loss: 4.198144435882568 | KNN Loss: 3.184840440750122 | BCE Loss: 

Epoch 43 / 500 | iteration 10 / 30 | Total Loss: 4.156114101409912 | KNN Loss: 3.1511781215667725 | BCE Loss: 1.0049360990524292
Epoch 43 / 500 | iteration 15 / 30 | Total Loss: 4.241332530975342 | KNN Loss: 3.19590163230896 | BCE Loss: 1.0454308986663818
Epoch 43 / 500 | iteration 20 / 30 | Total Loss: 4.158176422119141 | KNN Loss: 3.1401865482330322 | BCE Loss: 1.0179896354675293
Epoch 43 / 500 | iteration 25 / 30 | Total Loss: 4.214005470275879 | KNN Loss: 3.1279385089874268 | BCE Loss: 1.0860671997070312
Epoch 44 / 500 | iteration 0 / 30 | Total Loss: 4.188873291015625 | KNN Loss: 3.1579582691192627 | BCE Loss: 1.0309152603149414
Epoch 44 / 500 | iteration 5 / 30 | Total Loss: 4.160538196563721 | KNN Loss: 3.1324663162231445 | BCE Loss: 1.0280717611312866
Epoch 44 / 500 | iteration 10 / 30 | Total Loss: 4.171976089477539 | KNN Loss: 3.175236225128174 | BCE Loss: 0.9967400431632996
Epoch 44 / 500 | iteration 15 / 30 | Total Loss: 4.1338043212890625 | KNN Loss: 3.134368419647217 | BC

Epoch 54 / 500 | iteration 5 / 30 | Total Loss: 4.209314346313477 | KNN Loss: 3.1803693771362305 | BCE Loss: 1.028944969177246
Epoch 54 / 500 | iteration 10 / 30 | Total Loss: 4.224369049072266 | KNN Loss: 3.1429290771484375 | BCE Loss: 1.0814399719238281
Epoch 54 / 500 | iteration 15 / 30 | Total Loss: 4.168971538543701 | KNN Loss: 3.137782573699951 | BCE Loss: 1.03118896484375
Epoch 54 / 500 | iteration 20 / 30 | Total Loss: 4.174190521240234 | KNN Loss: 3.1545588970184326 | BCE Loss: 1.0196316242218018
Epoch 54 / 500 | iteration 25 / 30 | Total Loss: 4.156680107116699 | KNN Loss: 3.1302263736724854 | BCE Loss: 1.0264536142349243
Epoch 55 / 500 | iteration 0 / 30 | Total Loss: 4.129638671875 | KNN Loss: 3.118009567260742 | BCE Loss: 1.0116291046142578
Epoch 55 / 500 | iteration 5 / 30 | Total Loss: 4.117300987243652 | KNN Loss: 3.1140735149383545 | BCE Loss: 1.003227710723877
Epoch 55 / 500 | iteration 10 / 30 | Total Loss: 4.158773422241211 | KNN Loss: 3.143685817718506 | BCE Loss: 

Epoch 65 / 500 | iteration 0 / 30 | Total Loss: 4.167259216308594 | KNN Loss: 3.107916831970215 | BCE Loss: 1.059342384338379
Epoch 65 / 500 | iteration 5 / 30 | Total Loss: 4.157262802124023 | KNN Loss: 3.1473915576934814 | BCE Loss: 1.009871482849121
Epoch 65 / 500 | iteration 10 / 30 | Total Loss: 4.100696563720703 | KNN Loss: 3.0978970527648926 | BCE Loss: 1.0027992725372314
Epoch 65 / 500 | iteration 15 / 30 | Total Loss: 4.148861408233643 | KNN Loss: 3.1161396503448486 | BCE Loss: 1.0327218770980835
Epoch 65 / 500 | iteration 20 / 30 | Total Loss: 4.183882236480713 | KNN Loss: 3.1561295986175537 | BCE Loss: 1.0277525186538696
Epoch 65 / 500 | iteration 25 / 30 | Total Loss: 4.107870101928711 | KNN Loss: 3.0830211639404297 | BCE Loss: 1.0248491764068604
Epoch 66 / 500 | iteration 0 / 30 | Total Loss: 4.134270668029785 | KNN Loss: 3.1169486045837402 | BCE Loss: 1.017322301864624
Epoch 66 / 500 | iteration 5 / 30 | Total Loss: 4.132001876831055 | KNN Loss: 3.119136333465576 | BCE Lo

Epoch 75 / 500 | iteration 25 / 30 | Total Loss: 4.115258693695068 | KNN Loss: 3.0931437015533447 | BCE Loss: 1.022114872932434
Epoch 76 / 500 | iteration 0 / 30 | Total Loss: 4.11326789855957 | KNN Loss: 3.0994009971618652 | BCE Loss: 1.013866662979126
Epoch 76 / 500 | iteration 5 / 30 | Total Loss: 4.157391548156738 | KNN Loss: 3.124852180480957 | BCE Loss: 1.0325392484664917
Epoch 76 / 500 | iteration 10 / 30 | Total Loss: 4.120510101318359 | KNN Loss: 3.1057159900665283 | BCE Loss: 1.014794111251831
Epoch 76 / 500 | iteration 15 / 30 | Total Loss: 4.126264572143555 | KNN Loss: 3.100647449493408 | BCE Loss: 1.025617241859436
Epoch 76 / 500 | iteration 20 / 30 | Total Loss: 4.141498565673828 | KNN Loss: 3.118583917617798 | BCE Loss: 1.0229146480560303
Epoch 76 / 500 | iteration 25 / 30 | Total Loss: 4.108288764953613 | KNN Loss: 3.081177234649658 | BCE Loss: 1.027111291885376
Epoch 77 / 500 | iteration 0 / 30 | Total Loss: 4.127962112426758 | KNN Loss: 3.1061434745788574 | BCE Loss: 

Epoch 86 / 500 | iteration 20 / 30 | Total Loss: 4.145899295806885 | KNN Loss: 3.097020387649536 | BCE Loss: 1.048878788948059
Epoch 86 / 500 | iteration 25 / 30 | Total Loss: 4.130449295043945 | KNN Loss: 3.128154754638672 | BCE Loss: 1.0022944211959839
Epoch 87 / 500 | iteration 0 / 30 | Total Loss: 4.140462875366211 | KNN Loss: 3.108340263366699 | BCE Loss: 1.0321226119995117
Epoch 87 / 500 | iteration 5 / 30 | Total Loss: 4.137731075286865 | KNN Loss: 3.1127171516418457 | BCE Loss: 1.0250139236450195
Epoch 87 / 500 | iteration 10 / 30 | Total Loss: 4.126206398010254 | KNN Loss: 3.1050164699554443 | BCE Loss: 1.0211896896362305
Epoch 87 / 500 | iteration 15 / 30 | Total Loss: 4.109657287597656 | KNN Loss: 3.0867059230804443 | BCE Loss: 1.022951602935791
Epoch 87 / 500 | iteration 20 / 30 | Total Loss: 4.141723155975342 | KNN Loss: 3.1193838119506836 | BCE Loss: 1.0223392248153687
Epoch 87 / 500 | iteration 25 / 30 | Total Loss: 4.126799583435059 | KNN Loss: 3.098886728286743 | BCE L

Epoch 97 / 500 | iteration 10 / 30 | Total Loss: 4.104032516479492 | KNN Loss: 3.079009532928467 | BCE Loss: 1.025023102760315
Epoch 97 / 500 | iteration 15 / 30 | Total Loss: 4.056861400604248 | KNN Loss: 3.0668978691101074 | BCE Loss: 0.9899635910987854
Epoch 97 / 500 | iteration 20 / 30 | Total Loss: 4.093735694885254 | KNN Loss: 3.072929859161377 | BCE Loss: 1.0208059549331665
Epoch 97 / 500 | iteration 25 / 30 | Total Loss: 4.113432884216309 | KNN Loss: 3.0902647972106934 | BCE Loss: 1.0231680870056152
Epoch 98 / 500 | iteration 0 / 30 | Total Loss: 4.096861362457275 | KNN Loss: 3.0889532566070557 | BCE Loss: 1.0079081058502197
Epoch 98 / 500 | iteration 5 / 30 | Total Loss: 4.11738395690918 | KNN Loss: 3.1147632598876953 | BCE Loss: 1.0026204586029053
Epoch 98 / 500 | iteration 10 / 30 | Total Loss: 4.152283668518066 | KNN Loss: 3.1218464374542236 | BCE Loss: 1.0304373502731323
Epoch 98 / 500 | iteration 15 / 30 | Total Loss: 4.1653218269348145 | KNN Loss: 3.1318163871765137 | BC

Epoch 108 / 500 | iteration 0 / 30 | Total Loss: 4.064780235290527 | KNN Loss: 3.083144426345825 | BCE Loss: 0.981635570526123
Epoch 108 / 500 | iteration 5 / 30 | Total Loss: 4.112067222595215 | KNN Loss: 3.080502510070801 | BCE Loss: 1.0315649509429932
Epoch 108 / 500 | iteration 10 / 30 | Total Loss: 4.134092807769775 | KNN Loss: 3.1187565326690674 | BCE Loss: 1.015336275100708
Epoch 108 / 500 | iteration 15 / 30 | Total Loss: 4.138516902923584 | KNN Loss: 3.1067519187927246 | BCE Loss: 1.0317649841308594
Epoch 108 / 500 | iteration 20 / 30 | Total Loss: 4.074275016784668 | KNN Loss: 3.0694282054901123 | BCE Loss: 1.0048470497131348
Epoch 108 / 500 | iteration 25 / 30 | Total Loss: 4.129911422729492 | KNN Loss: 3.101876974105835 | BCE Loss: 1.0280343294143677
Epoch 109 / 500 | iteration 0 / 30 | Total Loss: 4.121542453765869 | KNN Loss: 3.0900018215179443 | BCE Loss: 1.0315407514572144
Epoch 109 / 500 | iteration 5 / 30 | Total Loss: 4.131997108459473 | KNN Loss: 3.0987720489501953 

Epoch 118 / 500 | iteration 20 / 30 | Total Loss: 4.07534122467041 | KNN Loss: 3.061506748199463 | BCE Loss: 1.0138345956802368
Epoch 118 / 500 | iteration 25 / 30 | Total Loss: 4.095791339874268 | KNN Loss: 3.0686893463134766 | BCE Loss: 1.0271021127700806
Epoch 119 / 500 | iteration 0 / 30 | Total Loss: 4.0810699462890625 | KNN Loss: 3.0565662384033203 | BCE Loss: 1.0245038270950317
Epoch 119 / 500 | iteration 5 / 30 | Total Loss: 4.141618728637695 | KNN Loss: 3.1109230518341064 | BCE Loss: 1.0306955575942993
Epoch 119 / 500 | iteration 10 / 30 | Total Loss: 4.101593017578125 | KNN Loss: 3.0896291732788086 | BCE Loss: 1.0119637250900269
Epoch 119 / 500 | iteration 15 / 30 | Total Loss: 4.140707015991211 | KNN Loss: 3.0879030227661133 | BCE Loss: 1.0528037548065186
Epoch 119 / 500 | iteration 20 / 30 | Total Loss: 4.098947048187256 | KNN Loss: 3.0683906078338623 | BCE Loss: 1.0305564403533936
Epoch 119 / 500 | iteration 25 / 30 | Total Loss: 4.074775218963623 | KNN Loss: 3.06634974479

Epoch 129 / 500 | iteration 10 / 30 | Total Loss: 4.127488136291504 | KNN Loss: 3.0967600345611572 | BCE Loss: 1.0307281017303467
Epoch 129 / 500 | iteration 15 / 30 | Total Loss: 4.08768892288208 | KNN Loss: 3.0845706462860107 | BCE Loss: 1.0031182765960693
Epoch 129 / 500 | iteration 20 / 30 | Total Loss: 4.069705486297607 | KNN Loss: 3.0448086261749268 | BCE Loss: 1.0248967409133911
Epoch 129 / 500 | iteration 25 / 30 | Total Loss: 4.103206634521484 | KNN Loss: 3.0649547576904297 | BCE Loss: 1.0382518768310547
Epoch 130 / 500 | iteration 0 / 30 | Total Loss: 4.0718159675598145 | KNN Loss: 3.0555579662323 | BCE Loss: 1.0162581205368042
Epoch 130 / 500 | iteration 5 / 30 | Total Loss: 4.059959411621094 | KNN Loss: 3.064202308654785 | BCE Loss: 0.9957571029663086
Epoch 130 / 500 | iteration 10 / 30 | Total Loss: 4.129817008972168 | KNN Loss: 3.095857620239258 | BCE Loss: 1.0339596271514893
Epoch 130 / 500 | iteration 15 / 30 | Total Loss: 4.087767601013184 | KNN Loss: 3.097030878067016

Epoch 140 / 500 | iteration 0 / 30 | Total Loss: 4.087669849395752 | KNN Loss: 3.072147846221924 | BCE Loss: 1.0155220031738281
Epoch 140 / 500 | iteration 5 / 30 | Total Loss: 4.066332817077637 | KNN Loss: 3.042588710784912 | BCE Loss: 1.0237441062927246
Epoch 140 / 500 | iteration 10 / 30 | Total Loss: 4.075277328491211 | KNN Loss: 3.0483484268188477 | BCE Loss: 1.0269291400909424
Epoch 140 / 500 | iteration 15 / 30 | Total Loss: 4.139771938323975 | KNN Loss: 3.120537757873535 | BCE Loss: 1.019234299659729
Epoch 140 / 500 | iteration 20 / 30 | Total Loss: 4.068439960479736 | KNN Loss: 3.0506632328033447 | BCE Loss: 1.0177768468856812
Epoch 140 / 500 | iteration 25 / 30 | Total Loss: 4.093867301940918 | KNN Loss: 3.0898005962371826 | BCE Loss: 1.0040664672851562
Epoch 141 / 500 | iteration 0 / 30 | Total Loss: 4.105763912200928 | KNN Loss: 3.1022231578826904 | BCE Loss: 1.0035407543182373
Epoch 141 / 500 | iteration 5 / 30 | Total Loss: 4.129912376403809 | KNN Loss: 3.090496301651001 

Epoch 150 / 500 | iteration 20 / 30 | Total Loss: 4.112130641937256 | KNN Loss: 3.090085983276367 | BCE Loss: 1.0220447778701782
Epoch 150 / 500 | iteration 25 / 30 | Total Loss: 4.1304144859313965 | KNN Loss: 3.0965538024902344 | BCE Loss: 1.033860683441162
Epoch 151 / 500 | iteration 0 / 30 | Total Loss: 4.101290225982666 | KNN Loss: 3.096068859100342 | BCE Loss: 1.0052212476730347
Epoch 151 / 500 | iteration 5 / 30 | Total Loss: 4.12487268447876 | KNN Loss: 3.096296787261963 | BCE Loss: 1.0285760164260864
Epoch 151 / 500 | iteration 10 / 30 | Total Loss: 4.109187126159668 | KNN Loss: 3.067141056060791 | BCE Loss: 1.042046070098877
Epoch 151 / 500 | iteration 15 / 30 | Total Loss: 4.133791923522949 | KNN Loss: 3.114473342895508 | BCE Loss: 1.0193188190460205
Epoch 151 / 500 | iteration 20 / 30 | Total Loss: 4.106403827667236 | KNN Loss: 3.0850844383239746 | BCE Loss: 1.0213193893432617
Epoch 151 / 500 | iteration 25 / 30 | Total Loss: 4.1623640060424805 | KNN Loss: 3.130061626434326 

Epoch 161 / 500 | iteration 10 / 30 | Total Loss: 4.118540287017822 | KNN Loss: 3.0823333263397217 | BCE Loss: 1.036206841468811
Epoch 161 / 500 | iteration 15 / 30 | Total Loss: 4.107219696044922 | KNN Loss: 3.086304187774658 | BCE Loss: 1.0209152698516846
Epoch 161 / 500 | iteration 20 / 30 | Total Loss: 4.0495100021362305 | KNN Loss: 3.0572421550750732 | BCE Loss: 0.9922676682472229
Epoch 161 / 500 | iteration 25 / 30 | Total Loss: 4.108197212219238 | KNN Loss: 3.083733320236206 | BCE Loss: 1.0244641304016113
Epoch 162 / 500 | iteration 0 / 30 | Total Loss: 4.090916633605957 | KNN Loss: 3.1008825302124023 | BCE Loss: 0.9900342226028442
Epoch 162 / 500 | iteration 5 / 30 | Total Loss: 4.107399940490723 | KNN Loss: 3.077390670776367 | BCE Loss: 1.0300090312957764
Epoch 162 / 500 | iteration 10 / 30 | Total Loss: 4.107733726501465 | KNN Loss: 3.115868091583252 | BCE Loss: 0.991865873336792
Epoch 162 / 500 | iteration 15 / 30 | Total Loss: 4.129289627075195 | KNN Loss: 3.080485105514526

Epoch 172 / 500 | iteration 0 / 30 | Total Loss: 4.082667350769043 | KNN Loss: 3.070828437805176 | BCE Loss: 1.011838674545288
Epoch 172 / 500 | iteration 5 / 30 | Total Loss: 4.105261325836182 | KNN Loss: 3.081597089767456 | BCE Loss: 1.0236642360687256
Epoch 172 / 500 | iteration 10 / 30 | Total Loss: 4.077622413635254 | KNN Loss: 3.059946298599243 | BCE Loss: 1.0176758766174316
Epoch 172 / 500 | iteration 15 / 30 | Total Loss: 4.094514846801758 | KNN Loss: 3.0599417686462402 | BCE Loss: 1.0345730781555176
Epoch 172 / 500 | iteration 20 / 30 | Total Loss: 4.062402248382568 | KNN Loss: 3.0536224842071533 | BCE Loss: 1.008779764175415
Epoch 172 / 500 | iteration 25 / 30 | Total Loss: 4.062097072601318 | KNN Loss: 3.0508034229278564 | BCE Loss: 1.0112935304641724
Epoch 173 / 500 | iteration 0 / 30 | Total Loss: 4.069813251495361 | KNN Loss: 3.0651464462280273 | BCE Loss: 1.0046666860580444
Epoch 173 / 500 | iteration 5 / 30 | Total Loss: 4.118173122406006 | KNN Loss: 3.06518816947937 | 

Epoch 182 / 500 | iteration 20 / 30 | Total Loss: 4.077742576599121 | KNN Loss: 3.089813470840454 | BCE Loss: 0.9879289269447327
Epoch 182 / 500 | iteration 25 / 30 | Total Loss: 4.156537055969238 | KNN Loss: 3.1254093647003174 | BCE Loss: 1.031127691268921
Epoch 183 / 500 | iteration 0 / 30 | Total Loss: 4.09780740737915 | KNN Loss: 3.0823864936828613 | BCE Loss: 1.015420913696289
Epoch 183 / 500 | iteration 5 / 30 | Total Loss: 4.075613021850586 | KNN Loss: 3.0519683361053467 | BCE Loss: 1.0236444473266602
Epoch 183 / 500 | iteration 10 / 30 | Total Loss: 4.097132205963135 | KNN Loss: 3.0830821990966797 | BCE Loss: 1.0140501260757446
Epoch 183 / 500 | iteration 15 / 30 | Total Loss: 4.087172508239746 | KNN Loss: 3.086259603500366 | BCE Loss: 1.000913143157959
Epoch 183 / 500 | iteration 20 / 30 | Total Loss: 4.071048736572266 | KNN Loss: 3.057013750076294 | BCE Loss: 1.0140352249145508
Epoch 183 / 500 | iteration 25 / 30 | Total Loss: 4.075319766998291 | KNN Loss: 3.0735678672790527 

Epoch 193 / 500 | iteration 10 / 30 | Total Loss: 4.075530052185059 | KNN Loss: 3.041816234588623 | BCE Loss: 1.0337138175964355
Epoch 193 / 500 | iteration 15 / 30 | Total Loss: 4.071645259857178 | KNN Loss: 3.048696517944336 | BCE Loss: 1.0229487419128418
Epoch 193 / 500 | iteration 20 / 30 | Total Loss: 4.098415374755859 | KNN Loss: 3.069889545440674 | BCE Loss: 1.0285258293151855
Epoch 193 / 500 | iteration 25 / 30 | Total Loss: 4.090778350830078 | KNN Loss: 3.0762500762939453 | BCE Loss: 1.0145280361175537
Epoch 194 / 500 | iteration 0 / 30 | Total Loss: 4.054068565368652 | KNN Loss: 3.038609027862549 | BCE Loss: 1.015459656715393
Epoch 194 / 500 | iteration 5 / 30 | Total Loss: 4.079634666442871 | KNN Loss: 3.064450979232788 | BCE Loss: 1.015183687210083
Epoch 194 / 500 | iteration 10 / 30 | Total Loss: 4.129920959472656 | KNN Loss: 3.093766927719116 | BCE Loss: 1.036153793334961
Epoch 194 / 500 | iteration 15 / 30 | Total Loss: 4.096051216125488 | KNN Loss: 3.101250410079956 | B

Epoch   204: reducing learning rate of group 0 to 2.8824e-04.
Epoch 204 / 500 | iteration 0 / 30 | Total Loss: 4.125419616699219 | KNN Loss: 3.086040496826172 | BCE Loss: 1.0393792390823364
Epoch 204 / 500 | iteration 5 / 30 | Total Loss: 4.092495441436768 | KNN Loss: 3.0590319633483887 | BCE Loss: 1.0334633588790894
Epoch 204 / 500 | iteration 10 / 30 | Total Loss: 4.099667549133301 | KNN Loss: 3.0697333812713623 | BCE Loss: 1.0299339294433594
Epoch 204 / 500 | iteration 15 / 30 | Total Loss: 4.077828884124756 | KNN Loss: 3.0676543712615967 | BCE Loss: 1.0101745128631592
Epoch 204 / 500 | iteration 20 / 30 | Total Loss: 4.065883636474609 | KNN Loss: 3.0601091384887695 | BCE Loss: 1.005774736404419
Epoch 204 / 500 | iteration 25 / 30 | Total Loss: 4.1445441246032715 | KNN Loss: 3.111586332321167 | BCE Loss: 1.0329577922821045
Epoch 205 / 500 | iteration 0 / 30 | Total Loss: 4.07792854309082 | KNN Loss: 3.0669896602630615 | BCE Loss: 1.010939121246338
Epoch 205 / 500 | iteration 5 / 30 

Epoch 214 / 500 | iteration 20 / 30 | Total Loss: 4.0959672927856445 | KNN Loss: 3.0575151443481445 | BCE Loss: 1.0384521484375
Epoch 214 / 500 | iteration 25 / 30 | Total Loss: 4.0950775146484375 | KNN Loss: 3.080497980117798 | BCE Loss: 1.0145795345306396
Epoch 215 / 500 | iteration 0 / 30 | Total Loss: 4.116948127746582 | KNN Loss: 3.110846519470215 | BCE Loss: 1.0061014890670776
Epoch 215 / 500 | iteration 5 / 30 | Total Loss: 4.124973297119141 | KNN Loss: 3.1150460243225098 | BCE Loss: 1.0099270343780518
Epoch 215 / 500 | iteration 10 / 30 | Total Loss: 4.1122636795043945 | KNN Loss: 3.088801622390747 | BCE Loss: 1.0234622955322266
Epoch 215 / 500 | iteration 15 / 30 | Total Loss: 4.1079816818237305 | KNN Loss: 3.0836756229400635 | BCE Loss: 1.024305820465088
Epoch 215 / 500 | iteration 20 / 30 | Total Loss: 4.08434534072876 | KNN Loss: 3.064077377319336 | BCE Loss: 1.0202678442001343
Epoch 215 / 500 | iteration 25 / 30 | Total Loss: 4.054769515991211 | KNN Loss: 3.049403905868530

Epoch 225 / 500 | iteration 10 / 30 | Total Loss: 4.079341411590576 | KNN Loss: 3.073218822479248 | BCE Loss: 1.0061225891113281
Epoch 225 / 500 | iteration 15 / 30 | Total Loss: 4.080904006958008 | KNN Loss: 3.080845594406128 | BCE Loss: 1.0000581741333008
Epoch 225 / 500 | iteration 20 / 30 | Total Loss: 4.0407538414001465 | KNN Loss: 3.0450000762939453 | BCE Loss: 0.9957536458969116
Epoch 225 / 500 | iteration 25 / 30 | Total Loss: 4.089763164520264 | KNN Loss: 3.0705320835113525 | BCE Loss: 1.0192310810089111
Epoch 226 / 500 | iteration 0 / 30 | Total Loss: 4.060108184814453 | KNN Loss: 3.058717966079712 | BCE Loss: 1.0013902187347412
Epoch 226 / 500 | iteration 5 / 30 | Total Loss: 4.033927917480469 | KNN Loss: 3.0242152214050293 | BCE Loss: 1.0097126960754395
Epoch 226 / 500 | iteration 10 / 30 | Total Loss: 4.05644416809082 | KNN Loss: 3.0364699363708496 | BCE Loss: 1.0199744701385498
Epoch 226 / 500 | iteration 15 / 30 | Total Loss: 4.069771766662598 | KNN Loss: 3.0681011676788

Epoch 236 / 500 | iteration 0 / 30 | Total Loss: 4.123483657836914 | KNN Loss: 3.0767765045166016 | BCE Loss: 1.0467071533203125
Epoch 236 / 500 | iteration 5 / 30 | Total Loss: 4.099625587463379 | KNN Loss: 3.077615261077881 | BCE Loss: 1.0220104455947876
Epoch 236 / 500 | iteration 10 / 30 | Total Loss: 4.054172515869141 | KNN Loss: 3.0681746006011963 | BCE Loss: 0.9859981536865234
Epoch 236 / 500 | iteration 15 / 30 | Total Loss: 4.105992794036865 | KNN Loss: 3.090855598449707 | BCE Loss: 1.0151371955871582
Epoch 236 / 500 | iteration 20 / 30 | Total Loss: 4.063653945922852 | KNN Loss: 3.048473596572876 | BCE Loss: 1.0151804685592651
Epoch 236 / 500 | iteration 25 / 30 | Total Loss: 4.1444244384765625 | KNN Loss: 3.1009156703948975 | BCE Loss: 1.0435090065002441
Epoch 237 / 500 | iteration 0 / 30 | Total Loss: 4.147364616394043 | KNN Loss: 3.0983364582061768 | BCE Loss: 1.0490283966064453
Epoch 237 / 500 | iteration 5 / 30 | Total Loss: 4.0580525398254395 | KNN Loss: 3.0640513896942

Epoch 246 / 500 | iteration 20 / 30 | Total Loss: 4.095270156860352 | KNN Loss: 3.0891337394714355 | BCE Loss: 1.006136417388916
Epoch 246 / 500 | iteration 25 / 30 | Total Loss: 4.084768772125244 | KNN Loss: 3.071667432785034 | BCE Loss: 1.0131014585494995
Epoch 247 / 500 | iteration 0 / 30 | Total Loss: 4.1899213790893555 | KNN Loss: 3.1274044513702393 | BCE Loss: 1.062516689300537
Epoch 247 / 500 | iteration 5 / 30 | Total Loss: 4.106707572937012 | KNN Loss: 3.070157527923584 | BCE Loss: 1.0365499258041382
Epoch 247 / 500 | iteration 10 / 30 | Total Loss: 4.08260440826416 | KNN Loss: 3.062119722366333 | BCE Loss: 1.0204845666885376
Epoch 247 / 500 | iteration 15 / 30 | Total Loss: 4.152750492095947 | KNN Loss: 3.099821090698242 | BCE Loss: 1.0529295206069946
Epoch 247 / 500 | iteration 20 / 30 | Total Loss: 4.082065105438232 | KNN Loss: 3.078385829925537 | BCE Loss: 1.0036792755126953
Epoch 247 / 500 | iteration 25 / 30 | Total Loss: 4.065840721130371 | KNN Loss: 3.0683465003967285 

Epoch 257 / 500 | iteration 10 / 30 | Total Loss: 4.070108413696289 | KNN Loss: 3.037182331085205 | BCE Loss: 1.0329259634017944
Epoch 257 / 500 | iteration 15 / 30 | Total Loss: 4.140316486358643 | KNN Loss: 3.114244222640991 | BCE Loss: 1.0260722637176514
Epoch 257 / 500 | iteration 20 / 30 | Total Loss: 4.135322570800781 | KNN Loss: 3.096376657485962 | BCE Loss: 1.0389459133148193
Epoch 257 / 500 | iteration 25 / 30 | Total Loss: 4.143105506896973 | KNN Loss: 3.0671355724334717 | BCE Loss: 1.07597017288208
Epoch 258 / 500 | iteration 0 / 30 | Total Loss: 4.090673446655273 | KNN Loss: 3.0576834678649902 | BCE Loss: 1.0329902172088623
Epoch 258 / 500 | iteration 5 / 30 | Total Loss: 4.109383583068848 | KNN Loss: 3.0917980670928955 | BCE Loss: 1.0175857543945312
Epoch 258 / 500 | iteration 10 / 30 | Total Loss: 4.052361011505127 | KNN Loss: 3.0545129776000977 | BCE Loss: 0.9978480339050293
Epoch 258 / 500 | iteration 15 / 30 | Total Loss: 4.101809024810791 | KNN Loss: 3.076626539230346

Epoch 268 / 500 | iteration 0 / 30 | Total Loss: 4.100186347961426 | KNN Loss: 3.05822491645813 | BCE Loss: 1.0419611930847168
Epoch 268 / 500 | iteration 5 / 30 | Total Loss: 4.088730812072754 | KNN Loss: 3.0611038208007812 | BCE Loss: 1.0276272296905518
Epoch 268 / 500 | iteration 10 / 30 | Total Loss: 4.037891864776611 | KNN Loss: 3.029721736907959 | BCE Loss: 1.0081701278686523
Epoch 268 / 500 | iteration 15 / 30 | Total Loss: 4.092559337615967 | KNN Loss: 3.08541202545166 | BCE Loss: 1.0071474313735962
Epoch 268 / 500 | iteration 20 / 30 | Total Loss: 4.059804916381836 | KNN Loss: 3.0516316890716553 | BCE Loss: 1.0081733465194702
Epoch 268 / 500 | iteration 25 / 30 | Total Loss: 4.051492691040039 | KNN Loss: 3.0716278553009033 | BCE Loss: 0.9798646569252014
Epoch 269 / 500 | iteration 0 / 30 | Total Loss: 4.081490516662598 | KNN Loss: 3.0276541709899902 | BCE Loss: 1.0538363456726074
Epoch 269 / 500 | iteration 5 / 30 | Total Loss: 4.041841506958008 | KNN Loss: 3.030160665512085 |

Epoch 278 / 500 | iteration 20 / 30 | Total Loss: 4.07474422454834 | KNN Loss: 3.043532371520996 | BCE Loss: 1.0312120914459229
Epoch 278 / 500 | iteration 25 / 30 | Total Loss: 4.093918800354004 | KNN Loss: 3.05249285697937 | BCE Loss: 1.041426181793213
Epoch 279 / 500 | iteration 0 / 30 | Total Loss: 4.113008499145508 | KNN Loss: 3.0846848487854004 | BCE Loss: 1.0283238887786865
Epoch 279 / 500 | iteration 5 / 30 | Total Loss: 4.090731143951416 | KNN Loss: 3.051161527633667 | BCE Loss: 1.039569616317749
Epoch 279 / 500 | iteration 10 / 30 | Total Loss: 4.082348823547363 | KNN Loss: 3.055863857269287 | BCE Loss: 1.0264849662780762
Epoch 279 / 500 | iteration 15 / 30 | Total Loss: 4.068127632141113 | KNN Loss: 3.054783582687378 | BCE Loss: 1.0133440494537354
Epoch 279 / 500 | iteration 20 / 30 | Total Loss: 4.050754547119141 | KNN Loss: 3.0480031967163086 | BCE Loss: 1.0027514696121216
Epoch 279 / 500 | iteration 25 / 30 | Total Loss: 4.101797103881836 | KNN Loss: 3.0572683811187744 | 

Epoch 289 / 500 | iteration 10 / 30 | Total Loss: 4.04249382019043 | KNN Loss: 3.031182050704956 | BCE Loss: 1.0113115310668945
Epoch 289 / 500 | iteration 15 / 30 | Total Loss: 4.05267333984375 | KNN Loss: 3.068523406982422 | BCE Loss: 0.984149694442749
Epoch 289 / 500 | iteration 20 / 30 | Total Loss: 4.067854881286621 | KNN Loss: 3.0508241653442383 | BCE Loss: 1.0170304775238037
Epoch 289 / 500 | iteration 25 / 30 | Total Loss: 4.091040134429932 | KNN Loss: 3.0988481044769287 | BCE Loss: 0.9921919703483582
Epoch 290 / 500 | iteration 0 / 30 | Total Loss: 4.085176467895508 | KNN Loss: 3.0615639686584473 | BCE Loss: 1.0236122608184814
Epoch 290 / 500 | iteration 5 / 30 | Total Loss: 4.060710906982422 | KNN Loss: 3.057844638824463 | BCE Loss: 1.0028663873672485
Epoch 290 / 500 | iteration 10 / 30 | Total Loss: 4.091562747955322 | KNN Loss: 3.087888717651367 | BCE Loss: 1.003674030303955
Epoch 290 / 500 | iteration 15 / 30 | Total Loss: 4.131926536560059 | KNN Loss: 3.0802345275878906 |

Epoch 300 / 500 | iteration 0 / 30 | Total Loss: 4.03791618347168 | KNN Loss: 3.0237276554107666 | BCE Loss: 1.0141886472702026
Epoch 300 / 500 | iteration 5 / 30 | Total Loss: 4.053903579711914 | KNN Loss: 3.036454916000366 | BCE Loss: 1.017448902130127
Epoch 300 / 500 | iteration 10 / 30 | Total Loss: 4.1013946533203125 | KNN Loss: 3.069600820541382 | BCE Loss: 1.0317939519882202
Epoch 300 / 500 | iteration 15 / 30 | Total Loss: 4.0716166496276855 | KNN Loss: 3.0533080101013184 | BCE Loss: 1.0183086395263672
Epoch 300 / 500 | iteration 20 / 30 | Total Loss: 4.098880767822266 | KNN Loss: 3.095897674560547 | BCE Loss: 1.0029830932617188
Epoch 300 / 500 | iteration 25 / 30 | Total Loss: 4.079333305358887 | KNN Loss: 3.0556344985961914 | BCE Loss: 1.0236989259719849
Epoch 301 / 500 | iteration 0 / 30 | Total Loss: 4.1141510009765625 | KNN Loss: 3.0491294860839844 | BCE Loss: 1.065021276473999
Epoch 301 / 500 | iteration 5 / 30 | Total Loss: 4.093747615814209 | KNN Loss: 3.079596042633056

Epoch 310 / 500 | iteration 20 / 30 | Total Loss: 4.0688676834106445 | KNN Loss: 3.054885149002075 | BCE Loss: 1.0139825344085693
Epoch 310 / 500 | iteration 25 / 30 | Total Loss: 4.069663047790527 | KNN Loss: 3.0568366050720215 | BCE Loss: 1.0128264427185059
Epoch 311 / 500 | iteration 0 / 30 | Total Loss: 4.105458736419678 | KNN Loss: 3.0780348777770996 | BCE Loss: 1.0274237394332886
Epoch 311 / 500 | iteration 5 / 30 | Total Loss: 4.0656561851501465 | KNN Loss: 3.0675697326660156 | BCE Loss: 0.9980863332748413
Epoch 311 / 500 | iteration 10 / 30 | Total Loss: 4.098581314086914 | KNN Loss: 3.0619049072265625 | BCE Loss: 1.0366766452789307
Epoch 311 / 500 | iteration 15 / 30 | Total Loss: 4.042320251464844 | KNN Loss: 3.022893190383911 | BCE Loss: 1.0194268226623535
Epoch 311 / 500 | iteration 20 / 30 | Total Loss: 4.037567138671875 | KNN Loss: 3.0338921546936035 | BCE Loss: 1.0036749839782715
Epoch 311 / 500 | iteration 25 / 30 | Total Loss: 4.072630405426025 | KNN Loss: 3.0723655223

Epoch 321 / 500 | iteration 10 / 30 | Total Loss: 4.089625358581543 | KNN Loss: 3.05407452583313 | BCE Loss: 1.0355510711669922
Epoch 321 / 500 | iteration 15 / 30 | Total Loss: 4.054434299468994 | KNN Loss: 3.027496099472046 | BCE Loss: 1.0269383192062378
Epoch 321 / 500 | iteration 20 / 30 | Total Loss: 4.070505142211914 | KNN Loss: 3.0712969303131104 | BCE Loss: 0.9992079734802246
Epoch 321 / 500 | iteration 25 / 30 | Total Loss: 4.0474348068237305 | KNN Loss: 3.0566258430480957 | BCE Loss: 0.9908088445663452
Epoch 322 / 500 | iteration 0 / 30 | Total Loss: 4.1109619140625 | KNN Loss: 3.073997735977173 | BCE Loss: 1.0369642972946167
Epoch 322 / 500 | iteration 5 / 30 | Total Loss: 4.1141533851623535 | KNN Loss: 3.076200246810913 | BCE Loss: 1.0379530191421509
Epoch 322 / 500 | iteration 10 / 30 | Total Loss: 4.134049892425537 | KNN Loss: 3.0792510509490967 | BCE Loss: 1.0547988414764404
Epoch 322 / 500 | iteration 15 / 30 | Total Loss: 4.055889129638672 | KNN Loss: 3.051246643066406

Epoch 332 / 500 | iteration 0 / 30 | Total Loss: 4.0422868728637695 | KNN Loss: 3.037440061569214 | BCE Loss: 1.0048468112945557
Epoch 332 / 500 | iteration 5 / 30 | Total Loss: 4.054657936096191 | KNN Loss: 3.055222749710083 | BCE Loss: 0.9994350075721741
Epoch 332 / 500 | iteration 10 / 30 | Total Loss: 4.098073959350586 | KNN Loss: 3.0644638538360596 | BCE Loss: 1.0336103439331055
Epoch 332 / 500 | iteration 15 / 30 | Total Loss: 4.08916711807251 | KNN Loss: 3.07126522064209 | BCE Loss: 1.0179020166397095
Epoch 332 / 500 | iteration 20 / 30 | Total Loss: 4.099562644958496 | KNN Loss: 3.0796992778778076 | BCE Loss: 1.0198631286621094
Epoch 332 / 500 | iteration 25 / 30 | Total Loss: 4.133970260620117 | KNN Loss: 3.1102516651153564 | BCE Loss: 1.0237188339233398
Epoch 333 / 500 | iteration 0 / 30 | Total Loss: 4.1077070236206055 | KNN Loss: 3.0923643112182617 | BCE Loss: 1.0153424739837646
Epoch 333 / 500 | iteration 5 / 30 | Total Loss: 4.0902419090271 | KNN Loss: 3.0819711685180664 

Epoch 342 / 500 | iteration 20 / 30 | Total Loss: 4.122278690338135 | KNN Loss: 3.080533742904663 | BCE Loss: 1.0417449474334717
Epoch 342 / 500 | iteration 25 / 30 | Total Loss: 4.056013107299805 | KNN Loss: 3.0413880348205566 | BCE Loss: 1.0146253108978271
Epoch 343 / 500 | iteration 0 / 30 | Total Loss: 4.137190818786621 | KNN Loss: 3.083211660385132 | BCE Loss: 1.0539790391921997
Epoch 343 / 500 | iteration 5 / 30 | Total Loss: 4.1156158447265625 | KNN Loss: 3.0589075088500977 | BCE Loss: 1.056708574295044
Epoch 343 / 500 | iteration 10 / 30 | Total Loss: 4.043389797210693 | KNN Loss: 3.0465331077575684 | BCE Loss: 0.996856689453125
Epoch 343 / 500 | iteration 15 / 30 | Total Loss: 4.051224231719971 | KNN Loss: 3.0441367626190186 | BCE Loss: 1.0070874691009521
Epoch 343 / 500 | iteration 20 / 30 | Total Loss: 4.089864730834961 | KNN Loss: 3.0485119819641113 | BCE Loss: 1.04135262966156
Epoch 343 / 500 | iteration 25 / 30 | Total Loss: 4.054082870483398 | KNN Loss: 3.038988113403320

Epoch 353 / 500 | iteration 10 / 30 | Total Loss: 4.097019195556641 | KNN Loss: 3.074462413787842 | BCE Loss: 1.0225566625595093
Epoch 353 / 500 | iteration 15 / 30 | Total Loss: 4.085872650146484 | KNN Loss: 3.069568634033203 | BCE Loss: 1.0163038969039917
Epoch 353 / 500 | iteration 20 / 30 | Total Loss: 4.0974225997924805 | KNN Loss: 3.0741922855377197 | BCE Loss: 1.0232301950454712
Epoch 353 / 500 | iteration 25 / 30 | Total Loss: 4.092455863952637 | KNN Loss: 3.0785021781921387 | BCE Loss: 1.0139538049697876
Epoch 354 / 500 | iteration 0 / 30 | Total Loss: 4.056356430053711 | KNN Loss: 3.047128200531006 | BCE Loss: 1.0092281103134155
Epoch 354 / 500 | iteration 5 / 30 | Total Loss: 4.112671852111816 | KNN Loss: 3.0718259811401367 | BCE Loss: 1.0408458709716797
Epoch 354 / 500 | iteration 10 / 30 | Total Loss: 4.09751558303833 | KNN Loss: 3.079820156097412 | BCE Loss: 1.017695426940918
Epoch 354 / 500 | iteration 15 / 30 | Total Loss: 4.140211582183838 | KNN Loss: 3.137942314147949

Epoch 363 / 500 | iteration 25 / 30 | Total Loss: 4.09181022644043 | KNN Loss: 3.063035249710083 | BCE Loss: 1.0287747383117676
Epoch 364 / 500 | iteration 0 / 30 | Total Loss: 4.050440788269043 | KNN Loss: 3.049584150314331 | BCE Loss: 1.0008567571640015
Epoch 364 / 500 | iteration 5 / 30 | Total Loss: 4.089094638824463 | KNN Loss: 3.0801994800567627 | BCE Loss: 1.0088951587677002
Epoch 364 / 500 | iteration 10 / 30 | Total Loss: 4.087156295776367 | KNN Loss: 3.0584588050842285 | BCE Loss: 1.0286972522735596
Epoch 364 / 500 | iteration 15 / 30 | Total Loss: 4.084129333496094 | KNN Loss: 3.0673892498016357 | BCE Loss: 1.016740322113037
Epoch 364 / 500 | iteration 20 / 30 | Total Loss: 4.062153339385986 | KNN Loss: 3.047149181365967 | BCE Loss: 1.0150041580200195
Epoch 364 / 500 | iteration 25 / 30 | Total Loss: 4.113218307495117 | KNN Loss: 3.107050657272339 | BCE Loss: 1.0061675310134888
Epoch 365 / 500 | iteration 0 / 30 | Total Loss: 4.04009485244751 | KNN Loss: 3.0605435371398926 |

Epoch 374 / 500 | iteration 15 / 30 | Total Loss: 4.126917839050293 | KNN Loss: 3.1130356788635254 | BCE Loss: 1.0138819217681885
Epoch 374 / 500 | iteration 20 / 30 | Total Loss: 4.104763031005859 | KNN Loss: 3.078443765640259 | BCE Loss: 1.0263195037841797
Epoch 374 / 500 | iteration 25 / 30 | Total Loss: 4.115032196044922 | KNN Loss: 3.07304310798645 | BCE Loss: 1.0419893264770508
Epoch 375 / 500 | iteration 0 / 30 | Total Loss: 4.061326026916504 | KNN Loss: 3.0720436573028564 | BCE Loss: 0.989282488822937
Epoch 375 / 500 | iteration 5 / 30 | Total Loss: 4.078152656555176 | KNN Loss: 3.0504777431488037 | BCE Loss: 1.027674913406372
Epoch 375 / 500 | iteration 10 / 30 | Total Loss: 4.116270065307617 | KNN Loss: 3.0732333660125732 | BCE Loss: 1.0430364608764648
Epoch 375 / 500 | iteration 15 / 30 | Total Loss: 4.064121246337891 | KNN Loss: 3.045297145843506 | BCE Loss: 1.0188243389129639
Epoch 375 / 500 | iteration 20 / 30 | Total Loss: 4.106471061706543 | KNN Loss: 3.063218832015991 

Epoch 385 / 500 | iteration 5 / 30 | Total Loss: 4.129976272583008 | KNN Loss: 3.1048688888549805 | BCE Loss: 1.0251073837280273
Epoch 385 / 500 | iteration 10 / 30 | Total Loss: 4.064586639404297 | KNN Loss: 3.055338144302368 | BCE Loss: 1.0092484951019287
Epoch 385 / 500 | iteration 15 / 30 | Total Loss: 4.046426296234131 | KNN Loss: 3.042177200317383 | BCE Loss: 1.0042489767074585
Epoch 385 / 500 | iteration 20 / 30 | Total Loss: 4.09766960144043 | KNN Loss: 3.0776567459106445 | BCE Loss: 1.0200127363204956
Epoch 385 / 500 | iteration 25 / 30 | Total Loss: 4.068721771240234 | KNN Loss: 3.0444581508636475 | BCE Loss: 1.024263858795166
Epoch 386 / 500 | iteration 0 / 30 | Total Loss: 4.124364376068115 | KNN Loss: 3.0850038528442383 | BCE Loss: 1.039360523223877
Epoch 386 / 500 | iteration 5 / 30 | Total Loss: 4.096399307250977 | KNN Loss: 3.065905809402466 | BCE Loss: 1.0304932594299316
Epoch 386 / 500 | iteration 10 / 30 | Total Loss: 4.106784343719482 | KNN Loss: 3.0762219429016113 

Epoch 395 / 500 | iteration 25 / 30 | Total Loss: 4.082831382751465 | KNN Loss: 3.0514848232269287 | BCE Loss: 1.0313466787338257
Epoch 396 / 500 | iteration 0 / 30 | Total Loss: 4.0340423583984375 | KNN Loss: 3.0247106552124023 | BCE Loss: 1.0093317031860352
Epoch 396 / 500 | iteration 5 / 30 | Total Loss: 4.067949295043945 | KNN Loss: 3.0632834434509277 | BCE Loss: 1.0046658515930176
Epoch 396 / 500 | iteration 10 / 30 | Total Loss: 4.078846454620361 | KNN Loss: 3.0822558403015137 | BCE Loss: 0.9965904355049133
Epoch 396 / 500 | iteration 15 / 30 | Total Loss: 4.064353942871094 | KNN Loss: 3.0487091541290283 | BCE Loss: 1.0156450271606445
Epoch 396 / 500 | iteration 20 / 30 | Total Loss: 4.141764163970947 | KNN Loss: 3.0865445137023926 | BCE Loss: 1.0552196502685547
Epoch 396 / 500 | iteration 25 / 30 | Total Loss: 4.078712463378906 | KNN Loss: 3.0742363929748535 | BCE Loss: 1.0044763088226318
Epoch 397 / 500 | iteration 0 / 30 | Total Loss: 4.101964473724365 | KNN Loss: 3.0839352607

Epoch 406 / 500 | iteration 15 / 30 | Total Loss: 4.090633869171143 | KNN Loss: 3.1064648628234863 | BCE Loss: 0.9841690063476562
Epoch 406 / 500 | iteration 20 / 30 | Total Loss: 4.092097282409668 | KNN Loss: 3.066734790802002 | BCE Loss: 1.0253627300262451
Epoch 406 / 500 | iteration 25 / 30 | Total Loss: 4.071992874145508 | KNN Loss: 3.0668017864227295 | BCE Loss: 1.0051910877227783
Epoch 407 / 500 | iteration 0 / 30 | Total Loss: 4.057834148406982 | KNN Loss: 3.0643906593322754 | BCE Loss: 0.9934436082839966
Epoch 407 / 500 | iteration 5 / 30 | Total Loss: 4.059449195861816 | KNN Loss: 3.0701775550842285 | BCE Loss: 0.9892714023590088
Epoch 407 / 500 | iteration 10 / 30 | Total Loss: 4.111280918121338 | KNN Loss: 3.0892813205718994 | BCE Loss: 1.021999478340149
Epoch 407 / 500 | iteration 15 / 30 | Total Loss: 4.080898761749268 | KNN Loss: 3.0628952980041504 | BCE Loss: 1.0180034637451172
Epoch 407 / 500 | iteration 20 / 30 | Total Loss: 4.090489387512207 | KNN Loss: 3.061770200729

Epoch 417 / 500 | iteration 5 / 30 | Total Loss: 4.050157070159912 | KNN Loss: 3.0386476516723633 | BCE Loss: 1.0115095376968384
Epoch 417 / 500 | iteration 10 / 30 | Total Loss: 4.112070083618164 | KNN Loss: 3.088057518005371 | BCE Loss: 1.0240126848220825
Epoch 417 / 500 | iteration 15 / 30 | Total Loss: 4.081592559814453 | KNN Loss: 3.0780200958251953 | BCE Loss: 1.0035722255706787
Epoch 417 / 500 | iteration 20 / 30 | Total Loss: 4.071453094482422 | KNN Loss: 3.0735833644866943 | BCE Loss: 0.9978697299957275
Epoch 417 / 500 | iteration 25 / 30 | Total Loss: 4.114165306091309 | KNN Loss: 3.045379161834717 | BCE Loss: 1.0687860250473022
Epoch 418 / 500 | iteration 0 / 30 | Total Loss: 4.12565803527832 | KNN Loss: 3.118760824203491 | BCE Loss: 1.0068974494934082
Epoch 418 / 500 | iteration 5 / 30 | Total Loss: 4.043552875518799 | KNN Loss: 3.0345213413238525 | BCE Loss: 1.0090315341949463
Epoch 418 / 500 | iteration 10 / 30 | Total Loss: 4.083132743835449 | KNN Loss: 3.073400497436523

Epoch 427 / 500 | iteration 25 / 30 | Total Loss: 4.063131809234619 | KNN Loss: 3.060817241668701 | BCE Loss: 1.0023144483566284
Epoch 428 / 500 | iteration 0 / 30 | Total Loss: 4.0294671058654785 | KNN Loss: 3.0181758403778076 | BCE Loss: 1.011291265487671
Epoch 428 / 500 | iteration 5 / 30 | Total Loss: 4.068187713623047 | KNN Loss: 3.0853111743927 | BCE Loss: 0.9828767776489258
Epoch 428 / 500 | iteration 10 / 30 | Total Loss: 4.115523338317871 | KNN Loss: 3.0644357204437256 | BCE Loss: 1.0510876178741455
Epoch 428 / 500 | iteration 15 / 30 | Total Loss: 4.01731014251709 | KNN Loss: 3.032503128051758 | BCE Loss: 0.9848072528839111
Epoch 428 / 500 | iteration 20 / 30 | Total Loss: 4.1256513595581055 | KNN Loss: 3.0785109996795654 | BCE Loss: 1.04714035987854
Epoch 428 / 500 | iteration 25 / 30 | Total Loss: 4.085188388824463 | KNN Loss: 3.0763819217681885 | BCE Loss: 1.0088064670562744
Epoch 429 / 500 | iteration 0 / 30 | Total Loss: 4.111865520477295 | KNN Loss: 3.09560227394104 | B

Epoch 438 / 500 | iteration 15 / 30 | Total Loss: 4.070826530456543 | KNN Loss: 3.068413496017456 | BCE Loss: 1.0024127960205078
Epoch 438 / 500 | iteration 20 / 30 | Total Loss: 4.064877986907959 | KNN Loss: 3.0348563194274902 | BCE Loss: 1.0300217866897583
Epoch 438 / 500 | iteration 25 / 30 | Total Loss: 4.051326751708984 | KNN Loss: 3.0578482151031494 | BCE Loss: 0.9934786558151245
Epoch 439 / 500 | iteration 0 / 30 | Total Loss: 4.058355331420898 | KNN Loss: 3.0422353744506836 | BCE Loss: 1.016120195388794
Epoch 439 / 500 | iteration 5 / 30 | Total Loss: 4.057433128356934 | KNN Loss: 3.0393223762512207 | BCE Loss: 1.018110752105713
Epoch 439 / 500 | iteration 10 / 30 | Total Loss: 4.105820655822754 | KNN Loss: 3.0920188426971436 | BCE Loss: 1.0138015747070312
Epoch 439 / 500 | iteration 15 / 30 | Total Loss: 4.0294389724731445 | KNN Loss: 3.038922071456909 | BCE Loss: 0.9905167818069458
Epoch 439 / 500 | iteration 20 / 30 | Total Loss: 4.088231086730957 | KNN Loss: 3.0647845268249

Epoch 449 / 500 | iteration 0 / 30 | Total Loss: 4.074820518493652 | KNN Loss: 3.059860944747925 | BCE Loss: 1.0149593353271484
Epoch 449 / 500 | iteration 5 / 30 | Total Loss: 4.077322959899902 | KNN Loss: 3.0536935329437256 | BCE Loss: 1.0236293077468872
Epoch 449 / 500 | iteration 10 / 30 | Total Loss: 4.085323333740234 | KNN Loss: 3.068756341934204 | BCE Loss: 1.0165669918060303
Epoch 449 / 500 | iteration 15 / 30 | Total Loss: 4.125814437866211 | KNN Loss: 3.0888099670410156 | BCE Loss: 1.0370047092437744
Epoch 449 / 500 | iteration 20 / 30 | Total Loss: 4.084348678588867 | KNN Loss: 3.0829086303710938 | BCE Loss: 1.0014402866363525
Epoch 449 / 500 | iteration 25 / 30 | Total Loss: 4.132314682006836 | KNN Loss: 3.0944225788116455 | BCE Loss: 1.0378918647766113
Epoch 450 / 500 | iteration 0 / 30 | Total Loss: 4.035597324371338 | KNN Loss: 3.0204126834869385 | BCE Loss: 1.0151846408843994
Epoch 450 / 500 | iteration 5 / 30 | Total Loss: 4.066457748413086 | KNN Loss: 3.07345724105834

Epoch 459 / 500 | iteration 20 / 30 | Total Loss: 4.079638957977295 | KNN Loss: 3.042768955230713 | BCE Loss: 1.0368698835372925
Epoch 459 / 500 | iteration 25 / 30 | Total Loss: 4.109965801239014 | KNN Loss: 3.082084894180298 | BCE Loss: 1.0278807878494263
Epoch 460 / 500 | iteration 0 / 30 | Total Loss: 4.082767486572266 | KNN Loss: 3.072995662689209 | BCE Loss: 1.0097715854644775
Epoch 460 / 500 | iteration 5 / 30 | Total Loss: 4.054298400878906 | KNN Loss: 3.0378551483154297 | BCE Loss: 1.0164430141448975
Epoch 460 / 500 | iteration 10 / 30 | Total Loss: 4.093414783477783 | KNN Loss: 3.0742766857147217 | BCE Loss: 1.019137978553772
Epoch 460 / 500 | iteration 15 / 30 | Total Loss: 4.090819835662842 | KNN Loss: 3.07895827293396 | BCE Loss: 1.0118614435195923
Epoch 460 / 500 | iteration 20 / 30 | Total Loss: 4.145552158355713 | KNN Loss: 3.1148338317871094 | BCE Loss: 1.030718207359314
Epoch 460 / 500 | iteration 25 / 30 | Total Loss: 4.0599870681762695 | KNN Loss: 3.039482355117798 

Epoch 470 / 500 | iteration 10 / 30 | Total Loss: 4.104984283447266 | KNN Loss: 3.056821823120117 | BCE Loss: 1.0481626987457275
Epoch 470 / 500 | iteration 15 / 30 | Total Loss: 4.079991340637207 | KNN Loss: 3.051877498626709 | BCE Loss: 1.0281140804290771
Epoch 470 / 500 | iteration 20 / 30 | Total Loss: 4.057466983795166 | KNN Loss: 3.0654470920562744 | BCE Loss: 0.9920198917388916
Epoch 470 / 500 | iteration 25 / 30 | Total Loss: 4.107001781463623 | KNN Loss: 3.0615482330322266 | BCE Loss: 1.045453429222107
Epoch 471 / 500 | iteration 0 / 30 | Total Loss: 4.0921549797058105 | KNN Loss: 3.0709269046783447 | BCE Loss: 1.0212280750274658
Epoch 471 / 500 | iteration 5 / 30 | Total Loss: 4.114859580993652 | KNN Loss: 3.079730749130249 | BCE Loss: 1.0351288318634033
Epoch 471 / 500 | iteration 10 / 30 | Total Loss: 4.035788059234619 | KNN Loss: 3.0500328540802 | BCE Loss: 0.985755205154419
Epoch 471 / 500 | iteration 15 / 30 | Total Loss: 4.117266654968262 | KNN Loss: 3.072608470916748 |

Epoch 481 / 500 | iteration 0 / 30 | Total Loss: 4.073573589324951 | KNN Loss: 3.070242404937744 | BCE Loss: 1.003331184387207
Epoch 481 / 500 | iteration 5 / 30 | Total Loss: 4.109986782073975 | KNN Loss: 3.0789897441864014 | BCE Loss: 1.0309971570968628
Epoch 481 / 500 | iteration 10 / 30 | Total Loss: 4.059248924255371 | KNN Loss: 3.0500409603118896 | BCE Loss: 1.0092079639434814
Epoch 481 / 500 | iteration 15 / 30 | Total Loss: 4.138001918792725 | KNN Loss: 3.11407732963562 | BCE Loss: 1.023924708366394
Epoch 481 / 500 | iteration 20 / 30 | Total Loss: 4.1047563552856445 | KNN Loss: 3.072605609893799 | BCE Loss: 1.0321505069732666
Epoch 481 / 500 | iteration 25 / 30 | Total Loss: 4.089522838592529 | KNN Loss: 3.062883138656616 | BCE Loss: 1.026639699935913
Epoch 482 / 500 | iteration 0 / 30 | Total Loss: 4.08200740814209 | KNN Loss: 3.0660314559936523 | BCE Loss: 1.0159759521484375
Epoch 482 / 500 | iteration 5 / 30 | Total Loss: 4.118971824645996 | KNN Loss: 3.0805106163024902 | B

Epoch 491 / 500 | iteration 20 / 30 | Total Loss: 4.110325813293457 | KNN Loss: 3.0628061294555664 | BCE Loss: 1.0475199222564697
Epoch 491 / 500 | iteration 25 / 30 | Total Loss: 4.0865702629089355 | KNN Loss: 3.091723918914795 | BCE Loss: 0.9948463439941406
Epoch 492 / 500 | iteration 0 / 30 | Total Loss: 4.10693359375 | KNN Loss: 3.0725910663604736 | BCE Loss: 1.0343425273895264
Epoch 492 / 500 | iteration 5 / 30 | Total Loss: 4.1229658126831055 | KNN Loss: 3.0546011924743652 | BCE Loss: 1.0683645009994507
Epoch 492 / 500 | iteration 10 / 30 | Total Loss: 4.117815971374512 | KNN Loss: 3.0638058185577393 | BCE Loss: 1.0540101528167725
Epoch 492 / 500 | iteration 15 / 30 | Total Loss: 4.043185710906982 | KNN Loss: 3.0612683296203613 | BCE Loss: 0.9819175004959106
Epoch 492 / 500 | iteration 20 / 30 | Total Loss: 4.104577541351318 | KNN Loss: 3.0563957691192627 | BCE Loss: 1.0481818914413452
Epoch 492 / 500 | iteration 25 / 30 | Total Loss: 4.127537727355957 | KNN Loss: 3.0996971130371

In [7]:
outputs, iterm = model(dataset[67][0].unsqueeze(0).to(device), return_intermidiate=True)
print(outputs)
print(outputs.min())
print(outputs.max())

tensor([[ 2.9690,  3.3708,  2.5624,  3.2352,  3.2558,  0.8220,  2.5128,  2.0602,
          2.2463,  2.0747,  2.2316,  2.1456,  0.8478,  1.8535,  1.3311,  1.4486,
          2.7928,  2.8470,  2.7278,  2.3734,  1.8174,  2.6831,  2.3625,  2.3584,
          2.3992,  1.5953,  1.8204,  1.4668,  1.5452,  0.4173, -0.2491,  1.0625,
          0.1785,  0.9827,  1.6118,  1.5006,  0.9838,  2.9433,  0.8661,  1.3627,
          1.0244, -0.8252, -0.2011,  2.2390,  2.2572,  0.7698, -0.1759,  0.1700,
          1.4033,  2.3332,  1.8797,  0.0729,  1.3771,  0.5723, -0.6208,  1.2612,
          0.9353,  1.3836,  1.3940,  1.8973,  0.7152,  0.9429,  0.3115,  1.7001,
          1.3407,  1.5566, -1.8287,  0.3413,  2.3174,  2.1991,  2.4107,  0.5554,
          1.3664,  2.2806,  2.0282,  1.3582,  0.1327,  0.7336,  0.2443,  1.6848,
          0.0856,  0.4232,  1.6518, -0.3664,  0.2874, -1.0850, -2.3750, -0.1772,
          0.6430, -1.8912,  0.5034, -0.2220, -0.5476, -0.9107,  0.6913,  1.2768,
         -0.6644, -0.7303,  

In [8]:
plt.figure()
plt.plot(losses)
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [9]:
dataset.transform = torchvision.transforms.Compose([
    BinaryEncodingTransform(mapping=dataset.items_to_idx),
]
)
dataset.target_transform = torchvision.transforms.Compose([
    BinaryEncodingTransform(mapping=dataset.items_to_idx),
]
)

In [10]:
dataset_ = [d[0].cpu() for d in dataset]

In [11]:
model = model.eval().cpu()
projections = model.calculate_intermidiate(dataset_)

100%|██████████| 15/15 [00:00<00:00, 81.65it/s]


In [12]:
distances = pairwise_distances(projections)
# distances = np.triu(distances)
distances_f = distances.flatten()

plt.matshow(distances)
plt.colorbar()
plt.figure()
plt.hist(distances_f[distances_f > 0], bins=1000)
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

# Fit DBSCAN and calculate indices


In [13]:
clusters = DBSCAN(eps=0.2, min_samples=80).fit_predict(projections)
# scores = []
# best_score = float('inf')
# clusters = None
# range_ = list(range(5, 20))
# for k in tqdm(range_):
#     y = GaussianMixture(n_components=k).fit_predict(projections)
#     cur_score = davies_bouldin_score(projections, y)
#     scores.append(cur_score)
    
#     if cur_score < best_score:
#         best_score = cur_score
#         clusters = y

In [14]:
perplexity = 100

p = reduce_dims_and_plot(projections[clusters != -1],
                         y=clusters[clusters != -1],
                         title=f'perplexity: {perplexity}',
                         file_name=None,
                         perplexity=perplexity,
                         library='Multicore-TSNE',
                         perform_PCA=False,
                         projected=None,
                         figure_type='2d',
                         show_figure=True,
                         close_figure=False,
                         text=None)    

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [15]:
# from sklearn.tree import DecisionTreeClassifier
# from sklearn import tree
# from sklearn.tree import _tree

In [16]:
tensor_dataset = torch.stack(dataset_)

In [17]:
# clf = DecisionTreeClassifier(max_depth=200, min_samples_leaf=5)
# clf = clf.fit(tensor_dataset[clusters!=-1], clusters[clusters != -1])
# print(clf.score(tensor_dataset[clusters!=-1], clusters[clusters != -1]))
# print(clf.get_depth())

In [18]:
# scores = []
# for min_samples in range(1,50, 1):
#     clf = DecisionTreeClassifier(max_depth=200, min_samples_leaf=min_samples)
#     clf = clf.fit(tensor_dataset[clusters!=-1], clusters[clusters != -1])
#     scores.append(clf.score(tensor_dataset[clusters!=-1], clusters[clusters != -1]))
    
# plt.figure()
# plt.plot(list(range(1,50, 1)), scores)
# plt.show()

In [19]:
def get_rules(tree, feature_names, class_names):
    tree_ = tree.tree_
    feature_name = [
        feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!"
        for i in tree_.feature
    ]

    paths = []
    path = []
    
    def recurse(node, path, paths):
        if tree_.feature[node] != _tree.TREE_UNDEFINED:
            name = feature_name[node]
            threshold = tree_.threshold[node]
            p1, p2 = list(path), list(path)
#             p1 += [f"({name} <= {np.round(threshold, 3)})"]
            p1 += [(name, '<=', np.round(threshold, 3))]
            recurse(tree_.children_left[node], p1, paths)
            p2 += [(name, '>', np.round(threshold, 3))]
            recurse(tree_.children_right[node], p2, paths)
        else:
            path += [(tree_.value[node], tree_.n_node_samples[node])]
            paths += [path]
            
    recurse(0, path, paths)

    # sort by samples count
    samples_count = [p[-1][1] for p in paths]
    ii = list(np.argsort(samples_count))
    paths = [paths[i] for i in reversed(ii)]
    
    rules = []
    for path in paths:
        rule = []
        
        for p in path[:-1]:
            rule += [p]
        target = " then "
        if class_names is None:
            target += "response: "+str(np.round(path[-1][0][0][0],3))
        else:
            classes = path[-1][0][0]
            l = np.argmax(classes)
            target += f"class: {class_names[l]} (proba: {np.round(100.0*classes[l]/np.sum(classes),2)}%)"
           
        proba = np.round(100.0*classes[l]/np.sum(classes),2)
        target += f" | based on {path[-1][1]:,} samples"
        rule_wrapper = {'target': target, 'rule': rule, 'proba': proba}
        rules += [rule_wrapper]
        
    return rules

In [20]:
# rules = get_rules(clf, dataset.items, clusters[clusters != -1])

# for rule in rules:
#     n_pos = 0
#     for c,p,v in rule['rule']:
#         if p == '>':
#             n_pos += 1
#     rule['pos'] = n_pos

In [21]:
# plt.figure()
# probs = [r['proba'] for r in rules]
# plt.hist(probs, bins = 100)
# plt.show()

In [22]:
# rules = sorted(rules, key=lambda x:x['pos'])
# rules = [r for r in rules if r['proba'] > 50]
# print(len(rules))

In [23]:
# for i in range(17):
#     r_i = rules[i]
#     print(f"------------- rule {i} length {len(r_i)} -------------")
#     print(r_i)

# Train a Soft-Decision-Tree given the self-labels

## Prepare the dataset

In [24]:
tree_dataset = list(zip(tensor_dataset[clusters!=-1], clusters[clusters != -1]))
batch_size = 512
tree_loader = torch.utils.data.DataLoader(tree_dataset, batch_size=batch_size, shuffle=True)

# Define how we prune the weights of a node

In [25]:
def prune_node(node_weights, factor=1):
    w = node_weights.cpu().detach().numpy()
    mean_ = np.mean(w)
    std_ = np.std(w)
    node_weights[((mean_ - std_ * factor) < node_weights) & (node_weights < (mean_ + std_ * factor))] = 0
    return node_weights

def prune_node_keep(node_weights, keep=4):
    w = node_weights.cpu().detach().numpy()
    throw_idx = np.argsort(abs(w))[:-keep]
    node_weights[throw_idx] = 0
    return node_weights

def prune_tree(tree_, factor):
    new_weights = tree_.inner_nodes.weight.clone()
    for i in range(new_weights.shape[0]):
        res = prune_node_keep(new_weights[i, :], factor)
        new_weights[i, :] = res

    with torch.no_grad():
        tree_.inner_nodes.weight.copy_(new_weights)
        
def sparseness(x):
    s = []
    for i in range(x.shape[0]):
        x_ = x[i, :]
        sp = (len(x_) - torch.norm(x_, 0).item()) / len(x_)
        s.append(sp)
    return np.mean(s)

def compute_regularization_by_level(tree):
    total_reg = 0
    for i in range(tree.inner_nodes.weight.shape[0]):
        cur_level = np.floor(np.log2(i+1))
        node_reg = torch.norm(tree.inner_nodes.weight[i].view(-1), 2)
        total_reg += 2**(-cur_level) * node_reg
    return total_reg

def show_sparseness(tree):
    avg_sp = sparseness(tree.inner_nodes.weight)
    print(f"Average sparseness: {avg_sp}")
    layer = 0
    sps = []
    for i in range(tree.inner_nodes.weight.shape[0]):
        cur_layer = int(np.floor(np.log2(i+1)))
        if cur_layer != layer:
            print(f"layer {layer}: {np.mean(sps)}")
            sps = []
            layer = cur_layer

        x_ = tree.inner_nodes.weight[i, :]
        sp = (len(x_) - torch.norm(x_, 0).item()) / len(x_)
        sps.append(sp)
        
    return avg_sp

## Training configurations

In [26]:
def do_epoch(model, loader, device, log_interval, losses, accs, epoch, iteration):
    model = model.train()
    for batch_idx, (data, target) in enumerate(loader):
        iteration += 1
        data, target = data.to(device), target.to(device)

        output, penalty = tree.forward(data)

        # Loss
        loss_tree = criterion(output, target.view(-1))

        # Penalty
        loss_tree += penalty

        # Sparse regularization
#         fc_params = torch.cat([x.view(-1) for x in tree.inner_nodes.parameters()])
#         regularization = sparsity_lamda * torch.norm(fc_params, 2)
        regularization = sparsity_lamda * compute_regularization_by_level(tree)
        loss = loss_tree

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        losses.append(loss.item())

        pred = output.data.max(1)[1]
        correct = pred.eq(target.view(-1).data).sum()
        accs.append(correct.item() / data.size()[0])

        # Print training status
        if batch_idx % log_interval == 0:
            print(f"Epoch: {epoch:02d} | Batch: {batch_idx:03d} / {len(loader):03d} | Total loss: {loss.item():.3f} | Reg loss: {regularization.item():.3f} | Tree loss: {loss_tree.item():.3f} | Accuracy: {correct.item() / data.size()[0]:03f} | {round((time.time() - start_time) / iteration, 3)} sec/iter")
            
    return iteration


In [27]:
lr = 5e-3
weight_decay = 5e-4
sparsity_lamda = 2e-3
epochs = 100
output_dim = len(set(clusters))
log_interval = 1
use_cuda = device != 'cpu'

In [28]:
tree = SDT(input_dim=tensor_dataset.shape[1], output_dim=len(clusters - 1), depth=tree_depth, lamda=1e-3, use_cuda=use_cuda)
optimizer = torch.optim.Adam(tree.parameters(),
                                 lr=lr,
                                 weight_decay=weight_decay)
criterion = nn.CrossEntropyLoss()
tree = tree.to(device)

In [29]:
losses = []
accs = []
sparsity = []

In [30]:
start_time = time.time()
iteration = 0
for epoch in range(epochs):
    # Training
    avg_sp = show_sparseness(tree)
    sparsity.append(avg_sp)
    iteration = do_epoch(tree, tree_loader, device, log_interval, losses, accs, epoch, iteration)
    
    if epoch % 1 == 0:
        prune_tree(tree, factor=3)
        

Average sparseness: 0.0
layer 0: 0.0
layer 1: 0.0
layer 2: 0.0
layer 3: 0.0
layer 4: 0.0
layer 5: 0.0
layer 6: 0.0
layer 7: 0.0
layer 8: 0.0
layer 9: 0.0
layer 10: 0.0
Epoch: 00 | Batch: 000 / 018 | Total loss: 9.632 | Reg loss: 0.014 | Tree loss: 9.632 | Accuracy: 0.000000 | 4.111 sec/iter
Epoch: 00 | Batch: 001 / 018 | Total loss: 9.630 | Reg loss: 0.013 | Tree loss: 9.630 | Accuracy: 0.000000 | 4.112 sec/iter
Epoch: 00 | Batch: 002 / 018 | Total loss: 9.630 | Reg loss: 0.012 | Tree loss: 9.630 | Accuracy: 0.007812 | 4.084 sec/iter
Epoch: 00 | Batch: 003 / 018 | Total loss: 9.629 | Reg loss: 0.011 | Tree loss: 9.629 | Accuracy: 0.035156 | 4.089 sec/iter
Epoch: 00 | Batch: 004 / 018 | Total loss: 9.627 | Reg loss: 0.010 | Tree loss: 9.627 | Accuracy: 0.060547 | 4.094 sec/iter
Epoch: 00 | Batch: 005 / 018 | Total loss: 9.625 | Reg loss: 0.009 | Tree loss: 9.625 | Accuracy: 0.062500 | 4.102 sec/iter
Epoch: 00 | Batch: 006 / 018 | Total loss: 9.625 | Reg loss: 0.008 | Tree loss: 9.625 | 

Epoch: 03 | Batch: 003 / 018 | Total loss: 9.602 | Reg loss: 0.005 | Tree loss: 9.602 | Accuracy: 0.074219 | 4.972 sec/iter
Epoch: 03 | Batch: 004 / 018 | Total loss: 9.602 | Reg loss: 0.005 | Tree loss: 9.602 | Accuracy: 0.076172 | 4.972 sec/iter
Epoch: 03 | Batch: 005 / 018 | Total loss: 9.600 | Reg loss: 0.005 | Tree loss: 9.600 | Accuracy: 0.080078 | 4.967 sec/iter
Epoch: 03 | Batch: 006 / 018 | Total loss: 9.599 | Reg loss: 0.005 | Tree loss: 9.599 | Accuracy: 0.058594 | 4.966 sec/iter
Epoch: 03 | Batch: 007 / 018 | Total loss: 9.597 | Reg loss: 0.005 | Tree loss: 9.597 | Accuracy: 0.072266 | 4.965 sec/iter
Epoch: 03 | Batch: 008 / 018 | Total loss: 9.598 | Reg loss: 0.005 | Tree loss: 9.598 | Accuracy: 0.080078 | 4.964 sec/iter
Epoch: 03 | Batch: 009 / 018 | Total loss: 9.597 | Reg loss: 0.006 | Tree loss: 9.597 | Accuracy: 0.082031 | 4.962 sec/iter
Epoch: 03 | Batch: 010 / 018 | Total loss: 9.597 | Reg loss: 0.006 | Tree loss: 9.597 | Accuracy: 0.070312 | 4.958 sec/iter
Epoch: 0

Epoch: 06 | Batch: 007 / 018 | Total loss: 9.576 | Reg loss: 0.009 | Tree loss: 9.576 | Accuracy: 0.052734 | 5.064 sec/iter
Epoch: 06 | Batch: 008 / 018 | Total loss: 9.569 | Reg loss: 0.009 | Tree loss: 9.569 | Accuracy: 0.085938 | 5.063 sec/iter
Epoch: 06 | Batch: 009 / 018 | Total loss: 9.576 | Reg loss: 0.009 | Tree loss: 9.576 | Accuracy: 0.076172 | 5.061 sec/iter
Epoch: 06 | Batch: 010 / 018 | Total loss: 9.564 | Reg loss: 0.009 | Tree loss: 9.564 | Accuracy: 0.064453 | 5.058 sec/iter
Epoch: 06 | Batch: 011 / 018 | Total loss: 9.567 | Reg loss: 0.009 | Tree loss: 9.567 | Accuracy: 0.087891 | 5.055 sec/iter
Epoch: 06 | Batch: 012 / 018 | Total loss: 9.565 | Reg loss: 0.010 | Tree loss: 9.565 | Accuracy: 0.062500 | 5.052 sec/iter
Epoch: 06 | Batch: 013 / 018 | Total loss: 9.563 | Reg loss: 0.010 | Tree loss: 9.563 | Accuracy: 0.074219 | 5.049 sec/iter
Epoch: 06 | Batch: 014 / 018 | Total loss: 9.557 | Reg loss: 0.010 | Tree loss: 9.557 | Accuracy: 0.064453 | 5.046 sec/iter
Epoch: 0

Epoch: 09 | Batch: 011 / 018 | Total loss: 9.366 | Reg loss: 0.013 | Tree loss: 9.366 | Accuracy: 0.074219 | 5.085 sec/iter
Epoch: 09 | Batch: 012 / 018 | Total loss: 9.355 | Reg loss: 0.014 | Tree loss: 9.355 | Accuracy: 0.076172 | 5.082 sec/iter
Epoch: 09 | Batch: 013 / 018 | Total loss: 9.355 | Reg loss: 0.014 | Tree loss: 9.355 | Accuracy: 0.054688 | 5.08 sec/iter
Epoch: 09 | Batch: 014 / 018 | Total loss: 9.343 | Reg loss: 0.014 | Tree loss: 9.343 | Accuracy: 0.076172 | 5.078 sec/iter
Epoch: 09 | Batch: 015 / 018 | Total loss: 9.317 | Reg loss: 0.015 | Tree loss: 9.317 | Accuracy: 0.082031 | 5.076 sec/iter
Epoch: 09 | Batch: 016 / 018 | Total loss: 9.317 | Reg loss: 0.015 | Tree loss: 9.317 | Accuracy: 0.074219 | 5.075 sec/iter
Epoch: 09 | Batch: 017 / 018 | Total loss: 9.294 | Reg loss: 0.015 | Tree loss: 9.294 | Accuracy: 0.073171 | 5.066 sec/iter
Average sparseness: 0.9821428571428572
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0

Epoch: 12 | Batch: 015 / 018 | Total loss: 8.846 | Reg loss: 0.017 | Tree loss: 8.846 | Accuracy: 0.064453 | 5.059 sec/iter
Epoch: 12 | Batch: 016 / 018 | Total loss: 8.819 | Reg loss: 0.017 | Tree loss: 8.819 | Accuracy: 0.099609 | 5.058 sec/iter
Epoch: 12 | Batch: 017 / 018 | Total loss: 8.788 | Reg loss: 0.017 | Tree loss: 8.788 | Accuracy: 0.121951 | 5.052 sec/iter
Average sparseness: 0.9821428571428572
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
layer 7: 0.9821428571428573
layer 8: 0.9821428571428573
layer 9: 0.9821428571428573
layer 10: 0.9821428571428573
Epoch: 13 | Batch: 000 / 018 | Total loss: 8.983 | Reg loss: 0.015 | Tree loss: 8.983 | Accuracy: 0.076172 | 5.068 sec/iter
Epoch: 13 | Batch: 001 / 018 | Total loss: 8.941 | Reg loss: 0.015 | Tree loss: 8.941 | Accuracy: 0.078125 | 5.07 sec/iter
Epoch: 13 | Batch: 002 / 018 | Tot

Epoch: 16 | Batch: 000 / 018 | Total loss: 8.393 | Reg loss: 0.017 | Tree loss: 8.393 | Accuracy: 0.070312 | 5.111 sec/iter
Epoch: 16 | Batch: 001 / 018 | Total loss: 8.401 | Reg loss: 0.017 | Tree loss: 8.401 | Accuracy: 0.070312 | 5.112 sec/iter
Epoch: 16 | Batch: 002 / 018 | Total loss: 8.366 | Reg loss: 0.017 | Tree loss: 8.366 | Accuracy: 0.068359 | 5.112 sec/iter
Epoch: 16 | Batch: 003 / 018 | Total loss: 8.306 | Reg loss: 0.017 | Tree loss: 8.306 | Accuracy: 0.085938 | 5.112 sec/iter
Epoch: 16 | Batch: 004 / 018 | Total loss: 8.319 | Reg loss: 0.017 | Tree loss: 8.319 | Accuracy: 0.085938 | 5.112 sec/iter
Epoch: 16 | Batch: 005 / 018 | Total loss: 8.305 | Reg loss: 0.017 | Tree loss: 8.305 | Accuracy: 0.078125 | 5.11 sec/iter
Epoch: 16 | Batch: 006 / 018 | Total loss: 8.251 | Reg loss: 0.017 | Tree loss: 8.251 | Accuracy: 0.080078 | 5.106 sec/iter
Epoch: 16 | Batch: 007 / 018 | Total loss: 8.246 | Reg loss: 0.018 | Tree loss: 8.246 | Accuracy: 0.072266 | 5.107 sec/iter
Epoch: 16

Epoch: 19 | Batch: 004 / 018 | Total loss: 7.723 | Reg loss: 0.018 | Tree loss: 7.723 | Accuracy: 0.060547 | 5.11 sec/iter
Epoch: 19 | Batch: 005 / 018 | Total loss: 7.737 | Reg loss: 0.018 | Tree loss: 7.737 | Accuracy: 0.062500 | 5.11 sec/iter
Epoch: 19 | Batch: 006 / 018 | Total loss: 7.710 | Reg loss: 0.019 | Tree loss: 7.710 | Accuracy: 0.068359 | 5.11 sec/iter
Epoch: 19 | Batch: 007 / 018 | Total loss: 7.669 | Reg loss: 0.019 | Tree loss: 7.669 | Accuracy: 0.062500 | 5.11 sec/iter
Epoch: 19 | Batch: 008 / 018 | Total loss: 7.620 | Reg loss: 0.019 | Tree loss: 7.620 | Accuracy: 0.076172 | 5.109 sec/iter
Epoch: 19 | Batch: 009 / 018 | Total loss: 7.615 | Reg loss: 0.019 | Tree loss: 7.615 | Accuracy: 0.078125 | 5.109 sec/iter
Epoch: 19 | Batch: 010 / 018 | Total loss: 7.593 | Reg loss: 0.019 | Tree loss: 7.593 | Accuracy: 0.099609 | 5.108 sec/iter
Epoch: 19 | Batch: 011 / 018 | Total loss: 7.588 | Reg loss: 0.019 | Tree loss: 7.588 | Accuracy: 0.082031 | 5.106 sec/iter
Epoch: 19 | 

Epoch: 22 | Batch: 008 / 018 | Total loss: 7.117 | Reg loss: 0.019 | Tree loss: 7.117 | Accuracy: 0.078125 | 5.132 sec/iter
Epoch: 22 | Batch: 009 / 018 | Total loss: 7.103 | Reg loss: 0.019 | Tree loss: 7.103 | Accuracy: 0.076172 | 5.131 sec/iter
Epoch: 22 | Batch: 010 / 018 | Total loss: 7.103 | Reg loss: 0.020 | Tree loss: 7.103 | Accuracy: 0.052734 | 5.13 sec/iter
Epoch: 22 | Batch: 011 / 018 | Total loss: 7.040 | Reg loss: 0.020 | Tree loss: 7.040 | Accuracy: 0.056641 | 5.129 sec/iter
Epoch: 22 | Batch: 012 / 018 | Total loss: 7.058 | Reg loss: 0.020 | Tree loss: 7.058 | Accuracy: 0.054688 | 5.128 sec/iter
Epoch: 22 | Batch: 013 / 018 | Total loss: 7.021 | Reg loss: 0.020 | Tree loss: 7.021 | Accuracy: 0.080078 | 5.127 sec/iter
Epoch: 22 | Batch: 014 / 018 | Total loss: 6.988 | Reg loss: 0.020 | Tree loss: 6.988 | Accuracy: 0.083984 | 5.126 sec/iter
Epoch: 22 | Batch: 015 / 018 | Total loss: 6.950 | Reg loss: 0.020 | Tree loss: 6.950 | Accuracy: 0.087891 | 5.125 sec/iter
Epoch: 22

Epoch: 25 | Batch: 012 / 018 | Total loss: 6.531 | Reg loss: 0.020 | Tree loss: 6.531 | Accuracy: 0.083984 | 5.118 sec/iter
Epoch: 25 | Batch: 013 / 018 | Total loss: 6.555 | Reg loss: 0.020 | Tree loss: 6.555 | Accuracy: 0.082031 | 5.117 sec/iter
Epoch: 25 | Batch: 014 / 018 | Total loss: 6.537 | Reg loss: 0.020 | Tree loss: 6.537 | Accuracy: 0.070312 | 5.116 sec/iter
Epoch: 25 | Batch: 015 / 018 | Total loss: 6.593 | Reg loss: 0.020 | Tree loss: 6.593 | Accuracy: 0.058594 | 5.116 sec/iter
Epoch: 25 | Batch: 016 / 018 | Total loss: 6.535 | Reg loss: 0.020 | Tree loss: 6.535 | Accuracy: 0.060547 | 5.115 sec/iter
Epoch: 25 | Batch: 017 / 018 | Total loss: 6.514 | Reg loss: 0.020 | Tree loss: 6.514 | Accuracy: 0.073171 | 5.111 sec/iter
Average sparseness: 0.9821428571428572
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
layer 7: 0.98214285714

Epoch: 28 | Batch: 016 / 018 | Total loss: 6.085 | Reg loss: 0.021 | Tree loss: 6.085 | Accuracy: 0.070312 | 5.114 sec/iter
Epoch: 28 | Batch: 017 / 018 | Total loss: 6.001 | Reg loss: 0.021 | Tree loss: 6.001 | Accuracy: 0.048780 | 5.111 sec/iter
Average sparseness: 0.9821428571428572
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
layer 7: 0.9821428571428573
layer 8: 0.9821428571428573
layer 9: 0.9821428571428573
layer 10: 0.9821428571428573
Epoch: 29 | Batch: 000 / 018 | Total loss: 6.209 | Reg loss: 0.021 | Tree loss: 6.209 | Accuracy: 0.072266 | 5.118 sec/iter
Epoch: 29 | Batch: 001 / 018 | Total loss: 6.190 | Reg loss: 0.021 | Tree loss: 6.190 | Accuracy: 0.062500 | 5.118 sec/iter
Epoch: 29 | Batch: 002 / 018 | Total loss: 6.160 | Reg loss: 0.021 | Tree loss: 6.160 | Accuracy: 0.062500 | 5.118 sec/iter
Epoch: 29 | Batch: 003 / 018 | To

Epoch: 32 | Batch: 000 / 018 | Total loss: 5.794 | Reg loss: 0.021 | Tree loss: 5.794 | Accuracy: 0.083984 | 5.117 sec/iter
Epoch: 32 | Batch: 001 / 018 | Total loss: 5.782 | Reg loss: 0.021 | Tree loss: 5.782 | Accuracy: 0.068359 | 5.117 sec/iter
Epoch: 32 | Batch: 002 / 018 | Total loss: 5.765 | Reg loss: 0.021 | Tree loss: 5.765 | Accuracy: 0.072266 | 5.116 sec/iter
Epoch: 32 | Batch: 003 / 018 | Total loss: 5.727 | Reg loss: 0.021 | Tree loss: 5.727 | Accuracy: 0.076172 | 5.116 sec/iter
Epoch: 32 | Batch: 004 / 018 | Total loss: 5.718 | Reg loss: 0.021 | Tree loss: 5.718 | Accuracy: 0.078125 | 5.116 sec/iter
Epoch: 32 | Batch: 005 / 018 | Total loss: 5.743 | Reg loss: 0.021 | Tree loss: 5.743 | Accuracy: 0.076172 | 5.117 sec/iter
Epoch: 32 | Batch: 006 / 018 | Total loss: 5.743 | Reg loss: 0.021 | Tree loss: 5.743 | Accuracy: 0.091797 | 5.116 sec/iter
Epoch: 32 | Batch: 007 / 018 | Total loss: 5.702 | Reg loss: 0.021 | Tree loss: 5.702 | Accuracy: 0.056641 | 5.116 sec/iter
Epoch: 3

Epoch: 35 | Batch: 004 / 018 | Total loss: 5.465 | Reg loss: 0.021 | Tree loss: 5.465 | Accuracy: 0.062500 | 5.109 sec/iter
Epoch: 35 | Batch: 005 / 018 | Total loss: 5.398 | Reg loss: 0.021 | Tree loss: 5.398 | Accuracy: 0.099609 | 5.109 sec/iter
Epoch: 35 | Batch: 006 / 018 | Total loss: 5.355 | Reg loss: 0.021 | Tree loss: 5.355 | Accuracy: 0.082031 | 5.108 sec/iter
Epoch: 35 | Batch: 007 / 018 | Total loss: 5.379 | Reg loss: 0.021 | Tree loss: 5.379 | Accuracy: 0.078125 | 5.108 sec/iter
Epoch: 35 | Batch: 008 / 018 | Total loss: 5.368 | Reg loss: 0.021 | Tree loss: 5.368 | Accuracy: 0.066406 | 5.107 sec/iter
Epoch: 35 | Batch: 009 / 018 | Total loss: 5.423 | Reg loss: 0.021 | Tree loss: 5.423 | Accuracy: 0.050781 | 5.106 sec/iter
Epoch: 35 | Batch: 010 / 018 | Total loss: 5.321 | Reg loss: 0.021 | Tree loss: 5.321 | Accuracy: 0.083984 | 5.106 sec/iter
Epoch: 35 | Batch: 011 / 018 | Total loss: 5.382 | Reg loss: 0.021 | Tree loss: 5.382 | Accuracy: 0.070312 | 5.105 sec/iter
Epoch: 3

Epoch: 38 | Batch: 008 / 018 | Total loss: 5.092 | Reg loss: 0.022 | Tree loss: 5.092 | Accuracy: 0.091797 | 5.112 sec/iter
Epoch: 38 | Batch: 009 / 018 | Total loss: 5.098 | Reg loss: 0.022 | Tree loss: 5.098 | Accuracy: 0.066406 | 5.112 sec/iter
Epoch: 38 | Batch: 010 / 018 | Total loss: 5.107 | Reg loss: 0.022 | Tree loss: 5.107 | Accuracy: 0.070312 | 5.111 sec/iter
Epoch: 38 | Batch: 011 / 018 | Total loss: 5.060 | Reg loss: 0.022 | Tree loss: 5.060 | Accuracy: 0.078125 | 5.111 sec/iter
Epoch: 38 | Batch: 012 / 018 | Total loss: 5.079 | Reg loss: 0.022 | Tree loss: 5.079 | Accuracy: 0.074219 | 5.11 sec/iter
Epoch: 38 | Batch: 013 / 018 | Total loss: 5.059 | Reg loss: 0.022 | Tree loss: 5.059 | Accuracy: 0.064453 | 5.11 sec/iter
Epoch: 38 | Batch: 014 / 018 | Total loss: 5.065 | Reg loss: 0.022 | Tree loss: 5.065 | Accuracy: 0.068359 | 5.109 sec/iter
Epoch: 38 | Batch: 015 / 018 | Total loss: 5.032 | Reg loss: 0.022 | Tree loss: 5.032 | Accuracy: 0.072266 | 5.109 sec/iter
Epoch: 38 

Epoch: 41 | Batch: 012 / 018 | Total loss: 4.823 | Reg loss: 0.022 | Tree loss: 4.823 | Accuracy: 0.052734 | 5.11 sec/iter
Epoch: 41 | Batch: 013 / 018 | Total loss: 4.883 | Reg loss: 0.022 | Tree loss: 4.883 | Accuracy: 0.072266 | 5.11 sec/iter
Epoch: 41 | Batch: 014 / 018 | Total loss: 4.784 | Reg loss: 0.022 | Tree loss: 4.784 | Accuracy: 0.091797 | 5.11 sec/iter
Epoch: 41 | Batch: 015 / 018 | Total loss: 4.799 | Reg loss: 0.022 | Tree loss: 4.799 | Accuracy: 0.076172 | 5.109 sec/iter
Epoch: 41 | Batch: 016 / 018 | Total loss: 4.768 | Reg loss: 0.022 | Tree loss: 4.768 | Accuracy: 0.078125 | 5.109 sec/iter
Epoch: 41 | Batch: 017 / 018 | Total loss: 4.889 | Reg loss: 0.022 | Tree loss: 4.889 | Accuracy: 0.024390 | 5.107 sec/iter
Average sparseness: 0.9821428571428572
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
layer 7: 0.98214285714285

Epoch: 44 | Batch: 016 / 018 | Total loss: 4.627 | Reg loss: 0.022 | Tree loss: 4.627 | Accuracy: 0.076172 | 5.112 sec/iter
Epoch: 44 | Batch: 017 / 018 | Total loss: 4.467 | Reg loss: 0.022 | Tree loss: 4.467 | Accuracy: 0.073171 | 5.11 sec/iter
Average sparseness: 0.9821428571428572
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
layer 7: 0.9821428571428573
layer 8: 0.9821428571428573
layer 9: 0.9821428571428573
layer 10: 0.9821428571428573
Epoch: 45 | Batch: 000 / 018 | Total loss: 4.645 | Reg loss: 0.022 | Tree loss: 4.645 | Accuracy: 0.095703 | 5.114 sec/iter
Epoch: 45 | Batch: 001 / 018 | Total loss: 4.636 | Reg loss: 0.022 | Tree loss: 4.636 | Accuracy: 0.080078 | 5.114 sec/iter
Epoch: 45 | Batch: 002 / 018 | Total loss: 4.625 | Reg loss: 0.022 | Tree loss: 4.625 | Accuracy: 0.082031 | 5.113 sec/iter
Epoch: 45 | Batch: 003 / 018 | Tot

Epoch: 48 | Batch: 000 / 018 | Total loss: 4.516 | Reg loss: 0.022 | Tree loss: 4.516 | Accuracy: 0.097656 | 5.118 sec/iter
Epoch: 48 | Batch: 001 / 018 | Total loss: 4.469 | Reg loss: 0.022 | Tree loss: 4.469 | Accuracy: 0.076172 | 5.118 sec/iter
Epoch: 48 | Batch: 002 / 018 | Total loss: 4.506 | Reg loss: 0.022 | Tree loss: 4.506 | Accuracy: 0.058594 | 5.118 sec/iter
Epoch: 48 | Batch: 003 / 018 | Total loss: 4.438 | Reg loss: 0.022 | Tree loss: 4.438 | Accuracy: 0.083984 | 5.118 sec/iter
Epoch: 48 | Batch: 004 / 018 | Total loss: 4.490 | Reg loss: 0.022 | Tree loss: 4.490 | Accuracy: 0.066406 | 5.117 sec/iter
Epoch: 48 | Batch: 005 / 018 | Total loss: 4.504 | Reg loss: 0.022 | Tree loss: 4.504 | Accuracy: 0.070312 | 5.117 sec/iter
Epoch: 48 | Batch: 006 / 018 | Total loss: 4.407 | Reg loss: 0.022 | Tree loss: 4.407 | Accuracy: 0.080078 | 5.117 sec/iter
Epoch: 48 | Batch: 007 / 018 | Total loss: 4.374 | Reg loss: 0.022 | Tree loss: 4.374 | Accuracy: 0.099609 | 5.116 sec/iter
Epoch: 4

Epoch: 51 | Batch: 004 / 018 | Total loss: 4.360 | Reg loss: 0.022 | Tree loss: 4.360 | Accuracy: 0.072266 | 5.12 sec/iter
Epoch: 51 | Batch: 005 / 018 | Total loss: 4.332 | Reg loss: 0.022 | Tree loss: 4.332 | Accuracy: 0.064453 | 5.12 sec/iter
Epoch: 51 | Batch: 006 / 018 | Total loss: 4.333 | Reg loss: 0.022 | Tree loss: 4.333 | Accuracy: 0.062500 | 5.12 sec/iter
Epoch: 51 | Batch: 007 / 018 | Total loss: 4.266 | Reg loss: 0.022 | Tree loss: 4.266 | Accuracy: 0.074219 | 5.119 sec/iter
Epoch: 51 | Batch: 008 / 018 | Total loss: 4.278 | Reg loss: 0.022 | Tree loss: 4.278 | Accuracy: 0.076172 | 5.119 sec/iter
Epoch: 51 | Batch: 009 / 018 | Total loss: 4.342 | Reg loss: 0.023 | Tree loss: 4.342 | Accuracy: 0.062500 | 5.118 sec/iter
Epoch: 51 | Batch: 010 / 018 | Total loss: 4.276 | Reg loss: 0.023 | Tree loss: 4.276 | Accuracy: 0.080078 | 5.118 sec/iter
Epoch: 51 | Batch: 011 / 018 | Total loss: 4.278 | Reg loss: 0.023 | Tree loss: 4.278 | Accuracy: 0.068359 | 5.117 sec/iter
Epoch: 51 |

Epoch: 54 | Batch: 008 / 018 | Total loss: 4.164 | Reg loss: 0.023 | Tree loss: 4.164 | Accuracy: 0.080078 | 5.116 sec/iter
Epoch: 54 | Batch: 009 / 018 | Total loss: 4.216 | Reg loss: 0.023 | Tree loss: 4.216 | Accuracy: 0.074219 | 5.117 sec/iter
Epoch: 54 | Batch: 010 / 018 | Total loss: 4.171 | Reg loss: 0.023 | Tree loss: 4.171 | Accuracy: 0.056641 | 5.117 sec/iter
Epoch: 54 | Batch: 011 / 018 | Total loss: 4.101 | Reg loss: 0.023 | Tree loss: 4.101 | Accuracy: 0.072266 | 5.117 sec/iter
Epoch: 54 | Batch: 012 / 018 | Total loss: 4.160 | Reg loss: 0.023 | Tree loss: 4.160 | Accuracy: 0.083984 | 5.116 sec/iter
Epoch: 54 | Batch: 013 / 018 | Total loss: 4.173 | Reg loss: 0.023 | Tree loss: 4.173 | Accuracy: 0.091797 | 5.116 sec/iter
Epoch: 54 | Batch: 014 / 018 | Total loss: 4.174 | Reg loss: 0.023 | Tree loss: 4.174 | Accuracy: 0.064453 | 5.116 sec/iter
Epoch: 54 | Batch: 015 / 018 | Total loss: 4.162 | Reg loss: 0.023 | Tree loss: 4.162 | Accuracy: 0.085938 | 5.116 sec/iter
Epoch: 5

Epoch: 57 | Batch: 012 / 018 | Total loss: 4.082 | Reg loss: 0.023 | Tree loss: 4.082 | Accuracy: 0.080078 | 5.124 sec/iter
Epoch: 57 | Batch: 013 / 018 | Total loss: 4.114 | Reg loss: 0.023 | Tree loss: 4.114 | Accuracy: 0.078125 | 5.124 sec/iter
Epoch: 57 | Batch: 014 / 018 | Total loss: 4.044 | Reg loss: 0.023 | Tree loss: 4.044 | Accuracy: 0.087891 | 5.124 sec/iter
Epoch: 57 | Batch: 015 / 018 | Total loss: 4.118 | Reg loss: 0.023 | Tree loss: 4.118 | Accuracy: 0.054688 | 5.124 sec/iter
Epoch: 57 | Batch: 016 / 018 | Total loss: 4.054 | Reg loss: 0.023 | Tree loss: 4.054 | Accuracy: 0.070312 | 5.124 sec/iter
Epoch: 57 | Batch: 017 / 018 | Total loss: 3.908 | Reg loss: 0.023 | Tree loss: 3.908 | Accuracy: 0.097561 | 5.122 sec/iter
Average sparseness: 0.9821428571428572
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
layer 7: 0.98214285714

Epoch: 60 | Batch: 016 / 018 | Total loss: 4.008 | Reg loss: 0.023 | Tree loss: 4.008 | Accuracy: 0.087891 | 5.126 sec/iter
Epoch: 60 | Batch: 017 / 018 | Total loss: 3.972 | Reg loss: 0.023 | Tree loss: 3.972 | Accuracy: 0.024390 | 5.125 sec/iter
Average sparseness: 0.9821428571428572
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
layer 7: 0.9821428571428573
layer 8: 0.9821428571428573
layer 9: 0.9821428571428573
layer 10: 0.9821428571428573
Epoch: 61 | Batch: 000 / 018 | Total loss: 4.006 | Reg loss: 0.023 | Tree loss: 4.006 | Accuracy: 0.070312 | 5.125 sec/iter
Epoch: 61 | Batch: 001 / 018 | Total loss: 4.004 | Reg loss: 0.023 | Tree loss: 4.004 | Accuracy: 0.074219 | 5.125 sec/iter
Epoch: 61 | Batch: 002 / 018 | Total loss: 4.024 | Reg loss: 0.023 | Tree loss: 4.024 | Accuracy: 0.078125 | 5.125 sec/iter
Epoch: 61 | Batch: 003 / 018 | To

Epoch: 64 | Batch: 000 / 018 | Total loss: 3.987 | Reg loss: 0.023 | Tree loss: 3.987 | Accuracy: 0.078125 | 5.122 sec/iter
Epoch: 64 | Batch: 001 / 018 | Total loss: 3.926 | Reg loss: 0.023 | Tree loss: 3.926 | Accuracy: 0.068359 | 5.122 sec/iter
Epoch: 64 | Batch: 002 / 018 | Total loss: 3.949 | Reg loss: 0.023 | Tree loss: 3.949 | Accuracy: 0.078125 | 5.122 sec/iter
Epoch: 64 | Batch: 003 / 018 | Total loss: 3.994 | Reg loss: 0.023 | Tree loss: 3.994 | Accuracy: 0.078125 | 5.123 sec/iter
Epoch: 64 | Batch: 004 / 018 | Total loss: 3.986 | Reg loss: 0.023 | Tree loss: 3.986 | Accuracy: 0.072266 | 5.123 sec/iter
Epoch: 64 | Batch: 005 / 018 | Total loss: 3.958 | Reg loss: 0.023 | Tree loss: 3.958 | Accuracy: 0.089844 | 5.123 sec/iter
Epoch: 64 | Batch: 006 / 018 | Total loss: 3.937 | Reg loss: 0.023 | Tree loss: 3.937 | Accuracy: 0.091797 | 5.122 sec/iter
Epoch: 64 | Batch: 007 / 018 | Total loss: 3.882 | Reg loss: 0.023 | Tree loss: 3.882 | Accuracy: 0.087891 | 5.122 sec/iter
Epoch: 6

Epoch: 67 | Batch: 004 / 018 | Total loss: 3.934 | Reg loss: 0.023 | Tree loss: 3.934 | Accuracy: 0.080078 | 5.12 sec/iter
Epoch: 67 | Batch: 005 / 018 | Total loss: 3.909 | Reg loss: 0.023 | Tree loss: 3.909 | Accuracy: 0.076172 | 5.12 sec/iter
Epoch: 67 | Batch: 006 / 018 | Total loss: 3.891 | Reg loss: 0.023 | Tree loss: 3.891 | Accuracy: 0.074219 | 5.12 sec/iter
Epoch: 67 | Batch: 007 / 018 | Total loss: 3.896 | Reg loss: 0.023 | Tree loss: 3.896 | Accuracy: 0.089844 | 5.12 sec/iter
Epoch: 67 | Batch: 008 / 018 | Total loss: 3.907 | Reg loss: 0.023 | Tree loss: 3.907 | Accuracy: 0.089844 | 5.12 sec/iter
Epoch: 67 | Batch: 009 / 018 | Total loss: 3.859 | Reg loss: 0.023 | Tree loss: 3.859 | Accuracy: 0.078125 | 5.121 sec/iter
Epoch: 67 | Batch: 010 / 018 | Total loss: 3.841 | Reg loss: 0.023 | Tree loss: 3.841 | Accuracy: 0.070312 | 5.121 sec/iter
Epoch: 67 | Batch: 011 / 018 | Total loss: 3.866 | Reg loss: 0.023 | Tree loss: 3.866 | Accuracy: 0.083984 | 5.121 sec/iter
Epoch: 67 | B

Epoch: 70 | Batch: 008 / 018 | Total loss: 3.911 | Reg loss: 0.023 | Tree loss: 3.911 | Accuracy: 0.060547 | 5.121 sec/iter
Epoch: 70 | Batch: 009 / 018 | Total loss: 3.801 | Reg loss: 0.023 | Tree loss: 3.801 | Accuracy: 0.085938 | 5.121 sec/iter
Epoch: 70 | Batch: 010 / 018 | Total loss: 3.830 | Reg loss: 0.023 | Tree loss: 3.830 | Accuracy: 0.089844 | 5.121 sec/iter
Epoch: 70 | Batch: 011 / 018 | Total loss: 3.851 | Reg loss: 0.023 | Tree loss: 3.851 | Accuracy: 0.091797 | 5.121 sec/iter
Epoch: 70 | Batch: 012 / 018 | Total loss: 3.863 | Reg loss: 0.023 | Tree loss: 3.863 | Accuracy: 0.076172 | 5.121 sec/iter
Epoch: 70 | Batch: 013 / 018 | Total loss: 3.847 | Reg loss: 0.023 | Tree loss: 3.847 | Accuracy: 0.076172 | 5.121 sec/iter
Epoch: 70 | Batch: 014 / 018 | Total loss: 3.890 | Reg loss: 0.023 | Tree loss: 3.890 | Accuracy: 0.070312 | 5.121 sec/iter
Epoch: 70 | Batch: 015 / 018 | Total loss: 3.836 | Reg loss: 0.023 | Tree loss: 3.836 | Accuracy: 0.080078 | 5.121 sec/iter
Epoch: 7

Epoch: 73 | Batch: 012 / 018 | Total loss: 3.816 | Reg loss: 0.023 | Tree loss: 3.816 | Accuracy: 0.080078 | 5.123 sec/iter
Epoch: 73 | Batch: 013 / 018 | Total loss: 3.795 | Reg loss: 0.023 | Tree loss: 3.795 | Accuracy: 0.093750 | 5.123 sec/iter
Epoch: 73 | Batch: 014 / 018 | Total loss: 3.796 | Reg loss: 0.023 | Tree loss: 3.796 | Accuracy: 0.076172 | 5.123 sec/iter
Epoch: 73 | Batch: 015 / 018 | Total loss: 3.788 | Reg loss: 0.023 | Tree loss: 3.788 | Accuracy: 0.082031 | 5.123 sec/iter
Epoch: 73 | Batch: 016 / 018 | Total loss: 3.792 | Reg loss: 0.023 | Tree loss: 3.792 | Accuracy: 0.078125 | 5.123 sec/iter
Epoch: 73 | Batch: 017 / 018 | Total loss: 3.930 | Reg loss: 0.023 | Tree loss: 3.930 | Accuracy: 0.097561 | 5.122 sec/iter
Average sparseness: 0.9821428571428572
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
layer 7: 0.98214285714

Epoch: 76 | Batch: 016 / 018 | Total loss: 3.763 | Reg loss: 0.023 | Tree loss: 3.763 | Accuracy: 0.085938 | 5.126 sec/iter
Epoch: 76 | Batch: 017 / 018 | Total loss: 3.851 | Reg loss: 0.023 | Tree loss: 3.851 | Accuracy: 0.024390 | 5.125 sec/iter
Average sparseness: 0.9821428571428572
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
layer 7: 0.9821428571428573
layer 8: 0.9821428571428573
layer 9: 0.9821428571428573
layer 10: 0.9821428571428573
Epoch: 77 | Batch: 000 / 018 | Total loss: 3.793 | Reg loss: 0.023 | Tree loss: 3.793 | Accuracy: 0.099609 | 5.128 sec/iter
Epoch: 77 | Batch: 001 / 018 | Total loss: 3.812 | Reg loss: 0.023 | Tree loss: 3.812 | Accuracy: 0.068359 | 5.127 sec/iter
Epoch: 77 | Batch: 002 / 018 | Total loss: 3.816 | Reg loss: 0.023 | Tree loss: 3.816 | Accuracy: 0.060547 | 5.127 sec/iter
Epoch: 77 | Batch: 003 / 018 | To

Epoch: 80 | Batch: 000 / 018 | Total loss: 3.784 | Reg loss: 0.023 | Tree loss: 3.784 | Accuracy: 0.074219 | 5.122 sec/iter
Epoch: 80 | Batch: 001 / 018 | Total loss: 3.780 | Reg loss: 0.023 | Tree loss: 3.780 | Accuracy: 0.078125 | 5.122 sec/iter
Epoch: 80 | Batch: 002 / 018 | Total loss: 3.783 | Reg loss: 0.023 | Tree loss: 3.783 | Accuracy: 0.082031 | 5.122 sec/iter
Epoch: 80 | Batch: 003 / 018 | Total loss: 3.723 | Reg loss: 0.023 | Tree loss: 3.723 | Accuracy: 0.085938 | 5.122 sec/iter
Epoch: 80 | Batch: 004 / 018 | Total loss: 3.729 | Reg loss: 0.023 | Tree loss: 3.729 | Accuracy: 0.064453 | 5.123 sec/iter
Epoch: 80 | Batch: 005 / 018 | Total loss: 3.734 | Reg loss: 0.023 | Tree loss: 3.734 | Accuracy: 0.080078 | 5.123 sec/iter
Epoch: 80 | Batch: 006 / 018 | Total loss: 3.785 | Reg loss: 0.023 | Tree loss: 3.785 | Accuracy: 0.085938 | 5.123 sec/iter
Epoch: 80 | Batch: 007 / 018 | Total loss: 3.739 | Reg loss: 0.023 | Tree loss: 3.739 | Accuracy: 0.080078 | 5.123 sec/iter
Epoch: 8

Epoch: 83 | Batch: 004 / 018 | Total loss: 3.721 | Reg loss: 0.023 | Tree loss: 3.721 | Accuracy: 0.082031 | 5.129 sec/iter
Epoch: 83 | Batch: 005 / 018 | Total loss: 3.717 | Reg loss: 0.023 | Tree loss: 3.717 | Accuracy: 0.091797 | 5.13 sec/iter
Epoch: 83 | Batch: 006 / 018 | Total loss: 3.736 | Reg loss: 0.023 | Tree loss: 3.736 | Accuracy: 0.062500 | 5.13 sec/iter
Epoch: 83 | Batch: 007 / 018 | Total loss: 3.795 | Reg loss: 0.023 | Tree loss: 3.795 | Accuracy: 0.070312 | 5.13 sec/iter
Epoch: 83 | Batch: 008 / 018 | Total loss: 3.662 | Reg loss: 0.023 | Tree loss: 3.662 | Accuracy: 0.091797 | 5.13 sec/iter
Epoch: 83 | Batch: 009 / 018 | Total loss: 3.740 | Reg loss: 0.023 | Tree loss: 3.740 | Accuracy: 0.082031 | 5.13 sec/iter
Epoch: 83 | Batch: 010 / 018 | Total loss: 3.772 | Reg loss: 0.023 | Tree loss: 3.772 | Accuracy: 0.085938 | 5.131 sec/iter
Epoch: 83 | Batch: 011 / 018 | Total loss: 3.770 | Reg loss: 0.023 | Tree loss: 3.770 | Accuracy: 0.070312 | 5.131 sec/iter
Epoch: 83 | B

Epoch: 86 | Batch: 008 / 018 | Total loss: 3.694 | Reg loss: 0.023 | Tree loss: 3.694 | Accuracy: 0.070312 | 5.131 sec/iter
Epoch: 86 | Batch: 009 / 018 | Total loss: 3.767 | Reg loss: 0.023 | Tree loss: 3.767 | Accuracy: 0.085938 | 5.132 sec/iter
Epoch: 86 | Batch: 010 / 018 | Total loss: 3.633 | Reg loss: 0.023 | Tree loss: 3.633 | Accuracy: 0.103516 | 5.132 sec/iter
Epoch: 86 | Batch: 011 / 018 | Total loss: 3.693 | Reg loss: 0.023 | Tree loss: 3.693 | Accuracy: 0.068359 | 5.132 sec/iter
Epoch: 86 | Batch: 012 / 018 | Total loss: 3.641 | Reg loss: 0.023 | Tree loss: 3.641 | Accuracy: 0.082031 | 5.132 sec/iter
Epoch: 86 | Batch: 013 / 018 | Total loss: 3.758 | Reg loss: 0.023 | Tree loss: 3.758 | Accuracy: 0.076172 | 5.132 sec/iter
Epoch: 86 | Batch: 014 / 018 | Total loss: 3.727 | Reg loss: 0.023 | Tree loss: 3.727 | Accuracy: 0.076172 | 5.132 sec/iter
Epoch: 86 | Batch: 015 / 018 | Total loss: 3.695 | Reg loss: 0.023 | Tree loss: 3.695 | Accuracy: 0.091797 | 5.132 sec/iter
Epoch: 8

Epoch: 89 | Batch: 012 / 018 | Total loss: 3.689 | Reg loss: 0.024 | Tree loss: 3.689 | Accuracy: 0.085938 | 5.13 sec/iter
Epoch: 89 | Batch: 013 / 018 | Total loss: 3.696 | Reg loss: 0.024 | Tree loss: 3.696 | Accuracy: 0.085938 | 5.13 sec/iter
Epoch: 89 | Batch: 014 / 018 | Total loss: 3.656 | Reg loss: 0.024 | Tree loss: 3.656 | Accuracy: 0.083984 | 5.129 sec/iter
Epoch: 89 | Batch: 015 / 018 | Total loss: 3.639 | Reg loss: 0.024 | Tree loss: 3.639 | Accuracy: 0.082031 | 5.129 sec/iter
Epoch: 89 | Batch: 016 / 018 | Total loss: 3.719 | Reg loss: 0.024 | Tree loss: 3.719 | Accuracy: 0.082031 | 5.129 sec/iter
Epoch: 89 | Batch: 017 / 018 | Total loss: 3.680 | Reg loss: 0.024 | Tree loss: 3.680 | Accuracy: 0.195122 | 5.128 sec/iter
Average sparseness: 0.9821428571428572
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
layer 7: 0.9821428571428

Epoch: 92 | Batch: 016 / 018 | Total loss: 3.661 | Reg loss: 0.024 | Tree loss: 3.661 | Accuracy: 0.078125 | 5.136 sec/iter
Epoch: 92 | Batch: 017 / 018 | Total loss: 3.719 | Reg loss: 0.024 | Tree loss: 3.719 | Accuracy: 0.024390 | 5.135 sec/iter
Average sparseness: 0.9821428571428572
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
layer 7: 0.9821428571428573
layer 8: 0.9821428571428573
layer 9: 0.9821428571428573
layer 10: 0.9821428571428573
Epoch: 93 | Batch: 000 / 018 | Total loss: 3.696 | Reg loss: 0.024 | Tree loss: 3.696 | Accuracy: 0.076172 | 5.136 sec/iter
Epoch: 93 | Batch: 001 / 018 | Total loss: 3.717 | Reg loss: 0.024 | Tree loss: 3.717 | Accuracy: 0.093750 | 5.136 sec/iter
Epoch: 93 | Batch: 002 / 018 | Total loss: 3.706 | Reg loss: 0.024 | Tree loss: 3.706 | Accuracy: 0.085938 | 5.136 sec/iter
Epoch: 93 | Batch: 003 / 018 | To

Epoch: 96 | Batch: 000 / 018 | Total loss: 3.744 | Reg loss: 0.024 | Tree loss: 3.744 | Accuracy: 0.080078 | 5.142 sec/iter
Epoch: 96 | Batch: 001 / 018 | Total loss: 3.639 | Reg loss: 0.024 | Tree loss: 3.639 | Accuracy: 0.105469 | 5.142 sec/iter
Epoch: 96 | Batch: 002 / 018 | Total loss: 3.680 | Reg loss: 0.024 | Tree loss: 3.680 | Accuracy: 0.083984 | 5.142 sec/iter
Epoch: 96 | Batch: 003 / 018 | Total loss: 3.692 | Reg loss: 0.024 | Tree loss: 3.692 | Accuracy: 0.074219 | 5.143 sec/iter
Epoch: 96 | Batch: 004 / 018 | Total loss: 3.705 | Reg loss: 0.024 | Tree loss: 3.705 | Accuracy: 0.072266 | 5.143 sec/iter
Epoch: 96 | Batch: 005 / 018 | Total loss: 3.701 | Reg loss: 0.024 | Tree loss: 3.701 | Accuracy: 0.072266 | 5.143 sec/iter
Epoch: 96 | Batch: 006 / 018 | Total loss: 3.624 | Reg loss: 0.024 | Tree loss: 3.624 | Accuracy: 0.093750 | 5.143 sec/iter
Epoch: 96 | Batch: 007 / 018 | Total loss: 3.628 | Reg loss: 0.024 | Tree loss: 3.628 | Accuracy: 0.078125 | 5.144 sec/iter
Epoch: 9

Epoch: 99 | Batch: 004 / 018 | Total loss: 3.672 | Reg loss: 0.024 | Tree loss: 3.672 | Accuracy: 0.080078 | 5.145 sec/iter
Epoch: 99 | Batch: 005 / 018 | Total loss: 3.619 | Reg loss: 0.024 | Tree loss: 3.619 | Accuracy: 0.072266 | 5.145 sec/iter
Epoch: 99 | Batch: 006 / 018 | Total loss: 3.718 | Reg loss: 0.024 | Tree loss: 3.718 | Accuracy: 0.089844 | 5.145 sec/iter
Epoch: 99 | Batch: 007 / 018 | Total loss: 3.681 | Reg loss: 0.024 | Tree loss: 3.681 | Accuracy: 0.082031 | 5.145 sec/iter
Epoch: 99 | Batch: 008 / 018 | Total loss: 3.689 | Reg loss: 0.024 | Tree loss: 3.689 | Accuracy: 0.085938 | 5.146 sec/iter
Epoch: 99 | Batch: 009 / 018 | Total loss: 3.727 | Reg loss: 0.024 | Tree loss: 3.727 | Accuracy: 0.085938 | 5.146 sec/iter
Epoch: 99 | Batch: 010 / 018 | Total loss: 3.691 | Reg loss: 0.024 | Tree loss: 3.691 | Accuracy: 0.064453 | 5.146 sec/iter
Epoch: 99 | Batch: 011 / 018 | Total loss: 3.631 | Reg loss: 0.024 | Tree loss: 3.631 | Accuracy: 0.091797 | 5.146 sec/iter
Epoch: 9

In [31]:
plt.figure(figsize=(10, 5))
plt.ylabel("Accuracy")
plt.xlabel('Iteration')
plt.plot(accs, label='Accuracy vs iteration')
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [32]:
plt.figure()
plt.ylabel("Loss")
plt.xlabel('Iteration')
plt.plot(losses, label='Loss vs iteration')
plt.yscale("log")
plt.show()

plt.figure()
weights = tree.inner_nodes.weight.cpu().detach().numpy().flatten()
plt.hist(weights, bins=500)
weights_std = np.std(weights)
weights_mean = np.mean(weights)
plt.axvline(weights_mean + weights_std, color='r')
plt.axvline(weights_mean - weights_std, color='r')
plt.title(f"Mean: {weights_mean}   |   STD: {weights_std}")
plt.yscale("log")
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

# Tree Visualization

In [33]:
plt.figure(figsize=(15, 10), dpi=80)
avg_height, root = tree.visualize()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Average height: 12.0


# Extract Rules

# Accumulate samples in the leaves

In [34]:
print(f"Number of patterns: {len(root.get_leaves())}")

Number of patterns: 4096


In [35]:
method = 'greedy'

In [36]:
root.clear_leaves_samples()

with torch.no_grad():
    for batch_idx, (data, target) in enumerate(tree_loader):
        root.accumulate_samples(data, method)



# Tighten boundaries

In [37]:
attr_names = dataset.items

# print(attr_names)
leaves = root.get_leaves()
sum_comprehensibility = 0
comprehensibilities = []
for pattern_counter, leaf in enumerate(leaves):
    leaf.reset_path()
    leaf.tighten_with_accumulated_samples()
    conds = leaf.get_path_conditions(attr_names)
    print(f"============== Pattern {pattern_counter + 1} ==============")
    comprehensibilities.append(sum([cond.comprehensibility for cond in conds]))
    
print(f"Average comprehensibility: {np.mean(comprehensibilities)}")
print(f"std comprehensibility: {np.std(comprehensibilities)}")
print(f"var comprehensibility: {np.var(comprehensibilities)}")
print(f"minimum comprehensibility: {np.min(comprehensibilities)}")
print(f"maximum comprehensibility: {np.max(comprehensibilities)}")



  return np.log(1 / (1 - x))


5662
3083












Average comprehensibility: 54.7666015625
std comprehensibility: 3.3709781843513253
var comprehensibility: 11.363493919372559
minimum comprehensibility: 48
maximum comprehensibility: 68
