In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib widget
import numpy as np
import pandas as pd
import torch
import torchvision
from torch import optim
import matplotlib.pyplot as plt
import torch.nn.functional as F
from sklearn.metrics import pairwise_distances
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tqdm.notebook import tqdm
import os
import sys
import time
import torch.nn as nn
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)
from stream_generators.market_basket_dataset import MarketBasketDataset, BinaryEncodingTransform, RemoveItemsTransform
from utils.MatplotlibUtils import reduce_dims_and_plot
from network.auto_encoder import AutoEncoder
from losses.knn_loss import KNNLoss
from sklearn.cluster import DBSCAN
from sklearn.mixture import GaussianMixture
from soft_decision_tree.sdt_model import SDT
from sklearn.metrics import davies_bouldin_score

In [2]:
k = 8
tree_depth = 10
device = 'cuda'
dataset_path = r"/mnt/qnap/ekosman/Groceries_dataset.csv"

## Create the market basket dataset and use one-hot encoding for items

In [3]:
dataset = MarketBasketDataset(dataset_path=dataset_path)

In [4]:
model = AutoEncoder(dataset.n_items, 50, 4).train().to(device)
epochs = 500
lr = 5e-3
batch_size = 512
log_every = 5

In [5]:
dataset.transform = torchvision.transforms.Compose([
    RemoveItemsTransform(p=0.5),
    BinaryEncodingTransform(mapping=dataset.items_to_idx),
]
)
dataset.target_transform = torchvision.transforms.Compose([
    BinaryEncodingTransform(mapping=dataset.items_to_idx),
]
)

In [6]:
model.train()
data_iter = torch.utils.data.DataLoader(dataset,
                                     batch_size=batch_size,
                                     shuffle=True,
                                     num_workers=1,
                                     pin_memory=True)
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)
scheduler = ReduceLROnPlateau(optimizer, 'min', verbose=True, factor=0.7, threshold=1e-4)
knn_crt = KNNLoss(k=k).to(device)
losses = []
alpha = 10/170
gamma = 2
for epoch in range(epochs):
    total_loss = 0
    for iteration, (batch, target) in enumerate(data_iter):
        batch = batch.to(device)
        target = target.to(device)
        outputs, iterm = model(batch, return_intermidiate=True)
        mse_loss = F.binary_cross_entropy_with_logits(outputs, target, reduction='none')
        mask = torch.ones_like(mse_loss)
        mask[target == 0] = alpha ** gamma
        mask[target == 1] = (1 - alpha) ** gamma
        mse_loss = (mse_loss * mask).sum(dim=-1).mean()
        try:
            knn_loss = knn_crt(iterm)
            if torch.isinf(knn_loss):
                knn_loss = 0
        except ValueError:
            knn_loss = torch.tensor(0)
        loss = mse_loss + knn_loss
        total_loss += loss.item()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if iteration % log_every == 0:
            print(f"Epoch {epoch} / {epochs} | iteration {iteration} / {len(data_iter)} | Total Loss: {loss.item()} | KNN Loss: {knn_loss.item()} | BCE Loss: {mse_loss.item()}")
    
    scheduler.step(total_loss / (iteration + 1))
    losses.append(total_loss / (iteration + 1))

Epoch 0 / 500 | iteration 0 / 30 | Total Loss: 8.206886291503906 | KNN Loss: 6.2244954109191895 | BCE Loss: 1.9823912382125854
Epoch 0 / 500 | iteration 5 / 30 | Total Loss: 8.132919311523438 | KNN Loss: 6.224689483642578 | BCE Loss: 1.9082294702529907
Epoch 0 / 500 | iteration 10 / 30 | Total Loss: 8.164770126342773 | KNN Loss: 6.224438190460205 | BCE Loss: 1.9403314590454102
Epoch 0 / 500 | iteration 15 / 30 | Total Loss: 8.090507507324219 | KNN Loss: 6.223604202270508 | BCE Loss: 1.8669028282165527
Epoch 0 / 500 | iteration 20 / 30 | Total Loss: 8.19241714477539 | KNN Loss: 6.222877025604248 | BCE Loss: 1.9695402383804321
Epoch 0 / 500 | iteration 25 / 30 | Total Loss: 8.14727783203125 | KNN Loss: 6.222434997558594 | BCE Loss: 1.9248433113098145
Epoch 1 / 500 | iteration 0 / 30 | Total Loss: 8.140485763549805 | KNN Loss: 6.221712589263916 | BCE Loss: 1.9187731742858887
Epoch 1 / 500 | iteration 5 / 30 | Total Loss: 8.113654136657715 | KNN Loss: 6.220485687255859 | BCE Loss: 1.893168

Epoch 10 / 500 | iteration 25 / 30 | Total Loss: 4.986264228820801 | KNN Loss: 3.8754231929779053 | BCE Loss: 1.1108407974243164
Epoch 11 / 500 | iteration 0 / 30 | Total Loss: 4.874971389770508 | KNN Loss: 3.7803409099578857 | BCE Loss: 1.094630241394043
Epoch 11 / 500 | iteration 5 / 30 | Total Loss: 4.799947738647461 | KNN Loss: 3.687481641769409 | BCE Loss: 1.1124663352966309
Epoch 11 / 500 | iteration 10 / 30 | Total Loss: 4.735187530517578 | KNN Loss: 3.6294658184051514 | BCE Loss: 1.1057215929031372
Epoch 11 / 500 | iteration 15 / 30 | Total Loss: 4.600740432739258 | KNN Loss: 3.5104541778564453 | BCE Loss: 1.090286135673523
Epoch 11 / 500 | iteration 20 / 30 | Total Loss: 4.573747158050537 | KNN Loss: 3.4555721282958984 | BCE Loss: 1.1181750297546387
Epoch 11 / 500 | iteration 25 / 30 | Total Loss: 4.487905979156494 | KNN Loss: 3.3704724311828613 | BCE Loss: 1.1174335479736328
Epoch 12 / 500 | iteration 0 / 30 | Total Loss: 4.407155990600586 | KNN Loss: 3.301198720932007 | BCE 

Epoch 21 / 500 | iteration 15 / 30 | Total Loss: 3.745793104171753 | KNN Loss: 2.7118005752563477 | BCE Loss: 1.0339925289154053
Epoch 21 / 500 | iteration 20 / 30 | Total Loss: 3.7594337463378906 | KNN Loss: 2.709543466567993 | BCE Loss: 1.049890398979187
Epoch 21 / 500 | iteration 25 / 30 | Total Loss: 3.7088677883148193 | KNN Loss: 2.6439218521118164 | BCE Loss: 1.064945936203003
Epoch 22 / 500 | iteration 0 / 30 | Total Loss: 3.787175178527832 | KNN Loss: 2.700822591781616 | BCE Loss: 1.0863525867462158
Epoch 22 / 500 | iteration 5 / 30 | Total Loss: 3.7421951293945312 | KNN Loss: 2.6866912841796875 | BCE Loss: 1.0555038452148438
Epoch 22 / 500 | iteration 10 / 30 | Total Loss: 3.7248144149780273 | KNN Loss: 2.6636526584625244 | BCE Loss: 1.0611616373062134
Epoch 22 / 500 | iteration 15 / 30 | Total Loss: 3.732130289077759 | KNN Loss: 2.6874048709869385 | BCE Loss: 1.0447254180908203
Epoch 22 / 500 | iteration 20 / 30 | Total Loss: 3.753727674484253 | KNN Loss: 2.7074761390686035 |

Epoch 32 / 500 | iteration 5 / 30 | Total Loss: 3.6425485610961914 | KNN Loss: 2.6196279525756836 | BCE Loss: 1.0229206085205078
Epoch 32 / 500 | iteration 10 / 30 | Total Loss: 3.6593141555786133 | KNN Loss: 2.5833194255828857 | BCE Loss: 1.0759947299957275
Epoch 32 / 500 | iteration 15 / 30 | Total Loss: 3.5900542736053467 | KNN Loss: 2.5517096519470215 | BCE Loss: 1.0383446216583252
Epoch 32 / 500 | iteration 20 / 30 | Total Loss: 3.6452982425689697 | KNN Loss: 2.5981411933898926 | BCE Loss: 1.0471570491790771
Epoch 32 / 500 | iteration 25 / 30 | Total Loss: 3.558393716812134 | KNN Loss: 2.5428829193115234 | BCE Loss: 1.0155107975006104
Epoch 33 / 500 | iteration 0 / 30 | Total Loss: 3.5885677337646484 | KNN Loss: 2.533275842666626 | BCE Loss: 1.0552918910980225
Epoch 33 / 500 | iteration 5 / 30 | Total Loss: 3.6900157928466797 | KNN Loss: 2.650986671447754 | BCE Loss: 1.0390291213989258
Epoch 33 / 500 | iteration 10 / 30 | Total Loss: 3.6802167892456055 | KNN Loss: 2.64267873764038

Epoch 42 / 500 | iteration 25 / 30 | Total Loss: 3.618813991546631 | KNN Loss: 2.572598457336426 | BCE Loss: 1.046215534210205
Epoch 43 / 500 | iteration 0 / 30 | Total Loss: 3.6307477951049805 | KNN Loss: 2.569920539855957 | BCE Loss: 1.0608272552490234
Epoch 43 / 500 | iteration 5 / 30 | Total Loss: 3.6177334785461426 | KNN Loss: 2.5625157356262207 | BCE Loss: 1.0552177429199219
Epoch 43 / 500 | iteration 10 / 30 | Total Loss: 3.590005874633789 | KNN Loss: 2.553828716278076 | BCE Loss: 1.036177158355713
Epoch 43 / 500 | iteration 15 / 30 | Total Loss: 3.61464262008667 | KNN Loss: 2.6170718669891357 | BCE Loss: 0.9975708723068237
Epoch 43 / 500 | iteration 20 / 30 | Total Loss: 3.642763614654541 | KNN Loss: 2.5822277069091797 | BCE Loss: 1.0605359077453613
Epoch 43 / 500 | iteration 25 / 30 | Total Loss: 3.558879852294922 | KNN Loss: 2.5587778091430664 | BCE Loss: 1.0001020431518555
Epoch 44 / 500 | iteration 0 / 30 | Total Loss: 3.5804293155670166 | KNN Loss: 2.5366017818450928 | BCE

Epoch 53 / 500 | iteration 15 / 30 | Total Loss: 3.5945534706115723 | KNN Loss: 2.554644823074341 | BCE Loss: 1.0399086475372314
Epoch 53 / 500 | iteration 20 / 30 | Total Loss: 3.570094585418701 | KNN Loss: 2.533419132232666 | BCE Loss: 1.0366753339767456
Epoch 53 / 500 | iteration 25 / 30 | Total Loss: 3.58140230178833 | KNN Loss: 2.5324249267578125 | BCE Loss: 1.0489773750305176
Epoch 54 / 500 | iteration 0 / 30 | Total Loss: 3.5327978134155273 | KNN Loss: 2.5236971378326416 | BCE Loss: 1.0091006755828857
Epoch 54 / 500 | iteration 5 / 30 | Total Loss: 3.5243842601776123 | KNN Loss: 2.516045331954956 | BCE Loss: 1.0083389282226562
Epoch 54 / 500 | iteration 10 / 30 | Total Loss: 3.6296427249908447 | KNN Loss: 2.563041925430298 | BCE Loss: 1.0666007995605469
Epoch 54 / 500 | iteration 15 / 30 | Total Loss: 3.588491916656494 | KNN Loss: 2.5645387172698975 | BCE Loss: 1.0239530801773071
Epoch 54 / 500 | iteration 20 / 30 | Total Loss: 3.616178512573242 | KNN Loss: 2.58389949798584 | BC

Epoch 64 / 500 | iteration 5 / 30 | Total Loss: 3.6058554649353027 | KNN Loss: 2.599349021911621 | BCE Loss: 1.0065064430236816
Epoch 64 / 500 | iteration 10 / 30 | Total Loss: 3.5871920585632324 | KNN Loss: 2.5305991172790527 | BCE Loss: 1.0565928220748901
Epoch 64 / 500 | iteration 15 / 30 | Total Loss: 3.581334114074707 | KNN Loss: 2.5197882652282715 | BCE Loss: 1.0615458488464355
Epoch 64 / 500 | iteration 20 / 30 | Total Loss: 3.5523552894592285 | KNN Loss: 2.511107921600342 | BCE Loss: 1.0412473678588867
Epoch 64 / 500 | iteration 25 / 30 | Total Loss: 3.5439321994781494 | KNN Loss: 2.5204761028289795 | BCE Loss: 1.02345609664917
Epoch 65 / 500 | iteration 0 / 30 | Total Loss: 3.5868639945983887 | KNN Loss: 2.555251121520996 | BCE Loss: 1.0316128730773926
Epoch 65 / 500 | iteration 5 / 30 | Total Loss: 3.5048317909240723 | KNN Loss: 2.4932897090911865 | BCE Loss: 1.0115419626235962
Epoch 65 / 500 | iteration 10 / 30 | Total Loss: 3.5775623321533203 | KNN Loss: 2.5503509044647217 

Epoch 74 / 500 | iteration 25 / 30 | Total Loss: 3.5620498657226562 | KNN Loss: 2.523221254348755 | BCE Loss: 1.0388284921646118
Epoch 75 / 500 | iteration 0 / 30 | Total Loss: 3.5502638816833496 | KNN Loss: 2.5138258934020996 | BCE Loss: 1.0364378690719604
Epoch 75 / 500 | iteration 5 / 30 | Total Loss: 3.5414342880249023 | KNN Loss: 2.5267117023468018 | BCE Loss: 1.0147227048873901
Epoch 75 / 500 | iteration 10 / 30 | Total Loss: 3.581902265548706 | KNN Loss: 2.5560367107391357 | BCE Loss: 1.0258655548095703
Epoch 75 / 500 | iteration 15 / 30 | Total Loss: 3.600609540939331 | KNN Loss: 2.578886032104492 | BCE Loss: 1.0217235088348389
Epoch 75 / 500 | iteration 20 / 30 | Total Loss: 3.589806079864502 | KNN Loss: 2.5425126552581787 | BCE Loss: 1.0472933053970337
Epoch 75 / 500 | iteration 25 / 30 | Total Loss: 3.543903112411499 | KNN Loss: 2.526517391204834 | BCE Loss: 1.017385721206665
Epoch 76 / 500 | iteration 0 / 30 | Total Loss: 3.5625858306884766 | KNN Loss: 2.5251457691192627 | 

Epoch 85 / 500 | iteration 15 / 30 | Total Loss: 3.5573840141296387 | KNN Loss: 2.5480592250823975 | BCE Loss: 1.0093249082565308
Epoch 85 / 500 | iteration 20 / 30 | Total Loss: 3.517387866973877 | KNN Loss: 2.503220319747925 | BCE Loss: 1.0141675472259521
Epoch 85 / 500 | iteration 25 / 30 | Total Loss: 3.579596996307373 | KNN Loss: 2.5244650840759277 | BCE Loss: 1.0551319122314453
Epoch 86 / 500 | iteration 0 / 30 | Total Loss: 3.58402419090271 | KNN Loss: 2.5128889083862305 | BCE Loss: 1.0711352825164795
Epoch 86 / 500 | iteration 5 / 30 | Total Loss: 3.5946221351623535 | KNN Loss: 2.5446226596832275 | BCE Loss: 1.0499993562698364
Epoch 86 / 500 | iteration 10 / 30 | Total Loss: 3.5960631370544434 | KNN Loss: 2.5850255489349365 | BCE Loss: 1.0110375881195068
Epoch 86 / 500 | iteration 15 / 30 | Total Loss: 3.520005464553833 | KNN Loss: 2.5244715213775635 | BCE Loss: 0.9955338835716248
Epoch 86 / 500 | iteration 20 / 30 | Total Loss: 3.572930097579956 | KNN Loss: 2.525005340576172 |

Epoch 96 / 500 | iteration 5 / 30 | Total Loss: 3.5368642807006836 | KNN Loss: 2.5356953144073486 | BCE Loss: 1.0011690855026245
Epoch 96 / 500 | iteration 10 / 30 | Total Loss: 3.5779151916503906 | KNN Loss: 2.537531852722168 | BCE Loss: 1.0403833389282227
Epoch 96 / 500 | iteration 15 / 30 | Total Loss: 3.532299041748047 | KNN Loss: 2.51983380317688 | BCE Loss: 1.012465238571167
Epoch 96 / 500 | iteration 20 / 30 | Total Loss: 3.562077760696411 | KNN Loss: 2.5262606143951416 | BCE Loss: 1.0358171463012695
Epoch 96 / 500 | iteration 25 / 30 | Total Loss: 3.5010180473327637 | KNN Loss: 2.4703311920166016 | BCE Loss: 1.0306869745254517
Epoch 97 / 500 | iteration 0 / 30 | Total Loss: 3.5403804779052734 | KNN Loss: 2.4941577911376953 | BCE Loss: 1.0462225675582886
Epoch 97 / 500 | iteration 5 / 30 | Total Loss: 3.541656255722046 | KNN Loss: 2.531221866607666 | BCE Loss: 1.0104343891143799
Epoch 97 / 500 | iteration 10 / 30 | Total Loss: 3.5166468620300293 | KNN Loss: 2.4980311393737793 | 

Epoch 106 / 500 | iteration 25 / 30 | Total Loss: 3.5235190391540527 | KNN Loss: 2.501573085784912 | BCE Loss: 1.021945834159851
Epoch 107 / 500 | iteration 0 / 30 | Total Loss: 3.477726936340332 | KNN Loss: 2.458770275115967 | BCE Loss: 1.0189567804336548
Epoch 107 / 500 | iteration 5 / 30 | Total Loss: 3.5238239765167236 | KNN Loss: 2.492575168609619 | BCE Loss: 1.0312488079071045
Epoch 107 / 500 | iteration 10 / 30 | Total Loss: 3.53653621673584 | KNN Loss: 2.516874074935913 | BCE Loss: 1.0196621417999268
Epoch 107 / 500 | iteration 15 / 30 | Total Loss: 3.4995992183685303 | KNN Loss: 2.486217498779297 | BCE Loss: 1.0133817195892334
Epoch 107 / 500 | iteration 20 / 30 | Total Loss: 3.4829905033111572 | KNN Loss: 2.4659807682037354 | BCE Loss: 1.0170097351074219
Epoch 107 / 500 | iteration 25 / 30 | Total Loss: 3.4860730171203613 | KNN Loss: 2.4955649375915527 | BCE Loss: 0.9905081987380981
Epoch 108 / 500 | iteration 0 / 30 | Total Loss: 3.5263938903808594 | KNN Loss: 2.514806032180

Epoch 117 / 500 | iteration 10 / 30 | Total Loss: 3.516819477081299 | KNN Loss: 2.4686055183410645 | BCE Loss: 1.0482139587402344
Epoch 117 / 500 | iteration 15 / 30 | Total Loss: 3.5354788303375244 | KNN Loss: 2.5247154235839844 | BCE Loss: 1.01076340675354
Epoch 117 / 500 | iteration 20 / 30 | Total Loss: 3.5077829360961914 | KNN Loss: 2.481762647628784 | BCE Loss: 1.0260204076766968
Epoch 117 / 500 | iteration 25 / 30 | Total Loss: 3.5098652839660645 | KNN Loss: 2.4714391231536865 | BCE Loss: 1.038426160812378
Epoch 118 / 500 | iteration 0 / 30 | Total Loss: 3.472809314727783 | KNN Loss: 2.4850997924804688 | BCE Loss: 0.9877094030380249
Epoch 118 / 500 | iteration 5 / 30 | Total Loss: 3.562181234359741 | KNN Loss: 2.5617263317108154 | BCE Loss: 1.0004549026489258
Epoch 118 / 500 | iteration 10 / 30 | Total Loss: 3.5523910522460938 | KNN Loss: 2.532783269882202 | BCE Loss: 1.0196079015731812
Epoch 118 / 500 | iteration 15 / 30 | Total Loss: 3.491797685623169 | KNN Loss: 2.49460053443

Epoch 128 / 500 | iteration 0 / 30 | Total Loss: 3.4790713787078857 | KNN Loss: 2.482372999191284 | BCE Loss: 0.9966983199119568
Epoch 128 / 500 | iteration 5 / 30 | Total Loss: 3.540515899658203 | KNN Loss: 2.486239194869995 | BCE Loss: 1.054276704788208
Epoch 128 / 500 | iteration 10 / 30 | Total Loss: 3.512518882751465 | KNN Loss: 2.4698238372802734 | BCE Loss: 1.0426949262619019
Epoch 128 / 500 | iteration 15 / 30 | Total Loss: 3.5281336307525635 | KNN Loss: 2.4905502796173096 | BCE Loss: 1.037583351135254
Epoch 128 / 500 | iteration 20 / 30 | Total Loss: 3.5035970211029053 | KNN Loss: 2.519258975982666 | BCE Loss: 0.984338104724884
Epoch 128 / 500 | iteration 25 / 30 | Total Loss: 3.501039505004883 | KNN Loss: 2.4848451614379883 | BCE Loss: 1.0161943435668945
Epoch 129 / 500 | iteration 0 / 30 | Total Loss: 3.522975444793701 | KNN Loss: 2.5222346782684326 | BCE Loss: 1.000740885734558
Epoch 129 / 500 | iteration 5 / 30 | Total Loss: 3.4650051593780518 | KNN Loss: 2.481488227844238

Epoch 138 / 500 | iteration 20 / 30 | Total Loss: 3.5242679119110107 | KNN Loss: 2.5149524211883545 | BCE Loss: 1.0093154907226562
Epoch 138 / 500 | iteration 25 / 30 | Total Loss: 3.5178122520446777 | KNN Loss: 2.5162384510040283 | BCE Loss: 1.0015736818313599
Epoch 139 / 500 | iteration 0 / 30 | Total Loss: 3.47829270362854 | KNN Loss: 2.46437668800354 | BCE Loss: 1.013916015625
Epoch 139 / 500 | iteration 5 / 30 | Total Loss: 3.548544406890869 | KNN Loss: 2.5250816345214844 | BCE Loss: 1.0234627723693848
Epoch 139 / 500 | iteration 10 / 30 | Total Loss: 3.496713638305664 | KNN Loss: 2.5032238960266113 | BCE Loss: 0.9934898614883423
Epoch 139 / 500 | iteration 15 / 30 | Total Loss: 3.49935245513916 | KNN Loss: 2.4931440353393555 | BCE Loss: 1.0062083005905151
Epoch 139 / 500 | iteration 20 / 30 | Total Loss: 3.5211572647094727 | KNN Loss: 2.493079662322998 | BCE Loss: 1.028077483177185
Epoch 139 / 500 | iteration 25 / 30 | Total Loss: 3.5016798973083496 | KNN Loss: 2.4818763732910156

Epoch 149 / 500 | iteration 5 / 30 | Total Loss: 3.5108957290649414 | KNN Loss: 2.5202534198760986 | BCE Loss: 0.9906424283981323
Epoch 149 / 500 | iteration 10 / 30 | Total Loss: 3.495527744293213 | KNN Loss: 2.477023124694824 | BCE Loss: 1.0185045003890991
Epoch 149 / 500 | iteration 15 / 30 | Total Loss: 3.5112931728363037 | KNN Loss: 2.50504732131958 | BCE Loss: 1.0062458515167236
Epoch 149 / 500 | iteration 20 / 30 | Total Loss: 3.5018668174743652 | KNN Loss: 2.460127592086792 | BCE Loss: 1.0417392253875732
Epoch 149 / 500 | iteration 25 / 30 | Total Loss: 3.5237412452697754 | KNN Loss: 2.500802755355835 | BCE Loss: 1.0229384899139404
Epoch 150 / 500 | iteration 0 / 30 | Total Loss: 3.5362730026245117 | KNN Loss: 2.5196688175201416 | BCE Loss: 1.0166040658950806
Epoch 150 / 500 | iteration 5 / 30 | Total Loss: 3.5245282649993896 | KNN Loss: 2.4881811141967773 | BCE Loss: 1.0363471508026123
Epoch 150 / 500 | iteration 10 / 30 | Total Loss: 3.5430285930633545 | KNN Loss: 2.511057615

Epoch 159 / 500 | iteration 20 / 30 | Total Loss: 3.4902005195617676 | KNN Loss: 2.4705018997192383 | BCE Loss: 1.0196987390518188
Epoch 159 / 500 | iteration 25 / 30 | Total Loss: 3.485769271850586 | KNN Loss: 2.4574851989746094 | BCE Loss: 1.0282840728759766
Epoch 160 / 500 | iteration 0 / 30 | Total Loss: 3.524022102355957 | KNN Loss: 2.493981122970581 | BCE Loss: 1.030040979385376
Epoch 160 / 500 | iteration 5 / 30 | Total Loss: 3.4356391429901123 | KNN Loss: 2.457348108291626 | BCE Loss: 0.9782910346984863
Epoch 160 / 500 | iteration 10 / 30 | Total Loss: 3.5191798210144043 | KNN Loss: 2.4984936714172363 | BCE Loss: 1.020686149597168
Epoch 160 / 500 | iteration 15 / 30 | Total Loss: 3.49527645111084 | KNN Loss: 2.482365369796753 | BCE Loss: 1.0129109621047974
Epoch 160 / 500 | iteration 20 / 30 | Total Loss: 3.486483573913574 | KNN Loss: 2.488933563232422 | BCE Loss: 0.9975500702857971
Epoch 160 / 500 | iteration 25 / 30 | Total Loss: 3.5069117546081543 | KNN Loss: 2.4892706871032

Epoch 170 / 500 | iteration 10 / 30 | Total Loss: 3.489757537841797 | KNN Loss: 2.4722371101379395 | BCE Loss: 1.017520546913147
Epoch 170 / 500 | iteration 15 / 30 | Total Loss: 3.495889663696289 | KNN Loss: 2.4910104274749756 | BCE Loss: 1.0048792362213135
Epoch 170 / 500 | iteration 20 / 30 | Total Loss: 3.5067591667175293 | KNN Loss: 2.4838387966156006 | BCE Loss: 1.0229203701019287
Epoch 170 / 500 | iteration 25 / 30 | Total Loss: 3.490398406982422 | KNN Loss: 2.4731409549713135 | BCE Loss: 1.0172574520111084
Epoch 171 / 500 | iteration 0 / 30 | Total Loss: 3.472928047180176 | KNN Loss: 2.476259231567383 | BCE Loss: 0.9966689348220825
Epoch 171 / 500 | iteration 5 / 30 | Total Loss: 3.509270191192627 | KNN Loss: 2.478125810623169 | BCE Loss: 1.031144380569458
Epoch 171 / 500 | iteration 10 / 30 | Total Loss: 3.511808156967163 | KNN Loss: 2.486237049102783 | BCE Loss: 1.0255711078643799
Epoch 171 / 500 | iteration 15 / 30 | Total Loss: 3.481624126434326 | KNN Loss: 2.49203181266784

Epoch 181 / 500 | iteration 0 / 30 | Total Loss: 3.509568929672241 | KNN Loss: 2.5079567432403564 | BCE Loss: 1.0016121864318848
Epoch 181 / 500 | iteration 5 / 30 | Total Loss: 3.528179168701172 | KNN Loss: 2.5269272327423096 | BCE Loss: 1.0012519359588623
Epoch 181 / 500 | iteration 10 / 30 | Total Loss: 3.516800880432129 | KNN Loss: 2.4994313716888428 | BCE Loss: 1.0173693895339966
Epoch 181 / 500 | iteration 15 / 30 | Total Loss: 3.5282649993896484 | KNN Loss: 2.521934986114502 | BCE Loss: 1.0063300132751465
Epoch 181 / 500 | iteration 20 / 30 | Total Loss: 3.506962299346924 | KNN Loss: 2.4782400131225586 | BCE Loss: 1.0287224054336548
Epoch 181 / 500 | iteration 25 / 30 | Total Loss: 3.523735761642456 | KNN Loss: 2.478365898132324 | BCE Loss: 1.0453698635101318
Epoch 182 / 500 | iteration 0 / 30 | Total Loss: 3.521090269088745 | KNN Loss: 2.487921953201294 | BCE Loss: 1.0331683158874512
Epoch 182 / 500 | iteration 5 / 30 | Total Loss: 3.5310826301574707 | KNN Loss: 2.5038719177246

Epoch 191 / 500 | iteration 15 / 30 | Total Loss: 3.493678569793701 | KNN Loss: 2.46427321434021 | BCE Loss: 1.0294053554534912
Epoch 191 / 500 | iteration 20 / 30 | Total Loss: 3.5171847343444824 | KNN Loss: 2.478317975997925 | BCE Loss: 1.0388667583465576
Epoch 191 / 500 | iteration 25 / 30 | Total Loss: 3.4994025230407715 | KNN Loss: 2.4978771209716797 | BCE Loss: 1.0015255212783813
Epoch 192 / 500 | iteration 0 / 30 | Total Loss: 3.4948947429656982 | KNN Loss: 2.467897653579712 | BCE Loss: 1.0269970893859863
Epoch 192 / 500 | iteration 5 / 30 | Total Loss: 3.498119354248047 | KNN Loss: 2.4874825477600098 | BCE Loss: 1.0106369256973267
Epoch 192 / 500 | iteration 10 / 30 | Total Loss: 3.5351531505584717 | KNN Loss: 2.5009243488311768 | BCE Loss: 1.034228801727295
Epoch 192 / 500 | iteration 15 / 30 | Total Loss: 3.454158067703247 | KNN Loss: 2.4661145210266113 | BCE Loss: 0.9880436062812805
Epoch 192 / 500 | iteration 20 / 30 | Total Loss: 3.4966297149658203 | KNN Loss: 2.4917263984

Epoch 202 / 500 | iteration 0 / 30 | Total Loss: 3.4852137565612793 | KNN Loss: 2.469900369644165 | BCE Loss: 1.0153132677078247
Epoch 202 / 500 | iteration 5 / 30 | Total Loss: 3.506436347961426 | KNN Loss: 2.481440305709839 | BCE Loss: 1.0249959230422974
Epoch 202 / 500 | iteration 10 / 30 | Total Loss: 3.485416889190674 | KNN Loss: 2.4587881565093994 | BCE Loss: 1.0266287326812744
Epoch 202 / 500 | iteration 15 / 30 | Total Loss: 3.506577253341675 | KNN Loss: 2.4914398193359375 | BCE Loss: 1.0151374340057373
Epoch 202 / 500 | iteration 20 / 30 | Total Loss: 3.50209379196167 | KNN Loss: 2.495753526687622 | BCE Loss: 1.0063402652740479
Epoch 202 / 500 | iteration 25 / 30 | Total Loss: 3.4758317470550537 | KNN Loss: 2.4733033180236816 | BCE Loss: 1.002528429031372
Epoch 203 / 500 | iteration 0 / 30 | Total Loss: 3.4789905548095703 | KNN Loss: 2.444753408432007 | BCE Loss: 1.0342371463775635
Epoch 203 / 500 | iteration 5 / 30 | Total Loss: 3.4938454627990723 | KNN Loss: 2.46647262573242

Epoch 212 / 500 | iteration 20 / 30 | Total Loss: 3.4997994899749756 | KNN Loss: 2.480531692504883 | BCE Loss: 1.0192677974700928
Epoch 212 / 500 | iteration 25 / 30 | Total Loss: 3.475440502166748 | KNN Loss: 2.4762654304504395 | BCE Loss: 0.9991750121116638
Epoch 213 / 500 | iteration 0 / 30 | Total Loss: 3.491387367248535 | KNN Loss: 2.4822981357574463 | BCE Loss: 1.0090892314910889
Epoch 213 / 500 | iteration 5 / 30 | Total Loss: 3.4993159770965576 | KNN Loss: 2.447472333908081 | BCE Loss: 1.0518436431884766
Epoch 213 / 500 | iteration 10 / 30 | Total Loss: 3.532686948776245 | KNN Loss: 2.50586199760437 | BCE Loss: 1.026824951171875
Epoch 213 / 500 | iteration 15 / 30 | Total Loss: 3.4772043228149414 | KNN Loss: 2.4476232528686523 | BCE Loss: 1.0295811891555786
Epoch 213 / 500 | iteration 20 / 30 | Total Loss: 3.489011287689209 | KNN Loss: 2.4446699619293213 | BCE Loss: 1.0443414449691772
Epoch 213 / 500 | iteration 25 / 30 | Total Loss: 3.4883484840393066 | KNN Loss: 2.47473859786

Epoch 223 / 500 | iteration 10 / 30 | Total Loss: 3.48211932182312 | KNN Loss: 2.4569292068481445 | BCE Loss: 1.0251901149749756
Epoch 223 / 500 | iteration 15 / 30 | Total Loss: 3.4521803855895996 | KNN Loss: 2.4801275730133057 | BCE Loss: 0.972052812576294
Epoch 223 / 500 | iteration 20 / 30 | Total Loss: 3.4767961502075195 | KNN Loss: 2.451019525527954 | BCE Loss: 1.0257765054702759
Epoch 223 / 500 | iteration 25 / 30 | Total Loss: 3.50527286529541 | KNN Loss: 2.473127603530884 | BCE Loss: 1.0321452617645264
Epoch 224 / 500 | iteration 0 / 30 | Total Loss: 3.4720869064331055 | KNN Loss: 2.4628007411956787 | BCE Loss: 1.0092862844467163
Epoch 224 / 500 | iteration 5 / 30 | Total Loss: 3.488077163696289 | KNN Loss: 2.473170042037964 | BCE Loss: 1.0149070024490356
Epoch 224 / 500 | iteration 10 / 30 | Total Loss: 3.4708476066589355 | KNN Loss: 2.4547536373138428 | BCE Loss: 1.0160940885543823
Epoch 224 / 500 | iteration 15 / 30 | Total Loss: 3.517003059387207 | KNN Loss: 2.487504005432

Epoch 233 / 500 | iteration 25 / 30 | Total Loss: 3.4762868881225586 | KNN Loss: 2.4581706523895264 | BCE Loss: 1.0181162357330322
Epoch 234 / 500 | iteration 0 / 30 | Total Loss: 3.47845721244812 | KNN Loss: 2.479384422302246 | BCE Loss: 0.9990727305412292
Epoch 234 / 500 | iteration 5 / 30 | Total Loss: 3.484494686126709 | KNN Loss: 2.4609875679016113 | BCE Loss: 1.0235071182250977
Epoch 234 / 500 | iteration 10 / 30 | Total Loss: 3.4246671199798584 | KNN Loss: 2.4542312622070312 | BCE Loss: 0.9704358577728271
Epoch 234 / 500 | iteration 15 / 30 | Total Loss: 3.5096564292907715 | KNN Loss: 2.5033481121063232 | BCE Loss: 1.0063083171844482
Epoch 234 / 500 | iteration 20 / 30 | Total Loss: 3.464242935180664 | KNN Loss: 2.4319450855255127 | BCE Loss: 1.0322977304458618
Epoch 234 / 500 | iteration 25 / 30 | Total Loss: 3.4825663566589355 | KNN Loss: 2.4584901332855225 | BCE Loss: 1.0240763425827026
Epoch 235 / 500 | iteration 0 / 30 | Total Loss: 3.4697084426879883 | KNN Loss: 2.46713328

Epoch 244 / 500 | iteration 15 / 30 | Total Loss: 3.574016571044922 | KNN Loss: 2.5398333072662354 | BCE Loss: 1.034183144569397
Epoch 244 / 500 | iteration 20 / 30 | Total Loss: 3.437544345855713 | KNN Loss: 2.460221290588379 | BCE Loss: 0.9773229956626892
Epoch 244 / 500 | iteration 25 / 30 | Total Loss: 3.470226287841797 | KNN Loss: 2.4775280952453613 | BCE Loss: 0.9926981925964355
Epoch 245 / 500 | iteration 0 / 30 | Total Loss: 3.4323129653930664 | KNN Loss: 2.4348177909851074 | BCE Loss: 0.9974952936172485
Epoch 245 / 500 | iteration 5 / 30 | Total Loss: 3.449394464492798 | KNN Loss: 2.451300621032715 | BCE Loss: 0.9980939030647278
Epoch 245 / 500 | iteration 10 / 30 | Total Loss: 3.4762632846832275 | KNN Loss: 2.457416534423828 | BCE Loss: 1.0188467502593994
Epoch 245 / 500 | iteration 15 / 30 | Total Loss: 3.488162040710449 | KNN Loss: 2.459139823913574 | BCE Loss: 1.029022216796875
Epoch 245 / 500 | iteration 20 / 30 | Total Loss: 3.5028584003448486 | KNN Loss: 2.4790284633636

Epoch 255 / 500 | iteration 0 / 30 | Total Loss: 3.429410696029663 | KNN Loss: 2.442307233810425 | BCE Loss: 0.9871034026145935
Epoch 255 / 500 | iteration 5 / 30 | Total Loss: 3.551013946533203 | KNN Loss: 2.4945268630981445 | BCE Loss: 1.0564870834350586
Epoch 255 / 500 | iteration 10 / 30 | Total Loss: 3.4544155597686768 | KNN Loss: 2.4412341117858887 | BCE Loss: 1.013181447982788
Epoch 255 / 500 | iteration 15 / 30 | Total Loss: 3.54347825050354 | KNN Loss: 2.5289134979248047 | BCE Loss: 1.0145647525787354
Epoch 255 / 500 | iteration 20 / 30 | Total Loss: 3.4640920162200928 | KNN Loss: 2.4676883220672607 | BCE Loss: 0.996403694152832
Epoch 255 / 500 | iteration 25 / 30 | Total Loss: 3.48828125 | KNN Loss: 2.468348979949951 | BCE Loss: 1.0199322700500488
Epoch 256 / 500 | iteration 0 / 30 | Total Loss: 3.474470615386963 | KNN Loss: 2.4420266151428223 | BCE Loss: 1.032443881034851
Epoch 256 / 500 | iteration 5 / 30 | Total Loss: 3.4802210330963135 | KNN Loss: 2.4563381671905518 | BCE

Epoch 265 / 500 | iteration 15 / 30 | Total Loss: 3.4921107292175293 | KNN Loss: 2.4892146587371826 | BCE Loss: 1.0028961896896362
Epoch 265 / 500 | iteration 20 / 30 | Total Loss: 3.4874446392059326 | KNN Loss: 2.452831983566284 | BCE Loss: 1.0346126556396484
Epoch 265 / 500 | iteration 25 / 30 | Total Loss: 3.4669370651245117 | KNN Loss: 2.456583023071289 | BCE Loss: 1.0103540420532227
Epoch 266 / 500 | iteration 0 / 30 | Total Loss: 3.451132297515869 | KNN Loss: 2.461760997772217 | BCE Loss: 0.9893711805343628
Epoch 266 / 500 | iteration 5 / 30 | Total Loss: 3.4446444511413574 | KNN Loss: 2.4657046794891357 | BCE Loss: 0.9789396524429321
Epoch 266 / 500 | iteration 10 / 30 | Total Loss: 3.482452869415283 | KNN Loss: 2.4561593532562256 | BCE Loss: 1.0262935161590576
Epoch 266 / 500 | iteration 15 / 30 | Total Loss: 3.453648090362549 | KNN Loss: 2.429795503616333 | BCE Loss: 1.0238525867462158
Epoch 266 / 500 | iteration 20 / 30 | Total Loss: 3.472653865814209 | KNN Loss: 2.4637365341

Epoch 276 / 500 | iteration 5 / 30 | Total Loss: 3.510298728942871 | KNN Loss: 2.470303773880005 | BCE Loss: 1.0399950742721558
Epoch 276 / 500 | iteration 10 / 30 | Total Loss: 3.5128023624420166 | KNN Loss: 2.480029582977295 | BCE Loss: 1.0327727794647217
Epoch 276 / 500 | iteration 15 / 30 | Total Loss: 3.441740036010742 | KNN Loss: 2.4424004554748535 | BCE Loss: 0.9993396401405334
Epoch 276 / 500 | iteration 20 / 30 | Total Loss: 3.4356942176818848 | KNN Loss: 2.4335148334503174 | BCE Loss: 1.0021793842315674
Epoch 276 / 500 | iteration 25 / 30 | Total Loss: 3.4870920181274414 | KNN Loss: 2.454796552658081 | BCE Loss: 1.03229558467865
Epoch 277 / 500 | iteration 0 / 30 | Total Loss: 3.447707176208496 | KNN Loss: 2.4312632083892822 | BCE Loss: 1.0164440870285034
Epoch 277 / 500 | iteration 5 / 30 | Total Loss: 3.5148987770080566 | KNN Loss: 2.5022695064544678 | BCE Loss: 1.0126291513442993
Epoch 277 / 500 | iteration 10 / 30 | Total Loss: 3.424781560897827 | KNN Loss: 2.410676002502

Epoch 286 / 500 | iteration 20 / 30 | Total Loss: 3.4823737144470215 | KNN Loss: 2.461528778076172 | BCE Loss: 1.0208449363708496
Epoch 286 / 500 | iteration 25 / 30 | Total Loss: 3.485708236694336 | KNN Loss: 2.4700653553009033 | BCE Loss: 1.0156430006027222
Epoch 287 / 500 | iteration 0 / 30 | Total Loss: 3.476346492767334 | KNN Loss: 2.468585729598999 | BCE Loss: 1.007760763168335
Epoch 287 / 500 | iteration 5 / 30 | Total Loss: 3.450305938720703 | KNN Loss: 2.446007251739502 | BCE Loss: 1.0042986869812012
Epoch 287 / 500 | iteration 10 / 30 | Total Loss: 3.4579336643218994 | KNN Loss: 2.4602715969085693 | BCE Loss: 0.9976620674133301
Epoch 287 / 500 | iteration 15 / 30 | Total Loss: 3.516634225845337 | KNN Loss: 2.496877670288086 | BCE Loss: 1.019756555557251
Epoch 287 / 500 | iteration 20 / 30 | Total Loss: 3.4482507705688477 | KNN Loss: 2.442563533782959 | BCE Loss: 1.0056873559951782
Epoch 287 / 500 | iteration 25 / 30 | Total Loss: 3.4620256423950195 | KNN Loss: 2.4473774433135

Epoch 297 / 500 | iteration 5 / 30 | Total Loss: 3.493544340133667 | KNN Loss: 2.486621856689453 | BCE Loss: 1.0069224834442139
Epoch 297 / 500 | iteration 10 / 30 | Total Loss: 3.4447829723358154 | KNN Loss: 2.447726249694824 | BCE Loss: 0.9970566630363464
Epoch 297 / 500 | iteration 15 / 30 | Total Loss: 3.4541220664978027 | KNN Loss: 2.4394986629486084 | BCE Loss: 1.0146234035491943
Epoch 297 / 500 | iteration 20 / 30 | Total Loss: 3.4610347747802734 | KNN Loss: 2.4465529918670654 | BCE Loss: 1.014481782913208
Epoch 297 / 500 | iteration 25 / 30 | Total Loss: 3.4590916633605957 | KNN Loss: 2.443528890609741 | BCE Loss: 1.0155627727508545
Epoch 298 / 500 | iteration 0 / 30 | Total Loss: 3.4590492248535156 | KNN Loss: 2.4450724124908447 | BCE Loss: 1.013976812362671
Epoch 298 / 500 | iteration 5 / 30 | Total Loss: 3.4908151626586914 | KNN Loss: 2.456394910812378 | BCE Loss: 1.0344202518463135
Epoch 298 / 500 | iteration 10 / 30 | Total Loss: 3.4581780433654785 | KNN Loss: 2.4371409416

Epoch 307 / 500 | iteration 25 / 30 | Total Loss: 3.483628273010254 | KNN Loss: 2.480242967605591 | BCE Loss: 1.003385305404663
Epoch 308 / 500 | iteration 0 / 30 | Total Loss: 3.459257125854492 | KNN Loss: 2.433971643447876 | BCE Loss: 1.0252854824066162
Epoch 308 / 500 | iteration 5 / 30 | Total Loss: 3.478952646255493 | KNN Loss: 2.4826176166534424 | BCE Loss: 0.9963350892066956
Epoch 308 / 500 | iteration 10 / 30 | Total Loss: 3.4686474800109863 | KNN Loss: 2.454085350036621 | BCE Loss: 1.0145621299743652
Epoch 308 / 500 | iteration 15 / 30 | Total Loss: 3.435175895690918 | KNN Loss: 2.4399054050445557 | BCE Loss: 0.9952703714370728
Epoch 308 / 500 | iteration 20 / 30 | Total Loss: 3.4631505012512207 | KNN Loss: 2.4547712802886963 | BCE Loss: 1.0083792209625244
Epoch 308 / 500 | iteration 25 / 30 | Total Loss: 3.4567677974700928 | KNN Loss: 2.4304306507110596 | BCE Loss: 1.0263371467590332
Epoch 309 / 500 | iteration 0 / 30 | Total Loss: 3.4513254165649414 | KNN Loss: 2.43362140655

Epoch 318 / 500 | iteration 10 / 30 | Total Loss: 3.483295440673828 | KNN Loss: 2.460132598876953 | BCE Loss: 1.023162841796875
Epoch 318 / 500 | iteration 15 / 30 | Total Loss: 3.439887285232544 | KNN Loss: 2.427950620651245 | BCE Loss: 1.0119366645812988
Epoch 318 / 500 | iteration 20 / 30 | Total Loss: 3.464860200881958 | KNN Loss: 2.4646353721618652 | BCE Loss: 1.0002248287200928
Epoch 318 / 500 | iteration 25 / 30 | Total Loss: 3.458597183227539 | KNN Loss: 2.4608895778656006 | BCE Loss: 0.9977074861526489
Epoch 319 / 500 | iteration 0 / 30 | Total Loss: 3.491584300994873 | KNN Loss: 2.5006511211395264 | BCE Loss: 0.9909331202507019
Epoch 319 / 500 | iteration 5 / 30 | Total Loss: 3.4923014640808105 | KNN Loss: 2.490429162979126 | BCE Loss: 1.001872181892395
Epoch 319 / 500 | iteration 10 / 30 | Total Loss: 3.5051076412200928 | KNN Loss: 2.467672824859619 | BCE Loss: 1.0374348163604736
Epoch 319 / 500 | iteration 15 / 30 | Total Loss: 3.4858639240264893 | KNN Loss: 2.4609615802764

Epoch 328 / 500 | iteration 25 / 30 | Total Loss: 3.466801643371582 | KNN Loss: 2.441115617752075 | BCE Loss: 1.0256860256195068
Epoch 329 / 500 | iteration 0 / 30 | Total Loss: 3.4257402420043945 | KNN Loss: 2.4347081184387207 | BCE Loss: 0.9910321235656738
Epoch 329 / 500 | iteration 5 / 30 | Total Loss: 3.4547743797302246 | KNN Loss: 2.451307773590088 | BCE Loss: 1.0034664869308472
Epoch 329 / 500 | iteration 10 / 30 | Total Loss: 3.469719886779785 | KNN Loss: 2.455233573913574 | BCE Loss: 1.014486312866211
Epoch 329 / 500 | iteration 15 / 30 | Total Loss: 3.4591481685638428 | KNN Loss: 2.4565937519073486 | BCE Loss: 1.0025544166564941
Epoch 329 / 500 | iteration 20 / 30 | Total Loss: 3.4851789474487305 | KNN Loss: 2.4718945026397705 | BCE Loss: 1.01328444480896
Epoch 329 / 500 | iteration 25 / 30 | Total Loss: 3.4757766723632812 | KNN Loss: 2.4731595516204834 | BCE Loss: 1.0026172399520874
Epoch 330 / 500 | iteration 0 / 30 | Total Loss: 3.5093994140625 | KNN Loss: 2.47561144828796

Epoch 339 / 500 | iteration 10 / 30 | Total Loss: 3.4653587341308594 | KNN Loss: 2.474175453186035 | BCE Loss: 0.9911832213401794
Epoch 339 / 500 | iteration 15 / 30 | Total Loss: 3.4786505699157715 | KNN Loss: 2.444380760192871 | BCE Loss: 1.03426992893219
Epoch 339 / 500 | iteration 20 / 30 | Total Loss: 3.4617252349853516 | KNN Loss: 2.4548959732055664 | BCE Loss: 1.0068292617797852
Epoch 339 / 500 | iteration 25 / 30 | Total Loss: 3.4984078407287598 | KNN Loss: 2.4705822467803955 | BCE Loss: 1.0278254747390747
Epoch 340 / 500 | iteration 0 / 30 | Total Loss: 3.4546103477478027 | KNN Loss: 2.452166795730591 | BCE Loss: 1.002443552017212
Epoch 340 / 500 | iteration 5 / 30 | Total Loss: 3.4487791061401367 | KNN Loss: 2.4425504207611084 | BCE Loss: 1.0062286853790283
Epoch 340 / 500 | iteration 10 / 30 | Total Loss: 3.507533073425293 | KNN Loss: 2.469407558441162 | BCE Loss: 1.0381255149841309
Epoch 340 / 500 | iteration 15 / 30 | Total Loss: 3.4662647247314453 | KNN Loss: 2.4357957839

Epoch 349 / 500 | iteration 25 / 30 | Total Loss: 3.4649462699890137 | KNN Loss: 2.447113275527954 | BCE Loss: 1.0178329944610596
Epoch 350 / 500 | iteration 0 / 30 | Total Loss: 3.4282803535461426 | KNN Loss: 2.448997735977173 | BCE Loss: 0.9792827367782593
Epoch 350 / 500 | iteration 5 / 30 | Total Loss: 3.42014217376709 | KNN Loss: 2.41744327545166 | BCE Loss: 1.0026988983154297
Epoch 350 / 500 | iteration 10 / 30 | Total Loss: 3.465639352798462 | KNN Loss: 2.444157123565674 | BCE Loss: 1.021482229232788
Epoch 350 / 500 | iteration 15 / 30 | Total Loss: 3.4842031002044678 | KNN Loss: 2.469306707382202 | BCE Loss: 1.0148963928222656
Epoch 350 / 500 | iteration 20 / 30 | Total Loss: 3.453723192214966 | KNN Loss: 2.4256370067596436 | BCE Loss: 1.0280861854553223
Epoch 350 / 500 | iteration 25 / 30 | Total Loss: 3.4656639099121094 | KNN Loss: 2.465247631072998 | BCE Loss: 1.0004163980484009
Epoch 351 / 500 | iteration 0 / 30 | Total Loss: 3.4614477157592773 | KNN Loss: 2.422975778579712

Epoch 360 / 500 | iteration 10 / 30 | Total Loss: 3.442615032196045 | KNN Loss: 2.430036783218384 | BCE Loss: 1.0125782489776611
Epoch 360 / 500 | iteration 15 / 30 | Total Loss: 3.413458824157715 | KNN Loss: 2.4186089038848877 | BCE Loss: 0.9948499202728271
Epoch 360 / 500 | iteration 20 / 30 | Total Loss: 3.5144314765930176 | KNN Loss: 2.482703924179077 | BCE Loss: 1.03172767162323
Epoch 360 / 500 | iteration 25 / 30 | Total Loss: 3.4646236896514893 | KNN Loss: 2.4516429901123047 | BCE Loss: 1.0129806995391846
Epoch 361 / 500 | iteration 0 / 30 | Total Loss: 3.5039286613464355 | KNN Loss: 2.4803898334503174 | BCE Loss: 1.0235388278961182
Epoch 361 / 500 | iteration 5 / 30 | Total Loss: 3.4024136066436768 | KNN Loss: 2.4189484119415283 | BCE Loss: 0.9834652543067932
Epoch 361 / 500 | iteration 10 / 30 | Total Loss: 3.475543975830078 | KNN Loss: 2.477698802947998 | BCE Loss: 0.9978450536727905
Epoch 361 / 500 | iteration 15 / 30 | Total Loss: 3.463165283203125 | KNN Loss: 2.43229413032

Epoch   371: reducing learning rate of group 0 to 1.6616e-05.
Epoch 371 / 500 | iteration 0 / 30 | Total Loss: 3.4461357593536377 | KNN Loss: 2.433032989501953 | BCE Loss: 1.0131027698516846
Epoch 371 / 500 | iteration 5 / 30 | Total Loss: 3.5088186264038086 | KNN Loss: 2.47672700881958 | BCE Loss: 1.0320916175842285
Epoch 371 / 500 | iteration 10 / 30 | Total Loss: 3.442269802093506 | KNN Loss: 2.445831775665283 | BCE Loss: 0.9964380264282227
Epoch 371 / 500 | iteration 15 / 30 | Total Loss: 3.407585620880127 | KNN Loss: 2.4274895191192627 | BCE Loss: 0.9800959825515747
Epoch 371 / 500 | iteration 20 / 30 | Total Loss: 3.5030508041381836 | KNN Loss: 2.4553632736206055 | BCE Loss: 1.0476876497268677
Epoch 371 / 500 | iteration 25 / 30 | Total Loss: 3.4787588119506836 | KNN Loss: 2.4788661003112793 | BCE Loss: 0.9998928308486938
Epoch 372 / 500 | iteration 0 / 30 | Total Loss: 3.4816625118255615 | KNN Loss: 2.4454808235168457 | BCE Loss: 1.0361816883087158
Epoch 372 / 500 | iteration 5 

Epoch 381 / 500 | iteration 15 / 30 | Total Loss: 3.4975507259368896 | KNN Loss: 2.4710843563079834 | BCE Loss: 1.0264663696289062
Epoch 381 / 500 | iteration 20 / 30 | Total Loss: 3.489368438720703 | KNN Loss: 2.464390993118286 | BCE Loss: 1.0249775648117065
Epoch 381 / 500 | iteration 25 / 30 | Total Loss: 3.441452741622925 | KNN Loss: 2.459196090698242 | BCE Loss: 0.9822566509246826
Epoch   382: reducing learning rate of group 0 to 1.1632e-05.
Epoch 382 / 500 | iteration 0 / 30 | Total Loss: 3.4230570793151855 | KNN Loss: 2.423502206802368 | BCE Loss: 0.9995548725128174
Epoch 382 / 500 | iteration 5 / 30 | Total Loss: 3.463864803314209 | KNN Loss: 2.440539598464966 | BCE Loss: 1.0233252048492432
Epoch 382 / 500 | iteration 10 / 30 | Total Loss: 3.420931816101074 | KNN Loss: 2.441098213195801 | BCE Loss: 0.9798336029052734
Epoch 382 / 500 | iteration 15 / 30 | Total Loss: 3.4768993854522705 | KNN Loss: 2.4749045372009277 | BCE Loss: 1.0019948482513428
Epoch 382 / 500 | iteration 20 /

Epoch 392 / 500 | iteration 0 / 30 | Total Loss: 3.4671812057495117 | KNN Loss: 2.4407200813293457 | BCE Loss: 1.0264610052108765
Epoch 392 / 500 | iteration 5 / 30 | Total Loss: 3.426889419555664 | KNN Loss: 2.4338812828063965 | BCE Loss: 0.9930081963539124
Epoch 392 / 500 | iteration 10 / 30 | Total Loss: 3.4560275077819824 | KNN Loss: 2.4413228034973145 | BCE Loss: 1.014704704284668
Epoch 392 / 500 | iteration 15 / 30 | Total Loss: 3.4801790714263916 | KNN Loss: 2.470973014831543 | BCE Loss: 1.0092060565948486
Epoch 392 / 500 | iteration 20 / 30 | Total Loss: 3.497077703475952 | KNN Loss: 2.4936747550964355 | BCE Loss: 1.0034029483795166
Epoch 392 / 500 | iteration 25 / 30 | Total Loss: 3.487529754638672 | KNN Loss: 2.466723918914795 | BCE Loss: 1.0208057165145874
Epoch   393: reducing learning rate of group 0 to 8.1421e-06.
Epoch 393 / 500 | iteration 0 / 30 | Total Loss: 3.4667410850524902 | KNN Loss: 2.4551947116851807 | BCE Loss: 1.0115463733673096
Epoch 393 / 500 | iteration 5 

Epoch 402 / 500 | iteration 15 / 30 | Total Loss: 3.4513816833496094 | KNN Loss: 2.4249818325042725 | BCE Loss: 1.026399850845337
Epoch 402 / 500 | iteration 20 / 30 | Total Loss: 3.47743821144104 | KNN Loss: 2.424076557159424 | BCE Loss: 1.0533616542816162
Epoch 402 / 500 | iteration 25 / 30 | Total Loss: 3.446455955505371 | KNN Loss: 2.477160692214966 | BCE Loss: 0.96929532289505
Epoch 403 / 500 | iteration 0 / 30 | Total Loss: 3.480337619781494 | KNN Loss: 2.448223114013672 | BCE Loss: 1.0321145057678223
Epoch 403 / 500 | iteration 5 / 30 | Total Loss: 3.4990200996398926 | KNN Loss: 2.4762418270111084 | BCE Loss: 1.0227781534194946
Epoch 403 / 500 | iteration 10 / 30 | Total Loss: 3.4707510471343994 | KNN Loss: 2.4629528522491455 | BCE Loss: 1.007798194885254
Epoch 403 / 500 | iteration 15 / 30 | Total Loss: 3.4756102561950684 | KNN Loss: 2.47312593460083 | BCE Loss: 1.0024843215942383
Epoch 403 / 500 | iteration 20 / 30 | Total Loss: 3.5126137733459473 | KNN Loss: 2.488327026367187

Epoch 413 / 500 | iteration 5 / 30 | Total Loss: 3.4677071571350098 | KNN Loss: 2.45778751373291 | BCE Loss: 1.0099196434020996
Epoch 413 / 500 | iteration 10 / 30 | Total Loss: 3.472837448120117 | KNN Loss: 2.4448535442352295 | BCE Loss: 1.0279839038848877
Epoch 413 / 500 | iteration 15 / 30 | Total Loss: 3.4805819988250732 | KNN Loss: 2.4786388874053955 | BCE Loss: 1.0019431114196777
Epoch 413 / 500 | iteration 20 / 30 | Total Loss: 3.4549567699432373 | KNN Loss: 2.4527933597564697 | BCE Loss: 1.0021634101867676
Epoch 413 / 500 | iteration 25 / 30 | Total Loss: 3.462805986404419 | KNN Loss: 2.4537529945373535 | BCE Loss: 1.0090529918670654
Epoch 414 / 500 | iteration 0 / 30 | Total Loss: 3.4220261573791504 | KNN Loss: 2.43038010597229 | BCE Loss: 0.9916461110115051
Epoch 414 / 500 | iteration 5 / 30 | Total Loss: 3.4872870445251465 | KNN Loss: 2.4481289386749268 | BCE Loss: 1.0391581058502197
Epoch 414 / 500 | iteration 10 / 30 | Total Loss: 3.5133628845214844 | KNN Loss: 2.488320112

Epoch 423 / 500 | iteration 20 / 30 | Total Loss: 3.412867784500122 | KNN Loss: 2.417356252670288 | BCE Loss: 0.995511531829834
Epoch 423 / 500 | iteration 25 / 30 | Total Loss: 3.4994609355926514 | KNN Loss: 2.478468894958496 | BCE Loss: 1.0209920406341553
Epoch 424 / 500 | iteration 0 / 30 | Total Loss: 3.459414482116699 | KNN Loss: 2.4588637351989746 | BCE Loss: 1.000550627708435
Epoch 424 / 500 | iteration 5 / 30 | Total Loss: 3.415191173553467 | KNN Loss: 2.41640043258667 | BCE Loss: 0.9987908601760864
Epoch 424 / 500 | iteration 10 / 30 | Total Loss: 3.480778694152832 | KNN Loss: 2.469973564147949 | BCE Loss: 1.0108051300048828
Epoch 424 / 500 | iteration 15 / 30 | Total Loss: 3.4566850662231445 | KNN Loss: 2.4379448890686035 | BCE Loss: 1.018740177154541
Epoch 424 / 500 | iteration 20 / 30 | Total Loss: 3.447805404663086 | KNN Loss: 2.4166512489318848 | BCE Loss: 1.0311541557312012
Epoch 424 / 500 | iteration 25 / 30 | Total Loss: 3.4985873699188232 | KNN Loss: 2.459856033325195

Epoch 434 / 500 | iteration 5 / 30 | Total Loss: 3.496371269226074 | KNN Loss: 2.46533203125 | BCE Loss: 1.0310393571853638
Epoch 434 / 500 | iteration 10 / 30 | Total Loss: 3.450620174407959 | KNN Loss: 2.459688186645508 | BCE Loss: 0.9909321069717407
Epoch 434 / 500 | iteration 15 / 30 | Total Loss: 3.4168701171875 | KNN Loss: 2.4359307289123535 | BCE Loss: 0.980939507484436
Epoch 434 / 500 | iteration 20 / 30 | Total Loss: 3.5168821811676025 | KNN Loss: 2.4583756923675537 | BCE Loss: 1.0585064888000488
Epoch 434 / 500 | iteration 25 / 30 | Total Loss: 3.4602153301239014 | KNN Loss: 2.4537246227264404 | BCE Loss: 1.006490707397461
Epoch 435 / 500 | iteration 0 / 30 | Total Loss: 3.467649459838867 | KNN Loss: 2.4405202865600586 | BCE Loss: 1.0271291732788086
Epoch 435 / 500 | iteration 5 / 30 | Total Loss: 3.4647293090820312 | KNN Loss: 2.442828893661499 | BCE Loss: 1.0219004154205322
Epoch 435 / 500 | iteration 10 / 30 | Total Loss: 3.462029218673706 | KNN Loss: 2.4423508644104004 | 

Epoch 444 / 500 | iteration 20 / 30 | Total Loss: 3.4964520931243896 | KNN Loss: 2.4649245738983154 | BCE Loss: 1.0315275192260742
Epoch 444 / 500 | iteration 25 / 30 | Total Loss: 3.439265727996826 | KNN Loss: 2.4136388301849365 | BCE Loss: 1.0256270170211792
Epoch 445 / 500 | iteration 0 / 30 | Total Loss: 3.471036195755005 | KNN Loss: 2.4681406021118164 | BCE Loss: 1.0028955936431885
Epoch 445 / 500 | iteration 5 / 30 | Total Loss: 3.4477429389953613 | KNN Loss: 2.438974618911743 | BCE Loss: 1.0087684392929077
Epoch 445 / 500 | iteration 10 / 30 | Total Loss: 3.4605307579040527 | KNN Loss: 2.4481873512268066 | BCE Loss: 1.0123432874679565
Epoch 445 / 500 | iteration 15 / 30 | Total Loss: 3.410706043243408 | KNN Loss: 2.438807487487793 | BCE Loss: 0.9718984961509705
Epoch 445 / 500 | iteration 20 / 30 | Total Loss: 3.47870135307312 | KNN Loss: 2.4583938121795654 | BCE Loss: 1.0203075408935547
Epoch 445 / 500 | iteration 25 / 30 | Total Loss: 3.456961154937744 | KNN Loss: 2.4455015659

Epoch 455 / 500 | iteration 5 / 30 | Total Loss: 3.506587505340576 | KNN Loss: 2.482028007507324 | BCE Loss: 1.0245593786239624
Epoch 455 / 500 | iteration 10 / 30 | Total Loss: 3.4664971828460693 | KNN Loss: 2.4481492042541504 | BCE Loss: 1.018347978591919
Epoch 455 / 500 | iteration 15 / 30 | Total Loss: 3.4151315689086914 | KNN Loss: 2.4421021938323975 | BCE Loss: 0.9730292558670044
Epoch 455 / 500 | iteration 20 / 30 | Total Loss: 3.4269938468933105 | KNN Loss: 2.422834634780884 | BCE Loss: 1.0041590929031372
Epoch 455 / 500 | iteration 25 / 30 | Total Loss: 3.4638149738311768 | KNN Loss: 2.476300001144409 | BCE Loss: 0.9875149130821228
Epoch 456 / 500 | iteration 0 / 30 | Total Loss: 3.5066330432891846 | KNN Loss: 2.4695498943328857 | BCE Loss: 1.0370831489562988
Epoch 456 / 500 | iteration 5 / 30 | Total Loss: 3.490180730819702 | KNN Loss: 2.455274820327759 | BCE Loss: 1.0349059104919434
Epoch 456 / 500 | iteration 10 / 30 | Total Loss: 3.4882771968841553 | KNN Loss: 2.4746177196

Epoch 465 / 500 | iteration 20 / 30 | Total Loss: 3.490302085876465 | KNN Loss: 2.4632110595703125 | BCE Loss: 1.0270910263061523
Epoch 465 / 500 | iteration 25 / 30 | Total Loss: 3.4911444187164307 | KNN Loss: 2.46443247795105 | BCE Loss: 1.0267119407653809
Epoch 466 / 500 | iteration 0 / 30 | Total Loss: 3.4638800621032715 | KNN Loss: 2.4437785148620605 | BCE Loss: 1.020101547241211
Epoch 466 / 500 | iteration 5 / 30 | Total Loss: 3.459263563156128 | KNN Loss: 2.4353206157684326 | BCE Loss: 1.0239429473876953
Epoch 466 / 500 | iteration 10 / 30 | Total Loss: 3.4581027030944824 | KNN Loss: 2.446424961090088 | BCE Loss: 1.0116777420043945
Epoch 466 / 500 | iteration 15 / 30 | Total Loss: 3.4539642333984375 | KNN Loss: 2.453716993331909 | BCE Loss: 1.0002471208572388
Epoch 466 / 500 | iteration 20 / 30 | Total Loss: 3.4721603393554688 | KNN Loss: 2.465996265411377 | BCE Loss: 1.0061639547348022
Epoch 466 / 500 | iteration 25 / 30 | Total Loss: 3.4632041454315186 | KNN Loss: 2.4549984931

Epoch 476 / 500 | iteration 5 / 30 | Total Loss: 3.435896158218384 | KNN Loss: 2.4375128746032715 | BCE Loss: 0.9983832240104675
Epoch 476 / 500 | iteration 10 / 30 | Total Loss: 3.4385604858398438 | KNN Loss: 2.4401745796203613 | BCE Loss: 0.9983858466148376
Epoch 476 / 500 | iteration 15 / 30 | Total Loss: 3.512115240097046 | KNN Loss: 2.4807064533233643 | BCE Loss: 1.0314087867736816
Epoch 476 / 500 | iteration 20 / 30 | Total Loss: 3.46596097946167 | KNN Loss: 2.4419243335723877 | BCE Loss: 1.0240366458892822
Epoch 476 / 500 | iteration 25 / 30 | Total Loss: 3.487907886505127 | KNN Loss: 2.4812941551208496 | BCE Loss: 1.0066137313842773
Epoch 477 / 500 | iteration 0 / 30 | Total Loss: 3.494438648223877 | KNN Loss: 2.4712114334106445 | BCE Loss: 1.0232272148132324
Epoch 477 / 500 | iteration 5 / 30 | Total Loss: 3.5344932079315186 | KNN Loss: 2.517587423324585 | BCE Loss: 1.0169057846069336
Epoch 477 / 500 | iteration 10 / 30 | Total Loss: 3.449920654296875 | KNN Loss: 2.44377946853

Epoch 486 / 500 | iteration 20 / 30 | Total Loss: 3.497659206390381 | KNN Loss: 2.462867498397827 | BCE Loss: 1.0347917079925537
Epoch 486 / 500 | iteration 25 / 30 | Total Loss: 3.446634531021118 | KNN Loss: 2.426318883895874 | BCE Loss: 1.0203156471252441
Epoch 487 / 500 | iteration 0 / 30 | Total Loss: 3.4695658683776855 | KNN Loss: 2.4240596294403076 | BCE Loss: 1.0455061197280884
Epoch 487 / 500 | iteration 5 / 30 | Total Loss: 3.4758872985839844 | KNN Loss: 2.467365026473999 | BCE Loss: 1.0085222721099854
Epoch 487 / 500 | iteration 10 / 30 | Total Loss: 3.442422866821289 | KNN Loss: 2.435115098953247 | BCE Loss: 1.007307767868042
Epoch 487 / 500 | iteration 15 / 30 | Total Loss: 3.4759304523468018 | KNN Loss: 2.4420318603515625 | BCE Loss: 1.0338985919952393
Epoch 487 / 500 | iteration 20 / 30 | Total Loss: 3.4397456645965576 | KNN Loss: 2.422353982925415 | BCE Loss: 1.0173916816711426
Epoch 487 / 500 | iteration 25 / 30 | Total Loss: 3.460203170776367 | KNN Loss: 2.462445497512

Epoch 497 / 500 | iteration 5 / 30 | Total Loss: 3.4742002487182617 | KNN Loss: 2.4685566425323486 | BCE Loss: 1.0056437253952026
Epoch 497 / 500 | iteration 10 / 30 | Total Loss: 3.459146738052368 | KNN Loss: 2.4547905921936035 | BCE Loss: 1.0043561458587646
Epoch 497 / 500 | iteration 15 / 30 | Total Loss: 3.43691349029541 | KNN Loss: 2.4403629302978516 | BCE Loss: 0.996550440788269
Epoch 497 / 500 | iteration 20 / 30 | Total Loss: 3.407719373703003 | KNN Loss: 2.4244635105133057 | BCE Loss: 0.9832558631896973
Epoch 497 / 500 | iteration 25 / 30 | Total Loss: 3.4951319694519043 | KNN Loss: 2.4849376678466797 | BCE Loss: 1.0101943016052246
Epoch 498 / 500 | iteration 0 / 30 | Total Loss: 3.4707562923431396 | KNN Loss: 2.453183650970459 | BCE Loss: 1.0175726413726807
Epoch 498 / 500 | iteration 5 / 30 | Total Loss: 3.4635910987854004 | KNN Loss: 2.4739856719970703 | BCE Loss: 0.9896055459976196
Epoch 498 / 500 | iteration 10 / 30 | Total Loss: 3.4412012100219727 | KNN Loss: 2.442117929

In [7]:
outputs, iterm = model(dataset[67][0].unsqueeze(0).to(device), return_intermidiate=True)
print(outputs)
print(outputs.min())
print(outputs.max())

tensor([[ 1.8485,  4.5208,  3.0494,  2.1339,  3.6342,  0.9707,  1.8464,  2.4946,
          2.8739,  1.1835,  2.6940,  2.8041,  0.4477,  2.2759,  0.6752,  1.7183,
          1.1334,  3.7789,  2.5588,  0.8831,  1.1705,  1.5609,  2.6076,  2.4545,
          3.0419,  0.7232,  2.4458,  0.8959,  1.3016,  0.3329, -0.1169,  0.7758,
          0.1671, -0.0449,  0.5050,  0.4908,  0.8311,  1.7182,  1.0354,  0.3517,
          0.7787, -0.2847,  0.0153,  1.3845,  2.5655,  0.6838, -0.5005, -0.4219,
          1.6837,  2.9835,  2.2762, -0.1934,  0.7293,  0.9027, -0.5613,  0.6813,
          1.0101,  0.4168,  1.8063,  1.5576, -0.4391,  0.3404,  0.1638,  1.2698,
          1.8243,  0.3636, -1.4577, -0.0348,  2.2464,  2.4268,  1.2849,  0.8021,
          0.1023,  2.4085,  1.5305,  1.6561,  0.4774,  0.2325,  0.6024,  1.7814,
          0.1514,  0.6269,  0.3638, -0.0341,  0.1555, -0.8348, -2.1540, -0.0870,
          0.7430, -2.2588,  0.6612, -0.3409, -0.6863, -0.9409,  0.9034,  1.6642,
         -0.6596, -0.6228,  

In [8]:
plt.figure()
plt.plot(losses)
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [9]:
dataset.transform = torchvision.transforms.Compose([
    BinaryEncodingTransform(mapping=dataset.items_to_idx),
]
)
dataset.target_transform = torchvision.transforms.Compose([
    BinaryEncodingTransform(mapping=dataset.items_to_idx),
]
)

In [10]:
dataset_ = [d[0].cpu() for d in dataset]

In [11]:
model = model.eval().cpu()
projections = model.calculate_intermidiate(dataset_)

100%|██████████| 15/15 [00:00<00:00, 77.00it/s]


In [12]:
distances = pairwise_distances(projections)
# distances = np.triu(distances)
distances_f = distances.flatten()

plt.matshow(distances)
plt.colorbar()
plt.figure()
plt.hist(distances_f[distances_f > 0], bins=1000)
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

# Fit DBSCAN and calculate indices


In [13]:
clusters = DBSCAN(eps=0.2, min_samples=80).fit_predict(projections)
# scores = []
# best_score = float('inf')
# clusters = None
# range_ = list(range(5, 20))
# for k in tqdm(range_):
#     y = GaussianMixture(n_components=k).fit_predict(projections)
#     cur_score = davies_bouldin_score(projections, y)
#     scores.append(cur_score)
    
#     if cur_score < best_score:
#         best_score = cur_score
#         clusters = y

In [14]:
perplexity = 100

p = reduce_dims_and_plot(projections[clusters != -1],
                         y=clusters[clusters != -1],
                         title=f'perplexity: {perplexity}',
                         file_name=None,
                         perplexity=perplexity,
                         library='Multicore-TSNE',
                         perform_PCA=False,
                         projected=None,
                         figure_type='2d',
                         show_figure=True,
                         close_figure=False,
                         text=None)    

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [15]:
# from sklearn.tree import DecisionTreeClassifier
# from sklearn import tree
# from sklearn.tree import _tree

In [16]:
tensor_dataset = torch.stack(dataset_)

In [17]:
# clf = DecisionTreeClassifier(max_depth=200, min_samples_leaf=5)
# clf = clf.fit(tensor_dataset[clusters!=-1], clusters[clusters != -1])
# print(clf.score(tensor_dataset[clusters!=-1], clusters[clusters != -1]))
# print(clf.get_depth())

In [18]:
# scores = []
# for min_samples in range(1,50, 1):
#     clf = DecisionTreeClassifier(max_depth=200, min_samples_leaf=min_samples)
#     clf = clf.fit(tensor_dataset[clusters!=-1], clusters[clusters != -1])
#     scores.append(clf.score(tensor_dataset[clusters!=-1], clusters[clusters != -1]))
    
# plt.figure()
# plt.plot(list(range(1,50, 1)), scores)
# plt.show()

In [19]:
def get_rules(tree, feature_names, class_names):
    tree_ = tree.tree_
    feature_name = [
        feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!"
        for i in tree_.feature
    ]

    paths = []
    path = []
    
    def recurse(node, path, paths):
        if tree_.feature[node] != _tree.TREE_UNDEFINED:
            name = feature_name[node]
            threshold = tree_.threshold[node]
            p1, p2 = list(path), list(path)
#             p1 += [f"({name} <= {np.round(threshold, 3)})"]
            p1 += [(name, '<=', np.round(threshold, 3))]
            recurse(tree_.children_left[node], p1, paths)
            p2 += [(name, '>', np.round(threshold, 3))]
            recurse(tree_.children_right[node], p2, paths)
        else:
            path += [(tree_.value[node], tree_.n_node_samples[node])]
            paths += [path]
            
    recurse(0, path, paths)

    # sort by samples count
    samples_count = [p[-1][1] for p in paths]
    ii = list(np.argsort(samples_count))
    paths = [paths[i] for i in reversed(ii)]
    
    rules = []
    for path in paths:
        rule = []
        
        for p in path[:-1]:
            rule += [p]
        target = " then "
        if class_names is None:
            target += "response: "+str(np.round(path[-1][0][0][0],3))
        else:
            classes = path[-1][0][0]
            l = np.argmax(classes)
            target += f"class: {class_names[l]} (proba: {np.round(100.0*classes[l]/np.sum(classes),2)}%)"
           
        proba = np.round(100.0*classes[l]/np.sum(classes),2)
        target += f" | based on {path[-1][1]:,} samples"
        rule_wrapper = {'target': target, 'rule': rule, 'proba': proba}
        rules += [rule_wrapper]
        
    return rules

In [20]:
# rules = get_rules(clf, dataset.items, clusters[clusters != -1])

# for rule in rules:
#     n_pos = 0
#     for c,p,v in rule['rule']:
#         if p == '>':
#             n_pos += 1
#     rule['pos'] = n_pos

In [21]:
# plt.figure()
# probs = [r['proba'] for r in rules]
# plt.hist(probs, bins = 100)
# plt.show()

In [22]:
# rules = sorted(rules, key=lambda x:x['pos'])
# rules = [r for r in rules if r['proba'] > 50]
# print(len(rules))

In [23]:
# for i in range(17):
#     r_i = rules[i]
#     print(f"------------- rule {i} length {len(r_i)} -------------")
#     print(r_i)

# Train a Soft-Decision-Tree given the self-labels

## Prepare the dataset

In [24]:
tree_dataset = list(zip(tensor_dataset[clusters!=-1], clusters[clusters != -1]))
batch_size = 512
tree_loader = torch.utils.data.DataLoader(tree_dataset, batch_size=batch_size, shuffle=True)

# Define how we prune the weights of a node

In [25]:
def prune_node(node_weights, factor=1):
    w = node_weights.cpu().detach().numpy()
    mean_ = np.mean(w)
    std_ = np.std(w)
    node_weights[((mean_ - std_ * factor) < node_weights) & (node_weights < (mean_ + std_ * factor))] = 0
    return node_weights

def prune_node_keep(node_weights, keep=4):
    w = node_weights.cpu().detach().numpy()
    throw_idx = np.argsort(abs(w))[:-keep]
    node_weights[throw_idx] = 0
    return node_weights

def prune_tree(tree_, factor):
    new_weights = tree_.inner_nodes.weight.clone()
    for i in range(new_weights.shape[0]):
        res = prune_node_keep(new_weights[i, :], factor)
        new_weights[i, :] = res

    with torch.no_grad():
        tree_.inner_nodes.weight.copy_(new_weights)
        
def sparseness(x):
    s = []
    for i in range(x.shape[0]):
        x_ = x[i, :]
        sp = (len(x_) - torch.norm(x_, 0).item()) / len(x_)
        s.append(sp)
    return np.mean(s)

def compute_regularization_by_level(tree):
    total_reg = 0
    for i in range(tree.inner_nodes.weight.shape[0]):
        cur_level = np.floor(np.log2(i+1))
        node_reg = torch.norm(tree.inner_nodes.weight[i].view(-1), 2)
        total_reg += 2**(-cur_level) * node_reg
    return total_reg

def show_sparseness(tree):
    avg_sp = sparseness(tree.inner_nodes.weight)
    print(f"Average sparseness: {avg_sp}")
    layer = 0
    sps = []
    for i in range(tree.inner_nodes.weight.shape[0]):
        cur_layer = int(np.floor(np.log2(i+1)))
        if cur_layer != layer:
            print(f"layer {layer}: {np.mean(sps)}")
            sps = []
            layer = cur_layer

        x_ = tree.inner_nodes.weight[i, :]
        sp = (len(x_) - torch.norm(x_, 0).item()) / len(x_)
        sps.append(sp)
        
    return avg_sp

## Training configurations

In [26]:
def do_epoch(model, loader, device, log_interval, losses, accs, epoch, iteration):
    model = model.train()
    for batch_idx, (data, target) in enumerate(loader):
        iteration += 1
        data, target = data.to(device), target.to(device)

        output, penalty = tree.forward(data)

        # Loss
        loss_tree = criterion(output, target.view(-1))

        # Penalty
        loss_tree += penalty

        # Sparse regularization
#         fc_params = torch.cat([x.view(-1) for x in tree.inner_nodes.parameters()])
#         regularization = sparsity_lamda * torch.norm(fc_params, 2)
        regularization = sparsity_lamda * compute_regularization_by_level(tree)
        loss = loss_tree

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        losses.append(loss.item())

        pred = output.data.max(1)[1]
        correct = pred.eq(target.view(-1).data).sum()
        accs.append(correct.item() / data.size()[0])

        # Print training status
        if batch_idx % log_interval == 0:
            print(f"Epoch: {epoch:02d} | Batch: {batch_idx:03d} / {len(loader):03d} | Total loss: {loss.item():.3f} | Reg loss: {regularization.item():.3f} | Tree loss: {loss_tree.item():.3f} | Accuracy: {correct.item() / data.size()[0]:03f} | {round((time.time() - start_time) / iteration, 3)} sec/iter")
            
    return iteration


In [27]:
lr = 5e-3
weight_decay = 5e-4
sparsity_lamda = 2e-3
epochs = 100
output_dim = len(set(clusters))
log_interval = 1
use_cuda = device != 'cpu'

In [28]:
tree = SDT(input_dim=tensor_dataset.shape[1], output_dim=len(clusters - 1), depth=tree_depth, lamda=1e-3, use_cuda=use_cuda)
optimizer = torch.optim.Adam(tree.parameters(),
                                 lr=lr,
                                 weight_decay=weight_decay)
criterion = nn.CrossEntropyLoss()
tree = tree.to(device)

In [29]:
losses = []
accs = []
sparsity = []

In [30]:
start_time = time.time()
iteration = 0
for epoch in range(epochs):
    # Training
    avg_sp = show_sparseness(tree)
    sparsity.append(avg_sp)
    iteration = do_epoch(tree, tree_loader, device, log_interval, losses, accs, epoch, iteration)
    
    if epoch % 1 == 0:
        prune_tree(tree, factor=3)
        

Average sparseness: 0.0
layer 0: 0.0
layer 1: 0.0
layer 2: 0.0
layer 3: 0.0
layer 4: 0.0
layer 5: 0.0
layer 6: 0.0
layer 7: 0.0
layer 8: 0.0
Epoch: 00 | Batch: 000 / 002 | Total loss: 9.645 | Reg loss: 0.012 | Tree loss: 9.645 | Accuracy: 0.000000 | 1.038 sec/iter
Epoch: 00 | Batch: 001 / 002 | Total loss: 9.639 | Reg loss: 0.011 | Tree loss: 9.639 | Accuracy: 0.000000 | 0.952 sec/iter
Average sparseness: 0.9821428571428573
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
layer 7: 0.9821428571428573
layer 8: 0.9821428571428573
Epoch: 01 | Batch: 000 / 002 | Total loss: 9.635 | Reg loss: 0.003 | Tree loss: 9.635 | Accuracy: 0.000000 | 1.088 sec/iter
Epoch: 01 | Batch: 001 / 002 | Total loss: 9.628 | Reg loss: 0.003 | Tree loss: 9.628 | Accuracy: 0.000000 | 1.028 sec/iter
Average sparseness: 0.9821428571428573
layer 0: 0.9821428571428571
layer 

Epoch: 15 | Batch: 000 / 002 | Total loss: 9.578 | Reg loss: 0.004 | Tree loss: 9.578 | Accuracy: 0.000000 | 1.092 sec/iter
Epoch: 15 | Batch: 001 / 002 | Total loss: 9.568 | Reg loss: 0.004 | Tree loss: 9.568 | Accuracy: 0.009390 | 1.084 sec/iter
Average sparseness: 0.9821428571428573
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
layer 7: 0.9821428571428573
layer 8: 0.9821428571428573
Epoch: 16 | Batch: 000 / 002 | Total loss: 9.572 | Reg loss: 0.004 | Tree loss: 9.572 | Accuracy: 0.001953 | 1.093 sec/iter
Epoch: 16 | Batch: 001 / 002 | Total loss: 9.567 | Reg loss: 0.004 | Tree loss: 9.567 | Accuracy: 0.007042 | 1.086 sec/iter
Average sparseness: 0.9821428571428573
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6

Average sparseness: 0.9821428571428573
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
layer 7: 0.9821428571428573
layer 8: 0.9821428571428573
Epoch: 31 | Batch: 000 / 002 | Total loss: 9.514 | Reg loss: 0.006 | Tree loss: 9.514 | Accuracy: 0.269531 | 1.093 sec/iter
Epoch: 31 | Batch: 001 / 002 | Total loss: 9.505 | Reg loss: 0.006 | Tree loss: 9.505 | Accuracy: 0.295775 | 1.09 sec/iter
Average sparseness: 0.9821428571428573
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
layer 7: 0.9821428571428573
layer 8: 0.9821428571428573
Epoch: 32 | Batch: 000 / 002 | Total loss: 9.510 | Reg loss: 0.006 | Tree loss: 9.510 | Accuracy: 0.259766 | 1.093 sec/iter
Epoch: 32 | Batch: 001 / 002 | Total loss: 9.503 

Epoch: 46 | Batch: 000 / 002 | Total loss: 9.457 | Reg loss: 0.007 | Tree loss: 9.457 | Accuracy: 0.267578 | 1.092 sec/iter
Epoch: 46 | Batch: 001 / 002 | Total loss: 9.446 | Reg loss: 0.007 | Tree loss: 9.446 | Accuracy: 0.298122 | 1.089 sec/iter
Average sparseness: 0.9821428571428573
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
layer 7: 0.9821428571428573
layer 8: 0.9821428571428573
Epoch: 47 | Batch: 000 / 002 | Total loss: 9.448 | Reg loss: 0.007 | Tree loss: 9.448 | Accuracy: 0.298828 | 1.092 sec/iter
Epoch: 47 | Batch: 001 / 002 | Total loss: 9.448 | Reg loss: 0.007 | Tree loss: 9.448 | Accuracy: 0.260563 | 1.09 sec/iter
Average sparseness: 0.9821428571428573
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6:

Average sparseness: 0.9821428571428573
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
layer 7: 0.9821428571428573
layer 8: 0.9821428571428573
Epoch: 62 | Batch: 000 / 002 | Total loss: 9.389 | Reg loss: 0.009 | Tree loss: 9.389 | Accuracy: 0.263672 | 1.092 sec/iter
Epoch: 62 | Batch: 001 / 002 | Total loss: 9.374 | Reg loss: 0.009 | Tree loss: 9.374 | Accuracy: 0.302817 | 1.09 sec/iter
Average sparseness: 0.9821428571428573
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
layer 7: 0.9821428571428573
layer 8: 0.9821428571428573
Epoch: 63 | Batch: 000 / 002 | Total loss: 9.377 | Reg loss: 0.009 | Tree loss: 9.377 | Accuracy: 0.296875 | 1.092 sec/iter
Epoch: 63 | Batch: 001 / 002 | Total loss: 9.379 

Epoch: 77 | Batch: 000 / 002 | Total loss: 9.308 | Reg loss: 0.010 | Tree loss: 9.308 | Accuracy: 0.273438 | 1.091 sec/iter
Epoch: 77 | Batch: 001 / 002 | Total loss: 9.289 | Reg loss: 0.010 | Tree loss: 9.289 | Accuracy: 0.291080 | 1.09 sec/iter
Average sparseness: 0.9821428571428573
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
layer 7: 0.9821428571428573
layer 8: 0.9821428571428573
Epoch: 78 | Batch: 000 / 002 | Total loss: 9.301 | Reg loss: 0.010 | Tree loss: 9.301 | Accuracy: 0.263672 | 1.092 sec/iter
Epoch: 78 | Batch: 001 / 002 | Total loss: 9.284 | Reg loss: 0.011 | Tree loss: 9.284 | Accuracy: 0.302817 | 1.09 sec/iter
Average sparseness: 0.9821428571428573
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 

Average sparseness: 0.9821428571428573
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
layer 7: 0.9821428571428573
layer 8: 0.9821428571428573
Epoch: 93 | Batch: 000 / 002 | Total loss: 9.177 | Reg loss: 0.013 | Tree loss: 9.177 | Accuracy: 0.292969 | 1.092 sec/iter
Epoch: 93 | Batch: 001 / 002 | Total loss: 9.166 | Reg loss: 0.013 | Tree loss: 9.166 | Accuracy: 0.267606 | 1.091 sec/iter
Average sparseness: 0.9821428571428573
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
layer 7: 0.9821428571428573
layer 8: 0.9821428571428573
Epoch: 94 | Batch: 000 / 002 | Total loss: 9.170 | Reg loss: 0.013 | Tree loss: 9.170 | Accuracy: 0.279297 | 1.093 sec/iter
Epoch: 94 | Batch: 001 / 002 | Total loss: 9.154

In [31]:
plt.figure(figsize=(10, 5))
plt.ylabel("Accuracy")
plt.xlabel('Iteration')
plt.plot(accs, label='Accuracy vs iteration')
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [32]:
plt.figure()
plt.ylabel("Loss")
plt.xlabel('Iteration')
plt.plot(losses, label='Loss vs iteration')
plt.yscale("log")
plt.show()

plt.figure()
weights = tree.inner_nodes.weight.cpu().detach().numpy().flatten()
plt.hist(weights, bins=500)
weights_std = np.std(weights)
weights_mean = np.mean(weights)
plt.axvline(weights_mean + weights_std, color='r')
plt.axvline(weights_mean - weights_std, color='r')
plt.title(f"Mean: {weights_mean}   |   STD: {weights_std}")
plt.yscale("log")
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

# Tree Visualization

In [33]:
plt.figure(figsize=(15, 10), dpi=80)
avg_height, root = tree.visualize()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Average height: 10.0


# Extract Rules

# Accumulate samples in the leaves

In [34]:
print(f"Number of patterns: {len(root.get_leaves())}")

Number of patterns: 1024


In [35]:
method = 'greedy'

In [36]:
root.clear_leaves_samples()

with torch.no_grad():
    for batch_idx, (data, target) in enumerate(tree_loader):
        root.accumulate_samples(data, method)



# Tighten boundaries

In [37]:
attr_names = dataset.items

# print(attr_names)
leaves = root.get_leaves()
sum_comprehensibility = 0
comprehensibilities = []
for pattern_counter, leaf in enumerate(leaves):
    leaf.reset_path()
    leaf.tighten_with_accumulated_samples()
    conds = leaf.get_path_conditions(attr_names)
    print(f"============== Pattern {pattern_counter + 1} ==============")
    comprehensibilities.append(sum([cond.comprehensibility for cond in conds]))
    
print(f"Average comprehensibility: {np.mean(comprehensibilities)}")
print(f"std comprehensibility: {np.std(comprehensibilities)}")
print(f"var comprehensibility: {np.var(comprehensibilities)}")
print(f"minimum comprehensibility: {np.min(comprehensibilities)}")
print(f"maximum comprehensibility: {np.max(comprehensibilities)}")

501
188
42

  return np.log(1 / (1 - x))



7
1
1
136


3
3
2
1


18
1
15
9
5
5


Average comprehensibility: 50.94921875
std comprehensibility: 2.7834189883394194
var comprehensibility: 7.7474212646484375
minimum comprehensibility: 44
maximum comprehensibility: 60
