# Imports

In [None]:
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from matplotlib.lines import Line2D
from matplotlib.colors import TwoSlopeNorm
from scipy.ndimage import binary_dilation
# Define font properties
import matplotlib.ticker as ticker
font = {
    'family': 'sans-serif',  # Use sans-serif family
    'sans-serif': ['Helvetica'],  # Specify Helvetica as the sans-serif font
    'size': 14  # Set the default font size
}
plt.rc('font', **font)

# Set tick label sizes
plt.rc('ytick', labelsize=24)
plt.rc('xtick', labelsize=24)

plt.rcParams.update({
    "text.usetex": False,
    "font.family": "Helvetica"
})
# Customize axes spines and legend appearance
plt.rcParams['axes.spines.top'] = False
plt.rcParams['axes.spines.right'] = False
plt.rcParams['legend.frameon'] = False


%load_ext autoreload
%autoreload 2
    
from dwMRI_BasicFuncs import *
from joblib import Parallel, delayed

from tqdm.auto import tqdm
from scipy.ndimage import gaussian_filter

In [None]:
import pymatreader as pmt

from dipy.io.image import load_nifti

In [None]:
MSDir = './MS_data/'

In [None]:
Dir = MSDir+'/Ctrl055_R01_28/'
dat = pmt.read_mat(Dir+'data_loaded.mat')
bvecs = dat['direction']
bvals = dat['bval']
FixedParams = {
    'bvals':bvals,
    'bvecs':bvecs,
    'Delta':[0.017,0.035,0.061],
    'delta':0.007,
}
Delta = FixedParams['Delta']
delta = FixedParams['delta']
n_pts = 90

Delta = [0.017,0.035,0.061] # We know this 
delta = 0.007 # We know this 


data = dat['data']
axial_middle = data.shape[2] // 2
maskdata, mask = median_otsu(data, vol_idx=range(0, 10), median_radius=5,
                             numpass=1, autocrop=False, dilate=2)

plt.imshow(maskdata[:,:,41,0])

In [None]:
network_path = './Networks/'
image_path   = '../Figures/'

In [None]:
FigLoc = image_path + 'Fig_5/'
if not os.path.exists(FigLoc):
    os.makedirs(FigLoc)

In [None]:
f"{network_path}/Full_Dat_50_200k.pickle"

In [None]:
if os.path.exists(f"{network_path}/Full_Dat_50_100k_poisson.pickle"):
    with open(f"{network_path}/Full_Dat_50_100k_poisson.pickle", "rb") as handle:
        posterior = pickle.load(handle)
else:

    
    TrainSig = []
    NoisyTrainSig = []
    for i in tqdm(range(NumSamps)):
        v = np.array([Angs[i]])
        dpar = Dpar[i]
        dperp = Dperp[i]
        
        dh   = DHind[i]
        f    = [frac[i],1-frac[i]]
    
        a = mean[i]
        s0 = S0Rand[i]
        
        Noise = 50#np.random.rand()*30 + 20
    
        TrainSig1 = CombSignal_poisson(bvecs[:(n_pts+1)],bvals[:(n_pts+1)],Delta[0],delta,[v,dpar,dperp,dh,f,a,s0])
        TrainSig2 = CombSignal_poisson(bvecs[(n_pts+1):2*(n_pts+1)],bvals[(n_pts+1):2*(n_pts+1)],Delta[1],delta,[v,dpar,dperp,dh,f,a,s0])
        TrainSig3 = CombSignal_poisson(bvecs[2*(n_pts+1):],bvals[2*(n_pts+1):],Delta[2],delta,[v,dpar,dperp,dh,f,a,s0])
        TrainSig.append(np.hstack([TrainSig1,TrainSig2,TrainSig3]))
        NoisyTrainSig.append(AddNoise(TrainSig[-1],s0,Noise))
    NoisyTrainSig = np.array(NoisyTrainSig)



    Obs = torch.tensor(NoisyTrainSig).float()
    Par = torch.tensor(TrainParams[:,3:]).float()
    # Create inference object. Here, NPE is used.
    inference = SNPE()
    
    # generate simulations and pass to the inference object
    inference = inference.append_simulations(Par, Obs)
    
    # train the density estimator and build the posterior
    density_estimator = inference.train()
    posterior = inference.build_posterior(density_estimator)
    with open(f"{network_path}/Full_Dat_50_100k_poisson.pickle", "wb") as handle:
        pickle.dump(posterior, handle)

In [None]:
# Compute the mask where the sum is not zero
mask = np.sum(maskdata[:, 54, :, :], axis=-1) != 0

# Get the indices where mask is True
indices = np.argwhere(mask)

# Define the function for optimization
def optimize_pixel(i, j):
    torch.manual_seed(10)  # If required
    posterior_samples_1 = posterior.sample((1000,), x=maskdata[i, 54, j, :],show_progress_bars=False)
    return i, j, posterior_samples_1.mean(axis=0)

# Initialize NoiseEst with the appropriate shape
ArrShape = mask.shape

# Use joblib to parallelize the optimization tasks
results = Parallel(n_jobs=-1)(
    delayed(optimize_pixel)(i, j) for i, j in tqdm(indices)
)


NoiseEst = np.zeros(list(ArrShape) + [13])

# Assign the optimization results to NoiseEst
for i, j, x in results:
    NoiseEst[i, j] = x

for i, j, x in results:
    NoiseEst[i, j,-2] = np.clip(NoiseEst[i, j,-2],0,100)
    NoiseEst[i, j,-3] = np.clip(NoiseEst[i, j,-3],0,1)

In [None]:
NoiseEst2 = np.copy(NoiseEst)

mask1 = np.ones_like(S_mask[:,54,:])
mask1[S_mask[:,54,:]==0] = 0
structure = np.ones((3, 3), dtype=bool)

# Apply dilation. Increase 'iterations' to make the mask even fatter.
fat_mask = binary_dilation(mask1, structure=structure, iterations=1)

comb_mask = fat_mask * ((1-NoiseEst2[...,-3])>0.1)

mask_CC = (1-NoiseEst2[...,-3])<0.3
for i in range(13):
    NoiseEst2[~mask,i] = math.nan

NoiseEst2[~comb_mask,-2] = math.nan

In [None]:
NoiseEst2 = np.copy(NoiseEst)
mask_CC = (1-NoiseEst2[...,-3])<0.3
for i in range(13):
    NoiseEst2[~mask,i] = math.nan

NoiseEst2[~comb_mask,-2] = math.nan
plt.subplots(figsize=(12,12))
plt.imshow(np.flipud(NoiseEst2[...,-1].T),cmap='gray')
norm = TwoSlopeNorm(vmin=0, vcenter=0.0035, vmax=0.007)
im = plt.imshow(np.flipud(NoiseEst2[...,-2].T),cmap='hot',norm=norm)
#cbar = plt.colorbar(im,fraction=0.035, pad=0.01,format=ticker.FormatStrFormatter('%.e'))
#cbar.ax.tick_params(labelsize=14)
plt.axis('off')

plt.savefig(FigLoc+'/FullSize_CC.pdf',bbox_inches='tight',format='pdf',transparent=True)

In [None]:
def residuals_S0(params,TrueSig,bvecs,bvals,Delta):
    Signal = Simulator_new(params,bvecs,bvals,Delta,S0=params[-1])
    return TrueSig - Signal
def Simulator_new(params,bvecs,bvals,Delta,S0=1):
    new_params = [np.array([params[:2]]),params[2],params[3],params[4:10],[params[10],1-params[10]],params[11],S0]
    Sig = []
    for bve,bva,d in zip(bvecs,bvals,Delta):
        Sig.append(CombSignal_poisson(bve,bva,d,delta,new_params))
    return np.hstack(Sig) 

In [None]:
def CombSignal_poisson(bvecs, bvals, Delta, delta, params):
    """
    Compute the combined diffusion signal in a fast, vectorized way.
    
    Parameters:
      bvecs  : (M,3) array of b-vectors.
      bvals  : (M,) array of b-values.
      Delta, delta : acquisition parameters (scalars)
      params : list/tuple of parameters:
          params[0] : fiber directions as an (N,2) array of spherical angles (theta, phi)
          params[1] : Dpar (scalar)
          params[2] : Dperp (scalar)
          params[3] : D (for hindered compartment; passed to vals_to_mat)
          params[4] : fiber fractions as an (N+1,) array 
                      (first element for hindered compartment, then one per fiber)
          params[5] : mean (scalar, for gamma distribution)
          params[6] : sig2 (scalar, for gamma distribution)
          params[7] : S0 (scalar)
    
    Returns:
      Signal : (M,) array of simulated signal values.
    """
    # Unpack parameters
    V_angles, Dpar, Dperp, D, fracs, mean, S0 = params

    # --- 1. Compute fiber unit vectors from spherical angles ---
    # Assume V_angles is an (N,2) array: each row is (theta, phi).
    theta_fibers = V_angles[:, 0]
    phi_fibers   = V_angles[:, 1]
    V_unit = np.column_stack((np.sin(theta_fibers) * np.cos(phi_fibers),
                              np.sin(theta_fibers) * np.sin(phi_fibers),
                              np.cos(theta_fibers)))  # shape: (N, 3)

    # --- 2. Compute angles between each fiber and each b-vector ---
    # Make sure bvecs is an array.
    bvecs = np.asarray(bvecs)  # shape: (M,3)
    M = bvecs.shape[0]
    N = V_unit.shape[0]

    # Precompute norms of bvecs (we assume fibers are unit length so no extra norm is needed)
    bvec_norms = np.linalg.norm(bvecs, axis=1)
    # Avoid division by zero:
    safe_bvec_norms = np.where(bvec_norms == 0, 1, bvec_norms)

    # Compute the dot products for each fiber with all bvecs:
    # This gives a (N, M) array where the (i,j) element = v_i dot bvec_j.
    dots = V_unit @ bvecs.T  # shape: (N, M)

    # Divide each column j by the norm of bvec j (broadcasting over fibers)
    cos_angles = dots / safe_bvec_norms  # shape: (N, M)
    cos_angles = np.clip(cos_angles, -1, 1)
    # Get the angles in [0,pi]
    Angs = np.arccos(cos_angles)
    # For bvecs that are zero (norm==0), force the angle to zero.
    if np.any(bvec_norms == 0):
        Angs[:, bvec_norms == 0] = 0
    # If an angle is greater than pi/2, use pi - angle.
    Angs = np.where(Angs > np.pi/2, np.pi - Angs, Angs)
    # In the original code the first measurement was forced to zero (presumably b = 0)
    Angs[:, 0] = 0

    # --- 3. Precompute the gamma-distributed weights for the integration over R ---
    # Gamma distribution parameters:
    lam = mean*10000
    # Define R values (50 points between 0.0001 and 0.005)
    R_vals = np.arange(0.0001, 0.01, 0.0001)  # 
    transR = (R_vals * 10000).astype(int)

    weights = (lam**transR) * np.exp(-lam) / np.array([math.factorial(r) for r in transR.astype(int)]).astype(np.double)
    weights /= np.sum(weights)

    # --- 4. Precompute the "sumterm" that appears in the restricted compartment ---
    # Here we use m=10 terms and assume that a global array Bessel_roots is available.
    m = 10
    br = Bessel_roots[:m]  # shape: (m,)
    br2 = br**2
    br6 = br**6
    # For each R in R_vals, compute the sumterm.
    # We need to broadcast over R and over the m terms.
    R2 = R_vals**2  # shape: (50,)
    # numerator: shape (50, m)
    num = (2 * Dperp * br2 * delta / R2[:, None] - 2 +
           2 * np.exp(-Dperp * br2 * delta / R2[:, None]) +
           2 * np.exp(-Dperp * br2 * Delta / R2[:, None]) -
           np.exp(-Dperp * br2 * (Delta - delta) / R2[:, None]) -
           np.exp(-Dperp * br2 * (Delta + delta) / R2[:, None]))
    # denominator: shape (50, m)
    den = (Dperp**2) * br6 * (br2 - 1) / (R_vals[:, None]**6)
    sumterm_R = np.sum(num / den, axis=1)  # shape: (50,)

    # --- 5. Compute the restricted compartment signal ---
    # For each fiber orientation i (i = 0...N-1) and for each measurement j (j = 0...M-1)
    # we need to compute:
    #   Restricted(b, theta, R) = exp(-b * (cos(theta)**2) * Dpar) *
    #                             exp(-2 * b * (sin(theta)**2) / ((Delta-delta/3)*delta**2) * sumterm)
    #
    # Notice that only the second exponential depends on R (via sumterm_R) and we need to integrate
    # over R with weights.
    #
    # Compute the part independent of R (base) and the factor x that multiplies sumterm_R.
    #
    # Angs has shape (N, M) (one row per fiber) and bvals is (M,).
    # (We assume that bvals is a 1D array; if not, cast it with np.asarray(bvals).)
    bvals = np.asarray(bvals)  # shape: (M,)
    base = np.exp(-bvals * (np.cos(Angs)**2) * Dpar)  # shape: (N, M)
    # Factor multiplying sumterm_R inside the second exponential.
    x = -2 * bvals * (np.sin(Angs)**2) / ((Delta - delta/3) * delta**2)  # shape: (N, M)
    # For each fiber orientation and measurement, we want to compute:
    #    f(i,j) = sum_{r=0}^{49} weights[r] * exp( x(i,j) * sumterm_R[r] )
    # We can compute the 3D array exp(x * sumterm_R) with shape (N, M, 50) and then contract out the last axis.
    exp_term = np.exp(x[..., None] * sumterm_R)  # shape: (N, M, 50)
    # Now take the weighted sum over the last axis (the R axis):
    restricted_integral = np.tensordot(exp_term, weights, axes=([2], [0]))  # shape: (N, M)
    # The restricted compartment signal for each fiber and measurement is then:
    Res = base * restricted_integral  # shape: (N, M)
    #
    # Finally, combine the fibers by weighting each fiber's contribution by its fraction.
    # The original code did: np.sum([f * R for f,R in zip(fracs[1:],Res)], axis=0)
    # That is equivalent to a dot product: (fracs[1:]) dot (each row of Res).
    restricted_signal = np.dot(fracs[1:], Res)  # shape: (M,)

    # --- 6. Compute the hindered compartment signal ---
    # Compute the diffusion tensor from D (using your vals_to_mat function).
    dh = vals_to_mat(D)
    # The hindered signal is given by:
    #    Hi = exp(-b * s)
    # where s = sum((bvec @ dh)*bvec, axis=1). Here bvecs is (M,3).
    s = np.sum((bvecs @ dh) * bvecs, axis=1)  # shape: (M,)
    hindered_signal = np.exp(-bvals * s)  # shape: (M,)

    # --- 7. Combine compartments and scale by S0 ---
    Signal = fracs[0] * hindered_signal + restricted_signal
    return S0 * Signal

def SpherAng(v_in):

    if v_in[2] < 0:
        v_in = -v_in  # Flip the vector to the top hemisphere

    x, y, z = v_in
    r = np.linalg.norm(v_in)
    if r == 0:
        # Degenerate vector, define angles however you like:
        return 0.0, 0.0
    
    # Polar angle in [0, pi]
    theta = np.arccos(z / r)
    
    # Azimuthal angle in (-pi, pi]
    phi = np.arctan2(y, x)
        
    return theta,phi

def j1_derivative(x):
    """Derivative of J1(x) using the identity: J1'(x) = 0.5 * (J0(x) - J2(x))."""
    return 0.5 * (j0(x) - j2(x))

def j2(x):
    """Bessel function J_2(x)."""
    return jv(2, x)

def j1prime_zeros(n, x_max=100, step=0.1):
    """
    Find the first n positive roots of J1'(x) by scanning from x=0 to x_max.
    
    Parameters
    ----------
    n     : int
        Number of roots to find
    x_max : float
        Maximum x to search
    step  : float
        Step size for scanning sign changes
    
    Returns
    -------
    zeros : list of float
        List of the first n roots (x > 0) of J1'(x).
    """
    zeros = []
    x_vals = np.arange(0.0, x_max, step)
    
    f_prev = j1_derivative(x_vals[0])
    for i in range(1, len(x_vals)):
        f_curr = j1_derivative(x_vals[i])
        # Check for a sign change in [x_vals[i-1], x_vals[i]]
        if f_prev * f_curr < 0:
            root = bisect(j1_derivative, x_vals[i-1], x_vals[i])
            zeros.append(root)
            if len(zeros) == n:
                break
        f_prev = f_curr
    
    return zeros

n_roots = 100
Bessel_roots = np.array(j1prime_zeros(n_roots, x_max=10e6, step=0.01))
Bessel = True

In [None]:
np.random.seed(133)
S0 = 2000
mean_guess = np.random.rand()*0.005+1e-4
Params_abc =  np.random.rand(1,3)*0.14-0.07
Params_rest =  np.random.rand(1,3)*0.03-0.015
Params = np.hstack([Params_abc,Params_rest])
DHind_guess = np.array([ComputeDTI(p) for p in Params])
DHind_guess = np.array([mat_to_vals(ForceLowFA(dt)) for dt in DHind_guess])

Dpar_guess = np.random.rand()*1e-3            # mm^2/s
Dperp_guess = np.random.rand()*1e-3             # mm^2/s
phi = 0#np.random.rand()*pi
cos_theta = 0#np.random.rand()  # uniform in [0,1]
theta = np.arccos(cos_theta)         # in [0, pi/2]
Angs_guess = np.vstack([theta,phi]).T
S0_guess =np.random.rand()*2475+25

frac_guess = np.random.rand()
guess = np.column_stack([Angs_guess,Dpar_guess,Dperp_guess,DHind_guess,frac_guess,mean_guess,S0_guess]).squeeze()
bounds = np.array([[-np.inf,np.inf]]*13).T
bounds[:,0] = [0,np.pi/2]
bounds[:,1] = [-np.pi,np.pi]
bounds[:,2] = [0,5e-3]
bounds[:,3] = [0,5e-3]
bounds[:,4] = [-5e-3,5e-3]
bounds[:,5] = [-5e-3,5e-3]
bounds[:,6] = [-5e-3,5e-3]
bounds[:,7] = [-5e-3,5e-3]
bounds[:,8] = [-5e-3,5e-3]
bounds[:,9] = [-5e-3,5e-3]
bounds[:,10] = [0,1]
bounds[:,11] = [1e-4,0.005+1e-4]
bounds[:,12] = [25,2500]

bve_split = [bvecs[:(n_pts+1)],bvecs[(n_pts+1):2*(n_pts+1)],bvecs[2*(n_pts+1):]]
bva_split = [bvals[:(n_pts+1)],bvals[(n_pts+1):2*(n_pts+1)],bvals[2*(n_pts+1):]]

In [None]:
# Compute the mask where the sum is not zero
mask = np.sum(maskdata[:, 54, :, :], axis=-1) != 0

# Define the function for optimization
def optimize_pixel_LS(i, j):
    result = sp.optimize.least_squares(residuals_S0, guess, args=[maskdata[i, 54, j, :],bve_split,bva_split,Delta],
                              bounds=bounds,verbose=0,xtol=1e-12,gtol=1e-12,ftol=1e-12,jac='3-point')
    return i, j, result.x

# Initialize NoiseEst with the appropriate shape
ArrShape = mask.shape

# Use joblib to parallelize the optimization tasks
results = Parallel(n_jobs=-1)(
    delayed(optimize_pixel_LS)(i, j) for i, j in tqdm(indices)
)


NoiseEst_LS = np.zeros(list(ArrShape) + [13])

# Assign the optimization results to NoiseEst
for i, j, x in results:
    NoiseEst_LS[i, j] = x

In [None]:
def BoxPlots(y_data, positions, colors, colors2, ax,hatch = False,scatter=False,scatter_alpha=0.5, **kwargs):

    GREY_DARK = "#747473"
    jitter = 0.02
    # Clean data to remove NaNs column-wise
    if(np.ndim(y_data) == 1):
        cleaned_data = y_data[~np.isnan(y_data)]
    else:
        cleaned_data = [d[~np.isnan(d)] for d in y_data]
    
    # Define properties for the boxes (patch objects)
    boxprops = dict(
        linewidth=2, 
        facecolor='none',       # use facecolor for filling (set to 'none' if you want no fill)
        edgecolor='turquoise'   # edgecolor for the outline
    )

    # Define properties for the medians (Line2D objects)
    # Ensure GREY_DARK is defined (or replace it with a color string)
    medianprops = dict(
        linewidth=2, 
        color=GREY_DARK,
        solid_capstyle="butt"
    )

    # For whiskers, since they are Line2D objects, use 'color'
    whiskerprops = dict(
        linewidth=2, 
        color='turquoise'
    )

    bplot = ax.boxplot(
        cleaned_data,
        positions=positions, 
        showfliers=False,
        showcaps = False,
        medianprops=medianprops,
        whiskerprops=whiskerprops,
        boxprops=boxprops,
        patch_artist=True,
        **kwargs
    )

    # Update the color of each box (these are patch objects)
    for i, box in enumerate(bplot['boxes']):
        box.set_edgecolor(colors[i])
        if(hatch):
            box.set_hatch('/')
    
    
    # Update the color of the whiskers (each box has 2 whiskers)
    for i in range(len(positions)):
        bplot['whiskers'][2*i].set_color(colors[i])
        bplot['whiskers'][2*i+1].set_color(colors[i])
    
    # If caps are enabled, update their color (Line2D objects)
    if 'caps' in bplot:
        for i, cap in enumerate(bplot['caps']):
            cap.set_color(colors[i//2])  # two caps per box

    if(scatter):
        if(np.ndim(cleaned_data) == 1):
            x_data = np.array([positions] * len(cleaned_data))
            x_jittered = x_data + stats.t(df=6, scale=jitter).rvs(len(x_data))
            ax.scatter(x_data, cleaned_data, s=100, color=colors2, alpha=scatter_alpha)
        else:
            x_data = [np.array([positions[i]] * len(d)) for i, d in enumerate(cleaned_data)]
            x_jittered = [x + stats.t(df=6, scale=jitter).rvs(len(x)) for x in x_data]
            # Plot the scatter points with jitter (using colors2)
            for x, y, c in zip(x_jittered, cleaned_data, colors2):
                ax.scatter(x, y, s=100, color=c, alpha=scatter_alpha)

## Multi

In [None]:
FigLoc = image_path + 'Fig_5/'
if not os.path.exists(FigLoc):
    os.makedirs(FigLoc)

In [None]:
Dirs = ['NMSS_11_1year','NMSS_15','NMSS_16','NMSS_18','NMSS_19','Ctrl055_R01_28','Ctrl056_R01_29','Ctrl057_R01_30']
BVecs = []
BVals = []
Deltas = []
deltas = []
S_masks = []
Datas = []
Outlines = []
axial_middles = []
for D in tqdm(Dirs):
    F = pmt.read_mat(MSDir+D+'/data_loaded.mat')
    affine = np.ones((4,4))
    BVecs.append(F['direction'])
    BVals.append(F['bval'])
    Deltas.append(FixedParams['Delta'])
    deltas.append(FixedParams['delta'])


    
    data, affine = reslice(F['data'], affine, (2,2,2), (2.5,2.5,2.5))

    axial_middle = data.shape[2] // 2
    md, mk = median_otsu(data, vol_idx=range(0, 10), median_radius=5,
                                 numpass=1, autocrop=False, dilate=2)
    Datas.append(md)
    axial_middles.append(axial_middle)
    Outlines.append(mk)

In [None]:
np.random.seed(12)
NumSamps = 10000

# Directions
x1  = np.random.randn(NumSamps)
y1  = np.random.randn(NumSamps)
z1  =  np.random.randn(NumSamps)
V = np.vstack([x1,y1,z1])
V = (V/np.linalg.norm(V,axis=0)).T
Angs = np.array([SpherAng(v) for v in V])

#Diffusion of restricted
Dpar  = np.random.rand(NumSamps)*5e-3
Dperp = np.random.rand(NumSamps)*5e-3

#Diffusion of hindered
Params_abc =  np.random.rand(NumSamps,3)*0.14-0.07
Params_rest =  np.random.rand(NumSamps,3)*0.03-0.015
Params = np.hstack([Params_abc,Params_rest])
DHind = np.array([ComputeDTI(p) for p in Params])
DHind = np.array([mat_to_vals(ForceLowFA(dt)) for dt in DHind])

#Fraction of hindered
frac  = np.random.rand(NumSamps)

mean = np.random.rand(NumSamps)*0.005+1e-4

S0Rand =np.random.rand(NumSamps)*2475+25

Choice = np.random.choice(np.arange(8),NumSamps)

In [None]:
TrainParams = np.column_stack([V,Angs,Dpar,Dperp,DHind,frac,mean,S0Rand])

In [None]:
if os.path.exists(f"{network_path}/AxMultiMSFull_300.pickle"):
    with open(f"{network_path}/AxMultiMSFull_300.pickle", "rb") as handle:
        posteriorFull = pickle.load(handle)
else:
    TrainSig = []
    NoisyTrainSig = []
    for i in tqdm(range(NumSamps)):
        v = np.array([Angs[i]])
        dpar = Dpar[i]
        dperp = Dperp[i]
        
        dh   = DHind[i]
        f    = [frac[i],1-frac[i]]
    
        a = mean[i]
        s0 = S0Rand[i]
        c = Choice[i]
        
        Noise = 50#np.random.rand()*30 + 20
    
        TrainSig1 = CombSignal_poisson(BVecs[c][:(n_pts+1)],BVals[c][:(n_pts+1)],Delta[0],delta,[v,dpar,dperp,dh,f,a,s0])
        TrainSig2 = CombSignal_poisson(BVecs[c][(n_pts+1):2*(n_pts+1)],BVals[c][(n_pts+1):2*(n_pts+1)],Delta[1],delta,[v,dpar,dperp,dh,f,a,s0])
        TrainSig3 = CombSignal_poisson(BVecs[c][2*(n_pts+1):],BVals[c][2*(n_pts+1):],Delta[2],delta,[v,dpar,dperp,dh,f,a,s0])
        TrainSig.append(np.hstack([TrainSig1,TrainSig2,TrainSig3]))
        NoisyTrainSig.append(AddNoise(TrainSig[-1],s0,Noise))
    NoisyTrainSig = np.array(NoisyTrainSig)
    
    
    
    Obs = np.hstack([NoisyTrainSig,np.expand_dims(Choice, axis=-1)])
    Obs = torch.tensor(Obs).float()
    Par = torch.tensor(TrainParams[:,3:]).float()
    # Create inference object. Here, NPE is used.
    inference = SNPE()
    
    # generate simulations and pass to the inference object
    inference = inference.append_simulations(Par, Obs)
    
    # train the density estimator and build the posterior
    density_estimator = inference.train()
    posterior = inference.build_posterior(density_estimator)
    with open(f"{network_path}/AxMultiMSFull_300.pickle", "wb") as handle:
        pickle.dump(posterior, handle)

In [None]:
bve,bva = BVecs[0],BVals[0]
# Choose the first point (arbitrary starting point, e.g., the first gradient)
selected_indices = [0]
bvecs2000 = bve[:91][bva[:91]==2000]
distance_matrix = squareform(pdist(bvecs2000))
# Iteratively select the point furthest from the current selection
for _ in range(9):  # We need 7 points in total, and one is already selected
    remaining_indices = list(set(range(len(bvecs2000))) - set(selected_indices))
    
    # Calculate the minimum distance to the selected points for each remaining point
    min_distances = np.min(distance_matrix[remaining_indices][:, selected_indices], axis=1)
    
    # Select the point with the maximum minimum distance
    next_index = remaining_indices[np.argmax(min_distances)]
    selected_indices.append(next_index)

selected_indices = selected_indices
bvecs2000_selected = bve[:91][bva[:91]==2000][selected_indices]
true_indices = []
for b in bvecs2000_selected:
    true_indices.append(np.where((b == bve).all(axis=1))[0][0])

# Choose the first point (arbitrary starting point, e.g., the first gradient)
selected_indices = [0]
bvecs4000 = bve[:91][bva[:91]==4000]
distance_matrix = squareform(pdist(bvecs4000))
# Iteratively select the point furthest from the current selection
for _ in range(9):  # We need 7 points in total, and one is already selected
    remaining_indices = list(set(range(len(bvecs4000))) - set(selected_indices))
    
    # Calculate the minimum distance to the selected points for each remaining point
    min_distances = np.min(distance_matrix[remaining_indices][:, selected_indices], axis=1)
    
    # Select the point with the maximum minimum distance
    next_index = remaining_indices[np.argmax(min_distances)]
    selected_indices.append(next_index)

selected_indices = selected_indices
bvecs4000_selected = bve[:91][bva[:91]==4000][selected_indices]
for b in bvecs4000_selected:
    true_indices.append(np.where((b == bve).all(axis=1))[0][0])
true_indices1 = true_indices

DevIndices = [0] + true_indices1[:3]+true_indices1[9:12] + [91+t for t in true_indices1[3:6]]+[91+t for t in true_indices1[12:15]]\
+ [182+t for t in true_indices1[6:9]]+[182+t for t in true_indices1[15:18]]

In [None]:
np.random.seed(12)
NumSamps = 30000

# Directions
x1  = np.random.randn(NumSamps)
y1  = np.random.randn(NumSamps)
z1  =  np.random.randn(NumSamps)
V = np.vstack([x1,y1,z1])
V = (V/np.linalg.norm(V,axis=0)).T
Angs = np.array([SpherAng(v) for v in V])

#Diffusion of restricted
Dpar  = np.random.rand(NumSamps)*5e-3
Dperp = np.random.rand(NumSamps)*5e-3

#Diffusion of hindered
Params_abc =  np.random.rand(NumSamps,3)*0.14-0.07
Params_rest =  np.random.rand(NumSamps,3)*0.03-0.015
Params = np.hstack([Params_abc,Params_rest])
DHind = np.array([ComputeDTI(p) for p in Params])
DHind = np.array([mat_to_vals(ForceLowFA(dt)) for dt in DHind])

#Fraction of hindered
frac  = np.random.rand(NumSamps)

mean = np.random.rand(NumSamps)*0.005+1e-4

S0Rand =np.random.rand(NumSamps)*2475+25

Choice = np.random.choice(np.arange(8),NumSamps)

In [None]:
TrainParams = np.column_stack([V,Angs,Dpar,Dperp,DHind,frac,mean,S0Rand])

In [None]:
BVecs[c][DevIndices][7:13]

In [None]:
BVals[c][DevIndices[7:13]]

In [None]:
DevIndices

In [None]:
plt.plot(D[30, 30, sl, :])
plt.plot(DevIndices,np.ones(19)*400,'xr')

In [None]:
if os.path.exists(f"{network_path}/AxMultiMSDev_30.pickle"):
    with open(f"{network_path}/AxMultiMSDev_30.pickle", "rb") as handle:
        posteriorDev = pickle.load(handle)
else:
    TrainSig = []
    NoisyTrainSig = []
    for i in tqdm(range(NumSamps)):
        v = np.array([Angs[i]])
        dpar = Dpar[i]
        dperp = Dperp[i]
        
        dh   = DHind[i]
        f    = [frac[i],1-frac[i]]
    
        a = mean[i]
        s0 = S0Rand[i]
        c = Choice[i]
        
        Noise = 50#np.random.rand()*30 + 20
    
        TrainSig1 = CombSignal_poisson(BVecs[c][DevIndices][:7],BVals[c][DevIndices][:7],Delta[0],delta,[v,dpar,dperp,dh,f,a,s0])
        TrainSig2 = CombSignal_poisson(BVecs[c][DevIndices][7:13],BVals[c][DevIndices][7:13],Delta[1],delta,[v,dpar,dperp,dh,f,a,s0])
        TrainSig3 = CombSignal_poisson(BVecs[c][DevIndices][13:],BVals[c][DevIndices][13:],Delta[2],delta,[v,dpar,dperp,dh,f,a,s0])
        TrainSig.append(np.hstack([TrainSig1,TrainSig2,TrainSig3]))
        NoisyTrainSig.append(AddNoise(TrainSig[-1],s0,Noise))
    NoisyTrainSig = np.array(NoisyTrainSig)
    
        TrainSig1 = CombSignal_poisson(BVecs[c][:(n_pts+1)],BVals[c][:(n_pts+1)],Delta[0],delta,[v,dpar,dperp,dh,f,a,s0])
        TrainSig2 = CombSignal_poisson(BVecs[c][(n_pts+1):2*(n_pts+1)],BVals[c][(n_pts+1):2*(n_pts+1)],Delta[1],delta,[v,dpar,dperp,dh,f,a,s0])
        TrainSig3 = CombSignal_poisson(BVecs[c][2*(n_pts+1):],BVals[c][2*(n_pts+1):],Delta[2],delta,[v,dpar,dperp,dh,f,a,s0])
        TrainSig.append(np.hstack([TrainSig1,TrainSig2,TrainSig3]))
        NoisyTrainSig.append(AddNoise(TrainSig[-1],s0,Noise))
    NoisyTrainSig = np.array(NoisyTrainSig)
    
    Obs = np.hstack([NoisyTrainSig,np.expand_dims(Choice, axis=-1)])
    Obs = torch.tensor(Obs).float()
    Par = torch.tensor(TrainParams[:,3:]).float()
    # Create inference object. Here, NPE is used.
    inference = SNPE()
    
    # generate simulations and pass to the inference object
    inference = inference.append_simulations(Par, Obs)
    
    # train the density estimator and build the posterior
    density_estimator = inference.train()
    posterior = inference.build_posterior(density_estimator)
    with open(f"{network_path}/AxMultiMSDev_30.pickle", "wb") as handle:
        pickle.dump(posterior, handle)

In [None]:
Full_SBI = []
for kk,(D,sl,sma) in enumerate(zip(Datas,axial_middles,Outlines)):
    # Compute the mask where the sum is not zero
    mask = sma[:,:,sl]
    
    # Get the indices where mask is True
    indices = np.argwhere(mask)
    
    # Define the function for optimization
    def optimize_pixel(i, j):
        torch.manual_seed(10)  # If required
        posterior_samples_1 = posteriorFull.sample((1000,), x=np.append(D[i, j, sl, :],kk),show_progress_bars=False)
        return i, j, np.array([histogram_mode(p) for p in posterior_samples_1.T])
    
    # Initialize NoiseEst with the appropriate shape
    ArrShape = mask.shape
    
    # Use joblib to parallelize the optimization tasks
    results = Parallel(n_jobs=-1)(
        delayed(optimize_pixel)(i, j) for i, j in tqdm(indices)
    )
    
    
    NoiseEst = np.zeros(list(ArrShape) + [13])
    
    # Assign the optimization results to NoiseEst
    for i, j, x in results:
        NoiseEst[i, j] = x

    Full_SBI.append(NoiseEst)
    fig,ax = plt.subplots()
    plt.imshow(1-NoiseEst[...,-3].T,vmin=0,vmax=1)
    plt.show()

In [None]:
Dev_SBI = []
for kk,(D,sl,sma) in enumerate(zip(Datas,axial_middles,Outlines)):
    # Compute the mask where the sum is not zero
    mask = sma[:,:,sl]
    
    # Get the indices where mask is True
    indices = np.argwhere(mask)
    
    # Define the function for optimization
    def optimize_pixel(i, j):
        torch.manual_seed(10)  # If required
        posterior_samples_1 = posteriorDev.sample((1000,), x=np.append(D[i, j, sl, DevIndices],kk),show_progress_bars=False)
        return i, j, np.array([histogram_mode(p) for p in posterior_samples_1.T])
    
    # Initialize NoiseEst with the appropriate shape
    ArrShape = mask.shape
    
    # Use joblib to parallelize the optimization tasks
    results = Parallel(n_jobs=-1)(
        delayed(optimize_pixel)(i, j) for i, j in tqdm(indices)
    )
    
    
    NoiseEst = np.zeros(list(ArrShape) + [13])
    
    # Assign the optimization results to NoiseEst
    for i, j, x in results:
        NoiseEst[i, j] = x

    Dev_SBI.append(NoiseEst)
    fig,ax = plt.subplots()
    plt.imshow(1-NoiseEst[...,-3].T,vmin=0,vmax=1)
    plt.show()

In [None]:
SBI_comp = []
i = 5
NS1 = np.copy(Dev_SBI[i][...,0])
NS2 = np.copy(Full_SBI[i][...,0])
Ma  = Outlines[i][:,:,axial_middles[i]]
core,ssim_map = ssim(NS1,NS2, data_range=max([NS1.max(),NS2.max()])-min([NS1.min(),NS2.min()]),full=True,win_size=7)
masked_ssim = ssim_map[Outlines[i][:,:,axial_middles[i]]].mean()
SBI_comp.append(masked_ssim.mean())
SBI_comp

In [None]:
plt.imshow(NS2)
plt.show()
plt.imshow(NS1)
plt.show()

In [None]:
masked_ssim.shape

In [None]:
plt.imshow(ssim_map)
plt.colorbar()

In [None]:
SBI_comp

In [None]:
plt.imshow(NS2)

In [None]:
plt.imshow(NS1)

In [None]:
SBI_comp = []
for i in range(8):
    NS1 = np.copy(Dev_SBI[i][...,3])
    NS2 = np.copy(Full_SBI[i][...,3])
    Ma  = Outlines[i]

    result = masked_local_ssim(NS1, NS2, Ma, win_size=7,dat_range=1)
    SBI_comp.append(result)
print(SBI_comp)

In [None]:
    NoiseEst = np.zeros(list(ArrShape) + [13])
    
    # Assign the optimization results to NoiseEst
    for i, j, x in results:
        NoiseEst[i, j] = x

    Full_SBI.append(NoiseEst)
    fig,ax = plt.subplots()
    plt.imshow(1-NoiseEst[...,-3].T,vmin=0,vmax=1)
    plt.show()

In [None]:
Dirs = ['NMSS_11_1year','NMSS_15','NMSS_16','NMSS_18','NMSS_19','Ctrl055_R01_28','Ctrl056_R01_29','Ctrl057_R01_30']
BVecs = []
BVals = []
Deltas = []
deltas = []
S_masks = []
Datas = []
Outlines = []
for D in tqdm(Dirs):
    dat = pmt.read_mat(MSDir+D+'/data_loaded.mat')
    BVecs.append(dat['direction'])
    BVals.append(dat['bval'])
    Deltas.append(FixedParams['Delta'])
    deltas.append(FixedParams['delta'])

    data = dat['data']
    axial_middle = data.shape[2] // 2
    md, mk = median_otsu(data, vol_idx=range(0, 10), median_radius=5,
                                 numpass=1, autocrop=False, dilate=2)
    Datas.append(md)
    Outlines.append(mk)

In [None]:
np.random.seed(12)
NumSamps = 300000

# Directions
x1  = np.random.randn(NumSamps)
y1  = np.random.randn(NumSamps)
z1  =  np.random.randn(NumSamps)
V = np.vstack([x1,y1,z1])
V = (V/np.linalg.norm(V,axis=0)).T
Angs = np.array([SpherAng(v) for v in V])

#Diffusion of restricted
Dpar  = np.random.rand(NumSamps)*5e-3
Dperp = np.random.rand(NumSamps)*5e-3

#Diffusion of hindered
Params_abc =  np.random.rand(NumSamps,3)*0.14-0.07
Params_rest =  np.random.rand(NumSamps,3)*0.03-0.015
Params = np.hstack([Params_abc,Params_rest])
DHind = np.array([ComputeDTI(p) for p in Params])
DHind = np.array([mat_to_vals(ForceLowFA(dt)) for dt in DHind])

#Fraction of hindered
frac  = np.random.rand(NumSamps)

mean = np.random.rand(NumSamps)*0.005+1e-4

S0Rand =np.random.rand(NumSamps)*2475+25

Choice = np.random.choice([1,2,3,4,5,6,7,8],NumSamps)

In [None]:
TrainParams = np.column_stack([V,Angs,Dpar,Dperp,DHind,frac,mean,S0Rand,Choice*100])

In [None]:
if os.path.exists(f"{network_path}/8Indv_50_300k_poisson.pickle"):
    with open(f"{network_path}/8Indv_50_300k_poisson.pickle", "rb") as handle:
        posterior = pickle.load(handle)
else:
    TrainSig = []
    NoisyTrainSig = []
    for i in tqdm(range(NumSamps)):
        v = np.array([Angs[i]])
        dpar = Dpar[i]
        dperp = Dperp[i]
        
        dh   = DHind[i]
        f    = [frac[i],1-frac[i]]
    
        a = mean[i]
        s0 = S0Rand[i]
        c = Choice[i]
        
        Noise = 50#np.random.rand()*30 + 20
    
        TrainSig1 = CombSignal_poisson(BVecs[c-1][:(n_pts+1)],bvals[:(n_pts+1)],Delta[0],delta,[v,dpar,dperp,dh,f,a,s0])
        TrainSig2 = CombSignal_poisson(BVecs[c-1][(n_pts+1):2*(n_pts+1)],bvals[(n_pts+1):2*(n_pts+1)],Delta[1],delta,[v,dpar,dperp,dh,f,a,s0])
        TrainSig3 = CombSignal_poisson(BVecs[c-1][2*(n_pts+1):],bvals[2*(n_pts+1):],Delta[2],delta,[v,dpar,dperp,dh,f,a,s0])
        TrainSig.append(np.hstack([TrainSig1,TrainSig2,TrainSig3]))
        NoisyTrainSig.append(np.append(AddNoise(TrainSig[-1],s0,Noise),c*100))
    NoisyTrainSig = np.array(NoisyTrainSig)
    
    
    
    Obs = torch.tensor(NoisyTrainSig).float()
    Par = torch.tensor(TrainParams[:,3:]).float()
    # Create inference object. Here, NPE is used.
    inference = SNPE()
    
    # generate simulations and pass to the inference object
    inference = inference.append_simulations(Par, Obs)
    
    # train the density estimator and build the posterior
    density_estimator = inference.train(stop_after_epochs=100)
    posterior = inference.build_posterior(density_estimator)
    with open(f"{network_path}/8Indv_50_300k_poisson.pickle", "wb") as handle:
        pickle.dump(posterior, handle)

In [None]:
# Choose the first point (arbitrary starting point, e.g., the first gradient)
IndxArr  = []
BVecsDev = []
BValsDev = []
for bve,bva in zip(BVecs,BVals): 
    # Choose the first point (arbitrary starting point, e.g., the first gradient)
    selected_indices = [0]
    bvecs2000 = bve[:91][bva[:91]==2000]
    distance_matrix = squareform(pdist(bve))
    # Iteratively select the point furthest from the current selection
    for _ in range(2):  # We need 7 points in total, and one is already selected
        remaining_indices = list(set(range(len(bvecs2000))) - set(selected_indices))
        
        # Calculate the minimum distance to the selected points for each remaining point
        min_distances = np.min(distance_matrix[remaining_indices][:, selected_indices], axis=1)
        
        # Select the point with the maximum minimum distance
        next_index = remaining_indices[np.argmax(min_distances)]
        selected_indices.append(next_index)
    
    selected_indices = selected_indices
    bvecs2000_selected = bve[:91][bva[:91]==2000][selected_indices]
    true_indices = []
    for b in bvecs2000_selected:
        true_indices.append(np.where((b == bve).all(axis=1))[0][0])
    
    # Choose the first point (arbitrary starting point, e.g., the first gradient)
    selected_indices = [0]
    bvecs4000 = bve[:91][bva[:91]==4000]
    distance_matrix = squareform(pdist(bve))
    # Iteratively select the point furthest from the current selection
    for _ in range(2):  # We need 7 points in total, and one is already selected
        remaining_indices = list(set(range(len(bvecs4000))) - set(selected_indices))
        
        # Calculate the minimum distance to the selected points for each remaining point
        min_distances = np.min(distance_matrix[remaining_indices][:, selected_indices], axis=1)
        
        # Select the point with the maximum minimum distance
        next_index = remaining_indices[np.argmax(min_distances)]
        selected_indices.append(next_index)
    
    selected_indices = selected_indices
    bvecs4000_selected = bve[:91][bva[:91]==4000][selected_indices]
    for b in bvecs4000_selected:
        true_indices.append(np.where((b == bve).all(axis=1))[0][0])
    true_indices1 = true_indices
    
    # Choose the first point (arbitrary starting point, e.g., the first gradient)
    selected_indices = [0]
    bvecs2000 = bve[91:182][bva[91:182]==2000]
    distance_matrix = squareform(pdist(bve))
    # Iteratively select the point furthest from the current selection
    for _ in range(2):  # We need 7 points in total, and one is already selected
        remaining_indices = list(set(range(len(bvecs2000))) - set(selected_indices))
        
        # Calculate the minimum distance to the selected points for each remaining point
        min_distances = np.min(distance_matrix[remaining_indices][:, selected_indices], axis=1)
        
        # Select the point with the maximum minimum distance
        next_index = remaining_indices[np.argmax(min_distances)]
        selected_indices.append(next_index)
    
    selected_indices = selected_indices
    bvecs2000_selected = bve[91:182][bva[91:182]==2000][selected_indices]
    true_indices = []
    for b in bvecs2000_selected:
        true_indices.append(np.where((b == bve).all(axis=1))[0][0])
    
    # Choose the first point (arbitrary starting point, e.g., the first gradient)
    selected_indices = [0]
    bvecs4000 = bve[91:182][bva[91:182]==4000]
    distance_matrix = squareform(pdist(bve))
    # Iteratively select the point furthest from the current selection
    for _ in range(2):  # We need 7 points in total, and one is already selected
        remaining_indices = list(set(range(len(bvecs4000))) - set(selected_indices))
        
        # Calculate the minimum distance to the selected points for each remaining point
        min_distances = np.min(distance_matrix[remaining_indices][:, selected_indices], axis=1)
        
        # Select the point with the maximum minimum distance
        next_index = remaining_indices[np.argmax(min_distances)]
        selected_indices.append(next_index)
    
    selected_indices = selected_indices
    bvecs4000_selected = bve[91:182][bva[91:182]==4000][selected_indices]
    for b in bvecs4000_selected:
        true_indices.append(np.where((b == bve).all(axis=1))[0][0])
    true_indices2 = true_indices
    
    # Choose the first point (arbitrary starting point, e.g., the first gradient)
    selected_indices = [0]
    bvecs2000 = bve[182:][bva[182:]==2000]
    distance_matrix = squareform(pdist(bve))
    # Iteratively select the point furthest from the current selection
    for _ in range(2):  # We need 7 points in total, and one is already selected
        remaining_indices = list(set(range(len(bvecs2000))) - set(selected_indices))
        
        # Calculate the minimum distance to the selected points for each remaining point
        min_distances = np.min(distance_matrix[remaining_indices][:, selected_indices], axis=1)
        
        # Select the point with the maximum minimum distance
        next_index = remaining_indices[np.argmax(min_distances)]
        selected_indices.append(next_index)
    
    selected_indices = selected_indices
    bvecs2000_selected = bve[182:][bva[182:]==2000][selected_indices]
    true_indices = []
    for b in bvecs2000_selected:
        true_indices.append(np.where((b == bve).all(axis=1))[0][0])
    
    # Choose the first point (arbitrary starting point, e.g., the first gradient)
    selected_indices = [0]
    bvecs4000 = bve[182:][bva[182:]==4000]
    distance_matrix = squareform(pdist(bve))
    # Iteratively select the point furthest from the current selection
    for _ in range(2):  # We need 7 points in total, and one is already selected
        remaining_indices = list(set(range(len(bvecs4000))) - set(selected_indices))
        
        # Calculate the minimum distance to the selected points for each remaining point
        min_distances = np.min(distance_matrix[remaining_indices][:, selected_indices], axis=1)
        
        # Select the point with the maximum minimum distance
        next_index = remaining_indices[np.argmax(min_distances)]
        selected_indices.append(next_index)
    
    selected_indices = selected_indices
    bvecs4000_selected = bve[182:][bva[182:]==4000][selected_indices]
    for b in bvecs4000_selected:
        true_indices.append(np.where((b == bve).all(axis=1))[0][0])
    true_indices3 = true_indices
    
    DevIndices = [0] + true_indices1 + true_indices2 + true_indices3
    bvecs_Dev = bve[DevIndices]
    bvals_Dev = bva[DevIndices]

    IndxArr.append(DevIndices)
    BVecsDev.append(bvecs_Dev)
    BValsDev.append(bvals_Dev)

In [None]:
np.random.seed(12)
NumSamps = 300000

# Directions
x1  = np.random.randn(NumSamps)
y1  = np.random.randn(NumSamps)
z1  =  np.random.randn(NumSamps)
V = np.vstack([x1,y1,z1])
V = (V/np.linalg.norm(V,axis=0)).T
Angs = np.array([SpherAng(v) for v in V])

#Diffusion of restricted
Dpar  = np.random.rand(NumSamps)*5e-3
Dperp = np.random.rand(NumSamps)*5e-3

#Diffusion of hindered
Params_abc =  np.random.rand(NumSamps,3)*0.14-0.07
Params_rest =  np.random.rand(NumSamps,3)*0.03-0.015
Params = np.hstack([Params_abc,Params_rest])
DHind = np.array([ComputeDTI(p) for p in Params])
DHind = np.array([mat_to_vals(ForceLowFA(dt)) for dt in DHind])

#Fraction of hindered
frac  = np.random.rand(NumSamps)

mean = np.random.rand(NumSamps)*0.005+1e-4

S0Rand =np.random.rand(NumSamps)*2475+25

Choice = np.random.choice([1,2,3,4,5,6,7,8],NumSamps)

In [None]:
TrainParams = np.column_stack([V,Angs,Dpar,Dperp,DHind,frac,mean,S0Rand,Choice*100])

In [None]:
if os.path.exists(f"{network_path}/Dev_8Indv_50_300k_poisson.pickle"):
    with open(f"{network_path}/Dev_8Indv_50_300k_poisson.pickle", "rb") as handle:
        posteriorMin = pickle.load(handle)
else:
    TrainSig = []
    NoisyTrainSig = []
    for i in tqdm(range(NumSamps)):
        v = np.array([Angs[i]])
        dpar = Dpar[i]
        dperp = Dperp[i]
        
        dh   = DHind[i]
        f    = [frac[i],1-frac[i]]
    
        a = mean[i]
        #s = sig2[i]
        s0 = S0Rand[i]
        
        Noise = 50#np.random.rand()*30 + 20
        c = Choice[i]
        
        Noise = 50#np.random.rand()*30 + 20

        TrainSig1 = CombSignal_poisson(BVecsDev[c-1][:7],BValsDev[c-1][:7],Delta[0],delta,[v,dpar,dperp,dh,f,a,s0])
        TrainSig2 = CombSignal_poisson(BVecsDev[c-1][7:13],BValsDev[c-1][7:13],Delta[1],delta,[v,dpar,dperp,dh,f,a,s0])
        TrainSig3 = CombSignal_poisson(BVecsDev[c-1][13:],BValsDev[c-1][13:],Delta[2],delta,[v,dpar,dperp,dh,f,a,s0])
        TrainSig.append(np.hstack([TrainSig1,TrainSig2,TrainSig3]))
        NoisyTrainSig.append(np.append(AddNoise(TrainSig[-1],s0,Noise),c*100))
    NoisyTrainSig = np.array(NoisyTrainSig)
    
    
    Obs = torch.tensor(NoisyTrainSig).float()
    Par = torch.tensor(TrainParams[:,3:]).float()
    # Create inference object. Here, NPE is used.
    inference = SNPE()
    
    # generate simulations and pass to the inference object
    inference = inference.append_simulations(Par, Obs)
    
    # train the density estimator and build the posterior
    density_estimator = inference.train(stop_after_epochs=100)
    posteriorMin = inference.build_posterior(density_estimator)
    with open(f"{network_path}/Dev_8Indv_50_300k_poisson.pickle", "wb") as handle:
        pickle.dump(posteriorMin, handle)

In [None]:
[48,

In [None]:
plt.imshow(Datas[8][:,:,48,0])

In [None]:
Full_SBI = []
for kk,(D,sl) in enumerate(zip(Datas,[48]*8)):
    # Compute the mask where the sum is not zero
    mask = np.sum(D[:, :, sl, :], axis=-1) != 0
    
    # Get the indices where mask is True
    indices = np.argwhere(mask)
    
    # Define the function for optimization
    def optimize_pixel(i, j):
        torch.manual_seed(10)  # If required
        posterior_samples_1 = posterior.sample((1000,), x=np.append(D[i,j,sl, :],100*(kk+1)),show_progress_bars=False)
        return i, j, posterior_samples_1.mean(axis=0)
    
    # Initialize NoiseEst with the appropriate shape
    ArrShape = mask.shape
    
    # Use joblib to parallelize the optimization tasks
    results = Parallel(n_jobs=-1)(
        delayed(optimize_pixel)(i, j) for i, j in tqdm(indices)
    )
    
    
    NoiseEst = np.zeros(list(ArrShape) + [14])
    
    # Assign the optimization results to NoiseEst
    for i, j, x in results:
        NoiseEst[i, j] = x

    Full_SBI.append(NoiseEst)

In [None]:
Min_SBI = []
for kk,(D,sl) in enumerate(zip(Datas,[48]*8)):
    # Compute the mask where the sum is not zero
    mask = np.sum(D[:, :, sl, :], axis=-1) != 0
    
    # Get the indices where mask is True
    indices = np.argwhere(mask)
    
    # Define the function for optimization
    def optimize_pixel(i, j):
        torch.manual_seed(10)  # If required
        posterior_samples_1 = posteriorMin.sample((1000,), x=np.append(D[i,j,sl, IndxArr[kk]],100*(kk+1)),show_progress_bars=False)
        return i, j, posterior_samples_1.mean(axis=0)
    
    # Initialize NoiseEst with the appropriate shape
    ArrShape = mask.shape
    
    # Use joblib to parallelize the optimization tasks
    results = Parallel(n_jobs=-1)(
        delayed(optimize_pixel)(i, j) for i, j in tqdm(indices)
    )
    
    
    NoiseEst = np.zeros(list(ArrShape) + [14])
    
    # Assign the optimization results to NoiseEst
    for i, j, x in results:
        NoiseEst[i, j] = x

    Min_SBI.append(NoiseEst)

In [None]:
Full_LS = []
for kk,(D,sl) in enumerate(zip(Datas,[48]*8)):
    # Compute the mask where the sum is not zero
    mask = np.sum(D[:, :, sl, :], axis=-1) != 0
    
    # Get the indices where mask is True
    indices = np.argwhere(mask)

    bve_split = [BVecs[kk][:(n_pts+1)],BVecs[kk][(n_pts+1):2*(n_pts+1)],BVecs[kk][2*(n_pts+1):]]
    bva_split = [BVals[kk][:(n_pts+1)],BVals[kk][(n_pts+1):2*(n_pts+1)],BVals[kk][2*(n_pts+1):]]
    # Define the function for optimization
    def optimize_pixel_LS(i, j):
        result = sp.optimize.least_squares(residuals_S0, guess, args=[D[i,j,sl, :],bve_split,bva_split,Delta],
                                  bounds=bounds,verbose=0,jac='3-point')
        return i, j, result.x
    

    
    # Initialize NoiseEst with the appropriate shape
    ArrShape = mask.shape
    
    # Use joblib to parallelize the optimization tasks
    results = Parallel(n_jobs=-1)(
        delayed(optimize_pixel_LS)(i, j) for i, j in tqdm(indices)
    )
    
    
    NoiseEst_LS = np.zeros(list(ArrShape) + [13])
    
    # Assign the optimization results to NoiseEst
    for i, j, x in results:
        NoiseEst_LS[i, j] = x

    Full_LS.append(NoiseEst_LS)

In [None]:
Min_LS = []
for kk,(D,sl) in enumerate(zip(Datas,[48]*8)):
    # Compute the mask where the sum is not zero
    mask = np.sum(D[:, :, sl, :], axis=-1) != 0
    
    # Get the indices where mask is True
    indices = np.argwhere(mask)
    
    bve_splitd = [BVecsDev[kk][:7],BVecsDev[kk][7:13],BVecsDev[kk][13:]]
    bva_splitd = [BValsDev[kk][:7],BValsDev[kk][7:13],BValsDev[kk][13:]]

    # Define the function for optimization
    def optimize_pixel_LS(i, j):
        result = sp.optimize.least_squares(residuals_S0, guess, args=[D[i, j,sl, IndxArr[kk]],bve_splitd,bva_splitd,Delta],
                                  bounds=bounds,verbose=0,xtol=1e-12,gtol=1e-12,ftol=1e-12,jac='3-point')
        return i, j, result.x
        
    # Initialize NoiseEst with the appropriate shape
    ArrShape = mask.shape
    
    # Use joblib to parallelize the optimization tasks
    results = Parallel(n_jobs=-1)(
        delayed(optimize_pixel_LS)(i, j) for i, j in tqdm(indices)
    )
    
    
    NoiseEst_LS = np.zeros(list(ArrShape) + [13])
    
    # Assign the optimization results to NoiseEst
    for i, j, x in results:
        NoiseEst_LS[i, j] = x

    Min_LS.append(NoiseEst_LS)

In [None]:
np.save('temp_Full_LS',Full_LS)

In [None]:
np.save('temp_Min_LS',Min_LS)

In [None]:
os.system("say 'AxCaliber Fit done'") # or '\7'

In [None]:
WRKir = '/Users/maximilianeggl/Dropbox/PostDoc/Silvia/SBIDTIPaper2/Code/MS_data/WM_masks/'
WMs = []
for i,Name in tqdm(enumerate(['NMSS_11_1year','NMSS_15','NMSS_16','NMSS_18','NMSS_19','Ctrl055_R01_28','Ctrl056_R01_29','Ctrl057_R01_30'])):
    
    for k,x in enumerate(os.listdir(WRKir)):
        if Name in x:
            print(Name)
            WM, affine, img = load_nifti(WRKir+x, return_img=True)
            #WM, affine = reslice(WM, affine, (2,2,2), (2.5,2.5,2.5))
            if(i<5):
                WM_t = np.fliplr(np.swapaxes(WM,0,1))
            else:
                WM_t = np.fliplr(np.flipud(np.swapaxes(WM,0,1)))
            WMs.append(WM_t)

In [None]:
SBI_comp = []
KK = [48]*8
for i in range(8):
    NS1 = np.copy(Min_SBI[i][...,3])
    NS2 = np.copy(Full_SBI[i][...,3])

    core,ssim_map = ssim(NS1,NS2, data_range=max([NS1.max(),NS2.max()])-min([NS1.min(),NS2.min()]),full=True,win_size=15)
    masked_ssim = ssim_map[Outlines[i][:,:,KK[i]]].mean()
    SBI_comp.append(masked_ssim.mean())
print(SBI_comp)

In [None]:
def Par_frac(i,j,Mat):
    MD = np.linalg.eigh(vals_to_mat(Mat[i,j]))[0].mean()

    FA = FracAni(np.linalg.eigh(vals_to_mat(Mat[i,j]))[0],MD)
    return i, j, [FA,MD]

In [None]:
KK = [48]*8
FA_Full_SBI = []
MD_Full_SBI = []
for jj in range(8):
    mask = Outlines[jj][:,:,KK[jj]]
    indices = np.argwhere(mask)
    
    results = Parallel(n_jobs=-1)(
        delayed(Par_frac)(i, j,Full_SBI[jj][...,4:10]) for i, j in tqdm(indices)
    )
    
    
    temp1 = np.zeros(list(ArrShape))
    temp2 = np.zeros(list(ArrShape))
    # Assign the optimization results to NoiseEst
    for i, j, k in results:
        temp1[i, j] = k[0]
        temp2[i, j] = k[1]

    FA_Full_SBI.append(temp1)
    MD_Full_SBI.append(temp2)

In [None]:
KK = [48]*8
FA_Min_SBI = []
MD_Min_SBI = []
for jj in range(8):
    mask = Outlines[jj][:,:,KK[jj]]
    indices = np.argwhere(mask)
    
    results = Parallel(n_jobs=-1)(
        delayed(Par_frac)(i, j,Min_SBI[jj][...,4:10]) for i, j in tqdm(indices)
    )
    
    
    temp1 = np.zeros(list(ArrShape))
    temp2 = np.zeros(list(ArrShape))
    # Assign the optimization results to NoiseEst
    for i, j, k in results:
        temp1[i, j] = k[0]
        temp2[i, j] = k[1]

    FA_Min_SBI.append(temp1)
    MD_Min_SBI.append(temp2)

In [None]:
KK = [48]*8
FA_Full_LS = []
MD_Full_LS = []
for jj in range(8):
    mask = Outlines[jj][:,:,KK[jj]]
    indices = np.argwhere(mask)
    
    results = Parallel(n_jobs=-1)(
        delayed(Par_frac)(i, j,Full_LS[jj][...,4:10]) for i, j in tqdm(indices)
    )
    
    
    temp1 = np.zeros(list(ArrShape))
    temp2 = np.zeros(list(ArrShape))
    # Assign the optimization results to NoiseEst
    for i, j, k in results:
        temp1[i, j] = k[0]
        temp2[i, j] = k[1]

    FA_Full_LS.append(temp1)
    MD_Full_LS.append(temp2)

KK = [48]*8
FA_Min_LS = []
MD_Min_LS = []
for jj in range(8):
    mask = Outlines[jj][:,:,KK[jj]]
    indices = np.argwhere(mask)
    
    results = Parallel(n_jobs=-1)(
        delayed(Par_frac)(i, j,Min_LS[jj][...,4:10]) for i, j in tqdm(indices)
    )
    
    
    temp1 = np.zeros(list(ArrShape))
    temp2 = np.zeros(list(ArrShape))
    # Assign the optimization results to NoiseEst
    for i, j, k in results:
        temp1[i, j] = k[0]
        temp2[i, j] = k[1]

    FA_Min_LS.append(temp1)
    MD_Min_LS.append(temp2)

In [None]:
jj = 0
SBI_comp = []
KK = [48]*8
for i in range(8):
    NS1 = np.copy(Min_SBI[i][...,3])
    NS1 = gaussian_filter(NS1, sigma=0.5)
    NS2 = np.copy(Full_SBI[i][...,3])

    core,ssim_map = ssim(NS1,NS2, data_range=max([NS1.max(),NS2.max()])-min([NS1.min(),NS2.min()]),full=True,win_size=15)
    masked_ssim = ssim_map[Outlines[i][:,:,KK[i]]].mean()
    SBI_comp.append(masked_ssim.mean())

LS_comp = []
for i in range(8):
    NS1 = np.copy(Min_LS[i][...,3])
    NS1 = gaussian_filter(NS1, sigma=0.5)
    NS2 = np.copy(Full_LS[i][...,3])

    core,ssim_map = ssim(NS1,NS2, data_range=max([NS1.max(),NS2.max()])-min([NS1.min(),NS2.min()]),full=True,win_size=15)
    masked_ssim = ssim_map[Outlines[i][:,:,KK[i]]].mean()
    LS_comp.append(masked_ssim.mean())

SBI_LS_comp = []
for i in range(8):
    NS1 = np.copy(Full_SBI[i][...,3])
    NS2 = np.copy(Full_LS[i][...,3])

    core,ssim_map = ssim(NS1,NS2, data_range=max([NS1.max(),NS2.max()])-min([NS1.min(),NS2.min()]),full=True,win_size=15)
    masked_ssim = ssim_map[Outlines[i][:,:,KK[i]]].mean()
    SBI_LS_comp.append(masked_ssim.mean())
Prec7_SBI = []
PrecFull_SBI = []

Prec7_NLLS = []
PrecFull_NLLS = []
for i in range(8):
    Prec7_SBI.append(np.std(Min_SBI[i][...,3][WMs[i].astype(bool)[:,:,48]]))
    PrecFull_SBI.append(np.std(Full_SBI[i][...,3][WMs[i].astype(bool)[:,:,48]]))

    Prec7_NLLS.append(np.std(Min_LS[i][...,3][WMs[i].astype(bool)[:,:,48]]))
    PrecFull_NLLS.append(np.std(Full_LS[i][...,3][WMs[i].astype(bool)[:,:,48]]))


In [None]:
fig, ax = plt.subplots(figsize=(3.2,4.8))#, sharex=True)

# Plotting on ax1
plt.sca(ax)
y_data = np.array(SBI_comp)
g_pos = np.array([1.3])
colors = ['lightseagreen']
colors2 = ['paleturquoise']

BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2,scatter=True,scatter_alpha=0.5)

y_data = np.array(LS_comp)
g_pos = np.array([1.9])
colors = ['sandybrown']
colors2 = ['peachpuff']

BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2,scatter=True)

plt.axhline(0.66, lw=3, ls='--', c='k')
plt.xticks([1.3,1.9],['SBI','NNLS'],fontsize=32,rotation=90)
ax.set_yticks([0,0.2,0.4,0.6,0.8,1.0])
ax.set_ylim(-0.1,1)

if Save: plt.savefig(FigLoc+'MS_Ax_SSIM_Dperp.pdf',format='PDF',transparent=True,bbox_inches='tight')

In [None]:
fig,ax1 = plt.subplots(1,1,figsize=(3.2,4.8))

y_data = np.array(PrecFull_SBI)
g_pos = np.array([0.65])
colors = ['mediumturquoise']
colors2 = ['paleturquoise']

BoxPlots(y_data,g_pos,colors,colors2,ax1,widths=0.2,scatter=True)

y_data = np.array(Prec7_SBI)
g_pos = np.array([1.1])

BoxPlots(y_data,g_pos,colors,colors2,ax1,widths=0.2,scatter=True)

y_data = np.array(PrecFull_NLLS)
g_pos = np.array([1.8])
colors = ['sandybrown']
colors2 = ['peachpuff']

BoxPlots(y_data,g_pos,colors,colors2,ax1,widths=0.2,scatter=True)


y_data = np.array(Prec7_NLLS)
g_pos = np.array([2.15])

BoxPlots(y_data,g_pos,colors,colors2,ax1,widths=0.2,scatter=True)
plt.xticks([0.65,1.1,1.8,2.15],['Full','Red.','Full','Red.'],fontsize=32,rotation=90)

x = np.arange(1.7,2.3,0.05)
y1 = np.ones_like(x)*np.percentile(np.array(PrecFull_NLLS)[~np.isnan(PrecFull_NLLS)], 25)
y2 = np.ones_like(x)*np.percentile(np.array(PrecFull_NLLS)[~np.isnan(PrecFull_NLLS)], 77)
plt.fill_between(x,y1,y2,color='sandybrown',zorder=10,alpha=0.2,hatch='//')

x = np.arange(0.55,1.25,0.05)
y1 = np.ones_like(x)*np.percentile(np.array(PrecFull_SBI)[~np.isnan(PrecFull_SBI)], 25)
y2 = np.ones_like(x)*np.percentile(np.array(PrecFull_SBI)[~np.isnan(PrecFull_SBI)], 77)
plt.fill_between(x,y1,y2,color='mediumturquoise',zorder=10,alpha=0.2,hatch='//')

#ax1.set_xlim(0.3,2.8)
ax1.ticklabel_format(axis='y',style='sci',scilimits=(-1,1))
if Save: plt.savefig(FigLoc+'MS_Ax_Prec_Dperp.pdf',format='PDF',transparent=True,bbox_inches='tight')

In [None]:
jj = 0
SBI_comp = []
KK = [48]*8
for i in range(8):
    NS1 = np.copy(MD_Min_SBI[i])
    NS1 = gaussian_filter(NS1, sigma=0.5)
    NS2 = np.copy(MD_Full_SBI[i])

    core,ssim_map = ssim(NS1,NS2, data_range=max([NS1.max(),NS2.max()])-min([NS1.min(),NS2.min()]),full=True,win_size=15)
    masked_ssim = ssim_map[Outlines[i][:,:,KK[i]]].mean()
    SBI_comp.append(masked_ssim.mean())

LS_comp = []
for i in range(8):
    NS1 = np.copy(MD_Min_LS[i])
    NS1 = gaussian_filter(NS1, sigma=0.5)
    NS2 = np.copy(MD_Full_LS[i])

    core,ssim_map = ssim(NS1,NS2, data_range=max([NS1.max(),NS2.max()])-min([NS1.min(),NS2.min()]),full=True,win_size=15)
    masked_ssim = ssim_map[Outlines[i][:,:,KK[i]]].mean()
    LS_comp.append(masked_ssim.mean())

SBI_LS_comp = []
for i in range(8):
    NS1 = np.copy(MD_Full_SBI[i])
    NS2 = np.copy(MD_Full_LS[i])

    core,ssim_map = ssim(NS1,NS2, data_range=max([NS1.max(),NS2.max()])-min([NS1.min(),NS2.min()]),full=True,win_size=15)
    masked_ssim = ssim_map[Outlines[i][:,:,KK[i]]].mean()
    SBI_LS_comp.append(masked_ssim.mean())
Prec7_SBI = []
PrecFull_SBI = []

Prec7_NLLS = []
PrecFull_NLLS = []
for i in range(8):
    Prec7_SBI.append(np.std(MD_Min_SBI[i][WMs[i].astype(bool)[:,:,48]]))
    PrecFull_SBI.append(np.std(MD_Full_SBI[i][WMs[i].astype(bool)[:,:,48]]))

    Prec7_NLLS.append(np.std(MD_Min_LS[i][WMs[i].astype(bool)[:,:,48]]))
    PrecFull_NLLS.append(np.std(MD_Full_LS[i][WMs[i].astype(bool)[:,:,48]]))


In [None]:
fig, ax = plt.subplots(figsize=(3.2,4.8))#, sharex=True)

# Plotting on ax1
plt.sca(ax)
y_data = np.array(SBI_comp)
g_pos = np.array([1.3])
colors = ['lightseagreen']
colors2 = ['paleturquoise']

BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2,scatter=True,scatter_alpha=0.5)

y_data = np.array(LS_comp)
g_pos = np.array([1.9])
colors = ['sandybrown']
colors2 = ['peachpuff']

BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2,scatter=True)

plt.axhline(0.66, lw=3, ls='--', c='k')
plt.xticks([1.3,1.9],['SBI','NNLS'],fontsize=32,rotation=90)
ax.set_yticks([0,0.2,0.4,0.6,0.8,1.0])
ax.set_ylim(-0.1,1)

if Save: plt.savefig(FigLoc+'MS_Ax_SSIM_MD.pdf',format='PDF',transparent=True,bbox_inches='tight')

In [None]:
fig,ax1 = plt.subplots(1,1,figsize=(3.2,4.8))

y_data = np.array(PrecFull_SBI)
g_pos = np.array([0.65])
colors = ['mediumturquoise']
colors2 = ['paleturquoise']

BoxPlots(y_data,g_pos,colors,colors2,ax1,widths=0.2,scatter=True)

y_data = np.array(Prec7_SBI)
g_pos = np.array([1.1])

BoxPlots(y_data,g_pos,colors,colors2,ax1,widths=0.2,scatter=True)

y_data = np.array(PrecFull_NLLS)
g_pos = np.array([1.8])
colors = ['sandybrown']
colors2 = ['peachpuff']

BoxPlots(y_data,g_pos,colors,colors2,ax1,widths=0.2,scatter=True)


y_data = np.array(Prec7_NLLS)
g_pos = np.array([2.15])

BoxPlots(y_data,g_pos,colors,colors2,ax1,widths=0.2,scatter=True)
plt.xticks([0.65,1.1,1.8,2.15],['Full','Red.','Full','Red.'],fontsize=32,rotation=90)

x = np.arange(1.7,2.3,0.05)
y1 = np.ones_like(x)*np.percentile(np.array(PrecFull_NLLS)[~np.isnan(PrecFull_NLLS)], 25)
y2 = np.ones_like(x)*np.percentile(np.array(PrecFull_NLLS)[~np.isnan(PrecFull_NLLS)], 77)
plt.fill_between(x,y1,y2,color='sandybrown',zorder=10,alpha=0.2,hatch='//')

x = np.arange(0.55,1.25,0.05)
y1 = np.ones_like(x)*np.percentile(np.array(PrecFull_SBI)[~np.isnan(PrecFull_SBI)], 25)
y2 = np.ones_like(x)*np.percentile(np.array(PrecFull_SBI)[~np.isnan(PrecFull_SBI)], 77)
plt.fill_between(x,y1,y2,color='mediumturquoise',zorder=10,alpha=0.2,hatch='//')

#ax1.set_xlim(0.3,2.8)
ax1.ticklabel_format(axis='y',style='sci',scilimits=(-1,1))
if Save: plt.savefig(FigLoc+'MS_Ax_Prec_MD.pdf',format='PDF',transparent=True,bbox_inches='tight')

In [None]:
jj = -4
SBI_comp = []
KK = [48]*8
for i in range(8):
    NS1 = np.copy(Min_SBI[i][...,jj])
    NS1 = gaussian_filter(NS1, sigma=0.5)
    NS2 = np.copy(Full_SBI[i][...,jj])

    core,ssim_map = ssim(NS1,NS2, data_range=max([NS1.max(),NS2.max()])-min([NS1.min(),NS2.min()]),full=True,win_size=15)
    masked_ssim = ssim_map[Outlines[i][:,:,KK[i]]].mean()
    SBI_comp.append(masked_ssim.mean())

LS_comp = []
for i in range(8):
    NS1 = np.copy(Min_LS[i][...,jj])
    NS1 = gaussian_filter(NS1, sigma=0.5)
    NS2 = np.copy(Full_LS[i][...,jj])

    core,ssim_map = ssim(NS1,NS2, data_range=max([NS1.max(),NS2.max()])-min([NS1.min(),NS2.min()]),full=True,win_size=15)
    masked_ssim = ssim_map[Outlines[i][:,:,KK[i]]].mean()
    LS_comp.append(masked_ssim.mean())

SBI_LS_comp = []
for i in range(8):
    NS1 = np.copy(Full_SBI[i][...,jj])
    NS2 = np.copy(Full_LS[i][...,jj])

    core,ssim_map = ssim(NS1,NS2, data_range=max([NS1.max(),NS2.max()])-min([NS1.min(),NS2.min()]),full=True,win_size=15)
    masked_ssim = ssim_map[Outlines[i][:,:,KK[i]]].mean()
    SBI_LS_comp.append(masked_ssim.mean())
Prec7_SBI = []
PrecFull_SBI = []

Prec7_NLLS = []
PrecFull_NLLS = []
for i in range(8):
    Prec7_SBI.append(np.std(Min_SBI[i][...,jj][WMs[i].astype(bool)[:,:,48]]))
    PrecFull_SBI.append(np.std(Full_SBI[i][...,jj][WMs[i].astype(bool)[:,:,48]]))

    Prec7_NLLS.append(np.std(Min_LS[i][...,jj][WMs[i].astype(bool)[:,:,48]]))
    PrecFull_NLLS.append(np.std(Full_LS[i][...,jj][WMs[i].astype(bool)[:,:,48]]))


In [None]:
fig, ax = plt.subplots(figsize=(3.2,4.8))#, sharex=True)

# Plotting on ax1
plt.sca(ax)
y_data = np.array(SBI_comp)
g_pos = np.array([1.3])
colors = ['lightseagreen']
colors2 = ['paleturquoise']

BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2,scatter=True,scatter_alpha=0.5)

y_data = np.array(LS_comp)
g_pos = np.array([1.9])
colors = ['sandybrown']
colors2 = ['peachpuff']

BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2,scatter=True)

plt.axhline(0.66, lw=3, ls='--', c='k')
plt.xticks([1.3,1.9],['SBI','NNLS'],fontsize=32,rotation=90)
ax.set_yticks([0,0.2,0.4,0.6,0.8,1.0])
ax.set_ylim(-0.1,1)

if Save: plt.savefig(FigLoc+'MS_Ax_SSIM_Frac.pdf',format='PDF',transparent=True,bbox_inches='tight')

In [None]:
fig,ax1 = plt.subplots(1,1,figsize=(3.2,4.8))

y_data = np.array(PrecFull_SBI)
g_pos = np.array([0.65])
colors = ['mediumturquoise']
colors2 = ['paleturquoise']

BoxPlots(y_data,g_pos,colors,colors2,ax1,widths=0.2,scatter=True)

y_data = np.array(Prec7_SBI)
g_pos = np.array([1.1])

BoxPlots(y_data,g_pos,colors,colors2,ax1,widths=0.2,scatter=True)

y_data = np.array(PrecFull_NLLS)
g_pos = np.array([1.8])
colors = ['sandybrown']
colors2 = ['peachpuff']

BoxPlots(y_data,g_pos,colors,colors2,ax1,widths=0.2,scatter=True)


y_data = np.array(Prec7_NLLS)
g_pos = np.array([2.15])

BoxPlots(y_data,g_pos,colors,colors2,ax1,widths=0.2,scatter=True)
plt.xticks([0.65,1.1,1.8,2.15],['Full','Red.','Full','Red.'],fontsize=32,rotation=90)

x = np.arange(1.7,2.3,0.05)
y1 = np.ones_like(x)*np.percentile(np.array(PrecFull_NLLS)[~np.isnan(PrecFull_NLLS)], 25)
y2 = np.ones_like(x)*np.percentile(np.array(PrecFull_NLLS)[~np.isnan(PrecFull_NLLS)], 77)
plt.fill_between(x,y1,y2,color='sandybrown',zorder=10,alpha=0.2,hatch='//')

x = np.arange(0.55,1.25,0.05)
y1 = np.ones_like(x)*np.percentile(np.array(PrecFull_SBI)[~np.isnan(PrecFull_SBI)], 25)
y2 = np.ones_like(x)*np.percentile(np.array(PrecFull_SBI)[~np.isnan(PrecFull_SBI)], 77)
plt.fill_between(x,y1,y2,color='mediumturquoise',zorder=10,alpha=0.2,hatch='//')

#ax1.set_xlim(0.3,2.8)
ax1.ticklabel_format(axis='y',style='sci',scilimits=(-1,1))
ax1.set_yticks([0,0.1,0.2])
if Save: plt.savefig(FigLoc+'MS_Ax_Prec_Frac.pdf',format='PDF',transparent=True,bbox_inches='tight')