In [None]:
"""
Este script tiene por objetivo realizar data augmentation de datos de resonancia magnetica estructural mediante la
librería torchIO. TorchIO es una libreria versatil que permite realizar data augmentation en 3D para despues introducirlo
facilmente en una red neuronal implementada en Pytorch. Ademas, permite preparar los datos sin leerlos desde el disco 
hasta el momento del entrenamiento para evitar problemas de memoria RAM. Ya que la pipeline implementada en este github
utiliza una red neuronal en keras, las imagenes se leen, se augmentan y se lleva a cabo histogram matching,
y despues se guardan los archvios generados.

TorchIO esta preparada para tareas donde las etiquetas no son un solo numero, sino imagenes con unos para 
la region de interes y zeros para el resto. No esta directamente implementada las etiquetas binarias
para la clasificacion entre pacientes y sujetos sanos utilizando toda la imagen. Por eso las imagenes de ambos
grupos se leen separadamente para este script. Cabe destacar que el autor de este codigo no es un programador,
por lo que sin duda el codigo puede ser simplificado por aquellos que sepan como hacerlo (por ejemplo, creando funciones
heredadas). A pesar de no ser optimo, funciona.

Documentacion de torchIO con ejemplos y tutoriales: https://torchio.readthedocs.io/
Informacion sobre permitir etiquetas que no sean imagenes: https://github.com/fepegar/torchio/issues/112

This script has the goal to perform data augmentation on structural magnetic resonance imaging using torchIO. 
TorchIO is a library which allows to permorm 3D data augmentation to further introduce easily the data in a 
neural net implemented in Pytorch. Furthermore, it allows to set up the data without reading them from the disk 
until the moment of training, avoiding memory issues. As the code in this github is implemented in keras this script
is donde to read the file, perform data augmentation and histogram matching, and save them on disk.

TorchIO is set up for tasks where the label is not a single number, but an image with ones for a region of interest
and zeros for the rest. Groups are readed separately to add the label manually. It is worth mentioning that
the author of this code is not a programmer so this code can be simplified and optimized. Even though it is
not perfect, it does work.

Documentation about torchIO with examples and tutorials: https://torchio.readthedocs.io/
Info about allowing labels that are not images: https://github.com/fepegar/torchio/issues/112
"""

In [None]:
import enum
import time
import multiprocessing
from pathlib import Path
import os

import torch
import torchvision
import torchio as tio
import torch.nn.functional as F

import numpy as np
from unet import UNet
from scipy import stats
import matplotlib.pyplot as plt

from IPython import display
from tqdm.notebook import tqdm

import nibabel as nib

#evitar un warning
from scipy.ndimage import measurements
from scipy.ndimage import zoom

In [None]:
"""
Set variables:
"""
histogram_landmarks_path = 'landmarks.npy'
#crop = (15, 8, 15, 4, 1, 14) si se quiere recortar

In [None]:
OneGroupPath = "PATH"
AnotherGroupPath = "AnotherPATH"

In [None]:
"""
Crear clase subject para ambos grupos (para poder etiquetarlos) y despues crear un solo SubjectsDataset
Create subject class for both groups (so you can label then) and then create only one SubjectDataset
"""

subjects = []

# Usar zip si el formato de las imagenes es .nii.gz
# Zip is used in case images are in .nii.gz format
for OneGroupPath in zip(OneGroupPath):
    subject = tio.Subject(
    mri = tio.ScalarImage(OneGroupPath),
    label = torch.tensor(1)
    )
    subjects.append(subject)
    
for AnotherGroupPath in zip(AnotherGroupPath):
    subject = tio.Subject(
    mri = tio.ScalarImage(AnotherGroupPath),
    label = torch.tensor(0)
    )
    subjects.append(subject)
    
dataset = tio.SubjectsDataset(subjects)
print("Dataset size: ", len(dataset), " subjects")

In [None]:
"""
Calcular landmarks para histogram standarization
Calculate landmarks for histogram standarization
"""

landmarks = tio.HistogramStandardization.train(
    whole_image_path,
    output_path = histogram_landmarks_path,
)
np.set_printoptions(suppress=True, precision=3)
print('\nTrained landmarks:', landmarks)

In [None]:
"""
Transformaciones antes de augmentar o validar
Transformations done before augmentation or validation
"""

landmarks_dict = {"mri" : landmarks}
pre_augmentation_transfrom = tio.Compose([
    tio.HistogramStandardization({"mri":landmarks}),
    #tio.Crop(crop), para cortar las imagenes / in case you want to crop images
    #tio.Resample(1.8) para modificar la resolucion / in case you want to resize the images
])

In [None]:
"""
Crear train set y validation set para el data augmentation
Create train set and validation set before data augmentation
"""

# Generar los sets de entrenamiento y validacion
# Generate training and validation sets

training_split_ratio = 0.8
num_subjects = len(dataset)
num_training_subjects = int(training_split_ratio * num_subjects)
num_validation_subjects = num_subjects - num_training_subjects

num_split_subjects = num_training_subjects, num_validation_subjects

training_subjects, validation_subjects = torch.utils.data.random_split(
    subjects, num_split_subjects, generator=torch.Generator().manual_seed(17))

# Realizar las transformaciones previas al augmentation (histogram mactching, crop, resample)
# Perform transforms before augmentation (histogram mactching, crop, resample)
training_set = tio.SubjectsDataset(
    training_subjects, transform=pre_augmentation_transfrom)

validation_set = tio.SubjectsDataset(
    validation_subjects, transform=pre_augmentation_transfrom)

In [None]:
"""
Guardar sets para controlar que imagenes estan en que set
Save sets to have control of which images are in which set
"""

contador = 1
for subject in validation_set:
    data, affine = subject.mri.data, subject.mri.affine
    data = data.numpy()
    data = data.squeeze(axis=0)
    val_nii = nib.Nifti1Image(data, affine)
    nib.save(val_nii, "path_to_save_val_set/SubjectNumber" +str(contador)+ "/struct.nii.gz")
    torch.save(subject.label, "path_to_save_val_Set/SubjectNumber" +str(contador)+ "/label")
    contador +=1
    
contador = 1
for subject in training_set:
    data, affine = subject.mri.data, subject.mri.affine
    data = data.numpy()
    data = data.squeeze(axis=0)
    train_nii = nib.Nifti1Image(data, affine)
    nib.save(val_nii, "path_to_save_train_set/SubjectNumber" +str(contador)+ "/struct.nii.gz")
    torch.save(subject.label, "path_to_save_train_Set/SubjectNumber" +str(contador)+ "/label")
    contador +=1

In [None]:
"""
Para utilizar el codigo literalmente debes crear estas carpetas
To use this code literally you have to create this folders
"data_augmentations/gamma/gammasubj1"
"data_augmentations/gamma/gammasubj2"
...
"data_augmentations/noise/aug_noisesubj1"
"data_augmentations/noise/aug_noisesubj2"
...
y asi sucesivamente
and so on
"""

gamma_subj_path = os.listdir("data_augmentations/gamma/")
noise_subj_path = os.listdir("data_augmentations/noise/")
rotation_subj_path = os.listdir("data_augmentations/rotation/")
scaling_subj_path = os.listdir("data_augmentations/scaling/")
shift_subj_path = os.listdir("data_augmentations/shift/")

In [None]:
"""
Data augmentation para el training set y guardarlo en disco. El path para guardar las etiquetas y las imagenes depende
de como se quiere que sea la estructura de carpetas. La usada aqui es solo una opcion.
Data augmentation for training set and save it to disk. The paths were to save the augmented images and
labels depends on how you want the folder structure. The one here is just an option.
"""

contador_subj = 1
for subject in training_set:
    #Gamma
    torch.save(subject.label, "data_augmentations/gamma/gammasubj"+str(contador_subj)+ "/label")
    gammas = tio.Gamma(gamma = None)
    gamma_aug = []
    value = 0.7
    while value <= 1.3:
        gammas.gamma = value
        augmentation_gamma = gammas(subject.mri)
        gamma_aug.append(augmentation_gamma)
        value += 0.2
        contador_img = 1
    for augmentation in gamma_aug:
        data, affine = augmentation.data, augmentation.affine
        data = data.numpy()
        data = data.squeeze(axis=0)
        gamma_nii = nib.Nifti1Image(data, affine)
        nib.save(gamma_nii,
                    "data_augmentations/gamma/gammasubj"+str(contador_subj)+ "/gamma"+str(contador_img)+".nii.gz")
        contador_img +=1

    #Gaussian Noise
    torch.save(subject.label, "data_augmentations/noise/noisesubj"+str(contador_subj)+ "/label")
    noise = tio.Noise(mean=0, std=0.1, seed=None)
    noise_aug = []
    seed = 0
    while seed <= 30:
        noise.seed = seed
        augmentetation_noise = noise(subject.mri)
        noise_aug.append(augmentetation_noise) 
        seed += 10
        contador_img = 1
    for augmentation in noise_aug:
        data, affine = augmentation.data, augmentation.affine
        data = data.numpy()
        data = data.squeeze(axis=0)
        noise_nii = nib.Nifti1Image(data, affine)
        nib.save(noise_nii,
                    "data_augmentations/noise/noisesubj"+str(contador_subj)+ "/noise"+str(contador_img)+".nii.gz")
        contador_img +=1
        
    #Image rotation:
    torch.save(subject.label, "data_augmentations/rotation/rotationsubj"+str(contador_subj)+ "/label")
    rotation_aug = []
    angle = -16
    while angle <= 16:
        rotation = tio.Affine(scales=1, degrees=angle, translation=0)
        augmentetation_rotation = rotation(subject.mri)
        rotation_aug.append(augmentetation_rotation) 
        angle += 8
        contador_img = 1
    for augmentation in rotation_aug:
        data, affine = augmentation.data, augmentation.affine
        data = data.numpy()
        data = data.squeeze(axis=0)
        rotation_nii = nib.Nifti1Image(data, affine)
        nib.save(rotation_nii,
                    "data_augmentations/rotation/rotationsubj"+str(contador_subj)+ "/rotation"+str(contador_img)+".nii.gz")
        contador_img +=1
        
    #Scaling
    torch.save(subject.label, "data_augmentations/scaling/scalingsubj"+str(contador_subj)+ "/label")
    scaling_aug = []
    scaling_factor = 0.7
    while scaling_factor <= 1.3:
        scaling = tio.Affine(scales=scaling_factor, degrees=0, translation=(0,0,0))
        augmentation_scaling = scaling(subject.mri)
        scaling_aug.append(augmentation_scaling)
        scaling_factor += 0.2
        contador_img = 1
    for augmentation in scaling_aug:
        data, affine = augmentation.data, augmentation.affine
        data = data.numpy()
        data = data.squeeze(axis=0)
        scaling_nii = nib.Nifti1Image(data, affine)
        nib.save(scaling_nii,
                    "data_augmentations/scaling/scalingsubj"+str(contador_subj)+ "/scaling"+str(contador_img)+".nii.gz")
        contador_img +=1
    
    #shift
    torch.save(subject.label, "data_augmentations/shift/shiftsubj"+str(contador_subj)+ "/label")
    shift = tio.RandomAffine(scales=0, degrees=0, translation=(70))
    shift_aug = []
    numero = 0
    while numero <= 30:
        augmentation_shift = shift(subject.mri)
        shift_aug.append(augmentation_shift)
        numero +=10
        contador_img = 1
    for augmentation in shift_aug:
        data, affine = augmentation.data, augmentation.affine
        data = data.numpy()
        data = data.squeeze(axis=0)
        shift_nii = nib.Nifti1Image(data, affine)
        nib.save(shift_nii,
                    "data_augmentations/shift/shiftsubj"+str(contador_subj)+ "/shift"+str(contador_img)+".nii.gz")
        contador_img +=1
    
    contador_subj += 1

In [8]:
"""
Comprobar informacion de un sujeto en el dataset
Check the info for one subject of the dataset
"""

one_subject = dataset[78]
print(one_subject)
print(one_subject.mri)
print(one_subject.label)

Subject(Keys: ('mri', 'label'); images: 1)
ScalarImage(shape: (1, 193, 229, 193); spacing: (1.00, 1.00, 1.00); orientation: RAS+; memory: 32.5 MiB; dtype: torch.FloatTensor)
tensor(0)
