## Module loading

In [None]:
import h5py
import matplotlib.pyplot as plt
import numpy as np
import tensorflow.keras as keras
import tensorflow as tf
import os
import nibabel as nib
import random
import re
from sklearn.model_selection import train_test_split
from natsort import natsorted
from collections import Counter
from tensorflow.keras.utils import plot_model
from tensorflow.keras import mixed_precision
import progressbar
from modules.generator import DataGenerator
from modules.model import Generator, Discriminator
from modules.losses import generator_loss, discriminator_loss
from modules.figures import figure

## Define GPU Strategy

In [None]:
multiGPU = True
mixedPrecision = True

In [None]:
if multiGPU:
    GPUstrategy = tf.distribute.MirroredStrategy()
    def parallelize(func):
        with GPUstrategy.scope():
            func()
else:
    def parallelize(func):
        func()
        
if mixedPrecision:
    policy = mixed_precision.Policy('mixed_float16')
    mixed_precision.set_global_policy(policy)

## Data loading

The data must be in the following format :
- one **metadata.hdf5** file containing the following variables :
    - *"patientnames"*, a list with all patient identifiers
    - *"shape_x"*, the numpy shape of the X array - typically, (n, 256, 256, 25, 3)
    - *"shape_y"*, the numpy shape of the Y array - typically, (n, 256, 256, 25, 1)
    - *"shape_mask"*, the numpy shape of the Brain mask array - typically, (n, 256, 256, 25, 1)
    - *"shape_meta"*, the numpy shape of the Metadata array - typically, (n, 2)
- Four **data_?.dat** files consisting in numpy memmaps
    - *"data_x.dat"* in float32 with the following sequences stored in this order: 
        - b0 DWI (normalized with centered mean and divided by standard deviation)
        - b1000 DWI (normalized with centered mean and divided by standard deviation)
        - ADC computed in .10-6 mm2/sec
    - *"data_y.dat"* in float32 with the realFLAIR sequences (normalized)
    - *"data_mask.dat"* in uint8 with the brain weighting sequence
        - value = 0 for out-of-brain voxels
        - value = 1 for in-brain voxels
        - value = 2 for stroke-region voxels
    - *"data_meta.dat"* in float32 containing for each datapoint :
        - the corresponding quality (0, 1, 2, 3) 
        - the corresponding timepoint (0, 1) 

In [None]:
sourcedir = "data/" # Data directory

with h5py.File(os.path.join(sourcedir,"metadata.hdf5"), "r") as data:
    train_names = [l.decode() for l in list(data["patientnames"])]
    shape_x = tuple(data["shape_x"])
    shape_y = tuple(data["shape_y"])
    shape_mask = tuple(data["shape_mask"])
    shape_meta = tuple(data["shape_meta"])
    
datax = np.memmap(os.path.join(sourcedir, "data_x.dat"), dtype="float32", mode="r", shape=shape_x)
datay = np.memmap(os.path.join(sourcedir, "data_y.dat"), dtype="float32", mode="r", shape=shape_y)
datamask = np.memmap(os.path.join(sourcedir, "data_mask.dat"), dtype="uint8", mode="r", shape=shape_mask)
datameta = np.memmap(os.path.join(sourcedir, "data_meta.dat"), dtype="float32", mode="r", shape=shape_meta)

## Stratified data splitting

Data is split between train and test, with stratification on Quality and Datapoint
Training data is then split betweeen train and validation

In [None]:
TEST_SIZE = 0.2
VALIDATION_SIZE = 0.2
RANDOM_SEED = 1000

In [None]:
total_stratmri = []
for i in range(shape_x[0]):
    total_stratmri.append(str(int(datameta[i,0]))+"_"+str(int(datameta[i,1])))

train_index, test_index = train_test_split(range(len(total_stratmri)), stratify=total_stratmri, 
                                           test_size=TEST_SIZE, random_state=RANDOM_SEED)

print("Stratification count")
print("Training set: ", Counter([total_stratmri[i] for i in train_index]))
print("Test set: ", Counter([total_stratmri[i] for i in test_index]))

In [None]:
small_train_index, valid_index = train_test_split(train_index, test_size=VALIDATION_SIZE, 
                                                  shuffle=True, random_state=RANDOM_SEED)

## Showing erratic data
Looks up for volumes containing voxel values <50 or >50 and shows the middle slice.

Please check the corresponding volumes of these patients

In [None]:
flatmax = datay[...,0:2].max(axis=(1,2,3,4))
flatmin = datay[...].min(axis=(1,2,3,4))
erratic = np.where(np.logical_or(flatmax>50,flatmin<-50))[0]
if len(erratic) > 0:
    plt.rcParams['figure.figsize'] = [15, 5]
    print([train_names[i] for i in erratic])
    for i in range(len(erratic)):
        j = erratic[i]
        plt.subplot(1,len(erratic),i+1)
        plt.imshow(np.flipud(datay[j,:,:,16,0].T), cmap='gray')

## Checking data generation

In [None]:
check_generator = DataGenerator(datax=datax,
                                datay=datay,
                                datac=datameta.astype(np.uint8),
                                mask=datamask,
                                indices=np.arange(len(train_names)),
                                shuffle=True, 
                                flatten_output=False,
                                batch_size=1, dim_z=1,
                                augment=True, flipaugm=True, brightaugm=[True,True,False], gpu_augment=True,
                                scale_input=True, scale_input_lim=[(-5,12),(-5,12),(0,7500.0)], scale_input_clip=[True,True,False],
                                scale_output=True, scale_output_lim=(-5,10), scale_output_clip=True,
                                only_stroke=True, give_mask=True, give_meta=True)

check_gen_iter = check_generator.getnext()

plt.rcParams['figure.figsize'] = [15, 15]
n_row = 4
for i in range(n_row):
    sampleX, sampleY = next(check_gen_iter)
    plt.subplot(n_row,5,i*5+1)
    plt.title('Diffusion imaging (b0)')
    plt.imshow(np.flipud(sampleX["img"][:,:,0,0].T), cmap='gray', vmin=-0.8, vmax=1)
    plt.subplot(n_row,5,i*5+2)
    plt.title('Diffusion imaging (b1000)')
    plt.imshow(np.flipud(sampleX["img"][:,:,0,1].T), cmap='gray', vmin=-0.8, vmax=1)
    plt.subplot(n_row,5,i*5+3)
    plt.title('ADC')
    plt.imshow(np.flipud(sampleX["img"][:,:,0,2].T), cmap='gray', vmin=-1.2, vmax=1)
    plt.subplot(n_row,5,i*5+4)
    plt.title('Mask')
    plt.imshow(np.flipud(sampleX["mask"][:,:,0].T), cmap='gray', vmin=0, vmax=2)
    plt.subplot(n_row,5,i*5+5)
    plt.title('real FLAIR')
    plt.imshow(np.flipud(sampleY[:,:,0].T), cmap='gray', vmin=-0.8, vmax=1)

## Create model

In [None]:
LAMBDA_L1 = 100
MAX_LAMBDA_EGDE = 100
if mixedPrecision:
    MASK_WEIGHTING = 3
else:
    MASK_WEIGHTING = 7 # in powers of ten

figures_dir="figures"
log_dir="logs"
model_dir="models"
model_name="synthflair"
checkpoint_prefix = os.path.join(model_dir,model_name)

epoch = 0

In [None]:
def defModel():
    global generator, discriminator, generator_optimizer, discriminator_optimizer, checkpoint, summary_writer
    generator = Generator()
    discriminator = Discriminator()
    generator_optimizer = tf.keras.optimizers.Adam(1e-5, beta_1=0.5)
    discriminator_optimizer = tf.keras.optimizers.Adam(1e-5, beta_1=0.5)
    checkpoint = tf.train.Checkpoint(generator=generator,
                                     discriminator=discriminator,
                                     generator_optimizer=generator_optimizer,
                                     discriminator_optimizer=discriminator_optimizer)
    summary_writer = tf.summary.create_file_writer(log_dir + "/fit/" + model_name)
    if mixedPrecision:
        generator_optimizer = mixed_precision.LossScaleOptimizer(generator_optimizer)
        discriminator_optimizer = mixed_precision.LossScaleOptimizer(discriminator_optimizer)

    return generator, discriminator, generator_optimizer, discriminator_optimizer, checkpoint, summary_writer

parallelize(defModel)

In [None]:
# Load model if already trained 
def loadModel():
    global epoch
    if tf.train.latest_checkpoint(model_dir):
        last_checkpoint = tf.train.latest_checkpoint(model_dir)
        m = re.search('_epoch(\d+)\-', last_checkpoint)
        last_epoch = m.group(1)
        if last_epoch :
            checkpoint.restore(last_checkpoint)
            epoch = int(last_epoch) + 1  
            
parallelize(loadModel)

## Test model and figure export

In [None]:
figure(datax, datay, datameta, datamask, generator, train_names, valid_index,
       multiquality=False, slices=(3,5), save=False, show=True, n_patients=2)

## Define training and validation functions

In [None]:
def train_step_(inputs, lambda_sobel, apply_gradients=True):
    predictors, real_flair = inputs
    quality = predictors["meta"][...,0]
    diffusion, mask = predictors["img"], predictors["mask"][:,:,:,0]
    mask = tf.reshape(mask, mask.shape+(1,))

    if mixedPrecision:
        epsilon = 10e-7
    else:
        epsilon = 10e-12
    weighted_mask = 10**((mask-1)*MASK_WEIGHTING)
    mask_pooled = tf.keras.layers.MaxPool2D(32)(weighted_mask)

    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        pseudo_flair = generator([diffusion, quality], training=True)
        pseudo_filtered = tf.image.sobel_edges(pseudo_flair)
        pseudo_sobel = tf.math.sqrt(tf.square(pseudo_filtered[...,0]) + tf.square(pseudo_filtered[...,1]) + epsilon)
        real_filtered = tf.image.sobel_edges(real_flair)
        real_sobel = tf.math.sqrt(tf.square(real_filtered[...,0]) + tf.square(real_filtered[...,1]) + epsilon)

        disc_real_output = discriminator([diffusion, real_sobel, quality, real_flair], training=True)
        disc_generated_output = discriminator([diffusion, pseudo_sobel, quality, pseudo_flair], training=True)
        gen_total_loss, gen_gan_loss, gen_l1_loss, gen_edge_loss = generator_loss(disc_generated_output, 
                                                                               pseudo_flair, real_flair, LAMBDA_L1,
                                                                               pseudo_sobel, real_sobel, lambda_sobel,
                                                                               weighted_mask, mask_pooled)

        disc_loss = discriminator_loss(disc_real_output, disc_generated_output, weighted_mask, mask_pooled)
        
        if mixedPrecision:
            scaled_gen_loss = generator_optimizer.get_scaled_loss(gen_total_loss)
            scaled_disc_loss = discriminator_optimizer.get_scaled_loss(disc_loss)

        

    if apply_gradients:
        if mixedPrecision:
            scaled_generator_gradients = gen_tape.gradient(scaled_gen_loss, generator.trainable_variables)
            generator_gradients = generator_optimizer.get_unscaled_gradients(scaled_generator_gradients)
            scaled_discriminator_gradients = disc_tape.gradient(scaled_disc_loss, discriminator.trainable_variables)
            discriminator_gradients = discriminator_optimizer.get_unscaled_gradients(scaled_discriminator_gradients)
        else:
            generator_gradients = gen_tape.gradient(gen_total_loss,
                                                    generator.trainable_variables)
            discriminator_gradients = disc_tape.gradient(disc_loss,
                                                       discriminator.trainable_variables)

        generator_optimizer.apply_gradients(zip(generator_gradients,
                                               generator.trainable_variables))
        discriminator_optimizer.apply_gradients(zip(discriminator_gradients,
                                                    discriminator.trainable_variables))

    return {'gen_total_loss': gen_total_loss, 'gen_gan_loss': gen_gan_loss, 
            'gen_l1_loss': gen_l1_loss, 'gen_edge_loss': gen_edge_loss, 'disc_loss': disc_loss}

def validation_step_(inputs, lambda_sobel):
    return train_step_(inputs, lambda_sobel, False)

if multiGPU:
    with GPUstrategy.scope():
        @tf.function
        def train_step(dataset_inputs, epoch):
            per_replica_losses = GPUstrategy.run(train_step_, args=(dataset_inputs,epoch))
            all_replica_losses = {}
            for l in per_replica_losses.keys():
                all_replica_losses[l] = GPUstrategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses[l], axis=(0,))/BATCH_SIZE
            return all_replica_losses

        @tf.function
        def validation_step(dataset_inputs, epoch):
            per_replica_losses = GPUstrategy.run(validation_step_, args=(dataset_inputs,epoch))
            all_replica_losses = {}
            for l in per_replica_losses.keys():
                all_replica_losses[l] = GPUstrategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses[l], axis=(0,))/BATCH_SIZE
            return all_replica_losses
else:
    @tf.function
    def train_step(inputs, lambda_sobel, apply_gradients=True):
        return train_step_(inputs, lambda_sobel, apply_gradients=True)
    
    @tf.function
    def validation_step(inputs, lambda_sobel):
        return validation_step_(inputs, lambda_sobel)

## Train/Validation split and Generators definition

In [None]:
valid_generator = DataGenerator(datax, datay, datameta.astype(np.uint8),
                            mask=datamask, indices=valid_index, shuffle=True, 
                            flatten_output=False, batch_size=1, dim_z=1,
                            augment=True, brightaugm=[False,False,False], shapeaugm=False, flipaugm=False, gpu_augment=False,
                            scale_input=True, scale_input_lim=[(-5,12),(-5,12),(0,7500.0)], scale_input_clip=[False,False,False],
                            scale_output=True, scale_output_lim=(-5,10), scale_output_clip=True,
                            only_stroke=False, give_mask=True, give_meta=True)

train_generator = DataGenerator(datax, datay, datameta.astype(np.uint8),
                            mask=datamask, indices=small_train_index, shuffle=True, 
                            flatten_output=False, batch_size=1, dim_z=1,
                            augment=True, shapeaugm=True, brightaugm=[True,True,False], flipaugm=True, gpu_augment=True, 
                            scale_input=True, scale_input_lim=[(-5,12),(-5,12),(0,7500.0)], scale_input_clip=[False,False,False],
                            scale_output=True, scale_output_lim=(-5,10), scale_output_clip=True,
                            only_stroke=False, give_mask=True, give_meta=True)


dsT_pre = tf.data.Dataset.from_generator(train_generator.getnext, 
           ({"img":keras.backend.floatx(),
             "mask":keras.backend.floatx(),
             "meta":keras.backend.floatx()}, 
            keras.backend.floatx()), 
           ({"img":(256,256,1,3), "mask":(256,256,1), "meta":(1,2)}, (256,256,1)))
dsV_pre = tf.data.Dataset.from_generator(valid_generator.getnext,  
           ({"img":keras.backend.floatx(),
             "mask":keras.backend.floatx(),
             "meta":keras.backend.floatx()}, 
            keras.backend.floatx()), 
           ({"img":(256,256,1,3), "mask":(256,256,1), "meta":(1,2)}, (256,256,1)))

## Training loop

In [None]:
NB_EPOCHS = 5000
BATCH_SIZE = 128 # Maximize the batch size for your GPU
NB_SUBEPOCHS = len(train_generator)//BATCH_SIZE
VALIDATION_STEPS = len(valid_generator)//BATCH_SIZE
VALIDATION_EACH_EPOCH = 10

saveFigures = True
FIGURE_EACH_EPOCH = 10

saveModels = True
MODELSAVE_EACH_EPOCH = 50

In [None]:
if multiGPU:
    with GPUstrategy.scope():
        dsT = GPUstrategy.experimental_distribute_dataset(dsT_pre.batch(BATCH_SIZE).prefetch(256))
        dsV = GPUstrategy.experimental_distribute_dataset(dsV_pre.batch(BATCH_SIZE).prefetch(256))
else:
    dsT = dsT_pre.batch(BATCH_SIZE).prefetch(256)
    dsV = dsV_pre.batch(BATCH_SIZE).prefetch(256)

def train_loop():
    global epoch
    for epoch in range(epoch, NB_EPOCHS+1):
        print("Epoch", epoch)
        widgets = [
            progressbar.Percentage(),
            progressbar.Bar(),
            "    ",
            progressbar.DynamicMessage('gen_total_loss', format="{name}: {formatted_value}", precision=4),
            "    ",
            progressbar.DynamicMessage('gen_gan_loss', format="{name}: {formatted_value}", precision=4),
            "    ",
            progressbar.DynamicMessage('gen_l1_loss', format="{name}: {formatted_value}", precision=4),
            "    ",
            progressbar.DynamicMessage('gen_edge_loss', format="{name}: {formatted_value}", precision=4),
            "    ",
            progressbar.DynamicMessage('disc_loss', format="{name}: {formatted_value}", precision=4),
            "    ",
            progressbar.ETA()
        ]
        progbar = progressbar.ProgressBar(max_value=NB_SUBEPOCHS, widgets=widgets, term_width=150)
        subepoch = 0       
        for dat in dsT:
            if multiGPU:
                if GPUstrategy.experimental_local_results(dat[0]["img"])[-1].shape[0] == 0:
                    break
            else:
                if dat[0]["img"].shape[0] == 0:
                    break
            losses = train_step(dat, epoch if epoch < MAX_LAMBDA_EGDE else MAX_LAMBDA_EGDE)
            step = epoch * NB_SUBEPOCHS + subepoch
            with summary_writer.as_default():
                tf.summary.scalar('gen_total_loss', losses["gen_total_loss"].numpy().mean(), step=step)
                tf.summary.scalar('gen_gan_loss', losses["gen_gan_loss"].numpy().mean(), step=step)
                tf.summary.scalar('gen_l1_loss', losses["gen_l1_loss"].numpy().mean(), step=step)
                tf.summary.scalar('gen_edge_loss', losses["gen_edge_loss"].numpy().mean(), step=step)
                tf.summary.scalar('disc_loss', losses["disc_loss"].numpy().mean(), step=step)
            progbar.update(subepoch, gen_total_loss=losses["gen_total_loss"].numpy().mean(),
                           gen_gan_loss=losses["gen_gan_loss"].numpy().mean(),
                           gen_l1_loss=losses["gen_l1_loss"].numpy().mean(),
                           gen_edge_loss=losses["gen_edge_loss"].numpy().mean(),
                           disc_loss=losses["disc_loss"].numpy().mean())
            subepoch += 1
            if subepoch >= NB_SUBEPOCHS:
                break

        # Saving Trained Model
        if saveModels and epoch % MODELSAVE_EACH_EPOCH == 0:
            checkpoint.save(file_prefix = checkpoint_prefix+"_epoch"+str(epoch))

        # Saving validation figures
        if saveFigures and epoch % FIGURE_EACH_EPOCH == 0:
            figure(datax, datay, datameta, datamask, generator, train_names, valid_index,
                   multiquality=True, save=True, show=False, n_patients=10, epoch=epoch,
                   output=figures_dir)

        # Validation step
        if epoch % VALIDATION_EACH_EPOCH == 0:
            all_losses = []
            mean_losses = []
            print("VALIDATING at epoch", epoch)
            progbar = tf.keras.utils.Progbar(VALIDATION_STEPS)
            subepoch = 0
            for dat in dsV:
                if multiGPU:
                    if GPUstrategy.experimental_local_results(dat[0]["img"])[-1].shape[0] == 0:
                        break
                else:
                    if dat[0]["img"].shape[0] == 0:
                        break
                all_losses.append(validation_step(dat, epoch))
                progbar.update(subepoch, [(i,losses[i].numpy().mean()) for i in losses])  
                subepoch += 1  
                if subepoch >= VALIDATION_STEPS:
                    break
            mean_losses={key:np.mean([all_losses[i][key].numpy().mean() for i in range(len(all_losses))])
                        for key in ["gen_total_loss","gen_gan_loss","gen_l1_loss","gen_edge_loss","disc_loss"]}
            print("")
            with summary_writer.as_default():
                tf.summary.scalar('valid_gen_total_loss', mean_losses["gen_total_loss"], step=epoch*NB_SUBEPOCHS)
                tf.summary.scalar('valid_gen_gan_loss', mean_losses["gen_gan_loss"], step=epoch*NB_SUBEPOCHS)
                tf.summary.scalar('valid_gen_l1_loss', mean_losses["gen_l1_loss"], step=epoch*NB_SUBEPOCHS)
                tf.summary.scalar('valid_gen_edge_loss', mean_losses["gen_edge_loss"], step=epoch*NB_SUBEPOCHS)
                tf.summary.scalar('valid_disc_loss', mean_losses["disc_loss"], step=epoch*NB_SUBEPOCHS)
    checkpoint.save(file_prefix = checkpoint_prefix+"_epoch"+str(epoch))
    
parallelize(train_loop)

In [None]:
checkpoint.save(file_prefix = checkpoint_prefix+"_epoch"+str(epoch))

## Test export

In [None]:
for i in test_index:
    figure(datax, datay, datameta, datamask, generator, train_names, [i],
           multiquality=False, save=True, show=False, n_patients=1, show_outline=True,
          output="output", savestr=train_names[i])

## Model export

In [None]:
generator.save('saved_generator')