In [None]:
# Add the source directory to the path
import sys
sys.path.append('/home/sleyse4/repos/LoRE_SD/LoRE-SD/src')
sys.path.append('/home/sleyse4/repos/LoRE_SD/LoRE-SD')

In [None]:
import os
import subprocess
import shutil

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

from utils import gradient_utils, plot_utils, io_utils, math_utils
from optimisation import optimise

import contrasts

from mrtrix_io.io import load_mrtrix, save_mrtrix
from mrtrix_io.io.image import Image

import utils.SphericalHarmonics as sh

from matplotlib.colors import LinearSegmentedColormap
from matplotlib.transforms import Bbox

# Constants

In [None]:
lmax = 8

Q = np.load('/LOCALDATA/sleyse4/Q_odf.npy')

# Multiprocessing reduces the time to run the simulations
cores = 150

In [None]:
# LoRE-SD Parameters
reg = 1e-3

AD = np.linspace(0, 4e-3, 10)
RD = np.linspace(0, 4e-3, 10)

In [None]:
def add_noise(noise_free, grad, snr, mask):
    """
    Add Gaussian noise or Rician noise to a signal.

    Returns:
        numpy.ndarray: A signal with added Gaussian/Rician noise.
    """

    mean_b0 = np.mean(noise_free[..., grad[...,-1] == 0][mask])

    std_noise = mean_b0 / snr
    noise1 = np.random.normal(0, std_noise, size=noise_free.shape)
    noise2 = np.random.normal(0, std_noise, size=noise_free.shape)
    noisy_dwi = np.sqrt((noise_free + noise1)**2 + noise2**2)

    return noisy_dwi

In [None]:
def run_lore_sd(dwi, mask, out_dir, gt_odf, reg=1e-3, grid_size=10, slice=None):
    # Run the LoRE-SD algorithm
    lore_cmd = (
        f'python /home/sleyse4/repos/LoRE_SD/LoRE-SD/dwi2decomposition.py {dwi} '
        f'{out_dir} --cores 100 --mask {mask} --reg {reg} --grid_size {grid_size} '
        f'--eval_matrix /LOCALDATA/sleyse4/Q_odf.npy'
    )
    if slice is not None:
        lore_cmd += f' --slice {slice}'
    subprocess.run(lore_cmd, shell=True)
    recon = load_mrtrix(f'{out_dir}/reconstructed.mif').data
    data = load_mrtrix(dwi).data
    mask = load_mrtrix(mask).data > .5
    odf = load_mrtrix(f'{out_dir}/odf.mif').data
    if slice is not None:
        mask = mask[:,:, slice:slice+1]
        data = data[:,:, slice:slice+1]
        gt_odf = gt_odf[:,:, slice:slice+1]
    rmse = math_utils.rmse(data, recon, mask)
    acc = sh.angularCorrelation(gt_odf, odf)
    save_mrtrix(f'{out_dir}/rmse.mif', Image(rmse, vox=load_mrtrix(dwi).vox, comments=''))
    save_mrtrix(f'{out_dir}/acc.mif', Image(acc, vox=load_mrtrix(dwi).vox, comments=''))

def run_msmt_csd(dwi, mask, out_dir, gt_odf):
    # Run the MSMT-CSD algorithm
    msmt_cmd = (
        f'dwi2response dhollander {dwi} {out_dir}/wm.txt {out_dir}/gm.txt {out_dir}/csf.txt '
        f'--mask {mask} --force'
    )
    subprocess.run(msmt_cmd, shell=True)

    msmt_cmd = (
        f'dwi2fod msmt_csd {dwi} {out_dir}/wm.txt {out_dir}/wm.mif {out_dir}/gm.txt {out_dir}/gm.mif '
        f'{out_dir}/csf.txt {out_dir}/csf.mif --mask {mask} --predicted_signal {out_dir}/pred.mif --force'
    )
    subprocess.run(msmt_cmd, shell=True)

    recon = load_mrtrix(f'{out_dir}/pred.mif').data
    data = load_mrtrix(dwi).data
    rmse = math_utils.rmse(data, recon, load_mrtrix(mask).data > .5)
    acc = sh.angularCorrelation(gt_odf, load_mrtrix(f'{out_dir}/wm.mif').data)
    save_mrtrix(f'{out_dir}/rmse.mif', Image(rmse, vox=load_mrtrix(dwi).vox, comments=''))
    save_mrtrix(f'{out_dir}/acc.mif', Image(acc, vox=load_mrtrix(dwi).vox, comments=''))

In [None]:
READ_DIR = f'/DATASERVER/MIC/GENERAL/STAFF/sleyse4/u0152170/Blind_Deconvolution/Preprocessing/simulations'

SLICE = 40

snr_dict = {
    'snr10': 10,
    'snr20': 20,
    'snr50': 50,
    'noise_free': np.inf
}

reg_vals = [0, 1e-5, 1e-4, 1e-3, 1e-2, 1e-1, 1, 10]
grid_sizes = [3,5,7,10,15,20]

noise_free = load_mrtrix(os.path.join(READ_DIR, 'noise_free.mif')).data
grad = load_mrtrix(os.path.join(READ_DIR, 'noise_free.mif')).grad
wm_mask = load_mrtrix(os.path.join(READ_DIR, 'wm_mask.mif')).data > .5
vox = load_mrtrix(os.path.join(READ_DIR, 'noise_free.mif')).vox

if SLICE is not None:
    noise_free = noise_free[:,:, SLICE]
    wm_mask = wm_mask[:,:, SLICE]
    

# Hyperparameter Tuning: Regularisation parameter $\lambda$

In [None]:
# for snr_name, snr in snr_dict.items():
#     snr_dir = os.path.join(READ_DIR, snr_name)
#     if os.path.exists(os.path.join(snr_dir, f'{snr_name}.mif')):
#         print(f'Loading {snr_name}')
#         noisy_dwi = load_mrtrix(os.path.join(snr_dir, f'{snr_name}.mif')).data
#         gt_odf = load_mrtrix(os.path.join(READ_DIR, 'gt_odf.mif')).data
#     else:
#         print(f'Adding noise to {snr_name}')
#         noisy_dwi = add_noise(noise_free, grad, snr, wm_mask)
#         gt_odf = load_mrtrix(os.path.join(READ_DIR, 'gt_odf.mif')).data
#         print(f'Saving {snr_name}')
#         save_mrtrix(os.path.join(snr_dir, f'{snr_name}.mif'), Image(noisy_dwi, vox=vox, grad=grad))

#     if SLICE is not None:
#         noisy_dwi = noisy_dwi[:,:, SLICE:SLICE+1]
#         gt_odf = gt_odf[:,:, SLICE:SLICE+1]

#     for reg in reg_vals:
#         lore_dir = os.path.join(snr_dir, 'LoRE', f'reg_{reg}')
#         if not os.path.exists(lore_dir):
#             os.makedirs(lore_dir)

#         print(f'Running LoRE-SD on {snr_name} with reg {reg}')
#         run_lore_sd(os.path.join(snr_dir, f'{snr_name}.mif'), os.path.join(READ_DIR, 'mask.mif'), lore_dir, gt_odf, reg=reg, slice=SLICE)
#     for grid_size in grid_sizes:
#         lore_dir = os.path.join(snr_dir, 'LoRE', f'grid_{grid_size}')
#         if not os.path.exists(lore_dir):
#             os.makedirs(lore_dir)

#         print(f'Running LoRE-SD on {snr_name} with grid size {grid_size}')
#         run_lore_sd(os.path.join(snr_dir, f'{snr_name}.mif'), os.path.join(READ_DIR, 'mask.mif'), lore_dir, gt_odf, reg=1e-3, grid_size=grid_size, slice=SLICE)

In [None]:
fig = plt.figure(figsize=(14,14))

gs = fig.add_gridspec(len(snr_dict), 2, wspace=.2, hspace=1)

wm_mask = load_mrtrix(os.path.join(READ_DIR, 'wm_mask.mif')).data > .5
if SLICE is not None:
    wm_mask = wm_mask[:,:,SLICE]

mask = load_mrtrix(os.path.join(READ_DIR, 'mask.mif')).data > .5
if SLICE is not None:
    mask = mask[:,:,SLICE]

for i, (snr_name, snr) in enumerate(snr_dict.items()):

    acc_ax = fig.add_subplot(gs[i, 0])

    acc_rect = plt.Rectangle((2.5, 0), 1, 2, facecolor='green', alpha=.6)
    acc_ax.add_patch(acc_rect)
    acc_ax.set_ylim(.7, 1)
    acc_lore = []
    for reg in reg_vals:
        acc_lore.append(load_mrtrix(os.path.join(READ_DIR, snr_name, 'LoRE', f'reg_{reg}', 'acc.mif')).data[wm_mask].flatten())
    acc_lore = np.array(acc_lore).T
    bp_lore = sns.boxplot(data=acc_lore, ax=acc_ax, width=.5, saturation=1, showfliers=False, color='blue')
    
    acc_mtcsd = load_mrtrix(os.path.join(READ_DIR, snr_name, 'MTCSD', 'acc.mif')).data
    if SLICE is not None:
        acc_mtcsd = acc_mtcsd[:,:,SLICE]
    acc_mtcsd = acc_mtcsd[wm_mask].flatten()
    bp_mtcsd = sns.boxplot(data=acc_mtcsd, ax=acc_ax, width=.5, saturation=1, showfliers=False, color='orange')

    acc_ax.set_xlim(-0.5, len(reg_vals) + .5)
    acc_ax.set_xticks(range(len(reg_vals)+1))
    acc_ax.set_xticklabels([0] + [f'$10^{{{np.log10(r):.0f}}}$' for r in reg_vals[1:]] + ['MSMT-CSD'])

    acc_ax.get_xticklabels()[-1].set_rotation(90)

    rmse_ax = fig.add_subplot(gs[i, 1])
    rmse_rect = plt.Rectangle((2.5, 0), 1, 500, facecolor='green', alpha=.6)
    rmse_ax.add_patch(rmse_rect)
    rmse_ax.set_ylim(0,500)

    rmse_lore = []
    for reg in reg_vals:
        rmse_lore.append(load_mrtrix(os.path.join(READ_DIR, snr_name, 'LoRE', f'reg_{reg}', 'rmse.mif')).data[mask].flatten())
    rmse_lore = np.array(rmse_lore).T
    bp_lore = sns.boxplot(data=rmse_lore, ax=rmse_ax, width=.5, saturation=1, showfliers=False, color='blue', positions=np.arange(len(reg_vals)))

    rmse_mtcsd = load_mrtrix(os.path.join(READ_DIR, snr_name, 'MTCSD', 'rmse.mif')).data
    if SLICE is not None:
        rmse_mtcsd = rmse_mtcsd[:,:,SLICE]
    rmse_mtcsd = rmse_mtcsd[mask].flatten()
    bp_mtcsd = sns.boxplot(data=rmse_mtcsd, ax=rmse_ax, width=.5, saturation=1, showfliers=False, color='orange', positions=[len(reg_vals)])

    dwi_gt = load_mrtrix(os.path.join(READ_DIR, snr_name, f'{snr_name}.mif')).data
    if SLICE is not None:
        dwi_gt = dwi_gt[:,:,SLICE]
    rmse_gt = math_utils.rmse(dwi_gt, noise_free, mask)
    rmse_gt = rmse_gt[mask].flatten()
    bp_gt = sns.boxplot(data=rmse_gt, ax=rmse_ax, width=.5, saturation=1, showfliers=False, color='green', positions=[len(reg_vals) + 1])
    
    rmse_ax.set_xlim(-0.5, len(reg_vals) + 1.5)
    rmse_ax.set_xticks(range(len(reg_vals)+2))
    rmse_ax.set_xticklabels([0] + [f'$10^{{{np.log10(r):.0f}}}$' for r in reg_vals[1:]] + ['MSMT-CSD', 'Ground Truth'])
    rmse_ax.get_xticklabels()[-1].set_rotation(90)
    rmse_ax.get_xticklabels()[-2].set_rotation(90)

    # Add a vertical line to separate the LoRE-SD and MSMT-CSD results
    acc_ax.axvline(len(reg_vals) - .5, color='black', linestyle='--')
    rmse_ax.axvline(len(reg_vals) - .5, color='black', linestyle='--')
    acc_ax.set_xlabel('Regularisation', labelpad=-40)

    if i == 0:
        acc_ax.set_title('Angular Correlation')
        rmse_ax.set_title('RMSE')
    
    if i < len(snr_dict) -1:
        acc_ax.set_ylabel(f'SNR {snr}')
    else:
        acc_ax.set_ylabel('Noise Free')

    plt.savefig('/home/sleyse4/reg.png', dpi=300)
    

In [None]:
# HEALTHY_DIR = '/DATASERVER/MIC/GENERAL/STAFF/sleyse4/u0152170/DATA/philips/PREPROC'
# PREPROC_DIR = '/DATASERVER/MIC/GENERAL/STAFF/sleyse4/u0152170/Blind_Deconvolution/Preprocessing/philips'

# gt_odf = load_mrtrix(os.path.join(PREPROC_DIR, 'MTCSD/wm.mif')).data
# for reg in reg_vals:
#     lore_dir = os.path.join(PREPROC_DIR, 'LoRE', f'reg_{reg}')
#     if not os.path.exists(lore_dir):
#         os.makedirs(lore_dir)

#     print(f'Running LoRE-SD on healthy with reg {reg}')
#     run_lore_sd(os.path.join(HEALTHY_DIR, 'dwi.mif'), 
#                 os.path.join(HEALTHY_DIR, 'mask.mif'), lore_dir, 
#                 gt_odf, reg=reg, slice=33)

# for g in grid_sizes:
#     lore_dir = os.path.join(PREPROC_DIR, 'LoRE', f'grid_{g}')
#     if not os.path.exists(lore_dir):
#         os.makedirs(lore_dir)

#     print(f'Running LoRE-SD on healthy with grid size {g}')
#     run_lore_sd(os.path.join(HEALTHY_DIR, 'dwi.mif'), 
#                 os.path.join(HEALTHY_DIR, 'mask.mif'), lore_dir, 
#                 gt_odf, reg=1e-3, grid_size=g, slice=33)

In [None]:
from matplotlib.colors import PowerNorm
import matplotlib
matplotlib.use('Agg')  # Use a non-interactive backend

reg_int = [0, 1e-5, 1e-3, 1e-1, 10]

rows = 2
cols = len(reg_int)+1
fig = plt.figure(figsize=(cols*4, rows*4))
gs = fig.add_gridspec(rows, cols, wspace=0.2, hspace=0.2, width_ratios=[1]*(cols-1)+[.1])
mask = load_mrtrix(os.path.join(HEALTHY_DIR, 'mask.mif')).data[:,:,33] > .5

# Green rectangle around the third column
rect = plt.Rectangle((782, 43), 250, 630, facecolor='green', alpha=.6, zorder=-100)
fig.patches.append(rect)

for i, reg in enumerate(reg_int):
    lore_dir = os.path.join(PREPROC_DIR, 'LoRE', f'reg_{reg}')
    print(f'Loading {lore_dir}')
    rmse = np.squeeze(load_mrtrix(os.path.join(lore_dir, 'rmse.mif')).data)[20:-20,20:-15]
    acc = np.squeeze(load_mrtrix(os.path.join(lore_dir, 'acc.mif')).data)[20:-20,20:-15]

    ax_acc = fig.add_subplot(gs[0, i])
    im_acc = ax_acc.imshow(np.rot90(acc), cmap='copper', norm=PowerNorm(vmin=0, vmax=1, gamma=6), alpha=np.rot90(mask[20:-20,20:-15]).astype(float), zorder=100)
    if i == 0:
        ax_acc.set_title(fr'No regularisation', fontsize=20)
    else:
        ax_acc.set_title(fr'$\lambda = 10^{{{np.log10(reg):.0f}}}$', fontsize=20)

    ax_rmse = fig.add_subplot(gs[1, i])
    im_rmse = ax_rmse.imshow(np.rot90(rmse), cmap='inferno', alpha=np.rot90(mask[20:-20,20:-15]).astype(float), vmin=0, vmax=400, zorder=100)

    if i == 0:
        ax_acc.set_ylabel('ACC', fontsize=20)
        ax_acc.set_xticks([])
        ax_acc.set_yticks([])
        for spine in ax_acc.spines.values():
            spine.set_visible(False)
        ax_rmse.set_ylabel('RMSE', fontsize=20)
        ax_rmse.set_xticks([])
        ax_rmse.set_yticks([])
        for spine in ax_rmse.spines.values():
            spine.set_visible(False)
    else:
        ax_acc.set_axis_off()
        ax_rmse.set_axis_off()


ax_cbar_acc = fig.add_subplot(gs[0, -1])
cbar_acc = plt.colorbar(im_acc, cax=ax_cbar_acc)
cbar_acc.ax.set_yticks([0, 0.8, 0.9, 0.95, 1])

ax_cbar_res = fig.add_subplot(gs[1, -1])
cbar_res = plt.colorbar(im_rmse, cax=ax_cbar_res)
cbar_res.ax.set_yticks([0, 100, 200, 300, 400])
plt.show()



In [None]:
# for i, (snr_name, snr) in enumerate(snr_dict.items()):
#     snr_dir = os.path.join(READ_DIR, snr_name)
#     lore_dir = os.path.join(snr_dir, 'LoRE')
#     to_contrast_cmd = f'python /home/sleyse4/repos/LoRE_SD/LoRE-SD/scripts/decomposition2contrast.py ' \
#     f'{os.path.join(lore_dir, "gaussian_fractions.mif")} {os.path.join(lore_dir, "response.mif")} {os.path.join(lore_dir, "contrasts")}'

#     subprocess.run(to_contrast_cmd, shell=True)

# for reg in reg_vals:
#     for snr_name in snr_dict.keys():
#         lore_dir = os.path.join(READ_DIR, snr_name, 'LoRE', f'reg_{reg}')
#         to_contrast_cmd = f'python /home/sleyse4/repos/LoRE_SD/LoRE-SD/scripts/decomposition2contrast.py ' \
#         f'{os.path.join(lore_dir, "gaussian_fractions.mif")} {os.path.join(lore_dir, "response.mif")} {os.path.join(lore_dir, "contrasts")}'

#         subprocess.run(to_contrast_cmd, shell=True)

# for g in grid_sizes:
#     for snr_name in snr_dict.keys():
#         lore_dir = os.path.join(READ_DIR, snr_name, 'LoRE', f'grid_{g}')
#         to_contrast_cmd = f'python /home/sleyse4/repos/LoRE_SD/LoRE-SD/scripts/decomposition2contrast.py ' \
#         f'{os.path.join(lore_dir, "gaussian_fractions.mif")} {os.path.join(lore_dir, "response.mif")} {os.path.join(lore_dir, "contrasts")}'

#         subprocess.run(to_contrast_cmd, shell=True)

In [None]:
fig = plt.figure(figsize=(14,14))

gs = fig.add_gridspec(len(snr_dict), 2, wspace=.2, hspace=1)

wm_mask = load_mrtrix(os.path.join(READ_DIR, 'wm_mask.mif')).data > .5
if SLICE is not None:
    wm_mask = wm_mask[:,:,SLICE]

mask = load_mrtrix(os.path.join(READ_DIR, 'mask.mif')).data > .5
if SLICE is not None:
    mask = mask[:,:,SLICE]

for i, (snr_name, snr) in enumerate(snr_dict.items()):

    acc_ax = fig.add_subplot(gs[i, 0])

    acc_rect = plt.Rectangle((2.5, 0), 1, 2, facecolor='green', alpha=.6)
    acc_ax.add_patch(acc_rect)
    acc_ax.set_ylim(.7, 1)
    acc_lore = []
    for grid_size in grid_sizes:
        acc_lore.append(load_mrtrix(os.path.join(READ_DIR, snr_name, 'LoRE', f'grid_{grid_size}', 'acc.mif')).data[wm_mask].flatten())
    acc_lore = np.array(acc_lore).T
    bp_lore = sns.boxplot(data=acc_lore, ax=acc_ax, width=.5, saturation=1, showfliers=False, color='blue')
    
    acc_mtcsd = load_mrtrix(os.path.join(READ_DIR, snr_name, 'MTCSD', 'acc.mif')).data
    if SLICE is not None:
        acc_mtcsd = acc_mtcsd[:,:,SLICE]
    acc_mtcsd = acc_mtcsd[wm_mask].flatten()
    bp_mtcsd = sns.boxplot(data=acc_mtcsd, ax=acc_ax, width=.5, saturation=1, showfliers=False, color='orange')

    acc_ax.set_xlim(-0.5, len(grid_sizes) + .5)
    acc_ax.set_xticks(range(len(grid_sizes)+1))
    acc_ax.set_xticklabels([f'${r}$' for r in grid_sizes] + ['MSMT-CSD'])

    acc_ax.get_xticklabels()[-1].set_rotation(90)

    rmse_ax = fig.add_subplot(gs[i, 1])
    rmse_rect = plt.Rectangle((2.5, 0), 1, 500, facecolor='green', alpha=.6)
    rmse_ax.add_patch(rmse_rect)
    rmse_ax.set_ylim(0,400)

    rmse_lore = []
    for grid_size in grid_sizes:
        rmse_lore.append(load_mrtrix(os.path.join(READ_DIR, snr_name, 'LoRE', f'grid_{grid_size}', 'rmse.mif')).data[mask].flatten())
    rmse_lore = np.array(rmse_lore).T
    bp_lore = sns.boxplot(data=rmse_lore, ax=rmse_ax, width=.5, saturation=1, showfliers=False, color='blue', positions=np.arange(len(grid_sizes)))

    rmse_mtcsd = load_mrtrix(os.path.join(READ_DIR, snr_name, 'MTCSD', 'rmse.mif')).data
    if SLICE is not None:
        rmse_mtcsd = rmse_mtcsd[:,:,SLICE]
    rmse_mtcsd = rmse_mtcsd[mask].flatten()
    bp_mtcsd = sns.boxplot(data=rmse_mtcsd, ax=rmse_ax, width=.5, saturation=1, showfliers=False, color='orange', positions=[len(grid_sizes)])

    dwi_gt = load_mrtrix(os.path.join(READ_DIR, snr_name, f'{snr_name}.mif')).data
    if SLICE is not None:
        dwi_gt = dwi_gt[:,:,SLICE]
    rmse_gt = math_utils.rmse(dwi_gt, noise_free, mask)
    rmse_gt = rmse_gt[mask].flatten()
    bp_gt = sns.boxplot(data=rmse_gt, ax=rmse_ax, width=.5, saturation=1, showfliers=False, color='green', positions=[len(grid_sizes) + 1])
    
    rmse_ax.set_xlim(-0.5, len(grid_sizes) + 1.5)
    rmse_ax.set_xticks(range(len(grid_sizes)+2))
    rmse_ax.set_xticklabels([f'${r}$' for r in grid_sizes] + ['MSMT-CSD', 'Ground Truth'])
    rmse_ax.get_xticklabels()[-1].set_rotation(90)
    rmse_ax.get_xticklabels()[-2].set_rotation(90)

    # Add a vertical line to separate the LoRE-SD and MSMT-CSD results
    acc_ax.axvline(len(grid_sizes) - .5, color='black', linestyle='--')
    rmse_ax.axvline(len(grid_sizes) - .5, color='black', linestyle='--')
    acc_ax.set_xlabel(fr'Grid size $N \times N$', labelpad=-40)

    if i == 0:
        acc_ax.set_title(fr'ACC')
        rmse_ax.set_title('RMSE')
    
    if i < len(snr_dict) -1:
        acc_ax.set_ylabel(f'SNR {snr}')
    else:
        acc_ax.set_ylabel('Noise Free')

    plt.savefig('/home/sleyse4/grid.png', dpi=300)

In [None]:
from matplotlib.gridspec import GridSpec

# Plot the intra-axonal contrast and FA contrast wrt regularisation and SNR
reg_int = [0, 1e-5, 1e-3, 1e-1, 10]
sub_dict = {
    'snr50': 50,
}
fig = plt.figure(figsize=(5*len(snr_dict), 5*len(reg_int)))
gs = fig.add_gridspec(len(snr_dict), len(reg_int)+1, wspace=0, hspace=0, width_ratios=[1]*len(reg_int)+[.1])

mask = load_mrtrix(os.path.join(READ_DIR, 'mask.mif')).data > .5
if SLICE is not None:
    mask = mask[:,:,SLICE]

gt_fa = load_mrtrix(os.path.join(READ_DIR, 'fa.mif')).data
if SLICE is not None:
    gt_fa = gt_fa[:,:,SLICE]

for i, (snr_name, snr) in enumerate(sub_dict.items()):
    for j, reg in enumerate(reg_int):
        fa = np.squeeze(load_mrtrix(os.path.join(READ_DIR, snr_name, 'LoRE', f'reg_{reg}', 'contrasts', 'fa.mif')).data)
        ax = fig.add_subplot(gs[i, j])
        ax.imshow(np.rot90(fa[20:-20,20:-15]), cmap='gray', vmin=0, vmax=1, alpha=np.rot90(mask[20:-20,20:-15]).astype(float))

        ax.set_xticks([])
        ax.set_yticks([])

        if i == 0:
            if j == 0:
                ax.set_title(f'No regularisation')
            else:
                ax.set_title(f'$10^{{{np.log10(reg):.0f}}}$')
        if j == 0:
            if snr is not np.inf:
                ax.set_ylabel(f'SNR {snr}')
            else:
                ax.set_ylabel('Noise Free')
# Add a colorbar
cax = fig.add_subplot(gs[0, -1])
cmap = LinearSegmentedColormap.from_list('mycmap', [(0, 'black'), (1, 'white')])
norm = plt.Normalize(0, 1)
cb1 = plt.colorbar(plt.cm.ScalarMappable(norm=norm, cmap=cmap), cax=cax)
cb1.set_label('FA')



In [None]:
from matplotlib.gridspec import GridSpec

# Plot the intra-axonal contrast and FA contrast wrt regularisation and SNR
grid_int = [3, 5, 7, 10, 15, 20]
sub_dict = {
    'snr20': 20,
}
fig = plt.figure(figsize=(5*len(snr_dict), 5*len(reg_int)))
gs = fig.add_gridspec(len(snr_dict), len(reg_int)+1, wspace=0, hspace=0, width_ratios=[1]*len(reg_int)+[.1])

mask = load_mrtrix(os.path.join(READ_DIR, 'mask.mif')).data > .5
if SLICE is not None:
    mask = mask[:,:,SLICE]

gt_fa = load_mrtrix(os.path.join(READ_DIR, 'fa.mif')).data
if SLICE is not None:
    gt_fa = gt_fa[:,:,SLICE]

for i, (snr_name, snr) in enumerate(sub_dict.items()):
    for j, g in enumerate(grid_int):
        fa = np.squeeze(load_mrtrix(os.path.join(READ_DIR, snr_name, 'LoRE', f'grid_{g}', 'contrasts', 'fa.mif')).data)
        ax = fig.add_subplot(gs[i, j])
        ax.imshow(np.rot90(fa[20:-20,20:-15]), cmap='gray', vmin=0, vmax=1, alpha=np.rot90(mask[20:-20,20:-15]).astype(float))

        ax.set_xticks([])
        ax.set_yticks([])

        if i == 0:
            ax.set_title(f'{g}x{g}')
        if j == 0:
            if snr is not np.inf:
                ax.set_ylabel(f'SNR {snr}')
            else:
                ax.set_ylabel('Noise Free')
# Add a colorbar
cax = fig.add_subplot(gs[0, -1])
cmap = LinearSegmentedColormap.from_list('mycmap', [(0, 'black'), (1, 'white')])
norm = plt.Normalize(0, 1)
cb1 = plt.colorbar(plt.cm.ScalarMappable(norm=norm, cmap=cmap), cax=cax)
cb1.set_label('FA')

