In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib widget
import numpy as np
import pandas as pd
import torch
import torchvision
from torch import optim
import matplotlib.pyplot as plt
import torch.nn.functional as F
from sklearn.metrics import pairwise_distances
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tqdm.notebook import tqdm
import os
import sys
import time
import torch.nn as nn
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)
from stream_generators.mit_bih import MITBIH
from utils.MatplotlibUtils import reduce_dims_and_plot
from network.auto_encoder import AutoEncoder
from losses.knn_loss import KNNLoss, ClassificationKNNLoss
from sklearn.cluster import DBSCAN
from sklearn.mixture import GaussianMixture
from soft_decision_tree.sdt_model import SDT
from sklearn.metrics import davies_bouldin_score

In [2]:
k = 64
tree_depth = 8
batch_size = 512
device = 'cuda'
train_data_path = r'<>/mitbih_train.csv'  # replace <> with the correct path of the dataset
test_data_path = r'<>/mitbih_test.csv'  # replace <> with the correct path of the dataset

In [3]:
train_data_iter = torch.utils.data.DataLoader(MITBIH(train_data_path),
                                             batch_size=batch_size,
                                             shuffle=True,
                                             num_workers=1,
                                             pin_memory=True,
                                             drop_last=True)

test_data_iter = torch.utils.data.DataLoader(MITBIH(test_data_path),
                                             batch_size=batch_size,
                                             shuffle=True,
                                             num_workers=1,
                                             pin_memory=True)

In [4]:
class ConvBlock(nn.Module):
    def __init__(self):
        super(ConvBlock, self).__init__()
        self.conv1 = nn.Conv1d(32, 32, kernel_size=5, stride=1, padding=2)
        self.conv2 = nn.Conv1d(32, 32, kernel_size=5, stride=1, padding=2)
        self.relu1 = nn.ReLU()
        self.relu2 = nn.ReLU()
        self.pool = nn.MaxPool1d(kernel_size=5, stride=2)
        
    def forward(self, x):
        y = x
        y = self.conv1(y)
        y = self.relu1(y)
        y = self.conv2(y)
        y = y + x
        y = self.relu2(y)
        y = self.pool(y)
        return y


class ECGModel(nn.Module):
    def __init__(self):
        super(ECGModel, self).__init__()
        self.conv = nn.Conv1d(1, 32, kernel_size=5, stride=1, padding=1)
        self.block1 = ConvBlock()
        self.block2 = ConvBlock()
        self.block3 = ConvBlock()
        self.block4 = ConvBlock()
        self.block5 = ConvBlock()
        self.fc1 = nn.Linear(64, 32)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(32, 5)

    def forward(self, x, return_interm=False):
        x = self.conv(x)
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        x = self.block4(x)
        x = self.block5(x)
        interm = x.flatten(1)
        x = self.fc1(interm)
        x = self.relu(x)
        x = self.fc2(x)
        
        if return_interm:
            return x, interm
        
        return x

In [5]:
knn_crt = ClassificationKNNLoss(k=k).to(device)

def train(model, loader, optimizer, device):
    model.train()

    total_loss = 0
    for iteration, (batch, target) in enumerate(loader):
        batch = batch.to(device)
        target = target.to(device)
        outputs, interm = model(batch, return_interm=True)
        mse_loss = F.cross_entropy(outputs, target)
        mse_loss = mse_loss.sum(dim=-1).mean()
        try:
            knn_loss = knn_crt(interm, target)
            if torch.isinf(knn_loss):
                knn_loss = torch.tensor(0).to(device)
        except ValueError:
            knn_loss = torch.tensor(0).to(device)
        loss = mse_loss + knn_loss
        total_loss += loss.item()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if iteration % log_every == 0:
            print(f"Epoch {epoch} / {epochs} | iteration {iteration} / {len(loader)} | Total Loss: {loss.item()} | KNN Loss: {knn_loss.item()} | CLS Loss: {mse_loss.item()}")

    return total_loss / len(loader)

@torch.no_grad()
def test(model, loader, device):
    model.eval()
    
    correct = 0
    for iteration, (batch, target) in enumerate(loader):
        batch = batch.to(device)
        target = target.to(device)
        y_pred = model(batch).argmax(dim=-1)
        correct += y_pred.eq(target.view(-1).data).sum()
    
    return correct / len(loader.dataset)

In [6]:
epochs = 200
lr = 1e-3
log_every = 10

model = ECGModel().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
num_params = sum([p.numel() for p in model.parameters()])
print(f'#Params: {num_params}')

#Params: 53957


In [None]:
best_valid_acc = 0
losses = []
train_accs = []
val_accs = []
for epoch in range(1, epochs + 1):
    loss = train(model, train_data_iter, optimizer, device)
#     print(f"Loss: {loss} =============================")
    losses.append(loss)
    train_acc = test(model, train_data_iter, device)
    train_accs.append(train_acc)
    valid_acc = test(model, test_data_iter, device)
    val_accs.append(valid_acc)
    if valid_acc > best_valid_acc:
        best_valid_acc = valid_acc
    if epoch % 1 == 0:
        print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, '
              f'Train: {train_acc:.4f}, Valid: {valid_acc:.4f}, '
              f'Best: {best_valid_acc:.4f}')

Epoch 1 / 200 | iteration 0 / 171 | Total Loss: 7.37387752532959 | KNN Loss: 5.745442867279053 | CLS Loss: 1.6284348964691162
Epoch 1 / 200 | iteration 10 / 171 | Total Loss: 5.379044532775879 | KNN Loss: 4.717100143432617 | CLS Loss: 0.6619442105293274
Epoch 1 / 200 | iteration 20 / 171 | Total Loss: 5.141284465789795 | KNN Loss: 4.541954517364502 | CLS Loss: 0.5993297696113586
Epoch 1 / 200 | iteration 30 / 171 | Total Loss: 5.106218338012695 | KNN Loss: 4.494356632232666 | CLS Loss: 0.6118617057800293
Epoch 1 / 200 | iteration 40 / 171 | Total Loss: 5.020920753479004 | KNN Loss: 4.459837436676025 | CLS Loss: 0.5610834360122681
Epoch 1 / 200 | iteration 50 / 171 | Total Loss: 4.974459171295166 | KNN Loss: 4.389924049377441 | CLS Loss: 0.5845351815223694
Epoch 1 / 200 | iteration 60 / 171 | Total Loss: 4.929912567138672 | KNN Loss: 4.4062910079956055 | CLS Loss: 0.5236214995384216
Epoch 1 / 200 | iteration 70 / 171 | Total Loss: 4.792877674102783 | KNN Loss: 4.397512912750244 | CLS Lo

Epoch 4 / 200 | iteration 80 / 171 | Total Loss: 4.4375739097595215 | KNN Loss: 4.3133392333984375 | CLS Loss: 0.12423483282327652
Epoch 4 / 200 | iteration 90 / 171 | Total Loss: 4.4066643714904785 | KNN Loss: 4.28235387802124 | CLS Loss: 0.12431031465530396
Epoch 4 / 200 | iteration 100 / 171 | Total Loss: 4.430000305175781 | KNN Loss: 4.27595853805542 | CLS Loss: 0.15404152870178223
Epoch 4 / 200 | iteration 110 / 171 | Total Loss: 4.413160800933838 | KNN Loss: 4.305107116699219 | CLS Loss: 0.10805387049913406
Epoch 4 / 200 | iteration 120 / 171 | Total Loss: 4.368527412414551 | KNN Loss: 4.257289409637451 | CLS Loss: 0.11123808473348618
Epoch 4 / 200 | iteration 130 / 171 | Total Loss: 4.407856464385986 | KNN Loss: 4.308313369750977 | CLS Loss: 0.09954296797513962
Epoch 4 / 200 | iteration 140 / 171 | Total Loss: 4.343050956726074 | KNN Loss: 4.2276611328125 | CLS Loss: 0.11538973450660706
Epoch 4 / 200 | iteration 150 / 171 | Total Loss: 4.367811679840088 | KNN Loss: 4.22594022750

Epoch 7 / 200 | iteration 160 / 171 | Total Loss: 4.337508201599121 | KNN Loss: 4.26438570022583 | CLS Loss: 0.07312242686748505
Epoch 7 / 200 | iteration 170 / 171 | Total Loss: 4.38505744934082 | KNN Loss: 4.29645299911499 | CLS Loss: 0.08860433101654053
Epoch: 007, Loss: 4.3477, Train: 0.9780, Valid: 0.9746, Best: 0.9746
Epoch 8 / 200 | iteration 0 / 171 | Total Loss: 4.310878753662109 | KNN Loss: 4.237785339355469 | CLS Loss: 0.0730932205915451
Epoch 8 / 200 | iteration 10 / 171 | Total Loss: 4.317103385925293 | KNN Loss: 4.243649959564209 | CLS Loss: 0.07345327734947205
Epoch 8 / 200 | iteration 20 / 171 | Total Loss: 4.363016128540039 | KNN Loss: 4.258892059326172 | CLS Loss: 0.1041242927312851
Epoch 8 / 200 | iteration 30 / 171 | Total Loss: 4.30294132232666 | KNN Loss: 4.233722686767578 | CLS Loss: 0.06921858340501785
Epoch 8 / 200 | iteration 40 / 171 | Total Loss: 4.381361484527588 | KNN Loss: 4.2399582862854 | CLS Loss: 0.1414029896259308
Epoch 8 / 200 | iteration 50 / 171 |

Epoch 11 / 200 | iteration 60 / 171 | Total Loss: 4.265674591064453 | KNN Loss: 4.198460578918457 | CLS Loss: 0.06721407920122147
Epoch 11 / 200 | iteration 70 / 171 | Total Loss: 4.292954444885254 | KNN Loss: 4.233286380767822 | CLS Loss: 0.05966807156801224
Epoch 11 / 200 | iteration 80 / 171 | Total Loss: 4.345195293426514 | KNN Loss: 4.264439582824707 | CLS Loss: 0.08075588196516037
Epoch 11 / 200 | iteration 90 / 171 | Total Loss: 4.255561828613281 | KNN Loss: 4.189026832580566 | CLS Loss: 0.06653519719839096
Epoch 11 / 200 | iteration 100 / 171 | Total Loss: 4.286505699157715 | KNN Loss: 4.235074996948242 | CLS Loss: 0.05143052339553833
Epoch 11 / 200 | iteration 110 / 171 | Total Loss: 4.33653450012207 | KNN Loss: 4.242529392242432 | CLS Loss: 0.09400498867034912
Epoch 11 / 200 | iteration 120 / 171 | Total Loss: 4.2691545486450195 | KNN Loss: 4.221701622009277 | CLS Loss: 0.047452885657548904
Epoch 11 / 200 | iteration 130 / 171 | Total Loss: 4.316824913024902 | KNN Loss: 4.262

Epoch 14 / 200 | iteration 140 / 171 | Total Loss: 4.3268327713012695 | KNN Loss: 4.242607116699219 | CLS Loss: 0.08422550559043884
Epoch 14 / 200 | iteration 150 / 171 | Total Loss: 4.250906944274902 | KNN Loss: 4.21697473526001 | CLS Loss: 0.033932238817214966
Epoch 14 / 200 | iteration 160 / 171 | Total Loss: 4.29229736328125 | KNN Loss: 4.2303314208984375 | CLS Loss: 0.06196609139442444
Epoch 14 / 200 | iteration 170 / 171 | Total Loss: 4.245389938354492 | KNN Loss: 4.221269607543945 | CLS Loss: 0.024120260030031204
Epoch: 014, Loss: 4.2697, Train: 0.9867, Valid: 0.9825, Best: 0.9825
Epoch 15 / 200 | iteration 0 / 171 | Total Loss: 4.3214945793151855 | KNN Loss: 4.242618083953857 | CLS Loss: 0.07887637615203857
Epoch 15 / 200 | iteration 10 / 171 | Total Loss: 4.244252681732178 | KNN Loss: 4.169705867767334 | CLS Loss: 0.07454696297645569
Epoch 15 / 200 | iteration 20 / 171 | Total Loss: 4.253611087799072 | KNN Loss: 4.202136993408203 | CLS Loss: 0.05147401615977287
Epoch 15 / 200 

Epoch 18 / 200 | iteration 30 / 171 | Total Loss: 4.268491744995117 | KNN Loss: 4.179671764373779 | CLS Loss: 0.08882015198469162
Epoch 18 / 200 | iteration 40 / 171 | Total Loss: 4.274069309234619 | KNN Loss: 4.242717266082764 | CLS Loss: 0.03135199844837189
Epoch 18 / 200 | iteration 50 / 171 | Total Loss: 4.261744976043701 | KNN Loss: 4.19309663772583 | CLS Loss: 0.06864825636148453
Epoch 18 / 200 | iteration 60 / 171 | Total Loss: 4.265486717224121 | KNN Loss: 4.206356525421143 | CLS Loss: 0.0591299831867218
Epoch 18 / 200 | iteration 70 / 171 | Total Loss: 4.2632646560668945 | KNN Loss: 4.203118801116943 | CLS Loss: 0.06014568358659744
Epoch 18 / 200 | iteration 80 / 171 | Total Loss: 4.288754463195801 | KNN Loss: 4.255762100219727 | CLS Loss: 0.03299219161272049
Epoch 18 / 200 | iteration 90 / 171 | Total Loss: 4.256245136260986 | KNN Loss: 4.218506336212158 | CLS Loss: 0.037738800048828125
Epoch 18 / 200 | iteration 100 / 171 | Total Loss: 4.250636577606201 | KNN Loss: 4.2108726

Epoch 21 / 200 | iteration 110 / 171 | Total Loss: 4.291011810302734 | KNN Loss: 4.203880786895752 | CLS Loss: 0.08713094890117645
Epoch 21 / 200 | iteration 120 / 171 | Total Loss: 4.262795925140381 | KNN Loss: 4.209388732910156 | CLS Loss: 0.053407035768032074
Epoch 21 / 200 | iteration 130 / 171 | Total Loss: 4.259098529815674 | KNN Loss: 4.2263593673706055 | CLS Loss: 0.03273935988545418
Epoch 21 / 200 | iteration 140 / 171 | Total Loss: 4.1981964111328125 | KNN Loss: 4.181526184082031 | CLS Loss: 0.01667002961039543
Epoch 21 / 200 | iteration 150 / 171 | Total Loss: 4.229837417602539 | KNN Loss: 4.208588123321533 | CLS Loss: 0.021249476820230484
Epoch 21 / 200 | iteration 160 / 171 | Total Loss: 4.2756195068359375 | KNN Loss: 4.238901138305664 | CLS Loss: 0.03671824559569359
Epoch 21 / 200 | iteration 170 / 171 | Total Loss: 4.293495178222656 | KNN Loss: 4.2436604499816895 | CLS Loss: 0.04983476921916008
Epoch: 021, Loss: 4.2423, Train: 0.9878, Valid: 0.9839, Best: 0.9839
Epoch 22

Epoch: 024, Loss: 4.2356, Train: 0.9896, Valid: 0.9845, Best: 0.9848
Epoch 25 / 200 | iteration 0 / 171 | Total Loss: 4.186246395111084 | KNN Loss: 4.164303779602051 | CLS Loss: 0.021942852064967155
Epoch 25 / 200 | iteration 10 / 171 | Total Loss: 4.168835639953613 | KNN Loss: 4.150075912475586 | CLS Loss: 0.018759753555059433
Epoch 25 / 200 | iteration 20 / 171 | Total Loss: 4.218719482421875 | KNN Loss: 4.181945323944092 | CLS Loss: 0.036774229258298874
Epoch 25 / 200 | iteration 30 / 171 | Total Loss: 4.247707843780518 | KNN Loss: 4.198183059692383 | CLS Loss: 0.049524638801813126
Epoch 25 / 200 | iteration 40 / 171 | Total Loss: 4.2186126708984375 | KNN Loss: 4.18963623046875 | CLS Loss: 0.02897651493549347
Epoch 25 / 200 | iteration 50 / 171 | Total Loss: 4.2046709060668945 | KNN Loss: 4.158708095550537 | CLS Loss: 0.04596279561519623
Epoch 25 / 200 | iteration 60 / 171 | Total Loss: 4.228909492492676 | KNN Loss: 4.199152946472168 | CLS Loss: 0.029756421223282814
Epoch 25 / 200 |

Epoch 28 / 200 | iteration 70 / 171 | Total Loss: 4.265418529510498 | KNN Loss: 4.1763596534729 | CLS Loss: 0.08905909210443497
Epoch 28 / 200 | iteration 80 / 171 | Total Loss: 4.174824237823486 | KNN Loss: 4.154087543487549 | CLS Loss: 0.0207368154078722
Epoch 28 / 200 | iteration 90 / 171 | Total Loss: 4.168670654296875 | KNN Loss: 4.1440277099609375 | CLS Loss: 0.02464302070438862
Epoch 28 / 200 | iteration 100 / 171 | Total Loss: 4.200687408447266 | KNN Loss: 4.175668239593506 | CLS Loss: 0.02501937374472618
Epoch 28 / 200 | iteration 110 / 171 | Total Loss: 4.261422157287598 | KNN Loss: 4.236147403717041 | CLS Loss: 0.025274906307458878
Epoch 28 / 200 | iteration 120 / 171 | Total Loss: 4.206640720367432 | KNN Loss: 4.184180736541748 | CLS Loss: 0.022459914907813072
Epoch 28 / 200 | iteration 130 / 171 | Total Loss: 4.206204414367676 | KNN Loss: 4.1753339767456055 | CLS Loss: 0.030870487913489342
Epoch 28 / 200 | iteration 140 / 171 | Total Loss: 4.25545072555542 | KNN Loss: 4.19

Epoch 31 / 200 | iteration 150 / 171 | Total Loss: 4.227519512176514 | KNN Loss: 4.202352523803711 | CLS Loss: 0.025166820734739304
Epoch 31 / 200 | iteration 160 / 171 | Total Loss: 4.203866004943848 | KNN Loss: 4.190859317779541 | CLS Loss: 0.013006502762436867
Epoch 31 / 200 | iteration 170 / 171 | Total Loss: 4.2616729736328125 | KNN Loss: 4.1755523681640625 | CLS Loss: 0.08612081408500671
Epoch: 031, Loss: 4.2113, Train: 0.9911, Valid: 0.9845, Best: 0.9863
Epoch 32 / 200 | iteration 0 / 171 | Total Loss: 4.196518898010254 | KNN Loss: 4.182875156402588 | CLS Loss: 0.013643544167280197
Epoch 32 / 200 | iteration 10 / 171 | Total Loss: 4.194110870361328 | KNN Loss: 4.180301189422607 | CLS Loss: 0.013809784315526485
Epoch 32 / 200 | iteration 20 / 171 | Total Loss: 4.194121360778809 | KNN Loss: 4.173984050750732 | CLS Loss: 0.020137427374720573
Epoch 32 / 200 | iteration 30 / 171 | Total Loss: 4.182323932647705 | KNN Loss: 4.1590962409973145 | CLS Loss: 0.023227810859680176
Epoch 32 /

Epoch 35 / 200 | iteration 40 / 171 | Total Loss: 4.200634479522705 | KNN Loss: 4.173842906951904 | CLS Loss: 0.026791518554091454
Epoch 35 / 200 | iteration 50 / 171 | Total Loss: 4.175117492675781 | KNN Loss: 4.145471096038818 | CLS Loss: 0.029646478593349457
Epoch 35 / 200 | iteration 60 / 171 | Total Loss: 4.202883720397949 | KNN Loss: 4.1836724281311035 | CLS Loss: 0.019211314618587494
Epoch 35 / 200 | iteration 70 / 171 | Total Loss: 4.230337142944336 | KNN Loss: 4.192715167999268 | CLS Loss: 0.03762213513255119
Epoch 35 / 200 | iteration 80 / 171 | Total Loss: 4.137632846832275 | KNN Loss: 4.116912841796875 | CLS Loss: 0.02072017453610897
Epoch 35 / 200 | iteration 90 / 171 | Total Loss: 4.1775054931640625 | KNN Loss: 4.134088039398193 | CLS Loss: 0.043417561799287796
Epoch 35 / 200 | iteration 100 / 171 | Total Loss: 4.188200950622559 | KNN Loss: 4.172223091125488 | CLS Loss: 0.01597805880010128
Epoch 35 / 200 | iteration 110 / 171 | Total Loss: 4.187870979309082 | KNN Loss: 4.

Epoch 38 / 200 | iteration 110 / 171 | Total Loss: 4.208645343780518 | KNN Loss: 4.191484451293945 | CLS Loss: 0.01716066710650921
Epoch 38 / 200 | iteration 120 / 171 | Total Loss: 4.175397872924805 | KNN Loss: 4.131710529327393 | CLS Loss: 0.04368741437792778
Epoch 38 / 200 | iteration 130 / 171 | Total Loss: 4.2022786140441895 | KNN Loss: 4.165480613708496 | CLS Loss: 0.03679820895195007
Epoch 38 / 200 | iteration 140 / 171 | Total Loss: 4.199082851409912 | KNN Loss: 4.186020374298096 | CLS Loss: 0.013062701560556889
Epoch 38 / 200 | iteration 150 / 171 | Total Loss: 4.1941046714782715 | KNN Loss: 4.159545421600342 | CLS Loss: 0.03455926850438118
Epoch 38 / 200 | iteration 160 / 171 | Total Loss: 4.205271244049072 | KNN Loss: 4.168568134307861 | CLS Loss: 0.036703210324048996
Epoch 38 / 200 | iteration 170 / 171 | Total Loss: 4.1857008934021 | KNN Loss: 4.1565446853637695 | CLS Loss: 0.02915632352232933
Epoch: 038, Loss: 4.1978, Train: 0.9924, Valid: 0.9856, Best: 0.9863
Epoch 39 / 

Epoch: 041, Loss: 4.1919, Train: 0.9928, Valid: 0.9869, Best: 0.9869
Epoch 42 / 200 | iteration 0 / 171 | Total Loss: 4.199117183685303 | KNN Loss: 4.171749114990234 | CLS Loss: 0.027368221431970596
Epoch 42 / 200 | iteration 10 / 171 | Total Loss: 4.185390472412109 | KNN Loss: 4.147750377655029 | CLS Loss: 0.03764018043875694
Epoch 42 / 200 | iteration 20 / 171 | Total Loss: 4.162098407745361 | KNN Loss: 4.134032249450684 | CLS Loss: 0.0280663650482893
Epoch 42 / 200 | iteration 30 / 171 | Total Loss: 4.173219680786133 | KNN Loss: 4.157772541046143 | CLS Loss: 0.015447165817022324
Epoch 42 / 200 | iteration 40 / 171 | Total Loss: 4.1575236320495605 | KNN Loss: 4.150005340576172 | CLS Loss: 0.0075180851854383945
Epoch 42 / 200 | iteration 50 / 171 | Total Loss: 4.194015026092529 | KNN Loss: 4.1768622398376465 | CLS Loss: 0.017152631655335426
Epoch 42 / 200 | iteration 60 / 171 | Total Loss: 4.164169788360596 | KNN Loss: 4.13985013961792 | CLS Loss: 0.024319753050804138
Epoch 42 / 200 |

Epoch 45 / 200 | iteration 70 / 171 | Total Loss: 4.195741176605225 | KNN Loss: 4.174771308898926 | CLS Loss: 0.02096995897591114
Epoch 45 / 200 | iteration 80 / 171 | Total Loss: 4.240171909332275 | KNN Loss: 4.206181049346924 | CLS Loss: 0.03399086743593216
Epoch 45 / 200 | iteration 90 / 171 | Total Loss: 4.183694362640381 | KNN Loss: 4.151634216308594 | CLS Loss: 0.032060373574495316
Epoch 45 / 200 | iteration 100 / 171 | Total Loss: 4.1404709815979 | KNN Loss: 4.127050876617432 | CLS Loss: 0.01341994572430849
Epoch 45 / 200 | iteration 110 / 171 | Total Loss: 4.148603916168213 | KNN Loss: 4.131430625915527 | CLS Loss: 0.01717349700629711
Epoch 45 / 200 | iteration 120 / 171 | Total Loss: 4.173666000366211 | KNN Loss: 4.138299942016602 | CLS Loss: 0.03536606580018997
Epoch 45 / 200 | iteration 130 / 171 | Total Loss: 4.18219518661499 | KNN Loss: 4.1590447425842285 | CLS Loss: 0.02315058745443821
Epoch 45 / 200 | iteration 140 / 171 | Total Loss: 4.176926612854004 | KNN Loss: 4.1576

Epoch 48 / 200 | iteration 140 / 171 | Total Loss: 4.214667320251465 | KNN Loss: 4.183526515960693 | CLS Loss: 0.031140761449933052
Epoch 48 / 200 | iteration 150 / 171 | Total Loss: 4.174677848815918 | KNN Loss: 4.163699626922607 | CLS Loss: 0.010978439822793007
Epoch 48 / 200 | iteration 160 / 171 | Total Loss: 4.226525783538818 | KNN Loss: 4.206968784332275 | CLS Loss: 0.019557001069188118
Epoch 48 / 200 | iteration 170 / 171 | Total Loss: 4.168240070343018 | KNN Loss: 4.135058879852295 | CLS Loss: 0.03318142145872116
Epoch: 048, Loss: 4.1807, Train: 0.9944, Valid: 0.9862, Best: 0.9869
Epoch 49 / 200 | iteration 0 / 171 | Total Loss: 4.194820880889893 | KNN Loss: 4.185375690460205 | CLS Loss: 0.009445064701139927
Epoch 49 / 200 | iteration 10 / 171 | Total Loss: 4.1533589363098145 | KNN Loss: 4.146146297454834 | CLS Loss: 0.0072127413004636765
Epoch 49 / 200 | iteration 20 / 171 | Total Loss: 4.1549530029296875 | KNN Loss: 4.149294853210449 | CLS Loss: 0.005658128298819065
Epoch 49 

Epoch 52 / 200 | iteration 30 / 171 | Total Loss: 4.196008682250977 | KNN Loss: 4.175873279571533 | CLS Loss: 0.020135505124926567
Epoch 52 / 200 | iteration 40 / 171 | Total Loss: 4.173285007476807 | KNN Loss: 4.135589122772217 | CLS Loss: 0.03769611194729805
Epoch 52 / 200 | iteration 50 / 171 | Total Loss: 4.125834941864014 | KNN Loss: 4.106276988983154 | CLS Loss: 0.019558018073439598
Epoch 52 / 200 | iteration 60 / 171 | Total Loss: 4.18631649017334 | KNN Loss: 4.166380405426025 | CLS Loss: 0.019936062395572662
Epoch 52 / 200 | iteration 70 / 171 | Total Loss: 4.168275833129883 | KNN Loss: 4.138324737548828 | CLS Loss: 0.029951177537441254
Epoch 52 / 200 | iteration 80 / 171 | Total Loss: 4.1514482498168945 | KNN Loss: 4.132675647735596 | CLS Loss: 0.018772650510072708
Epoch 52 / 200 | iteration 90 / 171 | Total Loss: 4.154571056365967 | KNN Loss: 4.1483845710754395 | CLS Loss: 0.006186395883560181
Epoch 52 / 200 | iteration 100 / 171 | Total Loss: 4.1694231033325195 | KNN Loss: 4

Epoch 55 / 200 | iteration 100 / 171 | Total Loss: 4.155078411102295 | KNN Loss: 4.135499000549316 | CLS Loss: 0.01957947388291359
Epoch 55 / 200 | iteration 110 / 171 | Total Loss: 4.178250312805176 | KNN Loss: 4.141901016235352 | CLS Loss: 0.036349281668663025
Epoch 55 / 200 | iteration 120 / 171 | Total Loss: 4.185174465179443 | KNN Loss: 4.161858558654785 | CLS Loss: 0.02331603690981865
Epoch 55 / 200 | iteration 130 / 171 | Total Loss: 4.171797275543213 | KNN Loss: 4.151957988739014 | CLS Loss: 0.019839385524392128
Epoch 55 / 200 | iteration 140 / 171 | Total Loss: 4.169252872467041 | KNN Loss: 4.16044807434082 | CLS Loss: 0.008804881945252419
Epoch 55 / 200 | iteration 150 / 171 | Total Loss: 4.143223285675049 | KNN Loss: 4.126646041870117 | CLS Loss: 0.016577070578932762
Epoch 55 / 200 | iteration 160 / 171 | Total Loss: 4.174678325653076 | KNN Loss: 4.156030654907227 | CLS Loss: 0.018647542223334312
Epoch 55 / 200 | iteration 170 / 171 | Total Loss: 4.200728416442871 | KNN Loss

Epoch 58 / 200 | iteration 170 / 171 | Total Loss: 4.156519412994385 | KNN Loss: 4.12684440612793 | CLS Loss: 0.029674824327230453
Epoch: 058, Loss: 4.1752, Train: 0.9953, Valid: 0.9870, Best: 0.9870
Epoch 59 / 200 | iteration 0 / 171 | Total Loss: 4.228031635284424 | KNN Loss: 4.2226080894470215 | CLS Loss: 0.005423477850854397
Epoch 59 / 200 | iteration 10 / 171 | Total Loss: 4.157278060913086 | KNN Loss: 4.120929718017578 | CLS Loss: 0.03634849563241005
Epoch 59 / 200 | iteration 20 / 171 | Total Loss: 4.184452056884766 | KNN Loss: 4.171756267547607 | CLS Loss: 0.012695957906544209
Epoch 59 / 200 | iteration 30 / 171 | Total Loss: 4.147996425628662 | KNN Loss: 4.115864276885986 | CLS Loss: 0.032132092863321304
Epoch 59 / 200 | iteration 40 / 171 | Total Loss: 4.15817928314209 | KNN Loss: 4.138386249542236 | CLS Loss: 0.01979309320449829
Epoch 59 / 200 | iteration 50 / 171 | Total Loss: 4.180298328399658 | KNN Loss: 4.169809341430664 | CLS Loss: 0.010488756000995636
Epoch 59 / 200 | 

Epoch 62 / 200 | iteration 60 / 171 | Total Loss: 4.244478225708008 | KNN Loss: 4.214363098144531 | CLS Loss: 0.030115094035863876
Epoch 62 / 200 | iteration 70 / 171 | Total Loss: 4.1530656814575195 | KNN Loss: 4.145619869232178 | CLS Loss: 0.007445921655744314
Epoch 62 / 200 | iteration 80 / 171 | Total Loss: 4.192887783050537 | KNN Loss: 4.178734302520752 | CLS Loss: 0.014153496362268925
Epoch 62 / 200 | iteration 90 / 171 | Total Loss: 4.157037734985352 | KNN Loss: 4.148182392120361 | CLS Loss: 0.008855515159666538
Epoch 62 / 200 | iteration 100 / 171 | Total Loss: 4.22615909576416 | KNN Loss: 4.188061714172363 | CLS Loss: 0.03809729963541031
Epoch 62 / 200 | iteration 110 / 171 | Total Loss: 4.161397457122803 | KNN Loss: 4.142932891845703 | CLS Loss: 0.018464647233486176
Epoch 62 / 200 | iteration 120 / 171 | Total Loss: 4.121703624725342 | KNN Loss: 4.11791467666626 | CLS Loss: 0.0037890055682510138
Epoch 62 / 200 | iteration 130 / 171 | Total Loss: 4.186787128448486 | KNN Loss: 

Epoch 65 / 200 | iteration 130 / 171 | Total Loss: 4.148693084716797 | KNN Loss: 4.131814956665039 | CLS Loss: 0.016878211870789528
Epoch 65 / 200 | iteration 140 / 171 | Total Loss: 4.119271278381348 | KNN Loss: 4.104356288909912 | CLS Loss: 0.014915036037564278
Epoch 65 / 200 | iteration 150 / 171 | Total Loss: 4.254508018493652 | KNN Loss: 4.225152492523193 | CLS Loss: 0.029355740174651146
Epoch 65 / 200 | iteration 160 / 171 | Total Loss: 4.1954264640808105 | KNN Loss: 4.18572473526001 | CLS Loss: 0.009701497852802277
Epoch 65 / 200 | iteration 170 / 171 | Total Loss: 4.151735782623291 | KNN Loss: 4.1447434425354 | CLS Loss: 0.006992333568632603
Epoch: 065, Loss: 4.1721, Train: 0.9947, Valid: 0.9851, Best: 0.9871
Epoch 66 / 200 | iteration 0 / 171 | Total Loss: 4.20805025100708 | KNN Loss: 4.175268650054932 | CLS Loss: 0.032781705260276794
Epoch 66 / 200 | iteration 10 / 171 | Total Loss: 4.166605472564697 | KNN Loss: 4.160793304443359 | CLS Loss: 0.005812346935272217
Epoch 66 / 20

Epoch 69 / 200 | iteration 20 / 171 | Total Loss: 4.1215715408325195 | KNN Loss: 4.08599853515625 | CLS Loss: 0.03557300195097923
Epoch 69 / 200 | iteration 30 / 171 | Total Loss: 4.150624752044678 | KNN Loss: 4.1423773765563965 | CLS Loss: 0.008247347548604012
Epoch 69 / 200 | iteration 40 / 171 | Total Loss: 4.166671276092529 | KNN Loss: 4.147459506988525 | CLS Loss: 0.019211871549487114
Epoch 69 / 200 | iteration 50 / 171 | Total Loss: 4.177854061126709 | KNN Loss: 4.1699442863464355 | CLS Loss: 0.007909662090241909
Epoch 69 / 200 | iteration 60 / 171 | Total Loss: 4.138155937194824 | KNN Loss: 4.1307759284973145 | CLS Loss: 0.0073798056691884995
Epoch 69 / 200 | iteration 70 / 171 | Total Loss: 4.173851490020752 | KNN Loss: 4.156491279602051 | CLS Loss: 0.017360180616378784
Epoch 69 / 200 | iteration 80 / 171 | Total Loss: 4.178874969482422 | KNN Loss: 4.16314697265625 | CLS Loss: 0.015727803111076355
Epoch 69 / 200 | iteration 90 / 171 | Total Loss: 4.151392459869385 | KNN Loss: 4

Epoch 72 / 200 | iteration 90 / 171 | Total Loss: 4.145737648010254 | KNN Loss: 4.141435623168945 | CLS Loss: 0.004302092362195253
Epoch 72 / 200 | iteration 100 / 171 | Total Loss: 4.175812244415283 | KNN Loss: 4.157259464263916 | CLS Loss: 0.018552837893366814
Epoch 72 / 200 | iteration 110 / 171 | Total Loss: 4.169170379638672 | KNN Loss: 4.148924350738525 | CLS Loss: 0.020245933905243874
Epoch 72 / 200 | iteration 120 / 171 | Total Loss: 4.164619445800781 | KNN Loss: 4.154659748077393 | CLS Loss: 0.009959859773516655
Epoch 72 / 200 | iteration 130 / 171 | Total Loss: 4.264278411865234 | KNN Loss: 4.203948497772217 | CLS Loss: 0.06032988801598549
Epoch 72 / 200 | iteration 140 / 171 | Total Loss: 4.160919666290283 | KNN Loss: 4.14577579498291 | CLS Loss: 0.015143807046115398
Epoch 72 / 200 | iteration 150 / 171 | Total Loss: 4.183973789215088 | KNN Loss: 4.153616428375244 | CLS Loss: 0.030357303097844124
Epoch 72 / 200 | iteration 160 / 171 | Total Loss: 4.164931297302246 | KNN Loss

Epoch 75 / 200 | iteration 160 / 171 | Total Loss: 4.222021102905273 | KNN Loss: 4.214169979095459 | CLS Loss: 0.00785116944462061
Epoch 75 / 200 | iteration 170 / 171 | Total Loss: 4.131816864013672 | KNN Loss: 4.116368770599365 | CLS Loss: 0.015447967685759068
Epoch: 075, Loss: 4.1815, Train: 0.9948, Valid: 0.9859, Best: 0.9871
Epoch 76 / 200 | iteration 0 / 171 | Total Loss: 4.248978614807129 | KNN Loss: 4.218338966369629 | CLS Loss: 0.030639667063951492
Epoch 76 / 200 | iteration 10 / 171 | Total Loss: 4.192474365234375 | KNN Loss: 4.165014743804932 | CLS Loss: 0.027459723874926567
Epoch 76 / 200 | iteration 20 / 171 | Total Loss: 4.174904823303223 | KNN Loss: 4.144264221191406 | CLS Loss: 0.030640719458460808
Epoch 76 / 200 | iteration 30 / 171 | Total Loss: 4.182766437530518 | KNN Loss: 4.168643951416016 | CLS Loss: 0.014122456312179565
Epoch 76 / 200 | iteration 40 / 171 | Total Loss: 4.168797016143799 | KNN Loss: 4.156519412994385 | CLS Loss: 0.012277767062187195
Epoch 76 / 200

In [8]:
test(model, test_data_iter, device)

tensor(0.9884, device='cuda:0')

In [9]:
plt.figure()
plt.plot(losses, label='train loss')
plt.legend()
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [10]:
plt.figure()
plt.plot(train_accs, label='train accuracy')
plt.plot(val_accs, label='validation accuracy')
plt.legend()
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [12]:
test_samples = torch.tensor([])
projections = torch.tensor([])
labels = torch.tensor([])

with torch.no_grad():
    for x, y in tqdm(test_data_iter):
        test_samples = torch.cat([test_samples, x])
        labels = torch.cat([labels, y])
        x = x.to(device)
        _, interm = model(x, True)
        projections = torch.cat([projections, interm.detach().cpu().flatten(1)])

  0%|          | 0/43 [00:00<?, ?it/s]

In [13]:
distances = pairwise_distances(projections)
# distances = np.triu(distances)
distances_f = distances.flatten()

plt.matshow(distances)
plt.colorbar()
plt.figure()
plt.hist(distances_f[distances_f > 0], bins=1000)
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [14]:
clusters = DBSCAN(eps=2, min_samples=10).fit_predict(projections)

In [15]:
print(f"Number of inliers: {sum(clusters != -1) / len(clusters)}")

Number of inliers: 0.8989539079987209


In [16]:
perplexity = 100
p = reduce_dims_and_plot(projections[clusters != -1],
                         y=clusters[clusters != -1],
                         title=f'perplexity: {perplexity}',
                         file_name=None,
                         perplexity=perplexity,
                         library='Multicore-TSNE',
                         perform_PCA=False,
                         projected=None,
                         figure_type='2d',
                         show_figure=True,
                         close_figure=False,
                         text=None)    

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

# Train a Soft-Decision-Tree given the self-labels

## Prepare the dataset

In [17]:
tree_dataset = list(zip(test_samples.flatten(1)[clusters!=-1], clusters[clusters != -1]))
batch_size = 512
tree_loader = torch.utils.data.DataLoader(tree_dataset, batch_size=batch_size, shuffle=True)

# Define how we prune the weights of a node

In [18]:
def prune_node(node_weights, factor=1):
    w = node_weights.cpu().detach().numpy()
    mean_ = np.mean(w)
    std_ = np.std(w)
    node_weights[((mean_ - std_ * factor) < node_weights) & (node_weights < (mean_ + std_ * factor))] = 0
    return node_weights

def prune_node_keep(node_weights, keep=4):
    w = node_weights.cpu().detach().numpy()
    throw_idx = np.argsort(abs(w))[:-keep]
    node_weights[throw_idx] = 0
    return node_weights

def prune_tree(tree_, factor):
    new_weights = tree_.inner_nodes.weight.clone()
    for i in range(new_weights.shape[0]):
        res = prune_node_keep(new_weights[i, :], factor)
        new_weights[i, :] = res

    with torch.no_grad():
        tree_.inner_nodes.weight.copy_(new_weights)
        
def sparseness(x):
    s = []
    for i in range(x.shape[0]):
        x_ = x[i, :]
        sp = (len(x_) - torch.norm(x_, 0).item()) / len(x_)
        s.append(sp)
    return np.mean(s)

def compute_regularization_by_level(tree):
    total_reg = 0
    for i in range(tree.inner_nodes.weight.shape[0]):
        cur_level = np.floor(np.log2(i+1))
        node_reg = torch.norm(tree.inner_nodes.weight[i].view(-1), 2)
        total_reg += 2**(-cur_level) * node_reg
    return total_reg

def show_sparseness(tree):
    avg_sp = sparseness(tree.inner_nodes.weight)
    print(f"Average sparseness: {avg_sp}")
    layer = 0
    sps = []
    for i in range(tree.inner_nodes.weight.shape[0]):
        cur_layer = int(np.floor(np.log2(i+1)))
        if cur_layer != layer:
            print(f"layer {layer}: {np.mean(sps)}")
            sps = []
            layer = cur_layer

        x_ = tree.inner_nodes.weight[i, :]
        sp = (len(x_) - torch.norm(x_, 0).item()) / len(x_)
        sps.append(sp)
        
    return avg_sp

## Training configurations

In [19]:
def do_epoch(model, loader, device, log_interval, losses, accs, epoch, iteration):
    model = model.train()
    for batch_idx, (data, target) in enumerate(loader):
        iteration += 1
        data, target = data.to(device), target.to(device)

        output, penalty = tree.forward(data)

        # Loss
        loss_tree = criterion(output, target.view(-1))

        # Penalty
        loss_tree += penalty

        # Sparse regularization
#         fc_params = torch.cat([x.view(-1) for x in tree.inner_nodes.parameters()])
#         regularization = sparsity_lamda * torch.norm(fc_params, 2)
        regularization = sparsity_lamda * compute_regularization_by_level(tree)
        loss = loss_tree

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        losses.append(loss.item())

        pred = output.data.max(1)[1]
        correct = pred.eq(target.view(-1).data).sum()
        accs.append(correct.item() / data.size()[0])

        # Print training status
        if batch_idx % log_interval == 0:
            print(f"Epoch: {epoch:02d} | Batch: {batch_idx:03d} / {len(loader):03d} | Total loss: {loss.item():.3f} | Reg loss: {regularization.item():.3f} | Tree loss: {loss_tree.item():.3f} | Accuracy: {correct.item() / data.size()[0]:03f} | {round((time.time() - start_time) / iteration, 3)} sec/iter")
            
    return iteration


In [20]:
lr = 5e-3
weight_decay = 5e-4
sparsity_lamda = 2e-3
epochs = 400
log_interval = 100
use_cuda = device != 'cpu'

In [21]:
tree = SDT(input_dim=test_samples.shape[2], output_dim=len(set(clusters)) - 1, depth=tree_depth, lamda=1e-3, use_cuda=use_cuda)
optimizer = torch.optim.Adam(tree.parameters(),
                                 lr=lr,
                                 weight_decay=weight_decay)
criterion = nn.CrossEntropyLoss()
tree = tree.to(device)

In [22]:
losses = []
accs = []
sparsity = []

In [23]:
start_time = time.time()
iteration = 0
for epoch in range(epochs):
    # Training
    avg_sp = show_sparseness(tree)
    sparsity.append(avg_sp)
    iteration = do_epoch(tree, tree_loader, device, log_interval, losses, accs, epoch, iteration)
    
    if epoch % 1 == 0:
        prune_tree(tree, factor=3)
        

Average sparseness: 0.0
layer 0: 0.0
layer 1: 0.0
layer 2: 0.0
layer 3: 0.0
layer 4: 0.0
layer 5: 0.0
layer 6: 0.0
Epoch: 00 | Batch: 000 / 039 | Total loss: 2.083 | Reg loss: 0.009 | Tree loss: 2.083 | Accuracy: 0.005859 | 0.245 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
Epoch: 01 | Batch: 000 / 039 | Total loss: 1.779 | Reg loss: 0.005 | Tree loss: 1.779 | Accuracy: 0.894531 | 0.219 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
Epoch: 02 | Batch: 000 / 039 | Total loss: 1.589 | Reg loss: 0.007 | Tree loss: 1.589 | Accuracy: 0.896484 | 0.22 sec/iter
Average sparseness: 0.9840425531914894
layer

Epoch: 23 | Batch: 000 / 039 | Total loss: 0.460 | Reg loss: 0.016 | Tree loss: 0.460 | Accuracy: 0.900391 | 0.217 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
Epoch: 24 | Batch: 000 / 039 | Total loss: 0.407 | Reg loss: 0.016 | Tree loss: 0.407 | Accuracy: 0.917969 | 0.217 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
Epoch: 25 | Batch: 000 / 039 | Total loss: 0.416 | Reg loss: 0.016 | Tree loss: 0.416 | Accuracy: 0.912109 | 0.217 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4

Epoch: 46 | Batch: 000 / 039 | Total loss: 0.368 | Reg loss: 0.015 | Tree loss: 0.368 | Accuracy: 0.914062 | 0.217 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
Epoch: 47 | Batch: 000 / 039 | Total loss: 0.345 | Reg loss: 0.015 | Tree loss: 0.345 | Accuracy: 0.921875 | 0.217 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
Epoch: 48 | Batch: 000 / 039 | Total loss: 0.424 | Reg loss: 0.015 | Tree loss: 0.424 | Accuracy: 0.882812 | 0.217 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4

Epoch: 69 | Batch: 000 / 039 | Total loss: 0.344 | Reg loss: 0.015 | Tree loss: 0.344 | Accuracy: 0.947266 | 0.217 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
Epoch: 70 | Batch: 000 / 039 | Total loss: 0.402 | Reg loss: 0.015 | Tree loss: 0.402 | Accuracy: 0.925781 | 0.217 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
Epoch: 71 | Batch: 000 / 039 | Total loss: 0.358 | Reg loss: 0.015 | Tree loss: 0.358 | Accuracy: 0.929688 | 0.217 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4

Epoch: 92 | Batch: 000 / 039 | Total loss: 0.340 | Reg loss: 0.015 | Tree loss: 0.340 | Accuracy: 0.937500 | 0.217 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
Epoch: 93 | Batch: 000 / 039 | Total loss: 0.342 | Reg loss: 0.015 | Tree loss: 0.342 | Accuracy: 0.945312 | 0.217 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
Epoch: 94 | Batch: 000 / 039 | Total loss: 0.345 | Reg loss: 0.015 | Tree loss: 0.345 | Accuracy: 0.935547 | 0.217 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4

Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
Epoch: 116 | Batch: 000 / 039 | Total loss: 0.347 | Reg loss: 0.015 | Tree loss: 0.347 | Accuracy: 0.931641 | 0.215 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
Epoch: 117 | Batch: 000 / 039 | Total loss: 0.430 | Reg loss: 0.015 | Tree loss: 0.430 | Accuracy: 0.919922 | 0.215 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
Epoch: 118 | Batch: 000 / 039 | Total loss: 0

Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
Epoch: 139 | Batch: 000 / 039 | Total loss: 0.372 | Reg loss: 0.015 | Tree loss: 0.372 | Accuracy: 0.923828 | 0.212 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
Epoch: 140 | Batch: 000 / 039 | Total loss: 0.361 | Reg loss: 0.015 | Tree loss: 0.361 | Accuracy: 0.941406 | 0.212 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
Epoch: 141 | Batch: 000 / 039 | Total loss: 0

Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
Epoch: 162 | Batch: 000 / 039 | Total loss: 0.344 | Reg loss: 0.015 | Tree loss: 0.344 | Accuracy: 0.933594 | 0.209 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
Epoch: 163 | Batch: 000 / 039 | Total loss: 0.338 | Reg loss: 0.015 | Tree loss: 0.338 | Accuracy: 0.945312 | 0.209 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
Epoch: 164 | Batch: 000 / 039 | Total loss: 0

Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
Epoch: 185 | Batch: 000 / 039 | Total loss: 0.371 | Reg loss: 0.015 | Tree loss: 0.371 | Accuracy: 0.914062 | 0.207 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
Epoch: 186 | Batch: 000 / 039 | Total loss: 0.345 | Reg loss: 0.015 | Tree loss: 0.345 | Accuracy: 0.935547 | 0.207 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
Epoch: 187 | Batch: 000 / 039 | Total loss: 0

Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
Epoch: 208 | Batch: 000 / 039 | Total loss: 0.376 | Reg loss: 0.015 | Tree loss: 0.376 | Accuracy: 0.921875 | 0.206 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
Epoch: 209 | Batch: 000 / 039 | Total loss: 0.386 | Reg loss: 0.014 | Tree loss: 0.386 | Accuracy: 0.917969 | 0.206 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
Epoch: 210 | Batch: 000 / 039 | Total loss: 0

Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
Epoch: 231 | Batch: 000 / 039 | Total loss: 0.340 | Reg loss: 0.014 | Tree loss: 0.340 | Accuracy: 0.929688 | 0.205 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
Epoch: 232 | Batch: 000 / 039 | Total loss: 0.391 | Reg loss: 0.014 | Tree loss: 0.391 | Accuracy: 0.914062 | 0.205 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
Epoch: 233 | Batch: 000 / 039 | Total loss: 0

Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
Epoch: 254 | Batch: 000 / 039 | Total loss: 0.419 | Reg loss: 0.014 | Tree loss: 0.419 | Accuracy: 0.894531 | 0.204 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
Epoch: 255 | Batch: 000 / 039 | Total loss: 0.339 | Reg loss: 0.014 | Tree loss: 0.339 | Accuracy: 0.935547 | 0.204 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
Epoch: 256 | Batch: 000 / 039 | Total loss: 0

Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
Epoch: 277 | Batch: 000 / 039 | Total loss: 0.328 | Reg loss: 0.014 | Tree loss: 0.328 | Accuracy: 0.939453 | 0.203 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
Epoch: 278 | Batch: 000 / 039 | Total loss: 0.331 | Reg loss: 0.014 | Tree loss: 0.331 | Accuracy: 0.945312 | 0.203 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
Epoch: 279 | Batch: 000 / 039 | Total loss: 0

Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
Epoch: 300 | Batch: 000 / 039 | Total loss: 0.358 | Reg loss: 0.014 | Tree loss: 0.358 | Accuracy: 0.925781 | 0.203 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
Epoch: 301 | Batch: 000 / 039 | Total loss: 0.378 | Reg loss: 0.014 | Tree loss: 0.378 | Accuracy: 0.914062 | 0.203 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
Epoch: 302 | Batch: 000 / 039 | Total loss: 0

Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
Epoch: 323 | Batch: 000 / 039 | Total loss: 0.286 | Reg loss: 0.014 | Tree loss: 0.286 | Accuracy: 0.947266 | 0.202 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
Epoch: 324 | Batch: 000 / 039 | Total loss: 0.345 | Reg loss: 0.014 | Tree loss: 0.345 | Accuracy: 0.933594 | 0.202 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
Epoch: 325 | Batch: 000 / 039 | Total loss: 0

Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
Epoch: 346 | Batch: 000 / 039 | Total loss: 0.383 | Reg loss: 0.014 | Tree loss: 0.383 | Accuracy: 0.910156 | 0.202 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
Epoch: 347 | Batch: 000 / 039 | Total loss: 0.398 | Reg loss: 0.014 | Tree loss: 0.398 | Accuracy: 0.906250 | 0.202 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
Epoch: 348 | Batch: 000 / 039 | Total loss: 0

Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
Epoch: 369 | Batch: 000 / 039 | Total loss: 0.361 | Reg loss: 0.014 | Tree loss: 0.361 | Accuracy: 0.925781 | 0.202 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
Epoch: 370 | Batch: 000 / 039 | Total loss: 0.361 | Reg loss: 0.014 | Tree loss: 0.361 | Accuracy: 0.925781 | 0.202 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
Epoch: 371 | Batch: 000 / 039 | Total loss: 0

Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
Epoch: 392 | Batch: 000 / 039 | Total loss: 0.350 | Reg loss: 0.014 | Tree loss: 0.350 | Accuracy: 0.927734 | 0.201 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
Epoch: 393 | Batch: 000 / 039 | Total loss: 0.390 | Reg loss: 0.014 | Tree loss: 0.390 | Accuracy: 0.908203 | 0.201 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
Epoch: 394 | Batch: 000 / 039 | Total loss: 0

In [24]:
plt.figure(figsize=(10, 5))
plt.ylabel("Accuracy")
plt.xlabel('Iteration')
plt.plot(accs, label='Accuracy vs iteration')
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [25]:
plt.figure()
plt.ylabel("Loss")
plt.xlabel('Iteration')
plt.plot(losses, label='Loss vs iteration')
# plt.yscale("log")
plt.show()

plt.figure()
weights = tree.inner_nodes.weight.cpu().detach().numpy().flatten()
plt.hist(weights, bins=500)
weights_std = np.std(weights)
weights_mean = np.mean(weights)
plt.axvline(weights_mean + weights_std, color='r')
plt.axvline(weights_mean - weights_std, color='r')
plt.title(f"Mean: {weights_mean}   |   STD: {weights_std}")
plt.yscale("log")
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

# Tree Visualization

In [26]:
plt.figure(figsize=(15, 10), dpi=80)
avg_height, root = tree.visualize()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Average height: 4.888888888888889


# Extract Rules

# Accumulate samples in the leaves

In [27]:
print(f"Number of patterns: {len(root.get_leaves())}")

Number of patterns: 9


In [28]:
method = 'greedy'

In [29]:
root.clear_leaves_samples()

with torch.no_grad():
    for batch_idx, (data, target) in enumerate(tree_loader):
        root.accumulate_samples(data, method)



# Tighten boundaries

In [30]:
attr_names = [f"T_{i}" for i in range(test_samples.shape[2])]
leaves = root.get_leaves()
sum_comprehensibility = 0
comprehensibilities = []
for pattern_counter, leaf in enumerate(leaves):
    leaf.reset_path()
    leaf.tighten_with_accumulated_samples()
    conds = leaf.get_path_conditions(attr_names)
    print(f"============== Pattern {pattern_counter + 1} ==============")
    comprehensibilities.append(sum([cond.comprehensibility for cond in conds]))
    
print(f"Average comprehensibility: {np.mean(comprehensibilities)}")
print(f"std comprehensibility: {np.std(comprehensibilities)}")
print(f"var comprehensibility: {np.var(comprehensibilities)}")
print(f"minimum comprehensibility: {np.min(comprehensibilities)}")
print(f"maximum comprehensibility: {np.max(comprehensibilities)}")

372
11490
7817
Average comprehensibility: 26.666666666666668
std comprehensibility: 13.063945294843617
var comprehensibility: 170.66666666666666
minimum comprehensibility: 6
maximum comprehensibility: 44


  return np.log(1 / (1 - x))
