In [14]:
%load_ext autoreload

%autoreload 2
    
%matplotlib inline
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
from pathlib import Path
DEST_DIR = Path('./MNIST_output')
PYTORCH_ENABLE_MPS_FALLBACK=1

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [3]:
import numpy as np
import torch
from torchvision import transforms
from torchvision.datasets import MNIST
from torchir.utils import IRDataSet

Then we create a specific dataset class that selects MNIST instances with a specific class label. Here we take all number 9s.

In [4]:
class MNISTSubSet(MNIST):
    '''
    A Dataset class that selects a single type of MNIST digit.
    '''
    def __init__(self, label, rng=np.random.default_rng(), *args, **kwargs):
        super().__init__(*args, **kwargs)
        assert(label >= 0 and label <= 9)
        idcs = torch.where(self.targets == label)
        
        self.data = self.data[idcs]
        self.targets = self.targets[idcs]

        self.transform = transform
        self.rng = rng
        
    def __getitem__(self, idx):
        return super().__getitem__(idx)[0] # only return image


Next we set up the dataloaders.

In [5]:
# Create a new instance of the dataset
rng = np.random.default_rng(808)

# transforms.ToTensor(): Converts a PIL (Python Imaging Library) image or a numpy array to a PyTorch tensor. Additionally, it scales the 
# image's pixel intensity values to the range [0, 1].
# transforms.Normalize(): Normalizes a tensor image with mean and standard deviation. Here, it takes the values of the tensor 
# (originally in [0,1] due to ToTensor()) and transforms them to the range [-1,1]
transform=transforms.Compose([transforms.ToTensor(),
                              transforms.Normalize((0.5), (0.5)),
                             ])
ds_train_subset = MNISTSubSet(label=9, rng=rng, root='./MNIST_datasets/',  transform=transform, download=True, train=True)

We reserve part of the training set for validation. Note that we only select a small number for validation. This will become clear immediately after this block.

In [6]:
val_set_size = 20
train_set_size = len(ds_train_subset) - val_set_size
ds_train_subset, ds_validation_subset = torch.utils.data.random_split(ds_train_subset, [train_set_size, val_set_size], 
                                                        generator=torch.Generator().manual_seed(808))
ds_test_subset = MNISTSubSet(label=9, rng=rng, root='./MNIST_datasets/',  transform=transform, download=True, train=False)
print(f'Training subset size: {len(ds_train_subset)}')
print(f'Validation subset size: {len(ds_validation_subset)}')
print(f'Test subset size: {len(ds_test_subset)}')

Training subset size: 5929
Validation subset size: 20
Test subset size: 1009


Now we will employ a convenience class to convert the training and validation data sets into image registration sets. The new class will provide all possible permutations of the input dataset. Note that this will heavily increase the number of training instances (i.e. there are many fixed and moving image permutations). Note that we will use the test set later in a different manner.

In [7]:
# Each entry in the dataset will be a dictionary containing a 'fixed' and a 'moving' image.
ds_train = IRDataSet(ds_train_subset)
ds_validation = IRDataSet(ds_validation_subset)

# IRDataSet will create pairs of x and all other items in ds to serve as fixed and moving images for registration tasks. 
# That's why the length of this dataset is len(ds) × len(ds).
print(f'Training IR set size: {len(ds_train)}')
print(f'Validation IR set size: {len(ds_validation)}')

Training IR set size: 35153041
Validation IR set size: 400


Note that each instance of an IR dataset is a permutation of a fixed and moving image. This results in a very high number of possible permutations. Since Pytorch Lightning does not play well with iterations and rather likes epochs, we limit the number of permutations per epoch by setting up our own data samplers:

In [8]:
batch_size = 32
training_batches = 100
validation_batches = 10

# RandomSampler: Samples elements randomly from a given list of indices, with replacement.
# num_samples: ensures that we get the desired number of training samples for each epoch.
train_sampler = torch.utils.data.RandomSampler(ds_train, replacement=True, 
                                               num_samples=training_batches*batch_size, 
                                               generator=torch.Generator().manual_seed(808))

# DataLoader: Combines a dataset and a sampler, and provides an iterable over the given dataset.
train_loader = torch.utils.data.DataLoader(ds_train, batch_size, sampler=train_sampler)

# Since no sampler is provided, it will simply iterate over the dataset in its original order.
val_loader = torch.utils.data.DataLoader(ds_validation, batch_size)

We have setup all our data classes and now we can start the image registration experiments.

If you want to inspect the data, use tensorboard. The logs are stored in the specified `DEST_DIR`.
> tensorboard --logdir=./output/lightning_logs

# DLIR Framework

In [2] I also demonstrated that multiple coarse-to-fine registration layers improve image registration. Now let's implement this using the DLIRFramework module. I chose a dynamic implementation where we add a layer. We train the layer. We add another layer and fix the weight of the previous layer. We train the new layer. Etc...

In [9]:
import pytorch_lightning as pl
from torchir.networks import DIRNet
from torchir.metrics import NCC
from torchir.transformers import BsplineTransformer
from torchir.dlir_framework import DLIRFramework

In [10]:
class LitDLIRFramework(pl.LightningModule):
    def __init__(self, only_last_trainable=True):
        super().__init__()
        self.dlir_framework = DLIRFramework(only_last_trainable=only_last_trainable)
        self.add_stage = self.dlir_framework.add_stage
        self.metric = NCC()
    
    def configure_optimizers(self):
        lr = 0.001
        weight_decay = 0
        optimizer = torch.optim.Adam(self.dlir_framework.parameters(), lr=lr, weight_decay=weight_decay, amsgrad=True)
        return {'optimizer': optimizer}

    def forward(self, fixed, moving):
        warped = self.dlir_framework(fixed, moving)
        return warped
    
    def training_step(self, batch, batch_idx):
        warped = self(batch['fixed'], batch['moving'])
        loss = self.metric(batch['fixed'], warped)
        self.log('NCC/training', loss)
        return loss
    
    def validation_step(self, batch, batch_idx):
        warped = self(batch['fixed'], batch['moving'])
        loss = self.metric(batch['fixed'], warped)
        self.log('NCC/validation', loss)
        return loss  

Add a BSpline layer with an 8x8 grid:

In [15]:
model = LitDLIRFramework()
model.add_stage(network=DIRNet(grid_spacing=(8, 8), kernels=16, num_conv_layers=5, num_dense_layers=2),
                transformer=BsplineTransformer(ndim=2, upsampling_factors=(8, 8)))
trainer = pl.Trainer(default_root_dir=DEST_DIR,
                     log_every_n_steps=50,
                     val_check_interval=50,
                     max_epochs=100)
trainer.fit(model, train_loader, val_loader)

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name           | Type          | Params
-------------------------------------------------
0 | dlir_framework | DLIRFramework | 11.2 K
1 | metric         | NCC           | 0     
-------------------------------------------------
10.2 K    Trainable params
1.0 K     Non-trainable params
11.2 K    Total params
0.045     Total estimated model params size (MB)


Epoch 0:   0%|          | 0/100 [00:00<?, ?it/s]                           

NotImplementedError: The operator 'aten::grid_sampler_2d_backward' is not currently implemented for the MPS device. If you want this op to be added in priority during the prototype phase of this feature, please comment on https://github.com/pytorch/pytorch/issues/77764. As a temporary fix, you can set the environment variable `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU as a fallback for this op. WARNING: this will be slower than running natively on MPS.

Add a finer 4x4 grid and train. The model, now augmented with the BSpline layer, is trained again on the same datasets.

In [None]:
model.add_stage(network=DIRNet(grid_spacing=(4, 4), kernels=16, num_conv_layers=5, num_dense_layers=2),
                transformer=BsplineTransformer(ndim=2, upsampling_factors=(4, 4)))
trainer = pl.Trainer(default_root_dir=DEST_DIR,
                     log_every_n_steps=50,
                     val_check_interval=50,
                     max_epochs=100,
                     gpus=1)
trainer.fit(model, train_loader, val_loader)

Let's store the checkpoint

In [None]:
trainer.save_checkpoint(DEST_DIR / 'mnist_dlir_8_4.ckpt')
torch.save(model.state_dict(), DEST_DIR / 'mnist_dlir_8_4.pt')

Now do ultra-fast coarse-to-fine image registration:

In [None]:
avg_moving = np.zeros((28, 28), dtype=float)
avg_warped = np.zeros((28, 28), dtype=float)
model = model.cuda()
fixed = ds_test_subset[0]
for moving in tqdm(ds_test_subset):
    warped = model(fixed[None].cuda(), moving[None].cuda()).detach().squeeze().cpu().numpy()
    avg_moving += moving.squeeze().cpu().numpy() / len(ds_test_subset)
    avg_warped += warped / len(ds_test_subset)

In [None]:
fig, axarr = plt.subplots(1, 3, figsize=(10, 3.5))
axarr[0].imshow(-fixed.squeeze().cpu().numpy(), cmap='gray', vmin=-1, vmax=1)
axarr[1].imshow(-avg_moving, cmap='gray', vmin=-1, vmax=1)
axarr[2].imshow(-avg_warped, cmap='gray', vmin=-1, vmax=1)
for ax in axarr.ravel():
    ax.axis('off')

axarr[0].set_title('fixed image')
axarr[1].set_title('images before registration')
axarr[2].set_title('images after registration')

fig.suptitle('DLIR Framework: coarse-to-fine b-splines');

# Loading a model

Loading a model for later use can be done as follows:

In [None]:
model = LitDLIRFramework()
model.add_stage(GlobalTransformer(AIRNet(kernels=16)))
model.add_stage(BsplineTransformer(DIRNet(grid_spacing=(8, 8), kernels=16, num_conv_layers=5, num_dense_layers=2)))
model.add_stage(BsplineTransformer(DIRNet(grid_spacing=(4, 4), kernels=16, num_conv_layers=5, num_dense_layers=2)))
model.load_state_dict(torch.load(DEST_DIR / 'mnist_dlir_affine_8_4.ckpt')['state_dict'])