In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib widget
import numpy as np
import pandas as pd
import torch
import torchvision
from torch import optim
import matplotlib.pyplot as plt
import torch.nn.functional as F
from sklearn.metrics import pairwise_distances
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tqdm.notebook import tqdm
import os
import sys
import time
import torch.nn as nn
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)
from stream_generators.mit_bih import MITBIH
from utils.MatplotlibUtils import reduce_dims_and_plot
from network.auto_encoder import AutoEncoder
from losses.knn_loss import KNNLoss, ClassificationKNNLoss
from sklearn.cluster import DBSCAN
from sklearn.mixture import GaussianMixture
from soft_decision_tree.sdt_model import SDT
from sklearn.metrics import davies_bouldin_score

In [2]:
k = 64
tree_depth = 12
batch_size = 512
device = 'cpu'
train_data_path = r'<>/mitbih_train.csv'  # replace <> with the correct path of the dataset
test_data_path = r'<>/mitbih_test.csv'  # replace <> with the correct path of the dataset

In [3]:
train_data_iter = torch.utils.data.DataLoader(MITBIH(train_data_path),
                                             batch_size=batch_size,
                                             shuffle=True,
                                             num_workers=1,
                                             pin_memory=True,
                                             drop_last=True)

test_data_iter = torch.utils.data.DataLoader(MITBIH(test_data_path),
                                             batch_size=batch_size,
                                             shuffle=True,
                                             num_workers=1,
                                             pin_memory=True)

In [4]:
class ConvBlock(nn.Module):
    def __init__(self):
        super(ConvBlock, self).__init__()
        self.conv1 = nn.Conv1d(32, 32, kernel_size=5, stride=1, padding=2)
        self.conv2 = nn.Conv1d(32, 32, kernel_size=5, stride=1, padding=2)
        self.relu1 = nn.ReLU()
        self.relu2 = nn.ReLU()
        self.pool = nn.MaxPool1d(kernel_size=5, stride=2)
        
    def forward(self, x):
        y = x
        y = self.conv1(y)
        y = self.relu1(y)
        y = self.conv2(y)
        y = y + x
        y = self.relu2(y)
        y = self.pool(y)
        return y


class ECGModel(nn.Module):
    def __init__(self):
        super(ECGModel, self).__init__()
        self.conv = nn.Conv1d(1, 32, kernel_size=5, stride=1, padding=1)
        self.block1 = ConvBlock()
        self.block2 = ConvBlock()
        self.block3 = ConvBlock()
        self.block4 = ConvBlock()
        self.block5 = ConvBlock()
        self.fc1 = nn.Linear(64, 32)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(32, 5)

    def forward(self, x, return_interm=False):
        x = self.conv(x)
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        x = self.block4(x)
        x = self.block5(x)
        interm = x.flatten(1)
        x = self.fc1(interm)
        x = self.relu(x)
        x = self.fc2(x)
        
        if return_interm:
            return x, interm
        
        return x

In [5]:
knn_crt = ClassificationKNNLoss(k=k).to(device)

def train(model, loader, optimizer, device):
    model.train()

    total_loss = 0
    for iteration, (batch, target) in enumerate(loader):
        batch = batch.to(device)
        target = target.to(device)
        outputs, interm = model(batch, return_interm=True)
        mse_loss = F.cross_entropy(outputs, target)
        mse_loss = mse_loss.sum(dim=-1).mean()
        try:
            knn_loss = knn_crt(interm, target)
            if torch.isinf(knn_loss):
                knn_loss = torch.tensor(0).to(device)
        except ValueError:
            knn_loss = torch.tensor(0).to(device)
        loss = mse_loss + knn_loss
        total_loss += loss.item()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if iteration % log_every == 0:
            print(f"Epoch {epoch} / {epochs} | iteration {iteration} / {len(loader)} | Total Loss: {loss.item()} | KNN Loss: {knn_loss.item()} | CLS Loss: {mse_loss.item()}")

    return total_loss / len(loader)

@torch.no_grad()
def test(model, loader, device):
    model.eval()
    
    correct = 0
    for iteration, (batch, target) in enumerate(loader):
        batch = batch.to(device)
        target = target.to(device)
        y_pred = model(batch).argmax(dim=-1)
        correct += y_pred.eq(target.view(-1).data).sum()
    
    return correct / len(loader.dataset)

In [6]:
epochs = 200
lr = 1e-3
log_every = 10

model = ECGModel().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
num_params = sum([p.numel() for p in model.parameters()])
print(f'#Params: {num_params}')

#Params: 53957


In [7]:
best_valid_acc = 0
losses = []
train_accs = []
val_accs = []
for epoch in range(1, epochs + 1):
    loss = train(model, train_data_iter, optimizer, device)
#     print(f"Loss: {loss} =============================")
    losses.append(loss)
    train_acc = test(model, train_data_iter, device)
    train_accs.append(train_acc)
    valid_acc = test(model, test_data_iter, device)
    val_accs.append(valid_acc)
    if valid_acc > best_valid_acc:
        best_valid_acc = valid_acc
    if epoch % 1 == 0:
        print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, '
              f'Train: {train_acc:.4f}, Valid: {valid_acc:.4f}, '
              f'Best: {best_valid_acc:.4f}')

  return torch._C._cuda_getDeviceCount() > 0


Epoch 1 / 200 | iteration 0 / 171 | Total Loss: 7.531494617462158 | KNN Loss: 5.871538162231445 | CLS Loss: 1.6599565744400024
Epoch 1 / 200 | iteration 10 / 171 | Total Loss: 5.513195037841797 | KNN Loss: 4.814934253692627 | CLS Loss: 0.6982605457305908
Epoch 1 / 200 | iteration 20 / 171 | Total Loss: 5.213438987731934 | KNN Loss: 4.526139736175537 | CLS Loss: 0.6872991323471069
Epoch 1 / 200 | iteration 30 / 171 | Total Loss: 5.109689712524414 | KNN Loss: 4.477982997894287 | CLS Loss: 0.631706953048706
Epoch 1 / 200 | iteration 40 / 171 | Total Loss: 4.985007286071777 | KNN Loss: 4.419722557067871 | CLS Loss: 0.5652847290039062
Epoch 1 / 200 | iteration 50 / 171 | Total Loss: 5.026525974273682 | KNN Loss: 4.453143119812012 | CLS Loss: 0.5733826756477356
Epoch 1 / 200 | iteration 60 / 171 | Total Loss: 4.984081745147705 | KNN Loss: 4.435934066772461 | CLS Loss: 0.5481476187705994
Epoch 1 / 200 | iteration 70 / 171 | Total Loss: 4.842815399169922 | KNN Loss: 4.402174472808838 | CLS Los

Epoch 4 / 200 | iteration 90 / 171 | Total Loss: 4.3926873207092285 | KNN Loss: 4.253031253814697 | CLS Loss: 0.1396559327840805
Epoch 4 / 200 | iteration 100 / 171 | Total Loss: 4.483060359954834 | KNN Loss: 4.3383259773254395 | CLS Loss: 0.14473433792591095
Epoch 4 / 200 | iteration 110 / 171 | Total Loss: 4.440836429595947 | KNN Loss: 4.299612045288086 | CLS Loss: 0.14122438430786133
Epoch 4 / 200 | iteration 120 / 171 | Total Loss: 4.440685749053955 | KNN Loss: 4.326118469238281 | CLS Loss: 0.11456704884767532
Epoch 4 / 200 | iteration 130 / 171 | Total Loss: 4.37481164932251 | KNN Loss: 4.264281272888184 | CLS Loss: 0.11053016036748886
Epoch 4 / 200 | iteration 140 / 171 | Total Loss: 4.391666412353516 | KNN Loss: 4.29433012008667 | CLS Loss: 0.0973360538482666
Epoch 4 / 200 | iteration 150 / 171 | Total Loss: 4.4143781661987305 | KNN Loss: 4.301726341247559 | CLS Loss: 0.11265181005001068
Epoch 4 / 200 | iteration 160 / 171 | Total Loss: 4.412785053253174 | KNN Loss: 4.2735900878

Epoch 7 / 200 | iteration 170 / 171 | Total Loss: 4.310531139373779 | KNN Loss: 4.249791622161865 | CLS Loss: 0.06073939800262451
Epoch: 007, Loss: 4.3683, Train: 0.9783, Valid: 0.9751, Best: 0.9751
Epoch 8 / 200 | iteration 0 / 171 | Total Loss: 4.277074813842773 | KNN Loss: 4.203298568725586 | CLS Loss: 0.07377618551254272
Epoch 8 / 200 | iteration 10 / 171 | Total Loss: 4.31929874420166 | KNN Loss: 4.248308181762695 | CLS Loss: 0.07099057734012604
Epoch 8 / 200 | iteration 20 / 171 | Total Loss: 4.40557861328125 | KNN Loss: 4.308181285858154 | CLS Loss: 0.09739750623703003
Epoch 8 / 200 | iteration 30 / 171 | Total Loss: 4.342289924621582 | KNN Loss: 4.222406387329102 | CLS Loss: 0.11988357454538345
Epoch 8 / 200 | iteration 40 / 171 | Total Loss: 4.311273097991943 | KNN Loss: 4.241094589233398 | CLS Loss: 0.07017853856086731
Epoch 8 / 200 | iteration 50 / 171 | Total Loss: 4.384051322937012 | KNN Loss: 4.28169059753418 | CLS Loss: 0.10236073285341263
Epoch 8 / 200 | iteration 60 / 

Epoch 11 / 200 | iteration 60 / 171 | Total Loss: 4.304662704467773 | KNN Loss: 4.250932216644287 | CLS Loss: 0.05373039469122887
Epoch 11 / 200 | iteration 70 / 171 | Total Loss: 4.341189384460449 | KNN Loss: 4.271505832672119 | CLS Loss: 0.06968340277671814
Epoch 11 / 200 | iteration 80 / 171 | Total Loss: 4.320568084716797 | KNN Loss: 4.246573448181152 | CLS Loss: 0.07399451732635498
Epoch 11 / 200 | iteration 90 / 171 | Total Loss: 4.314630031585693 | KNN Loss: 4.247573375701904 | CLS Loss: 0.06705685704946518
Epoch 11 / 200 | iteration 100 / 171 | Total Loss: 4.307745456695557 | KNN Loss: 4.227632522583008 | CLS Loss: 0.08011271804571152
Epoch 11 / 200 | iteration 110 / 171 | Total Loss: 4.377388000488281 | KNN Loss: 4.2904052734375 | CLS Loss: 0.08698272705078125
Epoch 11 / 200 | iteration 120 / 171 | Total Loss: 4.331469535827637 | KNN Loss: 4.27426815032959 | CLS Loss: 0.05720130726695061
Epoch 11 / 200 | iteration 130 / 171 | Total Loss: 4.290808200836182 | KNN Loss: 4.2009387

Epoch 14 / 200 | iteration 140 / 171 | Total Loss: 4.30272912979126 | KNN Loss: 4.275683879852295 | CLS Loss: 0.027045099064707756
Epoch 14 / 200 | iteration 150 / 171 | Total Loss: 4.2770795822143555 | KNN Loss: 4.199060916900635 | CLS Loss: 0.07801886647939682
Epoch 14 / 200 | iteration 160 / 171 | Total Loss: 4.277829647064209 | KNN Loss: 4.22072172164917 | CLS Loss: 0.057107966393232346
Epoch 14 / 200 | iteration 170 / 171 | Total Loss: 4.290104389190674 | KNN Loss: 4.221899509429932 | CLS Loss: 0.06820482760667801
Epoch: 014, Loss: 4.2912, Train: 0.9864, Valid: 0.9821, Best: 0.9821
Epoch 15 / 200 | iteration 0 / 171 | Total Loss: 4.274198055267334 | KNN Loss: 4.247531414031982 | CLS Loss: 0.026666713878512383
Epoch 15 / 200 | iteration 10 / 171 | Total Loss: 4.339036464691162 | KNN Loss: 4.263019561767578 | CLS Loss: 0.07601696252822876
Epoch 15 / 200 | iteration 20 / 171 | Total Loss: 4.276369571685791 | KNN Loss: 4.238150596618652 | CLS Loss: 0.038218822330236435
Epoch 15 / 200 

Epoch 18 / 200 | iteration 30 / 171 | Total Loss: 4.300478458404541 | KNN Loss: 4.260936260223389 | CLS Loss: 0.03954201936721802
Epoch 18 / 200 | iteration 40 / 171 | Total Loss: 4.229725360870361 | KNN Loss: 4.186895370483398 | CLS Loss: 0.04282982274889946
Epoch 18 / 200 | iteration 50 / 171 | Total Loss: 4.31614351272583 | KNN Loss: 4.250302314758301 | CLS Loss: 0.06584139913320541
Epoch 18 / 200 | iteration 60 / 171 | Total Loss: 4.270289421081543 | KNN Loss: 4.2217583656311035 | CLS Loss: 0.04853105917572975
Epoch 18 / 200 | iteration 70 / 171 | Total Loss: 4.284079074859619 | KNN Loss: 4.241131782531738 | CLS Loss: 0.042947061359882355
Epoch 18 / 200 | iteration 80 / 171 | Total Loss: 4.313508033752441 | KNN Loss: 4.248310089111328 | CLS Loss: 0.06519799679517746
Epoch 18 / 200 | iteration 90 / 171 | Total Loss: 4.228884696960449 | KNN Loss: 4.191882133483887 | CLS Loss: 0.03700267896056175
Epoch 18 / 200 | iteration 100 / 171 | Total Loss: 4.2878546714782715 | KNN Loss: 4.22772

Epoch 21 / 200 | iteration 110 / 171 | Total Loss: 4.242568492889404 | KNN Loss: 4.19084358215332 | CLS Loss: 0.05172489210963249
Epoch 21 / 200 | iteration 120 / 171 | Total Loss: 4.26951789855957 | KNN Loss: 4.204338073730469 | CLS Loss: 0.06517962366342545
Epoch 21 / 200 | iteration 130 / 171 | Total Loss: 4.252274513244629 | KNN Loss: 4.210559368133545 | CLS Loss: 0.041715025901794434
Epoch 21 / 200 | iteration 140 / 171 | Total Loss: 4.244344234466553 | KNN Loss: 4.211349010467529 | CLS Loss: 0.03299541398882866
Epoch 21 / 200 | iteration 150 / 171 | Total Loss: 4.223918914794922 | KNN Loss: 4.208064556121826 | CLS Loss: 0.01585458032786846
Epoch 21 / 200 | iteration 160 / 171 | Total Loss: 4.237831115722656 | KNN Loss: 4.218386650085449 | CLS Loss: 0.019444482401013374
Epoch 21 / 200 | iteration 170 / 171 | Total Loss: 4.250600337982178 | KNN Loss: 4.207740783691406 | CLS Loss: 0.04285946115851402
Epoch: 021, Loss: 4.2551, Train: 0.9901, Valid: 0.9843, Best: 0.9843
Epoch 22 / 200

Epoch: 024, Loss: 4.2406, Train: 0.9907, Valid: 0.9845, Best: 0.9845
Epoch 25 / 200 | iteration 0 / 171 | Total Loss: 4.207675933837891 | KNN Loss: 4.183061122894287 | CLS Loss: 0.024614786729216576
Epoch 25 / 200 | iteration 10 / 171 | Total Loss: 4.246267795562744 | KNN Loss: 4.224581718444824 | CLS Loss: 0.021685931831598282
Epoch 25 / 200 | iteration 20 / 171 | Total Loss: 4.279109477996826 | KNN Loss: 4.247529983520508 | CLS Loss: 0.03157961741089821
Epoch 25 / 200 | iteration 30 / 171 | Total Loss: 4.267364501953125 | KNN Loss: 4.233802318572998 | CLS Loss: 0.03356197103857994
Epoch 25 / 200 | iteration 40 / 171 | Total Loss: 4.251857280731201 | KNN Loss: 4.219854831695557 | CLS Loss: 0.03200257942080498
Epoch 25 / 200 | iteration 50 / 171 | Total Loss: 4.244296073913574 | KNN Loss: 4.202948093414307 | CLS Loss: 0.041347846388816833
Epoch 25 / 200 | iteration 60 / 171 | Total Loss: 4.264890193939209 | KNN Loss: 4.198514938354492 | CLS Loss: 0.06637528538703918
Epoch 25 / 200 | it

Epoch 28 / 200 | iteration 70 / 171 | Total Loss: 4.2436323165893555 | KNN Loss: 4.2053399085998535 | CLS Loss: 0.03829244151711464
Epoch 28 / 200 | iteration 80 / 171 | Total Loss: 4.230072498321533 | KNN Loss: 4.197988986968994 | CLS Loss: 0.03208357095718384
Epoch 28 / 200 | iteration 90 / 171 | Total Loss: 4.234865665435791 | KNN Loss: 4.207112789154053 | CLS Loss: 0.02775268629193306
Epoch 28 / 200 | iteration 100 / 171 | Total Loss: 4.284020900726318 | KNN Loss: 4.208919525146484 | CLS Loss: 0.0751013234257698
Epoch 28 / 200 | iteration 110 / 171 | Total Loss: 4.203401565551758 | KNN Loss: 4.1948933601379395 | CLS Loss: 0.008508346974849701
Epoch 28 / 200 | iteration 120 / 171 | Total Loss: 4.279885292053223 | KNN Loss: 4.220667362213135 | CLS Loss: 0.05921813100576401
Epoch 28 / 200 | iteration 130 / 171 | Total Loss: 4.213982582092285 | KNN Loss: 4.182875633239746 | CLS Loss: 0.031106863170862198
Epoch 28 / 200 | iteration 140 / 171 | Total Loss: 4.274678707122803 | KNN Loss: 4

Epoch 31 / 200 | iteration 140 / 171 | Total Loss: 4.243635177612305 | KNN Loss: 4.2295145988464355 | CLS Loss: 0.01412065140902996
Epoch 31 / 200 | iteration 150 / 171 | Total Loss: 4.250445365905762 | KNN Loss: 4.179553985595703 | CLS Loss: 0.0708916187286377
Epoch 31 / 200 | iteration 160 / 171 | Total Loss: 4.22918176651001 | KNN Loss: 4.187734127044678 | CLS Loss: 0.04144773259758949
Epoch 31 / 200 | iteration 170 / 171 | Total Loss: 4.221603870391846 | KNN Loss: 4.197984218597412 | CLS Loss: 0.023619798943400383
Epoch: 031, Loss: 4.2214, Train: 0.9924, Valid: 0.9854, Best: 0.9854
Epoch 32 / 200 | iteration 0 / 171 | Total Loss: 4.279277801513672 | KNN Loss: 4.232723236083984 | CLS Loss: 0.046554360538721085
Epoch 32 / 200 | iteration 10 / 171 | Total Loss: 4.216104507446289 | KNN Loss: 4.1956634521484375 | CLS Loss: 0.02044091187417507
Epoch 32 / 200 | iteration 20 / 171 | Total Loss: 4.202889919281006 | KNN Loss: 4.1732988357543945 | CLS Loss: 0.029590968042612076
Epoch 32 / 200

Epoch 35 / 200 | iteration 30 / 171 | Total Loss: 4.223684310913086 | KNN Loss: 4.208004474639893 | CLS Loss: 0.015679875388741493
Epoch 35 / 200 | iteration 40 / 171 | Total Loss: 4.1940083503723145 | KNN Loss: 4.175229549407959 | CLS Loss: 0.018778743222355843
Epoch 35 / 200 | iteration 50 / 171 | Total Loss: 4.2106170654296875 | KNN Loss: 4.187617778778076 | CLS Loss: 0.022999344393610954
Epoch 35 / 200 | iteration 60 / 171 | Total Loss: 4.218364238739014 | KNN Loss: 4.182963848114014 | CLS Loss: 0.0354003831744194
Epoch 35 / 200 | iteration 70 / 171 | Total Loss: 4.194406986236572 | KNN Loss: 4.168439865112305 | CLS Loss: 0.025967281311750412
Epoch 35 / 200 | iteration 80 / 171 | Total Loss: 4.277900218963623 | KNN Loss: 4.234869480133057 | CLS Loss: 0.04303088039159775
Epoch 35 / 200 | iteration 90 / 171 | Total Loss: 4.208922863006592 | KNN Loss: 4.195460796356201 | CLS Loss: 0.013462303206324577
Epoch 35 / 200 | iteration 100 / 171 | Total Loss: 4.239328384399414 | KNN Loss: 4.2

Epoch 38 / 200 | iteration 100 / 171 | Total Loss: 4.246324062347412 | KNN Loss: 4.219326019287109 | CLS Loss: 0.026997875422239304
Epoch 38 / 200 | iteration 110 / 171 | Total Loss: 4.198949813842773 | KNN Loss: 4.184017658233643 | CLS Loss: 0.014931933023035526
Epoch 38 / 200 | iteration 120 / 171 | Total Loss: 4.225663661956787 | KNN Loss: 4.193836212158203 | CLS Loss: 0.031827352941036224
Epoch 38 / 200 | iteration 130 / 171 | Total Loss: 4.20611047744751 | KNN Loss: 4.194192886352539 | CLS Loss: 0.01191764697432518
Epoch 38 / 200 | iteration 140 / 171 | Total Loss: 4.259815216064453 | KNN Loss: 4.221579074859619 | CLS Loss: 0.03823606297373772
Epoch 38 / 200 | iteration 150 / 171 | Total Loss: 4.221337795257568 | KNN Loss: 4.190531253814697 | CLS Loss: 0.03080640733242035
Epoch 38 / 200 | iteration 160 / 171 | Total Loss: 4.226406097412109 | KNN Loss: 4.202661991119385 | CLS Loss: 0.023743966594338417
Epoch 38 / 200 | iteration 170 / 171 | Total Loss: 4.256950378417969 | KNN Loss:

Epoch 41 / 200 | iteration 170 / 171 | Total Loss: 4.2383131980896 | KNN Loss: 4.208118438720703 | CLS Loss: 0.030194832012057304
Epoch: 041, Loss: 4.2181, Train: 0.9940, Valid: 0.9863, Best: 0.9866
Epoch 42 / 200 | iteration 0 / 171 | Total Loss: 4.219261646270752 | KNN Loss: 4.211283206939697 | CLS Loss: 0.007978660054504871
Epoch 42 / 200 | iteration 10 / 171 | Total Loss: 4.20481014251709 | KNN Loss: 4.200500011444092 | CLS Loss: 0.004310272168368101
Epoch 42 / 200 | iteration 20 / 171 | Total Loss: 4.217003345489502 | KNN Loss: 4.188265800476074 | CLS Loss: 0.028737446293234825
Epoch 42 / 200 | iteration 30 / 171 | Total Loss: 4.198896408081055 | KNN Loss: 4.182306289672852 | CLS Loss: 0.016590293496847153
Epoch 42 / 200 | iteration 40 / 171 | Total Loss: 4.225890636444092 | KNN Loss: 4.198854446411133 | CLS Loss: 0.027036169543862343
Epoch 42 / 200 | iteration 50 / 171 | Total Loss: 4.241453647613525 | KNN Loss: 4.208463668823242 | CLS Loss: 0.03299001604318619
Epoch 42 / 200 | i

Epoch 45 / 200 | iteration 60 / 171 | Total Loss: 4.225077152252197 | KNN Loss: 4.1896491050720215 | CLS Loss: 0.03542814403772354
Epoch 45 / 200 | iteration 70 / 171 | Total Loss: 4.214105606079102 | KNN Loss: 4.202685832977295 | CLS Loss: 0.011419560760259628
Epoch 45 / 200 | iteration 80 / 171 | Total Loss: 4.1733808517456055 | KNN Loss: 4.158309459686279 | CLS Loss: 0.015071537345647812
Epoch 45 / 200 | iteration 90 / 171 | Total Loss: 4.203010559082031 | KNN Loss: 4.181915283203125 | CLS Loss: 0.021095316857099533
Epoch 45 / 200 | iteration 100 / 171 | Total Loss: 4.207612037658691 | KNN Loss: 4.16044807434082 | CLS Loss: 0.04716384410858154
Epoch 45 / 200 | iteration 110 / 171 | Total Loss: 4.259705543518066 | KNN Loss: 4.214901924133301 | CLS Loss: 0.04480363428592682
Epoch 45 / 200 | iteration 120 / 171 | Total Loss: 4.184332370758057 | KNN Loss: 4.16845703125 | CLS Loss: 0.01587517559528351
Epoch 45 / 200 | iteration 130 / 171 | Total Loss: 4.227839946746826 | KNN Loss: 4.2021

Epoch 48 / 200 | iteration 130 / 171 | Total Loss: 4.189859390258789 | KNN Loss: 4.166978359222412 | CLS Loss: 0.022881262004375458
Epoch 48 / 200 | iteration 140 / 171 | Total Loss: 4.185413837432861 | KNN Loss: 4.1628737449646 | CLS Loss: 0.022540142759680748
Epoch 48 / 200 | iteration 150 / 171 | Total Loss: 4.170398235321045 | KNN Loss: 4.161129951477051 | CLS Loss: 0.0092684431001544
Epoch 48 / 200 | iteration 160 / 171 | Total Loss: 4.205132484436035 | KNN Loss: 4.196259021759033 | CLS Loss: 0.008873231709003448
Epoch 48 / 200 | iteration 170 / 171 | Total Loss: 4.196340560913086 | KNN Loss: 4.182790756225586 | CLS Loss: 0.013549677096307278
Epoch: 048, Loss: 4.2077, Train: 0.9942, Valid: 0.9852, Best: 0.9866
Epoch 49 / 200 | iteration 0 / 171 | Total Loss: 4.1939921379089355 | KNN Loss: 4.182600498199463 | CLS Loss: 0.011391740292310715
Epoch 49 / 200 | iteration 10 / 171 | Total Loss: 4.217296600341797 | KNN Loss: 4.188128471374512 | CLS Loss: 0.02916790172457695
Epoch 49 / 200

Epoch 52 / 200 | iteration 20 / 171 | Total Loss: 4.1702799797058105 | KNN Loss: 4.148709297180176 | CLS Loss: 0.02157084457576275
Epoch 52 / 200 | iteration 30 / 171 | Total Loss: 4.214503288269043 | KNN Loss: 4.184453964233398 | CLS Loss: 0.030049467459321022
Epoch 52 / 200 | iteration 40 / 171 | Total Loss: 4.189785003662109 | KNN Loss: 4.168893337249756 | CLS Loss: 0.02089146338403225
Epoch 52 / 200 | iteration 50 / 171 | Total Loss: 4.182406425476074 | KNN Loss: 4.168633937835693 | CLS Loss: 0.013772501610219479
Epoch 52 / 200 | iteration 60 / 171 | Total Loss: 4.188601970672607 | KNN Loss: 4.151299476623535 | CLS Loss: 0.03730225935578346
Epoch 52 / 200 | iteration 70 / 171 | Total Loss: 4.210635662078857 | KNN Loss: 4.193665504455566 | CLS Loss: 0.016970157623291016
Epoch 52 / 200 | iteration 80 / 171 | Total Loss: 4.154101371765137 | KNN Loss: 4.144765853881836 | CLS Loss: 0.009335282258689404
Epoch 52 / 200 | iteration 90 / 171 | Total Loss: 4.1909871101379395 | KNN Loss: 4.18

Epoch 55 / 200 | iteration 90 / 171 | Total Loss: 4.181483745574951 | KNN Loss: 4.147748947143555 | CLS Loss: 0.033735014498233795
Epoch 55 / 200 | iteration 100 / 171 | Total Loss: 4.148003101348877 | KNN Loss: 4.136166572570801 | CLS Loss: 0.0118366414681077
Epoch 55 / 200 | iteration 110 / 171 | Total Loss: 4.198420524597168 | KNN Loss: 4.159153461456299 | CLS Loss: 0.03926724195480347
Epoch 55 / 200 | iteration 120 / 171 | Total Loss: 4.213823318481445 | KNN Loss: 4.202559947967529 | CLS Loss: 0.011263209395110607
Epoch 55 / 200 | iteration 130 / 171 | Total Loss: 4.202659606933594 | KNN Loss: 4.1685872077941895 | CLS Loss: 0.034072376787662506
Epoch 55 / 200 | iteration 140 / 171 | Total Loss: 4.193799018859863 | KNN Loss: 4.180274486541748 | CLS Loss: 0.013524500653147697
Epoch 55 / 200 | iteration 150 / 171 | Total Loss: 4.159256458282471 | KNN Loss: 4.1519622802734375 | CLS Loss: 0.0072944085113704205
Epoch 55 / 200 | iteration 160 / 171 | Total Loss: 4.164046764373779 | KNN Lo

Epoch 58 / 200 | iteration 160 / 171 | Total Loss: 4.215922832489014 | KNN Loss: 4.194922924041748 | CLS Loss: 0.020999914035201073
Epoch 58 / 200 | iteration 170 / 171 | Total Loss: 4.204423427581787 | KNN Loss: 4.199481964111328 | CLS Loss: 0.004941476043313742
Epoch: 058, Loss: 4.1947, Train: 0.9947, Valid: 0.9853, Best: 0.9871
Epoch 59 / 200 | iteration 0 / 171 | Total Loss: 4.221714973449707 | KNN Loss: 4.19355583190918 | CLS Loss: 0.028159206733107567
Epoch 59 / 200 | iteration 10 / 171 | Total Loss: 4.179389476776123 | KNN Loss: 4.159837245941162 | CLS Loss: 0.01955222710967064
Epoch 59 / 200 | iteration 20 / 171 | Total Loss: 4.170622825622559 | KNN Loss: 4.161334037780762 | CLS Loss: 0.009288894012570381
Epoch 59 / 200 | iteration 30 / 171 | Total Loss: 4.214389801025391 | KNN Loss: 4.201254367828369 | CLS Loss: 0.013135419227182865
Epoch 59 / 200 | iteration 40 / 171 | Total Loss: 4.2412190437316895 | KNN Loss: 4.2173380851745605 | CLS Loss: 0.023881033062934875
Epoch 59 / 20

Epoch 62 / 200 | iteration 50 / 171 | Total Loss: 4.165596008300781 | KNN Loss: 4.16331148147583 | CLS Loss: 0.002284666523337364
Epoch 62 / 200 | iteration 60 / 171 | Total Loss: 4.158426761627197 | KNN Loss: 4.157233715057373 | CLS Loss: 0.001193030970171094
Epoch 62 / 200 | iteration 70 / 171 | Total Loss: 4.257723331451416 | KNN Loss: 4.219242572784424 | CLS Loss: 0.0384807325899601
Epoch 62 / 200 | iteration 80 / 171 | Total Loss: 4.183809280395508 | KNN Loss: 4.1766252517700195 | CLS Loss: 0.007184003479778767
Epoch 62 / 200 | iteration 90 / 171 | Total Loss: 4.178555011749268 | KNN Loss: 4.160491943359375 | CLS Loss: 0.01806303672492504
Epoch 62 / 200 | iteration 100 / 171 | Total Loss: 4.188300132751465 | KNN Loss: 4.170668601989746 | CLS Loss: 0.017631743103265762
Epoch 62 / 200 | iteration 110 / 171 | Total Loss: 4.186277866363525 | KNN Loss: 4.169018268585205 | CLS Loss: 0.017259778454899788
Epoch 62 / 200 | iteration 120 / 171 | Total Loss: 4.15166711807251 | KNN Loss: 4.14

Epoch 65 / 200 | iteration 120 / 171 | Total Loss: 4.2054572105407715 | KNN Loss: 4.190098285675049 | CLS Loss: 0.015359107404947281
Epoch 65 / 200 | iteration 130 / 171 | Total Loss: 4.203043460845947 | KNN Loss: 4.165285587310791 | CLS Loss: 0.03775806725025177
Epoch 65 / 200 | iteration 140 / 171 | Total Loss: 4.232099533081055 | KNN Loss: 4.195133686065674 | CLS Loss: 0.03696604445576668
Epoch 65 / 200 | iteration 150 / 171 | Total Loss: 4.164560794830322 | KNN Loss: 4.1569647789001465 | CLS Loss: 0.007595815230160952
Epoch 65 / 200 | iteration 160 / 171 | Total Loss: 4.187994956970215 | KNN Loss: 4.162508010864258 | CLS Loss: 0.025487020611763
Epoch 65 / 200 | iteration 170 / 171 | Total Loss: 4.232361793518066 | KNN Loss: 4.178335666656494 | CLS Loss: 0.05402589216828346
Epoch: 065, Loss: 4.1911, Train: 0.9955, Valid: 0.9865, Best: 0.9871
Epoch 66 / 200 | iteration 0 / 171 | Total Loss: 4.16154670715332 | KNN Loss: 4.155331134796143 | CLS Loss: 0.006215378642082214
Epoch 66 / 200

Epoch 69 / 200 | iteration 10 / 171 | Total Loss: 4.231618881225586 | KNN Loss: 4.22288179397583 | CLS Loss: 0.008737083524465561
Epoch 69 / 200 | iteration 20 / 171 | Total Loss: 4.187717437744141 | KNN Loss: 4.164570331573486 | CLS Loss: 0.02314707450568676
Epoch 69 / 200 | iteration 30 / 171 | Total Loss: 4.168126583099365 | KNN Loss: 4.159274101257324 | CLS Loss: 0.008852286264300346
Epoch 69 / 200 | iteration 40 / 171 | Total Loss: 4.220248222351074 | KNN Loss: 4.179140567779541 | CLS Loss: 0.041107773780822754
Epoch 69 / 200 | iteration 50 / 171 | Total Loss: 4.194516658782959 | KNN Loss: 4.1766133308410645 | CLS Loss: 0.017903273925185204
Epoch 69 / 200 | iteration 60 / 171 | Total Loss: 4.203122615814209 | KNN Loss: 4.177933692932129 | CLS Loss: 0.02518889307975769
Epoch 69 / 200 | iteration 70 / 171 | Total Loss: 4.203919410705566 | KNN Loss: 4.185119152069092 | CLS Loss: 0.018800247460603714
Epoch 69 / 200 | iteration 80 / 171 | Total Loss: 4.175041198730469 | KNN Loss: 4.168

Epoch 72 / 200 | iteration 80 / 171 | Total Loss: 4.180418014526367 | KNN Loss: 4.1602067947387695 | CLS Loss: 0.020211130380630493
Epoch 72 / 200 | iteration 90 / 171 | Total Loss: 4.218597888946533 | KNN Loss: 4.202602863311768 | CLS Loss: 0.01599503681063652
Epoch 72 / 200 | iteration 100 / 171 | Total Loss: 4.178929805755615 | KNN Loss: 4.15957498550415 | CLS Loss: 0.019354816526174545
Epoch 72 / 200 | iteration 110 / 171 | Total Loss: 4.216339111328125 | KNN Loss: 4.183912754058838 | CLS Loss: 0.032426562160253525
Epoch 72 / 200 | iteration 120 / 171 | Total Loss: 4.148946762084961 | KNN Loss: 4.133347511291504 | CLS Loss: 0.01559903472661972
Epoch 72 / 200 | iteration 130 / 171 | Total Loss: 4.189474582672119 | KNN Loss: 4.172037601470947 | CLS Loss: 0.01743677817285061
Epoch 72 / 200 | iteration 140 / 171 | Total Loss: 4.195927619934082 | KNN Loss: 4.191999912261963 | CLS Loss: 0.0039279405027627945
Epoch 72 / 200 | iteration 150 / 171 | Total Loss: 4.171967506408691 | KNN Loss:

Epoch 75 / 200 | iteration 150 / 171 | Total Loss: 4.215743541717529 | KNN Loss: 4.198930740356445 | CLS Loss: 0.016812657937407494
Epoch 75 / 200 | iteration 160 / 171 | Total Loss: 4.2382636070251465 | KNN Loss: 4.18907356262207 | CLS Loss: 0.049190256744623184
Epoch 75 / 200 | iteration 170 / 171 | Total Loss: 4.166059494018555 | KNN Loss: 4.150218486785889 | CLS Loss: 0.015840888023376465
Epoch: 075, Loss: 4.1784, Train: 0.9956, Valid: 0.9859, Best: 0.9882
Epoch 76 / 200 | iteration 0 / 171 | Total Loss: 4.171874523162842 | KNN Loss: 4.155615329742432 | CLS Loss: 0.01625923439860344
Epoch 76 / 200 | iteration 10 / 171 | Total Loss: 4.162373065948486 | KNN Loss: 4.1519904136657715 | CLS Loss: 0.010382474400103092
Epoch 76 / 200 | iteration 20 / 171 | Total Loss: 4.185617923736572 | KNN Loss: 4.174278736114502 | CLS Loss: 0.011339057236909866
Epoch 76 / 200 | iteration 30 / 171 | Total Loss: 4.141385078430176 | KNN Loss: 4.137580871582031 | CLS Loss: 0.0038040040526539087
Epoch 76 / 

Epoch 79 / 200 | iteration 40 / 171 | Total Loss: 4.22163200378418 | KNN Loss: 4.187371253967285 | CLS Loss: 0.03426066413521767
Epoch 79 / 200 | iteration 50 / 171 | Total Loss: 4.1714982986450195 | KNN Loss: 4.169597625732422 | CLS Loss: 0.001900878269225359
Epoch 79 / 200 | iteration 60 / 171 | Total Loss: 4.23212194442749 | KNN Loss: 4.210782527923584 | CLS Loss: 0.021339282393455505
Epoch 79 / 200 | iteration 70 / 171 | Total Loss: 4.201274871826172 | KNN Loss: 4.177408218383789 | CLS Loss: 0.02386642061173916
Epoch 79 / 200 | iteration 80 / 171 | Total Loss: 4.187958717346191 | KNN Loss: 4.170316696166992 | CLS Loss: 0.017641974613070488
Epoch 79 / 200 | iteration 90 / 171 | Total Loss: 4.225302696228027 | KNN Loss: 4.212100982666016 | CLS Loss: 0.013201511465013027
Epoch 79 / 200 | iteration 100 / 171 | Total Loss: 4.1854777336120605 | KNN Loss: 4.173969745635986 | CLS Loss: 0.011508071795105934
Epoch 79 / 200 | iteration 110 / 171 | Total Loss: 4.159139156341553 | KNN Loss: 4.1

Epoch 82 / 200 | iteration 110 / 171 | Total Loss: 4.202474117279053 | KNN Loss: 4.1924214363098145 | CLS Loss: 0.010052827186882496
Epoch 82 / 200 | iteration 120 / 171 | Total Loss: 4.163381576538086 | KNN Loss: 4.14677095413208 | CLS Loss: 0.016610395163297653
Epoch 82 / 200 | iteration 130 / 171 | Total Loss: 4.161956310272217 | KNN Loss: 4.155298233032227 | CLS Loss: 0.006658002734184265
Epoch 82 / 200 | iteration 140 / 171 | Total Loss: 4.190673351287842 | KNN Loss: 4.182836055755615 | CLS Loss: 0.007837367244064808
Epoch 82 / 200 | iteration 150 / 171 | Total Loss: 4.1866960525512695 | KNN Loss: 4.165360450744629 | CLS Loss: 0.021335672587156296
Epoch 82 / 200 | iteration 160 / 171 | Total Loss: 4.19980525970459 | KNN Loss: 4.171334266662598 | CLS Loss: 0.02847079187631607
Epoch 82 / 200 | iteration 170 / 171 | Total Loss: 4.149526596069336 | KNN Loss: 4.142782211303711 | CLS Loss: 0.006744252517819405
Epoch: 082, Loss: 4.1796, Train: 0.9970, Valid: 0.9873, Best: 0.9882
Epoch 83

Epoch: 085, Loss: 4.1828, Train: 0.9967, Valid: 0.9861, Best: 0.9882
Epoch 86 / 200 | iteration 0 / 171 | Total Loss: 4.190130233764648 | KNN Loss: 4.1751532554626465 | CLS Loss: 0.014976749196648598
Epoch 86 / 200 | iteration 10 / 171 | Total Loss: 4.171115398406982 | KNN Loss: 4.158337593078613 | CLS Loss: 0.01277759950608015
Epoch 86 / 200 | iteration 20 / 171 | Total Loss: 4.154627323150635 | KNN Loss: 4.151052951812744 | CLS Loss: 0.0035743520129472017
Epoch 86 / 200 | iteration 30 / 171 | Total Loss: 4.193169593811035 | KNN Loss: 4.183012008666992 | CLS Loss: 0.010157492011785507
Epoch 86 / 200 | iteration 40 / 171 | Total Loss: 4.161135673522949 | KNN Loss: 4.155674934387207 | CLS Loss: 0.005460788030177355
Epoch 86 / 200 | iteration 50 / 171 | Total Loss: 4.141238212585449 | KNN Loss: 4.129074573516846 | CLS Loss: 0.012163816951215267
Epoch 86 / 200 | iteration 60 / 171 | Total Loss: 4.136568069458008 | KNN Loss: 4.132318496704102 | CLS Loss: 0.004249591380357742
Epoch 86 / 200

KeyboardInterrupt: 

In [8]:
test(model, test_data_iter, device)

tensor(0.9873)

In [9]:
plt.figure()
plt.plot(losses, label='train loss')
plt.legend()
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [10]:
plt.figure()
plt.plot(train_accs, label='train accuracy')
plt.plot(val_accs, label='validation accuracy')
plt.legend()
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [11]:
test_samples = torch.tensor([])
projections = torch.tensor([])
labels = torch.tensor([])

with torch.no_grad():
    for x, y in tqdm(test_data_iter):
        test_samples = torch.cat([test_samples, x])
        labels = torch.cat([labels, y])
        x = x.to(device)
        _, interm = model(x, True)
        projections = torch.cat([projections, interm.detach().cpu().flatten(1)])

  0%|          | 0/43 [00:00<?, ?it/s]

In [12]:
distances = pairwise_distances(projections)
# distances = np.triu(distances)
distances_f = distances.flatten()

plt.matshow(distances)
plt.colorbar()
plt.figure()
plt.hist(distances_f[distances_f > 0], bins=1000)
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [13]:
clusters = DBSCAN(eps=2, min_samples=10).fit_predict(projections)

In [14]:
print(f"Number of inliers: {sum(clusters != -1) / len(clusters)}")

Number of inliers: 0.9535425517335891


In [15]:
perplexity = 100
p = reduce_dims_and_plot(projections[clusters != -1],
                         y=clusters[clusters != -1],
                         title=f'perplexity: {perplexity}',
                         file_name=None,
                         perplexity=perplexity,
                         library='Multicore-TSNE',
                         perform_PCA=False,
                         projected=None,
                         figure_type='2d',
                         show_figure=True,
                         close_figure=False,
                         text=None)    

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

# Train a Soft-Decision-Tree given the self-labels

## Prepare the dataset

In [16]:
tree_dataset = list(zip(test_samples.flatten(1)[clusters!=-1], clusters[clusters != -1]))
batch_size = 512
tree_loader = torch.utils.data.DataLoader(tree_dataset, batch_size=batch_size, shuffle=True)

# Define how we prune the weights of a node

In [17]:
def prune_node(node_weights, factor=1):
    w = node_weights.cpu().detach().numpy()
    mean_ = np.mean(w)
    std_ = np.std(w)
    node_weights[((mean_ - std_ * factor) < node_weights) & (node_weights < (mean_ + std_ * factor))] = 0
    return node_weights

def prune_node_keep(node_weights, keep=4):
    w = node_weights.cpu().detach().numpy()
    throw_idx = np.argsort(abs(w))[:-keep]
    node_weights[throw_idx] = 0
    return node_weights

def prune_tree(tree_, factor):
    new_weights = tree_.inner_nodes.weight.clone()
    for i in range(new_weights.shape[0]):
        res = prune_node_keep(new_weights[i, :], factor)
        new_weights[i, :] = res

    with torch.no_grad():
        tree_.inner_nodes.weight.copy_(new_weights)
        
def sparseness(x):
    s = []
    for i in range(x.shape[0]):
        x_ = x[i, :]
        sp = (len(x_) - torch.norm(x_, 0).item()) / len(x_)
        s.append(sp)
    return np.mean(s)

def compute_regularization_by_level(tree):
    total_reg = 0
    for i in range(tree.inner_nodes.weight.shape[0]):
        cur_level = np.floor(np.log2(i+1))
        node_reg = torch.norm(tree.inner_nodes.weight[i].view(-1), 2)
        total_reg += 2**(-cur_level) * node_reg
    return total_reg

def show_sparseness(tree):
    avg_sp = sparseness(tree.inner_nodes.weight)
    print(f"Average sparseness: {avg_sp}")
    layer = 0
    sps = []
    for i in range(tree.inner_nodes.weight.shape[0]):
        cur_layer = int(np.floor(np.log2(i+1)))
        if cur_layer != layer:
            print(f"layer {layer}: {np.mean(sps)}")
            sps = []
            layer = cur_layer

        x_ = tree.inner_nodes.weight[i, :]
        sp = (len(x_) - torch.norm(x_, 0).item()) / len(x_)
        sps.append(sp)
        
    return avg_sp

## Training configurations

In [18]:
def do_epoch(model, loader, device, log_interval, losses, accs, epoch, iteration):
    model = model.train()
    for batch_idx, (data, target) in enumerate(loader):
        iteration += 1
        data, target = data.to(device), target.to(device)

        output, penalty = tree.forward(data)

        # Loss
        loss_tree = criterion(output, target.view(-1))

        # Penalty
        loss_tree += penalty

        # Sparse regularization
#         fc_params = torch.cat([x.view(-1) for x in tree.inner_nodes.parameters()])
#         regularization = sparsity_lamda * torch.norm(fc_params, 2)
        regularization = sparsity_lamda * compute_regularization_by_level(tree)
        loss = loss_tree

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        losses.append(loss.item())

        pred = output.data.max(1)[1]
        correct = pred.eq(target.view(-1).data).sum()
        accs.append(correct.item() / data.size()[0])

        # Print training status
        if batch_idx % log_interval == 0:
            print(f"Epoch: {epoch:02d} | Batch: {batch_idx:03d} / {len(loader):03d} | Total loss: {loss.item():.3f} | Reg loss: {regularization.item():.3f} | Tree loss: {loss_tree.item():.3f} | Accuracy: {correct.item() / data.size()[0]:03f} | {round((time.time() - start_time) / iteration, 3)} sec/iter")
            
    return iteration


In [19]:
lr = 5e-3
weight_decay = 5e-4
sparsity_lamda = 2e-3
epochs = 400
log_interval = 100
use_cuda = device != 'cpu'

In [20]:
tree = SDT(input_dim=test_samples.shape[2], output_dim=len(set(clusters)) - 1, depth=tree_depth, lamda=1e-3, use_cuda=use_cuda)
optimizer = torch.optim.Adam(tree.parameters(),
                                 lr=lr,
                                 weight_decay=weight_decay)
criterion = nn.CrossEntropyLoss()
tree = tree.to(device)

In [21]:
losses = []
accs = []
sparsity = []

In [None]:
start_time = time.time()
iteration = 0
for epoch in range(epochs):
    # Training
    avg_sp = show_sparseness(tree)
    sparsity.append(avg_sp)
    iteration = do_epoch(tree, tree_loader, device, log_interval, losses, accs, epoch, iteration)
    
    if epoch % 1 == 0:
        prune_tree(tree, factor=3)
        

Average sparseness: 0.0
layer 0: 0.0
layer 1: 0.0
layer 2: 0.0
layer 3: 0.0
layer 4: 0.0
layer 5: 0.0
layer 6: 0.0
layer 7: 0.0
layer 8: 0.0
layer 9: 0.0
layer 10: 0.0
Epoch: 00 | Batch: 000 / 041 | Total loss: 1.802 | Reg loss: 0.014 | Tree loss: 1.802 | Accuracy: 0.035156 | 7.411 sec/iter
Average sparseness: 0.9840425531914895
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
layer 7: 0.9840425531914895
layer 8: 0.9840425531914895
layer 9: 0.9840425531914895
layer 10: 0.9840425531914895
Epoch: 01 | Batch: 000 / 041 | Total loss: 1.663 | Reg loss: 0.005 | Tree loss: 1.663 | Accuracy: 0.906250 | 6.875 sec/iter
Average sparseness: 0.9840425531914895
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
l

Average sparseness: 0.9840425531914895
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
layer 7: 0.9840425531914895
layer 8: 0.9840425531914895
layer 9: 0.9840425531914895
layer 10: 0.9840425531914895
Epoch: 18 | Batch: 000 / 041 | Total loss: 0.590 | Reg loss: 0.017 | Tree loss: 0.590 | Accuracy: 0.896484 | 6.979 sec/iter
Average sparseness: 0.9840425531914895
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
layer 7: 0.9840425531914895
layer 8: 0.9840425531914895
layer 9: 0.9840425531914895
layer 10: 0.9840425531914895
Epoch: 19 | Batch: 000 / 041 | Total loss: 0.558 | Reg loss: 0.017 | Tree loss: 0.558 | Accuracy: 0.912109 | 6.984 sec/iter
Average sparseness: 0.9840425531914895
layer 0: 0.984042

In [None]:
plt.figure(figsize=(10, 5))
plt.ylabel("Accuracy")
plt.xlabel('Iteration')
plt.plot(accs, label='Accuracy vs iteration')
plt.show()

In [None]:
plt.figure()
plt.ylabel("Loss")
plt.xlabel('Iteration')
plt.plot(losses, label='Loss vs iteration')
# plt.yscale("log")
plt.show()

plt.figure()
weights = tree.inner_nodes.weight.cpu().detach().numpy().flatten()
plt.hist(weights, bins=500)
weights_std = np.std(weights)
weights_mean = np.mean(weights)
plt.axvline(weights_mean + weights_std, color='r')
plt.axvline(weights_mean - weights_std, color='r')
plt.title(f"Mean: {weights_mean}   |   STD: {weights_std}")
plt.yscale("log")
plt.show()

# Tree Visualization

In [None]:
plt.figure(figsize=(15, 10), dpi=80)
avg_height, root = tree.visualize()

# Extract Rules

# Accumulate samples in the leaves

In [None]:
print(f"Number of patterns: {len(root.get_leaves())}")

In [None]:
method = 'greedy'

In [None]:
root.clear_leaves_samples()

with torch.no_grad():
    for batch_idx, (data, target) in enumerate(tree_loader):
        root.accumulate_samples(data, method)

# Tighten boundaries

In [None]:
attr_names = [f"T_{i}" for i in range(test_samples.shape[2])]
leaves = root.get_leaves()
sum_comprehensibility = 0
comprehensibilities = []
for pattern_counter, leaf in enumerate(leaves):
    leaf.reset_path()
    leaf.tighten_with_accumulated_samples()
    conds = leaf.get_path_conditions(attr_names)
    print(f"============== Pattern {pattern_counter + 1} ==============")
    comprehensibilities.append(sum([cond.comprehensibility for cond in conds]))
    
print(f"Average comprehensibility: {np.mean(comprehensibilities)}")
print(f"std comprehensibility: {np.std(comprehensibilities)}")
print(f"var comprehensibility: {np.var(comprehensibilities)}")
print(f"minimum comprehensibility: {np.min(comprehensibilities)}")
print(f"maximum comprehensibility: {np.max(comprehensibilities)}")