# Augmentation

### Import useful libraries

In [1]:
from typing import Tuple, List
from pathlib import Path
from tqdm.notebook import tqdm

import numpy as np
import matplotlib.pyplot as plt

import kornia as kn
import kornia.augmentation as K

import torch
from torch import nn, optim
import torch.nn.functional as F
from torchvision import transforms
from torch.utils.data import Dataset, random_split, DataLoader

from kedro.extras.datasets.pickle import PickleDataSet
from kedro.config import ConfigLoader

In [3]:
import os, sys
sys.path.append(os.path.abspath('../src'))

from tagseg.data.acdc_dataset import AcdcDataSet

### Set up connection to dataset

Raw images and labels (of varying sizes) are saved in `ims` and `las` respectively

In [4]:
conf_paths = ["../conf/base", "../conf/local"]
conf_loader = ConfigLoader(conf_paths)
conf_catalog = conf_loader.get("catalog*", "catalog*/**")

dataset = PickleDataSet(filepath='../' + conf_catalog['acdc_data_tagged']['filepath']).load()

KeyError: 'acdc_data_tagged'

Fetch all images and their labels without any preprocessing or augmentations

In [None]:
verbose: bool = True

# Get all patient folders from main raw downloaded ACDC directory
patient_paths = [ppath for ppath in Path(acdc_path).iterdir() if ppath.is_dir()]

ims: List[np.ndarray] = []
las: List[np.ndarray] = []

accepted_classes: set = set([0., 1., 2., 3.])

# Iterate over all patients
patients_pbar = tqdm(patient_paths, leave=True)
for ppath in patients_pbar:
    if verbose > 0:
        patients_pbar.set_description(f'Processing {ppath.name}...')
    
    # Loading .nii.gz files in handled in the `Patient` class
    patient = Patient(ppath)
    assert len(patient.images) == len(patient.masks)
    
    # Loop through each patient's list of images (around 10 per patient)
    for image, label in zip(patient.images, patient.masks):        
        image, label = image.astype(np.float64), label.astype(np.float64)

        ims.append(image)
        las.append(label)

### Test out and visualize different augmentations

In [None]:
_preprocess_image = transforms.Compose([
    SimulateTags(),
    transforms.ToTensor(),
    transforms.Normalize(mean=0.456, std=0.224),
    transforms.Resize((256, 256))
])

_preprocess_label = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((256, 256), interpolation=transforms.InterpolationMode.NEAREST)
])

In [None]:
selection = slice(0, 10)

images: torch.Tensor = torch.Tensor()
labels: torch.Tensor = torch.Tensor()

for im, la in tqdm(zip(ims[selection], las[selection]), total=len(ims[selection])):

    image = im
    image /= image.max()
    image = _preprocess_image(image).unsqueeze(0)
    image += image.min()
    label = _preprocess_label(la)

    images = torch.cat((images, image), axis=0)
    labels = torch.cat((labels, label), axis=0)

In [None]:
images.shape, labels.shape

In [None]:
proba: float = 0.2

train_aug = K.AugmentationSequential(
    K.RandomHorizontalFlip(p=proba),
    K.RandomVerticalFlip(p=proba),
    K.RandomElasticTransform(p=proba),
    K.RandomGaussianNoise(p=proba),
    K.RandomSharpness(p=proba),
    K.RandomGaussianBlur(kernel_size=(3, 3), sigma=(0.1, 0.1), p=proba),
    data_keys=["input", "mask"],
)

In [None]:
aug_ims, aug_las = augment(images, labels.unsqueeze(1))

In [None]:
aug_las.squeeze(1).size()

In [None]:
fig, axes = plt.subplots(10, 6, figsize=(20, 40))

for i in range(10):
    axes[i, 0].imshow(ims[i], cmap='gray'),         axes[i, 0].axis('off')
    axes[i, 1].imshow(las[i]),                      axes[i, 1].axis('off')
    axes[i, 2].imshow(images[i, 0], cmap='gray'),   axes[i, 2].axis('off')
    axes[i, 3].imshow(labels[i]),                   axes[i, 3].axis('off')
    axes[i, 4].imshow(aug_ims[i, 0], cmap='gray'),  axes[i, 4].axis('off')
    axes[i, 5].imshow(aug_las[i, 0]),               axes[i, 5].axis('off')

NameError: name 'plt' is not defined