In [6]:
%load_ext autoreload
%autoreload 2
%matplotlib widget
import numpy as np
import pandas as pd
import torch
import torchvision
from torch import optim
import matplotlib.pyplot as plt
import torch.nn.functional as F
from sklearn.metrics import pairwise_distances
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tqdm.notebook import tqdm
import os
import sys
import time
import torch.nn as nn
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)
from stream_generators.market_basket_dataset import MarketBasketDataset, BinaryEncodingTransform, RemoveItemsTransform
from utils.MatplotlibUtils import reduce_dims_and_plot
from network.auto_encoder import AutoEncoder
from losses.knn_loss import KNNLoss
from sklearn.cluster import DBSCAN
from sklearn.mixture import GaussianMixture
from soft_decision_tree.sdt_model import SDT
from sklearn.metrics import davies_bouldin_score

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [7]:
k = 4
tree_depth = 8
device = 'cuda'
dataset_path = r"/mnt/qnap/ekosman/Groceries_dataset.csv"

## Create the market basket dataset and use one-hot encoding for items

In [13]:
dataset = MarketBasketDataset(dataset_path=dataset_path)

In [14]:
model = AutoEncoder(dataset.n_items, 50, 4).train().to(device)
epochs = 500
lr = 5e-3
batch_size = 512
log_every = 5

In [15]:
dataset.transform = torchvision.transforms.Compose([
    RemoveItemsTransform(p=0.5),
    BinaryEncodingTransform(mapping=dataset.items_to_idx),
]
)
dataset.target_transform = torchvision.transforms.Compose([
    BinaryEncodingTransform(mapping=dataset.items_to_idx),
]
)

In [16]:
model.train()
data_iter = torch.utils.data.DataLoader(dataset,
                                     batch_size=batch_size,
                                     shuffle=True,
                                     num_workers=1,
                                     pin_memory=True)
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)
scheduler = ReduceLROnPlateau(optimizer, 'min', verbose=True, factor=0.7, threshold=1e-4)
knn_crt = KNNLoss(k=k).to(device)
losses = []
alpha = 10/170
gamma = 2
for epoch in range(epochs):
    total_loss = 0
    for iteration, (batch, target) in enumerate(data_iter):
        batch = batch.to(device)
        target = target.to(device)
        outputs, iterm = model(batch, return_intermidiate=True)
        mse_loss = F.binary_cross_entropy_with_logits(outputs, target, reduction='none')
        mask = torch.ones_like(mse_loss)
        mask[target == 0] = alpha ** gamma
        mask[target == 1] = (1 - alpha) ** gamma
        mse_loss = (mse_loss * mask).sum(dim=-1).mean()
        try:
            knn_loss = knn_crt(iterm)
            if torch.isinf(knn_loss):
                knn_loss = 0
        except ValueError:
            knn_loss = torch.tensor(0)
        loss = mse_loss
        total_loss += loss.item()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if iteration % log_every == 0:
            print(f"Epoch {epoch} / {epochs} | iteration {iteration} / {len(data_iter)} | Total Loss: {loss.item()} | KNN Loss: {knn_loss.item()} | BCE Loss: {mse_loss.item()}")
    
    scheduler.step(total_loss / (iteration + 1))
    losses.append(total_loss / (iteration + 1))

Epoch 0 / 500 | iteration 0 / 30 | Total Loss: 2.0072975158691406 | KNN Loss: 6.227202892303467 | BCE Loss: 2.0072975158691406
Epoch 0 / 500 | iteration 5 / 30 | Total Loss: 1.9876428842544556 | KNN Loss: 6.2271728515625 | BCE Loss: 1.9876428842544556
Epoch 0 / 500 | iteration 10 / 30 | Total Loss: 1.9654991626739502 | KNN Loss: 6.227078914642334 | BCE Loss: 1.9654991626739502
Epoch 0 / 500 | iteration 15 / 30 | Total Loss: 2.007986545562744 | KNN Loss: 6.227046489715576 | BCE Loss: 2.007986545562744
Epoch 0 / 500 | iteration 20 / 30 | Total Loss: 1.93807053565979 | KNN Loss: 6.227117538452148 | BCE Loss: 1.93807053565979
Epoch 0 / 500 | iteration 25 / 30 | Total Loss: 1.890448808670044 | KNN Loss: 6.2272539138793945 | BCE Loss: 1.890448808670044
Epoch 1 / 500 | iteration 0 / 30 | Total Loss: 1.8803540468215942 | KNN Loss: 6.227087020874023 | BCE Loss: 1.8803540468215942
Epoch 1 / 500 | iteration 5 / 30 | Total Loss: 1.8928595781326294 | KNN Loss: 6.227339267730713 | BCE Loss: 1.892859

Epoch 10 / 500 | iteration 25 / 30 | Total Loss: 1.1139508485794067 | KNN Loss: 6.227087020874023 | BCE Loss: 1.1139508485794067
Epoch 11 / 500 | iteration 0 / 30 | Total Loss: 1.1409372091293335 | KNN Loss: 6.227234840393066 | BCE Loss: 1.1409372091293335
Epoch 11 / 500 | iteration 5 / 30 | Total Loss: 1.1305830478668213 | KNN Loss: 6.227318286895752 | BCE Loss: 1.1305830478668213
Epoch 11 / 500 | iteration 10 / 30 | Total Loss: 1.154988408088684 | KNN Loss: 6.227245807647705 | BCE Loss: 1.154988408088684
Epoch 11 / 500 | iteration 15 / 30 | Total Loss: 1.1592546701431274 | KNN Loss: 6.227466106414795 | BCE Loss: 1.1592546701431274
Epoch 11 / 500 | iteration 20 / 30 | Total Loss: 1.1000138521194458 | KNN Loss: 6.227219104766846 | BCE Loss: 1.1000138521194458
Epoch 11 / 500 | iteration 25 / 30 | Total Loss: 1.1004137992858887 | KNN Loss: 6.227105617523193 | BCE Loss: 1.1004137992858887
Epoch 12 / 500 | iteration 0 / 30 | Total Loss: 1.113079309463501 | KNN Loss: 6.2272515296936035 | BC

Epoch 21 / 500 | iteration 15 / 30 | Total Loss: 1.0572924613952637 | KNN Loss: 6.226929664611816 | BCE Loss: 1.0572924613952637
Epoch 21 / 500 | iteration 20 / 30 | Total Loss: 1.0835916996002197 | KNN Loss: 6.226919174194336 | BCE Loss: 1.0835916996002197
Epoch 21 / 500 | iteration 25 / 30 | Total Loss: 1.0618128776550293 | KNN Loss: 6.227316379547119 | BCE Loss: 1.0618128776550293
Epoch 22 / 500 | iteration 0 / 30 | Total Loss: 1.0492719411849976 | KNN Loss: 6.227105617523193 | BCE Loss: 1.0492719411849976
Epoch 22 / 500 | iteration 5 / 30 | Total Loss: 1.0456345081329346 | KNN Loss: 6.227390289306641 | BCE Loss: 1.0456345081329346
Epoch 22 / 500 | iteration 10 / 30 | Total Loss: 1.0723578929901123 | KNN Loss: 6.227113723754883 | BCE Loss: 1.0723578929901123
Epoch 22 / 500 | iteration 15 / 30 | Total Loss: 1.0711735486984253 | KNN Loss: 6.227255821228027 | BCE Loss: 1.0711735486984253
Epoch 22 / 500 | iteration 20 / 30 | Total Loss: 1.0376567840576172 | KNN Loss: 6.227641582489014 |

Epoch 32 / 500 | iteration 5 / 30 | Total Loss: 1.0753090381622314 | KNN Loss: 6.227157115936279 | BCE Loss: 1.0753090381622314
Epoch 32 / 500 | iteration 10 / 30 | Total Loss: 1.029982089996338 | KNN Loss: 6.227000713348389 | BCE Loss: 1.029982089996338
Epoch 32 / 500 | iteration 15 / 30 | Total Loss: 1.0515832901000977 | KNN Loss: 6.227030277252197 | BCE Loss: 1.0515832901000977
Epoch 32 / 500 | iteration 20 / 30 | Total Loss: 1.0583832263946533 | KNN Loss: 6.2273454666137695 | BCE Loss: 1.0583832263946533
Epoch 32 / 500 | iteration 25 / 30 | Total Loss: 1.0431360006332397 | KNN Loss: 6.22691535949707 | BCE Loss: 1.0431360006332397
Epoch 33 / 500 | iteration 0 / 30 | Total Loss: 1.018707513809204 | KNN Loss: 6.227166652679443 | BCE Loss: 1.018707513809204
Epoch 33 / 500 | iteration 5 / 30 | Total Loss: 1.0483719110488892 | KNN Loss: 6.227583885192871 | BCE Loss: 1.0483719110488892
Epoch 33 / 500 | iteration 10 / 30 | Total Loss: 1.065049171447754 | KNN Loss: 6.227123737335205 | BCE L

Epoch 42 / 500 | iteration 25 / 30 | Total Loss: 1.0463571548461914 | KNN Loss: 6.22697114944458 | BCE Loss: 1.0463571548461914
Epoch 43 / 500 | iteration 0 / 30 | Total Loss: 1.0512670278549194 | KNN Loss: 6.227015495300293 | BCE Loss: 1.0512670278549194
Epoch 43 / 500 | iteration 5 / 30 | Total Loss: 1.04410719871521 | KNN Loss: 6.226932048797607 | BCE Loss: 1.04410719871521
Epoch 43 / 500 | iteration 10 / 30 | Total Loss: 1.048304557800293 | KNN Loss: 6.226841449737549 | BCE Loss: 1.048304557800293
Epoch 43 / 500 | iteration 15 / 30 | Total Loss: 1.0363131761550903 | KNN Loss: 6.22735071182251 | BCE Loss: 1.0363131761550903
Epoch 43 / 500 | iteration 20 / 30 | Total Loss: 1.0411527156829834 | KNN Loss: 6.226919174194336 | BCE Loss: 1.0411527156829834
Epoch 43 / 500 | iteration 25 / 30 | Total Loss: 1.091540813446045 | KNN Loss: 6.227146625518799 | BCE Loss: 1.091540813446045
Epoch 44 / 500 | iteration 0 / 30 | Total Loss: 1.0506095886230469 | KNN Loss: 6.22714900970459 | BCE Loss: 1

Epoch 53 / 500 | iteration 20 / 30 | Total Loss: 1.0615265369415283 | KNN Loss: 6.227290153503418 | BCE Loss: 1.0615265369415283
Epoch 53 / 500 | iteration 25 / 30 | Total Loss: 1.0531423091888428 | KNN Loss: 6.227272033691406 | BCE Loss: 1.0531423091888428
Epoch 54 / 500 | iteration 0 / 30 | Total Loss: 1.0578547716140747 | KNN Loss: 6.227115631103516 | BCE Loss: 1.0578547716140747
Epoch 54 / 500 | iteration 5 / 30 | Total Loss: 1.0479074716567993 | KNN Loss: 6.227134704589844 | BCE Loss: 1.0479074716567993
Epoch 54 / 500 | iteration 10 / 30 | Total Loss: 1.0590124130249023 | KNN Loss: 6.227303981781006 | BCE Loss: 1.0590124130249023
Epoch 54 / 500 | iteration 15 / 30 | Total Loss: 1.05381441116333 | KNN Loss: 6.226988315582275 | BCE Loss: 1.05381441116333
Epoch 54 / 500 | iteration 20 / 30 | Total Loss: 1.0622897148132324 | KNN Loss: 6.226889133453369 | BCE Loss: 1.0622897148132324
Epoch 54 / 500 | iteration 25 / 30 | Total Loss: 1.0240764617919922 | KNN Loss: 6.2272491455078125 | BC

Epoch 64 / 500 | iteration 10 / 30 | Total Loss: 1.0658824443817139 | KNN Loss: 6.227335453033447 | BCE Loss: 1.0658824443817139
Epoch 64 / 500 | iteration 15 / 30 | Total Loss: 1.0483509302139282 | KNN Loss: 6.227010250091553 | BCE Loss: 1.0483509302139282
Epoch 64 / 500 | iteration 20 / 30 | Total Loss: 1.0740257501602173 | KNN Loss: 6.22695779800415 | BCE Loss: 1.0740257501602173
Epoch 64 / 500 | iteration 25 / 30 | Total Loss: 1.026416301727295 | KNN Loss: 6.22708797454834 | BCE Loss: 1.026416301727295
Epoch 65 / 500 | iteration 0 / 30 | Total Loss: 1.071574330329895 | KNN Loss: 6.226841926574707 | BCE Loss: 1.071574330329895
Epoch 65 / 500 | iteration 5 / 30 | Total Loss: 1.0505361557006836 | KNN Loss: 6.226873874664307 | BCE Loss: 1.0505361557006836
Epoch 65 / 500 | iteration 10 / 30 | Total Loss: 1.0773199796676636 | KNN Loss: 6.227167129516602 | BCE Loss: 1.0773199796676636
Epoch 65 / 500 | iteration 15 / 30 | Total Loss: 1.0297532081604004 | KNN Loss: 6.2273783683776855 | BCE 

Epoch 75 / 500 | iteration 0 / 30 | Total Loss: 1.058895468711853 | KNN Loss: 6.226919651031494 | BCE Loss: 1.058895468711853
Epoch 75 / 500 | iteration 5 / 30 | Total Loss: 1.0337564945220947 | KNN Loss: 6.226998805999756 | BCE Loss: 1.0337564945220947
Epoch 75 / 500 | iteration 10 / 30 | Total Loss: 1.0652885437011719 | KNN Loss: 6.227047443389893 | BCE Loss: 1.0652885437011719
Epoch 75 / 500 | iteration 15 / 30 | Total Loss: 1.0697295665740967 | KNN Loss: 6.226920127868652 | BCE Loss: 1.0697295665740967
Epoch 75 / 500 | iteration 20 / 30 | Total Loss: 1.0246686935424805 | KNN Loss: 6.2267889976501465 | BCE Loss: 1.0246686935424805
Epoch 75 / 500 | iteration 25 / 30 | Total Loss: 1.080754280090332 | KNN Loss: 6.227147102355957 | BCE Loss: 1.080754280090332
Epoch 76 / 500 | iteration 0 / 30 | Total Loss: 1.0415258407592773 | KNN Loss: 6.226583957672119 | BCE Loss: 1.0415258407592773
Epoch 76 / 500 | iteration 5 / 30 | Total Loss: 1.0682053565979004 | KNN Loss: 6.227074146270752 | BCE 

Epoch 85 / 500 | iteration 20 / 30 | Total Loss: 1.0773146152496338 | KNN Loss: 6.226704120635986 | BCE Loss: 1.0773146152496338
Epoch 85 / 500 | iteration 25 / 30 | Total Loss: 1.0384821891784668 | KNN Loss: 6.226852893829346 | BCE Loss: 1.0384821891784668
Epoch 86 / 500 | iteration 0 / 30 | Total Loss: 1.066483974456787 | KNN Loss: 6.226875305175781 | BCE Loss: 1.066483974456787
Epoch 86 / 500 | iteration 5 / 30 | Total Loss: 1.0413341522216797 | KNN Loss: 6.227021217346191 | BCE Loss: 1.0413341522216797
Epoch 86 / 500 | iteration 10 / 30 | Total Loss: 1.0411936044692993 | KNN Loss: 6.22705602645874 | BCE Loss: 1.0411936044692993
Epoch 86 / 500 | iteration 15 / 30 | Total Loss: 1.084040641784668 | KNN Loss: 6.227148056030273 | BCE Loss: 1.084040641784668
Epoch 86 / 500 | iteration 20 / 30 | Total Loss: 1.0611236095428467 | KNN Loss: 6.2269768714904785 | BCE Loss: 1.0611236095428467
Epoch 86 / 500 | iteration 25 / 30 | Total Loss: 1.0695083141326904 | KNN Loss: 6.227015495300293 | BCE

Epoch 96 / 500 | iteration 10 / 30 | Total Loss: 1.0603450536727905 | KNN Loss: 6.22675085067749 | BCE Loss: 1.0603450536727905
Epoch 96 / 500 | iteration 15 / 30 | Total Loss: 1.050800085067749 | KNN Loss: 6.2269086837768555 | BCE Loss: 1.050800085067749
Epoch 96 / 500 | iteration 20 / 30 | Total Loss: 1.0563969612121582 | KNN Loss: 6.226683139801025 | BCE Loss: 1.0563969612121582
Epoch 96 / 500 | iteration 25 / 30 | Total Loss: 1.0643882751464844 | KNN Loss: 6.226886749267578 | BCE Loss: 1.0643882751464844
Epoch 97 / 500 | iteration 0 / 30 | Total Loss: 1.0618863105773926 | KNN Loss: 6.226844310760498 | BCE Loss: 1.0618863105773926
Epoch 97 / 500 | iteration 5 / 30 | Total Loss: 1.0626922845840454 | KNN Loss: 6.226853847503662 | BCE Loss: 1.0626922845840454
Epoch 97 / 500 | iteration 10 / 30 | Total Loss: 1.0466418266296387 | KNN Loss: 6.22709321975708 | BCE Loss: 1.0466418266296387
Epoch 97 / 500 | iteration 15 / 30 | Total Loss: 1.0199904441833496 | KNN Loss: 6.227088928222656 | BC

Epoch 107 / 500 | iteration 0 / 30 | Total Loss: 1.045440912246704 | KNN Loss: 6.226930618286133 | BCE Loss: 1.045440912246704
Epoch 107 / 500 | iteration 5 / 30 | Total Loss: 1.036458969116211 | KNN Loss: 6.226829528808594 | BCE Loss: 1.036458969116211
Epoch 107 / 500 | iteration 10 / 30 | Total Loss: 1.0731292963027954 | KNN Loss: 6.226868152618408 | BCE Loss: 1.0731292963027954
Epoch 107 / 500 | iteration 15 / 30 | Total Loss: 1.0812146663665771 | KNN Loss: 6.2269978523254395 | BCE Loss: 1.0812146663665771
Epoch 107 / 500 | iteration 20 / 30 | Total Loss: 1.0514521598815918 | KNN Loss: 6.227014064788818 | BCE Loss: 1.0514521598815918
Epoch 107 / 500 | iteration 25 / 30 | Total Loss: 1.0046002864837646 | KNN Loss: 6.226696014404297 | BCE Loss: 1.0046002864837646
Epoch 108 / 500 | iteration 0 / 30 | Total Loss: 1.0442957878112793 | KNN Loss: 6.227020263671875 | BCE Loss: 1.0442957878112793
Epoch 108 / 500 | iteration 5 / 30 | Total Loss: 1.0481359958648682 | KNN Loss: 6.22673892974853

Epoch 117 / 500 | iteration 15 / 30 | Total Loss: 1.044250249862671 | KNN Loss: 6.226868629455566 | BCE Loss: 1.044250249862671
Epoch 117 / 500 | iteration 20 / 30 | Total Loss: 1.0394551753997803 | KNN Loss: 6.227016448974609 | BCE Loss: 1.0394551753997803
Epoch 117 / 500 | iteration 25 / 30 | Total Loss: 1.0605182647705078 | KNN Loss: 6.2265448570251465 | BCE Loss: 1.0605182647705078
Epoch 118 / 500 | iteration 0 / 30 | Total Loss: 1.0465468168258667 | KNN Loss: 6.226943016052246 | BCE Loss: 1.0465468168258667
Epoch 118 / 500 | iteration 5 / 30 | Total Loss: 1.057612657546997 | KNN Loss: 6.226848125457764 | BCE Loss: 1.057612657546997
Epoch 118 / 500 | iteration 10 / 30 | Total Loss: 1.040324091911316 | KNN Loss: 6.226984977722168 | BCE Loss: 1.040324091911316
Epoch 118 / 500 | iteration 15 / 30 | Total Loss: 1.057418704032898 | KNN Loss: 6.226921558380127 | BCE Loss: 1.057418704032898
Epoch 118 / 500 | iteration 20 / 30 | Total Loss: 1.0508637428283691 | KNN Loss: 6.227057456970215 

Epoch 128 / 500 | iteration 0 / 30 | Total Loss: 1.042443037033081 | KNN Loss: 6.2268805503845215 | BCE Loss: 1.042443037033081
Epoch 128 / 500 | iteration 5 / 30 | Total Loss: 1.0773723125457764 | KNN Loss: 6.226934432983398 | BCE Loss: 1.0773723125457764
Epoch 128 / 500 | iteration 10 / 30 | Total Loss: 1.0399131774902344 | KNN Loss: 6.227065086364746 | BCE Loss: 1.0399131774902344
Epoch 128 / 500 | iteration 15 / 30 | Total Loss: 1.060936450958252 | KNN Loss: 6.227071285247803 | BCE Loss: 1.060936450958252
Epoch 128 / 500 | iteration 20 / 30 | Total Loss: 1.035548210144043 | KNN Loss: 6.226974010467529 | BCE Loss: 1.035548210144043
Epoch 128 / 500 | iteration 25 / 30 | Total Loss: 1.0282609462738037 | KNN Loss: 6.227021217346191 | BCE Loss: 1.0282609462738037
Epoch 129 / 500 | iteration 0 / 30 | Total Loss: 1.0470993518829346 | KNN Loss: 6.2268967628479 | BCE Loss: 1.0470993518829346
Epoch 129 / 500 | iteration 5 / 30 | Total Loss: 1.0671563148498535 | KNN Loss: 6.226925849914551 | 

Epoch 138 / 500 | iteration 20 / 30 | Total Loss: 1.0495957136154175 | KNN Loss: 6.226792335510254 | BCE Loss: 1.0495957136154175
Epoch 138 / 500 | iteration 25 / 30 | Total Loss: 1.0606262683868408 | KNN Loss: 6.227038860321045 | BCE Loss: 1.0606262683868408
Epoch   139: reducing learning rate of group 0 to 2.8824e-04.
Epoch 139 / 500 | iteration 0 / 30 | Total Loss: 1.0573172569274902 | KNN Loss: 6.227077960968018 | BCE Loss: 1.0573172569274902
Epoch 139 / 500 | iteration 5 / 30 | Total Loss: 1.0607187747955322 | KNN Loss: 6.226583480834961 | BCE Loss: 1.0607187747955322
Epoch 139 / 500 | iteration 10 / 30 | Total Loss: 1.0605781078338623 | KNN Loss: 6.227216720581055 | BCE Loss: 1.0605781078338623
Epoch 139 / 500 | iteration 15 / 30 | Total Loss: 1.06512451171875 | KNN Loss: 6.226833343505859 | BCE Loss: 1.06512451171875
Epoch 139 / 500 | iteration 20 / 30 | Total Loss: 1.0430717468261719 | KNN Loss: 6.2271270751953125 | BCE Loss: 1.0430717468261719
Epoch 139 / 500 | iteration 25 / 

Epoch 149 / 500 | iteration 10 / 30 | Total Loss: 1.0562978982925415 | KNN Loss: 6.227269649505615 | BCE Loss: 1.0562978982925415
Epoch 149 / 500 | iteration 15 / 30 | Total Loss: 1.0376942157745361 | KNN Loss: 6.226743221282959 | BCE Loss: 1.0376942157745361
Epoch 149 / 500 | iteration 20 / 30 | Total Loss: 1.0255951881408691 | KNN Loss: 6.226550102233887 | BCE Loss: 1.0255951881408691
Epoch 149 / 500 | iteration 25 / 30 | Total Loss: 1.0463883876800537 | KNN Loss: 6.226866722106934 | BCE Loss: 1.0463883876800537
Epoch   150: reducing learning rate of group 0 to 2.0177e-04.
Epoch 150 / 500 | iteration 0 / 30 | Total Loss: 1.0331904888153076 | KNN Loss: 6.2269110679626465 | BCE Loss: 1.0331904888153076
Epoch 150 / 500 | iteration 5 / 30 | Total Loss: 1.0676355361938477 | KNN Loss: 6.226658344268799 | BCE Loss: 1.0676355361938477
Epoch 150 / 500 | iteration 10 / 30 | Total Loss: 1.0651873350143433 | KNN Loss: 6.2268829345703125 | BCE Loss: 1.0651873350143433
Epoch 150 / 500 | iteration 

Epoch 159 / 500 | iteration 25 / 30 | Total Loss: 1.076477289199829 | KNN Loss: 6.227128982543945 | BCE Loss: 1.076477289199829
Epoch 160 / 500 | iteration 0 / 30 | Total Loss: 1.0566004514694214 | KNN Loss: 6.227077960968018 | BCE Loss: 1.0566004514694214
Epoch 160 / 500 | iteration 5 / 30 | Total Loss: 1.0471282005310059 | KNN Loss: 6.226767539978027 | BCE Loss: 1.0471282005310059
Epoch 160 / 500 | iteration 10 / 30 | Total Loss: 1.0086981058120728 | KNN Loss: 6.226433753967285 | BCE Loss: 1.0086981058120728
Epoch 160 / 500 | iteration 15 / 30 | Total Loss: 1.0347754955291748 | KNN Loss: 6.226790428161621 | BCE Loss: 1.0347754955291748
Epoch 160 / 500 | iteration 20 / 30 | Total Loss: 1.0571165084838867 | KNN Loss: 6.226821422576904 | BCE Loss: 1.0571165084838867
Epoch 160 / 500 | iteration 25 / 30 | Total Loss: 1.0350394248962402 | KNN Loss: 6.227270126342773 | BCE Loss: 1.0350394248962402
Epoch   161: reducing learning rate of group 0 to 1.4124e-04.
Epoch 161 / 500 | iteration 0 / 

Epoch 170 / 500 | iteration 10 / 30 | Total Loss: 1.0616375207901 | KNN Loss: 6.2268595695495605 | BCE Loss: 1.0616375207901
Epoch 170 / 500 | iteration 15 / 30 | Total Loss: 1.07437002658844 | KNN Loss: 6.226970672607422 | BCE Loss: 1.07437002658844
Epoch 170 / 500 | iteration 20 / 30 | Total Loss: 1.0691665410995483 | KNN Loss: 6.22697639465332 | BCE Loss: 1.0691665410995483
Epoch 170 / 500 | iteration 25 / 30 | Total Loss: 1.04924476146698 | KNN Loss: 6.226975917816162 | BCE Loss: 1.04924476146698
Epoch 171 / 500 | iteration 0 / 30 | Total Loss: 1.049546241760254 | KNN Loss: 6.226875305175781 | BCE Loss: 1.049546241760254
Epoch 171 / 500 | iteration 5 / 30 | Total Loss: 1.0546324253082275 | KNN Loss: 6.2266387939453125 | BCE Loss: 1.0546324253082275
Epoch 171 / 500 | iteration 10 / 30 | Total Loss: 1.0438556671142578 | KNN Loss: 6.226565361022949 | BCE Loss: 1.0438556671142578
Epoch 171 / 500 | iteration 15 / 30 | Total Loss: 1.0355594158172607 | KNN Loss: 6.226954460144043 | BCE Lo

Epoch 181 / 500 | iteration 0 / 30 | Total Loss: 1.0623252391815186 | KNN Loss: 6.226928234100342 | BCE Loss: 1.0623252391815186
Epoch 181 / 500 | iteration 5 / 30 | Total Loss: 1.070289969444275 | KNN Loss: 6.22699499130249 | BCE Loss: 1.070289969444275
Epoch 181 / 500 | iteration 10 / 30 | Total Loss: 1.02224600315094 | KNN Loss: 6.226816654205322 | BCE Loss: 1.02224600315094
Epoch 181 / 500 | iteration 15 / 30 | Total Loss: 1.0373693704605103 | KNN Loss: 6.226861000061035 | BCE Loss: 1.0373693704605103
Epoch 181 / 500 | iteration 20 / 30 | Total Loss: 1.0419515371322632 | KNN Loss: 6.2267680168151855 | BCE Loss: 1.0419515371322632
Epoch 181 / 500 | iteration 25 / 30 | Total Loss: 1.0479635000228882 | KNN Loss: 6.226902484893799 | BCE Loss: 1.0479635000228882
Epoch 182 / 500 | iteration 0 / 30 | Total Loss: 1.0527113676071167 | KNN Loss: 6.2272114753723145 | BCE Loss: 1.0527113676071167
Epoch 182 / 500 | iteration 5 / 30 | Total Loss: 1.066657543182373 | KNN Loss: 6.226917743682861 |

Epoch 191 / 500 | iteration 15 / 30 | Total Loss: 1.0322407484054565 | KNN Loss: 6.227099418640137 | BCE Loss: 1.0322407484054565
Epoch 191 / 500 | iteration 20 / 30 | Total Loss: 1.0836859941482544 | KNN Loss: 6.22698450088501 | BCE Loss: 1.0836859941482544
Epoch 191 / 500 | iteration 25 / 30 | Total Loss: 1.0590853691101074 | KNN Loss: 6.22674036026001 | BCE Loss: 1.0590853691101074
Epoch 192 / 500 | iteration 0 / 30 | Total Loss: 1.0102930068969727 | KNN Loss: 6.2269110679626465 | BCE Loss: 1.0102930068969727
Epoch 192 / 500 | iteration 5 / 30 | Total Loss: 1.0650807619094849 | KNN Loss: 6.226867198944092 | BCE Loss: 1.0650807619094849
Epoch 192 / 500 | iteration 10 / 30 | Total Loss: 1.0559502840042114 | KNN Loss: 6.226860523223877 | BCE Loss: 1.0559502840042114
Epoch 192 / 500 | iteration 15 / 30 | Total Loss: 1.0615205764770508 | KNN Loss: 6.226774215698242 | BCE Loss: 1.0615205764770508
Epoch 192 / 500 | iteration 20 / 30 | Total Loss: 1.0490975379943848 | KNN Loss: 6.2269496917

Epoch 202 / 500 | iteration 0 / 30 | Total Loss: 1.0554401874542236 | KNN Loss: 6.226815700531006 | BCE Loss: 1.0554401874542236
Epoch 202 / 500 | iteration 5 / 30 | Total Loss: 1.070899248123169 | KNN Loss: 6.2269368171691895 | BCE Loss: 1.070899248123169
Epoch 202 / 500 | iteration 10 / 30 | Total Loss: 1.049095630645752 | KNN Loss: 6.22702169418335 | BCE Loss: 1.049095630645752
Epoch 202 / 500 | iteration 15 / 30 | Total Loss: 1.040338158607483 | KNN Loss: 6.226525783538818 | BCE Loss: 1.040338158607483
Epoch 202 / 500 | iteration 20 / 30 | Total Loss: 1.0707011222839355 | KNN Loss: 6.226597785949707 | BCE Loss: 1.0707011222839355
Epoch 202 / 500 | iteration 25 / 30 | Total Loss: 1.0314799547195435 | KNN Loss: 6.226867198944092 | BCE Loss: 1.0314799547195435
Epoch 203 / 500 | iteration 0 / 30 | Total Loss: 1.0740795135498047 | KNN Loss: 6.226897239685059 | BCE Loss: 1.0740795135498047
Epoch 203 / 500 | iteration 5 / 30 | Total Loss: 1.0412496328353882 | KNN Loss: 6.227119445800781 |

Epoch 212 / 500 | iteration 15 / 30 | Total Loss: 1.0406173467636108 | KNN Loss: 6.226927757263184 | BCE Loss: 1.0406173467636108
Epoch 212 / 500 | iteration 20 / 30 | Total Loss: 1.044654369354248 | KNN Loss: 6.2265472412109375 | BCE Loss: 1.044654369354248
Epoch 212 / 500 | iteration 25 / 30 | Total Loss: 1.0590059757232666 | KNN Loss: 6.226870536804199 | BCE Loss: 1.0590059757232666
Epoch 213 / 500 | iteration 0 / 30 | Total Loss: 1.036024808883667 | KNN Loss: 6.227085113525391 | BCE Loss: 1.036024808883667
Epoch 213 / 500 | iteration 5 / 30 | Total Loss: 1.0623430013656616 | KNN Loss: 6.226907253265381 | BCE Loss: 1.0623430013656616
Epoch 213 / 500 | iteration 10 / 30 | Total Loss: 1.050119400024414 | KNN Loss: 6.226696491241455 | BCE Loss: 1.050119400024414
Epoch 213 / 500 | iteration 15 / 30 | Total Loss: 1.0394859313964844 | KNN Loss: 6.2268147468566895 | BCE Loss: 1.0394859313964844
Epoch 213 / 500 | iteration 20 / 30 | Total Loss: 1.0428714752197266 | KNN Loss: 6.2267804145812

Epoch 223 / 500 | iteration 0 / 30 | Total Loss: 1.0365090370178223 | KNN Loss: 6.226977825164795 | BCE Loss: 1.0365090370178223
Epoch 223 / 500 | iteration 5 / 30 | Total Loss: 1.0410178899765015 | KNN Loss: 6.22702169418335 | BCE Loss: 1.0410178899765015
Epoch 223 / 500 | iteration 10 / 30 | Total Loss: 1.010920524597168 | KNN Loss: 6.226714611053467 | BCE Loss: 1.010920524597168
Epoch 223 / 500 | iteration 15 / 30 | Total Loss: 1.0478696823120117 | KNN Loss: 6.226889133453369 | BCE Loss: 1.0478696823120117
Epoch 223 / 500 | iteration 20 / 30 | Total Loss: 1.0529683828353882 | KNN Loss: 6.226987361907959 | BCE Loss: 1.0529683828353882
Epoch 223 / 500 | iteration 25 / 30 | Total Loss: 1.043717622756958 | KNN Loss: 6.2266459465026855 | BCE Loss: 1.043717622756958
Epoch 224 / 500 | iteration 0 / 30 | Total Loss: 1.0927985906600952 | KNN Loss: 6.226918697357178 | BCE Loss: 1.0927985906600952
Epoch 224 / 500 | iteration 5 / 30 | Total Loss: 1.043976902961731 | KNN Loss: 6.226938247680664 

Epoch 233 / 500 | iteration 20 / 30 | Total Loss: 1.0241713523864746 | KNN Loss: 6.2269134521484375 | BCE Loss: 1.0241713523864746
Epoch 233 / 500 | iteration 25 / 30 | Total Loss: 1.0648051500320435 | KNN Loss: 6.226967811584473 | BCE Loss: 1.0648051500320435
Epoch 234 / 500 | iteration 0 / 30 | Total Loss: 1.0805978775024414 | KNN Loss: 6.226791858673096 | BCE Loss: 1.0805978775024414
Epoch 234 / 500 | iteration 5 / 30 | Total Loss: 1.0411412715911865 | KNN Loss: 6.226834297180176 | BCE Loss: 1.0411412715911865
Epoch 234 / 500 | iteration 10 / 30 | Total Loss: 1.0600465536117554 | KNN Loss: 6.226978302001953 | BCE Loss: 1.0600465536117554
Epoch 234 / 500 | iteration 15 / 30 | Total Loss: 1.0166292190551758 | KNN Loss: 6.226944923400879 | BCE Loss: 1.0166292190551758
Epoch 234 / 500 | iteration 20 / 30 | Total Loss: 1.0721124410629272 | KNN Loss: 6.227112293243408 | BCE Loss: 1.0721124410629272
Epoch 234 / 500 | iteration 25 / 30 | Total Loss: 1.0578986406326294 | KNN Loss: 6.22703313

Epoch 244 / 500 | iteration 5 / 30 | Total Loss: 1.065881371498108 | KNN Loss: 6.227015972137451 | BCE Loss: 1.065881371498108
Epoch 244 / 500 | iteration 10 / 30 | Total Loss: 1.05806303024292 | KNN Loss: 6.227002143859863 | BCE Loss: 1.05806303024292
Epoch 244 / 500 | iteration 15 / 30 | Total Loss: 1.0601227283477783 | KNN Loss: 6.226663112640381 | BCE Loss: 1.0601227283477783
Epoch 244 / 500 | iteration 20 / 30 | Total Loss: 1.0578258037567139 | KNN Loss: 6.226873874664307 | BCE Loss: 1.0578258037567139
Epoch 244 / 500 | iteration 25 / 30 | Total Loss: 1.048478603363037 | KNN Loss: 6.2269134521484375 | BCE Loss: 1.048478603363037
Epoch 245 / 500 | iteration 0 / 30 | Total Loss: 1.0473240613937378 | KNN Loss: 6.22663688659668 | BCE Loss: 1.0473240613937378
Epoch 245 / 500 | iteration 5 / 30 | Total Loss: 1.0441851615905762 | KNN Loss: 6.22693395614624 | BCE Loss: 1.0441851615905762
Epoch 245 / 500 | iteration 10 / 30 | Total Loss: 1.0393147468566895 | KNN Loss: 6.226897716522217 | B

Epoch 254 / 500 | iteration 25 / 30 | Total Loss: 1.075943946838379 | KNN Loss: 6.226963043212891 | BCE Loss: 1.075943946838379
Epoch 255 / 500 | iteration 0 / 30 | Total Loss: 1.0674834251403809 | KNN Loss: 6.226994514465332 | BCE Loss: 1.0674834251403809
Epoch 255 / 500 | iteration 5 / 30 | Total Loss: 1.0355924367904663 | KNN Loss: 6.227250099182129 | BCE Loss: 1.0355924367904663
Epoch 255 / 500 | iteration 10 / 30 | Total Loss: 1.0497645139694214 | KNN Loss: 6.226935386657715 | BCE Loss: 1.0497645139694214
Epoch 255 / 500 | iteration 15 / 30 | Total Loss: 1.0508261919021606 | KNN Loss: 6.226415634155273 | BCE Loss: 1.0508261919021606
Epoch 255 / 500 | iteration 20 / 30 | Total Loss: 1.063277244567871 | KNN Loss: 6.226919174194336 | BCE Loss: 1.063277244567871
Epoch 255 / 500 | iteration 25 / 30 | Total Loss: 1.0400593280792236 | KNN Loss: 6.226562023162842 | BCE Loss: 1.0400593280792236
Epoch 256 / 500 | iteration 0 / 30 | Total Loss: 1.0698119401931763 | KNN Loss: 6.22683572769165

Epoch 265 / 500 | iteration 10 / 30 | Total Loss: 1.0564132928848267 | KNN Loss: 6.226837635040283 | BCE Loss: 1.0564132928848267
Epoch 265 / 500 | iteration 15 / 30 | Total Loss: 1.0412437915802002 | KNN Loss: 6.226728916168213 | BCE Loss: 1.0412437915802002
Epoch 265 / 500 | iteration 20 / 30 | Total Loss: 1.0547711849212646 | KNN Loss: 6.226656436920166 | BCE Loss: 1.0547711849212646
Epoch 265 / 500 | iteration 25 / 30 | Total Loss: 1.0370988845825195 | KNN Loss: 6.226762771606445 | BCE Loss: 1.0370988845825195
Epoch 266 / 500 | iteration 0 / 30 | Total Loss: 1.0700490474700928 | KNN Loss: 6.22704553604126 | BCE Loss: 1.0700490474700928
Epoch 266 / 500 | iteration 5 / 30 | Total Loss: 1.048577070236206 | KNN Loss: 6.226825714111328 | BCE Loss: 1.048577070236206
Epoch 266 / 500 | iteration 10 / 30 | Total Loss: 1.0356485843658447 | KNN Loss: 6.226738452911377 | BCE Loss: 1.0356485843658447
Epoch 266 / 500 | iteration 15 / 30 | Total Loss: 1.0310592651367188 | KNN Loss: 6.226920127868

Epoch 275 / 500 | iteration 25 / 30 | Total Loss: 1.0494046211242676 | KNN Loss: 6.227067947387695 | BCE Loss: 1.0494046211242676
Epoch 276 / 500 | iteration 0 / 30 | Total Loss: 1.0572210550308228 | KNN Loss: 6.227041721343994 | BCE Loss: 1.0572210550308228
Epoch 276 / 500 | iteration 5 / 30 | Total Loss: 1.0414752960205078 | KNN Loss: 6.226471900939941 | BCE Loss: 1.0414752960205078
Epoch 276 / 500 | iteration 10 / 30 | Total Loss: 1.0326828956604004 | KNN Loss: 6.226677417755127 | BCE Loss: 1.0326828956604004
Epoch 276 / 500 | iteration 15 / 30 | Total Loss: 1.026505708694458 | KNN Loss: 6.226903915405273 | BCE Loss: 1.026505708694458
Epoch 276 / 500 | iteration 20 / 30 | Total Loss: 1.0510408878326416 | KNN Loss: 6.227099418640137 | BCE Loss: 1.0510408878326416
Epoch 276 / 500 | iteration 25 / 30 | Total Loss: 1.0704619884490967 | KNN Loss: 6.226993560791016 | BCE Loss: 1.0704619884490967
Epoch 277 / 500 | iteration 0 / 30 | Total Loss: 1.0390326976776123 | KNN Loss: 6.226654529571

Epoch 286 / 500 | iteration 10 / 30 | Total Loss: 1.0557183027267456 | KNN Loss: 6.227101802825928 | BCE Loss: 1.0557183027267456
Epoch 286 / 500 | iteration 15 / 30 | Total Loss: 1.0336003303527832 | KNN Loss: 6.2269206047058105 | BCE Loss: 1.0336003303527832
Epoch 286 / 500 | iteration 20 / 30 | Total Loss: 1.0574395656585693 | KNN Loss: 6.226722240447998 | BCE Loss: 1.0574395656585693
Epoch 286 / 500 | iteration 25 / 30 | Total Loss: 1.031369686126709 | KNN Loss: 6.226720333099365 | BCE Loss: 1.031369686126709
Epoch 287 / 500 | iteration 0 / 30 | Total Loss: 1.071319580078125 | KNN Loss: 6.226885795593262 | BCE Loss: 1.071319580078125
Epoch 287 / 500 | iteration 5 / 30 | Total Loss: 1.043230652809143 | KNN Loss: 6.226687908172607 | BCE Loss: 1.043230652809143
Epoch 287 / 500 | iteration 10 / 30 | Total Loss: 1.0752365589141846 | KNN Loss: 6.227033615112305 | BCE Loss: 1.0752365589141846
Epoch 287 / 500 | iteration 15 / 30 | Total Loss: 1.0289485454559326 | KNN Loss: 6.22705078125 | 

Epoch 296 / 500 | iteration 25 / 30 | Total Loss: 1.0411936044692993 | KNN Loss: 6.22706413269043 | BCE Loss: 1.0411936044692993
Epoch 297 / 500 | iteration 0 / 30 | Total Loss: 1.055984616279602 | KNN Loss: 6.227064609527588 | BCE Loss: 1.055984616279602
Epoch 297 / 500 | iteration 5 / 30 | Total Loss: 1.0472931861877441 | KNN Loss: 6.226950645446777 | BCE Loss: 1.0472931861877441
Epoch 297 / 500 | iteration 10 / 30 | Total Loss: 1.00771164894104 | KNN Loss: 6.227148056030273 | BCE Loss: 1.00771164894104
Epoch 297 / 500 | iteration 15 / 30 | Total Loss: 1.0632202625274658 | KNN Loss: 6.226799011230469 | BCE Loss: 1.0632202625274658
Epoch 297 / 500 | iteration 20 / 30 | Total Loss: 1.030023217201233 | KNN Loss: 6.226908206939697 | BCE Loss: 1.030023217201233
Epoch 297 / 500 | iteration 25 / 30 | Total Loss: 1.0804083347320557 | KNN Loss: 6.226727485656738 | BCE Loss: 1.0804083347320557
Epoch 298 / 500 | iteration 0 / 30 | Total Loss: 1.0481376647949219 | KNN Loss: 6.226726055145264 | B

Epoch 307 / 500 | iteration 15 / 30 | Total Loss: 1.062995433807373 | KNN Loss: 6.227226734161377 | BCE Loss: 1.062995433807373
Epoch 307 / 500 | iteration 20 / 30 | Total Loss: 1.0576951503753662 | KNN Loss: 6.226957321166992 | BCE Loss: 1.0576951503753662
Epoch 307 / 500 | iteration 25 / 30 | Total Loss: 1.057387113571167 | KNN Loss: 6.227044105529785 | BCE Loss: 1.057387113571167
Epoch 308 / 500 | iteration 0 / 30 | Total Loss: 1.0620652437210083 | KNN Loss: 6.226962089538574 | BCE Loss: 1.0620652437210083
Epoch 308 / 500 | iteration 5 / 30 | Total Loss: 1.0354008674621582 | KNN Loss: 6.226673603057861 | BCE Loss: 1.0354008674621582
Epoch 308 / 500 | iteration 10 / 30 | Total Loss: 1.067190170288086 | KNN Loss: 6.227007865905762 | BCE Loss: 1.067190170288086
Epoch 308 / 500 | iteration 15 / 30 | Total Loss: 1.0475444793701172 | KNN Loss: 6.22662353515625 | BCE Loss: 1.0475444793701172
Epoch 308 / 500 | iteration 20 / 30 | Total Loss: 1.0546373128890991 | KNN Loss: 6.2266845703125 | 

Epoch 318 / 500 | iteration 5 / 30 | Total Loss: 1.084618091583252 | KNN Loss: 6.226693153381348 | BCE Loss: 1.084618091583252
Epoch 318 / 500 | iteration 10 / 30 | Total Loss: 1.0383810997009277 | KNN Loss: 6.226623058319092 | BCE Loss: 1.0383810997009277
Epoch 318 / 500 | iteration 15 / 30 | Total Loss: 1.0358498096466064 | KNN Loss: 6.2269134521484375 | BCE Loss: 1.0358498096466064
Epoch 318 / 500 | iteration 20 / 30 | Total Loss: 1.0525397062301636 | KNN Loss: 6.226765155792236 | BCE Loss: 1.0525397062301636
Epoch 318 / 500 | iteration 25 / 30 | Total Loss: 1.0377707481384277 | KNN Loss: 6.226866245269775 | BCE Loss: 1.0377707481384277
Epoch 319 / 500 | iteration 0 / 30 | Total Loss: 1.0531634092330933 | KNN Loss: 6.226933002471924 | BCE Loss: 1.0531634092330933
Epoch 319 / 500 | iteration 5 / 30 | Total Loss: 1.0302256345748901 | KNN Loss: 6.226992607116699 | BCE Loss: 1.0302256345748901
Epoch 319 / 500 | iteration 10 / 30 | Total Loss: 1.046430230140686 | KNN Loss: 6.226705551147

Epoch 328 / 500 | iteration 20 / 30 | Total Loss: 1.0383208990097046 | KNN Loss: 6.226868629455566 | BCE Loss: 1.0383208990097046
Epoch 328 / 500 | iteration 25 / 30 | Total Loss: 1.0299129486083984 | KNN Loss: 6.227171897888184 | BCE Loss: 1.0299129486083984
Epoch 329 / 500 | iteration 0 / 30 | Total Loss: 1.0493028163909912 | KNN Loss: 6.227002143859863 | BCE Loss: 1.0493028163909912
Epoch 329 / 500 | iteration 5 / 30 | Total Loss: 1.0455604791641235 | KNN Loss: 6.226907253265381 | BCE Loss: 1.0455604791641235
Epoch 329 / 500 | iteration 10 / 30 | Total Loss: 1.0510529279708862 | KNN Loss: 6.226870536804199 | BCE Loss: 1.0510529279708862
Epoch 329 / 500 | iteration 15 / 30 | Total Loss: 1.0584723949432373 | KNN Loss: 6.226937770843506 | BCE Loss: 1.0584723949432373
Epoch 329 / 500 | iteration 20 / 30 | Total Loss: 1.0623018741607666 | KNN Loss: 6.227108955383301 | BCE Loss: 1.0623018741607666
Epoch 329 / 500 | iteration 25 / 30 | Total Loss: 1.0264055728912354 | KNN Loss: 6.226986408

Epoch 339 / 500 | iteration 5 / 30 | Total Loss: 1.0491628646850586 | KNN Loss: 6.22667121887207 | BCE Loss: 1.0491628646850586
Epoch 339 / 500 | iteration 10 / 30 | Total Loss: 1.0504395961761475 | KNN Loss: 6.226828098297119 | BCE Loss: 1.0504395961761475
Epoch 339 / 500 | iteration 15 / 30 | Total Loss: 1.0080511569976807 | KNN Loss: 6.226879596710205 | BCE Loss: 1.0080511569976807
Epoch 339 / 500 | iteration 20 / 30 | Total Loss: 1.0510786771774292 | KNN Loss: 6.2266645431518555 | BCE Loss: 1.0510786771774292
Epoch 339 / 500 | iteration 25 / 30 | Total Loss: 1.0249855518341064 | KNN Loss: 6.2270660400390625 | BCE Loss: 1.0249855518341064
Epoch 340 / 500 | iteration 0 / 30 | Total Loss: 1.0923376083374023 | KNN Loss: 6.226940631866455 | BCE Loss: 1.0923376083374023
Epoch 340 / 500 | iteration 5 / 30 | Total Loss: 1.0316810607910156 | KNN Loss: 6.226646900177002 | BCE Loss: 1.0316810607910156
Epoch 340 / 500 | iteration 10 / 30 | Total Loss: 1.0597476959228516 | KNN Loss: 6.227002620

Epoch 349 / 500 | iteration 20 / 30 | Total Loss: 1.0686166286468506 | KNN Loss: 6.226937770843506 | BCE Loss: 1.0686166286468506
Epoch 349 / 500 | iteration 25 / 30 | Total Loss: 1.1126576662063599 | KNN Loss: 6.227221965789795 | BCE Loss: 1.1126576662063599
Epoch 350 / 500 | iteration 0 / 30 | Total Loss: 1.066407561302185 | KNN Loss: 6.226893424987793 | BCE Loss: 1.066407561302185
Epoch 350 / 500 | iteration 5 / 30 | Total Loss: 1.030342936515808 | KNN Loss: 6.226999759674072 | BCE Loss: 1.030342936515808
Epoch 350 / 500 | iteration 10 / 30 | Total Loss: 1.0249712467193604 | KNN Loss: 6.226402759552002 | BCE Loss: 1.0249712467193604
Epoch 350 / 500 | iteration 15 / 30 | Total Loss: 1.03749680519104 | KNN Loss: 6.2267961502075195 | BCE Loss: 1.03749680519104
Epoch 350 / 500 | iteration 20 / 30 | Total Loss: 1.0640727281570435 | KNN Loss: 6.227158546447754 | BCE Loss: 1.0640727281570435
Epoch 350 / 500 | iteration 25 / 30 | Total Loss: 1.048877239227295 | KNN Loss: 6.2268595695495605 

Epoch 360 / 500 | iteration 10 / 30 | Total Loss: 1.011671781539917 | KNN Loss: 6.226860046386719 | BCE Loss: 1.011671781539917
Epoch 360 / 500 | iteration 15 / 30 | Total Loss: 1.0535238981246948 | KNN Loss: 6.226858615875244 | BCE Loss: 1.0535238981246948
Epoch 360 / 500 | iteration 20 / 30 | Total Loss: 1.0723010301589966 | KNN Loss: 6.227019786834717 | BCE Loss: 1.0723010301589966
Epoch 360 / 500 | iteration 25 / 30 | Total Loss: 1.06814444065094 | KNN Loss: 6.226754665374756 | BCE Loss: 1.06814444065094
Epoch 361 / 500 | iteration 0 / 30 | Total Loss: 1.0720678567886353 | KNN Loss: 6.2269606590271 | BCE Loss: 1.0720678567886353
Epoch 361 / 500 | iteration 5 / 30 | Total Loss: 1.0251882076263428 | KNN Loss: 6.227046966552734 | BCE Loss: 1.0251882076263428
Epoch 361 / 500 | iteration 10 / 30 | Total Loss: 1.0776755809783936 | KNN Loss: 6.2266621589660645 | BCE Loss: 1.0776755809783936
Epoch 361 / 500 | iteration 15 / 30 | Total Loss: 1.052001953125 | KNN Loss: 6.226693630218506 | BC

Epoch 371 / 500 | iteration 0 / 30 | Total Loss: 1.0554600954055786 | KNN Loss: 6.226508617401123 | BCE Loss: 1.0554600954055786
Epoch 371 / 500 | iteration 5 / 30 | Total Loss: 1.0589137077331543 | KNN Loss: 6.226742267608643 | BCE Loss: 1.0589137077331543
Epoch 371 / 500 | iteration 10 / 30 | Total Loss: 1.0744062662124634 | KNN Loss: 6.226871013641357 | BCE Loss: 1.0744062662124634
Epoch 371 / 500 | iteration 15 / 30 | Total Loss: 1.0271823406219482 | KNN Loss: 6.2269287109375 | BCE Loss: 1.0271823406219482
Epoch 371 / 500 | iteration 20 / 30 | Total Loss: 1.051118016242981 | KNN Loss: 6.2270402908325195 | BCE Loss: 1.051118016242981
Epoch 371 / 500 | iteration 25 / 30 | Total Loss: 1.0544267892837524 | KNN Loss: 6.2272629737854 | BCE Loss: 1.0544267892837524
Epoch 372 / 500 | iteration 0 / 30 | Total Loss: 1.0607872009277344 | KNN Loss: 6.2267842292785645 | BCE Loss: 1.0607872009277344
Epoch 372 / 500 | iteration 5 / 30 | Total Loss: 1.0319286584854126 | KNN Loss: 6.226847648620605

Epoch 381 / 500 | iteration 20 / 30 | Total Loss: 1.0663492679595947 | KNN Loss: 6.226807117462158 | BCE Loss: 1.0663492679595947
Epoch 381 / 500 | iteration 25 / 30 | Total Loss: 1.0421209335327148 | KNN Loss: 6.226771354675293 | BCE Loss: 1.0421209335327148
Epoch 382 / 500 | iteration 0 / 30 | Total Loss: 1.0716760158538818 | KNN Loss: 6.226850986480713 | BCE Loss: 1.0716760158538818
Epoch 382 / 500 | iteration 5 / 30 | Total Loss: 1.0899244546890259 | KNN Loss: 6.22730016708374 | BCE Loss: 1.0899244546890259
Epoch 382 / 500 | iteration 10 / 30 | Total Loss: 1.0272246599197388 | KNN Loss: 6.226653099060059 | BCE Loss: 1.0272246599197388
Epoch 382 / 500 | iteration 15 / 30 | Total Loss: 1.053100347518921 | KNN Loss: 6.2268877029418945 | BCE Loss: 1.053100347518921
Epoch 382 / 500 | iteration 20 / 30 | Total Loss: 1.0508770942687988 | KNN Loss: 6.226931095123291 | BCE Loss: 1.0508770942687988
Epoch 382 / 500 | iteration 25 / 30 | Total Loss: 1.0225175619125366 | KNN Loss: 6.22670269012

Epoch 392 / 500 | iteration 5 / 30 | Total Loss: 1.0249288082122803 | KNN Loss: 6.226827621459961 | BCE Loss: 1.0249288082122803
Epoch 392 / 500 | iteration 10 / 30 | Total Loss: 1.0448682308197021 | KNN Loss: 6.2268385887146 | BCE Loss: 1.0448682308197021
Epoch 392 / 500 | iteration 15 / 30 | Total Loss: 1.0646051168441772 | KNN Loss: 6.226797580718994 | BCE Loss: 1.0646051168441772
Epoch 392 / 500 | iteration 20 / 30 | Total Loss: 1.059933066368103 | KNN Loss: 6.226742267608643 | BCE Loss: 1.059933066368103
Epoch 392 / 500 | iteration 25 / 30 | Total Loss: 1.069427251815796 | KNN Loss: 6.226789951324463 | BCE Loss: 1.069427251815796
Epoch 393 / 500 | iteration 0 / 30 | Total Loss: 1.0437631607055664 | KNN Loss: 6.226999282836914 | BCE Loss: 1.0437631607055664
Epoch 393 / 500 | iteration 5 / 30 | Total Loss: 1.0496611595153809 | KNN Loss: 6.226742267608643 | BCE Loss: 1.0496611595153809
Epoch 393 / 500 | iteration 10 / 30 | Total Loss: 1.05628502368927 | KNN Loss: 6.226803779602051 | 

Epoch 402 / 500 | iteration 25 / 30 | Total Loss: 1.0361123085021973 | KNN Loss: 6.226984024047852 | BCE Loss: 1.0361123085021973
Epoch   403: reducing learning rate of group 0 to 5.5221e-08.
Epoch 403 / 500 | iteration 0 / 30 | Total Loss: 1.066566824913025 | KNN Loss: 6.226778984069824 | BCE Loss: 1.066566824913025
Epoch 403 / 500 | iteration 5 / 30 | Total Loss: 1.0680656433105469 | KNN Loss: 6.226938724517822 | BCE Loss: 1.0680656433105469
Epoch 403 / 500 | iteration 10 / 30 | Total Loss: 1.0470733642578125 | KNN Loss: 6.226911544799805 | BCE Loss: 1.0470733642578125
Epoch 403 / 500 | iteration 15 / 30 | Total Loss: 1.0465502738952637 | KNN Loss: 6.226790428161621 | BCE Loss: 1.0465502738952637
Epoch 403 / 500 | iteration 20 / 30 | Total Loss: 1.0392248630523682 | KNN Loss: 6.226500511169434 | BCE Loss: 1.0392248630523682
Epoch 403 / 500 | iteration 25 / 30 | Total Loss: 1.0770124197006226 | KNN Loss: 6.226837158203125 | BCE Loss: 1.0770124197006226
Epoch 404 / 500 | iteration 0 / 

Epoch 413 / 500 | iteration 10 / 30 | Total Loss: 1.0608344078063965 | KNN Loss: 6.226782321929932 | BCE Loss: 1.0608344078063965
Epoch 413 / 500 | iteration 15 / 30 | Total Loss: 1.0264697074890137 | KNN Loss: 6.226752758026123 | BCE Loss: 1.0264697074890137
Epoch 413 / 500 | iteration 20 / 30 | Total Loss: 1.0338995456695557 | KNN Loss: 6.226859092712402 | BCE Loss: 1.0338995456695557
Epoch 413 / 500 | iteration 25 / 30 | Total Loss: 1.0504870414733887 | KNN Loss: 6.22652530670166 | BCE Loss: 1.0504870414733887
Epoch   414: reducing learning rate of group 0 to 3.8655e-08.
Epoch 414 / 500 | iteration 0 / 30 | Total Loss: 1.0361324548721313 | KNN Loss: 6.226386547088623 | BCE Loss: 1.0361324548721313
Epoch 414 / 500 | iteration 5 / 30 | Total Loss: 1.0342727899551392 | KNN Loss: 6.227099895477295 | BCE Loss: 1.0342727899551392
Epoch 414 / 500 | iteration 10 / 30 | Total Loss: 1.04404616355896 | KNN Loss: 6.226900577545166 | BCE Loss: 1.04404616355896
Epoch 414 / 500 | iteration 15 / 30

Epoch 423 / 500 | iteration 25 / 30 | Total Loss: 1.0463674068450928 | KNN Loss: 6.227054595947266 | BCE Loss: 1.0463674068450928
Epoch 424 / 500 | iteration 0 / 30 | Total Loss: 1.0453095436096191 | KNN Loss: 6.226705074310303 | BCE Loss: 1.0453095436096191
Epoch 424 / 500 | iteration 5 / 30 | Total Loss: 1.033537745475769 | KNN Loss: 6.227057933807373 | BCE Loss: 1.033537745475769
Epoch 424 / 500 | iteration 10 / 30 | Total Loss: 1.056617021560669 | KNN Loss: 6.226995944976807 | BCE Loss: 1.056617021560669
Epoch 424 / 500 | iteration 15 / 30 | Total Loss: 1.0635597705841064 | KNN Loss: 6.226958751678467 | BCE Loss: 1.0635597705841064
Epoch 424 / 500 | iteration 20 / 30 | Total Loss: 1.0578558444976807 | KNN Loss: 6.226738452911377 | BCE Loss: 1.0578558444976807
Epoch 424 / 500 | iteration 25 / 30 | Total Loss: 1.0769426822662354 | KNN Loss: 6.226658821105957 | BCE Loss: 1.0769426822662354
Epoch   425: reducing learning rate of group 0 to 2.7058e-08.
Epoch 425 / 500 | iteration 0 / 30

Epoch 434 / 500 | iteration 10 / 30 | Total Loss: 1.0547343492507935 | KNN Loss: 6.226904392242432 | BCE Loss: 1.0547343492507935
Epoch 434 / 500 | iteration 15 / 30 | Total Loss: 1.0579932928085327 | KNN Loss: 6.226836681365967 | BCE Loss: 1.0579932928085327
Epoch 434 / 500 | iteration 20 / 30 | Total Loss: 1.0171520709991455 | KNN Loss: 6.226900577545166 | BCE Loss: 1.0171520709991455
Epoch 434 / 500 | iteration 25 / 30 | Total Loss: 1.0413484573364258 | KNN Loss: 6.226826190948486 | BCE Loss: 1.0413484573364258
Epoch 435 / 500 | iteration 0 / 30 | Total Loss: 1.027597427368164 | KNN Loss: 6.2268853187561035 | BCE Loss: 1.027597427368164
Epoch 435 / 500 | iteration 5 / 30 | Total Loss: 1.0344082117080688 | KNN Loss: 6.226993560791016 | BCE Loss: 1.0344082117080688
Epoch 435 / 500 | iteration 10 / 30 | Total Loss: 1.0293371677398682 | KNN Loss: 6.226876258850098 | BCE Loss: 1.0293371677398682
Epoch 435 / 500 | iteration 15 / 30 | Total Loss: 1.0478893518447876 | KNN Loss: 6.2268095016

Epoch 445 / 500 | iteration 0 / 30 | Total Loss: 1.0739376544952393 | KNN Loss: 6.22672176361084 | BCE Loss: 1.0739376544952393
Epoch 445 / 500 | iteration 5 / 30 | Total Loss: 1.0409018993377686 | KNN Loss: 6.226966857910156 | BCE Loss: 1.0409018993377686
Epoch 445 / 500 | iteration 10 / 30 | Total Loss: 1.0651123523712158 | KNN Loss: 6.2271728515625 | BCE Loss: 1.0651123523712158
Epoch 445 / 500 | iteration 15 / 30 | Total Loss: 1.0619744062423706 | KNN Loss: 6.2267913818359375 | BCE Loss: 1.0619744062423706
Epoch 445 / 500 | iteration 20 / 30 | Total Loss: 1.0516583919525146 | KNN Loss: 6.226952075958252 | BCE Loss: 1.0516583919525146
Epoch 445 / 500 | iteration 25 / 30 | Total Loss: 1.0842089653015137 | KNN Loss: 6.226693153381348 | BCE Loss: 1.0842089653015137
Epoch 446 / 500 | iteration 0 / 30 | Total Loss: 1.0410763025283813 | KNN Loss: 6.226855278015137 | BCE Loss: 1.0410763025283813
Epoch 446 / 500 | iteration 5 / 30 | Total Loss: 1.0736632347106934 | KNN Loss: 6.2269711494445

Epoch 455 / 500 | iteration 20 / 30 | Total Loss: 1.0498290061950684 | KNN Loss: 6.226836204528809 | BCE Loss: 1.0498290061950684
Epoch 455 / 500 | iteration 25 / 30 | Total Loss: 1.0303428173065186 | KNN Loss: 6.227087497711182 | BCE Loss: 1.0303428173065186
Epoch 456 / 500 | iteration 0 / 30 | Total Loss: 1.0508776903152466 | KNN Loss: 6.226822376251221 | BCE Loss: 1.0508776903152466
Epoch 456 / 500 | iteration 5 / 30 | Total Loss: 1.0384671688079834 | KNN Loss: 6.2269062995910645 | BCE Loss: 1.0384671688079834
Epoch 456 / 500 | iteration 10 / 30 | Total Loss: 1.049192190170288 | KNN Loss: 6.226996421813965 | BCE Loss: 1.049192190170288
Epoch 456 / 500 | iteration 15 / 30 | Total Loss: 1.0369064807891846 | KNN Loss: 6.22691535949707 | BCE Loss: 1.0369064807891846
Epoch 456 / 500 | iteration 20 / 30 | Total Loss: 1.0283255577087402 | KNN Loss: 6.226684093475342 | BCE Loss: 1.0283255577087402
Epoch 456 / 500 | iteration 25 / 30 | Total Loss: 1.0855860710144043 | KNN Loss: 6.22651338577

Epoch 466 / 500 | iteration 10 / 30 | Total Loss: 1.0367319583892822 | KNN Loss: 6.226821422576904 | BCE Loss: 1.0367319583892822
Epoch 466 / 500 | iteration 15 / 30 | Total Loss: 1.0570251941680908 | KNN Loss: 6.2269287109375 | BCE Loss: 1.0570251941680908
Epoch 466 / 500 | iteration 20 / 30 | Total Loss: 1.0570476055145264 | KNN Loss: 6.227108001708984 | BCE Loss: 1.0570476055145264
Epoch 466 / 500 | iteration 25 / 30 | Total Loss: 1.0634664297103882 | KNN Loss: 6.22666072845459 | BCE Loss: 1.0634664297103882
Epoch 467 / 500 | iteration 0 / 30 | Total Loss: 1.037036418914795 | KNN Loss: 6.226956844329834 | BCE Loss: 1.037036418914795
Epoch 467 / 500 | iteration 5 / 30 | Total Loss: 1.06328547000885 | KNN Loss: 6.227017879486084 | BCE Loss: 1.06328547000885
Epoch 467 / 500 | iteration 10 / 30 | Total Loss: 1.060294270515442 | KNN Loss: 6.226694107055664 | BCE Loss: 1.060294270515442
Epoch 467 / 500 | iteration 15 / 30 | Total Loss: 1.04581618309021 | KNN Loss: 6.226778984069824 | BCE 

Epoch 477 / 500 | iteration 0 / 30 | Total Loss: 1.0306501388549805 | KNN Loss: 6.22696590423584 | BCE Loss: 1.0306501388549805
Epoch 477 / 500 | iteration 5 / 30 | Total Loss: 1.0561633110046387 | KNN Loss: 6.226858615875244 | BCE Loss: 1.0561633110046387
Epoch 477 / 500 | iteration 10 / 30 | Total Loss: 1.0134172439575195 | KNN Loss: 6.226982116699219 | BCE Loss: 1.0134172439575195
Epoch 477 / 500 | iteration 15 / 30 | Total Loss: 1.048302173614502 | KNN Loss: 6.2264885902404785 | BCE Loss: 1.048302173614502
Epoch 477 / 500 | iteration 20 / 30 | Total Loss: 1.075028419494629 | KNN Loss: 6.227017402648926 | BCE Loss: 1.075028419494629
Epoch 477 / 500 | iteration 25 / 30 | Total Loss: 1.039601445198059 | KNN Loss: 6.227046012878418 | BCE Loss: 1.039601445198059
Epoch 478 / 500 | iteration 0 / 30 | Total Loss: 1.0403178930282593 | KNN Loss: 6.226950168609619 | BCE Loss: 1.0403178930282593
Epoch 478 / 500 | iteration 5 / 30 | Total Loss: 1.065673589706421 | KNN Loss: 6.226911544799805 | 

Epoch 487 / 500 | iteration 20 / 30 | Total Loss: 1.0551360845565796 | KNN Loss: 6.226767539978027 | BCE Loss: 1.0551360845565796
Epoch 487 / 500 | iteration 25 / 30 | Total Loss: 1.0185307264328003 | KNN Loss: 6.226561546325684 | BCE Loss: 1.0185307264328003
Epoch 488 / 500 | iteration 0 / 30 | Total Loss: 1.0295162200927734 | KNN Loss: 6.226572513580322 | BCE Loss: 1.0295162200927734
Epoch 488 / 500 | iteration 5 / 30 | Total Loss: 1.0006952285766602 | KNN Loss: 6.226861953735352 | BCE Loss: 1.0006952285766602
Epoch 488 / 500 | iteration 10 / 30 | Total Loss: 1.0619583129882812 | KNN Loss: 6.226715564727783 | BCE Loss: 1.0619583129882812
Epoch 488 / 500 | iteration 15 / 30 | Total Loss: 1.0766658782958984 | KNN Loss: 6.226724147796631 | BCE Loss: 1.0766658782958984
Epoch 488 / 500 | iteration 20 / 30 | Total Loss: 1.0658035278320312 | KNN Loss: 6.226887226104736 | BCE Loss: 1.0658035278320312
Epoch 488 / 500 | iteration 25 / 30 | Total Loss: 1.0724034309387207 | KNN Loss: 6.226710319

Epoch 498 / 500 | iteration 10 / 30 | Total Loss: 1.0423955917358398 | KNN Loss: 6.226851463317871 | BCE Loss: 1.0423955917358398
Epoch 498 / 500 | iteration 15 / 30 | Total Loss: 1.053356647491455 | KNN Loss: 6.227038860321045 | BCE Loss: 1.053356647491455
Epoch 498 / 500 | iteration 20 / 30 | Total Loss: 1.0328301191329956 | KNN Loss: 6.2270331382751465 | BCE Loss: 1.0328301191329956
Epoch 498 / 500 | iteration 25 / 30 | Total Loss: 1.0257878303527832 | KNN Loss: 6.226958274841309 | BCE Loss: 1.0257878303527832
Epoch 499 / 500 | iteration 0 / 30 | Total Loss: 1.0898840427398682 | KNN Loss: 6.227123737335205 | BCE Loss: 1.0898840427398682
Epoch 499 / 500 | iteration 5 / 30 | Total Loss: 1.0536696910858154 | KNN Loss: 6.226707935333252 | BCE Loss: 1.0536696910858154
Epoch 499 / 500 | iteration 10 / 30 | Total Loss: 1.0438902378082275 | KNN Loss: 6.226748466491699 | BCE Loss: 1.0438902378082275
Epoch 499 / 500 | iteration 15 / 30 | Total Loss: 1.0359838008880615 | KNN Loss: 6.2267909049

In [17]:
outputs, iterm = model(dataset[67][0].unsqueeze(0).to(device), return_intermidiate=True)
print(outputs)
print(outputs.min())
print(outputs.max())

tensor([[ 2.9260,  3.8751,  2.5815,  3.5753,  3.4575,  0.7073,  2.6673,  2.1985,
          2.3084,  1.9959,  2.2377,  2.2015,  0.7877,  1.8225,  1.2886,  1.5240,
          2.8093,  3.1821,  2.8012,  2.3064,  1.7457,  2.9548,  2.2895,  2.6397,
          2.5373,  1.7417,  2.1251,  1.4121,  1.4929,  0.3215, -0.2411,  0.9961,
          0.2093,  0.9268,  1.5295,  1.4739,  1.0045,  3.3175,  0.8013,  1.3213,
          0.9638, -0.7019, -0.2381,  2.3380,  2.1909,  0.7373, -0.1839,  0.0959,
          1.4636,  2.4984,  1.8227,  0.1449,  1.4272,  0.5246, -0.6349,  1.1086,
          1.4808,  1.3742,  1.3425,  1.8266,  0.5733,  0.8346,  0.1406,  1.7248,
          1.3218,  1.6694, -1.8264,  0.3082,  2.2918,  2.1461,  2.5514,  0.4280,
          1.3545,  2.4628,  2.0004,  1.2954,  0.2204,  0.7366,  0.2152,  1.5899,
          0.0288,  0.3820,  1.8390, -0.3779,  0.2434, -1.0678, -2.2868, -0.2612,
          0.5480, -1.8348,  0.4684, -0.1306, -0.5724, -0.9397,  0.5634,  1.2697,
         -0.6998, -0.6983,  

In [18]:
plt.figure()
plt.plot(losses)
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [19]:
dataset.transform = torchvision.transforms.Compose([
    BinaryEncodingTransform(mapping=dataset.items_to_idx),
]
)
dataset.target_transform = torchvision.transforms.Compose([
    BinaryEncodingTransform(mapping=dataset.items_to_idx),
]
)

In [21]:
dataset_ = [d[0].cpu() for d in dataset]

In [22]:
model = model.eval().cpu()
projections = model.calculate_intermidiate(dataset_)

100%|██████████| 15/15 [00:00<00:00, 78.51it/s]


In [23]:
distances = pairwise_distances(projections)
# distances = np.triu(distances)
distances_f = distances.flatten()

plt.matshow(distances)
plt.colorbar()
plt.figure()
plt.hist(distances_f[distances_f > 0], bins=1000)
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

# Fit DBSCAN and calculate indices


In [45]:
clusters = DBSCAN(eps=0.01, min_samples=80).fit_predict(projections)
# scores = []
# best_score = float('inf')
# clusters = None
# range_ = list(range(5, 20))
# for k in tqdm(range_):
#     y = GaussianMixture(n_components=k).fit_predict(projections)
#     cur_score = davies_bouldin_score(projections, y)
#     scores.append(cur_score)
    
#     if cur_score < best_score:
#         best_score = cur_score
#         clusters = y

In [46]:
perplexity = 100

p = reduce_dims_and_plot(projections[clusters != -1],
                         y=clusters[clusters != -1],
                         title=f'perplexity: {perplexity}',
                         file_name=None,
                         perplexity=perplexity,
                         library='Multicore-TSNE',
                         perform_PCA=False,
                         projected=None,
                         figure_type='2d',
                         show_figure=True,
                         close_figure=False,
                         text=None)    

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [47]:
# from sklearn.tree import DecisionTreeClassifier
# from sklearn import tree
# from sklearn.tree import _tree

In [48]:
tensor_dataset = torch.stack(dataset_)

In [49]:
# clf = DecisionTreeClassifier(max_depth=200, min_samples_leaf=5)
# clf = clf.fit(tensor_dataset[clusters!=-1], clusters[clusters != -1])
# print(clf.score(tensor_dataset[clusters!=-1], clusters[clusters != -1]))
# print(clf.get_depth())

In [50]:
# scores = []
# for min_samples in range(1,50, 1):
#     clf = DecisionTreeClassifier(max_depth=200, min_samples_leaf=min_samples)
#     clf = clf.fit(tensor_dataset[clusters!=-1], clusters[clusters != -1])
#     scores.append(clf.score(tensor_dataset[clusters!=-1], clusters[clusters != -1]))
    
# plt.figure()
# plt.plot(list(range(1,50, 1)), scores)
# plt.show()

In [51]:
def get_rules(tree, feature_names, class_names):
    tree_ = tree.tree_
    feature_name = [
        feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!"
        for i in tree_.feature
    ]

    paths = []
    path = []
    
    def recurse(node, path, paths):
        if tree_.feature[node] != _tree.TREE_UNDEFINED:
            name = feature_name[node]
            threshold = tree_.threshold[node]
            p1, p2 = list(path), list(path)
#             p1 += [f"({name} <= {np.round(threshold, 3)})"]
            p1 += [(name, '<=', np.round(threshold, 3))]
            recurse(tree_.children_left[node], p1, paths)
            p2 += [(name, '>', np.round(threshold, 3))]
            recurse(tree_.children_right[node], p2, paths)
        else:
            path += [(tree_.value[node], tree_.n_node_samples[node])]
            paths += [path]
            
    recurse(0, path, paths)

    # sort by samples count
    samples_count = [p[-1][1] for p in paths]
    ii = list(np.argsort(samples_count))
    paths = [paths[i] for i in reversed(ii)]
    
    rules = []
    for path in paths:
        rule = []
        
        for p in path[:-1]:
            rule += [p]
        target = " then "
        if class_names is None:
            target += "response: "+str(np.round(path[-1][0][0][0],3))
        else:
            classes = path[-1][0][0]
            l = np.argmax(classes)
            target += f"class: {class_names[l]} (proba: {np.round(100.0*classes[l]/np.sum(classes),2)}%)"
           
        proba = np.round(100.0*classes[l]/np.sum(classes),2)
        target += f" | based on {path[-1][1]:,} samples"
        rule_wrapper = {'target': target, 'rule': rule, 'proba': proba}
        rules += [rule_wrapper]
        
    return rules

In [52]:
# rules = get_rules(clf, dataset.items, clusters[clusters != -1])

# for rule in rules:
#     n_pos = 0
#     for c,p,v in rule['rule']:
#         if p == '>':
#             n_pos += 1
#     rule['pos'] = n_pos

In [53]:
# plt.figure()
# probs = [r['proba'] for r in rules]
# plt.hist(probs, bins = 100)
# plt.show()

In [54]:
# rules = sorted(rules, key=lambda x:x['pos'])
# rules = [r for r in rules if r['proba'] > 50]
# print(len(rules))

In [55]:
# for i in range(17):
#     r_i = rules[i]
#     print(f"------------- rule {i} length {len(r_i)} -------------")
#     print(r_i)

# Train a Soft-Decision-Tree given the self-labels

## Prepare the dataset

In [56]:
tree_dataset = list(zip(tensor_dataset[clusters!=-1], clusters[clusters != -1]))
batch_size = 512
tree_loader = torch.utils.data.DataLoader(tree_dataset, batch_size=batch_size, shuffle=True)

# Define how we prune the weights of a node

In [57]:
def prune_node(node_weights, factor=1):
    w = node_weights.cpu().detach().numpy()
    mean_ = np.mean(w)
    std_ = np.std(w)
    node_weights[((mean_ - std_ * factor) < node_weights) & (node_weights < (mean_ + std_ * factor))] = 0
    return node_weights

def prune_node_keep(node_weights, keep=4):
    w = node_weights.cpu().detach().numpy()
    throw_idx = np.argsort(abs(w))[:-keep]
    node_weights[throw_idx] = 0
    return node_weights

def prune_tree(tree_, factor):
    new_weights = tree_.inner_nodes.weight.clone()
    for i in range(new_weights.shape[0]):
        res = prune_node_keep(new_weights[i, :], factor)
        new_weights[i, :] = res

    with torch.no_grad():
        tree_.inner_nodes.weight.copy_(new_weights)
        
def sparseness(x):
    s = []
    for i in range(x.shape[0]):
        x_ = x[i, :]
        sp = (len(x_) - torch.norm(x_, 0).item()) / len(x_)
        s.append(sp)
    return np.mean(s)

def compute_regularization_by_level(tree):
    total_reg = 0
    for i in range(tree.inner_nodes.weight.shape[0]):
        cur_level = np.floor(np.log2(i+1))
        node_reg = torch.norm(tree.inner_nodes.weight[i].view(-1), 2)
        total_reg += 2**(-cur_level) * node_reg
    return total_reg

def show_sparseness(tree):
    avg_sp = sparseness(tree.inner_nodes.weight)
    print(f"Average sparseness: {avg_sp}")
    layer = 0
    sps = []
    for i in range(tree.inner_nodes.weight.shape[0]):
        cur_layer = int(np.floor(np.log2(i+1)))
        if cur_layer != layer:
            print(f"layer {layer}: {np.mean(sps)}")
            sps = []
            layer = cur_layer

        x_ = tree.inner_nodes.weight[i, :]
        sp = (len(x_) - torch.norm(x_, 0).item()) / len(x_)
        sps.append(sp)
        
    return avg_sp

## Training configurations

In [58]:
def do_epoch(model, loader, device, log_interval, losses, accs, epoch, iteration):
    model = model.train()
    for batch_idx, (data, target) in enumerate(loader):
        iteration += 1
        data, target = data.to(device), target.to(device)

        output, penalty = tree.forward(data)

        # Loss
        loss_tree = criterion(output, target.view(-1))

        # Penalty
        loss_tree += penalty

        # Sparse regularization
#         fc_params = torch.cat([x.view(-1) for x in tree.inner_nodes.parameters()])
#         regularization = sparsity_lamda * torch.norm(fc_params, 2)
        regularization = sparsity_lamda * compute_regularization_by_level(tree)
        loss = loss_tree

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        losses.append(loss.item())

        pred = output.data.max(1)[1]
        correct = pred.eq(target.view(-1).data).sum()
        accs.append(correct.item() / data.size()[0])

        # Print training status
        if batch_idx % log_interval == 0:
            print(f"Epoch: {epoch:02d} | Batch: {batch_idx:03d} / {len(loader):03d} | Total loss: {loss.item():.3f} | Reg loss: {regularization.item():.3f} | Tree loss: {loss_tree.item():.3f} | Accuracy: {correct.item() / data.size()[0]:03f} | {round((time.time() - start_time) / iteration, 3)} sec/iter")
            
    return iteration


In [59]:
lr = 5e-3
weight_decay = 5e-4
sparsity_lamda = 2e-3
epochs = 100
output_dim = len(set(clusters))
log_interval = 1
use_cuda = device != 'cpu'

In [60]:
tree = SDT(input_dim=tensor_dataset.shape[1], output_dim=len(clusters - 1), depth=tree_depth, lamda=1e-3, use_cuda=use_cuda)
optimizer = torch.optim.Adam(tree.parameters(),
                                 lr=lr,
                                 weight_decay=weight_decay)
criterion = nn.CrossEntropyLoss()
tree = tree.to(device)

In [61]:
losses = []
accs = []
sparsity = []

In [62]:
set(clusters)

{-1, 0}

In [42]:
start_time = time.time()
iteration = 0
for epoch in range(epochs):
    # Training
    avg_sp = show_sparseness(tree)
    sparsity.append(avg_sp)
    iteration = do_epoch(tree, tree_loader, device, log_interval, losses, accs, epoch, iteration)
    
    if epoch % 1 == 0:
        prune_tree(tree, factor=3)
        

Average sparseness: 0.0
layer 0: 0.0
layer 1: 0.0
layer 2: 0.0
layer 3: 0.0
layer 4: 0.0
layer 5: 0.0
layer 6: 0.0
Epoch: 00 | Batch: 000 / 030 | Total loss: 9.569 | Reg loss: 0.009 | Tree loss: 9.569 | Accuracy: 0.000000 | 0.307 sec/iter
Epoch: 00 | Batch: 001 / 030 | Total loss: 9.549 | Reg loss: 0.009 | Tree loss: 9.549 | Accuracy: 0.000000 | 0.278 sec/iter
Epoch: 00 | Batch: 002 / 030 | Total loss: 9.530 | Reg loss: 0.009 | Tree loss: 9.530 | Accuracy: 0.000000 | 0.263 sec/iter
Epoch: 00 | Batch: 003 / 030 | Total loss: 9.511 | Reg loss: 0.008 | Tree loss: 9.511 | Accuracy: 0.000000 | 0.255 sec/iter
Epoch: 00 | Batch: 004 / 030 | Total loss: 9.492 | Reg loss: 0.008 | Tree loss: 9.492 | Accuracy: 0.000000 | 0.251 sec/iter
Epoch: 00 | Batch: 005 / 030 | Total loss: 9.472 | Reg loss: 0.008 | Tree loss: 9.472 | Accuracy: 0.000000 | 0.248 sec/iter
Epoch: 00 | Batch: 006 / 030 | Total loss: 9.450 | Reg loss: 0.008 | Tree loss: 9.450 | Accuracy: 0.021484 | 0.249 sec/iter
Epoch: 00 | Batch

KeyboardInterrupt: 

In [None]:
plt.figure(figsize=(10, 5))
plt.ylabel("Accuracy")
plt.xlabel('Iteration')
plt.plot(accs, label='Accuracy vs iteration')
plt.show()

In [None]:
plt.figure()
plt.ylabel("Loss")
plt.xlabel('Iteration')
plt.plot(losses, label='Loss vs iteration')
plt.yscale("log")
plt.show()

plt.figure()
weights = tree.inner_nodes.weight.cpu().detach().numpy().flatten()
plt.hist(weights, bins=500)
weights_std = np.std(weights)
weights_mean = np.mean(weights)
plt.axvline(weights_mean + weights_std, color='r')
plt.axvline(weights_mean - weights_std, color='r')
plt.title(f"Mean: {weights_mean}   |   STD: {weights_std}")
plt.yscale("log")
plt.show()

# Tree Visualization

In [None]:
plt.figure(figsize=(15, 10), dpi=80)
avg_height, root = tree.visualize()

# Extract Rules

# Accumulate samples in the leaves

In [None]:
print(f"Number of patterns: {len(root.get_leaves())}")

In [None]:
method = 'greedy'

In [None]:
root.clear_leaves_samples()

with torch.no_grad():
    for batch_idx, (data, target) in enumerate(tree_loader):
        root.accumulate_samples(data, method)

# Tighten boundaries

In [None]:
attr_names = dataset.items

# print(attr_names)
leaves = root.get_leaves()
sum_comprehensibility = 0
comprehensibilities = []
for pattern_counter, leaf in enumerate(leaves):
    leaf.reset_path()
    leaf.tighten_with_accumulated_samples()
    conds = leaf.get_path_conditions(attr_names)
    for cond in conds:
        cond.weights = cond.weights / normalizers
    print(f"============== Pattern {pattern_counter + 1} ==============")
    comprehensibilities.append(sum([cond.comprehensibility for cond in conds]))
    
print(f"Average comprehensibility: {np.mean(comprehensibilities)}")
print(f"std comprehensibility: {np.std(comprehensibilities)}")
print(f"var comprehensibility: {np.var(comprehensibilities)}")
print(f"minimum comprehensibility: {np.min(comprehensibilities)}")
print(f"maximum comprehensibility: {np.max(comprehensibilities)}")