# Imports 

In [None]:
import numpy as np
import dill as pickle
import os
import tqdm.auto as tqdm
from joblib import Parallel, delayed

import scipy.stats as stats
from scipy.spatial.distance import pdist, squareform
from scipy.stats import lognorm,norm

from skimage.metrics import structural_similarity as ssim

import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
from matplotlib.gridspec import GridSpec
from matplotlib.colors import TwoSlopeNorm
import matplotlib.patches as mpatches
from matplotlib.ticker import FormatStrFormatter
from matplotlib.ticker import ScalarFormatter
from scipy.special import i0e

from scipy.optimize import least_squares
from numpy.linalg import inv


# Define font properties
font = {
    'family': 'sans-serif',  # Use sans-serif family
    'sans-serif': ['Helvetica'],  # Specify Helvetica as the sans-serif font
    'size': 14  # Set the default font size
}
plt.rc('font', **font)

# Set tick label sizes
plt.rc('ytick', labelsize=24)
plt.rc('xtick', labelsize=24)

plt.rcParams.update({
    "text.usetex": False,
    "font.family": "Helvetica"
})
# Customize axes spines and legend appearance
plt.rcParams['axes.spines.top'] = False
plt.rcParams['axes.spines.right'] = False
plt.rcParams['legend.frameon'] = False

from sbi import analysis as analysis
from sbi import utils as utils
from sbi.inference import SNPE, simulate_for_sbi
from sbi.inference.potentials.posterior_based_potential import posterior_estimator_based_potential
from sbi.utils.user_input_checks import (
    check_sbi_inputs,
    process_prior,
    process_simulator,
)
from sbi.utils import process_prior,BoxUniform

import torch
from torch.distributions import Categorical,Normal
from torch import Tensor


from dipy.sims.voxel import single_tensor
from dipy.data import get_fnames
from dipy.io.gradients import read_bvals_bvecs
from dipy.core.gradients import gradient_table
from dipy.reconst.dti import (decompose_tensor, from_lower_triangular)
from dipy.io.image import load_nifti
from dipy.segment.mask import median_otsu
import dipy.reconst.dti as dti
import dipy.reconst.dki as dki
from dipy.align.reslice import reslice
from dipy.core.sphere import disperse_charges, Sphere, HemiSphere

import pymatreader as pmt

from scipy.optimize import minimize
from scipy.special import i0

from dwMRI_BasicFuncs import *

In [None]:
def BoxPlots(y_data, positions, colors, colors2, ax,hatch = False,scatter=False,scatter_alpha=0.5, **kwargs):

    GREY_DARK = "#747473"
    jitter = 0.02
    # Clean data to remove NaNs column-wise
    if(np.ndim(y_data) == 1):
        cleaned_data = y_data[~np.isnan(y_data)]
    else:
        cleaned_data = [d[~np.isnan(d)] for d in y_data]
    
    # Define properties for the boxes (patch objects)
    boxprops = dict(
        linewidth=2, 
        facecolor='none',       # use facecolor for filling (set to 'none' if you want no fill)
        edgecolor='turquoise'   # edgecolor for the outline
    )

    # Define properties for the medians (Line2D objects)
    # Ensure GREY_DARK is defined (or replace it with a color string)
    medianprops = dict(
        linewidth=2, 
        color=GREY_DARK,
        solid_capstyle="butt"
    )

    # For whiskers, since they are Line2D objects, use 'color'
    whiskerprops = dict(
        linewidth=2, 
        color='turquoise'
    )

    bplot = ax.boxplot(
        cleaned_data,
        positions=positions, 
        showfliers=False,
        showcaps = False,
        medianprops=medianprops,
        whiskerprops=whiskerprops,
        boxprops=boxprops,
        patch_artist=True,
        **kwargs
    )

    # Update the color of each box (these are patch objects)
    for i, box in enumerate(bplot['boxes']):
        box.set_edgecolor(colors[i])
        if(hatch):
            box.set_hatch('/')
    
    
    # Update the color of the whiskers (each box has 2 whiskers)
    for i in range(len(positions)):
        bplot['whiskers'][2*i].set_color(colors[i])
        bplot['whiskers'][2*i+1].set_color(colors[i])
    
    # If caps are enabled, update their color (Line2D objects)
    if 'caps' in bplot:
        for i, cap in enumerate(bplot['caps']):
            cap.set_color(colors[i//2])  # two caps per box

    if(scatter):
        if(np.ndim(cleaned_data) == 1):
            x_data = np.array([positions] * len(cleaned_data))
            x_jittered = x_data + stats.t(df=6, scale=jitter).rvs(len(x_data))
            ax.scatter(x_data, cleaned_data, s=100, color=colors2, alpha=scatter_alpha)
        else:
            x_data = [np.array([positions[i]] * len(d)) for i, d in enumerate(cleaned_data)]
            x_jittered = [x + stats.t(df=6, scale=jitter).rvs(len(x)) for x in x_data]
            # Plot the scatter points with jitter (using colors2)
            for x, y, c in zip(x_jittered, cleaned_data, colors2):
                ax.scatter(x, y, s=100, color=c, alpha=scatter_alpha)

## Basc functions

In [None]:
def vals_to_mat(dt):
    DTI = np.zeros((3,3))
    DTI[0,0] = dt[0]
    DTI[0,1],DTI[1,0] =  dt[1],dt[1]
    DTI[1,1] =  dt[2]
    DTI[0,2],DTI[2,0] =  dt[3],dt[3]
    DTI[1,2],DTI[2,1] =  dt[4],dt[4]
    DTI[2,2] =  dt[5]
    return DTI

def mat_to_vals(DTI):
    dt = np.zeros(6)
    dt[0] = DTI[0,0]
    dt[1] = DTI[0,1]
    dt[2] = DTI[1,1]
    dt[3] = DTI[0,2]
    dt[4] = DTI[1,2]
    dt[5] = DTI[2,2]
    return dt

def fill_lower_diag(a):
    b = [a[0],a[3],a[1],a[4],a[5],a[2]]
    n = 3
    mask = np.tri(n,dtype=bool) 
    out = np.zeros((n,n),dtype=float)
    out[mask] = b
    return out


def ForceLowFA(dt):
    # Modify the matrix to ensure low FA (more isotropic)
    eigenvalues, eigenvectors = np.linalg.eigh(dt)
    
    # Make the eigenvalues more similar to enforce low FA
    mean_eigenvalue = np.mean(eigenvalues)

    adjusted_eigenvalues = np.clip(eigenvalues, mean_eigenvalue * np.random.rand(), mean_eigenvalue * 1.0)
    
    # Reconstruct the matrix with the adjusted eigenvalues
    dt_low_fa = eigenvectors @ np.diag(adjusted_eigenvalues) @ eigenvectors.T
    
    return dt_low_fa
    
def FracAni(evals,MD):
    numerator = np.sqrt(3 * np.sum((evals - MD) ** 2))
    denominator = np.sqrt(2) * np.sqrt(np.sum(evals ** 2))
    
    return numerator / denominator

def clip_negative_eigenvalues(matrix):
    # Perform eigenvalue decomposition
    eigenvalues, eigenvectors = np.linalg.eig(matrix)
    
    # Clip negative eigenvalues to 0
    clipped_eigenvalues = np.maximum(eigenvalues, 1e-5)
    
    # Reconstruct the matrix with the clipped eigenvalues
    clipped_matrix = eigenvectors @ np.diag(clipped_eigenvalues) @ np.linalg.inv(eigenvectors)
    
    return clipped_matrix


In [None]:
def FitDT(Dat,seed=1):

    np.random.seed(seed)
    # DT_abc
    data = Dat[:,0]
    shape,loc,scale = lognorm.fit(data)
    
    dti1_fitted = stats.lognorm(shape, loc=loc, scale=scale)

    #DT_rest
    data = Dat[:,1]
    loc,scale = norm.fit(data)
    
    # Compute the fitted PDF
    dti2_fitted = stats.norm(loc=loc, scale=scale)

    return dti1_fitted,dti2_fitted

def FitKT(Dat,seed=1):
    np.random.seed(seed)    
    # Fitting x4
    data = Dat[:,0]
    shape,loc,scale = lognorm.fit(data)
    x4_fitted = stats.lognorm(shape, loc=loc, scale=scale)
    
    # Fitting R1
    data = Dat[:,3]
    loc,scale = norm.fit(data)
    R1_fitted = norm(loc,scale)
    
    # Fitting x2
    data = Dat[:,9]
    shape,loc,scale = lognorm.fit(data)
    x2_fitted = stats.lognorm(shape, loc=loc, scale=scale)

    # Fitting R2
    data = Dat[:,12]
    loc,scale = norm.fit(data)
    R2_fitted = norm(loc,scale)


    return x4_fitted,R1_fitted,x2_fitted,R2_fitted

def GenDTKT(DT_Fits,KT_Fits,seed,size):

    np.random.seed(seed)
    DT = np.zeros([size,6])
    KT = np.zeros([size,15])

    DT[:,0] = DT_Fits[0].rvs(size)
    DT[:,2] = DT_Fits[0].rvs(size)
    DT[:,5] = DT_Fits[0].rvs(size)

    DT[:,1] = DT_Fits[1].rvs(size)
    DT[:,3] = DT_Fits[1].rvs(size)
    DT[:,4] = DT_Fits[1].rvs(size)

    for k in range(3):
        KT[:,k] = KT_Fits[0].rvs(size)
    for k in range(3,9):
        KT[:,k] = KT_Fits[1].rvs(size)
    for k in range(9,12):
        KT[:,k] = KT_Fits[2].rvs(size)
    for k in range(12,15):
        KT[:,k] = KT_Fits[3].rvs(size)

    return DT,KT
    
def DKIMetrics(dt,kt,analytical=True):
    if(dt.ndim == 1):
        dt = vals_to_mat(dt)
    evals,evecs = np.linalg.eigh(dt)
    idx = np.argsort(evals)[::-1]
    evals = evals[idx]
    evecs = evecs[:, idx]
    
    params = np.concatenate([evals,np.hstack(evecs),kt])
    params2 = np.concatenate([evals,np.hstack(evecs),-kt])

    mk = dki.mean_kurtosis(params,analytical=analytical,min_kurtosis=-3.0 / 7, max_kurtosis=np.inf)

    ak = dki.axial_kurtosis(params,analytical=analytical,min_kurtosis=-3.0 / 7, max_kurtosis=np.inf)

    rk = dki.radial_kurtosis(params,analytical=analytical,min_kurtosis=-3.0 / 7, max_kurtosis=np.inf)

    mkt = dki.mean_kurtosis_tensor(params, min_kurtosis=-3.0 / 7, max_kurtosis=np.inf)

    kfa = kurtosis_fractional_anisotropy_test(params)

    return mk,ak,rk,mkt,kfa
    
def kurtosis_fractional_anisotropy_test(dki_params):
    r"""Compute the anisotropy of the kurtosis tensor (KFA).

    See :footcite:p:`Glenn2015` and :footcite:p:`NetoHenriques2021` for further
    details about the method.

    Parameters
    ----------
    dki_params : ndarray (x, y, z, 27) or (n, 27)
        All parameters estimated from the diffusion kurtosis model.
        Parameters are ordered as follows:
            1) Three diffusion tensor's eigenvalues
            2) Three lines of the eigenvector matrix each containing the first,
                second and third coordinates of the eigenvector
            3) Fifteen elements of the kurtosis tensor

    Returns
    -------
    kfa : array
        Calculated mean kurtosis tensor.

    Notes
    -----
    The KFA is defined as :footcite:p:`Glenn2015`:

    .. math::

         KFA \equiv
         \frac{||\mathbf{W} - MKT \mathbf{I}^{(4)}||_F}{||\mathbf{W}||_F}

    where $W$ is the kurtosis tensor, MKT the kurtosis tensor mean, $I^{(4)}$ is
    the fully symmetric rank 2 isotropic tensor and $||...||_F$ is the tensor's
    Frobenius norm :footcite:p:`Glenn2015`.

    References
    ----------
    .. footbibliography::

    """
    Wxxxx = dki_params[..., 12]
    Wyyyy = dki_params[..., 13]
    Wzzzz = dki_params[..., 14]
    Wxxxy = dki_params[..., 15]
    Wxxxz = dki_params[..., 16]
    Wxyyy = dki_params[..., 17]
    Wyyyz = dki_params[..., 18]
    Wxzzz = dki_params[..., 19]
    Wyzzz = dki_params[..., 20]
    Wxxyy = dki_params[..., 21]
    Wxxzz = dki_params[..., 22]
    Wyyzz = dki_params[..., 23]
    Wxxyz = dki_params[..., 24]
    Wxyyz = dki_params[..., 25]
    Wxyzz = dki_params[..., 26]


    W = 1.0 / 5.0 * (Wxxxx + Wyyyy + Wzzzz + 2 * Wxxyy + 2 * Wxxzz + 2 * Wyyzz)
    # Compute's equation numerator
    A = (
        (Wxxxx - W) ** 2
        + (Wyyyy - W) ** 2
        + (Wzzzz - W) ** 2
        + 4 * (Wxxxy**2 + Wxxxz**2 + Wxyyy**2 + Wyyyz**2 + Wxzzz**2 + Wyzzz**2)
        + 6 * ((Wxxyy - W / 3) ** 2 + (Wxxzz - W / 3) ** 2 + (Wyyzz - W / 3) ** 2)
        + 12 * (Wxxyz**2 + Wxyyz**2 + Wxyzz**2)
    )
    # Compute's equation denominator
    B = (
        Wxxxx**2
        + Wyyyy**2
        + Wzzzz**2
        + 4 * (Wxxxy**2 + Wxxxz**2 + Wxyyy**2 + Wyyyz**2 + Wxzzz**2 + Wyzzz**2)
        + 6 * (Wxxyy**2 + Wxxzz**2 + Wyyzz**2)
        + 12 * (Wxxyz**2 + Wxyyz**2 + Wxyzz**2)
    )

    # Compute KFA
    KFA = np.zeros(A.shape)
    KFA = np.sqrt(A / B)

    return KFA

In [None]:
def CustomSimulator(Mat,gtab,S0,snr=None):
    evals,evecs = np.linalg.eigh(Mat)
    signal = single_tensor(gtab, S0=S0, evals=evals, evecs=evecs)
    if(snr is None):
        return signal
    else:
        return AddNoise(signal,S0,snr)

def Simulator(gtab,S0,params,SNR):

    dt = ComputeDTI(params)
    signal_dti = CustomSimulator(dt,gtab,S0,SNR)
    
    return signal_dti


def GenRicciNoise(signal,S0,snr):

    size = signal.shape
    sigma = S0 / snr
    noise1 = np.random.normal(0, sigma, size=size)
    noise2 = np.random.normal(0, sigma, size=size)

    return np.sqrt((signal+noise1) ** 2 + noise2 ** 2)


def AddNoise(signal,S0,snr):
    
    return GenRicciNoise(signal,S0,snr)

def CustomDKISimulator(dt,kt,gtab,S0,snr=None):
    if(dt.ndim == 1):
        dt = vals_to_mat(dt)
    evals,evecs = np.linalg.eigh(dt)
    params = np.concatenate([evals,np.hstack(evecs),kt])
    signal = dki.dki_prediction(params,gtab,S0)
    if(snr is None):
        return signal
    else:
        return AddNoise(signal,S0,snr)

def CustomDKISimulator2(params,kt,gtab,S0,snr=None):
    dt = ComputeDTI(params)
    evals,evecs = np.linalg.eigh(dt)
    combined_set = np.concatenate([evals,np.hstack(evecs),kt])
    signal = dki.dki_prediction(combined_set,gtab,S0)
    if(np.isnan(signal).any() or np.isinf(signal).any() or (signal>1e15).any()):
        pass#import pdb;pdb.set_trace()
    if(snr is None):
        return signal
    else:
        return AddNoise(signal,S0,snr)

In [None]:
network_path = '../Networks/'
image_path   = './Images/'
if not os.path.exists(image_path):
    os.makedirs(image_path)
NoiseLevels = [None,20,10,5,2]

TrainingSamples = 50000
InferSamples    = 500

lower_abs,upper_abs = -0.07,0.07
lower_rest,upper_rest = -0.015,0.015
lower_S0 = 25
upper_S0 = 2000
Save = True

TrueCol  = 'k'
NoisyCol = 'k'
NLLSFit   = np.array([225,190,106])/255
SBIFit   = np.array([64,176,166])/255
WLSFit = np.array([140, 100, 200]) / 255  # muted violet
MLEFit = np.array([210, 80, 140]) / 255
BayFit = np.array([100, 120, 220]) / 255 

Errors_name = ['MD comparison','FA comparison','eig. comparison','Frobenius','Signal comparison','Correlation','Signal comparison','Correlation2']

## DKI Fits

In [None]:

i = 1
fdwi = './HCP_data/Pat'+str(i)+'/diff_1k.nii.gz'
bvalloc = './HCP_data/Pat'+str(i)+'/bvals_1k.txt'
bvecloc = './HCP_data/Pat'+str(i)+'/bvecs_1k.txt'

fdwi3 = './HCP_data/Pat'+str(i)+'/diff_3k.nii.gz'
bvalloc3 = './HCP_data/Pat'+str(i)+'/bvals_3k.txt'
bvecloc3 = './HCP_data/Pat'+str(i)+'/bvecs_3k.txt'

bvalsHCP = np.loadtxt(bvalloc)
bvecsHCP = np.loadtxt(bvecloc)
gtabHCP = gradient_table(bvalsHCP, bvecsHCP)

bvalsHCP3 = np.loadtxt(bvalloc3)
bvecsHCP3 = np.loadtxt(bvecloc3)
gtabHCP3 = gradient_table(bvalsHCP3, bvecsHCP3)

gtabExt  = gradient_table(np.hstack((bvalsHCP,bvalsHCP3)), np.vstack((bvecsHCP,bvecsHCP3)))

data, affine, img = load_nifti(fdwi, return_img=True)
data, affine = reslice(data, affine, (1.5,1.5,1.5), (2.5,2.5,2.5))
axial_middle = data.shape[2] // 2
maskdata, mask = median_otsu(data, vol_idx=range(10, 50), median_radius=3,
                             numpass=1, autocrop=False, dilate=2)

data3, affine, img = load_nifti(fdwi3, return_img=True)
data3, affine = reslice(data3, affine, (1.5,1.5,1.5), (2.5,2.5,2.5))
# Get the indices of True values
true_indices = np.argwhere(mask)

# Determine the minimum and maximum indices along each dimension
min_coords = true_indices.min(axis=0)
max_coords = true_indices.max(axis=0)

maskdata  = maskdata[min_coords[0]:max_coords[0]+1,min_coords[1]:max_coords[1]+1,min_coords[2]:max_coords[2]+1]
maskdata3 = data3[min_coords[0]:max_coords[0]+1,min_coords[1]:max_coords[1]+1,min_coords[2]:max_coords[2]+1]

TestData = np.concatenate([maskdata[:, :, axial_middle, :],maskdata3[:, :, axial_middle, :]],axis=-1)
FlatTD = TestData.reshape(maskdata.shape[0]*maskdata.shape[1],138)
FlatTD = FlatTD[FlatTD[:,:69].sum(axis=-1)>0]
FlatTD = FlatTD[~np.array(FlatTD<0).any(axis=-1)]

dkimodel = dki.DiffusionKurtosisModel(gtabExt)
tenfit = dkimodel.fit(FlatTD)
DKIHCP = tenfit.kt
DTIHCP = tenfit.lower_triangular()
DKIFull = np.array(DKIHCP)
DTIFull = np.array(DTIHCP)


DTIFilt1 = DTIFull[(abs(DKIFull)<10).all(axis=1)]
DKIFilt1 = DKIFull[(abs(DKIFull)<10).all(axis=1)]
DTIFilt = DTIFilt1[(DKIFilt1>-3/7).all(axis=1)]
DKIFilt = DKIFilt1[(DKIFilt1>-3/7).all(axis=1)]

TrueMets = []
FA       = []
for (dt,kt) in tqdm.tqdm(zip(DTIFilt,DKIFilt)):
    TrueMets.append(DKIMetrics(dt,kt))
    FA.append(FracAni(np.linalg.eigh(vals_to_mat(dt))[0],np.mean(np.linalg.eigh(vals_to_mat(dt))[0])))
TrueMets = np.array(TrueMets)
TrueFA = np.array(FA)

In [None]:
# Full fit
DT1_full,DT2_full = FitDT(DTIFilt,1)
x4_full,R1_full,x2_full,R2_full = FitKT(DKIFilt,1)

# LowFA Fit
DT1_lfa,DT2_lfa = FitDT(DTIFilt[TrueMets[:,-1]<0.3,:],1)
x4_lfa,R1_lfa,x2_lfa,R2_lfa = FitKT(DKIFilt[TrueMets[:,-1]<0.3,:],1)

# HighFA Fit
DT1_hfa,DT2_hfa = FitDT(DTIFilt[TrueMets[:,-1]>0.7,:],1)
x4_hfa,R1_hfa,x2_hfa,R2_hfa = FitKT(DKIFilt[TrueMets[:,-1]>0.7,:],1)

# UltraLowFA Fit
DT1_ulfa,DT2_ulfa = FitDT(DTIFilt[TrueMets[:,-1]<0.1,:],1)
x4_ulfa,R1_ulfa,x2_ulfa,R2_ulfa = FitKT(DKIFilt[TrueMets[:,-1]<0.1,:],1)

# HigherAK Fit
DT1_hak,DT2_hak = FitDT(DTIFilt[TrueMets[:,1]>0.9,:],1)
x4_hak,R1_hak,x2_hak,R2_hak = FitKT(DKIFilt[TrueMets[:,1]>0.9,:],1)

In [None]:
se = 14
torch.manual_seed(se)
np.random.seed(se)
j = 1
vL = torch.tensor([0.2*j])
vS = torch.tensor([0.01*j])  

kk = np.random.randint(0,4)
if(kk==0):
    DT_guess,KT_guess= GenDTKT([DT1_full,DT2_full],[x4_full,R1_full,x2_full,R2_full],se,1)
elif(kk==1):
    DT_guess,KT_guess = GenDTKT([DT1_lfa,DT2_lfa],[x4_lfa,R1_lfa,x2_lfa,R2_lfa],se,1)
elif(kk==2):
    DT_guess,KT_guess = GenDTKT([DT1_hfa,DT2_hfa],[x4_hfa,R1_hfa,x2_hfa,R2_hfa],se,1)
elif(kk==3):
    DT_guess,KT_guess = GenDTKT([DT1_ulfa,DT2_ulfa],[x4_ulfa,R1_ulfa,x2_ulfa,R2_ulfa],se,1)


In [None]:
DatFolder = '/Users/maximilianeggl/Dropbox/PostDoc/Silvia/SBIDTIPaper2/Code/SavedDat/'

# MLE

In [None]:
def invertComputeDTI(A, scale_factor=1e-3):
    # 1) ensure perfect symmetry
    A = (A + A.T) / 2.0

    # 2) Cholesky → L lower‑triangular with positive diagonals
    L = np.linalg.cholesky(A)

    # 3) recover params:
    #    - diagonal params = log(L[ii,ii] / scale_factor)
    #    - off‑diagonals = the corresponding L entries
    return np.array([
        np.log(L[0,0] / scale_factor),  # p0
        np.log(L[1,1] / scale_factor),  # p2
        np.log(L[2,2] / scale_factor),   # p5
        L[1,0],                         # p1
        L[2,0],                         # p3
        L[2,1]                         # p4

    ])
def ComputeDTI(params, scale_factor=1e-3):
    L = fill_lower_diag(params)
    diag_indices = np.diag_indices_from(L)
    L[diag_indices] = np.exp(L[diag_indices]) * scale_factor
    A = L @ L.T
    return A
    
def rician_nll_DKI(params,S_obs,gtab,S0):
    dt        = params[:6]
    kt        = params[6:-1]
    log_sigma = params[-1]

    S_model  = CustomDKISimulator2(dt,kt,gtab,S0)

    sigma2 = np.exp(2 * log_sigma)  # ensure positivity
    # Avoid log(0) with clipping
    S_obs_clipped = np.clip(S_obs, 1e-10, None)
    S_model_clipped = np.clip(S_model, 1e-10, 1e10)

    bessel_arg = (S_obs_clipped * S_model_clipped) / sigma2

    # use the scaled Bessel to avoid overflow:
    log_bessel = np.log(i0e(bessel_arg)) + np.abs(bessel_arg)

    nll = (
        np.log(sigma2)
        - np.log(S_obs_clipped)
        + (S_obs_clipped**2 + S_model_clipped**2) / (2 * sigma2)
        - log_bessel
    )
    if(np.isnan(nll).any() or np.isinf(nll).any()):
        import pdb;pdb.set_trace()
    return np.sum(nll)

In [None]:
def residuals(params,gtab,y):

    dt        = params[:6]
    kt        = params[6:-1]
    log_S0 = params[-1]
    
    Signal = CustomDKISimulator2(dt,kt,gtab,np.exp(log_S0))
    return y - Signal

In [None]:
kk = 3
fdwi = './HCP_data/Pat'+str(kk)+'/diff_1k.nii.gz'
bvalloc = './HCP_data/Pat'+str(kk)+'/bvals_1k.txt'
bvecloc = './HCP_data/Pat'+str(kk)+'/bvecs_1k.txt'

fdwi3 = './HCP_data/Pat'+str(kk)+'/diff_3k.nii.gz'
bvalloc3 = './HCP_data/Pat'+str(kk)+'/bvals_3k.txt'
bvecloc3 = './HCP_data/Pat'+str(kk)+'/bvecs_3k.txt'

bvalsHCP = np.loadtxt(bvalloc)
bvecsHCP = np.loadtxt(bvecloc)
gtabHCP = gradient_table(bvalsHCP, bvecsHCP)

bvalsHCP3 = np.loadtxt(bvalloc3)
bvecsHCP3 = np.loadtxt(bvecloc3)
gtabHCP3 = gradient_table(bvalsHCP3, bvecsHCP3)

gtabExt  = gradient_table(np.hstack((bvalsHCP,bvalsHCP3)), np.vstack((bvecsHCP,bvecsHCP3)))

data, affine, img = load_nifti(fdwi, return_img=True)
data, affine = reslice(data, affine, (1.5,1.5,1.5), (2.5,2.5,2.5))
maskdata, mask = median_otsu(data, vol_idx=range(10, 50), median_radius=4,
                             numpass=1, autocrop=False, dilate=2)
_, mask2 = median_otsu(data, vol_idx=range(10, 50), median_radius=4,
                             numpass=1, autocrop=True, dilate=2)


data3, affine, img = load_nifti(fdwi3, return_img=True)
data3, affine = reslice(data3, affine, (1.5,1.5,1.5), (2.5,2.5,2.5))
# Get the indices of True values
true_indices = np.argwhere(mask)

# Determine the minimum and maximum indices along each dimension
min_coords = true_indices.min(axis=0)
max_coords = true_indices.max(axis=0)

maskdata  = maskdata[min_coords[0]:max_coords[0]+1,min_coords[1]:max_coords[1]+1,min_coords[2]:max_coords[2]+1]
axial_middle = maskdata.shape[2] // 2
maskdata3 = data3[min_coords[0]:max_coords[0]+1,min_coords[1]:max_coords[1]+1,min_coords[2]:max_coords[2]+1]

TestData = np.concatenate([maskdata[:, :, axial_middle, :],maskdata3[:, :, axial_middle, :]],axis=-1)
TestData4D = np.concatenate([maskdata,maskdata3],axis=-1)

In [None]:
lower_bounds = [-0.0] * 3 + [-0.015]*3 + [0]*3 + [-1]*9 + [-0.25]*3 + [np.log(2)]

upper_bounds = [5] * 3 + [0.015]*3 + [5]*3 + [1]*9 + [0.25]*3 + [np.log(20)]
bounds = list(zip(lower_bounds, upper_bounds))

In [None]:
for kk in tqdm.tqdm(range(1,33)):
    fdwi = './HCP_data/Pat'+str(kk)+'/diff_1k.nii.gz'
    bvalloc = './HCP_data/Pat'+str(kk)+'/bvals_1k.txt'
    bvecloc = './HCP_data/Pat'+str(kk)+'/bvecs_1k.txt'
    
    fdwi3 = './HCP_data/Pat'+str(kk)+'/diff_3k.nii.gz'
    bvalloc3 = './HCP_data/Pat'+str(kk)+'/bvals_3k.txt'
    bvecloc3 = './HCP_data/Pat'+str(kk)+'/bvecs_3k.txt'
    
    bvalsHCP = np.loadtxt(bvalloc)
    bvecsHCP = np.loadtxt(bvecloc)
    gtabHCP = gradient_table(bvalsHCP, bvecsHCP)
    
    bvalsHCP3 = np.loadtxt(bvalloc3)
    bvecsHCP3 = np.loadtxt(bvecloc3)
    gtabHCP3 = gradient_table(bvalsHCP3, bvecsHCP3)
    
    gtabExt  = gradient_table(np.hstack((bvalsHCP,bvalsHCP3)), np.vstack((bvecsHCP,bvecsHCP3)))
    
    data, affine, img = load_nifti(fdwi, return_img=True)
    data, affine = reslice(data, affine, (1.5,1.5,1.5), (2.5,2.5,2.5))
    maskdata, mask = median_otsu(data, vol_idx=range(10, 50), median_radius=4,
                                 numpass=1, autocrop=False, dilate=2)
    _, mask2 = median_otsu(data, vol_idx=range(10, 50), median_radius=4,
                                 numpass=1, autocrop=True, dilate=2)
    
    
    data3, affine, img = load_nifti(fdwi3, return_img=True)
    data3, affine = reslice(data3, affine, (1.5,1.5,1.5), (2.5,2.5,2.5))
    # Get the indices of True values
    true_indices = np.argwhere(mask)
    
    # Determine the minimum and maximum indices along each dimension
    min_coords = true_indices.min(axis=0)
    max_coords = true_indices.max(axis=0)
    
    maskdata  = maskdata[min_coords[0]:max_coords[0]+1,min_coords[1]:max_coords[1]+1,min_coords[2]:max_coords[2]+1]
    axial_middle = maskdata.shape[2] // 2
    maskdata3 = data3[min_coords[0]:max_coords[0]+1,min_coords[1]:max_coords[1]+1,min_coords[2]:max_coords[2]+1]
    
    TestData = np.concatenate([maskdata[:, :, axial_middle, :],maskdata3[:, :, axial_middle, :]],axis=-1)
    TestData4D = np.concatenate([maskdata,maskdata3],axis=-1)

    ArrShape = TestData4D[:,:,axial_middle,0].shape
    NoiseEst = np.zeros(list(ArrShape) + [22])
    torch.manual_seed(10)
    DT_guess_new = invertComputeDTI(vals_to_mat(DT_guess.squeeze()))
    KT_guess_new = KT_guess.squeeze()
    for i in tqdm.tqdm(range(ArrShape[0])):
        for j in range(ArrShape[1]):
            if(np.sum(TestData4D[i,j,axial_middle,:69],axis=-1) == 0):
                pass
            else:
                tObs = TestData4D[i,j,axial_middle,:]
                LS_x = least_squares(residuals, x0=np.hstack([DT_guess_new,KT_guess_new,np.log(tObs[gtabExt.bvals==0].mean())]), args=(gtabExt, tObs)).x
                res = minimize(
                        rician_nll_DKI,
                        np.hstack([LS_x[:6],LS_x[6:21],np.log(10)]),
                        args=(tObs, gtabExt, np.exp(LS_x[-1])),
                        options={'disp':0,'maxfun': 100000,   # ← raise the cap on f+g evaluations
                                'maxiter': 20000,   # ← (optional) raise the cap on iterations   # ← raise the cap on f+g evaluations   # ← raise the cap on f+g evaluations
                                'gtol': 1e-6,
                                'ftol': 1e-6,},
                            bounds=bounds
                    )
                NoiseEst[i,j] = res.x
                #DT_guess_new,KT_guess_new = LS_x[:6],LS_x[6:21]
    NoiseEstInv = np.zeros_like(NoiseEst)
    for i in range(ArrShape[0]):
        for j in range(ArrShape[1]):    
            NoiseEstInv[i,j] = np.hstack([mat_to_vals(ComputeDTI(NoiseEst[i,j,:6])),NoiseEst[i,j,6:21],NoiseEst[i,j,-1]])
    NoiseEst2 =  np.zeros_like(NoiseEst)
    for i in range(ArrShape[0]):
        for j in range(ArrShape[1]):    
            NoiseEst2[i,j] = np.hstack([mat_to_vals(clip_negative_eigenvalues(vals_to_mat(NoiseEstInv[i,j]))),NoiseEstInv[i,j,6:]])
    
    MK_SBIFull  = np.zeros(ArrShape)
    AK_SBIFull  = np.zeros(ArrShape)
    RK_SBIFull  = np.zeros(ArrShape)
    for i in range(ArrShape[0]):
        for j in range(ArrShape[1]): 
            Metrics = DKIMetrics(NoiseEst2[i,j][:6],NoiseEst2[i,j][6:21])
            MK_SBIFull[i,j] = Metrics[0]
            AK_SBIFull[i,j] = Metrics[1]
            RK_SBIFull[i,j] = Metrics[2]
    np.save(DatFolder+'Full_MK_MLE_'+str(kk),np.array(MK_SBIFull,dtype=object))
    np.save(DatFolder+'Full_AK_MLE_'+str(kk),np.array(AK_SBIFull,dtype=object))
    np.save(DatFolder+'Full_RK_MLE_'+str(kk),np.array(RK_SBIFull,dtype=object))

In [None]:
for kk in tqdm.tqdm(range(23,33)):
    fdwi = './HCP_data/Pat'+str(kk)+'/diff_1k.nii.gz'
    bvalloc = './HCP_data/Pat'+str(kk)+'/bvals_1k.txt'
    bvecloc = './HCP_data/Pat'+str(kk)+'/bvecs_1k.txt'
    
    fdwi3 = './HCP_data/Pat'+str(kk)+'/diff_3k.nii.gz'
    bvalloc3 = './HCP_data/Pat'+str(kk)+'/bvals_3k.txt'
    bvecloc3 = './HCP_data/Pat'+str(kk)+'/bvecs_3k.txt'
    
    bvalsHCP = np.loadtxt(bvalloc)
    bvecsHCP = np.loadtxt(bvecloc)
    gtabHCP = gradient_table(bvalsHCP, bvecsHCP)
    
    bvalsHCP3 = np.loadtxt(bvalloc3)
    bvecsHCP3 = np.loadtxt(bvecloc3)
    gtabHCP3 = gradient_table(bvalsHCP3, bvecsHCP3)
    
    gtabExt  = gradient_table(np.hstack((bvalsHCP,bvalsHCP3)), np.vstack((bvecsHCP,bvecsHCP3)))
    
    data, affine, img = load_nifti(fdwi, return_img=True)
    data, affine = reslice(data, affine, (1.5,1.5,1.5), (2.5,2.5,2.5))
    maskdata, mask = median_otsu(data, vol_idx=range(10, 50), median_radius=4,
                                 numpass=1, autocrop=False, dilate=2)
    _, mask2 = median_otsu(data, vol_idx=range(10, 50), median_radius=4,
                                 numpass=1, autocrop=True, dilate=2)
    
    
    data3, affine, img = load_nifti(fdwi3, return_img=True)
    data3, affine = reslice(data3, affine, (1.5,1.5,1.5), (2.5,2.5,2.5))
    # Get the indices of True values
    true_indices = np.argwhere(mask)
    
    # Determine the minimum and maximum indices along each dimension
    min_coords = true_indices.min(axis=0)
    max_coords = true_indices.max(axis=0)
    
    maskdata  = maskdata[min_coords[0]:max_coords[0]+1,min_coords[1]:max_coords[1]+1,min_coords[2]:max_coords[2]+1]
    axial_middle = maskdata.shape[2] // 2
    maskdata3 = data3[min_coords[0]:max_coords[0]+1,min_coords[1]:max_coords[1]+1,min_coords[2]:max_coords[2]+1]
    
    TestData = np.concatenate([maskdata[:, :, axial_middle, :],maskdata3[:, :, axial_middle, :]],axis=-1)
    TestData4D = np.concatenate([maskdata,maskdata3],axis=-1)
    
    ArrShape = TestData4D[:,:,axial_middle,0].shape
    NoiseEst = np.zeros(list(ArrShape) + [22])
    torch.manual_seed(10)
    DT_guess_new = invertComputeDTI(vals_to_mat(DT_guess.squeeze()))
    KT_guess_new = KT_guess.squeeze()
    for i in tqdm.tqdm(range(ArrShape[0])):
        for j in range(ArrShape[1]):
            if(np.sum(TestData4D[i,j,axial_middle,:69],axis=-1) == 0):
                pass
            else:
                tObs = TestData4D[i,j,axial_middle,:]
                avg_s0 = tObs[gtabExt.bvals==0].mean()
                if(avg_s0<0):
                    tObs = tObs + np.abs(tObs.min())
                try:
                    LS_x = least_squares(residuals, x0=np.hstack([DT_guess_new,KT_guess_new,np.log(tObs[gtabExt.bvals==0].mean())]), args=(gtabExt, tObs)).x
                    try:
                        res = minimize(
                                    rician_nll_DKI,
                                    np.hstack([LS_x[:6],LS_x[6:21],np.log(10)]),
                                    args=(tObs, gtabExt, np.exp(LS_x[-1])),
                                    options={'disp':0,'maxfun': 100000,   # ← raise the cap on f+g evaluations
                                            'maxiter': 20000,   # ← (optional) raise the cap on iterations   # ← raise the cap on f+g evaluations   # ← raise the cap on f+g evaluations
                                            'gtol': 1e-6,
                                            'ftol': 1e-6,},
                                        bounds=bounds
                                )
                    except:
                        res = minimize(
                                    rician_nll_DKI,
                                    np.hstack([DT_guess_new,KT_guess_new,np.log(tObs[gtabExt.bvals==0].mean())]),
                                    args=(tObs, gtabExt, np.exp(LS_x[-1])),
                                    options={'disp':0,'maxfun': 100000,   # ← raise the cap on f+g evaluations
                                            'maxiter': 20000,   # ← (optional) raise the cap on iterations   # ← raise the cap on f+g evaluations   # ← raise the cap on f+g evaluations
                                            'gtol': 1e-6,
                                            'ftol': 1e-6,},
                                        bounds=bounds
                        )
                    NoiseEst[i,j] = res.x
                except:
                    NoiseEst[i,j] = 0
                
                #DT_guess_new,KT_guess_new = LS_x[:6],LS_x[6:21]
    NoiseEstInv = np.zeros_like(NoiseEst)
    for i in range(ArrShape[0]):
        for j in range(ArrShape[1]):    
            NoiseEstInv[i,j] = np.hstack([mat_to_vals(ComputeDTI(NoiseEst[i,j,:6])),NoiseEst[i,j,6:21],NoiseEst[i,j,-1]])
    NoiseEst2 =  np.zeros_like(NoiseEst)
    for i in range(ArrShape[0]):
        for j in range(ArrShape[1]):    
            NoiseEst2[i,j] = np.hstack([mat_to_vals(clip_negative_eigenvalues(vals_to_mat(NoiseEstInv[i,j]))),NoiseEstInv[i,j,6:]])
    
    MK_SBIFull  = np.zeros(ArrShape)
    AK_SBIFull  = np.zeros(ArrShape)
    RK_SBIFull  = np.zeros(ArrShape)
    for i in range(ArrShape[0]):
        for j in range(ArrShape[1]): 
            Metrics = DKIMetrics(NoiseEst2[i,j][:6],NoiseEst2[i,j][6:21])
            MK_SBIFull[i,j] = Metrics[0]
            AK_SBIFull[i,j] = Metrics[1]
            RK_SBIFull[i,j] = Metrics[2]
    np.save(DatFolder+'Full_MK_MLE_'+str(kk),np.array(MK_SBIFull,dtype=object))
    np.save(DatFolder+'Full_AK_MLE_'+str(kk),np.array(AK_SBIFull,dtype=object))
    np.save(DatFolder+'Full_RK_MLE_'+str(kk),np.array(RK_SBIFull,dtype=object))

    fig,ax = plt.subplots(1,3)
    ax[0].imshow(MK_SBIFull,vmin=0,vmax=1)
    ax[1].imshow(AK_SBIFull,vmin=0,vmax=1)
    ax[2].imshow(RK_SBIFull,vmin=0,vmax=1)
    plt.show()

#### Min

In [None]:
i = 1
bvalloc = './HCP_data/Pat'+str(i)+'/bvals_1k.txt'
bvecloc = './HCP_data/Pat'+str(i)+'/bvecs_1k.txt'
bvalsHCP = np.loadtxt(bvalloc)
bvecsHCP = np.loadtxt(bvecloc)
gtabHCP = gradient_table(bvalsHCP, bvecsHCP)

# Choose the first point (arbitrary starting point, e.g., the first gradient)
selected_indices = [1]
distance_matrix = squareform(pdist(bvecsHCP))
# Iteratively select the point furthest from the current selection
for _ in range(5):  # We need 7 points in total, and one is already selected
    remaining_indices = list(set(range(len(bvecsHCP))) - set(selected_indices))
    
    # Calculate the minimum distance to the selected points for each remaining point
    min_distances = np.min(distance_matrix[remaining_indices][:, selected_indices], axis=1)
    
    # Select the point with the maximum minimum distance
    next_index = remaining_indices[np.argmax(min_distances)]
    selected_indices.append(next_index)

selected_indices7 = [0]+selected_indices

bvalsHCP7_1 = bvalsHCP[selected_indices7]
bvecsHCP7_1 = bvecsHCP[selected_indices7]

bvalloc = './HCP_data/Pat'+str(i)+'/bvals_3k.txt'
bvecloc = './HCP_data/Pat'+str(i)+'/bvecs_3k.txt'

bvalsHCP3 = np.loadtxt(bvalloc)
bvecsHCP3 = np.loadtxt(bvecloc)
gtabHCP3 = gradient_table(bvalsHCP, bvecsHCP)

# Choose the first point (arbitrary starting point, e.g., the first gradient)
selected_indices = [0]

temp_bvecs = bvecsHCP3[bvalsHCP3>0]
temp_bvals = bvalsHCP3[bvalsHCP3>0]
distance_matrix = squareform(pdist(temp_bvecs))
# Iteratively select the point furthest from the current selection
for _ in range(14):  # We need 7 points in total, and one is already selected
    remaining_indices = list(set(range(len(temp_bvecs))) - set(selected_indices))
    
    # Calculate the minimum distance to the selected points for each remaining point
    min_distances = np.min(distance_matrix[remaining_indices][:, selected_indices], axis=1)
    
    # Select the point with the maximum minimum distance
    next_index = remaining_indices[np.argmax(min_distances)]
    selected_indices.append(next_index)

bvalsHCP7_3 = temp_bvals[selected_indices]
bvecsHCP7_3 = temp_bvecs[selected_indices]

gtabHCP7 = gradient_table(np.hstack((bvalsHCP7_1,bvalsHCP7_3)), np.vstack((bvecsHCP7_1,bvecsHCP7_3)))

true_indx = []
for b in bvecsHCP7_3:
    true_indx.append(np.linalg.norm(b-bvecsHCP3,axis=1).argmin())
selected_indices7 = selected_indices7+[t+69 for t in true_indx]



In [None]:
for kk in tqdm.tqdm(range(1,33)):
    fdwi = './HCP_data/Pat'+str(kk)+'/diff_1k.nii.gz'
    bvalloc = './HCP_data/Pat'+str(kk)+'/bvals_1k.txt'
    bvecloc = './HCP_data/Pat'+str(kk)+'/bvecs_1k.txt'
    
    fdwi3 = './HCP_data/Pat'+str(kk)+'/diff_3k.nii.gz'
    bvalloc3 = './HCP_data/Pat'+str(kk)+'/bvals_3k.txt'
    bvecloc3 = './HCP_data/Pat'+str(kk)+'/bvecs_3k.txt'
    
    bvalsHCP = np.loadtxt(bvalloc)
    bvecsHCP = np.loadtxt(bvecloc)
    gtabHCP = gradient_table(bvalsHCP, bvecsHCP)
    
    bvalsHCP3 = np.loadtxt(bvalloc3)
    bvecsHCP3 = np.loadtxt(bvecloc3)
    gtabHCP3 = gradient_table(bvalsHCP3, bvecsHCP3)
    
    gtabExt  = gradient_table(np.hstack((bvalsHCP,bvalsHCP3)), np.vstack((bvecsHCP,bvecsHCP3)))

    gtabHCP7 = gradient_table(gtabExt.bvals[selected_indices7],gtabExt.bvecs[selected_indices7])
    
    data, affine, img = load_nifti(fdwi, return_img=True)
    data, affine = reslice(data, affine, (1.5,1.5,1.5), (2.5,2.5,2.5))
    maskdata, mask = median_otsu(data, vol_idx=range(10, 50), median_radius=4,
                                 numpass=1, autocrop=False, dilate=2)
    _, mask2 = median_otsu(data, vol_idx=range(10, 50), median_radius=4,
                                 numpass=1, autocrop=True, dilate=2)
    
    
    data3, affine, img = load_nifti(fdwi3, return_img=True)
    data3, affine = reslice(data3, affine, (1.5,1.5,1.5), (2.5,2.5,2.5))
    # Get the indices of True values
    true_indices = np.argwhere(mask)
    
    # Determine the minimum and maximum indices along each dimension
    min_coords = true_indices.min(axis=0)
    max_coords = true_indices.max(axis=0)
    
    maskdata  = maskdata[min_coords[0]:max_coords[0]+1,min_coords[1]:max_coords[1]+1,min_coords[2]:max_coords[2]+1]
    axial_middle = maskdata.shape[2] // 2
    maskdata3 = data3[min_coords[0]:max_coords[0]+1,min_coords[1]:max_coords[1]+1,min_coords[2]:max_coords[2]+1]
    
    TestData = np.concatenate([maskdata[:, :, axial_middle, :],maskdata3[:, :, axial_middle, :]],axis=-1)
    TestData4D = np.concatenate([maskdata,maskdata3],axis=-1)
    
    ArrShape = TestData4D[:,:,axial_middle,0].shape
    NoiseEst = np.zeros(list(ArrShape) + [22])
    torch.manual_seed(10)
    DT_guess_new = invertComputeDTI(vals_to_mat(DT_guess.squeeze()))
    KT_guess_new = KT_guess.squeeze()
    for i in tqdm.tqdm(range(ArrShape[0])):
        for j in range(ArrShape[1]):
            if(np.sum(TestData4D[i,j,axial_middle,:69],axis=-1) == 0):
                pass
            else:
                tObs = TestData4D[i,j,axial_middle,selected_indices7]
                avg_s0 = tObs[gtabHCP7.bvals==0].mean()
                if(avg_s0<0):
                    tObs = tObs + np.abs(tObs.min())
                try:
                    LS_x = least_squares(residuals, x0=np.hstack([invertComputeDTI(vals_to_mat(DT_guess.squeeze())),KT_guess.squeeze(),np.log(100)]), args=(gtabHCP7, tObs)).x
                    
                    try:
                        res = minimize(
                                    rician_nll_DKI,
                                    np.hstack([LS_x[:6],LS_x[6:21],np.log(10)]),
                                    args=(tObs, gtabHCP7, np.exp(LS_x[-1])),
                                    options={'disp':0,'maxfun': 100000,   # ← raise the cap on f+g evaluations
                                            'maxiter': 20000,   # ← (optional) raise the cap on iterations   # ← raise the cap on f+g evaluations   # ← raise the cap on f+g evaluations
                                            'gtol': 1e-6,
                                            'ftol': 1e-6,},
                                        bounds=bounds
                                )
                    except:
                        res = minimize(
                                    rician_nll_DKI,
                                    np.hstack([DT_guess_new,KT_guess_new,np.log(tObs[gtabExt.bvals==0].mean())]),
                                    args=(tObs, gtabHCP7, np.exp(LS_x[-1])),
                                    options={'disp':0,'maxfun': 100000,   # ← raise the cap on f+g evaluations
                                            'maxiter': 20000,   # ← (optional) raise the cap on iterations   # ← raise the cap on f+g evaluations   # ← raise the cap on f+g evaluations
                                            'gtol': 1e-6,
                                            'ftol': 1e-6,},
                                        bounds=bounds
                        )
                    NoiseEst[i,j] = res.x
                except:
                    NoiseEst[i,j] = 0
                #DT_guess_new,KT_guess_new = LS_x[:6],LS_x[6:21]
    NoiseEstInv = np.zeros_like(NoiseEst)
    for i in range(ArrShape[0]):
        for j in range(ArrShape[1]):    
            NoiseEstInv[i,j] = np.hstack([mat_to_vals(ComputeDTI(NoiseEst[i,j,:6])),NoiseEst[i,j,6:21],NoiseEst[i,j,-1]])
    NoiseEst2 =  np.zeros_like(NoiseEst)
    for i in range(ArrShape[0]):
        for j in range(ArrShape[1]):    
            NoiseEst2[i,j] = np.hstack([mat_to_vals(clip_negative_eigenvalues(vals_to_mat(NoiseEstInv[i,j]))),NoiseEstInv[i,j,6:]])
    
    MK_SBIFull  = np.zeros(ArrShape)
    AK_SBIFull  = np.zeros(ArrShape)
    RK_SBIFull  = np.zeros(ArrShape)
    for i in range(ArrShape[0]):
        for j in range(ArrShape[1]): 
            Metrics = DKIMetrics(NoiseEst2[i,j][:6],NoiseEst2[i,j][6:21])
            MK_SBIFull[i,j] = Metrics[0]
            AK_SBIFull[i,j] = Metrics[1]
            RK_SBIFull[i,j] = Metrics[2]
    np.save(DatFolder+'Min_MK_MLE_'+str(kk),np.array(MK_SBIFull,dtype=object))
    np.save(DatFolder+'Min_AK_MLE_'+str(kk),np.array(AK_SBIFull,dtype=object))
    np.save(DatFolder+'Min_RK_MLE_'+str(kk),np.array(RK_SBIFull,dtype=object))

    fig,ax = plt.subplots(1,3)
    ax[0].imshow(MK_SBIFull,vmin=0,vmax=1)
    ax[1].imshow(AK_SBIFull,vmin=0,vmax=1)
    ax[2].imshow(RK_SBIFull,vmin=0,vmax=1)
    plt.show()

# Variational Bayes

In [None]:
def invertComputeDTI_exp(A, scale_factor=1e-3):
    # 1) ensure perfect symmetry
    A = (A + A.T) / 2.0

    # 2) Cholesky → L lower‑triangular with positive diagonals
    L = np.linalg.cholesky(A)

    # 3) recover params:
    #    - diagonal params = log(L[ii,ii] / scale_factor)
    #    - off‑diagonals = the corresponding L entries
    return np.array([
        np.log(L[0,0] / scale_factor),  # p0
        np.log(L[1,1] / scale_factor),  # p2
        np.log(L[2,2] / scale_factor),   # p5
        L[1,0],                         # p1
        L[2,0],                         # p3
        L[2,1]                         # p4

    ])
def ComputeDTI_exp(params, scale_factor=1e-3):
    L = fill_lower_diag(params)
    diag_indices = np.diag_indices_from(L)
    L[diag_indices] = np.exp(L[diag_indices]) * scale_factor
    A = L @ L.T
    return A

def CustomDKISimulator_exp(params,kt,gtab,S0,snr=None):
    dt = ComputeDTI_exp(params)
    evals,evecs = np.linalg.eigh(dt)
    combined_set = np.concatenate([evals,np.hstack(evecs),kt])
    signal = dki.dki_prediction(combined_set,gtab,S0)
    if(np.isnan(signal).any() or np.isinf(signal).any() or (signal>1e15).any()):
        pass#import pdb;pdb.set_trace()
    if(snr is None):
        return signal
    else:
        return AddNoise(signal,S0,snr)

def residuals(params,gtab,y):

    dt        = params[:6]
    kt        = params[6:-1]
    log_S0 = params[-1]
    
    Signal = CustomDKISimulator_exp(dt,kt,gtab,np.exp(log_S0))
    return y - Signal

In [None]:
def f(theta, gtab):
    """DKI signal model for *one* voxel, given params theta."""
    dt        = theta[:6]
    kt        = theta[6:-1]
    S0        = np.exp(theta[-1])          # log-S0 in parameters
    return CustomDKISimulator_exp(dt, kt, gtab, S0)   # (Ngrad,)

# Finite-difference Jacobian (6 dt + K kt + 1 S0 parameters)
def jacobian(theta, gtab, eps=1e-5):
    J = np.zeros((len(gtab.bvals), len(theta)))
    f0 = f(theta, gtab)
    for i in range(len(theta)):
        t_eps        = theta.copy()
        t_eps[i]    += eps
        J[:, i]      = (f(t_eps, gtab) - f0) / eps
    return J, f0
# ------------------------------------------------------------

def vb_gauss_one_voxel(y, gtab, theta0,
                       mu0=None, V0=None,  # Gaussian prior
                       a0=1e-3, b0=1e-3,   # Gamma prior
                       max_iter=100, tol=1e-10):
    """
    Mean-field VB (Gaussian q(theta), Gamma q(1/σ²))
    for one voxel with Gaussian noise.
    """
    D = len(theta0)                       # #parameters
    if mu0 is None:
        mu0 = np.zeros(D)
    if V0 is None:
        V0 = np.eye(D) * 1e2              # very vague prior
    V0_inv = inv(V0)

    m  = theta0.copy()                    # variational mean
    S  = V0.copy()                        # variational cov
    a  = a0 + len(y)/2.
    b  = b0 + 1.0                         # will be updated
    for it in range(max_iter):
        # --- E[precision] ----------------
        lam = a / b                       # <1/σ²>

        # --- linearise model around current mean
        J, f_m = jacobian(m, gtab)
        r      = y - f_m                  # residual

        # --- Gauss-Newton-style VB updates
        S_new  = inv(V0_inv + lam * J.T @ J)
        m_new  = m + S_new @ (V0_inv @ (mu0 - m) + lam * J.T @ r)

        # --- update Gamma factors
        quad   = (r @ r
                  + np.trace(J @ S_new @ J.T))
        b_new  = b0 + 0.5 * quad

        # --- convergence check
        if np.linalg.norm(m_new - m) < tol:
            m, S, b = m_new, S_new, b_new
            break

        m, S, b = m_new, S_new, b_new

    return m, S, a/b 

In [None]:
priors_SD =  [1/0.001633,    1/.5e-5,    1/0.001633,    1/.5e-5,   1/ 7.5e-5, 1/0.001633] + [1e2]*15 + [1e2]
mu0 = np.zeros(22)

V0  = np.diag(priors_SD)

In [None]:
def residuals(params,gtab,y):

    dt        = params[:6]
    kt        = params[6:-1]
    log_S0 = params[-1]
    
    Signal = CustomDKISimulator2(dt,kt,gtab,np.exp(log_S0))
    return y - Signal

In [None]:
for kk in tqdm.tqdm(range(10,33)):
    fdwi = './HCP_data/Pat'+str(kk)+'/diff_1k.nii.gz'
    bvalloc = './HCP_data/Pat'+str(kk)+'/bvals_1k.txt'
    bvecloc = './HCP_data/Pat'+str(kk)+'/bvecs_1k.txt'
    
    fdwi3 = './HCP_data/Pat'+str(kk)+'/diff_3k.nii.gz'
    bvalloc3 = './HCP_data/Pat'+str(kk)+'/bvals_3k.txt'
    bvecloc3 = './HCP_data/Pat'+str(kk)+'/bvecs_3k.txt'
    
    bvalsHCP = np.loadtxt(bvalloc)
    bvecsHCP = np.loadtxt(bvecloc)
    gtabHCP = gradient_table(bvalsHCP, bvecsHCP)
    
    bvalsHCP3 = np.loadtxt(bvalloc3)
    bvecsHCP3 = np.loadtxt(bvecloc3)
    gtabHCP3 = gradient_table(bvalsHCP3, bvecsHCP3)
    
    gtabExt  = gradient_table(np.hstack((bvalsHCP,bvalsHCP3)), np.vstack((bvecsHCP,bvecsHCP3)))
    
    data, affine, img = load_nifti(fdwi, return_img=True)
    data, affine = reslice(data, affine, (1.5,1.5,1.5), (2.5,2.5,2.5))
    maskdata, mask = median_otsu(data, vol_idx=range(10, 50), median_radius=4,
                                 numpass=1, autocrop=False, dilate=2)
    _, mask2 = median_otsu(data, vol_idx=range(10, 50), median_radius=4,
                                 numpass=1, autocrop=True, dilate=2)
    
    
    data3, affine, img = load_nifti(fdwi3, return_img=True)
    data3, affine = reslice(data3, affine, (1.5,1.5,1.5), (2.5,2.5,2.5))
    # Get the indices of True values
    true_indices = np.argwhere(mask)
    
    # Determine the minimum and maximum indices along each dimension
    min_coords = true_indices.min(axis=0)
    max_coords = true_indices.max(axis=0)
    
    maskdata  = maskdata[min_coords[0]:max_coords[0]+1,min_coords[1]:max_coords[1]+1,min_coords[2]:max_coords[2]+1]
    axial_middle = maskdata.shape[2] // 2
    maskdata3 = data3[min_coords[0]:max_coords[0]+1,min_coords[1]:max_coords[1]+1,min_coords[2]:max_coords[2]+1]
    
    TestData = np.concatenate([maskdata[:, :, axial_middle, :],maskdata3[:, :, axial_middle, :]],axis=-1)
    TestData4D = np.concatenate([maskdata,maskdata3],axis=-1)
    
    ArrShape = TestData4D[:,:,axial_middle,0].shape
    NoiseEst = np.zeros(list(ArrShape) + [22])
    torch.manual_seed(10)
    DT_guess_new = invertComputeDTI(vals_to_mat(DT_guess.squeeze()))
    KT_guess_new = KT_guess.squeeze()
    for i in tqdm.tqdm(range(ArrShape[0])):
        for j in range(ArrShape[1]):
            if(np.sum(TestData4D[i,j,axial_middle,:69],axis=-1) == 0):
                pass
            else:
                tObs = TestData4D[i,j,axial_middle,:]
                avg_s0 = tObs[gtabExt.bvals==0].mean()
                if(avg_s0<0):
                    tObs = tObs + np.abs(tObs.min())
                try:
                    LS_x = least_squares(residuals, x0=np.hstack([invertComputeDTI(vals_to_mat(DT_guess.squeeze())),KT_guess.squeeze(),np.log(tObs[gtabExt.bvals==0].mean())]), args=(gtabExt, tObs)).x
                    try:
                        m, S, prec = vb_gauss_one_voxel(tObs, gtabExt, LS_x,mu0=mu0, V0=V0)
                    except:
                        m = 0
                    NoiseEst[i,j] = m
                except:
                    NoiseEst[i,j] = 0
                
                #DT_guess_new,KT_guess_new = LS_x[:6],LS_x[6:21]
    NoiseEstInv = np.zeros_like(NoiseEst)
    for i in range(ArrShape[0]):
        for j in range(ArrShape[1]):    
            NoiseEstInv[i,j] = np.hstack([mat_to_vals(ComputeDTI(NoiseEst[i,j,:6])),NoiseEst[i,j,6:21],NoiseEst[i,j,-1]])
    NoiseEst2 =  np.zeros_like(NoiseEst)
    for i in range(ArrShape[0]):
        for j in range(ArrShape[1]):    
            NoiseEst2[i,j] = np.hstack([mat_to_vals(clip_negative_eigenvalues(vals_to_mat(NoiseEstInv[i,j]))),NoiseEstInv[i,j,6:]])
    
    MK_SBIFull  = np.zeros(ArrShape)
    AK_SBIFull  = np.zeros(ArrShape)
    RK_SBIFull  = np.zeros(ArrShape)
    for i in range(ArrShape[0]):
        for j in range(ArrShape[1]): 
            Metrics = DKIMetrics(NoiseEst2[i,j][:6],NoiseEst2[i,j][6:21])
            MK_SBIFull[i,j] = Metrics[0]
            AK_SBIFull[i,j] = Metrics[1]
            RK_SBIFull[i,j] = Metrics[2]
    np.save(DatFolder+'Full_MK_Bay_'+str(kk),np.array(MK_SBIFull,dtype=object))
    np.save(DatFolder+'Full_AK_Bay_'+str(kk),np.array(AK_SBIFull,dtype=object))
    np.save(DatFolder+'Full_RK_Bay_'+str(kk),np.array(RK_SBIFull,dtype=object))

    fig,ax = plt.subplots(1,3)
    ax[0].imshow(MK_SBIFull,vmin=0,vmax=1)
    ax[1].imshow(AK_SBIFull,vmin=0,vmax=1)
    ax[2].imshow(RK_SBIFull,vmin=0,vmax=1)
    plt.show()

In [None]:
i = 1
bvalloc = './HCP_data/Pat'+str(i)+'/bvals_1k.txt'
bvecloc = './HCP_data/Pat'+str(i)+'/bvecs_1k.txt'
bvalsHCP = np.loadtxt(bvalloc)
bvecsHCP = np.loadtxt(bvecloc)
gtabHCP = gradient_table(bvalsHCP, bvecsHCP)

# Choose the first point (arbitrary starting point, e.g., the first gradient)
selected_indices = [1]
distance_matrix = squareform(pdist(bvecsHCP))
# Iteratively select the point furthest from the current selection
for _ in range(5):  # We need 7 points in total, and one is already selected
    remaining_indices = list(set(range(len(bvecsHCP))) - set(selected_indices))
    
    # Calculate the minimum distance to the selected points for each remaining point
    min_distances = np.min(distance_matrix[remaining_indices][:, selected_indices], axis=1)
    
    # Select the point with the maximum minimum distance
    next_index = remaining_indices[np.argmax(min_distances)]
    selected_indices.append(next_index)

selected_indices7 = [0]+selected_indices

bvalsHCP7_1 = bvalsHCP[selected_indices7]
bvecsHCP7_1 = bvecsHCP[selected_indices7]

bvalloc = './HCP_data/Pat'+str(i)+'/bvals_3k.txt'
bvecloc = './HCP_data/Pat'+str(i)+'/bvecs_3k.txt'

bvalsHCP3 = np.loadtxt(bvalloc)
bvecsHCP3 = np.loadtxt(bvecloc)
gtabHCP3 = gradient_table(bvalsHCP, bvecsHCP)

# Choose the first point (arbitrary starting point, e.g., the first gradient)
selected_indices = [0]

temp_bvecs = bvecsHCP3[bvalsHCP3>0]
temp_bvals = bvalsHCP3[bvalsHCP3>0]
distance_matrix = squareform(pdist(temp_bvecs))
# Iteratively select the point furthest from the current selection
for _ in range(14):  # We need 7 points in total, and one is already selected
    remaining_indices = list(set(range(len(temp_bvecs))) - set(selected_indices))
    
    # Calculate the minimum distance to the selected points for each remaining point
    min_distances = np.min(distance_matrix[remaining_indices][:, selected_indices], axis=1)
    
    # Select the point with the maximum minimum distance
    next_index = remaining_indices[np.argmax(min_distances)]
    selected_indices.append(next_index)

bvalsHCP7_3 = temp_bvals[selected_indices]
bvecsHCP7_3 = temp_bvecs[selected_indices]

gtabHCP7 = gradient_table(np.hstack((bvalsHCP7_1,bvalsHCP7_3)), np.vstack((bvecsHCP7_1,bvecsHCP7_3)))

true_indx = []
for b in bvecsHCP7_3:
    true_indx.append(np.linalg.norm(b-bvecsHCP3,axis=1).argmin())
selected_indices7 = selected_indices7+[t+69 for t in true_indx]



In [None]:
for kk in tqdm.tqdm(range(9,33)):
    fdwi = './HCP_data/Pat'+str(kk)+'/diff_1k.nii.gz'
    bvalloc = './HCP_data/Pat'+str(kk)+'/bvals_1k.txt'
    bvecloc = './HCP_data/Pat'+str(kk)+'/bvecs_1k.txt'
    
    fdwi3 = './HCP_data/Pat'+str(kk)+'/diff_3k.nii.gz'
    bvalloc3 = './HCP_data/Pat'+str(kk)+'/bvals_3k.txt'
    bvecloc3 = './HCP_data/Pat'+str(kk)+'/bvecs_3k.txt'
    
    bvalsHCP = np.loadtxt(bvalloc)
    bvecsHCP = np.loadtxt(bvecloc)
    gtabHCP = gradient_table(bvalsHCP, bvecsHCP)
    
    bvalsHCP3 = np.loadtxt(bvalloc3)
    bvecsHCP3 = np.loadtxt(bvecloc3)
    gtabHCP3 = gradient_table(bvalsHCP3, bvecsHCP3)
    
    gtabExt  = gradient_table(np.hstack((bvalsHCP,bvalsHCP3)), np.vstack((bvecsHCP,bvecsHCP3)))
    
    data, affine, img = load_nifti(fdwi, return_img=True)
    data, affine = reslice(data, affine, (1.5,1.5,1.5), (2.5,2.5,2.5))
    maskdata, mask = median_otsu(data, vol_idx=range(10, 50), median_radius=4,
                                 numpass=1, autocrop=False, dilate=2)
    _, mask2 = median_otsu(data, vol_idx=range(10, 50), median_radius=4,
                                 numpass=1, autocrop=True, dilate=2)
    
    
    data3, affine, img = load_nifti(fdwi3, return_img=True)
    data3, affine = reslice(data3, affine, (1.5,1.5,1.5), (2.5,2.5,2.5))
    # Get the indices of True values
    true_indices = np.argwhere(mask)
    
    # Determine the minimum and maximum indices along each dimension
    min_coords = true_indices.min(axis=0)
    max_coords = true_indices.max(axis=0)
    
    maskdata  = maskdata[min_coords[0]:max_coords[0]+1,min_coords[1]:max_coords[1]+1,min_coords[2]:max_coords[2]+1]
    axial_middle = maskdata.shape[2] // 2
    maskdata3 = data3[min_coords[0]:max_coords[0]+1,min_coords[1]:max_coords[1]+1,min_coords[2]:max_coords[2]+1]
    
    TestData = np.concatenate([maskdata[:, :, axial_middle, :],maskdata3[:, :, axial_middle, :]],axis=-1)
    TestData4D = np.concatenate([maskdata,maskdata3],axis=-1)

    gtabHCP7 = gradient_table(gtabExt.bvals[selected_indices7],gtabExt.bvecs[selected_indices7])

    ArrShape = TestData4D[:,:,axial_middle,0].shape
    NoiseEst = np.zeros(list(ArrShape) + [22])
    torch.manual_seed(10)
    DT_guess_new = invertComputeDTI(vals_to_mat(DT_guess.squeeze()))
    KT_guess_new = KT_guess.squeeze()
    for i in tqdm.tqdm(range(ArrShape[0])):
        for j in range(ArrShape[1]):
            if(np.sum(TestData4D[i,j,axial_middle,:69],axis=-1) == 0):
                pass
            else:
                tObs = TestData4D[i,j,axial_middle,selected_indices7]
                avg_s0 = tObs[gtabHCP7.bvals==0].mean()
                if(avg_s0<0):
                    tObs = tObs + np.abs(tObs.min())
                try:
                    LS_x = least_squares(residuals, x0=np.hstack([invertComputeDTI(vals_to_mat(DT_guess.squeeze())),KT_guess.squeeze(),np.log(100)]), args=(gtabHCP7, tObs)).x
                    try:
                        m, S, prec = vb_gauss_one_voxel(tObs, gtabHCP7, LS_x,mu0=mu0, V0=V0)
                    except:
                        m = 0
                    NoiseEst[i,j] = m
                except:
                    NoiseEst[i,j] = 0
                
                #DT_guess_new,KT_guess_new = LS_x[:6],LS_x[6:21]
    NoiseEstInv = np.zeros_like(NoiseEst)
    for i in range(ArrShape[0]):
        for j in range(ArrShape[1]):    
            NoiseEstInv[i,j] = np.hstack([mat_to_vals(ComputeDTI(NoiseEst[i,j,:6])),NoiseEst[i,j,6:21],NoiseEst[i,j,-1]])
    NoiseEst2 =  np.zeros_like(NoiseEst)
    for i in range(ArrShape[0]):
        for j in range(ArrShape[1]):    
            NoiseEst2[i,j] = np.hstack([mat_to_vals(clip_negative_eigenvalues(vals_to_mat(NoiseEstInv[i,j]))),NoiseEstInv[i,j,6:]])
    
    MK_SBIFull  = np.zeros(ArrShape)
    AK_SBIFull  = np.zeros(ArrShape)
    RK_SBIFull  = np.zeros(ArrShape)
    for i in range(ArrShape[0]):
        for j in range(ArrShape[1]): 
            Metrics = DKIMetrics(NoiseEst2[i,j][:6],NoiseEst2[i,j][6:21])
            MK_SBIFull[i,j] = Metrics[0]
            AK_SBIFull[i,j] = Metrics[1]
            RK_SBIFull[i,j] = Metrics[2]
    np.save(DatFolder+'Min_MK_Bay_'+str(kk),np.array(MK_SBIFull,dtype=object))
    np.save(DatFolder+'Min_AK_Bay_'+str(kk),np.array(AK_SBIFull,dtype=object))
    np.save(DatFolder+'Min_RK_Bay_'+str(kk),np.array(RK_SBIFull,dtype=object))

    fig,ax = plt.subplots(1,3)
    ax[0].imshow(MK_SBIFull,vmin=0,vmax=1)
    ax[1].imshow(AK_SBIFull,vmin=0,vmax=1)
    ax[2].imshow(RK_SBIFull,vmin=0,vmax=1)
    plt.show()

In [None]:
os.system("say 'Other methods done'") # or '\7'

In [None]:
# Plot setup
fig, (ax1, ax2) = plt.subplots(2, 1,figsize=(3.2,4.8))
fig.subplots_adjust(hspace=0.05)



y_data = np.array(Prec7_NLLS)
g_pos = np.array([2.5])
colors = ['sandybrown']
colors2 = ['peachpuff']

BoxPlots(y_data,g_pos,colors,colors2,ax1,widths=0.2,scatter=True)
plt.xticks([1,1.7,2,2.8,3.1],['Full','Mid','Min','Mid','Min'],fontsize=32,rotation=90)
#plt.gca().ticklabel_format(axis='y',style='sci',scilimits=(-1,1))
plt.yticks(fontsize=24)

#ax1.yaxis.set_ticks(np.arange(0.0005, 0.006, 0.002))
ax1.set_xticks([])

plt.sca(ax2)

y_data = np.array(PrecFull_SBI)
g_pos = np.array([0.65])
colors = ['mediumturquoise']
colors2 = ['paleturquoise']

BoxPlots(y_data,g_pos,colors,colors2,ax2,widths=0.2,scatter=True)

y_data = np.array(Prec20_SBI)
g_pos = np.array([1.0])

BoxPlots(y_data,g_pos,colors,colors2,ax2,widths=0.2,scatter=True)

y_data = np.array(Prec7_SBI)
g_pos = np.array([1.35])

BoxPlots(y_data,g_pos,colors,colors2,ax2,widths=0.2,scatter=True)

y_data = np.array(PrecFull_NLLS)
g_pos = np.array([1.8])
colors = ['sandybrown']
colors2 = ['peachpuff']

BoxPlots(y_data,g_pos,colors,colors2,ax2,widths=0.2,scatter=True)

y_data = np.array(Prec20_NLLS)
g_pos = np.array([2.15])

BoxPlots(y_data,g_pos,colors,colors2,ax2,widths=0.2,scatter=True)

y_data = np.array(Prec7_NLLS)
g_pos = np.array([2.5])

BoxPlots(y_data,g_pos,colors,colors2,ax2,widths=0.2,scatter=True)

#plt.gca().ticklabel_format(axis='y',style='sci',scilimits=(-0.5,3))
#ax1.set_ylim(0.4, 1)
#ax1.ticklabel_format(axis='y', style='sci', scilimits=(-3, -3))
#ax1.yaxis.set_ticks([0.4,0.7,1])
ax1.set_xticks([])

plt.yticks(fontsize=24)

#ax2.set_ylim(0.0,0.3)
#ax2.ticklabel_format(axis='y', style='sci', scilimits=(-3, -3))
#ax2.yaxis.set_ticks(np.arange(0, 0.0001, 0.00004))

# Common x-ticks
#ax2.set_xticks([1, 1.7, 2, 2.8, 3.1])

# Adding broken axis effect
d = .5
kwargs = dict(marker=[(-1, -d), (1, d)], markersize=12, linestyle="none", color='k', mec='k', mew=1, clip_on=False)
ax1.plot([0], [0], transform=ax1.transAxes, **kwargs)
ax2.plot([0], [1], transform=ax2.transAxes, **kwargs)

# Hide the spines between ax and ax2
ax1.spines.bottom.set_visible(False)
ax2.spines.top.set_visible(False)
ax1.xaxis.tick_top()
ax1.tick_params(labeltop=False)  # don't put tick labels at the top
ax2.xaxis.tick_bottom()
ax2.yaxis.offsetText.set_visible(False)  # Hide the "1e-3" from ax2
ax2.set_xlim([0.3,2.7])
ax1.set_xlim(ax2.get_xlim())
# Show plot
ax2.set_ylim(0,0.9)
ax2.set_yticks([0,0.4,0.8])
ax1.set_ylim(1,13)

x = np.arange(1.7,2.6,0.05)
y1 = np.ones_like(x)*np.percentile(np.array(PrecFull_NLLS)[~np.isnan(PrecFull_NLLS)], 25)
y2 = np.ones_like(x)*np.percentile(np.array(PrecFull_NLLS)[~np.isnan(PrecFull_NLLS)], 77)
plt.fill_between(x,y1,y2,color=WLSFit,zorder=10,alpha=0.2,hatch='//')

x = np.arange(0.55,1.5,0.05)
y1 = np.ones_like(x)*np.percentile(np.array(PrecFull_SBI)[~np.isnan(PrecFull_SBI)], 25)
y2 = np.ones_like(x)*np.percentile(np.array(PrecFull_SBI)[~np.isnan(PrecFull_SBI)], 77)
plt.fill_between(x,y1,y2,color=SBIFit,zorder=10,alpha=0.2,hatch='//')

plt.xticks([0.65,1,1.35,1.8,2.15,2.5],['Full','Mid','Min','Full','Mid','Min'],fontsize=32,rotation=90)
#if Save: plt.savefig(FigLoc+'DKI_MK_Prec.pdf',format='pdf',bbox_inches='tight',transparent=True)
plt.show()

# WLS

In [None]:
MKFullWLArr = []
RKFullWLArr = []
AKFullWLArr = []
MKTFullWLArr = []
KFAFullWLArr = []
for kk in tqdm.tqdm(range(1,33)):
    fdwi = './HCP_data/Pat'+str(kk)+'/diff_1k.nii.gz'
    bvalloc = './HCP_data/Pat'+str(kk)+'/bvals_1k.txt'
    bvecloc = './HCP_data/Pat'+str(kk)+'/bvecs_1k.txt'
    
    fdwi3 = './HCP_data/Pat'+str(kk)+'/diff_3k.nii.gz'
    bvalloc3 = './HCP_data/Pat'+str(kk)+'/bvals_3k.txt'
    bvecloc3 = './HCP_data/Pat'+str(kk)+'/bvecs_3k.txt'
    
    bvalsHCP = np.loadtxt(bvalloc)
    bvecsHCP = np.loadtxt(bvecloc)
    gtabHCP = gradient_table(bvalsHCP, bvecsHCP)
    
    bvalsHCP3 = np.loadtxt(bvalloc3)
    bvecsHCP3 = np.loadtxt(bvecloc3)
    gtabHCP3 = gradient_table(bvalsHCP3, bvecsHCP3)
    
    gtabExt  = gradient_table(np.hstack((bvalsHCP,bvalsHCP3)), np.vstack((bvecsHCP,bvecsHCP3)))
    
    data, affine, img = load_nifti(fdwi, return_img=True)
    data, affine = reslice(data, affine, (1.5,1.5,1.5), (2.5,2.5,2.5))
    maskdata, mask = median_otsu(data, vol_idx=range(10, 50), median_radius=4,
                                 numpass=1, autocrop=False, dilate=2)
    _, mask2 = median_otsu(data, vol_idx=range(10, 50), median_radius=4,
                                 numpass=1, autocrop=True, dilate=2)
    
    
    data3, affine, img = load_nifti(fdwi3, return_img=True)
    data3, affine = reslice(data3, affine, (1.5,1.5,1.5), (2.5,2.5,2.5))
    # Get the indices of True values
    true_indices = np.argwhere(mask)
    
    # Determine the minimum and maximum indices along each dimension
    min_coords = true_indices.min(axis=0)
    max_coords = true_indices.max(axis=0)
    
    maskdata  = maskdata[min_coords[0]:max_coords[0]+1,min_coords[1]:max_coords[1]+1,min_coords[2]:max_coords[2]+1]
    axial_middle = maskdata.shape[2] // 2
    maskdata3 = data3[min_coords[0]:max_coords[0]+1,min_coords[1]:max_coords[1]+1,min_coords[2]:max_coords[2]+1]
    
    TestData = np.concatenate([maskdata[:, :, axial_middle, :],maskdata3[:, :, axial_middle, :]],axis=-1)
    TestData4D = np.concatenate([maskdata,maskdata3],axis=-1)
    dkimodelNL = dki.DiffusionKurtosisModel(gtabExt,fit_method='WLS')
    dkifitNL = dkimodelNL.fit(TestData[:,:,:])

    ArrShape = TestData4D[:,:,axial_middle,0].shape
    MK_SBIFull  = np.zeros(ArrShape)
    AK_SBIFull  = np.zeros(ArrShape)
    RK_SBIFull  = np.zeros(ArrShape)
    for i in range(ArrShape[0]):
        for j in range(ArrShape[1]): 
            Metrics = DKIMetrics(dkifitNL[i,j].lower_triangular(),dkifitNL[i,j].kt)
            MK_SBIFull[i,j] = Metrics[0]
            AK_SBIFull[i,j] = Metrics[1]
            RK_SBIFull[i,j] = Metrics[2]
    MKFullWLArr.append(MK_SBIFull)
    AKFullWLArr.append(AK_SBIFull)
    RKFullWLArr.append(RK_SBIFull)

In [None]:
i = 1
bvalloc = './HCP_data/Pat'+str(i)+'/bvals_1k.txt'
bvecloc = './HCP_data/Pat'+str(i)+'/bvecs_1k.txt'
bvalsHCP = np.loadtxt(bvalloc)
bvecsHCP = np.loadtxt(bvecloc)
gtabHCP = gradient_table(bvalsHCP, bvecsHCP)

# Choose the first point (arbitrary starting point, e.g., the first gradient)
selected_indices = [1]
distance_matrix = squareform(pdist(bvecsHCP))

temp_bvecs = bvecsHCP[bvalsHCP>0]
temp_bvals = bvalsHCP[bvalsHCP>0]
distance_matrix = squareform(pdist(temp_bvecs))
# Iteratively select the point furthest from the current selection
for _ in range(18):  # We need 7 points in total, and one is already selected
    remaining_indices = list(set(range(len(temp_bvecs))) - set(selected_indices))
    
    # Calculate the minimum distance to the selected points for each remaining point
    min_distances = np.min(distance_matrix[remaining_indices][:, selected_indices], axis=1)
    
    # Select the point with the maximum minimum distance
    next_index = remaining_indices[np.argmax(min_distances)]
    selected_indices.append(next_index)

temp = selected_indices

bvalsHCP7_1 = np.insert(temp_bvals[temp],0,0)
bvecsHCP7_1 = np.insert(temp_bvecs[temp],0,[0,0,0],axis=0)

bvalloc = './HCP_data/Pat'+str(i)+'/bvals_3k.txt'
bvecloc = './HCP_data/Pat'+str(i)+'/bvecs_3k.txt'

bvalsHCP3 = np.loadtxt(bvalloc)
bvecsHCP3 = np.loadtxt(bvecloc)
gtabHCP3 = gradient_table(bvalsHCP, bvecsHCP)

# Choose the first point (arbitrary starting point, e.g., the first gradient)
selected_indices = [0]

temp_bvecs = bvecsHCP3[bvalsHCP3>0]
temp_bvals = bvalsHCP3[bvalsHCP3>0]
distance_matrix = squareform(pdist(temp_bvecs))
# Iteratively select the point furthest from the current selection
for _ in range(27):  # We need 7 points in total, and one is already selected
    remaining_indices = list(set(range(len(temp_bvecs))) - set(selected_indices))
    
    # Calculate the minimum distance to the selected points for each remaining point
    min_distances = np.min(distance_matrix[remaining_indices][:, selected_indices], axis=1)
    
    # Select the point with the maximum minimum distance
    next_index = remaining_indices[np.argmax(min_distances)]
    selected_indices.append(next_index)

bvalsHCP7_3 = temp_bvals[selected_indices]
bvecsHCP7_3 = temp_bvecs[selected_indices]

gtabHCP20 = gradient_table(np.hstack((bvalsHCP7_1,bvalsHCP7_3)), np.vstack((bvecsHCP7_1,bvecsHCP7_3)))

true_indx_one = []
for b in bvecsHCP7_1:
    true_indx_one.append(np.linalg.norm(b-bvecsHCP,axis=1).argmin())
true_indx = []        
for b in bvecsHCP7_3:
    true_indx.append(np.linalg.norm(b-bvecsHCP3,axis=1).argmin())
selected_indices20 = true_indx_one+[t+69 for t in true_indx]

In [None]:
TD = []
axial_middles = []
masks = []
WMs = []
for kk in tqdm.tqdm(range(32)):
    fdwi = './HCP_data/Pat'+str(kk+1)+'/diff_1k.nii.gz'
    bvalloc = './HCP_data/Pat'+str(kk+1)+'/bvals_1k.txt'
    bvecloc = './HCP_data/Pat'+str(kk+1)+'/bvecs_1k.txt'
    
    fdwi3 = './HCP_data/Pat'+str(kk+1)+'/diff_3k.nii.gz'
    bvalloc3 = './HCP_data/Pat'+str(kk+1)+'/bvals_3k.txt'
    bvecloc3 = './HCP_data/Pat'+str(kk+1)+'/bvecs_3k.txt'
    
    bvalsHCP = np.loadtxt(bvalloc)
    bvecsHCP = np.loadtxt(bvecloc)
    gtabHCP = gradient_table(bvalsHCP, bvecsHCP)
    
    bvalsHCP3 = np.loadtxt(bvalloc3)
    bvecsHCP3 = np.loadtxt(bvecloc3)
    gtabHCP3 = gradient_table(bvalsHCP3, bvecsHCP3)
    
    gtabExt  = gradient_table(np.hstack((bvalsHCP,bvalsHCP3)), np.vstack((bvecsHCP,bvecsHCP3)))
    
    data, affine, img = load_nifti(fdwi, return_img=True)
    data, affine = reslice(data, affine, (1.5,1.5,1.5), (2.5,2.5,2.5))
    maskdata, mask = median_otsu(data, vol_idx=range(10, 50), median_radius=3,
                                 numpass=1, autocrop=False, dilate=2)
    _, mask2 = median_otsu(data, vol_idx=range(10, 50), median_radius=3,
                                 numpass=1, autocrop=True, dilate=2)
    
    
    data3, affine, img = load_nifti(fdwi3, return_img=True)
    data3, affine = reslice(data3, affine, (1.5,1.5,1.5), (2.5,2.5,2.5))
    # Get the indices of True values
    true_indices = np.argwhere(mask)
    
    # Determine the minimum and maximum indices along each dimension
    min_coords = true_indices.min(axis=0)
    max_coords = true_indices.max(axis=0)
    
    maskdata  = maskdata[min_coords[0]:max_coords[0]+1,min_coords[1]:max_coords[1]+1,min_coords[2]:max_coords[2]+1]
    axial_middle = maskdata.shape[2] // 2
    maskdata3 = data3[min_coords[0]:max_coords[0]+1,min_coords[1]:max_coords[1]+1,min_coords[2]:max_coords[2]+1]
    axial_middles.append(axial_middle)
    TestData = np.concatenate([maskdata[:, :, axial_middle, :],maskdata3[:, :, axial_middle, :]],axis=-1)
    TestData4D = np.concatenate([maskdata,maskdata3],axis=-1)
    TD.append(TestData4D)
    masks.append(mask[min_coords[0]:max_coords[0]+1,min_coords[1]:max_coords[1]+1,axial_middle])
    WM, affine, img = load_nifti('./flipped/c2Pat'+str(kk+1)+'_FP.nii', return_img=True)
    WMs.append(np.fliplr(WM[:,:,axial_middle]>0.8))

In [None]:
gTabsF = []
gTabs7 = []
gTabs20 = []

FullDat   = []

for i in tqdm.tqdm(range(1,33)):
    fdwi = './HCP_data/Pat'+str(i)+'/diff_1k.nii.gz'
    bvalloc = './HCP_data/Pat'+str(i)+'/bvals_1k.txt'
    bvecloc = './HCP_data/Pat'+str(i)+'/bvecs_1k.txt'
    
    fdwi3 = './HCP_data/Pat'+str(i)+'/diff_3k.nii.gz'
    bvalloc3 = './HCP_data/Pat'+str(i)+'/bvals_3k.txt'
    bvecloc3 = './HCP_data/Pat'+str(i)+'/bvecs_3k.txt'
    
    bvalsHCP = np.loadtxt(bvalloc)
    bvecsHCP = np.loadtxt(bvecloc)
    gtabHCP = gradient_table(bvalsHCP, bvecsHCP)
    
    bvalsHCP3 = np.loadtxt(bvalloc3)
    bvecsHCP3 = np.loadtxt(bvecloc3)
    gtabHCP3 = gradient_table(bvalsHCP3, bvecsHCP3)
    
    gtabExt  = gradient_table(np.hstack((bvalsHCP,bvalsHCP3)), np.vstack((bvecsHCP,bvecsHCP3)))
    gTabsF.append(gtabExt)
    
    bvalsHCP7 = gtabExt.bvals[selected_indices7]
    bvecsHCP7 = gtabExt.bvecs[selected_indices7]
    gtabHCP7 = gradient_table(bvalsHCP7, bvecsHCP7)
    gTabs7.append(gtabHCP7)

    bvalsHCP20 = gtabExt.bvals[selected_indices20]
    bvecsHCP20 = gtabExt.bvecs[selected_indices20]
    gtabHCP20 = gradient_table(bvalsHCP20, bvecsHCP20)
    gTabs20.append(gtabHCP20)

In [None]:
MKFullWLArr = []
RKFullWLArr = []
AKFullWLArr = []
MKTFullWLArr = []
KFAFullWLArr = []
for kk in tqdm.tqdm(range(32)):
    dkimodelNL = dki.DiffusionKurtosisModel(gTabsF[kk],fit_method='WLS')
    dkifitNL = dkimodelNL.fit(TD[kk][:,:,axial_middles[kk],:])
    ArrShape = TD[kk][:,:,axial_middles[kk],0].shape
    NoiseEst_NL = np.zeros(list(ArrShape)+[21])
    MK_NL7  = np.zeros(ArrShape)
    AK_NL7  = np.zeros(ArrShape)
    RK_NL7 = np.zeros(ArrShape)
    MKT_NL7 = np.zeros(ArrShape)
    KFA_NL7 = np.zeros(ArrShape)
    for i in range(ArrShape[0]):
        for j in range(ArrShape[1]):
            NoiseEst_NL[i,j] = np.hstack([dkifitNL[i,j].lower_triangular(),dkifitNL[i,j].kt])
    NoiseEst_NL2 =  np.zeros_like(NoiseEst_NL)
    for i in range(ArrShape[0]):
        for j in range(ArrShape[1]):
            NoiseEst_NL2[i,j] = np.hstack([mat_to_vals(clip_negative_eigenvalues(vals_to_mat(NoiseEst_NL[i,j]))),NoiseEst_NL[i,j,6:]])
    for i in range(ArrShape[0]):
        for j in range(ArrShape[1]):
            Metrics = DKIMetrics(NoiseEst_NL2[i,j][:6],NoiseEst_NL2[i,j][6:21])
            MK_NL7[i,j] = Metrics[0]
            AK_NL7[i,j] = Metrics[1]
            RK_NL7[i,j] = Metrics[2]
            MKT_NL7[i,j] = Metrics[3]
            KFA_NL7[i,j] = Metrics[4]
    MKFullWLArr.append(MK_NL7)
    RKFullWLArr.append(RK_NL7)
    AKFullWLArr.append(AK_NL7)
    MKTFullWLArr.append(MKT_NL7)
    KFAFullWLArr.append(KFA_NL7)

In [None]:
MKMinWLArr = []
RKMinWLArr = []
AKMinWLArr = []
MKTMinWLArr = []
KFAMinWLArr = []
for kk in tqdm.tqdm(range(32)):
    dkimodelNL = dki.DiffusionKurtosisModel(gTabs7[kk],fit_method='WLS')
    dkifitNL = dkimodelNL.fit(TD[kk][:,:,axial_middles[kk],selected_indices7])
    ArrShape = TD[kk][:,:,axial_middles[kk],0].shape
    NoiseEst_NL = np.zeros(list(ArrShape)+[21])
    MK_NL7  = np.zeros(ArrShape)
    AK_NL7  = np.zeros(ArrShape)
    RK_NL7 = np.zeros(ArrShape)
    MKT_NL7 = np.zeros(ArrShape)
    KFA_NL7 = np.zeros(ArrShape)
    for i in range(ArrShape[0]):
        for j in range(ArrShape[1]):
            NoiseEst_NL[i,j] = np.hstack([dkifitNL[i,j].lower_triangular(),dkifitNL[i,j].kt])
    NoiseEst_NL2 =  np.zeros_like(NoiseEst_NL)
    for i in range(ArrShape[0]):
        for j in range(ArrShape[1]):
            NoiseEst_NL2[i,j] = np.hstack([mat_to_vals(clip_negative_eigenvalues(vals_to_mat(NoiseEst_NL[i,j]))),NoiseEst_NL[i,j,6:]])
    for i in range(ArrShape[0]):
        for j in range(ArrShape[1]):
            Metrics = DKIMetrics(NoiseEst_NL2[i,j][:6],NoiseEst_NL2[i,j][6:21])
            MK_NL7[i,j] = Metrics[0]
            AK_NL7[i,j] = Metrics[1]
            RK_NL7[i,j] = Metrics[2]
            MKT_NL7[i,j] = Metrics[3]
            KFA_NL7[i,j] = Metrics[4]
    MKMinWLArr.append(MK_NL7)
    RKMinWLArr.append(RK_NL7)
    AKMinWLArr.append(AK_NL7)
    MKTMinWLArr.append(MKT_NL7)
    KFAMinWLArr.append(KFA_NL7)

In [None]:
MKFullNLArr = []
RKFullNLArr = []
AKFullNLArr = []
MKTFullNLArr = []
KFAFullNLArr = []
for kk in tqdm.tqdm(range(32)):
    dkimodelNL = dki.DiffusionKurtosisModel(gTabsF[kk],fit_method='NLLS')
    dkifitNL = dkimodelNL.fit(TD[kk][:,:,axial_middles[kk],:])
    ArrShape = TD[kk][:,:,axial_middles[kk],0].shape
    NoiseEst_NL = np.zeros(list(ArrShape)+[21])
    MK_NL7  = np.zeros(ArrShape)
    AK_NL7  = np.zeros(ArrShape)
    RK_NL7 = np.zeros(ArrShape)
    MKT_NL7 = np.zeros(ArrShape)
    KFA_NL7 = np.zeros(ArrShape)
    for i in range(ArrShape[0]):
        for j in range(ArrShape[1]):
            NoiseEst_NL[i,j] = np.hstack([dkifitNL[i,j].lower_triangular(),dkifitNL[i,j].kt])
    NoiseEst_NL2 =  np.zeros_like(NoiseEst_NL)
    for i in range(ArrShape[0]):
        for j in range(ArrShape[1]):
            NoiseEst_NL2[i,j] = np.hstack([mat_to_vals(clip_negative_eigenvalues(vals_to_mat(NoiseEst_NL[i,j]))),NoiseEst_NL[i,j,6:]])
    for i in range(ArrShape[0]):
        for j in range(ArrShape[1]):
            Metrics = DKIMetrics(NoiseEst_NL2[i,j][:6],NoiseEst_NL2[i,j][6:21])
            MK_NL7[i,j] = Metrics[0]
            AK_NL7[i,j] = Metrics[1]
            RK_NL7[i,j] = Metrics[2]
            MKT_NL7[i,j] = Metrics[3]
            KFA_NL7[i,j] = Metrics[4]
    MKFullNLArr.append(MK_NL7)
    RKFullNLArr.append(RK_NL7)
    AKFullNLArr.append(AK_NL7)
    MKTFullNLArr.append(MKT_NL7)
    KFAFullNLArr.append(KFA_NL7)

In [None]:
MKMinNLArr = []
RKMinNLArr = []
AKMinNLArr = []
MKTMinNLArr = []
KFAMinNLArr = []
for kk in tqdm.tqdm(range(32)):
    dkimodelNL = dki.DiffusionKurtosisModel(gTabs7[kk],fit_method='NLLS')
    dkifitNL = dkimodelNL.fit(TD[kk][:,:,axial_middles[kk],selected_indices7])
    ArrShape = TD[kk][:,:,axial_middles[kk],0].shape
    NoiseEst_NL = np.zeros(list(ArrShape)+[21])
    MK_NL7  = np.zeros(ArrShape)
    AK_NL7  = np.zeros(ArrShape)
    RK_NL7 = np.zeros(ArrShape)
    MKT_NL7 = np.zeros(ArrShape)
    KFA_NL7 = np.zeros(ArrShape)
    for i in range(ArrShape[0]):
        for j in range(ArrShape[1]):
            NoiseEst_NL[i,j] = np.hstack([dkifitNL[i,j].lower_triangular(),dkifitNL[i,j].kt])
    NoiseEst_NL2 =  np.zeros_like(NoiseEst_NL)
    for i in range(ArrShape[0]):
        for j in range(ArrShape[1]):
            NoiseEst_NL2[i,j] = np.hstack([mat_to_vals(clip_negative_eigenvalues(vals_to_mat(NoiseEst_NL[i,j]))),NoiseEst_NL[i,j,6:]])
    for i in range(ArrShape[0]):
        for j in range(ArrShape[1]):
            Metrics = DKIMetrics(NoiseEst_NL2[i,j][:6],NoiseEst_NL2[i,j][6:21])
            MK_NL7[i,j] = Metrics[0]
            AK_NL7[i,j] = Metrics[1]
            RK_NL7[i,j] = Metrics[2]
            MKT_NL7[i,j] = Metrics[3]
            KFA_NL7[i,j] = Metrics[4]
    MKMinNLArr.append(MK_NL7)
    RKMinNLArr.append(RK_NL7)
    AKMinNLArr.append(AK_NL7)
    MKTMinNLArr.append(MKT_NL7)
    KFAMinNLArr.append(KFA_NL7)

In [None]:
MKFullNLArr = []
RKFullNLArr = []
AKFullNLArr = []
MKTFullNLArr = []
KFAFullNLArr = []
for kk in tqdm.tqdm(range(1,33)):
    fdwi = './HCP_data/Pat'+str(kk)+'/diff_1k.nii.gz'
    bvalloc = './HCP_data/Pat'+str(kk)+'/bvals_1k.txt'
    bvecloc = './HCP_data/Pat'+str(kk)+'/bvecs_1k.txt'
    
    fdwi3 = './HCP_data/Pat'+str(kk)+'/diff_3k.nii.gz'
    bvalloc3 = './HCP_data/Pat'+str(kk)+'/bvals_3k.txt'
    bvecloc3 = './HCP_data/Pat'+str(kk)+'/bvecs_3k.txt'
    
    bvalsHCP = np.loadtxt(bvalloc)
    bvecsHCP = np.loadtxt(bvecloc)
    gtabHCP = gradient_table(bvalsHCP, bvecsHCP)
    
    bvalsHCP3 = np.loadtxt(bvalloc3)
    bvecsHCP3 = np.loadtxt(bvecloc3)
    gtabHCP3 = gradient_table(bvalsHCP3, bvecsHCP3)
    
    gtabExt  = gradient_table(np.hstack((bvalsHCP,bvalsHCP3)), np.vstack((bvecsHCP,bvecsHCP3)))
    
    data, affine, img = load_nifti(fdwi, return_img=True)
    data, affine = reslice(data, affine, (1.5,1.5,1.5), (2.5,2.5,2.5))
    maskdata, mask = median_otsu(data, vol_idx=range(10, 50), median_radius=4,
                                 numpass=1, autocrop=False, dilate=2)
    _, mask2 = median_otsu(data, vol_idx=range(10, 50), median_radius=4,
                                 numpass=1, autocrop=True, dilate=2)
    
    
    data3, affine, img = load_nifti(fdwi3, return_img=True)
    data3, affine = reslice(data3, affine, (1.5,1.5,1.5), (2.5,2.5,2.5))
    # Get the indices of True values
    true_indices = np.argwhere(mask)
    
    # Determine the minimum and maximum indices along each dimension
    min_coords = true_indices.min(axis=0)
    max_coords = true_indices.max(axis=0)
    
    maskdata  = maskdata[min_coords[0]:max_coords[0]+1,min_coords[1]:max_coords[1]+1,min_coords[2]:max_coords[2]+1]
    axial_middle = maskdata.shape[2] // 2
    maskdata3 = data3[min_coords[0]:max_coords[0]+1,min_coords[1]:max_coords[1]+1,min_coords[2]:max_coords[2]+1]
    
    TestData = np.concatenate([maskdata[:, :, axial_middle, :],maskdata3[:, :, axial_middle, :]],axis=-1)
    TestData4D = np.concatenate([maskdata,maskdata3],axis=-1)
    dkimodelNL = dki.DiffusionKurtosisModel(gtabExt,fit_method='NLLS')
    dkifitNL = dkimodelNL.fit(TestData[:,:,:])

    ArrShape = TestData4D[:,:,axial_middle,0].shape
    MK_SBIFull  = np.zeros(ArrShape)
    AK_SBIFull  = np.zeros(ArrShape)
    RK_SBIFull  = np.zeros(ArrShape)
    for i in range(ArrShape[0]):
        for j in range(ArrShape[1]): 
            Metrics = DKIMetrics(dkifitNL[i,j].lower_triangular(),dkifitNL[i,j].kt)
            MK_SBIFull[i,j] = Metrics[0]
            AK_SBIFull[i,j] = Metrics[1]
            RK_SBIFull[i,j] = Metrics[2]
    MKFullNLArr.append(MK_SBIFull)
    AKFullNLArr.append(AK_SBIFull)
    RKFullNLArr.append(RK_SBIFull)

In [None]:
MK_Bay_Full = []
AK_Bay_Full = []
RK_Bay_Full = []
for kk in range(1,33):
    MK_Bay_Full.append(np.load(DatFolder+'Full_MK_Bay_'+str(kk)+'.npy',allow_pickle=True))
    AK_Bay_Full.append(np.load(DatFolder+'Full_AK_Bay_'+str(kk)+'.npy',allow_pickle=True))
    RK_Bay_Full.append(np.load(DatFolder+'Full_RK_Bay_'+str(kk)+'.npy',allow_pickle=True))

In [None]:
MK_Bay_Min = []
AK_Bay_Min = []
RK_Bay_Min = []
for kk in range(1,33):
    MK_Bay_Min.append(np.load(DatFolder+'Min_MK_Bay_'+str(kk)+'.npy',allow_pickle=True))
    AK_Bay_Min.append(np.load(DatFolder+'Min_AK_Bay_'+str(kk)+'.npy',allow_pickle=True))
    RK_Bay_Min.append(np.load(DatFolder+'Min_RK_Bay_'+str(kk)+'.npy',allow_pickle=True))

In [None]:
MK_Bay_Full = np.load(DatFolder+'Full_MK_HCP_Bay.npy',allow_pickle=True)
AK_Bay_Full = np.load(DatFolder+'Full_AK_HCP_Bay.npy',allow_pickle=True)
RK_Bay_Full = np.load(DatFolder+'Full_RK_HCP_Bay.npy',allow_pickle=True)

In [None]:
MK_MLE_Min = np.load(DatFolder+'Min_MK_HCP_MLE.npy',allow_pickle=True)
AK_MLE_Min = np.load(DatFolder+'Min_AK_HCP_MLE.npy',allow_pickle=True)
RK_MLE_Min = np.load(DatFolder+'Min_RK_HCP_MLE.npy',allow_pickle=True)

In [None]:
MK_MLE_Full = np.load(DatFolder+'Full_MK_HCP_MLE.npy',allow_pickle=True)
AK_MLE_Full = np.load(DatFolder+'Full_AK_HCP_MLE.npy',allow_pickle=True)
RK_MLE_Full = np.load(DatFolder+'Full_RK_HCP_MLE.npy',allow_pickle=True)

In [None]:
np.save(DatFolder+'Min_MK_HCP_Bay',np.array(MK_Bay_Min,dtype=object))
np.save(DatFolder+'Min_AK_HCP_Bay',np.array(AK_Bay_Min,dtype=object))
np.save(DatFolder+'Min_RK_HCP_Bay',np.array(RK_Bay_Min,dtype=object))

In [None]:
MKMinArr = np.load(DatFolder+'Min_MK_HCP.npy',allow_pickle=True)
MKMidArr = np.load(DatFolder+'Mid_MK_HCP.npy',allow_pickle=True)
MKFullArr = np.load(DatFolder+'Full_MK_HCP.npy',allow_pickle=True)

AKMinArr = np.load(DatFolder+'Min_AK_HCP.npy',allow_pickle=True)
AKMidArr = np.load(DatFolder+'Mid_AK_HCP.npy',allow_pickle=True)
AKFullArr = np.load(DatFolder+'Full_AK_HCP.npy',allow_pickle=True)

RKMinArr = np.load(DatFolder+'Min_RK_HCP.npy',allow_pickle=True)
RKMidArr = np.load(DatFolder+'Mid_RK_HCP.npy',allow_pickle=True)
RKFullArr = np.load(DatFolder+'Full_RK_HCP.npy',allow_pickle=True)

# Combined

In [None]:
image_path = './Figures/'

In [None]:
FigLoc = image_path + 'Fig_S11/'
if not os.path.exists(FigLoc):
    os.makedirs(FigLoc)

In [None]:
kk = 1
fdwi = './HCP_data/Pat'+str(kk)+'/diff_1k.nii.gz'
bvalloc = './HCP_data/Pat'+str(kk)+'/bvals_1k.txt'
bvecloc = './HCP_data/Pat'+str(kk)+'/bvecs_1k.txt'

fdwi3 = './HCP_data/Pat'+str(kk)+'/diff_3k.nii.gz'
bvalloc3 = './HCP_data/Pat'+str(kk)+'/bvals_3k.txt'
bvecloc3 = './HCP_data/Pat'+str(kk)+'/bvecs_3k.txt'

bvalsHCP = np.loadtxt(bvalloc)
bvecsHCP = np.loadtxt(bvecloc)
gtabHCP = gradient_table(bvalsHCP, bvecsHCP)

bvalsHCP3 = np.loadtxt(bvalloc3)
bvecsHCP3 = np.loadtxt(bvecloc3)
gtabHCP3 = gradient_table(bvalsHCP3, bvecsHCP3)

gtabExt  = gradient_table(np.hstack((bvalsHCP,bvalsHCP3)), np.vstack((bvecsHCP,bvecsHCP3)))

data, affine, img = load_nifti(fdwi, return_img=True)
data, affine = reslice(data, affine, (1.5,1.5,1.5), (2.5,2.5,2.5))
maskdata, mask = median_otsu(data, vol_idx=range(10, 50), median_radius=4,
                             numpass=1, autocrop=False, dilate=2)
_, mask2 = median_otsu(data, vol_idx=range(10, 50), median_radius=4,
                             numpass=1, autocrop=True, dilate=2)

In [None]:
mask2.shape

In [None]:
masks_other = []
for kk in tqdm.tqdm(range(1,33)):
    fdwi = './HCP_data/Pat'+str(kk)+'/diff_1k.nii.gz'
    bvalloc = './HCP_data/Pat'+str(kk)+'/bvals_1k.txt'
    bvecloc = './HCP_data/Pat'+str(kk)+'/bvecs_1k.txt'
    
    fdwi3 = './HCP_data/Pat'+str(kk)+'/diff_3k.nii.gz'
    bvalloc3 = './HCP_data/Pat'+str(kk)+'/bvals_3k.txt'
    bvecloc3 = './HCP_data/Pat'+str(kk)+'/bvecs_3k.txt'
    
    bvalsHCP = np.loadtxt(bvalloc)
    bvecsHCP = np.loadtxt(bvecloc)
    gtabHCP = gradient_table(bvalsHCP, bvecsHCP)
    
    bvalsHCP3 = np.loadtxt(bvalloc3)
    bvecsHCP3 = np.loadtxt(bvecloc3)
    gtabHCP3 = gradient_table(bvalsHCP3, bvecsHCP3)
    
    gtabExt  = gradient_table(np.hstack((bvalsHCP,bvalsHCP3)), np.vstack((bvecsHCP,bvecsHCP3)))
    
    data, affine, img = load_nifti(fdwi, return_img=True)
    data, affine = reslice(data, affine, (1.5,1.5,1.5), (2.5,2.5,2.5))
    _, mask2 = median_otsu(data, vol_idx=range(10, 50), median_radius=3,
                                 numpass=1, autocrop=True, dilate=2)
    masks_other.append(mask2)

In [None]:
SSIM_bay = []
for i in range(8):
    NS1 = MK_Bay_Min[i].astype(np.float32)
    NS2 = MK_Bay_Full[i].astype(np.float32)
    Ma = masks[i][:,:,axial_middles[i]]
    result = masked_local_ssim(NS1, NS2, Ma, win_size=7)
    SSIM_bay.append(result)


In [None]:
SSIM_bay = []
for i in range(32):
    NS1 = MK_Bay_Min[i].astype(np.float32)
    NS2 = MK_Bay_Full[i].astype(np.float32)
    Ma = masks[i][:,:,axial_middles[i]]
    result = masked_local_ssim(NS1, NS2, Ma, win_size=7,dat_range = 1)
    SSIM_bay.append(result)

SSIM_MLE = []
for i in range(32):
    NS1 = MK_MLE_Min[i].astype(np.float32)
    NS2 = MK_MLE_Full[i].astype(np.float32)
    Ma = masks[i][:,:,axial_middles[i]]
    result = masked_local_ssim(NS1, NS2, Ma, win_size=7,dat_range = 1)
    SSIM_MLE.append(result)

SSIM_NL = []
for i in range(32):
    NS1 = MKMinNLArr[i]
    NS2 = MKFullNLArr[i]
    Ma = masks_other[i][:,:,axial_middles[i]]
    result = masked_local_ssim(NS1, NS2, Ma, win_size=7,dat_range = 1)
    SSIM_NL.append(result)

SSIM_WL = []
for i in range(32):
    NS1 = MKMinWLArr[i]
    NS2 = MKFullWLArr[i]
    Ma = masks_other[i][:,:,axial_middles[i]]
    result = masked_local_ssim(NS1, NS2, Ma, win_size=7,dat_range = 1)
    SSIM_WL.append(result)

SSIM_SBI= []
for i in range(32):
    NS1 = MKMinArr[i]
    NS2 = MKFullArr[i]
    Ma = masks_other[i][:,:,axial_middles[i]]
    result = masked_local_ssim(NS1, NS2, Ma, win_size=7,dat_range = 1)
    SSIM_SBI.append(result)

In [None]:
NLLSFit   = np.array([225,190,106])/255
SBIFit   = np.array([64,176,166])/255
WLSFit = np.array((192,108,132)) / 255  # muted violet
MLEFit = np.array([70,100,150]) / 255
BayFit = np.array([140,165,200])/255

In [None]:
fig, ax = plt.subplots(figsize=(3.2,4.8))#, sharex=True)

y_data = np.array(SSIM_SBI)
g_pos = np.array([1])
colors = ['mediumturquoise']
colors2 = ['paleturquoise']

BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2,scatter=True)

y_data = np.array(SSIM_NL)
g_pos = np.array([2])
colors = ['sandybrown']
colors2 = ['peachpuff']


BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2,scatter=True)

y_data = np.array(SSIM_WL)
g_pos = np.array([3])
colors = [WLSFit]
colors2 = ['lightsalmon']

BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2,scatter=True)

plt.axhline(0.66, lw=3, ls='--', c='k')

y_data = np.array(SSIM_MLE)
g_pos = np.array([4])
colors = [MLEFit]
colors2 = ['lightsteelblue']

BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2,scatter=True)

plt.axhline(0.66, lw=3, ls='--', c='k')

y_data = np.array(SSIM_bay)
g_pos = np.array([5])
colors = [BayFit]
colors2 = ['lavender']

BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2,scatter=True)

plt.axhline(0.66, lw=3, ls='--', c='k')

plt.yticks([0,0.2,0.4,0.6,0.8])
plt.xticks([1,2,3,4,5],['SBI','NLLS','WLS','MLE','Bay.'],fontsize=32,rotation=90)

if Save: plt.savefig(FigLoc+'DKIHCP_SSIM_MK_Other.pdf',format='PDF',transparent=True,bbox_inches='tight')

In [None]:
from scipy.ndimage import gaussian_filter

In [None]:
SSIM_bay = []
for i in range(32):
    NS1 = AK_Bay_Min[i].astype(np.float32)
    NS2 = AK_Bay_Full[i].astype(np.float32)
    NS1 = gaussian_filter(NS1, sigma=0.5)
    NS2 = gaussian_filter(NS2, sigma=0.5)
    Ma = masks[i][:,:,axial_middles[i]]
    result = masked_local_ssim(NS1, NS2, Ma, win_size=7,dat_range = 1)
    SSIM_bay.append(result)

SSIM_MLE = []
for i in range(32):
    NS1 = AK_MLE_Min[i].astype(np.float32)
    NS2 = AK_MLE_Full[i].astype(np.float32)
    NS1 = gaussian_filter(NS1, sigma=0.5)
    NS2 = gaussian_filter(NS2, sigma=0.5)
    Ma = masks[i][:,:,axial_middles[i]]
    result = masked_local_ssim(NS1, NS2, Ma, win_size=7,dat_range = 1)
    SSIM_MLE.append(result)

SSIM_NL = []
for i in range(32):
    NS1 = AKMinNLArr[i]
    NS2 = AKFullNLArr[i]
    NS1 = gaussian_filter(NS1, sigma=0.5)
    NS2 = gaussian_filter(NS2, sigma=0.5)
    Ma = masks_other[i][:,:,axial_middles[i]]
    result = masked_local_ssim(NS1, NS2, Ma, win_size=7,dat_range = 1)
    SSIM_NL.append(result)

SSIM_WL = []
for i in range(32):
    NS1 = AKMinWLArr[i]
    NS2 = AKFullWLArr[i]
    NS1 = gaussian_filter(NS1, sigma=0.5)
    NS2 = gaussian_filter(NS2, sigma=0.5)
    Ma = masks_other[i][:,:,axial_middles[i]]
    result = masked_local_ssim(NS1, NS2, Ma, win_size=7,dat_range = 1)
    SSIM_WL.append(result)

SSIM_SBI= []
for i in range(32):
    NS1 =AKMinArr[i]
    NS1 = gaussian_filter(NS1, sigma=0.5)
    NS2 =AKFullArr[i]
    NS2 = gaussian_filter(NS2, sigma=0.5)
    Ma = masks_other[i][:,:,axial_middles[i]]
    result = masked_local_ssim(NS1, NS2, Ma, win_size=7,dat_range=1)
    SSIM_SBI.append(result)

In [None]:
fig, ax = plt.subplots(figsize=(3.2,4.8))#, sharex=True)

y_data = np.array(SSIM_SBI)
g_pos = np.array([1])
colors = ['mediumturquoise']
colors2 = ['paleturquoise']

BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2,scatter=True)

y_data = np.array(SSIM_NL)
g_pos = np.array([2])
colors = ['sandybrown']
colors2 = ['peachpuff']


BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2,scatter=True)

y_data = np.array(SSIM_WL)
g_pos = np.array([3])
colors = [WLSFit]
colors2 = ['lightsalmon']

BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2,scatter=True)

plt.axhline(0.66, lw=3, ls='--', c='k')

y_data = np.array(SSIM_MLE)
g_pos = np.array([4])
colors = [MLEFit]
colors2 = ['lightsteelblue']

BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2,scatter=True)

plt.axhline(0.66, lw=3, ls='--', c='k')

y_data = np.array(SSIM_bay)
g_pos = np.array([5])
colors = [BayFit]
colors2 = ['lavender']

BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2,scatter=True)

plt.axhline(0.66, lw=3, ls='--', c='k')

plt.yticks([0,0.2,0.4,0.6,0.8])
plt.ylim([0,1])
plt.xticks([1,2,3,4,5],['SBI','NLLS','WLS','MLE','Bay.'],fontsize=32,rotation=90)

if Save: plt.savefig(FigLoc+'DKIHCP_SSIM_AK_Other.pdf',format='PDF',transparent=True,bbox_inches='tight')

In [None]:
SSIM_bay = []
for i in range(32):
    NS1 = RK_Bay_Min[i].astype(np.float32)
    NS2 = RK_Bay_Full[i].astype(np.float32)
    NS1 = gaussian_filter(NS1, sigma=0.5)
    Ma = masks[i][:,:,axial_middles[i]]
    result = masked_local_ssim(NS1, NS2, Ma, win_size=7,dat_range = 1)
    SSIM_bay.append(result)

SSIM_MLE = []
for i in range(32):
    NS1 = RK_MLE_Min[i].astype(np.float32)
    NS2 = RK_MLE_Full[i].astype(np.float32)
    NS1 = gaussian_filter(NS1, sigma=0.5)
    Ma = masks[i][:,:,axial_middles[i]]
    result = masked_local_ssim(NS1, NS2, Ma, win_size=7,dat_range = 1)
    SSIM_MLE.append(result)

SSIM_NL = []
for i in range(32):
    NS1 = RKMinNLArr[i]
    NS2 = RKFullNLArr[i]
    NS1 = gaussian_filter(NS1, sigma=0.5)
    Ma = masks_other[i][:,:,axial_middles[i]]
    result = masked_local_ssim(NS1, NS2, Ma, win_size=7,dat_range = 1)
    SSIM_NL.append(result)

SSIM_WL = []
for i in range(32):
    NS1 = RKMinWLArr[i]
    NS2 = RKFullWLArr[i]
    NS1 = gaussian_filter(NS1, sigma=0.5)
    Ma = masks_other[i][:,:,axial_middles[i]]
    result = masked_local_ssim(NS1, NS2, Ma, win_size=7,dat_range = 1)
    SSIM_WL.append(result)

SSIM_SBI= []
for i in range(32):
    NS1 =RKMinArr[i]
    NS1 = gaussian_filter(NS1, sigma=0.5)
    NS2 =RKFullArr[i]
    Ma = masks_other[i][:,:,axial_middles[i]]
    result = masked_local_ssim(NS1, NS2, Ma, win_size=7,dat_range=1)
    SSIM_SBI.append(result)

In [None]:
fig, ax = plt.subplots(figsize=(3.2,4.8))#, sharex=True)

y_data = np.array(SSIM_SBI)
g_pos = np.array([1])
colors = ['mediumturquoise']
colors2 = ['paleturquoise']

BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2,scatter=True)

y_data = np.array(SSIM_NL)
g_pos = np.array([2])
colors = ['sandybrown']
colors2 = ['peachpuff']


BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2,scatter=True)

y_data = np.array(SSIM_WL)
g_pos = np.array([3])
colors = [WLSFit]
colors2 = ['lightsalmon']

BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2,scatter=True)

plt.axhline(0.66, lw=3, ls='--', c='k')

y_data = np.array(SSIM_MLE)
g_pos = np.array([4])
colors = [MLEFit]
colors2 = ['lightsteelblue']

BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2,scatter=True)

plt.axhline(0.66, lw=3, ls='--', c='k')

y_data = np.array(SSIM_bay)
g_pos = np.array([5])
colors = [BayFit]
colors2 = ['lavender']

BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2,scatter=True)

plt.axhline(0.66, lw=3, ls='--', c='k')

plt.yticks([0,0.2,0.4,0.6,0.8])
plt.ylim([-0.1,1])
plt.xticks([1,2,3,4,5],['SBI','NLLS','WLS','MLE','Bay.'],fontsize=32,rotation=90)

if Save: plt.savefig(FigLoc+'DKIHCP_SSIM_RK_Other.pdf',format='PDF',transparent=True,bbox_inches='tight')