# Training a 2D Unet model

In this model, we will demonstrate how to train a U-Net model from scratch. It is strongly encouraged to create a separate environment for this:

```
conda create -n pytorch-segmentation python=3.9
conda activate pytorch-segmentation
```
In order to correctly install pytorch, find the matching command for your system on the [pytorch page](https://pytorch.org/get-started/locally/):

![](./torch_get_started.png)

After that is done, install some more packages:

```
mamba install scikit-image albumentations segmentation-models-pytorch pandas matplotlib torchmetrics tensorboard -c conda-forge
```



In [None]:
import os
from pathlib import Path
import numpy as np
import tqdm
from skimage import io
import pandas as pd
import matplotlib.pyplot as plt
import datetime
import yaml

import albumentations as albu
import torch, torchmetrics
from torch.utils.data import DataLoader
from torch.nn import CrossEntropyLoss
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint

import segmentation_models_pytorch as smp
import requests
import zipfile

In [None]:
torch.cuda.is_available()

## Dataset preparation

The first thing we have to do for Pytorch training, is tpo create a custom `Dataset` class for our dataset. This dataset object will then serve as a utility through which Pytorch can access and load all the data from the drive. During the training process we will iterate over our dataset, so the `Dataset` implementation needs to have two important member functions:

* `__len__()`: The Dataloader needs to know how *many* samples there are.
* `__getitem__()`: The Dataloader needs to be able to access the i-th sample out of our whole dataset.

In this example, we will create a pandas dataframe with two columns `image` and `mask` and pass this to the Dataset. Working with dataframes is particularly easy because it allows us to easily split our data into a training and a validation cohort.

*Note 1*: We will add the option for augmentation to the datafly. If we pass an augmentation function as an argument (which should accept parameters `image` and `mask` as inputs), the augmentations are applied on-the-fly. We will use the [albumentations package](https://albumentations.ai/) for this.
*Note 2*: We will use a pretrained model. Such pretrained models are typically implemented for "normal" images, i.e. RGB images. Hence, the used model expects images to have the dimensions `[3, Y, X]`. Thus, if we work with grayscale images, we need to stack the single channel we are working with 3 times to create an artifical RGB image.

In [None]:
class Dataset():
    def __init__(self, root, image_dir='images', label_dir='labels', augmentation=None):
        self.root = root
        self.augmentation = augmentation
        
        self.image_dir = os.path.join(root, image_dir)
        self.label_dir = os.path.join(root, label_dir)
        
        # Assuming image and label filenames match, we can just use one list for pairing
        filenames = sorted(os.listdir(self.image_dir))

        # Creating a DataFrame to store paired images and labels
        self.data = pd.DataFrame({
            'image_filenames': [os.path.join(self.image_dir, fname) for fname in filenames],
            'label_filenames': [os.path.join(self.label_dir, fname) for fname in filenames]
        })

    def __getitem__(self, i):
        image_filepath = self.data['image_filenames'].iloc[i]
        label_filepath = self.data['label_filenames'].iloc[i]

        image = io.imread(image_filepath) / 255
        mask = np.argmax(io.imread(label_filepath), axis=2)  # Assuming label images have multiple channels for classes

        if self.augmentation:
            sample = self.augmentation(image=image.astype(np.float32), mask=mask)
            image, mask = sample['image'], sample['mask']

        return image.astype(np.float32).transpose((2, 0, 1)), mask[None, :]

    def __len__(self):
        return len(self.data)

Download the data:

In [None]:
if not 'data.zip' in os.listdir('.'):
    r = requests.get(r'https://zenodo.org/record/7213527/files/HE_segmentation_data.zip')
    with open("data.zip", 'wb') as f:
        f.write(r.content)
        
with zipfile.ZipFile('./data.zip', 'r') as zip_ref:
    zip_ref.extractall('./data')

Unzip:

In [None]:
root = os.path.abspath('./data')
root

Let's use this opportunity to quickly check what the above-defined Dataset class does with this. For thhis, we create an instance of the Dataset class using the dataframe with the filenames of our training data. Then we'll try to get an arbitrary sample from the dataset:

In [None]:
MyDataset = Dataset(root=os.path.join(root))
sample = MyDataset[120]

fig, axes = plt.subplots(ncols=2, figsize=(15, 15), sharex=True, sharey=True)
axes[0].imshow(sample[0].transpose((1,2,0)))  # need to transpose for matplotlib RGB format
axes[1].imshow(sample[1][0])

## Augmentation

For a typical segmentation job, it makes sense to augment the data to a certain degree. Albumentations allows to compose several augmentations together and luckily apply it to the image and the mask alike on-the fly. We will use the following augmentations for the training:
* Vertical flip: Flip the image upside down in 50% of all calls
* Horizontal flip: Same, but horizontal
* Random Rotate: * Randomly rotate image and mask by 90 degrees in 50% of all calls
* Random brightness/contrast: randomly change the brightness/contrast setting of the image in 20% of all calls

In [None]:
aug_train = albu.Compose([
    albu.VerticalFlip(p=0.5),
    albu.HorizontalFlip(p=0.5),
    albu.RandomRotate90(p=0.5),
    albu.RandomBrightnessContrast(p=0.2)
])

### Train/test/validation split

Last but not least, we need to create three subsets from our dataset, a training-, test- and validation-cohort. In every epoch of training (see below), the training process will look at all images in the training cohort and update our model based on this. The model is then applied to the images in the test cohort without updating the model. This is to see how well the method is currently performing. Finally, the model is appliedto the image data in the validation cohort to measure its performance on unknown data.  Scikit-learn provides the KFold strategy for problems like this.

In [None]:
val_size = int(0.1 * len(MyDataset))
train_size = int(0.8 * (len(MyDataset) - val_size))
test_size = len(MyDataset) - val_size - train_size
train_dataset, test_dataset, validation_dataset = torch.utils.data.random_split(MyDataset, [train_size, test_size, val_size])

print('Samples in training set: ', len(train_dataset))
print('Samples in testing set: ', len(test_dataset))
print('Samples in validation set: ', len(validation_dataset))

## Model creation and preparation

Next, we have to actually create an instance of a model which we will train for a number of epochs. This section will mostly set some parameters which are explained here.

* `n_classes`: As desribed elsewhere, we are only trying to separate a background nd a foreground here. Hence, there are only 2 classes in our case.
* `epochs`: In one epoch, the the Dataloader will go throough the entire dataset, update the layers and then check the net's performance in the test dataset. This is then repeated `epochs` times
* `batch_size`: During the training process, multiple images are stitched together to a batch of images. This is done along a batch axis that is added to the image data. Typically, images are provided in  `[B, C, Y, X]` shape, with `B` being the batch dimension, `C` the channel dimension and `Y, X` being the actual image dimensions. If the `batch_size` is set too large, the images may not fit on the GPU anymore. 
*Note*: Images in a batch are usually batch-averaged! In other words, the pixel intensity values will be z-score normalized using the commmon mean and standard deviation of the entire batch. Making the batch too small can disrupt these running statistics.

* `learning_rate`: How much the weights of every layer should be changed  in every training step. This is also referred to as the momentum of the training - see [here](https://twitter.com/marktenenholtz/status/1490309316347248646) for a nice explanation!
* `num_workers`: How many CPU cores are allowed to be used to operate the dataloaders to feed the data to the network

In [None]:
epochs = 100
log_interval = 10
n_classes = 3
batch_size = 24
learning_rate = 2e-5
num_workers = 0

## The loss

An aspect of paramount importance is the used **loss function**. After all, deep learning is all about passing data through the network, evaluating the performance and then changing the weights accordingly. The loss function determines how exactlly performance is measured. Torch offers a few different implementations but you can basically implement any metric that compares two label images and calculates something like a degree of similarity. Something that is very commonly used (and thus being used here) is the `CrossEntropyLoss()` function, which calculates the cross-entropy of two label images. The [CrossEntropy ](https://en.wikipedia.org/wiki/Cross_entropy) is closely related to the [Mutual Information](https://en.wikipedia.org/wiki/Cross_entropy).

Since the cross-entropy is a bit abstract to interpret, we will use a more intuitive measure to monitor the performance of our network in the validation cohort: The [Jaccard-coefficient](https://en.wikipedia.org/wiki/Jaccard_index). During the training we should observe that the cross-entropy in the training process goes down while the Jaccard-index should converge closer to 1.

In [None]:
criterion_train = CrossEntropyLoss()
criterion_test = torchmetrics.functional.accuracy

In [None]:
class MyModel(pl.LightningModule):
    def __init__(
        self,
        loss_fn=CrossEntropyLoss(),
        loss_fn_test=torchmetrics.functional.accuracy,
        n_classes=3,
        learning_rate=2e-5):
        super(MyModel, self).__init__()

        # store some parameters
        self.learning_rate = learning_rate
        self.loss_fn = loss_fn
        self.loss_fn_test = loss_fn_test
        self.n_classes = n_classes

        self.model = smp.Unet(
            encoder_name='resnet50',
            encoder_weights='imagenet',
            classes=self.n_classes,
            activation=None,
        )

        self.encoder = self.model.encoder
        self.decoder = self.model.decoder

        # log hyperparameters
        self.save_hyperparameters()
    
    def forward(self, x):
        return self.model.forward(x)
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.forward(x)
        loss = self.loss_fn(y_hat, y.squeeze())
        self.log("train_loss", loss)

        return loss
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.forward(x)
        loss = self.loss_fn(y_hat, y.squeeze())
        loss_accuracy = self.loss_fn_test(y_hat.argmax(
            axis=1), y.int().squeeze(), average=None, num_classes=3, task="multiclass")

        log = {'accuracy {}'.format(i): loss_accuracy[i] for i in range(self.n_classes)}
        log["validation_loss"] = loss
        self.log_dict(log)

        # make a matplotlib figure of the prediction if batch index is zero
        if batch_idx == 0:
            fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(15, 5))
            axes[0].imshow(x[0, 0, :, :].cpu().numpy())
            axes[1].imshow(y[0, 0, :, :].cpu().numpy())
            axes[2].imshow(y_hat.argmax(axis=1)[0, :, :].cpu().numpy())
            self.logger.experiment.add_figure(
                "validation prediction", fig, self.current_epoch)
            
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate)
        return optimizer


### Training

Now we have everything at hand to actually start training! For this, we first create Datasets from our train/test dataframes. Let's not forget to pass the composed albumentations **only to the training dataset.** We could also apply the augmentations to the trainign data, but it is preferable to have performance statistics on the real, unchanged image data.

In [None]:
# Create dataloaders
num_workers = 0
train_dataloader = DataLoader(dataset=train_dataset, batch_size=batch_size, num_workers=num_workers)
test_dataloader = DataLoader(dataset=test_dataset, batch_size=batch_size, num_workers=num_workers)
validation_dataloader = DataLoader(dataset=validation_dataset, batch_size=batch_size, num_workers=num_workers)

Every training epoch follows the following steps:

* Set the model to training mode: In this mode, torch automatically updates the gradients of the model's layers on-the-fly as image data is passed through the layers. The optimizer can then have a look at these gradients to know how the convolutions in the respective layers need to be changed to improve performance
* Reset the gradients known to the optimizer from the previous epoch
* Pass each batch of training data through the network and calculate the loss (deviation of acquired result and correct mask)
* Back-propagate the loss through the network and calculate the gradients
* Let the optimizer update the weights of the network

To check the progress of the training, navigate to the working directory in a terminal and open the tensorboard with

```
tensorboard --logdir=runs
```

Then, navigate to http://localhost:6006/ in the browser.

In [None]:
model = MyModel()
model.train(True)

In [None]:
logfolder = Path("lightning_outputs")
if not logfolder.is_dir():
    logfolder.mkdir()

ckpt_callback = ModelCheckpoint(
    filename='{epoch:03.0f}-{train_loss:.3f}',
    save_last=True,
    save_top_k=1,
    monitor="train_loss",
    every_n_epochs=1
)

In [None]:
use_cuda = torch.cuda.is_available()
ndevices = torch.cuda.device_count()
trainer = pl.Trainer(
    default_root_dir=logfolder,
    max_epochs=epochs,
    log_every_n_steps=log_interval,
    accelerator="gpu" if use_cuda else "cpu",
    devices=ndevices if use_cuda else 1,
    num_sanity_val_steps=0,
    callbacks=[
        ckpt_callback
    ]
)

In [None]:
trainer.fit(model, train_dataloaders=train_dataloader, val_dataloaders=validation_dataloader)

## Validation

Last but not least, let's apply the trained model to the data in the validation cohort and calculate performance statistics.

In [None]:
model = MyModel.load_from_checkpoint("./lightning_outputs/lightning_logs/version_10/checkpoints/epoch=098-train_loss=0.082.ckpt")
model.eval()

In [None]:
torch.tensor(validation_dataset[0][0][None, :])

In [None]:
prediction = model(torch.tensor(validation_dataset[0][0][None, :]).to('cuda'))
prediction = prediction.detach().cpu().numpy()[0].transpose((1, 2, 0))

fig, axes = plt.subplots(1, 2, figsize=(10, 5))
axes[0].imshow(validation_dataset[0][0].transpose((1,2,0)))
axes[1].imshow(np.argmax(prediction, axis=2))

## Advanced implementation aspects:

There are a few options we can to improve the training process and make it more robust. Some of them are listed here and will be expored in more advanced notebooks for the sake of this notebook's simplicity.

* **Early Stopping:** We have implemented a naive way of stopping the training process if the model is good enough. However, to effectively prevent overfitting, we need to stop the training process as soon as performance is not improving anymore. Thus, a more suitable early stopping implementation would have to look at the train/test performance scores within the last X epochs and check if the variance of performance scores has been small. If so, the training process will be interrupted.
* **Weighted sampling**: If we want to segmented regions in the image, which are rare in our training data, any well-behaving network should lean towards not predicting such labels *at all*. After all, the error in prediction is small if these labels are sufficiently scarce. To counter this, we can introduce [weighted sampling](https://pytorch.org/docs/stable/data.html#torch.utils.data.WeightedRandomSampler) to ensure that the network is equally exposed to all present labels in the image data.
* **Scheduling**: As the model converges closer to its optimum, it is not wise to keep updating the model at the same speed as during the first epochs. We can do this by changing the `learning_rate` parameter as we progress through the epochs. Thus, the steps become smaller and safer. See [this tweet](https://twitter.com/marktenenholtz/status/1490309316347248646) for a nice visualization.
* **Visualization**: It is good practice to visualize the training process through a side-by-side comparison of reference annotation and predicted label image.

## Inference

Applying the model to data is also called inference. Let's try this on some sample data. If you have not trained a model of your own previously, you can load the one that comes with the downloaded data:

In [None]:
with torch.no_grad():
    tk0 = tqdm.tqdm(valid_dataloader, total=len(valid_dataloader))
    for b_idx, data in enumerate(tk0):

        # Move images and masks in batch to GPU
        for key, value in data.items():
            data[key] = value.to(device).float()

        # Feed the batch through the network and catch output into a new dictionary key
        data['prediction']  = model(data['image'])

        for i in range(data['image'].shape[0]):
            fig, axes = plt.subplots(ncols=4, figsize=(10,40))
            axes[0].imshow(data['image'][i].cpu().numpy().transpose((1,2,0)))
            axes[1].imshow(data['prediction'][i][0].cpu().numpy())
            axes[2].imshow(data['prediction'][i][1].cpu().numpy())
            axes[3].imshow(data['prediction'][i][2].cpu().numpy())
            axes[0].set_title('Raw')
            axes[1].set_title('Background-ness')
            axes[2].set_title('Necrosis-ness')
            axes[3].set_title('Vital-ness')
            fig.tight_layout()

In [None]:
sample = list(tk0)[-1]
sample['image'].shape

In [None]:
output = model_children[0](sample['image'].float())
output[0].shape

In [None]:
ncols=10
fig, axes = plt.subplots(nrows=5, ncols=ncols, figsize=(15,8))

batch = 0
for k in range(1,6):
    for i in range(ncols):
        feature_maps = output[k][batch].detach().cpu().numpy()
        axes[k-1, i].imshow(feature_maps[i], cmap='gray')
        axes[k-1,0].set_ylabel('$N_{featuremaps}$ '+f'= {len(feature_maps)}' + '\n' + f'{feature_maps[i].shape[0]} x {feature_maps[i].shape[1]}')
fig.tight_layout()
fig.savefig('featuremaps.png')

In [None]:
plt.imshow(sample['image'].detach().numpy()[batch].transpose((1,2,0)))
plt.savefig('input.png')