API Reference: https://docs.monai.io/en/stable/api.html

In [None]:
import torch
import monai
import matplotlib.pyplot as plt
import numpy as np

## Data

In [None]:
# Load datalist
path = '/mount/src/data/datalist.npy'
datalist = list(np.load(path, allow_pickle=True))
datalist[:3]

In [None]:
# Shuffle
np.random.shuffle(datalist)
datalist[:3]

In [None]:
# Split the datalist to train and validation
data_train = datalist[:800]
data_val = datalist[800:]

## Transforms

In [None]:
keys = ['img', 'seg']
spatial_size = (256, 256)
prob = 0.7

def convert_mask(mask):
    mask = mask[0]>100
    return mask.astype('int')[None, ...]

trans = monai.transforms.Compose([monai.transforms.LoadImaged(keys), # I/O
                                  monai.transforms.EnsureChannelFirstd(keys), # Pre-processing
                                  monai.transforms.Lambdad(keys='seg', func=convert_mask), # Pre-processing
                                  monai.transforms.ToDeviced(keys, device='cuda'), # Pre-processing
                                  monai.transforms.Resized(keys, spatial_size=spatial_size, mode=['area', 'nearest']), # Pre-processing
                                  monai.transforms.NormalizeIntensityd(keys='img'), # Pre-processing
                                  monai.transforms.RandAdjustContrastd(keys='img', gamma=(0.8, 3.0), prob=prob), # Augmentation
                                  monai.transforms.RandFlipd(keys, prob=prob), # Augmentation
                                  monai.transforms.RandCoarseDropoutd(keys, holes=1, max_holes=10,
                                                                      spatial_size=(32, 32), max_spatial_size=(96, 96),
                                                                      dropout_holes=True, fill_value=0, prob=prob) # Augmentation
                                 ])
val_trans = monai.transforms.Compose([monai.transforms.LoadImaged(keys), 
                                      monai.transforms.EnsureChannelFirstd(keys),
                                      monai.transforms.Lambdad(keys='seg', func=convert_mask),
                                      monai.transforms.ToDeviced(keys, device='cuda'),
                                      monai.transforms.Resized(keys, spatial_size=spatial_size, mode=['area', 'nearest']),
                                      monai.transforms.NormalizeIntensityd(keys='img')
                                     ])

post_trans = monai.transforms.AsDiscrete(threshold=0.5) # Thresholding

In [None]:
test = trans(data_train[0])
plt.figure(figsize=(8, 4))
plt.subplot(121)
plt.imshow(test['img'].cpu().numpy().transpose([1, 2, 0]))
plt.subplot(122)
plt.imshow(test['seg'][0].cpu(), cmap='gray')
plt.show()
test['img'].shape, test['seg'].shape

## Dataset and Dataloader

In [None]:
batch_size = 32

ds_train = monai.data.CacheDataset(data_train, transform=trans)
dl_train = torch.utils.data.DataLoader(ds_train, batch_size=batch_size)

ds_val = monai.data.CacheDataset(data_val, transform=val_trans)
dl_val = torch.utils.data.DataLoader(ds_val, batch_size=batch_size)

## Network, loss and optimizer

In [None]:
device = torch.device("cuda")

net = monai.networks.nets.SegResNet(
    spatial_dims=2,
    in_channels=3,
    out_channels=1,
    dropout_prob=.5
).to(device)

net(test['img'][None, ...]).shape

In [None]:
loss_function = monai.losses.DiceLoss(sigmoid=True)
optimizer = torch.optim.Adam(net.parameters(), 5e-4)
dice_metric = monai.metrics.DiceMetric(include_background=False, reduction="mean")

## Train
Classic PyTorch for-loop

In [None]:
max_epochs = 500
epoch_loss_values = []
epoch_valloss_values = []
metric_values = []
best_metric = 0
best_metric_epoch = -1

for epoch in range(max_epochs):
    print('Epoch: '+str(epoch+1)+'/'+str(max_epochs))
    # Train
    epoch_loss = 0
    net.train()
    ## Load data from training dataloader iteratively
    for step, batch_data in enumerate(dl_train):
        inputs, labels = (
            batch_data["img"],
            batch_data["seg"]
        )
        optimizer.zero_grad()
        ## Forward path
        outputs = net(inputs)
        ## Calculate loss
        loss = loss_function(outputs, labels)
        ## Backward 
        loss.backward()
        ## Update model
        optimizer.step()
        epoch_loss += loss.item()
    epoch_loss /= (step+1)
    epoch_loss_values.append(epoch_loss)
    # Validation
    val_loss = 0
    net.eval()
    ## Disabled gradient calculation
    with torch.no_grad():
        ## Load data from validation dataloader iteratively
        for step, batch_data in enumerate(dl_val):
            val_inputs, val_labels = (
                batch_data["img"],
                batch_data["seg"]
            )
            ## Forward (Inference)
            val_outputs = net(val_inputs)
            ## Calculate loss
            loss = loss_function(val_outputs, val_labels)
            val_loss += loss.item()
            ## Calculate Dice score
            val_outputs = [post_trans(i) for i in monai.data.decollate_batch(val_outputs)]
            val_labels = monai.data.decollate_batch(val_labels)
            dice_metric(y_pred=val_outputs, y=val_labels)
        val_loss /= (step+1)
        epoch_valloss_values.append(val_loss)
        ## aggregate the final mean dice result
        metric = dice_metric.aggregate().item()
        ## reset the status for next validation round
        dice_metric.reset()
        metric_values.append(metric)
        ## Save the model with the best metric
        if metric > best_metric:
            best_metric = metric
            best_metric_epoch = epoch + 1
            torch.save(net.state_dict(), './checkpoints/best.pt')
        
    print(f'  Train_loss: {epoch_loss:.4f}')
    print(f'  Val_loss: {val_loss:.4f}', f', Val_dice: {metric:.4f}')
torch.save(net.state_dict(), './checkpoints/last.pt')

In [None]:
plt.figure(figsize=(16, 8))
plt.subplot(211)
plt.plot(epoch_loss_values, 'b')
plt.plot(epoch_valloss_values, 'r')
plt.legend(['Train_loss', 'Val_loss'])
plt.subplot(212)
plt.plot(metric_values, 'g')
plt.legend(['Val_dice'])
plt.show()

In [None]:
best_metric, best_metric_epoch

## Visualization

In [None]:
# Restore the best checkpoint
best_ckpt_dict = torch.load('./checkpoints/best.pt')
net.load_state_dict(best_ckpt_dict)

In [None]:
# Get one val data and inference
load_resize = monai.transforms.Compose([monai.transforms.LoadImage(image_only=True),
                                        monai.transforms.EnsureChannelFirst(),
                                        monai.transforms.Resize(spatial_size=spatial_size, mode='area')])

fpath = data_val[0]
img = load_resize(fpath['img'])
data = val_trans(fpath)
inputs = data['img'][None, ...]
label = data['seg'][None, ...]
net.eval()
output = net(inputs)
output = post_trans(output)
inputs.shape, label.shape, output.shape, img.shape

In [None]:
# Plot the results
plt.figure(figsize=(20, 5))
plt.subplot(141)
plt.imshow(img.numpy().transpose([1, 2, 0]).astype('uint8'))
plt.title('Image')
plt.subplot(142)
plt.imshow(inputs[0].cpu().numpy().transpose([1, 2, 0]))
plt.title('Pre-processed Image')
plt.subplot(143)
plt.imshow(label[0, 0].cpu().numpy(), cmap='gray')
plt.title('Ground Truth')
plt.subplot(144)
plt.imshow(output[0, 0].cpu().numpy(), cmap='gray')
plt.title('Prediction')
plt.show()

## Export
Export the model to ONNX format

In [None]:
dummy_input = torch.randn(batch_size, 3, 256, 256).to(device)

with torch.no_grad():
    print(net(dummy_input).shape)

In [None]:
torch.onnx.export(net, dummy_input, 'model.onnx')

## Exploration
Try to use different architecture, transformation and hyperparameter to improve the model.