# Frontmatter

In [None]:
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from matplotlib.legend_handler import HandlerPatch
from matplotlib.patches import Rectangle
from matplotlib.lines import Line2D
from matplotlib.colors import TwoSlopeNorm
from dipy.denoise.localpca import mppca

# Define font properties
font = {
    'family': 'sans-serif',  # Use sans-serif family
    'sans-serif': ['Helvetica'],  # Specify Helvetica as the sans-serif font
    'size': 14  # Set the default font size
}
plt.rc('font', **font)

# Set tick label sizes
plt.rc('ytick', labelsize=24)
plt.rc('xtick', labelsize=24)

plt.rcParams.update({
    "text.usetex": False,
    "font.family": "Helvetica"
})
# Customize axes spines and legend appearance
plt.rcParams['axes.spines.top'] = False
plt.rcParams['axes.spines.right'] = False
plt.rcParams['legend.frameon'] = False


%load_ext autoreload
%autoreload 2
    
from dwMRI_BasicFuncs import *
from joblib import Parallel, delayed

from tqdm.auto import tqdm
from scipy.ndimage import gaussian_filter

from scipy.optimize import minimize
from scipy.special import i0

from scipy.optimize import least_squares
from scipy.special import j0, jv
from scipy.optimize import bisect

In [None]:
network_path = './Networks/'
image_path   = '../Figures/'
if not os.path.exists(image_path):
    os.mkdir(image_path)
NoiseLevels = [None,20,10,5,2]

TrainingSamples = 50000
InferSamples    = 500

lower_abs,upper_abs = -0.07,0.07
lower_rest,upper_rest = -0.015,0.015
lower_S0 = 25
upper_S0 = 2000
Save = True

TrueCol  = 'k'
NoisyCol = 'k'
WLSFit   = 'sandybrown'#np.array([225,190,106])/255
SBIFit   = np.array([64,176,166])/255

Errors_name = ['RK comparison','FA comparison','eig. comparison','Frobenius','Signal comparison','Correlation','Signal comparison','Correlation2']
custom_prior = DTIPriorS0(lower_abs,upper_abs,lower_rest,upper_rest,lower_S0,upper_S0)
priorS0, *_ = process_prior(custom_prior) 

NLLSFit   = np.array([225,190,106])/255

In [None]:
def BoxPlots(y_data, positions, colors, colors2, ax,hatch = False,scatter=False,scatter_alpha=0.5, **kwargs):

    GREY_DARK = "#747473"
    jitter = 0.02
    # Clean data to remove NaNs column-wise
    is_1d = np.ndim(y_data) == 1
    if is_1d:
        cleaned_data = y_data[~np.isnan(y_data)]
    else:
        cleaned_data = [np.asarray(d)[~np.isnan(d)] for d in y_data]
    
    # Define properties for the boxes (patch objects)
    boxprops = dict(
        linewidth=2, 
        facecolor='none',       # use facecolor for filling (set to 'none' if you want no fill)
        edgecolor='turquoise'   # edgecolor for the outline
    )

    # Define properties for the medians (Line2D objects)
    # Ensure GREY_DARK is defined (or replace it with a color string)
    medianprops = dict(
        linewidth=2, 
        color=GREY_DARK,
        solid_capstyle="butt"
    )

    # For whiskers, since they are Line2D objects, use 'color'
    whiskerprops = dict(
        linewidth=2, 
        color='turquoise'
    )

    bplot = ax.boxplot(
        cleaned_data,
        positions=positions, 
        showfliers=False,
        showcaps = False,
        medianprops=medianprops,
        whiskerprops=whiskerprops,
        boxprops=boxprops,
        patch_artist=True,
        **kwargs
    )

    # Update the color of each box (these are patch objects)
    for i, box in enumerate(bplot['boxes']):
        box.set_edgecolor(colors[i])
        if(hatch):
            box.set_hatch('//')
    
    
    # Update the color of the whiskers (each box has 2 whiskers)
    for i in range(len(positions)):
        bplot['whiskers'][2*i].set_color(colors[i])
        bplot['whiskers'][2*i+1].set_color(colors[i])
    
    # If caps are enabled, update their color (Line2D objects)
    if 'caps' in bplot:
        for i, cap in enumerate(bplot['caps']):
            cap.set_color(colors[i//2])  # two caps per box

    if scatter:
        if is_1d:
            x_data = np.array([positions[0]] * len(cleaned_data))
            x_jittered = x_data + stats.t(df=6, scale=jitter).rvs(len(x_data))
            ax.scatter(x_jittered, cleaned_data, s=100, color=colors2, alpha=scatter_alpha)
        else:
            x_data = [np.array([positions[i]] * len(d)) for i, d in enumerate(cleaned_data)]
            x_jittered = [x + stats.t(df=6, scale=jitter).rvs(len(x)) for x in x_data]
            for x, y, c in zip(x_jittered, cleaned_data, colors2):
                ax.scatter(x, y, s=100, color=c, alpha=scatter_alpha)

In [None]:
class ThinPatchHandler(HandlerPatch):
    def create_artists(self, legend, orig_handle,
                       xdescent, ydescent, width, height, fontsize, trans):
        # mRKe the legend‐patch only 20% as tall as a normal one
        thin_height = height * 0.2
        # center it vertically
        y = ydescent + (height - thin_height) / 2
        patch = Rectangle((xdescent, y),
                          width, thin_height,
                          facecolor=orig_handle.get_facecolor(),
                          edgecolor=orig_handle.get_edgecolor(),
                          hatch=orig_handle.get_hatch(),
                          linewidth=orig_handle.get_linewidth(),
                          transform=trans)
        return [patch]

In [None]:
Save = False

# Fig 1

In [None]:
FigLoc = image_path + 'Fig_S1/'
if not os.path.exists(FigLoc):
    os.mRKedirs(FigLoc)

In [None]:
fimg_init, fbvals, fbvecs = get_fnames('small_64D')
bvals, bvecs = read_bvals_bvecs(fbvals, fbvecs)
hsph_initial = HemiSphere(xyz=bvecs[1:])
hsph_initial20 = HemiSphere(xyz=bvecs[1:20])
hsph_initial7 = HemiSphere(xyz=bvecs[1:7])
hsph_updated,potentials = disperse_charges(hsph_initial,5000)
hsph_updated20,potentials = disperse_charges(hsph_initial20,5000)
hsph_updated7,potentials = disperse_charges(hsph_initial7,5000)

gtabSimF = gradient_table(np.array([0]+[1000]*64).squeeze(), np.vstack([[0,0,0],hsph_updated.vertices]))
gtabSim20 = gradient_table(np.array([0]+[1000]*19).squeeze(), np.vstack([[0,0,0],hsph_updated20.vertices]))
gtabSim7 = gradient_table(np.array([0]+[1000]*6).squeeze(), np.vstack([[0,0,0],hsph_updated7.vertices]))

In [None]:
FullDat = []
S0Full  = []
DTIFull = []
for i in tqdm(range(1,6)):
    fdwi = './HCP_data/Pat'+str(i)+'/diff_1k.nii.gz'
    bvalloc = './HCP_data/Pat'+str(i)+'/bvals_1k.txt'
    bvecloc = './HCP_data/Pat'+str(i)+'/bvecs_1k.txt'
    
    bvalsHCP = np.loadtxt(bvalloc)
    bvecsHCP = np.loadtxt(bvecloc)
    gtabHCP = gradient_table(bvalsHCP, bvecsHCP)
    
    data, affine, img = load_nifti(fdwi, return_img=True)
    axial_middle = data.shape[2] // 2
    maskdata, mask = median_otsu(data, vol_idx=range(10, 50), median_radius=3,
                                 numpass=1, autocrop=True, dilate=2)
    print('maskdata.shape (%d, %d, %d, %d)' % maskdata.shape)
    
    TestData = maskdata[:, :, axial_middle, :]
    FlatTD = TestData.reshape(maskdata.shape[0]*maskdata.shape[1],69)
    FlatTD = FlatTD[FlatTD.sum(axis=-1)>0]
    FlatTD = FlatTD[~np.array(FlatTD<0).any(axis=-1)]
    FullDat.append(FlatTD)
    # Fit the tensor model to the DWI data with return_S0_hat=True
    tenmodel = dti.TensorModel(gtabHCP, return_S0_hat=True,fit_method='NLLS')
    tenfit = tenmodel.fit(FlatTD)
    DTIHCP = tenfit.quadratic_form
    DTIFull.append(DTIHCP)
    # Get the estimated S0_hat values
    S0HCP = tenfit.S0_hat
    S0Full.append(S0HCP)
DTIFull = np.concatenate(DTIFull)
FullDat = np.concatenate(FullDat)
S0Full = np.hstack(S0Full)

In [None]:
np.random.seed(0)
torch.manual_seed(0)
Samples  = []
DTISim = []
S0Sim    = []
# Define the lower and upper bounds

lower_abs,upper_abs = -0.07,0.07
lower_rest,upper_rest = -0.015,0.015
lower_S0 = 25
upper_S0 = 2000

custom_prior = DTIPriorS0(lower_abs,upper_abs,lower_rest,upper_rest,lower_S0,upper_S0)
prior, *_ = process_prior(custom_prior) 

params = prior.sample([5000])
for i in tqdm(range(5000)):
    dt = ComputeDTI(params[i])
    dt = ForceLowFA(dt)
    DTISim.append(dt)
    S0Sim.append(params[i,-1])
    Samples.append([CustomSimulator(dt,gtabSimF, S0=200,snr=scale) for scale in NoiseLevels])
    
Samples = np.array(Samples).squeeze()
Samples = np.moveaxis(Samples, 0, -1)

DTISim = np.array(DTISim)

RKSim = [np.mean(np.linalg.eigh(B)[0]) for B in DTISim]
RKHCP = [np.mean(np.linalg.eigh(B)[0]) for B in DTIFull]

FASim = [FracAni(np.linalg.eigh(B)[0],m) for B,m in zip(DTISim,RKSim)]
FAHCP = [FracAni(np.linalg.eigh(B)[0],m) for B,m in zip(DTIFull,RKHCP)]

## a

In [None]:
plt.hist(S0Sim,density=True,stacked=True,alpha=0.75,label='Simulated',color=SBIFit,bins=100)
plt.hist(S0Full,density=True,stacked=True,alpha=0.75,label='HCP',color='gray',bins=100)
plt.legend(fontsize=32,loc=1,bbox_to_anchor=(0.95,1.),columnspacing=0.3,handlelength=0.8,handletextpad=0.1)
plt.gca().ticklabel_format(axis='y',style='sci',scilimits=(-1,1))
plt.gca().yaxis.get_offset_text().set_fontsize(32)
plt.yticks(fontsize=32)
plt.xticks(fontsize=32)
plt.xlim(0,2000)
plt.xticks([0,1000])
if Save: plt.savefig(FigLoc+'S0Dist.pdf',format='pdf',bbox_inches='tight',transparent=True)

## b

In [None]:
plt.hist(RKSim,density=True,stacked=True,label='Simulated samples',color=SBIFit,bins=100)
plt.hist(RKHCP,density=True,stacked=True,alpha=0.75,label='HPC subset',color='gray',bins=100)
plt.gca().ticklabel_format(axis='y',style='sci',scilimits=(-1,1))
plt.gca().ticklabel_format(axis='x',style='sci',scilimits=(-1,1))
plt.gca().xaxis.get_offset_text().set_visible(False)

plt.yticks(fontsize=32)
plt.xticks(fontsize=32)
plt.xticks([0,0.003],['0','3e-3'])
if Save: plt.savefig(FigLoc+'RKDist.pdf',format='pdf',bbox_inches='tight',transparent=True)

## c

In [None]:
plt.hist(FASim,density=True,label='Simulated samples',color=SBIFit,bins=100)
plt.hist(FAHCP,density=True,alpha=0.75,label='HPC subset',color='gray',bins=100)
plt.yticks(fontsize=32)
plt.xticks(fontsize=32)
if Save: plt.savefig(FigLoc+'FADist.pdf',format='pdf',bbox_inches='tight',transparent=True)

## d

In [None]:
fig,axs = plt.subplots(3,3,figsize=(12,12))
ax = axs.ravel()
ax[0].hist(DTISim[:,0,0],density=True,color=SBIFit,bins=100)
ax[1].hist(DTISim[:,0,1],density=True,color=SBIFit,bins=100)
ax[2].hist(DTISim[:,0,2],density=True,color=SBIFit,bins=100)
ax[4].hist(DTISim[:,1,1],density=True,color=SBIFit,bins=100)
ax[5].hist(DTISim[:,1,2],density=True,color=SBIFit,bins=100)
ax[-1].hist(DTISim[:,2,2],density=True,color=SBIFit,bins=100)


ax[0].hist(DTIFull[:,0,0],density=True,alpha=0.75,color='gray',bins=100)
ax[1].hist(DTIFull[:,0,1],density=True,alpha=0.75,color='gray',bins=100)
ax[2].hist(DTIFull[:,0,2],density=True,alpha=0.75,color='gray',bins=100)
ax[4].hist(DTIFull[:,1,1],density=True,alpha=0.75,color='gray',bins=100)
ax[5].hist(DTIFull[:,1,2],density=True,alpha=0.75,color='gray',bins=100)
ax[-1].hist(DTIFull[:,2,2],density=True,alpha=0.75,color='gray',bins=100)
ax[3].axis('off')
ax[-2].axis('off')
ax[-3].axis('off')

for a in ax:
    a.tick_params(axis='x', labelsize=32)
    a.tick_params(axis='y', labelsize=32)
    a.ticklabel_format(axis='y',style='sci',scilimits=(-1,1))
    a.ticklabel_format(axis='x',style='sci',scilimits=(-1,1))
    a.yaxis.get_offset_text().set_fontsize(32)
ax[0].set_xticks([0,2.5e-3],['0','2.5e-3'])
ax[1].set_xlim([-1.2e-3,1e-3])
ax[1].set_xticks([-1e-3,0,1e-3],['-1e-3','0','1e-3'])

ax[2].set_xlim([-8e-4,8e-4])
ax[2].set_yticks([0,10000])
ax[2].set_xticks([-5e-4,0,5e-4],['-5e-4','0','5e-4'])

ax[4].set_xticks([0,2.5e-3],['0','2.5e-3'])

ax[5].set_xlim([-1.2e-3,1e-3])
ax[5].set_xticks([-1e-3,0,1e-3],['-1e-3','0','1e-3'])

ax[-1].set_xticks([0,2.5e-3],['0','2.5e-3'])

ax[0].set_xlabel('$D_{11}$',fontsize=32)
ax[1].set_xlabel('$D_{12}$',fontsize=32)
ax[2].set_xlabel('$D_{13}$',fontsize=32)
ax[4].set_xlabel('$D_{22}$',fontsize=32)
ax[5].set_xlabel('$D_{23}$',fontsize=32)
ax[-1].set_xlabel('$D_{33}$',fontsize=32)
plt.tight_layout()
if Save: plt.savefig(FigLoc+'DTDist.pdf',format='pdf',bbox_inches='tight',transparent=True)

# Fig 2

In [None]:
FigLoc = image_path + 'Fig_S2/'
if not os.path.exists(FigLoc):
    os.mkdir(FigLoc)

In [None]:

i = 1
fdwi = './HCP_data/Pat'+str(i)+'/diff_1k.nii.gz'
bvalloc = './HCP_data/Pat'+str(i)+'/bvals_1k.txt'
bvecloc = './HCP_data/Pat'+str(i)+'/bvecs_1k.txt'

fdwi3 = './HCP_data/Pat'+str(i)+'/diff_3k.nii.gz'
bvalloc3 = './HCP_data/Pat'+str(i)+'/bvals_3k.txt'
bvecloc3 = './HCP_data/Pat'+str(i)+'/bvecs_3k.txt'

bvalsHCP = np.loadtxt(bvalloc)
bvecsHCP = np.loadtxt(bvecloc)
gtabHCP = gradient_table(bvalsHCP, bvecsHCP)

bvalsHCP3 = np.loadtxt(bvalloc3)
bvecsHCP3 = np.loadtxt(bvecloc3)
gtabHCP3 = gradient_table(bvalsHCP3, bvecsHCP3)

gtabExt  = gradient_table(np.hstack((bvalsHCP,bvalsHCP3)), np.vstack((bvecsHCP,bvecsHCP3)))

data, affine, img = load_nifti(fdwi, return_img=True)
data, affine = reslice(data, affine, (1.5,1.5,1.5), (2.5,2.5,2.5))
axial_middle = data.shape[2] // 2
maskdata, mask = median_otsu(data, vol_idx=range(10, 50), median_radius=3,
                             numpass=1, autocrop=False, dilate=2)

data3, affine, img = load_nifti(fdwi3, return_img=True)
data3, affine = reslice(data3, affine, (1.5,1.5,1.5), (2.5,2.5,2.5))
# Get the indices of True values
true_indices = np.argwhere(mask)

# Determine the minimum and maximum indices along each dimension
min_coords = true_indices.min(axis=0)
max_coords = true_indices.max(axis=0)

maskdata  = maskdata[min_coords[0]:max_coords[0]+1,min_coords[1]:max_coords[1]+1,min_coords[2]:max_coords[2]+1]
maskdata3 = data3[min_coords[0]:max_coords[0]+1,min_coords[1]:max_coords[1]+1,min_coords[2]:max_coords[2]+1]

TestData = np.concatenate([maskdata[:, :, axial_middle, :],maskdata3[:, :, axial_middle, :]],axis=-1)
FlatTD = TestData.reshape(maskdata.shape[0]*maskdata.shape[1],138)
FlatTD = FlatTD[FlatTD[:,:69].sum(axis=-1)>0]
FlatTD = FlatTD[~np.array(FlatTD<0).any(axis=-1)]

dkimodel = dki.DiffusionKurtosisModel(gtabExt)
tenfit = dkimodel.fit(FlatTD)
DKIHCP = tenfit.kt
DTIHCP = tenfit.lower_triangular()
DKIFull = np.array(DKIHCP)
DTIFull = np.array(DTIHCP)


DTIFilt1 = DTIFull[(abs(DKIFull)<10).all(axis=1)]
DKIFilt1 = DKIFull[(abs(DKIFull)<10).all(axis=1)]
DTIFilt = DTIFilt1[(DKIFilt1>-3/7).all(axis=1)]
DKIFilt = DKIFilt1[(DKIFilt1>-3/7).all(axis=1)]

TrueMets = []
FA       = []
for (dt,kt) in tqdm(zip(DTIFilt,DKIFilt)):
    TrueMets.append(DKIMetrics(dt,kt))
    FA.append(FracAni(np.linalg.eigh(vals_to_mat(dt))[0],np.mean(np.linalg.eigh(vals_to_mat(dt))[0])))
TrueMets = np.array(TrueMets)
TrueFA = np.array(FA)

In [None]:
# Full fit
DT1_full,DT2_full = FitDT(DTIFilt,1)
x4_full,R1_full,x2_full,R2_full = FitKT(DKIFilt,1)

# LowFA Fit
DT1_lfa,DT2_lfa = FitDT(DTIFilt[TrueMets[:,-1]<0.3,:],1)
x4_lfa,R1_lfa,x2_lfa,R2_lfa = FitKT(DKIFilt[TrueMets[:,-1]<0.3,:],1)

# HighFA Fit
DT1_hfa,DT2_hfa = FitDT(DTIFilt[TrueMets[:,-1]>0.7,:],1)
x4_hfa,R1_hfa,x2_hfa,R2_hfa = FitKT(DKIFilt[TrueMets[:,-1]>0.7,:],1)

# UltraLowFA Fit
DT1_ulfa,DT2_ulfa = FitDT(DTIFilt[TrueMets[:,-1]<0.1,:],1)
x4_ulfa,R1_ulfa,x2_ulfa,R2_ulfa = FitKT(DKIFilt[TrueMets[:,-1]<0.1,:],1)

# HigherRK Fit
DT1_hRK,DT2_hRK = FitDT(DTIFilt[TrueMets[:,1]>0.9,:],1)
x4_hRK,R1_hRK,x2_hRK,R2_hRK = FitKT(DKIFilt[TrueMets[:,1]>0.9,:],1)

In [None]:
FullDat = []
S0Full  = []
DKIFull = []
DTIFull = []
for i in tqdm(range(1,6)):
    fdwi = './HCP_data/Pat'+str(i)+'/diff_1k.nii.gz'
    bvalloc = './HCP_data/Pat'+str(i)+'/bvals_1k.txt'
    bvecloc = './HCP_data/Pat'+str(i)+'/bvecs_1k.txt'
    
    fdwi3 = './HCP_data/Pat'+str(i)+'/diff_3k.nii.gz'
    bvalloc3 = './HCP_data/Pat'+str(i)+'/bvals_3k.txt'
    bvecloc3 = './HCP_data/Pat'+str(i)+'/bvecs_3k.txt'
    
    bvalsHCP = np.loadtxt(bvalloc)
    bvecsHCP = np.loadtxt(bvecloc)
    gtabHCP = gradient_table(bvalsHCP, bvecsHCP)
    
    bvalsHCP3 = np.loadtxt(bvalloc3)
    bvecsHCP3 = np.loadtxt(bvecloc3)
    gtabHCP3 = gradient_table(bvalsHCP3, bvecsHCP3)
    
    gtabExt  = gradient_table(np.hstack((bvalsHCP,bvalsHCP3)), np.vstack((bvecsHCP,bvecsHCP3)))
    
    data, affine, img = load_nifti(fdwi, return_img=True)
    data, affine = reslice(data, affine, (1.5,1.5,1.5), (2.5,2.5,2.5))
    axial_middle = data.shape[2] // 2
    maskdata, mask = median_otsu(data, vol_idx=range(10, 50), median_radius=3,
                                 numpass=1, autocrop=False, dilate=2)
    
    data3, affine, img = load_nifti(fdwi3, return_img=True)
    data3, affine = reslice(data3, affine, (1.5,1.5,1.5), (2.5,2.5,2.5))
    # Get the indices of True values
    true_indices = np.argwhere(mask)
    
    # Determine the minimum and maximum indices along each dimension
    min_coords = true_indices.min(axis=0)
    max_coords = true_indices.max(axis=0)
    
    maskdata  = maskdata[min_coords[0]:max_coords[0]+1,min_coords[1]:max_coords[1]+1,min_coords[2]:max_coords[2]+1]
    maskdata3 = data3[min_coords[0]:max_coords[0]+1,min_coords[1]:max_coords[1]+1,min_coords[2]:max_coords[2]+1]
    
    TestData = np.concatenate([maskdata[:, :, axial_middle, :],maskdata3[:, :, axial_middle, :]],axis=-1)
    FlatTD = TestData.reshape(maskdata.shape[0]*maskdata.shape[1],138)
    FlatTD = FlatTD[FlatTD[:,:69].sum(axis=-1)>0]
    FlatTD = FlatTD[~np.array(FlatTD<0).any(axis=-1)]
    FullDat.append(FlatTD)
    
    dkimodel = dki.DiffusionKurtosisModel(gtabExt)
    tenfit = dkimodel.fit(FlatTD)
    DKIHCP = tenfit.kt
    DTIHCP = tenfit.lower_triangular()
    DTIFull.append(DTIHCP)
    DKIFull.append(DKIHCP)
    # Get the estimated S0_hat values
    S0HCP = tenfit.S0_hat
    S0Full.append(S0HCP)
DKIFull = np.concatenate(DKIFull)
DTIFull = np.concatenate(DTIFull)

DTIFilt_all = DTIFull[(abs(DKIFull)<10).all(axis=1)]
DKIFilt_all = DKIFull[(abs(DKIFull)<10).all(axis=1)]
DTIFilt_all = DTIFilt_all[(DKIFilt_all>-3/7).all(axis=1)]
DKIFilt_all = DKIFilt_all[(DKIFilt_all>-3/7).all(axis=1)]

## a

In [None]:
plt.hist(DTIFilt[:,5],bins=30,density=True,color=WLSFit,alpha=0.5,label='1 HCP Indv.')
plt.hist(DTIFilt_all[:,5],bins=30,density=True,color=SBIFit,alpha=0.5,label='All HCP')
plt.legend(fontsize=32,columnspacing=0.3,handlelength=0.4,handletextpad=0.3,loc=1)
plt.yticks([])
plt.xticks(fontsize=32)
if Save: plt.savefig(FigLoc+'Comp1.pdf',format='pdf',bbox_inches='tight',transparent=True)

## b

In [None]:
data = DTIFilt[:,0]
shape,loc,scale = lognorm.fit(data)
# Generate x-values for plotting
x = np.linspace(min(data), max(data), 1000)

# Compute the fitted PDF
dti1_fitted = stats.lognorm(shape, loc=loc, scale=scale)
plt.hist(DTIFilt[:,5],bins=30,density=True,color=WLSFit)
plt.hist(np.array(DTISim)[:,0,0],bins=30,density=True,alpha=0.5,color='gray')
plt.plot(x,dti1_fitted.pdf(x),lw=3,c=SBIFit)
plt.text(0.0014,600,"Lognormal, \n shape = {:.2f}, \n location = {:.2e} \n scale = {:.2e}".format(shape,loc,scale),
        fontsize=32)
plt.yticks([])
plt.xticks(fontsize=32)
if Save: plt.savefig(FigLoc+'Normal1.pdf',format='pdf',bbox_inches='tight',transparent=True)
plt.show()
#DT_rest
data = DTIFilt[:,1]
loc,scale = stats.norm.fit(data)

# Compute the fitted PDF
dti2_fitted = stats.norm(loc=loc, scale=scale)

x = np.linspace(min(data), max(data), 1000)

# Compute the fitted PDF
dti1_fitted = stats.norm(loc=loc, scale=scale)
plt.hist(DTIFilt[:,1],bins=30,density=True,color=WLSFit,label='HCP data')
plt.hist(DTISim[:,1,0],bins=30,density=True,alpha=0.5,color='gray',label='DTI prior')
plt.plot(x,dti1_fitted.pdf(x),lw=3,c=SBIFit,label='stat. fit')
plt.text(0.00011,2000,"Normal, \n $\mu$ = {:.2f},\n $\sigma$ = {:.2e} \n".format(loc,scale),
        fontsize=32)
plt.yticks([])
plt.xticks([-0.0005,0,0.0005],[-5e-4,0,5e-4],fontsize=32)

plt.legend(fontsize=32,columnspacing=0.3,handlelength=0.4,handletextpad=0.1,loc=1,bbox_to_anchor=(0.52,1))
if Save: plt.savefig(FigLoc+'Normal2.pdf',format='pdf',bbox_inches='tight',transparent=True)

plt.show()

data = DKIFilt[:,0]
shape,loc,scale = lognorm.fit(data)
x4_fitted = stats.lognorm(shape, loc=loc, scale=scale)
    
# Generate x-values for plotting
x = np.linspace(min(data), max(data), 1000)

# Compute the fitted PDF
dti1_fitted = stats.lognorm(shape, loc=loc, scale=scale)
plt.hist(DKIFilt[:,0],bins=30,density=True,color=WLSFit)
plt.plot(x,x4_fitted.pdf(x),lw=3,c=SBIFit)
plt.text(1,0.8,"Lognormal, \n shape = {:.2f}, \n location = {:.2e} \n scale = {:.2e}".format(shape,loc,scale),
        fontsize=32)
plt.yticks([])
plt.xticks(fontsize=32)
if Save: plt.savefig(FigLoc+'Normal3.pdf',format='pdf',bbox_inches='tight',transparent=True)
plt.show()


# Fitting R1
data = DKIFilt[:,3]
loc,scale = stats.norm.fit(data)
R1_fitted = stats.norm(loc,scale)
    
# Generate x-values for plotting
x = np.linspace(min(data), max(data), 1000)

plt.hist(DKIFilt[:,3],bins=30,density=True,color=WLSFit)
plt.plot(x,R1_fitted.pdf(x),lw=3,c=SBIFit)
plt.text(0.1,3,"Normal, \n $\mu$ = {:.2f},\n $\sigma$ = {:.2e} \n".format(loc,scale),
        fontsize=32)
plt.yticks([])
plt.xticks(fontsize=32)
if Save: plt.savefig(FigLoc+'Normal4.pdf',format='pdf',bbox_inches='tight',transparent=True)
plt.show()



## c

In [None]:

mask = TrueMets[:,-1]<0.3
data = DTIFilt[mask,0]
shape,loc,scale = lognorm.fit(data)
# Generate x-values for plotting
x = np.linspace(min(data), max(data), 1000)

# Compute the fitted PDF
dti1_fitted = stats.lognorm(shape, loc=loc, scale=scale)
plt.hist(DTIFilt[mask,5],bins=30,density=True,color=WLSFit)
plt.plot(x,dti1_fitted.pdf(x),lw=3,c=SBIFit)
plt.text(0.0014,600,"Lognormal, \n shape = {:.2f}, \n location = {:.2e} \n scale = {:.2e}".format(shape,loc,scale),
        fontsize=32)
plt.yticks([])
plt.xticks(fontsize=32)
if Save: plt.savefig(FigLoc+'lowFA1.pdf',format='pdf',bbox_inches='tight',transparent=True)
plt.show()
#DT_rest
data = DTIFilt[mask,1]
loc,scale = stats.norm.fit(data)

# Compute the fitted PDF
dti2_fitted = stats.norm(loc=loc, scale=scale)

x = np.linspace(min(data), max(data), 1000)

# Compute the fitted PDF
dti1_fitted = stats.norm(loc=loc, scale=scale)
plt.hist(DTIFilt[mask,1],bins=30,density=True,color=WLSFit)
plt.plot(x,dti1_fitted.pdf(x),lw=3,c=SBIFit)
plt.text(0.00011,2000,"Normal, \n $\mu$ = {:.2f},\n $\sigma$ = {:.2e} \n".format(loc,scale),
        fontsize=32)
plt.yticks([])
plt.xticks(fontsize=32)
if Save: plt.savefig(FigLoc+'lowFA2.pdf',format='pdf',bbox_inches='tight',transparent=True)
plt.show()

data = DKIFilt[mask,0]
shape,loc,scale = lognorm.fit(data)
x4_fitted = stats.lognorm(shape, loc=loc, scale=scale)
    
# Generate x-values for plotting
x = np.linspace(min(data), max(data), 1000)

# Compute the fitted PDF
dti1_fitted = stats.lognorm(shape, loc=loc, scale=scale)
plt.hist(DKIFilt[mask,0],bins=30,density=True,color=WLSFit)
plt.plot(x,x4_fitted.pdf(x),lw=3,c=SBIFit)
plt.text(1,0.8,"Lognormal, \n shape = {:.2f}, \n location = {:.2e} \n scale = {:.2e}".format(shape,loc,scale),
        fontsize=32)
plt.yticks([])
plt.xticks(fontsize=32)
if Save: plt.savefig(FigLoc+'lowFA3.pdf',format='pdf',bbox_inches='tight',transparent=True)
plt.show()


# Fitting R1
data = DKIFilt[mask,3]
loc,scale = stats.norm.fit(data)
R1_fitted = stats.norm(loc,scale)
    
# Generate x-values for plotting
x = np.linspace(min(data), max(data), 1000)

plt.hist(DKIFilt[mask,3],bins=30,density=True,color=WLSFit)
plt.plot(x,R1_fitted.pdf(x),lw=3,c=SBIFit)
plt.text(0.05,3,"Normal, \n $\mu$ = {:.2f},\n $\sigma$ = {:.2e} \n".format(loc,scale),
        fontsize=32)
plt.yticks([])
plt.xticks(fontsize=32)
if Save: plt.savefig(FigLoc+'lowFA4.pdf',format='pdf',bbox_inches='tight',transparent=True)
plt.show()



## d

In [None]:
mask = TrueMets[:,-1]>0.7
data = DTIFilt[mask,0]
shape,loc,scale = lognorm.fit(data)
# Generate x-values for plotting
x = np.linspace(min(data), max(data), 1000)

# Compute the fitted PDF
dti1_fitted = stats.lognorm(shape, loc=loc, scale=scale)
plt.hist(DTIFilt[mask,0],bins=30,density=True,color=WLSFit)
plt.plot(x,dti1_fitted.pdf(x),lw=3,c=SBIFit)
plt.text(0.0008,1400,"Lognormal, \n shape = {:.2f}, \n location = {:.2e} \n scale = {:.2e}".format(shape,loc,scale),
        fontsize=24)
plt.yticks([])
plt.xticks(fontsize=32)
if Save: plt.savefig(FigLoc+'HighFA1.pdf',format='pdf',bbox_inches='tight',transparent=True)
plt.show()
#DT_rest
data = DTIFilt[mask,1]
loc,scale = stats.norm.fit(data)

# Compute the fitted PDF
dti2_fitted = stats.norm(loc=loc, scale=scale)

x = np.linspace(min(data), max(data), 1000)

# Compute the fitted PDF
dti1_fitted = stats.norm(loc=loc, scale=scale)
plt.hist(DTIFilt[mask,1],bins=30,density=True,color=WLSFit)
plt.plot(x,dti1_fitted.pdf(x),lw=3,c=SBIFit)
plt.text(0.00013,1600,"Normal, \n $\mu$ = {:.2f},\n $\sigma$ = {:.2e} \n".format(loc,scale),
        fontsize=32)
plt.yticks([])
plt.xticks(fontsize=32)
if Save: plt.savefig(FigLoc+'HighFA2.pdf',format='pdf',bbox_inches='tight',transparent=True)
plt.show()

data = DKIFilt[mask,0]
shape,loc,scale = lognorm.fit(data)
x4_fitted = stats.lognorm(shape, loc=loc, scale=scale)
    
# Generate x-values for plotting
x = np.linspace(min(data), max(data), 1000)

# Compute the fitted PDF
dti1_fitted = stats.lognorm(shape, loc=loc, scale=scale)
plt.hist(DKIFilt[mask,0],bins=30,density=True,color=WLSFit)
plt.plot(x,x4_fitted.pdf(x),lw=3,c=SBIFit)
plt.text(1.3,0.5,"Lognormal, \n shape = {:.2f}, \n location = {:.2e} \n scale = {:.2e}".format(shape,loc,scale),
        fontsize=24)
plt.yticks([])
plt.xticks(fontsize=32)
if Save: plt.savefig(FigLoc+'HighFA3.pdf',format='pdf',bbox_inches='tight',transparent=True)
plt.show()


# Fitting R1
data = DKIFilt[mask,3]
loc,scale = stats.norm.fit(data)
R1_fitted = stats.norm(loc,scale)
    
# Generate x-values for plotting
x = np.linspace(min(data), max(data), 1000)

plt.hist(DKIFilt[mask,3],bins=30,density=True,color=WLSFit)
plt.plot(x,R1_fitted.pdf(x),lw=3,c=SBIFit)
plt.text(0.3,1,"Normal, \n $\mu$ = {:.2f},\n $\sigma$ = {:.2e} \n".format(loc,scale),
        fontsize=32)
plt.yticks([])
plt.xticks(fontsize=32)
if Save: plt.savefig(FigLoc+'HighFA4.pdf',format='pdf',bbox_inches='tight',transparent=True)
plt.show()



## e

In [None]:
i,j=0,0
mask = (TrueMets[:,-1]<0.7)*(TrueMets[:,-1]>0.3)
plt.scatter(DTIFilt[mask,i],DKIFilt[mask,j],color=np.clip(np.array(col.to_rgb(WLSFit))-0.5,0,1),label='HCP data')
mask = TrueMets[:,-1]>0.7
plt.scatter(DTIFilt[mask,i],DKIFilt[mask,j],color=np.clip(np.array(col.to_rgb(WLSFit)),0,1),marker='v'
            ,label='HCP data (KFA$>$0.7)')
mask = TrueMets[:,-1]<0.3
plt.scatter(DTIFilt[mask,i],DKIFilt[mask,j],color=np.clip(np.array(col.to_rgb(WLSFit))-0.3,0,1),marker='^'
            ,label='HCP data (KFA$<$0.3)')
plt.yticks([])
plt.xticks([])
plt.legend(fontsize=20,loc=1,bbox_to_anchor=(1,1),handlelength=0.4,handletextpad=0.4,markerscale=2)
if Save: plt.savefig(FigLoc+'Scatter1Dat.pdf',format='pdf',bbox_inches='tight',transparent=True)
plt.show()

In [None]:
i,j=9,0
mask = (TrueMets[:,-1]<0.7)*(TrueMets[:,-1]>0.3)
plt.scatter(DKIFilt[mask,i],DKIFilt[mask,j],color=np.clip(np.array(col.to_rgb(WLSFit))-0.5,0,1))
mask = TrueMets[:,-1]>0.7
plt.scatter(DKIFilt[mask,i],DKIFilt[mask,j],color=np.clip(np.array(col.to_rgb(WLSFit))+0.2,0,1),marker='v')
mask = TrueMets[:,-1]<0.3
plt.scatter(DKIFilt[mask,i],DKIFilt[mask,j],color=np.clip(np.array(col.to_rgb(WLSFit))-0.3,0,1),marker='^')
plt.yticks([])
plt.xticks([])
if Save: plt.savefig(FigLoc+'Scatter2Dat.pdf',format='pdf',bbox_inches='tight',transparent=True)
plt.show()

# Fig 3

In [None]:
FigLoc = image_path + 'Fig_S3/'
if not os.path.exists(FigLoc):
    os.mkdir(FigLoc)

In [None]:
fdwi = './HCP_data/Pat'+str(1)+'/diff_1k.nii.gz'
bvalloc = './HCP_data/Pat'+str(1)+'/bvals_1k.txt'
bvecloc = './HCP_data/Pat'+str(1)+'/bvecs_1k.txt'

bvalsHCP = np.loadtxt(bvalloc)
bvecsHCP = np.loadtxt(bvecloc)
gtabHCP = gradient_table(bvalsHCP, bvecsHCP)

data, affine, img = load_nifti(fdwi, return_img=True)
axial_middle = data.shape[2] // 2
maskdata, mask = median_otsu(data, vol_idx=range(10, 50), median_radius=3,
                             numpass=1, autocrop=True, dilate=2)

# Choose the first point (arbitrary starting point, e.g., the first gradient)
selected_indices = [1]
distance_matrix = squareform(pdist(bvecsHCP))
# Iteratively select the point furthest from the current selection
for _ in range(5):  # We need 7 points in total, and one is already selected
    remaining_indices = list(set(range(len(bvecsHCP))) - set(selected_indices))
    
    # Calculate the minimum distance to the selected points for each remaining point
    min_distances = np.min(distance_matrix[remaining_indices][:, selected_indices], axis=1)
    
    # Select the point with the maximum minimum distance
    next_index = remaining_indices[np.argmax(min_distances)]
    selected_indices.append(next_index)

selected_indices = [0]+selected_indices

bvalsHCP7 = bvalsHCP[selected_indices]
bvecsHCP7 = bvecsHCP[selected_indices]
gtabHCP7 = gradient_table(bvalsHCP7, bvecsHCP7)

In [None]:
custom_prior = DTIPriorS0Noise(lower_abs,upper_abs,lower_rest,upper_rest,lower_S0,upper_S0,0,30)
priorS0Noise, *_ = process_prior(custom_prior) 

## a

In [None]:
if os.path.exists(f"{network_path}/DTIHCPFull.pickle"):
    with open(f"{network_path}/DTIHCPFull.pickle", "rb") as handle:
        posterior2 = pickle.load(handle)
else:
    np.random.seed(1)
    torch.manual_seed(1)
    bvals = gtabHCP.bvals
    bvecs = gtabHCP.bvecs
    Obs = []
    Par = []
    for i in tqdm(range(TrainingSamples)):
        params = priorS0.sample()
        dt = ComputeDTI(params[:-1])
        dt = ForceLowFA(dt)
        Obs.append(CustomSimulator(dt,gtabHCP,params[-1],np.random.rand()*30 + 20))
        Par.append(np.hstack([mat_to_vals(dt),params[-1]]))
    
    Obs = np.array(Obs)
    Par = np.array(Par)
    Obs = torch.tensor(Obs).float()
    Par= torch.tensor(Par).float()
    
    # Create inference object. Here, NPE is used.
    inference = SNPE(prior=priorS0)
    
    # generate simulations and pass to the inference object
    inference = inference.append_simulations(Par, Obs)
    
    # train the density estimator and build the posterior
    density_estimator = inference.train(stop_after_epochs= 100)
    posterior2 = inference.build_posterior(density_estimator)
    if not os.path.exists(f"{save_path}/DTIHCPFull.pickle"):
        with open(f"{save_path}/DTIHCPFull.pickle", "wb") as handle:
            pickle.dump(posterior2, handle)

In [None]:
# Compute the mask where the sum is not zero
mask = np.sum(maskdata[:, :, axial_middle, :], axis=-1) != 0

# Get the indices where mask is True
indices = np.argwhere(mask)

# Define the function for optimization
def optimize_pixel(i, j):
    torch.manual_seed(10)  # If required
    posterior_samples_1 = posterior2.sample((1000,), x=maskdata[i,j,axial_middle, :],show_progress_bars=False)
    return i, j, posterior_samples_1.mean(axis=0)

# Initialize NoiseEst with the appropriate shape
ArrShape = mask.shape

# Use joblib to parallelize the optimization tasks
results = Parallel(n_jobs=-1)(
    delayed(optimize_pixel)(i, j) for i, j in tqdm(indices)
)


In [None]:

NoiseEst = np.zeros(list(ArrShape) + [7])

# Assign the optimization results to NoiseEst
for i, j, x in results:
    NoiseEst[i, j] = x

for i, j, x in results:
    NoiseEst[i, j,-2] = np.clip(NoiseEst[i, j,-2],0,100)
    NoiseEst[i, j,-3] = np.clip(NoiseEst[i, j,-3],0,1)

In [None]:

NoiseEst2 =  np.zeros_like(NoiseEst)
NoiseEst2 =  np.zeros_like(NoiseEst)
for i in range(94):
    for j in range(104):    
        NoiseEst2[i,j] = np.hstack([mat_to_vals(clip_negative_eigenvalues(vals_to_mat(NoiseEst[i,j]))),NoiseEst[i,j,-1]])
MD_SBI = np.zeros([94,104])
FA_SBI = np.zeros([94,104])
for i in range(94):
    for j in range(104):
        Eigs = np.linalg.eigh(vals_to_mat(NoiseEst2[i,j,:6]))[0]
        MD_SBI[i,j] = np.mean(Eigs)
        FA_SBI[i,j] = FracAni(Eigs,np.mean(Eigs))
FA_SBI[np.isnan(FA_SBI)] = 0

In [None]:
temp = np.copy(MD_SBI)
temp[~mask] = math.nan
img = plt.imshow(temp.T,cmap='hot')
plt.axis('off')
#cbar = plt.colorbar()
#cbar.formatter.set_powerlimits((0, 0))
vmin, vmax = img.get_clim()

if Save: plt.savefig(FigLoc+'HCP_SBI_MD_US.pdf',format='pdf',bbox_inches='tight',transparent=True)

In [None]:
tenmodel = dti.TensorModel(gtabHCP,return_S0_hat = True,fit_method='NLLS')
tenfit = tenmodel.fit(maskdata[:,:,axial_middle])
FAFull = dti.fractional_anisotropy(tenfit.evals)
MDFull = dti.mean_diffusivity(tenfit.evals)

In [None]:
temp = np.copy(MDFull)
temp[~mask] = math.nan
img = plt.imshow(temp.T,cmap='hot')
plt.axis('off')
cbar = plt.colorbar()
cbar.formatter.set_powerlimits((0, 0))
vmin, vmax = img.get_clim()

if Save: plt.savefig(FigLoc+'HCP_NLLS_MD_US.pdf',format='pdf',bbox_inches='tight',transparent=True)

In [None]:
data = MD_SBI.T-MDFull.T
data[~mask.T] = np.nan
norm = TwoSlopeNorm(vmin=np.nanmin(data), vcenter=0, vmax=np.nanmax(data))
plt.imshow(data,cmap='seismic',norm=norm)
plt.axis('off')
cbar = plt.colorbar()
ticks = [np.nanmin(data), 0, np.nanmax(data)]  # Adjust the number of ticks as needed
cbar.set_ticks(ticks)
cbar.formatter.set_powerlimits((0, 0))

if Save: plt.savefig(FigLoc+'HCP_MD_Diff.pdf',format='pdf',bbox_inches='tight',transparent=True)

## b

In [None]:
if os.path.exists(f"{network_path}/DTIHCPMin.pickle"):
    with open(f"{network_path}/DTIHCPMin.pickle", "rb") as handle:
        posterior7_2 = pickle.load(handle)
else:
    np.random.seed(1)
    torch.manual_seed(1)
    bvals = gtabHCP.bvals
    bvecs = gtabHCP.bvecs
    Obs = []
    Par = []
    for i in tqdm(range(TrainingSamples)):
        params = priorS0.sample()
        dt = ComputeDTI(params[:-1])
        dt = ForceLowFA(dt)
        Obs.append(CustomSimulator(dt,gtabHCP7,params[-1],np.random.rand()*30 + 20))
        Par.append(np.hstack([mat_to_vals(dt),params[-1]]))
    
    Obs = np.array(Obs)
    Par = np.array(Par)
    Obs = torch.tensor(Obs).float()
    Par= torch.tensor(Par).float()
    
    # Create inference object. Here, NPE is used.
    inference = SNPE(prior=priorS0)
    
    # generate simulations and pass to the inference object
    inference = inference.append_simulations(Par, Obs)
    
    # train the density estimator and build the posterior
    density_estimator = inference.train(stop_after_epochs= 100)
    posterior7_2 = inference.build_posterior(density_estimator)
    if not os.path.exists(f"{save_path}/DTIHCPMin.pickle"):
        with open(f"{save_path}/DTIHCPMin.pickle", "wb") as handle:
            pickle.dump(posterior7_2, handle)

In [None]:
# Compute the mask where the sum is not zero
mask = np.sum(maskdata[:, :, axial_middle, :], axis=-1) != 0

# Get the indices where mask is True
indices = np.argwhere(mask)

# Define the function for optimization
def optimize_pixel(i, j):
    torch.manual_seed(10)  # If required
    posterior_samples_1 = posterior7_2.sample((1000,), x=maskdata[i,j,axial_middle, selected_indices],show_progress_bars=False)
    return i, j, posterior_samples_1.mean(axis=0)

# Initialize NoiseEst with the appropriate shape
ArrShape = mask.shape

# Use joblib to parallelize the optimization tasks
results = Parallel(n_jobs=-1)(
    delayed(optimize_pixel)(i, j) for i, j in tqdm(indices)
)


In [None]:

NoiseEst = np.zeros(list(ArrShape) + [7])

# Assign the optimization results to NoiseEst
for i, j, x in results:
    NoiseEst[i, j] = x

for i, j, x in results:
    NoiseEst[i, j,-2] = np.clip(NoiseEst[i, j,-2],0,100)
    NoiseEst[i, j,-3] = np.clip(NoiseEst[i, j,-3],0,1)

NoiseEst2 =  np.zeros_like(NoiseEst)
for i in range(94):
    for j in range(104):    
        NoiseEst2[i,j] = np.hstack([mat_to_vals(clip_negative_eigenvalues(vals_to_mat(NoiseEst[i,j]))),NoiseEst[i,j,-1]])
MD_SBI7 = np.zeros([94,104])
FA_SBI7 = np.zeros([94,104])
for i in range(94):
    for j in range(104):
        Eigs = np.linalg.eigh(vals_to_mat(NoiseEst2[i,j,:6]))[0]
        MD_SBI7[i,j] = np.mean(Eigs)
        FA_SBI7[i,j] = FracAni(Eigs,np.mean(Eigs))
FA_SBI7[np.isnan(FA_SBI7)] = 0

In [None]:
temp = np.copy(MD_SBI7)
temp[~mask] = math.nan
img = plt.imshow(temp.T,cmap='hot',vmin=0,vmax=0.005)
plt.axis('off')
#cbar = plt.colorbar()
#cbar.formatter.set_powerlimits((0, 0))
vmin, vmax = img.get_clim()

if Save: plt.savefig(FigLoc+'HCP_SBI_MD_7_US.pdf',format='pdf',bbox_inches='tight',transparent=True)

In [None]:
tenmodel = dti.TensorModel(gtabHCP7,return_S0_hat = True,fit_method='NLLS')
tenfit = tenmodel.fit(maskdata[:,:,axial_middle,selected_indices])
FA7 = dti.fractional_anisotropy(tenfit.evals)
MD7 = dti.mean_diffusivity(tenfit.evals)

In [None]:
temp = np.copy(MD7)
temp[~mask] = math.nan
img = plt.imshow(temp.T,cmap='hot',vmin=0,vmax=0.005)
plt.axis('off')
cbar = plt.colorbar()
cbar.formatter.set_powerlimits((0, 0))

if Save: plt.savefig(FigLoc+'HCP_NLLS_MD_7_US.pdf',format='pdf',bbox_inches='tight',transparent=True)

In [None]:
data = np.abs(MDFull.T-MD7.T)
norm = TwoSlopeNorm(vmin=0,vcenter=0.00075, vmax=0.0015)
data[~mask.T] = np.nan
plt.imshow(data,cmap='Reds',norm=norm)

ticks = [0, 0.001]
plt.axis('off')
cbar = plt.colorbar()

cbar.set_ticks(ticks)
cbar.formatter.set_powerlimits((0, 0))
if Save: plt.savefig(FigLoc+'DTI_MDWLSErr_US.pdf',format='pdf',bbox_inches='tight',transparent=True)
plt.show()

data = np.abs(MD_SBI.T-MD_SBI7.T)
data[~mask.T] = np.nan
plt.imshow(data,cmap='Reds',norm=norm)
plt.axis('off')
if Save: plt.savefig(FigLoc+'DTI_MDSBIErr.pdf',format='pdf',bbox_inches='tight',transparent=True)
plt.show()




## c

In [None]:
temp = np.copy(FA_SBI)
temp[~mask] = math.nan
img = plt.imshow(temp.T,cmap='hot')
plt.axis('off')
#cbar = plt.colorbar()
#cbar.formatter.set_powerlimits((0, 0))
vmin, vmax = img.get_clim()

if Save: plt.savefig(FigLoc+'HCP_SBI_FA_US.pdf',format='pdf',bbox_inches='tight',transparent=True)

In [None]:
temp = np.copy(FAFull)
temp[~mask] = math.nan
img = plt.imshow(temp.T,cmap='hot')
plt.axis('off')
cbar = plt.colorbar()
cbar.formatter.set_powerlimits((0, 0))
vmin, vmax = img.get_clim()

if Save: plt.savefig(FigLoc+'HCP_NLLS_FA_US.pdf',format='pdf',bbox_inches='tight',transparent=True)

In [None]:
data = FA_SBI.T-FAFull.T
data[~mask.T] = np.nan
norm = TwoSlopeNorm(vmin=np.nanmin(data), vcenter=0, vmax=1)
plt.imshow(data,cmap='seismic',norm=norm)
plt.axis('off')
cbar = plt.colorbar()
ticks = [np.nanmin(data), 0, 1]  # Adjust the number of ticks as needed
cbar.set_ticks(ticks)
cbar.formatter.set_powerlimits((0, 0))

if Save: plt.savefig(FigLoc+'HCP_FA_Diff.pdf',format='pdf',bbox_inches='tight',transparent=True)

## d

In [None]:
temp = np.copy(FA_SBI7)
temp[~mask] = math.nan
img = plt.imshow(temp.T,cmap='hot')
plt.axis('off')
#cbar = plt.colorbar()
#cbar.formatter.set_powerlimits((0, 0))
vmin, vmax = img.get_clim()

if Save: plt.savefig(FigLoc+'HCP_SBI_FA_7_US.pdf',format='pdf',bbox_inches='tight',transparent=True)

In [None]:
temp = np.copy(FA7)
temp[~mask] = math.nan
img = plt.imshow(temp.T,cmap='hot')
plt.axis('off')
cbar = plt.colorbar()
cbar.formatter.set_powerlimits((0, 0))
vmin, vmax = img.get_clim()

if Save: plt.savefig(FigLoc+'HCP_NLLS_FA_7_US.pdf',format='pdf',bbox_inches='tight',transparent=True)

In [None]:
data = np.abs(FAFull.T-FA7.T)
norm = TwoSlopeNorm(vmin=0,vcenter=0.5, vmax=1)
data[~mask.T] = np.nan
plt.imshow(data,cmap='Reds',norm=norm)

ticks = [0, 1]
plt.axis('off')
cbar = plt.colorbar()

cbar.set_ticks(ticks)
cbar.formatter.set_powerlimits((0, 0))
if Save: plt.savefig(FigLoc+'DTI_FAWLSErr_US.pdf',format='pdf',bbox_inches='tight',transparent=True)
plt.show()

data = np.abs(FA_SBI.T-FA_SBI7.T)
data[~mask.T] = np.nan
plt.imshow(data,cmap='Reds',norm=norm)
plt.axis('off')
if Save: plt.savefig(FigLoc+'DTI_FASBIErr_US.pdf',format='pdf',bbox_inches='tight',transparent=True)
plt.show()




# Fig 4

In [None]:
FigLoc = image_path + 'Fig_S4/'
if not os.path.exists(FigLoc):
    os.mkdir(FigLoc)

## c

In [None]:
fimg_init, fbvals, fbvecs = get_fnames('small_64D')
bvals, bvecs = read_bvals_bvecs(fbvals, fbvecs)
hsph_initial = HemiSphere(xyz=bvecs[1:])
hsph_updated,_ = disperse_charges(hsph_initial,5000)
bvecs = np.vstack([[0,0,0],hsph_updated.vertices])
bvalsExt = np.hstack([bvals, 3000*np.ones_like(bvals)])
bvecsExt = np.vstack([bvecs, bvecs])
bvalsExt[65] = 0
gtabSim = gradient_table(bvalsExt, bvecsExt)

In [None]:
np.random.seed(1)
gtabSimDirs = []
for i in range(10):
    gtabSimDirs.append(gradient_table(np.array([0]+[1000]*6).squeeze(), np.vstack([[0,0,0],hsph_updated.vertices[np.random.choice(np.arange(1,64),6)]])))

In [None]:
np.random.seed(1)
RandDirs = []
for i in range(10):
    RandDirs.append(np.random.choice(np.arange(1,64),6))

In [None]:
u = np.linspace(0, 2 * np.pi, 100)     # Azimuthal angle
v = np.linspace(0, np.pi / 2, 100)     # Polar angle — only upper hemisphere

x1 = 2*np.outer(np.cos(u), np.sin(v))
y1 = 2*np.outer(np.sin(u), np.sin(v))
z1 = 2*np.outer(np.ones_like(u), np.cos(v))

In [None]:
fig = plt.figure()
ax = fig.add_subplot(projection='3d')
plt.quiver([0]*6,[0]*6,[0]*6,2*gtabSimDirs[0].bvecs[1:,0],2*gtabSimDirs[0].bvecs[1:,1],2*gtabSimDirs[0].bvecs[1:,2],color='red',lw=3,label='Random')
plt.quiver([0]*6,[0]*6,[0]*6,2*gtabSim7.bvecs[1:,0],2*gtabSim7.bvecs[1:,1],2*gtabSim7.bvecs[1:,2],color='k',lw=3,label='Optimal')
ax.plot_surface(x1, y1, z1,  rstride=4, cstride=4, color=np.array([140, 100, 200]) / 255 , linewidth=0, alpha=0.25)
plt.axis('off')
ax.set_box_aspect((1.8, 1.8, 1))
ax.view_init(elev=45., azim=-102)
plt.legend(loc=2, bbox_to_anchor=(0.1,1.19),fontsize=36,columnspacing=0.3,handlelength=0.6,handletextpad=0.3)
if Save: plt.savefig(FigLoc+'DirectionsEg.pdf',format='pdf',bbox_inches='tight',transparent=True)
    

## d

In [None]:
np.random.seed(1)
torch.manual_seed(1)
if os.path.exists(f"{network_path}/DTISimMin.pickle"):
    with open(f"{network_path}/DTISimMin.pickle", "rb") as handle:
        posterior7 = pickle.load(handle)
else:
    Obs = []
    Par = []
    for i in tqdm(range(TrainingSamples)):
        params = priorNoise.sample()
        dt = ComputeDTI(params)
        dt = ForceLowFA(dt)
        a = params[-1]
        Obs.append(CustomSimulator(dt,gtabSim7,200,a))
        Par.append(np.hstack([mat_to_vals(dt),a]))
    
    Obs = np.array(Obs)
    Par = np.array(Par)
    Obs = torch.tensor(Obs).float()
    Par = torch.tensor(Par).float()
    
    # Create inference object. Here, NPE is used.
    inference = SNPE(prior=priorNoise)
    
    # generate simulations and pass to the inference object
    inference = inference.append_simulations(Par, Obs)
    
    # train the density estimator and build the posterior
    density_estimator = inference.train()
    posterior7 = inference.build_posterior(density_estimator)
    if not os.path.exists(f"{network_path}/DTISimMin.pickle"):
        with open(f"{network_path}/DTISimMin.pickle", "wb") as handle:
            pickle.dump(posterior7, handle)

In [None]:
np.random.seed(1)
torch.manual_seed(1)
posterior7_RD = []
for kk in range(10):
    if os.path.exists(f"{network_path}/DTISimMin_Dir"+str(kk)+".pickle"):
        with open(f"{network_path}/DTISimMin_Dir"+str(kk)+".pickle", "rb") as handle:
            posterior7_RD.append(pickle.load(handle))
    else:
        Obs = []
        Par = []
        for i in tqdm(range(TrainingSamples)):
            params = priorNoise.sample()
            dt = ComputeDTI(params)
            dt = ForceLowFA(dt)
            a = params[-1]
            Obs.append(CustomSimulator(dt,gtabSiRKirs[kk],200,a))
            Par.append(np.hstack([mat_to_vals(dt),a]))
        
        Obs = np.array(Obs)
        Par = np.array(Par)
        Obs = torch.tensor(Obs).float()
        Par = torch.tensor(Par).float()
        
        # Create inference object. Here, NPE is used.
        inference = SNPE(prior=priorNoise)
        
        # generate simulations and pass to the inference object
        inference = inference.append_simulations(Par, Obs)
        
        # train the density estimator and build the posterior
        density_estimator = inference.train()
        posterior7_RD.append(inference.build_posterior(density_estimator))
        if not os.path.exists(f"{network_path}/DTISimMin_Dir"+str(kk)+".pickle"):
            with open(f"{network_path}/DTISimMin_Dir"+str(kk)+".pickle", "wb") as handle:
                pickle.dump(posterior7_RD[-1], handle)

In [None]:
Samples7  = []
Samples7Dirs  = []
DTISim = []

params = priorS0.sample([500])
for i in tqdm(range(500)):
    dt = ComputeDTI(params[i])
    dt = ForceLowFA(dt)
    DTISim.append(dt)
    Samples7.append([CustomSimulator(dt,gtabSim7, S0=200,snr=scale) for scale in NoiseLevels])
    temp_Samples = []
    for k in range(10):
        temp_Samples.append([CustomSimulator(dt,gtabSimDirs[k], S0=200,snr=scale) for scale in NoiseLevels])
    temp_Samples = np.array(temp_Samples).squeeze()
    temp_Samples = np.moveaxis(temp_Samples, 0, -1)
    Samples7Dirs.append(temp_Samples)
Samples7 = np.array(Samples7).squeeze()
Samples7 = np.moveaxis(Samples7, 0, -1)
Samples7Dirs = np.array(Samples7Dirs)
Samples7Dirs = np.moveaxis(Samples7Dirs, 0, -2)

In [None]:
torch.manual_seed(10)
SNR = NoiseLevels
Error7 = []
NoiseApprox7 = []
for k in tqdm(range(5)):
    ErrorN2 = []
    ENoise = []
    for i in range(200):
        tparams = mat_to_vals(DTISim[i])
        tObs = Samples7[k,:,i]
        mat_true = vals_to_mat(tparams)
        evals_true,evecs_true = np.linalg.eigh(mat_true)
        true_signal_dti = single_tensor(gtabSim7, S0=200, evals=evals_true, evecs=evecs_true,
                           snr=None)
        posterior_samples_1 = posterior7.sample((InferSamples,), x=tObs,show_progress_bars=False)
        mat_guess = vals_to_mat(np.array(posterior_samples_1.mean(axis=0)))
        ErrorN2.append(Errors(mat_guess,mat_true,gtabSim7,true_signal_dti,tObs))
        ENoise.append(posterior_samples_1[:,-1].mean())
    NoiseApprox7.append(ENoise)
    Error7.append(ErrorN2)

NoiseApprox7 = np.array(NoiseApprox7)    


In [None]:
torch.manual_seed(10)
SNR = NoiseLevels
Error7_RD = []
for k in tqdm(range(5)):
    ErrorN2 = []
    for i in range(200):
        tparams = mat_to_vals(DTISim[i])
        tObs = Samples7Dirs[k,:,i,0]
        mat_true = vals_to_mat(tparams)
        evals_true,evecs_true = np.linalg.eigh(mat_true)
        true_signal_dti = single_tensor(gtabSim, S0=200, evals=evals_true, evecs=evecs_true,
                           snr=None)[:7]
        posterior_samples_1 = posterior7_RD[0].sample((InferSamples,), x=tObs,show_progress_bars=False)
        mat_guess = vals_to_mat(np.array(posterior_samples_1.mean(axis=0)))
        ErrorN2.append(Errors(mat_guess,mat_true,gtabSimDirs[0],true_signal_dti,tObs))
        ENoise.append(posterior_samples_1[:,-1].mean())
    Error7_RD.append(ErrorN2)


In [None]:
Error_s = []
for k,gtab,Samps,DTIS in zip([7,7],[gtabSimDirs[0],gtabSim7],[Samples7Dirs[...,0],Samples7],[DTISim,DTISim]):
    tenmodel = dti.TensorModel(gtab,fit_method='NLLS')
    Error_n = []
    for S,Noise in zip(Samps,NoiseLevels):
        Error = []
        for i in range(500):
            tenfit = tenmodel.fit(S[:,i])
            tensor_vals = dti.lower_triangular(tenfit.quadratic_form)
            DT_test = vals_to_mat(tensor_vals)
            Error.append(Errors(DT_test,DTIS[i],gtab,Samps[0][:,i],S[:,i]))
        Error_n.append(Error)
    Error_s.append(Error_n)
Error_s = np.array(Error_s)
Error_s = np.swapaxes(Error_s,0,1)

In [None]:

fig,axs = plt.subplots(2,1,figsize=(4.5,6))
ax = axs.ravel()
for ll,(a,E,E2,t) in enumerate(zip(ax,np.array(Error7).T,np.array(Error7_RD).T,Errors_name)):
    plt.sca(a) 
    g_pos = np.array([1,3,5,7])
    colors = ['lightseagreen','lightseagreen','lightseagreen','lightseagreen']
    colors2 = ['paleturquoise','paleturquoise','paleturquoise','paleturquoise']
    BoxPlots(E[:,1:].T,g_pos,colors,colors2,a,widths=0.3,scatter=False)


    g_pos = np.array([1.3,3.3,5.3,7.3])

    colors = ['mediumturquoise','mediumturquoise','mediumturquoise','mediumturquoise']
    colors2 = ['paleturquoise','paleturquoise','paleturquoise','paleturquoise']
    BoxPlots(E2[:,1:].T,g_pos,colors,colors2,a,widths=0.3,scatter=False,hatch=True)

    g_pos = np.array([1.6,3.6,5.6,7.6])
    colors = ['sandybrown','sandybrown','sandybrown','sandybrown']
    colors2 = ['peachpuff','peachpuff','peachpuff','peachpuff']
    BoxPlots(Error_s[1:,1,:,ll],g_pos,colors,colors2,a,widths=0.3,scatter=False)

    g_pos = np.array([1.9,3.9,5.9,7.9])
    colors = ['burlywood','burlywood','burlywood','burlywood']
    colors2 = ['peachpuff','peachpuff','peachpuff','peachpuff']
    BoxPlots(Error_s[1:,0,:,ll],g_pos,colors,colors2,a,widths=0.3,scatter=False,hatch='x')
        
    plt.sca(a)
    plt.xticks([1.45, 3.45, 5.45, 7.45,], NoiseLevels[1:],fontsize=32)
    #ymax = max(bp2['whiskers'][1].get_ydata()[1],bp['whiskers'][1].get_ydata()[1])*1.2
    # Create custom legend handles
    plt.gca().ticklabel_format(axis='y',style='sci',scilimits=(-1,1))
    plt.yticks(fontsize=32)
    if(ll==1):
        handles = [
            Line2D([0], [0], color=SBIFit, lw=4, label='SBI'),  # Adjust color as per the actual plot color
        ]
        # Add the legenda
        plt.legend(handles=handles,loc=2, bbox_to_anchor=(-0.05,1.15),
                   fontsize=28,columnspacing=0.3,handlelength=0.6,handletextpad=0.3)
    if ll == 0:
        handles = [
            Line2D([0], [0],
                   color='sandybrown', lw=4,
                   label='NLLS'),
            Rectangle((0,0), 1, 0.2,
                            facecolor='peachpuff',
                            edgecolor='burlywood',
                            hatch='///',
                            label='Rand. Dir')
        ]
        plt.legend(handles=handles,handler_map={Rectangle: ThinPatchHandler()},
                   loc=2,
                   bbox_to_anchor=(-0.05, 1.15),
                   fontsize=28,
                   columnspacing=0.3,
                   handlelength=0.6,
                   handletextpad=0.3,
                  labelspacing=0.0 )
plt.tight_layout()
if Save: plt.savefig(FigLoc+'DirectionErrors.pdf',format='pdf',bbox_inches='tight',transparent=True)

## e

In [None]:
torch.manual_seed(10)
SNR = NoiseLevels
Sensitivity = []
for k in tqdm(range(5)):
    ErrorN2 = []
    temp = []
    for i in tqdm(range(500)):
        tObs = Samples7[k,:,i]
        posterior_samples_1 = posterior7.sample((InferSamples,), x=tObs,show_progress_bars=False)
        mat_guess = vals_to_mat(np.array(posterior_samples_1.mean(axis=0)))
        evals = np.linalg.eigh(mat_guess)[0]
        RK_guess = np.mean(evals)
        FA_guess = FracAni(evals,RK_guess)
        temp2 = []
        for kk in range(10):
            tObs2 = Samples7Dirs[k,:,i,kk]
            posterior_samples_1 = posterior7_RD[kk].sample((InferSamples,), x=tObs2,show_progress_bars=False)
            mat_guess2 = vals_to_mat(np.array(posterior_samples_1.mean(axis=0)))
            evals2 = np.linalg.eigh(mat_guess2)[0]
            RK_guess2 = np.mean(evals2)
            FA_guess2 = FracAni(evals2,RK_guess2)
            temp2.append((np.abs(RK_guess-RK_guess2),np.abs(FA_guess-FA_guess2)))
        temp.append(temp2)
    Sensitivity.append(temp)

In [None]:
torch.manual_seed(10)
SNR = NoiseLevels
Sensitivity_NLLS = []
for k in tqdm(range(5)):
    ErrorN2 = []
    temp = []
    tenmodel = dti.TensorModel(gtabSim7,fit_method='NLLS')
    for i in tqdm(range(500)):
        tObs = Samples7[k,:,i]
        tenfit = tenmodel.fit(tObs)
        RK_guess = tenfit.md
        FA_guess = tenfit.fa
        temp2 = []
        for kk in range(10):
            tObs2 = Samples7Dirs[k,:,i,kk]
            tenmodel2 = dti.TensorModel(gtabSimDirs[kk],fit_method='NLLS')
            tenfit = tenmodel2.fit(tObs2)
            RK_guess2 = tenfit.md
            FA_guess2 = tenfit.fa
            temp2.append((np.abs(RK_guess-RK_guess2),np.abs(FA_guess-FA_guess2)))
        temp.append(temp2)
    Sensitivity_NLLS.append(temp)

In [None]:
Sensitivity = np.array(Sensitivity)
Sensitivity_NLLS = np.array(Sensitivity_NLLS)

S_RK = Sensitivity[...,0]
S_RK_NLLS = Sensitivity_NLLS[...,0]

S_FA = Sensitivity[...,1]
S_FA_NLLS = Sensitivity_NLLS[...,1]

In [None]:
fig,axs = plt.subplots(2,1,figsize=(4.5,6))
ax = axs.ravel()
plt.sca(ax[0])
for i in range(1,5):
    g_pos = np.array([1,2,3,4])
    colors = ['lightseagreen','lightseagreen','lightseagreen','lightseagreen']
    colors2 = ['paleturquoise','paleturquoise','paleturquoise','paleturquoise']
    BoxPlots(S_RK[1:,:,:].mean(axis=1),g_pos,colors,colors2,ax[0],widths=0.3,scatter=True)
    g_pos = np.array([1.3,2.3,3.3,4.3])
    colors = ['sandybrown','sandybrown','sandybrown','sandybrown']
    colors2 = ['peachpuff','peachpuff','peachpuff','peachpuff']
    BoxPlots(S_RK_NLLS[1:,:,:].mean(axis=1),g_pos,colors,colors2,ax[0],widths=0.3,scatter=True)
    
    plt.semilogy()
    plt.xticks([1.15,2.15, 3.15, 4.15,], NoiseLevels[1:],fontsize=32)
    plt.yticks(rotation=90,va='center')
    #plt.ticklabel_format(axis='y',style='sci',scilimits=(-1,1))
plt.grid()

plt.sca(ax[1])
for i in range(1,5):
    g_pos = np.array([1,2,3,4])
    colors = ['lightseagreen','lightseagreen','lightseagreen','lightseagreen']
    colors2 = ['paleturquoise','paleturquoise','paleturquoise','paleturquoise']
    BoxPlots(S_FA[1:,:,:].mean(axis=1),g_pos,colors,colors2,ax[1],widths=0.3,scatter=True)
    g_pos = np.array([1.3,2.3,3.3,4.3])
    colors = ['sandybrown','sandybrown','sandybrown','sandybrown']
    colors2 = ['peachpuff','peachpuff','peachpuff','peachpuff']
    BoxPlots(S_FA_NLLS[1:,:,:].mean(axis=1),g_pos,colors,colors2,ax[1],widths=0.3,scatter=True)
    
    plt.xticks([1.15,2.15, 3.15, 4.15,], NoiseLevels[1:],fontsize=32)
    #plt.ticklabel_format(axis='y',style='sci',scilimits=(-1,1))
plt.grid()
plt.tight_layout()
if Save: plt.savefig(FigLoc+'DirectionComp.pdf',format='pdf',bbox_inches='tight',transparent=True)

## f

In [None]:
fdwi = './HCP_data/Pat'+str(1)+'/diff_1k.nii.gz'
bvalloc = './HCP_data/Pat'+str(1)+'/bvals_1k.txt'
bvecloc = './HCP_data/Pat'+str(1)+'/bvecs_1k.txt'

bvalsHCP = np.loadtxt(bvalloc)
bvecsHCP = np.loadtxt(bvecloc)
gtabHCP = gradient_table(bvalsHCP, bvecsHCP)

data, affine, img = load_nifti(fdwi, return_img=True)
data, affine = reslice(data, affine, (1.5,1.5,1.5), (2.5,2.5,2.5))
axial_middle = data.shape[2] // 2
maskdata, mask = median_otsu(data, vol_idx=range(10, 50), median_radius=3,
                             numpass=1, autocrop=True, dilate=2)

In [None]:
np.random.seed(1)
RandomDir_HCP = np.random.choice(np.arange(1,69),6)
RandomDir_HCP = np.insert(RandomDir_HCP,0,0)

In [None]:
bvalsHCP7 = bvalsHCP[RandomDir_HCP]
bvecsHCP7 = bvecsHCP[RandomDir_HCP]
gtabHCP7 = gradient_table(bvalsHCP7, bvecsHCP7)

In [None]:
custom_prior = DTIPriorS0Noise(lower_abs,upper_abs,lower_rest,upper_rest,lower_S0,upper_S0,0,30)
priorS0Noise, *_ = process_prior(custom_prior) 

In [None]:
if os.path.exists(f"{network_path}/DTIHCPMin_RD.pickle"):
    with open(f"{network_path}/DTIHCPMin_RD.pickle", "rb") as handle:
        posterior7_2 = pickle.load(handle)
else:
    np.random.seed(1)
    torch.manual_seed(1)
    bvals = gtabHCP.bvals
    bvecs = gtabHCP.bvecs
    Obs = []
    Par = []
    for i in tqdm(range(TrainingSamples)):
        params = priorS0.sample()
        dt = ComputeDTI(params[:-1])
        dt = ForceLowFA(dt)
        Obs.append(CustomSimulator(dt,gtabHCP7,params[-1],np.random.rand()*30 + 20))
        Par.append(np.hstack([mat_to_vals(dt),params[-1]]))
    
    Obs = np.array(Obs)
    Par = np.array(Par)
    Obs = torch.tensor(Obs).float()
    Par= torch.tensor(Par).float()
    
    # Create inference object. Here, NPE is used.
    inference = SNPE(prior=priorS0)
    
    # generate simulations and pass to the inference object
    inference = inference.append_simulations(Par, Obs)
    
    # train the density estimator and build the posterior
    density_estimator = inference.train()
    posterior7_2 = inference.build_posterior(density_estimator)
    if not os.path.exists(f"{network_path}/DTIHCPMin_RD.pickle"):
        with open(f"{snetwork_path}/DTIHCPMin_RD.pickle", "wb") as handle:
            pickle.dump(posterior7_2, handle)

In [None]:
# Compute the mask where the sum is not zero
mask = np.sum(maskdata[:, :, axial_middle, :], axis=-1) != 0

# Get the indices where mask is True
indices = np.argwhere(mask)

# Define the function for optimization
def optimize_pixel(i, j):
    torch.manual_seed(10)  # If required
    posterior_samples_1 = posterior7_2.sample((1000,), x=maskdata[i,j,axial_middle, RandomDir_HCP],show_progress_bars=False)
    return i, j, posterior_samples_1.mean(axis=0)

# Initialize NoiseEst with the appropriate shape
ArrShape = mask.shape

# Use joblib to parallelize the optimization tasks
results = Parallel(n_jobs=-1)(
    delayed(optimize_pixel)(i, j) for i, j in tqdm(indices)
)


In [None]:

NoiseEst = np.zeros(list(ArrShape) + [7])

# Assign the optimization results to NoiseEst
for i, j, x in results:
    NoiseEst[i, j] = x

for i, j, x in results:
    NoiseEst[i, j,-2] = np.clip(NoiseEst[i, j,-2],0,100)
    NoiseEst[i, j,-3] = np.clip(NoiseEst[i, j,-3],0,1)

In [None]:

NoiseEst2 =  np.zeros_like(NoiseEst)
NoiseEst2 =  np.zeros_like(NoiseEst)
for i in range(55):
    for j in range(64):    
        NoiseEst2[i,j] = np.hstack([mat_to_vals(clip_negative_eigenvalues(vals_to_mat(NoiseEst[i,j]))),NoiseEst[i,j,-1]])
RK_SBI7 = np.zeros([55,64])
FA_SBI7 = np.zeros([55,64])
for i in range(55):
    for j in range(64):
        Eigs = np.linalg.eigh(vals_to_mat(NoiseEst2[i,j,:6]))[0]
        RK_SBI7[i,j] = np.mean(Eigs)
        FA_SBI7[i,j] = FracAni(Eigs,np.mean(Eigs))
FA_SBI7[np.isnan(FA_SBI7)] = 0

In [None]:
tenmodel = dti.TensorModel(gtabHCP7,return_S0_hat = True,fit_method='NLLS')
tenfit = tenmodel.fit(maskdata[:,:,axial_middle,RandomDir_HCP])
FA7 = dti.fractional_anisotropy(tenfit.evals)
RK7 = dti.mean_diffusivity(tenfit.evals)

In [None]:
for i in range(55):
    for j in range(64):
        if(np.sum(maskdata[i,j,axial_middle,:]) == 0):
            FA7[i,j] = 0

In [None]:
temp = np.copy(RK_SBI7)

temp[~mask] = math.nan
img = plt.imshow(temp.T,cmap='hot')
plt.axis('off')
vmin, vmax = img.get_clim()
#plt.colorbar()
if Save: plt.savefig(FigLoc+'HCP_SBI_RK_7_RD.pdf',format='pdf',bbox_inches='tight',transparent=True)

In [None]:
temp = np.copy(RK7)

temp[~mask] = math.nan
img = plt.imshow(temp.T,cmap='hot',vmin=vmin,vmax=vmax)
cbar = plt.colorbar(fraction=0.032, pad=0.04)
cbar.formatter.set_powerlimits((0, 0))
plt.axis('off')
if Save: plt.savefig(FigLoc+'HCP_NLLS_RK_7_RD.pdf',format='pdf',bbox_inches='tight',transparent=True)

## g

In [None]:
temp = np.copy(FA_SBI7)

temp[~mask] = math.nan
img = plt.imshow(temp.T,cmap='hot',vmin=0,vmax=1)
plt.axis('off')
if Save: plt.savefig(FigLoc+'HCP_SBI_FA_7_RD.pdf',format='pdf',bbox_inches='tight',transparent=True)

In [None]:
temp = np.copy(FA7)

temp[~mask] = math.nan
img = plt.imshow(temp.T,cmap='hot',vmin=0,vmax=1)
cbar = plt.colorbar(fraction=0.032, pad=0.04)
plt.axis('off')
cbar.set_ticks([0,0.2,0.4,0.6,0.8,1])
if Save: plt.savefig(FigLoc+'HCP_NLLS_FA_7_RD.pdf',format='pdf',bbox_inches='tight',transparent=True)

# Fig 5

In [None]:
FigLoc = image_path + 'Fig_S3/'
if not os.path.exists(FigLoc):
    os.mkdir(FigLoc)

In [None]:
fimg_init, fbvals, fbvecs = get_fnames('small_64D')
bvals, bvecs = read_bvals_bvecs(fbvals, fbvecs)
hsph_initial = HemiSphere(xyz=bvecs[1:])
hsph_initial20 = HemiSphere(xyz=bvecs[1:20])
hsph_initial7 = HemiSphere(xyz=bvecs[1:7])
hsph_updated,potentials = disperse_charges(hsph_initial,5000)
hsph_updated20,potentials = disperse_charges(hsph_initial20,5000)
hsph_updated7,potentials = disperse_charges(hsph_initial7,5000)

gtabSimF = gradient_table(np.array([0]+[1000]*64).squeeze(), np.vstack([[0,0,0],hsph_updated.vertices]))
gtabSim20 = gradient_table(np.array([0]+[1000]*19).squeeze(), np.vstack([[0,0,0],hsph_updated20.vertices]))

In [None]:
np.random.seed(0)
torch.manual_seed(0)
Samples  = []
DTISim = []
S0Sim    = []

params = priorS0.sample([500])
for i in tqdm(range(500)):
    dt = ComputeDTI(params[i])
    dt = ForceLowFA(dt)
    DTISim.append(dt)
    S0Sim.append(params[i,-1])
    Samples.append([CustomSimulator(dt,gtabSimF, S0=200,snr=scale) for scale in NoiseLevels])
    
Samples = np.array(Samples).squeeze()
Samples = np.moveaxis(Samples, 0, -1)

Samples20  = []
DTISim20 = []
S0Sim20    = []

params = priorS0.sample([500])
for i in tqdm(range(500)):
    dt = ComputeDTI(params[i])
    dt = ForceLowFA(dt)
    DTISim20.append(dt)
    S0Sim20.append(params[i,-1])
    Samples20.append([CustomSimulator(dt,gtabSim20, S0=200,snr=scale) for scale in NoiseLevels])
    
Samples20 = np.array(Samples20).squeeze()
Samples20 = np.moveaxis(Samples20, 0, -1)

Samples7  = []
DTISim7 = []
S0Sim7    = []

params = priorS0.sample([500])
for i in tqdm(range(500)):
    dt = ComputeDTI(params[i])
    dt = ForceLowFA(dt)
    DTISim7.append(dt)
    S0Sim7.append(params[i,-1])
    Samples7.append([CustomSimulator(dt,gtabSim7, S0=200,snr=scale) for scale in NoiseLevels])
    
Samples7 = np.array(Samples7).squeeze()
Samples7 = np.moveaxis(Samples7, 0, -1)

## a

In [None]:
np.random.seed(1)
torch.manual_seed(1)
if os.path.exists(f"{network_path}/DTISimFull.pickle"):
    with open(f"{network_path}/DTISimFull.pickle", "rb") as handle:
        posteriorFull = pickle.load(handle)
else:
    Obs = []
    Par = []
    for i in tqdm(range(TrainingSamples)):
        params = priorNoise.sample()
        dt = ComputeDTI(params)
        dt = ForceLowFA(dt)
        a = params[-1]
        Obs.append(CustomSimulator(dt,gtabSimF,200,a))
        Par.append(np.hstack([mat_to_vals(dt),a]))
    
    Obs = np.array(Obs)
    Par = np.array(Par)
    Obs = torch.tensor(Obs).float()
    Par = torch.tensor(Par).float()
    
    # Create inference object. Here, NPE is used.
    inference = SNPE(prior=priorNoise)
    
    # generate simulations and pass to the inference object
    inference = inference.append_simulations(Par, Obs)
    
    # train the density estimator and build the posterior
    density_estimator = inference.train()
    posteriorFull = inference.build_posterior(density_estimator)
    if not os.path.exists(f"{save_path}/DTISimFull.pickle"):
        with open(f"{save_path}/DTISimFull.pickle", "wb") as handle:
            pickle.dump(posteriorFull, handle)


In [None]:
torch.manual_seed(10)
SNR = NoiseLevels
ErrorFull = []
NoiseApproxFull = []
for k in tqdm(range(5)):
    ErrorN2 = []
    ENoise = []
    for i in range(500):
        tparams = mat_to_vals(DTISim[i])
        tObs = Samples[k,:,i]#Simulator(bvals,bvecs,200,params,Noise)
        mat_true = vals_to_mat(tparams)
        evals_true,evecs_true = np.linalg.eigh(mat_true)
        true_signal_dti = single_tensor(gtabSimF, S0=200, evals=evals_true, evecs=evecs_true,
                           snr=None)
        posterior_samples_1 = posteriorFull.sample((InferSamples,), x=tObs,show_progress_bars=False)
        mat_guess = vals_to_mat(np.array(posterior_samples_1.mean(axis=0)))
        mat_guess = clip_negative_eigenvalues(mat_guess)
        ErrorN2.append(Errors(mat_guess,mat_true,gtabSimF,true_signal_dti,tObs))
        ENoise.append(posterior_samples_1[:,-1].mean())
    NoiseApproxFull.append(ENoise)
    ErrorFull.append(ErrorN2)

NoiseApproxFull = np.array(NoiseApproxFull)    


In [None]:
k,gtab,Samps,DTIS = 65,gtabSimF,Samples,DTISim
tenmodel = dti.TensorModel(gtab,fit_method='NLLS')
Error_n = []
for S,Noise in zip(Samps,NoiseLevels):
    Error = []
    for i in range(500):
        tenfit = tenmodel.fit(S[:,i])
        tensor_vals = dti.lower_triangular(tenfit.quadratic_form)
        DT_test = vals_to_mat(tensor_vals)
        Error.append(Errors(DT_test,DTIS[i],gtab,Samps[0][:,i],S[:,i]))
    Error_n.append(Error)
Error_n = np.array(Error_n)

In [None]:
fig,axs = plt.subplots(1,6,figsize=(27,3))
ax = axs.ravel()
for ll,(a,E,t) in enumerate(zip(ax,np.array(ErrorFull).T[2:],Errors_name[2:])):
    y_data = E[:,1:]
    g_pos = np.array([1.3,2.3,3.3,4.3])
    colors = ['lightseagreen','lightseagreen','lightseagreen','lightseagreen']
    colors2 = ['paleturquoise','paleturquoise','paleturquoise','paleturquoise']
    
    BoxPlots(y_data.T,g_pos,colors,colors2,a,widths=0.3,scatter=False)
    y_data = Error_n[1:,:,ll+2].T
    g_pos = np.array([1,2,3,4])
    colors = ['sandybrown','sandybrown','sandybrown','sandybrown']
    colors2 = ['peachpuff','peachpuff','peachpuff','peachpuff']
    
    BoxPlots(y_data.T,g_pos,colors,colors2,a,widths=0.3,scatter=False)
    plt.sca(a)
    plt.xticks([1.15, 2.15, 3.15, 4.15,], NoiseLevels[1:],fontsize=32)
    plt.yticks(fontsize=32)
    # Create custom legend handles
    plt.gca().ticklabel_format(axis='y',style='sci',scilimits=(-1,1))
    plt.grid(axis='y')
    if(ll==1):
        handles = [
            Line2D([0], [0], color=SBIFit, lw=4, label='SBI'),  # Adjust color as per the actual plot color
        ]
        # Add the legenda
        plt.legend(handles=handles,loc=2, bbox_to_anchor=(0,1.05),
                   fontsize=36,columnspacing=0.3,handlelength=0.6,handletextpad=0.3)
    if(ll==0):
        handles = [
            Line2D([0], [0], color=WLSFit, lw=4, label='NLLS'),  # Adjust color as per the actual plot color
        ]
        # Add the legenda
        plt.legend(handles=handles,loc=2, bbox_to_anchor=(0,1.15),
                   fontsize=36,columnspacing=0.3,handlelength=0.6,handletextpad=0.3)
plt.tight_layout()
if Save: plt.savefig(FigLoc+'SiRKatDTIErrors2.pdf',format='pdf',bbox_inches='tight',transparent=True)

## b

In [None]:
np.random.seed(1)
torch.manual_seed(1)
if os.path.exists(f"{network_path}/DTISimMid.pickle"):
    with open(f"{network_path}/DTISimMid.pickle", "rb") as handle:
        posterior20 = pickle.load(handle)
else:
    Obs = []
    Par = []
    for i in tqdm(range(TrainingSamples)):
        params = priorNoise.sample()
        dt = ComputeDTI(params)
        dt = ForceLowFA(dt)
        a = params[-1]
        Obs.append(CustomSimulator(dt,gtabSim20,200,a))
        Par.append(np.hstack([mat_to_vals(dt),a]))
    
    
    Obs = np.array(Obs)
    Par = np.array(Par)
    Obs = torch.tensor(Obs).float()
    Par = torch.tensor(Par).float()
    
    # Create inference object. Here, NPE is used.
    inference = SNPE(prior=priorNoise)
    
    # generate simulations and pass to the inference object
    inference = inference.append_simulations(Par, Obs)
    
    # train the density estimator and build the posterior
    density_estimator = inference.train()
    posterior20 = inference.build_posterior(density_estimator)
    if not os.path.exists(f"{network_path}/DTISimMid.pickle"):
        with open(f"{network_path}/DTISimMid.pickle", "wb") as handle:
            pickle.dump(posterior20, handle)

In [None]:
np.random.seed(0)
torch.manual_seed(0)

Samples20  = []
DTISim20 = []
S0Sim20    = []

params = priorS0.sample([500])
for i in tqdm(range(500)):
    dt = ComputeDTI(params[i])
    dt = ForceLowFA(dt)
    DTISim20.append(dt)
    S0Sim20.append(params[i,-1])
    Samples20.append([CustomSimulator(dt,gtabSim20, S0=200,snr=scale) for scale in NoiseLevels])
    
Samples20 = np.array(Samples20).squeeze()
Samples20 = np.moveaxis(Samples20, 0, -1)


In [None]:
torch.manual_seed(10)
SNR = NoiseLevels
Error20 = []
NoiseApprox20 = []
for k in tqdm(range(5)):
    ErrorN2 = []
    ENoise = []
    for i in range(500):
        tparams = mat_to_vals(DTISim20[i])
        tObs = Samples20[k,:,i]
        mat_true = vals_to_mat(tparams)
        evals_true,evecs_true = np.linalg.eigh(mat_true)
        true_signal_dti = single_tensor(gtabSim20, S0=200, evals=evals_true, evecs=evecs_true,
                           snr=None)
        posterior_samples_1 = posterior20.sample((InferSamples,), x=tObs,show_progress_bars=False)
        mat_guess = vals_to_mat(np.array(posterior_samples_1.mean(axis=0)))
        ErrorN2.append(Errors(mat_guess,mat_true,gtabSim20,true_signal_dti,tObs))
        ENoise.append(posterior_samples_1[:,-1].mean())
    NoiseApprox20.append(ENoise)
    Error20.append(ErrorN2)

NoiseApprox20 = np.array(NoiseApprox20)    


In [None]:
k,gtab,Samps,DTIS = 20,gtabSim20,Samples20,DTISim20
tenmodel = dti.TensorModel(gtab,fit_method='NLLS')
Error_n = []
for S,Noise in zip(Samps,NoiseLevels):
    Error = []
    for i in range(500):
        tenfit = tenmodel.fit(S[:,i])
        tensor_vals = dti.lower_triangular(tenfit.quadratic_form)
        DT_test = vals_to_mat(tensor_vals)
        Error.append(Errors(DT_test,DTIS[i],gtab,Samps[0][:,i],S[:,i]))
    Error_n.append(Error)
Error_n = np.array(Error_n)

In [None]:
fig,axs = plt.subplots(1,4,figsize=(18,3))
ax = axs.ravel()
for ll,(a,E,t) in enumerate(zip(ax,np.array(Error20).T,Errors_name)):
    y_data = E[:,1:]
    g_pos = np.array([1.3,2.3,3.3,4.3])
    colors = ['lightseagreen','lightseagreen','lightseagreen','lightseagreen']
    colors2 = ['paleturquoise','paleturquoise','paleturquoise','paleturquoise']
    
    BoxPlots(y_data.T,g_pos,colors,colors2,a,widths=0.3,scatter=False)
    y_data = Error_n[1:,:,ll].T
    g_pos = np.array([1,2,3,4])
    colors = ['sandybrown','sandybrown','sandybrown','sandybrown']
    colors2 = ['peachpuff','peachpuff','peachpuff','peachpuff']
    
    BoxPlots(y_data.T,g_pos,colors,colors2,a,widths=0.3,scatter=False)
    
    plt.sca(a) 
    if(ll>3):
        plt.xlabel('SNR', fontsize=24)
    plt.xticks([1.15, 2.15, 3.15, 4.15,], NoiseLevels[1:],fontsize=32)
    plt.yticks(fontsize=32)
    # Create custom legend handles
    plt.gca().ticklabel_format(axis='y',style='sci',scilimits=(-1,1))
    plt.grid(axis='y')
plt.tight_layout()
if Save: plt.savefig(FigLoc+'SiRKatDTIErrors1_20.pdf',format='pdf',bbox_inches='tight',transparent=True)

In [None]:
fig,axs = plt.subplots(1,4,figsize=(18,3))
ax = axs.ravel()
for ll,(a,E,t) in enumerate(zip(ax,np.array(Error20).T[4:],Errors_name[4:])):
    y_data = E[:,1:]
    g_pos = np.array([1.3,2.3,3.3,4.3])
    colors = ['lightseagreen','lightseagreen','lightseagreen','lightseagreen']
    colors2 = ['paleturquoise','paleturquoise','paleturquoise','paleturquoise']
    
    BoxPlots(y_data.T,g_pos,colors,colors2,a,widths=0.3,scatter=False)
    y_data = Error_n[1:,:,ll+4].T
    g_pos = np.array([1,2,3,4])
    colors = ['sandybrown','sandybrown','sandybrown','sandybrown']
    colors2 = ['peachpuff','peachpuff','peachpuff','peachpuff']
    
    BoxPlots(y_data.T,g_pos,colors,colors2,a,widths=0.3,scatter=False)
    plt.sca(a)
    if(ll == 0 or ll == 2):
        plt.yticks([0,20])
    if(ll>3):
        plt.xlabel('SNR', fontsize=24)
    plt.xticks([1.15, 2.15, 3.15, 4.15,], NoiseLevels[1:],fontsize=32)

    plt.yticks(fontsize=32)
    # Create custom legend handles
    plt.gca().ticklabel_format(axis='y',style='sci',scilimits=(-1,1))
    plt.grid(axis='y')
plt.tight_layout()
if Save: plt.savefig(FigLoc+'SiRKatDTIErrors2_20.pdf',format='pdf',bbox_inches='tight',transparent=True)

## c

In [None]:
np.random.seed(1)
torch.manual_seed(1)
if os.path.exists(f"{network_path}/DTISimMin.pickle"):
    with open(f"{network_path}/DTISimMin.pickle", "rb") as handle:
        posterior7 = pickle.load(handle)
else:
    Obs = []
    Par = []
    for i in tqdm(range(TrainingSamples)):
        params = priorNoise.sample()
        dt = ComputeDTI(params)
        dt = ForceLowFA(dt)
        a = params[-1]
        Obs.append(CustomSimulator(dt,gtabSim7,200,a))
        Par.append(np.hstack([mat_to_vals(dt),a]))
    
    Obs = np.array(Obs)
    Par = np.array(Par)
    Obs = torch.tensor(Obs).float()
    Par = torch.tensor(Par).float()
    
    # Create inference object. Here, NPE is used.
    inference = SNPE(prior=priorNoise)
    
    # generate simulations and pass to the inference object
    inference = inference.append_simulations(Par, Obs)
    
    # train the density estimator and build the posterior
    density_estimator = inference.train()
    posterior7 = inference.build_posterior(density_estimator)
    if not os.path.exists(f"{network_path}/DTISimMin.pickle"):
        with open(f"{network_path}/DTISimMin.pickle", "wb") as handle:
            pickle.dump(posterior7, handle)

In [None]:
torch.manual_seed(10)
SNR = NoiseLevels
Error7 = []
NoiseApprox7 = []
for k in tqdm(range(5)):
    ErrorN2 = []
    ENoise = []
    for i in range(500):
        tparams = mat_to_vals(DTISim7[i])
        tObs = Samples7[k,:,i]
        mat_true = vals_to_mat(tparams)
        evals_true,evecs_true = np.linalg.eigh(mat_true)
        true_signal_dti = single_tensor(gtabSim7, S0=200, evals=evals_true, evecs=evecs_true,
                           snr=None)
        posterior_samples_1 = posterior7.sample((InferSamples,), x=tObs,show_progress_bars=False)
        mat_guess = vals_to_mat(np.array(posterior_samples_1.mean(axis=0)))
        ErrorN2.append(Errors(mat_guess,mat_true,gtabSim7,true_signal_dti,tObs))
        ENoise.append(posterior_samples_1[:,-1].mean())
    NoiseApprox7.append(ENoise)
    Error7.append(ErrorN2)

NoiseApprox7 = np.array(NoiseApprox7)    


In [None]:
k,gtab,Samps,DTIS = 7,gtabSim7,Samples7,DTISim7
tenmodel = dti.TensorModel(gtab,fit_method='NLLS')
Error_n = []
for S,Noise in zip(Samps,NoiseLevels):
    Error = []
    for i in range(500):
        tenfit = tenmodel.fit(S[:,i])
        tensor_vals = dti.lower_triangular(tenfit.quadratic_form)
        DT_test = vals_to_mat(tensor_vals)
        Error.append(Errors(DT_test,DTIS[i],gtab,Samps[0][:,i],S[:,i]))
    Error_n.append(Error)
Error_n = np.array(Error_n)

In [None]:
fig,axs = plt.subplots(1,6,figsize=(27,3))
ax = axs.ravel()
for ll,(a,E,t) in enumerate(zip(ax,np.array(ErrorFull).T[2:],Errors_name[2:])):
    y_data = E[:,1:]
    g_pos = np.array([1.3,2.3,3.3,4.3])
    colors = ['lightseagreen','lightseagreen','lightseagreen','lightseagreen']
    colors2 = ['paleturquoise','paleturquoise','paleturquoise','paleturquoise']
    
    BoxPlots(y_data.T,g_pos,colors,colors2,a,widths=0.3,scatter=False)
    y_data = Error_n[1:,:,ll+2].T
    g_pos = np.array([1,2,3,4])
    colors = ['sandybrown','sandybrown','sandybrown','sandybrown']
    colors2 = ['peachpuff','peachpuff','peachpuff','peachpuff']
    
    BoxPlots(y_data.T,g_pos,colors,colors2,a,widths=0.3,scatter=False)
    plt.sca(a)
    plt.xticks([1.15, 2.15, 3.15, 4.15,], NoiseLevels[1:],fontsize=32)
    plt.yticks(fontsize=32)
    # Create custom legend handles
    plt.gca().ticklabel_format(axis='y',style='sci',scilimits=(-1,1))
    plt.grid(axis='y')
plt.tight_layout()
if Save: plt.savefig(FigLoc+'SiRKatDTIErrors2_7.pdf',format='pdf',bbox_inches='tight',transparent=True)

# Fig 6

In [None]:
FigLoc = image_path + 'Fig_S4/'
if not os.path.exists(FigLoc):
    os.mRKedirs(FigLoc)

In [None]:
DT1_hRK,DT2_hRK = FitDT(DTIFilt[TrueMets[:,1]>0.99,:],1)
x4_hRK,R1_hRK,x2_hRK,R2_hRK = FitKT(DKIFilt[TrueMets[:,1]>0.99,:],1)
DT5,KT5 = GenDTKT([DT1_hRK,DT2_hRK],[x4_hRK,R1_hRK,x2_hRK,R2_hRK],12,300)
ParMets = []
for d,k in tqdm(zip(DT5,KT5)):
    ParMets.append(DKIMetrics(d,k))
ParTest5 = np.array(ParMets)

In [None]:
# Full fit
DT1_full,DT2_full = FitDT(DTIFilt,1)
x4_full,R1_full,x2_full,R2_full = FitKT(DKIFilt,1)

# LowFA Fit
DT1_lfa,DT2_lfa = FitDT(DTIFilt[TrueMets[:,-1]<0.3,:],1)
x4_lfa,R1_lfa,x2_lfa,R2_lfa = FitKT(DKIFilt[TrueMets[:,-1]<0.3,:],1)

# HighFA Fit
DT1_hfa,DT2_hfa = FitDT(DTIFilt[TrueMets[:,-1]>0.7,:],1)
x4_hfa,R1_hfa,x2_hfa,R2_hfa = FitKT(DKIFilt[TrueMets[:,-1]>0.7,:],1)

# UltraLowFA Fit
DT1_ulfa,DT2_ulfa = FitDT(DTIFilt[TrueMets[:,-1]<0.1,:],1)
x4_ulfa,R1_ulfa,x2_ulfa,R2_ulfa = FitKT(DKIFilt[TrueMets[:,-1]<0.1,:],1)

# HigherRK Fit
DT1_hRK,DT2_hRK = FitDT(DTIFilt[TrueMets[:,1]>0.8,:],1)
x4_hRK,R1_hRK,x2_hRK,R2_hRK = FitKT(DKIFilt[TrueMets[:,1]>0.8,:],1)

In [None]:

DT2,KT2 = GenDTKT([DT1_lfa,DT2_lfa],[x4_lfa,R1_lfa,x2_lfa,R2_lfa],12,500)
DT3,KT3 = GenDTKT([DT1_hfa,DT2_hfa],[x4_hfa,R1_hfa,x2_hfa,R2_hfa],12,500)
DT5,KT5 = GenDTKT([DT1_hRK,DT2_hRK],[x4_hRK,R1_hRK,x2_hRK,R2_hRK],12,1000)


DT = np.vstack([DT2,DT3,DT5])
KT = np.vstack([KT2,KT3,KT5])

ParMets = []
for d,k in tqdm(zip(DT1,KT1)):
    ParMets.append(DKIMetrics(d,k))
ParTest1 = np.array(ParMets)

ParMets = []
for d,k in tqdm(zip(DT2,KT2)):
    ParMets.append(DKIMetrics(d,k))
ParTest2 = np.array(ParMets)

ParMets = []
for d,k in tqdm(zip(DT3,KT3)):
    ParMets.append(DKIMetrics(d,k))
ParTest3 = np.array(ParMets)

ParMets = []
for d,k in tqdm(zip(DT4,KT4)):
    ParMets.append(DKIMetrics(d,k))
ParTest4 = np.array(ParMets)

ParMets = []
for d,k in tqdm(zip(DT5,KT5)):
    ParMets.append(DKIMetrics(d,k))
ParTest5 = np.array(ParMets)
ParMets = []
for d,k in tqdm(zip(DT,KT)):
    ParMets.append(DKIMetrics(d,k))
ParTest = np.array(ParMets)

In [None]:
for i in range(5):
    plt.hist(ParTest[:,i],density=True,range=[0,1],color=SBIFit,label='Simulated',bins=50)
    plt.hist(TrueMets[:,i],alpha=0.8,density=True,range=[0,1],color='gray',label='HCP',bins=50)
    if(i==0):
        plt.legend(fontsize=32,columnspacing=0.3,handlelength=0.4,handletextpad=0.1,bbox_to_anchor=(0.7,1),loc=1)
    plt.xticks(fontsize=32)
    plt.yticks(fontsize=32)
    if Save: plt.savefig(FigLoc+'EgMetricDKI_'+str(i)+'.pdf',format='pdf',bbox_inches='tight',transparent=True)
    plt.show()

# Fig 7

In [None]:
fdwi = '../HCP_data/Pat'+str(1)+'/diff_1k.nii.gz'
bvalloc = '../HCP_data/Pat'+str(1)+'/bvals_1k.txt'
bvecloc = '../HCP_data/Pat'+str(1)+'/bvecs_1k.txt'

bvalsHCP = np.loadtxt(bvalloc)
bvecsHCP = np.loadtxt(bvecloc)
gtabHCP = gradient_table(bvalsHCP, bvecsHCP)

data, affine, img = load_nifti(fdwi, return_img=True)
data, affine = reslice(data, affine, (1.5,1.5,1.5), (2.5,2.5,2.5))
axial_middle = data.shape[2] // 2
maskdata, mask = median_otsu(data, vol_idx=range(10, 50), median_radius=3,
                             numpass=1, autocrop=True, dilate=2)

# Choose the first point (arbitrary starting point, e.g., the first gradient)
selected_indices = [1]
distance_matrix = squareform(pdist(bvecsHCP))
# Iteratively select the point furthest from the current selection
for _ in range(5):  # We need 7 points in total, and one is already selected
    remaining_indices = list(set(range(len(bvecsHCP))) - set(selected_indices))
    
    # Calculate the minimum distance to the selected points for each remaining point
    min_distances = np.min(distance_matrix[remaining_indices][:, selected_indices], axis=1)
    
    # Select the point with the maximum minimum distance
    next_index = remaining_indices[np.argmax(min_distances)]
    selected_indices.append(next_index)

selected_indices = [0]+selected_indices

bvalsHCP7 = bvalsHCP[selected_indices]
bvecsHCP7 = bvecsHCP[selected_indices]
gtabHCP7 = gradient_table(bvalsHCP7, bvecsHCP7)

In [None]:
if os.path.exists(f"{network_path}/DTIHCPMin_Denoise.pickle"):
    with open(f"{network_path}/DTIHCPMin_Denoise.pickle", "rb") as handle:
        posterior7_2 = pickle.load(handle)
else:
    np.random.seed(1)
    torch.manual_seed(1)
    bvals = gtabHCP.bvals
    bvecs = gtabHCP.bvecs
    Obs = []
    Par = []
    for i in tqdm(range(TrainingSamples)):
        params = priorS0.sample()
        dt = ComputeDTI(params[:-1])
        dt = ForceLowFA(dt)
        Obs.append(CustomSimulator(dt,gtabHCP7,params[-1],50))
        Par.append(np.hstack([mat_to_vals(dt),params[-1]]))
    
    Obs = np.array(Obs)
    Par = np.array(Par)
    Obs = torch.tensor(Obs).float()
    Par= torch.tensor(Par).float()
    
    # Create inference object. Here, NPE is used.
    inference = SNPE(prior=priorS0)
    
    # generate simulations and pass to the inference object
    inference = inference.append_simulations(Par, Obs)
    
    # train the density estimator and build the posterior
    density_estimator = inference.train(stop_after_epochs= 100)
    posterior7_2 = inference.build_posterior(density_estimator)
    if not os.path.exists(f"{network_path}/DTIHCPMin_Denoise.pickle"):
        with open(f"{network_path}/DTIHCPMin_Denoise.pickle", "wb") as handle:
            pickle.dump(posterior7_2, handle)

In [None]:
NoiseLevels = [10,20,30,40,50]

In [None]:
NoisyImgPlot = np.copy(maskdata[:,:,axial_middle])
NoisyImgPlot[~mask[:,:,axial_middle]] = math.nan
plt.imshow(NoisyImgPlot[:,:,selected_indices[3]].T,cmap='gray')
plt.axis('off')
if Save: plt.savefig(image_path+'Noise_base.pdf',format='pdf',bbox_inches='tight',transparent=True)
plt.show()

In [None]:
NoisyImgs = []
for N in NoiseLevels:
    np.random.seed(15)
    NoisyImg = np.zeros_like(maskdata[:,:,axial_middle])
    for i in range(55):
        for j in range(64):
            if(maskdata[i,j,axial_middle,0]>0):
                NoisyImg[i,j] = AddNoise(maskdata[i,j,axial_middle],maskdata[i,j,axial_middle,0],N)
            else:
                NoisyImg[i,j] = 0
    NoisyImgs.append(NoisyImg)
    NoisyImgPlot = np.copy(NoisyImg)
    NoisyImgPlot[~mask[:,:,axial_middle]] = math.nan
    plt.imshow(NoisyImgPlot[:,:,selected_indices[3]].T,cmap='gray')
    plt.axis('off')
    if Save: plt.savefig(image_path+'Noise_'+str(N)+'.pdf',format='pdf',bbox_inches='tight',transparent=True)
    plt.show()

In [None]:
Denoised_Arr = []
for N in NoiseLevels:
    NoisyImg4D = np.zeros_like(maskdata[:,:,:3])
    for i in range(55):
        for j in range(64):
            for ll,k in enumerate([axial_middle-1,axial_middle,axial_middle+1]):
                if(maskdata[i,j,k,0]>0):
                    NoisyImg4D[i,j,ll] = AddNoise(maskdata[i,j,k],maskdata[i,j,k,0],N)
                else:
                    NoisyImg4D[i,j,ll] = 0
    Denoised_Arr.append(mppca(NoisyImg4D[...,selected_indices], patch_radius=1, return_sigma=False))

In [None]:
sbi_Arr = []
for Imgs in NoisyImgs:
    # Compute the mask where the sum is not zero
    masks = np.sum(maskdata[:, :, axial_middle, :], axis=-1) != 0
    
    # Get the indices where mask is True
    indices = np.argwhere(masks)
    
    # Define the function for optimization
    def optimize_pixel(i, j):
        torch.manual_seed(10)  # If required
        posterior_samples_1 = posterior7_2.sample((1000,), x=Imgs[i,j,selected_indices],show_progress_bars=False)
        return i, j, posterior_samples_1.mean(axis=0)
    
    # Initialize NoiseEst with the appropriate shape
    ArrShape = masks.shape
    
    # Use joblib to parallelize the optimization tasks
    results = Parallel(n_jobs=8)(
        delayed(optimize_pixel)(i, j) for i, j in tqdm(indices)
    )

    NoiseEst = np.zeros(list(ArrShape) + [7])
    
    # Assign the optimization results to NoiseEst
    for i, j, x in results:
        NoiseEst[i, j] = x
    
    for i, j, x in results:
        NoiseEst[i, j,-2] = np.clip(NoiseEst[i, j,-2],0,100)
        NoiseEst[i, j,-3] = np.clip(NoiseEst[i, j,-3],0,1)
    NoiseEst2 =  np.zeros_like(NoiseEst)
    for i in range(55):
        for j in range(64):    
            NoiseEst2[i,j] = np.hstack([mat_to_vals(clip_negative_eigenvalues(vals_to_mat(NoiseEst[i,j]))),NoiseEst[i,j,-1]])
    MD_SBI7 = np.zeros([55,64])
    FA_SBI7 = np.zeros([55,64])
    for i in range(55):
        for j in range(64):
            Eigs = np.linalg.eigh(vals_to_mat(NoiseEst2[i,j,:6]))[0]
            MD_SBI7[i,j] = np.mean(Eigs)
            FA_SBI7[i,j] = FracAni(Eigs,np.mean(Eigs))
    FA_SBI7[np.isnan(FA_SBI7)] = 0
    NoiseEst3_7 =  np.zeros((55,64,69))
    for i in range(55):
        for j in range(64):    
            NoiseEst3_7[i,j] = CustomSimulator(vals_to_mat(NoiseEst2[i,j,:-1]),gtabHCP, S0=NoiseEst2[i,j,-1])
    sbi_Arr.append(NoiseEst3_7)

In [None]:
j = 3
NoisyImgPlot = np.copy(Denoised_Arr[0][:,:,1,j])
NoisyImgPlot[~mask[:,:,axial_middle]] = math.nan
plt.imshow(1-NoisyImgPlot.T,cmap='Purples')
plt.axis('off')
plt.savefig(image_path+'pca_denoise_10.pdf',format='pdf',bbox_inches='tight',transparent=True)
plt.show()
NoisyImgPlot = np.copy(Denoised_Arr[-1][:,:,1,j])
NoisyImgPlot[~mask[:,:,axial_middle]] = math.nan
plt.imshow(1-NoisyImgPlot.T,cmap='Purples')
plt.axis('off')
plt.savefig(image_path+'pca_denoise_50.pdf',format='pdf',bbox_inches='tight',transparent=True)

In [None]:
from matplotlib.colors import LinearSegmentedColormap

In [None]:
colors = [
    "#eaf7f7",  # almost white with a hint of teal
    "#cceeee",
    "#a6e0e0",
    "#7acccc",
    "#4db8b8",
    "#249f9f",
    "#007f7f",
    "#005f5f"   # deep teal
]

teals = LinearSegmentedColormap.from_list("Teals", colors, N=256)
colors = [
    "#fff9e6",  # very light creamy gold
    "#fff0c2",
    "#ffe59e",
    "#ffda70",
    "#ffca38",
    "#ffba00",
    "#d4a700",
    "#a88000"   # deep gold/brown
]

golds = LinearSegmentedColormap.from_list("Golds", colors, N=256)

In [None]:
j = 3
NoisyImgPlot = np.copy(sbi_Arr[0][:,:,selected_indices[j]])
NoisyImgPlot[~mask[:,:,axial_middle]] = math.nan
plt.imshow(1-NoisyImgPlot.T,cmap=teals)
plt.axis('off')
plt.savefig(image_path+'sbi_denoise_10.pdf',format='pdf',bbox_inches='tight',transparent=True)
plt.show()
NoisyImgPlot = np.copy(sbi_Arr[-1][:,:,selected_indices[j]])
NoisyImgPlot[~mask[:,:,axial_middle]] = math.nan
plt.imshow(1-NoisyImgPlot.T,cmap=teals)
plt.axis('off')
plt.savefig(image_path+'sbi_denoise_50.pdf',format='pdf',bbox_inches='tight',transparent=True)

In [None]:
j = 3
tenmodel = dti.TensorModel(gtabHCP7,return_S0_hat = True,fit_method='NLLS')
tenfit = tenmodel.fit(NoisyImgs[0][...,selected_indices])
NoisyImgPlot = tenfit.predict(gtabHCP)[...,selected_indices[j]]
NoisyImgPlot[~mask[:,:,axial_middle]] = math.nan
plt.imshow(1-NoisyImgPlot.T,cmap=teals)
plt.axis('off')
plt.imshow(1-NoisyImgPlot.T,cmap=golds)
plt.savefig(image_path+'nlls_denoise_10.pdf',format='pdf',bbox_inches='tight',transparent=True)
plt.show()
tenfit = tenmodel.fit(NoisyImgs[-1][...,selected_indices])
NoisyImgPlot = tenfit.predict(gtabHCP)[...,selected_indices[j]]
NoisyImgPlot[~mask[:,:,axial_middle]] = math.nan
plt.imshow(1-NoisyImgPlot.T,cmap=teals)
plt.axis('off')
plt.imshow(1-NoisyImgPlot.T,cmap=golds)
plt.savefig(image_path+'nlls_denoise_50.pdf',format='pdf',bbox_inches='tight',transparent=True)

In [None]:
pca_ssim_Arr = []
sbi_ssim_Arr = []
nlls_ssim_Arr = []
for k in range(5):
    pca_ssim = []
    sbi_ssim = []
    nlls_ssim = []
    tenfit = tenmodel.fit(NoisyImgs[k][...,selected_indices])
    NoiseN = tenfit.predict(gtabHCP)[...,selected_indices[j]].T
    for j in range(7):
        pca_ssim.append(ssim(maskdata[:, :, axial_middle, selected_indices[j]],Denoised_Arr[k][:,:,1,j],data_range=600))
        sbi_ssim.append(ssim(maskdata[:, :, axial_middle, selected_indices[j]],sbi_Arr[k][:,:,selected_indices[j]],data_range=600))
        NoiseN = tenfit.predict(gtabHCP)[...,selected_indices[j]]
        nlls_ssim.append(ssim(maskdata[:, :, axial_middle, selected_indices[j]],NoiseN,data_range=600))
    pca_ssim_Arr.append(pca_ssim)
    sbi_ssim_Arr.append(sbi_ssim)
    nlls_ssim_Arr.append(nlls_ssim)

In [None]:
plt.errorbar([1,1.1,1.2,1.3,1.4],np.array(pca_ssim_Arr).mean(axis=-1),yerr=stats.sem(np.array(pca_ssim_Arr).mean(axis=-1),axis=-1),
             lw=3,c='tab:purple',label = 'MP-PCA')
#for i in range(32):
#    plt.scatter(np.ones(7)*(1+0.1*i),np.array(pca_ssim_Arr)[i],marker='^',color='b')
plt.errorbar([1,1.1,1.2,1.3,1.4],np.array(sbi_ssim_Arr).mean(axis=-1),yerr=stats.sem(np.array(pca_ssim_Arr).mean(axis=-1),axis=-1),
             lw=3,c=SBIFit,label = 'SBI')
#for i in range(32):
#    plt.scatter(np.ones(7)*(1+0.1*i),np.array(sbi_ssim_Arr)[i],marker='o',color='orange')
plt.errorbar([1,1.1,1.2,1.3,1.4],np.array(nlls_ssim_Arr).mean(axis=-1),yerr=stats.sem(np.array(pca_ssim_Arr).mean(axis=-1),axis=-1),
             lw=3,c=NLLSFit,label = 'NLLS')
plt.grid()
plt.yticks([0.5,0.6,0.7,0.8,0.9,1.0])
plt.xticks([1,1.1,1.2,1.3,1.4],[10,20,30,40,50])

plt.legend(loc=4,
           bbox_to_anchor=(1.05, -0.1),
           fontsize=32,
           columnspacing=0.3,
           handlelength=0.4,
           handletextpad=0.3,
          labelspacing=0.3, 
            ncols=2)
#for i in range(32):
#    plt.scatter(np.ones(7)*(1+0.1*i),np.array(nlls_ssim_Arr)[i],marker='o',color='green')
plt.savefig(image_path+'ssim_denoise.pdf',format='pdf',bbox_inches='tight',transparent=True)

# Fig 8

In [None]:
FigLoc = image_path + 'Fig_S6/'
if not os.path.exists(FigLoc):
    os.mkdir(FigLoc)

In [None]:
i=3
fdwi = './HCP_data/Pat'+str(i)+'/diff_1k.nii.gz'
bvalloc = './HCP_data/Pat'+str(i)+'/bvals_1k.txt'
bvecloc = './HCP_data/Pat'+str(i)+'/bvecs_1k.txt'

fdwi3 = './HCP_data/Pat'+str(i)+'/diff_3k.nii.gz'
bvalloc3 = './HCP_data/Pat'+str(i)+'/bvals_3k.txt'
bvecloc3 = './HCP_data/Pat'+str(i)+'/bvecs_3k.txt'

bvalsHCP = np.loadtxt(bvalloc)
bvecsHCP = np.loadtxt(bvecloc)
gtabHCP = gradient_table(bvalsHCP, bvecsHCP)

bvalsHCP3 = np.loadtxt(bvalloc3)
bvecsHCP3 = np.loadtxt(bvecloc3)
gtabHCP3 = gradient_table(bvalsHCP3, bvecsHCP3)

gtabExt  = gradient_table(np.hstack((bvalsHCP,bvalsHCP3)), np.vstack((bvecsHCP,bvecsHCP3)))

data, affine, img = load_nifti(fdwi, return_img=True)
data, affine = reslice(data, affine, (1.5,1.5,1.5), (2.5,2.5,2.5))
maskdata, mask = median_otsu(data, vol_idx=range(10, 50), median_radius=4,
                             numpass=1, autocrop=False, dilate=2)
_, mask2 = median_otsu(data, vol_idx=range(10, 50), median_radius=4,
                             numpass=1, autocrop=True, dilate=2)


data3, affine, img = load_nifti(fdwi3, return_img=True)
data3, affine = reslice(data3, affine, (1.5,1.5,1.5), (2.5,2.5,2.5))
# Get the indices of True values
true_indices = np.argwhere(mask)

# Determine the minimum and maximum indices along each dimension
min_coords = true_indices.min(axis=0)
max_coords = true_indices.max(axis=0)

maskdata  = maskdata[min_coords[0]:max_coords[0]+1,min_coords[1]:max_coords[1]+1,min_coords[2]:max_coords[2]+1]
axial_middle = maskdata.shape[2] // 2
maskdata3 = data3[min_coords[0]:max_coords[0]+1,min_coords[1]:max_coords[1]+1,min_coords[2]:max_coords[2]+1]

TestData = np.concatenate([maskdata[:, :, axial_middle, :],maskdata3[:, :, axial_middle, :]],axis=-1)
TestData4D = np.concatenate([maskdata,maskdata3],axis=-1)

In [None]:
i=3
bvalloc = './HCP_data/Pat'+str(i)+'/bvals_1k.txt'
bvecloc = './HCP_data/Pat'+str(i)+'/bvecs_1k.txt'
bvalsHCP = np.loadtxt(bvalloc)
bvecsHCP = np.loadtxt(bvecloc)
gtabHCP = gradient_table(bvalsHCP, bvecsHCP)

# Choose the first point (arbitrary starting point, e.g., the first gradient)
selected_indices = [1]
distance_matrix = squareform(pdist(bvecsHCP))
# Iteratively select the point furthest from the current selection
for _ in range(5):  # We need 7 points in total, and one is already selected
    remaining_indices = list(set(range(len(bvecsHCP))) - set(selected_indices))
    
    # Calculate the minimum distance to the selected points for each remaining point
    min_distances = np.min(distance_matrix[remaining_indices][:, selected_indices], axis=1)
    
    # Select the point with the maximum minimum distance
    next_index = remaining_indices[np.argmax(min_distances)]
    selected_indices.append(next_index)

selected_indices7 = [0]+selected_indices

bvalsHCP7_1 = bvalsHCP[selected_indices7]
bvecsHCP7_1 = bvecsHCP[selected_indices7]

i=3
bvalloc = './HCP_data/Pat'+str(i)+'/bvals_3k.txt'
bvecloc = './HCP_data/Pat'+str(i)+'/bvecs_3k.txt'

bvalsHCP3 = np.loadtxt(bvalloc)
bvecsHCP3 = np.loadtxt(bvecloc)
gtabHCP3 = gradient_table(bvalsHCP, bvecsHCP)

# Choose the first point (arbitrary starting point, e.g., the first gradient)
selected_indices = [0]

temp_bvecs = bvecsHCP3[bvalsHCP3>0]
temp_bvals = bvalsHCP3[bvalsHCP3>0]
distance_matrix = squareform(pdist(temp_bvecs))
# Iteratively select the point furthest from the current selection
for _ in range(14):  # We need 7 points in total, and one is already selected
    remaining_indices = list(set(range(len(temp_bvecs))) - set(selected_indices))
    
    # Calculate the minimum distance to the selected points for each remaining point
    min_distances = np.min(distance_matrix[remaining_indices][:, selected_indices], axis=1)
    
    # Select the point with the maximum minimum distance
    next_index = remaining_indices[np.argmax(min_distances)]
    selected_indices.append(next_index)

bvalsHCP7_3 = temp_bvals[selected_indices]
bvecsHCP7_3 = temp_bvecs[selected_indices]

gtabHCP7 = gradient_table(np.hstack((bvalsHCP7_1,bvalsHCP7_3)), np.vstack((bvecsHCP7_1,bvecsHCP7_3)))

true_indx = []
for b in bvecsHCP7_3:
    true_indx.append(np.linalg.norm(b-bvecsHCP3,axis=1).argmin())
true_indx = selected_indices7+[t+69 for t in true_indx]
gtabHCP7 = gradient_table(np.hstack((bvalsHCP7_1,bvalsHCP7_3)), np.vstack((bvecsHCP7_1,bvecsHCP7_3)))

In [None]:
np.random.seed(1)
torch.manual_seed(1)
if os.path.exists(f"{network_path}/DKIHCPMin.pickle"):
    with open(f"{network_path}/DKIHCPMin.pickle", "rb") as handle:
        posteriorFull = pickle.load(handle)
else:
    DT = []
    KT = []
    S0 = []

    DT2,KT2 = GenDTKT([DT1_lfa,DT2_lfa],[x4_lfa,R1_lfa,x2_lfa,R2_lfa],12,int(3*13000))
    DT3,KT3 = GenDTKT([DT1_hfa,DT2_hfa],[x4_hfa,R1_hfa,x2_hfa,R2_hfa],12,int(3*13000))
    DT5,KT5 = GenDTKT([DT1_hRK,DT2_hRK],[x4_hRK,R1_hRK,x2_hRK,R2_hRK],12,int(3*26000))   
    
    
    DT = np.vstack([DT2,DT3,DT5])
    KT = np.vstack([KT2,KT3,KT5])
    
    S0Dist = BoxUniform(low=torch.tensor([lower_S0]), high=torch.tensor([upper_S0]))
    
    S0 = S0Dist.sample([3*52000])
    
    S0 = np.array(S0).reshape(len(S0),1)
    
    indx = np.arange(len(KT))
    Obs = np.zeros([len(KT),len(gtabHCP7.bvecs)])
    kk = 0
    while len(indx)>0:
        for i in tqdm(indx): 
            Obs[i] = CustoRKKISimulator(DT[i],KT[i],gtabHCP7,S0[i],np.random.rand()*20 + 30)
        
        indxNew = []
        for i,O in enumerate(Obs):
            if (O>4*np.array(S0[i])).any() or (O<0).any():
                indxNew.append(i)
        KT[indxNew] = KT[indxNew]/2
        DT[indxNew] = GenDTKT([DT1_full,DT2_full],[x4_full,R1_full,x2_full,R2_full],kk,1)[0]
    
        indx = indxNew
        kk+=1
    Par = np.hstack([DT,KT,S0])
    Obs = torch.tensor(Obs).float()
    Par = torch.tensor(Par).float()
    
    # Create inference object. Here, NPE is used.
    inference = SNPE()
    
    # generate simulations and pass to the inference object
    inference = inference.append_simulations(Par, Obs)
    
    # train the density estimator and build the posterior
    density_estimator = inference.train(stop_after_epochs= 100)
    posteriorFull = inference.build_posterior(density_estimator)
    
    if not os.path.exists(f"{network_path}/DKIHCPMin.pickle"):
        with open(f"{network_path}/DKIHCPMinA2.pickle", "wb") as handle:
            pickle.dump(posteriorFull, handle)


In [None]:
# Compute the mask where the sum is not zero
mask = np.sum(TestData4D[:,:,axial_middle,:], axis=-1) != 0

# Get the indices where mask is True
indices = np.argwhere(mask)

# Define the function for optimization
def optimize_pixel(i, j):
    torch.manual_seed(10)  # If required
    posterior_samples_1 = posteriorFull.sample((500,), x=TestData4D[i,j,axial_middle, true_indx],show_progress_bars=False)
    return i, j, posterior_samples_1.mean(axis=0)

# Initialize NoiseEst with the appropriate shape
ArrShape = mask.shape

# Use joblib to parallelize the optimization tasks
results = Parallel(n_jobs=8)(
    delayed(optimize_pixel)(i, j) for i, j in tqdm(indices)
)

In [None]:
NoiseEst7 = np.zeros([62, 68 ,22])
for i, j, x in results:
    NoiseEst7[i, j] = x

In [None]:
NoiseEst2 =  np.zeros_like(NoiseEst7)
for i in range(62):
    for j in range(68):    
        NoiseEst2[i,j] = np.hstack([mat_to_vals(clip_negative_eigenvalues(vals_to_mat(NoiseEst7[i,j]))),NoiseEst7[i,j,6:]])

In [None]:
MK_SBI7  = np.zeros([62, 68])
AK_SBI7  = np.zeros([62, 68])
RK_SBI7  = np.zeros([62, 68])
MKT_SBI7 = np.zeros([62, 68])
KFA_SBI7 = np.zeros([62, 68])
for i in tqdm(range(62)):
    for j in range(68):
        Metrics = DKIMetrics(NoiseEst2[i,j][:6],NoiseEst2[i,j][6:21])
        MK_SBI7[i,j] = Metrics[0]
        AK_SBI7[i,j] = Metrics[1]
        RK_SBI7[i,j] = Metrics[2]
        MKT_SBI7[i,j] = Metrics[3]
        KFA_SBI7[i,j] = Metrics[4]

In [None]:
KFA_SBI7[np.isnan(KFA_SBI7)] = 1

In [None]:
cutout = mask[...,axial_middle]
cutout = cutout[min_coords[0]:max_coords[0]+1,min_coords[1]:max_coords[1]+1]

In [None]:
temp = np.copy(RK_SBI7)
temp[~cutout] = math.nan
plt.imshow(temp.T,cmap='hot',vmin=0,vmax=1)
plt.axis('off')
cbar = plt.colorbar(fraction=0.032, pad=0.04)
cbar.ax.set_ylim(0,1)
if Save: plt.savefig('../Figures/Fig_3/RKSBI7.pdf',format='pdf',bbox_inches='tight',transparent=True)

In [None]:
temp = np.copy(KFA_SBI7)
temp[~cutout] = math.nan
plt.imshow(temp.T,cmap='hot',vmin=0,vmax=1)
plt.axis('off')
cbar = plt.colorbar(fraction=0.032, pad=0.04)
cbar.ax.set_ylim(0,1)
if Save: plt.savefig('../Figures/Fig_3/KFASBI7.pdf',format='pdf',bbox_inches='tight',transparent=True)

In [None]:
DatFolder = './SavedDat/' 

In [None]:
RKTMinArr = np.load(DatFolder+'Min_MKT_HCP.npy',allow_pickle=True)
RKTMidArr = np.load(DatFolder+'Mid_MKT_HCP.npy',allow_pickle=True)
RKTFullArr = np.load(DatFolder+'Full_MKT_HCP.npy',allow_pickle=True)

KFAMinArr = np.load(DatFolder+'Min_KFA_HCP.npy',allow_pickle=True)
KFAMidArr = np.load(DatFolder+'Mid_KFA_HCP.npy',allow_pickle=True)
KFAFullArr = np.load(DatFolder+'Full_KFA_HCP.npy',allow_pickle=True)

In [None]:
i = 1
bvalloc = './HCP_data/Pat'+str(i)+'/bvals_1k.txt'
bvecloc = './HCP_data/Pat'+str(i)+'/bvecs_1k.txt'
bvalsHCP = np.loadtxt(bvalloc)
bvecsHCP = np.loadtxt(bvecloc)
gtabHCP = gradient_table(bvalsHCP, bvecsHCP)

# Choose the first point (arbitrary starting point, e.g., the first gradient)
selected_indices = [1]
distance_matrix = squareform(pdist(bvecsHCP))
# Iteratively select the point furthest from the current selection
for _ in range(5):  # We need 7 points in total, and one is already selected
    remaining_indices = list(set(range(len(bvecsHCP))) - set(selected_indices))
    
    # Calculate the minimum distance to the selected points for each remaining point
    min_distances = np.min(distance_matrix[remaining_indices][:, selected_indices], axis=1)
    
    # Select the point with the maximum minimum distance
    next_index = remaining_indices[np.argmax(min_distances)]
    selected_indices.append(next_index)

selected_indices7 = [0]+selected_indices

bvalsHCP7_1 = bvalsHCP[selected_indices7]
bvecsHCP7_1 = bvecsHCP[selected_indices7]

bvalloc = './HCP_data/Pat'+str(i)+'/bvals_3k.txt'
bvecloc = './HCP_data/Pat'+str(i)+'/bvecs_3k.txt'

bvalsHCP3 = np.loadtxt(bvalloc)
bvecsHCP3 = np.loadtxt(bvecloc)
gtabHCP3 = gradient_table(bvalsHCP, bvecsHCP)

# Choose the first point (arbitrary starting point, e.g., the first gradient)
selected_indices = [0]

temp_bvecs = bvecsHCP3[bvalsHCP3>0]
temp_bvals = bvalsHCP3[bvalsHCP3>0]
distance_matrix = squareform(pdist(temp_bvecs))
# Iteratively select the point furthest from the current selection
for _ in range(14):  # We need 7 points in total, and one is already selected
    remaining_indices = list(set(range(len(temp_bvecs))) - set(selected_indices))
    
    # Calculate the minimum distance to the selected points for each remaining point
    min_distances = np.min(distance_matrix[remaining_indices][:, selected_indices], axis=1)
    
    # Select the point with the maximum minimum distance
    next_index = remaining_indices[np.argmax(min_distances)]
    selected_indices.append(next_index)

bvalsHCP7_3 = temp_bvals[selected_indices]
bvecsHCP7_3 = temp_bvecs[selected_indices]

gtabHCP7 = gradient_table(np.hstack((bvalsHCP7_1,bvalsHCP7_3)), np.vstack((bvecsHCP7_1,bvecsHCP7_3)))

true_indx = []
for b in bvecsHCP7_3:
    true_indx.append(np.linalg.norm(b-bvecsHCP3,axis=1).argmin())
selected_indices7 = selected_indices7+[t+69 for t in true_indx]



In [None]:
i = 1
bvalloc = './HCP_data/Pat'+str(i)+'/bvals_1k.txt'
bvecloc = './HCP_data/Pat'+str(i)+'/bvecs_1k.txt'
bvalsHCP = np.loadtxt(bvalloc)
bvecsHCP = np.loadtxt(bvecloc)
gtabHCP = gradient_table(bvalsHCP, bvecsHCP)

# Choose the first point (arbitrary starting point, e.g., the first gradient)
selected_indices = [1]
distance_matrix = squareform(pdist(bvecsHCP))

temp_bvecs = bvecsHCP[bvalsHCP>0]
temp_bvals = bvalsHCP[bvalsHCP>0]
distance_matrix = squareform(pdist(temp_bvecs))
# Iteratively select the point furthest from the current selection
for _ in range(18):  # We need 7 points in total, and one is already selected
    remaining_indices = list(set(range(len(temp_bvecs))) - set(selected_indices))
    
    # Calculate the minimum distance to the selected points for each remaining point
    min_distances = np.min(distance_matrix[remaining_indices][:, selected_indices], axis=1)
    
    # Select the point with the maximum minimum distance
    next_index = remaining_indices[np.argmax(min_distances)]
    selected_indices.append(next_index)

temp = selected_indices

bvalsHCP7_1 = np.insert(temp_bvals[temp],0,0)
bvecsHCP7_1 = np.insert(temp_bvecs[temp],0,[0,0,0],axis=0)

bvalloc = './HCP_data/Pat'+str(i)+'/bvals_3k.txt'
bvecloc = './HCP_data/Pat'+str(i)+'/bvecs_3k.txt'

bvalsHCP3 = np.loadtxt(bvalloc)
bvecsHCP3 = np.loadtxt(bvecloc)
gtabHCP3 = gradient_table(bvalsHCP, bvecsHCP)

# Choose the first point (arbitrary starting point, e.g., the first gradient)
selected_indices = [0]

temp_bvecs = bvecsHCP3[bvalsHCP3>0]
temp_bvals = bvalsHCP3[bvalsHCP3>0]
distance_matrix = squareform(pdist(temp_bvecs))
# Iteratively select the point furthest from the current selection
for _ in range(27):  # We need 7 points in total, and one is already selected
    remaining_indices = list(set(range(len(temp_bvecs))) - set(selected_indices))
    
    # Calculate the minimum distance to the selected points for each remaining point
    min_distances = np.min(distance_matrix[remaining_indices][:, selected_indices], axis=1)
    
    # Select the point with the maximum minimum distance
    next_index = remaining_indices[np.argmax(min_distances)]
    selected_indices.append(next_index)

bvalsHCP7_3 = temp_bvals[selected_indices]
bvecsHCP7_3 = temp_bvecs[selected_indices]

gtabHCP20 = gradient_table(np.hstack((bvalsHCP7_1,bvalsHCP7_3)), np.vstack((bvecsHCP7_1,bvecsHCP7_3)))

true_indx_one = []
for b in bvecsHCP7_1:
    true_indx_one.append(np.linalg.norm(b-bvecsHCP,axis=1).argmin())
true_indx = []        
for b in bvecsHCP7_3:
    true_indx.append(np.linalg.norm(b-bvecsHCP3,axis=1).argmin())
selected_indices20 = true_indx_one+[t+69 for t in true_indx]

In [None]:
gTabsF = []
gTabs7 = []
gTabs20 = []

FullDat   = []

for i in tqdm(range(1,33)):
    fdwi = './HCP_data/Pat'+str(i)+'/diff_1k.nii.gz'
    bvalloc = './HCP_data/Pat'+str(i)+'/bvals_1k.txt'
    bvecloc = './HCP_data/Pat'+str(i)+'/bvecs_1k.txt'
    
    fdwi3 = './HCP_data/Pat'+str(i)+'/diff_3k.nii.gz'
    bvalloc3 = './HCP_data/Pat'+str(i)+'/bvals_3k.txt'
    bvecloc3 = './HCP_data/Pat'+str(i)+'/bvecs_3k.txt'
    
    bvalsHCP = np.loadtxt(bvalloc)
    bvecsHCP = np.loadtxt(bvecloc)
    gtabHCP = gradient_table(bvalsHCP, bvecsHCP)
    
    bvalsHCP3 = np.loadtxt(bvalloc3)
    bvecsHCP3 = np.loadtxt(bvecloc3)
    gtabHCP3 = gradient_table(bvalsHCP3, bvecsHCP3)
    
    gtabExt  = gradient_table(np.hstack((bvalsHCP,bvalsHCP3)), np.vstack((bvecsHCP,bvecsHCP3)))
    gTabsF.append(gtabExt)
    
    bvalsHCP7 = gtabExt.bvals[selected_indices7]
    bvecsHCP7 = gtabExt.bvecs[selected_indices7]
    gtabHCP7 = gradient_table(bvalsHCP7, bvecsHCP7)
    gTabs7.append(gtabHCP7)

    bvalsHCP20 = gtabExt.bvals[selected_indices20]
    bvecsHCP20 = gtabExt.bvecs[selected_indices20]
    gtabHCP20 = gradient_table(bvalsHCP20, bvecsHCP20)
    gTabs20.append(gtabHCP20)

In [None]:
TD = []
axial_middles = []
masks = []
WMs = []
for kk in tqdm(range(32)):
    fdwi = './HCP_data/Pat'+str(kk+1)+'/diff_1k.nii.gz'
    bvalloc = './HCP_data/Pat'+str(kk+1)+'/bvals_1k.txt'
    bvecloc = './HCP_data/Pat'+str(kk+1)+'/bvecs_1k.txt'
    
    fdwi3 = './HCP_data/Pat'+str(kk+1)+'/diff_3k.nii.gz'
    bvalloc3 = './HCP_data/Pat'+str(kk+1)+'/bvals_3k.txt'
    bvecloc3 = './HCP_data/Pat'+str(kk+1)+'/bvecs_3k.txt'
    
    bvalsHCP = np.loadtxt(bvalloc)
    bvecsHCP = np.loadtxt(bvecloc)
    gtabHCP = gradient_table(bvalsHCP, bvecsHCP)
    
    bvalsHCP3 = np.loadtxt(bvalloc3)
    bvecsHCP3 = np.loadtxt(bvecloc3)
    gtabHCP3 = gradient_table(bvalsHCP3, bvecsHCP3)
    
    gtabExt  = gradient_table(np.hstack((bvalsHCP,bvalsHCP3)), np.vstack((bvecsHCP,bvecsHCP3)))
    
    data, affine, img = load_nifti(fdwi, return_img=True)
    data, affine = reslice(data, affine, (1.5,1.5,1.5), (2.5,2.5,2.5))
    maskdata, mask = median_otsu(data, vol_idx=range(10, 50), median_radius=3,
                                 numpass=1, autocrop=False, dilate=2)
    _, mask2 = median_otsu(data, vol_idx=range(10, 50), median_radius=3,
                                 numpass=1, autocrop=True, dilate=2)
    
    
    data3, affine, img = load_nifti(fdwi3, return_img=True)
    data3, affine = reslice(data3, affine, (1.5,1.5,1.5), (2.5,2.5,2.5))
    # Get the indices of True values
    true_indices = np.argwhere(mask)
    
    # Determine the minimum and maximum indices along each dimension
    min_coords = true_indices.min(axis=0)
    max_coords = true_indices.max(axis=0)
    
    maskdata  = maskdata[min_coords[0]:max_coords[0]+1,min_coords[1]:max_coords[1]+1,min_coords[2]:max_coords[2]+1]
    axial_middle = maskdata.shape[2] // 2
    maskdata3 = data3[min_coords[0]:max_coords[0]+1,min_coords[1]:max_coords[1]+1,min_coords[2]:max_coords[2]+1]
    axial_middles.append(axial_middle)
    TestData = np.concatenate([maskdata[:, :, axial_middle, :],maskdata3[:, :, axial_middle, :]],axis=-1)
    TestData4D = np.concatenate([maskdata,maskdata3],axis=-1)
    TD.append(TestData4D)
    masks.append(mask[min_coords[0]:max_coords[0]+1,min_coords[1]:max_coords[1]+1,axial_middle])
    WM, affine, img = load_nifti('./flipped/c2Pat'+str(kk+1)+'_FP.nii', return_img=True)
    WMs.append(np.fliplr(WM[:,:,axial_middle]>0.8))

In [None]:
TD = []
axial_middles = []
masks = []
WMs = []
for kk in tqdm(range(32)):
    fdwi = './HCP_data/Pat'+str(kk+1)+'/diff_1k.nii.gz'
    bvalloc = './HCP_data/Pat'+str(kk+1)+'/bvals_1k.txt'
    bvecloc = './HCP_data/Pat'+str(kk+1)+'/bvecs_1k.txt'
    
    fdwi3 = './HCP_data/Pat'+str(kk+1)+'/diff_3k.nii.gz'
    bvalloc3 = './HCP_data/Pat'+str(kk+1)+'/bvals_3k.txt'
    bvecloc3 = './HCP_data/Pat'+str(kk+1)+'/bvecs_3k.txt'
    
    bvalsHCP = np.loadtxt(bvalloc)
    bvecsHCP = np.loadtxt(bvecloc)
    gtabHCP = gradient_table(bvalsHCP, bvecsHCP)
    
    bvalsHCP3 = np.loadtxt(bvalloc3)
    bvecsHCP3 = np.loadtxt(bvecloc3)
    gtabHCP3 = gradient_table(bvalsHCP3, bvecsHCP3)
    
    gtabExt  = gradient_table(np.hstack((bvalsHCP,bvalsHCP3)), np.vstack((bvecsHCP,bvecsHCP3)))
    
    data, affine, img = load_nifti(fdwi, return_img=True)
    data, affine = reslice(data, affine, (1.5,1.5,1.5), (2.5,2.5,2.5))
    maskdata, mask = median_otsu(data, vol_idx=range(10, 50), median_radius=3,
                                 numpass=1, autocrop=False, dilate=2)
    _, mask2 = median_otsu(data, vol_idx=range(10, 50), median_radius=3,
                                 numpass=1, autocrop=True, dilate=2)
    
    
    data3, affine, img = load_nifti(fdwi3, return_img=True)
    data3, affine = reslice(data3, affine, (1.5,1.5,1.5), (2.5,2.5,2.5))
    # Get the indices of True values
    true_indices = np.argwhere(mask)
    
    # Determine the minimum and maximum indices along each dimension
    min_coords = true_indices.min(axis=0)
    max_coords = true_indices.max(axis=0)
    
    maskdata  = maskdata[min_coords[0]:max_coords[0]+1,min_coords[1]:max_coords[1]+1,min_coords[2]:max_coords[2]+1]
    axial_middle = maskdata.shape[2] // 2
    maskdata3 = data3[min_coords[0]:max_coords[0]+1,min_coords[1]:max_coords[1]+1,min_coords[2]:max_coords[2]+1]
    axial_middles.append(axial_middle)
    TestData = np.concatenate([maskdata[:, :, axial_middle, :],maskdata3[:, :, axial_middle, :]],axis=-1)
    TestData4D = np.concatenate([maskdata,maskdata3],axis=-1)
    TD.append(TestData4D)
    masks.append(mask[min_coords[0]:max_coords[0]+1,min_coords[1]:max_coords[1]+1,axial_middle])
    WM, affine, img = load_nifti('./flipped/c2Pat'+str(kk+1)+'_FP.nii', return_img=True)
    WMs.append(np.fliplr(WM[min_coords[0]:max_coords[0]+1,min_coords[1]:max_coords[1]+1,axial_middles[kk]]>0.8))
    print(WMs[-1].shape)

In [None]:
MKFullNLArr = []
AKFullNLArr = []
RKFullNLArr = []
MKTFullNLArr = []
KFAFullNLArr = []
for kk in tqdm(range(32)):
    dkimodelNL = dki.DiffusionKurtosisModel(gTabsF[kk],fit_method='NLLS')
    dkifitNL = dkimodelNL.fit(TD[kk][:,:,axial_middles[kk]])
    ArrShape = TD[kk][:,:,axial_middles[kk],0].shape
    NoiseEst_NL = np.zeros(list(ArrShape)+[21])
    MK_NL7  = np.zeros(ArrShape)
    AK_NL7  = np.zeros(ArrShape)
    RK_NL7 = np.zeros(ArrShape)
    MKT_NL7 = np.zeros(ArrShape)
    KFA_NL7 = np.zeros(ArrShape)
    for i in range(ArrShape[0]):
        for j in range(ArrShape[1]):
            NoiseEst_NL[i,j] = np.hstack([dkifitNL[i,j].lower_triangular(),dkifitNL[i,j].kt])
    NoiseEst_NL2 =  np.zeros_like(NoiseEst_NL)
    for i in range(ArrShape[0]):
        for j in range(ArrShape[1]):
            NoiseEst_NL2[i,j] = np.hstack([mat_to_vals(clip_negative_eigenvalues(vals_to_mat(NoiseEst_NL[i,j]))),NoiseEst_NL[i,j,6:]])
    for i in range(ArrShape[0]):
        for j in range(ArrShape[1]):
            Metrics = DKIMetrics(NoiseEst_NL2[i,j][:6],NoiseEst_NL2[i,j][6:21])
            MK_NL7[i,j] = Metrics[0]
            AK_NL7[i,j] = Metrics[1]
            RK_NL7[i,j] = Metrics[2]
            MKT_NL7[i,j] = Metrics[3]
            KFA_NL7[i,j] = Metrics[4]
    MKFullNLArr.append(RK_NL7)
    AKFullNLArr.append(RK_NL7)
    RKFullNLArr.append(RK_NL7)
    MKTFullNLArr.append(RKT_NL7)
    KFAFullNLArr.append(KFA_NL7)

In [None]:
MKMidNLArr = []
AKMidNLArr = []
RKMidNLArr = []
MKTMidNLArr = []
KFAMidNLArr = []
for kk in tqdm(range(32)):
    dkimodelNL = dki.DiffusionKurtosisModel(gTabs20[kk],fit_method='NLLS')
    dkifitNL = dkimodelNL.fit(TD[kk][:,:,axial_middles[kk],selected_indices20])
    ArrShape = TD[kk][:,:,axial_middles[kk],0].shape
    NoiseEst_NL = np.zeros(list(ArrShape)+[21])
    MK_NL7  = np.zeros(ArrShape)
    AK_NL7  = np.zeros(ArrShape)
    RK_NL7 = np.zeros(ArrShape)
    MKT_NL7 = np.zeros(ArrShape)
    KFA_NL7 = np.zeros(ArrShape)
    for i in range(ArrShape[0]):
        for j in range(ArrShape[1]):
            NoiseEst_NL[i,j] = np.hstack([dkifitNL[i,j].lower_triangular(),dkifitNL[i,j].kt])
    NoiseEst_NL2 =  np.zeros_like(NoiseEst_NL)
    for i in range(ArrShape[0]):
        for j in range(ArrShape[1]):
            NoiseEst_NL2[i,j] = np.hstack([mat_to_vals(clip_negative_eigenvalues(vals_to_mat(NoiseEst_NL[i,j]))),NoiseEst_NL[i,j,6:]])
    for i in range(ArrShape[0]):
        for j in range(ArrShape[1]):
            Metrics = DKIMetrics(NoiseEst_NL2[i,j][:6],NoiseEst_NL2[i,j][6:21])
            MK_NL7[i,j] = Metrics[0]
            AK_NL7[i,j] = Metrics[1]
            RK_NL7[i,j] = Metrics[2]
            MKT_NL7[i,j] = Metrics[3]
            KFA_NL7[i,j] = Metrics[4]
    MKMidNLArr.append(RK_NL7)
    AKMidNLArr.append(RK_NL7)
    RKMidNLArr.append(RK_NL7)
    MKTMidNLArr.append(RKT_NL7)
    KFAMidNLArr.append(KFA_NL7)

In [None]:
MKMinNLArr = []
AKMinNLArr = []
RKMinNLArr = []
MKTMinNLArr = []
KFAMinNLArr = []
for kk in tqdm(range(32)):
    dkimodelNL = dki.DiffusionKurtosisModel(gTabs7[kk],fit_method='NLLS')
    dkifitNL = dkimodelNL.fit(TD[kk][:,:,axial_middles[kk],selected_indices7])
    ArrShape = TD[kk][:,:,axial_middles[kk],0].shape
    NoiseEst_NL = np.zeros(list(ArrShape)+[21])
    MK_NL7  = np.zeros(ArrShape)
    AK_NL7  = np.zeros(ArrShape)
    RK_NL7 = np.zeros(ArrShape)
    MKT_NL7 = np.zeros(ArrShape)
    KFA_NL7 = np.zeros(ArrShape)
    for i in range(ArrShape[0]):
        for j in range(ArrShape[1]):
            NoiseEst_NL[i,j] = np.hstack([dkifitNL[i,j].lower_triangular(),dkifitNL[i,j].kt])
    NoiseEst_NL2 =  np.zeros_like(NoiseEst_NL)
    for i in range(ArrShape[0]):
        for j in range(ArrShape[1]):
            NoiseEst_NL2[i,j] = np.hstack([mat_to_vals(clip_negative_eigenvalues(vals_to_mat(NoiseEst_NL[i,j]))),NoiseEst_NL[i,j,6:]])
    for i in range(ArrShape[0]):
        for j in range(ArrShape[1]):
            Metrics = DKIMetrics(NoiseEst_NL2[i,j][:6],NoiseEst_NL2[i,j][6:21])
            MK_NL7[i,j] = Metrics[0]
            AK_NL7[i,j] = Metrics[1]
            RK_NL7[i,j] = Metrics[2]
            MKT_NL7[i,j] = Metrics[3]
            KFA_NL7[i,j] = Metrics[4]
    MKMinNLArr.append(RK_NL7)
    AKMinNLArr.append(RK_NL7)
    RKMinNLArr.append(RK_NL7)
    MKTMinNLArr.append(RKT_NL7)
    KFAMinNLArr.append(KFA_NL7)

In [None]:
AccM7 = []
for i in range(32):
    M7 =MKTMinArr[i]
    MF =MKTFullArr[i]
    Ma = masks[i]
    AccM7.append(np.mean(np.abs(M7-MF)[Ma]))

AccM20 = []
for i in range(32):
    M7 =MKTMidArr[i]
    MF =MKTFullArr[i]
    Ma = masks[i]
    AccM20.append(np.mean(np.abs(M7-MF)[Ma]))

AccMFulls = []
for i in range(32):
    M7 =MKTFullArr[i]
    MF =MKTFullNLArr[i]
    Ma = masks[i]
    AccMFulls.append(np.mean(np.abs(M7-MF)[Ma]))

AccM7NL = []
for i in range(32):
    M7 =MKTMinNLArr[i]
    M7[np.isnan(M7)] = 0
    MF =MKTFullNLArr[i]
    Ma = masks[i]
    AccM7NL.append(np.mean(np.abs(M7-MF)[Ma]))

AccM20NL = []
for i in range(32):
    M7 =MKTMidNLArr[i]
    M7[np.isnan(M7)] = 0
    MF =MKTFullNLArr[i]
    Ma = masks[i]
    AccM20NL.append(np.nanmean(np.abs(M7-MF)[Ma]))

SSIM7 = []
SSIM20 = []
SSIMFulls = []

SSIM7NL = []
SSIM20NL = []
for i in tqdm(range(32)):
    NS1 =MKTMinArr[i]
    NS1 = gaussian_filter(NS1, sigma=0.5)
    NS2 =MKTFullArr[i]
    Ma = masks[i]
    result = masked_local_ssim(NS1, NS2, Ma, win_size=7,dat_range=1)
    SSIM7.append(result)

    NS1 =MKTMidArr[i]
    NS1 = gaussian_filter(NS1, sigma=0.5)
    NS2 =MKTFullArr[i]
    Ma = masks[i]
    result = masked_local_ssim(NS1, NS2, Ma, win_size=7,dat_range=1)
    SSIM20.append(result)
    
    NS1 =MKTFullArr[i]
    NS2 =MKTFullNLArr[i]
    Ma = masks[i]
    result = masked_local_ssim(NS1, NS2, Ma, win_size=7,dat_range=1)
    SSIMFulls.append(result)

    NS1 =MKTMinNLArr[i]
    NS1 = gaussian_filter(NS1, sigma=0.5)
    NS2 =MKTFullNLArr[i]
    Ma = masks[i]
    result = masked_local_ssim(NS1, NS2, Ma, win_size=7,dat_range=1)
    SSIM7NL.append(result)

    NS1 =MKTMidNLArr[i]
    NS1 = gaussian_filter(NS1, sigma=0.5)
    NS2 =MKTFullNLArr[i]
    Ma = masks[i]
    result = masked_local_ssim(NS1, NS2, Ma, win_size=7,dat_range=1)
    SSIM20NL.append(result)





In [None]:
# Plot setup
fig, ax = plt.subplots(1,1, figsize=(3.2,4.8))
fig.subplots_adjust(hspace=0.05)



y_data = np.array(AccM7NL)
g_pos = np.array([3.05])
colors = ['burlywood']
colors2 = ['peachpuff']

BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2,scatter=True)
plt.xticks([1,1.7,2,2.8,3.1],['Full','Mid','Min','Mid','Min'],fontsize=32,rotation=90)
#plt.gca().ticklabel_format(axis='y',style='sci',scilimits=(-1,1))
plt.yticks(fontsize=24)

y_data = np.array(AccMFulls)
g_pos = np.array([0.8])
colors = ['black']
colors2 = ['gray']

BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2,scatter=True)

y_data = np.array(AccM20)
g_pos = np.array([1.55])
colors = ['lightseagreen']
colors2 = ['paleturquoise']

BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2,scatter=True)

y_data = np.array(AccM7)
g_pos = np.array([1.95])
colors = ['mediumturquoise']
colors2 = ['paleturquoise']

BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2,scatter=True)

y_data = np.array(AccM20NL)
g_pos = np.array([2.65])
colors = ['sandybrown']
colors2 = ['peachpuff']

BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2,scatter=True)

ax.set_ylim(0,0.6)
plt.xticks([0.8,1.55,1.95,2.65,3.05],['Full','Mid','Min','Mid','Min'],fontsize=32,rotation=90)
ax.set_xlim(0.5,3.5)

if Save: plt.savefig(FigLoc+'DKIHCP_Acc_MKT.pdf',format='PDF',transparent=True,bbox_inches='tight')

In [None]:
fig, ax = plt.subplots(figsize=(3.2,4.8))#, sharex=True)

# Plotting on ax1
plt.sca(ax)
y_data = np.array(SSIMFulls)
g_pos = np.array([0.8])
colors = ['black']
colors2 = ['gray']

BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2,scatter=True)

y_data = np.array(SSIM20)
g_pos = np.array([1.5])
colors = ['lightseagreen']
colors2 = ['paleturquoise']

BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2,scatter=True,scatter_alpha=0.5)

y_data = np.array(SSIM7)
g_pos = np.array([1.95])
colors = ['mediumturquoise']
colors2 = ['paleturquoise']

BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2,scatter=True)

y_data = np.array(SSIM20NL)
g_pos = np.array([2.6])
colors = ['sandybrown']
colors2 = ['peachpuff']

BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2,scatter=True)

y_data = np.array(SSIM7NL)
g_pos = np.array([3.05])
colors = ['burlywood']
colors2 = ['peachpuff']

BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2,scatter=True)

y_data = np.array(SSIM7NL)
g_pos = np.array([3.1])
colors = ['burlywood']
colors2 = ['peachpuff']

plt.axhline(0.66, lw=3, ls='--', c='k')
plt.xticks([0.8,1.5,1.95,2.6,3.05],['Full','Mid','Min','Mid','Min'],fontsize=32,rotation=90)
ax.set_yticks([0,0.2,0.4,0.6,0.8,1.0])
ax.set_ylim(-0.1,1)

ax.legend(
    handles=[leg_patch3],
    loc='lower left',         # base location  # fine-tune the legend's position
    frameon=False, ncols=1,
fontsize=32,columnspacing=0.3,handlelength=0.6,handletextpad=0.3,bbox_to_anchor= (-0.1,-0.05))

if Save: plt.savefig(FigLoc+'DKIHCP_SSIM_RKT.pdf',format='PDF',transparent=True,bbox_inches='tight')

In [None]:
Prec7_SBI = []
Prec20_SBI = []
PrecFull_SBI = []

Prec7_NLLS = []
Prec20_NLLS = []
PrecFull_NLLS = []
for i in range(32):
    Prec7_SBI.append(np.std(MKTMinArr[i][WMs[i]]))
    Prec20_SBI.append(np.std(MKTMidArr[i][WMs[i]]))
    PrecFull_SBI.append(np.std(MKTFullArr[i][WMs[i]]))

    Prec7_NLLS.append(np.std(MKTMinNLArr[i][WMs[i]]))
    Prec20_NLLS.append(np.std(MKTMidNLArr[i][WMs[i]]))
    PrecFull_NLLS.append(np.std(MKTFullNLArr[i][WMs[i]]))


In [None]:
fig,ax1 = plt.subplots(1,1,figsize=(3.2,4.8))

y_data = np.array(PrecFull_SBI)
g_pos = np.array([0.65])
colors = ['mediumturquoise']
colors2 = ['paleturquoise']

BoxPlots(y_data,g_pos,colors,colors2,ax1,widths=0.2,scatter=True)

y_data = np.array(Prec20_SBI)
g_pos = np.array([1.0])

BoxPlots(y_data,g_pos,colors,colors2,ax1,widths=0.2,scatter=True)

y_data = np.array(Prec7_SBI)
g_pos = np.array([1.35])

BoxPlots(y_data,g_pos,colors,colors2,ax1,widths=0.2,scatter=True)

y_data = np.array(PrecFull_NLLS)
g_pos = np.array([1.8])
colors = ['sandybrown']
colors2 = ['peachpuff']

BoxPlots(y_data,g_pos,colors,colors2,ax1,widths=0.2,scatter=True)

y_data = np.array(Prec20_NLLS)
g_pos = np.array([2.15])

BoxPlots(y_data,g_pos,colors,colors2,ax1,widths=0.2,scatter=True)

y_data = np.array(Prec7_NLLS)
g_pos = np.array([2.5])

BoxPlots(y_data,g_pos,colors,colors2,ax1,widths=0.2,scatter=True)
plt.xticks([0.65,1,1.35,1.8,2.15,2.5],['Full','Mid','Min','Full','Mid','Min'],fontsize=32,rotation=90)

x = np.arange(1.7,2.6,0.05)
y1 = np.ones_like(x)*np.percentile(np.array(PrecFull_NLLS)[~np.isnan(PrecFull_NLLS)], 25)
y2 = np.ones_like(x)*np.percentile(np.array(PrecFull_NLLS)[~np.isnan(PrecFull_NLLS)], 77)
plt.fill_between(x,y1,y2,color=WLSFit,zorder=10,alpha=0.2,hatch='//')

x = np.arange(0.55,1.5,0.05)
y1 = np.ones_like(x)*np.percentile(np.array(PrecFull_SBI)[~np.isnan(PrecFull_SBI)], 25)
y2 = np.ones_like(x)*np.percentile(np.array(PrecFull_SBI)[~np.isnan(PrecFull_SBI)], 77)
plt.fill_between(x,y1,y2,color=SBIFit,zorder=10,alpha=0.2,hatch='//')

ax1.set_xlim(0.3,2.8)

if Save: plt.savefig(FigLoc+'DKI_RKT_Prec.pdf',format='pdf',bbox_inches='tight',transparent=True)

In [None]:
AccM7 = []
for i in range(32):
    M7 =KFAMinArr[i]
    MF =KFAFullArr[i]
    Ma = masks[i]
    AccM7.append(np.nanmean(np.abs(M7-MF)[Ma]))

AccM20 = []
for i in range(32):
    M7 =KFAMidArr[i]
    MF =KFAFullArr[i]
    Ma = masks[i]
    AccM20.append(np.nanmean(np.abs(M7-MF)[Ma]))

AccMFulls = []
for i in range(32):
    M7 =KFAFullArr[i]
    MF =KFAFullNLArr[i]
    Ma = masks[i]
    AccMFulls.append(np.nanmean(np.abs(M7-MF)[Ma]))

AccM7NL = []
for i in range(32):
    M7 =KFAMinNLArr[i]
    M7[np.isnan(M7)] = 0
    MF =KFAFullNLArr[i]
    Ma = masks[i]
    AccM7NL.append(np.nanmean(np.abs(M7-MF)[Ma]))

AccM20NL = []
for i in range(32):
    M7 =KFAMidNLArr[i]
    M7[np.isnan(M7)] = 0
    MF =KFAFullNLArr[i]
    Ma = masks[i]
    AccM20NL.append(np.nanmean(np.abs(M7-MF)[Ma]))

SSIM7 = []
SSIM20 = []
SSIMFulls = []

SSIM7NL = []
SSIM20NL = []
for i in tqdm(range(32)):
    NS1 =KFAMinArr[i]
    NS1 = gaussian_filter(NS1, sigma=0.5)
    NS2 =KFAFullArr[i]
    NS2 = gaussian_filter(NS2, sigma=0.5)
    Ma = masks[i]
    result = masked_local_ssim(NS1, NS2, Ma, win_size=7,dat_range=1)
    SSIM7.append(result)

    NS1 =KFAMidArr[i]
    NS1 = gaussian_filter(NS1, sigma=0.5)
    NS2 =KFAFullArr[i]
    Ma = masks[i]
    result = masked_local_ssim(NS1, NS2, Ma, win_size=7,dat_range=1)
    SSIM20.append(result)
    
    NS1 =KFAFullArr[i]
    NS2 =KFAFullNLArr[i]
    Ma = masks[i]
    result = masked_local_ssim(NS1, NS2, Ma, win_size=7,dat_range=1)
    SSIMFulls.append(result)

    NS1 =KFAMinNLArr[i]
    NS1 = gaussian_filter(NS1, sigma=0.5)
    NS2 =KFAFullNLArr[i]
    NS2 = gaussian_filter(NS2, sigma=0.5)
    Ma = masks[i]
    result = masked_local_ssim(NS1, NS2, Ma, win_size=7,dat_range=1)
    SSIM7NL.append(result)

    NS1 =KFAMidNLArr[i]
    NS1 = gaussian_filter(NS1, sigma=0.5)
    NS2 =KFAFullNLArr[i]
    Ma = masks[i]
    result = masked_local_ssim(NS1, NS2, Ma, win_size=7,dat_range=1)
    SSIM20NL.append(result)

In [None]:
# Plot setup
fig, ax = plt.subplots(1,1, figsize=(3.2,4.8))
fig.subplots_adjust(hspace=0.05)



y_data = np.array(AccM7NL)
g_pos = np.array([3.05])
colors = ['burlywood']
colors2 = ['peachpuff']

BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2,scatter=True)
plt.xticks([1,1.7,2,2.8,3.1],['Full','Mid','Min','Mid','Min'],fontsize=32,rotation=90)
#plt.gca().ticklabel_format(axis='y',style='sci',scilimits=(-1,1))
plt.yticks(fontsize=24)

y_data = np.array(AccMFulls)
g_pos = np.array([0.8])
colors = ['black']
colors2 = ['gray']

BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2,scatter=True)

y_data = np.array(AccM20)
g_pos = np.array([1.55])
colors = ['lightseagreen']
colors2 = ['paleturquoise']

BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2,scatter=True)

y_data = np.array(AccM7)
g_pos = np.array([1.95])
colors = ['mediumturquoise']
colors2 = ['paleturquoise']

BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2,scatter=True)

y_data = np.array(AccM20NL)
g_pos = np.array([2.65])
colors = ['sandybrown']
colors2 = ['peachpuff']

BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2,scatter=True)

ax.set_ylim(0,0.6)
plt.xticks([0.8,1.55,1.95,2.65,3.05],['Full','Mid','Min','Mid','Min'],fontsize=32,rotation=90)
ax.set_xlim(0.5,3.5)

if Save: plt.savefig(FigLoc+'DKIHCP_Acc_KFA.pdf',format='PDF',transparent=True,bbox_inches='tight')

In [None]:
fig, ax = plt.subplots(figsize=(3.2,4.8))#, sharex=True)

# Plotting on ax1
plt.sca(ax)
y_data = np.array(SSIMFulls)
g_pos = np.array([0.8])
colors = ['black']
colors2 = ['gray']

BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2,scatter=True)

y_data = np.array(SSIM20)
g_pos = np.array([1.5])
colors = ['lightseagreen']
colors2 = ['paleturquoise']

BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2,scatter=True,scatter_alpha=0.5)

y_data = np.array(SSIM7)
g_pos = np.array([1.95])
colors = ['mediumturquoise']
colors2 = ['paleturquoise']

BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2,scatter=True)

y_data = np.array(SSIM20NL)
g_pos = np.array([2.6])
colors = ['sandybrown']
colors2 = ['peachpuff']

BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2,scatter=True)

y_data = np.array(SSIM7NL)
g_pos = np.array([3.05])
colors = ['burlywood']
colors2 = ['peachpuff']

BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2,scatter=True)

y_data = np.array(SSIM7NL)
g_pos = np.array([3.1])
colors = ['burlywood']
colors2 = ['peachpuff']

plt.axhline(0.66, lw=3, ls='--', c='k')
plt.xticks([0.8,1.5,1.95,2.6,3.05],['Full','Mid','Min','Mid','Min'],fontsize=32,rotation=90)
ax.set_yticks([0,0.2,0.4,0.6,0.8,1.0])
ax.set_ylim(-0.1,1)

ax.legend(
    handles=[leg_patch3],
    loc='lower left',         # base location  # fine-tune the legend's position
    frameon=False, ncols=1,
fontsize=32,columnspacing=0.3,handlelength=0.6,handletextpad=0.3,bbox_to_anchor= (-0.1,-0.05))

if Save: plt.savefig(FigLoc+'DKIHCP_SSIM_KFA.pdf',format='PDF',transparent=True,bbox_inches='tight')

In [None]:
Prec7_SBI = []
Prec20_SBI = []
PrecFull_SBI = []

Prec7_NLLS = []
Prec20_NLLS = []
PrecFull_NLLS = []
for i in range(32):
    Prec7_SBI.append(np.nanstd(KFAMinArr[i][WMs[i]]))
    Prec20_SBI.append(np.nanstd(KFAMidArr[i][WMs[i]]))
    PrecFull_SBI.append(np.nanstd(KFAFullArr[i][WMs[i]]))

    Prec7_NLLS.append(np.std(KFAMinNLArr[i][WMs[i]]))
    Prec20_NLLS.append(np.std(KFAMidNLArr[i][WMs[i]]))
    PrecFull_NLLS.append(np.std(KFAFullNLArr[i][WMs[i]]))


In [None]:
fig,ax1 = plt.subplots(1,1,figsize=(3.2,4.8))

y_data = np.array(PrecFull_SBI)
g_pos = np.array([0.65])
colors = ['mediumturquoise']
colors2 = ['paleturquoise']

BoxPlots(y_data,g_pos,colors,colors2,ax1,widths=0.2,scatter=True)

y_data = np.array(Prec20_SBI)
g_pos = np.array([1.0])

BoxPlots(y_data,g_pos,colors,colors2,ax1,widths=0.2,scatter=True)

y_data = np.array(Prec7_SBI)
g_pos = np.array([1.35])

BoxPlots(y_data,g_pos,colors,colors2,ax1,widths=0.2,scatter=True)

y_data = np.array(PrecFull_NLLS)
g_pos = np.array([1.8])
colors = ['sandybrown']
colors2 = ['peachpuff']

BoxPlots(y_data,g_pos,colors,colors2,ax1,widths=0.2,scatter=True)

y_data = np.array(Prec20_NLLS)
g_pos = np.array([2.15])

BoxPlots(y_data,g_pos,colors,colors2,ax1,widths=0.2,scatter=True)

y_data = np.array(Prec7_NLLS)
g_pos = np.array([2.5])

BoxPlots(y_data,g_pos,colors,colors2,ax1,widths=0.2,scatter=True)
plt.xticks([0.65,1,1.35,1.8,2.15,2.5],['Full','Mid','Min','Full','Mid','Min'],fontsize=32,rotation=90)

x = np.arange(1.7,2.6,0.05)
y1 = np.ones_like(x)*np.percentile(np.array(PrecFull_NLLS)[~np.isnan(PrecFull_NLLS)], 25)
y2 = np.ones_like(x)*np.percentile(np.array(PrecFull_NLLS)[~np.isnan(PrecFull_NLLS)], 77)
plt.fill_between(x,y1,y2,color=WLSFit,zorder=10,alpha=0.2,hatch='//')

x = np.arange(0.55,1.5,0.05)
y1 = np.ones_like(x)*np.percentile(np.array(PrecFull_SBI)[~np.isnan(PrecFull_SBI)], 25)
y2 = np.ones_like(x)*np.percentile(np.array(PrecFull_SBI)[~np.isnan(PrecFull_SBI)], 77)
plt.fill_between(x,y1,y2,color=SBIFit,zorder=10,alpha=0.2,hatch='//')

ax1.set_xlim(0.3,2.8)

if Save: plt.savefig(FigLoc+'DKI_KFA_Prec.pdf',format='pdf',bbox_inches='tight',transparent=True)

In [None]:
fimg_init, fbvals, fbvecs = get_fnames('small_64D')
bvals, bvecs = read_bvals_bvecs(fbvals, fbvecs)
hsph_initial15 = HemiSphere(xyz=bvecs[1:16])
hsph_initial7 = HemiSphere(xyz=bvecs[1:7])
hsph_updated15,_ = disperse_charges(hsph_initial15,5000)
hsph_updated7,_ = disperse_charges(hsph_initial7,5000)
gtabSimSub = gradient_table(np.array([0]+[1000]*6+[3000]*15).squeeze(), np.vstack([[0,0,0],hsph_updated7.vertices,hsph_updated15.vertices]))

In [None]:
torch.manual_seed(1)
np.random.seed(1)
DT1,KT1 = GenDTKT([DT1_full,DT2_full],[x4_full,R1_full,x2_full,R2_full],1,40)
DT2,KT2 = GenDTKT([DT1_lfa,DT2_lfa],[x4_lfa,R1_lfa,x2_lfa,R2_lfa],1,40)
DT3,KT3 = GenDTKT([DT1_hfa,DT2_hfa],[x4_hfa,R1_hfa,x2_hfa,R2_hfa],1,40)
DT4,KT4 = GenDTKT([DT1_ulfa,DT2_ulfa],[x4_ulfa,R1_ulfa,x2_ulfa,R2_ulfa],1,40)
DT5,KT5 = GenDTKT([DT1_hRK,DT2_hRK],[x4_hRK,R1_hRK,x2_hRK,R2_hRK],12,40)

SampsDT = np.vstack([DT1,DT2,DT3,DT4,DT5])
SampsKT = np.vstack([KT1,KT2,KT3,KT4,KT5])

In [None]:
torch.manual_seed(1)
np.random.seed(1)

Samples7  = []

for Sd,Sk in zip(SampsDT,SampsKT):
    Samples7.append([CustoRKKISimulator(Sd,Sk,gtabSimSub, S0=200,snr=scale) for scale in NoiseLevels])

Samples7 = np.array(Samples7)

In [None]:
if os.path.exists(f"{network_path}/DKISimMin.pickle"):
    with open(f"{network_path}/DKISimMin.pickle", "rb") as handle:
        posteriorFull = pickle.load(handle)
else:
    torch.manual_seed(1)
    np.random.seed(1)
    DT = []
    KT = []
    S0 = []
    DT1,KT1 = GenDTKT([DT1_full,DT2_full],[x4_full,R1_full,x2_full,R2_full],12,int(2.5*6000))
    DT2,KT2 = GenDTKT([DT1_lfa,DT2_lfa],[x4_lfa,R1_lfa,x2_lfa,R2_lfa],12,int(2.5*2000))
    DT3,KT3 = GenDTKT([DT1_hfa,DT2_hfa],[x4_hfa,R1_hfa,x2_hfa,R2_hfa],12,int(2.5*6000))
    DT4,KT4 = GenDTKT([DT1_ulfa,DT2_ulfa],[x4_ulfa,R1_ulfa,x2_ulfa,R2_ulfa],12,int(2.5*6000))
    DT5,KT5 = GenDTKT([DT1_hRK,DT2_hRK],[x4_hRK,R1_hRK,x2_hRK,R2_hRK],12,int(2.5*6000))
    
    
    DT = np.vstack([DT1,DT2,DT3,DT4,DT5])
    KT = np.vstack([KT1,KT2,KT3,KT4,KT5])
    
    S0 = np.array(S0).reshape(len(S0),1)
    
    indx = np.arange(len(KT))
    Obs = np.zeros([len(KT),len(gtabSimSub.bvecs)])
    kk = 0
    while len(indx)>0:
        for i in tqdm(indx): 
            Obs[i] = CustoRKKISimulator(DT[i],KT[i],gtabSimSub,200,np.random.rand()*30)
        
        indxNew = []
        for i,O in enumerate(Obs):
            if (O>800).any() or (O<0).any():
                indxNew.append(i)
        KT[indxNew] = KT[indxNew]/2
        DT[indxNew] = GenDTKT([DT1_full,DT2_full],[x4_full,R1_full,x2_full,R2_full],kk,1)[0]
    
        indx = indxNew
        kk+=1
    Par = np.hstack([DT,KT])
    Obs = torch.tensor(Obs).float()
    Par = torch.tensor(Par).float()
    
    # Create inference object. Here, NPE is used.
    inference = SNPE()
    
    # generate simulations and pass to the inference object
    inference = inference.append_simulations(Par, Obs)
    
    # train the density estimator and build the posterior
    density_estimator = inference.train(stop_after_epochs= 100)
    posteriorFull = inference.build_posterior(density_estimator)
    
    os.system('say "Network done."')
    if not os.path.exists(f"{network_path}/DKISimMin.pickle"):
        with open(f"{network_path}/DKISimMin.pickle", "wb") as handle:
            pickle.dump(posteriorFull, handle)

In [None]:
torch.manual_seed(10)
ErrorFull = []
for k in tqdm(range(5)):
    ErrorN2 = []
    ENoise = []
    for i in range(200):
        tObs = Samples7[i,k,:]
        posterior_samples_1 = posteriorFull.sample((InferSamples,), x=tObs,show_progress_bars=False)
        GuessSBI = posterior_samples_1.mean(axis=0)
        
        ErrorN2.append(DKIErrors(GuessSBI[:6],GuessSBI[6:],SampsDT[i],SampsKT[i]))
    ErrorFull.append(ErrorN2)

Error_s = []
dkimodel = dki.DiffusionKurtosisModel(gtabSimSub,fit_method='NLLS')

for k in tqdm(range(5)):
    ErrorN2 = []
    ENoise = []
    for i in range(200):
        tObs = Samples7[i,k,:]#Simulator(bvals,bvecs,200,params,Noise)
        tenfit = dkimodel.fit(tObs)
        
        ErrorN2.append(DKIErrors(tenfit.lower_triangular(),tenfit.kt,SampsDT[i],SampsKT[i]))
    Error_s.append(ErrorN2)



In [None]:
ErrorFull = np.array(ErrorFull)
Error_s = np.array(Error_s)
ErrorNames = ['MK Error', 'AK Error', 'RK Error', 'MKT Error', 'KFA Error']
fig,ax = plt.subplots(1,2,figsize=(9,3))
for i in range(2):
    plt.sca(ax[i])
    g_pos = np.array([1.3,2.3,3.3,4.3])
    colors = ['lightseagreen','lightseagreen','lightseagreen','lightseagreen']
    colors2 = ['paleturquoise','paleturquoise','paleturquoise','paleturquoise']
    BoxPlots(ErrorFull[1:,:,i+3],g_pos,colors,colors2,ax[i],widths=0.3,scatter=False)
    g_pos = np.array([1,2,3,4])
    colors = ['sandybrown','sandybrown','sandybrown','sandybrown']
    colors2 = ['peachpuff','peachpuff','peachpuff','peachpuff']
    BoxPlots(Error_s[1:,:,i+3],g_pos,colors,colors2,ax[i],widths=0.3,scatter=False)
    plt.xticks([1.15, 2.15, 3.15, 4.15,],[20,10,5,2],fontsize=32)
    plt.gca().ticklabel_format(axis='y',style='sci',scilimits=(-1,1))
    plt.grid(axis='y')
    plt.yticks(fontsize=32)
    if(i==0):
        handles = [
            Line2D([0], [0], color=SBIFit, lw=4, label='SBI'),Line2D([0], [0], color=WLSFit, lw=4, label='NLLS'),  # Adjust color as per the actual plot color
        ]
        # Add the legenda
        plt.legend(handles=handles,loc=2, bbox_to_anchor=(-0.1,1.1),
                   fontsize=36,columnspacing=0.3,handlelength=0.6,handletextpad=0.3,labelspacing=0.1)
if Save: plt.savefig(FigLoc+'ErrorsMin_MKTKFA.pdf',format='pdf',bbox_inches='tight',transparent=True)