<a href="https://colab.research.google.com/github/GitTeaching/Predicting-using-Neural-ODE/blob/main/Neural_ODE_Demo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Neural Ordinary Differential Equations based on : https://arxiv.org/abs/1806.07366

#### - Dataset used: ECG Heartbeat Categorization Dataset - https://www.kaggle.com/shayanfazeli/heartbeat
#### - Code Sources from : 

https://medium.com/analytics-vidhya/intro-to-neural-odes-part-3-the-basics-9697b3bd1946

https://github.com/abaietto/neural_ode_classification/blob/master/ECG_Classification.ipynb

## Load Data from Kaggle

In [None]:
from google.colab import files
files.upload()

!pip install -q kaggle

!mkdir -p ~/.kaggle
!cp kaggle.json ~/.kaggle/

!kaggle datasets download shayanfazeli/heartbeat

Saving kaggle.json to kaggle.json
Downloading heartbeat.zip to /content
 90% 89.0M/98.8M [00:01<00:00, 75.3MB/s]
100% 98.8M/98.8M [00:01<00:00, 92.8MB/s]


In [None]:
!ls /content
!unzip heartbeat.zip

heartbeat.zip  kaggle.json  sample_data
Archive:  heartbeat.zip
  inflating: mitbih_test.csv         
  inflating: mitbih_train.csv        
  inflating: ptbdb_abnormal.csv      
  inflating: ptbdb_normal.csv        


In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import train_test_split

%matplotlib inline
sns.set_style('darkgrid')

## Data loading and spliting

In [None]:
mit_train = pd.read_csv('mitbih_train.csv', header=None)
mit_test = pd.read_csv('mitbih_test.csv', header=None)

print(f"Train shape: {mit_train.shape}")
print(f"Test shape: {mit_test.shape}")

Train shape: (87554, 188)
Test shape: (21892, 188)


In [None]:
# Separate target from data
y_train = mit_train[187]
X_train = mit_train.loc[:, :186]

y_test = mit_test[187]
X_test = mit_test.loc[:, :186]

In [None]:
# Proportions of each class
round(mit_train[187].value_counts(normalize=True).sort_index(), 2)

0.0    0.83
1.0    0.03
2.0    0.07
3.0    0.01
4.0    0.07
Name: 187, dtype: float64

## Preparing and converting data for Pytorch - tensors

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader

In [None]:
X_train, y_train, X_test, y_test = map(torch.from_numpy, (X_train.values, y_train.values, X_test.values, y_test.values))

In [None]:
# Convert to 3D tensor
X_train = X_train.unsqueeze(1)
X_test = X_test.unsqueeze(1)

In [None]:
# Batch size
bs = 128

train_ds = TensorDataset(X_train, y_train)
train_dl = DataLoader(train_ds, batch_size=bs, shuffle=True)

test_ds = TensorDataset(X_test, y_test)
test_dl = DataLoader(test_ds, batch_size=bs * 2)

## Import models from torchdiffeq  - ResBlock and NeuralODEBlock

In [None]:
import time

In [None]:
!pip install torchdiffeq

Collecting torchdiffeq
  Downloading https://files.pythonhosted.org/packages/67/af/377e42c20058f4891dedc1827c1a6b7b16772a452d18097d05c1db06b338/torchdiffeq-0.1.1-py3-none-any.whl
Installing collected packages: torchdiffeq
Successfully installed torchdiffeq-0.1.1


### ResBlock class, ConcatConv1d class, ODEfunc class, ODENet class, and Flatten class

In [None]:
"""
ResNet and ODENet classes for ECG classification.
Code adapted from:
https://github.com/rtqichen/torchdiffeq/blob/master/examples/odenet_mnist.py
"""

import torch
import torch.nn as nn
from torchdiffeq import odeint_adjoint


def norm(dim):
    """
    Group normalization to improve model accuracy and training speed.
    """
    return nn.GroupNorm(min(32, dim), dim)


class ResBlock(nn.Module):
    """
    Simple residual block used to construct ResNet.
    """
    def __init__(self, dim):
        super(ResBlock, self).__init__()
        self.gn1 = norm(dim)
        self.conv1 = nn.Conv1d(dim, dim, kernel_size=3, padding=1, bias=False)
        self.gn2 = norm(dim)
        self.conv2 = nn.Conv1d(dim, dim, kernel_size=3, padding=1, bias=False)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        # Shortcut
        identity = x

        # First convolution
        out = self.gn1(x)
        out = self.relu(out)
        out = self.conv1(out)

        # Second convolution
        out = self.gn2(out)
        out = self.relu(out)
        out = self.conv2(out)

        # Add shortcut
        out += identity

        return out


class ConcatConv1d(nn.Module):
    """
    1d convolution concatenated with time for usage in ODENet.
    """
    def __init__(self, dim_in, dim_out, kernel_size=3, stride=1, padding=0, bias=True, transpose=False):
        super(ConcatConv1d, self).__init__()
        module = nn.ConvTranspose1d if transpose else nn.Conv1d
        self._layer = module(
            dim_in + 1, dim_out, kernel_size=kernel_size, stride=stride, padding=padding,
            bias=bias
        )

    def forward(self, t, x):
        tt = torch.ones_like(x[:, :1, :]) * t
        ttx = torch.cat([tt, x], 1)
        return self._layer(ttx)


class ODEfunc(nn.Module):
    """
    Network architecture for ODENet.
    """
    def __init__(self, dim):
        super(ODEfunc, self).__init__()
        self.norm1 = norm(dim)
        self.relu = nn.ReLU(inplace=True)
        self.conv1 = ConcatConv1d(dim, dim, 3, 1, 1)
        self.norm2 = norm(dim)
        self.conv2 = ConcatConv1d(dim, dim, 3, 1, 1)
        self.norm3 = norm(dim)
        self.nfe = 0    # Number of function evaluations

    def forward(self, t, x):
        self.nfe += 1
        out = self.norm1(x)
        out = self.relu(out)
        out = self.conv1(t, out)
        out = self.norm2(out)
        out = self.relu(out)
        out = self.conv2(t, out)
        out = self.norm3(out)
        return out


class ODENet(nn.Module):
    """
    Neural ODE.
    Uses ODE solver (dopri5 by default) to yield model output.
    Backpropagation is done with the adjoint method as described in
    https://arxiv.org/abs/1806.07366.
    Parameters
    ----------
    odefunc : nn.Module
        network architecture
    rtol : float
        relative tolerance of ODE solver
    atol : float
        absolute tolerance of ODE solver
    """
    def __init__(self, odefunc, rtol=1e-3, atol=1e-3):
        super(ODENet, self).__init__()
        self.odefunc = odefunc
        self.integration_time = torch.tensor([0, 1]).float()
        self.rtol = rtol
        self.atol = atol

    def forward(self, x):
        self.integration_time = self.integration_time.type_as(x)
        out = odeint_adjoint(self.odefunc, x, self.integration_time, self.rtol, self.atol)
        return out[1]

    # Update number of function evaluations (nfe)
    @property
    def nfe(self):
        return self.odefunc.nfe

    @nfe.setter
    def nfe(self, value):
        self.odefunc.nfe = value


class Flatten(nn.Module):
    """
    Flatten feature maps for input to linear layer.
    """
    def __init__(self):
        super(Flatten, self).__init__()

    def forward(self, x):
        shape = torch.prod(torch.tensor(x.shape[1:])).item()
        return x.view(-1, shape)

    
def count_parameters(model):
    """
    Count number of tunable parameters.
    """
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

### Helpers adapted from https://pytorch.org/tutorials/beginner/nn_tutorial.html

In [None]:
# Helpers adapted from https://pytorch.org/tutorials/beginner/nn_tutorial.html

def get_model(is_odenet=True, dim=64, adam=False, **kwargs):
    """
    Initialize ResNet or ODENet with optimizer.
    """
    downsampling_layers = [
        nn.Conv1d(1, dim, 3, 1),
        norm(dim),
        nn.ReLU(inplace=True),
        nn.Conv1d(dim, dim, 4, 2, 1),
        norm(dim),
        nn.ReLU(inplace=True),
        nn.Conv1d(dim, dim, 4, 2, 1)
    ]

    feature_layers = [ODENet(ODEfunc(dim), **kwargs)] if is_odenet else [ResBlock(dim) for _ in range(6)]

    fc_layers = [norm(dim), nn.ReLU(inplace=True), nn.AdaptiveAvgPool1d(1), Flatten(), nn.Linear(dim, 5)]

    model = nn.Sequential(*downsampling_layers, *feature_layers, *fc_layers)

    opt = optim.Adam(model.parameters()) if adam else optim.SGD(model.parameters(), lr=0.1, momentum=0.9)

    return model, opt


def loss_batch(model, loss_func, xb, yb, opt=None):
    """
    Calculate loss and update weights if training.
    """
    loss = loss_func(model(xb.float()), yb.long())

    if opt is not None:
        loss.backward()
        opt.step()
        opt.zero_grad()

    return loss.item(), len(xb)


def fit(epochs, model, loss_func, opt, train_dl, valid_dl):
    """
    Train neural network model.
    """
    num_batches = len(train_dl)
    
    for epoch in range(epochs):
        print(f"Training... epoch {epoch + 1}")
        
        model.train()   # Set model to training mode
        batch_count = 0
        start = time.time()
        for xb, yb in train_dl:
            batch_count += 1
            curr_time = time.time()
            percent = round(batch_count/len(train_dl) * 100, 1)
            elapsed = round((curr_time - start)/60, 1)
            print(f"    Percent trained: {percent}%  Time elapsed: {elapsed} min", end='\r')
            loss_batch(model, loss_func, xb, yb, opt)
            
            

        model.eval()    # Set model to validation mode
        with torch.no_grad():
            losses, nums = zip(
                *[loss_batch(model, loss_func, xb, yb) for xb, yb in valid_dl]
            )
        val_loss = np.sum(np.multiply(losses, nums)) / np.sum(nums)

        print(f"\n    val loss: {round(val_loss, 2)}\n")

## ResNet modeling

#### Its architecture : [ Downsampling layers + ResBlock / features layers + fully connected layers ]

In [None]:
resnet, resopt = get_model(is_odenet=False, adam=False)

In [None]:
fit(2, resnet, F.cross_entropy, resopt, train_dl, test_dl)

Training... epoch 1

    val loss: 0.25

Training... epoch 2

    val loss: 0.14



## ODENet modeling

#### Its architecture : [ Downsampling layers + ODENet Block / features layers + fully connected layers ]

In [None]:
odenet, odeopt = get_model(adam=False, rtol=1e-3, atol=1e-3)

In [None]:
fit(2, odenet, F.cross_entropy, odeopt, train_dl, test_dl)

Training... epoch 1

    val loss: 0.21

Training... epoch 2

    val loss: 0.12



## Accuracy and memory usage

In [None]:
def accuracy(model, X_test, y_test):
    model.eval()
    with torch.no_grad():
        logits = model(X_test.float())
    preds = torch.argmax(F.softmax(logits, dim=1), axis=1).numpy()
    return (preds == y_test.numpy()).mean()

In [None]:
print(f"ResNet accuracy: {round(accuracy(resnet, X_test, y_test), 3)}")
print(f"ODENet accuracy: {round(accuracy(odenet, X_test, y_test), 3)}")

ResNet accuracy: 0.961
ODENet accuracy: 0.967


In [None]:
print("Number of tunable parameters in...")
print(f"    ResNet: {count_parameters(resnet)}")
print(f"    ODENet: {count_parameters(odenet)}")

Number of tunable parameters in...
    ResNet: 182853
    ODENet: 59333
