In [None]:
'''
Notebook to demonstrate the augmentations done in the
compare_augmentations.ipynb notebook. Creates a
dataloader with 1 batch of images, then transforms each
of the images in the batch. Saves the result to
check_augmentations_dir/screenshots.
'''


import tbitk.ai.deep_learning as dl
import matplotlib.pyplot as plt
import itk
import scipy.ndimage
import numpy as np
from pathlib import Path
from monai.transforms import (
    RandRotated,
    RandScaleIntensityd,
    ThresholdIntensityd,
    RandAffined,
    LoadImaged,
    EnsureTyped,
    AddChanneld,
    Resized,
    ScaleIntensityd,
    RandFlipd,
    Compose,
    MapTransform
)
from tbitk.ai.transforms import eval_transforms, eval_transforms
from tbitk.ai.constants import DEFAULT_BEST_MODEL_NAME, BATCH_SIZE
from torchvision.utils import make_grid
NETWORK_INPUT_SHAPE = (256, 256)
def get_transforms(l):
    transforms = [
        LoadImaged(keys=["x", "y"], image_only=True),
        EnsureTyped(keys=["x", "y"]),
        AddChanneld(keys=["x", "y"]),
        Resized(keys=["x", "y"], spatial_size=NETWORK_INPUT_SHAPE, mode="nearest"),
    ]
    if "gain" in l:
        transforms.append(RandScaleIntensityd(keys=["x"], prob=1, factors=(0, 0.75)))
        transforms.append(ThresholdIntensityd(keys=["x"], threshold=1, above=False, cval=1))
    if "randflip" in l:
        transforms.append(RandFlipd(keys=["x", "y"], prob=1, spatial_axis=1))
    if "randtranslate" in l:
        transforms.append(RandAffined(keys=["x", "y"], prob=1, translate_range=(0, 50), padding_mode="zeros"))
    if "randrotate" in l:
        transforms.append(RandRotated(keys=["x", "y"], prob=1, range_x=0.35, padding_mode="zeros"))
        
    transforms.extend([
        EnsureTyped(keys=["x"], data_type="numpy"),
#         ScaleIntensityd(keys=["x"]),
        EnsureTyped(keys=["x", "y"])
    ])

    return Compose(transforms)

In [None]:
# Note: This needs to be a directory containing extracted frames
# and masks. Naming convention should be img_{i}.mha and mask_{i}.mha
# for the entire directory.
# Here we use the extracted files from a run of the main notebook
INPUT_DATA_DIR = Path("051622/data/train/")

In [None]:
# Write them to a directory
check_augs_dir = Path("check_augmentations_dir/")
ex_images_dir = check_augs_dir / "ex_images"
screenshot_dir = check_augs_dir / "screenshots"

ex_images_dir.mkdir(exist_ok=True, parents=True)
screenshot_dir.mkdir(exist_ok=True, parents=True)
for i in range(BATCH_SIZE):
    im = itk.imread(str(INPUT_DATA_DIR / f"img_{i}.mha"))
    mask = itk.imread(str(INPUT_DATA_DIR / f"mask_{i}.mha"))

    itk.imwrite(im, str(ex_images_dir / f"img_{i}.mha"))
    itk.imwrite(mask, str(ex_images_dir / f"mask_{i}.mha"))

def save_transformed_batch(transform, fname, dir_=ex_images_dir):
    data_loader = dl.get_data_loader([dir_], transform, shuffle=False)
    batchdata = next(iter(data_loader))
    grid = make_grid(batchdata["x"], nrow=4)
    grid = grid[0]
    plt.imsave(screenshot_dir / fname, grid)
    

save_transformed_batch(get_transforms([]), "no_aug.png")

In [None]:
transform = get_transforms(["gain"])
save_transformed_batch(transform, "gain.png")

In [None]:
transform = get_transforms(["randtranslate"])
save_transformed_batch(transform, "translate.png")

In [None]:
transform = get_transforms("randrotate")
save_transformed_batch(transform, "rotate.png")