#  Frontmatter

In [None]:
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from matplotlib.lines import Line2D
from matplotlib.colors import TwoSlopeNorm
import matplotlib.ticker as ticker
from dwMRI_BasicFuncs import *
from tqdm.auto import tqdm
from joblib import Parallel, delayed

from scipy.ndimage import gaussian_filter,binary_dilation

# Define font properties
font = {
    'family': 'sans-serif',  # Use sans-serif family
    'sans-serif': ['Helvetica'],  # Specify Helvetica as the sans-serif font
    'size': 14  # Set the default font size
}
plt.rc('font', **font)

# Set tick label sizes
plt.rc('ytick', labelsize=24)
plt.rc('xtick', labelsize=24)

plt.rcParams.update({
    "text.usetex": False,
    "font.family": "Helvetica"
})
# Customize axes spines and legend appearance
plt.rcParams['axes.spines.top'] = False
plt.rcParams['axes.spines.right'] = False
plt.rcParams['legend.frameon'] = False

%load_ext autoreload
%autoreload 2
    

In [None]:
network_path = './Networks/'
NoiseLevels = [None,20,10,5,2]

TrainingSamples = 50000
InferSamples    = 500

lower_abs,upper_abs = -0.07,0.07
lower_rest,upper_rest = -0.015,0.015
lower_S0 = 25
upper_S0 = 2000
Save = True

TrueCol  = 'k'
NoisyCol = 'k'
WLSFit   = 'sandybrown'
SBIFit   = np.array([64,176,166])/255

Errors_name = ['MD comparison','FA comparison','eig. comparison','Frobenius','Signal comparison','Correlation','Signal comparison','Correlation2']

DatFolder = './SavedDat/'
MSDir = './MS_data/'
Save = False

ChunkSize = 128

## Plotting

In [None]:
def viol_plot(A,col,hatch=False,**kwargs):
    A_T = np.transpose(A)
    filtered_A = []
    for column in A_T:
        # Remove NaNs
        column = column[~np.isnan(column)]
        # Identify outliers using Z-score
        z_scores = stats.zscore(column)
        abs_z_scores = np.abs(z_scores)
        # Filter data within 3 standard deviations
        filtered_entries = (abs_z_scores < 1000)
        filtered_column = column[filtered_entries]
        filtered_A.append(filtered_column)
    
    vp = plt.violinplot(filtered_A,showmeans=True,**kwargs)  
    for v in vp['bodies']:
        v.set_facecolor(col)
    vp['cbars'].set_color(col)
    vp['cmins'].set_color(col)
    vp['cmaxes'].set_color(col)
    vp['cmeans'].set_color('black')
    if(hatch):
        vp['bodies'][0].set_hatch('//')
def BoxPlots(y_data, positions, colors, colors2, ax,hatch = False,scatter=False,scatter_alpha=0.5, **kwargs):

    GREY_DARK = "#747473"
    jitter = 0.02
    # Clean data to remove NaNs column-wise
    if(np.ndim(y_data) == 1):
        cleaned_data = y_data[~np.isnan(y_data)]
    else:
        cleaned_data = [d[~np.isnan(d)] for d in y_data]
    
    # Define properties for the boxes (patch objects)
    boxprops = dict(
        linewidth=2, 
        facecolor='none',       # use facecolor for filling (set to 'none' if you want no fill)
        edgecolor='turquoise'   # edgecolor for the outline
    )

    # Define properties for the medians (Line2D objects)
    # Ensure GREY_DARK is defined (or replace it with a color string)
    medianprops = dict(
        linewidth=2, 
        color=GREY_DARK,
        solid_capstyle="butt"
    )

    # For whiskers, since they are Line2D objects, use 'color'
    whiskerprops = dict(
        linewidth=2, 
        color='turquoise'
    )

    bplot = ax.boxplot(
        cleaned_data,
        positions=positions, 
        showfliers=False,
        showcaps = False,
        medianprops=medianprops,
        whiskerprops=whiskerprops,
        boxprops=boxprops,
        patch_artist=True,
        **kwargs
    )

    # Update the color of each box (these are patch objects)
    for i, box in enumerate(bplot['boxes']):
        box.set_edgecolor(colors[i])
        if(hatch):
            box.set_hatch('/')
    
    
    # Update the color of the whiskers (each box has 2 whiskers)
    for i in range(len(positions)):
        bplot['whiskers'][2*i].set_color(colors[i])
        bplot['whiskers'][2*i+1].set_color(colors[i])
    
    # If caps are enabled, update their color (Line2D objects)
    if 'caps' in bplot:
        for i, cap in enumerate(bplot['caps']):
            cap.set_color(colors[i//2])  # two caps per box

    if(scatter):
        if(np.ndim(cleaned_data) == 1):
            x_data = np.array([positions] * len(cleaned_data))
            x_jittered = x_data + stats.t(df=6, scale=jitter).rvs(len(x_data))
            ax.scatter(x_data, cleaned_data, s=100, color=colors2, alpha=scatter_alpha)
        else:
            x_data = [np.array([positions[i]] * len(d)) for i, d in enumerate(cleaned_data)]
            x_jittered = [x + stats.t(df=6, scale=jitter).rvs(len(x)) for x in x_data]
            # Plot the scatter points with jitter (using colors2)
            for x, y, c in zip(x_jittered, cleaned_data, colors2):
                ax.scatter(x, y, s=100, color=c, alpha=scatter_alpha)
def BoxPlots2(y_data, positions, colors, colors2, ax,hatch = False):
    import numpy as np
    from scipy import stats

    jitter = 0.02
    x_data = [np.array([positions[i]] * len(d)) for i, d in enumerate(y_data)]
    x_jittered = [x + stats.t(df=6, scale=jitter).rvs(len(x)) for x in x_data]

    # Define properties for the boxes (patch objects)
    boxprops = dict(
        linewidth=2, 
        facecolor='none',       # use facecolor for filling (set to 'none' if you want no fill)
        edgecolor='turquoise'   # edgecolor for the outline
    )

    # Define properties for the medians (Line2D objects)
    # Ensure GREY_DARK is defined (or replace it with a color string)
    medianprops = dict(
        linewidth=2, 
        color='dimgray',  # Replace 'GREY_DARK' with an actual color if needed
        solid_capstyle="butt"
    )

    # For whiskers, since they are Line2D objects, use 'color'
    whiskerprops = dict(
        linewidth=2, 
        color='turquoise'
    )

    bplot = ax.boxplot(
        y_data,
        positions=positions, 
        showfliers=False,
        showcaps=False,
        showmeans=True,
        medianprops=medianprops,
        whiskerprops=whiskerprops,
        boxprops=boxprops,
        patch_artist=True
    )

    # Update the color of each box (these are patch objects)
    for i, box in enumerate(bplot['boxes']):
        box.set_edgecolor(colors[i])
        if(hatch):
            box.set_hatch('/')
    
    # Update the color of the medians (Line2D objects)
    for i, median in enumerate(bplot['medians']):
        median.set_color(colors[i])
    
    # Update the color of the whiskers (each box has 2 whiskers)
    for i in range(len(positions)):
        bplot['whiskers'][2*i].set_color(colors[i])
        bplot['whiskers'][2*i+1].set_color(colors[i])
    
    # If caps are enabled, update their color (Line2D objects)
    if 'caps' in bplot:
        for i, cap in enumerate(bplot['caps']):
            cap.set_color(colors[i//2])  # two caps per box

    # Plot the scatter points with jitter (using colors2)
    for x, y, c in zip(x_jittered, y_data, colors2):
        ax.scatter(x, y, s=100, color=c, alpha=0.5)

## DKI Fit

In [None]:

i = 1
fdwi = './HCP_data/Pat'+str(i)+'/diff_1k.nii.gz'
bvalloc = './HCP_data/Pat'+str(i)+'/bvals_1k.txt'
bvecloc = './HCP_data/Pat'+str(i)+'/bvecs_1k.txt'

fdwi3 = './HCP_data/Pat'+str(i)+'/diff_3k.nii.gz'
bvalloc3 = './HCP_data/Pat'+str(i)+'/bvals_3k.txt'
bvecloc3 = './HCP_data/Pat'+str(i)+'/bvecs_3k.txt'

bvalsHCP = np.loadtxt(bvalloc)
bvecsHCP = np.loadtxt(bvecloc)
gtabHCP = gradient_table(bvalsHCP, bvecsHCP)

bvalsHCP3 = np.loadtxt(bvalloc3)
bvecsHCP3 = np.loadtxt(bvecloc3)
gtabHCP3 = gradient_table(bvalsHCP3, bvecsHCP3)

gtabExt  = gradient_table(np.hstack((bvalsHCP,bvalsHCP3)), np.vstack((bvecsHCP,bvecsHCP3)))

data, affine, img = load_nifti(fdwi, return_img=True)
data, affine = reslice(data, affine, (1.5,1.5,1.5), (2.5,2.5,2.5))
axial_middle = data.shape[2] // 2
maskdata, mask = median_otsu(data, vol_idx=range(10, 50), median_radius=3,
                             numpass=1, autocrop=False, dilate=2)

data3, affine, img = load_nifti(fdwi3, return_img=True)
data3, affine = reslice(data3, affine, (1.5,1.5,1.5), (2.5,2.5,2.5))
# Get the indices of True values
true_indices = np.argwhere(mask)

# Determine the minimum and maximum indices along each dimension
min_coords = true_indices.min(axis=0)
max_coords = true_indices.max(axis=0)

maskdata  = maskdata[min_coords[0]:max_coords[0]+1,min_coords[1]:max_coords[1]+1,min_coords[2]:max_coords[2]+1]
maskdata3 = data3[min_coords[0]:max_coords[0]+1,min_coords[1]:max_coords[1]+1,min_coords[2]:max_coords[2]+1]

TestData = np.concatenate([maskdata[:, :, axial_middle, :],maskdata3[:, :, axial_middle, :]],axis=-1)
FlatTD = TestData.reshape(maskdata.shape[0]*maskdata.shape[1],138)
FlatTD = FlatTD[FlatTD[:,:69].sum(axis=-1)>0]
FlatTD = FlatTD[~np.array(FlatTD<0).any(axis=-1)]

dkimodel = dki.DiffusionKurtosisModel(gtabExt)
tenfit = dkimodel.fit(FlatTD)
DKIHCP = tenfit.kt
DTIHCP = tenfit.lower_triangular()
DKIFull = np.array(DKIHCP)
DTIFull = np.array(DTIHCP)


DTIFilt1 = DTIFull[(abs(DKIFull)<10).all(axis=1)]
DKIFilt1 = DKIFull[(abs(DKIFull)<10).all(axis=1)]
DTIFilt = DTIFilt1[(DKIFilt1>-3/7).all(axis=1)]
DKIFilt = DKIFilt1[(DKIFilt1>-3/7).all(axis=1)]

TrueMets = []
FA       = []
for (dt,kt) in tqdm(zip(DTIFilt,DKIFilt)):
    TrueMets.append(DKIMetrics(dt,kt))
    FA.append(FracAni(np.linalg.eigh(vals_to_mat(dt))[0],np.mean(np.linalg.eigh(vals_to_mat(dt))[0])))
TrueMets = np.array(TrueMets)
TrueFA = np.array(FA)

In [None]:
# Full fit
DT1_full,DT2_full = FitDT(DTIFilt,1)
x4_full,R1_full,x2_full,R2_full = FitKT(DKIFilt,1)

# LowFA Fit
DT1_lfa,DT2_lfa = FitDT(DTIFilt[TrueMets[:,-1]<0.3,:],1)
x4_lfa,R1_lfa,x2_lfa,R2_lfa = FitKT(DKIFilt[TrueMets[:,-1]<0.3,:],1)

# HighFA Fit
DT1_hfa,DT2_hfa = FitDT(DTIFilt[TrueMets[:,-1]>0.7,:],1)
x4_hfa,R1_hfa,x2_hfa,R2_hfa = FitKT(DKIFilt[TrueMets[:,-1]>0.7,:],1)

# UltraLowFA Fit
DT1_ulfa,DT2_ulfa = FitDT(DTIFilt[TrueMets[:,-1]<0.1,:],1)
x4_ulfa,R1_ulfa,x2_ulfa,R2_ulfa = FitKT(DKIFilt[TrueMets[:,-1]<0.1,:],1)

# HigherAK Fit
DT1_hak,DT2_hak = FitDT(DTIFilt[TrueMets[:,1]>0.9,:],1)
x4_hak,R1_hak,x2_hak,R2_hak = FitKT(DKIFilt[TrueMets[:,1]>0.9,:],1)

# Figure 2

## DTI

In [None]:
custom_prior = DTIPriorS0(lower_abs,upper_abs,lower_rest,upper_rest,lower_S0,upper_S0)
priorS0, *_ = process_prior(custom_prior) 

In [None]:
fimg_init, fbvals, fbvecs = get_fnames('small_64D')
bvals, bvecs = read_bvals_bvecs(fbvals, fbvecs)
hsph_initial = HemiSphere(xyz=bvecs[1:])
hsph_initial7 = HemiSphere(xyz=bvecs[1:7])
hsph_updated,potentials = disperse_charges(hsph_initial,5000)
hsph_updated7,potentials = disperse_charges(hsph_initial7,5000)

gtabSimF = gradient_table(np.array([0]+[1000]*64).squeeze(), np.vstack([[0,0,0],hsph_updated.vertices]))
gtabSim7 = gradient_table(np.array([0]+[1000]*6).squeeze(), np.vstack([[0,0,0],hsph_updated7.vertices]))

np.random.seed(1)

gTabs = [gtabSimF]
for _ in range(4):
    x = np.random.permutation(np.arange(65))
    bvecs_shuffle = gtabSimF.bvecs[x]
    bvals_shuffle = gtabSimF.bvals[x]
    
    gTabs.append(gradient_table(bvals_shuffle, bvecs_shuffle))

torch.manual_seed(0)
np.random.seed(0)

params = priorS0.sample()
dtTruth = ComputeDTI(params)
dtTruth = ForceLowFA(dtTruth)
Truth = CustomSimulator(dtTruth,gtabSimF,S0=200,snr=None)

    
dt_evals,dt_evecs = np.linalg.eigh(dtTruth)

SNR = [CustomSimulator(dtTruth,gtabSimF, S0=200,snr=scale) for scale in NoiseLevels[1:]]
    
SNR = np.array(SNR)

In [None]:
np.random.seed(13)
SNR20 = np.vstack([CustomSimulator(dtTruth,gtabSimF, S0=200,snr=20) for k in range(200)])
SNR10 = np.vstack([CustomSimulator(dtTruth,gtabSimF, S0=200,snr=10) for k in range(200)])
SNR5 = np.vstack([CustomSimulator(dtTruth,gtabSimF, S0=200,snr=5) for k in range(200)])
SNR2 = np.vstack([CustomSimulator(dtTruth,gtabSimF, S0=200,snr=2) for k in range(200)])

tenmodel = dti.TensorModel(gtabSimF,return_S0_hat = True,fit_method='NLLS')
tenfit = tenmodel.fit(SNR20)
FA20 = dti.fractional_anisotropy(tenfit.evals)
MD20 = dti.mean_diffusivity(tenfit.evals)
tenfit = tenmodel.fit(SNR10)
FA10 = dti.fractional_anisotropy(tenfit.evals)
MD10 = dti.mean_diffusivity(tenfit.evals)
tenfit = tenmodel.fit(SNR5)
FA5 = dti.fractional_anisotropy(tenfit.evals)
MD5 = dti.mean_diffusivity(tenfit.evals)
tenfit = tenmodel.fit(SNR2)
FA2 = dti.fractional_anisotropy(tenfit.evals)
MD2 = dti.mean_diffusivity(tenfit.evals)


In [None]:
np.random.seed(1)
torch.manual_seed(1)
if os.path.exists(f"{network_path}/DTISimFull.pickle"):
    with open(f"{network_path}/DTISimFull.pickle", "rb") as handle:
        posteriorFull = pickle.load(handle)
else:
    Obs = []
    Par = []
    for i in tqdm(range(TrainingSamples)):
        params = priorNoise.sample()
        dt = ComputeDTI(params)
        dt = ForceLowFA(dt)
        a = params[-1]
        Obs.append(CustomSimulator(dt,gtabSimF,200,a))
        Par.append(np.hstack([mat_to_vals(dt),a]))
    
    Obs = np.array(Obs)
    Par = np.array(Par)
    Obs = torch.tensor(Obs).float()
    Par = torch.tensor(Par).float()
    
    # Create inference object. Here, NPE is used.
    inference = SNPE(prior=priorNoise)
    
    # generate simulations and pass to the inference object
    inference = inference.append_simulations(Par, Obs)
    
    # train the density estimator and build the posterior
    density_estimator = inference.train()
    posteriorFull = inference.build_posterior(density_estimator)
    if not os.path.exists(f"{network_path}/DTISimFull.pickle"):
        with open(f"{network_path}/DTISimFull.pickle", "wb") as handle:
            pickle.dump(posteriorFull, handle)

np.random.seed(1)
torch.manual_seed(1)
if os.path.exists(f"{network_path}/DTISimMin.pickle"):
    with open(f"{network_path}/DTISimMin.pickle", "rb") as handle:
        posterior7 = pickle.load(handle)
else:
    Obs = []
    Par = []
    for i in tqdm(range(TrainingSamples)):
        params = priorNoise.sample()
        dt = ComputeDTI(params)
        dt = ForceLowFA(dt)
        a = params[-1]
        Obs.append(CustomSimulator(dt,gtabSim7,200,a))
        Par.append(np.hstack([mat_to_vals(dt),a]))
    
    Obs = np.array(Obs)
    Par = np.array(Par)
    Obs = torch.tensor(Obs).float()
    Par = torch.tensor(Par).float()
    
    # Create inference object. Here, NPE is used.
    inference = SNPE(prior=priorNoise)
    
    # generate simulations and pass to the inference object
    inference = inference.append_simulations(Par, Obs)
    
    # train the density estimator and build the posterior
    density_estimator = inference.train()
    posterior7 = inference.build_posterior(density_estimator)
    if not os.path.exists(f"{network_path}/DTISimMin.pickle"):
        with open(f"{network_path}/DTISimMin.pickle", "wb") as handle:
            pickle.dump(posterior7, handle)

In [None]:
torch.manual_seed(2)
np.random.seed(2)
MD20_SBI = []
FA20_SBI = []
for S in tqdm(SNR20):
    posterior_samples_1 = posteriorFull.sample((InferSamples,), x=S,show_progress_bars=False)
    Guess = vals_to_mat(posterior_samples_1.mean(axis=0))
    Guess_clean = clip_negative_eigenvalues(Guess)
    evals_guess_raw,evecs_guess = np.linalg.eigh(Guess_clean)
    MD20_SBI.append(np.mean(evals_guess_raw))
    FA20_SBI.append(FracAni(evals_guess_raw,MD20_SBI[-1]))

torch.manual_seed(2)
np.random.seed(2)
MD10_SBI = []
FA10_SBI = []
for S in tqdm(SNR10):
    posterior_samples_1 = posteriorFull.sample((InferSamples,), x=S,show_progress_bars=False)
    Guess = vals_to_mat(posterior_samples_1.mean(axis=0))
    Guess_clean = clip_negative_eigenvalues(Guess)
    evals_guess_raw,evecs_guess = np.linalg.eigh(Guess_clean)
    if((evals_guess_raw<0).any()): print(True)
    MD10_SBI.append(np.mean(evals_guess_raw))
    FA10_SBI.append(FracAni(evals_guess_raw,MD10_SBI[-1]))

torch.manual_seed(2)
np.random.seed(2)
MD5_SBI = []


FA5_SBI = []
for S in tqdm(SNR5):
    posterior_samples_1 = posteriorFull.sample((InferSamples,), x=S,show_progress_bars=False)
    Guess = vals_to_mat(posterior_samples_1.mean(axis=0))
    Guess_clean = clip_negative_eigenvalues(Guess)
    evals_guess_raw,evecs_guess = np.linalg.eigh(Guess_clean)
    if((evals_guess_raw<0).any()): print(True)
    MD5_SBI.append(np.mean(evals_guess_raw))
    FA5_SBI.append(FracAni(evals_guess_raw,MD5_SBI[-1]))

torch.manual_seed(2)
np.random.seed(2)
MD2_SBI = []
FA2_SBI = []
for S in tqdm(SNR2):
    posterior_samples_1 = posteriorFull.sample((InferSamples,), x=S,show_progress_bars=False)
    Guess = vals_to_mat(posterior_samples_1.mean(axis=0))
    Guess_clean = clip_negative_eigenvalues(Guess)
    evals_guess_raw,evecs_guess = np.linalg.eigh(Guess_clean)
    if((evals_guess_raw<0).any()): print(True)
    MD2_SBI.append(np.mean(evals_guess_raw))
    FA2_SBI.append(FracAni(evals_guess_raw,MD2_SBI[-1]))

In [None]:
np.random.seed(0)
torch.manual_seed(0)
Samples  = []
DTISim = []
S0Sim    = []

params = priorS0.sample([500])
for i in tqdm(range(500)):
    dt = ComputeDTI(params[i])
    dt = ForceLowFA(dt)
    DTISim.append(dt)
    S0Sim.append(params[i,-1])
    Samples.append([CustomSimulator(dt,gtabSimF, S0=200,snr=scale) for scale in NoiseLevels])
    
Samples = np.array(Samples).squeeze()
Samples = np.moveaxis(Samples, 0, -1)

Samples7  = []
DTISim7 = []
S0Sim7    = []

params = priorS0.sample([500])
for i in tqdm(range(500)):
    dt = ComputeDTI(params[i])
    dt = ForceLowFA(dt)
    DTISim7.append(dt)
    S0Sim7.append(params[i,-1])
    Samples7.append([CustomSimulator(dt,gtabSim7, S0=200,snr=scale) for scale in NoiseLevels])
    
Samples7 = np.array(Samples7).squeeze()
Samples7 = np.moveaxis(Samples7, 0, -1)

In [None]:
torch.manual_seed(10)
ErrorFull = []
NoiseApproxFull = []
for k in tqdm(range(5)):
    ErrorN2 = []
    ENoise = []
    for i in range(500):
        tparams = mat_to_vals(DTISim[i])
        tObs = Samples[k,:,i]
        mat_true = vals_to_mat(tparams)
        evals_true,evecs_true = np.linalg.eigh(mat_true)
        true_signal_dti = single_tensor(gtabSimF, S0=200, evals=evals_true, evecs=evecs_true,
                           snr=None)
        posterior_samples_1 = posteriorFull.sample((InferSamples,), x=tObs,show_progress_bars=False)
        mat_guess = vals_to_mat([histogram_mode(p) for p in posterior_samples_1.T])
        mat_guess = clip_negative_eigenvalues(mat_guess)
        ErrorN2.append(Errors(mat_guess,mat_true,gtabSimF,true_signal_dti,tObs))
        ENoise.append(posterior_samples_1[:,-1].mean())
    NoiseApproxFull.append(ENoise)
    ErrorFull.append(ErrorN2)

NoiseApproxFull = np.array(NoiseApproxFull)    

Error_s = []
for k,gtab,Samps,DTIS in zip([65,7],[gtabSimF,gtabSim7],[Samples,Samples7],[DTISim,DTISim7]):
    tenmodel = dti.TensorModel(gtab,fit_method='NLLS')
    Error_n = []
    for S,Noise in zip(Samps,NoiseLevels):
        Error = []
        for i in range(500):
            tenfit = tenmodel.fit(S[:,i])
            tensor_vals = dti.lower_triangular(tenfit.quadratic_form)
            DT_test = vals_to_mat(tensor_vals)
            Error.append(Errors(DT_test,DTIS[i],gtab,Samps[0][:,i],S[:,i]))
        Error_n.append(Error)
    Error_s.append(Error_n)
Error_s = np.array(Error_s)
Error_s = np.swapaxes(Error_s,0,1)

torch.manual_seed(10)
Error7 = []
NoiseApprox7 = []
for k in tqdm(range(5)):
    ErrorN2 = []
    ENoise = []
    for i in range(500):
        tparams = mat_to_vals(DTISim7[i])
        tObs = Samples7[k,:,i]
        mat_true = vals_to_mat(tparams)
        evals_true,evecs_true = np.linalg.eigh(mat_true)
        true_signal_dti = single_tensor(gtabSim7, S0=200, evals=evals_true, evecs=evecs_true,
                           snr=None)
        posterior_samples_1 = posterior7.sample((InferSamples,), x=tObs,show_progress_bars=False)
        mat_guess = vals_to_mat(np.array(posterior_samples_1.mean(axis=0)))
        ErrorN2.append(Errors(mat_guess,mat_true,gtabSim7,true_signal_dti,tObs))
        ENoise.append(posterior_samples_1[:,-1].mean())
    NoiseApprox7.append(ENoise)
    Error7.append(ErrorN2)

NoiseApprox7 = np.array(NoiseApprox7)    


### a

In [None]:
plt.subplots(figsize=(6,1))
plt.plot(Truth,'k',lw=2,label='True signal')
plt.plot(SNR[0],'gray',lw=2,ls='--',label='Noisy signal')
plt.axis('off')
legend= plt.legend(ncols=2,loc=1,bbox_to_anchor =  (1.03,1.7),fontsize=26,columnspacing=0.3,handlelength=0.4,handletextpad=0.1)
for handle in legend.get_lines():
    handle.set_linewidth(6)  # Set desired linewidth
if Save: plt.savefig(FigLoc+'EgSig20.pdf',format='pdf',bbox_inches='tight',transparent=True)
plt.show()
plt.subplots(figsize=(6,1))
plt.plot(Truth,'k',lw=2)
plt.plot(SNR[1],'gray',lw=2,ls='--')
plt.axis('off')
if Save: plt.savefig(FigLoc+'EgSig10.pdf',format='pdf',bbox_inches='tight',transparent=True)
plt.show()
plt.subplots(figsize=(6,1))
plt.plot(Truth,'k',lw=2)
plt.plot(SNR[2],'gray',lw=2,ls='--')
plt.axis('off')
if Save: plt.savefig(FigLoc+'EgSig5.pdf',format='pdf',bbox_inches='tight',transparent=True)
plt.show()
plt.subplots(figsize=(6,1))
plt.plot(Truth,'k',lw=2)
plt.plot(SNR[3],'gray',lw=2,ls='--')
plt.axis('off')
if Save: plt.savefig(FigLoc+'EgSig2.pdf',format='pdf',bbox_inches='tight',transparent=True)
plt.show()

### b

In [None]:
fig,ax = plt.subplots(figsize=(6.4,2.4))
y_data = np.array([FA20_SBI,FA10_SBI,FA5_SBI,FA2_SBI])
g_pos = np.array([1.3,2.3,3.3,4.3])

colors = ['lightseagreen','lightseagreen','lightseagreen','lightseagreen']
colors2 = ['paleturquoise','paleturquoise','paleturquoise','paleturquoise']

BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.3,scatter=True)

g_pos = np.array([1,2,3,4])
colors = ['sandybrown','sandybrown','sandybrown','sandybrown']
colors2 = ['peachpuff','peachpuff','peachpuff','peachpuff']
y_data = np.array([FA20,FA10,FA5,FA2])
BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.3,scatter=True)

l = plt.axhline(FracAni(dt_evals,np.mean(dt_evals)),c='k',lw=3,ls='--',label='True FA')
plt.xticks([1,2,3,4],[20,10,5,2],fontsize=28)
plt.xticks(fontsize=28)
#plt.xlabel('SNR',fontsize=32)
#plt.ylabel('FA',fontsize=32)
leg_patch1 = mpatches.Patch(color='lightseagreen', label='SBI Fit')
leg_patch2 = mpatches.Patch(color='sandybrown', label='NLLS Fit')
ax.legend(
    handles=[leg_patch1],
    loc='upper left',         # base location  # fine-tune the legend's position
    frameon=False, ncols=1,
fontsize=32,columnspacing=0.3,handlelength=0.6,handletextpad=0.3,bbox_to_anchor=(0,1))
plt.yticks([0,1])
if Save: plt.savefig(FigLoc+'EgNoiseFA.pdf',format='pdf',bbox_inches='tight',transparent=True)

In [None]:
fig,ax = plt.subplots(figsize=(6.4,2.4))
y_data = np.array([MD20_SBI,MD10_SBI,MD5_SBI,MD2_SBI])
g_pos = np.array([1.3,2.3,3.3,4.3])

colors = ['lightseagreen','lightseagreen','lightseagreen','lightseagreen']
colors2 = ['paleturquoise','paleturquoise','paleturquoise','paleturquoise']

BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.3,scatter=True)

g_pos = np.array([1,2,3,4])
colors = ['sandybrown','sandybrown','sandybrown','sandybrown']
colors2 = ['peachpuff','peachpuff','peachpuff','peachpuff']
y_data = np.array([MD20,MD10,MD5,MD2])
BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.3,scatter=True)

l = plt.axhline(np.mean(dt_evals),c='k',lw=3,ls='--',label='True MD')
plt.xticks([])
#plt.xticks(fontsize=28)
#plt.xlabel('SNR',fontsize=32)
#plt.ylabel('MD',fontsize=32)
leg_patch2 = mpatches.Patch(color='sandybrown', label='NLLS Fit')
ax.legend(
    handles=[leg_patch2],
    loc='upper left',         # base location  # fine-tune the legend's position
    frameon=False, ncols=1,
fontsize=32,columnspacing=0.3,handlelength=0.6,handletextpad=0.3,bbox_to_anchor=(0,0.5))
plt.yticks([0,0.001,0.002])
plt.ylim((0, 0.0025))
plt.gca().ticklabel_format(axis='y',style='sci',scilimits=(-1,1))
if Save: plt.savefig(FigLoc+'EgNoiseMD.pdf',format='pdf',bbox_inches='tight',transparent=True)

### c

In [None]:
fig,axs = plt.subplots(1,2,figsize=(9,3),constrained_layout=True)
ax = axs.ravel()
for ll,(a,E,t) in enumerate(zip(ax,np.array(ErrorFull).T,Errors_name)):
    y_data = E[:,1:]
    g_pos = np.array([1.3,2.3,3.3,4.3])
    colors = ['lightseagreen','lightseagreen','lightseagreen','lightseagreen']
    colors2 = ['paleturquoise','paleturquoise','paleturquoise','paleturquoise']
    
    BoxPlots(y_data.T,g_pos,colors,colors2,a,widths=0.3,scatter=False)
    y_data = Error_s[1:,0,:,ll].T
    g_pos = np.array([1,2,3,4])
    colors = ['sandybrown','sandybrown','sandybrown','sandybrown']
    colors2 = ['peachpuff','peachpuff','peachpuff','peachpuff']
    
    BoxPlots(y_data.T,g_pos,colors,colors2,a,widths=0.3,scatter=False)


    plt.sca(a)
    if(ll>3):
        plt.xlabel('SNR', fontsize=24)
    plt.xticks([1.15, 2.15, 3.15, 4.15,], NoiseLevels[1:])
    #ymax = max(bp2['whiskers'][1].get_ydata()[1],bp['whiskers'][1].get_ydata()[1])*1.2
    # Create custom legend handles
    plt.gca().ticklabel_format(axis='y',style='sci',scilimits=(-1,1))
    plt.yticks(fontsize=32)
    plt.xticks([1.15, 2.15, 3.15, 4.15,], NoiseLevels[1:],fontsize=32)

    if(ll==0):
        handles = [
            Line2D([0], [0], color=SBIFit, lw=4, label='SBI'),  # Adjust color as per the actual plot color
        ]
        # Add the legenda
        plt.legend(handles=handles,loc=2, bbox_to_anchor=(0,1.15),
                   fontsize=36,columnspacing=0.3,handlelength=0.6,handletextpad=0.3)
    if(ll==1):
        handles = [
            Line2D([0], [0], color=WLSFit, lw=4, label='NLLS'),  # Adjust color as per the actual plot color
        ]
        # Add the legenda
        plt.legend(handles=handles,loc=2, bbox_to_anchor=(0,1.15),
                   fontsize=36,columnspacing=0.3,handlelength=0.6,handletextpad=0.3)
        plt.ylim([-0.05,1])
        plt.yticks([0,1])
    plt.grid()

if Save: plt.savefig(FigLoc+'SimDatDTIErrors1.pdf',format='pdf',bbox_inches='tight',transparent=True)

### d

In [None]:
fig,axs = plt.subplots(1,2,figsize=(9,3),constrained_layout=True)
ax = axs.ravel()
for ll,(a,E,t) in enumerate(zip(ax,np.array(Error7).T,Errors_name)):
    y_data = E[:,1:]
    g_pos = np.array([1.3,2.3,3.3,4.3])
    colors = ['lightseagreen','lightseagreen','lightseagreen','lightseagreen']
    colors2 = ['paleturquoise','paleturquoise','paleturquoise','paleturquoise']
    
    BoxPlots(y_data.T,g_pos,colors,colors2,a,widths=0.3,scatter=False)
    y_data = Error_s[1:,-1,:,ll].T
    g_pos = np.array([1,2,3,4])
    colors = ['sandybrown','sandybrown','sandybrown','sandybrown']
    colors2 = ['peachpuff','peachpuff','peachpuff','peachpuff']
    
    BoxPlots(y_data.T,g_pos,colors,colors2,a,widths=0.3,scatter=False)




    plt.sca(a)
    if(ll>3):
        plt.xlabel('SNR', fontsize=24)
    plt.xticks([1.15, 2.15, 3.15, 4.15,], NoiseLevels[1:])
    #ymax = max(bp2['whiskers'][1].get_ydata()[1],bp['whiskers'][1].get_ydata()[1])*1.2
    # Create custom legend handles
    plt.gca().ticklabel_format(axis='y',style='sci',scilimits=(-1,1))
    plt.yticks(fontsize=32)
    plt.xticks([1.15, 2.15, 3.15, 4.15,], NoiseLevels[1:],fontsize=32)

    if(ll==0):
        handles = [
            Line2D([0], [0], color=SBIFit, lw=4, label='SBI'),  # Adjust color as per the actual plot color
        ]
        # Add the legenda
        plt.legend(handles=handles,loc=2, bbox_to_anchor=(0,1.15),
                   fontsize=36,columnspacing=0.3,handlelength=0.6,handletextpad=0.3)
    if(ll==1):
        handles = [
            Line2D([0], [0], color=WLSFit, lw=4, label='NLLS'),  # Adjust color as per the actual plot color
        ]
        # Add the legenda
        plt.legend(handles=handles,loc=2, bbox_to_anchor=(0,1.15),
                   fontsize=36,columnspacing=0.3,handlelength=0.6,handletextpad=0.3)
        plt.ylim([-0.05,1])
        plt.yticks([0,1])
    plt.grid()
#plt.tight_layout()

if Save: plt.savefig(FigLoc+'SimDatDTIErrors2.pdf',format='pdf',bbox_inches='tight',transparent=True)

## DKI

In [None]:
fimg_init, fbvals, fbvecs = get_fnames('small_64D')
bvals, bvecs = read_bvals_bvecs(fbvals, fbvecs)
hsph_initial = HemiSphere(xyz=bvecs[1:])
hsph_updated,_ = disperse_charges(hsph_initial,5000)
bvecsExt = np.vstack([[0,0,0],hsph_updated.vertices])
bvalsExt = np.hstack([bvals, 3000*np.ones_like(bvals)])
bvecsExt = np.vstack([bvecsExt, bvecsExt])
bvalsExt[65] = 0
gtabSim = gradient_table(bvalsExt, bvecsExt)


hsph_initial15 = HemiSphere(xyz=bvecs[1:16])
hsph_initial7 = HemiSphere(xyz=bvecs[1:7])
hsph_updated15,_ = disperse_charges(hsph_initial15,5000)
hsph_updated7,_ = disperse_charges(hsph_initial7,5000)
gtabSimSub = gradient_table(np.array([0]+[1000]*6+[3000]*15).squeeze(), np.vstack([[0,0,0],hsph_updated7.vertices,hsph_updated15.vertices]))

In [None]:
torch.manual_seed(2)
np.random.seed(2)
j = 1
vL = torch.tensor([0.2*j])
vS = torch.tensor([0.01*j])  

kk = np.random.randint(0,4)
if(kk==0):
    DT,KT = GenDTKT([DT1_full,DT2_full],[x4_full,R1_full,x2_full,R2_full],2,1)
elif(kk==1):
    DT,KT = GenDTKT([DT1_lfa,DT2_lfa],[x4_lfa,R1_lfa,x2_lfa,R2_lfa],2,1)
elif(kk==2):
    DT,KT = GenDTKT([DT1_hfa,DT2_hfa],[x4_hfa,R1_hfa,x2_hfa,R2_hfa],2,1)
elif(kk==3):
    DT,KT = GenDTKT([DT1_ulfa,DT2_ulfa],[x4_ulfa,R1_ulfa,x2_ulfa,R2_ulfa],2,1)

tObs = CustomDKISimulator(DT.squeeze(),KT.squeeze(),gtabSim,200,20)
tObs7 = CustomDKISimulator(np.squeeze(DT),np.squeeze(KT),gtabSimSub,200,20)
tTrue = CustomDKISimulator(DT.squeeze(),KT.squeeze(),gtabSim,200,None)

torch.manual_seed(1)
np.random.seed(1)
DT2,KT2 = GenDTKT([DT1_lfa,DT2_lfa],[x4_lfa,R1_lfa,x2_lfa,R2_lfa],1,50)
DT3,KT3 = GenDTKT([DT1_hfa,DT2_hfa],[x4_hfa,R1_hfa,x2_hfa,R2_hfa],1,50)
DT5,KT5 = GenDTKT([DT1_hak,DT2_hak],[x4_hak,R1_hak,x2_hak,R2_hak],12,100)

SampsDT = np.vstack([DT2,DT3,DT5])
SampsKT = np.vstack([KT2,KT3,KT5])

Samples  = []
for Sd,Sk in zip(SampsDT,SampsKT):
    Samples.append([CustomDKISimulator(Sd,Sk,gtabSim, S0=200,snr=scale) for scale in NoiseLevels])

Samples = np.array(Samples)

Samples7  = []
for Sd,Sk in zip(SampsDT,SampsKT):
    Samples7.append([CustomDKISimulator(Sd,Sk,gtabSimSub, S0=200,snr=scale) for scale in NoiseLevels])

Samples7 = np.array(Samples7)

In [None]:
if os.path.exists(f"{network_path}/DKISimFull.pickle"):
    with open(f"{network_path}/DKISimFull.pickle", "rb") as handle:
        posteriorFull = pickle.load(handle)
else:
    torch.manual_seed(1)
    np.random.seed(1)
    DT = []
    KT = []
    S0 = []
    DT1,KT1 = GenDTKT([DT1_full,DT2_full],[x4_full,R1_full,x2_full,R2_full],12,int(2.5*6000))
    DT2,KT2 = GenDTKT([DT1_lfa,DT2_lfa],[x4_lfa,R1_lfa,x2_lfa,R2_lfa],12,int(2.5*2000))
    DT3,KT3 = GenDTKT([DT1_hfa,DT2_hfa],[x4_hfa,R1_hfa,x2_hfa,R2_hfa],12,int(2.5*6000))
    DT4,KT4 = GenDTKT([DT1_ulfa,DT2_ulfa],[x4_ulfa,R1_ulfa,x2_ulfa,R2_ulfa],12,int(2.5*6000))
    DT5,KT5 = GenDTKT([DT1_hak,DT2_hak],[x4_hak,R1_hak,x2_hak,R2_hak],12,int(2.5*6000))
    
    
    DT = np.vstack([DT1,DT2,DT3,DT4,DT5])
    KT = np.vstack([KT1,KT2,KT3,KT4,KT5])
    
    S0 = np.array(S0).reshape(len(S0),1)
    
    indx = np.arange(len(KT))
    Obs = np.zeros([len(KT),len(gtabSim.bvecs)])
    kk = 0
    while len(indx)>0:
        for i in tqdm(indx): 
            Obs[i] = CustomDKISimulator(DT[i],KT[i],gtabSim,200,np.random.rand()*30)
        
        indxNew = []
        for i,O in enumerate(Obs):
            if (O>800).any() or (O<0).any():
                indxNew.append(i)
        KT[indxNew] = KT[indxNew]/2
        DT[indxNew] = GenDTKT([DT1_full,DT2_full],[x4_full,R1_full,x2_full,R2_full],kk,1)[0]
    
        indx = indxNew
        kk+=1
    Par = np.hstack([DT,KT])
    Obs = torch.tensor(Obs).float()
    Par = torch.tensor(Par).float()
    
    # Create inference object. Here, NPE is used.
    inference = SNPE()
    
    # generate simulations and pass to the inference object
    inference = inference.append_simulations(Par, Obs)
    
    # train the density estimator and build the posterior
    density_estimator = inference.train(stop_after_epochs= 100)
    posteriorFull = inference.build_posterior(density_estimator)
    
    os.system('say "Network done."')
    if not os.path.exists(f"{network_path}/DKISimFull.pickle"):
        with open(f"{network_path}/DKISimFull.pickle", "wb") as handle:
            pickle.dump(posteriorFull, handle)

if os.path.exists(f"{network_path}/DKISimMin.pickle"):
    with open(f"{network_path}/DKISimMin.pickle", "rb") as handle:
        posteriorMin = pickle.load(handle)
else:
    torch.manual_seed(1)
    np.random.seed(1)
    DT = []
    KT = []
    S0 = []
    DT1,KT1 = GenDTKT([DT1_full,DT2_full],[x4_full,R1_full,x2_full,R2_full],12,int(2.5*6000))
    DT2,KT2 = GenDTKT([DT1_lfa,DT2_lfa],[x4_lfa,R1_lfa,x2_lfa,R2_lfa],12,int(2.5*2000))
    DT3,KT3 = GenDTKT([DT1_hfa,DT2_hfa],[x4_hfa,R1_hfa,x2_hfa,R2_hfa],12,int(2.5*6000))
    DT4,KT4 = GenDTKT([DT1_ulfa,DT2_ulfa],[x4_ulfa,R1_ulfa,x2_ulfa,R2_ulfa],12,int(2.5*6000))
    DT5,KT5 = GenDTKT([DT1_hak,DT2_hak],[x4_hak,R1_hak,x2_hak,R2_hak],12,int(2.5*6000))
    
    
    DT = np.vstack([DT1,DT2,DT3,DT4,DT5])
    KT = np.vstack([KT1,KT2,KT3,KT4,KT5])
    
    S0 = np.array(S0).reshape(len(S0),1)
    
    indx = np.arange(len(KT))
    Obs = np.zeros([len(KT),len(gtabSimSub.bvecs)])
    kk = 0
    while len(indx)>0:
        for i in tqdm(indx): 
            Obs[i] = CustomDKISimulator(DT[i],KT[i],gtabSimSub,200,np.random.rand()*30)
        
        indxNew = []
        for i,O in enumerate(Obs):
            if (O>800).any() or (O<0).any():
                indxNew.append(i)
        KT[indxNew] = KT[indxNew]/2
        DT[indxNew] = GenDTKT([DT1_full,DT2_full],[x4_full,R1_full,x2_full,R2_full],kk,1)[0]
    
        indx = indxNew
        kk+=1
    Par = np.hstack([DT,KT])
    Obs = torch.tensor(Obs).float()
    Par = torch.tensor(Par).float()
    
    # Create inference object. Here, NPE is used.
    inference = SNPE()
    
    # generate simulations and pass to the inference object
    inference = inference.append_simulations(Par, Obs)
    
    # train the density estimator and build the posterior
    density_estimator = inference.train(stop_after_epochs= 100)
    posteriorFull = inference.build_posterior(density_estimator)
    
    os.system('say "Network done."')
    if not os.path.exists(f"{network_path}/DKISimMin.pickle"):
        with open(f"{network_path}/DKISimMin.pickle", "wb") as handle:
            pickle.dump(posteriorMin, handle)
            

In [None]:
torch.manual_seed(10)
ErrorFull = []
for k in tqdm(range(5)):
    ErrorN2 = []
    ENoise = []
    for i in range(200):
        posterior_samples_1 = posteriorFull.sample((InferSamples,), x=Samples[i,k,:],show_progress_bars=False)
        GuessSBI = posterior_samples_1.mean(axis=0)
        
        ErrorN2.append(DKIErrors(GuessSBI[:6],GuessSBI[6:],SampsDT[i],SampsKT[i]))
    ErrorFull.append(ErrorN2)

Error_s = []
dkimodel = dki.DiffusionKurtosisModel(gtabSim,fit_method='NLLS')

for k in tqdm(range(5)):
    ErrorN2 = []
    ENoise = []
    for i in range(200):
        tenfit = dkimodel.fit(Samples[i,k,:])
        
        ErrorN2.append(DKIErrors(tenfit.lower_triangular(),tenfit.kt,SampsDT[i],SampsKT[i]))
    Error_s.append(ErrorN2)



In [None]:
torch.manual_seed(10)
ErrorMin = []
for k in tqdm(range(5)):
    ErrorN2 = []
    ENoise = []
    for i in range(200):
        posterior_samples_1 = posteriorMin.sample((InferSamples,), x=Samples7[i,k,:],show_progress_bars=False)
        GuessSBI = posterior_samples_1.mean(axis=0)
        
        ErrorN2.append(DKIErrors(GuessSBI[:6],GuessSBI[6:],SampsDT[i],SampsKT[i]))
    ErrorMin.append(ErrorN2)

Error_s_min = []
dkimodel = dki.DiffusionKurtosisModel(gtabSimSub,fit_method='NLLS')

for k in tqdm(range(5)):
    ErrorN2 = []
    ENoise = []
    for i in range(200):
        tenfit = dkimodel.fit(Samples7[i,k,:])
        
        ErrorN2.append(DKIErrors(tenfit.lower_triangular(),tenfit.kt,SampsDT[i],SampsKT[i]))
    Error_s_min.append(ErrorN2)



### e

In [None]:
torch.manual_seed(1)
np.random.seed(1)
posterior_samples_1 = posteriorFull.sample((InferSamples,), x=tObs,show_progress_bars=True)
GuessDKI = posterior_samples_1.mean(axis=0)
GuessSig = CustomDKISimulator(GuessDKI[:6],GuessDKI[6:],gtabSim,200)
plt.subplots(figsize=(6,1))
plt.plot(tTrue,lw=2,c='k',label='True signal')
plt.plot(GuessSig,lw=2,c=SBIFit,ls='--',label='SBI Recon.')
plt.axis('off')
legend = plt.legend(ncols=2,loc=1,bbox_to_anchor =  (1,1.95),fontsize=26,columnspacing=0.3,handlelength=0.4,handletextpad=0.1)
for handle in legend.get_lines():
    handle.set_linewidth(6)  # Set desired linewidth
if Save: plt.savefig(FigLoc+'FullReconSBI.pdf',format='pdf',bbox_inches='tight',transparent=True)

In [None]:
dkimodel = dki.DiffusionKurtosisModel(gtabSim,fit_method='NLLS')
tenfit = dkimodel.fit(tObs)
plt.subplots(figsize=(6,1))
plt.plot(tTrue,lw=2,c='k')
plt.plot(tenfit.predict(gtabSim,200),lw=2,c=WLSFit,ls='--',label='NLLS Recon.')
plt.axis('off')
legend = plt.legend(ncols=2,loc=1,bbox_to_anchor =  (0.9,1.95),fontsize=26,columnspacing=0.3,handlelength=0.4,handletextpad=0.1)
for handle in legend.get_lines():
    handle.set_linewidth(6)  # Set desired linewidth
if Save: plt.savefig(FigLoc+'FullReconWLS.pdf',format='pdf',bbox_inches='tight',transparent=True)

### f

In [None]:
ErrorFull = np.array(ErrorFull)
Error_s = np.array(Error_s)
ErrorNames = ['MK Error', 'AK Error', 'RK Error', 'MKT Error', 'KFA Error']
fig,ax = plt.subplots(1,3,figsize=(13.5,3))
for i in range(3):
    plt.sca(ax[i])
    g_pos = np.array([1.3,2.3,3.3,4.3])
    colors = ['lightseagreen','lightseagreen','lightseagreen','lightseagreen']
    colors2 = ['paleturquoise','paleturquoise','paleturquoise','paleturquoise']
    BoxPlots(ErrorFull[1:,:,i],g_pos,colors,colors2,ax[i],widths=0.3,scatter=False)
    g_pos = np.array([1,2,3,4])
    colors = ['sandybrown','sandybrown','sandybrown','sandybrown']
    colors2 = ['peachpuff','peachpuff','peachpuff','peachpuff']
    BoxPlots(Error_s[1:,:,i],g_pos,colors,colors2,ax[i],widths=0.3,scatter=False)
    plt.xticks([1.15, 2.15, 3.15, 4.15,],[20,10,5,2],fontsize=32)
    plt.gca().ticklabel_format(axis='y',style='sci',scilimits=(-1,1))
    plt.grid(axis='y')
    plt.yticks(fontsize=32)
    
    if(i==0):
        handles = [
            Line2D([0], [0], color=SBIFit, lw=4, label='SBI'),Line2D([0], [0], color=WLSFit, lw=4, label='NLLS'),  # Adjust color as per the actual plot color
        ]
        # Add the legenda
        plt.legend(handles=handles,loc=2, bbox_to_anchor=(-0.1,1.1),
                   fontsize=36,columnspacing=0.3,handlelength=0.6,handletextpad=0.3)

if Save: plt.savefig(FigLoc+'ErrorsFull.pdf',format='pdf',bbox_inches='tight',transparent=True)

### g

In [None]:
torch.manual_seed(1)
np.random.seed(1)
posterior_samples_1 = posteriorMin.sample((InferSamples,), x=tObs7,show_progress_bars=True)
GuessDKI = posterior_samples_1.mean(axis=0)
GuessSig = CustomDKISimulator(GuessDKI[:6],GuessDKI[6:],gtabSim,200)
plt.subplots(figsize=(6,1))
plt.plot(tTrue,lw=2,c='k',label='True signal')
plt.plot(GuessSig,lw=2,c=SBIFit,ls='--',label='SBI Recon.')
plt.axis('off')
plt.fill_betweenx(np.arange(0,500,50),0*np.ones(10),7*np.ones(10),color='gray',alpha=0.5)
plt.fill_betweenx(np.arange(0,500,50),64*np.ones(10),79*np.ones(10),color='gray',alpha=0.5)
plt.ylim(-9.996985449425491, 209.99985644997255)
legend = plt.legend(ncols=2,loc=1,bbox_to_anchor =  (1,1.95),fontsize=26,columnspacing=0.3,handlelength=0.4,handletextpad=0.1)
for handle in legend.get_lines():
    handle.set_linewidth(6)  # Set desired linewidth
if Save: plt.savefig(FigLoc+'7ReconSBI.pdf',format='pdf',bbox_inches='tight',transparent=True)

In [None]:
dkimodel = dki.DiffusionKurtosisModel(gtabSimSub,fit_method='NLLS')
tenfit = dkimodel.fit(tObs7)
plt.subplots(figsize=(6,1))
plt.plot(tTrue,lw=2,c='k')
plt.plot(tenfit.predict(gtabSim,200),lw=2,c=WLSFit,ls='--',label='NLLS Recon.')
plt.axis('off')
legend = plt.legend(ncols=2,loc=1,bbox_to_anchor =  (0.9,1.95),fontsize=26,columnspacing=0.3,handlelength=0.4,handletextpad=0.1)
for handle in legend.get_lines():
    handle.set_linewidth(6)  # Set desired linewidth
plt.fill_betweenx(np.arange(0,500,50),0*np.ones(10),7*np.ones(10),color='gray',alpha=0.5)
plt.fill_betweenx(np.arange(0,500,50),64*np.ones(10),79*np.ones(10),color='gray',alpha=0.5)
plt.ylim(-9.996985449425491, 209.99985644997255)
if Save: plt.savefig(FigLoc+'7ReconWLS.pdf',format='pdf',bbox_inches='tight',transparent=True)

### h

In [None]:
ErrorFull = np.array(ErrorMin)
Error_s = np.array(Error_s_min)
ErrorNames = ['MK Error', 'AK Error', 'RK Error', 'MKT Error', 'KFA Error']
fig,ax = plt.subplots(1,3,figsize=(13.5,3))
for i in range(3):
    plt.sca(ax[i])
    g_pos = np.array([1.3,2.3,3.3,4.3])
    colors = ['lightseagreen','lightseagreen','lightseagreen','lightseagreen']
    colors2 = ['paleturquoise','paleturquoise','paleturquoise','paleturquoise']
    BoxPlots(ErrorFull[1:,:,i],g_pos,colors,colors2,ax[i],widths=0.3,scatter=False)
    g_pos = np.array([1,2,3,4])
    colors = ['sandybrown','sandybrown','sandybrown','sandybrown']
    colors2 = ['peachpuff','peachpuff','peachpuff','peachpuff']
    BoxPlots(Error_s[1:,:,i],g_pos,colors,colors2,ax[i],widths=0.3,scatter=False)
    plt.xticks([1.15, 2.15, 3.15, 4.15,],[20,10,5,2],fontsize=32)
    plt.gca().ticklabel_format(axis='y',style='sci',scilimits=(-1,1))
    plt.grid(axis='y')
    plt.yticks(fontsize=32)
    if(i==0):
        handles = [
            Line2D([0], [0], color=SBIFit, lw=4, label='SBI'),Line2D([0], [0], color=WLSFit, lw=4, label='NLLS'),  # Adjust color as per the actual plot color
        ]
        # Add the legenda
        plt.legend(handles=handles,loc=2, bbox_to_anchor=(-0.1,1.1),
                   fontsize=36,columnspacing=0.3,handlelength=0.6,handletextpad=0.3,labelspacing=0.1)
if Save: plt.savefig(FigLoc+'ErrorsMin.pdf',format='pdf',bbox_inches='tight',transparent=True)

# Figure 3

In [None]:
custom_prior = DTIPriorS0Direc(lower_abs,upper_abs,lower_rest,upper_rest,lower_S0,upper_S0)
priorDirec, *_ = process_prior(custom_prior) 

## HCP

In [None]:
fdwi = './HCP_data/Pat'+str(1)+'/diff_1k.nii.gz'
bvalloc = './HCP_data/Pat'+str(1)+'/bvals_1k.txt'
bvecloc = './HCP_data/Pat'+str(1)+'/bvecs_1k.txt'

bvalsHCP = np.loadtxt(bvalloc)
bvecsHCP = np.loadtxt(bvecloc)
gtabHCP = gradient_table(bvalsHCP, bvecsHCP)

data, affine, img = load_nifti(fdwi, return_img=True)
data, affine = reslice(data, affine, (1.5,1.5,1.5), (2.5,2.5,2.5))
axial_middle = data.shape[2] // 2
maskdata, mask = median_otsu(data, vol_idx=range(10, 50), median_radius=3,
                             numpass=1, autocrop=True, dilate=2)
mask_cutout = np.copy(mask[:,:,axial_middle])
# Choose the first point (arbitrary starting point, e.g., the first gradient)
selected_indices_alt = [1]
distance_matrix = squareform(pdist(bvecsHCP))
# Iteratively select the point furthest from the current selection
for _ in range(5):  # We need 7 points in total, and one is already selected
    remaining_indices = list(set(range(len(bvecsHCP))) - set(selected_indices_alt))
    
    # Calculate the minimum distance to the selected points for each remaining point
    min_distances = np.min(distance_matrix[remaining_indices][:, selected_indices_alt], axis=1)
    
    # Select the point with the maximum minimum distance
    next_index = remaining_indices[np.argmax(min_distances)]
    selected_indices_alt.append(next_index)

selected_indices_alt = [0]+selected_indices_alt

bvalsHCP7_alt = bvalsHCP[selected_indices_alt]
bvecsHCP7_alt = bvecsHCP[selected_indices_alt]
gtabHCP7_alt = gradient_table(bvalsHCP7_alt, bvecsHCP7_alt)

custom_prior = DTIPriorS0Noise(lower_abs,upper_abs,lower_rest,upper_rest,lower_S0,upper_S0,0,30)
priorS0Noise, *_ = process_prior(custom_prior) 


# Choose the first point (arbitrary starting point, e.g., the first gradient)
selected_indices = [0]
distance_matrix = squareform(pdist(bvecsHCP))
# Iteratively select the point furthest from the current selection
for _ in range(6):  # We need 7 points in total, and one is already selected
    remaining_indices = list(set(range(len(bvecsHCP))) - set(selected_indices))
    
    # Calculate the minimum distance to the selected points for each remaining point
    min_distances = np.min(distance_matrix[remaining_indices][:, selected_indices], axis=1)
    
    # Select the point with the maximum minimum distance
    next_index = remaining_indices[np.argmax(min_distances)]
    selected_indices.append(next_index)


# Choose the first point (arbitrary starting point, e.g., the first gradient)
selected_indices20 = [0]
distance_matrix = squareform(pdist(bvecsHCP))
# Iteratively select the point furthest from the current selection
for _ in range(19):  # We need 7 points in total, and one is already selected
    remaining_indices = list(set(range(len(bvecsHCP))) - set(selected_indices20))
    
    # Calculate the minimum distance to the selected points for each remaining point
    min_distances = np.min(distance_matrix[remaining_indices][:, selected_indices20], axis=1)
    
    # Select the point with the maximum minimum distance
    next_index = remaining_indices[np.argmax(min_distances)]
    selected_indices20.append(next_index)

In [None]:
if os.path.exists(f"{network_path}/DTIHCPFull.pickle"):
    with open(f"{network_path}/DTIHCPFull.pickle", "rb") as handle:
        posterior2 = pickle.load(handle)
else:
    np.random.seed(1)
    torch.manual_seed(1)
    bvals = gtabHCP.bvals
    bvecs = gtabHCP.bvecs
    Obs = []
    Par = []
    for i in tqdm.tqdm(range(TrainingSamples)):
        params = priorS0.sample()
        dt = ComputeDTI(params[:-1])
        dt = ForceLowFA(dt)
        Obs.append(CustomSimulator(dt,gtabHCP,params[-1],np.random.rand()*30 + 20))
        Par.append(np.hstack([mat_to_vals(dt),params[-1]]))
    
    Obs = np.array(Obs)
    Par = np.array(Par)
    Obs = torch.tensor(Obs).float()
    Par= torch.tensor(Par).float()
    
    # Create inference object. Here, NPE is used.
    inference = SNPE(prior=priorS0)
    
    # generate simulations and pass to the inference object
    inference = inference.append_simulations(Par, Obs)
    
    # train the density estimator and build the posterior
    density_estimator = inference.train(stop_after_epochs= 100)
    posterior2 = inference.build_posterior(density_estimator)
    if not os.path.exists(f"{save_path}/DTIHCPFull.pickle"):
        with open(f"{save_path}/DTIHCPFull.pickle", "wb") as handle:
            pickle.dump(posterior2, handle)

In [None]:
# Compute the mask where the sum is not zero
mask = np.sum(maskdata[:, :, axial_middle, :], axis=-1) != 0

# Get the indices where mask is True
indices = np.argwhere(mask)
def optimize_chunk(pixels):
    results = []
    for i, j in pixels:
        posterior_samples_1 = posterior2.sample((1000,), x=maskdata[i, j,axial_middle, :91],show_progress_bars=False)
        results.append((i, j, np.array([histogram_mode(p) for p in posterior_samples_1.T])))
    return results

chunked_indices = [indices[i:i+ChunkSize] for i in range(0, len(indices), ChunkSize)]
results = Parallel(n_jobs=8)(
    delayed(optimize_chunk)(chunk) for chunk in tqdm(chunked_indices)
)

# Initialize array with the appropriate shape
ArrShape = mask.shape
NoiseEst = np.zeros(list(ArrShape) + [7])

# Assign the optimization results to InferredParams
for chunk in results:
    for i, j, x in chunk:
        NoiseEst[i, j] = x

NoiseEst2 =  np.zeros_like(NoiseEst)
for i in range(55):
    for j in range(64):    
        NoiseEst2[i,j] = np.hstack([mat_to_vals(clip_negative_eigenvalues(vals_to_mat(NoiseEst[i,j]))),NoiseEst[i,j,-1]])
MD_SBIFull = np.zeros([55,64])
FA_SBIFull = np.zeros([55,64])
for i in range(55):
    for j in range(64):
        Eigs = np.linalg.eigh(vals_to_mat(NoiseEst2[i,j,:6]))[0]
        MD_SBIFull[i,j] = np.mean(Eigs)
        FA_SBIFull[i,j] = FracAni(Eigs,np.mean(Eigs))
FA_SBIFull[np.isnan(FA_SBIFull)] = 0

tenmodel = dti.TensorModel(gtabHCP,return_S0_hat = True,fit_method='NLLS')
tenfit = tenmodel.fit(maskdata[:,:,axial_middle])
FAFull = dti.fractional_anisotropy(tenfit.evals)
MDFull = dti.mean_diffusivity(tenfit.evals)

for i in range(55):
    for j in range(64):
        if(np.sum(maskdata[i,j,axial_middle,:]) == 0):
            FAFull[i,j] = 0

In [None]:
if os.path.exists(f"{network_path}/DTIHCPMin.pickle"):
    with open(f"{network_path}/DTIHCPMin.pickle", "rb") as handle:
        posterior7_2 = pickle.load(handle)
else:
    np.random.seed(1)
    torch.manual_seed(1)
    bvals = gtabHCP.bvals
    bvecs = gtabHCP.bvecs
    Obs = []
    Par = []
    for i in tqdm.tqdm(range(TrainingSamples)):
        params = priorS0.sample()
        dt = ComputeDTI(params[:-1])
        dt = ForceLowFA(dt)
        Obs.append(CustomSimulator(dt,gtabHCP7,params[-1],np.random.rand()*30 + 20))
        Par.append(np.hstack([mat_to_vals(dt),params[-1]]))
    
    Obs = np.array(Obs)
    Par = np.array(Par)
    Obs = torch.tensor(Obs).float()
    Par= torch.tensor(Par).float()
    
    # Create inference object. Here, NPE is used.
    inference = SNPE(prior=priorS0)
    
    # generate simulations and pass to the inference object
    inference = inference.append_simulations(Par, Obs)
    
    # train the density estimator and build the posterior
    density_estimator = inference.train(stop_after_epochs= 100)
    posterior7_2 = inference.build_posterior(density_estimator)
    if not os.path.exists(f"{save_path}/DTIHCPMin.pickle"):
        with open(f"{save_path}/DTIHCPMin.pickle", "wb") as handle:
            pickle.dump(posterior7_2, handle)

In [None]:
# Compute the mask where the sum is not zero
mask = np.sum(maskdata[:, :, axial_middle, :], axis=-1) != 0

# Get the indices where mask is True
indices = np.argwhere(mask)

# Define the function for optimization
def optimize_pixel(i, j):
    torch.manual_seed(10)  # If required
    posterior_samples_1 = posterior7_2.sample((1000,), x=maskdata[i, j,axial_middle, selected_indices_alt],show_progress_bars=False)
    return i, j, np.array([histogram_mode(p) for p in posterior_samples_1.T])

# Initialize array with the appropriate shape
ArrShape = mask.shape

# Use joblib to parallelize the optimization tasks
results = Parallel(n_jobs=8)(
    delayed(optimize_pixel)(i, j) for i, j in tqdm(indices)
)

NoiseEst = np.zeros(list(ArrShape) + [7])

# Assign the optimization results to InferredParams
for i, j, x in results:
    NoiseEst[i, j] = x

NoiseEst2 =  np.zeros_like(NoiseEst)
for i in range(55):
    for j in range(64):    
        NoiseEst2[i,j] = np.hstack([mat_to_vals(clip_negative_eigenvalues(vals_to_mat(NoiseEst[i,j]))),NoiseEst[i,j,-1]])
MD_SBI7 = np.zeros([55,64])
FA_SBI7 = np.zeros([55,64])
for i in range(55):
    for j in range(64):
        Eigs = np.linalg.eigh(vals_to_mat(NoiseEst2[i,j,:6]))[0]
        MD_SBI7[i,j] = np.mean(Eigs)
        FA_SBI7[i,j] = FracAni(Eigs,np.mean(Eigs))
FA_SBI7[np.isnan(FA_SBI7)] = 0

tenmodel = dti.TensorModel(gtabHCP7_alt,return_S0_hat = True,fit_method='NLLS')
tenfit = tenmodel.fit(maskdata[:,:,axial_middle,selected_indices_alt])
FA7 = dti.fractional_anisotropy(tenfit.evals)
MD7 = dti.mean_diffusivity(tenfit.evals)
for i in range(55):
    for j in range(64):
        if(np.sum(maskdata[i,j,axial_middle,:]) == 0):
            FA7[i,j] = 0

In [None]:
Masks = []
maskdatas = []
axial_middles = []
WMs = []

gTabsF = []
gTabs7 = []
gTabs20 = []

FullDat   = []
for kk in tqdm(range(32)):
    fdwi = './HCP_data/Pat'+str(kk+1)+'/diff_1k.nii.gz'
    bvalloc = './HCP_data/Pat'+str(kk+1)+'/bvals_1k.txt'
    bvecloc = './HCP_data/Pat'+str(kk+1)+'/bvecs_1k.txt'
    
    bvalsHCP = np.loadtxt(bvalloc)
    bvecsHCP = np.loadtxt(bvecloc)
    gtabHCP = gradient_table(bvalsHCP, bvecsHCP)
    gTabsF.append(gtabHCP)
    
    data, affine, img = load_nifti(fdwi, return_img=True)
    data, affine = reslice(data, affine, (1.5,1.5,1.5), (2.5,2.5,2.5))
    maskdata, _ = median_otsu(data, vol_idx=range(10, 50), median_radius=3,
                                 numpass=1, autocrop=True, dilate=2)
    axial_middle = maskdata.shape[2] // 2
    # Compute the mask where the sum is not zero
    mask = np.sum(maskdata[:, :, axial_middle, :], axis=-1) != 0
    Masks.append(mask)
    maskdatas.append(maskdata[:,:,axial_middle])
    axial_middles.append(axial_middle)

    WM, affine, img = load_nifti('./HCP_data/WM_Masks/c2Pat'+str(kk+1)+'_FP.nii', return_img=True)
    WMs.append(np.fliplr(WM[:,:,axial_middles[kk]]>0.8))


    
    bvalsHCP7 = bvalsHCP[selected_indices]
    bvecsHCP7 = bvecsHCP[selected_indices]
    gtabHCP7 = gradient_table(bvalsHCP7, bvecsHCP7)

    gTabs7.append(gtabHCP7)

    bvalsHCP20 = bvalsHCP[selected_indices20]
    bvecsHCP20 = bvecsHCP[selected_indices20]
    gtabHCP20 = gradient_table(bvalsHCP20, bvecsHCP20)

    gTabs20.append(gtabHCP20)

In [None]:
if os.path.exists(f"{network_path}/DTIMultiHCPFull_300.pickle"):
    with open(f"{network_path}/DTIMultiHCPFull_300.pickle", "rb") as handle:
        posteriorFull = pickle.load(handle)
else:
    np.random.seed(1)
    torch.manual_seed(1)
    Obs = []
    Par = []
    for i in tqdm(range(TrainingSamples)):
        params = priorDirec.sample()
        dt = ComputeDTI(params[:-2])
        dt = ForceLowFA(dt)
        cG = gTabsF[int(params[-1])]
        Obs.append(np.hstack([CustomSimulator(dt,cG,params[-2],50),params[-1]]))
        Par.append(np.hstack([mat_to_vals(dt),params[-2]]))
    
    Obs = np.array(Obs)
    Par = np.array(Par)
    
    Obs = torch.tensor(Obs).float()
    Par = torch.tensor(Par).float()
    
    
    # Create inference object. Here, NPE is used.
    inference = SNPE()
    
    # generate simulations and pass to the inference object
    inference = inference.append_simulations(Par, Obs)
    
    # train the density estimator and build the posterior
    density_estimator = inference.train(stop_after_epochs=100)
    posteriorFull = inference.build_posterior(density_estimator)
    if not os.path.exists(f"{network_path}/DTIMultiHCPFull_300.pickle"):
        with open(f"{network_path}/DTIMultiHCPFull_300.pickle", "wb") as handle:
            pickle.dump(posteriorFull, handle)

if os.path.exists(f"{network_path}/DTIMultiHCPMin_300.pickle"):
    with open(f"{network_path}/DTIMultiHCPMin_300.pickle", "rb") as handle:
        posteriorMin = pickle.load(handle)
else:
    np.random.seed(1)
    torch.manual_seed(1)
    Obs = []
    Par = []
    for i in tqdm.tqdm(range(TrainingSamples)):
        params = priorDirec.sample()
        dt = ComputeDTI(params[:-2])
        dt = ForceLowFA(dt)
        cG = gTabs7[int(params[-1])]
        Obs.append(np.hstack([CustomSimulator(dt,cG,params[-2],50),params[-1]]))
        Par.append(np.hstack([mat_to_vals(dt),params[-2]]))
    
    Obs = np.array(Obs)
    Par = np.array(Par)
    
    Obs = torch.tensor(Obs).float()
    Par = torch.tensor(Par).float()
    
    
    # Create inference object. Here, NPE is used.
    inference = SNPE()
    
    # generate simulations and pass to the inference object
    inference = inference.append_simulations(Par, Obs)
    
    # train the density estimator and build the posterior
    density_estimator = inference.train(stop_after_epochs=100)
    posteriorMin = inference.build_posterior(density_estimator)
    if not os.path.exists(f"{network_path}/DTIMultiHCPMin_300.pickle"):
        with open(f"{network_path}/DTIMultiHCPMin_300.pickle", "wb") as handle:
            pickle.dump(posteriorMin, handle)


if os.path.exists(f"{network_path}/DTIMultiHCPMid_300.pickle"):
    with open(f"{network_path}/DTIMultiHCPMid_300.pickle", "rb") as handle:
        posteriorMid = pickle.load(handle)
else:
    np.random.seed(1)
    torch.manual_seed(1)
    Obs = []
    Par = []
    for i in tqdm.tqdm(range(TrainingSamples)):
        params = priorDirec.sample()
        dt = ComputeDTI(params[:-2])
        dt = ForceLowFA(dt)
        cG = gTabs20[int(params[-1])]
        Obs.append(np.hstack([CustomSimulator(dt,cG,params[-2],50),params[-1]]))
        Par.append(np.hstack([mat_to_vals(dt),params[-2]]))
    
    Obs = np.array(Obs)
    Par = np.array(Par)
    
    Obs = torch.tensor(Obs).float()
    Par = torch.tensor(Par).float()
    
    
    # Create inference object. Here, NPE is used.
    inference = SNPE()
    
    # generate simulations and pass to the inference object
    inference = inference.append_simulations(Par, Obs)
    
    # train the density estimator and build the posterior
    density_estimator = inference.train(stop_after_epochs=100)
    posteriorMid = inference.build_posterior(density_estimator)
    if not os.path.exists(f"{network_path}/DTIMultiHCPMid_300.pickle"):
        with open(f"{network_path}/DTIMultiHCPMid_300.pickle", "wb") as handle:
            pickle.dump(posteriorMid, handle)

In [None]:
indices = np.argwhere(mask)
def optimize_chunk(pixels):
    results = []
    for i, j in pixels:
        posterior_samples_1 = posterior2.sample((1000,), x=maskdata[i, j,axial_middle, :91],show_progress_bars=False)
        results.append((i, j, np.array([histogram_mode(p) for p in posterior_samples_1.T])))
    return results

chunked_indices = [indices[i:i+ChunkSize] for i in range(0, len(indices), ChunkSize)]
results = Parallel(n_jobs=8)(
    delayed(optimize_chunk)(chunk) for chunk in tqdm(chunked_indices)
)

# Initialize array with the appropriate shape
ArrShape = mask.shape
NoiseEst = np.zeros(list(ArrShape) + [7])

In [None]:
for kk in tqdm(range(32)):
    fdwi = './HCP_data/Pat'+str(kk+1)+'/diff_1k.nii.gz'
    bvalloc = './HCP_data/Pat'+str(kk+1)+'/bvals_1k.txt'
    bvecloc = './HCP_data/Pat'+str(kk+1)+'/bvecs_1k.txt'
    
    bvalsHCP = np.loadtxt(bvalloc)
    bvecsHCP = np.loadtxt(bvecloc)
    gtabHCP = gradient_table(bvalsHCP, bvecsHCP)
    
    data, affine, img = load_nifti(fdwi, return_img=True)
    data, affine = reslice(data, affine, (1.5,1.5,1.5), (2.5,2.5,2.5))
    axial_middle = data.shape[2] // 2
    maskdata, _ = median_otsu(data, vol_idx=range(10, 50), median_radius=3,
                                 numpass=1, autocrop=True, dilate=2)
    # Compute the mask where the sum is not zero
    mask = np.sum(maskdata[:, :, axial_middle, :], axis=-1) != 0
    
    # Get the indices where mask is True
    # Compute the mask where the sum is not zero
    mask = np.sum(maskdata[:, :, axial_middle, :], axis=-1) != 0
    
    # Get the indices where mask is True
    Arr = maskdata[:,:,axial_middle, selected_indices]
    indices = np.argwhere(mask)
    def optimize_chunk(pixels):
        results = []
        for i, j in pixels:
            posterior_samples_1 = posteriorMin.sample((500,), x=np.hstack([Arr[i,j],kk]),show_progress_bars=False)
            results.append((i, j, np.array([histogram_mode(p) for p in posterior_samples_1.T])))
        return results
    
    chunked_indices = [indices[i:i+ChunkSize] for i in range(0, len(indices), ChunkSize)]
    results = Parallel(n_jobs=8)(
        delayed(optimize_chunk)(chunk) for chunk in tqdm(chunked_indices)
    )
    
    # Initialize array with the appropriate shape
    ArrShape = mask.shape
    NoiseEst = np.zeros(list(ArrShape) + [7])
    
    # Assign the optimization results to NoiseEst
    for chunk in results:
        for i, j, x in chunk:
            NoiseEst[i, j] = x

In [None]:
if os.path.exists(f"{DatFolder}/Full_MD_HCP.npy"):
    MDFullArr = np.load(f"{DatFolder}/Full_MD_HCP.npy",allow_pickle=True)
    FAFullArr = np.load(f"{DatFolder}/Full_FA_HCP.npy",allow_pickle=True)
else:
    MDFullArr = []
    FAFullArr = []
    for kk in tqdm(range(32)):
        fdwi = './HCP_data/Pat'+str(kk+1)+'/diff_1k.nii.gz'
        bvalloc = './HCP_data/Pat'+str(kk+1)+'/bvals_1k.txt'
        bvecloc = './HCP_data/Pat'+str(kk+1)+'/bvecs_1k.txt'
        
        bvalsHCP = np.loadtxt(bvalloc)
        bvecsHCP = np.loadtxt(bvecloc)
        gtabHCP = gradient_table(bvalsHCP, bvecsHCP)
        
        data, affine, img = load_nifti(fdwi, return_img=True)
        data, affine = reslice(data, affine, (1.5,1.5,1.5), (2.5,2.5,2.5))
        axial_middle = data.shape[2] // 2
        maskdata, _ = median_otsu(data, vol_idx=range(10, 50), median_radius=3,
                                     numpass=1, autocrop=True, dilate=2)
        # Compute the mask where the sum is not zero
        mask = np.sum(maskdata[:, :, axial_middle, :], axis=-1) != 0
        
        # Get the indices where mask is True
        indices = np.argwhere(mask)
        def optimize_chunk(pixels):
            results = []
            for i, j in pixels:
                posterior_samples_1 = posteriorFull.sample((500,), x=np.hstack([maskdata[i,j,axial_middle, :],kk]),show_progress_bars=False)
                results.append((i, j, np.array([histogram_mode(p) for p in posterior_samples_1.T])))
            return results
        
        chunked_indices = [indices[i:i+ChunkSize] for i in range(0, len(indices), ChunkSize)]
        results = Parallel(n_jobs=8)(
            delayed(optimize_chunk)(chunk) for chunk in tqdm(chunked_indices)
        )
        
        # Initialize array with the appropriate shape
        ArrShape = mask.shape
        NoiseEst = np.zeros(list(ArrShape) + [7])
        
        # Assign the optimization results to NoiseEst
        for chunk in results:
            for i, j, x in chunk:
                NoiseEst[i, j] = x
        
        NoiseEst2 =  np.zeros_like(NoiseEst)
        for i in range(ArrShape[0]):
            for j in range(ArrShape[1]):    
                NoiseEst2[i,j] = np.hstack([mat_to_vals(clip_negative_eigenvalues(vals_to_mat(NoiseEst[i,j]))),NoiseEst[i,j,-1:]])
        MD_SBIFull = np.zeros(ArrShape)
        FA_SBIFull = np.zeros(ArrShape)
        for i in range(ArrShape[0]):
            for j in range(ArrShape[1]): 
                Eigs = np.linalg.eigh(vals_to_mat(NoiseEst2[i,j,:6]))[0]
                MD_SBIFull[i,j] = np.mean(Eigs)
                FA_SBIFull[i,j] = FracAni(Eigs,np.mean(Eigs))
        FA_SBIFull[np.isnan(FA_SBIFull)] = 0
        MDFullArr.append(MD_SBIFull)
        FAFullArr.append(FA_SBIFull)
if os.path.exists(f"{DatFolder}/Min_MD_HCP.npy"):
    MDMinArr = np.load(f"{DatFolder}/Min_MD_HCP.npy",allow_pickle=True)
    FAMinArr = np.load(f"{DatFolder}/Min_FA_HCP.npy",allow_pickle=True)
else:
    MDMinArr = []
    FAMinArr = []
    for kk in tqdm(range(32)):
        fdwi = './HCP_data/Pat'+str(kk+1)+'/diff_1k.nii.gz'
        bvalloc = './HCP_data/Pat'+str(kk+1)+'/bvals_1k.txt'
        bvecloc = './HCP_data/Pat'+str(kk+1)+'/bvecs_1k.txt'
        
        bvalsHCP = np.loadtxt(bvalloc)
        bvecsHCP = np.loadtxt(bvecloc)
        gtabHCP = gradient_table(bvalsHCP, bvecsHCP)
        
        data, affine, img = load_nifti(fdwi, return_img=True)
        data, affine = reslice(data, affine, (1.5,1.5,1.5), (2.5,2.5,2.5))
        axial_middle = data.shape[2] // 2
        maskdata, _ = median_otsu(data, vol_idx=range(10, 50), median_radius=3,
                                     numpass=1, autocrop=True, dilate=2)
        # Compute the mask where the sum is not zero
        mask = np.sum(maskdata[:, :, axial_middle, :], axis=-1) != 0
        
        # Get the indices where mask is True
        Arr = maskdata[:,:,axial_middle, selected_indices]
        indices = np.argwhere(mask)
        def optimize_chunk(pixels):
            results = []
            for i, j in pixels:
                posterior_samples_1 = posteriorMin.sample((500,), x=np.hstack([Arr[i,j],kk]),show_progress_bars=False)
                results.append((i, j, np.array([histogram_mode(p) for p in posterior_samples_1.T])))
            return results
        
        chunked_indices = [indices[i:i+ChunkSize] for i in range(0, len(indices), ChunkSize)]
        results = Parallel(n_jobs=8)(
            delayed(optimize_chunk)(chunk) for chunk in tqdm(chunked_indices)
        )
        
        # Initialize array with the appropriate shape
        ArrShape = mask.shape
        NoiseEst = np.zeros(list(ArrShape) + [7])
        
        # Assign the optimization results to NoiseEst
        for chunk in results:
            for i, j, x in chunk:
                NoiseEst[i, j] = x
        
        NoiseEst2 =  np.zeros_like(NoiseEst)
        for i in range(ArrShape[0]):
            for j in range(ArrShape[1]):    
                NoiseEst2[i,j] = np.hstack([mat_to_vals(clip_negative_eigenvalues(vals_to_mat(NoiseEst[i,j]))),NoiseEst[i,j,-1:]])
        MD_SBIMin = np.zeros(ArrShape)
        FA_SBIMin = np.zeros(ArrShape)
        for i in range(ArrShape[0]):
            for j in range(ArrShape[1]): 
                Eigs = np.linalg.eigh(vals_to_mat(NoiseEst2[i,j,:6]))[0]
                MD_SBIMin[i,j] = np.mean(Eigs)
                FA_SBIMin[i,j] = FracAni(Eigs,np.mean(Eigs))
        FA_SBIMin[np.isnan(FA_SBIMin)] = 0
        MDMinArr.append(MD_SBIMin)
        FAMinArr.append(FA_SBIMin)
if os.path.exists(f"{DatFolder}/Mid_MD_HCP.npy"):
    MDMidArr = np.load(f"{DatFolder}/Mid_MD_HCP.npy",allow_pickle=True)
    FAMidArr = np.load(f"{DatFolder}/Mid_FA_HCP.npy",allow_pickle=True)
else:    
    MDMidArr = []
    FAMidArr = []
    for kk in tqdm(range(32)):
        fdwi = './HCP_data/Pat'+str(kk+1)+'/diff_1k.nii.gz'
        bvalloc = './HCP_data/Pat'+str(kk+1)+'/bvals_1k.txt'
        bvecloc = './HCP_data/Pat'+str(kk+1)+'/bvecs_1k.txt'
        
        bvalsHCP = np.loadtxt(bvalloc)
        bvecsHCP = np.loadtxt(bvecloc)
        gtabHCP = gradient_table(bvalsHCP, bvecsHCP)
        
        data, affine, img = load_nifti(fdwi, return_img=True)
        data, affine = reslice(data, affine, (1.5,1.5,1.5), (2.5,2.5,2.5))
        axial_middle = data.shape[2] // 2
        maskdata, _ = median_otsu(data, vol_idx=range(10, 50), median_radius=3,
                                     numpass=1, autocrop=True, dilate=2)
        # Compute the mask where the sum is not zero
        mask = np.sum(maskdata[:, :, axial_middle, :], axis=-1) != 0
        
        # Get the indices where mask is True
        Arr = maskdata[:,:,axial_middle, selected_indices20]
        indices = np.argwhere(mask)
        def optimize_chunk(pixels):
            results = []
            for i, j in pixels:
                posterior_samples_1 = posteriorMid.sample((500,), x=np.hstack([Arr[i,j],kk]),show_progress_bars=False)
                results.append((i, j, np.array([histogram_mode(p) for p in posterior_samples_1.T])))
            return results
        
        chunked_indices = [indices[i:i+ChunkSize] for i in range(0, len(indices), ChunkSize)]
        results = Parallel(n_jobs=8)(
            delayed(optimize_chunk)(chunk) for chunk in tqdm(chunked_indices)
        )
        
        # Initialize array with the appropriate shape
        ArrShape = mask.shape
        NoiseEst = np.zeros(list(ArrShape) + [7])
        
        # Assign the optimization results to NoiseEst
        for chunk in results:
            for i, j, x in chunk:
                NoiseEst[i, j] = x
        
        NoiseEst2 =  np.zeros_like(NoiseEst)
        for i in range(ArrShape[0]):
            for j in range(ArrShape[1]):    
                NoiseEst2[i,j] = np.hstack([mat_to_vals(clip_negative_eigenvalues(vals_to_mat(NoiseEst[i,j]))),NoiseEst[i,j,-1:]])
        MD_SBIMid = np.zeros(ArrShape)
        FA_SBIMid = np.zeros(ArrShape)
        for i in range(ArrShape[0]):
            for j in range(ArrShape[1]): 
                Eigs = np.linalg.eigh(vals_to_mat(NoiseEst2[i,j,:6]))[0]
                MD_SBIMid[i,j] = np.mean(Eigs)
                FA_SBIMid[i,j] = FracAni(Eigs,np.mean(Eigs))
        FA_SBIMid[np.isnan(FA_SBIMid)] = 0
        MDMidArr.append(MD_SBIMid)
        FAMidArr.append(FA_SBIMid)
MDFullNLArr = []
FAFullNLArr = []

MDMidNLArr = []
FAMidNLArr = []

MDMinNLArr = []
FAMinNLArr = []
for kk in range(32):
    fdwi = './HCP_data/Pat'+str(kk+1)+'/diff_1k.nii.gz'
    bvalloc = './HCP_data/Pat'+str(kk+1)+'/bvals_1k.txt'
    bvecloc = './HCP_data/Pat'+str(kk+1)+'/bvecs_1k.txt'
    
    bvalsHCP = np.loadtxt(bvalloc)
    bvecsHCP = np.loadtxt(bvecloc)
    gtabHCP = gradient_table(bvalsHCP, bvecsHCP)
    
    data, affine, img = load_nifti(fdwi, return_img=True)
    data, affine = reslice(data, affine, (1.5,1.5,1.5), (2.5,2.5,2.5))
    axial_middle = data.shape[2] // 2
    maskdata, mask = median_otsu(data, vol_idx=range(10, 50), median_radius=3,
                                 numpass=1, autocrop=True, dilate=2)
    ArrShape = maskdata[:,:,axial_middle,0].shape
    tenmodel = dti.TensorModel(gtabHCP,return_S0_hat = True,fit_method='NLLS')
    tenfit = tenmodel.fit(maskdata[:,:,axial_middle])
    FAFull_t = dti.fractional_anisotropy(tenfit.evals)
    MDFull_t = dti.mean_diffusivity(tenfit.evals)
    MDFullNLArr.append(MDFull_t)
    FAFullNLArr.append(FAFull_t)

    tenmodel = dti.TensorModel(gTabs20[kk],return_S0_hat = True,fit_method='NLLS')
    tenfit = tenmodel.fit(maskdata[:,:,axial_middle,selected_indices20])
    FAFull_t = dti.fractional_anisotropy(tenfit.evals)
    MDFull_t = dti.mean_diffusivity(tenfit.evals)
    MDMidNLArr.append(MDFull_t)
    FAMidNLArr.append(FAFull_t)
    
    tenmodel = dti.TensorModel(gTabs7[kk],return_S0_hat = True,fit_method='NLLS')
    tenfit = tenmodel.fit(maskdata[:,:,axial_middle,selected_indices])
    FAFull_t = dti.fractional_anisotropy(tenfit.evals)
    MDFull_t = dti.mean_diffusivity(tenfit.evals)
    MDMinNLArr.append(MDFull_t)
    FAMinNLArr.append(FAFull)

In [None]:
AccM7_MD = []
AccM20_MD = []
AccMFulls_MD = []

AccM7NL_MD = []
AccM20NL_MD = []

SSIM7_MD = []
SSIM20_MD = []
SSIMFulls_MD = []

SSIM7NL_MD = []
SSIM20NL_MD = []
for i in tqdm(range(32)):
    M7 = MDMinArr[i]
    MF = MDFullArr[i]
    Ma = Masks[i]
    AccM7_MD.append(np.mean(np.abs(M7-MF)[Ma]))

    M7 = MDMidArr[i]
    MF = MDFullArr[i]
    AccM20_MD.append(np.mean(np.abs(M7-MF)[Ma]))

    M7 = MDFullArr[i]
    MF = MDFullNLArr[i]
    AccMFulls_MD.append(np.mean(np.abs(M7-MF)[Ma]))

    M7 = MDMinNLArr[i]
    MF = MDFullNLArr[i]
    AccM7NL_MD.append(np.mean(np.abs(M7-MF)[Ma]))

    M7 = MDMidNLArr[i]
    MF = MDFullNLArr[i]
    AccM20NL_MD.append(np.mean(np.abs(M7-MF)[Ma]))


    NS1 = MDMinArr[i]
    NS2 = MDFullArr[i]
    result = masked_local_ssim(NS1, NS2, Ma, win_size=7)
    SSIM7_MD.append(result)

    NS1 = MDMidArr[i]
    NS2 = MDFullArr[i]
    result = masked_local_ssim(NS1, NS2, Ma, win_size=7)
    SSIM20_MD.append(result)
    
    NS1 = MDFullArr[i]
    NS2 = MDFullNLArr[i]
    result = masked_local_ssim(NS1, NS2, Ma, win_size=7)
    SSIMFulls_MD.append(result)

    NS1 = MDMinNLArr[i]
    NS2 = MDFullNLArr[i]
    result = masked_local_ssim(NS1, NS2, Ma, win_size=7)
    SSIM7NL_MD.append(result)

    NS1 = MDMidNLArr[i]
    NS2 = MDFullNLArr[i]
    result = masked_local_ssim(NS1, NS2, Ma, win_size=7)
    SSIM20NL_MD.append(result)

Prec7_SBI_MD = []
Prec20_SBI_MD = []
PrecFull_SBI_MD = []

Prec7_NLLS_MD = []
Prec20_NLLS_MD = []
PrecFull_NLLS_MD = []
for i in range(32):
    Prec7_SBI_MD.append(np.std(MDMinArr[i][WMs[i]]))
    Prec20_SBI_MD.append(np.std(MDMidArr[i][WMs[i]]))
    PrecFull_SBI_MD.append(np.std(MDFullArr[i][WMs[i]]))

    Prec7_NLLS_MD.append(np.std(MDMinNLArr[i][WMs[i]]))
    Prec20_NLLS_MD.append(np.std(MDMidNLArr[i][WMs[i]]))
    PrecFull_NLLS_MD.append(np.std(MDFullNLArr[i][WMs[i]]))



In [None]:
AccM7_FA = []
AccM20_FA = []
AccMFulls_FA = []

AccM7NL_FA = []
AccM20NL_FA = []

SSIM7_FA = []
SSIM20_FA = []
SSIMFulls_FA = []

SSIM7NL_FA = []
SSIM20NL_FA = []
for i in range(32):
    M7 = FAMinArr[i]
    MF = FAFullArr[i]
    Ma = Masks[i]
    AccM7_FA.append(np.mean(np.abs(M7-MF)[Ma]))

    M7 = FAMidArr[i]
    MF = FAFullArr[i]
    AccM20_FA.append(np.mean(np.abs(M7-MF)[Ma]))

    M7 = FAFullArr[i]
    MF = FAFullNLArr[i]
    AccMFulls_FA.append(np.mean(np.abs(M7-MF)[Ma]))

    M7 = FAMinNLArr[i]
    MF = FAFullNLArr[i]
    AccM7NL_FA.append(np.mean(np.abs(M7-MF)[Ma]))

    M7 = FAMidNLArr[i]
    MF = FAFullNLArr[i]
    AccM20NL_FA.append(np.mean(np.abs(M7-MF)[Ma]))


    NS1 = FAMinArr[i]
    NS2 = FAFullArr[i]
    result = masked_local_ssim(NS1, NS2, Ma, win_size=7)
    SSIM7_FA.append(result)

    NS1 = FAMidArr[i]
    NS2 = FAFullArr[i]
    result = masked_local_ssim(NS1, NS2, Ma, win_size=7)
    SSIM20_FA.append(result)
    
    NS1 = FAFullArr[i]
    NS2 = FAFullNLArr[i]
    result = masked_local_ssim(NS1, NS2, Ma, win_size=7)
    SSIMFulls_FA.append(result)

    NS1 = FAMinNLArr[i]
    NS2 = FAFullNLArr[i]
    result = masked_local_ssim(NS1, NS2, Ma, win_size=7)
    SSIM7NL_FA.append(result)

    NS1 = FAMidNLArr[i]
    NS2 = FAFullNLArr[i]
    result = masked_local_ssim(NS1, NS2, Ma, win_size=7)
    SSIM20NL_FA.append(result)


Prec7_SBI_FA = []
Prec20_SBI_FA = []
PrecFull_SBI_FA = []

Prec7_NLLS_FA = []
Prec20_NLLS_FA = []
PrecFull_NLLS_FA = []
for i in range(32):
    Prec7_SBI_FA.append(np.std(FAMinArr[i][WMs[i]]))
    Prec20_SBI_FA.append(np.std(FAMidArr[i][WMs[i]]))
    PrecFull_SBI_FA.append(np.std(FAFullArr[i][WMs[i]]))

    Prec7_NLLS_FA.append(np.std(FAMinNLArr[i][WMs[i]]))
    Prec20_NLLS_FA.append(np.std(FAMidNLArr[i][WMs[i]]))
    PrecFull_NLLS_FA.append(np.std(FAFullNLArr[i][WMs[i]]))


### a

In [None]:
temp = np.copy(MD_SBIFull)

temp[~mask_cutout] = math.nan
img = plt.imshow(temp.T,cmap='hot')
plt.axis('off')
vmin, vmax = img.get_clim()

if Save: plt.savefig(FigLoc+'HCP_SBI_MD.pdf',format='pdf',bbox_inches='tight',transparent=True)

In [None]:
temp = np.copy(MDFull)

temp[~mask_cutout] = math.nan
img = plt.imshow(temp.T,cmap='hot')
plt.axis('off')
vmin, vmax = img.get_clim()
cbar = plt.colorbar()
cbar.formatter.set_powerlimits((0, 0))

if Save: plt.savefig(FigLoc+'HCP_WLS_MD.pdf',format='pdf',bbox_inches='tight',transparent=True)

In [None]:
data = MD_SBIFull.T-MDFull.T
data[~mask_cutout.T] = np.nan
norm = TwoSlopeNorm(vmin=np.nanmin(data), vcenter=0, vmax=np.nanmax(data))
plt.imshow(data,cmap='seismic',norm=norm)
plt.axis('off')
cbar = plt.colorbar()
ticks = [np.nanmin(data), 0, np.nanmax(data)]  # Adjust the number of ticks as needed
cbar.set_ticks(ticks)
cbar.formatter.set_powerlimits((0, 0))

if Save: plt.savefig(FigLoc+'HCP_MD_Diff.pdf',format='pdf',bbox_inches='tight',transparent=True)

### b

In [None]:
temp = np.copy(MD_SBI7)

temp[~mask] = math.nan
img = plt.imshow(temp.T,cmap='hot',vmin=0,vmax=3.5e-3)
plt.axis('off')
#cbar = plt.colorbar()
#cbar.formatter.set_powerlimits((0, 0))
vmin, vmax = img.get_clim()

if Save: plt.savefig(FigLoc+'HCP_SBI_MD_7.pdf',format='pdf',bbox_inches='tight',transparent=True)

In [None]:
temp = np.copy(MD7)

temp[~mask] = math.nan
img = plt.imshow(temp.T,cmap='hot',vmin=0,vmax=3.5e-3)
plt.axis('off')
#cbar = plt.colorbar()
#cbar.formatter.set_powerlimits((0, 0))
vmin, vmax = img.get_clim()
cbar = plt.colorbar()
cbar.formatter.set_powerlimits((0, 0))

if Save: plt.savefig(FigLoc+'HCP_WLS_MD_7.pdf',format='pdf',bbox_inches='tight',transparent=True)

In [None]:
data = np.abs(MD_SBIFull.T-MD_SBI7.T)
data[~mask.T] = np.nan
norm = TwoSlopeNorm(vmin=0, vcenter=np.nanmax(data)/2, vmax=np.nanmax(data))
plt.imshow(data,cmap='Reds',norm=norm)
plt.axis('off')
if Save: plt.savefig(FigLoc+'DTI_MDSBIErr.pdf',format='pdf',bbox_inches='tight',transparent=True)
plt.show()

norm = TwoSlopeNorm(vmin=0, vcenter=np.nanmax(data)/2, vmax=np.nanmax(data))
ticks = [0, np.round(np.nanmax(data),3)]  # Adjust the number of ticks as needed
data = np.abs(MDFull.T-MD7.T)
data[~mask.T] = np.nan
plt.imshow(data,cmap='Reds',norm=norm)

plt.axis('off')
cbar = plt.colorbar()

cbar.set_ticks(ticks)
cbar.formatter.set_powerlimits((0, 0))
if Save: plt.savefig(FigLoc+'DTI_MDWLSErr.pdf',format='pdf',bbox_inches='tight',transparent=True)

### c

In [None]:
temp = np.copy(FA_SBIFull)

temp[~mask] = math.nan
img = plt.imshow(temp.T,cmap='hot')
plt.axis('off')
vmin, vmax = img.get_clim()

if Save: plt.savefig(FigLoc+'HCP_SBI_FA.pdf',format='pdf',bbox_inches='tight',transparent=True)

In [None]:
temp = np.copy(FAFull)

temp[~mask] = math.nan
img = plt.imshow(temp.T,cmap='hot')
plt.axis('off')
vmin, vmax = img.get_clim()
cbar = plt.colorbar()
cbar.formatter.set_powerlimits((0, 0))

if Save: plt.savefig(FigLoc+'HCP_WLS_FA.pdf',format='pdf',bbox_inches='tight',transparent=True)

In [None]:
data = FA_SBIFull.T-FAFull.T
data[~mask.T] = np.nan
plt.imshow(data,cmap='seismic',vmin=-1, vmax=1)
plt.axis('off')
cbar = plt.colorbar()
cbar.formatter.set_powerlimits((0, 0))
ticks = [-1, 0, 1]  # Adjust the number of ticks as needed
cbar.set_ticks(ticks)
if Save: plt.savefig(FigLoc+'HCP_FA_Diff.pdf',format='pdf',bbox_inches='tight',transparent=True)

### d

In [None]:
temp = np.copy(FA_SBI7)

temp[~mask] = math.nan
img = plt.imshow(temp.T,cmap='hot')
plt.axis('off')
#cbar = plt.colorbar()
#cbar.formatter.set_powerlimits((0, 0))
vmin, vmax = img.get_clim()
if Save: plt.savefig(FigLoc+'HCP_SBI_FA_7.pdf',format='pdf',bbox_inches='tight',transparent=True)

In [None]:
temp = np.copy(FA7)

temp[~mask] = math.nan
img = plt.imshow(temp.T,cmap='hot')
plt.axis('off')
cbar = plt.colorbar()
cbar.formatter.set_powerlimits((0, 0))
if Save: plt.savefig(FigLoc+'HCP_WLS_FA_7.pdf',format='pdf',bbox_inches='tight',transparent=True)

In [None]:
data = np.abs(FA_SBIFull.T-FA_SBI7.T)
data[~mask.T] = np.nan
norm = TwoSlopeNorm(vmin=0, vcenter=0.5, vmax=1)
plt.imshow(data,cmap='Reds',norm=norm)
plt.axis('off')
if Save: plt.savefig(FigLoc+'DTI_FASBIErr.pdf',format='pdf',bbox_inches='tight',transparent=True)
plt.show()

norm = TwoSlopeNorm(vmin=0, vcenter=0.5, vmax=1)
ticks = [0, 1]  # Adjust the number of ticks as needed
data = np.abs(FAFull.T-FA7.T)
data[~mask.T] = np.nan
plt.imshow(data,cmap='Reds',norm=norm)

plt.axis('off')
cbar = plt.colorbar()

cbar.set_ticks(ticks)
if Save: plt.savefig(FigLoc+'DTI_FAWLSErr.pdf',format='pdf',bbox_inches='tight',transparent=True)

### e

In [None]:
# Plot setup
fig, (ax1, ax2) = plt.subplots(2, 1)
fig.subplots_adjust(hspace=0.05)



y_data = np.array(AccM7NL_MD)
g_pos = np.array([3.1])
colors = ['burlywood']
colors2 = ['peachpuff']

BoxPlots(y_data,g_pos,colors,colors2,ax1,widths=0.2,scatter=True)
plt.xticks([1,1.7,2,2.8,3.1],['Full','Mid','Min','Mid','Min'],fontsize=32,rotation=90)
#plt.gca().ticklabel_format(axis='y',style='sci',scilimits=(-1,1))
plt.yticks(fontsize=24)

ax1.set_ylim(0.0001, 2.5e-3)
ax1.ticklabel_format(axis='y', style='sci', scilimits=(-3, -3))
#ax1.yaxis.set_ticks(np.arange(0.0005, 0.006, 0.002))
ax1.set_xticks([])

plt.sca(ax2)

y_data = np.array(AccMFulls_MD)
g_pos = np.array([1])
colors = ['black']
colors2 = ['gray']

BoxPlots(y_data,g_pos,colors,colors2,ax2,widths=0.2,scatter=True)

y_data = np.array(AccM20_MD)
g_pos = np.array([1.7])
colors = ['lightseagreen']
colors2 = ['paleturquoise']

BoxPlots(y_data,g_pos,colors,colors2,ax2,widths=0.2,scatter=True)

y_data = np.array(AccM7_MD)
g_pos = np.array([2])
colors = ['mediumturquoise']
colors2 = ['paleturquoise']

BoxPlots(y_data,g_pos,colors,colors2,ax2,widths=0.2,scatter=True)

y_data = np.array(AccM20NL_MD)
g_pos = np.array([2.8])
colors = ['sandybrown']
colors2 = ['peachpuff']

BoxPlots(y_data,g_pos,colors,colors2,ax2,widths=0.2,scatter=True)

plt.xticks([1,1.7,2,2.8,3.1],['Full','Mid','Min','Mid','Min'],fontsize=32,rotation=90)
plt.gca().ticklabel_format(axis='y',style='sci',scilimits=(-0.5,3))
plt.yticks(fontsize=24)

ax2.set_ylim(0, 0.00016)
ax2.ticklabel_format(axis='y', style='sci', scilimits=(-3, -3))
ax2.yaxis.set_ticks(np.arange(0, 0.00018, 0.0001))

# Common x-ticks
ax2.set_xticks([1, 1.7, 2, 2.8, 3.1])
ax2.set_xticklabels(['Full', 'Mid', 'Min', 'Mid', 'Min'], fontsize=32, rotation=90)

# Adding broken axis effect
d = .5
kwargs = dict(marker=[(-1, -d), (1, d)], markersize=12, linestyle="none", color='k', mec='k', mew=1, clip_on=False)
ax1.plot([0], [0], transform=ax1.transAxes, **kwargs)
ax2.plot([0], [1], transform=ax2.transAxes, **kwargs)

ax1.set_xlim(ax2.get_xlim())

# Hide the spines between ax and ax2
ax1.spines.bottom.set_visible(False)
ax2.spines.top.set_visible(False)
ax1.xaxis.tick_top()
ax1.tick_params(labeltop=False)  # don't put tick labels at the top
ax2.xaxis.tick_bottom()
ax2.yaxis.offsetText.set_visible(False)  # Hide the "1e-3" from ax2

leg_patch1 = mpatches.Patch(color='gray', label='Full Comp.')
leg_patch2 = mpatches.Patch(color='lightseagreen', label='Mid. (SBI)')
leg_patch3 = mpatches.Patch(color='mediumturquoise', label='Min. (SBI)')
leg_patch4 = mpatches.Patch(color='sandybrown', label='Mid. (NLLS)')
leg_patch5 = mpatches.Patch(color='burlywood', label='Min. (NLLS)')


# Show plot
if Save:
    plt.savefig(FigLoc + 'DTIHCP_Acc_MD.pdf', format='PDF', transparent=True, bbox_inches='tight')

ax1.legend(
    handles=[leg_patch1,leg_patch2,leg_patch3],
    loc='upper left',         # base location  # fine-tune the legend's position
    frameon=False, ncols=1,
fontsize=32,columnspacing=0.3,handlelength=0.6,handletextpad=0.3,bbox_to_anchor=(0,1.2))
plt.show()

In [None]:
fig, (ax1, ax2) = plt.subplots(2, 1)#, sharex=True)
fig.subplots_adjust(hspace=0.05)  # adjust space between Axes

# Plotting on ax1
plt.sca(ax1)
y_data = np.array(SSIMFulls_MD)
g_pos = np.array([1])
colors = ['black']
colors2 = ['gray']

BoxPlots(y_data,g_pos,colors,colors2,ax1,widths=0.2,scatter=True)

y_data = np.array(SSIM20_MD)
g_pos = np.array([1.7])
colors = ['lightseagreen']
colors2 = ['paleturquoise']

BoxPlots(y_data,g_pos,colors,colors2,ax1,widths=0.2,scatter=True,scatter_alpha=0.5)

y_data = np.array(SSIM7_MD)
g_pos = np.array([2])
colors = ['mediumturquoise']
colors2 = ['paleturquoise']

BoxPlots(y_data,g_pos,colors,colors2,ax1,widths=0.2,scatter=True)

y_data = np.array(SSIM20NL_MD)
g_pos = np.array([2.8])
colors = ['sandybrown']
colors2 = ['peachpuff']

BoxPlots(y_data,g_pos,colors,colors2,ax1,widths=0.2,scatter=True)

y_data = np.array(SSIM7NL_MD)
g_pos = np.array([3.1])
colors = ['burlywood']
colors2 = ['peachpuff']

BoxPlots(y_data,g_pos,colors,colors2,ax1,widths=0.2,scatter=True)

y_data = np.array(SSIM7NL_MD)
g_pos = np.array([3.1])
colors = ['burlywood']
colors2 = ['peachpuff']

BoxPlots(y_data,g_pos,colors,colors2,ax2,widths=0.2,scatter=True)

ax1.spines.bottom.set_visible(False)
ax2.spines.top.set_visible(False)
ax1.xaxis.tick_top()
ax1.tick_params(labeltop=False)  # don't put tick labels at the top
ax2.xaxis.tick_bottom()

ax1.set_ylim(.7, 1.)  # outliers only
ax2.set_ylim(0, .7)  # most of the data

ax1.set_xticks([]) 
ax2.set_xticks([]) 
plt.yticks(fontsize=32)
plt.sca(ax2)
plt.yticks(fontsize=32)

d = .5  # proportion of vertical to horizontal extent of the slanted line
kwargs = dict(marker=[(-1, -d), (1, d)], markersize=12,
              linestyle="none", color='k', mec='k', mew=1, clip_on=False)
ax1.plot([0], [0], transform=ax1.transAxes, **kwargs)
ax2.plot([0], [1], transform=ax2.transAxes, **kwargs)

ax2.set_xlim(ax1.get_xlim())
ax2.axhline(0.66, lw=3, ls='--', c='k')
ax2.set_xticks([1, 1.7, 2, 2.8, 3.1])
ax2.set_xticklabels(['Full', 'Mid', 'Min', 'Mid', 'Min'], fontsize=32, rotation=90)

ax2.legend(
    handles=[leg_patch4,leg_patch5],
    loc='upper left',         # base location  # fine-tune the legend's position
    frameon=False, ncols=1,
fontsize=32,columnspacing=0.3,handlelength=0.6,handletextpad=0.3,bbox_to_anchor= (-0.0,0.9))
if Save: plt.savefig(FigLoc+'DTI_MD_SSIMErr.pdf',format='pdf',bbox_inches='tight',transparent=True)

In [None]:
y_data = np.array(PrecFull_SBI_MD)
g_pos = np.array([1])
colors = ['mediumturquoise']
colors2 = ['paleturquoise']
fig,ax = plt.subplots()
BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2,scatter=True)

y_data = np.array(Prec20_SBI_MD)
g_pos = np.array([1.2])
BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2,scatter=True)

y_data = np.array(Prec7_SBI_MD)
g_pos = np.array([1.4])
BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2,scatter=True)

colors = ['sandybrown']
colors2 = ['peachpuff']
y_data = np.array(PrecFull_NLLS_MD)
g_pos = np.array([2])
BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2,scatter=True)

y_data = np.array(Prec20_NLLS_MD)
g_pos = np.array([2.2])
BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2,scatter=True)

y_data = np.array(Prec7_NLLS_MD)
g_pos = np.array([2.4])
BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2,scatter=True)

x = np.arange(1.85,2.6,0.05)
y1 = np.ones_like(x)*np.percentile(np.array(PrecFull_NLLS_MD)[~np.isnan(PrecFull_NLLS_MD)], 25)
y2 = np.ones_like(x)*np.percentile(np.array(PrecFull_NLLS_MD)[~np.isnan(PrecFull_NLLS_MD)], 77)
plt.fill_between(x,y1,y2,color=WLSFit,zorder=10,alpha=0.2,hatch='//')

x = np.arange(0.85,1.6,0.05)
y1 = np.ones_like(x)*np.percentile(np.array(PrecFull_SBI_MD)[~np.isnan(PrecFull_SBI_MD)], 25)
y2 = np.ones_like(x)*np.percentile(np.array(PrecFull_SBI_MD)[~np.isnan(PrecFull_SBI_MD)], 77)
plt.fill_between(x,y1,y2,color=SBIFit,zorder=10,alpha=0.2,hatch='//')

plt.xticks([1,1.2,1.4,2,2.2,2.4],['Full','Mid','Min','Full','Mid','Min'],fontsize=32,rotation=90)
plt.gca().ticklabel_format(axis='y',style='sci',scilimits=(-0.5,3))
plt.yticks(fontsize=24)
if Save: plt.savefig(FigLoc+'DTI_MD_Prec.pdf',format='pdf',bbox_inches='tight',transparent=True)

### f

In [None]:
# Plot setup
fig, (ax1, ax2) = plt.subplots(2, 1)
fig.subplots_adjust(hspace=0.05)



y_data = np.array(AccM7NL_FA)
g_pos = np.array([3.1])
colors = ['burlywood']
colors2 = ['peachpuff']

BoxPlots(y_data,g_pos,colors,colors2,ax1,widths=0.2,scatter=True)
plt.xticks([1,1.7,2,2.8,3.1],['Full','Mid','Min','Mid','Min'],fontsize=32,rotation=90)
#plt.gca().ticklabel_format(axis='y',style='sci',scilimits=(-1,1))
plt.yticks(fontsize=24)

#ax1.yaxis.set_ticks(np.arange(0.0005, 0.006, 0.002))
ax1.set_xticks([])

plt.sca(ax2)

y_data = np.array(AccMFulls_FA)
g_pos = np.array([1])
colors = ['black']
colors2 = ['gray']

BoxPlots(y_data,g_pos,colors,colors2,ax2,widths=0.2,scatter=True)

y_data = np.array(AccM20_FA)
g_pos = np.array([1.7])
colors = ['lightseagreen']
colors2 = ['paleturquoise']

BoxPlots(y_data,g_pos,colors,colors2,ax2,widths=0.2,scatter=True)

y_data = np.array(AccM7_FA)
g_pos = np.array([2])
colors = ['mediumturquoise']
colors2 = ['paleturquoise']

BoxPlots(y_data,g_pos,colors,colors2,ax2,widths=0.2,scatter=True)

y_data = np.array(AccM20NL_FA)
g_pos = np.array([2.8])
colors = ['sandybrown']
colors2 = ['peachpuff']

BoxPlots(y_data,g_pos,colors,colors2,ax2,widths=0.2,scatter=True)

plt.xticks([1,1.7,2,2.8,3.1],['Full','Mid','Min','Mid','Min'],fontsize=32,rotation=90)
#plt.gca().ticklabel_format(axis='y',style='sci',scilimits=(-0.5,3))
ax1.set_ylim(0.4, 1)
#ax1.ticklabel_format(axis='y', style='sci', scilimits=(-3, -3))
ax1.yaxis.set_ticks([0.4,0.7,1])
ax1.set_xticks([])

plt.yticks(fontsize=24)

ax2.set_ylim(0.0,0.3)
#ax2.ticklabel_format(axis='y', style='sci', scilimits=(-3, -3))
#ax2.yaxis.set_ticks(np.arange(0, 0.0001, 0.00004))

# Common x-ticks
#ax2.set_xticks([1, 1.7, 2, 2.8, 3.1])
ax2.set_xticklabels(['Full', 'Mid', 'Min', 'Mid', 'Min'], fontsize=32, rotation=90)

# Adding broken axis effect
d = .5
kwargs = dict(marker=[(-1, -d), (1, d)], markersize=12, linestyle="none", color='k', mec='k', mew=1, clip_on=False)
ax1.plot([0], [0], transform=ax1.transAxes, **kwargs)
ax2.plot([0], [1], transform=ax2.transAxes, **kwargs)

# Hide the spines between ax and ax2
ax1.spines.bottom.set_visible(False)
ax2.spines.top.set_visible(False)
ax1.xaxis.tick_top()
ax1.tick_params(labeltop=False)  # don't put tick labels at the top
ax2.xaxis.tick_bottom()
ax2.yaxis.offsetText.set_visible(False)  # Hide the "1e-3" from ax2
ax1.set_xlim(ax2.get_xlim())
# Show plot
if Save:
    plt.savefig(FigLoc + 'DTIHCP_Acc_FA.pdf', format='PDF', transparent=True, bbox_inches='tight')

plt.show()

In [None]:
fig,ax = plt.subplots()
y_data = np.array(SSIMFulls_FA)
g_pos = np.array([1])
colors = ['black']
colors2 = ['gray']

BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2,scatter=True)

y_data = np.array(SSIM20_FA)
g_pos = np.array([1.7])
colors = ['lightseagreen']
colors2 = ['paleturquoise']

BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2,scatter=True,scatter_alpha=0.5)

y_data = np.array(SSIM7_FA)
g_pos = np.array([2])
colors = ['mediumturquoise']
colors2 = ['paleturquoise']

BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2,scatter=True)

y_data = np.array(SSIM20NL_FA)
g_pos = np.array([2.8])
colors = ['sandybrown']
colors2 = ['peachpuff']

BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2,scatter=True)

y_data = np.array(SSIM7NL_FA)
g_pos = np.array([3.1])
colors = ['burlywood']
colors2 = ['peachpuff']

BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2,scatter=True)

y_data = np.array(SSIM7NL_FA)
g_pos = np.array([3.1])
colors = ['burlywood']
colors2 = ['peachpuff']

BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2,scatter=True)

plt.axhline(0.66,lw=3,ls='--',c='k')
plt.yticks(fontsize=32)
plt.ylim([0,1])
ax2.legend(
    handles=[leg_patch3],
    loc='lower left',         # base location  # fine-tune the legend's position
    frameon=False, ncols=2,
fontsize=32,columnspacing=0.3,handlelength=0.6,handletextpad=0.3)
ax.set_xticks([1,1.7,2,2.8,3.1])
ax.set_xticklabels(['Full', 'Mid', 'Min', 'Mid', 'Min'], fontsize=32, rotation=90)

if Save: plt.savefig(FigLoc+'DTI_FA_SSIMErr.pdf',format='pdf',bbox_inches='tight',transparent=True)

In [None]:
y_data = np.array(PrecFull_SBI_FA)
g_pos = np.array([1])
colors = ['mediumturquoise']
colors2 = ['paleturquoise']
fig,ax = plt.subplots()
BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2,scatter=True)

y_data = np.array(Prec20_SBI_FA)
g_pos = np.array([1.2])
BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2,scatter=True)

y_data = np.array(Prec7_SBI_FA)
g_pos = np.array([1.4])
BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2,scatter=True)

colors = ['sandybrown']
colors2 = ['peachpuff']
y_data = np.array(PrecFull_NLLS_FA)
g_pos = np.array([2])
BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2,scatter=True)

y_data = np.array(Prec20_NLLS_FA)
g_pos = np.array([2.2])
BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2,scatter=True)

y_data = np.array(Prec7_NLLS_FA)
g_pos = np.array([2.4])
BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2,scatter=True)

x = np.arange(1.85,2.6,0.05)
y1 = np.ones_like(x)*np.percentile(np.array(PrecFull_NLLS_FA)[~np.isnan(PrecFull_NLLS_FA)], 25)
y2 = np.ones_like(x)*np.percentile(np.array(PrecFull_NLLS_FA)[~np.isnan(PrecFull_NLLS_FA)], 77)
plt.fill_between(x,y1,y2,color=WLSFit,zorder=10,alpha=0.2,hatch='//')

x = np.arange(0.85,1.6,0.05)
y1 = np.ones_like(x)*np.percentile(np.array(PrecFull_SBI_FA)[~np.isnan(PrecFull_SBI_FA)], 25)
y2 = np.ones_like(x)*np.percentile(np.array(PrecFull_SBI_FA)[~np.isnan(PrecFull_SBI_FA)], 75)
plt.fill_between(x,y1,y2,color=SBIFit,zorder=10,alpha=0.2,hatch='//')

plt.xticks([1,1.2,1.4,2,2.2,2.4],['Full','Mid','Min','Full','Mid','Min'],fontsize=32,rotation=90)
plt.gca().ticklabel_format(axis='y',style='sci',scilimits=(-0.5,3))
plt.yticks(fontsize=24)
if Save: plt.savefig(FigLoc+'DTI_FA_Prec.pdf',format='pdf',bbox_inches='tight',transparent=True)

## MS

In [None]:
Dats_MS   = []
gTabs7_MS = []
gTabs20_MS = []
gTabsF_MS = []
Masks_MS   = []
TrueIndxs = []
axial_middles_MS = []
for i,Name in tqdm(enumerate(['NMSS_11_1year','NMSS_15','NMSS_16','NMSS_18','NMSS_19','Ctrl055_R01_28','Ctrl056_R01_29','Ctrl057_R01_30'])):
    MatDir = MSDir+Name

    F = pmt.read_mat(MatDir+'/data_loaded.mat')
    affine = np.ones((4,4))
    
    data, affine = reslice(F['data'], affine, (2,2,2), (2.5,2.5,2.5))
    _, maskCut = median_otsu(data, vol_idx=range(10, 80), autocrop=False)
    maskdata, mask = median_otsu(data, vol_idx=range(10, 80), autocrop=True)
    Masks_MS.append(mask)
    axial_middle = maskdata.shape[2] // 2
    axial_middles_MS.append(axial_middle)
    bvecs = (F['direction'].T/np.linalg.norm(F['direction'],axis=1)).T
    bvecs[np.isnan(bvecs)] = 0
    bvals = F['bval']
    bvecs2000 = bvecs[bvals==2000]
    bvecs4000 = bvecs[bvals==4000]

    bvals2000 = np.array([0] + list(bvals[bvals==2000]))
    bvecs2000 = np.vstack([[0,0,0],bvecs[bvals==2000]])

    Dats_MS.append(maskdata[:,:,:,np.hstack([0,np.where(bvals==2000)[0]])])
    
    gTabsF_MS.append(gradient_table(bvals2000,bvecs2000))

    if(i == 0):
        # Choose the first point (arbitrary starting point, e.g., the first gradient)
        selected_indices = [0]
        distance_matrix = squareform(pdist(bvecs2000))
        # Iteratively select the point furthest from the current selection
        for _ in range(6):  # We need 7 points in total, and one is already selected
            remaining_indices = list(set(range(len(bvecs2000))) - set(selected_indices))
            
            # Calculate the minimum distance to the selected points for each remaining point
            min_distances = np.min(distance_matrix[remaining_indices][:, selected_indices], axis=1)
            
            # Select the point with the maximum minimum distance
            next_index = remaining_indices[np.argmax(min_distances)]
            selected_indices.append(next_index)
        
        selected_indices_MS = selected_indices

        selected_indices20_MS = [0]
        distance_matrix = squareform(pdist(bvecs2000))
        # Iteratively select the point furthest from the current selection
        for _ in range(19):  # We need 7 points in total, and one is already selected
            remaining_indices = list(set(range(len(bvecs2000))) - set(selected_indices20_MS))
            
            # Calculate the minimum distance to the selected points for each remaining point
            min_distances = np.min(distance_matrix[remaining_indices][:, selected_indices20_MS], axis=1)
            
            # Select the point with the maximum minimum distance
            next_index = remaining_indices[np.argmax(min_distances)]
            selected_indices20_MS.append(next_index)

    bvalsHCP7 = bvals2000[selected_indices_MS]
    bvecsHCP7 = bvecs2000[selected_indices_MS]
    
    gTabs7_MS.append(gradient_table(bvalsHCP7, bvecsHCP7))
    bvalsHCP7 = bvals2000[selected_indices20_MS]
    bvecsHCP7 = bvecs2000[selected_indices20_MS]
    
    gTabs20_MS.append(gradient_table(bvalsHCP7, bvecsHCP7))

In [None]:
WMDir = MSDir+'WM_masks/'
WMs_MS = []
for i,Name in tqdm(enumerate(['NMSS_11_1year','NMSS_15','NMSS_16','NMSS_18','NMSS_19','Ctrl055_R01_28','Ctrl056_R01_29','Ctrl057_R01_30'])):
    MatDir = MSDir+Name
    F = pmt.read_mat(MatDir+'/data_loaded.mat')
    affine = np.ones((4,4))
    
    data, affine = reslice(F['data'], affine, (2,2,2), (2.5,2.5,2.5))
    _, maskCut = median_otsu(data, vol_idx=range(10, 80), autocrop=False)
    
    true_indices = np.argwhere(maskCut)
    
    # Determine the minimum and maximum indices along each dimension
    min_coords = true_indices.min(axis=0)
    max_coords = true_indices.max(axis=0)
    
    for k,x in enumerate(os.listdir(WMDir)):
        if Name in x:
            print(Name)
            WM, affine, img = load_nifti(WMDir+x, return_img=True)
            WM, affine = reslice(WM, affine, (2,2,2), (2.5,2.5,2.5))
            if(i<5):
                WM_t = np.fliplr(np.swapaxes(WM,0,1))
            else:
                WM_t = np.fliplr(np.flipud(np.swapaxes(WM,0,1)))
            WM_t  = WM_t[min_coords[0]:max_coords[0]+1,min_coords[1]:max_coords[1]+1,min_coords[2]:max_coords[2]+1]
            WMs_MS.append(WM_t)

In [None]:
if os.path.exists(f"{network_path}/DTIMultiMSFull_300.pickle"):
    with open(f"{network_path}/DTIMultiMSFull_300.pickle", "rb") as handle:
        posteriorFull = pickle.load(handle)
else:
    np.random.seed(1)
    torch.manual_seed(1)
    Obs = []
    Par = []
    for i in tqdm(range(TrainingSamples)):
        params = priorDirec.sample()
        dt = ComputeDTI(params[:-2])
        dt = ForceLowFA(dt)
        cG = gTabsF_MS[int(params[-1])]
        Obs.append(np.hstack([CustomSimulator(dt,cG,params[-2],50),params[-1]]))
        Par.append(np.hstack([mat_to_vals(dt),params[-2]]))
    
    Obs = np.array(Obs)
    Par = np.array(Par)
    Obs = torch.tensor(Obs).float()
    Par= torch.tensor(Par).float()
    
    # Create inference object. Here, NPE is used.
    inference = SNPE(prior=priorS0)
    
    # generate simulations and pass to the inference object
    inference = inference.append_simulations(Par, Obs)
    
    # train the density estimator and build the posterior
    density_estimator = inference.train(stop_after_epochs= 100)
    posteriorFull = inference.build_posterior(density_estimator)
    if not os.path.exists(f"{network_path}/DTIMultiMSFull_300.pickle"):
        with open(f"{network_path}/DTIMultiMSFull_300.pickle", "wb") as handle:
            pickle.dump(posteriorFull, handle)
            
if os.path.exists(f"{network_path}/DTIMultiMSMin_300.pickle"):
    with open(f"{network_path}/DTIMultiMSMin_300.pickle", "rb") as handle:
        posteriorMin = pickle.load(handle)
else:
    np.random.seed(1)
    torch.manual_seed(1)
    Obs = []
    Par = []
    for i in tqdm(range(TrainingSamples)):
        params = priorDirec.sample()
        dt = ComputeDTI(params[:-2])
        dt = ForceLowFA(dt)
        cG = gTabs7_MS[int(params[-1])]
        Obs.append(np.hstack([CustomSimulator(dt,cG,params[-2],50),params[-1]]))
        Par.append(np.hstack([mat_to_vals(dt),params[-2]]))
    
    Obs = np.array(Obs)
    Par = np.array(Par)
    Obs = torch.tensor(Obs).float()
    Par= torch.tensor(Par).float()
    
    # Create inference object. Here, NPE is used.
    inference = SNPE(prior=priorS0)
    
    # generate simulations and pass to the inference object
    inference = inference.append_simulations(Par, Obs)
    
    # train the density estimator and build the posterior
    density_estimator = inference.train(stop_after_epochs= 100)
    posteriorMin = inference.build_posterior(density_estimator)
    if not os.path.exists(f"{network_path}/DTIMultiMSMin_300.pickle"):
        with open(f"{network_path}/DTIMultiMSMin_300.pickle", "wb") as handle:
            pickle.dump(posteriorMin, handle)
            
if os.path.exists(f"{network_path}/DTIMultiMSMid_300.pickle"):
    with open(f"{network_path}/DTIMultiMSMid_300.pickle", "rb") as handle:
        posteriorMid = pickle.load(handle)
else:
    np.random.seed(1)
    torch.manual_seed(1)
    Obs = []
    Par = []
    for i in tqdm(range(TrainingSamples)):
        params = priorDirec.sample()
        dt = ComputeDTI(params[:-2])
        dt = ForceLowFA(dt)
        cG = gTabs20_MS[int(params[-1])]
        Obs.append(np.hstack([CustomSimulator(dt,cG,params[-2],50),params[-1]]))
        Par.append(np.hstack([mat_to_vals(dt),params[-2]]))
    
    Obs = np.array(Obs)
    Par = np.array(Par)
    Obs = torch.tensor(Obs).float()
    Par= torch.tensor(Par).float()
    
    # Create inference object. Here, NPE is used.
    inference = SNPE(prior=priorS0)
    
    # generate simulations and pass to the inference object
    inference = inference.append_simulations(Par, Obs)
    
    # train the density estimator and build the posterior
    density_estimator = inference.train(stop_after_epochs= 100)
    posteriorMid = inference.build_posterior(density_estimator)
    if not os.path.exists(f"{network_path}/DTIMultiMSMid_300.pickle"):
        with open(f"{network_path}/DTIMultiMSMid_300.pickle", "wb") as handle:
            pickle.dump(posteriorMid, handle)

In [None]:
MDFullArr_MS = []
FAFullArr_MS = []
for kk in tqdm(range(8)):

    
    # Compute the mask where the sum is not zero
    mask = np.sum(Dats_MS[kk][:, :, axial_middles_MS[kk], :], axis=-1) != 0
    
    # Get the indices where mask is True
    indices = np.argwhere(mask)
    
    # Define the function for optimization
    def optimize_pixel(i, j):
        torch.manual_seed(10)  # If required
        posterior_samples_1 = posteriorFull.sample((500,), x=np.hstack([Dats_MS[kk][i,j,axial_middles_MS[kk], :],kk]),show_progress_bars=False)
        return i, j, np.array([histogram_mode(p) for p in posterior_samples_1.T])
    
    # Initialize NoiseEst with the appropriate shape
    ArrShape = mask.shape
    
    # Use joblib to parallelize the optimization tasks
    results = Parallel(n_jobs=8)(
        delayed(optimize_pixel)(i, j) for i, j in tqdm(indices)
    )
    
    NoiseEst = np.zeros(list(ArrShape) + [7])
    
    # Assign the optimization results to NoiseEst
    for i, j, x in results:
        NoiseEst[i, j] = x
    
    NoiseEst2 =  np.zeros_like(NoiseEst)
    for i in range(ArrShape[0]):
        for j in range(ArrShape[1]):    
            NoiseEst2[i,j] = np.hstack([mat_to_vals(clip_negative_eigenvalues(vals_to_mat(NoiseEst[i,j]))),NoiseEst[i,j,-1:]])
    MD_SBIMin = np.zeros(ArrShape)
    FA_SBIMin = np.zeros(ArrShape)
    for i in range(ArrShape[0]):
        for j in range(ArrShape[1]): 
            Eigs = np.linalg.eigh(vals_to_mat(NoiseEst2[i,j,:6]))[0]
            MD_SBIMin[i,j] = np.mean(Eigs)
            FA_SBIMin[i,j] = FracAni(Eigs,np.mean(Eigs))
    FA_SBIMin[np.isnan(FA_SBIMin)] = 0
    MDFullArr_MS.append(MD_SBIMin)
    FAFullArr_MS.append(FA_SBIMin)

MDMinArr_MS = []
FAMinArr_MS = []
for kk in tqdm(range(8)):

    
    # Compute the mask where the sum is not zero
    mask = np.sum(Dats_MS[kk][:, :, axial_middles_MS[kk], :], axis=-1) != 0
    
    # Get the indices where mask is True
    indices = np.argwhere(mask)
    Arr = Dats_MS[kk][:,:,axial_middles_MS[kk], selected_indices_MS]
    # Define the function for optimization
    def optimize_pixel(i, j):
        torch.manual_seed(10)  # If required
        posterior_samples_1 = posteriorMin.sample((500,), x=np.hstack([Arr[i,j],kk]),show_progress_bars=False)
        return i, j, np.array([histogram_mode(p) for p in posterior_samples_1.T])
    
    # Initialize NoiseEst with the appropriate shape
    ArrShape = mask.shape
    
    # Use joblib to parallelize the optimization tasks
    results = Parallel(n_jobs=8)(
        delayed(optimize_pixel)(i, j) for i, j in tqdm(indices)
    )
    
    NoiseEst = np.zeros(list(ArrShape) + [7])
    
    # Assign the optimization results to NoiseEst
    for i, j, x in results:
        NoiseEst[i, j] = x
    
    NoiseEst2 =  np.zeros_like(NoiseEst)
    for i in range(ArrShape[0]):
        for j in range(ArrShape[1]):    
            NoiseEst2[i,j] = np.hstack([mat_to_vals(clip_negative_eigenvalues(vals_to_mat(NoiseEst[i,j]))),NoiseEst[i,j,-1:]])
    MD_SBIMin = np.zeros(ArrShape)
    FA_SBIMin = np.zeros(ArrShape)
    for i in range(ArrShape[0]):
        for j in range(ArrShape[1]): 
            Eigs = np.linalg.eigh(vals_to_mat(NoiseEst2[i,j,:6]))[0]
            MD_SBIMin[i,j] = np.mean(Eigs)
            FA_SBIMin[i,j] = FracAni(Eigs,np.mean(Eigs))
    FA_SBIMin[np.isnan(FA_SBIMin)] = 0
    MDMinArr_MS.append(MD_SBIMin)
    FAMinArr_MS.append(FA_SBIMin)

MDMidArr_MS = []
FAMidArr_MS = []

for kk in tqdm(range(8)):

    # Compute the mask where the sum is not zero
    mask = np.sum(Dats_MS[kk][:, :, axial_middles_MS[kk], :], axis=-1) != 0
    
    # Get the indices where mask is True
    indices = np.argwhere(mask)

    Arr = Dats_MS[kk][:,:,axial_middles_MS[kk], selected_indices20_MS]
    # Define the function for optimization
    def optimize_pixel(i, j):
        torch.manual_seed(10)  # If required
        posterior_samples_1 = posteriorMid.sample((500,), x=np.hstack([Arr[i,j],kk]),show_progress_bars=False)
        return i, j, np.array([histogram_mode(p) for p in posterior_samples_1.T])
    
    # Initialize NoiseEst with the appropriate shape
    ArrShape = mask.shape
    
    # Use joblib to parallelize the optimization tasks
    results = Parallel(n_jobs=8)(
        delayed(optimize_pixel)(i, j) for i, j in tqdm(indices)
    )
    
    NoiseEst = np.zeros(list(ArrShape) + [7])
    
    # Assign the optimization results to NoiseEst
    for i, j, x in results:
        NoiseEst[i, j] = x
    
    NoiseEst2 =  np.zeros_like(NoiseEst)
    for i in range(ArrShape[0]):
        for j in range(ArrShape[1]):    
            NoiseEst2[i,j] = np.hstack([mat_to_vals(clip_negative_eigenvalues(vals_to_mat(NoiseEst[i,j]))),NoiseEst[i,j,-1:]])
    MD_SBIMin = np.zeros(ArrShape)
    FA_SBIMin = np.zeros(ArrShape)
    for i in range(ArrShape[0]):
        for j in range(ArrShape[1]): 
            Eigs = np.linalg.eigh(vals_to_mat(NoiseEst2[i,j,:6]))[0]
            MD_SBIMin[i,j] = np.mean(Eigs)
            FA_SBIMin[i,j] = FracAni(Eigs,np.mean(Eigs))
    FA_SBIMin[np.isnan(FA_SBIMin)] = 0
    MDMidArr_MS.append(MD_SBIMin)
    FAMidArr_MS.append(FA_SBIMin)

In [None]:
MDFullNLArr_MS = []
FAFullNLArr_MS = []
for kk in range(8):
    
    tenmodel = dti.TensorModel(gTabsF_MS[kk],return_S0_hat = True,fit_method='NLLS')
    tenfit = tenmodel.fit(Dats_MS[kk][:,:,axial_middles_MS[kk],:])
    FAFull = dti.fractional_anisotropy(tenfit.evals)
    MDFull = dti.mean_diffusivity(tenfit.evals)
    MDFullNLArr_MS.append(MDFull)
    FAFullNLArr_MS.append(FAFull)
MDMinNLArr_MS = []
FAMinNLArr_MS = []
for kk in range(8):
    
    tenmodel = dti.TensorModel(gTabs7_MS[kk],return_S0_hat = True,fit_method='NLLS')
    tenfit = tenmodel.fit(Dats_MS[kk][:,:,axial_middles_MS[kk],selected_indices_MS])
    FAFull = dti.fractional_anisotropy(tenfit.evals)
    MDFull = dti.mean_diffusivity(tenfit.evals)
    MDMinNLArr_MS.append(MDFull)
    FAMinNLArr_MS.append(FAFull)
MDMidNLArr_MS = []
FAMidNLArr_MS = []
for kk in range(8):
    
    tenmodel = dti.TensorModel(gTabs20_MS[kk],return_S0_hat = True,fit_method='NLLS')
    tenfit = tenmodel.fit(Dats_MS[kk][:,:,axial_middles_MS[kk],selected_indices20_MS])
    FAFull = dti.fractional_anisotropy(tenfit.evals)
    MDFull = dti.mean_diffusivity(tenfit.evals)
    MDMidNLArr_MS.append(MDFull)
    FAMidNLArr_MS.append(FAFull)

In [None]:
AccM7_MD_MS = []
AccM20_MD_MS = []
AccMFulls_MD_MS = []

AccM7NL_MD_MS = []
AccM20NL_MD_MS = []

SSIM7_MD_MS = []
SSIM20_MD_MS = []
SSIMFulls_MD_MS = []

SSIM7NL_MD_MS = []
SSIM20NL_MD_MS = []
for i in range(8):
    M7 = MDMinArr_MS[i]
    MF = MDFullArr_MS[i]
    Ma = Masks_MS[i][:,:,axial_middles_MS[i]]
    AccM7_MD_MS.append(np.mean(np.abs(M7-MF)[Ma]))

    M7 = MDMidArr_MS[i]
    MF = MDFullArr_MS[i]
    AccM20_MD_MS.append(np.mean(np.abs(M7-MF)[Ma]))

    M7 = MDFullArr_MS[i]
    MF = MDFullNLArr_MS[i]
    AccMFulls_MD_MS.append(np.mean(np.abs(M7-MF)[Ma]))

    M7 = MDMinNLArr_MS[i]
    MF = MDFullNLArr_MS[i]
    AccM7NL_MD_MS.append(np.mean(np.abs(M7-MF)[Ma]))

    M7 = MDMidNLArr_MS[i]
    MF = MDFullNLArr_MS[i]
    AccM20NL_MD_MS.append(np.mean(np.abs(M7-MF)[Ma]))


    NS1 = MDMinArr_MS[i]
    NS2 = MDFullArr_MS[i]
    result = masked_local_ssim(NS1, NS2, Ma, win_size=7)
    SSIM7_MD_MS.append(result)

    NS1 = MDMidArr_MS[i]
    NS2 = MDFullArr_MS[i]
    result = masked_local_ssim(NS1, NS2, Ma, win_size=7)
    SSIM20_MD_MS.append(result)
    
    NS1 = MDFullArr_MS[i]
    NS2 = MDFullNLArr_MS[i]
    result = masked_local_ssim(NS1, NS2, Ma, win_size=7)
    SSIMFulls_MD_MS.append(result)

    NS1 = MDMinNLArr_MS[i]
    NS2 = MDFullNLArr_MS[i]
    result = masked_local_ssim(NS1, NS2, Ma, win_size=7)
    SSIM7NL_MD_MS.append(result)

    NS1 = MDMidNLArr_MS[i]
    NS2 = MDFullNLArr_MS[i]
    result = masked_local_ssim(NS1, NS2, Ma, win_size=7)
    SSIM20NL_MD_MS.append(result)


Prec7_SBI_MD_MS = []
Prec20_SBI_MD_MS = []
PrecFull_SBI_MD_MS = []

Prec7_NLLS_MD_MS = []
Prec20_NLLS_MD_MS = []
PrecFull_NLLS_MD_MS = []
for i in range(8):
    Prec7_SBI_MD_MS.append(np.std(MDMinArr_MS[i][WMs_MS[i][:,:,axial_middles_MS[i]]>0.8]))
    Prec20_SBI_MD_MS.append(np.std(MDMidArr_MS[i][WMs_MS[i][:,:,axial_middles_MS[i]]>0.8]))
    PrecFull_SBI_MD_MS.append(np.std(MDFullArr_MS[i][WMs_MS[i][:,:,axial_middles_MS[i]]>0.8]))

    Prec7_NLLS_MD_MS.append(np.std(MDMinNLArr_MS[i][WMs_MS[i][:,:,axial_middles_MS[i]]>0.8]))
    Prec20_NLLS_MD_MS.append(np.std(MDMidNLArr_MS[i][WMs_MS[i][:,:,axial_middles_MS[i]]>0.8]))
    PrecFull_NLLS_MD_MS.append(np.std(MDFullNLArr_MS[i][WMs_MS[i][:,:,axial_middles_MS[i]]>0.8]))


In [None]:
AccM7_FA_MS = []
AccM20_FA_MS = []
AccMFulls_FA_MS = []

AccM7NL_FA_MS = []
AccM20NL_FA_MS = []

SSIM7_FA_MS = []
SSIM20_FA_MS = []
SSIMFulls_FA_MS = []

SSIM7NL_FA_MS = []
SSIM20NL_FA_MS = []
for i in range(8):
    M7 = FAMinArr_MS[i]
    MF = FAFullArr_MS[i]
    Ma = Masks_MS[i][:,:,axial_middles_MS[i]]
    AccM7_FA_MS.append(np.mean(np.abs(M7-MF)[Ma]))

    M7 = FAMidArr_MS[i]
    MF = FAFullArr_MS[i]
    AccM20_FA_MS.append(np.mean(np.abs(M7-MF)[Ma]))

    M7 = FAFullArr_MS[i]
    MF = FAFullNLArr_MS[i]
    AccMFulls_FA_MS.append(np.mean(np.abs(M7-MF)[Ma]))

    M7 = FAMinNLArr_MS[i]
    MF = FAFullNLArr_MS[i]
    AccM7NL_FA_MS.append(np.mean(np.abs(M7-MF)[Ma]))

    M7 = FAMidNLArr_MS[i]
    MF = FAFullNLArr_MS[i]
    AccM20NL_FA_MS.append(np.mean(np.abs(M7-MF)[Ma]))


    NS1 = FAMinArr_MS[i]
    NS2 = FAFullArr_MS[i]
    result = masked_local_ssim(NS1, NS2, Ma, win_size=7)
    SSIM7_FA_MS.append(result)

    NS1 = FAMidArr_MS[i]
    NS2 = FAFullArr_MS[i]
    result = masked_local_ssim(NS1, NS2, Ma, win_size=7)
    SSIM20_FA_MS.append(result)
    
    NS1 = FAFullArr_MS[i]
    NS2 = FAFullNLArr_MS[i]
    result = masked_local_ssim(NS1, NS2, Ma, win_size=7)
    SSIMFulls_FA_MS.append(result)

    NS1 = FAMinNLArr_MS[i]
    NS2 = FAFullNLArr_MS[i]
    result = masked_local_ssim(NS1, NS2, Ma, win_size=7)
    SSIM7NL_FA_MS.append(result)

    NS1 = FAMidNLArr_MS[i]
    NS2 = FAFullNLArr_MS[i]
    result = masked_local_ssim(NS1, NS2, Ma, win_size=7)
    SSIM20NL_FA_MS.append(result)


Prec7_SBI_FA_MS = []
Prec20_SBI_FA_MS = []
PrecFull_SBI_FA_MS = []

Prec7_NLLS_FA_MS = []
Prec20_NLLS_FA_MS = []
PrecFull_NLLS_FA_MS = []
for i in range(8):
    Prec7_SBI_FA_MS.append(np.std(FAMinArr_MS[i][WMs_MS[i][:,:,axial_middles_MS[i]]>0.8]))
    Prec20_SBI_FA_MS.append(np.std(FAMidArr_MS[i][WMs_MS[i][:,:,axial_middles_MS[i]]>0.8]))
    PrecFull_SBI_FA_MS.append(np.std(FAFullArr_MS[i][WMs_MS[i][:,:,axial_middles_MS[i]]>0.8]))

    Prec7_NLLS_FA_MS.append(np.std(FAMinNLArr_MS[i][WMs_MS[i][:,:,axial_middles_MS[i]]>0.8]))
    Prec20_NLLS_FA_MS.append(np.std(FAMidNLArr_MS[i][WMs_MS[i][:,:,axial_middles_MS[i]]>0.8]))
    PrecFull_NLLS_FA_MS.append(np.std(FAFullNLArr_MS[i][WMs_MS[i][:,:,axial_middles_MS[i]]>0.8]))


### g

In [None]:
fig, ax1 = plt.subplots(1,1)#, sharex=True)
fig.subplots_adjust(hspace=0.05)  # adjust space between Axes

# Plotting on ax1
plt.sca(ax1)
y_data = np.array(SSIMFulls_MD_MS)
g_pos = np.array([1])
colors = ['black']
colors2 = ['gray']

BoxPlots(y_data,g_pos,colors,colors2,ax1,widths=0.2)
x_data = g_pos*np.ones_like(y_data)
x_data += stats.t(df=6, scale=0.02).rvs(len(x_data))
plt.scatter(x_data[:5],y_data[:5],marker='o',color=colors2,s=100,alpha=0.5)
plt.scatter(x_data[5:],y_data[5:],marker='^',color=colors2,s=100,alpha=0.5)

y_data = np.array(SSIM20_MD_MS)
g_pos = np.array([1.7])
colors = ['lightseagreen']
colors2 = ['paleturquoise']

BoxPlots(y_data,g_pos,colors,colors2,ax1,widths=0.2)
x_data = g_pos*np.ones_like(y_data)
x_data += stats.t(df=6, scale=0.02).rvs(len(x_data))
plt.scatter(x_data[:5],y_data[:5],marker='o',color=colors2,s=100,alpha=0.5)
plt.scatter(x_data[5:],y_data[5:],marker='^',color=colors2,s=100,alpha=0.5)

y_data = np.array(SSIM7_MD_MS)
g_pos = np.array([2])
colors = ['mediumturquoise']
colors2 = ['paleturquoise']

BoxPlots(y_data,g_pos,colors,colors2,ax1,widths=0.2)
x_data = g_pos*np.ones_like(y_data)
x_data += stats.t(df=6, scale=0.02).rvs(len(x_data))
plt.scatter(x_data[:5],y_data[:5],marker='o',color=colors2,s=100,alpha=0.5)
plt.scatter(x_data[5:],y_data[5:],marker='^',color=colors2,s=100,alpha=0.5)

y_data = np.array(SSIM20NL_MD_MS)
g_pos = np.array([2.8])
colors = ['sandybrown']
colors2 = ['peachpuff']

BoxPlots(y_data,g_pos,colors,colors2,ax1,widths=0.2)
x_data = g_pos*np.ones_like(y_data)
x_data += stats.t(df=6, scale=0.02).rvs(len(x_data))
plt.scatter(x_data[:5],y_data[:5],marker='o',color=colors2,s=100,alpha=0.5)
plt.scatter(x_data[5:],y_data[5:],marker='^',color=colors2,s=100,alpha=0.5)

y_data = np.array(SSIM7NL_MD_MS)
g_pos = np.array([3.1])
colors = ['burlywood']
colors2 = ['peachpuff']

BoxPlots(y_data,g_pos,colors,colors2,ax1,widths=0.2)
x_data = g_pos*np.ones_like(y_data)
x_data += stats.t(df=6, scale=0.02).rvs(len(x_data))
plt.scatter(x_data[:5],y_data[:5],marker='o',color=colors2,s=100,alpha=0.5)
plt.scatter(x_data[5:],y_data[5:],marker='^',color=colors2,s=100,alpha=0.5)

ax1.axhline(0.66, lw=3, ls='--', c='k')
ax1.set_xticks([1, 1.7, 2, 2.8, 3.1])
ax1.set_xticklabels(['Full', 'Mid', 'Min', 'Mid', 'Min'], fontsize=32, rotation=90)
ax1.set_ylim([0,1])

leg_patch1 = mpatches.Patch(color='lightseagreen', label='Mid. (SBI)')
leg_patch2 = mpatches.Patch(color='mediumturquoise', label='Min. (SBI)')
leg_patch3 = mpatches.Patch(color='sandybrown', label='Mid. (NLLS)')
leg_patch4 = mpatches.Patch(color='burlywood', label='Min. (NLLS)')
leg_patch5 = mpatches.Patch(color='gray', label='Full Comp.')

ax1.legend(
    handles=[leg_patch3,leg_patch4],
    loc='lower left',         # base location  # fine-tune the legend's position
    frameon=False, ncols=1,
fontsize=32,columnspacing=0.3,handlelength=0.6,handletextpad=0.3,bbox_to_anchor= (-0.0,-0.05))

if Save: plt.savefig(FigLoc+'DTI_MD_SSIM_MS.pdf',format='pdf',bbox_inches='tight',transparent=True)

In [None]:
fig, ax1 = plt.subplots(1,1)#, sharex=True)
fig.subplots_adjust(hspace=0.05)  # adjust space between Axes

# Plotting on ax1
plt.sca(ax1)
y_data = np.array(AccMFulls_MD_MS)
g_pos = np.array([1])
colors = ['black']
colors2 = ['gray']

BoxPlots(y_data,g_pos,colors,colors2,ax1,widths=0.2)
x_data = g_pos*np.ones_like(y_data)
x_data += stats.t(df=6, scale=0.02).rvs(len(x_data))
plt.scatter(x_data[:5],y_data[:5],marker='o',color=colors2,s=100,alpha=0.5)
plt.scatter(x_data[5:],y_data[5:],marker='^',color=colors2,s=100,alpha=0.5)

y_data = np.array(AccM20_MD_MS)
g_pos = np.array([1.7])
colors = ['lightseagreen']
colors2 = ['paleturquoise']

BoxPlots(y_data,g_pos,colors,colors2,ax1,widths=0.2)
x_data = g_pos*np.ones_like(y_data)
x_data += stats.t(df=6, scale=0.02).rvs(len(x_data))
plt.scatter(x_data[:5],y_data[:5],marker='o',color=colors2,s=100,alpha=0.5)
plt.scatter(x_data[5:],y_data[5:],marker='^',color=colors2,s=100,alpha=0.5)

y_data = np.array(AccM7_MD_MS)
g_pos = np.array([2])
colors = ['mediumturquoise']
colors2 = ['paleturquoise']

BoxPlots(y_data,g_pos,colors,colors2,ax1,widths=0.2)
x_data = g_pos*np.ones_like(y_data)
x_data += stats.t(df=6, scale=0.02).rvs(len(x_data))
plt.scatter(x_data[:5],y_data[:5],marker='o',color=colors2,s=100,alpha=0.5)
plt.scatter(x_data[5:],y_data[5:],marker='^',color=colors2,s=100,alpha=0.5)

y_data = np.array(AccM20NL_MD_MS)
g_pos = np.array([2.8])
colors = ['sandybrown']
colors2 = ['peachpuff']

BoxPlots(y_data,g_pos,colors,colors2,ax1,widths=0.2)
x_data = g_pos*np.ones_like(y_data)
x_data += stats.t(df=6, scale=0.02).rvs(len(x_data))
plt.scatter(x_data[:5],y_data[:5],marker='o',color=colors2,s=100,alpha=0.5)
plt.scatter(x_data[5:],y_data[5:],marker='^',color=colors2,s=100,alpha=0.5)

y_data = np.array(AccM7NL_MD_MS)
g_pos = np.array([3.1])
colors = ['burlywood']
colors2 = ['peachpuff']

BoxPlots(y_data,g_pos,colors,colors2,ax1,widths=0.2)
x_data = g_pos*np.ones_like(y_data)
x_data += stats.t(df=6, scale=0.02).rvs(len(x_data))
plt.scatter(x_data[:5],y_data[:5],marker='o',color=colors2,s=100,alpha=0.5)
plt.scatter(x_data[5:],y_data[5:],marker='^',color=colors2,s=100,alpha=0.5)

ax1.set_xticks([1, 1.7, 2, 2.8, 3.1])
ax1.set_xticklabels(['Full', 'Mid', 'Min', 'Mid', 'Min'], fontsize=32, rotation=90)

leg_patch1 = mpatches.Patch(color='lightseagreen', label='Mid. (SBI)')
leg_patch2 = mpatches.Patch(color='mediumturquoise', label='Min. (SBI)')
leg_patch3 = mpatches.Patch(color='sandybrown', label='Mid. (NLLS)')
leg_patch4 = mpatches.Patch(color='burlywood', label='Min. (NLLS)')
leg_patch5 = mpatches.Patch(color='gray', label='Full Comp.')

ax1.legend(
    handles=[leg_patch5,leg_patch1,leg_patch2],
    loc='upper left',         # base location  # fine-tune the legend's position
    frameon=False, ncols=1,
fontsize=32,columnspacing=0.3,handlelength=0.6,handletextpad=0.3,bbox_to_anchor= (0,1.1))
plt.gca().ticklabel_format(axis='y',style='sci',scilimits=(-0.5,3))

if Save: plt.savefig(FigLoc+'DTI_MD_Acc_MS.pdf',format='pdf',bbox_inches='tight',transparent=True)

In [None]:
y_data = np.array(PrecFull_SBI_MD_MS)
g_pos = np.array([1])
colors = ['mediumturquoise']
colors2 = ['paleturquoise']
fig,ax = plt.subplots()
BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2)
x_data = g_pos*np.ones_like(y_data)
x_data += stats.t(df=6, scale=0.02).rvs(len(x_data))
plt.scatter(x_data[:5],y_data[:5],marker='o',color=colors2,s=100,alpha=0.5)
plt.scatter(x_data[5:],y_data[5:],marker='^',color=colors2,s=100,alpha=0.5)
y_data = np.array(Prec20_SBI_MD_MS)
g_pos = np.array([1.2])
BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2)
x_data = g_pos*np.ones_like(y_data)
x_data += stats.t(df=6, scale=0.02).rvs(len(x_data))
plt.scatter(x_data[:5],y_data[:5],marker='o',color=colors2,s=100,alpha=0.5)
plt.scatter(x_data[5:],y_data[5:],marker='^',color=colors2,s=100,alpha=0.5)

y_data = np.array(Prec7_SBI_MD_MS)
g_pos = np.array([1.4])
BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2)
x_data = g_pos*np.ones_like(y_data)
x_data += stats.t(df=6, scale=0.02).rvs(len(x_data))
plt.scatter(x_data[:5],y_data[:5],marker='o',color=colors2,s=100,alpha=0.5)
plt.scatter(x_data[5:],y_data[5:],marker='^',color=colors2,s=100,alpha=0.5)

colors = ['sandybrown']
colors2 = ['peachpuff']
y_data = np.array(PrecFull_NLLS_MD_MS)
g_pos = np.array([2])
BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2)
x_data = g_pos*np.ones_like(y_data)
x_data += stats.t(df=6, scale=0.02).rvs(len(x_data))
plt.scatter(x_data[:5],y_data[:5],marker='o',color=colors2,s=100,alpha=0.5)
plt.scatter(x_data[5:],y_data[5:],marker='^',color=colors2,s=100,alpha=0.5)

y_data = np.array(Prec20_NLLS_MD_MS)
g_pos = np.array([2.2])
BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2)
x_data = g_pos*np.ones_like(y_data)
x_data += stats.t(df=6, scale=0.02).rvs(len(x_data))
plt.scatter(x_data[:5],y_data[:5],marker='o',color=colors2,s=100,alpha=0.5)
plt.scatter(x_data[5:],y_data[5:],marker='^',color=colors2,s=100,alpha=0.5)

y_data = np.array(Prec7_NLLS_MD_MS)
g_pos = np.array([2.4])
BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2)
x_data = g_pos*np.ones_like(y_data)
x_data += stats.t(df=6, scale=0.02).rvs(len(x_data))
plt.scatter(x_data[:5],y_data[:5],marker='o',color=colors2,s=100,alpha=0.5)
plt.scatter(x_data[5:],y_data[5:],marker='^',color=colors2,s=100,alpha=0.5)

x = np.arange(1.85,2.6,0.05)
y1 = np.ones_like(x)*np.percentile(np.array(PrecFull_NLLS_MD_MS)[~np.isnan(PrecFull_NLLS_MD_MS)], 25)
y2 = np.ones_like(x)*np.percentile(np.array(PrecFull_NLLS_MD_MS)[~np.isnan(PrecFull_NLLS_MD_MS)], 77)
plt.fill_between(x,y1,y2,color=WLSFit,zorder=10,alpha=0.2,hatch='//')

x = np.arange(0.85,1.6,0.05)
y1 = np.ones_like(x)*np.percentile(np.array(PrecFull_SBI_MD_MS)[~np.isnan(PrecFull_SBI_MD_MS)], 25)
y2 = np.ones_like(x)*np.percentile(np.array(PrecFull_SBI_MD_MS)[~np.isnan(PrecFull_SBI_MD_MS)], 77)
plt.fill_between(x,y1,y2,color=SBIFit,zorder=10,alpha=0.2,hatch='//')

plt.xticks([1,1.2,1.4,2,2.2,2.4],['Full','Mid','Min','Full','Mid','Min'],fontsize=32,rotation=90)
plt.gca().ticklabel_format(axis='y',style='sci',scilimits=(-0.5,3))
plt.yticks(fontsize=24)
leg_o = Line2D([0], [0],
               marker='o',
               color='w',                # no line
               markerfacecolor='lightseagreen',
               markersize=10,
               alpha=0.5,
               linestyle='None',
               label='Healthy indiv.')

leg_tri = Line2D([0], [0],
                 marker='^',
                 color='w',
                 markerfacecolor='lightseagreen',
                 markersize=10,
                 alpha=0.5,
                 linestyle='None',
                 label='MS indiv.')

ax.legend(handles=[leg_o, leg_tri],
          loc='upper left',fontsize=32,columnspacing=0.3,handlelength=0.6,handletextpad=0.3,bbox_to_anchor= (-0.05,0.8))

if Save: plt.savefig(FigLoc+'DTI_MD_Prec_MS.pdf',format='pdf',bbox_inches='tight',transparent=True)

### h

In [None]:
fig, ax1 = plt.subplots(1,1)#, sharex=True)
fig.subplots_adjust(hspace=0.05)  # adjust space between Axes

# Plotting on ax1
plt.sca(ax1)
y_data = np.array(AccMFulls_FA_MS)
g_pos = np.array([1])
colors = ['black']
colors2 = ['gray']

BoxPlots(y_data,g_pos,colors,colors2,ax1,widths=0.2)
x_data = g_pos*np.ones_like(y_data)
x_data += stats.t(df=6, scale=0.02).rvs(len(x_data))
plt.scatter(x_data[:5],y_data[:5],marker='o',color=colors2,s=100,alpha=0.5)
plt.scatter(x_data[5:],y_data[5:],marker='^',color=colors2,s=100,alpha=0.5)

y_data = np.array(AccM20_FA_MS)
g_pos = np.array([1.7])
colors = ['lightseagreen']
colors2 = ['paleturquoise']

BoxPlots(y_data,g_pos,colors,colors2,ax1,widths=0.2)
x_data = g_pos*np.ones_like(y_data)
x_data += stats.t(df=6, scale=0.02).rvs(len(x_data))
plt.scatter(x_data[:5],y_data[:5],marker='o',color=colors2,s=100,alpha=0.5)
plt.scatter(x_data[5:],y_data[5:],marker='^',color=colors2,s=100,alpha=0.5)

y_data = np.array(AccM7_FA_MS)
g_pos = np.array([2])
colors = ['mediumturquoise']
colors2 = ['paleturquoise']

BoxPlots(y_data,g_pos,colors,colors2,ax1,widths=0.2)
x_data = g_pos*np.ones_like(y_data)
x_data += stats.t(df=6, scale=0.02).rvs(len(x_data))
plt.scatter(x_data[:5],y_data[:5],marker='o',color=colors2,s=100,alpha=0.5)
plt.scatter(x_data[5:],y_data[5:],marker='^',color=colors2,s=100,alpha=0.5)

y_data = np.array(AccM20NL_FA_MS)
g_pos = np.array([2.8])
colors = ['sandybrown']
colors2 = ['peachpuff']

BoxPlots(y_data,g_pos,colors,colors2,ax1,widths=0.2)
x_data = g_pos*np.ones_like(y_data)
x_data += stats.t(df=6, scale=0.02).rvs(len(x_data))
plt.scatter(x_data[:5],y_data[:5],marker='o',color=colors2,s=100,alpha=0.5)
plt.scatter(x_data[5:],y_data[5:],marker='^',color=colors2,s=100,alpha=0.5)

y_data = np.array(AccM7NL_FA_MS)
g_pos = np.array([3.1])
colors = ['burlywood']
colors2 = ['peachpuff']

BoxPlots(y_data,g_pos,colors,colors2,ax1,widths=0.2)
x_data = g_pos*np.ones_like(y_data)
x_data += stats.t(df=6, scale=0.02).rvs(len(x_data))
plt.scatter(x_data[:5],y_data[:5],marker='o',color=colors2,s=100,alpha=0.5)
plt.scatter(x_data[5:],y_data[5:],marker='^',color=colors2,s=100,alpha=0.5)

ax1.set_xticks([1, 1.7, 2, 2.8, 3.1])
ax1.set_xticklabels(['Full', 'Mid', 'Min', 'Mid', 'Min'], fontsize=32, rotation=90)

leg_patch1 = mpatches.Patch(color='lightseagreen', label='Mid. (SBI)')
leg_patch2 = mpatches.Patch(color='mediumturquoise', label='Min. (SBI)')
leg_patch3 = mpatches.Patch(color='sandybrown', label='Mid. (NLLS)')
leg_patch4 = mpatches.Patch(color='burlywood', label='Min. (NLLS)')
leg_patch5 = mpatches.Patch(color='gray', label='Full Comp.')

ax1.ticklabel_format(axis='y', style='sci', scilimits=(-1, -1))

if Save: plt.savefig(FigLoc+'DTI_FA_Acc_MS.pdf',format='pdf',bbox_inches='tight',transparent=True)

In [None]:
fig, ax1 = plt.subplots(1,1)#, sharex=True)
fig.subplots_adjust(hspace=0.05)  # adjust space between Axes

# Plotting on ax1
plt.sca(ax1)
y_data = np.array(SSIMFulls_FA_MS)
g_pos = np.array([1])
colors = ['black']
colors2 = ['gray']

BoxPlots(y_data,g_pos,colors,colors2,ax1,widths=0.2)
x_data = g_pos*np.ones_like(y_data)
x_data += stats.t(df=6, scale=0.02).rvs(len(x_data))
plt.scatter(x_data[:5],y_data[:5],marker='o',color=colors2,s=100,alpha=0.5)
plt.scatter(x_data[5:],y_data[5:],marker='^',color=colors2,s=100,alpha=0.5)

y_data = np.array(SSIM20_FA_MS)
g_pos = np.array([1.7])
colors = ['lightseagreen']
colors2 = ['paleturquoise']

BoxPlots(y_data,g_pos,colors,colors2,ax1,widths=0.2)
x_data = g_pos*np.ones_like(y_data)
x_data += stats.t(df=6, scale=0.02).rvs(len(x_data))
plt.scatter(x_data[:5],y_data[:5],marker='o',color=colors2,s=100,alpha=0.5)
plt.scatter(x_data[5:],y_data[5:],marker='^',color=colors2,s=100,alpha=0.5)

y_data = np.array(SSIM7_FA_MS)
g_pos = np.array([2])
colors = ['mediumturquoise']
colors2 = ['paleturquoise']

BoxPlots(y_data,g_pos,colors,colors2,ax1,widths=0.2)
x_data = g_pos*np.ones_like(y_data)
x_data += stats.t(df=6, scale=0.02).rvs(len(x_data))
plt.scatter(x_data[:5],y_data[:5],marker='o',color=colors2,s=100,alpha=0.5)
plt.scatter(x_data[5:],y_data[5:],marker='^',color=colors2,s=100,alpha=0.5)

y_data = np.array(SSIM20NL_FA_MS)
g_pos = np.array([2.8])
colors = ['sandybrown']
colors2 = ['peachpuff']

BoxPlots(y_data,g_pos,colors,colors2,ax1,widths=0.2)
x_data = g_pos*np.ones_like(y_data)
x_data += stats.t(df=6, scale=0.02).rvs(len(x_data))
plt.scatter(x_data[:5],y_data[:5],marker='o',color=colors2,s=100,alpha=0.5)
plt.scatter(x_data[5:],y_data[5:],marker='^',color=colors2,s=100,alpha=0.5)

y_data = np.array(SSIM7NL_FA_MS)
g_pos = np.array([3.1])
colors = ['burlywood']
colors2 = ['peachpuff']

BoxPlots(y_data,g_pos,colors,colors2,ax1,widths=0.2)
x_data = g_pos*np.ones_like(y_data)
x_data += stats.t(df=6, scale=0.02).rvs(len(x_data))
plt.scatter(x_data[:5],y_data[:5],marker='o',color=colors2,s=100,alpha=0.5)
plt.scatter(x_data[5:],y_data[5:],marker='^',color=colors2,s=100,alpha=0.5)

ax1.axhline(0.66, lw=3, ls='--', c='k')
ax1.set_xticks([1, 1.7, 2, 2.8, 3.1])
ax1.set_xticklabels(['Full', 'Mid', 'Min', 'Mid', 'Min'], fontsize=32, rotation=90)
ax1.set_ylim([0,1])

leg_patch1 = mpatches.Patch(color='lightseagreen', label='Mid. (SBI)')
leg_patch2 = mpatches.Patch(color='mediumturquoise', label='Min. (SBI)')
leg_patch3 = mpatches.Patch(color='sandybrown', label='Mid. (NLLS)')
leg_patch4 = mpatches.Patch(color='burlywood', label='Min. (NLLS)')
leg_patch5 = mpatches.Patch(color='gray', label='Full Comp.')

if Save: plt.savefig(FigLoc+'DTI_FA_SSIM_MS.pdf',format='pdf',bbox_inches='tight',transparent=True)

In [None]:
y_data = np.array(PrecFull_SBI_FA_MS)
g_pos = np.array([1])
colors = ['mediumturquoise']
colors2 = ['paleturquoise']
fig,ax = plt.subplots()
BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2)
x_data = g_pos*np.ones_like(y_data)
x_data += stats.t(df=6, scale=0.02).rvs(len(x_data))
plt.scatter(x_data[:5],y_data[:5],marker='o',color=colors2,s=100,alpha=0.5)
plt.scatter(x_data[5:],y_data[5:],marker='^',color=colors2,s=100,alpha=0.5)
y_data = np.array(Prec20_SBI_FA_MS)
g_pos = np.array([1.2])
BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2)
x_data = g_pos*np.ones_like(y_data)
x_data += stats.t(df=6, scale=0.02).rvs(len(x_data))
plt.scatter(x_data[:5],y_data[:5],marker='o',color=colors2,s=100,alpha=0.5)
plt.scatter(x_data[5:],y_data[5:],marker='^',color=colors2,s=100,alpha=0.5)

y_data = np.array(Prec7_SBI_FA_MS)
g_pos = np.array([1.4])
BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2)
x_data = g_pos*np.ones_like(y_data)
x_data += stats.t(df=6, scale=0.02).rvs(len(x_data))
plt.scatter(x_data[:5],y_data[:5],marker='o',color=colors2,s=100,alpha=0.5)
plt.scatter(x_data[5:],y_data[5:],marker='^',color=colors2,s=100,alpha=0.5)

colors = ['sandybrown']
colors2 = ['peachpuff']
y_data = np.array(PrecFull_NLLS_FA_MS)
g_pos = np.array([2])
BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2)
x_data = g_pos*np.ones_like(y_data)
x_data += stats.t(df=6, scale=0.02).rvs(len(x_data))
plt.scatter(x_data[:5],y_data[:5],marker='o',color=colors2,s=100,alpha=0.5)
plt.scatter(x_data[5:],y_data[5:],marker='^',color=colors2,s=100,alpha=0.5)

y_data = np.array(Prec20_NLLS_FA_MS)
g_pos = np.array([2.2])
BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2)
x_data = g_pos*np.ones_like(y_data)
x_data += stats.t(df=6, scale=0.02).rvs(len(x_data))
plt.scatter(x_data[:5],y_data[:5],marker='o',color=colors2,s=100,alpha=0.5)
plt.scatter(x_data[5:],y_data[5:],marker='^',color=colors2,s=100,alpha=0.5)

y_data = np.array(Prec7_NLLS_FA_MS)
g_pos = np.array([2.4])
BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2)
x_data = g_pos*np.ones_like(y_data)
x_data += stats.t(df=6, scale=0.02).rvs(len(x_data))
plt.scatter(x_data[:5],y_data[:5],marker='o',color=colors2,s=100,alpha=0.5)
plt.scatter(x_data[5:],y_data[5:],marker='^',color=colors2,s=100,alpha=0.5)

x = np.arange(1.85,2.6,0.05)
y1 = np.ones_like(x)*np.percentile(np.array(PrecFull_NLLS_FA_MS)[~np.isnan(PrecFull_NLLS_FA_MS)], 25)
y2 = np.ones_like(x)*np.percentile(np.array(PrecFull_NLLS_FA_MS)[~np.isnan(PrecFull_NLLS_FA_MS)], 77)
plt.fill_between(x,y1,y2,color=WLSFit,zorder=10,alpha=0.2,hatch='//')

x = np.arange(0.85,1.6,0.05)
y1 = np.ones_like(x)*np.percentile(np.array(PrecFull_SBI_FA_MS)[~np.isnan(PrecFull_SBI_FA_MS)], 25)
y2 = np.ones_like(x)*np.percentile(np.array(PrecFull_SBI_FA_MS)[~np.isnan(PrecFull_SBI_FA_MS)], 77)
plt.fill_between(x,y1,y2,color=SBIFit,zorder=10,alpha=0.2,hatch='//')

plt.xticks([1,1.2,1.4,2,2.2,2.4],['Full','Mid','Min','Full','Mid','Min'],fontsize=32,rotation=90)
plt.gca().ticklabel_format(axis='y',style='sci',scilimits=(-0.5,3))
plt.yticks(fontsize=24)
leg_o = Line2D([0], [0],
               marker='o',
               color='w',                # no line
               markerfacecolor='lightseagreen',
               markersize=10,
               alpha=0.5,
               linestyle='None',
               label='Healthy indiv.')

leg_tri = Line2D([0], [0],
                 marker='^',
                 color='w',
                 markerfacecolor='lightseagreen',
                 markersize=10,
                 alpha=0.5,
                 linestyle='None',
                 label='MS indiv.')

if Save: plt.savefig(FigLoc+'DTI_FA_Prec_MS.pdf',format='pdf',bbox_inches='tight',transparent=True)

# Figure 4

## HCP

In [None]:
i=3
fdwi = './HCP_data/Pat'+str(i)+'/diff_1k.nii.gz'
bvalloc = './HCP_data/Pat'+str(i)+'/bvals_1k.txt'
bvecloc = './HCP_data/Pat'+str(i)+'/bvecs_1k.txt'

fdwi3 = './HCP_data/Pat'+str(i)+'/diff_3k.nii.gz'
bvalloc3 = './HCP_data/Pat'+str(i)+'/bvals_3k.txt'
bvecloc3 = './HCP_data/Pat'+str(i)+'/bvecs_3k.txt'

bvalsHCP = np.loadtxt(bvalloc)
bvecsHCP = np.loadtxt(bvecloc)
gtabHCP = gradient_table(bvalsHCP, bvecsHCP)

bvalsHCP3 = np.loadtxt(bvalloc3)
bvecsHCP3 = np.loadtxt(bvecloc3)
gtabHCP3 = gradient_table(bvalsHCP3, bvecsHCP3)

gtabExt  = gradient_table(np.hstack((bvalsHCP,bvalsHCP3)), np.vstack((bvecsHCP,bvecsHCP3)))

data, affine, img = load_nifti(fdwi, return_img=True)
data, affine = reslice(data, affine, (1.5,1.5,1.5), (2.5,2.5,2.5))
maskdata, mask = median_otsu(data, vol_idx=range(10, 50), median_radius=4,
                             numpass=1, autocrop=False, dilate=2)
_, mask2 = median_otsu(data, vol_idx=range(10, 50), median_radius=4,
                             numpass=1, autocrop=True, dilate=2)


data3, affine, img = load_nifti(fdwi3, return_img=True)
data3, affine = reslice(data3, affine, (1.5,1.5,1.5), (2.5,2.5,2.5))
# Get the indices of True values
true_indices = np.argwhere(mask)

# Determine the minimum and maximum indices along each dimension
min_coords = true_indices.min(axis=0)
max_coords = true_indices.max(axis=0)

maskdata  = maskdata[min_coords[0]:max_coords[0]+1,min_coords[1]:max_coords[1]+1,min_coords[2]:max_coords[2]+1]
axial_middle = maskdata.shape[2] // 2
maskdata3 = data3[min_coords[0]:max_coords[0]+1,min_coords[1]:max_coords[1]+1,min_coords[2]:max_coords[2]+1]

TestData = np.concatenate([maskdata[:, :, axial_middle, :],maskdata3[:, :, axial_middle, :]],axis=-1)
TestData4D = np.concatenate([maskdata,maskdata3],axis=-1)

# Choose the first point (arbitrary starting point, e.g., the first gradient)
selected_indices = [1]
distance_matrix = squareform(pdist(bvecsHCP))
# Iteratively select the point furthest from the current selection
for _ in range(5):  # We need 7 points in total, and one is already selected
    remaining_indices = list(set(range(len(bvecsHCP))) - set(selected_indices))
    
    # Calculate the minimum distance to the selected points for each remaining point
    min_distances = np.min(distance_matrix[remaining_indices][:, selected_indices], axis=1)
    
    # Select the point with the maximum minimum distance
    next_index = remaining_indices[np.argmax(min_distances)]
    selected_indices.append(next_index)

selected_indices7 = [0]+selected_indices

bvalsHCP7_1 = bvalsHCP[selected_indices7]
bvecsHCP7_1 = bvecsHCP[selected_indices7]

i=3
bvalloc = './HCP_data/Pat'+str(i)+'/bvals_3k.txt'
bvecloc = './HCP_data/Pat'+str(i)+'/bvecs_3k.txt'

bvalsHCP3 = np.loadtxt(bvalloc)
bvecsHCP3 = np.loadtxt(bvecloc)
gtabHCP3 = gradient_table(bvalsHCP, bvecsHCP)

# Choose the first point (arbitrary starting point, e.g., the first gradient)
selected_indices = [0]

temp_bvecs = bvecsHCP3[bvalsHCP3>0]
temp_bvals = bvalsHCP3[bvalsHCP3>0]
distance_matrix = squareform(pdist(temp_bvecs))
# Iteratively select the point furthest from the current selection
for _ in range(14):  # We need 7 points in total, and one is already selected
    remaining_indices = list(set(range(len(temp_bvecs))) - set(selected_indices))
    
    # Calculate the minimum distance to the selected points for each remaining point
    min_distances = np.min(distance_matrix[remaining_indices][:, selected_indices], axis=1)
    
    # Select the point with the maximum minimum distance
    next_index = remaining_indices[np.argmax(min_distances)]
    selected_indices.append(next_index)

bvalsHCP7_3 = temp_bvals[selected_indices]
bvecsHCP7_3 = temp_bvecs[selected_indices]

gtabHCP7 = gradient_table(np.hstack((bvalsHCP7_1,bvalsHCP7_3)), np.vstack((bvecsHCP7_1,bvecsHCP7_3)))

true_indx = []
for b in bvecsHCP7_3:
    true_indx.append(np.linalg.norm(b-bvecsHCP3,axis=1).argmin())
true_indx = selected_indices7+[t+69 for t in true_indx]
gtabHCP7 = gradient_table(np.hstack((bvalsHCP7_1,bvalsHCP7_3)), np.vstack((bvecsHCP7_1,bvecsHCP7_3)))

cutout = np.sum(TestData4D[:,:,axial_middle,:69], axis=-1) != 0

In [None]:
np.random.seed(1)
torch.manual_seed(1)
if os.path.exists(f"{network_path}/DKIHCPFull.pickle"):
    with open(f"{network_path}/DKIHCPFull.pickle", "rb") as handle:
        posteriorFull = pickle.load(handle)
else:
    torch.manual_seed(1)
    np.random.seed(1)
    DT = []
    KT = []
    S0 = []
    DT2,KT2 = GenDTKT([DT1_lfa,DT2_lfa],[x4_lfa,R1_lfa,x2_lfa,R2_lfa],12,int(13000))
    DT3,KT3 = GenDTKT([DT1_hfa,DT2_hfa],[x4_hfa,R1_hfa,x2_hfa,R2_hfa],12,int(13000))
    DT5,KT5 = GenDTKT([DT1_hak,DT2_hak],[x4_hak,R1_hak,x2_hak,R2_hak],12,int(26000))   
    
    
    DT = np.vstack([DT2,DT3,DT5])
    KT = np.vstack([KT2,KT3,KT5])
    
    S0Dist = BoxUniform(low=torch.tensor([lower_S0]), high=torch.tensor([upper_S0]))
    
    S0 = S0Dist.sample([4*13000])
    
    S0 = np.array(S0).reshape(len(S0),1)
    
    indx = np.arange(len(KT))
    Obs = np.zeros([len(KT),len(gtabExt.bvecs)])
    kk = 0
    while len(indx)>0:
        for i in tqdm.tqdm(indx): 
            Obs[i] = CustomDKISimulator(DT[i],KT[i],gtabExt,S0[i],np.random.rand()*20 + 30)
        
        indxNew = []
        for i,O in enumerate(Obs):
            if (O>4*np.array(S0[i])).any() or (O<0).any():
                indxNew.append(i)
        KT[indxNew] = KT[indxNew]/2
        DT[indxNew] = GenDTKT([DT1_full,DT2_full],[x4_full,R1_full,x2_full,R2_full],kk,1)[0]
    
        indx = indxNew
        kk+=1
    Par = np.hstack([DT,KT,S0])
    Obs = torch.tensor(Obs).float()
    Par = torch.tensor(Par).float()
    
    # Create inference object. Here, NPE is used.
    inference = SNPE()
    
    # generate simulations and pass to the inference object
    inference = inference.append_simulations(Par, Obs)
    
    # train the density estimator and build the posterior
    density_estimator = inference.train(stop_after_epochs= 100)
    posteriorFull = inference.build_posterior(density_estimator)
    
    if not os.path.exists(f"{network_path}/DKIHCPFull.pickle"):
        with open(f"{network_path}/DKIHCPFull.pickle", "wb") as handle:
            pickle.dump(posteriorFull, handle)
np.random.seed(1)
torch.manual_seed(1)
if os.path.exists(f"{network_path}/DKIHCPMin.pickle"):
    with open(f"{network_path}/DKIHCPMin.pickle", "rb") as handle:
        posteriorMin = pickle.load(handle)
else:
    DT = []
    KT = []
    S0 = []

    DT2,KT2 = GenDTKT([DT1_lfa,DT2_lfa],[x4_lfa,R1_lfa,x2_lfa,R2_lfa],12,int(3*13000))
    DT3,KT3 = GenDTKT([DT1_hfa,DT2_hfa],[x4_hfa,R1_hfa,x2_hfa,R2_hfa],12,int(3*13000))
    DT5,KT5 = GenDTKT([DT1_hak,DT2_hak],[x4_hak,R1_hak,x2_hak,R2_hak],12,int(3*26000))   
    
    
    DT = np.vstack([DT2,DT3,DT5])
    KT = np.vstack([KT2,KT3,KT5])
    
    S0Dist = BoxUniform(low=torch.tensor([lower_S0]), high=torch.tensor([upper_S0]))
    
    S0 = S0Dist.sample([3*52000])
    
    S0 = np.array(S0).reshape(len(S0),1)
    
    indx = np.arange(len(KT))
    Obs = np.zeros([len(KT),len(gtabHCP7.bvecs)])
    kk = 0
    while len(indx)>0:
        for i in tqdm.tqdm(indx): 
            Obs[i] = CustomDKISimulator(DT[i],KT[i],gtabHCP7,S0[i],np.random.rand()*20 + 30)
        
        indxNew = []
        for i,O in enumerate(Obs):
            if (O>4*np.array(S0[i])).any() or (O<0).any():
                indxNew.append(i)
        KT[indxNew] = KT[indxNew]/2
        DT[indxNew] = GenDTKT([DT1_full,DT2_full],[x4_full,R1_full,x2_full,R2_full],kk,1)[0]
    
        indx = indxNew
        kk+=1
    Par = np.hstack([DT,KT,S0])
    Obs = torch.tensor(Obs).float()
    Par = torch.tensor(Par).float()
    
    # Create inference object. Here, NPE is used.
    inference = SNPE()
    
    # generate simulations and pass to the inference object
    inference = inference.append_simulations(Par, Obs)
    
    # train the density estimator and build the posterior
    density_estimator = inference.train(stop_after_epochs= 100)
    posteriorFull = inference.build_posterior(density_estimator)
    
    if not os.path.exists(f"{network_path}/DKIHCPMin.pickle"):
        with open(f"{network_path}/DKIHCPMin.pickle", "wb") as handle:
            pickle.dump(posteriorMin, handle)


In [None]:
# Compute the mask where the sum is not zero
mask = np.sum(TestData4D[:,:,axial_middle,:], axis=-1) != 0

# Get the indices where mask is True
indices = np.argwhere(mask)

def optimize_chunk(pixels):
    results = []
    for i, j in pixels:
        samples = posteriorFull.sample((500,), x=TestData4D[i,j,axial_middle, :],show_progress_bars=False)
        results.append((i, j, samples.mean(axis=0)))
    return results

chunked_indices = [indices[i:i+ChunkSize] for i in range(0, len(indices), ChunkSize)]
results = Parallel(n_jobs=8)(
    delayed(optimize_chunk)(chunk) for chunk in tqdm(chunked_indices)
)


# Initialize NoiseEst with the appropriate shape
ArrShape = mask.shape

NoiseEst = np.zeros([62, 68 ,22])
for chunk in results:
    for i, j, x in chunk:
        NoiseEst[i, j] = x
NoiseEst2 =  np.zeros_like(NoiseEst)
for i in range(62):
    for j in range(68):    
        NoiseEst2[i,j] = np.hstack([mat_to_vals(clip_negative_eigenvalues(vals_to_mat(NoiseEst[i,j]))),NoiseEst[i,j,6:]])

MK_SBIFull  = np.zeros([62, 68])
AK_SBIFull  = np.zeros([62, 68])
RK_SBIFull  = np.zeros([62, 68])
MKT_SBIFull = np.zeros([62, 68])
KFA_SBIFull = np.zeros([62, 68])
for i in tqdm(range(62)):
    for j in range(68):
        Metrics = DKIMetrics(NoiseEst2[i,j][:6],NoiseEst2[i,j][6:21])
        MK_SBIFull[i,j] = Metrics[0]
        AK_SBIFull[i,j] = Metrics[1]
        RK_SBIFull[i,j] = Metrics[2]
        MKT_SBIFull[i,j] = Metrics[3]
        KFA_SBIFull[i,j] = Metrics[4]
KFA_SBIFull[np.isnan(KFA_SBIFull)] = 1

In [None]:
# Compute the mask where the sum is not zero
mask = np.sum(TestData4D[:,:,axial_middle,:], axis=-1) != 0

# Get the indices where mask is True
indices = np.argwhere(mask)
Arr = TestData4D[:,:,axial_middle, true_indx]
# Define the function for optimization
def optimize_chunk(pixels):
    results = []
    for i, j in pixels:
        samples = posteriorMin.sample((500,), x=Arr[i,j],show_progress_bars=False)
        results.append((i, j, samples.mean(axis=0)))
    return results

chunked_indices = [indices[i:i+ChunkSize] for i in range(0, len(indices), ChunkSize)]
results = Parallel(n_jobs=8)(
    delayed(optimize_chunk)(chunk) for chunk in tqdm(chunked_indices)
)

# Initialize NoiseEst with the appropriate shape
ArrShape = mask.shape


NoiseEst7 = np.zeros([62, 68 ,22])
for chunk in results:
    for i, j, x in chunk:
        NoiseEst7[i, j] = x
NoiseEst2 =  np.zeros_like(NoiseEst)
for i in range(62):
    for j in range(68):    
        NoiseEst2[i,j] = np.hstack([mat_to_vals(clip_negative_eigenvalues(vals_to_mat(NoiseEst7[i,j]))),NoiseEst7[i,j,6:]])

MK_SBI7  = np.zeros([62, 68])
AK_SBI7  = np.zeros([62, 68])
RK_SBI7  = np.zeros([62, 68])
MKT_SBI7 = np.zeros([62, 68])
KFA_SBI7 = np.zeros([62, 68])
for i in tqdm(range(62)):
    for j in range(68):
        Metrics = DKIMetrics(NoiseEst2[i,j][:6],NoiseEst2[i,j][6:21])
        MK_SBI7[i,j] = Metrics[0]
        AK_SBI7[i,j] = Metrics[1]
        RK_SBI7[i,j] = Metrics[2]
        MKT_SBI7[i,j] = Metrics[3]
        KFA_SBI7[i,j] = Metrics[4]
KFA_SBI7[np.isnan(KFA_SBI7)] = 1

In [None]:
dkimodelNL = dki.DiffusionKurtosisModel(gtabExt,fit_method='NLLS')
dkifitNL = dkimodelNL.fit(TestData[:,:,:])
MK_NLFull  = np.zeros([62, 68])
AK_NLFull  = np.zeros([62, 68])
RK_NLFull  = np.zeros([62, 68])
MKT_NLFull = np.zeros([62, 68])
KFA_NLFull = np.zeros([62, 68])
for i in range(62):
    for j in range(68):
        Metrics = DKIMetrics(dkifitNL[i,j].lower_triangular(),dkifitNL[i,j].kt)
        MK_NLFull[i,j] = Metrics[0]
        AK_NLFull[i,j] = Metrics[1]
        RK_NLFull[i,j] = Metrics[2]
        MKT_NLFull[i,j] = Metrics[3]
        KFA_NLFull[i,j] = Metrics[4]

In [None]:
dkimodelNL = dki.DiffusionKurtosisModel(gtabHCP7,fit_method='NLLS')
dkifitNL = dkimodelNL.fit(TestData[:,:,true_indx])
MK_NL7  = np.zeros([62, 68])
AK_NL7  = np.zeros([62, 68])
RK_NL7 = np.zeros([62, 68])
MKT_NL7 = np.zeros([62, 68])
KFA_NL7 = np.zeros([62, 68])
for i in range(62):
    for j in range(68):
        Metrics = DKIMetrics(dkifitNL[i,j].lower_triangular(),dkifitNL[i,j].kt)
        MK_NL7[i,j] = Metrics[0]
        AK_NL7[i,j] = Metrics[1]
        RK_NL7[i,j] = Metrics[2]
        MKT_NL7[i,j] = Metrics[3]
        KFA_NL7[i,j] = Metrics[4]

In [None]:
i = 1
bvalloc = './HCP_data/Pat'+str(i)+'/bvals_1k.txt'
bvecloc = './HCP_data/Pat'+str(i)+'/bvecs_1k.txt'
bvalsHCP = np.loadtxt(bvalloc)
bvecsHCP = np.loadtxt(bvecloc)
gtabHCP = gradient_table(bvalsHCP, bvecsHCP)

# Choose the first point (arbitrary starting point, e.g., the first gradient)
selected_indices = [1]
distance_matrix = squareform(pdist(bvecsHCP))
# Iteratively select the point furthest from the current selection
for _ in range(5):  # We need 7 points in total, and one is already selected
    remaining_indices = list(set(range(len(bvecsHCP))) - set(selected_indices))
    
    # Calculate the minimum distance to the selected points for each remaining point
    min_distances = np.min(distance_matrix[remaining_indices][:, selected_indices], axis=1)
    
    # Select the point with the maximum minimum distance
    next_index = remaining_indices[np.argmax(min_distances)]
    selected_indices.append(next_index)

selected_indices7 = [0]+selected_indices

bvalsHCP7_1 = bvalsHCP[selected_indices7]
bvecsHCP7_1 = bvecsHCP[selected_indices7]

bvalloc = './HCP_data/Pat'+str(i)+'/bvals_3k.txt'
bvecloc = './HCP_data/Pat'+str(i)+'/bvecs_3k.txt'

bvalsHCP3 = np.loadtxt(bvalloc)
bvecsHCP3 = np.loadtxt(bvecloc)
gtabHCP3 = gradient_table(bvalsHCP, bvecsHCP)

# Choose the first point (arbitrary starting point, e.g., the first gradient)
selected_indices = [0]

temp_bvecs = bvecsHCP3[bvalsHCP3>0]
temp_bvals = bvalsHCP3[bvalsHCP3>0]
distance_matrix = squareform(pdist(temp_bvecs))
# Iteratively select the point furthest from the current selection
for _ in range(14):  # We need 7 points in total, and one is already selected
    remaining_indices = list(set(range(len(temp_bvecs))) - set(selected_indices))
    
    # Calculate the minimum distance to the selected points for each remaining point
    min_distances = np.min(distance_matrix[remaining_indices][:, selected_indices], axis=1)
    
    # Select the point with the maximum minimum distance
    next_index = remaining_indices[np.argmax(min_distances)]
    selected_indices.append(next_index)

bvalsHCP7_3 = temp_bvals[selected_indices]
bvecsHCP7_3 = temp_bvecs[selected_indices]

gtabHCP7 = gradient_table(np.hstack((bvalsHCP7_1,bvalsHCP7_3)), np.vstack((bvecsHCP7_1,bvecsHCP7_3)))

true_indx = []
for b in bvecsHCP7_3:
    true_indx.append(np.linalg.norm(b-bvecsHCP3,axis=1).argmin())
selected_indices7 = selected_indices7+[t+69 for t in true_indx]

# Choose the first point (arbitrary starting point, e.g., the first gradient)
selected_indices = [1]
distance_matrix = squareform(pdist(bvecsHCP))

temp_bvecs = bvecsHCP[bvalsHCP>0]
temp_bvals = bvalsHCP[bvalsHCP>0]
distance_matrix = squareform(pdist(temp_bvecs))
# Iteratively select the point furthest from the current selection
for _ in range(18):  # We need 7 points in total, and one is already selected
    remaining_indices = list(set(range(len(temp_bvecs))) - set(selected_indices))
    
    # Calculate the minimum distance to the selected points for each remaining point
    min_distances = np.min(distance_matrix[remaining_indices][:, selected_indices], axis=1)
    
    # Select the point with the maximum minimum distance
    next_index = remaining_indices[np.argmax(min_distances)]
    selected_indices.append(next_index)

temp = selected_indices

bvalsHCP7_1 = np.insert(temp_bvals[temp],0,0)
bvecsHCP7_1 = np.insert(temp_bvecs[temp],0,[0,0,0],axis=0)

bvalloc = './HCP_data/Pat'+str(i)+'/bvals_3k.txt'
bvecloc = './HCP_data/Pat'+str(i)+'/bvecs_3k.txt'

bvalsHCP3 = np.loadtxt(bvalloc)
bvecsHCP3 = np.loadtxt(bvecloc)
gtabHCP3 = gradient_table(bvalsHCP, bvecsHCP)

# Choose the first point (arbitrary starting point, e.g., the first gradient)
selected_indices = [0]

temp_bvecs = bvecsHCP3[bvalsHCP3>0]
temp_bvals = bvalsHCP3[bvalsHCP3>0]
distance_matrix = squareform(pdist(temp_bvecs))
# Iteratively select the point furthest from the current selection
for _ in range(27):  # We need 7 points in total, and one is already selected
    remaining_indices = list(set(range(len(temp_bvecs))) - set(selected_indices))
    
    # Calculate the minimum distance to the selected points for each remaining point
    min_distances = np.min(distance_matrix[remaining_indices][:, selected_indices], axis=1)
    
    # Select the point with the maximum minimum distance
    next_index = remaining_indices[np.argmax(min_distances)]
    selected_indices.append(next_index)

bvalsHCP7_3 = temp_bvals[selected_indices]
bvecsHCP7_3 = temp_bvecs[selected_indices]

gtabHCP20 = gradient_table(np.hstack((bvalsHCP7_1,bvalsHCP7_3)), np.vstack((bvecsHCP7_1,bvecsHCP7_3)))

true_indx_one = []
for b in bvecsHCP7_1:
    true_indx_one.append(np.linalg.norm(b-bvecsHCP,axis=1).argmin())
true_indx = []        
for b in bvecsHCP7_3:
    true_indx.append(np.linalg.norm(b-bvecsHCP3,axis=1).argmin())
selected_indices20 = true_indx_one+[t+69 for t in true_indx]

In [None]:
gTabsF = []
gTabs7 = []
gTabs20 = []

FullDat   = []

for i in tqdm(range(1,33)):
    fdwi = './HCP_data/Pat'+str(i)+'/diff_1k.nii.gz'
    bvalloc = './HCP_data/Pat'+str(i)+'/bvals_1k.txt'
    bvecloc = './HCP_data/Pat'+str(i)+'/bvecs_1k.txt'
    
    fdwi3 = './HCP_data/Pat'+str(i)+'/diff_3k.nii.gz'
    bvalloc3 = './HCP_data/Pat'+str(i)+'/bvals_3k.txt'
    bvecloc3 = './HCP_data/Pat'+str(i)+'/bvecs_3k.txt'
    
    bvalsHCP = np.loadtxt(bvalloc)
    bvecsHCP = np.loadtxt(bvecloc)
    gtabHCP = gradient_table(bvalsHCP, bvecsHCP)
    
    bvalsHCP3 = np.loadtxt(bvalloc3)
    bvecsHCP3 = np.loadtxt(bvecloc3)
    gtabHCP3 = gradient_table(bvalsHCP3, bvecsHCP3)
    
    gtabExt  = gradient_table(np.hstack((bvalsHCP,bvalsHCP3)), np.vstack((bvecsHCP,bvecsHCP3)))
    gTabsF.append(gtabExt)
    
    bvalsHCP7 = gtabExt.bvals[selected_indices7]
    bvecsHCP7 = gtabExt.bvecs[selected_indices7]
    gtabHCP7 = gradient_table(bvalsHCP7, bvecsHCP7)
    gTabs7.append(gtabHCP7)

    bvalsHCP20 = gtabExt.bvals[selected_indices20]
    bvecsHCP20 = gtabExt.bvecs[selected_indices20]
    gtabHCP20 = gradient_table(bvalsHCP20, bvecsHCP20)
    gTabs20.append(gtabHCP20)

In [None]:
if os.path.exists(f"{network_path}/DKIMultiHCPFull_300k.pickle"):
    with open(f"{network_path}/DKIMultiHCPFull_300k.pickle", "rb") as handle:
        posteriorFull = pickle.load(handle)
else:
    np.random.seed(1)
    DT = []
    KT = []
    S0 = []
    DT2,KT2 = GenDTKT([DT1_lfa,DT2_lfa],[x4_lfa,R1_lfa,x2_lfa,R2_lfa],12,int(75000))
    DT3,KT3 = GenDTKT([DT1_hfa,DT2_hfa],[x4_hfa,R1_hfa,x2_hfa,R2_hfa],12,int(75000))
    DT5,KT5 = GenDTKT([DT1_hak,DT2_hak],[x4_hak,R1_hak,x2_hak,R2_hak],12,int(150000))   
    
    DT = np.vstack([DT5,DT2,DT3])
    KT = np.vstack([KT5,KT2,KT3])
    
    S0Dist = BoxUniform(low=torch.tensor([lower_S0]), high=torch.tensor([upper_S0]))
    
    S0 = S0Dist.sample([DT.shape[0]])
    
    S0 = np.array(S0).reshape(len(S0),1)
    
    A  = np.random.choice(32,DT.shape[0])
    
    indx = np.arange(len(KT))
    Obs = np.zeros([len(KT),len(gTabsF[0].bvecs)])
    kk = 0
    while len(indx)>0:
        for i in tqdm.tqdm(indx): 
            Obs[i] = CustomDKISimulator(DT[i],KT[i],gTabsF[A[i]],S0[i],50)
        
        indxNew = []
        for i,O in enumerate(Obs):
            if (O>4*np.array(S0[i])).any() or (O<0).any():
                indxNew.append(i)
        KT[indxNew] = KT[indxNew]/2
        DT[indxNew] = GenDTKT([DT1_full,DT2_full],[x4_full,R1_full,x2_full,R2_full],kk,1)[0]
    
        indx = indxNew
        kk+=1
    Par = np.hstack([DT,KT,S0])
    Obs = np.hstack([Obs,np.expand_dims(A, axis=-1)])
        
    
    Obs = torch.tensor(Obs).float()
    Par = torch.tensor(Par).float()
    
    # Create inference object. Here, NPE is used.
    inference = SNPE()
    
    # generate simulations and pass to the inference object
    inference = inference.append_simulations(Par, Obs)
    
    # train the density estimator and build the posterior
    density_estimator = inference.train(stop_after_epochs= 50)
    posteriorFull = inference.build_posterior(density_estimator)
    if not os.path.exists(f"{network_path}/DKIMultiHCPFull_300k.pickle"):
        with open(f"{network_path}/DKIMultiHCPFull_300k.pickle", "wb") as handle:
            pickle.dump(posteriorFull, handle)
    import os
    os.system("say 'DKI network done'") # or '\7'

if os.path.exists(f"{network_path}/DKIMultiHCPMin_300k.pickle"):
    with open(f"{network_path}/DKIMultiHCPMin_300k.pickle", "rb") as handle:
        posteriorMin = pickle.load(handle)
else:
    np.random.seed(1)
    DT = []
    KT = []
    S0 = []
    DT2,KT2 = GenDTKT([DT1_lfa,DT2_lfa],[x4_lfa,R1_lfa,x2_lfa,R2_lfa],12,int(75000))
    DT3,KT3 = GenDTKT([DT1_hfa,DT2_hfa],[x4_hfa,R1_hfa,x2_hfa,R2_hfa],12,int(75000))
    DT5,KT5 = GenDTKT([DT1_hak,DT2_hak],[x4_hak,R1_hak,x2_hak,R2_hak],12,int(150000))   
    
    DT = np.vstack([DT5,DT2,DT3])
    KT = np.vstack([KT5,KT2,KT3])
    
    S0Dist = BoxUniform(low=torch.tensor([lower_S0]), high=torch.tensor([upper_S0]))
    
    S0 = S0Dist.sample([DT.shape[0]])
    
    S0 = np.array(S0).reshape(len(S0),1)
    
    A  = np.random.choice(32,DT.shape[0])
    
    indx = np.arange(len(KT))
    Obs = np.zeros([len(KT),len(gTabsF[0].bvecs)])
    kk = 0
    while len(indx)>0:
        for i in tqdm.tqdm(indx): 
            Obs[i] = CustomDKISimulator(DT[i],KT[i],gTabsF[A[i]],S0[i],50)
        
        indxNew = []
        for i,O in enumerate(Obs):
            if (O>4*np.array(S0[i])).any() or (O<0).any():
                indxNew.append(i)
        KT[indxNew] = KT[indxNew]/2
        DT[indxNew] = GenDTKT([DT1_full,DT2_full],[x4_full,R1_full,x2_full,R2_full],kk,1)[0]
    
        indx = indxNew
        kk+=1
    Par = np.hstack([DT,KT,S0])
    Obs = np.hstack([Obs[:,selected_indices7],np.expand_dims(A, axis=-1)])
        
    
    Obs = torch.tensor(Obs).float()
    Par = torch.tensor(Par).float()
    
    # Create inference object. Here, NPE is used.
    inference = SNPE()
    
    # generate simulations and pass to the inference object
    inference = inference.append_simulations(Par, Obs)
    
    # train the density estimator and build the posterior
    density_estimator = inference.train(stop_after_epochs= 50)
    posteriorMin = inference.build_posterior(density_estimator)
    if not os.path.exists(f"{network_path}/DKIMultiHCPMin_300k.pickle"):
        with open(f"{network_path}/DKIMultiHCPMin_300k.pickle", "wb") as handle:
            pickle.dump(posteriorMin, handle)
    import os
    os.system("say 'DKI network done'") # or '\7'

if os.path.exists(f"{network_path}/DKIMultiHCPMid_300k.pickle"):
    with open(f"{network_path}/DKIMultiHCPMid_300k.pickle", "rb") as handle:
        posteriorMid = pickle.load(handle)
else:
    np.random.seed(1)
    DT = []
    KT = []
    S0 = []
    DT2,KT2 = GenDTKT([DT1_lfa,DT2_lfa],[x4_lfa,R1_lfa,x2_lfa,R2_lfa],12,int(75000))
    DT3,KT3 = GenDTKT([DT1_hfa,DT2_hfa],[x4_hfa,R1_hfa,x2_hfa,R2_hfa],12,int(75000))
    DT5,KT5 = GenDTKT([DT1_hak,DT2_hak],[x4_hak,R1_hak,x2_hak,R2_hak],12,int(150000))   
    
    DT = np.vstack([DT5,DT2,DT3])
    KT = np.vstack([KT5,KT2,KT3])
    
    S0Dist = BoxUniform(low=torch.tensor([lower_S0]), high=torch.tensor([upper_S0]))
    
    S0 = S0Dist.sample([DT.shape[0]])
    
    S0 = np.array(S0).reshape(len(S0),1)
    
    A  = np.random.choice(32,DT.shape[0])
    
    indx = np.arange(len(KT))
    Obs = np.zeros([len(KT),len(gTabsF[0].bvecs)])
    kk = 0
    while len(indx)>0:
        for i in tqdm.tqdm(indx): 
            Obs[i] = CustomDKISimulator(DT[i],KT[i],gTabsF[A[i]],S0[i],50)
        
        indxNew = []
        for i,O in enumerate(Obs):
            if (O>4*np.array(S0[i])).any() or (O<0).any():
                indxNew.append(i)
        KT[indxNew] = KT[indxNew]/2
        DT[indxNew] = GenDTKT([DT1_full,DT2_full],[x4_full,R1_full,x2_full,R2_full],kk,1)[0]
    
        indx = indxNew
        kk+=1
    Par = np.hstack([DT,KT,S0])
    Obs = np.hstack([Obs[:,selected_indices20],np.expand_dims(A, axis=-1)])
        
    
    Obs = torch.tensor(Obs).float()
    Par = torch.tensor(Par).float()
    
    # Create inference object. Here, NPE is used.
    inference = SNPE()
    
    # generate simulations and pass to the inference object
    inference = inference.append_simulations(Par, Obs)
    
    # train the density estimator and build the posterior
    density_estimator = inference.train(stop_after_epochs= 50)
    posteriorMid = inference.build_posterior(density_estimator)
    if not os.path.exists(f"{network_path}/DKIMultiHCPMid_300k.pickle"):
        with open(f"{network_path}/DKIMultiHCPMid_300k.pickle", "wb") as handle:
            pickle.dump(posteriorMid, handle)
    import os
    os.system("say 'DKI network done'") # or '\7'

In [None]:
TD = []
axial_middles = []
masks = []
WMs = []
for kk in tqdm(range(32)):
    fdwi = './HCP_data/Pat'+str(kk+1)+'/diff_1k.nii.gz'
    bvalloc = './HCP_data/Pat'+str(kk+1)+'/bvals_1k.txt'
    bvecloc = './HCP_data/Pat'+str(kk+1)+'/bvecs_1k.txt'
    
    fdwi3 = './HCP_data/Pat'+str(kk+1)+'/diff_3k.nii.gz'
    bvalloc3 = './HCP_data/Pat'+str(kk+1)+'/bvals_3k.txt'
    bvecloc3 = './HCP_data/Pat'+str(kk+1)+'/bvecs_3k.txt'
    
    bvalsHCP = np.loadtxt(bvalloc)
    bvecsHCP = np.loadtxt(bvecloc)
    gtabHCP = gradient_table(bvalsHCP, bvecsHCP)
    
    bvalsHCP3 = np.loadtxt(bvalloc3)
    bvecsHCP3 = np.loadtxt(bvecloc3)
    gtabHCP3 = gradient_table(bvalsHCP3, bvecsHCP3)
    
    gtabExt  = gradient_table(np.hstack((bvalsHCP,bvalsHCP3)), np.vstack((bvecsHCP,bvecsHCP3)))
    
    data, affine, img = load_nifti(fdwi, return_img=True)
    data, affine = reslice(data, affine, (1.5,1.5,1.5), (2.5,2.5,2.5))
    maskdata, mask = median_otsu(data, vol_idx=range(10, 50), median_radius=3,
                                 numpass=1, autocrop=False, dilate=2)
    _, mask2 = median_otsu(data, vol_idx=range(10, 50), median_radius=3,
                                 numpass=1, autocrop=True, dilate=2)
    
    
    data3, affine, img = load_nifti(fdwi3, return_img=True)
    data3, affine = reslice(data3, affine, (1.5,1.5,1.5), (2.5,2.5,2.5))
    # Get the indices of True values
    true_indices = np.argwhere(mask)
    
    # Determine the minimum and maximum indices along each dimension
    min_coords = true_indices.min(axis=0)
    max_coords = true_indices.max(axis=0)
    
    maskdata  = maskdata[min_coords[0]:max_coords[0]+1,min_coords[1]:max_coords[1]+1,min_coords[2]:max_coords[2]+1]
    axial_middle = maskdata.shape[2] // 2
    maskdata3 = data3[min_coords[0]:max_coords[0]+1,min_coords[1]:max_coords[1]+1,min_coords[2]:max_coords[2]+1]
    axial_middles.append(axial_middle)
    TestData = np.concatenate([maskdata[:, :, axial_middle, :],maskdata3[:, :, axial_middle, :]],axis=-1)
    TestData4D = np.concatenate([maskdata,maskdata3],axis=-1)
    TD.append(TestData4D)
    masks.append(mask[min_coords[0]:max_coords[0]+1,min_coords[1]:max_coords[1]+1,axial_middle])
    WM, affine, img = load_nifti('./HCP_data/WM_Masks/c2Pat'+str(kk+1)+'_FP.nii', return_img=True)
    WMs.append(np.fliplr(WM[:,:,axial_middle]>0.8))

In [None]:
if os.path.exists(f"{DatFolder}/Full_MK_HCP.npy"):
    MKFullArr = np.load(f"{DatFolder}/Full_MK_HCP.npy",allow_pickle=True)
    RKFullArr = np.load(f"{DatFolder}/Full_RK_HCP.npy",allow_pickle=True)
    AKFullArr = np.load(f"{DatFolder}/Full_AK_HCP.npy",allow_pickle=True)
else:
    MKFullArr = []
    RKFullArr = []
    AKFullArr = []
    for kk in tqdm(range(32)):
        
        # Compute the mask where the sum is not zero
        mask = np.sum(TD[kk][:, :, axial_middles[kk], :69], axis=-1) != 0
        
        # Get the indices where mask is True
        indices = np.argwhere(mask)
        def optimize_chunk(pixels):
            results = []
            for i, j in pixels:
                posterior_samples_1 = posteriorFull.sample((500,), x=np.hstack([TD[kk][i, j, axial_middles[kk], :],kk]),show_progress_bars=False)
                results.append((i, j, np.array([histogram_mode(p) for p in posterior_samples_1.T])))
            return results
        
        chunked_indices = [indices[i:i+ChunkSize] for i in range(0, len(indices), ChunkSize)]
        results = Parallel(n_jobs=8)(
            delayed(optimize_chunk)(chunk) for chunk in tqdm(chunked_indices)
        )
        
        # Initialize NoiseEst with the appropriate shape
        ArrShape = mask.shape
    
        
        NoiseEst = np.zeros(list(ArrShape) + [22])
        
        # Assign the optimization results to NoiseEst
        for chunk in results:
            for i, j, x in chunk:
                NoiseEst[i, j] = x
                
        MK_SBIFull  = np.zeros([NoiseEst.shape[0], NoiseEst.shape[1]])
        AK_SBIFull  = np.zeros([NoiseEst.shape[0], NoiseEst.shape[1]])
        RK_SBIFull  = np.zeros([NoiseEst.shape[0], NoiseEst.shape[1]])
        NoiseEst2 =  np.zeros_like(NoiseEst)
        for i in range(NoiseEst.shape[0]):
            for j in range(NoiseEst.shape[1]):    
                NoiseEst2[i,j] = np.hstack([mat_to_vals(clip_negative_eigenvalues(vals_to_mat(NoiseEst[i,j]))),NoiseEst[i,j,6:]])
                Metrics = DKIMetrics(NoiseEst2[i,j][:6],NoiseEst2[i,j][6:21])
                MK_SBIFull[i,j] = Metrics[0]
                AK_SBIFull[i,j] = Metrics[1]
                RK_SBIFull[i,j] = Metrics[2]
            
    
        MKFullArr.append(MK_SBIFull)
        RKFullArr.append(RK_SBIFull)
        AKFullArr.append(AK_SBIFull)

In [None]:
if os.path.exists(f"{DatFolder}/Min_MK_HCP.npy"):
    MKMinArr = np.load(f"{DatFolder}/Min_MK_HCP.npy",allow_pickle=True)
    RKMinArr = np.load(f"{DatFolder}/Min_RK_HCP.npy",allow_pickle=True)
    AKMinArr = np.load(f"{DatFolder}/Min_AK_HCP.npy",allow_pickle=True)
else:
    MKMinArr = []
    RKMinArr = []
    AKMinArr = []
    for kk in tqdm(range(32)):
        
        # Compute the mask where the sum is not zero
        mask = np.sum(TD[kk][:, :, axial_middles[kk], :69], axis=-1) != 0
        
        # Get the indices where mask is True
        indices = np.argwhere(mask)
        Arr = TD[kk][:,:, axial_middles[kk], selected_indices7]
        def optimize_chunk(pixels):
            results = []
            for i, j in pixels:
                posterior_samples_1 = posteriorMin.sample((500,), x=np.hstack([Arr[i,j],kk]),show_progress_bars=False)
                results.append((i, j, np.array([histogram_mode(p) for p in posterior_samples_1.T])))
            return results
        
        chunked_indices = [indices[i:i+ChunkSize] for i in range(0, len(indices), ChunkSize)]
        results = Parallel(n_jobs=8)(
            delayed(optimize_chunk)(chunk) for chunk in tqdm(chunked_indices)
        )
        
        # Initialize NoiseEst with the appropriate shape
        ArrShape = mask.shape
    
        
        NoiseEst = np.zeros(list(ArrShape) + [22])
        
        # Assign the optimization results to NoiseEst
        for chunk in results:
            for i, j, x in chunk:
                NoiseEst[i, j] = x
        
        NoiseEst2 =  np.zeros_like(NoiseEst)
        for i in range(NoiseEst.shape[0]):
            for j in range(NoiseEst.shape[1]):    
                NoiseEst2[i,j] = np.hstack([mat_to_vals(clip_negative_eigenvalues(vals_to_mat(NoiseEst[i,j]))),NoiseEst[i,j,6:]])
    
        MK_SBIFull  = np.zeros([NoiseEst.shape[0], NoiseEst.shape[1]])
        AK_SBIFull  = np.zeros([NoiseEst.shape[0], NoiseEst.shape[1]])
        RK_SBIFull  = np.zeros([NoiseEst.shape[0], NoiseEst.shape[1]])

        for i in range(NoiseEst.shape[0]):
            for j in range(NoiseEst.shape[1]): 
                Metrics = DKIMetrics(NoiseEst2[i,j][:6],NoiseEst2[i,j][6:21])
                MK_SBIFull[i,j] = Metrics[0]
                AK_SBIFull[i,j] = Metrics[1]
                RK_SBIFull[i,j] = Metrics[2]

    
        MKMinArr.append(MK_SBIFull)
        RKMinArr.append(RK_SBIFull)
        AKMinArr.append(AK_SBIFull)


In [None]:
if os.path.exists(f"{DatFolder}/Mid_MK_HCP.npy"):
    MKMidArr = np.load(f"{DatFolder}/Mid_MK_HCP.npy",allow_pickle=True)
    RKMidArr = np.load(f"{DatFolder}/Mid_RK_HCP.npy",allow_pickle=True)
    AKMidArr = np.load(f"{DatFolder}/Mid_AK_HCP.npy",allow_pickle=True)
else:
    MKMidArr = []
    RKMidArr = []
    AKMidArr = []
    for kk in tqdm(range(32)):
        
        # Compute the mask where the sum is not zero
        mask = np.sum(TD[kk][:, :, axial_middles[kk], :69], axis=-1) != 0
        
        # Get the indices where mask is True
        indices = np.argwhere(mask)
        Arr = TD[kk][:,:, axial_middles[kk], selected_indices20]
        def optimize_chunk(pixels):
            results = []
            for i, j in pixels:
                posterior_samples_1 = posteriorMid.sample((500,), x=np.hstack([Arr[i,j],kk]),show_progress_bars=False)
                results.append((i, j, np.array([histogram_mode(p) for p in posterior_samples_1.T])))
            return results
        
        chunked_indices = [indices[i:i+ChunkSize] for i in range(0, len(indices), ChunkSize)]
        results = Parallel(n_jobs=8)(
            delayed(optimize_chunk)(chunk) for chunk in tqdm(chunked_indices)
        )
        
        # Initialize NoiseEst with the appropriate shape
        ArrShape = mask.shape
    
        
        NoiseEst = np.zeros(list(ArrShape) + [22])
        
        # Assign the optimization results to NoiseEst
        for chunk in results:
            for i, j, x in chunk:
                NoiseEst[i, j] = x
        
        NoiseEst2 =  np.zeros_like(NoiseEst)
        for i in range(NoiseEst.shape[0]):
            for j in range(NoiseEst.shape[1]):    
                NoiseEst2[i,j] = np.hstack([mat_to_vals(clip_negative_eigenvalues(vals_to_mat(NoiseEst[i,j]))),NoiseEst[i,j,6:]])
    
        MK_SBIFull  = np.zeros([NoiseEst.shape[0], NoiseEst.shape[1]])
        AK_SBIFull  = np.zeros([NoiseEst.shape[0], NoiseEst.shape[1]])
        RK_SBIFull  = np.zeros([NoiseEst.shape[0], NoiseEst.shape[1]])
    
        for i in range(NoiseEst.shape[0]):
            for j in range(NoiseEst.shape[1]): 
                Metrics = DKIMetrics(NoiseEst2[i,j][:6],NoiseEst2[i,j][6:21])
                MK_SBIFull[i,j] = Metrics[0]
                AK_SBIFull[i,j] = Metrics[1]
                RK_SBIFull[i,j] = Metrics[2]
    
        MKMidArr.append(MK_SBIFull)
        RKMidArr.append(RK_SBIFull)
        AKMidArr.append(AK_SBIFull)

In [None]:
MKFullNLArr = []
RKFullNLArr = []
AKFullNLArr = []
MKTFullNLArr = []
KFAFullNLArr = []
for kk in tqdm(range(32)):
    dkimodelNL = dki.DiffusionKurtosisModel(gTabsF[kk],fit_method='NLLS')
    dkifitNL = dkimodelNL.fit(TD[kk][:,:,axial_middles[kk]])
    ArrShape = TD[kk][:,:,axial_middles[kk],0].shape
    NoiseEst_NL = np.zeros(list(ArrShape)+[21])
    MK_NL7_t  = np.zeros(ArrShape)
    AK_NL7_t  = np.zeros(ArrShape)
    RK_NL7_t = np.zeros(ArrShape)
    for i in range(ArrShape[0]):
        for j in range(ArrShape[1]):
            NoiseEst_NL[i,j] = np.hstack([dkifitNL[i,j].lower_triangular(),dkifitNL[i,j].kt])
    NoiseEst_NL2 =  np.zeros_like(NoiseEst_NL)
    for i in range(ArrShape[0]):
        for j in range(ArrShape[1]):
            NoiseEst_NL2[i,j] = np.hstack([mat_to_vals(clip_negative_eigenvalues(vals_to_mat(NoiseEst_NL[i,j]))),NoiseEst_NL[i,j,6:]])
    for i in range(ArrShape[0]):
        for j in range(ArrShape[1]):
            Metrics = DKIMetrics(NoiseEst_NL2[i,j][:6],NoiseEst_NL2[i,j][6:21])
            MK_NL7_t[i,j] = Metrics[0]
            AK_NL7_t[i,j] = Metrics[1]
            RK_NL7_t[i,j] = Metrics[2]
    MKFullNLArr.append(MK_NL7_t)
    RKFullNLArr.append(RK_NL7_t)
    AKFullNLArr.append(AK_NL7_t)

MKMidNLArr = []
RKMidNLArr = []
AKMidNLArr = []
for kk in tqdm(range(32)):
    dkimodelNL = dki.DiffusionKurtosisModel(gTabs20[kk],fit_method='NLLS')
    dkifitNL = dkimodelNL.fit(TD[kk][:,:,axial_middles[kk],selected_indices20])
    ArrShape = TD[kk][:,:,axial_middles[kk],0].shape
    NoiseEst_NL = np.zeros(list(ArrShape)+[21])
    MK_NL7_t  = np.zeros(ArrShape)
    AK_NL7_t  = np.zeros(ArrShape)
    RK_NL7_t = np.zeros(ArrShape)
    for i in range(ArrShape[0]):
        for j in range(ArrShape[1]):
            NoiseEst_NL[i,j] = np.hstack([dkifitNL[i,j].lower_triangular(),dkifitNL[i,j].kt])
    NoiseEst_NL2 =  np.zeros_like(NoiseEst_NL)
    for i in range(ArrShape[0]):
        for j in range(ArrShape[1]):
            NoiseEst_NL2[i,j] = np.hstack([mat_to_vals(clip_negative_eigenvalues(vals_to_mat(NoiseEst_NL[i,j]))),NoiseEst_NL[i,j,6:]])
    for i in range(ArrShape[0]):
        for j in range(ArrShape[1]):
            Metrics = DKIMetrics(NoiseEst_NL2[i,j][:6],NoiseEst_NL2[i,j][6:21])
            MK_NL7_t[i,j] = Metrics[0]
            AK_NL7_t[i,j] = Metrics[1]
            RK_NL7_t[i,j] = Metrics[2]
    MKMidNLArr.append(MK_NL7_t)
    RKMidNLArr.append(RK_NL7_t)
    AKMidNLArr.append(AK_NL7_t)

MKMinNLArr = []
RKMinNLArr = []
AKMinNLArr = []
for kk in tqdm(range(32)):
    dkimodelNL = dki.DiffusionKurtosisModel(gTabs7[kk],fit_method='NLLS')
    dkifitNL = dkimodelNL.fit(TD[kk][:,:,axial_middles[kk],selected_indices7])
    ArrShape = TD[kk][:,:,axial_middles[kk],0].shape
    NoiseEst_NL = np.zeros(list(ArrShape)+[21])
    MK_NL7_t  = np.zeros(ArrShape)
    AK_NL7_t  = np.zeros(ArrShape)
    RK_NL7_t = np.zeros(ArrShape)
    for i in range(ArrShape[0]):
        for j in range(ArrShape[1]):
            NoiseEst_NL[i,j] = np.hstack([dkifitNL[i,j].lower_triangular(),dkifitNL[i,j].kt])
    NoiseEst_NL2 =  np.zeros_like(NoiseEst_NL)
    for i in range(ArrShape[0]):
        for j in range(ArrShape[1]):
            NoiseEst_NL2[i,j] = np.hstack([mat_to_vals(clip_negative_eigenvalues(vals_to_mat(NoiseEst_NL[i,j]))),NoiseEst_NL[i,j,6:]])
    for i in range(ArrShape[0]):
        for j in range(ArrShape[1]):
            Metrics = DKIMetrics(NoiseEst_NL2[i,j][:6],NoiseEst_NL2[i,j][6:21])
            MK_NL7_t[i,j] = Metrics[0]
            AK_NL7_t[i,j] = Metrics[1]
            RK_NL7_t[i,j] = Metrics[2]
    MKMinNLArr.append(MK_NL7_t)
    RKMinNLArr.append(RK_NL7_t)
    AKMinNLArr.append(AK_NL7_t)

In [None]:
AccM7_MK = []
AccM20_MK = []
AccMFulls_MK = []

AccM7NL_MK = []
AccM20NL_MK = []

SSIM7_MK = []
SSIM20_MK = []
SSIMFulls_MK = []

SSIM7NL_MK = []
SSIM20NL_MK = []
for i in range(32):
    M7 =MKMinArr[i]
    MF =MKFullArr[i]
    Ma = masks[i]
    AccM7_MK.append(np.mean(np.abs(M7-MF)[Ma]))


    M7 =MKMidArr[i]
    MF =MKFullArr[i]
    AccM20_MK.append(np.mean(np.abs(M7-MF)[Ma]))


    M7 =MKFullArr[i]
    MF =MKFullNLArr[i]
    AccMFulls_MK.append(np.mean(np.abs(M7-MF)[Ma]))

    M7 =MKMinNLArr[i]
    M7[np.isnan(M7)] = 0
    MF =MKFullNLArr[i]
    AccM7NL_MK.append(np.mean(np.abs(M7-MF)[Ma]))


    M7 =MKMidNLArr[i]
    M7[np.isnan(M7)] = 0
    MF =MKFullNLArr[i]
    AccM20NL_MK.append(np.nanmean(np.abs(M7-MF)[Ma]))

    NS1 =MKMinArr[i]
    NS1 = gaussian_filter(NS1, sigma=0.5)
    NS2 =MKFullArr[i]
    Ma = masks[i]
    result = masked_local_ssim(NS1, NS2, Ma, win_size=7,dat_range=1)
    SSIM7_MK.append(result)

    NS1 =MKMidArr[i]
    NS1 = gaussian_filter(NS1, sigma=0.5)
    NS2 =MKFullArr[i]
    Ma = masks[i]
    result = masked_local_ssim(NS1, NS2, Ma, win_size=7,dat_range=1)
    SSIM20_MK.append(result)
    
    NS1 =MKFullArr[i]
    NS2 =MKFullNLArr[i]
    Ma = masks[i]
    result = masked_local_ssim(NS1, NS2, Ma, win_size=7,dat_range=1)
    SSIMFulls_MK.append(result)

    NS1 =MKMinNLArr[i]
    NS1 = gaussian_filter(NS1, sigma=0.5)
    NS2 =MKFullNLArr[i]
    Ma = masks[i]
    result = masked_local_ssim(NS1, NS2, Ma, win_size=7,dat_range=1)
    SSIM7NL_MK.append(result)

    NS1 =MKMidNLArr[i]
    NS1 = gaussian_filter(NS1, sigma=0.5)
    NS2 =MKFullNLArr[i]
    Ma = masks[i]
    result = masked_local_ssim(NS1, NS2, Ma, win_size=7,dat_range=1)
    SSIM20NL_MK.append(result)


Prec7_SBI_MK = []
Prec20_SBI_MK = []
PrecFull_SBI_MK = []

Prec7_NLLS_MK = []
Prec20_NLLS_MK = []
PrecFull_NLLS_MK = []
for i in range(32):
    Prec7_SBI_MK.append(np.std(MKMinArr[i][WMs[i]]))
    Prec20_SBI_MK.append(np.std(MKMidArr[i][WMs[i]]))
    PrecFull_SBI_MK.append(np.std(MKFullArr[i][WMs[i]]))

    Prec7_NLLS_MK.append(np.std(MKMinNLArr[i][WMs[i]]))
    Prec20_NLLS_MK.append(np.std(MKMidNLArr[i][WMs[i]]))
    PrecFull_NLLS_MK.append(np.std(MKFullNLArr[i][WMs[i]]))



In [None]:
AccM7_AK = []
AccM20_AK = []
AccMFulls_AK = []

AccM7NL_AK = []
AccM20NL_AK = []

SSIM7_AK = []
SSIM20_AK = []
SSIMFulls_AK = []

SSIM7NL_AK = []
SSIM20NL_AK = []
for i in range(32):
    M7 =AKMinArr[i]
    MF =AKFullArr[i]
    Ma = masks[i]
    AccM7_AK.append(np.mean(np.abs(M7-MF)[Ma]))


    M7 =AKMidArr[i]
    MF =AKFullArr[i]
    AccM20_AK.append(np.mean(np.abs(M7-MF)[Ma]))


    M7 =AKFullArr[i]
    MF =AKFullNLArr[i]
    AccMFulls_AK.append(np.mean(np.abs(M7-MF)[Ma]))

    M7 =AKMinNLArr[i]
    M7[np.isnan(M7)] = 0
    MF =AKFullNLArr[i]
    AccM7NL_AK.append(np.mean(np.abs(M7-MF)[Ma]))


    M7 =AKMidNLArr[i]
    M7[np.isnan(M7)] = 0
    MF =AKFullNLArr[i]
    AccM20NL_AK.append(np.nanmean(np.abs(M7-MF)[Ma]))

    
    NS1 =AKMinArr[i]
    NS1 = gaussian_filter(NS1, sigma=0.5)
    NS2 =AKFullArr[i]
    NS2 = gaussian_filter(NS2, sigma=0.5)
    Ma = masks[i]
    result = masked_local_ssim(NS1, NS2, Ma, win_size=7,dat_range=1)
    SSIM7_AK.append(result)

    NS1 =AKMidArr[i]
    NS1 = gaussian_filter(NS1, sigma=0.5)
    NS2 =AKFullArr[i]
    Ma = masks[i]
    result = masked_local_ssim(NS1, NS2, Ma, win_size=7,dat_range=1)
    SSIM20_AK.append(result)
    
    NS1 =AKFullArr[i]
    NS2 =AKFullNLArr[i]
    Ma = masks[i]
    result = masked_local_ssim(NS1, NS2, Ma, win_size=7,dat_range=1)
    SSIMFulls_AK.append(result)

    NS1 =AKMinNLArr[i]
    NS1 = gaussian_filter(NS1, sigma=0.5)
    NS2 =AKFullNLArr[i]
    NS2 = gaussian_filter(NS2, sigma=0.5)
    Ma = masks[i]
    result = masked_local_ssim(NS1, NS2, Ma, win_size=7,dat_range=1)
    SSIM7NL_AK.append(result)

    NS1 =AKMidNLArr[i]
    NS1 = gaussian_filter(NS1, sigma=0.5)
    NS2 =AKFullNLArr[i]
    Ma = masks[i]
    result = masked_local_ssim(NS1, NS2, Ma, win_size=7,dat_range=1)
    SSIM20NL_AK.append(result)

Prec7_SBI_AK = []
Prec20_SBI_AK = []
PrecFull_SBI_AK = []

Prec7_NLLS_AK = []
Prec20_NLLS_AK = []
PrecFull_NLLS_AK = []
for i in range(32):
    Prec7_SBI_AK.append(np.std(AKMinArr[i][WMs[i]]))
    Prec20_SBI_AK.append(np.std(AKMidArr[i][WMs[i]]))
    PrecFull_SBI_AK.append(np.std(AKFullArr[i][WMs[i]]))

    Prec7_NLLS_AK.append(np.std(AKMinNLArr[i][WMs[i]]))
    Prec20_NLLS_AK.append(np.std(AKMidNLArr[i][WMs[i]]))
    PrecFull_NLLS_AK.append(np.std(AKFullNLArr[i][WMs[i]]))




In [None]:
AccM7_RK = []
AccM20_RK = []
AccMFulls_RK = []

AccM7NL_RK = []
AccM20NL_RK = []

SSIM7_RK = []
SSIM20_RK = []
SSIMFulls_RK = []

SSIM7NL_RK = []
SSIM20NL_RK = []
for i in range(32):
    M7 =RKMinArr[i]
    MF =RKFullArr[i]
    Ma = masks[i]
    AccM7_RK.append(np.mean(np.abs(M7-MF)[Ma]))


    M7 =RKMidArr[i]
    MF =RKFullArr[i]
    AccM20_RK.append(np.mean(np.abs(M7-MF)[Ma]))


    M7 =RKFullArr[i]
    MF =RKFullNLArr[i]
    AccMFulls_RK.append(np.mean(np.abs(M7-MF)[Ma]))

    M7 =RKMinNLArr[i]
    M7[np.isnan(M7)] = 0
    MF =RKFullNLArr[i]
    AccM7NL_RK.append(np.mean(np.abs(M7-MF)[Ma]))


    M7 =RKMidNLArr[i]
    M7[np.isnan(M7)] = 0
    MF =RKFullNLArr[i]
    AccM20NL_RK.append(np.nanmean(np.abs(M7-MF)[Ma]))

    
    NS1 =RKMinArr[i]
    NS1 = gaussian_filter(NS1, sigma=0.5)
    NS2 =RKFullArr[i]
    Ma = masks[i]
    result = masked_local_ssim(NS1, NS2, Ma, win_size=7,dat_range=1)
    SSIM7_RK.append(result)

    NS1 =RKMidArr[i]
    NS1 = gaussian_filter(NS1, sigma=0.5)
    NS2 =RKFullArr[i]
    Ma = masks[i]
    result = masked_local_ssim(NS1, NS2, Ma, win_size=7,dat_range=1)
    SSIM20_RK.append(result)
    
    NS1 =RKFullArr[i]
    NS2 =RKFullNLArr[i]
    Ma = masks[i]
    result = masked_local_ssim(NS1, NS2, Ma, win_size=7,dat_range=1)
    SSIMFulls_RK.append(result)

    NS1 =RKMinNLArr[i]
    NS1 = gaussian_filter(NS1, sigma=0.5)
    NS2 =RKFullNLArr[i]
    Ma = masks[i]
    result = masked_local_ssim(NS1, NS2, Ma, win_size=7,dat_range=1)
    SSIM7NL_RK.append(result)

    NS1 =RKMidNLArr[i]
    NS1 = gaussian_filter(NS1, sigma=0.5)
    NS2 =RKFullNLArr[i]
    Ma = masks[i]
    result = masked_local_ssim(NS1, NS2, Ma, win_size=7,dat_range=1)
    SSIM20NL_RK.append(result)

Prec7_SBI_RK = []
Prec20_SBI_RK = []
PrecFull_SBI_RK = []

Prec7_NLLS_RK = []
Prec20_NLLS_RK = []
PrecFull_NLLS_RK = []
for i in range(32):
    Prec7_SBI_RK.append(np.std(RKMinArr[i][WMs[i]]))
    Prec20_SBI_RK.append(np.std(RKMidArr[i][WMs[i]]))
    PrecFull_SBI_RK.append(np.std(RKFullArr[i][WMs[i]]))

    Prec7_NLLS_RK.append(np.std(RKMinNLArr[i][WMs[i]]))
    Prec20_NLLS_RK.append(np.std(RKMidNLArr[i][WMs[i]]))
    PrecFull_NLLS_RK.append(np.std(RKFullNLArr[i][WMs[i]]))


### a

In [None]:
tnorm = TwoSlopeNorm(vmin=-0,vcenter = 0.6,vmax=1.2)
temp = np.copy(MK_SBIFull)
temp[~cutout] = math.nan
plt.imshow(temp.T,norm=tnorm,cmap='hot')
plt.axis('off')
if Save: plt.savefig(FigLoc+'MKSBIFull.pdf',format='pdf',bbox_inches='tight',transparent=True)
plt.show()

temp = np.copy(AK_SBIFull)
temp[~cutout] = math.nan
plt.imshow(temp.T,norm=tnorm,cmap='hot')
plt.axis('off')
if Save: plt.savefig(FigLoc+'AKSBIFull.pdf',format='pdf',bbox_inches='tight',transparent=True)
plt.show()

temp = np.copy(RK_SBIFull)
temp[~cutout] = math.nan
plt.imshow(temp.T,norm=tnorm,cmap='hot')
plt.axis('off')
cbar = plt.colorbar(fraction=0.032, pad=0.04)
cbar.ax.set_ylim(0,1)
if Save: plt.savefig(FigLoc+'RKSBIFull.pdf',format='pdf',bbox_inches='tight',transparent=True)
plt.show()

In [None]:
tnorm = TwoSlopeNorm(vmin=-0,vcenter = 0.6,vmax=1.2)
temp = np.copy(MK_SBI7)
temp = gaussian_filter(temp, sigma=0.5)
temp[~cutout] = math.nan
plt.imshow(temp.T,norm=tnorm,cmap='hot')
plt.axis('off')
if Save: plt.savefig(FigLoc+'MKSBI7.pdf',format='pdf',bbox_inches='tight',transparent=True)
plt.show()

temp = np.copy(AK_SBI7)
temp = gaussian_filter(temp, sigma=0.5)
temp[~cutout] = math.nan
plt.imshow(temp.T,norm=tnorm,cmap='hot')
plt.axis('off')
if Save: plt.savefig(FigLoc+'AKSBI7.pdf',format='pdf',bbox_inches='tight',transparent=True)
plt.show()

temp = np.copy(RK_SBI7)
temp = gaussian_filter(temp, sigma=0.5)
temp[~cutout] = math.nan
plt.imshow(temp.T,norm=tnorm,cmap='hot')
plt.axis('off')
cbar = plt.colorbar(fraction=0.032, pad=0.04)
cbar.ax.set_ylim(0,1)
if Save: plt.savefig(FigLoc+'RKSBI7.pdf',format='pdf',bbox_inches='tight',transparent=True)
plt.show()

In [None]:
MK_SBIFull.shape

In [None]:
MK_SBI7.shape

In [None]:
ticks = [0,1,2]
data = np.abs((MK_SBIFull-MK_SBI7)*cutout).T
data[~cutout.T] = np.nan
norm = TwoSlopeNorm(vmin=0, vcenter=1, vmax=2)
plt.imshow(data,cmap='Reds',norm=norm)
plt.axis('off')
cbar = plt.colorbar(ticks=ticks)
cbar.formatter.set_powerlimits((0, 0))
if Save: plt.savefig(FigLoc+'MKDiffSBI.pdf',format='pdf',bbox_inches='tight',transparent=True)
plt.show()


data = np.abs((AK_SBIFull-AK_SBI7)*cutout).T
data[~cutout.T] = np.nan
norm = TwoSlopeNorm(vmin=0, vcenter=np.nanmax(data)/2, vmax=np.nanmax(data))
plt.imshow(data,cmap='Reds',norm=norm)
plt.axis('off')
cbar = plt.colorbar(ticks=ticks)
cbar.formatter.set_powerlimits((0, 0))
if Save: plt.savefig(FigLoc+'AKDiffSBI.pdf',format='pdf',bbox_inches='tight',transparent=True)
plt.show()

data = np.abs((RK_SBIFull-RK_SBI7)*cutout).T
data[~cutout.T] = np.nan
norm = TwoSlopeNorm(vmin=0, vcenter=1,vmax=2)
plt.imshow(data,cmap='Reds',norm=norm)
plt.axis('off')
cbar = plt.colorbar()
cbar.formatter.set_powerlimits((0, 0))
if Save: plt.savefig(FigLoc+'RKDiffSBI.pdf',format='pdf',bbox_inches='tight',transparent=True)
plt.show()

### b

In [None]:
tnorm = TwoSlopeNorm(vmin=-0,vcenter = 0.6,vmax=1.2)
temp = np.copy(MK_NLFull)
temp[~cutout] = math.nan
plt.imshow(temp.T,norm=tnorm,cmap='hot')
plt.axis('off')
if Save: plt.savefig(FigLoc+'MKNLFull.pdf',format='pdf',bbox_inches='tight',transparent=True)
plt.show()

temp = np.copy(AK_NLFull)
temp[~cutout] = math.nan
plt.imshow(temp.T,norm=tnorm,cmap='hot')
plt.axis('off')
if Save: plt.savefig(FigLoc+'AKNLFull.pdf',format='pdf',bbox_inches='tight',transparent=True)
plt.show()

temp = np.copy(RK_NLFull)
temp[~cutout] = math.nan
plt.imshow(temp.T,norm=tnorm,cmap='hot')
plt.axis('off')
cbar = plt.colorbar(fraction=0.032, pad=0.04)
cbar.ax.set_ylim(0,1)
if Save: plt.savefig(FigLoc+'RKNLFull.pdf',format='pdf',bbox_inches='tight',transparent=True)
plt.show()

In [None]:
tnorm = TwoSlopeNorm(vmin=-0,vcenter = 0.6,vmax=1.2)
temp = np.copy(MK_NL7)
temp[~cutout] = math.nan
plt.imshow(temp.T,norm=tnorm,cmap='hot')
plt.axis('off')
if Save: plt.savefig(FigLoc+'MKNL7.pdf',format='pdf',bbox_inches='tight',transparent=True)
plt.show()

temp = np.copy(AK_NL7)
temp[~cutout] = math.nan
plt.imshow(temp.T,norm=tnorm,cmap='hot')
plt.axis('off')
if Save: plt.savefig(FigLoc+'AKNL7.pdf',format='pdf',bbox_inches='tight',transparent=True)
plt.show()

temp = np.copy(RK_NL7)
temp[~cutout] = math.nan
plt.imshow(temp.T,norm=tnorm,cmap='hot')
plt.axis('off')
cbar = plt.colorbar(fraction=0.032, pad=0.04)
cbar.ax.set_ylim(0,1)
if Save: plt.savefig(FigLoc+'RKNL7.pdf',format='pdf',bbox_inches='tight',transparent=True)
plt.show()

In [None]:
data = np.abs((MK_NLFull-MK_NL7)*cutout).T
data[~cutout.T] = np.nan
plt.imshow(data,cmap='Reds',vmin=0,vmax=2)
plt.axis('off')
cbar = plt.colorbar(ticks=[0,1,2])
cbar.formatter.set_powerlimits((0, 0))
if Save: plt.savefig(FigLoc+'MKDiffWLS.pdf',format='pdf',bbox_inches='tight',transparent=True)
plt.show()

dat = np.abs((AK_SBIFull-AK_SBI7)*cutout).T
data[~cutout.T] = np.nan
norm = TwoSlopeNorm(vmin=0, vcenter=np.nanmax(dat)/2, vmax=np.nanmax(dat))
data = np.abs((AK_NLFull-AK_NL7)*cutout).T
data[~cutout.T] = np.nan
plt.imshow(data,cmap='Reds',norm=norm)
plt.axis('off')
cbar = plt.colorbar(ticks=ticks)
cbar.formatter.set_powerlimits((0, 0))
if Save: plt.savefig(FigLoc+'AKDiffWLS.pdf',format='pdf',bbox_inches='tight',transparent=True)
plt.show()

norm = TwoSlopeNorm(vmin=0, vcenter=1,vmax=2)
data = np.abs((RK_NLFull-RK_NL7)*cutout).T
data[~cutout.T] = np.nan
#ticks = [0, np.round(np.max(data),10)]  #Adjust the number of ticks as needed
plt.imshow(data,cmap='Reds',norm=norm)
plt.axis('off')
cbar = plt.colorbar()
cbar.formatter.set_powerlimits((0, 0))
if Save: plt.savefig(FigLoc+'RKDiffWLS.pdf',format='pdf',bbox_inches='tight',transparent=True)
plt.show()

### c

In [None]:
# Plot setup
fig, (ax1, ax2) = plt.subplots(2, 1,figsize=(3.2,4.8))
fig.subplots_adjust(hspace=0.05)



y_data = np.array(AccM7NL_MK)
g_pos = np.array([3.05])
colors = ['burlywood']
colors2 = ['peachpuff']

BoxPlots(y_data,g_pos,colors,colors2,ax1,widths=0.2,scatter=True)
BoxPlots(y_data,g_pos,colors,colors2,ax2,widths=0.2,scatter=True)
plt.xticks([1,1.7,2,2.8,3.1],['Full','Mid','Min','Mid','Min'],fontsize=32,rotation=90)
#plt.gca().ticklabel_format(axis='y',style='sci',scilimits=(-1,1))
plt.yticks(fontsize=24)

#ax1.yaxis.set_ticks(np.arange(0.0005, 0.006, 0.002))
ax1.set_xticks([])

plt.sca(ax2)

y_data = np.array(AccMFulls_MK)
g_pos = np.array([0.8])
colors = ['black']
colors2 = ['gray']

BoxPlots(y_data,g_pos,colors,colors2,ax2,widths=0.2,scatter=True)

y_data = np.array(AccM20_MK)
g_pos = np.array([1.55])
colors = ['lightseagreen']
colors2 = ['paleturquoise']

BoxPlots(y_data,g_pos,colors,colors2,ax2,widths=0.2,scatter=True)

y_data = np.array(AccM7_MK)
g_pos = np.array([1.95])
colors = ['mediumturquoise']
colors2 = ['paleturquoise']

BoxPlots(y_data,g_pos,colors,colors2,ax2,widths=0.2,scatter=True)

y_data = np.array(AccM20NL_MK)
g_pos = np.array([2.65])
colors = ['sandybrown']
colors2 = ['peachpuff']

BoxPlots(y_data,g_pos,colors,colors2,ax2,widths=0.2,scatter=True)

plt.xticks([0.8,1.55,1.95,2.65,3.05],['Full','Mid','Min','Mid','Min'],fontsize=32,rotation=90)
#plt.gca().ticklabel_format(axis='y',style='sci',scilimits=(-0.5,3))
#ax1.set_ylim(0.4, 1)
#ax1.ticklabel_format(axis='y', style='sci', scilimits=(-3, -3))
#ax1.yaxis.set_ticks([0.4,0.7,1])
ax1.set_xticks([])

plt.yticks(fontsize=24)

#ax2.set_ylim(0.0,0.3)
#ax2.ticklabel_format(axis='y', style='sci', scilimits=(-3, -3))
#ax2.yaxis.set_ticks(np.arange(0, 0.0001, 0.00004))

# Common x-ticks
#ax2.set_xticks([1, 1.7, 2, 2.8, 3.1])

# Adding broken axis effect
d = .5
kwargs = dict(marker=[(-1, -d), (1, d)], markersize=12, linestyle="none", color='k', mec='k', mew=1, clip_on=False)
ax1.plot([0], [0], transform=ax1.transAxes, **kwargs)
ax2.plot([0], [1], transform=ax2.transAxes, **kwargs)

# Hide the spines between ax and ax2
ax1.spines.bottom.set_visible(False)
ax2.spines.top.set_visible(False)
ax1.xaxis.tick_top()
ax1.tick_params(labeltop=False)  # don't put tick labels at the top
ax2.xaxis.tick_bottom()
ax2.yaxis.offsetText.set_visible(False)  # Hide the "1e-3" from ax2
ax1.set_xlim(np.array(ax2.get_xlim()))

# Show plot
ax2.set_ylim(0,0.5)
ax2.set_yticks([0,0.2,0.4])
ax1.set_ylim(0.8,1.8)

leg_patch2 = mpatches.Patch(color='mediumturquoise', label='SBI')
leg_patch3 = mpatches.Patch(color='sandybrown', label='NLLS')
leg_patch5 = mpatches.Patch(color='gray', label='Full \nComp.')

ax1.legend(
    handles=[leg_patch5,leg_patch2],
    loc='upper left',         # base location  # fine-tune the legend's position
    frameon=False, ncols=1,
fontsize=28,columnspacing=0.3,handlelength=0.6,handletextpad=0.3,bbox_to_anchor= (-0.0,0.9))

if Save: plt.savefig(FigLoc+'DKIHCP_Acc_MK.pdf',format='PDF',transparent=True,bbox_inches='tight')
plt.show()

In [None]:
fig, ax = plt.subplots(figsize=(3.2,4.8))#, sharex=True)

# Plotting on ax1
plt.sca(ax)
y_data = np.array(SSIMFulls_MK)
g_pos = np.array([0.8])
colors = ['black']
colors2 = ['gray']

BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2,scatter=True)

y_data = np.array(SSIM20_MK)
g_pos = np.array([1.5])
colors = ['lightseagreen']
colors2 = ['paleturquoise']

BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2,scatter=True,scatter_alpha=0.5)

y_data = np.array(SSIM7_MK)
g_pos = np.array([1.95])
colors = ['mediumturquoise']
colors2 = ['paleturquoise']

BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2,scatter=True)

y_data = np.array(SSIM20NL_MK)
g_pos = np.array([2.6])
colors = ['sandybrown']
colors2 = ['peachpuff']

BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2,scatter=True)

y_data = np.array(SSIM7NL_MK)
g_pos = np.array([3.05])
colors = ['burlywood']
colors2 = ['peachpuff']

BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2,scatter=True)

y_data = np.array(SSIM7NL_MK)
g_pos = np.array([3.1])
colors = ['burlywood']
colors2 = ['peachpuff']

plt.axhline(0.66, lw=3, ls='--', c='k')
plt.xticks([0.8,1.5,1.95,2.6,3.05],['Full','Mid','Min','Mid','Min'],fontsize=32,rotation=90)
ax.set_yticks([0,0.2,0.4,0.6,0.8,1.0])
ax.set_ylim(-0.1,1)

ax.legend(
    handles=[leg_patch3],
    loc='lower left',         # base location  # fine-tune the legend's position
    frameon=False, ncols=1,
fontsize=32,columnspacing=0.3,handlelength=0.6,handletextpad=0.3,bbox_to_anchor= (-0.1,-0.05))

if Save: plt.savefig(FigLoc+'DKIHCP_SSIM_MK.pdf',format='PDF',transparent=True,bbox_inches='tight')

In [None]:
# Plot setup
fig, (ax1, ax2) = plt.subplots(2, 1,figsize=(3.2,4.8))
fig.subplots_adjust(hspace=0.05)



y_data = np.array(Prec7_NLLS_MK)
g_pos = np.array([2.5])
colors = ['sandybrown']
colors2 = ['peachpuff']

BoxPlots(y_data,g_pos,colors,colors2,ax1,widths=0.2,scatter=True)
plt.xticks([1,1.7,2,2.8,3.1],['Full','Mid','Min','Mid','Min'],fontsize=32,rotation=90)
#plt.gca().ticklabel_format(axis='y',style='sci',scilimits=(-1,1))
plt.yticks(fontsize=24)

#ax1.yaxis.set_ticks(np.arange(0.0005, 0.006, 0.002))
ax1.set_xticks([])

plt.sca(ax2)

y_data = np.array(PrecFull_SBI_MK)
g_pos = np.array([0.65])
colors = ['mediumturquoise']
colors2 = ['paleturquoise']

BoxPlots(y_data,g_pos,colors,colors2,ax2,widths=0.2,scatter=True)

y_data = np.array(Prec20_SBI_MK)
g_pos = np.array([1.0])

BoxPlots(y_data,g_pos,colors,colors2,ax2,widths=0.2,scatter=True)

y_data = np.array(Prec7_SBI_MK)
g_pos = np.array([1.35])

BoxPlots(y_data,g_pos,colors,colors2,ax2,widths=0.2,scatter=True)

y_data = np.array(PrecFull_NLLS_MK)
g_pos = np.array([1.8])
colors = ['sandybrown']
colors2 = ['peachpuff']

BoxPlots(y_data,g_pos,colors,colors2,ax2,widths=0.2,scatter=True)

y_data = np.array(Prec20_NLLS_MK)
g_pos = np.array([2.15])

BoxPlots(y_data,g_pos,colors,colors2,ax2,widths=0.2,scatter=True)

y_data = np.array(Prec7_NLLS_MK)
g_pos = np.array([2.5])

BoxPlots(y_data,g_pos,colors,colors2,ax2,widths=0.2,scatter=True)

#plt.gca().ticklabel_format(axis='y',style='sci',scilimits=(-0.5,3))
#ax1.set_ylim(0.4, 1)
#ax1.ticklabel_format(axis='y', style='sci', scilimits=(-3, -3))
#ax1.yaxis.set_ticks([0.4,0.7,1])
ax1.set_xticks([])

plt.yticks(fontsize=24)

#ax2.set_ylim(0.0,0.3)
#ax2.ticklabel_format(axis='y', style='sci', scilimits=(-3, -3))
#ax2.yaxis.set_ticks(np.arange(0, 0.0001, 0.00004))

# Common x-ticks
#ax2.set_xticks([1, 1.7, 2, 2.8, 3.1])

# Adding broken axis effect
d = .5
kwargs = dict(marker=[(-1, -d), (1, d)], markersize=12, linestyle="none", color='k', mec='k', mew=1, clip_on=False)
ax1.plot([0], [0], transform=ax1.transAxes, **kwargs)
ax2.plot([0], [1], transform=ax2.transAxes, **kwargs)

# Hide the spines between ax and ax2
ax1.spines.bottom.set_visible(False)
ax2.spines.top.set_visible(False)
ax1.xaxis.tick_top()
ax1.tick_params(labeltop=False)  # don't put tick labels at the top
ax2.xaxis.tick_bottom()
ax2.yaxis.offsetText.set_visible(False)  # Hide the "1e-3" from ax2
ax2.set_xlim([0.3,2.7])
ax1.set_xlim(ax2.get_xlim())
# Show plot
ax2.set_ylim(0,0.9)
ax2.set_yticks([0,0.4,0.8])
ax1.set_ylim(1,13)

x = np.arange(1.7,2.6,0.05)
y1 = np.ones_like(x)*np.percentile(np.array(PrecFull_NLLS_MK)[~np.isnan(PrecFull_NLLS_MK)], 25)
y2 = np.ones_like(x)*np.percentile(np.array(PrecFull_NLLS_MK)[~np.isnan(PrecFull_NLLS_MK)], 77)
plt.fill_between(x,y1,y2,color=WLSFit,zorder=10,alpha=0.2,hatch='//')

x = np.arange(0.55,1.5,0.05)
y1 = np.ones_like(x)*np.percentile(np.array(PrecFull_SBI_MK)[~np.isnan(PrecFull_SBI_MK)], 25)
y2 = np.ones_like(x)*np.percentile(np.array(PrecFull_SBI_MK)[~np.isnan(PrecFull_SBI_MK)], 77)
plt.fill_between(x,y1,y2,color=SBIFit,zorder=10,alpha=0.2,hatch='//')

plt.xticks([0.65,1,1.35,1.8,2.15,2.5],['Full','Mid','Min','Full','Mid','Min'],fontsize=32,rotation=90)
if Save: plt.savefig(FigLoc+'DKI_MK_Prec.pdf',format='pdf',bbox_inches='tight',transparent=True)
plt.show()

### d

In [None]:
fig, ax1 = plt.subplots(1, 1,figsize=(3.2,4.8))
fig.subplots_adjust(hspace=0.05)




y_data = np.array(AccMFulls_AK)
g_pos = np.array([0.8])
colors = ['black']
colors2 = ['gray']

BoxPlots(y_data,g_pos,colors,colors2,ax1,widths=0.2,scatter=True)

y_data = np.array(AccM20_AK)
g_pos = np.array([1.5])
colors = ['lightseagreen']
colors2 = ['paleturquoise']

BoxPlots(y_data,g_pos,colors,colors2,ax1,widths=0.2,scatter=True)

y_data = np.array(AccM7_AK)
g_pos = np.array([1.95])
colors = ['mediumturquoise']
colors2 = ['paleturquoise']

BoxPlots(y_data,g_pos,colors,colors2,ax1,widths=0.2,scatter=True)

y_data = np.array(AccM20NL_AK)
g_pos = np.array([2.6])
colors = ['sandybrown']
colors2 = ['peachpuff']

BoxPlots(y_data,g_pos,colors,colors2,ax1,widths=0.2,scatter=True)

y_data = np.array(AccM7NL_AK)
g_pos = np.array([3.05])
colors = ['burlywood']
colors2 = ['peachpuff']

BoxPlots(y_data,g_pos,colors,colors2,ax1,widths=0.2,scatter=True)
plt.xticks([0.8,1.5,1.95,2.6,3.05],['Full','Mid','Min','Mid','Min'],fontsize=32,rotation=90)
#plt.gca().ticklabel_format(axis='y',style='sci',scilimits=(-1,1))
plt.yticks(fontsize=24)
ax1.set_ylim(0., 0.5)

if Save: plt.savefig(FigLoc+'DKIHCP_Acc_AK.pdf',format='PDF',transparent=True,bbox_inches='tight')

In [None]:
fig, ax = plt.subplots(figsize=(3.2,4.8))#, sharex=True)

# Plotting on ax1
plt.sca(ax)
y_data = np.array(SSIMFulls_AK)
g_pos = np.array([0.8])
colors = ['black']
colors2 = ['gray']

BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2,scatter=True)

y_data = np.array(SSIM20_AK)
g_pos = np.array([1.5])
colors = ['lightseagreen']
colors2 = ['paleturquoise']

BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2,scatter=True,scatter_alpha=0.5)

y_data = np.array(SSIM7_AK)
g_pos = np.array([1.95])
colors = ['mediumturquoise']
colors2 = ['paleturquoise']

BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2,scatter=True)

y_data = np.array(SSIM20NL_AK)
g_pos = np.array([2.6])
colors = ['sandybrown']
colors2 = ['peachpuff']

BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2,scatter=True)

y_data = np.array(SSIM7NL_AK)
g_pos = np.array([3.05])
colors = ['burlywood']
colors2 = ['peachpuff']

BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2,scatter=True)

y_data = np.array(SSIM7NL_AK)
g_pos = np.array([3.1])
colors = ['burlywood']
colors2 = ['peachpuff']

plt.axhline(0.66, lw=3, ls='--', c='k')
plt.xticks([0.8,1.5,1.95,2.6,3.05],['Full','Mid','Min','Mid','Min'],fontsize=32,rotation=90)
ax.set_yticks([0,0.2,0.4,0.6,0.8,1.0])
ax.set_ylim(-0.1,1)

if Save: plt.savefig(FigLoc+'DKIHCP_SSIM_AK.pdf',format='PDF',transparent=True,bbox_inches='tight')

In [None]:
fig,ax = plt.subplots(figsize=(3.2,4.8))
y_data = np.array(PrecFull_SBI_AK)
g_pos = np.array([0.65])
colors = ['mediumturquoise']
colors2 = ['paleturquoise']

BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2,scatter=True)

y_data = np.array(Prec20_SBI_AK)
g_pos = np.array([1])
BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2,scatter=True)

y_data = np.array(Prec7_SBI_AK)
g_pos = np.array([1.35])
BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2,scatter=True)

colors = ['sandybrown']
colors2 = ['peachpuff']
y_data = np.array(PrecFull_NLLS_AK)
g_pos = np.array([1.8])
BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2,scatter=True)

y_data = np.array(Prec20_NLLS_AK)
g_pos = np.array([2.15])
BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2,scatter=True)

y_data = np.array(Prec7_NLLS_AK)
g_pos = np.array([2.5])
BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2,scatter=True)

x = np.arange(1.7,2.6,0.05)
y1 = np.ones_like(x)*np.percentile(np.array(PrecFull_NLLS_AK)[~np.isnan(PrecFull_NLLS_AK)], 25)
y2 = np.ones_like(x)*np.percentile(np.array(PrecFull_NLLS_AK)[~np.isnan(PrecFull_NLLS_AK)], 77)
plt.fill_between(x,y1,y2,color=WLSFit,zorder=10,alpha=0.2,hatch='//')

x = np.arange(0.55,1.5,0.05)
y1 = np.ones_like(x)*np.percentile(np.array(PrecFull_SBI_AK)[~np.isnan(PrecFull_SBI_AK)], 25)
y2 = np.ones_like(x)*np.percentile(np.array(PrecFull_SBI_AK)[~np.isnan(PrecFull_SBI_AK)], 77)
plt.fill_between(x,y1,y2,color=SBIFit,zorder=10,alpha=0.2,hatch='//')

plt.xticks([0.65,1,1.35,1.8,2.15,2.5],['Full','Mid','Min','Full','Mid','Min'],fontsize=32,rotation=90)
plt.gca().ticklabel_format(axis='y',style='sci',scilimits=(-0.5,3))
plt.yticks(fontsize=24)
ax.set_xlim([0.4,2.7])
if Save: plt.savefig(FigLoc+'DKI_AK_Prec.pdf',format='pdf',bbox_inches='tight',transparent=True)

### e

In [None]:
# Plot setup
fig, (ax1, ax2) = plt.subplots(2, 1,figsize=(3.2,4.8))
fig.subplots_adjust(hspace=0.05)



y_data = np.array(AccM7NL_RK)
g_pos = np.array([3.05])
colors = ['burlywood']
colors2 = ['peachpuff']

BoxPlots(y_data,g_pos,colors,colors2,ax1,widths=0.2,scatter=True)
BoxPlots(y_data,g_pos,colors,colors2,ax2,widths=0.2,scatter=True)
plt.xticks([1,1.7,2,2.8,3.1],['Full','Mid','Min','Mid','Min'],fontsize=32,rotation=90)
#plt.gca().ticklabel_format(axis='y',style='sci',scilimits=(-1,1))
plt.yticks(fontsize=24)

#ax1.yaxis.set_ticks(np.arange(0.0005, 0.006, 0.002))
ax1.set_xticks([])

plt.sca(ax2)

y_data = np.array(AccMFulls_RK)
g_pos = np.array([0.8])
colors = ['black']
colors2 = ['gray']

BoxPlots(y_data,g_pos,colors,colors2,ax2,widths=0.2,scatter=True)

y_data = np.array(AccM20_RK)
g_pos = np.array([1.55])
colors = ['lightseagreen']
colors2 = ['paleturquoise']

BoxPlots(y_data,g_pos,colors,colors2,ax2,widths=0.2,scatter=True)

y_data = np.array(AccM7_RK)
g_pos = np.array([1.95])
colors = ['mediumturquoise']
colors2 = ['paleturquoise']

BoxPlots(y_data,g_pos,colors,colors2,ax2,widths=0.2,scatter=True)

y_data = np.array(AccM20NL_RK)
g_pos = np.array([2.65])
colors = ['sandybrown']
colors2 = ['peachpuff']

BoxPlots(y_data,g_pos,colors,colors2,ax2,widths=0.2,scatter=True)

plt.xticks([0.8,1.55,1.95,2.65,3.05],['Full','Mid','Min','Mid','Min'],fontsize=32,rotation=90)
#plt.gca().ticklabel_format(axis='y',style='sci',scilimits=(-0.5,3))
#ax1.set_ylim(0.4, 1)
#ax1.ticklabel_format(axis='y', style='sci', scilimits=(-3, -3))
#ax1.yaxis.set_ticks([0.4,0.7,1])
ax1.set_xticks([])

plt.yticks(fontsize=24)

#ax2.set_ylim(0.0,0.3)
#ax2.ticklabel_format(axis='y', style='sci', scilimits=(-3, -3))
#ax2.yaxis.set_ticks(np.arange(0, 0.0001, 0.00004))

# Common x-ticks
#ax2.set_xticks([1, 1.7, 2, 2.8, 3.1])

# Adding broken axis effect
d = .5
kwargs = dict(marker=[(-1, -d), (1, d)], markersize=12, linestyle="none", color='k', mec='k', mew=1, clip_on=False)
ax1.plot([0], [0], transform=ax1.transAxes, **kwargs)
ax2.plot([0], [1], transform=ax2.transAxes, **kwargs)

# Hide the spines between ax and ax2
ax1.spines.bottom.set_visible(False)
ax2.spines.top.set_visible(False)
ax1.xaxis.tick_top()
ax1.tick_params(labeltop=False)  # don't put tick labels at the top
ax2.xaxis.tick_bottom()
ax2.yaxis.offsetText.set_visible(False)  # Hide the "1e-3" from ax2
ax1.set_xlim(np.array(ax2.get_xlim()))

# Show plot
ax2.set_ylim(0,0.8)
ax2.set_yticks([0,0.3,0.6])
ax1.set_ylim(1,15)

leg_patch2 = mpatches.Patch(color='mediumturquoise', label='SBI')
leg_patch3 = mpatches.Patch(color='sandybrown', label='NLLS')
leg_patch5 = mpatches.Patch(color='gray', label='Full \nComp.')


if Save: plt.savefig(FigLoc+'DKIHCP_Acc_RK.pdf',format='PDF',transparent=True,bbox_inches='tight')
plt.show()

In [None]:
fig, ax = plt.subplots(figsize=(3.2,4.8))#, sharex=True)

# Plotting on ax1
plt.sca(ax)
y_data = np.array(SSIMFulls_RK)
g_pos = np.array([0.8])
colors = ['black']
colors2 = ['gray']

BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2,scatter=True)

y_data = np.array(SSIM20_RK)
g_pos = np.array([1.5])
colors = ['lightseagreen']
colors2 = ['paleturquoise']

BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2,scatter=True,scatter_alpha=0.5)

y_data = np.array(SSIM7_RK)
g_pos = np.array([1.95])
colors = ['mediumturquoise']
colors2 = ['paleturquoise']

BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2,scatter=True)

y_data = np.array(SSIM20NL_RK)
g_pos = np.array([2.6])
colors = ['sandybrown']
colors2 = ['peachpuff']

BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2,scatter=True)

y_data = np.array(SSIM7NL_RK)
g_pos = np.array([3.05])
colors = ['burlywood']
colors2 = ['peachpuff']

BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2,scatter=True)

y_data = np.array(SSIM7NL_RK)
g_pos = np.array([3.1])
colors = ['burlywood']
colors2 = ['peachpuff']

plt.axhline(0.66, lw=3, ls='--', c='k')
plt.xticks([0.8,1.5,1.95,2.6,3.05],['Full','Mid','Min','Mid','Min'],fontsize=32,rotation=90)
ax.set_yticks([0,0.2,0.4,0.6,0.8,1.0])
ax.set_ylim(-0.1,1)

if Save: plt.savefig(FigLoc+'DKIHCP_SSIM_RK.pdf',format='PDF',transparent=True,bbox_inches='tight')

In [None]:
# Plot setup
fig, (ax1, ax2) = plt.subplots(2, 1,figsize=(3.2,4.8))
fig.subplots_adjust(hspace=0.05)



y_data = np.array(Prec7_NLLS_RK)
g_pos = np.array([2.5])
colors = ['sandybrown']
colors2 = ['peachpuff']

BoxPlots(y_data,g_pos,colors,colors2,ax1,widths=0.2,scatter=True)
plt.xticks([1,1.7,2,2.8,3.1],['Full','Mid','Min','Mid','Min'],fontsize=32,rotation=90)
#plt.gca().ticklabel_format(axis='y',style='sci',scilimits=(-1,1))
plt.yticks(fontsize=24)

#ax1.yaxis.set_ticks(np.arange(0.0005, 0.006, 0.002))
ax1.set_xticks([])

plt.sca(ax2)

y_data = np.array(PrecFull_SBI_RK)
g_pos = np.array([0.65])
colors = ['mediumturquoise']
colors2 = ['paleturquoise']

BoxPlots(y_data,g_pos,colors,colors2,ax2,widths=0.2,scatter=True)

y_data = np.array(Prec20_SBI_RK)
g_pos = np.array([1.0])

BoxPlots(y_data,g_pos,colors,colors2,ax2,widths=0.2,scatter=True)

y_data = np.array(Prec7_SBI_RK)
g_pos = np.array([1.35])

BoxPlots(y_data,g_pos,colors,colors2,ax2,widths=0.2,scatter=True)

y_data = np.array(PrecFull_NLLS_RK)
g_pos = np.array([1.8])
colors = ['sandybrown']
colors2 = ['peachpuff']

BoxPlots(y_data,g_pos,colors,colors2,ax2,widths=0.2,scatter=True)

y_data = np.array(Prec20_NLLS_RK)
g_pos = np.array([2.15])

BoxPlots(y_data,g_pos,colors,colors2,ax2,widths=0.2,scatter=True)

y_data = np.array(Prec7_NLLS_RK)
g_pos = np.array([2.5])

BoxPlots(y_data,g_pos,colors,colors2,ax2,widths=0.2,scatter=True)

#plt.gca().ticklabel_format(axis='y',style='sci',scilimits=(-0.5,3))
#ax1.set_ylim(0.4, 1)
#ax1.ticklabel_format(axis='y', style='sci', scilimits=(-3, -3))
#ax1.yaxis.set_ticks([0.4,0.7,1])
ax1.set_xticks([])

plt.yticks(fontsize=24)

#ax2.set_ylim(0.0,0.3)
#ax2.ticklabel_format(axis='y', style='sci', scilimits=(-3, -3))
#ax2.yaxis.set_ticks(np.arange(0, 0.0001, 0.00004))

# Common x-ticks
#ax2.set_xticks([1, 1.7, 2, 2.8, 3.1])

# Adding broken axis effect
d = .5
kwargs = dict(marker=[(-1, -d), (1, d)], markersize=12, linestyle="none", color='k', mec='k', mew=1, clip_on=False)
ax1.plot([0], [0], transform=ax1.transAxes, **kwargs)
ax2.plot([0], [1], transform=ax2.transAxes, **kwargs)

# Hide the spines between ax and ax2
ax1.spines.bottom.set_visible(False)
ax2.spines.top.set_visible(False)
ax1.xaxis.tick_top()
ax1.tick_params(labeltop=False)  # don't put tick labels at the top
ax2.xaxis.tick_bottom()
ax2.yaxis.offsetText.set_visible(False)  # Hide the "1e-3" from ax2
ax2.set_xlim([0.3,2.7])
ax1.set_xlim(ax2.get_xlim())
# Show plot
ax2.set_ylim(0,2)
ax2.set_yticks([0,0.5,1,1.5])
ax1.set_ylim(1,250)

x = np.arange(1.7,2.6,0.05)
y1 = np.ones_like(x)*np.percentile(np.array(PrecFull_NLLS_RK)[~np.isnan(PrecFull_NLLS_RK)], 25)
y2 = np.ones_like(x)*np.percentile(np.array(PrecFull_NLLS_RK)[~np.isnan(PrecFull_NLLS_RK)], 77)
plt.fill_between(x,y1,y2,color=WLSFit,zorder=10,alpha=0.2,hatch='//')

x = np.arange(0.55,1.5,0.05)
y1 = np.ones_like(x)*np.percentile(np.array(PrecFull_SBI_RK)[~np.isnan(PrecFull_SBI_RK)], 25)
y2 = np.ones_like(x)*np.percentile(np.array(PrecFull_SBI_RK)[~np.isnan(PrecFull_SBI_RK)], 77)
plt.fill_between(x,y1,y2,color=SBIFit,zorder=10,alpha=0.2,hatch='//')

plt.xticks([0.65,1,1.35,1.8,2.15,2.5],['Full','Mid','Min','Full','Mid','Min'],fontsize=32,rotation=90)
if Save: plt.savefig(FigLoc+'DKI_RK_Prec.pdf',format='pdf',bbox_inches='tight',transparent=True)
plt.show()

## MS

In [None]:
gTabsF = []
Dats   = []

gTabs7 = []
gTabs20 = []
Masks = []
WMDir = MSDir+'WM_masks/'
WMs = []
for i,Name in tqdm(enumerate(['NMSS_11_1year','NMSS_15','NMSS_16','NMSS_18','NMSS_19','Ctrl055_R01_28','Ctrl056_R01_29','Ctrl057_R01_30'])):
    MatDir = MSDir+Name

    F = pmt.read_mat(MatDir+'/data_loaded.mat')
    affine = np.ones((4,4))
    
    data, affine = reslice(F['data'], affine, (2,2,2), (2.5,2.5,2.5))
    _, maskCut = median_otsu(data, vol_idx=range(10, 50),autocrop=False)

    true_indices = np.argwhere(maskCut)
    
    # Determine the minimum and maximum indices along each dimension
    min_coords = true_indices.min(axis=0)
    max_coords = true_indices.max(axis=0)
    
    for k,x in enumerate(os.listdir(WMDir)):
        if Name in x:
            print(Name)
            WM, affine, img = load_nifti(WMDir+x, return_img=True)
            WM, affine = reslice(WM, affine, (2,2,2), (2.5,2.5,2.5))
            if(i<5):
                WM_t = np.fliplr(np.swapaxes(WM,0,1))
            else:
                WM_t = np.fliplr(np.flipud(np.swapaxes(WM,0,1)))
            WM_t  = WM_t[min_coords[0]:max_coords[0]+1,min_coords[1]:max_coords[1]+1,min_coords[2]:max_coords[2]+1]
            WMs.append(WM_t)
            
    maskdata, mask = median_otsu(data, vol_idx=range(10, 50),autocrop=True)
    axial_middle = maskdata.shape[2] // 2
    Masks.append(mask)
    bvecs = (F['direction'].T/np.linalg.norm(F['direction'],axis=1)).T
    bvecs[np.isnan(bvecs)] = 0
    bvals = F['bval']
    bvecs2000 = bvecs[bvals==2000]
    bvecs4000 = bvecs[bvals==4000]

    bvals2000 = np.array([0] + list(bvals[bvals==2000]))
    bvecs2000 = np.vstack([[0,0,0],bvecs[bvals==2000]])

    Dats.append(maskdata[:,:,:,:])
    
    gTabsF.append(gradient_table(bvals,bvecs))
    if i == 0:
        # Choose the first point (arbitrary starting point, e.g., the first gradient)
        selected_indices = [0]
        distance_matrix = squareform(pdist(bvecs2000))
        # Iteratively select the point furthest from the current selection
        for _ in range(6):  # We need 7 points in total, and one is already selected
            remaining_indices = list(set(range(len(bvecs2000))) - set(selected_indices))
            
            # Calculate the minimum distance to the selected points for each remaining point
            min_distances = np.min(distance_matrix[remaining_indices][:, selected_indices], axis=1)
            
            # Select the point with the maximum minimum distance
            next_index = remaining_indices[np.argmax(min_distances)]
            selected_indices.append(next_index)
        
        selected_indices7 = selected_indices
        
        # Choose the first point (arbitrary starting point, e.g., the first gradient)
        selected_indices = [0]
        distance_matrix = squareform(pdist(bvecs4000))
        # Iteratively select the point furthest from the current selection
        for _ in range(14):  # We need 7 points in total, and one is already selected
            remaining_indices = list(set(range(len(bvecs4000))) - set(selected_indices))
            
            # Calculate the minimum distance to the selected points for each remaining point
            min_distances = np.min(distance_matrix[remaining_indices][:, selected_indices], axis=1)
            
            # Select the point with the maximum minimum distance
            next_index = remaining_indices[np.argmax(min_distances)]
            selected_indices.append(next_index)
        selected_indices7_2 = selected_indices
        
        selected_indices = [0]
        distance_matrix = squareform(pdist(bvecs2000))
        # Iteratively select the point furthest from the current selection
        for _ in range(19):  # We need 7 points in total, and one is already selected
            remaining_indices = list(set(range(len(bvecs2000))) - set(selected_indices))
            
            # Calculate the minimum distance to the selected points for each remaining point
            min_distances = np.min(distance_matrix[remaining_indices][:, selected_indices], axis=1)
            
            # Select the point with the maximum minimum distance
            next_index = remaining_indices[np.argmax(min_distances)]
            selected_indices.append(next_index)
        selected_indices20 = selected_indices
        
        # Choose the first point (arbitrary starting point, e.g., the first gradient)
        selected_indices = [0]
        distance_matrix = squareform(pdist(bvecs4000))
        # Iteratively select the point furthest from the current selection
        for _ in range(27):  # We need 7 points in total, and one is already selected
            remaining_indices = list(set(range(len(bvecs4000))) - set(selected_indices))
            
            # Calculate the minimum distance to the selected points for each remaining point
            min_distances = np.min(distance_matrix[remaining_indices][:, selected_indices], axis=1)
            
            # Select the point with the maximum minimum distance
            next_index = remaining_indices[np.argmax(min_distances)]
            selected_indices.append(next_index)
        selected_indices20_2 = selected_indices
    
        Indxs7 = np.hstack([0,np.where(bvals==2000)[0][np.array(selected_indices7)[1:]-1],np.where(bvals==4000)[0][selected_indices7_2]])
        Indxs20 = np.hstack([0,np.where(bvals==2000)[0][np.array(selected_indices20)[1:]-1],np.where(bvals==4000)[0][selected_indices20_2]])
    gTabs7.append(gradient_table([0]+[2000]*6 + [4000]*15,np.vstack([bvecs2000[selected_indices7],bvecs4000[selected_indices7_2]])))
    gTabs20.append(gradient_table([0]+[2000]*19 + [4000]*28,np.vstack([bvecs2000[selected_indices20],bvecs4000[selected_indices20_2]])))
axial_middles = [32]*8

In [None]:
if os.path.exists(f"{network_path}/DKIMultiMSFull_300k.pickle"):
    with open(f"{network_path}/DKIMultiMSFull_300k.pickle", "rb") as handle:
        posteriorFull = pickle.load(handle)
else:
    np.random.seed(1)
    DT = []
    KT = []
    S0 = []
    DT,KT = GenDTKT([DT1_full,DT2_full],[x4_full,R1_full,x2_full,R2_full],12,30000) 
    
    S0Dist = BoxUniform(low=torch.tensor([lower_S0]), high=torch.tensor([upper_S0]))
    
    S0 = S0Dist.sample([DT.shape[0]])
    
    S0 = np.array(S0).reshape(len(S0),1)
    
    A  = np.random.choice(8,DT.shape[0])
    
    indx = np.arange(len(KT))
    Obs = np.zeros([len(KT),len(gTabsF[0].bvecs)])
    kk = 0
    while len(indx)>0:
        for i in tqdm(indx): 
            Obs[i] = CustomDKISimulator(DT[i],KT[i],gTabsF[A[i]],S0[i],50)
        
        indxNew = []
        for i,O in enumerate(Obs):
            if (O>4*np.array(S0[i])).any() or (O<0).any():
                indxNew.append(i)
        KT[indxNew] = KT[indxNew]/2
        DT[indxNew] = GenDTKT([DT1_full,DT2_full],[x4_full,R1_full,x2_full,R2_full],kk,1)[0]
    
        indx = indxNew
        kk+=1
    Par = np.hstack([DT,KT,S0])
    Obs = np.hstack([Obs,np.expand_dims(A, axis=-1)])
        
    
    Obs = torch.tensor(Obs).float()
    Par = torch.tensor(Par).float()
    
    # Create inference object. Here, NPE is used.
    inference = SNPE()
    
    # generate simulations and pass to the inference object
    inference = inference.append_simulations(Par, Obs)
    
    # train the density estimator and build the posterior
    density_estimator = inference.train(stop_after_epochs=100)
    posteriorFull = inference.build_posterior(density_estimator)
    if not os.path.exists(f"{network_path}/DKIMultiMSFull_300k.pickle"):
        with open(f"{network_path}/DKIMultiMSFull_300k.pickle", "wb") as handle:
            pickle.dump(posteriorFull, handle)
    import os
    os.system("say 'DKI network done'") # or '\7'
if os.path.exists(f"{network_path}/DKIMultiMSMin_300k.pickle"):
    with open(f"{network_path}/DKIMultiMSMin_300k.pickle", "rb") as handle:
        posteriorMin = pickle.load(handle)
else:
    np.random.seed(1)
    DT = []
    KT = []
    S0 = []
    DT,KT = GenDTKT([DT1_full,DT2_full],[x4_full,R1_full,x2_full,R2_full],12,30000)
    
    S0Dist = BoxUniform(low=torch.tensor([lower_S0]), high=torch.tensor([upper_S0]))
    
    S0 = S0Dist.sample([DT.shape[0]])
    
    S0 = np.array(S0).reshape(len(S0),1)
    
    A  = np.random.choice(8,DT.shape[0])
    
    indx = np.arange(len(KT))
    Obs = np.zeros([len(KT),len(gTabsF[0].bvecs)])
    kk = 0
    while len(indx)>0:
        for i in tqdm(indx): 
            Obs[i] = CustomDKISimulator(DT[i],KT[i],gTabsF[A[i]],S0[i],50)
        
        indxNew = []
        for i,O in enumerate(Obs):
            if (O>4*np.array(S0[i])).any() or (O<0).any():
                indxNew.append(i)
        KT[indxNew] = KT[indxNew]/2
        DT[indxNew] = GenDTKT([DT1_full,DT2_full],[x4_full,R1_full,x2_full,R2_full],kk,1)[0]
    
        indx = indxNew
        kk+=1
    Par = np.hstack([DT,KT,S0])
    Obs = np.hstack([Obs[:,Indxs7],np.expand_dims(A, axis=-1)])
        
    
    Obs = torch.tensor(Obs).float()
    Par = torch.tensor(Par).float()
    
    # Create inference object. Here, NPE is used.
    inference = SNPE()
    
    # generate simulations and pass to the inference object
    inference = inference.append_simulations(Par, Obs)
    
    # train the density estimator and build the posterior
    density_estimator = inference.train()
    posteriorMin = inference.build_posterior(density_estimator)
    if not os.path.exists(f"{network_path}/DKIMultiMSMin_300k.pickle"):
        with open(f"{network_path}/DKIMultiMSMin_300k.pickle", "wb") as handle:
            pickle.dump(posteriorMin, handle)
    import os
    os.system("say 'DKI network done'") # or '\7'
if os.path.exists(f"{network_path}/DKIMultiMSMid_300k.pickle"):
    with open(f"{network_path}/DKIMultiMSMid_300k.pickle", "rb") as handle:
        posteriorMid = pickle.load(handle)
else:
    np.random.seed(1)
    DT = []
    KT = []
    S0 = []
    DT,KT = GenDTKT([DT1_full,DT2_full],[x4_full,R1_full,x2_full,R2_full],12,30000)
    
    S0Dist = BoxUniform(low=torch.tensor([lower_S0]), high=torch.tensor([upper_S0]))
    
    S0 = S0Dist.sample([DT.shape[0]])
    
    S0 = np.array(S0).reshape(len(S0),1)
    
    A  = np.random.choice(8,DT.shape[0])
    
    indx = np.arange(len(KT))
    Obs = np.zeros([len(KT),len(gTabsF[0].bvecs)])
    kk = 0
    while len(indx)>0:
        for i in tqdm(indx): 
            Obs[i] = CustomDKISimulator(DT[i],KT[i],gTabsF[A[i]],S0[i],50)
        
        indxNew = []
        for i,O in enumerate(Obs):
            if (O>4*np.array(S0[i])).any() or (O<0).any():
                indxNew.append(i)
        KT[indxNew] = KT[indxNew]/2
        DT[indxNew] = GenDTKT([DT1_full,DT2_full],[x4_full,R1_full,x2_full,R2_full],kk,1)[0]
    
        indx = indxNew
        kk+=1
    Par = np.hstack([DT,KT,S0])
    Obs = np.hstack([Obs[:,Indxs20],np.expand_dims(A, axis=-1)])
        
    
    Obs = torch.tensor(Obs).float()
    Par = torch.tensor(Par).float()
    
    # Create inference object. Here, NPE is used.
    inference = SNPE()
    
    # generate simulations and pass to the inference object
    inference = inference.append_simulations(Par, Obs)
    
    # train the density estimator and build the posterior
    density_estimator = inference.train()
    posteriorMid = inference.build_posterior(density_estimator)
    if not os.path.exists(f"{network_path}/DKIMultiMSMid_300k.pickle"):
        with open(f"{network_path}/DKIMultiMSMid_300k.pickle", "wb") as handle:
            pickle.dump(posteriorMid, handle)
    import os
    os.system("say 'DKI network done'") # or '\7'

In [None]:
MKFullArr = []
RKFullArr = []
AKFullArr = []
MKTFullArr = []
KFAFullArr = []
for kk in range(8):
    Dat = Dats[kk]
    ArrShape = Dat.shape[:2]
    
    # Compute the mask where the sum is not zero
    mask = np.sum(Dats[kk][:, :, axial_middles[kk], :], axis=-1) != 0
    
    # Get the indices where mask is True
    indices = np.argwhere(mask)
    def optimize_chunk(pixels):
        results = []
        for i, j in pixels:
            posterior_samples_1 = posteriorFull.sample((500,), x=np.hstack([Dats[kk][i,j,axial_middles[kk], :],kk]),show_progress_bars=False)
            results.append((i, j, np.array([histogram_mode(p) for p in posterior_samples_1.T])))
        return results
    
    chunked_indices = [indices[i:i+ChunkSize] for i in range(0, len(indices), ChunkSize)]
    results = Parallel(n_jobs=8)(
        delayed(optimize_chunk)(chunk) for chunk in tqdm(chunked_indices)
    )
    
    # Initialize NoiseEst with the appropriate shape
    ArrShape = mask.shape
    
    NoiseEst = np.zeros(list(ArrShape) + [22])
    
    # Assign the optimization results to NoiseEst
    for chunk in results:
        for i, j, x in chunk:
            NoiseEst[i, j] = x
        
    NoiseEst2 =  np.zeros_like(NoiseEst)
    for i in range(ArrShape[0]):
        for j in range(ArrShape[1]):  
            NoiseEst2[i,j] = np.hstack([mat_to_vals(clip_negative_eigenvalues(vals_to_mat(NoiseEst[i,j]))),NoiseEst[i,j,6:]])
    
    MK7  = np.zeros(ArrShape)
    RK7  = np.zeros(ArrShape)
    AK7  = np.zeros(ArrShape)
    MKT7 = np.zeros(ArrShape)
    KFA7 = np.zeros(ArrShape)
    for i in tqdm(range(ArrShape[0])):
        for j in range(ArrShape[1]):
            Metrics = DKIMetrics(NoiseEst2[i,j][:6],NoiseEst2[i,j][6:21])
            MK7[i,j] = Metrics[0]
            AK7[i,j] = Metrics[1]
            RK7[i,j] = Metrics[2]
            MKT7[i,j] = Metrics[3]
            KFA7[i,j] = Metrics[4]
    MKFullArr.append(MK7)
    RKFullArr.append(RK7)
    AKFullArr.append(AK7)
    MKTFullArr.append(MKT7)
    KFAFullArr.append(KFA7)
    fig,ax = plt.subplots(1,5,figsize=(24,6.4))
    plt.sca(ax[0])
    plt.imshow(MK7.T,vmin=0,vmax=1)
    plt.sca(ax[1])
    plt.imshow(RK7.T,vmin=0,vmax=1)
    plt.sca(ax[2])
    plt.imshow(AK7.T,vmin=0,vmax=1)
    plt.sca(ax[3])
    plt.imshow(MKT7.T,vmin=0,vmax=1)
    plt.sca(ax[4])
    plt.imshow(KFA7.T,vmin=0,vmax=1)
    plt.show()

In [None]:
MKMidArr = []
RKMidArr = []
AKMidArr = []
MKTMidArr = []
KFAMidArr = []
for kk in range(8):
    Dat = Dats[kk]
    ArrShape = Dat.shape[:2]
    
    # Compute the mask where the sum is not zero
    mask = np.sum(Dats[kk][:, :, axial_middles[kk], :], axis=-1) != 0
    
    # Get the indices where mask is True
    indices = np.argwhere(mask)
    Arr = Dats[kk][:,:,axial_middles[kk], Indxs20]
    def optimize_chunk(pixels):
        results = []
        for i, j in pixels:
            posterior_samples_1 = posteriorMid.sample((500,), x=np.hstack([Arr[i,j],kk]),show_progress_bars=False)
            results.append((i, j, np.array([histogram_mode(p) for p in posterior_samples_1.T])))
        return results
    
    chunked_indices = [indices[i:i+ChunkSize] for i in range(0, len(indices), ChunkSize)]
    results = Parallel(n_jobs=8)(
        delayed(optimize_chunk)(chunk) for chunk in tqdm(chunked_indices)
    )
    
    NoiseEst = np.zeros(list(ArrShape) + [22])
    
    # Assign the optimization results to NoiseEst
    for chunk in results:
        for i, j, x in chunk:
            NoiseEst[i, j] = x
    NoiseEst2 =  np.zeros_like(NoiseEst)
    for i in range(ArrShape[0]):
        for j in range(ArrShape[1]):  
            NoiseEst2[i,j] = np.hstack([mat_to_vals(clip_negative_eigenvalues(vals_to_mat(NoiseEst[i,j]))),NoiseEst[i,j,6:]])
    
    MK7  = np.zeros(ArrShape)
    RK7  = np.zeros(ArrShape)
    AK7  = np.zeros(ArrShape)
    MKT7 = np.zeros(ArrShape)
    KFA7 = np.zeros(ArrShape)
    for i in tqdm(range(ArrShape[0])):
        for j in range(ArrShape[1]):
            Metrics = DKIMetrics(NoiseEst2[i,j][:6],NoiseEst2[i,j][6:21])
            MK7[i,j] = Metrics[0]
            AK7[i,j] = Metrics[1]
            RK7[i,j] = Metrics[2]
            MKT7[i,j] = Metrics[3]
            KFA7[i,j] = Metrics[4]
    MKMidArr.append(MK7)
    RKMidArr.append(RK7)
    AKMidArr.append(AK7)
    MKTMidArr.append(MKT7)
    KFAMidArr.append(KFA7)
    fig,ax = plt.subplots(1,5,figsize=(24,6.4))
    plt.sca(ax[0])
    plt.imshow(MK7.T,vmin=0,vmax=1)
    plt.sca(ax[1])
    plt.imshow(RK7.T,vmin=0,vmax=1)
    plt.sca(ax[2])
    plt.imshow(AK7.T,vmin=0,vmax=1)
    plt.sca(ax[3])
    plt.imshow(MKT7.T,vmin=0,vmax=1)
    plt.sca(ax[4])
    plt.imshow(KFA7.T,vmin=0,vmax=1)
    plt.show()

In [None]:
MKMinArr = []
RKMinArr = []
AKMinArr = []
MKTMinArr = []
KFAMinArr = []
for kk in range(8):
    Dat = Dats[kk]
    ArrShape = Dat.shape[:2]
    
    # Compute the mask where the sum is not zero
    mask = np.sum(Dats[kk][:, :, axial_middles[kk], :], axis=-1) != 0
    
    # Get the indices where mask is True
    indices = np.argwhere(mask)
    Arr = Dats[kk][:,:,axial_middles[kk], Indxs7]
    def optimize_chunk(pixels):
        results = []
        for i, j in pixels:
            posterior_samples_1 = posteriorMin.sample((500,), x=np.hstack([Arr[i,j],kk]),show_progress_bars=False)
            results.append((i, j, np.array([histogram_mode(p) for p in posterior_samples_1.T])))
        return results
    
    chunked_indices = [indices[i:i+ChunkSize] for i in range(0, len(indices), ChunkSize)]
    results = Parallel(n_jobs=8)(
        delayed(optimize_chunk)(chunk) for chunk in tqdm(chunked_indices)
    )
    
    NoiseEst = np.zeros(list(ArrShape) + [22])
    
    # Assign the optimization results to NoiseEst
    for chunk in results:
        for i, j, x in chunk:
            NoiseEst[i, j] = x  
        
    NoiseEst2 =  np.zeros_like(NoiseEst)
    for i in range(ArrShape[0]):
        for j in range(ArrShape[1]):  
            NoiseEst2[i,j] = np.hstack([mat_to_vals(clip_negative_eigenvalues(vals_to_mat(NoiseEst[i,j]))),NoiseEst[i,j,6:]])
    
    MK7  = np.zeros(ArrShape)
    RK7  = np.zeros(ArrShape)
    AK7  = np.zeros(ArrShape)
    MKT7 = np.zeros(ArrShape)
    KFA7 = np.zeros(ArrShape)
    for i in tqdm(range(ArrShape[0])):
        for j in range(ArrShape[1]):
            Metrics = DKIMetrics(NoiseEst2[i,j][:6],NoiseEst2[i,j][6:21])
            MK7[i,j] = Metrics[0]
            AK7[i,j] = Metrics[1]
            RK7[i,j] = Metrics[2]
            MKT7[i,j] = Metrics[3]
            KFA7[i,j] = Metrics[4]
    MKMinArr.append(MK7)
    RKMinArr.append(RK7)
    AKMinArr.append(AK7)
    MKTMinArr.append(MKT7)
    KFAMinArr.append(KFA7)
    fig,ax = plt.subplots(1,5,figsize=(24,6.4))
    plt.sca(ax[0])
    plt.imshow(MK7.T,vmin=0,vmax=1)
    plt.sca(ax[1])
    plt.imshow(RK7.T,vmin=0,vmax=1)
    plt.sca(ax[2])
    plt.imshow(AK7.T,vmin=0,vmax=1)
    plt.sca(ax[3])
    plt.imshow(MKT7.T,vmin=0,vmax=1)
    plt.sca(ax[4])
    plt.imshow(KFA7.T,vmin=0,vmax=1)
    plt.show()

In [None]:
MKFullNLArr = []
RKFullNLArr = []
AKFullNLArr = []
MKTFullNLArr = []
KFAFullNLArr = []
for kk in range(8):
    dkimodelNL = dki.DiffusionKurtosisModel(gTabsF[kk],fit_method='NLLS')
    dkifitNL = dkimodelNL.fit(Dats[kk][:,:,axial_middles[kk],:])
    ArrShape = Dats[kk][:,:,axial_middles[kk],0].shape
    MK_NLFull  = np.zeros(ArrShape)
    AK_NLFull  = np.zeros(ArrShape)
    RK_NLFull = np.zeros(ArrShape)
    MKT_NLFull = np.zeros(ArrShape)
    KFA_NLFull = np.zeros(ArrShape)
    NoiseEst_NL = np.zeros(list(ArrShape)+[21])
    for i in tqdm(range(ArrShape[0])):
        for j in range(ArrShape[1]):
            NoiseEst_NL[i,j] = np.hstack([dkifitNL[i,j].lower_triangular(),dkifitNL[i,j].kt])
    NoiseEst_NL2 =  np.zeros_like(NoiseEst_NL)
    for i in tqdm(range(ArrShape[0])):
        for j in range(ArrShape[1]):
            NoiseEst_NL2[i,j] = np.hstack([mat_to_vals(clip_negative_eigenvalues(vals_to_mat(NoiseEst_NL[i,j]))),NoiseEst_NL[i,j,6:]])
    for i in tqdm(range(ArrShape[0])):
        for j in range(ArrShape[1]):
            if(np.sum(Dats[kk][i,j,axial_middles[kk]]) == 0):
                pass
            else: 
                Metrics = DKIMetrics(NoiseEst_NL2[i,j][:6],NoiseEst_NL2[i,j][6:21])
                MK_NLFull[i,j] = Metrics[0]
                AK_NLFull[i,j] = Metrics[1]
                RK_NLFull[i,j] = Metrics[2]
                MKT_NLFull[i,j] = Metrics[3]
                KFA_NLFull[i,j] = Metrics[4]
    MKFullNLArr.append(MK_NLFull)
    RKFullNLArr.append(RK_NLFull)
    AKFullNLArr.append(AK_NLFull)
    MKTFullNLArr.append(MKT_NLFull)
    KFAFullNLArr.append(KFA_NLFull)
    
    fig,ax = plt.subplots(1,5,figsize=(24,6.4))
    plt.sca(ax[0])
    plt.imshow(MK_NLFull.T,vmin=0,vmax=1)
    plt.sca(ax[1])
    plt.imshow(RK_NLFull.T,vmin=0,vmax=1)
    plt.sca(ax[2])
    plt.imshow(AK_NLFull.T,vmin=0,vmax=1)
    plt.sca(ax[3])
    plt.imshow(MKT_NLFull.T,vmin=0,vmax=1)
    plt.sca(ax[4])
    plt.imshow(KFA_NLFull.T,vmin=0,vmax=1)
    plt.show()

In [None]:
MK20NLArr = []
RK20NLArr = []
AK20NLArr = []
MKT20NLArr = []
KFA20NLArr = []
for kk in range(8):
    dkimodelNL = dki.DiffusionKurtosisModel(gTabs20[kk],fit_method='NLLS')
    dkifitNL = dkimodelNL.fit(Dats[kk][:,:,axial_middles[kk],Indxs20])
    ArrShape = Dats[kk][:,:,axial_middles[kk],0].shape
    MK_NLFull  = np.zeros(ArrShape)
    AK_NLFull  = np.zeros(ArrShape)
    RK_NLFull = np.zeros(ArrShape)
    MKT_NLFull = np.zeros(ArrShape)
    KFA_NLFull = np.zeros(ArrShape)
    NoiseEst_NL = np.zeros(list(ArrShape)+[21])
    for i in tqdm(range(ArrShape[0])):
        for j in range(ArrShape[1]):
            NoiseEst_NL[i,j] = np.hstack([dkifitNL[i,j].lower_triangular(),dkifitNL[i,j].kt])
    NoiseEst_NL2 =  np.zeros_like(NoiseEst_NL)
    for i in tqdm(range(ArrShape[0])):
        for j in range(ArrShape[1]):
            NoiseEst_NL2[i,j] = np.hstack([mat_to_vals(clip_negative_eigenvalues(vals_to_mat(NoiseEst_NL[i,j]))),NoiseEst_NL[i,j,6:]])
    for i in tqdm(range(ArrShape[0])):
        for j in range(ArrShape[1]):
            if(np.sum(Dats[kk][i,j,axial_middles[kk]]) == 0):
                pass
            else: 
                Metrics = DKIMetrics(NoiseEst_NL2[i,j][:6],NoiseEst_NL2[i,j][6:21])
                MK_NLFull[i,j] = Metrics[0]
                AK_NLFull[i,j] = Metrics[1]
                RK_NLFull[i,j] = Metrics[2]
                MKT_NLFull[i,j] = Metrics[3]
                KFA_NLFull[i,j] = Metrics[4]
    MK20NLArr.append(MK_NLFull)
    RK20NLArr.append(RK_NLFull)
    AK20NLArr.append(AK_NLFull)
    MKT20NLArr.append(MKT_NLFull)
    KFA20NLArr.append(KFA_NLFull)
    
    fig,ax = plt.subplots(1,5,figsize=(24,6.4))
    plt.sca(ax[0])
    plt.imshow(MK_NLFull.T,vmin=0,vmax=1)
    plt.sca(ax[1])
    plt.imshow(RK_NLFull.T,vmin=0,vmax=1)
    plt.sca(ax[2])
    plt.imshow(AK_NLFull.T,vmin=0,vmax=1)
    plt.sca(ax[3])
    plt.imshow(MKT_NLFull.T,vmin=0,vmax=1)
    plt.sca(ax[4])
    plt.imshow(KFA_NLFull.T,vmin=0,vmax=1)
    plt.show()

In [None]:
MK7NLArr = []
RK7NLArr = []
AK7NLArr = []
MKT7NLArr = []
KFA7NLArr = []
for kk in range(8):
    dkimodelNL = dki.DiffusionKurtosisModel(gTabs7[kk],fit_method='NLLS')
    dkifitNL = dkimodelNL.fit(Dats[kk][:,:,axial_middles[kk],Indxs7])
    ArrShape = Dats[kk][:,:,axial_middles[kk],0].shape
    MK_NLFull  = np.zeros(ArrShape)
    AK_NLFull  = np.zeros(ArrShape)
    RK_NLFull = np.zeros(ArrShape)
    MKT_NLFull = np.zeros(ArrShape)
    KFA_NLFull = np.zeros(ArrShape)
    NoiseEst_NL = np.zeros(list(ArrShape)+[21])
    for i in tqdm(range(ArrShape[0])):
        for j in range(ArrShape[1]):
            NoiseEst_NL[i,j] = np.hstack([dkifitNL[i,j].lower_triangular(),dkifitNL[i,j].kt])
    NoiseEst_NL2 =  np.zeros_like(NoiseEst_NL)
    for i in tqdm(range(ArrShape[0])):
        for j in range(ArrShape[1]):
            NoiseEst_NL2[i,j] = np.hstack([mat_to_vals(clip_negative_eigenvalues(vals_to_mat(NoiseEst_NL[i,j]))),NoiseEst_NL[i,j,6:]])
    for i in tqdm(range(ArrShape[0])):
        for j in range(ArrShape[1]):
            if(np.sum(Dats[kk][i,j,axial_middles[kk]]) == 0):
                pass
            else: 
                Metrics = DKIMetrics(NoiseEst_NL2[i,j][:6],NoiseEst_NL2[i,j][6:21])
                MK_NLFull[i,j] = Metrics[0]
                AK_NLFull[i,j] = Metrics[1]
                RK_NLFull[i,j] = Metrics[2]
                MKT_NLFull[i,j] = Metrics[3]
                KFA_NLFull[i,j] = Metrics[4]
    MK7NLArr.append(MK_NLFull)
    RK7NLArr.append(RK_NLFull)
    AK7NLArr.append(AK_NLFull)
    MKT7NLArr.append(MKT_NLFull)
    KFA7NLArr.append(KFA_NLFull)
    
    fig,ax = plt.subplots(1,5,figsize=(24,6.4))
    plt.sca(ax[0])
    plt.imshow(MK_NLFull.T,vmin=0,vmax=1)
    plt.sca(ax[1])
    plt.imshow(AK_NLFull.T,vmin=0,vmax=1)
    plt.sca(ax[2])
    plt.imshow(RK_NLFull.T,vmin=0,vmax=1)
    plt.sca(ax[3])
    plt.imshow(MKT_NLFull.T,vmin=0,vmax=1)
    plt.sca(ax[4])
    plt.imshow(KFA_NLFull.T,vmin=0,vmax=1)
    plt.show()

In [None]:
AccM7_MK = []
AccM20_MK = []
AccMFulls_MK = []

AccM7NL_MK = []
AccM20NL_MK = []

SSIM7_MK = []
SSIM20_MK = []
SSIMFulls_MK = []

SSIM7NL_MK = []
SSIM20NL_MK = []
for i in range(8):
    M7 = MKMinArr[i]
    MF = MKFullArr[i]
    Ma = Masks[i][:,:,axial_middles[i]]
    AccM7_MK.append(np.mean(np.abs(M7-MF)[Ma]))

    M7 = MKMidArr[i]
    MF = MKFullArr[i]
    Ma = Masks[i][:,:,axial_middles[i]]
    AccM20_MK.append(np.mean(np.abs(M7-MF)[Ma]))

    M7 = MKFullArr[i]
    MF = MKFullNLArr[i]
    Ma = Masks[i][:,:,axial_middles[i]]
    AccMFulls_MK.append(np.mean(np.abs(M7-MF)[Ma]))

    M7 = MK7NLArr[i]
    MF = MKFullNLArr[i]

    AccM7NL_MK.append(np.mean(np.abs(M7-MF)[Ma]))

    M7 = MK20NLArr[i]
    MF = MKFullNLArr[i]

    AccM20NL_MK.append(np.mean(np.abs(M7-MF)[Ma]))

    NS1 = MKMinArr[i]
    NS2 = MKFullArr[i]

    result = masked_local_ssim(NS1, NS2, Ma, win_size=7)
    SSIM7_MK.append(result)

    NS1 = MKMidArr[i]
    NS2 = MKFullArr[i]

    result = masked_local_ssim(NS1, NS2, Ma, win_size=7)
    SSIM20_MK.append(result)
    
    NS1 = MKFullArr[i]
    NS2 = MKFullNLArr[i]

    result = masked_local_ssim(NS1, NS2, Ma, win_size=7)
    SSIMFulls_MK.append(result)

    NS1 = MK7NLArr[i]
    NS2 = MKFullNLArr[i]

    result = masked_local_ssim(NS1, NS2, Ma, win_size=7)
    SSIM7NL_MK.append(result)

    NS1 = MK20NLArr[i]
    NS2 = MKFullNLArr[i]

    result = masked_local_ssim(NS1, NS2, Ma, win_size=7)
    SSIM20NL_MK.append(result)


Prec7_SBI_MK = []
Prec20_SBI_MK = []
PrecFull_SBI_MK = []

Prec7_NLLS_MK = []
Prec20_NLLS_MK = []
PrecFull_NLLS_MK = []
for i in range(8):
    Prec7_SBI_MK.append(np.std(MKMinArr[i][WMs[i][:,:,axial_middles[i]]>0.8]))
    Prec20_SBI_MK.append(np.std(MKMidArr[i][WMs[i][:,:,axial_middles[i]]>0.8]))
    PrecFull_SBI_MK.append(np.std(MKFullArr[i][WMs[i][:,:,axial_middles[i]]>0.8]))

    Prec7_NLLS_MK.append(np.std(MK7NLArr[i][WMs[i][:,:,axial_middles[i]]>0.8]))
    Prec20_NLLS_MK.append(np.std(MK20NLArr[i][WMs[i][:,:,axial_middles[i]]>0.8]))
    PrecFull_NLLS_MK.append(np.std(MKFullNLArr[i][WMs[i][:,:,axial_middles[i]]>0.8]))


In [None]:
AccM7_AK = []
AccM20_AK = []
AccMFulls_AK = []

AccM7NL_AK = []
AccM20NL_AK = []

SSIM7_AK = []
SSIM20_AK = []
SSIMFulls_AK = []

SSIM7NL_AK = []
SSIM20NL_AK = []
for i in range(8):
    M7 = AKMinArr[i]
    MF = AKFullArr[i]
    Ma = Masks[i][:,:,axial_middles[i]]
    AccM7_AK.append(np.mean(np.abs(M7-MF)[Ma]))

    M7 = AKMidArr[i]
    MF = AKFullArr[i]
    AccM20_AK.append(np.mean(np.abs(M7-MF)[Ma]))


    M7 = AKFullArr[i]
    MF = AKFullNLArr[i]
    AccMFulls_AK.append(np.mean(np.abs(M7-MF)[Ma]))

    M7 = AK7NLArr[i]
    MF = AKFullNLArr[i]
    AccM7NL_AK.append(np.mean(np.abs(M7-MF)[Ma]))

    M7 = AK20NLArr[i]
    MF = AKFullNLArr[i]

    AccM20NL_AK.append(np.mean(np.abs(M7-MF)[Ma]))

    NS1 = np.clip(AKMinArr[i],0,1)
    NS1 = gaussian_filter(NS1, sigma=0.5)
    NS2 = np.clip(AKFullArr[i],0,1)
    NS2 = gaussian_filter(NS2, sigma=0.5)
    result = masked_local_ssim(NS1, NS2, Ma, win_size=7)
    SSIM7_AK.append(result)

    NS1 = np.clip(AKMidArr[i],0,1)
    NS1 = gaussian_filter(NS1, sigma=0.5)
    NS2 = np.clip(AKFullArr[i],0,1)
    NS2 = gaussian_filter(NS2, sigma=0.5)
    result = masked_local_ssim(NS1, NS2, Ma, win_size=7)
    SSIM20_AK.append(result)
    
    NS1 = AKFullArr[i]
    NS2 = AKFullNLArr[i]
    Ma = Masks[i][:,:,axial_middles[i]]
    result = masked_local_ssim(NS1, NS2, Ma, win_size=7)
    SSIMFulls_AK.append(result)

    NS1 = np.clip(AK7NLArr[i],0,1)
    NS1 = gaussian_filter(NS1, sigma=0.5)
    NS2 = np.clip(AKFullNLArr[i],0,1)
    NS2 = gaussian_filter(NS2, sigma=0.5)
    Ma = Masks[i][:,:,axial_middles[i]]
    result = masked_local_ssim(NS1, NS2, Ma, win_size=7)
    SSIM7NL_AK.append(result)

    NS1 = np.clip(AK20NLArr[i],0,1)
    NS1 = gaussian_filter(NS1, sigma=0.5)
    NS2 = np.clip(AKFullNLArr[i],0,1)
    Ma = Masks[i][:,:,axial_middles[i]]
    result = masked_local_ssim(NS1, NS2, Ma, win_size=7)
    SSIM20NL_AK.append(result)

Prec7_SBI_AK = []
Prec20_SBI_AK = []
PrecFull_SBI_AK = []

Prec7_NLLS_AK = []
Prec20_NLLS_AK = []
PrecFull_NLLS_AK = []
for i in range(8):
    Prec7_SBI_AK.append(np.std(AKMinArr[i][WMs[i][:,:,axial_middles[i]]>0.8]))
    Prec20_SBI_AK.append(np.std(AKMidArr[i][WMs[i][:,:,axial_middles[i]]>0.8]))
    PrecFull_SBI_AK.append(np.std(AKFullArr[i][WMs[i][:,:,axial_middles[i]]>0.8]))

    Prec7_NLLS_AK.append(np.std(AK7NLArr[i][WMs[i][:,:,axial_middles[i]]>0.8]))
    Prec20_NLLS_AK.append(np.std(AK20NLArr[i][WMs[i][:,:,axial_middles[i]]>0.8]))
    PrecFull_NLLS_AK.append(np.std(AKFullNLArr[i][WMs[i][:,:,axial_middles[i]]>0.8]))



In [None]:
AccM7_RK = []
AccM20_RK = []
AccMFulls_RK = []

AccM7NL_RK = []
AccM20NL_RK = []

SSIM7_RK = []
SSIM20_RK = []
SSIMFulls_RK = []

SSIM7NL_RK = []
SSIM20NL_RK = []
for i in range(8):
    M7 = RKMinArr[i]
    MF = RKFullArr[i]
    Ma = Masks[i][:,:,axial_middles[i]]
    AccM7_RK.append(np.mean(np.abs(M7-MF)[Ma]))

    M7 = RKMidArr[i]
    MF = RKFullArr[i]
    Ma = Masks[i][:,:,axial_middles[i]]
    AccM20_RK.append(np.mean(np.abs(M7-MF)[Ma]))

    M7 = RKFullArr[i]
    MF = RKFullNLArr[i]
    Ma = Masks[i][:,:,axial_middles[i]]
    AccMFulls_RK.append(np.mean(np.abs(M7-MF)[Ma]))

    M7 = RK7NLArr[i]
    MF = RKFullNLArr[i]

    AccM7NL_RK.append(np.mean(np.abs(M7-MF)[Ma]))

    M7 = RK20NLArr[i]
    MF = RKFullNLArr[i]

    AccM20NL_RK.append(np.mean(np.abs(M7-MF)[Ma]))

    NS1 = RKMinArr[i]
    NS2 = RKFullArr[i]

    result = masked_local_ssim(NS1, NS2, Ma, win_size=7)
    SSIM7_RK.append(result)

    NS1 = RKMidArr[i]
    NS2 = RKFullArr[i]

    result = masked_local_ssim(NS1, NS2, Ma, win_size=7)
    SSIM20_RK.append(result)
    
    NS1 = RKFullArr[i]
    NS2 = RKFullNLArr[i]

    result = masked_local_ssim(NS1, NS2, Ma, win_size=7)
    SSIMFulls_RK.append(result)

    NS1 = RK7NLArr[i]
    NS2 = RKFullNLArr[i]

    result = masked_local_ssim(NS1, NS2, Ma, win_size=7)
    SSIM7NL_RK.append(result)

    NS1 = RK20NLArr[i]
    NS2 = RKFullNLArr[i]

    result = masked_local_ssim(NS1, NS2, Ma, win_size=7)
    SSIM20NL_RK.append(result)


Prec7_SBI_RK = []
Prec20_SBI_RK = []
PrecFull_SBI_RK = []

Prec7_NLLS_RK = []
Prec20_NLLS_RK = []
PrecFull_NLLS_RK = []
for i in range(8):
    Prec7_SBI_RK.append(np.std(RKMinArr[i][WMs[i][:,:,axial_middles[i]]>0.8]))
    Prec20_SBI_RK.append(np.std(RKMidArr[i][WMs[i][:,:,axial_middles[i]]>0.8]))
    PrecFull_SBI_RK.append(np.std(RKFullArr[i][WMs[i][:,:,axial_middles[i]]>0.8]))

    Prec7_NLLS_RK.append(np.std(RK7NLArr[i][WMs[i][:,:,axial_middles[i]]>0.8]))
    Prec20_NLLS_RK.append(np.std(RK20NLArr[i][WMs[i][:,:,axial_middles[i]]>0.8]))
    PrecFull_NLLS_RK.append(np.std(RKFullNLArr[i][WMs[i][:,:,axial_middles[i]]>0.8]))


### f

In [None]:
# Plot setup
fig, ax2 = plt.subplots(1,1,figsize=(3.2,4.8))
fig.subplots_adjust(hspace=0.05)



y_data = np.array(AccM7NL_MK)
g_pos = np.array([3.05])
colors = ['burlywood']
colors2 = ['peachpuff']

BoxPlots(y_data,g_pos,colors,colors2,ax2,widths=0.2,scatter=True)
plt.xticks([1,1.7,2,2.8,3.1],['Full','Mid','Min','Mid','Min'],fontsize=32,rotation=90)
#plt.gca().ticklabel_format(axis='y',style='sci',scilimits=(-1,1))
plt.yticks(fontsize=24)


plt.sca(ax2)

y_data = np.array(AccMFulls_MK)
g_pos = np.array([0.8])
colors = ['black']
colors2 = ['gray']

BoxPlots(y_data,g_pos,colors,colors2,ax2,widths=0.2,scatter=True)

y_data = np.array(AccM20_MK)
g_pos = np.array([1.55])
colors = ['lightseagreen']
colors2 = ['paleturquoise']

BoxPlots(y_data,g_pos,colors,colors2,ax2,widths=0.2,scatter=True)

y_data = np.array(AccM7_MK)
g_pos = np.array([1.95])
colors = ['mediumturquoise']
colors2 = ['paleturquoise']

BoxPlots(y_data,g_pos,colors,colors2,ax2,widths=0.2,scatter=True)

y_data = np.array(AccM20NL_MK)
g_pos = np.array([2.65])
colors = ['sandybrown']
colors2 = ['peachpuff']

BoxPlots(y_data,g_pos,colors,colors2,ax2,widths=0.2,scatter=True)

plt.xticks([0.8,1.55,1.95,2.65,3.05],['Full','Mid','Min','Mid','Min'],fontsize=32,rotation=90)
#plt.gca().ticklabel_format(axis='y',style='sci',scilimits=(-0.5,3))
#ax1.set_ylim(0.4, 1)
#ax1.ticklabel_format(axis='y', style='sci', scilimits=(-3, -3))
#ax1.yaxis.set_ticks([0.4,0.7,1])

plt.yticks(fontsize=24)

#ax2.set_ylim(0.0,0.3)
#ax2.ticklabel_format(axis='y', style='sci', scilimits=(-3, -3))
#ax2.yaxis.set_ticks(np.arange(0, 0.0001, 0.00004))

# Common x-ticks
#ax2.set_xticks([1, 1.7, 2, 2.8, 3.1])

# Adding broken axis effect
d = .5
kwargs = dict(marker=[(-1, -d), (1, d)], markersize=12, linestyle="none", color='k', mec='k', mew=1, clip_on=False)


# Hide the spines between ax and ax2
ax2.spines.top.set_visible(False)
ax2.xaxis.tick_bottom()
ax2.yaxis.offsetText.set_visible(False)  # Hide the "1e-3" from ax2

# Show plot

leg_patch2 = mpatches.Patch(color='mediumturquoise', label='SBI')
leg_patch3 = mpatches.Patch(color='sandybrown', label='NLLS')
leg_patch5 = mpatches.Patch(color='gray', label='Full \nComp.')

ax2.set_xlim(0.5,3.5)
if Save: plt.savefig(FigLoc+'DKI_Acc_MK_MS.pdf',format='PDF',transparent=True,bbox_inches='tight')
plt.show()

In [None]:
fig, ax = plt.subplots(figsize=(3.2,4.8))#, sharex=True)

# Plotting on ax1
plt.sca(ax)
y_data = np.array(SSIMFulls_MK)
g_pos = np.array([0.8])
colors = ['black']
colors2 = ['gray']

BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2,scatter=True)

y_data = np.array(SSIM20_MK)
g_pos = np.array([1.5])
colors = ['lightseagreen']
colors2 = ['paleturquoise']

BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2,scatter=True,scatter_alpha=0.5)

y_data = np.array(SSIM7_MK)
g_pos = np.array([1.95])
colors = ['mediumturquoise']
colors2 = ['paleturquoise']

BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2,scatter=True)

y_data = np.array(SSIM20NL_MK)
g_pos = np.array([2.6])
colors = ['sandybrown']
colors2 = ['peachpuff']

BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2,scatter=True)

y_data = np.array(SSIM7NL_MK)
g_pos = np.array([3.05])
colors = ['burlywood']
colors2 = ['peachpuff']

BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2,scatter=True)

y_data = np.array(SSIM7NL_MK)
g_pos = np.array([3.1])
colors = ['burlywood']
colors2 = ['peachpuff']

plt.axhline(0.66, lw=3, ls='--', c='k')
plt.xticks([0.8,1.5,1.95,2.6,3.05],['Full','Mid','Min','Mid','Min'],fontsize=32,rotation=90)
ax.set_yticks([0,0.2,0.4,0.6,0.8,1.0])
ax.set_ylim(-0.1,1)

ax.legend(
    handles=[leg_patch3],
    loc='lower left',         # base location  # fine-tune the legend's position
    frameon=False, ncols=1,
fontsize=32,columnspacing=0.3,handlelength=0.6,handletextpad=0.3,bbox_to_anchor= (-0.1,-0.05))

if Save: plt.savefig(FigLoc+'DKIHCP_SSIM_MK_MS.pdf',format='PDF',transparent=True,bbox_inches='tight')

In [None]:
fig,ax = plt.subplots(figsize=(3.2,4.8))
y_data = np.array(PrecFull_SBI_MK)
g_pos = np.array([0.65])
colors = ['mediumturquoise']
colors2 = ['paleturquoise']

BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2,scatter=True)

y_data = np.array(Prec20_SBI_MK)
g_pos = np.array([1])
BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2,scatter=True)

y_data = np.array(Prec7_SBI_MK)
g_pos = np.array([1.35])
BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2,scatter=True)

colors = ['sandybrown']
colors2 = ['peachpuff']
y_data = np.array(PrecFull_NLLS_MK)
g_pos = np.array([1.8])
BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2,scatter=True)

y_data = np.array(Prec20_NLLS_MK)
g_pos = np.array([2.15])
BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2,scatter=True)

y_data = np.array(Prec7_NLLS_MK)
g_pos = np.array([2.5])
BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2,scatter=True)

x = np.arange(1.7,2.6,0.05)
y1 = np.ones_like(x)*np.percentile(np.array(PrecFull_NLLS_MK)[~np.isnan(PrecFull_NLLS_MK)], 25)
y2 = np.ones_like(x)*np.percentile(np.array(PrecFull_NLLS_MK)[~np.isnan(PrecFull_NLLS_MK)], 77)
plt.fill_between(x,y1,y2,color=WLSFit,zorder=10,alpha=0.2,hatch='//')

x = np.arange(0.55,1.5,0.05)
y1 = np.ones_like(x)*np.percentile(np.array(PrecFull_SBI_MK)[~np.isnan(PrecFull_SBI_MK)], 25)
y2 = np.ones_like(x)*np.percentile(np.array(PrecFull_SBI_MK)[~np.isnan(PrecFull_SBI_MK)], 77)
plt.fill_between(x,y1,y2,color=SBIFit,zorder=10,alpha=0.2,hatch='//')

plt.xticks([0.65,1,1.35,1.8,2.15,2.5],['Full','Mid','Min','Full','Mid','Min'],fontsize=32,rotation=90)
plt.gca().ticklabel_format(axis='y',style='sci',scilimits=(-0.5,3))
plt.yticks(fontsize=24)
ax.set_xlim([0.4,2.7])
if Save: plt.savefig(FigLoc+'DKI_RK_Prec_MS.pdf',format='pdf',bbox_inches='tight',transparent=True)

### g

In [None]:
# Plot setup
fig, ax2 = plt.subplots(1,1,figsize=(3.2,4.8))
fig.subplots_adjust(hspace=0.05)



y_data = np.array(AccM7NL_AK)
g_pos = np.array([3.05])
colors = ['burlywood']
colors2 = ['peachpuff']

BoxPlots(y_data,g_pos,colors,colors2,ax2,widths=0.2,scatter=True)
plt.xticks([1,1.7,2,2.8,3.1],['Full','Mid','Min','Mid','Min'],fontsize=32,rotation=90)
#plt.gca().ticklabel_format(axis='y',style='sci',scilimits=(-1,1))
plt.yticks(fontsize=24)


plt.sca(ax2)

y_data = np.array(AccMFulls_AK)
g_pos = np.array([0.8])
colors = ['black']
colors2 = ['gray']

BoxPlots(y_data,g_pos,colors,colors2,ax2,widths=0.2,scatter=True)

y_data = np.array(AccM20_AK)
g_pos = np.array([1.55])
colors = ['lightseagreen']
colors2 = ['paleturquoise']

BoxPlots(y_data,g_pos,colors,colors2,ax2,widths=0.2,scatter=True)

y_data = np.array(AccM7_AK)
g_pos = np.array([1.95])
colors = ['mediumturquoise']
colors2 = ['paleturquoise']

BoxPlots(y_data,g_pos,colors,colors2,ax2,widths=0.2,scatter=True)

y_data = np.array(AccM20NL_AK)
g_pos = np.array([2.65])
colors = ['sandybrown']
colors2 = ['peachpuff']

BoxPlots(y_data,g_pos,colors,colors2,ax2,widths=0.2,scatter=True)

plt.xticks([0.8,1.55,1.95,2.65,3.05],['Full','Mid','Min','Mid','Min'],fontsize=32,rotation=90)
#plt.gca().ticklabel_format(axis='y',style='sci',scilimits=(-0.5,3))
#ax1.set_ylim(0.4, 1)
#ax1.ticklabel_format(axis='y', style='sci', scilimits=(-3, -3))
#ax1.yaxis.set_ticks([0.4,0.7,1])

plt.yticks(fontsize=24)

#ax2.set_ylim(0.0,0.3)
#ax2.ticklabel_format(axis='y', style='sci', scilimits=(-3, -3))
#ax2.yaxis.set_ticks(np.arange(0, 0.0001, 0.00004))

# Common x-ticks
#ax2.set_xticks([1, 1.7, 2, 2.8, 3.1])

# Adding broken axis effect
d = .5
kwargs = dict(marker=[(-1, -d), (1, d)], markersize=12, linestyle="none", color='k', mec='k', mew=1, clip_on=False)


# Hide the spines between ax and ax2
ax2.spines.top.set_visible(False)
ax2.xaxis.tick_bottom()
ax2.yaxis.offsetText.set_visible(False)  # Hide the "1e-3" from ax2

# Show plot

leg_patch2 = mpatches.Patch(color='mediumturquoise', label='SBI')
leg_patch3 = mpatches.Patch(color='sandybrown', label='NLLS')
leg_patch5 = mpatches.Patch(color='gray', label='Full \nComp.')

ax2.set_xlim(0.5,3.5)
if Save: plt.savefig(FigLoc+'DKI_Acc_AK_MS.pdf',format='PDF',transparent=True,bbox_inches='tight')
plt.show()

In [None]:
fig, ax = plt.subplots(figsize=(3.2,4.8))#, sharex=True)

# Plotting on ax1
plt.sca(ax)
y_data = np.array(SSIMFulls_AK)
g_pos = np.array([0.8])
colors = ['black']
colors2 = ['gray']

BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2,scatter=True)

y_data = np.array(SSIM20_AK)
g_pos = np.array([1.5])
colors = ['lightseagreen']
colors2 = ['paleturquoise']

BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2,scatter=True,scatter_alpha=0.5)

y_data = np.array(SSIM7_AK)
g_pos = np.array([1.95])
colors = ['mediumturquoise']
colors2 = ['paleturquoise']

BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2,scatter=True)

y_data = np.array(SSIM20NL_AK)
g_pos = np.array([2.6])
colors = ['sandybrown']
colors2 = ['peachpuff']

BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2,scatter=True)

y_data = np.array(SSIM7NL_AK)
g_pos = np.array([3.05])
colors = ['burlywood']
colors2 = ['peachpuff']

BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2,scatter=True)

y_data = np.array(SSIM7NL_AK)
g_pos = np.array([3.1])
colors = ['burlywood']
colors2 = ['peachpuff']

plt.axhline(0.66, lw=3, ls='--', c='k')
plt.xticks([0.8,1.5,1.95,2.6,3.05],['Full','Mid','Min','Mid','Min'],fontsize=32,rotation=90)
ax.set_yticks([0,0.2,0.4,0.6,0.8,1.0])
ax.set_ylim(-0.1,1)

ax.legend(
    handles=[leg_patch3],
    loc='lower left',         # base location  # fine-tune the legend's position
    frameon=False, ncols=1,
fontsize=32,columnspacing=0.3,handlelength=0.6,handletextpad=0.3,bbox_to_anchor= (-0.1,-0.05))

if Save: plt.savefig(FigLoc+'DKI_SSIM_AK_MS.pdf',format='PDF',transparent=True,bbox_inches='tight')

In [None]:
fig,ax = plt.subplots(figsize=(3.2,4.8))
y_data = np.array(PrecFull_SBI_AK)
g_pos = np.array([0.65])
colors = ['mediumturquoise']
colors2 = ['paleturquoise']

BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2,scatter=True)

y_data = np.array(Prec20_SBI_AK)
g_pos = np.array([1])
BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2,scatter=True)

y_data = np.array(Prec7_SBI_AK)
g_pos = np.array([1.35])
BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2,scatter=True)

colors = ['sandybrown']
colors2 = ['peachpuff']
y_data = np.array(PrecFull_NLLS_AK)
g_pos = np.array([1.8])
BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2,scatter=True)

y_data = np.array(Prec20_NLLS_AK)
g_pos = np.array([2.15])
BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2,scatter=True)

y_data = np.array(Prec7_NLLS_AK)
g_pos = np.array([2.5])
BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2,scatter=True)

x = np.arange(1.7,2.6,0.05)
y1 = np.ones_like(x)*np.percentile(np.array(PrecFull_NLLS_AK)[~np.isnan(PrecFull_NLLS_AK)], 25)
y2 = np.ones_like(x)*np.percentile(np.array(PrecFull_NLLS_AK)[~np.isnan(PrecFull_NLLS_AK)], 77)
plt.fill_between(x,y1,y2,color=WLSFit,zorder=10,alpha=0.2,hatch='//')

x = np.arange(0.55,1.5,0.05)
y1 = np.ones_like(x)*np.percentile(np.array(PrecFull_SBI_AK)[~np.isnan(PrecFull_SBI_AK)], 25)
y2 = np.ones_like(x)*np.percentile(np.array(PrecFull_SBI_AK)[~np.isnan(PrecFull_SBI_AK)], 77)
plt.fill_between(x,y1,y2,color=SBIFit,zorder=10,alpha=0.2,hatch='//')

plt.xticks([0.65,1,1.35,1.8,2.15,2.5],['Full','Mid','Min','Full','Mid','Min'],fontsize=32,rotation=90)
plt.gca().ticklabel_format(axis='y',style='sci',scilimits=(-0.5,3))
plt.yticks(fontsize=24)
ax.set_xlim([0.4,2.7])
if Save: plt.savefig(FigLoc+'DKI_AK_Prec_MS.pdf',format='pdf',bbox_inches='tight',transparent=True)

### h

In [None]:
# Plot setup
fig, ax2 = plt.subplots(1,1,figsize=(3.2,4.8))
fig.subplots_adjust(hspace=0.05)



y_data = np.array(AccM7NL_RK)
g_pos = np.array([3.05])
colors = ['burlywood']
colors2 = ['peachpuff']

BoxPlots(y_data,g_pos,colors,colors2,ax2,widths=0.2,scatter=True)
plt.xticks([1,1.7,2,2.8,3.1],['Full','Mid','Min','Mid','Min'],fontsize=32,rotation=90)
#plt.gca().ticklabel_format(axis='y',style='sci',scilimits=(-1,1))
plt.yticks(fontsize=24)


plt.sca(ax2)

y_data = np.array(AccMFulls_RK)
g_pos = np.array([0.8])
colors = ['black']
colors2 = ['gray']

BoxPlots(y_data,g_pos,colors,colors2,ax2,widths=0.2,scatter=True)

y_data = np.array(AccM20_RK)
g_pos = np.array([1.55])
colors = ['lightseagreen']
colors2 = ['paleturquoise']

BoxPlots(y_data,g_pos,colors,colors2,ax2,widths=0.2,scatter=True)

y_data = np.array(AccM7_RK)
g_pos = np.array([1.95])
colors = ['mediumturquoise']
colors2 = ['paleturquoise']

BoxPlots(y_data,g_pos,colors,colors2,ax2,widths=0.2,scatter=True)

y_data = np.array(AccM20NL_RK)
g_pos = np.array([2.65])
colors = ['sandybrown']
colors2 = ['peachpuff']

BoxPlots(y_data,g_pos,colors,colors2,ax2,widths=0.2,scatter=True)

plt.xticks([0.8,1.55,1.95,2.65,3.05],['Full','Mid','Min','Mid','Min'],fontsize=32,rotation=90)
#plt.gca().ticklabel_format(axis='y',style='sci',scilimits=(-0.5,3))
#ax1.set_ylim(0.4, 1)
#ax1.ticklabel_format(axis='y', style='sci', scilimits=(-3, -3))
#ax1.yaxis.set_ticks([0.4,0.7,1])

plt.yticks(fontsize=24)

#ax2.set_ylim(0.0,0.3)
#ax2.ticklabel_format(axis='y', style='sci', scilimits=(-3, -3))
#ax2.yaxis.set_ticks(np.arange(0, 0.0001, 0.00004))

# Common x-ticks
#ax2.set_xticks([1, 1.7, 2, 2.8, 3.1])

# Adding broken axis effect
d = .5
kwargs = dict(marker=[(-1, -d), (1, d)], markersize=12, linestyle="none", color='k', mec='k', mew=1, clip_on=False)


# Hide the spines between ax and ax2
ax2.spines.top.set_visible(False)
ax2.xaxis.tick_bottom()
ax2.yaxis.offsetText.set_visible(False)  # Hide the "1e-3" from ax2

# Show plot

leg_patch2 = mpatches.Patch(color='mediumturquoise', label='SBI')
leg_patch3 = mpatches.Patch(color='sandybrown', label='NLLS')
leg_patch5 = mpatches.Patch(color='gray', label='Full \nComp.')

ax2.set_xlim(0.5,3.5)
if Save: plt.savefig(FigLoc+'DKI_Acc_RK_MS.pdf',format='PDF',transparent=True,bbox_inches='tight')
plt.show()

In [None]:
fig, ax = plt.subplots(figsize=(3.2,4.8))#, sharex=True)

# Plotting on ax1
plt.sca(ax)
y_data = np.array(SSIMFulls_RK)
g_pos = np.array([0.8])
colors = ['black']
colors2 = ['gray']

BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2,scatter=True)

y_data = np.array(SSIM20_RK)
g_pos = np.array([1.5])
colors = ['lightseagreen']
colors2 = ['paleturquoise']

BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2,scatter=True,scatter_alpha=0.5)

y_data = np.array(SSIM7_RK)
g_pos = np.array([1.95])
colors = ['mediumturquoise']
colors2 = ['paleturquoise']

BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2,scatter=True)

y_data = np.array(SSIM20NL_RK)
g_pos = np.array([2.6])
colors = ['sandybrown']
colors2 = ['peachpuff']

BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2,scatter=True)

y_data = np.array(SSIM7NL_RK)
g_pos = np.array([3.05])
colors = ['burlywood']
colors2 = ['peachpuff']

BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2,scatter=True)

y_data = np.array(SSIM7NL_RK)
g_pos = np.array([3.1])
colors = ['burlywood']
colors2 = ['peachpuff']

plt.axhline(0.66, lw=3, ls='--', c='k')
plt.xticks([0.8,1.5,1.95,2.6,3.05],['Full','Mid','Min','Mid','Min'],fontsize=32,rotation=90)
ax.set_yticks([0,0.2,0.4,0.6,0.8,1.0])
ax.set_ylim(-0.1,1)

ax.legend(
    handles=[leg_patch3],
    loc='lower left',         # base location  # fine-tune the legend's position
    frameon=False, ncols=1,
fontsize=32,columnspacing=0.3,handlelength=0.6,handletextpad=0.3,bbox_to_anchor= (-0.1,-0.05))

if Save: plt.savefig(FigLoc+'DKIHCP_SSIM_RK_MS.pdf',format='PDF',transparent=True,bbox_inches='tight')

In [None]:
fig,ax = plt.subplots(figsize=(3.2,4.8))
y_data = np.array(PrecFull_SBI_RK)
g_pos = np.array([0.65])
colors = ['mediumturquoise']
colors2 = ['paleturquoise']

BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2,scatter=True)

y_data = np.array(Prec20_SBI_RK)
g_pos = np.array([1])
BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2,scatter=True)

y_data = np.array(Prec7_SBI_RK)
g_pos = np.array([1.35])
BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2,scatter=True)

colors = ['sandybrown']
colors2 = ['peachpuff']
y_data = np.array(PrecFull_NLLS_RK)
g_pos = np.array([1.8])
BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2,scatter=True)

y_data = np.array(Prec20_NLLS_RK)
g_pos = np.array([2.15])
BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2,scatter=True)

y_data = np.array(Prec7_NLLS_RK)
g_pos = np.array([2.5])
BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2,scatter=True)

x = np.arange(1.7,2.6,0.05)
y1 = np.ones_like(x)*np.percentile(np.array(PrecFull_NLLS_RK)[~np.isnan(PrecFull_NLLS_RK)], 25)
y2 = np.ones_like(x)*np.percentile(np.array(PrecFull_NLLS_RK)[~np.isnan(PrecFull_NLLS_RK)], 77)
plt.fill_between(x,y1,y2,color=WLSFit,zorder=10,alpha=0.2,hatch='//')

x = np.arange(0.55,1.5,0.05)
y1 = np.ones_like(x)*np.percentile(np.array(PrecFull_SBI_RK)[~np.isnan(PrecFull_SBI_RK)], 25)
y2 = np.ones_like(x)*np.percentile(np.array(PrecFull_SBI_RK)[~np.isnan(PrecFull_SBI_RK)], 77)
plt.fill_between(x,y1,y2,color=SBIFit,zorder=10,alpha=0.2,hatch='//')

plt.xticks([0.65,1,1.35,1.8,2.15,2.5],['Full','Mid','Min','Full','Mid','Min'],fontsize=32,rotation=90)
plt.gca().ticklabel_format(axis='y',style='sci',scilimits=(-0.5,3))
plt.yticks(fontsize=24)
ax.set_xlim([0.4,2.7])
if Save: plt.savefig(FigLoc+'DKI_RK_Prec_MS.pdf',format='pdf',bbox_inches='tight',transparent=True)

# Figure 5

In [None]:
def Par_frac(i,j,Mat):
    MD = np.linalg.eigh(vals_to_mat(Mat[i,j]))[0].mean()

    FA = FracAni(np.linalg.eigh(vals_to_mat(Mat[i,j]))[0],MD)
    return i, j, [FA,MD]

In [None]:
Delta = [0.017, 0.035, 0.061]             # ms
delta = 0.007           # ms

In [None]:
np.random.seed(10)
n_pts = 90
theta = np.pi * np.random.random(n_pts)
phi = 2 * np.pi * np.random.random(n_pts)
hsph_initial = HemiSphere(theta=theta, phi=phi)
hsph_updated, potential = disperse_charges(hsph_initial, 5000)
vertices = hsph_updated.vertices
values = np.ones(31)
bvecs = np.vstack((vertices))
bvecs = np.insert(bvecs, 0, np.array([0, 0, 0]), axis=0)
bvals = np.hstack((0,[2000] * 30,[4000]*60))
bvecs = np.vstack([bvecs,bvecs,bvecs])
bvals = np.hstack([bvals,bvals,bvals])

In [None]:
# Choose the first point (arbitrary starting point, e.g., the first gradient)
selected_indices = [0]
bvecs2000 = bvecs[:91][bvals[:91]==2000]
distance_matrix = squareform(pdist(bvecs2000))
# Iteratively select the point furthest from the current selection
for _ in range(2):  # We need 7 points in total, and one is already selected
    remaining_indices = list(set(range(len(bvecs2000))) - set(selected_indices))
    
    # Calculate the minimum distance to the selected points for each remaini'ng point
    min_distances = np.min(distance_matrix[remaining_indices][:, selected_indices], axis=1)
    
    # Select the point with the maximum minimum distance
    next_index = remaining_indices[np.argmax(min_distances)]
    selected_indices.append(next_index)

selected_indices = selected_indices
bvecs2000_selected = bvecs[:91][bvals[:91]==2000][selected_indices]
true_indices = []
for b in bvecs2000_selected:
    true_indices.append(np.where((b == bvecs).all(axis=1))[0][0])

# Choose the first point (arbitrary starting point, e.g., the first gradient)
selected_indices = [0]
bvecs4000 = bvecs[:91][bvals[:91]==4000]
distance_matrix = squareform(pdist(bvecs4000))
# Iteratively select the point furthest from the current selection
for _ in range(2):  # We need 7 points in total, and one is already selected
    remaining_indices = list(set(range(len(bvecs4000))) - set(selected_indices))
    
    # Calculate the minimum distance to the selected points for each remaining point
    min_distances = np.min(distance_matrix[remaining_indices][:, selected_indices], axis=1)
    
    # Select the point with the maximum minimum distance
    next_index = remaining_indices[np.argmax(min_distances)]
    selected_indices.append(next_index)

selected_indices = selected_indices
bvecs4000_selected = bvecs[:91][bvals[:91]==4000][selected_indices]
for b in bvecs4000_selected:
    true_indices.append(np.where((b == bvecs).all(axis=1))[0][0])
MinIdices = np.array(true_indices)
DevilIndices = np.hstack([MinIdices,MinIdices+91,MinIdices+182])
DevilIndices = np.hstack([0,DevilIndices])
bvecs_Dev = bvecs[DevilIndices]
bvals_Dev = bvals[DevilIndices]

bve_split = [bvecs[:(n_pts+1)],bvecs[(n_pts+1):2*(n_pts+1)],bvecs[2*(n_pts+1):]]
bva_split = [bvals[:(n_pts+1)],bvals[(n_pts+1):2*(n_pts+1)],bvals[2*(n_pts+1):]]

In [None]:
np.random.seed(12)
TestSamps = 20

# Directions
x1  = np.random.randn(TestSamps)
y1  = np.random.randn(TestSamps)
z1  =  np.random.randn(TestSamps)
V = np.vstack([x1,y1,z1])
V = (V/np.linalg.norm(V,axis=0)).T
Angs = np.array([SpherAng(v) for v in V])

#Diffusion of restricted
Dpar  = np.random.rand(TestSamps)*5e-3
Dperp = np.random.rand(TestSamps)*5e-3

#Diffusion of hindered
Params_abc =  np.random.rand(TestSamps,3)*0.14-0.07
Params_rest =  np.random.rand(TestSamps,3)*0.03-0.015
Params = np.hstack([Params_abc,Params_rest])
DHind = np.array([ComputeDTI(p) for p in Params])
DHind = np.array([mat_to_vals(ForceLowFA(dt)) for dt in DHind])

#Fraction of hindered
frac  = np.random.rand(TestSamps)

mean = np.random.rand(TestSamps)*0.005+1e-4
sig2 = np.random.rand(TestSamps) * (4e-7 - 9e-8) + 9e-8

S0Rand =np.ones(TestSamps)

TestParams = np.column_stack([Angs,Dpar,Dperp,DHind,frac,mean])

TestSig = []
NoisyTestSig = []
for i in tqdm(range(TestSamps)):
    v = np.array([Angs[i]])
    dpar = Dpar[i]
    dperp = Dperp[i]
    
    dh   = DHind[i]
    f    = [frac[i],1-frac[i]]

    a = mean[i]
    s = sig2[i]
    alpha     = a * a / s
    scale = s / a
    rv = stats.gamma(a=alpha,scale=scale)
    
    R = np.linspace(0.0001,0.005, 30)
    weights = rv.pdf(R)
    weights = weights/np.sum(weights)
    s0 = 1

    TestSig1 = CombSignal_poisson(bvecs[:(n_pts+1)],bvals[:(n_pts+1)],Delta[0],delta,[v,dpar,dperp,dh,f,a,s0])
    TestSig2 = CombSignal_poisson(bvecs[(n_pts+1):2*(n_pts+1)],bvals[(n_pts+1):2*(n_pts+1)],Delta[1],delta,[v,dpar,dperp,dh,f,a,s0])
    TestSig3 = CombSignal_poisson(bvecs[2*(n_pts+1):],bvals[2*(n_pts+1):],Delta[2],delta,[v,dpar,dperp,dh,f,a,s0])
    TestSig.append(np.hstack([TestSig1,TestSig2,TestSig3]))
    Noisy = []
    for Noise in [2,10,20,30]:
        Noisy.append(AddNoise(TestSig[-1],s0,Noise))
    NoisyTestSig.append(Noisy)
NoisyTestSig = np.array(NoisyTestSig)
NoisyTestSig = np.swapaxes(NoisyTestSig,0,1)
TestSig = np.array(TestSig)

In [None]:
if os.path.exists(f"{network_path}/Full_Sim_50_100k_poisson.pickle"):
    with open(f"{network_path}/Full_Sim_50_100k_poisson.pickle", "rb") as handle:
        posterior = pickle.load(handle)
else:
    np.random.seed(10)
    NumSamps = 100000
    
    # Directions
    x1  = np.random.randn(NumSamps)
    y1  = np.random.randn(NumSamps)
    z1  =  np.random.randn(NumSamps)
    VS = np.vstack([x1,y1,z1])
    VS = (VS/np.linalg.norm(VS,axis=0)).T
    AngsS = np.array([SpherAng(v) for v in VS])
    
    #Diffusion of restricted
    DparS  = np.random.rand(NumSamps)*5e-3
    DperpS = np.random.rand(NumSamps)*5e-3
    
    #Diffusion of hindered
    Params_abc =  np.random.rand(NumSamps,3)*0.14-0.07
    Params_rest =  np.random.rand(NumSamps,3)*0.03-0.015
    Params = np.hstack([Params_abc,Params_rest])
    DHindS = np.array([ComputeDTI(p) for p in Params])
    DHindS = np.array([mat_to_vals(ForceLowFA(dt)) for dt in DHindS])
    
    meanS = np.random.rand(NumSamps)*0.005+1e-4
    sig2S = np.random.rand(NumSamps) * (4e-7 - 9e-8) + 9e-8
    
    #Fraction of hindered
    fracS  = np.random.rand(NumSamps)
    TrainParams = np.column_stack([AngsS,DparS,DperpS,DHindS,fracS,meanS])


    TrainSigS = []
    NoisyTrainSigS = []
    for i in tqdm(range(NumSamps)):
        v = np.array([AngsS[i]])
        dpar = DparS[i]
        dperp = DperpS[i]
        
        dh   = DHindS[i]
        f    = [fracS[i],1-fracS[i]]
    
        a = meanS[i]
        s = sig2S[i]
        s0 = 1
        
        Noise = 50
        
        TrainSig1 = CombSignal_poisson(bvecs[:(n_pts+1)],bvals[:(n_pts+1)],Delta[0],delta,[v,dpar,dperp,dh,f,a,s0])
        TrainSig2 = CombSignal_poisson(bvecs[(n_pts+1):2*(n_pts+1)],bvals[(n_pts+1):2*(n_pts+1)],Delta[1],delta,[v,dpar,dperp,dh,f,a,s0])
        TrainSig3 = CombSignal_poisson(bvecs[2*(n_pts+1):],bvals[2*(n_pts+1):],Delta[2],delta,[v,dpar,dperp,dh,f,a,s0])
        TrainSigS.append(np.hstack([TrainSig1,TrainSig2,TrainSig3]))
        
        NoisyTrainSigS.append(AddNoise(TrainSigS[-1],s0,Noise))
    NoisyTrainSigS = np.array(NoisyTrainSigS)

    Obs = torch.tensor(NoisyTrainSigS).float()
    Par = torch.tensor(TrainParams).float()
    
    # Create inference object. Here, NPE is used.
    inference = SNPE()
    
    # generate simulations and pass to the inference object
    inference = inference.append_simulations(Par, Obs)
    # train the density estimator and build the posterior
    density_estimator = inference.train()
    posterior = inference.build_posterior(density_estimator)

    with open(f"{network_path}/Full_Sim_50_100k_poisson", "wb") as handle:
        pickle.dump(posterior, handle)

In [None]:
if os.path.exists(f"{network_path}/Dev_Sim_50_100k_poisson.pickle"):
    with open(f"{network_path}/Dev_Sim_50_100k_poisson.pickle", "rb") as handle:
        posteriorMin = pickle.load(handle)
else:

    np.random.seed(10)
    torch.manual_seed(10)
    TrainSigS = []
    NoisyTrainSigS = []
    for i in tqdm(range(NumSamps)):
        v = np.array([AngsS[i]])
        dpar = DparS[i]
        dperp = DperpS[i]
        
        dh   = DHindS[i]
        f    = [fracS[i],1-fracS[i]]
    
        a = meanS[i]
        s0 = 1
        
        Noise = 50
        
        TrainSig1 = CombSignal_poisson(bvecs_Dev[:7],bvals_Dev[:7],Delta[0],delta,[v,dpar,dperp,dh,f,a,s0])
        TrainSig2 = CombSignal_poisson(bvecs_Dev[7:13],bvals_Dev[7:13],Delta[1],delta,[v,dpar,dperp,dh,f,a,s0])
        TrainSig3 = CombSignal_poisson(bvecs_Dev[13:],bvals_Dev[13:],Delta[2],delta,[v,dpar,dperp,dh,f,a,s0])
        TrainSigS.append(np.hstack([TrainSig1,TrainSig2,TrainSig3]))
        
        NoisyTrainSigS.append(AddNoise(TrainSigS[-1],s0,Noise))
    NoisyTrainSigS = np.array(NoisyTrainSigS)


    Obs = torch.tensor(NoisyTrainSigS).float()
    Par = torch.tensor(TrainParams).float()
    
    # Create inference object. Here, NPE is used.
    inference = SNPE()
    
    # generate simulations and pass to the inference object
    inference = inference.append_simulations(Par, Obs)
    # train the density estimator and build the posterior
    density_estimator = inference.train()
    posteriorMin = inference.build_posterior(density_estimator)

    with open(f"{network_path}/Dev_Sim_50_100k_poisson.pickle", "wb") as handle:
        pickle.dump(posteriorMin, handle)

In [None]:
def Errors(TrueSig,TrueParams,GuessParams,Delta,bvecs,bvals):

    Res = np.linalg.norm(residuals(GuessParams,TrueSig,bvecs,bvals,Delta))
    alpha_err = np.abs(GuessParams[11]-TrueParams[11])

    angle_err1 =  np.abs(GuessParams[0]-TrueParams[0])
    angle_err2 =  np.abs(GuessParams[1]-TrueParams[1])

    Dpar_err  = np.abs(TrueParams[2]-GuessParams[2])
    Dperp_err  = np.abs(TrueParams[3]-GuessParams[3])

    MD_guess = np.linalg.eigh(vals_to_mat(GuessParams[4:10]))[0].mean()
    MD_true = np.linalg.eigh(vals_to_mat(TrueParams[4:10]))[0].mean()

    FA_guess = FracAni(np.linalg.eigh(vals_to_mat(GuessParams[4:10]))[0],MD_guess)
    FA_true  = FracAni(np.linalg.eigh(vals_to_mat(TrueParams[4:10]))[0],MD_true)

    MD_err = np.abs(MD_guess-MD_true)
    FA_err = np.abs(FA_guess-FA_true)

    Frac_err  = np.abs(TrueParams[10]-GuessParams[10])

    return Res, alpha_err,angle_err1,angle_err2,Dpar_err,Dperp_err,MD_err,FA_err,Frac_err

In [None]:
def residuals(params,TrueSig,bvecs,bvals,Delta):
    Signal = Simulator_new(params,bvecs,bvals,Delta,S0=1)
    return TrueSig - Signal

def residuals_S0(params,TrueSig,bvecs,bvals,Delta):
    Signal = Simulator_new(params,bvecs,bvals,Delta,S0=params[-1])
    return TrueSig - Signal


In [None]:
def Simulator_new(params,bvecs,bvals,Delta,S0=1):
    new_params = [np.array([params[:2]]),params[2],params[3],params[4:10],[params[10],1-params[10]],params[11],S0]
    Sig = []
    for bve,bva,d in zip(bvecs,bvals,Delta):
        Sig.append(CombSignal_poisson(bve,bva,d,delta,new_params))
    return np.hstack(Sig) 

def Simulator_Min(params,Delta,S0=1):
    new_params = [np.array([params[:2]]),params[2],params[3],params[4:10],[params[10],1-params[10]],params[11],S0]
    Sig1 = CombSignal_poisson(bvecs_Dev[:7],bvals_Dev[:7],Delta[0],delta,new_params)
    Sig2 = CombSignal_poisson(bvecs_Dev[7:13],bvals_Dev[7:13],Delta[1],delta,new_params)
    Sig3 = CombSignal_poisson(bvecs_Dev[13:],bvals_Dev[13:],Delta[2],delta,new_params)
    return np.hstack([Sig1,Sig2,Sig3])

In [None]:
mean = np.random.rand(1)*0.005+1e-4
Params_abc =  np.random.rand(1,3)*0.14-0.07
Params_rest =  np.random.rand(1,3)*0.03-0.015
Params = np.hstack([Params_abc,Params_rest])
DHind_guess = np.array([ComputeDTI(p) for p in Params])
DHind_guess = np.array([mat_to_vals(ForceLowFA(dt)) for dt in DHind_guess])

Dpar_guess = np.random.rand()*1e-3            # mm^2/s
Dperp_guess = np.random.rand()*1e-3             # mm^2/s
phi = 0#np.random.rand()*pi
cos_theta = 0#np.random.rand()  # uniform in [0,1]
theta = np.arccos(cos_theta)         # in [0, pi/2]
Angs_guess = np.vstack([theta,phi]).T

mean_guess = np.random.rand()*0.005 + 1e-4

frac_guess = np.random.rand()
guess = np.column_stack([Angs_guess,Dpar_guess,Dperp_guess,DHind_guess,frac_guess,mean_guess]).squeeze()
bounds = np.array([[-np.inf,np.inf]]*12).T
bounds[:,0] = [0,np.pi/2]
bounds[:,1] = [-np.pi,np.pi]
bounds[:,2] = [0,5e-3]
bounds[:,3] = [0,5e-3]
bounds[:,4] = [-5e-3,5e-3]
bounds[:,5] = [-5e-3,5e-3]
bounds[:,6] = [-5e-3,5e-3]
bounds[:,7] = [-5e-3,5e-3]
bounds[:,8] = [-5e-3,5e-3]
bounds[:,9] = [-5e-3,5e-3]
bounds[:,10] = [0,1]
bounds[:,11] = [1e-4,0.005+1e-4]
LS_result = np.zeros([4,20,12])
bve_split = [bvecs[:(n_pts+1)],bvecs[(n_pts+1):2*(n_pts+1)],bvecs[2*(n_pts+1):]]
bva_split = [bvals[:(n_pts+1)],bvals[(n_pts+1):2*(n_pts+1)],bvals[2*(n_pts+1):]]
for i in tqdm(range(20)):
    for j in range(4):
        result = sp.optimize.least_squares(residuals, guess, args=[NoisyTestSig[j,i],bve_split,bva_split,Delta],
                                      bounds=bounds,verbose=0,xtol=1e-12,gtol=1e-12,ftol=1e-12,jac='3-point')
        LS_result[j,i] = result.x

In [None]:
LS_Errors = []
for N in tqdm(LS_result):
    temp = []
    for n_guess,n_true,sig in zip(N,TestParams,TestSig):
        temp.append(Errors(sig,n_true,n_guess,Delta,bve_split,bva_split))
    LS_Errors.append(temp)
LS_Errors = np.array(LS_Errors)

In [None]:
# Define the function for optimization
def fit_SBI(i,j):
    torch.manual_seed(10)  # If required
    posterior_samples_1 = posterior.sample((1000,), x=NoisyTestSig[i,j],show_progress_bars=False)
    return i, j, posterior_samples_1.mean(axis=0)

y_indx = np.repeat(np.arange(20),4)
x_indx = np.tile(np.arange(4),20)
indices = np.column_stack([x_indx,y_indx])

# Use joblib to parallelize the optimization tasks
results = Parallel(n_jobs=-1)(
    delayed(fit_SBI)(i, j) for i, j in tqdm(indices)
)

SBI_Res = np.zeros([4,20,12])

for i, j, x in results:
    SBI_Res[i, j] = x

for i, j, x in results:
    SBI_Res[i, j,-2] = np.clip(SBI_Res[i, j,-2],0,100)
    
SBI_Errors = []
for N in tqdm(SBI_Res):
    temp = []
    for n_guess,n_true,sig in zip(N,TestParams,TestSig):
        temp.append(Errors(sig,n_true,n_guess,Delta,bve_split,bva_split))
    SBI_Errors.append(temp)
SBI_Errors = np.array(SBI_Errors)

In [None]:
# Define the function for optimization
def fit_SBI(i,j):
    torch.manual_seed(10)  # If required
    posterior_samples_1 = posteriorMin.sample((1000,), x=NoisyTestSig[i,j][DevilIndices],show_progress_bars=False)
    return i, j, posterior_samples_1.mean(axis=0)

y_indx = np.repeat(np.arange(20),4)
x_indx = np.tile(np.arange(4),20)
indices = np.column_stack([x_indx,y_indx])

# Use joblib to parallelize the optimization tasks
results = Parallel(n_jobs=-1)(
    delayed(fit_SBI)(i, j) for i, j in tqdm(indices)
)

SBI_Res = np.zeros([4,20,12])

for i, j, x in results:
    SBI_Res[i, j] = x

for i, j, x in results:
    SBI_Res[i, j,-2] = np.clip(SBI_Res[i, j,-2],0,100)
    
SBI_Errors_Min = []
for N in tqdm(SBI_Res):
    temp = []
    for n_guess,n_true,sig in zip(N,TestParams,TestSig):
        temp.append(Errors(sig,n_true,n_guess,Delta,bve_split,bva_split))
    SBI_Errors_Min.append(temp)
SBI_Errors_Min = np.array(SBI_Errors_Min)

In [None]:
bve_splitd = [bvecs_Dev[:7],bvecs_Dev[7:13],bvecs_Dev[13:]]
bva_splitd = [bvals_Dev[:7],bvals_Dev[7:13],bvals_Dev[13:]]
for i in tqdm(range(20)):
    for j in range(4):
        result = sp.optimize.least_squares(residuals, guess, args=[NoisyTestSig[j,i][DevilIndices],bve_splitd,bva_splitd,Delta],
                                      bounds=bounds,verbose=1,xtol=1e-12,gtol=1e-12,ftol=1e-12,jac='3-point')
        LS_result[j,i] = result.x

In [None]:
LS_Errors_Min = []
for N in tqdm(LS_result):
    temp = []
    for n_guess,n_true,sig in zip(N,TestParams,TestSig):
        temp.append(Errors(sig,n_true,n_guess,Delta,bve_split,bva_split))
    LS_Errors_Min.append(temp)
LS_Errors_Min = np.array(LS_Errors_Min)

In [None]:
Dir = MSDir+'/Ctrl055_R01_28/'
dat = pmt.read_mat(Dir+'data_loaded.mat')
bvecs = dat['direction']
bvals = dat['bval']
FixedParams = {
    'bvals':bvals,
    'bvecs':bvecs,
    'Delta':[0.017,0.035,0.061],
    'delta':0.007,
}
Delta = FixedParams['Delta']
delta = FixedParams['delta']
n_pts = 90

Delta = [0.017,0.035,0.061] # We know this 
delta = 0.007 # We know this 


data = dat['data']
axial_middle = data.shape[2] // 2
maskdata, mask = median_otsu(data, vol_idx=range(0, 10), median_radius=5,
                             numpass=1, autocrop=False, dilate=2)

S_mask, _, _ = load_nifti(Dir+'mask_055.nii.gz', return_img=True)


mask1 = np.ones_like(S_mask[:,54,:])
mask1[S_mask[:,54,:]==0] = 0
structure = np.ones((3, 3), dtype=bool)

# Apply dilation. Increase 'iterations' to make the mask even fatter.
fat_mask = binary_dilation(mask1, structure=structure, iterations=1)

In [None]:
if os.path.exists(f"{network_path}/Full_Dat_50_100k_poisson.pickle"):
    with open(f"{network_path}/Full_Dat_50_100k_poisson.pickle", "rb") as handle:
        posterior = pickle.load(handle)
else:

    TrainSig = []
    NoisyTrainSig = []
    for i in tqdm(range(NumSamps)):
        v = np.array([Angs[i]])
        dpar = Dpar[i]
        dperp = Dperp[i]
        
        dh   = DHind[i]
        f    = [frac[i],1-frac[i]]
    
        a = mean[i]
        s0 = S0Rand[i]
        
        Noise = 50#np.random.rand()*30 + 20
    
        TrainSig1 = CombSignal_poisson(bvecs[:(n_pts+1)],bvals[:(n_pts+1)],Delta[0],delta,[v,dpar,dperp,dh,f,a,s0])
        TrainSig2 = CombSignal_poisson(bvecs[(n_pts+1):2*(n_pts+1)],bvals[(n_pts+1):2*(n_pts+1)],Delta[1],delta,[v,dpar,dperp,dh,f,a,s0])
        TrainSig3 = CombSignal_poisson(bvecs[2*(n_pts+1):],bvals[2*(n_pts+1):],Delta[2],delta,[v,dpar,dperp,dh,f,a,s0])
        TrainSig.append(np.hstack([TrainSig1,TrainSig2,TrainSig3]))
        NoisyTrainSig.append(AddNoise(TrainSig[-1],s0,Noise))
    NoisyTrainSig = np.array(NoisyTrainSig)



    Obs = torch.tensor(NoisyTrainSig).float()
    Par = torch.tensor(TrainParams[:,3:]).float()
    # Create inference object. Here, NPE is used.
    inference = SNPE()
    
    # generate simulations and pass to the inference object
    inference = inference.append_simulations(Par, Obs)
    
    # train the density estimator and build the posterior
    density_estimator = inference.train()
    posterior = inference.build_posterior(density_estimator)
    with open(f"{network_path}/Full_Dat_50_100k_poisson.pickle", "wb") as handle:
        pickle.dump(posterior, handle)

In [None]:
# Compute the mask where the sum is not zero
mask = np.sum(maskdata[:, :, axial_middle, :], axis=-1) != 0

# Get the indices where mask is True
indices = np.argwhere(mask)
def optimize_chunk(pixels):
    results = []
    for i, j in pixels:
        posterior_samples_1 = posterior.sample((1000,), x=maskdata[i, j,axial_middle, :],show_progress_bars=False)        
        results.append((i, j, np.array([histogram_mode(p) for p in posterior_samples_1.T])))
    return results

chunked_indices = [indices[i:i+ChunkSize] for i in range(0, len(indices), ChunkSize)]
results = Parallel(n_jobs=8)(
    delayed(optimize_chunk)(chunk) for chunk in tqdm(chunked_indices)
)

# Initialize NoiseEst with the appropriate shape
ArrShape = mask.shape

NoiseEst = np.zeros(list(ArrShape) + [13])

# Assign the optimization results to NoiseEst
for chunk in results:
    for i, j, x in chunk:
        NoiseEst[i, j] = x
        NoiseEst[i, j,-2] = np.clip(NoiseEst[i, j,-2],0,100)
        NoiseEst[i, j,-3] = np.clip(NoiseEst[i, j,-3],0,1)
NoiseEst2 = np.copy(NoiseEst)

for i in range(13):
    NoiseEst2[~mask,i] = math.nan

NoiseEst2[(1-NoiseEst2[...,-3])<0.3,-2] = math.nan

In [None]:
np.random.seed(133)
S0 = 2000
mean_guess = np.random.rand()*0.005+1e-4
Params_abc =  np.random.rand(1,3)*0.14-0.07
Params_rest =  np.random.rand(1,3)*0.03-0.015
Params = np.hstack([Params_abc,Params_rest])
DHind_guess = np.array([ComputeDTI(p) for p in Params])
DHind_guess = np.array([mat_to_vals(ForceLowFA(dt)) for dt in DHind_guess])

Dpar_guess = np.random.rand()*1e-3            # mm^2/s
Dperp_guess = np.random.rand()*1e-3             # mm^2/s
phi = 0#np.random.rand()*pi
cos_theta = 0#np.random.rand()  # uniform in [0,1]
theta = np.arccos(cos_theta)         # in [0, pi/2]
Angs_guess = np.vstack([theta,phi]).T
S0_guess =np.random.rand()*2475+25

frac_guess = np.random.rand()
guess = np.column_stack([Angs_guess,Dpar_guess,Dperp_guess,DHind_guess,frac_guess,mean_guess,S0_guess]).squeeze()
bounds = np.array([[-np.inf,np.inf]]*13).T
bounds[:,0] = [0,np.pi/2]
bounds[:,1] = [-np.pi,np.pi]
bounds[:,2] = [0,5e-3]
bounds[:,3] = [0,5e-3]
bounds[:,4] = [-5e-3,5e-3]
bounds[:,5] = [-5e-3,5e-3]
bounds[:,6] = [-5e-3,5e-3]
bounds[:,7] = [-5e-3,5e-3]
bounds[:,8] = [-5e-3,5e-3]
bounds[:,9] = [-5e-3,5e-3]
bounds[:,10] = [0,1]
bounds[:,11] = [1e-4,0.005+1e-4]
bounds[:,12] = [25,2500]

bve_split = [bvecs[:(n_pts+1)],bvecs[(n_pts+1):2*(n_pts+1)],bvecs[2*(n_pts+1):]]
bva_split = [bvals[:(n_pts+1)],bvals[(n_pts+1):2*(n_pts+1)],bvals[2*(n_pts+1):]]

In [None]:
# Compute the mask where the sum is not zero
mask = np.sum(maskdata[:, :, axial_middle, :], axis=-1) != 0

# Get the indices where mask is True
indices = np.argwhere(mask)

def optimize_pixel_LS(i, j):
    result = sp.optimize.least_squares(residuals_S0, guess, args=[maskdata[i, j, axial_middle, :],bve_split,bva_split,Delta],
                              bounds=bounds,verbose=0,xtol=1e-12,gtol=1e-12,ftol=1e-12,jac='3-point')
    return i, j, result.x

# Initialize NoiseEst with the appropriate shape
ArrShape = mask.shape

# Use joblib to parallelize the optimization tasks
results = Parallel(n_jobs=8)(
    delayed(optimize_pixel_LS)(i, j) for i, j in tqdm(indices)
)


NoiseEst_LS = np.zeros(list(ArrShape) + [13])

# Assign the optimization results to NoiseEst
for i, j, x in results:
    NoiseEst_LS[i, j] = x

NoiseEst2_LS = np.copy(NoiseEst_LS)
for i in range(13):
    NoiseEst2_LS[~mask,i] = math.nan

NoiseEst2_LS[(1-NoiseEst2_LS[...,-3])<0.3,-2] = math.nan

In [None]:
# Choose the first point (arbitrary starting point, e.g., the first gradient)
selected_indices = [0]
bvecs2000 = bvecs[:91][bvals[:91]==2000]
distance_matrix = squareform(pdist(bvecs2000))
# Iteratively select the point furthest from the current selection
for _ in range(2):  # We need 7 points in total, and one is already selected
    remaining_indices = list(set(range(len(bvecs2000))) - set(selected_indices))
    
    # Calculate the minimum distance to the selected points for each remaining point
    min_distances = np.min(distance_matrix[remaining_indices][:, selected_indices], axis=1)
    
    # Select the point with the maximum minimum distance
    next_index = remaining_indices[np.argmax(min_distances)]
    selected_indices.append(next_index)

selected_indices = selected_indices
bvecs2000_selected = bvecs[:91][bvals[:91]==2000][selected_indices]
true_indices = []
for b in bvecs2000_selected:
    true_indices.append(np.where((b == bvecs).all(axis=1))[0][0])

# Choose the first point (arbitrary starting point, e.g., the first gradient)
selected_indices = [0]
bvecs4000 = bvecs[:91][bvals[:91]==4000]
distance_matrix = squareform(pdist(bvecs4000))
# Iteratively select the point furthest from the current selection
for _ in range(2):  # We need 7 points in total, and one is already selected
    remaining_indices = list(set(range(len(bvecs4000))) - set(selected_indices))
    
    # Calculate the minimum distance to the selected points for each remaining point
    min_distances = np.min(distance_matrix[remaining_indices][:, selected_indices], axis=1)
    
    # Select the point with the maximum minimum distance
    next_index = remaining_indices[np.argmax(min_distances)]
    selected_indices.append(next_index)

selected_indices = selected_indices
bvecs4000_selected = bvecs[:91][bvals[:91]==4000][selected_indices]
for b in bvecs4000_selected:
    true_indices.append(np.where((b == bvecs).all(axis=1))[0][0])
true_indices1 = true_indices

# Choose the first point (arbitrary starting point, e.g., the first gradient)
selected_indices = [0]
bvecs2000 = bvecs[91:182][bvals[91:182]==2000]
distance_matrix = squareform(pdist(bvecs2000))
# Iteratively select the point furthest from the current selection
for _ in range(2):  # We need 7 points in total, and one is already selected
    remaining_indices = list(set(range(len(bvecs2000))) - set(selected_indices))
    
    # Calculate the minimum distance to the selected points for each remaining point
    min_distances = np.min(distance_matrix[remaining_indices][:, selected_indices], axis=1)
    
    # Select the point with the maximum minimum distance
    next_index = remaining_indices[np.argmax(min_distances)]
    selected_indices.append(next_index)

selected_indices = selected_indices
bvecs2000_selected = bvecs[91:182][bvals[91:182]==2000][selected_indices]
true_indices = []
for b in bvecs2000_selected:
    true_indices.append(np.where((b == bvecs).all(axis=1))[0][0])

# Choose the first point (arbitrary starting point, e.g., the first gradient)
selected_indices = [0]
bvecs4000 = bvecs[91:182][bvals[91:182]==4000]
distance_matrix = squareform(pdist(bvecs4000))
# Iteratively select the point furthest from the current selection
for _ in range(2):  # We need 7 points in total, and one is already selected
    remaining_indices = list(set(range(len(bvecs4000))) - set(selected_indices))
    
    # Calculate the minimum distance to the selected points for each remaining point
    min_distances = np.min(distance_matrix[remaining_indices][:, selected_indices], axis=1)
    
    # Select the point with the maximum minimum distance
    next_index = remaining_indices[np.argmax(min_distances)]
    selected_indices.append(next_index)

selected_indices = selected_indices
bvecs4000_selected = bvecs[91:182][bvals[91:182]==4000][selected_indices]
for b in bvecs4000_selected:
    true_indices.append(np.where((b == bvecs).all(axis=1))[0][0])
true_indices2 = true_indices

# Choose the first point (arbitrary starting point, e.g., the first gradient)
selected_indices = [0]
bvecs2000 = bvecs[182:][bvals[182:]==2000]
distance_matrix = squareform(pdist(bvecs2000))
# Iteratively select the point furthest from the current selection
for _ in range(2):  # We need 7 points in total, and one is already selected
    remaining_indices = list(set(range(len(bvecs2000))) - set(selected_indices))
    
    # Calculate the minimum distance to the selected points for each remaining point
    min_distances = np.min(distance_matrix[remaining_indices][:, selected_indices], axis=1)
    
    # Select the point with the maximum minimum distance
    next_index = remaining_indices[np.argmax(min_distances)]
    selected_indices.append(next_index)

selected_indices = selected_indices
bvecs2000_selected = bvecs[182:][bvals[182:]==2000][selected_indices]
true_indices = []
for b in bvecs2000_selected:
    true_indices.append(np.where((b == bvecs).all(axis=1))[0][0])

# Choose the first point (arbitrary starting point, e.g., the first gradient)
selected_indices = [0]
bvecs4000 = bvecs[182:][bvals[182:]==4000]
distance_matrix = squareform(pdist(bvecs4000))
# Iteratively select the point furthest from the current selection
for _ in range(2):  # We need 7 points in total, and one is already selected
    remaining_indices = list(set(range(len(bvecs4000))) - set(selected_indices))
    
    # Calculate the minimum distance to the selected points for each remaining point
    min_distances = np.min(distance_matrix[remaining_indices][:, selected_indices], axis=1)
    
    # Select the point with the maximum minimum distance
    next_index = remaining_indices[np.argmax(min_distances)]
    selected_indices.append(next_index)

selected_indices = selected_indices
bvecs4000_selected = bvecs[182:][bvals[182:]==4000][selected_indices]
for b in bvecs4000_selected:
    true_indices.append(np.where((b == bvecs).all(axis=1))[0][0])
true_indices3 = true_indices

DevIndices = [0] + true_indices1 + true_indices2 + true_indices3
bvecs_Dev = bvecs[DevIndices]
bvals_Dev = bvals[DevIndices]

In [None]:
if os.path.exists(f"{network_path}/Dev_Dat_50_200k_poisson.pickle"):
    with open(f"{network_path}/Dev_Dat_50_200k_poisson.pickle", "rb") as handle:
        posteriorMin = pickle.load(handle)
else:

    np.random.seed(12)
    NumSamps = 100000
    
    # Directions
    x1  = np.random.randn(NumSamps)
    y1  = np.random.randn(NumSamps)
    z1  =  np.random.randn(NumSamps)
    V = np.vstack([x1,y1,z1])
    V = (V/np.linalg.norm(V,axis=0)).T
    Angs = np.array([SpherAng(v) for v in V])
    
    #Diffusion of restricted
    Dpar  = np.random.rand(NumSamps)*5e-3
    Dperp = np.random.rand(NumSamps)*5e-3
    
    #Diffusion of hindered
    Params_abc =  np.random.rand(NumSamps,3)*0.14-0.07
    Params_rest =  np.random.rand(NumSamps,3)*0.03-0.015
    Params = np.hstack([Params_abc,Params_rest])
    DHind = np.array([ComputeDTI(p) for p in Params])
    DHind = np.array([mat_to_vals(ForceLowFA(dt)) for dt in DHind])
    
    #Fraction of hindered
    frac  = np.random.rand(NumSamps)
    
    mean = np.random.rand(NumSamps)*0.005+1e-4
    scale = np.random.rand(NumSamps)*0.0009+0.0001
    
    S0Rand =np.random.rand(NumSamps)*2475+25
    TrainParams = np.column_stack([V,Angs,Dpar,Dperp,DHind,frac,mean,S0Rand])

    TrainSig = []
    NoisyTrainSig = []
    for i in tqdm(range(NumSamps)):
        v = np.array([Angs[i]])
        dpar = Dpar[i]
        dperp = Dperp[i]
        
        dh   = DHind[i]
        f    = [frac[i],1-frac[i]]
    
        a = mean[i]
        #s = sig2[i]
        s0 = S0Rand[i]
        
        Noise = 50#np.random.rand()*30 + 20
                
        TrainSig1 = CombSignal_poisson(bvecs_Dev[:7],bvals_Dev[:7],Delta[0],delta,[v,dpar,dperp,dh,f,a,s0])
        TrainSig2 = CombSignal_poisson(bvecs_Dev[7:13],bvals_Dev[7:13],Delta[1],delta,[v,dpar,dperp,dh,f,a,s0])
        TrainSig3 = CombSignal_poisson(bvecs_Dev[13:],bvals_Dev[13:],Delta[2],delta,[v,dpar,dperp,dh,f,a,s0])
        TrainSig.append(np.hstack([TrainSig1,TrainSig2,TrainSig3]))
        NoisyTrainSig.append(AddNoise(TrainSig[-1],s0,Noise))
    NoisyTrainSig = np.array(NoisyTrainSig)


    Obs = torch.tensor(NoisyTrainSig).float()
    Par = torch.tensor(TrainParams[:,3:]).float()
    # Create inference object. Here, NPE is used.
    inference = SNPE()
    
    # generate simulations and pass to the inference object
    inference = inference.append_simulations(Par, Obs)
    
    # train the density estimator and build the posterior
    density_estimator = inference.train()
    posteriorMin = inference.build_posterior(density_estimator)
    with open(f"{network_path}/Dev_Dat_50_200k_poisson.pickle", "wb") as handle:
        pickle.dump(posteriorMin, handle)

In [None]:
# Compute the mask where the sum is not zero
mask = np.sum(maskdata[:, :, axial_middle, :], axis=-1) != 0

# Get the indices where mask is True
indices = np.argwhere(mask)

# Define the function for optimization
def optimize_pixel(i, j):
    torch.manual_seed(10)  # If required
    posterior_samples_1 = posteriorMin.sample((1000,), x=maskdata[i, j,axial_middle, DevIndices],show_progress_bars=False)
    return i, j, posterior_samples_1.mean(axis=0)

# Initialize NoiseEst with the appropriate shape
ArrShape = mask.shape

# Use joblib to parallelize the optimization tasks
results = Parallel(n_jobs=8)(
    delayed(optimize_pixel)(i, j) for i, j in tqdm(indices)
)


NoiseEst_Min = np.zeros(list(ArrShape) + [13])

# Assign the optimization results to NoiseEst
for i, j, x in results:
    NoiseEst_Min[i, j] = x

for i, j, x in results:
    NoiseEst_Min[i, j,-2] = np.clip(NoiseEst_Min[i, j,-2],0,100)
    NoiseEst_Min[i, j,-3] = np.clip(NoiseEst_Min[i, j,-3],0,1)

In [None]:
bve_splitD = [bvecs_Dev[:7],bvecs_Dev[7:13],bvecs_Dev[13:]]
bva_splitD = [bvals_Dev[:7],bvals_Dev[7:13],bvals_Dev[13:]]
# Compute the mask where the sum is not zero
mask = np.sum(maskdata[:, :, axial_middle, :], axis=-1) != 0

# Get the indices where mask is True
indices = np.argwhere(mask)

# Define the function for optimization
def optimize_pixel_LS(i, j):
    result = sp.optimize.least_squares(residuals_S0, guess, args=[maskdata[i, j,axial_middle, DevIndices],bve_splitD,bva_splitD,Delta],
                          bounds=bounds,verbose=0,xtol=1e-12,gtol=1e-12,ftol=1e-12,jac='3-point')
    return i, j, result.x

# Initialize NoiseEst with the appropriate shape
ArrShape = mask.shape

# Use joblib to parallelize the optimization tasks
results = Parallel(n_jobs=8)(
    delayed(optimize_pixel_LS)(i, j) for i, j in tqdm(indices)
)


NoiseEst_LS_Min = np.zeros(list(ArrShape) + [13])

# Assign the optimization results to NoiseEst
for i, j, x in results:
    NoiseEst_LS_Min[i, j] = x

In [None]:
# Compute the mask where the sum is not zero
mask = np.sum(maskdata[:, 54, :, :], axis=-1) != 0

# Get the indices where mask is True
indices = np.argwhere(mask)
def optimize_chunk(pixels):
    results = []
    for i, j in pixels:
        posterior_samples_1 = posterior.sample((1000,), x=maskdata[i, 54,j, :],show_progress_bars=False)        
        results.append((i, j, np.array([histogram_mode(p) for p in posterior_samples_1.T])))
    return results

chunked_indices = [indices[i:i+ChunkSize] for i in range(0, len(indices), ChunkSize)]
results = Parallel(n_jobs=8)(
    delayed(optimize_chunk)(chunk) for chunk in tqdm(chunked_indices)
)

# Initialize NoiseEst with the appropriate shape
ArrShape = mask.shape

NoiseEst_CC = np.zeros(list(ArrShape) + [13])

# Assign the optimization results to NoiseEst
for chunk in results:
    for i, j, x in chunk:
        NoiseEst_CC[i, j] = x
        NoiseEst_CC[i, j,-2] = np.clip(NoiseEst_CC[i, j,-2],0,100)
        NoiseEst_CC[i, j,-3] = np.clip(NoiseEst_CC[i, j,-3],0,1)
NoiseEst2_CC = np.copy(NoiseEst_CC)

comb_mask = fat_mask * ((1-NoiseEst2_CC[...,-3])>0.1)

mask_CC = (1-NoiseEst2_CC[...,-3])<0.3
for i in range(13):
    NoiseEst2_CC[~mask,i] = math.nan

NoiseEst2_CC[~comb_mask,-2] = math.nan

In [None]:
# Compute the mask where the sum is not zero
mask = np.sum(maskdata[:, 54, :, :], axis=-1) != 0

# Get the indices where mask is True
indices = np.argwhere(mask)

# Define the function for optimization
def optimize_pixel(i, j):
    torch.manual_seed(10)  # If required
    posterior_samples_1 = posteriorMin.sample((500,), x=maskdata[i, 54, j, DevIndices],show_progress_bars=False)
    return i, j, posterior_samples_1.mean(axis=0)

# Initialize NoiseEst with the appropriate shape
ArrShape = mask.shape

# Use joblib to parallelize the optimization tasks
results = Parallel(n_jobs=-1)(
    delayed(optimize_pixel)(i, j) for i, j in tqdm(indices)
)


NoiseEst_Min_CC = np.zeros(list(ArrShape) + [13])

# Assign the optimization results to NoiseEst
for i, j, x in results:
    NoiseEst_Min_CC[i, j] = x

for i, j, x in results:
    NoiseEst_Min_CC[i, j,-2] = np.clip(NoiseEst_Min_CC[i, j,-2],0,100)
    NoiseEst_Min_CC[i, j,-3] = np.clip(NoiseEst_Min_CC[i, j,-3],0,1)
NoiseEst2_Min_CC = np.copy(NoiseEst_Min_CC)

for i in range(13):
    NoiseEst2_Min_CC[~mask,i] = math.nan

NoiseEst2_Min_CC[~comb_mask,-2] = math.nan

In [None]:
# Compute the mask where the sum is not zero
mask = np.sum(maskdata[:, 54, :, :], axis=-1) != 0

# Define the function for optimization
def optimize_pixel_LS(i, j):
    result = sp.optimize.least_squares(residuals_S0, guess, args=[maskdata[i, 54, j, :],bve_split,bva_split,Delta],
                              bounds=bounds,verbose=0,xtol=1e-12,gtol=1e-12,ftol=1e-12,jac='3-point')
    return i, j, result.x

# Initialize NoiseEst with the appropriate shape
ArrShape = mask.shape

# Use joblib to parallelize the optimization tasks
results = Parallel(n_jobs=8)(
    delayed(optimize_pixel_LS)(i, j) for i, j in tqdm(indices)
)


NoiseEst_LS_CC = np.zeros(list(ArrShape) + [13])

# Assign the optimization results to NoiseEst
for i, j, x in results:
    NoiseEst_LS_CC[i, j] = x

NoiseEst2_LS_CC = np.copy(NoiseEst_LS_CC)

for i in range(13):
    NoiseEst2_LS_CC[~mask,i] = math.nan

NoiseEst2_LS_CC[~comb_mask,-2] = math.nan

In [None]:
bve_splitD = [bvecs_Dev[:7],bvecs_Dev[7:13],bvecs_Dev[13:]]
bva_splitD = [bvals_Dev[:7],bvals_Dev[7:13],bvals_Dev[13:]]

if os.path.exists(f"{DatFolder}/Temp_LS_Min_CC.npy"):
    with open(f"{DatFolder}/Temp_LS_Min_CC.npy", "rb") as handle:
        NoiseEst_LS_Min_CC = np.load(f"{DatFolder}/Temp_LS_Min_CC.npy",allow_pickle=True)
else:
    # Compute the mask where the sum is not zero
    mask = np.sum(maskdata[:, 54, :, :], axis=-1) != 0
    
    # Get the indices where mask is True
    indices = np.argwhere(mask)
    
    def optimize_pixel_LS(i, j):
        result = sp.optimize.least_squares(residuals_S0, guess, args=[maskdata[i,54,j, DevIndices],bve_splitD,bva_splitD,Delta],
                              bounds=bounds,verbose=0,xtol=1e-12,gtol=1e-12,ftol=1e-12,jac='3-point')
        return i, j, result.x
    
    
    # Initialize NoiseEst with the appropriate shape
    
    
    ArrShape = mask.shape
    
    # Use joblib to parallelize the optimization tasks
    results = Parallel(n_jobs=8)(
        delayed(optimize_pixel_LS)(i, j) for i, j in tqdm(indices)
    )
    
    
    NoiseEst_LS_Min_CC = np.zeros(list(ArrShape) + [13])
    
    # Assign the optimization results to NoiseEst
    for i, j, x in results:
        NoiseEst_LS_Min_CC[i, j] = x
NoiseEst2_LS_Min_CC = np.copy(NoiseEst_LS_Min_CC)

for i in range(13):
    NoiseEst2_LS_Min_CC[~mask,i] = math.nan

NoiseEst2_LS_Min_CC[~comb_mask,-2] = math.nan

In [None]:
Dirs = ['Ctrl055_R01_28','Ctrl056_R01_29','Ctrl057_R01_30']
Masks = ['mask_055.nii.gz','mask_056.nii.gz','mask_057.nii.gz']
BVecs = []
BVals = []
Deltas = []
deltas = []
S_masks = []
Datas = []
Outlines = []
for D,M in tqdm(zip(Dirs,Masks)):
    dat = pmt.read_mat(MSDir+D+'/data_loaded.mat')
    BVecs.append(dat['direction'])
    BVals.append(dat['bval'])
    Deltas.append(FixedParams['Delta'])
    deltas.append(FixedParams['delta'])
    
    m, _, _ = load_nifti(MSDir+D+'/'+M, return_img=True)
    S_masks.append(m)

    data = dat['data']
    axial_middle = data.shape[2] // 2
    md, mk = median_otsu(data, vol_idx=range(0, 10), median_radius=5,
                                 numpass=1, autocrop=False, dilate=2)
    Datas.append(md)
    Outlines.append(mk)

In [None]:
if os.path.exists(f"./Networks/3Indv_50_300k_poisson.pickle"):
    with open(f"./Networks/3Indv_50_300k_poisson.pickle", "rb") as handle:
        posterior = pickle.load(handle)
else:

    np.random.seed(12)
    NumSamps = 300000
    
    # Directions
    x1  = np.random.randn(NumSamps)
    y1  = np.random.randn(NumSamps)
    z1  =  np.random.randn(NumSamps)
    V = np.vstack([x1,y1,z1])
    V = (V/np.linalg.norm(V,axis=0)).T
    Angs = np.array([SpherAng(v) for v in V])
    
    #Diffusion of restricted
    Dpar  = np.random.rand(NumSamps)*5e-3
    Dperp = np.random.rand(NumSamps)*5e-3
    
    #Diffusion of hindered
    Params_abc =  np.random.rand(NumSamps,3)*0.14-0.07
    Params_rest =  np.random.rand(NumSamps,3)*0.03-0.015
    Params = np.hstack([Params_abc,Params_rest])
    DHind = np.array([ComputeDTI(p) for p in Params])
    DHind = np.array([mat_to_vals(ForceLowFA(dt)) for dt in DHind])
    
    #Fraction of hindered
    frac  = np.random.rand(NumSamps)
    
    mean = np.random.rand(NumSamps)*0.005+1e-4
    
    S0Rand =np.random.rand(NumSamps)*2475+25
    
    Choice = np.random.choice([1,2,3],NumSamps)
    TrainParams = np.column_stack([V,Angs,Dpar,Dperp,DHind,frac,mean,S0Rand,Choice*100])
    
    TrainSig = []
    NoisyTrainSig = []
    for i in tqdm(range(NumSamps)):
        v = np.array([Angs[i]])
        dpar = Dpar[i]
        dperp = Dperp[i]
        
        dh   = DHind[i]
        f    = [frac[i],1-frac[i]]
    
        a = mean[i]
        s0 = S0Rand[i]
        c = Choice[i]
        
        Noise = 50#np.random.rand()*30 + 20
    
        TrainSig1 = CombSignal_poisson(BVecs[c-1][:(n_pts+1)],bvals[:(n_pts+1)],Delta[0],delta,[v,dpar,dperp,dh,f,a,s0])
        TrainSig2 = CombSignal_poisson(BVecs[c-1][(n_pts+1):2*(n_pts+1)],bvals[(n_pts+1):2*(n_pts+1)],Delta[1],delta,[v,dpar,dperp,dh,f,a,s0])
        TrainSig3 = CombSignal_poisson(BVecs[c-1][2*(n_pts+1):],bvals[2*(n_pts+1):],Delta[2],delta,[v,dpar,dperp,dh,f,a,s0])
        TrainSig.append(np.hstack([TrainSig1,TrainSig2,TrainSig3]))
        NoisyTrainSig.append(np.append(AddNoise(TrainSig[-1],s0,Noise),c*100))
    NoisyTrainSig = np.array(NoisyTrainSig)
    
    
    
    Obs = torch.tensor(NoisyTrainSig).float()
    Par = torch.tensor(TrainParams[:,3:]).float()
    # Create inference object. Here, NPE is used.
    inference = SNPE()
    
    # generate simulations and pass to the inference object
    inference = inference.append_simulations(Par, Obs)
    
    # train the density estimator and build the posterior
    density_estimator = inference.train()
    posterior = inference.build_posterior(density_estimator)
    with open(f"{network_path}/3Indv_50_300k_poisson.pickle", "wb") as handle:
        pickle.dump(posterior, handle)

In [None]:
# Choose the first point (arbitrary starting point, e.g., the first gradient)
IndxArr  = []
BVecsDev = []
BValsDev = []
for bve,bva in zip(BVecs,BVals): 
    # Choose the first point (arbitrary starting point, e.g., the first gradient)
    selected_indices = [0]
    bvecs2000 = bve[:91][bva[:91]==2000]
    distance_matrix = squareform(pdist(bvecs2000))
    # Iteratively select the point furthest from the current selection
    for _ in range(2):  # We need 7 points in total, and one is already selected
        remaining_indices = list(set(range(len(bvecs2000))) - set(selected_indices))
        
        # Calculate the minimum distance to the selected points for each remaining point
        min_distances = np.min(distance_matrix[remaining_indices][:, selected_indices], axis=1)
        
        # Select the point with the maximum minimum distance
        next_index = remaining_indices[np.argmax(min_distances)]
        selected_indices.append(next_index)
    
    selected_indices = selected_indices
    bvecs2000_selected = bve[:91][bva[:91]==2000][selected_indices]
    true_indices = []
    for b in bvecs2000_selected:
        true_indices.append(np.where((b == bve).all(axis=1))[0][0])
    
    # Choose the first point (arbitrary starting point, e.g., the first gradient)
    selected_indices = [0]
    bvecs4000 = bve[:91][bva[:91]==4000]
    distance_matrix = squareform(pdist(bvecs4000))
    # Iteratively select the point furthest from the current selection
    for _ in range(2):  # We need 7 points in total, and one is already selected
        remaining_indices = list(set(range(len(bvecs4000))) - set(selected_indices))
        
        # Calculate the minimum distance to the selected points for each remaining point
        min_distances = np.min(distance_matrix[remaining_indices][:, selected_indices], axis=1)
        
        # Select the point with the maximum minimum distance
        next_index = remaining_indices[np.argmax(min_distances)]
        selected_indices.append(next_index)
    
    selected_indices = selected_indices
    bvecs4000_selected = bve[:91][bva[:91]==4000][selected_indices]
    for b in bvecs4000_selected:
        true_indices.append(np.where((b == bve).all(axis=1))[0][0])
    true_indices1 = true_indices
    
    # Choose the first point (arbitrary starting point, e.g., the first gradient)
    selected_indices = [0]
    bvecs2000 = bve[91:182][bva[91:182]==2000]
    distance_matrix = squareform(pdist(bvecs2000))
    # Iteratively select the point furthest from the current selection
    for _ in range(2):  # We need 7 points in total, and one is already selected
        remaining_indices = list(set(range(len(bvecs2000))) - set(selected_indices))
        
        # Calculate the minimum distance to the selected points for each remaining point
        min_distances = np.min(distance_matrix[remaining_indices][:, selected_indices], axis=1)
        
        # Select the point with the maximum minimum distance
        next_index = remaining_indices[np.argmax(min_distances)]
        selected_indices.append(next_index)
    
    selected_indices = selected_indices
    bvecs2000_selected = bve[91:182][bva[91:182]==2000][selected_indices]
    true_indices = []
    for b in bvecs2000_selected:
        true_indices.append(np.where((b == bve).all(axis=1))[0][0])
    
    # Choose the first point (arbitrary starting point, e.g., the first gradient)
    selected_indices = [0]
    bvecs4000 = bve[91:182][bva[91:182]==4000]
    distance_matrix = squareform(pdist(bvecs4000))
    # Iteratively select the point furthest from the current selection
    for _ in range(2):  # We need 7 points in total, and one is already selected
        remaining_indices = list(set(range(len(bvecs4000))) - set(selected_indices))
        
        # Calculate the minimum distance to the selected points for each remaining point
        min_distances = np.min(distance_matrix[remaining_indices][:, selected_indices], axis=1)
        
        # Select the point with the maximum minimum distance
        next_index = remaining_indices[np.argmax(min_distances)]
        selected_indices.append(next_index)
    
    selected_indices = selected_indices
    bvecs4000_selected = bve[91:182][bva[91:182]==4000][selected_indices]
    for b in bvecs4000_selected:
        true_indices.append(np.where((b == bve).all(axis=1))[0][0])
    true_indices2 = true_indices
    
    # Choose the first point (arbitrary starting point, e.g., the first gradient)
    selected_indices = [0]
    bvecs2000 = bve[182:][bva[182:]==2000]
    distance_matrix = squareform(pdist(bvecs2000))
    # Iteratively select the point furthest from the current selection
    for _ in range(2):  # We need 7 points in total, and one is already selected
        remaining_indices = list(set(range(len(bvecs2000))) - set(selected_indices))
        
        # Calculate the minimum distance to the selected points for each remaining point
        min_distances = np.min(distance_matrix[remaining_indices][:, selected_indices], axis=1)
        
        # Select the point with the maximum minimum distance
        next_index = remaining_indices[np.argmax(min_distances)]
        selected_indices.append(next_index)
    
    selected_indices = selected_indices
    bvecs2000_selected = bve[182:][bva[182:]==2000][selected_indices]
    true_indices = []
    for b in bvecs2000_selected:
        true_indices.append(np.where((b == bve).all(axis=1))[0][0])
    
    # Choose the first point (arbitrary starting point, e.g., the first gradient)
    selected_indices = [0]
    bvecs4000 = bve[182:][bva[182:]==4000]
    distance_matrix = squareform(pdist(bvecs4000))
    # Iteratively select the point furthest from the current selection
    for _ in range(2):  # We need 7 points in total, and one is already selected
        remaining_indices = list(set(range(len(bvecs4000))) - set(selected_indices))
        
        # Calculate the minimum distance to the selected points for each remaining point
        min_distances = np.min(distance_matrix[remaining_indices][:, selected_indices], axis=1)
        
        # Select the point with the maximum minimum distance
        next_index = remaining_indices[np.argmax(min_distances)]
        selected_indices.append(next_index)
    
    selected_indices = selected_indices
    bvecs4000_selected = bve[182:][bva[182:]==4000][selected_indices]
    for b in bvecs4000_selected:
        true_indices.append(np.where((b == bve).all(axis=1))[0][0])
    true_indices3 = true_indices
    
    DevIndices = [0] + true_indices1 + true_indices2 + true_indices3
    bvecs_Dev = bve[DevIndices]
    bvals_Dev = bva[DevIndices]

    IndxArr.append(DevIndices)
    BVecsDev.append(bvecs_Dev)
    BValsDev.append(bvals_Dev)

In [None]:
if os.path.exists(f"./Networks/Dev_3Indv_50_300k_poisson.pickle"):
    with open(f"./Networks/Dev_3Indv_50_300k_poisson.pickle", "rb") as handle:
        posteriorMin = pickle.load(handle)
else:

    np.random.seed(12)
    NumSamps = 600000
    
    # Directions
    x1  = np.random.randn(NumSamps)
    y1  = np.random.randn(NumSamps)
    z1  =  np.random.randn(NumSamps)
    V = np.vstack([x1,y1,z1])
    V = (V/np.linalg.norm(V,axis=0)).T
    Angs = np.array([SpherAng(v) for v in V])
    
    #Diffusion of restricted
    Dpar  = np.random.rand(NumSamps)*5e-3
    Dperp = np.random.rand(NumSamps)*5e-3
    
    #Diffusion of hindered
    Params_abc =  np.random.rand(NumSamps,3)*0.14-0.07
    Params_rest =  np.random.rand(NumSamps,3)*0.03-0.015
    Params = np.hstack([Params_abc,Params_rest])
    DHind = np.array([ComputeDTI(p) for p in Params])
    DHind = np.array([mat_to_vals(ForceLowFA(dt)) for dt in DHind])
    
    #Fraction of hindered
    frac  = np.random.rand(NumSamps)
    
    mean = np.random.rand(NumSamps)*0.005+1e-4
    
    S0Rand =np.random.rand(NumSamps)*2475+25
    
    Choice = np.random.choice([1,2,3],NumSamps)

    TrainParams = np.column_stack([V,Angs,Dpar,Dperp,DHind,frac,mean,S0Rand,Choice*100])
    TrainSig = []
    NoisyTrainSig = []
    for i in tqdm(range(NumSamps)):
        v = np.array([Angs[i]])
        dpar = Dpar[i]
        dperp = Dperp[i]
        
        dh   = DHind[i]
        f    = [frac[i],1-frac[i]]
    
        a = mean[i]
        #s = sig2[i]
        s0 = S0Rand[i]
        
        Noise = 50#np.random.rand()*30 + 20
        c = Choice[i]
        
        Noise = 50#np.random.rand()*30 + 20

        TrainSig1 = CombSignal_poisson(BVecsDev[c-1][:7],BValsDev[c-1][:7],Delta[0],delta,[v,dpar,dperp,dh,f,a,s0])
        TrainSig2 = CombSignal_poisson(BVecsDev[c-1][7:13],BValsDev[c-1][7:13],Delta[1],delta,[v,dpar,dperp,dh,f,a,s0])
        TrainSig3 = CombSignal_poisson(BVecsDev[c-1][13:],BValsDev[c-1][13:],Delta[2],delta,[v,dpar,dperp,dh,f,a,s0])
        TrainSig.append(np.hstack([TrainSig1,TrainSig2,TrainSig3]))
        NoisyTrainSig.append(np.append(AddNoise(TrainSig[-1],s0,Noise),c*100))
    NoisyTrainSig = np.array(NoisyTrainSig)
    
    
    Obs = torch.tensor(NoisyTrainSig).float()
    Par = torch.tensor(TrainParams[:,3:]).float()
    # Create inference object. Here, NPE is used.
    inference = SNPE()
    
    # generate simulations and pass to the inference object
    inference = inference.append_simulations(Par, Obs)
    
    # train the density estimator and build the posterior
    density_estimator = inference.train()
    posteriorMin = inference.build_posterior(density_estimator)
    with open(f"Networks/Dev_3Indv_50_300k_poisson.pickle", "wb") as handle:
        pickle.dump(posteriorMin, handle)

In [None]:
Full_SBI = []
for kk,(D,sl,sma) in enumerate(zip(Datas,[54,52,54],S_masks)):
    # Compute the mask where the sum is not zero
    mask = sma[:,sl,:]
    
    # Get the indices where mask is True
    indices = np.argwhere(mask)
    
    # Define the function for optimization
    def optimize_pixel(i, j):
        torch.manual_seed(10)  # If required
        posterior_samples_1 = posterior.sample((1000,), x=np.append(D[i, sl, j, :],100*(kk+1)),show_progress_bars=False)
        return i, j, posterior_samples_1.mean(axis=0)
    
    # Initialize NoiseEst with the appropriate shape
    ArrShape = mask.shape
    
    # Use joblib to parallelize the optimization tasks
    results = Parallel(n_jobs=8)(
        delayed(optimize_pixel)(i, j) for i, j in tqdm(indices)
    )
    
    
    NoiseEst = np.zeros(list(ArrShape) + [14])
    
    # Assign the optimization results to NoiseEst
    for i, j, x in results:
        NoiseEst[i, j] = x

    Full_SBI.append(NoiseEst)

In [None]:
Min_SBI = []
for kk,(D,sl,sma) in enumerate(zip(Datas,[54,52,54],S_masks)):
    # Compute the mask where the sum is not zero
    mask = sma[:,sl,:]
    
    # Get the indices where mask is True
    indices = np.argwhere(mask)
    
    # Define the function for optimization
    def optimize_pixel(i, j):
        torch.manual_seed(10)  # If required
        posterior_samples_1 = posteriorMin.sample((1000,), x=np.append(D[i, sl, j, IndxArr[kk]],100*(kk+1)),show_progress_bars=False)
        return i, j, posterior_samples_1.mean(axis=0)
    
    # Initialize NoiseEst with the appropriate shape
    ArrShape = mask.shape
    
    # Use joblib to parallelize the optimization tasks
    results = Parallel(n_jobs=-1)(
        delayed(optimize_pixel)(i, j) for i, j in tqdm(indices)
    )
    
    
    NoiseEst = np.zeros(list(ArrShape) + [14])
    
    # Assign the optimization results to NoiseEst
    for i, j, x in results:
        NoiseEst[i, j] = x

    Min_SBI.append(NoiseEst)

In [None]:
CMasks = []
kk = 0
d  = 54
temp = np.copy(Full_SBI[kk])
for i in range(14):
    temp[~Outlines[kk][:,d,:],i] = math.nan
    
mask1 = np.ones_like(S_masks[kk][:,d,:])
mask1[S_masks[kk][:,d,:]==0] = 0

# Apply dilation. Increase 'iterations' to make the mask even fatter.
fat_mask = binary_dilation(mask1, structure=structure, iterations=1)

CMasks.append(fat_mask * ((1-temp[...,-4])>0.1) * (temp[...,-4]>0))

kk = 1
d  = 52
temp = np.copy(Full_SBI[kk])
for i in range(14):
    temp[~Outlines[kk][:,d,:],i] = math.nan
    
mask1 = np.ones_like(S_masks[kk][:,d,:])
mask1[S_masks[kk][:,d,:]==0] = 0

# Apply dilation. Increase 'iterations' to make the mask even fatter.
fat_mask = binary_dilation(mask1, structure=structure, iterations=1)

CMasks.append(fat_mask * ((1-temp[...,-4])>0) * (temp[...,-4]>0))

kk = 2
d  = 54
temp = np.copy(Full_SBI[kk])
for i in range(14):
    temp[~Outlines[kk][:,d,:],i] = math.nan
    
mask1 = np.ones_like(S_masks[kk][:,d,:])
mask1[S_masks[kk][:,d,:]==0] = 0
# Apply dilation. Increase 'iterations' to make the mask even fatter.
fat_mask = binary_dilation(mask1, structure=structure, iterations=1)
CMasks.append(fat_mask * ((1-temp[...,-4])>0.3) * (temp[...,-4]>0))

In [None]:
Full_LS = []
for kk,(D,sl,sma) in enumerate(zip(Datas,[54,52,54],S_masks)):
    # Compute the mask where the sum is not zero
    mask = sma[:,sl,:]
    
    # Get the indices where mask is True
    indices = np.argwhere(mask)
    bve_split_kk = [BVecs[kk][:(n_pts+1)],BVecs[kk][(n_pts+1):2*(n_pts+1)],BVecs[kk][2*(n_pts+1):]]
    bva_split_kk = [BVals[kk][:(n_pts+1)],BVals[kk][(n_pts+1):2*(n_pts+1)],BVals[kk][2*(n_pts+1):]]
    # Define the function for optimization
    def optimize_pixel_LS(i, j):
        result = sp.optimize.least_squares(residuals_S0, guess, args=[D[i, sl, j, :],bve_split_kk,bva_split_kk,Delta],
                                  bounds=bounds,verbose=0,xtol=1e-12,gtol=1e-12,ftol=1e-12,jac='3-point')
        return i, j, result.x
    
    # Initialize NoiseEst with the appropriate shape
    ArrShape = mask.shape
    
    # Use joblib to parallelize the optimization tasks
    results = Parallel(n_jobs=8)(
        delayed(optimize_pixel_LS)(i, j) for i, j in tqdm(indices)
    )
    
    
    NoiseEst_LS = np.zeros(list(ArrShape) + [13])
    
    # Assign the optimization results to NoiseEst
    for i, j, x in results:
        NoiseEst_LS[i, j] = x

    Full_LS.append(NoiseEst_LS)

In [None]:
Min_LS = []
for kk,(D,sl,sma) in enumerate(zip(Datas,[54,52,54],S_masks)):
    # Compute the mask where the sum is not zero
    mask = sma[:,sl,:]
    
    # Get the indices where mask is True
    indices = np.argwhere(mask)
    bve_splitd_kk = [BVecsDev[kk][:7],BVecsDev[kk][7:13],BVecsDev[kk][13:]]
    bva_splitd_kk = [BValsDev[kk][:7],BValsDev[kk][7:13],BValsDev[kk][13:]]

    # Define the function for optimization
    def optimize_pixel_LS(i, j):
        result = sp.optimize.least_squares(residuals_S0, guess, args=[D[i, sl, j, IndxArr[kk]],bve_splitd_kk,bva_splitd_kk,Delta],
                                  bounds=bounds,verbose=0,xtol=1e-12,gtol=1e-12,ftol=1e-12,jac='3-point')
        return i, j, result.x
    
    # Initialize NoiseEst with the appropriate shape
    ArrShape = mask.shape
    
    # Use joblib to parallelize the optimization tasks
    results = Parallel(n_jobs=8)(
        delayed(optimize_pixel_LS)(i, j) for i, j in tqdm(indices)
    )
    
    
    NoiseEst_LS = np.zeros(list(ArrShape) + [13])
    
    # Assign the optimization results to NoiseEst
    for i, j, x in results:
        NoiseEst_LS[i, j] = x

    Min_LS.append(NoiseEst_LS)

In [None]:
Dirs = ['NMSS_11_1year','NMSS_15','NMSS_16','NMSS_18','NMSS_19','Ctrl055_R01_28','Ctrl056_R01_29','Ctrl057_R01_30']
BVecs = []
BVals = []
Deltas = []
deltas = []
S_masks = []
Datas = []
Outlines = []
axial_middles = []
for D in tqdm(Dirs):
    F = pmt.read_mat(MSDir+D+'/data_loaded.mat')
    affine = np.ones((4,4))
    BVecs.append(F['direction'])
    BVals.append(F['bval'])
    Deltas.append(Delta)
    deltas.append(delta)


    
    data, affine = reslice(F['data'], affine, (2,2,2), (2.5,2.5,2.5))

    axial_middle = data.shape[2] // 2
    md, mk = median_otsu(data, vol_idx=range(0, 10), median_radius=5,
                                 numpass=1, autocrop=False, dilate=2)
    Datas.append(md)
    axial_middles.append(axial_middle)
    Outlines.append(mk)

In [None]:
if os.path.exists(f"{network_path}/8Indv_50_300k_poisson.pickle"):
    with open(f"{network_path}/8Indv_50_300k_poisson.pickle", "rb") as handle:
        posterior = pickle.load(handle)
else:

    np.random.seed(12)
    NumSamps = 300000
    
    # Directions
    x1  = np.random.randn(NumSamps)
    y1  = np.random.randn(NumSamps)
    z1  =  np.random.randn(NumSamps)
    V = np.vstack([x1,y1,z1])
    V = (V/np.linalg.norm(V,axis=0)).T
    Angs = np.array([SpherAng(v) for v in V])
    
    #Diffusion of restricted
    Dpar  = np.random.rand(NumSamps)*5e-3
    Dperp = np.random.rand(NumSamps)*5e-3
    
    #Diffusion of hindered
    Params_abc =  np.random.rand(NumSamps,3)*0.14-0.07
    Params_rest =  np.random.rand(NumSamps,3)*0.03-0.015
    Params = np.hstack([Params_abc,Params_rest])
    DHind = np.array([ComputeDTI(p) for p in Params])
    DHind = np.array([mat_to_vals(ForceLowFA(dt)) for dt in DHind])
    
    #Fraction of hindered
    frac  = np.random.rand(NumSamps)
    
    mean = np.random.rand(NumSamps)*0.005+1e-4
    
    S0Rand =np.random.rand(NumSamps)*2475+25
    
    Choice = np.random.choice([1,2,3,4,5,6,7,8],NumSamps)

    TrainParams = np.column_stack([V,Angs,Dpar,Dperp,DHind,frac,mean,S0Rand,Choice*100])
    TrainSig = []
    NoisyTrainSig = []
    for i in tqdm(range(NumSamps)):
        v = np.array([Angs[i]])
        dpar = Dpar[i]
        dperp = Dperp[i]
        
        dh   = DHind[i]
        f    = [frac[i],1-frac[i]]
    
        a = mean[i]
        s0 = S0Rand[i]
        c = Choice[i]
        
        Noise = 50#np.random.rand()*30 + 20
    
        TrainSig1 = CombSignal_poisson(BVecs[c-1][:(n_pts+1)],bvals[:(n_pts+1)],Delta[0],delta,[v,dpar,dperp,dh,f,a,s0])
        TrainSig2 = CombSignal_poisson(BVecs[c-1][(n_pts+1):2*(n_pts+1)],bvals[(n_pts+1):2*(n_pts+1)],Delta[1],delta,[v,dpar,dperp,dh,f,a,s0])
        TrainSig3 = CombSignal_poisson(BVecs[c-1][2*(n_pts+1):],bvals[2*(n_pts+1):],Delta[2],delta,[v,dpar,dperp,dh,f,a,s0])
        TrainSig.append(np.hstack([TrainSig1,TrainSig2,TrainSig3]))
        NoisyTrainSig.append(np.append(AddNoise(TrainSig[-1],s0,Noise),c*100))
    NoisyTrainSig = np.array(NoisyTrainSig)
    
    
    
    Obs = torch.tensor(NoisyTrainSig).float()
    Par = torch.tensor(TrainParams[:,3:]).float()
    # Create inference object. Here, NPE is used.
    inference = SNPE()
    
    # generate simulations and pass to the inference object
    inference = inference.append_simulations(Par, Obs)
    
    # train the density estimator and build the posterior
    density_estimator = inference.train(stop_after_epochs=100)
    posterior = inference.build_posterior(density_estimator)
    with open(f"{network_path}/8Indv_50_300k_poisson.pickle", "wb") as handle:
        pickle.dump(posterior, handle)

In [None]:
# Choose the first point (arbitrary starting point, e.g., the first gradient)
IndxArr  = []
BVecsDev = []
BValsDev = []
for bve,bva in zip(BVecs,BVals): 
    # Choose the first point (arbitrary starting point, e.g., the first gradient)
    selected_indices = [0]
    bvecs2000 = bve[:91][bva[:91]==2000]
    distance_matrix = squareform(pdist(bve))
    # Iteratively select the point furthest from the current selection
    for _ in range(2):  # We need 7 points in total, and one is already selected
        remaining_indices = list(set(range(len(bvecs2000))) - set(selected_indices))
        
        # Calculate the minimum distance to the selected points for each remaining point
        min_distances = np.min(distance_matrix[remaining_indices][:, selected_indices], axis=1)
        
        # Select the point with the maximum minimum distance
        next_index = remaining_indices[np.argmax(min_distances)]
        selected_indices.append(next_index)
    
    selected_indices = selected_indices
    bvecs2000_selected = bve[:91][bva[:91]==2000][selected_indices]
    true_indices = []
    for b in bvecs2000_selected:
        true_indices.append(np.where((b == bve).all(axis=1))[0][0])
    
    # Choose the first point (arbitrary starting point, e.g., the first gradient)
    selected_indices = [0]
    bvecs4000 = bve[:91][bva[:91]==4000]
    distance_matrix = squareform(pdist(bve))
    # Iteratively select the point furthest from the current selection
    for _ in range(2):  # We need 7 points in total, and one is already selected
        remaining_indices = list(set(range(len(bvecs4000))) - set(selected_indices))
        
        # Calculate the minimum distance to the selected points for each remaining point
        min_distances = np.min(distance_matrix[remaining_indices][:, selected_indices], axis=1)
        
        # Select the point with the maximum minimum distance
        next_index = remaining_indices[np.argmax(min_distances)]
        selected_indices.append(next_index)
    
    selected_indices = selected_indices
    bvecs4000_selected = bve[:91][bva[:91]==4000][selected_indices]
    for b in bvecs4000_selected:
        true_indices.append(np.where((b == bve).all(axis=1))[0][0])
    true_indices1 = true_indices
    
    # Choose the first point (arbitrary starting point, e.g., the first gradient)
    selected_indices = [0]
    bvecs2000 = bve[91:182][bva[91:182]==2000]
    distance_matrix = squareform(pdist(bve))
    # Iteratively select the point furthest from the current selection
    for _ in range(2):  # We need 7 points in total, and one is already selected
        remaining_indices = list(set(range(len(bvecs2000))) - set(selected_indices))
        
        # Calculate the minimum distance to the selected points for each remaining point
        min_distances = np.min(distance_matrix[remaining_indices][:, selected_indices], axis=1)
        
        # Select the point with the maximum minimum distance
        next_index = remaining_indices[np.argmax(min_distances)]
        selected_indices.append(next_index)
    
    selected_indices = selected_indices
    bvecs2000_selected = bve[91:182][bva[91:182]==2000][selected_indices]
    true_indices = []
    for b in bvecs2000_selected:
        true_indices.append(np.where((b == bve).all(axis=1))[0][0])
    
    # Choose the first point (arbitrary starting point, e.g., the first gradient)
    selected_indices = [0]
    bvecs4000 = bve[91:182][bva[91:182]==4000]
    distance_matrix = squareform(pdist(bve))
    # Iteratively select the point furthest from the current selection
    for _ in range(2):  # We need 7 points in total, and one is already selected
        remaining_indices = list(set(range(len(bvecs4000))) - set(selected_indices))
        
        # Calculate the minimum distance to the selected points for each remaining point
        min_distances = np.min(distance_matrix[remaining_indices][:, selected_indices], axis=1)
        
        # Select the point with the maximum minimum distance
        next_index = remaining_indices[np.argmax(min_distances)]
        selected_indices.append(next_index)
    
    selected_indices = selected_indices
    bvecs4000_selected = bve[91:182][bva[91:182]==4000][selected_indices]
    for b in bvecs4000_selected:
        true_indices.append(np.where((b == bve).all(axis=1))[0][0])
    true_indices2 = true_indices
    
    # Choose the first point (arbitrary starting point, e.g., the first gradient)
    selected_indices = [0]
    bvecs2000 = bve[182:][bva[182:]==2000]
    distance_matrix = squareform(pdist(bve))
    # Iteratively select the point furthest from the current selection
    for _ in range(2):  # We need 7 points in total, and one is already selected
        remaining_indices = list(set(range(len(bvecs2000))) - set(selected_indices))
        
        # Calculate the minimum distance to the selected points for each remaining point
        min_distances = np.min(distance_matrix[remaining_indices][:, selected_indices], axis=1)
        
        # Select the point with the maximum minimum distance
        next_index = remaining_indices[np.argmax(min_distances)]
        selected_indices.append(next_index)
    
    selected_indices = selected_indices
    bvecs2000_selected = bve[182:][bva[182:]==2000][selected_indices]
    true_indices = []
    for b in bvecs2000_selected:
        true_indices.append(np.where((b == bve).all(axis=1))[0][0])
    
    # Choose the first point (arbitrary starting point, e.g., the first gradient)
    selected_indices = [0]
    bvecs4000 = bve[182:][bva[182:]==4000]
    distance_matrix = squareform(pdist(bve))
    # Iteratively select the point furthest from the current selection
    for _ in range(2):  # We need 7 points in total, and one is already selected
        remaining_indices = list(set(range(len(bvecs4000))) - set(selected_indices))
        
        # Calculate the minimum distance to the selected points for each remaining point
        min_distances = np.min(distance_matrix[remaining_indices][:, selected_indices], axis=1)
        
        # Select the point with the maximum minimum distance
        next_index = remaining_indices[np.argmax(min_distances)]
        selected_indices.append(next_index)
    
    selected_indices = selected_indices
    bvecs4000_selected = bve[182:][bva[182:]==4000][selected_indices]
    for b in bvecs4000_selected:
        true_indices.append(np.where((b == bve).all(axis=1))[0][0])
    true_indices3 = true_indices
    
    DevIndices = [0] + true_indices1 + true_indices2 + true_indices3
    bvecs_Dev = bve[DevIndices]
    bvals_Dev = bva[DevIndices]

    IndxArr.append(DevIndices)
    BVecsDev.append(bvecs_Dev)
    BValsDev.append(bvals_Dev)

In [None]:
if os.path.exists(f"{network_path}/Dev_8Indv_50_300k_poisson.pickle"):
    with open(f"{network_path}/Dev_8Indv_50_300k_poisson.pickle", "rb") as handle:
        posteriorMin = pickle.load(handle)
else:

    np.random.seed(12)
    NumSamps = 300000
    
    # Directions
    x1  = np.random.randn(NumSamps)
    y1  = np.random.randn(NumSamps)
    z1  =  np.random.randn(NumSamps)
    V = np.vstack([x1,y1,z1])
    V = (V/np.linalg.norm(V,axis=0)).T
    Angs = np.array([SpherAng(v) for v in V])
    
    #Diffusion of restricted
    Dpar  = np.random.rand(NumSamps)*5e-3
    Dperp = np.random.rand(NumSamps)*5e-3
    
    #Diffusion of hindered
    Params_abc =  np.random.rand(NumSamps,3)*0.14-0.07
    Params_rest =  np.random.rand(NumSamps,3)*0.03-0.015
    Params = np.hstack([Params_abc,Params_rest])
    DHind = np.array([ComputeDTI(p) for p in Params])
    DHind = np.array([mat_to_vals(ForceLowFA(dt)) for dt in DHind])
    
    #Fraction of hindered
    frac  = np.random.rand(NumSamps)
    
    mean = np.random.rand(NumSamps)*0.005+1e-4
    
    S0Rand =np.random.rand(NumSamps)*2475+25
    
    Choice = np.random.choice([1,2,3,4,5,6,7,8],NumSamps)
    TrainParams = np.column_stack([V,Angs,Dpar,Dperp,DHind,frac,mean,S0Rand,Choice*100])
    TrainSig = []
    NoisyTrainSig = []
    for i in tqdm(range(NumSamps)):
        v = np.array([Angs[i]])
        dpar = Dpar[i]
        dperp = Dperp[i]
        
        dh   = DHind[i]
        f    = [frac[i],1-frac[i]]
    
        a = mean[i]
        #s = sig2[i]
        s0 = S0Rand[i]
        
        Noise = 50#np.random.rand()*30 + 20
        c = Choice[i]
        
        Noise = 50#np.random.rand()*30 + 20

        TrainSig1 = CombSignal_poisson(BVecsDev[c-1][:7],BValsDev[c-1][:7],Delta[0],delta,[v,dpar,dperp,dh,f,a,s0])
        TrainSig2 = CombSignal_poisson(BVecsDev[c-1][7:13],BValsDev[c-1][7:13],Delta[1],delta,[v,dpar,dperp,dh,f,a,s0])
        TrainSig3 = CombSignal_poisson(BVecsDev[c-1][13:],BValsDev[c-1][13:],Delta[2],delta,[v,dpar,dperp,dh,f,a,s0])
        TrainSig.append(np.hstack([TrainSig1,TrainSig2,TrainSig3]))
        NoisyTrainSig.append(np.append(AddNoise(TrainSig[-1],s0,Noise),c*100))
    NoisyTrainSig = np.array(NoisyTrainSig)
    
    
    Obs = torch.tensor(NoisyTrainSig).float()
    Par = torch.tensor(TrainParams[:,3:]).float()
    # Create inference object. Here, NPE is used.
    inference = SNPE()
    
    # generate simulations and pass to the inference object
    inference = inference.append_simulations(Par, Obs)
    
    # train the density estimator and build the posterior
    density_estimator = inference.train(stop_after_epochs=100)
    posteriorMin = inference.build_posterior(density_estimator)
    with open(f"{network_path}/Dev_8Indv_50_300k_poisson.pickle", "wb") as handle:
        pickle.dump(posteriorMin, handle)

In [None]:
Full_SBI_Extra = []
for kk,(D,sl) in enumerate(zip(Datas,[48]*8)):
    # Compute the mask where the sum is not zero
    mask = np.sum(D[:, :, sl, :], axis=-1) != 0
    
    # Get the indices where mask is True
    indices = np.argwhere(mask)
    def optimize_chunk(pixels):
        results = []
        for i, j in pixels:
            posterior_samples_1 = posterior.sample((1000,), x=np.append(D[i,j,sl, :],100*(kk+1)),show_progress_bars=False)
            results.append((i, j, posterior_samples_1.mean(axis=0)))
        return results
    
    chunked_indices = [indices[i:i+ChunkSize] for i in range(0, len(indices), ChunkSize)]
    results = Parallel(n_jobs=8)(
        delayed(optimize_chunk)(chunk) for chunk in tqdm(chunked_indices)
    )
    
    # Initialize NoiseEst with the appropriate shape
    ArrShape = mask.shape
    
    NoiseEst = np.zeros(list(ArrShape) + [14])
    
    # Assign the optimization results to NoiseEst
    for chunk in results:
        for i, j, x in chunk:
            NoiseEst[i, j] = x

    Full_SBI_Extra.append(NoiseEst)

In [None]:
Min_SBI_Extra = []
for kk,(D,sl) in enumerate(zip(Datas,[48]*8)):
    # Compute the mask where the sum is not zero
    mask = np.sum(D[:, :, sl, :], axis=-1) != 0
    
    # Get the indices where mask is True
    indices = np.argwhere(mask)
    Arr = D[:,:,sl, IndxArr[kk]]
    def optimize_chunk(pixels):
        results = []
        for i, j in pixels:
            posterior_samples_1 = posteriorMin.sample((1000,), x=np.append(Arr[i,j],100*(kk+1)),show_progress_bars=False)
            results.append((i, j, posterior_samples_1.mean(axis=0)))
        return results
    
    chunked_indices = [indices[i:i+ChunkSize] for i in range(0, len(indices), ChunkSize)]
    results = Parallel(n_jobs=8)(
        delayed(optimize_chunk)(chunk) for chunk in tqdm(chunked_indices)
    )
    
    # Initialize NoiseEst with the appropriate shape
    ArrShape = mask.shape
    
    NoiseEst = np.zeros(list(ArrShape) + [14])
    
    # Assign the optimization results to NoiseEst
    for chunk in results:
        for i, j, x in chunk:
            NoiseEst[i, j] = x

    Min_SBI_Extra.append(NoiseEst)

In [None]:
if os.path.exists(f"{DatFolder}/temp_Full_LS.npy"):
    with open(f"{DatFolder}/temp_Full_LS.npy", "rb") as handle:
        Full_LS_extra = np.load(DatFolder+'temp_Full_LS.npy')
else:
    Full_LS_extra = []
    for kk,(D,sl) in enumerate(zip(Datas,[48]*8)):
        # Compute the mask where the sum is not zero
        mask = np.sum(D[:, :, sl, :], axis=-1) != 0
        
        # Get the indices where mask is True
        indices = np.argwhere(mask)
    
        bve_split_kk = [BVecs[kk][:(n_pts+1)],BVecs[kk][(n_pts+1):2*(n_pts+1)],BVecs[kk][2*(n_pts+1):]]
        bva_split_kk = [BVals[kk][:(n_pts+1)],BVals[kk][(n_pts+1):2*(n_pts+1)],BVals[kk][2*(n_pts+1):]]
        # Define the function for optimization
        def optimize_pixel_LS(i, j):
            result = sp.optimize.least_squares(residuals_S0, guess, args=[D[i,j,sl, :],bve_split_kk,bva_split_kk,Delta],
                                      bounds=bounds,verbose=0,jac='3-point')
            return i, j, result.x
        
    
        
        # Initialize NoiseEst with the appropriate shape
        ArrShape = mask.shape
        
        # Use joblib to parallelize the optimization tasks
        results = Parallel(n_jobs=8)(
            delayed(optimize_pixel_LS)(i, j) for i, j in tqdm(indices)
        )
        
        
        NoiseEst_LS = np.zeros(list(ArrShape) + [13])
        
        # Assign the optimization results to NoiseEst
        for i, j, x in results:
            NoiseEst_LS[i, j] = x
    
        Full_LS_extra.append(NoiseEst_LS)

In [None]:
if os.path.exists(f"{DatFolder}/temp_Min_LS.npy"):
    with open(f"{DatFolder}/temp_Min_LS.npy", "rb") as handle:
        Min_LS_extra = np.load(DatFolder+'temp_Min_LS.npy')
else:
    Min_LS_extra = []
    for kk,(D,sl) in enumerate(zip(Datas,[48]*8)):
        # Compute the mask where the sum is not zero
        mask = np.sum(D[:, :, sl, :], axis=-1) != 0
        
        # Get the indices where mask is True
        indices = np.argwhere(mask)
        
        bve_splitd_kk = [BVecsDev[kk][:7],BVecsDev[kk][7:13],BVecsDev[kk][13:]]
        bva_splitd_kk = [BValsDev[kk][:7],BValsDev[kk][7:13],BValsDev[kk][13:]]
    
        # Define the function for optimization
        def optimize_pixel_LS(i, j):
            result = sp.optimize.least_squares(residuals_S0, guess, args=[D[i, j,sl, IndxArr[kk]],bve_splitd_kk,bva_splitd_kk,Delta],
                                      bounds=bounds,verbose=0,xtol=1e-12,gtol=1e-12,ftol=1e-12,jac='3-point')
            return i, j, result.x
            
        # Initialize NoiseEst with the appropriate shape
        ArrShape = mask.shape
        
        # Use joblib to parallelize the optimization tasks
        results = Parallel(n_jobs=8)(
            delayed(optimize_pixel_LS)(i, j) for i, j in tqdm(indices)
        )
        
        
        NoiseEst_LS = np.zeros(list(ArrShape) + [13])
        
        # Assign the optimization results to NoiseEst
        for i, j, x in results:
            NoiseEst_LS[i, j] = x
    
        Min_LS_extra.append(NoiseEst_LS)

In [None]:
WMDir = './MS_data/WM_masks/'
WMs = []
for i,Name in tqdm(enumerate(['NMSS_11_1year','NMSS_15','NMSS_16','NMSS_18','NMSS_19','Ctrl055_R01_28','Ctrl056_R01_29','Ctrl057_R01_30'])):
    
    for k,x in enumerate(os.listdir(WMDir)):
        if Name in x:
            WM, affine, img = load_nifti(WMDir+x, return_img=True)
            #WM, affine = reslice(WM, affine, (2,2,2), (2.5,2.5,2.5))
            if(i<5):
                WM_t = np.fliplr(np.swapaxes(WM,0,1))
            else:
                WM_t = np.fliplr(np.flipud(np.swapaxes(WM,0,1)))
            WM_t,_ = reslice(WM_t, affine, (2,2,2), (2.5,2.5,2.5))
            WMs.append(WM_t)

In [None]:
KK = [48]*8
FA_Full_SBI = []
MD_Full_SBI = []
for jj in range(8):
    mask = Outlines[jj][:,:,KK[jj]]
    indices = np.argwhere(mask)
    
    results = Parallel(n_jobs=8)(
        delayed(Par_frac)(i, j,Full_SBI_Extra[jj][...,4:10]) for i, j in tqdm(indices)
    )
    
    
    temp1 = np.zeros(list(ArrShape))
    temp2 = np.zeros(list(ArrShape))
    # Assign the optimization results to NoiseEst
    for i, j, k in results:
        temp1[i, j] = k[0]
        temp2[i, j] = k[1]

    FA_Full_SBI.append(temp1)
    MD_Full_SBI.append(temp2)
KK = [48]*8
FA_Min_SBI = []
MD_Min_SBI = []
for jj in range(8):
    mask = Outlines[jj][:,:,KK[jj]]
    indices = np.argwhere(mask)

    Arr = Min_SBI_Extra[jj][...,4:10]
         
    results = Parallel(n_jobs=8,)(
        delayed(Par_frac)(i, j,Min_SBI_Extra[jj][...,4:10]) for i, j in tqdm(indices)
    )
    
    
    temp1 = np.zeros(list(ArrShape))
    temp2 = np.zeros(list(ArrShape))
    # Assign the optimization results to NoiseEst
    for i, j, k in results:
        temp1[i, j] = k[0]
        temp2[i, j] = k[1]

    FA_Min_SBI.append(temp1)
    MD_Min_SBI.append(temp2)
KK = [48]*8
FA_Full_LS = []
MD_Full_LS = []
for jj in range(8):
    mask = Outlines[jj][:,:,KK[jj]]
    indices = np.argwhere(mask)
    
    results = Parallel(n_jobs=8,)(
        delayed(Par_frac)(i, j,Full_LS_extra[jj][...,4:10]) for i, j in tqdm(indices)
    )
    
    
    temp1 = np.zeros(list(ArrShape))
    temp2 = np.zeros(list(ArrShape))
    # Assign the optimization results to NoiseEst
    for i, j, k in results:
        temp1[i, j] = k[0]
        temp2[i, j] = k[1]

    FA_Full_LS.append(temp1)
    MD_Full_LS.append(temp2)

KK = [48]*8
FA_Min_LS = []
MD_Min_LS = []
for jj in range(8):
    mask = Outlines[jj][:,:,KK[jj]]
    indices = np.argwhere(mask)
    
    results = Parallel(n_jobs=8)(
        delayed(Par_frac)(i, j,Min_LS_extra[jj][...,4:10]) for i, j in tqdm(indices)
    )
    
    
    temp1 = np.zeros(list(ArrShape))
    temp2 = np.zeros(list(ArrShape))
    # Assign the optimization results to NoiseEst
    for i, j, k in results:
        temp1[i, j] = k[0]
        temp2[i, j] = k[1]

    FA_Min_LS.append(temp1)
    MD_Min_LS.append(temp2)


In [None]:

jj = -4
SBI_comp_Frac = []
KK = [48]*8
for i in range(8):
    NS1 = np.copy(Min_SBI_Extra[i][...,jj])
    NS1 = gaussian_filter(NS1, sigma=0.5)
    NS2 = np.copy(Full_SBI_Extra[i][...,jj])

    core,ssim_map = ssim(NS1,NS2, data_range=max([NS1.max(),NS2.max()])-min([NS1.min(),NS2.min()]),full=True,win_size=15)
    masked_ssim = ssim_map[Outlines[i][:,:,KK[i]]].mean()
    SBI_comp_Frac.append(masked_ssim.mean())

LS_comp_Frac = []
for i in range(8):
    NS1 = np.copy(Min_LS_extra[i][...,jj])
    NS1 = gaussian_filter(NS1, sigma=0.5)
    NS2 = np.copy(Full_LS_extra[i][...,jj])

    core,ssim_map = ssim(NS1,NS2, data_range=max([NS1.max(),NS2.max()])-min([NS1.min(),NS2.min()]),full=True,win_size=15)
    masked_ssim = ssim_map[Outlines[i][:,:,KK[i]]].mean()
    LS_comp_Frac.append(masked_ssim.mean())

SBI_LS_comp_Frac = []
for i in range(8):
    NS1 = np.copy(Full_SBI_Extra[i][...,jj])
    NS2 = np.copy(Full_LS_extra[i][...,jj])

    core,ssim_map = ssim(NS1,NS2, data_range=max([NS1.max(),NS2.max()])-min([NS1.min(),NS2.min()]),full=True,win_size=15)
    masked_ssim = ssim_map[Outlines[i][:,:,KK[i]]].mean()
    SBI_LS_comp_Frac.append(masked_ssim.mean())
Prec7_SBI_Frac = []
PrecFull_SBI_Frac = []

Prec7_NLLS_Frac = []
PrecFull_NLLS_Frac = []
for i in range(8):
    Prec7_SBI_Frac.append(np.std(Min_SBI_Extra[i][...,jj][WMs[i].astype(bool)[:,:,48]]))
    PrecFull_SBI_Frac.append(np.std(Full_SBI_Extra[i][...,jj][WMs[i].astype(bool)[:,:,48]]))

    Prec7_NLLS_Frac.append(np.std(Min_LS_extra[i][...,jj][WMs[i].astype(bool)[:,:,48]]))
    PrecFull_NLLS_Frac.append(np.std(Full_LS_extra[i][...,jj][WMs[i].astype(bool)[:,:,48]]))

In [None]:
SBI_comp_MD = []
KK = [48]*8
for i in range(8):
    NS1 = np.copy(MD_Min_SBI[i])
    NS1 = gaussian_filter(NS1, sigma=0.5)
    NS2 = np.copy(MD_Full_SBI[i])

    core,ssim_map = ssim(NS1,NS2, data_range=max([NS1.max(),NS2.max()])-min([NS1.min(),NS2.min()]),full=True,win_size=15)
    masked_ssim = ssim_map[Outlines[i][:,:,KK[i]]].mean()
    SBI_comp_MD.append(masked_ssim.mean())

LS_comp_MD = []
for i in range(8):
    NS1 = np.copy(MD_Min_LS[i])
    NS1 = gaussian_filter(NS1, sigma=0.5)
    NS2 = np.copy(MD_Full_LS[i])

    core,ssim_map = ssim(NS1,NS2, data_range=max([NS1.max(),NS2.max()])-min([NS1.min(),NS2.min()]),full=True,win_size=15)
    masked_ssim = ssim_map[Outlines[i][:,:,KK[i]]].mean()
    LS_comp_MD.append(masked_ssim.mean())

SBI_LS_comp_MD = []
for i in range(8):
    NS1 = np.copy(MD_Full_SBI[i])
    NS2 = np.copy(MD_Full_LS[i])

    core,ssim_map = ssim(NS1,NS2, data_range=max([NS1.max(),NS2.max()])-min([NS1.min(),NS2.min()]),full=True,win_size=15)
    masked_ssim = ssim_map[Outlines[i][:,:,KK[i]]].mean()
    SBI_LS_comp_MD.append(masked_ssim.mean())
Prec7_SBI_MD = []
PrecFull_SBI_MD = []

Prec7_NLLS_MD = []
PrecFull_NLLS_MD = []
for i in range(8):
    Prec7_SBI_MD.append(np.std(MD_Min_SBI[i][WMs[i].astype(bool)[:,:,48]]))
    PrecFull_SBI_MD.append(np.std(MD_Full_SBI[i][WMs[i].astype(bool)[:,:,48]]))

    Prec7_NLLS_MD.append(np.std(MD_Min_LS[i][WMs[i].astype(bool)[:,:,48]]))
    PrecFull_NLLS_MD.append(np.std(MD_Full_LS[i][WMs[i].astype(bool)[:,:,48]]))


In [None]:

jj = -4
SBI_comp_Frac = []
KK = [48]*8
for i in range(8):
    NS1 = np.copy(Min_SBI_Extra[i][...,jj])
    NS1 = gaussian_filter(NS1, sigma=0.5)
    NS2 = np.copy(Full_SBI_Extra[i][...,jj])

    core,ssim_map = ssim(NS1,NS2, data_range=max([NS1.max(),NS2.max()])-min([NS1.min(),NS2.min()]),full=True,win_size=15)
    masked_ssim = ssim_map[Outlines[i][:,:,KK[i]]].mean()
    SBI_comp_Frac.append(masked_ssim.mean())

LS_comp_Frac = []
for i in range(8):
    NS1 = np.copy(Min_LS_extra[i][...,jj])
    NS1 = gaussian_filter(NS1, sigma=0.5)
    NS2 = np.copy(Full_LS_extra[i][...,jj])

    core,ssim_map = ssim(NS1,NS2, data_range=max([NS1.max(),NS2.max()])-min([NS1.min(),NS2.min()]),full=True,win_size=15)
    masked_ssim = ssim_map[Outlines[i][:,:,KK[i]]].mean()
    LS_comp_Frac.append(masked_ssim.mean())

SBI_LS_comp_Frac = []
for i in range(8):
    NS1 = np.copy(Full_SBI_Extra[i][...,jj])
    NS2 = np.copy(Full_LS_extra[i][...,jj])

    core,ssim_map = ssim(NS1,NS2, data_range=max([NS1.max(),NS2.max()])-min([NS1.min(),NS2.min()]),full=True,win_size=15)
    masked_ssim = ssim_map[Outlines[i][:,:,KK[i]]].mean()
    SBI_LS_comp_Frac.append(masked_ssim.mean())
Prec7_SBI_Frac = []
PrecFull_SBI_Frac = []

Prec7_NLLS_Frac = []
PrecFull_NLLS_Frac = []
for i in range(8):
    Prec7_SBI_Frac.append(np.std(Min_SBI_Extra[i][...,jj][WMs[i].astype(bool)[:,:,48]]))
    PrecFull_SBI_Frac.append(np.std(Full_SBI_Extra[i][...,jj][WMs[i].astype(bool)[:,:,48]]))

    Prec7_NLLS_Frac.append(np.std(Min_LS_extra[i][...,jj][WMs[i].astype(bool)[:,:,48]]))
    PrecFull_NLLS_Frac.append(np.std(Full_LS_extra[i][...,jj][WMs[i].astype(bool)[:,:,48]]))

In [None]:
jj = 3
SBI_comp_Dp = []
KK = [48]*8
for i in range(8):
    NS1 = np.copy(Min_SBI_Extra[i][...,jj])
    NS1 = gaussian_filter(NS1, sigma=0.5)
    NS2 = np.copy(Full_SBI_Extra[i][...,jj])

    core,ssim_map = ssim(NS1,NS2, data_range=max([NS1.max(),NS2.max()])-min([NS1.min(),NS2.min()]),full=True,win_size=15)
    masked_ssim = ssim_map[Outlines[i][:,:,KK[i]]].mean()
    SBI_comp_Dp.append(masked_ssim.mean())

LS_comp_Dp = []
for i in range(8):
    NS1 = np.copy(Min_LS_extra[i][...,jj])
    NS1 = gaussian_filter(NS1, sigma=0.5)
    NS2 = np.copy(Full_LS_extra[i][...,jj])

    core,ssim_map = ssim(NS1,NS2, data_range=max([NS1.max(),NS2.max()])-min([NS1.min(),NS2.min()]),full=True,win_size=15)
    masked_ssim = ssim_map[Outlines[i][:,:,KK[i]]].mean()
    LS_comp_Dp.append(masked_ssim.mean())

SBI_LS_comp_Dp = []
for i in range(8):
    NS1 = np.copy(Full_SBI_Extra[i][...,jj])
    NS2 = np.copy(Full_LS_extra[i][...,jj])

    core,ssim_map = ssim(NS1,NS2, data_range=max([NS1.max(),NS2.max()])-min([NS1.min(),NS2.min()]),full=True,win_size=15)
    masked_ssim = ssim_map[Outlines[i][:,:,KK[i]]].mean()
    SBI_LS_comp_Dp.append(masked_ssim.mean())
Prec7_SBI_Dp = []
PrecFull_SBI_Dp = []

Prec7_NLLS_Dp = []
PrecFull_NLLS_Dp = []
for i in range(8):
    Prec7_SBI_Dp.append(np.std(Min_SBI_Extra[i][...,jj][WMs[i].astype(bool)[:,:,48]]))
    PrecFull_SBI_Dp.append(np.std(Full_SBI_Extra[i][...,jj][WMs[i].astype(bool)[:,:,48]]))

    Prec7_NLLS_Dp.append(np.std(Min_LS_extra[i][...,jj][WMs[i].astype(bool)[:,:,48]]))
    PrecFull_NLLS_Dp.append(np.std(Full_LS_extra[i][...,jj][WMs[i].astype(bool)[:,:,48]]))


## a

In [None]:


# -----------------------------
# Parameters
# -----------------------------
r = 1.0  # sphere radius
vector = np.array([-0.5, -1, 1])   # arbitrary vector
n = vector / np.linalg.norm(vector)  # unit vector in the direction of 'vector'
intersection = n * r  # intersection of the vector with the sphere

# Circle parameters (geodesic circle on the sphere)
circle_angle_deg = 15  # angular radius in degrees
alpha1 = [(S[:,2].mean()) for S in SBI_Errors][-1]

# -----------------------------
# Construct a circle on the sphere
# -----------------------------
# To draw a circle on the sphere centered at 'intersection',
# we use the following idea:
# For a given center n (a point on the unit sphere) and an angular radius alpha,
# any point on the circle can be written as:
#   P(t) = cos(alpha)*n + sin(alpha)*(cos(t)*u + sin(t)*w)
# where u and w are any two orthonormal vectors spanning the tangent plane at n.

# First, choose u as a vector perpendicular to n.
# (If n is parallel to the z-axis, choose a different axis to avoid the zero vector.)
if np.allclose(n, [0, 0, 1]):
    u = np.array([1, 0, 0])
else:
    u = np.cross(n, [0, 0, 1])
    u = u / np.linalg.norm(u)

# Then, w is perpendicular to both n and u.
w = np.cross(n, u)


# -----------------------------
# Create the sphere mesh
# -----------------------------
phi = np.linspace(0, 2 * np.pi, 500)  # azimuthal angle
theta = np.linspace(0, np.pi, 500)      # polar angle

phi, theta = np.meshgrid(phi, theta)
x_sphere = r * np.sin(theta) * np.cos(phi)
y_sphere = r * np.sin(theta) * np.sin(phi)
z_sphere = r * np.cos(theta)

# -----------------------------
# Plot everything
# -----------------------------
fig = plt.figure(figsize=(8, 8))
ax = fig.add_subplot(111, projection='3d')


# Plot the vector (using quiver)
ax.quiver(0, 0, 0, intersection[0], intersection[1], intersection[2],
          color='r', linewidth=2, arrow_length_ratio=0.1)

# Plot the circle on the sphere
# Create points around the circle
t_vals = np.linspace(0, 2 * np.pi, 200)
circle_points = np.array([
    np.cos(alpha1) * n + np.sin(alpha1) * (np.cos(t) * u + np.sin(t) * w)
    for t in t_vals
])

ax.plot(circle_points[:, 0], circle_points[:, 1], circle_points[:, 2], color='paleturquoise', linewidth=2,ls='--')

circle_angle_deg = 15  # angular radius in degrees
alpha2 = [(S[:,2].mean()) for S in SBI_Errors_Min][-1]

# Plot the circle on the sphere
# Create points around the circle
t_vals = np.linspace(0, 2 * np.pi, 100)
circle_points = np.array([
    np.cos(alpha2) * n + np.sin(alpha2) * (np.cos(t) * u + np.sin(t) * w)
    for t in t_vals
])

ax.plot(circle_points[:, 0], circle_points[:, 1], circle_points[:, 2], color='lightseagreen', linewidth=2,ls='--')

alpha3 = [(S[:,2].mean()) for S in LS_Errors][-1]

# Plot the circle on the sphere
# Create points around the circle
t_vals = np.linspace(0, 2 * np.pi, 100)
circle_points = np.array([
    np.cos(alpha3) * n + np.sin(alpha3) * (np.cos(t) * u + np.sin(t) * w)
    for t in t_vals
])

ax.plot(circle_points[:, 0], circle_points[:, 1], circle_points[:, 2], color='sandybrown', linewidth=2,ls='--')

alpha4 = [(S[:,2].mean()) for S in LS_Errors_Min][-1]

# Plot the circle on the sphere
# Create points around the circle
t_vals = np.linspace(0, 2 * np.pi, 100)
circle_points = np.array([
    np.cos(alpha4) * n + np.sin(alpha4) * (np.cos(t) * u + np.sin(t) * w)
    for t in t_vals
])

ax.plot(circle_points[:, 0], circle_points[:, 1], circle_points[:, 2], color='darkorange', linewidth=2,ls='--')

# Set equal aspect ratio for all axes
max_range = r * 1.2
for axis in 'xyz':
    getattr(ax, 'set_{}lim'.format(axis))((-max_range, max_range))


dot = x_sphere * n[0] + y_sphere * n[1] + z_sphere * n[2]
mask = (dot < np.cos(alpha3)) + (dot > np.cos(alpha4))
# Mask out points that are not in the spherical cap (set them to NaN)
x_sphere_masked = np.where(mask, np.nan, x_sphere)
y_sphere_masked = np.where(mask, np.nan, y_sphere)
z_sphere_masked = np.where(mask, np.nan, z_sphere)

# -----------------------------
# Plot everything
# -----------------------------
dot = x_sphere * n[0] + y_sphere * n[1] + z_sphere * n[2]
mask = (dot > np.cos(alpha3)) + (dot < np.cos(alpha4))
# Mask out points that are not in the spherical cap (set them to NaN)
x_sphere_masked = np.where(mask, np.nan, x_sphere)
y_sphere_masked = np.where(mask, np.nan, y_sphere)
z_sphere_masked = np.where(mask, np.nan, z_sphere)

# Plot the spherical cap (inside the circle) with transparency
ax.plot_surface(x_sphere_masked, y_sphere_masked, z_sphere_masked,
                color='darkorange',shade=False, alpha=0.5, rstride=2, cstride=2, edgecolor='none')

dot = x_sphere * n[0] + y_sphere * n[1] + z_sphere * n[2]
mask = (dot > np.cos(alpha2)) + (dot < np.cos(alpha3))
# Mask out points that are not in the spherical cap (set them to NaN)
x_sphere_masked = np.where(mask, np.nan, x_sphere)
y_sphere_masked = np.where(mask, np.nan, y_sphere)
z_sphere_masked = np.where(mask, np.nan, z_sphere)

# Plot the spherical cap (inside the circle) with transparency
ax.plot_surface(x_sphere_masked, y_sphere_masked, z_sphere_masked,
                color='sandybrown',shade=False, alpha=0.5, rstride=2, cstride=2, edgecolor='none')

dot = x_sphere * n[0] + y_sphere * n[1] + z_sphere * n[2]
mask = (dot > np.cos(alpha1)) + (dot < np.cos(alpha2))
# Mask out points that are not in the spherical cap (set them to NaN)
x_sphere_masked = np.where(mask, np.nan, x_sphere)
y_sphere_masked = np.where(mask, np.nan, y_sphere)
z_sphere_masked = np.where(mask, np.nan, z_sphere)

# Plot the spherical cap (inside the circle) with transparency
ax.plot_surface(x_sphere_masked, y_sphere_masked, z_sphere_masked,
                color='lightseagreen',shade=False, alpha=0.5, rstride=2, cstride=2, edgecolor='none')

dot = x_sphere * n[0] + y_sphere * n[1] + z_sphere * n[2]
mask = (dot < np.cos(alpha1))
# Mask out points that are not in the spherical cap (set them to NaN)
x_sphere_masked = np.where(mask, np.nan, x_sphere)
y_sphere_masked = np.where(mask, np.nan, y_sphere)
z_sphere_masked = np.where(mask, np.nan, z_sphere)

# Plot the spherical cap (inside the circle) with transparency
ax.plot_surface(x_sphere_masked, y_sphere_masked, z_sphere_masked,
                color='paleturquoise',alpha=0.5,linewidth=0,rstride=1, cstride=1, shade=False,)

dot = x_sphere * n[0] + y_sphere * n[1] + z_sphere * n[2]
mask = (dot > np.cos(alpha4))
# Mask out points that are not in the spherical cap (set them to NaN)
x_sphere_masked = np.where(mask, np.nan, x_sphere)
y_sphere_masked = np.where(mask, np.nan, y_sphere)
z_sphere_masked = np.where(mask, np.nan, z_sphere)

# Plot the spherical cap (inside the circle) with transparency
ax.plot_surface(x_sphere_masked, y_sphere_masked, z_sphere_masked,
                color='gray', alpha=0.2, rstride=2, cstride=2, edgecolor='none')

ax.axis('equal')
ax.axis('off')
ax.view_init(elev=20, azim=-85)

minLS_patch = mpatches.Patch(color='darkorange', label='Reduced NLLS')
fullLS_patch = mpatches.Patch(color='sandybrown', label='Full NLLS')

minSBI_patch = mpatches.Patch(color='lightseagreen', label='Reduced SBI')
fullSBI_patch = mpatches.Patch(color='paleturquoise', label='Full SBI')

ax.legend(
    handles=[minLS_patch,minSBI_patch,fullLS_patch,fullSBI_patch],
    loc='lower left',         # base location  # fine-tune the legend's position
    frameon=False, ncols=2,
    bbox_to_anchor=(0.18, 0.09),fontsize=18,
    columnspacing=0.5,
    handlelength=0.8,
)
ax.set_title('Average angle diff.',x=0.52, y=0.825,fontsize=24)


if Save: plt.savefig(FigLoc+'AngleErr.pdf',bbox_inches='tight',format='pdf',transparent=True)

## b

In [None]:
g_pos = np.array([0, 2, 4,6])*2
POSITIONS = g_pos
jitter = 0.04
x_data = [np.array([POSITIONS[i]] * len(d)) for i, d in enumerate(SBI_Errors[:,:,1]*1000,)]
x_jittered = [x + stats.t(df=6, scale=jitter).rvs(len(x)) for x in x_data]

fig,ax = plt.subplots(figsize=(8,4))    

# Add boxplots ---------------------------------------------------
# Note that properties about the median and the box are passed
# as dictionaries.

# Colors
BG_WHITE = "#fbf9f4"
GREY_LIGHT = "#b4aea9"
GREY50 = "#7F7F7F"
BLUE_DARK = "#1B2838"
BLUE = "#2a475e"
BLACK = "#282724"
GREY_DARK = "#747473"
RED_DARK = "#850e00"

medianprops = dict(
    linewidth=2, 
    color=GREY_DARK,
    solid_capstyle="butt"
)
boxprops = dict(
    linewidth=2, 
    color='turquoise'
)

ax.boxplot(
    SBI_Errors[:,:,1].T*1000,
    positions=POSITIONS, 
    showfliers = False, # Do not show the outliers beyond the caps.
    showcaps = False,   # Do not show the caps
    medianprops = medianprops,
    whiskerprops = boxprops,
    boxprops = boxprops
)

# Add jittered dots ----------------------------------------------
for x, y in zip(x_jittered, SBI_Errors[:,:,1]*1000):
    ax.scatter(x, y, s = 100, color='paleturquoise', alpha=0.8)

POSITIONS = g_pos+0.5

jitter = 0.04
x_data = [np.array([POSITIONS[i]] * len(d)) for i, d in enumerate(SBI_Errors_Min[:,:,1],)]
x_jittered = [x + stats.t(df=6, scale=jitter).rvs(len(x)) for x in x_data]

    

# Add boxplots ---------------------------------------------------
# Note that properties about the median and the box are passed
# as dictionaries.

medianprops = dict(
    linewidth=2, 
    color=GREY_DARK,
    solid_capstyle="butt"
)
boxprops = dict(
    linewidth=2, 
    color='cadetblue'
)

ax.boxplot(
    SBI_Errors_Min[:,:,1].T*1000,
    positions=POSITIONS, 
    showfliers = False, # Do not show the outliers beyond the caps.
    showcaps = False,   # Do not show the caps
    medianprops = medianprops,
    whiskerprops = boxprops,
    boxprops = boxprops
)

# Add jittered dots ----------------------------------------------
for x, y in zip(x_jittered, SBI_Errors_Min[:,:,1]*1000):
    ax.scatter(x, y, s = 100, color='darkturquoise', alpha=0.8)


POSITIONS = g_pos+1.5

jitter = 0.04
x_data = [np.array([POSITIONS[i]] * len(d)) for i, d in enumerate(LS_Errors[:,:,1],)]
x_jittered = [x + stats.t(df=6, scale=jitter).rvs(len(x)) for x in x_data]

    

# Add boxplots ---------------------------------------------------
# Note that properties about the median and the box are passed
# as dictionaries.

medianprops = dict(
    linewidth=2, 
    color=GREY_DARK,
    solid_capstyle="butt"
)
boxprops = dict(
    linewidth=2, 
    color='sandybrown'
)

ax.boxplot(
    LS_Errors[:,:,1].T*1000,
    positions=POSITIONS, 
    showfliers = False, # Do not show the outliers beyond the caps.
    showcaps = False,   # Do not show the caps
    medianprops = medianprops,
    whiskerprops = boxprops,
    boxprops = boxprops
)

# Add jittered dots ----------------------------------------------
for x, y in zip(x_jittered, LS_Errors[:,:,1]*1000):
    ax.scatter(x, y, s = 100, color='peachpuff', alpha=0.8)

POSITIONS = g_pos+2

jitter = 0.04
x_data = [np.array([POSITIONS[i]] * len(d)) for i, d in enumerate(LS_Errors_Min[:,:,1]*1000,)]
x_jittered = [x + stats.t(df=6, scale=jitter).rvs(len(x)) for x in x_data]



# Add boxplots ---------------------------------------------------
# Note that properties about the median and the box are passed
# as dictionaries.

medianprops = dict(
    linewidth=2, 
    color=GREY_DARK,
    solid_capstyle="butt"
)
boxprops = dict(
    linewidth=2, 
    color='darkorange'
)

ax.boxplot(
    LS_Errors_Min[:,:,1].T*1000,
    positions=POSITIONS, 
    showfliers = False, # Do not show the outliers beyond the caps.
    showcaps = False,   # Do not show the caps
    medianprops = medianprops,
    whiskerprops = boxprops,
    boxprops = boxprops
)

# Add jittered dots ----------------------------------------------
for x, y in zip(x_jittered, LS_Errors_Min[:,:,1]*1000):
    ax.scatter(x, y, s = 100, color='orange', alpha=0.8)

ax.set_xticks([1,5,9,13],['2','10','20','30'],fontsize=24)
ax.set_xlabel('SNR',fontsize=32)
ax.tick_params(axis='x', labelsize=24)
ax.ticklabel_format(axis='y',style='sci',scilimits=(-1,1))
ax.tick_params(axis='y', labelsize=24,)
ax.yaxis.get_offset_text().set_fontsize(24)

minLS_patch = mpatches.Patch(color='darkorange', label='Reduced NLLS')
fullLS_patch = mpatches.Patch(color='sandybrown', label='Full NLLS')

minSBI_patch = mpatches.Patch(color='lightseagreen', label='Reduced SBI')
fullSBI_patch = mpatches.Patch(color='paleturquoise', label='Full SBI')

ax.legend(
    handles=[minLS_patch,minSBI_patch,fullLS_patch,fullSBI_patch],
    loc='lower left',         # base location  # fine-tune the legend's position
    frameon=False, ncols=2,
    bbox_to_anchor=(0.12, 0.8),fontsize=24,
    columnspacing=0.5,
    handlelength=0.8,
)

## c

In [None]:
g_pos = np.array([0, 2, 4,6])*2
POSITIONS = g_pos
jitter = 0.04
x_data = [np.array([POSITIONS[i]] * len(d)) for i, d in enumerate(SBI_Errors[:,:,-1],)]
x_jittered = [x + stats.t(df=6, scale=jitter).rvs(len(x)) for x in x_data]

fig,ax = plt.subplots(figsize=(8,4))

    

# Add boxplots ---------------------------------------------------
# Note that properties about the median and the box are passed
# as dictionaries.

medianprops = dict(
    linewidth=2, 
    color=GREY_DARK,
    solid_capstyle="butt"
)
boxprops = dict(
    linewidth=2, 
    color='turquoise'
)

ax.boxplot(
    SBI_Errors[:,:,-1].T,
    positions=POSITIONS, 
    showfliers = False, # Do not show the outliers beyond the caps.
    showcaps = False,   # Do not show the caps
    medianprops = medianprops,
    whiskerprops = boxprops,
    boxprops = boxprops
)

# Add jittered dots ----------------------------------------------
for x, y in zip(x_jittered, SBI_Errors[:,:,-1]):
    ax.scatter(x, y, s = 100, color='paleturquoise', alpha=0.8)

POSITIONS = g_pos+0.5

jitter = 0.04
x_data = [np.array([POSITIONS[i]] * len(d)) for i, d in enumerate(SBI_Errors_Min[:,:,-1],)]
x_jittered = [x + stats.t(df=6, scale=jitter).rvs(len(x)) for x in x_data]


    

# Add boxplots ---------------------------------------------------
# Note that properties about the median and the box are passed
# as dictionaries.

medianprops = dict(
    linewidth=2, 
    color=GREY_DARK,
    solid_capstyle="butt"
)
boxprops = dict(
    linewidth=2, 
    color='cadetblue'
)

ax.boxplot(
    SBI_Errors_Min[:,:,-1].T,
    positions=POSITIONS, 
    showfliers = False, # Do not show the outliers beyond the caps.
    showcaps = False,   # Do not show the caps
    medianprops = medianprops,
    whiskerprops = boxprops,
    boxprops = boxprops
)

# Add jittered dots ----------------------------------------------
for x, y in zip(x_jittered, SBI_Errors_Min[:,:,-1]):
    ax.scatter(x, y, s = 100, color='darkturquoise', alpha=0.8)


POSITIONS = g_pos+1.5

jitter = 0.04
x_data = [np.array([POSITIONS[i]] * len(d)) for i, d in enumerate(LS_Errors[:,:,-1],)]
x_jittered = [x + stats.t(df=6, scale=jitter).rvs(len(x)) for x in x_data]
    

# Add boxplots ---------------------------------------------------
# Note that properties about the median and the box are passed
# as dictionaries.

medianprops = dict(
    linewidth=2, 
    color=GREY_DARK,
    solid_capstyle="butt"
)
boxprops = dict(
    linewidth=2, 
    color='sandybrown'
)

ax.boxplot(
    LS_Errors[:,:,-1].T,
    positions=POSITIONS, 
    showfliers = False, # Do not show the outliers beyond the caps.
    showcaps = False,   # Do not show the caps
    medianprops = medianprops,
    whiskerprops = boxprops,
    boxprops = boxprops
)

# Add jittered dots ----------------------------------------------
for x, y in zip(x_jittered, LS_Errors[:,:,-1]):
    ax.scatter(x, y, s = 100, color='peachpuff', alpha=0.8)

POSITIONS = g_pos+2

jitter = 0.04
x_data = [np.array([POSITIONS[i]] * len(d)) for i, d in enumerate(LS_Errors_Min[:,:,-1],)]
x_jittered = [x + stats.t(df=6, scale=jitter).rvs(len(x)) for x in x_data]

    

# Add boxplots ---------------------------------------------------
# Note that properties about the median and the box are passed
# as dictionaries.

medianprops = dict(
    linewidth=2, 
    color=GREY_DARK,
    solid_capstyle="butt"
)
boxprops = dict(
    linewidth=2, 
    color='darkorange'
)

ax.boxplot(
    LS_Errors_Min[:,:,-1].T,
    positions=POSITIONS, 
    showfliers = False, # Do not show the outliers beyond the caps.
    showcaps = False,   # Do not show the caps
    medianprops = medianprops,
    whiskerprops = boxprops,
    boxprops = boxprops
)

# Add jittered dots ----------------------------------------------
for x, y in zip(x_jittered, LS_Errors_Min[:,:,-1]):
    ax.scatter(x, y, s = 100, color='orange', alpha=0.8)

ax.set_xticks([1,5,9,13],['2','10','20','30'],fontsize=24)
ax.set_xlabel('SNR',fontsize=32)
ax.tick_params(axis='x', labelsize=24)
ax.ticklabel_format(axis='y',style='sci',scilimits=(-1,1))
ax.tick_params(axis='y', labelsize=24,)
ax.yaxis.get_offset_text().set_fontsize(24)

minLS_patch = mpatches.Patch(color='darkorange', label='Minimum NLLS')
fullLS_patch = mpatches.Patch(color='sandybrown', label='Full NLLS')

minSBI_patch = mpatches.Patch(color='lightseagreen', label='Minimum SBI')
fullSBI_patch = mpatches.Patch(color='paleturquoise', label='Full SBI')

plt.ylim([-0.1,1])


## d

In [None]:
g_pos = np.array([0, 2, 4,6])*2
POSITIONS = g_pos
jitter = 0.04
x_data = [np.array([POSITIONS[i]] * len(d)) for i, d in enumerate(SBI_Errors[:,:,-4],)]
x_jittered = [x + stats.t(df=6, scale=jitter).rvs(len(x)) for x in x_data]

fig,ax = plt.subplots(figsize=(8,4))

    

# Add boxplots ---------------------------------------------------
# Note that properties about the median and the box are passed
# as dictionaries.

medianprops = dict(
    linewidth=2, 
    color=GREY_DARK,
    solid_capstyle="butt"
)
boxprops = dict(
    linewidth=2, 
    color='turquoise'
)

ax.boxplot(
    SBI_Errors[:,:,-4].T,
    positions=POSITIONS, 
    showfliers = False, # Do not show the outliers beyond the caps.
    showcaps = False,   # Do not show the caps
    medianprops = medianprops,
    whiskerprops = boxprops,
    boxprops = boxprops
)

# Add jittered dots ----------------------------------------------
for x, y in zip(x_jittered, SBI_Errors[:,:,-4]):
    ax.scatter(x, y, s = 100, color='paleturquoise', alpha=0.8)

POSITIONS = g_pos+0.5

jitter = 0.04
x_data = [np.array([POSITIONS[i]] * len(d)) for i, d in enumerate(SBI_Errors_Min[:,:,-4],)]
x_jittered = [x + stats.t(df=6, scale=jitter).rvs(len(x)) for x in x_data]

    

# Add boxplots ---------------------------------------------------
# Note that properties about the median and the box are passed
# as dictionaries.

medianprops = dict(
    linewidth=2, 
    color=GREY_DARK,
    solid_capstyle="butt"
)
boxprops = dict(
    linewidth=2, 
    color='cadetblue'
)

ax.boxplot(
    SBI_Errors_Min[:,:,-4].T,
    positions=POSITIONS, 
    showfliers = False, # Do not show the outliers beyond the caps.
    showcaps = False,   # Do not show the caps
    medianprops = medianprops,
    whiskerprops = boxprops,
    boxprops = boxprops
)

# Add jittered dots ----------------------------------------------
for x, y in zip(x_jittered, SBI_Errors_Min[:,:,-4]):
    ax.scatter(x, y, s = 100, color='darkturquoise', alpha=0.8)


POSITIONS = g_pos+1.5

jitter = 0.04
x_data = [np.array([POSITIONS[i]] * len(d)) for i, d in enumerate(LS_Errors[:,:,-4],)]
x_jittered = [x + stats.t(df=6, scale=jitter).rvs(len(x)) for x in x_data]


    

# Add boxplots ---------------------------------------------------
# Note that properties about the median and the box are passed
# as dictionaries.

medianprops = dict(
    linewidth=2, 
    color=GREY_DARK,
    solid_capstyle="butt"
)
boxprops = dict(
    linewidth=2, 
    color='sandybrown'
)

ax.boxplot(
    LS_Errors[:,:,-4].T,
    positions=POSITIONS, 
    showfliers = False, # Do not show the outliers beyond the caps.
    showcaps = False,   # Do not show the caps
    medianprops = medianprops,
    whiskerprops = boxprops,
    boxprops = boxprops
)

# Add jittered dots ----------------------------------------------
for x, y in zip(x_jittered, LS_Errors[:,:,-4]):
    ax.scatter(x, y, s = 100, color='peachpuff', alpha=0.8)

POSITIONS = g_pos+2

jitter = 0.04
x_data = [np.array([POSITIONS[i]] * len(d)) for i, d in enumerate(LS_Errors_Min[:,:,-4],)]
x_jittered = [x + stats.t(df=6, scale=jitter).rvs(len(x)) for x in x_data]
    

# Add boxplots ---------------------------------------------------
# Note that properties about the median and the box are passed
# as dictionaries.

medianprops = dict(
    linewidth=2, 
    color=GREY_DARK,
    solid_capstyle="butt"
)
boxprops = dict(
    linewidth=2, 
    color='darkorange'
)

ax.boxplot(
    LS_Errors_Min[:,:,-4].T,
    positions=POSITIONS, 
    showfliers = False, # Do not show the outliers beyond the caps.
    showcaps = False,   # Do not show the caps
    medianprops = medianprops,
    whiskerprops = boxprops,
    boxprops = boxprops
)

# Add jittered dots ----------------------------------------------
for x, y in zip(x_jittered, LS_Errors_Min[:,:,-4]):
    ax.scatter(x, y, s = 100, color='orange', alpha=0.8)

ax.set_xticks([1,5,9,13],['2','10','20','30'],fontsize=24)
ax.set_xlabel('SNR',fontsize=32)
ax.tick_params(axis='x', labelsize=24)
ax.ticklabel_format(axis='y',style='sci',scilimits=(-1,1))
ax.tick_params(axis='y', labelsize=24,)
ax.yaxis.get_offset_text().set_fontsize(24)

minLS_patch = mpatches.Patch(color='darkorange', label='Minimum NLLS')
fullLS_patch = mpatches.Patch(color='sandybrown', label='Full NLLS')

minSBI_patch = mpatches.Patch(color='lightseagreen', label='Minimum SBI')
fullSBI_patch = mpatches.Patch(color='paleturquoise', label='Full SBI')


## e

In [None]:
g_pos = np.array([0, 2, 4,6])*2
POSITIONS = g_pos
jitter = 0.04
x_data = [np.array([POSITIONS[i]] * len(d)) for i, d in enumerate(SBI_Errors[:,:,-2],)]
x_jittered = [x + stats.t(df=6, scale=jitter).rvs(len(x)) for x in x_data]

fig,ax = plt.subplots(figsize=(8,4))
    

# Add boxplots ---------------------------------------------------
# Note that properties about the median and the box are passed
# as dictionaries.

medianprops = dict(
    linewidth=2, 
    color=GREY_DARK,
    solid_capstyle="butt"
)
boxprops = dict(
    linewidth=2, 
    color='turquoise'
)

ax.boxplot(
    SBI_Errors[:,:,-2].T,
    positions=POSITIONS, 
    showfliers = False, # Do not show the outliers beyond the caps.
    showcaps = False,   # Do not show the caps
    medianprops = medianprops,
    whiskerprops = boxprops,
    boxprops = boxprops
)

# Add jittered dots ----------------------------------------------
for x, y in zip(x_jittered, SBI_Errors[:,:,-2]):
    ax.scatter(x, y, s = 100, color='paleturquoise', alpha=0.8)

POSITIONS = g_pos+0.5

jitter = 0.04
x_data = [np.array([POSITIONS[i]] * len(d)) for i, d in enumerate(SBI_Errors_Min[:,:,-2],)]
x_jittered = [x + stats.t(df=6, scale=jitter).rvs(len(x)) for x in x_data]
    

# Add boxplots ---------------------------------------------------
# Note that properties about the median and the box are passed
# as dictionaries.

medianprops = dict(
    linewidth=2, 
    color=GREY_DARK,
    solid_capstyle="butt"
)
boxprops = dict(
    linewidth=2, 
    color='cadetblue'
)

ax.boxplot(
    SBI_Errors_Min[:,:,-2].T,
    positions=POSITIONS, 
    showfliers = False, # Do not show the outliers beyond the caps.
    showcaps = False,   # Do not show the caps
    medianprops = medianprops,
    whiskerprops = boxprops,
    boxprops = boxprops
)

# Add jittered dots ----------------------------------------------
for x, y in zip(x_jittered, SBI_Errors_Min[:,:,-2]):
    ax.scatter(x, y, s = 100, color='darkturquoise', alpha=0.8)


POSITIONS = g_pos+1.5

jitter = 0.04
x_data = [np.array([POSITIONS[i]] * len(d)) for i, d in enumerate(LS_Errors[:,:,-2],)]
x_jittered = [x + stats.t(df=6, scale=jitter).rvs(len(x)) for x in x_data]

    

# Add boxplots ---------------------------------------------------
# Note that properties about the median and the box are passed
# as dictionaries.

medianprops = dict(
    linewidth=2, 
    color=GREY_DARK,
    solid_capstyle="butt"
)
boxprops = dict(
    linewidth=2, 
    color='sandybrown'
)

ax.boxplot(
    LS_Errors[:,:,-2].T,
    positions=POSITIONS, 
    showfliers = False, # Do not show the outliers beyond the caps.
    showcaps = False,   # Do not show the caps
    medianprops = medianprops,
    whiskerprops = boxprops,
    boxprops = boxprops
)

# Add jittered dots ----------------------------------------------
for x, y in zip(x_jittered, LS_Errors[:,:,-2]):
    ax.scatter(x, y, s = 100, color='peachpuff', alpha=0.8)

POSITIONS = g_pos+2

jitter = 0.04
x_data = [np.array([POSITIONS[i]] * len(d)) for i, d in enumerate(LS_Errors_Min[:,:,-2],)]
x_jittered = [x + stats.t(df=6, scale=jitter).rvs(len(x)) for x in x_data]

    

# Add boxplots ---------------------------------------------------
# Note that properties about the median and the box are passed
# as dictionaries.

medianprops = dict(
    linewidth=2, 
    color=GREY_DARK,
    solid_capstyle="butt"
)
boxprops = dict(
    linewidth=2, 
    color='darkorange'
)

ax.boxplot(
    LS_Errors_Min[:,:,-2].T,
    positions=POSITIONS, 
    showfliers = False, # Do not show the outliers beyond the caps.
    showcaps = False,   # Do not show the caps
    medianprops = medianprops,
    whiskerprops = boxprops,
    boxprops = boxprops
)

# Add jittered dots ----------------------------------------------
for x, y in zip(x_jittered, LS_Errors_Min[:,:,-2]):
    ax.scatter(x, y, s = 100, color='orange', alpha=0.8)

ax.set_xticks([1,5,9,13],['2','10','20','30'],fontsize=24)
ax.set_xlabel('SNR',fontsize=32)
ax.tick_params(axis='x', labelsize=24)
ax.ticklabel_format(axis='y',style='sci',scilimits=(-1,1))
ax.tick_params(axis='y', labelsize=24,)
ax.yaxis.get_offset_text().set_fontsize(24)

minLS_patch = mpatches.Patch(color='darkorange', label='Minimum NLLS')
fullLS_patch = mpatches.Patch(color='sandybrown', label='Full NLLS')

minSBI_patch = mpatches.Patch(color='lightseagreen', label='Minimum SBI')
fullSBI_patch = mpatches.Patch(color='paleturquoise', label='Full SBI')


## f

In [None]:
plt.subplots(figsize=(12,12))
im = plt.imshow(1-NoiseEst2[...,-3],vmin=0,vmax=1,cmap='hot')
plt.axis('off')

In [None]:
plt.subplots(figsize=(12,12))
im = plt.imshow(1-NoiseEst2_LS[...,-3],vmin=0,vmax=1,cmap='hot')
cbar = plt.colorbar(im,fraction=0.035, pad=-0.1)
cbar.ax.tick_params(labelsize=32)
plt.axis('off')


if Save: plt.savefig(FigLoc+'FullSize_LS.pdf',bbox_inches='tight',format='pdf',transparent=True)

In [None]:
plt.subplots(figsize=(12,12))
im = plt.imshow(1-NoiseEst_Min[...,-3],cmap='hot',vmin=0,vmax=1)
#cbar.ax.tick_params(labelsize=14)
plt.axis('off')


if Save: plt.savefig(FigLoc+'DevSize.pdf',bbox_inches='tight',format='pdf',transparent=True)

In [None]:
plt.subplots(figsize=(12,12))
im = plt.imshow(1-NoiseEst_LS_Min[...,-3],cmap='hot')
cbar = plt.colorbar(im,fraction=0.035, pad=-0.1)
cbar.ax.tick_params(labelsize=32)
plt.axis('off')


if Save: plt.savefig((FigLoc+'DevSize_LS.pdf',bbox_inches='tight',format='pdf',transparent=True)

## g 

In [None]:
plt.subplots(figsize=(12,12))
plt.imshow(np.flipud(NoiseEst2_CC[...,-1].T),cmap='gray')
im = plt.imshow(np.flipud(NoiseEst2_CC[...,-2].T),cmap='hot',vmin=0,vmax=0.005)
cbar = plt.colorbar(im,fraction=0.035, pad=0.01,format=ticker.FormatStrFormatter('%.e'))
cbar.ax.tick_params(labelsize=14)
plt.axis('off')


In [None]:
plt.subplots(figsize=(12,12))
plt.imshow(np.flipud(NoiseEst2_Min_CC[...,-1].T),cmap='gray')
im = plt.imshow(np.flipud(NoiseEst2_Min_CC[...,-2].T),cmap='hot',vmin=0,vmax=0.005)
cbar = plt.colorbar(im,fraction=0.035, pad=0.01,format=ticker.FormatStrFormatter('%.e'))
cbar.ax.tick_params(labelsize=14)
plt.axis('off')

In [None]:
plt.subplots(figsize=(12,12))
plt.imshow(np.flipud(NoiseEst2_LS_CC[...,-1].T),cmap='gray')
im = plt.imshow(np.flipud(NoiseEst2_LS_CC[...,-2].T),cmap='hot',vmin=0,vmax=0.007)
cbar = plt.colorbar(im,fraction=0.03, pad=0.01,format=ticker.FormatStrFormatter('%2.e'))
cbar.ax.tick_params(labelsize=32)
plt.axis('off')

if Save: plt.savefig((FigLoc+'FullSize_CC_LS.pdf',bbox_inches='tight',format='pdf',transparent=True)

In [None]:
plt.subplots(figsize=(12,12))
plt.imshow(np.flipud(NoiseEst2_LS_Min_CC[...,-1].T),cmap='gray')
im = plt.imshow(np.flipud(NoiseEst2_LS_Min_CC[...,-2].T),cmap='hot',vmin=0,vmax=0.007)
cbar = plt.colorbar(im,fraction=0.03, pad=0.01,format=ticker.FormatStrFormatter('%.e'))
cbar.ax.tick_params(labelsize=32)
plt.axis('off')

if Save: plt.savefig(FigLoc+'MinSize_CC_LS.pdf',bbox_inches='tight',format='pdf',transparent=True)

## h

In [None]:
g_pos = np.array([0,0.25,0.5])

colors = ['lightseagreen','lightseagreen','lightseagreen']
colors2 = ['paleturquoise','paleturquoise','paleturquoise']
fig,ax = plt.subplots(figsize=(8,4))
y_data = [1000*abs(Min_SBI[i][CMasks[i]][:,-3]-Full_SBI[i][CMasks[i]][:,-3]) for i in range(3)]


BoxPlots2(y_data,g_pos,colors,colors2,ax)


g_pos = np.array([2,2.25,2.5])
colors = ['darkorange','darkorange','darkorange']
colors2 = ['peachpuff','peachpuff','peachpuff']
y_data = [1000*abs(Full_LS[i][CMasks[i]][:,-2]-Min_LS[i][CMasks[i]][:,-2]) for i in range(3)]

BoxPlots2(y_data,g_pos,colors,colors2,ax)

ax.set_xticks([0.25,2.25],['SBI Comp','NLLS Comp'],fontsize =24)

ax.ticklabel_format(axis='y',style='sci',scilimits=(-1,1))
ax.tick_params(axis='y', labelsize=24,)
ax.yaxis.get_offset_text().set_fontsize(24)


if Save: plt.savefig(FigLoc+'CC_3Indv_Comp.pdf',bbox_inches='tight',format='pdf',transparent=True)

## i-j

In [None]:
fig, ax = plt.subplots(figsize=(3.2,4.8))#, sharex=True)

# Plotting on ax1
plt.sca(ax)
y_data = np.array(SBI_comp_Frac)
g_pos = np.array([1.3])
colors = ['lightseagreen']
colors2 = ['paleturquoise']

BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2,scatter=True,scatter_alpha=0.5)

y_data = np.array(LS_comp_Frac)
g_pos = np.array([1.9])
colors = ['sandybrown']
colors2 = ['peachpuff']

BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2,scatter=True)

plt.axhline(0.66, lw=3, ls='--', c='k')
plt.xticks([1.3,1.9],['SBI','NNLS'],fontsize=32,rotation=90)
ax.set_yticks([0,0.2,0.4,0.6,0.8,1.0])
ax.set_ylim(-0.1,1)

if Save: plt.savefig(FigLoc+'MS_Ax_SSIM_Frac.pdf',format='PDF',transparent=True,bbox_inches='tight')

In [None]:
fig,ax1 = plt.subplots(1,1,figsize=(3.2,4.8))

y_data = np.array(PrecFull_SBI_Frac)
g_pos = np.array([0.65])
colors = ['mediumturquoise']
colors2 = ['paleturquoise']

BoxPlots(y_data,g_pos,colors,colors2,ax1,widths=0.2,scatter=True)

y_data = np.array(Prec7_SBI_Frac)
g_pos = np.array([1.1])

BoxPlots(y_data,g_pos,colors,colors2,ax1,widths=0.2,scatter=True)

y_data = np.array(PrecFull_NLLS_Frac)
g_pos = np.array([1.8])
colors = ['sandybrown']
colors2 = ['peachpuff']

BoxPlots(y_data,g_pos,colors,colors2,ax1,widths=0.2,scatter=True)


y_data = np.array(Prec7_NLLS_Frac)
g_pos = np.array([2.15])

BoxPlots(y_data,g_pos,colors,colors2,ax1,widths=0.2,scatter=True)
plt.xticks([0.65,1.1,1.8,2.15],['Full','Red.','Full','Red.'],fontsize=32,rotation=90)

x = np.arange(1.7,2.3,0.05)
y1 = np.ones_like(x)*np.percentile(np.array(PrecFull_NLLS_Frac)[~np.isnan(PrecFull_NLLS_Frac)], 25)
y2 = np.ones_like(x)*np.percentile(np.array(PrecFull_NLLS_Frac)[~np.isnan(PrecFull_NLLS_Frac)], 77)
plt.fill_between(x,y1,y2,color='sandybrown',zorder=10,alpha=0.2,hatch='//')

x = np.arange(0.55,1.25,0.05)
y1 = np.ones_like(x)*np.percentile(np.array(PrecFull_SBI_Frac)[~np.isnan(PrecFull_SBI_Frac)], 25)
y2 = np.ones_like(x)*np.percentile(np.array(PrecFull_SBI_Frac)[~np.isnan(PrecFull_SBI_Frac)], 77)
plt.fill_between(x,y1,y2,color='mediumturquoise',zorder=10,alpha=0.2,hatch='//')

#ax1.set_xlim(0.3,2.8)
ax1.ticklabel_format(axis='y',style='sci',scilimits=(-1,1))
ax1.set_yticks([0,0.1,0.2])
if Save: plt.savefig(FigLoc+'MS_Ax_Prec_Frac.pdf',format='PDF',transparent=True,bbox_inches='tight')

In [None]:
fig, ax = plt.subplots(figsize=(3.2,4.8))#, sharex=True)

# Plotting on ax1
plt.sca(ax)
y_data = np.array(SBI_comp_MD)
g_pos = np.array([1.3])
colors = ['lightseagreen']
colors2 = ['paleturquoise']

BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2,scatter=True,scatter_alpha=0.5)

y_data = np.array(LS_comp_MD)
g_pos = np.array([1.9])
colors = ['sandybrown']
colors2 = ['peachpuff']

BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2,scatter=True)

plt.axhline(0.66, lw=3, ls='--', c='k')
plt.xticks([1.3,1.9],['SBI','NNLS'],fontsize=32,rotation=90)
ax.set_yticks([0,0.2,0.4,0.6,0.8,1.0])
ax.set_ylim(-0.1,1)

if Save: plt.savefig(FigLoc+'MS_Ax_SSIM_MD.pdf',format='PDF',transparent=True,bbox_inches='tight')

In [None]:
fig,ax1 = plt.subplots(1,1,figsize=(3.2,4.8))

y_data = np.array(PrecFull_SBI_MD)
g_pos = np.array([0.65])
colors = ['mediumturquoise']
colors2 = ['paleturquoise']

BoxPlots(y_data,g_pos,colors,colors2,ax1,widths=0.2,scatter=True)

y_data = np.array(Prec7_SBI_MD)
g_pos = np.array([1.1])

BoxPlots(y_data,g_pos,colors,colors2,ax1,widths=0.2,scatter=True)

y_data = np.array(PrecFull_NLLS_MD)
g_pos = np.array([1.8])
colors = ['sandybrown']
colors2 = ['peachpuff']

BoxPlots(y_data,g_pos,colors,colors2,ax1,widths=0.2,scatter=True)


y_data = np.array(Prec7_NLLS_MD)
g_pos = np.array([2.15])

BoxPlots(y_data,g_pos,colors,colors2,ax1,widths=0.2,scatter=True)
plt.xticks([0.65,1.1,1.8,2.15],['Full','Red.','Full','Red.'],fontsize=32,rotation=90)

x = np.arange(1.7,2.3,0.05)
y1 = np.ones_like(x)*np.percentile(np.array(PrecFull_NLLS_MD)[~np.isnan(PrecFull_NLLS_MD)], 25)
y2 = np.ones_like(x)*np.percentile(np.array(PrecFull_NLLS_MD)[~np.isnan(PrecFull_NLLS_MD)], 77)
plt.fill_between(x,y1,y2,color='sandybrown',zorder=10,alpha=0.2,hatch='//')

x = np.arange(0.55,1.25,0.05)
y1 = np.ones_like(x)*np.percentile(np.array(PrecFull_SBI_MD)[~np.isnan(PrecFull_SBI_MD)], 25)
y2 = np.ones_like(x)*np.percentile(np.array(PrecFull_SBI_MD)[~np.isnan(PrecFull_SBI_MD)], 77)
plt.fill_between(x,y1,y2,color='mediumturquoise',zorder=10,alpha=0.2,hatch='//')

#ax1.set_xlim(0.3,2.8)
ax1.ticklabel_format(axis='y',style='sci',scilimits=(-1,1))
if Save: plt.savefig(FigLoc+'MS_Ax_Prec_MD.pdf',format='PDF',transparent=True,bbox_inches='tight')

In [None]:
fig, ax = plt.subplots(figsize=(3.2,4.8))#, sharex=True)

# Plotting on ax1
plt.sca(ax)
y_data = np.array(SBI_comp_Dp)
g_pos = np.array([1.3])
colors = ['lightseagreen']
colors2 = ['paleturquoise']

BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2,scatter=True,scatter_alpha=0.5)

y_data = np.array(LS_comp_Dp)
g_pos = np.array([1.9])
colors = ['sandybrown']
colors2 = ['peachpuff']

BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2,scatter=True)

plt.axhline(0.66, lw=3, ls='--', c='k')
plt.xticks([1.3,1.9],['SBI','NNLS'],fontsize=32,rotation=90)
ax.set_yticks([0,0.2,0.4,0.6,0.8,1.0])
ax.set_ylim(-0.1,1)

if Save: plt.savefig(FigLoc+'MS_Ax_SSIM_Dperp.pdf',format='PDF',transparent=True,bbox_inches='tight')

In [None]:
fig,ax1 = plt.subplots(1,1,figsize=(3.2,4.8))

y_data = np.array(PrecFull_SBI_Dp)
g_pos = np.array([0.65])
colors = ['mediumturquoise']
colors2 = ['paleturquoise']

BoxPlots(y_data,g_pos,colors,colors2,ax1,widths=0.2,scatter=True)

y_data = np.array(Prec7_SBI_Dp)
g_pos = np.array([1.1])

BoxPlots(y_data,g_pos,colors,colors2,ax1,widths=0.2,scatter=True)

y_data = np.array(PrecFull_NLLS_Dp)
g_pos = np.array([1.8])
colors = ['sandybrown']
colors2 = ['peachpuff']

BoxPlots(y_data,g_pos,colors,colors2,ax1,widths=0.2,scatter=True)


y_data = np.array(Prec7_NLLS_Dp)
g_pos = np.array([2.15])

BoxPlots(y_data,g_pos,colors,colors2,ax1,widths=0.2,scatter=True)
plt.xticks([0.65,1.1,1.8,2.15],['Full','Red.','Full','Red.'],fontsize=32,rotation=90)

x = np.arange(1.7,2.3,0.05)
y1 = np.ones_like(x)*np.percentile(np.array(PrecFull_NLLS_Dp)[~np.isnan(PrecFull_NLLS_Dp)], 25)
y2 = np.ones_like(x)*np.percentile(np.array(PrecFull_NLLS_Dp)[~np.isnan(PrecFull_NLLS_Dp)], 77)
plt.fill_between(x,y1,y2,color='sandybrown',zorder=10,alpha=0.2,hatch='//')

x = np.arange(0.55,1.25,0.05)
y1 = np.ones_like(x)*np.percentile(np.array(PrecFull_SBI_Dp)[~np.isnan(PrecFull_SBI_Dp)], 25)
y2 = np.ones_like(x)*np.percentile(np.array(PrecFull_SBI_Dp)[~np.isnan(PrecFull_SBI_Dp)], 77)
plt.fill_between(x,y1,y2,color='mediumturquoise',zorder=10,alpha=0.2,hatch='//')

#ax1.set_xlim(0.3,2.8)
ax1.ticklabel_format(axis='y',style='sci',scilimits=(-1,1))
if Save: plt.savefig(FigLoc+'MS_Ax_Prec_Dperp.pdf',format='PDF',transparent=True,bbox_inches='tight')