# Momentum survival: MNIST use case

In this example, we will use the pytorch lightning framework to further show how easy is it to use `torchsurv`

In [1]:
import torch
import pytorch_lightning as L
from torchvision.models import resnet18
from torchvision.transforms import v2

from torchsurv.loss.cox import neg_partial_log_likelihood
from torchsurv.loss.momentum import Momentum

In [2]:
# For simplicity (or laziness), we already implemented the datamodule for MNIST. See code for details
from helpers_momentum import MNISTDataModule, LitMomentum, LitMNIST

In [3]:
import warnings
warnings.filterwarnings('ignore')

In [4]:
BATCH_SIZE = 256

## Model backbone

First we need to define out model backbone. We will use here the resnet18.

In [5]:
resnet = resnet18(weights=None)
# Fit grayscale images
resnet.conv1 = torch.nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
# Output log hazards
resnet.fc = torch.nn.Linear(in_features=resnet.fc.in_features, out_features=1)

In [6]:
# Transforms our images
transforms =v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32, scale=True), v2.Resize(224, antialias=True), v2.Normalize(mean=(0,), std=(1,))])

# Sanity check
x = torch.randn((6, 1, 28, 28))
print(f'{transforms(x).shape}')
print(f'{resnet(transforms(x)).shape}')

torch.Size([6, 1, 224, 224])
torch.Size([6, 1])


## Regular training

In [7]:
model_regular = LitMNIST(resnet)


In [8]:
# Define experiment
datamodule = MNISTDataModule(batch_size=BATCH_SIZE, transforms=transforms)

In [9]:
# Define trainer
trainer = L.Trainer(
    accelerator="auto",
    logger=False,
    enable_checkpointing=False,
    fast_dev_run=5,  # Quick dev, 
)

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Running in `fast_dev_run` mode: will run the requested loop using 5 batch(es). Logging and checkpointing is suppressed.


In [10]:
# Fit the model
trainer.fit(model_regular, datamodule)


  | Name  | Type   | Params
---------------------------------
0 | model | ResNet | 11.2 M
---------------------------------
11.2 M    Trainable params
0         Non-trainable params
11.2 M    Total params
44.683    Total estimated model params size (MB)


Training: |                                                                      | 0/? [00:00<?, ?it/s]

Validation: |                                                                    | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_steps=5` reached.


In [11]:
# Validate the model
trainer.validate(model_regular, datamodule)

Validation: |                                                                    | 0/? [00:00<?, ?it/s]

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
     Validate metric           DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
      cindex_epoch           0.603299081325531
     val_loss_epoch         -68.12861633300781
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'val_loss_epoch': -68.12861633300781, 'cindex_epoch': 0.603299081325531}]

In [12]:
# Test the model
trainer.test(model_regular, datamodule)

Testing: |                                                                       | 0/? [00:00<?, ?it/s]

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
      cindex_epoch          0.6242874264717102
     val_loss_epoch         -70.14149475097656
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'val_loss_epoch': -70.14149475097656, 'cindex_epoch': 0.6242874264717102}]

# Momentum

By decreasing our batch by `factor_n`, and matching this factor in the `Momentum(n=X)`, we should in theory mimic the same training information while having a much smaller batch size. This is ideal for large dataset, such as medical images:

In [13]:
factor = 10
model_momentum = LitMomentum(backbone=Momentum(resnet, neg_partial_log_likelihood, n = factor, m=0.999))
# Define experiment
datamodule_momentum = MNISTDataModule(batch_size= BATCH_SIZE // factor , transforms=transforms)

In [14]:
# Fit the model
trainer.fit(model_momentum, datamodule_momentum)


  | Name  | Type     | Params
-----------------------------------
0 | model | Momentum | 22.3 M
-----------------------------------
11.2 M    Trainable params
11.2 M    Non-trainable params
22.3 M    Total params
89.366    Total estimated model params size (MB)
`Trainer.fit` stopped: `max_steps=5` reached.


In [15]:
# Validate the model
trainer.validate(model_momentum, datamodule_momentum)

Validation: |                                                                    | 0/? [00:00<?, ?it/s]

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
     Validate metric           DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
      cindex_epoch           0.658102810382843
     val_loss_epoch         21.859458923339844
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'val_loss_epoch': 21.859458923339844, 'cindex_epoch': 0.658102810382843}]

In [16]:
# Validate the model
trainer.test(model_momentum, datamodule_momentum)

Testing: |                                                                       | 0/? [00:00<?, ?it/s]

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
      cindex_epoch          0.6049648523330688
     val_loss_epoch          62.4331169128418
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'val_loss_epoch': 62.4331169128418, 'cindex_epoch': 0.6049648523330688}]