In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib widget
import numpy as np
import pandas as pd
import torch
import torchvision
from torch import optim
import matplotlib.pyplot as plt
import torch.nn.functional as F
from sklearn.metrics import pairwise_distances
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tqdm.notebook import tqdm
import os
import sys
import time
import torch.nn as nn
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)
from stream_generators.market_basket_dataset import MarketBasketDataset, BinaryEncodingTransform, RemoveItemsTransform
from utils.MatplotlibUtils import reduce_dims_and_plot
from network.auto_encoder import AutoEncoder
from losses.knn_loss import KNNLoss
from sklearn.cluster import DBSCAN
from sklearn.mixture import GaussianMixture
from soft_decision_tree.sdt_model import SDT
from sklearn.metrics import davies_bouldin_score

In [2]:
k = 8
tree_depth = 8
device = 'cuda'
dataset_path = r"/mnt/qnap/ekosman/Groceries_dataset.csv"

## Create the market basket dataset and use one-hot encoding for items

In [3]:
dataset = MarketBasketDataset(dataset_path=dataset_path)

In [4]:
model = AutoEncoder(dataset.n_items, 50, 4).train().to(device)
epochs = 500
lr = 5e-3
batch_size = 512
log_every = 5

In [5]:
dataset.transform = torchvision.transforms.Compose([
    RemoveItemsTransform(p=0.5),
    BinaryEncodingTransform(mapping=dataset.items_to_idx),
]
)
dataset.target_transform = torchvision.transforms.Compose([
    BinaryEncodingTransform(mapping=dataset.items_to_idx),
]
)

In [6]:
model.train()
data_iter = torch.utils.data.DataLoader(dataset,
                                     batch_size=batch_size,
                                     shuffle=True,
                                     num_workers=1,
                                     pin_memory=True)
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)
scheduler = ReduceLROnPlateau(optimizer, 'min', verbose=True, factor=0.7, threshold=1e-4)
knn_crt = KNNLoss(k=k).to(device)
losses = []
alpha = 10/170
gamma = 2
for epoch in range(epochs):
    total_loss = 0
    for iteration, (batch, target) in enumerate(data_iter):
        batch = batch.to(device)
        target = target.to(device)
        outputs, iterm = model(batch, return_intermidiate=True)
        mse_loss = F.binary_cross_entropy_with_logits(outputs, target, reduction='none')
        mask = torch.ones_like(mse_loss)
        mask[target == 0] = alpha ** gamma
        mask[target == 1] = (1 - alpha) ** gamma
        mse_loss = (mse_loss * mask).sum(dim=-1).mean()
        try:
            knn_loss = knn_crt(iterm)
            if torch.isinf(knn_loss):
                knn_loss = 0
        except ValueError:
            knn_loss = torch.tensor(0)
        loss = mse_loss + knn_loss
        total_loss += loss.item()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if iteration % log_every == 0:
            print(f"Epoch {epoch} / {epochs} | iteration {iteration} / {len(data_iter)} | Total Loss: {loss.item()} | KNN Loss: {knn_loss.item()} | BCE Loss: {mse_loss.item()}")
    
    scheduler.step(total_loss / (iteration + 1))
    losses.append(total_loss / (iteration + 1))

Epoch 0 / 500 | iteration 0 / 30 | Total Loss: 8.164026260375977 | KNN Loss: 6.224818229675293 | BCE Loss: 1.9392085075378418
Epoch 0 / 500 | iteration 5 / 30 | Total Loss: 8.184050559997559 | KNN Loss: 6.224360942840576 | BCE Loss: 1.9596898555755615
Epoch 0 / 500 | iteration 10 / 30 | Total Loss: 8.180439949035645 | KNN Loss: 6.224224090576172 | BCE Loss: 1.956215739250183
Epoch 0 / 500 | iteration 15 / 30 | Total Loss: 8.135895729064941 | KNN Loss: 6.223940849304199 | BCE Loss: 1.9119547605514526
Epoch 0 / 500 | iteration 20 / 30 | Total Loss: 8.149843215942383 | KNN Loss: 6.2236647605896 | BCE Loss: 1.9261789321899414
Epoch 0 / 500 | iteration 25 / 30 | Total Loss: 8.154428482055664 | KNN Loss: 6.223090648651123 | BCE Loss: 1.931337833404541
Epoch 1 / 500 | iteration 0 / 30 | Total Loss: 8.125213623046875 | KNN Loss: 6.22247838973999 | BCE Loss: 1.9027354717254639
Epoch 1 / 500 | iteration 5 / 30 | Total Loss: 8.126628875732422 | KNN Loss: 6.222614765167236 | BCE Loss: 1.9040144681

Epoch 10 / 500 | iteration 25 / 30 | Total Loss: 6.929956436157227 | KNN Loss: 5.7535014152526855 | BCE Loss: 1.1764552593231201
Epoch 11 / 500 | iteration 0 / 30 | Total Loss: 6.860462665557861 | KNN Loss: 5.719170093536377 | BCE Loss: 1.141292691230774
Epoch 11 / 500 | iteration 5 / 30 | Total Loss: 6.783545970916748 | KNN Loss: 5.627161502838135 | BCE Loss: 1.1563844680786133
Epoch 11 / 500 | iteration 10 / 30 | Total Loss: 6.727011203765869 | KNN Loss: 5.563403606414795 | BCE Loss: 1.1636075973510742
Epoch 11 / 500 | iteration 15 / 30 | Total Loss: 6.586681842803955 | KNN Loss: 5.452396869659424 | BCE Loss: 1.1342850923538208
Epoch 11 / 500 | iteration 20 / 30 | Total Loss: 6.4945807456970215 | KNN Loss: 5.36434268951416 | BCE Loss: 1.1302380561828613
Epoch 11 / 500 | iteration 25 / 30 | Total Loss: 6.422830581665039 | KNN Loss: 5.273220062255859 | BCE Loss: 1.1496105194091797
Epoch 12 / 500 | iteration 0 / 30 | Total Loss: 6.283742427825928 | KNN Loss: 5.15598201751709 | BCE Loss:

Epoch 21 / 500 | iteration 20 / 30 | Total Loss: 3.7727081775665283 | KNN Loss: 2.703932285308838 | BCE Loss: 1.0687758922576904
Epoch 21 / 500 | iteration 25 / 30 | Total Loss: 3.7198593616485596 | KNN Loss: 2.672633171081543 | BCE Loss: 1.0472261905670166
Epoch 22 / 500 | iteration 0 / 30 | Total Loss: 3.7044451236724854 | KNN Loss: 2.654899835586548 | BCE Loss: 1.0495452880859375
Epoch 22 / 500 | iteration 5 / 30 | Total Loss: 3.72906494140625 | KNN Loss: 2.709667921066284 | BCE Loss: 1.0193970203399658
Epoch 22 / 500 | iteration 10 / 30 | Total Loss: 3.7616140842437744 | KNN Loss: 2.7004189491271973 | BCE Loss: 1.0611951351165771
Epoch 22 / 500 | iteration 15 / 30 | Total Loss: 3.775270462036133 | KNN Loss: 2.7181968688964844 | BCE Loss: 1.0570734739303589
Epoch 22 / 500 | iteration 20 / 30 | Total Loss: 3.724064826965332 | KNN Loss: 2.66469407081604 | BCE Loss: 1.059370756149292
Epoch 22 / 500 | iteration 25 / 30 | Total Loss: 3.689032793045044 | KNN Loss: 2.62568736076355 | BCE L

Epoch 32 / 500 | iteration 10 / 30 | Total Loss: 3.6349692344665527 | KNN Loss: 2.5865023136138916 | BCE Loss: 1.0484668016433716
Epoch 32 / 500 | iteration 15 / 30 | Total Loss: 3.709860324859619 | KNN Loss: 2.6696319580078125 | BCE Loss: 1.0402283668518066
Epoch 32 / 500 | iteration 20 / 30 | Total Loss: 3.582477569580078 | KNN Loss: 2.54008150100708 | BCE Loss: 1.0423959493637085
Epoch 32 / 500 | iteration 25 / 30 | Total Loss: 3.575697898864746 | KNN Loss: 2.571176767349243 | BCE Loss: 1.004521131515503
Epoch 33 / 500 | iteration 0 / 30 | Total Loss: 3.658238410949707 | KNN Loss: 2.632403612136841 | BCE Loss: 1.0258347988128662
Epoch 33 / 500 | iteration 5 / 30 | Total Loss: 3.5832250118255615 | KNN Loss: 2.5571794509887695 | BCE Loss: 1.026045560836792
Epoch 33 / 500 | iteration 10 / 30 | Total Loss: 3.619227647781372 | KNN Loss: 2.572521686553955 | BCE Loss: 1.046705961227417
Epoch 33 / 500 | iteration 15 / 30 | Total Loss: 3.5785794258117676 | KNN Loss: 2.559694528579712 | BCE L

Epoch 43 / 500 | iteration 0 / 30 | Total Loss: 3.610761880874634 | KNN Loss: 2.56288743019104 | BCE Loss: 1.0478744506835938
Epoch 43 / 500 | iteration 5 / 30 | Total Loss: 3.5444395542144775 | KNN Loss: 2.53094220161438 | BCE Loss: 1.0134973526000977
Epoch 43 / 500 | iteration 10 / 30 | Total Loss: 3.677459716796875 | KNN Loss: 2.616302013397217 | BCE Loss: 1.0611577033996582
Epoch 43 / 500 | iteration 15 / 30 | Total Loss: 3.6485652923583984 | KNN Loss: 2.6028311252593994 | BCE Loss: 1.0457342863082886
Epoch 43 / 500 | iteration 20 / 30 | Total Loss: 3.5360727310180664 | KNN Loss: 2.542710304260254 | BCE Loss: 0.993362307548523
Epoch 43 / 500 | iteration 25 / 30 | Total Loss: 3.5642385482788086 | KNN Loss: 2.5197808742523193 | BCE Loss: 1.0444576740264893
Epoch 44 / 500 | iteration 0 / 30 | Total Loss: 3.6286654472351074 | KNN Loss: 2.5832479000091553 | BCE Loss: 1.0454174280166626
Epoch 44 / 500 | iteration 5 / 30 | Total Loss: 3.601832866668701 | KNN Loss: 2.5612523555755615 | BCE

Epoch 53 / 500 | iteration 20 / 30 | Total Loss: 3.5895376205444336 | KNN Loss: 2.5695345401763916 | BCE Loss: 1.0200029611587524
Epoch 53 / 500 | iteration 25 / 30 | Total Loss: 3.5572471618652344 | KNN Loss: 2.551034450531006 | BCE Loss: 1.0062127113342285
Epoch 54 / 500 | iteration 0 / 30 | Total Loss: 3.576753616333008 | KNN Loss: 2.546872615814209 | BCE Loss: 1.0298810005187988
Epoch 54 / 500 | iteration 5 / 30 | Total Loss: 3.6599910259246826 | KNN Loss: 2.6025888919830322 | BCE Loss: 1.0574021339416504
Epoch 54 / 500 | iteration 10 / 30 | Total Loss: 3.6276564598083496 | KNN Loss: 2.5744032859802246 | BCE Loss: 1.0532532930374146
Epoch 54 / 500 | iteration 15 / 30 | Total Loss: 3.587306499481201 | KNN Loss: 2.5536444187164307 | BCE Loss: 1.03366219997406
Epoch 54 / 500 | iteration 20 / 30 | Total Loss: 3.5825631618499756 | KNN Loss: 2.5526204109191895 | BCE Loss: 1.0299427509307861
Epoch 54 / 500 | iteration 25 / 30 | Total Loss: 3.598851203918457 | KNN Loss: 2.569169521331787 |

Epoch 64 / 500 | iteration 10 / 30 | Total Loss: 3.585465908050537 | KNN Loss: 2.552129030227661 | BCE Loss: 1.0333369970321655
Epoch 64 / 500 | iteration 15 / 30 | Total Loss: 3.5608396530151367 | KNN Loss: 2.5394670963287354 | BCE Loss: 1.0213725566864014
Epoch 64 / 500 | iteration 20 / 30 | Total Loss: 3.5875821113586426 | KNN Loss: 2.570178508758545 | BCE Loss: 1.017403483390808
Epoch 64 / 500 | iteration 25 / 30 | Total Loss: 3.5298104286193848 | KNN Loss: 2.4983975887298584 | BCE Loss: 1.0314128398895264
Epoch 65 / 500 | iteration 0 / 30 | Total Loss: 3.5792436599731445 | KNN Loss: 2.533919095993042 | BCE Loss: 1.045324444770813
Epoch 65 / 500 | iteration 5 / 30 | Total Loss: 3.516679525375366 | KNN Loss: 2.5165629386901855 | BCE Loss: 1.0001165866851807
Epoch 65 / 500 | iteration 10 / 30 | Total Loss: 3.5751900672912598 | KNN Loss: 2.519615650177002 | BCE Loss: 1.0555744171142578
Epoch 65 / 500 | iteration 15 / 30 | Total Loss: 3.564438819885254 | KNN Loss: 2.549682855606079 | B

Epoch 75 / 500 | iteration 0 / 30 | Total Loss: 3.5113983154296875 | KNN Loss: 2.489837884902954 | BCE Loss: 1.021560549736023
Epoch 75 / 500 | iteration 5 / 30 | Total Loss: 3.551206588745117 | KNN Loss: 2.510192394256592 | BCE Loss: 1.041014313697815
Epoch 75 / 500 | iteration 10 / 30 | Total Loss: 3.5672738552093506 | KNN Loss: 2.5317435264587402 | BCE Loss: 1.0355303287506104
Epoch 75 / 500 | iteration 15 / 30 | Total Loss: 3.5376336574554443 | KNN Loss: 2.539167881011963 | BCE Loss: 0.9984658360481262
Epoch 75 / 500 | iteration 20 / 30 | Total Loss: 3.5633864402770996 | KNN Loss: 2.520132541656494 | BCE Loss: 1.0432538986206055
Epoch 75 / 500 | iteration 25 / 30 | Total Loss: 3.5208120346069336 | KNN Loss: 2.514152765274048 | BCE Loss: 1.0066593885421753
Epoch 76 / 500 | iteration 0 / 30 | Total Loss: 3.5505008697509766 | KNN Loss: 2.507652997970581 | BCE Loss: 1.042847990989685
Epoch 76 / 500 | iteration 5 / 30 | Total Loss: 3.588606357574463 | KNN Loss: 2.574306011199951 | BCE L

Epoch 85 / 500 | iteration 20 / 30 | Total Loss: 3.5259575843811035 | KNN Loss: 2.514822006225586 | BCE Loss: 1.0111355781555176
Epoch 85 / 500 | iteration 25 / 30 | Total Loss: 3.5052566528320312 | KNN Loss: 2.4733643531799316 | BCE Loss: 1.0318922996520996
Epoch 86 / 500 | iteration 0 / 30 | Total Loss: 3.5190553665161133 | KNN Loss: 2.5123367309570312 | BCE Loss: 1.0067187547683716
Epoch 86 / 500 | iteration 5 / 30 | Total Loss: 3.552382469177246 | KNN Loss: 2.5506584644317627 | BCE Loss: 1.0017238855361938
Epoch 86 / 500 | iteration 10 / 30 | Total Loss: 3.5013747215270996 | KNN Loss: 2.486071825027466 | BCE Loss: 1.0153028964996338
Epoch 86 / 500 | iteration 15 / 30 | Total Loss: 3.475728988647461 | KNN Loss: 2.478623628616333 | BCE Loss: 0.9971053004264832
Epoch 86 / 500 | iteration 20 / 30 | Total Loss: 3.5276479721069336 | KNN Loss: 2.4986956119537354 | BCE Loss: 1.0289524793624878
Epoch 86 / 500 | iteration 25 / 30 | Total Loss: 3.5383169651031494 | KNN Loss: 2.510927677154541

Epoch 96 / 500 | iteration 10 / 30 | Total Loss: 3.506221055984497 | KNN Loss: 2.4730634689331055 | BCE Loss: 1.0331575870513916
Epoch 96 / 500 | iteration 15 / 30 | Total Loss: 3.562993288040161 | KNN Loss: 2.5314979553222656 | BCE Loss: 1.0314953327178955
Epoch 96 / 500 | iteration 20 / 30 | Total Loss: 3.5189082622528076 | KNN Loss: 2.532003164291382 | BCE Loss: 0.986905038356781
Epoch 96 / 500 | iteration 25 / 30 | Total Loss: 3.5038681030273438 | KNN Loss: 2.493849039077759 | BCE Loss: 1.010019063949585
Epoch 97 / 500 | iteration 0 / 30 | Total Loss: 3.542483329772949 | KNN Loss: 2.507343292236328 | BCE Loss: 1.035140037536621
Epoch 97 / 500 | iteration 5 / 30 | Total Loss: 3.5304887294769287 | KNN Loss: 2.466078996658325 | BCE Loss: 1.0644097328186035
Epoch 97 / 500 | iteration 10 / 30 | Total Loss: 3.5315475463867188 | KNN Loss: 2.49580454826355 | BCE Loss: 1.035742998123169
Epoch 97 / 500 | iteration 15 / 30 | Total Loss: 3.5670032501220703 | KNN Loss: 2.52927303314209 | BCE Lo

Epoch 107 / 500 | iteration 0 / 30 | Total Loss: 3.4873452186584473 | KNN Loss: 2.485896110534668 | BCE Loss: 1.0014492273330688
Epoch 107 / 500 | iteration 5 / 30 | Total Loss: 3.567662000656128 | KNN Loss: 2.553964376449585 | BCE Loss: 1.013697624206543
Epoch 107 / 500 | iteration 10 / 30 | Total Loss: 3.5459632873535156 | KNN Loss: 2.5208258628845215 | BCE Loss: 1.0251374244689941
Epoch 107 / 500 | iteration 15 / 30 | Total Loss: 3.5103864669799805 | KNN Loss: 2.4966509342193604 | BCE Loss: 1.0137355327606201
Epoch 107 / 500 | iteration 20 / 30 | Total Loss: 3.5276284217834473 | KNN Loss: 2.4909420013427734 | BCE Loss: 1.0366865396499634
Epoch 107 / 500 | iteration 25 / 30 | Total Loss: 3.4602839946746826 | KNN Loss: 2.4719700813293457 | BCE Loss: 0.9883139729499817
Epoch 108 / 500 | iteration 0 / 30 | Total Loss: 3.5015876293182373 | KNN Loss: 2.4686832427978516 | BCE Loss: 1.0329043865203857
Epoch 108 / 500 | iteration 5 / 30 | Total Loss: 3.533907175064087 | KNN Loss: 2.530115365

Epoch 117 / 500 | iteration 20 / 30 | Total Loss: 3.581278085708618 | KNN Loss: 2.549314498901367 | BCE Loss: 1.031963586807251
Epoch 117 / 500 | iteration 25 / 30 | Total Loss: 3.484386444091797 | KNN Loss: 2.486278772354126 | BCE Loss: 0.9981077909469604
Epoch 118 / 500 | iteration 0 / 30 | Total Loss: 3.5107364654541016 | KNN Loss: 2.481480360031128 | BCE Loss: 1.0292561054229736
Epoch 118 / 500 | iteration 5 / 30 | Total Loss: 3.498415231704712 | KNN Loss: 2.484563112258911 | BCE Loss: 1.0138521194458008
Epoch 118 / 500 | iteration 10 / 30 | Total Loss: 3.453669548034668 | KNN Loss: 2.4866321086883545 | BCE Loss: 0.9670374989509583
Epoch 118 / 500 | iteration 15 / 30 | Total Loss: 3.5156712532043457 | KNN Loss: 2.493459463119507 | BCE Loss: 1.0222117900848389
Epoch 118 / 500 | iteration 20 / 30 | Total Loss: 3.5361135005950928 | KNN Loss: 2.494555950164795 | BCE Loss: 1.0415575504302979
Epoch 118 / 500 | iteration 25 / 30 | Total Loss: 3.521599531173706 | KNN Loss: 2.47942566871643

Epoch 128 / 500 | iteration 10 / 30 | Total Loss: 3.5212621688842773 | KNN Loss: 2.4974770545959473 | BCE Loss: 1.0237849950790405
Epoch 128 / 500 | iteration 15 / 30 | Total Loss: 3.527390956878662 | KNN Loss: 2.516160249710083 | BCE Loss: 1.0112305879592896
Epoch 128 / 500 | iteration 20 / 30 | Total Loss: 3.5220448970794678 | KNN Loss: 2.502805471420288 | BCE Loss: 1.0192394256591797
Epoch 128 / 500 | iteration 25 / 30 | Total Loss: 3.519540786743164 | KNN Loss: 2.470280408859253 | BCE Loss: 1.0492603778839111
Epoch 129 / 500 | iteration 0 / 30 | Total Loss: 3.4957807064056396 | KNN Loss: 2.495668888092041 | BCE Loss: 1.0001118183135986
Epoch 129 / 500 | iteration 5 / 30 | Total Loss: 3.5080034732818604 | KNN Loss: 2.4827423095703125 | BCE Loss: 1.0252611637115479
Epoch 129 / 500 | iteration 10 / 30 | Total Loss: 3.496786117553711 | KNN Loss: 2.4712905883789062 | BCE Loss: 1.0254954099655151
Epoch 129 / 500 | iteration 15 / 30 | Total Loss: 3.5710787773132324 | KNN Loss: 2.508934020

Epoch 139 / 500 | iteration 0 / 30 | Total Loss: 3.482316732406616 | KNN Loss: 2.4791340827941895 | BCE Loss: 1.0031826496124268
Epoch 139 / 500 | iteration 5 / 30 | Total Loss: 3.5040557384490967 | KNN Loss: 2.473863363265991 | BCE Loss: 1.0301923751831055
Epoch 139 / 500 | iteration 10 / 30 | Total Loss: 3.5037801265716553 | KNN Loss: 2.484387159347534 | BCE Loss: 1.019392967224121
Epoch 139 / 500 | iteration 15 / 30 | Total Loss: 3.482200860977173 | KNN Loss: 2.4948575496673584 | BCE Loss: 0.9873433709144592
Epoch 139 / 500 | iteration 20 / 30 | Total Loss: 3.522592544555664 | KNN Loss: 2.492755651473999 | BCE Loss: 1.029836893081665
Epoch 139 / 500 | iteration 25 / 30 | Total Loss: 3.4807376861572266 | KNN Loss: 2.4747965335845947 | BCE Loss: 1.0059412717819214
Epoch 140 / 500 | iteration 0 / 30 | Total Loss: 3.516231060028076 | KNN Loss: 2.4755067825317383 | BCE Loss: 1.0407243967056274
Epoch 140 / 500 | iteration 5 / 30 | Total Loss: 3.5319228172302246 | KNN Loss: 2.5263152122497

Epoch 149 / 500 | iteration 20 / 30 | Total Loss: 3.4819135665893555 | KNN Loss: 2.476623058319092 | BCE Loss: 1.0052905082702637
Epoch 149 / 500 | iteration 25 / 30 | Total Loss: 3.504093647003174 | KNN Loss: 2.474837303161621 | BCE Loss: 1.0292563438415527
Epoch 150 / 500 | iteration 0 / 30 | Total Loss: 3.522616386413574 | KNN Loss: 2.49013090133667 | BCE Loss: 1.0324856042861938
Epoch 150 / 500 | iteration 5 / 30 | Total Loss: 3.452866554260254 | KNN Loss: 2.4645397663116455 | BCE Loss: 0.9883266687393188
Epoch 150 / 500 | iteration 10 / 30 | Total Loss: 3.4894609451293945 | KNN Loss: 2.501044511795044 | BCE Loss: 0.9884165525436401
Epoch 150 / 500 | iteration 15 / 30 | Total Loss: 3.5418457984924316 | KNN Loss: 2.4853861331939697 | BCE Loss: 1.0564595460891724
Epoch 150 / 500 | iteration 20 / 30 | Total Loss: 3.491000175476074 | KNN Loss: 2.461798906326294 | BCE Loss: 1.0292011499404907
Epoch 150 / 500 | iteration 25 / 30 | Total Loss: 3.4834799766540527 | KNN Loss: 2.474443674087

Epoch 160 / 500 | iteration 10 / 30 | Total Loss: 3.4842495918273926 | KNN Loss: 2.4631786346435547 | BCE Loss: 1.0210708379745483
Epoch 160 / 500 | iteration 15 / 30 | Total Loss: 3.490623712539673 | KNN Loss: 2.4754111766815186 | BCE Loss: 1.0152125358581543
Epoch 160 / 500 | iteration 20 / 30 | Total Loss: 3.5053629875183105 | KNN Loss: 2.498751640319824 | BCE Loss: 1.0066113471984863
Epoch 160 / 500 | iteration 25 / 30 | Total Loss: 3.5257866382598877 | KNN Loss: 2.5125467777252197 | BCE Loss: 1.013239860534668
Epoch 161 / 500 | iteration 0 / 30 | Total Loss: 3.520066499710083 | KNN Loss: 2.488219976425171 | BCE Loss: 1.031846523284912
Epoch 161 / 500 | iteration 5 / 30 | Total Loss: 3.4922516345977783 | KNN Loss: 2.471696138381958 | BCE Loss: 1.0205554962158203
Epoch 161 / 500 | iteration 10 / 30 | Total Loss: 3.4958648681640625 | KNN Loss: 2.4635989665985107 | BCE Loss: 1.0322657823562622
Epoch 161 / 500 | iteration 15 / 30 | Total Loss: 3.5132272243499756 | KNN Loss: 2.490487813

Epoch 170 / 500 | iteration 25 / 30 | Total Loss: 3.536038398742676 | KNN Loss: 2.486983060836792 | BCE Loss: 1.0490552186965942
Epoch 171 / 500 | iteration 0 / 30 | Total Loss: 3.4819111824035645 | KNN Loss: 2.4913947582244873 | BCE Loss: 0.9905164837837219
Epoch 171 / 500 | iteration 5 / 30 | Total Loss: 3.5052528381347656 | KNN Loss: 2.503960371017456 | BCE Loss: 1.00129234790802
Epoch 171 / 500 | iteration 10 / 30 | Total Loss: 3.4983720779418945 | KNN Loss: 2.4889581203460693 | BCE Loss: 1.0094139575958252
Epoch 171 / 500 | iteration 15 / 30 | Total Loss: 3.45648455619812 | KNN Loss: 2.4666526317596436 | BCE Loss: 0.9898319244384766
Epoch 171 / 500 | iteration 20 / 30 | Total Loss: 3.4699037075042725 | KNN Loss: 2.455061197280884 | BCE Loss: 1.0148425102233887
Epoch 171 / 500 | iteration 25 / 30 | Total Loss: 3.455717086791992 | KNN Loss: 2.4700863361358643 | BCE Loss: 0.9856308102607727
Epoch 172 / 500 | iteration 0 / 30 | Total Loss: 3.487621784210205 | KNN Loss: 2.4631655216217

Epoch 181 / 500 | iteration 10 / 30 | Total Loss: 3.5404510498046875 | KNN Loss: 2.5114963054656982 | BCE Loss: 1.0289547443389893
Epoch 181 / 500 | iteration 15 / 30 | Total Loss: 3.489990234375 | KNN Loss: 2.4874119758605957 | BCE Loss: 1.0025782585144043
Epoch 181 / 500 | iteration 20 / 30 | Total Loss: 3.5134968757629395 | KNN Loss: 2.477597951889038 | BCE Loss: 1.0358989238739014
Epoch 181 / 500 | iteration 25 / 30 | Total Loss: 3.494162082672119 | KNN Loss: 2.4808385372161865 | BCE Loss: 1.0133235454559326
Epoch 182 / 500 | iteration 0 / 30 | Total Loss: 3.4722747802734375 | KNN Loss: 2.467616081237793 | BCE Loss: 1.004658579826355
Epoch 182 / 500 | iteration 5 / 30 | Total Loss: 3.467479944229126 | KNN Loss: 2.4614691734313965 | BCE Loss: 1.0060107707977295
Epoch 182 / 500 | iteration 10 / 30 | Total Loss: 3.527548313140869 | KNN Loss: 2.48189377784729 | BCE Loss: 1.045654535293579
Epoch 182 / 500 | iteration 15 / 30 | Total Loss: 3.4685022830963135 | KNN Loss: 2.481451749801635

Epoch 192 / 500 | iteration 0 / 30 | Total Loss: 3.466597557067871 | KNN Loss: 2.4687345027923584 | BCE Loss: 0.9978631734848022
Epoch 192 / 500 | iteration 5 / 30 | Total Loss: 3.4962000846862793 | KNN Loss: 2.489098072052002 | BCE Loss: 1.007102131843567
Epoch 192 / 500 | iteration 10 / 30 | Total Loss: 3.475496530532837 | KNN Loss: 2.4516568183898926 | BCE Loss: 1.0238397121429443
Epoch 192 / 500 | iteration 15 / 30 | Total Loss: 3.4903697967529297 | KNN Loss: 2.4665439128875732 | BCE Loss: 1.0238258838653564
Epoch 192 / 500 | iteration 20 / 30 | Total Loss: 3.477313995361328 | KNN Loss: 2.4598562717437744 | BCE Loss: 1.0174577236175537
Epoch 192 / 500 | iteration 25 / 30 | Total Loss: 3.482335090637207 | KNN Loss: 2.4766244888305664 | BCE Loss: 1.0057107210159302
Epoch 193 / 500 | iteration 0 / 30 | Total Loss: 3.510077476501465 | KNN Loss: 2.5040981769561768 | BCE Loss: 1.0059791803359985
Epoch 193 / 500 | iteration 5 / 30 | Total Loss: 3.4522902965545654 | KNN Loss: 2.46310400962

Epoch 202 / 500 | iteration 20 / 30 | Total Loss: 3.521244525909424 | KNN Loss: 2.501488208770752 | BCE Loss: 1.0197563171386719
Epoch 202 / 500 | iteration 25 / 30 | Total Loss: 3.544469118118286 | KNN Loss: 2.489837169647217 | BCE Loss: 1.0546319484710693
Epoch 203 / 500 | iteration 0 / 30 | Total Loss: 3.4819629192352295 | KNN Loss: 2.455026149749756 | BCE Loss: 1.0269367694854736
Epoch 203 / 500 | iteration 5 / 30 | Total Loss: 3.5292177200317383 | KNN Loss: 2.4953784942626953 | BCE Loss: 1.033839225769043
Epoch 203 / 500 | iteration 10 / 30 | Total Loss: 3.414332389831543 | KNN Loss: 2.428968667984009 | BCE Loss: 0.9853636026382446
Epoch 203 / 500 | iteration 15 / 30 | Total Loss: 3.491642713546753 | KNN Loss: 2.494187831878662 | BCE Loss: 0.997454822063446
Epoch 203 / 500 | iteration 20 / 30 | Total Loss: 3.450909376144409 | KNN Loss: 2.4501454830169678 | BCE Loss: 1.0007638931274414
Epoch 203 / 500 | iteration 25 / 30 | Total Loss: 3.4956445693969727 | KNN Loss: 2.49507212638855

Epoch 213 / 500 | iteration 5 / 30 | Total Loss: 3.4666881561279297 | KNN Loss: 2.4429380893707275 | BCE Loss: 1.0237501859664917
Epoch 213 / 500 | iteration 10 / 30 | Total Loss: 3.475761890411377 | KNN Loss: 2.4680514335632324 | BCE Loss: 1.0077104568481445
Epoch 213 / 500 | iteration 15 / 30 | Total Loss: 3.4060983657836914 | KNN Loss: 2.4355077743530273 | BCE Loss: 0.9705907106399536
Epoch 213 / 500 | iteration 20 / 30 | Total Loss: 3.4767813682556152 | KNN Loss: 2.473816394805908 | BCE Loss: 1.002964973449707
Epoch 213 / 500 | iteration 25 / 30 | Total Loss: 3.500514030456543 | KNN Loss: 2.4512157440185547 | BCE Loss: 1.0492981672286987
Epoch 214 / 500 | iteration 0 / 30 | Total Loss: 3.5339837074279785 | KNN Loss: 2.500840425491333 | BCE Loss: 1.033143162727356
Epoch 214 / 500 | iteration 5 / 30 | Total Loss: 3.4960975646972656 | KNN Loss: 2.456702709197998 | BCE Loss: 1.0393948554992676
Epoch 214 / 500 | iteration 10 / 30 | Total Loss: 3.477818489074707 | KNN Loss: 2.45334315299

Epoch 223 / 500 | iteration 25 / 30 | Total Loss: 3.4682388305664062 | KNN Loss: 2.4789772033691406 | BCE Loss: 0.9892617464065552
Epoch 224 / 500 | iteration 0 / 30 | Total Loss: 3.4530858993530273 | KNN Loss: 2.4361634254455566 | BCE Loss: 1.0169223546981812
Epoch 224 / 500 | iteration 5 / 30 | Total Loss: 3.5178747177124023 | KNN Loss: 2.4939587116241455 | BCE Loss: 1.0239160060882568
Epoch 224 / 500 | iteration 10 / 30 | Total Loss: 3.5003507137298584 | KNN Loss: 2.469301700592041 | BCE Loss: 1.0310490131378174
Epoch 224 / 500 | iteration 15 / 30 | Total Loss: 3.481722354888916 | KNN Loss: 2.4644556045532227 | BCE Loss: 1.0172666311264038
Epoch 224 / 500 | iteration 20 / 30 | Total Loss: 3.4827194213867188 | KNN Loss: 2.471250295639038 | BCE Loss: 1.0114692449569702
Epoch 224 / 500 | iteration 25 / 30 | Total Loss: 3.510706663131714 | KNN Loss: 2.4650845527648926 | BCE Loss: 1.0456221103668213
Epoch 225 / 500 | iteration 0 / 30 | Total Loss: 3.5160083770751953 | KNN Loss: 2.4760878

Epoch 234 / 500 | iteration 15 / 30 | Total Loss: 3.487150192260742 | KNN Loss: 2.457684278488159 | BCE Loss: 1.0294660329818726
Epoch 234 / 500 | iteration 20 / 30 | Total Loss: 3.516845703125 | KNN Loss: 2.489135265350342 | BCE Loss: 1.0277105569839478
Epoch 234 / 500 | iteration 25 / 30 | Total Loss: 3.480289936065674 | KNN Loss: 2.4600515365600586 | BCE Loss: 1.0202382802963257
Epoch 235 / 500 | iteration 0 / 30 | Total Loss: 3.5107548236846924 | KNN Loss: 2.4997782707214355 | BCE Loss: 1.0109765529632568
Epoch 235 / 500 | iteration 5 / 30 | Total Loss: 3.4701790809631348 | KNN Loss: 2.4584739208221436 | BCE Loss: 1.0117051601409912
Epoch 235 / 500 | iteration 10 / 30 | Total Loss: 3.49613094329834 | KNN Loss: 2.455707550048828 | BCE Loss: 1.0404233932495117
Epoch 235 / 500 | iteration 15 / 30 | Total Loss: 3.466824769973755 | KNN Loss: 2.4414374828338623 | BCE Loss: 1.0253872871398926
Epoch 235 / 500 | iteration 20 / 30 | Total Loss: 3.4956817626953125 | KNN Loss: 2.48965430259704

Epoch 245 / 500 | iteration 0 / 30 | Total Loss: 3.43994140625 | KNN Loss: 2.448589324951172 | BCE Loss: 0.9913519620895386
Epoch 245 / 500 | iteration 5 / 30 | Total Loss: 3.4686684608459473 | KNN Loss: 2.433504581451416 | BCE Loss: 1.0351638793945312
Epoch 245 / 500 | iteration 10 / 30 | Total Loss: 3.4746437072753906 | KNN Loss: 2.466587781906128 | BCE Loss: 1.0080558061599731
Epoch 245 / 500 | iteration 15 / 30 | Total Loss: 3.526972770690918 | KNN Loss: 2.4982352256774902 | BCE Loss: 1.0287376642227173
Epoch 245 / 500 | iteration 20 / 30 | Total Loss: 3.49519419670105 | KNN Loss: 2.4866888523101807 | BCE Loss: 1.0085053443908691
Epoch 245 / 500 | iteration 25 / 30 | Total Loss: 3.454439163208008 | KNN Loss: 2.4561712741851807 | BCE Loss: 0.9982680082321167
Epoch 246 / 500 | iteration 0 / 30 | Total Loss: 3.4670865535736084 | KNN Loss: 2.4497711658477783 | BCE Loss: 1.01731538772583
Epoch 246 / 500 | iteration 5 / 30 | Total Loss: 3.469473361968994 | KNN Loss: 2.448495388031006 | B

Epoch 255 / 500 | iteration 15 / 30 | Total Loss: 3.500812530517578 | KNN Loss: 2.4851438999176025 | BCE Loss: 1.015668511390686
Epoch 255 / 500 | iteration 20 / 30 | Total Loss: 3.4908697605133057 | KNN Loss: 2.4316227436065674 | BCE Loss: 1.0592470169067383
Epoch 255 / 500 | iteration 25 / 30 | Total Loss: 3.5236706733703613 | KNN Loss: 2.494957447052002 | BCE Loss: 1.028713345527649
Epoch 256 / 500 | iteration 0 / 30 | Total Loss: 3.4648239612579346 | KNN Loss: 2.446715831756592 | BCE Loss: 1.0181081295013428
Epoch 256 / 500 | iteration 5 / 30 | Total Loss: 3.4858226776123047 | KNN Loss: 2.46500825881958 | BCE Loss: 1.0208144187927246
Epoch 256 / 500 | iteration 10 / 30 | Total Loss: 3.4975953102111816 | KNN Loss: 2.4745688438415527 | BCE Loss: 1.023026466369629
Epoch 256 / 500 | iteration 15 / 30 | Total Loss: 3.493410110473633 | KNN Loss: 2.4550998210906982 | BCE Loss: 1.0383102893829346
Epoch 256 / 500 | iteration 20 / 30 | Total Loss: 3.4657177925109863 | KNN Loss: 2.45739459991

Epoch 266 / 500 | iteration 5 / 30 | Total Loss: 3.4793930053710938 | KNN Loss: 2.4455349445343018 | BCE Loss: 1.0338579416275024
Epoch 266 / 500 | iteration 10 / 30 | Total Loss: 3.494457960128784 | KNN Loss: 2.453348159790039 | BCE Loss: 1.0411098003387451
Epoch 266 / 500 | iteration 15 / 30 | Total Loss: 3.4696717262268066 | KNN Loss: 2.4808340072631836 | BCE Loss: 0.9888378381729126
Epoch 266 / 500 | iteration 20 / 30 | Total Loss: 3.4969067573547363 | KNN Loss: 2.4618964195251465 | BCE Loss: 1.0350102186203003
Epoch 266 / 500 | iteration 25 / 30 | Total Loss: 3.4299352169036865 | KNN Loss: 2.4306607246398926 | BCE Loss: 0.9992745518684387
Epoch 267 / 500 | iteration 0 / 30 | Total Loss: 3.444842576980591 | KNN Loss: 2.4461121559143066 | BCE Loss: 0.9987304210662842
Epoch 267 / 500 | iteration 5 / 30 | Total Loss: 3.45875883102417 | KNN Loss: 2.4498467445373535 | BCE Loss: 1.0089120864868164
Epoch 267 / 500 | iteration 10 / 30 | Total Loss: 3.4739418029785156 | KNN Loss: 2.44504880

Epoch 276 / 500 | iteration 25 / 30 | Total Loss: 3.4779164791107178 | KNN Loss: 2.4540045261383057 | BCE Loss: 1.023911952972412
Epoch 277 / 500 | iteration 0 / 30 | Total Loss: 3.432441473007202 | KNN Loss: 2.4330930709838867 | BCE Loss: 0.9993483424186707
Epoch 277 / 500 | iteration 5 / 30 | Total Loss: 3.4827635288238525 | KNN Loss: 2.4495465755462646 | BCE Loss: 1.033216953277588
Epoch 277 / 500 | iteration 10 / 30 | Total Loss: 3.4427547454833984 | KNN Loss: 2.4392154216766357 | BCE Loss: 1.0035394430160522
Epoch 277 / 500 | iteration 15 / 30 | Total Loss: 3.4746475219726562 | KNN Loss: 2.4598522186279297 | BCE Loss: 1.014795184135437
Epoch 277 / 500 | iteration 20 / 30 | Total Loss: 3.4226436614990234 | KNN Loss: 2.4271819591522217 | BCE Loss: 0.9954617023468018
Epoch 277 / 500 | iteration 25 / 30 | Total Loss: 3.4650306701660156 | KNN Loss: 2.4415688514709473 | BCE Loss: 1.0234618186950684
Epoch 278 / 500 | iteration 0 / 30 | Total Loss: 3.4120399951934814 | KNN Loss: 2.4128892

Epoch 287 / 500 | iteration 10 / 30 | Total Loss: 3.484673500061035 | KNN Loss: 2.4439029693603516 | BCE Loss: 1.040770411491394
Epoch 287 / 500 | iteration 15 / 30 | Total Loss: 3.453619956970215 | KNN Loss: 2.4371225833892822 | BCE Loss: 1.016497254371643
Epoch 287 / 500 | iteration 20 / 30 | Total Loss: 3.4600253105163574 | KNN Loss: 2.4508986473083496 | BCE Loss: 1.0091267824172974
Epoch 287 / 500 | iteration 25 / 30 | Total Loss: 3.474038600921631 | KNN Loss: 2.4546124935150146 | BCE Loss: 1.0194261074066162
Epoch 288 / 500 | iteration 0 / 30 | Total Loss: 3.4919047355651855 | KNN Loss: 2.4647786617279053 | BCE Loss: 1.0271261930465698
Epoch 288 / 500 | iteration 5 / 30 | Total Loss: 3.4771616458892822 | KNN Loss: 2.4328725337982178 | BCE Loss: 1.0442891120910645
Epoch 288 / 500 | iteration 10 / 30 | Total Loss: 3.4543333053588867 | KNN Loss: 2.454338788986206 | BCE Loss: 0.9999945759773254
Epoch 288 / 500 | iteration 15 / 30 | Total Loss: 3.460477113723755 | KNN Loss: 2.437691211

Epoch 298 / 500 | iteration 0 / 30 | Total Loss: 3.4963135719299316 | KNN Loss: 2.467194080352783 | BCE Loss: 1.0291194915771484
Epoch 298 / 500 | iteration 5 / 30 | Total Loss: 3.4381637573242188 | KNN Loss: 2.4402761459350586 | BCE Loss: 0.9978876113891602
Epoch 298 / 500 | iteration 10 / 30 | Total Loss: 3.4735970497131348 | KNN Loss: 2.441227674484253 | BCE Loss: 1.0323693752288818
Epoch 298 / 500 | iteration 15 / 30 | Total Loss: 3.425041913986206 | KNN Loss: 2.422712802886963 | BCE Loss: 1.0023291110992432
Epoch 298 / 500 | iteration 20 / 30 | Total Loss: 3.456709861755371 | KNN Loss: 2.4374852180480957 | BCE Loss: 1.0192246437072754
Epoch 298 / 500 | iteration 25 / 30 | Total Loss: 3.4864721298217773 | KNN Loss: 2.4347658157348633 | BCE Loss: 1.0517061948776245
Epoch 299 / 500 | iteration 0 / 30 | Total Loss: 3.5064194202423096 | KNN Loss: 2.476707696914673 | BCE Loss: 1.0297117233276367
Epoch 299 / 500 | iteration 5 / 30 | Total Loss: 3.46610689163208 | KNN Loss: 2.446050643920

Epoch 308 / 500 | iteration 15 / 30 | Total Loss: 3.44573974609375 | KNN Loss: 2.4388303756713867 | BCE Loss: 1.0069092512130737
Epoch 308 / 500 | iteration 20 / 30 | Total Loss: 3.495628595352173 | KNN Loss: 2.4559507369995117 | BCE Loss: 1.0396778583526611
Epoch 308 / 500 | iteration 25 / 30 | Total Loss: 3.474010467529297 | KNN Loss: 2.47218656539917 | BCE Loss: 1.0018237829208374
Epoch 309 / 500 | iteration 0 / 30 | Total Loss: 3.4567947387695312 | KNN Loss: 2.4463653564453125 | BCE Loss: 1.0104292631149292
Epoch 309 / 500 | iteration 5 / 30 | Total Loss: 3.453601360321045 | KNN Loss: 2.44804310798645 | BCE Loss: 1.0055583715438843
Epoch 309 / 500 | iteration 10 / 30 | Total Loss: 3.480644941329956 | KNN Loss: 2.460764169692993 | BCE Loss: 1.019880771636963
Epoch 309 / 500 | iteration 15 / 30 | Total Loss: 3.502180576324463 | KNN Loss: 2.4563088417053223 | BCE Loss: 1.0458717346191406
Epoch 309 / 500 | iteration 20 / 30 | Total Loss: 3.4481253623962402 | KNN Loss: 2.432525157928467

Epoch 319 / 500 | iteration 0 / 30 | Total Loss: 3.4951343536376953 | KNN Loss: 2.4988577365875244 | BCE Loss: 0.9962764978408813
Epoch 319 / 500 | iteration 5 / 30 | Total Loss: 3.463897228240967 | KNN Loss: 2.4661757946014404 | BCE Loss: 0.9977214336395264
Epoch 319 / 500 | iteration 10 / 30 | Total Loss: 3.4419898986816406 | KNN Loss: 2.4483163356781006 | BCE Loss: 0.9936736822128296
Epoch 319 / 500 | iteration 15 / 30 | Total Loss: 3.455347776412964 | KNN Loss: 2.4182567596435547 | BCE Loss: 1.0370910167694092
Epoch 319 / 500 | iteration 20 / 30 | Total Loss: 3.4829821586608887 | KNN Loss: 2.469203233718872 | BCE Loss: 1.0137790441513062
Epoch 319 / 500 | iteration 25 / 30 | Total Loss: 3.4745564460754395 | KNN Loss: 2.4561941623687744 | BCE Loss: 1.0183621644973755
Epoch 320 / 500 | iteration 0 / 30 | Total Loss: 3.450375556945801 | KNN Loss: 2.4403364658355713 | BCE Loss: 1.01003897190094
Epoch 320 / 500 | iteration 5 / 30 | Total Loss: 3.459965229034424 | KNN Loss: 2.44052958488

Epoch 329 / 500 | iteration 15 / 30 | Total Loss: 3.487696409225464 | KNN Loss: 2.4417428970336914 | BCE Loss: 1.0459535121917725
Epoch 329 / 500 | iteration 20 / 30 | Total Loss: 3.5294198989868164 | KNN Loss: 2.4545137882232666 | BCE Loss: 1.0749059915542603
Epoch 329 / 500 | iteration 25 / 30 | Total Loss: 3.4387598037719727 | KNN Loss: 2.4374871253967285 | BCE Loss: 1.0012726783752441
Epoch 330 / 500 | iteration 0 / 30 | Total Loss: 3.433992862701416 | KNN Loss: 2.43955135345459 | BCE Loss: 0.9944414496421814
Epoch 330 / 500 | iteration 5 / 30 | Total Loss: 3.461209297180176 | KNN Loss: 2.4528281688690186 | BCE Loss: 1.0083811283111572
Epoch 330 / 500 | iteration 10 / 30 | Total Loss: 3.4461655616760254 | KNN Loss: 2.429147720336914 | BCE Loss: 1.0170178413391113
Epoch 330 / 500 | iteration 15 / 30 | Total Loss: 3.4675862789154053 | KNN Loss: 2.4561471939086914 | BCE Loss: 1.0114390850067139
Epoch 330 / 500 | iteration 20 / 30 | Total Loss: 3.4564170837402344 | KNN Loss: 2.43917870

Epoch 340 / 500 | iteration 0 / 30 | Total Loss: 3.4628241062164307 | KNN Loss: 2.450666666030884 | BCE Loss: 1.0121574401855469
Epoch 340 / 500 | iteration 5 / 30 | Total Loss: 3.4313652515411377 | KNN Loss: 2.440866231918335 | BCE Loss: 0.9904990196228027
Epoch 340 / 500 | iteration 10 / 30 | Total Loss: 3.440450668334961 | KNN Loss: 2.4465572834014893 | BCE Loss: 0.9938932657241821
Epoch 340 / 500 | iteration 15 / 30 | Total Loss: 3.4485034942626953 | KNN Loss: 2.4247398376464844 | BCE Loss: 1.0237637758255005
Epoch 340 / 500 | iteration 20 / 30 | Total Loss: 3.4983177185058594 | KNN Loss: 2.480395793914795 | BCE Loss: 1.017922043800354
Epoch 340 / 500 | iteration 25 / 30 | Total Loss: 3.4470086097717285 | KNN Loss: 2.4370381832122803 | BCE Loss: 1.0099705457687378
Epoch 341 / 500 | iteration 0 / 30 | Total Loss: 3.430767774581909 | KNN Loss: 2.416963815689087 | BCE Loss: 1.0138039588928223
Epoch 341 / 500 | iteration 5 / 30 | Total Loss: 3.433671712875366 | KNN Loss: 2.429371595382

Epoch 350 / 500 | iteration 15 / 30 | Total Loss: 3.451826810836792 | KNN Loss: 2.4359049797058105 | BCE Loss: 1.0159218311309814
Epoch 350 / 500 | iteration 20 / 30 | Total Loss: 3.4243640899658203 | KNN Loss: 2.4164812564849854 | BCE Loss: 1.007882833480835
Epoch 350 / 500 | iteration 25 / 30 | Total Loss: 3.4216995239257812 | KNN Loss: 2.421969175338745 | BCE Loss: 0.9997302889823914
Epoch 351 / 500 | iteration 0 / 30 | Total Loss: 3.451133966445923 | KNN Loss: 2.454793691635132 | BCE Loss: 0.996340274810791
Epoch 351 / 500 | iteration 5 / 30 | Total Loss: 3.4513702392578125 | KNN Loss: 2.4259397983551025 | BCE Loss: 1.0254305601119995
Epoch 351 / 500 | iteration 10 / 30 | Total Loss: 3.4961676597595215 | KNN Loss: 2.4557600021362305 | BCE Loss: 1.040407657623291
Epoch 351 / 500 | iteration 15 / 30 | Total Loss: 3.4682564735412598 | KNN Loss: 2.4321274757385254 | BCE Loss: 1.0361289978027344
Epoch 351 / 500 | iteration 20 / 30 | Total Loss: 3.4320778846740723 | KNN Loss: 2.440474987

Epoch 361 / 500 | iteration 0 / 30 | Total Loss: 3.4307641983032227 | KNN Loss: 2.4071671962738037 | BCE Loss: 1.0235971212387085
Epoch 361 / 500 | iteration 5 / 30 | Total Loss: 3.4367942810058594 | KNN Loss: 2.428792953491211 | BCE Loss: 1.0080012083053589
Epoch 361 / 500 | iteration 10 / 30 | Total Loss: 3.4668169021606445 | KNN Loss: 2.4456634521484375 | BCE Loss: 1.0211533308029175
Epoch 361 / 500 | iteration 15 / 30 | Total Loss: 3.473691940307617 | KNN Loss: 2.4372177124023438 | BCE Loss: 1.0364742279052734
Epoch 361 / 500 | iteration 20 / 30 | Total Loss: 3.4736852645874023 | KNN Loss: 2.481696128845215 | BCE Loss: 0.991989254951477
Epoch 361 / 500 | iteration 25 / 30 | Total Loss: 3.4969851970672607 | KNN Loss: 2.4619152545928955 | BCE Loss: 1.0350699424743652
Epoch 362 / 500 | iteration 0 / 30 | Total Loss: 3.4558610916137695 | KNN Loss: 2.4420931339263916 | BCE Loss: 1.013767957687378
Epoch 362 / 500 | iteration 5 / 30 | Total Loss: 3.4741268157958984 | KNN Loss: 2.482291698

Epoch 371 / 500 | iteration 20 / 30 | Total Loss: 3.45676851272583 | KNN Loss: 2.422055721282959 | BCE Loss: 1.0347126722335815
Epoch 371 / 500 | iteration 25 / 30 | Total Loss: 3.452685832977295 | KNN Loss: 2.4204068183898926 | BCE Loss: 1.0322790145874023
Epoch   372: reducing learning rate of group 0 to 3.3911e-05.
Epoch 372 / 500 | iteration 0 / 30 | Total Loss: 3.456789493560791 | KNN Loss: 2.434203863143921 | BCE Loss: 1.0225855112075806
Epoch 372 / 500 | iteration 5 / 30 | Total Loss: 3.464006185531616 | KNN Loss: 2.4412076473236084 | BCE Loss: 1.0227985382080078
Epoch 372 / 500 | iteration 10 / 30 | Total Loss: 3.4395852088928223 | KNN Loss: 2.4570276737213135 | BCE Loss: 0.9825576543807983
Epoch 372 / 500 | iteration 15 / 30 | Total Loss: 3.465245246887207 | KNN Loss: 2.460357189178467 | BCE Loss: 1.0048880577087402
Epoch 372 / 500 | iteration 20 / 30 | Total Loss: 3.4620110988616943 | KNN Loss: 2.4449644088745117 | BCE Loss: 1.0170466899871826
Epoch 372 / 500 | iteration 25 /

Epoch 382 / 500 | iteration 5 / 30 | Total Loss: 3.4356589317321777 | KNN Loss: 2.4302480220794678 | BCE Loss: 1.0054107904434204
Epoch 382 / 500 | iteration 10 / 30 | Total Loss: 3.4638590812683105 | KNN Loss: 2.4563090801239014 | BCE Loss: 1.0075498819351196
Epoch 382 / 500 | iteration 15 / 30 | Total Loss: 3.4861347675323486 | KNN Loss: 2.4605329036712646 | BCE Loss: 1.025601863861084
Epoch 382 / 500 | iteration 20 / 30 | Total Loss: 3.49409556388855 | KNN Loss: 2.462716579437256 | BCE Loss: 1.031378984451294
Epoch 382 / 500 | iteration 25 / 30 | Total Loss: 3.4988198280334473 | KNN Loss: 2.4902713298797607 | BCE Loss: 1.008548378944397
Epoch   383: reducing learning rate of group 0 to 2.3738e-05.
Epoch 383 / 500 | iteration 0 / 30 | Total Loss: 3.426450252532959 | KNN Loss: 2.4382364749908447 | BCE Loss: 0.9882137775421143
Epoch 383 / 500 | iteration 5 / 30 | Total Loss: 3.465665578842163 | KNN Loss: 2.4518051147460938 | BCE Loss: 1.0138604640960693
Epoch 383 / 500 | iteration 10 /

Epoch 392 / 500 | iteration 20 / 30 | Total Loss: 3.470994472503662 | KNN Loss: 2.4397103786468506 | BCE Loss: 1.0312840938568115
Epoch 392 / 500 | iteration 25 / 30 | Total Loss: 3.501615047454834 | KNN Loss: 2.4897356033325195 | BCE Loss: 1.0118794441223145
Epoch 393 / 500 | iteration 0 / 30 | Total Loss: 3.460090398788452 | KNN Loss: 2.4271278381347656 | BCE Loss: 1.0329625606536865
Epoch 393 / 500 | iteration 5 / 30 | Total Loss: 3.4711830615997314 | KNN Loss: 2.4630308151245117 | BCE Loss: 1.0081522464752197
Epoch 393 / 500 | iteration 10 / 30 | Total Loss: 3.429441452026367 | KNN Loss: 2.425044536590576 | BCE Loss: 1.004396915435791
Epoch 393 / 500 | iteration 15 / 30 | Total Loss: 3.465471029281616 | KNN Loss: 2.451066732406616 | BCE Loss: 1.014404296875
Epoch 393 / 500 | iteration 20 / 30 | Total Loss: 3.4621341228485107 | KNN Loss: 2.4483399391174316 | BCE Loss: 1.013794183731079
Epoch 393 / 500 | iteration 25 / 30 | Total Loss: 3.508707046508789 | KNN Loss: 2.461805820465088 

Epoch 403 / 500 | iteration 10 / 30 | Total Loss: 3.4628725051879883 | KNN Loss: 2.4468088150024414 | BCE Loss: 1.0160636901855469
Epoch 403 / 500 | iteration 15 / 30 | Total Loss: 3.476689100265503 | KNN Loss: 2.4794442653656006 | BCE Loss: 0.9972448348999023
Epoch 403 / 500 | iteration 20 / 30 | Total Loss: 3.473046064376831 | KNN Loss: 2.4464361667633057 | BCE Loss: 1.0266098976135254
Epoch 403 / 500 | iteration 25 / 30 | Total Loss: 3.480100631713867 | KNN Loss: 2.466578483581543 | BCE Loss: 1.0135222673416138
Epoch 404 / 500 | iteration 0 / 30 | Total Loss: 3.4549713134765625 | KNN Loss: 2.442111015319824 | BCE Loss: 1.0128604173660278
Epoch 404 / 500 | iteration 5 / 30 | Total Loss: 3.4482712745666504 | KNN Loss: 2.4365503787994385 | BCE Loss: 1.011720895767212
Epoch 404 / 500 | iteration 10 / 30 | Total Loss: 3.456695079803467 | KNN Loss: 2.4581100940704346 | BCE Loss: 0.9985849261283875
Epoch 404 / 500 | iteration 15 / 30 | Total Loss: 3.47969388961792 | KNN Loss: 2.45687675476

Epoch 413 / 500 | iteration 25 / 30 | Total Loss: 3.4526660442352295 | KNN Loss: 2.4355456829071045 | BCE Loss: 1.017120361328125
Epoch 414 / 500 | iteration 0 / 30 | Total Loss: 3.4548628330230713 | KNN Loss: 2.4479410648345947 | BCE Loss: 1.0069217681884766
Epoch 414 / 500 | iteration 5 / 30 | Total Loss: 3.447676181793213 | KNN Loss: 2.432644844055176 | BCE Loss: 1.0150312185287476
Epoch 414 / 500 | iteration 10 / 30 | Total Loss: 3.5004615783691406 | KNN Loss: 2.454744815826416 | BCE Loss: 1.0457167625427246
Epoch 414 / 500 | iteration 15 / 30 | Total Loss: 3.43498158454895 | KNN Loss: 2.4332122802734375 | BCE Loss: 1.0017693042755127
Epoch 414 / 500 | iteration 20 / 30 | Total Loss: 3.4490609169006348 | KNN Loss: 2.4323794841766357 | BCE Loss: 1.0166815519332886
Epoch 414 / 500 | iteration 25 / 30 | Total Loss: 3.443997383117676 | KNN Loss: 2.431135654449463 | BCE Loss: 1.0128618478775024
Epoch 415 / 500 | iteration 0 / 30 | Total Loss: 3.4691872596740723 | KNN Loss: 2.43173313140

Epoch 424 / 500 | iteration 10 / 30 | Total Loss: 3.4722704887390137 | KNN Loss: 2.4437897205352783 | BCE Loss: 1.0284807682037354
Epoch 424 / 500 | iteration 15 / 30 | Total Loss: 3.4408159255981445 | KNN Loss: 2.4402902126312256 | BCE Loss: 1.0005258321762085
Epoch 424 / 500 | iteration 20 / 30 | Total Loss: 3.465956211090088 | KNN Loss: 2.4784810543060303 | BCE Loss: 0.9874750375747681
Epoch 424 / 500 | iteration 25 / 30 | Total Loss: 3.461963653564453 | KNN Loss: 2.449181079864502 | BCE Loss: 1.0127826929092407
Epoch 425 / 500 | iteration 0 / 30 | Total Loss: 3.458625555038452 | KNN Loss: 2.4408457279205322 | BCE Loss: 1.01777982711792
Epoch 425 / 500 | iteration 5 / 30 | Total Loss: 3.4682774543762207 | KNN Loss: 2.432546854019165 | BCE Loss: 1.0357304811477661
Epoch 425 / 500 | iteration 10 / 30 | Total Loss: 3.471376895904541 | KNN Loss: 2.4650843143463135 | BCE Loss: 1.006292700767517
Epoch 425 / 500 | iteration 15 / 30 | Total Loss: 3.4297189712524414 | KNN Loss: 2.42405939102

Epoch 434 / 500 | iteration 25 / 30 | Total Loss: 3.429617404937744 | KNN Loss: 2.42323637008667 | BCE Loss: 1.0063809156417847
Epoch 435 / 500 | iteration 0 / 30 | Total Loss: 3.4497780799865723 | KNN Loss: 2.466322183609009 | BCE Loss: 0.9834557771682739
Epoch 435 / 500 | iteration 5 / 30 | Total Loss: 3.4866178035736084 | KNN Loss: 2.45715594291687 | BCE Loss: 1.0294618606567383
Epoch 435 / 500 | iteration 10 / 30 | Total Loss: 3.478959083557129 | KNN Loss: 2.4637222290039062 | BCE Loss: 1.0152369737625122
Epoch 435 / 500 | iteration 15 / 30 | Total Loss: 3.526686191558838 | KNN Loss: 2.4768800735473633 | BCE Loss: 1.0498061180114746
Epoch 435 / 500 | iteration 20 / 30 | Total Loss: 3.4363651275634766 | KNN Loss: 2.4207119941711426 | BCE Loss: 1.015653133392334
Epoch 435 / 500 | iteration 25 / 30 | Total Loss: 3.460972547531128 | KNN Loss: 2.4455013275146484 | BCE Loss: 1.0154712200164795
Epoch 436 / 500 | iteration 0 / 30 | Total Loss: 3.5019798278808594 | KNN Loss: 2.4697473049163

Epoch 445 / 500 | iteration 10 / 30 | Total Loss: 3.4624907970428467 | KNN Loss: 2.463801145553589 | BCE Loss: 0.9986896514892578
Epoch 445 / 500 | iteration 15 / 30 | Total Loss: 3.420482635498047 | KNN Loss: 2.4239392280578613 | BCE Loss: 0.996543288230896
Epoch 445 / 500 | iteration 20 / 30 | Total Loss: 3.4632930755615234 | KNN Loss: 2.435424566268921 | BCE Loss: 1.027868390083313
Epoch 445 / 500 | iteration 25 / 30 | Total Loss: 3.4622347354888916 | KNN Loss: 2.436458110809326 | BCE Loss: 1.0257766246795654
Epoch 446 / 500 | iteration 0 / 30 | Total Loss: 3.397132396697998 | KNN Loss: 2.425917863845825 | BCE Loss: 0.9712145328521729
Epoch 446 / 500 | iteration 5 / 30 | Total Loss: 3.4580068588256836 | KNN Loss: 2.4454846382141113 | BCE Loss: 1.0125222206115723
Epoch 446 / 500 | iteration 10 / 30 | Total Loss: 3.5098018646240234 | KNN Loss: 2.4948716163635254 | BCE Loss: 1.0149301290512085
Epoch 446 / 500 | iteration 15 / 30 | Total Loss: 3.4408397674560547 | KNN Loss: 2.4229085445

Epoch 455 / 500 | iteration 25 / 30 | Total Loss: 3.4797115325927734 | KNN Loss: 2.4524600505828857 | BCE Loss: 1.0272513628005981
Epoch 456 / 500 | iteration 0 / 30 | Total Loss: 3.490490436553955 | KNN Loss: 2.4655580520629883 | BCE Loss: 1.0249323844909668
Epoch 456 / 500 | iteration 5 / 30 | Total Loss: 3.450888156890869 | KNN Loss: 2.4461400508880615 | BCE Loss: 1.0047482252120972
Epoch 456 / 500 | iteration 10 / 30 | Total Loss: 3.4608447551727295 | KNN Loss: 2.4407060146331787 | BCE Loss: 1.0201387405395508
Epoch 456 / 500 | iteration 15 / 30 | Total Loss: 3.427189350128174 | KNN Loss: 2.4095282554626465 | BCE Loss: 1.0176609754562378
Epoch 456 / 500 | iteration 20 / 30 | Total Loss: 3.440563201904297 | KNN Loss: 2.4440314769744873 | BCE Loss: 0.9965318441390991
Epoch 456 / 500 | iteration 25 / 30 | Total Loss: 3.4479925632476807 | KNN Loss: 2.4262053966522217 | BCE Loss: 1.021787166595459
Epoch 457 / 500 | iteration 0 / 30 | Total Loss: 3.444863796234131 | KNN Loss: 2.438103914

Epoch 466 / 500 | iteration 10 / 30 | Total Loss: 3.4270777702331543 | KNN Loss: 2.43412709236145 | BCE Loss: 0.9929506778717041
Epoch 466 / 500 | iteration 15 / 30 | Total Loss: 3.4769558906555176 | KNN Loss: 2.4305708408355713 | BCE Loss: 1.0463851690292358
Epoch 466 / 500 | iteration 20 / 30 | Total Loss: 3.5045700073242188 | KNN Loss: 2.4780433177948 | BCE Loss: 1.0265265703201294
Epoch 466 / 500 | iteration 25 / 30 | Total Loss: 3.434802532196045 | KNN Loss: 2.4517154693603516 | BCE Loss: 0.9830871224403381
Epoch 467 / 500 | iteration 0 / 30 | Total Loss: 3.5032684803009033 | KNN Loss: 2.4783828258514404 | BCE Loss: 1.024885654449463
Epoch 467 / 500 | iteration 5 / 30 | Total Loss: 3.4402925968170166 | KNN Loss: 2.4393630027770996 | BCE Loss: 1.000929594039917
Epoch 467 / 500 | iteration 10 / 30 | Total Loss: 3.4974653720855713 | KNN Loss: 2.4704699516296387 | BCE Loss: 1.0269954204559326
Epoch 467 / 500 | iteration 15 / 30 | Total Loss: 3.4298720359802246 | KNN Loss: 2.4396941661

Epoch 476 / 500 | iteration 25 / 30 | Total Loss: 3.4186818599700928 | KNN Loss: 2.420632839202881 | BCE Loss: 0.9980490803718567
Epoch 477 / 500 | iteration 0 / 30 | Total Loss: 3.4992356300354004 | KNN Loss: 2.478489875793457 | BCE Loss: 1.0207457542419434
Epoch 477 / 500 | iteration 5 / 30 | Total Loss: 3.456890821456909 | KNN Loss: 2.46777081489563 | BCE Loss: 0.9891200065612793
Epoch 477 / 500 | iteration 10 / 30 | Total Loss: 3.455470561981201 | KNN Loss: 2.433079719543457 | BCE Loss: 1.0223909616470337
Epoch 477 / 500 | iteration 15 / 30 | Total Loss: 3.4551830291748047 | KNN Loss: 2.440864086151123 | BCE Loss: 1.0143189430236816
Epoch 477 / 500 | iteration 20 / 30 | Total Loss: 3.474964141845703 | KNN Loss: 2.452160120010376 | BCE Loss: 1.0228039026260376
Epoch 477 / 500 | iteration 25 / 30 | Total Loss: 3.4621801376342773 | KNN Loss: 2.4839096069335938 | BCE Loss: 0.9782705307006836
Epoch 478 / 500 | iteration 0 / 30 | Total Loss: 3.4525279998779297 | KNN Loss: 2.4353661537170

Epoch 487 / 500 | iteration 10 / 30 | Total Loss: 3.4130349159240723 | KNN Loss: 2.427854061126709 | BCE Loss: 0.9851807355880737
Epoch 487 / 500 | iteration 15 / 30 | Total Loss: 3.469191312789917 | KNN Loss: 2.467212677001953 | BCE Loss: 1.0019786357879639
Epoch 487 / 500 | iteration 20 / 30 | Total Loss: 3.43693470954895 | KNN Loss: 2.429675340652466 | BCE Loss: 1.0072593688964844
Epoch 487 / 500 | iteration 25 / 30 | Total Loss: 3.4785704612731934 | KNN Loss: 2.4568064212799072 | BCE Loss: 1.0217640399932861
Epoch 488 / 500 | iteration 0 / 30 | Total Loss: 3.4128501415252686 | KNN Loss: 2.4294497966766357 | BCE Loss: 0.983400285243988
Epoch 488 / 500 | iteration 5 / 30 | Total Loss: 3.4198803901672363 | KNN Loss: 2.432467222213745 | BCE Loss: 0.987413227558136
Epoch 488 / 500 | iteration 10 / 30 | Total Loss: 3.455881118774414 | KNN Loss: 2.426189661026001 | BCE Loss: 1.029691457748413
Epoch 488 / 500 | iteration 15 / 30 | Total Loss: 3.426206588745117 | KNN Loss: 2.421798229217529

Epoch   498: reducing learning rate of group 0 to 9.5791e-07.
Epoch 498 / 500 | iteration 0 / 30 | Total Loss: 3.493215560913086 | KNN Loss: 2.460930585861206 | BCE Loss: 1.0322850942611694
Epoch 498 / 500 | iteration 5 / 30 | Total Loss: 3.4536564350128174 | KNN Loss: 2.4292213916778564 | BCE Loss: 1.024435043334961
Epoch 498 / 500 | iteration 10 / 30 | Total Loss: 3.451491594314575 | KNN Loss: 2.449463367462158 | BCE Loss: 1.002028226852417
Epoch 498 / 500 | iteration 15 / 30 | Total Loss: 3.4769890308380127 | KNN Loss: 2.4451828002929688 | BCE Loss: 1.031806230545044
Epoch 498 / 500 | iteration 20 / 30 | Total Loss: 3.489982843399048 | KNN Loss: 2.4779584407806396 | BCE Loss: 1.0120244026184082
Epoch 498 / 500 | iteration 25 / 30 | Total Loss: 3.477921724319458 | KNN Loss: 2.450702428817749 | BCE Loss: 1.027219295501709
Epoch 499 / 500 | iteration 0 / 30 | Total Loss: 3.42901873588562 | KNN Loss: 2.4322667121887207 | BCE Loss: 0.9967520833015442
Epoch 499 / 500 | iteration 5 / 30 | 

In [7]:
outputs, iterm = model(dataset[67][0].unsqueeze(0).to(device), return_intermidiate=True)
print(outputs)
print(outputs.min())
print(outputs.max())

tensor([[ 3.0974,  2.4199,  2.8636,  3.6784,  2.6929,  0.8784,  2.8699,  2.5165,
          2.5702,  2.1222,  2.4783,  2.5003,  1.1700,  2.0119,  1.4261,  1.7646,
          1.8027,  3.2493,  2.1962,  2.4891,  1.9998,  2.9664,  2.0832,  2.3126,
          2.5504,  1.9909,  2.3046,  1.7974,  1.6798,  0.1771, -0.2215,  0.4746,
          0.3704,  1.1309,  1.7533,  1.6911,  1.4606,  2.5366,  0.3809,  1.5309,
          0.8706, -0.5006,  0.1272,  2.5845,  1.7429,  0.9854, -0.0687,  0.3883,
          1.7073,  2.3993,  1.8991,  0.4401,  1.8039,  0.7181, -0.4625,  1.3088,
          0.9047,  1.4864,  1.5912,  2.0448,  0.6664,  0.2319, -0.2534,  1.8781,
          1.6133,  2.0089, -2.0149,  0.5438,  2.2521,  2.2253,  1.9007,  0.7418,
          1.6033,  2.0876,  1.7641,  1.5441,  0.2970,  0.8440,  0.4793,  1.8122,
          0.1397,  0.6163,  1.9695, -0.3058,  0.2172, -0.6535, -2.0342, -0.1630,
          0.7243, -2.3721,  0.7076, -0.0236, -0.5952, -0.6824,  0.4202,  1.4988,
         -0.9907, -0.6045,  

In [8]:
plt.figure()
plt.plot(losses)
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [9]:
dataset.transform = torchvision.transforms.Compose([
    BinaryEncodingTransform(mapping=dataset.items_to_idx),
]
)
dataset.target_transform = torchvision.transforms.Compose([
    BinaryEncodingTransform(mapping=dataset.items_to_idx),
]
)

In [14]:
dataset_ = [d[0].cpu() for d in dataset]

In [15]:
model = model.eval().cpu()
projections = model.calculate_intermidiate(dataset_)

100%|██████████| 15/15 [00:00<00:00, 91.38it/s]


In [16]:
distances = pairwise_distances(projections)
# distances = np.triu(distances)
distances_f = distances.flatten()

plt.matshow(distances)
plt.colorbar()
plt.figure()
plt.hist(distances_f[distances_f > 0], bins=1000)
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

# Fit DBSCAN and calculate indices


In [17]:
clusters = DBSCAN(eps=0.2, min_samples=80).fit_predict(projections)
# scores = []
# best_score = float('inf')
# clusters = None
# range_ = list(range(5, 20))
# for k in tqdm(range_):
#     y = GaussianMixture(n_components=k).fit_predict(projections)
#     cur_score = davies_bouldin_score(projections, y)
#     scores.append(cur_score)
    
#     if cur_score < best_score:
#         best_score = cur_score
#         clusters = y

In [18]:
perplexity = 100

p = reduce_dims_and_plot(projections[clusters != -1],
                         y=clusters[clusters != -1],
                         title=f'perplexity: {perplexity}',
                         file_name=None,
                         perplexity=perplexity,
                         library='Multicore-TSNE',
                         perform_PCA=False,
                         projected=None,
                         figure_type='2d',
                         show_figure=True,
                         close_figure=False,
                         text=None)    

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [19]:
# from sklearn.tree import DecisionTreeClassifier
# from sklearn import tree
# from sklearn.tree import _tree

In [21]:
tensor_dataset = torch.stack(dataset_)

In [22]:
# clf = DecisionTreeClassifier(max_depth=200, min_samples_leaf=5)
# clf = clf.fit(tensor_dataset[clusters!=-1], clusters[clusters != -1])
# print(clf.score(tensor_dataset[clusters!=-1], clusters[clusters != -1]))
# print(clf.get_depth())

In [23]:
# scores = []
# for min_samples in range(1,50, 1):
#     clf = DecisionTreeClassifier(max_depth=200, min_samples_leaf=min_samples)
#     clf = clf.fit(tensor_dataset[clusters!=-1], clusters[clusters != -1])
#     scores.append(clf.score(tensor_dataset[clusters!=-1], clusters[clusters != -1]))
    
# plt.figure()
# plt.plot(list(range(1,50, 1)), scores)
# plt.show()

In [24]:
def get_rules(tree, feature_names, class_names):
    tree_ = tree.tree_
    feature_name = [
        feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!"
        for i in tree_.feature
    ]

    paths = []
    path = []
    
    def recurse(node, path, paths):
        if tree_.feature[node] != _tree.TREE_UNDEFINED:
            name = feature_name[node]
            threshold = tree_.threshold[node]
            p1, p2 = list(path), list(path)
#             p1 += [f"({name} <= {np.round(threshold, 3)})"]
            p1 += [(name, '<=', np.round(threshold, 3))]
            recurse(tree_.children_left[node], p1, paths)
            p2 += [(name, '>', np.round(threshold, 3))]
            recurse(tree_.children_right[node], p2, paths)
        else:
            path += [(tree_.value[node], tree_.n_node_samples[node])]
            paths += [path]
            
    recurse(0, path, paths)

    # sort by samples count
    samples_count = [p[-1][1] for p in paths]
    ii = list(np.argsort(samples_count))
    paths = [paths[i] for i in reversed(ii)]
    
    rules = []
    for path in paths:
        rule = []
        
        for p in path[:-1]:
            rule += [p]
        target = " then "
        if class_names is None:
            target += "response: "+str(np.round(path[-1][0][0][0],3))
        else:
            classes = path[-1][0][0]
            l = np.argmax(classes)
            target += f"class: {class_names[l]} (proba: {np.round(100.0*classes[l]/np.sum(classes),2)}%)"
           
        proba = np.round(100.0*classes[l]/np.sum(classes),2)
        target += f" | based on {path[-1][1]:,} samples"
        rule_wrapper = {'target': target, 'rule': rule, 'proba': proba}
        rules += [rule_wrapper]
        
    return rules

In [25]:
# rules = get_rules(clf, dataset.items, clusters[clusters != -1])

# for rule in rules:
#     n_pos = 0
#     for c,p,v in rule['rule']:
#         if p == '>':
#             n_pos += 1
#     rule['pos'] = n_pos

In [26]:
# plt.figure()
# probs = [r['proba'] for r in rules]
# plt.hist(probs, bins = 100)
# plt.show()

In [27]:
# rules = sorted(rules, key=lambda x:x['pos'])
# rules = [r for r in rules if r['proba'] > 50]
# print(len(rules))

In [28]:
# for i in range(17):
#     r_i = rules[i]
#     print(f"------------- rule {i} length {len(r_i)} -------------")
#     print(r_i)

# Train a Soft-Decision-Tree given the self-labels

## Prepare the dataset

In [29]:
tree_dataset = list(zip(tensor_dataset[clusters!=-1], clusters[clusters != -1]))
batch_size = 512
tree_loader = torch.utils.data.DataLoader(tree_dataset, batch_size=batch_size, shuffle=True)

# Define how we prune the weights of a node

In [30]:
def prune_node(node_weights, factor=1):
    w = node_weights.cpu().detach().numpy()
    mean_ = np.mean(w)
    std_ = np.std(w)
    node_weights[((mean_ - std_ * factor) < node_weights) & (node_weights < (mean_ + std_ * factor))] = 0
    return node_weights

def prune_node_keep(node_weights, keep=4):
    w = node_weights.cpu().detach().numpy()
    throw_idx = np.argsort(abs(w))[:-keep]
    node_weights[throw_idx] = 0
    return node_weights

def prune_tree(tree_, factor):
    new_weights = tree_.inner_nodes.weight.clone()
    for i in range(new_weights.shape[0]):
        res = prune_node_keep(new_weights[i, :], factor)
        new_weights[i, :] = res

    with torch.no_grad():
        tree_.inner_nodes.weight.copy_(new_weights)
        
def sparseness(x):
    s = []
    for i in range(x.shape[0]):
        x_ = x[i, :]
        sp = (len(x_) - torch.norm(x_, 0).item()) / len(x_)
        s.append(sp)
    return np.mean(s)

def compute_regularization_by_level(tree):
    total_reg = 0
    for i in range(tree.inner_nodes.weight.shape[0]):
        cur_level = np.floor(np.log2(i+1))
        node_reg = torch.norm(tree.inner_nodes.weight[i].view(-1), 2)
        total_reg += 2**(-cur_level) * node_reg
    return total_reg

def show_sparseness(tree):
    avg_sp = sparseness(tree.inner_nodes.weight)
    print(f"Average sparseness: {avg_sp}")
    layer = 0
    sps = []
    for i in range(tree.inner_nodes.weight.shape[0]):
        cur_layer = int(np.floor(np.log2(i+1)))
        if cur_layer != layer:
            print(f"layer {layer}: {np.mean(sps)}")
            sps = []
            layer = cur_layer

        x_ = tree.inner_nodes.weight[i, :]
        sp = (len(x_) - torch.norm(x_, 0).item()) / len(x_)
        sps.append(sp)
        
    return avg_sp

## Training configurations

In [31]:
def do_epoch(model, loader, device, log_interval, losses, accs, epoch, iteration):
    model = model.train()
    for batch_idx, (data, target) in enumerate(loader):
        iteration += 1
        data, target = data.to(device), target.to(device)

        output, penalty = tree.forward(data)

        # Loss
        loss_tree = criterion(output, target.view(-1))

        # Penalty
        loss_tree += penalty

        # Sparse regularization
#         fc_params = torch.cat([x.view(-1) for x in tree.inner_nodes.parameters()])
#         regularization = sparsity_lamda * torch.norm(fc_params, 2)
        regularization = sparsity_lamda * compute_regularization_by_level(tree)
        loss = loss_tree

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        losses.append(loss.item())

        pred = output.data.max(1)[1]
        correct = pred.eq(target.view(-1).data).sum()
        accs.append(correct.item() / data.size()[0])

        # Print training status
        if batch_idx % log_interval == 0:
            print(f"Epoch: {epoch:02d} | Batch: {batch_idx:03d} / {len(loader):03d} | Total loss: {loss.item():.3f} | Reg loss: {regularization.item():.3f} | Tree loss: {loss_tree.item():.3f} | Accuracy: {correct.item() / data.size()[0]:03f} | {round((time.time() - start_time) / iteration, 3)} sec/iter")
            
    return iteration


In [32]:
lr = 5e-3
weight_decay = 5e-4
sparsity_lamda = 2e-3
epochs = 100
output_dim = len(set(clusters))
log_interval = 1
use_cuda = device != 'cpu'

In [33]:
tree = SDT(input_dim=tensor_dataset.shape[1], output_dim=len(clusters - 1), depth=tree_depth, lamda=1e-3, use_cuda=use_cuda)
optimizer = torch.optim.Adam(tree.parameters(),
                                 lr=lr,
                                 weight_decay=weight_decay)
criterion = nn.CrossEntropyLoss()
tree = tree.to(device)

In [34]:
losses = []
accs = []
sparsity = []

In [35]:
start_time = time.time()
iteration = 0
for epoch in range(epochs):
    # Training
    avg_sp = show_sparseness(tree)
    sparsity.append(avg_sp)
    iteration = do_epoch(tree, tree_loader, device, log_interval, losses, accs, epoch, iteration)
    
    if epoch % 1 == 0:
        prune_tree(tree, factor=3)
        

Average sparseness: 0.0
layer 0: 0.0
layer 1: 0.0
layer 2: 0.0
layer 3: 0.0
layer 4: 0.0
layer 5: 0.0
layer 6: 0.0
Epoch: 00 | Batch: 000 / 002 | Total loss: 9.605 | Reg loss: 0.009 | Tree loss: 9.605 | Accuracy: 0.000000 | 0.245 sec/iter
Epoch: 00 | Batch: 001 / 002 | Total loss: 9.591 | Reg loss: 0.009 | Tree loss: 9.591 | Accuracy: 0.000000 | 0.216 sec/iter
Average sparseness: 0.9821428571428573
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
Epoch: 01 | Batch: 000 / 002 | Total loss: 9.589 | Reg loss: 0.002 | Tree loss: 9.589 | Accuracy: 0.000000 | 0.221 sec/iter
Epoch: 01 | Batch: 001 / 002 | Total loss: 9.578 | Reg loss: 0.002 | Tree loss: 9.578 | Accuracy: 0.000000 | 0.216 sec/iter
Average sparseness: 0.9821428571428573
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
laye

Epoch: 17 | Batch: 000 / 002 | Total loss: 9.449 | Reg loss: 0.005 | Tree loss: 9.449 | Accuracy: 0.140625 | 0.214 sec/iter
Epoch: 17 | Batch: 001 / 002 | Total loss: 9.439 | Reg loss: 0.005 | Tree loss: 9.439 | Accuracy: 0.166292 | 0.214 sec/iter
Average sparseness: 0.9821428571428573
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
Epoch: 18 | Batch: 000 / 002 | Total loss: 9.441 | Reg loss: 0.005 | Tree loss: 9.441 | Accuracy: 0.148438 | 0.215 sec/iter
Epoch: 18 | Batch: 001 / 002 | Total loss: 9.429 | Reg loss: 0.005 | Tree loss: 9.429 | Accuracy: 0.184270 | 0.214 sec/iter
Average sparseness: 0.9821428571428573
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
Epoch: 19 | Batch: 000 / 002 | Total

Epoch: 34 | Batch: 000 / 002 | Total loss: 9.307 | Reg loss: 0.006 | Tree loss: 9.307 | Accuracy: 0.306641 | 0.219 sec/iter
Epoch: 34 | Batch: 001 / 002 | Total loss: 9.301 | Reg loss: 0.007 | Tree loss: 9.301 | Accuracy: 0.323596 | 0.218 sec/iter
Average sparseness: 0.9821428571428573
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
Epoch: 35 | Batch: 000 / 002 | Total loss: 9.295 | Reg loss: 0.007 | Tree loss: 9.295 | Accuracy: 0.326172 | 0.219 sec/iter
Epoch: 35 | Batch: 001 / 002 | Total loss: 9.298 | Reg loss: 0.007 | Tree loss: 9.298 | Accuracy: 0.301124 | 0.219 sec/iter
Average sparseness: 0.9821428571428573
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
Epoch: 36 | Batch: 000 / 002 | Total

Epoch: 51 | Batch: 001 / 002 | Total loss: 9.163 | Reg loss: 0.008 | Tree loss: 9.163 | Accuracy: 0.312360 | 0.224 sec/iter
Average sparseness: 0.9821428571428573
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
Epoch: 52 | Batch: 000 / 002 | Total loss: 9.161 | Reg loss: 0.008 | Tree loss: 9.161 | Accuracy: 0.347656 | 0.224 sec/iter
Epoch: 52 | Batch: 001 / 002 | Total loss: 9.164 | Reg loss: 0.008 | Tree loss: 9.164 | Accuracy: 0.276404 | 0.224 sec/iter
Average sparseness: 0.9821428571428573
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
Epoch: 53 | Batch: 000 / 002 | Total loss: 9.159 | Reg loss: 0.008 | Tree loss: 9.159 | Accuracy: 0.320312 | 0.225 sec/iter
Epoch: 53 | Batch: 001 / 002 | Total

Epoch: 68 | Batch: 001 / 002 | Total loss: 9.026 | Reg loss: 0.010 | Tree loss: 9.026 | Accuracy: 0.332584 | 0.229 sec/iter
Average sparseness: 0.9821428571428573
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
Epoch: 69 | Batch: 000 / 002 | Total loss: 9.030 | Reg loss: 0.010 | Tree loss: 9.030 | Accuracy: 0.335938 | 0.23 sec/iter
Epoch: 69 | Batch: 001 / 002 | Total loss: 9.025 | Reg loss: 0.010 | Tree loss: 9.025 | Accuracy: 0.289888 | 0.229 sec/iter
Average sparseness: 0.9821428571428573
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
Epoch: 70 | Batch: 000 / 002 | Total loss: 9.030 | Reg loss: 0.010 | Tree loss: 9.030 | Accuracy: 0.306641 | 0.23 sec/iter
Epoch: 70 | Batch: 001 / 002 | Total l

Epoch: 86 | Batch: 000 / 002 | Total loss: 8.889 | Reg loss: 0.011 | Tree loss: 8.889 | Accuracy: 0.316406 | 0.232 sec/iter
Epoch: 86 | Batch: 001 / 002 | Total loss: 8.884 | Reg loss: 0.011 | Tree loss: 8.884 | Accuracy: 0.312360 | 0.232 sec/iter
Average sparseness: 0.9821428571428573
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
Epoch: 87 | Batch: 000 / 002 | Total loss: 8.886 | Reg loss: 0.011 | Tree loss: 8.886 | Accuracy: 0.306641 | 0.233 sec/iter
Epoch: 87 | Batch: 001 / 002 | Total loss: 8.868 | Reg loss: 0.011 | Tree loss: 8.868 | Accuracy: 0.323596 | 0.232 sec/iter
Average sparseness: 0.9821428571428573
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
Epoch: 88 | Batch: 000 / 002 | Total

In [36]:
plt.figure(figsize=(10, 5))
plt.ylabel("Accuracy")
plt.xlabel('Iteration')
plt.plot(accs, label='Accuracy vs iteration')
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [37]:
plt.figure()
plt.ylabel("Loss")
plt.xlabel('Iteration')
plt.plot(losses, label='Loss vs iteration')
plt.yscale("log")
plt.show()

plt.figure()
weights = tree.inner_nodes.weight.cpu().detach().numpy().flatten()
plt.hist(weights, bins=500)
weights_std = np.std(weights)
weights_mean = np.mean(weights)
plt.axvline(weights_mean + weights_std, color='r')
plt.axvline(weights_mean - weights_std, color='r')
plt.title(f"Mean: {weights_mean}   |   STD: {weights_std}")
plt.yscale("log")
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

# Tree Visualization

In [38]:
plt.figure(figsize=(15, 10), dpi=80)
avg_height, root = tree.visualize()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Average height: 8.0


# Extract Rules

# Accumulate samples in the leaves

In [39]:
print(f"Number of patterns: {len(root.get_leaves())}")

Number of patterns: 256


In [40]:
method = 'greedy'

In [41]:
root.clear_leaves_samples()

with torch.no_grad():
    for batch_idx, (data, target) in enumerate(tree_loader):
        root.accumulate_samples(data, method)



# Tighten boundaries

In [43]:
attr_names = dataset.items

# print(attr_names)
leaves = root.get_leaves()
sum_comprehensibility = 0
comprehensibilities = []
for pattern_counter, leaf in enumerate(leaves):
    leaf.reset_path()
    leaf.tighten_with_accumulated_samples()
    conds = leaf.get_path_conditions(attr_names)
    print(f"============== Pattern {pattern_counter + 1} ==============")
    comprehensibilities.append(sum([cond.comprehensibility for cond in conds]))
    
print(f"Average comprehensibility: {np.mean(comprehensibilities)}")
print(f"std comprehensibility: {np.std(comprehensibilities)}")
print(f"var comprehensibility: {np.var(comprehensibilities)}")
print(f"minimum comprehensibility: {np.min(comprehensibilities)}")
print(f"maximum comprehensibility: {np.max(comprehensibilities)}")

  return np.log(1 / (1 - x))


486
147
20
27
3
16
3
46
8
1
89
7
4
1
93
5
1
Average comprehensibility: 45.640625
std comprehensibility: 1.8613434958048447
var comprehensibility: 3.464599609375
minimum comprehensibility: 42
maximum comprehensibility: 48
