# Imports

In [None]:
from AxCaliber_funcs import *

%load_ext autoreload
%autoreload 2

## network

In [None]:
if os.path.exists(NetworkDir+"AxCal_Network.pickle"):
    with open(NetworkDir+"AxCal_Network.pickle", "rb") as handle:
        Network = pickle.load(handle)
        print('loaded')
else:
    NumSamps = 2_000_000
    np.random.seed(12)
    x1  = np.random.randn(NumSamps)
    y1  = np.random.randn(NumSamps)
    z1  =  np.random.randn(NumSamps)
    VS = np.vstack([x1,y1,z1])
    VS = (VS/np.linalg.norm(VS,axis=0)).T
    Angs = np.array([SpherAng(v) for v in VS])

    #Diffusion of restricted
    Dpar  = np.random.rand(NumSamps)*5e-3
    Dperp = np.random.rand(NumSamps)*5e-3

    #Diffusion of hindered - needs to be updated
    MD_prior = np.random.rand(NumSamps)*0.005
    FA_prior = np.random.rand(NumSamps)*0.999
    DHind = [mat_to_vals(random_diffusion_tensor(m, f)) for m,f in zip(MD_prior,FA_prior)]

    mean = np.random.rand(NumSamps)*0.005+1e-4
    sig2 = np.random.rand(NumSamps) * (4e-7 - 9e-8) + 9e-8

    #Fraction of hindered
    frac  = np.random.rand(NumSamps)
    #frac  = np.hstack([np.random.rand(NumSamps//5)*0.5+0.5,np.random.rand(4*NumSamps//5)])
    S0Rand =np.random.rand(NumSamps)*2475+25
    TrainParams = np.column_stack([Angs,Dpar,Dperp,DHind,frac,mean,S0Rand])
    
    
    TrainSig = []
    NoisyTrainSig = []
    for i in tqdm(range(NumSamps),position=0,leave=True):
        v = np.array([Angs[i]])
        dpar = Dpar[i]
        dperp = Dperp[i]

        dh   = DHind[i]
        f    = [frac[i],1-frac[i]]

        a = mean[i]
        s0 = S0Rand[i]

        Noise = 50#np.random.rand()*30 + 20

        TrainSig1 = AxCaliber(bvecs[:(n_pts+1)],bvals[:(n_pts+1)],Delta[0],delta,[v,dpar,dperp,dh,f,a,s0])
        TrainSig2 = AxCaliber(bvecs[(n_pts+1):2*(n_pts+1)],bvals[(n_pts+1):2*(n_pts+1)],Delta[1],delta,[v,dpar,dperp,dh,f,a,s0])
        TrainSig3 = AxCaliber(bvecs[2*(n_pts+1):],bvals[2*(n_pts+1):],Delta[2],delta,[v,dpar,dperp,dh,f,a,s0])
        TrainSig.append(np.hstack([TrainSig1,TrainSig2,TrainSig3]))
        NoisyTrainSig.append(AddNoise(TrainSig[-1],s0,Noise))
    NoisyTrainSig = np.array(NoisyTrainSig)

    Obs = torch.tensor(NoisyTrainSig).float()
    
    Vecs = []
    for _ in range(NumSamps):
        r = [1,1,1]
        # random lengths for the 3 vectors from b2000
        #len1 = np.random.randint(1, len(b2000))
        # random base lengths for b4000
        #len2_raw = np.random.randint(1, len(b4000))
        # enforce the "max(6 - len(x1), random)" logic
        #len2 = np.maximum(6 - len1, len2_raw)
        X = np.hstack([
                [0],                                    # force first element = 0
                np.random.choice(b2000, 3, replace=False),
                np.random.choice(b4000, 3, replace=False),
            ])
        # choose which of X[i] to include, with offsets
        offsets = [0, 91, 182]
        temp_parts = [
            X + offsets[i]
            for i,_ in enumerate(r)
        ]

        temp = np.hstack(temp_parts)
        Vecs.append(temp)
        
        feats = []
        for v,TS in zip(Vecs,NoisyTrainSig):
            feats.append(AxcaliberFeatures(bvecs[v], bvals[v], Deltas[v], TS[v]))
        features = np.array(feats)
        Par = torch.tensor(TrainParams).float()
        
        low = Par.min(axis=0)[0] - 10*torch.sign(Par.min(axis=0)[0])*Par.min(axis=0)[0]
        low = np.clip(low,low,-1)
        high = Par.max(axis=0)[0] + 10*Par.max(axis=0)[0]
        
        Obs_feats = torch.tensor(features).float()
        Par = torch.tensor(Par).float()

        inference = SNPE(device='mps')

        # generate simulations and pass to the inference object
        inference = inference.append_simulations(Par, Obs_feats,data_device='cpu')

        # train the density estim ator and build the posterior
        density_estimator = inference.train(training_batch_size = 1024)
        prior_bounds = BoxUniform(low=low, high=high)
        Network = DirectPosterior(density_estimator.cpu(), prior=prior_bounds) 

# Figure 5

In [None]:
Dir = MSDir+'/Ctrl055_R01_28/'
dat = pmt.read_mat(Dir+'data_loaded.mat')
bvecs = dat['direction']
bvals = dat['bval']

n_pts = 90

Deltas = np.concatenate([
    np.full(n_pts + 1, Delta[0]),
    np.full(n_pts + 1, Delta[1]),
    np.full(n_pts + 1, Delta[2]),
])

# Choose the first point (arbitrary starting point, e.g., the first gradient)
selected_indices = [0]
bvecs2000 = bvecs[:91][bvals[:91]==2000]
distance_matrix = squareform(pdist(bvecs2000))
# Iteratively select the point furthest from the current selection
for _ in range(2):  # We need 7 points in total, and one is already selected
    remaining_indices = list(set(range(len(bvecs2000))) - set(selected_indices))
    
    # Calculate the minimum distance to the selected points for each remaini'ng point
    min_distances = np.min(distance_matrix[remaining_indices][:, selected_indices], axis=1)
    
    # Select the point with the maximum minimum distance
    next_index = remaining_indices[np.argmax(min_distances)]
    selected_indices.append(next_index)

selected_indices = selected_indices
bvecs2000_selected = bvecs[:91][bvals[:91]==2000][selected_indices]
true_indices = []
for b in bvecs2000_selected:
    true_indices.append(np.where((b == bvecs).all(axis=1))[0][0])

# Choose the first point (arbitrary starting point, e.g., the first gradient)
selected_indices = [0]
bvecs4000 = bvecs[:91][bvals[:91]==4000]
distance_matrix = squareform(pdist(bvecs4000))
# Iteratively select the point furthest from the current selection
for _ in range(2):  # We need 7 points in total, and one is already selected
    remaining_indices = list(set(range(len(bvecs4000))) - set(selected_indices))
    
    # Calculate the minimum distance to the selected points for each remaining point
    min_distances = np.min(distance_matrix[remaining_indices][:, selected_indices], axis=1)
    
    # Select the point with the maximum minimum distance
    next_index = remaining_indices[np.argmax(min_distances)]
    selected_indices.append(next_index)

selected_indices = selected_indices
bvecs4000_selected = bvecs[:91][bvals[:91]==4000][selected_indices]
for b in bvecs4000_selected:
    true_indices.append(np.where((b == bvecs).all(axis=1))[0][0])
MinIdices = np.hstack([0,true_indices])
DevilIndices = np.hstack([MinIdices,MinIdices+91,MinIdices+182])
bvecs_Dev = bvecs[DevilIndices]
bvals_Dev = bvals[DevilIndices]

bve_split = [bvecs[:(n_pts+1)],bvecs[(n_pts+1):2*(n_pts+1)],bvecs[2*(n_pts+1):]]
bva_split = [bvals[:(n_pts+1)],bvals[(n_pts+1):2*(n_pts+1)],bvals[2*(n_pts+1):]]

## a

In [None]:
np.random.seed(12)
TestSamps = 20

# Directions
x1  = np.random.randn(TestSamps)
y1  = np.random.randn(TestSamps)
z1  =  np.random.randn(TestSamps)
V = np.vstack([x1,y1,z1])
V = (V/np.linalg.norm(V,axis=0)).T
Angs = np.array([SpherAng(v) for v in V])

#Diffusion of restricted
Dpar  = np.random.rand(TestSamps)*5e-3
Dperp = np.random.rand(TestSamps)*5e-3

#Diffusion of hindered
MD_prior = np.random.rand(TestSamps)*0.005
FA_prior = np.random.rand(TestSamps)*0.999
DHind = [mat_to_vals(random_diffusion_tensor(m, f)) for m,f in zip(MD_prior,FA_prior)]

#Fraction of hindered
frac  = np.random.rand(TestSamps)

mean = np.random.rand(TestSamps)*0.005+1e-4
sig2 = np.random.rand(TestSamps) * (4e-7 - 9e-8) + 9e-8

S0Rand =np.ones(TestSamps)

TestParams = np.column_stack([Angs,Dpar,Dperp,DHind,frac,mean])

TestSig = []
NoisyTestSig = []
for i in tqdm(range(TestSamps)):
    v = np.array([Angs[i]])
    dpar = Dpar[i]
    dperp = Dperp[i]
    
    dh   = DHind[i]
    f    = [frac[i],1-frac[i]]

    a = mean[i]
    s = sig2[i]
    alpha     = a * a / s
    scale = s / a
    rv = stats.gamma(a=alpha,scale=scale)
    
    R = np.linspace(0.0001,0.005, 30)
    weights = rv.pdf(R)
    weights = weights/np.sum(weights)
    s0 = 200

    TestSig1 = AxCaliber(bvecs[:(n_pts+1)],bvals[:(n_pts+1)],Delta[0],delta,[v,dpar,dperp,dh,f,a,s0])
    TestSig2 = AxCaliber(bvecs[(n_pts+1):2*(n_pts+1)],bvals[(n_pts+1):2*(n_pts+1)],Delta[1],delta,[v,dpar,dperp,dh,f,a,s0])
    TestSig3 = AxCaliber(bvecs[2*(n_pts+1):],bvals[2*(n_pts+1):],Delta[2],delta,[v,dpar,dperp,dh,f,a,s0])
    TestSig.append(np.hstack([TestSig1,TestSig2,TestSig3]))
    Noisy = []
    for Noise in [2,10,20,30]:
        Noisy.append(AddNoise(TestSig[-1],s0,Noise))
    NoisyTestSig.append(Noisy)
NoisyTestSig = np.array(NoisyTestSig)
NoisyTestSig = np.swapaxes(NoisyTestSig,0,1)
TestSig = np.array(TestSig)

In [None]:
np.random.seed(10)
mean = np.random.rand(1)*0.005+1e-4
MD_prior = np.random.rand(1)*0.005
FA_prior = np.random.rand(1)*0.999
DHind_guess = [mat_to_vals(random_diffusion_tensor(m, f)) for m,f in zip(MD_prior,FA_prior)]

Dpar_guess = np.random.rand()*1e-3            # mm^2/s
Dperp_guess = np.random.rand()*1e-3             # mm^2/s
phi = 0#np.random.rand()*pi
cos_theta = 0#np.random.rand()  # uniform in [0,1]
theta = np.arccos(cos_theta)         # in [0, pi/2]
Angs_guess = np.vstack([theta,phi]).T

mean_guess = np.random.rand()*0.005 + 1e-4

frac_guess = np.random.rand()
guess = np.column_stack([Angs_guess,Dpar_guess,Dperp_guess,DHind_guess,frac_guess,mean_guess]).squeeze()
bounds = np.array([[-np.inf,np.inf]]*12).T
bounds[:,0] = [0,np.pi/2]
bounds[:,1] = [-np.pi,np.pi]
bounds[:,2] = [0,5e-3]
bounds[:,3] = [0,5e-3]
bounds[:,4] = [-5e-3,5e-3]
bounds[:,5] = [-5e-3,5e-3]
bounds[:,6] = [-5e-3,5e-3]
bounds[:,7] = [-5e-3,5e-3]
bounds[:,8] = [-5e-3,5e-3]
bounds[:,9] = [-5e-3,5e-3]
bounds[:,10] = [0,1]
bounds[:,11] = [1e-4,0.005+1e-4]

In [None]:
LS_result = np.zeros([4,20,12])
bve_split = [bvecs[:(n_pts+1)],bvecs[(n_pts+1):2*(n_pts+1)],bvecs[2*(n_pts+1):]]
bva_split = [bvals[:(n_pts+1)],bvals[(n_pts+1):2*(n_pts+1)],bvals[2*(n_pts+1):]]
for i in tqdm(range(20),position = 0, leave = True):
    for j in range(4):
        result = sp.optimize.least_squares(residuals, guess, args=[NoisyTestSig[j,i],bve_split,bva_split,Delta,False],
                                      bounds=bounds,verbose=0,xtol=1e-12,gtol=1e-12,ftol=1e-12,jac='3-point')
        LS_result[j,i] = result.x
LS_Errors = []
for N in tqdm(LS_result,position = 0, leave = True):
    temp = []
    for n_guess,n_true,sig in zip(N,TestParams,TestSig):
        temp.append(AxCal_Errors(sig,n_true,n_guess,Delta,bve_split,bva_split))
    LS_Errors.append(temp)
LS_Errors = np.array(LS_Errors)

bve_splitd = [bvecs_Dev[:7],bvecs_Dev[7:14],bvecs_Dev[14:]]
bva_splitd = [bvals_Dev[:7],bvals_Dev[7:14],bvals_Dev[14:]]
for i in tqdm(range(20),position = 0, leave = True):
    for j in range(4):
        result = sp.optimize.least_squares(residuals, guess, args=[NoisyTestSig[j,i][DevilIndices],bve_splitd,bva_splitd,Delta],
                                      bounds=bounds,verbose=1,xtol=1e-12,gtol=1e-12,ftol=1e-12,jac='3-point')
        LS_result[j,i] = result.x
        
LS_Errors_Min = []
for N in tqdm(LS_result,position = 0, leave = True):
    temp = []
    for n_guess,n_true,sig in zip(N,TestParams,TestSig):
        temp.append(AxCal_Errors(sig,n_true,n_guess,Delta,bve_split,bva_split))
    LS_Errors_Min.append(temp)
LS_Errors_Min = np.array(LS_Errors_Min)

In [None]:
# Define the function for optimization
def fit_SBI(i,j):
    torch.manual_seed(10)  # If required
    posterior_samples_1 = Network.sample((1000,), x=AxcaliberFeatures(bvecs,bvals,Deltas,NoisyTestSig[i,j]),show_progress_bars=False)
    return i, j, posterior_samples_1.mean(axis=0)

y_indx = np.repeat(np.arange(20),4)
x_indx = np.tile(np.arange(4),20)
indices = np.column_stack([x_indx,y_indx])

# Use joblib to parallelize the optimization tasks
results = Parallel(n_jobs=8)(
    delayed(fit_SBI)(i, j) for i, j in tqdm(indices)
)

SBI_Res = np.zeros([4,20,13])

for i, j, x in results:
    SBI_Res[i, j] = x

for i, j, x in results:
    SBI_Res[i, j,-2] = np.clip(SBI_Res[i, j,-2],0,100)
    
SBI_Errors = []
for N in tqdm(SBI_Res):
    temp = []
    for n_guess,n_true,sig in zip(N,TestParams,TestSig):
        temp.append(AxCal_Errors(sig,n_true,n_guess,Delta,bve_split,bva_split))
    SBI_Errors.append(temp)
SBI_Errors = np.array(SBI_Errors)

In [None]:
# Define the function for optimization
def fit_SBI(i,j):
    torch.manual_seed(10)  # If required
    posterior_samples_1 = Network.sample((1000,), x=AxcaliberFeatures(bvecs[DevilIndices],bvals[DevilIndices],Deltas[DevilIndices],NoisyTestSig[i,j,DevilIndices]),show_progress_bars=False)
    return i, j, posterior_samples_1.mean(axis=0)

y_indx = np.repeat(np.arange(20),4)
x_indx = np.tile(np.arange(4),20)
indices = np.column_stack([x_indx,y_indx])

# Use joblib to parallelize the optimization tasks
results = Parallel(n_jobs=8)(
    delayed(fit_SBI)(i, j) for i, j in tqdm(indices)
)

SBI_Res = np.zeros([4,20,13])

for i, j, x in results:
    SBI_Res[i, j] = x

for i, j, x in results:
    SBI_Res[i, j,-2] = np.clip(SBI_Res[i, j,-2],0,100)
    
SBI_Errors_Min = []
for N in tqdm(SBI_Res):
    temp = []
    for n_guess,n_true,sig in zip(N,TestParams,TestSig):
        temp.append(AxCal_Errors(sig,n_true,n_guess,Delta,bve_split,bva_split))
    SBI_Errors_Min.append(temp)
SBI_Errors_Min = np.array(SBI_Errors_Min)
        

In [None]:


# -----------------------------
# Parameters
# -----------------------------
r = 1.0  # sphere radius
vector = np.array([-0.5, -1, 1])   # arbitrary vector
n = vector / np.linalg.norm(vector)  # unit vector in the direction of 'vector'
intersection = n * r  # intersection of the vector with the sphere

# Circle parameters (geodesic circle on the sphere)
circle_angle_deg = 15  # angular radius in degrees
alpha1 = [(S[:,2].mean()) for S in SBI_Errors][-1]

# -----------------------------
# Construct a circle on the sphere
# -----------------------------
# To draw a circle on the sphere centered at 'intersection',
# we use the following idea:
# For a given center n (a point on the unit sphere) and an angular radius alpha,
# any point on the circle can be written as:
#   P(t) = cos(alpha)*n + sin(alpha)*(cos(t)*u + sin(t)*w)
# where u and w are any two orthonormal vectors spanning the tangent plane at n.

# First, choose u as a vector perpendicular to n.
# (If n is parallel to the z-axis, choose a different axis to avoid the zero vector.)
if np.allclose(n, [0, 0, 1]):
    u = np.array([1, 0, 0])
else:
    u = np.cross(n, [0, 0, 1])
    u = u / np.linalg.norm(u)

# Then, w is perpendicular to both n and u.
w = np.cross(n, u)


# -----------------------------
# Create the sphere mesh
# -----------------------------
phi = np.linspace(0, 2 * np.pi, 500)  # azimuthal angle
theta = np.linspace(0, np.pi, 500)      # polar angle

phi, theta = np.meshgrid(phi, theta)
x_sphere = r * np.sin(theta) * np.cos(phi)
y_sphere = r * np.sin(theta) * np.sin(phi)
z_sphere = r * np.cos(theta)

# -----------------------------
# Plot everything
# -----------------------------
fig = plt.figure(figsize=(8, 8))
ax = fig.add_subplot(111, projection='3d')


# Plot the vector (using quiver)
ax.quiver(0, 0, 0, intersection[0], intersection[1], intersection[2],
          color='r', linewidth=2, arrow_length_ratio=0.1)

# Plot the circle on the sphere
# Create points around the circle
t_vals = np.linspace(0, 2 * np.pi, 200)
circle_points = np.array([
    np.cos(alpha1) * n + np.sin(alpha1) * (np.cos(t) * u + np.sin(t) * w)
    for t in t_vals
])

ax.plot(circle_points[:, 0], circle_points[:, 1], circle_points[:, 2], color='paleturquoise', linewidth=2,ls='--')

circle_angle_deg = 15  # angular radius in degrees
alpha2 = [(S[:,2].mean()) for S in SBI_Errors_Min][-1]

# Plot the circle on the sphere
# Create points around the circle
t_vals = np.linspace(0, 2 * np.pi, 100)
circle_points = np.array([
    np.cos(alpha2) * n + np.sin(alpha2) * (np.cos(t) * u + np.sin(t) * w)
    for t in t_vals
])

ax.plot(circle_points[:, 0], circle_points[:, 1], circle_points[:, 2], color='lightseagreen', linewidth=2,ls='--')

alpha3 = [(S[:,2].mean()) for S in LS_Errors][-1]

# Plot the circle on the sphere
# Create points around the circle
t_vals = np.linspace(0, 2 * np.pi, 100)
circle_points = np.array([
    np.cos(alpha3) * n + np.sin(alpha3) * (np.cos(t) * u + np.sin(t) * w)
    for t in t_vals
])

ax.plot(circle_points[:, 0], circle_points[:, 1], circle_points[:, 2], color='sandybrown', linewidth=2,ls='--')

alpha4 = [(S[:,2].mean()) for S in LS_Errors_Min][-1]

# Plot the circle on the sphere
# Create points around the circle
t_vals = np.linspace(0, 2 * np.pi, 100)
circle_points = np.array([
    np.cos(alpha4) * n + np.sin(alpha4) * (np.cos(t) * u + np.sin(t) * w)
    for t in t_vals
])

ax.plot(circle_points[:, 0], circle_points[:, 1], circle_points[:, 2], color='darkorange', linewidth=2,ls='--')

# Set equal aspect ratio for all axes
max_range = r * 1.2
for axis in 'xyz':
    getattr(ax, 'set_{}lim'.format(axis))((-max_range, max_range))


dot = x_sphere * n[0] + y_sphere * n[1] + z_sphere * n[2]
mask = (dot < np.cos(alpha3)) + (dot > np.cos(alpha4))
# Mask out points that are not in the spherical cap (set them to NaN)
x_sphere_masked = np.where(mask, np.nan, x_sphere)
y_sphere_masked = np.where(mask, np.nan, y_sphere)
z_sphere_masked = np.where(mask, np.nan, z_sphere)

# -----------------------------
# Plot everything
# -----------------------------
dot = x_sphere * n[0] + y_sphere * n[1] + z_sphere * n[2]
mask = (dot > np.cos(alpha3)) + (dot < np.cos(alpha4))
# Mask out points that are not in the spherical cap (set them to NaN)
x_sphere_masked = np.where(mask, np.nan, x_sphere)
y_sphere_masked = np.where(mask, np.nan, y_sphere)
z_sphere_masked = np.where(mask, np.nan, z_sphere)

# Plot the spherical cap (inside the circle) with transparency
ax.plot_surface(x_sphere_masked, y_sphere_masked, z_sphere_masked,
                color='darkorange',shade=False, alpha=0.5, rstride=2, cstride=2, edgecolor='none')

dot = x_sphere * n[0] + y_sphere * n[1] + z_sphere * n[2]
mask = (dot > np.cos(alpha2)) + (dot < np.cos(alpha3))
# Mask out points that are not in the spherical cap (set them to NaN)
x_sphere_masked = np.where(mask, np.nan, x_sphere)
y_sphere_masked = np.where(mask, np.nan, y_sphere)
z_sphere_masked = np.where(mask, np.nan, z_sphere)

# Plot the spherical cap (inside the circle) with transparency
ax.plot_surface(x_sphere_masked, y_sphere_masked, z_sphere_masked,
                color='sandybrown',shade=False, alpha=0.5, rstride=2, cstride=2, edgecolor='none')

dot = x_sphere * n[0] + y_sphere * n[1] + z_sphere * n[2]
mask = (dot > np.cos(alpha1)) + (dot < np.cos(alpha2))
# Mask out points that are not in the spherical cap (set them to NaN)
x_sphere_masked = np.where(mask, np.nan, x_sphere)
y_sphere_masked = np.where(mask, np.nan, y_sphere)
z_sphere_masked = np.where(mask, np.nan, z_sphere)

# Plot the spherical cap (inside the circle) with transparency
ax.plot_surface(x_sphere_masked, y_sphere_masked, z_sphere_masked,
                color='lightseagreen',shade=False, alpha=0.5, rstride=2, cstride=2, edgecolor='none')

dot = x_sphere * n[0] + y_sphere * n[1] + z_sphere * n[2]
mask = (dot < np.cos(alpha1))
# Mask out points that are not in the spherical cap (set them to NaN)
x_sphere_masked = np.where(mask, np.nan, x_sphere)
y_sphere_masked = np.where(mask, np.nan, y_sphere)
z_sphere_masked = np.where(mask, np.nan, z_sphere)

# Plot the spherical cap (inside the circle) with transparency
ax.plot_surface(x_sphere_masked, y_sphere_masked, z_sphere_masked,
                color='paleturquoise',alpha=0.5,linewidth=0,rstride=1, cstride=1, shade=False,)

dot = x_sphere * n[0] + y_sphere * n[1] + z_sphere * n[2]
mask = (dot > np.cos(alpha4))
# Mask out points that are not in the spherical cap (set them to NaN)
x_sphere_masked = np.where(mask, np.nan, x_sphere)
y_sphere_masked = np.where(mask, np.nan, y_sphere)
z_sphere_masked = np.where(mask, np.nan, z_sphere)

# Plot the spherical cap (inside the circle) with transparency
ax.plot_surface(x_sphere_masked, y_sphere_masked, z_sphere_masked,
                color='gray', alpha=0.2, rstride=2, cstride=2, edgecolor='none')

ax.axis('equal')
ax.axis('off')
ax.view_init(elev=20, azim=-85)

minLS_patch = mpatches.Patch(color='darkorange', label='Reduced NLLS')
fullLS_patch = mpatches.Patch(color='sandybrown', label='Full NLLS')

minSBI_patch = mpatches.Patch(color='lightseagreen', label='Reduced SBI')
fullSBI_patch = mpatches.Patch(color='paleturquoise', label='Full SBI')

ax.legend(
    handles=[minLS_patch,minSBI_patch,fullLS_patch,fullSBI_patch],
    loc='lower left',         # base location  # fine-tune the legend's position
    frameon=False, ncols=2,
    bbox_to_anchor=(0.18, 0.09),fontsize=18,
    columnspacing=0.5,
    handlelength=0.8,
)
ax.set_title('Average angle diff.',x=0.52, y=0.825,fontsize=24)

## b

In [None]:
fig,ax = plt.subplots(figsize=(8,4))  
g_pos = np.array([0, 2, 4,6])*2
colors = ['lightseagreen','lightseagreen','lightseagreen','lightseagreen']
colors2 = ['paleturquoise','paleturquoise','paleturquoise','paleturquoise']
BoxPlots(SBI_Errors[:,:,1],g_pos,colors,colors2,ax,widths=0.5,scatter=True)

POSITIONS = g_pos+0.5
colors = ['cadetblue','cadetblue','cadetblue','cadetblue']
colors2 = ['darkturquoise','darkturquoise','darkturquoise','darkturquoise']
BoxPlots(SBI_Errors_Min[:,:,1],POSITIONS,colors,colors2,ax,widths=0.5,scatter=True)

POSITIONS = g_pos+1.5
colors = ['sandybrown','sandybrown','sandybrown','sandybrown']
colors2 = ['peachpuff','peachpuff','peachpuff','peachpuff']
BoxPlots(LS_Errors[:,:,1],POSITIONS,colors,colors2,ax,widths=0.5,scatter=True)

POSITIONS = g_pos+2
colors = ['darkorange','darkorange','darkorange','darkorange']
colors2 = ['orange','orange','orange','orange']
BoxPlots(LS_Errors_Min[:,:,1],POSITIONS,colors,colors2,ax,widths=0.5,scatter=True)

ax.set_xticks([1,5,9,13],['2','10','20','30'],fontsize=24)
ax.set_xlabel('SNR',fontsize=32)
ax.tick_params(axis='x', labelsize=24)
ax.ticklabel_format(axis='y',style='sci',scilimits=(-1,1))
ax.tick_params(axis='y', labelsize=24,)
ax.yaxis.get_offset_text().set_fontsize(24)

minLS_patch = mpatches.Patch(color='darkorange', label='Reduced NLLS')
fullLS_patch = mpatches.Patch(color='sandybrown', label='Full NLLS')

minSBI_patch = mpatches.Patch(color='lightseagreen', label='Reduced SBI')
fullSBI_patch = mpatches.Patch(color='paleturquoise', label='Full SBI')

ax.legend(
    handles=[minLS_patch,minSBI_patch],
    loc='lower left',         # base location  # fine-tune the legend's position
    frameon=False, ncols=2,
    bbox_to_anchor=(0.12, 0.88),fontsize=24,
    columnspacing=0.5,
    handlelength=0.8,
)

## c

In [None]:
fig,ax = plt.subplots(figsize=(8,4))  
g_pos = np.array([0, 2, 4,6])*2
colors = ['lightseagreen','lightseagreen','lightseagreen','lightseagreen']
colors2 = ['paleturquoise','paleturquoise','paleturquoise','paleturquoise']
BoxPlots(SBI_Errors[:,:,-1],g_pos,colors,colors2,ax,widths=0.5,scatter=True)

POSITIONS = g_pos+0.5
colors = ['cadetblue','cadetblue','cadetblue','cadetblue']
colors2 = ['darkturquoise','darkturquoise','darkturquoise','darkturquoise']
BoxPlots(SBI_Errors_Min[:,:,-1],POSITIONS,colors,colors2,ax,widths=0.5,scatter=True)

POSITIONS = g_pos+1.5
colors = ['sandybrown','sandybrown','sandybrown','sandybrown']
colors2 = ['peachpuff','peachpuff','peachpuff','peachpuff']
BoxPlots(LS_Errors[:,:,-1],POSITIONS,colors,colors2,ax,widths=0.5,scatter=True)

POSITIONS = g_pos+2
colors = ['darkorange','darkorange','darkorange','darkorange']
colors2 = ['orange','orange','orange','orange']
BoxPlots(LS_Errors_Min[:,:,-1],POSITIONS,colors,colors2,ax,widths=0.5,scatter=True)

ax.set_xticks([1,5,9,13],['2','10','20','30'],fontsize=24)
ax.set_xlabel('SNR',fontsize=32)
ax.tick_params(axis='x', labelsize=24)
ax.ticklabel_format(axis='y',style='sci',scilimits=(-1,1))
ax.tick_params(axis='y', labelsize=24,)
ax.yaxis.get_offset_text().set_fontsize(24)

minLS_patch = mpatches.Patch(color='darkorange', label='Reduced NLLS')
fullLS_patch = mpatches.Patch(color='sandybrown', label='Full NLLS')

minSBI_patch = mpatches.Patch(color='lightseagreen', label='Reduced SBI')
fullSBI_patch = mpatches.Patch(color='paleturquoise', label='Full SBI')

ax.legend(
    handles=[fullLS_patch,fullSBI_patch],
    loc='lower left',         # base location  # fine-tune the legend's position
    frameon=False, ncols=2,
    bbox_to_anchor=(0.18, 0.88),fontsize=24,
    columnspacing=0.5,
    handlelength=0.8,
)

## d

In [None]:
fig,ax = plt.subplots(figsize=(8,4))  
g_pos = np.array([0, 2, 4,6])*2
colors = ['lightseagreen','lightseagreen','lightseagreen','lightseagreen']
colors2 = ['paleturquoise','paleturquoise','paleturquoise','paleturquoise']
BoxPlots(SBI_Errors[:,:,-5],g_pos,colors,colors2,ax,widths=0.5,scatter=True)

POSITIONS = g_pos+0.5
colors = ['cadetblue','cadetblue','cadetblue','cadetblue']
colors2 = ['darkturquoise','darkturquoise','darkturquoise','darkturquoise']
BoxPlots(SBI_Errors_Min[:,:,-5],POSITIONS,colors,colors2,ax,widths=0.5,scatter=True)

POSITIONS = g_pos+1.5
colors = ['sandybrown','sandybrown','sandybrown','sandybrown']
colors2 = ['peachpuff','peachpuff','peachpuff','peachpuff']
BoxPlots(LS_Errors[:,:,-5],POSITIONS,colors,colors2,ax,widths=0.5,scatter=True)

POSITIONS = g_pos+2
colors = ['darkorange','darkorange','darkorange','darkorange']
colors2 = ['orange','orange','orange','orange']
BoxPlots(LS_Errors_Min[:,:,-5],POSITIONS,colors,colors2,ax,widths=0.5,scatter=True)

ax.set_xticks([1,5,9,13],['2','10','20','30'],fontsize=24)
ax.set_xlabel('SNR',fontsize=32)
ax.tick_params(axis='x', labelsize=24)
ax.ticklabel_format(axis='y',style='sci',scilimits=(-1,1))
ax.tick_params(axis='y', labelsize=24,)
ax.yaxis.get_offset_text().set_fontsize(24)

minLS_patch = mpatches.Patch(color='darkorange', label='Reduced NLLS')
fullLS_patch = mpatches.Patch(color='sandybrown', label='Full NLLS')

minSBI_patch = mpatches.Patch(color='lightseagreen', label='Reduced SBI')
fullSBI_patch = mpatches.Patch(color='paleturquoise', label='Full SBI')

## e

In [None]:
fig,ax = plt.subplots(figsize=(8,4))  
g_pos = np.array([0, 2, 4,6])*2
colors = ['lightseagreen','lightseagreen','lightseagreen','lightseagreen']
colors2 = ['paleturquoise','paleturquoise','paleturquoise','paleturquoise']
BoxPlots(SBI_Errors[:,:,-3],g_pos,colors,colors2,ax,widths=0.5,scatter=True)

POSITIONS = g_pos+0.5
colors = ['cadetblue','cadetblue','cadetblue','cadetblue']
colors2 = ['darkturquoise','darkturquoise','darkturquoise','darkturquoise']
BoxPlots(SBI_Errors_Min[:,:,-3],POSITIONS,colors,colors2,ax,widths=0.5,scatter=True)

POSITIONS = g_pos+1.5
colors = ['sandybrown','sandybrown','sandybrown','sandybrown']
colors2 = ['peachpuff','peachpuff','peachpuff','peachpuff']
BoxPlots(LS_Errors[:,:,-3],POSITIONS,colors,colors2,ax,widths=0.5,scatter=True)

POSITIONS = g_pos+2
colors = ['darkorange','darkorange','darkorange','darkorange']
colors2 = ['orange','orange','orange','orange']
BoxPlots(LS_Errors_Min[:,:,-3],POSITIONS,colors,colors2,ax,widths=0.5,scatter=True)

ax.set_xticks([1,5,9,13],['2','10','20','30'],fontsize=24)
ax.set_xlabel('SNR',fontsize=32)
ax.tick_params(axis='x', labelsize=24)
ax.ticklabel_format(axis='y',style='sci',scilimits=(-1,1))
ax.tick_params(axis='y', labelsize=24,)
ax.yaxis.get_offset_text().set_fontsize(24)

minLS_patch = mpatches.Patch(color='darkorange', label='Reduced NLLS')
fullLS_patch = mpatches.Patch(color='sandybrown', label='Full NLLS')

minSBI_patch = mpatches.Patch(color='lightseagreen', label='Reduced SBI')
fullSBI_patch = mpatches.Patch(color='paleturquoise', label='Full SBI')


## f

In [None]:
Dir = MSDir+'/Ctrl055_R01_28/'
dat = pmt.read_mat(Dir+'data_loaded.mat')
bvecs = dat['direction']
bvals = dat['bval']
FixedParams = {
    'bvals':bvals,
    'bvecs':bvecs,
    'Delta':[0.017,0.035,0.061],
    'delta':0.007,
}
Delta = FixedParams['Delta']
delta = FixedParams['delta']
n_pts = 90

Delta = [0.017,0.035,0.061] # We know this 
delta = 0.007 # We know this 


data = dat['data']
axial_middle = data.shape[2] // 2
maskdata, mask = median_otsu(data, vol_idx=range(0, 10), median_radius=5,
                             numpass=1, autocrop=False, dilate=2)

S_mask, _, _ = load_nifti(Dir+'mask_055.nii.gz', return_img=True)


mask1 = np.ones_like(S_mask[:,54,:])
mask1[S_mask[:,54,:]==0] = 0
structure = np.ones((3, 3), dtype=bool)

floor = np.clip(maskdata.min(axis=-1),-np.inf,0)
maskdata_2 = np.copy(maskdata)
maskdata_2[floor <=0 ] = maskdata[floor <= 0] + abs(floor)[floor <=0 ,None] + 1e-5

# Apply dilation. Increase 'iterations' to make the mask even fatter.
fat_mask = binary_dilation(mask1, structure=structure, iterations=1)

In [None]:
# Compute the mask where the sum is not zero
mask = np.sum(maskdata[:, :, axial_middle, :], axis=-1) != 0

# Get the indices where mask is True
indices = np.argwhere(mask)
def optimize_chunk(pixels):
    results = []
    for i, j in pixels:
        samples = Network.sample((1000,), x=AxcaliberFeatures(bvecs[:],bvals[:],Deltas[:],maskdata_2[i, j,axial_middle, :]),show_progress_bars=False)        
        results.append((i, j, samples.mean(axis=0)))
    return results

chunked_indices = [indices[i:i+ChunkSize] for i in range(0, len(indices), ChunkSize)]
results = Parallel(n_jobs=8)(
    delayed(optimize_chunk)(chunk) for chunk in tqdm(chunked_indices)
)

# Initialize NoiseEst with the appropriate shape
ArrShape = mask.shape

NoiseEst = np.zeros(list(ArrShape) + [13])

# Assign the optimization results to NoiseEst
for chunk in results:
    for i, j, x in chunk:
        NoiseEst[i, j] = x
        NoiseEst[i, j,-2] = np.clip(NoiseEst[i, j,-2],0,100)
        NoiseEst[i, j,-3] = np.clip(NoiseEst[i, j,-3],0,1)
NoiseEst2 = np.copy(NoiseEst)

for i in range(13):
    NoiseEst2[~mask,i] = math.nan

NoiseEst2[(1-NoiseEst2[...,-3])<0.3,-2] = math.nan

In [None]:
plt.subplots(figsize=(12,12))
im = plt.imshow(1-NoiseEst2[...,-3],vmin=0.1,vmax=1,cmap='hot')
plt.axis('off')
#plt.colorbar()

In [None]:
Save = False

In [None]:
np.random.seed(133)
S0 = 2000
MD_prior = np.random.rand(1)*0.005
FA_prior = np.random.rand(1)*0.999
DHind_guess = [mat_to_vals(random_diffusion_tensor(m, f)) for m,f in zip(MD_prior,FA_prior)]

Dpar_guess = np.random.rand()*1e-3            # mm^2/s
Dperp_guess = np.random.rand()*1e-3             # mm^2/s
phi = 0#np.random.rand()*pi
cos_theta = 0#np.random.rand()  # uniform in [0,1]
theta = np.arccos(cos_theta)         # in [0, pi/2]
Angs_guess = np.vstack([theta,phi]).T
S0_guess =np.random.rand()*2475+25

frac_guess = np.random.rand()
guess = np.column_stack([Angs_guess,Dpar_guess,Dperp_guess,DHind_guess,frac_guess,mean_guess,S0_guess]).squeeze()

In [None]:
np.random.seed(133)
S0 = 2000
MD_prior = np.random.rand(1)*0.005
FA_prior = np.random.rand(1)*0.999
DHind_guess = [mat_to_vals(random_diffusion_tensor(m, f)) for m,f in zip(MD_prior,FA_prior)]

Dpar_guess = np.random.rand()*1e-3            # mm^2/s
Dperp_guess = np.random.rand()*1e-3             # mm^2/s
phi = 0#np.random.rand()*pi
cos_theta = 0#np.random.rand()  # uniform in [0,1]
theta = np.arccos(cos_theta)         # in [0, pi/2]
Angs_guess = np.vstack([theta,phi]).T
S0_guess =np.random.rand()*2475+25

frac_guess = np.random.rand()
guess = np.column_stack([Angs_guess,Dpar_guess,Dperp_guess,DHind_guess,frac_guess,mean_guess,S0_guess]).squeeze()
bounds = np.array([[-np.inf,np.inf]]*13).T
bounds[:,0] = [0,np.pi/2]
bounds[:,1] = [-np.pi,np.pi]
bounds[:,2] = [0,5e-3]
bounds[:,3] = [0,5e-3]
bounds[:,4] = [-5e-3,5e-3]
bounds[:,5] = [-5e-3,5e-3]
bounds[:,6] = [-5e-3,5e-3]
bounds[:,7] = [-5e-3,5e-3]
bounds[:,8] = [-5e-3,5e-3]
bounds[:,9] = [-5e-3,5e-3]
bounds[:,10] = [0,1]
bounds[:,11] = [1e-4,0.005+1e-4]
bounds[:,12] = [25,2500]

bve_split = [bvecs[:(n_pts+1)],bvecs[(n_pts+1):2*(n_pts+1)],bvecs[2*(n_pts+1):]]
bva_split = [bvals[:(n_pts+1)],bvals[(n_pts+1):2*(n_pts+1)],bvals[2*(n_pts+1):]]

In [None]:
# Compute the mask where the sum is not zero
mask = np.sum(maskdata[:, :, axial_middle, :], axis=-1) != 0

# Get the indices where mask is True
indices = np.argwhere(mask)

def optimize_pixel_LS(i, j):
    result = sp.optimize.least_squares(residuals, guess, args=[maskdata_2[i, j, axial_middle, :],bve_split,bva_split,Delta,True],
                              bounds=bounds,verbose=0,xtol=1e-12,gtol=1e-12,ftol=1e-12,jac='3-point')
    return i, j, result.x

# Initialize NoiseEst with the appropriate shape
ArrShape = mask.shape

# Use joblib to parallelize the optimization tasks
results = Parallel(n_jobs=8)(
    delayed(optimize_pixel_LS)(i, j) for i, j in tqdm(indices,position=0,leave=True)
)


NoiseEst_LS = np.zeros(list(ArrShape) + [13])

# Assign the optimization results to NoiseEst
for i, j, x in results:
    NoiseEst_LS[i, j] = x

NoiseEst2_LS = np.copy(NoiseEst_LS)
for i in range(13):
    NoiseEst2_LS[~mask,i] = math.nan

NoiseEst2_LS[(1-NoiseEst2_LS[...,-3])<0.3,-2] = math.nan

In [None]:
plt.subplots(figsize=(12,12))
im = plt.imshow(1-NoiseEst2_LS[...,-3],vmin=0.0,vmax=1,cmap='hot')
cbar = plt.colorbar(im,fraction=0.035, pad=-0.1)
cbar.ax.tick_params(labelsize=32)
plt.axis('off')

In [None]:
# Compute the mask where the sum is not zero
mask = np.sum(maskdata[:, :, axial_middle, :], axis=-1) != 0

# Get the indices where mask is True
indices = np.argwhere(mask)
def optimize_chunk(pixels):
    results = []
    for i, j in pixels:
        samples = Network.sample((1000,), x=AxcaliberFeatures(bvecs[DevilIndices],bvals[DevilIndices],Deltas[DevilIndices],maskdata_2[i, j,axial_middle, DevilIndices]),show_progress_bars=False)        
        results.append((i, j, samples.mean(axis=0)))
    return results

chunked_indices = [indices[i:i+ChunkSize] for i in range(0, len(indices), ChunkSize)]
results = Parallel(n_jobs=8)(
    delayed(optimize_chunk)(chunk) for chunk in tqdm(chunked_indices)
)

# Initialize NoiseEst with the appropriate shape
ArrShape = mask.shape

NoiseEst = np.zeros(list(ArrShape) + [13])

# Assign the optimization results to NoiseEst
for chunk in results:
    for i, j, x in chunk:
        NoiseEst[i, j] = x
        NoiseEst[i, j,-2] = np.clip(NoiseEst[i, j,-2],0,100)
        NoiseEst[i, j,-3] = np.clip(NoiseEst[i, j,-3],0,1)
NoiseEst2_min = np.copy(NoiseEst)

for i in range(13):
    NoiseEst2_min[~mask,i] = math.nan

NoiseEst2_min[(1-NoiseEst2_min[...,-3])<0.3,-2] = math.nan

In [None]:
plt.subplots(figsize=(12,12))
temp = gaussian_filter(1-NoiseEst2_min[...,-3], sigma=0.5)
im = plt.imshow(temp,vmin=0.1,vmax=1,cmap='hot')
plt.axis('off')
#plt.colorbar()

In [None]:
bve_splitD = [bvecs_Dev[:7],bvecs_Dev[7:14],bvecs_Dev[14:]]
bva_splitD = [bvals_Dev[:7],bvals_Dev[7:14],bvals_Dev[14:]]
# Compute the mask where the sum is not zero
mask = np.sum(maskdata[:, :, axial_middle, :], axis=-1) != 0

# Get the indices where mask is True
indices = np.argwhere(mask)

# Define the function for optimization
def optimize_pixel_LS(i, j):
    result = sp.optimize.least_squares(residuals, guess, args=[maskdata[i, j,axial_middle, DevilIndices,True],bve_splitD,bva_splitD,Delta],
                          bounds=bounds,verbose=0,xtol=1e-12,gtol=1e-12,ftol=1e-12,jac='3-point')
    return i, j, result.x

# Initialize NoiseEst with the appropriate shape
ArrShape = mask.shape

# Use joblib to parallelize the optimization tasks
results = Parallel(n_jobs=8)(
    delayed(optimize_pixel_LS)(i, j) for i, j in tqdm(indices,position=0,leave = True)
)


NoiseEst_LS_Min = np.zeros(list(ArrShape) + [13])
bve_splitD
# Assign the optimization results to NoiseEst
for i, j, x in results:
    NoiseEst_LS_Min[i, j] = x

In [None]:
plt.subplots(figsize=(12,12))
temp = gaussian_filter(1-NoiseEst_LS_Min[...,-3], sigma=0.5)
im = plt.imshow(temp,vmin=0.0,vmax=1,cmap='hot')
cbar = plt.colorbar(im,fraction=0.035, pad=-0.1)
cbar.ax.tick_params(labelsize=32)
plt.axis('off')

## g

In [None]:
# Compute the mask where the sum is not zero
mask = np.sum(maskdata[:, 54, :, :], axis=-1) != 0

# Get the indices where mask is True
indices = np.argwhere(mask)
def optimize_chunk(pixels):
    results = []
    for i, j in pixels:
        posterior_samples_1 = Network.sample((1000,), x=AxcaliberFeatures(bvecs[:],bvals[:],Deltas[:],maskdata[i, 54,j, :]),show_progress_bars=False)        
        results.append((i, j, posterior_samples_1.mean(axis=0)))
    return results

chunked_indices = [indices[i:i+ChunkSize] for i in range(0, len(indices), ChunkSize)]
results = Parallel(n_jobs=8)(
    delayed(optimize_chunk)(chunk) for chunk in tqdm(chunked_indices)
)

# Initialize NoiseEst with the appropriate shape
ArrShape = mask.shape

NoiseEst_CC = np.zeros(list(ArrShape) + [13])

# Assign the optimization results to NoiseEst
for chunk in results:
    for i, j, x in chunk:
        NoiseEst_CC[i, j] = x
        NoiseEst_CC[i, j,-2] = np.clip(NoiseEst_CC[i, j,-2],0,100)
        NoiseEst_CC[i, j,-3] = np.clip(NoiseEst_CC[i, j,-3],0,1)
NoiseEst2_CC = np.copy(NoiseEst_CC)

comb_mask = mask1.astype(bool) * ((1-NoiseEst2_CC[...,-3])>0.1)

mask_CC = (1-NoiseEst2_CC[...,-3])<0.3
for i in range(13):
    NoiseEst2_CC[~mask,i] = math.nan

NoiseEst2_CC[~comb_mask,-2] = math.nan

In [None]:
# Compute the mask where the sum is not zero
mask = np.sum(maskdata[:, 54, :, :], axis=-1) != 0

# Get the indices where mask is True
indices = np.argwhere(mask)
def optimize_chunk(pixels):
    results = []
    for i, j in pixels:
        posterior_samples_1 = Network.sample((1000,), x=AxcaliberFeatures(bvecs[DevilIndices],bvals[DevilIndices],Deltas[DevilIndices],maskdata[i, 54,j, DevilIndices]),show_progress_bars=False)        
        results.append((i, j, posterior_samples_1.mean(axis=0)))
    return results

chunked_indices = [indices[i:i+ChunkSize] for i in range(0, len(indices), ChunkSize)]
results = Parallel(n_jobs=8)(
    delayed(optimize_chunk)(chunk) for chunk in tqdm(chunked_indices)
)

# Initialize NoiseEst with the appropriate shape
ArrShape = mask.shape

NoiseEst_Min_CC = np.zeros(list(ArrShape) + [13])

# Assign the optimization results to NoiseEst
for chunk in results:
    for i, j, x in chunk:
        NoiseEst_Min_CC[i, j] = x
        NoiseEst_Min_CC[i, j,-2] = np.clip(NoiseEst_Min_CC[i, j,-2],0,100)
        NoiseEst_Min_CC[i, j,-3] = np.clip(NoiseEst_Min_CC[i, j,-3],0,1)
NoiseEst_Min_CC = np.copy(NoiseEst_Min_CC)

for i in range(13):
    NoiseEst_Min_CC[~mask,i] = math.nan

NoiseEst_Min_CC[~comb_mask,-2] = math.nan

In [None]:
plt.subplots(figsize=(12,12))
plt.imshow(np.flipud(NoiseEst2_CC[...,-1].T),cmap='gray')
im = plt.imshow(np.flipud(NoiseEst2_CC[...,-2].T),cmap='hot',vmin=0.001,vmax=0.006)
#cbar = plt.colorbar(im,fraction=0.035, pad=0.01,format=ticker.FormatStrFormatter('%.e'))
#cbar.ax.tick_params(labelsize=14)
plt.axis('off')

In [None]:
plt.subplots(figsize=(12,12))
plt.imshow(np.flipud(NoiseEst_Min_CC[...,-1].T),cmap='gray')
norm = TwoSlopeNorm(vmin=0, vcenter=0.0042, vmax=0.006)
im = plt.imshow(np.flipud(NoiseEst_Min_CC[...,-2].T),cmap='hot',norm=norm)
#cbar = plt.colorbar(im,fraction=0.035, pad=0.01,format=ticker.FormatStrFormatter('%.e'))
#cbar.ax.tick_params(labelsize=14)
plt.axis('off')


In [None]:
# Compute the mask where the sum is not zero
mask = np.sum(maskdata[:, 54, :, :], axis=-1) != 0

# Define the function for optimization
def optimize_pixel_LS(i, j):
    result = sp.optimize.least_squares(residuals, guess, args=[maskdata[i, 54, j, :],bve_split,bva_split,Delta,True],
                              bounds=bounds,verbose=0,xtol=1e-12,gtol=1e-12,ftol=1e-12,jac='3-point')
    return i, j, result.x

# Initialize NoiseEst with the appropriate shape
ArrShape = mask.shape

# Use joblib to parallelize the optimization tasks
results = Parallel(n_jobs=8)(
    delayed(optimize_pixel_LS)(i, j) for i, j in tqdm(indices,position=0,leave=True)
)


NoiseEst_LS_CC = np.zeros(list(ArrShape) + [13])

# Assign the optimization results to NoiseEst
for i, j, x in results:
    NoiseEst_LS_CC[i, j] = x

NoiseEst2_LS_CC = np.copy(NoiseEst_LS_CC)

for i in range(13):
    NoiseEst2_LS_CC[~mask,i] = math.nan

NoiseEst2_LS_CC[~comb_mask,-2] = math.nan

In [None]:
# Compute the mask where the sum is not zero
mask = np.sum(maskdata[:, 54, :, :], axis=-1) != 0

# Get the indices where mask is True
indices = np.argwhere(mask)

def optimize_pixel_LS(i, j):
    result = sp.optimize.least_squares(residuals, guess, args=[maskdata[i,54,j, DevilIndices,True],bve_splitD,bva_splitD,Delta],
                          bounds=bounds,verbose=0,xtol=1e-12,gtol=1e-12,ftol=1e-12,jac='3-point')
    return i, j, result.x


# Initialize NoiseEst with the appropriate shape


ArrShape = mask.shape

# Use joblib to parallelize the optimization tasks
results = Parallel(n_jobs=8)(
    delayed(optimize_pixel_LS)(i, j) for i, j in tqdm(indices,position=0,leave=True)
)


NoiseEst_LS_Min_CC = np.zeros(list(ArrShape) + [13])

# Assign the optimization results to NoiseEst
for i, j, x in results:
    NoiseEst_LS_Min_CC[i, j] = x
NoiseEst2_LS_Min_CC = np.copy(NoiseEst_LS_Min_CC)

for i in range(13):
    NoiseEst2_LS_Min_CC[~mask,i] = math.nan

NoiseEst2_LS_Min_CC[~comb_mask,-2] = math.nan

In [None]:
plt.subplots(figsize=(12,12))
plt.imshow(np.flipud(NoiseEst2_LS_CC[...,-1].T),cmap='gray')
im = plt.imshow(np.flipud(NoiseEst2_LS_CC[...,-2].T)*1000,cmap='hot',vmin=0,vmax=6)
cbar = plt.colorbar(im,fraction=0.03, pad=0.01)
cbar.ax.tick_params(labelsize=32)
plt.axis('off')

In [None]:
plt.subplots(figsize=(12,12))
plt.imshow(np.flipud(NoiseEst2_LS_Min_CC[...,-1].T),cmap='gray')
im = plt.imshow(np.flipud(NoiseEst2_LS_Min_CC[...,-2].T)*1000,cmap='hot',vmin=0,vmax=6)
cbar = plt.colorbar(im,fraction=0.03, pad=0.01)
cbar.ax.tick_params(labelsize=32)
plt.axis('off')


## h

In [None]:
Dirs = ['Ctrl055_R01_28','Ctrl056_R01_29','Ctrl057_R01_30']
Masks = ['mask_055.nii.gz','mask_056.nii.gz','mask_057.nii.gz']
BVecs = []
BVals = []
S_masks = []
Datas = []
Outlines = []
for D,M in tqdm(zip(Dirs,Masks)):
    dat = pmt.read_mat(MSDir+D+'/data_loaded.mat')
    BVecs.append(dat['direction'])
    BVals.append(dat['bval'])
    
    m, _, _ = load_nifti(MSDir+D+'/'+M, return_img=True)
    S_masks.append(m)

    data = dat['data']
    axial_middle = data.shape[2] // 2
    md, mk = median_otsu(data, vol_idx=range(0, 10), median_radius=5,
                                 numpass=1, autocrop=False, dilate=2)
    Datas.append(md)
    Outlines.append(mk)

In [None]:
# Choose the first point (arbitrary starting point, e.g., the first gradient)
IndxArr  = []
BVecsDev = []
BValsDev = []
for bve,bva in zip(BVecs,BVals): 
    # Choose the first point (arbitrary starting point, e.g., the first gradient)
    selected_indices = [0]
    bvecs2000 = bve[:91][bva[:91]==2000]
    distance_matrix = squareform(pdist(bvecs2000))
    # Iteratively select the point furthest from the current selection
    for _ in range(2):  # We need 7 points in total, and one is already selected
        remaining_indices = list(set(range(len(bvecs2000))) - set(selected_indices))
        
        # Calculate the minimum distance to the selected points for each remaining point
        min_distances = np.min(distance_matrix[remaining_indices][:, selected_indices], axis=1)
        
        # Select the point with the maximum minimum distance
        next_index = remaining_indices[np.argmax(min_distances)]
        selected_indices.append(next_index)
    
    selected_indices = selected_indices
    bvecs2000_selected = bve[:91][bva[:91]==2000][selected_indices]
    true_indices = []
    for b in bvecs2000_selected:
        true_indices.append(np.where((b == bve).all(axis=1))[0][0])
    
    # Choose the first point (arbitrary starting point, e.g., the first gradient)
    selected_indices = [0]
    bvecs4000 = bve[:91][bva[:91]==4000]
    distance_matrix = squareform(pdist(bvecs4000))
    # Iteratively select the point furthest from the current selection
    for _ in range(2):  # We need 7 points in total, and one is already selected
        remaining_indices = list(set(range(len(bvecs4000))) - set(selected_indices))
        
        # Calculate the minimum distance to the selected points for each remaining point
        min_distances = np.min(distance_matrix[remaining_indices][:, selected_indices], axis=1)
        
        # Select the point with the maximum minimum distance
        next_index = remaining_indices[np.argmax(min_distances)]
        selected_indices.append(next_index)
    
    selected_indices = selected_indices
    bvecs4000_selected = bve[:91][bva[:91]==4000][selected_indices]
    for b in bvecs4000_selected:
        true_indices.append(np.where((b == bve).all(axis=1))[0][0])
    true_indices1 = true_indices
    
    # Choose the first point (arbitrary starting point, e.g., the first gradient)
    selected_indices = [0]
    bvecs2000 = bve[91:182][bva[91:182]==2000]
    distance_matrix = squareform(pdist(bvecs2000))
    # Iteratively select the point furthest from the current selection
    for _ in range(2):  # We need 7 points in total, and one is already selected
        remaining_indices = list(set(range(len(bvecs2000))) - set(selected_indices))
        
        # Calculate the minimum distance to the selected points for each remaining point
        min_distances = np.min(distance_matrix[remaining_indices][:, selected_indices], axis=1)
        
        # Select the point with the maximum minimum distance
        next_index = remaining_indices[np.argmax(min_distances)]
        selected_indices.append(next_index)
    
    selected_indices = selected_indices
    bvecs2000_selected = bve[91:182][bva[91:182]==2000][selected_indices]
    true_indices = []
    for b in bvecs2000_selected:
        true_indices.append(np.where((b == bve).all(axis=1))[0][0])
    
    # Choose the first point (arbitrary starting point, e.g., the first gradient)
    selected_indices = [0]
    bvecs4000 = bve[91:182][bva[91:182]==4000]
    distance_matrix = squareform(pdist(bvecs4000))
    # Iteratively select the point furthest from the current selection
    for _ in range(2):  # We need 7 points in total, and one is already selected
        remaining_indices = list(set(range(len(bvecs4000))) - set(selected_indices))
        
        # Calculate the minimum distance to the selected points for each remaining point
        min_distances = np.min(distance_matrix[remaining_indices][:, selected_indices], axis=1)
        
        # Select the point with the maximum minimum distance
        next_index = remaining_indices[np.argmax(min_distances)]
        selected_indices.append(next_index)
    
    selected_indices = selected_indices
    bvecs4000_selected = bve[91:182][bva[91:182]==4000][selected_indices]
    for b in bvecs4000_selected:
        true_indices.append(np.where((b == bve).all(axis=1))[0][0])
    true_indices2 = true_indices
    
    # Choose the first point (arbitrary starting point, e.g., the first gradient)
    selected_indices = [0]
    bvecs2000 = bve[182:][bva[182:]==2000]
    distance_matrix = squareform(pdist(bvecs2000))
    # Iteratively select the point furthest from the current selection
    for _ in range(2):  # We need 7 points in total, and one is already selected
        remaining_indices = list(set(range(len(bvecs2000))) - set(selected_indices))
        
        # Calculate the minimum distance to the selected points for each remaining point
        min_distances = np.min(distance_matrix[remaining_indices][:, selected_indices], axis=1)
        
        # Select the point with the maximum minimum distance
        next_index = remaining_indices[np.argmax(min_distances)]
        selected_indices.append(next_index)
    
    selected_indices = selected_indices
    bvecs2000_selected = bve[182:][bva[182:]==2000][selected_indices]
    true_indices = []
    for b in bvecs2000_selected:
        true_indices.append(np.where((b == bve).all(axis=1))[0][0])
    
    # Choose the first point (arbitrary starting point, e.g., the first gradient)
    selected_indices = [0]
    bvecs4000 = bve[182:][bva[182:]==4000]
    distance_matrix = squareform(pdist(bvecs4000))
    # Iteratively select the point furthest from the current selection
    for _ in range(2):  # We need 7 points in total, and one is already selected
        remaining_indices = list(set(range(len(bvecs4000))) - set(selected_indices))
        
        # Calculate the minimum distance to the selected points for each remaining point
        min_distances = np.min(distance_matrix[remaining_indices][:, selected_indices], axis=1)
        
        # Select the point with the maximum minimum distance
        next_index = remaining_indices[np.argmax(min_distances)]
        selected_indices.append(next_index)
    
    selected_indices = selected_indices
    bvecs4000_selected = bve[182:][bva[182:]==4000][selected_indices]
    for b in bvecs4000_selected:
        true_indices.append(np.where((b == bve).all(axis=1))[0][0])
    true_indices3 = true_indices
    
    DevIndices = [0] + true_indices1 + [n_pts] +  true_indices2 + [n_pts+1] + true_indices3
    bvecs_Dev = bve[DevIndices]
    bvals_Dev = bva[DevIndices]

    IndxArr.append(DevIndices)
    BVecsDev.append(bvecs_Dev)
    BValsDev.append(bvals_Dev)

In [None]:
Full_SBI = []
for kk,(D,sl,sma) in enumerate(zip(Datas,[54,52,54],S_masks)):
    # Compute the mask where the sum is not zero
    mask = sma[:,sl,:]
    
    # Get the indices where mask is True
    indices = np.argwhere(mask)
    
    # Define the function for optimization
    def optimize_pixel(i, j):
        torch.manual_seed(10)  # If required
        posterior_samples_1 = Network.sample((1000,), x=AxcaliberFeatures(BVecs[kk],BVals[kk],Deltas,D[i, sl, j, :]),show_progress_bars=False)
        return i, j, posterior_samples_1.mean(axis=0)
    
    # Initialize NoiseEst with the appropriate shape
    ArrShape = mask.shape
    
    # Use joblib to parallelize the optimization tasks
    results = Parallel(n_jobs=8)(
        delayed(optimize_pixel)(i, j) for i, j in tqdm(indices)
    )
    
    
    NoiseEst = np.zeros(list(ArrShape) + [13])
    
    # Assign the optimization results to NoiseEst
    for i, j, x in results:
        NoiseEst[i, j] = x

    Full_SBI.append(NoiseEst)

In [None]:
Min_SBI = []
for kk,(D,sl,sma) in enumerate(zip(Datas,[54,52,54],S_masks)):
    # Compute the mask where the sum is not zero
    mask = sma[:,sl,:]
    
    # Get the indices where mask is True
    indices = np.argwhere(mask)
    
    # Define the function for optimization
    def optimize_pixel(i, j):
        torch.manual_seed(10)  # If required
        posterior_samples_1 = Network.sample((1000,), x=AxcaliberFeatures(BVecs[kk][IndxArr[kk]],BVals[kk][IndxArr[kk]],Deltas[IndxArr[kk]],D[i, sl, j, IndxArr[kk]]),show_progress_bars=False)
        return i, j, posterior_samples_1.mean(axis=0)
    
    # Initialize NoiseEst with the appropriate shape
    ArrShape = mask.shape
    
    # Use joblib to parallelize the optimization tasks
    results = Parallel(n_jobs=8)(
        delayed(optimize_pixel)(i, j) for i, j in tqdm(indices,position=0,leave=True)
    )
    
    
    NoiseEst = np.zeros(list(ArrShape) + [13])
    
    # Assign the optimization results to NoiseEst
    for i, j, x in results:
        NoiseEst[i, j] = x

    Min_SBI.append(NoiseEst)

In [None]:
Full_LS = []
for kk,(D,sl,sma) in enumerate(zip(Datas,[54,52,54],S_masks)):
    # Compute the mask where the sum is not zero
    mask = sma[:,sl,:]
    
    # Get the indices where mask is True
    indices = np.argwhere(mask)
    bve_split_kk = [BVecs[kk][:(n_pts+1)],BVecs[kk][(n_pts+1):2*(n_pts+1)],BVecs[kk][2*(n_pts+1):]]
    bva_split_kk = [BVals[kk][:(n_pts+1)],BVals[kk][(n_pts+1):2*(n_pts+1)],BVals[kk][2*(n_pts+1):]]
    # Define the function for optimization
    def optimize_pixel_LS(i, j):
        result = sp.optimize.least_squares(residuals, guess, args=[D[i, sl, j, :],bve_split_kk,bva_split_kk,Delta,True],
                                  bounds=bounds,verbose=0,xtol=1e-12,gtol=1e-12,ftol=1e-12,jac='3-point')
        return i, j, result.x
    
    # Initialize NoiseEst with the appropriate shape
    ArrShape = mask.shape
    
    # Use joblib to parallelize the optimization tasks
    results = Parallel(n_jobs=8)(
        delayed(optimize_pixel_LS)(i, j) for i, j in tqdm(indices,position=0,leave=True)
    )
    
    
    NoiseEst_LS = np.zeros(list(ArrShape) + [13])
    
    # Assign the optimization results to NoiseEst
    for i, j, x in results:
        NoiseEst_LS[i, j] = x

    Full_LS.append(NoiseEst_LS)

In [None]:
Min_LS = []
for kk,(D,sl,sma) in enumerate(zip(Datas,[54,52,54],S_masks)):
    # Compute the mask where the sum is not zero
    mask = sma[:,sl,:]
    
    # Get the indices where mask is True
    indices = np.argwhere(mask)
    bve_splitd_kk = [BVecsDev[kk][:7],BVecsDev[kk][7:13],BVecsDev[kk][13:]]
    bva_splitd_kk = [BValsDev[kk][:7],BValsDev[kk][7:13],BValsDev[kk][13:]]

    # Define the function for optimization
    def optimize_pixel_LS(i, j):
        result = sp.optimize.least_squares(residuals, guess, args=[D[i, sl, j, IndxArr[kk]],bve_splitd_kk,bva_splitd_kk,Delta,True],
                                  bounds=bounds,verbose=0,xtol=1e-12,gtol=1e-12,ftol=1e-12,jac='3-point')
        return i, j, result.x
    
    # Initialize NoiseEst with the appropriate shape
    ArrShape = mask.shape
    
    # Use joblib to parallelize the optimization tasks
    results = Parallel(n_jobs=8)(
        delayed(optimize_pixel_LS)(i, j) for i, j in tqdm(indices,position=0,leave=True)
    )
    
    
    NoiseEst_LS = np.zeros(list(ArrShape) + [13])
    
    # Assign the optimization results to NoiseEst
    for i, j, x in results:
        NoiseEst_LS[i, j] = x

    Min_LS.append(NoiseEst_LS)

In [None]:
CMasks = []
kk = 0
d  = 54
temp = np.copy(Full_SBI[kk])
for i in range(13):
    temp[~Outlines[kk][:,d,:],i] = math.nan
    
mask1 = np.ones_like(S_masks[kk][:,d,:])
mask1[S_masks[kk][:,d,:]==0] = 0

# Apply dilation. Increase 'iterations' to make the mask even fatter.
fat_mask = binary_dilation(mask1, structure=structure, iterations=1)

CMasks.append(fat_mask * ((1-temp[...,-4])>0.1) * (temp[...,-4]>0))

kk = 1
d  = 52
temp = np.copy(Full_SBI[kk])
for i in range(13):
    temp[~Outlines[kk][:,d,:],i] = math.nan
    
mask1 = np.ones_like(S_masks[kk][:,d,:])
mask1[S_masks[kk][:,d,:]==0] = 0

# Apply dilation. Increase 'iterations' to make the mask even fatter.
fat_mask = binary_dilation(mask1, structure=structure, iterations=1)

CMasks.append(fat_mask * ((1-temp[...,-4])>0) * (temp[...,-4]>0))

kk = 2
d  = 54
temp = np.copy(Full_SBI[kk])
for i in range(13):
    temp[~Outlines[kk][:,d,:],i] = math.nan
    
mask1 = np.ones_like(S_masks[kk][:,d,:])
mask1[S_masks[kk][:,d,:]==0] = 0
# Apply dilation. Increase 'iterations' to make the mask even fatter.
fat_mask = binary_dilation(mask1, structure=structure, iterations=1)
CMasks.append(fat_mask * ((1-temp[...,-4])>0.3) * (temp[...,-4]>0))

In [None]:
i = 0
fig,ax = plt.subplots(figsize=(8,4))
g_pos = np.array([0,0.25,0.5])
colors = ['lightseagreen','lightseagreen','lightseagreen']
colors2 = ['paleturquoise','paleturquoise','paleturquoise']
for i in range(3):
    y_dat = 1000*abs(Min_SBI[i][CMasks[i]][:,-2]-Full_SBI[i][CMasks[i]][:,-2]).squeeze()
    BoxPlots(y_dat,[g_pos[i]],[colors[i]],[colors2[i]],ax,scatter=True)

g_pos = np.array([2,2.25,2.5])
colors = ['darkorange','darkorange','darkorange']
colors2 = ['peachpuff','peachpuff','peachpuff']
y_data = [1000*abs(Full_LS[i][CMasks[i]][:,-2]-Min_LS[i][CMasks[i]][:,-2]) for i in range(3)]
for i in range(3):
    y_dat = y_data[i]
    BoxPlots(y_dat,[g_pos[i]],[colors[i]],[colors2[i]],ax,scatter=True)

ax.set_xticks([0.25,2.25],['SBI Comp','NLLS Comp'],fontsize =36)

ax.ticklabel_format(axis='y',style='sci',scilimits=(-1,1))
ax.tick_params(axis='y', labelsize=24,)
ax.yaxis.get_offset_text().set_fontsize(24)

## i

In [None]:
Dirs = ['NMSS_11_1year','NMSS_15','NMSS_16','NMSS_18','NMSS_19','Ctrl055_R01_28','Ctrl056_R01_29','Ctrl057_R01_30']
BVecs = []
BVals = []
S_masks = []
Datas = []
Outlines = []
axial_middles = []
for D in tqdm(Dirs,position=0,leave=True):
    F = pmt.read_mat(MSDir+D+'/data_loaded.mat')
    affine = np.ones((4,4))
    BVecs.append(F['direction'])
    BVals.append(F['bval'])


    
    data, affine = reslice(F['data'], affine, (2,2,2), (2.5,2.5,2.5))

    axial_middle = data.shape[2] // 2
    md, mk = median_otsu(data, vol_idx=range(0, 10), median_radius=5,
                                 numpass=1, autocrop=False, dilate=2)

    floor = np.clip(md.min(axis=-1),-np.inf,0)
    maskdata_2 = np.copy(md)
    maskdata_2[floor <=0 ] = md[floor <= 0] + abs(floor)[floor <=0 ,None] + 1e-5
    Datas.append(maskdata_2)
    axial_middles.append(axial_middle)
    Outlines.append(mk)

In [None]:
# Choose the first point (arbitrary starting point, e.g., the first gradient)
IndxArr  = []
BVecsDev = []
BValsDev = []
for bve,bva in zip(BVecs,BVals): 
    # Choose the first point (arbitrary starting point, e.g., the first gradient)
    selected_indices = [0]
    bvecs2000 = bve[:91][bva[:91]==2000]
    distance_matrix = squareform(pdist(bve))
    # Iteratively select the point furthest from the current selection
    for _ in range(2):  # We need 7 points in total, and one is already selected
        remaining_indices = list(set(range(len(bvecs2000))) - set(selected_indices))
        
        # Calculate the minimum distance to the selected points for each remaining point
        min_distances = np.min(distance_matrix[remaining_indices][:, selected_indices], axis=1)
        
        # Select the point with the maximum minimum distance
        next_index = remaining_indices[np.argmax(min_distances)]
        selected_indices.append(next_index)
    
    selected_indices = selected_indices
    bvecs2000_selected = bve[:91][bva[:91]==2000][selected_indices]
    true_indices = []
    for b in bvecs2000_selected:
        true_indices.append(np.where((b == bve).all(axis=1))[0][0])
    
    # Choose the first point (arbitrary starting point, e.g., the first gradient)
    selected_indices = [0]
    bvecs4000 = bve[:91][bva[:91]==4000]
    distance_matrix = squareform(pdist(bve))
    # Iteratively select the point furthest from the current selection
    for _ in range(2):  # We need 7 points in total, and one is already selected
        remaining_indices = list(set(range(len(bvecs4000))) - set(selected_indices))
        
        # Calculate the minimum distance to the selected points for each remaining point
        min_distances = np.min(distance_matrix[remaining_indices][:, selected_indices], axis=1)
        
        # Select the point with the maximum minimum distance
        next_index = remaining_indices[np.argmax(min_distances)]
        selected_indices.append(next_index)
    
    selected_indices = selected_indices
    bvecs4000_selected = bve[:91][bva[:91]==4000][selected_indices]
    for b in bvecs4000_selected:
        true_indices.append(np.where((b == bve).all(axis=1))[0][0])
    true_indices1 = true_indices
    
    # Choose the first point (arbitrary starting point, e.g., the first gradient)
    selected_indices = [0]
    bvecs2000 = bve[91:182][bva[91:182]==2000]
    distance_matrix = squareform(pdist(bve))
    # Iteratively select the point furthest from the current selection
    for _ in range(2):  # We need 7 points in total, and one is already selected
        remaining_indices = list(set(range(len(bvecs2000))) - set(selected_indices))
        
        # Calculate the minimum distance to the selected points for each remaining point
        min_distances = np.min(distance_matrix[remaining_indices][:, selected_indices], axis=1)
        
        # Select the point with the maximum minimum distance
        next_index = remaining_indices[np.argmax(min_distances)]
        selected_indices.append(next_index)
    
    selected_indices = selected_indices
    bvecs2000_selected = bve[91:182][bva[91:182]==2000][selected_indices]
    true_indices = []
    for b in bvecs2000_selected:
        true_indices.append(np.where((b == bve).all(axis=1))[0][0])
    
    # Choose the first point (arbitrary starting point, e.g., the first gradient)
    selected_indices = [0]
    bvecs4000 = bve[91:182][bva[91:182]==4000]
    distance_matrix = squareform(pdist(bve))
    # Iteratively select the point furthest from the current selection
    for _ in range(2):  # We need 7 points in total, and one is already selected
        remaining_indices = list(set(range(len(bvecs4000))) - set(selected_indices))
        
        # Calculate the minimum distance to the selected points for each remaining point
        min_distances = np.min(distance_matrix[remaining_indices][:, selected_indices], axis=1)
        
        # Select the point with the maximum minimum distance
        next_index = remaining_indices[np.argmax(min_distances)]
        selected_indices.append(next_index)
    
    selected_indices = selected_indices
    bvecs4000_selected = bve[91:182][bva[91:182]==4000][selected_indices]
    for b in bvecs4000_selected:
        true_indices.append(np.where((b == bve).all(axis=1))[0][0])
    true_indices2 = true_indices
    
    # Choose the first point (arbitrary starting point, e.g., the first gradient)
    selected_indices = [0]
    bvecs2000 = bve[182:][bva[182:]==2000]
    distance_matrix = squareform(pdist(bve))
    # Iteratively select the point furthest from the current selection
    for _ in range(2):  # We need 7 points in total, and one is already selected
        remaining_indices = list(set(range(len(bvecs2000))) - set(selected_indices))
        
        # Calculate the minimum distance to the selected points for each remaining point
        min_distances = np.min(distance_matrix[remaining_indices][:, selected_indices], axis=1)
        
        # Select the point with the maximum minimum distance
        next_index = remaining_indices[np.argmax(min_distances)]
        selected_indices.append(next_index)
    
    selected_indices = selected_indices
    bvecs2000_selected = bve[182:][bva[182:]==2000][selected_indices]
    true_indices = []
    for b in bvecs2000_selected:
        true_indices.append(np.where((b == bve).all(axis=1))[0][0])
    
    # Choose the first point (arbitrary starting point, e.g., the first gradient)
    selected_indices = [0]
    bvecs4000 = bve[182:][bva[182:]==4000]
    distance_matrix = squareform(pdist(bve))
    # Iteratively select the point furthest from the current selection
    for _ in range(2):  # We need 7 points in total, and one is already selected
        remaining_indices = list(set(range(len(bvecs4000))) - set(selected_indices))
        
        # Calculate the minimum distance to the selected points for each remaining point
        min_distances = np.min(distance_matrix[remaining_indices][:, selected_indices], axis=1)
        
        # Select the point with the maximum minimum distance
        next_index = remaining_indices[np.argmax(min_distances)]
        selected_indices.append(next_index)
    
    selected_indices = selected_indices
    bvecs4000_selected = bve[182:][bva[182:]==4000][selected_indices]
    for b in bvecs4000_selected:
        true_indices.append(np.where((b == bve).all(axis=1))[0][0])
    true_indices3 = true_indices
    
    DevIndices = [0] + true_indices1 + [n_pts] +  true_indices2 + [n_pts+1] + true_indices3
    bvecs_Dev = bve[DevIndices]
    bvals_Dev = bva[DevIndices]

    IndxArr.append(DevIndices)
    BVecsDev.append(bvecs_Dev)
    BValsDev.append(bvals_Dev)

In [None]:
Full_SBI_Extra = []
for kk,(D,sl) in enumerate(zip(Datas,[48]*8)):
    # Compute the mask where the sum is not zero
    mask = np.sum(D[:, :, sl, :], axis=-1) != 0
    
    # Get the indices where mask is True
    indices = np.argwhere(mask)
    def optimize_chunk(pixels):
        results = []
        for i, j in pixels:
            posterior_samples_1 = Network.sample((1000,), x=AxcaliberFeatures(BVecs[kk],BVals[kk],Deltas,D[i,j,sl, :]),show_progress_bars=False)
            results.append((i, j, posterior_samples_1.mean(axis=0)))
        return results
    
    chunked_indices = [indices[i:i+ChunkSize] for i in range(0, len(indices), ChunkSize)]
    results = Parallel(n_jobs=8)(
        delayed(optimize_chunk)(chunk) for chunk in tqdm(chunked_indices)
    )
    
    # Initialize NoiseEst with the appropriate shape
    ArrShape = mask.shape
    
    NoiseEst = np.zeros(list(ArrShape) + [13])
    
    # Assign the optimization results to NoiseEst
    for chunk in results:
        for i, j, x in chunk:
            NoiseEst[i, j] = x

    Full_SBI_Extra.append(NoiseEst)

In [None]:
Min_SBI_Extra = []
for kk,(D,sl) in enumerate(zip(Datas,[48]*8)):
    # Compute the mask where the sum is not zero
    mask = np.sum(D[:, :, sl, :], axis=-1) != 0
    
    # Get the indices where mask is True
    indices = np.argwhere(mask)
    Arr = D[:,:,sl, IndxArr[kk]]
    def optimize_chunk(pixels):
        results = []
        for i, j in pixels:
            posterior_samples_1 = Network.sample((1000,), x=AxcaliberFeatures(BVecs[kk][IndxArr[kk]],BVals[kk][IndxArr[kk]],Deltas[IndxArr[kk]],Arr[i,j,:]),show_progress_bars=False)
            results.append((i, j, posterior_samples_1.mean(axis=0)))
        return results
    
    chunked_indices = [indices[i:i+ChunkSize] for i in range(0, len(indices), ChunkSize)]
    results = Parallel(n_jobs=8)(
        delayed(optimize_chunk)(chunk) for chunk in tqdm(chunked_indices)
    )
    
    # Initialize NoiseEst with the appropriate shape
    ArrShape = mask.shape
    
    NoiseEst = np.zeros(list(ArrShape) + [13])
    
    # Assign the optimization results to NoiseEst
    for chunk in results:
        for i, j, x in chunk:
            NoiseEst[i, j] = x

    Min_SBI_Extra.append(NoiseEst)

In [None]:
Full_LS_extra = []
for kk,(D,sl) in enumerate(zip(Datas,[48]*8)):
    # Compute the mask where the sum is not zero
    mask = np.sum(D[:, :, sl, :], axis=-1) != 0

    # Get the indices where mask is True
    indices = np.argwhere(mask)

    bve_split_kk = [BVecs[kk][:(n_pts+1)],BVecs[kk][(n_pts+1):2*(n_pts+1)],BVecs[kk][2*(n_pts+1):]]
    bva_split_kk = [BVals[kk][:(n_pts+1)],BVals[kk][(n_pts+1):2*(n_pts+1)],BVals[kk][2*(n_pts+1):]]
    # Define the function for optimization
    def optimize_pixel_LS(i, j):
        result = sp.optimize.least_squares(residuals, guess, args=[D[i,j,sl, :],bve_split_kk,bva_split_kk,Delta,True],
                                  bounds=bounds,verbose=0,jac='3-point')
        return i, j, result.x



    # Initialize NoiseEst with the appropriate shape
    ArrShape = mask.shape

    # Use joblib to parallelize the optimization tasks
    results = Parallel(n_jobs=8)(
        delayed(optimize_pixel_LS)(i, j) for i, j in tqdm(indices,position=0,leave=True)
    )


    NoiseEst_LS = np.zeros(list(ArrShape) + [13])

    # Assign the optimization results to NoiseEst
    for i, j, x in results:
        NoiseEst_LS[i, j] = x

    Full_LS_extra.append(NoiseEst_LS)

In [None]:
Min_LS_extra = []
for kk,(D,sl) in enumerate(zip(Datas,[48]*8)):
    # Compute the mask where the sum is not zero
    mask = np.sum(D[:, :, sl, :], axis=-1) != 0

    # Get the indices where mask is True
    indices = np.argwhere(mask)

    bve_splitd_kk = [BVecsDev[kk][:7],BVecsDev[kk][7:13],BVecsDev[kk][13:]]
    bva_splitd_kk = [BValsDev[kk][:7],BValsDev[kk][7:13],BValsDev[kk][13:]]

    # Define the function for optimization
    def optimize_pixel_LS(i, j):
        result = sp.optimize.least_squares(residuals, guess, args=[D[i, j,sl, IndxArr[kk]],bve_splitd_kk,bva_splitd_kk,Delta,True],
                                  bounds=bounds,verbose=0,xtol=1e-12,gtol=1e-12,ftol=1e-12,jac='3-point')
        return i, j, result.x

    # Initialize NoiseEst with the appropriate shape
    ArrShape = mask.shape

    # Use joblib to parallelize the optimization tasks
    results = Parallel(n_jobs=8)(
        delayed(optimize_pixel_LS)(i, j) for i, j in tqdm(indices,position=0,leave=True)
    )


    NoiseEst_LS = np.zeros(list(ArrShape) + [13])

    # Assign the optimization results to NoiseEst
    for i, j, x in results:
        NoiseEst_LS[i, j] = x

    Min_LS_extra.append(NoiseEst_LS)

In [None]:
WMDir = '../../MS_data/WM_masks/'
WMs = []
for i,Name in tqdm(enumerate(['NMSS_11_1year','NMSS_15','NMSS_16','NMSS_18','NMSS_19','Ctrl055_R01_28','Ctrl056_R01_29','Ctrl057_R01_30'])):
    
    for k,x in enumerate(os.listdir(WMDir)):
        if Name in x:
            WM, affine, img = load_nifti(WMDir+x, return_img=True)
            #WM, affine = reslice(WM, affine, (2,2,2), (2.5,2.5,2.5))
            if(i<5):
                WM_t = np.fliplr(np.swapaxes(WM,0,1))
            else:
                WM_t = np.fliplr(np.flipud(np.swapaxes(WM,0,1)))
            WM_t,_ = reslice(WM_t, affine, (2,2,2), (2.5,2.5,2.5))
            WMs.append(WM_t)

In [None]:
np.random.seed(133)
S0 = 2000
MD_prior = np.random.rand(1)*0.005
FA_prior = np.random.rand(1)*0.999
DHind_guess = [mat_to_vals(random_diffusion_tensor(m, f)) for m,f in zip(MD_prior,FA_prior)]

Dpar_guess = np.random.rand()*1e-3            # mm^2/s
Dperp_guess = np.random.rand()*1e-3             # mm^2/s
phi = 0#np.random.rand()*pi
cos_theta = 0#np.random.rand()  # uniform in [0,1]
theta = np.arccos(cos_theta)         # in [0, pi/2]
Angs_guess = np.vstack([theta,phi]).T
S0_guess =np.random.rand()*2475+25

mean_guess = np.random.rand()*0.005 + 1e-4

frac_guess = np.random.rand()
guess = np.column_stack([Angs_guess,Dpar_guess,Dperp_guess,DHind_guess,frac_guess,mean_guess,S0_guess]).squeeze()
bounds = np.array([[-np.inf,np.inf]]*13).T
bounds[:,0] = [0,np.pi/2]
bounds[:,1] = [-np.pi,np.pi]
bounds[:,2] = [0,5e-3]
bounds[:,3] = [0,5e-3]
bounds[:,4] = [-5e-3,5e-3]
bounds[:,5] = [-5e-3,5e-3]
bounds[:,6] = [-5e-3,5e-3]
bounds[:,7] = [-5e-3,5e-3]
bounds[:,8] = [-5e-3,5e-3]
bounds[:,9] = [-5e-3,5e-3]
bounds[:,10] = [0,1]
bounds[:,11] = [1e-4,0.005+1e-4]
bounds[:,12] = [25,2500]

In [None]:
FullSBI = []
FullNLLS = []
MinSBI = []
MinNLLS = []
for jj,N in enumerate(Names):
    Subfiles = []
    for k,x in enumerate(os.listdir(RetestDir)):
        if N in x:
            print(x)
            Subfiles.append(x)
    Subfiles = sorted(Subfiles)

    S = Subfiles[0]
    MatDir = RetestDir+S
    F = pmt.read_mat(MatDir+'/data_loaded.mat')
    
    bvecs = F['direction']
    bvals = F['bval']
    Deltas = F['deltas']
    data = F['data']

    bvecs = bvecs[Deltas != 42]
    bvals = bvals[Deltas != 42]
    data = data[..., Deltas !=42]
    Deltas = Deltas[Deltas != 42]*0.001
    n_pts = np.sum(Deltas == Deltas[0])
    bve = np.copy(bvecs)
    bva = np.copy(bvals)
    affine1 = np.eye(4)

    data, affine1 = reslice(data, affine1, (2,2,2), (2.5,2.5,2.5))
    _, maskCut = median_otsu(data, vol_idx=range(10, 80), autocrop=False)
    true_indices = np.argwhere(maskCut)

    # Determine the minimum and maximum indices along each dimension
    min_coords = true_indices.min(axis=0)
    max_coords = true_indices.max(axis=0)
    AM = (max_coords[-1]+min_coords[-1])//2
    bvecs = (bvecs.T/np.linalg.norm(bvecs,axis=1)).T
    bvecs[np.isnan(bvecs)] = 0
    
    # Choose the first point (arbitrary starting point, e.g., the first gradient)
    selected_indices = [0]
    bvecs2000 = bve[:n_pts][bva[:n_pts]==2000]
    distance_matrix = squareform(pdist(bve))
    # Iteratively select the point furthest from the current selection
    for _ in range(2):  # We need 7 points in total, and one is already selected
        remaining_indices = list(set(range(len(bvecs2000))) - set(selected_indices))
        
        # Calculate the minimum distance to the selected points for each remaining point
        min_distances = np.min(distance_matrix[remaining_indices][:, selected_indices], axis=1)
        
        # Select the point with the maximum minimum distance
        next_index = remaining_indices[np.argmax(min_distances)]
        selected_indices.append(next_index)
    
    selected_indices = selected_indices
    bvecs2000_selected = bve[:n_pts][bva[:n_pts]==2000][selected_indices]
    true_indices = []
    for b in bvecs2000_selected:
        true_indices.append(np.where((b == bve).all(axis=1))[0][0])
    
    # Choose the first point (arbitrary starting point, e.g., the first gradient)
    selected_indices = [0]
    bvecs4000 = bve[:n_pts][bva[:n_pts]==4000]
    distance_matrix = squareform(pdist(bve))
    # Iteratively select the point furthest from the current selection
    for _ in range(2):  # We need 7 points in total, and one is already selected
        remaining_indices = list(set(range(len(bvecs4000))) - set(selected_indices))
        
        # Calculate the minimum distance to the selected points for each remaining point
        min_distances = np.min(distance_matrix[remaining_indices][:, selected_indices], axis=1)
        
        # Select the point with the maximum minimum distance
        next_index = remaining_indices[np.argmax(min_distances)]
        selected_indices.append(next_index)
    
    selected_indices = selected_indices
    bvecs4000_selected = bve[:91][bva[:91]==4000][selected_indices]
    for b in bvecs4000_selected:
        true_indices.append(np.where((b == bve).all(axis=1))[0][0])
    true_indices1 = true_indices
    
    # Choose the first point (arbitrary starting point, e.g., the first gradient)
    selected_indices = [0]
    bvecs2000 = bve[n_pts:2*n_pts][bva[n_pts:2*n_pts]==2000]
    distance_matrix = squareform(pdist(bve))
    # Iteratively select the point furthest from the current selection
    for _ in range(2):  # We need 7 points in total, and one is already selected
        remaining_indices = list(set(range(len(bvecs2000))) - set(selected_indices))
        
        # Calculate the minimum distance to the selected points for each remaining point
        min_distances = np.min(distance_matrix[remaining_indices][:, selected_indices], axis=1)
        
        # Select the point with the maximum minimum distance
        next_index = remaining_indices[np.argmax(min_distances)]
        selected_indices.append(next_index)
    
    selected_indices = selected_indices
    bvecs2000_selected = bve[n_pts:2*n_pts][bva[n_pts:2*n_pts]==2000][selected_indices]
    true_indices = []
    for b in bvecs2000_selected:
        true_indices.append(np.where((b == bve).all(axis=1))[0][0])
    
    # Choose the first point (arbitrary starting point, e.g., the first gradient)
    selected_indices = [0]
    bvecs4000 = bve[n_pts:2*n_pts][bva[n_pts:2*n_pts]==4000]
    distance_matrix = squareform(pdist(bve))
    # Iteratively select the point furthest from the current selection
    for _ in range(2):  # We need 7 points in total, and one is already selected
        remaining_indices = list(set(range(len(bvecs4000))) - set(selected_indices))
        
        # Calculate the minimum distance to the selected points for each remaining point
        min_distances = np.min(distance_matrix[remaining_indices][:, selected_indices], axis=1)
        
        # Select the point with the maximum minimum distance
        next_index = remaining_indices[np.argmax(min_distances)]
        selected_indices.append(next_index)
    
    selected_indices = selected_indices
    bvecs4000_selected = bve[n_pts:2*n_pts][bva[n_pts:2*n_pts]==4000][selected_indices]
    for b in bvecs4000_selected:
        true_indices.append(np.where((b == bve).all(axis=1))[0][0])
    true_indices2 = true_indices
    
    # Choose the first point (arbitrary starting point, e.g., the first gradient)
    selected_indices = [0]
    bvecs2000 = bve[2*n_pts:][bva[2*n_pts:]==2000]
    distance_matrix = squareform(pdist(bve))
    # Iteratively select the point furthest from the current selection
    for _ in range(2):  # We need 7 points in total, and one is already selected
        remaining_indices = list(set(range(len(bvecs2000))) - set(selected_indices))
        
        # Calculate the minimum distance to the selected points for each remaining point
        min_distances = np.min(distance_matrix[remaining_indices][:, selected_indices], axis=1)
        
        # Select the point with the maximum minimum distance
        next_index = remaining_indices[np.argmax(min_distances)]
        selected_indices.append(next_index)
    
    selected_indices = selected_indices
    bvecs2000_selected = bve[2*n_pts:][bva[2*n_pts:]==2000][selected_indices]
    true_indices = []
    for b in bvecs2000_selected:
        true_indices.append(np.where((b == bve).all(axis=1))[0][0])
    
    # Choose the first point (arbitrary starting point, e.g., the first gradient)
    selected_indices = [0]
    bvecs4000 = bve[2*n_pts:][bva[2*n_pts:]==4000]
    distance_matrix = squareform(pdist(bve))
    # Iteratively select the point furthest from the current selection
    for _ in range(2):  # We need 7 points in total, and one is already selected
        remaining_indices = list(set(range(len(bvecs4000))) - set(selected_indices))
        
        # Calculate the minimum distance to the selected points for each remaining point
        min_distances = np.min(distance_matrix[remaining_indices][:, selected_indices], axis=1)
        
        # Select the point with the maximum minimum distance
        next_index = remaining_indices[np.argmax(min_distances)]
        selected_indices.append(next_index)
    
    selected_indices = selected_indices
    bvecs4000_selected = bve[2*n_pts:][bva[2*n_pts:]==4000][selected_indices]
    for b in bvecs4000_selected:
        true_indices.append(np.where((b == bve).all(axis=1))[0][0])
    true_indices3 = true_indices
    
    DevIndices = [0] + true_indices1 + [n_pts] +  true_indices2 + [2*n_pts] + true_indices3

    IndxArr = [DevIndices]
    
    gtabs = [gradient_table(bvals = bvals,bvecs = bvecs)]
    Delts = [Deltas]
    Dats = []
    for i,S in enumerate(Subfiles[1:]):
        MatDir = RetestDir+S
        F = pmt.read_mat(MatDir+'/data_loaded.mat')

        bvecs = F['direction']
        bvals = F['bval']
        Deltas = F['deltas']
        data1 = F['data']

        bvecs = bvecs[Deltas != 42]
        bvals = bvals[Deltas != 42]
        data1 = data1[..., Deltas !=42]
        Deltas = Deltas[Deltas != 42]*0.001
        Delts.append(Deltas)
        n_pts = np.sum(Deltas == Deltas[0])
        bve = np.copy(bvecs)
        bva = np.copy(bvals)
    
        data1, affine1 = reslice(data1, affine1, (2,2,2), (2.5,2.5,2.5))
        bvecs = (bvecs.T/np.linalg.norm(bvecs,axis=1)).T
        bvecs[np.isnan(bvecs)] = 0
        if(jj == 0):
            if(i < 2):
                data1 = data1[:,::-1]
        elif(jj == 1 or jj == 2 or jj == 3 or jj == 4):
            if(i<3):
                data1 = data1[:,::-1]
        elif(jj==5):
            if(i>0 and i < 3):
                data1 = data1[:,::-1]
        Dats.append(data1)
        gtabs.append(gradient_table(bvals = bvals,bvecs = bvecs))
        
        selected_indices = [0]
        bvecs2000 = bve[:n_pts][bva[:n_pts]==2000]
        distance_matrix = squareform(pdist(bve))
        # Iteratively select the point furthest from the current selection
        for _ in range(2):  # We need 7 points in total, and one is already selected
            remaining_indices = list(set(range(len(bvecs2000))) - set(selected_indices))

            # Calculate the minimum distance to the selected points for each remaining point
            min_distances = np.min(distance_matrix[remaining_indices][:, selected_indices], axis=1)

            # Select the point with the maximum minimum distance
            next_index = remaining_indices[np.argmax(min_distances)]
            selected_indices.append(next_index)

        selected_indices = selected_indices
        bvecs2000_selected = bve[:n_pts][bva[:n_pts]==2000][selected_indices]
        true_indices = []
        for b in bvecs2000_selected:
            true_indices.append(np.where((b == bve).all(axis=1))[0][0])

        # Choose the first point (arbitrary starting point, e.g., the first gradient)
        selected_indices = [0]
        bvecs4000 = bve[:n_pts][bva[:n_pts]==4000]
        distance_matrix = squareform(pdist(bve))
        # Iteratively select the point furthest from the current selection
        for _ in range(2):  # We need 7 points in total, and one is already selected
            remaining_indices = list(set(range(len(bvecs4000))) - set(selected_indices))

            # Calculate the minimum distance to the selected points for each remaining point
            min_distances = np.min(distance_matrix[remaining_indices][:, selected_indices], axis=1)

            # Select the point with the maximum minimum distance
            next_index = remaining_indices[np.argmax(min_distances)]
            selected_indices.append(next_index)

        selected_indices = selected_indices
        bvecs4000_selected = bve[:91][bva[:91]==4000][selected_indices]
        for b in bvecs4000_selected:
            true_indices.append(np.where((b == bve).all(axis=1))[0][0])
        true_indices1 = true_indices

        # Choose the first point (arbitrary starting point, e.g., the first gradient)
        selected_indices = [0]
        bvecs2000 = bve[n_pts:2*n_pts][bva[n_pts:2*n_pts]==2000]
        distance_matrix = squareform(pdist(bve))
        # Iteratively select the point furthest from the current selection
        for _ in range(2):  # We need 7 points in total, and one is already selected
            remaining_indices = list(set(range(len(bvecs2000))) - set(selected_indices))

            # Calculate the minimum distance to the selected points for each remaining point
            min_distances = np.min(distance_matrix[remaining_indices][:, selected_indices], axis=1)

            # Select the point with the maximum minimum distance
            next_index = remaining_indices[np.argmax(min_distances)]
            selected_indices.append(next_index)

        selected_indices = selected_indices
        bvecs2000_selected = bve[n_pts:2*n_pts][bva[n_pts:2*n_pts]==2000][selected_indices]
        true_indices = []
        for b in bvecs2000_selected:
            true_indices.append(np.where((b == bve).all(axis=1))[0][0])

        # Choose the first point (arbitrary starting point, e.g., the first gradient)
        selected_indices = [0]
        bvecs4000 = bve[n_pts:2*n_pts][bva[n_pts:2*n_pts]==4000]
        distance_matrix = squareform(pdist(bve))
        # Iteratively select the point furthest from the current selection
        for _ in range(2):  # We need 7 points in total, and one is already selected
            remaining_indices = list(set(range(len(bvecs4000))) - set(selected_indices))

            # Calculate the minimum distance to the selected points for each remaining point
            min_distances = np.min(distance_matrix[remaining_indices][:, selected_indices], axis=1)

            # Select the point with the maximum minimum distance
            next_index = remaining_indices[np.argmax(min_distances)]
            selected_indices.append(next_index)

        selected_indices = selected_indices
        bvecs4000_selected = bve[n_pts:2*n_pts][bva[n_pts:2*n_pts]==4000][selected_indices]
        for b in bvecs4000_selected:
            true_indices.append(np.where((b == bve).all(axis=1))[0][0])
        true_indices2 = true_indices

        # Choose the first point (arbitrary starting point, e.g., the first gradient)
        selected_indices = [0]
        bvecs2000 = bve[2*n_pts:][bva[2*n_pts:]==2000]
        distance_matrix = squareform(pdist(bve))
        # Iteratively select the point furthest from the current selection
        for _ in range(2):  # We need 7 points in total, and one is already selected
            remaining_indices = list(set(range(len(bvecs2000))) - set(selected_indices))

            # Calculate the minimum distance to the selected points for each remaining point
            min_distances = np.min(distance_matrix[remaining_indices][:, selected_indices], axis=1)

            # Select the point with the maximum minimum distance
            next_index = remaining_indices[np.argmax(min_distances)]
            selected_indices.append(next_index)

        selected_indices = selected_indices
        bvecs2000_selected = bve[2*n_pts:][bva[2*n_pts:]==2000][selected_indices]
        true_indices = []
        for b in bvecs2000_selected:
            true_indices.append(np.where((b == bve).all(axis=1))[0][0])

        # Choose the first point (arbitrary starting point, e.g., the first gradient)
        selected_indices = [0]
        bvecs4000 = bve[2*n_pts:][bva[2*n_pts:]==4000]
        distance_matrix = squareform(pdist(bve))
        # Iteratively select the point furthest from the current selection
        for _ in range(2):  # We need 7 points in total, and one is already selected
            remaining_indices = list(set(range(len(bvecs4000))) - set(selected_indices))

            # Calculate the minimum distance to the selected points for each remaining point
            min_distances = np.min(distance_matrix[remaining_indices][:, selected_indices], axis=1)

            # Select the point with the maximum minimum distance
            next_index = remaining_indices[np.argmax(min_distances)]
            selected_indices.append(next_index)

        selected_indices = selected_indices
        bvecs4000_selected = bve[2*n_pts:][bva[2*n_pts:]==4000][selected_indices]
        for b in bvecs4000_selected:
            true_indices.append(np.where((b == bve).all(axis=1))[0][0])
        true_indices3 = true_indices

        DevIndices = [0] + true_indices1 + [n_pts] +  true_indices2 + [2*n_pts] + true_indices3

        IndxArr.append(DevIndices)
        
    NewDats = [data]
    for d,gt in zip(Dats,gtabs[1:]):
        affine_map = rigid_register(data[...,gtabs[0].bvals==0].mean(axis=-1),d[...,gt.bvals==0].mean(axis=-1),affine1,affine1)
        data2_warp = np.array([affine_map.transform(d[:,:,:,i], interpolation="linear") for i in range(len(gt.bvals))])
        data2_warp = np.rollaxis(data2_warp, 0, data2_warp.ndim)
        NewDats.append(data2_warp)
    
    NewDats_masked = [ND*maskCut[...,None] for ND in NewDats]
    
    Full_SBI_Arr = []
    for ND,gt,Del in tqdm(zip(NewDats_masked,gtabs,Delts),position=0,leave=True):
        mask = np.sum(ND[:, :, AM, :], axis=-1) != 0
        # Get the indices where mask is True
        indices = np.argwhere(mask)
        floor = np.clip(ND.min(axis=-1),-np.inf,0)
        dat = ND + abs(floor)[:,:,:,None] + 1e-5
        # Define the function for optimization
        def optimize_chunk(pixels):
            results = []
            for i, j in pixels:
                samples = Network.sample((1000,), x=AxcaliberFeatures(gt.bvecs,gt.bvals,Del,dat[i, j,AM, :]),show_progress_bars=False)        
                results.append((i, j, samples.mean(axis=0)))
            return results

        chunked_indices = [indices[i:i+ChunkSize] for i in range(0, len(indices), ChunkSize)]
        results = Parallel(n_jobs=6)(
            delayed(optimize_chunk)(chunk) for chunk in tqdm(chunked_indices,position=0,leave=True)
        )

        ArrShape = maskCut[...,AM].shape

        NoiseEst = np.zeros(list(ArrShape) + [13])

        # Assign the optimization results to NoiseEst
        for chunk in results:
            for i, j, x in chunk:
                NoiseEst[i, j] = x
                NoiseEst[i, j,-2] = np.clip(NoiseEst[i, j,-2],0,100)
                NoiseEst[i, j,-3] = np.clip(NoiseEst[i, j,-3],0,1)
        NoiseEst2 = np.copy(NoiseEst)

        for i in range(13):
            NoiseEst2[~maskCut[...,AM],i] = math.nan

        NoiseEst2[(1-NoiseEst2[...,-3])<0.3,-2] = math.nan
        Full_SBI_Arr.append(NoiseEst2)
    Full_SBI_Arr2.append(Full_SBI_Arr)
    
    Full_SBI_Arr = []
    for ND,gt,Del,idxs in tqdm(zip(NewDats_masked,gtabs,Delts,IndxArr),position=0,leave=True):
        mask = np.sum(ND[:, :, AM, :], axis=-1) != 0
        # Get the indices where mask is True
        indices = np.argwhere(mask)
        floor = np.clip(ND.min(axis=-1),-np.inf,0)
        dat = ND + abs(floor)[:,:,:,None] + 1e-5
        # Define the function for optimization
        def optimize_chunk(pixels):
            results = []
            for i, j in pixels:
                samples = Network.sample((1000,), x=AxcaliberFeatures(gt.bvecs[idxs],gt.bvals[idxs],Del[idxs],dat[i, j,AM, idxs]),show_progress_bars=False)        
                results.append((i, j, samples.mean(axis=0)))
            return results

        chunked_indices = [indices[i:i+ChunkSize] for i in range(0, len(indices), ChunkSize)]
        results = Parallel(n_jobs=6)(
            delayed(optimize_chunk)(chunk) for chunk in tqdm(chunked_indices,position=0,leave=True)
        )

        ArrShape = maskCut[...,AM].shape

        NoiseEst = np.zeros(list(ArrShape) + [13])

        # Assign the optimization results to NoiseEst
        for chunk in results:
            for i, j, x in chunk:
                NoiseEst[i, j] = x
                NoiseEst[i, j,-2] = np.clip(NoiseEst[i, j,-2],0,100)
                NoiseEst[i, j,-3] = np.clip(NoiseEst[i, j,-3],0,1)
        NoiseEst2 = np.copy(NoiseEst)

        for i in range(13):
            NoiseEst2[~maskCut[...,AM],i] = math.nan

        NoiseEst2[(1-NoiseEst2[...,-3])<0.3,-2] = math.nan
        Full_SBI_Arr.append(NoiseEst2)
    Full_SBI_min_Arr2.append(Full_SBI_Arr)
    
        # Get the indices where mask is True
    Full_NLLS_Arr = []
    for ND,gt,Del in tqdm(zip(NewDats_masked,gtabs,Delts),position=0,leave=True):
        mask = np.sum(ND[:, :, AM, :], axis=-1) != 0
        # Get the indices where mask is True
        indices = np.argwhere(mask)
        floor = np.clip(ND.min(axis=-1),-np.inf,0)
        dat = ND + abs(floor)[:,:,:,None] + 1e-5
        
        bve_split_kk = [gt.bvecs[:(n_pts)],gt.bvecs[(n_pts):2*(n_pts)],gt.bvecs[2*(n_pts):]]
        bva_split_kk = [gt.bvals[:(n_pts)],gt.bvals[(n_pts):2*(n_pts)],gt.bvals[2*(n_pts):]]
        # Define the function for optimization
        def optimize_pixel_LS(i, j):
            result = sp.optimize.least_squares(residuals, guess, args=[dat[i,j,AM, :],bve_split_kk,bva_split_kk,Del,True],
                                      bounds=bounds,verbose=0,jac='3-point')
            return i, j, result.x



        # Initialize NoiseEst with the appropriate shape
        ArrShape = mask.shape

        # Use joblib to parallelize the optimization tasks
        results = Parallel(n_jobs=6)(
            delayed(optimize_pixel_LS)(i, j) for i, j in tqdm(indices,position=0,leave=True)
        )


        NoiseEst_LS = np.zeros(list(ArrShape) + [13])

        # Assign the optimization results to NoiseEst
        for i, j, x in results:
            NoiseEst_LS[i, j] = x

        Full_NLLS_Arr.append(NoiseEst_LS)
    Full_NLLS_Arr2.append(Full_NLLS_Arr)
    
            # Get the indices where mask is True
    Full_NLLS_Arr = []
    for ND,gt,Del,idxs in tqdm(zip(NewDats_masked,gtabs,Delts,IndxArr),position=0,leave=True):
        mask = np.sum(ND[:, :, AM, :], axis=-1) != 0
        # Get the indices where mask is True
        indices = np.argwhere(mask)
        floor = np.clip(ND.min(axis=-1),-np.inf,0)
        dat = ND + abs(floor)[:,:,:,None] + 1e-5
        
        bve_split_kk = [gt.bvecs[idxs[:7]],gt.bvecs[idxs[7:14]],gt.bvecs[idxs[14:]]]
        bva_split_kk = [gt.bvals[idxs[:7]],gt.bvals[idxs[7:14]],gt.bvals[idxs[14:]]]
        bve_split_kk 
        # Define the function for optimization
        def optimize_pixel_LS(i, j):
            result = sp.optimize.least_squares(residuals, guess, args=[dat[i,j,AM, idxs],bve_split_kk,bva_split_kk,Delta,True],
                                      bounds=bounds,verbose=0,jac='3-point')
            return i, j, result.x



        # Initialize NoiseEst with the appropriate shape
        ArrShape = mask.shape

        # Use joblib to parallelize the optimization tasks
        results = Parallel(n_jobs=6)(
            delayed(optimize_pixel_LS)(i, j) for i, j in tqdm(indices,position=0,leave=True)
        )


        NoiseEst_LS = np.zeros(list(ArrShape) + [13])

        # Assign the optimization results to NoiseEst
        for i, j, x in results:
            NoiseEst_LS[i, j] = x

        Full_NLLS.append(NoiseEst_LS)
    MinNLLS.append(Full_NLLS_Arr)

In [None]:
FullNLLS_list = []
MinNLLS_list = []
FullSBI_list = []
MinSBI_list = []


for kk in range(len(FullNLLS)):
    for i in range(len(FullNLLS[kk])):
        FullNLLS_list.append(FullNLLS[kk][i])
        MinNLLS_list.append(MinNLLS[kk][i])
        FullSBI_list.append(FullSBI[kk][i])
        MinSBI_list.append(MinSBI[kk][i])

In [None]:
jj = -4
SBI_comp_Frac_RT = []
for i in range(29):
    NS1 = np.copy(FullSBI_list[i][...,jj])
    NS2 = np.copy(MinSBI_list[i][...,jj])
    Ma = ~np.isnan(NS1)
    NS1[~Ma] = 0
    NS2[~Ma] = 0
    SBI_comp_Frac_RT.append(masked_local_ssim(NS1, NS2, Ma, win_size=7))

LS_comp_Frac_RT = []
for i in range(29):
    NS1 = np.copy(FullNLLS_list[i][...,jj])
    NS2 = np.copy(MinNLLS_list[i][...,jj])
    Ma = ~np.isnan(NS1)
    NS1[~Ma] = 0
    NS2[~Ma] = 0
    LS_comp_Frac_RT.append(masked_local_ssim(NS1, NS2, Ma, win_size=7))

In [None]:
fig, ax = plt.subplots(figsize=(3.2,4.8))#, sharex=True)
y_data = np.array(SBI_comp_Frac+SBI_comp_Frac_RT)
g_pos = np.array([1.3])
colors = ['lightseagreen']
colors2 = ['paleturquoise']
BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2,scatter=False,scatter_alpha=0.5)

y_data = np.array(SBI_comp_Frac)
x_data = g_pos*np.ones_like(y_data)
x_data += stats.t(df=6, scale=0.02).rvs(len(x_data))
ax.scatter(x_data,y_data,marker='o',color=colors2,s=100,alpha=0.8)

y_data = np.array(SBI_comp_Frac_RT)
x_data = g_pos*np.ones_like(y_data)
x_data += stats.t(df=6, scale=0.02).rvs(len(x_data))
ax.scatter(x_data,y_data,marker='s',color='darkcyan',s=100,alpha=0.5)

y_data = np.array(LS_comp_Frac+LS_comp_Frac_RT)
g_pos = np.array([1.9])
colors = ['sandybrown']
colors2 = ['peachpuff']

BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2,scatter=False)

y_data = np.array(LS_comp_Frac)
x_data = g_pos*np.ones_like(y_data)
x_data += stats.t(df=6, scale=0.02).rvs(len(x_data))
ax.scatter(x_data,y_data,marker='o',color=colors2,s=100,alpha=0.8)

y_data = np.array(LS_comp_Frac_RT)
x_data = g_pos*np.ones_like(y_data)
x_data += stats.t(df=6, scale=0.02).rvs(len(x_data))
ax.scatter(x_data,y_data,marker='s',color='chocolate',s=100,alpha=0.5)

plt.xticks([1.3,1.9],['SBI','NNLS'],fontsize=32,rotation=90)
ax.set_yticks([0,0.2,0.4,0.6,0.8,1.0])
ax.set_ylim(-0.1,1)

if Save: plt.savefig(FigLoc+'MS_Ax_SSIM_Frac.pdf',format='PDF',transparent=True,bbox_inches='tight')

In [None]:
fig,ax1 = plt.subplots(1,1,figsize=(3.2,4.8))

y_data = np.array(PrecFull_SBI_Frac)
g_pos = np.array([0.65])
colors = ['mediumturquoise']
colors2 = ['paleturquoise']

BoxPlots(y_data,g_pos,colors,colors2,ax1,widths=0.2,scatter=True)

y_data = np.array(Prec7_SBI_Frac)
g_pos = np.array([1.1])

BoxPlots(y_data,g_pos,colors,colors2,ax1,widths=0.2,scatter=True)

y_data = np.array(PrecFull_NLLS_Frac)
g_pos = np.array([1.8])
colors = ['sandybrown']
colors2 = ['peachpuff']

BoxPlots(y_data,g_pos,colors,colors2,ax1,widths=0.2,scatter=True)


y_data = np.array(Prec7_NLLS_Frac)
g_pos = np.array([2.15])

BoxPlots(y_data,g_pos,colors,colors2,ax1,widths=0.2,scatter=True)
plt.xticks([0.65,1.1,1.8,2.15],['Full','Red.','Full','Red.'],fontsize=32,rotation=90)

x = np.arange(1.7,2.3,0.05)
y1 = np.ones_like(x)*np.percentile(np.array(PrecFull_NLLS_Frac)[~np.isnan(PrecFull_NLLS_Frac)], 25)
y2 = np.ones_like(x)*np.percentile(np.array(PrecFull_NLLS_Frac)[~np.isnan(PrecFull_NLLS_Frac)], 77)
plt.fill_between(x,y1,y2,color='sandybrown',zorder=10,alpha=0.2,hatch='//')

x = np.arange(0.55,1.25,0.05)
y1 = np.ones_like(x)*np.percentile(np.array(PrecFull_SBI_Frac)[~np.isnan(PrecFull_SBI_Frac)], 25)
y2 = np.ones_like(x)*np.percentile(np.array(PrecFull_SBI_Frac)[~np.isnan(PrecFull_SBI_Frac)], 77)
plt.fill_between(x,y1,y2,color='mediumturquoise',zorder=10,alpha=0.2,hatch='//')

#ax1.set_xlim(0.3,2.8)
ax1.ticklabel_format(axis='y',style='sci',scilimits=(-1,1))
#ax1.set_yticks([0,0.1,0.2])
if Save: plt.savefig(FigLoc+'MS_Ax_Prec_Frac.pdf',format='PDF',transparent=True,bbox_inches='tight')

## j

In [None]:
KK = [48]*8
FA_Full_SBI = []
MD_Full_SBI = []
for jj in range(8):
    mask = Outlines[jj][:,:,KK[jj]]
    indices = np.argwhere(mask)
    
    results = Parallel(n_jobs=8)(
        delayed(Par_frac)(i, j,Full_SBI_Extra[jj][...,4:10]) for i, j in tqdm(indices,position=0,leave=True)
    )
    
    
    temp1 = np.zeros(list(ArrShape))
    temp2 = np.zeros(list(ArrShape))
    # Assign the optimization results to NoiseEst
    for i, j, k in results:
        temp1[i, j] = k[0]
        temp2[i, j] = k[1]

    FA_Full_SBI.append(temp1)
    MD_Full_SBI.append(temp2)
KK = [48]*8
FA_Min_SBI = []
MD_Min_SBI = []
for jj in range(8):
    mask = Outlines[jj][:,:,KK[jj]]
    indices = np.argwhere(mask)

    Arr = Min_SBI_Extra[jj][...,4:10]
         
    results = Parallel(n_jobs=8,)(
        delayed(Par_frac)(i, j,Min_SBI_Extra[jj][...,4:10]) for i, j in tqdm(indices,position=0,leave=True)
    )
    
    
    temp1 = np.zeros(list(ArrShape))
    temp2 = np.zeros(list(ArrShape))
    # Assign the optimization results to NoiseEst
    for i, j, k in results:
        temp1[i, j] = k[0]
        temp2[i, j] = k[1]

    FA_Min_SBI.append(temp1)
    MD_Min_SBI.append(temp2)
KK = [48]*8
FA_Full_LS = []
MD_Full_LS = []
for jj in range(8):
    mask = Outlines[jj][:,:,KK[jj]]
    indices = np.argwhere(mask)
    
    results = Parallel(n_jobs=8,)(
        delayed(Par_frac)(i, j,Full_LS_extra[jj][...,4:10]) for i, j in tqdm(indices,position=0,leave=True)
    )
    
    
    temp1 = np.zeros(list(ArrShape))
    temp2 = np.zeros(list(ArrShape))
    # Assign the optimization results to NoiseEst
    for i, j, k in results:
        temp1[i, j] = k[0]
        temp2[i, j] = k[1]

    FA_Full_LS.append(temp1)
    MD_Full_LS.append(temp2)

KK = [48]*8
FA_Min_LS = []
MD_Min_LS = []
for jj in range(8):
    mask = Outlines[jj][:,:,KK[jj]]
    indices = np.argwhere(mask)
    
    results = Parallel(n_jobs=8)(
        delayed(Par_frac)(i, j,Min_LS_extra[jj][...,4:10]) for i, j in tqdm(indices,position=0,leave=True)
    )
    
    
    temp1 = np.zeros(list(ArrShape))
    temp2 = np.zeros(list(ArrShape))
    # Assign the optimization results to NoiseEst
    for i, j, k in results:
        temp1[i, j] = k[0]
        temp2[i, j] = k[1]

    FA_Min_LS.append(temp1)
    MD_Min_LS.append(temp2)


In [None]:
MD_Full_SBI_RT = []
for kk in tqdm(range(29),position=0,leave=True):
    temp = np.zeros_like(FullSBI_list[kk][:,:,0])*math.nan
    mask = ~np.isnan(FullSBI_list[kk][:,:,0])
    indices = np.argwhere(mask)
    for i1,i2 in indices:
        temp[i1,i2] = MD_FA(vals_to_mat(FullSBI_list[kk][i1,i2,4:10]))[0]
    MD_Full_SBI_RT.append(temp)

MD_Full_NLLS_RT = []
for kk in tqdm(range(29),position=0,leave=True):
    temp = np.zeros_like(FullNLLS_list[kk][:,:,0])*math.nan
    mask = ~np.isnan(FullNLLS_list[kk][:,:,0])
    indices = np.argwhere(mask)
    for i1,i2 in indices:
        temp[i1,i2] = MD_FA(vals_to_mat(FullNLLS_list[kk][i1,i2,4:10]))[0]
    MD_Full_NLLS_RT.append(temp)

In [None]:
MD_Min_SBI_RT = []
for kk in tqdm(range(29),position=0,leave=True):
    temp = np.zeros_like(MinSBI_list[kk][:,:,0])*math.nan
    mask = ~np.isnan(MinSBI_list[kk][:,:,0])
    indices = np.argwhere(mask)
    for i1,i2 in indices:
        temp[i1,i2] = MD_FA(vals_to_mat(MinSBI_list[kk][i1,i2,4:10]))[0]
    MD_Min_SBI_RT.append(temp)

MD_Min_NLLS_RT = []
for kk in tqdm(range(29),position=0,leave=True):
    temp = np.zeros_like(MinNLLS_list[kk][:,:,0])*math.nan
    mask = ~np.isnan(MinNLLS_list[kk][:,:,0])
    indices = np.argwhere(mask)
    for i1,i2 in indices:
        temp[i1,i2] = MD_FA(vals_to_mat(MinNLLS_list[kk][i1,i2,4:10]))[0]
    MD_Min_NLLS_RT.append(temp)

In [None]:
SBI_comp_MD = []
KK = [48]*8
for i in range(8):
    NS1 = np.copy(MD_Min_SBI[i])
    NS1 = gaussian_filter(NS1, sigma=0.5)
    NS2 = np.copy(MD_Full_SBI[i])

    core,ssim_map = ssim(NS1,NS2, data_range=max([NS1.max(),NS2.max()])-min([NS1.min(),NS2.min()]),full=True,win_size=7)
    masked_ssim = ssim_map[Outlines[i][:,:,KK[i]]].mean()
    SBI_comp_MD.append(masked_ssim.mean())

LS_comp_MD = []
for i in range(8):
    NS1 = np.copy(MD_Min_LS[i])
    NS1 = gaussian_filter(NS1, sigma=0.5)
    NS2 = np.copy(MD_Full_LS[i])

    core,ssim_map = ssim(NS1,NS2, data_range=max([NS1.max(),NS2.max()])-min([NS1.min(),NS2.min()]),full=True,win_size=7)
    masked_ssim = ssim_map[Outlines[i][:,:,KK[i]]].mean()
    LS_comp_MD.append(masked_ssim.mean())

SBI_LS_comp_MD = []
for i in range(8):
    NS1 = np.copy(MD_Full_SBI[i])
    NS2 = np.copy(MD_Full_LS[i])

    core,ssim_map = ssim(NS1,NS2, data_range=max([NS1.max(),NS2.max()])-min([NS1.min(),NS2.min()]),full=True,win_size=7)
    masked_ssim = ssim_map[Outlines[i][:,:,KK[i]]].mean()
    SBI_LS_comp_MD.append(masked_ssim.mean())
Prec7_SBI_MD = []
PrecFull_SBI_MD = []

Prec7_NLLS_MD = []
PrecFull_NLLS_MD = []
for i in range(8):
    Prec7_SBI_MD.append(np.std(MD_Min_SBI[i][WMs[i].astype(bool)[:,:,48]]))
    PrecFull_SBI_MD.append(np.std(MD_Full_SBI[i][WMs[i].astype(bool)[:,:,48]]))

    Prec7_NLLS_MD.append(np.std(MD_Min_LS[i][WMs[i].astype(bool)[:,:,48]]))
    PrecFull_NLLS_MD.append(np.std(MD_Full_LS[i][WMs[i].astype(bool)[:,:,48]]))


In [None]:
jj = -4
SBI_comp_MD_RT = []
for i in range(29):
    NS1 = np.copy(MD_Full_SBI_RT[i])
    NS2 = np.copy(MD_Min_SBI_RT[i])
    Ma = ~np.isnan(NS1)
    NS1[~Ma] = 0
    NS2[~Ma] = 0
    SBI_comp_MD_RT.append(masked_local_ssim(NS1, NS2, Ma, win_size=7))

LS_comp_MD_RT = []
for i in range(29):
    NS1 = np.copy(MD_Full_NLLS_RT[i])
    NS2 = np.copy(MD_Min_NLLS_RT[i])
    Ma = ~np.isnan(NS1)
    NS1[~Ma] = 0
    NS2[~Ma] = 0
    LS_comp_MD_RT.append(masked_local_ssim(NS1, NS2, Ma, win_size=7))

In [None]:
fig, ax = plt.subplots(figsize=(3.2,4.8))#, sharex=True)

# Plotting on ax1
plt.sca(ax)
y_data = np.array(SBI_comp_MD+SBI_comp_MD_RT)
g_pos = np.array([1.3])
colors = ['lightseagreen']
colors2 = ['paleturquoise']

BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2,scatter=False,scatter_alpha=0.5)
y_data = np.array(SBI_comp_MD)
x_data = g_pos*np.ones_like(y_data)
x_data += stats.t(df=6, scale=0.02).rvs(len(x_data))
ax.scatter(x_data,y_data,marker='o',color=colors2,s=100,alpha=0.8)

y_data = np.array(SBI_comp_MD_RT)
x_data = g_pos*np.ones_like(y_data)
x_data += stats.t(df=6, scale=0.02).rvs(len(x_data))
ax.scatter(x_data,y_data,marker='s',color='darkcyan',s=100,alpha=0.5)


y_data = np.array(LS_comp_MD+LS_comp_MD_RT)
g_pos = np.array([1.9])
colors = ['sandybrown']
colors2 = ['peachpuff']

BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2,scatter=False)

y_data = np.array(LS_comp_MD)
x_data = g_pos*np.ones_like(y_data)
x_data += stats.t(df=6, scale=0.02).rvs(len(x_data))
ax.scatter(x_data,y_data,marker='o',color=colors2,s=100,alpha=0.8)

y_data = np.array(LS_comp_MD_RT)
x_data = g_pos*np.ones_like(y_data)
x_data += stats.t(df=6, scale=0.02).rvs(len(x_data))
ax.scatter(x_data,y_data,marker='s',color='chocolate',s=100,alpha=0.5)

plt.xticks([1.3,1.9],['SBI','NNLS'],fontsize=32,rotation=90)
ax.set_yticks([0,0.2,0.4,0.6,0.8,1.0])
ax.set_ylim(-0.1,1)

if Save: plt.savefig(FigLoc+'MS_Ax_SSIM_MD.pdf',format='PDF',transparent=True,bbox_inches='tight')

In [None]:
fig,ax1 = plt.subplots(1,1,figsize=(3.2,4.8))

y_data = np.array(PrecFull_SBI_MD)
g_pos = np.array([0.65])
colors = ['mediumturquoise']
colors2 = ['paleturquoise']

BoxPlots(y_data,g_pos,colors,colors2,ax1,widths=0.2,scatter=True)

y_data = np.array(Prec7_SBI_MD)
g_pos = np.array([1.1])

BoxPlots(y_data,g_pos,colors,colors2,ax1,widths=0.2,scatter=True)

y_data = np.array(PrecFull_NLLS_MD)
g_pos = np.array([1.8])
colors = ['sandybrown']
colors2 = ['peachpuff']

BoxPlots(y_data,g_pos,colors,colors2,ax1,widths=0.2,scatter=True)


y_data = np.array(Prec7_NLLS_MD)
g_pos = np.array([2.15])

BoxPlots(y_data,g_pos,colors,colors2,ax1,widths=0.2,scatter=True)
plt.xticks([0.65,1.1,1.8,2.15],['Full','Red.','Full','Red.'],fontsize=32,rotation=90)

x = np.arange(1.7,2.3,0.05)
y1 = np.ones_like(x)*np.percentile(np.array(PrecFull_NLLS_MD)[~np.isnan(PrecFull_NLLS_MD)], 25)
y2 = np.ones_like(x)*np.percentile(np.array(PrecFull_NLLS_MD)[~np.isnan(PrecFull_NLLS_MD)], 77)
plt.fill_between(x,y1,y2,color='sandybrown',zorder=10,alpha=0.2,hatch='//')

x = np.arange(0.55,1.25,0.05)
y1 = np.ones_like(x)*np.percentile(np.array(PrecFull_SBI_MD)[~np.isnan(PrecFull_SBI_MD)], 25)
y2 = np.ones_like(x)*np.percentile(np.array(PrecFull_SBI_MD)[~np.isnan(PrecFull_SBI_MD)], 77)
plt.fill_between(x,y1,y2,color='mediumturquoise',zorder=10,alpha=0.2,hatch='//')

#ax1.set_xlim(0.3,2.8)
ax1.ticklabel_format(axis='y',style='sci',scilimits=(-1,1))
ax1.set_yticks([0.0004,0.0008,0.0012])
if Save: plt.savefig(FigLoc+'MS_Ax_Prec_MD.pdf',format='PDF',transparent=True,bbox_inches='tight')

## k

In [None]:
jj = 2
SBI_comp_Dp = []
KK = [48]*8
for i in range(8):
    NS1 = np.copy(Min_SBI_Extra[i][...,jj])
    NS1 = gaussian_filter(NS1, sigma=0.5)
    NS2 = np.copy(Full_SBI_Extra[i][...,jj])

    core,ssim_map = ssim(NS1,NS2, data_range=max([NS1.max(),NS2.max()])-min([NS1.min(),NS2.min()]),full=True,win_size=7)
    masked_ssim = ssim_map[Outlines[i][:,:,KK[i]]].mean()
    SBI_comp_Dp.append(masked_ssim.mean())

LS_comp_Dp = []
for i in range(8):
    NS1 = np.copy(Min_LS_extra[i][...,jj])
    NS1 = gaussian_filter(NS1, sigma=0.5)
    NS2 = np.copy(Full_LS_extra[i][...,jj])

    core,ssim_map = ssim(NS1,NS2, data_range=max([NS1.max(),NS2.max()])-min([NS1.min(),NS2.min()]),full=True,win_size=7)
    masked_ssim = ssim_map[Outlines[i][:,:,KK[i]]].mean()
    LS_comp_Dp.append(masked_ssim.mean())

SBI_LS_comp_Dp = []
for i in range(8):
    NS1 = np.copy(Full_SBI_Extra[i][...,jj])
    NS2 = np.copy(Full_LS_extra[i][...,jj])

    core,ssim_map = ssim(NS1,NS2, data_range=max([NS1.max(),NS2.max()])-min([NS1.min(),NS2.min()]),full=True,win_size=7)
    masked_ssim = ssim_map[Outlines[i][:,:,KK[i]]].mean()
    SBI_LS_comp_Dp.append(masked_ssim.mean())
Prec7_SBI_Dp = []
PrecFull_SBI_Dp = []

Prec7_NLLS_Dp = []
PrecFull_NLLS_Dp = []
for i in range(8):
    Prec7_SBI_Dp.append(np.std(Min_SBI_Extra[i][...,jj][WMs[i].astype(bool)[:,:,48]]))
    PrecFull_SBI_Dp.append(np.std(Full_SBI_Extra[i][...,jj][WMs[i].astype(bool)[:,:,48]]))

    Prec7_NLLS_Dp.append(np.std(Min_LS_extra[i][...,jj][WMs[i].astype(bool)[:,:,48]]))
    PrecFull_NLLS_Dp.append(np.std(Full_LS_extra[i][...,jj][WMs[i].astype(bool)[:,:,48]]))


In [None]:
jj = 2
SBI_comp_Dp_RT = []
for i in range(29):
    NS1 = np.copy(FullSBI_list[i][...,jj])
    NS2 = np.copy(MinSBI_list[i][...,jj])
    Ma = ~np.isnan(NS1)
    NS1[~Ma] = 0
    NS2[~Ma] = 0
    SBI_comp_Dp_RT.append(masked_local_ssim(NS1, NS2, Ma, win_size=7))

LS_comp_Dp_RT = []
for i in range(29):
    NS1 = np.copy(FullNLLS_list[i][...,jj])
    NS2 = np.copy(MinNLLS_list[i][...,jj])
    Ma = ~np.isnan(NS1)
    NS1[~Ma] = 0
    NS2[~Ma] = 0
    LS_comp_Dp_RT.append(masked_local_ssim(NS1, NS2, Ma, win_size=7))

In [None]:
fig, ax = plt.subplots(figsize=(3.2,4.8))#, sharex=True)

# Plotting on ax1
plt.sca(ax)
y_data = np.array(SBI_comp_Dp+SBI_comp_Dp_RT)
g_pos = np.array([1.3])
colors = ['lightseagreen']
colors2 = ['paleturquoise']

BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2,scatter=False,scatter_alpha=0.5)

y_data = np.array(SBI_comp_Dp)
x_data = g_pos*np.ones_like(y_data)
x_data += stats.t(df=6, scale=0.02).rvs(len(x_data))
ax.scatter(x_data,y_data,marker='o',color=colors2,s=100,alpha=0.8)

y_data = np.array(SBI_comp_Dp_RT)
x_data = g_pos*np.ones_like(y_data)
x_data += stats.t(df=6, scale=0.02).rvs(len(x_data))
ax.scatter(x_data,y_data,marker='s',color='darkcyan',s=100,alpha=0.5)

y_data = np.array(LS_comp_Dp+LS_comp_Dp_RT)
g_pos = np.array([1.9])
colors = ['sandybrown']
colors2 = ['peachpuff']

BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2,scatter=True)

y_data = np.array(LS_comp_Dp)
x_data = g_pos*np.ones_like(y_data)
x_data += stats.t(df=6, scale=0.02).rvs(len(x_data))
ax.scatter(x_data,y_data,marker='o',color=colors2,s=100,alpha=0.8)

y_data = np.array(LS_comp_Dp_RT)
x_data = g_pos*np.ones_like(y_data)
x_data += stats.t(df=6, scale=0.02).rvs(len(x_data))
ax.scatter(x_data,y_data,marker='s',color='chocolate',s=100,alpha=0.5)

plt.xticks([1.3,1.9],['SBI','NNLS'],fontsize=32,rotation=90)
ax.set_yticks([0,0.2,0.4,0.6,0.8,1.0])
ax.set_ylim(-0.1,1)

if Save: plt.savefig(FigLoc+'MS_Ax_SSIM_Dpar.pdf',format='PDF',transparent=True,bbox_inches='tight')

In [None]:
fig,ax1 = plt.subplots(1,1,figsize=(3.2,4.8))

y_data = np.array(PrecFull_SBI_Dp)
g_pos = np.array([0.65])
colors = ['mediumturquoise']
colors2 = ['paleturquoise']

BoxPlots(y_data,g_pos,colors,colors2,ax1,widths=0.2,scatter=True)

y_data = np.array(Prec7_SBI_Dp)
g_pos = np.array([1.1])

BoxPlots(y_data,g_pos,colors,colors2,ax1,widths=0.2,scatter=True)

y_data = np.array(PrecFull_NLLS_Dp)
g_pos = np.array([1.8])
colors = ['sandybrown']
colors2 = ['peachpuff']

BoxPlots(y_data,g_pos,colors,colors2,ax1,widths=0.2,scatter=True)


y_data = np.array(Prec7_NLLS_Dp)
g_pos = np.array([2.15])

BoxPlots(y_data,g_pos,colors,colors2,ax1,widths=0.2,scatter=True)
plt.xticks([0.65,1.1,1.8,2.15],['Full','Red.','Full','Red.'],fontsize=32,rotation=90)

x = np.arange(1.7,2.3,0.05)
y1 = np.ones_like(x)*np.percentile(np.array(PrecFull_NLLS_Dp)[~np.isnan(PrecFull_NLLS_Dp)], 25)
y2 = np.ones_like(x)*np.percentile(np.array(PrecFull_NLLS_Dp)[~np.isnan(PrecFull_NLLS_Dp)], 77)
plt.fill_between(x,y1,y2,color='sandybrown',zorder=10,alpha=0.2,hatch='//')

x = np.arange(0.55,1.25,0.05)
y1 = np.ones_like(x)*np.percentile(np.array(PrecFull_SBI_Dp)[~np.isnan(PrecFull_SBI_Dp)], 25)
y2 = np.ones_like(x)*np.percentile(np.array(PrecFull_SBI_Dp)[~np.isnan(PrecFull_SBI_Dp)], 77)
plt.fill_between(x,y1,y2,color='mediumturquoise',zorder=10,alpha=0.2,hatch='//')

#ax1.set_xlim(0.3,2.8)
ax1.ticklabel_format(axis='y',style='sci',scilimits=(-1,1))
if Save: plt.savefig(FigLoc+'MS_Ax_Prec_Dpar.pdf',format='PDF',transparent=True,bbox_inches='tight')