In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib widget
import numpy as np
import pandas as pd
import torch
import torchvision
from torch import optim
import matplotlib.pyplot as plt
import torch.nn.functional as F
from sklearn.metrics import pairwise_distances
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tqdm.notebook import tqdm
import os
import sys
import time
import torch.nn as nn
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)
from stream_generators.market_basket_dataset import MarketBasketDataset, BinaryEncodingTransform, RemoveItemsTransform
from utils.MatplotlibUtils import reduce_dims_and_plot
from network.auto_encoder import AutoEncoder
from losses.knn_loss import KNNLoss
from sklearn.cluster import DBSCAN
from sklearn.mixture import GaussianMixture
from soft_decision_tree.sdt_model import SDT
from sklearn.metrics import davies_bouldin_score

In [2]:
k = 256
tree_depth = 10
device = 'cuda'
dataset_path = r"/mnt/qnap/ekosman/Groceries_dataset.csv"

## Create the market basket dataset and use one-hot encoding for items

In [3]:
dataset = MarketBasketDataset(dataset_path=dataset_path)

In [4]:
model = AutoEncoder(dataset.n_items, 50, 4).train().to(device)
epochs = 500
lr = 5e-3
batch_size = 512
log_every = 5

In [5]:
dataset.transform = torchvision.transforms.Compose([
    RemoveItemsTransform(p=0.5),
    BinaryEncodingTransform(mapping=dataset.items_to_idx),
]
)
dataset.target_transform = torchvision.transforms.Compose([
    BinaryEncodingTransform(mapping=dataset.items_to_idx),
]
)

In [6]:
model.train()
data_iter = torch.utils.data.DataLoader(dataset,
                                     batch_size=batch_size,
                                     shuffle=True,
                                     num_workers=1,
                                     pin_memory=True)
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)
scheduler = ReduceLROnPlateau(optimizer, 'min', verbose=True, factor=0.7, threshold=1e-4)
knn_crt = KNNLoss(k=k).to(device)
losses = []
alpha = 10/170
gamma = 2
for epoch in range(epochs):
    total_loss = 0
    for iteration, (batch, target) in enumerate(data_iter):
        batch = batch.to(device)
        target = target.to(device)
        outputs, iterm = model(batch, return_intermidiate=True)
        mse_loss = F.binary_cross_entropy_with_logits(outputs, target, reduction='none')
        mask = torch.ones_like(mse_loss)
        mask[target == 0] = alpha ** gamma
        mask[target == 1] = (1 - alpha) ** gamma
        mse_loss = (mse_loss * mask).sum(dim=-1).mean()
        try:
            knn_loss = knn_crt(iterm)
            if torch.isinf(knn_loss):
                knn_loss = 0
        except ValueError:
            knn_loss = torch.tensor(0)
        loss = mse_loss + knn_loss
        total_loss += loss.item()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if iteration % log_every == 0:
            print(f"Epoch {epoch} / {epochs} | iteration {iteration} / {len(data_iter)} | Total Loss: {loss.item()} | KNN Loss: {knn_loss.item()} | BCE Loss: {mse_loss.item()}")
    
    scheduler.step(total_loss / (iteration + 1))
    losses.append(total_loss / (iteration + 1))

Epoch 0 / 500 | iteration 0 / 30 | Total Loss: 8.164484977722168 | KNN Loss: 6.233161449432373 | BCE Loss: 1.931323766708374
Epoch 0 / 500 | iteration 5 / 30 | Total Loss: 8.180647850036621 | KNN Loss: 6.2330732345581055 | BCE Loss: 1.9475743770599365
Epoch 0 / 500 | iteration 10 / 30 | Total Loss: 8.156098365783691 | KNN Loss: 6.233052730560303 | BCE Loss: 1.9230459928512573
Epoch 0 / 500 | iteration 15 / 30 | Total Loss: 8.168828964233398 | KNN Loss: 6.232822418212891 | BCE Loss: 1.936007022857666
Epoch 0 / 500 | iteration 20 / 30 | Total Loss: 8.191194534301758 | KNN Loss: 6.232855319976807 | BCE Loss: 1.9583394527435303
Epoch 0 / 500 | iteration 25 / 30 | Total Loss: 8.178317070007324 | KNN Loss: 6.232844352722168 | BCE Loss: 1.9454725980758667
Epoch 1 / 500 | iteration 0 / 30 | Total Loss: 8.167888641357422 | KNN Loss: 6.2326531410217285 | BCE Loss: 1.935235619544983
Epoch 1 / 500 | iteration 5 / 30 | Total Loss: 8.10910415649414 | KNN Loss: 6.23281192779541 | BCE Loss: 1.87629258

Epoch 10 / 500 | iteration 25 / 30 | Total Loss: 7.361576080322266 | KNN Loss: 6.212650775909424 | BCE Loss: 1.148925542831421
Epoch 11 / 500 | iteration 0 / 30 | Total Loss: 7.327655792236328 | KNN Loss: 6.208916664123535 | BCE Loss: 1.118739128112793
Epoch 11 / 500 | iteration 5 / 30 | Total Loss: 7.329785346984863 | KNN Loss: 6.207803249359131 | BCE Loss: 1.1219818592071533
Epoch 11 / 500 | iteration 10 / 30 | Total Loss: 7.311932563781738 | KNN Loss: 6.206353664398193 | BCE Loss: 1.105578899383545
Epoch 11 / 500 | iteration 15 / 30 | Total Loss: 7.303010940551758 | KNN Loss: 6.206175327301025 | BCE Loss: 1.0968353748321533
Epoch 11 / 500 | iteration 20 / 30 | Total Loss: 7.301403045654297 | KNN Loss: 6.20162296295166 | BCE Loss: 1.0997802019119263
Epoch 11 / 500 | iteration 25 / 30 | Total Loss: 7.295412540435791 | KNN Loss: 6.200352191925049 | BCE Loss: 1.0950604677200317
Epoch 12 / 500 | iteration 0 / 30 | Total Loss: 7.2853102684021 | KNN Loss: 6.198827266693115 | BCE Loss: 1.08

Epoch 21 / 500 | iteration 20 / 30 | Total Loss: 6.777968406677246 | KNN Loss: 5.711792945861816 | BCE Loss: 1.0661754608154297
Epoch 21 / 500 | iteration 25 / 30 | Total Loss: 6.72640323638916 | KNN Loss: 5.690771579742432 | BCE Loss: 1.0356318950653076
Epoch 22 / 500 | iteration 0 / 30 | Total Loss: 6.781209945678711 | KNN Loss: 5.708673000335693 | BCE Loss: 1.0725367069244385
Epoch 22 / 500 | iteration 5 / 30 | Total Loss: 6.753395080566406 | KNN Loss: 5.69069766998291 | BCE Loss: 1.0626976490020752
Epoch 22 / 500 | iteration 10 / 30 | Total Loss: 6.7406487464904785 | KNN Loss: 5.6938862800598145 | BCE Loss: 1.0467623472213745
Epoch 22 / 500 | iteration 15 / 30 | Total Loss: 6.773418426513672 | KNN Loss: 5.704648971557617 | BCE Loss: 1.0687696933746338
Epoch 22 / 500 | iteration 20 / 30 | Total Loss: 6.787770748138428 | KNN Loss: 5.717114448547363 | BCE Loss: 1.0706562995910645
Epoch 22 / 500 | iteration 25 / 30 | Total Loss: 6.706891059875488 | KNN Loss: 5.66427755355835 | BCE Loss

Epoch 32 / 500 | iteration 15 / 30 | Total Loss: 6.650391101837158 | KNN Loss: 5.624721050262451 | BCE Loss: 1.0256701707839966
Epoch 32 / 500 | iteration 20 / 30 | Total Loss: 6.719789028167725 | KNN Loss: 5.673291206359863 | BCE Loss: 1.0464977025985718
Epoch 32 / 500 | iteration 25 / 30 | Total Loss: 6.676051139831543 | KNN Loss: 5.617538928985596 | BCE Loss: 1.0585122108459473
Epoch 33 / 500 | iteration 0 / 30 | Total Loss: 6.666265487670898 | KNN Loss: 5.619668483734131 | BCE Loss: 1.046596884727478
Epoch 33 / 500 | iteration 5 / 30 | Total Loss: 6.716798782348633 | KNN Loss: 5.635405540466309 | BCE Loss: 1.0813934803009033
Epoch 33 / 500 | iteration 10 / 30 | Total Loss: 6.654577732086182 | KNN Loss: 5.638788223266602 | BCE Loss: 1.01578950881958
Epoch 33 / 500 | iteration 15 / 30 | Total Loss: 6.737710952758789 | KNN Loss: 5.6912736892700195 | BCE Loss: 1.046437382698059
Epoch 33 / 500 | iteration 20 / 30 | Total Loss: 6.695381164550781 | KNN Loss: 5.65086555480957 | BCE Loss: 1

Epoch 43 / 500 | iteration 10 / 30 | Total Loss: 6.675329208374023 | KNN Loss: 5.626804828643799 | BCE Loss: 1.0485246181488037
Epoch 43 / 500 | iteration 15 / 30 | Total Loss: 6.788908958435059 | KNN Loss: 5.719067096710205 | BCE Loss: 1.0698418617248535
Epoch 43 / 500 | iteration 20 / 30 | Total Loss: 6.6926984786987305 | KNN Loss: 5.659687042236328 | BCE Loss: 1.0330116748809814
Epoch 43 / 500 | iteration 25 / 30 | Total Loss: 6.709036350250244 | KNN Loss: 5.644835472106934 | BCE Loss: 1.0642008781433105
Epoch 44 / 500 | iteration 0 / 30 | Total Loss: 6.69319486618042 | KNN Loss: 5.628274917602539 | BCE Loss: 1.0649198293685913
Epoch 44 / 500 | iteration 5 / 30 | Total Loss: 6.7110772132873535 | KNN Loss: 5.625705242156982 | BCE Loss: 1.085371971130371
Epoch 44 / 500 | iteration 10 / 30 | Total Loss: 6.660208225250244 | KNN Loss: 5.618839740753174 | BCE Loss: 1.0413684844970703
Epoch 44 / 500 | iteration 15 / 30 | Total Loss: 6.687494277954102 | KNN Loss: 5.6368865966796875 | BCE Lo

Epoch 54 / 500 | iteration 5 / 30 | Total Loss: 6.667268753051758 | KNN Loss: 5.6239237785339355 | BCE Loss: 1.0433449745178223
Epoch 54 / 500 | iteration 10 / 30 | Total Loss: 6.678684711456299 | KNN Loss: 5.625412464141846 | BCE Loss: 1.0532722473144531
Epoch 54 / 500 | iteration 15 / 30 | Total Loss: 6.744381904602051 | KNN Loss: 5.686888694763184 | BCE Loss: 1.0574932098388672
Epoch 54 / 500 | iteration 20 / 30 | Total Loss: 6.666080474853516 | KNN Loss: 5.619416236877441 | BCE Loss: 1.0466644763946533
Epoch 54 / 500 | iteration 25 / 30 | Total Loss: 6.723796367645264 | KNN Loss: 5.659749984741211 | BCE Loss: 1.0640463829040527
Epoch 55 / 500 | iteration 0 / 30 | Total Loss: 6.65497350692749 | KNN Loss: 5.604964733123779 | BCE Loss: 1.0500088930130005
Epoch 55 / 500 | iteration 5 / 30 | Total Loss: 6.676738739013672 | KNN Loss: 5.629705905914307 | BCE Loss: 1.0470328330993652
Epoch 55 / 500 | iteration 10 / 30 | Total Loss: 6.687737941741943 | KNN Loss: 5.671445846557617 | BCE Loss

Epoch 64 / 500 | iteration 25 / 30 | Total Loss: 6.685281753540039 | KNN Loss: 5.612063407897949 | BCE Loss: 1.0732183456420898
Epoch 65 / 500 | iteration 0 / 30 | Total Loss: 6.726638317108154 | KNN Loss: 5.666286945343018 | BCE Loss: 1.0603513717651367
Epoch 65 / 500 | iteration 5 / 30 | Total Loss: 6.692787170410156 | KNN Loss: 5.638363361358643 | BCE Loss: 1.0544236898422241
Epoch 65 / 500 | iteration 10 / 30 | Total Loss: 6.717414855957031 | KNN Loss: 5.65566873550415 | BCE Loss: 1.06174635887146
Epoch 65 / 500 | iteration 15 / 30 | Total Loss: 6.657205104827881 | KNN Loss: 5.624779224395752 | BCE Loss: 1.032425880432129
Epoch 65 / 500 | iteration 20 / 30 | Total Loss: 6.687451362609863 | KNN Loss: 5.6171555519104 | BCE Loss: 1.0702956914901733
Epoch 65 / 500 | iteration 25 / 30 | Total Loss: 6.72709846496582 | KNN Loss: 5.661682605743408 | BCE Loss: 1.065415620803833
Epoch 66 / 500 | iteration 0 / 30 | Total Loss: 6.644546985626221 | KNN Loss: 5.6008830070495605 | BCE Loss: 1.043

Epoch 75 / 500 | iteration 15 / 30 | Total Loss: 6.773924827575684 | KNN Loss: 5.709336757659912 | BCE Loss: 1.0645878314971924
Epoch 75 / 500 | iteration 20 / 30 | Total Loss: 6.703548908233643 | KNN Loss: 5.641159534454346 | BCE Loss: 1.0623892545700073
Epoch 75 / 500 | iteration 25 / 30 | Total Loss: 6.657436847686768 | KNN Loss: 5.605563640594482 | BCE Loss: 1.0518732070922852
Epoch 76 / 500 | iteration 0 / 30 | Total Loss: 6.646927833557129 | KNN Loss: 5.598230838775635 | BCE Loss: 1.0486972332000732
Epoch 76 / 500 | iteration 5 / 30 | Total Loss: 6.710811614990234 | KNN Loss: 5.649511337280273 | BCE Loss: 1.06130051612854
Epoch 76 / 500 | iteration 10 / 30 | Total Loss: 6.77372932434082 | KNN Loss: 5.7055559158325195 | BCE Loss: 1.0681736469268799
Epoch 76 / 500 | iteration 15 / 30 | Total Loss: 6.673055171966553 | KNN Loss: 5.621921062469482 | BCE Loss: 1.0511339902877808
Epoch 76 / 500 | iteration 20 / 30 | Total Loss: 6.67494010925293 | KNN Loss: 5.614178657531738 | BCE Loss: 

Epoch 86 / 500 | iteration 5 / 30 | Total Loss: 6.7097487449646 | KNN Loss: 5.640717506408691 | BCE Loss: 1.0690313577651978
Epoch 86 / 500 | iteration 10 / 30 | Total Loss: 6.644429683685303 | KNN Loss: 5.5963006019592285 | BCE Loss: 1.0481290817260742
Epoch 86 / 500 | iteration 15 / 30 | Total Loss: 6.676712989807129 | KNN Loss: 5.64622688293457 | BCE Loss: 1.0304858684539795
Epoch 86 / 500 | iteration 20 / 30 | Total Loss: 6.676248550415039 | KNN Loss: 5.6216278076171875 | BCE Loss: 1.0546207427978516
Epoch 86 / 500 | iteration 25 / 30 | Total Loss: 6.660740852355957 | KNN Loss: 5.619797706604004 | BCE Loss: 1.040942907333374
Epoch 87 / 500 | iteration 0 / 30 | Total Loss: 6.682594299316406 | KNN Loss: 5.64323091506958 | BCE Loss: 1.0393636226654053
Epoch 87 / 500 | iteration 5 / 30 | Total Loss: 6.641056060791016 | KNN Loss: 5.615077495574951 | BCE Loss: 1.0259785652160645
Epoch 87 / 500 | iteration 10 / 30 | Total Loss: 6.66670036315918 | KNN Loss: 5.618285179138184 | BCE Loss: 1.

Epoch 96 / 500 | iteration 25 / 30 | Total Loss: 6.675928115844727 | KNN Loss: 5.6319499015808105 | BCE Loss: 1.043978214263916
Epoch 97 / 500 | iteration 0 / 30 | Total Loss: 6.682690143585205 | KNN Loss: 5.6225786209106445 | BCE Loss: 1.060111403465271
Epoch 97 / 500 | iteration 5 / 30 | Total Loss: 6.8442206382751465 | KNN Loss: 5.748156547546387 | BCE Loss: 1.0960640907287598
Epoch 97 / 500 | iteration 10 / 30 | Total Loss: 6.668389320373535 | KNN Loss: 5.611451148986816 | BCE Loss: 1.0569381713867188
Epoch 97 / 500 | iteration 15 / 30 | Total Loss: 6.693245887756348 | KNN Loss: 5.627976417541504 | BCE Loss: 1.0652692317962646
Epoch 97 / 500 | iteration 20 / 30 | Total Loss: 6.662095069885254 | KNN Loss: 5.625235080718994 | BCE Loss: 1.0368598699569702
Epoch 97 / 500 | iteration 25 / 30 | Total Loss: 6.646385192871094 | KNN Loss: 5.615398406982422 | BCE Loss: 1.030987024307251
Epoch 98 / 500 | iteration 0 / 30 | Total Loss: 6.7444658279418945 | KNN Loss: 5.680598735809326 | BCE Los

Epoch 107 / 500 | iteration 15 / 30 | Total Loss: 6.733925819396973 | KNN Loss: 5.682546138763428 | BCE Loss: 1.0513794422149658
Epoch 107 / 500 | iteration 20 / 30 | Total Loss: 6.773036479949951 | KNN Loss: 5.717731475830078 | BCE Loss: 1.0553048849105835
Epoch 107 / 500 | iteration 25 / 30 | Total Loss: 6.671138763427734 | KNN Loss: 5.612814903259277 | BCE Loss: 1.0583237409591675
Epoch 108 / 500 | iteration 0 / 30 | Total Loss: 6.693307876586914 | KNN Loss: 5.627022743225098 | BCE Loss: 1.0662853717803955
Epoch 108 / 500 | iteration 5 / 30 | Total Loss: 6.632239818572998 | KNN Loss: 5.6046319007873535 | BCE Loss: 1.027607798576355
Epoch 108 / 500 | iteration 10 / 30 | Total Loss: 6.6662702560424805 | KNN Loss: 5.615923881530762 | BCE Loss: 1.0503461360931396
Epoch 108 / 500 | iteration 15 / 30 | Total Loss: 6.706708908081055 | KNN Loss: 5.655028343200684 | BCE Loss: 1.051680326461792
Epoch 108 / 500 | iteration 20 / 30 | Total Loss: 6.68624210357666 | KNN Loss: 5.662965774536133 | 

Epoch 118 / 500 | iteration 5 / 30 | Total Loss: 6.690972328186035 | KNN Loss: 5.633313179016113 | BCE Loss: 1.057659387588501
Epoch 118 / 500 | iteration 10 / 30 | Total Loss: 6.655465602874756 | KNN Loss: 5.597350597381592 | BCE Loss: 1.0581148862838745
Epoch 118 / 500 | iteration 15 / 30 | Total Loss: 6.643095970153809 | KNN Loss: 5.589795112609863 | BCE Loss: 1.0533010959625244
Epoch 118 / 500 | iteration 20 / 30 | Total Loss: 6.687753677368164 | KNN Loss: 5.645359992980957 | BCE Loss: 1.0423938035964966
Epoch 118 / 500 | iteration 25 / 30 | Total Loss: 6.676441192626953 | KNN Loss: 5.600766181945801 | BCE Loss: 1.0756747722625732
Epoch 119 / 500 | iteration 0 / 30 | Total Loss: 6.698936462402344 | KNN Loss: 5.625314235687256 | BCE Loss: 1.073622465133667
Epoch 119 / 500 | iteration 5 / 30 | Total Loss: 6.703783988952637 | KNN Loss: 5.646705150604248 | BCE Loss: 1.0570785999298096
Epoch 119 / 500 | iteration 10 / 30 | Total Loss: 6.687049388885498 | KNN Loss: 5.648042678833008 | BC

Epoch 128 / 500 | iteration 25 / 30 | Total Loss: 6.705039978027344 | KNN Loss: 5.68022346496582 | BCE Loss: 1.0248167514801025
Epoch   129: reducing learning rate of group 0 to 4.1177e-04.
Epoch 129 / 500 | iteration 0 / 30 | Total Loss: 6.657160758972168 | KNN Loss: 5.610757827758789 | BCE Loss: 1.046403169631958
Epoch 129 / 500 | iteration 5 / 30 | Total Loss: 6.646566867828369 | KNN Loss: 5.61391019821167 | BCE Loss: 1.0326565504074097
Epoch 129 / 500 | iteration 10 / 30 | Total Loss: 6.717144966125488 | KNN Loss: 5.698241233825684 | BCE Loss: 1.0189037322998047
Epoch 129 / 500 | iteration 15 / 30 | Total Loss: 6.739503860473633 | KNN Loss: 5.692935466766357 | BCE Loss: 1.046568512916565
Epoch 129 / 500 | iteration 20 / 30 | Total Loss: 6.642569541931152 | KNN Loss: 5.597662925720215 | BCE Loss: 1.044906497001648
Epoch 129 / 500 | iteration 25 / 30 | Total Loss: 6.669338226318359 | KNN Loss: 5.624062538146973 | BCE Loss: 1.0452758073806763
Epoch 130 / 500 | iteration 0 / 30 | Total

Epoch 139 / 500 | iteration 15 / 30 | Total Loss: 6.665410041809082 | KNN Loss: 5.608802795410156 | BCE Loss: 1.0566071271896362
Epoch 139 / 500 | iteration 20 / 30 | Total Loss: 6.633964538574219 | KNN Loss: 5.602878093719482 | BCE Loss: 1.0310863256454468
Epoch 139 / 500 | iteration 25 / 30 | Total Loss: 6.724414825439453 | KNN Loss: 5.664705276489258 | BCE Loss: 1.0597097873687744
Epoch 140 / 500 | iteration 0 / 30 | Total Loss: 6.637143611907959 | KNN Loss: 5.619965076446533 | BCE Loss: 1.0171785354614258
Epoch 140 / 500 | iteration 5 / 30 | Total Loss: 6.742663860321045 | KNN Loss: 5.677785396575928 | BCE Loss: 1.0648784637451172
Epoch 140 / 500 | iteration 10 / 30 | Total Loss: 6.692472457885742 | KNN Loss: 5.632762908935547 | BCE Loss: 1.0597093105316162
Epoch 140 / 500 | iteration 15 / 30 | Total Loss: 6.73406982421875 | KNN Loss: 5.661208152770996 | BCE Loss: 1.072861909866333
Epoch 140 / 500 | iteration 20 / 30 | Total Loss: 6.7213850021362305 | KNN Loss: 5.658390045166016 | 

Epoch 150 / 500 | iteration 5 / 30 | Total Loss: 6.680704116821289 | KNN Loss: 5.618550777435303 | BCE Loss: 1.0621533393859863
Epoch 150 / 500 | iteration 10 / 30 | Total Loss: 6.6456403732299805 | KNN Loss: 5.617905616760254 | BCE Loss: 1.0277347564697266
Epoch 150 / 500 | iteration 15 / 30 | Total Loss: 6.703197002410889 | KNN Loss: 5.628095626831055 | BCE Loss: 1.0751012563705444
Epoch 150 / 500 | iteration 20 / 30 | Total Loss: 6.682847499847412 | KNN Loss: 5.633007526397705 | BCE Loss: 1.049839973449707
Epoch 150 / 500 | iteration 25 / 30 | Total Loss: 6.686898231506348 | KNN Loss: 5.640496730804443 | BCE Loss: 1.0464016199111938
Epoch 151 / 500 | iteration 0 / 30 | Total Loss: 6.712689399719238 | KNN Loss: 5.67117977142334 | BCE Loss: 1.0415095090866089
Epoch 151 / 500 | iteration 5 / 30 | Total Loss: 6.645291328430176 | KNN Loss: 5.596818923950195 | BCE Loss: 1.0484721660614014
Epoch 151 / 500 | iteration 10 / 30 | Total Loss: 6.711159706115723 | KNN Loss: 5.631913661956787 | B

Epoch 160 / 500 | iteration 25 / 30 | Total Loss: 6.612914562225342 | KNN Loss: 5.594588279724121 | BCE Loss: 1.0183262825012207
Epoch 161 / 500 | iteration 0 / 30 | Total Loss: 6.6717939376831055 | KNN Loss: 5.6015143394470215 | BCE Loss: 1.0702794790267944
Epoch 161 / 500 | iteration 5 / 30 | Total Loss: 6.6479411125183105 | KNN Loss: 5.614067077636719 | BCE Loss: 1.0338740348815918
Epoch 161 / 500 | iteration 10 / 30 | Total Loss: 6.656642913818359 | KNN Loss: 5.597691059112549 | BCE Loss: 1.0589516162872314
Epoch 161 / 500 | iteration 15 / 30 | Total Loss: 6.7353386878967285 | KNN Loss: 5.640489101409912 | BCE Loss: 1.0948495864868164
Epoch 161 / 500 | iteration 20 / 30 | Total Loss: 6.638606071472168 | KNN Loss: 5.623013019561768 | BCE Loss: 1.0155930519104004
Epoch 161 / 500 | iteration 25 / 30 | Total Loss: 6.6692023277282715 | KNN Loss: 5.608076095581055 | BCE Loss: 1.0611262321472168
Epoch 162 / 500 | iteration 0 / 30 | Total Loss: 6.699557781219482 | KNN Loss: 5.6558203697204

Epoch 171 / 500 | iteration 15 / 30 | Total Loss: 6.67775821685791 | KNN Loss: 5.641509532928467 | BCE Loss: 1.036248803138733
Epoch 171 / 500 | iteration 20 / 30 | Total Loss: 6.671025276184082 | KNN Loss: 5.6101555824279785 | BCE Loss: 1.0608699321746826
Epoch 171 / 500 | iteration 25 / 30 | Total Loss: 6.734766006469727 | KNN Loss: 5.6900787353515625 | BCE Loss: 1.044687032699585
Epoch 172 / 500 | iteration 0 / 30 | Total Loss: 6.689410209655762 | KNN Loss: 5.611419200897217 | BCE Loss: 1.0779911279678345
Epoch 172 / 500 | iteration 5 / 30 | Total Loss: 6.635105133056641 | KNN Loss: 5.596543788909912 | BCE Loss: 1.0385611057281494
Epoch 172 / 500 | iteration 10 / 30 | Total Loss: 6.841293811798096 | KNN Loss: 5.768370151519775 | BCE Loss: 1.0729236602783203
Epoch 172 / 500 | iteration 15 / 30 | Total Loss: 6.74234676361084 | KNN Loss: 5.680455684661865 | BCE Loss: 1.0618908405303955
Epoch 172 / 500 | iteration 20 / 30 | Total Loss: 6.657424449920654 | KNN Loss: 5.617287635803223 | B

Epoch 182 / 500 | iteration 5 / 30 | Total Loss: 6.651412010192871 | KNN Loss: 5.6018195152282715 | BCE Loss: 1.0495922565460205
Epoch 182 / 500 | iteration 10 / 30 | Total Loss: 6.647519111633301 | KNN Loss: 5.61251163482666 | BCE Loss: 1.0350077152252197
Epoch 182 / 500 | iteration 15 / 30 | Total Loss: 6.702489852905273 | KNN Loss: 5.669201374053955 | BCE Loss: 1.0332882404327393
Epoch 182 / 500 | iteration 20 / 30 | Total Loss: 6.726588726043701 | KNN Loss: 5.670341491699219 | BCE Loss: 1.0562472343444824
Epoch 182 / 500 | iteration 25 / 30 | Total Loss: 6.694048881530762 | KNN Loss: 5.6458048820495605 | BCE Loss: 1.048243761062622
Epoch 183 / 500 | iteration 0 / 30 | Total Loss: 6.750754356384277 | KNN Loss: 5.671268939971924 | BCE Loss: 1.0794851779937744
Epoch 183 / 500 | iteration 5 / 30 | Total Loss: 6.6449079513549805 | KNN Loss: 5.614165782928467 | BCE Loss: 1.0307421684265137
Epoch 183 / 500 | iteration 10 / 30 | Total Loss: 6.756049633026123 | KNN Loss: 5.684309482574463 |

Epoch 192 / 500 | iteration 25 / 30 | Total Loss: 6.677616119384766 | KNN Loss: 5.613646030426025 | BCE Loss: 1.0639702081680298
Epoch 193 / 500 | iteration 0 / 30 | Total Loss: 6.707725524902344 | KNN Loss: 5.670441627502441 | BCE Loss: 1.0372836589813232
Epoch 193 / 500 | iteration 5 / 30 | Total Loss: 6.633491039276123 | KNN Loss: 5.600869178771973 | BCE Loss: 1.03262197971344
Epoch 193 / 500 | iteration 10 / 30 | Total Loss: 6.727217674255371 | KNN Loss: 5.666467666625977 | BCE Loss: 1.0607502460479736
Epoch 193 / 500 | iteration 15 / 30 | Total Loss: 6.638514518737793 | KNN Loss: 5.6036810874938965 | BCE Loss: 1.034833550453186
Epoch 193 / 500 | iteration 20 / 30 | Total Loss: 6.653772830963135 | KNN Loss: 5.6080403327941895 | BCE Loss: 1.0457323789596558
Epoch 193 / 500 | iteration 25 / 30 | Total Loss: 6.69126558303833 | KNN Loss: 5.625704765319824 | BCE Loss: 1.0655606985092163
Epoch 194 / 500 | iteration 0 / 30 | Total Loss: 6.671272277832031 | KNN Loss: 5.638047218322754 | BC

Epoch 203 / 500 | iteration 15 / 30 | Total Loss: 6.724959373474121 | KNN Loss: 5.652915000915527 | BCE Loss: 1.0720441341400146
Epoch 203 / 500 | iteration 20 / 30 | Total Loss: 6.6394805908203125 | KNN Loss: 5.609781265258789 | BCE Loss: 1.0296993255615234
Epoch 203 / 500 | iteration 25 / 30 | Total Loss: 6.677958011627197 | KNN Loss: 5.6209635734558105 | BCE Loss: 1.0569944381713867
Epoch 204 / 500 | iteration 0 / 30 | Total Loss: 6.652341842651367 | KNN Loss: 5.620996952056885 | BCE Loss: 1.0313451290130615
Epoch 204 / 500 | iteration 5 / 30 | Total Loss: 6.716287136077881 | KNN Loss: 5.656905651092529 | BCE Loss: 1.0593814849853516
Epoch 204 / 500 | iteration 10 / 30 | Total Loss: 6.696322441101074 | KNN Loss: 5.644847393035889 | BCE Loss: 1.0514748096466064
Epoch 204 / 500 | iteration 15 / 30 | Total Loss: 6.734745502471924 | KNN Loss: 5.676151752471924 | BCE Loss: 1.0585936307907104
Epoch 204 / 500 | iteration 20 / 30 | Total Loss: 6.672548294067383 | KNN Loss: 5.610565185546875

Epoch 214 / 500 | iteration 5 / 30 | Total Loss: 6.7908759117126465 | KNN Loss: 5.709783554077148 | BCE Loss: 1.081092357635498
Epoch 214 / 500 | iteration 10 / 30 | Total Loss: 6.6126627922058105 | KNN Loss: 5.599368095397949 | BCE Loss: 1.0132948160171509
Epoch 214 / 500 | iteration 15 / 30 | Total Loss: 6.672238826751709 | KNN Loss: 5.605334281921387 | BCE Loss: 1.0669045448303223
Epoch 214 / 500 | iteration 20 / 30 | Total Loss: 6.667666435241699 | KNN Loss: 5.623753070831299 | BCE Loss: 1.0439136028289795
Epoch 214 / 500 | iteration 25 / 30 | Total Loss: 6.726398944854736 | KNN Loss: 5.660894870758057 | BCE Loss: 1.0655040740966797
Epoch 215 / 500 | iteration 0 / 30 | Total Loss: 6.646176338195801 | KNN Loss: 5.617110252380371 | BCE Loss: 1.0290660858154297
Epoch 215 / 500 | iteration 5 / 30 | Total Loss: 6.659642219543457 | KNN Loss: 5.6121416091918945 | BCE Loss: 1.0475003719329834
Epoch 215 / 500 | iteration 10 / 30 | Total Loss: 6.707139015197754 | KNN Loss: 5.6739630699157715

Epoch 224 / 500 | iteration 25 / 30 | Total Loss: 6.663578987121582 | KNN Loss: 5.6256513595581055 | BCE Loss: 1.0379278659820557
Epoch 225 / 500 | iteration 0 / 30 | Total Loss: 6.642977714538574 | KNN Loss: 5.610443115234375 | BCE Loss: 1.0325344800949097
Epoch 225 / 500 | iteration 5 / 30 | Total Loss: 6.660835266113281 | KNN Loss: 5.607780456542969 | BCE Loss: 1.053054928779602
Epoch 225 / 500 | iteration 10 / 30 | Total Loss: 6.733063220977783 | KNN Loss: 5.67727518081665 | BCE Loss: 1.0557880401611328
Epoch 225 / 500 | iteration 15 / 30 | Total Loss: 6.730774879455566 | KNN Loss: 5.68580436706543 | BCE Loss: 1.0449705123901367
Epoch 225 / 500 | iteration 20 / 30 | Total Loss: 6.6830949783325195 | KNN Loss: 5.656432151794434 | BCE Loss: 1.026663064956665
Epoch 225 / 500 | iteration 25 / 30 | Total Loss: 6.7849836349487305 | KNN Loss: 5.719709396362305 | BCE Loss: 1.0652741193771362
Epoch 226 / 500 | iteration 0 / 30 | Total Loss: 6.803089141845703 | KNN Loss: 5.719107151031494 | B

Epoch 235 / 500 | iteration 15 / 30 | Total Loss: 6.691985607147217 | KNN Loss: 5.617237567901611 | BCE Loss: 1.074748158454895
Epoch 235 / 500 | iteration 20 / 30 | Total Loss: 6.690059661865234 | KNN Loss: 5.649933815002441 | BCE Loss: 1.0401256084442139
Epoch 235 / 500 | iteration 25 / 30 | Total Loss: 6.705994606018066 | KNN Loss: 5.646218776702881 | BCE Loss: 1.0597758293151855
Epoch 236 / 500 | iteration 0 / 30 | Total Loss: 6.616951942443848 | KNN Loss: 5.601201057434082 | BCE Loss: 1.0157508850097656
Epoch 236 / 500 | iteration 5 / 30 | Total Loss: 6.677674293518066 | KNN Loss: 5.629677772521973 | BCE Loss: 1.0479965209960938
Epoch 236 / 500 | iteration 10 / 30 | Total Loss: 6.647692680358887 | KNN Loss: 5.608232021331787 | BCE Loss: 1.0394604206085205
Epoch 236 / 500 | iteration 15 / 30 | Total Loss: 6.667832851409912 | KNN Loss: 5.644240379333496 | BCE Loss: 1.023592472076416
Epoch 236 / 500 | iteration 20 / 30 | Total Loss: 6.713674545288086 | KNN Loss: 5.65777063369751 | BC

Epoch 246 / 500 | iteration 5 / 30 | Total Loss: 6.6944146156311035 | KNN Loss: 5.655323028564453 | BCE Loss: 1.0390915870666504
Epoch 246 / 500 | iteration 10 / 30 | Total Loss: 6.670848846435547 | KNN Loss: 5.634727478027344 | BCE Loss: 1.0361213684082031
Epoch 246 / 500 | iteration 15 / 30 | Total Loss: 6.644354820251465 | KNN Loss: 5.612605094909668 | BCE Loss: 1.0317498445510864
Epoch 246 / 500 | iteration 20 / 30 | Total Loss: 6.702712535858154 | KNN Loss: 5.651481628417969 | BCE Loss: 1.0512309074401855
Epoch 246 / 500 | iteration 25 / 30 | Total Loss: 6.640658855438232 | KNN Loss: 5.616275310516357 | BCE Loss: 1.0243836641311646
Epoch 247 / 500 | iteration 0 / 30 | Total Loss: 6.682989597320557 | KNN Loss: 5.650564193725586 | BCE Loss: 1.0324254035949707
Epoch 247 / 500 | iteration 5 / 30 | Total Loss: 6.63939905166626 | KNN Loss: 5.599682807922363 | BCE Loss: 1.0397162437438965
Epoch 247 / 500 | iteration 10 / 30 | Total Loss: 6.716461658477783 | KNN Loss: 5.66948127746582 | B

Epoch 256 / 500 | iteration 25 / 30 | Total Loss: 6.657812118530273 | KNN Loss: 5.602855205535889 | BCE Loss: 1.0549571514129639
Epoch 257 / 500 | iteration 0 / 30 | Total Loss: 6.6382341384887695 | KNN Loss: 5.605401039123535 | BCE Loss: 1.032833218574524
Epoch 257 / 500 | iteration 5 / 30 | Total Loss: 6.656836032867432 | KNN Loss: 5.612751007080078 | BCE Loss: 1.0440850257873535
Epoch 257 / 500 | iteration 10 / 30 | Total Loss: 6.648529052734375 | KNN Loss: 5.615478515625 | BCE Loss: 1.033050298690796
Epoch 257 / 500 | iteration 15 / 30 | Total Loss: 6.6168999671936035 | KNN Loss: 5.607003211975098 | BCE Loss: 1.0098967552185059
Epoch 257 / 500 | iteration 20 / 30 | Total Loss: 6.6676435470581055 | KNN Loss: 5.622049808502197 | BCE Loss: 1.0455937385559082
Epoch 257 / 500 | iteration 25 / 30 | Total Loss: 6.68458890914917 | KNN Loss: 5.627805233001709 | BCE Loss: 1.0567837953567505
Epoch 258 / 500 | iteration 0 / 30 | Total Loss: 6.723412036895752 | KNN Loss: 5.6740498542785645 | BC

Epoch 267 / 500 | iteration 15 / 30 | Total Loss: 6.672837257385254 | KNN Loss: 5.607428073883057 | BCE Loss: 1.0654094219207764
Epoch 267 / 500 | iteration 20 / 30 | Total Loss: 6.648698806762695 | KNN Loss: 5.599656105041504 | BCE Loss: 1.0490429401397705
Epoch 267 / 500 | iteration 25 / 30 | Total Loss: 6.69840145111084 | KNN Loss: 5.632015705108643 | BCE Loss: 1.0663855075836182
Epoch 268 / 500 | iteration 0 / 30 | Total Loss: 6.664953231811523 | KNN Loss: 5.635253429412842 | BCE Loss: 1.0297000408172607
Epoch 268 / 500 | iteration 5 / 30 | Total Loss: 6.713943958282471 | KNN Loss: 5.678980827331543 | BCE Loss: 1.0349632501602173
Epoch 268 / 500 | iteration 10 / 30 | Total Loss: 6.682404041290283 | KNN Loss: 5.6667962074279785 | BCE Loss: 1.0156078338623047
Epoch 268 / 500 | iteration 15 / 30 | Total Loss: 6.651074409484863 | KNN Loss: 5.612020969390869 | BCE Loss: 1.039053201675415
Epoch 268 / 500 | iteration 20 / 30 | Total Loss: 6.6873297691345215 | KNN Loss: 5.620919227600098 |

Epoch 278 / 500 | iteration 5 / 30 | Total Loss: 6.659604072570801 | KNN Loss: 5.625123500823975 | BCE Loss: 1.0344805717468262
Epoch 278 / 500 | iteration 10 / 30 | Total Loss: 6.674835205078125 | KNN Loss: 5.6158318519592285 | BCE Loss: 1.0590031147003174
Epoch 278 / 500 | iteration 15 / 30 | Total Loss: 6.741918563842773 | KNN Loss: 5.694171905517578 | BCE Loss: 1.0477468967437744
Epoch 278 / 500 | iteration 20 / 30 | Total Loss: 6.673861026763916 | KNN Loss: 5.608985424041748 | BCE Loss: 1.064875602722168
Epoch 278 / 500 | iteration 25 / 30 | Total Loss: 6.650136947631836 | KNN Loss: 5.604300022125244 | BCE Loss: 1.0458369255065918
Epoch 279 / 500 | iteration 0 / 30 | Total Loss: 6.615821838378906 | KNN Loss: 5.6042609214782715 | BCE Loss: 1.0115611553192139
Epoch 279 / 500 | iteration 5 / 30 | Total Loss: 6.648772239685059 | KNN Loss: 5.602992534637451 | BCE Loss: 1.0457799434661865
Epoch 279 / 500 | iteration 10 / 30 | Total Loss: 6.711928844451904 | KNN Loss: 5.639384746551514 |

Epoch 288 / 500 | iteration 25 / 30 | Total Loss: 6.695335388183594 | KNN Loss: 5.658448219299316 | BCE Loss: 1.0368874073028564
Epoch 289 / 500 | iteration 0 / 30 | Total Loss: 6.712217330932617 | KNN Loss: 5.66303825378418 | BCE Loss: 1.0491788387298584
Epoch 289 / 500 | iteration 5 / 30 | Total Loss: 6.705946922302246 | KNN Loss: 5.654061317443848 | BCE Loss: 1.0518854856491089
Epoch 289 / 500 | iteration 10 / 30 | Total Loss: 6.682750701904297 | KNN Loss: 5.609925270080566 | BCE Loss: 1.0728254318237305
Epoch 289 / 500 | iteration 15 / 30 | Total Loss: 6.662166118621826 | KNN Loss: 5.616677284240723 | BCE Loss: 1.0454888343811035
Epoch 289 / 500 | iteration 20 / 30 | Total Loss: 6.689208030700684 | KNN Loss: 5.65395975112915 | BCE Loss: 1.035248041152954
Epoch 289 / 500 | iteration 25 / 30 | Total Loss: 6.703148365020752 | KNN Loss: 5.64250373840332 | BCE Loss: 1.0606447458267212
Epoch 290 / 500 | iteration 0 / 30 | Total Loss: 6.661749362945557 | KNN Loss: 5.609600067138672 | BCE 

Epoch 299 / 500 | iteration 15 / 30 | Total Loss: 6.700072288513184 | KNN Loss: 5.637031555175781 | BCE Loss: 1.0630409717559814
Epoch 299 / 500 | iteration 20 / 30 | Total Loss: 6.698983669281006 | KNN Loss: 5.643558502197266 | BCE Loss: 1.0554251670837402
Epoch 299 / 500 | iteration 25 / 30 | Total Loss: 6.674539566040039 | KNN Loss: 5.618405342102051 | BCE Loss: 1.0561344623565674
Epoch 300 / 500 | iteration 0 / 30 | Total Loss: 6.68757438659668 | KNN Loss: 5.665266990661621 | BCE Loss: 1.0223076343536377
Epoch 300 / 500 | iteration 5 / 30 | Total Loss: 6.6887359619140625 | KNN Loss: 5.615756511688232 | BCE Loss: 1.0729796886444092
Epoch 300 / 500 | iteration 10 / 30 | Total Loss: 6.677878379821777 | KNN Loss: 5.636390209197998 | BCE Loss: 1.0414880514144897
Epoch 300 / 500 | iteration 15 / 30 | Total Loss: 6.664612293243408 | KNN Loss: 5.626784801483154 | BCE Loss: 1.037827491760254
Epoch 300 / 500 | iteration 20 / 30 | Total Loss: 6.682285308837891 | KNN Loss: 5.617672443389893 | 

Epoch 310 / 500 | iteration 5 / 30 | Total Loss: 6.699848175048828 | KNN Loss: 5.639166831970215 | BCE Loss: 1.0606813430786133
Epoch 310 / 500 | iteration 10 / 30 | Total Loss: 6.656959056854248 | KNN Loss: 5.60576057434082 | BCE Loss: 1.0511984825134277
Epoch 310 / 500 | iteration 15 / 30 | Total Loss: 6.655472278594971 | KNN Loss: 5.604700565338135 | BCE Loss: 1.050771713256836
Epoch 310 / 500 | iteration 20 / 30 | Total Loss: 6.6800689697265625 | KNN Loss: 5.610294342041016 | BCE Loss: 1.0697745084762573
Epoch 310 / 500 | iteration 25 / 30 | Total Loss: 6.683248519897461 | KNN Loss: 5.634279727935791 | BCE Loss: 1.048969030380249
Epoch 311 / 500 | iteration 0 / 30 | Total Loss: 6.658801078796387 | KNN Loss: 5.616293907165527 | BCE Loss: 1.0425069332122803
Epoch 311 / 500 | iteration 5 / 30 | Total Loss: 6.645790100097656 | KNN Loss: 5.608813285827637 | BCE Loss: 1.0369765758514404
Epoch 311 / 500 | iteration 10 / 30 | Total Loss: 6.67181396484375 | KNN Loss: 5.612786293029785 | BCE

Epoch 320 / 500 | iteration 25 / 30 | Total Loss: 6.68550968170166 | KNN Loss: 5.6514506340026855 | BCE Loss: 1.0340590476989746
Epoch 321 / 500 | iteration 0 / 30 | Total Loss: 6.697296142578125 | KNN Loss: 5.636620998382568 | BCE Loss: 1.0606751441955566
Epoch 321 / 500 | iteration 5 / 30 | Total Loss: 6.6705427169799805 | KNN Loss: 5.61781644821167 | BCE Loss: 1.052726149559021
Epoch 321 / 500 | iteration 10 / 30 | Total Loss: 6.667546272277832 | KNN Loss: 5.653421401977539 | BCE Loss: 1.0141249895095825
Epoch 321 / 500 | iteration 15 / 30 | Total Loss: 6.659872055053711 | KNN Loss: 5.611155033111572 | BCE Loss: 1.0487172603607178
Epoch 321 / 500 | iteration 20 / 30 | Total Loss: 6.671212196350098 | KNN Loss: 5.612053394317627 | BCE Loss: 1.0591585636138916
Epoch 321 / 500 | iteration 25 / 30 | Total Loss: 6.650148391723633 | KNN Loss: 5.625871658325195 | BCE Loss: 1.024276614189148
Epoch 322 / 500 | iteration 0 / 30 | Total Loss: 6.645895481109619 | KNN Loss: 5.631824016571045 | BC

Epoch 331 / 500 | iteration 15 / 30 | Total Loss: 6.640481948852539 | KNN Loss: 5.621628761291504 | BCE Loss: 1.0188534259796143
Epoch 331 / 500 | iteration 20 / 30 | Total Loss: 6.6848273277282715 | KNN Loss: 5.608381271362305 | BCE Loss: 1.0764459371566772
Epoch 331 / 500 | iteration 25 / 30 | Total Loss: 6.666048049926758 | KNN Loss: 5.645476818084717 | BCE Loss: 1.0205714702606201
Epoch 332 / 500 | iteration 0 / 30 | Total Loss: 6.643307685852051 | KNN Loss: 5.614668369293213 | BCE Loss: 1.028639316558838
Epoch 332 / 500 | iteration 5 / 30 | Total Loss: 6.653417110443115 | KNN Loss: 5.599168300628662 | BCE Loss: 1.0542488098144531
Epoch 332 / 500 | iteration 10 / 30 | Total Loss: 6.667682647705078 | KNN Loss: 5.597994327545166 | BCE Loss: 1.0696882009506226
Epoch 332 / 500 | iteration 15 / 30 | Total Loss: 6.632024765014648 | KNN Loss: 5.605566501617432 | BCE Loss: 1.0264580249786377
Epoch 332 / 500 | iteration 20 / 30 | Total Loss: 6.685467720031738 | KNN Loss: 5.640016078948975 |

Epoch 342 / 500 | iteration 5 / 30 | Total Loss: 6.702144622802734 | KNN Loss: 5.638906002044678 | BCE Loss: 1.0632383823394775
Epoch 342 / 500 | iteration 10 / 30 | Total Loss: 6.697080135345459 | KNN Loss: 5.62741231918335 | BCE Loss: 1.0696678161621094
Epoch 342 / 500 | iteration 15 / 30 | Total Loss: 6.746339797973633 | KNN Loss: 5.701863765716553 | BCE Loss: 1.04447603225708
Epoch 342 / 500 | iteration 20 / 30 | Total Loss: 6.698399543762207 | KNN Loss: 5.645187854766846 | BCE Loss: 1.0532116889953613
Epoch 342 / 500 | iteration 25 / 30 | Total Loss: 6.739989280700684 | KNN Loss: 5.704217433929443 | BCE Loss: 1.0357716083526611
Epoch 343 / 500 | iteration 0 / 30 | Total Loss: 6.669395923614502 | KNN Loss: 5.614299297332764 | BCE Loss: 1.0550966262817383
Epoch 343 / 500 | iteration 5 / 30 | Total Loss: 6.655205726623535 | KNN Loss: 5.639425754547119 | BCE Loss: 1.015779972076416
Epoch 343 / 500 | iteration 10 / 30 | Total Loss: 6.658651828765869 | KNN Loss: 5.630436420440674 | BCE 

Epoch 352 / 500 | iteration 25 / 30 | Total Loss: 6.695579528808594 | KNN Loss: 5.6214704513549805 | BCE Loss: 1.0741093158721924
Epoch 353 / 500 | iteration 0 / 30 | Total Loss: 6.737795352935791 | KNN Loss: 5.681590557098389 | BCE Loss: 1.0562047958374023
Epoch 353 / 500 | iteration 5 / 30 | Total Loss: 6.747066974639893 | KNN Loss: 5.67075777053833 | BCE Loss: 1.076309084892273
Epoch 353 / 500 | iteration 10 / 30 | Total Loss: 6.6765851974487305 | KNN Loss: 5.6425042152404785 | BCE Loss: 1.034081220626831
Epoch 353 / 500 | iteration 15 / 30 | Total Loss: 6.687179088592529 | KNN Loss: 5.636568546295166 | BCE Loss: 1.0506104230880737
Epoch 353 / 500 | iteration 20 / 30 | Total Loss: 6.628294944763184 | KNN Loss: 5.6126389503479 | BCE Loss: 1.0156559944152832
Epoch 353 / 500 | iteration 25 / 30 | Total Loss: 6.666232109069824 | KNN Loss: 5.623336315155029 | BCE Loss: 1.042895793914795
Epoch 354 / 500 | iteration 0 / 30 | Total Loss: 6.660545825958252 | KNN Loss: 5.62753963470459 | BCE 

Epoch 363 / 500 | iteration 15 / 30 | Total Loss: 6.684708118438721 | KNN Loss: 5.643117904663086 | BCE Loss: 1.0415902137756348
Epoch 363 / 500 | iteration 20 / 30 | Total Loss: 6.659982681274414 | KNN Loss: 5.632632732391357 | BCE Loss: 1.0273501873016357
Epoch 363 / 500 | iteration 25 / 30 | Total Loss: 6.689106464385986 | KNN Loss: 5.635043621063232 | BCE Loss: 1.0540629625320435
Epoch 364 / 500 | iteration 0 / 30 | Total Loss: 6.699442386627197 | KNN Loss: 5.647862911224365 | BCE Loss: 1.051579475402832
Epoch 364 / 500 | iteration 5 / 30 | Total Loss: 6.638387680053711 | KNN Loss: 5.6126708984375 | BCE Loss: 1.025716781616211
Epoch 364 / 500 | iteration 10 / 30 | Total Loss: 6.709596633911133 | KNN Loss: 5.636284351348877 | BCE Loss: 1.073312520980835
Epoch 364 / 500 | iteration 15 / 30 | Total Loss: 6.690317630767822 | KNN Loss: 5.64185094833374 | BCE Loss: 1.0484668016433716
Epoch 364 / 500 | iteration 20 / 30 | Total Loss: 6.731231689453125 | KNN Loss: 5.649199962615967 | BCE L

Epoch 374 / 500 | iteration 5 / 30 | Total Loss: 6.657561302185059 | KNN Loss: 5.61644983291626 | BCE Loss: 1.041111707687378
Epoch 374 / 500 | iteration 10 / 30 | Total Loss: 6.661438941955566 | KNN Loss: 5.606329441070557 | BCE Loss: 1.0551093816757202
Epoch 374 / 500 | iteration 15 / 30 | Total Loss: 6.664345741271973 | KNN Loss: 5.604970455169678 | BCE Loss: 1.059375286102295
Epoch 374 / 500 | iteration 20 / 30 | Total Loss: 6.641066551208496 | KNN Loss: 5.618422031402588 | BCE Loss: 1.0226445198059082
Epoch 374 / 500 | iteration 25 / 30 | Total Loss: 6.684414863586426 | KNN Loss: 5.6556782722473145 | BCE Loss: 1.0287363529205322
Epoch 375 / 500 | iteration 0 / 30 | Total Loss: 6.700927257537842 | KNN Loss: 5.642199993133545 | BCE Loss: 1.0587271451950073
Epoch 375 / 500 | iteration 5 / 30 | Total Loss: 6.705483436584473 | KNN Loss: 5.661896705627441 | BCE Loss: 1.0435867309570312
Epoch 375 / 500 | iteration 10 / 30 | Total Loss: 6.67356014251709 | KNN Loss: 5.643574237823486 | BCE

Epoch 384 / 500 | iteration 25 / 30 | Total Loss: 6.713467597961426 | KNN Loss: 5.680002212524414 | BCE Loss: 1.0334655046463013
Epoch 385 / 500 | iteration 0 / 30 | Total Loss: 6.670261383056641 | KNN Loss: 5.607552528381348 | BCE Loss: 1.062709093093872
Epoch 385 / 500 | iteration 5 / 30 | Total Loss: 6.681747913360596 | KNN Loss: 5.635590553283691 | BCE Loss: 1.0461574792861938
Epoch 385 / 500 | iteration 10 / 30 | Total Loss: 6.688969135284424 | KNN Loss: 5.64265775680542 | BCE Loss: 1.0463112592697144
Epoch 385 / 500 | iteration 15 / 30 | Total Loss: 6.640787124633789 | KNN Loss: 5.6023454666137695 | BCE Loss: 1.0384414196014404
Epoch 385 / 500 | iteration 20 / 30 | Total Loss: 6.720127582550049 | KNN Loss: 5.6583356857299805 | BCE Loss: 1.0617918968200684
Epoch 385 / 500 | iteration 25 / 30 | Total Loss: 6.681565284729004 | KNN Loss: 5.5961012840271 | BCE Loss: 1.0854637622833252
Epoch 386 / 500 | iteration 0 / 30 | Total Loss: 6.627669334411621 | KNN Loss: 5.597972393035889 | BC

Epoch 395 / 500 | iteration 15 / 30 | Total Loss: 6.6987128257751465 | KNN Loss: 5.642199993133545 | BCE Loss: 1.056512713432312
Epoch 395 / 500 | iteration 20 / 30 | Total Loss: 6.684201240539551 | KNN Loss: 5.613606929779053 | BCE Loss: 1.070594310760498
Epoch 395 / 500 | iteration 25 / 30 | Total Loss: 6.676967620849609 | KNN Loss: 5.63129997253418 | BCE Loss: 1.0456678867340088
Epoch 396 / 500 | iteration 0 / 30 | Total Loss: 6.698118209838867 | KNN Loss: 5.6217241287231445 | BCE Loss: 1.0763940811157227
Epoch 396 / 500 | iteration 5 / 30 | Total Loss: 6.666566848754883 | KNN Loss: 5.608283996582031 | BCE Loss: 1.0582828521728516
Epoch 396 / 500 | iteration 10 / 30 | Total Loss: 6.641812324523926 | KNN Loss: 5.6164045333862305 | BCE Loss: 1.0254075527191162
Epoch 396 / 500 | iteration 15 / 30 | Total Loss: 6.696036338806152 | KNN Loss: 5.622082233428955 | BCE Loss: 1.0739542245864868
Epoch 396 / 500 | iteration 20 / 30 | Total Loss: 6.692816257476807 | KNN Loss: 5.640416145324707 |

Epoch 406 / 500 | iteration 5 / 30 | Total Loss: 6.660782814025879 | KNN Loss: 5.630023002624512 | BCE Loss: 1.0307598114013672
Epoch 406 / 500 | iteration 10 / 30 | Total Loss: 6.725547790527344 | KNN Loss: 5.635313510894775 | BCE Loss: 1.090234398841858
Epoch 406 / 500 | iteration 15 / 30 | Total Loss: 6.670483589172363 | KNN Loss: 5.610896587371826 | BCE Loss: 1.0595868825912476
Epoch 406 / 500 | iteration 20 / 30 | Total Loss: 6.642815113067627 | KNN Loss: 5.604629993438721 | BCE Loss: 1.0381850004196167
Epoch 406 / 500 | iteration 25 / 30 | Total Loss: 6.663578510284424 | KNN Loss: 5.64147424697876 | BCE Loss: 1.0221043825149536
Epoch 407 / 500 | iteration 0 / 30 | Total Loss: 6.687146186828613 | KNN Loss: 5.636116027832031 | BCE Loss: 1.0510303974151611
Epoch 407 / 500 | iteration 5 / 30 | Total Loss: 6.680265426635742 | KNN Loss: 5.6080498695373535 | BCE Loss: 1.0722153186798096
Epoch 407 / 500 | iteration 10 / 30 | Total Loss: 6.681831359863281 | KNN Loss: 5.6144609451293945 | 

Epoch 416 / 500 | iteration 25 / 30 | Total Loss: 6.662468433380127 | KNN Loss: 5.6211724281311035 | BCE Loss: 1.0412960052490234
Epoch 417 / 500 | iteration 0 / 30 | Total Loss: 6.740629196166992 | KNN Loss: 5.683051586151123 | BCE Loss: 1.0575777292251587
Epoch 417 / 500 | iteration 5 / 30 | Total Loss: 6.644497871398926 | KNN Loss: 5.602814197540283 | BCE Loss: 1.0416839122772217
Epoch 417 / 500 | iteration 10 / 30 | Total Loss: 6.684334754943848 | KNN Loss: 5.643134593963623 | BCE Loss: 1.0411999225616455
Epoch 417 / 500 | iteration 15 / 30 | Total Loss: 6.765189170837402 | KNN Loss: 5.70035982131958 | BCE Loss: 1.0648292303085327
Epoch 417 / 500 | iteration 20 / 30 | Total Loss: 6.679823398590088 | KNN Loss: 5.6223673820495605 | BCE Loss: 1.0574560165405273
Epoch 417 / 500 | iteration 25 / 30 | Total Loss: 6.6344170570373535 | KNN Loss: 5.600895404815674 | BCE Loss: 1.0335215330123901
Epoch 418 / 500 | iteration 0 / 30 | Total Loss: 6.668328285217285 | KNN Loss: 5.633907318115234 

Epoch 427 / 500 | iteration 15 / 30 | Total Loss: 6.734557151794434 | KNN Loss: 5.708924293518066 | BCE Loss: 1.0256330966949463
Epoch 427 / 500 | iteration 20 / 30 | Total Loss: 6.654829025268555 | KNN Loss: 5.6163649559021 | BCE Loss: 1.0384643077850342
Epoch 427 / 500 | iteration 25 / 30 | Total Loss: 6.65602970123291 | KNN Loss: 5.600555896759033 | BCE Loss: 1.0554739236831665
Epoch 428 / 500 | iteration 0 / 30 | Total Loss: 6.622618198394775 | KNN Loss: 5.600786209106445 | BCE Loss: 1.0218318700790405
Epoch 428 / 500 | iteration 5 / 30 | Total Loss: 6.678993225097656 | KNN Loss: 5.610630989074707 | BCE Loss: 1.0683622360229492
Epoch 428 / 500 | iteration 10 / 30 | Total Loss: 6.718871593475342 | KNN Loss: 5.6543660163879395 | BCE Loss: 1.0645055770874023
Epoch 428 / 500 | iteration 15 / 30 | Total Loss: 6.674161434173584 | KNN Loss: 5.634287357330322 | BCE Loss: 1.0398739576339722
Epoch 428 / 500 | iteration 20 / 30 | Total Loss: 6.674372673034668 | KNN Loss: 5.616437911987305 | B

Epoch 438 / 500 | iteration 5 / 30 | Total Loss: 6.7002787590026855 | KNN Loss: 5.645789623260498 | BCE Loss: 1.0544891357421875
Epoch 438 / 500 | iteration 10 / 30 | Total Loss: 6.699611663818359 | KNN Loss: 5.6629557609558105 | BCE Loss: 1.0366556644439697
Epoch 438 / 500 | iteration 15 / 30 | Total Loss: 6.684426307678223 | KNN Loss: 5.6183881759643555 | BCE Loss: 1.0660383701324463
Epoch 438 / 500 | iteration 20 / 30 | Total Loss: 6.675989151000977 | KNN Loss: 5.623812675476074 | BCE Loss: 1.0521767139434814
Epoch 438 / 500 | iteration 25 / 30 | Total Loss: 6.678472518920898 | KNN Loss: 5.646635055541992 | BCE Loss: 1.0318377017974854
Epoch 439 / 500 | iteration 0 / 30 | Total Loss: 6.706539154052734 | KNN Loss: 5.634246349334717 | BCE Loss: 1.0722928047180176
Epoch 439 / 500 | iteration 5 / 30 | Total Loss: 6.640414714813232 | KNN Loss: 5.60800838470459 | BCE Loss: 1.0324063301086426
Epoch 439 / 500 | iteration 10 / 30 | Total Loss: 6.63592529296875 | KNN Loss: 5.605398178100586 |

Epoch 448 / 500 | iteration 25 / 30 | Total Loss: 6.693940162658691 | KNN Loss: 5.643997669219971 | BCE Loss: 1.0499427318572998
Epoch 449 / 500 | iteration 0 / 30 | Total Loss: 6.675162315368652 | KNN Loss: 5.601166248321533 | BCE Loss: 1.07399582862854
Epoch 449 / 500 | iteration 5 / 30 | Total Loss: 6.631659984588623 | KNN Loss: 5.604273319244385 | BCE Loss: 1.0273866653442383
Epoch 449 / 500 | iteration 10 / 30 | Total Loss: 6.649306297302246 | KNN Loss: 5.607953071594238 | BCE Loss: 1.0413532257080078
Epoch 449 / 500 | iteration 15 / 30 | Total Loss: 6.697667598724365 | KNN Loss: 5.652248859405518 | BCE Loss: 1.0454188585281372
Epoch 449 / 500 | iteration 20 / 30 | Total Loss: 6.626442909240723 | KNN Loss: 5.607489585876465 | BCE Loss: 1.0189530849456787
Epoch 449 / 500 | iteration 25 / 30 | Total Loss: 6.717464447021484 | KNN Loss: 5.687590599060059 | BCE Loss: 1.0298739671707153
Epoch 450 / 500 | iteration 0 / 30 | Total Loss: 6.659090042114258 | KNN Loss: 5.6447601318359375 | B

Epoch 459 / 500 | iteration 15 / 30 | Total Loss: 6.695256233215332 | KNN Loss: 5.636814117431641 | BCE Loss: 1.0584423542022705
Epoch 459 / 500 | iteration 20 / 30 | Total Loss: 6.671474933624268 | KNN Loss: 5.605099678039551 | BCE Loss: 1.0663752555847168
Epoch 459 / 500 | iteration 25 / 30 | Total Loss: 6.645374774932861 | KNN Loss: 5.601078033447266 | BCE Loss: 1.0442967414855957
Epoch 460 / 500 | iteration 0 / 30 | Total Loss: 6.737398624420166 | KNN Loss: 5.681279182434082 | BCE Loss: 1.056119441986084
Epoch 460 / 500 | iteration 5 / 30 | Total Loss: 6.646589279174805 | KNN Loss: 5.604992389678955 | BCE Loss: 1.0415971279144287
Epoch 460 / 500 | iteration 10 / 30 | Total Loss: 6.741371154785156 | KNN Loss: 5.673711776733398 | BCE Loss: 1.067659616470337
Epoch 460 / 500 | iteration 15 / 30 | Total Loss: 6.6905717849731445 | KNN Loss: 5.645541667938232 | BCE Loss: 1.045029878616333
Epoch 460 / 500 | iteration 20 / 30 | Total Loss: 6.740664482116699 | KNN Loss: 5.651845932006836 | B

Epoch 470 / 500 | iteration 5 / 30 | Total Loss: 6.6612348556518555 | KNN Loss: 5.609503746032715 | BCE Loss: 1.0517313480377197
Epoch 470 / 500 | iteration 10 / 30 | Total Loss: 6.639632225036621 | KNN Loss: 5.60650634765625 | BCE Loss: 1.0331257581710815
Epoch 470 / 500 | iteration 15 / 30 | Total Loss: 6.713498115539551 | KNN Loss: 5.690204620361328 | BCE Loss: 1.0232934951782227
Epoch 470 / 500 | iteration 20 / 30 | Total Loss: 6.689447402954102 | KNN Loss: 5.658040523529053 | BCE Loss: 1.031407117843628
Epoch 470 / 500 | iteration 25 / 30 | Total Loss: 6.685724258422852 | KNN Loss: 5.6364359855651855 | BCE Loss: 1.0492883920669556
Epoch 471 / 500 | iteration 0 / 30 | Total Loss: 6.699162483215332 | KNN Loss: 5.618096828460693 | BCE Loss: 1.0810654163360596
Epoch 471 / 500 | iteration 5 / 30 | Total Loss: 6.6565656661987305 | KNN Loss: 5.613607406616211 | BCE Loss: 1.0429580211639404
Epoch 471 / 500 | iteration 10 / 30 | Total Loss: 6.690430164337158 | KNN Loss: 5.645288944244385 |

Epoch 480 / 500 | iteration 25 / 30 | Total Loss: 6.666149616241455 | KNN Loss: 5.60941743850708 | BCE Loss: 1.056732177734375
Epoch 481 / 500 | iteration 0 / 30 | Total Loss: 6.658149719238281 | KNN Loss: 5.624458312988281 | BCE Loss: 1.0336915254592896
Epoch 481 / 500 | iteration 5 / 30 | Total Loss: 6.652418613433838 | KNN Loss: 5.60536527633667 | BCE Loss: 1.0470534563064575
Epoch 481 / 500 | iteration 10 / 30 | Total Loss: 6.755138397216797 | KNN Loss: 5.71406364440918 | BCE Loss: 1.0410747528076172
Epoch 481 / 500 | iteration 15 / 30 | Total Loss: 6.677190780639648 | KNN Loss: 5.620389461517334 | BCE Loss: 1.0568015575408936
Epoch 481 / 500 | iteration 20 / 30 | Total Loss: 6.673486709594727 | KNN Loss: 5.606107234954834 | BCE Loss: 1.0673794746398926
Epoch 481 / 500 | iteration 25 / 30 | Total Loss: 6.66376256942749 | KNN Loss: 5.61820650100708 | BCE Loss: 1.0455561876296997
Epoch 482 / 500 | iteration 0 / 30 | Total Loss: 6.731047630310059 | KNN Loss: 5.666060447692871 | BCE Lo

Epoch 491 / 500 | iteration 15 / 30 | Total Loss: 6.668487548828125 | KNN Loss: 5.611492156982422 | BCE Loss: 1.0569953918457031
Epoch 491 / 500 | iteration 20 / 30 | Total Loss: 6.717453956604004 | KNN Loss: 5.6381916999816895 | BCE Loss: 1.0792620182037354
Epoch 491 / 500 | iteration 25 / 30 | Total Loss: 6.6683526039123535 | KNN Loss: 5.634644031524658 | BCE Loss: 1.0337084531784058
Epoch 492 / 500 | iteration 0 / 30 | Total Loss: 6.681865692138672 | KNN Loss: 5.645390033721924 | BCE Loss: 1.0364757776260376
Epoch 492 / 500 | iteration 5 / 30 | Total Loss: 6.676312446594238 | KNN Loss: 5.6179609298706055 | BCE Loss: 1.0583512783050537
Epoch 492 / 500 | iteration 10 / 30 | Total Loss: 6.69525146484375 | KNN Loss: 5.647629737854004 | BCE Loss: 1.047621488571167
Epoch 492 / 500 | iteration 15 / 30 | Total Loss: 6.663219451904297 | KNN Loss: 5.6334547996521 | BCE Loss: 1.0297648906707764
Epoch 492 / 500 | iteration 20 / 30 | Total Loss: 6.682117462158203 | KNN Loss: 5.639921188354492 | 

In [7]:
outputs, iterm = model(dataset[67][0].unsqueeze(0).to(device), return_intermidiate=True)
print(outputs)
print(outputs.min())
print(outputs.max())

tensor([[ 3.0258,  3.9439,  2.6377,  3.6324,  3.5231,  0.6968,  2.7194,  2.2294,
          2.3823,  2.0247,  2.2558,  2.2561,  0.7992,  1.9055,  1.3400,  1.4725,
          2.8682,  3.3029,  2.8738,  2.3719,  1.7438,  3.0170,  2.3850,  2.6840,
          2.6220,  1.7341,  2.1848,  1.4313,  1.5433,  0.3170, -0.2358,  0.9944,
          0.1910,  0.9046,  1.5290,  1.4944,  0.9864,  3.3969,  0.8239,  1.3225,
          0.9583, -0.7440, -0.2944,  2.3565,  2.2295,  0.7441, -0.2419,  0.1186,
          1.5079,  2.5429,  1.8524,  0.0965,  1.4037,  0.5762, -0.6752,  1.1513,
          1.4866,  1.3908,  1.3941,  1.9165,  0.5962,  0.8472,  0.1199,  1.7733,
          1.3448,  1.6887, -1.9134,  0.2970,  2.3554,  2.2118,  2.6246,  0.4108,
          1.3821,  2.5102,  2.0509,  1.3257,  0.2335,  0.7602,  0.2090,  1.6089,
          0.0213,  0.3772,  1.8677, -0.4050,  0.2225, -1.1053, -2.4161, -0.2930,
          0.5203, -1.8782,  0.4301, -0.1449, -0.6335, -0.9953,  0.5809,  1.2839,
         -0.7181, -0.7370,  

In [8]:
plt.figure()
plt.plot(losses)
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [9]:
dataset.transform = torchvision.transforms.Compose([
    BinaryEncodingTransform(mapping=dataset.items_to_idx),
]
)
dataset.target_transform = torchvision.transforms.Compose([
    BinaryEncodingTransform(mapping=dataset.items_to_idx),
]
)

In [10]:
dataset_ = [d[0].cpu() for d in dataset]

In [11]:
model = model.eval().cpu()
projections = model.calculate_intermidiate(dataset_)

100%|██████████| 15/15 [00:00<00:00, 83.52it/s]


In [12]:
distances = pairwise_distances(projections)
# distances = np.triu(distances)
distances_f = distances.flatten()

plt.matshow(distances)
plt.colorbar()
plt.figure()
plt.hist(distances_f[distances_f > 0], bins=1000)
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

# Fit DBSCAN and calculate indices


In [13]:
clusters = DBSCAN(eps=0.2, min_samples=80).fit_predict(projections)
# scores = []
# best_score = float('inf')
# clusters = None
# range_ = list(range(5, 20))
# for k in tqdm(range_):
#     y = GaussianMixture(n_components=k).fit_predict(projections)
#     cur_score = davies_bouldin_score(projections, y)
#     scores.append(cur_score)
    
#     if cur_score < best_score:
#         best_score = cur_score
#         clusters = y

In [14]:
perplexity = 100

p = reduce_dims_and_plot(projections[clusters != -1],
                         y=clusters[clusters != -1],
                         title=f'perplexity: {perplexity}',
                         file_name=None,
                         perplexity=perplexity,
                         library='Multicore-TSNE',
                         perform_PCA=False,
                         projected=None,
                         figure_type='2d',
                         show_figure=True,
                         close_figure=False,
                         text=None)    

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [15]:
# from sklearn.tree import DecisionTreeClassifier
# from sklearn import tree
# from sklearn.tree import _tree

In [16]:
tensor_dataset = torch.stack(dataset_)

In [17]:
# clf = DecisionTreeClassifier(max_depth=200, min_samples_leaf=5)
# clf = clf.fit(tensor_dataset[clusters!=-1], clusters[clusters != -1])
# print(clf.score(tensor_dataset[clusters!=-1], clusters[clusters != -1]))
# print(clf.get_depth())

In [18]:
# scores = []
# for min_samples in range(1,50, 1):
#     clf = DecisionTreeClassifier(max_depth=200, min_samples_leaf=min_samples)
#     clf = clf.fit(tensor_dataset[clusters!=-1], clusters[clusters != -1])
#     scores.append(clf.score(tensor_dataset[clusters!=-1], clusters[clusters != -1]))
    
# plt.figure()
# plt.plot(list(range(1,50, 1)), scores)
# plt.show()

In [19]:
def get_rules(tree, feature_names, class_names):
    tree_ = tree.tree_
    feature_name = [
        feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!"
        for i in tree_.feature
    ]

    paths = []
    path = []
    
    def recurse(node, path, paths):
        if tree_.feature[node] != _tree.TREE_UNDEFINED:
            name = feature_name[node]
            threshold = tree_.threshold[node]
            p1, p2 = list(path), list(path)
#             p1 += [f"({name} <= {np.round(threshold, 3)})"]
            p1 += [(name, '<=', np.round(threshold, 3))]
            recurse(tree_.children_left[node], p1, paths)
            p2 += [(name, '>', np.round(threshold, 3))]
            recurse(tree_.children_right[node], p2, paths)
        else:
            path += [(tree_.value[node], tree_.n_node_samples[node])]
            paths += [path]
            
    recurse(0, path, paths)

    # sort by samples count
    samples_count = [p[-1][1] for p in paths]
    ii = list(np.argsort(samples_count))
    paths = [paths[i] for i in reversed(ii)]
    
    rules = []
    for path in paths:
        rule = []
        
        for p in path[:-1]:
            rule += [p]
        target = " then "
        if class_names is None:
            target += "response: "+str(np.round(path[-1][0][0][0],3))
        else:
            classes = path[-1][0][0]
            l = np.argmax(classes)
            target += f"class: {class_names[l]} (proba: {np.round(100.0*classes[l]/np.sum(classes),2)}%)"
           
        proba = np.round(100.0*classes[l]/np.sum(classes),2)
        target += f" | based on {path[-1][1]:,} samples"
        rule_wrapper = {'target': target, 'rule': rule, 'proba': proba}
        rules += [rule_wrapper]
        
    return rules

In [20]:
# rules = get_rules(clf, dataset.items, clusters[clusters != -1])

# for rule in rules:
#     n_pos = 0
#     for c,p,v in rule['rule']:
#         if p == '>':
#             n_pos += 1
#     rule['pos'] = n_pos

In [21]:
# plt.figure()
# probs = [r['proba'] for r in rules]
# plt.hist(probs, bins = 100)
# plt.show()

In [22]:
# rules = sorted(rules, key=lambda x:x['pos'])
# rules = [r for r in rules if r['proba'] > 50]
# print(len(rules))

In [23]:
# for i in range(17):
#     r_i = rules[i]
#     print(f"------------- rule {i} length {len(r_i)} -------------")
#     print(r_i)

# Train a Soft-Decision-Tree given the self-labels

## Prepare the dataset

In [24]:
tree_dataset = list(zip(tensor_dataset[clusters!=-1], clusters[clusters != -1]))
batch_size = 512
tree_loader = torch.utils.data.DataLoader(tree_dataset, batch_size=batch_size, shuffle=True)

# Define how we prune the weights of a node

In [25]:
def prune_node(node_weights, factor=1):
    w = node_weights.cpu().detach().numpy()
    mean_ = np.mean(w)
    std_ = np.std(w)
    node_weights[((mean_ - std_ * factor) < node_weights) & (node_weights < (mean_ + std_ * factor))] = 0
    return node_weights

def prune_node_keep(node_weights, keep=4):
    w = node_weights.cpu().detach().numpy()
    throw_idx = np.argsort(abs(w))[:-keep]
    node_weights[throw_idx] = 0
    return node_weights

def prune_tree(tree_, factor):
    new_weights = tree_.inner_nodes.weight.clone()
    for i in range(new_weights.shape[0]):
        res = prune_node_keep(new_weights[i, :], factor)
        new_weights[i, :] = res

    with torch.no_grad():
        tree_.inner_nodes.weight.copy_(new_weights)
        
def sparseness(x):
    s = []
    for i in range(x.shape[0]):
        x_ = x[i, :]
        sp = (len(x_) - torch.norm(x_, 0).item()) / len(x_)
        s.append(sp)
    return np.mean(s)

def compute_regularization_by_level(tree):
    total_reg = 0
    for i in range(tree.inner_nodes.weight.shape[0]):
        cur_level = np.floor(np.log2(i+1))
        node_reg = torch.norm(tree.inner_nodes.weight[i].view(-1), 2)
        total_reg += 2**(-cur_level) * node_reg
    return total_reg

def show_sparseness(tree):
    avg_sp = sparseness(tree.inner_nodes.weight)
    print(f"Average sparseness: {avg_sp}")
    layer = 0
    sps = []
    for i in range(tree.inner_nodes.weight.shape[0]):
        cur_layer = int(np.floor(np.log2(i+1)))
        if cur_layer != layer:
            print(f"layer {layer}: {np.mean(sps)}")
            sps = []
            layer = cur_layer

        x_ = tree.inner_nodes.weight[i, :]
        sp = (len(x_) - torch.norm(x_, 0).item()) / len(x_)
        sps.append(sp)
        
    return avg_sp

## Training configurations

In [26]:
def do_epoch(model, loader, device, log_interval, losses, accs, epoch, iteration):
    model = model.train()
    for batch_idx, (data, target) in enumerate(loader):
        iteration += 1
        data, target = data.to(device), target.to(device)

        output, penalty = tree.forward(data)

        # Loss
        loss_tree = criterion(output, target.view(-1))

        # Penalty
        loss_tree += penalty

        # Sparse regularization
#         fc_params = torch.cat([x.view(-1) for x in tree.inner_nodes.parameters()])
#         regularization = sparsity_lamda * torch.norm(fc_params, 2)
        regularization = sparsity_lamda * compute_regularization_by_level(tree)
        loss = loss_tree

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        losses.append(loss.item())

        pred = output.data.max(1)[1]
        correct = pred.eq(target.view(-1).data).sum()
        accs.append(correct.item() / data.size()[0])

        # Print training status
        if batch_idx % log_interval == 0:
            print(f"Epoch: {epoch:02d} | Batch: {batch_idx:03d} / {len(loader):03d} | Total loss: {loss.item():.3f} | Reg loss: {regularization.item():.3f} | Tree loss: {loss_tree.item():.3f} | Accuracy: {correct.item() / data.size()[0]:03f} | {round((time.time() - start_time) / iteration, 3)} sec/iter")
            
    return iteration


In [27]:
lr = 5e-3
weight_decay = 5e-4
sparsity_lamda = 2e-3
epochs = 100
output_dim = len(set(clusters))
log_interval = 1
use_cuda = device != 'cpu'

In [28]:
tree = SDT(input_dim=tensor_dataset.shape[1], output_dim=len(clusters - 1), depth=tree_depth, lamda=1e-3, use_cuda=use_cuda)
optimizer = torch.optim.Adam(tree.parameters(),
                                 lr=lr,
                                 weight_decay=weight_decay)
criterion = nn.CrossEntropyLoss()
tree = tree.to(device)

In [29]:
losses = []
accs = []
sparsity = []

In [30]:
start_time = time.time()
iteration = 0
for epoch in range(epochs):
    # Training
    avg_sp = show_sparseness(tree)
    sparsity.append(avg_sp)
    iteration = do_epoch(tree, tree_loader, device, log_interval, losses, accs, epoch, iteration)
    
    if epoch % 1 == 0:
        prune_tree(tree, factor=3)
        

Average sparseness: 0.0
layer 0: 0.0
layer 1: 0.0
layer 2: 0.0
layer 3: 0.0
layer 4: 0.0
layer 5: 0.0
layer 6: 0.0
layer 7: 0.0
layer 8: 0.0
Epoch: 00 | Batch: 000 / 030 | Total loss: 9.618 | Reg loss: 0.012 | Tree loss: 9.618 | Accuracy: 0.000000 | 1.236 sec/iter
Epoch: 00 | Batch: 001 / 030 | Total loss: 9.605 | Reg loss: 0.011 | Tree loss: 9.605 | Accuracy: 0.000000 | 1.038 sec/iter
Epoch: 00 | Batch: 002 / 030 | Total loss: 9.592 | Reg loss: 0.010 | Tree loss: 9.592 | Accuracy: 0.000000 | 0.977 sec/iter
Epoch: 00 | Batch: 003 / 030 | Total loss: 9.579 | Reg loss: 0.010 | Tree loss: 9.579 | Accuracy: 0.000000 | 0.948 sec/iter
Epoch: 00 | Batch: 004 / 030 | Total loss: 9.566 | Reg loss: 0.009 | Tree loss: 9.566 | Accuracy: 0.000000 | 0.933 sec/iter
Epoch: 00 | Batch: 005 / 030 | Total loss: 9.554 | Reg loss: 0.009 | Tree loss: 9.554 | Accuracy: 0.009766 | 0.925 sec/iter
Epoch: 00 | Batch: 006 / 030 | Total loss: 9.542 | Reg loss: 0.009 | Tree loss: 9.542 | Accuracy: 0.052734 | 0.918 

Epoch: 02 | Batch: 001 / 030 | Total loss: 9.272 | Reg loss: 0.009 | Tree loss: 9.272 | Accuracy: 0.554688 | 0.908 sec/iter
Epoch: 02 | Batch: 002 / 030 | Total loss: 9.260 | Reg loss: 0.009 | Tree loss: 9.260 | Accuracy: 0.591797 | 0.907 sec/iter
Epoch: 02 | Batch: 003 / 030 | Total loss: 9.249 | Reg loss: 0.010 | Tree loss: 9.249 | Accuracy: 0.603516 | 0.907 sec/iter
Epoch: 02 | Batch: 004 / 030 | Total loss: 9.234 | Reg loss: 0.010 | Tree loss: 9.234 | Accuracy: 0.599609 | 0.906 sec/iter
Epoch: 02 | Batch: 005 / 030 | Total loss: 9.230 | Reg loss: 0.010 | Tree loss: 9.230 | Accuracy: 0.552734 | 0.906 sec/iter
Epoch: 02 | Batch: 006 / 030 | Total loss: 9.220 | Reg loss: 0.010 | Tree loss: 9.220 | Accuracy: 0.542969 | 0.906 sec/iter
Epoch: 02 | Batch: 007 / 030 | Total loss: 9.199 | Reg loss: 0.011 | Tree loss: 9.199 | Accuracy: 0.593750 | 0.905 sec/iter
Epoch: 02 | Batch: 008 / 030 | Total loss: 9.190 | Reg loss: 0.011 | Tree loss: 9.190 | Accuracy: 0.589844 | 0.905 sec/iter
Epoch: 0

Epoch: 04 | Batch: 003 / 030 | Total loss: 8.926 | Reg loss: 0.016 | Tree loss: 8.926 | Accuracy: 0.589844 | 0.907 sec/iter
Epoch: 04 | Batch: 004 / 030 | Total loss: 8.912 | Reg loss: 0.016 | Tree loss: 8.912 | Accuracy: 0.556641 | 0.906 sec/iter
Epoch: 04 | Batch: 005 / 030 | Total loss: 8.893 | Reg loss: 0.016 | Tree loss: 8.893 | Accuracy: 0.552734 | 0.906 sec/iter
Epoch: 04 | Batch: 006 / 030 | Total loss: 8.877 | Reg loss: 0.017 | Tree loss: 8.877 | Accuracy: 0.574219 | 0.906 sec/iter
Epoch: 04 | Batch: 007 / 030 | Total loss: 8.852 | Reg loss: 0.017 | Tree loss: 8.852 | Accuracy: 0.593750 | 0.906 sec/iter
Epoch: 04 | Batch: 008 / 030 | Total loss: 8.838 | Reg loss: 0.017 | Tree loss: 8.838 | Accuracy: 0.570312 | 0.906 sec/iter
Epoch: 04 | Batch: 009 / 030 | Total loss: 8.819 | Reg loss: 0.018 | Tree loss: 8.819 | Accuracy: 0.578125 | 0.906 sec/iter
Epoch: 04 | Batch: 010 / 030 | Total loss: 8.809 | Reg loss: 0.018 | Tree loss: 8.809 | Accuracy: 0.554688 | 0.905 sec/iter
Epoch: 0

Epoch: 06 | Batch: 005 / 030 | Total loss: 8.455 | Reg loss: 0.022 | Tree loss: 8.455 | Accuracy: 0.593750 | 0.907 sec/iter
Epoch: 06 | Batch: 006 / 030 | Total loss: 8.424 | Reg loss: 0.023 | Tree loss: 8.424 | Accuracy: 0.566406 | 0.906 sec/iter
Epoch: 06 | Batch: 007 / 030 | Total loss: 8.386 | Reg loss: 0.023 | Tree loss: 8.386 | Accuracy: 0.560547 | 0.906 sec/iter
Epoch: 06 | Batch: 008 / 030 | Total loss: 8.360 | Reg loss: 0.023 | Tree loss: 8.360 | Accuracy: 0.611328 | 0.906 sec/iter
Epoch: 06 | Batch: 009 / 030 | Total loss: 8.345 | Reg loss: 0.024 | Tree loss: 8.345 | Accuracy: 0.562500 | 0.906 sec/iter
Epoch: 06 | Batch: 010 / 030 | Total loss: 8.310 | Reg loss: 0.024 | Tree loss: 8.310 | Accuracy: 0.574219 | 0.906 sec/iter
Epoch: 06 | Batch: 011 / 030 | Total loss: 8.291 | Reg loss: 0.025 | Tree loss: 8.291 | Accuracy: 0.583984 | 0.906 sec/iter
Epoch: 06 | Batch: 012 / 030 | Total loss: 8.285 | Reg loss: 0.025 | Tree loss: 8.285 | Accuracy: 0.535156 | 0.906 sec/iter
Epoch: 0

Epoch: 08 | Batch: 007 / 030 | Total loss: 7.881 | Reg loss: 0.028 | Tree loss: 7.881 | Accuracy: 0.601562 | 0.907 sec/iter
Epoch: 08 | Batch: 008 / 030 | Total loss: 7.853 | Reg loss: 0.028 | Tree loss: 7.853 | Accuracy: 0.580078 | 0.907 sec/iter
Epoch: 08 | Batch: 009 / 030 | Total loss: 7.826 | Reg loss: 0.029 | Tree loss: 7.826 | Accuracy: 0.548828 | 0.907 sec/iter
Epoch: 08 | Batch: 010 / 030 | Total loss: 7.792 | Reg loss: 0.029 | Tree loss: 7.792 | Accuracy: 0.570312 | 0.907 sec/iter
Epoch: 08 | Batch: 011 / 030 | Total loss: 7.710 | Reg loss: 0.030 | Tree loss: 7.710 | Accuracy: 0.583984 | 0.907 sec/iter
Epoch: 08 | Batch: 012 / 030 | Total loss: 7.729 | Reg loss: 0.030 | Tree loss: 7.729 | Accuracy: 0.578125 | 0.906 sec/iter
Epoch: 08 | Batch: 013 / 030 | Total loss: 7.670 | Reg loss: 0.030 | Tree loss: 7.670 | Accuracy: 0.552734 | 0.906 sec/iter
Epoch: 08 | Batch: 014 / 030 | Total loss: 7.642 | Reg loss: 0.031 | Tree loss: 7.642 | Accuracy: 0.580078 | 0.906 sec/iter
Epoch: 0

Epoch: 10 | Batch: 009 / 030 | Total loss: 7.265 | Reg loss: 0.033 | Tree loss: 7.265 | Accuracy: 0.601562 | 0.907 sec/iter
Epoch: 10 | Batch: 010 / 030 | Total loss: 7.250 | Reg loss: 0.033 | Tree loss: 7.250 | Accuracy: 0.583984 | 0.907 sec/iter
Epoch: 10 | Batch: 011 / 030 | Total loss: 7.192 | Reg loss: 0.034 | Tree loss: 7.192 | Accuracy: 0.591797 | 0.907 sec/iter
Epoch: 10 | Batch: 012 / 030 | Total loss: 7.168 | Reg loss: 0.034 | Tree loss: 7.168 | Accuracy: 0.587891 | 0.907 sec/iter
Epoch: 10 | Batch: 013 / 030 | Total loss: 7.125 | Reg loss: 0.034 | Tree loss: 7.125 | Accuracy: 0.583984 | 0.907 sec/iter
Epoch: 10 | Batch: 014 / 030 | Total loss: 7.094 | Reg loss: 0.035 | Tree loss: 7.094 | Accuracy: 0.574219 | 0.907 sec/iter
Epoch: 10 | Batch: 015 / 030 | Total loss: 7.064 | Reg loss: 0.035 | Tree loss: 7.064 | Accuracy: 0.587891 | 0.907 sec/iter
Epoch: 10 | Batch: 016 / 030 | Total loss: 7.047 | Reg loss: 0.035 | Tree loss: 7.047 | Accuracy: 0.568359 | 0.907 sec/iter
Epoch: 1

Epoch: 12 | Batch: 011 / 030 | Total loss: 6.677 | Reg loss: 0.037 | Tree loss: 6.677 | Accuracy: 0.619141 | 0.907 sec/iter
Epoch: 12 | Batch: 012 / 030 | Total loss: 6.734 | Reg loss: 0.037 | Tree loss: 6.734 | Accuracy: 0.501953 | 0.907 sec/iter
Epoch: 12 | Batch: 013 / 030 | Total loss: 6.684 | Reg loss: 0.037 | Tree loss: 6.684 | Accuracy: 0.583984 | 0.907 sec/iter
Epoch: 12 | Batch: 014 / 030 | Total loss: 6.588 | Reg loss: 0.037 | Tree loss: 6.588 | Accuracy: 0.578125 | 0.907 sec/iter
Epoch: 12 | Batch: 015 / 030 | Total loss: 6.569 | Reg loss: 0.038 | Tree loss: 6.569 | Accuracy: 0.580078 | 0.907 sec/iter
Epoch: 12 | Batch: 016 / 030 | Total loss: 6.522 | Reg loss: 0.038 | Tree loss: 6.522 | Accuracy: 0.566406 | 0.907 sec/iter
Epoch: 12 | Batch: 017 / 030 | Total loss: 6.516 | Reg loss: 0.038 | Tree loss: 6.516 | Accuracy: 0.554688 | 0.907 sec/iter
Epoch: 12 | Batch: 018 / 030 | Total loss: 6.470 | Reg loss: 0.039 | Tree loss: 6.470 | Accuracy: 0.574219 | 0.907 sec/iter
Epoch: 1

Epoch: 14 | Batch: 013 / 030 | Total loss: 6.157 | Reg loss: 0.039 | Tree loss: 6.157 | Accuracy: 0.591797 | 0.907 sec/iter
Epoch: 14 | Batch: 014 / 030 | Total loss: 6.097 | Reg loss: 0.039 | Tree loss: 6.097 | Accuracy: 0.582031 | 0.907 sec/iter
Epoch: 14 | Batch: 015 / 030 | Total loss: 6.100 | Reg loss: 0.039 | Tree loss: 6.100 | Accuracy: 0.554688 | 0.907 sec/iter
Epoch: 14 | Batch: 016 / 030 | Total loss: 6.074 | Reg loss: 0.040 | Tree loss: 6.074 | Accuracy: 0.580078 | 0.907 sec/iter
Epoch: 14 | Batch: 017 / 030 | Total loss: 6.040 | Reg loss: 0.040 | Tree loss: 6.040 | Accuracy: 0.546875 | 0.907 sec/iter
Epoch: 14 | Batch: 018 / 030 | Total loss: 6.001 | Reg loss: 0.040 | Tree loss: 6.001 | Accuracy: 0.550781 | 0.907 sec/iter
Epoch: 14 | Batch: 019 / 030 | Total loss: 5.980 | Reg loss: 0.041 | Tree loss: 5.980 | Accuracy: 0.570312 | 0.907 sec/iter
Epoch: 14 | Batch: 020 / 030 | Total loss: 5.909 | Reg loss: 0.041 | Tree loss: 5.909 | Accuracy: 0.587891 | 0.907 sec/iter
Epoch: 1

Epoch: 16 | Batch: 015 / 030 | Total loss: 5.621 | Reg loss: 0.041 | Tree loss: 5.621 | Accuracy: 0.552734 | 0.907 sec/iter
Epoch: 16 | Batch: 016 / 030 | Total loss: 5.595 | Reg loss: 0.041 | Tree loss: 5.595 | Accuracy: 0.542969 | 0.907 sec/iter
Epoch: 16 | Batch: 017 / 030 | Total loss: 5.507 | Reg loss: 0.042 | Tree loss: 5.507 | Accuracy: 0.589844 | 0.907 sec/iter
Epoch: 16 | Batch: 018 / 030 | Total loss: 5.555 | Reg loss: 0.042 | Tree loss: 5.555 | Accuracy: 0.542969 | 0.907 sec/iter
Epoch: 16 | Batch: 019 / 030 | Total loss: 5.493 | Reg loss: 0.042 | Tree loss: 5.493 | Accuracy: 0.562500 | 0.907 sec/iter
Epoch: 16 | Batch: 020 / 030 | Total loss: 5.449 | Reg loss: 0.043 | Tree loss: 5.449 | Accuracy: 0.562500 | 0.907 sec/iter
Epoch: 16 | Batch: 021 / 030 | Total loss: 5.416 | Reg loss: 0.043 | Tree loss: 5.416 | Accuracy: 0.552734 | 0.907 sec/iter
Epoch: 16 | Batch: 022 / 030 | Total loss: 5.357 | Reg loss: 0.043 | Tree loss: 5.357 | Accuracy: 0.572266 | 0.907 sec/iter
Epoch: 1

Epoch: 18 | Batch: 017 / 030 | Total loss: 5.020 | Reg loss: 0.044 | Tree loss: 5.020 | Accuracy: 0.595703 | 0.907 sec/iter
Epoch: 18 | Batch: 018 / 030 | Total loss: 4.999 | Reg loss: 0.044 | Tree loss: 4.999 | Accuracy: 0.521484 | 0.907 sec/iter
Epoch: 18 | Batch: 019 / 030 | Total loss: 4.932 | Reg loss: 0.044 | Tree loss: 4.932 | Accuracy: 0.595703 | 0.907 sec/iter
Epoch: 18 | Batch: 020 / 030 | Total loss: 4.941 | Reg loss: 0.045 | Tree loss: 4.941 | Accuracy: 0.546875 | 0.907 sec/iter
Epoch: 18 | Batch: 021 / 030 | Total loss: 4.879 | Reg loss: 0.045 | Tree loss: 4.879 | Accuracy: 0.535156 | 0.907 sec/iter
Epoch: 18 | Batch: 022 / 030 | Total loss: 4.793 | Reg loss: 0.046 | Tree loss: 4.793 | Accuracy: 0.607422 | 0.907 sec/iter
Epoch: 18 | Batch: 023 / 030 | Total loss: 4.772 | Reg loss: 0.046 | Tree loss: 4.772 | Accuracy: 0.582031 | 0.907 sec/iter
Epoch: 18 | Batch: 024 / 030 | Total loss: 4.788 | Reg loss: 0.046 | Tree loss: 4.788 | Accuracy: 0.593750 | 0.907 sec/iter
Epoch: 1

Epoch: 20 | Batch: 019 / 030 | Total loss: 4.561 | Reg loss: 0.045 | Tree loss: 4.561 | Accuracy: 0.566406 | 0.907 sec/iter
Epoch: 20 | Batch: 020 / 030 | Total loss: 4.464 | Reg loss: 0.046 | Tree loss: 4.464 | Accuracy: 0.542969 | 0.907 sec/iter
Epoch: 20 | Batch: 021 / 030 | Total loss: 4.476 | Reg loss: 0.046 | Tree loss: 4.476 | Accuracy: 0.560547 | 0.907 sec/iter
Epoch: 20 | Batch: 022 / 030 | Total loss: 4.425 | Reg loss: 0.047 | Tree loss: 4.425 | Accuracy: 0.583984 | 0.907 sec/iter
Epoch: 20 | Batch: 023 / 030 | Total loss: 4.408 | Reg loss: 0.047 | Tree loss: 4.408 | Accuracy: 0.574219 | 0.907 sec/iter
Epoch: 20 | Batch: 024 / 030 | Total loss: 4.320 | Reg loss: 0.047 | Tree loss: 4.320 | Accuracy: 0.568359 | 0.907 sec/iter
Epoch: 20 | Batch: 025 / 030 | Total loss: 4.341 | Reg loss: 0.048 | Tree loss: 4.341 | Accuracy: 0.533203 | 0.907 sec/iter
Epoch: 20 | Batch: 026 / 030 | Total loss: 4.274 | Reg loss: 0.048 | Tree loss: 4.274 | Accuracy: 0.599609 | 0.907 sec/iter
Epoch: 2

Epoch: 22 | Batch: 021 / 030 | Total loss: 4.123 | Reg loss: 0.047 | Tree loss: 4.123 | Accuracy: 0.578125 | 0.907 sec/iter
Epoch: 22 | Batch: 022 / 030 | Total loss: 4.042 | Reg loss: 0.048 | Tree loss: 4.042 | Accuracy: 0.570312 | 0.907 sec/iter
Epoch: 22 | Batch: 023 / 030 | Total loss: 4.025 | Reg loss: 0.048 | Tree loss: 4.025 | Accuracy: 0.566406 | 0.907 sec/iter
Epoch: 22 | Batch: 024 / 030 | Total loss: 3.990 | Reg loss: 0.048 | Tree loss: 3.990 | Accuracy: 0.554688 | 0.907 sec/iter
Epoch: 22 | Batch: 025 / 030 | Total loss: 3.902 | Reg loss: 0.049 | Tree loss: 3.902 | Accuracy: 0.597656 | 0.907 sec/iter
Epoch: 22 | Batch: 026 / 030 | Total loss: 3.932 | Reg loss: 0.049 | Tree loss: 3.932 | Accuracy: 0.572266 | 0.907 sec/iter
Epoch: 22 | Batch: 027 / 030 | Total loss: 3.877 | Reg loss: 0.050 | Tree loss: 3.877 | Accuracy: 0.585938 | 0.907 sec/iter
Epoch: 22 | Batch: 028 / 030 | Total loss: 3.801 | Reg loss: 0.050 | Tree loss: 3.801 | Accuracy: 0.601562 | 0.907 sec/iter
Epoch: 2

Epoch: 24 | Batch: 023 / 030 | Total loss: 3.643 | Reg loss: 0.050 | Tree loss: 3.643 | Accuracy: 0.572266 | 0.907 sec/iter
Epoch: 24 | Batch: 024 / 030 | Total loss: 3.606 | Reg loss: 0.050 | Tree loss: 3.606 | Accuracy: 0.556641 | 0.907 sec/iter
Epoch: 24 | Batch: 025 / 030 | Total loss: 3.607 | Reg loss: 0.050 | Tree loss: 3.607 | Accuracy: 0.576172 | 0.907 sec/iter
Epoch: 24 | Batch: 026 / 030 | Total loss: 3.549 | Reg loss: 0.051 | Tree loss: 3.549 | Accuracy: 0.568359 | 0.907 sec/iter
Epoch: 24 | Batch: 027 / 030 | Total loss: 3.521 | Reg loss: 0.051 | Tree loss: 3.521 | Accuracy: 0.601562 | 0.907 sec/iter
Epoch: 24 | Batch: 028 / 030 | Total loss: 3.504 | Reg loss: 0.051 | Tree loss: 3.504 | Accuracy: 0.554688 | 0.907 sec/iter
Epoch: 24 | Batch: 029 / 030 | Total loss: 3.517 | Reg loss: 0.051 | Tree loss: 3.517 | Accuracy: 0.514286 | 0.907 sec/iter
Average sparseness: 0.9821428571428573
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 

Epoch: 26 | Batch: 025 / 030 | Total loss: 3.213 | Reg loss: 0.051 | Tree loss: 3.213 | Accuracy: 0.583984 | 0.907 sec/iter
Epoch: 26 | Batch: 026 / 030 | Total loss: 3.241 | Reg loss: 0.052 | Tree loss: 3.241 | Accuracy: 0.548828 | 0.907 sec/iter
Epoch: 26 | Batch: 027 / 030 | Total loss: 3.213 | Reg loss: 0.052 | Tree loss: 3.213 | Accuracy: 0.556641 | 0.907 sec/iter
Epoch: 26 | Batch: 028 / 030 | Total loss: 3.160 | Reg loss: 0.052 | Tree loss: 3.160 | Accuracy: 0.578125 | 0.907 sec/iter
Epoch: 26 | Batch: 029 / 030 | Total loss: 3.023 | Reg loss: 0.053 | Tree loss: 3.023 | Accuracy: 0.638095 | 0.907 sec/iter
Average sparseness: 0.9821428571428573
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
layer 7: 0.9821428571428573
layer 8: 0.9821428571428573
Epoch: 27 | Batch: 000 / 030 | Total loss: 4.138 | Reg loss: 0.047 | Tree loss: 4.138 | Ac

Epoch: 28 | Batch: 027 / 030 | Total loss: 2.909 | Reg loss: 0.053 | Tree loss: 2.909 | Accuracy: 0.585938 | 0.907 sec/iter
Epoch: 28 | Batch: 028 / 030 | Total loss: 2.820 | Reg loss: 0.054 | Tree loss: 2.820 | Accuracy: 0.587891 | 0.907 sec/iter
Epoch: 28 | Batch: 029 / 030 | Total loss: 2.838 | Reg loss: 0.054 | Tree loss: 2.838 | Accuracy: 0.619048 | 0.907 sec/iter
Average sparseness: 0.9821428571428573
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
layer 7: 0.9821428571428573
layer 8: 0.9821428571428573
Epoch: 29 | Batch: 000 / 030 | Total loss: 3.819 | Reg loss: 0.049 | Tree loss: 3.819 | Accuracy: 0.589844 | 0.908 sec/iter
Epoch: 29 | Batch: 001 / 030 | Total loss: 3.782 | Reg loss: 0.049 | Tree loss: 3.782 | Accuracy: 0.552734 | 0.908 sec/iter
Epoch: 29 | Batch: 002 / 030 | Total loss: 3.704 | Reg loss: 0.049 | Tree loss: 3.704 | Ac

Epoch: 30 | Batch: 029 / 030 | Total loss: 2.434 | Reg loss: 0.055 | Tree loss: 2.434 | Accuracy: 0.638095 | 0.907 sec/iter
Average sparseness: 0.9821428571428573
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
layer 7: 0.9821428571428573
layer 8: 0.9821428571428573
Epoch: 31 | Batch: 000 / 030 | Total loss: 3.471 | Reg loss: 0.050 | Tree loss: 3.471 | Accuracy: 0.591797 | 0.907 sec/iter
Epoch: 31 | Batch: 001 / 030 | Total loss: 3.404 | Reg loss: 0.050 | Tree loss: 3.404 | Accuracy: 0.601562 | 0.907 sec/iter
Epoch: 31 | Batch: 002 / 030 | Total loss: 3.425 | Reg loss: 0.050 | Tree loss: 3.425 | Accuracy: 0.564453 | 0.907 sec/iter
Epoch: 31 | Batch: 003 / 030 | Total loss: 3.411 | Reg loss: 0.050 | Tree loss: 3.411 | Accuracy: 0.521484 | 0.907 sec/iter
Epoch: 31 | Batch: 004 / 030 | Total loss: 3.396 | Reg loss: 0.050 | Tree loss: 3.396 | Ac

Epoch: 33 | Batch: 000 / 030 | Total loss: 3.226 | Reg loss: 0.051 | Tree loss: 3.226 | Accuracy: 0.589844 | 0.907 sec/iter
Epoch: 33 | Batch: 001 / 030 | Total loss: 3.158 | Reg loss: 0.051 | Tree loss: 3.158 | Accuracy: 0.542969 | 0.907 sec/iter
Epoch: 33 | Batch: 002 / 030 | Total loss: 3.154 | Reg loss: 0.051 | Tree loss: 3.154 | Accuracy: 0.587891 | 0.907 sec/iter
Epoch: 33 | Batch: 003 / 030 | Total loss: 3.174 | Reg loss: 0.051 | Tree loss: 3.174 | Accuracy: 0.564453 | 0.907 sec/iter
Epoch: 33 | Batch: 004 / 030 | Total loss: 3.101 | Reg loss: 0.051 | Tree loss: 3.101 | Accuracy: 0.576172 | 0.907 sec/iter
Epoch: 33 | Batch: 005 / 030 | Total loss: 3.048 | Reg loss: 0.051 | Tree loss: 3.048 | Accuracy: 0.601562 | 0.907 sec/iter
Epoch: 33 | Batch: 006 / 030 | Total loss: 3.001 | Reg loss: 0.051 | Tree loss: 3.001 | Accuracy: 0.599609 | 0.907 sec/iter
Epoch: 33 | Batch: 007 / 030 | Total loss: 3.008 | Reg loss: 0.052 | Tree loss: 3.008 | Accuracy: 0.558594 | 0.907 sec/iter
Epoch: 3

Epoch: 35 | Batch: 002 / 030 | Total loss: 2.971 | Reg loss: 0.052 | Tree loss: 2.971 | Accuracy: 0.572266 | 0.907 sec/iter
Epoch: 35 | Batch: 003 / 030 | Total loss: 2.923 | Reg loss: 0.052 | Tree loss: 2.923 | Accuracy: 0.593750 | 0.907 sec/iter
Epoch: 35 | Batch: 004 / 030 | Total loss: 2.863 | Reg loss: 0.052 | Tree loss: 2.863 | Accuracy: 0.578125 | 0.907 sec/iter
Epoch: 35 | Batch: 005 / 030 | Total loss: 2.857 | Reg loss: 0.052 | Tree loss: 2.857 | Accuracy: 0.570312 | 0.907 sec/iter
Epoch: 35 | Batch: 006 / 030 | Total loss: 2.783 | Reg loss: 0.052 | Tree loss: 2.783 | Accuracy: 0.595703 | 0.907 sec/iter
Epoch: 35 | Batch: 007 / 030 | Total loss: 2.786 | Reg loss: 0.052 | Tree loss: 2.786 | Accuracy: 0.560547 | 0.907 sec/iter
Epoch: 35 | Batch: 008 / 030 | Total loss: 2.692 | Reg loss: 0.052 | Tree loss: 2.692 | Accuracy: 0.554688 | 0.907 sec/iter
Epoch: 35 | Batch: 009 / 030 | Total loss: 2.635 | Reg loss: 0.052 | Tree loss: 2.635 | Accuracy: 0.578125 | 0.907 sec/iter
Epoch: 3

Epoch: 37 | Batch: 004 / 030 | Total loss: 2.675 | Reg loss: 0.053 | Tree loss: 2.675 | Accuracy: 0.568359 | 0.907 sec/iter
Epoch: 37 | Batch: 005 / 030 | Total loss: 2.673 | Reg loss: 0.053 | Tree loss: 2.673 | Accuracy: 0.576172 | 0.907 sec/iter
Epoch: 37 | Batch: 006 / 030 | Total loss: 2.571 | Reg loss: 0.053 | Tree loss: 2.571 | Accuracy: 0.587891 | 0.907 sec/iter
Epoch: 37 | Batch: 007 / 030 | Total loss: 2.524 | Reg loss: 0.053 | Tree loss: 2.524 | Accuracy: 0.574219 | 0.907 sec/iter
Epoch: 37 | Batch: 008 / 030 | Total loss: 2.488 | Reg loss: 0.053 | Tree loss: 2.488 | Accuracy: 0.603516 | 0.907 sec/iter
Epoch: 37 | Batch: 009 / 030 | Total loss: 2.523 | Reg loss: 0.053 | Tree loss: 2.523 | Accuracy: 0.537109 | 0.907 sec/iter
Epoch: 37 | Batch: 010 / 030 | Total loss: 2.431 | Reg loss: 0.053 | Tree loss: 2.431 | Accuracy: 0.583984 | 0.907 sec/iter
Epoch: 37 | Batch: 011 / 030 | Total loss: 2.420 | Reg loss: 0.053 | Tree loss: 2.420 | Accuracy: 0.556641 | 0.907 sec/iter
Epoch: 3

Epoch: 39 | Batch: 006 / 030 | Total loss: 2.455 | Reg loss: 0.053 | Tree loss: 2.455 | Accuracy: 0.548828 | 0.907 sec/iter
Epoch: 39 | Batch: 007 / 030 | Total loss: 2.395 | Reg loss: 0.053 | Tree loss: 2.395 | Accuracy: 0.595703 | 0.907 sec/iter
Epoch: 39 | Batch: 008 / 030 | Total loss: 2.384 | Reg loss: 0.053 | Tree loss: 2.384 | Accuracy: 0.572266 | 0.907 sec/iter
Epoch: 39 | Batch: 009 / 030 | Total loss: 2.372 | Reg loss: 0.053 | Tree loss: 2.372 | Accuracy: 0.544922 | 0.907 sec/iter
Epoch: 39 | Batch: 010 / 030 | Total loss: 2.311 | Reg loss: 0.053 | Tree loss: 2.311 | Accuracy: 0.556641 | 0.907 sec/iter
Epoch: 39 | Batch: 011 / 030 | Total loss: 2.256 | Reg loss: 0.053 | Tree loss: 2.256 | Accuracy: 0.601562 | 0.907 sec/iter
Epoch: 39 | Batch: 012 / 030 | Total loss: 2.225 | Reg loss: 0.054 | Tree loss: 2.225 | Accuracy: 0.572266 | 0.907 sec/iter
Epoch: 39 | Batch: 013 / 030 | Total loss: 2.156 | Reg loss: 0.054 | Tree loss: 2.156 | Accuracy: 0.597656 | 0.907 sec/iter
Epoch: 3

Epoch: 41 | Batch: 008 / 030 | Total loss: 2.258 | Reg loss: 0.053 | Tree loss: 2.258 | Accuracy: 0.580078 | 0.907 sec/iter
Epoch: 41 | Batch: 009 / 030 | Total loss: 2.198 | Reg loss: 0.053 | Tree loss: 2.198 | Accuracy: 0.550781 | 0.907 sec/iter
Epoch: 41 | Batch: 010 / 030 | Total loss: 2.126 | Reg loss: 0.054 | Tree loss: 2.126 | Accuracy: 0.591797 | 0.907 sec/iter
Epoch: 41 | Batch: 011 / 030 | Total loss: 2.113 | Reg loss: 0.054 | Tree loss: 2.113 | Accuracy: 0.568359 | 0.907 sec/iter
Epoch: 41 | Batch: 012 / 030 | Total loss: 2.088 | Reg loss: 0.054 | Tree loss: 2.088 | Accuracy: 0.574219 | 0.907 sec/iter
Epoch: 41 | Batch: 013 / 030 | Total loss: 2.044 | Reg loss: 0.054 | Tree loss: 2.044 | Accuracy: 0.562500 | 0.907 sec/iter
Epoch: 41 | Batch: 014 / 030 | Total loss: 2.012 | Reg loss: 0.054 | Tree loss: 2.012 | Accuracy: 0.535156 | 0.907 sec/iter
Epoch: 41 | Batch: 015 / 030 | Total loss: 1.994 | Reg loss: 0.054 | Tree loss: 1.994 | Accuracy: 0.582031 | 0.907 sec/iter
Epoch: 4

Epoch: 43 | Batch: 010 / 030 | Total loss: 2.003 | Reg loss: 0.054 | Tree loss: 2.003 | Accuracy: 0.583984 | 0.907 sec/iter
Epoch: 43 | Batch: 011 / 030 | Total loss: 2.001 | Reg loss: 0.054 | Tree loss: 2.001 | Accuracy: 0.570312 | 0.907 sec/iter
Epoch: 43 | Batch: 012 / 030 | Total loss: 1.945 | Reg loss: 0.054 | Tree loss: 1.945 | Accuracy: 0.587891 | 0.907 sec/iter
Epoch: 43 | Batch: 013 / 030 | Total loss: 1.978 | Reg loss: 0.054 | Tree loss: 1.978 | Accuracy: 0.570312 | 0.907 sec/iter
Epoch: 43 | Batch: 014 / 030 | Total loss: 1.905 | Reg loss: 0.054 | Tree loss: 1.905 | Accuracy: 0.597656 | 0.907 sec/iter
Epoch: 43 | Batch: 015 / 030 | Total loss: 1.882 | Reg loss: 0.054 | Tree loss: 1.882 | Accuracy: 0.562500 | 0.907 sec/iter
Epoch: 43 | Batch: 016 / 030 | Total loss: 1.830 | Reg loss: 0.054 | Tree loss: 1.830 | Accuracy: 0.558594 | 0.907 sec/iter
Epoch: 43 | Batch: 017 / 030 | Total loss: 1.777 | Reg loss: 0.055 | Tree loss: 1.777 | Accuracy: 0.558594 | 0.907 sec/iter
Epoch: 4

Epoch: 45 | Batch: 012 / 030 | Total loss: 1.865 | Reg loss: 0.054 | Tree loss: 1.865 | Accuracy: 0.582031 | 0.908 sec/iter
Epoch: 45 | Batch: 013 / 030 | Total loss: 1.838 | Reg loss: 0.054 | Tree loss: 1.838 | Accuracy: 0.585938 | 0.907 sec/iter
Epoch: 45 | Batch: 014 / 030 | Total loss: 1.749 | Reg loss: 0.054 | Tree loss: 1.749 | Accuracy: 0.632812 | 0.907 sec/iter
Epoch: 45 | Batch: 015 / 030 | Total loss: 1.749 | Reg loss: 0.054 | Tree loss: 1.749 | Accuracy: 0.550781 | 0.907 sec/iter
Epoch: 45 | Batch: 016 / 030 | Total loss: 1.737 | Reg loss: 0.054 | Tree loss: 1.737 | Accuracy: 0.554688 | 0.907 sec/iter
Epoch: 45 | Batch: 017 / 030 | Total loss: 1.689 | Reg loss: 0.054 | Tree loss: 1.689 | Accuracy: 0.597656 | 0.907 sec/iter
Epoch: 45 | Batch: 018 / 030 | Total loss: 1.658 | Reg loss: 0.055 | Tree loss: 1.658 | Accuracy: 0.589844 | 0.907 sec/iter
Epoch: 45 | Batch: 019 / 030 | Total loss: 1.667 | Reg loss: 0.055 | Tree loss: 1.667 | Accuracy: 0.537109 | 0.907 sec/iter
Epoch: 4

Epoch: 47 | Batch: 014 / 030 | Total loss: 1.689 | Reg loss: 0.054 | Tree loss: 1.689 | Accuracy: 0.560547 | 0.908 sec/iter
Epoch: 47 | Batch: 015 / 030 | Total loss: 1.700 | Reg loss: 0.054 | Tree loss: 1.700 | Accuracy: 0.548828 | 0.908 sec/iter
Epoch: 47 | Batch: 016 / 030 | Total loss: 1.693 | Reg loss: 0.054 | Tree loss: 1.693 | Accuracy: 0.542969 | 0.908 sec/iter
Epoch: 47 | Batch: 017 / 030 | Total loss: 1.606 | Reg loss: 0.054 | Tree loss: 1.606 | Accuracy: 0.568359 | 0.908 sec/iter
Epoch: 47 | Batch: 018 / 030 | Total loss: 1.625 | Reg loss: 0.054 | Tree loss: 1.625 | Accuracy: 0.550781 | 0.908 sec/iter
Epoch: 47 | Batch: 019 / 030 | Total loss: 1.561 | Reg loss: 0.054 | Tree loss: 1.561 | Accuracy: 0.611328 | 0.908 sec/iter
Epoch: 47 | Batch: 020 / 030 | Total loss: 1.559 | Reg loss: 0.055 | Tree loss: 1.559 | Accuracy: 0.566406 | 0.908 sec/iter
Epoch: 47 | Batch: 021 / 030 | Total loss: 1.529 | Reg loss: 0.055 | Tree loss: 1.529 | Accuracy: 0.544922 | 0.908 sec/iter
Epoch: 4

Epoch: 49 | Batch: 016 / 030 | Total loss: 1.534 | Reg loss: 0.054 | Tree loss: 1.534 | Accuracy: 0.601562 | 0.908 sec/iter
Epoch: 49 | Batch: 017 / 030 | Total loss: 1.542 | Reg loss: 0.054 | Tree loss: 1.542 | Accuracy: 0.597656 | 0.908 sec/iter
Epoch: 49 | Batch: 018 / 030 | Total loss: 1.518 | Reg loss: 0.054 | Tree loss: 1.518 | Accuracy: 0.599609 | 0.908 sec/iter
Epoch: 49 | Batch: 019 / 030 | Total loss: 1.536 | Reg loss: 0.054 | Tree loss: 1.536 | Accuracy: 0.578125 | 0.908 sec/iter
Epoch: 49 | Batch: 020 / 030 | Total loss: 1.532 | Reg loss: 0.054 | Tree loss: 1.532 | Accuracy: 0.541016 | 0.908 sec/iter
Epoch: 49 | Batch: 021 / 030 | Total loss: 1.454 | Reg loss: 0.054 | Tree loss: 1.454 | Accuracy: 0.562500 | 0.908 sec/iter
Epoch: 49 | Batch: 022 / 030 | Total loss: 1.412 | Reg loss: 0.054 | Tree loss: 1.412 | Accuracy: 0.583984 | 0.908 sec/iter
Epoch: 49 | Batch: 023 / 030 | Total loss: 1.392 | Reg loss: 0.055 | Tree loss: 1.392 | Accuracy: 0.556641 | 0.908 sec/iter
Epoch: 4

Epoch: 51 | Batch: 018 / 030 | Total loss: 1.492 | Reg loss: 0.054 | Tree loss: 1.492 | Accuracy: 0.558594 | 0.908 sec/iter
Epoch: 51 | Batch: 019 / 030 | Total loss: 1.441 | Reg loss: 0.054 | Tree loss: 1.441 | Accuracy: 0.597656 | 0.908 sec/iter
Epoch: 51 | Batch: 020 / 030 | Total loss: 1.350 | Reg loss: 0.054 | Tree loss: 1.350 | Accuracy: 0.589844 | 0.908 sec/iter
Epoch: 51 | Batch: 021 / 030 | Total loss: 1.396 | Reg loss: 0.054 | Tree loss: 1.396 | Accuracy: 0.578125 | 0.908 sec/iter
Epoch: 51 | Batch: 022 / 030 | Total loss: 1.405 | Reg loss: 0.054 | Tree loss: 1.405 | Accuracy: 0.517578 | 0.908 sec/iter
Epoch: 51 | Batch: 023 / 030 | Total loss: 1.334 | Reg loss: 0.054 | Tree loss: 1.334 | Accuracy: 0.589844 | 0.908 sec/iter
Epoch: 51 | Batch: 024 / 030 | Total loss: 1.341 | Reg loss: 0.054 | Tree loss: 1.341 | Accuracy: 0.546875 | 0.908 sec/iter
Epoch: 51 | Batch: 025 / 030 | Total loss: 1.296 | Reg loss: 0.054 | Tree loss: 1.296 | Accuracy: 0.597656 | 0.908 sec/iter
Epoch: 5

Epoch: 53 | Batch: 020 / 030 | Total loss: 1.345 | Reg loss: 0.053 | Tree loss: 1.345 | Accuracy: 0.595703 | 0.908 sec/iter
Epoch: 53 | Batch: 021 / 030 | Total loss: 1.343 | Reg loss: 0.054 | Tree loss: 1.343 | Accuracy: 0.578125 | 0.908 sec/iter
Epoch: 53 | Batch: 022 / 030 | Total loss: 1.305 | Reg loss: 0.054 | Tree loss: 1.305 | Accuracy: 0.580078 | 0.908 sec/iter
Epoch: 53 | Batch: 023 / 030 | Total loss: 1.292 | Reg loss: 0.054 | Tree loss: 1.292 | Accuracy: 0.583984 | 0.908 sec/iter
Epoch: 53 | Batch: 024 / 030 | Total loss: 1.280 | Reg loss: 0.054 | Tree loss: 1.280 | Accuracy: 0.576172 | 0.908 sec/iter
Epoch: 53 | Batch: 025 / 030 | Total loss: 1.232 | Reg loss: 0.054 | Tree loss: 1.232 | Accuracy: 0.607422 | 0.908 sec/iter
Epoch: 53 | Batch: 026 / 030 | Total loss: 1.240 | Reg loss: 0.054 | Tree loss: 1.240 | Accuracy: 0.583984 | 0.908 sec/iter
Epoch: 53 | Batch: 027 / 030 | Total loss: 1.186 | Reg loss: 0.054 | Tree loss: 1.186 | Accuracy: 0.609375 | 0.908 sec/iter
Epoch: 5

Epoch: 55 | Batch: 022 / 030 | Total loss: 1.252 | Reg loss: 0.053 | Tree loss: 1.252 | Accuracy: 0.611328 | 0.908 sec/iter
Epoch: 55 | Batch: 023 / 030 | Total loss: 1.275 | Reg loss: 0.053 | Tree loss: 1.275 | Accuracy: 0.562500 | 0.908 sec/iter
Epoch: 55 | Batch: 024 / 030 | Total loss: 1.267 | Reg loss: 0.053 | Tree loss: 1.267 | Accuracy: 0.548828 | 0.908 sec/iter
Epoch: 55 | Batch: 025 / 030 | Total loss: 1.231 | Reg loss: 0.054 | Tree loss: 1.231 | Accuracy: 0.574219 | 0.908 sec/iter
Epoch: 55 | Batch: 026 / 030 | Total loss: 1.190 | Reg loss: 0.054 | Tree loss: 1.190 | Accuracy: 0.613281 | 0.908 sec/iter
Epoch: 55 | Batch: 027 / 030 | Total loss: 1.198 | Reg loss: 0.054 | Tree loss: 1.198 | Accuracy: 0.574219 | 0.908 sec/iter
Epoch: 55 | Batch: 028 / 030 | Total loss: 1.174 | Reg loss: 0.054 | Tree loss: 1.174 | Accuracy: 0.576172 | 0.908 sec/iter
Epoch: 55 | Batch: 029 / 030 | Total loss: 1.113 | Reg loss: 0.054 | Tree loss: 1.113 | Accuracy: 0.638095 | 0.908 sec/iter
Average 

Epoch: 57 | Batch: 024 / 030 | Total loss: 1.234 | Reg loss: 0.053 | Tree loss: 1.234 | Accuracy: 0.558594 | 0.908 sec/iter
Epoch: 57 | Batch: 025 / 030 | Total loss: 1.198 | Reg loss: 0.053 | Tree loss: 1.198 | Accuracy: 0.585938 | 0.908 sec/iter
Epoch: 57 | Batch: 026 / 030 | Total loss: 1.209 | Reg loss: 0.053 | Tree loss: 1.209 | Accuracy: 0.546875 | 0.908 sec/iter
Epoch: 57 | Batch: 027 / 030 | Total loss: 1.151 | Reg loss: 0.053 | Tree loss: 1.151 | Accuracy: 0.591797 | 0.908 sec/iter
Epoch: 57 | Batch: 028 / 030 | Total loss: 1.167 | Reg loss: 0.053 | Tree loss: 1.167 | Accuracy: 0.597656 | 0.908 sec/iter
Epoch: 57 | Batch: 029 / 030 | Total loss: 1.072 | Reg loss: 0.053 | Tree loss: 1.072 | Accuracy: 0.685714 | 0.908 sec/iter
Average sparseness: 0.9821428571428573
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
layer 7: 0.98214285714

Epoch: 59 | Batch: 026 / 030 | Total loss: 1.141 | Reg loss: 0.053 | Tree loss: 1.141 | Accuracy: 0.578125 | 0.908 sec/iter
Epoch: 59 | Batch: 027 / 030 | Total loss: 1.121 | Reg loss: 0.053 | Tree loss: 1.121 | Accuracy: 0.572266 | 0.908 sec/iter
Epoch: 59 | Batch: 028 / 030 | Total loss: 1.106 | Reg loss: 0.053 | Tree loss: 1.106 | Accuracy: 0.603516 | 0.908 sec/iter
Epoch: 59 | Batch: 029 / 030 | Total loss: 1.083 | Reg loss: 0.053 | Tree loss: 1.083 | Accuracy: 0.666667 | 0.908 sec/iter
Average sparseness: 0.9821428571428573
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
layer 7: 0.9821428571428573
layer 8: 0.9821428571428573
Epoch: 60 | Batch: 000 / 030 | Total loss: 1.835 | Reg loss: 0.050 | Tree loss: 1.835 | Accuracy: 0.582031 | 0.908 sec/iter
Epoch: 60 | Batch: 001 / 030 | Total loss: 1.828 | Reg loss: 0.050 | Tree loss: 1.828 | Ac

Epoch: 61 | Batch: 028 / 030 | Total loss: 1.076 | Reg loss: 0.053 | Tree loss: 1.076 | Accuracy: 0.585938 | 0.908 sec/iter
Epoch: 61 | Batch: 029 / 030 | Total loss: 1.064 | Reg loss: 0.053 | Tree loss: 1.064 | Accuracy: 0.628571 | 0.908 sec/iter
Average sparseness: 0.9821428571428573
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
layer 7: 0.9821428571428573
layer 8: 0.9821428571428573
Epoch: 62 | Batch: 000 / 030 | Total loss: 1.787 | Reg loss: 0.050 | Tree loss: 1.787 | Accuracy: 0.574219 | 0.908 sec/iter
Epoch: 62 | Batch: 001 / 030 | Total loss: 1.717 | Reg loss: 0.050 | Tree loss: 1.717 | Accuracy: 0.630859 | 0.908 sec/iter
Epoch: 62 | Batch: 002 / 030 | Total loss: 1.705 | Reg loss: 0.050 | Tree loss: 1.705 | Accuracy: 0.576172 | 0.908 sec/iter
Epoch: 62 | Batch: 003 / 030 | Total loss: 1.686 | Reg loss: 0.050 | Tree loss: 1.686 | Ac

Average sparseness: 0.9821428571428573
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
layer 7: 0.9821428571428573
layer 8: 0.9821428571428573
Epoch: 64 | Batch: 000 / 030 | Total loss: 1.812 | Reg loss: 0.050 | Tree loss: 1.812 | Accuracy: 0.593750 | 0.908 sec/iter
Epoch: 64 | Batch: 001 / 030 | Total loss: 1.750 | Reg loss: 0.050 | Tree loss: 1.750 | Accuracy: 0.572266 | 0.908 sec/iter
Epoch: 64 | Batch: 002 / 030 | Total loss: 1.768 | Reg loss: 0.050 | Tree loss: 1.768 | Accuracy: 0.552734 | 0.908 sec/iter
Epoch: 64 | Batch: 003 / 030 | Total loss: 1.629 | Reg loss: 0.050 | Tree loss: 1.629 | Accuracy: 0.611328 | 0.908 sec/iter
Epoch: 64 | Batch: 004 / 030 | Total loss: 1.677 | Reg loss: 0.050 | Tree loss: 1.677 | Accuracy: 0.560547 | 0.908 sec/iter
Epoch: 64 | Batch: 005 / 030 | Total loss: 1.645 | Reg loss: 0.050 | Tree loss: 1.645 | Ac

Epoch: 66 | Batch: 000 / 030 | Total loss: 1.753 | Reg loss: 0.049 | Tree loss: 1.753 | Accuracy: 0.570312 | 0.908 sec/iter
Epoch: 66 | Batch: 001 / 030 | Total loss: 1.689 | Reg loss: 0.049 | Tree loss: 1.689 | Accuracy: 0.572266 | 0.908 sec/iter
Epoch: 66 | Batch: 002 / 030 | Total loss: 1.666 | Reg loss: 0.049 | Tree loss: 1.666 | Accuracy: 0.595703 | 0.908 sec/iter
Epoch: 66 | Batch: 003 / 030 | Total loss: 1.642 | Reg loss: 0.049 | Tree loss: 1.642 | Accuracy: 0.583984 | 0.908 sec/iter
Epoch: 66 | Batch: 004 / 030 | Total loss: 1.672 | Reg loss: 0.049 | Tree loss: 1.672 | Accuracy: 0.527344 | 0.908 sec/iter
Epoch: 66 | Batch: 005 / 030 | Total loss: 1.671 | Reg loss: 0.049 | Tree loss: 1.671 | Accuracy: 0.541016 | 0.908 sec/iter
Epoch: 66 | Batch: 006 / 030 | Total loss: 1.661 | Reg loss: 0.049 | Tree loss: 1.661 | Accuracy: 0.574219 | 0.908 sec/iter
Epoch: 66 | Batch: 007 / 030 | Total loss: 1.512 | Reg loss: 0.050 | Tree loss: 1.512 | Accuracy: 0.574219 | 0.908 sec/iter
Epoch: 6

Epoch: 68 | Batch: 002 / 030 | Total loss: 1.756 | Reg loss: 0.049 | Tree loss: 1.756 | Accuracy: 0.558594 | 0.908 sec/iter
Epoch: 68 | Batch: 003 / 030 | Total loss: 1.703 | Reg loss: 0.049 | Tree loss: 1.703 | Accuracy: 0.566406 | 0.908 sec/iter
Epoch: 68 | Batch: 004 / 030 | Total loss: 1.648 | Reg loss: 0.049 | Tree loss: 1.648 | Accuracy: 0.550781 | 0.908 sec/iter
Epoch: 68 | Batch: 005 / 030 | Total loss: 1.583 | Reg loss: 0.049 | Tree loss: 1.583 | Accuracy: 0.541016 | 0.908 sec/iter
Epoch: 68 | Batch: 006 / 030 | Total loss: 1.609 | Reg loss: 0.049 | Tree loss: 1.609 | Accuracy: 0.576172 | 0.908 sec/iter
Epoch: 68 | Batch: 007 / 030 | Total loss: 1.513 | Reg loss: 0.049 | Tree loss: 1.513 | Accuracy: 0.601562 | 0.908 sec/iter
Epoch: 68 | Batch: 008 / 030 | Total loss: 1.481 | Reg loss: 0.049 | Tree loss: 1.481 | Accuracy: 0.574219 | 0.908 sec/iter
Epoch: 68 | Batch: 009 / 030 | Total loss: 1.490 | Reg loss: 0.049 | Tree loss: 1.490 | Accuracy: 0.542969 | 0.908 sec/iter
Epoch: 6

Epoch: 70 | Batch: 004 / 030 | Total loss: 1.672 | Reg loss: 0.049 | Tree loss: 1.672 | Accuracy: 0.585938 | 0.908 sec/iter
Epoch: 70 | Batch: 005 / 030 | Total loss: 1.552 | Reg loss: 0.049 | Tree loss: 1.552 | Accuracy: 0.582031 | 0.908 sec/iter
Epoch: 70 | Batch: 006 / 030 | Total loss: 1.528 | Reg loss: 0.049 | Tree loss: 1.528 | Accuracy: 0.597656 | 0.908 sec/iter
Epoch: 70 | Batch: 007 / 030 | Total loss: 1.501 | Reg loss: 0.049 | Tree loss: 1.501 | Accuracy: 0.582031 | 0.908 sec/iter
Epoch: 70 | Batch: 008 / 030 | Total loss: 1.555 | Reg loss: 0.049 | Tree loss: 1.555 | Accuracy: 0.556641 | 0.908 sec/iter
Epoch: 70 | Batch: 009 / 030 | Total loss: 1.434 | Reg loss: 0.049 | Tree loss: 1.434 | Accuracy: 0.603516 | 0.908 sec/iter
Epoch: 70 | Batch: 010 / 030 | Total loss: 1.446 | Reg loss: 0.049 | Tree loss: 1.446 | Accuracy: 0.568359 | 0.908 sec/iter
Epoch: 70 | Batch: 011 / 030 | Total loss: 1.363 | Reg loss: 0.049 | Tree loss: 1.363 | Accuracy: 0.552734 | 0.908 sec/iter
Epoch: 7

Epoch: 72 | Batch: 006 / 030 | Total loss: 1.518 | Reg loss: 0.049 | Tree loss: 1.518 | Accuracy: 0.566406 | 0.908 sec/iter
Epoch: 72 | Batch: 007 / 030 | Total loss: 1.540 | Reg loss: 0.049 | Tree loss: 1.540 | Accuracy: 0.599609 | 0.908 sec/iter
Epoch: 72 | Batch: 008 / 030 | Total loss: 1.540 | Reg loss: 0.049 | Tree loss: 1.540 | Accuracy: 0.558594 | 0.908 sec/iter
Epoch: 72 | Batch: 009 / 030 | Total loss: 1.432 | Reg loss: 0.049 | Tree loss: 1.432 | Accuracy: 0.578125 | 0.908 sec/iter
Epoch: 72 | Batch: 010 / 030 | Total loss: 1.454 | Reg loss: 0.049 | Tree loss: 1.454 | Accuracy: 0.570312 | 0.908 sec/iter
Epoch: 72 | Batch: 011 / 030 | Total loss: 1.385 | Reg loss: 0.049 | Tree loss: 1.385 | Accuracy: 0.605469 | 0.907 sec/iter
Epoch: 72 | Batch: 012 / 030 | Total loss: 1.360 | Reg loss: 0.049 | Tree loss: 1.360 | Accuracy: 0.550781 | 0.907 sec/iter
Epoch: 72 | Batch: 013 / 030 | Total loss: 1.361 | Reg loss: 0.049 | Tree loss: 1.361 | Accuracy: 0.591797 | 0.907 sec/iter
Epoch: 7

Epoch: 74 | Batch: 008 / 030 | Total loss: 1.473 | Reg loss: 0.048 | Tree loss: 1.473 | Accuracy: 0.589844 | 0.907 sec/iter
Epoch: 74 | Batch: 009 / 030 | Total loss: 1.408 | Reg loss: 0.048 | Tree loss: 1.408 | Accuracy: 0.542969 | 0.907 sec/iter
Epoch: 74 | Batch: 010 / 030 | Total loss: 1.323 | Reg loss: 0.049 | Tree loss: 1.323 | Accuracy: 0.556641 | 0.907 sec/iter
Epoch: 74 | Batch: 011 / 030 | Total loss: 1.352 | Reg loss: 0.049 | Tree loss: 1.352 | Accuracy: 0.578125 | 0.907 sec/iter
Epoch: 74 | Batch: 012 / 030 | Total loss: 1.353 | Reg loss: 0.049 | Tree loss: 1.353 | Accuracy: 0.558594 | 0.907 sec/iter
Epoch: 74 | Batch: 013 / 030 | Total loss: 1.296 | Reg loss: 0.049 | Tree loss: 1.296 | Accuracy: 0.615234 | 0.907 sec/iter
Epoch: 74 | Batch: 014 / 030 | Total loss: 1.305 | Reg loss: 0.049 | Tree loss: 1.305 | Accuracy: 0.580078 | 0.907 sec/iter
Epoch: 74 | Batch: 015 / 030 | Total loss: 1.220 | Reg loss: 0.049 | Tree loss: 1.220 | Accuracy: 0.580078 | 0.907 sec/iter
Epoch: 7

Epoch: 76 | Batch: 010 / 030 | Total loss: 1.366 | Reg loss: 0.048 | Tree loss: 1.366 | Accuracy: 0.591797 | 0.907 sec/iter
Epoch: 76 | Batch: 011 / 030 | Total loss: 1.403 | Reg loss: 0.048 | Tree loss: 1.403 | Accuracy: 0.560547 | 0.907 sec/iter
Epoch: 76 | Batch: 012 / 030 | Total loss: 1.330 | Reg loss: 0.049 | Tree loss: 1.330 | Accuracy: 0.576172 | 0.907 sec/iter
Epoch: 76 | Batch: 013 / 030 | Total loss: 1.318 | Reg loss: 0.049 | Tree loss: 1.318 | Accuracy: 0.560547 | 0.907 sec/iter
Epoch: 76 | Batch: 014 / 030 | Total loss: 1.268 | Reg loss: 0.049 | Tree loss: 1.268 | Accuracy: 0.595703 | 0.907 sec/iter
Epoch: 76 | Batch: 015 / 030 | Total loss: 1.244 | Reg loss: 0.049 | Tree loss: 1.244 | Accuracy: 0.585938 | 0.907 sec/iter
Epoch: 76 | Batch: 016 / 030 | Total loss: 1.183 | Reg loss: 0.049 | Tree loss: 1.183 | Accuracy: 0.572266 | 0.907 sec/iter
Epoch: 76 | Batch: 017 / 030 | Total loss: 1.218 | Reg loss: 0.049 | Tree loss: 1.218 | Accuracy: 0.542969 | 0.907 sec/iter
Epoch: 7

Epoch: 78 | Batch: 012 / 030 | Total loss: 1.298 | Reg loss: 0.048 | Tree loss: 1.298 | Accuracy: 0.562500 | 0.907 sec/iter
Epoch: 78 | Batch: 013 / 030 | Total loss: 1.321 | Reg loss: 0.048 | Tree loss: 1.321 | Accuracy: 0.546875 | 0.907 sec/iter
Epoch: 78 | Batch: 014 / 030 | Total loss: 1.246 | Reg loss: 0.049 | Tree loss: 1.246 | Accuracy: 0.580078 | 0.907 sec/iter
Epoch: 78 | Batch: 015 / 030 | Total loss: 1.236 | Reg loss: 0.049 | Tree loss: 1.236 | Accuracy: 0.578125 | 0.907 sec/iter
Epoch: 78 | Batch: 016 / 030 | Total loss: 1.216 | Reg loss: 0.049 | Tree loss: 1.216 | Accuracy: 0.574219 | 0.907 sec/iter
Epoch: 78 | Batch: 017 / 030 | Total loss: 1.160 | Reg loss: 0.049 | Tree loss: 1.160 | Accuracy: 0.613281 | 0.907 sec/iter
Epoch: 78 | Batch: 018 / 030 | Total loss: 1.173 | Reg loss: 0.049 | Tree loss: 1.173 | Accuracy: 0.556641 | 0.907 sec/iter
Epoch: 78 | Batch: 019 / 030 | Total loss: 1.168 | Reg loss: 0.049 | Tree loss: 1.168 | Accuracy: 0.570312 | 0.907 sec/iter
Epoch: 7

Epoch: 80 | Batch: 014 / 030 | Total loss: 1.225 | Reg loss: 0.048 | Tree loss: 1.225 | Accuracy: 0.576172 | 0.907 sec/iter
Epoch: 80 | Batch: 015 / 030 | Total loss: 1.229 | Reg loss: 0.048 | Tree loss: 1.229 | Accuracy: 0.615234 | 0.907 sec/iter
Epoch: 80 | Batch: 016 / 030 | Total loss: 1.209 | Reg loss: 0.049 | Tree loss: 1.209 | Accuracy: 0.550781 | 0.907 sec/iter
Epoch: 80 | Batch: 017 / 030 | Total loss: 1.200 | Reg loss: 0.049 | Tree loss: 1.200 | Accuracy: 0.544922 | 0.907 sec/iter
Epoch: 80 | Batch: 018 / 030 | Total loss: 1.152 | Reg loss: 0.049 | Tree loss: 1.152 | Accuracy: 0.544922 | 0.907 sec/iter
Epoch: 80 | Batch: 019 / 030 | Total loss: 1.129 | Reg loss: 0.049 | Tree loss: 1.129 | Accuracy: 0.615234 | 0.907 sec/iter
Epoch: 80 | Batch: 020 / 030 | Total loss: 1.126 | Reg loss: 0.049 | Tree loss: 1.126 | Accuracy: 0.580078 | 0.907 sec/iter
Epoch: 80 | Batch: 021 / 030 | Total loss: 1.114 | Reg loss: 0.049 | Tree loss: 1.114 | Accuracy: 0.587891 | 0.907 sec/iter
Epoch: 8

Epoch: 82 | Batch: 016 / 030 | Total loss: 1.189 | Reg loss: 0.048 | Tree loss: 1.189 | Accuracy: 0.597656 | 0.907 sec/iter
Epoch: 82 | Batch: 017 / 030 | Total loss: 1.197 | Reg loss: 0.048 | Tree loss: 1.197 | Accuracy: 0.576172 | 0.907 sec/iter
Epoch: 82 | Batch: 018 / 030 | Total loss: 1.169 | Reg loss: 0.049 | Tree loss: 1.169 | Accuracy: 0.550781 | 0.907 sec/iter
Epoch: 82 | Batch: 019 / 030 | Total loss: 1.130 | Reg loss: 0.049 | Tree loss: 1.130 | Accuracy: 0.570312 | 0.907 sec/iter
Epoch: 82 | Batch: 020 / 030 | Total loss: 1.097 | Reg loss: 0.049 | Tree loss: 1.097 | Accuracy: 0.605469 | 0.907 sec/iter
Epoch: 82 | Batch: 021 / 030 | Total loss: 1.094 | Reg loss: 0.049 | Tree loss: 1.094 | Accuracy: 0.554688 | 0.907 sec/iter
Epoch: 82 | Batch: 022 / 030 | Total loss: 1.098 | Reg loss: 0.049 | Tree loss: 1.098 | Accuracy: 0.527344 | 0.907 sec/iter
Epoch: 82 | Batch: 023 / 030 | Total loss: 1.054 | Reg loss: 0.049 | Tree loss: 1.054 | Accuracy: 0.548828 | 0.907 sec/iter
Epoch: 8

Epoch: 84 | Batch: 018 / 030 | Total loss: 1.133 | Reg loss: 0.048 | Tree loss: 1.133 | Accuracy: 0.578125 | 0.907 sec/iter
Epoch: 84 | Batch: 019 / 030 | Total loss: 1.139 | Reg loss: 0.048 | Tree loss: 1.139 | Accuracy: 0.560547 | 0.907 sec/iter
Epoch: 84 | Batch: 020 / 030 | Total loss: 1.136 | Reg loss: 0.049 | Tree loss: 1.136 | Accuracy: 0.572266 | 0.907 sec/iter
Epoch: 84 | Batch: 021 / 030 | Total loss: 1.100 | Reg loss: 0.049 | Tree loss: 1.100 | Accuracy: 0.556641 | 0.907 sec/iter
Epoch: 84 | Batch: 022 / 030 | Total loss: 1.075 | Reg loss: 0.049 | Tree loss: 1.075 | Accuracy: 0.582031 | 0.907 sec/iter
Epoch: 84 | Batch: 023 / 030 | Total loss: 1.040 | Reg loss: 0.049 | Tree loss: 1.040 | Accuracy: 0.582031 | 0.907 sec/iter
Epoch: 84 | Batch: 024 / 030 | Total loss: 1.068 | Reg loss: 0.049 | Tree loss: 1.068 | Accuracy: 0.601562 | 0.907 sec/iter
Epoch: 84 | Batch: 025 / 030 | Total loss: 1.007 | Reg loss: 0.049 | Tree loss: 1.007 | Accuracy: 0.568359 | 0.907 sec/iter
Epoch: 8

Epoch: 86 | Batch: 020 / 030 | Total loss: 1.112 | Reg loss: 0.048 | Tree loss: 1.112 | Accuracy: 0.562500 | 0.907 sec/iter
Epoch: 86 | Batch: 021 / 030 | Total loss: 1.102 | Reg loss: 0.049 | Tree loss: 1.102 | Accuracy: 0.574219 | 0.907 sec/iter
Epoch: 86 | Batch: 022 / 030 | Total loss: 1.094 | Reg loss: 0.049 | Tree loss: 1.094 | Accuracy: 0.539062 | 0.907 sec/iter
Epoch: 86 | Batch: 023 / 030 | Total loss: 1.052 | Reg loss: 0.049 | Tree loss: 1.052 | Accuracy: 0.556641 | 0.907 sec/iter
Epoch: 86 | Batch: 024 / 030 | Total loss: 1.009 | Reg loss: 0.049 | Tree loss: 1.009 | Accuracy: 0.582031 | 0.907 sec/iter
Epoch: 86 | Batch: 025 / 030 | Total loss: 1.018 | Reg loss: 0.049 | Tree loss: 1.018 | Accuracy: 0.570312 | 0.907 sec/iter
Epoch: 86 | Batch: 026 / 030 | Total loss: 0.994 | Reg loss: 0.049 | Tree loss: 0.994 | Accuracy: 0.582031 | 0.907 sec/iter
Epoch: 86 | Batch: 027 / 030 | Total loss: 0.979 | Reg loss: 0.049 | Tree loss: 0.979 | Accuracy: 0.587891 | 0.907 sec/iter
Epoch: 8

Epoch: 88 | Batch: 022 / 030 | Total loss: 1.076 | Reg loss: 0.048 | Tree loss: 1.076 | Accuracy: 0.548828 | 0.907 sec/iter
Epoch: 88 | Batch: 023 / 030 | Total loss: 1.035 | Reg loss: 0.049 | Tree loss: 1.035 | Accuracy: 0.585938 | 0.907 sec/iter
Epoch: 88 | Batch: 024 / 030 | Total loss: 1.016 | Reg loss: 0.049 | Tree loss: 1.016 | Accuracy: 0.572266 | 0.907 sec/iter
Epoch: 88 | Batch: 025 / 030 | Total loss: 1.010 | Reg loss: 0.049 | Tree loss: 1.010 | Accuracy: 0.568359 | 0.907 sec/iter
Epoch: 88 | Batch: 026 / 030 | Total loss: 0.993 | Reg loss: 0.049 | Tree loss: 0.993 | Accuracy: 0.578125 | 0.907 sec/iter
Epoch: 88 | Batch: 027 / 030 | Total loss: 1.009 | Reg loss: 0.049 | Tree loss: 1.009 | Accuracy: 0.552734 | 0.907 sec/iter
Epoch: 88 | Batch: 028 / 030 | Total loss: 0.968 | Reg loss: 0.049 | Tree loss: 0.968 | Accuracy: 0.589844 | 0.907 sec/iter
Epoch: 88 | Batch: 029 / 030 | Total loss: 0.962 | Reg loss: 0.049 | Tree loss: 0.962 | Accuracy: 0.533333 | 0.907 sec/iter
Average 

Epoch: 90 | Batch: 024 / 030 | Total loss: 1.019 | Reg loss: 0.049 | Tree loss: 1.019 | Accuracy: 0.593750 | 0.907 sec/iter
Epoch: 90 | Batch: 025 / 030 | Total loss: 1.017 | Reg loss: 0.049 | Tree loss: 1.017 | Accuracy: 0.566406 | 0.907 sec/iter
Epoch: 90 | Batch: 026 / 030 | Total loss: 1.012 | Reg loss: 0.049 | Tree loss: 1.012 | Accuracy: 0.566406 | 0.907 sec/iter
Epoch: 90 | Batch: 027 / 030 | Total loss: 0.967 | Reg loss: 0.049 | Tree loss: 0.967 | Accuracy: 0.587891 | 0.907 sec/iter
Epoch: 90 | Batch: 028 / 030 | Total loss: 0.970 | Reg loss: 0.049 | Tree loss: 0.970 | Accuracy: 0.562500 | 0.907 sec/iter
Epoch: 90 | Batch: 029 / 030 | Total loss: 0.966 | Reg loss: 0.049 | Tree loss: 0.966 | Accuracy: 0.571429 | 0.907 sec/iter
Average sparseness: 0.9821428571428573
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
layer 7: 0.98214285714

Epoch: 92 | Batch: 026 / 030 | Total loss: 0.975 | Reg loss: 0.049 | Tree loss: 0.975 | Accuracy: 0.595703 | 0.907 sec/iter
Epoch: 92 | Batch: 027 / 030 | Total loss: 0.984 | Reg loss: 0.049 | Tree loss: 0.984 | Accuracy: 0.568359 | 0.907 sec/iter
Epoch: 92 | Batch: 028 / 030 | Total loss: 0.953 | Reg loss: 0.049 | Tree loss: 0.953 | Accuracy: 0.607422 | 0.907 sec/iter
Epoch: 92 | Batch: 029 / 030 | Total loss: 0.929 | Reg loss: 0.049 | Tree loss: 0.929 | Accuracy: 0.580952 | 0.907 sec/iter
Average sparseness: 0.9821428571428573
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
layer 7: 0.9821428571428573
layer 8: 0.9821428571428573
Epoch: 93 | Batch: 000 / 030 | Total loss: 1.690 | Reg loss: 0.046 | Tree loss: 1.690 | Accuracy: 0.582031 | 0.907 sec/iter
Epoch: 93 | Batch: 001 / 030 | Total loss: 1.675 | Reg loss: 0.046 | Tree loss: 1.675 | Ac

Epoch: 94 | Batch: 028 / 030 | Total loss: 0.965 | Reg loss: 0.049 | Tree loss: 0.965 | Accuracy: 0.580078 | 0.907 sec/iter
Epoch: 94 | Batch: 029 / 030 | Total loss: 0.962 | Reg loss: 0.049 | Tree loss: 0.962 | Accuracy: 0.561905 | 0.907 sec/iter
Average sparseness: 0.9821428571428573
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
layer 7: 0.9821428571428573
layer 8: 0.9821428571428573
Epoch: 95 | Batch: 000 / 030 | Total loss: 1.688 | Reg loss: 0.046 | Tree loss: 1.688 | Accuracy: 0.583984 | 0.907 sec/iter
Epoch: 95 | Batch: 001 / 030 | Total loss: 1.598 | Reg loss: 0.046 | Tree loss: 1.598 | Accuracy: 0.572266 | 0.907 sec/iter
Epoch: 95 | Batch: 002 / 030 | Total loss: 1.625 | Reg loss: 0.046 | Tree loss: 1.625 | Accuracy: 0.562500 | 0.907 sec/iter
Epoch: 95 | Batch: 003 / 030 | Total loss: 1.628 | Reg loss: 0.046 | Tree loss: 1.628 | Ac

Average sparseness: 0.9821428571428573
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
layer 7: 0.9821428571428573
layer 8: 0.9821428571428573
Epoch: 97 | Batch: 000 / 030 | Total loss: 1.613 | Reg loss: 0.046 | Tree loss: 1.613 | Accuracy: 0.554688 | 0.906 sec/iter
Epoch: 97 | Batch: 001 / 030 | Total loss: 1.663 | Reg loss: 0.046 | Tree loss: 1.663 | Accuracy: 0.574219 | 0.906 sec/iter
Epoch: 97 | Batch: 002 / 030 | Total loss: 1.588 | Reg loss: 0.046 | Tree loss: 1.588 | Accuracy: 0.566406 | 0.906 sec/iter
Epoch: 97 | Batch: 003 / 030 | Total loss: 1.571 | Reg loss: 0.046 | Tree loss: 1.571 | Accuracy: 0.578125 | 0.906 sec/iter
Epoch: 97 | Batch: 004 / 030 | Total loss: 1.510 | Reg loss: 0.046 | Tree loss: 1.510 | Accuracy: 0.572266 | 0.906 sec/iter
Epoch: 97 | Batch: 005 / 030 | Total loss: 1.469 | Reg loss: 0.046 | Tree loss: 1.469 | Ac

Epoch: 99 | Batch: 000 / 030 | Total loss: 1.669 | Reg loss: 0.046 | Tree loss: 1.669 | Accuracy: 0.583984 | 0.906 sec/iter
Epoch: 99 | Batch: 001 / 030 | Total loss: 1.619 | Reg loss: 0.046 | Tree loss: 1.619 | Accuracy: 0.582031 | 0.906 sec/iter
Epoch: 99 | Batch: 002 / 030 | Total loss: 1.657 | Reg loss: 0.046 | Tree loss: 1.657 | Accuracy: 0.576172 | 0.906 sec/iter
Epoch: 99 | Batch: 003 / 030 | Total loss: 1.579 | Reg loss: 0.046 | Tree loss: 1.579 | Accuracy: 0.574219 | 0.906 sec/iter
Epoch: 99 | Batch: 004 / 030 | Total loss: 1.505 | Reg loss: 0.046 | Tree loss: 1.505 | Accuracy: 0.568359 | 0.906 sec/iter
Epoch: 99 | Batch: 005 / 030 | Total loss: 1.545 | Reg loss: 0.046 | Tree loss: 1.545 | Accuracy: 0.562500 | 0.906 sec/iter
Epoch: 99 | Batch: 006 / 030 | Total loss: 1.448 | Reg loss: 0.046 | Tree loss: 1.448 | Accuracy: 0.601562 | 0.906 sec/iter
Epoch: 99 | Batch: 007 / 030 | Total loss: 1.429 | Reg loss: 0.046 | Tree loss: 1.429 | Accuracy: 0.578125 | 0.906 sec/iter
Epoch: 9

In [31]:
plt.figure(figsize=(10, 5))
plt.ylabel("Accuracy")
plt.xlabel('Iteration')
plt.plot(accs, label='Accuracy vs iteration')
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [32]:
plt.figure()
plt.ylabel("Loss")
plt.xlabel('Iteration')
plt.plot(losses, label='Loss vs iteration')
plt.yscale("log")
plt.show()

plt.figure()
weights = tree.inner_nodes.weight.cpu().detach().numpy().flatten()
plt.hist(weights, bins=500)
weights_std = np.std(weights)
weights_mean = np.mean(weights)
plt.axvline(weights_mean + weights_std, color='r')
plt.axvline(weights_mean - weights_std, color='r')
plt.title(f"Mean: {weights_mean}   |   STD: {weights_std}")
plt.yscale("log")
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

# Tree Visualization

In [33]:
plt.figure(figsize=(15, 10), dpi=80)
avg_height, root = tree.visualize()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Average height: 9.191358024691358


# Extract Rules

# Accumulate samples in the leaves

In [34]:
print(f"Number of patterns: {len(root.get_leaves())}")

Number of patterns: 324


In [35]:
method = 'greedy'

In [36]:
root.clear_leaves_samples()

with torch.no_grad():
    for batch_idx, (data, target) in enumerate(tree_loader):
        root.accumulate_samples(data, method)



# Tighten boundaries

In [37]:
attr_names = dataset.items

# print(attr_names)
leaves = root.get_leaves()
sum_comprehensibility = 0
comprehensibilities = []
for pattern_counter, leaf in enumerate(leaves):
    leaf.reset_path()
    leaf.tighten_with_accumulated_samples()
    conds = leaf.get_path_conditions(attr_names)
    print(f"============== Pattern {pattern_counter + 1} ==============")
    comprehensibilities.append(sum([cond.comprehensibility for cond in conds]))
    
print(f"Average comprehensibility: {np.mean(comprehensibilities)}")
print(f"std comprehensibility: {np.std(comprehensibilities)}")
print(f"var comprehensibility: {np.var(comprehensibilities)}")
print(f"minimum comprehensibility: {np.min(comprehensibilities)}")
print(f"maximum comprehensibility: {np.max(comprehensibilities)}")

  return np.log(1 / (1 - x))


9468
3127
1188
103
120
947
Average comprehensibility: 48.18518518518518
std comprehensibility: 6.052468348851946
var comprehensibility: 36.6323731138546
minimum comprehensibility: 24
maximum comprehensibility: 58
