In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib widget
import numpy as np
import pandas as pd
import torch
import torchvision
from torch import optim
import matplotlib.pyplot as plt
import torch.nn.functional as F
from sklearn.metrics import pairwise_distances
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tqdm.notebook import tqdm
import os
import sys
import time
import torch.nn as nn
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)
from stream_generators.market_basket_dataset import MarketBasketDataset, BinaryEncodingTransform, RemoveItemsTransform
from utils.MatplotlibUtils import reduce_dims_and_plot
from network.auto_encoder import AutoEncoder
from losses.knn_loss import KNNLoss
from sklearn.cluster import DBSCAN
from sklearn.mixture import GaussianMixture
from soft_decision_tree.sdt_model import SDT
from sklearn.metrics import davies_bouldin_score

In [2]:
k = 64
tree_depth = 6
device = 'cuda'
dataset_path = r"/mnt/qnap/ekosman/Groceries_dataset.csv"

## Create the market basket dataset and use one-hot encoding for items

In [3]:
dataset = MarketBasketDataset(dataset_path=dataset_path)

In [4]:
model = AutoEncoder(dataset.n_items, 50, 4).train().to(device)
epochs = 500
lr = 5e-3
batch_size = 512
log_every = 5

In [5]:
dataset.transform = torchvision.transforms.Compose([
    RemoveItemsTransform(p=0.5),
    BinaryEncodingTransform(mapping=dataset.items_to_idx),
]
)
dataset.target_transform = torchvision.transforms.Compose([
    BinaryEncodingTransform(mapping=dataset.items_to_idx),
]
)

In [6]:
model.train()
data_iter = torch.utils.data.DataLoader(dataset,
                                     batch_size=batch_size,
                                     shuffle=True,
                                     num_workers=1,
                                     pin_memory=True)
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)
scheduler = ReduceLROnPlateau(optimizer, 'min', verbose=True, factor=0.7, threshold=1e-4)
knn_crt = KNNLoss(k=k).to(device)
losses = []
alpha = 10/170
gamma = 2
for epoch in range(epochs):
    total_loss = 0
    for iteration, (batch, target) in enumerate(data_iter):
        batch = batch.to(device)
        target = target.to(device)
        outputs, iterm = model(batch, return_intermidiate=True)
        mse_loss = F.binary_cross_entropy_with_logits(outputs, target, reduction='none')
        mask = torch.ones_like(mse_loss)
        mask[target == 0] = alpha ** gamma
        mask[target == 1] = (1 - alpha) ** gamma
        mse_loss = (mse_loss * mask).sum(dim=-1).mean()
        try:
            knn_loss = knn_crt(iterm)
            if torch.isinf(knn_loss):
                knn_loss = 0
        except ValueError:
            knn_loss = torch.tensor(0)
        loss = mse_loss + knn_loss
        total_loss += loss.item()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if iteration % log_every == 0:
            print(f"Epoch {epoch} / {epochs} | iteration {iteration} / {len(data_iter)} | Total Loss: {loss.item()} | KNN Loss: {knn_loss.item()} | BCE Loss: {mse_loss.item()}")
    
    scheduler.step(total_loss / (iteration + 1))
    losses.append(total_loss / (iteration + 1))

Epoch 0 / 500 | iteration 0 / 30 | Total Loss: 8.20648193359375 | KNN Loss: 6.229707717895508 | BCE Loss: 1.9767743349075317
Epoch 0 / 500 | iteration 5 / 30 | Total Loss: 8.153791427612305 | KNN Loss: 6.229312896728516 | BCE Loss: 1.9244780540466309
Epoch 0 / 500 | iteration 10 / 30 | Total Loss: 8.150104522705078 | KNN Loss: 6.22951078414917 | BCE Loss: 1.9205937385559082
Epoch 0 / 500 | iteration 15 / 30 | Total Loss: 8.113521575927734 | KNN Loss: 6.229531288146973 | BCE Loss: 1.8839905261993408
Epoch 0 / 500 | iteration 20 / 30 | Total Loss: 8.16419506072998 | KNN Loss: 6.228993892669678 | BCE Loss: 1.9352010488510132
Epoch 0 / 500 | iteration 25 / 30 | Total Loss: 8.130905151367188 | KNN Loss: 6.228990077972412 | BCE Loss: 1.9019153118133545
Epoch 1 / 500 | iteration 0 / 30 | Total Loss: 8.154356956481934 | KNN Loss: 6.228734970092773 | BCE Loss: 1.9256222248077393
Epoch 1 / 500 | iteration 5 / 30 | Total Loss: 8.132513046264648 | KNN Loss: 6.228709697723389 | BCE Loss: 1.90380287

Epoch 10 / 500 | iteration 25 / 30 | Total Loss: 7.164376258850098 | KNN Loss: 5.898101329803467 | BCE Loss: 1.2662748098373413
Epoch 11 / 500 | iteration 0 / 30 | Total Loss: 7.144925117492676 | KNN Loss: 5.868739604949951 | BCE Loss: 1.2761855125427246
Epoch 11 / 500 | iteration 5 / 30 | Total Loss: 7.037556171417236 | KNN Loss: 5.836838245391846 | BCE Loss: 1.2007179260253906
Epoch 11 / 500 | iteration 10 / 30 | Total Loss: 6.977527618408203 | KNN Loss: 5.7611799240112305 | BCE Loss: 1.2163479328155518
Epoch 11 / 500 | iteration 15 / 30 | Total Loss: 6.82755184173584 | KNN Loss: 5.637224197387695 | BCE Loss: 1.1903276443481445
Epoch 11 / 500 | iteration 20 / 30 | Total Loss: 6.773477554321289 | KNN Loss: 5.586361885070801 | BCE Loss: 1.1871156692504883
Epoch 11 / 500 | iteration 25 / 30 | Total Loss: 6.676501274108887 | KNN Loss: 5.483994007110596 | BCE Loss: 1.192507028579712
Epoch 12 / 500 | iteration 0 / 30 | Total Loss: 6.55367374420166 | KNN Loss: 5.393153667449951 | BCE Loss: 

Epoch 21 / 500 | iteration 20 / 30 | Total Loss: 5.5549540519714355 | KNN Loss: 4.488445281982422 | BCE Loss: 1.0665087699890137
Epoch 21 / 500 | iteration 25 / 30 | Total Loss: 5.548538684844971 | KNN Loss: 4.491568088531494 | BCE Loss: 1.0569705963134766
Epoch 22 / 500 | iteration 0 / 30 | Total Loss: 5.554873466491699 | KNN Loss: 4.491104602813721 | BCE Loss: 1.0637686252593994
Epoch 22 / 500 | iteration 5 / 30 | Total Loss: 5.548341751098633 | KNN Loss: 4.495208740234375 | BCE Loss: 1.053133249282837
Epoch 22 / 500 | iteration 10 / 30 | Total Loss: 5.5611186027526855 | KNN Loss: 4.513321399688721 | BCE Loss: 1.0477972030639648
Epoch 22 / 500 | iteration 15 / 30 | Total Loss: 5.552997589111328 | KNN Loss: 4.482912063598633 | BCE Loss: 1.0700852870941162
Epoch 22 / 500 | iteration 20 / 30 | Total Loss: 5.53004264831543 | KNN Loss: 4.47650146484375 | BCE Loss: 1.0535414218902588
Epoch 22 / 500 | iteration 25 / 30 | Total Loss: 5.532318592071533 | KNN Loss: 4.472842216491699 | BCE Loss

Epoch 32 / 500 | iteration 15 / 30 | Total Loss: 5.465092658996582 | KNN Loss: 4.412795543670654 | BCE Loss: 1.0522971153259277
Epoch 32 / 500 | iteration 20 / 30 | Total Loss: 5.471210479736328 | KNN Loss: 4.431530475616455 | BCE Loss: 1.039679765701294
Epoch 32 / 500 | iteration 25 / 30 | Total Loss: 5.5225958824157715 | KNN Loss: 4.42623233795166 | BCE Loss: 1.0963634252548218
Epoch 33 / 500 | iteration 0 / 30 | Total Loss: 5.483162879943848 | KNN Loss: 4.425594329833984 | BCE Loss: 1.0575687885284424
Epoch 33 / 500 | iteration 5 / 30 | Total Loss: 5.465487480163574 | KNN Loss: 4.40500020980835 | BCE Loss: 1.0604875087738037
Epoch 33 / 500 | iteration 10 / 30 | Total Loss: 5.452884197235107 | KNN Loss: 4.401498317718506 | BCE Loss: 1.0513858795166016
Epoch 33 / 500 | iteration 15 / 30 | Total Loss: 5.537181854248047 | KNN Loss: 4.447823524475098 | BCE Loss: 1.0893582105636597
Epoch 33 / 500 | iteration 20 / 30 | Total Loss: 5.439001083374023 | KNN Loss: 4.420112609863281 | BCE Loss:

Epoch 43 / 500 | iteration 10 / 30 | Total Loss: 5.449646949768066 | KNN Loss: 4.367004871368408 | BCE Loss: 1.0826420783996582
Epoch 43 / 500 | iteration 15 / 30 | Total Loss: 5.438197135925293 | KNN Loss: 4.406385898590088 | BCE Loss: 1.031810998916626
Epoch 43 / 500 | iteration 20 / 30 | Total Loss: 5.438326358795166 | KNN Loss: 4.400182723999023 | BCE Loss: 1.038143515586853
Epoch 43 / 500 | iteration 25 / 30 | Total Loss: 5.430689811706543 | KNN Loss: 4.380826473236084 | BCE Loss: 1.0498634576797485
Epoch 44 / 500 | iteration 0 / 30 | Total Loss: 5.429266929626465 | KNN Loss: 4.39487361907959 | BCE Loss: 1.034393072128296
Epoch 44 / 500 | iteration 5 / 30 | Total Loss: 5.439969062805176 | KNN Loss: 4.393905162811279 | BCE Loss: 1.0460641384124756
Epoch 44 / 500 | iteration 10 / 30 | Total Loss: 5.422880172729492 | KNN Loss: 4.380152225494385 | BCE Loss: 1.0427279472351074
Epoch 44 / 500 | iteration 15 / 30 | Total Loss: 5.443521499633789 | KNN Loss: 4.396193027496338 | BCE Loss: 1

Epoch 54 / 500 | iteration 5 / 30 | Total Loss: 5.449002265930176 | KNN Loss: 4.384759426116943 | BCE Loss: 1.064242959022522
Epoch 54 / 500 | iteration 10 / 30 | Total Loss: 5.4345502853393555 | KNN Loss: 4.393093109130859 | BCE Loss: 1.041456937789917
Epoch 54 / 500 | iteration 15 / 30 | Total Loss: 5.435512542724609 | KNN Loss: 4.39929723739624 | BCE Loss: 1.03621506690979
Epoch 54 / 500 | iteration 20 / 30 | Total Loss: 5.455670356750488 | KNN Loss: 4.377146244049072 | BCE Loss: 1.078524112701416
Epoch 54 / 500 | iteration 25 / 30 | Total Loss: 5.399627208709717 | KNN Loss: 4.3628764152526855 | BCE Loss: 1.0367507934570312
Epoch 55 / 500 | iteration 0 / 30 | Total Loss: 5.409934997558594 | KNN Loss: 4.399448871612549 | BCE Loss: 1.0104858875274658
Epoch 55 / 500 | iteration 5 / 30 | Total Loss: 5.371906280517578 | KNN Loss: 4.363832950592041 | BCE Loss: 1.0080735683441162
Epoch 55 / 500 | iteration 10 / 30 | Total Loss: 5.45555305480957 | KNN Loss: 4.432048797607422 | BCE Loss: 1.0

Epoch 65 / 500 | iteration 0 / 30 | Total Loss: 5.391382217407227 | KNN Loss: 4.372014999389648 | BCE Loss: 1.019366979598999
Epoch 65 / 500 | iteration 5 / 30 | Total Loss: 5.449484825134277 | KNN Loss: 4.38981819152832 | BCE Loss: 1.0596665143966675
Epoch 65 / 500 | iteration 10 / 30 | Total Loss: 5.417436599731445 | KNN Loss: 4.3928608894348145 | BCE Loss: 1.0245754718780518
Epoch 65 / 500 | iteration 15 / 30 | Total Loss: 5.410167694091797 | KNN Loss: 4.370477199554443 | BCE Loss: 1.0396902561187744
Epoch 65 / 500 | iteration 20 / 30 | Total Loss: 5.371586799621582 | KNN Loss: 4.362719535827637 | BCE Loss: 1.0088670253753662
Epoch 65 / 500 | iteration 25 / 30 | Total Loss: 5.4390363693237305 | KNN Loss: 4.396903991699219 | BCE Loss: 1.0421321392059326
Epoch 66 / 500 | iteration 0 / 30 | Total Loss: 5.366359710693359 | KNN Loss: 4.362687587738037 | BCE Loss: 1.0036718845367432
Epoch 66 / 500 | iteration 5 / 30 | Total Loss: 5.437038898468018 | KNN Loss: 4.4082255363464355 | BCE Loss

Epoch 75 / 500 | iteration 25 / 30 | Total Loss: 5.38702392578125 | KNN Loss: 4.331572532653809 | BCE Loss: 1.0554511547088623
Epoch 76 / 500 | iteration 0 / 30 | Total Loss: 5.380053520202637 | KNN Loss: 4.339132785797119 | BCE Loss: 1.0409208536148071
Epoch 76 / 500 | iteration 5 / 30 | Total Loss: 5.417142391204834 | KNN Loss: 4.351643085479736 | BCE Loss: 1.0654993057250977
Epoch 76 / 500 | iteration 10 / 30 | Total Loss: 5.421656608581543 | KNN Loss: 4.388874053955078 | BCE Loss: 1.032782793045044
Epoch 76 / 500 | iteration 15 / 30 | Total Loss: 5.431421279907227 | KNN Loss: 4.381196022033691 | BCE Loss: 1.0502252578735352
Epoch 76 / 500 | iteration 20 / 30 | Total Loss: 5.425710678100586 | KNN Loss: 4.402492523193359 | BCE Loss: 1.0232183933258057
Epoch 76 / 500 | iteration 25 / 30 | Total Loss: 5.429288387298584 | KNN Loss: 4.405551910400391 | BCE Loss: 1.023736596107483
Epoch 77 / 500 | iteration 0 / 30 | Total Loss: 5.424142837524414 | KNN Loss: 4.382375240325928 | BCE Loss: 1

Epoch 86 / 500 | iteration 15 / 30 | Total Loss: 5.451026916503906 | KNN Loss: 4.4105448722839355 | BCE Loss: 1.0404820442199707
Epoch 86 / 500 | iteration 20 / 30 | Total Loss: 5.399460315704346 | KNN Loss: 4.376109600067139 | BCE Loss: 1.023350715637207
Epoch 86 / 500 | iteration 25 / 30 | Total Loss: 5.343379020690918 | KNN Loss: 4.310970783233643 | BCE Loss: 1.0324079990386963
Epoch 87 / 500 | iteration 0 / 30 | Total Loss: 5.369244575500488 | KNN Loss: 4.3475751876831055 | BCE Loss: 1.021669626235962
Epoch 87 / 500 | iteration 5 / 30 | Total Loss: 5.364074230194092 | KNN Loss: 4.318551063537598 | BCE Loss: 1.0455231666564941
Epoch 87 / 500 | iteration 10 / 30 | Total Loss: 5.330653190612793 | KNN Loss: 4.3121795654296875 | BCE Loss: 1.0184736251831055
Epoch 87 / 500 | iteration 15 / 30 | Total Loss: 5.37527322769165 | KNN Loss: 4.33676290512085 | BCE Loss: 1.0385102033615112
Epoch 87 / 500 | iteration 20 / 30 | Total Loss: 5.41039514541626 | KNN Loss: 4.382853984832764 | BCE Loss:

Epoch 97 / 500 | iteration 10 / 30 | Total Loss: 5.375739097595215 | KNN Loss: 4.346031188964844 | BCE Loss: 1.0297081470489502
Epoch 97 / 500 | iteration 15 / 30 | Total Loss: 5.378602504730225 | KNN Loss: 4.352100849151611 | BCE Loss: 1.0265016555786133
Epoch 97 / 500 | iteration 20 / 30 | Total Loss: 5.336757183074951 | KNN Loss: 4.333893299102783 | BCE Loss: 1.002863883972168
Epoch 97 / 500 | iteration 25 / 30 | Total Loss: 5.354245185852051 | KNN Loss: 4.33836555480957 | BCE Loss: 1.0158798694610596
Epoch 98 / 500 | iteration 0 / 30 | Total Loss: 5.350261211395264 | KNN Loss: 4.34181022644043 | BCE Loss: 1.0084511041641235
Epoch 98 / 500 | iteration 5 / 30 | Total Loss: 5.39853572845459 | KNN Loss: 4.36923885345459 | BCE Loss: 1.029296636581421
Epoch 98 / 500 | iteration 10 / 30 | Total Loss: 5.356836318969727 | KNN Loss: 4.332324504852295 | BCE Loss: 1.0245115756988525
Epoch 98 / 500 | iteration 15 / 30 | Total Loss: 5.3523383140563965 | KNN Loss: 4.3183674812316895 | BCE Loss: 1

Epoch 108 / 500 | iteration 0 / 30 | Total Loss: 5.446849822998047 | KNN Loss: 4.411737442016602 | BCE Loss: 1.0351126194000244
Epoch 108 / 500 | iteration 5 / 30 | Total Loss: 5.4523444175720215 | KNN Loss: 4.451930046081543 | BCE Loss: 1.0004143714904785
Epoch 108 / 500 | iteration 10 / 30 | Total Loss: 5.371026039123535 | KNN Loss: 4.361382961273193 | BCE Loss: 1.009643316268921
Epoch 108 / 500 | iteration 15 / 30 | Total Loss: 5.370291233062744 | KNN Loss: 4.333562850952148 | BCE Loss: 1.0367285013198853
Epoch 108 / 500 | iteration 20 / 30 | Total Loss: 5.396492004394531 | KNN Loss: 4.350986480712891 | BCE Loss: 1.0455056428909302
Epoch 108 / 500 | iteration 25 / 30 | Total Loss: 5.3906168937683105 | KNN Loss: 4.359808921813965 | BCE Loss: 1.0308078527450562
Epoch 109 / 500 | iteration 0 / 30 | Total Loss: 5.367439270019531 | KNN Loss: 4.340241432189941 | BCE Loss: 1.0271977186203003
Epoch 109 / 500 | iteration 5 / 30 | Total Loss: 5.378763675689697 | KNN Loss: 4.3424153327941895 |

Epoch 118 / 500 | iteration 20 / 30 | Total Loss: 5.362718105316162 | KNN Loss: 4.340510368347168 | BCE Loss: 1.0222077369689941
Epoch 118 / 500 | iteration 25 / 30 | Total Loss: 5.405222415924072 | KNN Loss: 4.36403751373291 | BCE Loss: 1.0411847829818726
Epoch 119 / 500 | iteration 0 / 30 | Total Loss: 5.386414527893066 | KNN Loss: 4.36256217956543 | BCE Loss: 1.0238525867462158
Epoch 119 / 500 | iteration 5 / 30 | Total Loss: 5.413398265838623 | KNN Loss: 4.3725104331970215 | BCE Loss: 1.0408879518508911
Epoch 119 / 500 | iteration 10 / 30 | Total Loss: 5.375942230224609 | KNN Loss: 4.380314350128174 | BCE Loss: 0.9956281185150146
Epoch 119 / 500 | iteration 15 / 30 | Total Loss: 5.398951530456543 | KNN Loss: 4.359055995941162 | BCE Loss: 1.0398955345153809
Epoch 119 / 500 | iteration 20 / 30 | Total Loss: 5.4364776611328125 | KNN Loss: 4.378838539123535 | BCE Loss: 1.0576390027999878
Epoch 119 / 500 | iteration 25 / 30 | Total Loss: 5.4148054122924805 | KNN Loss: 4.36395788192749 |

Epoch 129 / 500 | iteration 10 / 30 | Total Loss: 5.376575946807861 | KNN Loss: 4.338107109069824 | BCE Loss: 1.038468837738037
Epoch 129 / 500 | iteration 15 / 30 | Total Loss: 5.3478922843933105 | KNN Loss: 4.3298821449279785 | BCE Loss: 1.018010139465332
Epoch 129 / 500 | iteration 20 / 30 | Total Loss: 5.397371292114258 | KNN Loss: 4.370826244354248 | BCE Loss: 1.0265452861785889
Epoch 129 / 500 | iteration 25 / 30 | Total Loss: 5.37315559387207 | KNN Loss: 4.348732948303223 | BCE Loss: 1.0244224071502686
Epoch 130 / 500 | iteration 0 / 30 | Total Loss: 5.414874076843262 | KNN Loss: 4.368021488189697 | BCE Loss: 1.046852707862854
Epoch 130 / 500 | iteration 5 / 30 | Total Loss: 5.34259557723999 | KNN Loss: 4.333748817443848 | BCE Loss: 1.0088468790054321
Epoch 130 / 500 | iteration 10 / 30 | Total Loss: 5.333478927612305 | KNN Loss: 4.347184658050537 | BCE Loss: 0.9862940311431885
Epoch 130 / 500 | iteration 15 / 30 | Total Loss: 5.400967597961426 | KNN Loss: 4.388238430023193 | BC

Epoch 140 / 500 | iteration 0 / 30 | Total Loss: 5.3428754806518555 | KNN Loss: 4.329902172088623 | BCE Loss: 1.0129730701446533
Epoch 140 / 500 | iteration 5 / 30 | Total Loss: 5.370046138763428 | KNN Loss: 4.342462062835693 | BCE Loss: 1.0275839567184448
Epoch 140 / 500 | iteration 10 / 30 | Total Loss: 5.386135101318359 | KNN Loss: 4.365938186645508 | BCE Loss: 1.0201969146728516
Epoch 140 / 500 | iteration 15 / 30 | Total Loss: 5.3488945960998535 | KNN Loss: 4.322516441345215 | BCE Loss: 1.0263781547546387
Epoch 140 / 500 | iteration 20 / 30 | Total Loss: 5.366454601287842 | KNN Loss: 4.33605432510376 | BCE Loss: 1.0304001569747925
Epoch 140 / 500 | iteration 25 / 30 | Total Loss: 5.373658657073975 | KNN Loss: 4.337007522583008 | BCE Loss: 1.0366511344909668
Epoch 141 / 500 | iteration 0 / 30 | Total Loss: 5.374581336975098 | KNN Loss: 4.349555492401123 | BCE Loss: 1.0250260829925537
Epoch 141 / 500 | iteration 5 / 30 | Total Loss: 5.349714756011963 | KNN Loss: 4.31928014755249 | B

Epoch 150 / 500 | iteration 20 / 30 | Total Loss: 5.320639610290527 | KNN Loss: 4.330503940582275 | BCE Loss: 0.990135908126831
Epoch 150 / 500 | iteration 25 / 30 | Total Loss: 5.372696876525879 | KNN Loss: 4.336907386779785 | BCE Loss: 1.0357896089553833
Epoch 151 / 500 | iteration 0 / 30 | Total Loss: 5.386903762817383 | KNN Loss: 4.376443386077881 | BCE Loss: 1.010460376739502
Epoch 151 / 500 | iteration 5 / 30 | Total Loss: 5.39193058013916 | KNN Loss: 4.3518829345703125 | BCE Loss: 1.0400474071502686
Epoch 151 / 500 | iteration 10 / 30 | Total Loss: 5.398073196411133 | KNN Loss: 4.340515613555908 | BCE Loss: 1.0575578212738037
Epoch 151 / 500 | iteration 15 / 30 | Total Loss: 5.389028072357178 | KNN Loss: 4.367123603820801 | BCE Loss: 1.0219045877456665
Epoch 151 / 500 | iteration 20 / 30 | Total Loss: 5.4321208000183105 | KNN Loss: 4.403107166290283 | BCE Loss: 1.0290136337280273
Epoch 151 / 500 | iteration 25 / 30 | Total Loss: 5.408133506774902 | KNN Loss: 4.38257360458374 | B

Epoch 161 / 500 | iteration 10 / 30 | Total Loss: 5.352229118347168 | KNN Loss: 4.324623107910156 | BCE Loss: 1.0276062488555908
Epoch 161 / 500 | iteration 15 / 30 | Total Loss: 5.341828346252441 | KNN Loss: 4.319684028625488 | BCE Loss: 1.022144079208374
Epoch 161 / 500 | iteration 20 / 30 | Total Loss: 5.400810241699219 | KNN Loss: 4.381416320800781 | BCE Loss: 1.0193936824798584
Epoch 161 / 500 | iteration 25 / 30 | Total Loss: 5.377937316894531 | KNN Loss: 4.355761528015137 | BCE Loss: 1.022175669670105
Epoch 162 / 500 | iteration 0 / 30 | Total Loss: 5.390307903289795 | KNN Loss: 4.3600592613220215 | BCE Loss: 1.0302486419677734
Epoch 162 / 500 | iteration 5 / 30 | Total Loss: 5.381618499755859 | KNN Loss: 4.345142841339111 | BCE Loss: 1.036475658416748
Epoch 162 / 500 | iteration 10 / 30 | Total Loss: 5.416598320007324 | KNN Loss: 4.419052600860596 | BCE Loss: 0.9975454807281494
Epoch 162 / 500 | iteration 15 / 30 | Total Loss: 5.362450122833252 | KNN Loss: 4.339688301086426 | B

Epoch 172 / 500 | iteration 0 / 30 | Total Loss: 5.34600830078125 | KNN Loss: 4.309240818023682 | BCE Loss: 1.0367674827575684
Epoch 172 / 500 | iteration 5 / 30 | Total Loss: 5.418545722961426 | KNN Loss: 4.388669967651367 | BCE Loss: 1.0298758745193481
Epoch 172 / 500 | iteration 10 / 30 | Total Loss: 5.389286518096924 | KNN Loss: 4.359756946563721 | BCE Loss: 1.0295295715332031
Epoch 172 / 500 | iteration 15 / 30 | Total Loss: 5.428611755371094 | KNN Loss: 4.395956516265869 | BCE Loss: 1.0326553583145142
Epoch 172 / 500 | iteration 20 / 30 | Total Loss: 5.3387908935546875 | KNN Loss: 4.326706886291504 | BCE Loss: 1.0120837688446045
Epoch 172 / 500 | iteration 25 / 30 | Total Loss: 5.347718238830566 | KNN Loss: 4.327335357666016 | BCE Loss: 1.0203826427459717
Epoch 173 / 500 | iteration 0 / 30 | Total Loss: 5.366621017456055 | KNN Loss: 4.3409552574157715 | BCE Loss: 1.0256658792495728
Epoch 173 / 500 | iteration 5 / 30 | Total Loss: 5.402732849121094 | KNN Loss: 4.374973297119141 | 

Epoch 182 / 500 | iteration 20 / 30 | Total Loss: 5.363295555114746 | KNN Loss: 4.338576316833496 | BCE Loss: 1.024718999862671
Epoch 182 / 500 | iteration 25 / 30 | Total Loss: 5.350245475769043 | KNN Loss: 4.35152530670166 | BCE Loss: 0.9987201690673828
Epoch 183 / 500 | iteration 0 / 30 | Total Loss: 5.352194786071777 | KNN Loss: 4.339862823486328 | BCE Loss: 1.0123322010040283
Epoch 183 / 500 | iteration 5 / 30 | Total Loss: 5.396963119506836 | KNN Loss: 4.372692584991455 | BCE Loss: 1.02427077293396
Epoch 183 / 500 | iteration 10 / 30 | Total Loss: 5.404500961303711 | KNN Loss: 4.373344421386719 | BCE Loss: 1.0311565399169922
Epoch 183 / 500 | iteration 15 / 30 | Total Loss: 5.339503765106201 | KNN Loss: 4.3305277824401855 | BCE Loss: 1.008975863456726
Epoch 183 / 500 | iteration 20 / 30 | Total Loss: 5.3696746826171875 | KNN Loss: 4.339744567871094 | BCE Loss: 1.0299301147460938
Epoch 183 / 500 | iteration 25 / 30 | Total Loss: 5.399806022644043 | KNN Loss: 4.373886585235596 | BC

Epoch 193 / 500 | iteration 10 / 30 | Total Loss: 5.422121047973633 | KNN Loss: 4.388824939727783 | BCE Loss: 1.0332963466644287
Epoch 193 / 500 | iteration 15 / 30 | Total Loss: 5.366623401641846 | KNN Loss: 4.3213419914245605 | BCE Loss: 1.0452812910079956
Epoch 193 / 500 | iteration 20 / 30 | Total Loss: 5.3987135887146 | KNN Loss: 4.381022930145264 | BCE Loss: 1.0176905393600464
Epoch 193 / 500 | iteration 25 / 30 | Total Loss: 5.38063907623291 | KNN Loss: 4.34373140335083 | BCE Loss: 1.036907434463501
Epoch 194 / 500 | iteration 0 / 30 | Total Loss: 5.374362468719482 | KNN Loss: 4.344920635223389 | BCE Loss: 1.0294419527053833
Epoch 194 / 500 | iteration 5 / 30 | Total Loss: 5.408993721008301 | KNN Loss: 4.383683204650879 | BCE Loss: 1.025310754776001
Epoch 194 / 500 | iteration 10 / 30 | Total Loss: 5.3824968338012695 | KNN Loss: 4.349276065826416 | BCE Loss: 1.0332205295562744
Epoch 194 / 500 | iteration 15 / 30 | Total Loss: 5.442544937133789 | KNN Loss: 4.394824504852295 | BCE

Epoch 204 / 500 | iteration 0 / 30 | Total Loss: 5.348715782165527 | KNN Loss: 4.320167064666748 | BCE Loss: 1.0285487174987793
Epoch 204 / 500 | iteration 5 / 30 | Total Loss: 5.4134016036987305 | KNN Loss: 4.3543009757995605 | BCE Loss: 1.059100866317749
Epoch 204 / 500 | iteration 10 / 30 | Total Loss: 5.359466552734375 | KNN Loss: 4.353501319885254 | BCE Loss: 1.0059654712677002
Epoch 204 / 500 | iteration 15 / 30 | Total Loss: 5.384145736694336 | KNN Loss: 4.339105129241943 | BCE Loss: 1.0450408458709717
Epoch 204 / 500 | iteration 20 / 30 | Total Loss: 5.352299213409424 | KNN Loss: 4.337447643280029 | BCE Loss: 1.0148515701293945
Epoch 204 / 500 | iteration 25 / 30 | Total Loss: 5.327912330627441 | KNN Loss: 4.311300754547119 | BCE Loss: 1.0166113376617432
Epoch 205 / 500 | iteration 0 / 30 | Total Loss: 5.382155418395996 | KNN Loss: 4.333001136779785 | BCE Loss: 1.049154281616211
Epoch 205 / 500 | iteration 5 / 30 | Total Loss: 5.397663116455078 | KNN Loss: 4.369076728820801 | B

Epoch 214 / 500 | iteration 20 / 30 | Total Loss: 5.411269664764404 | KNN Loss: 4.3769097328186035 | BCE Loss: 1.0343598127365112
Epoch 214 / 500 | iteration 25 / 30 | Total Loss: 5.409047603607178 | KNN Loss: 4.370877742767334 | BCE Loss: 1.0381697416305542
Epoch 215 / 500 | iteration 0 / 30 | Total Loss: 5.363781929016113 | KNN Loss: 4.362726211547852 | BCE Loss: 1.0010559558868408
Epoch 215 / 500 | iteration 5 / 30 | Total Loss: 5.374637603759766 | KNN Loss: 4.3541364669799805 | BCE Loss: 1.020500898361206
Epoch 215 / 500 | iteration 10 / 30 | Total Loss: 5.38748836517334 | KNN Loss: 4.374902725219727 | BCE Loss: 1.0125858783721924
Epoch 215 / 500 | iteration 15 / 30 | Total Loss: 5.422930717468262 | KNN Loss: 4.409876823425293 | BCE Loss: 1.0130541324615479
Epoch 215 / 500 | iteration 20 / 30 | Total Loss: 5.412209510803223 | KNN Loss: 4.361316680908203 | BCE Loss: 1.0508928298950195
Epoch 215 / 500 | iteration 25 / 30 | Total Loss: 5.416443824768066 | KNN Loss: 4.371646881103516 |

Epoch 225 / 500 | iteration 10 / 30 | Total Loss: 5.378992557525635 | KNN Loss: 4.326544761657715 | BCE Loss: 1.0524479150772095
Epoch 225 / 500 | iteration 15 / 30 | Total Loss: 5.3792924880981445 | KNN Loss: 4.349289417266846 | BCE Loss: 1.0300031900405884
Epoch 225 / 500 | iteration 20 / 30 | Total Loss: 5.3625359535217285 | KNN Loss: 4.318331718444824 | BCE Loss: 1.0442043542861938
Epoch 225 / 500 | iteration 25 / 30 | Total Loss: 5.359462261199951 | KNN Loss: 4.3393378257751465 | BCE Loss: 1.0201244354248047
Epoch 226 / 500 | iteration 0 / 30 | Total Loss: 5.340794563293457 | KNN Loss: 4.312992095947266 | BCE Loss: 1.027802586555481
Epoch 226 / 500 | iteration 5 / 30 | Total Loss: 5.435834884643555 | KNN Loss: 4.394080638885498 | BCE Loss: 1.0417542457580566
Epoch 226 / 500 | iteration 10 / 30 | Total Loss: 5.371147632598877 | KNN Loss: 4.328165531158447 | BCE Loss: 1.0429822206497192
Epoch 226 / 500 | iteration 15 / 30 | Total Loss: 5.376134395599365 | KNN Loss: 4.35329532623291 

Epoch 236 / 500 | iteration 0 / 30 | Total Loss: 5.339364528656006 | KNN Loss: 4.337825775146484 | BCE Loss: 1.001538634300232
Epoch 236 / 500 | iteration 5 / 30 | Total Loss: 5.416865348815918 | KNN Loss: 4.333803653717041 | BCE Loss: 1.0830614566802979
Epoch 236 / 500 | iteration 10 / 30 | Total Loss: 5.37565803527832 | KNN Loss: 4.3523759841918945 | BCE Loss: 1.0232818126678467
Epoch 236 / 500 | iteration 15 / 30 | Total Loss: 5.332500457763672 | KNN Loss: 4.334814071655273 | BCE Loss: 0.9976861476898193
Epoch 236 / 500 | iteration 20 / 30 | Total Loss: 5.3318281173706055 | KNN Loss: 4.311900615692139 | BCE Loss: 1.019927740097046
Epoch 236 / 500 | iteration 25 / 30 | Total Loss: 5.401526927947998 | KNN Loss: 4.381556510925293 | BCE Loss: 1.019970417022705
Epoch 237 / 500 | iteration 0 / 30 | Total Loss: 5.421878337860107 | KNN Loss: 4.350667953491211 | BCE Loss: 1.071210503578186
Epoch 237 / 500 | iteration 5 / 30 | Total Loss: 5.396313667297363 | KNN Loss: 4.38993501663208 | BCE L

Epoch 246 / 500 | iteration 20 / 30 | Total Loss: 5.348838806152344 | KNN Loss: 4.314168930053711 | BCE Loss: 1.0346699953079224
Epoch 246 / 500 | iteration 25 / 30 | Total Loss: 5.391701698303223 | KNN Loss: 4.327085494995117 | BCE Loss: 1.0646159648895264
Epoch 247 / 500 | iteration 0 / 30 | Total Loss: 5.3645524978637695 | KNN Loss: 4.341398239135742 | BCE Loss: 1.0231542587280273
Epoch 247 / 500 | iteration 5 / 30 | Total Loss: 5.436985015869141 | KNN Loss: 4.397541046142578 | BCE Loss: 1.0394442081451416
Epoch 247 / 500 | iteration 10 / 30 | Total Loss: 5.332683086395264 | KNN Loss: 4.321549415588379 | BCE Loss: 1.0111335515975952
Epoch 247 / 500 | iteration 15 / 30 | Total Loss: 5.445439338684082 | KNN Loss: 4.404435157775879 | BCE Loss: 1.0410044193267822
Epoch 247 / 500 | iteration 20 / 30 | Total Loss: 5.358083724975586 | KNN Loss: 4.334383964538574 | BCE Loss: 1.0236995220184326
Epoch 247 / 500 | iteration 25 / 30 | Total Loss: 5.383734226226807 | KNN Loss: 4.3540239334106445

Epoch 257 / 500 | iteration 10 / 30 | Total Loss: 5.371140003204346 | KNN Loss: 4.33725118637085 | BCE Loss: 1.033888816833496
Epoch 257 / 500 | iteration 15 / 30 | Total Loss: 5.382960319519043 | KNN Loss: 4.357584476470947 | BCE Loss: 1.0253756046295166
Epoch 257 / 500 | iteration 20 / 30 | Total Loss: 5.3643364906311035 | KNN Loss: 4.336507797241211 | BCE Loss: 1.027828574180603
Epoch 257 / 500 | iteration 25 / 30 | Total Loss: 5.354681968688965 | KNN Loss: 4.333842754364014 | BCE Loss: 1.020838975906372
Epoch 258 / 500 | iteration 0 / 30 | Total Loss: 5.365898132324219 | KNN Loss: 4.344149112701416 | BCE Loss: 1.0217489004135132
Epoch 258 / 500 | iteration 5 / 30 | Total Loss: 5.329577445983887 | KNN Loss: 4.335717678070068 | BCE Loss: 0.9938599467277527
Epoch 258 / 500 | iteration 10 / 30 | Total Loss: 5.380795478820801 | KNN Loss: 4.344150543212891 | BCE Loss: 1.0366451740264893
Epoch 258 / 500 | iteration 15 / 30 | Total Loss: 5.432803153991699 | KNN Loss: 4.409760475158691 | BC

Epoch 268 / 500 | iteration 0 / 30 | Total Loss: 5.369577407836914 | KNN Loss: 4.339098930358887 | BCE Loss: 1.0304783582687378
Epoch 268 / 500 | iteration 5 / 30 | Total Loss: 5.427004337310791 | KNN Loss: 4.39735746383667 | BCE Loss: 1.0296469926834106
Epoch 268 / 500 | iteration 10 / 30 | Total Loss: 5.390960693359375 | KNN Loss: 4.351238250732422 | BCE Loss: 1.039722204208374
Epoch 268 / 500 | iteration 15 / 30 | Total Loss: 5.336404800415039 | KNN Loss: 4.338916301727295 | BCE Loss: 0.9974887371063232
Epoch 268 / 500 | iteration 20 / 30 | Total Loss: 5.385190486907959 | KNN Loss: 4.336864471435547 | BCE Loss: 1.0483261346817017
Epoch 268 / 500 | iteration 25 / 30 | Total Loss: 5.360833168029785 | KNN Loss: 4.339285850524902 | BCE Loss: 1.0215471982955933
Epoch 269 / 500 | iteration 0 / 30 | Total Loss: 5.384693622589111 | KNN Loss: 4.354675769805908 | BCE Loss: 1.0300178527832031
Epoch 269 / 500 | iteration 5 / 30 | Total Loss: 5.376913547515869 | KNN Loss: 4.348777770996094 | BCE

Epoch 278 / 500 | iteration 20 / 30 | Total Loss: 5.315619468688965 | KNN Loss: 4.301762580871582 | BCE Loss: 1.0138567686080933
Epoch 278 / 500 | iteration 25 / 30 | Total Loss: 5.373989582061768 | KNN Loss: 4.359185218811035 | BCE Loss: 1.0148043632507324
Epoch 279 / 500 | iteration 0 / 30 | Total Loss: 5.34412956237793 | KNN Loss: 4.317481994628906 | BCE Loss: 1.0266475677490234
Epoch 279 / 500 | iteration 5 / 30 | Total Loss: 5.348244667053223 | KNN Loss: 4.3188042640686035 | BCE Loss: 1.02944016456604
Epoch 279 / 500 | iteration 10 / 30 | Total Loss: 5.349989891052246 | KNN Loss: 4.342126846313477 | BCE Loss: 1.007863163948059
Epoch 279 / 500 | iteration 15 / 30 | Total Loss: 5.356930255889893 | KNN Loss: 4.337635040283203 | BCE Loss: 1.0192952156066895
Epoch 279 / 500 | iteration 20 / 30 | Total Loss: 5.4054460525512695 | KNN Loss: 4.350876331329346 | BCE Loss: 1.0545694828033447
Epoch 279 / 500 | iteration 25 / 30 | Total Loss: 5.433729648590088 | KNN Loss: 4.394647121429443 | B

Epoch 289 / 500 | iteration 10 / 30 | Total Loss: 5.368314266204834 | KNN Loss: 4.356307029724121 | BCE Loss: 1.012007236480713
Epoch 289 / 500 | iteration 15 / 30 | Total Loss: 5.421264171600342 | KNN Loss: 4.370146751403809 | BCE Loss: 1.0511174201965332
Epoch 289 / 500 | iteration 20 / 30 | Total Loss: 5.426236152648926 | KNN Loss: 4.3680925369262695 | BCE Loss: 1.0581433773040771
Epoch 289 / 500 | iteration 25 / 30 | Total Loss: 5.383167266845703 | KNN Loss: 4.391207695007324 | BCE Loss: 0.9919596314430237
Epoch 290 / 500 | iteration 0 / 30 | Total Loss: 5.386585235595703 | KNN Loss: 4.347933292388916 | BCE Loss: 1.0386518239974976
Epoch 290 / 500 | iteration 5 / 30 | Total Loss: 5.3985466957092285 | KNN Loss: 4.370584487915039 | BCE Loss: 1.0279620885849
Epoch 290 / 500 | iteration 10 / 30 | Total Loss: 5.38459587097168 | KNN Loss: 4.3313727378845215 | BCE Loss: 1.053222894668579
Epoch 290 / 500 | iteration 15 / 30 | Total Loss: 5.326060771942139 | KNN Loss: 4.335299968719482 | BC

Epoch 300 / 500 | iteration 0 / 30 | Total Loss: 5.389990329742432 | KNN Loss: 4.355979919433594 | BCE Loss: 1.034010410308838
Epoch 300 / 500 | iteration 5 / 30 | Total Loss: 5.348206520080566 | KNN Loss: 4.313880443572998 | BCE Loss: 1.0343259572982788
Epoch 300 / 500 | iteration 10 / 30 | Total Loss: 5.396929740905762 | KNN Loss: 4.364871978759766 | BCE Loss: 1.0320580005645752
Epoch 300 / 500 | iteration 15 / 30 | Total Loss: 5.388006210327148 | KNN Loss: 4.359795093536377 | BCE Loss: 1.0282108783721924
Epoch 300 / 500 | iteration 20 / 30 | Total Loss: 5.370822906494141 | KNN Loss: 4.342879295349121 | BCE Loss: 1.0279433727264404
Epoch 300 / 500 | iteration 25 / 30 | Total Loss: 5.386289119720459 | KNN Loss: 4.3584465980529785 | BCE Loss: 1.0278425216674805
Epoch 301 / 500 | iteration 0 / 30 | Total Loss: 5.387313365936279 | KNN Loss: 4.361653804779053 | BCE Loss: 1.025659441947937
Epoch 301 / 500 | iteration 5 / 30 | Total Loss: 5.371978759765625 | KNN Loss: 4.334102153778076 | BC

Epoch 310 / 500 | iteration 20 / 30 | Total Loss: 5.410435676574707 | KNN Loss: 4.364666938781738 | BCE Loss: 1.0457689762115479
Epoch 310 / 500 | iteration 25 / 30 | Total Loss: 5.3740153312683105 | KNN Loss: 4.344790935516357 | BCE Loss: 1.0292242765426636
Epoch 311 / 500 | iteration 0 / 30 | Total Loss: 5.342670440673828 | KNN Loss: 4.315915107727051 | BCE Loss: 1.0267555713653564
Epoch 311 / 500 | iteration 5 / 30 | Total Loss: 5.338126182556152 | KNN Loss: 4.32811164855957 | BCE Loss: 1.0100146532058716
Epoch 311 / 500 | iteration 10 / 30 | Total Loss: 5.386592864990234 | KNN Loss: 4.353062629699707 | BCE Loss: 1.0335302352905273
Epoch 311 / 500 | iteration 15 / 30 | Total Loss: 5.349861145019531 | KNN Loss: 4.331956386566162 | BCE Loss: 1.0179049968719482
Epoch 311 / 500 | iteration 20 / 30 | Total Loss: 5.402221202850342 | KNN Loss: 4.38935661315918 | BCE Loss: 1.012864589691162
Epoch 311 / 500 | iteration 25 / 30 | Total Loss: 5.369329452514648 | KNN Loss: 4.338865280151367 | B

Epoch 321 / 500 | iteration 10 / 30 | Total Loss: 5.34492301940918 | KNN Loss: 4.324484348297119 | BCE Loss: 1.0204386711120605
Epoch 321 / 500 | iteration 15 / 30 | Total Loss: 5.350344657897949 | KNN Loss: 4.334980010986328 | BCE Loss: 1.0153645277023315
Epoch 321 / 500 | iteration 20 / 30 | Total Loss: 5.345221042633057 | KNN Loss: 4.309000492095947 | BCE Loss: 1.036220669746399
Epoch 321 / 500 | iteration 25 / 30 | Total Loss: 5.355144500732422 | KNN Loss: 4.345033645629883 | BCE Loss: 1.0101110935211182
Epoch 322 / 500 | iteration 0 / 30 | Total Loss: 5.343729496002197 | KNN Loss: 4.363171100616455 | BCE Loss: 0.9805582165718079
Epoch 322 / 500 | iteration 5 / 30 | Total Loss: 5.39039421081543 | KNN Loss: 4.363812446594238 | BCE Loss: 1.026581883430481
Epoch 322 / 500 | iteration 10 / 30 | Total Loss: 5.392969131469727 | KNN Loss: 4.34699821472168 | BCE Loss: 1.045971155166626
Epoch 322 / 500 | iteration 15 / 30 | Total Loss: 5.393810272216797 | KNN Loss: 4.371192932128906 | BCE L

Epoch 332 / 500 | iteration 0 / 30 | Total Loss: 5.411312103271484 | KNN Loss: 4.355129241943359 | BCE Loss: 1.056182622909546
Epoch 332 / 500 | iteration 5 / 30 | Total Loss: 5.346630096435547 | KNN Loss: 4.33907413482666 | BCE Loss: 1.0075558423995972
Epoch 332 / 500 | iteration 10 / 30 | Total Loss: 5.349249362945557 | KNN Loss: 4.351034641265869 | BCE Loss: 0.9982146620750427
Epoch 332 / 500 | iteration 15 / 30 | Total Loss: 5.372572898864746 | KNN Loss: 4.352540969848633 | BCE Loss: 1.0200319290161133
Epoch 332 / 500 | iteration 20 / 30 | Total Loss: 5.382381439208984 | KNN Loss: 4.331045627593994 | BCE Loss: 1.0513358116149902
Epoch 332 / 500 | iteration 25 / 30 | Total Loss: 5.355215072631836 | KNN Loss: 4.350736141204834 | BCE Loss: 1.0044786930084229
Epoch 333 / 500 | iteration 0 / 30 | Total Loss: 5.344328880310059 | KNN Loss: 4.329905033111572 | BCE Loss: 1.0144238471984863
Epoch 333 / 500 | iteration 5 / 30 | Total Loss: 5.351188659667969 | KNN Loss: 4.339887619018555 | BCE

Epoch 342 / 500 | iteration 20 / 30 | Total Loss: 5.451674461364746 | KNN Loss: 4.418697357177734 | BCE Loss: 1.0329771041870117
Epoch 342 / 500 | iteration 25 / 30 | Total Loss: 5.389314651489258 | KNN Loss: 4.3579511642456055 | BCE Loss: 1.031363606452942
Epoch 343 / 500 | iteration 0 / 30 | Total Loss: 5.3628411293029785 | KNN Loss: 4.350743293762207 | BCE Loss: 1.0120978355407715
Epoch 343 / 500 | iteration 5 / 30 | Total Loss: 5.405381202697754 | KNN Loss: 4.340137958526611 | BCE Loss: 1.0652433633804321
Epoch 343 / 500 | iteration 10 / 30 | Total Loss: 5.397469520568848 | KNN Loss: 4.361158847808838 | BCE Loss: 1.0363106727600098
Epoch 343 / 500 | iteration 15 / 30 | Total Loss: 5.3735671043396 | KNN Loss: 4.334228992462158 | BCE Loss: 1.039338231086731
Epoch 343 / 500 | iteration 20 / 30 | Total Loss: 5.379185676574707 | KNN Loss: 4.369607448577881 | BCE Loss: 1.0095784664154053
Epoch 343 / 500 | iteration 25 / 30 | Total Loss: 5.327301979064941 | KNN Loss: 4.30942964553833 | BC

Epoch 353 / 500 | iteration 10 / 30 | Total Loss: 5.3717942237854 | KNN Loss: 4.3435750007629395 | BCE Loss: 1.028219223022461
Epoch 353 / 500 | iteration 15 / 30 | Total Loss: 5.387276649475098 | KNN Loss: 4.355654239654541 | BCE Loss: 1.0316225290298462
Epoch 353 / 500 | iteration 20 / 30 | Total Loss: 5.374594688415527 | KNN Loss: 4.328200340270996 | BCE Loss: 1.0463945865631104
Epoch 353 / 500 | iteration 25 / 30 | Total Loss: 5.344537734985352 | KNN Loss: 4.313028335571289 | BCE Loss: 1.0315093994140625
Epoch 354 / 500 | iteration 0 / 30 | Total Loss: 5.407098293304443 | KNN Loss: 4.381347179412842 | BCE Loss: 1.0257512331008911
Epoch 354 / 500 | iteration 5 / 30 | Total Loss: 5.389979362487793 | KNN Loss: 4.335143566131592 | BCE Loss: 1.0548357963562012
Epoch 354 / 500 | iteration 10 / 30 | Total Loss: 5.4219841957092285 | KNN Loss: 4.376189231872559 | BCE Loss: 1.04579496383667
Epoch 354 / 500 | iteration 15 / 30 | Total Loss: 5.379724025726318 | KNN Loss: 4.372725963592529 | BC

Epoch 364 / 500 | iteration 0 / 30 | Total Loss: 5.372774124145508 | KNN Loss: 4.351986885070801 | BCE Loss: 1.0207874774932861
Epoch 364 / 500 | iteration 5 / 30 | Total Loss: 5.461465835571289 | KNN Loss: 4.410369873046875 | BCE Loss: 1.0510962009429932
Epoch 364 / 500 | iteration 10 / 30 | Total Loss: 5.372284889221191 | KNN Loss: 4.3514723777771 | BCE Loss: 1.020812749862671
Epoch 364 / 500 | iteration 15 / 30 | Total Loss: 5.372580528259277 | KNN Loss: 4.3545823097229 | BCE Loss: 1.0179980993270874
Epoch 364 / 500 | iteration 20 / 30 | Total Loss: 5.368582248687744 | KNN Loss: 4.327709674835205 | BCE Loss: 1.040872573852539
Epoch 364 / 500 | iteration 25 / 30 | Total Loss: 5.31877326965332 | KNN Loss: 4.322635173797607 | BCE Loss: 0.996138334274292
Epoch 365 / 500 | iteration 0 / 30 | Total Loss: 5.4661865234375 | KNN Loss: 4.416163921356201 | BCE Loss: 1.0500223636627197
Epoch 365 / 500 | iteration 5 / 30 | Total Loss: 5.442360877990723 | KNN Loss: 4.391425132751465 | BCE Loss: 1

Epoch 374 / 500 | iteration 20 / 30 | Total Loss: 5.353998184204102 | KNN Loss: 4.3004374504089355 | BCE Loss: 1.053560495376587
Epoch 374 / 500 | iteration 25 / 30 | Total Loss: 5.361170768737793 | KNN Loss: 4.346921920776367 | BCE Loss: 1.0142486095428467
Epoch 375 / 500 | iteration 0 / 30 | Total Loss: 5.425289154052734 | KNN Loss: 4.3696112632751465 | BCE Loss: 1.055677890777588
Epoch 375 / 500 | iteration 5 / 30 | Total Loss: 5.354892730712891 | KNN Loss: 4.347204685211182 | BCE Loss: 1.0076879262924194
Epoch 375 / 500 | iteration 10 / 30 | Total Loss: 5.319975852966309 | KNN Loss: 4.3120598793029785 | BCE Loss: 1.0079158544540405
Epoch 375 / 500 | iteration 15 / 30 | Total Loss: 5.350371837615967 | KNN Loss: 4.323099136352539 | BCE Loss: 1.0272728204727173
Epoch 375 / 500 | iteration 20 / 30 | Total Loss: 5.418903827667236 | KNN Loss: 4.399919509887695 | BCE Loss: 1.018984317779541
Epoch 375 / 500 | iteration 25 / 30 | Total Loss: 5.390944004058838 | KNN Loss: 4.364144325256348 |

Epoch 385 / 500 | iteration 10 / 30 | Total Loss: 5.3813276290893555 | KNN Loss: 4.34381103515625 | BCE Loss: 1.0375168323516846
Epoch 385 / 500 | iteration 15 / 30 | Total Loss: 5.357524871826172 | KNN Loss: 4.324836254119873 | BCE Loss: 1.032688856124878
Epoch 385 / 500 | iteration 20 / 30 | Total Loss: 5.404170989990234 | KNN Loss: 4.361038684844971 | BCE Loss: 1.0431320667266846
Epoch 385 / 500 | iteration 25 / 30 | Total Loss: 5.346055030822754 | KNN Loss: 4.334372520446777 | BCE Loss: 1.0116827487945557
Epoch 386 / 500 | iteration 0 / 30 | Total Loss: 5.354611396789551 | KNN Loss: 4.330931186676025 | BCE Loss: 1.0236800909042358
Epoch 386 / 500 | iteration 5 / 30 | Total Loss: 5.421602249145508 | KNN Loss: 4.379310607910156 | BCE Loss: 1.0422918796539307
Epoch 386 / 500 | iteration 10 / 30 | Total Loss: 5.409458160400391 | KNN Loss: 4.35283088684082 | BCE Loss: 1.0566273927688599
Epoch 386 / 500 | iteration 15 / 30 | Total Loss: 5.345491886138916 | KNN Loss: 4.327455520629883 | B

Epoch 396 / 500 | iteration 0 / 30 | Total Loss: 5.479053020477295 | KNN Loss: 4.424442768096924 | BCE Loss: 1.054610252380371
Epoch 396 / 500 | iteration 5 / 30 | Total Loss: 5.4039812088012695 | KNN Loss: 4.381107330322266 | BCE Loss: 1.022873878479004
Epoch 396 / 500 | iteration 10 / 30 | Total Loss: 5.393434524536133 | KNN Loss: 4.372860431671143 | BCE Loss: 1.0205740928649902
Epoch 396 / 500 | iteration 15 / 30 | Total Loss: 5.357295513153076 | KNN Loss: 4.340730667114258 | BCE Loss: 1.016564965248108
Epoch 396 / 500 | iteration 20 / 30 | Total Loss: 5.375602722167969 | KNN Loss: 4.356148719787598 | BCE Loss: 1.019453763961792
Epoch 396 / 500 | iteration 25 / 30 | Total Loss: 5.38551139831543 | KNN Loss: 4.362536430358887 | BCE Loss: 1.022975206375122
Epoch 397 / 500 | iteration 0 / 30 | Total Loss: 5.35002326965332 | KNN Loss: 4.322474956512451 | BCE Loss: 1.0275481939315796
Epoch 397 / 500 | iteration 5 / 30 | Total Loss: 5.322835445404053 | KNN Loss: 4.341619491577148 | BCE Los

Epoch 406 / 500 | iteration 20 / 30 | Total Loss: 5.38449239730835 | KNN Loss: 4.349279880523682 | BCE Loss: 1.035212516784668
Epoch 406 / 500 | iteration 25 / 30 | Total Loss: 5.40925407409668 | KNN Loss: 4.359587669372559 | BCE Loss: 1.049666166305542
Epoch 407 / 500 | iteration 0 / 30 | Total Loss: 5.45833158493042 | KNN Loss: 4.370003700256348 | BCE Loss: 1.0883278846740723
Epoch 407 / 500 | iteration 5 / 30 | Total Loss: 5.374407768249512 | KNN Loss: 4.3740925788879395 | BCE Loss: 1.0003154277801514
Epoch 407 / 500 | iteration 10 / 30 | Total Loss: 5.381504058837891 | KNN Loss: 4.365171432495117 | BCE Loss: 1.0163328647613525
Epoch 407 / 500 | iteration 15 / 30 | Total Loss: 5.40058708190918 | KNN Loss: 4.402865886688232 | BCE Loss: 0.9977211952209473
Epoch 407 / 500 | iteration 20 / 30 | Total Loss: 5.348052978515625 | KNN Loss: 4.322151184082031 | BCE Loss: 1.0259015560150146
Epoch 407 / 500 | iteration 25 / 30 | Total Loss: 5.398111343383789 | KNN Loss: 4.368443489074707 | BCE 

Epoch 417 / 500 | iteration 10 / 30 | Total Loss: 5.352338790893555 | KNN Loss: 4.3151655197143555 | BCE Loss: 1.0371730327606201
Epoch 417 / 500 | iteration 15 / 30 | Total Loss: 5.377419471740723 | KNN Loss: 4.3636345863342285 | BCE Loss: 1.0137851238250732
Epoch 417 / 500 | iteration 20 / 30 | Total Loss: 5.376640319824219 | KNN Loss: 4.3538713455200195 | BCE Loss: 1.0227687358856201
Epoch 417 / 500 | iteration 25 / 30 | Total Loss: 5.358232498168945 | KNN Loss: 4.3286237716674805 | BCE Loss: 1.0296088457107544
Epoch 418 / 500 | iteration 0 / 30 | Total Loss: 5.374948501586914 | KNN Loss: 4.335185527801514 | BCE Loss: 1.0397629737854004
Epoch 418 / 500 | iteration 5 / 30 | Total Loss: 5.433009624481201 | KNN Loss: 4.401450157165527 | BCE Loss: 1.0315595865249634
Epoch 418 / 500 | iteration 10 / 30 | Total Loss: 5.339648723602295 | KNN Loss: 4.330431938171387 | BCE Loss: 1.0092166662216187
Epoch 418 / 500 | iteration 15 / 30 | Total Loss: 5.378896713256836 | KNN Loss: 4.3469910621643

Epoch 428 / 500 | iteration 0 / 30 | Total Loss: 5.421289920806885 | KNN Loss: 4.422662734985352 | BCE Loss: 0.9986270070075989
Epoch 428 / 500 | iteration 5 / 30 | Total Loss: 5.395833969116211 | KNN Loss: 4.362049102783203 | BCE Loss: 1.0337848663330078
Epoch 428 / 500 | iteration 10 / 30 | Total Loss: 5.390296936035156 | KNN Loss: 4.336329936981201 | BCE Loss: 1.0539672374725342
Epoch 428 / 500 | iteration 15 / 30 | Total Loss: 5.3510026931762695 | KNN Loss: 4.337388038635254 | BCE Loss: 1.0136144161224365
Epoch 428 / 500 | iteration 20 / 30 | Total Loss: 5.383535385131836 | KNN Loss: 4.361435413360596 | BCE Loss: 1.0220997333526611
Epoch 428 / 500 | iteration 25 / 30 | Total Loss: 5.352415561676025 | KNN Loss: 4.311427116394043 | BCE Loss: 1.040988564491272
Epoch 429 / 500 | iteration 0 / 30 | Total Loss: 5.419510364532471 | KNN Loss: 4.354089736938477 | BCE Loss: 1.0654206275939941
Epoch 429 / 500 | iteration 5 / 30 | Total Loss: 5.428005695343018 | KNN Loss: 4.3959856033325195 | 

Epoch 438 / 500 | iteration 20 / 30 | Total Loss: 5.4169511795043945 | KNN Loss: 4.378000259399414 | BCE Loss: 1.0389506816864014
Epoch 438 / 500 | iteration 25 / 30 | Total Loss: 5.374704360961914 | KNN Loss: 4.360998630523682 | BCE Loss: 1.0137059688568115
Epoch 439 / 500 | iteration 0 / 30 | Total Loss: 5.400293350219727 | KNN Loss: 4.336090564727783 | BCE Loss: 1.0642027854919434
Epoch 439 / 500 | iteration 5 / 30 | Total Loss: 5.392294883728027 | KNN Loss: 4.387814998626709 | BCE Loss: 1.0044798851013184
Epoch 439 / 500 | iteration 10 / 30 | Total Loss: 5.423892498016357 | KNN Loss: 4.417397499084473 | BCE Loss: 1.0064949989318848
Epoch 439 / 500 | iteration 15 / 30 | Total Loss: 5.377165794372559 | KNN Loss: 4.338637828826904 | BCE Loss: 1.0385277271270752
Epoch 439 / 500 | iteration 20 / 30 | Total Loss: 5.317832946777344 | KNN Loss: 4.313119411468506 | BCE Loss: 1.004713773727417
Epoch 439 / 500 | iteration 25 / 30 | Total Loss: 5.35314416885376 | KNN Loss: 4.353750705718994 | 

Epoch 449 / 500 | iteration 10 / 30 | Total Loss: 5.423148155212402 | KNN Loss: 4.35068416595459 | BCE Loss: 1.0724639892578125
Epoch 449 / 500 | iteration 15 / 30 | Total Loss: 5.461158275604248 | KNN Loss: 4.419189453125 | BCE Loss: 1.0419687032699585
Epoch 449 / 500 | iteration 20 / 30 | Total Loss: 5.389246463775635 | KNN Loss: 4.370909690856934 | BCE Loss: 1.0183367729187012
Epoch 449 / 500 | iteration 25 / 30 | Total Loss: 5.395580768585205 | KNN Loss: 4.372767448425293 | BCE Loss: 1.0228132009506226
Epoch 450 / 500 | iteration 0 / 30 | Total Loss: 5.377960205078125 | KNN Loss: 4.34993839263916 | BCE Loss: 1.0280218124389648
Epoch 450 / 500 | iteration 5 / 30 | Total Loss: 5.393111228942871 | KNN Loss: 4.352309703826904 | BCE Loss: 1.0408015251159668
Epoch 450 / 500 | iteration 10 / 30 | Total Loss: 5.331720352172852 | KNN Loss: 4.331087589263916 | BCE Loss: 1.0006327629089355
Epoch 450 / 500 | iteration 15 / 30 | Total Loss: 5.380775451660156 | KNN Loss: 4.344493389129639 | BCE 

Epoch 460 / 500 | iteration 0 / 30 | Total Loss: 5.386999607086182 | KNN Loss: 4.380929946899414 | BCE Loss: 1.0060697793960571
Epoch 460 / 500 | iteration 5 / 30 | Total Loss: 5.334714889526367 | KNN Loss: 4.324290752410889 | BCE Loss: 1.0104238986968994
Epoch 460 / 500 | iteration 10 / 30 | Total Loss: 5.385002613067627 | KNN Loss: 4.346275806427002 | BCE Loss: 1.0387266874313354
Epoch 460 / 500 | iteration 15 / 30 | Total Loss: 5.4249653816223145 | KNN Loss: 4.3769330978393555 | BCE Loss: 1.0480321645736694
Epoch 460 / 500 | iteration 20 / 30 | Total Loss: 5.290914535522461 | KNN Loss: 4.295884132385254 | BCE Loss: 0.9950305819511414
Epoch 460 / 500 | iteration 25 / 30 | Total Loss: 5.386981010437012 | KNN Loss: 4.3875532150268555 | BCE Loss: 0.9994277954101562
Epoch 461 / 500 | iteration 0 / 30 | Total Loss: 5.370757579803467 | KNN Loss: 4.334097385406494 | BCE Loss: 1.0366601943969727
Epoch 461 / 500 | iteration 5 / 30 | Total Loss: 5.3638505935668945 | KNN Loss: 4.335745811462402

Epoch 470 / 500 | iteration 20 / 30 | Total Loss: 5.423464298248291 | KNN Loss: 4.391591548919678 | BCE Loss: 1.0318726301193237
Epoch 470 / 500 | iteration 25 / 30 | Total Loss: 5.368599891662598 | KNN Loss: 4.357300281524658 | BCE Loss: 1.01129949092865
Epoch 471 / 500 | iteration 0 / 30 | Total Loss: 5.375969886779785 | KNN Loss: 4.3642730712890625 | BCE Loss: 1.0116968154907227
Epoch 471 / 500 | iteration 5 / 30 | Total Loss: 5.359688758850098 | KNN Loss: 4.33807897567749 | BCE Loss: 1.0216097831726074
Epoch 471 / 500 | iteration 10 / 30 | Total Loss: 5.340062141418457 | KNN Loss: 4.32853889465332 | BCE Loss: 1.0115234851837158
Epoch 471 / 500 | iteration 15 / 30 | Total Loss: 5.390985488891602 | KNN Loss: 4.370636463165283 | BCE Loss: 1.0203490257263184
Epoch 471 / 500 | iteration 20 / 30 | Total Loss: 5.374786853790283 | KNN Loss: 4.332377910614014 | BCE Loss: 1.0424089431762695
Epoch 471 / 500 | iteration 25 / 30 | Total Loss: 5.377124309539795 | KNN Loss: 4.357393741607666 | BC

Epoch 481 / 500 | iteration 10 / 30 | Total Loss: 5.449256420135498 | KNN Loss: 4.416590213775635 | BCE Loss: 1.0326663255691528
Epoch 481 / 500 | iteration 15 / 30 | Total Loss: 5.333684921264648 | KNN Loss: 4.340879917144775 | BCE Loss: 0.9928052425384521
Epoch 481 / 500 | iteration 20 / 30 | Total Loss: 5.391486167907715 | KNN Loss: 4.374935150146484 | BCE Loss: 1.0165507793426514
Epoch 481 / 500 | iteration 25 / 30 | Total Loss: 5.351450443267822 | KNN Loss: 4.331624984741211 | BCE Loss: 1.0198253393173218
Epoch 482 / 500 | iteration 0 / 30 | Total Loss: 5.394283294677734 | KNN Loss: 4.3771209716796875 | BCE Loss: 1.0171622037887573
Epoch 482 / 500 | iteration 5 / 30 | Total Loss: 5.38389778137207 | KNN Loss: 4.361886024475098 | BCE Loss: 1.0220119953155518
Epoch 482 / 500 | iteration 10 / 30 | Total Loss: 5.368283271789551 | KNN Loss: 4.354225158691406 | BCE Loss: 1.0140578746795654
Epoch 482 / 500 | iteration 15 / 30 | Total Loss: 5.428168296813965 | KNN Loss: 4.3803253173828125 

Epoch 492 / 500 | iteration 0 / 30 | Total Loss: 5.382512092590332 | KNN Loss: 4.365140438079834 | BCE Loss: 1.0173718929290771
Epoch 492 / 500 | iteration 5 / 30 | Total Loss: 5.305492401123047 | KNN Loss: 4.301912784576416 | BCE Loss: 1.0035796165466309
Epoch 492 / 500 | iteration 10 / 30 | Total Loss: 5.386650085449219 | KNN Loss: 4.372052192687988 | BCE Loss: 1.0145981311798096
Epoch 492 / 500 | iteration 15 / 30 | Total Loss: 5.3594818115234375 | KNN Loss: 4.342615127563477 | BCE Loss: 1.016866683959961
Epoch 492 / 500 | iteration 20 / 30 | Total Loss: 5.398952960968018 | KNN Loss: 4.367936611175537 | BCE Loss: 1.0310163497924805
Epoch 492 / 500 | iteration 25 / 30 | Total Loss: 5.428266525268555 | KNN Loss: 4.396023750305176 | BCE Loss: 1.032242774963379
Epoch 493 / 500 | iteration 0 / 30 | Total Loss: 5.367700099945068 | KNN Loss: 4.355287075042725 | BCE Loss: 1.0124129056930542
Epoch 493 / 500 | iteration 5 / 30 | Total Loss: 5.358769416809082 | KNN Loss: 4.3501362800598145 | B

In [7]:
outputs, iterm = model(dataset[67][0].unsqueeze(0).to(device), return_intermidiate=True)
print(outputs)
print(outputs.min())
print(outputs.max())

tensor([[ 1.7441,  2.6732,  3.0950,  4.1396,  4.0467,  0.6395,  1.9095,  2.7127,
          1.7671,  1.6543,  2.7099,  1.4446,  1.1443,  2.1694,  1.2966,  2.0321,
          2.4437,  1.9823,  1.9205,  2.8463,  1.7464,  3.4970,  1.5933,  3.1552,
          2.5494,  1.8534,  1.7290,  1.4775,  1.4493,  0.4399,  0.1224,  1.0770,
          0.2949,  1.1134,  1.3717,  1.1682,  1.3365,  3.8645,  0.7598,  1.3859,
          1.3416, -0.6291, -0.3325,  2.8607,  2.1404,  0.6125, -0.1257, -0.0143,
          1.4580,  2.1000,  2.1628,  0.0753,  1.7593,  0.5025, -0.3609,  1.1956,
          1.8183,  1.4060,  1.6863,  1.7672,  0.6736,  0.8092,  0.1594,  1.3707,
          1.1880,  1.9930, -1.7575,  0.3536,  1.8959,  1.8832,  3.0835,  0.2219,
          1.3353,  1.6725,  1.6955,  1.6607,  0.1663,  0.6747,  0.0909,  1.3627,
         -0.0319,  0.7510,  2.2128, -0.3403,  0.3755, -1.0097, -2.2813, -0.5493,
          0.6126, -1.6691,  0.3744, -0.0092, -0.4956, -0.7589,  0.5502,  1.3111,
         -0.6540, -1.2032,  

In [8]:
plt.figure()
plt.plot(losses)
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [9]:
dataset.transform = torchvision.transforms.Compose([
    BinaryEncodingTransform(mapping=dataset.items_to_idx),
]
)
dataset.target_transform = torchvision.transforms.Compose([
    BinaryEncodingTransform(mapping=dataset.items_to_idx),
]
)

In [10]:
dataset_ = [d[0].to('cpu') for d in dataset]

In [11]:
model = model.eval().to('cpu')
projections = model.calculate_intermidiate(dataset_)

100%|██████████| 15/15 [00:00<00:00, 92.22it/s]


In [12]:
distances = pairwise_distances(projections)
# distances = np.triu(distances)
distances_f = distances.flatten()

plt.matshow(distances)
plt.colorbar()
plt.figure()
plt.hist(distances_f[distances_f > 0], bins=1000)
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

# Fit DBSCAN and calculate indices


In [13]:
clusters = DBSCAN(eps=0.1, min_samples=80).fit_predict(projections)
# scores = []
# best_score = float('inf')
# clusters = None
# range_ = list(range(5, 20))
# for k in tqdm(range_):
#     y = GaussianMixture(n_components=k).fit_predict(projections)
#     cur_score = davies_bouldin_score(projections, y)
#     scores.append(cur_score)
    
#     if cur_score < best_score:
#         best_score = cur_score
#         clusters = y

In [14]:
perplexity = 100

p = reduce_dims_and_plot(projections[clusters != -1],
                         y=clusters[clusters != -1],
                         title=f'perplexity: {perplexity}',
                         file_name=None,
                         perplexity=perplexity,
                         library='Multicore-TSNE',
                         perform_PCA=False,
                         projected=None,
                         figure_type='2d',
                         show_figure=True,
                         close_figure=False,
                         text=None)    

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [15]:
# from sklearn.tree import DecisionTreeClassifier
# from sklearn import tree
# from sklearn.tree import _tree

In [16]:
# clf = DecisionTreeClassifier(max_depth=200, min_samples_leaf=5)
# clf = clf.fit(tensor_dataset[clusters!=-1], clusters[clusters != -1])
# print(clf.score(tensor_dataset[clusters!=-1], clusters[clusters != -1]))
# print(clf.get_depth())

In [17]:
# scores = []
# for min_samples in range(1,50, 1):
#     clf = DecisionTreeClassifier(max_depth=200, min_samples_leaf=min_samples)
#     clf = clf.fit(tensor_dataset[clusters!=-1], clusters[clusters != -1])
#     scores.append(clf.score(tensor_dataset[clusters!=-1], clusters[clusters != -1]))
    
# plt.figure()
# plt.plot(list(range(1,50, 1)), scores)
# plt.show()

In [18]:
def get_rules(tree, feature_names, class_names):
    tree_ = tree.tree_
    feature_name = [
        feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!"
        for i in tree_.feature
    ]

    paths = []
    path = []
    
    def recurse(node, path, paths):
        if tree_.feature[node] != _tree.TREE_UNDEFINED:
            name = feature_name[node]
            threshold = tree_.threshold[node]
            p1, p2 = list(path), list(path)
#             p1 += [f"({name} <= {np.round(threshold, 3)})"]
            p1 += [(name, '<=', np.round(threshold, 3))]
            recurse(tree_.children_left[node], p1, paths)
            p2 += [(name, '>', np.round(threshold, 3))]
            recurse(tree_.children_right[node], p2, paths)
        else:
            path += [(tree_.value[node], tree_.n_node_samples[node])]
            paths += [path]
            
    recurse(0, path, paths)

    # sort by samples count
    samples_count = [p[-1][1] for p in paths]
    ii = list(np.argsort(samples_count))
    paths = [paths[i] for i in reversed(ii)]
    
    rules = []
    for path in paths:
        rule = []
        
        for p in path[:-1]:
            rule += [p]
        target = " then "
        if class_names is None:
            target += "response: "+str(np.round(path[-1][0][0][0],3))
        else:
            classes = path[-1][0][0]
            l = np.argmax(classes)
            target += f"class: {class_names[l]} (proba: {np.round(100.0*classes[l]/np.sum(classes),2)}%)"
           
        proba = np.round(100.0*classes[l]/np.sum(classes),2)
        target += f" | based on {path[-1][1]:,} samples"
        rule_wrapper = {'target': target, 'rule': rule, 'proba': proba}
        rules += [rule_wrapper]
        
    return rules

In [19]:
# rules = get_rules(clf, dataset.items, clusters[clusters != -1])

# for rule in rules:
#     n_pos = 0
#     for c,p,v in rule['rule']:
#         if p == '>':
#             n_pos += 1
#     rule['pos'] = n_pos

In [20]:
# plt.figure()
# probs = [r['proba'] for r in rules]
# plt.hist(probs, bins = 100)
# plt.show()

In [21]:
# rules = sorted(rules, key=lambda x:x['pos'])
# rules = [r for r in rules if r['proba'] > 50]
# print(len(rules))

In [22]:
# for i in range(17):
#     r_i = rules[i]
#     print(f"------------- rule {i} length {len(r_i)} -------------")
#     print(r_i)

# Train a Soft-Decision-Tree given the self-labels

## Prepare the dataset

In [23]:
tensor_dataset = torch.stack(dataset_)

In [24]:
tree_dataset = list(zip(tensor_dataset[clusters!=-1], clusters[clusters != -1]))
batch_size = 512
tree_loader = torch.utils.data.DataLoader(tree_dataset, batch_size=batch_size, shuffle=True)

# Define how we prune the weights of a node

In [25]:
def prune_node(node_weights, factor=1):
    w = node_weights.cpu().detach().numpy()
    mean_ = np.mean(w)
    std_ = np.std(w)
    node_weights[((mean_ - std_ * factor) < node_weights) & (node_weights < (mean_ + std_ * factor))] = 0
    return node_weights

def prune_node_keep(node_weights, keep=4):
    w = node_weights.cpu().detach().numpy()
    throw_idx = np.argsort(abs(w))[:-keep]
    node_weights[throw_idx] = 0
    return node_weights

def prune_tree(tree_, factor):
    new_weights = tree_.inner_nodes.weight.clone()
    for i in range(new_weights.shape[0]):
        res = prune_node_keep(new_weights[i, :], factor)
        new_weights[i, :] = res

    with torch.no_grad():
        tree_.inner_nodes.weight.copy_(new_weights)
        
def sparseness(x):
    s = []
    for i in range(x.shape[0]):
        x_ = x[i, :]
        sp = (len(x_) - torch.norm(x_, 0).item()) / len(x_)
        s.append(sp)
    return np.mean(s)

def compute_regularization_by_level(tree):
    total_reg = 0
    for i in range(tree.inner_nodes.weight.shape[0]):
        cur_level = np.floor(np.log2(i+1))
        node_reg = torch.norm(tree.inner_nodes.weight[i].view(-1), 2)
        total_reg += 2**(-cur_level) * node_reg
    return total_reg

def show_sparseness(tree):
    avg_sp = sparseness(tree.inner_nodes.weight)
    print(f"Average sparseness: {avg_sp}")
    layer = 0
    sps = []
    for i in range(tree.inner_nodes.weight.shape[0]):
        cur_layer = int(np.floor(np.log2(i+1)))
        if cur_layer != layer:
            print(f"layer {layer}: {np.mean(sps)}")
            sps = []
            layer = cur_layer

        x_ = tree.inner_nodes.weight[i, :]
        sp = (len(x_) - torch.norm(x_, 0).item()) / len(x_)
        sps.append(sp)
        
    return avg_sp

## Training configurations

In [26]:
def do_epoch(model, loader, device, log_interval, losses, accs, epoch, iteration):
    model = model.train()
    for batch_idx, (data, target) in enumerate(loader):
        iteration += 1
        data, target = data.to(device), target.to(device)

        output, penalty = tree.forward(data)

        # Loss
        loss_tree = criterion(output, target.view(-1))

        # Penalty
        loss_tree += penalty

        # Sparse regularization
#         fc_params = torch.cat([x.view(-1) for x in tree.inner_nodes.parameters()])
#         regularization = sparsity_lamda * torch.norm(fc_params, 2)
        regularization = sparsity_lamda * compute_regularization_by_level(tree)
        loss = loss_tree

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        losses.append(loss.item())

        pred = output.data.max(1)[1]
        correct = pred.eq(target.view(-1).data).sum()
        accs.append(correct.item() / data.size()[0])

        # Print training status
        if batch_idx % log_interval == 0:
            print(f"Epoch: {epoch:02d} | Batch: {batch_idx:03d} / {len(loader):03d} | Total loss: {loss.item():.3f} | Reg loss: {regularization.item():.3f} | Tree loss: {loss_tree.item():.3f} | Accuracy: {correct.item() / data.size()[0]:03f} | {round((time.time() - start_time) / iteration, 3)} sec/iter")
            
    return iteration


In [27]:
lr = 5e-3
weight_decay = 5e-4
sparsity_lamda = 2e-3
epochs = 100
output_dim = len(set(clusters))
log_interval = 1
use_cuda = device != 'cpu'

In [28]:
tree = SDT(input_dim=tensor_dataset.shape[1], output_dim=len(clusters - 1), depth=tree_depth, lamda=1e-3, use_cuda=use_cuda)
optimizer = torch.optim.Adam(tree.parameters(),
                                 lr=lr,
                                 weight_decay=weight_decay)
criterion = nn.CrossEntropyLoss()
tree = tree.to(device)

In [29]:
losses = []
accs = []
sparsity = []

In [30]:
start_time = time.time()
iteration = 0
for epoch in range(epochs):
    # Training
    avg_sp = show_sparseness(tree)
    sparsity.append(avg_sp)
    iteration = do_epoch(tree, tree_loader, device, log_interval, losses, accs, epoch, iteration)
    
    if epoch % 1 == 0:
        prune_tree(tree, factor=3)
        

Average sparseness: 0.0
layer 0: 0.0
layer 1: 0.0
layer 2: 0.0
layer 3: 0.0
layer 4: 0.0
Epoch: 00 | Batch: 000 / 025 | Total loss: 9.590 | Reg loss: 0.007 | Tree loss: 9.590 | Accuracy: 0.000000 | 0.097 sec/iter
Epoch: 00 | Batch: 001 / 025 | Total loss: 9.587 | Reg loss: 0.007 | Tree loss: 9.587 | Accuracy: 0.000000 | 0.086 sec/iter
Epoch: 00 | Batch: 002 / 025 | Total loss: 9.571 | Reg loss: 0.007 | Tree loss: 9.571 | Accuracy: 0.000000 | 0.079 sec/iter
Epoch: 00 | Batch: 003 / 025 | Total loss: 9.558 | Reg loss: 0.007 | Tree loss: 9.558 | Accuracy: 0.000000 | 0.075 sec/iter
Epoch: 00 | Batch: 004 / 025 | Total loss: 9.561 | Reg loss: 0.006 | Tree loss: 9.561 | Accuracy: 0.000000 | 0.072 sec/iter
Epoch: 00 | Batch: 005 / 025 | Total loss: 9.550 | Reg loss: 0.006 | Tree loss: 9.550 | Accuracy: 0.000000 | 0.07 sec/iter
Epoch: 00 | Batch: 006 / 025 | Total loss: 9.535 | Reg loss: 0.006 | Tree loss: 9.535 | Accuracy: 0.000000 | 0.068 sec/iter
Epoch: 00 | Batch: 007 / 025 | Total loss: 9

Epoch: 02 | Batch: 015 / 025 | Total loss: 9.186 | Reg loss: 0.008 | Tree loss: 9.186 | Accuracy: 0.101562 | 0.069 sec/iter
Epoch: 02 | Batch: 016 / 025 | Total loss: 9.175 | Reg loss: 0.008 | Tree loss: 9.175 | Accuracy: 0.111328 | 0.069 sec/iter
Epoch: 02 | Batch: 017 / 025 | Total loss: 9.168 | Reg loss: 0.008 | Tree loss: 9.168 | Accuracy: 0.093750 | 0.069 sec/iter
Epoch: 02 | Batch: 018 / 025 | Total loss: 9.159 | Reg loss: 0.009 | Tree loss: 9.159 | Accuracy: 0.109375 | 0.069 sec/iter
Epoch: 02 | Batch: 019 / 025 | Total loss: 9.142 | Reg loss: 0.009 | Tree loss: 9.142 | Accuracy: 0.109375 | 0.069 sec/iter
Epoch: 02 | Batch: 020 / 025 | Total loss: 9.141 | Reg loss: 0.009 | Tree loss: 9.141 | Accuracy: 0.099609 | 0.069 sec/iter
Epoch: 02 | Batch: 021 / 025 | Total loss: 9.129 | Reg loss: 0.010 | Tree loss: 9.129 | Accuracy: 0.089844 | 0.069 sec/iter
Epoch: 02 | Batch: 022 / 025 | Total loss: 9.128 | Reg loss: 0.010 | Tree loss: 9.128 | Accuracy: 0.115234 | 0.069 sec/iter
Epoch: 0

Epoch: 05 | Batch: 003 / 025 | Total loss: 8.896 | Reg loss: 0.010 | Tree loss: 8.896 | Accuracy: 0.117188 | 0.07 sec/iter
Epoch: 05 | Batch: 004 / 025 | Total loss: 8.892 | Reg loss: 0.010 | Tree loss: 8.892 | Accuracy: 0.078125 | 0.07 sec/iter
Epoch: 05 | Batch: 005 / 025 | Total loss: 8.877 | Reg loss: 0.010 | Tree loss: 8.877 | Accuracy: 0.111328 | 0.07 sec/iter
Epoch: 05 | Batch: 006 / 025 | Total loss: 8.872 | Reg loss: 0.010 | Tree loss: 8.872 | Accuracy: 0.111328 | 0.07 sec/iter
Epoch: 05 | Batch: 007 / 025 | Total loss: 8.862 | Reg loss: 0.010 | Tree loss: 8.862 | Accuracy: 0.105469 | 0.07 sec/iter
Epoch: 05 | Batch: 008 / 025 | Total loss: 8.851 | Reg loss: 0.011 | Tree loss: 8.851 | Accuracy: 0.130859 | 0.07 sec/iter
Epoch: 05 | Batch: 009 / 025 | Total loss: 8.839 | Reg loss: 0.011 | Tree loss: 8.839 | Accuracy: 0.107422 | 0.07 sec/iter
Epoch: 05 | Batch: 010 / 025 | Total loss: 8.832 | Reg loss: 0.011 | Tree loss: 8.832 | Accuracy: 0.101562 | 0.07 sec/iter
Epoch: 05 | Batc

Epoch: 07 | Batch: 017 / 025 | Total loss: 8.425 | Reg loss: 0.017 | Tree loss: 8.425 | Accuracy: 0.101562 | 0.069 sec/iter
Epoch: 07 | Batch: 018 / 025 | Total loss: 8.410 | Reg loss: 0.017 | Tree loss: 8.410 | Accuracy: 0.119141 | 0.069 sec/iter
Epoch: 07 | Batch: 019 / 025 | Total loss: 8.395 | Reg loss: 0.017 | Tree loss: 8.395 | Accuracy: 0.115234 | 0.069 sec/iter
Epoch: 07 | Batch: 020 / 025 | Total loss: 8.386 | Reg loss: 0.018 | Tree loss: 8.386 | Accuracy: 0.093750 | 0.069 sec/iter
Epoch: 07 | Batch: 021 / 025 | Total loss: 8.383 | Reg loss: 0.018 | Tree loss: 8.383 | Accuracy: 0.076172 | 0.069 sec/iter
Epoch: 07 | Batch: 022 / 025 | Total loss: 8.365 | Reg loss: 0.018 | Tree loss: 8.365 | Accuracy: 0.099609 | 0.069 sec/iter
Epoch: 07 | Batch: 023 / 025 | Total loss: 8.344 | Reg loss: 0.019 | Tree loss: 8.344 | Accuracy: 0.107422 | 0.069 sec/iter
Epoch: 07 | Batch: 024 / 025 | Total loss: 8.332 | Reg loss: 0.019 | Tree loss: 8.332 | Accuracy: 0.103125 | 0.069 sec/iter
Average 

Epoch: 10 | Batch: 007 / 025 | Total loss: 8.009 | Reg loss: 0.020 | Tree loss: 8.009 | Accuracy: 0.121094 | 0.068 sec/iter
Epoch: 10 | Batch: 008 / 025 | Total loss: 8.001 | Reg loss: 0.020 | Tree loss: 8.001 | Accuracy: 0.109375 | 0.068 sec/iter
Epoch: 10 | Batch: 009 / 025 | Total loss: 7.980 | Reg loss: 0.020 | Tree loss: 7.980 | Accuracy: 0.123047 | 0.068 sec/iter
Epoch: 10 | Batch: 010 / 025 | Total loss: 7.955 | Reg loss: 0.020 | Tree loss: 7.955 | Accuracy: 0.113281 | 0.068 sec/iter
Epoch: 10 | Batch: 011 / 025 | Total loss: 7.922 | Reg loss: 0.021 | Tree loss: 7.922 | Accuracy: 0.101562 | 0.068 sec/iter
Epoch: 10 | Batch: 012 / 025 | Total loss: 7.910 | Reg loss: 0.021 | Tree loss: 7.910 | Accuracy: 0.121094 | 0.068 sec/iter
Epoch: 10 | Batch: 013 / 025 | Total loss: 7.901 | Reg loss: 0.021 | Tree loss: 7.901 | Accuracy: 0.099609 | 0.068 sec/iter
Epoch: 10 | Batch: 014 / 025 | Total loss: 7.861 | Reg loss: 0.021 | Tree loss: 7.861 | Accuracy: 0.101562 | 0.068 sec/iter
Epoch: 1

Epoch: 12 | Batch: 023 / 025 | Total loss: 7.269 | Reg loss: 0.027 | Tree loss: 7.269 | Accuracy: 0.113281 | 0.069 sec/iter
Epoch: 12 | Batch: 024 / 025 | Total loss: 7.244 | Reg loss: 0.027 | Tree loss: 7.244 | Accuracy: 0.140625 | 0.068 sec/iter
Average sparseness: 0.9821428571428567
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
Epoch: 13 | Batch: 000 / 025 | Total loss: 7.575 | Reg loss: 0.023 | Tree loss: 7.575 | Accuracy: 0.136719 | 0.069 sec/iter
Epoch: 13 | Batch: 001 / 025 | Total loss: 7.567 | Reg loss: 0.023 | Tree loss: 7.567 | Accuracy: 0.138672 | 0.069 sec/iter
Epoch: 13 | Batch: 002 / 025 | Total loss: 7.520 | Reg loss: 0.023 | Tree loss: 7.520 | Accuracy: 0.126953 | 0.069 sec/iter
Epoch: 13 | Batch: 003 / 025 | Total loss: 7.503 | Reg loss: 0.024 | Tree loss: 7.503 | Accuracy: 0.123047 | 0.069 sec/iter
Epoch: 13 | Batch: 004 / 025 | Total loss: 7.511 | Reg loss: 0.024 | Tree los

Epoch: 15 | Batch: 010 / 025 | Total loss: 6.974 | Reg loss: 0.027 | Tree loss: 6.974 | Accuracy: 0.113281 | 0.069 sec/iter
Epoch: 15 | Batch: 011 / 025 | Total loss: 6.956 | Reg loss: 0.027 | Tree loss: 6.956 | Accuracy: 0.134766 | 0.069 sec/iter
Epoch: 15 | Batch: 012 / 025 | Total loss: 6.920 | Reg loss: 0.027 | Tree loss: 6.920 | Accuracy: 0.115234 | 0.069 sec/iter
Epoch: 15 | Batch: 013 / 025 | Total loss: 6.948 | Reg loss: 0.027 | Tree loss: 6.948 | Accuracy: 0.083984 | 0.069 sec/iter
Epoch: 15 | Batch: 014 / 025 | Total loss: 6.914 | Reg loss: 0.027 | Tree loss: 6.914 | Accuracy: 0.132812 | 0.069 sec/iter
Epoch: 15 | Batch: 015 / 025 | Total loss: 6.880 | Reg loss: 0.027 | Tree loss: 6.880 | Accuracy: 0.107422 | 0.069 sec/iter
Epoch: 15 | Batch: 016 / 025 | Total loss: 6.859 | Reg loss: 0.028 | Tree loss: 6.859 | Accuracy: 0.109375 | 0.069 sec/iter
Epoch: 15 | Batch: 017 / 025 | Total loss: 6.821 | Reg loss: 0.028 | Tree loss: 6.821 | Accuracy: 0.121094 | 0.069 sec/iter
Epoch: 1

Epoch: 18 | Batch: 000 / 025 | Total loss: 6.668 | Reg loss: 0.028 | Tree loss: 6.668 | Accuracy: 0.138672 | 0.069 sec/iter
Epoch: 18 | Batch: 001 / 025 | Total loss: 6.595 | Reg loss: 0.028 | Tree loss: 6.595 | Accuracy: 0.142578 | 0.069 sec/iter
Epoch: 18 | Batch: 002 / 025 | Total loss: 6.622 | Reg loss: 0.028 | Tree loss: 6.622 | Accuracy: 0.134766 | 0.069 sec/iter
Epoch: 18 | Batch: 003 / 025 | Total loss: 6.564 | Reg loss: 0.028 | Tree loss: 6.564 | Accuracy: 0.150391 | 0.069 sec/iter
Epoch: 18 | Batch: 004 / 025 | Total loss: 6.543 | Reg loss: 0.028 | Tree loss: 6.543 | Accuracy: 0.134766 | 0.069 sec/iter
Epoch: 18 | Batch: 005 / 025 | Total loss: 6.567 | Reg loss: 0.028 | Tree loss: 6.567 | Accuracy: 0.126953 | 0.069 sec/iter
Epoch: 18 | Batch: 006 / 025 | Total loss: 6.522 | Reg loss: 0.028 | Tree loss: 6.522 | Accuracy: 0.132812 | 0.069 sec/iter
Epoch: 18 | Batch: 007 / 025 | Total loss: 6.510 | Reg loss: 0.028 | Tree loss: 6.510 | Accuracy: 0.107422 | 0.069 sec/iter
Epoch: 1

Epoch: 20 | Batch: 015 / 025 | Total loss: 5.977 | Reg loss: 0.030 | Tree loss: 5.977 | Accuracy: 0.134766 | 0.069 sec/iter
Epoch: 20 | Batch: 016 / 025 | Total loss: 5.974 | Reg loss: 0.030 | Tree loss: 5.974 | Accuracy: 0.126953 | 0.069 sec/iter
Epoch: 20 | Batch: 017 / 025 | Total loss: 5.944 | Reg loss: 0.030 | Tree loss: 5.944 | Accuracy: 0.134766 | 0.069 sec/iter
Epoch: 20 | Batch: 018 / 025 | Total loss: 5.894 | Reg loss: 0.030 | Tree loss: 5.894 | Accuracy: 0.121094 | 0.069 sec/iter
Epoch: 20 | Batch: 019 / 025 | Total loss: 5.861 | Reg loss: 0.031 | Tree loss: 5.861 | Accuracy: 0.132812 | 0.069 sec/iter
Epoch: 20 | Batch: 020 / 025 | Total loss: 5.819 | Reg loss: 0.031 | Tree loss: 5.819 | Accuracy: 0.162109 | 0.069 sec/iter
Epoch: 20 | Batch: 021 / 025 | Total loss: 5.849 | Reg loss: 0.031 | Tree loss: 5.849 | Accuracy: 0.111328 | 0.069 sec/iter
Epoch: 20 | Batch: 022 / 025 | Total loss: 5.837 | Reg loss: 0.031 | Tree loss: 5.837 | Accuracy: 0.111328 | 0.069 sec/iter
Epoch: 2

Epoch: 23 | Batch: 005 / 025 | Total loss: 5.599 | Reg loss: 0.030 | Tree loss: 5.599 | Accuracy: 0.130859 | 0.07 sec/iter
Epoch: 23 | Batch: 006 / 025 | Total loss: 5.559 | Reg loss: 0.030 | Tree loss: 5.559 | Accuracy: 0.132812 | 0.07 sec/iter
Epoch: 23 | Batch: 007 / 025 | Total loss: 5.530 | Reg loss: 0.030 | Tree loss: 5.530 | Accuracy: 0.136719 | 0.07 sec/iter
Epoch: 23 | Batch: 008 / 025 | Total loss: 5.492 | Reg loss: 0.030 | Tree loss: 5.492 | Accuracy: 0.132812 | 0.07 sec/iter
Epoch: 23 | Batch: 009 / 025 | Total loss: 5.471 | Reg loss: 0.030 | Tree loss: 5.471 | Accuracy: 0.121094 | 0.069 sec/iter
Epoch: 23 | Batch: 010 / 025 | Total loss: 5.484 | Reg loss: 0.031 | Tree loss: 5.484 | Accuracy: 0.123047 | 0.069 sec/iter
Epoch: 23 | Batch: 011 / 025 | Total loss: 5.460 | Reg loss: 0.031 | Tree loss: 5.460 | Accuracy: 0.144531 | 0.069 sec/iter
Epoch: 23 | Batch: 012 / 025 | Total loss: 5.398 | Reg loss: 0.031 | Tree loss: 5.398 | Accuracy: 0.115234 | 0.069 sec/iter
Epoch: 23 | 

Epoch: 25 | Batch: 021 / 025 | Total loss: 4.964 | Reg loss: 0.032 | Tree loss: 4.964 | Accuracy: 0.132812 | 0.069 sec/iter
Epoch: 25 | Batch: 022 / 025 | Total loss: 4.905 | Reg loss: 0.032 | Tree loss: 4.905 | Accuracy: 0.121094 | 0.069 sec/iter
Epoch: 25 | Batch: 023 / 025 | Total loss: 4.858 | Reg loss: 0.032 | Tree loss: 4.858 | Accuracy: 0.136719 | 0.069 sec/iter
Epoch: 25 | Batch: 024 / 025 | Total loss: 4.848 | Reg loss: 0.032 | Tree loss: 4.848 | Accuracy: 0.115625 | 0.069 sec/iter
Average sparseness: 0.9821428571428567
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
Epoch: 26 | Batch: 000 / 025 | Total loss: 5.117 | Reg loss: 0.031 | Tree loss: 5.117 | Accuracy: 0.148438 | 0.069 sec/iter
Epoch: 26 | Batch: 001 / 025 | Total loss: 5.094 | Reg loss: 0.031 | Tree loss: 5.094 | Accuracy: 0.142578 | 0.069 sec/iter
Epoch: 26 | Batch: 002 / 025 | Total loss: 5.134 | Reg loss: 0.031 | Tree los

Epoch: 28 | Batch: 009 / 025 | Total loss: 4.537 | Reg loss: 0.030 | Tree loss: 4.537 | Accuracy: 0.134766 | 0.07 sec/iter
Epoch: 28 | Batch: 010 / 025 | Total loss: 4.600 | Reg loss: 0.030 | Tree loss: 4.600 | Accuracy: 0.113281 | 0.07 sec/iter
Epoch: 28 | Batch: 011 / 025 | Total loss: 4.535 | Reg loss: 0.030 | Tree loss: 4.535 | Accuracy: 0.117188 | 0.07 sec/iter
Epoch: 28 | Batch: 012 / 025 | Total loss: 4.623 | Reg loss: 0.030 | Tree loss: 4.623 | Accuracy: 0.134766 | 0.07 sec/iter
Epoch: 28 | Batch: 013 / 025 | Total loss: 4.526 | Reg loss: 0.030 | Tree loss: 4.526 | Accuracy: 0.128906 | 0.07 sec/iter
Epoch: 28 | Batch: 014 / 025 | Total loss: 4.562 | Reg loss: 0.030 | Tree loss: 4.562 | Accuracy: 0.089844 | 0.07 sec/iter
Epoch: 28 | Batch: 015 / 025 | Total loss: 4.523 | Reg loss: 0.030 | Tree loss: 4.523 | Accuracy: 0.123047 | 0.07 sec/iter
Epoch: 28 | Batch: 016 / 025 | Total loss: 4.590 | Reg loss: 0.030 | Tree loss: 4.590 | Accuracy: 0.119141 | 0.07 sec/iter
Epoch: 28 | Batc

Epoch: 31 | Batch: 000 / 025 | Total loss: 4.310 | Reg loss: 0.028 | Tree loss: 4.310 | Accuracy: 0.138672 | 0.07 sec/iter
Epoch: 31 | Batch: 001 / 025 | Total loss: 4.349 | Reg loss: 0.028 | Tree loss: 4.349 | Accuracy: 0.095703 | 0.07 sec/iter
Epoch: 31 | Batch: 002 / 025 | Total loss: 4.324 | Reg loss: 0.028 | Tree loss: 4.324 | Accuracy: 0.125000 | 0.07 sec/iter
Epoch: 31 | Batch: 003 / 025 | Total loss: 4.305 | Reg loss: 0.028 | Tree loss: 4.305 | Accuracy: 0.121094 | 0.07 sec/iter
Epoch: 31 | Batch: 004 / 025 | Total loss: 4.312 | Reg loss: 0.028 | Tree loss: 4.312 | Accuracy: 0.138672 | 0.07 sec/iter
Epoch: 31 | Batch: 005 / 025 | Total loss: 4.243 | Reg loss: 0.028 | Tree loss: 4.243 | Accuracy: 0.119141 | 0.07 sec/iter
Epoch: 31 | Batch: 006 / 025 | Total loss: 4.216 | Reg loss: 0.028 | Tree loss: 4.216 | Accuracy: 0.148438 | 0.07 sec/iter
Epoch: 31 | Batch: 007 / 025 | Total loss: 4.206 | Reg loss: 0.028 | Tree loss: 4.206 | Accuracy: 0.132812 | 0.07 sec/iter
Epoch: 31 | Batc

Epoch: 33 | Batch: 016 / 025 | Total loss: 3.790 | Reg loss: 0.027 | Tree loss: 3.790 | Accuracy: 0.140625 | 0.071 sec/iter
Epoch: 33 | Batch: 017 / 025 | Total loss: 3.770 | Reg loss: 0.027 | Tree loss: 3.770 | Accuracy: 0.144531 | 0.071 sec/iter
Epoch: 33 | Batch: 018 / 025 | Total loss: 3.777 | Reg loss: 0.027 | Tree loss: 3.777 | Accuracy: 0.140625 | 0.071 sec/iter
Epoch: 33 | Batch: 019 / 025 | Total loss: 3.777 | Reg loss: 0.028 | Tree loss: 3.777 | Accuracy: 0.107422 | 0.071 sec/iter
Epoch: 33 | Batch: 020 / 025 | Total loss: 3.612 | Reg loss: 0.028 | Tree loss: 3.612 | Accuracy: 0.171875 | 0.071 sec/iter
Epoch: 33 | Batch: 021 / 025 | Total loss: 3.677 | Reg loss: 0.028 | Tree loss: 3.677 | Accuracy: 0.132812 | 0.071 sec/iter
Epoch: 33 | Batch: 022 / 025 | Total loss: 3.696 | Reg loss: 0.028 | Tree loss: 3.696 | Accuracy: 0.128906 | 0.071 sec/iter
Epoch: 33 | Batch: 023 / 025 | Total loss: 3.643 | Reg loss: 0.028 | Tree loss: 3.643 | Accuracy: 0.160156 | 0.071 sec/iter
Epoch: 3

Epoch: 36 | Batch: 006 / 025 | Total loss: 3.647 | Reg loss: 0.025 | Tree loss: 3.647 | Accuracy: 0.123047 | 0.071 sec/iter
Epoch: 36 | Batch: 007 / 025 | Total loss: 3.575 | Reg loss: 0.026 | Tree loss: 3.575 | Accuracy: 0.105469 | 0.071 sec/iter
Epoch: 36 | Batch: 008 / 025 | Total loss: 3.577 | Reg loss: 0.026 | Tree loss: 3.577 | Accuracy: 0.138672 | 0.071 sec/iter
Epoch: 36 | Batch: 009 / 025 | Total loss: 3.510 | Reg loss: 0.026 | Tree loss: 3.510 | Accuracy: 0.148438 | 0.071 sec/iter
Epoch: 36 | Batch: 010 / 025 | Total loss: 3.548 | Reg loss: 0.026 | Tree loss: 3.548 | Accuracy: 0.152344 | 0.071 sec/iter
Epoch: 36 | Batch: 011 / 025 | Total loss: 3.514 | Reg loss: 0.026 | Tree loss: 3.514 | Accuracy: 0.121094 | 0.071 sec/iter
Epoch: 36 | Batch: 012 / 025 | Total loss: 3.458 | Reg loss: 0.026 | Tree loss: 3.458 | Accuracy: 0.132812 | 0.071 sec/iter
Epoch: 36 | Batch: 013 / 025 | Total loss: 3.416 | Reg loss: 0.026 | Tree loss: 3.416 | Accuracy: 0.126953 | 0.071 sec/iter
Epoch: 3

Epoch: 38 | Batch: 021 / 025 | Total loss: 3.222 | Reg loss: 0.028 | Tree loss: 3.222 | Accuracy: 0.109375 | 0.071 sec/iter
Epoch: 38 | Batch: 022 / 025 | Total loss: 3.263 | Reg loss: 0.028 | Tree loss: 3.263 | Accuracy: 0.123047 | 0.071 sec/iter
Epoch: 38 | Batch: 023 / 025 | Total loss: 3.218 | Reg loss: 0.028 | Tree loss: 3.218 | Accuracy: 0.150391 | 0.071 sec/iter
Epoch: 38 | Batch: 024 / 025 | Total loss: 3.179 | Reg loss: 0.029 | Tree loss: 3.179 | Accuracy: 0.128125 | 0.071 sec/iter
Average sparseness: 0.9821428571428567
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
Epoch: 39 | Batch: 000 / 025 | Total loss: 3.452 | Reg loss: 0.026 | Tree loss: 3.452 | Accuracy: 0.126953 | 0.071 sec/iter
Epoch: 39 | Batch: 001 / 025 | Total loss: 3.416 | Reg loss: 0.026 | Tree loss: 3.416 | Accuracy: 0.144531 | 0.071 sec/iter
Epoch: 39 | Batch: 002 / 025 | Total loss: 3.366 | Reg loss: 0.026 | Tree los

Epoch: 41 | Batch: 011 / 025 | Total loss: 3.136 | Reg loss: 0.028 | Tree loss: 3.136 | Accuracy: 0.142578 | 0.07 sec/iter
Epoch: 41 | Batch: 012 / 025 | Total loss: 3.071 | Reg loss: 0.028 | Tree loss: 3.071 | Accuracy: 0.121094 | 0.07 sec/iter
Epoch: 41 | Batch: 013 / 025 | Total loss: 3.108 | Reg loss: 0.028 | Tree loss: 3.108 | Accuracy: 0.150391 | 0.07 sec/iter
Epoch: 41 | Batch: 014 / 025 | Total loss: 3.112 | Reg loss: 0.028 | Tree loss: 3.112 | Accuracy: 0.146484 | 0.07 sec/iter
Epoch: 41 | Batch: 015 / 025 | Total loss: 3.043 | Reg loss: 0.028 | Tree loss: 3.043 | Accuracy: 0.115234 | 0.07 sec/iter
Epoch: 41 | Batch: 016 / 025 | Total loss: 3.045 | Reg loss: 0.028 | Tree loss: 3.045 | Accuracy: 0.119141 | 0.07 sec/iter
Epoch: 41 | Batch: 017 / 025 | Total loss: 3.083 | Reg loss: 0.028 | Tree loss: 3.083 | Accuracy: 0.138672 | 0.07 sec/iter
Epoch: 41 | Batch: 018 / 025 | Total loss: 2.985 | Reg loss: 0.028 | Tree loss: 2.985 | Accuracy: 0.123047 | 0.07 sec/iter
Epoch: 41 | Batc

Epoch: 44 | Batch: 001 / 025 | Total loss: 3.083 | Reg loss: 0.028 | Tree loss: 3.083 | Accuracy: 0.119141 | 0.071 sec/iter
Epoch: 44 | Batch: 002 / 025 | Total loss: 3.061 | Reg loss: 0.028 | Tree loss: 3.061 | Accuracy: 0.132812 | 0.071 sec/iter
Epoch: 44 | Batch: 003 / 025 | Total loss: 3.055 | Reg loss: 0.028 | Tree loss: 3.055 | Accuracy: 0.142578 | 0.071 sec/iter
Epoch: 44 | Batch: 004 / 025 | Total loss: 3.013 | Reg loss: 0.028 | Tree loss: 3.013 | Accuracy: 0.179688 | 0.071 sec/iter
Epoch: 44 | Batch: 005 / 025 | Total loss: 3.054 | Reg loss: 0.028 | Tree loss: 3.054 | Accuracy: 0.167969 | 0.071 sec/iter
Epoch: 44 | Batch: 006 / 025 | Total loss: 3.033 | Reg loss: 0.028 | Tree loss: 3.033 | Accuracy: 0.152344 | 0.071 sec/iter
Epoch: 44 | Batch: 007 / 025 | Total loss: 3.022 | Reg loss: 0.028 | Tree loss: 3.022 | Accuracy: 0.144531 | 0.071 sec/iter
Epoch: 44 | Batch: 008 / 025 | Total loss: 3.003 | Reg loss: 0.028 | Tree loss: 3.003 | Accuracy: 0.125000 | 0.071 sec/iter
Epoch: 4

Epoch: 46 | Batch: 016 / 025 | Total loss: 2.798 | Reg loss: 0.029 | Tree loss: 2.798 | Accuracy: 0.138672 | 0.071 sec/iter
Epoch: 46 | Batch: 017 / 025 | Total loss: 2.855 | Reg loss: 0.029 | Tree loss: 2.855 | Accuracy: 0.107422 | 0.071 sec/iter
Epoch: 46 | Batch: 018 / 025 | Total loss: 2.830 | Reg loss: 0.029 | Tree loss: 2.830 | Accuracy: 0.109375 | 0.071 sec/iter
Epoch: 46 | Batch: 019 / 025 | Total loss: 2.784 | Reg loss: 0.029 | Tree loss: 2.784 | Accuracy: 0.166016 | 0.071 sec/iter
Epoch: 46 | Batch: 020 / 025 | Total loss: 2.818 | Reg loss: 0.029 | Tree loss: 2.818 | Accuracy: 0.121094 | 0.071 sec/iter
Epoch: 46 | Batch: 021 / 025 | Total loss: 2.766 | Reg loss: 0.030 | Tree loss: 2.766 | Accuracy: 0.136719 | 0.071 sec/iter
Epoch: 46 | Batch: 022 / 025 | Total loss: 2.739 | Reg loss: 0.030 | Tree loss: 2.739 | Accuracy: 0.154297 | 0.071 sec/iter
Epoch: 46 | Batch: 023 / 025 | Total loss: 2.770 | Reg loss: 0.030 | Tree loss: 2.770 | Accuracy: 0.130859 | 0.071 sec/iter
Epoch: 4

Epoch: 49 | Batch: 003 / 025 | Total loss: 2.851 | Reg loss: 0.029 | Tree loss: 2.851 | Accuracy: 0.134766 | 0.071 sec/iter
Epoch: 49 | Batch: 004 / 025 | Total loss: 2.885 | Reg loss: 0.029 | Tree loss: 2.885 | Accuracy: 0.152344 | 0.071 sec/iter
Epoch: 49 | Batch: 005 / 025 | Total loss: 2.817 | Reg loss: 0.029 | Tree loss: 2.817 | Accuracy: 0.144531 | 0.071 sec/iter
Epoch: 49 | Batch: 006 / 025 | Total loss: 2.839 | Reg loss: 0.029 | Tree loss: 2.839 | Accuracy: 0.150391 | 0.071 sec/iter
Epoch: 49 | Batch: 007 / 025 | Total loss: 2.817 | Reg loss: 0.029 | Tree loss: 2.817 | Accuracy: 0.126953 | 0.071 sec/iter
Epoch: 49 | Batch: 008 / 025 | Total loss: 2.821 | Reg loss: 0.029 | Tree loss: 2.821 | Accuracy: 0.132812 | 0.071 sec/iter
Epoch: 49 | Batch: 009 / 025 | Total loss: 2.813 | Reg loss: 0.029 | Tree loss: 2.813 | Accuracy: 0.132812 | 0.071 sec/iter
Epoch: 49 | Batch: 010 / 025 | Total loss: 2.760 | Reg loss: 0.029 | Tree loss: 2.760 | Accuracy: 0.142578 | 0.071 sec/iter
Epoch: 4

Epoch: 51 | Batch: 019 / 025 | Total loss: 2.641 | Reg loss: 0.030 | Tree loss: 2.641 | Accuracy: 0.134766 | 0.071 sec/iter
Epoch: 51 | Batch: 020 / 025 | Total loss: 2.659 | Reg loss: 0.030 | Tree loss: 2.659 | Accuracy: 0.132812 | 0.071 sec/iter
Epoch: 51 | Batch: 021 / 025 | Total loss: 2.610 | Reg loss: 0.030 | Tree loss: 2.610 | Accuracy: 0.136719 | 0.071 sec/iter
Epoch: 51 | Batch: 022 / 025 | Total loss: 2.642 | Reg loss: 0.030 | Tree loss: 2.642 | Accuracy: 0.154297 | 0.071 sec/iter
Epoch: 51 | Batch: 023 / 025 | Total loss: 2.628 | Reg loss: 0.030 | Tree loss: 2.628 | Accuracy: 0.140625 | 0.071 sec/iter
Epoch: 51 | Batch: 024 / 025 | Total loss: 2.605 | Reg loss: 0.030 | Tree loss: 2.605 | Accuracy: 0.109375 | 0.071 sec/iter
Average sparseness: 0.9821428571428567
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
Epoch: 52 | Batch: 000 / 025 | Total loss: 2.798 | Reg loss: 0.029 | Tree los

Epoch: 54 | Batch: 007 / 025 | Total loss: 2.684 | Reg loss: 0.030 | Tree loss: 2.684 | Accuracy: 0.142578 | 0.071 sec/iter
Epoch: 54 | Batch: 008 / 025 | Total loss: 2.648 | Reg loss: 0.030 | Tree loss: 2.648 | Accuracy: 0.136719 | 0.071 sec/iter
Epoch: 54 | Batch: 009 / 025 | Total loss: 2.694 | Reg loss: 0.030 | Tree loss: 2.694 | Accuracy: 0.103516 | 0.071 sec/iter
Epoch: 54 | Batch: 010 / 025 | Total loss: 2.660 | Reg loss: 0.030 | Tree loss: 2.660 | Accuracy: 0.154297 | 0.071 sec/iter
Epoch: 54 | Batch: 011 / 025 | Total loss: 2.680 | Reg loss: 0.030 | Tree loss: 2.680 | Accuracy: 0.148438 | 0.071 sec/iter
Epoch: 54 | Batch: 012 / 025 | Total loss: 2.609 | Reg loss: 0.030 | Tree loss: 2.609 | Accuracy: 0.125000 | 0.071 sec/iter
Epoch: 54 | Batch: 013 / 025 | Total loss: 2.609 | Reg loss: 0.030 | Tree loss: 2.609 | Accuracy: 0.154297 | 0.071 sec/iter
Epoch: 54 | Batch: 014 / 025 | Total loss: 2.640 | Reg loss: 0.030 | Tree loss: 2.640 | Accuracy: 0.134766 | 0.071 sec/iter
Epoch: 5

Epoch: 56 | Batch: 021 / 025 | Total loss: 2.548 | Reg loss: 0.030 | Tree loss: 2.548 | Accuracy: 0.130859 | 0.071 sec/iter
Epoch: 56 | Batch: 022 / 025 | Total loss: 2.517 | Reg loss: 0.030 | Tree loss: 2.517 | Accuracy: 0.154297 | 0.071 sec/iter
Epoch: 56 | Batch: 023 / 025 | Total loss: 2.540 | Reg loss: 0.030 | Tree loss: 2.540 | Accuracy: 0.144531 | 0.071 sec/iter
Epoch: 56 | Batch: 024 / 025 | Total loss: 2.571 | Reg loss: 0.030 | Tree loss: 2.571 | Accuracy: 0.125000 | 0.071 sec/iter
Average sparseness: 0.9821428571428567
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
Epoch: 57 | Batch: 000 / 025 | Total loss: 2.654 | Reg loss: 0.030 | Tree loss: 2.654 | Accuracy: 0.173828 | 0.071 sec/iter
Epoch: 57 | Batch: 001 / 025 | Total loss: 2.697 | Reg loss: 0.030 | Tree loss: 2.697 | Accuracy: 0.166016 | 0.071 sec/iter
Epoch: 57 | Batch: 002 / 025 | Total loss: 2.673 | Reg loss: 0.030 | Tree los

Epoch: 59 | Batch: 010 / 025 | Total loss: 2.585 | Reg loss: 0.030 | Tree loss: 2.585 | Accuracy: 0.144531 | 0.071 sec/iter
Epoch: 59 | Batch: 011 / 025 | Total loss: 2.554 | Reg loss: 0.030 | Tree loss: 2.554 | Accuracy: 0.166016 | 0.071 sec/iter
Epoch: 59 | Batch: 012 / 025 | Total loss: 2.548 | Reg loss: 0.030 | Tree loss: 2.548 | Accuracy: 0.144531 | 0.071 sec/iter
Epoch: 59 | Batch: 013 / 025 | Total loss: 2.582 | Reg loss: 0.030 | Tree loss: 2.582 | Accuracy: 0.177734 | 0.071 sec/iter
Epoch: 59 | Batch: 014 / 025 | Total loss: 2.571 | Reg loss: 0.030 | Tree loss: 2.571 | Accuracy: 0.156250 | 0.071 sec/iter
Epoch: 59 | Batch: 015 / 025 | Total loss: 2.533 | Reg loss: 0.030 | Tree loss: 2.533 | Accuracy: 0.156250 | 0.071 sec/iter
Epoch: 59 | Batch: 016 / 025 | Total loss: 2.540 | Reg loss: 0.030 | Tree loss: 2.540 | Accuracy: 0.160156 | 0.071 sec/iter
Epoch: 59 | Batch: 017 / 025 | Total loss: 2.515 | Reg loss: 0.030 | Tree loss: 2.515 | Accuracy: 0.154297 | 0.071 sec/iter
Epoch: 5

Epoch: 62 | Batch: 001 / 025 | Total loss: 2.603 | Reg loss: 0.030 | Tree loss: 2.603 | Accuracy: 0.207031 | 0.071 sec/iter
Epoch: 62 | Batch: 002 / 025 | Total loss: 2.592 | Reg loss: 0.030 | Tree loss: 2.592 | Accuracy: 0.148438 | 0.071 sec/iter
Epoch: 62 | Batch: 003 / 025 | Total loss: 2.617 | Reg loss: 0.030 | Tree loss: 2.617 | Accuracy: 0.152344 | 0.071 sec/iter
Epoch: 62 | Batch: 004 / 025 | Total loss: 2.597 | Reg loss: 0.030 | Tree loss: 2.597 | Accuracy: 0.167969 | 0.071 sec/iter
Epoch: 62 | Batch: 005 / 025 | Total loss: 2.583 | Reg loss: 0.030 | Tree loss: 2.583 | Accuracy: 0.166016 | 0.071 sec/iter
Epoch: 62 | Batch: 006 / 025 | Total loss: 2.544 | Reg loss: 0.030 | Tree loss: 2.544 | Accuracy: 0.175781 | 0.071 sec/iter
Epoch: 62 | Batch: 007 / 025 | Total loss: 2.573 | Reg loss: 0.030 | Tree loss: 2.573 | Accuracy: 0.173828 | 0.071 sec/iter
Epoch: 62 | Batch: 008 / 025 | Total loss: 2.555 | Reg loss: 0.030 | Tree loss: 2.555 | Accuracy: 0.164062 | 0.071 sec/iter
Epoch: 6

Epoch: 64 | Batch: 015 / 025 | Total loss: 2.484 | Reg loss: 0.030 | Tree loss: 2.484 | Accuracy: 0.177734 | 0.071 sec/iter
Epoch: 64 | Batch: 016 / 025 | Total loss: 2.471 | Reg loss: 0.030 | Tree loss: 2.471 | Accuracy: 0.167969 | 0.071 sec/iter
Epoch: 64 | Batch: 017 / 025 | Total loss: 2.469 | Reg loss: 0.030 | Tree loss: 2.469 | Accuracy: 0.144531 | 0.071 sec/iter
Epoch: 64 | Batch: 018 / 025 | Total loss: 2.487 | Reg loss: 0.030 | Tree loss: 2.487 | Accuracy: 0.148438 | 0.071 sec/iter
Epoch: 64 | Batch: 019 / 025 | Total loss: 2.443 | Reg loss: 0.030 | Tree loss: 2.443 | Accuracy: 0.166016 | 0.071 sec/iter
Epoch: 64 | Batch: 020 / 025 | Total loss: 2.473 | Reg loss: 0.030 | Tree loss: 2.473 | Accuracy: 0.146484 | 0.071 sec/iter
Epoch: 64 | Batch: 021 / 025 | Total loss: 2.433 | Reg loss: 0.030 | Tree loss: 2.433 | Accuracy: 0.181641 | 0.071 sec/iter
Epoch: 64 | Batch: 022 / 025 | Total loss: 2.448 | Reg loss: 0.030 | Tree loss: 2.448 | Accuracy: 0.146484 | 0.071 sec/iter
Epoch: 6

Epoch: 67 | Batch: 002 / 025 | Total loss: 2.524 | Reg loss: 0.030 | Tree loss: 2.524 | Accuracy: 0.160156 | 0.071 sec/iter
Epoch: 67 | Batch: 003 / 025 | Total loss: 2.529 | Reg loss: 0.030 | Tree loss: 2.529 | Accuracy: 0.148438 | 0.071 sec/iter
Epoch: 67 | Batch: 004 / 025 | Total loss: 2.557 | Reg loss: 0.030 | Tree loss: 2.557 | Accuracy: 0.132812 | 0.071 sec/iter
Epoch: 67 | Batch: 005 / 025 | Total loss: 2.516 | Reg loss: 0.030 | Tree loss: 2.516 | Accuracy: 0.179688 | 0.071 sec/iter
Epoch: 67 | Batch: 006 / 025 | Total loss: 2.503 | Reg loss: 0.030 | Tree loss: 2.503 | Accuracy: 0.173828 | 0.071 sec/iter
Epoch: 67 | Batch: 007 / 025 | Total loss: 2.485 | Reg loss: 0.030 | Tree loss: 2.485 | Accuracy: 0.162109 | 0.071 sec/iter
Epoch: 67 | Batch: 008 / 025 | Total loss: 2.444 | Reg loss: 0.030 | Tree loss: 2.444 | Accuracy: 0.173828 | 0.071 sec/iter
Epoch: 67 | Batch: 009 / 025 | Total loss: 2.484 | Reg loss: 0.030 | Tree loss: 2.484 | Accuracy: 0.154297 | 0.071 sec/iter
Epoch: 6

Epoch: 69 | Batch: 016 / 025 | Total loss: 2.458 | Reg loss: 0.030 | Tree loss: 2.458 | Accuracy: 0.140625 | 0.071 sec/iter
Epoch: 69 | Batch: 017 / 025 | Total loss: 2.433 | Reg loss: 0.030 | Tree loss: 2.433 | Accuracy: 0.187500 | 0.071 sec/iter
Epoch: 69 | Batch: 018 / 025 | Total loss: 2.447 | Reg loss: 0.030 | Tree loss: 2.447 | Accuracy: 0.142578 | 0.071 sec/iter
Epoch: 69 | Batch: 019 / 025 | Total loss: 2.457 | Reg loss: 0.030 | Tree loss: 2.457 | Accuracy: 0.169922 | 0.071 sec/iter
Epoch: 69 | Batch: 020 / 025 | Total loss: 2.378 | Reg loss: 0.030 | Tree loss: 2.378 | Accuracy: 0.189453 | 0.071 sec/iter
Epoch: 69 | Batch: 021 / 025 | Total loss: 2.367 | Reg loss: 0.030 | Tree loss: 2.367 | Accuracy: 0.164062 | 0.071 sec/iter
Epoch: 69 | Batch: 022 / 025 | Total loss: 2.404 | Reg loss: 0.030 | Tree loss: 2.404 | Accuracy: 0.150391 | 0.071 sec/iter
Epoch: 69 | Batch: 023 / 025 | Total loss: 2.383 | Reg loss: 0.030 | Tree loss: 2.383 | Accuracy: 0.167969 | 0.071 sec/iter
Epoch: 6

Epoch: 72 | Batch: 003 / 025 | Total loss: 2.489 | Reg loss: 0.030 | Tree loss: 2.489 | Accuracy: 0.160156 | 0.071 sec/iter
Epoch: 72 | Batch: 004 / 025 | Total loss: 2.477 | Reg loss: 0.030 | Tree loss: 2.477 | Accuracy: 0.208984 | 0.071 sec/iter
Epoch: 72 | Batch: 005 / 025 | Total loss: 2.463 | Reg loss: 0.030 | Tree loss: 2.463 | Accuracy: 0.171875 | 0.071 sec/iter
Epoch: 72 | Batch: 006 / 025 | Total loss: 2.492 | Reg loss: 0.030 | Tree loss: 2.492 | Accuracy: 0.150391 | 0.071 sec/iter
Epoch: 72 | Batch: 007 / 025 | Total loss: 2.493 | Reg loss: 0.030 | Tree loss: 2.493 | Accuracy: 0.148438 | 0.071 sec/iter
Epoch: 72 | Batch: 008 / 025 | Total loss: 2.468 | Reg loss: 0.030 | Tree loss: 2.468 | Accuracy: 0.156250 | 0.071 sec/iter
Epoch: 72 | Batch: 009 / 025 | Total loss: 2.478 | Reg loss: 0.030 | Tree loss: 2.478 | Accuracy: 0.152344 | 0.071 sec/iter
Epoch: 72 | Batch: 010 / 025 | Total loss: 2.405 | Reg loss: 0.030 | Tree loss: 2.405 | Accuracy: 0.158203 | 0.071 sec/iter
Epoch: 7

Epoch: 74 | Batch: 017 / 025 | Total loss: 2.464 | Reg loss: 0.030 | Tree loss: 2.464 | Accuracy: 0.154297 | 0.071 sec/iter
Epoch: 74 | Batch: 018 / 025 | Total loss: 2.380 | Reg loss: 0.030 | Tree loss: 2.380 | Accuracy: 0.167969 | 0.071 sec/iter
Epoch: 74 | Batch: 019 / 025 | Total loss: 2.397 | Reg loss: 0.030 | Tree loss: 2.397 | Accuracy: 0.156250 | 0.071 sec/iter
Epoch: 74 | Batch: 020 / 025 | Total loss: 2.363 | Reg loss: 0.030 | Tree loss: 2.363 | Accuracy: 0.181641 | 0.071 sec/iter
Epoch: 74 | Batch: 021 / 025 | Total loss: 2.355 | Reg loss: 0.030 | Tree loss: 2.355 | Accuracy: 0.207031 | 0.071 sec/iter
Epoch: 74 | Batch: 022 / 025 | Total loss: 2.367 | Reg loss: 0.030 | Tree loss: 2.367 | Accuracy: 0.154297 | 0.071 sec/iter
Epoch: 74 | Batch: 023 / 025 | Total loss: 2.330 | Reg loss: 0.030 | Tree loss: 2.330 | Accuracy: 0.187500 | 0.071 sec/iter
Epoch: 74 | Batch: 024 / 025 | Total loss: 2.364 | Reg loss: 0.030 | Tree loss: 2.364 | Accuracy: 0.165625 | 0.071 sec/iter
Average 

Epoch: 77 | Batch: 005 / 025 | Total loss: 2.441 | Reg loss: 0.029 | Tree loss: 2.441 | Accuracy: 0.152344 | 0.071 sec/iter
Epoch: 77 | Batch: 006 / 025 | Total loss: 2.433 | Reg loss: 0.029 | Tree loss: 2.433 | Accuracy: 0.162109 | 0.071 sec/iter
Epoch: 77 | Batch: 007 / 025 | Total loss: 2.417 | Reg loss: 0.029 | Tree loss: 2.417 | Accuracy: 0.164062 | 0.071 sec/iter
Epoch: 77 | Batch: 008 / 025 | Total loss: 2.465 | Reg loss: 0.029 | Tree loss: 2.465 | Accuracy: 0.193359 | 0.071 sec/iter
Epoch: 77 | Batch: 009 / 025 | Total loss: 2.447 | Reg loss: 0.030 | Tree loss: 2.447 | Accuracy: 0.154297 | 0.071 sec/iter
Epoch: 77 | Batch: 010 / 025 | Total loss: 2.435 | Reg loss: 0.030 | Tree loss: 2.435 | Accuracy: 0.193359 | 0.071 sec/iter
Epoch: 77 | Batch: 011 / 025 | Total loss: 2.392 | Reg loss: 0.030 | Tree loss: 2.392 | Accuracy: 0.173828 | 0.071 sec/iter
Epoch: 77 | Batch: 012 / 025 | Total loss: 2.402 | Reg loss: 0.030 | Tree loss: 2.402 | Accuracy: 0.156250 | 0.071 sec/iter
Epoch: 7

Epoch: 79 | Batch: 022 / 025 | Total loss: 2.381 | Reg loss: 0.030 | Tree loss: 2.381 | Accuracy: 0.154297 | 0.071 sec/iter
Epoch: 79 | Batch: 023 / 025 | Total loss: 2.346 | Reg loss: 0.030 | Tree loss: 2.346 | Accuracy: 0.181641 | 0.071 sec/iter
Epoch: 79 | Batch: 024 / 025 | Total loss: 2.369 | Reg loss: 0.030 | Tree loss: 2.369 | Accuracy: 0.184375 | 0.071 sec/iter
Average sparseness: 0.9821428571428567
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
Epoch: 80 | Batch: 000 / 025 | Total loss: 2.462 | Reg loss: 0.029 | Tree loss: 2.462 | Accuracy: 0.156250 | 0.071 sec/iter
Epoch: 80 | Batch: 001 / 025 | Total loss: 2.458 | Reg loss: 0.029 | Tree loss: 2.458 | Accuracy: 0.191406 | 0.071 sec/iter
Epoch: 80 | Batch: 002 / 025 | Total loss: 2.464 | Reg loss: 0.029 | Tree loss: 2.464 | Accuracy: 0.158203 | 0.071 sec/iter
Epoch: 80 | Batch: 003 / 025 | Total loss: 2.509 | Reg loss: 0.029 | Tree los

Epoch: 82 | Batch: 011 / 025 | Total loss: 2.451 | Reg loss: 0.029 | Tree loss: 2.451 | Accuracy: 0.171875 | 0.071 sec/iter
Epoch: 82 | Batch: 012 / 025 | Total loss: 2.383 | Reg loss: 0.029 | Tree loss: 2.383 | Accuracy: 0.152344 | 0.071 sec/iter
Epoch: 82 | Batch: 013 / 025 | Total loss: 2.386 | Reg loss: 0.029 | Tree loss: 2.386 | Accuracy: 0.158203 | 0.071 sec/iter
Epoch: 82 | Batch: 014 / 025 | Total loss: 2.406 | Reg loss: 0.029 | Tree loss: 2.406 | Accuracy: 0.134766 | 0.071 sec/iter
Epoch: 82 | Batch: 015 / 025 | Total loss: 2.348 | Reg loss: 0.029 | Tree loss: 2.348 | Accuracy: 0.146484 | 0.071 sec/iter
Epoch: 82 | Batch: 016 / 025 | Total loss: 2.357 | Reg loss: 0.029 | Tree loss: 2.357 | Accuracy: 0.191406 | 0.071 sec/iter
Epoch: 82 | Batch: 017 / 025 | Total loss: 2.319 | Reg loss: 0.029 | Tree loss: 2.319 | Accuracy: 0.167969 | 0.071 sec/iter
Epoch: 82 | Batch: 018 / 025 | Total loss: 2.344 | Reg loss: 0.029 | Tree loss: 2.344 | Accuracy: 0.183594 | 0.071 sec/iter
Epoch: 8

Epoch: 85 | Batch: 001 / 025 | Total loss: 2.448 | Reg loss: 0.029 | Tree loss: 2.448 | Accuracy: 0.166016 | 0.071 sec/iter
Epoch: 85 | Batch: 002 / 025 | Total loss: 2.394 | Reg loss: 0.029 | Tree loss: 2.394 | Accuracy: 0.183594 | 0.071 sec/iter
Epoch: 85 | Batch: 003 / 025 | Total loss: 2.450 | Reg loss: 0.029 | Tree loss: 2.450 | Accuracy: 0.138672 | 0.071 sec/iter
Epoch: 85 | Batch: 004 / 025 | Total loss: 2.483 | Reg loss: 0.029 | Tree loss: 2.483 | Accuracy: 0.158203 | 0.071 sec/iter
Epoch: 85 | Batch: 005 / 025 | Total loss: 2.396 | Reg loss: 0.029 | Tree loss: 2.396 | Accuracy: 0.171875 | 0.071 sec/iter
Epoch: 85 | Batch: 006 / 025 | Total loss: 2.448 | Reg loss: 0.029 | Tree loss: 2.448 | Accuracy: 0.148438 | 0.071 sec/iter
Epoch: 85 | Batch: 007 / 025 | Total loss: 2.394 | Reg loss: 0.029 | Tree loss: 2.394 | Accuracy: 0.169922 | 0.071 sec/iter
Epoch: 85 | Batch: 008 / 025 | Total loss: 2.405 | Reg loss: 0.029 | Tree loss: 2.405 | Accuracy: 0.150391 | 0.071 sec/iter
Epoch: 8

Epoch: 87 | Batch: 016 / 025 | Total loss: 2.355 | Reg loss: 0.029 | Tree loss: 2.355 | Accuracy: 0.175781 | 0.071 sec/iter
Epoch: 87 | Batch: 017 / 025 | Total loss: 2.334 | Reg loss: 0.029 | Tree loss: 2.334 | Accuracy: 0.160156 | 0.071 sec/iter
Epoch: 87 | Batch: 018 / 025 | Total loss: 2.317 | Reg loss: 0.029 | Tree loss: 2.317 | Accuracy: 0.173828 | 0.071 sec/iter
Epoch: 87 | Batch: 019 / 025 | Total loss: 2.341 | Reg loss: 0.029 | Tree loss: 2.341 | Accuracy: 0.152344 | 0.071 sec/iter
Epoch: 87 | Batch: 020 / 025 | Total loss: 2.353 | Reg loss: 0.029 | Tree loss: 2.353 | Accuracy: 0.177734 | 0.071 sec/iter
Epoch: 87 | Batch: 021 / 025 | Total loss: 2.351 | Reg loss: 0.029 | Tree loss: 2.351 | Accuracy: 0.162109 | 0.071 sec/iter
Epoch: 87 | Batch: 022 / 025 | Total loss: 2.337 | Reg loss: 0.029 | Tree loss: 2.337 | Accuracy: 0.150391 | 0.071 sec/iter
Epoch: 87 | Batch: 023 / 025 | Total loss: 2.287 | Reg loss: 0.029 | Tree loss: 2.287 | Accuracy: 0.214844 | 0.071 sec/iter
Epoch: 8

Epoch: 90 | Batch: 003 / 025 | Total loss: 2.386 | Reg loss: 0.029 | Tree loss: 2.386 | Accuracy: 0.175781 | 0.071 sec/iter
Epoch: 90 | Batch: 004 / 025 | Total loss: 2.413 | Reg loss: 0.029 | Tree loss: 2.413 | Accuracy: 0.177734 | 0.071 sec/iter
Epoch: 90 | Batch: 005 / 025 | Total loss: 2.438 | Reg loss: 0.029 | Tree loss: 2.438 | Accuracy: 0.140625 | 0.071 sec/iter
Epoch: 90 | Batch: 006 / 025 | Total loss: 2.374 | Reg loss: 0.029 | Tree loss: 2.374 | Accuracy: 0.167969 | 0.071 sec/iter
Epoch: 90 | Batch: 007 / 025 | Total loss: 2.356 | Reg loss: 0.029 | Tree loss: 2.356 | Accuracy: 0.175781 | 0.071 sec/iter
Epoch: 90 | Batch: 008 / 025 | Total loss: 2.390 | Reg loss: 0.029 | Tree loss: 2.390 | Accuracy: 0.158203 | 0.071 sec/iter
Epoch: 90 | Batch: 009 / 025 | Total loss: 2.401 | Reg loss: 0.029 | Tree loss: 2.401 | Accuracy: 0.175781 | 0.071 sec/iter
Epoch: 90 | Batch: 010 / 025 | Total loss: 2.379 | Reg loss: 0.029 | Tree loss: 2.379 | Accuracy: 0.173828 | 0.071 sec/iter
Epoch: 9

Epoch: 92 | Batch: 019 / 025 | Total loss: 2.336 | Reg loss: 0.029 | Tree loss: 2.336 | Accuracy: 0.146484 | 0.071 sec/iter
Epoch: 92 | Batch: 020 / 025 | Total loss: 2.363 | Reg loss: 0.029 | Tree loss: 2.363 | Accuracy: 0.152344 | 0.071 sec/iter
Epoch: 92 | Batch: 021 / 025 | Total loss: 2.297 | Reg loss: 0.029 | Tree loss: 2.297 | Accuracy: 0.148438 | 0.071 sec/iter
Epoch: 92 | Batch: 022 / 025 | Total loss: 2.311 | Reg loss: 0.029 | Tree loss: 2.311 | Accuracy: 0.183594 | 0.071 sec/iter
Epoch: 92 | Batch: 023 / 025 | Total loss: 2.303 | Reg loss: 0.029 | Tree loss: 2.303 | Accuracy: 0.150391 | 0.071 sec/iter
Epoch: 92 | Batch: 024 / 025 | Total loss: 2.285 | Reg loss: 0.029 | Tree loss: 2.285 | Accuracy: 0.193750 | 0.071 sec/iter
Average sparseness: 0.9821428571428567
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
Epoch: 93 | Batch: 000 / 025 | Total loss: 2.425 | Reg loss: 0.029 | Tree los

Epoch: 95 | Batch: 007 / 025 | Total loss: 2.368 | Reg loss: 0.029 | Tree loss: 2.368 | Accuracy: 0.189453 | 0.072 sec/iter
Epoch: 95 | Batch: 008 / 025 | Total loss: 2.404 | Reg loss: 0.029 | Tree loss: 2.404 | Accuracy: 0.158203 | 0.072 sec/iter
Epoch: 95 | Batch: 009 / 025 | Total loss: 2.338 | Reg loss: 0.029 | Tree loss: 2.338 | Accuracy: 0.193359 | 0.072 sec/iter
Epoch: 95 | Batch: 010 / 025 | Total loss: 2.348 | Reg loss: 0.029 | Tree loss: 2.348 | Accuracy: 0.199219 | 0.072 sec/iter
Epoch: 95 | Batch: 011 / 025 | Total loss: 2.421 | Reg loss: 0.029 | Tree loss: 2.421 | Accuracy: 0.160156 | 0.072 sec/iter
Epoch: 95 | Batch: 012 / 025 | Total loss: 2.354 | Reg loss: 0.029 | Tree loss: 2.354 | Accuracy: 0.158203 | 0.072 sec/iter
Epoch: 95 | Batch: 013 / 025 | Total loss: 2.334 | Reg loss: 0.029 | Tree loss: 2.334 | Accuracy: 0.160156 | 0.072 sec/iter
Epoch: 95 | Batch: 014 / 025 | Total loss: 2.305 | Reg loss: 0.029 | Tree loss: 2.305 | Accuracy: 0.185547 | 0.072 sec/iter
Epoch: 9

Epoch: 97 | Batch: 021 / 025 | Total loss: 2.277 | Reg loss: 0.029 | Tree loss: 2.277 | Accuracy: 0.177734 | 0.072 sec/iter
Epoch: 97 | Batch: 022 / 025 | Total loss: 2.266 | Reg loss: 0.029 | Tree loss: 2.266 | Accuracy: 0.189453 | 0.072 sec/iter
Epoch: 97 | Batch: 023 / 025 | Total loss: 2.287 | Reg loss: 0.029 | Tree loss: 2.287 | Accuracy: 0.152344 | 0.072 sec/iter
Epoch: 97 | Batch: 024 / 025 | Total loss: 2.272 | Reg loss: 0.029 | Tree loss: 2.272 | Accuracy: 0.181250 | 0.072 sec/iter
Average sparseness: 0.9821428571428567
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
Epoch: 98 | Batch: 000 / 025 | Total loss: 2.456 | Reg loss: 0.028 | Tree loss: 2.456 | Accuracy: 0.164062 | 0.072 sec/iter
Epoch: 98 | Batch: 001 / 025 | Total loss: 2.457 | Reg loss: 0.028 | Tree loss: 2.457 | Accuracy: 0.160156 | 0.072 sec/iter
Epoch: 98 | Batch: 002 / 025 | Total loss: 2.412 | Reg loss: 0.028 | Tree los

In [31]:
plt.figure(figsize=(10, 5))
plt.ylabel("Accuracy")
plt.xlabel('Iteration')
plt.plot(accs, label='Accuracy vs iteration')
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [32]:
plt.figure()
plt.ylabel("Loss")
plt.xlabel('Iteration')
plt.plot(losses, label='Loss vs iteration')
plt.yscale("log")
plt.show()

plt.figure()
weights = tree.inner_nodes.weight.cpu().detach().numpy().flatten()
plt.hist(weights, bins=500)
weights_std = np.std(weights)
weights_mean = np.mean(weights)
plt.axvline(weights_mean + weights_std, color='r')
plt.axvline(weights_mean - weights_std, color='r')
plt.title(f"Mean: {weights_mean}   |   STD: {weights_std}")
plt.yscale("log")
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

# Tree Visualization

In [33]:
plt.figure(figsize=(15, 10), dpi=80)
avg_height, root = tree.visualize()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Average height: 4.333333333333333


# Extract Rules

# Accumulate samples in the leaves

In [34]:
print(f"Number of patterns: {len(root.get_leaves())}")

Number of patterns: 12


In [35]:
method = 'greedy'

In [36]:
root.clear_leaves_samples()

with torch.no_grad():
    for batch_idx, (data, target) in enumerate(tree_loader):
        root.accumulate_samples(data, method)



# Tighten boundaries

In [37]:
attr_names = dataset.items

# print(attr_names)
leaves = root.get_leaves()
sum_comprehensibility = 0
comprehensibilities = []
for pattern_counter, leaf in enumerate(leaves):
    leaf.reset_path()
    leaf.tighten_with_accumulated_samples()
    conds = leaf.get_path_conditions(attr_names)
    print(f"============== Pattern {pattern_counter + 1} ==============")
    comprehensibilities.append(sum([cond.comprehensibility for cond in conds]))
    
print(f"Average comprehensibility: {np.mean(comprehensibilities)}")
print(f"std comprehensibility: {np.std(comprehensibilities)}")
print(f"var comprehensibility: {np.var(comprehensibilities)}")
print(f"minimum comprehensibility: {np.min(comprehensibilities)}")
print(f"maximum comprehensibility: {np.max(comprehensibilities)}")

11084
1524
Average comprehensibility: 21.0
std comprehensibility: 7.14142842854285
var comprehensibility: 51.0
minimum comprehensibility: 10
maximum comprehensibility: 30


  return np.log(1 / (1 - x))
