In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib widget
import numpy as np
import pandas as pd
import torch
import torchvision
from torch import optim
import matplotlib.pyplot as plt
import torch.nn.functional as F
from sklearn.metrics import pairwise_distances
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tqdm.notebook import tqdm
import os
import sys
import time
import torch.nn as nn
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)
from stream_generators.market_basket_dataset import MarketBasketDataset, BinaryEncodingTransform, RemoveItemsTransform
from utils.MatplotlibUtils import reduce_dims_and_plot
from network.auto_encoder import AutoEncoder
from losses.knn_loss import KNNLoss
from sklearn.cluster import DBSCAN
from sklearn.mixture import GaussianMixture
from soft_decision_tree.sdt_model import SDT
from sklearn.metrics import davies_bouldin_score

In [2]:
k = 8
tree_depth = 12
device = 'cuda'
dataset_path = r"/mnt/qnap/ekosman/Groceries_dataset.csv"

## Create the market basket dataset and use one-hot encoding for items

In [3]:
dataset = MarketBasketDataset(dataset_path=dataset_path)

In [4]:
model = AutoEncoder(dataset.n_items, 50, 4).train().to(device)
epochs = 500
lr = 5e-3
batch_size = 512
log_every = 5

In [5]:
dataset.transform = torchvision.transforms.Compose([
    RemoveItemsTransform(p=0.5),
    BinaryEncodingTransform(mapping=dataset.items_to_idx),
]
)
dataset.target_transform = torchvision.transforms.Compose([
    BinaryEncodingTransform(mapping=dataset.items_to_idx),
]
)

In [6]:
model.train()
data_iter = torch.utils.data.DataLoader(dataset,
                                     batch_size=batch_size,
                                     shuffle=True,
                                     num_workers=1,
                                     pin_memory=True)
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)
scheduler = ReduceLROnPlateau(optimizer, 'min', verbose=True, factor=0.7, threshold=1e-4)
knn_crt = KNNLoss(k=k).to(device)
losses = []
alpha = 10/170
gamma = 2
for epoch in range(epochs):
    total_loss = 0
    for iteration, (batch, target) in enumerate(data_iter):
        batch = batch.to(device)
        target = target.to(device)
        outputs, iterm = model(batch, return_intermidiate=True)
        mse_loss = F.binary_cross_entropy_with_logits(outputs, target, reduction='none')
        mask = torch.ones_like(mse_loss)
        mask[target == 0] = alpha ** gamma
        mask[target == 1] = (1 - alpha) ** gamma
        mse_loss = (mse_loss * mask).sum(dim=-1).mean()
        try:
            knn_loss = knn_crt(iterm)
            if torch.isinf(knn_loss):
                knn_loss = 0
        except ValueError:
            knn_loss = torch.tensor(0)
        loss = mse_loss + knn_loss
        total_loss += loss.item()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if iteration % log_every == 0:
            print(f"Epoch {epoch} / {epochs} | iteration {iteration} / {len(data_iter)} | Total Loss: {loss.item()} | KNN Loss: {knn_loss.item()} | BCE Loss: {mse_loss.item()}")
    
    scheduler.step(total_loss / (iteration + 1))
    losses.append(total_loss / (iteration + 1))

Epoch 0 / 500 | iteration 0 / 30 | Total Loss: 8.198968887329102 | KNN Loss: 6.225536346435547 | BCE Loss: 1.9734323024749756
Epoch 0 / 500 | iteration 5 / 30 | Total Loss: 8.135589599609375 | KNN Loss: 6.225613117218018 | BCE Loss: 1.9099769592285156
Epoch 0 / 500 | iteration 10 / 30 | Total Loss: 8.183895111083984 | KNN Loss: 6.225164413452148 | BCE Loss: 1.9587311744689941
Epoch 0 / 500 | iteration 15 / 30 | Total Loss: 8.114386558532715 | KNN Loss: 6.22507905960083 | BCE Loss: 1.8893071413040161
Epoch 0 / 500 | iteration 20 / 30 | Total Loss: 8.159957885742188 | KNN Loss: 6.224706172943115 | BCE Loss: 1.9352517127990723
Epoch 0 / 500 | iteration 25 / 30 | Total Loss: 8.136695861816406 | KNN Loss: 6.224417686462402 | BCE Loss: 1.912278652191162
Epoch 1 / 500 | iteration 0 / 30 | Total Loss: 8.143320083618164 | KNN Loss: 6.2238640785217285 | BCE Loss: 1.9194557666778564
Epoch 1 / 500 | iteration 5 / 30 | Total Loss: 8.132482528686523 | KNN Loss: 6.223404407501221 | BCE Loss: 1.909078

Epoch 10 / 500 | iteration 25 / 30 | Total Loss: 5.453589916229248 | KNN Loss: 4.331330299377441 | BCE Loss: 1.122259497642517
Epoch 11 / 500 | iteration 0 / 30 | Total Loss: 5.370828628540039 | KNN Loss: 4.244254112243652 | BCE Loss: 1.1265745162963867
Epoch 11 / 500 | iteration 5 / 30 | Total Loss: 5.28737735748291 | KNN Loss: 4.171576499938965 | BCE Loss: 1.1158009767532349
Epoch 11 / 500 | iteration 10 / 30 | Total Loss: 5.168726921081543 | KNN Loss: 4.064449787139893 | BCE Loss: 1.1042771339416504
Epoch 11 / 500 | iteration 15 / 30 | Total Loss: 5.064929008483887 | KNN Loss: 3.9546024799346924 | BCE Loss: 1.1103267669677734
Epoch 11 / 500 | iteration 20 / 30 | Total Loss: 4.961371421813965 | KNN Loss: 3.8593170642852783 | BCE Loss: 1.1020543575286865
Epoch 11 / 500 | iteration 25 / 30 | Total Loss: 4.923951148986816 | KNN Loss: 3.8265035152435303 | BCE Loss: 1.0974478721618652
Epoch 12 / 500 | iteration 0 / 30 | Total Loss: 4.831162452697754 | KNN Loss: 3.720099925994873 | BCE Los

Epoch 21 / 500 | iteration 15 / 30 | Total Loss: 3.6977899074554443 | KNN Loss: 2.6339762210845947 | BCE Loss: 1.0638136863708496
Epoch 21 / 500 | iteration 20 / 30 | Total Loss: 3.7353272438049316 | KNN Loss: 2.6816258430480957 | BCE Loss: 1.0537012815475464
Epoch 21 / 500 | iteration 25 / 30 | Total Loss: 3.670689105987549 | KNN Loss: 2.639291763305664 | BCE Loss: 1.0313974618911743
Epoch 22 / 500 | iteration 0 / 30 | Total Loss: 3.6490042209625244 | KNN Loss: 2.616891860961914 | BCE Loss: 1.0321123600006104
Epoch 22 / 500 | iteration 5 / 30 | Total Loss: 3.6471664905548096 | KNN Loss: 2.5883514881134033 | BCE Loss: 1.0588150024414062
Epoch 22 / 500 | iteration 10 / 30 | Total Loss: 3.649770498275757 | KNN Loss: 2.6246516704559326 | BCE Loss: 1.0251188278198242
Epoch 22 / 500 | iteration 15 / 30 | Total Loss: 3.6736395359039307 | KNN Loss: 2.6033623218536377 | BCE Loss: 1.070277214050293
Epoch 22 / 500 | iteration 20 / 30 | Total Loss: 3.6453888416290283 | KNN Loss: 2.597358226776123

Epoch 32 / 500 | iteration 5 / 30 | Total Loss: 3.618807077407837 | KNN Loss: 2.5717806816101074 | BCE Loss: 1.0470263957977295
Epoch 32 / 500 | iteration 10 / 30 | Total Loss: 3.569949150085449 | KNN Loss: 2.5322418212890625 | BCE Loss: 1.0377073287963867
Epoch 32 / 500 | iteration 15 / 30 | Total Loss: 3.577805995941162 | KNN Loss: 2.5482277870178223 | BCE Loss: 1.0295783281326294
Epoch 32 / 500 | iteration 20 / 30 | Total Loss: 3.6160964965820312 | KNN Loss: 2.5657830238342285 | BCE Loss: 1.0503135919570923
Epoch 32 / 500 | iteration 25 / 30 | Total Loss: 3.622591495513916 | KNN Loss: 2.5976924896240234 | BCE Loss: 1.0248990058898926
Epoch 33 / 500 | iteration 0 / 30 | Total Loss: 3.5767431259155273 | KNN Loss: 2.555893898010254 | BCE Loss: 1.0208492279052734
Epoch 33 / 500 | iteration 5 / 30 | Total Loss: 3.610532760620117 | KNN Loss: 2.5846874713897705 | BCE Loss: 1.0258452892303467
Epoch 33 / 500 | iteration 10 / 30 | Total Loss: 3.701009511947632 | KNN Loss: 2.645540952682495 | 

Epoch 42 / 500 | iteration 25 / 30 | Total Loss: 3.599447250366211 | KNN Loss: 2.5760529041290283 | BCE Loss: 1.0233944654464722
Epoch 43 / 500 | iteration 0 / 30 | Total Loss: 3.576629161834717 | KNN Loss: 2.5417699813842773 | BCE Loss: 1.0348591804504395
Epoch 43 / 500 | iteration 5 / 30 | Total Loss: 3.612381935119629 | KNN Loss: 2.572984457015991 | BCE Loss: 1.0393973588943481
Epoch 43 / 500 | iteration 10 / 30 | Total Loss: 3.5771279335021973 | KNN Loss: 2.5571000576019287 | BCE Loss: 1.0200278759002686
Epoch 43 / 500 | iteration 15 / 30 | Total Loss: 3.6161727905273438 | KNN Loss: 2.5712759494781494 | BCE Loss: 1.0448968410491943
Epoch 43 / 500 | iteration 20 / 30 | Total Loss: 3.5584521293640137 | KNN Loss: 2.5300369262695312 | BCE Loss: 1.0284152030944824
Epoch 43 / 500 | iteration 25 / 30 | Total Loss: 3.539503335952759 | KNN Loss: 2.492119789123535 | BCE Loss: 1.0473835468292236
Epoch 44 / 500 | iteration 0 / 30 | Total Loss: 3.569394588470459 | KNN Loss: 2.5309765338897705 |

Epoch 53 / 500 | iteration 15 / 30 | Total Loss: 3.511972665786743 | KNN Loss: 2.4963977336883545 | BCE Loss: 1.0155749320983887
Epoch 53 / 500 | iteration 20 / 30 | Total Loss: 3.5446720123291016 | KNN Loss: 2.494915008544922 | BCE Loss: 1.0497570037841797
Epoch 53 / 500 | iteration 25 / 30 | Total Loss: 3.548696756362915 | KNN Loss: 2.5070040225982666 | BCE Loss: 1.0416927337646484
Epoch 54 / 500 | iteration 0 / 30 | Total Loss: 3.548344850540161 | KNN Loss: 2.5258564949035645 | BCE Loss: 1.0224883556365967
Epoch 54 / 500 | iteration 5 / 30 | Total Loss: 3.5514042377471924 | KNN Loss: 2.5236122608184814 | BCE Loss: 1.027791976928711
Epoch 54 / 500 | iteration 10 / 30 | Total Loss: 3.547266960144043 | KNN Loss: 2.526203155517578 | BCE Loss: 1.0210636854171753
Epoch 54 / 500 | iteration 15 / 30 | Total Loss: 3.556274890899658 | KNN Loss: 2.538482904434204 | BCE Loss: 1.0177921056747437
Epoch 54 / 500 | iteration 20 / 30 | Total Loss: 3.562323570251465 | KNN Loss: 2.5182886123657227 | B

Epoch 64 / 500 | iteration 5 / 30 | Total Loss: 3.5428080558776855 | KNN Loss: 2.51823353767395 | BCE Loss: 1.0245743989944458
Epoch 64 / 500 | iteration 10 / 30 | Total Loss: 3.54805850982666 | KNN Loss: 2.524513006210327 | BCE Loss: 1.0235453844070435
Epoch 64 / 500 | iteration 15 / 30 | Total Loss: 3.565990447998047 | KNN Loss: 2.51900053024292 | BCE Loss: 1.0469897985458374
Epoch 64 / 500 | iteration 20 / 30 | Total Loss: 3.50648832321167 | KNN Loss: 2.4750001430511475 | BCE Loss: 1.031488060951233
Epoch 64 / 500 | iteration 25 / 30 | Total Loss: 3.54034686088562 | KNN Loss: 2.5017125606536865 | BCE Loss: 1.0386343002319336
Epoch 65 / 500 | iteration 0 / 30 | Total Loss: 3.5117130279541016 | KNN Loss: 2.4871654510498047 | BCE Loss: 1.0245476961135864
Epoch 65 / 500 | iteration 5 / 30 | Total Loss: 3.6115224361419678 | KNN Loss: 2.575932502746582 | BCE Loss: 1.0355899333953857
Epoch 65 / 500 | iteration 10 / 30 | Total Loss: 3.5500192642211914 | KNN Loss: 2.523078441619873 | BCE Los

Epoch 74 / 500 | iteration 25 / 30 | Total Loss: 3.516369342803955 | KNN Loss: 2.4834420680999756 | BCE Loss: 1.0329272747039795
Epoch 75 / 500 | iteration 0 / 30 | Total Loss: 3.5317420959472656 | KNN Loss: 2.501016139984131 | BCE Loss: 1.0307259559631348
Epoch 75 / 500 | iteration 5 / 30 | Total Loss: 3.5574615001678467 | KNN Loss: 2.533857583999634 | BCE Loss: 1.023603916168213
Epoch 75 / 500 | iteration 10 / 30 | Total Loss: 3.511458396911621 | KNN Loss: 2.511655569076538 | BCE Loss: 0.9998027086257935
Epoch 75 / 500 | iteration 15 / 30 | Total Loss: 3.4992496967315674 | KNN Loss: 2.4767322540283203 | BCE Loss: 1.022517442703247
Epoch 75 / 500 | iteration 20 / 30 | Total Loss: 3.52164626121521 | KNN Loss: 2.4742960929870605 | BCE Loss: 1.0473501682281494
Epoch 75 / 500 | iteration 25 / 30 | Total Loss: 3.556767225265503 | KNN Loss: 2.51163387298584 | BCE Loss: 1.045133352279663
Epoch 76 / 500 | iteration 0 / 30 | Total Loss: 3.5352797508239746 | KNN Loss: 2.519198417663574 | BCE Lo

Epoch 85 / 500 | iteration 15 / 30 | Total Loss: 3.48433780670166 | KNN Loss: 2.487887144088745 | BCE Loss: 0.9964507818222046
Epoch 85 / 500 | iteration 20 / 30 | Total Loss: 3.518157720565796 | KNN Loss: 2.4706478118896484 | BCE Loss: 1.0475099086761475
Epoch 85 / 500 | iteration 25 / 30 | Total Loss: 3.581376314163208 | KNN Loss: 2.5050547122955322 | BCE Loss: 1.0763216018676758
Epoch    86: reducing learning rate of group 0 to 2.4500e-03.
Epoch 86 / 500 | iteration 0 / 30 | Total Loss: 3.4986252784729004 | KNN Loss: 2.470700740814209 | BCE Loss: 1.027924656867981
Epoch 86 / 500 | iteration 5 / 30 | Total Loss: 3.5552327632904053 | KNN Loss: 2.530477285385132 | BCE Loss: 1.0247554779052734
Epoch 86 / 500 | iteration 10 / 30 | Total Loss: 3.4889910221099854 | KNN Loss: 2.502777576446533 | BCE Loss: 0.9862135052680969
Epoch 86 / 500 | iteration 15 / 30 | Total Loss: 3.523919105529785 | KNN Loss: 2.4937007427215576 | BCE Loss: 1.030218243598938
Epoch 86 / 500 | iteration 20 / 30 | Tota

Epoch 96 / 500 | iteration 5 / 30 | Total Loss: 3.5028865337371826 | KNN Loss: 2.476839303970337 | BCE Loss: 1.0260472297668457
Epoch 96 / 500 | iteration 10 / 30 | Total Loss: 3.5230791568756104 | KNN Loss: 2.511836290359497 | BCE Loss: 1.0112428665161133
Epoch 96 / 500 | iteration 15 / 30 | Total Loss: 3.5682125091552734 | KNN Loss: 2.499727725982666 | BCE Loss: 1.0684847831726074
Epoch 96 / 500 | iteration 20 / 30 | Total Loss: 3.4807815551757812 | KNN Loss: 2.4838593006134033 | BCE Loss: 0.9969223737716675
Epoch 96 / 500 | iteration 25 / 30 | Total Loss: 3.513824939727783 | KNN Loss: 2.514911651611328 | BCE Loss: 0.9989132881164551
Epoch 97 / 500 | iteration 0 / 30 | Total Loss: 3.5171098709106445 | KNN Loss: 2.4811134338378906 | BCE Loss: 1.0359965562820435
Epoch 97 / 500 | iteration 5 / 30 | Total Loss: 3.5042455196380615 | KNN Loss: 2.498847007751465 | BCE Loss: 1.0053985118865967
Epoch 97 / 500 | iteration 10 / 30 | Total Loss: 3.4921188354492188 | KNN Loss: 2.48435378074646 | 

Epoch 106 / 500 | iteration 25 / 30 | Total Loss: 3.5105643272399902 | KNN Loss: 2.4997990131378174 | BCE Loss: 1.0107653141021729
Epoch 107 / 500 | iteration 0 / 30 | Total Loss: 3.5715267658233643 | KNN Loss: 2.5322704315185547 | BCE Loss: 1.0392563343048096
Epoch 107 / 500 | iteration 5 / 30 | Total Loss: 3.523423194885254 | KNN Loss: 2.495837450027466 | BCE Loss: 1.0275858640670776
Epoch 107 / 500 | iteration 10 / 30 | Total Loss: 3.507741928100586 | KNN Loss: 2.4925918579101562 | BCE Loss: 1.0151499509811401
Epoch 107 / 500 | iteration 15 / 30 | Total Loss: 3.497525691986084 | KNN Loss: 2.4803102016448975 | BCE Loss: 1.017215371131897
Epoch 107 / 500 | iteration 20 / 30 | Total Loss: 3.494323492050171 | KNN Loss: 2.477142572402954 | BCE Loss: 1.0171809196472168
Epoch 107 / 500 | iteration 25 / 30 | Total Loss: 3.4929041862487793 | KNN Loss: 2.4785025119781494 | BCE Loss: 1.0144016742706299
Epoch 108 / 500 | iteration 0 / 30 | Total Loss: 3.5089058876037598 | KNN Loss: 2.4958498477

Epoch 117 / 500 | iteration 15 / 30 | Total Loss: 3.527435779571533 | KNN Loss: 2.4842774868011475 | BCE Loss: 1.0431582927703857
Epoch 117 / 500 | iteration 20 / 30 | Total Loss: 3.508188009262085 | KNN Loss: 2.5029351711273193 | BCE Loss: 1.0052528381347656
Epoch 117 / 500 | iteration 25 / 30 | Total Loss: 3.4684646129608154 | KNN Loss: 2.4497814178466797 | BCE Loss: 1.0186831951141357
Epoch 118 / 500 | iteration 0 / 30 | Total Loss: 3.5249102115631104 | KNN Loss: 2.504594564437866 | BCE Loss: 1.0203156471252441
Epoch 118 / 500 | iteration 5 / 30 | Total Loss: 3.494647741317749 | KNN Loss: 2.49564790725708 | BCE Loss: 0.9989997744560242
Epoch 118 / 500 | iteration 10 / 30 | Total Loss: 3.508951187133789 | KNN Loss: 2.506401777267456 | BCE Loss: 1.002549409866333
Epoch 118 / 500 | iteration 15 / 30 | Total Loss: 3.524812698364258 | KNN Loss: 2.4995064735412598 | BCE Loss: 1.025306224822998
Epoch 118 / 500 | iteration 20 / 30 | Total Loss: 3.5236291885375977 | KNN Loss: 2.5103526115417

Epoch 128 / 500 | iteration 0 / 30 | Total Loss: 3.492321491241455 | KNN Loss: 2.478208303451538 | BCE Loss: 1.014113187789917
Epoch 128 / 500 | iteration 5 / 30 | Total Loss: 3.4962241649627686 | KNN Loss: 2.4720571041107178 | BCE Loss: 1.0241670608520508
Epoch 128 / 500 | iteration 10 / 30 | Total Loss: 3.4817395210266113 | KNN Loss: 2.474567413330078 | BCE Loss: 1.0071719884872437
Epoch 128 / 500 | iteration 15 / 30 | Total Loss: 3.4710018634796143 | KNN Loss: 2.479167938232422 | BCE Loss: 0.9918339252471924
Epoch 128 / 500 | iteration 20 / 30 | Total Loss: 3.5352823734283447 | KNN Loss: 2.501796245574951 | BCE Loss: 1.0334861278533936
Epoch 128 / 500 | iteration 25 / 30 | Total Loss: 3.4706225395202637 | KNN Loss: 2.4541523456573486 | BCE Loss: 1.0164700746536255
Epoch 129 / 500 | iteration 0 / 30 | Total Loss: 3.484872341156006 | KNN Loss: 2.4847629070281982 | BCE Loss: 1.0001094341278076
Epoch 129 / 500 | iteration 5 / 30 | Total Loss: 3.499220371246338 | KNN Loss: 2.491783857345

Epoch 138 / 500 | iteration 20 / 30 | Total Loss: 3.4935483932495117 | KNN Loss: 2.495151996612549 | BCE Loss: 0.9983965158462524
Epoch 138 / 500 | iteration 25 / 30 | Total Loss: 3.4936907291412354 | KNN Loss: 2.4574263095855713 | BCE Loss: 1.036264419555664
Epoch 139 / 500 | iteration 0 / 30 | Total Loss: 3.500603675842285 | KNN Loss: 2.487312078475952 | BCE Loss: 1.013291597366333
Epoch 139 / 500 | iteration 5 / 30 | Total Loss: 3.4797768592834473 | KNN Loss: 2.466181993484497 | BCE Loss: 1.0135948657989502
Epoch 139 / 500 | iteration 10 / 30 | Total Loss: 3.5796473026275635 | KNN Loss: 2.538252592086792 | BCE Loss: 1.0413947105407715
Epoch 139 / 500 | iteration 15 / 30 | Total Loss: 3.505741596221924 | KNN Loss: 2.4781224727630615 | BCE Loss: 1.0276191234588623
Epoch 139 / 500 | iteration 20 / 30 | Total Loss: 3.5333735942840576 | KNN Loss: 2.504798412322998 | BCE Loss: 1.0285751819610596
Epoch 139 / 500 | iteration 25 / 30 | Total Loss: 3.4414145946502686 | KNN Loss: 2.45444512367

Epoch 149 / 500 | iteration 10 / 30 | Total Loss: 3.517822265625 | KNN Loss: 2.4781246185302734 | BCE Loss: 1.0396976470947266
Epoch 149 / 500 | iteration 15 / 30 | Total Loss: 3.5461506843566895 | KNN Loss: 2.516878366470337 | BCE Loss: 1.0292723178863525
Epoch 149 / 500 | iteration 20 / 30 | Total Loss: 3.486924648284912 | KNN Loss: 2.4669923782348633 | BCE Loss: 1.0199322700500488
Epoch 149 / 500 | iteration 25 / 30 | Total Loss: 3.4714860916137695 | KNN Loss: 2.476755380630493 | BCE Loss: 0.9947308301925659
Epoch 150 / 500 | iteration 0 / 30 | Total Loss: 3.4871912002563477 | KNN Loss: 2.475466012954712 | BCE Loss: 1.0117251873016357
Epoch 150 / 500 | iteration 5 / 30 | Total Loss: 3.4891839027404785 | KNN Loss: 2.4775004386901855 | BCE Loss: 1.0116833448410034
Epoch 150 / 500 | iteration 10 / 30 | Total Loss: 3.5118227005004883 | KNN Loss: 2.477031707763672 | BCE Loss: 1.0347909927368164
Epoch 150 / 500 | iteration 15 / 30 | Total Loss: 3.4957919120788574 | KNN Loss: 2.47924947738

Epoch 159 / 500 | iteration 25 / 30 | Total Loss: 3.5687875747680664 | KNN Loss: 2.5259628295898438 | BCE Loss: 1.0428248643875122
Epoch 160 / 500 | iteration 0 / 30 | Total Loss: 3.4703948497772217 | KNN Loss: 2.4795725345611572 | BCE Loss: 0.9908223748207092
Epoch 160 / 500 | iteration 5 / 30 | Total Loss: 3.502793073654175 | KNN Loss: 2.482849597930908 | BCE Loss: 1.0199434757232666
Epoch 160 / 500 | iteration 10 / 30 | Total Loss: 3.480257034301758 | KNN Loss: 2.479675531387329 | BCE Loss: 1.0005815029144287
Epoch 160 / 500 | iteration 15 / 30 | Total Loss: 3.538160562515259 | KNN Loss: 2.5098910331726074 | BCE Loss: 1.0282695293426514
Epoch 160 / 500 | iteration 20 / 30 | Total Loss: 3.4917283058166504 | KNN Loss: 2.489123821258545 | BCE Loss: 1.002604603767395
Epoch 160 / 500 | iteration 25 / 30 | Total Loss: 3.4330670833587646 | KNN Loss: 2.451557159423828 | BCE Loss: 0.9815099239349365
Epoch 161 / 500 | iteration 0 / 30 | Total Loss: 3.5235633850097656 | KNN Loss: 2.49537825584

Epoch 170 / 500 | iteration 10 / 30 | Total Loss: 3.492882251739502 | KNN Loss: 2.4751317501068115 | BCE Loss: 1.0177505016326904
Epoch 170 / 500 | iteration 15 / 30 | Total Loss: 3.510105609893799 | KNN Loss: 2.4771065711975098 | BCE Loss: 1.032999038696289
Epoch 170 / 500 | iteration 20 / 30 | Total Loss: 3.464184522628784 | KNN Loss: 2.48067045211792 | BCE Loss: 0.9835140705108643
Epoch 170 / 500 | iteration 25 / 30 | Total Loss: 3.5360872745513916 | KNN Loss: 2.5044565200805664 | BCE Loss: 1.0316307544708252
Epoch 171 / 500 | iteration 0 / 30 | Total Loss: 3.4814095497131348 | KNN Loss: 2.4773924350738525 | BCE Loss: 1.0040171146392822
Epoch 171 / 500 | iteration 5 / 30 | Total Loss: 3.5272955894470215 | KNN Loss: 2.501317262649536 | BCE Loss: 1.0259783267974854
Epoch 171 / 500 | iteration 10 / 30 | Total Loss: 3.5181872844696045 | KNN Loss: 2.4573137760162354 | BCE Loss: 1.0608735084533691
Epoch 171 / 500 | iteration 15 / 30 | Total Loss: 3.4752697944641113 | KNN Loss: 2.467348098

Epoch 181 / 500 | iteration 0 / 30 | Total Loss: 3.477966785430908 | KNN Loss: 2.4622085094451904 | BCE Loss: 1.0157583951950073
Epoch 181 / 500 | iteration 5 / 30 | Total Loss: 3.507354259490967 | KNN Loss: 2.4803056716918945 | BCE Loss: 1.0270485877990723
Epoch 181 / 500 | iteration 10 / 30 | Total Loss: 3.494075298309326 | KNN Loss: 2.50473952293396 | BCE Loss: 0.9893357157707214
Epoch 181 / 500 | iteration 15 / 30 | Total Loss: 3.4678003787994385 | KNN Loss: 2.4632842540740967 | BCE Loss: 1.0045161247253418
Epoch 181 / 500 | iteration 20 / 30 | Total Loss: 3.461928367614746 | KNN Loss: 2.43854022026062 | BCE Loss: 1.0233880281448364
Epoch 181 / 500 | iteration 25 / 30 | Total Loss: 3.5018322467803955 | KNN Loss: 2.491797924041748 | BCE Loss: 1.0100343227386475
Epoch 182 / 500 | iteration 0 / 30 | Total Loss: 3.5308685302734375 | KNN Loss: 2.4839329719543457 | BCE Loss: 1.0469355583190918
Epoch 182 / 500 | iteration 5 / 30 | Total Loss: 3.4266469478607178 | KNN Loss: 2.4288527965545

Epoch 191 / 500 | iteration 20 / 30 | Total Loss: 3.5003325939178467 | KNN Loss: 2.4927446842193604 | BCE Loss: 1.0075879096984863
Epoch 191 / 500 | iteration 25 / 30 | Total Loss: 3.500047206878662 | KNN Loss: 2.4765772819519043 | BCE Loss: 1.0234699249267578
Epoch 192 / 500 | iteration 0 / 30 | Total Loss: 3.47090482711792 | KNN Loss: 2.4573416709899902 | BCE Loss: 1.0135632753372192
Epoch 192 / 500 | iteration 5 / 30 | Total Loss: 3.471834182739258 | KNN Loss: 2.464101791381836 | BCE Loss: 1.0077323913574219
Epoch 192 / 500 | iteration 10 / 30 | Total Loss: 3.5221688747406006 | KNN Loss: 2.4798760414123535 | BCE Loss: 1.042292833328247
Epoch 192 / 500 | iteration 15 / 30 | Total Loss: 3.4889001846313477 | KNN Loss: 2.452802896499634 | BCE Loss: 1.0360972881317139
Epoch 192 / 500 | iteration 20 / 30 | Total Loss: 3.4980108737945557 | KNN Loss: 2.470329761505127 | BCE Loss: 1.0276811122894287
Epoch 192 / 500 | iteration 25 / 30 | Total Loss: 3.499168634414673 | KNN Loss: 2.47619748115

Epoch 202 / 500 | iteration 5 / 30 | Total Loss: 3.5068349838256836 | KNN Loss: 2.46734881401062 | BCE Loss: 1.0394861698150635
Epoch 202 / 500 | iteration 10 / 30 | Total Loss: 3.499793291091919 | KNN Loss: 2.462345600128174 | BCE Loss: 1.0374476909637451
Epoch 202 / 500 | iteration 15 / 30 | Total Loss: 3.481081962585449 | KNN Loss: 2.454221725463867 | BCE Loss: 1.0268603563308716
Epoch 202 / 500 | iteration 20 / 30 | Total Loss: 3.496720314025879 | KNN Loss: 2.5101540088653564 | BCE Loss: 0.9865663051605225
Epoch 202 / 500 | iteration 25 / 30 | Total Loss: 3.471802234649658 | KNN Loss: 2.462562084197998 | BCE Loss: 1.0092402696609497
Epoch 203 / 500 | iteration 0 / 30 | Total Loss: 3.5006136894226074 | KNN Loss: 2.489772081375122 | BCE Loss: 1.0108414888381958
Epoch 203 / 500 | iteration 5 / 30 | Total Loss: 3.467963933944702 | KNN Loss: 2.4658584594726562 | BCE Loss: 1.002105474472046
Epoch 203 / 500 | iteration 10 / 30 | Total Loss: 3.4860987663269043 | KNN Loss: 2.479560136795044

Epoch 212 / 500 | iteration 20 / 30 | Total Loss: 3.462161064147949 | KNN Loss: 2.4326059818267822 | BCE Loss: 1.029555082321167
Epoch 212 / 500 | iteration 25 / 30 | Total Loss: 3.4633519649505615 | KNN Loss: 2.461988687515259 | BCE Loss: 1.0013632774353027
Epoch 213 / 500 | iteration 0 / 30 | Total Loss: 3.486776351928711 | KNN Loss: 2.466282367706299 | BCE Loss: 1.0204941034317017
Epoch 213 / 500 | iteration 5 / 30 | Total Loss: 3.528419256210327 | KNN Loss: 2.475679874420166 | BCE Loss: 1.0527393817901611
Epoch 213 / 500 | iteration 10 / 30 | Total Loss: 3.497161388397217 | KNN Loss: 2.490509271621704 | BCE Loss: 1.0066522359848022
Epoch 213 / 500 | iteration 15 / 30 | Total Loss: 3.4970736503601074 | KNN Loss: 2.4694621562957764 | BCE Loss: 1.027611494064331
Epoch 213 / 500 | iteration 20 / 30 | Total Loss: 3.500676155090332 | KNN Loss: 2.461939573287964 | BCE Loss: 1.0387364625930786
Epoch 213 / 500 | iteration 25 / 30 | Total Loss: 3.479408025741577 | KNN Loss: 2.464498996734619

Epoch 223 / 500 | iteration 5 / 30 | Total Loss: 3.4995951652526855 | KNN Loss: 2.4769914150238037 | BCE Loss: 1.0226038694381714
Epoch 223 / 500 | iteration 10 / 30 | Total Loss: 3.472158193588257 | KNN Loss: 2.4525303840637207 | BCE Loss: 1.0196278095245361
Epoch 223 / 500 | iteration 15 / 30 | Total Loss: 3.4825668334960938 | KNN Loss: 2.4722959995269775 | BCE Loss: 1.0102709531784058
Epoch 223 / 500 | iteration 20 / 30 | Total Loss: 3.499417304992676 | KNN Loss: 2.474975109100342 | BCE Loss: 1.024442195892334
Epoch 223 / 500 | iteration 25 / 30 | Total Loss: 3.504148006439209 | KNN Loss: 2.466987133026123 | BCE Loss: 1.037160873413086
Epoch 224 / 500 | iteration 0 / 30 | Total Loss: 3.5083110332489014 | KNN Loss: 2.481292486190796 | BCE Loss: 1.0270185470581055
Epoch 224 / 500 | iteration 5 / 30 | Total Loss: 3.5025506019592285 | KNN Loss: 2.465989589691162 | BCE Loss: 1.036561131477356
Epoch 224 / 500 | iteration 10 / 30 | Total Loss: 3.51839542388916 | KNN Loss: 2.488136768341064

Epoch 233 / 500 | iteration 25 / 30 | Total Loss: 3.434509754180908 | KNN Loss: 2.4173152446746826 | BCE Loss: 1.017194390296936
Epoch 234 / 500 | iteration 0 / 30 | Total Loss: 3.4717183113098145 | KNN Loss: 2.4419703483581543 | BCE Loss: 1.0297480821609497
Epoch 234 / 500 | iteration 5 / 30 | Total Loss: 3.4583611488342285 | KNN Loss: 2.4240524768829346 | BCE Loss: 1.034308671951294
Epoch 234 / 500 | iteration 10 / 30 | Total Loss: 3.5188021659851074 | KNN Loss: 2.4815423488616943 | BCE Loss: 1.0372599363327026
Epoch 234 / 500 | iteration 15 / 30 | Total Loss: 3.4368722438812256 | KNN Loss: 2.4405100345611572 | BCE Loss: 0.9963622689247131
Epoch 234 / 500 | iteration 20 / 30 | Total Loss: 3.4783968925476074 | KNN Loss: 2.475135087966919 | BCE Loss: 1.0032618045806885
Epoch 234 / 500 | iteration 25 / 30 | Total Loss: 3.539263963699341 | KNN Loss: 2.4771318435668945 | BCE Loss: 1.0621321201324463
Epoch 235 / 500 | iteration 0 / 30 | Total Loss: 3.5090279579162598 | KNN Loss: 2.45754146

Epoch 244 / 500 | iteration 15 / 30 | Total Loss: 3.453809976577759 | KNN Loss: 2.4478659629821777 | BCE Loss: 1.005944013595581
Epoch 244 / 500 | iteration 20 / 30 | Total Loss: 3.4986910820007324 | KNN Loss: 2.467963457107544 | BCE Loss: 1.030727744102478
Epoch 244 / 500 | iteration 25 / 30 | Total Loss: 3.490947961807251 | KNN Loss: 2.4728846549987793 | BCE Loss: 1.0180633068084717
Epoch 245 / 500 | iteration 0 / 30 | Total Loss: 3.5086991786956787 | KNN Loss: 2.481510639190674 | BCE Loss: 1.0271885395050049
Epoch 245 / 500 | iteration 5 / 30 | Total Loss: 3.4248032569885254 | KNN Loss: 2.431762456893921 | BCE Loss: 0.993040919303894
Epoch 245 / 500 | iteration 10 / 30 | Total Loss: 3.4481449127197266 | KNN Loss: 2.4581263065338135 | BCE Loss: 0.9900187253952026
Epoch 245 / 500 | iteration 15 / 30 | Total Loss: 3.4930777549743652 | KNN Loss: 2.4678890705108643 | BCE Loss: 1.025188684463501
Epoch 245 / 500 | iteration 20 / 30 | Total Loss: 3.491400718688965 | KNN Loss: 2.465294361114

Epoch 255 / 500 | iteration 0 / 30 | Total Loss: 3.478372097015381 | KNN Loss: 2.4450743198394775 | BCE Loss: 1.0332977771759033
Epoch 255 / 500 | iteration 5 / 30 | Total Loss: 3.46812105178833 | KNN Loss: 2.446790933609009 | BCE Loss: 1.0213301181793213
Epoch 255 / 500 | iteration 10 / 30 | Total Loss: 3.5101044178009033 | KNN Loss: 2.510512351989746 | BCE Loss: 0.9995920062065125
Epoch 255 / 500 | iteration 15 / 30 | Total Loss: 3.4798073768615723 | KNN Loss: 2.4878506660461426 | BCE Loss: 0.9919568300247192
Epoch 255 / 500 | iteration 20 / 30 | Total Loss: 3.4724669456481934 | KNN Loss: 2.46370792388916 | BCE Loss: 1.0087590217590332
Epoch 255 / 500 | iteration 25 / 30 | Total Loss: 3.530513286590576 | KNN Loss: 2.478782892227173 | BCE Loss: 1.0517303943634033
Epoch 256 / 500 | iteration 0 / 30 | Total Loss: 3.498157024383545 | KNN Loss: 2.4786581993103027 | BCE Loss: 1.0194987058639526
Epoch 256 / 500 | iteration 5 / 30 | Total Loss: 3.4699110984802246 | KNN Loss: 2.45302009582519

Epoch 265 / 500 | iteration 15 / 30 | Total Loss: 3.426283359527588 | KNN Loss: 2.4487533569335938 | BCE Loss: 0.9775299429893494
Epoch 265 / 500 | iteration 20 / 30 | Total Loss: 3.4668264389038086 | KNN Loss: 2.457836627960205 | BCE Loss: 1.008989691734314
Epoch 265 / 500 | iteration 25 / 30 | Total Loss: 3.51741886138916 | KNN Loss: 2.4636075496673584 | BCE Loss: 1.0538113117218018
Epoch 266 / 500 | iteration 0 / 30 | Total Loss: 3.4980392456054688 | KNN Loss: 2.471996545791626 | BCE Loss: 1.0260426998138428
Epoch 266 / 500 | iteration 5 / 30 | Total Loss: 3.4218153953552246 | KNN Loss: 2.430177688598633 | BCE Loss: 0.9916378259658813
Epoch 266 / 500 | iteration 10 / 30 | Total Loss: 3.4970548152923584 | KNN Loss: 2.4707229137420654 | BCE Loss: 1.026331901550293
Epoch 266 / 500 | iteration 15 / 30 | Total Loss: 3.473012924194336 | KNN Loss: 2.476369619369507 | BCE Loss: 0.9966433048248291
Epoch 266 / 500 | iteration 20 / 30 | Total Loss: 3.475454330444336 | KNN Loss: 2.4399857521057

Epoch 276 / 500 | iteration 0 / 30 | Total Loss: 3.478055953979492 | KNN Loss: 2.4417800903320312 | BCE Loss: 1.0362757444381714
Epoch 276 / 500 | iteration 5 / 30 | Total Loss: 3.4640164375305176 | KNN Loss: 2.451784133911133 | BCE Loss: 1.0122324228286743
Epoch 276 / 500 | iteration 10 / 30 | Total Loss: 3.47721266746521 | KNN Loss: 2.4417550563812256 | BCE Loss: 1.0354576110839844
Epoch 276 / 500 | iteration 15 / 30 | Total Loss: 3.4294252395629883 | KNN Loss: 2.4271347522735596 | BCE Loss: 1.0022904872894287
Epoch 276 / 500 | iteration 20 / 30 | Total Loss: 3.4510223865509033 | KNN Loss: 2.4492104053497314 | BCE Loss: 1.0018119812011719
Epoch 276 / 500 | iteration 25 / 30 | Total Loss: 3.499720573425293 | KNN Loss: 2.472156524658203 | BCE Loss: 1.0275640487670898
Epoch 277 / 500 | iteration 0 / 30 | Total Loss: 3.467503070831299 | KNN Loss: 2.4519171714782715 | BCE Loss: 1.0155858993530273
Epoch 277 / 500 | iteration 5 / 30 | Total Loss: 3.4813172817230225 | KNN Loss: 2.48212099075

Epoch 286 / 500 | iteration 20 / 30 | Total Loss: 3.4972965717315674 | KNN Loss: 2.476848602294922 | BCE Loss: 1.0204479694366455
Epoch 286 / 500 | iteration 25 / 30 | Total Loss: 3.4601330757141113 | KNN Loss: 2.4301607608795166 | BCE Loss: 1.0299721956253052
Epoch 287 / 500 | iteration 0 / 30 | Total Loss: 3.519158124923706 | KNN Loss: 2.491647958755493 | BCE Loss: 1.027510166168213
Epoch 287 / 500 | iteration 5 / 30 | Total Loss: 3.4610531330108643 | KNN Loss: 2.4582440853118896 | BCE Loss: 1.0028090476989746
Epoch 287 / 500 | iteration 10 / 30 | Total Loss: 3.4943246841430664 | KNN Loss: 2.492602586746216 | BCE Loss: 1.001721978187561
Epoch 287 / 500 | iteration 15 / 30 | Total Loss: 3.4334664344787598 | KNN Loss: 2.4535648822784424 | BCE Loss: 0.9799015522003174
Epoch 287 / 500 | iteration 20 / 30 | Total Loss: 3.4592132568359375 | KNN Loss: 2.4530551433563232 | BCE Loss: 1.0061581134796143
Epoch 287 / 500 | iteration 25 / 30 | Total Loss: 3.4757306575775146 | KNN Loss: 2.44038200

Epoch 297 / 500 | iteration 5 / 30 | Total Loss: 3.4732837677001953 | KNN Loss: 2.4351251125335693 | BCE Loss: 1.038158655166626
Epoch 297 / 500 | iteration 10 / 30 | Total Loss: 3.520873785018921 | KNN Loss: 2.508338212966919 | BCE Loss: 1.012535572052002
Epoch 297 / 500 | iteration 15 / 30 | Total Loss: 3.4514873027801514 | KNN Loss: 2.4574573040008545 | BCE Loss: 0.9940299987792969
Epoch 297 / 500 | iteration 20 / 30 | Total Loss: 3.513929843902588 | KNN Loss: 2.4694881439208984 | BCE Loss: 1.0444415807724
Epoch 297 / 500 | iteration 25 / 30 | Total Loss: 3.454925537109375 | KNN Loss: 2.4508113861083984 | BCE Loss: 1.0041142702102661
Epoch 298 / 500 | iteration 0 / 30 | Total Loss: 3.495342969894409 | KNN Loss: 2.4714572429656982 | BCE Loss: 1.023885726928711
Epoch 298 / 500 | iteration 5 / 30 | Total Loss: 3.5227975845336914 | KNN Loss: 2.4586291313171387 | BCE Loss: 1.0641684532165527
Epoch 298 / 500 | iteration 10 / 30 | Total Loss: 3.469710350036621 | KNN Loss: 2.453362464904785

Epoch 307 / 500 | iteration 20 / 30 | Total Loss: 3.468215227127075 | KNN Loss: 2.475731372833252 | BCE Loss: 0.992483913898468
Epoch 307 / 500 | iteration 25 / 30 | Total Loss: 3.533130407333374 | KNN Loss: 2.4826128482818604 | BCE Loss: 1.0505175590515137
Epoch 308 / 500 | iteration 0 / 30 | Total Loss: 3.495448589324951 | KNN Loss: 2.447493314743042 | BCE Loss: 1.0479552745819092
Epoch 308 / 500 | iteration 5 / 30 | Total Loss: 3.508626699447632 | KNN Loss: 2.489301919937134 | BCE Loss: 1.019324779510498
Epoch 308 / 500 | iteration 10 / 30 | Total Loss: 3.5021634101867676 | KNN Loss: 2.4864394664764404 | BCE Loss: 1.0157240629196167
Epoch 308 / 500 | iteration 15 / 30 | Total Loss: 3.4730708599090576 | KNN Loss: 2.4594860076904297 | BCE Loss: 1.013584852218628
Epoch 308 / 500 | iteration 20 / 30 | Total Loss: 3.4839322566986084 | KNN Loss: 2.47733473777771 | BCE Loss: 1.0065975189208984
Epoch 308 / 500 | iteration 25 / 30 | Total Loss: 3.4828743934631348 | KNN Loss: 2.45712018013000

Epoch 318 / 500 | iteration 5 / 30 | Total Loss: 3.4578561782836914 | KNN Loss: 2.4376814365386963 | BCE Loss: 1.0201748609542847
Epoch 318 / 500 | iteration 10 / 30 | Total Loss: 3.466712474822998 | KNN Loss: 2.4801025390625 | BCE Loss: 0.9866100549697876
Epoch 318 / 500 | iteration 15 / 30 | Total Loss: 3.475964307785034 | KNN Loss: 2.469437599182129 | BCE Loss: 1.0065267086029053
Epoch 318 / 500 | iteration 20 / 30 | Total Loss: 3.495105743408203 | KNN Loss: 2.4960546493530273 | BCE Loss: 0.999051034450531
Epoch 318 / 500 | iteration 25 / 30 | Total Loss: 3.4792048931121826 | KNN Loss: 2.440356492996216 | BCE Loss: 1.0388484001159668
Epoch 319 / 500 | iteration 0 / 30 | Total Loss: 3.4917168617248535 | KNN Loss: 2.4539999961853027 | BCE Loss: 1.0377168655395508
Epoch 319 / 500 | iteration 5 / 30 | Total Loss: 3.4845714569091797 | KNN Loss: 2.460747718811035 | BCE Loss: 1.023823857307434
Epoch 319 / 500 | iteration 10 / 30 | Total Loss: 3.4620461463928223 | KNN Loss: 2.46889448165893

Epoch 328 / 500 | iteration 20 / 30 | Total Loss: 3.4735074043273926 | KNN Loss: 2.4690756797790527 | BCE Loss: 1.0044316053390503
Epoch 328 / 500 | iteration 25 / 30 | Total Loss: 3.4629359245300293 | KNN Loss: 2.483506679534912 | BCE Loss: 0.9794292449951172
Epoch 329 / 500 | iteration 0 / 30 | Total Loss: 3.472550392150879 | KNN Loss: 2.446959972381592 | BCE Loss: 1.025590419769287
Epoch 329 / 500 | iteration 5 / 30 | Total Loss: 3.5078933238983154 | KNN Loss: 2.470393419265747 | BCE Loss: 1.0374999046325684
Epoch 329 / 500 | iteration 10 / 30 | Total Loss: 3.4523143768310547 | KNN Loss: 2.447262763977051 | BCE Loss: 1.0050514936447144
Epoch 329 / 500 | iteration 15 / 30 | Total Loss: 3.508495569229126 | KNN Loss: 2.444810628890991 | BCE Loss: 1.0636849403381348
Epoch 329 / 500 | iteration 20 / 30 | Total Loss: 3.4900126457214355 | KNN Loss: 2.489133834838867 | BCE Loss: 1.0008786916732788
Epoch 329 / 500 | iteration 25 / 30 | Total Loss: 3.5150203704833984 | KNN Loss: 2.48931789398

Epoch 339 / 500 | iteration 5 / 30 | Total Loss: 3.4773764610290527 | KNN Loss: 2.453782320022583 | BCE Loss: 1.0235941410064697
Epoch 339 / 500 | iteration 10 / 30 | Total Loss: 3.446526050567627 | KNN Loss: 2.4267659187316895 | BCE Loss: 1.019760251045227
Epoch 339 / 500 | iteration 15 / 30 | Total Loss: 3.4576079845428467 | KNN Loss: 2.4474525451660156 | BCE Loss: 1.010155439376831
Epoch 339 / 500 | iteration 20 / 30 | Total Loss: 3.4715757369995117 | KNN Loss: 2.4687118530273438 | BCE Loss: 1.0028637647628784
Epoch 339 / 500 | iteration 25 / 30 | Total Loss: 3.4999752044677734 | KNN Loss: 2.4689481258392334 | BCE Loss: 1.0310271978378296
Epoch 340 / 500 | iteration 0 / 30 | Total Loss: 3.458094835281372 | KNN Loss: 2.43710994720459 | BCE Loss: 1.0209848880767822
Epoch 340 / 500 | iteration 5 / 30 | Total Loss: 3.4802255630493164 | KNN Loss: 2.4444146156311035 | BCE Loss: 1.035810947418213
Epoch 340 / 500 | iteration 10 / 30 | Total Loss: 3.482548713684082 | KNN Loss: 2.459063291549

Epoch 349 / 500 | iteration 20 / 30 | Total Loss: 3.497955560684204 | KNN Loss: 2.4773476123809814 | BCE Loss: 1.0206079483032227
Epoch 349 / 500 | iteration 25 / 30 | Total Loss: 3.4847192764282227 | KNN Loss: 2.494239330291748 | BCE Loss: 0.9904800653457642
Epoch 350 / 500 | iteration 0 / 30 | Total Loss: 3.466337203979492 | KNN Loss: 2.4480340480804443 | BCE Loss: 1.0183031558990479
Epoch 350 / 500 | iteration 5 / 30 | Total Loss: 3.4647974967956543 | KNN Loss: 2.4471867084503174 | BCE Loss: 1.017610788345337
Epoch 350 / 500 | iteration 10 / 30 | Total Loss: 3.445967197418213 | KNN Loss: 2.442220449447632 | BCE Loss: 1.003746747970581
Epoch 350 / 500 | iteration 15 / 30 | Total Loss: 3.458871603012085 | KNN Loss: 2.441967248916626 | BCE Loss: 1.016904354095459
Epoch 350 / 500 | iteration 20 / 30 | Total Loss: 3.4091196060180664 | KNN Loss: 2.4076266288757324 | BCE Loss: 1.0014928579330444
Epoch 350 / 500 | iteration 25 / 30 | Total Loss: 3.4765677452087402 | KNN Loss: 2.445722579956

Epoch 360 / 500 | iteration 5 / 30 | Total Loss: 3.497386932373047 | KNN Loss: 2.449998140335083 | BCE Loss: 1.0473887920379639
Epoch 360 / 500 | iteration 10 / 30 | Total Loss: 3.4849693775177 | KNN Loss: 2.448499917984009 | BCE Loss: 1.0364694595336914
Epoch 360 / 500 | iteration 15 / 30 | Total Loss: 3.457946300506592 | KNN Loss: 2.4493463039398193 | BCE Loss: 1.008600115776062
Epoch 360 / 500 | iteration 20 / 30 | Total Loss: 3.478201389312744 | KNN Loss: 2.450784683227539 | BCE Loss: 1.0274165868759155
Epoch 360 / 500 | iteration 25 / 30 | Total Loss: 3.4787216186523438 | KNN Loss: 2.4399068355560303 | BCE Loss: 1.038814902305603
Epoch 361 / 500 | iteration 0 / 30 | Total Loss: 3.4922869205474854 | KNN Loss: 2.46174693107605 | BCE Loss: 1.0305399894714355
Epoch 361 / 500 | iteration 5 / 30 | Total Loss: 3.4928555488586426 | KNN Loss: 2.4699490070343018 | BCE Loss: 1.0229065418243408
Epoch 361 / 500 | iteration 10 / 30 | Total Loss: 3.477891445159912 | KNN Loss: 2.452258348464966 |

Epoch 370 / 500 | iteration 25 / 30 | Total Loss: 3.437398672103882 | KNN Loss: 2.4345626831054688 | BCE Loss: 1.002835988998413
Epoch   371: reducing learning rate of group 0 to 5.6994e-06.
Epoch 371 / 500 | iteration 0 / 30 | Total Loss: 3.4330203533172607 | KNN Loss: 2.45805287361145 | BCE Loss: 0.9749674797058105
Epoch 371 / 500 | iteration 5 / 30 | Total Loss: 3.4354896545410156 | KNN Loss: 2.4290504455566406 | BCE Loss: 1.006439208984375
Epoch 371 / 500 | iteration 10 / 30 | Total Loss: 3.511507749557495 | KNN Loss: 2.493105173110962 | BCE Loss: 1.0184025764465332
Epoch 371 / 500 | iteration 15 / 30 | Total Loss: 3.4580392837524414 | KNN Loss: 2.4449856281280518 | BCE Loss: 1.0130537748336792
Epoch 371 / 500 | iteration 20 / 30 | Total Loss: 3.4615907669067383 | KNN Loss: 2.4442012310028076 | BCE Loss: 1.0173895359039307
Epoch 371 / 500 | iteration 25 / 30 | Total Loss: 3.4890990257263184 | KNN Loss: 2.4638450145721436 | BCE Loss: 1.0252541303634644
Epoch 372 / 500 | iteration 0 

Epoch 381 / 500 | iteration 10 / 30 | Total Loss: 3.473507881164551 | KNN Loss: 2.446810007095337 | BCE Loss: 1.0266978740692139
Epoch 381 / 500 | iteration 15 / 30 | Total Loss: 3.463573694229126 | KNN Loss: 2.440509796142578 | BCE Loss: 1.0230638980865479
Epoch 381 / 500 | iteration 20 / 30 | Total Loss: 3.500549077987671 | KNN Loss: 2.464442491531372 | BCE Loss: 1.0361065864562988
Epoch 381 / 500 | iteration 25 / 30 | Total Loss: 3.4594476222991943 | KNN Loss: 2.4439120292663574 | BCE Loss: 1.015535593032837
Epoch   382: reducing learning rate of group 0 to 3.9896e-06.
Epoch 382 / 500 | iteration 0 / 30 | Total Loss: 3.5113072395324707 | KNN Loss: 2.4863266944885254 | BCE Loss: 1.0249804258346558
Epoch 382 / 500 | iteration 5 / 30 | Total Loss: 3.4707465171813965 | KNN Loss: 2.4653351306915283 | BCE Loss: 1.0054112672805786
Epoch 382 / 500 | iteration 10 / 30 | Total Loss: 3.438887119293213 | KNN Loss: 2.4046576023101807 | BCE Loss: 1.0342293977737427
Epoch 382 / 500 | iteration 15 

Epoch 391 / 500 | iteration 25 / 30 | Total Loss: 3.487957000732422 | KNN Loss: 2.487403154373169 | BCE Loss: 1.000553846359253
Epoch 392 / 500 | iteration 0 / 30 | Total Loss: 3.4655370712280273 | KNN Loss: 2.4671952724456787 | BCE Loss: 0.9983417987823486
Epoch 392 / 500 | iteration 5 / 30 | Total Loss: 3.490931510925293 | KNN Loss: 2.456716299057007 | BCE Loss: 1.0342150926589966
Epoch 392 / 500 | iteration 10 / 30 | Total Loss: 3.444540500640869 | KNN Loss: 2.4568886756896973 | BCE Loss: 0.9876517653465271
Epoch 392 / 500 | iteration 15 / 30 | Total Loss: 3.4635839462280273 | KNN Loss: 2.4381258487701416 | BCE Loss: 1.0254582166671753
Epoch 392 / 500 | iteration 20 / 30 | Total Loss: 3.4503414630889893 | KNN Loss: 2.431884527206421 | BCE Loss: 1.0184569358825684
Epoch 392 / 500 | iteration 25 / 30 | Total Loss: 3.4753336906433105 | KNN Loss: 2.430148124694824 | BCE Loss: 1.0451855659484863
Epoch   393: reducing learning rate of group 0 to 2.7927e-06.
Epoch 393 / 500 | iteration 0 /

Epoch 402 / 500 | iteration 10 / 30 | Total Loss: 3.4767160415649414 | KNN Loss: 2.45137882232666 | BCE Loss: 1.0253373384475708
Epoch 402 / 500 | iteration 15 / 30 | Total Loss: 3.5057761669158936 | KNN Loss: 2.5011417865753174 | BCE Loss: 1.0046343803405762
Epoch 402 / 500 | iteration 20 / 30 | Total Loss: 3.4773662090301514 | KNN Loss: 2.4580719470977783 | BCE Loss: 1.019294261932373
Epoch 402 / 500 | iteration 25 / 30 | Total Loss: 3.490044593811035 | KNN Loss: 2.4456827640533447 | BCE Loss: 1.0443618297576904
Epoch 403 / 500 | iteration 0 / 30 | Total Loss: 3.489441156387329 | KNN Loss: 2.4473602771759033 | BCE Loss: 1.0420808792114258
Epoch 403 / 500 | iteration 5 / 30 | Total Loss: 3.4443180561065674 | KNN Loss: 2.4254701137542725 | BCE Loss: 1.018847942352295
Epoch 403 / 500 | iteration 10 / 30 | Total Loss: 3.438912868499756 | KNN Loss: 2.430589437484741 | BCE Loss: 1.008323311805725
Epoch 403 / 500 | iteration 15 / 30 | Total Loss: 3.4608047008514404 | KNN Loss: 2.47040867805

Epoch 412 / 500 | iteration 25 / 30 | Total Loss: 3.5131309032440186 | KNN Loss: 2.4810879230499268 | BCE Loss: 1.0320429801940918
Epoch 413 / 500 | iteration 0 / 30 | Total Loss: 3.466654062271118 | KNN Loss: 2.4532883167266846 | BCE Loss: 1.0133657455444336
Epoch 413 / 500 | iteration 5 / 30 | Total Loss: 3.483829975128174 | KNN Loss: 2.4658114910125732 | BCE Loss: 1.0180184841156006
Epoch 413 / 500 | iteration 10 / 30 | Total Loss: 3.467008113861084 | KNN Loss: 2.447864532470703 | BCE Loss: 1.0191435813903809
Epoch 413 / 500 | iteration 15 / 30 | Total Loss: 3.456756114959717 | KNN Loss: 2.4543142318725586 | BCE Loss: 1.0024418830871582
Epoch 413 / 500 | iteration 20 / 30 | Total Loss: 3.4565889835357666 | KNN Loss: 2.457531452178955 | BCE Loss: 0.9990575313568115
Epoch 413 / 500 | iteration 25 / 30 | Total Loss: 3.4298341274261475 | KNN Loss: 2.4092140197753906 | BCE Loss: 1.0206201076507568
Epoch 414 / 500 | iteration 0 / 30 | Total Loss: 3.480556011199951 | KNN Loss: 2.4865055084

Epoch 423 / 500 | iteration 10 / 30 | Total Loss: 3.4923417568206787 | KNN Loss: 2.461247682571411 | BCE Loss: 1.0310940742492676
Epoch 423 / 500 | iteration 15 / 30 | Total Loss: 3.495669364929199 | KNN Loss: 2.46765398979187 | BCE Loss: 1.028015375137329
Epoch 423 / 500 | iteration 20 / 30 | Total Loss: 3.4793505668640137 | KNN Loss: 2.4492850303649902 | BCE Loss: 1.0300655364990234
Epoch 423 / 500 | iteration 25 / 30 | Total Loss: 3.4683918952941895 | KNN Loss: 2.461319923400879 | BCE Loss: 1.0070719718933105
Epoch 424 / 500 | iteration 0 / 30 | Total Loss: 3.4450132846832275 | KNN Loss: 2.4509239196777344 | BCE Loss: 0.9940893054008484
Epoch 424 / 500 | iteration 5 / 30 | Total Loss: 3.461169719696045 | KNN Loss: 2.4487059116363525 | BCE Loss: 1.0124636888504028
Epoch 424 / 500 | iteration 10 / 30 | Total Loss: 3.469176769256592 | KNN Loss: 2.469172954559326 | BCE Loss: 1.0000039339065552
Epoch 424 / 500 | iteration 15 / 30 | Total Loss: 3.443237781524658 | KNN Loss: 2.435686111450

Epoch 433 / 500 | iteration 25 / 30 | Total Loss: 3.4620742797851562 | KNN Loss: 2.4270942211151123 | BCE Loss: 1.034980058670044
Epoch 434 / 500 | iteration 0 / 30 | Total Loss: 3.4873569011688232 | KNN Loss: 2.4771249294281006 | BCE Loss: 1.0102319717407227
Epoch 434 / 500 | iteration 5 / 30 | Total Loss: 3.46212100982666 | KNN Loss: 2.431734085083008 | BCE Loss: 1.0303869247436523
Epoch 434 / 500 | iteration 10 / 30 | Total Loss: 3.48880672454834 | KNN Loss: 2.4545857906341553 | BCE Loss: 1.0342210531234741
Epoch 434 / 500 | iteration 15 / 30 | Total Loss: 3.4438915252685547 | KNN Loss: 2.4465861320495605 | BCE Loss: 0.9973052740097046
Epoch 434 / 500 | iteration 20 / 30 | Total Loss: 3.4529895782470703 | KNN Loss: 2.434837818145752 | BCE Loss: 1.0181517601013184
Epoch 434 / 500 | iteration 25 / 30 | Total Loss: 3.461735248565674 | KNN Loss: 2.4463207721710205 | BCE Loss: 1.0154144763946533
Epoch 435 / 500 | iteration 0 / 30 | Total Loss: 3.4560623168945312 | KNN Loss: 2.45732069015

Epoch 444 / 500 | iteration 10 / 30 | Total Loss: 3.449402093887329 | KNN Loss: 2.4508118629455566 | BCE Loss: 0.9985901713371277
Epoch 444 / 500 | iteration 15 / 30 | Total Loss: 3.4989774227142334 | KNN Loss: 2.4647326469421387 | BCE Loss: 1.0342447757720947
Epoch 444 / 500 | iteration 20 / 30 | Total Loss: 3.493281364440918 | KNN Loss: 2.472656011581421 | BCE Loss: 1.0206252336502075
Epoch 444 / 500 | iteration 25 / 30 | Total Loss: 3.5535709857940674 | KNN Loss: 2.5044548511505127 | BCE Loss: 1.0491161346435547
Epoch 445 / 500 | iteration 0 / 30 | Total Loss: 3.527285575866699 | KNN Loss: 2.5087225437164307 | BCE Loss: 1.0185630321502686
Epoch 445 / 500 | iteration 5 / 30 | Total Loss: 3.4583017826080322 | KNN Loss: 2.4532322883605957 | BCE Loss: 1.0050694942474365
Epoch 445 / 500 | iteration 10 / 30 | Total Loss: 3.4738872051239014 | KNN Loss: 2.452670097351074 | BCE Loss: 1.0212171077728271
Epoch 445 / 500 | iteration 15 / 30 | Total Loss: 3.4590539932250977 | KNN Loss: 2.4362812

Epoch 454 / 500 | iteration 25 / 30 | Total Loss: 3.488745927810669 | KNN Loss: 2.463747024536133 | BCE Loss: 1.0249989032745361
Epoch 455 / 500 | iteration 0 / 30 | Total Loss: 3.4412522315979004 | KNN Loss: 2.463918924331665 | BCE Loss: 0.9773333072662354
Epoch 455 / 500 | iteration 5 / 30 | Total Loss: 3.467306613922119 | KNN Loss: 2.4757308959960938 | BCE Loss: 0.9915756583213806
Epoch 455 / 500 | iteration 10 / 30 | Total Loss: 3.512146234512329 | KNN Loss: 2.493724822998047 | BCE Loss: 1.0184214115142822
Epoch 455 / 500 | iteration 15 / 30 | Total Loss: 3.498267650604248 | KNN Loss: 2.504296064376831 | BCE Loss: 0.9939714670181274
Epoch 455 / 500 | iteration 20 / 30 | Total Loss: 3.4648666381835938 | KNN Loss: 2.4375252723693848 | BCE Loss: 1.0273414850234985
Epoch 455 / 500 | iteration 25 / 30 | Total Loss: 3.4987854957580566 | KNN Loss: 2.466909170150757 | BCE Loss: 1.0318762063980103
Epoch 456 / 500 | iteration 0 / 30 | Total Loss: 3.4794301986694336 | KNN Loss: 2.469111204147

Epoch 465 / 500 | iteration 10 / 30 | Total Loss: 3.4834883213043213 | KNN Loss: 2.4896440505981445 | BCE Loss: 0.9938442707061768
Epoch 465 / 500 | iteration 15 / 30 | Total Loss: 3.5023272037506104 | KNN Loss: 2.4970390796661377 | BCE Loss: 1.0052881240844727
Epoch 465 / 500 | iteration 20 / 30 | Total Loss: 3.4285550117492676 | KNN Loss: 2.429630756378174 | BCE Loss: 0.9989243149757385
Epoch 465 / 500 | iteration 25 / 30 | Total Loss: 3.4695472717285156 | KNN Loss: 2.4654178619384766 | BCE Loss: 1.0041295289993286
Epoch 466 / 500 | iteration 0 / 30 | Total Loss: 3.4450631141662598 | KNN Loss: 2.422218084335327 | BCE Loss: 1.0228450298309326
Epoch 466 / 500 | iteration 5 / 30 | Total Loss: 3.459489107131958 | KNN Loss: 2.458935022354126 | BCE Loss: 1.000554084777832
Epoch 466 / 500 | iteration 10 / 30 | Total Loss: 3.4836134910583496 | KNN Loss: 2.457202911376953 | BCE Loss: 1.0264105796813965
Epoch 466 / 500 | iteration 15 / 30 | Total Loss: 3.495404005050659 | KNN Loss: 2.492884159

Epoch 475 / 500 | iteration 25 / 30 | Total Loss: 3.5007741451263428 | KNN Loss: 2.466007947921753 | BCE Loss: 1.0347661972045898
Epoch 476 / 500 | iteration 0 / 30 | Total Loss: 3.478994846343994 | KNN Loss: 2.447139263153076 | BCE Loss: 1.0318554639816284
Epoch 476 / 500 | iteration 5 / 30 | Total Loss: 3.5263655185699463 | KNN Loss: 2.5072953701019287 | BCE Loss: 1.0190701484680176
Epoch 476 / 500 | iteration 10 / 30 | Total Loss: 3.481199264526367 | KNN Loss: 2.435866355895996 | BCE Loss: 1.0453327894210815
Epoch 476 / 500 | iteration 15 / 30 | Total Loss: 3.447815179824829 | KNN Loss: 2.4393181800842285 | BCE Loss: 1.0084969997406006
Epoch 476 / 500 | iteration 20 / 30 | Total Loss: 3.46799373626709 | KNN Loss: 2.4431240558624268 | BCE Loss: 1.0248697996139526
Epoch 476 / 500 | iteration 25 / 30 | Total Loss: 3.462148427963257 | KNN Loss: 2.473351240158081 | BCE Loss: 0.988797128200531
Epoch 477 / 500 | iteration 0 / 30 | Total Loss: 3.4834818840026855 | KNN Loss: 2.47166061401367

Epoch 486 / 500 | iteration 10 / 30 | Total Loss: 3.4681904315948486 | KNN Loss: 2.44805908203125 | BCE Loss: 1.0201313495635986
Epoch 486 / 500 | iteration 15 / 30 | Total Loss: 3.4739770889282227 | KNN Loss: 2.4568309783935547 | BCE Loss: 1.017146110534668
Epoch 486 / 500 | iteration 20 / 30 | Total Loss: 3.452244997024536 | KNN Loss: 2.422790050506592 | BCE Loss: 1.0294549465179443
Epoch 486 / 500 | iteration 25 / 30 | Total Loss: 3.4274847507476807 | KNN Loss: 2.4373679161071777 | BCE Loss: 0.9901167750358582
Epoch 487 / 500 | iteration 0 / 30 | Total Loss: 3.5117275714874268 | KNN Loss: 2.470431089401245 | BCE Loss: 1.0412964820861816
Epoch 487 / 500 | iteration 5 / 30 | Total Loss: 3.469747543334961 | KNN Loss: 2.4631640911102295 | BCE Loss: 1.0065834522247314
Epoch 487 / 500 | iteration 10 / 30 | Total Loss: 3.432394504547119 | KNN Loss: 2.424036741256714 | BCE Loss: 1.0083577632904053
Epoch 487 / 500 | iteration 15 / 30 | Total Loss: 3.464578151702881 | KNN Loss: 2.436367988586

Epoch 496 / 500 | iteration 25 / 30 | Total Loss: 3.4983267784118652 | KNN Loss: 2.4660351276397705 | BCE Loss: 1.0322915315628052
Epoch 497 / 500 | iteration 0 / 30 | Total Loss: 3.459026336669922 | KNN Loss: 2.448936700820923 | BCE Loss: 1.010089635848999
Epoch 497 / 500 | iteration 5 / 30 | Total Loss: 3.456577777862549 | KNN Loss: 2.4369843006134033 | BCE Loss: 1.019593358039856
Epoch 497 / 500 | iteration 10 / 30 | Total Loss: 3.4633026123046875 | KNN Loss: 2.4450020790100098 | BCE Loss: 1.0183006525039673
Epoch 497 / 500 | iteration 15 / 30 | Total Loss: 3.49482798576355 | KNN Loss: 2.481560707092285 | BCE Loss: 1.0132672786712646
Epoch 497 / 500 | iteration 20 / 30 | Total Loss: 3.4193570613861084 | KNN Loss: 2.4318349361419678 | BCE Loss: 0.9875220656394958
Epoch 497 / 500 | iteration 25 / 30 | Total Loss: 3.443190097808838 | KNN Loss: 2.4283461570739746 | BCE Loss: 1.0148439407348633
Epoch 498 / 500 | iteration 0 / 30 | Total Loss: 3.477756977081299 | KNN Loss: 2.4426386356353

In [7]:
outputs, iterm = model(dataset[67][0].unsqueeze(0).to(device), return_intermidiate=True)
print(outputs)
print(outputs.min())
print(outputs.max())

tensor([[ 2.4856,  2.3745,  2.7505,  3.4871,  3.6530,  0.4236,  2.7899,  2.3621,
          2.5124,  2.1661,  2.3326,  2.4144,  0.6715,  1.7256,  1.4388,  1.7747,
          2.8992,  3.2629,  2.5353,  2.0810,  1.7299,  3.1325,  2.4733,  2.7931,
          2.7131,  1.8447,  1.3999,  1.1521,  1.0779,  0.2516, -0.5336,  1.1434,
          0.1360,  1.0365,  1.1041,  1.1722,  1.2571,  2.9358,  0.5095,  0.9384,
          0.7459, -0.5871, -0.1879,  2.4866,  2.4053,  0.7438, -0.2547, -0.0358,
          1.0731,  2.3665,  2.0010,  0.1006,  1.6326,  0.4142, -0.5522,  0.5195,
          1.5745,  1.3536,  1.5414,  1.1375,  0.4956,  1.0106,  0.2676,  1.3019,
          1.4890,  1.8231, -2.0561,  0.0387,  2.0633,  1.9456,  2.1997,  0.2808,
          1.5345,  2.7063,  1.9748,  1.0967,  0.2422,  0.7985,  0.2194,  1.6211,
          0.1200,  0.5925,  2.0239, -0.3642,  0.3667, -0.9796, -2.3407, -0.5386,
          0.2856, -1.8904,  0.4330, -0.1589, -0.9170, -0.7434,  0.4221,  1.2528,
         -1.2124, -0.8169,  

In [8]:
plt.figure()
plt.plot(losses)
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [9]:
dataset.transform = torchvision.transforms.Compose([
    BinaryEncodingTransform(mapping=dataset.items_to_idx),
]
)
dataset.target_transform = torchvision.transforms.Compose([
    BinaryEncodingTransform(mapping=dataset.items_to_idx),
]
)

In [10]:
dataset_ = [d[0].cpu() for d in dataset]

In [11]:
model = model.eval().cpu()
projections = model.calculate_intermidiate(dataset_)

100%|██████████| 15/15 [00:00<00:00, 79.45it/s]


In [12]:
distances = pairwise_distances(projections)
# distances = np.triu(distances)
distances_f = distances.flatten()

plt.matshow(distances)
plt.colorbar()
plt.figure()
plt.hist(distances_f[distances_f > 0], bins=1000)
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

# Fit DBSCAN and calculate indices


In [13]:
clusters = DBSCAN(eps=0.2, min_samples=80).fit_predict(projections)
# scores = []
# best_score = float('inf')
# clusters = None
# range_ = list(range(5, 20))
# for k in tqdm(range_):
#     y = GaussianMixture(n_components=k).fit_predict(projections)
#     cur_score = davies_bouldin_score(projections, y)
#     scores.append(cur_score)
    
#     if cur_score < best_score:
#         best_score = cur_score
#         clusters = y

In [14]:
perplexity = 100

p = reduce_dims_and_plot(projections[clusters != -1],
                         y=clusters[clusters != -1],
                         title=f'perplexity: {perplexity}',
                         file_name=None,
                         perplexity=perplexity,
                         library='Multicore-TSNE',
                         perform_PCA=False,
                         projected=None,
                         figure_type='2d',
                         show_figure=True,
                         close_figure=False,
                         text=None)    

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [15]:
# from sklearn.tree import DecisionTreeClassifier
# from sklearn import tree
# from sklearn.tree import _tree

In [16]:
tensor_dataset = torch.stack(dataset_)

In [17]:
# clf = DecisionTreeClassifier(max_depth=200, min_samples_leaf=5)
# clf = clf.fit(tensor_dataset[clusters!=-1], clusters[clusters != -1])
# print(clf.score(tensor_dataset[clusters!=-1], clusters[clusters != -1]))
# print(clf.get_depth())

In [18]:
# scores = []
# for min_samples in range(1,50, 1):
#     clf = DecisionTreeClassifier(max_depth=200, min_samples_leaf=min_samples)
#     clf = clf.fit(tensor_dataset[clusters!=-1], clusters[clusters != -1])
#     scores.append(clf.score(tensor_dataset[clusters!=-1], clusters[clusters != -1]))
    
# plt.figure()
# plt.plot(list(range(1,50, 1)), scores)
# plt.show()

In [19]:
def get_rules(tree, feature_names, class_names):
    tree_ = tree.tree_
    feature_name = [
        feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!"
        for i in tree_.feature
    ]

    paths = []
    path = []
    
    def recurse(node, path, paths):
        if tree_.feature[node] != _tree.TREE_UNDEFINED:
            name = feature_name[node]
            threshold = tree_.threshold[node]
            p1, p2 = list(path), list(path)
#             p1 += [f"({name} <= {np.round(threshold, 3)})"]
            p1 += [(name, '<=', np.round(threshold, 3))]
            recurse(tree_.children_left[node], p1, paths)
            p2 += [(name, '>', np.round(threshold, 3))]
            recurse(tree_.children_right[node], p2, paths)
        else:
            path += [(tree_.value[node], tree_.n_node_samples[node])]
            paths += [path]
            
    recurse(0, path, paths)

    # sort by samples count
    samples_count = [p[-1][1] for p in paths]
    ii = list(np.argsort(samples_count))
    paths = [paths[i] for i in reversed(ii)]
    
    rules = []
    for path in paths:
        rule = []
        
        for p in path[:-1]:
            rule += [p]
        target = " then "
        if class_names is None:
            target += "response: "+str(np.round(path[-1][0][0][0],3))
        else:
            classes = path[-1][0][0]
            l = np.argmax(classes)
            target += f"class: {class_names[l]} (proba: {np.round(100.0*classes[l]/np.sum(classes),2)}%)"
           
        proba = np.round(100.0*classes[l]/np.sum(classes),2)
        target += f" | based on {path[-1][1]:,} samples"
        rule_wrapper = {'target': target, 'rule': rule, 'proba': proba}
        rules += [rule_wrapper]
        
    return rules

In [20]:
# rules = get_rules(clf, dataset.items, clusters[clusters != -1])

# for rule in rules:
#     n_pos = 0
#     for c,p,v in rule['rule']:
#         if p == '>':
#             n_pos += 1
#     rule['pos'] = n_pos

In [21]:
# plt.figure()
# probs = [r['proba'] for r in rules]
# plt.hist(probs, bins = 100)
# plt.show()

In [22]:
# rules = sorted(rules, key=lambda x:x['pos'])
# rules = [r for r in rules if r['proba'] > 50]
# print(len(rules))

In [23]:
# for i in range(17):
#     r_i = rules[i]
#     print(f"------------- rule {i} length {len(r_i)} -------------")
#     print(r_i)

# Train a Soft-Decision-Tree given the self-labels

## Prepare the dataset

In [24]:
tree_dataset = list(zip(tensor_dataset[clusters!=-1], clusters[clusters != -1]))
batch_size = 512
tree_loader = torch.utils.data.DataLoader(tree_dataset, batch_size=batch_size, shuffle=True)

# Define how we prune the weights of a node

In [25]:
def prune_node(node_weights, factor=1):
    w = node_weights.cpu().detach().numpy()
    mean_ = np.mean(w)
    std_ = np.std(w)
    node_weights[((mean_ - std_ * factor) < node_weights) & (node_weights < (mean_ + std_ * factor))] = 0
    return node_weights

def prune_node_keep(node_weights, keep=4):
    w = node_weights.cpu().detach().numpy()
    throw_idx = np.argsort(abs(w))[:-keep]
    node_weights[throw_idx] = 0
    return node_weights

def prune_tree(tree_, factor):
    new_weights = tree_.inner_nodes.weight.clone()
    for i in range(new_weights.shape[0]):
        res = prune_node_keep(new_weights[i, :], factor)
        new_weights[i, :] = res

    with torch.no_grad():
        tree_.inner_nodes.weight.copy_(new_weights)
        
def sparseness(x):
    s = []
    for i in range(x.shape[0]):
        x_ = x[i, :]
        sp = (len(x_) - torch.norm(x_, 0).item()) / len(x_)
        s.append(sp)
    return np.mean(s)

def compute_regularization_by_level(tree):
    total_reg = 0
    for i in range(tree.inner_nodes.weight.shape[0]):
        cur_level = np.floor(np.log2(i+1))
        node_reg = torch.norm(tree.inner_nodes.weight[i].view(-1), 2)
        total_reg += 2**(-cur_level) * node_reg
    return total_reg

def show_sparseness(tree):
    avg_sp = sparseness(tree.inner_nodes.weight)
    print(f"Average sparseness: {avg_sp}")
    layer = 0
    sps = []
    for i in range(tree.inner_nodes.weight.shape[0]):
        cur_layer = int(np.floor(np.log2(i+1)))
        if cur_layer != layer:
            print(f"layer {layer}: {np.mean(sps)}")
            sps = []
            layer = cur_layer

        x_ = tree.inner_nodes.weight[i, :]
        sp = (len(x_) - torch.norm(x_, 0).item()) / len(x_)
        sps.append(sp)
        
    return avg_sp

## Training configurations

In [26]:
def do_epoch(model, loader, device, log_interval, losses, accs, epoch, iteration):
    model = model.train()
    for batch_idx, (data, target) in enumerate(loader):
        iteration += 1
        data, target = data.to(device), target.to(device)

        output, penalty = tree.forward(data)

        # Loss
        loss_tree = criterion(output, target.view(-1))

        # Penalty
        loss_tree += penalty

        # Sparse regularization
#         fc_params = torch.cat([x.view(-1) for x in tree.inner_nodes.parameters()])
#         regularization = sparsity_lamda * torch.norm(fc_params, 2)
        regularization = sparsity_lamda * compute_regularization_by_level(tree)
        loss = loss_tree

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        losses.append(loss.item())

        pred = output.data.max(1)[1]
        correct = pred.eq(target.view(-1).data).sum()
        accs.append(correct.item() / data.size()[0])

        # Print training status
        if batch_idx % log_interval == 0:
            print(f"Epoch: {epoch:02d} | Batch: {batch_idx:03d} / {len(loader):03d} | Total loss: {loss.item():.3f} | Reg loss: {regularization.item():.3f} | Tree loss: {loss_tree.item():.3f} | Accuracy: {correct.item() / data.size()[0]:03f} | {round((time.time() - start_time) / iteration, 3)} sec/iter")
            
    return iteration


In [27]:
lr = 5e-3
weight_decay = 5e-4
sparsity_lamda = 2e-3
epochs = 100
output_dim = len(set(clusters))
log_interval = 1
use_cuda = device != 'cpu'

In [28]:
tree = SDT(input_dim=tensor_dataset.shape[1], output_dim=len(clusters - 1), depth=tree_depth, lamda=1e-3, use_cuda=use_cuda)
optimizer = torch.optim.Adam(tree.parameters(),
                                 lr=lr,
                                 weight_decay=weight_decay)
criterion = nn.CrossEntropyLoss()
tree = tree.to(device)

In [29]:
losses = []
accs = []
sparsity = []

In [30]:
start_time = time.time()
iteration = 0
for epoch in range(epochs):
    # Training
    avg_sp = show_sparseness(tree)
    sparsity.append(avg_sp)
    iteration = do_epoch(tree, tree_loader, device, log_interval, losses, accs, epoch, iteration)
    
    if epoch % 1 == 0:
        prune_tree(tree, factor=3)
        

Average sparseness: 0.0
layer 0: 0.0
layer 1: 0.0
layer 2: 0.0
layer 3: 0.0
layer 4: 0.0
layer 5: 0.0
layer 6: 0.0
layer 7: 0.0
layer 8: 0.0
layer 9: 0.0
layer 10: 0.0
Epoch: 00 | Batch: 000 / 003 | Total loss: 9.629 | Reg loss: 0.014 | Tree loss: 9.629 | Accuracy: 0.000000 | 3.81 sec/iter
Epoch: 00 | Batch: 001 / 003 | Total loss: 9.627 | Reg loss: 0.013 | Tree loss: 9.627 | Accuracy: 0.000000 | 3.735 sec/iter
Epoch: 00 | Batch: 002 / 003 | Total loss: 9.623 | Reg loss: 0.012 | Tree loss: 9.623 | Accuracy: 0.000000 | 3.564 sec/iter
Average sparseness: 0.9821428571428572
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
layer 7: 0.9821428571428573
layer 8: 0.9821428571428573
layer 9: 0.9821428571428573
layer 10: 0.9821428571428573
Epoch: 01 | Batch: 000 / 003 | Total loss: 9.623 | Reg loss: 0.003 | Tree loss: 9.623 | Accuracy: 0.000000 | 3.722

Epoch: 11 | Batch: 001 / 003 | Total loss: 9.586 | Reg loss: 0.004 | Tree loss: 9.586 | Accuracy: 0.234375 | 3.685 sec/iter
Epoch: 11 | Batch: 002 / 003 | Total loss: 9.583 | Reg loss: 0.004 | Tree loss: 9.583 | Accuracy: 0.229814 | 3.672 sec/iter
Average sparseness: 0.9821428571428572
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
layer 7: 0.9821428571428573
layer 8: 0.9821428571428573
layer 9: 0.9821428571428573
layer 10: 0.9821428571428573
Epoch: 12 | Batch: 000 / 003 | Total loss: 9.585 | Reg loss: 0.004 | Tree loss: 9.585 | Accuracy: 0.230469 | 3.687 sec/iter
Epoch: 12 | Batch: 001 / 003 | Total loss: 9.583 | Reg loss: 0.004 | Tree loss: 9.583 | Accuracy: 0.230469 | 3.685 sec/iter
Epoch: 12 | Batch: 002 / 003 | Total loss: 9.580 | Reg loss: 0.004 | Tree loss: 9.580 | Accuracy: 0.245342 | 3.673 sec/iter
Average sparseness: 0.98214285714

Epoch: 23 | Batch: 000 / 003 | Total loss: 9.562 | Reg loss: 0.006 | Tree loss: 9.562 | Accuracy: 0.234375 | 3.687 sec/iter
Epoch: 23 | Batch: 001 / 003 | Total loss: 9.562 | Reg loss: 0.006 | Tree loss: 9.562 | Accuracy: 0.226562 | 3.686 sec/iter
Epoch: 23 | Batch: 002 / 003 | Total loss: 9.556 | Reg loss: 0.006 | Tree loss: 9.556 | Accuracy: 0.245342 | 3.68 sec/iter
Average sparseness: 0.9821428571428572
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
layer 7: 0.9821428571428573
layer 8: 0.9821428571428573
layer 9: 0.9821428571428573
layer 10: 0.9821428571428573
Epoch: 24 | Batch: 000 / 003 | Total loss: 9.563 | Reg loss: 0.006 | Tree loss: 9.563 | Accuracy: 0.212891 | 3.687 sec/iter
Epoch: 24 | Batch: 001 / 003 | Total loss: 9.557 | Reg loss: 0.006 | Tree loss: 9.557 | Accuracy: 0.242188 | 3.686 sec/iter
Epoch: 24 | Batch: 002 / 003 | Tot

Average sparseness: 0.9821428571428572
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
layer 7: 0.9821428571428573
layer 8: 0.9821428571428573
layer 9: 0.9821428571428573
layer 10: 0.9821428571428573
Epoch: 35 | Batch: 000 / 003 | Total loss: 9.527 | Reg loss: 0.008 | Tree loss: 9.527 | Accuracy: 0.255859 | 3.685 sec/iter
Epoch: 35 | Batch: 001 / 003 | Total loss: 9.527 | Reg loss: 0.009 | Tree loss: 9.527 | Accuracy: 0.232422 | 3.684 sec/iter
Epoch: 35 | Batch: 002 / 003 | Total loss: 9.524 | Reg loss: 0.009 | Tree loss: 9.524 | Accuracy: 0.201863 | 3.68 sec/iter
Average sparseness: 0.9821428571428572
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
layer 7: 0.9821428571428573
layer 8: 0.982142857

Epoch: 46 | Batch: 000 / 003 | Total loss: 9.456 | Reg loss: 0.011 | Tree loss: 9.456 | Accuracy: 0.214844 | 3.682 sec/iter
Epoch: 46 | Batch: 001 / 003 | Total loss: 9.435 | Reg loss: 0.011 | Tree loss: 9.435 | Accuracy: 0.253906 | 3.682 sec/iter
Epoch: 46 | Batch: 002 / 003 | Total loss: 9.422 | Reg loss: 0.011 | Tree loss: 9.422 | Accuracy: 0.232919 | 3.678 sec/iter
Average sparseness: 0.9821428571428572
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
layer 7: 0.9821428571428573
layer 8: 0.9821428571428573
layer 9: 0.9821428571428573
layer 10: 0.9821428571428573
Epoch: 47 | Batch: 000 / 003 | Total loss: 9.435 | Reg loss: 0.011 | Tree loss: 9.435 | Accuracy: 0.244141 | 3.682 sec/iter
Epoch: 47 | Batch: 001 / 003 | Total loss: 9.425 | Reg loss: 0.011 | Tree loss: 9.425 | Accuracy: 0.228516 | 3.682 sec/iter
Epoch: 47 | Batch: 002 / 003 | To

Average sparseness: 0.9821428571428572
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
layer 7: 0.9821428571428573
layer 8: 0.9821428571428573
layer 9: 0.9821428571428573
layer 10: 0.9821428571428573
Epoch: 58 | Batch: 000 / 003 | Total loss: 9.240 | Reg loss: 0.013 | Tree loss: 9.240 | Accuracy: 0.226562 | 3.682 sec/iter
Epoch: 58 | Batch: 001 / 003 | Total loss: 9.212 | Reg loss: 0.013 | Tree loss: 9.212 | Accuracy: 0.236328 | 3.681 sec/iter
Epoch: 58 | Batch: 002 / 003 | Total loss: 9.186 | Reg loss: 0.013 | Tree loss: 9.186 | Accuracy: 0.242236 | 3.679 sec/iter
Average sparseness: 0.9821428571428572
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
layer 7: 0.9821428571428573
layer 8: 0.98214285

Epoch: 69 | Batch: 000 / 003 | Total loss: 8.895 | Reg loss: 0.015 | Tree loss: 8.895 | Accuracy: 0.253906 | 3.683 sec/iter
Epoch: 69 | Batch: 001 / 003 | Total loss: 8.876 | Reg loss: 0.015 | Tree loss: 8.876 | Accuracy: 0.234375 | 3.683 sec/iter
Epoch: 69 | Batch: 002 / 003 | Total loss: 8.857 | Reg loss: 0.015 | Tree loss: 8.857 | Accuracy: 0.198758 | 3.68 sec/iter
Average sparseness: 0.9821428571428572
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
layer 7: 0.9821428571428573
layer 8: 0.9821428571428573
layer 9: 0.9821428571428573
layer 10: 0.9821428571428573
Epoch: 70 | Batch: 000 / 003 | Total loss: 8.872 | Reg loss: 0.015 | Tree loss: 8.872 | Accuracy: 0.246094 | 3.683 sec/iter
Epoch: 70 | Batch: 001 / 003 | Total loss: 8.832 | Reg loss: 0.015 | Tree loss: 8.832 | Accuracy: 0.234375 | 3.683 sec/iter
Epoch: 70 | Batch: 002 / 003 | Tot

Average sparseness: 0.9821428571428572
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
layer 7: 0.9821428571428573
layer 8: 0.9821428571428573
layer 9: 0.9821428571428573
layer 10: 0.9821428571428573
Epoch: 81 | Batch: 000 / 003 | Total loss: 8.444 | Reg loss: 0.017 | Tree loss: 8.444 | Accuracy: 0.242188 | 3.682 sec/iter
Epoch: 81 | Batch: 001 / 003 | Total loss: 8.419 | Reg loss: 0.017 | Tree loss: 8.419 | Accuracy: 0.220703 | 3.682 sec/iter
Epoch: 81 | Batch: 002 / 003 | Total loss: 8.377 | Reg loss: 0.017 | Tree loss: 8.377 | Accuracy: 0.229814 | 3.68 sec/iter
Average sparseness: 0.9821428571428572
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
layer 7: 0.9821428571428573
layer 8: 0.982142857

Epoch: 92 | Batch: 000 / 003 | Total loss: 7.985 | Reg loss: 0.018 | Tree loss: 7.985 | Accuracy: 0.244141 | 3.682 sec/iter
Epoch: 92 | Batch: 001 / 003 | Total loss: 7.962 | Reg loss: 0.018 | Tree loss: 7.962 | Accuracy: 0.205078 | 3.682 sec/iter
Epoch: 92 | Batch: 002 / 003 | Total loss: 7.943 | Reg loss: 0.018 | Tree loss: 7.943 | Accuracy: 0.195652 | 3.68 sec/iter
Average sparseness: 0.9821428571428572
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
layer 7: 0.9821428571428573
layer 8: 0.9821428571428573
layer 9: 0.9821428571428573
layer 10: 0.9821428571428573
Epoch: 93 | Batch: 000 / 003 | Total loss: 7.948 | Reg loss: 0.018 | Tree loss: 7.948 | Accuracy: 0.220703 | 3.682 sec/iter
Epoch: 93 | Batch: 001 / 003 | Total loss: 7.936 | Reg loss: 0.018 | Tree loss: 7.936 | Accuracy: 0.205078 | 3.682 sec/iter
Epoch: 93 | Batch: 002 / 003 | Tot

In [31]:
plt.figure(figsize=(10, 5))
plt.ylabel("Accuracy")
plt.xlabel('Iteration')
plt.plot(accs, label='Accuracy vs iteration')
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [32]:
plt.figure()
plt.ylabel("Loss")
plt.xlabel('Iteration')
plt.plot(losses, label='Loss vs iteration')
plt.yscale("log")
plt.show()

plt.figure()
weights = tree.inner_nodes.weight.cpu().detach().numpy().flatten()
plt.hist(weights, bins=500)
weights_std = np.std(weights)
weights_mean = np.mean(weights)
plt.axvline(weights_mean + weights_std, color='r')
plt.axvline(weights_mean - weights_std, color='r')
plt.title(f"Mean: {weights_mean}   |   STD: {weights_std}")
plt.yscale("log")
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

# Tree Visualization

In [33]:
plt.figure(figsize=(15, 10), dpi=80)
avg_height, root = tree.visualize()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Average height: 11.999511480214949


# Extract Rules

# Accumulate samples in the leaves

In [34]:
print(f"Number of patterns: {len(root.get_leaves())}")

Number of patterns: 4094


In [35]:
method = 'greedy'

In [36]:
root.clear_leaves_samples()

with torch.no_grad():
    for batch_idx, (data, target) in enumerate(tree_loader):
        root.accumulate_samples(data, method)



# Tighten boundaries

In [37]:
attr_names = dataset.items

# print(attr_names)
leaves = root.get_leaves()
sum_comprehensibility = 0
comprehensibilities = []
for pattern_counter, leaf in enumerate(leaves):
    leaf.reset_path()
    leaf.tighten_with_accumulated_samples()
    conds = leaf.get_path_conditions(attr_names)
    print(f"============== Pattern {pattern_counter + 1} ==============")
    comprehensibilities.append(sum([cond.comprehensibility for cond in conds]))
    
print(f"Average comprehensibility: {np.mean(comprehensibilities)}")
print(f"std comprehensibility: {np.std(comprehensibilities)}")
print(f"var comprehensibility: {np.var(comprehensibilities)}")
print(f"minimum comprehensibility: {np.min(comprehensibilities)}")
print(f"maximum comprehensibility: {np.max(comprehensibilities)}")



  return np.log(1 / (1 - x))


1346












Average comprehensibility: 53.49438202247191
std comprehensibility: 2.2039052908750985
var comprehensibility: 4.857198531147253
minimum comprehensibility: 44
maximum comprehensibility: 60
