In [14]:
%load_ext autoreload
%autoreload 2
%matplotlib widget
import numpy as np
import pandas as pd
import torch
import torchvision
from torch import optim
import matplotlib.pyplot as plt
import torch.nn.functional as F
from sklearn.metrics import pairwise_distances
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tqdm.notebook import tqdm
import os
import sys
import time
import torch.nn as nn
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)
from stream_generators.market_basket_dataset import MarketBasketDataset, BinaryEncodingTransform, RemoveItemsTransform
from utils.MatplotlibUtils import reduce_dims_and_plot
from network.auto_encoder import AutoEncoder
from losses.knn_loss import KNNLoss
from sklearn.cluster import DBSCAN
from sklearn.mixture import GaussianMixture
from soft_decision_tree.sdt_model import SDT
from sklearn.metrics import davies_bouldin_score

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [15]:
k = 4
tree_depth = 6
device = 'cuda'
dataset_path = r"/mnt/qnap/ekosman/Groceries_dataset.csv"

## Create the market basket dataset and use one-hot encoding for items

In [21]:
dataset = MarketBasketDataset(dataset_path=dataset_path)

In [22]:
model = AutoEncoder(dataset.n_items, 50, 4).train().to(device)
epochs = 500
lr = 5e-3
batch_size = 512
log_every = 5

In [23]:
dataset.transform = torchvision.transforms.Compose([
    RemoveItemsTransform(p=0.5),
    BinaryEncodingTransform(mapping=dataset.items_to_idx),
]
)
dataset.target_transform = torchvision.transforms.Compose([
    BinaryEncodingTransform(mapping=dataset.items_to_idx),
]
)

In [24]:
model.train()
data_iter = torch.utils.data.DataLoader(dataset,
                                     batch_size=batch_size,
                                     shuffle=True,
                                     num_workers=1,
                                     pin_memory=True)
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)
scheduler = ReduceLROnPlateau(optimizer, 'min', verbose=True, factor=0.7, threshold=1e-4)
knn_crt = KNNLoss(k=k).to(device)
losses = []
alpha = 10/170
gamma = 2
for epoch in range(epochs):
    total_loss = 0
    for iteration, (batch, target) in enumerate(data_iter):
        batch = batch.to(device)
        target = target.to(device)
        outputs, iterm = model(batch, return_intermidiate=True)
        mse_loss = F.binary_cross_entropy_with_logits(outputs, target, reduction='none')
        mask = torch.ones_like(mse_loss)
        mask[target == 0] = alpha ** gamma
        mask[target == 1] = (1 - alpha) ** gamma
        mse_loss = (mse_loss * mask).sum(dim=-1).mean()
        try:
            knn_loss = knn_crt(iterm)
            if torch.isinf(knn_loss):
                knn_loss = 0
        except ValueError:
            knn_loss = torch.tensor(0)
        loss = mse_loss + knn_loss
        loss = mse_loss
        total_loss += loss.item()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if iteration % log_every == 0:
            print(f"Epoch {epoch} / {epochs} | iteration {iteration} / {len(data_iter)} | Total Loss: {loss.item()} | KNN Loss: {knn_loss.item()} | BCE Loss: {mse_loss.item()}")
    
    scheduler.step(total_loss / (iteration + 1))
    losses.append(total_loss / (iteration + 1))

Epoch 0 / 500 | iteration 0 / 30 | Total Loss: 1.9548475742340088 | KNN Loss: 6.224452495574951 | BCE Loss: 1.9548475742340088
Epoch 0 / 500 | iteration 5 / 30 | Total Loss: 1.9269031286239624 | KNN Loss: 6.223893642425537 | BCE Loss: 1.9269031286239624
Epoch 0 / 500 | iteration 10 / 30 | Total Loss: 1.930891513824463 | KNN Loss: 6.223770618438721 | BCE Loss: 1.930891513824463
Epoch 0 / 500 | iteration 15 / 30 | Total Loss: 1.908257246017456 | KNN Loss: 6.223718166351318 | BCE Loss: 1.908257246017456
Epoch 0 / 500 | iteration 20 / 30 | Total Loss: 1.9198890924453735 | KNN Loss: 6.224004745483398 | BCE Loss: 1.9198890924453735
Epoch 0 / 500 | iteration 25 / 30 | Total Loss: 1.941080093383789 | KNN Loss: 6.224161624908447 | BCE Loss: 1.941080093383789
Epoch 1 / 500 | iteration 0 / 30 | Total Loss: 1.9609971046447754 | KNN Loss: 6.224413871765137 | BCE Loss: 1.9609971046447754
Epoch 1 / 500 | iteration 5 / 30 | Total Loss: 1.9316774606704712 | KNN Loss: 6.2242889404296875 | BCE Loss: 1.93

Epoch 10 / 500 | iteration 25 / 30 | Total Loss: 1.1807327270507812 | KNN Loss: 6.2258710861206055 | BCE Loss: 1.1807327270507812
Epoch 11 / 500 | iteration 0 / 30 | Total Loss: 1.158316731452942 | KNN Loss: 6.226020336151123 | BCE Loss: 1.158316731452942
Epoch 11 / 500 | iteration 5 / 30 | Total Loss: 1.1760101318359375 | KNN Loss: 6.225729465484619 | BCE Loss: 1.1760101318359375
Epoch 11 / 500 | iteration 10 / 30 | Total Loss: 1.1324989795684814 | KNN Loss: 6.225797653198242 | BCE Loss: 1.1324989795684814
Epoch 11 / 500 | iteration 15 / 30 | Total Loss: 1.1305344104766846 | KNN Loss: 6.226112365722656 | BCE Loss: 1.1305344104766846
Epoch 11 / 500 | iteration 20 / 30 | Total Loss: 1.1277282238006592 | KNN Loss: 6.225693702697754 | BCE Loss: 1.1277282238006592
Epoch 11 / 500 | iteration 25 / 30 | Total Loss: 1.1142151355743408 | KNN Loss: 6.226020336151123 | BCE Loss: 1.1142151355743408
Epoch 12 / 500 | iteration 0 / 30 | Total Loss: 1.1002548933029175 | KNN Loss: 6.226152420043945 | B

Epoch 21 / 500 | iteration 15 / 30 | Total Loss: 1.0578217506408691 | KNN Loss: 6.225747108459473 | BCE Loss: 1.0578217506408691
Epoch 21 / 500 | iteration 20 / 30 | Total Loss: 1.0658769607543945 | KNN Loss: 6.225818634033203 | BCE Loss: 1.0658769607543945
Epoch 21 / 500 | iteration 25 / 30 | Total Loss: 1.0609794855117798 | KNN Loss: 6.225801467895508 | BCE Loss: 1.0609794855117798
Epoch 22 / 500 | iteration 0 / 30 | Total Loss: 1.0351464748382568 | KNN Loss: 6.22580099105835 | BCE Loss: 1.0351464748382568
Epoch 22 / 500 | iteration 5 / 30 | Total Loss: 1.0949946641921997 | KNN Loss: 6.2258172035217285 | BCE Loss: 1.0949946641921997
Epoch 22 / 500 | iteration 10 / 30 | Total Loss: 1.0600221157073975 | KNN Loss: 6.225919723510742 | BCE Loss: 1.0600221157073975
Epoch 22 / 500 | iteration 15 / 30 | Total Loss: 1.069531798362732 | KNN Loss: 6.225830078125 | BCE Loss: 1.069531798362732
Epoch 22 / 500 | iteration 20 / 30 | Total Loss: 1.0701208114624023 | KNN Loss: 6.225661277770996 | BCE 

Epoch 32 / 500 | iteration 5 / 30 | Total Loss: 1.0552157163619995 | KNN Loss: 6.225578308105469 | BCE Loss: 1.0552157163619995
Epoch 32 / 500 | iteration 10 / 30 | Total Loss: 1.0483157634735107 | KNN Loss: 6.225674629211426 | BCE Loss: 1.0483157634735107
Epoch 32 / 500 | iteration 15 / 30 | Total Loss: 1.036783218383789 | KNN Loss: 6.225553512573242 | BCE Loss: 1.036783218383789
Epoch 32 / 500 | iteration 20 / 30 | Total Loss: 1.0388524532318115 | KNN Loss: 6.225253105163574 | BCE Loss: 1.0388524532318115
Epoch 32 / 500 | iteration 25 / 30 | Total Loss: 1.0534021854400635 | KNN Loss: 6.225140571594238 | BCE Loss: 1.0534021854400635
Epoch 33 / 500 | iteration 0 / 30 | Total Loss: 1.070055603981018 | KNN Loss: 6.225834846496582 | BCE Loss: 1.070055603981018
Epoch 33 / 500 | iteration 5 / 30 | Total Loss: 1.0525712966918945 | KNN Loss: 6.225413799285889 | BCE Loss: 1.0525712966918945
Epoch 33 / 500 | iteration 10 / 30 | Total Loss: 1.0483182668685913 | KNN Loss: 6.225648880004883 | BCE 

Epoch 42 / 500 | iteration 25 / 30 | Total Loss: 1.0471668243408203 | KNN Loss: 6.225060939788818 | BCE Loss: 1.0471668243408203
Epoch 43 / 500 | iteration 0 / 30 | Total Loss: 1.0459469556808472 | KNN Loss: 6.225045204162598 | BCE Loss: 1.0459469556808472
Epoch 43 / 500 | iteration 5 / 30 | Total Loss: 1.0588915348052979 | KNN Loss: 6.225334644317627 | BCE Loss: 1.0588915348052979
Epoch 43 / 500 | iteration 10 / 30 | Total Loss: 1.03245210647583 | KNN Loss: 6.2253241539001465 | BCE Loss: 1.03245210647583
Epoch 43 / 500 | iteration 15 / 30 | Total Loss: 1.0404412746429443 | KNN Loss: 6.225056171417236 | BCE Loss: 1.0404412746429443
Epoch 43 / 500 | iteration 20 / 30 | Total Loss: 1.077479362487793 | KNN Loss: 6.225242614746094 | BCE Loss: 1.077479362487793
Epoch 43 / 500 | iteration 25 / 30 | Total Loss: 1.0450937747955322 | KNN Loss: 6.225175380706787 | BCE Loss: 1.0450937747955322
Epoch 44 / 500 | iteration 0 / 30 | Total Loss: 1.0388567447662354 | KNN Loss: 6.225160121917725 | BCE L

Epoch 53 / 500 | iteration 15 / 30 | Total Loss: 1.0646257400512695 | KNN Loss: 6.225299835205078 | BCE Loss: 1.0646257400512695
Epoch 53 / 500 | iteration 20 / 30 | Total Loss: 1.0549118518829346 | KNN Loss: 6.224935054779053 | BCE Loss: 1.0549118518829346
Epoch 53 / 500 | iteration 25 / 30 | Total Loss: 1.0352615118026733 | KNN Loss: 6.225063323974609 | BCE Loss: 1.0352615118026733
Epoch 54 / 500 | iteration 0 / 30 | Total Loss: 1.0344198942184448 | KNN Loss: 6.225149631500244 | BCE Loss: 1.0344198942184448
Epoch 54 / 500 | iteration 5 / 30 | Total Loss: 1.037271499633789 | KNN Loss: 6.224960803985596 | BCE Loss: 1.037271499633789
Epoch 54 / 500 | iteration 10 / 30 | Total Loss: 1.0486021041870117 | KNN Loss: 6.225105285644531 | BCE Loss: 1.0486021041870117
Epoch 54 / 500 | iteration 15 / 30 | Total Loss: 1.0494158267974854 | KNN Loss: 6.224609851837158 | BCE Loss: 1.0494158267974854
Epoch 54 / 500 | iteration 20 / 30 | Total Loss: 1.0824624300003052 | KNN Loss: 6.225467205047607 | B

Epoch 64 / 500 | iteration 5 / 30 | Total Loss: 1.0402413606643677 | KNN Loss: 6.224911689758301 | BCE Loss: 1.0402413606643677
Epoch 64 / 500 | iteration 10 / 30 | Total Loss: 1.0289549827575684 | KNN Loss: 6.2247819900512695 | BCE Loss: 1.0289549827575684
Epoch 64 / 500 | iteration 15 / 30 | Total Loss: 1.0621110200881958 | KNN Loss: 6.225223541259766 | BCE Loss: 1.0621110200881958
Epoch 64 / 500 | iteration 20 / 30 | Total Loss: 1.061340570449829 | KNN Loss: 6.224773406982422 | BCE Loss: 1.061340570449829
Epoch 64 / 500 | iteration 25 / 30 | Total Loss: 1.0504777431488037 | KNN Loss: 6.225216865539551 | BCE Loss: 1.0504777431488037
Epoch 65 / 500 | iteration 0 / 30 | Total Loss: 1.056117057800293 | KNN Loss: 6.224765300750732 | BCE Loss: 1.056117057800293
Epoch 65 / 500 | iteration 5 / 30 | Total Loss: 1.023073434829712 | KNN Loss: 6.224705696105957 | BCE Loss: 1.023073434829712
Epoch 65 / 500 | iteration 10 / 30 | Total Loss: 1.0511085987091064 | KNN Loss: 6.225002765655518 | BCE L

Epoch 74 / 500 | iteration 25 / 30 | Total Loss: 1.07102370262146 | KNN Loss: 6.225177764892578 | BCE Loss: 1.07102370262146
Epoch 75 / 500 | iteration 0 / 30 | Total Loss: 1.053309440612793 | KNN Loss: 6.225134372711182 | BCE Loss: 1.053309440612793
Epoch 75 / 500 | iteration 5 / 30 | Total Loss: 1.0440880060195923 | KNN Loss: 6.2250823974609375 | BCE Loss: 1.0440880060195923
Epoch 75 / 500 | iteration 10 / 30 | Total Loss: 1.0464603900909424 | KNN Loss: 6.2247467041015625 | BCE Loss: 1.0464603900909424
Epoch 75 / 500 | iteration 15 / 30 | Total Loss: 1.073265790939331 | KNN Loss: 6.224511623382568 | BCE Loss: 1.073265790939331
Epoch 75 / 500 | iteration 20 / 30 | Total Loss: 1.067850112915039 | KNN Loss: 6.224900722503662 | BCE Loss: 1.067850112915039
Epoch 75 / 500 | iteration 25 / 30 | Total Loss: 1.0611684322357178 | KNN Loss: 6.224756717681885 | BCE Loss: 1.0611684322357178
Epoch 76 / 500 | iteration 0 / 30 | Total Loss: 1.048614501953125 | KNN Loss: 6.224401473999023 | BCE Loss:

Epoch 85 / 500 | iteration 15 / 30 | Total Loss: 1.0765334367752075 | KNN Loss: 6.22444486618042 | BCE Loss: 1.0765334367752075
Epoch 85 / 500 | iteration 20 / 30 | Total Loss: 1.0580940246582031 | KNN Loss: 6.224793910980225 | BCE Loss: 1.0580940246582031
Epoch 85 / 500 | iteration 25 / 30 | Total Loss: 1.0606456995010376 | KNN Loss: 6.224514484405518 | BCE Loss: 1.0606456995010376
Epoch 86 / 500 | iteration 0 / 30 | Total Loss: 1.044396162033081 | KNN Loss: 6.224551677703857 | BCE Loss: 1.044396162033081
Epoch 86 / 500 | iteration 5 / 30 | Total Loss: 1.033923864364624 | KNN Loss: 6.224844932556152 | BCE Loss: 1.033923864364624
Epoch 86 / 500 | iteration 10 / 30 | Total Loss: 1.0635799169540405 | KNN Loss: 6.224798202514648 | BCE Loss: 1.0635799169540405
Epoch 86 / 500 | iteration 15 / 30 | Total Loss: 1.0358927249908447 | KNN Loss: 6.224663257598877 | BCE Loss: 1.0358927249908447
Epoch 86 / 500 | iteration 20 / 30 | Total Loss: 1.0440495014190674 | KNN Loss: 6.224752902984619 | BCE 

Epoch 96 / 500 | iteration 5 / 30 | Total Loss: 1.0582730770111084 | KNN Loss: 6.224781513214111 | BCE Loss: 1.0582730770111084
Epoch 96 / 500 | iteration 10 / 30 | Total Loss: 1.0543168783187866 | KNN Loss: 6.224903106689453 | BCE Loss: 1.0543168783187866
Epoch 96 / 500 | iteration 15 / 30 | Total Loss: 1.0624308586120605 | KNN Loss: 6.224381446838379 | BCE Loss: 1.0624308586120605
Epoch 96 / 500 | iteration 20 / 30 | Total Loss: 1.0421669483184814 | KNN Loss: 6.224800109863281 | BCE Loss: 1.0421669483184814
Epoch 96 / 500 | iteration 25 / 30 | Total Loss: 1.0481914281845093 | KNN Loss: 6.224433898925781 | BCE Loss: 1.0481914281845093
Epoch 97 / 500 | iteration 0 / 30 | Total Loss: 1.0442655086517334 | KNN Loss: 6.224995136260986 | BCE Loss: 1.0442655086517334
Epoch 97 / 500 | iteration 5 / 30 | Total Loss: 1.0551190376281738 | KNN Loss: 6.224913120269775 | BCE Loss: 1.0551190376281738
Epoch 97 / 500 | iteration 10 / 30 | Total Loss: 1.0600216388702393 | KNN Loss: 6.224800109863281 | 

Epoch 106 / 500 | iteration 25 / 30 | Total Loss: 1.0405049324035645 | KNN Loss: 6.2249298095703125 | BCE Loss: 1.0405049324035645
Epoch 107 / 500 | iteration 0 / 30 | Total Loss: 1.0596915483474731 | KNN Loss: 6.224476337432861 | BCE Loss: 1.0596915483474731
Epoch 107 / 500 | iteration 5 / 30 | Total Loss: 1.048195719718933 | KNN Loss: 6.22458553314209 | BCE Loss: 1.048195719718933
Epoch 107 / 500 | iteration 10 / 30 | Total Loss: 1.0707423686981201 | KNN Loss: 6.224493980407715 | BCE Loss: 1.0707423686981201
Epoch 107 / 500 | iteration 15 / 30 | Total Loss: 1.0522383451461792 | KNN Loss: 6.224863529205322 | BCE Loss: 1.0522383451461792
Epoch 107 / 500 | iteration 20 / 30 | Total Loss: 1.078044056892395 | KNN Loss: 6.224877834320068 | BCE Loss: 1.078044056892395
Epoch 107 / 500 | iteration 25 / 30 | Total Loss: 1.0453369617462158 | KNN Loss: 6.224391937255859 | BCE Loss: 1.0453369617462158
Epoch 108 / 500 | iteration 0 / 30 | Total Loss: 1.0545514822006226 | KNN Loss: 6.22476625442504

Epoch 117 / 500 | iteration 10 / 30 | Total Loss: 1.0205494165420532 | KNN Loss: 6.224244594573975 | BCE Loss: 1.0205494165420532
Epoch 117 / 500 | iteration 15 / 30 | Total Loss: 1.0265007019042969 | KNN Loss: 6.224869728088379 | BCE Loss: 1.0265007019042969
Epoch 117 / 500 | iteration 20 / 30 | Total Loss: 1.0465185642242432 | KNN Loss: 6.224089622497559 | BCE Loss: 1.0465185642242432
Epoch 117 / 500 | iteration 25 / 30 | Total Loss: 1.0598138570785522 | KNN Loss: 6.22451639175415 | BCE Loss: 1.0598138570785522
Epoch 118 / 500 | iteration 0 / 30 | Total Loss: 1.026871681213379 | KNN Loss: 6.224137783050537 | BCE Loss: 1.026871681213379
Epoch 118 / 500 | iteration 5 / 30 | Total Loss: 1.0496935844421387 | KNN Loss: 6.224656105041504 | BCE Loss: 1.0496935844421387
Epoch 118 / 500 | iteration 10 / 30 | Total Loss: 1.0394704341888428 | KNN Loss: 6.224643230438232 | BCE Loss: 1.0394704341888428
Epoch 118 / 500 | iteration 15 / 30 | Total Loss: 1.037872552871704 | KNN Loss: 6.2243289947509

Epoch 127 / 500 | iteration 25 / 30 | Total Loss: 1.0717620849609375 | KNN Loss: 6.224358081817627 | BCE Loss: 1.0717620849609375
Epoch 128 / 500 | iteration 0 / 30 | Total Loss: 1.0686118602752686 | KNN Loss: 6.22437858581543 | BCE Loss: 1.0686118602752686
Epoch 128 / 500 | iteration 5 / 30 | Total Loss: 1.0324101448059082 | KNN Loss: 6.2245192527771 | BCE Loss: 1.0324101448059082
Epoch 128 / 500 | iteration 10 / 30 | Total Loss: 1.072211742401123 | KNN Loss: 6.2246294021606445 | BCE Loss: 1.072211742401123
Epoch 128 / 500 | iteration 15 / 30 | Total Loss: 1.0662367343902588 | KNN Loss: 6.224266052246094 | BCE Loss: 1.0662367343902588
Epoch 128 / 500 | iteration 20 / 30 | Total Loss: 1.0588347911834717 | KNN Loss: 6.224305629730225 | BCE Loss: 1.0588347911834717
Epoch 128 / 500 | iteration 25 / 30 | Total Loss: 1.0081779956817627 | KNN Loss: 6.224628448486328 | BCE Loss: 1.0081779956817627
Epoch 129 / 500 | iteration 0 / 30 | Total Loss: 1.0564218759536743 | KNN Loss: 6.22444581985473

Epoch 138 / 500 | iteration 10 / 30 | Total Loss: 1.0434868335723877 | KNN Loss: 6.224539756774902 | BCE Loss: 1.0434868335723877
Epoch 138 / 500 | iteration 15 / 30 | Total Loss: 1.0644848346710205 | KNN Loss: 6.224668502807617 | BCE Loss: 1.0644848346710205
Epoch 138 / 500 | iteration 20 / 30 | Total Loss: 1.0518426895141602 | KNN Loss: 6.224427700042725 | BCE Loss: 1.0518426895141602
Epoch 138 / 500 | iteration 25 / 30 | Total Loss: 1.0558515787124634 | KNN Loss: 6.224489212036133 | BCE Loss: 1.0558515787124634
Epoch 139 / 500 | iteration 0 / 30 | Total Loss: 1.0500545501708984 | KNN Loss: 6.2243733406066895 | BCE Loss: 1.0500545501708984
Epoch 139 / 500 | iteration 5 / 30 | Total Loss: 1.0232995748519897 | KNN Loss: 6.224390506744385 | BCE Loss: 1.0232995748519897
Epoch 139 / 500 | iteration 10 / 30 | Total Loss: 1.0811059474945068 | KNN Loss: 6.224963665008545 | BCE Loss: 1.0811059474945068
Epoch 139 / 500 | iteration 15 / 30 | Total Loss: 1.0447630882263184 | KNN Loss: 6.22429275

Epoch 148 / 500 | iteration 25 / 30 | Total Loss: 1.05637788772583 | KNN Loss: 6.224045753479004 | BCE Loss: 1.05637788772583
Epoch 149 / 500 | iteration 0 / 30 | Total Loss: 1.0540565252304077 | KNN Loss: 6.224485397338867 | BCE Loss: 1.0540565252304077
Epoch 149 / 500 | iteration 5 / 30 | Total Loss: 1.0433368682861328 | KNN Loss: 6.224481105804443 | BCE Loss: 1.0433368682861328
Epoch 149 / 500 | iteration 10 / 30 | Total Loss: 1.062540054321289 | KNN Loss: 6.224402904510498 | BCE Loss: 1.062540054321289
Epoch 149 / 500 | iteration 15 / 30 | Total Loss: 1.0731459856033325 | KNN Loss: 6.224174976348877 | BCE Loss: 1.0731459856033325
Epoch 149 / 500 | iteration 20 / 30 | Total Loss: 1.0279265642166138 | KNN Loss: 6.224489688873291 | BCE Loss: 1.0279265642166138
Epoch 149 / 500 | iteration 25 / 30 | Total Loss: 1.0626847743988037 | KNN Loss: 6.224423885345459 | BCE Loss: 1.0626847743988037
Epoch 150 / 500 | iteration 0 / 30 | Total Loss: 1.0488964319229126 | KNN Loss: 6.224236488342285 

Epoch 159 / 500 | iteration 15 / 30 | Total Loss: 1.0526831150054932 | KNN Loss: 6.224292278289795 | BCE Loss: 1.0526831150054932
Epoch 159 / 500 | iteration 20 / 30 | Total Loss: 1.0371780395507812 | KNN Loss: 6.224307537078857 | BCE Loss: 1.0371780395507812
Epoch 159 / 500 | iteration 25 / 30 | Total Loss: 1.035780668258667 | KNN Loss: 6.224663257598877 | BCE Loss: 1.035780668258667
Epoch 160 / 500 | iteration 0 / 30 | Total Loss: 1.0621693134307861 | KNN Loss: 6.224734306335449 | BCE Loss: 1.0621693134307861
Epoch 160 / 500 | iteration 5 / 30 | Total Loss: 1.0234891176223755 | KNN Loss: 6.224315166473389 | BCE Loss: 1.0234891176223755
Epoch 160 / 500 | iteration 10 / 30 | Total Loss: 1.0591609477996826 | KNN Loss: 6.224399566650391 | BCE Loss: 1.0591609477996826
Epoch 160 / 500 | iteration 15 / 30 | Total Loss: 1.0416561365127563 | KNN Loss: 6.22465705871582 | BCE Loss: 1.0416561365127563
Epoch 160 / 500 | iteration 20 / 30 | Total Loss: 1.0567348003387451 | KNN Loss: 6.223995685577

Epoch 170 / 500 | iteration 0 / 30 | Total Loss: 1.0331776142120361 | KNN Loss: 6.2244648933410645 | BCE Loss: 1.0331776142120361
Epoch 170 / 500 | iteration 5 / 30 | Total Loss: 1.0659198760986328 | KNN Loss: 6.224669456481934 | BCE Loss: 1.0659198760986328
Epoch 170 / 500 | iteration 10 / 30 | Total Loss: 1.040419340133667 | KNN Loss: 6.224400043487549 | BCE Loss: 1.040419340133667
Epoch 170 / 500 | iteration 15 / 30 | Total Loss: 1.0286602973937988 | KNN Loss: 6.2245707511901855 | BCE Loss: 1.0286602973937988
Epoch 170 / 500 | iteration 20 / 30 | Total Loss: 1.0602948665618896 | KNN Loss: 6.224560260772705 | BCE Loss: 1.0602948665618896
Epoch 170 / 500 | iteration 25 / 30 | Total Loss: 1.049367070198059 | KNN Loss: 6.2246246337890625 | BCE Loss: 1.049367070198059
Epoch 171 / 500 | iteration 0 / 30 | Total Loss: 1.0278263092041016 | KNN Loss: 6.224634647369385 | BCE Loss: 1.0278263092041016
Epoch 171 / 500 | iteration 5 / 30 | Total Loss: 1.044264793395996 | KNN Loss: 6.2243552207946

Epoch 180 / 500 | iteration 20 / 30 | Total Loss: 1.0396220684051514 | KNN Loss: 6.224434852600098 | BCE Loss: 1.0396220684051514
Epoch 180 / 500 | iteration 25 / 30 | Total Loss: 1.0584218502044678 | KNN Loss: 6.224247932434082 | BCE Loss: 1.0584218502044678
Epoch 181 / 500 | iteration 0 / 30 | Total Loss: 1.037247896194458 | KNN Loss: 6.224360466003418 | BCE Loss: 1.037247896194458
Epoch 181 / 500 | iteration 5 / 30 | Total Loss: 1.043837070465088 | KNN Loss: 6.2248125076293945 | BCE Loss: 1.043837070465088
Epoch 181 / 500 | iteration 10 / 30 | Total Loss: 1.0840134620666504 | KNN Loss: 6.224400043487549 | BCE Loss: 1.0840134620666504
Epoch 181 / 500 | iteration 15 / 30 | Total Loss: 1.030479073524475 | KNN Loss: 6.224012851715088 | BCE Loss: 1.030479073524475
Epoch 181 / 500 | iteration 20 / 30 | Total Loss: 1.0713104009628296 | KNN Loss: 6.224464416503906 | BCE Loss: 1.0713104009628296
Epoch 181 / 500 | iteration 25 / 30 | Total Loss: 1.0437517166137695 | KNN Loss: 6.22450780868530

Epoch 191 / 500 | iteration 5 / 30 | Total Loss: 1.0423469543457031 | KNN Loss: 6.2241010665893555 | BCE Loss: 1.0423469543457031
Epoch 191 / 500 | iteration 10 / 30 | Total Loss: 1.0488966703414917 | KNN Loss: 6.224429130554199 | BCE Loss: 1.0488966703414917
Epoch 191 / 500 | iteration 15 / 30 | Total Loss: 1.029536247253418 | KNN Loss: 6.224433898925781 | BCE Loss: 1.029536247253418
Epoch 191 / 500 | iteration 20 / 30 | Total Loss: 1.0639927387237549 | KNN Loss: 6.224366664886475 | BCE Loss: 1.0639927387237549
Epoch 191 / 500 | iteration 25 / 30 | Total Loss: 1.047588586807251 | KNN Loss: 6.2245025634765625 | BCE Loss: 1.047588586807251
Epoch 192 / 500 | iteration 0 / 30 | Total Loss: 1.0649235248565674 | KNN Loss: 6.224679946899414 | BCE Loss: 1.0649235248565674
Epoch 192 / 500 | iteration 5 / 30 | Total Loss: 1.0534687042236328 | KNN Loss: 6.224210739135742 | BCE Loss: 1.0534687042236328
Epoch 192 / 500 | iteration 10 / 30 | Total Loss: 1.0691789388656616 | KNN Loss: 6.224896907806

Epoch 201 / 500 | iteration 25 / 30 | Total Loss: 1.0267778635025024 | KNN Loss: 6.224354267120361 | BCE Loss: 1.0267778635025024
Epoch 202 / 500 | iteration 0 / 30 | Total Loss: 1.0787684917449951 | KNN Loss: 6.224696636199951 | BCE Loss: 1.0787684917449951
Epoch 202 / 500 | iteration 5 / 30 | Total Loss: 1.078385829925537 | KNN Loss: 6.224146842956543 | BCE Loss: 1.078385829925537
Epoch 202 / 500 | iteration 10 / 30 | Total Loss: 1.0522165298461914 | KNN Loss: 6.224628448486328 | BCE Loss: 1.0522165298461914
Epoch 202 / 500 | iteration 15 / 30 | Total Loss: 1.0645942687988281 | KNN Loss: 6.2244133949279785 | BCE Loss: 1.0645942687988281
Epoch 202 / 500 | iteration 20 / 30 | Total Loss: 1.0524063110351562 | KNN Loss: 6.224375247955322 | BCE Loss: 1.0524063110351562
Epoch 202 / 500 | iteration 25 / 30 | Total Loss: 1.0243499279022217 | KNN Loss: 6.224273681640625 | BCE Loss: 1.0243499279022217
Epoch 203 / 500 | iteration 0 / 30 | Total Loss: 1.0561286211013794 | KNN Loss: 6.22485113143

Epoch 212 / 500 | iteration 10 / 30 | Total Loss: 1.0614166259765625 | KNN Loss: 6.224542140960693 | BCE Loss: 1.0614166259765625
Epoch 212 / 500 | iteration 15 / 30 | Total Loss: 1.0532965660095215 | KNN Loss: 6.224286079406738 | BCE Loss: 1.0532965660095215
Epoch 212 / 500 | iteration 20 / 30 | Total Loss: 1.0444316864013672 | KNN Loss: 6.224583625793457 | BCE Loss: 1.0444316864013672
Epoch 212 / 500 | iteration 25 / 30 | Total Loss: 1.06974196434021 | KNN Loss: 6.224419593811035 | BCE Loss: 1.06974196434021
Epoch 213 / 500 | iteration 0 / 30 | Total Loss: 1.0598831176757812 | KNN Loss: 6.2248406410217285 | BCE Loss: 1.0598831176757812
Epoch 213 / 500 | iteration 5 / 30 | Total Loss: 1.0313644409179688 | KNN Loss: 6.224653244018555 | BCE Loss: 1.0313644409179688
Epoch 213 / 500 | iteration 10 / 30 | Total Loss: 1.0352468490600586 | KNN Loss: 6.224531173706055 | BCE Loss: 1.0352468490600586
Epoch 213 / 500 | iteration 15 / 30 | Total Loss: 1.0814759731292725 | KNN Loss: 6.224477291107

Epoch 222 / 500 | iteration 25 / 30 | Total Loss: 1.0172940492630005 | KNN Loss: 6.224638938903809 | BCE Loss: 1.0172940492630005
Epoch 223 / 500 | iteration 0 / 30 | Total Loss: 1.060800313949585 | KNN Loss: 6.224909782409668 | BCE Loss: 1.060800313949585
Epoch 223 / 500 | iteration 5 / 30 | Total Loss: 1.047761082649231 | KNN Loss: 6.224462985992432 | BCE Loss: 1.047761082649231
Epoch 223 / 500 | iteration 10 / 30 | Total Loss: 1.0428802967071533 | KNN Loss: 6.224353313446045 | BCE Loss: 1.0428802967071533
Epoch 223 / 500 | iteration 15 / 30 | Total Loss: 1.0442523956298828 | KNN Loss: 6.224369049072266 | BCE Loss: 1.0442523956298828
Epoch 223 / 500 | iteration 20 / 30 | Total Loss: 1.027963399887085 | KNN Loss: 6.224691867828369 | BCE Loss: 1.027963399887085
Epoch 223 / 500 | iteration 25 / 30 | Total Loss: 1.020859956741333 | KNN Loss: 6.224618911743164 | BCE Loss: 1.020859956741333
Epoch 224 / 500 | iteration 0 / 30 | Total Loss: 1.0474138259887695 | KNN Loss: 6.224839687347412 | 

Epoch 233 / 500 | iteration 10 / 30 | Total Loss: 1.0416216850280762 | KNN Loss: 6.224234104156494 | BCE Loss: 1.0416216850280762
Epoch 233 / 500 | iteration 15 / 30 | Total Loss: 1.051363468170166 | KNN Loss: 6.224732398986816 | BCE Loss: 1.051363468170166
Epoch 233 / 500 | iteration 20 / 30 | Total Loss: 1.0091228485107422 | KNN Loss: 6.22451639175415 | BCE Loss: 1.0091228485107422
Epoch 233 / 500 | iteration 25 / 30 | Total Loss: 1.049600601196289 | KNN Loss: 6.224287986755371 | BCE Loss: 1.049600601196289
Epoch 234 / 500 | iteration 0 / 30 | Total Loss: 1.055942416191101 | KNN Loss: 6.224703311920166 | BCE Loss: 1.055942416191101
Epoch 234 / 500 | iteration 5 / 30 | Total Loss: 1.0682318210601807 | KNN Loss: 6.224812984466553 | BCE Loss: 1.0682318210601807
Epoch 234 / 500 | iteration 10 / 30 | Total Loss: 0.9930659532546997 | KNN Loss: 6.224452972412109 | BCE Loss: 0.9930659532546997
Epoch 234 / 500 | iteration 15 / 30 | Total Loss: 1.0639606714248657 | KNN Loss: 6.224636554718018 

Epoch   244: reducing learning rate of group 0 to 1.1632e-05.
Epoch 244 / 500 | iteration 0 / 30 | Total Loss: 1.0352814197540283 | KNN Loss: 6.224630355834961 | BCE Loss: 1.0352814197540283
Epoch 244 / 500 | iteration 5 / 30 | Total Loss: 1.031891107559204 | KNN Loss: 6.2245354652404785 | BCE Loss: 1.031891107559204
Epoch 244 / 500 | iteration 10 / 30 | Total Loss: 1.0435283184051514 | KNN Loss: 6.22435188293457 | BCE Loss: 1.0435283184051514
Epoch 244 / 500 | iteration 15 / 30 | Total Loss: 1.056640863418579 | KNN Loss: 6.224337100982666 | BCE Loss: 1.056640863418579
Epoch 244 / 500 | iteration 20 / 30 | Total Loss: 1.0300850868225098 | KNN Loss: 6.22428035736084 | BCE Loss: 1.0300850868225098
Epoch 244 / 500 | iteration 25 / 30 | Total Loss: 1.0617735385894775 | KNN Loss: 6.224442958831787 | BCE Loss: 1.0617735385894775
Epoch 245 / 500 | iteration 0 / 30 | Total Loss: 1.079470157623291 | KNN Loss: 6.224175453186035 | BCE Loss: 1.079470157623291
Epoch 245 / 500 | iteration 5 / 30 | T

Epoch 254 / 500 | iteration 15 / 30 | Total Loss: 1.0573482513427734 | KNN Loss: 6.224456310272217 | BCE Loss: 1.0573482513427734
Epoch 254 / 500 | iteration 20 / 30 | Total Loss: 1.0567731857299805 | KNN Loss: 6.224266052246094 | BCE Loss: 1.0567731857299805
Epoch 254 / 500 | iteration 25 / 30 | Total Loss: 1.073794960975647 | KNN Loss: 6.2245612144470215 | BCE Loss: 1.073794960975647
Epoch 255 / 500 | iteration 0 / 30 | Total Loss: 1.0609829425811768 | KNN Loss: 6.224381446838379 | BCE Loss: 1.0609829425811768
Epoch 255 / 500 | iteration 5 / 30 | Total Loss: 1.0225276947021484 | KNN Loss: 6.22417688369751 | BCE Loss: 1.0225276947021484
Epoch 255 / 500 | iteration 10 / 30 | Total Loss: 1.0486847162246704 | KNN Loss: 6.224466800689697 | BCE Loss: 1.0486847162246704
Epoch 255 / 500 | iteration 15 / 30 | Total Loss: 1.0587266683578491 | KNN Loss: 6.224488258361816 | BCE Loss: 1.0587266683578491
Epoch 255 / 500 | iteration 20 / 30 | Total Loss: 1.0480461120605469 | KNN Loss: 6.22424936294

Epoch 265 / 500 | iteration 0 / 30 | Total Loss: 1.043273687362671 | KNN Loss: 6.2243733406066895 | BCE Loss: 1.043273687362671
Epoch 265 / 500 | iteration 5 / 30 | Total Loss: 1.0259876251220703 | KNN Loss: 6.224399566650391 | BCE Loss: 1.0259876251220703
Epoch 265 / 500 | iteration 10 / 30 | Total Loss: 1.0471856594085693 | KNN Loss: 6.224631309509277 | BCE Loss: 1.0471856594085693
Epoch 265 / 500 | iteration 15 / 30 | Total Loss: 1.024215817451477 | KNN Loss: 6.2240376472473145 | BCE Loss: 1.024215817451477
Epoch 265 / 500 | iteration 20 / 30 | Total Loss: 1.0617256164550781 | KNN Loss: 6.224730491638184 | BCE Loss: 1.0617256164550781
Epoch 265 / 500 | iteration 25 / 30 | Total Loss: 1.0659470558166504 | KNN Loss: 6.224471092224121 | BCE Loss: 1.0659470558166504
Epoch 266 / 500 | iteration 0 / 30 | Total Loss: 1.0728554725646973 | KNN Loss: 6.224879264831543 | BCE Loss: 1.0728554725646973
Epoch 266 / 500 | iteration 5 / 30 | Total Loss: 1.0661587715148926 | KNN Loss: 6.2244014739990

Epoch 275 / 500 | iteration 15 / 30 | Total Loss: 1.0248640775680542 | KNN Loss: 6.224447250366211 | BCE Loss: 1.0248640775680542
Epoch 275 / 500 | iteration 20 / 30 | Total Loss: 1.0563808679580688 | KNN Loss: 6.224633693695068 | BCE Loss: 1.0563808679580688
Epoch 275 / 500 | iteration 25 / 30 | Total Loss: 1.064805507659912 | KNN Loss: 6.224029064178467 | BCE Loss: 1.064805507659912
Epoch 276 / 500 | iteration 0 / 30 | Total Loss: 1.0296330451965332 | KNN Loss: 6.224772930145264 | BCE Loss: 1.0296330451965332
Epoch 276 / 500 | iteration 5 / 30 | Total Loss: 1.026968240737915 | KNN Loss: 6.224341869354248 | BCE Loss: 1.026968240737915
Epoch 276 / 500 | iteration 10 / 30 | Total Loss: 1.03910493850708 | KNN Loss: 6.2240681648254395 | BCE Loss: 1.03910493850708
Epoch 276 / 500 | iteration 15 / 30 | Total Loss: 1.0757595300674438 | KNN Loss: 6.224577903747559 | BCE Loss: 1.0757595300674438
Epoch 276 / 500 | iteration 20 / 30 | Total Loss: 1.052459478378296 | KNN Loss: 6.224550247192383 |

Epoch 286 / 500 | iteration 0 / 30 | Total Loss: 1.0380558967590332 | KNN Loss: 6.22443151473999 | BCE Loss: 1.0380558967590332
Epoch 286 / 500 | iteration 5 / 30 | Total Loss: 1.0746194124221802 | KNN Loss: 6.224595546722412 | BCE Loss: 1.0746194124221802
Epoch 286 / 500 | iteration 10 / 30 | Total Loss: 1.0416345596313477 | KNN Loss: 6.224277496337891 | BCE Loss: 1.0416345596313477
Epoch 286 / 500 | iteration 15 / 30 | Total Loss: 1.052182912826538 | KNN Loss: 6.224298477172852 | BCE Loss: 1.052182912826538
Epoch 286 / 500 | iteration 20 / 30 | Total Loss: 1.071049690246582 | KNN Loss: 6.224433422088623 | BCE Loss: 1.071049690246582
Epoch 286 / 500 | iteration 25 / 30 | Total Loss: 1.0638822317123413 | KNN Loss: 6.224143028259277 | BCE Loss: 1.0638822317123413
Epoch 287 / 500 | iteration 0 / 30 | Total Loss: 1.0355387926101685 | KNN Loss: 6.224236488342285 | BCE Loss: 1.0355387926101685
Epoch 287 / 500 | iteration 5 / 30 | Total Loss: 1.078221082687378 | KNN Loss: 6.224668502807617 |

Epoch 296 / 500 | iteration 15 / 30 | Total Loss: 1.0803128480911255 | KNN Loss: 6.224167823791504 | BCE Loss: 1.0803128480911255
Epoch 296 / 500 | iteration 20 / 30 | Total Loss: 1.0536699295043945 | KNN Loss: 6.224960803985596 | BCE Loss: 1.0536699295043945
Epoch 296 / 500 | iteration 25 / 30 | Total Loss: 1.0311559438705444 | KNN Loss: 6.2247538566589355 | BCE Loss: 1.0311559438705444
Epoch 297 / 500 | iteration 0 / 30 | Total Loss: 1.0729475021362305 | KNN Loss: 6.224485397338867 | BCE Loss: 1.0729475021362305
Epoch 297 / 500 | iteration 5 / 30 | Total Loss: 1.0430800914764404 | KNN Loss: 6.224111080169678 | BCE Loss: 1.0430800914764404
Epoch 297 / 500 | iteration 10 / 30 | Total Loss: 1.042130708694458 | KNN Loss: 6.224938869476318 | BCE Loss: 1.042130708694458
Epoch 297 / 500 | iteration 15 / 30 | Total Loss: 1.0647881031036377 | KNN Loss: 6.224634647369385 | BCE Loss: 1.0647881031036377
Epoch 297 / 500 | iteration 20 / 30 | Total Loss: 1.0589311122894287 | KNN Loss: 6.2246956825

Epoch 307 / 500 | iteration 5 / 30 | Total Loss: 1.0271050930023193 | KNN Loss: 6.224339008331299 | BCE Loss: 1.0271050930023193
Epoch 307 / 500 | iteration 10 / 30 | Total Loss: 1.041398525238037 | KNN Loss: 6.22435188293457 | BCE Loss: 1.041398525238037
Epoch 307 / 500 | iteration 15 / 30 | Total Loss: 1.0410473346710205 | KNN Loss: 6.224340438842773 | BCE Loss: 1.0410473346710205
Epoch 307 / 500 | iteration 20 / 30 | Total Loss: 1.061018943786621 | KNN Loss: 6.224172115325928 | BCE Loss: 1.061018943786621
Epoch 307 / 500 | iteration 25 / 30 | Total Loss: 1.0444198846817017 | KNN Loss: 6.224419116973877 | BCE Loss: 1.0444198846817017
Epoch 308 / 500 | iteration 0 / 30 | Total Loss: 1.0373451709747314 | KNN Loss: 6.2242536544799805 | BCE Loss: 1.0373451709747314
Epoch 308 / 500 | iteration 5 / 30 | Total Loss: 1.091111660003662 | KNN Loss: 6.224403381347656 | BCE Loss: 1.091111660003662
Epoch 308 / 500 | iteration 10 / 30 | Total Loss: 1.0396782159805298 | KNN Loss: 6.2241010665893555

Epoch 317 / 500 | iteration 20 / 30 | Total Loss: 1.0304913520812988 | KNN Loss: 6.2244086265563965 | BCE Loss: 1.0304913520812988
Epoch 317 / 500 | iteration 25 / 30 | Total Loss: 1.022512674331665 | KNN Loss: 6.224454402923584 | BCE Loss: 1.022512674331665
Epoch 318 / 500 | iteration 0 / 30 | Total Loss: 1.042155146598816 | KNN Loss: 6.224437713623047 | BCE Loss: 1.042155146598816
Epoch 318 / 500 | iteration 5 / 30 | Total Loss: 1.051027774810791 | KNN Loss: 6.224869728088379 | BCE Loss: 1.051027774810791
Epoch 318 / 500 | iteration 10 / 30 | Total Loss: 1.1020526885986328 | KNN Loss: 6.224310874938965 | BCE Loss: 1.1020526885986328
Epoch 318 / 500 | iteration 15 / 30 | Total Loss: 1.0528674125671387 | KNN Loss: 6.2241740226745605 | BCE Loss: 1.0528674125671387
Epoch 318 / 500 | iteration 20 / 30 | Total Loss: 1.0457926988601685 | KNN Loss: 6.224430561065674 | BCE Loss: 1.0457926988601685
Epoch 318 / 500 | iteration 25 / 30 | Total Loss: 1.0376588106155396 | KNN Loss: 6.2245411872863

Epoch 328 / 500 | iteration 10 / 30 | Total Loss: 1.0698760747909546 | KNN Loss: 6.224327087402344 | BCE Loss: 1.0698760747909546
Epoch 328 / 500 | iteration 15 / 30 | Total Loss: 1.0441166162490845 | KNN Loss: 6.224541187286377 | BCE Loss: 1.0441166162490845
Epoch 328 / 500 | iteration 20 / 30 | Total Loss: 1.031194806098938 | KNN Loss: 6.224257946014404 | BCE Loss: 1.031194806098938
Epoch 328 / 500 | iteration 25 / 30 | Total Loss: 1.0331920385360718 | KNN Loss: 6.224069595336914 | BCE Loss: 1.0331920385360718
Epoch 329 / 500 | iteration 0 / 30 | Total Loss: 1.0517688989639282 | KNN Loss: 6.224905490875244 | BCE Loss: 1.0517688989639282
Epoch 329 / 500 | iteration 5 / 30 | Total Loss: 1.0591871738433838 | KNN Loss: 6.224545478820801 | BCE Loss: 1.0591871738433838
Epoch 329 / 500 | iteration 10 / 30 | Total Loss: 1.0593271255493164 | KNN Loss: 6.224287986755371 | BCE Loss: 1.0593271255493164
Epoch 329 / 500 | iteration 15 / 30 | Total Loss: 1.0481394529342651 | KNN Loss: 6.22423315048

Epoch 338 / 500 | iteration 25 / 30 | Total Loss: 1.0476807355880737 | KNN Loss: 6.22468900680542 | BCE Loss: 1.0476807355880737
Epoch 339 / 500 | iteration 0 / 30 | Total Loss: 1.0425565242767334 | KNN Loss: 6.224487781524658 | BCE Loss: 1.0425565242767334
Epoch 339 / 500 | iteration 5 / 30 | Total Loss: 1.0701920986175537 | KNN Loss: 6.224555015563965 | BCE Loss: 1.0701920986175537
Epoch 339 / 500 | iteration 10 / 30 | Total Loss: 1.0560481548309326 | KNN Loss: 6.224308013916016 | BCE Loss: 1.0560481548309326
Epoch 339 / 500 | iteration 15 / 30 | Total Loss: 1.0609067678451538 | KNN Loss: 6.224414348602295 | BCE Loss: 1.0609067678451538
Epoch 339 / 500 | iteration 20 / 30 | Total Loss: 1.0209779739379883 | KNN Loss: 6.224145889282227 | BCE Loss: 1.0209779739379883
Epoch 339 / 500 | iteration 25 / 30 | Total Loss: 1.042926549911499 | KNN Loss: 6.224164009094238 | BCE Loss: 1.042926549911499
Epoch 340 / 500 | iteration 0 / 30 | Total Loss: 1.0528521537780762 | KNN Loss: 6.2246489524841

Epoch 349 / 500 | iteration 10 / 30 | Total Loss: 1.0152555704116821 | KNN Loss: 6.224185466766357 | BCE Loss: 1.0152555704116821
Epoch 349 / 500 | iteration 15 / 30 | Total Loss: 1.0581552982330322 | KNN Loss: 6.224181652069092 | BCE Loss: 1.0581552982330322
Epoch 349 / 500 | iteration 20 / 30 | Total Loss: 1.0157139301300049 | KNN Loss: 6.2242960929870605 | BCE Loss: 1.0157139301300049
Epoch 349 / 500 | iteration 25 / 30 | Total Loss: 1.057373046875 | KNN Loss: 6.22415828704834 | BCE Loss: 1.057373046875
Epoch 350 / 500 | iteration 0 / 30 | Total Loss: 1.0690217018127441 | KNN Loss: 6.224431991577148 | BCE Loss: 1.0690217018127441
Epoch 350 / 500 | iteration 5 / 30 | Total Loss: 1.0632871389389038 | KNN Loss: 6.224470615386963 | BCE Loss: 1.0632871389389038
Epoch 350 / 500 | iteration 10 / 30 | Total Loss: 1.0469751358032227 | KNN Loss: 6.224776268005371 | BCE Loss: 1.0469751358032227
Epoch 350 / 500 | iteration 15 / 30 | Total Loss: 1.0382843017578125 | KNN Loss: 6.224638938903809 |

Epoch 359 / 500 | iteration 25 / 30 | Total Loss: 1.0313165187835693 | KNN Loss: 6.224365711212158 | BCE Loss: 1.0313165187835693
Epoch 360 / 500 | iteration 0 / 30 | Total Loss: 1.0807781219482422 | KNN Loss: 6.2242350578308105 | BCE Loss: 1.0807781219482422
Epoch 360 / 500 | iteration 5 / 30 | Total Loss: 1.0295381546020508 | KNN Loss: 6.224366188049316 | BCE Loss: 1.0295381546020508
Epoch 360 / 500 | iteration 10 / 30 | Total Loss: 1.0438203811645508 | KNN Loss: 6.224514484405518 | BCE Loss: 1.0438203811645508
Epoch 360 / 500 | iteration 15 / 30 | Total Loss: 1.0498225688934326 | KNN Loss: 6.224639892578125 | BCE Loss: 1.0498225688934326
Epoch 360 / 500 | iteration 20 / 30 | Total Loss: 1.0330480337142944 | KNN Loss: 6.224452018737793 | BCE Loss: 1.0330480337142944
Epoch 360 / 500 | iteration 25 / 30 | Total Loss: 1.0579442977905273 | KNN Loss: 6.224433898925781 | BCE Loss: 1.0579442977905273
Epoch 361 / 500 | iteration 0 / 30 | Total Loss: 1.027134895324707 | KNN Loss: 6.2248950004

Epoch 370 / 500 | iteration 10 / 30 | Total Loss: 1.0368140935897827 | KNN Loss: 6.2245283126831055 | BCE Loss: 1.0368140935897827
Epoch 370 / 500 | iteration 15 / 30 | Total Loss: 1.059569001197815 | KNN Loss: 6.224441051483154 | BCE Loss: 1.059569001197815
Epoch 370 / 500 | iteration 20 / 30 | Total Loss: 1.0367376804351807 | KNN Loss: 6.2246785163879395 | BCE Loss: 1.0367376804351807
Epoch 370 / 500 | iteration 25 / 30 | Total Loss: 1.0563509464263916 | KNN Loss: 6.224264621734619 | BCE Loss: 1.0563509464263916
Epoch 371 / 500 | iteration 0 / 30 | Total Loss: 1.0409603118896484 | KNN Loss: 6.224153995513916 | BCE Loss: 1.0409603118896484
Epoch 371 / 500 | iteration 5 / 30 | Total Loss: 1.0537822246551514 | KNN Loss: 6.224803447723389 | BCE Loss: 1.0537822246551514
Epoch 371 / 500 | iteration 10 / 30 | Total Loss: 1.0504586696624756 | KNN Loss: 6.224382400512695 | BCE Loss: 1.0504586696624756
Epoch 371 / 500 | iteration 15 / 30 | Total Loss: 1.0511090755462646 | KNN Loss: 6.224651336

Epoch 380 / 500 | iteration 25 / 30 | Total Loss: 1.086454153060913 | KNN Loss: 6.224172115325928 | BCE Loss: 1.086454153060913
Epoch 381 / 500 | iteration 0 / 30 | Total Loss: 1.0659332275390625 | KNN Loss: 6.224648475646973 | BCE Loss: 1.0659332275390625
Epoch 381 / 500 | iteration 5 / 30 | Total Loss: 1.036665678024292 | KNN Loss: 6.224678039550781 | BCE Loss: 1.036665678024292
Epoch 381 / 500 | iteration 10 / 30 | Total Loss: 1.0537291765213013 | KNN Loss: 6.22409200668335 | BCE Loss: 1.0537291765213013
Epoch 381 / 500 | iteration 15 / 30 | Total Loss: 1.0460681915283203 | KNN Loss: 6.224681854248047 | BCE Loss: 1.0460681915283203
Epoch 381 / 500 | iteration 20 / 30 | Total Loss: 1.090316891670227 | KNN Loss: 6.224748134613037 | BCE Loss: 1.090316891670227
Epoch 381 / 500 | iteration 25 / 30 | Total Loss: 1.0223219394683838 | KNN Loss: 6.224374294281006 | BCE Loss: 1.0223219394683838
Epoch 382 / 500 | iteration 0 / 30 | Total Loss: 1.0401746034622192 | KNN Loss: 6.224605560302734 |

Epoch 391 / 500 | iteration 10 / 30 | Total Loss: 1.0403945446014404 | KNN Loss: 6.224269390106201 | BCE Loss: 1.0403945446014404
Epoch 391 / 500 | iteration 15 / 30 | Total Loss: 1.0430713891983032 | KNN Loss: 6.224405288696289 | BCE Loss: 1.0430713891983032
Epoch 391 / 500 | iteration 20 / 30 | Total Loss: 1.0586649179458618 | KNN Loss: 6.224427223205566 | BCE Loss: 1.0586649179458618
Epoch 391 / 500 | iteration 25 / 30 | Total Loss: 1.07962167263031 | KNN Loss: 6.224656105041504 | BCE Loss: 1.07962167263031
Epoch 392 / 500 | iteration 0 / 30 | Total Loss: 1.071535348892212 | KNN Loss: 6.2246503829956055 | BCE Loss: 1.071535348892212
Epoch 392 / 500 | iteration 5 / 30 | Total Loss: 1.0645420551300049 | KNN Loss: 6.224515914916992 | BCE Loss: 1.0645420551300049
Epoch 392 / 500 | iteration 10 / 30 | Total Loss: 1.0727369785308838 | KNN Loss: 6.224330425262451 | BCE Loss: 1.0727369785308838
Epoch 392 / 500 | iteration 15 / 30 | Total Loss: 1.0274819135665894 | KNN Loss: 6.22418880462646

Epoch   402: reducing learning rate of group 0 to 7.8888e-08.
Epoch 402 / 500 | iteration 0 / 30 | Total Loss: 1.0489351749420166 | KNN Loss: 6.224586009979248 | BCE Loss: 1.0489351749420166
Epoch 402 / 500 | iteration 5 / 30 | Total Loss: 1.0453919172286987 | KNN Loss: 6.224823474884033 | BCE Loss: 1.0453919172286987
Epoch 402 / 500 | iteration 10 / 30 | Total Loss: 1.053794026374817 | KNN Loss: 6.224209308624268 | BCE Loss: 1.053794026374817
Epoch 402 / 500 | iteration 15 / 30 | Total Loss: 1.0373785495758057 | KNN Loss: 6.224127769470215 | BCE Loss: 1.0373785495758057
Epoch 402 / 500 | iteration 20 / 30 | Total Loss: 1.0833402872085571 | KNN Loss: 6.224686622619629 | BCE Loss: 1.0833402872085571
Epoch 402 / 500 | iteration 25 / 30 | Total Loss: 1.0511353015899658 | KNN Loss: 6.224090576171875 | BCE Loss: 1.0511353015899658
Epoch 403 / 500 | iteration 0 / 30 | Total Loss: 1.047454595565796 | KNN Loss: 6.224590301513672 | BCE Loss: 1.047454595565796
Epoch 403 / 500 | iteration 5 / 30 

Epoch 412 / 500 | iteration 15 / 30 | Total Loss: 1.0403016805648804 | KNN Loss: 6.22496223449707 | BCE Loss: 1.0403016805648804
Epoch 412 / 500 | iteration 20 / 30 | Total Loss: 1.0390263795852661 | KNN Loss: 6.224148750305176 | BCE Loss: 1.0390263795852661
Epoch 412 / 500 | iteration 25 / 30 | Total Loss: 1.0590797662734985 | KNN Loss: 6.224576473236084 | BCE Loss: 1.0590797662734985
Epoch   413: reducing learning rate of group 0 to 5.5221e-08.
Epoch 413 / 500 | iteration 0 / 30 | Total Loss: 1.071245551109314 | KNN Loss: 6.224816799163818 | BCE Loss: 1.071245551109314
Epoch 413 / 500 | iteration 5 / 30 | Total Loss: 1.0546715259552002 | KNN Loss: 6.224415302276611 | BCE Loss: 1.0546715259552002
Epoch 413 / 500 | iteration 10 / 30 | Total Loss: 1.034571886062622 | KNN Loss: 6.224503517150879 | BCE Loss: 1.034571886062622
Epoch 413 / 500 | iteration 15 / 30 | Total Loss: 1.034019947052002 | KNN Loss: 6.224691867828369 | BCE Loss: 1.034019947052002
Epoch 413 / 500 | iteration 20 / 30 |

Epoch 423 / 500 | iteration 0 / 30 | Total Loss: 1.0519174337387085 | KNN Loss: 6.2244672775268555 | BCE Loss: 1.0519174337387085
Epoch 423 / 500 | iteration 5 / 30 | Total Loss: 1.0655481815338135 | KNN Loss: 6.2244415283203125 | BCE Loss: 1.0655481815338135
Epoch 423 / 500 | iteration 10 / 30 | Total Loss: 1.0739216804504395 | KNN Loss: 6.224446773529053 | BCE Loss: 1.0739216804504395
Epoch 423 / 500 | iteration 15 / 30 | Total Loss: 1.022998332977295 | KNN Loss: 6.224568843841553 | BCE Loss: 1.022998332977295
Epoch 423 / 500 | iteration 20 / 30 | Total Loss: 1.0214616060256958 | KNN Loss: 6.224155426025391 | BCE Loss: 1.0214616060256958
Epoch 423 / 500 | iteration 25 / 30 | Total Loss: 1.0822575092315674 | KNN Loss: 6.224652290344238 | BCE Loss: 1.0822575092315674
Epoch   424: reducing learning rate of group 0 to 3.8655e-08.
Epoch 424 / 500 | iteration 0 / 30 | Total Loss: 1.0564122200012207 | KNN Loss: 6.224653244018555 | BCE Loss: 1.0564122200012207
Epoch 424 / 500 | iteration 5 /

Epoch 433 / 500 | iteration 20 / 30 | Total Loss: 1.0382535457611084 | KNN Loss: 6.2244343757629395 | BCE Loss: 1.0382535457611084
Epoch 433 / 500 | iteration 25 / 30 | Total Loss: 1.0269207954406738 | KNN Loss: 6.224828720092773 | BCE Loss: 1.0269207954406738
Epoch 434 / 500 | iteration 0 / 30 | Total Loss: 1.0532808303833008 | KNN Loss: 6.224363327026367 | BCE Loss: 1.0532808303833008
Epoch 434 / 500 | iteration 5 / 30 | Total Loss: 1.0783848762512207 | KNN Loss: 6.224705219268799 | BCE Loss: 1.0783848762512207
Epoch 434 / 500 | iteration 10 / 30 | Total Loss: 1.0590641498565674 | KNN Loss: 6.224483966827393 | BCE Loss: 1.0590641498565674
Epoch 434 / 500 | iteration 15 / 30 | Total Loss: 1.0717170238494873 | KNN Loss: 6.223955154418945 | BCE Loss: 1.0717170238494873
Epoch 434 / 500 | iteration 20 / 30 | Total Loss: 1.0330395698547363 | KNN Loss: 6.224670886993408 | BCE Loss: 1.0330395698547363
Epoch 434 / 500 | iteration 25 / 30 | Total Loss: 1.003517746925354 | KNN Loss: 6.223978996

Epoch 444 / 500 | iteration 5 / 30 | Total Loss: 1.0181970596313477 | KNN Loss: 6.2242560386657715 | BCE Loss: 1.0181970596313477
Epoch 444 / 500 | iteration 10 / 30 | Total Loss: 1.0399465560913086 | KNN Loss: 6.2241902351379395 | BCE Loss: 1.0399465560913086
Epoch 444 / 500 | iteration 15 / 30 | Total Loss: 1.040428638458252 | KNN Loss: 6.224527359008789 | BCE Loss: 1.040428638458252
Epoch 444 / 500 | iteration 20 / 30 | Total Loss: 1.0321407318115234 | KNN Loss: 6.223783493041992 | BCE Loss: 1.0321407318115234
Epoch 444 / 500 | iteration 25 / 30 | Total Loss: 1.0559033155441284 | KNN Loss: 6.224629878997803 | BCE Loss: 1.0559033155441284
Epoch 445 / 500 | iteration 0 / 30 | Total Loss: 1.0751242637634277 | KNN Loss: 6.224888801574707 | BCE Loss: 1.0751242637634277
Epoch 445 / 500 | iteration 5 / 30 | Total Loss: 1.0568912029266357 | KNN Loss: 6.224323272705078 | BCE Loss: 1.0568912029266357
Epoch 445 / 500 | iteration 10 / 30 | Total Loss: 1.0513652563095093 | KNN Loss: 6.2245631217

Epoch 454 / 500 | iteration 25 / 30 | Total Loss: 1.0386836528778076 | KNN Loss: 6.2243499755859375 | BCE Loss: 1.0386836528778076
Epoch 455 / 500 | iteration 0 / 30 | Total Loss: 1.0423836708068848 | KNN Loss: 6.224418640136719 | BCE Loss: 1.0423836708068848
Epoch 455 / 500 | iteration 5 / 30 | Total Loss: 1.0410542488098145 | KNN Loss: 6.2241291999816895 | BCE Loss: 1.0410542488098145
Epoch 455 / 500 | iteration 10 / 30 | Total Loss: 1.0504595041275024 | KNN Loss: 6.224740505218506 | BCE Loss: 1.0504595041275024
Epoch 455 / 500 | iteration 15 / 30 | Total Loss: 1.0450432300567627 | KNN Loss: 6.224768161773682 | BCE Loss: 1.0450432300567627
Epoch 455 / 500 | iteration 20 / 30 | Total Loss: 1.0475773811340332 | KNN Loss: 6.224482536315918 | BCE Loss: 1.0475773811340332
Epoch 455 / 500 | iteration 25 / 30 | Total Loss: 1.0611717700958252 | KNN Loss: 6.224587440490723 | BCE Loss: 1.0611717700958252
Epoch 456 / 500 | iteration 0 / 30 | Total Loss: 1.07692289352417 | KNN Loss: 6.2246189117

Epoch 465 / 500 | iteration 15 / 30 | Total Loss: 1.0223348140716553 | KNN Loss: 6.224613666534424 | BCE Loss: 1.0223348140716553
Epoch 465 / 500 | iteration 20 / 30 | Total Loss: 1.067919373512268 | KNN Loss: 6.224935531616211 | BCE Loss: 1.067919373512268
Epoch 465 / 500 | iteration 25 / 30 | Total Loss: 1.043061375617981 | KNN Loss: 6.224202632904053 | BCE Loss: 1.043061375617981
Epoch 466 / 500 | iteration 0 / 30 | Total Loss: 1.045883059501648 | KNN Loss: 6.2244791984558105 | BCE Loss: 1.045883059501648
Epoch 466 / 500 | iteration 5 / 30 | Total Loss: 1.0668728351593018 | KNN Loss: 6.224584102630615 | BCE Loss: 1.0668728351593018
Epoch 466 / 500 | iteration 10 / 30 | Total Loss: 1.0705169439315796 | KNN Loss: 6.224330425262451 | BCE Loss: 1.0705169439315796
Epoch 466 / 500 | iteration 15 / 30 | Total Loss: 1.0422401428222656 | KNN Loss: 6.224624156951904 | BCE Loss: 1.0422401428222656
Epoch 466 / 500 | iteration 20 / 30 | Total Loss: 1.044601559638977 | KNN Loss: 6.224548816680908

Epoch 476 / 500 | iteration 5 / 30 | Total Loss: 1.0375841856002808 | KNN Loss: 6.224062919616699 | BCE Loss: 1.0375841856002808
Epoch 476 / 500 | iteration 10 / 30 | Total Loss: 1.052534818649292 | KNN Loss: 6.224465847015381 | BCE Loss: 1.052534818649292
Epoch 476 / 500 | iteration 15 / 30 | Total Loss: 1.045093059539795 | KNN Loss: 6.224832057952881 | BCE Loss: 1.045093059539795
Epoch 476 / 500 | iteration 20 / 30 | Total Loss: 1.063495397567749 | KNN Loss: 6.224554538726807 | BCE Loss: 1.063495397567749
Epoch 476 / 500 | iteration 25 / 30 | Total Loss: 1.0082381963729858 | KNN Loss: 6.22435188293457 | BCE Loss: 1.0082381963729858
Epoch 477 / 500 | iteration 0 / 30 | Total Loss: 1.0363152027130127 | KNN Loss: 6.224386215209961 | BCE Loss: 1.0363152027130127
Epoch 477 / 500 | iteration 5 / 30 | Total Loss: 1.0466156005859375 | KNN Loss: 6.224661350250244 | BCE Loss: 1.0466156005859375
Epoch 477 / 500 | iteration 10 / 30 | Total Loss: 1.0606563091278076 | KNN Loss: 6.22452449798584 | 

Epoch 486 / 500 | iteration 25 / 30 | Total Loss: 1.055832862854004 | KNN Loss: 6.22460412979126 | BCE Loss: 1.055832862854004
Epoch 487 / 500 | iteration 0 / 30 | Total Loss: 1.053954839706421 | KNN Loss: 6.224515438079834 | BCE Loss: 1.053954839706421
Epoch 487 / 500 | iteration 5 / 30 | Total Loss: 1.0651004314422607 | KNN Loss: 6.224644660949707 | BCE Loss: 1.0651004314422607
Epoch 487 / 500 | iteration 10 / 30 | Total Loss: 1.057348370552063 | KNN Loss: 6.224420547485352 | BCE Loss: 1.057348370552063
Epoch 487 / 500 | iteration 15 / 30 | Total Loss: 1.0303072929382324 | KNN Loss: 6.224255084991455 | BCE Loss: 1.0303072929382324
Epoch 487 / 500 | iteration 20 / 30 | Total Loss: 1.054123878479004 | KNN Loss: 6.224785327911377 | BCE Loss: 1.054123878479004
Epoch 487 / 500 | iteration 25 / 30 | Total Loss: 1.0397799015045166 | KNN Loss: 6.224265098571777 | BCE Loss: 1.0397799015045166
Epoch 488 / 500 | iteration 0 / 30 | Total Loss: 1.0715899467468262 | KNN Loss: 6.2243733406066895 | 

Epoch 497 / 500 | iteration 15 / 30 | Total Loss: 1.0636194944381714 | KNN Loss: 6.224575996398926 | BCE Loss: 1.0636194944381714
Epoch 497 / 500 | iteration 20 / 30 | Total Loss: 1.0476899147033691 | KNN Loss: 6.2240729331970215 | BCE Loss: 1.0476899147033691
Epoch 497 / 500 | iteration 25 / 30 | Total Loss: 1.036081314086914 | KNN Loss: 6.224581241607666 | BCE Loss: 1.036081314086914
Epoch 498 / 500 | iteration 0 / 30 | Total Loss: 1.066345453262329 | KNN Loss: 6.224164009094238 | BCE Loss: 1.066345453262329
Epoch 498 / 500 | iteration 5 / 30 | Total Loss: 1.0668036937713623 | KNN Loss: 6.224806785583496 | BCE Loss: 1.0668036937713623
Epoch 498 / 500 | iteration 10 / 30 | Total Loss: 1.0526536703109741 | KNN Loss: 6.224017143249512 | BCE Loss: 1.0526536703109741
Epoch 498 / 500 | iteration 15 / 30 | Total Loss: 1.026498794555664 | KNN Loss: 6.224737167358398 | BCE Loss: 1.026498794555664
Epoch 498 / 500 | iteration 20 / 30 | Total Loss: 1.048725962638855 | KNN Loss: 6.22463321685791 

In [25]:
outputs, iterm = model(dataset[67][0].unsqueeze(0).to(device), return_intermidiate=True)
print(outputs)
print(outputs.min())
print(outputs.max())

tensor([[ 2.9231,  3.8722,  2.5812,  3.5745,  3.4543,  0.7017,  2.6651,  2.1970,
          2.3083,  1.9925,  2.2344,  2.1979,  0.7876,  1.8203,  1.2891,  1.5216,
          2.8064,  3.1788,  2.7991,  2.3030,  1.7450,  2.9516,  2.2904,  2.6379,
          2.5332,  1.7390,  2.1229,  1.4155,  1.4934,  0.3250, -0.2396,  0.9983,
          0.2212,  0.9240,  1.5308,  1.4778,  1.0049,  3.3144,  0.8001,  1.3198,
          0.9675, -0.7032, -0.2357,  2.3364,  2.1887,  0.7335, -0.2012,  0.0970,
          1.4607,  2.4952,  1.8230,  0.1381,  1.4258,  0.5204, -0.6363,  1.1095,
          1.4810,  1.3720,  1.3400,  1.8270,  0.5738,  0.8433,  0.1389,  1.7235,
          1.3158,  1.6659, -1.8257,  0.3069,  2.2878,  2.1419,  2.5500,  0.4267,
          1.3524,  2.4587,  1.9961,  1.2923,  0.2264,  0.7375,  0.2173,  1.5898,
          0.0257,  0.3779,  1.8381, -0.3734,  0.2388, -1.0732, -2.2649, -0.2561,
          0.5456, -1.8580,  0.4615, -0.1291, -0.5688, -0.9364,  0.5635,  1.2683,
         -0.6937, -0.7033,  

In [26]:
plt.figure()
plt.plot(losses)
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [27]:
dataset.transform = torchvision.transforms.Compose([
    BinaryEncodingTransform(mapping=dataset.items_to_idx),
]
)
dataset.target_transform = torchvision.transforms.Compose([
    BinaryEncodingTransform(mapping=dataset.items_to_idx),
]
)

In [28]:
dataset_ = [d[0].cpu() for d in dataset]

In [29]:
model = model.eval().cpu()
projections = model.calculate_intermidiate(dataset_)

100%|██████████| 15/15 [00:00<00:00, 72.73it/s]


In [30]:
distances = pairwise_distances(projections)
# distances = np.triu(distances)
distances_f = distances.flatten()

plt.matshow(distances)
plt.colorbar()
plt.figure()
plt.hist(distances_f[distances_f > 0], bins=1000)
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

# Fit DBSCAN and calculate indices


In [107]:
clusters = DBSCAN(eps=0.01, min_samples=80).fit_predict(projections)
# scores = []
# best_score = float('inf')
# clusters = None
# range_ = list(range(5, 20))
# for k in tqdm(range_):
#     y = GaussianMixture(n_components=k).fit_predict(projections)
#     cur_score = davies_bouldin_score(projections, y)
#     scores.append(cur_score)
    
#     if cur_score < best_score:
#         best_score = cur_score
#         clusters = y

In [108]:
perplexity = 100

p = reduce_dims_and_plot(projections[clusters != -1],
                         y=clusters[clusters != -1],
                         title=f'perplexity: {perplexity}',
                         file_name=None,
                         perplexity=perplexity,
                         library='Multicore-TSNE',
                         perform_PCA=False,
                         projected=None,
                         figure_type='2d',
                         show_figure=True,
                         close_figure=False,
                         text=None)    

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [109]:
# from sklearn.tree import DecisionTreeClassifier
# from sklearn import tree
# from sklearn.tree import _tree

In [110]:
tensor_dataset = torch.stack(dataset_)

In [111]:
# clf = DecisionTreeClassifier(max_depth=200, min_samples_leaf=5)
# clf = clf.fit(tensor_dataset[clusters!=-1], clusters[clusters != -1])
# print(clf.score(tensor_dataset[clusters!=-1], clusters[clusters != -1]))
# print(clf.get_depth())

In [112]:
# scores = []
# for min_samples in range(1,50, 1):
#     clf = DecisionTreeClassifier(max_depth=200, min_samples_leaf=min_samples)
#     clf = clf.fit(tensor_dataset[clusters!=-1], clusters[clusters != -1])
#     scores.append(clf.score(tensor_dataset[clusters!=-1], clusters[clusters != -1]))
    
# plt.figure()
# plt.plot(list(range(1,50, 1)), scores)
# plt.show()

In [113]:
def get_rules(tree, feature_names, class_names):
    tree_ = tree.tree_
    feature_name = [
        feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!"
        for i in tree_.feature
    ]

    paths = []
    path = []
    
    def recurse(node, path, paths):
        if tree_.feature[node] != _tree.TREE_UNDEFINED:
            name = feature_name[node]
            threshold = tree_.threshold[node]
            p1, p2 = list(path), list(path)
#             p1 += [f"({name} <= {np.round(threshold, 3)})"]
            p1 += [(name, '<=', np.round(threshold, 3))]
            recurse(tree_.children_left[node], p1, paths)
            p2 += [(name, '>', np.round(threshold, 3))]
            recurse(tree_.children_right[node], p2, paths)
        else:
            path += [(tree_.value[node], tree_.n_node_samples[node])]
            paths += [path]
            
    recurse(0, path, paths)

    # sort by samples count
    samples_count = [p[-1][1] for p in paths]
    ii = list(np.argsort(samples_count))
    paths = [paths[i] for i in reversed(ii)]
    
    rules = []
    for path in paths:
        rule = []
        
        for p in path[:-1]:
            rule += [p]
        target = " then "
        if class_names is None:
            target += "response: "+str(np.round(path[-1][0][0][0],3))
        else:
            classes = path[-1][0][0]
            l = np.argmax(classes)
            target += f"class: {class_names[l]} (proba: {np.round(100.0*classes[l]/np.sum(classes),2)}%)"
           
        proba = np.round(100.0*classes[l]/np.sum(classes),2)
        target += f" | based on {path[-1][1]:,} samples"
        rule_wrapper = {'target': target, 'rule': rule, 'proba': proba}
        rules += [rule_wrapper]
        
    return rules

In [114]:
# rules = get_rules(clf, dataset.items, clusters[clusters != -1])

# for rule in rules:
#     n_pos = 0
#     for c,p,v in rule['rule']:
#         if p == '>':
#             n_pos += 1
#     rule['pos'] = n_pos

In [115]:
# plt.figure()
# probs = [r['proba'] for r in rules]
# plt.hist(probs, bins = 100)
# plt.show()

In [116]:
# rules = sorted(rules, key=lambda x:x['pos'])
# rules = [r for r in rules if r['proba'] > 50]
# print(len(rules))

In [117]:
# for i in range(17):
#     r_i = rules[i]
#     print(f"------------- rule {i} length {len(r_i)} -------------")
#     print(r_i)

# Train a Soft-Decision-Tree given the self-labels

## Prepare the dataset

In [118]:
tree_dataset = list(zip(tensor_dataset[clusters!=-1], clusters[clusters != -1]))
batch_size = 512
tree_loader = torch.utils.data.DataLoader(tree_dataset, batch_size=batch_size, shuffle=True)

# Define how we prune the weights of a node

In [119]:
def prune_node(node_weights, factor=1):
    w = node_weights.cpu().detach().numpy()
    mean_ = np.mean(w)
    std_ = np.std(w)
    node_weights[((mean_ - std_ * factor) < node_weights) & (node_weights < (mean_ + std_ * factor))] = 0
    return node_weights

def prune_node_keep(node_weights, keep=4):
    w = node_weights.cpu().detach().numpy()
    throw_idx = np.argsort(abs(w))[:-keep]
    node_weights[throw_idx] = 0
    return node_weights

def prune_tree(tree_, factor):
    new_weights = tree_.inner_nodes.weight.clone()
    for i in range(new_weights.shape[0]):
        res = prune_node_keep(new_weights[i, :], factor)
        new_weights[i, :] = res

    with torch.no_grad():
        tree_.inner_nodes.weight.copy_(new_weights)
        
def sparseness(x):
    s = []
    for i in range(x.shape[0]):
        x_ = x[i, :]
        sp = (len(x_) - torch.norm(x_, 0).item()) / len(x_)
        s.append(sp)
    return np.mean(s)

def compute_regularization_by_level(tree):
    total_reg = 0
    for i in range(tree.inner_nodes.weight.shape[0]):
        cur_level = np.floor(np.log2(i+1))
        node_reg = torch.norm(tree.inner_nodes.weight[i].view(-1), 2)
        total_reg += 2**(-cur_level) * node_reg
    return total_reg

def show_sparseness(tree):
    avg_sp = sparseness(tree.inner_nodes.weight)
    print(f"Average sparseness: {avg_sp}")
    layer = 0
    sps = []
    for i in range(tree.inner_nodes.weight.shape[0]):
        cur_layer = int(np.floor(np.log2(i+1)))
        if cur_layer != layer:
            print(f"layer {layer}: {np.mean(sps)}")
            sps = []
            layer = cur_layer

        x_ = tree.inner_nodes.weight[i, :]
        sp = (len(x_) - torch.norm(x_, 0).item()) / len(x_)
        sps.append(sp)
        
    return avg_sp

## Training configurations

In [120]:
def do_epoch(model, loader, device, log_interval, losses, accs, epoch, iteration):
    model = model.train()
    for batch_idx, (data, target) in enumerate(loader):
        iteration += 1
        data, target = data.to(device), target.to(device)

        output, penalty = tree.forward(data)

        # Loss
        loss_tree = criterion(output, target.view(-1))

        # Penalty
        loss_tree += penalty

        # Sparse regularization
#         fc_params = torch.cat([x.view(-1) for x in tree.inner_nodes.parameters()])
#         regularization = sparsity_lamda * torch.norm(fc_params, 2)
        regularization = sparsity_lamda * compute_regularization_by_level(tree)
        loss = loss_tree

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        losses.append(loss.item())

        pred = output.data.max(1)[1]
        correct = pred.eq(target.view(-1).data).sum()
        accs.append(correct.item() / data.size()[0])

        # Print training status
        if batch_idx % log_interval == 0:
            print(f"Epoch: {epoch:02d} | Batch: {batch_idx:03d} / {len(loader):03d} | Total loss: {loss.item():.3f} | Reg loss: {regularization.item():.3f} | Tree loss: {loss_tree.item():.3f} | Accuracy: {correct.item() / data.size()[0]:03f} | {round((time.time() - start_time) / iteration, 3)} sec/iter")
            
    return iteration


In [121]:
lr = 5e-3
weight_decay = 5e-4
sparsity_lamda = 2e-3
epochs = 100
output_dim = len(set(clusters))
log_interval = 1
use_cuda = device != 'cpu'

In [122]:
tree = SDT(input_dim=tensor_dataset.shape[1], output_dim=len(clusters - 1), depth=tree_depth, lamda=1e-3, use_cuda=use_cuda)
optimizer = torch.optim.Adam(tree.parameters(),
                                 lr=lr,
                                 weight_decay=weight_decay)
criterion = nn.CrossEntropyLoss()
tree = tree.to(device)

In [123]:
losses = []
accs = []
sparsity = []

In [124]:
start_time = time.time()
iteration = 0
for epoch in range(epochs):
    # Training
    avg_sp = show_sparseness(tree)
    sparsity.append(avg_sp)
    iteration = do_epoch(tree, tree_loader, device, log_interval, losses, accs, epoch, iteration)
    
    if epoch % 1 == 0:
        prune_tree(tree, factor=3)
        

Average sparseness: 0.0
layer 0: 0.0
layer 1: 0.0
layer 2: 0.0
layer 3: 0.0
layer 4: 0.0
Epoch: 00 | Batch: 000 / 025 | Total loss: 9.612 | Reg loss: 0.007 | Tree loss: 9.612 | Accuracy: 0.000000 | 0.095 sec/iter
Epoch: 00 | Batch: 001 / 025 | Total loss: 9.592 | Reg loss: 0.007 | Tree loss: 9.592 | Accuracy: 0.000000 | 0.082 sec/iter
Epoch: 00 | Batch: 002 / 025 | Total loss: 9.572 | Reg loss: 0.007 | Tree loss: 9.572 | Accuracy: 0.000000 | 0.08 sec/iter
Epoch: 00 | Batch: 003 / 025 | Total loss: 9.553 | Reg loss: 0.007 | Tree loss: 9.553 | Accuracy: 0.000000 | 0.076 sec/iter
Epoch: 00 | Batch: 004 / 025 | Total loss: 9.531 | Reg loss: 0.007 | Tree loss: 9.531 | Accuracy: 0.000000 | 0.073 sec/iter
Epoch: 00 | Batch: 005 / 025 | Total loss: 9.512 | Reg loss: 0.007 | Tree loss: 9.512 | Accuracy: 0.000000 | 0.071 sec/iter
Epoch: 00 | Batch: 006 / 025 | Total loss: 9.492 | Reg loss: 0.007 | Tree loss: 9.492 | Accuracy: 0.000000 | 0.069 sec/iter
Epoch: 00 | Batch: 007 / 025 | Total loss: 9

Epoch: 02 | Batch: 015 / 025 | Total loss: 9.022 | Reg loss: 0.010 | Tree loss: 9.022 | Accuracy: 1.000000 | 0.066 sec/iter
Epoch: 02 | Batch: 016 / 025 | Total loss: 9.010 | Reg loss: 0.011 | Tree loss: 9.010 | Accuracy: 1.000000 | 0.066 sec/iter
Epoch: 02 | Batch: 017 / 025 | Total loss: 8.991 | Reg loss: 0.011 | Tree loss: 8.991 | Accuracy: 1.000000 | 0.066 sec/iter
Epoch: 02 | Batch: 018 / 025 | Total loss: 8.972 | Reg loss: 0.011 | Tree loss: 8.972 | Accuracy: 1.000000 | 0.066 sec/iter
Epoch: 02 | Batch: 019 / 025 | Total loss: 8.947 | Reg loss: 0.012 | Tree loss: 8.947 | Accuracy: 1.000000 | 0.066 sec/iter
Epoch: 02 | Batch: 020 / 025 | Total loss: 8.933 | Reg loss: 0.012 | Tree loss: 8.933 | Accuracy: 1.000000 | 0.066 sec/iter
Epoch: 02 | Batch: 021 / 025 | Total loss: 8.911 | Reg loss: 0.013 | Tree loss: 8.911 | Accuracy: 1.000000 | 0.067 sec/iter
Epoch: 02 | Batch: 022 / 025 | Total loss: 8.882 | Reg loss: 0.013 | Tree loss: 8.882 | Accuracy: 1.000000 | 0.066 sec/iter
Epoch: 0

Epoch: 05 | Batch: 004 / 025 | Total loss: 8.814 | Reg loss: 0.012 | Tree loss: 8.814 | Accuracy: 1.000000 | 0.067 sec/iter
Epoch: 05 | Batch: 005 / 025 | Total loss: 8.799 | Reg loss: 0.012 | Tree loss: 8.799 | Accuracy: 1.000000 | 0.067 sec/iter
Epoch: 05 | Batch: 006 / 025 | Total loss: 8.774 | Reg loss: 0.013 | Tree loss: 8.774 | Accuracy: 1.000000 | 0.067 sec/iter
Epoch: 05 | Batch: 007 / 025 | Total loss: 8.757 | Reg loss: 0.013 | Tree loss: 8.757 | Accuracy: 1.000000 | 0.067 sec/iter
Epoch: 05 | Batch: 008 / 025 | Total loss: 8.737 | Reg loss: 0.013 | Tree loss: 8.737 | Accuracy: 1.000000 | 0.067 sec/iter
Epoch: 05 | Batch: 009 / 025 | Total loss: 8.723 | Reg loss: 0.013 | Tree loss: 8.723 | Accuracy: 1.000000 | 0.067 sec/iter
Epoch: 05 | Batch: 010 / 025 | Total loss: 8.691 | Reg loss: 0.014 | Tree loss: 8.691 | Accuracy: 1.000000 | 0.067 sec/iter
Epoch: 05 | Batch: 011 / 025 | Total loss: 8.679 | Reg loss: 0.014 | Tree loss: 8.679 | Accuracy: 1.000000 | 0.067 sec/iter
Epoch: 0

Epoch: 07 | Batch: 019 / 025 | Total loss: 8.229 | Reg loss: 0.020 | Tree loss: 8.229 | Accuracy: 1.000000 | 0.067 sec/iter
Epoch: 07 | Batch: 020 / 025 | Total loss: 8.218 | Reg loss: 0.021 | Tree loss: 8.218 | Accuracy: 1.000000 | 0.067 sec/iter
Epoch: 07 | Batch: 021 / 025 | Total loss: 8.197 | Reg loss: 0.021 | Tree loss: 8.197 | Accuracy: 1.000000 | 0.067 sec/iter
Epoch: 07 | Batch: 022 / 025 | Total loss: 8.172 | Reg loss: 0.021 | Tree loss: 8.172 | Accuracy: 1.000000 | 0.067 sec/iter
Epoch: 07 | Batch: 023 / 025 | Total loss: 8.163 | Reg loss: 0.022 | Tree loss: 8.163 | Accuracy: 1.000000 | 0.067 sec/iter
Epoch: 07 | Batch: 024 / 025 | Total loss: 8.136 | Reg loss: 0.022 | Tree loss: 8.136 | Accuracy: 1.000000 | 0.067 sec/iter
Average sparseness: 0.9821428571428567
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
Epoch: 08 | Batch: 000 / 025 | Total loss: 8.483 | Reg loss: 0.017 | Tree los

Epoch: 10 | Batch: 006 / 025 | Total loss: 8.088 | Reg loss: 0.020 | Tree loss: 8.088 | Accuracy: 1.000000 | 0.067 sec/iter
Epoch: 10 | Batch: 007 / 025 | Total loss: 8.066 | Reg loss: 0.020 | Tree loss: 8.066 | Accuracy: 1.000000 | 0.067 sec/iter
Epoch: 10 | Batch: 008 / 025 | Total loss: 8.040 | Reg loss: 0.020 | Tree loss: 8.040 | Accuracy: 1.000000 | 0.067 sec/iter
Epoch: 10 | Batch: 009 / 025 | Total loss: 8.032 | Reg loss: 0.021 | Tree loss: 8.032 | Accuracy: 1.000000 | 0.067 sec/iter
Epoch: 10 | Batch: 010 / 025 | Total loss: 7.996 | Reg loss: 0.021 | Tree loss: 7.996 | Accuracy: 1.000000 | 0.067 sec/iter
Epoch: 10 | Batch: 011 / 025 | Total loss: 7.989 | Reg loss: 0.021 | Tree loss: 7.989 | Accuracy: 1.000000 | 0.067 sec/iter
Epoch: 10 | Batch: 012 / 025 | Total loss: 7.953 | Reg loss: 0.021 | Tree loss: 7.953 | Accuracy: 1.000000 | 0.067 sec/iter
Epoch: 10 | Batch: 013 / 025 | Total loss: 7.943 | Reg loss: 0.022 | Tree loss: 7.943 | Accuracy: 1.000000 | 0.067 sec/iter
Epoch: 1

Epoch: 12 | Batch: 020 / 025 | Total loss: 7.477 | Reg loss: 0.026 | Tree loss: 7.477 | Accuracy: 1.000000 | 0.067 sec/iter
Epoch: 12 | Batch: 021 / 025 | Total loss: 7.455 | Reg loss: 0.026 | Tree loss: 7.455 | Accuracy: 1.000000 | 0.067 sec/iter
Epoch: 12 | Batch: 022 / 025 | Total loss: 7.425 | Reg loss: 0.027 | Tree loss: 7.425 | Accuracy: 1.000000 | 0.067 sec/iter
Epoch: 12 | Batch: 023 / 025 | Total loss: 7.400 | Reg loss: 0.027 | Tree loss: 7.400 | Accuracy: 1.000000 | 0.067 sec/iter
Epoch: 12 | Batch: 024 / 025 | Total loss: 7.368 | Reg loss: 0.027 | Tree loss: 7.368 | Accuracy: 1.000000 | 0.067 sec/iter
Average sparseness: 0.9821428571428567
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
Epoch: 13 | Batch: 000 / 025 | Total loss: 7.773 | Reg loss: 0.023 | Tree loss: 7.773 | Accuracy: 1.000000 | 0.067 sec/iter
Epoch: 13 | Batch: 001 / 025 | Total loss: 7.740 | Reg loss: 0.023 | Tree los

Epoch: 15 | Batch: 009 / 025 | Total loss: 7.226 | Reg loss: 0.025 | Tree loss: 7.226 | Accuracy: 1.000000 | 0.067 sec/iter
Epoch: 15 | Batch: 010 / 025 | Total loss: 7.192 | Reg loss: 0.025 | Tree loss: 7.192 | Accuracy: 1.000000 | 0.067 sec/iter
Epoch: 15 | Batch: 011 / 025 | Total loss: 7.162 | Reg loss: 0.025 | Tree loss: 7.162 | Accuracy: 1.000000 | 0.067 sec/iter
Epoch: 15 | Batch: 012 / 025 | Total loss: 7.148 | Reg loss: 0.025 | Tree loss: 7.148 | Accuracy: 1.000000 | 0.067 sec/iter
Epoch: 15 | Batch: 013 / 025 | Total loss: 7.104 | Reg loss: 0.026 | Tree loss: 7.104 | Accuracy: 1.000000 | 0.067 sec/iter
Epoch: 15 | Batch: 014 / 025 | Total loss: 7.105 | Reg loss: 0.026 | Tree loss: 7.105 | Accuracy: 1.000000 | 0.067 sec/iter
Epoch: 15 | Batch: 015 / 025 | Total loss: 7.060 | Reg loss: 0.026 | Tree loss: 7.060 | Accuracy: 1.000000 | 0.067 sec/iter
Epoch: 15 | Batch: 016 / 025 | Total loss: 7.046 | Reg loss: 0.026 | Tree loss: 7.046 | Accuracy: 1.000000 | 0.067 sec/iter
Epoch: 1

Epoch: 17 | Batch: 023 / 025 | Total loss: 6.509 | Reg loss: 0.029 | Tree loss: 6.509 | Accuracy: 1.000000 | 0.067 sec/iter
Epoch: 17 | Batch: 024 / 025 | Total loss: 6.478 | Reg loss: 0.029 | Tree loss: 6.478 | Accuracy: 1.000000 | 0.067 sec/iter
Average sparseness: 0.9821428571428567
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
Epoch: 18 | Batch: 000 / 025 | Total loss: 6.925 | Reg loss: 0.025 | Tree loss: 6.925 | Accuracy: 1.000000 | 0.067 sec/iter
Epoch: 18 | Batch: 001 / 025 | Total loss: 6.903 | Reg loss: 0.025 | Tree loss: 6.903 | Accuracy: 1.000000 | 0.067 sec/iter
Epoch: 18 | Batch: 002 / 025 | Total loss: 6.883 | Reg loss: 0.025 | Tree loss: 6.883 | Accuracy: 1.000000 | 0.067 sec/iter
Epoch: 18 | Batch: 003 / 025 | Total loss: 6.857 | Reg loss: 0.025 | Tree loss: 6.857 | Accuracy: 1.000000 | 0.067 sec/iter
Epoch: 18 | Batch: 004 / 025 | Total loss: 6.816 | Reg loss: 0.026 | Tree los

Epoch: 20 | Batch: 011 / 025 | Total loss: 6.296 | Reg loss: 0.027 | Tree loss: 6.296 | Accuracy: 1.000000 | 0.067 sec/iter
Epoch: 20 | Batch: 012 / 025 | Total loss: 6.271 | Reg loss: 0.027 | Tree loss: 6.271 | Accuracy: 1.000000 | 0.067 sec/iter
Epoch: 20 | Batch: 013 / 025 | Total loss: 6.249 | Reg loss: 0.027 | Tree loss: 6.249 | Accuracy: 1.000000 | 0.067 sec/iter
Epoch: 20 | Batch: 014 / 025 | Total loss: 6.209 | Reg loss: 0.027 | Tree loss: 6.209 | Accuracy: 1.000000 | 0.067 sec/iter
Epoch: 20 | Batch: 015 / 025 | Total loss: 6.193 | Reg loss: 0.027 | Tree loss: 6.193 | Accuracy: 1.000000 | 0.067 sec/iter
Epoch: 20 | Batch: 016 / 025 | Total loss: 6.185 | Reg loss: 0.028 | Tree loss: 6.185 | Accuracy: 1.000000 | 0.067 sec/iter
Epoch: 20 | Batch: 017 / 025 | Total loss: 6.140 | Reg loss: 0.028 | Tree loss: 6.140 | Accuracy: 1.000000 | 0.067 sec/iter
Epoch: 20 | Batch: 018 / 025 | Total loss: 6.093 | Reg loss: 0.028 | Tree loss: 6.093 | Accuracy: 1.000000 | 0.067 sec/iter
Epoch: 2

Epoch: 23 | Batch: 000 / 025 | Total loss: 6.078 | Reg loss: 0.027 | Tree loss: 6.078 | Accuracy: 1.000000 | 0.067 sec/iter
Epoch: 23 | Batch: 001 / 025 | Total loss: 6.063 | Reg loss: 0.027 | Tree loss: 6.063 | Accuracy: 1.000000 | 0.067 sec/iter
Epoch: 23 | Batch: 002 / 025 | Total loss: 6.042 | Reg loss: 0.027 | Tree loss: 6.042 | Accuracy: 1.000000 | 0.067 sec/iter
Epoch: 23 | Batch: 003 / 025 | Total loss: 6.013 | Reg loss: 0.027 | Tree loss: 6.013 | Accuracy: 1.000000 | 0.067 sec/iter
Epoch: 23 | Batch: 004 / 025 | Total loss: 5.976 | Reg loss: 0.027 | Tree loss: 5.976 | Accuracy: 1.000000 | 0.067 sec/iter
Epoch: 23 | Batch: 005 / 025 | Total loss: 5.967 | Reg loss: 0.027 | Tree loss: 5.967 | Accuracy: 1.000000 | 0.067 sec/iter
Epoch: 23 | Batch: 006 / 025 | Total loss: 5.935 | Reg loss: 0.027 | Tree loss: 5.935 | Accuracy: 1.000000 | 0.067 sec/iter
Epoch: 23 | Batch: 007 / 025 | Total loss: 5.904 | Reg loss: 0.027 | Tree loss: 5.904 | Accuracy: 1.000000 | 0.067 sec/iter
Epoch: 2

Epoch: 25 | Batch: 014 / 025 | Total loss: 5.418 | Reg loss: 0.028 | Tree loss: 5.418 | Accuracy: 1.000000 | 0.067 sec/iter
Epoch: 25 | Batch: 015 / 025 | Total loss: 5.394 | Reg loss: 0.028 | Tree loss: 5.394 | Accuracy: 1.000000 | 0.067 sec/iter
Epoch: 25 | Batch: 016 / 025 | Total loss: 5.383 | Reg loss: 0.028 | Tree loss: 5.383 | Accuracy: 1.000000 | 0.067 sec/iter
Epoch: 25 | Batch: 017 / 025 | Total loss: 5.339 | Reg loss: 0.028 | Tree loss: 5.339 | Accuracy: 1.000000 | 0.067 sec/iter
Epoch: 25 | Batch: 018 / 025 | Total loss: 5.309 | Reg loss: 0.028 | Tree loss: 5.309 | Accuracy: 1.000000 | 0.067 sec/iter
Epoch: 25 | Batch: 019 / 025 | Total loss: 5.279 | Reg loss: 0.028 | Tree loss: 5.279 | Accuracy: 1.000000 | 0.067 sec/iter
Epoch: 25 | Batch: 020 / 025 | Total loss: 5.269 | Reg loss: 0.028 | Tree loss: 5.269 | Accuracy: 1.000000 | 0.067 sec/iter
Epoch: 25 | Batch: 021 / 025 | Total loss: 5.247 | Reg loss: 0.029 | Tree loss: 5.247 | Accuracy: 1.000000 | 0.067 sec/iter
Epoch: 2

Epoch: 28 | Batch: 002 / 025 | Total loss: 5.240 | Reg loss: 0.027 | Tree loss: 5.240 | Accuracy: 1.000000 | 0.067 sec/iter
Epoch: 28 | Batch: 003 / 025 | Total loss: 5.228 | Reg loss: 0.027 | Tree loss: 5.228 | Accuracy: 1.000000 | 0.067 sec/iter
Epoch: 28 | Batch: 004 / 025 | Total loss: 5.187 | Reg loss: 0.027 | Tree loss: 5.187 | Accuracy: 1.000000 | 0.067 sec/iter
Epoch: 28 | Batch: 005 / 025 | Total loss: 5.173 | Reg loss: 0.027 | Tree loss: 5.173 | Accuracy: 1.000000 | 0.067 sec/iter
Epoch: 28 | Batch: 006 / 025 | Total loss: 5.150 | Reg loss: 0.027 | Tree loss: 5.150 | Accuracy: 1.000000 | 0.067 sec/iter
Epoch: 28 | Batch: 007 / 025 | Total loss: 5.108 | Reg loss: 0.027 | Tree loss: 5.108 | Accuracy: 1.000000 | 0.067 sec/iter
Epoch: 28 | Batch: 008 / 025 | Total loss: 5.099 | Reg loss: 0.027 | Tree loss: 5.099 | Accuracy: 1.000000 | 0.067 sec/iter
Epoch: 28 | Batch: 009 / 025 | Total loss: 5.069 | Reg loss: 0.027 | Tree loss: 5.069 | Accuracy: 1.000000 | 0.067 sec/iter
Epoch: 2

Epoch: 30 | Batch: 016 / 025 | Total loss: 4.611 | Reg loss: 0.028 | Tree loss: 4.611 | Accuracy: 1.000000 | 0.067 sec/iter
Epoch: 30 | Batch: 017 / 025 | Total loss: 4.600 | Reg loss: 0.028 | Tree loss: 4.600 | Accuracy: 1.000000 | 0.067 sec/iter
Epoch: 30 | Batch: 018 / 025 | Total loss: 4.585 | Reg loss: 0.028 | Tree loss: 4.585 | Accuracy: 1.000000 | 0.067 sec/iter
Epoch: 30 | Batch: 019 / 025 | Total loss: 4.552 | Reg loss: 0.028 | Tree loss: 4.552 | Accuracy: 1.000000 | 0.067 sec/iter
Epoch: 30 | Batch: 020 / 025 | Total loss: 4.537 | Reg loss: 0.029 | Tree loss: 4.537 | Accuracy: 1.000000 | 0.067 sec/iter
Epoch: 30 | Batch: 021 / 025 | Total loss: 4.503 | Reg loss: 0.029 | Tree loss: 4.503 | Accuracy: 1.000000 | 0.067 sec/iter
Epoch: 30 | Batch: 022 / 025 | Total loss: 4.485 | Reg loss: 0.029 | Tree loss: 4.485 | Accuracy: 1.000000 | 0.067 sec/iter
Epoch: 30 | Batch: 023 / 025 | Total loss: 4.463 | Reg loss: 0.029 | Tree loss: 4.463 | Accuracy: 1.000000 | 0.067 sec/iter
Epoch: 3

Epoch: 33 | Batch: 004 / 025 | Total loss: 4.462 | Reg loss: 0.027 | Tree loss: 4.462 | Accuracy: 1.000000 | 0.067 sec/iter
Epoch: 33 | Batch: 005 / 025 | Total loss: 4.424 | Reg loss: 0.027 | Tree loss: 4.424 | Accuracy: 1.000000 | 0.067 sec/iter
Epoch: 33 | Batch: 006 / 025 | Total loss: 4.405 | Reg loss: 0.028 | Tree loss: 4.405 | Accuracy: 1.000000 | 0.067 sec/iter
Epoch: 33 | Batch: 007 / 025 | Total loss: 4.407 | Reg loss: 0.028 | Tree loss: 4.407 | Accuracy: 1.000000 | 0.067 sec/iter
Epoch: 33 | Batch: 008 / 025 | Total loss: 4.366 | Reg loss: 0.028 | Tree loss: 4.366 | Accuracy: 1.000000 | 0.067 sec/iter
Epoch: 33 | Batch: 009 / 025 | Total loss: 4.345 | Reg loss: 0.028 | Tree loss: 4.345 | Accuracy: 1.000000 | 0.067 sec/iter
Epoch: 33 | Batch: 010 / 025 | Total loss: 4.330 | Reg loss: 0.028 | Tree loss: 4.330 | Accuracy: 1.000000 | 0.067 sec/iter
Epoch: 33 | Batch: 011 / 025 | Total loss: 4.303 | Reg loss: 0.028 | Tree loss: 4.303 | Accuracy: 1.000000 | 0.067 sec/iter
Epoch: 3

Epoch: 35 | Batch: 019 / 025 | Total loss: 3.835 | Reg loss: 0.029 | Tree loss: 3.835 | Accuracy: 1.000000 | 0.067 sec/iter
Epoch: 35 | Batch: 020 / 025 | Total loss: 3.822 | Reg loss: 0.029 | Tree loss: 3.822 | Accuracy: 1.000000 | 0.067 sec/iter
Epoch: 35 | Batch: 021 / 025 | Total loss: 3.804 | Reg loss: 0.029 | Tree loss: 3.804 | Accuracy: 1.000000 | 0.067 sec/iter
Epoch: 35 | Batch: 022 / 025 | Total loss: 3.811 | Reg loss: 0.029 | Tree loss: 3.811 | Accuracy: 1.000000 | 0.067 sec/iter
Epoch: 35 | Batch: 023 / 025 | Total loss: 3.778 | Reg loss: 0.029 | Tree loss: 3.778 | Accuracy: 1.000000 | 0.067 sec/iter
Epoch: 35 | Batch: 024 / 025 | Total loss: 3.752 | Reg loss: 0.029 | Tree loss: 3.752 | Accuracy: 1.000000 | 0.067 sec/iter
Average sparseness: 0.9821428571428567
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
Epoch: 36 | Batch: 000 / 025 | Total loss: 4.118 | Reg loss: 0.028 | Tree los

Epoch: 38 | Batch: 008 / 025 | Total loss: 3.679 | Reg loss: 0.028 | Tree loss: 3.679 | Accuracy: 1.000000 | 0.067 sec/iter
Epoch: 38 | Batch: 009 / 025 | Total loss: 3.672 | Reg loss: 0.028 | Tree loss: 3.672 | Accuracy: 1.000000 | 0.067 sec/iter
Epoch: 38 | Batch: 010 / 025 | Total loss: 3.642 | Reg loss: 0.028 | Tree loss: 3.642 | Accuracy: 1.000000 | 0.067 sec/iter
Epoch: 38 | Batch: 011 / 025 | Total loss: 3.617 | Reg loss: 0.028 | Tree loss: 3.617 | Accuracy: 1.000000 | 0.067 sec/iter
Epoch: 38 | Batch: 012 / 025 | Total loss: 3.599 | Reg loss: 0.028 | Tree loss: 3.599 | Accuracy: 1.000000 | 0.067 sec/iter
Epoch: 38 | Batch: 013 / 025 | Total loss: 3.581 | Reg loss: 0.028 | Tree loss: 3.581 | Accuracy: 1.000000 | 0.067 sec/iter
Epoch: 38 | Batch: 014 / 025 | Total loss: 3.547 | Reg loss: 0.028 | Tree loss: 3.547 | Accuracy: 1.000000 | 0.067 sec/iter
Epoch: 38 | Batch: 015 / 025 | Total loss: 3.534 | Reg loss: 0.028 | Tree loss: 3.534 | Accuracy: 1.000000 | 0.067 sec/iter
Epoch: 3

Epoch: 40 | Batch: 023 / 025 | Total loss: 3.109 | Reg loss: 0.029 | Tree loss: 3.109 | Accuracy: 1.000000 | 0.067 sec/iter
Epoch: 40 | Batch: 024 / 025 | Total loss: 3.100 | Reg loss: 0.029 | Tree loss: 3.100 | Accuracy: 1.000000 | 0.067 sec/iter
Average sparseness: 0.9821428571428567
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
Epoch: 41 | Batch: 000 / 025 | Total loss: 3.457 | Reg loss: 0.028 | Tree loss: 3.457 | Accuracy: 1.000000 | 0.067 sec/iter
Epoch: 41 | Batch: 001 / 025 | Total loss: 3.427 | Reg loss: 0.028 | Tree loss: 3.427 | Accuracy: 1.000000 | 0.067 sec/iter
Epoch: 41 | Batch: 002 / 025 | Total loss: 3.410 | Reg loss: 0.028 | Tree loss: 3.410 | Accuracy: 1.000000 | 0.067 sec/iter
Epoch: 41 | Batch: 003 / 025 | Total loss: 3.396 | Reg loss: 0.028 | Tree loss: 3.396 | Accuracy: 1.000000 | 0.067 sec/iter
Epoch: 41 | Batch: 004 / 025 | Total loss: 3.365 | Reg loss: 0.028 | Tree los

Epoch: 43 | Batch: 010 / 025 | Total loss: 3.004 | Reg loss: 0.028 | Tree loss: 3.004 | Accuracy: 1.000000 | 0.067 sec/iter
Epoch: 43 | Batch: 011 / 025 | Total loss: 2.981 | Reg loss: 0.028 | Tree loss: 2.981 | Accuracy: 1.000000 | 0.067 sec/iter
Epoch: 43 | Batch: 012 / 025 | Total loss: 2.942 | Reg loss: 0.028 | Tree loss: 2.942 | Accuracy: 1.000000 | 0.067 sec/iter
Epoch: 43 | Batch: 013 / 025 | Total loss: 2.951 | Reg loss: 0.028 | Tree loss: 2.951 | Accuracy: 1.000000 | 0.067 sec/iter
Epoch: 43 | Batch: 014 / 025 | Total loss: 2.924 | Reg loss: 0.028 | Tree loss: 2.924 | Accuracy: 1.000000 | 0.067 sec/iter
Epoch: 43 | Batch: 015 / 025 | Total loss: 2.900 | Reg loss: 0.028 | Tree loss: 2.900 | Accuracy: 1.000000 | 0.067 sec/iter
Epoch: 43 | Batch: 016 / 025 | Total loss: 2.873 | Reg loss: 0.029 | Tree loss: 2.873 | Accuracy: 1.000000 | 0.067 sec/iter
Epoch: 43 | Batch: 017 / 025 | Total loss: 2.866 | Reg loss: 0.029 | Tree loss: 2.866 | Accuracy: 1.000000 | 0.067 sec/iter
Epoch: 4

Average sparseness: 0.9821428571428567
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
Epoch: 46 | Batch: 000 / 025 | Total loss: 2.830 | Reg loss: 0.028 | Tree loss: 2.830 | Accuracy: 1.000000 | 0.067 sec/iter
Epoch: 46 | Batch: 001 / 025 | Total loss: 2.813 | Reg loss: 0.028 | Tree loss: 2.813 | Accuracy: 1.000000 | 0.067 sec/iter
Epoch: 46 | Batch: 002 / 025 | Total loss: 2.804 | Reg loss: 0.028 | Tree loss: 2.804 | Accuracy: 1.000000 | 0.067 sec/iter
Epoch: 46 | Batch: 003 / 025 | Total loss: 2.776 | Reg loss: 0.028 | Tree loss: 2.776 | Accuracy: 1.000000 | 0.067 sec/iter
Epoch: 46 | Batch: 004 / 025 | Total loss: 2.771 | Reg loss: 0.028 | Tree loss: 2.771 | Accuracy: 1.000000 | 0.067 sec/iter
Epoch: 46 | Batch: 005 / 025 | Total loss: 2.722 | Reg loss: 0.028 | Tree loss: 2.722 | Accuracy: 1.000000 | 0.067 sec/iter
Epoch: 46 | Batch: 006 / 025 | Total loss: 2.716 | Reg loss: 0.028 | Tree los

Epoch: 48 | Batch: 012 / 025 | Total loss: 2.380 | Reg loss: 0.028 | Tree loss: 2.380 | Accuracy: 1.000000 | 0.067 sec/iter
Epoch: 48 | Batch: 013 / 025 | Total loss: 2.351 | Reg loss: 0.028 | Tree loss: 2.351 | Accuracy: 1.000000 | 0.067 sec/iter
Epoch: 48 | Batch: 014 / 025 | Total loss: 2.348 | Reg loss: 0.028 | Tree loss: 2.348 | Accuracy: 1.000000 | 0.067 sec/iter
Epoch: 48 | Batch: 015 / 025 | Total loss: 2.320 | Reg loss: 0.028 | Tree loss: 2.320 | Accuracy: 1.000000 | 0.067 sec/iter
Epoch: 48 | Batch: 016 / 025 | Total loss: 2.314 | Reg loss: 0.029 | Tree loss: 2.314 | Accuracy: 1.000000 | 0.067 sec/iter
Epoch: 48 | Batch: 017 / 025 | Total loss: 2.281 | Reg loss: 0.029 | Tree loss: 2.281 | Accuracy: 1.000000 | 0.067 sec/iter
Epoch: 48 | Batch: 018 / 025 | Total loss: 2.271 | Reg loss: 0.029 | Tree loss: 2.271 | Accuracy: 1.000000 | 0.067 sec/iter
Epoch: 48 | Batch: 019 / 025 | Total loss: 2.241 | Reg loss: 0.029 | Tree loss: 2.241 | Accuracy: 1.000000 | 0.067 sec/iter
Epoch: 4

Epoch: 51 | Batch: 000 / 025 | Total loss: 2.255 | Reg loss: 0.028 | Tree loss: 2.255 | Accuracy: 1.000000 | 0.067 sec/iter
Epoch: 51 | Batch: 001 / 025 | Total loss: 2.241 | Reg loss: 0.028 | Tree loss: 2.241 | Accuracy: 1.000000 | 0.067 sec/iter
Epoch: 51 | Batch: 002 / 025 | Total loss: 2.217 | Reg loss: 0.028 | Tree loss: 2.217 | Accuracy: 1.000000 | 0.067 sec/iter
Epoch: 51 | Batch: 003 / 025 | Total loss: 2.212 | Reg loss: 0.028 | Tree loss: 2.212 | Accuracy: 1.000000 | 0.067 sec/iter
Epoch: 51 | Batch: 004 / 025 | Total loss: 2.200 | Reg loss: 0.028 | Tree loss: 2.200 | Accuracy: 1.000000 | 0.067 sec/iter
Epoch: 51 | Batch: 005 / 025 | Total loss: 2.198 | Reg loss: 0.028 | Tree loss: 2.198 | Accuracy: 1.000000 | 0.067 sec/iter
Epoch: 51 | Batch: 006 / 025 | Total loss: 2.178 | Reg loss: 0.028 | Tree loss: 2.178 | Accuracy: 1.000000 | 0.067 sec/iter
Epoch: 51 | Batch: 007 / 025 | Total loss: 2.138 | Reg loss: 0.028 | Tree loss: 2.138 | Accuracy: 1.000000 | 0.067 sec/iter
Epoch: 5

Epoch: 53 | Batch: 016 / 025 | Total loss: 1.803 | Reg loss: 0.028 | Tree loss: 1.803 | Accuracy: 1.000000 | 0.067 sec/iter
Epoch: 53 | Batch: 017 / 025 | Total loss: 1.800 | Reg loss: 0.028 | Tree loss: 1.800 | Accuracy: 1.000000 | 0.067 sec/iter
Epoch: 53 | Batch: 018 / 025 | Total loss: 1.776 | Reg loss: 0.028 | Tree loss: 1.776 | Accuracy: 1.000000 | 0.067 sec/iter
Epoch: 53 | Batch: 019 / 025 | Total loss: 1.745 | Reg loss: 0.028 | Tree loss: 1.745 | Accuracy: 1.000000 | 0.067 sec/iter
Epoch: 53 | Batch: 020 / 025 | Total loss: 1.754 | Reg loss: 0.029 | Tree loss: 1.754 | Accuracy: 1.000000 | 0.067 sec/iter
Epoch: 53 | Batch: 021 / 025 | Total loss: 1.740 | Reg loss: 0.029 | Tree loss: 1.740 | Accuracy: 1.000000 | 0.067 sec/iter
Epoch: 53 | Batch: 022 / 025 | Total loss: 1.721 | Reg loss: 0.029 | Tree loss: 1.721 | Accuracy: 1.000000 | 0.067 sec/iter
Epoch: 53 | Batch: 023 / 025 | Total loss: 1.701 | Reg loss: 0.029 | Tree loss: 1.701 | Accuracy: 1.000000 | 0.067 sec/iter
Epoch: 5

Epoch: 56 | Batch: 003 / 025 | Total loss: 1.738 | Reg loss: 0.028 | Tree loss: 1.738 | Accuracy: 1.000000 | 0.067 sec/iter
Epoch: 56 | Batch: 004 / 025 | Total loss: 1.716 | Reg loss: 0.028 | Tree loss: 1.716 | Accuracy: 1.000000 | 0.067 sec/iter
Epoch: 56 | Batch: 005 / 025 | Total loss: 1.696 | Reg loss: 0.028 | Tree loss: 1.696 | Accuracy: 1.000000 | 0.067 sec/iter
Epoch: 56 | Batch: 006 / 025 | Total loss: 1.686 | Reg loss: 0.028 | Tree loss: 1.686 | Accuracy: 1.000000 | 0.067 sec/iter
Epoch: 56 | Batch: 007 / 025 | Total loss: 1.673 | Reg loss: 0.028 | Tree loss: 1.673 | Accuracy: 1.000000 | 0.067 sec/iter
Epoch: 56 | Batch: 008 / 025 | Total loss: 1.660 | Reg loss: 0.028 | Tree loss: 1.660 | Accuracy: 1.000000 | 0.067 sec/iter
Epoch: 56 | Batch: 009 / 025 | Total loss: 1.655 | Reg loss: 0.028 | Tree loss: 1.655 | Accuracy: 1.000000 | 0.067 sec/iter
Epoch: 56 | Batch: 010 / 025 | Total loss: 1.648 | Reg loss: 0.028 | Tree loss: 1.648 | Accuracy: 1.000000 | 0.067 sec/iter
Epoch: 5

Epoch: 58 | Batch: 019 / 025 | Total loss: 1.352 | Reg loss: 0.028 | Tree loss: 1.352 | Accuracy: 1.000000 | 0.068 sec/iter
Epoch: 58 | Batch: 020 / 025 | Total loss: 1.349 | Reg loss: 0.028 | Tree loss: 1.349 | Accuracy: 1.000000 | 0.068 sec/iter
Epoch: 58 | Batch: 021 / 025 | Total loss: 1.332 | Reg loss: 0.028 | Tree loss: 1.332 | Accuracy: 1.000000 | 0.068 sec/iter
Epoch: 58 | Batch: 022 / 025 | Total loss: 1.323 | Reg loss: 0.028 | Tree loss: 1.323 | Accuracy: 1.000000 | 0.068 sec/iter
Epoch: 58 | Batch: 023 / 025 | Total loss: 1.317 | Reg loss: 0.028 | Tree loss: 1.317 | Accuracy: 1.000000 | 0.068 sec/iter
Epoch: 58 | Batch: 024 / 025 | Total loss: 1.299 | Reg loss: 0.028 | Tree loss: 1.299 | Accuracy: 1.000000 | 0.068 sec/iter
Average sparseness: 0.9821428571428567
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
Epoch: 59 | Batch: 000 / 025 | Total loss: 1.539 | Reg loss: 0.027 | Tree los

Epoch: 61 | Batch: 006 / 025 | Total loss: 1.312 | Reg loss: 0.027 | Tree loss: 1.312 | Accuracy: 1.000000 | 0.068 sec/iter
Epoch: 61 | Batch: 007 / 025 | Total loss: 1.316 | Reg loss: 0.027 | Tree loss: 1.316 | Accuracy: 1.000000 | 0.068 sec/iter
Epoch: 61 | Batch: 008 / 025 | Total loss: 1.286 | Reg loss: 0.027 | Tree loss: 1.286 | Accuracy: 1.000000 | 0.068 sec/iter
Epoch: 61 | Batch: 009 / 025 | Total loss: 1.271 | Reg loss: 0.027 | Tree loss: 1.271 | Accuracy: 1.000000 | 0.068 sec/iter
Epoch: 61 | Batch: 010 / 025 | Total loss: 1.264 | Reg loss: 0.027 | Tree loss: 1.264 | Accuracy: 1.000000 | 0.068 sec/iter
Epoch: 61 | Batch: 011 / 025 | Total loss: 1.255 | Reg loss: 0.027 | Tree loss: 1.255 | Accuracy: 1.000000 | 0.068 sec/iter
Epoch: 61 | Batch: 012 / 025 | Total loss: 1.238 | Reg loss: 0.027 | Tree loss: 1.238 | Accuracy: 1.000000 | 0.068 sec/iter
Epoch: 61 | Batch: 013 / 025 | Total loss: 1.226 | Reg loss: 0.027 | Tree loss: 1.226 | Accuracy: 1.000000 | 0.068 sec/iter
Epoch: 6

Epoch: 63 | Batch: 021 / 025 | Total loss: 1.032 | Reg loss: 0.027 | Tree loss: 1.032 | Accuracy: 1.000000 | 0.068 sec/iter
Epoch: 63 | Batch: 022 / 025 | Total loss: 1.017 | Reg loss: 0.027 | Tree loss: 1.017 | Accuracy: 1.000000 | 0.068 sec/iter
Epoch: 63 | Batch: 023 / 025 | Total loss: 1.012 | Reg loss: 0.027 | Tree loss: 1.012 | Accuracy: 1.000000 | 0.068 sec/iter
Epoch: 63 | Batch: 024 / 025 | Total loss: 1.000 | Reg loss: 0.027 | Tree loss: 1.000 | Accuracy: 1.000000 | 0.068 sec/iter
Average sparseness: 0.9821428571428567
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
Epoch: 64 | Batch: 000 / 025 | Total loss: 1.175 | Reg loss: 0.026 | Tree loss: 1.175 | Accuracy: 1.000000 | 0.068 sec/iter
Epoch: 64 | Batch: 001 / 025 | Total loss: 1.159 | Reg loss: 0.026 | Tree loss: 1.159 | Accuracy: 1.000000 | 0.068 sec/iter
Epoch: 64 | Batch: 002 / 025 | Total loss: 1.167 | Reg loss: 0.026 | Tree los

Epoch: 66 | Batch: 008 / 025 | Total loss: 0.999 | Reg loss: 0.026 | Tree loss: 0.999 | Accuracy: 1.000000 | 0.068 sec/iter
Epoch: 66 | Batch: 009 / 025 | Total loss: 0.989 | Reg loss: 0.026 | Tree loss: 0.989 | Accuracy: 1.000000 | 0.068 sec/iter
Epoch: 66 | Batch: 010 / 025 | Total loss: 0.972 | Reg loss: 0.026 | Tree loss: 0.972 | Accuracy: 1.000000 | 0.068 sec/iter
Epoch: 66 | Batch: 011 / 025 | Total loss: 0.977 | Reg loss: 0.026 | Tree loss: 0.977 | Accuracy: 1.000000 | 0.068 sec/iter
Epoch: 66 | Batch: 012 / 025 | Total loss: 0.959 | Reg loss: 0.026 | Tree loss: 0.959 | Accuracy: 1.000000 | 0.068 sec/iter
Epoch: 66 | Batch: 013 / 025 | Total loss: 0.950 | Reg loss: 0.026 | Tree loss: 0.950 | Accuracy: 1.000000 | 0.068 sec/iter
Epoch: 66 | Batch: 014 / 025 | Total loss: 0.947 | Reg loss: 0.026 | Tree loss: 0.947 | Accuracy: 1.000000 | 0.068 sec/iter
Epoch: 66 | Batch: 015 / 025 | Total loss: 0.928 | Reg loss: 0.026 | Tree loss: 0.928 | Accuracy: 1.000000 | 0.068 sec/iter
Epoch: 6

Epoch: 68 | Batch: 022 / 025 | Total loss: 0.784 | Reg loss: 0.026 | Tree loss: 0.784 | Accuracy: 1.000000 | 0.068 sec/iter
Epoch: 68 | Batch: 023 / 025 | Total loss: 0.785 | Reg loss: 0.026 | Tree loss: 0.785 | Accuracy: 1.000000 | 0.068 sec/iter
Epoch: 68 | Batch: 024 / 025 | Total loss: 0.794 | Reg loss: 0.026 | Tree loss: 0.794 | Accuracy: 1.000000 | 0.068 sec/iter
Average sparseness: 0.9821428571428567
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
Epoch: 69 | Batch: 000 / 025 | Total loss: 0.921 | Reg loss: 0.026 | Tree loss: 0.921 | Accuracy: 1.000000 | 0.069 sec/iter
Epoch: 69 | Batch: 001 / 025 | Total loss: 0.920 | Reg loss: 0.026 | Tree loss: 0.920 | Accuracy: 1.000000 | 0.069 sec/iter
Epoch: 69 | Batch: 002 / 025 | Total loss: 0.906 | Reg loss: 0.026 | Tree loss: 0.906 | Accuracy: 1.000000 | 0.069 sec/iter
Epoch: 69 | Batch: 003 / 025 | Total loss: 0.891 | Reg loss: 0.026 | Tree los

Epoch: 71 | Batch: 009 / 025 | Total loss: 0.765 | Reg loss: 0.026 | Tree loss: 0.765 | Accuracy: 1.000000 | 0.069 sec/iter
Epoch: 71 | Batch: 010 / 025 | Total loss: 0.760 | Reg loss: 0.026 | Tree loss: 0.760 | Accuracy: 1.000000 | 0.069 sec/iter
Epoch: 71 | Batch: 011 / 025 | Total loss: 0.763 | Reg loss: 0.026 | Tree loss: 0.763 | Accuracy: 1.000000 | 0.069 sec/iter
Epoch: 71 | Batch: 012 / 025 | Total loss: 0.748 | Reg loss: 0.026 | Tree loss: 0.748 | Accuracy: 1.000000 | 0.069 sec/iter
Epoch: 71 | Batch: 013 / 025 | Total loss: 0.740 | Reg loss: 0.026 | Tree loss: 0.740 | Accuracy: 1.000000 | 0.069 sec/iter
Epoch: 71 | Batch: 014 / 025 | Total loss: 0.738 | Reg loss: 0.026 | Tree loss: 0.738 | Accuracy: 1.000000 | 0.069 sec/iter
Epoch: 71 | Batch: 015 / 025 | Total loss: 0.724 | Reg loss: 0.026 | Tree loss: 0.724 | Accuracy: 1.000000 | 0.069 sec/iter
Epoch: 71 | Batch: 016 / 025 | Total loss: 0.715 | Reg loss: 0.026 | Tree loss: 0.715 | Accuracy: 1.000000 | 0.069 sec/iter
Epoch: 7

Epoch: 73 | Batch: 024 / 025 | Total loss: 0.608 | Reg loss: 0.026 | Tree loss: 0.608 | Accuracy: 1.000000 | 0.069 sec/iter
Average sparseness: 0.9821428571428567
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
Epoch: 74 | Batch: 000 / 025 | Total loss: 0.723 | Reg loss: 0.025 | Tree loss: 0.723 | Accuracy: 1.000000 | 0.069 sec/iter
Epoch: 74 | Batch: 001 / 025 | Total loss: 0.727 | Reg loss: 0.025 | Tree loss: 0.727 | Accuracy: 1.000000 | 0.069 sec/iter
Epoch: 74 | Batch: 002 / 025 | Total loss: 0.710 | Reg loss: 0.025 | Tree loss: 0.710 | Accuracy: 1.000000 | 0.069 sec/iter
Epoch: 74 | Batch: 003 / 025 | Total loss: 0.707 | Reg loss: 0.025 | Tree loss: 0.707 | Accuracy: 1.000000 | 0.069 sec/iter
Epoch: 74 | Batch: 004 / 025 | Total loss: 0.701 | Reg loss: 0.025 | Tree loss: 0.701 | Accuracy: 1.000000 | 0.069 sec/iter
Epoch: 74 | Batch: 005 / 025 | Total loss: 0.699 | Reg loss: 0.025 | Tree los

Epoch: 76 | Batch: 013 / 025 | Total loss: 0.600 | Reg loss: 0.025 | Tree loss: 0.600 | Accuracy: 1.000000 | 0.069 sec/iter
Epoch: 76 | Batch: 014 / 025 | Total loss: 0.586 | Reg loss: 0.025 | Tree loss: 0.586 | Accuracy: 1.000000 | 0.069 sec/iter
Epoch: 76 | Batch: 015 / 025 | Total loss: 0.575 | Reg loss: 0.025 | Tree loss: 0.575 | Accuracy: 1.000000 | 0.069 sec/iter
Epoch: 76 | Batch: 016 / 025 | Total loss: 0.576 | Reg loss: 0.025 | Tree loss: 0.576 | Accuracy: 1.000000 | 0.069 sec/iter
Epoch: 76 | Batch: 017 / 025 | Total loss: 0.574 | Reg loss: 0.025 | Tree loss: 0.574 | Accuracy: 1.000000 | 0.069 sec/iter
Epoch: 76 | Batch: 018 / 025 | Total loss: 0.568 | Reg loss: 0.025 | Tree loss: 0.568 | Accuracy: 1.000000 | 0.069 sec/iter
Epoch: 76 | Batch: 019 / 025 | Total loss: 0.559 | Reg loss: 0.025 | Tree loss: 0.559 | Accuracy: 1.000000 | 0.069 sec/iter
Epoch: 76 | Batch: 020 / 025 | Total loss: 0.560 | Reg loss: 0.025 | Tree loss: 0.560 | Accuracy: 1.000000 | 0.069 sec/iter
Epoch: 7

Epoch: 79 | Batch: 002 / 025 | Total loss: 0.582 | Reg loss: 0.025 | Tree loss: 0.582 | Accuracy: 1.000000 | 0.069 sec/iter
Epoch: 79 | Batch: 003 / 025 | Total loss: 0.570 | Reg loss: 0.025 | Tree loss: 0.570 | Accuracy: 1.000000 | 0.069 sec/iter
Epoch: 79 | Batch: 004 / 025 | Total loss: 0.559 | Reg loss: 0.025 | Tree loss: 0.559 | Accuracy: 1.000000 | 0.069 sec/iter
Epoch: 79 | Batch: 005 / 025 | Total loss: 0.570 | Reg loss: 0.025 | Tree loss: 0.570 | Accuracy: 1.000000 | 0.069 sec/iter
Epoch: 79 | Batch: 006 / 025 | Total loss: 0.557 | Reg loss: 0.025 | Tree loss: 0.557 | Accuracy: 1.000000 | 0.069 sec/iter
Epoch: 79 | Batch: 007 / 025 | Total loss: 0.548 | Reg loss: 0.025 | Tree loss: 0.548 | Accuracy: 1.000000 | 0.069 sec/iter
Epoch: 79 | Batch: 008 / 025 | Total loss: 0.548 | Reg loss: 0.025 | Tree loss: 0.548 | Accuracy: 1.000000 | 0.069 sec/iter
Epoch: 79 | Batch: 009 / 025 | Total loss: 0.542 | Reg loss: 0.025 | Tree loss: 0.542 | Accuracy: 1.000000 | 0.069 sec/iter
Epoch: 7

Epoch: 81 | Batch: 017 / 025 | Total loss: 0.463 | Reg loss: 0.025 | Tree loss: 0.463 | Accuracy: 1.000000 | 0.068 sec/iter
Epoch: 81 | Batch: 018 / 025 | Total loss: 0.465 | Reg loss: 0.025 | Tree loss: 0.465 | Accuracy: 1.000000 | 0.068 sec/iter
Epoch: 81 | Batch: 019 / 025 | Total loss: 0.463 | Reg loss: 0.025 | Tree loss: 0.463 | Accuracy: 1.000000 | 0.068 sec/iter
Epoch: 81 | Batch: 020 / 025 | Total loss: 0.453 | Reg loss: 0.025 | Tree loss: 0.453 | Accuracy: 1.000000 | 0.068 sec/iter
Epoch: 81 | Batch: 021 / 025 | Total loss: 0.454 | Reg loss: 0.025 | Tree loss: 0.454 | Accuracy: 1.000000 | 0.068 sec/iter
Epoch: 81 | Batch: 022 / 025 | Total loss: 0.453 | Reg loss: 0.025 | Tree loss: 0.453 | Accuracy: 1.000000 | 0.068 sec/iter
Epoch: 81 | Batch: 023 / 025 | Total loss: 0.444 | Reg loss: 0.025 | Tree loss: 0.444 | Accuracy: 1.000000 | 0.068 sec/iter
Epoch: 81 | Batch: 024 / 025 | Total loss: 0.441 | Reg loss: 0.025 | Tree loss: 0.441 | Accuracy: 1.000000 | 0.068 sec/iter
Average 

Epoch: 84 | Batch: 006 / 025 | Total loss: 0.457 | Reg loss: 0.024 | Tree loss: 0.457 | Accuracy: 1.000000 | 0.068 sec/iter
Epoch: 84 | Batch: 007 / 025 | Total loss: 0.457 | Reg loss: 0.024 | Tree loss: 0.457 | Accuracy: 1.000000 | 0.068 sec/iter
Epoch: 84 | Batch: 008 / 025 | Total loss: 0.451 | Reg loss: 0.024 | Tree loss: 0.451 | Accuracy: 1.000000 | 0.068 sec/iter
Epoch: 84 | Batch: 009 / 025 | Total loss: 0.451 | Reg loss: 0.024 | Tree loss: 0.451 | Accuracy: 1.000000 | 0.068 sec/iter
Epoch: 84 | Batch: 010 / 025 | Total loss: 0.444 | Reg loss: 0.024 | Tree loss: 0.444 | Accuracy: 1.000000 | 0.068 sec/iter
Epoch: 84 | Batch: 011 / 025 | Total loss: 0.440 | Reg loss: 0.024 | Tree loss: 0.440 | Accuracy: 1.000000 | 0.068 sec/iter
Epoch: 84 | Batch: 012 / 025 | Total loss: 0.439 | Reg loss: 0.024 | Tree loss: 0.439 | Accuracy: 1.000000 | 0.068 sec/iter
Epoch: 84 | Batch: 013 / 025 | Total loss: 0.436 | Reg loss: 0.024 | Tree loss: 0.436 | Accuracy: 1.000000 | 0.068 sec/iter
Epoch: 8

Epoch: 86 | Batch: 020 / 025 | Total loss: 0.380 | Reg loss: 0.024 | Tree loss: 0.380 | Accuracy: 1.000000 | 0.068 sec/iter
Epoch: 86 | Batch: 021 / 025 | Total loss: 0.381 | Reg loss: 0.024 | Tree loss: 0.381 | Accuracy: 1.000000 | 0.068 sec/iter
Epoch: 86 | Batch: 022 / 025 | Total loss: 0.374 | Reg loss: 0.024 | Tree loss: 0.374 | Accuracy: 1.000000 | 0.068 sec/iter
Epoch: 86 | Batch: 023 / 025 | Total loss: 0.374 | Reg loss: 0.024 | Tree loss: 0.374 | Accuracy: 1.000000 | 0.068 sec/iter
Epoch: 86 | Batch: 024 / 025 | Total loss: 0.372 | Reg loss: 0.024 | Tree loss: 0.372 | Accuracy: 1.000000 | 0.068 sec/iter
Average sparseness: 0.9821428571428567
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
Epoch: 87 | Batch: 000 / 025 | Total loss: 0.434 | Reg loss: 0.024 | Tree loss: 0.434 | Accuracy: 1.000000 | 0.068 sec/iter
Epoch: 87 | Batch: 001 / 025 | Total loss: 0.433 | Reg loss: 0.024 | Tree los

Epoch: 89 | Batch: 008 / 025 | Total loss: 0.389 | Reg loss: 0.024 | Tree loss: 0.389 | Accuracy: 1.000000 | 0.068 sec/iter
Epoch: 89 | Batch: 009 / 025 | Total loss: 0.384 | Reg loss: 0.024 | Tree loss: 0.384 | Accuracy: 1.000000 | 0.068 sec/iter
Epoch: 89 | Batch: 010 / 025 | Total loss: 0.372 | Reg loss: 0.024 | Tree loss: 0.372 | Accuracy: 1.000000 | 0.068 sec/iter
Epoch: 89 | Batch: 011 / 025 | Total loss: 0.371 | Reg loss: 0.024 | Tree loss: 0.371 | Accuracy: 1.000000 | 0.068 sec/iter
Epoch: 89 | Batch: 012 / 025 | Total loss: 0.368 | Reg loss: 0.024 | Tree loss: 0.368 | Accuracy: 1.000000 | 0.068 sec/iter
Epoch: 89 | Batch: 013 / 025 | Total loss: 0.363 | Reg loss: 0.024 | Tree loss: 0.363 | Accuracy: 1.000000 | 0.068 sec/iter
Epoch: 89 | Batch: 014 / 025 | Total loss: 0.369 | Reg loss: 0.024 | Tree loss: 0.369 | Accuracy: 1.000000 | 0.068 sec/iter
Epoch: 89 | Batch: 015 / 025 | Total loss: 0.360 | Reg loss: 0.024 | Tree loss: 0.360 | Accuracy: 1.000000 | 0.068 sec/iter
Epoch: 8

Epoch: 91 | Batch: 023 / 025 | Total loss: 0.317 | Reg loss: 0.024 | Tree loss: 0.317 | Accuracy: 1.000000 | 0.068 sec/iter
Epoch: 91 | Batch: 024 / 025 | Total loss: 0.310 | Reg loss: 0.024 | Tree loss: 0.310 | Accuracy: 1.000000 | 0.068 sec/iter
Average sparseness: 0.9821428571428567
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
Epoch: 92 | Batch: 000 / 025 | Total loss: 0.377 | Reg loss: 0.023 | Tree loss: 0.377 | Accuracy: 1.000000 | 0.068 sec/iter
Epoch: 92 | Batch: 001 / 025 | Total loss: 0.373 | Reg loss: 0.023 | Tree loss: 0.373 | Accuracy: 1.000000 | 0.068 sec/iter
Epoch: 92 | Batch: 002 / 025 | Total loss: 0.370 | Reg loss: 0.023 | Tree loss: 0.370 | Accuracy: 1.000000 | 0.068 sec/iter
Epoch: 92 | Batch: 003 / 025 | Total loss: 0.361 | Reg loss: 0.023 | Tree loss: 0.361 | Accuracy: 1.000000 | 0.068 sec/iter
Epoch: 92 | Batch: 004 / 025 | Total loss: 0.361 | Reg loss: 0.023 | Tree los

Epoch: 94 | Batch: 012 / 025 | Total loss: 0.319 | Reg loss: 0.023 | Tree loss: 0.319 | Accuracy: 1.000000 | 0.068 sec/iter
Epoch: 94 | Batch: 013 / 025 | Total loss: 0.320 | Reg loss: 0.023 | Tree loss: 0.320 | Accuracy: 1.000000 | 0.068 sec/iter
Epoch: 94 | Batch: 014 / 025 | Total loss: 0.318 | Reg loss: 0.023 | Tree loss: 0.318 | Accuracy: 1.000000 | 0.068 sec/iter
Epoch: 94 | Batch: 015 / 025 | Total loss: 0.310 | Reg loss: 0.024 | Tree loss: 0.310 | Accuracy: 1.000000 | 0.068 sec/iter
Epoch: 94 | Batch: 016 / 025 | Total loss: 0.313 | Reg loss: 0.024 | Tree loss: 0.313 | Accuracy: 1.000000 | 0.068 sec/iter
Epoch: 94 | Batch: 017 / 025 | Total loss: 0.306 | Reg loss: 0.024 | Tree loss: 0.306 | Accuracy: 1.000000 | 0.068 sec/iter
Epoch: 94 | Batch: 018 / 025 | Total loss: 0.305 | Reg loss: 0.024 | Tree loss: 0.305 | Accuracy: 1.000000 | 0.068 sec/iter
Epoch: 94 | Batch: 019 / 025 | Total loss: 0.303 | Reg loss: 0.024 | Tree loss: 0.303 | Accuracy: 1.000000 | 0.068 sec/iter
Epoch: 9

Epoch: 97 | Batch: 001 / 025 | Total loss: 0.325 | Reg loss: 0.023 | Tree loss: 0.325 | Accuracy: 1.000000 | 0.068 sec/iter
Epoch: 97 | Batch: 002 / 025 | Total loss: 0.322 | Reg loss: 0.023 | Tree loss: 0.322 | Accuracy: 1.000000 | 0.068 sec/iter
Epoch: 97 | Batch: 003 / 025 | Total loss: 0.321 | Reg loss: 0.023 | Tree loss: 0.321 | Accuracy: 1.000000 | 0.068 sec/iter
Epoch: 97 | Batch: 004 / 025 | Total loss: 0.316 | Reg loss: 0.023 | Tree loss: 0.316 | Accuracy: 1.000000 | 0.068 sec/iter
Epoch: 97 | Batch: 005 / 025 | Total loss: 0.315 | Reg loss: 0.023 | Tree loss: 0.315 | Accuracy: 1.000000 | 0.068 sec/iter
Epoch: 97 | Batch: 006 / 025 | Total loss: 0.313 | Reg loss: 0.023 | Tree loss: 0.313 | Accuracy: 1.000000 | 0.068 sec/iter
Epoch: 97 | Batch: 007 / 025 | Total loss: 0.313 | Reg loss: 0.023 | Tree loss: 0.313 | Accuracy: 1.000000 | 0.068 sec/iter
Epoch: 97 | Batch: 008 / 025 | Total loss: 0.306 | Reg loss: 0.023 | Tree loss: 0.306 | Accuracy: 1.000000 | 0.068 sec/iter
Epoch: 9

Epoch: 99 | Batch: 016 / 025 | Total loss: 0.276 | Reg loss: 0.023 | Tree loss: 0.276 | Accuracy: 1.000000 | 0.068 sec/iter
Epoch: 99 | Batch: 017 / 025 | Total loss: 0.272 | Reg loss: 0.023 | Tree loss: 0.272 | Accuracy: 1.000000 | 0.068 sec/iter
Epoch: 99 | Batch: 018 / 025 | Total loss: 0.271 | Reg loss: 0.023 | Tree loss: 0.271 | Accuracy: 1.000000 | 0.068 sec/iter
Epoch: 99 | Batch: 019 / 025 | Total loss: 0.267 | Reg loss: 0.023 | Tree loss: 0.267 | Accuracy: 1.000000 | 0.068 sec/iter
Epoch: 99 | Batch: 020 / 025 | Total loss: 0.264 | Reg loss: 0.023 | Tree loss: 0.264 | Accuracy: 1.000000 | 0.068 sec/iter
Epoch: 99 | Batch: 021 / 025 | Total loss: 0.261 | Reg loss: 0.023 | Tree loss: 0.261 | Accuracy: 1.000000 | 0.068 sec/iter
Epoch: 99 | Batch: 022 / 025 | Total loss: 0.263 | Reg loss: 0.023 | Tree loss: 0.263 | Accuracy: 1.000000 | 0.068 sec/iter
Epoch: 99 | Batch: 023 / 025 | Total loss: 0.262 | Reg loss: 0.023 | Tree loss: 0.262 | Accuracy: 1.000000 | 0.068 sec/iter
Epoch: 9

In [125]:
plt.figure(figsize=(10, 5))
plt.ylabel("Accuracy")
plt.xlabel('Iteration')
plt.plot(accs, label='Accuracy vs iteration')
plt.show()

  """Entry point for launching an IPython kernel.


Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [126]:
plt.figure()
plt.ylabel("Loss")
plt.xlabel('Iteration')
plt.plot(losses, label='Loss vs iteration')
plt.yscale("log")
plt.show()

plt.figure()
weights = tree.inner_nodes.weight.cpu().detach().numpy().flatten()
plt.hist(weights, bins=500)
weights_std = np.std(weights)
weights_mean = np.mean(weights)
plt.axvline(weights_mean + weights_std, color='r')
plt.axvline(weights_mean - weights_std, color='r')
plt.title(f"Mean: {weights_mean}   |   STD: {weights_std}")
plt.yscale("log")
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

# Tree Visualization

In [127]:
plt.figure(figsize=(15, 10), dpi=80)
avg_height, root = tree.visualize()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Average height: 0.0


# Extract Rules

# Accumulate samples in the leaves

In [128]:
print(f"Number of patterns: {len(root.get_leaves())}")

Number of patterns: 1


In [129]:
method = 'greedy'

In [130]:
root.clear_leaves_samples()

with torch.no_grad():
    for batch_idx, (data, target) in enumerate(tree_loader):
        root.accumulate_samples(data, method)

# Tighten boundaries

In [131]:
attr_names = dataset.items

# print(attr_names)
leaves = root.get_leaves()
sum_comprehensibility = 0
comprehensibilities = []
for pattern_counter, leaf in enumerate(leaves):
    leaf.reset_path()
    leaf.tighten_with_accumulated_samples()
    conds = leaf.get_path_conditions(attr_names)
    for cond in conds:
        cond.weights = cond.weights / normalizers
    print(f"============== Pattern {pattern_counter + 1} ==============")
    comprehensibilities.append(sum([cond.comprehensibility for cond in conds]))
    
print(f"Average comprehensibility: {np.mean(comprehensibilities)}")
print(f"std comprehensibility: {np.std(comprehensibilities)}")
print(f"var comprehensibility: {np.var(comprehensibilities)}")
print(f"minimum comprehensibility: {np.min(comprehensibilities)}")
print(f"maximum comprehensibility: {np.max(comprehensibilities)}")

12584
Average comprehensibility: 0.0
std comprehensibility: 0.0
var comprehensibility: 0.0
minimum comprehensibility: 0
maximum comprehensibility: 0
