In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [1]:
import os 

In [2]:
os.chdir("/kaggle/input/ecg-signal-classification/ecg-neural-ode-master")

In [3]:
pip install torchdiffeq

Collecting torchdiffeq
  Downloading torchdiffeq-0.2.4-py3-none-any.whl.metadata (440 bytes)
Downloading torchdiffeq-0.2.4-py3-none-any.whl (32 kB)
Installing collected packages: torchdiffeq
Successfully installed torchdiffeq-0.2.4
Note: you may need to restart the kernel to use updated packages.


In [4]:
import torchdiffeq

In [5]:
import pandas as pd
import numpy as np
from collections import Counter

from imblearn.over_sampling import RandomOverSampler
from sklearn.model_selection import train_test_split

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader

from utils import anderson, count_parameters, epoch, epoch_eval
from models import ResBlock, ODEfunc, ODENet

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
torch.manual_seed(0);

In [6]:
def get_model(net, device, n_channels=32, n_inner_channels=32, kernel_size=3, 
              num_groups=8, adam=False, **kwargs):
    """
    Initialize ResNet, ODENet, or DEQNet with optimizer.
    """
    downsampling_layers = [
        nn.Conv1d(1, n_channels, kernel_size=3, bias=True, padding="same"),
        nn.BatchNorm1d(n_channels)
    ]

    if net == 'ResNet':
        feature_layers = [ResBlock(n_channels, n_inner_channels, kernel_size, num_groups) for _ in range(3)]
    elif net == 'ODENet':
        feature_layers = [ODENet(ODEfunc(n_channels, n_inner_channels, kernel_size, num_groups), **kwargs)]
    else:
        return 0
        
    fc_layers = [
        nn.BatchNorm1d(n_channels), 
        nn.ReLU(inplace=True), 
        nn.AdaptiveAvgPool1d(1), 
        nn.Flatten(), 
        nn.Linear(n_channels, 5)
    ]

    model = nn.Sequential(*downsampling_layers, *feature_layers, *fc_layers)

    opt = optim.Adam(model.parameters(), lr=1e-3) if adam else optim.SGD(model.parameters(), lr=0.1, momentum=0.9)

    return model.to(device), opt

In [7]:
np.random.seed(42)
path = '/kaggle/input/ecg-signal-classification/ecg-neural-ode-master/data/'

# Load MIT-BIH data
mit_train = pd.read_csv(path + "mitdb_360_train.csv", header=None)
mit_test = pd.read_csv(path + "mitdb_360_test.csv", header=None)

y_train = mit_train[360]
X_train = mit_train.loc[:, :359]
y_test = mit_test[360]
X_test = mit_test.loc[:, :359]
print('Train before:', Counter(y_train))
print('Test before:', Counter(y_test), end='\n\n')

# Oversample training set
ros = RandomOverSampler(random_state=0)
X_train_oversampled, y_train_oversampled = ros.fit_resample(X_train, y_train)
print('Train oversampled:', Counter(y_train_oversampled), end='\n\n')

# Split a validation set
X_train, X_val, y_train, y_val = train_test_split(
                                    X_train_oversampled, 
                                    y_train_oversampled, 
                                    test_size=0.1, 
                                    random_state=42)
print('Train after:', Counter(y_train))
print('Val after:', Counter(y_val))
print('Test after:', Counter(y_test))

# Convert to 3D tensor
X_train, y_train, X_val, y_val, X_test, y_test = map(
    torch.from_numpy, 
    (X_train.values, y_train.values, 
     X_val.values, y_val.values, 
     X_test.values, y_test.values)
)
X_train = X_train.unsqueeze(1).float()
X_val = X_val.unsqueeze(1).float()
X_test = X_test.unsqueeze(1).float()
y_train = y_train.long()
y_val = y_val.long()
y_test = y_test.long()

# Batch size
bs = 128

# Dataloaders
train_ds = TensorDataset(X_train, y_train)
train_dl = DataLoader(train_ds, batch_size=bs, shuffle=True, num_workers=8)
val_ds = TensorDataset(X_val, y_val)
val_dl = DataLoader(val_ds, batch_size=bs * 2, shuffle=False, num_workers=8)
test_ds = TensorDataset(X_test, y_test)
test_dl = DataLoader(test_ds, batch_size=bs * 2, shuffle=False, num_workers=8)

Train before: Counter({0: 88069, 2: 7042, 4: 3625, 1: 3016, 3: 760})
Test before: Counter({0: 800, 1: 800, 2: 800, 4: 800, 3: 300})

Train oversampled: Counter({0: 88069, 1: 88069, 2: 88069, 4: 88069, 3: 88069})

Train after: Counter({0: 79358, 3: 79269, 4: 79264, 1: 79254, 2: 79165})
Val after: Counter({2: 8904, 1: 8815, 4: 8805, 3: 8800, 0: 8711})
Test after: Counter({0: 800, 1: 800, 2: 800, 4: 800, 3: 300})




In [24]:
# Initialize ResNet
resnet, resnetopt = get_model(net='ResNet', device=device,
                              n_channels=32, n_inner_channels=32, 
                              kernel_size=3, num_groups=8, adam=True)

# Training options
max_epochs = 5
#scheduler = optim.lr_scheduler.ReduceLROnPlateau(resnetopt, mode='min', factor=0.1, patience=5)

# Training loop
for i in range(max_epochs):
    epoch(train_dl, resnet, device, resnetopt, scheduler, epoch=i+1)
    epoch_eval(val_dl, resnet, device)

Training... epoch 1




    Percent trained: 100.0%  Time elapsed: 0.37 min
    Train acc: 0.91
    Train loss: 0.262
Testing
    Test acc: 0.928
    Test loss: 0.204
              precision    recall  f1-score   support

           0       0.89      0.87      0.88      8711
           1       0.87      0.97      0.92      8815
           2       0.99      0.85      0.92      8904
           3       0.91      0.97      0.94      8800
           4       1.00      0.99      0.99      8805

    accuracy                           0.93     44035
   macro avg       0.93      0.93      0.93     44035
weighted avg       0.93      0.93      0.93     44035

[[7543  957   50  158    3]
 [ 287 8514   11    3    0]
 [ 341  240 7610  707    6]
 [ 232   60    0 8508    0]
 [  78    8    2   11 8706]]

Training... epoch 2




    Percent trained: 100.0%  Time elapsed: 0.37 min
    Train acc: 0.956
    Train loss: 0.127
Testing
    Test acc: 0.965
    Test loss: 0.102
              precision    recall  f1-score   support

           0       0.96      0.90      0.93      8711
           1       0.94      0.96      0.95      8815
           2       0.96      0.98      0.97      8904
           3       0.97      0.99      0.98      8800
           4       1.00      1.00      1.00      8805

    accuracy                           0.97     44035
   macro avg       0.97      0.96      0.96     44035
weighted avg       0.97      0.97      0.96     44035

[[7839  493  205  171    3]
 [ 266 8489   54    6    0]
 [  35   36 8695  132    6]
 [   0    0  105 8695    0]
 [   8    0   12    4 8781]]

Training... epoch 3




    Percent trained: 100.0%  Time elapsed: 0.38 min
    Train acc: 0.968
    Train loss: 0.093
Testing
    Test acc: 0.967
    Test loss: 0.098
              precision    recall  f1-score   support

           0       0.96      0.92      0.94      8711
           1       0.93      0.98      0.95      8815
           2       0.99      0.94      0.96      8904
           3       0.96      1.00      0.98      8800
           4       1.00      1.00      1.00      8805

    accuracy                           0.97     44035
   macro avg       0.97      0.97      0.97     44035
weighted avg       0.97      0.97      0.97     44035

[[8015  518   55  117    6]
 [ 175 8620   14    6    0]
 [ 113  145 8365  261   20]
 [  15   10    0 8775    0]
 [   7    0    2    8 8788]]

Training... epoch 4




    Percent trained: 100.0%  Time elapsed: 0.37 min
    Train acc: 0.974
    Train loss: 0.076
Testing
    Test acc: 0.972
    Test loss: 0.085
              precision    recall  f1-score   support

           0       0.97      0.92      0.95      8711
           1       0.92      0.99      0.95      8815
           2       0.99      0.96      0.98      8904
           3       0.98      0.99      0.99      8800
           4       1.00      0.99      1.00      8805

    accuracy                           0.97     44035
   macro avg       0.97      0.97      0.97     44035
weighted avg       0.97      0.97      0.97     44035

[[8042  571   51   46    1]
 [  93 8720    2    0    0]
 [  85  148 8554  117    0]
 [  36   34    9 8721    0]
 [  33    0   11    8 8753]]

Training... epoch 5




    Percent trained: 100.0%  Time elapsed: 0.38 min
    Train acc: 0.978
    Train loss: 0.065
Testing
    Test acc: 0.969
    Test loss: 0.091
              precision    recall  f1-score   support

           0       0.95      0.94      0.94      8711
           1       0.99      0.94      0.96      8815
           2       0.98      0.97      0.97      8904
           3       0.93      1.00      0.97      8800
           4       1.00      1.00      1.00      8805

    accuracy                           0.97     44035
   macro avg       0.97      0.97      0.97     44035
weighted avg       0.97      0.97      0.97     44035

[[8156  113   95  346    1]
 [ 406 8314   56   39    0]
 [  38    7 8606  253    0]
 [   0    0    0 8800    0]
 [   5    0    2    0 8798]]



In [25]:
# Test set
print("Number of Parameters:", count_parameters(resnet), end='\n\n')
epoch_eval(test_dl, resnet, device)

Number of Parameters: 19237

Testing




    Test acc: 0.958
    Test loss: 0.127
              precision    recall  f1-score   support

           0       0.94      0.93      0.93       800
           1       0.99      0.93      0.95       800
           2       0.98      0.96      0.97       800
           3       0.82      1.00      0.90       300
           4       1.00      1.00      1.00       800

    accuracy                           0.96      3500
   macro avg       0.94      0.96      0.95      3500
weighted avg       0.96      0.96      0.96      3500

[[745  11  12  32   0]
 [ 46 741   6   7   0]
 [  3   0 770  26   1]
 [  0   0   1 299   0]
 [  1   0   0   0 799]]



In [14]:
# Initialize ODENet
odenet, odeopt = get_model(net='ODENet',device=device,
                           n_channels=32, n_inner_channels=32, 
                           kernel_size=3, num_groups=8, adam=True,
                           rtol=1e-3, atol=1e-3)

# Training Options
max_epochs = 15
scheduler = optim.lr_scheduler.ReduceLROnPlateau(odeopt, mode='min', factor=0.1, patience=5)

# Training loop
for i in range(max_epochs):
    epoch(train_dl, odenet, device, odeopt, scheduler, epoch=i+1)
    epoch_eval(val_dl, odenet, device)

Training... epoch 1




    Percent trained: 100.0%  Time elapsed: 13.97 min
    Train acc: 0.865
    Train loss: 0.382
Testing
    Test acc: 0.91
    Test loss: 0.247
              precision    recall  f1-score   support

           0       0.90      0.78      0.83      8711
           1       0.85      0.93      0.89      8815
           2       0.90      0.92      0.91      8904
           3       0.90      0.93      0.91      8800
           4       1.00      0.99      1.00      8805

    accuracy                           0.91     44035
   macro avg       0.91      0.91      0.91     44035
weighted avg       0.91      0.91      0.91     44035

[[6761 1228  363  351    8]
 [ 454 8183  100   69    9]
 [  61   98 8227  509    9]
 [ 209   49  390 8152    0]
 [  17   16   24    0 8748]]

Training... epoch 2




    Percent trained: 100.0%  Time elapsed: 15.58 min
    Train acc: 0.916
    Train loss: 0.233
Testing
    Test acc: 0.905
    Test loss: 0.267
              precision    recall  f1-score   support

           0       0.80      0.89      0.84      8711
           1       0.97      0.78      0.87      8815
           2       0.89      0.94      0.91      8904
           3       0.89      0.92      0.90      8800
           4       1.00      0.99      1.00      8805

    accuracy                           0.90     44035
   macro avg       0.91      0.90      0.90     44035
weighted avg       0.91      0.90      0.90     44035

[[7749  188  265  501    8]
 [1634 6909  134  138    0]
 [ 124   20 8373  383    4]
 [ 118    0  619 8063    0]
 [  22    2   31    3 8747]]

Training... epoch 3




    Percent trained: 100.0%  Time elapsed: 18.3 minn
    Train acc: 0.911
    Train loss: 0.246
Testing
    Test acc: 0.908
    Test loss: 0.257
              precision    recall  f1-score   support

           0       0.86      0.84      0.85      8711
           1       0.92      0.87      0.89      8815
           2       0.85      0.97      0.90      8904
           3       0.94      0.87      0.90      8800
           4       1.00      0.99      0.99      8805

    accuracy                           0.91     44035
   macro avg       0.91      0.91      0.91     44035
weighted avg       0.91      0.91      0.91     44035

[[7336  593  419  354    9]
 [ 931 7647  211   26    0]
 [  79   60 8632  126    7]
 [ 182   32  929 7638   19]
 [  21   16   17    0 8751]]

Training... epoch 4




    Percent trained: 100.0%  Time elapsed: 18.35 min
    Train acc: 0.908
    Train loss: 0.256
Testing
    Test acc: 0.919
    Test loss: 0.224
              precision    recall  f1-score   support

           0       0.85      0.87      0.86      8711
           1       0.92      0.88      0.90      8815
           2       0.90      0.94      0.92      8904
           3       0.93      0.92      0.92      8800
           4       0.99      0.99      0.99      8805

    accuracy                           0.92     44035
   macro avg       0.92      0.92      0.92     44035
weighted avg       0.92      0.92      0.92     44035

[[7550  573  298  277   13]
 [ 951 7751   76   33    4]
 [ 163   70 8332  302   37]
 [ 176   32  505 8087    0]
 [  25    1   18   18 8743]]

Training... epoch 5




    Percent trained: 100.0%  Time elapsed: 20.56 min
    Train acc: 0.915
    Train loss: 0.234
Testing
    Test acc: 0.911
    Test loss: 0.242
              precision    recall  f1-score   support

           0       0.75      0.95      0.84      8711
           1       0.97      0.79      0.87      8815
           2       0.95      0.89      0.92      8904
           3       0.94      0.93      0.94      8800
           4       1.00      0.99      0.99      8805

    accuracy                           0.91     44035
   macro avg       0.92      0.91      0.91     44035
weighted avg       0.92      0.91      0.91     44035

[[8280  183   66  168   14]
 [1812 6948   19   32    4]
 [ 632   32 7914  305   21]
 [ 258    0  319 8223    0]
 [  20    4   14   21 8746]]

Training... epoch 6




    Percent trained: 100.0%  Time elapsed: 21.17 min
    Train acc: 0.92
    Train loss: 0.217
Testing
    Test acc: 0.877
    Test loss: 0.334
              precision    recall  f1-score   support

           0       0.97      0.55      0.70      8711
           1       0.79      0.95      0.86      8815
           2       0.92      0.92      0.92      8904
           3       0.80      0.97      0.88      8800
           4       0.99      0.99      0.99      8805

    accuracy                           0.88     44035
   macro avg       0.89      0.88      0.87     44035
weighted avg       0.89      0.88      0.87     44035

[[4773 2084  437 1390   27]
 [ 125 8333   89  264    4]
 [  22  108 8224  524   26]
 [   0   22  206 8560   12]
 [   5    4   26   23 8747]]

Training... epoch 7




    Percent trained: 100.0%  Time elapsed: 21.53 min
    Train acc: 0.924
    Train loss: 0.209
Testing
    Test acc: 0.92
    Test loss: 0.212
              precision    recall  f1-score   support

           0       0.83      0.89      0.86      8711
           1       0.91      0.88      0.90      8815
           2       0.95      0.89      0.92      8904
           3       0.92      0.94      0.93      8800
           4       1.00      1.00      1.00      8805

    accuracy                           0.92     44035
   macro avg       0.92      0.92      0.92     44035
weighted avg       0.92      0.92      0.92     44035

[[7724  623  143  214    7]
 [ 923 7784   62   46    0]
 [ 405   59 7966  461   13]
 [ 255   50  216 8267   12]
 [   5   14   25    0 8761]]

Training... epoch 8




    Percent trained: 100.0%  Time elapsed: 21.52 min
    Train acc: 0.929
    Train loss: 0.197
Testing
    Test acc: 0.92
    Test loss: 0.225
              precision    recall  f1-score   support

           0       0.80      0.92      0.86      8711
           1       0.95      0.83      0.89      8815
           2       0.97      0.90      0.93      8904
           3       0.91      0.96      0.93      8800
           4       1.00      0.99      0.99      8805

    accuracy                           0.92     44035
   macro avg       0.93      0.92      0.92     44035
weighted avg       0.93      0.92      0.92     44035

[[8038  299   80  288    6]
 [1379 7349   54   33    0]
 [ 352   51 7973  512   16]
 [ 241   21   91 8447    0]
 [  72    0   23    9 8701]]

Training... epoch 9




    Percent trained: 100.0%  Time elapsed: 22.12 min
    Train acc: 0.929
    Train loss: 0.198
Testing
    Test acc: 0.899
    Test loss: 0.279
              precision    recall  f1-score   support

           0       0.85      0.86      0.85      8711
           1       0.94      0.87      0.90      8815
           2       0.99      0.79      0.88      8904
           3       0.77      1.00      0.87      8800
           4       1.00      0.99      0.99      8805

    accuracy                           0.90     44035
   macro avg       0.91      0.90      0.90     44035
weighted avg       0.91      0.90      0.90     44035

[[7457  380   40  829    5]
 [ 957 7627   22  209    0]
 [ 288   98 7004 1507    7]
 [  24    9    0 8767    0]
 [  70    0   14    4 8717]]

Training... epoch 10




    Percent trained: 100.0%  Time elapsed: 22.91 min
    Train acc: 0.931
    Train loss: 0.192
Testing
    Test acc: 0.929
    Test loss: 0.194
              precision    recall  f1-score   support

           0       0.83      0.92      0.87      8711
           1       0.93      0.86      0.90      8815
           2       0.94      0.95      0.94      8904
           3       0.95      0.92      0.94      8800
           4       1.00      0.99      0.99      8805

    accuracy                           0.93     44035
   macro avg       0.93      0.93      0.93     44035
weighted avg       0.93      0.93      0.93     44035

[[8019  384  164  137    7]
 [1131 7614   60   10    0]
 [ 147   76 8434  237   10]
 [ 297   89  316 8098    0]
 [  30    4   38   11 8722]]

Training... epoch 11




    Percent trained: 100.0%  Time elapsed: 23.1 minn
    Train acc: 0.934
    Train loss: 0.183
Testing
    Test acc: 0.936
    Test loss: 0.176
              precision    recall  f1-score   support

           0       0.84      0.94      0.89      8711
           1       0.94      0.88      0.91      8815
           2       0.98      0.91      0.94      8904
           3       0.94      0.95      0.94      8800
           4       1.00      0.99      1.00      8805

    accuracy                           0.94     44035
   macro avg       0.94      0.94      0.94     44035
weighted avg       0.94      0.94      0.94     44035

[[8171  378   95   64    3]
 [ 945 7799   49   22    0]
 [ 224   89 8108  476    7]
 [ 375   29   31 8364    1]
 [  36    1   14    0 8754]]

Training... epoch 12




    Percent trained: 100.0%  Time elapsed: 23.74 min
    Train acc: 0.934
    Train loss: 0.181
Testing
    Test acc: 0.936
    Test loss: 0.178
              precision    recall  f1-score   support

           0       0.87      0.90      0.88      8711
           1       0.94      0.89      0.91      8815
           2       0.96      0.92      0.94      8904
           3       0.92      0.97      0.95      8800
           4       0.99      1.00      1.00      8805

    accuracy                           0.94     44035
   macro avg       0.94      0.94      0.94     44035
weighted avg       0.94      0.94      0.94     44035

[[7797  423  160  320   11]
 [ 848 7863   43   47   14]
 [ 228   96 8223  335   22]
 [  87    1  135 8577    0]
 [  11    0   21    4 8769]]

Training... epoch 13




    Percent trained: 100.0%  Time elapsed: 22.86 min
    Train acc: 0.932
    Train loss: 0.187
Testing
    Test acc: 0.913
    Test loss: 0.237
              precision    recall  f1-score   support

           0       0.76      0.94      0.84      8711
           1       0.96      0.77      0.86      8815
           2       0.93      0.94      0.94      8904
           3       0.96      0.91      0.94      8800
           4       1.00      1.00      1.00      8805

    accuracy                           0.91     44035
   macro avg       0.92      0.91      0.91     44035
weighted avg       0.92      0.91      0.91     44035

[[8175  213  127  184   12]
 [1884 6820   87   19    5]
 [ 325   28 8408  128   15]
 [ 354    9  410 8027    0]
 [  10    0   24    0 8771]]

Training... epoch 14




    Percent trained: 100.0%  Time elapsed: 22.29 min
    Train acc: 0.933
    Train loss: 0.185
Testing
    Test acc: 0.942
    Test loss: 0.161
              precision    recall  f1-score   support

           0       0.92      0.85      0.88      8711
           1       0.91      0.93      0.92      8815
           2       0.97      0.95      0.96      8904
           3       0.93      0.98      0.95      8800
           4       0.99      1.00      0.99      8805

    accuracy                           0.94     44035
   macro avg       0.94      0.94      0.94     44035
weighted avg       0.94      0.94      0.94     44035

[[7377  784  143  374   33]
 [ 504 8209   40   58    4]
 [ 100   60 8465  247   32]
 [  21    6  102 8659   12]
 [   7    2    8    9 8779]]

Training... epoch 15




    Percent trained: 100.0%  Time elapsed: 21.88 min
    Train acc: 0.935
    Train loss: 0.179
Testing
    Test acc: 0.933
    Test loss: 0.181
              precision    recall  f1-score   support

           0       0.84      0.91      0.88      8711
           1       0.93      0.88      0.90      8815
           2       0.95      0.94      0.94      8904
           3       0.95      0.95      0.95      8800
           4       1.00      0.98      0.99      8805

    accuracy                           0.93     44035
   macro avg       0.93      0.93      0.93     44035
weighted avg       0.93      0.93      0.93     44035

[[7970  476  156  108    1]
 [ 966 7738   93   18    0]
 [ 163   63 8364  305    9]
 [ 297   47  115 8341    0]
 [  62    9   72    6 8656]]



In [15]:
# Test set
print("Number of Parameters:", count_parameters(odenet), end='\n\n')
epoch_eval(test_dl, odenet, device)

Number of Parameters: 6885

Testing




    Test acc: 0.923
    Test loss: 0.207
              precision    recall  f1-score   support

           0       0.84      0.91      0.87       800
           1       0.93      0.86      0.90       800
           2       0.96      0.92      0.94       800
           3       0.84      0.96      0.90       300
           4       1.00      0.98      0.99       800

    accuracy                           0.92      3500
   macro avg       0.92      0.93      0.92      3500
weighted avg       0.93      0.92      0.92      3500

[[728  46  15  11   0]
 [103 689   5   3   0]
 [ 17   4 739  40   0]
 [ 11   0   2 287   0]
 [  7   0   6   0 787]]

