In [1]:
from time import time
from batchgenerators.augmentations.crop_and_pad_augmentations import crop
from batchgenerators.dataloading import MultiThreadedAugmenter, SingleThreadedAugmenter
from config import brats_preprocessed_folder, num_threads_for_brats_example
# from batchgenerators.examples.brats2017.config import brats_preprocessed_folder, num_threads_for_brats_example
from batchgenerators.transforms import Compose
from batchgenerators.utilities.data_splitting import get_split_deterministic
from batchgenerators.utilities.file_and_folder_operations import *
import numpy as np
from batchgenerators.dataloading.data_loader import DataLoader
from batchgenerators.augmentations.utils import pad_nd_image
from batchgenerators.augmentations.spatial_transformations import augment_resize
from batchgenerators.transforms.spatial_transforms import SpatialTransform_2, MirrorTransform
from batchgenerators.transforms.color_transforms import BrightnessMultiplicativeTransform, GammaTransform, BrightnessTransform
from batchgenerators.transforms.noise_transforms import GaussianNoiseTransform, GaussianBlurTransform, RicianNoiseTransform

import nibabel as nib

from brats_data_loader import BRATSDataLoader
import matplotlib.pyplot as plt


In [16]:
channel_indices = {
    't1': 0,
    't1c': 1,
    't2': 2,
    'flair': 3,
    'seg': 4
}

def get_train_transform(patch_size):
    # we now create a list of transforms. These are not necessarily the best transforms to use for BraTS, this is just
    # to showcase some things
    tr_transforms = []

    # the first thing we want to run is the SpatialTransform. It reduces the size of our data to patch_size and thus
    # also reduces the computational cost of all subsequent operations. All subsequent operations do not modify the
    # shape and do not transform spatially, so no border artifacts will be introduced
    # Here we use the new SpatialTransform_2 which uses a new way of parameterizing elastic_deform
    # We use all spatial transformations with a probability of 0.2 per sample. This means that 1 - (1 - 0.1) ** 3 = 27%
    # of samples will be augmented, the rest will just be cropped

    tr_transforms.append(
        SpatialTransform_2(
            patch_size, [i // 2 for i in patch_size],
            do_elastic_deform=True, deformation_scale=(0, 0.25),
            do_rotation=True,
            angle_x=(- 15 / 360. * 2 * np.pi, 15 / 360. * 2 * np.pi),
            angle_y=(- 15 / 360. * 2 * np.pi, 15 / 360. * 2 * np.pi),
            angle_z=(- 15 / 360. * 2 * np.pi, 15 / 360. * 2 * np.pi),
            do_scale=True, scale=(0.65, 1.60),
            border_mode_data='constant', border_cval_data=0,
            border_mode_seg='constant', border_cval_seg=0,
            order_seg=1, order_data=3,
            random_crop=True,
            p_el_per_sample=0.1, p_rot_per_sample=0.1, p_scale_per_sample=0.1
        )
    )

    # now we mirror along all axes
    tr_transforms.append(MirrorTransform(axes=(0, 1, 2)))

    # COLOR TRANSFORMS - brightness transform for 15% of samples 
    # tr_transforms.append(BrightnessMultiplicativeTransform((0.7, 1.5), per_channel=True, p_per_sample=0.15))
    tr_transforms.append(BrightnessTransform(mu=0, sigma=0.5, p_per_sample=0.3))

    # COLOR TRANSFORMS - gamma transform. This is a nonlinear transformation of intensity values
    # (https://en.wikipedia.org/wiki/Gamma_correction)
    tr_transforms.append(GammaTransform(gamma_range=(0.5, 2), invert_image=False, per_channel=True, p_per_sample=0.15))
    # we can also invert the image, apply the transform and then invert back
    tr_transforms.append(GammaTransform(gamma_range=(0.5, 2), invert_image=True, per_channel=True, p_per_sample=0.15))

    # Gaussian Noise and RicianNoise
    # tr_transforms.append(GaussianNoiseTransform(noise_variance=(0, 0.05), p_per_sample=0.15))
    # tr_transforms.append(GaussianNoiseTransform(noise_variance=(0, 0.05), p_per_sample=0.15))
    tr_transforms.append(RicianNoiseTransform(noise_variance=(0, 0.05), p_per_sample=0.15))

    # blurring. Some BraTS cases have very blurry modalities. This can simulate more patients with this problem and
    # thus make the model more robust to it
    # tr_transforms.append(GaussianBlurTransform(blur_sigma=(0.5, 1.5), different_sigma_per_channel=True,
    #                                            p_per_channel=0.5, p_per_sample=0.15))
    tr_transforms.append(GaussianBlurTransform(blur_sigma=(0.5, 1.5), different_sigma_per_channel=True,
                                               p_per_channel=0.7, p_per_sample=0.30))
    

    # now we compose these transforms together
    tr_transforms = Compose(tr_transforms)
    return tr_transforms


def get_list_of_patients(preprocessed_data_folder):
    npy_files = subfiles(preprocessed_data_folder, suffix=".npy", join=True)
    # remove npy file extension
    patients = [i[:-4] for i in npy_files]
    return patients

#NEW
def iterate_through_patients(patients, in_channels):
    in_channels = [channel_indices[i] for i in in_channels]
    
    for p in patients:
        patient_data, meta_data = BRATSDataLoader.load_patient(p)
        
        # patient_data = BRATSDataLoader.load_patient(p)[0][in_channels][None]
        # meta_data = BRATSDataLoader.load_patient(p)[1]
        yield (patient_data[in_channels][None], meta_data)
        
def iterate_through_patients_transforms(patients, in_channels):
    patient_data_ls = []
    meta_data_ls = []
    in_channels = [channel_indices[i] for i in in_channels]
    
    for p in patients:
        patient_data, meta_data = BRATSDataLoader.load_patient(p)
        
        # patient_data = BRATSDataLoader.load_patient(p)[0][in_channels][None]
        # meta_data = BRATSDataLoader.load_patient(p)[1]

        
        
        
        yield (patient_data[in_channels][None], meta_data)        
        return 

# Load in patient data here 

In [101]:
base = 'brats_data_preprocessed/'

patients = get_list_of_patients(base+"Brats20TrainingData")
patients_train = patients
print(f"The number of training patients: {len(patients_train)}")


test_patients = get_list_of_patients(base + "Brats20ValidationData")
target_patients = test_patients[0:5]
print(f"The number of test patients: {len(target_patients)}")


batch_size = 12
patch_size = [24, 128, 128]
in_channels = ['t1c', 't2', 'flair']

train_dl = BRATSDataLoader(
    patients_train,
    batch_size=batch_size,
    patch_size=patch_size,
    in_channels=in_channels
)


tr_transforms = get_train_transform(patch_size)

tr_gen = MultiThreadedAugmenter(train_dl, tr_transforms, num_processes=4, # num_processes=4
                                num_cached_per_queue=3,
                                seeds=None, pin_memory=False)


#Target shape in predict patient in patches is: [1, 3, 144, 192, 192]
# Training data shape is (12, 3, 24, 128, 128), (batch_size, channels, depth, height, width)

The number of training patients: 10
The number of test patients: 5


In [119]:
patdata, metadata = BRATSDataLoader.load_patient(target_patients[0])

In [99]:
metadata

{'spacing': array([1., 1., 1.]),
 'direction': (1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0),
 'origin': (-0.0, -239.0, 0.0),
 'original_shape': (155, 240, 240),
 'nonzero_region': array([[  0, 139],
        [ 35, 220],
        [ 52, 184]], dtype=int64)}

In [100]:
meta_data

{'spacing': array([1., 1., 1.]),
 'direction': (1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0),
 'origin': (-0.0, -239.0, 0.0),
 'original_shape': (155, 240, 240),
 'nonzero_region': array([[  2, 135],
        [ 48, 207],
        [ 46, 191]], dtype=int64)}

In [130]:
from tqdm import tqdm
tr_transforms = get_train_transform(patch_size)

for patient in tqdm(target_patients):
    
    original_size, metadata_old = BRATSDataLoader.load_patient(patient)
    patch_size = list(original_size[0,:,:,:].shape)
    print(f"Patch size is: {patch_size}")
    
    name = patient.split("\\")[-1]
    print(f"Patient {name}")
    
    test_dl_new = BRATSDataLoader(
    [patient],
    batch_size=1,
    patch_size=patch_size,
    in_channels=in_channels
    ) 

    # What if we apply the same process to the test data first? 
    test_gen_new =  MultiThreadedAugmenter(test_dl_new, tr_transforms, num_processes=4, # num_processes=4
                                    num_cached_per_queue=3,
                                    seeds=None, pin_memory=False)

#     test_gen_new.restart()
    batch = next(test_gen_new)
    patient_data = batch["data"]
    meta_data = batch["metadata"][0]
    name = batch["names"]
    print(name, patient_data.shape)
    
    test_gen_new._finish()
    del test_dl_new
    continue

  0%|          | 0/5 [00:00<?, ?it/s]

Patch size is: [140, 186, 133]
Patient BraTS20_Validation_001
['brats_data_preprocessed/Brats20ValidationData\\BraTS20_Validation_001'] (1, 3, 140, 186, 133)


 20%|██        | 1/5 [07:57<31:50, 477.53s/it]

Patch size is: [142, 190, 143]
Patient BraTS20_Validation_002
['brats_data_preprocessed/Brats20ValidationData\\BraTS20_Validation_002'] (1, 3, 140, 186, 133)


 40%|████      | 2/5 [14:10<20:48, 416.21s/it]

Patch size is: [140, 165, 125]
Patient BraTS20_Validation_003
['brats_data_preprocessed/Brats20ValidationData\\BraTS20_Validation_003'] (1, 3, 140, 186, 133)


 60%|██████    | 3/5 [20:56<13:42, 411.27s/it]

Patch size is: [138, 180, 132]
Patient BraTS20_Validation_004
['brats_data_preprocessed/Brats20ValidationData\\BraTS20_Validation_004'] (1, 3, 140, 186, 133)


 80%|████████  | 4/5 [35:23<09:51, 591.45s/it]

Patch size is: [139, 163, 147]
Patient BraTS20_Validation_005
['brats_data_preprocessed/Brats20ValidationData\\BraTS20_Validation_005'] (1, 3, 140, 186, 133)


100%|██████████| 5/5 [41:36<00:00, 499.39s/it]


In [109]:
patient_data.shape

(1, 3, 24, 128, 128)

In [108]:
batch["metadata"][0]

{'spacing': array([1., 1., 1.]),
 'direction': (1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0),
 'origin': (-0.0, -239.0, 0.0),
 'original_shape': (155, 240, 240),
 'nonzero_region': array([[  0, 138],
        [ 45, 207],
        [ 44, 190]], dtype=int64)}

In [71]:
test_dl_new = BRATSDataLoader(
target_patients,
batch_size=1,
patch_size=patch_size,
in_channels=in_channels
) 

# What if we apply the same process to the test data first? 
test_gen_new =  MultiThreadedAugmenter(test_dl_new, tr_transforms, num_processes=4, # num_processes=4
                                num_cached_per_queue=3,
                                seeds=None, pin_memory=False)

for patient in range(len(target_patients)):
#     name = patient.split("\\")[-1]
#     print(f"Patient {name}")

    batch = next(test_gen_new)
    patient_data = batch["data"]
    meta_data = batch["metadata"]
    name = batch["names"]
    print(name, patient_data.shape)

    continue

ERROR:root:MultiThreadedGenerator: caught exception: (<class 'KeyboardInterrupt'>, KeyboardInterrupt(), <traceback object at 0x0000025D49827F00>)


KeyboardInterrupt: 

In [None]:
stringlol = "brats_data_preprocessed/Brats20ValidationData/BraTS20_Validation_001"

stringlol.split('/')[-1]

In [None]:
target_patients

In [None]:
next(test_gen)['data'].shape

In [None]:
test_gen2 = SingleThreadedAugmenter()

In [None]:
for idx, batch in enumerate(test_gen):
    print(idx,batch["names"])

In [None]:
batch.keys()

In [None]:
for idx, i in enumerate(test_gen):
    print(i)

In [None]:
batch = next(test_dl)

In [None]:
batch = next(tr_gen)
batch["data"][0].shape

In [None]:
batch = next(tr_gen)
batch["data"][0].shape

In [None]:
some_pat = batch["data"][0][0]
some_pat = np.moveaxis(some_pat, 0, -1)
print(some_pat.shape)

In [None]:
def show_slices(slices):
    """ Function to display row of image slices """
    fig, axes = plt.subplots(1, len(slices))
    for i, slice in enumerate(slices):
        axes[i].imshow(slice.T, cmap="gray", origin="lower")

In [None]:
slice_0 = some_pat[65, :, :]
slice_1 = some_pat[:, 65, :]
slice_2 = some_pat[:, :, 5]

show_slices([slice_0,slice_1,slice_2])


In [None]:
slice_0 = some_pat[65, :, :]
slice_1 = some_pat[:, 65, :]
slice_2 = some_pat[:, :, 15]

show_slices([slice_0,slice_1,slice_2])


In [None]:
train_dl.generate_train_batch()["data"].shape

In [None]:
next(iterate_through_patients(target_patients, in_channels))[0].shape

In [None]:
import nibabel as nib

nib.load("MICCAI_BraTS2020_TrainingData/BraTS20_Training_001/BraTS20_Training_001_t1.nii.gz").get_fdata().shape

In [None]:
data , metadata = BRATSDataLoader.load_patient(base+"Brats20TrainingData/BraTS20_Training_001")
data.shape

In [None]:
data , metadata = BRATSDataLoader.load_patient(base+"Brats20ValidationData/BraTS20_Validation_001")
data.shape

In [None]:
pad_nd_image(data, [144, 192, 192]).shape