# Imports

In [None]:
from DTI_funcs import *

%load_ext autoreload
%autoreload 2

In [None]:
Save = False

InferSamples = 1000
TestSamples = 200

## Network

In [None]:
fdwi = HCPDir + '/Pat'+str(1)+'/diff_1k.nii.gz'
bvalloc = HCPDir + '/Pat'+str(1)+'/bvals_1k.txt'
bvecloc = HCPDir + '/Pat'+str(1)+'/bvecs_1k.txt'

bvalsHCP = np.loadtxt(bvalloc)
bvecsHCP = np.loadtxt(bvecloc)
gtabHCP = gradient_table(bvals = bvalsHCP, bvecs = bvecsHCP)

selected_indices = [1]
distance_matrix = squareform(pdist(bvecsHCP))
# Iteratively select the point furthest from the current selection
for _ in range(5):  # We need 7 points in total, and one is already selected
    remaining_indices = list(set(range(len(bvecsHCP))) - set(selected_indices))
    
    # Calculate the minimum distance to the selected points for each remaining point
    min_distances = np.min(distance_matrix[remaining_indices][:, selected_indices], axis=1)
    
    # Select the point with the maximum minimum distance
    next_index = remaining_indices[np.argmax(min_distances)]
    selected_indices.append(next_index)

selected_indices = [0]+selected_indices

bvalsHCP7 = bvalsHCP[selected_indices]
bvecsHCP7 = bvecsHCP[selected_indices]
gtabHCP7 = gradient_table(bvals = bvalsHCP7, bvecs = bvecsHCP7)

In [None]:
if os.path.exists(NetworkDir+"DTI_Network.pickle"):
    with open(NetworkDir+"DTI_Network.pickle", "rb") as handle:
        Network = pickle.load(handle)
        print('loaded')
else:
    TrainingSamples = 1_000_000
    Obs = []
    Par = []
    D_prior = []
    MD_prior = np.random.rand(int(TrainingSamples))*0.005
    FA_prior = np.random.rand(TrainingSamples)*0.999
    S0_prior = np.random.uniform(20,2500,TrainingSamples)
    
    R1 = (np.ones(int(TrainingSamples*0.8))*6).astype(int)
    R2 = np.random.choice(np.arange(6,68),int(TrainingSamples*0.2))
    
    R = [np.insert(np.random.choice(np.arange(1,69),r,replace=False),0,0) for r in np.hstack([R1,R2])]
    bval_choice = np.random.choice([1,2],TrainingSamples,replace=True)
    
    for m,f,r,S,bv in tqdm(zip(MD_prior,FA_prior,R,S0_prior,bval_choice),position=0,leave=True):
        dt = random_diffusion_tensor(m, f)
        D_prior.append(mat_to_vals(dt))
        gtab = gradient_table(bvals = gtabHCP.bvals[r]*bv,bvecs = gtabHCP.bvecs[r])
        a = 50#N
        Obs.append(DTIFeatures(gtab.bvecs,gtab.bvals,CustomSimulator(dt,gtab,S,a)))
        Par.append(np.hstack([mat_to_vals(dt),S]))

    Obs = np.array(Obs)
    Par = np.array(Par)
    Obs= torch.tensor(Obs).float()
    Par= torch.tensor(Par).float()

    # Create inference object. Here, NPE is used.
    inference = SNPE(device='mps')

    # generate simulations and pass to the inference object
    inference = inference.append_simulations(Par, Obs,data_device='cpu')

    # train the density estim ator and build the posterior
    density_estimator = inference.train(training_batch_size = 512)
    low = torch.tensor([
        *[-1e-2,-5e-3,-1e-2,-5e-3,-5e-3,-1e-2,0],    # dhind (6)
    ])
    high = torch.tensor([
        *[1e-2,5e-3,1e-2,5e-3,5e-3,1e-2,3000],    # dhind (6)
    ])

    prior_bounds = BoxUniform(low=low, high=high)
    Network = DirectPosterior(density_estimator.cpu(), prior=prior_bounds)
    if not os.path.exists(NetworkDir+"DTI_Network.pickle"):
        with open(NetworkDir+"DTI_Network.pickle", "wb") as handle:
            pickle.dump(Network, handle)

# Figure 2

In [None]:
fdwi = HCPDir + '/Pat'+str(1)+'/diff_1k.nii.gz'
bvalloc = HCPDir + '/Pat'+str(1)+'/bvals_1k.txt'
bvecloc = HCPDir + '/Pat'+str(1)+'/bvecs_1k.txt'

bvalsHCP = np.loadtxt(bvalloc)
bvecsHCP = np.loadtxt(bvecloc)
gtabHCP = gradient_table(bvals = bvalsHCP, bvecs = bvecsHCP)

selected_indices = [1]
distance_matrix = squareform(pdist(bvecsHCP))
# Iteratively select the point furthest from the current selection
for _ in range(5):  # We need 7 points in total, and one is already selected
    remaining_indices = list(set(range(len(bvecsHCP))) - set(selected_indices))
    
    # Calculate the minimum distance to the selected points for each remaining point
    min_distances = np.min(distance_matrix[remaining_indices][:, selected_indices], axis=1)
    
    # Select the point with the maximum minimum distance
    next_index = remaining_indices[np.argmax(min_distances)]
    selected_indices.append(next_index)

selected_indices = [0]+selected_indices

bvalsHCP7 = bvalsHCP[selected_indices]
bvecsHCP7 = bvecsHCP[selected_indices]
gtabHCP7 = gradient_table(bvals = bvalsHCP7, bvecs = bvecsHCP7)

## a

In [None]:
MD_truth = 0.001
FA_truth = 0.5
dtTruth = random_diffusion_tensor(MD_truth, FA_truth)

In [None]:
np.random.seed(13)
Truth = CustomSimulator(dtTruth,gtabHCP, S0=200,snr=None)
SNR = [CustomSimulator(dtTruth,gtabHCP, S0=200,snr=scale) for scale in [20,10,5,2]]


In [None]:
plt.subplots(figsize=(6,1))
plt.plot(Truth,'k',lw=2,label='True signal')
plt.plot(SNR[0],'gray',lw=2,ls='--',label='Noisy signal')
plt.axis('off')
legend= plt.legend(ncols=2,loc=1,bbox_to_anchor =  (1.03,1.7),fontsize=26,columnspacing=0.3,handlelength=0.4,handletextpad=0.1)
for handle in legend.get_lines():
    handle.set_linewidth(6)  # Set desired linewidth
plt.show()
plt.subplots(figsize=(6,1))
plt.plot(Truth,'k',lw=2)
plt.plot(SNR[1],'gray',lw=2,ls='--')
plt.axis('off')
plt.show()
plt.subplots(figsize=(6,1))
plt.plot(Truth,'k',lw=2)
plt.plot(SNR[2],'gray',lw=2,ls='--')
plt.axis('off')
plt.show()
plt.subplots(figsize=(6,1))
plt.plot(Truth,'k',lw=2)
plt.plot(SNR[3],'gray',lw=2,ls='--')
plt.axis('off')
plt.show()

## b

In [None]:
SNR20 = np.vstack([CustomSimulator(dtTruth,gtabHCP, S0=200,snr=20) for k in range(200)])
SNR10 = np.vstack([CustomSimulator(dtTruth,gtabHCP, S0=200,snr=10) for k in range(200)])
SNR5 = np.vstack([CustomSimulator(dtTruth,gtabHCP, S0=200,snr=5) for k in range(200)])
SNR2 = np.vstack([CustomSimulator(dtTruth,gtabHCP, S0=200,snr=2) for k in range(200)])

tenmodel = dti.TensorModel(gtabHCP,return_S0_hat = True,fit_method='NLLS')
tenfit = tenmodel.fit(SNR20)
FA20 = dti.fractional_anisotropy(tenfit.evals)
MD20 = dti.mean_diffusivity(tenfit.evals)
tenfit = tenmodel.fit(SNR10)
FA10 = dti.fractional_anisotropy(tenfit.evals)
MD10 = dti.mean_diffusivity(tenfit.evals)
tenfit = tenmodel.fit(SNR5)
FA5 = dti.fractional_anisotropy(tenfit.evals)
MD5 = dti.mean_diffusivity(tenfit.evals)
tenfit = tenmodel.fit(SNR2)
FA2 = dti.fractional_anisotropy(tenfit.evals)
MD2 = dti.mean_diffusivity(tenfit.evals)

In [None]:
torch.manual_seed(2)
np.random.seed(2)
gtab = gtabHCP
MD20_SBI = []
FA20_SBI = []
for S in tqdm(SNR20,position=0,leave=True):
    posterior_samples_1 = Network.sample((InferSamples,), x=DTIFeatures(gtabHCP.bvecs,gtabHCP.bvals,S),show_progress_bars=False)
    X = np.array([MD_FA(vals_to_mat(p)) for p in posterior_samples_1])
    MD,FA = np.mean(X,axis=0)
    MD20_SBI.append(MD)
    FA20_SBI.append(FA)

torch.manual_seed(2)
np.random.seed(2)
MD10_SBI = []
FA10_SBI = []
for S in tqdm(SNR10,position=0,leave=True):
    posterior_samples_1 = Network.sample((InferSamples,), x=DTIFeatures(gtabHCP.bvecs,gtabHCP.bvals,S),show_progress_bars=False)
    X = np.array([MD_FA(vals_to_mat(p)) for p in posterior_samples_1])
    MD,FA = np.mean(X,axis=0)
    MD10_SBI.append(MD)
    FA10_SBI.append(FA)

torch.manual_seed(2)
np.random.seed(2)
MD5_SBI = []


FA5_SBI = []
for S in tqdm(SNR5,position=0,leave=True):
    posterior_samples_1 = Network.sample((InferSamples,), x=DTIFeatures(gtabHCP.bvecs,gtabHCP.bvals,S),show_progress_bars=False)
    X = np.array([MD_FA(vals_to_mat(p)) for p in posterior_samples_1])
    MD,FA = np.mean(X,axis=0)
    MD5_SBI.append(MD)
    FA5_SBI.append(FA)

torch.manual_seed(2)
np.random.seed(2)
MD2_SBI = []
FA2_SBI = []
for S in tqdm(SNR2,position=0,leave=True):
    posterior_samples_1 = Network.sample((InferSamples,), x=DTIFeatures(gtabHCP.bvecs,gtabHCP.bvals,S),show_progress_bars=False)
    X = np.array([MD_FA(vals_to_mat(p)) for p in posterior_samples_1])
    MD,FA = np.mean(X,axis=0)
    MD2_SBI.append(MD)
    FA2_SBI.append(FA)

In [None]:
torch.manual_seed(2)
np.random.seed(2)
gtab = gtabHCP
MD20_SBI = []
FA20_SBI = []
for S in tqdm(SNR20,position=0,leave=True):
    posterior_samples_1 = Network.sample((InferSamples,), x=DTIFeatures(gtab.bvecs,gtab.bvals,S),show_progress_bars=False)
    mat_guess = vals_to_mat(posterior_samples_1.mean(axis=0))
    mat_guess = clip_negative_eigenvalues(mat_guess)
    MD,FA = MD_FA(mat_guess)
    MD20_SBI.append(MD)
    FA20_SBI.append(FA)

torch.manual_seed(2)
np.random.seed(2)
MD10_SBI = []
FA10_SBI = []
for S in tqdm(SNR10,position=0,leave=True):
    posterior_samples_1 = Network.sample((InferSamples,), x=DTIFeatures(gtab.bvecs,gtab.bvals,S),show_progress_bars=False)
    mat_guess = vals_to_mat(posterior_samples_1.mean(axis=0))
    mat_guess = clip_negative_eigenvalues(mat_guess)
    MD,FA = MD_FA(mat_guess)
    MD10_SBI.append(MD)
    FA10_SBI.append(FA)

torch.manual_seed(2)
np.random.seed(2)
MD5_SBI = []


FA5_SBI = []
for S in tqdm(SNR5,position=0,leave=True):
    posterior_samples_1 = Network.sample((InferSamples,), x=DTIFeatures(gtab.bvecs,gtab.bvals,S),show_progress_bars=False)
    mat_guess = vals_to_mat(posterior_samples_1.mean(axis=0))
    mat_guess = clip_negative_eigenvalues(mat_guess)
    MD,FA = MD_FA(mat_guess)
    MD5_SBI.append(MD)
    FA5_SBI.append(FA)

torch.manual_seed(2)
np.random.seed(2)
MD2_SBI = []
FA2_SBI = []
for S in tqdm(SNR2,position=0,leave=True):
    posterior_samples_1 = Network.sample((InferSamples,), x=DTIFeatures(gtab.bvecs,gtab.bvals,S),show_progress_bars=False)
    mat_guess = vals_to_mat(posterior_samples_1.mean(axis=0))
    mat_guess = clip_negative_eigenvalues(mat_guess)
    MD,FA = MD_FA(mat_guess)
    MD2_SBI.append(MD)
    FA2_SBI.append(FA)

In [None]:
fig,ax = plt.subplots(figsize=(6.4,2.4))
y_data = np.array([MD20_SBI,MD10_SBI,MD5_SBI,MD2_SBI])
g_pos = np.array([1.3,2.3,3.3,4.3])

colors = ['lightseagreen','lightseagreen','lightseagreen','lightseagreen']
colors2 = ['paleturquoise','paleturquoise','paleturquoise','paleturquoise']

BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.3,scatter=True)

g_pos = np.array([1,2,3,4])
colors = ['sandybrown','sandybrown','sandybrown','sandybrown']
colors2 = ['peachpuff','peachpuff','peachpuff','peachpuff']
y_data = np.array([MD20,MD10,MD5,MD2])
BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.3,scatter=True)

l = plt.axhline(MD_truth,c='k',lw=3,ls='--',label='True MD')
plt.xticks([])

leg_patch2 = mpatches.Patch(color='sandybrown', label='NLLS Fit')
ax.legend(
    handles=[leg_patch2],
    loc='upper left',         # base location  # fine-tune the legend's position
    frameon=False, ncols=1,
fontsize=32,columnspacing=0.3,handlelength=0.6,handletextpad=0.3,bbox_to_anchor=(0,0.5))
plt.gca().ticklabel_format(axis='y',style='sci',scilimits=(-1,1))
plt.yticks([0.001])

In [None]:
fig,ax = plt.subplots(figsize=(6.4,2.4))
y_data = np.array([MD20_SBI,MD10_SBI,MD5_SBI,MD2_SBI])
g_pos = np.array([1.3,2.3,3.3,4.3])

colors = ['lightseagreen','lightseagreen','lightseagreen','lightseagreen']
colors2 = ['paleturquoise','paleturquoise','paleturquoise','paleturquoise']

BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.3,scatter=True)

g_pos = np.array([1,2,3,4])
colors = ['sandybrown','sandybrown','sandybrown','sandybrown']
colors2 = ['peachpuff','peachpuff','peachpuff','peachpuff']
y_data = np.array([MD20,MD10,MD5,MD2])
BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.3,scatter=True)

l = plt.axhline(MD_truth,c='k',lw=3,ls='--',label='True MD')
plt.xticks([])

leg_patch2 = mpatches.Patch(color='sandybrown', label='NLLS Fit')
ax.legend(
    handles=[leg_patch2],
    loc='upper left',         # base location  # fine-tune the legend's position
    frameon=False, ncols=1,
fontsize=32,columnspacing=0.3,handlelength=0.6,handletextpad=0.3,bbox_to_anchor=(0,0.5))
plt.gca().ticklabel_format(axis='y',style='sci',scilimits=(-1,1))
plt.yticks([0.001])

In [None]:
fig,ax = plt.subplots(figsize=(6.4,2.4))
y_data = np.array([FA20_SBI,FA10_SBI,FA5_SBI,FA2_SBI])
g_pos = np.array([1.3,2.3,3.3,4.3])

colors = ['lightseagreen','lightseagreen','lightseagreen','lightseagreen']
colors2 = ['paleturquoise','paleturquoise','paleturquoise','paleturquoise']

BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.3,scatter=True)

g_pos = np.array([1,2,3,4])
colors = ['sandybrown','sandybrown','sandybrown','sandybrown']
colors2 = ['peachpuff','peachpuff','peachpuff','peachpuff']
y_data = np.array([FA20,FA10,FA5,FA2])
BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.3,scatter=True)

l = plt.axhline(FA_truth,c='k',lw=3,ls='--',label='True FA')
plt.xticks([1,2,3,4],[20,10,5,2],fontsize=28)
plt.xticks(fontsize=28)
#plt.xlabel('SNR',fontsize=32)
#plt.ylabel('FA',fontsize=32)
leg_patch1 = mpatches.Patch(color='lightseagreen', label='SBI Fit')
leg_patch2 = mpatches.Patch(color='sandybrown', label='NLLS Fit')
ax.legend(
    handles=[leg_patch1],
    loc='upper left',         # base location  # fine-tune the legend's position
    frameon=False, ncols=1,
fontsize=32,columnspacing=0.3,handlelength=0.6,handletextpad=0.3,bbox_to_anchor=(0,0.5))
plt.yticks([0,1])

In [None]:
fig,ax = plt.subplots(figsize=(6.4,2.4))
y_data = np.array([FA20_SBI,FA10_SBI,FA5_SBI,FA2_SBI])
g_pos = np.array([1.3,2.3,3.3,4.3])

colors = ['lightseagreen','lightseagreen','lightseagreen','lightseagreen']
colors2 = ['paleturquoise','paleturquoise','paleturquoise','paleturquoise']

BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.3,scatter=True)

g_pos = np.array([1,2,3,4])
colors = ['sandybrown','sandybrown','sandybrown','sandybrown']
colors2 = ['peachpuff','peachpuff','peachpuff','peachpuff']
y_data = np.array([FA20,FA10,FA5,FA2])
BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.3,scatter=True)

l = plt.axhline(FA_truth,c='k',lw=3,ls='--',label='True FA')
plt.xticks([1,2,3,4],[20,10,5,2],fontsize=28)
plt.xticks(fontsize=28)
#plt.xlabel('SNR',fontsize=32)
#plt.ylabel('FA',fontsize=32)
leg_patch1 = mpatches.Patch(color='lightseagreen', label='SBI Fit')
leg_patch2 = mpatches.Patch(color='sandybrown', label='NLLS Fit')
ax.legend(
    handles=[leg_patch1],
    loc='upper left',         # base location  # fine-tune the legend's position
    frameon=False, ncols=1,
fontsize=32,columnspacing=0.3,handlelength=0.6,handletextpad=0.3,bbox_to_anchor=(0,0.5))
plt.yticks([0,1])

## c

In [None]:
fig,ax = plt.subplots(1,2,figsize=(9,3))
r = np.insert(np.random.choice(np.arange(1,69),68,replace=False),0,0)
gtab = gtabHCP
torch.manual_seed(10)
for i,kk in enumerate([20,10,5,2]):
    np.random.seed(10)
    print('---Starting process---')
    D_prior = []
    MD_prior = np.random.uniform(0.0005,0.003,TestSamples)
    FA_prior = np.random.rand(TestSamples)*0.999
    S0_prior = np.ones(TestSamples)*200
    Noise_prior = np.ones(TestSamples)*kk

    print('---Generating observations and parameters---')
    D_prior  = [random_diffusion_tensor(m, f) for m,f in zip(MD_prior,FA_prior)]
    
    Obs_test  = np.array([CustomSimulator(dt,gtab,S,n) for dt,S,n in zip(D_prior,S0_prior,Noise_prior)])
    Pars_test = np.column_stack([mat_to_vals(dt) for dt in D_prior]).T
    True_MD_FA = np.array([MD_FA(vals_to_mat(P)) for P in Pars_test]).T

    #NLLSFIt
    tenmodel = dti.TensorModel(gtab,fit_method='NLLS')
    NLLSFit = tenmodel.fit(Obs_test)
    MD_FA_NLLS = np.array([NLLSFit.md,NLLSFit.fa])

    MD_FA_SBI= []
    for O in tqdm(Obs_test,position=0,leave=True):
        posterior_samples_1 = Network.sample((InferSamples,), x=DTIFeatures(gtab.bvecs,gtab.bvals,O),show_progress_bars=False)
        MD,FA = np.array([MD_FA(vals_to_mat(p)) for p in posterior_samples_1]).mean(axis=0)
        MD_FA_SBI.append([MD,FA])
    MD_FA_SBI = np.array(MD_FA_SBI).T

    y_data = abs(MD_FA_SBI[0,:]-True_MD_FA[0,:])
    g_pos = np.array([1.3 + i])
    colors = ['lightseagreen','lightseagreen','lightseagreen','lightseagreen']
    colors2 = ['paleturquoise','paleturquoise','paleturquoise','paleturquoise']

    BoxPlots(y_data.T,g_pos,colors,colors2,ax[0],widths=0.3,scatter=False)
    
    y_data = abs(MD_FA_NLLS[0,:]-True_MD_FA[0,:])
    g_pos = np.array([1+i])
    colors = ['sandybrown','sandybrown','sandybrown','sandybrown']
    colors2 = ['peachpuff','peachpuff','peachpuff','peachpuff']
    
    BoxPlots(y_data.T,g_pos,colors,colors2,ax[0],widths=0.3,scatter=False)

    y_data = abs(MD_FA_SBI[1,:]-True_MD_FA[1,:])
    g_pos = np.array([1.3 + i])
    colors = ['lightseagreen','lightseagreen','lightseagreen','lightseagreen']
    colors2 = ['paleturquoise','paleturquoise','paleturquoise','paleturquoise']

    BoxPlots(y_data.T,g_pos,colors,colors2,ax[1],widths=0.3,scatter=False)
    
    y_data = abs(MD_FA_NLLS[1,:]-True_MD_FA[1,:])
    g_pos = np.array([1+i])
    colors = ['sandybrown','sandybrown','sandybrown','sandybrown']
    colors2 = ['peachpuff','peachpuff','peachpuff','peachpuff']
    
    BoxPlots(y_data.T,g_pos,colors,colors2,ax[1],widths=0.3,scatter=False)
    
    #x[0].set_ylim([0,0.004])
    ax[1].set_ylim([0,1])
    ax[0].yaxis.grid(True)
    ax[1].yaxis.grid(True)
plt.tight_layout()
for ll,a in enumerate(ax):
    plt.sca(a)
    plt.gca().ticklabel_format(axis='y',style='sci',scilimits=(-1,1))
    plt.yticks(fontsize=32)
    plt.xticks([1.15, 2.15, 3.15, 4.15,], [20,10,5,2],fontsize=32)
    if(ll==0):
        handles = [
            Line2D([0], [0], color=SBIFit, lw=4, label='SBI'),  # Adjust color as per the actual plot color
        ]
        # Add the legenda
        plt.legend(handles=handles,loc=2, bbox_to_anchor=(0,1.15),
                   fontsize=36,columnspacing=0.3,handlelength=0.6,handletextpad=0.3)
        plt.ylim([np.float64(-0.00025242580421944123), np.float64(0.005317867177548843)])
    if(ll==1):
        handles = [
            Line2D([0], [0], color=WLSFit, lw=4, label='NLLS'),  # Adjust color as per the actual plot color
        ]
        # Add the legenda
        plt.legend(handles=handles,loc=2, bbox_to_anchor=(0,1.15),
                   fontsize=36,columnspacing=0.3,handlelength=0.6,handletextpad=0.3)
        plt.ylim([-0.05,1])
        plt.yticks([0,1])

In [None]:
fig,ax = plt.subplots(1,2,figsize=(9,3))
r = np.insert(np.random.choice(np.arange(1,69),68,replace=False),0,0)
gtab = gtabHCP
torch.manual_seed(10)
for i,kk in enumerate([20,10,5,2]):
    np.random.seed(10)
    print('---Starting process---')
    D_prior = []
    MD_prior = np.random.uniform(0.0005,0.003,TestSamples)
    FA_prior = np.random.rand(TestSamples)*0.999
    S0_prior = np.ones(TestSamples)*200
    Noise_prior = np.ones(TestSamples)*kk

    print('---Generating observations and parameters---')
    D_prior  = [random_diffusion_tensor(m, f) for m,f in zip(MD_prior,FA_prior)]
    
    Obs_test  = np.array([CustomSimulator(dt,gtab,S,n) for dt,S,n in zip(D_prior,S0_prior,Noise_prior)])
    Pars_test = np.column_stack([mat_to_vals(dt) for dt in D_prior]).T
    True_MD_FA = np.array([MD_FA(vals_to_mat(P)) for P in Pars_test]).T

    #NLLSFIt
    tenmodel = dti.TensorModel(gtab,fit_method='NLLS')
    NLLSFit = tenmodel.fit(Obs_test)
    MD_FA_NLLS = np.array([NLLSFit.md,NLLSFit.fa])

    MD_FA_SBI= []
    for O in tqdm(Obs_test,position=0,leave=True):
        posterior_samples_1 = Network.sample((InferSamples,), x=DTIFeatures(gtab.bvecs,gtab.bvals,O),show_progress_bars=False)
        mat_guess = vals_to_mat(posterior_samples_1.mean(axis=0))
        mat_guess = clip_negative_eigenvalues(mat_guess)
        MD,FA = MD_FA(mat_guess)
        MD_FA_SBI.append([MD,FA])
    MD_FA_SBI = np.array(MD_FA_SBI).T

    y_data = abs(MD_FA_SBI[0,:]-True_MD_FA[0,:])
    g_pos = np.array([1.3 + i])
    colors = ['lightseagreen','lightseagreen','lightseagreen','lightseagreen']
    colors2 = ['paleturquoise','paleturquoise','paleturquoise','paleturquoise']

    BoxPlots(y_data.T,g_pos,colors,colors2,ax[0],widths=0.3,scatter=False)
    
    y_data = abs(MD_FA_NLLS[0,:]-True_MD_FA[0,:])
    g_pos = np.array([1+i])
    colors = ['sandybrown','sandybrown','sandybrown','sandybrown']
    colors2 = ['peachpuff','peachpuff','peachpuff','peachpuff']
    
    BoxPlots(y_data.T,g_pos,colors,colors2,ax[0],widths=0.3,scatter=False)

    y_data = abs(MD_FA_SBI[1,:]-True_MD_FA[1,:])
    g_pos = np.array([1.3 + i])
    colors = ['lightseagreen','lightseagreen','lightseagreen','lightseagreen']
    colors2 = ['paleturquoise','paleturquoise','paleturquoise','paleturquoise']

    BoxPlots(y_data.T,g_pos,colors,colors2,ax[1],widths=0.3,scatter=False)
    
    y_data = abs(MD_FA_NLLS[1,:]-True_MD_FA[1,:])
    g_pos = np.array([1+i])
    colors = ['sandybrown','sandybrown','sandybrown','sandybrown']
    colors2 = ['peachpuff','peachpuff','peachpuff','peachpuff']
    
    BoxPlots(y_data.T,g_pos,colors,colors2,ax[1],widths=0.3,scatter=False)
    
    #x[0].set_ylim([0,0.004])
    ax[1].set_ylim([0,1])
    ax[0].yaxis.grid(True)
    ax[1].yaxis.grid(True)
plt.tight_layout()
for ll,a in enumerate(ax):
    plt.sca(a)
    plt.gca().ticklabel_format(axis='y',style='sci',scilimits=(-1,1))
    plt.yticks(fontsize=32)
    plt.xticks([1.15, 2.15, 3.15, 4.15,], [20,10,5,2],fontsize=32)
    if(ll==0):
        handles = [
            Line2D([0], [0], color=SBIFit, lw=4, label='SBI'),  # Adjust color as per the actual plot color
        ]
        # Add the legenda
        plt.legend(handles=handles,loc=2, bbox_to_anchor=(0,1.15),
                   fontsize=36,columnspacing=0.3,handlelength=0.6,handletextpad=0.3)
        plt.ylim([np.float64(-0.00025242580421944123), np.float64(0.005317867177548843)])
    if(ll==1):
        handles = [
            Line2D([0], [0], color=WLSFit, lw=4, label='NLLS'),  # Adjust color as per the actual plot color
        ]
        # Add the legenda
        plt.legend(handles=handles,loc=2, bbox_to_anchor=(0,1.15),
                   fontsize=36,columnspacing=0.3,handlelength=0.6,handletextpad=0.3)
        plt.ylim([-0.05,1])
        plt.yticks([0,1])

## d

In [None]:
fig,ax = plt.subplots(1,2,figsize=(9,3))
r = np.insert(np.random.choice(np.arange(1,69),68,replace=False),0,0)
gtab = gtabHCP7
torch.manual_seed(10)
for i,kk in enumerate([20,10,5,2]):
    np.random.seed(10)
    print('---Starting process---')
    D_prior = []
    MD_prior = np.random.uniform(0.0005,0.003,TestSamples)
    FA_prior = np.random.rand(TestSamples)*0.999
    S0_prior = np.ones(TestSamples)*200
    Noise_prior = np.ones(TestSamples)*kk

    print('---Generating observations and parameters---')
    D_prior  = [random_diffusion_tensor(m, f) for m,f in zip(MD_prior,FA_prior)]
    
    Obs_test  = np.array([CustomSimulator(dt,gtab,S,n) for dt,S,n in zip(D_prior,S0_prior,Noise_prior)])
    Pars_test = np.column_stack([mat_to_vals(dt) for dt in D_prior]).T
    True_MD_FA = np.array([MD_FA(vals_to_mat(P)) for P in Pars_test]).T

    #NLLSFIt
    tenmodel = dti.TensorModel(gtab,fit_method='NLLS')
    NLLSFit = tenmodel.fit(Obs_test)
    MD_FA_NLLS = np.array([NLLSFit.md,NLLSFit.fa])

    MD_FA_SBI= []
    for O in tqdm(Obs_test,position=0,leave=True):
        posterior_samples_1 = Network.sample((InferSamples,), x=DTIFeatures(gtab.bvecs,gtab.bvals,O),show_progress_bars=False)
        mat_guess = vals_to_mat(posterior_samples_1.mean(axis=0))
        mat_guess = clip_negative_eigenvalues(mat_guess)
        MD,FA = MD_FA(mat_guess)
        MD_FA_SBI.append([MD,FA])
    MD_FA_SBI = np.array(MD_FA_SBI).T

    y_data = abs(MD_FA_SBI[0,:]-True_MD_FA[0,:])
    g_pos = np.array([1.3 + i])
    colors = ['lightseagreen','lightseagreen','lightseagreen','lightseagreen']
    colors2 = ['paleturquoise','paleturquoise','paleturquoise','paleturquoise']

    BoxPlots(y_data.T,g_pos,colors,colors2,ax[0],widths=0.3,scatter=False)
    
    y_data = abs(MD_FA_NLLS[0,:]-True_MD_FA[0,:])
    g_pos = np.array([1+i])
    colors = ['sandybrown','sandybrown','sandybrown','sandybrown']
    colors2 = ['peachpuff','peachpuff','peachpuff','peachpuff']
    
    BoxPlots(y_data.T,g_pos,colors,colors2,ax[0],widths=0.3,scatter=False)

    y_data = abs(MD_FA_SBI[1,:]-True_MD_FA[1,:])
    g_pos = np.array([1.3 + i])
    colors = ['lightseagreen','lightseagreen','lightseagreen','lightseagreen']
    colors2 = ['paleturquoise','paleturquoise','paleturquoise','paleturquoise']

    BoxPlots(y_data.T,g_pos,colors,colors2,ax[1],widths=0.3,scatter=False)
    
    y_data = abs(MD_FA_NLLS[1,:]-True_MD_FA[1,:])
    g_pos = np.array([1+i])
    colors = ['sandybrown','sandybrown','sandybrown','sandybrown']
    colors2 = ['peachpuff','peachpuff','peachpuff','peachpuff']
    
    BoxPlots(y_data.T,g_pos,colors,colors2,ax[1],widths=0.3,scatter=False)
    
    #x[0].set_ylim([0,0.004])
    ax[1].set_ylim([0,1])
    ax[0].yaxis.grid(True)
    ax[1].yaxis.grid(True)
plt.tight_layout()
for ll,a in enumerate(ax):
    plt.sca(a)
    plt.gca().ticklabel_format(axis='y',style='sci',scilimits=(-1,1))
    plt.yticks(fontsize=32)
    plt.xticks([1.15, 2.15, 3.15, 4.15,], [20,10,5,2],fontsize=32)
    if(ll==0):
        plt.ylim([np.float64(-0.00025242580421944123), np.float64(0.005317867177548843)])
    if(ll==1):
        plt.ylim([-0.05,1])
        plt.yticks([0,1])

# Figure 3

In [None]:
fdwi = '../../HCP_data/Pat'+str(1)+'/diff_1k.nii.gz'

data, affine, img = load_nifti(fdwi, return_img=True)
data, affine = reslice(data, affine, (1.5,1.5,1.5), (2.5,2.5,2.5))
floor = np.clip(data.min(axis=-1),-np.inf,0)
data2 = data + abs(floor)[:,:,:,None] + 1e-5
axial_middle = data2.shape[2] // 2
maskdata, mask = median_otsu(data2, vol_idx=range(10, 50), median_radius=3,
                             numpass=1, autocrop=True, dilate=2)
mask_cutout = np.copy(mask[:,:,axial_middle])


In [None]:
# Compute the mask where the sum is not zero
mask = np.sum(maskdata[:, :, axial_middle, :], axis=-1) != 0

# Get the indices where mask is True
indices = np.argwhere(mask)

# Define the function for optimization
def optimize_pixel(i, j):
    torch.manual_seed(10)  # If required
    posterior_samples_1 = Network.sample((InferSamples,), x=DTIFeatures(gtabHCP.bvecs,gtabHCP.bvals,maskdata[i, j,axial_middle, :]),show_progress_bars=False)
    return i, j, posterior_samples_1.mean(axis=0)

# Initialize array with the appropriate shape
ArrShape = mask.shape
print('start')
# Use joblib to parallelize the optimization tasks
results = Parallel(n_jobs=24)(
    delayed(optimize_pixel)(i, j) for i, j in tqdm(indices,position=0,leave=True)
)
print('end')

InferredParams = np.zeros(list(ArrShape) + [7])

# Assign the optimization results to InferredParams
for i, j, x in results:
    InferredParams[i, j] = x
    
MD_SBIFull = np.zeros([55,64])
FA_SBIFull = np.zeros([55,64])
for i in range(55):
    for j in range(64):
        Eigs = np.linalg.eigh(vals_to_mat(InferredParams[i,j,:6]))[0]
        MD_SBIFull[i,j] = np.mean(Eigs)
        FA_SBIFull[i,j] = FracAni(Eigs,np.mean(Eigs))
FA_SBIFull[np.isnan(FA_SBIFull )] = 0

In [None]:
tenmodel = dti.TensorModel(gtabHCP,return_S0_hat = True,fit_method='NLLS')
tenfit = tenmodel.fit(maskdata[:,:,axial_middle,:])

## a

In [None]:
temp = np.copy(MD_SBIFull)

temp[~mask_cutout] = math.nan
img = plt.imshow(temp.T,cmap='hot')
plt.axis('off')
vmin, vmax = img.get_clim()


In [None]:
temp = np.copy(tenfit.md)

temp[~mask_cutout] = math.nan
img = plt.imshow(temp.T,cmap='hot')
plt.axis('off')
vmin, vmax = img.get_clim()
cbar = plt.colorbar()
cbar.formatter.set_powerlimits((0, 0))

In [None]:
data = MD_SBIFull.T-tenfit.md.T
data[~mask_cutout.T] = np.nan
norm = TwoSlopeNorm(vmin=np.nanmin(data), vcenter=0, vmax=np.nanmax(data))
plt.imshow(data,cmap='seismic',norm=norm)
plt.axis('off')
cbar = plt.colorbar()
ticks = [np.nanmin(data), 0, np.nanmax(data)]  # Adjust the number of ticks as needed
cbar.set_ticks(ticks)
cbar.formatter.set_powerlimits((0, 0))

## c

In [None]:
temp = np.copy(FA_SBIFull)

temp[~mask_cutout] = math.nan
img = plt.imshow(temp.T,cmap='hot',vmin=0,vmax=1)
plt.axis('off')
vmin, vmax = img.get_clim()

In [None]:
temp = np.copy(tenfit.fa)

temp[~mask_cutout] = math.nan
img = plt.imshow(temp.T,cmap='hot')
plt.axis('off')
vmin, vmax = img.get_clim()
cbar = plt.colorbar()
cbar.formatter.set_powerlimits((0, 0))
cbar.set_ticks([0,0.2,0.4,0.6,0.8,1])

In [None]:
data = FA_SBIFull.T-tenfit.fa.T
data[~mask_cutout.T] = np.nan
plt.imshow(data,cmap='seismic',vmin=-1, vmax=1)
plt.axis('off')
cbar = plt.colorbar()
cbar.formatter.set_powerlimits((0, 0))
ticks = [-1, 0, 1]  # Adjust the number of ticks as needed
cbar.set_ticks(ticks)

## b

In [None]:
# Compute the mask where the sum is not zero
mask = np.sum(maskdata[:, :, axial_middle, :], axis=-1) != 0

# Get the indices where mask is True
indices = np.argwhere(mask)

# Define the function for optimization
def optimize_pixel(i, j):
    torch.manual_seed(10)  # If required
    posterior_samples_1 = Network.sample((InferSamples,), x=DTIFeatures(gtabHCP7.bvecs,gtabHCP7.bvals,maskdata[i, j,axial_middle, selected_indices]),show_progress_bars=False)
    return i, j, posterior_samples_1.mean(axis=0)

# Initialize array with the appropriate shape
ArrShape = mask.shape
print('start')
# Use joblib to parallelize the optimization tasks
results = Parallel(n_jobs=24)(
    delayed(optimize_pixel)(i, j) for i, j in tqdm(indices,position=0,leave=True)
)
print('end')

InferredParams = np.zeros(list(ArrShape) + [7])

# Assign the optimization results to InferredParams
for i, j, x in results:
    InferredParams[i, j] = x

In [None]:
MD_SBIMin = np.zeros([55,64])
FA_SBIMin = np.zeros([55,64])
for i in range(55):
    for j in range(64):
        Eigs = np.linalg.eigh(vals_to_mat(InferredParams[i,j,:6]))[0]
        MD_SBIMin[i,j] = np.mean(Eigs)
        FA_SBIMin[i,j] = FracAni(Eigs,np.mean(Eigs))
FA_SBIMin[np.isnan(FA_SBIMin )] = 0

In [None]:
temp = np.copy(MD_SBIMin)

temp[~mask_cutout] = math.nan
img = plt.imshow(temp.T,cmap='hot',vmin=0,vmax=4e-3)
plt.axis('off')

In [None]:
tenmodel = dti.TensorModel(gtabHCP7,return_S0_hat = True,fit_method='NLLS')
tenfit_min = tenmodel.fit(maskdata[:,:,axial_middle,selected_indices])

In [None]:
temp = np.copy(tenfit_min.md)

temp[~mask] = math.nan
img = plt.imshow(temp.T,cmap='hot',vmin=0,vmax=3.5e-3)
plt.axis('off')
#cbar = plt.colorbar()
#cbar.formatter.set_powerlimits((0, 0))
vmin, vmax = img.get_clim()
cbar = plt.colorbar()
cbar.formatter.set_powerlimits((0, 0))

In [None]:
data = np.abs(MD_SBIFull.T-MD_SBIMin.T)
data[~mask.T] = np.nan
norm = TwoSlopeNorm(vmin=0, vcenter=np.nanmax(data)/2, vmax=np.nanmax(data))
plt.imshow(data,cmap='Reds',norm=norm)
plt.axis('off')
plt.show()

norm = TwoSlopeNorm(vmin=0, vcenter=np.nanmax(data)/2, vmax=np.nanmax(data))
ticks = [0, np.round(np.nanmax(data),3)]  # Adjust the number of ticks as needed
data = np.abs(tenfit.md.T-tenfit_min.md.T)
data[~mask.T] = np.nan
plt.imshow(data,cmap='Reds',norm=norm)

plt.axis('off')
cbar = plt.colorbar()

cbar.set_ticks(ticks)
cbar.formatter.set_powerlimits((0, 0))

## d

In [None]:
temp = gaussian_filter(FA_SBIMin, sigma=0.51).T
temp[~mask_cutout.T] = math.nan
plt.imshow(temp,cmap='hot',vmin=0,vmax=1)
plt.axis('off')


In [None]:
temp = np.copy(tenfit_min.fa)

temp[~mask] = math.nan
img = plt.imshow(temp.T,cmap='hot',vmin=0,vmax=1)
plt.axis('off')
cbar = plt.colorbar()
cbar.formatter.set_powerlimits((0, 0))


In [None]:
data = np.abs(FA_SBIFull.T-FA_SBIMin.T)
data[~mask.T] = np.nan
norm = TwoSlopeNorm(vmin=0, vcenter=0.5, vmax=1)
plt.imshow(data,cmap='Reds',norm=norm)
plt.axis('off')
plt.show()

norm = TwoSlopeNorm(vmin=0, vcenter=0.5, vmax=1)
ticks = [0, 1]  # Adjust the number of ticks as needed
data = np.abs(tenfit.fa.T-tenfit_min.fa.T)
data[~mask.T] = np.nan
plt.imshow(data,cmap='Reds',norm=norm)

plt.axis('off')
cbar = plt.colorbar()

cbar.set_ticks(ticks)

## e

In [None]:
fdwi = '../../HCP_data/Pat'+str(1)+'/diff_1k.nii.gz'
bvalloc = '../../HCP_data/Pat'+str(1)+'/bvals_1k.txt'
bvecloc = '../../HCP_data/Pat'+str(1)+'/bvecs_1k.txt'

bvalsHCP = np.loadtxt(bvalloc)
bvecsHCP = np.loadtxt(bvecloc)

# Choose the first point (arbitrary starting point, e.g., the first gradient)
selected_indices20 = [0]
distance_matrix = squareform(pdist(bvecsHCP))
# Iteratively select the point furthest from the current selection
for _ in range(19):  # We need 7 points in total, and one is already selected
    remaining_indices = list(set(range(len(bvecsHCP))) - set(selected_indices20))
    
    # Calculate the minimum distance to the selected points for each remaining point
    min_distances = np.min(distance_matrix[remaining_indices][:, selected_indices20], axis=1)
    
    # Select the point with the maximum minimum distance
    next_index = remaining_indices[np.argmax(min_distances)]
    selected_indices20.append(next_index)

In [None]:
Masks = []
maskdatas = []
axial_middles = []
WMs = []

gTabsF = []
gTabs7 = []
gTabs20 = []

FullDat   = []
for kk in tqdm(range(32),position=0,leave=True):
    fdwi = '../../HCP_data/Pat'+str(kk+1)+'/diff_1k.nii.gz'
    bvalloc = '../../HCP_data/Pat'+str(kk+1)+'/bvals_1k.txt'
    bvecloc = '../../HCP_data/Pat'+str(kk+1)+'/bvecs_1k.txt'
    
    bvalsHCP = np.loadtxt(bvalloc)
    bvecsHCP = np.loadtxt(bvecloc)
    gtabHCP = gradient_table(bvals = bvalsHCP, bvecs = bvecsHCP)
    gTabsF.append(gtabHCP)

    data, affine, img = load_nifti(fdwi, return_img=True)
    data, affine = reslice(data, affine, (1.5,1.5,1.5), (2.5,2.5,2.5))
    floor = np.clip(data.min(axis=-1),-np.inf,0)
    data2 = data + abs(floor)[:,:,:,None] + 1e-5
    maskdata, _ = median_otsu(data2, vol_idx=range(10, 50), median_radius=3,
                                 numpass=1, autocrop=True, dilate=2)
    axial_middle = maskdata.shape[2] // 2
    # Compute the mask where the sum is not zero
    mask = np.sum(maskdata[:, :, axial_middle, :], axis=-1) != 0
    Masks.append(mask)
    maskdatas.append(maskdata[:,:,axial_middle])
    axial_middles.append(axial_middle)

    WM, affine, img = load_nifti('../../HCP_data/WM_Masks/c2Pat'+str(kk+1)+'_FP.nii', return_img=True)
    WMs.append(np.fliplr(WM[:,:,axial_middles[kk]]>0.8))
    
    bvalsHCP7 = bvalsHCP[selected_indices]
    bvecsHCP7 = bvecsHCP[selected_indices]
    gtabHCP7 = gradient_table(bvals = bvalsHCP7, bvecs = bvecsHCP7)

    gTabs7.append(gtabHCP7)

    bvalsHCP20 = bvalsHCP[selected_indices20]
    bvecsHCP20 = bvecsHCP[selected_indices20]
    gtabHCP20 = gradient_table(bvals = bvalsHCP20, bvecs = bvecsHCP20)

    gTabs20.append(gtabHCP20)


In [None]:
MDFullArr = []
FAFullArr = []
for kk in tqdm(range(32),position=0,leave=True):
    dat = maskdatas[kk]
    mask = np.sum(dat, axis=-1) != 0

    gtab = gTabsF[kk]
    # Get the indices where mask is True
    indices = np.argwhere(mask)
    def optimize_chunk(pixels):
        results = []
        for i, j in pixels:
            posterior_samples_1 = Network.sample((InferSamples,), x=DTIFeatures(gtab.bvecs,gtab.bvals,dat[i,j]),show_progress_bars=False)
            results.append((i, j, posterior_samples_1.mean(axis=0)))
        return results

    chunked_indices = [indices[i:i+ChunkSize] for i in range(0, len(indices), ChunkSize)]
    results = Parallel(n_jobs=8)(
        delayed(optimize_chunk)(chunk) for chunk in chunked_indices
    )

    # Initialize array with the appropriate shape
    ArrShape = mask.shape
    NoiseEst = np.zeros(list(ArrShape) + [7])

    # Assign the optimization results to NoiseEst
    for chunk in results:
        for i, j, x in chunk:
            NoiseEst[i, j] = x

    NoiseEst2 =  np.zeros_like(NoiseEst)
    for i in range(ArrShape[0]):
        for j in range(ArrShape[1]):    
            NoiseEst2[i,j] = np.hstack([mat_to_vals(clip_negative_eigenvalues(vals_to_mat(NoiseEst[i,j]))),NoiseEst[i,j,-1:]])
    MD_SBIFull = np.zeros(ArrShape)
    FA_SBIFull = np.zeros(ArrShape)
    for i in range(ArrShape[0]):
        for j in range(ArrShape[1]): 
            Eigs = np.linalg.eigh(vals_to_mat(NoiseEst2[i,j,:6]))[0]
            MD_SBIFull[i,j] = np.mean(Eigs)
            FA_SBIFull[i,j] = FracAni(Eigs,np.mean(Eigs))
    FA_SBIFull[np.isnan(FA_SBIFull)] = 0
    MDFullArr.append(MD_SBIFull)
    FAFullArr.append(FA_SBIFull)

In [None]:
MDMidArr = []
FAMidArr = []
for kk in tqdm(range(32),position=0,leave=True):
    dat = maskdatas[kk]
    mask = np.sum(dat, axis=-1) != 0

    gtab = gTabs20[kk]
    # Get the indices where mask is True
    indices = np.argwhere(mask)
    def optimize_chunk(pixels):
        results = []
        for i, j in pixels:
            posterior_samples_1 = Network.sample((InferSamples,), x=DTIFeatures(gtab.bvecs,gtab.bvals,dat[i,j,selected_indices20]),show_progress_bars=False)
            results.append((i, j, posterior_samples_1.mean(axis=0)))
        return results

    chunked_indices = [indices[i:i+ChunkSize] for i in range(0, len(indices), ChunkSize)]
    results = Parallel(n_jobs=8)(
        delayed(optimize_chunk)(chunk) for chunk in chunked_indices
    )

    # Initialize array with the appropriate shape
    ArrShape = mask.shape
    NoiseEst = np.zeros(list(ArrShape) + [7])

    # Assign the optimization results to NoiseEst
    for chunk in results:
        for i, j, x in chunk:
            NoiseEst[i, j] = x

    NoiseEst2 =  np.zeros_like(NoiseEst)
    for i in range(ArrShape[0]):
        for j in range(ArrShape[1]):    
            NoiseEst2[i,j] = np.hstack([mat_to_vals(clip_negative_eigenvalues(vals_to_mat(NoiseEst[i,j]))),NoiseEst[i,j,-1:]])
    MD_SBIMid = np.zeros(ArrShape)
    FA_SBIMid = np.zeros(ArrShape)
    for i in range(ArrShape[0]):
        for j in range(ArrShape[1]): 
            Eigs = np.linalg.eigh(vals_to_mat(NoiseEst2[i,j,:6]))[0]
            MD_SBIMid[i,j] = np.mean(Eigs)
            FA_SBIMid[i,j] = FracAni(Eigs,np.mean(Eigs))
    FA_SBIMid[np.isnan(FA_SBIMid)] = 0
    MDMidArr.append(MD_SBIMid)
    FAMidArr.append(FA_SBIMid)

In [None]:
MDMinArr = []
FAMinArr = []
for kk in tqdm(range(32),position=0,leave=True):
    dat = maskdatas[kk]
    mask = np.sum(dat, axis=-1) != 0

    gtab = gTabs7[kk]
    # Get the indices where mask is True
    indices = np.argwhere(mask)
    def optimize_chunk(pixels):
        results = []
        for i, j in pixels:
            posterior_samples_1 = Network.sample((InferSamples,), x=DTIFeatures(gtab.bvecs,gtab.bvals,dat[i,j,selected_indices]),show_progress_bars=False)
            results.append((i, j, posterior_samples_1.mean(axis=0)))
        return results

    chunked_indices = [indices[i:i+ChunkSize] for i in range(0, len(indices), ChunkSize)]
    results = Parallel(n_jobs=8)(
        delayed(optimize_chunk)(chunk) for chunk in chunked_indices
    )

    # Initialize array with the appropriate shape
    ArrShape = mask.shape
    NoiseEst = np.zeros(list(ArrShape) + [7])

    # Assign the optimization results to NoiseEst
    for chunk in results:
        for i, j, x in chunk:
            NoiseEst[i, j] = x

    NoiseEst2 =  np.zeros_like(NoiseEst)
    for i in range(ArrShape[0]):
        for j in range(ArrShape[1]):    
            NoiseEst2[i,j] = np.hstack([mat_to_vals(clip_negative_eigenvalues(vals_to_mat(NoiseEst[i,j]))),NoiseEst[i,j,-1:]])
    MD_SBIFull = np.zeros(ArrShape)
    FA_SBIFull = np.zeros(ArrShape)
    for i in range(ArrShape[0]):
        for j in range(ArrShape[1]): 
            Eigs = np.linalg.eigh(vals_to_mat(NoiseEst2[i,j,:6]))[0]
            MD_SBIFull[i,j] = np.mean(Eigs)
            FA_SBIFull[i,j] = FracAni(Eigs,np.mean(Eigs))
    FA_SBIFull[np.isnan(FA_SBIFull)] = 0
    MDMinArr.append(MD_SBIFull)
    FAMinArr.append(FA_SBIFull)

In [None]:
MDFullNLArr = []
FAFullNLArr = []

MDMidNLArr = []
FAMidNLArr = []

MDMinNLArr = []
FAMinNLArr = []
for kk in tqdm(range(32),position=0,leave=True):
    dat = maskdatas[kk]
    gtab = gTabsF[kk]
    tenmodel = dti.TensorModel(gTabsF[kk],return_S0_hat = True,fit_method='NLLS')
    tenfit = tenmodel.fit(dat)
    FAFull_t = dti.fractional_anisotropy(tenfit.evals)
    MDFull_t = dti.mean_diffusivity(tenfit.evals)
    MDFullNLArr.append(MDFull_t)
    FAFullNLArr.append(FAFull_t)

    gtab = gTabs20[kk]
    tenmodel = dti.TensorModel(gtab,return_S0_hat = True,fit_method='NLLS')
    tenfit = tenmodel.fit(dat[...,selected_indices20])
    FAFull_t = dti.fractional_anisotropy(tenfit.evals)
    MDFull_t = dti.mean_diffusivity(tenfit.evals)
    MDMidNLArr.append(MDFull_t)
    FAMidNLArr.append(FAFull_t)
    
    gtab = gTabs7[kk]
    tenmodel = dti.TensorModel(gtab,return_S0_hat = True,fit_method='NLLS')
    tenfit = tenmodel.fit(dat[...,selected_indices])
    FAFull_t = dti.fractional_anisotropy(tenfit.evals)
    MDFull_t = dti.mean_diffusivity(tenfit.evals)
    MDMinNLArr.append(MDFull_t)
    FAMinNLArr.append(FAFull_t)

In [None]:
AccM7_MD = []
AccM20_MD = []
AccMFulls_MD = []

AccM7NL_MD = []
AccM20NL_MD = []

SSIM7_MD = []
SSIM20_MD = []
SSIMFulls_MD = []

SSIM7NL_MD = []
SSIM20NL_MD = []

Prec7_SBI_MD = []
Prec20_SBI_MD = []
PrecFull_SBI_MD = []

Prec7_NLLS_MD = []
Prec20_NLLS_MD = []
PrecFull_NLLS_MD = []
for i in tqdm(range(32),position=0,leave=True):
    M7 = MDMinArr[i]
    MF = MDFullArr[i]
    Ma = Masks[i]
    AccM7_MD.append(np.mean(np.abs(M7-MF)[Ma]))

    M7 = MDMidArr[i]
    MF = MDFullArr[i]
    AccM20_MD.append(np.mean(np.abs(M7-MF)[Ma]))

    M7 = MDFullArr[i]
    MF = MDFullNLArr[i]
    AccMFulls_MD.append(np.mean(np.abs(M7-MF)[Ma]))

    M7 = MDMinNLArr[i]
    MF = MDFullNLArr[i]
    AccM7NL_MD.append(np.mean(np.abs(M7-MF)[Ma]))

    M7 = MDMidNLArr[i]
    MF = MDFullNLArr[i]
    AccM20NL_MD.append(np.mean(np.abs(M7-MF)[Ma]))


    NS1 = MDMinArr[i]
    NS2 = MDFullArr[i]
    result = masked_local_ssim(NS1, NS2, Ma, win_size=7)
    SSIM7_MD.append(result)

    NS1 = MDMidArr[i]
    NS2 = MDFullArr[i]
    result = masked_local_ssim(NS1, NS2, Ma, win_size=7)
    SSIM20_MD.append(result)
    
    NS1 = MDFullArr[i]
    NS2 = MDFullNLArr[i]
    result = masked_local_ssim(NS1, NS2, Ma, win_size=7)
    SSIMFulls_MD.append(result)

    NS1 = MDMinNLArr[i]
    NS2 = MDFullNLArr[i]
    result = masked_local_ssim(NS1, NS2, Ma, win_size=7)
    SSIM7NL_MD.append(result)

    NS1 = MDMidNLArr[i]
    NS2 = MDFullNLArr[i]
    result = masked_local_ssim(NS1, NS2, Ma, win_size=7)
    SSIM20NL_MD.append(result)

    Prec7_SBI_MD.append(np.std(MDMinArr[i][WMs[i]]))
    Prec20_SBI_MD.append(np.std(MDMidArr[i][WMs[i]]))
    PrecFull_SBI_MD.append(np.std(MDFullArr[i][WMs[i]]))

    Prec7_NLLS_MD.append(np.std(MDMinNLArr[i][WMs[i]]))
    Prec20_NLLS_MD.append(np.std(MDMidNLArr[i][WMs[i]]))
    PrecFull_NLLS_MD.append(np.std(MDFullNLArr[i][WMs[i]]))



In [None]:
AccM7_FA = []
AccM20_FA = []
AccMFulls_FA = []

AccM7NL_FA = []
AccM20NL_FA = []

SSIM7_FA = []
SSIM20_FA = []
SSIMFulls_FA = []

SSIM7NL_FA = []
SSIM20NL_FA = []

Prec7_SBI_FA = []
Prec20_SBI_FA = []
PrecFull_SBI_FA = []

Prec7_NLLS_FA = []
Prec20_NLLS_FA = []
PrecFull_NLLS_FA = []
for i in range(32):
    M7 = FAMinArr[i]
    MF = FAFullArr[i]
    Ma = Masks[i]
    AccM7_FA.append(np.mean(np.abs(M7-MF)[Ma]))

    M7 = FAMidArr[i]
    MF = FAFullArr[i]
    AccM20_FA.append(np.mean(np.abs(M7-MF)[Ma]))

    M7 = FAFullArr[i]
    MF = FAFullNLArr[i]
    AccMFulls_FA.append(np.mean(np.abs(M7-MF)[Ma]))

    M7 = FAMinNLArr[i]
    MF = FAFullNLArr[i]
    AccM7NL_FA.append(np.mean(np.abs(M7-MF)[Ma]))

    M7 = FAMidNLArr[i]
    MF = FAFullNLArr[i]
    AccM20NL_FA.append(np.mean(np.abs(M7-MF)[Ma]))


    NS1 = FAMinArr[i]
    NS2 = FAFullArr[i]
    result = masked_local_ssim(NS1, NS2, Ma, win_size=7)
    SSIM7_FA.append(result)

    NS1 = FAMidArr[i]
    NS2 = FAFullArr[i]
    result = masked_local_ssim(NS1, NS2, Ma, win_size=7)
    SSIM20_FA.append(result)
    
    NS1 = FAFullArr[i]
    NS2 = FAFullNLArr[i]
    result = masked_local_ssim(NS1, NS2, Ma, win_size=7)
    SSIMFulls_FA.append(result)

    NS1 = FAMinNLArr[i]
    NS2 = FAFullNLArr[i]
    result = masked_local_ssim(NS1, NS2, Ma, win_size=7)
    SSIM7NL_FA.append(result)

    NS1 = FAMidNLArr[i]
    NS2 = FAFullNLArr[i]
    result = masked_local_ssim(NS1, NS2, Ma, win_size=7)
    SSIM20NL_FA.append(result)

    Prec7_SBI_FA.append(np.std(FAMinArr[i][WMs[i]]))
    Prec20_SBI_FA.append(np.std(FAMidArr[i][WMs[i]]))
    PrecFull_SBI_FA.append(np.std(FAFullArr[i][WMs[i]]))

    Prec7_NLLS_FA.append(np.std(FAMinNLArr[i][WMs[i]]))
    Prec20_NLLS_FA.append(np.std(FAMidNLArr[i][WMs[i]]))
    PrecFull_NLLS_FA.append(np.std(FAFullNLArr[i][WMs[i]]))


In [None]:
Dats_MS   = []
gTabs7_MS = []
gTabs20_MS = []
gTabsF_MS = []
Masks_MS   = []
TrueIndxs = []
axial_middles_MS = []
for i,Name in tqdm(enumerate(['NMSS_11_1year','NMSS_15','NMSS_16','NMSS_18','NMSS_19','Ctrl055_R01_28','Ctrl056_R01_29','Ctrl057_R01_30']),position = 0,leave = True):
    MatDir = '../../MS_data/'+Name

    F = pmt.read_mat(MatDir+'/data_loaded.mat')
    affine = np.ones((4,4))
    
    data, affine = reslice(F['data'], affine, (2,2,2), (2.5,2.5,2.5))
    _, maskCut = median_otsu(data, vol_idx=range(10, 80), autocrop=False)
    maskdata, mask = median_otsu(data, vol_idx=range(10, 80), autocrop=True)
    Masks_MS.append(mask)
    axial_middle = maskdata.shape[2] // 2
    axial_middles_MS.append(axial_middle)
    bvecs = (F['direction'].T/np.linalg.norm(F['direction'],axis=1)).T
    bvecs[np.isnan(bvecs)] = 0
    bvals = F['bval']
    bvecs2000 = bvecs[bvals==2000]
    bvecs4000 = bvecs[bvals==4000]

    bvals2000 = np.array([0] + list(bvals[bvals==2000]))
    bvecs2000 = np.vstack([[0,0,0],bvecs[bvals==2000]])

    Dats_MS.append(maskdata[:,:,:,np.hstack([0,np.where(bvals==2000)[0]])])
    
    gTabsF_MS.append(gradient_table(bvals = bvals2000,bvecs = bvecs2000))

    if(i == 0):
        # Choose the first point (arbitrary starting point, e.g., the first gradient)
        selected_indices_MS = [0]
        distance_matrix = squareform(pdist(bvecs2000))
        # Iteratively select the point furthest from the current selection
        for _ in range(6):  # We need 7 points in total, and one is already selected
            remaining_indices = list(set(range(len(bvecs2000))) - set(selected_indices_MS))
            
            # Calculate the minimum distance to the selected points for each remaining point
            min_distances = np.min(distance_matrix[remaining_indices][:, selected_indices_MS], axis=1)
            
            # Select the point with the maximum minimum distance
            next_index = remaining_indices[np.argmax(min_distances)]
            selected_indices_MS.append(next_index)
        
        selected_indices_MS = selected_indices_MS

        selected_indices20_MS = [0]
        distance_matrix = squareform(pdist(bvecs2000))
        # Iteratively select the point furthest from the current selection
        for _ in range(19):  # We need 7 points in total, and one is already selected
            remaining_indices = list(set(range(len(bvecs2000))) - set(selected_indices20_MS))
            
            # Calculate the minimum distance to the selected points for each remaining point
            min_distances = np.min(distance_matrix[remaining_indices][:, selected_indices20_MS], axis=1)
            
            # Select the point with the maximum minimum distance
            next_index = remaining_indices[np.argmax(min_distances)]
            selected_indices20_MS.append(next_index)

    bvalsHCP7 = bvals2000[selected_indices_MS]
    bvecsHCP7 = bvecs2000[selected_indices_MS]
    
    gTabs7_MS.append(gradient_table(bvals = bvalsHCP7, bvecs = bvecsHCP7))
    bvalsHCP7 = bvals2000[selected_indices20_MS]
    bvecsHCP7 = bvecs2000[selected_indices20_MS]
    
    gTabs20_MS.append(gradient_table(bvals = bvalsHCP7, bvecs = bvecsHCP7))

In [None]:
MSDir

In [None]:
WMDir = MSDir+'WM_masks/'
WMs_MS = []
for i,Name in tqdm(enumerate(['NMSS_11_1year','NMSS_15','NMSS_16','NMSS_18','NMSS_19','Ctrl055_R01_28','Ctrl056_R01_29','Ctrl057_R01_30']),position = 0,leave = True):
    MatDir = MSDir+Name
    F = pmt.read_mat(MatDir+'/data_loaded.mat')
    affine = np.ones((4,4))
    
    data, affine = reslice(F['data'], affine, (2,2,2), (2.5,2.5,2.5))
    _, maskCut = median_otsu(data, vol_idx=range(10, 80), autocrop=False)
    
    true_indices = np.argwhere(maskCut)
    
    # Determine the minimum and maximum indices along each dimension
    min_coords = true_indices.min(axis=0)
    max_coords = true_indices.max(axis=0)
    
    for k,x in enumerate(os.listdir(WMDir)):
        if Name in x:
            print(Name)
            WM, affine, img = load_nifti(WMDir+x, return_img=True)
            WM, affine = reslice(WM, affine, (2,2,2), (2.5,2.5,2.5))
            if(i<5):
                WM_t = np.fliplr(np.swapaxes(WM,0,1))
            else:
                WM_t = np.fliplr(np.flipud(np.swapaxes(WM,0,1)))
            WM_t  = WM_t[min_coords[0]:max_coords[0]+1,min_coords[1]:max_coords[1]+1,min_coords[2]:max_coords[2]+1]
            WMs_MS.append(WM_t)

In [None]:
MDFullArr_MS = []
FAFullArr_MS = []
for kk in tqdm(range(8),position=0,leave=True):
    Dat = Dats_MS[kk]
    # Compute the mask where the sum is not zero
    mask = np.sum(Dat[:, :, axial_middles_MS[kk], :], axis=-1) != 0
    gtab = gTabsF_MS[kk]
    # Get the indices where mask is True
    indices = np.argwhere(mask)
    AM = axial_middles_MS[kk]
    
    # Define the function for optimization
    def optimize_pixel(i, j):
        torch.manual_seed(10)  # If required
        posterior_samples_1 = Network.sample((InferSamples,), x=DTIFeatures(gtab.bvecs,gtab.bvals,Dat[i, j,AM, :]),show_progress_bars=False)
        return i, j, posterior_samples_1.mean(axis=0)
    
    # Initialize NoiseEst with the appropriate shape
    ArrShape = mask.shape
    
    # Use joblib to parallelize the optimization tasks
    results = Parallel(n_jobs=8)(
        delayed(optimize_pixel)(i, j) for i, j in indices
    )
    
    NoiseEst = np.zeros(list(ArrShape) + [7])
    
    # Assign the optimization results to NoiseEst
    for i, j, x in results:
        NoiseEst[i, j] = x
    
    NoiseEst2 =  np.zeros_like(NoiseEst)
    for i in range(ArrShape[0]):
        for j in range(ArrShape[1]):    
            NoiseEst2[i,j] = np.hstack([mat_to_vals(clip_negative_eigenvalues(vals_to_mat(NoiseEst[i,j]))),NoiseEst[i,j,-1:]])
    MD_SBIMin = np.zeros(ArrShape)
    FA_SBIMin = np.zeros(ArrShape)
    for i in range(ArrShape[0]):
        for j in range(ArrShape[1]): 
            Eigs = np.linalg.eigh(vals_to_mat(NoiseEst2[i,j,:6]))[0]
            MD_SBIMin[i,j] = np.mean(Eigs)
            FA_SBIMin[i,j] = FracAni(Eigs,np.mean(Eigs))
    FA_SBIMin[np.isnan(FA_SBIMin)] = 0
    MDFullArr_MS.append(MD_SBIMin)
    FAFullArr_MS.append(FA_SBIMin)

In [None]:
MDMinArr_MS = []
FAMinArr_MS = []
for kk in tqdm(range(8),position=0,leave=True):
    Dat = Dats_MS[kk]
    # Compute the mask where the sum is not zero
    mask = np.sum(Dat[:, :, axial_middles_MS[kk], :], axis=-1) != 0
    gtab = gTabs7_MS[kk]
    # Get the indices where mask is True
    indices = np.argwhere(mask)
    AM = axial_middles_MS[kk]
    
    # Define the function for optimization
    def optimize_pixel(i, j):
        torch.manual_seed(10)  # If required
        posterior_samples_1 = Network.sample((InferSamples,), x=DTIFeatures(gtab.bvecs,gtab.bvals,Dat[i, j,AM, selected_indices_MS]),show_progress_bars=False)
        return i, j, posterior_samples_1.mean(axis=0)
    
    # Initialize NoiseEst with the appropriate shape
    ArrShape = mask.shape
    
    # Use joblib to parallelize the optimization tasks
    results = Parallel(n_jobs=8)(
        delayed(optimize_pixel)(i, j) for i, j in indices
    )
    
    NoiseEst = np.zeros(list(ArrShape) + [7])
    
    # Assign the optimization results to NoiseEst
    for i, j, x in results:
        NoiseEst[i, j] = x
    
    NoiseEst2 =  np.zeros_like(NoiseEst)
    for i in range(ArrShape[0]):
        for j in range(ArrShape[1]):    
            NoiseEst2[i,j] = np.hstack([mat_to_vals(clip_negative_eigenvalues(vals_to_mat(NoiseEst[i,j]))),NoiseEst[i,j,-1:]])
    MD_SBIMin = np.zeros(ArrShape)
    FA_SBIMin = np.zeros(ArrShape)
    for i in range(ArrShape[0]):
        for j in range(ArrShape[1]): 
            Eigs = np.linalg.eigh(vals_to_mat(NoiseEst2[i,j,:6]))[0]
            MD_SBIMin[i,j] = np.mean(Eigs)
            FA_SBIMin[i,j] = FracAni(Eigs,np.mean(Eigs))
    FA_SBIMin[np.isnan(FA_SBIMin)] = 0
    MDMinArr_MS.append(MD_SBIMin)
    FAMinArr_MS.append(FA_SBIMin)

In [None]:
MDMidArr_MS = []
FAMidArr_MS = []
for kk in tqdm(range(8),position=0,leave=True):
    Dat = Dats_MS[kk]
    # Compute the mask where the sum is not zero
    mask = np.sum(Dat[:, :, axial_middles_MS[kk], :], axis=-1) != 0
    gtab = gTabs20_MS[kk]
    # Get the indices where mask is True
    indices = np.argwhere(mask)
    AM = axial_middles_MS[kk]
    
    # Define the function for optimization
    def optimize_pixel(i, j):
        torch.manual_seed(10)  # If required
        posterior_samples_1 = Network.sample((InferSamples,), x=DTIFeatures(gtab.bvecs,gtab.bvals,Dat[i, j,AM, selected_indices20_MS]),show_progress_bars=False)
        return i, j, posterior_samples_1.mean(axis=0)
    
    # Initialize NoiseEst with the appropriate shape
    ArrShape = mask.shape
    
    # Use joblib to parallelize the optimization tasks
    results = Parallel(n_jobs=8)(
        delayed(optimize_pixel)(i, j) for i, j in indices
    )
    
    NoiseEst = np.zeros(list(ArrShape) + [7])
    
    # Assign the optimization results to NoiseEst
    for i, j, x in results:
        NoiseEst[i, j] = x
    
    NoiseEst2 =  np.zeros_like(NoiseEst)
    for i in range(ArrShape[0]):
        for j in range(ArrShape[1]):    
            NoiseEst2[i,j] = np.hstack([mat_to_vals(clip_negative_eigenvalues(vals_to_mat(NoiseEst[i,j]))),NoiseEst[i,j,-1:]])
    MD_SBIMin = np.zeros(ArrShape)
    FA_SBIMin = np.zeros(ArrShape)
    for i in range(ArrShape[0]):
        for j in range(ArrShape[1]): 
            Eigs = np.linalg.eigh(vals_to_mat(NoiseEst2[i,j,:6]))[0]
            MD_SBIMin[i,j] = np.mean(Eigs)
            FA_SBIMin[i,j] = FracAni(Eigs,np.mean(Eigs))
    FA_SBIMin[np.isnan(FA_SBIMin)] = 0
    MDMidArr_MS.append(MD_SBIMin)
    FAMidArr_MS.append(FA_SBIMin)

In [None]:
MDFullNLArr_MS = []
FAFullNLArr_MS = []
for kk in range(8):
    
    tenmodel = dti.TensorModel(gTabsF_MS[kk],return_S0_hat = True,fit_method='NLLS')
    tenfit = tenmodel.fit(Dats_MS[kk][:,:,axial_middles_MS[kk],:])
    FAFull = dti.fractional_anisotropy(tenfit.evals)
    MDFull = dti.mean_diffusivity(tenfit.evals)
    MDFullNLArr_MS.append(MDFull)
    FAFullNLArr_MS.append(FAFull)
MDMinNLArr_MS = []
FAMinNLArr_MS = []
for kk in range(8):
    
    tenmodel = dti.TensorModel(gTabs7_MS[kk],return_S0_hat = True,fit_method='NLLS')
    tenfit = tenmodel.fit(Dats_MS[kk][:,:,axial_middles_MS[kk],selected_indices_MS])
    FAFull = dti.fractional_anisotropy(tenfit.evals)
    MDFull = dti.mean_diffusivity(tenfit.evals)
    MDMinNLArr_MS.append(MDFull)
    FAMinNLArr_MS.append(FAFull)
MDMidNLArr_MS = []
FAMidNLArr_MS = []
for kk in range(8):
    
    tenmodel = dti.TensorModel(gTabs20_MS[kk],return_S0_hat = True,fit_method='NLLS')
    tenfit = tenmodel.fit(Dats_MS[kk][:,:,axial_middles_MS[kk],selected_indices20_MS])
    FAFull = dti.fractional_anisotropy(tenfit.evals)
    MDFull = dti.mean_diffusivity(tenfit.evals)
    MDMidNLArr_MS.append(MDFull)
    FAMidNLArr_MS.append(FAFull)

In [None]:
AccM7_MD_MS = []
AccM20_MD_MS = []
AccMFulls_MD_MS = []

AccM7NL_MD_MS = []
AccM20NL_MD_MS = []

SSIM7_MD_MS = []
SSIM20_MD_MS = []
SSIMFulls_MD_MS = []

SSIM7NL_MD_MS = []
SSIM20NL_MD_MS = []
for i in range(8):
    M7 = MDMinArr_MS[i]
    MF = MDFullArr_MS[i]
    Ma = Masks_MS[i][:,:,axial_middles_MS[i]]
    AccM7_MD_MS.append(np.mean(np.abs(M7-MF)[Ma]))

    M7 = MDMidArr_MS[i]
    MF = MDFullArr_MS[i]
    AccM20_MD_MS.append(np.mean(np.abs(M7-MF)[Ma]))

    M7 = MDFullArr_MS[i]
    MF = MDFullNLArr_MS[i]
    AccMFulls_MD_MS.append(np.mean(np.abs(M7-MF)[Ma]))

    M7 = MDMinNLArr_MS[i]
    MF = MDFullNLArr_MS[i]
    AccM7NL_MD_MS.append(np.mean(np.abs(M7-MF)[Ma]))

    M7 = MDMidNLArr_MS[i]
    MF = MDFullNLArr_MS[i]
    AccM20NL_MD_MS.append(np.mean(np.abs(M7-MF)[Ma]))


    NS1 = MDMinArr_MS[i]
    NS2 = MDFullArr_MS[i]
    result = masked_local_ssim(NS1, NS2, Ma, win_size=7)
    SSIM7_MD_MS.append(result)

    NS1 = MDMidArr_MS[i]
    NS2 = MDFullArr_MS[i]
    result = masked_local_ssim(NS1, NS2, Ma, win_size=7)
    SSIM20_MD_MS.append(result)
    
    NS1 = MDFullArr_MS[i]
    NS2 = MDFullNLArr_MS[i]
    result = masked_local_ssim(NS1, NS2, Ma, win_size=7)
    SSIMFulls_MD_MS.append(result)

    NS1 = MDMinNLArr_MS[i]
    NS2 = MDFullNLArr_MS[i]
    result = masked_local_ssim(NS1, NS2, Ma, win_size=7)
    SSIM7NL_MD_MS.append(result)

    NS1 = MDMidNLArr_MS[i]
    NS2 = MDFullNLArr_MS[i]
    result = masked_local_ssim(NS1, NS2, Ma, win_size=7)
    SSIM20NL_MD_MS.append(result)


Prec7_SBI_MD_MS = []
Prec20_SBI_MD_MS = []
PrecFull_SBI_MD_MS = []

Prec7_NLLS_MD_MS = []
Prec20_NLLS_MD_MS = []
PrecFull_NLLS_MD_MS = []
for i in range(8):
    Prec7_SBI_MD_MS.append(np.std(MDMinArr_MS[i][WMs_MS[i][:,:,axial_middles_MS[i]]>0.8]))
    Prec20_SBI_MD_MS.append(np.std(MDMidArr_MS[i][WMs_MS[i][:,:,axial_middles_MS[i]]>0.8]))
    PrecFull_SBI_MD_MS.append(np.std(MDFullArr_MS[i][WMs_MS[i][:,:,axial_middles_MS[i]]>0.8]))

    Prec7_NLLS_MD_MS.append(np.std(MDMinNLArr_MS[i][WMs_MS[i][:,:,axial_middles_MS[i]]>0.8]))
    Prec20_NLLS_MD_MS.append(np.std(MDMidNLArr_MS[i][WMs_MS[i][:,:,axial_middles_MS[i]]>0.8]))
    PrecFull_NLLS_MD_MS.append(np.std(MDFullNLArr_MS[i][WMs_MS[i][:,:,axial_middles_MS[i]]>0.8]))


In [None]:
AccM7_FA_MS = []
AccM20_FA_MS = []
AccMFulls_FA_MS = []

AccM7NL_FA_MS = []
AccM20NL_FA_MS = []

SSIM7_FA_MS = []
SSIM20_FA_MS = []
SSIMFulls_FA_MS = []

SSIM7NL_FA_MS = []
SSIM20NL_FA_MS = []
for i in range(8):
    M7 = FAMinArr_MS[i]
    MF = FAFullArr_MS[i]
    Ma = Masks_MS[i][:,:,axial_middles_MS[i]]
    AccM7_FA_MS.append(np.mean(np.abs(M7-MF)[Ma]))

    M7 = FAMidArr_MS[i]
    MF = FAFullArr_MS[i]
    AccM20_FA_MS.append(np.mean(np.abs(M7-MF)[Ma]))

    M7 = FAFullArr_MS[i]
    MF = FAFullNLArr_MS[i]
    AccMFulls_FA_MS.append(np.mean(np.abs(M7-MF)[Ma]))

    M7 = FAMinNLArr_MS[i]
    MF = FAFullNLArr_MS[i]
    AccM7NL_FA_MS.append(np.mean(np.abs(M7-MF)[Ma]))

    M7 = FAMidNLArr_MS[i]
    MF = FAFullNLArr_MS[i]
    AccM20NL_FA_MS.append(np.mean(np.abs(M7-MF)[Ma]))


    NS1 = FAMinArr_MS[i]
    NS2 = FAFullArr_MS[i]
    result = masked_local_ssim(NS1, NS2, Ma, win_size=7)
    SSIM7_FA_MS.append(result)

    NS1 = FAMidArr_MS[i]
    NS2 = FAFullArr_MS[i]
    result = masked_local_ssim(NS1, NS2, Ma, win_size=7)
    SSIM20_FA_MS.append(result)
    
    NS1 = FAFullArr_MS[i]
    NS2 = FAFullNLArr_MS[i]
    result = masked_local_ssim(NS1, NS2, Ma, win_size=7)
    SSIMFulls_FA_MS.append(result)

    NS1 = FAMinNLArr_MS[i]
    NS2 = FAFullNLArr_MS[i]
    result = masked_local_ssim(NS1, NS2, Ma, win_size=7)
    SSIM7NL_FA_MS.append(result)

    NS1 = FAMidNLArr_MS[i]
    NS2 = FAFullNLArr_MS[i]
    result = masked_local_ssim(NS1, NS2, Ma, win_size=7)
    SSIM20NL_FA_MS.append(result)


Prec7_SBI_FA_MS = []
Prec20_SBI_FA_MS = []
PrecFull_SBI_FA_MS = []

Prec7_NLLS_FA_MS = []
Prec20_NLLS_FA_MS = []
PrecFull_NLLS_FA_MS = []
for i in range(8):
    Prec7_SBI_FA_MS.append(np.std(FAMinArr_MS[i][WMs_MS[i][:,:,axial_middles_MS[i]]>0.8]))
    Prec20_SBI_FA_MS.append(np.std(FAMidArr_MS[i][WMs_MS[i][:,:,axial_middles_MS[i]]>0.8]))
    PrecFull_SBI_FA_MS.append(np.std(FAFullArr_MS[i][WMs_MS[i][:,:,axial_middles_MS[i]]>0.8]))

    Prec7_NLLS_FA_MS.append(np.std(FAMinNLArr_MS[i][WMs_MS[i][:,:,axial_middles_MS[i]]>0.8]))
    Prec20_NLLS_FA_MS.append(np.std(FAMidNLArr_MS[i][WMs_MS[i][:,:,axial_middles_MS[i]]>0.8]))
    PrecFull_NLLS_FA_MS.append(np.std(FAFullNLArr_MS[i][WMs_MS[i][:,:,axial_middles_MS[i]]>0.8]))


In [None]:
MD_RT_SBI = []
FA_RT_SBI = []
MD_RT_NLLS  =[]
FA_RT_NLLS  =[]

MD_RT_SBI_Min = []
FA_RT_SBI_Min = []
MD_RT_NLLS_Min  =[]
FA_RT_NLLS_Min  =[]

for jj,N in enumerate(RTNames):
    Subfiles = []
    for k,x in enumerate(os.listdir(BaseDir)):
        if N in x:
            print(x)
            Subfiles.append(x)
    Subfiles = sorted(Subfiles)

    S = Subfiles[0]
    MatDir = BaseDir+S
    F = pmt.read_mat(MatDir+'/data_loaded.mat')
    affine = np.eye(4)

    data, affine = reslice(F['data'], affine, (2,2,2), (2.5,2.5,2.5))
    _, maskCut = median_otsu(data, vol_idx=range(10, 80), autocrop=False)
    true_indices = np.argwhere(maskCut)

    # Determine the minimum and maximum indices along each dimension
    min_coords = true_indices.min(axis=0)
    max_coords = true_indices.max(axis=0)
    AM = (max_coords[-1]+min_coords[-1])//2
    bvecs = (F['direction'].T/np.linalg.norm(F['direction'],axis=1)).T
    bvecs[np.isnan(bvecs)] = 0
    bvals = F['bval']
    data = data[...,np.logical_or(bvals==2000,bvals == 0)]
    bvecs2000 = bvecs[np.logical_or(bvals==2000,bvals == 0)]

    bvals2000 = np.array(list(bvals[np.logical_or(bvals==2000,bvals == 0)]))

    gtabs = [gradient_table(bvals = bvals2000,bvecs = bvecs2000)]
    Dats = []
    for i,S in enumerate(Subfiles[1:]):
        MatDir = BaseDir+S
        F = pmt.read_mat(MatDir+'/data_loaded.mat')
        affine = np.eye(4)

        data1, affine = reslice(F['data'], affine, (2,2,2), (2.5,2.5,2.5))

        bvecs = (F['direction'].T/np.linalg.norm(F['direction'],axis=1)).T
        bvecs[np.isnan(bvecs)] = 0
        bvals = F['bval']
        data1 = data1[...,np.logical_or(bvals==2000,bvals == 0)]
        if(jj == 0):
            if(i < 2):
                data1 = data1[:,::-1]
        elif(jj == 1 or jj == 2 or jj == 3 or jj == 4):
            if(i<3):
                data1 = data1[:,::-1]
        elif(jj==5):
            if(i>0 and i < 3):
                data1 = data1[:,::-1]
        Dats.append(data1)
        bvecs2000 = bvecs[np.logical_or(bvals==2000,bvals == 0)]

        bvals2000 = np.array(list(bvals[np.logical_or(bvals==2000,bvals == 0)]))

        gtabs.append(gradient_table(bvals = bvals2000,bvecs = bvecs2000))
        
        selected_indices_MS = [0]
        bvecs_temp = np.copy(gtab1.bvecs)
        distance_matrix = squareform(pdist(bvecs_temp))
        # Iteratively select the point furthest from the current selection
        for _ in range(6):  # We need 7 points in total, and one is already selected
            remaining_indices = list(set(range(len(bvecs_temp))) - set(selected_indices_MS))

            # Calculate the minimum distance to the selected points for each remaining point
            min_distances = np.min(distance_matrix[remaining_indices][:, selected_indices_MS], axis=1)

            # Select the point with the maximum minimum distance
            next_index = remaining_indices[np.argmax(min_distances)]
            selected_indices_MS.append(next_index)

        bvalsHCP7 = gtab1.bvals[selected_indices1]
        bvecsHCP7 = gtab1.bvecs[selected_indices1]


        gtabs7.append(gradient_table(bvals = bvalsHCP7, bvecs = bvecsHCP7))
        Indxs7.append(selected_indices_MS)
        
    NewDats = [data]
    for d,gt in zip(Dats,gtabs[1:]):
        affine_map = rigid_register(data[...,gtabs[0].bvals==0].mean(axis=-1),d[...,gt.bvals==0].mean(axis=-1),affine1,affine1)
        data2_warp = np.array([affine_map.transform(d[:,:,:,i], interpolation="linear") for i in range(len(gt.bvals))])
        data2_warp = np.rollaxis(data2_warp, 0, data2_warp.ndim)
        NewDats.append(data2_warp)

    NewDats_masked = [ND*maskCut[...,None] for ND in NewDats]
    MD_arr = []
    FA_arr = []
    for ND,gt in tqdm(zip(NewDats_masked,gtabs),position=0,leave=True):
        mask = np.sum(ND[:, :, 42, :], axis=-1) != 0
        gtab = gt
        # Get the indices where mask is True
        indices = np.argwhere(mask)
        floor = np.clip(ND.min(axis=-1),-np.inf,0)
        dat = ND + abs(floor)[:,:,:,None] + 1e-5
        # Define the function for optimization
        def optimize_pixel(i, j):
            torch.manual_seed(10)  # If required
            posterior_samples_1 = Network.sample((InferSamples,), x=DTIFeatures(gtab.bvecs,gtab.bvals,dat[i, j,AM, :]),show_progress_bars=False)
            return i, j, posterior_samples_1.mean(axis=0)

        # Initialize NoiseEst with the appropriate shape
        ArrShape = mask.shape

        # Use joblib to parallelize the optimization tasks
        results = Parallel(n_jobs=8)(
            delayed(optimize_pixel)(i, j) for i, j in tqdm(indices,position=0,leave=True)
        )

        NoiseEst = np.zeros(list(ArrShape) + [7])

        # Assign the optimization results to NoiseEst
        for i, j, x in results:
            NoiseEst[i, j] = x

        NoiseEst2 =  np.zeros_like(NoiseEst)
        for i in range(ArrShape[0]):
            for j in range(ArrShape[1]):    
                NoiseEst2[i,j] = np.hstack([mat_to_vals(clip_negative_eigenvalues(vals_to_mat(NoiseEst[i,j]))),NoiseEst[i,j,-1:]])
        MD_SBIMin = np.zeros(ArrShape)
        FA_SBIMin = np.zeros(ArrShape)
        for i in range(ArrShape[0]):
            for j in range(ArrShape[1]): 
                Eigs = np.linalg.eigh(vals_to_mat(NoiseEst2[i,j,:6]))[0]
                MD_SBIMin[i,j] = np.mean(Eigs)
                FA_SBIMin[i,j] = FracAni(Eigs,np.mean(Eigs))
        FA_SBIMin[np.isnan(FA_SBIMin)] = 0

        MD_arr.append(MD_SBIMin)
        FA_arr.append(FA_SBIMin)

    FA_NLS_arr = []
    MD_NLS_arr = []
    for d,gt in zip(NewDats_masked,gtabs):
        tenmodel = dti.TensorModel(gt,return_S0_hat = True)
        tenfit = tenmodel.fit(d[:,:,AM])

        FA_NLS_arr.append(tenfit.fa)
        MD_NLS_arr.append(tenfit.md)
    
    MD_RT_SBI_Full.append(MD_arr)
    FA_RT_SBI_Full.append(FA_arr)

    ND_RT_NLLS_Full.append(MD_NLS_arr)
    FA_RT_NLLS_Full.append(FA_NLS_arr)
    
    MD_arr = []
    FA_arr = []
    for ND,gt,idx in tqdm(zip(NewDats_masked,gtabs7,Indxs7),position=0,leave=True):
        mask = np.sum(ND[:, :, 42, :], axis=-1) != 0
        gtab = gt
        # Get the indices where mask is True
        indices = np.argwhere(mask)
        floor = np.clip(ND.min(axis=-1),-np.inf,0)
        dat = ND + abs(floor)[:,:,:,None] + 1e-5
        dat = dat[...,idx]
        # Define the function for optimization
        def optimize_pixel(i, j):
            torch.manual_seed(10)  # If required
            posterior_samples_1 = Network.sample((InferSamples,), x=DTIFeatures(gtab.bvecs,gtab.bvals,dat[i, j,AM, :]),show_progress_bars=False)
            return i, j, posterior_samples_1.mean(axis=0)

        # Initialize NoiseEst with the appropriate shape
        ArrShape = mask.shape

        # Use joblib to parallelize the optimization tasks
        results = Parallel(n_jobs=8)(
            delayed(optimize_pixel)(i, j) for i, j in tqdm(indices,position=0,leave=True)
        )

        NoiseEst = np.zeros(list(ArrShape) + [7])

        # Assign the optimization results to NoiseEst
        for i, j, x in results:
            NoiseEst[i, j] = x

        NoiseEst2 =  np.zeros_like(NoiseEst)
        for i in range(ArrShape[0]):
            for j in range(ArrShape[1]):    
                NoiseEst2[i,j] = np.hstack([mat_to_vals(clip_negative_eigenvalues(vals_to_mat(NoiseEst[i,j]))),NoiseEst[i,j,-1:]])
        MD_SBIMin = np.zeros(ArrShape)
        FA_SBIMin = np.zeros(ArrShape)
        for i in range(ArrShape[0]):
            for j in range(ArrShape[1]): 
                Eigs = np.linalg.eigh(vals_to_mat(NoiseEst2[i,j,:6]))[0]
                MD_SBIMin[i,j] = np.mean(Eigs)
                FA_SBIMin[i,j] = FracAni(Eigs,np.mean(Eigs))
        FA_SBIMin[np.isnan(FA_SBIMin)] = 0

        MD_arr.append(MD_SBIMin)
        FA_arr.append(FA_SBIMin)

    FA_NLS_arr = []
    MD_NLS_arr = []
    for d,gt,idx in zip(NewDats_masked,gtabs7,Indxs7):
        tenmodel = dti.TensorModel(gt,return_S0_hat = True)
        tenfit = tenmodel.fit(d[:,:,AM,idx])

        FA_NLS_arr.append(tenfit.fa)
        MD_NLS_arr.append(tenfit.md)
    
    MD_RT_SBI_Min.append(MD_arr)
    FA_RT_SBI_Min.append(FA_arr)

    MD_RT_NLLS_Min.append(MD_NLS_arr)
    FA_RT_NLLS_Min.append(FA_NLS_arr)

In [None]:
MD_RT_SBI_Min_list = []
FA_RT_SBI_Min_list = []
MD_RT_SBI_Full_list = []
FA_RT_SBI_Full_list = []

MD_RT_NLLS_Min_list = []
FA_RT_NLLS_Min_list = []
MD_RT_NLLS_Full_list = []
FA_RT_NLLS_Full_list = []


for kk in range(len(MD_RT_SBI_Min)):
    for i in range(len(MD_RT_SBI_Min[kk])):
        MD_RT_SBI_Min_list.append(MD_RT_SBI_Min[kk][i])
        FA_RT_SBI_Min_list.append(FA_RT_SBI_Min[kk][i])
        MD_RT_SBI_Full_list.append(MD_RT_SBI_Full[kk][i])
        FA_RT_SBI_Full_list.append(FA_RT_SBI_Full[kk][i])
        
        MD_RT_NLLS_Min_list.append(MD_RT_NLLS_Min[kk][i])
        FA_RT_NLLS_Min_list.append(FA_RT_NLLS_Min[kk][i])
        MD_RT_NLLS_Full_list.append(MD_RT_NLLS_Full[kk][i])
        FA_RT_NLLS_Full_list.append(FA_RT_NLLS_Full[kk][i])

In [None]:
AccM7_MD_RT = []
AccMFulls_MD_RT = []

AccM7NL_MD_RT = []

SSIM7_MD_RT = []
SSIMFulls_MD_RT = []

SSIM7NL_MD_RT = []
for i in tqdm(range(29),position=0,leave=True):
    M7 = MD_RT_SBI_Min_list[i]
    MF = MD_RT_SBI_Full_list[i]
    Ma = np.logical_not(MD_RT_SBI_Min_list[i]==1e-5)
    AccM7_MD_RT.append(np.mean(np.abs(M7-MF)[Ma]))

    M7 = MD_RT_SBI_Full_list[i]
    MF = MD_RT_NLLS_Full_list[i]
    AccMFulls_MD_RT.append(np.mean(np.abs(M7-MF)[Ma]))

    M7 = MD_RT_NLLS_Min_list[i]
    MF = MD_RT_NLLS_Full_list[i]
    AccM7NL_MD_RT.append(np.mean(np.abs(M7-MF)[Ma]))


    NS1 = MD_RT_SBI_Min_list[i]
    NS2 = MD_RT_SBI_Full_list[i]
    result = masked_local_ssim(NS1, NS2, Ma, win_size=7)
    SSIM7_MD_RT.append(result)

    NS1 = MD_RT_SBI_Full_list[i]
    NS2 = MD_RT_NLLS_Full_list[i]
    result = masked_local_ssim(NS1, NS2, Ma, win_size=7)
    SSIMFulls_MD_RT.append(result)

    NS1 = MD_RT_NLLS_Min_list[i]
    NS2 = MD_RT_NLLS_Full_list[i]
    result = masked_local_ssim(NS1, NS2, Ma, win_size=7)
    SSIM7NL_MD_RT.append(result)



In [None]:
AccM7_FA_RT = []
AccMFulls_FA_RT = []

AccM7NL_FA_RT = []

SSIM7_FA_RT = []
SSIMFulls_FA_RT = []

SSIM7NL_FA_RT = []
for i in tqdm(range(29),position=0,leave=True):
    M7 = FA_RT_SBI_Min_list[i]
    MF = FA_RT_SBI_Full_list[i]
    Ma = np.logical_not(MD_RT_SBI_Min_list[i]==1e-5)
    AccM7_FA_RT.append(np.mean(np.abs(M7-MF)[Ma]))

    M7 = FA_RT_SBI_Full_list[i]
    MF = FA_RT_NLLS_Full_list[i]
    AccMFulls_FA_RT.append(np.mean(np.abs(M7-MF)[Ma]))

    M7 = FA_RT_NLLS_Min_list[i]
    MF = FA_RT_NLLS_Full_list[i]
    AccM7NL_FA_RT.append(np.mean(np.abs(M7-MF)[Ma]))


    NS1 = FA_RT_SBI_Min_list[i]
    NS2 = FA_RT_SBI_Full_list[i]
    result = masked_local_ssim(NS1, NS2, Ma, win_size=7)
    SSIM7_FA_RT.append(result)

    NS1 = FA_RT_SBI_Full_list[i]
    NS2 = FA_RT_NLLS_Full_list[i]
    result = masked_local_ssim(NS1, NS2, Ma, win_size=7)
    SSIMFulls_FA_RT.append(result)

    NS1 = FA_RT_NLLS_Min_list[i]
    NS2 = FA_RT_NLLS_Full_list[i]
    result = masked_local_ssim(NS1, NS2, Ma, win_size=7)
    SSIM7NL_FA_RT.append(result)



In [None]:
# Plot setup
fig, ax2 = plt.subplots(1, 1)
fig.subplots_adjust(hspace=0.05)

y_data = np.array(AccM7NL_MD + AccM7NL_MD_MS+AccM7NL_MD_RT)
g_pos = np.array([3.1])
colors = ['burlywood']
colors2 = ['peachpuff']

BoxPlots(y_data,g_pos,colors,colors2,ax2,widths=0.2,scatter=False)
plt.xticks([1,1.7,2,2.8,3.1],['Full','Mid','Min','Mid','Min'],fontsize=32,rotation=90)
#plt.gca().ticklabel_format(axis='y',style='sci',scilimits=(-1,1))
plt.yticks(fontsize=24)

y_data = np.array(AccM7NL_MD)
x_data = g_pos*np.ones_like(y_data)
x_data += stats.t(df=6, scale=0.02).rvs(len(x_data))
ax2.scatter(x_data,y_data,marker='o',color=colors2,s=100,alpha=0.8)

y_data = np.array(AccM7NL_MD_MS)
x_data = g_pos*np.ones_like(y_data)
x_data += stats.t(df=6, scale=0.02).rvs(len(x_data))
ax2.scatter(x_data[:5],y_data[:5],marker='o',color='darkorange',s=100,alpha=0.5)
ax2.scatter(x_data[5:],y_data[5:],marker='^',color='darkorange',s=100,alpha=0.5)

y_data = np.array(AccM7NL_MD_RT)
x_data = g_pos*np.ones_like(y_data)
x_data += stats.t(df=6, scale=0.02).rvs(len(x_data))
ax2.scatter(x_data,y_data,marker='s',color='chocolate',s=100,alpha=0.5)


ax1.set_ylim(0, 8e-4)
ax1.ticklabel_format(axis='y', style='sci', scilimits=(-3, -3))
#ax1.yaxis.set_ticks(np.arange(0.0005, 0.006, 0.002))
ax1.set_xticks([])

plt.sca(ax2)

y_data = np.array(AccMFulls_MD + AccMFulls_MD_MS+AccMFulls_MD_RT)
g_pos = np.array([1])
colors = ['black']
colors2 = ['gray']

BoxPlots(y_data,g_pos,colors,colors2,ax2,widths=0.2,scatter=False)

y_data = np.array(AccMFulls_MD)
x_data = g_pos*np.ones_like(y_data)
x_data += stats.t(df=6, scale=0.02).rvs(len(x_data))
plt.scatter(x_data,y_data,marker='o',color='gray',s=100,alpha=0.5)

y_data = np.array(AccMFulls_MD_MS)
x_data = g_pos*np.ones_like(y_data)
x_data += stats.t(df=6, scale=0.02).rvs(len(x_data))
plt.scatter(x_data[:5],y_data[:5],marker='o',color='black',s=100,alpha=0.5)
plt.scatter(x_data[5:],y_data[5:],marker='^',color='black',s=100,alpha=0.5)

y_data = np.array(AccMFulls_MD_RT)
x_data = g_pos*np.ones_like(y_data)
x_data += stats.t(df=6, scale=0.02).rvs(len(x_data))
ax2.scatter(x_data,y_data,marker='s',color='gray',s=100,alpha=0.5)

y_data = np.array(AccM20_MD + AccM20_MD_MS)
g_pos = np.array([1.7])
colors = ['lightseagreen']
colors2 = ['paleturquoise']

BoxPlots(y_data,g_pos,colors,colors2,ax2,widths=0.2,scatter=False)

y_data = np.array(AccM20_MD)
x_data = g_pos*np.ones_like(y_data)
x_data += stats.t(df=6, scale=0.02).rvs(len(x_data))
plt.scatter(x_data,y_data,marker='o',color=colors2,s=100,alpha=0.5)

y_data = np.array(AccM20_MD_MS)
x_data = g_pos*np.ones_like(y_data)
x_data += stats.t(df=6, scale=0.02).rvs(len(x_data))
plt.scatter(x_data[:5],y_data[:5],marker='o',color='darkcyan',s=100,alpha=0.5)
plt.scatter(x_data[5:],y_data[5:],marker='^',color='darkcyan',s=100,alpha=0.5)

y_data = np.array(AccM7_MD + AccM7_MD_MS+AccM7_MD_RT)
g_pos = np.array([2])
colors = ['mediumturquoise']
colors2 = ['paleturquoise']

BoxPlots(y_data,g_pos,colors,colors2,ax2,widths=0.2,scatter=False)

y_data = np.array(AccM7_MD)
x_data = g_pos*np.ones_like(y_data)
x_data += stats.t(df=6, scale=0.02).rvs(len(x_data))
plt.scatter(x_data,y_data,marker='o',color=colors2,s=100,alpha=0.5)


y_data = np.array(AccM7_MD_RT)
x_data = g_pos*np.ones_like(y_data)
x_data += stats.t(df=6, scale=0.02).rvs(len(x_data))
ax2.scatter(x_data,y_data,marker='s',color='mediumaquamarine',s=100,alpha=0.5,label='GH data')

y_data = np.array(AccM7_MD_MS)
x_data = g_pos*np.ones_like(y_data)
x_data += stats.t(df=6, scale=0.02).rvs(len(x_data))
plt.scatter(x_data[:5],y_data[:5],marker='o',color='darkcyan',s=100,alpha=0.5)
plt.scatter(x_data[5:],y_data[5:],marker='^',color='darkcyan',s=100,alpha=0.5)

g_pos = np.array([2.8])
colors = ['sandybrown']
colors2 = ['peachpuff']

y_data = np.array(AccM20NL_MD + AccM20NL_MD_MS)

BoxPlots(y_data,g_pos,colors,colors2,ax2,widths=0.2,scatter=False)

y_data = np.array(AccM20NL_MD)
x_data = g_pos*np.ones_like(y_data)
x_data += stats.t(df=6, scale=0.02).rvs(len(x_data))
plt.scatter(x_data,y_data,marker='o',color=colors2,s=100,alpha=0.5)

y_data = np.array(AccM20NL_MD_MS)
x_data = g_pos*np.ones_like(y_data)
x_data += stats.t(df=6, scale=0.02).rvs(len(x_data))
plt.scatter(x_data[:5],y_data[:5],marker='o',color='darkorange',s=100,alpha=0.5)
plt.scatter(x_data[5:],y_data[5:],marker='^',color='darkorange',s=100,alpha=0.5)

y_data = np.array(AccM20NL_MD + AccM20NL_MD_MS)

plt.xticks([1,1.7,2,2.8,3.1],['Full','Mid','Min','Mid','Min'],fontsize=32,rotation=90)
plt.gca().ticklabel_format(axis='y',style='sci',scilimits=(-0.5,3))
plt.yticks(fontsize=24)

#ax2.set_ylim(0, 0.00016)
ax2.ticklabel_format(axis='y', style='sci', scilimits=(-3, -3))
ax2.yaxis.set_ticks(np.arange(0, 0.0005, 0.0001))

# Common x-ticks
ax2.set_xticks([1, 1.7, 2, 2.8, 3.1])
ax2.set_xticklabels(['Full', 'Mid', 'Min', 'Mid', 'Min'], fontsize=32, rotation=90)

# Adding broken axis effect
d = .5
kwargs = dict(marker=[(-1, -d), (1, d)], markersize=12, linestyle="none", color='k', mec='k', mew=1, clip_on=False)
ax1.plot([0], [0], transform=ax1.transAxes, **kwargs)

leg = plt.legend(
    loc='upper left',         # base location  # fine-tune the legend's position
    frameon=False, ncols=1,
fontsize=32,columnspacing=0.3,handlelength=0.6,handletextpad=0.3,bbox_to_anchor= (-0.05,1.0),markerscale=1.5)

for h in leg.legend_handles:
    try:
        h.set_alpha(1)
    except AttributeError:
        pass

In [None]:
fig, ax1 = plt.subplots(1, 1)#, sharex=True)
fig.subplots_adjust(hspace=0.05)  # adjust space between Axes

y_data = np.array(SSIMFulls_MD + SSIMFulls_MD_MS+SSIMFulls_MD_RT)
g_pos = np.array([1])
colors = ['black']
colors2 = ['gray']

BoxPlots(y_data,g_pos,colors,colors2,ax1,widths=0.2,scatter=False)

y_data = np.array(SSIMFulls_MD)
x_data = g_pos*np.ones_like(y_data)
x_data += stats.t(df=6, scale=0.02).rvs(len(x_data))
plt.scatter(x_data,y_data,marker='o',color='gray',s=100,alpha=0.5)

y_data = np.array(SSIMFulls_MD_MS)
x_data = g_pos*np.ones_like(y_data)
x_data += stats.t(df=6, scale=0.02).rvs(len(x_data))
plt.scatter(x_data[:5],y_data[:5],marker='o',color='black',s=100,alpha=0.7)
plt.scatter(x_data[5:],y_data[5:],marker='^',color='black',s=100,alpha=0.7)

y_data = np.array(SSIMFulls_MD_RT)
x_data = g_pos*np.ones_like(y_data)
x_data += stats.t(df=6, scale=0.02).rvs(len(x_data))
plt.scatter(x_data,y_data,marker='o',color='gray',s=100,alpha=0.5)

y_data = np.array(SSIM20_MD+SSIM20_MD_MS)
g_pos = np.array([1.7])
colors = ['lightseagreen']
colors2 = ['paleturquoise']

BoxPlots(y_data,g_pos,colors,colors2,ax1,widths=0.2,scatter=False,scatter_alpha=0.5)

y_data = np.array(SSIM20_MD)
x_data = g_pos*np.ones_like(y_data)
x_data += stats.t(df=6, scale=0.02).rvs(len(x_data))
plt.scatter(x_data,y_data,marker='o',color=colors2,s=100,alpha=0.5,label='HPC')

y_data = np.array(SSIM20_MD_MS)
x_data = g_pos*np.ones_like(y_data)
x_data += stats.t(df=6, scale=0.02).rvs(len(x_data))
plt.scatter(x_data[:5],y_data[:5],marker='o',color='darkcyan',s=100,alpha=0.5,label='MS-lesions')
plt.scatter(x_data[5:],y_data[5:],marker='^',color='darkcyan',s=100,alpha=0.5,label='MS-ctrl')




y_data = np.array(SSIM7_MD+SSIM7_MD_MS+SSIM7_MD_RT)
g_pos = np.array([2])
colors = ['mediumturquoise']
colors2 = ['paleturquoise']

BoxPlots(y_data,g_pos,colors,colors2,ax1,widths=0.2,scatter=True)

y_data = np.array(SSIM7_MD)
x_data = g_pos*np.ones_like(y_data)
x_data += stats.t(df=6, scale=0.02).rvs(len(x_data))
plt.scatter(x_data,y_data,marker='o',color=colors2,s=100,alpha=0.5)

y_data = np.array(SSIM7_MD_RT)
x_data = g_pos*np.ones_like(y_data)
x_data += stats.t(df=6, scale=0.02).rvs(len(x_data))
plt.scatter(x_data,y_data,marker='o',color='mediumaquamarine',s=100,alpha=0.5)

y_data = np.array(SSIM7_MD_MS)
x_data = g_pos*np.ones_like(y_data)
x_data += stats.t(df=6, scale=0.02).rvs(len(x_data))
plt.scatter(x_data[:5],y_data[:5],marker='o',color='darkcyan',s=100,alpha=0.5)
plt.scatter(x_data[5:],y_data[5:],marker='^',color='darkcyan',s=100,alpha=0.5)

y_data = np.array(SSIM20NL_MD +SSIM20NL_MD_MS)
g_pos = np.array([2.8])
colors = ['sandybrown']
colors2 = ['peachpuff']

BoxPlots(y_data,g_pos,colors,colors2,ax1,widths=0.2,scatter=True)

y_data = np.array(SSIM20NL_MD)
x_data = g_pos*np.ones_like(y_data)
x_data += stats.t(df=6, scale=0.02).rvs(len(x_data))
plt.scatter(x_data,y_data,marker='o',color=colors2,s=100,alpha=0.5)

y_data = np.array(SSIM20NL_MD_MS)
x_data = g_pos*np.ones_like(y_data)
x_data += stats.t(df=6, scale=0.02).rvs(len(x_data))
plt.scatter(x_data[:5],y_data[:5],marker='o',color='darkorange',s=100,alpha=0.5)
plt.scatter(x_data[5:],y_data[5:],marker='^',color='darkorange',s=100,alpha=0.5)

y_data = np.array(SSIM7NL_MD+SSIM7NL_MD_MS+SSIM7NL_MD_RT)
g_pos = np.array([3.1])
colors = ['burlywood']
colors2 = ['peachpuff']

BoxPlots(y_data,g_pos,colors,colors2,ax1,widths=0.2,scatter=True)
BoxPlots(y_data,g_pos,colors,colors2,ax2,widths=0.2,scatter=True)

y_data = np.array(SSIM7NL_MD)
x_data = g_pos*np.ones_like(y_data)
x_data += stats.t(df=6, scale=0.02).rvs(len(x_data))
plt.scatter(x_data,y_data,marker='o',color=colors2,s=100,alpha=0.5)

y_data = np.array(SSIM7NL_MD_RT)
x_data = g_pos*np.ones_like(y_data)
x_data += stats.t(df=6, scale=0.02).rvs(len(x_data))
plt.scatter(x_data,y_data,marker='s',color='chocolate',s=100,alpha=0.5)


y_data = np.array(SSIM7NL_MD_MS)
x_data = g_pos*np.ones_like(y_data)
x_data += stats.t(df=6, scale=0.02).rvs(len(x_data))
plt.scatter(x_data[:5],y_data[:5],marker='o',color='darkorange',s=100,alpha=0.5)
plt.scatter(x_data[5:],y_data[5:],marker='^',color='darkorange',s=100,alpha=0.5)



ax1.set_xlim(ax1.get_xlim())
#ax1.axhline(0.66, lw=3, ls='--', c='k')
ax1.set_xticks([1, 1.7, 2, 2.8, 3.1])
ax1.set_xticklabels(['Full', 'Mid', 'Min', 'Mid', 'Min'], fontsize=32, rotation=90)

leg = ax1.legend(
    loc='upper left',         # base location  # fine-tune the legend's position
    frameon=False, ncols=1,
fontsize=32,columnspacing=0.3,handlelength=0.6,handletextpad=0.3,bbox_to_anchor= (-0.0,0.6),markerscale=1.5)

for h in leg.legend_handles:
    try:
        h.set_alpha(1)
    except AttributeError:
        pass

In [None]:
y_data = np.array(PrecFull_SBI_MD+PrecFull_SBI_MD_MS)
g_pos = np.array([1])
colors = ['mediumturquoise']
colors2 = ['paleturquoise']
fig,ax = plt.subplots()
BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2,scatter=False)

x = np.arange(0.85,1.6,0.05)
y1 = np.ones_like(x)*np.percentile(y_data[~np.isnan(y_data)], 25)
y2 = np.ones_like(x)*np.percentile(y_data[~np.isnan(y_data)], 77)
plt.fill_between(x,y1,y2,color=SBIFit,zorder=10,alpha=0.2,hatch='//')

y_data = np.array(PrecFull_SBI_MD)
x_data = g_pos*np.ones_like(y_data)
x_data += stats.t(df=6, scale=0.02).rvs(len(x_data))
plt.scatter(x_data,y_data,marker='o',color=colors2,s=100,alpha=0.5)

y_data = np.array(PrecFull_SBI_MD_MS)
x_data = g_pos*np.ones_like(y_data)
x_data += stats.t(df=6, scale=0.02).rvs(len(x_data))
plt.scatter(x_data[:5],y_data[:5],marker='o',color='darkcyan',s=100,alpha=0.5)
plt.scatter(x_data[5:],y_data[5:],marker='^',color='darkcyan',s=100,alpha=0.5)

y_data = np.array(Prec20_SBI_MD+Prec20_SBI_MD_MS)
g_pos = np.array([1.2])
BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2,scatter=False)

y_data = np.array(Prec20_SBI_MD)
x_data = g_pos*np.ones_like(y_data)
x_data += stats.t(df=6, scale=0.02).rvs(len(x_data))
plt.scatter(x_data,y_data,marker='o',color=colors2,s=100,alpha=0.5)

y_data = np.array(Prec20_SBI_MD_MS)
x_data = g_pos*np.ones_like(y_data)
x_data += stats.t(df=6, scale=0.02).rvs(len(x_data))
plt.scatter(x_data[:5],y_data[:5],marker='o',color='darkcyan',s=100,alpha=0.5)
plt.scatter(x_data[5:],y_data[5:],marker='^',color='darkcyan',s=100,alpha=0.5)

y_data = np.array(Prec7_SBI_MD+Prec7_SBI_MD_MS)
g_pos = np.array([1.4])
BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2,scatter=False)

y_data = np.array(Prec7_SBI_MD)
x_data = g_pos*np.ones_like(y_data)
x_data += stats.t(df=6, scale=0.02).rvs(len(x_data))
plt.scatter(x_data,y_data,marker='o',color=colors2,s=100,alpha=0.5)

y_data = np.array(Prec7_SBI_MD_MS)
x_data = g_pos*np.ones_like(y_data)
x_data += stats.t(df=6, scale=0.02).rvs(len(x_data))
plt.scatter(x_data[:5],y_data[:5],marker='o',color='darkcyan',s=100,alpha=0.5)
plt.scatter(x_data[5:],y_data[5:],marker='^',color='darkcyan',s=100,alpha=0.5)

colors = ['sandybrown']
colors2 = ['peachpuff']
y_data = np.array(PrecFull_NLLS_MD+PrecFull_NLLS_MD_MS)
g_pos = np.array([2])
BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2,scatter=False)

x = np.arange(1.85,2.6,0.05)
y1 = np.ones_like(x)*np.percentile(np.array(PrecFull_NLLS_MD)[~np.isnan(PrecFull_NLLS_MD)], 25)
y2 = np.ones_like(x)*np.percentile(np.array(PrecFull_NLLS_MD)[~np.isnan(PrecFull_NLLS_MD)], 77)
plt.fill_between(x,y1,y2,color=WLSFit,zorder=10,alpha=0.2,hatch='//')#

y_data = np.array(PrecFull_NLLS_MD)
x_data = g_pos*np.ones_like(y_data)
x_data += stats.t(df=6, scale=0.02).rvs(len(x_data))
plt.scatter(x_data,y_data,marker='o',color=colors2,s=100,alpha=0.5)


y_data = np.array(PrecFull_NLLS_MD_MS)
x_data = g_pos*np.ones_like(y_data)
x_data += stats.t(df=6, scale=0.02).rvs(len(x_data))
plt.scatter(x_data[:5],y_data[:5],marker='o',color='darkorange',s=100,alpha=0.5)
plt.scatter(x_data[5:],y_data[5:],marker='^',color='darkorange',s=100,alpha=0.5)

y_data = np.array(Prec20_NLLS_MD)
g_pos = np.array([2.2])
BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2,scatter=True)

y_data = np.array(Prec20_NLLS_MD)
x_data = g_pos*np.ones_like(y_data)
x_data += stats.t(df=6, scale=0.02).rvs(len(x_data))
plt.scatter(x_data,y_data,marker='o',color=colors2,s=100,alpha=0.5)


y_data = np.array(Prec20_NLLS_MD_MS)
x_data = g_pos*np.ones_like(y_data)
x_data += stats.t(df=6, scale=0.02).rvs(len(x_data))
plt.scatter(x_data[:5],y_data[:5],marker='o',color='darkorange',s=100,alpha=0.5)
plt.scatter(x_data[5:],y_data[5:],marker='^',color='darkorange',s=100,alpha=0.5)

y_data = np.array(Prec7_NLLS_MD+Prec7_NLLS_MD_MS)
g_pos = np.array([2.4])
BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2,scatter=False)

y_data = np.array(Prec7_NLLS_MD)
x_data = g_pos*np.ones_like(y_data)
x_data += stats.t(df=6, scale=0.02).rvs(len(x_data))
plt.scatter(x_data,y_data,marker='o',color=colors2,s=100,alpha=0.5)


y_data = np.array(Prec7_NLLS_MD_MS)
x_data = g_pos*np.ones_like(y_data)
x_data += stats.t(df=6, scale=0.02).rvs(len(x_data))
plt.scatter(x_data[:5],y_data[:5],marker='o',color='darkorange',s=100,alpha=0.5)
plt.scatter(x_data[5:],y_data[5:],marker='^',color='darkorange',s=100,alpha=0.5)

plt.xticks([1,1.2,1.4,2,2.2,2.4],['Full','Mid','Min','Full','Mid','Min'],fontsize=32,rotation=90)
plt.gca().ticklabel_format(axis='y',style='sci',scilimits=(-0.5,3))
plt.yticks(fontsize=24)

## f

In [None]:
# Plot setup
fig, ax2 = plt.subplots(1, 1)
fig.subplots_adjust(hspace=0.05)

y_data = np.array(AccM7NL_FA + AccM7NL_FA_MS+AccM7NL_FA_RT)
g_pos = np.array([3.1])
colors = ['burlywood']
colors2 = ['peachpuff']

BoxPlots(y_data,g_pos,colors,colors2,ax2,widths=0.2,scatter=False)
plt.xticks([1,1.7,2,2.8,3.1],['Full','Mid','Min','Mid','Min'],fontsize=32,rotation=90)
#plt.gca().ticklabel_format(axis='y',style='sci',scilimits=(-1,1))
plt.yticks(fontsize=24)

y_data = np.array(AccM7NL_FA)
x_data = g_pos*np.ones_like(y_data)
x_data += stats.t(df=6, scale=0.02).rvs(len(x_data))
ax2.scatter(x_data,y_data,marker='o',color=colors2,s=100,alpha=0.8)

y_data = np.array(AccM7NL_FA_MS)
x_data = g_pos*np.ones_like(y_data)
x_data += stats.t(df=6, scale=0.02).rvs(len(x_data))
ax2.scatter(x_data[:5],y_data[:5],marker='o',color='darkorange',s=100,alpha=0.5)
ax2.scatter(x_data[5:],y_data[5:],marker='^',color='darkorange',s=100,alpha=0.5)

y_data = np.array(AccM7NL_FA_RT)
x_data = g_pos*np.ones_like(y_data)
x_data += stats.t(df=6, scale=0.02).rvs(len(x_data))
ax2.scatter(x_data,y_data,marker='s',color='chocolate',s=100,alpha=0.5)


ax1.set_ylim(0, 8e-4)
ax1.ticklabel_format(axis='y', style='sci', scilimits=(-3, -3))
#ax1.yaxis.set_ticks(np.arange(0.0005, 0.006, 0.002))
ax1.set_xticks([])

plt.sca(ax2)

y_data = np.array(AccMFulls_FA + AccMFulls_FA_MS+AccMFulls_FA_RT)
g_pos = np.array([1])
colors = ['black']
colors2 = ['gray']

BoxPlots(y_data,g_pos,colors,colors2,ax2,widths=0.2,scatter=False)

y_data = np.array(AccMFulls_FA)
x_data = g_pos*np.ones_like(y_data)
x_data += stats.t(df=6, scale=0.02).rvs(len(x_data))
plt.scatter(x_data,y_data,marker='o',color='gray',s=100,alpha=0.5)

y_data = np.array(AccMFulls_FA_MS)
x_data = g_pos*np.ones_like(y_data)
x_data += stats.t(df=6, scale=0.02).rvs(len(x_data))
plt.scatter(x_data[:5],y_data[:5],marker='o',color='black',s=100,alpha=0.5)
plt.scatter(x_data[5:],y_data[5:],marker='^',color='black',s=100,alpha=0.5)

y_data = np.array(AccMFulls_FA_RT)
x_data = g_pos*np.ones_like(y_data)
x_data += stats.t(df=6, scale=0.02).rvs(len(x_data))
ax2.scatter(x_data,y_data,marker='s',color='gray',s=100,alpha=0.5)

y_data = np.array(AccM20_FA + AccM20_FA_MS)
g_pos = np.array([1.7])
colors = ['lightseagreen']
colors2 = ['paleturquoise']

BoxPlots(y_data,g_pos,colors,colors2,ax2,widths=0.2,scatter=False)

y_data = np.array(AccM20_FA)
x_data = g_pos*np.ones_like(y_data)
x_data += stats.t(df=6, scale=0.02).rvs(len(x_data))
plt.scatter(x_data,y_data,marker='o',color=colors2,s=100,alpha=0.5)

y_data = np.array(AccM20_FA_MS)
x_data = g_pos*np.ones_like(y_data)
x_data += stats.t(df=6, scale=0.02).rvs(len(x_data))
plt.scatter(x_data[:5],y_data[:5],marker='o',color='darkcyan',s=100,alpha=0.5)
plt.scatter(x_data[5:],y_data[5:],marker='^',color='darkcyan',s=100,alpha=0.5)

y_data = np.array(AccM7_FA + AccM7_FA_MS+AccM7_FA_RT)
g_pos = np.array([2])
colors = ['mediumturquoise']
colors2 = ['paleturquoise']

BoxPlots(y_data,g_pos,colors,colors2,ax2,widths=0.2,scatter=False)

y_data = np.array(AccM7_FA)
x_data = g_pos*np.ones_like(y_data)
x_data += stats.t(df=6, scale=0.02).rvs(len(x_data))
plt.scatter(x_data,y_data,marker='o',color=colors2,s=100,alpha=0.5)


y_data = np.array(AccM7_FA_RT)
x_data = g_pos*np.ones_like(y_data)
x_data += stats.t(df=6, scale=0.02).rvs(len(x_data))
ax2.scatter(x_data,y_data,marker='s',color='mediumaquamarine',s=100,alpha=0.5)

y_data = np.array(AccM7_FA_MS)
x_data = g_pos*np.ones_like(y_data)
x_data += stats.t(df=6, scale=0.02).rvs(len(x_data))
plt.scatter(x_data[:5],y_data[:5],marker='o',color='darkcyan',s=100,alpha=0.5)
plt.scatter(x_data[5:],y_data[5:],marker='^',color='darkcyan',s=100,alpha=0.5)

g_pos = np.array([2.8])
colors = ['sandybrown']
colors2 = ['peachpuff']

y_data = np.array(AccM20NL_FA + AccM20NL_FA_MS)

BoxPlots(y_data,g_pos,colors,colors2,ax2,widths=0.2,scatter=False)

y_data = np.array(AccM20NL_FA)
x_data = g_pos*np.ones_like(y_data)
x_data += stats.t(df=6, scale=0.02).rvs(len(x_data))
plt.scatter(x_data,y_data,marker='o',color=colors2,s=100,alpha=0.5)

y_data = np.array(AccM20NL_FA_MS)
x_data = g_pos*np.ones_like(y_data)
x_data += stats.t(df=6, scale=0.02).rvs(len(x_data))
plt.scatter(x_data[:5],y_data[:5],marker='o',color='darkorange',s=100,alpha=0.5)
plt.scatter(x_data[5:],y_data[5:],marker='^',color='darkorange',s=100,alpha=0.5)

y_data = np.array(AccM20NL_FA + AccM20NL_FA_MS)

plt.xticks([1,1.7,2,2.8,3.1],['Full','Mid','Min','Mid','Min'],fontsize=32,rotation=90)
plt.gca().ticklabel_format(axis='y',style='sci',scilimits=(-0.5,3))
plt.yticks(fontsize=24)

#ax2.set_ylim(0, 0.00016)

# Common x-ticks
ax2.set_xticks([1, 1.7, 2, 2.8, 3.1])
ax2.set_xticklabels(['Full', 'Mid', 'Min', 'Mid', 'Min'], fontsize=32, rotation=90)

# Adding broken axis effect
d = .5
kwargs = dict(marker=[(-1, -d), (1, d)], markersize=12, linestyle="none", color='k', mec='k', mew=1, clip_on=False)
ax1.plot([0], [0], transform=ax1.transAxes, **kwargs)


In [None]:
fig, ax1 = plt.subplots(1, 1)#, sharex=True)
fig.subplots_adjust(hspace=0.05)  # adjust space between Axes

y_data = np.array(SSIMFulls_FA + SSIMFulls_FA_MS+SSIMFulls_FA_RT)
g_pos = np.array([1])
colors = ['black']
colors2 = ['gray']

BoxPlots(y_data,g_pos,colors,colors2,ax1,widths=0.2,scatter=False)

y_data = np.array(SSIMFulls_FA)
x_data = g_pos*np.ones_like(y_data)
x_data += stats.t(df=6, scale=0.02).rvs(len(x_data))
plt.scatter(x_data,y_data,marker='o',color='gray',s=100,alpha=0.5)

y_data = np.array(SSIMFulls_FA_MS)
x_data = g_pos*np.ones_like(y_data)
x_data += stats.t(df=6, scale=0.02).rvs(len(x_data))
plt.scatter(x_data[:5],y_data[:5],marker='o',color='black',s=100,alpha=0.7)
plt.scatter(x_data[5:],y_data[5:],marker='^',color='black',s=100,alpha=0.7)

y_data = np.array(SSIMFulls_FA_RT)
x_data = g_pos*np.ones_like(y_data)
x_data += stats.t(df=6, scale=0.02).rvs(len(x_data))
plt.scatter(x_data,y_data,marker='o',color='gray',s=100,alpha=0.5)

y_data = np.array(SSIM20_FA+SSIM20_FA_MS)
g_pos = np.array([1.7])
colors = ['lightseagreen']
colors2 = ['paleturquoise']

BoxPlots(y_data,g_pos,colors,colors2,ax1,widths=0.2,scatter=False,scatter_alpha=0.5)

y_data = np.array(SSIM20_FA)
x_data = g_pos*np.ones_like(y_data)
x_data += stats.t(df=6, scale=0.02).rvs(len(x_data))
plt.scatter(x_data,y_data,marker='o',color=colors2,s=100,alpha=0.5)

y_data = np.array(SSIM20_FA_MS)
x_data = g_pos*np.ones_like(y_data)
x_data += stats.t(df=6, scale=0.02).rvs(len(x_data))
plt.scatter(x_data[:5],y_data[:5],marker='o',color='darkcyan',s=100,alpha=0.5)
plt.scatter(x_data[5:],y_data[5:],marker='^',color='darkcyan',s=100,alpha=0.5)




y_data = np.array(SSIM7_FA+SSIM7_FA_MS+SSIM7_FA_RT)
g_pos = np.array([2])
colors = ['mediumturquoise']
colors2 = ['paleturquoise']

BoxPlots(y_data,g_pos,colors,colors2,ax1,widths=0.2,scatter=True)

y_data = np.array(SSIM7_FA)
x_data = g_pos*np.ones_like(y_data)
x_data += stats.t(df=6, scale=0.02).rvs(len(x_data))
plt.scatter(x_data,y_data,marker='o',color=colors2,s=100,alpha=0.5)

y_data = np.array(SSIM7_FA_RT)
x_data = g_pos*np.ones_like(y_data)
x_data += stats.t(df=6, scale=0.02).rvs(len(x_data))
plt.scatter(x_data,y_data,marker='s',color='mediumaquamarine',s=100,alpha=0.5)

y_data = np.array(SSIM7_FA_MS)
x_data = g_pos*np.ones_like(y_data)
x_data += stats.t(df=6, scale=0.02).rvs(len(x_data))
plt.scatter(x_data[:5],y_data[:5],marker='o',color='darkcyan',s=100,alpha=0.5)
plt.scatter(x_data[5:],y_data[5:],marker='^',color='darkcyan',s=100,alpha=0.5)

y_data = np.array(SSIM20NL_FA +SSIM20NL_FA_MS)
g_pos = np.array([2.8])
colors = ['sandybrown']
colors2 = ['peachpuff']

BoxPlots(y_data,g_pos,colors,colors2,ax1,widths=0.2,scatter=True)

y_data = np.array(SSIM20NL_FA)
x_data = g_pos*np.ones_like(y_data)
x_data += stats.t(df=6, scale=0.02).rvs(len(x_data))
plt.scatter(x_data,y_data,marker='o',color=colors2,s=100,alpha=0.5)

y_data = np.array(SSIM20NL_FA_MS)
x_data = g_pos*np.ones_like(y_data)
x_data += stats.t(df=6, scale=0.02).rvs(len(x_data))
plt.scatter(x_data[:5],y_data[:5],marker='o',color='darkorange',s=100,alpha=0.5)
plt.scatter(x_data[5:],y_data[5:],marker='^',color='darkorange',s=100,alpha=0.5)

y_data = np.array(SSIM7NL_FA+SSIM7NL_FA_MS+SSIM7NL_FA_RT)
g_pos = np.array([3.1])
colors = ['burlywood']
colors2 = ['peachpuff']

BoxPlots(y_data,g_pos,colors,colors2,ax1,widths=0.2,scatter=True)
BoxPlots(y_data,g_pos,colors,colors2,ax2,widths=0.2,scatter=True)

y_data = np.array(SSIM7NL_FA)
x_data = g_pos*np.ones_like(y_data)
x_data += stats.t(df=6, scale=0.02).rvs(len(x_data))
plt.scatter(x_data,y_data,marker='o',color=colors2,s=100,alpha=0.5)

y_data = np.array(SSIM7NL_FA_RT)
x_data = g_pos*np.ones_like(y_data)
x_data += stats.t(df=6, scale=0.02).rvs(len(x_data))
plt.scatter(x_data,y_data,marker='s',color='chocolate',s=100,alpha=0.5)


y_data = np.array(SSIM7NL_FA_MS)
x_data = g_pos*np.ones_like(y_data)
x_data += stats.t(df=6, scale=0.02).rvs(len(x_data))
plt.scatter(x_data[:5],y_data[:5],marker='o',color='darkorange',s=100,alpha=0.5)
plt.scatter(x_data[5:],y_data[5:],marker='^',color='darkorange',s=100,alpha=0.5)

ax2.scatter(x_data[:5],y_data[:5],marker='o',color='darkorange',s=100,alpha=0.5)



ax1.set_xlim(ax1.get_xlim())
#ax1.axhline(0.66, lw=3, ls='--', c='k')
ax1.set_xticks([1, 1.7, 2, 2.8, 3.1])
ax1.set_xticklabels(['Full', 'Mid', 'Min', 'Mid', 'Min'], fontsize=32, rotation=90)



In [None]:
y_data = np.array(PrecFull_SBI_FA+PrecFull_SBI_FA_MS)
g_pos = np.array([1])
colors = ['mediumturquoise']
colors2 = ['paleturquoise']
fig,ax = plt.subplots()
BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2,scatter=False)

x = np.arange(0.85,1.6,0.05)
y1 = np.ones_like(x)*np.percentile(y_data[~np.isnan(y_data)], 25)
y2 = np.ones_like(x)*np.percentile(y_data[~np.isnan(y_data)], 77)
plt.fill_between(x,y1,y2,color=SBIFit,zorder=10,alpha=0.2,hatch='//')

y_data = np.array(PrecFull_SBI_FA)
x_data = g_pos*np.ones_like(y_data)
x_data += stats.t(df=6, scale=0.02).rvs(len(x_data))
plt.scatter(x_data,y_data,marker='o',color=colors2,s=100,alpha=0.5)

y_data = np.array(PrecFull_SBI_FA_MS)
x_data = g_pos*np.ones_like(y_data)
x_data += stats.t(df=6, scale=0.02).rvs(len(x_data))
plt.scatter(x_data[:5],y_data[:5],marker='o',color='darkcyan',s=100,alpha=0.5)
plt.scatter(x_data[5:],y_data[5:],marker='^',color='darkcyan',s=100,alpha=0.5)

y_data = np.array(Prec20_SBI_FA+Prec20_SBI_FA_MS)
g_pos = np.array([1.2])
BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2,scatter=False)

y_data = np.array(Prec20_SBI_FA)
x_data = g_pos*np.ones_like(y_data)
x_data += stats.t(df=6, scale=0.02).rvs(len(x_data))
plt.scatter(x_data,y_data,marker='o',color=colors2,s=100,alpha=0.5)

y_data = np.array(Prec20_SBI_FA_MS)
x_data = g_pos*np.ones_like(y_data)
x_data += stats.t(df=6, scale=0.02).rvs(len(x_data))
plt.scatter(x_data[:5],y_data[:5],marker='o',color='darkcyan',s=100,alpha=0.5)
plt.scatter(x_data[5:],y_data[5:],marker='^',color='darkcyan',s=100,alpha=0.5)

y_data = np.array(Prec7_SBI_FA+Prec7_SBI_FA_MS)
g_pos = np.array([1.4])
BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2,scatter=False)

y_data = np.array(Prec7_SBI_FA)
x_data = g_pos*np.ones_like(y_data)
x_data += stats.t(df=6, scale=0.02).rvs(len(x_data))
plt.scatter(x_data,y_data,marker='o',color=colors2,s=100,alpha=0.5)

y_data = np.array(Prec7_SBI_FA_MS)
x_data = g_pos*np.ones_like(y_data)
x_data += stats.t(df=6, scale=0.02).rvs(len(x_data))
plt.scatter(x_data[:5],y_data[:5],marker='o',color='darkcyan',s=100,alpha=0.5)
plt.scatter(x_data[5:],y_data[5:],marker='^',color='darkcyan',s=100,alpha=0.5)

colors = ['sandybrown']
colors2 = ['peachpuff']
y_data = np.array(PrecFull_NLLS_FA+PrecFull_NLLS_FA_MS)
g_pos = np.array([2])
BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2,scatter=False)

x = np.arange(1.85,2.6,0.05)
y1 = np.ones_like(x)*np.percentile(np.array(PrecFull_NLLS_FA)[~np.isnan(PrecFull_NLLS_FA)], 25)
y2 = np.ones_like(x)*np.percentile(np.array(PrecFull_NLLS_FA)[~np.isnan(PrecFull_NLLS_FA)], 77)
plt.fill_between(x,y1,y2,color=WLSFit,zorder=10,alpha=0.2,hatch='//')#

y_data = np.array(PrecFull_NLLS_FA)
x_data = g_pos*np.ones_like(y_data)
x_data += stats.t(df=6, scale=0.02).rvs(len(x_data))
plt.scatter(x_data,y_data,marker='o',color=colors2,s=100,alpha=0.5)


y_data = np.array(PrecFull_NLLS_FA_MS)
x_data = g_pos*np.ones_like(y_data)
x_data += stats.t(df=6, scale=0.02).rvs(len(x_data))
plt.scatter(x_data[:5],y_data[:5],marker='o',color='darkorange',s=100,alpha=0.5)
plt.scatter(x_data[5:],y_data[5:],marker='^',color='darkorange',s=100,alpha=0.5)

y_data = np.array(Prec20_NLLS_FA)
g_pos = np.array([2.2])
BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2,scatter=True)

y_data = np.array(Prec20_NLLS_FA)
x_data = g_pos*np.ones_like(y_data)
x_data += stats.t(df=6, scale=0.02).rvs(len(x_data))
plt.scatter(x_data,y_data,marker='o',color=colors2,s=100,alpha=0.5)


y_data = np.array(Prec20_NLLS_FA_MS)
x_data = g_pos*np.ones_like(y_data)
x_data += stats.t(df=6, scale=0.02).rvs(len(x_data))
plt.scatter(x_data[:5],y_data[:5],marker='o',color='darkorange',s=100,alpha=0.5)
plt.scatter(x_data[5:],y_data[5:],marker='^',color='darkorange',s=100,alpha=0.5)

y_data = np.array(Prec7_NLLS_FA+Prec7_NLLS_FA_MS)
g_pos = np.array([2.4])
BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2,scatter=False)

y_data = np.array(Prec7_NLLS_FA)
x_data = g_pos*np.ones_like(y_data)
x_data += stats.t(df=6, scale=0.02).rvs(len(x_data))
plt.scatter(x_data,y_data,marker='o',color=colors2,s=100,alpha=0.5)


y_data = np.array(Prec7_NLLS_FA_MS)
x_data = g_pos*np.ones_like(y_data)
x_data += stats.t(df=6, scale=0.02).rvs(len(x_data))
plt.scatter(x_data[:5],y_data[:5],marker='o',color='darkorange',s=100,alpha=0.5)
plt.scatter(x_data[5:],y_data[5:],marker='^',color='darkorange',s=100,alpha=0.5)

plt.xticks([1,1.2,1.4,2,2.2,2.4],['Full','Mid','Min','Full','Mid','Min'],fontsize=32,rotation=90)
plt.gca().ticklabel_format(axis='y',style='sci',scilimits=(-0.5,3))
plt.yticks(fontsize=24)


# High Resolution

In [None]:
fdwi = './HCP_data/Pat'+str(1)+'/diff_1k.nii.gz'
bvalloc = './HCP_data/Pat'+str(1)+'/bvals_1k.txt'
bvecloc = './HCP_data/Pat'+str(1)+'/bvecs_1k.txt'

bvalsHCP = np.loadtxt(bvalloc)
bvecsHCP = np.loadtxt(bvecloc)
gtabHCP = gradient_table(bvalsHCP, bvecsHCP)

data, affine, img = load_nifti(fdwi, return_img=True)
floor = np.clip(data.min(axis=-1),-np.inf,0)
data2 = data + abs(floor)[:,:,:,None] + 1e-5
axial_middle = data.shape[2] // 2
maskdata, mask = median_otsu(data2, vol_idx=range(10, 50), median_radius=3,
                             numpass=1, autocrop=True, dilate=2)

# Choose the first point (arbitrary starting point, e.g., the first gradient)
selected_indices = [1]
distance_matrix = squareform(pdist(bvecsHCP))
# Iteratively select the point furthest from the current selection
for _ in range(5):  # We need 7 points in total, and one is already selected
    remaining_indices = list(set(range(len(bvecsHCP))) - set(selected_indices))
    
    # Calculate the minimum distance to the selected points for each remaining point
    min_distances = np.min(distance_matrix[remaining_indices][:, selected_indices], axis=1)
    
    # Select the point with the maximum minimum distance
    next_index = remaining_indices[np.argmax(min_distances)]
    selected_indices.append(next_index)

selected_indices = [0]+selected_indices

bvalsHCP7 = bvalsHCP[selected_indices]
bvecsHCP7 = bvecsHCP[selected_indices]
gtabHCP7 = gradient_table(bvalsHCP7, bvecsHCP7)

In [None]:
# Compute the mask where the sum is not zero
mask = np.sum(maskdata[:, :, axial_middle, :], axis=-1) != 0

# Get the indices where mask is True
indices = np.argwhere(mask)

# Define the function for optimization
def optimize_pixel(i, j):
    torch.manual_seed(10)  # If required
    posterior_samples_1 = Network.sample((1000,), x=DTIFeatures(gtabHCP.bvecs,gtabHCP.bvals,maskdata[i, j,axial_middle, :]),show_progress_bars=False)
    return i, j, posterior_samples_1.mean(axis=0)

# Initialize NoiseEst with the appropriate shape
ArrShape = mask.shape

# Use joblib to parallelize the optimization tasks
results = Parallel(n_jobs=8)(
    delayed(optimize_pixel)(i, j) for i, j in tqdm(indices)
)

In [None]:

NoiseEst = np.zeros(list(ArrShape) + [7])

# Assign the optimization results to NoiseEst
for i, j, x in results:
    NoiseEst[i, j] = x

for i, j, x in results:
    NoiseEst[i, j,-2] = np.clip(NoiseEst[i, j,-2],0,100)
    NoiseEst[i, j,-3] = np.clip(NoiseEst[i, j,-3],0,1)

In [None]:

NoiseEst2 =  np.zeros_like(NoiseEst)
NoiseEst2 =  np.zeros_like(NoiseEst)
for i in range(94):
    for j in range(104):    
        NoiseEst2[i,j] = np.hstack([mat_to_vals(clip_negative_eigenvalues(vals_to_mat(NoiseEst[i,j]))),NoiseEst[i,j,-1]])
MD_SBI = np.zeros([94,104])
FA_SBI = np.zeros([94,104])
for i in range(94):
    for j in range(104):
        Eigs = np.linalg.eigh(vals_to_mat(NoiseEst2[i,j,:6]))[0]
        MD_SBI[i,j] = np.mean(Eigs)
        FA_SBI[i,j] = FracAni(Eigs,np.mean(Eigs))
FA_SBI[np.isnan(FA_SBI)] = 0

In [None]:
temp = np.copy(MD_SBI)
temp[~mask] = math.nan
img = plt.imshow(temp.T,cmap='hot')
plt.axis('off')
#cbar = plt.colorbar()
#cbar.formatter.set_powerlimits((0, 0))
vmin, vmax = img.get_clim()

In [None]:
temp = np.copy(FA_SBI)
temp[~mask] = math.nan
img = plt.imshow(temp.T,cmap='hot')
plt.axis('off')
#cbar = plt.colorbar()
#cbar.formatter.set_powerlimits((0, 0))
vmin, vmax = img.get_clim()


In [None]:
# Compute the mask where the sum is not zero
mask = np.sum(maskdata[:, :, axial_middle, :], axis=-1) != 0

# Get the indices where mask is True
indices = np.argwhere(mask)

# Define the function for optimization
def optimize_pixel(i, j):
    torch.manual_seed(10)  # If required
    posterior_samples_1 = Network.sample((1000,), x=DTIFeatures(gtabHCP.bvecs[selected_indices],gtabHCP.bvals[selected_indices],maskdata[i, j,axial_middle, selected_indices]),show_progress_bars=False)
    return i, j, posterior_samples_1.mean(axis=0)

# Initialize NoiseEst with the appropriate shape
ArrShape = mask.shape

# Use joblib to parallelize the optimization tasks
results = Parallel(n_jobs=8)(
    delayed(optimize_pixel)(i, j) for i, j in tqdm(indices,position=0,leave=True)
)


In [None]:

NoiseEst = np.zeros(list(ArrShape) + [7])

# Assign the optimization results to NoiseEst
for i, j, x in results:
    NoiseEst[i, j] = x

for i, j, x in results:
    NoiseEst[i, j,-2] = np.clip(NoiseEst[i, j,-2],0,100)
    NoiseEst[i, j,-3] = np.clip(NoiseEst[i, j,-3],0,1)

In [None]:

NoiseEst2 =  np.zeros_like(NoiseEst)
for i in range(94):
    for j in range(104):    
        NoiseEst2[i,j] = np.hstack([mat_to_vals(clip_negative_eigenvalues(vals_to_mat(NoiseEst[i,j]))),NoiseEst[i,j,-1]])
MD_SBI7 = np.zeros([94,104])
FA_SBI7 = np.zeros([94,104])
for i in range(94):
    for j in range(104):
        Eigs = np.linalg.eigh(vals_to_mat(NoiseEst2[i,j,:6]))[0]
        MD_SBI7[i,j] = np.mean(Eigs)
        FA_SBI7[i,j] = FracAni(Eigs,np.mean(Eigs))
FA_SBI7[np.isnan(FA_SBI7)] = 0

In [None]:
temp = np.copy(MD_SBI7)
temp[~mask] = math.nan
img = plt.imshow(temp.T,cmap='hot',vmin=0,vmax=0.005)
plt.axis('off')
#cbar = plt.colorbar()
#cbar.formatter.set_powerlimits((0, 0))
vmin, vmax = img.get_clim()


In [None]:
temp = gaussian_filter(np.copy(FA_SBI7),sigma=0.5)
temp[~mask] = math.nan
img = plt.imshow(temp.T,cmap='hot')
plt.axis('off')
#cbar = plt.colorbar()
#cbar.formatter.set_powerlimits((0, 0))
vmin, vmax = img.get_clim()


In [None]:
tenmodel = dti.TensorModel(gtabHCP,return_S0_hat = True,fit_method='NLLS')
tenfit = tenmodel.fit(maskdata[:,:,axial_middle])
FAFull = dti.fractional_anisotropy(tenfit.evals)
MDFull = dti.mean_diffusivity(tenfit.evals)

In [None]:
tenmodel = dti.TensorModel(gtabHCP7,return_S0_hat = True,fit_method='NLLS')
tenfit = tenmodel.fit(maskdata[:,:,axial_middle,selected_indices])
FA7 = dti.fractional_anisotropy(tenfit.evals)
MD7 = dti.mean_diffusivity(tenfit.evals)

In [None]:
temp = np.copy(MDFull)
temp[~mask] = math.nan
img = plt.imshow(temp.T,cmap='hot')
plt.axis('off')
cbar = plt.colorbar()
cbar.formatter.set_powerlimits((0, 0))
vmin, vmax = img.get_clim()


In [None]:
temp = np.copy(FAFull)
temp[~mask] = math.nan
img = plt.imshow(temp.T,cmap='hot')
plt.axis('off')
cbar = plt.colorbar()
cbar.formatter.set_powerlimits((0, 0))
vmin, vmax = img.get_clim()


In [None]:
data = MD_SBI.T-MDFull.T
data[~mask.T] = np.nan
norm = TwoSlopeNorm(vmin=np.nanmin(data), vcenter=0, vmax=np.nanmax(data))
plt.imshow(data,cmap='seismic',norm=norm)
plt.axis('off')
cbar = plt.colorbar()
ticks = [np.nanmin(data), 0, np.nanmax(data)]  # Adjust the number of ticks as needed
cbar.set_ticks(ticks)
cbar.formatter.set_powerlimits((0, 0))


In [None]:
data = FA_SBI.T-FAFull.T
data[~mask.T] = np.nan
norm = TwoSlopeNorm(vmin=np.nanmin(data), vcenter=0, vmax=1)
plt.imshow(data,cmap='seismic',norm=norm)
plt.axis('off')
cbar = plt.colorbar()
ticks = [np.nanmin(data), 0, 1]  # Adjust the number of ticks as needed
cbar.set_ticks(ticks)
cbar.formatter.set_powerlimits((0, 0))


In [None]:
temp = np.copy(MD7)
temp[~mask] = math.nan
img = plt.imshow(temp.T,cmap='hot',vmin=0,vmax=0.005)
plt.axis('off')
cbar = plt.colorbar()
cbar.formatter.set_powerlimits((0, 0))


In [None]:
data = np.abs(MDFull.T-MD7.T)
norm = TwoSlopeNorm(vmin=0,vcenter=0.00075, vmax=0.0015)
data[~mask.T] = np.nan
plt.imshow(data,cmap='Reds',norm=norm)

ticks = [0, 0.001]
plt.axis('off')
cbar = plt.colorbar()

cbar.set_ticks(ticks)
cbar.formatter.set_powerlimits((0, 0))

plt.show()

data = np.abs(MD_SBI.T-MD_SBI7.T)
data[~mask.T] = np.nan
plt.imshow(data,cmap='Reds',norm=norm)
plt.axis('off')

plt.show()




In [None]:
temp = np.copy(FA7)
temp[~mask] = math.nan
img = plt.imshow(temp.T,cmap='hot')
plt.axis('off')
cbar = plt.colorbar()
cbar.formatter.set_powerlimits((0, 0))
vmin, vmax = img.get_clim()


In [None]:
data = np.abs(FAFull.T-FA7.T)
norm = TwoSlopeNorm(vmin=0,vcenter=0.5, vmax=1)
data[~mask.T] = np.nan
plt.imshow(data,cmap='Reds',norm=norm)

ticks = [0, 1]
plt.axis('off')
cbar = plt.colorbar()

cbar.set_ticks(ticks)
cbar.formatter.set_powerlimits((0, 0))

plt.show()

data = np.abs(FA_SBI.T-FA_SBI7.T)
data[~mask.T] = np.nan
plt.imshow(data,cmap='Reds',norm=norm)
plt.axis('off')

plt.show()


