In [10]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchmetrics
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
import pytorch_lightning as pl

import albumentations as A
from albumentations.pytorch import ToTensorV2

from torchdyn.core import NeuralODE
from torchdyn.nn import Augmenter
from torchdyn.utils import *

import matplotlib.pyplot as plt

In [39]:
class Learner(pl.LightningModule):
    def __init__(self, model: nn.Module, lr: float, l2: float, t_span: torch.Tensor = None):
        super().__init__()

        self.model = model
        self.t_span = t_span
        
        self.lr = lr
        self.l2 = l2

        self.accuracy = torchmetrics.Accuracy(multiclass=True, num_classes=10)
        self.f1_score = torchmetrics.F1Score(multiclass=True, num_classes=10)

    def forward(self, x: torch.Tensor):
        return self.model(x, self.t_span)

    def training_step(self, batch, batch_idx):
        x, y = batch

        y_hat = self.forward(x)

        loss = F.cross_entropy(y_hat, y)
        acc = self.accuracy(y_hat, y)
        f1 = self.f1_score(y_hat, y)

        self.log_dict(
            {
                "train/loss": loss,
                "train/acc": acc,
                "train/f1": f1
            }
        )

        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch

        y_hat = self.forward(x)

        loss = F.cross_entropy(y_hat, y)
        acc = self.accuracy(y_hat, y)
        f1 = self.f1_score(y_hat, y)

        self.log_dict(
            {
                "val/loss": loss,
                "val/acc": acc,
                "val/f1": f1
            }
        )

        return loss

    
    def test_step(self, batch, batch_idx):
        x, y = batch

        y_hat = self.forward(x)

        loss = F.cross_entropy(y_hat, y)
        acc = self.accuracy(y_hat, y)
        f1 = self.f1_score(y_hat, y)

        self.log_dict(
            {
                "test/loss": loss,
                "test/acc": acc,
                "test/f1": f1
            }
        )

        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.lr, weight_decay=self.l2)
        
        return optimizer


In [36]:
class NODEClassifier(nn.Module):
    def __init__(self, preprocessing_layer: nn.Module, vector_field: nn.Module, classifier: nn.Module):
        super().__init__()

        self.preprocessing_layer = preprocessing_layer
        self.node = NeuralODE(vector_field, solver="dopri5", order=1, sensitivity="adjoint")
        self.classifier = classifier

    def forward(self, x: torch.Tensor, t_span: torch.Tensor):
        x = self.preprocessing_layer(x)

        _, x = self.node(x, t_span)

        return self.classifier(x[-1])

In [4]:
class ConvBlock(nn.Module):
    def __init__(self, in_channels: int, out_channels: int, **kwargs):
        super().__init__()

        self.conv_block = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, bias=False, **kwargs),
            nn.BatchNorm2d(out_channels),
            nn.SiLU()
        )

    def forward(self, x: torch.Tensor):
        return self.conv_block(x)

In [37]:
preprocessing_layer = Augmenter(augment_idx=1, augment_dims=15)

f = nn.Sequential(
    ConvBlock(16, 16, kernel_size=3, padding=1),
    ConvBlock(16, 16, kernel_size=3, padding=1),
    ConvBlock(16, 16, kernel_size=3, padding=1)
)

classifier = nn.Sequential(
    ConvBlock(16, 1, kernel_size=3, padding=1),
    nn.Flatten(),
    nn.Linear(784, 10)
)

model = NODEClassifier(preprocessing_layer, f, classifier)

Your vector field callable (nn.Module) should have both time `t` and state `x` as arguments, we've wrapped it for you.


In [29]:
class Transforms:
    def __init__(self):
        self.transforms = A.Compose([
            A.Normalize(mean=0.5, std=0.5),
            ToTensorV2()
        ])

    def __call__(self, img):
        return self.transforms(image=np.array(img))["image"]

transform = Transforms()

In [30]:
mnist_train = MNIST(root="dataset/", train=True, download=True, transform=transform)
mnist_val = MNIST(root="dataset/", train=False, download=True, transform=transform)

loader_train = DataLoader(mnist_train, batch_size=32, shuffle=True, num_workers=8)
loader_val = DataLoader(mnist_val, batch_size=32, shuffle=False, num_workers=8)

In [40]:
learner = Learner(model, lr=3e-4, l2=0.001)
trainer = pl.Trainer(gpus=-1)
trainer.fit(learner, train_dataloaders=loader_train, val_dataloaders=loader_val)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Missing logger folder: /home/jasiek/python-projects/deep-learning-in-pytorch/node_mnist_classification/lightning_logs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name     | Type           | Params
--------------------------------------------
0 | model    | NODEClassifier | 15.0 K
1 | accuracy | Accuracy       | 0     
2 | f1_score | F1Score        | 0     
--------------------------------------------
15.0 K    Trainable params
0         Non-trainable params
15.0 K    Total params
0.060     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(


Training: 0it [00:00, ?it/s]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")
