In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib widget
import numpy as np
import pandas as pd
import torch
import torchvision
from torch import optim
import matplotlib.pyplot as plt
import torch.nn.functional as F
from sklearn.metrics import pairwise_distances
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tqdm.notebook import tqdm
import os
import sys
import time
import torch.nn as nn
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)
from stream_generators.market_basket_dataset import MarketBasketDataset, BinaryEncodingTransform, RemoveItemsTransform
from utils.MatplotlibUtils import reduce_dims_and_plot
from network.auto_encoder import AutoEncoder
from losses.knn_loss import KNNLoss
from sklearn.cluster import DBSCAN
from sklearn.mixture import GaussianMixture
from soft_decision_tree.sdt_model import SDT
from sklearn.metrics import davies_bouldin_score

In [2]:
k = 16
tree_depth = 6
device = 'cuda'
dataset_path = r"/mnt/qnap/ekosman/Groceries_dataset.csv"

## Create the market basket dataset and use one-hot encoding for items

In [3]:
dataset = MarketBasketDataset(dataset_path=dataset_path)

In [4]:
model = AutoEncoder(dataset.n_items, 50, 4).train().to(device)
epochs = 500
lr = 5e-3
batch_size = 512
log_every = 5

In [5]:
dataset.transform = torchvision.transforms.Compose([
    RemoveItemsTransform(p=0.5),
    BinaryEncodingTransform(mapping=dataset.items_to_idx),
]
)
dataset.target_transform = torchvision.transforms.Compose([
    BinaryEncodingTransform(mapping=dataset.items_to_idx),
]
)

In [6]:
model.train()
data_iter = torch.utils.data.DataLoader(dataset,
                                     batch_size=batch_size,
                                     shuffle=True,
                                     num_workers=1,
                                     pin_memory=True)
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)
scheduler = ReduceLROnPlateau(optimizer, 'min', verbose=True, factor=0.7, threshold=1e-4)
knn_crt = KNNLoss(k=k).to(device)
losses = []
alpha = 10/170
gamma = 2
for epoch in range(epochs):
    total_loss = 0
    for iteration, (batch, target) in enumerate(data_iter):
        batch = batch.to(device)
        target = target.to(device)
        outputs, iterm = model(batch, return_intermidiate=True)
        mse_loss = F.binary_cross_entropy_with_logits(outputs, target, reduction='none')
        mask = torch.ones_like(mse_loss)
        mask[target == 0] = alpha ** gamma
        mask[target == 1] = (1 - alpha) ** gamma
        mse_loss = (mse_loss * mask).sum(dim=-1).mean()
        try:
            knn_loss = knn_crt(iterm)
            if torch.isinf(knn_loss):
                knn_loss = 0
        except ValueError:
            knn_loss = torch.tensor(0)
        loss = mse_loss + knn_loss
        total_loss += loss.item()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if iteration % log_every == 0:
            print(f"Epoch {epoch} / {epochs} | iteration {iteration} / {len(data_iter)} | Total Loss: {loss.item()} | KNN Loss: {knn_loss.item()} | BCE Loss: {mse_loss.item()}")
    
    scheduler.step(total_loss / (iteration + 1))
    losses.append(total_loss / (iteration + 1))

Epoch 0 / 500 | iteration 0 / 30 | Total Loss: 8.210041999816895 | KNN Loss: 6.227721214294434 | BCE Loss: 1.98232102394104
Epoch 0 / 500 | iteration 5 / 30 | Total Loss: 8.172706604003906 | KNN Loss: 6.227598190307617 | BCE Loss: 1.9451082944869995
Epoch 0 / 500 | iteration 10 / 30 | Total Loss: 8.17885684967041 | KNN Loss: 6.227121353149414 | BCE Loss: 1.9517356157302856
Epoch 0 / 500 | iteration 15 / 30 | Total Loss: 8.124919891357422 | KNN Loss: 6.22711181640625 | BCE Loss: 1.8978075981140137
Epoch 0 / 500 | iteration 20 / 30 | Total Loss: 8.122896194458008 | KNN Loss: 6.22642707824707 | BCE Loss: 1.896469235420227
Epoch 0 / 500 | iteration 25 / 30 | Total Loss: 8.110965728759766 | KNN Loss: 6.225857257843018 | BCE Loss: 1.8851089477539062
Epoch 1 / 500 | iteration 0 / 30 | Total Loss: 8.14926528930664 | KNN Loss: 6.225307464599609 | BCE Loss: 1.9239575862884521
Epoch 1 / 500 | iteration 5 / 30 | Total Loss: 8.112594604492188 | KNN Loss: 6.2245941162109375 | BCE Loss: 1.88800001144

Epoch 10 / 500 | iteration 25 / 30 | Total Loss: 6.615598201751709 | KNN Loss: 5.525787353515625 | BCE Loss: 1.089810848236084
Epoch 11 / 500 | iteration 0 / 30 | Total Loss: 6.5095672607421875 | KNN Loss: 5.394506454467773 | BCE Loss: 1.115060567855835
Epoch 11 / 500 | iteration 5 / 30 | Total Loss: 6.456127166748047 | KNN Loss: 5.326240062713623 | BCE Loss: 1.1298872232437134
Epoch 11 / 500 | iteration 10 / 30 | Total Loss: 6.258234977722168 | KNN Loss: 5.168399810791016 | BCE Loss: 1.0898349285125732
Epoch 11 / 500 | iteration 15 / 30 | Total Loss: 6.153439521789551 | KNN Loss: 5.073712348937988 | BCE Loss: 1.0797271728515625
Epoch 11 / 500 | iteration 20 / 30 | Total Loss: 6.1178364753723145 | KNN Loss: 4.985212326049805 | BCE Loss: 1.1326241493225098
Epoch 11 / 500 | iteration 25 / 30 | Total Loss: 5.943134307861328 | KNN Loss: 4.838218688964844 | BCE Loss: 1.1049153804779053
Epoch 12 / 500 | iteration 0 / 30 | Total Loss: 5.827757835388184 | KNN Loss: 4.71121072769165 | BCE Loss:

Epoch 21 / 500 | iteration 20 / 30 | Total Loss: 4.275885581970215 | KNN Loss: 3.237459421157837 | BCE Loss: 1.0384260416030884
Epoch 21 / 500 | iteration 25 / 30 | Total Loss: 4.278428554534912 | KNN Loss: 3.2170639038085938 | BCE Loss: 1.0613646507263184
Epoch 22 / 500 | iteration 0 / 30 | Total Loss: 4.236767768859863 | KNN Loss: 3.1986160278320312 | BCE Loss: 1.0381519794464111
Epoch 22 / 500 | iteration 5 / 30 | Total Loss: 4.295625686645508 | KNN Loss: 3.219217300415039 | BCE Loss: 1.0764086246490479
Epoch 22 / 500 | iteration 10 / 30 | Total Loss: 4.2434821128845215 | KNN Loss: 3.205167531967163 | BCE Loss: 1.038314700126648
Epoch 22 / 500 | iteration 15 / 30 | Total Loss: 4.253411293029785 | KNN Loss: 3.2115750312805176 | BCE Loss: 1.0418360233306885
Epoch 22 / 500 | iteration 20 / 30 | Total Loss: 4.2888031005859375 | KNN Loss: 3.2246108055114746 | BCE Loss: 1.064192533493042
Epoch 22 / 500 | iteration 25 / 30 | Total Loss: 4.273770332336426 | KNN Loss: 3.2208034992218018 | BC

Epoch 32 / 500 | iteration 15 / 30 | Total Loss: 4.146391868591309 | KNN Loss: 3.1157655715942383 | BCE Loss: 1.0306265354156494
Epoch 32 / 500 | iteration 20 / 30 | Total Loss: 4.150697231292725 | KNN Loss: 3.132648468017578 | BCE Loss: 1.0180487632751465
Epoch 32 / 500 | iteration 25 / 30 | Total Loss: 4.123737335205078 | KNN Loss: 3.1169633865356445 | BCE Loss: 1.0067739486694336
Epoch 33 / 500 | iteration 0 / 30 | Total Loss: 4.14653205871582 | KNN Loss: 3.1173195838928223 | BCE Loss: 1.0292123556137085
Epoch 33 / 500 | iteration 5 / 30 | Total Loss: 4.153446197509766 | KNN Loss: 3.1160006523132324 | BCE Loss: 1.0374457836151123
Epoch 33 / 500 | iteration 10 / 30 | Total Loss: 4.164685249328613 | KNN Loss: 3.132547616958618 | BCE Loss: 1.0321377515792847
Epoch 33 / 500 | iteration 15 / 30 | Total Loss: 4.202447414398193 | KNN Loss: 3.1936850547790527 | BCE Loss: 1.0087623596191406
Epoch 33 / 500 | iteration 20 / 30 | Total Loss: 4.221055507659912 | KNN Loss: 3.13733172416687 | BCE 

Epoch 43 / 500 | iteration 5 / 30 | Total Loss: 4.191071510314941 | KNN Loss: 3.13553786277771 | BCE Loss: 1.055533528327942
Epoch 43 / 500 | iteration 10 / 30 | Total Loss: 4.151248931884766 | KNN Loss: 3.1207659244537354 | BCE Loss: 1.0304827690124512
Epoch 43 / 500 | iteration 15 / 30 | Total Loss: 4.1442718505859375 | KNN Loss: 3.1346371173858643 | BCE Loss: 1.0096344947814941
Epoch 43 / 500 | iteration 20 / 30 | Total Loss: 4.177558898925781 | KNN Loss: 3.1473333835601807 | BCE Loss: 1.0302252769470215
Epoch 43 / 500 | iteration 25 / 30 | Total Loss: 4.125032901763916 | KNN Loss: 3.108682632446289 | BCE Loss: 1.016350269317627
Epoch 44 / 500 | iteration 0 / 30 | Total Loss: 4.140707015991211 | KNN Loss: 3.1019270420074463 | BCE Loss: 1.0387797355651855
Epoch 44 / 500 | iteration 5 / 30 | Total Loss: 4.171140193939209 | KNN Loss: 3.1368513107299805 | BCE Loss: 1.0342888832092285
Epoch 44 / 500 | iteration 10 / 30 | Total Loss: 4.199906349182129 | KNN Loss: 3.154616594314575 | BCE L

Epoch 53 / 500 | iteration 25 / 30 | Total Loss: 4.139155864715576 | KNN Loss: 3.0930471420288086 | BCE Loss: 1.046108603477478
Epoch 54 / 500 | iteration 0 / 30 | Total Loss: 4.133483409881592 | KNN Loss: 3.113740921020508 | BCE Loss: 1.0197426080703735
Epoch 54 / 500 | iteration 5 / 30 | Total Loss: 4.1554155349731445 | KNN Loss: 3.1146538257598877 | BCE Loss: 1.0407614707946777
Epoch 54 / 500 | iteration 10 / 30 | Total Loss: 4.141876220703125 | KNN Loss: 3.1183860301971436 | BCE Loss: 1.0234904289245605
Epoch 54 / 500 | iteration 15 / 30 | Total Loss: 4.105504989624023 | KNN Loss: 3.0996224880218506 | BCE Loss: 1.0058822631835938
Epoch 54 / 500 | iteration 20 / 30 | Total Loss: 4.167599201202393 | KNN Loss: 3.135105609893799 | BCE Loss: 1.0324937105178833
Epoch 54 / 500 | iteration 25 / 30 | Total Loss: 4.158660888671875 | KNN Loss: 3.1281347274780273 | BCE Loss: 1.0305259227752686
Epoch 55 / 500 | iteration 0 / 30 | Total Loss: 4.115081787109375 | KNN Loss: 3.0937089920043945 | BC

Epoch 64 / 500 | iteration 15 / 30 | Total Loss: 4.135832786560059 | KNN Loss: 3.1009888648986816 | BCE Loss: 1.0348438024520874
Epoch 64 / 500 | iteration 20 / 30 | Total Loss: 4.106701850891113 | KNN Loss: 3.088606357574463 | BCE Loss: 1.0180954933166504
Epoch 64 / 500 | iteration 25 / 30 | Total Loss: 4.1288251876831055 | KNN Loss: 3.0770132541656494 | BCE Loss: 1.051811933517456
Epoch 65 / 500 | iteration 0 / 30 | Total Loss: 4.142237186431885 | KNN Loss: 3.1151328086853027 | BCE Loss: 1.027104377746582
Epoch 65 / 500 | iteration 5 / 30 | Total Loss: 4.1665472984313965 | KNN Loss: 3.1144158840179443 | BCE Loss: 1.0521315336227417
Epoch 65 / 500 | iteration 10 / 30 | Total Loss: 4.144742965698242 | KNN Loss: 3.1329174041748047 | BCE Loss: 1.0118255615234375
Epoch 65 / 500 | iteration 15 / 30 | Total Loss: 4.131760597229004 | KNN Loss: 3.119110584259033 | BCE Loss: 1.0126498937606812
Epoch 65 / 500 | iteration 20 / 30 | Total Loss: 4.157909393310547 | KNN Loss: 3.124241352081299 | BC

Epoch 75 / 500 | iteration 5 / 30 | Total Loss: 4.125607490539551 | KNN Loss: 3.1087586879730225 | BCE Loss: 1.0168486833572388
Epoch 75 / 500 | iteration 10 / 30 | Total Loss: 4.123086452484131 | KNN Loss: 3.1121490001678467 | BCE Loss: 1.0109375715255737
Epoch 75 / 500 | iteration 15 / 30 | Total Loss: 4.100192070007324 | KNN Loss: 3.0890591144561768 | BCE Loss: 1.0111329555511475
Epoch 75 / 500 | iteration 20 / 30 | Total Loss: 4.167860984802246 | KNN Loss: 3.118926525115967 | BCE Loss: 1.0489342212677002
Epoch 75 / 500 | iteration 25 / 30 | Total Loss: 4.104001045227051 | KNN Loss: 3.07480525970459 | BCE Loss: 1.0291955471038818
Epoch 76 / 500 | iteration 0 / 30 | Total Loss: 4.150948524475098 | KNN Loss: 3.088179588317871 | BCE Loss: 1.0627690553665161
Epoch 76 / 500 | iteration 5 / 30 | Total Loss: 4.140726089477539 | KNN Loss: 3.09153151512146 | BCE Loss: 1.049194574356079
Epoch 76 / 500 | iteration 10 / 30 | Total Loss: 4.098135471343994 | KNN Loss: 3.070617437362671 | BCE Loss

Epoch 86 / 500 | iteration 0 / 30 | Total Loss: 4.119589328765869 | KNN Loss: 3.0916733741760254 | BCE Loss: 1.0279159545898438
Epoch 86 / 500 | iteration 5 / 30 | Total Loss: 4.151787757873535 | KNN Loss: 3.115217685699463 | BCE Loss: 1.0365698337554932
Epoch 86 / 500 | iteration 10 / 30 | Total Loss: 4.180403709411621 | KNN Loss: 3.1145479679107666 | BCE Loss: 1.0658559799194336
Epoch 86 / 500 | iteration 15 / 30 | Total Loss: 4.167194366455078 | KNN Loss: 3.1339354515075684 | BCE Loss: 1.0332587957382202
Epoch 86 / 500 | iteration 20 / 30 | Total Loss: 4.0770111083984375 | KNN Loss: 3.0747954845428467 | BCE Loss: 1.0022157430648804
Epoch 86 / 500 | iteration 25 / 30 | Total Loss: 4.0813140869140625 | KNN Loss: 3.0861048698425293 | BCE Loss: 0.9952092170715332
Epoch 87 / 500 | iteration 0 / 30 | Total Loss: 4.140326499938965 | KNN Loss: 3.091454029083252 | BCE Loss: 1.048872709274292
Epoch 87 / 500 | iteration 5 / 30 | Total Loss: 4.075321674346924 | KNN Loss: 3.0539586544036865 | BC

Epoch 96 / 500 | iteration 20 / 30 | Total Loss: 4.152474403381348 | KNN Loss: 3.1283633708953857 | BCE Loss: 1.0241107940673828
Epoch 96 / 500 | iteration 25 / 30 | Total Loss: 4.121400833129883 | KNN Loss: 3.1151132583618164 | BCE Loss: 1.0062874555587769
Epoch 97 / 500 | iteration 0 / 30 | Total Loss: 4.149895191192627 | KNN Loss: 3.139977216720581 | BCE Loss: 1.009917974472046
Epoch 97 / 500 | iteration 5 / 30 | Total Loss: 4.170612335205078 | KNN Loss: 3.137399673461914 | BCE Loss: 1.033212661743164
Epoch 97 / 500 | iteration 10 / 30 | Total Loss: 4.135873317718506 | KNN Loss: 3.1026041507720947 | BCE Loss: 1.0332691669464111
Epoch 97 / 500 | iteration 15 / 30 | Total Loss: 4.139756202697754 | KNN Loss: 3.085341215133667 | BCE Loss: 1.054415225982666
Epoch 97 / 500 | iteration 20 / 30 | Total Loss: 4.106787204742432 | KNN Loss: 3.083601236343384 | BCE Loss: 1.0231859683990479
Epoch 97 / 500 | iteration 25 / 30 | Total Loss: 4.129411220550537 | KNN Loss: 3.100297451019287 | BCE Los

Epoch 107 / 500 | iteration 10 / 30 | Total Loss: 4.162463188171387 | KNN Loss: 3.1162052154541016 | BCE Loss: 1.0462580919265747
Epoch 107 / 500 | iteration 15 / 30 | Total Loss: 4.099435329437256 | KNN Loss: 3.073514938354492 | BCE Loss: 1.0259203910827637
Epoch 107 / 500 | iteration 20 / 30 | Total Loss: 4.131713390350342 | KNN Loss: 3.1121251583099365 | BCE Loss: 1.0195882320404053
Epoch 107 / 500 | iteration 25 / 30 | Total Loss: 4.130538463592529 | KNN Loss: 3.0965569019317627 | BCE Loss: 1.0339815616607666
Epoch 108 / 500 | iteration 0 / 30 | Total Loss: 4.085044860839844 | KNN Loss: 3.093395471572876 | BCE Loss: 0.9916493892669678
Epoch 108 / 500 | iteration 5 / 30 | Total Loss: 4.1126298904418945 | KNN Loss: 3.1008825302124023 | BCE Loss: 1.0117472410202026
Epoch 108 / 500 | iteration 10 / 30 | Total Loss: 4.131220817565918 | KNN Loss: 3.103867769241333 | BCE Loss: 1.027353048324585
Epoch 108 / 500 | iteration 15 / 30 | Total Loss: 4.147238731384277 | KNN Loss: 3.1005184650421

Epoch 118 / 500 | iteration 0 / 30 | Total Loss: 4.1111860275268555 | KNN Loss: 3.0812578201293945 | BCE Loss: 1.0299279689788818
Epoch 118 / 500 | iteration 5 / 30 | Total Loss: 4.108087062835693 | KNN Loss: 3.1178455352783203 | BCE Loss: 0.990241527557373
Epoch 118 / 500 | iteration 10 / 30 | Total Loss: 4.099915504455566 | KNN Loss: 3.0915944576263428 | BCE Loss: 1.0083210468292236
Epoch 118 / 500 | iteration 15 / 30 | Total Loss: 4.100764751434326 | KNN Loss: 3.069977045059204 | BCE Loss: 1.0307875871658325
Epoch 118 / 500 | iteration 20 / 30 | Total Loss: 4.126580238342285 | KNN Loss: 3.1091701984405518 | BCE Loss: 1.0174102783203125
Epoch 118 / 500 | iteration 25 / 30 | Total Loss: 4.078463554382324 | KNN Loss: 3.0735013484954834 | BCE Loss: 1.0049623250961304
Epoch 119 / 500 | iteration 0 / 30 | Total Loss: 4.062147617340088 | KNN Loss: 3.0353941917419434 | BCE Loss: 1.0267534255981445
Epoch 119 / 500 | iteration 5 / 30 | Total Loss: 4.1114702224731445 | KNN Loss: 3.070156812667

Epoch 128 / 500 | iteration 20 / 30 | Total Loss: 4.091609001159668 | KNN Loss: 3.044062852859497 | BCE Loss: 1.04754638671875
Epoch 128 / 500 | iteration 25 / 30 | Total Loss: 4.114238262176514 | KNN Loss: 3.0691139698028564 | BCE Loss: 1.0451242923736572
Epoch 129 / 500 | iteration 0 / 30 | Total Loss: 4.1301069259643555 | KNN Loss: 3.0957961082458496 | BCE Loss: 1.0343109369277954
Epoch 129 / 500 | iteration 5 / 30 | Total Loss: 4.105341911315918 | KNN Loss: 3.099771738052368 | BCE Loss: 1.0055701732635498
Epoch 129 / 500 | iteration 10 / 30 | Total Loss: 4.133753776550293 | KNN Loss: 3.1015212535858154 | BCE Loss: 1.0322322845458984
Epoch 129 / 500 | iteration 15 / 30 | Total Loss: 4.095797538757324 | KNN Loss: 3.047170400619507 | BCE Loss: 1.0486270189285278
Epoch 129 / 500 | iteration 20 / 30 | Total Loss: 4.125022888183594 | KNN Loss: 3.1062185764312744 | BCE Loss: 1.0188043117523193
Epoch 129 / 500 | iteration 25 / 30 | Total Loss: 4.086182117462158 | KNN Loss: 3.07660007476806

Epoch 139 / 500 | iteration 10 / 30 | Total Loss: 4.079288959503174 | KNN Loss: 3.070758581161499 | BCE Loss: 1.0085302591323853
Epoch 139 / 500 | iteration 15 / 30 | Total Loss: 4.08736515045166 | KNN Loss: 3.087836980819702 | BCE Loss: 0.9995284080505371
Epoch 139 / 500 | iteration 20 / 30 | Total Loss: 4.0745134353637695 | KNN Loss: 3.0727200508117676 | BCE Loss: 1.001793384552002
Epoch 139 / 500 | iteration 25 / 30 | Total Loss: 4.080516815185547 | KNN Loss: 3.086458683013916 | BCE Loss: 0.9940581917762756
Epoch 140 / 500 | iteration 0 / 30 | Total Loss: 4.069692611694336 | KNN Loss: 3.0683250427246094 | BCE Loss: 1.0013676881790161
Epoch 140 / 500 | iteration 5 / 30 | Total Loss: 4.059331893920898 | KNN Loss: 3.061249256134033 | BCE Loss: 0.9980824589729309
Epoch 140 / 500 | iteration 10 / 30 | Total Loss: 4.127124786376953 | KNN Loss: 3.114865303039551 | BCE Loss: 1.0122597217559814
Epoch 140 / 500 | iteration 15 / 30 | Total Loss: 4.1382598876953125 | KNN Loss: 3.086788654327392

Epoch 150 / 500 | iteration 0 / 30 | Total Loss: 4.097102165222168 | KNN Loss: 3.075808525085449 | BCE Loss: 1.0212934017181396
Epoch 150 / 500 | iteration 5 / 30 | Total Loss: 4.078984260559082 | KNN Loss: 3.0572636127471924 | BCE Loss: 1.0217204093933105
Epoch 150 / 500 | iteration 10 / 30 | Total Loss: 4.123972415924072 | KNN Loss: 3.091752767562866 | BCE Loss: 1.032219648361206
Epoch 150 / 500 | iteration 15 / 30 | Total Loss: 4.081401824951172 | KNN Loss: 3.0490920543670654 | BCE Loss: 1.0323097705841064
Epoch 150 / 500 | iteration 20 / 30 | Total Loss: 4.056092739105225 | KNN Loss: 3.0398778915405273 | BCE Loss: 1.0162148475646973
Epoch 150 / 500 | iteration 25 / 30 | Total Loss: 4.085788726806641 | KNN Loss: 3.053041696548462 | BCE Loss: 1.0327467918395996
Epoch 151 / 500 | iteration 0 / 30 | Total Loss: 4.060632228851318 | KNN Loss: 3.053041934967041 | BCE Loss: 1.0075902938842773
Epoch 151 / 500 | iteration 5 / 30 | Total Loss: 4.103055953979492 | KNN Loss: 3.0907936096191406 

Epoch 160 / 500 | iteration 20 / 30 | Total Loss: 4.079030513763428 | KNN Loss: 3.0839531421661377 | BCE Loss: 0.9950774908065796
Epoch 160 / 500 | iteration 25 / 30 | Total Loss: 4.097964286804199 | KNN Loss: 3.0737640857696533 | BCE Loss: 1.0241999626159668
Epoch 161 / 500 | iteration 0 / 30 | Total Loss: 4.105476379394531 | KNN Loss: 3.0887420177459717 | BCE Loss: 1.0167344808578491
Epoch 161 / 500 | iteration 5 / 30 | Total Loss: 4.075845718383789 | KNN Loss: 3.056375741958618 | BCE Loss: 1.0194697380065918
Epoch 161 / 500 | iteration 10 / 30 | Total Loss: 4.1096510887146 | KNN Loss: 3.073732852935791 | BCE Loss: 1.0359183549880981
Epoch 161 / 500 | iteration 15 / 30 | Total Loss: 4.0920257568359375 | KNN Loss: 3.0931198596954346 | BCE Loss: 0.9989060163497925
Epoch 161 / 500 | iteration 20 / 30 | Total Loss: 4.068831920623779 | KNN Loss: 3.0522940158843994 | BCE Loss: 1.0165379047393799
Epoch 161 / 500 | iteration 25 / 30 | Total Loss: 4.122918605804443 | KNN Loss: 3.0838129520416

Epoch 171 / 500 | iteration 5 / 30 | Total Loss: 4.132905960083008 | KNN Loss: 3.0833632946014404 | BCE Loss: 1.0495426654815674
Epoch 171 / 500 | iteration 10 / 30 | Total Loss: 4.038260459899902 | KNN Loss: 3.0444841384887695 | BCE Loss: 0.993776261806488
Epoch 171 / 500 | iteration 15 / 30 | Total Loss: 4.109405517578125 | KNN Loss: 3.061530828475952 | BCE Loss: 1.0478746891021729
Epoch 171 / 500 | iteration 20 / 30 | Total Loss: 4.061433792114258 | KNN Loss: 3.0315868854522705 | BCE Loss: 1.0298469066619873
Epoch 171 / 500 | iteration 25 / 30 | Total Loss: 4.1186628341674805 | KNN Loss: 3.0880637168884277 | BCE Loss: 1.0305991172790527
Epoch 172 / 500 | iteration 0 / 30 | Total Loss: 4.073940277099609 | KNN Loss: 3.0552175045013428 | BCE Loss: 1.0187228918075562
Epoch 172 / 500 | iteration 5 / 30 | Total Loss: 4.057873249053955 | KNN Loss: 3.0278451442718506 | BCE Loss: 1.030028223991394
Epoch 172 / 500 | iteration 10 / 30 | Total Loss: 4.080341339111328 | KNN Loss: 3.0712866783142

Epoch 181 / 500 | iteration 25 / 30 | Total Loss: 4.098759651184082 | KNN Loss: 3.0786404609680176 | BCE Loss: 1.020119309425354
Epoch 182 / 500 | iteration 0 / 30 | Total Loss: 4.0855326652526855 | KNN Loss: 3.062173843383789 | BCE Loss: 1.023358702659607
Epoch 182 / 500 | iteration 5 / 30 | Total Loss: 4.079740524291992 | KNN Loss: 3.054070472717285 | BCE Loss: 1.0256702899932861
Epoch 182 / 500 | iteration 10 / 30 | Total Loss: 4.116410255432129 | KNN Loss: 3.075468063354492 | BCE Loss: 1.0409424304962158
Epoch 182 / 500 | iteration 15 / 30 | Total Loss: 4.053072929382324 | KNN Loss: 3.015141010284424 | BCE Loss: 1.03793203830719
Epoch 182 / 500 | iteration 20 / 30 | Total Loss: 4.076537609100342 | KNN Loss: 3.0754430294036865 | BCE Loss: 1.0010946989059448
Epoch 182 / 500 | iteration 25 / 30 | Total Loss: 4.041234970092773 | KNN Loss: 3.054825782775879 | BCE Loss: 0.9864093065261841
Epoch 183 / 500 | iteration 0 / 30 | Total Loss: 4.013128280639648 | KNN Loss: 3.042893409729004 | B

Epoch 192 / 500 | iteration 15 / 30 | Total Loss: 4.070567607879639 | KNN Loss: 3.064936876296997 | BCE Loss: 1.0056307315826416
Epoch 192 / 500 | iteration 20 / 30 | Total Loss: 4.099040985107422 | KNN Loss: 3.0886318683624268 | BCE Loss: 1.010408878326416
Epoch 192 / 500 | iteration 25 / 30 | Total Loss: 4.061402797698975 | KNN Loss: 3.048452377319336 | BCE Loss: 1.0129504203796387
Epoch 193 / 500 | iteration 0 / 30 | Total Loss: 4.109785079956055 | KNN Loss: 3.0636441707611084 | BCE Loss: 1.0461411476135254
Epoch 193 / 500 | iteration 5 / 30 | Total Loss: 4.073755264282227 | KNN Loss: 3.053832769393921 | BCE Loss: 1.0199224948883057
Epoch 193 / 500 | iteration 10 / 30 | Total Loss: 4.095701694488525 | KNN Loss: 3.082777261734009 | BCE Loss: 1.0129244327545166
Epoch 193 / 500 | iteration 15 / 30 | Total Loss: 4.049081802368164 | KNN Loss: 3.0294525623321533 | BCE Loss: 1.0196294784545898
Epoch 193 / 500 | iteration 20 / 30 | Total Loss: 4.121363162994385 | KNN Loss: 3.092241525650024

Epoch 203 / 500 | iteration 0 / 30 | Total Loss: 4.118307113647461 | KNN Loss: 3.099837064743042 | BCE Loss: 1.0184698104858398
Epoch 203 / 500 | iteration 5 / 30 | Total Loss: 4.082684516906738 | KNN Loss: 3.0623574256896973 | BCE Loss: 1.020326852798462
Epoch 203 / 500 | iteration 10 / 30 | Total Loss: 4.06756591796875 | KNN Loss: 3.0335588455200195 | BCE Loss: 1.0340068340301514
Epoch 203 / 500 | iteration 15 / 30 | Total Loss: 4.056091785430908 | KNN Loss: 3.0439164638519287 | BCE Loss: 1.01217520236969
Epoch 203 / 500 | iteration 20 / 30 | Total Loss: 4.111517906188965 | KNN Loss: 3.0536048412323 | BCE Loss: 1.057912826538086
Epoch 203 / 500 | iteration 25 / 30 | Total Loss: 4.061524391174316 | KNN Loss: 3.045161724090576 | BCE Loss: 1.0163625478744507
Epoch 204 / 500 | iteration 0 / 30 | Total Loss: 4.067591667175293 | KNN Loss: 3.0468993186950684 | BCE Loss: 1.020692229270935
Epoch 204 / 500 | iteration 5 / 30 | Total Loss: 4.079185485839844 | KNN Loss: 3.0642359256744385 | BCE 

Epoch 213 / 500 | iteration 20 / 30 | Total Loss: 4.03912353515625 | KNN Loss: 3.0429651737213135 | BCE Loss: 0.9961584806442261
Epoch 213 / 500 | iteration 25 / 30 | Total Loss: 4.074508190155029 | KNN Loss: 3.0805087089538574 | BCE Loss: 0.9939993619918823
Epoch 214 / 500 | iteration 0 / 30 | Total Loss: 4.083720684051514 | KNN Loss: 3.05077862739563 | BCE Loss: 1.0329420566558838
Epoch 214 / 500 | iteration 5 / 30 | Total Loss: 4.0949273109436035 | KNN Loss: 3.0743186473846436 | BCE Loss: 1.0206087827682495
Epoch 214 / 500 | iteration 10 / 30 | Total Loss: 4.110234260559082 | KNN Loss: 3.0588455200195312 | BCE Loss: 1.0513886213302612
Epoch 214 / 500 | iteration 15 / 30 | Total Loss: 4.031532287597656 | KNN Loss: 3.032217502593994 | BCE Loss: 0.9993147253990173
Epoch 214 / 500 | iteration 20 / 30 | Total Loss: 4.060774803161621 | KNN Loss: 3.0407092571258545 | BCE Loss: 1.0200657844543457
Epoch 214 / 500 | iteration 25 / 30 | Total Loss: 4.075634956359863 | KNN Loss: 3.0570747852325

Epoch 224 / 500 | iteration 10 / 30 | Total Loss: 4.067376136779785 | KNN Loss: 3.0481350421905518 | BCE Loss: 1.0192409753799438
Epoch 224 / 500 | iteration 15 / 30 | Total Loss: 4.119959831237793 | KNN Loss: 3.0695080757141113 | BCE Loss: 1.0504515171051025
Epoch 224 / 500 | iteration 20 / 30 | Total Loss: 4.108808517456055 | KNN Loss: 3.0777597427368164 | BCE Loss: 1.0310487747192383
Epoch 224 / 500 | iteration 25 / 30 | Total Loss: 4.042533874511719 | KNN Loss: 3.0351309776306152 | BCE Loss: 1.0074028968811035
Epoch 225 / 500 | iteration 0 / 30 | Total Loss: 4.070557117462158 | KNN Loss: 3.0749635696411133 | BCE Loss: 0.9955934286117554
Epoch 225 / 500 | iteration 5 / 30 | Total Loss: 4.1121039390563965 | KNN Loss: 3.1139438152313232 | BCE Loss: 0.9981602430343628
Epoch 225 / 500 | iteration 10 / 30 | Total Loss: 4.0951385498046875 | KNN Loss: 3.09186053276062 | BCE Loss: 1.003278136253357
Epoch 225 / 500 | iteration 15 / 30 | Total Loss: 4.094564437866211 | KNN Loss: 3.04061126708

Epoch 235 / 500 | iteration 0 / 30 | Total Loss: 4.06601619720459 | KNN Loss: 3.0709445476531982 | BCE Loss: 0.9950718879699707
Epoch 235 / 500 | iteration 5 / 30 | Total Loss: 4.1497578620910645 | KNN Loss: 3.096395969390869 | BCE Loss: 1.0533617734909058
Epoch 235 / 500 | iteration 10 / 30 | Total Loss: 4.094093322753906 | KNN Loss: 3.077775001525879 | BCE Loss: 1.0163183212280273
Epoch 235 / 500 | iteration 15 / 30 | Total Loss: 4.051914215087891 | KNN Loss: 3.075132131576538 | BCE Loss: 0.9767818450927734
Epoch 235 / 500 | iteration 20 / 30 | Total Loss: 4.080904006958008 | KNN Loss: 3.0622057914733887 | BCE Loss: 1.01869797706604
Epoch 235 / 500 | iteration 25 / 30 | Total Loss: 4.054556846618652 | KNN Loss: 3.0456247329711914 | BCE Loss: 1.00893235206604
Epoch 236 / 500 | iteration 0 / 30 | Total Loss: 4.122065544128418 | KNN Loss: 3.085259437561035 | BCE Loss: 1.0368061065673828
Epoch 236 / 500 | iteration 5 / 30 | Total Loss: 4.1009674072265625 | KNN Loss: 3.092278003692627 | B

Epoch 245 / 500 | iteration 20 / 30 | Total Loss: 4.092745304107666 | KNN Loss: 3.0629520416259766 | BCE Loss: 1.029793381690979
Epoch 245 / 500 | iteration 25 / 30 | Total Loss: 4.09627103805542 | KNN Loss: 3.0671637058258057 | BCE Loss: 1.0291073322296143
Epoch 246 / 500 | iteration 0 / 30 | Total Loss: 4.116108417510986 | KNN Loss: 3.0819621086120605 | BCE Loss: 1.0341463088989258
Epoch 246 / 500 | iteration 5 / 30 | Total Loss: 4.1358723640441895 | KNN Loss: 3.088146924972534 | BCE Loss: 1.0477254390716553
Epoch 246 / 500 | iteration 10 / 30 | Total Loss: 4.155173301696777 | KNN Loss: 3.119962692260742 | BCE Loss: 1.0352104902267456
Epoch 246 / 500 | iteration 15 / 30 | Total Loss: 4.098528861999512 | KNN Loss: 3.052542209625244 | BCE Loss: 1.0459866523742676
Epoch 246 / 500 | iteration 20 / 30 | Total Loss: 4.1012115478515625 | KNN Loss: 3.073317289352417 | BCE Loss: 1.0278942584991455
Epoch 246 / 500 | iteration 25 / 30 | Total Loss: 4.093395233154297 | KNN Loss: 3.0665283203125 

Epoch 256 / 500 | iteration 10 / 30 | Total Loss: 4.121406078338623 | KNN Loss: 3.079143524169922 | BCE Loss: 1.0422626733779907
Epoch 256 / 500 | iteration 15 / 30 | Total Loss: 4.051721096038818 | KNN Loss: 3.049287796020508 | BCE Loss: 1.0024333000183105
Epoch 256 / 500 | iteration 20 / 30 | Total Loss: 4.115566253662109 | KNN Loss: 3.0803139209747314 | BCE Loss: 1.0352520942687988
Epoch 256 / 500 | iteration 25 / 30 | Total Loss: 4.069867134094238 | KNN Loss: 3.05568265914917 | BCE Loss: 1.0141843557357788
Epoch 257 / 500 | iteration 0 / 30 | Total Loss: 4.062214374542236 | KNN Loss: 3.0572218894958496 | BCE Loss: 1.0049926042556763
Epoch 257 / 500 | iteration 5 / 30 | Total Loss: 4.060173511505127 | KNN Loss: 3.048293352127075 | BCE Loss: 1.0118802785873413
Epoch 257 / 500 | iteration 10 / 30 | Total Loss: 4.121858596801758 | KNN Loss: 3.095674514770508 | BCE Loss: 1.0261842012405396
Epoch 257 / 500 | iteration 15 / 30 | Total Loss: 4.082003116607666 | KNN Loss: 3.0496068000793457

Epoch 267 / 500 | iteration 0 / 30 | Total Loss: 4.087404251098633 | KNN Loss: 3.083503484725952 | BCE Loss: 1.0039010047912598
Epoch 267 / 500 | iteration 5 / 30 | Total Loss: 4.100742340087891 | KNN Loss: 3.069709062576294 | BCE Loss: 1.0310335159301758
Epoch 267 / 500 | iteration 10 / 30 | Total Loss: 4.065103530883789 | KNN Loss: 3.0453507900238037 | BCE Loss: 1.0197527408599854
Epoch 267 / 500 | iteration 15 / 30 | Total Loss: 4.119292259216309 | KNN Loss: 3.0666182041168213 | BCE Loss: 1.0526740550994873
Epoch 267 / 500 | iteration 20 / 30 | Total Loss: 4.068384170532227 | KNN Loss: 3.054471492767334 | BCE Loss: 1.0139127969741821
Epoch 267 / 500 | iteration 25 / 30 | Total Loss: 4.084905624389648 | KNN Loss: 3.0917134284973145 | BCE Loss: 0.9931920766830444
Epoch 268 / 500 | iteration 0 / 30 | Total Loss: 4.088055610656738 | KNN Loss: 3.0530197620391846 | BCE Loss: 1.0350360870361328
Epoch 268 / 500 | iteration 5 / 30 | Total Loss: 4.071255207061768 | KNN Loss: 3.042615413665771

Epoch 277 / 500 | iteration 20 / 30 | Total Loss: 4.075835704803467 | KNN Loss: 3.0584115982055664 | BCE Loss: 1.0174241065979004
Epoch 277 / 500 | iteration 25 / 30 | Total Loss: 4.054621696472168 | KNN Loss: 3.030271291732788 | BCE Loss: 1.024350643157959
Epoch 278 / 500 | iteration 0 / 30 | Total Loss: 4.063046455383301 | KNN Loss: 3.0533480644226074 | BCE Loss: 1.009698510169983
Epoch 278 / 500 | iteration 5 / 30 | Total Loss: 4.140239238739014 | KNN Loss: 3.120058059692383 | BCE Loss: 1.0201812982559204
Epoch 278 / 500 | iteration 10 / 30 | Total Loss: 4.077526092529297 | KNN Loss: 3.06807541847229 | BCE Loss: 1.009450912475586
Epoch 278 / 500 | iteration 15 / 30 | Total Loss: 4.066543102264404 | KNN Loss: 3.0354506969451904 | BCE Loss: 1.0310924053192139
Epoch 278 / 500 | iteration 20 / 30 | Total Loss: 4.089733123779297 | KNN Loss: 3.051955461502075 | BCE Loss: 1.0377776622772217
Epoch 278 / 500 | iteration 25 / 30 | Total Loss: 4.086796760559082 | KNN Loss: 3.0740296840667725 |

Epoch 288 / 500 | iteration 10 / 30 | Total Loss: 4.066539764404297 | KNN Loss: 3.0562262535095215 | BCE Loss: 1.0103133916854858
Epoch 288 / 500 | iteration 15 / 30 | Total Loss: 4.100150108337402 | KNN Loss: 3.0956203937530518 | BCE Loss: 1.0045299530029297
Epoch 288 / 500 | iteration 20 / 30 | Total Loss: 4.098986625671387 | KNN Loss: 3.046168565750122 | BCE Loss: 1.052817940711975
Epoch 288 / 500 | iteration 25 / 30 | Total Loss: 4.035426139831543 | KNN Loss: 3.0289559364318848 | BCE Loss: 1.0064704418182373
Epoch 289 / 500 | iteration 0 / 30 | Total Loss: 4.111501693725586 | KNN Loss: 3.0806198120117188 | BCE Loss: 1.030881643295288
Epoch 289 / 500 | iteration 5 / 30 | Total Loss: 4.087698936462402 | KNN Loss: 3.068136215209961 | BCE Loss: 1.0195629596710205
Epoch 289 / 500 | iteration 10 / 30 | Total Loss: 4.036451816558838 | KNN Loss: 3.0418875217437744 | BCE Loss: 0.9945644736289978
Epoch 289 / 500 | iteration 15 / 30 | Total Loss: 4.057404041290283 | KNN Loss: 3.04875540733337

Epoch 299 / 500 | iteration 0 / 30 | Total Loss: 4.1044464111328125 | KNN Loss: 3.078803777694702 | BCE Loss: 1.0256426334381104
Epoch 299 / 500 | iteration 5 / 30 | Total Loss: 4.040267467498779 | KNN Loss: 3.037220001220703 | BCE Loss: 1.0030474662780762
Epoch 299 / 500 | iteration 10 / 30 | Total Loss: 4.0728230476379395 | KNN Loss: 3.063408613204956 | BCE Loss: 1.009414553642273
Epoch 299 / 500 | iteration 15 / 30 | Total Loss: 4.089882850646973 | KNN Loss: 3.0767266750335693 | BCE Loss: 1.0131564140319824
Epoch 299 / 500 | iteration 20 / 30 | Total Loss: 4.052104473114014 | KNN Loss: 3.034122943878174 | BCE Loss: 1.0179816484451294
Epoch 299 / 500 | iteration 25 / 30 | Total Loss: 4.053789138793945 | KNN Loss: 3.049194097518921 | BCE Loss: 1.0045948028564453
Epoch 300 / 500 | iteration 0 / 30 | Total Loss: 4.108120918273926 | KNN Loss: 3.0961012840270996 | BCE Loss: 1.012019395828247
Epoch 300 / 500 | iteration 5 / 30 | Total Loss: 4.086413383483887 | KNN Loss: 3.0748889446258545 

Epoch 309 / 500 | iteration 20 / 30 | Total Loss: 4.0407395362854 | KNN Loss: 3.025773048400879 | BCE Loss: 1.0149664878845215
Epoch 309 / 500 | iteration 25 / 30 | Total Loss: 4.068409442901611 | KNN Loss: 3.055664539337158 | BCE Loss: 1.0127449035644531
Epoch 310 / 500 | iteration 0 / 30 | Total Loss: 4.109676361083984 | KNN Loss: 3.090897560119629 | BCE Loss: 1.0187790393829346
Epoch 310 / 500 | iteration 5 / 30 | Total Loss: 4.047722339630127 | KNN Loss: 3.057224750518799 | BCE Loss: 0.9904977083206177
Epoch 310 / 500 | iteration 10 / 30 | Total Loss: 4.035503387451172 | KNN Loss: 3.031756639480591 | BCE Loss: 1.0037468671798706
Epoch 310 / 500 | iteration 15 / 30 | Total Loss: 4.065025806427002 | KNN Loss: 3.052802801132202 | BCE Loss: 1.0122230052947998
Epoch 310 / 500 | iteration 20 / 30 | Total Loss: 4.140834808349609 | KNN Loss: 3.069694757461548 | BCE Loss: 1.0711400508880615
Epoch 310 / 500 | iteration 25 / 30 | Total Loss: 4.116403102874756 | KNN Loss: 3.0655839443206787 | 

Epoch 320 / 500 | iteration 10 / 30 | Total Loss: 4.0772504806518555 | KNN Loss: 3.0411808490753174 | BCE Loss: 1.0360698699951172
Epoch 320 / 500 | iteration 15 / 30 | Total Loss: 4.073978424072266 | KNN Loss: 3.0452327728271484 | BCE Loss: 1.0287456512451172
Epoch 320 / 500 | iteration 20 / 30 | Total Loss: 4.067128658294678 | KNN Loss: 3.046569347381592 | BCE Loss: 1.0205594301223755
Epoch 320 / 500 | iteration 25 / 30 | Total Loss: 4.069195747375488 | KNN Loss: 3.0501182079315186 | BCE Loss: 1.0190777778625488
Epoch 321 / 500 | iteration 0 / 30 | Total Loss: 4.038179874420166 | KNN Loss: 3.0461928844451904 | BCE Loss: 0.9919869899749756
Epoch 321 / 500 | iteration 5 / 30 | Total Loss: 4.111441612243652 | KNN Loss: 3.0863091945648193 | BCE Loss: 1.025132656097412
Epoch 321 / 500 | iteration 10 / 30 | Total Loss: 4.10088586807251 | KNN Loss: 3.1105129718780518 | BCE Loss: 0.9903728365898132
Epoch 321 / 500 | iteration 15 / 30 | Total Loss: 4.069765090942383 | KNN Loss: 3.062263250350

Epoch 330 / 500 | iteration 25 / 30 | Total Loss: 4.119354248046875 | KNN Loss: 3.1064727306365967 | BCE Loss: 1.0128815174102783
Epoch 331 / 500 | iteration 0 / 30 | Total Loss: 4.112932205200195 | KNN Loss: 3.095289468765259 | BCE Loss: 1.0176424980163574
Epoch 331 / 500 | iteration 5 / 30 | Total Loss: 4.0741190910339355 | KNN Loss: 3.0485615730285645 | BCE Loss: 1.025557518005371
Epoch 331 / 500 | iteration 10 / 30 | Total Loss: 4.160473346710205 | KNN Loss: 3.0666799545288086 | BCE Loss: 1.0937933921813965
Epoch 331 / 500 | iteration 15 / 30 | Total Loss: 4.115240573883057 | KNN Loss: 3.058366537094116 | BCE Loss: 1.05687415599823
Epoch 331 / 500 | iteration 20 / 30 | Total Loss: 4.138321876525879 | KNN Loss: 3.10361385345459 | BCE Loss: 1.0347082614898682
Epoch 331 / 500 | iteration 25 / 30 | Total Loss: 4.082952976226807 | KNN Loss: 3.0658583641052246 | BCE Loss: 1.017094612121582
Epoch 332 / 500 | iteration 0 / 30 | Total Loss: 4.059178352355957 | KNN Loss: 3.0628654956817627 |

Epoch 341 / 500 | iteration 15 / 30 | Total Loss: 4.106731414794922 | KNN Loss: 3.081313371658325 | BCE Loss: 1.0254181623458862
Epoch 341 / 500 | iteration 20 / 30 | Total Loss: 4.084405422210693 | KNN Loss: 3.052861213684082 | BCE Loss: 1.0315440893173218
Epoch 341 / 500 | iteration 25 / 30 | Total Loss: 4.122188091278076 | KNN Loss: 3.0916287899017334 | BCE Loss: 1.0305593013763428
Epoch 342 / 500 | iteration 0 / 30 | Total Loss: 4.083104610443115 | KNN Loss: 3.0706849098205566 | BCE Loss: 1.012419581413269
Epoch 342 / 500 | iteration 5 / 30 | Total Loss: 4.105259895324707 | KNN Loss: 3.079033613204956 | BCE Loss: 1.02622652053833
Epoch 342 / 500 | iteration 10 / 30 | Total Loss: 4.050450325012207 | KNN Loss: 3.0529236793518066 | BCE Loss: 0.9975264072418213
Epoch 342 / 500 | iteration 15 / 30 | Total Loss: 4.088236331939697 | KNN Loss: 3.0754010677337646 | BCE Loss: 1.0128352642059326
Epoch 342 / 500 | iteration 20 / 30 | Total Loss: 4.0567851066589355 | KNN Loss: 3.045332193374634

Epoch 352 / 500 | iteration 5 / 30 | Total Loss: 4.077304840087891 | KNN Loss: 3.067039966583252 | BCE Loss: 1.0102648735046387
Epoch 352 / 500 | iteration 10 / 30 | Total Loss: 4.093416690826416 | KNN Loss: 3.0683653354644775 | BCE Loss: 1.0250513553619385
Epoch 352 / 500 | iteration 15 / 30 | Total Loss: 4.045453071594238 | KNN Loss: 3.0505919456481934 | BCE Loss: 0.9948611259460449
Epoch 352 / 500 | iteration 20 / 30 | Total Loss: 4.085879802703857 | KNN Loss: 3.0567078590393066 | BCE Loss: 1.0291719436645508
Epoch 352 / 500 | iteration 25 / 30 | Total Loss: 4.075547695159912 | KNN Loss: 3.0738303661346436 | BCE Loss: 1.0017173290252686
Epoch 353 / 500 | iteration 0 / 30 | Total Loss: 4.078812122344971 | KNN Loss: 3.053678035736084 | BCE Loss: 1.0251339673995972
Epoch 353 / 500 | iteration 5 / 30 | Total Loss: 4.125414848327637 | KNN Loss: 3.0956194400787354 | BCE Loss: 1.0297956466674805
Epoch 353 / 500 | iteration 10 / 30 | Total Loss: 4.072569847106934 | KNN Loss: 3.0511705875396

Epoch 362 / 500 | iteration 25 / 30 | Total Loss: 4.123946189880371 | KNN Loss: 3.072725534439087 | BCE Loss: 1.0512207746505737
Epoch 363 / 500 | iteration 0 / 30 | Total Loss: 4.049709320068359 | KNN Loss: 3.045161485671997 | BCE Loss: 1.0045475959777832
Epoch 363 / 500 | iteration 5 / 30 | Total Loss: 4.064548492431641 | KNN Loss: 3.035670280456543 | BCE Loss: 1.028878092765808
Epoch 363 / 500 | iteration 10 / 30 | Total Loss: 4.082615852355957 | KNN Loss: 3.053337335586548 | BCE Loss: 1.0292786359786987
Epoch 363 / 500 | iteration 15 / 30 | Total Loss: 4.0843186378479 | KNN Loss: 3.0676252841949463 | BCE Loss: 1.0166934728622437
Epoch 363 / 500 | iteration 20 / 30 | Total Loss: 4.100780963897705 | KNN Loss: 3.0641186237335205 | BCE Loss: 1.0366623401641846
Epoch 363 / 500 | iteration 25 / 30 | Total Loss: 4.079537391662598 | KNN Loss: 3.0617446899414062 | BCE Loss: 1.0177929401397705
Epoch 364 / 500 | iteration 0 / 30 | Total Loss: 4.135427951812744 | KNN Loss: 3.105100393295288 | 

Epoch 373 / 500 | iteration 15 / 30 | Total Loss: 4.077079772949219 | KNN Loss: 3.048875331878662 | BCE Loss: 1.0282044410705566
Epoch 373 / 500 | iteration 20 / 30 | Total Loss: 4.113705635070801 | KNN Loss: 3.0762808322906494 | BCE Loss: 1.0374246835708618
Epoch 373 / 500 | iteration 25 / 30 | Total Loss: 4.07818078994751 | KNN Loss: 3.077791213989258 | BCE Loss: 1.000389575958252
Epoch 374 / 500 | iteration 0 / 30 | Total Loss: 4.085204601287842 | KNN Loss: 3.080888271331787 | BCE Loss: 1.0043163299560547
Epoch 374 / 500 | iteration 5 / 30 | Total Loss: 4.062018394470215 | KNN Loss: 3.0552868843078613 | BCE Loss: 1.0067312717437744
Epoch 374 / 500 | iteration 10 / 30 | Total Loss: 4.046649932861328 | KNN Loss: 3.0492103099823 | BCE Loss: 0.9974395036697388
Epoch 374 / 500 | iteration 15 / 30 | Total Loss: 4.079708576202393 | KNN Loss: 3.0769402980804443 | BCE Loss: 1.0027683973312378
Epoch 374 / 500 | iteration 20 / 30 | Total Loss: 4.092165470123291 | KNN Loss: 3.054375410079956 | 

Epoch 384 / 500 | iteration 5 / 30 | Total Loss: 4.063702583312988 | KNN Loss: 3.0532169342041016 | BCE Loss: 1.0104856491088867
Epoch 384 / 500 | iteration 10 / 30 | Total Loss: 4.079419136047363 | KNN Loss: 3.0547778606414795 | BCE Loss: 1.0246412754058838
Epoch 384 / 500 | iteration 15 / 30 | Total Loss: 4.058086395263672 | KNN Loss: 3.0505003929138184 | BCE Loss: 1.0075857639312744
Epoch 384 / 500 | iteration 20 / 30 | Total Loss: 4.090585708618164 | KNN Loss: 3.0575969219207764 | BCE Loss: 1.0329887866973877
Epoch 384 / 500 | iteration 25 / 30 | Total Loss: 4.1158857345581055 | KNN Loss: 3.0565316677093506 | BCE Loss: 1.0593538284301758
Epoch 385 / 500 | iteration 0 / 30 | Total Loss: 4.118457794189453 | KNN Loss: 3.062504768371582 | BCE Loss: 1.0559532642364502
Epoch 385 / 500 | iteration 5 / 30 | Total Loss: 4.095338821411133 | KNN Loss: 3.0789670944213867 | BCE Loss: 1.016371488571167
Epoch 385 / 500 | iteration 10 / 30 | Total Loss: 4.069477081298828 | KNN Loss: 3.065480709075

Epoch 394 / 500 | iteration 20 / 30 | Total Loss: 4.120943546295166 | KNN Loss: 3.109389305114746 | BCE Loss: 1.0115541219711304
Epoch 394 / 500 | iteration 25 / 30 | Total Loss: 4.090363502502441 | KNN Loss: 3.077444076538086 | BCE Loss: 1.0129191875457764
Epoch 395 / 500 | iteration 0 / 30 | Total Loss: 4.07285213470459 | KNN Loss: 3.046447515487671 | BCE Loss: 1.0264043807983398
Epoch 395 / 500 | iteration 5 / 30 | Total Loss: 4.096619129180908 | KNN Loss: 3.0946435928344727 | BCE Loss: 1.001975655555725
Epoch 395 / 500 | iteration 10 / 30 | Total Loss: 4.078974723815918 | KNN Loss: 3.0472629070281982 | BCE Loss: 1.0317116975784302
Epoch 395 / 500 | iteration 15 / 30 | Total Loss: 4.114849090576172 | KNN Loss: 3.079066276550293 | BCE Loss: 1.035782814025879
Epoch 395 / 500 | iteration 20 / 30 | Total Loss: 4.036906719207764 | KNN Loss: 3.030313014984131 | BCE Loss: 1.0065938234329224
Epoch 395 / 500 | iteration 25 / 30 | Total Loss: 4.116434574127197 | KNN Loss: 3.059527635574341 | 

Epoch 405 / 500 | iteration 10 / 30 | Total Loss: 4.068042755126953 | KNN Loss: 3.046755790710449 | BCE Loss: 1.021286964416504
Epoch 405 / 500 | iteration 15 / 30 | Total Loss: 4.120184421539307 | KNN Loss: 3.0868475437164307 | BCE Loss: 1.0333367586135864
Epoch 405 / 500 | iteration 20 / 30 | Total Loss: 4.073808670043945 | KNN Loss: 3.06144118309021 | BCE Loss: 1.0123677253723145
Epoch 405 / 500 | iteration 25 / 30 | Total Loss: 4.043251991271973 | KNN Loss: 3.048901319503784 | BCE Loss: 0.994350790977478
Epoch 406 / 500 | iteration 0 / 30 | Total Loss: 4.099719524383545 | KNN Loss: 3.0810751914978027 | BCE Loss: 1.0186443328857422
Epoch 406 / 500 | iteration 5 / 30 | Total Loss: 4.093502998352051 | KNN Loss: 3.0630900859832764 | BCE Loss: 1.0304131507873535
Epoch 406 / 500 | iteration 10 / 30 | Total Loss: 4.115434169769287 | KNN Loss: 3.067732095718384 | BCE Loss: 1.0477020740509033
Epoch 406 / 500 | iteration 15 / 30 | Total Loss: 4.104669094085693 | KNN Loss: 3.1048500537872314 

Epoch 416 / 500 | iteration 0 / 30 | Total Loss: 4.081423759460449 | KNN Loss: 3.0442419052124023 | BCE Loss: 1.0371817350387573
Epoch 416 / 500 | iteration 5 / 30 | Total Loss: 4.0857133865356445 | KNN Loss: 3.070901870727539 | BCE Loss: 1.0148117542266846
Epoch 416 / 500 | iteration 10 / 30 | Total Loss: 4.059178829193115 | KNN Loss: 3.054871082305908 | BCE Loss: 1.004307746887207
Epoch 416 / 500 | iteration 15 / 30 | Total Loss: 4.07711124420166 | KNN Loss: 3.0650458335876465 | BCE Loss: 1.0120652914047241
Epoch 416 / 500 | iteration 20 / 30 | Total Loss: 4.08748722076416 | KNN Loss: 3.0678727626800537 | BCE Loss: 1.0196146965026855
Epoch 416 / 500 | iteration 25 / 30 | Total Loss: 4.0726823806762695 | KNN Loss: 3.0769901275634766 | BCE Loss: 0.995692253112793
Epoch 417 / 500 | iteration 0 / 30 | Total Loss: 4.075466156005859 | KNN Loss: 3.068356990814209 | BCE Loss: 1.00710928440094
Epoch 417 / 500 | iteration 5 / 30 | Total Loss: 4.062924385070801 | KNN Loss: 3.0262625217437744 | 

Epoch 426 / 500 | iteration 20 / 30 | Total Loss: 4.068086624145508 | KNN Loss: 3.0495188236236572 | BCE Loss: 1.018567681312561
Epoch 426 / 500 | iteration 25 / 30 | Total Loss: 4.062498092651367 | KNN Loss: 3.049300193786621 | BCE Loss: 1.0131981372833252
Epoch 427 / 500 | iteration 0 / 30 | Total Loss: 4.11154842376709 | KNN Loss: 3.0783779621124268 | BCE Loss: 1.0331707000732422
Epoch 427 / 500 | iteration 5 / 30 | Total Loss: 4.056885719299316 | KNN Loss: 3.038106918334961 | BCE Loss: 1.0187785625457764
Epoch 427 / 500 | iteration 10 / 30 | Total Loss: 4.104218006134033 | KNN Loss: 3.092028856277466 | BCE Loss: 1.012189269065857
Epoch 427 / 500 | iteration 15 / 30 | Total Loss: 4.0536723136901855 | KNN Loss: 3.0459482669830322 | BCE Loss: 1.0077241659164429
Epoch 427 / 500 | iteration 20 / 30 | Total Loss: 4.096477031707764 | KNN Loss: 3.077831268310547 | BCE Loss: 1.0186456441879272
Epoch 427 / 500 | iteration 25 / 30 | Total Loss: 4.044920921325684 | KNN Loss: 3.0555672645568848

Epoch 437 / 500 | iteration 10 / 30 | Total Loss: 4.0638628005981445 | KNN Loss: 3.057809591293335 | BCE Loss: 1.0060529708862305
Epoch 437 / 500 | iteration 15 / 30 | Total Loss: 4.075961589813232 | KNN Loss: 3.0576694011688232 | BCE Loss: 1.0182920694351196
Epoch 437 / 500 | iteration 20 / 30 | Total Loss: 4.070868015289307 | KNN Loss: 3.0687456130981445 | BCE Loss: 1.002122402191162
Epoch 437 / 500 | iteration 25 / 30 | Total Loss: 4.07330322265625 | KNN Loss: 3.0627620220184326 | BCE Loss: 1.0105412006378174
Epoch 438 / 500 | iteration 0 / 30 | Total Loss: 4.084773540496826 | KNN Loss: 3.057629346847534 | BCE Loss: 1.0271443128585815
Epoch 438 / 500 | iteration 5 / 30 | Total Loss: 4.05793571472168 | KNN Loss: 3.0460476875305176 | BCE Loss: 1.011888027191162
Epoch 438 / 500 | iteration 10 / 30 | Total Loss: 4.082894802093506 | KNN Loss: 3.0746188163757324 | BCE Loss: 1.0082759857177734
Epoch 438 / 500 | iteration 15 / 30 | Total Loss: 4.098641395568848 | KNN Loss: 3.077237129211426

Epoch 448 / 500 | iteration 0 / 30 | Total Loss: 4.125553131103516 | KNN Loss: 3.0818722248077393 | BCE Loss: 1.0436809062957764
Epoch 448 / 500 | iteration 5 / 30 | Total Loss: 4.056644439697266 | KNN Loss: 3.0316333770751953 | BCE Loss: 1.0250113010406494
Epoch 448 / 500 | iteration 10 / 30 | Total Loss: 4.055647373199463 | KNN Loss: 3.04300594329834 | BCE Loss: 1.012641429901123
Epoch 448 / 500 | iteration 15 / 30 | Total Loss: 4.073719024658203 | KNN Loss: 3.0572991371154785 | BCE Loss: 1.0164196491241455
Epoch 448 / 500 | iteration 20 / 30 | Total Loss: 4.058377265930176 | KNN Loss: 3.069915533065796 | BCE Loss: 0.9884615540504456
Epoch 448 / 500 | iteration 25 / 30 | Total Loss: 4.1179118156433105 | KNN Loss: 3.0713958740234375 | BCE Loss: 1.0465160608291626
Epoch 449 / 500 | iteration 0 / 30 | Total Loss: 4.121257305145264 | KNN Loss: 3.0971975326538086 | BCE Loss: 1.024059772491455
Epoch 449 / 500 | iteration 5 / 30 | Total Loss: 4.09030818939209 | KNN Loss: 3.055854558944702 |

Epoch 458 / 500 | iteration 20 / 30 | Total Loss: 4.085227966308594 | KNN Loss: 3.048189640045166 | BCE Loss: 1.0370380878448486
Epoch 458 / 500 | iteration 25 / 30 | Total Loss: 4.100359916687012 | KNN Loss: 3.0805132389068604 | BCE Loss: 1.019846796989441
Epoch 459 / 500 | iteration 0 / 30 | Total Loss: 4.06060791015625 | KNN Loss: 3.0455482006073 | BCE Loss: 1.0150599479675293
Epoch 459 / 500 | iteration 5 / 30 | Total Loss: 4.077927589416504 | KNN Loss: 3.0534939765930176 | BCE Loss: 1.0244338512420654
Epoch 459 / 500 | iteration 10 / 30 | Total Loss: 4.124232769012451 | KNN Loss: 3.058230400085449 | BCE Loss: 1.066002368927002
Epoch 459 / 500 | iteration 15 / 30 | Total Loss: 4.094308376312256 | KNN Loss: 3.0752480030059814 | BCE Loss: 1.0190603733062744
Epoch 459 / 500 | iteration 20 / 30 | Total Loss: 4.099851608276367 | KNN Loss: 3.10188627243042 | BCE Loss: 0.9979653358459473
Epoch 459 / 500 | iteration 25 / 30 | Total Loss: 4.0859375 | KNN Loss: 3.0637705326080322 | BCE Loss:

Epoch 469 / 500 | iteration 10 / 30 | Total Loss: 4.0941572189331055 | KNN Loss: 3.043483257293701 | BCE Loss: 1.0506737232208252
Epoch 469 / 500 | iteration 15 / 30 | Total Loss: 4.117823123931885 | KNN Loss: 3.0962131023406982 | BCE Loss: 1.021609902381897
Epoch 469 / 500 | iteration 20 / 30 | Total Loss: 4.04673957824707 | KNN Loss: 3.030496120452881 | BCE Loss: 1.016243577003479
Epoch 469 / 500 | iteration 25 / 30 | Total Loss: 4.1146559715271 | KNN Loss: 3.081186532974243 | BCE Loss: 1.033469557762146
Epoch 470 / 500 | iteration 0 / 30 | Total Loss: 4.117110252380371 | KNN Loss: 3.0767993927001953 | BCE Loss: 1.0403110980987549
Epoch 470 / 500 | iteration 5 / 30 | Total Loss: 4.091004371643066 | KNN Loss: 3.0672128200531006 | BCE Loss: 1.023791790008545
Epoch 470 / 500 | iteration 10 / 30 | Total Loss: 4.059557914733887 | KNN Loss: 3.0564918518066406 | BCE Loss: 1.0030661821365356
Epoch 470 / 500 | iteration 15 / 30 | Total Loss: 4.03196907043457 | KNN Loss: 3.050567388534546 | BC

Epoch 480 / 500 | iteration 0 / 30 | Total Loss: 4.090553283691406 | KNN Loss: 3.075035810470581 | BCE Loss: 1.0155173540115356
Epoch 480 / 500 | iteration 5 / 30 | Total Loss: 4.072437286376953 | KNN Loss: 3.0551910400390625 | BCE Loss: 1.0172460079193115
Epoch 480 / 500 | iteration 10 / 30 | Total Loss: 4.121167182922363 | KNN Loss: 3.074995994567871 | BCE Loss: 1.0461711883544922
Epoch 480 / 500 | iteration 15 / 30 | Total Loss: 4.090211868286133 | KNN Loss: 3.0612542629241943 | BCE Loss: 1.0289576053619385
Epoch 480 / 500 | iteration 20 / 30 | Total Loss: 4.082388877868652 | KNN Loss: 3.0901005268096924 | BCE Loss: 0.9922885894775391
Epoch 480 / 500 | iteration 25 / 30 | Total Loss: 4.121117115020752 | KNN Loss: 3.07875657081604 | BCE Loss: 1.0423604249954224
Epoch 481 / 500 | iteration 0 / 30 | Total Loss: 4.074170112609863 | KNN Loss: 3.0565381050109863 | BCE Loss: 1.017632246017456
Epoch 481 / 500 | iteration 5 / 30 | Total Loss: 4.0654191970825195 | KNN Loss: 3.064542531967163 

Epoch 490 / 500 | iteration 20 / 30 | Total Loss: 4.043079376220703 | KNN Loss: 3.027883291244507 | BCE Loss: 1.0151962041854858
Epoch 490 / 500 | iteration 25 / 30 | Total Loss: 4.08083438873291 | KNN Loss: 3.0618417263031006 | BCE Loss: 1.01899254322052
Epoch 491 / 500 | iteration 0 / 30 | Total Loss: 4.067022323608398 | KNN Loss: 3.072608470916748 | BCE Loss: 0.9944136738777161
Epoch 491 / 500 | iteration 5 / 30 | Total Loss: 4.042782783508301 | KNN Loss: 3.047851324081421 | BCE Loss: 0.9949314594268799
Epoch 491 / 500 | iteration 10 / 30 | Total Loss: 4.111721038818359 | KNN Loss: 3.0793263912200928 | BCE Loss: 1.0323948860168457
Epoch 491 / 500 | iteration 15 / 30 | Total Loss: 4.1432647705078125 | KNN Loss: 3.1072874069213867 | BCE Loss: 1.0359772443771362
Epoch 491 / 500 | iteration 20 / 30 | Total Loss: 4.0979437828063965 | KNN Loss: 3.1002767086029053 | BCE Loss: 0.9976669549942017
Epoch 491 / 500 | iteration 25 / 30 | Total Loss: 4.115687370300293 | KNN Loss: 3.07952213287353

In [7]:
outputs, iterm = model(dataset[67][0].unsqueeze(0).to(device), return_intermidiate=True)
print(outputs)
print(outputs.min())
print(outputs.max())

tensor([[ 3.1004e+00,  2.8424e+00,  2.8250e+00,  3.0073e+00,  2.9105e+00,
          6.3167e-01,  2.4144e+00,  2.4124e+00,  2.2677e+00,  1.7566e+00,
          2.4531e+00,  2.0375e+00,  9.6963e-01,  1.9713e+00,  1.1873e+00,
          1.7947e+00,  2.8798e+00,  2.7177e+00,  2.5930e+00,  2.2776e+00,
          1.6449e+00,  3.1229e+00,  2.5383e+00,  2.7050e+00,  2.7791e+00,
          1.9863e+00,  2.1137e+00,  1.4696e+00,  1.4666e+00,  4.0106e-01,
         -2.4083e-01,  1.1589e+00,  3.1289e-01,  1.0797e+00,  1.6853e+00,
          1.5168e+00,  7.8847e-01,  3.2470e+00,  6.9854e-01,  1.3243e+00,
          1.1100e+00, -5.2067e-01, -3.9265e-02,  2.4637e+00,  2.3308e+00,
          6.6560e-01, -1.2188e-02,  2.0878e-01,  1.6554e+00,  2.5232e+00,
          1.8229e+00,  1.0992e-01,  1.2048e+00,  6.4537e-01, -4.4819e-01,
          1.3678e+00,  1.6303e+00,  1.3435e+00,  1.4171e+00,  2.0011e+00,
          8.6431e-01,  8.2060e-01,  2.8669e-01,  1.8782e+00,  1.3908e+00,
          1.7580e+00, -1.9129e+00,  3.

In [8]:
plt.figure()
plt.plot(losses)
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [9]:
dataset.transform = torchvision.transforms.Compose([
    BinaryEncodingTransform(mapping=dataset.items_to_idx),
]
)
dataset.target_transform = torchvision.transforms.Compose([
    BinaryEncodingTransform(mapping=dataset.items_to_idx),
]
)

In [10]:
dataset_ = [d[0].to('cpu') for d in dataset]

In [11]:
model = model.eval().to('cpu')
projections = model.calculate_intermidiate(dataset_)

100%|██████████| 15/15 [00:00<00:00, 102.32it/s]


In [12]:
distances = pairwise_distances(projections)
# distances = np.triu(distances)
distances_f = distances.flatten()

plt.matshow(distances)
plt.colorbar()
plt.figure()
plt.hist(distances_f[distances_f > 0], bins=1000)
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

# Fit DBSCAN and calculate indices


In [13]:
clusters = DBSCAN(eps=0.2, min_samples=80).fit_predict(projections)
# scores = []
# best_score = float('inf')
# clusters = None
# range_ = list(range(5, 20))
# for k in tqdm(range_):
#     y = GaussianMixture(n_components=k).fit_predict(projections)
#     cur_score = davies_bouldin_score(projections, y)
#     scores.append(cur_score)
    
#     if cur_score < best_score:
#         best_score = cur_score
#         clusters = y

In [14]:
perplexity = 100

p = reduce_dims_and_plot(projections[clusters != -1],
                         y=clusters[clusters != -1],
                         title=f'perplexity: {perplexity}',
                         file_name=None,
                         perplexity=perplexity,
                         library='Multicore-TSNE',
                         perform_PCA=False,
                         projected=None,
                         figure_type='2d',
                         show_figure=True,
                         close_figure=False,
                         text=None)    

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [15]:
# from sklearn.tree import DecisionTreeClassifier
# from sklearn import tree
# from sklearn.tree import _tree

In [16]:
tensor_dataset = torch.stack(dataset_)

In [17]:
# clf = DecisionTreeClassifier(max_depth=200, min_samples_leaf=5)
# clf = clf.fit(tensor_dataset[clusters!=-1], clusters[clusters != -1])
# print(clf.score(tensor_dataset[clusters!=-1], clusters[clusters != -1]))
# print(clf.get_depth())

In [18]:
# scores = []
# for min_samples in range(1,50, 1):
#     clf = DecisionTreeClassifier(max_depth=200, min_samples_leaf=min_samples)
#     clf = clf.fit(tensor_dataset[clusters!=-1], clusters[clusters != -1])
#     scores.append(clf.score(tensor_dataset[clusters!=-1], clusters[clusters != -1]))
    
# plt.figure()
# plt.plot(list(range(1,50, 1)), scores)
# plt.show()

In [19]:
def get_rules(tree, feature_names, class_names):
    tree_ = tree.tree_
    feature_name = [
        feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!"
        for i in tree_.feature
    ]

    paths = []
    path = []
    
    def recurse(node, path, paths):
        if tree_.feature[node] != _tree.TREE_UNDEFINED:
            name = feature_name[node]
            threshold = tree_.threshold[node]
            p1, p2 = list(path), list(path)
#             p1 += [f"({name} <= {np.round(threshold, 3)})"]
            p1 += [(name, '<=', np.round(threshold, 3))]
            recurse(tree_.children_left[node], p1, paths)
            p2 += [(name, '>', np.round(threshold, 3))]
            recurse(tree_.children_right[node], p2, paths)
        else:
            path += [(tree_.value[node], tree_.n_node_samples[node])]
            paths += [path]
            
    recurse(0, path, paths)

    # sort by samples count
    samples_count = [p[-1][1] for p in paths]
    ii = list(np.argsort(samples_count))
    paths = [paths[i] for i in reversed(ii)]
    
    rules = []
    for path in paths:
        rule = []
        
        for p in path[:-1]:
            rule += [p]
        target = " then "
        if class_names is None:
            target += "response: "+str(np.round(path[-1][0][0][0],3))
        else:
            classes = path[-1][0][0]
            l = np.argmax(classes)
            target += f"class: {class_names[l]} (proba: {np.round(100.0*classes[l]/np.sum(classes),2)}%)"
           
        proba = np.round(100.0*classes[l]/np.sum(classes),2)
        target += f" | based on {path[-1][1]:,} samples"
        rule_wrapper = {'target': target, 'rule': rule, 'proba': proba}
        rules += [rule_wrapper]
        
    return rules

In [20]:
# rules = get_rules(clf, dataset.items, clusters[clusters != -1])

# for rule in rules:
#     n_pos = 0
#     for c,p,v in rule['rule']:
#         if p == '>':
#             n_pos += 1
#     rule['pos'] = n_pos

In [21]:
# plt.figure()
# probs = [r['proba'] for r in rules]
# plt.hist(probs, bins = 100)
# plt.show()

In [22]:
# rules = sorted(rules, key=lambda x:x['pos'])
# rules = [r for r in rules if r['proba'] > 50]
# print(len(rules))

In [23]:
# for i in range(17):
#     r_i = rules[i]
#     print(f"------------- rule {i} length {len(r_i)} -------------")
#     print(r_i)

# Train a Soft-Decision-Tree given the self-labels

## Prepare the dataset

In [24]:
tree_dataset = list(zip(tensor_dataset[clusters!=-1], clusters[clusters != -1]))
batch_size = 512
tree_loader = torch.utils.data.DataLoader(tree_dataset, batch_size=batch_size, shuffle=True)

# Define how we prune the weights of a node

In [25]:
def prune_node(node_weights, factor=1):
    w = node_weights.cpu().detach().numpy()
    mean_ = np.mean(w)
    std_ = np.std(w)
    node_weights[((mean_ - std_ * factor) < node_weights) & (node_weights < (mean_ + std_ * factor))] = 0
    return node_weights

def prune_node_keep(node_weights, keep=4):
    w = node_weights.cpu().detach().numpy()
    throw_idx = np.argsort(abs(w))[:-keep]
    node_weights[throw_idx] = 0
    return node_weights

def prune_tree(tree_, factor):
    new_weights = tree_.inner_nodes.weight.clone()
    for i in range(new_weights.shape[0]):
        res = prune_node_keep(new_weights[i, :], factor)
        new_weights[i, :] = res

    with torch.no_grad():
        tree_.inner_nodes.weight.copy_(new_weights)
        
def sparseness(x):
    s = []
    for i in range(x.shape[0]):
        x_ = x[i, :]
        sp = (len(x_) - torch.norm(x_, 0).item()) / len(x_)
        s.append(sp)
    return np.mean(s)

def compute_regularization_by_level(tree):
    total_reg = 0
    for i in range(tree.inner_nodes.weight.shape[0]):
        cur_level = np.floor(np.log2(i+1))
        node_reg = torch.norm(tree.inner_nodes.weight[i].view(-1), 2)
        total_reg += 2**(-cur_level) * node_reg
    return total_reg

def show_sparseness(tree):
    avg_sp = sparseness(tree.inner_nodes.weight)
    print(f"Average sparseness: {avg_sp}")
    layer = 0
    sps = []
    for i in range(tree.inner_nodes.weight.shape[0]):
        cur_layer = int(np.floor(np.log2(i+1)))
        if cur_layer != layer:
            print(f"layer {layer}: {np.mean(sps)}")
            sps = []
            layer = cur_layer

        x_ = tree.inner_nodes.weight[i, :]
        sp = (len(x_) - torch.norm(x_, 0).item()) / len(x_)
        sps.append(sp)
        
    return avg_sp

## Training configurations

In [26]:
def do_epoch(model, loader, device, log_interval, losses, accs, epoch, iteration):
    model = model.train()
    for batch_idx, (data, target) in enumerate(loader):
        iteration += 1
        data, target = data.to(device), target.to(device)

        output, penalty = tree.forward(data)

        # Loss
        loss_tree = criterion(output, target.view(-1))

        # Penalty
        loss_tree += penalty

        # Sparse regularization
#         fc_params = torch.cat([x.view(-1) for x in tree.inner_nodes.parameters()])
#         regularization = sparsity_lamda * torch.norm(fc_params, 2)
        regularization = sparsity_lamda * compute_regularization_by_level(tree)
        loss = loss_tree

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        losses.append(loss.item())

        pred = output.data.max(1)[1]
        correct = pred.eq(target.view(-1).data).sum()
        accs.append(correct.item() / data.size()[0])

        # Print training status
        if batch_idx % log_interval == 0:
            print(f"Epoch: {epoch:02d} | Batch: {batch_idx:03d} / {len(loader):03d} | Total loss: {loss.item():.3f} | Reg loss: {regularization.item():.3f} | Tree loss: {loss_tree.item():.3f} | Accuracy: {correct.item() / data.size()[0]:03f} | {round((time.time() - start_time) / iteration, 3)} sec/iter")
            
    return iteration


In [27]:
lr = 5e-3
weight_decay = 5e-4
sparsity_lamda = 2e-3
epochs = 100
output_dim = len(set(clusters))
log_interval = 1
use_cuda = device != 'cpu'

In [28]:
tree = SDT(input_dim=tensor_dataset.shape[1], output_dim=len(clusters - 1), depth=tree_depth, lamda=1e-3, use_cuda=use_cuda)
optimizer = torch.optim.Adam(tree.parameters(),
                                 lr=lr,
                                 weight_decay=weight_decay)
criterion = nn.CrossEntropyLoss()
tree = tree.to(device)

In [29]:
losses = []
accs = []
sparsity = []

In [30]:
start_time = time.time()
iteration = 0
for epoch in range(epochs):
    # Training
    avg_sp = show_sparseness(tree)
    sparsity.append(avg_sp)
    iteration = do_epoch(tree, tree_loader, device, log_interval, losses, accs, epoch, iteration)
    
    if epoch % 1 == 0:
        prune_tree(tree, factor=3)
        

Average sparseness: 0.0
layer 0: 0.0
layer 1: 0.0
layer 2: 0.0
layer 3: 0.0
layer 4: 0.0
Epoch: 00 | Batch: 000 / 016 | Total loss: 9.646 | Reg loss: 0.007 | Tree loss: 9.646 | Accuracy: 0.000000 | 0.108 sec/iter
Epoch: 00 | Batch: 001 / 016 | Total loss: 9.650 | Reg loss: 0.007 | Tree loss: 9.650 | Accuracy: 0.000000 | 0.09 sec/iter
Epoch: 00 | Batch: 002 / 016 | Total loss: 9.630 | Reg loss: 0.007 | Tree loss: 9.630 | Accuracy: 0.000000 | 0.081 sec/iter
Epoch: 00 | Batch: 003 / 016 | Total loss: 9.629 | Reg loss: 0.006 | Tree loss: 9.629 | Accuracy: 0.000000 | 0.078 sec/iter
Epoch: 00 | Batch: 004 / 016 | Total loss: 9.625 | Reg loss: 0.006 | Tree loss: 9.625 | Accuracy: 0.000000 | 0.074 sec/iter
Epoch: 00 | Batch: 005 / 016 | Total loss: 9.605 | Reg loss: 0.006 | Tree loss: 9.605 | Accuracy: 0.000000 | 0.071 sec/iter
Epoch: 00 | Batch: 006 / 016 | Total loss: 9.603 | Reg loss: 0.006 | Tree loss: 9.603 | Accuracy: 0.000000 | 0.069 sec/iter
Epoch: 00 | Batch: 007 / 016 | Total loss: 9

Epoch: 03 | Batch: 014 / 016 | Total loss: 9.360 | Reg loss: 0.007 | Tree loss: 9.360 | Accuracy: 0.115234 | 0.058 sec/iter
Epoch: 03 | Batch: 015 / 016 | Total loss: 9.366 | Reg loss: 0.007 | Tree loss: 9.366 | Accuracy: 0.106838 | 0.058 sec/iter
Average sparseness: 0.9821428571428567
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
Epoch: 04 | Batch: 000 / 016 | Total loss: 9.391 | Reg loss: 0.005 | Tree loss: 9.391 | Accuracy: 0.091797 | 0.058 sec/iter
Epoch: 04 | Batch: 001 / 016 | Total loss: 9.373 | Reg loss: 0.005 | Tree loss: 9.373 | Accuracy: 0.125000 | 0.058 sec/iter
Epoch: 04 | Batch: 002 / 016 | Total loss: 9.383 | Reg loss: 0.005 | Tree loss: 9.383 | Accuracy: 0.091797 | 0.058 sec/iter
Epoch: 04 | Batch: 003 / 016 | Total loss: 9.364 | Reg loss: 0.005 | Tree loss: 9.364 | Accuracy: 0.132812 | 0.058 sec/iter
Epoch: 04 | Batch: 004 / 016 | Total loss: 9.367 | Reg loss: 0.006 | Tree los

Epoch: 07 | Batch: 013 / 016 | Total loss: 9.084 | Reg loss: 0.011 | Tree loss: 9.084 | Accuracy: 0.087891 | 0.058 sec/iter
Epoch: 07 | Batch: 014 / 016 | Total loss: 9.072 | Reg loss: 0.011 | Tree loss: 9.072 | Accuracy: 0.121094 | 0.058 sec/iter
Epoch: 07 | Batch: 015 / 016 | Total loss: 9.063 | Reg loss: 0.011 | Tree loss: 9.063 | Accuracy: 0.123932 | 0.058 sec/iter
Average sparseness: 0.9821428571428567
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
Epoch: 08 | Batch: 000 / 016 | Total loss: 9.107 | Reg loss: 0.009 | Tree loss: 9.107 | Accuracy: 0.121094 | 0.058 sec/iter
Epoch: 08 | Batch: 001 / 016 | Total loss: 9.108 | Reg loss: 0.009 | Tree loss: 9.108 | Accuracy: 0.119141 | 0.059 sec/iter
Epoch: 08 | Batch: 002 / 016 | Total loss: 9.108 | Reg loss: 0.010 | Tree loss: 9.108 | Accuracy: 0.111328 | 0.059 sec/iter
Epoch: 08 | Batch: 003 / 016 | Total loss: 9.089 | Reg loss: 0.010 | Tree los

Epoch: 11 | Batch: 010 / 016 | Total loss: 8.765 | Reg loss: 0.014 | Tree loss: 8.765 | Accuracy: 0.091797 | 0.058 sec/iter
Epoch: 11 | Batch: 011 / 016 | Total loss: 8.719 | Reg loss: 0.015 | Tree loss: 8.719 | Accuracy: 0.132812 | 0.058 sec/iter
Epoch: 11 | Batch: 012 / 016 | Total loss: 8.702 | Reg loss: 0.015 | Tree loss: 8.702 | Accuracy: 0.113281 | 0.058 sec/iter
Epoch: 11 | Batch: 013 / 016 | Total loss: 8.717 | Reg loss: 0.015 | Tree loss: 8.717 | Accuracy: 0.113281 | 0.058 sec/iter
Epoch: 11 | Batch: 014 / 016 | Total loss: 8.704 | Reg loss: 0.015 | Tree loss: 8.704 | Accuracy: 0.117188 | 0.058 sec/iter
Epoch: 11 | Batch: 015 / 016 | Total loss: 8.673 | Reg loss: 0.016 | Tree loss: 8.673 | Accuracy: 0.121795 | 0.058 sec/iter
Average sparseness: 0.9821428571428567
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
Epoch: 12 | Batch: 000 / 016 | Total loss: 8.786 | Reg loss: 0.014 | Tree los

Epoch: 15 | Batch: 010 / 016 | Total loss: 8.314 | Reg loss: 0.018 | Tree loss: 8.314 | Accuracy: 0.083984 | 0.058 sec/iter
Epoch: 15 | Batch: 011 / 016 | Total loss: 8.309 | Reg loss: 0.018 | Tree loss: 8.309 | Accuracy: 0.125000 | 0.058 sec/iter
Epoch: 15 | Batch: 012 / 016 | Total loss: 8.258 | Reg loss: 0.018 | Tree loss: 8.258 | Accuracy: 0.109375 | 0.058 sec/iter
Epoch: 15 | Batch: 013 / 016 | Total loss: 8.287 | Reg loss: 0.018 | Tree loss: 8.287 | Accuracy: 0.097656 | 0.058 sec/iter
Epoch: 15 | Batch: 014 / 016 | Total loss: 8.257 | Reg loss: 0.018 | Tree loss: 8.257 | Accuracy: 0.095703 | 0.058 sec/iter
Epoch: 15 | Batch: 015 / 016 | Total loss: 8.222 | Reg loss: 0.019 | Tree loss: 8.222 | Accuracy: 0.117521 | 0.058 sec/iter
Average sparseness: 0.9821428571428567
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
Epoch: 16 | Batch: 000 / 016 | Total loss: 8.400 | Reg loss: 0.017 | Tree los

Epoch: 19 | Batch: 010 / 016 | Total loss: 7.854 | Reg loss: 0.020 | Tree loss: 7.854 | Accuracy: 0.107422 | 0.058 sec/iter
Epoch: 19 | Batch: 011 / 016 | Total loss: 7.856 | Reg loss: 0.021 | Tree loss: 7.856 | Accuracy: 0.093750 | 0.058 sec/iter
Epoch: 19 | Batch: 012 / 016 | Total loss: 7.787 | Reg loss: 0.021 | Tree loss: 7.787 | Accuracy: 0.105469 | 0.058 sec/iter
Epoch: 19 | Batch: 013 / 016 | Total loss: 7.755 | Reg loss: 0.021 | Tree loss: 7.755 | Accuracy: 0.109375 | 0.058 sec/iter
Epoch: 19 | Batch: 014 / 016 | Total loss: 7.799 | Reg loss: 0.021 | Tree loss: 7.799 | Accuracy: 0.091797 | 0.058 sec/iter
Epoch: 19 | Batch: 015 / 016 | Total loss: 7.779 | Reg loss: 0.021 | Tree loss: 7.779 | Accuracy: 0.094017 | 0.058 sec/iter
Average sparseness: 0.9821428571428567
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
Epoch: 20 | Batch: 000 / 016 | Total loss: 7.943 | Reg loss: 0.020 | Tree los

Epoch: 23 | Batch: 010 / 016 | Total loss: 7.406 | Reg loss: 0.023 | Tree loss: 7.406 | Accuracy: 0.091797 | 0.058 sec/iter
Epoch: 23 | Batch: 011 / 016 | Total loss: 7.387 | Reg loss: 0.023 | Tree loss: 7.387 | Accuracy: 0.095703 | 0.058 sec/iter
Epoch: 23 | Batch: 012 / 016 | Total loss: 7.334 | Reg loss: 0.023 | Tree loss: 7.334 | Accuracy: 0.072266 | 0.058 sec/iter
Epoch: 23 | Batch: 013 / 016 | Total loss: 7.345 | Reg loss: 0.023 | Tree loss: 7.345 | Accuracy: 0.091797 | 0.058 sec/iter
Epoch: 23 | Batch: 014 / 016 | Total loss: 7.295 | Reg loss: 0.023 | Tree loss: 7.295 | Accuracy: 0.078125 | 0.058 sec/iter
Epoch: 23 | Batch: 015 / 016 | Total loss: 7.286 | Reg loss: 0.023 | Tree loss: 7.286 | Accuracy: 0.083333 | 0.058 sec/iter
Average sparseness: 0.9821428571428567
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
Epoch: 24 | Batch: 000 / 016 | Total loss: 7.455 | Reg loss: 0.022 | Tree los

Epoch: 27 | Batch: 007 / 016 | Total loss: 6.987 | Reg loss: 0.024 | Tree loss: 6.987 | Accuracy: 0.083984 | 0.058 sec/iter
Epoch: 27 | Batch: 008 / 016 | Total loss: 6.974 | Reg loss: 0.024 | Tree loss: 6.974 | Accuracy: 0.095703 | 0.058 sec/iter
Epoch: 27 | Batch: 009 / 016 | Total loss: 7.004 | Reg loss: 0.024 | Tree loss: 7.004 | Accuracy: 0.087891 | 0.058 sec/iter
Epoch: 27 | Batch: 010 / 016 | Total loss: 6.950 | Reg loss: 0.024 | Tree loss: 6.950 | Accuracy: 0.085938 | 0.058 sec/iter
Epoch: 27 | Batch: 011 / 016 | Total loss: 6.928 | Reg loss: 0.025 | Tree loss: 6.928 | Accuracy: 0.087891 | 0.058 sec/iter
Epoch: 27 | Batch: 012 / 016 | Total loss: 6.951 | Reg loss: 0.025 | Tree loss: 6.951 | Accuracy: 0.080078 | 0.058 sec/iter
Epoch: 27 | Batch: 013 / 016 | Total loss: 6.890 | Reg loss: 0.025 | Tree loss: 6.890 | Accuracy: 0.083984 | 0.058 sec/iter
Epoch: 27 | Batch: 014 / 016 | Total loss: 6.849 | Reg loss: 0.025 | Tree loss: 6.849 | Accuracy: 0.080078 | 0.058 sec/iter
Epoch: 2

Epoch: 31 | Batch: 007 / 016 | Total loss: 6.618 | Reg loss: 0.026 | Tree loss: 6.618 | Accuracy: 0.072266 | 0.058 sec/iter
Epoch: 31 | Batch: 008 / 016 | Total loss: 6.581 | Reg loss: 0.026 | Tree loss: 6.581 | Accuracy: 0.087891 | 0.058 sec/iter
Epoch: 31 | Batch: 009 / 016 | Total loss: 6.549 | Reg loss: 0.026 | Tree loss: 6.549 | Accuracy: 0.089844 | 0.058 sec/iter
Epoch: 31 | Batch: 010 / 016 | Total loss: 6.560 | Reg loss: 0.026 | Tree loss: 6.560 | Accuracy: 0.091797 | 0.058 sec/iter
Epoch: 31 | Batch: 011 / 016 | Total loss: 6.529 | Reg loss: 0.026 | Tree loss: 6.529 | Accuracy: 0.072266 | 0.058 sec/iter
Epoch: 31 | Batch: 012 / 016 | Total loss: 6.520 | Reg loss: 0.026 | Tree loss: 6.520 | Accuracy: 0.068359 | 0.058 sec/iter
Epoch: 31 | Batch: 013 / 016 | Total loss: 6.469 | Reg loss: 0.026 | Tree loss: 6.469 | Accuracy: 0.085938 | 0.058 sec/iter
Epoch: 31 | Batch: 014 / 016 | Total loss: 6.531 | Reg loss: 0.026 | Tree loss: 6.531 | Accuracy: 0.068359 | 0.058 sec/iter
Epoch: 3

Epoch: 35 | Batch: 007 / 016 | Total loss: 6.243 | Reg loss: 0.027 | Tree loss: 6.243 | Accuracy: 0.080078 | 0.058 sec/iter
Epoch: 35 | Batch: 008 / 016 | Total loss: 6.220 | Reg loss: 0.027 | Tree loss: 6.220 | Accuracy: 0.087891 | 0.058 sec/iter
Epoch: 35 | Batch: 009 / 016 | Total loss: 6.183 | Reg loss: 0.027 | Tree loss: 6.183 | Accuracy: 0.105469 | 0.058 sec/iter
Epoch: 35 | Batch: 010 / 016 | Total loss: 6.176 | Reg loss: 0.027 | Tree loss: 6.176 | Accuracy: 0.095703 | 0.058 sec/iter
Epoch: 35 | Batch: 011 / 016 | Total loss: 6.216 | Reg loss: 0.027 | Tree loss: 6.216 | Accuracy: 0.058594 | 0.058 sec/iter
Epoch: 35 | Batch: 012 / 016 | Total loss: 6.185 | Reg loss: 0.027 | Tree loss: 6.185 | Accuracy: 0.080078 | 0.058 sec/iter
Epoch: 35 | Batch: 013 / 016 | Total loss: 6.157 | Reg loss: 0.027 | Tree loss: 6.157 | Accuracy: 0.083984 | 0.058 sec/iter
Epoch: 35 | Batch: 014 / 016 | Total loss: 6.136 | Reg loss: 0.027 | Tree loss: 6.136 | Accuracy: 0.076172 | 0.058 sec/iter
Epoch: 3

Epoch: 39 | Batch: 006 / 016 | Total loss: 5.928 | Reg loss: 0.028 | Tree loss: 5.928 | Accuracy: 0.091797 | 0.058 sec/iter
Epoch: 39 | Batch: 007 / 016 | Total loss: 5.908 | Reg loss: 0.028 | Tree loss: 5.908 | Accuracy: 0.105469 | 0.058 sec/iter
Epoch: 39 | Batch: 008 / 016 | Total loss: 5.917 | Reg loss: 0.028 | Tree loss: 5.917 | Accuracy: 0.095703 | 0.058 sec/iter
Epoch: 39 | Batch: 009 / 016 | Total loss: 5.871 | Reg loss: 0.028 | Tree loss: 5.871 | Accuracy: 0.107422 | 0.058 sec/iter
Epoch: 39 | Batch: 010 / 016 | Total loss: 5.863 | Reg loss: 0.028 | Tree loss: 5.863 | Accuracy: 0.072266 | 0.058 sec/iter
Epoch: 39 | Batch: 011 / 016 | Total loss: 5.888 | Reg loss: 0.028 | Tree loss: 5.888 | Accuracy: 0.085938 | 0.058 sec/iter
Epoch: 39 | Batch: 012 / 016 | Total loss: 5.814 | Reg loss: 0.028 | Tree loss: 5.814 | Accuracy: 0.095703 | 0.058 sec/iter
Epoch: 39 | Batch: 013 / 016 | Total loss: 5.834 | Reg loss: 0.028 | Tree loss: 5.834 | Accuracy: 0.078125 | 0.058 sec/iter
Epoch: 3

Epoch: 43 | Batch: 006 / 016 | Total loss: 5.629 | Reg loss: 0.029 | Tree loss: 5.629 | Accuracy: 0.111328 | 0.058 sec/iter
Epoch: 43 | Batch: 007 / 016 | Total loss: 5.615 | Reg loss: 0.029 | Tree loss: 5.615 | Accuracy: 0.093750 | 0.058 sec/iter
Epoch: 43 | Batch: 008 / 016 | Total loss: 5.567 | Reg loss: 0.029 | Tree loss: 5.567 | Accuracy: 0.103516 | 0.058 sec/iter
Epoch: 43 | Batch: 009 / 016 | Total loss: 5.596 | Reg loss: 0.029 | Tree loss: 5.596 | Accuracy: 0.091797 | 0.058 sec/iter
Epoch: 43 | Batch: 010 / 016 | Total loss: 5.616 | Reg loss: 0.029 | Tree loss: 5.616 | Accuracy: 0.080078 | 0.058 sec/iter
Epoch: 43 | Batch: 011 / 016 | Total loss: 5.580 | Reg loss: 0.029 | Tree loss: 5.580 | Accuracy: 0.083984 | 0.058 sec/iter
Epoch: 43 | Batch: 012 / 016 | Total loss: 5.579 | Reg loss: 0.029 | Tree loss: 5.579 | Accuracy: 0.103516 | 0.058 sec/iter
Epoch: 43 | Batch: 013 / 016 | Total loss: 5.544 | Reg loss: 0.029 | Tree loss: 5.544 | Accuracy: 0.091797 | 0.058 sec/iter
Epoch: 4

Epoch: 47 | Batch: 004 / 016 | Total loss: 5.435 | Reg loss: 0.030 | Tree loss: 5.435 | Accuracy: 0.083984 | 0.058 sec/iter
Epoch: 47 | Batch: 005 / 016 | Total loss: 5.421 | Reg loss: 0.030 | Tree loss: 5.421 | Accuracy: 0.085938 | 0.058 sec/iter
Epoch: 47 | Batch: 006 / 016 | Total loss: 5.400 | Reg loss: 0.030 | Tree loss: 5.400 | Accuracy: 0.093750 | 0.058 sec/iter
Epoch: 47 | Batch: 007 / 016 | Total loss: 5.436 | Reg loss: 0.030 | Tree loss: 5.436 | Accuracy: 0.095703 | 0.058 sec/iter
Epoch: 47 | Batch: 008 / 016 | Total loss: 5.331 | Reg loss: 0.030 | Tree loss: 5.331 | Accuracy: 0.095703 | 0.058 sec/iter
Epoch: 47 | Batch: 009 / 016 | Total loss: 5.310 | Reg loss: 0.030 | Tree loss: 5.310 | Accuracy: 0.095703 | 0.058 sec/iter
Epoch: 47 | Batch: 010 / 016 | Total loss: 5.271 | Reg loss: 0.030 | Tree loss: 5.271 | Accuracy: 0.101562 | 0.058 sec/iter
Epoch: 47 | Batch: 011 / 016 | Total loss: 5.342 | Reg loss: 0.030 | Tree loss: 5.342 | Accuracy: 0.103516 | 0.058 sec/iter
Epoch: 4

Epoch: 51 | Batch: 001 / 016 | Total loss: 5.239 | Reg loss: 0.030 | Tree loss: 5.239 | Accuracy: 0.115234 | 0.057 sec/iter
Epoch: 51 | Batch: 002 / 016 | Total loss: 5.206 | Reg loss: 0.030 | Tree loss: 5.206 | Accuracy: 0.115234 | 0.057 sec/iter
Epoch: 51 | Batch: 003 / 016 | Total loss: 5.195 | Reg loss: 0.030 | Tree loss: 5.195 | Accuracy: 0.119141 | 0.057 sec/iter
Epoch: 51 | Batch: 004 / 016 | Total loss: 5.174 | Reg loss: 0.030 | Tree loss: 5.174 | Accuracy: 0.128906 | 0.057 sec/iter
Epoch: 51 | Batch: 005 / 016 | Total loss: 5.189 | Reg loss: 0.030 | Tree loss: 5.189 | Accuracy: 0.085938 | 0.057 sec/iter
Epoch: 51 | Batch: 006 / 016 | Total loss: 5.228 | Reg loss: 0.030 | Tree loss: 5.228 | Accuracy: 0.095703 | 0.058 sec/iter
Epoch: 51 | Batch: 007 / 016 | Total loss: 5.134 | Reg loss: 0.030 | Tree loss: 5.134 | Accuracy: 0.103516 | 0.058 sec/iter
Epoch: 51 | Batch: 008 / 016 | Total loss: 5.130 | Reg loss: 0.030 | Tree loss: 5.130 | Accuracy: 0.128906 | 0.058 sec/iter
Epoch: 5

Epoch: 55 | Batch: 002 / 016 | Total loss: 5.071 | Reg loss: 0.031 | Tree loss: 5.071 | Accuracy: 0.101562 | 0.058 sec/iter
Epoch: 55 | Batch: 003 / 016 | Total loss: 5.047 | Reg loss: 0.031 | Tree loss: 5.047 | Accuracy: 0.105469 | 0.058 sec/iter
Epoch: 55 | Batch: 004 / 016 | Total loss: 5.003 | Reg loss: 0.031 | Tree loss: 5.003 | Accuracy: 0.123047 | 0.058 sec/iter
Epoch: 55 | Batch: 005 / 016 | Total loss: 4.946 | Reg loss: 0.031 | Tree loss: 4.946 | Accuracy: 0.111328 | 0.058 sec/iter
Epoch: 55 | Batch: 006 / 016 | Total loss: 5.013 | Reg loss: 0.031 | Tree loss: 5.013 | Accuracy: 0.103516 | 0.058 sec/iter
Epoch: 55 | Batch: 007 / 016 | Total loss: 4.903 | Reg loss: 0.031 | Tree loss: 4.903 | Accuracy: 0.130859 | 0.058 sec/iter
Epoch: 55 | Batch: 008 / 016 | Total loss: 4.968 | Reg loss: 0.031 | Tree loss: 4.968 | Accuracy: 0.083984 | 0.058 sec/iter
Epoch: 55 | Batch: 009 / 016 | Total loss: 4.997 | Reg loss: 0.031 | Tree loss: 4.997 | Accuracy: 0.080078 | 0.058 sec/iter
Epoch: 5

Epoch: 59 | Batch: 002 / 016 | Total loss: 4.935 | Reg loss: 0.031 | Tree loss: 4.935 | Accuracy: 0.107422 | 0.057 sec/iter
Epoch: 59 | Batch: 003 / 016 | Total loss: 4.820 | Reg loss: 0.031 | Tree loss: 4.820 | Accuracy: 0.105469 | 0.057 sec/iter
Epoch: 59 | Batch: 004 / 016 | Total loss: 4.859 | Reg loss: 0.031 | Tree loss: 4.859 | Accuracy: 0.119141 | 0.057 sec/iter
Epoch: 59 | Batch: 005 / 016 | Total loss: 4.809 | Reg loss: 0.031 | Tree loss: 4.809 | Accuracy: 0.111328 | 0.057 sec/iter
Epoch: 59 | Batch: 006 / 016 | Total loss: 4.741 | Reg loss: 0.031 | Tree loss: 4.741 | Accuracy: 0.121094 | 0.057 sec/iter
Epoch: 59 | Batch: 007 / 016 | Total loss: 4.801 | Reg loss: 0.031 | Tree loss: 4.801 | Accuracy: 0.103516 | 0.057 sec/iter
Epoch: 59 | Batch: 008 / 016 | Total loss: 4.782 | Reg loss: 0.031 | Tree loss: 4.782 | Accuracy: 0.095703 | 0.057 sec/iter
Epoch: 59 | Batch: 009 / 016 | Total loss: 4.793 | Reg loss: 0.031 | Tree loss: 4.793 | Accuracy: 0.109375 | 0.057 sec/iter
Epoch: 5

Epoch: 63 | Batch: 002 / 016 | Total loss: 4.735 | Reg loss: 0.031 | Tree loss: 4.735 | Accuracy: 0.103516 | 0.057 sec/iter
Epoch: 63 | Batch: 003 / 016 | Total loss: 4.726 | Reg loss: 0.031 | Tree loss: 4.726 | Accuracy: 0.123047 | 0.057 sec/iter
Epoch: 63 | Batch: 004 / 016 | Total loss: 4.737 | Reg loss: 0.031 | Tree loss: 4.737 | Accuracy: 0.119141 | 0.057 sec/iter
Epoch: 63 | Batch: 005 / 016 | Total loss: 4.674 | Reg loss: 0.031 | Tree loss: 4.674 | Accuracy: 0.113281 | 0.057 sec/iter
Epoch: 63 | Batch: 006 / 016 | Total loss: 4.660 | Reg loss: 0.031 | Tree loss: 4.660 | Accuracy: 0.093750 | 0.057 sec/iter
Epoch: 63 | Batch: 007 / 016 | Total loss: 4.716 | Reg loss: 0.031 | Tree loss: 4.716 | Accuracy: 0.083984 | 0.057 sec/iter
Epoch: 63 | Batch: 008 / 016 | Total loss: 4.619 | Reg loss: 0.031 | Tree loss: 4.619 | Accuracy: 0.097656 | 0.057 sec/iter
Epoch: 63 | Batch: 009 / 016 | Total loss: 4.630 | Reg loss: 0.031 | Tree loss: 4.630 | Accuracy: 0.123047 | 0.057 sec/iter
Epoch: 6

Epoch: 67 | Batch: 002 / 016 | Total loss: 4.633 | Reg loss: 0.031 | Tree loss: 4.633 | Accuracy: 0.099609 | 0.057 sec/iter
Epoch: 67 | Batch: 003 / 016 | Total loss: 4.657 | Reg loss: 0.031 | Tree loss: 4.657 | Accuracy: 0.109375 | 0.057 sec/iter
Epoch: 67 | Batch: 004 / 016 | Total loss: 4.523 | Reg loss: 0.031 | Tree loss: 4.523 | Accuracy: 0.123047 | 0.057 sec/iter
Epoch: 67 | Batch: 005 / 016 | Total loss: 4.596 | Reg loss: 0.031 | Tree loss: 4.596 | Accuracy: 0.101562 | 0.057 sec/iter
Epoch: 67 | Batch: 006 / 016 | Total loss: 4.602 | Reg loss: 0.032 | Tree loss: 4.602 | Accuracy: 0.130859 | 0.057 sec/iter
Epoch: 67 | Batch: 007 / 016 | Total loss: 4.633 | Reg loss: 0.032 | Tree loss: 4.633 | Accuracy: 0.115234 | 0.057 sec/iter
Epoch: 67 | Batch: 008 / 016 | Total loss: 4.585 | Reg loss: 0.032 | Tree loss: 4.585 | Accuracy: 0.107422 | 0.057 sec/iter
Epoch: 67 | Batch: 009 / 016 | Total loss: 4.563 | Reg loss: 0.032 | Tree loss: 4.563 | Accuracy: 0.082031 | 0.057 sec/iter
Epoch: 6

Epoch: 71 | Batch: 002 / 016 | Total loss: 4.614 | Reg loss: 0.032 | Tree loss: 4.614 | Accuracy: 0.087891 | 0.057 sec/iter
Epoch: 71 | Batch: 003 / 016 | Total loss: 4.588 | Reg loss: 0.032 | Tree loss: 4.588 | Accuracy: 0.095703 | 0.057 sec/iter
Epoch: 71 | Batch: 004 / 016 | Total loss: 4.514 | Reg loss: 0.032 | Tree loss: 4.514 | Accuracy: 0.125000 | 0.057 sec/iter
Epoch: 71 | Batch: 005 / 016 | Total loss: 4.508 | Reg loss: 0.032 | Tree loss: 4.508 | Accuracy: 0.099609 | 0.057 sec/iter
Epoch: 71 | Batch: 006 / 016 | Total loss: 4.476 | Reg loss: 0.032 | Tree loss: 4.476 | Accuracy: 0.107422 | 0.057 sec/iter
Epoch: 71 | Batch: 007 / 016 | Total loss: 4.478 | Reg loss: 0.032 | Tree loss: 4.478 | Accuracy: 0.103516 | 0.057 sec/iter
Epoch: 71 | Batch: 008 / 016 | Total loss: 4.409 | Reg loss: 0.032 | Tree loss: 4.409 | Accuracy: 0.115234 | 0.057 sec/iter
Epoch: 71 | Batch: 009 / 016 | Total loss: 4.382 | Reg loss: 0.032 | Tree loss: 4.382 | Accuracy: 0.115234 | 0.057 sec/iter
Epoch: 7

Epoch: 75 | Batch: 002 / 016 | Total loss: 4.477 | Reg loss: 0.032 | Tree loss: 4.477 | Accuracy: 0.097656 | 0.057 sec/iter
Epoch: 75 | Batch: 003 / 016 | Total loss: 4.480 | Reg loss: 0.032 | Tree loss: 4.480 | Accuracy: 0.111328 | 0.057 sec/iter
Epoch: 75 | Batch: 004 / 016 | Total loss: 4.484 | Reg loss: 0.032 | Tree loss: 4.484 | Accuracy: 0.093750 | 0.057 sec/iter
Epoch: 75 | Batch: 005 / 016 | Total loss: 4.400 | Reg loss: 0.032 | Tree loss: 4.400 | Accuracy: 0.132812 | 0.057 sec/iter
Epoch: 75 | Batch: 006 / 016 | Total loss: 4.516 | Reg loss: 0.032 | Tree loss: 4.516 | Accuracy: 0.072266 | 0.057 sec/iter
Epoch: 75 | Batch: 007 / 016 | Total loss: 4.378 | Reg loss: 0.032 | Tree loss: 4.378 | Accuracy: 0.097656 | 0.057 sec/iter
Epoch: 75 | Batch: 008 / 016 | Total loss: 4.461 | Reg loss: 0.032 | Tree loss: 4.461 | Accuracy: 0.103516 | 0.057 sec/iter
Epoch: 75 | Batch: 009 / 016 | Total loss: 4.326 | Reg loss: 0.032 | Tree loss: 4.326 | Accuracy: 0.119141 | 0.057 sec/iter
Epoch: 7

Epoch: 79 | Batch: 002 / 016 | Total loss: 4.363 | Reg loss: 0.032 | Tree loss: 4.363 | Accuracy: 0.121094 | 0.058 sec/iter
Epoch: 79 | Batch: 003 / 016 | Total loss: 4.440 | Reg loss: 0.032 | Tree loss: 4.440 | Accuracy: 0.113281 | 0.058 sec/iter
Epoch: 79 | Batch: 004 / 016 | Total loss: 4.341 | Reg loss: 0.032 | Tree loss: 4.341 | Accuracy: 0.103516 | 0.058 sec/iter
Epoch: 79 | Batch: 005 / 016 | Total loss: 4.411 | Reg loss: 0.032 | Tree loss: 4.411 | Accuracy: 0.087891 | 0.058 sec/iter
Epoch: 79 | Batch: 006 / 016 | Total loss: 4.393 | Reg loss: 0.032 | Tree loss: 4.393 | Accuracy: 0.091797 | 0.058 sec/iter
Epoch: 79 | Batch: 007 / 016 | Total loss: 4.297 | Reg loss: 0.032 | Tree loss: 4.297 | Accuracy: 0.111328 | 0.058 sec/iter
Epoch: 79 | Batch: 008 / 016 | Total loss: 4.362 | Reg loss: 0.032 | Tree loss: 4.362 | Accuracy: 0.111328 | 0.058 sec/iter
Epoch: 79 | Batch: 009 / 016 | Total loss: 4.199 | Reg loss: 0.032 | Tree loss: 4.199 | Accuracy: 0.148438 | 0.058 sec/iter
Epoch: 7

Epoch: 83 | Batch: 001 / 016 | Total loss: 4.320 | Reg loss: 0.032 | Tree loss: 4.320 | Accuracy: 0.107422 | 0.058 sec/iter
Epoch: 83 | Batch: 002 / 016 | Total loss: 4.347 | Reg loss: 0.032 | Tree loss: 4.347 | Accuracy: 0.109375 | 0.058 sec/iter
Epoch: 83 | Batch: 003 / 016 | Total loss: 4.353 | Reg loss: 0.032 | Tree loss: 4.353 | Accuracy: 0.107422 | 0.058 sec/iter
Epoch: 83 | Batch: 004 / 016 | Total loss: 4.281 | Reg loss: 0.032 | Tree loss: 4.281 | Accuracy: 0.125000 | 0.058 sec/iter
Epoch: 83 | Batch: 005 / 016 | Total loss: 4.320 | Reg loss: 0.032 | Tree loss: 4.320 | Accuracy: 0.113281 | 0.058 sec/iter
Epoch: 83 | Batch: 006 / 016 | Total loss: 4.321 | Reg loss: 0.032 | Tree loss: 4.321 | Accuracy: 0.097656 | 0.058 sec/iter
Epoch: 83 | Batch: 007 / 016 | Total loss: 4.226 | Reg loss: 0.032 | Tree loss: 4.226 | Accuracy: 0.125000 | 0.058 sec/iter
Epoch: 83 | Batch: 008 / 016 | Total loss: 4.352 | Reg loss: 0.032 | Tree loss: 4.352 | Accuracy: 0.097656 | 0.058 sec/iter
Epoch: 8

Epoch: 86 | Batch: 015 / 016 | Total loss: 4.110 | Reg loss: 0.032 | Tree loss: 4.110 | Accuracy: 0.119658 | 0.058 sec/iter
Average sparseness: 0.9821428571428567
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
Epoch: 87 | Batch: 000 / 016 | Total loss: 4.260 | Reg loss: 0.032 | Tree loss: 4.260 | Accuracy: 0.113281 | 0.058 sec/iter
Epoch: 87 | Batch: 001 / 016 | Total loss: 4.387 | Reg loss: 0.032 | Tree loss: 4.387 | Accuracy: 0.083984 | 0.058 sec/iter
Epoch: 87 | Batch: 002 / 016 | Total loss: 4.282 | Reg loss: 0.032 | Tree loss: 4.282 | Accuracy: 0.132812 | 0.058 sec/iter
Epoch: 87 | Batch: 003 / 016 | Total loss: 4.293 | Reg loss: 0.032 | Tree loss: 4.293 | Accuracy: 0.126953 | 0.058 sec/iter
Epoch: 87 | Batch: 004 / 016 | Total loss: 4.365 | Reg loss: 0.032 | Tree loss: 4.365 | Accuracy: 0.105469 | 0.058 sec/iter
Epoch: 87 | Batch: 005 / 016 | Total loss: 4.278 | Reg loss: 0.032 | Tree los

Epoch: 90 | Batch: 015 / 016 | Total loss: 4.153 | Reg loss: 0.032 | Tree loss: 4.153 | Accuracy: 0.102564 | 0.058 sec/iter
Average sparseness: 0.9821428571428567
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
Epoch: 91 | Batch: 000 / 016 | Total loss: 4.259 | Reg loss: 0.032 | Tree loss: 4.259 | Accuracy: 0.119141 | 0.058 sec/iter
Epoch: 91 | Batch: 001 / 016 | Total loss: 4.310 | Reg loss: 0.032 | Tree loss: 4.310 | Accuracy: 0.097656 | 0.058 sec/iter
Epoch: 91 | Batch: 002 / 016 | Total loss: 4.260 | Reg loss: 0.032 | Tree loss: 4.260 | Accuracy: 0.093750 | 0.058 sec/iter
Epoch: 91 | Batch: 003 / 016 | Total loss: 4.216 | Reg loss: 0.032 | Tree loss: 4.216 | Accuracy: 0.125000 | 0.058 sec/iter
Epoch: 91 | Batch: 004 / 016 | Total loss: 4.278 | Reg loss: 0.032 | Tree loss: 4.278 | Accuracy: 0.085938 | 0.058 sec/iter
Epoch: 91 | Batch: 005 / 016 | Total loss: 4.231 | Reg loss: 0.032 | Tree los

Epoch: 94 | Batch: 015 / 016 | Total loss: 4.160 | Reg loss: 0.032 | Tree loss: 4.160 | Accuracy: 0.094017 | 0.058 sec/iter
Average sparseness: 0.9821428571428567
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
Epoch: 95 | Batch: 000 / 016 | Total loss: 4.283 | Reg loss: 0.032 | Tree loss: 4.283 | Accuracy: 0.105469 | 0.058 sec/iter
Epoch: 95 | Batch: 001 / 016 | Total loss: 4.199 | Reg loss: 0.032 | Tree loss: 4.199 | Accuracy: 0.103516 | 0.058 sec/iter
Epoch: 95 | Batch: 002 / 016 | Total loss: 4.191 | Reg loss: 0.032 | Tree loss: 4.191 | Accuracy: 0.115234 | 0.058 sec/iter
Epoch: 95 | Batch: 003 / 016 | Total loss: 4.202 | Reg loss: 0.032 | Tree loss: 4.202 | Accuracy: 0.107422 | 0.058 sec/iter
Epoch: 95 | Batch: 004 / 016 | Total loss: 4.152 | Reg loss: 0.032 | Tree loss: 4.152 | Accuracy: 0.109375 | 0.058 sec/iter
Epoch: 95 | Batch: 005 / 016 | Total loss: 4.247 | Reg loss: 0.032 | Tree los

Epoch: 98 | Batch: 013 / 016 | Total loss: 3.986 | Reg loss: 0.032 | Tree loss: 3.986 | Accuracy: 0.123047 | 0.058 sec/iter
Epoch: 98 | Batch: 014 / 016 | Total loss: 4.051 | Reg loss: 0.032 | Tree loss: 4.051 | Accuracy: 0.105469 | 0.058 sec/iter
Epoch: 98 | Batch: 015 / 016 | Total loss: 4.014 | Reg loss: 0.032 | Tree loss: 4.014 | Accuracy: 0.094017 | 0.058 sec/iter
Average sparseness: 0.9821428571428567
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
Epoch: 99 | Batch: 000 / 016 | Total loss: 4.145 | Reg loss: 0.032 | Tree loss: 4.145 | Accuracy: 0.119141 | 0.058 sec/iter
Epoch: 99 | Batch: 001 / 016 | Total loss: 4.189 | Reg loss: 0.032 | Tree loss: 4.189 | Accuracy: 0.089844 | 0.058 sec/iter
Epoch: 99 | Batch: 002 / 016 | Total loss: 4.146 | Reg loss: 0.032 | Tree loss: 4.146 | Accuracy: 0.113281 | 0.058 sec/iter
Epoch: 99 | Batch: 003 / 016 | Total loss: 4.172 | Reg loss: 0.032 | Tree los

In [31]:
plt.figure(figsize=(10, 5))
plt.ylabel("Accuracy")
plt.xlabel('Iteration')
plt.plot(accs, label='Accuracy vs iteration')
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [32]:
plt.figure()
plt.ylabel("Loss")
plt.xlabel('Iteration')
plt.plot(losses, label='Loss vs iteration')
plt.yscale("log")
plt.show()

plt.figure()
weights = tree.inner_nodes.weight.cpu().detach().numpy().flatten()
plt.hist(weights, bins=500)
weights_std = np.std(weights)
weights_mean = np.mean(weights)
plt.axvline(weights_mean + weights_std, color='r')
plt.axvline(weights_mean - weights_std, color='r')
plt.title(f"Mean: {weights_mean}   |   STD: {weights_std}")
plt.yscale("log")
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

# Tree Visualization

In [33]:
plt.figure(figsize=(15, 10), dpi=80)
avg_height, root = tree.visualize()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Average height: 5.892857142857143


# Extract Rules

# Accumulate samples in the leaves

In [34]:
print(f"Number of patterns: {len(root.get_leaves())}")

Number of patterns: 56


In [35]:
method = 'greedy'

In [36]:
root.clear_leaves_samples()

with torch.no_grad():
    for batch_idx, (data, target) in enumerate(tree_loader):
        root.accumulate_samples(data, method)



# Tighten boundaries

In [37]:
attr_names = dataset.items

# print(attr_names)
leaves = root.get_leaves()
sum_comprehensibility = 0
comprehensibilities = []
for pattern_counter, leaf in enumerate(leaves):
    leaf.reset_path()
    leaf.tighten_with_accumulated_samples()
    conds = leaf.get_path_conditions(attr_names)
    print(f"============== Pattern {pattern_counter + 1} ==============")
    comprehensibilities.append(sum([cond.comprehensibility for cond in conds]))
    
print(f"Average comprehensibility: {np.mean(comprehensibilities)}")
print(f"std comprehensibility: {np.std(comprehensibilities)}")
print(f"var comprehensibility: {np.var(comprehensibilities)}")
print(f"minimum comprehensibility: {np.min(comprehensibilities)}")
print(f"maximum comprehensibility: {np.max(comprehensibilities)}")

6672
1476
Average comprehensibility: 27.857142857142858
std comprehensibility: 2.2632827882506943
var comprehensibility: 5.122448979591836
minimum comprehensibility: 20
maximum comprehensibility: 32


  return np.log(1 / (1 - x))
