In [1]:
import multiprocessing
import os
import random
import time
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import ruamel.yaml
import torch
import torchio as tio
import torchvision
from tqdm.notebook import tqdm

In [2]:
seed = 42
random.seed(seed)
torch.manual_seed(seed)
%config InlineBackend.figure_format = 'retina'
num_workers = multiprocessing.cpu_count()
plt.rcParams['figure.figsize'] = 12, 6

print('Last run on', time.ctime())
print('TorchIO version:', tio.__version__)

Last run on Sun May 22 11:48:34 2022
TorchIO version: 0.18.77


In [3]:
yaml = ruamel.yaml.YAML()
yaml.representer.ignore_aliases = lambda *data: True

DefaultConfigPath = "Config.yaml"

with open(DefaultConfigPath, encoding="utf-8") as inyaml:
    yaml_data = yaml.load(inyaml)

In [4]:
dataset_dir = Path(yaml_data["dataset_dir_name"])

In [8]:
transforms = [
    tio.ToCanonical(),  # to RAS
    tio.Resample((1, 1, 1)),  # to 1 mm iso
]
ixi_dataset = tio.datasets.IXI(
    dataset_dir,
    modalities=(yaml_data["modalities"]),
    transform=tio.Compose(transforms),
    download=yaml_data["download_dataset"],
    )
print('Number of subjects in dataset:', len(ixi_dataset))
sample_subject = ixi_dataset[0]
print('Keys in subject:', tuple(sample_subject.keys()))
for key in yaml_data["modalities"]:
    print(f'Shape of {key} data:', sample_subject[key].shape)

Number of subjects in dataset: 570
Keys in subject: ('subject_id', 'MRA')
Shape of MRA data: (1, 241, 240, 80)


In [9]:
folder = Path(yaml_data["output_augmentation_dir_name"])
if not os.path.exists(folder): 
    os.makedirs(folder)

In [10]:
def img_transform(sample, transtpye):

    if transtpye == 'RandomFlip':
        RandomFlip = tio.RandomFlip(axes=['inferior-superior'], flip_probability=1)
        transformed = RandomFlip(sample)
        transformed.mri.save(f'{yaml_data["output_augmentation_dir_name"]}//{key}//{transtpye}//RF-{str(sample.mri.path).lstrip(str(images_dir))}')

    if transtpye == 'RandomAffine':
        RandomAffine = tio.RandomAffine()
        transformed = RandomAffine(sample)
        transformed.mri.save(f'{yaml_data["output_augmentation_dir_name"]}//{key}//{transtpye}//RAff-{str(sample.mri.path).lstrip(str(images_dir))}')

    if transtpye == 'RandomElasticDeformation':
        max_displacement = 15, 10, 0 
        RandomElasticDeformation = tio.RandomElasticDeformation(max_displacement=max_displacement)
        transformed = RandomElasticDeformation(sample)
        transformed.mri.save(f'{yaml_data["output_augmentation_dir_name"]}//{key}//{transtpye}//RE-{str(sample.mri.path).lstrip(str(images_dir))}')

    if transtpye == 'RandomAnisotropy':
        RandomAnisotropy = tio.RandomAnisotropy()
        transformed = RandomAnisotropy(sample)
        transformed.mri.save(f'{yaml_data["output_augmentation_dir_name"]}//{key}//{transtpye}//RA-{str(sample.mri.path).lstrip(str(images_dir))}')

    if transtpye == 'RandomMotion':
        RandomMotion = tio.RandomMotion()
        transformed = RandomMotion(sample)
        transformed.mri.save(f'{yaml_data["output_augmentation_dir_name"]}//{key}//{transtpye}//RM-{str(sample.mri.path).lstrip(str(images_dir))}')

    if transtpye == 'RandomGhosting':
        RandomGhosting = tio.RandomGhosting(intensity=1.5)
        transformed = RandomGhosting(sample)
        transformed.mri.save(f'{yaml_data["output_augmentation_dir_name"]}//{key}//{transtpye}//RG-{str(sample.mri.path).lstrip(str(images_dir))}')

    if transtpye == 'RandomSpike':
        RandomSpike = tio.RandomSpike()
        transformed = RandomSpike(sample)
        transformed.mri.save(f'{yaml_data["output_augmentation_dir_name"]}//{key}//{transtpye}//RS-{str(sample.mri.path).lstrip(str(images_dir))}')

    if transtpye == 'RandomBiasField':
        RandomBiasField = tio.RandomBiasField(coefficients=1)
        transformed = RandomBiasField(sample)
        transformed.mri.save(f'{yaml_data["output_augmentation_dir_name"]}//{key}//{transtpye}//RBias-{str(sample.mri.path).lstrip(str(images_dir))}')

    if transtpye == 'RandomBlur':
        RandomBlur = tio.RandomBlur()
        transformed = RandomBlur(sample)
        transformed.mri.save(f'{yaml_data["output_augmentation_dir_name"]}//{key}//{transtpye}//RB-{str(sample.mri.path).lstrip(str(images_dir))}')

    if transtpye == 'RandomNoise':
        add_noise = tio.RandomNoise(std=0.5)
        transformed = add_noise(sample)
        transformed.mri.save(f'{yaml_data["output_augmentation_dir_name"]}//{key}//{transtpye}//RN-{str(sample.mri.path).lstrip(str(images_dir))}')

    if transtpye == 'RandomSwap':
        RandomSwap = tio.RandomSwap()
        transformed = RandomSwap(sample)
        transformed.mri.save(f'{yaml_data["output_augmentation_dir_name"]}//{key}//{transtpye}//RSwap-{str(sample.mri.path).lstrip(str(images_dir))}')

    if transtpye == 'RandomGamma':
        RandomGamma = tio.RandomGamma()
        transformed = RandomGamma(sample)
        transformed.mri.save(f'{yaml_data["output_augmentation_dir_name"]}//{key}//{transtpye}//RG-{str(sample.mri.path).lstrip(str(images_dir))}')

    if transtpye == 'RandomLabelsToImage':
        RandomLabelsToImage = tio.RandomLabelsToImage()
        transformed = RandomLabelsToImage(sample)
        transformed.mri.save(f'{yaml_data["output_augmentation_dir_name"]}//{key}//{transtpye}//RL2I-{str(sample.mri.path).lstrip(str(images_dir))}')
        
    # if transtpye == 'MIX':
    #     transform = tio.Compose((
    #         #tio.CropOrPad(max_side),
    #         tio.RandomFlip(),
    #         tio.RandomAffine(degrees=360),
    #     ))
    #     dataset = tio.SubjectsDataset(subjects, transform=transform)
    #     transformed.mri.save(f'{yaml_data["output_augmentation_dir_name"]}//{key}//{transtpye}//RF-{str(sample.mri.path).lstrip(str(images_dir))}')

In [11]:
def img_normalization(sample, transtpye):
    to_ras = tio.ToCanonical()
    sample_ras = to_ras(sample)
    target_shape = 256, 256, 256
    crop_pad = tio.CropOrPad(target_shape)
    croped = crop_pad(sample_ras)
    if transtpye == 'HistogramStandardization':
        landmarks = np.load(histogram_landmarks_path)
        landmarks_dict = {'mri': landmarks}
        histogram_transform = tio.HistogramStandardization(landmarks_dict)
        transformed = histogram_transform(croped)
        
        transformed.img.save(f'{yaml_data["output_augmentation_dir_name"]}//{key}//{transtpye}//HS-{str(sample.mri.path).lstrip(str(images_dir))}')

    if transtpye == 'ZNormalization':
        standardize = tio.ZNormalization()
        transformed = standardize(croped)
        transformed.mri.save(f'{yaml_data["output_augmentation_dir_name"]}//{key}//{transtpye}//ZN-{str(sample.mri.path).lstrip(str(images_dir))}')

    return transformed

In [None]:
for key in yaml_data["modalities"]:
    folder = Path(yaml_data["output_augmentation_dir_name"], key)
    if not os.path.exists(folder): 
        os.makedirs(folder)
    histogram_landmarks_path = f'{yaml_data["output_augmentation_dir_name"]}//{key}//landmarks_{key}.npy'
    images_dir = dataset_dir / key
    image_paths = sorted(images_dir.glob('*.nii.gz'))
    subjects = []
    for image_path in image_paths:
        subject = tio.Subject(
            mri=tio.ScalarImage(image_path)
        )
        subjects.append(subject)

    dataset = tio.SubjectsDataset(subjects)
    if not os.path.exists(histogram_landmarks_path): 
        landmarks = tio.HistogramStandardization.train(
        image_paths,
        output_path=histogram_landmarks_path,
        )
        #np.set_printoptions(suppress=True, precision=3)
        #print(f'\nTrained {key} landmarks:', landmarks)
    
    
    
    for i ,sample in enumerate(tqdm(dataset)):
    #     standard = histogram_transform(sample)
    #     tensor = standard.mri.data
    #     path = str(sample.mri.path)
        folder = Path(yaml_data["output_augmentation_dir_name"], key, yaml_data["Normalization"])
        if not os.path.exists(folder): 
            os.makedirs(folder)
        sample_norm = img_normalization(sample, yaml_data["Normalization"])
        for type in yaml_data["transform"]:
            folder = Path(yaml_data["output_augmentation_dir_name"], key, type)
            if not os.path.exists(folder): 
                os.makedirs(folder)
            img_transform(sample_norm, type)

100%|██████████| 578/578 [02:24<00:00,  4.01it/s]


  0%|          | 0/578 [00:00<?, ?it/s]