# Training and test of MultiscaleSR model

## Imports 

In [1]:
import matplotlib.pyplot as plt
import numpy as np
from scipy.ndimage import zoom
from scipy.ndimage.filters import gaussian_filter
import SimpleITK as sitk
from tensorflow import image, pad
from tensorflow.keras.initializers import RandomNormal, Constant
from tensorflow.keras.layers import Add, Conv3D, Input, ReLU
from tensorflow.keras.models import Model

from adamLRM import AdamLRM
from patches import array_to_patches
from store2hdf5 import store2hdf53D
from utils3D import modcrop3D, shave3D, imadjust3D

## PSNR calculation functions

In [2]:
def psnr_model(y_pred, y_true):
    return image.psnr(y_pred.numpy(), y_true, np.max(y_pred.numpy())).numpy()

def psnr(y_pred, y_true):
    return image.psnr(y_pred, y_true, np.max(y_pred))

## Model definition

In [3]:
def SRReCNN3D(input_shape, depth, nb_filters, kernel_size, padding, to_json=False):
    input_layer = Input(input_shape)
    layer = input_layer

    for i in range(depth+1):
        nf = 1 if i == depth else nb_filters
        padded_layer = pad(layer, [[0, 0], [padding, padding], [padding, padding], [padding, padding], [0, 0]])
        layer = Conv3D(
            filters=nf,
            kernel_size=kernel_size,
            strides=1,
            padding="valid",
            kernel_initializer=RandomNormal(
                mean=0,
                stddev=np.sqrt(2.0/float(nb_filters * kernel_size ** 3))
            ),
            bias_initializer=Constant(0)
        )(padded_layer)
        if i < depth:
            layer = ReLU()(layer)

    final_layer = Add()([input_layer, layer])

    #final_layer = Flatten()(final_layer)

    model = Model(input_layer, final_layer)

    if to_json:
        with open("model.js", "w") as json_model:
            json_model.write(model.to_json())

    return model

## Load NifTi Image and get numpy array

In [4]:
reference_nifti = sitk.ReadImage("/projets/srm4bmri/originals/Marmoset_T1w_mri/1010.nii")
reference_image = sitk.GetArrayFromImage(reference_nifti)

## Preprocessing to have a label and a low resolution image

`reference_image` is a "perfect" image. Before running the model we need to artificially degrade it.

### Definition of preprocessing parameters

In [5]:
blur_sigma = 1
downsampling_scale = (2, 2, 2)
shaving_border = (0, 0, 0)
interpolation_order = 3
patch_size = 21
patch_stride = 10
max_number_patches = 3200

### Swap axes

In [6]:
reference_image = np.swapaxes(reference_image, 0, 2).astype('float32')

### Normalisation and modcrop

Modcrop is the function that makes the size of each dimension strictly proportional to the scale.

If a dimension contains 80 values and the corresponding scale is equal to 3, then the resulting dimension size is 78 : `80 - 80 % 3 = 78`.

This is done for having a perfect downsampling.

In [7]:
reference_image = imadjust3D(reference_image, [0, 1])
reference_image = modcrop3D(reference_image, downsampling_scale)

### Blur and downsampling

In [8]:
blur_reference_image = gaussian_filter(reference_image, sigma=blur_sigma)
low_resolution_image = zoom(
    blur_reference_image,
    zoom=(1 / float(idxScale) for idxScale in downsampling_scale),
    order=interpolation_order
)

### Interpolation

In [9]:
interpolated_image = zoom(
    low_resolution_image, 
    zoom = downsampling_scale,
    order = interpolation_order
)

### Shaving

At the edges of the images sometimes there are only black voxels. We need to remove them to make the model not training on that data.

In [10]:
label_image = shave3D(reference_image, shaving_border)
data_image = shave3D(interpolated_image, shaving_border)

### Extract 3D patches

In [11]:
data_patches = array_to_patches(
    data_image,
    patch_shape = (patch_size, patch_size, patch_size),
    extraction_step = patch_stride,
    normalization = False
)

labels_patches = array_to_patches(
    label_image,
    patch_shape = (patch_size, patch_size, patch_size),
    extraction_step = patch_stride,
    normalization = False
)

270 patches have been extracted
270 patches have been extracted




#### Add channel axis !

In [12]:
data_patches = data_patches[:, :, :, :, np.newaxis]
labels_patches = labels_patches[:, :, :, :, np.newaxis]

#### Randomly rearrange and get the first `max_number_patches`

In [13]:
np.random.seed(0)  # makes the random numbers predictable
random_order = np.random.permutation(data_patches.shape[0])

data_patches = data_patches[random_order, :, :, :, :]
labels_patches = labels_patches[random_order, :, :, :, :]

# data_patches = data_patches[:max_number_patches, :, :, :, :]
# labels_patches = labels_patches[:max_number_patches, :, :, :, :]

## Launch training

### Set training parameters

In [21]:
network_depth = 10
nb_filters = 64
kernel_size = 3
conv_padding = 1
epochs = 10
batch_size = 64
adam_learning_rate = 0.0001
residual_learning = False # Unused for the moment

### Compile and launch the training

In [18]:
def launch_training(
    data,
    labels, 
    depth = 10, 
    nb_filters = 64,
    kernel_size = 3,
    padding = 1,
    epochs = 20,
    batch_size = 4, 
    adam_lr = 0.0001
):
    model = SRReCNN3D(data[0].shape, depth, nb_filters, kernel_size, padding)
    model.compile(
        optimizer=AdamLRM(learning_rate=adam_lr), 
        loss="mse", 
        metrics=[psnr_model],
        run_eagerly=True
    )
    history = model.fit(
        data, 
        labels, 
        batch_size=batch_size, 
        epochs=epochs
    )
    
    return model, history

### Function for drawing the loss and the PSNR metric

In [19]:
def draw_loss_and_psnr(history):
    plt.figure(figsize=(11, 3))

    # Plot loss function
    plt.subplot(1, 2, 1)
    plt.plot(history.epoch, history.history['loss'])
    plt.title('loss')

    # Plot PSNR metric
    plt.subplot(1, 2, 2)
    plt.plot(history.epoch, history.history['psnr_model'])
    plt.title('psnr')

### Training execution

In [None]:
model, history =  launch_training(
    data_patches,
    labels_patches, 
    depth = network_depth, 
    nb_filters = nb_filters,
    kernel_size = kernel_size,
    padding = conv_padding,
    epochs = epochs,
    batch_size = batch_size, 
    adam_lr = adam_learning_rate
)

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10

Saving weights: 

In [None]:
from
model.save_weights("/projets/srm4bmri/weights/mulstiscaleSR_training.h5")

### Prediction

In [20]:
predicted_image = model.predict(interpolated_image[np.newaxis, :,:,:,np.newaxis])

### Save output

In [22]:
sitk_image = sitk.GetImageFromArray(predicted_image[0, :, :, :, 0])
sitk.WriteImage(sitk_image, "/projets/srm4bmri/outputs/output_multiscale_sr.nii" )

In [23]:
sitk_image = sitk.GetImageFromArray(interpolated_image)
sitk.WriteImage(sitk_image, "/projets/srm4bmri/outputs/input_multiscale_sr.nii" )