In [7]:
%load_ext autoreload
%autoreload 2
%matplotlib widget
import numpy as np
import pandas as pd
import torch
import torchvision
from torch import optim
import matplotlib.pyplot as plt
import torch.nn.functional as F
from sklearn.metrics import pairwise_distances
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tqdm.notebook import tqdm
import os
import sys
import time
import torch.nn as nn
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)
from stream_generators.mit_bih import MITBIH
from utils.MatplotlibUtils import reduce_dims_and_plot
from network.auto_encoder import AutoEncoder
from losses.knn_loss import KNNLoss, ClassificationKNNLoss
from sklearn.cluster import DBSCAN
from sklearn.mixture import GaussianMixture
from soft_decision_tree.sdt_model import SDT
from sklearn.metrics import davies_bouldin_score

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [8]:
k = 64
tree_depth = 10
batch_size = 512
device = 'cpu'
train_data_path = r'<>/mitbih_train.csv'  # replace <> with the correct path of the dataset
test_data_path = r'<>/mitbih_test.csv'  # replace <> with the correct path of the dataset

In [9]:
train_data_iter = torch.utils.data.DataLoader(MITBIH(train_data_path),
                                             batch_size=batch_size,
                                             shuffle=True,
                                             num_workers=1,
                                             pin_memory=True,
                                             drop_last=True)

test_data_iter = torch.utils.data.DataLoader(MITBIH(test_data_path),
                                             batch_size=batch_size,
                                             shuffle=True,
                                             num_workers=1,
                                             pin_memory=True)

In [10]:
class ConvBlock(nn.Module):
    def __init__(self):
        super(ConvBlock, self).__init__()
        self.conv1 = nn.Conv1d(32, 32, kernel_size=5, stride=1, padding=2)
        self.conv2 = nn.Conv1d(32, 32, kernel_size=5, stride=1, padding=2)
        self.relu1 = nn.ReLU()
        self.relu2 = nn.ReLU()
        self.pool = nn.MaxPool1d(kernel_size=5, stride=2)
        
    def forward(self, x):
        y = x
        y = self.conv1(y)
        y = self.relu1(y)
        y = self.conv2(y)
        y = y + x
        y = self.relu2(y)
        y = self.pool(y)
        return y


class ECGModel(nn.Module):
    def __init__(self):
        super(ECGModel, self).__init__()
        self.conv = nn.Conv1d(1, 32, kernel_size=5, stride=1, padding=1)
        self.block1 = ConvBlock()
        self.block2 = ConvBlock()
        self.block3 = ConvBlock()
        self.block4 = ConvBlock()
        self.block5 = ConvBlock()
        self.fc1 = nn.Linear(64, 32)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(32, 5)

    def forward(self, x, return_interm=False):
        x = self.conv(x)
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        x = self.block4(x)
        x = self.block5(x)
        interm = x.flatten(1)
        x = self.fc1(interm)
        x = self.relu(x)
        x = self.fc2(x)
        
        if return_interm:
            return x, interm
        
        return x

In [11]:
knn_crt = ClassificationKNNLoss(k=k).to(device)

def train(model, loader, optimizer, device):
    model.train()

    total_loss = 0
    for iteration, (batch, target) in enumerate(loader):
        batch = batch.to(device)
        target = target.to(device)
        outputs, interm = model(batch, return_interm=True)
        mse_loss = F.cross_entropy(outputs, target)
        mse_loss = mse_loss.sum(dim=-1).mean()
        try:
            knn_loss = knn_crt(interm, target)
            if torch.isinf(knn_loss):
                knn_loss = torch.tensor(0).to(device)
        except ValueError:
            knn_loss = torch.tensor(0).to(device)
        loss = mse_loss + knn_loss
        total_loss += loss.item()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if iteration % log_every == 0:
            print(f"Epoch {epoch} / {epochs} | iteration {iteration} / {len(loader)} | Total Loss: {loss.item()} | KNN Loss: {knn_loss.item()} | CLS Loss: {mse_loss.item()}")

    return total_loss / len(loader)

@torch.no_grad()
def test(model, loader, device):
    model.eval()
    
    correct = 0
    for iteration, (batch, target) in enumerate(loader):
        batch = batch.to(device)
        target = target.to(device)
        y_pred = model(batch).argmax(dim=-1)
        correct += y_pred.eq(target.view(-1).data).sum()
    
    return correct / len(loader.dataset)

In [12]:
epochs = 200
lr = 1e-3
log_every = 10

model = ECGModel().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
num_params = sum([p.numel() for p in model.parameters()])
print(f'#Params: {num_params}')

#Params: 53957


In [13]:
best_valid_acc = 0
losses = []
train_accs = []
val_accs = []
for epoch in range(1, epochs + 1):
    loss = train(model, train_data_iter, optimizer, device)
#     print(f"Loss: {loss} =============================")
    losses.append(loss)
    train_acc = test(model, train_data_iter, device)
    train_accs.append(train_acc)
    valid_acc = test(model, test_data_iter, device)
    val_accs.append(valid_acc)
    if valid_acc > best_valid_acc:
        best_valid_acc = valid_acc
    if epoch % 1 == 0:
        print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, '
              f'Train: {train_acc:.4f}, Valid: {valid_acc:.4f}, '
              f'Best: {best_valid_acc:.4f}')

  return torch._C._cuda_getDeviceCount() > 0


Epoch 1 / 200 | iteration 0 / 171 | Total Loss: 7.611621379852295 | KNN Loss: 5.851576805114746 | CLS Loss: 1.7600445747375488
Epoch 1 / 200 | iteration 10 / 171 | Total Loss: 5.5411458015441895 | KNN Loss: 4.718804836273193 | CLS Loss: 0.8223410844802856
Epoch 1 / 200 | iteration 20 / 171 | Total Loss: 5.31278133392334 | KNN Loss: 4.556290626525879 | CLS Loss: 0.7564908266067505
Epoch 1 / 200 | iteration 30 / 171 | Total Loss: 5.130352973937988 | KNN Loss: 4.530723571777344 | CLS Loss: 0.5996293425559998
Epoch 1 / 200 | iteration 40 / 171 | Total Loss: 5.090763092041016 | KNN Loss: 4.479553699493408 | CLS Loss: 0.6112095713615417
Epoch 1 / 200 | iteration 50 / 171 | Total Loss: 5.027774810791016 | KNN Loss: 4.468639850616455 | CLS Loss: 0.5591350793838501
Epoch 1 / 200 | iteration 60 / 171 | Total Loss: 4.985223770141602 | KNN Loss: 4.48076868057251 | CLS Loss: 0.504455029964447
Epoch 1 / 200 | iteration 70 / 171 | Total Loss: 4.821516990661621 | KNN Loss: 4.4435038566589355 | CLS Los

Epoch 4 / 200 | iteration 90 / 171 | Total Loss: 4.386694431304932 | KNN Loss: 4.247145175933838 | CLS Loss: 0.1395494043827057
Epoch 4 / 200 | iteration 100 / 171 | Total Loss: 4.374841690063477 | KNN Loss: 4.246660232543945 | CLS Loss: 0.12818162143230438
Epoch 4 / 200 | iteration 110 / 171 | Total Loss: 4.411810874938965 | KNN Loss: 4.282763481140137 | CLS Loss: 0.12904725968837738
Epoch 4 / 200 | iteration 120 / 171 | Total Loss: 4.402511119842529 | KNN Loss: 4.281371116638184 | CLS Loss: 0.12113998830318451
Epoch 4 / 200 | iteration 130 / 171 | Total Loss: 4.4582953453063965 | KNN Loss: 4.3242268562316895 | CLS Loss: 0.13406848907470703
Epoch 4 / 200 | iteration 140 / 171 | Total Loss: 4.4174981117248535 | KNN Loss: 4.270980358123779 | CLS Loss: 0.14651760458946228
Epoch 4 / 200 | iteration 150 / 171 | Total Loss: 4.4136857986450195 | KNN Loss: 4.308887958526611 | CLS Loss: 0.10479802638292313
Epoch 4 / 200 | iteration 160 / 171 | Total Loss: 4.330372333526611 | KNN Loss: 4.249135

Epoch 7 / 200 | iteration 170 / 171 | Total Loss: 4.397749900817871 | KNN Loss: 4.290170669555664 | CLS Loss: 0.10757926106452942
Epoch: 007, Loss: 4.3257, Train: 0.9791, Valid: 0.9756, Best: 0.9756
Epoch 8 / 200 | iteration 0 / 171 | Total Loss: 4.312385082244873 | KNN Loss: 4.230754852294922 | CLS Loss: 0.08163004368543625
Epoch 8 / 200 | iteration 10 / 171 | Total Loss: 4.3073225021362305 | KNN Loss: 4.232346534729004 | CLS Loss: 0.07497601956129074
Epoch 8 / 200 | iteration 20 / 171 | Total Loss: 4.376765727996826 | KNN Loss: 4.256069183349609 | CLS Loss: 0.12069671601057053
Epoch 8 / 200 | iteration 30 / 171 | Total Loss: 4.250857830047607 | KNN Loss: 4.192142486572266 | CLS Loss: 0.0587153285741806
Epoch 8 / 200 | iteration 40 / 171 | Total Loss: 4.308696269989014 | KNN Loss: 4.215068340301514 | CLS Loss: 0.09362789988517761
Epoch 8 / 200 | iteration 50 / 171 | Total Loss: 4.2926836013793945 | KNN Loss: 4.228835582733154 | CLS Loss: 0.06384791433811188
Epoch 8 / 200 | iteration 6

Epoch 11 / 200 | iteration 70 / 171 | Total Loss: 4.251927375793457 | KNN Loss: 4.207891941070557 | CLS Loss: 0.044035378843545914
Epoch 11 / 200 | iteration 80 / 171 | Total Loss: 4.266667366027832 | KNN Loss: 4.218547344207764 | CLS Loss: 0.048120222985744476
Epoch 11 / 200 | iteration 90 / 171 | Total Loss: 4.2895026206970215 | KNN Loss: 4.229923248291016 | CLS Loss: 0.05957914888858795
Epoch 11 / 200 | iteration 100 / 171 | Total Loss: 4.284947395324707 | KNN Loss: 4.229155540466309 | CLS Loss: 0.05579190328717232
Epoch 11 / 200 | iteration 110 / 171 | Total Loss: 4.2707695960998535 | KNN Loss: 4.1840410232543945 | CLS Loss: 0.08672859519720078
Epoch 11 / 200 | iteration 120 / 171 | Total Loss: 4.274397373199463 | KNN Loss: 4.226527690887451 | CLS Loss: 0.04786955937743187
Epoch 11 / 200 | iteration 130 / 171 | Total Loss: 4.230456829071045 | KNN Loss: 4.206214427947998 | CLS Loss: 0.024242430925369263
Epoch 11 / 200 | iteration 140 / 171 | Total Loss: 4.3067755699157715 | KNN Loss

Epoch 14 / 200 | iteration 150 / 171 | Total Loss: 4.281956672668457 | KNN Loss: 4.200355529785156 | CLS Loss: 0.08160095661878586
Epoch 14 / 200 | iteration 160 / 171 | Total Loss: 4.22836971282959 | KNN Loss: 4.184549331665039 | CLS Loss: 0.043820470571517944
Epoch 14 / 200 | iteration 170 / 171 | Total Loss: 4.241146087646484 | KNN Loss: 4.2008819580078125 | CLS Loss: 0.04026389122009277
Epoch: 014, Loss: 4.2620, Train: 0.9862, Valid: 0.9815, Best: 0.9815
Epoch 15 / 200 | iteration 0 / 171 | Total Loss: 4.280323505401611 | KNN Loss: 4.223146438598633 | CLS Loss: 0.057177189737558365
Epoch 15 / 200 | iteration 10 / 171 | Total Loss: 4.249734878540039 | KNN Loss: 4.201269149780273 | CLS Loss: 0.048465877771377563
Epoch 15 / 200 | iteration 20 / 171 | Total Loss: 4.280198574066162 | KNN Loss: 4.216894626617432 | CLS Loss: 0.06330376863479614
Epoch 15 / 200 | iteration 30 / 171 | Total Loss: 4.236361980438232 | KNN Loss: 4.186278343200684 | CLS Loss: 0.050083499401807785
Epoch 15 / 200 

Epoch 18 / 200 | iteration 40 / 171 | Total Loss: 4.221866607666016 | KNN Loss: 4.176784038543701 | CLS Loss: 0.045082446187734604
Epoch 18 / 200 | iteration 50 / 171 | Total Loss: 4.240352630615234 | KNN Loss: 4.170080661773682 | CLS Loss: 0.07027196884155273
Epoch 18 / 200 | iteration 60 / 171 | Total Loss: 4.244841575622559 | KNN Loss: 4.174323081970215 | CLS Loss: 0.07051845639944077
Epoch 18 / 200 | iteration 70 / 171 | Total Loss: 4.205639839172363 | KNN Loss: 4.166610240936279 | CLS Loss: 0.03902976214885712
Epoch 18 / 200 | iteration 80 / 171 | Total Loss: 4.27803897857666 | KNN Loss: 4.193373203277588 | CLS Loss: 0.08466577529907227
Epoch 18 / 200 | iteration 90 / 171 | Total Loss: 4.326822280883789 | KNN Loss: 4.255382061004639 | CLS Loss: 0.07144034653902054
Epoch 18 / 200 | iteration 100 / 171 | Total Loss: 4.2749223709106445 | KNN Loss: 4.216742038726807 | CLS Loss: 0.058180514723062515
Epoch 18 / 200 | iteration 110 / 171 | Total Loss: 4.212435722351074 | KNN Loss: 4.1630

Epoch 21 / 200 | iteration 120 / 171 | Total Loss: 4.198505401611328 | KNN Loss: 4.171528339385986 | CLS Loss: 0.026977157220244408
Epoch 21 / 200 | iteration 130 / 171 | Total Loss: 4.295622825622559 | KNN Loss: 4.226320743560791 | CLS Loss: 0.06930216401815414
Epoch 21 / 200 | iteration 140 / 171 | Total Loss: 4.1923346519470215 | KNN Loss: 4.14118766784668 | CLS Loss: 0.05114692822098732
Epoch 21 / 200 | iteration 150 / 171 | Total Loss: 4.224180698394775 | KNN Loss: 4.167666435241699 | CLS Loss: 0.05651447921991348
Epoch 21 / 200 | iteration 160 / 171 | Total Loss: 4.245595932006836 | KNN Loss: 4.19228458404541 | CLS Loss: 0.05331148952245712
Epoch 21 / 200 | iteration 170 / 171 | Total Loss: 4.250962257385254 | KNN Loss: 4.175938606262207 | CLS Loss: 0.07502368092536926
Epoch: 021, Loss: 4.2320, Train: 0.9897, Valid: 0.9838, Best: 0.9838
Epoch 22 / 200 | iteration 0 / 171 | Total Loss: 4.24685001373291 | KNN Loss: 4.210028171539307 | CLS Loss: 0.0368216298520565
Epoch 22 / 200 | i

Epoch 25 / 200 | iteration 10 / 171 | Total Loss: 4.202946186065674 | KNN Loss: 4.176621913909912 | CLS Loss: 0.02632424794137478
Epoch 25 / 200 | iteration 20 / 171 | Total Loss: 4.242764949798584 | KNN Loss: 4.215856552124023 | CLS Loss: 0.026908554136753082
Epoch 25 / 200 | iteration 30 / 171 | Total Loss: 4.211588382720947 | KNN Loss: 4.167508602142334 | CLS Loss: 0.04407999664545059
Epoch 25 / 200 | iteration 40 / 171 | Total Loss: 4.235064506530762 | KNN Loss: 4.2154541015625 | CLS Loss: 0.019610607996582985
Epoch 25 / 200 | iteration 50 / 171 | Total Loss: 4.226382255554199 | KNN Loss: 4.202508449554443 | CLS Loss: 0.023873774334788322
Epoch 25 / 200 | iteration 60 / 171 | Total Loss: 4.229248046875 | KNN Loss: 4.195174217224121 | CLS Loss: 0.034073878079652786
Epoch 25 / 200 | iteration 70 / 171 | Total Loss: 4.186060905456543 | KNN Loss: 4.170558452606201 | CLS Loss: 0.015502363443374634
Epoch 25 / 200 | iteration 80 / 171 | Total Loss: 4.224132061004639 | KNN Loss: 4.18850088

Epoch 28 / 200 | iteration 90 / 171 | Total Loss: 4.209473609924316 | KNN Loss: 4.193514823913574 | CLS Loss: 0.01595856063067913
Epoch 28 / 200 | iteration 100 / 171 | Total Loss: 4.179968357086182 | KNN Loss: 4.150774955749512 | CLS Loss: 0.0291933324187994
Epoch 28 / 200 | iteration 110 / 171 | Total Loss: 4.2003173828125 | KNN Loss: 4.181091785430908 | CLS Loss: 0.019225429743528366
Epoch 28 / 200 | iteration 120 / 171 | Total Loss: 4.22584867477417 | KNN Loss: 4.196103572845459 | CLS Loss: 0.02974509634077549
Epoch 28 / 200 | iteration 130 / 171 | Total Loss: 4.169257640838623 | KNN Loss: 4.151370525360107 | CLS Loss: 0.01788724772632122
Epoch 28 / 200 | iteration 140 / 171 | Total Loss: 4.220255374908447 | KNN Loss: 4.201069355010986 | CLS Loss: 0.019185852259397507
Epoch 28 / 200 | iteration 150 / 171 | Total Loss: 4.1873908042907715 | KNN Loss: 4.162753105163574 | CLS Loss: 0.024637851864099503
Epoch 28 / 200 | iteration 160 / 171 | Total Loss: 4.215841770172119 | KNN Loss: 4.1

Epoch 31 / 200 | iteration 170 / 171 | Total Loss: 4.329753875732422 | KNN Loss: 4.283032417297363 | CLS Loss: 0.046721626073122025
Epoch: 031, Loss: 4.2095, Train: 0.9917, Valid: 0.9859, Best: 0.9859
Epoch 32 / 200 | iteration 0 / 171 | Total Loss: 4.196259498596191 | KNN Loss: 4.168260097503662 | CLS Loss: 0.027999596670269966
Epoch 32 / 200 | iteration 10 / 171 | Total Loss: 4.195666790008545 | KNN Loss: 4.160972595214844 | CLS Loss: 0.03469400480389595
Epoch 32 / 200 | iteration 20 / 171 | Total Loss: 4.257388591766357 | KNN Loss: 4.225433826446533 | CLS Loss: 0.03195473924279213
Epoch 32 / 200 | iteration 30 / 171 | Total Loss: 4.1703410148620605 | KNN Loss: 4.151196002960205 | CLS Loss: 0.019144900143146515
Epoch 32 / 200 | iteration 40 / 171 | Total Loss: 4.197514533996582 | KNN Loss: 4.166368007659912 | CLS Loss: 0.031146572902798653
Epoch 32 / 200 | iteration 50 / 171 | Total Loss: 4.257805824279785 | KNN Loss: 4.229037284851074 | CLS Loss: 0.028768394142389297
Epoch 32 / 200 

Epoch 35 / 200 | iteration 60 / 171 | Total Loss: 4.223788261413574 | KNN Loss: 4.2151007652282715 | CLS Loss: 0.008687641471624374
Epoch 35 / 200 | iteration 70 / 171 | Total Loss: 4.207919597625732 | KNN Loss: 4.192777633666992 | CLS Loss: 0.015141968615353107
Epoch 35 / 200 | iteration 80 / 171 | Total Loss: 4.206507682800293 | KNN Loss: 4.170849800109863 | CLS Loss: 0.0356581024825573
Epoch 35 / 200 | iteration 90 / 171 | Total Loss: 4.210572242736816 | KNN Loss: 4.16250467300415 | CLS Loss: 0.048067398369312286
Epoch 35 / 200 | iteration 100 / 171 | Total Loss: 4.238564491271973 | KNN Loss: 4.198464393615723 | CLS Loss: 0.04010014608502388
Epoch 35 / 200 | iteration 110 / 171 | Total Loss: 4.197249889373779 | KNN Loss: 4.1406731605529785 | CLS Loss: 0.056576717644929886
Epoch 35 / 200 | iteration 120 / 171 | Total Loss: 4.21688985824585 | KNN Loss: 4.1817851066589355 | CLS Loss: 0.03510491922497749
Epoch 35 / 200 | iteration 130 / 171 | Total Loss: 4.194021701812744 | KNN Loss: 4.

Epoch 38 / 200 | iteration 140 / 171 | Total Loss: 4.200791835784912 | KNN Loss: 4.183178901672363 | CLS Loss: 0.01761278137564659
Epoch 38 / 200 | iteration 150 / 171 | Total Loss: 4.150178909301758 | KNN Loss: 4.14164924621582 | CLS Loss: 0.008529637940227985
Epoch 38 / 200 | iteration 160 / 171 | Total Loss: 4.218690395355225 | KNN Loss: 4.183375358581543 | CLS Loss: 0.035315223038196564
Epoch 38 / 200 | iteration 170 / 171 | Total Loss: 4.195539951324463 | KNN Loss: 4.17203950881958 | CLS Loss: 0.02350040338933468
Epoch: 038, Loss: 4.2027, Train: 0.9932, Valid: 0.9866, Best: 0.9874
Epoch 39 / 200 | iteration 0 / 171 | Total Loss: 4.191890239715576 | KNN Loss: 4.1646599769592285 | CLS Loss: 0.027230199426412582
Epoch 39 / 200 | iteration 10 / 171 | Total Loss: 4.1915082931518555 | KNN Loss: 4.167344570159912 | CLS Loss: 0.024163732305169106
Epoch 39 / 200 | iteration 20 / 171 | Total Loss: 4.181209564208984 | KNN Loss: 4.139123916625977 | CLS Loss: 0.04208579286932945
Epoch 39 / 200

Epoch 42 / 200 | iteration 30 / 171 | Total Loss: 4.242067337036133 | KNN Loss: 4.20757532119751 | CLS Loss: 0.03449203446507454
Epoch 42 / 200 | iteration 40 / 171 | Total Loss: 4.190289497375488 | KNN Loss: 4.175161361694336 | CLS Loss: 0.015128017403185368
Epoch 42 / 200 | iteration 50 / 171 | Total Loss: 4.1797590255737305 | KNN Loss: 4.159085750579834 | CLS Loss: 0.02067323960363865
Epoch 42 / 200 | iteration 60 / 171 | Total Loss: 4.178226947784424 | KNN Loss: 4.153156757354736 | CLS Loss: 0.025070106610655785
Epoch 42 / 200 | iteration 70 / 171 | Total Loss: 4.206882476806641 | KNN Loss: 4.193690776824951 | CLS Loss: 0.013191461563110352
Epoch 42 / 200 | iteration 80 / 171 | Total Loss: 4.233734130859375 | KNN Loss: 4.177765369415283 | CLS Loss: 0.055968642234802246
Epoch 42 / 200 | iteration 90 / 171 | Total Loss: 4.192840576171875 | KNN Loss: 4.185133934020996 | CLS Loss: 0.007706467993557453
Epoch 42 / 200 | iteration 100 / 171 | Total Loss: 4.163606643676758 | KNN Loss: 4.14

Epoch 45 / 200 | iteration 110 / 171 | Total Loss: 4.192473888397217 | KNN Loss: 4.160407543182373 | CLS Loss: 0.032066185027360916
Epoch 45 / 200 | iteration 120 / 171 | Total Loss: 4.1660895347595215 | KNN Loss: 4.140162944793701 | CLS Loss: 0.025926820933818817
Epoch 45 / 200 | iteration 130 / 171 | Total Loss: 4.177515029907227 | KNN Loss: 4.141493320465088 | CLS Loss: 0.036021821200847626
Epoch 45 / 200 | iteration 140 / 171 | Total Loss: 4.1983418464660645 | KNN Loss: 4.179731369018555 | CLS Loss: 0.018610574305057526
Epoch 45 / 200 | iteration 150 / 171 | Total Loss: 4.187421798706055 | KNN Loss: 4.176373481750488 | CLS Loss: 0.011048156768083572
Epoch 45 / 200 | iteration 160 / 171 | Total Loss: 4.149960041046143 | KNN Loss: 4.128811359405518 | CLS Loss: 0.021148812025785446
Epoch 45 / 200 | iteration 170 / 171 | Total Loss: 4.187256336212158 | KNN Loss: 4.165436744689941 | CLS Loss: 0.021819591522216797
Epoch: 045, Loss: 4.1894, Train: 0.9941, Valid: 0.9860, Best: 0.9874
Epoch

Epoch: 048, Loss: 4.1907, Train: 0.9950, Valid: 0.9879, Best: 0.9879
Epoch 49 / 200 | iteration 0 / 171 | Total Loss: 4.163189888000488 | KNN Loss: 4.160344123840332 | CLS Loss: 0.002845947165042162
Epoch 49 / 200 | iteration 10 / 171 | Total Loss: 4.2866668701171875 | KNN Loss: 4.2625346183776855 | CLS Loss: 0.024132052436470985
Epoch 49 / 200 | iteration 20 / 171 | Total Loss: 4.204080104827881 | KNN Loss: 4.172611236572266 | CLS Loss: 0.031468715518713
Epoch 49 / 200 | iteration 30 / 171 | Total Loss: 4.177768230438232 | KNN Loss: 4.155491352081299 | CLS Loss: 0.02227671816945076
Epoch 49 / 200 | iteration 40 / 171 | Total Loss: 4.18756628036499 | KNN Loss: 4.157102584838867 | CLS Loss: 0.030463803559541702
Epoch 49 / 200 | iteration 50 / 171 | Total Loss: 4.2115559577941895 | KNN Loss: 4.168483734130859 | CLS Loss: 0.04307224228978157
Epoch 49 / 200 | iteration 60 / 171 | Total Loss: 4.175419330596924 | KNN Loss: 4.1609063148498535 | CLS Loss: 0.014512856490910053
Epoch 49 / 200 | 

Epoch 52 / 200 | iteration 70 / 171 | Total Loss: 4.197727680206299 | KNN Loss: 4.181432247161865 | CLS Loss: 0.016295595094561577
Epoch 52 / 200 | iteration 80 / 171 | Total Loss: 4.183227062225342 | KNN Loss: 4.17805814743042 | CLS Loss: 0.005169024225324392
Epoch 52 / 200 | iteration 90 / 171 | Total Loss: 4.170871734619141 | KNN Loss: 4.148606777191162 | CLS Loss: 0.022264782339334488
Epoch 52 / 200 | iteration 100 / 171 | Total Loss: 4.15255069732666 | KNN Loss: 4.133375644683838 | CLS Loss: 0.019175026565790176
Epoch 52 / 200 | iteration 110 / 171 | Total Loss: 4.185458660125732 | KNN Loss: 4.17549467086792 | CLS Loss: 0.009963990189135075
Epoch 52 / 200 | iteration 120 / 171 | Total Loss: 4.177595138549805 | KNN Loss: 4.153993606567383 | CLS Loss: 0.023601680994033813
Epoch 52 / 200 | iteration 130 / 171 | Total Loss: 4.218579292297363 | KNN Loss: 4.176449775695801 | CLS Loss: 0.042129743844270706
Epoch 52 / 200 | iteration 140 / 171 | Total Loss: 4.171987533569336 | KNN Loss: 4

Epoch 55 / 200 | iteration 140 / 171 | Total Loss: 4.197900772094727 | KNN Loss: 4.166614055633545 | CLS Loss: 0.03128694370388985
Epoch 55 / 200 | iteration 150 / 171 | Total Loss: 4.180901527404785 | KNN Loss: 4.161993503570557 | CLS Loss: 0.018908172845840454
Epoch 55 / 200 | iteration 160 / 171 | Total Loss: 4.196558475494385 | KNN Loss: 4.165285587310791 | CLS Loss: 0.03127269074320793
Epoch 55 / 200 | iteration 170 / 171 | Total Loss: 4.185582637786865 | KNN Loss: 4.162908554077148 | CLS Loss: 0.02267426624894142
Epoch: 055, Loss: 4.1812, Train: 0.9953, Valid: 0.9866, Best: 0.9879
Epoch 56 / 200 | iteration 0 / 171 | Total Loss: 4.168360233306885 | KNN Loss: 4.162323474884033 | CLS Loss: 0.006036869715899229
Epoch 56 / 200 | iteration 10 / 171 | Total Loss: 4.1695098876953125 | KNN Loss: 4.158447742462158 | CLS Loss: 0.011062183417379856
Epoch 56 / 200 | iteration 20 / 171 | Total Loss: 4.168417930603027 | KNN Loss: 4.160828590393066 | CLS Loss: 0.007589166518300772
Epoch 56 / 20

Epoch 59 / 200 | iteration 30 / 171 | Total Loss: 4.235450267791748 | KNN Loss: 4.224225997924805 | CLS Loss: 0.011224178597331047
Epoch 59 / 200 | iteration 40 / 171 | Total Loss: 4.189977169036865 | KNN Loss: 4.181004047393799 | CLS Loss: 0.008973284624516964
Epoch 59 / 200 | iteration 50 / 171 | Total Loss: 4.2686333656311035 | KNN Loss: 4.242259502410889 | CLS Loss: 0.026374027132987976
Epoch 59 / 200 | iteration 60 / 171 | Total Loss: 4.149801254272461 | KNN Loss: 4.1427435874938965 | CLS Loss: 0.007057898677885532
Epoch 59 / 200 | iteration 70 / 171 | Total Loss: 4.195149898529053 | KNN Loss: 4.171325206756592 | CLS Loss: 0.023824622854590416
Epoch 59 / 200 | iteration 80 / 171 | Total Loss: 4.235574245452881 | KNN Loss: 4.228626728057861 | CLS Loss: 0.00694735161960125
Epoch 59 / 200 | iteration 90 / 171 | Total Loss: 4.159236431121826 | KNN Loss: 4.141646862030029 | CLS Loss: 0.017589615657925606
Epoch 59 / 200 | iteration 100 / 171 | Total Loss: 4.204322814941406 | KNN Loss: 4

Epoch 62 / 200 | iteration 100 / 171 | Total Loss: 4.20235538482666 | KNN Loss: 4.177860736846924 | CLS Loss: 0.024494634941220284
Epoch 62 / 200 | iteration 110 / 171 | Total Loss: 4.153744220733643 | KNN Loss: 4.148408889770508 | CLS Loss: 0.005335189867764711
Epoch 62 / 200 | iteration 120 / 171 | Total Loss: 4.176042556762695 | KNN Loss: 4.166662216186523 | CLS Loss: 0.009380178526043892
Epoch 62 / 200 | iteration 130 / 171 | Total Loss: 4.2133307456970215 | KNN Loss: 4.1831183433532715 | CLS Loss: 0.030212225392460823
Epoch 62 / 200 | iteration 140 / 171 | Total Loss: 4.197668075561523 | KNN Loss: 4.183560848236084 | CLS Loss: 0.014107326976954937
Epoch 62 / 200 | iteration 150 / 171 | Total Loss: 4.1816301345825195 | KNN Loss: 4.1582489013671875 | CLS Loss: 0.023381469771265984
Epoch 62 / 200 | iteration 160 / 171 | Total Loss: 4.152663230895996 | KNN Loss: 4.140800476074219 | CLS Loss: 0.011862615123391151
Epoch 62 / 200 | iteration 170 / 171 | Total Loss: 4.2052083015441895 | K

Epoch 65 / 200 | iteration 170 / 171 | Total Loss: 4.169375896453857 | KNN Loss: 4.127748489379883 | CLS Loss: 0.041627321392297745
Epoch: 065, Loss: 4.1843, Train: 0.9954, Valid: 0.9860, Best: 0.9879
Epoch 66 / 200 | iteration 0 / 171 | Total Loss: 4.190199851989746 | KNN Loss: 4.179450511932373 | CLS Loss: 0.010749546810984612
Epoch 66 / 200 | iteration 10 / 171 | Total Loss: 4.161864280700684 | KNN Loss: 4.119619369506836 | CLS Loss: 0.04224475100636482
Epoch 66 / 200 | iteration 20 / 171 | Total Loss: 4.153781890869141 | KNN Loss: 4.14592170715332 | CLS Loss: 0.007860290817916393
Epoch 66 / 200 | iteration 30 / 171 | Total Loss: 4.176140308380127 | KNN Loss: 4.16296911239624 | CLS Loss: 0.013171353377401829
Epoch 66 / 200 | iteration 40 / 171 | Total Loss: 4.151883602142334 | KNN Loss: 4.1373162269592285 | CLS Loss: 0.014567478559911251
Epoch 66 / 200 | iteration 50 / 171 | Total Loss: 4.145975112915039 | KNN Loss: 4.138625144958496 | CLS Loss: 0.007349858991801739
Epoch 66 / 200 |

Epoch 69 / 200 | iteration 60 / 171 | Total Loss: 4.210141181945801 | KNN Loss: 4.1895246505737305 | CLS Loss: 0.020616691559553146
Epoch 69 / 200 | iteration 70 / 171 | Total Loss: 4.157844543457031 | KNN Loss: 4.152379035949707 | CLS Loss: 0.005465322639793158
Epoch 69 / 200 | iteration 80 / 171 | Total Loss: 4.168301105499268 | KNN Loss: 4.15464973449707 | CLS Loss: 0.013651560992002487
Epoch 69 / 200 | iteration 90 / 171 | Total Loss: 4.1590681076049805 | KNN Loss: 4.123043060302734 | CLS Loss: 0.03602497652173042
Epoch 69 / 200 | iteration 100 / 171 | Total Loss: 4.204143524169922 | KNN Loss: 4.197665691375732 | CLS Loss: 0.006477800663560629
Epoch 69 / 200 | iteration 110 / 171 | Total Loss: 4.153521537780762 | KNN Loss: 4.14231538772583 | CLS Loss: 0.011206322349607944
Epoch 69 / 200 | iteration 120 / 171 | Total Loss: 4.212551593780518 | KNN Loss: 4.192112922668457 | CLS Loss: 0.020438862964510918
Epoch 69 / 200 | iteration 130 / 171 | Total Loss: 4.154801845550537 | KNN Loss: 

Epoch 72 / 200 | iteration 130 / 171 | Total Loss: 4.1536078453063965 | KNN Loss: 4.136358737945557 | CLS Loss: 0.017249252647161484
Epoch 72 / 200 | iteration 140 / 171 | Total Loss: 4.202422618865967 | KNN Loss: 4.176987648010254 | CLS Loss: 0.025435088202357292
Epoch 72 / 200 | iteration 150 / 171 | Total Loss: 4.134022235870361 | KNN Loss: 4.112412929534912 | CLS Loss: 0.021609248593449593
Epoch 72 / 200 | iteration 160 / 171 | Total Loss: 4.159771919250488 | KNN Loss: 4.128117561340332 | CLS Loss: 0.031654298305511475
Epoch 72 / 200 | iteration 170 / 171 | Total Loss: 4.199166297912598 | KNN Loss: 4.157289028167725 | CLS Loss: 0.041877347975969315
Epoch: 072, Loss: 4.1673, Train: 0.9959, Valid: 0.9864, Best: 0.9879
Epoch 73 / 200 | iteration 0 / 171 | Total Loss: 4.168333053588867 | KNN Loss: 4.157445907592773 | CLS Loss: 0.010887343436479568
Epoch 73 / 200 | iteration 10 / 171 | Total Loss: 4.177974224090576 | KNN Loss: 4.144679546356201 | CLS Loss: 0.03329487517476082
Epoch 73 /

Epoch 76 / 200 | iteration 20 / 171 | Total Loss: 4.139902591705322 | KNN Loss: 4.133307456970215 | CLS Loss: 0.006595138926059008
Epoch 76 / 200 | iteration 30 / 171 | Total Loss: 4.160104751586914 | KNN Loss: 4.153448581695557 | CLS Loss: 0.006656260695308447
Epoch 76 / 200 | iteration 40 / 171 | Total Loss: 4.169928073883057 | KNN Loss: 4.166134357452393 | CLS Loss: 0.003793500829488039
Epoch 76 / 200 | iteration 50 / 171 | Total Loss: 4.214899063110352 | KNN Loss: 4.194516181945801 | CLS Loss: 0.020383082330226898
Epoch 76 / 200 | iteration 60 / 171 | Total Loss: 4.175389289855957 | KNN Loss: 4.13417387008667 | CLS Loss: 0.04121527820825577
Epoch 76 / 200 | iteration 70 / 171 | Total Loss: 4.181911945343018 | KNN Loss: 4.163629531860352 | CLS Loss: 0.018282609060406685
Epoch 76 / 200 | iteration 80 / 171 | Total Loss: 4.220448970794678 | KNN Loss: 4.213099002838135 | CLS Loss: 0.007349777966737747
Epoch 76 / 200 | iteration 90 / 171 | Total Loss: 4.17477560043335 | KNN Loss: 4.1420

Epoch 79 / 200 | iteration 90 / 171 | Total Loss: 4.150318622589111 | KNN Loss: 4.1280107498168945 | CLS Loss: 0.022307857871055603
Epoch 79 / 200 | iteration 100 / 171 | Total Loss: 4.149371147155762 | KNN Loss: 4.111486911773682 | CLS Loss: 0.03788416460156441
Epoch 79 / 200 | iteration 110 / 171 | Total Loss: 4.168167591094971 | KNN Loss: 4.155006408691406 | CLS Loss: 0.013161257840692997
Epoch 79 / 200 | iteration 120 / 171 | Total Loss: 4.1702680587768555 | KNN Loss: 4.153422832489014 | CLS Loss: 0.01684541627764702
Epoch 79 / 200 | iteration 130 / 171 | Total Loss: 4.148903846740723 | KNN Loss: 4.144194602966309 | CLS Loss: 0.004709448665380478
Epoch 79 / 200 | iteration 140 / 171 | Total Loss: 4.147111415863037 | KNN Loss: 4.144194602966309 | CLS Loss: 0.0029165949672460556
Epoch 79 / 200 | iteration 150 / 171 | Total Loss: 4.185890197753906 | KNN Loss: 4.164979934692383 | CLS Loss: 0.02091006375849247
Epoch 79 / 200 | iteration 160 / 171 | Total Loss: 4.15160608291626 | KNN Los

Epoch 82 / 200 | iteration 160 / 171 | Total Loss: 4.130731582641602 | KNN Loss: 4.118416786193848 | CLS Loss: 0.012314657680690289
Epoch 82 / 200 | iteration 170 / 171 | Total Loss: 4.144529342651367 | KNN Loss: 4.137655735015869 | CLS Loss: 0.006873531267046928
Epoch: 082, Loss: 4.1675, Train: 0.9973, Valid: 0.9870, Best: 0.9879
Epoch 83 / 200 | iteration 0 / 171 | Total Loss: 4.1322736740112305 | KNN Loss: 4.130181312561035 | CLS Loss: 0.0020925994031131268
Epoch 83 / 200 | iteration 10 / 171 | Total Loss: 4.155758857727051 | KNN Loss: 4.1355695724487305 | CLS Loss: 0.020189112052321434
Epoch 83 / 200 | iteration 20 / 171 | Total Loss: 4.145876884460449 | KNN Loss: 4.139819622039795 | CLS Loss: 0.006057152524590492
Epoch 83 / 200 | iteration 30 / 171 | Total Loss: 4.153669834136963 | KNN Loss: 4.14516544342041 | CLS Loss: 0.008504399098455906
Epoch 83 / 200 | iteration 40 / 171 | Total Loss: 4.1693220138549805 | KNN Loss: 4.161662578582764 | CLS Loss: 0.007659585680812597
Epoch 83 /

Epoch 86 / 200 | iteration 50 / 171 | Total Loss: 4.142706394195557 | KNN Loss: 4.13645601272583 | CLS Loss: 0.006250463426113129
Epoch 86 / 200 | iteration 60 / 171 | Total Loss: 4.159805774688721 | KNN Loss: 4.154379844665527 | CLS Loss: 0.005426023155450821
Epoch 86 / 200 | iteration 70 / 171 | Total Loss: 4.187921047210693 | KNN Loss: 4.169433116912842 | CLS Loss: 0.018488062545657158
Epoch 86 / 200 | iteration 80 / 171 | Total Loss: 4.153057098388672 | KNN Loss: 4.150600433349609 | CLS Loss: 0.0024564729537814856
Epoch 86 / 200 | iteration 90 / 171 | Total Loss: 4.174063682556152 | KNN Loss: 4.159762859344482 | CLS Loss: 0.014300926588475704
Epoch 86 / 200 | iteration 100 / 171 | Total Loss: 4.151689052581787 | KNN Loss: 4.1401824951171875 | CLS Loss: 0.011506634764373302
Epoch 86 / 200 | iteration 110 / 171 | Total Loss: 4.157598972320557 | KNN Loss: 4.133328914642334 | CLS Loss: 0.024270083755254745
Epoch 86 / 200 | iteration 120 / 171 | Total Loss: 4.167727947235107 | KNN Loss:

Epoch 89 / 200 | iteration 120 / 171 | Total Loss: 4.14461088180542 | KNN Loss: 4.140966892242432 | CLS Loss: 0.003643893403932452
Epoch 89 / 200 | iteration 130 / 171 | Total Loss: 4.188840389251709 | KNN Loss: 4.163620471954346 | CLS Loss: 0.02521994151175022
Epoch 89 / 200 | iteration 140 / 171 | Total Loss: 4.14785099029541 | KNN Loss: 4.139376163482666 | CLS Loss: 0.008474699221551418
Epoch 89 / 200 | iteration 150 / 171 | Total Loss: 4.1736555099487305 | KNN Loss: 4.167597770690918 | CLS Loss: 0.006057725753635168
Epoch 89 / 200 | iteration 160 / 171 | Total Loss: 4.142537593841553 | KNN Loss: 4.139855861663818 | CLS Loss: 0.0026817191392183304
Epoch 89 / 200 | iteration 170 / 171 | Total Loss: 4.176389217376709 | KNN Loss: 4.148622512817383 | CLS Loss: 0.027766840532422066
Epoch: 089, Loss: 4.1597, Train: 0.9971, Valid: 0.9867, Best: 0.9879
Epoch 90 / 200 | iteration 0 / 171 | Total Loss: 4.143197059631348 | KNN Loss: 4.127331733703613 | CLS Loss: 0.015865514054894447
Epoch 90 /

KeyboardInterrupt: 

In [14]:
test(model, test_data_iter, device)

tensor(0.9865)

In [15]:
plt.figure()
plt.plot(losses, label='train loss')
plt.legend()
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [16]:
plt.figure()
plt.plot(train_accs, label='train accuracy')
plt.plot(val_accs, label='validation accuracy')
plt.legend()
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [17]:
test_samples = torch.tensor([])
projections = torch.tensor([])
labels = torch.tensor([])

with torch.no_grad():
    for x, y in tqdm(test_data_iter):
        test_samples = torch.cat([test_samples, x])
        labels = torch.cat([labels, y])
        x = x.to(device)
        _, interm = model(x, True)
        projections = torch.cat([projections, interm.detach().cpu().flatten(1)])

  0%|          | 0/43 [00:00<?, ?it/s]

In [18]:
distances = pairwise_distances(projections)
# distances = np.triu(distances)
distances_f = distances.flatten()

plt.matshow(distances)
plt.colorbar()
plt.figure()
plt.hist(distances_f[distances_f > 0], bins=1000)
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [19]:
clusters = DBSCAN(eps=2, min_samples=10).fit_predict(projections)

In [20]:
print(f"Number of inliers: {sum(clusters != -1) / len(clusters)}")

Number of inliers: 0.9410716732903933


In [21]:
perplexity = 100
p = reduce_dims_and_plot(projections[clusters != -1],
                         y=clusters[clusters != -1],
                         title=f'perplexity: {perplexity}',
                         file_name=None,
                         perplexity=perplexity,
                         library='Multicore-TSNE',
                         perform_PCA=False,
                         projected=None,
                         figure_type='2d',
                         show_figure=True,
                         close_figure=False,
                         text=None)    

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

# Train a Soft-Decision-Tree given the self-labels

## Prepare the dataset

In [22]:
tree_dataset = list(zip(test_samples.flatten(1)[clusters!=-1], clusters[clusters != -1]))
batch_size = 512
tree_loader = torch.utils.data.DataLoader(tree_dataset, batch_size=batch_size, shuffle=True)

# Define how we prune the weights of a node

In [23]:
def prune_node(node_weights, factor=1):
    w = node_weights.cpu().detach().numpy()
    mean_ = np.mean(w)
    std_ = np.std(w)
    node_weights[((mean_ - std_ * factor) < node_weights) & (node_weights < (mean_ + std_ * factor))] = 0
    return node_weights

def prune_node_keep(node_weights, keep=4):
    w = node_weights.cpu().detach().numpy()
    throw_idx = np.argsort(abs(w))[:-keep]
    node_weights[throw_idx] = 0
    return node_weights

def prune_tree(tree_, factor):
    new_weights = tree_.inner_nodes.weight.clone()
    for i in range(new_weights.shape[0]):
        res = prune_node_keep(new_weights[i, :], factor)
        new_weights[i, :] = res

    with torch.no_grad():
        tree_.inner_nodes.weight.copy_(new_weights)
        
def sparseness(x):
    s = []
    for i in range(x.shape[0]):
        x_ = x[i, :]
        sp = (len(x_) - torch.norm(x_, 0).item()) / len(x_)
        s.append(sp)
    return np.mean(s)

def compute_regularization_by_level(tree):
    total_reg = 0
    for i in range(tree.inner_nodes.weight.shape[0]):
        cur_level = np.floor(np.log2(i+1))
        node_reg = torch.norm(tree.inner_nodes.weight[i].view(-1), 2)
        total_reg += 2**(-cur_level) * node_reg
    return total_reg

def show_sparseness(tree):
    avg_sp = sparseness(tree.inner_nodes.weight)
    print(f"Average sparseness: {avg_sp}")
    layer = 0
    sps = []
    for i in range(tree.inner_nodes.weight.shape[0]):
        cur_layer = int(np.floor(np.log2(i+1)))
        if cur_layer != layer:
            print(f"layer {layer}: {np.mean(sps)}")
            sps = []
            layer = cur_layer

        x_ = tree.inner_nodes.weight[i, :]
        sp = (len(x_) - torch.norm(x_, 0).item()) / len(x_)
        sps.append(sp)
        
    return avg_sp

## Training configurations

In [24]:
def do_epoch(model, loader, device, log_interval, losses, accs, epoch, iteration):
    model = model.train()
    for batch_idx, (data, target) in enumerate(loader):
        iteration += 1
        data, target = data.to(device), target.to(device)

        output, penalty = tree.forward(data)

        # Loss
        loss_tree = criterion(output, target.view(-1))

        # Penalty
        loss_tree += penalty

        # Sparse regularization
#         fc_params = torch.cat([x.view(-1) for x in tree.inner_nodes.parameters()])
#         regularization = sparsity_lamda * torch.norm(fc_params, 2)
        regularization = sparsity_lamda * compute_regularization_by_level(tree)
        loss = loss_tree

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        losses.append(loss.item())

        pred = output.data.max(1)[1]
        correct = pred.eq(target.view(-1).data).sum()
        accs.append(correct.item() / data.size()[0])

        # Print training status
        if batch_idx % log_interval == 0:
            print(f"Epoch: {epoch:02d} | Batch: {batch_idx:03d} / {len(loader):03d} | Total loss: {loss.item():.3f} | Reg loss: {regularization.item():.3f} | Tree loss: {loss_tree.item():.3f} | Accuracy: {correct.item() / data.size()[0]:03f} | {round((time.time() - start_time) / iteration, 3)} sec/iter")
            
    return iteration


In [25]:
lr = 5e-3
weight_decay = 5e-4
sparsity_lamda = 2e-3
epochs = 400
log_interval = 100
use_cuda = device != 'cpu'

In [26]:
tree = SDT(input_dim=test_samples.shape[2], output_dim=len(set(clusters)) - 1, depth=tree_depth, lamda=1e-3, use_cuda=use_cuda)
optimizer = torch.optim.Adam(tree.parameters(),
                                 lr=lr,
                                 weight_decay=weight_decay)
criterion = nn.CrossEntropyLoss()
tree = tree.to(device)

In [27]:
losses = []
accs = []
sparsity = []

In [None]:
start_time = time.time()
iteration = 0
for epoch in range(epochs):
    # Training
    avg_sp = show_sparseness(tree)
    sparsity.append(avg_sp)
    iteration = do_epoch(tree, tree_loader, device, log_interval, losses, accs, epoch, iteration)
    
    if epoch % 1 == 0:
        prune_tree(tree, factor=3)
        

Average sparseness: 0.0
layer 0: 0.0
layer 1: 0.0
layer 2: 0.0
layer 3: 0.0
layer 4: 0.0
layer 5: 0.0
layer 6: 0.0
layer 7: 0.0
layer 8: 0.0
Epoch: 00 | Batch: 000 / 041 | Total loss: 1.587 | Reg loss: 0.012 | Tree loss: 1.587 | Accuracy: 0.691406 | 1.298 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
layer 7: 0.9840425531914895
layer 8: 0.9840425531914895
Epoch: 01 | Batch: 000 / 041 | Total loss: 1.455 | Reg loss: 0.005 | Tree loss: 1.455 | Accuracy: 0.677734 | 2.015 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
layer 7: 0.9840425531914895
layer 8: 0.9840425531914895
Epoch: 02 | Batch: 000 / 041 

Epoch: 20 | Batch: 000 / 041 | Total loss: 0.872 | Reg loss: 0.021 | Tree loss: 0.872 | Accuracy: 0.695312 | 1.863 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
layer 7: 0.9840425531914895
layer 8: 0.9840425531914895
Epoch: 21 | Batch: 000 / 041 | Total loss: 0.838 | Reg loss: 0.021 | Tree loss: 0.838 | Accuracy: 0.708984 | 1.864 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
layer 7: 0.9840425531914895
layer 8: 0.9840425531914895
Epoch: 22 | Batch: 000 / 041 | Total loss: 0.850 | Reg loss: 0.021 | Tree loss: 0.850 | Accuracy: 0.705078 | 1.863 sec/iter
Average sparseness: 0.9840425531914894
layer 0

Epoch: 40 | Batch: 000 / 041 | Total loss: 0.840 | Reg loss: 0.023 | Tree loss: 0.840 | Accuracy: 0.708984 | 1.853 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
layer 7: 0.9840425531914895
layer 8: 0.9840425531914895
Epoch: 41 | Batch: 000 / 041 | Total loss: 0.921 | Reg loss: 0.023 | Tree loss: 0.921 | Accuracy: 0.681641 | 1.854 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
layer 7: 0.9840425531914895
layer 8: 0.9840425531914895
Epoch: 42 | Batch: 000 / 041 | Total loss: 0.921 | Reg loss: 0.023 | Tree loss: 0.921 | Accuracy: 0.691406 | 1.852 sec/iter
Average sparseness: 0.9840425531914894
layer 0

Epoch: 60 | Batch: 000 / 041 | Total loss: 0.710 | Reg loss: 0.024 | Tree loss: 0.710 | Accuracy: 0.748047 | 1.85 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
layer 7: 0.9840425531914895
layer 8: 0.9840425531914895
Epoch: 61 | Batch: 000 / 041 | Total loss: 0.805 | Reg loss: 0.024 | Tree loss: 0.805 | Accuracy: 0.708984 | 1.853 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
layer 7: 0.9840425531914895
layer 8: 0.9840425531914895
Epoch: 62 | Batch: 000 / 041 | Total loss: 0.702 | Reg loss: 0.024 | Tree loss: 0.702 | Accuracy: 0.740234 | 1.856 sec/iter
Average sparseness: 0.9840425531914894
layer 0:

Epoch: 80 | Batch: 000 / 041 | Total loss: 0.671 | Reg loss: 0.024 | Tree loss: 0.671 | Accuracy: 0.765625 | 1.878 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
layer 7: 0.9840425531914895
layer 8: 0.9840425531914895
Epoch: 81 | Batch: 000 / 041 | Total loss: 0.725 | Reg loss: 0.024 | Tree loss: 0.725 | Accuracy: 0.730469 | 1.878 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
layer 7: 0.9840425531914895
layer 8: 0.9840425531914895
Epoch: 82 | Batch: 000 / 041 | Total loss: 0.717 | Reg loss: 0.024 | Tree loss: 0.717 | Accuracy: 0.736328 | 1.881 sec/iter
Average sparseness: 0.9840425531914894
layer 0

Epoch: 100 | Batch: 000 / 041 | Total loss: 0.789 | Reg loss: 0.025 | Tree loss: 0.789 | Accuracy: 0.726562 | 1.89 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
layer 7: 0.9840425531914895
layer 8: 0.9840425531914895
Epoch: 101 | Batch: 000 / 041 | Total loss: 0.755 | Reg loss: 0.025 | Tree loss: 0.755 | Accuracy: 0.705078 | 1.888 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
layer 7: 0.9840425531914895
layer 8: 0.9840425531914895
Epoch: 102 | Batch: 000 / 041 | Total loss: 0.692 | Reg loss: 0.025 | Tree loss: 0.692 | Accuracy: 0.714844 | 1.89 sec/iter
Average sparseness: 0.9840425531914894
layer 

Epoch: 120 | Batch: 000 / 041 | Total loss: 0.700 | Reg loss: 0.025 | Tree loss: 0.700 | Accuracy: 0.742188 | 1.894 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
layer 7: 0.9840425531914895
layer 8: 0.9840425531914895
Epoch: 121 | Batch: 000 / 041 | Total loss: 0.782 | Reg loss: 0.025 | Tree loss: 0.782 | Accuracy: 0.693359 | 1.894 sec/iter
Average sparseness: 0.9840425531914894
layer 0: 0.9840425531914894
layer 1: 0.9840425531914894
layer 2: 0.9840425531914894
layer 3: 0.9840425531914894
layer 4: 0.9840425531914894
layer 5: 0.9840425531914894
layer 6: 0.9840425531914895
layer 7: 0.9840425531914895
layer 8: 0.9840425531914895
Epoch: 122 | Batch: 000 / 041 | Total loss: 0.832 | Reg loss: 0.025 | Tree loss: 0.832 | Accuracy: 0.673828 | 1.897 sec/iter
Average sparseness: 0.9840425531914894
laye

In [None]:
plt.figure(figsize=(10, 5))
plt.ylabel("Accuracy")
plt.xlabel('Iteration')
plt.plot(accs, label='Accuracy vs iteration')
plt.show()

In [None]:
plt.figure()
plt.ylabel("Loss")
plt.xlabel('Iteration')
plt.plot(losses, label='Loss vs iteration')
# plt.yscale("log")
plt.show()

plt.figure()
weights = tree.inner_nodes.weight.cpu().detach().numpy().flatten()
plt.hist(weights, bins=500)
weights_std = np.std(weights)
weights_mean = np.mean(weights)
plt.axvline(weights_mean + weights_std, color='r')
plt.axvline(weights_mean - weights_std, color='r')
plt.title(f"Mean: {weights_mean}   |   STD: {weights_std}")
plt.yscale("log")
plt.show()

# Tree Visualization

In [None]:
plt.figure(figsize=(15, 10), dpi=80)
avg_height, root = tree.visualize()

# Extract Rules

# Accumulate samples in the leaves

In [None]:
print(f"Number of patterns: {len(root.get_leaves())}")

In [None]:
method = 'greedy'

In [None]:
root.clear_leaves_samples()

with torch.no_grad():
    for batch_idx, (data, target) in enumerate(tree_loader):
        root.accumulate_samples(data, method)

# Tighten boundaries

In [None]:
attr_names = [f"T_{i}" for i in range(test_samples.shape[2])]
leaves = root.get_leaves()
sum_comprehensibility = 0
comprehensibilities = []
for pattern_counter, leaf in enumerate(leaves):
    leaf.reset_path()
    leaf.tighten_with_accumulated_samples()
    conds = leaf.get_path_conditions(attr_names)
    print(f"============== Pattern {pattern_counter + 1} ==============")
    comprehensibilities.append(sum([cond.comprehensibility for cond in conds]))
    
print(f"Average comprehensibility: {np.mean(comprehensibilities)}")
print(f"std comprehensibility: {np.std(comprehensibilities)}")
print(f"var comprehensibility: {np.var(comprehensibilities)}")
print(f"minimum comprehensibility: {np.min(comprehensibilities)}")
print(f"maximum comprehensibility: {np.max(comprehensibilities)}")