## Module loading

In [None]:
import h5py
import matplotlib.pyplot as plt
import numpy as np
import tensorflow.keras as keras
import tensorflow as tf
import os
import nibabel as nib
import random
import re
from sklearn.model_selection import train_test_split
from natsort import natsorted
from collections import Counter
from tensorflow.keras.utils import plot_model
from tensorflow.keras import mixed_precision
import progressbar
from modules.generator import DataGenerator
from modules.model import Unet2D

## Data loading

The data must be in the following format :
- one **metadata.hdf5** file containing the following variables :
    - *"patientnames"*, a list with all patient identifiers
    - *"shape_x"*, the numpy shape of the X array - typically, (n, 256, 256, 27, 3)
    - *"shape_y"*, the numpy shape of the Y array - typically, (n, 256, 256, 27, 1)
    - *"shape_mask"*, the numpy shape of the Brain mask array - typically, (n, 256, 256, 27, 1)
- Four **data_?.dat** files consisting in numpy memmaps
    - *"data_x.dat"* in float32 with the following sequences stored in this order: 
        - H0 DWI b1000 (normalized with centered mean and divided by standard deviation)
        - ADC (in .10-6 mm2/sec)
        - TMax map (in seconds)
    - *"data_y.dat"* in float32 with H24 stroke segmentations (binary)
    - *"data_mask.dat"* in uint8 with the brain weighting sequence
        - value = 0 for out-of-brain voxels
        - value = 1 for in-brain voxels

In [None]:
sourcedir = "data/" # Data directory
model_path = "models/" # output directory

with h5py.File(os.path.join(sourcedir,"metadata.hdf5"), "r") as data:
    train_names = [l.decode() for l in list(data["patientnames"])]
    shape_x = tuple(data["shape_x"])
    shape_y = tuple(data["shape_y"])
    shape_mask = tuple(data["shape_mask"])
    
datax = np.memmap(os.path.join(sourcedir, "data_x.dat"), dtype="float32", mode="r", shape=shape_x)
datay = np.memmap(os.path.join(sourcedir, "data_y.dat"), dtype="float32", mode="r", shape=shape_y)
datamask = np.memmap(os.path.join(sourcedir, "data_mask.dat"), dtype="uint8", mode="r", shape=shape_mask)


## Data splitting

Data is split between train and test

In [None]:
TEST_SIZE = 0.2
RANDOM_SEED = 1000

In [None]:
train_index, test_index = train_test_split(range(len(train_names)), 
                                           test_size=TEST_SIZE, random_state=RANDOM_SEED)

print("Stratification count")
print("Training set: ", Counter([train_names[i] for i in train_index]))
print("Test set: ", Counter([train_names[i] for i in test_index]))

## Showing erratic data
Looks up for volumes containing DWI voxel values <-5 or >12 and shows the middle slice.

Please check the corresponding volumes of these patients

In [None]:
flatmax = datax[...,0].max(axis=(1,2,3))
flatmin = datax[...,0].min(axis=(1,2,3))
erratic = np.where(np.logical_or(flatmax>12,flatmin<-5))[0]
if len(erratic) > 0:
    plt.rcParams['figure.figsize'] = [15, 5]
    print([train_names[i] for i in erratic])
    for i in range(len(erratic)):
        j = erratic[i]
        plt.subplot(1,len(erratic),i+1)
        plt.imshow(np.flipud(datay[j,:,:,16,0].T), cmap='gray')

## Checking data generation

In [None]:
check_generator = DataGenerator(datax=datax,
                                datay=datay,
                                mask=datamask,
                                indices=np.arange(len(train_names)),
                                shuffle=True, 
                                flatten_output=False,
                                batch_size=1, dim_z=1,
                                augment=True, flipaugm=True, brightaugm=[True,True,False], gpu_augment=True,
                                scale_input=True, scale_input_lim=[(-5,12),(0,7500.0),(-30,120)], scale_input_clip=[True,True,True],
                                only_stroke=True, give_mask=True)

check_gen_iter = check_generator.getnext()

plt.rcParams['figure.figsize'] = [15, 15]
n_row = 4
for i in range(n_row):
    sampleX, sampleY = next(check_gen_iter)
    plt.subplot(n_row,5,i*5+1)
    plt.title('Diffusion imaging')
    plt.imshow(np.flipud(sampleX["img"][:,:,0,0].T), cmap='gray', vmin=-0.8, vmax=1)
    plt.subplot(n_row,5,i*5+2)
    plt.title('ADC')
    plt.imshow(np.flipud(sampleX["img"][:,:,0,1].T), cmap='gray', vmin=-0.8, vmax=1)
    plt.subplot(n_row,5,i*5+3)
    plt.title('TMax')
    plt.imshow(np.flipud(sampleX["img"][:,:,0,2].T), cmap='gray', vmin=-1.2, vmax=1)
    plt.subplot(n_row,5,i*5+4)
    plt.title('Mask')
    plt.imshow(np.flipud(sampleX["mask"][:,:,0].T), cmap='gray', vmin=0, vmax=1)
    plt.subplot(n_row,5,i*5+5)
    plt.title('Final stroke segmentation')
    plt.imshow(np.flipud(sampleY[:,:,0].T), cmap='gray', vmin=0, vmax=1)

## Create and train model

In [None]:
train_generator = DataGenerator(datax, datay, indices=train_index, dim_z=3, batch_size=1, augment=False, 
                                scale_lim=[(-5,12),(0,7500.0),(-30,120)], brightaugm=False)
dsT = tf.data.Dataset.from_generator(train_generator.getnext, ({"img":K.floatx()}, K.floatx()), 
                                    ({"img":(256,256,3,3)}, (256*256,))).repeat().batch(batch_size).prefetch(16)

batch_size = 16
input_img = Input((256,256,3,3), name='img')
model = UNet2D(input_img, n_filters=256, dropout=0.5, batchnorm=True)

model.compile(optimizer=Adam(0.003), loss="binary_crossentropy")
model.fit(dsT, epochs=20, steps_per_epoch=len(train_generator)//batch_size, callbacks=[
        ModelCheckpoint("output/checkpoints/checkpoint", verbose=1, save_weights_only=True),
        TensorBoard(log_dir="output/tf_logs/"+lastcheckpoint+"/")
])

## Model export

In [None]:
model.save(model_path)