In [89]:
%load_ext autoreload
%autoreload 2
%matplotlib widget
import numpy as np
import pandas as pd
import torch
import torchvision
from torch import optim
import matplotlib.pyplot as plt
import torch.nn.functional as F
from sklearn.metrics import pairwise_distances
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tqdm.notebook import tqdm
import os
import sys
import time
import torch.nn as nn
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)
from stream_generators.market_basket_dataset import MarketBasketDataset, BinaryEncodingTransform, RemoveItemsTransform
from utils.MatplotlibUtils import reduce_dims_and_plot
from network.auto_encoder import AutoEncoder
from losses.knn_loss import KNNLoss
from sklearn.cluster import DBSCAN
from sklearn.mixture import GaussianMixture
from soft_decision_tree.sdt_model import SDT
from sklearn.metrics import davies_bouldin_score

## Create the market basket dataset and use one-hot encoding for items

In [2]:
dataset = MarketBasketDataset(dataset_path=r"C:\Users\eitan\PycharmProjects\EntangledExplainableClustering\data\Groceries_dataset.csv")

In [4]:
model = AutoEncoder(dataset.n_items, 50, 4)
epochs = 500
lr = 5e-3
batch_size = 512
log_every = 5

In [5]:
dataset.transform = torchvision.transforms.Compose([
    RemoveItemsTransform(p=0.5),
    BinaryEncodingTransform(mapping=dataset.items_to_idx),
]
)
dataset.target_transform = torchvision.transforms.Compose([
    BinaryEncodingTransform(mapping=dataset.items_to_idx),
]
)

In [6]:
model.train()
data_iter = torch.utils.data.DataLoader(dataset,
                                     batch_size=batch_size,
                                     shuffle=True,
                                     num_workers=1,
                                     pin_memory=True)
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)
scheduler = ReduceLROnPlateau(optimizer, 'min', verbose=True, factor=0.7, threshold=1e-4)
knn_crt = KNNLoss(k=32)
losses = []
alpha = 10/170
gamma = 2
for epoch in range(epochs):
    total_loss = 0
    for iteration, (batch, target) in enumerate(data_iter):
        outputs, iterm = model(batch, return_intermidiate=True)
        mse_loss = F.binary_cross_entropy_with_logits(outputs, target, reduction='none')
        mask = torch.ones_like(mse_loss)
        mask[target == 0] = alpha ** gamma
        mask[target == 1] = (1 - alpha) ** gamma
        mse_loss = (mse_loss * mask).sum(dim=-1).mean()
        knn_loss = knn_crt(iterm)
        loss = mse_loss + knn_loss
        total_loss += loss.item()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if iteration % log_every == 0:
            print(f"Epoch {epoch} / {epochs} | iteration {iteration} / {len(data_iter)} | Total Loss: {loss.item()} | KNN Loss: {knn_loss.item()} | BCE Loss: {mse_loss.item()}")
    
    scheduler.step(total_loss / (iteration + 1))
    losses.append(total_loss / (iteration + 1))

Epoch 0 / 500 | iteration 0 / 30 | Total Loss: 8.115936279296875 | KNN Loss: 6.224985599517822 | BCE Loss: 1.8909504413604736
Epoch 0 / 500 | iteration 5 / 30 | Total Loss: 8.152900695800781 | KNN Loss: 6.225427150726318 | BCE Loss: 1.927473783493042
Epoch 0 / 500 | iteration 10 / 30 | Total Loss: 8.18317985534668 | KNN Loss: 6.225057125091553 | BCE Loss: 1.9581222534179688
Epoch 0 / 500 | iteration 15 / 30 | Total Loss: 8.14844799041748 | KNN Loss: 6.224520683288574 | BCE Loss: 1.9239273071289062
Epoch 0 / 500 | iteration 20 / 30 | Total Loss: 8.107247352600098 | KNN Loss: 6.224483489990234 | BCE Loss: 1.8827638626098633
Epoch 0 / 500 | iteration 25 / 30 | Total Loss: 8.133842468261719 | KNN Loss: 6.224823474884033 | BCE Loss: 1.9090194702148438
Epoch 1 / 500 | iteration 0 / 30 | Total Loss: 8.103509902954102 | KNN Loss: 6.223139762878418 | BCE Loss: 1.880369782447815
Epoch 1 / 500 | iteration 5 / 30 | Total Loss: 8.12332534790039 | KNN Loss: 6.223299980163574 | BCE Loss: 1.9000256061

Epoch 10 / 500 | iteration 25 / 30 | Total Loss: 5.691650390625 | KNN Loss: 4.585964202880859 | BCE Loss: 1.1056861877441406
Epoch 11 / 500 | iteration 0 / 30 | Total Loss: 5.657222747802734 | KNN Loss: 4.540197849273682 | BCE Loss: 1.1170251369476318
Epoch 11 / 500 | iteration 5 / 30 | Total Loss: 5.559140682220459 | KNN Loss: 4.475207328796387 | BCE Loss: 1.0839334726333618
Epoch 11 / 500 | iteration 10 / 30 | Total Loss: 5.476250648498535 | KNN Loss: 4.389369010925293 | BCE Loss: 1.0868817567825317
Epoch 11 / 500 | iteration 15 / 30 | Total Loss: 5.420239448547363 | KNN Loss: 4.333858013153076 | BCE Loss: 1.0863815546035767
Epoch 11 / 500 | iteration 20 / 30 | Total Loss: 5.372244834899902 | KNN Loss: 4.272647857666016 | BCE Loss: 1.0995970964431763
Epoch 11 / 500 | iteration 25 / 30 | Total Loss: 5.269759178161621 | KNN Loss: 4.20658540725708 | BCE Loss: 1.0631736516952515
Epoch 12 / 500 | iteration 0 / 30 | Total Loss: 5.299194812774658 | KNN Loss: 4.192719459533691 | BCE Loss: 1.

Epoch 21 / 500 | iteration 20 / 30 | Total Loss: 4.917761325836182 | KNN Loss: 3.854158878326416 | BCE Loss: 1.0636024475097656
Epoch 21 / 500 | iteration 25 / 30 | Total Loss: 4.8506550788879395 | KNN Loss: 3.80037522315979 | BCE Loss: 1.0502798557281494
Epoch 22 / 500 | iteration 0 / 30 | Total Loss: 4.865542888641357 | KNN Loss: 3.8176093101501465 | BCE Loss: 1.0479336977005005
Epoch 22 / 500 | iteration 5 / 30 | Total Loss: 4.878241539001465 | KNN Loss: 3.829946279525757 | BCE Loss: 1.0482953786849976
Epoch 22 / 500 | iteration 10 / 30 | Total Loss: 4.888167858123779 | KNN Loss: 3.831289768218994 | BCE Loss: 1.0568782091140747
Epoch 22 / 500 | iteration 15 / 30 | Total Loss: 4.8718719482421875 | KNN Loss: 3.817835569381714 | BCE Loss: 1.0540366172790527
Epoch 22 / 500 | iteration 20 / 30 | Total Loss: 4.849645614624023 | KNN Loss: 3.808359146118164 | BCE Loss: 1.0412867069244385
Epoch 22 / 500 | iteration 25 / 30 | Total Loss: 4.8606719970703125 | KNN Loss: 3.8155460357666016 | BCE

Epoch 32 / 500 | iteration 15 / 30 | Total Loss: 4.803418159484863 | KNN Loss: 3.7790985107421875 | BCE Loss: 1.0243196487426758
Epoch 32 / 500 | iteration 20 / 30 | Total Loss: 4.8621320724487305 | KNN Loss: 3.814321994781494 | BCE Loss: 1.0478098392486572
Epoch 32 / 500 | iteration 25 / 30 | Total Loss: 4.788928031921387 | KNN Loss: 3.738187551498413 | BCE Loss: 1.0507407188415527
Epoch 33 / 500 | iteration 0 / 30 | Total Loss: 4.772955894470215 | KNN Loss: 3.7529213428497314 | BCE Loss: 1.020034670829773
Epoch 33 / 500 | iteration 5 / 30 | Total Loss: 4.833101749420166 | KNN Loss: 3.7818121910095215 | BCE Loss: 1.051289439201355
Epoch 33 / 500 | iteration 10 / 30 | Total Loss: 4.817929267883301 | KNN Loss: 3.795948028564453 | BCE Loss: 1.0219810009002686
Epoch 33 / 500 | iteration 15 / 30 | Total Loss: 4.77934455871582 | KNN Loss: 3.7568111419677734 | BCE Loss: 1.0225335359573364
Epoch 33 / 500 | iteration 20 / 30 | Total Loss: 4.878941535949707 | KNN Loss: 3.80889630317688 | BCE Lo

Epoch 43 / 500 | iteration 10 / 30 | Total Loss: 4.8286662101745605 | KNN Loss: 3.7770891189575195 | BCE Loss: 1.0515769720077515
Epoch 43 / 500 | iteration 15 / 30 | Total Loss: 4.798503398895264 | KNN Loss: 3.756279706954956 | BCE Loss: 1.042223572731018
Epoch 43 / 500 | iteration 20 / 30 | Total Loss: 4.760849952697754 | KNN Loss: 3.7435646057128906 | BCE Loss: 1.0172855854034424
Epoch 43 / 500 | iteration 25 / 30 | Total Loss: 4.74405574798584 | KNN Loss: 3.736921548843384 | BCE Loss: 1.007134199142456
Epoch 44 / 500 | iteration 0 / 30 | Total Loss: 4.801478862762451 | KNN Loss: 3.750027894973755 | BCE Loss: 1.0514508485794067
Epoch 44 / 500 | iteration 5 / 30 | Total Loss: 4.815277099609375 | KNN Loss: 3.7805542945861816 | BCE Loss: 1.0347228050231934
Epoch 44 / 500 | iteration 10 / 30 | Total Loss: 4.7980265617370605 | KNN Loss: 3.7462282180786133 | BCE Loss: 1.0517983436584473
Epoch 44 / 500 | iteration 15 / 30 | Total Loss: 4.757270336151123 | KNN Loss: 3.745525360107422 | BCE 

Epoch 54 / 500 | iteration 5 / 30 | Total Loss: 4.752335548400879 | KNN Loss: 3.7268314361572266 | BCE Loss: 1.025504231452942
Epoch 54 / 500 | iteration 10 / 30 | Total Loss: 4.776770114898682 | KNN Loss: 3.730642318725586 | BCE Loss: 1.0461277961730957
Epoch 54 / 500 | iteration 15 / 30 | Total Loss: 4.769738674163818 | KNN Loss: 3.7460007667541504 | BCE Loss: 1.0237380266189575
Epoch 54 / 500 | iteration 20 / 30 | Total Loss: 4.758370399475098 | KNN Loss: 3.729994535446167 | BCE Loss: 1.0283759832382202
Epoch 54 / 500 | iteration 25 / 30 | Total Loss: 4.734597206115723 | KNN Loss: 3.7082557678222656 | BCE Loss: 1.026341438293457
Epoch 55 / 500 | iteration 0 / 30 | Total Loss: 4.761995315551758 | KNN Loss: 3.7163920402526855 | BCE Loss: 1.0456032752990723
Epoch 55 / 500 | iteration 5 / 30 | Total Loss: 4.832346439361572 | KNN Loss: 3.7850418090820312 | BCE Loss: 1.047304630279541
Epoch 55 / 500 | iteration 10 / 30 | Total Loss: 4.786372661590576 | KNN Loss: 3.743522882461548 | BCE Lo

Epoch 65 / 500 | iteration 0 / 30 | Total Loss: 4.734915733337402 | KNN Loss: 3.7143349647521973 | BCE Loss: 1.0205806493759155
Epoch 65 / 500 | iteration 5 / 30 | Total Loss: 4.7693681716918945 | KNN Loss: 3.749528646469116 | BCE Loss: 1.0198394060134888
Epoch 65 / 500 | iteration 10 / 30 | Total Loss: 4.743477821350098 | KNN Loss: 3.7292909622192383 | BCE Loss: 1.0141867399215698
Epoch 65 / 500 | iteration 15 / 30 | Total Loss: 4.7527570724487305 | KNN Loss: 3.714845657348633 | BCE Loss: 1.037911295890808
Epoch 65 / 500 | iteration 20 / 30 | Total Loss: 4.719400882720947 | KNN Loss: 3.6981663703918457 | BCE Loss: 1.0212345123291016
Epoch 65 / 500 | iteration 25 / 30 | Total Loss: 4.689562797546387 | KNN Loss: 3.6784324645996094 | BCE Loss: 1.0111305713653564
Epoch 66 / 500 | iteration 0 / 30 | Total Loss: 4.686245441436768 | KNN Loss: 3.6960532665252686 | BCE Loss: 0.9901921153068542
Epoch 66 / 500 | iteration 5 / 30 | Total Loss: 4.756433963775635 | KNN Loss: 3.7147462368011475 | BC

Epoch 75 / 500 | iteration 25 / 30 | Total Loss: 4.747470855712891 | KNN Loss: 3.7174339294433594 | BCE Loss: 1.0300371646881104
Epoch 76 / 500 | iteration 0 / 30 | Total Loss: 4.704168319702148 | KNN Loss: 3.698162317276001 | BCE Loss: 1.006005883216858
Epoch 76 / 500 | iteration 5 / 30 | Total Loss: 4.730522155761719 | KNN Loss: 3.7140674591064453 | BCE Loss: 1.0164546966552734
Epoch 76 / 500 | iteration 10 / 30 | Total Loss: 4.772841930389404 | KNN Loss: 3.7469029426574707 | BCE Loss: 1.025938868522644
Epoch 76 / 500 | iteration 15 / 30 | Total Loss: 4.81037712097168 | KNN Loss: 3.77256441116333 | BCE Loss: 1.0378127098083496
Epoch 76 / 500 | iteration 20 / 30 | Total Loss: 4.759209632873535 | KNN Loss: 3.7385241985321045 | BCE Loss: 1.0206855535507202
Epoch 76 / 500 | iteration 25 / 30 | Total Loss: 4.692880630493164 | KNN Loss: 3.676941156387329 | BCE Loss: 1.0159395933151245
Epoch 77 / 500 | iteration 0 / 30 | Total Loss: 4.740056037902832 | KNN Loss: 3.727343797683716 | BCE Loss

Epoch 86 / 500 | iteration 20 / 30 | Total Loss: 4.7309064865112305 | KNN Loss: 3.705573081970215 | BCE Loss: 1.0253334045410156
Epoch 86 / 500 | iteration 25 / 30 | Total Loss: 4.747256278991699 | KNN Loss: 3.710092306137085 | BCE Loss: 1.0371637344360352
Epoch 87 / 500 | iteration 0 / 30 | Total Loss: 4.806870460510254 | KNN Loss: 3.7454051971435547 | BCE Loss: 1.0614655017852783
Epoch 87 / 500 | iteration 5 / 30 | Total Loss: 4.725842475891113 | KNN Loss: 3.720979690551758 | BCE Loss: 1.004862904548645
Epoch 87 / 500 | iteration 10 / 30 | Total Loss: 4.717747688293457 | KNN Loss: 3.6982581615448 | BCE Loss: 1.0194897651672363
Epoch 87 / 500 | iteration 15 / 30 | Total Loss: 4.732027053833008 | KNN Loss: 3.6981732845306396 | BCE Loss: 1.0338537693023682
Epoch 87 / 500 | iteration 20 / 30 | Total Loss: 4.715744972229004 | KNN Loss: 3.692932367324829 | BCE Loss: 1.0228126049041748
Epoch 87 / 500 | iteration 25 / 30 | Total Loss: 4.781083583831787 | KNN Loss: 3.7257699966430664 | BCE Lo

Epoch 97 / 500 | iteration 10 / 30 | Total Loss: 4.726659774780273 | KNN Loss: 3.697092294692993 | BCE Loss: 1.0295674800872803
Epoch 97 / 500 | iteration 15 / 30 | Total Loss: 4.79210090637207 | KNN Loss: 3.75777268409729 | BCE Loss: 1.0343281030654907
Epoch 97 / 500 | iteration 20 / 30 | Total Loss: 4.7152581214904785 | KNN Loss: 3.710667133331299 | BCE Loss: 1.0045909881591797
Epoch 97 / 500 | iteration 25 / 30 | Total Loss: 4.714756965637207 | KNN Loss: 3.689267158508301 | BCE Loss: 1.0254895687103271
Epoch 98 / 500 | iteration 0 / 30 | Total Loss: 4.711281776428223 | KNN Loss: 3.6811327934265137 | BCE Loss: 1.0301487445831299
Epoch 98 / 500 | iteration 5 / 30 | Total Loss: 4.758164882659912 | KNN Loss: 3.739976644515991 | BCE Loss: 1.0181881189346313
Epoch 98 / 500 | iteration 10 / 30 | Total Loss: 4.711976051330566 | KNN Loss: 3.7014362812042236 | BCE Loss: 1.0105395317077637
Epoch 98 / 500 | iteration 15 / 30 | Total Loss: 4.69168758392334 | KNN Loss: 3.7096893787384033 | BCE Lo

Epoch 108 / 500 | iteration 0 / 30 | Total Loss: 4.672432899475098 | KNN Loss: 3.656273365020752 | BCE Loss: 1.0161592960357666
Epoch 108 / 500 | iteration 5 / 30 | Total Loss: 4.719854354858398 | KNN Loss: 3.708113193511963 | BCE Loss: 1.0117411613464355
Epoch 108 / 500 | iteration 10 / 30 | Total Loss: 4.710195064544678 | KNN Loss: 3.682309865951538 | BCE Loss: 1.0278851985931396
Epoch 108 / 500 | iteration 15 / 30 | Total Loss: 4.722254753112793 | KNN Loss: 3.703908920288086 | BCE Loss: 1.018345594406128
Epoch 108 / 500 | iteration 20 / 30 | Total Loss: 4.760400772094727 | KNN Loss: 3.734800338745117 | BCE Loss: 1.0256006717681885
Epoch 108 / 500 | iteration 25 / 30 | Total Loss: 4.755642414093018 | KNN Loss: 3.6941463947296143 | BCE Loss: 1.0614960193634033
Epoch 109 / 500 | iteration 0 / 30 | Total Loss: 4.730020523071289 | KNN Loss: 3.6985361576080322 | BCE Loss: 1.0314841270446777
Epoch 109 / 500 | iteration 5 / 30 | Total Loss: 4.752190589904785 | KNN Loss: 3.7291274070739746 |

Epoch 118 / 500 | iteration 20 / 30 | Total Loss: 4.7846903800964355 | KNN Loss: 3.715196132659912 | BCE Loss: 1.0694941282272339
Epoch 118 / 500 | iteration 25 / 30 | Total Loss: 4.709635257720947 | KNN Loss: 3.712170124053955 | BCE Loss: 0.9974650740623474
Epoch 119 / 500 | iteration 0 / 30 | Total Loss: 4.7561163902282715 | KNN Loss: 3.711711883544922 | BCE Loss: 1.0444045066833496
Epoch 119 / 500 | iteration 5 / 30 | Total Loss: 4.736172676086426 | KNN Loss: 3.684737205505371 | BCE Loss: 1.0514353513717651
Epoch 119 / 500 | iteration 10 / 30 | Total Loss: 4.700291633605957 | KNN Loss: 3.7006659507751465 | BCE Loss: 0.9996259212493896
Epoch 119 / 500 | iteration 15 / 30 | Total Loss: 4.763790130615234 | KNN Loss: 3.7218568325042725 | BCE Loss: 1.0419334173202515
Epoch 119 / 500 | iteration 20 / 30 | Total Loss: 4.69743537902832 | KNN Loss: 3.6858415603637695 | BCE Loss: 1.0115939378738403
Epoch 119 / 500 | iteration 25 / 30 | Total Loss: 4.726835250854492 | KNN Loss: 3.6979422569274

Epoch 129 / 500 | iteration 10 / 30 | Total Loss: 4.721139907836914 | KNN Loss: 3.6969926357269287 | BCE Loss: 1.0241472721099854
Epoch 129 / 500 | iteration 15 / 30 | Total Loss: 4.726146221160889 | KNN Loss: 3.702768564224243 | BCE Loss: 1.023377776145935
Epoch 129 / 500 | iteration 20 / 30 | Total Loss: 4.717582702636719 | KNN Loss: 3.6911728382110596 | BCE Loss: 1.02640962600708
Epoch 129 / 500 | iteration 25 / 30 | Total Loss: 4.732463359832764 | KNN Loss: 3.716783046722412 | BCE Loss: 1.0156803131103516
Epoch 130 / 500 | iteration 0 / 30 | Total Loss: 4.725879192352295 | KNN Loss: 3.704484224319458 | BCE Loss: 1.0213948488235474
Epoch 130 / 500 | iteration 5 / 30 | Total Loss: 4.7041449546813965 | KNN Loss: 3.6880252361297607 | BCE Loss: 1.0161197185516357
Epoch 130 / 500 | iteration 10 / 30 | Total Loss: 4.7561445236206055 | KNN Loss: 3.70953369140625 | BCE Loss: 1.0466105937957764
Epoch 130 / 500 | iteration 15 / 30 | Total Loss: 4.744792461395264 | KNN Loss: 3.698913335800171 

Epoch 140 / 500 | iteration 0 / 30 | Total Loss: 4.706297397613525 | KNN Loss: 3.695859909057617 | BCE Loss: 1.0104373693466187
Epoch 140 / 500 | iteration 5 / 30 | Total Loss: 4.726670742034912 | KNN Loss: 3.6705219745635986 | BCE Loss: 1.0561487674713135
Epoch 140 / 500 | iteration 10 / 30 | Total Loss: 4.730268478393555 | KNN Loss: 3.7021918296813965 | BCE Loss: 1.0280765295028687
Epoch 140 / 500 | iteration 15 / 30 | Total Loss: 4.704137325286865 | KNN Loss: 3.700000047683716 | BCE Loss: 1.0041372776031494
Epoch 140 / 500 | iteration 20 / 30 | Total Loss: 4.699563503265381 | KNN Loss: 3.6826281547546387 | BCE Loss: 1.0169353485107422
Epoch 140 / 500 | iteration 25 / 30 | Total Loss: 4.741819381713867 | KNN Loss: 3.6940155029296875 | BCE Loss: 1.0478041172027588
Epoch 141 / 500 | iteration 0 / 30 | Total Loss: 4.686394691467285 | KNN Loss: 3.6744132041931152 | BCE Loss: 1.011981725692749
Epoch 141 / 500 | iteration 5 / 30 | Total Loss: 4.7060394287109375 | KNN Loss: 3.68632292747497

Epoch 150 / 500 | iteration 20 / 30 | Total Loss: 4.758878707885742 | KNN Loss: 3.700343132019043 | BCE Loss: 1.0585356950759888
Epoch 150 / 500 | iteration 25 / 30 | Total Loss: 4.734784126281738 | KNN Loss: 3.7115416526794434 | BCE Loss: 1.0232425928115845
Epoch 151 / 500 | iteration 0 / 30 | Total Loss: 4.756502628326416 | KNN Loss: 3.7077901363372803 | BCE Loss: 1.0487124919891357
Epoch 151 / 500 | iteration 5 / 30 | Total Loss: 4.788705825805664 | KNN Loss: 3.769190549850464 | BCE Loss: 1.0195152759552002
Epoch 151 / 500 | iteration 10 / 30 | Total Loss: 4.681026458740234 | KNN Loss: 3.6699345111846924 | BCE Loss: 1.011091709136963
Epoch 151 / 500 | iteration 15 / 30 | Total Loss: 4.7496538162231445 | KNN Loss: 3.7344610691070557 | BCE Loss: 1.0151928663253784
Epoch 151 / 500 | iteration 20 / 30 | Total Loss: 4.742631435394287 | KNN Loss: 3.6847493648529053 | BCE Loss: 1.0578820705413818
Epoch 151 / 500 | iteration 25 / 30 | Total Loss: 4.704981803894043 | KNN Loss: 3.718004703521

Epoch 161 / 500 | iteration 10 / 30 | Total Loss: 4.737645149230957 | KNN Loss: 3.7296502590179443 | BCE Loss: 1.0079951286315918
Epoch 161 / 500 | iteration 15 / 30 | Total Loss: 4.715563774108887 | KNN Loss: 3.686504364013672 | BCE Loss: 1.0290594100952148
Epoch 161 / 500 | iteration 20 / 30 | Total Loss: 4.7935471534729 | KNN Loss: 3.741656541824341 | BCE Loss: 1.0518906116485596
Epoch 161 / 500 | iteration 25 / 30 | Total Loss: 4.704719066619873 | KNN Loss: 3.703517198562622 | BCE Loss: 1.001201868057251
Epoch 162 / 500 | iteration 0 / 30 | Total Loss: 4.721606731414795 | KNN Loss: 3.7056291103363037 | BCE Loss: 1.0159775018692017
Epoch 162 / 500 | iteration 5 / 30 | Total Loss: 4.732273101806641 | KNN Loss: 3.7069590091705322 | BCE Loss: 1.0253138542175293
Epoch 162 / 500 | iteration 10 / 30 | Total Loss: 4.764775276184082 | KNN Loss: 3.7357966899871826 | BCE Loss: 1.0289788246154785
Epoch 162 / 500 | iteration 15 / 30 | Total Loss: 4.785272121429443 | KNN Loss: 3.7388341426849365

Epoch 172 / 500 | iteration 0 / 30 | Total Loss: 4.663305282592773 | KNN Loss: 3.6540133953094482 | BCE Loss: 1.009291648864746
Epoch 172 / 500 | iteration 5 / 30 | Total Loss: 4.718652725219727 | KNN Loss: 3.712099075317383 | BCE Loss: 1.0065534114837646
Epoch 172 / 500 | iteration 10 / 30 | Total Loss: 4.698805809020996 | KNN Loss: 3.6795685291290283 | BCE Loss: 1.0192371606826782
Epoch 172 / 500 | iteration 15 / 30 | Total Loss: 4.670957565307617 | KNN Loss: 3.667003631591797 | BCE Loss: 1.0039538145065308
Epoch 172 / 500 | iteration 20 / 30 | Total Loss: 4.6957526206970215 | KNN Loss: 3.670919179916382 | BCE Loss: 1.02483332157135
Epoch 172 / 500 | iteration 25 / 30 | Total Loss: 4.711999893188477 | KNN Loss: 3.702652931213379 | BCE Loss: 1.0093467235565186
Epoch 173 / 500 | iteration 0 / 30 | Total Loss: 4.7630767822265625 | KNN Loss: 3.7098870277404785 | BCE Loss: 1.0531895160675049
Epoch 173 / 500 | iteration 5 / 30 | Total Loss: 4.769441604614258 | KNN Loss: 3.734678268432617 |

Epoch 182 / 500 | iteration 20 / 30 | Total Loss: 4.740086078643799 | KNN Loss: 3.6989693641662598 | BCE Loss: 1.0411168336868286
Epoch 182 / 500 | iteration 25 / 30 | Total Loss: 4.734223365783691 | KNN Loss: 3.696307420730591 | BCE Loss: 1.0379161834716797
Epoch 183 / 500 | iteration 0 / 30 | Total Loss: 4.7114152908325195 | KNN Loss: 3.6851110458374023 | BCE Loss: 1.0263042449951172
Epoch 183 / 500 | iteration 5 / 30 | Total Loss: 4.679931163787842 | KNN Loss: 3.699855327606201 | BCE Loss: 0.980076014995575
Epoch 183 / 500 | iteration 10 / 30 | Total Loss: 4.713754177093506 | KNN Loss: 3.6949334144592285 | BCE Loss: 1.018820881843567
Epoch 183 / 500 | iteration 15 / 30 | Total Loss: 4.775022506713867 | KNN Loss: 3.764246940612793 | BCE Loss: 1.0107755661010742
Epoch 183 / 500 | iteration 20 / 30 | Total Loss: 4.732752799987793 | KNN Loss: 3.715127468109131 | BCE Loss: 1.017625093460083
Epoch 183 / 500 | iteration 25 / 30 | Total Loss: 4.701748371124268 | KNN Loss: 3.6657233238220215

Epoch 193 / 500 | iteration 10 / 30 | Total Loss: 4.723447322845459 | KNN Loss: 3.6960062980651855 | BCE Loss: 1.0274409055709839
Epoch 193 / 500 | iteration 15 / 30 | Total Loss: 4.733638286590576 | KNN Loss: 3.676910161972046 | BCE Loss: 1.0567281246185303
Epoch 193 / 500 | iteration 20 / 30 | Total Loss: 4.7095489501953125 | KNN Loss: 3.6756317615509033 | BCE Loss: 1.0339170694351196
Epoch 193 / 500 | iteration 25 / 30 | Total Loss: 4.704442501068115 | KNN Loss: 3.692896604537964 | BCE Loss: 1.0115458965301514
Epoch 194 / 500 | iteration 0 / 30 | Total Loss: 4.6921162605285645 | KNN Loss: 3.688992738723755 | BCE Loss: 1.0031235218048096
Epoch 194 / 500 | iteration 5 / 30 | Total Loss: 4.690792083740234 | KNN Loss: 3.66967511177063 | BCE Loss: 1.0211167335510254
Epoch 194 / 500 | iteration 10 / 30 | Total Loss: 4.671422004699707 | KNN Loss: 3.666356086730957 | BCE Loss: 1.005065679550171
Epoch 194 / 500 | iteration 15 / 30 | Total Loss: 4.724226474761963 | KNN Loss: 3.686429262161255

Epoch 204 / 500 | iteration 0 / 30 | Total Loss: 4.7221245765686035 | KNN Loss: 3.692314386367798 | BCE Loss: 1.0298100709915161
Epoch 204 / 500 | iteration 5 / 30 | Total Loss: 4.7668938636779785 | KNN Loss: 3.7012853622436523 | BCE Loss: 1.0656083822250366
Epoch 204 / 500 | iteration 10 / 30 | Total Loss: 4.678687572479248 | KNN Loss: 3.671180009841919 | BCE Loss: 1.0075074434280396
Epoch 204 / 500 | iteration 15 / 30 | Total Loss: 4.723668098449707 | KNN Loss: 3.697073459625244 | BCE Loss: 1.0265947580337524
Epoch 204 / 500 | iteration 20 / 30 | Total Loss: 4.741018295288086 | KNN Loss: 3.7199437618255615 | BCE Loss: 1.0210745334625244
Epoch 204 / 500 | iteration 25 / 30 | Total Loss: 4.709545135498047 | KNN Loss: 3.688100814819336 | BCE Loss: 1.0214444398880005
Epoch 205 / 500 | iteration 0 / 30 | Total Loss: 4.7265825271606445 | KNN Loss: 3.7112069129943848 | BCE Loss: 1.0153757333755493
Epoch 205 / 500 | iteration 5 / 30 | Total Loss: 4.644457817077637 | KNN Loss: 3.6718425750732

Epoch 214 / 500 | iteration 15 / 30 | Total Loss: 4.727537155151367 | KNN Loss: 3.6818385124206543 | BCE Loss: 1.0456985235214233
Epoch 214 / 500 | iteration 20 / 30 | Total Loss: 4.741756439208984 | KNN Loss: 3.6893279552459717 | BCE Loss: 1.0524282455444336
Epoch 214 / 500 | iteration 25 / 30 | Total Loss: 4.702582836151123 | KNN Loss: 3.695657730102539 | BCE Loss: 1.0069252252578735
Epoch 215 / 500 | iteration 0 / 30 | Total Loss: 4.744598865509033 | KNN Loss: 3.693526029586792 | BCE Loss: 1.0510729551315308
Epoch 215 / 500 | iteration 5 / 30 | Total Loss: 4.73369026184082 | KNN Loss: 3.6861016750335693 | BCE Loss: 1.047588586807251
Epoch 215 / 500 | iteration 10 / 30 | Total Loss: 4.729886054992676 | KNN Loss: 3.6959750652313232 | BCE Loss: 1.0339107513427734
Epoch 215 / 500 | iteration 15 / 30 | Total Loss: 4.771734714508057 | KNN Loss: 3.701991319656372 | BCE Loss: 1.0697435140609741
Epoch 215 / 500 | iteration 20 / 30 | Total Loss: 4.730463981628418 | KNN Loss: 3.701320171356201

Epoch 225 / 500 | iteration 5 / 30 | Total Loss: 4.728729724884033 | KNN Loss: 3.6937434673309326 | BCE Loss: 1.0349862575531006
Epoch 225 / 500 | iteration 10 / 30 | Total Loss: 4.686501502990723 | KNN Loss: 3.6707472801208496 | BCE Loss: 1.0157541036605835
Epoch 225 / 500 | iteration 15 / 30 | Total Loss: 4.7180562019348145 | KNN Loss: 3.710627794265747 | BCE Loss: 1.0074282884597778
Epoch 225 / 500 | iteration 20 / 30 | Total Loss: 4.73882532119751 | KNN Loss: 3.7127084732055664 | BCE Loss: 1.0261167287826538
Epoch 225 / 500 | iteration 25 / 30 | Total Loss: 4.724706172943115 | KNN Loss: 3.697829008102417 | BCE Loss: 1.0268771648406982
Epoch 226 / 500 | iteration 0 / 30 | Total Loss: 4.711949348449707 | KNN Loss: 3.706040382385254 | BCE Loss: 1.0059088468551636
Epoch 226 / 500 | iteration 5 / 30 | Total Loss: 4.693180084228516 | KNN Loss: 3.6841881275177 | BCE Loss: 1.0089917182922363
Epoch 226 / 500 | iteration 10 / 30 | Total Loss: 4.749368667602539 | KNN Loss: 3.71199893951416 | 

Epoch 235 / 500 | iteration 20 / 30 | Total Loss: 4.700353145599365 | KNN Loss: 3.6845409870147705 | BCE Loss: 1.0158121585845947
Epoch 235 / 500 | iteration 25 / 30 | Total Loss: 4.663250923156738 | KNN Loss: 3.6644036769866943 | BCE Loss: 0.998847484588623
Epoch 236 / 500 | iteration 0 / 30 | Total Loss: 4.671493053436279 | KNN Loss: 3.647700548171997 | BCE Loss: 1.0237923860549927
Epoch 236 / 500 | iteration 5 / 30 | Total Loss: 4.7289652824401855 | KNN Loss: 3.6981823444366455 | BCE Loss: 1.03078293800354
Epoch 236 / 500 | iteration 10 / 30 | Total Loss: 4.646731853485107 | KNN Loss: 3.6616616249084473 | BCE Loss: 0.9850702285766602
Epoch 236 / 500 | iteration 15 / 30 | Total Loss: 4.747117042541504 | KNN Loss: 3.7046167850494385 | BCE Loss: 1.0425004959106445
Epoch 236 / 500 | iteration 20 / 30 | Total Loss: 4.716935157775879 | KNN Loss: 3.7105023860931396 | BCE Loss: 1.0064327716827393
Epoch 236 / 500 | iteration 25 / 30 | Total Loss: 4.762528419494629 | KNN Loss: 3.7039144039154

Epoch 246 / 500 | iteration 5 / 30 | Total Loss: 4.705700397491455 | KNN Loss: 3.6888933181762695 | BCE Loss: 1.0168070793151855
Epoch 246 / 500 | iteration 10 / 30 | Total Loss: 4.714965343475342 | KNN Loss: 3.6797373294830322 | BCE Loss: 1.0352281332015991
Epoch 246 / 500 | iteration 15 / 30 | Total Loss: 4.742198467254639 | KNN Loss: 3.706108808517456 | BCE Loss: 1.0360896587371826
Epoch 246 / 500 | iteration 20 / 30 | Total Loss: 4.704590320587158 | KNN Loss: 3.693385362625122 | BCE Loss: 1.0112050771713257
Epoch 246 / 500 | iteration 25 / 30 | Total Loss: 4.738190174102783 | KNN Loss: 3.6975948810577393 | BCE Loss: 1.0405954122543335
Epoch 247 / 500 | iteration 0 / 30 | Total Loss: 4.692266464233398 | KNN Loss: 3.6873929500579834 | BCE Loss: 1.004873514175415
Epoch 247 / 500 | iteration 5 / 30 | Total Loss: 4.699589729309082 | KNN Loss: 3.678147792816162 | BCE Loss: 1.02144193649292
Epoch 247 / 500 | iteration 10 / 30 | Total Loss: 4.678483486175537 | KNN Loss: 3.663785457611084 |

Epoch 256 / 500 | iteration 25 / 30 | Total Loss: 4.698558330535889 | KNN Loss: 3.668684720993042 | BCE Loss: 1.0298734903335571
Epoch 257 / 500 | iteration 0 / 30 | Total Loss: 4.717868328094482 | KNN Loss: 3.6943488121032715 | BCE Loss: 1.0235196352005005
Epoch 257 / 500 | iteration 5 / 30 | Total Loss: 4.696666717529297 | KNN Loss: 3.6820380687713623 | BCE Loss: 1.0146284103393555
Epoch 257 / 500 | iteration 10 / 30 | Total Loss: 4.7254862785339355 | KNN Loss: 3.7186121940612793 | BCE Loss: 1.0068742036819458
Epoch 257 / 500 | iteration 15 / 30 | Total Loss: 4.715353012084961 | KNN Loss: 3.6962459087371826 | BCE Loss: 1.0191069841384888
Epoch 257 / 500 | iteration 20 / 30 | Total Loss: 4.715731143951416 | KNN Loss: 3.6618754863739014 | BCE Loss: 1.0538557767868042
Epoch 257 / 500 | iteration 25 / 30 | Total Loss: 4.718195915222168 | KNN Loss: 3.6890907287597656 | BCE Loss: 1.0291049480438232
Epoch 258 / 500 | iteration 0 / 30 | Total Loss: 4.715391159057617 | KNN Loss: 3.71378135681

Epoch 267 / 500 | iteration 15 / 30 | Total Loss: 4.707636833190918 | KNN Loss: 3.7067086696624756 | BCE Loss: 1.0009281635284424
Epoch 267 / 500 | iteration 20 / 30 | Total Loss: 4.678200721740723 | KNN Loss: 3.6721320152282715 | BCE Loss: 1.0060688257217407
Epoch 267 / 500 | iteration 25 / 30 | Total Loss: 4.6675591468811035 | KNN Loss: 3.654594898223877 | BCE Loss: 1.0129642486572266
Epoch 268 / 500 | iteration 0 / 30 | Total Loss: 4.7280120849609375 | KNN Loss: 3.7046127319335938 | BCE Loss: 1.0233993530273438
Epoch 268 / 500 | iteration 5 / 30 | Total Loss: 4.694061279296875 | KNN Loss: 3.703223466873169 | BCE Loss: 0.990837574005127
Epoch 268 / 500 | iteration 10 / 30 | Total Loss: 4.730989456176758 | KNN Loss: 3.7316179275512695 | BCE Loss: 0.9993715882301331
Epoch 268 / 500 | iteration 15 / 30 | Total Loss: 4.689515113830566 | KNN Loss: 3.674952745437622 | BCE Loss: 1.0145623683929443
Epoch 268 / 500 | iteration 20 / 30 | Total Loss: 4.720033645629883 | KNN Loss: 3.677161693572

Epoch 278 / 500 | iteration 5 / 30 | Total Loss: 4.730372905731201 | KNN Loss: 3.706259250640869 | BCE Loss: 1.0241137742996216
Epoch 278 / 500 | iteration 10 / 30 | Total Loss: 4.7361836433410645 | KNN Loss: 3.7037582397460938 | BCE Loss: 1.0324252843856812
Epoch 278 / 500 | iteration 15 / 30 | Total Loss: 4.695908546447754 | KNN Loss: 3.6684470176696777 | BCE Loss: 1.0274617671966553
Epoch 278 / 500 | iteration 20 / 30 | Total Loss: 4.7384209632873535 | KNN Loss: 3.727468967437744 | BCE Loss: 1.0109518766403198
Epoch 278 / 500 | iteration 25 / 30 | Total Loss: 4.697651386260986 | KNN Loss: 3.6647002696990967 | BCE Loss: 1.0329509973526
Epoch 279 / 500 | iteration 0 / 30 | Total Loss: 4.68834924697876 | KNN Loss: 3.6803154945373535 | BCE Loss: 1.0080338716506958
Epoch 279 / 500 | iteration 5 / 30 | Total Loss: 4.721467018127441 | KNN Loss: 3.700702667236328 | BCE Loss: 1.0207645893096924
Epoch 279 / 500 | iteration 10 / 30 | Total Loss: 4.708103179931641 | KNN Loss: 3.6669704914093018

Epoch 288 / 500 | iteration 25 / 30 | Total Loss: 4.699319839477539 | KNN Loss: 3.6849000453948975 | BCE Loss: 1.0144195556640625
Epoch 289 / 500 | iteration 0 / 30 | Total Loss: 4.678578853607178 | KNN Loss: 3.657196521759033 | BCE Loss: 1.0213823318481445
Epoch 289 / 500 | iteration 5 / 30 | Total Loss: 4.734014987945557 | KNN Loss: 3.701911211013794 | BCE Loss: 1.0321038961410522
Epoch 289 / 500 | iteration 10 / 30 | Total Loss: 4.737849235534668 | KNN Loss: 3.722419023513794 | BCE Loss: 1.0154300928115845
Epoch 289 / 500 | iteration 15 / 30 | Total Loss: 4.739881992340088 | KNN Loss: 3.7007198333740234 | BCE Loss: 1.039162278175354
Epoch 289 / 500 | iteration 20 / 30 | Total Loss: 4.7268171310424805 | KNN Loss: 3.7190210819244385 | BCE Loss: 1.007796287536621
Epoch 289 / 500 | iteration 25 / 30 | Total Loss: 4.680582046508789 | KNN Loss: 3.669222354888916 | BCE Loss: 1.0113599300384521
Epoch 290 / 500 | iteration 0 / 30 | Total Loss: 4.692950248718262 | KNN Loss: 3.6754825115203857

Epoch 299 / 500 | iteration 15 / 30 | Total Loss: 4.6643805503845215 | KNN Loss: 3.660395383834839 | BCE Loss: 1.003985047340393
Epoch 299 / 500 | iteration 20 / 30 | Total Loss: 4.664844512939453 | KNN Loss: 3.6575748920440674 | BCE Loss: 1.0072696208953857
Epoch 299 / 500 | iteration 25 / 30 | Total Loss: 4.73767614364624 | KNN Loss: 3.7242987155914307 | BCE Loss: 1.0133775472640991
Epoch 300 / 500 | iteration 0 / 30 | Total Loss: 4.740182876586914 | KNN Loss: 3.6962735652923584 | BCE Loss: 1.0439090728759766
Epoch 300 / 500 | iteration 5 / 30 | Total Loss: 4.70028829574585 | KNN Loss: 3.647413492202759 | BCE Loss: 1.0528748035430908
Epoch 300 / 500 | iteration 10 / 30 | Total Loss: 4.725006103515625 | KNN Loss: 3.697453498840332 | BCE Loss: 1.0275527238845825
Epoch 300 / 500 | iteration 15 / 30 | Total Loss: 4.720986843109131 | KNN Loss: 3.67313814163208 | BCE Loss: 1.0478485822677612
Epoch 300 / 500 | iteration 20 / 30 | Total Loss: 4.742677688598633 | KNN Loss: 3.7246346473693848 

Epoch 310 / 500 | iteration 0 / 30 | Total Loss: 4.707760810852051 | KNN Loss: 3.672421932220459 | BCE Loss: 1.0353388786315918
Epoch 310 / 500 | iteration 5 / 30 | Total Loss: 4.724827289581299 | KNN Loss: 3.6814639568328857 | BCE Loss: 1.0433634519577026
Epoch 310 / 500 | iteration 10 / 30 | Total Loss: 4.732338905334473 | KNN Loss: 3.6983423233032227 | BCE Loss: 1.0339964628219604
Epoch 310 / 500 | iteration 15 / 30 | Total Loss: 4.681587219238281 | KNN Loss: 3.673604965209961 | BCE Loss: 1.0079824924468994
Epoch 310 / 500 | iteration 20 / 30 | Total Loss: 4.736474990844727 | KNN Loss: 3.693772077560425 | BCE Loss: 1.0427027940750122
Epoch 310 / 500 | iteration 25 / 30 | Total Loss: 4.702083587646484 | KNN Loss: 3.690814971923828 | BCE Loss: 1.0112687349319458
Epoch 311 / 500 | iteration 0 / 30 | Total Loss: 4.727140426635742 | KNN Loss: 3.704580545425415 | BCE Loss: 1.0225598812103271
Epoch 311 / 500 | iteration 5 / 30 | Total Loss: 4.6939311027526855 | KNN Loss: 3.6808786392211914

Epoch 320 / 500 | iteration 20 / 30 | Total Loss: 4.713349342346191 | KNN Loss: 3.694822072982788 | BCE Loss: 1.0185270309448242
Epoch 320 / 500 | iteration 25 / 30 | Total Loss: 4.733984470367432 | KNN Loss: 3.710693836212158 | BCE Loss: 1.0232906341552734
Epoch 321 / 500 | iteration 0 / 30 | Total Loss: 4.720214366912842 | KNN Loss: 3.684551477432251 | BCE Loss: 1.0356627702713013
Epoch 321 / 500 | iteration 5 / 30 | Total Loss: 4.770184516906738 | KNN Loss: 3.7196998596191406 | BCE Loss: 1.0504846572875977
Epoch 321 / 500 | iteration 10 / 30 | Total Loss: 4.754359722137451 | KNN Loss: 3.7077207565307617 | BCE Loss: 1.0466389656066895
Epoch 321 / 500 | iteration 15 / 30 | Total Loss: 4.6976494789123535 | KNN Loss: 3.7114882469177246 | BCE Loss: 0.9861613512039185
Epoch 321 / 500 | iteration 20 / 30 | Total Loss: 4.7822370529174805 | KNN Loss: 3.7580761909484863 | BCE Loss: 1.0241608619689941
Epoch 321 / 500 | iteration 25 / 30 | Total Loss: 4.731814384460449 | KNN Loss: 3.69696974754

Epoch 331 / 500 | iteration 5 / 30 | Total Loss: 4.767207145690918 | KNN Loss: 3.699498414993286 | BCE Loss: 1.0677084922790527
Epoch 331 / 500 | iteration 10 / 30 | Total Loss: 4.7347187995910645 | KNN Loss: 3.696225166320801 | BCE Loss: 1.0384935140609741
Epoch 331 / 500 | iteration 15 / 30 | Total Loss: 4.773586273193359 | KNN Loss: 3.732598304748535 | BCE Loss: 1.0409879684448242
Epoch 331 / 500 | iteration 20 / 30 | Total Loss: 4.709778308868408 | KNN Loss: 3.681497097015381 | BCE Loss: 1.0282812118530273
Epoch 331 / 500 | iteration 25 / 30 | Total Loss: 4.703167915344238 | KNN Loss: 3.6793127059936523 | BCE Loss: 1.023855447769165
Epoch 332 / 500 | iteration 0 / 30 | Total Loss: 4.744091510772705 | KNN Loss: 3.7277071475982666 | BCE Loss: 1.016384243965149
Epoch 332 / 500 | iteration 5 / 30 | Total Loss: 4.726428031921387 | KNN Loss: 3.6827516555786133 | BCE Loss: 1.0436763763427734
Epoch 332 / 500 | iteration 10 / 30 | Total Loss: 4.6956562995910645 | KNN Loss: 3.653133869171142

Epoch 341 / 500 | iteration 25 / 30 | Total Loss: 4.64122200012207 | KNN Loss: 3.65218448638916 | BCE Loss: 0.9890373349189758
Epoch   342: reducing learning rate of group 0 to 3.9896e-06.
Epoch 342 / 500 | iteration 0 / 30 | Total Loss: 4.728679656982422 | KNN Loss: 3.7024009227752686 | BCE Loss: 1.0262789726257324
Epoch 342 / 500 | iteration 5 / 30 | Total Loss: 4.704317092895508 | KNN Loss: 3.6968178749084473 | BCE Loss: 1.0074992179870605
Epoch 342 / 500 | iteration 10 / 30 | Total Loss: 4.727871894836426 | KNN Loss: 3.670689105987549 | BCE Loss: 1.057182788848877
Epoch 342 / 500 | iteration 15 / 30 | Total Loss: 4.725987434387207 | KNN Loss: 3.6991994380950928 | BCE Loss: 1.0267881155014038
Epoch 342 / 500 | iteration 20 / 30 | Total Loss: 4.7078447341918945 | KNN Loss: 3.688105821609497 | BCE Loss: 1.0197389125823975
Epoch 342 / 500 | iteration 25 / 30 | Total Loss: 4.7216315269470215 | KNN Loss: 3.702838659286499 | BCE Loss: 1.018792748451233
Epoch 343 / 500 | iteration 0 / 30 |

Epoch 352 / 500 | iteration 10 / 30 | Total Loss: 4.763634204864502 | KNN Loss: 3.70817494392395 | BCE Loss: 1.0554591417312622
Epoch 352 / 500 | iteration 15 / 30 | Total Loss: 4.733028888702393 | KNN Loss: 3.6872875690460205 | BCE Loss: 1.045741319656372
Epoch 352 / 500 | iteration 20 / 30 | Total Loss: 4.73328161239624 | KNN Loss: 3.696939706802368 | BCE Loss: 1.0363420248031616
Epoch 352 / 500 | iteration 25 / 30 | Total Loss: 4.719284534454346 | KNN Loss: 3.6873185634613037 | BCE Loss: 1.031965970993042
Epoch   353: reducing learning rate of group 0 to 2.7927e-06.
Epoch 353 / 500 | iteration 0 / 30 | Total Loss: 4.757143497467041 | KNN Loss: 3.705756187438965 | BCE Loss: 1.0513871908187866
Epoch 353 / 500 | iteration 5 / 30 | Total Loss: 4.666490077972412 | KNN Loss: 3.650517702102661 | BCE Loss: 1.015972375869751
Epoch 353 / 500 | iteration 10 / 30 | Total Loss: 4.712058067321777 | KNN Loss: 3.685976982116699 | BCE Loss: 1.0260813236236572
Epoch 353 / 500 | iteration 15 / 30 | To

Epoch 362 / 500 | iteration 25 / 30 | Total Loss: 4.698068618774414 | KNN Loss: 3.6694324016571045 | BCE Loss: 1.02863609790802
Epoch 363 / 500 | iteration 0 / 30 | Total Loss: 4.7259016036987305 | KNN Loss: 3.683556079864502 | BCE Loss: 1.0423457622528076
Epoch 363 / 500 | iteration 5 / 30 | Total Loss: 4.723474502563477 | KNN Loss: 3.67948317527771 | BCE Loss: 1.0439913272857666
Epoch 363 / 500 | iteration 10 / 30 | Total Loss: 4.710903167724609 | KNN Loss: 3.6999237537384033 | BCE Loss: 1.0109796524047852
Epoch 363 / 500 | iteration 15 / 30 | Total Loss: 4.6990966796875 | KNN Loss: 3.6848998069763184 | BCE Loss: 1.0141966342926025
Epoch 363 / 500 | iteration 20 / 30 | Total Loss: 4.719333171844482 | KNN Loss: 3.681230306625366 | BCE Loss: 1.0381029844284058
Epoch 363 / 500 | iteration 25 / 30 | Total Loss: 4.683670997619629 | KNN Loss: 3.667605400085449 | BCE Loss: 1.0160653591156006
Epoch   364: reducing learning rate of group 0 to 1.9549e-06.
Epoch 364 / 500 | iteration 0 / 30 | T

Epoch 373 / 500 | iteration 15 / 30 | Total Loss: 4.732789039611816 | KNN Loss: 3.7025368213653564 | BCE Loss: 1.03025221824646
Epoch 373 / 500 | iteration 20 / 30 | Total Loss: 4.736299991607666 | KNN Loss: 3.6944735050201416 | BCE Loss: 1.0418264865875244
Epoch 373 / 500 | iteration 25 / 30 | Total Loss: 4.729044437408447 | KNN Loss: 3.6954715251922607 | BCE Loss: 1.033573031425476
Epoch 374 / 500 | iteration 0 / 30 | Total Loss: 4.685482025146484 | KNN Loss: 3.6864864826202393 | BCE Loss: 0.9989956021308899
Epoch 374 / 500 | iteration 5 / 30 | Total Loss: 4.71626091003418 | KNN Loss: 3.683424472808838 | BCE Loss: 1.0328365564346313
Epoch 374 / 500 | iteration 10 / 30 | Total Loss: 4.678765296936035 | KNN Loss: 3.661109685897827 | BCE Loss: 1.017655611038208
Epoch 374 / 500 | iteration 15 / 30 | Total Loss: 4.729079246520996 | KNN Loss: 3.6952106952667236 | BCE Loss: 1.0338685512542725
Epoch 374 / 500 | iteration 20 / 30 | Total Loss: 4.72708797454834 | KNN Loss: 3.6894328594207764 |

Epoch 384 / 500 | iteration 5 / 30 | Total Loss: 4.66320276260376 | KNN Loss: 3.667816400527954 | BCE Loss: 0.9953862428665161
Epoch 384 / 500 | iteration 10 / 30 | Total Loss: 4.751406192779541 | KNN Loss: 3.7286503314971924 | BCE Loss: 1.0227559804916382
Epoch 384 / 500 | iteration 15 / 30 | Total Loss: 4.720571517944336 | KNN Loss: 3.703176259994507 | BCE Loss: 1.017395257949829
Epoch 384 / 500 | iteration 20 / 30 | Total Loss: 4.756339073181152 | KNN Loss: 3.732741355895996 | BCE Loss: 1.0235974788665771
Epoch 384 / 500 | iteration 25 / 30 | Total Loss: 4.729302406311035 | KNN Loss: 3.689518690109253 | BCE Loss: 1.0397834777832031
Epoch 385 / 500 | iteration 0 / 30 | Total Loss: 4.731909275054932 | KNN Loss: 3.6979174613952637 | BCE Loss: 1.0339919328689575
Epoch 385 / 500 | iteration 5 / 30 | Total Loss: 4.70383358001709 | KNN Loss: 3.6813950538635254 | BCE Loss: 1.0224382877349854
Epoch 385 / 500 | iteration 10 / 30 | Total Loss: 4.7687296867370605 | KNN Loss: 3.719104766845703 |

Epoch 394 / 500 | iteration 25 / 30 | Total Loss: 4.76121711730957 | KNN Loss: 3.7254230976104736 | BCE Loss: 1.0357942581176758
Epoch 395 / 500 | iteration 0 / 30 | Total Loss: 4.70926570892334 | KNN Loss: 3.6817100048065186 | BCE Loss: 1.0275557041168213
Epoch 395 / 500 | iteration 5 / 30 | Total Loss: 4.73371696472168 | KNN Loss: 3.706780433654785 | BCE Loss: 1.0269367694854736
Epoch 395 / 500 | iteration 10 / 30 | Total Loss: 4.725086212158203 | KNN Loss: 3.71771502494812 | BCE Loss: 1.0073710680007935
Epoch 395 / 500 | iteration 15 / 30 | Total Loss: 4.682054042816162 | KNN Loss: 3.680115222930908 | BCE Loss: 1.0019389390945435
Epoch 395 / 500 | iteration 20 / 30 | Total Loss: 4.712334632873535 | KNN Loss: 3.704024314880371 | BCE Loss: 1.0083105564117432
Epoch 395 / 500 | iteration 25 / 30 | Total Loss: 4.697964191436768 | KNN Loss: 3.688457727432251 | BCE Loss: 1.0095064640045166
Epoch 396 / 500 | iteration 0 / 30 | Total Loss: 4.704105377197266 | KNN Loss: 3.6885366439819336 | B

Epoch 405 / 500 | iteration 15 / 30 | Total Loss: 4.739022254943848 | KNN Loss: 3.689577341079712 | BCE Loss: 1.0494447946548462
Epoch 405 / 500 | iteration 20 / 30 | Total Loss: 4.695044994354248 | KNN Loss: 3.6684489250183105 | BCE Loss: 1.026595950126648
Epoch 405 / 500 | iteration 25 / 30 | Total Loss: 4.687225818634033 | KNN Loss: 3.671999931335449 | BCE Loss: 1.015225887298584
Epoch 406 / 500 | iteration 0 / 30 | Total Loss: 4.764193058013916 | KNN Loss: 3.709041118621826 | BCE Loss: 1.0551519393920898
Epoch 406 / 500 | iteration 5 / 30 | Total Loss: 4.722105979919434 | KNN Loss: 3.7030696868896484 | BCE Loss: 1.0190365314483643
Epoch 406 / 500 | iteration 10 / 30 | Total Loss: 4.715198516845703 | KNN Loss: 3.6727888584136963 | BCE Loss: 1.0424094200134277
Epoch 406 / 500 | iteration 15 / 30 | Total Loss: 4.762514114379883 | KNN Loss: 3.715512990951538 | BCE Loss: 1.0470011234283447
Epoch 406 / 500 | iteration 20 / 30 | Total Loss: 4.763265609741211 | KNN Loss: 3.7025747299194336

Epoch 416 / 500 | iteration 5 / 30 | Total Loss: 4.718160152435303 | KNN Loss: 3.6901931762695312 | BCE Loss: 1.0279669761657715
Epoch 416 / 500 | iteration 10 / 30 | Total Loss: 4.725253105163574 | KNN Loss: 3.699852705001831 | BCE Loss: 1.025400161743164
Epoch 416 / 500 | iteration 15 / 30 | Total Loss: 4.72184944152832 | KNN Loss: 3.6811981201171875 | BCE Loss: 1.0406512022018433
Epoch 416 / 500 | iteration 20 / 30 | Total Loss: 4.7003397941589355 | KNN Loss: 3.670062303543091 | BCE Loss: 1.0302776098251343
Epoch 416 / 500 | iteration 25 / 30 | Total Loss: 4.666968822479248 | KNN Loss: 3.6545517444610596 | BCE Loss: 1.0124170780181885
Epoch 417 / 500 | iteration 0 / 30 | Total Loss: 4.67454195022583 | KNN Loss: 3.664839744567871 | BCE Loss: 1.009702205657959
Epoch 417 / 500 | iteration 5 / 30 | Total Loss: 4.732693672180176 | KNN Loss: 3.7023942470550537 | BCE Loss: 1.0302995443344116
Epoch 417 / 500 | iteration 10 / 30 | Total Loss: 4.762372970581055 | KNN Loss: 3.717921018600464 |

Epoch 426 / 500 | iteration 25 / 30 | Total Loss: 4.708044528961182 | KNN Loss: 3.6804416179656982 | BCE Loss: 1.0276027917861938
Epoch 427 / 500 | iteration 0 / 30 | Total Loss: 4.7303924560546875 | KNN Loss: 3.7076668739318848 | BCE Loss: 1.0227257013320923
Epoch 427 / 500 | iteration 5 / 30 | Total Loss: 4.747120380401611 | KNN Loss: 3.725834608078003 | BCE Loss: 1.0212856531143188
Epoch 427 / 500 | iteration 10 / 30 | Total Loss: 4.731727123260498 | KNN Loss: 3.687772512435913 | BCE Loss: 1.0439544916152954
Epoch 427 / 500 | iteration 15 / 30 | Total Loss: 4.66817569732666 | KNN Loss: 3.695002794265747 | BCE Loss: 0.9731730818748474
Epoch 427 / 500 | iteration 20 / 30 | Total Loss: 4.719873905181885 | KNN Loss: 3.692668914794922 | BCE Loss: 1.027204990386963
Epoch 427 / 500 | iteration 25 / 30 | Total Loss: 4.693912982940674 | KNN Loss: 3.65382981300354 | BCE Loss: 1.0400830507278442
Epoch 428 / 500 | iteration 0 / 30 | Total Loss: 4.770133018493652 | KNN Loss: 3.726199150085449 | 

Epoch 437 / 500 | iteration 15 / 30 | Total Loss: 4.734422206878662 | KNN Loss: 3.696061134338379 | BCE Loss: 1.0383610725402832
Epoch 437 / 500 | iteration 20 / 30 | Total Loss: 4.743606090545654 | KNN Loss: 3.7316668033599854 | BCE Loss: 1.0119394063949585
Epoch 437 / 500 | iteration 25 / 30 | Total Loss: 4.706932067871094 | KNN Loss: 3.681630849838257 | BCE Loss: 1.0253009796142578
Epoch 438 / 500 | iteration 0 / 30 | Total Loss: 4.6917924880981445 | KNN Loss: 3.6965365409851074 | BCE Loss: 0.9952561855316162
Epoch 438 / 500 | iteration 5 / 30 | Total Loss: 4.6929192543029785 | KNN Loss: 3.7015035152435303 | BCE Loss: 0.9914158582687378
Epoch 438 / 500 | iteration 10 / 30 | Total Loss: 4.736515522003174 | KNN Loss: 3.6732630729675293 | BCE Loss: 1.063252568244934
Epoch 438 / 500 | iteration 15 / 30 | Total Loss: 4.750938415527344 | KNN Loss: 3.715184450149536 | BCE Loss: 1.0357542037963867
Epoch 438 / 500 | iteration 20 / 30 | Total Loss: 4.672362804412842 | KNN Loss: 3.681021690368

Epoch 448 / 500 | iteration 0 / 30 | Total Loss: 4.732075214385986 | KNN Loss: 3.7153801918029785 | BCE Loss: 1.0166950225830078
Epoch 448 / 500 | iteration 5 / 30 | Total Loss: 4.716768264770508 | KNN Loss: 3.6851818561553955 | BCE Loss: 1.0315861701965332
Epoch 448 / 500 | iteration 10 / 30 | Total Loss: 4.7052998542785645 | KNN Loss: 3.690096616744995 | BCE Loss: 1.0152032375335693
Epoch 448 / 500 | iteration 15 / 30 | Total Loss: 4.742417335510254 | KNN Loss: 3.6937782764434814 | BCE Loss: 1.0486392974853516
Epoch 448 / 500 | iteration 20 / 30 | Total Loss: 4.696242809295654 | KNN Loss: 3.697507619857788 | BCE Loss: 0.9987351298332214
Epoch 448 / 500 | iteration 25 / 30 | Total Loss: 4.72514009475708 | KNN Loss: 3.6809780597686768 | BCE Loss: 1.0441621541976929
Epoch 449 / 500 | iteration 0 / 30 | Total Loss: 4.748376846313477 | KNN Loss: 3.7178196907043457 | BCE Loss: 1.0305571556091309
Epoch 449 / 500 | iteration 5 / 30 | Total Loss: 4.729743480682373 | KNN Loss: 3.70274162292480

Epoch 458 / 500 | iteration 20 / 30 | Total Loss: 4.7442216873168945 | KNN Loss: 3.72982120513916 | BCE Loss: 1.014400601387024
Epoch 458 / 500 | iteration 25 / 30 | Total Loss: 4.720094680786133 | KNN Loss: 3.694822311401367 | BCE Loss: 1.025272250175476
Epoch 459 / 500 | iteration 0 / 30 | Total Loss: 4.7125420570373535 | KNN Loss: 3.689485788345337 | BCE Loss: 1.0230562686920166
Epoch 459 / 500 | iteration 5 / 30 | Total Loss: 4.674315929412842 | KNN Loss: 3.6623287200927734 | BCE Loss: 1.011987328529358
Epoch 459 / 500 | iteration 10 / 30 | Total Loss: 4.732183456420898 | KNN Loss: 3.7102231979370117 | BCE Loss: 1.0219604969024658
Epoch 459 / 500 | iteration 15 / 30 | Total Loss: 4.6692352294921875 | KNN Loss: 3.6817195415496826 | BCE Loss: 0.987515926361084
Epoch 459 / 500 | iteration 20 / 30 | Total Loss: 4.71513557434082 | KNN Loss: 3.6937007904052734 | BCE Loss: 1.021435022354126
Epoch 459 / 500 | iteration 25 / 30 | Total Loss: 4.742697238922119 | KNN Loss: 3.6982126235961914 

Epoch 469 / 500 | iteration 10 / 30 | Total Loss: 4.695771217346191 | KNN Loss: 3.6765801906585693 | BCE Loss: 1.019190788269043
Epoch 469 / 500 | iteration 15 / 30 | Total Loss: 4.709070205688477 | KNN Loss: 3.6821446418762207 | BCE Loss: 1.0269253253936768
Epoch 469 / 500 | iteration 20 / 30 | Total Loss: 4.658243656158447 | KNN Loss: 3.664552927017212 | BCE Loss: 0.9936908483505249
Epoch 469 / 500 | iteration 25 / 30 | Total Loss: 4.726993560791016 | KNN Loss: 3.6653761863708496 | BCE Loss: 1.0616176128387451
Epoch 470 / 500 | iteration 0 / 30 | Total Loss: 4.704527378082275 | KNN Loss: 3.685822010040283 | BCE Loss: 1.0187053680419922
Epoch 470 / 500 | iteration 5 / 30 | Total Loss: 4.696495056152344 | KNN Loss: 3.66705584526062 | BCE Loss: 1.0294393301010132
Epoch 470 / 500 | iteration 10 / 30 | Total Loss: 4.683419704437256 | KNN Loss: 3.684457540512085 | BCE Loss: 0.9989620447158813
Epoch 470 / 500 | iteration 15 / 30 | Total Loss: 4.679872035980225 | KNN Loss: 3.6961405277252197

Epoch 480 / 500 | iteration 0 / 30 | Total Loss: 4.680174350738525 | KNN Loss: 3.695300579071045 | BCE Loss: 0.9848736524581909
Epoch 480 / 500 | iteration 5 / 30 | Total Loss: 4.6911163330078125 | KNN Loss: 3.6690118312835693 | BCE Loss: 1.0221046209335327
Epoch 480 / 500 | iteration 10 / 30 | Total Loss: 4.690283298492432 | KNN Loss: 3.699307680130005 | BCE Loss: 0.9909757971763611
Epoch 480 / 500 | iteration 15 / 30 | Total Loss: 4.677189826965332 | KNN Loss: 3.6597557067871094 | BCE Loss: 1.0174338817596436
Epoch 480 / 500 | iteration 20 / 30 | Total Loss: 4.69784688949585 | KNN Loss: 3.669581890106201 | BCE Loss: 1.0282648801803589
Epoch 480 / 500 | iteration 25 / 30 | Total Loss: 4.691289901733398 | KNN Loss: 3.673945188522339 | BCE Loss: 1.0173447132110596
Epoch 481 / 500 | iteration 0 / 30 | Total Loss: 4.698472023010254 | KNN Loss: 3.668053150177002 | BCE Loss: 1.030419111251831
Epoch 481 / 500 | iteration 5 / 30 | Total Loss: 4.7260236740112305 | KNN Loss: 3.7319328784942627 

Epoch 490 / 500 | iteration 20 / 30 | Total Loss: 4.658114910125732 | KNN Loss: 3.671670913696289 | BCE Loss: 0.9864441752433777
Epoch 490 / 500 | iteration 25 / 30 | Total Loss: 4.715171813964844 | KNN Loss: 3.6964666843414307 | BCE Loss: 1.018704891204834
Epoch 491 / 500 | iteration 0 / 30 | Total Loss: 4.735581398010254 | KNN Loss: 3.679997682571411 | BCE Loss: 1.0555838346481323
Epoch 491 / 500 | iteration 5 / 30 | Total Loss: 4.728246212005615 | KNN Loss: 3.735576629638672 | BCE Loss: 0.9926697015762329
Epoch 491 / 500 | iteration 10 / 30 | Total Loss: 4.7432451248168945 | KNN Loss: 3.7044808864593506 | BCE Loss: 1.038764238357544
Epoch 491 / 500 | iteration 15 / 30 | Total Loss: 4.702518463134766 | KNN Loss: 3.6970551013946533 | BCE Loss: 1.0054631233215332
Epoch 491 / 500 | iteration 20 / 30 | Total Loss: 4.74367094039917 | KNN Loss: 3.6656463146209717 | BCE Loss: 1.0780247449874878
Epoch 491 / 500 | iteration 25 / 30 | Total Loss: 4.733044624328613 | KNN Loss: 3.684960603713989

In [7]:
outputs, iterm = model(dataset[67][0].unsqueeze(0), return_intermidiate=True)
print(outputs)
print(outputs.min())
print(outputs.max())

tensor([[ 1.7604,  4.4755,  3.0771,  3.9097,  4.0269,  0.5695,  1.7009,  2.7212,
          1.9785,  1.8334,  2.6973,  2.0479,  0.5888,  2.0219,  1.4895,  0.5669,
          1.7798,  2.0688,  3.1418,  1.2685,  2.0295,  1.9587,  1.8234,  2.0338,
          1.7594,  1.3992,  2.6015,  0.9626,  1.8034,  0.1201, -0.0493,  0.8887,
          0.1603,  0.8327,  1.9717,  1.0101,  0.7896,  1.8603,  0.6187,  1.4763,
          0.8109, -0.5784, -0.3982,  2.8505,  1.4647,  0.7623, -0.0900,  0.2955,
          1.6988,  2.8949,  2.0963,  0.3493,  1.5421,  0.9165, -0.3905,  0.9805,
          1.5791,  1.6771,  1.2341,  1.6186,  0.4137,  0.5113, -0.0765,  1.7783,
          0.6330,  1.9503, -1.7280,  0.0204,  1.8930,  2.5013,  2.9008, -0.0242,
          1.5629,  2.4864,  2.0565,  1.5185,  0.1033,  0.7187, -0.1259,  1.9349,
          0.0165, -0.0965,  2.1189, -0.2565,  0.5036, -1.0598, -2.2981, -0.0859,
          0.8270, -1.8429,  0.0118, -0.0174, -0.5732, -0.6276,  0.3815,  1.0661,
         -0.6287, -0.8592,  

In [8]:
plt.figure()
plt.plot(losses)
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [9]:
dataset.transform = torchvision.transforms.Compose([
    BinaryEncodingTransform(mapping=dataset.items_to_idx),
]
)
dataset.target_transform = torchvision.transforms.Compose([
    BinaryEncodingTransform(mapping=dataset.items_to_idx),
]
)

In [10]:
dataset_ = [d[0] for d in dataset]

In [11]:
projections = model.calculate_intermidiate(dataset_)

100%|██████████████████████████████████████████████████████████████████████████████████| 15/15 [00:03<00:00,  4.22it/s]


In [12]:
distances = pairwise_distances(projections)
# distances = np.triu(distances)
distances_f = distances.flatten()

plt.matshow(distances)
plt.colorbar()
plt.figure()
plt.hist(distances_f[distances_f > 0], bins=1000)
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

# Fit DBSCAN and calculate indices


In [13]:
clusters = DBSCAN(eps=0.2, min_samples=80).fit_predict(projections)
# scores = []
# best_score = float('inf')
# clusters = None
# range_ = list(range(5, 20))
# for k in tqdm(range_):
#     y = GaussianMixture(n_components=k).fit_predict(projections)
#     cur_score = davies_bouldin_score(projections, y)
#     scores.append(cur_score)
    
#     if cur_score < best_score:
#         best_score = cur_score
#         clusters = y

In [14]:
perplexity = 100

p = reduce_dims_and_plot(projections[clusters != -1],
                         y=clusters[clusters != -1],
                         title=f'perplexity: {perplexity}',
                         file_name=None,
                         perplexity=perplexity,
                         library='Multicore-TSNE',
                         perform_PCA=False,
                         projected=None,
                         figure_type='2d',
                         show_figure=True,
                         close_figure=False,
                         text=None)    

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [15]:
from sklearn.tree import DecisionTreeClassifier
from sklearn import tree
from sklearn.tree import _tree

In [16]:
tensor_dataset = torch.stack(dataset_)

In [63]:
clf = DecisionTreeClassifier(max_depth=200, min_samples_leaf=5)
clf = clf.fit(tensor_dataset[clusters!=-1], clusters[clusters != -1])
print(clf.score(tensor_dataset[clusters!=-1], clusters[clusters != -1]))
print(clf.get_depth())

0.22159491884262528
132


In [62]:
scores = []
for min_samples in range(1,50, 1):
    clf = DecisionTreeClassifier(max_depth=200, min_samples_leaf=min_samples)
    clf = clf.fit(tensor_dataset[clusters!=-1], clusters[clusters != -1])
    scores.append(clf.score(tensor_dataset[clusters!=-1], clusters[clusters != -1]))
    
plt.figure()
plt.plot(list(range(1,50, 1)), scores)
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [64]:
def get_rules(tree, feature_names, class_names):
    tree_ = tree.tree_
    feature_name = [
        feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!"
        for i in tree_.feature
    ]

    paths = []
    path = []
    
    def recurse(node, path, paths):
        if tree_.feature[node] != _tree.TREE_UNDEFINED:
            name = feature_name[node]
            threshold = tree_.threshold[node]
            p1, p2 = list(path), list(path)
#             p1 += [f"({name} <= {np.round(threshold, 3)})"]
            p1 += [(name, '<=', np.round(threshold, 3))]
            recurse(tree_.children_left[node], p1, paths)
            p2 += [(name, '>', np.round(threshold, 3))]
            recurse(tree_.children_right[node], p2, paths)
        else:
            path += [(tree_.value[node], tree_.n_node_samples[node])]
            paths += [path]
            
    recurse(0, path, paths)

    # sort by samples count
    samples_count = [p[-1][1] for p in paths]
    ii = list(np.argsort(samples_count))
    paths = [paths[i] for i in reversed(ii)]
    
    rules = []
    for path in paths:
        rule = []
        
        for p in path[:-1]:
            rule += [p]
        target = " then "
        if class_names is None:
            target += "response: "+str(np.round(path[-1][0][0][0],3))
        else:
            classes = path[-1][0][0]
            l = np.argmax(classes)
            target += f"class: {class_names[l]} (proba: {np.round(100.0*classes[l]/np.sum(classes),2)}%)"
           
        proba = np.round(100.0*classes[l]/np.sum(classes),2)
        target += f" | based on {path[-1][1]:,} samples"
        rule_wrapper = {'target': target, 'rule': rule, 'proba': proba}
        rules += [rule_wrapper]
        
    return rules

In [65]:
rules = get_rules(clf, dataset.items, clusters[clusters != -1])

for rule in rules:
    n_pos = 0
    for c,p,v in rule['rule']:
        if p == '>':
            n_pos += 1
    rule['pos'] = n_pos

In [66]:
plt.figure()
probs = [r['proba'] for r in rules]
plt.hist(probs, bins = 100)
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [67]:
rules = sorted(rules, key=lambda x:x['pos'])
rules = [r for r in rules if r['proba'] > 50]
print(len(rules))

17


In [69]:
for i in range(17):
    r_i = rules[i]
    print(f"------------- rule {i} length {len(r_i)} -------------")
    print(r_i)

------------- rule 0 length 4 -------------
{'target': ' then class: 3 (proba: 55.56%) | based on 9 samples', 'rule': [('oil', '<=', 0.5), ('ketchup', '<=', 0.5), ('sliced cheese', '<=', 0.5), ('dish cleaner', '<=', 0.5), ('sparkling wine', '<=', 0.5), ('dental care', '<=', 0.5), ('specialty chocolate', '<=', 0.5), ('bottled beer', '<=', 0.5), ('pastry', '<=', 0.5), ('pasta', '<=', 0.5), ('liqueur', '<=', 0.5), ('frozen fruits', '<=', 0.5), ('yogurt', '<=', 0.5), ('mayonnaise', '<=', 0.5), ('processed cheese', '<=', 0.5), ('fish', '<=', 0.5), ('berries', '<=', 0.5), ('soap', '<=', 0.5), ('hamburger meat', '<=', 0.5), ('onions', '<=', 0.5), ('rum', '<=', 0.5), ('frozen dessert', '<=', 0.5), ('mustard', '<=', 0.5), ('baking powder', '<=', 0.5), ('rolls/buns', '<=', 0.5), ('white bread', '<=', 0.5), ('condensed milk', '<=', 0.5), ('flower soil/fertilizer', '<=', 0.5), ('decalcifier', '<=', 0.5), ('coffee', '<=', 0.5), ('misc. beverages', '<=', 0.5), ('artif. sweetener', '<=', 0.5), ('suga

# Train a Soft-Decision-Tree given the self-labels

## Prepare the dataset

In [70]:
tree_dataset = list(zip(tensor_dataset[clusters!=-1], clusters[clusters != -1]))
batch_size = 512
tree_loader = torch.utils.data.DataLoader(tree_dataset, batch_size=batch_size, shuffle=True)

# Define how we prune the weights of a node

In [71]:
def prune_node(node_weights, factor=1):
    w = node_weights.cpu().detach().numpy()
    mean_ = np.mean(w)
    std_ = np.std(w)
    node_weights[((mean_ - std_ * factor) < node_weights) & (node_weights < (mean_ + std_ * factor))] = 0
    return node_weights

def prune_node_keep(node_weights, keep=4):
    w = node_weights.cpu().detach().numpy()
    throw_idx = np.argsort(abs(w))[:-keep]
    node_weights[throw_idx] = 0
    return node_weights

def prune_tree(tree_, factor):
    new_weights = tree_.inner_nodes.weight.clone()
    for i in range(new_weights.shape[0]):
        res = prune_node_keep(new_weights[i, :], factor)
        new_weights[i, :] = res

    with torch.no_grad():
        tree_.inner_nodes.weight.copy_(new_weights)
        
def sparseness(x):
    s = []
    for i in range(x.shape[0]):
        x_ = x[i, :]
        sp = (len(x_) - torch.norm(x_, 0).item()) / len(x_)
        s.append(sp)
    return np.mean(s)

def compute_regularization_by_level(tree):
    total_reg = 0
    for i in range(tree.inner_nodes.weight.shape[0]):
        cur_level = np.floor(np.log2(i+1))
        node_reg = torch.norm(tree.inner_nodes.weight[i].view(-1), 2)
        total_reg += 2**(-cur_level) * node_reg
    return total_reg

def show_sparseness(tree):
    avg_sp = sparseness(tree.inner_nodes.weight)
    print(f"Average sparseness: {avg_sp}")
    layer = 0
    sps = []
    for i in range(tree.inner_nodes.weight.shape[0]):
        cur_layer = int(np.floor(np.log2(i+1)))
        if cur_layer != layer:
            print(f"layer {layer}: {np.mean(sps)}")
            sps = []
            layer = cur_layer

        x_ = tree.inner_nodes.weight[i, :]
        sp = (len(x_) - torch.norm(x_, 0).item()) / len(x_)
        sps.append(sp)
        
    return avg_sp

## Training configurations

In [72]:
def do_epoch(model, loader, device, log_interval, losses, accs, epoch, iteration):
    model = model.train()
    for batch_idx, (data, target) in enumerate(loader):
        iteration += 1
        data, target = data.to(device), target.to(device)

        output, penalty = tree.forward(data)

        # Loss
        loss_tree = criterion(output, target.view(-1))

        # Penalty
        loss_tree += penalty

        # Sparse regularization
#         fc_params = torch.cat([x.view(-1) for x in tree.inner_nodes.parameters()])
#         regularization = sparsity_lamda * torch.norm(fc_params, 2)
        regularization = sparsity_lamda * compute_regularization_by_level(tree)
        loss = loss_tree

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        losses.append(loss.item())

        pred = output.data.max(1)[1]
        correct = pred.eq(target.view(-1).data).sum()
        accs.append(correct.item() / data.size()[0])

        # Print training status
        if batch_idx % log_interval == 0:
            print(f"Epoch: {epoch:02d} | Batch: {batch_idx:03d} / {len(loader):03d} | Total loss: {loss.item():.3f} | Reg loss: {regularization.item():.3f} | Tree loss: {loss_tree.item():.3f} | Accuracy: {correct.item() / data.size()[0]:03f} | {round((time.time() - start_time) / iteration, 3)} sec/iter")
            
    return iteration


In [73]:
lr = 5e-3
weight_decay = 5e-4
sparsity_lamda = 2e-3
epochs = 100
output_dim = len(set(clusters))
log_interval = 1
tree_depth = 10
device = 'cpu'
use_cuda = device != 'cpu'

In [74]:
tree = SDT(input_dim=tensor_dataset.shape[1], output_dim=len(clusters - 1), depth=tree_depth, lamda=1e-3, use_cuda=use_cuda)
optimizer = torch.optim.Adam(tree.parameters(),
                                 lr=lr,
                                 weight_decay=weight_decay)
criterion = nn.CrossEntropyLoss()
tree = tree.to(device)

In [75]:
losses = []
accs = []
sparsity = []

In [76]:
start_time = time.time()
iteration = 0
for epoch in range(epochs):
    # Training
    avg_sp = show_sparseness(tree)
    sparsity.append(avg_sp)
    iteration = do_epoch(tree, tree_loader, device, log_interval, losses, accs, epoch, iteration)
    
    if epoch % 1 == 0:
        prune_tree(tree, factor=3)
        

Average sparseness: 0.0
layer 0: 0.0
layer 1: 0.0
layer 2: 0.0
layer 3: 0.0
layer 4: 0.0
layer 5: 0.0
layer 6: 0.0
layer 7: 0.0
layer 8: 0.0
Epoch: 00 | Batch: 000 / 025 | Total loss: 9.621 | Reg loss: 0.012 | Tree loss: 9.621 | Accuracy: 0.000000 | 6.428 sec/iter
Epoch: 00 | Batch: 001 / 025 | Total loss: 9.618 | Reg loss: 0.011 | Tree loss: 9.618 | Accuracy: 0.000000 | 6.431 sec/iter
Epoch: 00 | Batch: 002 / 025 | Total loss: 9.613 | Reg loss: 0.010 | Tree loss: 9.613 | Accuracy: 0.000000 | 6.465 sec/iter
Epoch: 00 | Batch: 003 / 025 | Total loss: 9.611 | Reg loss: 0.009 | Tree loss: 9.611 | Accuracy: 0.000000 | 6.474 sec/iter
Epoch: 00 | Batch: 004 / 025 | Total loss: 9.610 | Reg loss: 0.009 | Tree loss: 9.610 | Accuracy: 0.000000 | 6.494 sec/iter
Epoch: 00 | Batch: 005 / 025 | Total loss: 9.603 | Reg loss: 0.008 | Tree loss: 9.603 | Accuracy: 0.000000 | 6.525 sec/iter
Epoch: 00 | Batch: 006 / 025 | Total loss: 9.601 | Reg loss: 0.008 | Tree loss: 9.601 | Accuracy: 0.000000 | 6.486 

Epoch: 02 | Batch: 011 / 025 | Total loss: 9.527 | Reg loss: 0.007 | Tree loss: 9.527 | Accuracy: 0.085938 | 7.031 sec/iter
Epoch: 02 | Batch: 012 / 025 | Total loss: 9.527 | Reg loss: 0.007 | Tree loss: 9.527 | Accuracy: 0.082031 | 7.071 sec/iter
Epoch: 02 | Batch: 013 / 025 | Total loss: 9.523 | Reg loss: 0.007 | Tree loss: 9.523 | Accuracy: 0.072266 | 7.104 sec/iter
Epoch: 02 | Batch: 014 / 025 | Total loss: 9.521 | Reg loss: 0.008 | Tree loss: 9.521 | Accuracy: 0.076172 | 7.106 sec/iter
Epoch: 02 | Batch: 015 / 025 | Total loss: 9.520 | Reg loss: 0.008 | Tree loss: 9.520 | Accuracy: 0.064453 | 7.106 sec/iter
Epoch: 02 | Batch: 016 / 025 | Total loss: 9.518 | Reg loss: 0.008 | Tree loss: 9.518 | Accuracy: 0.085938 | 7.101 sec/iter
Epoch: 02 | Batch: 017 / 025 | Total loss: 9.517 | Reg loss: 0.009 | Tree loss: 9.517 | Accuracy: 0.042969 | 7.093 sec/iter
Epoch: 02 | Batch: 018 / 025 | Total loss: 9.510 | Reg loss: 0.009 | Tree loss: 9.510 | Accuracy: 0.060547 | 7.089 sec/iter
Epoch: 0

Epoch: 04 | Batch: 023 / 025 | Total loss: 9.343 | Reg loss: 0.017 | Tree loss: 9.343 | Accuracy: 0.070312 | 6.977 sec/iter
Epoch: 04 | Batch: 024 / 025 | Total loss: 9.333 | Reg loss: 0.018 | Tree loss: 9.333 | Accuracy: 0.068817 | 6.964 sec/iter
Average sparseness: 0.9821428571428573
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
layer 7: 0.9821428571428573
layer 8: 0.9821428571428573
Epoch: 05 | Batch: 000 / 025 | Total loss: 9.423 | Reg loss: 0.012 | Tree loss: 9.423 | Accuracy: 0.078125 | 6.963 sec/iter
Epoch: 05 | Batch: 001 / 025 | Total loss: 9.420 | Reg loss: 0.012 | Tree loss: 9.420 | Accuracy: 0.089844 | 6.959 sec/iter
Epoch: 05 | Batch: 002 / 025 | Total loss: 9.409 | Reg loss: 0.012 | Tree loss: 9.409 | Accuracy: 0.076172 | 6.956 sec/iter
Epoch: 05 | Batch: 003 / 025 | Total loss: 9.396 | Reg loss: 0.012 | Tree loss: 9.396 | Ac

Epoch: 07 | Batch: 008 / 025 | Total loss: 9.049 | Reg loss: 0.019 | Tree loss: 9.049 | Accuracy: 0.058594 | 6.861 sec/iter
Epoch: 07 | Batch: 009 / 025 | Total loss: 9.045 | Reg loss: 0.019 | Tree loss: 9.045 | Accuracy: 0.091797 | 6.858 sec/iter
Epoch: 07 | Batch: 010 / 025 | Total loss: 9.014 | Reg loss: 0.019 | Tree loss: 9.014 | Accuracy: 0.099609 | 6.855 sec/iter
Epoch: 07 | Batch: 011 / 025 | Total loss: 8.995 | Reg loss: 0.020 | Tree loss: 8.995 | Accuracy: 0.091797 | 6.853 sec/iter
Epoch: 07 | Batch: 012 / 025 | Total loss: 8.980 | Reg loss: 0.020 | Tree loss: 8.980 | Accuracy: 0.076172 | 6.849 sec/iter
Epoch: 07 | Batch: 013 / 025 | Total loss: 8.960 | Reg loss: 0.020 | Tree loss: 8.960 | Accuracy: 0.085938 | 6.846 sec/iter
Epoch: 07 | Batch: 014 / 025 | Total loss: 8.937 | Reg loss: 0.021 | Tree loss: 8.937 | Accuracy: 0.076172 | 6.843 sec/iter
Epoch: 07 | Batch: 015 / 025 | Total loss: 8.921 | Reg loss: 0.021 | Tree loss: 8.921 | Accuracy: 0.078125 | 6.84 sec/iter
Epoch: 07

Epoch: 09 | Batch: 020 / 025 | Total loss: 8.342 | Reg loss: 0.027 | Tree loss: 8.342 | Accuracy: 0.095703 | 6.851 sec/iter
Epoch: 09 | Batch: 021 / 025 | Total loss: 8.291 | Reg loss: 0.027 | Tree loss: 8.291 | Accuracy: 0.105469 | 6.85 sec/iter
Epoch: 09 | Batch: 022 / 025 | Total loss: 8.311 | Reg loss: 0.027 | Tree loss: 8.311 | Accuracy: 0.070312 | 6.853 sec/iter
Epoch: 09 | Batch: 023 / 025 | Total loss: 8.267 | Reg loss: 0.028 | Tree loss: 8.267 | Accuracy: 0.074219 | 6.851 sec/iter
Epoch: 09 | Batch: 024 / 025 | Total loss: 8.271 | Reg loss: 0.028 | Tree loss: 8.271 | Accuracy: 0.094624 | 6.847 sec/iter
Average sparseness: 0.9821428571428573
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
layer 7: 0.9821428571428573
layer 8: 0.9821428571428573
Epoch: 10 | Batch: 000 / 025 | Total loss: 8.588 | Reg loss: 0.023 | Tree loss: 8.588 | Acc

Epoch: 12 | Batch: 005 / 025 | Total loss: 8.002 | Reg loss: 0.026 | Tree loss: 8.002 | Accuracy: 0.080078 | 6.788 sec/iter
Epoch: 12 | Batch: 006 / 025 | Total loss: 7.970 | Reg loss: 0.026 | Tree loss: 7.970 | Accuracy: 0.080078 | 6.791 sec/iter
Epoch: 12 | Batch: 007 / 025 | Total loss: 7.996 | Reg loss: 0.027 | Tree loss: 7.996 | Accuracy: 0.074219 | 6.797 sec/iter
Epoch: 12 | Batch: 008 / 025 | Total loss: 7.919 | Reg loss: 0.027 | Tree loss: 7.919 | Accuracy: 0.097656 | 6.797 sec/iter
Epoch: 12 | Batch: 009 / 025 | Total loss: 7.902 | Reg loss: 0.027 | Tree loss: 7.902 | Accuracy: 0.103516 | 6.797 sec/iter
Epoch: 12 | Batch: 010 / 025 | Total loss: 7.893 | Reg loss: 0.027 | Tree loss: 7.893 | Accuracy: 0.062500 | 6.796 sec/iter
Epoch: 12 | Batch: 011 / 025 | Total loss: 7.833 | Reg loss: 0.027 | Tree loss: 7.833 | Accuracy: 0.107422 | 6.795 sec/iter
Epoch: 12 | Batch: 012 / 025 | Total loss: 7.842 | Reg loss: 0.028 | Tree loss: 7.842 | Accuracy: 0.082031 | 6.794 sec/iter
Epoch: 1

Epoch: 14 | Batch: 017 / 025 | Total loss: 7.260 | Reg loss: 0.030 | Tree loss: 7.260 | Accuracy: 0.083984 | 6.743 sec/iter
Epoch: 14 | Batch: 018 / 025 | Total loss: 7.265 | Reg loss: 0.030 | Tree loss: 7.265 | Accuracy: 0.074219 | 6.743 sec/iter
Epoch: 14 | Batch: 019 / 025 | Total loss: 7.199 | Reg loss: 0.031 | Tree loss: 7.199 | Accuracy: 0.093750 | 6.743 sec/iter
Epoch: 14 | Batch: 020 / 025 | Total loss: 7.161 | Reg loss: 0.031 | Tree loss: 7.161 | Accuracy: 0.099609 | 6.743 sec/iter
Epoch: 14 | Batch: 021 / 025 | Total loss: 7.172 | Reg loss: 0.031 | Tree loss: 7.172 | Accuracy: 0.091797 | 6.744 sec/iter
Epoch: 14 | Batch: 022 / 025 | Total loss: 7.173 | Reg loss: 0.031 | Tree loss: 7.173 | Accuracy: 0.082031 | 6.743 sec/iter
Epoch: 14 | Batch: 023 / 025 | Total loss: 7.148 | Reg loss: 0.031 | Tree loss: 7.148 | Accuracy: 0.085938 | 6.743 sec/iter
Epoch: 14 | Batch: 024 / 025 | Total loss: 7.126 | Reg loss: 0.032 | Tree loss: 7.126 | Accuracy: 0.079570 | 6.738 sec/iter
Average 

Epoch: 17 | Batch: 002 / 025 | Total loss: 6.983 | Reg loss: 0.031 | Tree loss: 6.983 | Accuracy: 0.080078 | 6.761 sec/iter
Epoch: 17 | Batch: 003 / 025 | Total loss: 6.950 | Reg loss: 0.031 | Tree loss: 6.950 | Accuracy: 0.082031 | 6.765 sec/iter
Epoch: 17 | Batch: 004 / 025 | Total loss: 6.914 | Reg loss: 0.031 | Tree loss: 6.914 | Accuracy: 0.097656 | 6.769 sec/iter
Epoch: 17 | Batch: 005 / 025 | Total loss: 6.898 | Reg loss: 0.031 | Tree loss: 6.898 | Accuracy: 0.103516 | 6.775 sec/iter
Epoch: 17 | Batch: 006 / 025 | Total loss: 6.913 | Reg loss: 0.031 | Tree loss: 6.913 | Accuracy: 0.089844 | 6.777 sec/iter
Epoch: 17 | Batch: 007 / 025 | Total loss: 6.888 | Reg loss: 0.031 | Tree loss: 6.888 | Accuracy: 0.085938 | 6.78 sec/iter
Epoch: 17 | Batch: 008 / 025 | Total loss: 6.861 | Reg loss: 0.031 | Tree loss: 6.861 | Accuracy: 0.072266 | 6.78 sec/iter
Epoch: 17 | Batch: 009 / 025 | Total loss: 6.861 | Reg loss: 0.031 | Tree loss: 6.861 | Accuracy: 0.080078 | 6.779 sec/iter
Epoch: 17 

Epoch: 19 | Batch: 014 / 025 | Total loss: 6.345 | Reg loss: 0.033 | Tree loss: 6.345 | Accuracy: 0.068359 | 6.765 sec/iter
Epoch: 19 | Batch: 015 / 025 | Total loss: 6.321 | Reg loss: 0.033 | Tree loss: 6.321 | Accuracy: 0.093750 | 6.764 sec/iter
Epoch: 19 | Batch: 016 / 025 | Total loss: 6.321 | Reg loss: 0.033 | Tree loss: 6.321 | Accuracy: 0.078125 | 6.764 sec/iter
Epoch: 19 | Batch: 017 / 025 | Total loss: 6.333 | Reg loss: 0.033 | Tree loss: 6.333 | Accuracy: 0.082031 | 6.763 sec/iter
Epoch: 19 | Batch: 018 / 025 | Total loss: 6.290 | Reg loss: 0.033 | Tree loss: 6.290 | Accuracy: 0.089844 | 6.763 sec/iter
Epoch: 19 | Batch: 019 / 025 | Total loss: 6.274 | Reg loss: 0.034 | Tree loss: 6.274 | Accuracy: 0.074219 | 6.762 sec/iter
Epoch: 19 | Batch: 020 / 025 | Total loss: 6.217 | Reg loss: 0.034 | Tree loss: 6.217 | Accuracy: 0.103516 | 6.762 sec/iter
Epoch: 19 | Batch: 021 / 025 | Total loss: 6.226 | Reg loss: 0.034 | Tree loss: 6.226 | Accuracy: 0.085938 | 6.761 sec/iter
Epoch: 1

Epoch: 22 | Batch: 000 / 025 | Total loss: 6.123 | Reg loss: 0.034 | Tree loss: 6.123 | Accuracy: 0.078125 | 6.745 sec/iter
Epoch: 22 | Batch: 001 / 025 | Total loss: 6.047 | Reg loss: 0.034 | Tree loss: 6.047 | Accuracy: 0.056641 | 6.745 sec/iter
Epoch: 22 | Batch: 002 / 025 | Total loss: 6.058 | Reg loss: 0.034 | Tree loss: 6.058 | Accuracy: 0.087891 | 6.745 sec/iter
Epoch: 22 | Batch: 003 / 025 | Total loss: 6.060 | Reg loss: 0.034 | Tree loss: 6.060 | Accuracy: 0.111328 | 6.746 sec/iter
Epoch: 22 | Batch: 004 / 025 | Total loss: 5.983 | Reg loss: 0.034 | Tree loss: 5.983 | Accuracy: 0.109375 | 6.746 sec/iter
Epoch: 22 | Batch: 005 / 025 | Total loss: 5.996 | Reg loss: 0.034 | Tree loss: 5.996 | Accuracy: 0.089844 | 6.747 sec/iter
Epoch: 22 | Batch: 006 / 025 | Total loss: 5.995 | Reg loss: 0.034 | Tree loss: 5.995 | Accuracy: 0.078125 | 6.748 sec/iter
Epoch: 22 | Batch: 007 / 025 | Total loss: 5.935 | Reg loss: 0.034 | Tree loss: 5.935 | Accuracy: 0.107422 | 6.747 sec/iter
Epoch: 2

Epoch: 24 | Batch: 012 / 025 | Total loss: 5.571 | Reg loss: 0.035 | Tree loss: 5.571 | Accuracy: 0.093750 | 6.779 sec/iter
Epoch: 24 | Batch: 013 / 025 | Total loss: 5.564 | Reg loss: 0.035 | Tree loss: 5.564 | Accuracy: 0.095703 | 6.779 sec/iter
Epoch: 24 | Batch: 014 / 025 | Total loss: 5.538 | Reg loss: 0.035 | Tree loss: 5.538 | Accuracy: 0.074219 | 6.778 sec/iter
Epoch: 24 | Batch: 015 / 025 | Total loss: 5.485 | Reg loss: 0.035 | Tree loss: 5.485 | Accuracy: 0.099609 | 6.778 sec/iter
Epoch: 24 | Batch: 016 / 025 | Total loss: 5.533 | Reg loss: 0.035 | Tree loss: 5.533 | Accuracy: 0.078125 | 6.778 sec/iter
Epoch: 24 | Batch: 017 / 025 | Total loss: 5.468 | Reg loss: 0.035 | Tree loss: 5.468 | Accuracy: 0.087891 | 6.778 sec/iter
Epoch: 24 | Batch: 018 / 025 | Total loss: 5.467 | Reg loss: 0.035 | Tree loss: 5.467 | Accuracy: 0.076172 | 6.778 sec/iter
Epoch: 24 | Batch: 019 / 025 | Total loss: 5.436 | Reg loss: 0.035 | Tree loss: 5.436 | Accuracy: 0.089844 | 6.777 sec/iter
Epoch: 2

Epoch: 26 | Batch: 024 / 025 | Total loss: 5.026 | Reg loss: 0.036 | Tree loss: 5.026 | Accuracy: 0.101075 | 6.773 sec/iter
Average sparseness: 0.9821428571428573
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
layer 7: 0.9821428571428573
layer 8: 0.9821428571428573
Epoch: 27 | Batch: 000 / 025 | Total loss: 5.367 | Reg loss: 0.035 | Tree loss: 5.367 | Accuracy: 0.082031 | 6.78 sec/iter
Epoch: 27 | Batch: 001 / 025 | Total loss: 5.332 | Reg loss: 0.035 | Tree loss: 5.332 | Accuracy: 0.078125 | 6.782 sec/iter
Epoch: 27 | Batch: 002 / 025 | Total loss: 5.308 | Reg loss: 0.035 | Tree loss: 5.308 | Accuracy: 0.082031 | 6.782 sec/iter
Epoch: 27 | Batch: 003 / 025 | Total loss: 5.297 | Reg loss: 0.035 | Tree loss: 5.297 | Accuracy: 0.099609 | 6.785 sec/iter
Epoch: 27 | Batch: 004 / 025 | Total loss: 5.281 | Reg loss: 0.035 | Tree loss: 5.281 | Acc

Epoch: 29 | Batch: 009 / 025 | Total loss: 4.955 | Reg loss: 0.036 | Tree loss: 4.955 | Accuracy: 0.119141 | 6.801 sec/iter
Epoch: 29 | Batch: 010 / 025 | Total loss: 4.915 | Reg loss: 0.036 | Tree loss: 4.915 | Accuracy: 0.080078 | 6.802 sec/iter
Epoch: 29 | Batch: 011 / 025 | Total loss: 4.915 | Reg loss: 0.036 | Tree loss: 4.915 | Accuracy: 0.091797 | 6.803 sec/iter
Epoch: 29 | Batch: 012 / 025 | Total loss: 4.939 | Reg loss: 0.036 | Tree loss: 4.939 | Accuracy: 0.076172 | 6.803 sec/iter
Epoch: 29 | Batch: 013 / 025 | Total loss: 4.925 | Reg loss: 0.036 | Tree loss: 4.925 | Accuracy: 0.083984 | 6.803 sec/iter
Epoch: 29 | Batch: 014 / 025 | Total loss: 4.856 | Reg loss: 0.036 | Tree loss: 4.856 | Accuracy: 0.085938 | 6.804 sec/iter
Epoch: 29 | Batch: 015 / 025 | Total loss: 4.812 | Reg loss: 0.036 | Tree loss: 4.812 | Accuracy: 0.105469 | 6.806 sec/iter
Epoch: 29 | Batch: 016 / 025 | Total loss: 4.903 | Reg loss: 0.036 | Tree loss: 4.903 | Accuracy: 0.066406 | 6.806 sec/iter
Epoch: 2

Epoch: 31 | Batch: 021 / 025 | Total loss: 4.553 | Reg loss: 0.037 | Tree loss: 4.553 | Accuracy: 0.109375 | 6.779 sec/iter
Epoch: 31 | Batch: 022 / 025 | Total loss: 4.582 | Reg loss: 0.037 | Tree loss: 4.582 | Accuracy: 0.085938 | 6.779 sec/iter
Epoch: 31 | Batch: 023 / 025 | Total loss: 4.516 | Reg loss: 0.037 | Tree loss: 4.516 | Accuracy: 0.082031 | 6.778 sec/iter
Epoch: 31 | Batch: 024 / 025 | Total loss: 4.553 | Reg loss: 0.037 | Tree loss: 4.553 | Accuracy: 0.079570 | 6.775 sec/iter
Average sparseness: 0.9821428571428573
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
layer 7: 0.9821428571428573
layer 8: 0.9821428571428573
Epoch: 32 | Batch: 000 / 025 | Total loss: 4.770 | Reg loss: 0.036 | Tree loss: 4.770 | Accuracy: 0.093750 | 6.774 sec/iter
Epoch: 32 | Batch: 001 / 025 | Total loss: 4.754 | Reg loss: 0.036 | Tree loss: 4.754 | Ac

Epoch: 34 | Batch: 006 / 025 | Total loss: 4.560 | Reg loss: 0.036 | Tree loss: 4.560 | Accuracy: 0.083984 | 6.756 sec/iter
Epoch: 34 | Batch: 007 / 025 | Total loss: 4.483 | Reg loss: 0.036 | Tree loss: 4.483 | Accuracy: 0.099609 | 6.756 sec/iter
Epoch: 34 | Batch: 008 / 025 | Total loss: 4.477 | Reg loss: 0.036 | Tree loss: 4.477 | Accuracy: 0.078125 | 6.756 sec/iter
Epoch: 34 | Batch: 009 / 025 | Total loss: 4.432 | Reg loss: 0.036 | Tree loss: 4.432 | Accuracy: 0.091797 | 6.755 sec/iter
Epoch: 34 | Batch: 010 / 025 | Total loss: 4.443 | Reg loss: 0.036 | Tree loss: 4.443 | Accuracy: 0.093750 | 6.755 sec/iter
Epoch: 34 | Batch: 011 / 025 | Total loss: 4.394 | Reg loss: 0.036 | Tree loss: 4.394 | Accuracy: 0.095703 | 6.755 sec/iter
Epoch: 34 | Batch: 012 / 025 | Total loss: 4.410 | Reg loss: 0.036 | Tree loss: 4.410 | Accuracy: 0.083984 | 6.755 sec/iter
Epoch: 34 | Batch: 013 / 025 | Total loss: 4.382 | Reg loss: 0.036 | Tree loss: 4.382 | Accuracy: 0.103516 | 6.755 sec/iter
Epoch: 3

Epoch: 36 | Batch: 018 / 025 | Total loss: 4.179 | Reg loss: 0.037 | Tree loss: 4.179 | Accuracy: 0.083984 | 6.739 sec/iter
Epoch: 36 | Batch: 019 / 025 | Total loss: 4.171 | Reg loss: 0.037 | Tree loss: 4.171 | Accuracy: 0.107422 | 6.739 sec/iter
Epoch: 36 | Batch: 020 / 025 | Total loss: 4.154 | Reg loss: 0.037 | Tree loss: 4.154 | Accuracy: 0.078125 | 6.739 sec/iter
Epoch: 36 | Batch: 021 / 025 | Total loss: 4.163 | Reg loss: 0.037 | Tree loss: 4.163 | Accuracy: 0.087891 | 6.739 sec/iter
Epoch: 36 | Batch: 022 / 025 | Total loss: 4.147 | Reg loss: 0.037 | Tree loss: 4.147 | Accuracy: 0.091797 | 6.738 sec/iter
Epoch: 36 | Batch: 023 / 025 | Total loss: 4.080 | Reg loss: 0.037 | Tree loss: 4.080 | Accuracy: 0.115234 | 6.738 sec/iter
Epoch: 36 | Batch: 024 / 025 | Total loss: 4.094 | Reg loss: 0.037 | Tree loss: 4.094 | Accuracy: 0.068817 | 6.736 sec/iter
Average sparseness: 0.9821428571428573
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 

Epoch: 39 | Batch: 003 / 025 | Total loss: 4.208 | Reg loss: 0.036 | Tree loss: 4.208 | Accuracy: 0.091797 | 6.722 sec/iter
Epoch: 39 | Batch: 004 / 025 | Total loss: 4.163 | Reg loss: 0.036 | Tree loss: 4.163 | Accuracy: 0.083984 | 6.721 sec/iter
Epoch: 39 | Batch: 005 / 025 | Total loss: 4.176 | Reg loss: 0.036 | Tree loss: 4.176 | Accuracy: 0.101562 | 6.721 sec/iter
Epoch: 39 | Batch: 006 / 025 | Total loss: 4.125 | Reg loss: 0.036 | Tree loss: 4.125 | Accuracy: 0.089844 | 6.721 sec/iter
Epoch: 39 | Batch: 007 / 025 | Total loss: 4.113 | Reg loss: 0.036 | Tree loss: 4.113 | Accuracy: 0.093750 | 6.721 sec/iter
Epoch: 39 | Batch: 008 / 025 | Total loss: 4.104 | Reg loss: 0.036 | Tree loss: 4.104 | Accuracy: 0.089844 | 6.72 sec/iter
Epoch: 39 | Batch: 009 / 025 | Total loss: 4.069 | Reg loss: 0.036 | Tree loss: 4.069 | Accuracy: 0.097656 | 6.72 sec/iter
Epoch: 39 | Batch: 010 / 025 | Total loss: 4.090 | Reg loss: 0.036 | Tree loss: 4.090 | Accuracy: 0.080078 | 6.72 sec/iter
Epoch: 39 |

Epoch: 41 | Batch: 015 / 025 | Total loss: 3.880 | Reg loss: 0.036 | Tree loss: 3.880 | Accuracy: 0.109375 | 6.706 sec/iter
Epoch: 41 | Batch: 016 / 025 | Total loss: 3.899 | Reg loss: 0.036 | Tree loss: 3.899 | Accuracy: 0.097656 | 6.706 sec/iter
Epoch: 41 | Batch: 017 / 025 | Total loss: 3.830 | Reg loss: 0.036 | Tree loss: 3.830 | Accuracy: 0.087891 | 6.706 sec/iter
Epoch: 41 | Batch: 018 / 025 | Total loss: 3.859 | Reg loss: 0.036 | Tree loss: 3.859 | Accuracy: 0.091797 | 6.706 sec/iter
Epoch: 41 | Batch: 019 / 025 | Total loss: 3.898 | Reg loss: 0.036 | Tree loss: 3.898 | Accuracy: 0.119141 | 6.706 sec/iter
Epoch: 41 | Batch: 020 / 025 | Total loss: 3.857 | Reg loss: 0.036 | Tree loss: 3.857 | Accuracy: 0.076172 | 6.705 sec/iter
Epoch: 41 | Batch: 021 / 025 | Total loss: 3.809 | Reg loss: 0.036 | Tree loss: 3.809 | Accuracy: 0.091797 | 6.705 sec/iter
Epoch: 41 | Batch: 022 / 025 | Total loss: 3.830 | Reg loss: 0.036 | Tree loss: 3.830 | Accuracy: 0.083984 | 6.705 sec/iter
Epoch: 4

Epoch: 44 | Batch: 000 / 025 | Total loss: 3.960 | Reg loss: 0.035 | Tree loss: 3.960 | Accuracy: 0.097656 | 6.693 sec/iter
Epoch: 44 | Batch: 001 / 025 | Total loss: 3.957 | Reg loss: 0.035 | Tree loss: 3.957 | Accuracy: 0.083984 | 6.692 sec/iter
Epoch: 44 | Batch: 002 / 025 | Total loss: 4.018 | Reg loss: 0.035 | Tree loss: 4.018 | Accuracy: 0.078125 | 6.692 sec/iter
Epoch: 44 | Batch: 003 / 025 | Total loss: 3.939 | Reg loss: 0.035 | Tree loss: 3.939 | Accuracy: 0.085938 | 6.691 sec/iter
Epoch: 44 | Batch: 004 / 025 | Total loss: 3.887 | Reg loss: 0.035 | Tree loss: 3.887 | Accuracy: 0.078125 | 6.69 sec/iter
Epoch: 44 | Batch: 005 / 025 | Total loss: 3.903 | Reg loss: 0.035 | Tree loss: 3.903 | Accuracy: 0.097656 | 6.69 sec/iter
Epoch: 44 | Batch: 006 / 025 | Total loss: 3.794 | Reg loss: 0.035 | Tree loss: 3.794 | Accuracy: 0.107422 | 6.69 sec/iter
Epoch: 44 | Batch: 007 / 025 | Total loss: 3.880 | Reg loss: 0.035 | Tree loss: 3.880 | Accuracy: 0.091797 | 6.689 sec/iter
Epoch: 44 |

Epoch: 46 | Batch: 012 / 025 | Total loss: 3.693 | Reg loss: 0.035 | Tree loss: 3.693 | Accuracy: 0.095703 | 6.662 sec/iter
Epoch: 46 | Batch: 013 / 025 | Total loss: 3.709 | Reg loss: 0.035 | Tree loss: 3.709 | Accuracy: 0.097656 | 6.661 sec/iter
Epoch: 46 | Batch: 014 / 025 | Total loss: 3.705 | Reg loss: 0.035 | Tree loss: 3.705 | Accuracy: 0.101562 | 6.661 sec/iter
Epoch: 46 | Batch: 015 / 025 | Total loss: 3.685 | Reg loss: 0.035 | Tree loss: 3.685 | Accuracy: 0.076172 | 6.661 sec/iter
Epoch: 46 | Batch: 016 / 025 | Total loss: 3.690 | Reg loss: 0.035 | Tree loss: 3.690 | Accuracy: 0.107422 | 6.66 sec/iter
Epoch: 46 | Batch: 017 / 025 | Total loss: 3.663 | Reg loss: 0.035 | Tree loss: 3.663 | Accuracy: 0.085938 | 6.66 sec/iter
Epoch: 46 | Batch: 018 / 025 | Total loss: 3.638 | Reg loss: 0.035 | Tree loss: 3.638 | Accuracy: 0.093750 | 6.659 sec/iter
Epoch: 46 | Batch: 019 / 025 | Total loss: 3.620 | Reg loss: 0.035 | Tree loss: 3.620 | Accuracy: 0.091797 | 6.659 sec/iter
Epoch: 46 

Epoch: 48 | Batch: 024 / 025 | Total loss: 3.466 | Reg loss: 0.035 | Tree loss: 3.466 | Accuracy: 0.105376 | 6.632 sec/iter
Average sparseness: 0.9821428571428573
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
layer 7: 0.9821428571428573
layer 8: 0.9821428571428573
Epoch: 49 | Batch: 000 / 025 | Total loss: 3.733 | Reg loss: 0.034 | Tree loss: 3.733 | Accuracy: 0.105469 | 6.632 sec/iter
Epoch: 49 | Batch: 001 / 025 | Total loss: 3.742 | Reg loss: 0.034 | Tree loss: 3.742 | Accuracy: 0.087891 | 6.631 sec/iter
Epoch: 49 | Batch: 002 / 025 | Total loss: 3.783 | Reg loss: 0.034 | Tree loss: 3.783 | Accuracy: 0.091797 | 6.631 sec/iter
Epoch: 49 | Batch: 003 / 025 | Total loss: 3.713 | Reg loss: 0.034 | Tree loss: 3.713 | Accuracy: 0.105469 | 6.631 sec/iter
Epoch: 49 | Batch: 004 / 025 | Total loss: 3.712 | Reg loss: 0.034 | Tree loss: 3.712 | Ac

Epoch: 51 | Batch: 009 / 025 | Total loss: 3.556 | Reg loss: 0.034 | Tree loss: 3.556 | Accuracy: 0.085938 | 6.607 sec/iter
Epoch: 51 | Batch: 010 / 025 | Total loss: 3.643 | Reg loss: 0.034 | Tree loss: 3.643 | Accuracy: 0.082031 | 6.607 sec/iter
Epoch: 51 | Batch: 011 / 025 | Total loss: 3.555 | Reg loss: 0.034 | Tree loss: 3.555 | Accuracy: 0.093750 | 6.608 sec/iter
Epoch: 51 | Batch: 012 / 025 | Total loss: 3.554 | Reg loss: 0.034 | Tree loss: 3.554 | Accuracy: 0.085938 | 6.607 sec/iter
Epoch: 51 | Batch: 013 / 025 | Total loss: 3.520 | Reg loss: 0.034 | Tree loss: 3.520 | Accuracy: 0.107422 | 6.608 sec/iter
Epoch: 51 | Batch: 014 / 025 | Total loss: 3.481 | Reg loss: 0.034 | Tree loss: 3.481 | Accuracy: 0.109375 | 6.609 sec/iter
Epoch: 51 | Batch: 015 / 025 | Total loss: 3.507 | Reg loss: 0.034 | Tree loss: 3.507 | Accuracy: 0.117188 | 6.608 sec/iter
Epoch: 51 | Batch: 016 / 025 | Total loss: 3.452 | Reg loss: 0.034 | Tree loss: 3.452 | Accuracy: 0.099609 | 6.608 sec/iter
Epoch: 5

Epoch: 53 | Batch: 021 / 025 | Total loss: 3.418 | Reg loss: 0.034 | Tree loss: 3.418 | Accuracy: 0.091797 | 6.599 sec/iter
Epoch: 53 | Batch: 022 / 025 | Total loss: 3.378 | Reg loss: 0.034 | Tree loss: 3.378 | Accuracy: 0.085938 | 6.598 sec/iter
Epoch: 53 | Batch: 023 / 025 | Total loss: 3.378 | Reg loss: 0.034 | Tree loss: 3.378 | Accuracy: 0.107422 | 6.598 sec/iter
Epoch: 53 | Batch: 024 / 025 | Total loss: 3.425 | Reg loss: 0.034 | Tree loss: 3.425 | Accuracy: 0.083871 | 6.596 sec/iter
Average sparseness: 0.9821428571428573
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
layer 7: 0.9821428571428573
layer 8: 0.9821428571428573
Epoch: 54 | Batch: 000 / 025 | Total loss: 3.695 | Reg loss: 0.033 | Tree loss: 3.695 | Accuracy: 0.089844 | 6.596 sec/iter
Epoch: 54 | Batch: 001 / 025 | Total loss: 3.581 | Reg loss: 0.033 | Tree loss: 3.581 | Ac

Epoch: 56 | Batch: 006 / 025 | Total loss: 3.476 | Reg loss: 0.033 | Tree loss: 3.476 | Accuracy: 0.117188 | 6.576 sec/iter
Epoch: 56 | Batch: 007 / 025 | Total loss: 3.476 | Reg loss: 0.033 | Tree loss: 3.476 | Accuracy: 0.113281 | 6.576 sec/iter
Epoch: 56 | Batch: 008 / 025 | Total loss: 3.496 | Reg loss: 0.033 | Tree loss: 3.496 | Accuracy: 0.093750 | 6.576 sec/iter
Epoch: 56 | Batch: 009 / 025 | Total loss: 3.448 | Reg loss: 0.033 | Tree loss: 3.448 | Accuracy: 0.074219 | 6.575 sec/iter
Epoch: 56 | Batch: 010 / 025 | Total loss: 3.405 | Reg loss: 0.033 | Tree loss: 3.405 | Accuracy: 0.123047 | 6.575 sec/iter
Epoch: 56 | Batch: 011 / 025 | Total loss: 3.431 | Reg loss: 0.033 | Tree loss: 3.431 | Accuracy: 0.091797 | 6.575 sec/iter
Epoch: 56 | Batch: 012 / 025 | Total loss: 3.456 | Reg loss: 0.033 | Tree loss: 3.456 | Accuracy: 0.097656 | 6.574 sec/iter
Epoch: 56 | Batch: 013 / 025 | Total loss: 3.391 | Reg loss: 0.033 | Tree loss: 3.391 | Accuracy: 0.093750 | 6.574 sec/iter
Epoch: 5

Epoch: 58 | Batch: 018 / 025 | Total loss: 3.277 | Reg loss: 0.033 | Tree loss: 3.277 | Accuracy: 0.087891 | 6.555 sec/iter
Epoch: 58 | Batch: 019 / 025 | Total loss: 3.383 | Reg loss: 0.033 | Tree loss: 3.383 | Accuracy: 0.083984 | 6.555 sec/iter
Epoch: 58 | Batch: 020 / 025 | Total loss: 3.269 | Reg loss: 0.033 | Tree loss: 3.269 | Accuracy: 0.105469 | 6.555 sec/iter
Epoch: 58 | Batch: 021 / 025 | Total loss: 3.334 | Reg loss: 0.033 | Tree loss: 3.334 | Accuracy: 0.101562 | 6.554 sec/iter
Epoch: 58 | Batch: 022 / 025 | Total loss: 3.249 | Reg loss: 0.033 | Tree loss: 3.249 | Accuracy: 0.083984 | 6.554 sec/iter
Epoch: 58 | Batch: 023 / 025 | Total loss: 3.256 | Reg loss: 0.033 | Tree loss: 3.256 | Accuracy: 0.107422 | 6.554 sec/iter
Epoch: 58 | Batch: 024 / 025 | Total loss: 3.288 | Reg loss: 0.033 | Tree loss: 3.288 | Accuracy: 0.068817 | 6.553 sec/iter
Average sparseness: 0.9821428571428573
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 

Epoch: 61 | Batch: 003 / 025 | Total loss: 3.376 | Reg loss: 0.032 | Tree loss: 3.376 | Accuracy: 0.111328 | 6.537 sec/iter
Epoch: 61 | Batch: 004 / 025 | Total loss: 3.269 | Reg loss: 0.032 | Tree loss: 3.269 | Accuracy: 0.111328 | 6.536 sec/iter
Epoch: 61 | Batch: 005 / 025 | Total loss: 3.292 | Reg loss: 0.032 | Tree loss: 3.292 | Accuracy: 0.119141 | 6.536 sec/iter
Epoch: 61 | Batch: 006 / 025 | Total loss: 3.317 | Reg loss: 0.032 | Tree loss: 3.317 | Accuracy: 0.109375 | 6.536 sec/iter
Epoch: 61 | Batch: 007 / 025 | Total loss: 3.368 | Reg loss: 0.032 | Tree loss: 3.368 | Accuracy: 0.103516 | 6.536 sec/iter
Epoch: 61 | Batch: 008 / 025 | Total loss: 3.362 | Reg loss: 0.032 | Tree loss: 3.362 | Accuracy: 0.087891 | 6.535 sec/iter
Epoch: 61 | Batch: 009 / 025 | Total loss: 3.318 | Reg loss: 0.032 | Tree loss: 3.318 | Accuracy: 0.099609 | 6.535 sec/iter
Epoch: 61 | Batch: 010 / 025 | Total loss: 3.290 | Reg loss: 0.032 | Tree loss: 3.290 | Accuracy: 0.115234 | 6.535 sec/iter
Epoch: 6

Epoch: 63 | Batch: 015 / 025 | Total loss: 3.238 | Reg loss: 0.032 | Tree loss: 3.238 | Accuracy: 0.091797 | 6.519 sec/iter
Epoch: 63 | Batch: 016 / 025 | Total loss: 3.175 | Reg loss: 0.032 | Tree loss: 3.175 | Accuracy: 0.113281 | 6.519 sec/iter
Epoch: 63 | Batch: 017 / 025 | Total loss: 3.171 | Reg loss: 0.032 | Tree loss: 3.171 | Accuracy: 0.082031 | 6.519 sec/iter
Epoch: 63 | Batch: 018 / 025 | Total loss: 3.260 | Reg loss: 0.032 | Tree loss: 3.260 | Accuracy: 0.087891 | 6.519 sec/iter
Epoch: 63 | Batch: 019 / 025 | Total loss: 3.151 | Reg loss: 0.032 | Tree loss: 3.151 | Accuracy: 0.087891 | 6.519 sec/iter
Epoch: 63 | Batch: 020 / 025 | Total loss: 3.080 | Reg loss: 0.032 | Tree loss: 3.080 | Accuracy: 0.113281 | 6.518 sec/iter
Epoch: 63 | Batch: 021 / 025 | Total loss: 3.146 | Reg loss: 0.033 | Tree loss: 3.146 | Accuracy: 0.121094 | 6.518 sec/iter
Epoch: 63 | Batch: 022 / 025 | Total loss: 3.055 | Reg loss: 0.033 | Tree loss: 3.055 | Accuracy: 0.119141 | 6.518 sec/iter
Epoch: 6

Epoch: 66 | Batch: 000 / 025 | Total loss: 3.305 | Reg loss: 0.032 | Tree loss: 3.305 | Accuracy: 0.097656 | 6.502 sec/iter
Epoch: 66 | Batch: 001 / 025 | Total loss: 3.285 | Reg loss: 0.032 | Tree loss: 3.285 | Accuracy: 0.091797 | 6.502 sec/iter
Epoch: 66 | Batch: 002 / 025 | Total loss: 3.271 | Reg loss: 0.032 | Tree loss: 3.271 | Accuracy: 0.091797 | 6.501 sec/iter
Epoch: 66 | Batch: 003 / 025 | Total loss: 3.179 | Reg loss: 0.032 | Tree loss: 3.179 | Accuracy: 0.113281 | 6.501 sec/iter
Epoch: 66 | Batch: 004 / 025 | Total loss: 3.208 | Reg loss: 0.032 | Tree loss: 3.208 | Accuracy: 0.107422 | 6.501 sec/iter
Epoch: 66 | Batch: 005 / 025 | Total loss: 3.150 | Reg loss: 0.032 | Tree loss: 3.150 | Accuracy: 0.107422 | 6.501 sec/iter
Epoch: 66 | Batch: 006 / 025 | Total loss: 3.124 | Reg loss: 0.032 | Tree loss: 3.124 | Accuracy: 0.087891 | 6.501 sec/iter
Epoch: 66 | Batch: 007 / 025 | Total loss: 3.215 | Reg loss: 0.032 | Tree loss: 3.215 | Accuracy: 0.091797 | 6.501 sec/iter
Epoch: 6

Epoch: 68 | Batch: 012 / 025 | Total loss: 3.097 | Reg loss: 0.031 | Tree loss: 3.097 | Accuracy: 0.101562 | 6.489 sec/iter
Epoch: 68 | Batch: 013 / 025 | Total loss: 3.044 | Reg loss: 0.031 | Tree loss: 3.044 | Accuracy: 0.097656 | 6.489 sec/iter
Epoch: 68 | Batch: 014 / 025 | Total loss: 3.121 | Reg loss: 0.031 | Tree loss: 3.121 | Accuracy: 0.097656 | 6.489 sec/iter
Epoch: 68 | Batch: 015 / 025 | Total loss: 3.120 | Reg loss: 0.032 | Tree loss: 3.120 | Accuracy: 0.093750 | 6.488 sec/iter
Epoch: 68 | Batch: 016 / 025 | Total loss: 3.117 | Reg loss: 0.032 | Tree loss: 3.117 | Accuracy: 0.097656 | 6.488 sec/iter
Epoch: 68 | Batch: 017 / 025 | Total loss: 3.126 | Reg loss: 0.032 | Tree loss: 3.126 | Accuracy: 0.097656 | 6.488 sec/iter
Epoch: 68 | Batch: 018 / 025 | Total loss: 3.097 | Reg loss: 0.032 | Tree loss: 3.097 | Accuracy: 0.111328 | 6.488 sec/iter
Epoch: 68 | Batch: 019 / 025 | Total loss: 3.041 | Reg loss: 0.032 | Tree loss: 3.041 | Accuracy: 0.111328 | 6.488 sec/iter
Epoch: 6

Epoch: 70 | Batch: 024 / 025 | Total loss: 3.049 | Reg loss: 0.031 | Tree loss: 3.049 | Accuracy: 0.092473 | 6.474 sec/iter
Average sparseness: 0.9821428571428573
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
layer 7: 0.9821428571428573
layer 8: 0.9821428571428573
Epoch: 71 | Batch: 000 / 025 | Total loss: 3.195 | Reg loss: 0.031 | Tree loss: 3.195 | Accuracy: 0.097656 | 6.474 sec/iter
Epoch: 71 | Batch: 001 / 025 | Total loss: 3.186 | Reg loss: 0.031 | Tree loss: 3.186 | Accuracy: 0.111328 | 6.474 sec/iter
Epoch: 71 | Batch: 002 / 025 | Total loss: 3.173 | Reg loss: 0.031 | Tree loss: 3.173 | Accuracy: 0.097656 | 6.474 sec/iter
Epoch: 71 | Batch: 003 / 025 | Total loss: 3.179 | Reg loss: 0.031 | Tree loss: 3.179 | Accuracy: 0.095703 | 6.474 sec/iter
Epoch: 71 | Batch: 004 / 025 | Total loss: 3.177 | Reg loss: 0.031 | Tree loss: 3.177 | Ac

Epoch: 73 | Batch: 009 / 025 | Total loss: 3.135 | Reg loss: 0.031 | Tree loss: 3.135 | Accuracy: 0.093750 | 6.461 sec/iter
Epoch: 73 | Batch: 010 / 025 | Total loss: 3.055 | Reg loss: 0.031 | Tree loss: 3.055 | Accuracy: 0.103516 | 6.461 sec/iter
Epoch: 73 | Batch: 011 / 025 | Total loss: 3.070 | Reg loss: 0.031 | Tree loss: 3.070 | Accuracy: 0.099609 | 6.461 sec/iter
Epoch: 73 | Batch: 012 / 025 | Total loss: 3.050 | Reg loss: 0.031 | Tree loss: 3.050 | Accuracy: 0.111328 | 6.461 sec/iter
Epoch: 73 | Batch: 013 / 025 | Total loss: 3.051 | Reg loss: 0.031 | Tree loss: 3.051 | Accuracy: 0.126953 | 6.461 sec/iter
Epoch: 73 | Batch: 014 / 025 | Total loss: 3.016 | Reg loss: 0.031 | Tree loss: 3.016 | Accuracy: 0.103516 | 6.461 sec/iter
Epoch: 73 | Batch: 015 / 025 | Total loss: 3.016 | Reg loss: 0.031 | Tree loss: 3.016 | Accuracy: 0.105469 | 6.46 sec/iter
Epoch: 73 | Batch: 016 / 025 | Total loss: 3.007 | Reg loss: 0.031 | Tree loss: 3.007 | Accuracy: 0.103516 | 6.46 sec/iter
Epoch: 73 

Epoch: 75 | Batch: 021 / 025 | Total loss: 3.012 | Reg loss: 0.031 | Tree loss: 3.012 | Accuracy: 0.105469 | 6.449 sec/iter
Epoch: 75 | Batch: 022 / 025 | Total loss: 3.006 | Reg loss: 0.031 | Tree loss: 3.006 | Accuracy: 0.101562 | 6.448 sec/iter
Epoch: 75 | Batch: 023 / 025 | Total loss: 2.957 | Reg loss: 0.031 | Tree loss: 2.957 | Accuracy: 0.117188 | 6.448 sec/iter
Epoch: 75 | Batch: 024 / 025 | Total loss: 2.921 | Reg loss: 0.031 | Tree loss: 2.921 | Accuracy: 0.109677 | 6.447 sec/iter
Average sparseness: 0.9821428571428573
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
layer 7: 0.9821428571428573
layer 8: 0.9821428571428573
Epoch: 76 | Batch: 000 / 025 | Total loss: 3.156 | Reg loss: 0.030 | Tree loss: 3.156 | Accuracy: 0.105469 | 6.447 sec/iter
Epoch: 76 | Batch: 001 / 025 | Total loss: 3.169 | Reg loss: 0.030 | Tree loss: 3.169 | Ac

Epoch: 78 | Batch: 006 / 025 | Total loss: 3.035 | Reg loss: 0.030 | Tree loss: 3.035 | Accuracy: 0.111328 | 6.436 sec/iter
Epoch: 78 | Batch: 007 / 025 | Total loss: 3.114 | Reg loss: 0.030 | Tree loss: 3.114 | Accuracy: 0.083984 | 6.436 sec/iter
Epoch: 78 | Batch: 008 / 025 | Total loss: 3.062 | Reg loss: 0.030 | Tree loss: 3.062 | Accuracy: 0.076172 | 6.435 sec/iter
Epoch: 78 | Batch: 009 / 025 | Total loss: 3.102 | Reg loss: 0.030 | Tree loss: 3.102 | Accuracy: 0.111328 | 6.435 sec/iter
Epoch: 78 | Batch: 010 / 025 | Total loss: 3.025 | Reg loss: 0.030 | Tree loss: 3.025 | Accuracy: 0.107422 | 6.435 sec/iter
Epoch: 78 | Batch: 011 / 025 | Total loss: 3.077 | Reg loss: 0.030 | Tree loss: 3.077 | Accuracy: 0.093750 | 6.435 sec/iter
Epoch: 78 | Batch: 012 / 025 | Total loss: 3.029 | Reg loss: 0.030 | Tree loss: 3.029 | Accuracy: 0.085938 | 6.435 sec/iter
Epoch: 78 | Batch: 013 / 025 | Total loss: 3.027 | Reg loss: 0.030 | Tree loss: 3.027 | Accuracy: 0.105469 | 6.434 sec/iter
Epoch: 7

Epoch: 80 | Batch: 018 / 025 | Total loss: 2.938 | Reg loss: 0.030 | Tree loss: 2.938 | Accuracy: 0.111328 | 6.424 sec/iter
Epoch: 80 | Batch: 019 / 025 | Total loss: 2.940 | Reg loss: 0.030 | Tree loss: 2.940 | Accuracy: 0.107422 | 6.424 sec/iter
Epoch: 80 | Batch: 020 / 025 | Total loss: 2.999 | Reg loss: 0.030 | Tree loss: 2.999 | Accuracy: 0.074219 | 6.424 sec/iter
Epoch: 80 | Batch: 021 / 025 | Total loss: 2.986 | Reg loss: 0.030 | Tree loss: 2.986 | Accuracy: 0.091797 | 6.424 sec/iter
Epoch: 80 | Batch: 022 / 025 | Total loss: 2.973 | Reg loss: 0.030 | Tree loss: 2.973 | Accuracy: 0.089844 | 6.424 sec/iter
Epoch: 80 | Batch: 023 / 025 | Total loss: 2.920 | Reg loss: 0.030 | Tree loss: 2.920 | Accuracy: 0.101562 | 6.423 sec/iter
Epoch: 80 | Batch: 024 / 025 | Total loss: 2.933 | Reg loss: 0.030 | Tree loss: 2.933 | Accuracy: 0.116129 | 6.422 sec/iter
Average sparseness: 0.9821428571428573
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 

Epoch: 83 | Batch: 003 / 025 | Total loss: 3.091 | Reg loss: 0.029 | Tree loss: 3.091 | Accuracy: 0.103516 | 6.415 sec/iter
Epoch: 83 | Batch: 004 / 025 | Total loss: 3.112 | Reg loss: 0.029 | Tree loss: 3.112 | Accuracy: 0.105469 | 6.415 sec/iter
Epoch: 83 | Batch: 005 / 025 | Total loss: 3.107 | Reg loss: 0.029 | Tree loss: 3.107 | Accuracy: 0.105469 | 6.415 sec/iter
Epoch: 83 | Batch: 006 / 025 | Total loss: 3.030 | Reg loss: 0.029 | Tree loss: 3.030 | Accuracy: 0.107422 | 6.415 sec/iter
Epoch: 83 | Batch: 007 / 025 | Total loss: 3.020 | Reg loss: 0.029 | Tree loss: 3.020 | Accuracy: 0.080078 | 6.415 sec/iter
Epoch: 83 | Batch: 008 / 025 | Total loss: 3.008 | Reg loss: 0.029 | Tree loss: 3.008 | Accuracy: 0.117188 | 6.414 sec/iter
Epoch: 83 | Batch: 009 / 025 | Total loss: 3.017 | Reg loss: 0.029 | Tree loss: 3.017 | Accuracy: 0.113281 | 6.414 sec/iter
Epoch: 83 | Batch: 010 / 025 | Total loss: 3.053 | Reg loss: 0.029 | Tree loss: 3.053 | Accuracy: 0.091797 | 6.414 sec/iter
Epoch: 8

Epoch: 85 | Batch: 015 / 025 | Total loss: 2.938 | Reg loss: 0.029 | Tree loss: 2.938 | Accuracy: 0.093750 | 6.405 sec/iter
Epoch: 85 | Batch: 016 / 025 | Total loss: 2.948 | Reg loss: 0.029 | Tree loss: 2.948 | Accuracy: 0.103516 | 6.405 sec/iter
Epoch: 85 | Batch: 017 / 025 | Total loss: 2.985 | Reg loss: 0.029 | Tree loss: 2.985 | Accuracy: 0.091797 | 6.405 sec/iter
Epoch: 85 | Batch: 018 / 025 | Total loss: 2.978 | Reg loss: 0.029 | Tree loss: 2.978 | Accuracy: 0.109375 | 6.405 sec/iter
Epoch: 85 | Batch: 019 / 025 | Total loss: 2.944 | Reg loss: 0.029 | Tree loss: 2.944 | Accuracy: 0.121094 | 6.404 sec/iter
Epoch: 85 | Batch: 020 / 025 | Total loss: 3.029 | Reg loss: 0.029 | Tree loss: 3.029 | Accuracy: 0.097656 | 6.404 sec/iter
Epoch: 85 | Batch: 021 / 025 | Total loss: 3.011 | Reg loss: 0.029 | Tree loss: 3.011 | Accuracy: 0.089844 | 6.404 sec/iter
Epoch: 85 | Batch: 022 / 025 | Total loss: 2.912 | Reg loss: 0.029 | Tree loss: 2.912 | Accuracy: 0.103516 | 6.404 sec/iter
Epoch: 8

Epoch: 88 | Batch: 000 / 025 | Total loss: 3.147 | Reg loss: 0.029 | Tree loss: 3.147 | Accuracy: 0.091797 | 6.396 sec/iter
Epoch: 88 | Batch: 001 / 025 | Total loss: 3.002 | Reg loss: 0.029 | Tree loss: 3.002 | Accuracy: 0.130859 | 6.396 sec/iter
Epoch: 88 | Batch: 002 / 025 | Total loss: 3.073 | Reg loss: 0.029 | Tree loss: 3.073 | Accuracy: 0.095703 | 6.395 sec/iter
Epoch: 88 | Batch: 003 / 025 | Total loss: 3.028 | Reg loss: 0.029 | Tree loss: 3.028 | Accuracy: 0.117188 | 6.395 sec/iter
Epoch: 88 | Batch: 004 / 025 | Total loss: 3.129 | Reg loss: 0.029 | Tree loss: 3.129 | Accuracy: 0.105469 | 6.395 sec/iter
Epoch: 88 | Batch: 005 / 025 | Total loss: 3.047 | Reg loss: 0.029 | Tree loss: 3.047 | Accuracy: 0.126953 | 6.395 sec/iter
Epoch: 88 | Batch: 006 / 025 | Total loss: 3.059 | Reg loss: 0.029 | Tree loss: 3.059 | Accuracy: 0.089844 | 6.395 sec/iter
Epoch: 88 | Batch: 007 / 025 | Total loss: 3.020 | Reg loss: 0.029 | Tree loss: 3.020 | Accuracy: 0.099609 | 6.395 sec/iter
Epoch: 8

Epoch: 90 | Batch: 012 / 025 | Total loss: 2.963 | Reg loss: 0.029 | Tree loss: 2.963 | Accuracy: 0.128906 | 6.387 sec/iter
Epoch: 90 | Batch: 013 / 025 | Total loss: 3.012 | Reg loss: 0.029 | Tree loss: 3.012 | Accuracy: 0.107422 | 6.387 sec/iter
Epoch: 90 | Batch: 014 / 025 | Total loss: 2.970 | Reg loss: 0.029 | Tree loss: 2.970 | Accuracy: 0.095703 | 6.387 sec/iter
Epoch: 90 | Batch: 015 / 025 | Total loss: 3.032 | Reg loss: 0.029 | Tree loss: 3.032 | Accuracy: 0.091797 | 6.387 sec/iter
Epoch: 90 | Batch: 016 / 025 | Total loss: 3.002 | Reg loss: 0.029 | Tree loss: 3.002 | Accuracy: 0.115234 | 6.387 sec/iter
Epoch: 90 | Batch: 017 / 025 | Total loss: 2.931 | Reg loss: 0.029 | Tree loss: 2.931 | Accuracy: 0.125000 | 6.387 sec/iter
Epoch: 90 | Batch: 018 / 025 | Total loss: 2.991 | Reg loss: 0.029 | Tree loss: 2.991 | Accuracy: 0.097656 | 6.387 sec/iter
Epoch: 90 | Batch: 019 / 025 | Total loss: 2.960 | Reg loss: 0.029 | Tree loss: 2.960 | Accuracy: 0.105469 | 6.387 sec/iter
Epoch: 9

Epoch: 92 | Batch: 024 / 025 | Total loss: 2.929 | Reg loss: 0.029 | Tree loss: 2.929 | Accuracy: 0.105376 | 6.38 sec/iter
Average sparseness: 0.9821428571428573
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
layer 7: 0.9821428571428573
layer 8: 0.9821428571428573
Epoch: 93 | Batch: 000 / 025 | Total loss: 3.094 | Reg loss: 0.028 | Tree loss: 3.094 | Accuracy: 0.134766 | 6.38 sec/iter
Epoch: 93 | Batch: 001 / 025 | Total loss: 3.079 | Reg loss: 0.028 | Tree loss: 3.079 | Accuracy: 0.107422 | 6.379 sec/iter
Epoch: 93 | Batch: 002 / 025 | Total loss: 3.107 | Reg loss: 0.028 | Tree loss: 3.107 | Accuracy: 0.089844 | 6.379 sec/iter
Epoch: 93 | Batch: 003 / 025 | Total loss: 3.082 | Reg loss: 0.028 | Tree loss: 3.082 | Accuracy: 0.103516 | 6.379 sec/iter
Epoch: 93 | Batch: 004 / 025 | Total loss: 3.015 | Reg loss: 0.028 | Tree loss: 3.015 | Accu

Epoch: 95 | Batch: 009 / 025 | Total loss: 3.056 | Reg loss: 0.028 | Tree loss: 3.056 | Accuracy: 0.093750 | 6.372 sec/iter
Epoch: 95 | Batch: 010 / 025 | Total loss: 3.006 | Reg loss: 0.028 | Tree loss: 3.006 | Accuracy: 0.105469 | 6.372 sec/iter
Epoch: 95 | Batch: 011 / 025 | Total loss: 2.986 | Reg loss: 0.028 | Tree loss: 2.986 | Accuracy: 0.109375 | 6.372 sec/iter
Epoch: 95 | Batch: 012 / 025 | Total loss: 3.008 | Reg loss: 0.028 | Tree loss: 3.008 | Accuracy: 0.072266 | 6.372 sec/iter
Epoch: 95 | Batch: 013 / 025 | Total loss: 3.000 | Reg loss: 0.028 | Tree loss: 3.000 | Accuracy: 0.091797 | 6.371 sec/iter
Epoch: 95 | Batch: 014 / 025 | Total loss: 3.014 | Reg loss: 0.028 | Tree loss: 3.014 | Accuracy: 0.082031 | 6.371 sec/iter
Epoch: 95 | Batch: 015 / 025 | Total loss: 2.980 | Reg loss: 0.028 | Tree loss: 2.980 | Accuracy: 0.132812 | 6.371 sec/iter
Epoch: 95 | Batch: 016 / 025 | Total loss: 2.960 | Reg loss: 0.028 | Tree loss: 2.960 | Accuracy: 0.125000 | 6.371 sec/iter
Epoch: 9

Epoch: 97 | Batch: 021 / 025 | Total loss: 3.009 | Reg loss: 0.028 | Tree loss: 3.009 | Accuracy: 0.085938 | 6.366 sec/iter
Epoch: 97 | Batch: 022 / 025 | Total loss: 2.968 | Reg loss: 0.028 | Tree loss: 2.968 | Accuracy: 0.099609 | 6.366 sec/iter
Epoch: 97 | Batch: 023 / 025 | Total loss: 2.915 | Reg loss: 0.028 | Tree loss: 2.915 | Accuracy: 0.082031 | 6.366 sec/iter
Epoch: 97 | Batch: 024 / 025 | Total loss: 2.906 | Reg loss: 0.028 | Tree loss: 2.906 | Accuracy: 0.092473 | 6.365 sec/iter
Average sparseness: 0.9821428571428573
layer 0: 0.9821428571428571
layer 1: 0.9821428571428571
layer 2: 0.9821428571428571
layer 3: 0.9821428571428571
layer 4: 0.9821428571428571
layer 5: 0.9821428571428571
layer 6: 0.982142857142857
layer 7: 0.9821428571428573
layer 8: 0.9821428571428573
Epoch: 98 | Batch: 000 / 025 | Total loss: 3.064 | Reg loss: 0.028 | Tree loss: 3.064 | Accuracy: 0.107422 | 6.365 sec/iter
Epoch: 98 | Batch: 001 / 025 | Total loss: 3.066 | Reg loss: 0.028 | Tree loss: 3.066 | Ac

In [77]:
plt.figure(figsize=(10, 5))
plt.ylabel("Accuracy")
plt.xlabel('Iteration')
plt.plot(accs, label='Accuracy vs iteration')
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [None]:
plt.figure()
plt.ylabel("Loss")
plt.xlabel('Iteration')
plt.plot(losses, label='Loss vs iteration')
plt.yscale("log")
plt.show()

plt.figure()
weights = tree.inner_nodes.weight.cpu().detach().numpy().flatten()
plt.hist(weights, bins=500)
weights_std = np.std(weights)
weights_mean = np.mean(weights)
plt.axvline(weights_mean + weights_std, color='r')
plt.axvline(weights_mean - weights_std, color='r')
plt.title(f"Mean: {weights_mean}   |   STD: {weights_std}")
plt.yscale("log")
plt.show()

# Tree Visualization

In [78]:
plt.figure(figsize=(15, 10), dpi=80)
avg_height, root = tree.visualize()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Average height: 9.966666666666667


# Extract Rules

# Accumulate samples in the leaves

In [91]:
print(f"Number of patterns: {len(root.get_leaves())}")

Number of patterns: 990


In [112]:
method = 'greedy'

In [113]:
root.clear_leaves_samples()

with torch.no_grad():
    for batch_idx, (data, target) in enumerate(tree_loader):
        root.accumulate_samples(data, method)

# Tighten boundaries

In [114]:
attr_names = dataset.items

# print(attr_names)
leaves = root.get_leaves()
sum_comprehensibility = 0

for pattern_counter, leaf in enumerate(leaves):
    leaf.reset_path()
    leaf.tighten_with_accumulated_samples()
    conds = leaf.get_path_conditions(attr_names)
    if len(conds) == 0:
        continue
    
    comp = sum([cond.comprehensibility for cond in conds])
    print(f"============== Pattern {pattern_counter + 1} | comprehensibility: {comp} ==============")
    print(conds)
    print()
    print()
    print()
    sum_comprehensibility += comp
    
print(f"Average comprehensibility: {sum_comprehensibility / len(leaves)}")

12753
[-0.22315509617328644 * whole milk + -0.1576598435640335 * yogurt >= tensor(5.8347), -0.22315509617328644 * whole milk + -0.1576598435640335 * yogurt <= tensor(5.8356), 0.22214579582214355 * whole milk + 0.15752269327640533 * yogurt >= tensor(0.0029), 0.22214579582214355 * whole milk + 0.15752269327640533 * yogurt <= tensor(0.3817), -0.22095395624637604 * whole milk + -0.15674421191215515 * yogurt >= tensor(5.8491), -0.22095395624637604 * whole milk + -0.15674421191215515 * yogurt <= tensor(5.8500), 0.21899883449077606 * whole milk + 0.42691516876220703 * rolls/buns >= tensor(0.0030), 0.21899883449077606 * whole milk + 0.42691516876220703 * rolls/buns <= tensor(0.6474), -0.15359069406986237 * yogurt + -0.3081378936767578 * soda >= tensor(5.8474), -0.15359069406986237 * yogurt + -0.3081378936767578 * soda <= tensor(5.8485), -0.20576642453670502 * whole milk + -0.14803259074687958 * yogurt >= tensor(5.8460), -0.20576642453670502 * whole milk + -0.14803259074687958 * yogurt <= tenso