In [None]:
import os
import torch
import torchio as tio
import numpy as np

data_dir = os.path.join(os.path.dirname(os.getcwd()), "data")

In [None]:
def to_pil(image):
    from PIL import Image
    from IPython.display import display
    data = image.numpy().squeeze().T
    data = data.astype(np.uint8)
    image = Image.fromarray(data)
    w, h = image.size
    display(image)
    print()  # in case multiple images are being displayed


In [None]:
img = tio.ScalarImage(os.path.join(data_dir,'vhp', 'a_vm1125.png'))
img.affine[0, 0] = img.affine[1, 1] = 0.33  # according to https://www.nlm.nih.gov/databases/download/vhp.html
print(img)
img.as_pil()

### Solutions to exercise

In [None]:
flip_ap = tio.RandomFlip(axes=['anteroposterior'], flip_probability=1)
flip_lr = tio.RandomFlip(axes=['lateral'], flip_probability=0.5)
crop = tio.CropOrPad((800, 800, 1))
resample = tio.Resample(0.75)
elastic = tio.RandomElasticDeformation(max_displacement=50)
affine = tio.RandomAffine()
spatial = tio.OneOf({
    elastic: 0.6,
    affine: 0.4,
})
blur = tio.RandomBlur(p=0.75)
noise = tio.RandomNoise(mean=128, std=10)
rescale = tio.RescaleIntensity((0, 255))
noise_rescale = tio.Compose([noise, rescale])
rgb2gray = tio.Lambda(lambda tensor: torch.mean(tensor, 0, keepdim=True))

transforms = [
    flip_ap,
    flip_lr,
    crop,
    resample,
    spatial,
    blur,
    noise_rescale,
    rgb2gray,
]

transform = tio.Compose(transforms)

[to_pil(transform(img)) for _ in range(10)];

### Solutions to bonus exercise

In [None]:
import time
import warnings

def apply_transforms(image, transforms, seed=42, show=False, exclude=None):
    torch.manual_seed(seed)
    results = []
    transformed = image
    tic_all = time.time()
    for transform in transforms:
        tic = time.time()
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            result = transform(transformed)
            # We want the transform to be applied even if it is excluded to ensure reproducibility
            if exclude is None or transform.name not in exclude:
                transformed = result
        millis = int((time.time() - tic) * 1000)
        print(f'{transform.name:12}{millis:>5} ms')
        results.append(transformed)
    millis = int((time.time() - tic_all) * 1000)
    print(f'{"TOTAL":12}{millis:>5} ms')
    if show:
        [to_pil(im) for im in results];

print('All transforms:')
apply_transforms(img, transforms)

print('\nWithout cropping:')
apply_transforms(img, transforms, exclude=['CropOrPad'])

print('\nWithout resampling:')
apply_transforms(img, transforms, exclude=['Resample'])

print('\nWithout cropping and resampling:')
apply_transforms(img, transforms, exclude=['Resample', 'CropOrPad'])

# Cropping and resampling makes our code one order of magnitude faster