In [2]:
%load_ext autoreload
%autoreload 2
%matplotlib widget
import numpy as np
import pandas as pd
import torch
import torchvision
from torch import optim
import matplotlib.pyplot as plt
import torch.nn.functional as F
from sklearn.metrics import pairwise_distances
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tqdm.notebook import tqdm
import os
import sys
import time
import torch.nn as nn
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)
from stream_generators.market_basket_dataset import MarketBasketDataset, BinaryEncodingTransform, RemoveItemsTransform
from utils.MatplotlibUtils import reduce_dims_and_plot
from network.auto_encoder import AutoEncoder
from losses.knn_loss import KNNLoss
from sklearn.cluster import DBSCAN
from sklearn.mixture import GaussianMixture
from soft_decision_tree.sdt_model import SDT
from sklearn.metrics import davies_bouldin_score

In [3]:
k = 8
tree_depth = 6
device = 'cuda'
dataset_path = r"/mnt/qnap/ekosman/Groceries_dataset.csv"

## Create the market basket dataset and use one-hot encoding for items

In [4]:
dataset = MarketBasketDataset(dataset_path=dataset_path)

In [5]:
model = AutoEncoder(dataset.n_items, 50, 4).train().to(device)
epochs = 500
lr = 5e-3
batch_size = 512
log_every = 5

In [6]:
dataset.transform = torchvision.transforms.Compose([
    RemoveItemsTransform(p=0.5),
    BinaryEncodingTransform(mapping=dataset.items_to_idx),
]
)
dataset.target_transform = torchvision.transforms.Compose([
    BinaryEncodingTransform(mapping=dataset.items_to_idx),
]
)

In [7]:
model.train()
data_iter = torch.utils.data.DataLoader(dataset,
                                     batch_size=batch_size,
                                     shuffle=True,
                                     num_workers=1,
                                     pin_memory=True)
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)
scheduler = ReduceLROnPlateau(optimizer, 'min', verbose=True, factor=0.7, threshold=1e-4)
knn_crt = KNNLoss(k=k).to(device)
losses = []
alpha = 10/170
gamma = 2
for epoch in range(epochs):
    total_loss = 0
    for iteration, (batch, target) in enumerate(data_iter):
        batch = batch.to(device)
        target = target.to(device)
        outputs, iterm = model(batch, return_intermidiate=True)
        mse_loss = F.binary_cross_entropy_with_logits(outputs, target, reduction='none')
        mask = torch.ones_like(mse_loss)
        mask[target == 0] = alpha ** gamma
        mask[target == 1] = (1 - alpha) ** gamma
        mse_loss = (mse_loss * mask).sum(dim=-1).mean()
        try:
            knn_loss = knn_crt(iterm)
            if torch.isinf(knn_loss):
                knn_loss = 0
        except ValueError:
            knn_loss = torch.tensor(0)
        loss = mse_loss + knn_loss
        total_loss += loss.item()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if iteration % log_every == 0:
            print(f"Epoch {epoch} / {epochs} | iteration {iteration} / {len(data_iter)} | Total Loss: {loss.item()} | KNN Loss: {knn_loss.item()} | BCE Loss: {mse_loss.item()}")
    
    scheduler.step(total_loss / (iteration + 1))
    losses.append(total_loss / (iteration + 1))

Epoch 0 / 500 | iteration 0 / 30 | Total Loss: 8.22247314453125 | KNN Loss: 6.22360897064209 | BCE Loss: 1.998863935470581
Epoch 0 / 500 | iteration 5 / 30 | Total Loss: 8.218988418579102 | KNN Loss: 6.223231792449951 | BCE Loss: 1.9957571029663086
Epoch 0 / 500 | iteration 10 / 30 | Total Loss: 8.13967514038086 | KNN Loss: 6.222769737243652 | BCE Loss: 1.9169052839279175
Epoch 0 / 500 | iteration 15 / 30 | Total Loss: 8.164443969726562 | KNN Loss: 6.222145080566406 | BCE Loss: 1.9422986507415771
Epoch 0 / 500 | iteration 20 / 30 | Total Loss: 8.195916175842285 | KNN Loss: 6.222169399261475 | BCE Loss: 1.9737468957901
Epoch 0 / 500 | iteration 25 / 30 | Total Loss: 8.157309532165527 | KNN Loss: 6.221463680267334 | BCE Loss: 1.9358460903167725
Epoch 1 / 500 | iteration 0 / 30 | Total Loss: 8.15555191040039 | KNN Loss: 6.221240043640137 | BCE Loss: 1.934312105178833
Epoch 1 / 500 | iteration 5 / 30 | Total Loss: 8.12782096862793 | KNN Loss: 6.2208991050720215 | BCE Loss: 1.90692198276519

Epoch 10 / 500 | iteration 25 / 30 | Total Loss: 5.264458656311035 | KNN Loss: 4.099384307861328 | BCE Loss: 1.165074110031128
Epoch 11 / 500 | iteration 0 / 30 | Total Loss: 5.162677764892578 | KNN Loss: 4.019097805023193 | BCE Loss: 1.1435799598693848
Epoch 11 / 500 | iteration 5 / 30 | Total Loss: 5.036418914794922 | KNN Loss: 3.8690450191497803 | BCE Loss: 1.1673741340637207
Epoch 11 / 500 | iteration 10 / 30 | Total Loss: 4.942875862121582 | KNN Loss: 3.8023130893707275 | BCE Loss: 1.1405625343322754
Epoch 11 / 500 | iteration 15 / 30 | Total Loss: 4.766353607177734 | KNN Loss: 3.647848129272461 | BCE Loss: 1.1185054779052734
Epoch 11 / 500 | iteration 20 / 30 | Total Loss: 4.7572808265686035 | KNN Loss: 3.615638494491577 | BCE Loss: 1.141642451286316
Epoch 11 / 500 | iteration 25 / 30 | Total Loss: 4.646496772766113 | KNN Loss: 3.518486499786377 | BCE Loss: 1.1280100345611572
Epoch 12 / 500 | iteration 0 / 30 | Total Loss: 4.577668190002441 | KNN Loss: 3.443626880645752 | BCE Los

Epoch 21 / 500 | iteration 15 / 30 | Total Loss: 3.7771406173706055 | KNN Loss: 2.7218680381774902 | BCE Loss: 1.0552725791931152
Epoch 21 / 500 | iteration 20 / 30 | Total Loss: 3.9323348999023438 | KNN Loss: 2.8805902004241943 | BCE Loss: 1.0517446994781494
Epoch 21 / 500 | iteration 25 / 30 | Total Loss: 3.7543439865112305 | KNN Loss: 2.7042951583862305 | BCE Loss: 1.0500487089157104
Epoch 22 / 500 | iteration 0 / 30 | Total Loss: 3.789013385772705 | KNN Loss: 2.7432188987731934 | BCE Loss: 1.0457944869995117
Epoch 22 / 500 | iteration 5 / 30 | Total Loss: 3.732511520385742 | KNN Loss: 2.7018165588378906 | BCE Loss: 1.0306950807571411
Epoch 22 / 500 | iteration 10 / 30 | Total Loss: 3.727170705795288 | KNN Loss: 2.662597417831421 | BCE Loss: 1.0645732879638672
Epoch 22 / 500 | iteration 15 / 30 | Total Loss: 3.7364912033081055 | KNN Loss: 2.7041547298431396 | BCE Loss: 1.0323365926742554
Epoch 22 / 500 | iteration 20 / 30 | Total Loss: 3.8142006397247314 | KNN Loss: 2.73337197303771

Epoch 32 / 500 | iteration 5 / 30 | Total Loss: 3.669727325439453 | KNN Loss: 2.654287099838257 | BCE Loss: 1.0154403448104858
Epoch 32 / 500 | iteration 10 / 30 | Total Loss: 3.717719554901123 | KNN Loss: 2.6808359622955322 | BCE Loss: 1.0368835926055908
Epoch 32 / 500 | iteration 15 / 30 | Total Loss: 3.7720143795013428 | KNN Loss: 2.7191321849823 | BCE Loss: 1.052882194519043
Epoch 32 / 500 | iteration 20 / 30 | Total Loss: 3.696995735168457 | KNN Loss: 2.63246750831604 | BCE Loss: 1.0645281076431274
Epoch 32 / 500 | iteration 25 / 30 | Total Loss: 3.6508092880249023 | KNN Loss: 2.6109132766723633 | BCE Loss: 1.039896011352539
Epoch 33 / 500 | iteration 0 / 30 | Total Loss: 3.650197982788086 | KNN Loss: 2.6085948944091797 | BCE Loss: 1.0416029691696167
Epoch 33 / 500 | iteration 5 / 30 | Total Loss: 3.7031362056732178 | KNN Loss: 2.636103868484497 | BCE Loss: 1.0670323371887207
Epoch 33 / 500 | iteration 10 / 30 | Total Loss: 3.753002643585205 | KNN Loss: 2.708535671234131 | BCE Los

Epoch 42 / 500 | iteration 25 / 30 | Total Loss: 3.645094156265259 | KNN Loss: 2.617058277130127 | BCE Loss: 1.0280358791351318
Epoch 43 / 500 | iteration 0 / 30 | Total Loss: 3.6482696533203125 | KNN Loss: 2.6339664459228516 | BCE Loss: 1.0143030881881714
Epoch 43 / 500 | iteration 5 / 30 | Total Loss: 3.686985492706299 | KNN Loss: 2.6422319412231445 | BCE Loss: 1.0447534322738647
Epoch 43 / 500 | iteration 10 / 30 | Total Loss: 3.5824618339538574 | KNN Loss: 2.5724573135375977 | BCE Loss: 1.0100046396255493
Epoch 43 / 500 | iteration 15 / 30 | Total Loss: 3.63242769241333 | KNN Loss: 2.593766927719116 | BCE Loss: 1.0386606454849243
Epoch 43 / 500 | iteration 20 / 30 | Total Loss: 3.6196696758270264 | KNN Loss: 2.5609395503997803 | BCE Loss: 1.058730125427246
Epoch 43 / 500 | iteration 25 / 30 | Total Loss: 3.5884146690368652 | KNN Loss: 2.5718188285827637 | BCE Loss: 1.0165958404541016
Epoch 44 / 500 | iteration 0 / 30 | Total Loss: 3.5769121646881104 | KNN Loss: 2.5403707027435303 |

Epoch 53 / 500 | iteration 15 / 30 | Total Loss: 3.6720290184020996 | KNN Loss: 2.624608039855957 | BCE Loss: 1.0474209785461426
Epoch 53 / 500 | iteration 20 / 30 | Total Loss: 3.5511789321899414 | KNN Loss: 2.509263277053833 | BCE Loss: 1.0419156551361084
Epoch 53 / 500 | iteration 25 / 30 | Total Loss: 3.590444326400757 | KNN Loss: 2.5605013370513916 | BCE Loss: 1.0299429893493652
Epoch 54 / 500 | iteration 0 / 30 | Total Loss: 3.6001973152160645 | KNN Loss: 2.572086811065674 | BCE Loss: 1.0281105041503906
Epoch 54 / 500 | iteration 5 / 30 | Total Loss: 3.6469030380249023 | KNN Loss: 2.599118232727051 | BCE Loss: 1.0477849245071411
Epoch 54 / 500 | iteration 10 / 30 | Total Loss: 3.609492063522339 | KNN Loss: 2.5756914615631104 | BCE Loss: 1.0338006019592285
Epoch 54 / 500 | iteration 15 / 30 | Total Loss: 3.609755039215088 | KNN Loss: 2.5621235370635986 | BCE Loss: 1.0476315021514893
Epoch 54 / 500 | iteration 20 / 30 | Total Loss: 3.5554230213165283 | KNN Loss: 2.5264675617218018 

Epoch 64 / 500 | iteration 5 / 30 | Total Loss: 3.5127980709075928 | KNN Loss: 2.5170600414276123 | BCE Loss: 0.9957380890846252
Epoch 64 / 500 | iteration 10 / 30 | Total Loss: 3.532395839691162 | KNN Loss: 2.5240042209625244 | BCE Loss: 1.0083917379379272
Epoch 64 / 500 | iteration 15 / 30 | Total Loss: 3.590880870819092 | KNN Loss: 2.5665719509124756 | BCE Loss: 1.0243088006973267
Epoch 64 / 500 | iteration 20 / 30 | Total Loss: 3.6283459663391113 | KNN Loss: 2.5616445541381836 | BCE Loss: 1.0667012929916382
Epoch 64 / 500 | iteration 25 / 30 | Total Loss: 3.556499719619751 | KNN Loss: 2.529820680618286 | BCE Loss: 1.0266790390014648
Epoch 65 / 500 | iteration 0 / 30 | Total Loss: 3.5364365577697754 | KNN Loss: 2.518228054046631 | BCE Loss: 1.018208622932434
Epoch 65 / 500 | iteration 5 / 30 | Total Loss: 3.5560708045959473 | KNN Loss: 2.529261350631714 | BCE Loss: 1.0268094539642334
Epoch 65 / 500 | iteration 10 / 30 | Total Loss: 3.556000232696533 | KNN Loss: 2.5382418632507324 | 

Epoch 74 / 500 | iteration 25 / 30 | Total Loss: 3.567671298980713 | KNN Loss: 2.5490541458129883 | BCE Loss: 1.0186171531677246
Epoch 75 / 500 | iteration 0 / 30 | Total Loss: 3.580475091934204 | KNN Loss: 2.5538806915283203 | BCE Loss: 1.0265944004058838
Epoch 75 / 500 | iteration 5 / 30 | Total Loss: 3.523573875427246 | KNN Loss: 2.505808115005493 | BCE Loss: 1.0177656412124634
Epoch 75 / 500 | iteration 10 / 30 | Total Loss: 3.5787439346313477 | KNN Loss: 2.5474350452423096 | BCE Loss: 1.0313090085983276
Epoch 75 / 500 | iteration 15 / 30 | Total Loss: 3.542860507965088 | KNN Loss: 2.499234914779663 | BCE Loss: 1.0436254739761353
Epoch 75 / 500 | iteration 20 / 30 | Total Loss: 3.5579254627227783 | KNN Loss: 2.543264865875244 | BCE Loss: 1.0146605968475342
Epoch 75 / 500 | iteration 25 / 30 | Total Loss: 3.5468859672546387 | KNN Loss: 2.5268537998199463 | BCE Loss: 1.0200320482254028
Epoch 76 / 500 | iteration 0 / 30 | Total Loss: 3.5171244144439697 | KNN Loss: 2.4885737895965576 |

Epoch 85 / 500 | iteration 15 / 30 | Total Loss: 3.6263134479522705 | KNN Loss: 2.614882230758667 | BCE Loss: 1.0114312171936035
Epoch 85 / 500 | iteration 20 / 30 | Total Loss: 3.5793681144714355 | KNN Loss: 2.5530636310577393 | BCE Loss: 1.0263043642044067
Epoch 85 / 500 | iteration 25 / 30 | Total Loss: 3.6473989486694336 | KNN Loss: 2.5911848545074463 | BCE Loss: 1.0562139749526978
Epoch 86 / 500 | iteration 0 / 30 | Total Loss: 3.534693717956543 | KNN Loss: 2.537569046020508 | BCE Loss: 0.9971246123313904
Epoch 86 / 500 | iteration 5 / 30 | Total Loss: 3.619964599609375 | KNN Loss: 2.582371711730957 | BCE Loss: 1.0375927686691284
Epoch 86 / 500 | iteration 10 / 30 | Total Loss: 3.5369973182678223 | KNN Loss: 2.508974075317383 | BCE Loss: 1.0280232429504395
Epoch 86 / 500 | iteration 15 / 30 | Total Loss: 3.624163866043091 | KNN Loss: 2.5709710121154785 | BCE Loss: 1.0531928539276123
Epoch 86 / 500 | iteration 20 / 30 | Total Loss: 3.593811511993408 | KNN Loss: 2.568632125854492 | 

Epoch 96 / 500 | iteration 5 / 30 | Total Loss: 3.5317749977111816 | KNN Loss: 2.5460212230682373 | BCE Loss: 0.9857537150382996
Epoch 96 / 500 | iteration 10 / 30 | Total Loss: 3.5643692016601562 | KNN Loss: 2.512495994567871 | BCE Loss: 1.0518733263015747
Epoch 96 / 500 | iteration 15 / 30 | Total Loss: 3.5798068046569824 | KNN Loss: 2.551253318786621 | BCE Loss: 1.0285536050796509
Epoch 96 / 500 | iteration 20 / 30 | Total Loss: 3.5565826892852783 | KNN Loss: 2.5273914337158203 | BCE Loss: 1.029191255569458
Epoch 96 / 500 | iteration 25 / 30 | Total Loss: 3.5695745944976807 | KNN Loss: 2.5380046367645264 | BCE Loss: 1.0315699577331543
Epoch 97 / 500 | iteration 0 / 30 | Total Loss: 3.5376482009887695 | KNN Loss: 2.5291762351989746 | BCE Loss: 1.0084720849990845
Epoch 97 / 500 | iteration 5 / 30 | Total Loss: 3.6098623275756836 | KNN Loss: 2.5416746139526367 | BCE Loss: 1.0681875944137573
Epoch 97 / 500 | iteration 10 / 30 | Total Loss: 3.5206894874572754 | KNN Loss: 2.52030706405639

Epoch 106 / 500 | iteration 25 / 30 | Total Loss: 3.522369146347046 | KNN Loss: 2.5107922554016113 | BCE Loss: 1.0115768909454346
Epoch 107 / 500 | iteration 0 / 30 | Total Loss: 3.536799430847168 | KNN Loss: 2.528308153152466 | BCE Loss: 1.0084912776947021
Epoch 107 / 500 | iteration 5 / 30 | Total Loss: 3.59501314163208 | KNN Loss: 2.581695079803467 | BCE Loss: 1.0133180618286133
Epoch 107 / 500 | iteration 10 / 30 | Total Loss: 3.5243887901306152 | KNN Loss: 2.4859836101531982 | BCE Loss: 1.0384052991867065
Epoch 107 / 500 | iteration 15 / 30 | Total Loss: 3.5179390907287598 | KNN Loss: 2.479325771331787 | BCE Loss: 1.0386133193969727
Epoch 107 / 500 | iteration 20 / 30 | Total Loss: 3.549006223678589 | KNN Loss: 2.520850658416748 | BCE Loss: 1.0281555652618408
Epoch 107 / 500 | iteration 25 / 30 | Total Loss: 3.4767372608184814 | KNN Loss: 2.4805407524108887 | BCE Loss: 0.9961965084075928
Epoch 108 / 500 | iteration 0 / 30 | Total Loss: 3.5234150886535645 | KNN Loss: 2.489650249481

Epoch 117 / 500 | iteration 10 / 30 | Total Loss: 3.5059053897857666 | KNN Loss: 2.4881539344787598 | BCE Loss: 1.0177514553070068
Epoch 117 / 500 | iteration 15 / 30 | Total Loss: 3.536095380783081 | KNN Loss: 2.492734432220459 | BCE Loss: 1.043360948562622
Epoch 117 / 500 | iteration 20 / 30 | Total Loss: 3.4987237453460693 | KNN Loss: 2.4899535179138184 | BCE Loss: 1.008770227432251
Epoch 117 / 500 | iteration 25 / 30 | Total Loss: 3.5520997047424316 | KNN Loss: 2.5070548057556152 | BCE Loss: 1.0450448989868164
Epoch 118 / 500 | iteration 0 / 30 | Total Loss: 3.5129337310791016 | KNN Loss: 2.4810712337493896 | BCE Loss: 1.031862497329712
Epoch 118 / 500 | iteration 5 / 30 | Total Loss: 3.5607287883758545 | KNN Loss: 2.5459887981414795 | BCE Loss: 1.014739990234375
Epoch 118 / 500 | iteration 10 / 30 | Total Loss: 3.527527332305908 | KNN Loss: 2.5120439529418945 | BCE Loss: 1.0154833793640137
Epoch 118 / 500 | iteration 15 / 30 | Total Loss: 3.5499892234802246 | KNN Loss: 2.528313159

Epoch 128 / 500 | iteration 0 / 30 | Total Loss: 3.5169739723205566 | KNN Loss: 2.4830057621002197 | BCE Loss: 1.033968210220337
Epoch 128 / 500 | iteration 5 / 30 | Total Loss: 3.5188164710998535 | KNN Loss: 2.483337640762329 | BCE Loss: 1.0354788303375244
Epoch 128 / 500 | iteration 10 / 30 | Total Loss: 3.533853530883789 | KNN Loss: 2.5042545795440674 | BCE Loss: 1.0295988321304321
Epoch 128 / 500 | iteration 15 / 30 | Total Loss: 3.5585851669311523 | KNN Loss: 2.494270086288452 | BCE Loss: 1.0643150806427002
Epoch 128 / 500 | iteration 20 / 30 | Total Loss: 3.4947192668914795 | KNN Loss: 2.489332437515259 | BCE Loss: 1.0053868293762207
Epoch 128 / 500 | iteration 25 / 30 | Total Loss: 3.5002036094665527 | KNN Loss: 2.507985830307007 | BCE Loss: 0.9922178387641907
Epoch 129 / 500 | iteration 0 / 30 | Total Loss: 3.5314083099365234 | KNN Loss: 2.4907000064849854 | BCE Loss: 1.0407081842422485
Epoch 129 / 500 | iteration 5 / 30 | Total Loss: 3.4951438903808594 | KNN Loss: 2.4878973960

Epoch 138 / 500 | iteration 15 / 30 | Total Loss: 3.4983348846435547 | KNN Loss: 2.5239028930664062 | BCE Loss: 0.9744319319725037
Epoch 138 / 500 | iteration 20 / 30 | Total Loss: 3.4749884605407715 | KNN Loss: 2.4800217151641846 | BCE Loss: 0.9949667453765869
Epoch 138 / 500 | iteration 25 / 30 | Total Loss: 3.597290515899658 | KNN Loss: 2.561082363128662 | BCE Loss: 1.036208152770996
Epoch 139 / 500 | iteration 0 / 30 | Total Loss: 3.49273681640625 | KNN Loss: 2.4890799522399902 | BCE Loss: 1.0036567449569702
Epoch 139 / 500 | iteration 5 / 30 | Total Loss: 3.526437282562256 | KNN Loss: 2.478039026260376 | BCE Loss: 1.0483983755111694
Epoch 139 / 500 | iteration 10 / 30 | Total Loss: 3.553596019744873 | KNN Loss: 2.5263984203338623 | BCE Loss: 1.0271977186203003
Epoch 139 / 500 | iteration 15 / 30 | Total Loss: 3.5084726810455322 | KNN Loss: 2.482048511505127 | BCE Loss: 1.0264241695404053
Epoch 139 / 500 | iteration 20 / 30 | Total Loss: 3.53173565864563 | KNN Loss: 2.5077357292175

Epoch 149 / 500 | iteration 5 / 30 | Total Loss: 3.5511744022369385 | KNN Loss: 2.5059008598327637 | BCE Loss: 1.0452735424041748
Epoch 149 / 500 | iteration 10 / 30 | Total Loss: 3.457425355911255 | KNN Loss: 2.454634189605713 | BCE Loss: 1.002791166305542
Epoch 149 / 500 | iteration 15 / 30 | Total Loss: 3.542464256286621 | KNN Loss: 2.4967877864837646 | BCE Loss: 1.045676589012146
Epoch 149 / 500 | iteration 20 / 30 | Total Loss: 3.4780526161193848 | KNN Loss: 2.4730188846588135 | BCE Loss: 1.0050338506698608
Epoch 149 / 500 | iteration 25 / 30 | Total Loss: 3.536585569381714 | KNN Loss: 2.4935142993927 | BCE Loss: 1.0430712699890137
Epoch 150 / 500 | iteration 0 / 30 | Total Loss: 3.5003035068511963 | KNN Loss: 2.4793686866760254 | BCE Loss: 1.020934820175171
Epoch 150 / 500 | iteration 5 / 30 | Total Loss: 3.5307364463806152 | KNN Loss: 2.518127202987671 | BCE Loss: 1.0126093626022339
Epoch 150 / 500 | iteration 10 / 30 | Total Loss: 3.481475353240967 | KNN Loss: 2.486847162246704

Epoch 159 / 500 | iteration 20 / 30 | Total Loss: 3.458702564239502 | KNN Loss: 2.4547579288482666 | BCE Loss: 1.0039445161819458
Epoch 159 / 500 | iteration 25 / 30 | Total Loss: 3.4868555068969727 | KNN Loss: 2.447715997695923 | BCE Loss: 1.0391395092010498
Epoch 160 / 500 | iteration 0 / 30 | Total Loss: 3.506044864654541 | KNN Loss: 2.491935968399048 | BCE Loss: 1.0141090154647827
Epoch 160 / 500 | iteration 5 / 30 | Total Loss: 3.468935012817383 | KNN Loss: 2.47986102104187 | BCE Loss: 0.9890739321708679
Epoch 160 / 500 | iteration 10 / 30 | Total Loss: 3.5129895210266113 | KNN Loss: 2.478776693344116 | BCE Loss: 1.0342128276824951
Epoch 160 / 500 | iteration 15 / 30 | Total Loss: 3.515824317932129 | KNN Loss: 2.489143133163452 | BCE Loss: 1.0266811847686768
Epoch 160 / 500 | iteration 20 / 30 | Total Loss: 3.467853546142578 | KNN Loss: 2.4754226207733154 | BCE Loss: 0.9924309253692627
Epoch 160 / 500 | iteration 25 / 30 | Total Loss: 3.4697697162628174 | KNN Loss: 2.4713985919952

Epoch 170 / 500 | iteration 10 / 30 | Total Loss: 3.5078723430633545 | KNN Loss: 2.4744577407836914 | BCE Loss: 1.033414602279663
Epoch 170 / 500 | iteration 15 / 30 | Total Loss: 3.52605938911438 | KNN Loss: 2.498671531677246 | BCE Loss: 1.0273878574371338
Epoch 170 / 500 | iteration 20 / 30 | Total Loss: 3.5003716945648193 | KNN Loss: 2.477769613265991 | BCE Loss: 1.0226020812988281
Epoch 170 / 500 | iteration 25 / 30 | Total Loss: 3.4787964820861816 | KNN Loss: 2.4643561840057373 | BCE Loss: 1.0144402980804443
Epoch 171 / 500 | iteration 0 / 30 | Total Loss: 3.5042099952697754 | KNN Loss: 2.5056469440460205 | BCE Loss: 0.9985629916191101
Epoch 171 / 500 | iteration 5 / 30 | Total Loss: 3.4526801109313965 | KNN Loss: 2.458416223526001 | BCE Loss: 0.9942638278007507
Epoch 171 / 500 | iteration 10 / 30 | Total Loss: 3.5223536491394043 | KNN Loss: 2.4803550243377686 | BCE Loss: 1.0419986248016357
Epoch 171 / 500 | iteration 15 / 30 | Total Loss: 3.5059032440185547 | KNN Loss: 2.49052429

Epoch 180 / 500 | iteration 25 / 30 | Total Loss: 3.533418655395508 | KNN Loss: 2.493548631668091 | BCE Loss: 1.0398701429367065
Epoch 181 / 500 | iteration 0 / 30 | Total Loss: 3.4665982723236084 | KNN Loss: 2.473998546600342 | BCE Loss: 0.9925997257232666
Epoch 181 / 500 | iteration 5 / 30 | Total Loss: 3.4580445289611816 | KNN Loss: 2.458850860595703 | BCE Loss: 0.9991936087608337
Epoch 181 / 500 | iteration 10 / 30 | Total Loss: 3.4841861724853516 | KNN Loss: 2.4744162559509277 | BCE Loss: 1.0097699165344238
Epoch 181 / 500 | iteration 15 / 30 | Total Loss: 3.5424728393554688 | KNN Loss: 2.505797863006592 | BCE Loss: 1.036674976348877
Epoch 181 / 500 | iteration 20 / 30 | Total Loss: 3.4985902309417725 | KNN Loss: 2.4562296867370605 | BCE Loss: 1.042360544204712
Epoch 181 / 500 | iteration 25 / 30 | Total Loss: 3.5357394218444824 | KNN Loss: 2.4982433319091797 | BCE Loss: 1.0374960899353027
Epoch 182 / 500 | iteration 0 / 30 | Total Loss: 3.5101611614227295 | KNN Loss: 2.4732043743

Epoch 191 / 500 | iteration 10 / 30 | Total Loss: 3.488408088684082 | KNN Loss: 2.4597349166870117 | BCE Loss: 1.0286731719970703
Epoch 191 / 500 | iteration 15 / 30 | Total Loss: 3.4402153491973877 | KNN Loss: 2.420341968536377 | BCE Loss: 1.0198733806610107
Epoch 191 / 500 | iteration 20 / 30 | Total Loss: 3.546417236328125 | KNN Loss: 2.517951250076294 | BCE Loss: 1.028465986251831
Epoch 191 / 500 | iteration 25 / 30 | Total Loss: 3.5082762241363525 | KNN Loss: 2.4921274185180664 | BCE Loss: 1.0161488056182861
Epoch 192 / 500 | iteration 0 / 30 | Total Loss: 3.4955859184265137 | KNN Loss: 2.448272228240967 | BCE Loss: 1.0473136901855469
Epoch 192 / 500 | iteration 5 / 30 | Total Loss: 3.4917759895324707 | KNN Loss: 2.464336395263672 | BCE Loss: 1.0274395942687988
Epoch 192 / 500 | iteration 10 / 30 | Total Loss: 3.5170273780822754 | KNN Loss: 2.4934747219085693 | BCE Loss: 1.023552656173706
Epoch 192 / 500 | iteration 15 / 30 | Total Loss: 3.488074779510498 | KNN Loss: 2.47661733627

Epoch 202 / 500 | iteration 0 / 30 | Total Loss: 3.4649863243103027 | KNN Loss: 2.4517998695373535 | BCE Loss: 1.0131865739822388
Epoch 202 / 500 | iteration 5 / 30 | Total Loss: 3.4816482067108154 | KNN Loss: 2.457880973815918 | BCE Loss: 1.0237672328948975
Epoch 202 / 500 | iteration 10 / 30 | Total Loss: 3.46248197555542 | KNN Loss: 2.4453697204589844 | BCE Loss: 1.0171122550964355
Epoch 202 / 500 | iteration 15 / 30 | Total Loss: 3.452958583831787 | KNN Loss: 2.4343101978302 | BCE Loss: 1.018648386001587
Epoch 202 / 500 | iteration 20 / 30 | Total Loss: 3.489839553833008 | KNN Loss: 2.4557693004608154 | BCE Loss: 1.0340701341629028
Epoch 202 / 500 | iteration 25 / 30 | Total Loss: 3.5080325603485107 | KNN Loss: 2.4709293842315674 | BCE Loss: 1.0371031761169434
Epoch 203 / 500 | iteration 0 / 30 | Total Loss: 3.4818124771118164 | KNN Loss: 2.4641575813293457 | BCE Loss: 1.0176548957824707
Epoch 203 / 500 | iteration 5 / 30 | Total Loss: 3.5288023948669434 | KNN Loss: 2.5167119503021

Epoch 212 / 500 | iteration 20 / 30 | Total Loss: 3.47157883644104 | KNN Loss: 2.450037717819214 | BCE Loss: 1.0215411186218262
Epoch 212 / 500 | iteration 25 / 30 | Total Loss: 3.5053529739379883 | KNN Loss: 2.5070388317108154 | BCE Loss: 0.9983142018318176
Epoch 213 / 500 | iteration 0 / 30 | Total Loss: 3.4697532653808594 | KNN Loss: 2.4658026695251465 | BCE Loss: 1.003950595855713
Epoch 213 / 500 | iteration 5 / 30 | Total Loss: 3.509462356567383 | KNN Loss: 2.4862470626831055 | BCE Loss: 1.0232152938842773
Epoch 213 / 500 | iteration 10 / 30 | Total Loss: 3.473299503326416 | KNN Loss: 2.4597206115722656 | BCE Loss: 1.0135788917541504
Epoch 213 / 500 | iteration 15 / 30 | Total Loss: 3.4736268520355225 | KNN Loss: 2.466838836669922 | BCE Loss: 1.0067880153656006
Epoch 213 / 500 | iteration 20 / 30 | Total Loss: 3.488909959793091 | KNN Loss: 2.461099624633789 | BCE Loss: 1.0278103351593018
Epoch 213 / 500 | iteration 25 / 30 | Total Loss: 3.48667311668396 | KNN Loss: 2.4552118778228

Epoch 223 / 500 | iteration 5 / 30 | Total Loss: 3.514244794845581 | KNN Loss: 2.4862587451934814 | BCE Loss: 1.0279860496520996
Epoch 223 / 500 | iteration 10 / 30 | Total Loss: 3.477745771408081 | KNN Loss: 2.4753332138061523 | BCE Loss: 1.0024125576019287
Epoch 223 / 500 | iteration 15 / 30 | Total Loss: 3.5421271324157715 | KNN Loss: 2.5129120349884033 | BCE Loss: 1.0292149782180786
Epoch 223 / 500 | iteration 20 / 30 | Total Loss: 3.5008695125579834 | KNN Loss: 2.4774293899536133 | BCE Loss: 1.0234401226043701
Epoch 223 / 500 | iteration 25 / 30 | Total Loss: 3.4838380813598633 | KNN Loss: 2.437594175338745 | BCE Loss: 1.0462437868118286
Epoch 224 / 500 | iteration 0 / 30 | Total Loss: 3.4738576412200928 | KNN Loss: 2.4507553577423096 | BCE Loss: 1.0231022834777832
Epoch 224 / 500 | iteration 5 / 30 | Total Loss: 3.470026731491089 | KNN Loss: 2.452702760696411 | BCE Loss: 1.0173239707946777
Epoch 224 / 500 | iteration 10 / 30 | Total Loss: 3.5128769874572754 | KNN Loss: 2.48346185

Epoch 233 / 500 | iteration 25 / 30 | Total Loss: 3.479222059249878 | KNN Loss: 2.4314968585968018 | BCE Loss: 1.0477252006530762
Epoch 234 / 500 | iteration 0 / 30 | Total Loss: 3.483564615249634 | KNN Loss: 2.462078094482422 | BCE Loss: 1.021486520767212
Epoch 234 / 500 | iteration 5 / 30 | Total Loss: 3.4659769535064697 | KNN Loss: 2.4426510334014893 | BCE Loss: 1.0233259201049805
Epoch 234 / 500 | iteration 10 / 30 | Total Loss: 3.45291805267334 | KNN Loss: 2.451627492904663 | BCE Loss: 1.0012905597686768
Epoch 234 / 500 | iteration 15 / 30 | Total Loss: 3.463162660598755 | KNN Loss: 2.4324777126312256 | BCE Loss: 1.0306849479675293
Epoch 234 / 500 | iteration 20 / 30 | Total Loss: 3.500356674194336 | KNN Loss: 2.4976255893707275 | BCE Loss: 1.0027310848236084
Epoch 234 / 500 | iteration 25 / 30 | Total Loss: 3.48954701423645 | KNN Loss: 2.4708573818206787 | BCE Loss: 1.0186896324157715
Epoch 235 / 500 | iteration 0 / 30 | Total Loss: 3.484165906906128 | KNN Loss: 2.483141660690307

Epoch 244 / 500 | iteration 15 / 30 | Total Loss: 3.4798247814178467 | KNN Loss: 2.4730236530303955 | BCE Loss: 1.0068011283874512
Epoch 244 / 500 | iteration 20 / 30 | Total Loss: 3.4971489906311035 | KNN Loss: 2.463555097579956 | BCE Loss: 1.0335938930511475
Epoch 244 / 500 | iteration 25 / 30 | Total Loss: 3.486262798309326 | KNN Loss: 2.4674971103668213 | BCE Loss: 1.0187656879425049
Epoch 245 / 500 | iteration 0 / 30 | Total Loss: 3.4814612865448 | KNN Loss: 2.4786269664764404 | BCE Loss: 1.0028343200683594
Epoch 245 / 500 | iteration 5 / 30 | Total Loss: 3.5019030570983887 | KNN Loss: 2.5055670738220215 | BCE Loss: 0.996336042881012
Epoch 245 / 500 | iteration 10 / 30 | Total Loss: 3.4595654010772705 | KNN Loss: 2.430117607116699 | BCE Loss: 1.0294477939605713
Epoch 245 / 500 | iteration 15 / 30 | Total Loss: 3.4510812759399414 | KNN Loss: 2.4557316303253174 | BCE Loss: 0.995349645614624
Epoch 245 / 500 | iteration 20 / 30 | Total Loss: 3.5019068717956543 | KNN Loss: 2.4492497444

Epoch 255 / 500 | iteration 0 / 30 | Total Loss: 3.4731388092041016 | KNN Loss: 2.452968120574951 | BCE Loss: 1.02017080783844
Epoch 255 / 500 | iteration 5 / 30 | Total Loss: 3.4379377365112305 | KNN Loss: 2.4313271045684814 | BCE Loss: 1.0066105127334595
Epoch 255 / 500 | iteration 10 / 30 | Total Loss: 3.4497933387756348 | KNN Loss: 2.4283478260040283 | BCE Loss: 1.021445393562317
Epoch 255 / 500 | iteration 15 / 30 | Total Loss: 3.4532909393310547 | KNN Loss: 2.4507863521575928 | BCE Loss: 1.0025044679641724
Epoch 255 / 500 | iteration 20 / 30 | Total Loss: 3.4763026237487793 | KNN Loss: 2.4828197956085205 | BCE Loss: 0.9934829473495483
Epoch 255 / 500 | iteration 25 / 30 | Total Loss: 3.4676640033721924 | KNN Loss: 2.4532649517059326 | BCE Loss: 1.0143990516662598
Epoch 256 / 500 | iteration 0 / 30 | Total Loss: 3.4915008544921875 | KNN Loss: 2.449420690536499 | BCE Loss: 1.0420801639556885
Epoch 256 / 500 | iteration 5 / 30 | Total Loss: 3.4514124393463135 | KNN Loss: 2.420502662

Epoch 265 / 500 | iteration 15 / 30 | Total Loss: 3.4767072200775146 | KNN Loss: 2.4463605880737305 | BCE Loss: 1.0303466320037842
Epoch 265 / 500 | iteration 20 / 30 | Total Loss: 3.472763776779175 | KNN Loss: 2.45454740524292 | BCE Loss: 1.0182163715362549
Epoch 265 / 500 | iteration 25 / 30 | Total Loss: 3.5096664428710938 | KNN Loss: 2.4719526767730713 | BCE Loss: 1.037713646888733
Epoch 266 / 500 | iteration 0 / 30 | Total Loss: 3.495213747024536 | KNN Loss: 2.454793691635132 | BCE Loss: 1.0404200553894043
Epoch 266 / 500 | iteration 5 / 30 | Total Loss: 3.4730498790740967 | KNN Loss: 2.4558212757110596 | BCE Loss: 1.017228603363037
Epoch 266 / 500 | iteration 10 / 30 | Total Loss: 3.4918785095214844 | KNN Loss: 2.4824416637420654 | BCE Loss: 1.009436845779419
Epoch 266 / 500 | iteration 15 / 30 | Total Loss: 3.4809868335723877 | KNN Loss: 2.4488232135772705 | BCE Loss: 1.0321636199951172
Epoch 266 / 500 | iteration 20 / 30 | Total Loss: 3.454820394515991 | KNN Loss: 2.45113229751

Epoch 276 / 500 | iteration 5 / 30 | Total Loss: 3.507936477661133 | KNN Loss: 2.482923746109009 | BCE Loss: 1.0250126123428345
Epoch 276 / 500 | iteration 10 / 30 | Total Loss: 3.4933950901031494 | KNN Loss: 2.4860658645629883 | BCE Loss: 1.0073292255401611
Epoch 276 / 500 | iteration 15 / 30 | Total Loss: 3.481349468231201 | KNN Loss: 2.4509952068328857 | BCE Loss: 1.030354380607605
Epoch 276 / 500 | iteration 20 / 30 | Total Loss: 3.464601516723633 | KNN Loss: 2.4575014114379883 | BCE Loss: 1.007100224494934
Epoch 276 / 500 | iteration 25 / 30 | Total Loss: 3.509666681289673 | KNN Loss: 2.4763643741607666 | BCE Loss: 1.0333023071289062
Epoch 277 / 500 | iteration 0 / 30 | Total Loss: 3.456704616546631 | KNN Loss: 2.464801788330078 | BCE Loss: 0.9919029474258423
Epoch 277 / 500 | iteration 5 / 30 | Total Loss: 3.55757474899292 | KNN Loss: 2.4835000038146973 | BCE Loss: 1.0740747451782227
Epoch 277 / 500 | iteration 10 / 30 | Total Loss: 3.512572765350342 | KNN Loss: 2.454874038696289

Epoch 286 / 500 | iteration 20 / 30 | Total Loss: 3.5066757202148438 | KNN Loss: 2.4585676193237305 | BCE Loss: 1.0481081008911133
Epoch 286 / 500 | iteration 25 / 30 | Total Loss: 3.466860055923462 | KNN Loss: 2.4403300285339355 | BCE Loss: 1.0265300273895264
Epoch 287 / 500 | iteration 0 / 30 | Total Loss: 3.4325897693634033 | KNN Loss: 2.4182283878326416 | BCE Loss: 1.0143613815307617
Epoch 287 / 500 | iteration 5 / 30 | Total Loss: 3.4706597328186035 | KNN Loss: 2.4441328048706055 | BCE Loss: 1.026526927947998
Epoch 287 / 500 | iteration 10 / 30 | Total Loss: 3.4781744480133057 | KNN Loss: 2.4808146953582764 | BCE Loss: 0.9973596930503845
Epoch 287 / 500 | iteration 15 / 30 | Total Loss: 3.477360963821411 | KNN Loss: 2.4580814838409424 | BCE Loss: 1.0192794799804688
Epoch 287 / 500 | iteration 20 / 30 | Total Loss: 3.4490253925323486 | KNN Loss: 2.448157548904419 | BCE Loss: 1.0008678436279297
Epoch 287 / 500 | iteration 25 / 30 | Total Loss: 3.460467576980591 | KNN Loss: 2.4243748

Epoch 297 / 500 | iteration 5 / 30 | Total Loss: 3.5022807121276855 | KNN Loss: 2.47072172164917 | BCE Loss: 1.0315591096878052
Epoch 297 / 500 | iteration 10 / 30 | Total Loss: 3.4401299953460693 | KNN Loss: 2.434986114501953 | BCE Loss: 1.0051438808441162
Epoch 297 / 500 | iteration 15 / 30 | Total Loss: 3.4728567600250244 | KNN Loss: 2.4496099948883057 | BCE Loss: 1.0232467651367188
Epoch 297 / 500 | iteration 20 / 30 | Total Loss: 3.455112934112549 | KNN Loss: 2.4276559352874756 | BCE Loss: 1.0274568796157837
Epoch 297 / 500 | iteration 25 / 30 | Total Loss: 3.4682841300964355 | KNN Loss: 2.416081190109253 | BCE Loss: 1.052202820777893
Epoch 298 / 500 | iteration 0 / 30 | Total Loss: 3.461200714111328 | KNN Loss: 2.4450738430023193 | BCE Loss: 1.0161268711090088
Epoch 298 / 500 | iteration 5 / 30 | Total Loss: 3.438159227371216 | KNN Loss: 2.425985097885132 | BCE Loss: 1.012174129486084
Epoch 298 / 500 | iteration 10 / 30 | Total Loss: 3.5314488410949707 | KNN Loss: 2.4761121273040

Epoch 307 / 500 | iteration 20 / 30 | Total Loss: 3.511638879776001 | KNN Loss: 2.447681427001953 | BCE Loss: 1.0639574527740479
Epoch 307 / 500 | iteration 25 / 30 | Total Loss: 3.4326210021972656 | KNN Loss: 2.44679856300354 | BCE Loss: 0.9858224391937256
Epoch 308 / 500 | iteration 0 / 30 | Total Loss: 3.4680984020233154 | KNN Loss: 2.4589455127716064 | BCE Loss: 1.009152889251709
Epoch 308 / 500 | iteration 5 / 30 | Total Loss: 3.4335241317749023 | KNN Loss: 2.425384759902954 | BCE Loss: 1.0081393718719482
Epoch 308 / 500 | iteration 10 / 30 | Total Loss: 3.437720775604248 | KNN Loss: 2.4233813285827637 | BCE Loss: 1.0143393278121948
Epoch 308 / 500 | iteration 15 / 30 | Total Loss: 3.513740301132202 | KNN Loss: 2.443140745162964 | BCE Loss: 1.0705995559692383
Epoch 308 / 500 | iteration 20 / 30 | Total Loss: 3.434210777282715 | KNN Loss: 2.427663803100586 | BCE Loss: 1.0065470933914185
Epoch 308 / 500 | iteration 25 / 30 | Total Loss: 3.4323015213012695 | KNN Loss: 2.4193806648254

Epoch 318 / 500 | iteration 10 / 30 | Total Loss: 3.4940080642700195 | KNN Loss: 2.468130350112915 | BCE Loss: 1.025877833366394
Epoch 318 / 500 | iteration 15 / 30 | Total Loss: 3.4352598190307617 | KNN Loss: 2.4331188201904297 | BCE Loss: 1.0021411180496216
Epoch 318 / 500 | iteration 20 / 30 | Total Loss: 3.449550151824951 | KNN Loss: 2.457087993621826 | BCE Loss: 0.9924620389938354
Epoch 318 / 500 | iteration 25 / 30 | Total Loss: 3.461981773376465 | KNN Loss: 2.4400851726531982 | BCE Loss: 1.0218967199325562
Epoch 319 / 500 | iteration 0 / 30 | Total Loss: 3.518134117126465 | KNN Loss: 2.4606852531433105 | BCE Loss: 1.0574488639831543
Epoch 319 / 500 | iteration 5 / 30 | Total Loss: 3.4512832164764404 | KNN Loss: 2.4509356021881104 | BCE Loss: 1.00034761428833
Epoch 319 / 500 | iteration 10 / 30 | Total Loss: 3.463332414627075 | KNN Loss: 2.4546899795532227 | BCE Loss: 1.0086424350738525
Epoch 319 / 500 | iteration 15 / 30 | Total Loss: 3.448024034500122 | KNN Loss: 2.460364818572

Epoch 328 / 500 | iteration 25 / 30 | Total Loss: 3.4476542472839355 | KNN Loss: 2.423076868057251 | BCE Loss: 1.0245774984359741
Epoch 329 / 500 | iteration 0 / 30 | Total Loss: 3.4436161518096924 | KNN Loss: 2.4490084648132324 | BCE Loss: 0.9946077466011047
Epoch 329 / 500 | iteration 5 / 30 | Total Loss: 3.5088818073272705 | KNN Loss: 2.477179765701294 | BCE Loss: 1.0317020416259766
Epoch 329 / 500 | iteration 10 / 30 | Total Loss: 3.4990687370300293 | KNN Loss: 2.4678211212158203 | BCE Loss: 1.0312477350234985
Epoch 329 / 500 | iteration 15 / 30 | Total Loss: 3.471590042114258 | KNN Loss: 2.456876277923584 | BCE Loss: 1.0147137641906738
Epoch 329 / 500 | iteration 20 / 30 | Total Loss: 3.489309072494507 | KNN Loss: 2.4697301387786865 | BCE Loss: 1.0195789337158203
Epoch 329 / 500 | iteration 25 / 30 | Total Loss: 3.5376129150390625 | KNN Loss: 2.488628387451172 | BCE Loss: 1.048984408378601
Epoch 330 / 500 | iteration 0 / 30 | Total Loss: 3.4759342670440674 | KNN Loss: 2.4914088249

Epoch 339 / 500 | iteration 15 / 30 | Total Loss: 3.448765277862549 | KNN Loss: 2.4675395488739014 | BCE Loss: 0.9812256693840027
Epoch 339 / 500 | iteration 20 / 30 | Total Loss: 3.473170280456543 | KNN Loss: 2.4496469497680664 | BCE Loss: 1.023523211479187
Epoch 339 / 500 | iteration 25 / 30 | Total Loss: 3.449988842010498 | KNN Loss: 2.435469150543213 | BCE Loss: 1.0145196914672852
Epoch 340 / 500 | iteration 0 / 30 | Total Loss: 3.4823596477508545 | KNN Loss: 2.4804975986480713 | BCE Loss: 1.0018620491027832
Epoch 340 / 500 | iteration 5 / 30 | Total Loss: 3.4629969596862793 | KNN Loss: 2.4424171447753906 | BCE Loss: 1.0205798149108887
Epoch 340 / 500 | iteration 10 / 30 | Total Loss: 3.457097053527832 | KNN Loss: 2.4482970237731934 | BCE Loss: 1.0087999105453491
Epoch 340 / 500 | iteration 15 / 30 | Total Loss: 3.465200185775757 | KNN Loss: 2.425084114074707 | BCE Loss: 1.0401160717010498
Epoch 340 / 500 | iteration 20 / 30 | Total Loss: 3.485048532485962 | KNN Loss: 2.44348239898

Epoch 350 / 500 | iteration 0 / 30 | Total Loss: 3.4532532691955566 | KNN Loss: 2.4620425701141357 | BCE Loss: 0.9912107586860657
Epoch 350 / 500 | iteration 5 / 30 | Total Loss: 3.472604751586914 | KNN Loss: 2.435816764831543 | BCE Loss: 1.0367878675460815
Epoch 350 / 500 | iteration 10 / 30 | Total Loss: 3.4543752670288086 | KNN Loss: 2.4492125511169434 | BCE Loss: 1.0051625967025757
Epoch 350 / 500 | iteration 15 / 30 | Total Loss: 3.4606523513793945 | KNN Loss: 2.452345609664917 | BCE Loss: 1.0083067417144775
Epoch 350 / 500 | iteration 20 / 30 | Total Loss: 3.515366792678833 | KNN Loss: 2.4772109985351562 | BCE Loss: 1.0381557941436768
Epoch 350 / 500 | iteration 25 / 30 | Total Loss: 3.452786922454834 | KNN Loss: 2.4418509006500244 | BCE Loss: 1.01093590259552
Epoch 351 / 500 | iteration 0 / 30 | Total Loss: 3.464845895767212 | KNN Loss: 2.4311182498931885 | BCE Loss: 1.0337276458740234
Epoch 351 / 500 | iteration 5 / 30 | Total Loss: 3.428709030151367 | KNN Loss: 2.4120509624481

Epoch 360 / 500 | iteration 15 / 30 | Total Loss: 3.435586929321289 | KNN Loss: 2.4119956493377686 | BCE Loss: 1.023591160774231
Epoch 360 / 500 | iteration 20 / 30 | Total Loss: 3.452620029449463 | KNN Loss: 2.433753252029419 | BCE Loss: 1.0188666582107544
Epoch 360 / 500 | iteration 25 / 30 | Total Loss: 3.502429246902466 | KNN Loss: 2.481525421142578 | BCE Loss: 1.0209038257598877
Epoch 361 / 500 | iteration 0 / 30 | Total Loss: 3.4538774490356445 | KNN Loss: 2.4292361736297607 | BCE Loss: 1.0246411561965942
Epoch 361 / 500 | iteration 5 / 30 | Total Loss: 3.503884792327881 | KNN Loss: 2.4544107913970947 | BCE Loss: 1.0494740009307861
Epoch 361 / 500 | iteration 10 / 30 | Total Loss: 3.4786605834960938 | KNN Loss: 2.455996036529541 | BCE Loss: 1.0226645469665527
Epoch 361 / 500 | iteration 15 / 30 | Total Loss: 3.447157382965088 | KNN Loss: 2.4301888942718506 | BCE Loss: 1.0169684886932373
Epoch 361 / 500 | iteration 20 / 30 | Total Loss: 3.472005844116211 | KNN Loss: 2.438847780227

Epoch 371 / 500 | iteration 0 / 30 | Total Loss: 3.478219985961914 | KNN Loss: 2.449655294418335 | BCE Loss: 1.0285645723342896
Epoch 371 / 500 | iteration 5 / 30 | Total Loss: 3.4414737224578857 | KNN Loss: 2.4435951709747314 | BCE Loss: 0.9978784918785095
Epoch 371 / 500 | iteration 10 / 30 | Total Loss: 3.469611883163452 | KNN Loss: 2.4687013626098633 | BCE Loss: 1.0009105205535889
Epoch 371 / 500 | iteration 15 / 30 | Total Loss: 3.4593541622161865 | KNN Loss: 2.4395668506622314 | BCE Loss: 1.019787311553955
Epoch 371 / 500 | iteration 20 / 30 | Total Loss: 3.462219476699829 | KNN Loss: 2.4310526847839355 | BCE Loss: 1.0311667919158936
Epoch 371 / 500 | iteration 25 / 30 | Total Loss: 3.4667530059814453 | KNN Loss: 2.4632203578948975 | BCE Loss: 1.0035325288772583
Epoch 372 / 500 | iteration 0 / 30 | Total Loss: 3.4662671089172363 | KNN Loss: 2.4473118782043457 | BCE Loss: 1.0189552307128906
Epoch 372 / 500 | iteration 5 / 30 | Total Loss: 3.4462549686431885 | KNN Loss: 2.435953140

Epoch 381 / 500 | iteration 15 / 30 | Total Loss: 3.422741651535034 | KNN Loss: 2.4392545223236084 | BCE Loss: 0.9834871292114258
Epoch 381 / 500 | iteration 20 / 30 | Total Loss: 3.4448494911193848 | KNN Loss: 2.4333763122558594 | BCE Loss: 1.0114731788635254
Epoch 381 / 500 | iteration 25 / 30 | Total Loss: 3.441282272338867 | KNN Loss: 2.4263429641723633 | BCE Loss: 1.014939308166504
Epoch 382 / 500 | iteration 0 / 30 | Total Loss: 3.5078442096710205 | KNN Loss: 2.4622044563293457 | BCE Loss: 1.0456397533416748
Epoch 382 / 500 | iteration 5 / 30 | Total Loss: 3.449286699295044 | KNN Loss: 2.445943593978882 | BCE Loss: 1.003343105316162
Epoch 382 / 500 | iteration 10 / 30 | Total Loss: 3.4708213806152344 | KNN Loss: 2.4548580646514893 | BCE Loss: 1.0159633159637451
Epoch 382 / 500 | iteration 15 / 30 | Total Loss: 3.4956393241882324 | KNN Loss: 2.449880361557007 | BCE Loss: 1.0457589626312256
Epoch 382 / 500 | iteration 20 / 30 | Total Loss: 3.450813055038452 | KNN Loss: 2.4651632308

Epoch 392 / 500 | iteration 0 / 30 | Total Loss: 3.4959592819213867 | KNN Loss: 2.4784324169158936 | BCE Loss: 1.0175267457962036
Epoch 392 / 500 | iteration 5 / 30 | Total Loss: 3.425710678100586 | KNN Loss: 2.4485795497894287 | BCE Loss: 0.9771310091018677
Epoch 392 / 500 | iteration 10 / 30 | Total Loss: 3.4792866706848145 | KNN Loss: 2.460533618927002 | BCE Loss: 1.018752932548523
Epoch 392 / 500 | iteration 15 / 30 | Total Loss: 3.441218852996826 | KNN Loss: 2.4427578449249268 | BCE Loss: 0.9984608888626099
Epoch 392 / 500 | iteration 20 / 30 | Total Loss: 3.4834706783294678 | KNN Loss: 2.467258930206299 | BCE Loss: 1.016211748123169
Epoch 392 / 500 | iteration 25 / 30 | Total Loss: 3.4701642990112305 | KNN Loss: 2.4438676834106445 | BCE Loss: 1.0262964963912964
Epoch 393 / 500 | iteration 0 / 30 | Total Loss: 3.489194869995117 | KNN Loss: 2.446622610092163 | BCE Loss: 1.0425723791122437
Epoch 393 / 500 | iteration 5 / 30 | Total Loss: 3.4822959899902344 | KNN Loss: 2.419707775115

Epoch 402 / 500 | iteration 20 / 30 | Total Loss: 3.423327684402466 | KNN Loss: 2.4228551387786865 | BCE Loss: 1.0004725456237793
Epoch 402 / 500 | iteration 25 / 30 | Total Loss: 3.4870615005493164 | KNN Loss: 2.447208881378174 | BCE Loss: 1.0398527383804321
Epoch   403: reducing learning rate of group 0 to 8.1421e-06.
Epoch 403 / 500 | iteration 0 / 30 | Total Loss: 3.5150034427642822 | KNN Loss: 2.4644601345062256 | BCE Loss: 1.0505433082580566
Epoch 403 / 500 | iteration 5 / 30 | Total Loss: 3.4819064140319824 | KNN Loss: 2.436379909515381 | BCE Loss: 1.0455265045166016
Epoch 403 / 500 | iteration 10 / 30 | Total Loss: 3.466797113418579 | KNN Loss: 2.469770908355713 | BCE Loss: 0.9970262050628662
Epoch 403 / 500 | iteration 15 / 30 | Total Loss: 3.436511278152466 | KNN Loss: 2.447770833969116 | BCE Loss: 0.9887404441833496
Epoch 403 / 500 | iteration 20 / 30 | Total Loss: 3.4776298999786377 | KNN Loss: 2.449188232421875 | BCE Loss: 1.0284416675567627
Epoch 403 / 500 | iteration 25 

Epoch 413 / 500 | iteration 5 / 30 | Total Loss: 3.47263765335083 | KNN Loss: 2.456786870956421 | BCE Loss: 1.0158509016036987
Epoch 413 / 500 | iteration 10 / 30 | Total Loss: 3.463237762451172 | KNN Loss: 2.452312707901001 | BCE Loss: 1.0109249353408813
Epoch 413 / 500 | iteration 15 / 30 | Total Loss: 3.50600528717041 | KNN Loss: 2.476257085800171 | BCE Loss: 1.0297480821609497
Epoch 413 / 500 | iteration 20 / 30 | Total Loss: 3.4290895462036133 | KNN Loss: 2.4215667247772217 | BCE Loss: 1.0075228214263916
Epoch 413 / 500 | iteration 25 / 30 | Total Loss: 3.4963715076446533 | KNN Loss: 2.461132526397705 | BCE Loss: 1.0352389812469482
Epoch   414: reducing learning rate of group 0 to 5.6994e-06.
Epoch 414 / 500 | iteration 0 / 30 | Total Loss: 3.458638906478882 | KNN Loss: 2.4385533332824707 | BCE Loss: 1.0200855731964111
Epoch 414 / 500 | iteration 5 / 30 | Total Loss: 3.4679973125457764 | KNN Loss: 2.445624589920044 | BCE Loss: 1.0223727226257324
Epoch 414 / 500 | iteration 10 / 30

Epoch 423 / 500 | iteration 20 / 30 | Total Loss: 3.456217050552368 | KNN Loss: 2.4527432918548584 | BCE Loss: 1.0034737586975098
Epoch 423 / 500 | iteration 25 / 30 | Total Loss: 3.4608864784240723 | KNN Loss: 2.4414494037628174 | BCE Loss: 1.0194370746612549
Epoch 424 / 500 | iteration 0 / 30 | Total Loss: 3.4647011756896973 | KNN Loss: 2.4589083194732666 | BCE Loss: 1.0057928562164307
Epoch 424 / 500 | iteration 5 / 30 | Total Loss: 3.488853693008423 | KNN Loss: 2.4441637992858887 | BCE Loss: 1.0446898937225342
Epoch 424 / 500 | iteration 10 / 30 | Total Loss: 3.4467687606811523 | KNN Loss: 2.4329352378845215 | BCE Loss: 1.0138334035873413
Epoch 424 / 500 | iteration 15 / 30 | Total Loss: 3.449246406555176 | KNN Loss: 2.4085404872894287 | BCE Loss: 1.0407060384750366
Epoch 424 / 500 | iteration 20 / 30 | Total Loss: 3.483002185821533 | KNN Loss: 2.4350340366363525 | BCE Loss: 1.0479681491851807
Epoch 424 / 500 | iteration 25 / 30 | Total Loss: 3.452014207839966 | KNN Loss: 2.4437315

Epoch 434 / 500 | iteration 5 / 30 | Total Loss: 3.4881834983825684 | KNN Loss: 2.44827938079834 | BCE Loss: 1.0399041175842285
Epoch 434 / 500 | iteration 10 / 30 | Total Loss: 3.4585931301116943 | KNN Loss: 2.43752121925354 | BCE Loss: 1.0210719108581543
Epoch 434 / 500 | iteration 15 / 30 | Total Loss: 3.4448769092559814 | KNN Loss: 2.430850028991699 | BCE Loss: 1.0140268802642822
Epoch 434 / 500 | iteration 20 / 30 | Total Loss: 3.4329426288604736 | KNN Loss: 2.441105604171753 | BCE Loss: 0.9918369650840759
Epoch 434 / 500 | iteration 25 / 30 | Total Loss: 3.487098217010498 | KNN Loss: 2.4562082290649414 | BCE Loss: 1.0308899879455566
Epoch 435 / 500 | iteration 0 / 30 | Total Loss: 3.4534294605255127 | KNN Loss: 2.4441022872924805 | BCE Loss: 1.0093271732330322
Epoch 435 / 500 | iteration 5 / 30 | Total Loss: 3.4872798919677734 | KNN Loss: 2.4596517086029053 | BCE Loss: 1.0276281833648682
Epoch 435 / 500 | iteration 10 / 30 | Total Loss: 3.481964588165283 | KNN Loss: 2.48844647407

Epoch 444 / 500 | iteration 20 / 30 | Total Loss: 3.475374221801758 | KNN Loss: 2.4733214378356934 | BCE Loss: 1.002052903175354
Epoch 444 / 500 | iteration 25 / 30 | Total Loss: 3.4991493225097656 | KNN Loss: 2.4720706939697266 | BCE Loss: 1.0270785093307495
Epoch 445 / 500 | iteration 0 / 30 | Total Loss: 3.495903730392456 | KNN Loss: 2.445645332336426 | BCE Loss: 1.0502583980560303
Epoch 445 / 500 | iteration 5 / 30 | Total Loss: 3.455580234527588 | KNN Loss: 2.4231808185577393 | BCE Loss: 1.032399296760559
Epoch 445 / 500 | iteration 10 / 30 | Total Loss: 3.4261634349823 | KNN Loss: 2.4190709590911865 | BCE Loss: 1.0070924758911133
Epoch 445 / 500 | iteration 15 / 30 | Total Loss: 3.4633331298828125 | KNN Loss: 2.470064878463745 | BCE Loss: 0.9932681322097778
Epoch 445 / 500 | iteration 20 / 30 | Total Loss: 3.4619057178497314 | KNN Loss: 2.428673267364502 | BCE Loss: 1.0332324504852295
Epoch 445 / 500 | iteration 25 / 30 | Total Loss: 3.4310970306396484 | KNN Loss: 2.4259560108184

Epoch 455 / 500 | iteration 5 / 30 | Total Loss: 3.497344493865967 | KNN Loss: 2.4596242904663086 | BCE Loss: 1.0377200841903687
Epoch 455 / 500 | iteration 10 / 30 | Total Loss: 3.4561691284179688 | KNN Loss: 2.4461653232574463 | BCE Loss: 1.0100038051605225
Epoch 455 / 500 | iteration 15 / 30 | Total Loss: 3.4925224781036377 | KNN Loss: 2.4930953979492188 | BCE Loss: 0.999427080154419
Epoch 455 / 500 | iteration 20 / 30 | Total Loss: 3.4616498947143555 | KNN Loss: 2.4440789222717285 | BCE Loss: 1.0175708532333374
Epoch 455 / 500 | iteration 25 / 30 | Total Loss: 3.4563753604888916 | KNN Loss: 2.427299737930298 | BCE Loss: 1.0290756225585938
Epoch 456 / 500 | iteration 0 / 30 | Total Loss: 3.45928955078125 | KNN Loss: 2.4431700706481934 | BCE Loss: 1.016119360923767
Epoch 456 / 500 | iteration 5 / 30 | Total Loss: 3.51631498336792 | KNN Loss: 2.4851460456848145 | BCE Loss: 1.0311689376831055
Epoch 456 / 500 | iteration 10 / 30 | Total Loss: 3.4680538177490234 | KNN Loss: 2.42786955833

Epoch 465 / 500 | iteration 20 / 30 | Total Loss: 3.423213481903076 | KNN Loss: 2.429776906967163 | BCE Loss: 0.9934365153312683
Epoch 465 / 500 | iteration 25 / 30 | Total Loss: 3.462337017059326 | KNN Loss: 2.4411447048187256 | BCE Loss: 1.0211924314498901
Epoch 466 / 500 | iteration 0 / 30 | Total Loss: 3.4683661460876465 | KNN Loss: 2.444000005722046 | BCE Loss: 1.0243662595748901
Epoch 466 / 500 | iteration 5 / 30 | Total Loss: 3.4846179485321045 | KNN Loss: 2.460521697998047 | BCE Loss: 1.0240962505340576
Epoch 466 / 500 | iteration 10 / 30 | Total Loss: 3.4645864963531494 | KNN Loss: 2.4379167556762695 | BCE Loss: 1.0266697406768799
Epoch 466 / 500 | iteration 15 / 30 | Total Loss: 3.4365944862365723 | KNN Loss: 2.4429233074188232 | BCE Loss: 0.9936712980270386
Epoch 466 / 500 | iteration 20 / 30 | Total Loss: 3.4956912994384766 | KNN Loss: 2.4752302169799805 | BCE Loss: 1.0204612016677856
Epoch 466 / 500 | iteration 25 / 30 | Total Loss: 3.4406561851501465 | KNN Loss: 2.4264559

Epoch 476 / 500 | iteration 5 / 30 | Total Loss: 3.445772647857666 | KNN Loss: 2.4483275413513184 | BCE Loss: 0.9974452257156372
Epoch 476 / 500 | iteration 10 / 30 | Total Loss: 3.5029916763305664 | KNN Loss: 2.4645087718963623 | BCE Loss: 1.038482904434204
Epoch 476 / 500 | iteration 15 / 30 | Total Loss: 3.471266984939575 | KNN Loss: 2.469269275665283 | BCE Loss: 1.001997709274292
Epoch 476 / 500 | iteration 20 / 30 | Total Loss: 3.5171613693237305 | KNN Loss: 2.491731882095337 | BCE Loss: 1.025429606437683
Epoch 476 / 500 | iteration 25 / 30 | Total Loss: 3.455288887023926 | KNN Loss: 2.4493396282196045 | BCE Loss: 1.0059492588043213
Epoch 477 / 500 | iteration 0 / 30 | Total Loss: 3.4545929431915283 | KNN Loss: 2.4520113468170166 | BCE Loss: 1.0025815963745117
Epoch 477 / 500 | iteration 5 / 30 | Total Loss: 3.4612174034118652 | KNN Loss: 2.428755283355713 | BCE Loss: 1.0324621200561523
Epoch 477 / 500 | iteration 10 / 30 | Total Loss: 3.4868459701538086 | KNN Loss: 2.441899061203

Epoch 486 / 500 | iteration 20 / 30 | Total Loss: 3.461009979248047 | KNN Loss: 2.439967632293701 | BCE Loss: 1.0210423469543457
Epoch 486 / 500 | iteration 25 / 30 | Total Loss: 3.4619665145874023 | KNN Loss: 2.424189329147339 | BCE Loss: 1.037777304649353
Epoch 487 / 500 | iteration 0 / 30 | Total Loss: 3.4783449172973633 | KNN Loss: 2.447495937347412 | BCE Loss: 1.0308489799499512
Epoch 487 / 500 | iteration 5 / 30 | Total Loss: 3.4994096755981445 | KNN Loss: 2.479006290435791 | BCE Loss: 1.0204033851623535
Epoch 487 / 500 | iteration 10 / 30 | Total Loss: 3.473900079727173 | KNN Loss: 2.4573276042938232 | BCE Loss: 1.0165724754333496
Epoch 487 / 500 | iteration 15 / 30 | Total Loss: 3.4627180099487305 | KNN Loss: 2.4414563179016113 | BCE Loss: 1.0212615728378296
Epoch 487 / 500 | iteration 20 / 30 | Total Loss: 3.4927616119384766 | KNN Loss: 2.458444118499756 | BCE Loss: 1.0343174934387207
Epoch 487 / 500 | iteration 25 / 30 | Total Loss: 3.4832425117492676 | KNN Loss: 2.4796440601

Epoch 497 / 500 | iteration 5 / 30 | Total Loss: 3.503479480743408 | KNN Loss: 2.483837842941284 | BCE Loss: 1.019641637802124
Epoch 497 / 500 | iteration 10 / 30 | Total Loss: 3.4388134479522705 | KNN Loss: 2.4380390644073486 | BCE Loss: 1.0007743835449219
Epoch 497 / 500 | iteration 15 / 30 | Total Loss: 3.498948335647583 | KNN Loss: 2.4851202964782715 | BCE Loss: 1.0138280391693115
Epoch 497 / 500 | iteration 20 / 30 | Total Loss: 3.4733033180236816 | KNN Loss: 2.4299356937408447 | BCE Loss: 1.0433675050735474
Epoch 497 / 500 | iteration 25 / 30 | Total Loss: 3.4800262451171875 | KNN Loss: 2.467148780822754 | BCE Loss: 1.0128774642944336
Epoch 498 / 500 | iteration 0 / 30 | Total Loss: 3.4869842529296875 | KNN Loss: 2.4547674655914307 | BCE Loss: 1.0322169065475464
Epoch 498 / 500 | iteration 5 / 30 | Total Loss: 3.5106704235076904 | KNN Loss: 2.454357624053955 | BCE Loss: 1.0563127994537354
Epoch 498 / 500 | iteration 10 / 30 | Total Loss: 3.4620797634124756 | KNN Loss: 2.441644191

In [8]:
outputs, iterm = model(dataset[67][0].unsqueeze(0).to(device), return_intermidiate=True)
print(outputs)
print(outputs.min())
print(outputs.max())

tensor([[ 1.8027e+00,  2.0555e+00,  3.2644e+00,  4.4957e+00,  2.6340e+00,
          4.5206e-02,  1.7304e+00,  2.7014e+00,  2.9560e+00,  1.8826e+00,
          2.2371e+00,  2.4518e+00,  1.5381e+00,  1.9312e+00,  1.9301e+00,
          1.6787e+00,  3.6269e+00,  3.7735e+00,  1.8001e+00,  1.6535e+00,
          1.3731e+00,  2.3451e+00,  2.7495e+00,  2.8467e+00,  3.1070e+00,
          1.1494e+00,  2.9233e+00,  1.3491e+00,  1.7394e+00,  5.9677e-01,
         -2.3201e-01,  8.0719e-01,  5.5910e-01,  3.2041e-01,  1.9358e+00,
          1.9745e+00,  1.1563e+00,  1.8492e+00,  1.3308e+00,  1.3106e+00,
          4.0394e-02, -6.2011e-01, -2.3070e-01,  1.8453e+00,  2.2274e+00,
          1.0378e+00,  1.3771e-01, -2.7222e-01,  1.5964e+00,  2.1125e+00,
          1.7411e+00,  5.7037e-01,  1.7317e+00,  6.2665e-01, -5.4628e-01,
          1.6759e+00,  1.3695e+00,  1.7128e+00,  8.5411e-01,  1.7499e+00,
          7.0341e-01,  5.3553e-01,  4.1325e-01,  8.8071e-01,  1.5722e+00,
          1.2705e+00, -1.7034e+00,  4.

In [9]:
plt.figure()
plt.plot(losses)
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [10]:
dataset.transform = torchvision.transforms.Compose([
    BinaryEncodingTransform(mapping=dataset.items_to_idx),
]
)
dataset.target_transform = torchvision.transforms.Compose([
    BinaryEncodingTransform(mapping=dataset.items_to_idx),
]
)

In [11]:
dataset_ = [d[0].to('cpu') for d in dataset]

In [12]:
model = model.eval().to('cpu')
projections = model.calculate_intermidiate(dataset_)

100%|██████████| 15/15 [00:00<00:00, 17.40it/s]


In [13]:
distances = pairwise_distances(projections)
# distances = np.triu(distances)
distances_f = distances.flatten()

plt.matshow(distances)
plt.colorbar()
plt.figure()
plt.hist(distances_f[distances_f > 0], bins=1000)
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

# Fit DBSCAN and calculate indices


In [14]:
clusters = DBSCAN(eps=0.01, min_samples=80).fit_predict(projections)
# scores = []
# best_score = float('inf')
# clusters = None
# range_ = list(range(5, 20))
# for k in tqdm(range_):
#     y = GaussianMixture(n_components=k).fit_predict(projections)
#     cur_score = davies_bouldin_score(projections, y)
#     scores.append(cur_score)
    
#     if cur_score < best_score:
#         best_score = cur_score
#         clusters = y

In [None]:
perplexity = 100

p = reduce_dims_and_plot(projections[clusters != -1],
                         y=clusters[clusters != -1],
                         title=f'perplexity: {perplexity}',
                         file_name=None,
                         perplexity=perplexity,
                         library='Multicore-TSNE',
                         perform_PCA=False,
                         projected=None,
                         figure_type='2d',
                         show_figure=True,
                         close_figure=False,
                         text=None)    

In [None]:
# from sklearn.tree import DecisionTreeClassifier
# from sklearn import tree
# from sklearn.tree import _tree

In [None]:
tensor_dataset = torch.stack(dataset_)

In [None]:
# clf = DecisionTreeClassifier(max_depth=200, min_samples_leaf=5)
# clf = clf.fit(tensor_dataset[clusters!=-1], clusters[clusters != -1])
# print(clf.score(tensor_dataset[clusters!=-1], clusters[clusters != -1]))
# print(clf.get_depth())

In [None]:
# scores = []
# for min_samples in range(1,50, 1):
#     clf = DecisionTreeClassifier(max_depth=200, min_samples_leaf=min_samples)
#     clf = clf.fit(tensor_dataset[clusters!=-1], clusters[clusters != -1])
#     scores.append(clf.score(tensor_dataset[clusters!=-1], clusters[clusters != -1]))
    
# plt.figure()
# plt.plot(list(range(1,50, 1)), scores)
# plt.show()

In [None]:
def get_rules(tree, feature_names, class_names):
    tree_ = tree.tree_
    feature_name = [
        feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!"
        for i in tree_.feature
    ]

    paths = []
    path = []
    
    def recurse(node, path, paths):
        if tree_.feature[node] != _tree.TREE_UNDEFINED:
            name = feature_name[node]
            threshold = tree_.threshold[node]
            p1, p2 = list(path), list(path)
#             p1 += [f"({name} <= {np.round(threshold, 3)})"]
            p1 += [(name, '<=', np.round(threshold, 3))]
            recurse(tree_.children_left[node], p1, paths)
            p2 += [(name, '>', np.round(threshold, 3))]
            recurse(tree_.children_right[node], p2, paths)
        else:
            path += [(tree_.value[node], tree_.n_node_samples[node])]
            paths += [path]
            
    recurse(0, path, paths)

    # sort by samples count
    samples_count = [p[-1][1] for p in paths]
    ii = list(np.argsort(samples_count))
    paths = [paths[i] for i in reversed(ii)]
    
    rules = []
    for path in paths:
        rule = []
        
        for p in path[:-1]:
            rule += [p]
        target = " then "
        if class_names is None:
            target += "response: "+str(np.round(path[-1][0][0][0],3))
        else:
            classes = path[-1][0][0]
            l = np.argmax(classes)
            target += f"class: {class_names[l]} (proba: {np.round(100.0*classes[l]/np.sum(classes),2)}%)"
           
        proba = np.round(100.0*classes[l]/np.sum(classes),2)
        target += f" | based on {path[-1][1]:,} samples"
        rule_wrapper = {'target': target, 'rule': rule, 'proba': proba}
        rules += [rule_wrapper]
        
    return rules

In [None]:
# rules = get_rules(clf, dataset.items, clusters[clusters != -1])

# for rule in rules:
#     n_pos = 0
#     for c,p,v in rule['rule']:
#         if p == '>':
#             n_pos += 1
#     rule['pos'] = n_pos

In [None]:
# plt.figure()
# probs = [r['proba'] for r in rules]
# plt.hist(probs, bins = 100)
# plt.show()

In [None]:
# rules = sorted(rules, key=lambda x:x['pos'])
# rules = [r for r in rules if r['proba'] > 50]
# print(len(rules))

In [None]:
# for i in range(17):
#     r_i = rules[i]
#     print(f"------------- rule {i} length {len(r_i)} -------------")
#     print(r_i)

# Train a Soft-Decision-Tree given the self-labels

## Prepare the dataset

In [None]:
tree_dataset = list(zip(tensor_dataset[clusters!=-1], clusters[clusters != -1]))
batch_size = 512
tree_loader = torch.utils.data.DataLoader(tree_dataset, batch_size=batch_size, shuffle=True)

# Define how we prune the weights of a node

In [None]:
def prune_node(node_weights, factor=1):
    w = node_weights.cpu().detach().numpy()
    mean_ = np.mean(w)
    std_ = np.std(w)
    node_weights[((mean_ - std_ * factor) < node_weights) & (node_weights < (mean_ + std_ * factor))] = 0
    return node_weights

def prune_node_keep(node_weights, keep=4):
    w = node_weights.cpu().detach().numpy()
    throw_idx = np.argsort(abs(w))[:-keep]
    node_weights[throw_idx] = 0
    return node_weights

def prune_tree(tree_, factor):
    new_weights = tree_.inner_nodes.weight.clone()
    for i in range(new_weights.shape[0]):
        res = prune_node_keep(new_weights[i, :], factor)
        new_weights[i, :] = res

    with torch.no_grad():
        tree_.inner_nodes.weight.copy_(new_weights)
        
def sparseness(x):
    s = []
    for i in range(x.shape[0]):
        x_ = x[i, :]
        sp = (len(x_) - torch.norm(x_, 0).item()) / len(x_)
        s.append(sp)
    return np.mean(s)

def compute_regularization_by_level(tree):
    total_reg = 0
    for i in range(tree.inner_nodes.weight.shape[0]):
        cur_level = np.floor(np.log2(i+1))
        node_reg = torch.norm(tree.inner_nodes.weight[i].view(-1), 2)
        total_reg += 2**(-cur_level) * node_reg
    return total_reg

def show_sparseness(tree):
    avg_sp = sparseness(tree.inner_nodes.weight)
    print(f"Average sparseness: {avg_sp}")
    layer = 0
    sps = []
    for i in range(tree.inner_nodes.weight.shape[0]):
        cur_layer = int(np.floor(np.log2(i+1)))
        if cur_layer != layer:
            print(f"layer {layer}: {np.mean(sps)}")
            sps = []
            layer = cur_layer

        x_ = tree.inner_nodes.weight[i, :]
        sp = (len(x_) - torch.norm(x_, 0).item()) / len(x_)
        sps.append(sp)
        
    return avg_sp

## Training configurations

In [None]:
def do_epoch(model, loader, device, log_interval, losses, accs, epoch, iteration):
    model = model.train()
    for batch_idx, (data, target) in enumerate(loader):
        iteration += 1
        data, target = data.to(device), target.to(device)

        output, penalty = tree.forward(data)

        # Loss
        loss_tree = criterion(output, target.view(-1))

        # Penalty
        loss_tree += penalty

        # Sparse regularization
#         fc_params = torch.cat([x.view(-1) for x in tree.inner_nodes.parameters()])
#         regularization = sparsity_lamda * torch.norm(fc_params, 2)
        regularization = sparsity_lamda * compute_regularization_by_level(tree)
        loss = loss_tree

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        losses.append(loss.item())

        pred = output.data.max(1)[1]
        correct = pred.eq(target.view(-1).data).sum()
        accs.append(correct.item() / data.size()[0])

        # Print training status
        if batch_idx % log_interval == 0:
            print(f"Epoch: {epoch:02d} | Batch: {batch_idx:03d} / {len(loader):03d} | Total loss: {loss.item():.3f} | Reg loss: {regularization.item():.3f} | Tree loss: {loss_tree.item():.3f} | Accuracy: {correct.item() / data.size()[0]:03f} | {round((time.time() - start_time) / iteration, 3)} sec/iter")
            
    return iteration


In [None]:
lr = 5e-3
weight_decay = 5e-4
sparsity_lamda = 2e-3
epochs = 100
output_dim = len(set(clusters))
log_interval = 1
use_cuda = device != 'cpu'

In [None]:
tree = SDT(input_dim=tensor_dataset.shape[1], output_dim=len(clusters - 1), depth=tree_depth, lamda=1e-3, use_cuda=use_cuda)
optimizer = torch.optim.Adam(tree.parameters(),
                                 lr=lr,
                                 weight_decay=weight_decay)
criterion = nn.CrossEntropyLoss()
tree = tree.to(device)

In [None]:
losses = []
accs = []
sparsity = []

In [None]:
start_time = time.time()
iteration = 0
for epoch in range(epochs):
    # Training
    avg_sp = show_sparseness(tree)
    sparsity.append(avg_sp)
    iteration = do_epoch(tree, tree_loader, device, log_interval, losses, accs, epoch, iteration)
    
    if epoch % 1 == 0:
        prune_tree(tree, factor=3)
        

In [None]:
plt.figure(figsize=(10, 5))
plt.ylabel("Accuracy")
plt.xlabel('Iteration')
plt.plot(accs, label='Accuracy vs iteration')
plt.show()

In [None]:
plt.figure()
plt.ylabel("Loss")
plt.xlabel('Iteration')
plt.plot(losses, label='Loss vs iteration')
plt.yscale("log")
plt.show()

plt.figure()
weights = tree.inner_nodes.weight.cpu().detach().numpy().flatten()
plt.hist(weights, bins=500)
weights_std = np.std(weights)
weights_mean = np.mean(weights)
plt.axvline(weights_mean + weights_std, color='r')
plt.axvline(weights_mean - weights_std, color='r')
plt.title(f"Mean: {weights_mean}   |   STD: {weights_std}")
plt.yscale("log")
plt.show()

# Tree Visualization

In [None]:
plt.figure(figsize=(15, 10), dpi=80)
avg_height, root = tree.visualize()

# Extract Rules

# Accumulate samples in the leaves

In [None]:
print(f"Number of patterns: {len(root.get_leaves())}")

In [None]:
method = 'greedy'

In [None]:
root.clear_leaves_samples()

with torch.no_grad():
    for batch_idx, (data, target) in enumerate(tree_loader):
        root.accumulate_samples(data, method)

# Tighten boundaries

In [None]:
attr_names = dataset.items

# print(attr_names)
leaves = root.get_leaves()
sum_comprehensibility = 0
comprehensibilities = []
for pattern_counter, leaf in enumerate(leaves):
    leaf.reset_path()
    leaf.tighten_with_accumulated_samples()
    conds = leaf.get_path_conditions(attr_names)
    print(f"============== Pattern {pattern_counter + 1} ==============")
    comprehensibilities.append(sum([cond.comprehensibility for cond in conds]))
    
print(f"Average comprehensibility: {np.mean(comprehensibilities)}")
print(f"std comprehensibility: {np.std(comprehensibilities)}")
print(f"var comprehensibility: {np.var(comprehensibilities)}")
print(f"minimum comprehensibility: {np.min(comprehensibilities)}")
print(f"maximum comprehensibility: {np.max(comprehensibilities)}")