<h1>New Experiment 2: Random rejection of motion-free subjects</h1>
In this notebook, random rejection tests were conducted using data from S1 in the still condition. The generalisation tests were conducted over S2. Each time we fix the number of retained volumes N. We randomly select N volumes from a subject's data, then we train our proposed model over this sample. At testing, we applied the same rejection scheme to S2, then test over the subset of S2.</br>
We would like to show this proposed model is robust, by which it is independent of 1. The proportion of discarded volumes and 2. The different combinations of b-values.</br>
<h2>Packages</h2>

In [None]:
import nibabel as nib
import numpy as np
"""
packages that does conventional model fitting
"""
import amico
"""
packages that generate train/test dataset
"""
from FormatData import generate_data, parser as data_parser
"""
packages that trains network
"""
from Training import train_network
from utils.model import parser as model_parser
"""
packages that test network
"""
from Testing import test_model
"""
packages that produce the rejection shceme
"""
from filter_qa import parser as filter_parser, load_eddy

In [None]:
"""
packages that handle graphs
"""
import matplotlib.pyplot as plt
from matplotlib.colors import BoundaryNorm, LinearSegmentedColormap
from skimage.metrics import structural_similarity as compare_ssim
from utils import calc_ssim
%matplotlib inline
def plot_loss(cmd):
    """
    A function that used to plot the loss curve for the trained network.
    Args:
        cmd: String, the command line in the terminal
    """
    args = model_parser().parse_args(cmd.split())
    history = train_network(args)
    plt.plot(history.history['loss'])
    plt.plot(history.history['val_loss'])
    plt.title('model loss')
    plt.ylabel('loss')
    plt.xlabel('epoch')
    plt.legend(['train', 'validation'], loc='upper left')
    plt.show()

def show_slices(slices, grayscale=True):
    """
    Function to display the slices

    Args:
        slices (list): a list of 2d ndarray that contains the data to be displayed
        grayscale (bool, optional): True, if diplay grayscale img. Defaults to True.
    """    
    fig, axes = plt.subplots(1, len(slices), figsize=(10,10))
    cax = fig.add_axes([0, 0, .3, .3])
    for i, slice in enumerate(slices):
        # use grayscale for displaying ref and pred imgs:
        if grayscale:
            cmap = plt.get_cmap('gray')
            cmaplist = [cmap(i) for i in range(cmap.N)]
            cmap = LinearSegmentedColormap.from_list('Custom cmap', cmaplist, cmap.N)
            bounds = np.arange(0, 1.0, .01)
            idx = np.searchsorted(bounds, 0)
            bounds = np.insert(bounds, idx, 0)
            norm = BoundaryNorm(bounds, cmap.N)
            im = axes[i].imshow(slice.T, cmap=cmap, origin="lower", interpolation='none', norm=norm)
        else:
            # define the colormap
            cmap = plt.get_cmap('bwr')
            # extract all colors from the .jet map
            cmaplist = [cmap(i) for i in range(cmap.N)]
            # create the new map
            cmap = LinearSegmentedColormap.from_list('Custom cmap', cmaplist, cmap.N)
            # define the bins and normalize and forcing 0 to be part of the colorbar
            # define the min and max to be -1 and +1 respectively
            bounds = np.arange(-0.5, 0.5, .01)
            idx = np.searchsorted(bounds, 0)
            bounds = np.insert(bounds, idx, 0)
            norm = BoundaryNorm(bounds, cmap.N)
            im = axes[i].imshow(slice.T, cmap=cmap, origin="lower", interpolation='none', norm=norm)
        fig.colorbar(im, cax=cax, orientation='vertical')

def scale(img):
    # for i in range(img.shape[0]):
    #     for j in range(img.shape[1]):
    #         img[i][j] = ((img[i][j]+1)/2)*255
    return img

def compare_simi(pred, ref):
    return calc_ssim(pred, ref)

In [None]:
def visualise(ref_ndi, ref_odi, ref_fwf, retained_vol, subject, model, layer, affine1, affine2, affine3):
    """
    Function to visualise the imgs and difference maps

    Args:
        ref_ndi (ndarray): the reference NDI data
        ref_odi (ndarray): the reference ODI data
        ref_fwf (ndarray): the reference FWF data
        retained_vol (int): the number of volumes retained after rejection
        subject (string): the subject that is examined
        model (string): the model used
        layer (int): the number of layers for the network
    """
    patch = 3
    if model == 'fc1d':
        patch = 1

    print(patch)
    print(model)
    # visualise the ref imgs
    refNDI0 = ref_ndi[26, :, :]
    refNDI1 = ref_ndi[:, 30, :]
    refNDI2 = ref_ndi[:, :, 16]
    show_slices([refNDI0, refNDI1, refNDI2])
    plt.suptitle("Center slices for NDI reference image")

    refODI0 = ref_odi[26, :, :]
    refODI1 = ref_odi[:, 30, :]
    refODI2 = ref_odi[:, :, 16]
    show_slices([refODI0, refODI1, refODI2])
    plt.suptitle("Center slices for ODI reference image")

    refFWF0 = ref_fwf[26, :, :]
    refFWF1 = ref_fwf[:, 30, :]
    refFWF2 = ref_fwf[:, :, 16]
    show_slices([refFWF0, refFWF1, refFWF2])
    plt.suptitle("Center slices for FWF reference image")

    # visualise the pred imgs produced at varied input size
    ndi_path = '../Net/nii/'+subject+'-'+str(retained_vol)+'-'+model+'-patch_'+str(patch)+'-base_1-layer_'+str(layer)+'-label_NDI.nii'
    odi_path = '../Net/nii/'+subject+'-'+str(retained_vol)+'-'+model+'-patch_'+str(patch)+'-base_1-layer_'+str(layer)+'-label_ODI.nii'
    fwf_path = '../Net/nii/'+subject+'-'+str(retained_vol)+'-'+model+'-patch_'+str(patch)+'-base_1-layer_'+str(layer)+'-label_FWF.nii'
    ndi_img = nib.load(ndi_path)
    ndi_data = ndi_img.get_fdata()
    odi_img = nib.load(odi_path)
    odi_data = odi_img.get_fdata()
    fwf_img = nib.load(fwf_path)
    fwf_data = fwf_img.get_fdata()

    ndi0 = ndi_data[26, :, :]
    ndi1 = ndi_data[:, 30, :]
    ndi2 = ndi_data[:, :, 16]
    show_slices([ndi0, ndi1, ndi2])
    plt.suptitle('Center slices for NDI predicted image by '+model+', input size='+str(retained_vol))
    (score, ndidiff) = compare_ssim(ndi_data, ref_ndi, full=True)
    print(str(retained_vol)+'input size the ssim score for ndi is: ' + str(score))

    odi0 = odi_data[26, :, :]
    odi1 = odi_data[:, 30, :]
    odi2 = odi_data[:, :, 16]
    show_slices([odi0, odi1, odi2])
    plt.suptitle('Center slices for ODI predicted image by '+model+', input size='+str(retained_vol))
    (score, odidiff) = compare_ssim(odi_data, ref_odi, full=True)
    print(str(retained_vol)+'input size the ssim score for odi is: ' + str(score))

    fwf0 = fwf_data[26, :, :]
    fwf1 = fwf_data[:, 30, :]
    fwf2 = fwf_data[:, :, 16]
    show_slices([fwf0, fwf1, fwf2])
    plt.suptitle('Center slices for FWF predicted image by '+model+', input size='+str(retained_vol))
    (score, fwfdiff) = compare_ssim(fwf_data, ref_fwf, full=True)
    print(str(retained_vol)+'input size the ssim score for fwf is: ' + str(score))

    # plot the difference map between the imgs by the lib
    ndidiff0 = scale(refNDI0 - ndi0)
    ndidiff1 = scale(refNDI1 - ndi1)
    ndidiff2 = scale(refNDI2 - ndi2)
    show_slices([ndidiff0, ndidiff1, ndidiff2], grayscale=False)
    plt.suptitle("Difference map NDI")

    diff_img_np = ref_ndi - ndi_data
    diff_img = nib.Nifti1Image(diff_img_np, affine1)
    nib.save(diff_img, '../Net/nii/'+subject+'-'+str(retained_vol)+'-'+model+'-patch_'+str(patch)+'-base_1-layer_'+str(layer)+'-label_NDI_difference.nii')

    odidiff0 = scale(refODI0 - odi0)
    odidiff1 = scale(refODI1 - odi1)
    odidiff2 = scale(refODI2 - odi2)
    show_slices([odidiff0, odidiff1, odidiff2], grayscale=False)
    plt.suptitle("Difference map ODI")

    diff_img_np = ref_odi - odi_data
    diff_img = nib.Nifti1Image(diff_img_np, affine2)
    nib.save(diff_img, '../Net/nii/'+subject+'-'+str(retained_vol)+'-'+model+'-patch_'+str(patch)+'-base_1-layer_'+str(layer)+'-label_ODI_difference.nii')

    fwfdiff0 = scale(refFWF0 - fwf0)
    fwfdiff1 = scale(refFWF1 - fwf1)
    fwfdiff2 = scale(refFWF2 - fwf2)
    show_slices([fwfdiff0, fwfdiff1, fwfdiff2], grayscale=False)
    plt.suptitle("Difference map FWF")

    diff_img_np = ref_fwf - fwf_data
    diff_img = nib.Nifti1Image(diff_img_np, affine3)
    nib.save(diff_img, '../Net/nii/'+subject+'-'+str(retained_vol)+'-'+model+'-patch_'+str(patch)+'-base_1-layer_'+str(layer)+'-label_FWF_difference.nii')

<h2>The class for parsing command line arguments</h2>

In [None]:
class Namespace:
    """
    a class generate parser for cmd line args
    """
    def __init__(self, **kwargs):
        self.__dict__.update(kwargs)

<h3>Data Preprocessing</h3>

In [None]:
# motion-free subject path
s01_path = '/home/vw/Desktop/IndividualProject/MedICSS2021_/Data-NODDI/s01_still/'
s02_path = '/home/vw/Desktop/IndividualProject/MedICSS2021_/Data-NODDI/s02_still/'
# motion-free target labels
s01_NDI_path = '/home/vw/Desktop/IndividualProject/MedICSS2021_/Data-NODDI/s01_still/s01_still_NDI.nii'
s02_NDI_path = '/home/vw/Desktop/IndividualProject/MedICSS2021_/Data-NODDI/s02_still/s02_still_NDI.nii'

s01_ODI_path = '/home/vw/Desktop/IndividualProject/MedICSS2021_/Data-NODDI/s01_still/s01_still_ODI.nii'
s02_ODI_path = '/home/vw/Desktop/IndividualProject/MedICSS2021_/Data-NODDI/s02_still/s02_still_ODI.nii'

s01_FWF_path = '/home/vw/Desktop/IndividualProject/MedICSS2021_/Data-NODDI/s01_still/s01_still_FWF.nii'
s02_FWF_path = '/home/vw/Desktop/IndividualProject/MedICSS2021_/Data-NODDI/s02_still/s02_still_FWF.nii'

In [None]:
def filter_mask(subpath, fwfpath, threshold=0.99):
    """
    By looking at the imgs generated, we have found out there are some regions that should not be included. Since they have values higher than 1.0
    And we have found out voxels have NDI and ODI values, while that voxel has GROUND TRUTH FWF 1.0
    This should indicate that that voxel should not even be included in the training
    Therefore we want to filter the each subject's mask first, by using their corresponding GROUND TRUTH FWF

    Args:
        subpath (string): the path of the subject folder
        fwfpath (string): the path of the corresponding fwf file
        threshold (float): the thresholds to be used to filter of the mask,
                           a stringnent threshold would be 0.9, the least stringnent threshold is 1.0
                           by default, it is set to 0.99
    """
    # fetch the mask data
    img_mask = nib.load(subpath+'mask-e.nii')
    original_mask = img_mask.get_fdata()
    original_affine = img_mask.affine
    shape = original_mask.shape # retain the shape of the mask
    origin_nonzeros = np.count_nonzero(original_mask)
    print('original mask has: ' + str(origin_nonzeros) + ' of nonzero voxels')
    # fetch the FWF data
    fwf = nib.load(fwfpath).get_fdata()
    # filter
    mask = original_mask.flatten() # this makes a copy of the orginal mask
    fwf = fwf.reshape(mask.shape[0]) # reshape fwf to the corresponding shape
    for i in range(len(mask)):
        # if fwf has high value, means there is no tissue
        # therefore, the voxel should be excluded
        if fwf[i] >= threshold:
            mask[i] = 0.0
    # reshape mask back
    mask = mask.reshape(shape)
    filter_nonzeros = np.count_nonzero(mask)
    print('filtered mask has: ' +str(filter_nonzeros) + ' of nonzero voxels')
    # save the mask
    filter_img = nib.Nifti1Image(mask, original_affine)
    nib.save(filter_img, subpath+'filtered_mask.nii')

In [None]:
# Use the above code to filter each subject's mask. Store as filtered_mask.nii in each subject folder
filter_mask(s01_path, s01_FWF_path)
filter_mask(s02_path, s02_FWF_path)

In [None]:
# filtered mask path for each subject
s01_mask_path = '/home/vw/Desktop/IndividualProject/MedICSS2021_/Data-NODDI/s01_still/filtered_mask.nii'
s02_mask_path = '/home/vw/Desktop/IndividualProject/MedICSS2021_/Data-NODDI/s02_still/filtered_mask.nii'

In [None]:
"""
Generate the base dataset for s01_still and s02_still for all NODDI parameters.
"""
cmd = "--base --label_type A --subjects s01_still s02_still"
args = data_parser().parse_args(cmd.split())
generate_data(args)

In [None]:
ltype = 'A' # the NODDI parameter we want to estimate, where A stands for ALL
model = 'conv3d'
# firstly, the 3d cnn dataset is required to be generated.
# the cmd line code for generating 3d cnn data
cnn3ddata_cmd = '--subjects s01_still s02_still' + ' --label_type ' + ltype + ' --' + model
cnn3ddata_args  = data_parser().parse_args(cnn3ddata_cmd.split())
generate_data(cnn3ddata_args)

In [None]:
"""
Using nib to fetch the ground truth img
"""
#load the truth data for subject 1
s01_NDI_img = nib.load(s01_NDI_path)
s01_ODI_img = nib.load(s01_ODI_path)
s01_FWF_img = nib.load(s01_FWF_path)
s01_NDI_affine = s01_NDI_img.affine
s01_ODI_affine = s01_ODI_img.affine
s01_FWF_affine = s01_FWF_img.affine
s01_NDI_img_data = s01_NDI_img.get_fdata()
s01_ODI_img_data = s01_ODI_img.get_fdata()
s01_FWF_img_data = s01_FWF_img.get_fdata()
#load the truth data for subject 2
s02_NDI_img = nib.load(s02_NDI_path)
s02_ODI_img = nib.load(s02_ODI_path)
s02_FWF_img = nib.load(s02_FWF_path)
s02_NDI_affine = s02_NDI_img.affine
s02_ODI_affine = s02_ODI_img.affine
s02_FWF_affine = s02_FWF_img.affine
s02_NDI_img_data = s02_NDI_img.get_fdata()
s02_ODI_img_data = s02_ODI_img.get_fdata()
s02_FWF_img_data = s02_FWF_img.get_fdata()

------------------------
<h2>Random Rejection Scheme Setup</h2>
Let N to be the tested number of retained volumes. For N, 100 subsampled schemes were drawn randomly from the full scheme (with the first b=0 volume and at least two different b values always included). Each subsampled schem was used to evaluate both techniques. For both 3D CNN and AMICO, N are 60, 40 and 30. We would give further undersampled scheme to 3D CNN, N are 20, 16 and 12.

In [None]:
# a list contains the N for each tested retained volume
N = [60, 40, 30, 20, 16, 12]
random_dict = {60:[], 40:[], 30:[], 20:[], 16:[], 12:[]}
random_dict

In [None]:
# Generated the random scheme for each tested volume
for n in N:
    # for each retained volume, we create 100 subsamples
    for i in range(100):
        ones = np.ones(n) # 1 = retained
        zeros = np.zeros(96-n)

        scheme = np.random.choice(np.concatenate([ones,zeros]), 96, replace=False)
        scheme = ' '.join(map(str, scheme))
        random_dict[n].append(scheme)

In [None]:
cnn_rmse_dict = {60:[], 40:[], 30:[], 20:[], 16:[], 12:[]}
cnn_ssim_dict = {60:[], 40:[], 30:[], 20:[], 16:[], 12:[]}

----------------------------
<h2>Training</h2>
Hyperparameters

In [None]:
# hyperparameters for the networks
layers = 4 # the number of hidden layers; this is the optimal number of layer
lr = 0.0001 # the learning rate
patch_size = 3 # the size of patches for 2D and 3D CNN
batch = 256 # the batch size
epoch = 100 # the number of epoches for training

# we firstly study subject s02, use s01 as a seperate training dataset
study_subject = 's02_still' # the subject we want to study
sep_train_subject = 's01_still' # this is separate training candidate

<h4>N = 60</h4>

In [None]:
# Training for CNN
for movefile in random_dict[60]:
    # we have 100 movefiles for each retained tested volume
    # Training
    cnntrain_cmd = '--train_subjects ' + sep_train_subject + ' --model ' + model + ' --layer ' + str(layers) + ' --label_type ' + ltype + ' --lr ' + str(lr) + ' --batch ' + str(batch) + ' --patch_size ' + str(patch_size)\
               + ' --epoch ' + str(epoch) + ' --movefile ' + movefile + ' --train'
    plot_loss(cnntrain_cmd)
    # Testing
    cnntest_cmd = '--test_subjects ' + study_subject + ' --model ' + model + ' --layer ' + str(layers) + ' --label_type ' + ltype + ' --movefile ' + movefile
    cnntest_args = model_parser().parse_args(cnntest_cmd.split())
    RMSE, SSIM = test_model(cnntest_args)
    # Add results to the cnn dict
    cnn_rmse_dict[60].append(RMSE)
    cnn_ssim_dict[60].append(SSIM)

<h4>N = 40</h4>

In [None]:
# Training for CNN
for movefile in random_dict[40]:
    # we have 100 movefiles for each retained tested volume
    # Training
    cnntrain_cmd = '--train_subjects ' + sep_train_subject + ' --model ' + model + ' --layer ' + str(layers) + ' --label_type ' + ltype + ' --lr ' + str(lr) + ' --batch ' + str(batch) + ' --patch_size ' + str(patch_size)\
               + ' --epoch ' + str(epoch) + ' --movefile ' + movefile + ' --train'
    plot_loss(cnntrain_cmd)
    # Testing
    cnntest_cmd = '--test_subjects ' + study_subject + ' --model ' + model + ' --layer ' + str(layers) + ' --label_type ' + ltype + ' --movefile ' + movefile
    cnntest_args = model_parser().parse_args(cnntest_cmd.split())
    RMSE, SSIM = test_model(cnntest_args)
    # Add results to the cnn dict
    cnn_rmse_dict[40].append(RMSE)
    cnn_ssim_dict[40].append(SSIM)

<h4>N = 30</h4>

In [None]:
# Training for CNN
for movefile in random_dict[30]:
    # we have 100 movefiles for each retained tested volume
    # Training
    cnntrain_cmd = '--train_subjects ' + sep_train_subject + ' --model ' + model + ' --layer ' + str(layers) + ' --label_type ' + ltype + ' --lr ' + str(lr) + ' --batch ' + str(batch) + ' --patch_size ' + str(patch_size)\
               + ' --epoch ' + str(epoch) + ' --movefile ' + movefile + ' --train'
    plot_loss(cnntrain_cmd)
    # Testing
    cnntest_cmd = '--test_subjects ' + study_subject + ' --model ' + model + ' --layer ' + str(layers) + ' --label_type ' + ltype + ' --movefile ' + movefile
    cnntest_args = model_parser().parse_args(cnntest_cmd.split())
    RMSE, SSIM = test_model(cnntest_args)
    # Add results to the cnn dict
    cnn_rmse_dict[30].append(RMSE)
    cnn_ssim_dict[30].append(SSIM)

<h4>N = 20</h4>

In [None]:
# Training for CNN
for movefile in random_dict[20]:
    # we have 100 movefiles for each retained tested volume
    # Training
    cnntrain_cmd = '--train_subjects ' + sep_train_subject + ' --model ' + model + ' --layer ' + str(layers) + ' --label_type ' + ltype + ' --lr ' + str(lr) + ' --batch ' + str(batch) + ' --patch_size ' + str(patch_size)\
               + ' --epoch ' + str(epoch) + ' --movefile ' + movefile + ' --train'
    plot_loss(cnntrain_cmd)
    # Testing
    cnntest_cmd = '--test_subjects ' + study_subject + ' --model ' + model + ' --layer ' + str(layers) + ' --label_type ' + ltype + ' --movefile ' + movefile
    cnntest_args = model_parser().parse_args(cnntest_cmd.split())
    RMSE, SSIM = test_model(cnntest_args)
    # Add results to the cnn dict
    cnn_rmse_dict[20].append(RMSE)
    cnn_ssim_dict[20].append(SSIM)

<h4>N = 16</h4>

In [None]:
# Training for CNN
for movefile in random_dict[16]:
    # we have 100 movefiles for each retained tested volume
    # Training
    cnntrain_cmd = '--train_subjects ' + sep_train_subject + ' --model ' + model + ' --layer ' + str(layers) + ' --label_type ' + ltype + ' --lr ' + str(lr) + ' --batch ' + str(batch) + ' --patch_size ' + str(patch_size)\
               + ' --epoch ' + str(epoch) + ' --movefile ' + movefile + ' --train'
    plot_loss(cnntrain_cmd)
    # Testing
    cnntest_cmd = '--test_subjects ' + study_subject + ' --model ' + model + ' --layer ' + str(layers) + ' --label_type ' + ltype + ' --movefile ' + movefile
    cnntest_args = model_parser().parse_args(cnntest_cmd.split())
    RMSE, SSIM = test_model(cnntest_args)
    # Add results to the cnn dict
    cnn_rmse_dict[16].append(RMSE)
    cnn_ssim_dict[16].append(SSIM)

<h4>N = 12</h4>

In [None]:
# Training for CNN
for movefile in random_dict[12]:
    # we have 100 movefiles for each retained tested volume
    # Training
    cnntrain_cmd = '--train_subjects ' + sep_train_subject + ' --model ' + model + ' --layer ' + str(layers) + ' --label_type ' + ltype + ' --lr ' + str(lr) + ' --batch ' + str(batch) + ' --patch_size ' + str(patch_size)\
               + ' --epoch ' + str(epoch) + ' --movefile ' + movefile + ' --train'
    plot_loss(cnntrain_cmd)
    # Testing
    cnntest_cmd = '--test_subjects ' + study_subject + ' --model ' + model + ' --layer ' + str(layers) + ' --label_type ' + ltype + ' --movefile ' + movefile
    cnntest_args = model_parser().parse_args(cnntest_cmd.split())
    RMSE, SSIM = test_model(cnntest_args)
    # Add results to the cnn dict
    cnn_rmse_dict[12].append(RMSE)
    cnn_ssim_dict[12].append(SSIM)