In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib widget
import numpy as np
import pandas as pd
import torch
import torchvision
from torch import optim
import matplotlib.pyplot as plt
import torch.nn.functional as F
from sklearn.metrics import pairwise_distances
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tqdm.notebook import tqdm
import os
import sys
import time
import torch.nn as nn
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)
from stream_generators.market_basket_dataset import MarketBasketDataset, BinaryEncodingTransform, RemoveItemsTransform
from utils.MatplotlibUtils import reduce_dims_and_plot
from network.auto_encoder import AutoEncoder
from losses.knn_loss import KNNLoss
from sklearn.cluster import DBSCAN
from sklearn.mixture import GaussianMixture
from soft_decision_tree.sdt_model import SDT
from sklearn.metrics import davies_bouldin_score

In [2]:
k = 32
tree_depth = 10
device = 'cuda'
dataset_path = r"/mnt/qnap/ekosman/Groceries_dataset.csv"

## Create the market basket dataset and use one-hot encoding for items

In [3]:
dataset = MarketBasketDataset(dataset_path=dataset_path)

In [4]:
model = AutoEncoder(dataset.n_items, 50, 4).train().to(device)
epochs = 500
lr = 5e-3
batch_size = 512
log_every = 5

In [5]:
dataset.transform = torchvision.transforms.Compose([
    RemoveItemsTransform(p=0.5),
    BinaryEncodingTransform(mapping=dataset.items_to_idx),
]
)
dataset.target_transform = torchvision.transforms.Compose([
    BinaryEncodingTransform(mapping=dataset.items_to_idx),
]
)

In [6]:
model.train()
data_iter = torch.utils.data.DataLoader(dataset,
                                     batch_size=batch_size,
                                     shuffle=True,
                                     num_workers=1,
                                     pin_memory=True)
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)
scheduler = ReduceLROnPlateau(optimizer, 'min', verbose=True, factor=0.7, threshold=1e-4)
knn_crt = KNNLoss(k=k).to(device)
losses = []
alpha = 10/170
gamma = 2
for epoch in range(epochs):
    total_loss = 0
    for iteration, (batch, target) in enumerate(data_iter):
        batch = batch.to(device)
        target = target.to(device)
        outputs, iterm = model(batch, return_intermidiate=True)
        mse_loss = F.binary_cross_entropy_with_logits(outputs, target, reduction='none')
        mask = torch.ones_like(mse_loss)
        mask[target == 0] = alpha ** gamma
        mask[target == 1] = (1 - alpha) ** gamma
        mse_loss = (mse_loss * mask).sum(dim=-1).mean()
        try:
            knn_loss = knn_crt(iterm)
            if torch.isinf(knn_loss):
                knn_loss = 0
        except ValueError:
            knn_loss = torch.tensor(0)
        loss = mse_loss + knn_loss
        total_loss += loss.item()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if iteration % log_every == 0:
            print(f"Epoch {epoch} / {epochs} | iteration {iteration} / {len(data_iter)} | Total Loss: {loss.item()} | KNN Loss: {knn_loss.item()} | BCE Loss: {mse_loss.item()}")
    
    scheduler.step(total_loss / (iteration + 1))
    losses.append(total_loss / (iteration + 1))

Epoch 0 / 500 | iteration 0 / 30 | Total Loss: 8.133718490600586 | KNN Loss: 6.226861000061035 | BCE Loss: 1.9068574905395508
Epoch 0 / 500 | iteration 5 / 30 | Total Loss: 8.137260437011719 | KNN Loss: 6.226616859436035 | BCE Loss: 1.9106433391571045
Epoch 0 / 500 | iteration 10 / 30 | Total Loss: 8.188157081604004 | KNN Loss: 6.226138114929199 | BCE Loss: 1.9620192050933838
Epoch 0 / 500 | iteration 15 / 30 | Total Loss: 8.124198913574219 | KNN Loss: 6.226280212402344 | BCE Loss: 1.8979182243347168
Epoch 0 / 500 | iteration 20 / 30 | Total Loss: 8.152647018432617 | KNN Loss: 6.226075172424316 | BCE Loss: 1.9265716075897217
Epoch 0 / 500 | iteration 25 / 30 | Total Loss: 8.139322280883789 | KNN Loss: 6.2255659103393555 | BCE Loss: 1.913756012916565
Epoch 1 / 500 | iteration 0 / 30 | Total Loss: 8.129941940307617 | KNN Loss: 6.225582599639893 | BCE Loss: 1.9043588638305664
Epoch 1 / 500 | iteration 5 / 30 | Total Loss: 8.077471733093262 | KNN Loss: 6.2255635261535645 | BCE Loss: 1.8519

Epoch 10 / 500 | iteration 25 / 30 | Total Loss: 6.2970476150512695 | KNN Loss: 5.1975531578063965 | BCE Loss: 1.0994946956634521
Epoch 11 / 500 | iteration 0 / 30 | Total Loss: 6.278271198272705 | KNN Loss: 5.143256664276123 | BCE Loss: 1.135014533996582
Epoch 11 / 500 | iteration 5 / 30 | Total Loss: 6.123623847961426 | KNN Loss: 5.0088276863098145 | BCE Loss: 1.1147961616516113
Epoch 11 / 500 | iteration 10 / 30 | Total Loss: 6.025326251983643 | KNN Loss: 4.92909574508667 | BCE Loss: 1.0962305068969727
Epoch 11 / 500 | iteration 15 / 30 | Total Loss: 5.927366256713867 | KNN Loss: 4.832589149475098 | BCE Loss: 1.0947771072387695
Epoch 11 / 500 | iteration 20 / 30 | Total Loss: 5.831093788146973 | KNN Loss: 4.722562313079834 | BCE Loss: 1.1085315942764282
Epoch 11 / 500 | iteration 25 / 30 | Total Loss: 5.719992637634277 | KNN Loss: 4.626026630401611 | BCE Loss: 1.0939661264419556
Epoch 12 / 500 | iteration 0 / 30 | Total Loss: 5.601950168609619 | KNN Loss: 4.527592182159424 | BCE Los

Epoch 21 / 500 | iteration 20 / 30 | Total Loss: 4.895739555358887 | KNN Loss: 3.8490068912506104 | BCE Loss: 1.046732783317566
Epoch 21 / 500 | iteration 25 / 30 | Total Loss: 4.902154922485352 | KNN Loss: 3.8195180892944336 | BCE Loss: 1.082637071609497
Epoch 22 / 500 | iteration 0 / 30 | Total Loss: 4.926777362823486 | KNN Loss: 3.8551180362701416 | BCE Loss: 1.0716592073440552
Epoch 22 / 500 | iteration 5 / 30 | Total Loss: 4.889044284820557 | KNN Loss: 3.8556411266326904 | BCE Loss: 1.0334031581878662
Epoch 22 / 500 | iteration 10 / 30 | Total Loss: 4.888866901397705 | KNN Loss: 3.8489649295806885 | BCE Loss: 1.0399020910263062
Epoch 22 / 500 | iteration 15 / 30 | Total Loss: 4.8839311599731445 | KNN Loss: 3.8232359886169434 | BCE Loss: 1.0606951713562012
Epoch 22 / 500 | iteration 20 / 30 | Total Loss: 4.8758344650268555 | KNN Loss: 3.8018832206726074 | BCE Loss: 1.0739511251449585
Epoch 22 / 500 | iteration 25 / 30 | Total Loss: 4.916973114013672 | KNN Loss: 3.8316996097564697 |

Epoch 32 / 500 | iteration 15 / 30 | Total Loss: 4.830670356750488 | KNN Loss: 3.7763350009918213 | BCE Loss: 1.0543352365493774
Epoch 32 / 500 | iteration 20 / 30 | Total Loss: 4.85447883605957 | KNN Loss: 3.7854175567626953 | BCE Loss: 1.069061279296875
Epoch 32 / 500 | iteration 25 / 30 | Total Loss: 4.841700077056885 | KNN Loss: 3.8065178394317627 | BCE Loss: 1.035182237625122
Epoch 33 / 500 | iteration 0 / 30 | Total Loss: 4.8088531494140625 | KNN Loss: 3.7778830528259277 | BCE Loss: 1.0309698581695557
Epoch 33 / 500 | iteration 5 / 30 | Total Loss: 4.791323184967041 | KNN Loss: 3.7804763317108154 | BCE Loss: 1.0108469724655151
Epoch 33 / 500 | iteration 10 / 30 | Total Loss: 4.834181785583496 | KNN Loss: 3.7814431190490723 | BCE Loss: 1.0527385473251343
Epoch 33 / 500 | iteration 15 / 30 | Total Loss: 4.873087406158447 | KNN Loss: 3.8162765502929688 | BCE Loss: 1.056810736656189
Epoch 33 / 500 | iteration 20 / 30 | Total Loss: 4.836986541748047 | KNN Loss: 3.7850193977355957 | BC

Epoch 43 / 500 | iteration 10 / 30 | Total Loss: 4.8183817863464355 | KNN Loss: 3.7670352458953857 | BCE Loss: 1.0513464212417603
Epoch 43 / 500 | iteration 15 / 30 | Total Loss: 4.799330711364746 | KNN Loss: 3.7650609016418457 | BCE Loss: 1.0342698097229004
Epoch 43 / 500 | iteration 20 / 30 | Total Loss: 4.820438861846924 | KNN Loss: 3.7949063777923584 | BCE Loss: 1.0255324840545654
Epoch 43 / 500 | iteration 25 / 30 | Total Loss: 4.801769733428955 | KNN Loss: 3.762892723083496 | BCE Loss: 1.038877010345459
Epoch 44 / 500 | iteration 0 / 30 | Total Loss: 4.841640472412109 | KNN Loss: 3.7723915576934814 | BCE Loss: 1.0692486763000488
Epoch 44 / 500 | iteration 5 / 30 | Total Loss: 4.806844711303711 | KNN Loss: 3.7636656761169434 | BCE Loss: 1.0431790351867676
Epoch 44 / 500 | iteration 10 / 30 | Total Loss: 4.792999267578125 | KNN Loss: 3.770460605621338 | BCE Loss: 1.0225389003753662
Epoch 44 / 500 | iteration 15 / 30 | Total Loss: 4.766674041748047 | KNN Loss: 3.7495651245117188 | B

Epoch 54 / 500 | iteration 5 / 30 | Total Loss: 4.746033668518066 | KNN Loss: 3.7426271438598633 | BCE Loss: 1.0034065246582031
Epoch 54 / 500 | iteration 10 / 30 | Total Loss: 4.759019374847412 | KNN Loss: 3.7413265705108643 | BCE Loss: 1.0176926851272583
Epoch 54 / 500 | iteration 15 / 30 | Total Loss: 4.778705596923828 | KNN Loss: 3.761247396469116 | BCE Loss: 1.017458438873291
Epoch 54 / 500 | iteration 20 / 30 | Total Loss: 4.833542823791504 | KNN Loss: 3.795400857925415 | BCE Loss: 1.0381418466567993
Epoch 54 / 500 | iteration 25 / 30 | Total Loss: 4.789176940917969 | KNN Loss: 3.7545673847198486 | BCE Loss: 1.0346095561981201
Epoch 55 / 500 | iteration 0 / 30 | Total Loss: 4.811786651611328 | KNN Loss: 3.773365020751953 | BCE Loss: 1.038421630859375
Epoch 55 / 500 | iteration 5 / 30 | Total Loss: 4.768449306488037 | KNN Loss: 3.762960195541382 | BCE Loss: 1.0054891109466553
Epoch 55 / 500 | iteration 10 / 30 | Total Loss: 4.768016815185547 | KNN Loss: 3.727332830429077 | BCE Los

Epoch 64 / 500 | iteration 25 / 30 | Total Loss: 4.804323673248291 | KNN Loss: 3.7712912559509277 | BCE Loss: 1.0330322980880737
Epoch 65 / 500 | iteration 0 / 30 | Total Loss: 4.743203163146973 | KNN Loss: 3.724619150161743 | BCE Loss: 1.0185837745666504
Epoch 65 / 500 | iteration 5 / 30 | Total Loss: 4.767424583435059 | KNN Loss: 3.7458789348602295 | BCE Loss: 1.021545648574829
Epoch 65 / 500 | iteration 10 / 30 | Total Loss: 4.785834312438965 | KNN Loss: 3.766901731491089 | BCE Loss: 1.018932580947876
Epoch 65 / 500 | iteration 15 / 30 | Total Loss: 4.73980188369751 | KNN Loss: 3.7337684631347656 | BCE Loss: 1.0060333013534546
Epoch 65 / 500 | iteration 20 / 30 | Total Loss: 4.798725605010986 | KNN Loss: 3.7558059692382812 | BCE Loss: 1.042919635772705
Epoch 65 / 500 | iteration 25 / 30 | Total Loss: 4.806766510009766 | KNN Loss: 3.750889301300049 | BCE Loss: 1.0558772087097168
Epoch 66 / 500 | iteration 0 / 30 | Total Loss: 4.716067790985107 | KNN Loss: 3.7115399837493896 | BCE Los

Epoch 75 / 500 | iteration 20 / 30 | Total Loss: 4.778682231903076 | KNN Loss: 3.7395260334014893 | BCE Loss: 1.0391563177108765
Epoch 75 / 500 | iteration 25 / 30 | Total Loss: 4.757467746734619 | KNN Loss: 3.7363107204437256 | BCE Loss: 1.021157145500183
Epoch 76 / 500 | iteration 0 / 30 | Total Loss: 4.747154235839844 | KNN Loss: 3.7119104862213135 | BCE Loss: 1.0352437496185303
Epoch 76 / 500 | iteration 5 / 30 | Total Loss: 4.7187180519104 | KNN Loss: 3.7200770378112793 | BCE Loss: 0.9986411929130554
Epoch 76 / 500 | iteration 10 / 30 | Total Loss: 4.768254280090332 | KNN Loss: 3.735257148742676 | BCE Loss: 1.0329973697662354
Epoch 76 / 500 | iteration 15 / 30 | Total Loss: 4.743958950042725 | KNN Loss: 3.7258899211883545 | BCE Loss: 1.0180689096450806
Epoch 76 / 500 | iteration 20 / 30 | Total Loss: 4.759041786193848 | KNN Loss: 3.7442972660064697 | BCE Loss: 1.0147442817687988
Epoch 76 / 500 | iteration 25 / 30 | Total Loss: 4.7481584548950195 | KNN Loss: 3.71254563331604 | BCE 

Epoch 86 / 500 | iteration 10 / 30 | Total Loss: 4.732298374176025 | KNN Loss: 3.727149248123169 | BCE Loss: 1.0051491260528564
Epoch 86 / 500 | iteration 15 / 30 | Total Loss: 4.73263692855835 | KNN Loss: 3.7160277366638184 | BCE Loss: 1.0166091918945312
Epoch 86 / 500 | iteration 20 / 30 | Total Loss: 4.771571636199951 | KNN Loss: 3.7526919841766357 | BCE Loss: 1.0188796520233154
Epoch 86 / 500 | iteration 25 / 30 | Total Loss: 4.807049751281738 | KNN Loss: 3.767599105834961 | BCE Loss: 1.039450764656067
Epoch 87 / 500 | iteration 0 / 30 | Total Loss: 4.786619186401367 | KNN Loss: 3.753934860229492 | BCE Loss: 1.032684564590454
Epoch 87 / 500 | iteration 5 / 30 | Total Loss: 4.76007080078125 | KNN Loss: 3.7363476753234863 | BCE Loss: 1.0237228870391846
Epoch 87 / 500 | iteration 10 / 30 | Total Loss: 4.695275783538818 | KNN Loss: 3.698026180267334 | BCE Loss: 0.9972495436668396
Epoch 87 / 500 | iteration 15 / 30 | Total Loss: 4.7546281814575195 | KNN Loss: 3.723904848098755 | BCE Los

Epoch 97 / 500 | iteration 5 / 30 | Total Loss: 4.744203567504883 | KNN Loss: 3.728470802307129 | BCE Loss: 1.0157325267791748
Epoch 97 / 500 | iteration 10 / 30 | Total Loss: 4.791128158569336 | KNN Loss: 3.7581899166107178 | BCE Loss: 1.0329383611679077
Epoch 97 / 500 | iteration 15 / 30 | Total Loss: 4.785099506378174 | KNN Loss: 3.7327120304107666 | BCE Loss: 1.0523874759674072
Epoch 97 / 500 | iteration 20 / 30 | Total Loss: 4.823698043823242 | KNN Loss: 3.756364107131958 | BCE Loss: 1.067333698272705
Epoch 97 / 500 | iteration 25 / 30 | Total Loss: 4.748905181884766 | KNN Loss: 3.7328269481658936 | BCE Loss: 1.0160784721374512
Epoch 98 / 500 | iteration 0 / 30 | Total Loss: 4.755220413208008 | KNN Loss: 3.7157833576202393 | BCE Loss: 1.039436936378479
Epoch 98 / 500 | iteration 5 / 30 | Total Loss: 4.740978717803955 | KNN Loss: 3.7173399925231934 | BCE Loss: 1.0236387252807617
Epoch 98 / 500 | iteration 10 / 30 | Total Loss: 4.73416805267334 | KNN Loss: 3.7425291538238525 | BCE L

Epoch 107 / 500 | iteration 25 / 30 | Total Loss: 4.766383171081543 | KNN Loss: 3.721853256225586 | BCE Loss: 1.0445301532745361
Epoch 108 / 500 | iteration 0 / 30 | Total Loss: 4.790277481079102 | KNN Loss: 3.7404749393463135 | BCE Loss: 1.049802541732788
Epoch 108 / 500 | iteration 5 / 30 | Total Loss: 4.7837700843811035 | KNN Loss: 3.7410778999328613 | BCE Loss: 1.0426921844482422
Epoch 108 / 500 | iteration 10 / 30 | Total Loss: 4.756117820739746 | KNN Loss: 3.7103755474090576 | BCE Loss: 1.045742392539978
Epoch 108 / 500 | iteration 15 / 30 | Total Loss: 4.707571983337402 | KNN Loss: 3.7029592990875244 | BCE Loss: 1.0046124458312988
Epoch 108 / 500 | iteration 20 / 30 | Total Loss: 4.696730136871338 | KNN Loss: 3.7123231887817383 | BCE Loss: 0.9844069480895996
Epoch 108 / 500 | iteration 25 / 30 | Total Loss: 4.7585577964782715 | KNN Loss: 3.6969218254089355 | BCE Loss: 1.0616360902786255
Epoch 109 / 500 | iteration 0 / 30 | Total Loss: 4.741058826446533 | KNN Loss: 3.738237857818

Epoch 118 / 500 | iteration 15 / 30 | Total Loss: 4.726959705352783 | KNN Loss: 3.7319440841674805 | BCE Loss: 0.9950157999992371
Epoch 118 / 500 | iteration 20 / 30 | Total Loss: 4.769983291625977 | KNN Loss: 3.7198033332824707 | BCE Loss: 1.050180196762085
Epoch 118 / 500 | iteration 25 / 30 | Total Loss: 4.776616096496582 | KNN Loss: 3.710921049118042 | BCE Loss: 1.0656951665878296
Epoch 119 / 500 | iteration 0 / 30 | Total Loss: 4.756170749664307 | KNN Loss: 3.717165231704712 | BCE Loss: 1.0390053987503052
Epoch 119 / 500 | iteration 5 / 30 | Total Loss: 4.712133884429932 | KNN Loss: 3.7022809982299805 | BCE Loss: 1.0098527669906616
Epoch 119 / 500 | iteration 10 / 30 | Total Loss: 4.762120246887207 | KNN Loss: 3.7132728099823 | BCE Loss: 1.0488471984863281
Epoch 119 / 500 | iteration 15 / 30 | Total Loss: 4.71661376953125 | KNN Loss: 3.6954264640808105 | BCE Loss: 1.0211875438690186
Epoch 119 / 500 | iteration 20 / 30 | Total Loss: 4.768210411071777 | KNN Loss: 3.756453037261963 |

Epoch 129 / 500 | iteration 5 / 30 | Total Loss: 4.81531286239624 | KNN Loss: 3.7517008781433105 | BCE Loss: 1.0636119842529297
Epoch 129 / 500 | iteration 10 / 30 | Total Loss: 4.734622001647949 | KNN Loss: 3.7043774127960205 | BCE Loss: 1.0302447080612183
Epoch 129 / 500 | iteration 15 / 30 | Total Loss: 4.729386329650879 | KNN Loss: 3.694829225540161 | BCE Loss: 1.0345571041107178
Epoch 129 / 500 | iteration 20 / 30 | Total Loss: 4.717040538787842 | KNN Loss: 3.720607042312622 | BCE Loss: 0.9964333772659302
Epoch 129 / 500 | iteration 25 / 30 | Total Loss: 4.73325252532959 | KNN Loss: 3.745354652404785 | BCE Loss: 0.9878976345062256
Epoch 130 / 500 | iteration 0 / 30 | Total Loss: 4.732753753662109 | KNN Loss: 3.714066743850708 | BCE Loss: 1.0186870098114014
Epoch 130 / 500 | iteration 5 / 30 | Total Loss: 4.747666835784912 | KNN Loss: 3.690275192260742 | BCE Loss: 1.05739164352417
Epoch 130 / 500 | iteration 10 / 30 | Total Loss: 4.709174156188965 | KNN Loss: 3.696916103363037 | BC

Epoch 139 / 500 | iteration 25 / 30 | Total Loss: 4.719339370727539 | KNN Loss: 3.6985154151916504 | BCE Loss: 1.0208237171173096
Epoch 140 / 500 | iteration 0 / 30 | Total Loss: 4.711606502532959 | KNN Loss: 3.7011332511901855 | BCE Loss: 1.0104732513427734
Epoch 140 / 500 | iteration 5 / 30 | Total Loss: 4.741427421569824 | KNN Loss: 3.7226762771606445 | BCE Loss: 1.0187511444091797
Epoch 140 / 500 | iteration 10 / 30 | Total Loss: 4.689827919006348 | KNN Loss: 3.6908771991729736 | BCE Loss: 0.9989506602287292
Epoch 140 / 500 | iteration 15 / 30 | Total Loss: 4.698273658752441 | KNN Loss: 3.6775872707366943 | BCE Loss: 1.0206865072250366
Epoch 140 / 500 | iteration 20 / 30 | Total Loss: 4.690425872802734 | KNN Loss: 3.721332311630249 | BCE Loss: 0.9690937995910645
Epoch 140 / 500 | iteration 25 / 30 | Total Loss: 4.759272575378418 | KNN Loss: 3.7386832237243652 | BCE Loss: 1.0205894708633423
Epoch 141 / 500 | iteration 0 / 30 | Total Loss: 4.714505195617676 | KNN Loss: 3.697933912277

Epoch 150 / 500 | iteration 10 / 30 | Total Loss: 4.756929397583008 | KNN Loss: 3.7393643856048584 | BCE Loss: 1.0175647735595703
Epoch 150 / 500 | iteration 15 / 30 | Total Loss: 4.754549980163574 | KNN Loss: 3.719489097595215 | BCE Loss: 1.0350606441497803
Epoch 150 / 500 | iteration 20 / 30 | Total Loss: 4.734681129455566 | KNN Loss: 3.716872453689575 | BCE Loss: 1.017808437347412
Epoch 150 / 500 | iteration 25 / 30 | Total Loss: 4.71866512298584 | KNN Loss: 3.683283567428589 | BCE Loss: 1.0353814363479614
Epoch 151 / 500 | iteration 0 / 30 | Total Loss: 4.7091569900512695 | KNN Loss: 3.7152724266052246 | BCE Loss: 0.9938847422599792
Epoch 151 / 500 | iteration 5 / 30 | Total Loss: 4.7798943519592285 | KNN Loss: 3.7471213340759277 | BCE Loss: 1.0327730178833008
Epoch 151 / 500 | iteration 10 / 30 | Total Loss: 4.720847129821777 | KNN Loss: 3.71753191947937 | BCE Loss: 1.0033149719238281
Epoch 151 / 500 | iteration 15 / 30 | Total Loss: 4.769082069396973 | KNN Loss: 3.740628719329834

Epoch 161 / 500 | iteration 0 / 30 | Total Loss: 4.753200531005859 | KNN Loss: 3.727186918258667 | BCE Loss: 1.0260136127471924
Epoch 161 / 500 | iteration 5 / 30 | Total Loss: 4.717901706695557 | KNN Loss: 3.688567876815796 | BCE Loss: 1.0293339490890503
Epoch 161 / 500 | iteration 10 / 30 | Total Loss: 4.754419326782227 | KNN Loss: 3.728102445602417 | BCE Loss: 1.02631676197052
Epoch 161 / 500 | iteration 15 / 30 | Total Loss: 4.766294002532959 | KNN Loss: 3.7201483249664307 | BCE Loss: 1.0461457967758179
Epoch 161 / 500 | iteration 20 / 30 | Total Loss: 4.765169143676758 | KNN Loss: 3.7180089950561523 | BCE Loss: 1.047160267829895
Epoch 161 / 500 | iteration 25 / 30 | Total Loss: 4.73515510559082 | KNN Loss: 3.7042880058288574 | BCE Loss: 1.0308669805526733
Epoch 162 / 500 | iteration 0 / 30 | Total Loss: 4.720388889312744 | KNN Loss: 3.6718897819519043 | BCE Loss: 1.0484991073608398
Epoch 162 / 500 | iteration 5 / 30 | Total Loss: 4.717018127441406 | KNN Loss: 3.710132122039795 | B

Epoch 171 / 500 | iteration 20 / 30 | Total Loss: 4.780730724334717 | KNN Loss: 3.729328155517578 | BCE Loss: 1.0514025688171387
Epoch 171 / 500 | iteration 25 / 30 | Total Loss: 4.683350086212158 | KNN Loss: 3.679535388946533 | BCE Loss: 1.0038148164749146
Epoch 172 / 500 | iteration 0 / 30 | Total Loss: 4.708678245544434 | KNN Loss: 3.717698335647583 | BCE Loss: 0.9909801483154297
Epoch 172 / 500 | iteration 5 / 30 | Total Loss: 4.731808662414551 | KNN Loss: 3.6848747730255127 | BCE Loss: 1.0469340085983276
Epoch 172 / 500 | iteration 10 / 30 | Total Loss: 4.716100215911865 | KNN Loss: 3.6944806575775146 | BCE Loss: 1.0216196775436401
Epoch 172 / 500 | iteration 15 / 30 | Total Loss: 4.724859237670898 | KNN Loss: 3.6820766925811768 | BCE Loss: 1.0427826642990112
Epoch 172 / 500 | iteration 20 / 30 | Total Loss: 4.738103866577148 | KNN Loss: 3.7123820781707764 | BCE Loss: 1.025721549987793
Epoch 172 / 500 | iteration 25 / 30 | Total Loss: 4.715538501739502 | KNN Loss: 3.71619677543640

Epoch 182 / 500 | iteration 10 / 30 | Total Loss: 4.716207027435303 | KNN Loss: 3.7061922550201416 | BCE Loss: 1.0100146532058716
Epoch 182 / 500 | iteration 15 / 30 | Total Loss: 4.746191024780273 | KNN Loss: 3.7029364109039307 | BCE Loss: 1.0432543754577637
Epoch 182 / 500 | iteration 20 / 30 | Total Loss: 4.789287567138672 | KNN Loss: 3.7146055698394775 | BCE Loss: 1.0746822357177734
Epoch 182 / 500 | iteration 25 / 30 | Total Loss: 4.7031474113464355 | KNN Loss: 3.6847736835479736 | BCE Loss: 1.0183738470077515
Epoch 183 / 500 | iteration 0 / 30 | Total Loss: 4.7702484130859375 | KNN Loss: 3.7433109283447266 | BCE Loss: 1.026937484741211
Epoch 183 / 500 | iteration 5 / 30 | Total Loss: 4.699399471282959 | KNN Loss: 3.6949753761291504 | BCE Loss: 1.0044242143630981
Epoch 183 / 500 | iteration 10 / 30 | Total Loss: 4.704554080963135 | KNN Loss: 3.7164793014526367 | BCE Loss: 0.9880748987197876
Epoch 183 / 500 | iteration 15 / 30 | Total Loss: 4.745429515838623 | KNN Loss: 3.717572927

Epoch 192 / 500 | iteration 25 / 30 | Total Loss: 4.713300704956055 | KNN Loss: 3.688751697540283 | BCE Loss: 1.0245490074157715
Epoch 193 / 500 | iteration 0 / 30 | Total Loss: 4.723952293395996 | KNN Loss: 3.714909076690674 | BCE Loss: 1.0090430974960327
Epoch 193 / 500 | iteration 5 / 30 | Total Loss: 4.708863258361816 | KNN Loss: 3.7277944087982178 | BCE Loss: 0.981069028377533
Epoch 193 / 500 | iteration 10 / 30 | Total Loss: 4.756234645843506 | KNN Loss: 3.720839262008667 | BCE Loss: 1.0353952646255493
Epoch 193 / 500 | iteration 15 / 30 | Total Loss: 4.751707077026367 | KNN Loss: 3.7351186275482178 | BCE Loss: 1.0165886878967285
Epoch 193 / 500 | iteration 20 / 30 | Total Loss: 4.728169918060303 | KNN Loss: 3.7104337215423584 | BCE Loss: 1.0177360773086548
Epoch 193 / 500 | iteration 25 / 30 | Total Loss: 4.722468376159668 | KNN Loss: 3.699815034866333 | BCE Loss: 1.022653341293335
Epoch 194 / 500 | iteration 0 / 30 | Total Loss: 4.726489543914795 | KNN Loss: 3.7234554290771484 

Epoch 203 / 500 | iteration 15 / 30 | Total Loss: 4.708405017852783 | KNN Loss: 3.7081358432769775 | BCE Loss: 1.0002692937850952
Epoch 203 / 500 | iteration 20 / 30 | Total Loss: 4.768553733825684 | KNN Loss: 3.729435682296753 | BCE Loss: 1.0391178131103516
Epoch 203 / 500 | iteration 25 / 30 | Total Loss: 4.720483779907227 | KNN Loss: 3.7037880420684814 | BCE Loss: 1.016695499420166
Epoch 204 / 500 | iteration 0 / 30 | Total Loss: 4.730031967163086 | KNN Loss: 3.7115371227264404 | BCE Loss: 1.018494963645935
Epoch 204 / 500 | iteration 5 / 30 | Total Loss: 4.7294230461120605 | KNN Loss: 3.693521738052368 | BCE Loss: 1.0359013080596924
Epoch 204 / 500 | iteration 10 / 30 | Total Loss: 4.683966159820557 | KNN Loss: 3.6786420345306396 | BCE Loss: 1.005324125289917
Epoch 204 / 500 | iteration 15 / 30 | Total Loss: 4.739405632019043 | KNN Loss: 3.7060437202453613 | BCE Loss: 1.0333621501922607
Epoch 204 / 500 | iteration 20 / 30 | Total Loss: 4.713843822479248 | KNN Loss: 3.70433449745178

Epoch 214 / 500 | iteration 5 / 30 | Total Loss: 4.764874458312988 | KNN Loss: 3.7142367362976074 | BCE Loss: 1.0506376028060913
Epoch 214 / 500 | iteration 10 / 30 | Total Loss: 4.754358768463135 | KNN Loss: 3.727163791656494 | BCE Loss: 1.0271949768066406
Epoch 214 / 500 | iteration 15 / 30 | Total Loss: 4.7194318771362305 | KNN Loss: 3.703477382659912 | BCE Loss: 1.015954613685608
Epoch 214 / 500 | iteration 20 / 30 | Total Loss: 4.709132671356201 | KNN Loss: 3.6900923252105713 | BCE Loss: 1.0190404653549194
Epoch 214 / 500 | iteration 25 / 30 | Total Loss: 4.708712577819824 | KNN Loss: 3.6976027488708496 | BCE Loss: 1.0111100673675537
Epoch 215 / 500 | iteration 0 / 30 | Total Loss: 4.748821258544922 | KNN Loss: 3.741931200027466 | BCE Loss: 1.0068902969360352
Epoch 215 / 500 | iteration 5 / 30 | Total Loss: 4.7399187088012695 | KNN Loss: 3.728058099746704 | BCE Loss: 1.0118608474731445
Epoch 215 / 500 | iteration 10 / 30 | Total Loss: 4.717723846435547 | KNN Loss: 3.70345711708068

Epoch 224 / 500 | iteration 25 / 30 | Total Loss: 4.706443786621094 | KNN Loss: 3.6974732875823975 | BCE Loss: 1.0089704990386963
Epoch 225 / 500 | iteration 0 / 30 | Total Loss: 4.730773448944092 | KNN Loss: 3.704136848449707 | BCE Loss: 1.0266367197036743
Epoch 225 / 500 | iteration 5 / 30 | Total Loss: 4.755996227264404 | KNN Loss: 3.7244274616241455 | BCE Loss: 1.0315686464309692
Epoch 225 / 500 | iteration 10 / 30 | Total Loss: 4.7377753257751465 | KNN Loss: 3.7246792316436768 | BCE Loss: 1.0130960941314697
Epoch 225 / 500 | iteration 15 / 30 | Total Loss: 4.77932596206665 | KNN Loss: 3.7532176971435547 | BCE Loss: 1.0261081457138062
Epoch 225 / 500 | iteration 20 / 30 | Total Loss: 4.741094589233398 | KNN Loss: 3.700617551803589 | BCE Loss: 1.0404767990112305
Epoch 225 / 500 | iteration 25 / 30 | Total Loss: 4.732032299041748 | KNN Loss: 3.705249071121216 | BCE Loss: 1.0267832279205322
Epoch 226 / 500 | iteration 0 / 30 | Total Loss: 4.762996673583984 | KNN Loss: 3.73645377159118

Epoch 235 / 500 | iteration 15 / 30 | Total Loss: 4.765722274780273 | KNN Loss: 3.724414825439453 | BCE Loss: 1.0413073301315308
Epoch 235 / 500 | iteration 20 / 30 | Total Loss: 4.672757148742676 | KNN Loss: 3.668394088745117 | BCE Loss: 1.0043628215789795
Epoch 235 / 500 | iteration 25 / 30 | Total Loss: 4.731295585632324 | KNN Loss: 3.7123425006866455 | BCE Loss: 1.0189533233642578
Epoch 236 / 500 | iteration 0 / 30 | Total Loss: 4.731834888458252 | KNN Loss: 3.6939644813537598 | BCE Loss: 1.0378704071044922
Epoch 236 / 500 | iteration 5 / 30 | Total Loss: 4.758278846740723 | KNN Loss: 3.7176353931427 | BCE Loss: 1.0406434535980225
Epoch 236 / 500 | iteration 10 / 30 | Total Loss: 4.75576114654541 | KNN Loss: 3.7350966930389404 | BCE Loss: 1.0206644535064697
Epoch 236 / 500 | iteration 15 / 30 | Total Loss: 4.7541961669921875 | KNN Loss: 3.7285399436950684 | BCE Loss: 1.0256564617156982
Epoch 236 / 500 | iteration 20 / 30 | Total Loss: 4.684233665466309 | KNN Loss: 3.686177730560302

Epoch 246 / 500 | iteration 5 / 30 | Total Loss: 4.711285591125488 | KNN Loss: 3.690167188644409 | BCE Loss: 1.0211186408996582
Epoch 246 / 500 | iteration 10 / 30 | Total Loss: 4.708611965179443 | KNN Loss: 3.6842329502105713 | BCE Loss: 1.0243788957595825
Epoch 246 / 500 | iteration 15 / 30 | Total Loss: 4.726962089538574 | KNN Loss: 3.7061986923217773 | BCE Loss: 1.0207631587982178
Epoch 246 / 500 | iteration 20 / 30 | Total Loss: 4.72654390335083 | KNN Loss: 3.7096428871154785 | BCE Loss: 1.016900897026062
Epoch 246 / 500 | iteration 25 / 30 | Total Loss: 4.762096405029297 | KNN Loss: 3.719595432281494 | BCE Loss: 1.0425007343292236
Epoch 247 / 500 | iteration 0 / 30 | Total Loss: 4.755941867828369 | KNN Loss: 3.7365715503692627 | BCE Loss: 1.019370436668396
Epoch 247 / 500 | iteration 5 / 30 | Total Loss: 4.743181228637695 | KNN Loss: 3.689335823059082 | BCE Loss: 1.0538454055786133
Epoch 247 / 500 | iteration 10 / 30 | Total Loss: 4.727602958679199 | KNN Loss: 3.694179058074951 |

Epoch 256 / 500 | iteration 25 / 30 | Total Loss: 4.723313331604004 | KNN Loss: 3.6996958255767822 | BCE Loss: 1.0236172676086426
Epoch 257 / 500 | iteration 0 / 30 | Total Loss: 4.7329888343811035 | KNN Loss: 3.7090611457824707 | BCE Loss: 1.0239276885986328
Epoch 257 / 500 | iteration 5 / 30 | Total Loss: 4.743885040283203 | KNN Loss: 3.693159580230713 | BCE Loss: 1.0507256984710693
Epoch 257 / 500 | iteration 10 / 30 | Total Loss: 4.718471527099609 | KNN Loss: 3.7119100093841553 | BCE Loss: 1.0065617561340332
Epoch 257 / 500 | iteration 15 / 30 | Total Loss: 4.800650596618652 | KNN Loss: 3.725152015686035 | BCE Loss: 1.075498342514038
Epoch 257 / 500 | iteration 20 / 30 | Total Loss: 4.682872295379639 | KNN Loss: 3.700822591781616 | BCE Loss: 0.982049822807312
Epoch 257 / 500 | iteration 25 / 30 | Total Loss: 4.744189739227295 | KNN Loss: 3.7414488792419434 | BCE Loss: 1.0027409791946411
Epoch 258 / 500 | iteration 0 / 30 | Total Loss: 4.725719451904297 | KNN Loss: 3.68992018699646 

Epoch 267 / 500 | iteration 15 / 30 | Total Loss: 4.713976860046387 | KNN Loss: 3.717313766479492 | BCE Loss: 0.996662974357605
Epoch 267 / 500 | iteration 20 / 30 | Total Loss: 4.707190036773682 | KNN Loss: 3.684222936630249 | BCE Loss: 1.0229671001434326
Epoch 267 / 500 | iteration 25 / 30 | Total Loss: 4.704684257507324 | KNN Loss: 3.6773910522460938 | BCE Loss: 1.0272934436798096
Epoch 268 / 500 | iteration 0 / 30 | Total Loss: 4.7339863777160645 | KNN Loss: 3.6897499561309814 | BCE Loss: 1.044236421585083
Epoch 268 / 500 | iteration 5 / 30 | Total Loss: 4.736303806304932 | KNN Loss: 3.729484796524048 | BCE Loss: 1.0068188905715942
Epoch 268 / 500 | iteration 10 / 30 | Total Loss: 4.690225124359131 | KNN Loss: 3.691316604614258 | BCE Loss: 0.9989083409309387
Epoch 268 / 500 | iteration 15 / 30 | Total Loss: 4.743015289306641 | KNN Loss: 3.716688394546509 | BCE Loss: 1.026327133178711
Epoch 268 / 500 | iteration 20 / 30 | Total Loss: 4.698496341705322 | KNN Loss: 3.699834108352661 |

Epoch 278 / 500 | iteration 5 / 30 | Total Loss: 4.7282891273498535 | KNN Loss: 3.6897709369659424 | BCE Loss: 1.0385181903839111
Epoch 278 / 500 | iteration 10 / 30 | Total Loss: 4.727675437927246 | KNN Loss: 3.684317111968994 | BCE Loss: 1.0433580875396729
Epoch 278 / 500 | iteration 15 / 30 | Total Loss: 4.756896495819092 | KNN Loss: 3.726398468017578 | BCE Loss: 1.0304981470108032
Epoch 278 / 500 | iteration 20 / 30 | Total Loss: 4.740072727203369 | KNN Loss: 3.694629430770874 | BCE Loss: 1.0454431772232056
Epoch 278 / 500 | iteration 25 / 30 | Total Loss: 4.71905517578125 | KNN Loss: 3.6993558406829834 | BCE Loss: 1.0196993350982666
Epoch 279 / 500 | iteration 0 / 30 | Total Loss: 4.702778339385986 | KNN Loss: 3.6962029933929443 | BCE Loss: 1.006575345993042
Epoch 279 / 500 | iteration 5 / 30 | Total Loss: 4.701593399047852 | KNN Loss: 3.7081174850463867 | BCE Loss: 0.9934757947921753
Epoch 279 / 500 | iteration 10 / 30 | Total Loss: 4.713642120361328 | KNN Loss: 3.705548286437988

Epoch 288 / 500 | iteration 20 / 30 | Total Loss: 4.7249908447265625 | KNN Loss: 3.6939847469329834 | BCE Loss: 1.031006097793579
Epoch 288 / 500 | iteration 25 / 30 | Total Loss: 4.720367431640625 | KNN Loss: 3.721616744995117 | BCE Loss: 0.9987505674362183
Epoch 289 / 500 | iteration 0 / 30 | Total Loss: 4.660045623779297 | KNN Loss: 3.6917829513549805 | BCE Loss: 0.9682624340057373
Epoch 289 / 500 | iteration 5 / 30 | Total Loss: 4.712893962860107 | KNN Loss: 3.700392246246338 | BCE Loss: 1.0125017166137695
Epoch 289 / 500 | iteration 10 / 30 | Total Loss: 4.739830493927002 | KNN Loss: 3.7237322330474854 | BCE Loss: 1.0160982608795166
Epoch 289 / 500 | iteration 15 / 30 | Total Loss: 4.728182315826416 | KNN Loss: 3.687965154647827 | BCE Loss: 1.0402171611785889
Epoch 289 / 500 | iteration 20 / 30 | Total Loss: 4.708847999572754 | KNN Loss: 3.687863349914551 | BCE Loss: 1.0209846496582031
Epoch 289 / 500 | iteration 25 / 30 | Total Loss: 4.700455665588379 | KNN Loss: 3.65868043899536

Epoch 299 / 500 | iteration 10 / 30 | Total Loss: 4.705644130706787 | KNN Loss: 3.704404592514038 | BCE Loss: 1.001239538192749
Epoch 299 / 500 | iteration 15 / 30 | Total Loss: 4.730780601501465 | KNN Loss: 3.71010684967041 | BCE Loss: 1.0206736326217651
Epoch 299 / 500 | iteration 20 / 30 | Total Loss: 4.734591007232666 | KNN Loss: 3.7141847610473633 | BCE Loss: 1.0204062461853027
Epoch 299 / 500 | iteration 25 / 30 | Total Loss: 4.731437683105469 | KNN Loss: 3.7091715335845947 | BCE Loss: 1.0222663879394531
Epoch 300 / 500 | iteration 0 / 30 | Total Loss: 4.754538536071777 | KNN Loss: 3.751600503921509 | BCE Loss: 1.002937912940979
Epoch 300 / 500 | iteration 5 / 30 | Total Loss: 4.721158981323242 | KNN Loss: 3.690776824951172 | BCE Loss: 1.0303820371627808
Epoch 300 / 500 | iteration 10 / 30 | Total Loss: 4.759917736053467 | KNN Loss: 3.7397727966308594 | BCE Loss: 1.0201449394226074
Epoch 300 / 500 | iteration 15 / 30 | Total Loss: 4.750723361968994 | KNN Loss: 3.702622890472412 |

Epoch 310 / 500 | iteration 0 / 30 | Total Loss: 4.710790634155273 | KNN Loss: 3.684206485748291 | BCE Loss: 1.0265841484069824
Epoch 310 / 500 | iteration 5 / 30 | Total Loss: 4.718181610107422 | KNN Loss: 3.703930139541626 | BCE Loss: 1.014251470565796
Epoch 310 / 500 | iteration 10 / 30 | Total Loss: 4.72463846206665 | KNN Loss: 3.7226040363311768 | BCE Loss: 1.0020344257354736
Epoch 310 / 500 | iteration 15 / 30 | Total Loss: 4.705843448638916 | KNN Loss: 3.7026607990264893 | BCE Loss: 1.0031825304031372
Epoch 310 / 500 | iteration 20 / 30 | Total Loss: 4.711874008178711 | KNN Loss: 3.6819443702697754 | BCE Loss: 1.029929518699646
Epoch 310 / 500 | iteration 25 / 30 | Total Loss: 4.721822261810303 | KNN Loss: 3.6949875354766846 | BCE Loss: 1.0268347263336182
Epoch 311 / 500 | iteration 0 / 30 | Total Loss: 4.758835792541504 | KNN Loss: 3.6966676712036133 | BCE Loss: 1.0621683597564697
Epoch 311 / 500 | iteration 5 / 30 | Total Loss: 4.779362678527832 | KNN Loss: 3.741718292236328 |

Epoch 320 / 500 | iteration 20 / 30 | Total Loss: 4.736855983734131 | KNN Loss: 3.724478244781494 | BCE Loss: 1.0123776197433472
Epoch 320 / 500 | iteration 25 / 30 | Total Loss: 4.745788097381592 | KNN Loss: 3.7305126190185547 | BCE Loss: 1.015275478363037
Epoch 321 / 500 | iteration 0 / 30 | Total Loss: 4.706618309020996 | KNN Loss: 3.687767505645752 | BCE Loss: 1.018850564956665
Epoch 321 / 500 | iteration 5 / 30 | Total Loss: 4.713393688201904 | KNN Loss: 3.687021017074585 | BCE Loss: 1.0263726711273193
Epoch 321 / 500 | iteration 10 / 30 | Total Loss: 4.737730979919434 | KNN Loss: 3.7323317527770996 | BCE Loss: 1.005399465560913
Epoch 321 / 500 | iteration 15 / 30 | Total Loss: 4.754436016082764 | KNN Loss: 3.7122859954833984 | BCE Loss: 1.0421500205993652
Epoch 321 / 500 | iteration 20 / 30 | Total Loss: 4.700850963592529 | KNN Loss: 3.6889078617095947 | BCE Loss: 1.0119431018829346
Epoch 321 / 500 | iteration 25 / 30 | Total Loss: 4.694737911224365 | KNN Loss: 3.6849381923675537

Epoch 331 / 500 | iteration 5 / 30 | Total Loss: 4.747918605804443 | KNN Loss: 3.704313039779663 | BCE Loss: 1.0436055660247803
Epoch 331 / 500 | iteration 10 / 30 | Total Loss: 4.689865589141846 | KNN Loss: 3.6858856678009033 | BCE Loss: 1.0039799213409424
Epoch 331 / 500 | iteration 15 / 30 | Total Loss: 4.693680286407471 | KNN Loss: 3.6991167068481445 | BCE Loss: 0.9945634007453918
Epoch 331 / 500 | iteration 20 / 30 | Total Loss: 4.7407026290893555 | KNN Loss: 3.7134084701538086 | BCE Loss: 1.0272942781448364
Epoch 331 / 500 | iteration 25 / 30 | Total Loss: 4.706449508666992 | KNN Loss: 3.6962218284606934 | BCE Loss: 1.0102277994155884
Epoch 332 / 500 | iteration 0 / 30 | Total Loss: 4.723599910736084 | KNN Loss: 3.6896913051605225 | BCE Loss: 1.033908724784851
Epoch 332 / 500 | iteration 5 / 30 | Total Loss: 4.70393180847168 | KNN Loss: 3.7126567363739014 | BCE Loss: 0.9912748336791992
Epoch 332 / 500 | iteration 10 / 30 | Total Loss: 4.7195258140563965 | KNN Loss: 3.698156595230

Epoch 341 / 500 | iteration 20 / 30 | Total Loss: 4.7226881980896 | KNN Loss: 3.708848714828491 | BCE Loss: 1.0138394832611084
Epoch 341 / 500 | iteration 25 / 30 | Total Loss: 4.672100067138672 | KNN Loss: 3.6779356002807617 | BCE Loss: 0.9941645860671997
Epoch 342 / 500 | iteration 0 / 30 | Total Loss: 4.723851680755615 | KNN Loss: 3.7013914585113525 | BCE Loss: 1.0224602222442627
Epoch 342 / 500 | iteration 5 / 30 | Total Loss: 4.720301151275635 | KNN Loss: 3.671569585800171 | BCE Loss: 1.0487314462661743
Epoch 342 / 500 | iteration 10 / 30 | Total Loss: 4.747027397155762 | KNN Loss: 3.750065803527832 | BCE Loss: 0.9969615936279297
Epoch 342 / 500 | iteration 15 / 30 | Total Loss: 4.733222961425781 | KNN Loss: 3.701648473739624 | BCE Loss: 1.0315744876861572
Epoch 342 / 500 | iteration 20 / 30 | Total Loss: 4.701747894287109 | KNN Loss: 3.689814567565918 | BCE Loss: 1.0119333267211914
Epoch 342 / 500 | iteration 25 / 30 | Total Loss: 4.699629783630371 | KNN Loss: 3.7087562084198 | B

Epoch 352 / 500 | iteration 10 / 30 | Total Loss: 4.6947760581970215 | KNN Loss: 3.6749894618988037 | BCE Loss: 1.0197867155075073
Epoch 352 / 500 | iteration 15 / 30 | Total Loss: 4.760902404785156 | KNN Loss: 3.74523663520813 | BCE Loss: 1.015665888786316
Epoch 352 / 500 | iteration 20 / 30 | Total Loss: 4.774825096130371 | KNN Loss: 3.7355313301086426 | BCE Loss: 1.0392935276031494
Epoch 352 / 500 | iteration 25 / 30 | Total Loss: 4.720942497253418 | KNN Loss: 3.7027907371520996 | BCE Loss: 1.0181517601013184
Epoch 353 / 500 | iteration 0 / 30 | Total Loss: 4.737308979034424 | KNN Loss: 3.710371732711792 | BCE Loss: 1.0269372463226318
Epoch 353 / 500 | iteration 5 / 30 | Total Loss: 4.728839874267578 | KNN Loss: 3.7028276920318604 | BCE Loss: 1.0260121822357178
Epoch 353 / 500 | iteration 10 / 30 | Total Loss: 4.686656951904297 | KNN Loss: 3.6773147583007812 | BCE Loss: 1.0093424320220947
Epoch 353 / 500 | iteration 15 / 30 | Total Loss: 4.700539588928223 | KNN Loss: 3.6829805374145

Epoch 363 / 500 | iteration 0 / 30 | Total Loss: 4.722850322723389 | KNN Loss: 3.708188533782959 | BCE Loss: 1.0146616697311401
Epoch 363 / 500 | iteration 5 / 30 | Total Loss: 4.7519073486328125 | KNN Loss: 3.7114431858062744 | BCE Loss: 1.0404644012451172
Epoch 363 / 500 | iteration 10 / 30 | Total Loss: 4.716756820678711 | KNN Loss: 3.7218017578125 | BCE Loss: 0.9949550628662109
Epoch 363 / 500 | iteration 15 / 30 | Total Loss: 4.7560319900512695 | KNN Loss: 3.720256805419922 | BCE Loss: 1.0357754230499268
Epoch 363 / 500 | iteration 20 / 30 | Total Loss: 4.779790878295898 | KNN Loss: 3.7272849082946777 | BCE Loss: 1.0525059700012207
Epoch 363 / 500 | iteration 25 / 30 | Total Loss: 4.745134353637695 | KNN Loss: 3.7289514541625977 | BCE Loss: 1.0161831378936768
Epoch 364 / 500 | iteration 0 / 30 | Total Loss: 4.747990608215332 | KNN Loss: 3.7096035480499268 | BCE Loss: 1.0383872985839844
Epoch 364 / 500 | iteration 5 / 30 | Total Loss: 4.780076026916504 | KNN Loss: 3.750590085983276

Epoch 373 / 500 | iteration 20 / 30 | Total Loss: 4.708643913269043 | KNN Loss: 3.6894748210906982 | BCE Loss: 1.0191688537597656
Epoch 373 / 500 | iteration 25 / 30 | Total Loss: 4.73189640045166 | KNN Loss: 3.7173259258270264 | BCE Loss: 1.0145704746246338
Epoch 374 / 500 | iteration 0 / 30 | Total Loss: 4.745631217956543 | KNN Loss: 3.7286622524261475 | BCE Loss: 1.0169689655303955
Epoch 374 / 500 | iteration 5 / 30 | Total Loss: 4.734236717224121 | KNN Loss: 3.730241537094116 | BCE Loss: 1.0039949417114258
Epoch 374 / 500 | iteration 10 / 30 | Total Loss: 4.695277214050293 | KNN Loss: 3.68479323387146 | BCE Loss: 1.0104838609695435
Epoch 374 / 500 | iteration 15 / 30 | Total Loss: 4.719077110290527 | KNN Loss: 3.708599328994751 | BCE Loss: 1.0104777812957764
Epoch 374 / 500 | iteration 20 / 30 | Total Loss: 4.7434916496276855 | KNN Loss: 3.7338762283325195 | BCE Loss: 1.0096155405044556
Epoch 374 / 500 | iteration 25 / 30 | Total Loss: 4.736710548400879 | KNN Loss: 3.70434761047363

Epoch 384 / 500 | iteration 10 / 30 | Total Loss: 4.725707054138184 | KNN Loss: 3.6938588619232178 | BCE Loss: 1.0318479537963867
Epoch 384 / 500 | iteration 15 / 30 | Total Loss: 4.705301761627197 | KNN Loss: 3.6808857917785645 | BCE Loss: 1.0244159698486328
Epoch 384 / 500 | iteration 20 / 30 | Total Loss: 4.7588348388671875 | KNN Loss: 3.7512500286102295 | BCE Loss: 1.007584810256958
Epoch 384 / 500 | iteration 25 / 30 | Total Loss: 4.716248512268066 | KNN Loss: 3.687664747238159 | BCE Loss: 1.0285835266113281
Epoch 385 / 500 | iteration 0 / 30 | Total Loss: 4.756243705749512 | KNN Loss: 3.7247273921966553 | BCE Loss: 1.0315163135528564
Epoch 385 / 500 | iteration 5 / 30 | Total Loss: 4.699306964874268 | KNN Loss: 3.697979211807251 | BCE Loss: 1.0013277530670166
Epoch 385 / 500 | iteration 10 / 30 | Total Loss: 4.739497184753418 | KNN Loss: 3.703273057937622 | BCE Loss: 1.0362238883972168
Epoch 385 / 500 | iteration 15 / 30 | Total Loss: 4.73333215713501 | KNN Loss: 3.70688676834106

Epoch 395 / 500 | iteration 0 / 30 | Total Loss: 4.728349685668945 | KNN Loss: 3.71309232711792 | BCE Loss: 1.0152573585510254
Epoch 395 / 500 | iteration 5 / 30 | Total Loss: 4.735073566436768 | KNN Loss: 3.713737726211548 | BCE Loss: 1.0213358402252197
Epoch 395 / 500 | iteration 10 / 30 | Total Loss: 4.696706295013428 | KNN Loss: 3.681657314300537 | BCE Loss: 1.0150489807128906
Epoch 395 / 500 | iteration 15 / 30 | Total Loss: 4.7346320152282715 | KNN Loss: 3.7257397174835205 | BCE Loss: 1.0088921785354614
Epoch 395 / 500 | iteration 20 / 30 | Total Loss: 4.738533020019531 | KNN Loss: 3.7023513317108154 | BCE Loss: 1.0361816883087158
Epoch 395 / 500 | iteration 25 / 30 | Total Loss: 4.736053943634033 | KNN Loss: 3.693506956100464 | BCE Loss: 1.0425468683242798
Epoch 396 / 500 | iteration 0 / 30 | Total Loss: 4.716411113739014 | KNN Loss: 3.685304641723633 | BCE Loss: 1.0311065912246704
Epoch 396 / 500 | iteration 5 / 30 | Total Loss: 4.726442337036133 | KNN Loss: 3.7217354774475098 

Epoch 405 / 500 | iteration 20 / 30 | Total Loss: 4.747786998748779 | KNN Loss: 3.698589324951172 | BCE Loss: 1.0491976737976074
Epoch 405 / 500 | iteration 25 / 30 | Total Loss: 4.723831653594971 | KNN Loss: 3.726553440093994 | BCE Loss: 0.9972783327102661
Epoch 406 / 500 | iteration 0 / 30 | Total Loss: 4.7312517166137695 | KNN Loss: 3.7239890098571777 | BCE Loss: 1.0072624683380127
Epoch 406 / 500 | iteration 5 / 30 | Total Loss: 4.723222255706787 | KNN Loss: 3.719123601913452 | BCE Loss: 1.0040987730026245
Epoch 406 / 500 | iteration 10 / 30 | Total Loss: 4.71256160736084 | KNN Loss: 3.6877942085266113 | BCE Loss: 1.0247673988342285
Epoch 406 / 500 | iteration 15 / 30 | Total Loss: 4.734351634979248 | KNN Loss: 3.7214672565460205 | BCE Loss: 1.012884497642517
Epoch 406 / 500 | iteration 20 / 30 | Total Loss: 4.70643424987793 | KNN Loss: 3.68967342376709 | BCE Loss: 1.0167605876922607
Epoch 406 / 500 | iteration 25 / 30 | Total Loss: 4.746150493621826 | KNN Loss: 3.7100167274475098 

Epoch 416 / 500 | iteration 10 / 30 | Total Loss: 4.705761909484863 | KNN Loss: 3.7020103931427 | BCE Loss: 1.003751516342163
Epoch 416 / 500 | iteration 15 / 30 | Total Loss: 4.718288421630859 | KNN Loss: 3.6987667083740234 | BCE Loss: 1.019521951675415
Epoch 416 / 500 | iteration 20 / 30 | Total Loss: 4.730375289916992 | KNN Loss: 3.707319974899292 | BCE Loss: 1.0230554342269897
Epoch 416 / 500 | iteration 25 / 30 | Total Loss: 4.7313642501831055 | KNN Loss: 3.7257070541381836 | BCE Loss: 1.005657434463501
Epoch 417 / 500 | iteration 0 / 30 | Total Loss: 4.709545612335205 | KNN Loss: 3.7183635234832764 | BCE Loss: 0.9911822080612183
Epoch 417 / 500 | iteration 5 / 30 | Total Loss: 4.6713361740112305 | KNN Loss: 3.6820507049560547 | BCE Loss: 0.9892855882644653
Epoch 417 / 500 | iteration 10 / 30 | Total Loss: 4.73178243637085 | KNN Loss: 3.692059278488159 | BCE Loss: 1.0397231578826904
Epoch 417 / 500 | iteration 15 / 30 | Total Loss: 4.707244396209717 | KNN Loss: 3.694640874862671 |

Epoch 427 / 500 | iteration 0 / 30 | Total Loss: 4.7531418800354 | KNN Loss: 3.7138781547546387 | BCE Loss: 1.0392636060714722
Epoch 427 / 500 | iteration 5 / 30 | Total Loss: 4.743902683258057 | KNN Loss: 3.710149049758911 | BCE Loss: 1.033753514289856
Epoch 427 / 500 | iteration 10 / 30 | Total Loss: 4.715070724487305 | KNN Loss: 3.7111356258392334 | BCE Loss: 1.0039349794387817
Epoch 427 / 500 | iteration 15 / 30 | Total Loss: 4.768191814422607 | KNN Loss: 3.745206594467163 | BCE Loss: 1.0229852199554443
Epoch 427 / 500 | iteration 20 / 30 | Total Loss: 4.698960781097412 | KNN Loss: 3.6769661903381348 | BCE Loss: 1.0219945907592773
Epoch 427 / 500 | iteration 25 / 30 | Total Loss: 4.7117390632629395 | KNN Loss: 3.7186484336853027 | BCE Loss: 0.9930906891822815
Epoch 428 / 500 | iteration 0 / 30 | Total Loss: 4.717439651489258 | KNN Loss: 3.697937250137329 | BCE Loss: 1.0195026397705078
Epoch 428 / 500 | iteration 5 / 30 | Total Loss: 4.7236833572387695 | KNN Loss: 3.7015960216522217

Epoch 437 / 500 | iteration 20 / 30 | Total Loss: 4.741335868835449 | KNN Loss: 3.726243257522583 | BCE Loss: 1.0150928497314453
Epoch 437 / 500 | iteration 25 / 30 | Total Loss: 4.785909175872803 | KNN Loss: 3.7275731563568115 | BCE Loss: 1.0583359003067017
Epoch 438 / 500 | iteration 0 / 30 | Total Loss: 4.708998203277588 | KNN Loss: 3.6945319175720215 | BCE Loss: 1.0144662857055664
Epoch 438 / 500 | iteration 5 / 30 | Total Loss: 4.715358734130859 | KNN Loss: 3.6989035606384277 | BCE Loss: 1.0164551734924316
Epoch 438 / 500 | iteration 10 / 30 | Total Loss: 4.741302967071533 | KNN Loss: 3.7196907997131348 | BCE Loss: 1.0216121673583984
Epoch 438 / 500 | iteration 15 / 30 | Total Loss: 4.79486608505249 | KNN Loss: 3.7653303146362305 | BCE Loss: 1.0295357704162598
Epoch 438 / 500 | iteration 20 / 30 | Total Loss: 4.73509407043457 | KNN Loss: 3.7317440509796143 | BCE Loss: 1.0033502578735352
Epoch 438 / 500 | iteration 25 / 30 | Total Loss: 4.703238010406494 | KNN Loss: 3.6846051216125

Epoch 448 / 500 | iteration 10 / 30 | Total Loss: 4.7324347496032715 | KNN Loss: 3.7232372760772705 | BCE Loss: 1.009197473526001
Epoch 448 / 500 | iteration 15 / 30 | Total Loss: 4.746771812438965 | KNN Loss: 3.7235639095306396 | BCE Loss: 1.0232077836990356
Epoch 448 / 500 | iteration 20 / 30 | Total Loss: 4.696274280548096 | KNN Loss: 3.687047004699707 | BCE Loss: 1.0092272758483887
Epoch 448 / 500 | iteration 25 / 30 | Total Loss: 4.720311641693115 | KNN Loss: 3.698303699493408 | BCE Loss: 1.0220078229904175
Epoch 449 / 500 | iteration 0 / 30 | Total Loss: 4.745702743530273 | KNN Loss: 3.7228331565856934 | BCE Loss: 1.022869348526001
Epoch 449 / 500 | iteration 5 / 30 | Total Loss: 4.699822425842285 | KNN Loss: 3.6966190338134766 | BCE Loss: 1.0032036304473877
Epoch 449 / 500 | iteration 10 / 30 | Total Loss: 4.670045375823975 | KNN Loss: 3.6735451221466064 | BCE Loss: 0.9965003728866577
Epoch 449 / 500 | iteration 15 / 30 | Total Loss: 4.763278961181641 | KNN Loss: 3.7159793376922

Epoch 458 / 500 | iteration 25 / 30 | Total Loss: 4.690435409545898 | KNN Loss: 3.6791489124298096 | BCE Loss: 1.011286735534668
Epoch 459 / 500 | iteration 0 / 30 | Total Loss: 4.689745903015137 | KNN Loss: 3.6812632083892822 | BCE Loss: 1.0084829330444336
Epoch 459 / 500 | iteration 5 / 30 | Total Loss: 4.755011558532715 | KNN Loss: 3.717414140701294 | BCE Loss: 1.037597417831421
Epoch 459 / 500 | iteration 10 / 30 | Total Loss: 4.73530387878418 | KNN Loss: 3.7053306102752686 | BCE Loss: 1.0299732685089111
Epoch 459 / 500 | iteration 15 / 30 | Total Loss: 4.7699384689331055 | KNN Loss: 3.739673614501953 | BCE Loss: 1.0302648544311523
Epoch 459 / 500 | iteration 20 / 30 | Total Loss: 4.737393379211426 | KNN Loss: 3.72377610206604 | BCE Loss: 1.0136171579360962
Epoch 459 / 500 | iteration 25 / 30 | Total Loss: 4.752344131469727 | KNN Loss: 3.733727216720581 | BCE Loss: 1.0186166763305664
Epoch 460 / 500 | iteration 0 / 30 | Total Loss: 4.694576740264893 | KNN Loss: 3.697068929672241 | 

Epoch 469 / 500 | iteration 15 / 30 | Total Loss: 4.790313720703125 | KNN Loss: 3.750507116317749 | BCE Loss: 1.0398067235946655
Epoch 469 / 500 | iteration 20 / 30 | Total Loss: 4.7207794189453125 | KNN Loss: 3.701083183288574 | BCE Loss: 1.0196962356567383
Epoch 469 / 500 | iteration 25 / 30 | Total Loss: 4.785818099975586 | KNN Loss: 3.729259490966797 | BCE Loss: 1.0565588474273682
Epoch 470 / 500 | iteration 0 / 30 | Total Loss: 4.736263275146484 | KNN Loss: 3.694779872894287 | BCE Loss: 1.0414836406707764
Epoch 470 / 500 | iteration 5 / 30 | Total Loss: 4.717832565307617 | KNN Loss: 3.720715284347534 | BCE Loss: 0.9971174597740173
Epoch 470 / 500 | iteration 10 / 30 | Total Loss: 4.718566417694092 | KNN Loss: 3.7001214027404785 | BCE Loss: 1.0184450149536133
Epoch 470 / 500 | iteration 15 / 30 | Total Loss: 4.682604789733887 | KNN Loss: 3.677612781524658 | BCE Loss: 1.004992127418518
Epoch 470 / 500 | iteration 20 / 30 | Total Loss: 4.745893478393555 | KNN Loss: 3.726593255996704 

Epoch 480 / 500 | iteration 5 / 30 | Total Loss: 4.699853897094727 | KNN Loss: 3.69313383102417 | BCE Loss: 1.0067198276519775
Epoch 480 / 500 | iteration 10 / 30 | Total Loss: 4.660924911499023 | KNN Loss: 3.6677608489990234 | BCE Loss: 0.9931640625
Epoch 480 / 500 | iteration 15 / 30 | Total Loss: 4.731444358825684 | KNN Loss: 3.7100045680999756 | BCE Loss: 1.021439552307129
Epoch 480 / 500 | iteration 20 / 30 | Total Loss: 4.725133419036865 | KNN Loss: 3.7023613452911377 | BCE Loss: 1.0227720737457275
Epoch 480 / 500 | iteration 25 / 30 | Total Loss: 4.721109390258789 | KNN Loss: 3.695420742034912 | BCE Loss: 1.0256887674331665
Epoch 481 / 500 | iteration 0 / 30 | Total Loss: 4.707658767700195 | KNN Loss: 3.6958277225494385 | BCE Loss: 1.011831283569336
Epoch 481 / 500 | iteration 5 / 30 | Total Loss: 4.735317230224609 | KNN Loss: 3.6926376819610596 | BCE Loss: 1.0426793098449707
Epoch 481 / 500 | iteration 10 / 30 | Total Loss: 4.760739803314209 | KNN Loss: 3.712620496749878 | BCE 

Epoch 490 / 500 | iteration 25 / 30 | Total Loss: 4.691200256347656 | KNN Loss: 3.7076048851013184 | BCE Loss: 0.9835952520370483
Epoch 491 / 500 | iteration 0 / 30 | Total Loss: 4.730219841003418 | KNN Loss: 3.7243521213531494 | BCE Loss: 1.0058677196502686
Epoch 491 / 500 | iteration 5 / 30 | Total Loss: 4.709720134735107 | KNN Loss: 3.6959407329559326 | BCE Loss: 1.0137794017791748
Epoch 491 / 500 | iteration 10 / 30 | Total Loss: 4.751674175262451 | KNN Loss: 3.746224880218506 | BCE Loss: 1.0054491758346558
Epoch 491 / 500 | iteration 15 / 30 | Total Loss: 4.745136260986328 | KNN Loss: 3.7450082302093506 | BCE Loss: 1.0001280307769775
Epoch 491 / 500 | iteration 20 / 30 | Total Loss: 4.755199432373047 | KNN Loss: 3.7198448181152344 | BCE Loss: 1.0353548526763916
Epoch 491 / 500 | iteration 25 / 30 | Total Loss: 4.71691370010376 | KNN Loss: 3.7100090980529785 | BCE Loss: 1.0069046020507812
Epoch 492 / 500 | iteration 0 / 30 | Total Loss: 4.70306921005249 | KNN Loss: 3.67787456512451

In [7]:
outputs, iterm = model(dataset[67][0].unsqueeze(0).to(device), return_intermidiate=True)
print(outputs)
print(outputs.min())
print(outputs.max())

tensor([[ 2.0376,  4.0441,  2.9325,  2.2331,  2.3031,  0.7054,  2.1160,  1.6226,
          2.6057,  1.9758,  1.6559,  1.7841,  0.4641,  1.8709,  1.3173,  1.0579,
          3.1501,  2.2467,  3.1790,  2.6883,  1.7776,  2.4152,  1.8910,  3.0578,
          2.8368,  1.8914,  2.1541,  1.4339,  1.7862,  0.6679, -0.0071,  0.7831,
          0.5321,  0.9999,  1.7067,  1.2602,  1.1641,  3.7554,  1.0405,  1.4261,
          0.7441, -1.0527, -0.0267,  2.4635,  2.0952,  0.8401, -0.3758,  0.2380,
          1.6674,  2.5574,  1.9121,  0.4660,  1.3182,  0.2007, -0.6321,  1.3695,
          1.4160,  1.5192,  1.3526,  1.9552,  0.1368,  1.1053,  0.3451,  1.2740,
          1.1374,  2.0587, -1.5341,  0.5228,  2.6743,  1.9066,  2.9133,  0.1733,
          1.4021,  2.7667,  1.9263,  1.1434,  0.5951,  0.6298,  0.0815,  1.9243,
          0.1050,  0.4397,  1.9906, -0.2688,  0.1968, -1.3471, -2.4413, -0.0647,
          0.7598, -1.8507,  0.6403,  0.1251, -0.5934, -1.3275,  0.4164,  1.4477,
         -0.3175, -0.4011,  

In [8]:
plt.figure()
plt.plot(losses)
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [9]:
dataset.transform = torchvision.transforms.Compose([
    BinaryEncodingTransform(mapping=dataset.items_to_idx),
]
)
dataset.target_transform = torchvision.transforms.Compose([
    BinaryEncodingTransform(mapping=dataset.items_to_idx),
]
)

In [10]:
dataset_ = [d[0].cpu() for d in dataset]

In [11]:
model = model.eval().cpu()
projections = model.calculate_intermidiate(dataset_)

100%|██████████| 15/15 [00:00<00:00, 99.09it/s] 


In [12]:
distances = pairwise_distances(projections)
# distances = np.triu(distances)
distances_f = distances.flatten()

plt.matshow(distances)
plt.colorbar()
plt.figure()
plt.hist(distances_f[distances_f > 0], bins=1000)
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

# Fit DBSCAN and calculate indices


In [13]:
clusters = DBSCAN(eps=0.2, min_samples=80).fit_predict(projections)
# scores = []
# best_score = float('inf')
# clusters = None
# range_ = list(range(5, 20))
# for k in tqdm(range_):
#     y = GaussianMixture(n_components=k).fit_predict(projections)
#     cur_score = davies_bouldin_score(projections, y)
#     scores.append(cur_score)
    
#     if cur_score < best_score:
#         best_score = cur_score
#         clusters = y

In [14]:
perplexity = 100

p = reduce_dims_and_plot(projections[clusters != -1],
                         y=clusters[clusters != -1],
                         title=f'perplexity: {perplexity}',
                         file_name=None,
                         perplexity=perplexity,
                         library='Multicore-TSNE',
                         perform_PCA=False,
                         projected=None,
                         figure_type='2d',
                         show_figure=True,
                         close_figure=False,
                         text=None)    

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [15]:
# from sklearn.tree import DecisionTreeClassifier
# from sklearn import tree
# from sklearn.tree import _tree

In [16]:
tensor_dataset = torch.stack(dataset_)

In [17]:
# clf = DecisionTreeClassifier(max_depth=200, min_samples_leaf=5)
# clf = clf.fit(tensor_dataset[clusters!=-1], clusters[clusters != -1])
# print(clf.score(tensor_dataset[clusters!=-1], clusters[clusters != -1]))
# print(clf.get_depth())

In [18]:
# scores = []
# for min_samples in range(1,50, 1):
#     clf = DecisionTreeClassifier(max_depth=200, min_samples_leaf=min_samples)
#     clf = clf.fit(tensor_dataset[clusters!=-1], clusters[clusters != -1])
#     scores.append(clf.score(tensor_dataset[clusters!=-1], clusters[clusters != -1]))
    
# plt.figure()
# plt.plot(list(range(1,50, 1)), scores)
# plt.show()

In [19]:
def get_rules(tree, feature_names, class_names):
    tree_ = tree.tree_
    feature_name = [
        feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!"
        for i in tree_.feature
    ]

    paths = []
    path = []
    
    def recurse(node, path, paths):
        if tree_.feature[node] != _tree.TREE_UNDEFINED:
            name = feature_name[node]
            threshold = tree_.threshold[node]
            p1, p2 = list(path), list(path)
#             p1 += [f"({name} <= {np.round(threshold, 3)})"]
            p1 += [(name, '<=', np.round(threshold, 3))]
            recurse(tree_.children_left[node], p1, paths)
            p2 += [(name, '>', np.round(threshold, 3))]
            recurse(tree_.children_right[node], p2, paths)
        else:
            path += [(tree_.value[node], tree_.n_node_samples[node])]
            paths += [path]
            
    recurse(0, path, paths)

    # sort by samples count
    samples_count = [p[-1][1] for p in paths]
    ii = list(np.argsort(samples_count))
    paths = [paths[i] for i in reversed(ii)]
    
    rules = []
    for path in paths:
        rule = []
        
        for p in path[:-1]:
            rule += [p]
        target = " then "
        if class_names is None:
            target += "response: "+str(np.round(path[-1][0][0][0],3))
        else:
            classes = path[-1][0][0]
            l = np.argmax(classes)
            target += f"class: {class_names[l]} (proba: {np.round(100.0*classes[l]/np.sum(classes),2)}%)"
           
        proba = np.round(100.0*classes[l]/np.sum(classes),2)
        target += f" | based on {path[-1][1]:,} samples"
        rule_wrapper = {'target': target, 'rule': rule, 'proba': proba}
        rules += [rule_wrapper]
        
    return rules

In [20]:
# rules = get_rules(clf, dataset.items, clusters[clusters != -1])

# for rule in rules:
#     n_pos = 0
#     for c,p,v in rule['rule']:
#         if p == '>':
#             n_pos += 1
#     rule['pos'] = n_pos

In [21]:
# plt.figure()
# probs = [r['proba'] for r in rules]
# plt.hist(probs, bins = 100)
# plt.show()

In [22]:
# rules = sorted(rules, key=lambda x:x['pos'])
# rules = [r for r in rules if r['proba'] > 50]
# print(len(rules))

In [23]:
# for i in range(17):
#     r_i = rules[i]
#     print(f"------------- rule {i} length {len(r_i)} -------------")
#     print(r_i)

# Train a Soft-Decision-Tree given the self-labels

## Prepare the dataset

In [24]:
tree_dataset = list(zip(tensor_dataset[clusters!=-1], clusters[clusters != -1]))
batch_size = 512
tree_loader = torch.utils.data.DataLoader(tree_dataset, batch_size=batch_size, shuffle=True)

# Define how we prune the weights of a node

In [25]:
def prune_node(node_weights, factor=1):
    w = node_weights.cpu().detach().numpy()
    mean_ = np.mean(w)
    std_ = np.std(w)
    node_weights[((mean_ - std_ * factor) < node_weights) & (node_weights < (mean_ + std_ * factor))] = 0
    return node_weights

def prune_node_keep(node_weights, keep=4):
    w = node_weights.cpu().detach().numpy()
    throw_idx = np.argsort(abs(w))[:-keep]
    node_weights[throw_idx] = 0
    return node_weights

def prune_tree(tree_, factor):
    new_weights = tree_.inner_nodes.weight.clone()
    for i in range(new_weights.shape[0]):
        res = prune_node_keep(new_weights[i, :], factor)
        new_weights[i, :] = res

    with torch.no_grad():
        tree_.inner_nodes.weight.copy_(new_weights)
        
def sparseness(x):
    s = []
    for i in range(x.shape[0]):
        x_ = x[i, :]
        sp = (len(x_) - torch.norm(x_, 0).item()) / len(x_)
        s.append(sp)
    return np.mean(s)

def compute_regularization_by_level(tree):
    total_reg = 0
    for i in range(tree.inner_nodes.weight.shape[0]):
        cur_level = np.floor(np.log2(i+1))
        node_reg = torch.norm(tree.inner_nodes.weight[i].view(-1), 2)
        total_reg += 2**(-cur_level) * node_reg
    return total_reg

def show_sparseness(tree):
    avg_sp = sparseness(tree.inner_nodes.weight)
    print(f"Average sparseness: {avg_sp}")
    layer = 0
    sps = []
    for i in range(tree.inner_nodes.weight.shape[0]):
        cur_layer = int(np.floor(np.log2(i+1)))
        if cur_layer != layer:
            print(f"layer {layer}: {np.mean(sps)}")
            sps = []
            layer = cur_layer

        x_ = tree.inner_nodes.weight[i, :]
        sp = (len(x_) - torch.norm(x_, 0).item()) / len(x_)
        sps.append(sp)
        
    return avg_sp

## Training configurations

In [26]:
def do_epoch(model, loader, device, log_interval, losses, accs, epoch, iteration):
    model = model.train()
    for batch_idx, (data, target) in enumerate(loader):
        iteration += 1
        data, target = data.to(device), target.to(device)

        output, penalty = tree.forward(data)

        # Loss
        loss_tree = criterion(output, target.view(-1))

        # Penalty
        loss_tree += penalty

        # Sparse regularization
#         fc_params = torch.cat([x.view(-1) for x in tree.inner_nodes.parameters()])
#         regularization = sparsity_lamda * torch.norm(fc_params, 2)
        regularization = sparsity_lamda * compute_regularization_by_level(tree)
        loss = loss_tree

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        losses.append(loss.item())

        pred = output.data.max(1)[1]
        correct = pred.eq(target.view(-1).data).sum()
        accs.append(correct.item() / data.size()[0])

        # Print training status
        if batch_idx % log_interval == 0:
            print(f"Epoch: {epoch:02d} | Batch: {batch_idx:03d} / {len(loader):03d} | Total loss: {loss.item():.3f} | Reg loss: {regularization.item():.3f} | Tree loss: {loss_tree.item():.3f} | Accuracy: {correct.item() / data.size()[0]:03f} | {round((time.time() - start_time) / iteration, 3)} sec/iter")
            
    return iteration


In [27]:
lr = 5e-3
weight_decay = 5e-4
sparsity_lamda = 2e-3
epochs = 100
output_dim = len(set(clusters))
log_interval = 1
use_cuda = device != 'cpu'

In [28]:
tree = SDT(input_dim=tensor_dataset.shape[1], output_dim=len(clusters - 1), depth=tree_depth, lamda=1e-3, use_cuda=use_cuda)
optimizer = torch.optim.Adam(tree.parameters(),
                                 lr=lr,
                                 weight_decay=weight_decay)
criterion = nn.CrossEntropyLoss()
tree = tree.to(device)

In [29]:
losses = []
accs = []
sparsity = []

In [30]:
start_time = time.time()
iteration = 0
for epoch in range(epochs):
    # Training
    avg_sp = show_sparseness(tree)
    sparsity.append(avg_sp)
    iteration = do_epoch(tree, tree_loader, device, log_interval, losses, accs, epoch, iteration)
    
    if epoch % 1 == 0:
        prune_tree(tree, factor=3)
        

Average sparseness: 0.0
layer 0: 0.0
layer 1: 0.0
layer 2: 0.0
layer 3: 0.0
layer 4: 0.0
layer 5: 0.0
layer 6: 0.0
layer 7: 0.0
layer 8: 0.0
Epoch: 00 | Batch: 000 / 026 | Total loss: 9.618 | Reg loss: 0.012 | Tree loss: 9.618 | Accuracy: 0.000000 | 0.887 sec/iter
Epoch: 00 | Batch: 001 / 026 | Total loss: 9.614 | Reg loss: 0.011 | Tree loss: 9.614 | Accuracy: 0.000000 | 0.824 sec/iter
Epoch: 00 | Batch: 002 / 026 | Total loss: 9.612 | Reg loss: 0.010 | Tree loss: 9.612 | Accuracy: 0.000000 | 0.808 sec/iter
Epoch: 00 | Batch: 003 / 026 | Total loss: 9.608 | Reg loss: 0.009 | Tree loss: 9.608 | Accuracy: 0.000000 | 0.798 sec/iter
Epoch: 00 | Batch: 004 / 026 | Total loss: 9.605 | Reg loss: 0.009 | Tree loss: 9.605 | Accuracy: 0.000000 | 0.792 sec/iter
Epoch: 00 | Batch: 005 / 026 | Total loss: 9.602 | Reg loss: 0.008 | Tree loss: 9.602 | Accuracy: 0.000000 | 0.79 sec/iter
Epoch: 00 | Batch: 006 / 026 | Total loss: 9.598 | Reg loss: 0.008 | Tree loss: 9.598 | Accuracy: 0.000000 | 0.788 s

Epoch: 02 | Batch: 009 / 026 | Total loss: 9.509 | Reg loss: 0.008 | Tree loss: 9.509 | Accuracy: 0.136719 | 0.827 sec/iter
Epoch: 02 | Batch: 010 / 026 | Total loss: 9.508 | Reg loss: 0.008 | Tree loss: 9.508 | Accuracy: 0.123047 | 0.827 sec/iter
Epoch: 02 | Batch: 011 / 026 | Total loss: 9.501 | Reg loss: 0.008 | Tree loss: 9.501 | Accuracy: 0.181641 | 0.827 sec/iter
Epoch: 02 | Batch: 012 / 026 | Total loss: 9.503 | Reg loss: 0.009 | Tree loss: 9.503 | Accuracy: 0.150391 | 0.827 sec/iter
Epoch: 02 | Batch: 013 / 026 | Total loss: 9.497 | Reg loss: 0.009 | Tree loss: 9.497 | Accuracy: 0.181641 | 0.827 sec/iter
Epoch: 02 | Batch: 014 / 026 | Total loss: 9.494 | Reg loss: 0.009 | Tree loss: 9.494 | Accuracy: 0.156250 | 0.827 sec/iter
Epoch: 02 | Batch: 015 / 026 | Total loss: 9.491 | Reg loss: 0.010 | Tree loss: 9.491 | Accuracy: 0.132812 | 0.826 sec/iter
Epoch: 02 | Batch: 016 / 026 | Total loss: 9.490 | Reg loss: 0.010 | Tree loss: 9.490 | Accuracy: 0.152344 | 0.826 sec/iter
Epoch: 0

Epoch: 04 | Batch: 019 / 026 | Total loss: 9.335 | Reg loss: 0.016 | Tree loss: 9.335 | Accuracy: 0.130859 | 0.839 sec/iter
Epoch: 04 | Batch: 020 / 026 | Total loss: 9.322 | Reg loss: 0.016 | Tree loss: 9.322 | Accuracy: 0.154297 | 0.839 sec/iter
Epoch: 04 | Batch: 021 / 026 | Total loss: 9.312 | Reg loss: 0.017 | Tree loss: 9.312 | Accuracy: 0.154297 | 0.839 sec/iter
Epoch: 04 | Batch: 022 / 026 | Total loss: 9.302 | Reg loss: 0.017 | Tree loss: 9.302 | Accuracy: 0.154297 | 0.838 sec/iter
Epoch: 04 | Batch: 023 / 026 | Total loss: 9.292 | Reg loss: 0.018 | Tree loss: 9.292 | Accuracy: 0.162109 | 0.838 sec/iter
Epoch: 04 | Batch: 024 / 026 | Total loss: 9.287 | Reg loss: 0.018 | Tree loss: 9.287 | Accuracy: 0.150391 | 0.838 sec/iter
Epoch: 04 | Batch: 025 / 026 | Total loss: 9.235 | Reg loss: 0.018 | Tree loss: 9.235 | Accuracy: 0.210526 | 0.838 sec/iter
Average sparseness: 0.9821428571428573
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 

Epoch: 07 | Batch: 001 / 026 | Total loss: 9.075 | Reg loss: 0.017 | Tree loss: 9.075 | Accuracy: 0.123047 | 0.852 sec/iter
Epoch: 07 | Batch: 002 / 026 | Total loss: 9.068 | Reg loss: 0.017 | Tree loss: 9.068 | Accuracy: 0.134766 | 0.851 sec/iter
Epoch: 07 | Batch: 003 / 026 | Total loss: 9.035 | Reg loss: 0.017 | Tree loss: 9.035 | Accuracy: 0.160156 | 0.851 sec/iter
Epoch: 07 | Batch: 004 / 026 | Total loss: 9.012 | Reg loss: 0.018 | Tree loss: 9.012 | Accuracy: 0.167969 | 0.851 sec/iter
Epoch: 07 | Batch: 005 / 026 | Total loss: 8.997 | Reg loss: 0.018 | Tree loss: 8.997 | Accuracy: 0.140625 | 0.852 sec/iter
Epoch: 07 | Batch: 006 / 026 | Total loss: 8.985 | Reg loss: 0.018 | Tree loss: 8.985 | Accuracy: 0.152344 | 0.852 sec/iter
Epoch: 07 | Batch: 007 / 026 | Total loss: 8.958 | Reg loss: 0.018 | Tree loss: 8.958 | Accuracy: 0.150391 | 0.852 sec/iter
Epoch: 07 | Batch: 008 / 026 | Total loss: 8.948 | Reg loss: 0.018 | Tree loss: 8.948 | Accuracy: 0.125000 | 0.852 sec/iter
Epoch: 0

Epoch: 09 | Batch: 011 / 026 | Total loss: 8.390 | Reg loss: 0.022 | Tree loss: 8.390 | Accuracy: 0.119141 | 0.854 sec/iter
Epoch: 09 | Batch: 012 / 026 | Total loss: 8.336 | Reg loss: 0.022 | Tree loss: 8.336 | Accuracy: 0.123047 | 0.854 sec/iter
Epoch: 09 | Batch: 013 / 026 | Total loss: 8.322 | Reg loss: 0.022 | Tree loss: 8.322 | Accuracy: 0.152344 | 0.854 sec/iter
Epoch: 09 | Batch: 014 / 026 | Total loss: 8.319 | Reg loss: 0.022 | Tree loss: 8.319 | Accuracy: 0.117188 | 0.854 sec/iter
Epoch: 09 | Batch: 015 / 026 | Total loss: 8.288 | Reg loss: 0.022 | Tree loss: 8.288 | Accuracy: 0.154297 | 0.854 sec/iter
Epoch: 09 | Batch: 016 / 026 | Total loss: 8.255 | Reg loss: 0.023 | Tree loss: 8.255 | Accuracy: 0.136719 | 0.854 sec/iter
Epoch: 09 | Batch: 017 / 026 | Total loss: 8.270 | Reg loss: 0.023 | Tree loss: 8.270 | Accuracy: 0.123047 | 0.854 sec/iter
Epoch: 09 | Batch: 018 / 026 | Total loss: 8.241 | Reg loss: 0.023 | Tree loss: 8.241 | Accuracy: 0.128906 | 0.854 sec/iter
Epoch: 0

Epoch: 11 | Batch: 021 / 026 | Total loss: 7.673 | Reg loss: 0.024 | Tree loss: 7.673 | Accuracy: 0.138672 | 0.856 sec/iter
Epoch: 11 | Batch: 022 / 026 | Total loss: 7.636 | Reg loss: 0.025 | Tree loss: 7.636 | Accuracy: 0.105469 | 0.856 sec/iter
Epoch: 11 | Batch: 023 / 026 | Total loss: 7.593 | Reg loss: 0.025 | Tree loss: 7.593 | Accuracy: 0.144531 | 0.856 sec/iter
Epoch: 11 | Batch: 024 / 026 | Total loss: 7.598 | Reg loss: 0.025 | Tree loss: 7.598 | Accuracy: 0.093750 | 0.855 sec/iter
Epoch: 11 | Batch: 025 / 026 | Total loss: 7.499 | Reg loss: 0.025 | Tree loss: 7.499 | Accuracy: 0.192982 | 0.855 sec/iter
Average sparseness: 0.9821428571428573
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
layer 7: 0.9821428571428573
layer 8: 0.9821428571428573
Epoch: 12 | Batch: 000 / 026 | Total loss: 7.839 | Reg loss: 0.023 | Tree loss: 7.839 | Ac

Epoch: 14 | Batch: 003 / 026 | Total loss: 7.317 | Reg loss: 0.024 | Tree loss: 7.317 | Accuracy: 0.115234 | 0.858 sec/iter
Epoch: 14 | Batch: 004 / 026 | Total loss: 7.277 | Reg loss: 0.024 | Tree loss: 7.277 | Accuracy: 0.132812 | 0.858 sec/iter
Epoch: 14 | Batch: 005 / 026 | Total loss: 7.256 | Reg loss: 0.024 | Tree loss: 7.256 | Accuracy: 0.134766 | 0.858 sec/iter
Epoch: 14 | Batch: 006 / 026 | Total loss: 7.215 | Reg loss: 0.024 | Tree loss: 7.215 | Accuracy: 0.150391 | 0.858 sec/iter
Epoch: 14 | Batch: 007 / 026 | Total loss: 7.202 | Reg loss: 0.024 | Tree loss: 7.202 | Accuracy: 0.117188 | 0.858 sec/iter
Epoch: 14 | Batch: 008 / 026 | Total loss: 7.183 | Reg loss: 0.024 | Tree loss: 7.183 | Accuracy: 0.140625 | 0.858 sec/iter
Epoch: 14 | Batch: 009 / 026 | Total loss: 7.164 | Reg loss: 0.024 | Tree loss: 7.164 | Accuracy: 0.128906 | 0.858 sec/iter
Epoch: 14 | Batch: 010 / 026 | Total loss: 7.135 | Reg loss: 0.024 | Tree loss: 7.135 | Accuracy: 0.136719 | 0.858 sec/iter
Epoch: 1

Epoch: 16 | Batch: 013 / 026 | Total loss: 6.596 | Reg loss: 0.025 | Tree loss: 6.596 | Accuracy: 0.130859 | 0.86 sec/iter
Epoch: 16 | Batch: 014 / 026 | Total loss: 6.629 | Reg loss: 0.025 | Tree loss: 6.629 | Accuracy: 0.113281 | 0.86 sec/iter
Epoch: 16 | Batch: 015 / 026 | Total loss: 6.588 | Reg loss: 0.025 | Tree loss: 6.588 | Accuracy: 0.136719 | 0.86 sec/iter
Epoch: 16 | Batch: 016 / 026 | Total loss: 6.616 | Reg loss: 0.025 | Tree loss: 6.616 | Accuracy: 0.123047 | 0.86 sec/iter
Epoch: 16 | Batch: 017 / 026 | Total loss: 6.534 | Reg loss: 0.025 | Tree loss: 6.534 | Accuracy: 0.146484 | 0.859 sec/iter
Epoch: 16 | Batch: 018 / 026 | Total loss: 6.534 | Reg loss: 0.025 | Tree loss: 6.534 | Accuracy: 0.128906 | 0.859 sec/iter
Epoch: 16 | Batch: 019 / 026 | Total loss: 6.545 | Reg loss: 0.025 | Tree loss: 6.545 | Accuracy: 0.146484 | 0.859 sec/iter
Epoch: 16 | Batch: 020 / 026 | Total loss: 6.524 | Reg loss: 0.025 | Tree loss: 6.524 | Accuracy: 0.103516 | 0.859 sec/iter
Epoch: 16 | 

Epoch: 18 | Batch: 023 / 026 | Total loss: 6.064 | Reg loss: 0.025 | Tree loss: 6.064 | Accuracy: 0.125000 | 0.858 sec/iter
Epoch: 18 | Batch: 024 / 026 | Total loss: 6.053 | Reg loss: 0.026 | Tree loss: 6.053 | Accuracy: 0.123047 | 0.858 sec/iter
Epoch: 18 | Batch: 025 / 026 | Total loss: 5.975 | Reg loss: 0.026 | Tree loss: 5.975 | Accuracy: 0.140351 | 0.858 sec/iter
Average sparseness: 0.9821428571428573
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
layer 7: 0.9821428571428573
layer 8: 0.9821428571428573
Epoch: 19 | Batch: 000 / 026 | Total loss: 6.178 | Reg loss: 0.025 | Tree loss: 6.178 | Accuracy: 0.146484 | 0.86 sec/iter
Epoch: 19 | Batch: 001 / 026 | Total loss: 6.174 | Reg loss: 0.025 | Tree loss: 6.174 | Accuracy: 0.140625 | 0.86 sec/iter
Epoch: 19 | Batch: 002 / 026 | Total loss: 6.174 | Reg loss: 0.025 | Tree loss: 6.174 | Accu

Epoch: 21 | Batch: 005 / 026 | Total loss: 5.753 | Reg loss: 0.025 | Tree loss: 5.753 | Accuracy: 0.150391 | 0.86 sec/iter
Epoch: 21 | Batch: 006 / 026 | Total loss: 5.716 | Reg loss: 0.025 | Tree loss: 5.716 | Accuracy: 0.115234 | 0.86 sec/iter
Epoch: 21 | Batch: 007 / 026 | Total loss: 5.686 | Reg loss: 0.025 | Tree loss: 5.686 | Accuracy: 0.126953 | 0.859 sec/iter
Epoch: 21 | Batch: 008 / 026 | Total loss: 5.703 | Reg loss: 0.025 | Tree loss: 5.703 | Accuracy: 0.132812 | 0.859 sec/iter
Epoch: 21 | Batch: 009 / 026 | Total loss: 5.681 | Reg loss: 0.025 | Tree loss: 5.681 | Accuracy: 0.125000 | 0.859 sec/iter
Epoch: 21 | Batch: 010 / 026 | Total loss: 5.677 | Reg loss: 0.025 | Tree loss: 5.677 | Accuracy: 0.138672 | 0.859 sec/iter
Epoch: 21 | Batch: 011 / 026 | Total loss: 5.637 | Reg loss: 0.025 | Tree loss: 5.637 | Accuracy: 0.117188 | 0.859 sec/iter
Epoch: 21 | Batch: 012 / 026 | Total loss: 5.653 | Reg loss: 0.025 | Tree loss: 5.653 | Accuracy: 0.142578 | 0.859 sec/iter
Epoch: 21 

Epoch: 23 | Batch: 015 / 026 | Total loss: 5.263 | Reg loss: 0.025 | Tree loss: 5.263 | Accuracy: 0.134766 | 0.859 sec/iter
Epoch: 23 | Batch: 016 / 026 | Total loss: 5.276 | Reg loss: 0.025 | Tree loss: 5.276 | Accuracy: 0.113281 | 0.859 sec/iter
Epoch: 23 | Batch: 017 / 026 | Total loss: 5.225 | Reg loss: 0.025 | Tree loss: 5.225 | Accuracy: 0.146484 | 0.859 sec/iter
Epoch: 23 | Batch: 018 / 026 | Total loss: 5.245 | Reg loss: 0.025 | Tree loss: 5.245 | Accuracy: 0.140625 | 0.859 sec/iter
Epoch: 23 | Batch: 019 / 026 | Total loss: 5.255 | Reg loss: 0.025 | Tree loss: 5.255 | Accuracy: 0.107422 | 0.859 sec/iter
Epoch: 23 | Batch: 020 / 026 | Total loss: 5.239 | Reg loss: 0.025 | Tree loss: 5.239 | Accuracy: 0.148438 | 0.859 sec/iter
Epoch: 23 | Batch: 021 / 026 | Total loss: 5.200 | Reg loss: 0.025 | Tree loss: 5.200 | Accuracy: 0.113281 | 0.859 sec/iter
Epoch: 23 | Batch: 022 / 026 | Total loss: 5.163 | Reg loss: 0.025 | Tree loss: 5.163 | Accuracy: 0.142578 | 0.859 sec/iter
Epoch: 2

Epoch: 25 | Batch: 025 / 026 | Total loss: 4.840 | Reg loss: 0.025 | Tree loss: 4.840 | Accuracy: 0.122807 | 0.859 sec/iter
Average sparseness: 0.9821428571428573
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
layer 7: 0.9821428571428573
layer 8: 0.9821428571428573
Epoch: 26 | Batch: 000 / 026 | Total loss: 4.955 | Reg loss: 0.025 | Tree loss: 4.955 | Accuracy: 0.123047 | 0.86 sec/iter
Epoch: 26 | Batch: 001 / 026 | Total loss: 4.945 | Reg loss: 0.025 | Tree loss: 4.945 | Accuracy: 0.148438 | 0.86 sec/iter
Epoch: 26 | Batch: 002 / 026 | Total loss: 4.927 | Reg loss: 0.025 | Tree loss: 4.927 | Accuracy: 0.128906 | 0.86 sec/iter
Epoch: 26 | Batch: 003 / 026 | Total loss: 4.958 | Reg loss: 0.025 | Tree loss: 4.958 | Accuracy: 0.128906 | 0.86 sec/iter
Epoch: 26 | Batch: 004 / 026 | Total loss: 4.939 | Reg loss: 0.025 | Tree loss: 4.939 | Accura

Epoch: 28 | Batch: 007 / 026 | Total loss: 4.683 | Reg loss: 0.025 | Tree loss: 4.683 | Accuracy: 0.107422 | 0.86 sec/iter
Epoch: 28 | Batch: 008 / 026 | Total loss: 4.611 | Reg loss: 0.025 | Tree loss: 4.611 | Accuracy: 0.140625 | 0.86 sec/iter
Epoch: 28 | Batch: 009 / 026 | Total loss: 4.604 | Reg loss: 0.025 | Tree loss: 4.604 | Accuracy: 0.138672 | 0.86 sec/iter
Epoch: 28 | Batch: 010 / 026 | Total loss: 4.621 | Reg loss: 0.025 | Tree loss: 4.621 | Accuracy: 0.113281 | 0.86 sec/iter
Epoch: 28 | Batch: 011 / 026 | Total loss: 4.596 | Reg loss: 0.025 | Tree loss: 4.596 | Accuracy: 0.128906 | 0.86 sec/iter
Epoch: 28 | Batch: 012 / 026 | Total loss: 4.557 | Reg loss: 0.025 | Tree loss: 4.557 | Accuracy: 0.138672 | 0.86 sec/iter
Epoch: 28 | Batch: 013 / 026 | Total loss: 4.564 | Reg loss: 0.025 | Tree loss: 4.564 | Accuracy: 0.140625 | 0.86 sec/iter
Epoch: 28 | Batch: 014 / 026 | Total loss: 4.585 | Reg loss: 0.025 | Tree loss: 4.585 | Accuracy: 0.119141 | 0.86 sec/iter
Epoch: 28 | Batc

Epoch: 30 | Batch: 017 / 026 | Total loss: 4.339 | Reg loss: 0.025 | Tree loss: 4.339 | Accuracy: 0.109375 | 0.861 sec/iter
Epoch: 30 | Batch: 018 / 026 | Total loss: 4.312 | Reg loss: 0.025 | Tree loss: 4.312 | Accuracy: 0.125000 | 0.861 sec/iter
Epoch: 30 | Batch: 019 / 026 | Total loss: 4.326 | Reg loss: 0.025 | Tree loss: 4.326 | Accuracy: 0.125000 | 0.861 sec/iter
Epoch: 30 | Batch: 020 / 026 | Total loss: 4.311 | Reg loss: 0.025 | Tree loss: 4.311 | Accuracy: 0.158203 | 0.861 sec/iter
Epoch: 30 | Batch: 021 / 026 | Total loss: 4.283 | Reg loss: 0.025 | Tree loss: 4.283 | Accuracy: 0.134766 | 0.861 sec/iter
Epoch: 30 | Batch: 022 / 026 | Total loss: 4.277 | Reg loss: 0.025 | Tree loss: 4.277 | Accuracy: 0.121094 | 0.861 sec/iter
Epoch: 30 | Batch: 023 / 026 | Total loss: 4.298 | Reg loss: 0.025 | Tree loss: 4.298 | Accuracy: 0.144531 | 0.861 sec/iter
Epoch: 30 | Batch: 024 / 026 | Total loss: 4.262 | Reg loss: 0.025 | Tree loss: 4.262 | Accuracy: 0.125000 | 0.861 sec/iter
Epoch: 3

layer 8: 0.9821428571428573
Epoch: 33 | Batch: 000 / 026 | Total loss: 4.063 | Reg loss: 0.025 | Tree loss: 4.063 | Accuracy: 0.169922 | 0.862 sec/iter
Epoch: 33 | Batch: 001 / 026 | Total loss: 4.126 | Reg loss: 0.025 | Tree loss: 4.126 | Accuracy: 0.125000 | 0.862 sec/iter
Epoch: 33 | Batch: 002 / 026 | Total loss: 4.068 | Reg loss: 0.025 | Tree loss: 4.068 | Accuracy: 0.144531 | 0.862 sec/iter
Epoch: 33 | Batch: 003 / 026 | Total loss: 4.131 | Reg loss: 0.025 | Tree loss: 4.131 | Accuracy: 0.142578 | 0.862 sec/iter
Epoch: 33 | Batch: 004 / 026 | Total loss: 4.073 | Reg loss: 0.025 | Tree loss: 4.073 | Accuracy: 0.126953 | 0.861 sec/iter
Epoch: 33 | Batch: 005 / 026 | Total loss: 4.094 | Reg loss: 0.025 | Tree loss: 4.094 | Accuracy: 0.125000 | 0.861 sec/iter
Epoch: 33 | Batch: 006 / 026 | Total loss: 4.095 | Reg loss: 0.025 | Tree loss: 4.095 | Accuracy: 0.130859 | 0.861 sec/iter
Epoch: 33 | Batch: 007 / 026 | Total loss: 4.093 | Reg loss: 0.025 | Tree loss: 4.093 | Accuracy: 0.1445

Epoch: 35 | Batch: 010 / 026 | Total loss: 3.834 | Reg loss: 0.026 | Tree loss: 3.834 | Accuracy: 0.138672 | 0.861 sec/iter
Epoch: 35 | Batch: 011 / 026 | Total loss: 3.842 | Reg loss: 0.026 | Tree loss: 3.842 | Accuracy: 0.123047 | 0.861 sec/iter
Epoch: 35 | Batch: 012 / 026 | Total loss: 3.810 | Reg loss: 0.026 | Tree loss: 3.810 | Accuracy: 0.130859 | 0.861 sec/iter
Epoch: 35 | Batch: 013 / 026 | Total loss: 3.803 | Reg loss: 0.026 | Tree loss: 3.803 | Accuracy: 0.138672 | 0.861 sec/iter
Epoch: 35 | Batch: 014 / 026 | Total loss: 3.796 | Reg loss: 0.026 | Tree loss: 3.796 | Accuracy: 0.140625 | 0.861 sec/iter
Epoch: 35 | Batch: 015 / 026 | Total loss: 3.794 | Reg loss: 0.026 | Tree loss: 3.794 | Accuracy: 0.121094 | 0.861 sec/iter
Epoch: 35 | Batch: 016 / 026 | Total loss: 3.776 | Reg loss: 0.026 | Tree loss: 3.776 | Accuracy: 0.158203 | 0.861 sec/iter
Epoch: 35 | Batch: 017 / 026 | Total loss: 3.839 | Reg loss: 0.026 | Tree loss: 3.839 | Accuracy: 0.125000 | 0.861 sec/iter
Epoch: 3

Epoch: 37 | Batch: 020 / 026 | Total loss: 3.594 | Reg loss: 0.027 | Tree loss: 3.594 | Accuracy: 0.119141 | 0.86 sec/iter
Epoch: 37 | Batch: 021 / 026 | Total loss: 3.602 | Reg loss: 0.027 | Tree loss: 3.602 | Accuracy: 0.125000 | 0.86 sec/iter
Epoch: 37 | Batch: 022 / 026 | Total loss: 3.575 | Reg loss: 0.027 | Tree loss: 3.575 | Accuracy: 0.142578 | 0.86 sec/iter
Epoch: 37 | Batch: 023 / 026 | Total loss: 3.549 | Reg loss: 0.027 | Tree loss: 3.549 | Accuracy: 0.154297 | 0.86 sec/iter
Epoch: 37 | Batch: 024 / 026 | Total loss: 3.508 | Reg loss: 0.027 | Tree loss: 3.508 | Accuracy: 0.154297 | 0.86 sec/iter
Epoch: 37 | Batch: 025 / 026 | Total loss: 3.585 | Reg loss: 0.027 | Tree loss: 3.585 | Accuracy: 0.228070 | 0.86 sec/iter
Average sparseness: 0.9821428571428573
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
layer 7: 0.9821428571428573


Epoch: 40 | Batch: 002 / 026 | Total loss: 3.588 | Reg loss: 0.027 | Tree loss: 3.588 | Accuracy: 0.162109 | 0.861 sec/iter
Epoch: 40 | Batch: 003 / 026 | Total loss: 3.580 | Reg loss: 0.027 | Tree loss: 3.580 | Accuracy: 0.134766 | 0.861 sec/iter
Epoch: 40 | Batch: 004 / 026 | Total loss: 3.553 | Reg loss: 0.027 | Tree loss: 3.553 | Accuracy: 0.160156 | 0.861 sec/iter
Epoch: 40 | Batch: 005 / 026 | Total loss: 3.539 | Reg loss: 0.027 | Tree loss: 3.539 | Accuracy: 0.181641 | 0.861 sec/iter
Epoch: 40 | Batch: 006 / 026 | Total loss: 3.554 | Reg loss: 0.027 | Tree loss: 3.554 | Accuracy: 0.150391 | 0.86 sec/iter
Epoch: 40 | Batch: 007 / 026 | Total loss: 3.547 | Reg loss: 0.027 | Tree loss: 3.547 | Accuracy: 0.134766 | 0.86 sec/iter
Epoch: 40 | Batch: 008 / 026 | Total loss: 3.523 | Reg loss: 0.027 | Tree loss: 3.523 | Accuracy: 0.134766 | 0.86 sec/iter
Epoch: 40 | Batch: 009 / 026 | Total loss: 3.444 | Reg loss: 0.027 | Tree loss: 3.444 | Accuracy: 0.132812 | 0.86 sec/iter
Epoch: 40 | 

Epoch: 42 | Batch: 012 / 026 | Total loss: 3.352 | Reg loss: 0.027 | Tree loss: 3.352 | Accuracy: 0.156250 | 0.86 sec/iter
Epoch: 42 | Batch: 013 / 026 | Total loss: 3.373 | Reg loss: 0.028 | Tree loss: 3.373 | Accuracy: 0.158203 | 0.86 sec/iter
Epoch: 42 | Batch: 014 / 026 | Total loss: 3.394 | Reg loss: 0.028 | Tree loss: 3.394 | Accuracy: 0.140625 | 0.86 sec/iter
Epoch: 42 | Batch: 015 / 026 | Total loss: 3.396 | Reg loss: 0.028 | Tree loss: 3.396 | Accuracy: 0.140625 | 0.86 sec/iter
Epoch: 42 | Batch: 016 / 026 | Total loss: 3.355 | Reg loss: 0.028 | Tree loss: 3.355 | Accuracy: 0.146484 | 0.86 sec/iter
Epoch: 42 | Batch: 017 / 026 | Total loss: 3.306 | Reg loss: 0.028 | Tree loss: 3.306 | Accuracy: 0.181641 | 0.86 sec/iter
Epoch: 42 | Batch: 018 / 026 | Total loss: 3.360 | Reg loss: 0.028 | Tree loss: 3.360 | Accuracy: 0.156250 | 0.86 sec/iter
Epoch: 42 | Batch: 019 / 026 | Total loss: 3.385 | Reg loss: 0.028 | Tree loss: 3.385 | Accuracy: 0.154297 | 0.86 sec/iter
Epoch: 42 | Batc

Epoch: 44 | Batch: 022 / 026 | Total loss: 3.363 | Reg loss: 0.028 | Tree loss: 3.363 | Accuracy: 0.113281 | 0.86 sec/iter
Epoch: 44 | Batch: 023 / 026 | Total loss: 3.284 | Reg loss: 0.028 | Tree loss: 3.284 | Accuracy: 0.148438 | 0.86 sec/iter
Epoch: 44 | Batch: 024 / 026 | Total loss: 3.246 | Reg loss: 0.028 | Tree loss: 3.246 | Accuracy: 0.175781 | 0.86 sec/iter
Epoch: 44 | Batch: 025 / 026 | Total loss: 3.166 | Reg loss: 0.028 | Tree loss: 3.166 | Accuracy: 0.175439 | 0.86 sec/iter
Average sparseness: 0.9821428571428573
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
layer 7: 0.9821428571428573
layer 8: 0.9821428571428573
Epoch: 45 | Batch: 000 / 026 | Total loss: 3.381 | Reg loss: 0.028 | Tree loss: 3.381 | Accuracy: 0.140625 | 0.86 sec/iter
Epoch: 45 | Batch: 001 / 026 | Total loss: 3.395 | Reg loss: 0.028 | Tree loss: 3.395 | Accurac

Epoch: 47 | Batch: 004 / 026 | Total loss: 3.293 | Reg loss: 0.028 | Tree loss: 3.293 | Accuracy: 0.181641 | 0.861 sec/iter
Epoch: 47 | Batch: 005 / 026 | Total loss: 3.335 | Reg loss: 0.028 | Tree loss: 3.335 | Accuracy: 0.132812 | 0.861 sec/iter
Epoch: 47 | Batch: 006 / 026 | Total loss: 3.263 | Reg loss: 0.028 | Tree loss: 3.263 | Accuracy: 0.154297 | 0.861 sec/iter
Epoch: 47 | Batch: 007 / 026 | Total loss: 3.327 | Reg loss: 0.028 | Tree loss: 3.327 | Accuracy: 0.134766 | 0.861 sec/iter
Epoch: 47 | Batch: 008 / 026 | Total loss: 3.338 | Reg loss: 0.028 | Tree loss: 3.338 | Accuracy: 0.132812 | 0.86 sec/iter
Epoch: 47 | Batch: 009 / 026 | Total loss: 3.295 | Reg loss: 0.028 | Tree loss: 3.295 | Accuracy: 0.123047 | 0.86 sec/iter
Epoch: 47 | Batch: 010 / 026 | Total loss: 3.209 | Reg loss: 0.028 | Tree loss: 3.209 | Accuracy: 0.142578 | 0.86 sec/iter
Epoch: 47 | Batch: 011 / 026 | Total loss: 3.286 | Reg loss: 0.028 | Tree loss: 3.286 | Accuracy: 0.123047 | 0.86 sec/iter
Epoch: 47 | 

Epoch: 49 | Batch: 014 / 026 | Total loss: 3.187 | Reg loss: 0.028 | Tree loss: 3.187 | Accuracy: 0.136719 | 0.86 sec/iter
Epoch: 49 | Batch: 015 / 026 | Total loss: 3.154 | Reg loss: 0.028 | Tree loss: 3.154 | Accuracy: 0.152344 | 0.86 sec/iter
Epoch: 49 | Batch: 016 / 026 | Total loss: 3.213 | Reg loss: 0.028 | Tree loss: 3.213 | Accuracy: 0.152344 | 0.86 sec/iter
Epoch: 49 | Batch: 017 / 026 | Total loss: 3.167 | Reg loss: 0.028 | Tree loss: 3.167 | Accuracy: 0.138672 | 0.86 sec/iter
Epoch: 49 | Batch: 018 / 026 | Total loss: 3.147 | Reg loss: 0.028 | Tree loss: 3.147 | Accuracy: 0.136719 | 0.86 sec/iter
Epoch: 49 | Batch: 019 / 026 | Total loss: 3.123 | Reg loss: 0.028 | Tree loss: 3.123 | Accuracy: 0.177734 | 0.86 sec/iter
Epoch: 49 | Batch: 020 / 026 | Total loss: 3.144 | Reg loss: 0.029 | Tree loss: 3.144 | Accuracy: 0.158203 | 0.86 sec/iter
Epoch: 49 | Batch: 021 / 026 | Total loss: 3.126 | Reg loss: 0.029 | Tree loss: 3.126 | Accuracy: 0.187500 | 0.86 sec/iter
Epoch: 49 | Batc

Epoch: 51 | Batch: 024 / 026 | Total loss: 3.123 | Reg loss: 0.029 | Tree loss: 3.123 | Accuracy: 0.117188 | 0.86 sec/iter
Epoch: 51 | Batch: 025 / 026 | Total loss: 3.112 | Reg loss: 0.029 | Tree loss: 3.112 | Accuracy: 0.175439 | 0.86 sec/iter
Average sparseness: 0.9821428571428573
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
layer 7: 0.9821428571428573
layer 8: 0.9821428571428573
Epoch: 52 | Batch: 000 / 026 | Total loss: 3.237 | Reg loss: 0.028 | Tree loss: 3.237 | Accuracy: 0.148438 | 0.86 sec/iter
Epoch: 52 | Batch: 001 / 026 | Total loss: 3.313 | Reg loss: 0.028 | Tree loss: 3.313 | Accuracy: 0.132812 | 0.86 sec/iter
Epoch: 52 | Batch: 002 / 026 | Total loss: 3.201 | Reg loss: 0.028 | Tree loss: 3.201 | Accuracy: 0.164062 | 0.86 sec/iter
Epoch: 52 | Batch: 003 / 026 | Total loss: 3.227 | Reg loss: 0.028 | Tree loss: 3.227 | Accurac

Epoch: 54 | Batch: 006 / 026 | Total loss: 3.231 | Reg loss: 0.028 | Tree loss: 3.231 | Accuracy: 0.126953 | 0.86 sec/iter
Epoch: 54 | Batch: 007 / 026 | Total loss: 3.128 | Reg loss: 0.028 | Tree loss: 3.128 | Accuracy: 0.158203 | 0.86 sec/iter
Epoch: 54 | Batch: 008 / 026 | Total loss: 3.173 | Reg loss: 0.028 | Tree loss: 3.173 | Accuracy: 0.132812 | 0.86 sec/iter
Epoch: 54 | Batch: 009 / 026 | Total loss: 3.163 | Reg loss: 0.028 | Tree loss: 3.163 | Accuracy: 0.126953 | 0.86 sec/iter
Epoch: 54 | Batch: 010 / 026 | Total loss: 3.074 | Reg loss: 0.029 | Tree loss: 3.074 | Accuracy: 0.171875 | 0.86 sec/iter
Epoch: 54 | Batch: 011 / 026 | Total loss: 3.173 | Reg loss: 0.029 | Tree loss: 3.173 | Accuracy: 0.140625 | 0.86 sec/iter
Epoch: 54 | Batch: 012 / 026 | Total loss: 3.118 | Reg loss: 0.029 | Tree loss: 3.118 | Accuracy: 0.138672 | 0.86 sec/iter
Epoch: 54 | Batch: 013 / 026 | Total loss: 3.085 | Reg loss: 0.029 | Tree loss: 3.085 | Accuracy: 0.169922 | 0.86 sec/iter
Epoch: 54 | Batc

Epoch: 56 | Batch: 016 / 026 | Total loss: 3.045 | Reg loss: 0.029 | Tree loss: 3.045 | Accuracy: 0.158203 | 0.86 sec/iter
Epoch: 56 | Batch: 017 / 026 | Total loss: 3.080 | Reg loss: 0.029 | Tree loss: 3.080 | Accuracy: 0.132812 | 0.86 sec/iter
Epoch: 56 | Batch: 018 / 026 | Total loss: 3.074 | Reg loss: 0.029 | Tree loss: 3.074 | Accuracy: 0.162109 | 0.86 sec/iter
Epoch: 56 | Batch: 019 / 026 | Total loss: 3.114 | Reg loss: 0.029 | Tree loss: 3.114 | Accuracy: 0.117188 | 0.86 sec/iter
Epoch: 56 | Batch: 020 / 026 | Total loss: 3.045 | Reg loss: 0.029 | Tree loss: 3.045 | Accuracy: 0.148438 | 0.86 sec/iter
Epoch: 56 | Batch: 021 / 026 | Total loss: 3.033 | Reg loss: 0.029 | Tree loss: 3.033 | Accuracy: 0.136719 | 0.86 sec/iter
Epoch: 56 | Batch: 022 / 026 | Total loss: 3.010 | Reg loss: 0.029 | Tree loss: 3.010 | Accuracy: 0.154297 | 0.86 sec/iter
Epoch: 56 | Batch: 023 / 026 | Total loss: 3.036 | Reg loss: 0.029 | Tree loss: 3.036 | Accuracy: 0.156250 | 0.86 sec/iter
Epoch: 56 | Batc

Average sparseness: 0.9821428571428573
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
layer 7: 0.9821428571428573
layer 8: 0.9821428571428573
Epoch: 59 | Batch: 000 / 026 | Total loss: 3.192 | Reg loss: 0.029 | Tree loss: 3.192 | Accuracy: 0.125000 | 0.86 sec/iter
Epoch: 59 | Batch: 001 / 026 | Total loss: 3.167 | Reg loss: 0.029 | Tree loss: 3.167 | Accuracy: 0.136719 | 0.86 sec/iter
Epoch: 59 | Batch: 002 / 026 | Total loss: 3.150 | Reg loss: 0.029 | Tree loss: 3.150 | Accuracy: 0.156250 | 0.86 sec/iter
Epoch: 59 | Batch: 003 / 026 | Total loss: 3.166 | Reg loss: 0.029 | Tree loss: 3.166 | Accuracy: 0.167969 | 0.86 sec/iter
Epoch: 59 | Batch: 004 / 026 | Total loss: 3.088 | Reg loss: 0.029 | Tree loss: 3.088 | Accuracy: 0.152344 | 0.86 sec/iter
Epoch: 59 | Batch: 005 / 026 | Total loss: 3.085 | Reg loss: 0.029 | Tree loss: 3.085 | Accurac

Epoch: 61 | Batch: 008 / 026 | Total loss: 3.136 | Reg loss: 0.029 | Tree loss: 3.136 | Accuracy: 0.144531 | 0.86 sec/iter
Epoch: 61 | Batch: 009 / 026 | Total loss: 3.120 | Reg loss: 0.029 | Tree loss: 3.120 | Accuracy: 0.119141 | 0.859 sec/iter
Epoch: 61 | Batch: 010 / 026 | Total loss: 3.094 | Reg loss: 0.029 | Tree loss: 3.094 | Accuracy: 0.138672 | 0.859 sec/iter
Epoch: 61 | Batch: 011 / 026 | Total loss: 3.058 | Reg loss: 0.029 | Tree loss: 3.058 | Accuracy: 0.146484 | 0.859 sec/iter
Epoch: 61 | Batch: 012 / 026 | Total loss: 3.041 | Reg loss: 0.029 | Tree loss: 3.041 | Accuracy: 0.144531 | 0.859 sec/iter
Epoch: 61 | Batch: 013 / 026 | Total loss: 3.046 | Reg loss: 0.029 | Tree loss: 3.046 | Accuracy: 0.164062 | 0.859 sec/iter
Epoch: 61 | Batch: 014 / 026 | Total loss: 3.071 | Reg loss: 0.029 | Tree loss: 3.071 | Accuracy: 0.171875 | 0.859 sec/iter
Epoch: 61 | Batch: 015 / 026 | Total loss: 3.049 | Reg loss: 0.029 | Tree loss: 3.049 | Accuracy: 0.144531 | 0.859 sec/iter
Epoch: 61

Epoch: 63 | Batch: 018 / 026 | Total loss: 3.003 | Reg loss: 0.029 | Tree loss: 3.003 | Accuracy: 0.130859 | 0.859 sec/iter
Epoch: 63 | Batch: 019 / 026 | Total loss: 2.945 | Reg loss: 0.029 | Tree loss: 2.945 | Accuracy: 0.130859 | 0.859 sec/iter
Epoch: 63 | Batch: 020 / 026 | Total loss: 2.996 | Reg loss: 0.029 | Tree loss: 2.996 | Accuracy: 0.167969 | 0.859 sec/iter
Epoch: 63 | Batch: 021 / 026 | Total loss: 2.988 | Reg loss: 0.029 | Tree loss: 2.988 | Accuracy: 0.134766 | 0.859 sec/iter
Epoch: 63 | Batch: 022 / 026 | Total loss: 2.986 | Reg loss: 0.029 | Tree loss: 2.986 | Accuracy: 0.130859 | 0.859 sec/iter
Epoch: 63 | Batch: 023 / 026 | Total loss: 2.938 | Reg loss: 0.029 | Tree loss: 2.938 | Accuracy: 0.140625 | 0.859 sec/iter
Epoch: 63 | Batch: 024 / 026 | Total loss: 3.030 | Reg loss: 0.029 | Tree loss: 3.030 | Accuracy: 0.134766 | 0.859 sec/iter
Epoch: 63 | Batch: 025 / 026 | Total loss: 2.970 | Reg loss: 0.029 | Tree loss: 2.970 | Accuracy: 0.192982 | 0.859 sec/iter
Average 

Epoch: 66 | Batch: 000 / 026 | Total loss: 3.095 | Reg loss: 0.029 | Tree loss: 3.095 | Accuracy: 0.166016 | 0.859 sec/iter
Epoch: 66 | Batch: 001 / 026 | Total loss: 3.161 | Reg loss: 0.029 | Tree loss: 3.161 | Accuracy: 0.134766 | 0.859 sec/iter
Epoch: 66 | Batch: 002 / 026 | Total loss: 3.030 | Reg loss: 0.029 | Tree loss: 3.030 | Accuracy: 0.144531 | 0.859 sec/iter
Epoch: 66 | Batch: 003 / 026 | Total loss: 3.108 | Reg loss: 0.029 | Tree loss: 3.108 | Accuracy: 0.158203 | 0.859 sec/iter
Epoch: 66 | Batch: 004 / 026 | Total loss: 3.090 | Reg loss: 0.029 | Tree loss: 3.090 | Accuracy: 0.164062 | 0.859 sec/iter
Epoch: 66 | Batch: 005 / 026 | Total loss: 3.108 | Reg loss: 0.029 | Tree loss: 3.108 | Accuracy: 0.136719 | 0.859 sec/iter
Epoch: 66 | Batch: 006 / 026 | Total loss: 3.042 | Reg loss: 0.029 | Tree loss: 3.042 | Accuracy: 0.148438 | 0.859 sec/iter
Epoch: 66 | Batch: 007 / 026 | Total loss: 3.079 | Reg loss: 0.029 | Tree loss: 3.079 | Accuracy: 0.128906 | 0.859 sec/iter
Epoch: 6

Epoch: 68 | Batch: 010 / 026 | Total loss: 3.019 | Reg loss: 0.029 | Tree loss: 3.019 | Accuracy: 0.154297 | 0.859 sec/iter
Epoch: 68 | Batch: 011 / 026 | Total loss: 3.022 | Reg loss: 0.029 | Tree loss: 3.022 | Accuracy: 0.164062 | 0.859 sec/iter
Epoch: 68 | Batch: 012 / 026 | Total loss: 3.076 | Reg loss: 0.029 | Tree loss: 3.076 | Accuracy: 0.134766 | 0.859 sec/iter
Epoch: 68 | Batch: 013 / 026 | Total loss: 2.952 | Reg loss: 0.029 | Tree loss: 2.952 | Accuracy: 0.160156 | 0.859 sec/iter
Epoch: 68 | Batch: 014 / 026 | Total loss: 2.993 | Reg loss: 0.029 | Tree loss: 2.993 | Accuracy: 0.125000 | 0.859 sec/iter
Epoch: 68 | Batch: 015 / 026 | Total loss: 3.005 | Reg loss: 0.029 | Tree loss: 3.005 | Accuracy: 0.136719 | 0.859 sec/iter
Epoch: 68 | Batch: 016 / 026 | Total loss: 2.925 | Reg loss: 0.029 | Tree loss: 2.925 | Accuracy: 0.171875 | 0.859 sec/iter
Epoch: 68 | Batch: 017 / 026 | Total loss: 2.987 | Reg loss: 0.029 | Tree loss: 2.987 | Accuracy: 0.162109 | 0.859 sec/iter
Epoch: 6

Epoch: 70 | Batch: 020 / 026 | Total loss: 2.973 | Reg loss: 0.029 | Tree loss: 2.973 | Accuracy: 0.150391 | 0.858 sec/iter
Epoch: 70 | Batch: 021 / 026 | Total loss: 2.950 | Reg loss: 0.029 | Tree loss: 2.950 | Accuracy: 0.167969 | 0.858 sec/iter
Epoch: 70 | Batch: 022 / 026 | Total loss: 2.954 | Reg loss: 0.029 | Tree loss: 2.954 | Accuracy: 0.154297 | 0.858 sec/iter
Epoch: 70 | Batch: 023 / 026 | Total loss: 2.957 | Reg loss: 0.029 | Tree loss: 2.957 | Accuracy: 0.138672 | 0.858 sec/iter
Epoch: 70 | Batch: 024 / 026 | Total loss: 2.914 | Reg loss: 0.029 | Tree loss: 2.914 | Accuracy: 0.175781 | 0.858 sec/iter
Epoch: 70 | Batch: 025 / 026 | Total loss: 2.815 | Reg loss: 0.029 | Tree loss: 2.815 | Accuracy: 0.175439 | 0.858 sec/iter
Average sparseness: 0.9821428571428573
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
layer 7: 0.98214285714

Epoch: 73 | Batch: 002 / 026 | Total loss: 3.076 | Reg loss: 0.029 | Tree loss: 3.076 | Accuracy: 0.136719 | 0.857 sec/iter
Epoch: 73 | Batch: 003 / 026 | Total loss: 3.065 | Reg loss: 0.029 | Tree loss: 3.065 | Accuracy: 0.121094 | 0.857 sec/iter
Epoch: 73 | Batch: 004 / 026 | Total loss: 3.013 | Reg loss: 0.029 | Tree loss: 3.013 | Accuracy: 0.156250 | 0.857 sec/iter
Epoch: 73 | Batch: 005 / 026 | Total loss: 3.070 | Reg loss: 0.029 | Tree loss: 3.070 | Accuracy: 0.152344 | 0.857 sec/iter
Epoch: 73 | Batch: 006 / 026 | Total loss: 3.064 | Reg loss: 0.029 | Tree loss: 3.064 | Accuracy: 0.162109 | 0.857 sec/iter
Epoch: 73 | Batch: 007 / 026 | Total loss: 3.070 | Reg loss: 0.029 | Tree loss: 3.070 | Accuracy: 0.144531 | 0.856 sec/iter
Epoch: 73 | Batch: 008 / 026 | Total loss: 3.039 | Reg loss: 0.029 | Tree loss: 3.039 | Accuracy: 0.144531 | 0.856 sec/iter
Epoch: 73 | Batch: 009 / 026 | Total loss: 3.023 | Reg loss: 0.029 | Tree loss: 3.023 | Accuracy: 0.144531 | 0.856 sec/iter
Epoch: 7

Epoch: 75 | Batch: 012 / 026 | Total loss: 2.978 | Reg loss: 0.029 | Tree loss: 2.978 | Accuracy: 0.128906 | 0.854 sec/iter
Epoch: 75 | Batch: 013 / 026 | Total loss: 2.997 | Reg loss: 0.029 | Tree loss: 2.997 | Accuracy: 0.144531 | 0.854 sec/iter
Epoch: 75 | Batch: 014 / 026 | Total loss: 2.974 | Reg loss: 0.029 | Tree loss: 2.974 | Accuracy: 0.142578 | 0.854 sec/iter
Epoch: 75 | Batch: 015 / 026 | Total loss: 2.960 | Reg loss: 0.029 | Tree loss: 2.960 | Accuracy: 0.140625 | 0.854 sec/iter
Epoch: 75 | Batch: 016 / 026 | Total loss: 2.915 | Reg loss: 0.029 | Tree loss: 2.915 | Accuracy: 0.146484 | 0.854 sec/iter
Epoch: 75 | Batch: 017 / 026 | Total loss: 2.988 | Reg loss: 0.029 | Tree loss: 2.988 | Accuracy: 0.148438 | 0.854 sec/iter
Epoch: 75 | Batch: 018 / 026 | Total loss: 2.943 | Reg loss: 0.029 | Tree loss: 2.943 | Accuracy: 0.140625 | 0.854 sec/iter
Epoch: 75 | Batch: 019 / 026 | Total loss: 2.927 | Reg loss: 0.029 | Tree loss: 2.927 | Accuracy: 0.148438 | 0.854 sec/iter
Epoch: 7

Epoch: 77 | Batch: 022 / 026 | Total loss: 2.931 | Reg loss: 0.029 | Tree loss: 2.931 | Accuracy: 0.121094 | 0.852 sec/iter
Epoch: 77 | Batch: 023 / 026 | Total loss: 2.961 | Reg loss: 0.029 | Tree loss: 2.961 | Accuracy: 0.125000 | 0.852 sec/iter
Epoch: 77 | Batch: 024 / 026 | Total loss: 2.893 | Reg loss: 0.029 | Tree loss: 2.893 | Accuracy: 0.146484 | 0.852 sec/iter
Epoch: 77 | Batch: 025 / 026 | Total loss: 2.824 | Reg loss: 0.029 | Tree loss: 2.824 | Accuracy: 0.175439 | 0.852 sec/iter
Average sparseness: 0.9821428571428573
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
layer 7: 0.9821428571428573
layer 8: 0.9821428571428573
Epoch: 78 | Batch: 000 / 026 | Total loss: 3.092 | Reg loss: 0.029 | Tree loss: 3.092 | Accuracy: 0.128906 | 0.852 sec/iter
Epoch: 78 | Batch: 001 / 026 | Total loss: 3.013 | Reg loss: 0.029 | Tree loss: 3.013 | Ac

Epoch: 80 | Batch: 004 / 026 | Total loss: 2.963 | Reg loss: 0.029 | Tree loss: 2.963 | Accuracy: 0.171875 | 0.85 sec/iter
Epoch: 80 | Batch: 005 / 026 | Total loss: 3.000 | Reg loss: 0.029 | Tree loss: 3.000 | Accuracy: 0.142578 | 0.85 sec/iter
Epoch: 80 | Batch: 006 / 026 | Total loss: 3.005 | Reg loss: 0.029 | Tree loss: 3.005 | Accuracy: 0.154297 | 0.85 sec/iter
Epoch: 80 | Batch: 007 / 026 | Total loss: 3.008 | Reg loss: 0.029 | Tree loss: 3.008 | Accuracy: 0.150391 | 0.85 sec/iter
Epoch: 80 | Batch: 008 / 026 | Total loss: 3.015 | Reg loss: 0.029 | Tree loss: 3.015 | Accuracy: 0.156250 | 0.85 sec/iter
Epoch: 80 | Batch: 009 / 026 | Total loss: 3.009 | Reg loss: 0.029 | Tree loss: 3.009 | Accuracy: 0.130859 | 0.85 sec/iter
Epoch: 80 | Batch: 010 / 026 | Total loss: 2.974 | Reg loss: 0.029 | Tree loss: 2.974 | Accuracy: 0.154297 | 0.85 sec/iter
Epoch: 80 | Batch: 011 / 026 | Total loss: 3.004 | Reg loss: 0.029 | Tree loss: 3.004 | Accuracy: 0.140625 | 0.85 sec/iter
Epoch: 80 | Batc

Epoch: 82 | Batch: 014 / 026 | Total loss: 2.883 | Reg loss: 0.029 | Tree loss: 2.883 | Accuracy: 0.158203 | 0.848 sec/iter
Epoch: 82 | Batch: 015 / 026 | Total loss: 2.937 | Reg loss: 0.029 | Tree loss: 2.937 | Accuracy: 0.173828 | 0.848 sec/iter
Epoch: 82 | Batch: 016 / 026 | Total loss: 2.954 | Reg loss: 0.029 | Tree loss: 2.954 | Accuracy: 0.099609 | 0.848 sec/iter
Epoch: 82 | Batch: 017 / 026 | Total loss: 2.905 | Reg loss: 0.029 | Tree loss: 2.905 | Accuracy: 0.148438 | 0.848 sec/iter
Epoch: 82 | Batch: 018 / 026 | Total loss: 2.920 | Reg loss: 0.029 | Tree loss: 2.920 | Accuracy: 0.156250 | 0.848 sec/iter
Epoch: 82 | Batch: 019 / 026 | Total loss: 2.984 | Reg loss: 0.029 | Tree loss: 2.984 | Accuracy: 0.150391 | 0.848 sec/iter
Epoch: 82 | Batch: 020 / 026 | Total loss: 2.889 | Reg loss: 0.029 | Tree loss: 2.889 | Accuracy: 0.169922 | 0.848 sec/iter
Epoch: 82 | Batch: 021 / 026 | Total loss: 2.933 | Reg loss: 0.029 | Tree loss: 2.933 | Accuracy: 0.140625 | 0.848 sec/iter
Epoch: 8

Epoch: 84 | Batch: 024 / 026 | Total loss: 2.860 | Reg loss: 0.029 | Tree loss: 2.860 | Accuracy: 0.167969 | 0.847 sec/iter
Epoch: 84 | Batch: 025 / 026 | Total loss: 2.850 | Reg loss: 0.029 | Tree loss: 2.850 | Accuracy: 0.157895 | 0.847 sec/iter
Average sparseness: 0.9821428571428573
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
layer 7: 0.9821428571428573
layer 8: 0.9821428571428573
Epoch: 85 | Batch: 000 / 026 | Total loss: 3.049 | Reg loss: 0.029 | Tree loss: 3.049 | Accuracy: 0.150391 | 0.847 sec/iter
Epoch: 85 | Batch: 001 / 026 | Total loss: 3.064 | Reg loss: 0.029 | Tree loss: 3.064 | Accuracy: 0.148438 | 0.847 sec/iter
Epoch: 85 | Batch: 002 / 026 | Total loss: 3.008 | Reg loss: 0.029 | Tree loss: 3.008 | Accuracy: 0.167969 | 0.847 sec/iter
Epoch: 85 | Batch: 003 / 026 | Total loss: 3.039 | Reg loss: 0.029 | Tree loss: 3.039 | Ac

Epoch: 87 | Batch: 006 / 026 | Total loss: 3.039 | Reg loss: 0.029 | Tree loss: 3.039 | Accuracy: 0.144531 | 0.845 sec/iter
Epoch: 87 | Batch: 007 / 026 | Total loss: 2.991 | Reg loss: 0.029 | Tree loss: 2.991 | Accuracy: 0.142578 | 0.845 sec/iter
Epoch: 87 | Batch: 008 / 026 | Total loss: 3.014 | Reg loss: 0.029 | Tree loss: 3.014 | Accuracy: 0.130859 | 0.845 sec/iter
Epoch: 87 | Batch: 009 / 026 | Total loss: 2.937 | Reg loss: 0.029 | Tree loss: 2.937 | Accuracy: 0.150391 | 0.845 sec/iter
Epoch: 87 | Batch: 010 / 026 | Total loss: 2.964 | Reg loss: 0.029 | Tree loss: 2.964 | Accuracy: 0.130859 | 0.845 sec/iter
Epoch: 87 | Batch: 011 / 026 | Total loss: 2.984 | Reg loss: 0.029 | Tree loss: 2.984 | Accuracy: 0.150391 | 0.845 sec/iter
Epoch: 87 | Batch: 012 / 026 | Total loss: 2.994 | Reg loss: 0.029 | Tree loss: 2.994 | Accuracy: 0.146484 | 0.845 sec/iter
Epoch: 87 | Batch: 013 / 026 | Total loss: 2.986 | Reg loss: 0.029 | Tree loss: 2.986 | Accuracy: 0.134766 | 0.845 sec/iter
Epoch: 8

Epoch: 89 | Batch: 016 / 026 | Total loss: 2.928 | Reg loss: 0.029 | Tree loss: 2.928 | Accuracy: 0.136719 | 0.844 sec/iter
Epoch: 89 | Batch: 017 / 026 | Total loss: 2.944 | Reg loss: 0.029 | Tree loss: 2.944 | Accuracy: 0.154297 | 0.844 sec/iter
Epoch: 89 | Batch: 018 / 026 | Total loss: 2.887 | Reg loss: 0.029 | Tree loss: 2.887 | Accuracy: 0.177734 | 0.844 sec/iter
Epoch: 89 | Batch: 019 / 026 | Total loss: 2.840 | Reg loss: 0.029 | Tree loss: 2.840 | Accuracy: 0.175781 | 0.844 sec/iter
Epoch: 89 | Batch: 020 / 026 | Total loss: 2.950 | Reg loss: 0.029 | Tree loss: 2.950 | Accuracy: 0.134766 | 0.844 sec/iter
Epoch: 89 | Batch: 021 / 026 | Total loss: 2.875 | Reg loss: 0.029 | Tree loss: 2.875 | Accuracy: 0.185547 | 0.844 sec/iter
Epoch: 89 | Batch: 022 / 026 | Total loss: 2.945 | Reg loss: 0.029 | Tree loss: 2.945 | Accuracy: 0.140625 | 0.844 sec/iter
Epoch: 89 | Batch: 023 / 026 | Total loss: 2.917 | Reg loss: 0.029 | Tree loss: 2.917 | Accuracy: 0.115234 | 0.844 sec/iter
Epoch: 8

Average sparseness: 0.9821428571428573
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
layer 7: 0.9821428571428573
layer 8: 0.9821428571428573
Epoch: 92 | Batch: 000 / 026 | Total loss: 3.022 | Reg loss: 0.028 | Tree loss: 3.022 | Accuracy: 0.138672 | 0.843 sec/iter
Epoch: 92 | Batch: 001 / 026 | Total loss: 3.036 | Reg loss: 0.028 | Tree loss: 3.036 | Accuracy: 0.146484 | 0.842 sec/iter
Epoch: 92 | Batch: 002 / 026 | Total loss: 3.011 | Reg loss: 0.028 | Tree loss: 3.011 | Accuracy: 0.115234 | 0.842 sec/iter
Epoch: 92 | Batch: 003 / 026 | Total loss: 2.959 | Reg loss: 0.028 | Tree loss: 2.959 | Accuracy: 0.185547 | 0.842 sec/iter
Epoch: 92 | Batch: 004 / 026 | Total loss: 3.092 | Reg loss: 0.028 | Tree loss: 3.092 | Accuracy: 0.121094 | 0.842 sec/iter
Epoch: 92 | Batch: 005 / 026 | Total loss: 3.053 | Reg loss: 0.028 | Tree loss: 3.053 | Ac

Epoch: 94 | Batch: 008 / 026 | Total loss: 2.970 | Reg loss: 0.029 | Tree loss: 2.970 | Accuracy: 0.136719 | 0.841 sec/iter
Epoch: 94 | Batch: 009 / 026 | Total loss: 2.943 | Reg loss: 0.029 | Tree loss: 2.943 | Accuracy: 0.146484 | 0.841 sec/iter
Epoch: 94 | Batch: 010 / 026 | Total loss: 2.912 | Reg loss: 0.029 | Tree loss: 2.912 | Accuracy: 0.134766 | 0.841 sec/iter
Epoch: 94 | Batch: 011 / 026 | Total loss: 3.013 | Reg loss: 0.029 | Tree loss: 3.013 | Accuracy: 0.128906 | 0.841 sec/iter
Epoch: 94 | Batch: 012 / 026 | Total loss: 2.931 | Reg loss: 0.029 | Tree loss: 2.931 | Accuracy: 0.136719 | 0.841 sec/iter
Epoch: 94 | Batch: 013 / 026 | Total loss: 2.998 | Reg loss: 0.029 | Tree loss: 2.998 | Accuracy: 0.146484 | 0.841 sec/iter
Epoch: 94 | Batch: 014 / 026 | Total loss: 2.860 | Reg loss: 0.029 | Tree loss: 2.860 | Accuracy: 0.173828 | 0.841 sec/iter
Epoch: 94 | Batch: 015 / 026 | Total loss: 2.945 | Reg loss: 0.029 | Tree loss: 2.945 | Accuracy: 0.154297 | 0.841 sec/iter
Epoch: 9

Epoch: 96 | Batch: 018 / 026 | Total loss: 2.937 | Reg loss: 0.029 | Tree loss: 2.937 | Accuracy: 0.152344 | 0.84 sec/iter
Epoch: 96 | Batch: 019 / 026 | Total loss: 2.963 | Reg loss: 0.029 | Tree loss: 2.963 | Accuracy: 0.142578 | 0.84 sec/iter
Epoch: 96 | Batch: 020 / 026 | Total loss: 2.893 | Reg loss: 0.029 | Tree loss: 2.893 | Accuracy: 0.152344 | 0.84 sec/iter
Epoch: 96 | Batch: 021 / 026 | Total loss: 2.948 | Reg loss: 0.029 | Tree loss: 2.948 | Accuracy: 0.132812 | 0.84 sec/iter
Epoch: 96 | Batch: 022 / 026 | Total loss: 2.890 | Reg loss: 0.029 | Tree loss: 2.890 | Accuracy: 0.138672 | 0.84 sec/iter
Epoch: 96 | Batch: 023 / 026 | Total loss: 2.899 | Reg loss: 0.029 | Tree loss: 2.899 | Accuracy: 0.158203 | 0.84 sec/iter
Epoch: 96 | Batch: 024 / 026 | Total loss: 2.859 | Reg loss: 0.029 | Tree loss: 2.859 | Accuracy: 0.144531 | 0.84 sec/iter
Epoch: 96 | Batch: 025 / 026 | Total loss: 2.801 | Reg loss: 0.029 | Tree loss: 2.801 | Accuracy: 0.157895 | 0.839 sec/iter
Average sparsen

Epoch: 99 | Batch: 000 / 026 | Total loss: 2.998 | Reg loss: 0.028 | Tree loss: 2.998 | Accuracy: 0.154297 | 0.839 sec/iter
Epoch: 99 | Batch: 001 / 026 | Total loss: 3.049 | Reg loss: 0.028 | Tree loss: 3.049 | Accuracy: 0.138672 | 0.838 sec/iter
Epoch: 99 | Batch: 002 / 026 | Total loss: 3.094 | Reg loss: 0.028 | Tree loss: 3.094 | Accuracy: 0.152344 | 0.838 sec/iter
Epoch: 99 | Batch: 003 / 026 | Total loss: 3.040 | Reg loss: 0.028 | Tree loss: 3.040 | Accuracy: 0.134766 | 0.838 sec/iter
Epoch: 99 | Batch: 004 / 026 | Total loss: 3.029 | Reg loss: 0.028 | Tree loss: 3.029 | Accuracy: 0.119141 | 0.838 sec/iter
Epoch: 99 | Batch: 005 / 026 | Total loss: 2.975 | Reg loss: 0.028 | Tree loss: 2.975 | Accuracy: 0.162109 | 0.838 sec/iter
Epoch: 99 | Batch: 006 / 026 | Total loss: 2.997 | Reg loss: 0.028 | Tree loss: 2.997 | Accuracy: 0.144531 | 0.838 sec/iter
Epoch: 99 | Batch: 007 / 026 | Total loss: 2.950 | Reg loss: 0.028 | Tree loss: 2.950 | Accuracy: 0.166016 | 0.838 sec/iter
Epoch: 9

In [31]:
plt.figure(figsize=(10, 5))
plt.ylabel("Accuracy")
plt.xlabel('Iteration')
plt.plot(accs, label='Accuracy vs iteration')
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [32]:
plt.figure()
plt.ylabel("Loss")
plt.xlabel('Iteration')
plt.plot(losses, label='Loss vs iteration')
plt.yscale("log")
plt.show()

plt.figure()
weights = tree.inner_nodes.weight.cpu().detach().numpy().flatten()
plt.hist(weights, bins=500)
weights_std = np.std(weights)
weights_mean = np.mean(weights)
plt.axvline(weights_mean + weights_std, color='r')
plt.axvline(weights_mean - weights_std, color='r')
plt.title(f"Mean: {weights_mean}   |   STD: {weights_std}")
plt.yscale("log")
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

# Tree Visualization

In [33]:
plt.figure(figsize=(15, 10), dpi=80)
avg_height, root = tree.visualize()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Average height: 9.948453608247423


# Extract Rules

# Accumulate samples in the leaves

In [34]:
print(f"Number of patterns: {len(root.get_leaves())}")

Number of patterns: 970


In [35]:
method = 'greedy'

In [36]:
root.clear_leaves_samples()

with torch.no_grad():
    for batch_idx, (data, target) in enumerate(tree_loader):
        root.accumulate_samples(data, method)



# Tighten boundaries

In [37]:
attr_names = dataset.items

# print(attr_names)
leaves = root.get_leaves()
sum_comprehensibility = 0
comprehensibilities = []
for pattern_counter, leaf in enumerate(leaves):
    leaf.reset_path()
    leaf.tighten_with_accumulated_samples()
    conds = leaf.get_path_conditions(attr_names)
    print(f"============== Pattern {pattern_counter + 1} ==============")
    comprehensibilities.append(sum([cond.comprehensibility for cond in conds]))
    
print(f"Average comprehensibility: {np.mean(comprehensibilities)}")
print(f"std comprehensibility: {np.std(comprehensibilities)}")
print(f"var comprehensibility: {np.var(comprehensibilities)}")
print(f"minimum comprehensibility: {np.min(comprehensibilities)}")
print(f"maximum comprehensibility: {np.max(comprehensibilities)}")

  return np.log(1 / (1 - x))


9675
3182






Average comprehensibility: 47.45154639175258
std comprehensibility: 3.1044070238126458
var comprehensibility: 9.63734296949729
minimum comprehensibility: 36
maximum comprehensibility: 56
