In [None]:
import os
import random
import h5py
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tqdm import tqdm
from tensorflow.keras.preprocessing.image import ImageDataGenerator

# funky librairies for segmentation
import segmentation_models_3D as sm
from patchify import patchify, unpatchify

In [None]:
PATH_DATASET='./challenge_dataset/'
CENTER_CUBE_ONLY = False # False train on all data (split volume in 9 cubes), True train only on a (64,64,64) cube around the aneurysm = less data
TEST_SIZE = 0.2 # % of test samples from the full dataset
VAL_SPLIT = 0.2 # % of training samples kept for the validation metrics
CROP = 64

In [None]:
# get file names
file_names = os.listdir(PATH_DATASET)
N = len(file_names)
print(f'{N} samples in dataset.')

# open all .h5 files, split inputs and target masks, store all in np.arrays
raw_data = []
labels = []
names = []

for file_name in tqdm(file_names):
    f = h5py.File(f'{PATH_DATASET}/{file_name}', 'r')

    X, Y = np.array(f['raw']), np.array(f['label'])

    if CENTER_CUBE_ONLY: # only keep the center cube (over 9 candidates)
        X = X[:,CROP:2*CROP,CROP:2*CROP]
        Y = Y[:,CROP:2*CROP,CROP:2*CROP]

        raw_data.append(X)
        labels.append(Y)
        names.append(file_name)

    else: # keep all = more data
        X_patches = patchify(X, (64, 64, 64), step=64)  # Step=64 for 64 patches means no overlap
        X_patches_resh = np.reshape(X_patches, (-1, X_patches.shape[3], X_patches.shape[4], X_patches.shape[5]))
        Y_patches = patchify(Y, (64, 64, 64), step=64)  # Step=64 for 64 patches means no overlap
        Y_patches_resh = np.reshape(Y_patches, (-1, Y_patches.shape[3], Y_patches.shape[4], Y_patches.shape[5]))
        raw_data.append(X_patches_resh)
        labels.append(Y_patches_resh)
        names.append(file_name)

# convert to arrays for patchify
raw_data = np.array(raw_data)
labels = np.array(labels)

if not CENTER_CUBE_ONLY: # only keep the center cube (over 9 candidates)
    raw_data = np.reshape(raw_data, (-1, raw_data.shape[2], raw_data.shape[3], raw_data.shape[4]))
    labels = np.reshape(labels, (-1, labels.shape[2], labels.shape[3], labels.shape[4]))

# check shapes
print(raw_data.shape)
print(labels.shape)

In [None]:
ID = 4
slice = 32

plt.figure()
plt.subplot(121)
plt.imshow(raw_data[ID, slice])
plt.subplot(122)
plt.imshow(labels[ID, slice])
plt.show()

# sequence

In [None]:
from sequence import DataGenerator
from volumentations import *

def get_augmentation(patch_size):
    return Compose([
        Rotate((-15, 15), (0, 0), (0, 0), p=0.5),
        # RandomCropFromBorders(crop_value=0.1, p=0.5),
        ElasticTransform((0, 0.25), interpolation=2, p=0.1),
        Resize(patch_size, interpolation=1, resize_type=0, always_apply=True, p=0.5),
        Flip(0, p=0.5),
        Flip(1, p=0.5),
        Flip(2, p=0.5),
        RandomRotate90((1, 2), p=0.5),
        # GaussianNoise(var_limit=(0, 5), p=0.2),
        # RandomGamma(gamma_limit=(80, 120), p=0.2),
    ], p=0.8)

aug = get_augmentation((64, 64, 64))

BATCH_SIZE = 5
x = raw_data
y = labels

gen = DataGenerator(raw_data=x, 
                    labels=y, 
                    augmentator=aug,
                    batch_size=BATCH_SIZE,
                    input_shape=(64, 64, 64),
                    shuffle=False
                    ) 

In [None]:
for batch_x, batch_y in gen:
    print(batch_x.shape)

In [None]:
slice = 32
id = 4

batch_x, batch_y = gen[0]

img = raw_data[id,slice]
lbl = labels[id,slice]

aug_img = batch_x[id,slice]
aug_lbl = batch_y[id,slice]

print('batch shape', batch_x.shape)

plt.figure(figsize=(10, 10))
plt.subplot(221)
plt.imshow(img)
plt.subplot(222)
plt.imshow(lbl)
plt.subplot(223)
plt.imshow(aug_img)
plt.subplot(224)
plt.imshow(aug_lbl)
plt.show()

# augmented_generator

In [None]:
from augmented_generator import AugmentedPairGenerator, CustomImageDataGenerator

data_gen_params = dict(
            validation_split=VAL_SPLIT,
            rotation_range=180,
            width_shift_range=0.1,
            height_shift_range=0.1,
            zoom_range=0.1,
            horizontal_flip=True,
            vertical_flip=True,
            fill_mode='nearest',
            data_format='channels_first'
            # shear_range=45, # in degrees counterclockwise
        )

augmented_gen = AugmentedPairGenerator(x=raw_data,
                                       y=labels,
                                       data_gen_params=data_gen_params,
                                       batch_size=16,
                                       seed=0)

pair_gen = augmented_gen.pair_generator

In [None]:
batch = next(pair_gen)
x, y = batch
print(x.shape)
print(y.shape)

plt.figure()
plt.imshow(x[0,32])
plt.show()

We can also check that the masks are modified with the transformations as the images.

In [None]:
slice = 32

count = 0
for i, (img, lbl) in enumerate(pair_gen):
    for x, y in zip(img, lbl):
        if 1 in np.unique(y[slice]):
            plt.figure()
            plt.subplot(121)
            plt.imshow(x[slice])
            plt.subplot(122)
            plt.imshow(y[slice])
            plt.show()
            
            # count += 1
            
        if count == 20:
            break

# CustomImageDataGenerator

In [None]:
data_gen_params = dict(
            validation_split=VAL_SPLIT,
            rotation_range=180,
            width_shift_range=0.1,
            height_shift_range=0.1,
            zoom_range=0.1,
            horizontal_flip=True,
            vertical_flip=True,
            fill_mode='nearest',
            data_format='channels_first'
            # shear_range=45, # in degrees counterclockwise
        )

data_gen = CustomImageDataGenerator(**data_gen_params)

In [None]:
data_generator = data_gen.flow(x=raw_data,
                               #   y=labels,
                               batch_size=16,
                               seed=0)

In [None]:
next(data_generator)
# x = batch
# print(x.shape)
# print(y.shape)

# plt.figure()
# plt.imshow(x[0,32])
# plt.show()

# Volumentations-3D

https://github.com/ZFTurbo/volumentations#volumentations-3d

In [None]:
!pip install volumentations-3D

In [None]:
from volumentations import *

In [None]:
def get_augmentation(patch_size):
    return Compose([
        Rotate((-15, 15), (0, 0), (0, 0), p=0.5),
        RandomCropFromBorders(crop_value=0.1, p=0.5),
        ElasticTransform((0, 0.25), interpolation=2, p=0.1),
        Resize(patch_size, interpolation=1, resize_type=0, always_apply=True, p=1.0),
        Flip(0, p=0.5),
        Flip(1, p=0.5),
        Flip(2, p=0.5),
        RandomRotate90((1, 2), p=0.5),
        GaussianNoise(var_limit=(0, 5), p=0.2),
        RandomGamma(gamma_limit=(80, 120), p=0.2),
    ], p=1.0)

aug = get_augmentation((64, 64, 64))

In [None]:
# img = np.random.randint(0, 255, size=(128, 256, 256), dtype=np.uint8)
# lbl = np.random.randint(0, 1, size=(128, 256, 256), dtype=np.uint8)
img = raw_data[4]
lbl = labels[4]

# with mask
data = {'image': img, 'mask': lbl}
aug_data = aug(**data)
aug_img, aug_lbl = aug_data['image'], aug_data['mask']

# without mask
# data = {'image': img}
# aug_data = aug(**data)
# img = aug_data['image']

slice = 32

plt.figure(figsize=(10, 10))
plt.subplot(221)
plt.imshow(img[slice])
plt.subplot(222)
plt.imshow(lbl[slice])
plt.subplot(223)
plt.imshow(aug_img[slice])
plt.subplot(224)
plt.imshow(aug_lbl[slice])
plt.show()

In [None]:
# Create a dataset from numpy array
dataset = tf.data.Dataset.from_tensor_slices(raw_data[4])

# Define custom data augmentation function
def data_augmentation(image):
    # perform data augmentation
    
    aug = get_augmentation((64, 64, 64))
    
    # with mask
    data = {'image': img}
    aug_data = aug(**data)
    augmented_image = aug_data['image']
        
    return augmented_image

# Apply data augmentation to dataset
dataset = dataset.map(data_augmentation)

# Batch, shuffle and repeat the dataset
dataset = dataset.batch(32).shuffle(buffer_size=1024).repeat()