In [None]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

# Integration External Frameworks into a `rising` Augmentation Pipeline

### Using transformation from external libraries inside `rising`
> Note: Some external augmentation libraries are only supported at the beginning of
the transformation pipeline. In general, please consider creating an issue in `rising` 
and there will be a high chance we (or if you prefer you :) ) will add the transformation in the future :) 

## 3D (Volumetric) Augmentation
The first part of this notebook will focus on frameworks which support volumetric transformations (like rising also does). This mean data with a shape of [C, D, H, W] (C=Channels, DHW sptial dimensions). 

In [None]:
!pip install --quiet --upgrade SimpleITK
!git clone https://github.com/PhoenixDL/rising.git
!pip install --quiet --upgrade ./rising

In [None]:
# download some volumetric data (here MRI Data)
from io import BytesIO
from zipfile import ZipFile
from urllib.request import urlopen

resp = urlopen("http://www.fmrib.ox.ac.uk/primers/intro_primer/ExBox3/ExBox3.zip")
zipfile = ZipFile(BytesIO(resp.read()))

img_file = zipfile.extract("ExBox3/T1_brain.nii.gz")
mask_file = zipfile.extract("ExBox3/T1_brain_seg.nii.gz")

In [None]:
import SimpleITK as sitk
import numpy as np

img = sitk.GetArrayFromImage(sitk.ReadImage(img_file))
img = img.astype(np.float32)
# sitk.WriteImage(sitk.GetImageFromArray(img), img_file)
mask = mask = sitk.GetArrayFromImage(sitk.ReadImage(mask_file))
mask = mask.astype(np.float32)
# sitk.WriteImage(sitk.GetImageFromArray(mask), mask_file)

assert mask.shape == img.shape
print(f"Image shape {img.shape}")
print(f"Mask shape {mask.shape}")

### Integration of `batchgenerators` transformations
Note: when batchgenerator transformations are integrated, gradients can not be propagated through its
transformations.

`batchgenerators` transformations are based on numpy to be framework agnostic. They are also based
on dictionaries which are modified through the transformations.

There are two steps which need to be integrated into your pipelin in order to the 
`batchgenerators` transforms

1. Exchange the `default_collate` function inside the dataloder with `numpy_collate`
2. When switching from `batchgenerators` transformations to `rising` transformations, insdert `ToTensor` transformation

In [None]:
# setup transforms
from rising.transforms import *
from batchgenerators.transforms import ZeroMeanUnitVarianceTransform

transforms = []
# convert tuple into dict
transforms.append(SeqToMap("data", "label"))

# batchgenerators transforms
transforms.append(ZeroMeanUnitVarianceTransform())
# ... additional batchgenerator transforms

# convert to tensor
transforms.append(ToTensor())

# rising transforms
transforms.append(Rot90((0, 1)))
transforms.append(Mirror(dims=(0, 1)))

In [None]:
from rising.loading import DataLoader, default_transform_call, numpy_collate
from rising.transforms import Compose

composed = Compose(transforms, transform_call=default_transform_call)
dataloader = DataLoader(dataset, batch_size=8, batch_transforms=composed,
                        num_workers=0, collate_fn=numpy_collate)
_iter = iter(dataloader)

In [None]:
batch = next(_iter)
show_batch(batch["data"])

### Integration of `tochio`

In [None]:
!pip install --quiet --upgrade torchio

In [None]:
dataset = None

import torchio

subject_a = torchio.Subject(
    t1=torchio.Image('./ExBox3/T1_brain.nii.gz', torchio.INTENSITY),
)
rescale = torchio.transforms.RescaleIntensity((0, 1))
transform = torchio.transforms.Compose([rescale])

# ImagesDataset is a subclass of torch.data.utils.Dataset
dataset = torchio.ImagesDataset([subject_a], transform=transform)

In [None]:
dataset[0]

In [None]:
# setup any additional rising transformations
from rising.transforms import *

class SelectKeys(AbstractTransform):
    def __init__(self, keys=["t1"]):
        super().__init__(grad=False)
        self.keys = keys
    
    def forward(self, **batch):
        for _key in self.keys:
            batch[_key] = batch[_key]["data"]
        return batch

In [None]:
rising_transforms = [
    SelectKeys(keys=["t1"]),
    Rot90(keys=("t1",), dims=(0, 1)),
    Mirror(keys=("t1",), dims=(0, 1)),
]
batch_transforms = Compose(rising_transforms)

In [None]:
# Instead of using the native pytorch dataloader we exchange it for the dataloder from rising 
from rising.loading import DataLoader

dataloader = DataLoader(dataset, batch_size=1, batch_transforms=batch_transforms, num_workers=4)
_iter = iter(dataloader)

In [None]:
batch = next(_iter)
print(batch)

## 2D Augmentation

In [None]:
# lets prepare a basic dataset (e.g. one from `torchvision`)
import os
import torchvision
import numpy as np
import torch

def to_array(inp):
    """
    We need a small helper in this example because torchvision datasets output PIL
    images. When using them in combination with `rising`,
    just add `torchvision.transforms.ToTensor()`to the transform of the dataset
    
    Returns
    -------
    numpy.ndarray
        converted data
    """
    from PIL import Image
    if isinstance(inp, Image.Image):
        return np.array(inp, np.float32, copy=False)[None]
    elif isinstance(inp, torch.Tensor):
        return inp.detach().cpu().numpy()
    else:
        return inp

dataset = torchvision.datasets.MNIST(
    os.getcwd(), train=True, download=True, transform=to_array)

In [None]:
#  plot shape
print(dataset[0][0].shape)
# visualize a single image
import matplotlib.pyplot as plt

plt.imshow(dataset[0][0][0], cmap='gray')
plt.colorbar()
plt.show()

In [None]:
# helper function to visualize batches of images
import torch

def show_batch(batch: torch.Tensor):
    grid = torchvision.utils.make_grid(batch)
    plt.imshow(grid[0], cmap='gray')
    # plt.colorbar()
    plt.show()

### Integration of `albumentation`

In [None]:
!pip install --quiet --upgrade albumentations

In [None]:
from albumentations import RandomRotate90, Flip, Compose

def aug(p=0.5):
    return Compose([
        RandomRotate90(),
        Flip(),
    ], p=p)

augmentation = aug(p=0.9)

In [None]:
rising_transforms = [
    SelectKeys(keys=["t1"]),
    Rot90(keys=("t1",), dims=(0, 1)),
    Mirror(keys=("t1",), dims=(0, 1)),
]
batch_transforms = Compose(rising_transforms)

In [None]:
# Instead of using the native pytorch dataloader we exchange it for the dataloder from rising 
from rising.loading import DataLoader

dataloader = DataLoader(dataset, batch_size=1, batch_transforms=batch_transforms, num_workers=4)
_iter = iter(dataloader)

### Integration of `imgaug`

In [None]:
!pip install --quiet --upgrade imgaug

In [None]:
# needs a rename transform

### Integration of `torchvision`

### You want a library which is not listed here? Just open an issue [here](https://github.com/PhoenixDL/rising/issues).
