# Frontmatter

## Imports 

In [None]:
import numpy as np
import dill as pickle
import os
from tqdm.auto import tqdm
from joblib import Parallel, delayed

import scipy.stats as stats
from scipy.spatial.distance import pdist, squareform
from scipy.stats import lognorm,norm
from scipy.optimize import least_squares

from skimage.metrics import structural_similarity as ssim

import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
from matplotlib.gridspec import GridSpec
from matplotlib.colors import TwoSlopeNorm
import matplotlib.patches as mpatches
from matplotlib.ticker import FormatStrFormatter
from matplotlib.ticker import ScalarFormatter
from scipy.special import i0e
from scipy.ndimage import gaussian_filter
from scipy.optimize import least_squares
# Define font properties
font = {
    'family': 'sans-serif',  # Use sans-serif family
    'sans-serif': ['Helvetica'],  # Specify Helvetica as the sans-serif font
    'size': 14  # Set the default font size
}
plt.rc('font', **font)

# Set tick label sizes
plt.rc('ytick', labelsize=24)
plt.rc('xtick', labelsize=24)

plt.rcParams.update({
    "text.usetex": False,
    "font.family": "Helvetica"
})
# Customize axes spines and legend appearance
plt.rcParams['axes.spines.top'] = False
plt.rcParams['axes.spines.right'] = False
plt.rcParams['legend.frameon'] = False

from sbi import analysis as analysis
from sbi import utils as utils
from sbi.inference import SNPE, simulate_for_sbi
from sbi.inference.potentials.posterior_based_potential import posterior_estimator_based_potential
from sbi.utils.user_input_checks import (
    check_sbi_inputs,
    process_prior,
    process_simulator,
)
from sbi.utils import process_prior,BoxUniform

import torch
from torch.distributions import Categorical,Normal
from torch import Tensor


from dipy.sims.voxel import single_tensor
from dipy.data import get_fnames
from dipy.io.gradients import read_bvals_bvecs
from dipy.core.gradients import gradient_table
from dipy.reconst.dti import (decompose_tensor, from_lower_triangular)
from dipy.io.image import load_nifti
from dipy.segment.mask import median_otsu
import dipy.reconst.dti as dti
import dipy.reconst.dki as dki
from dipy.align.reslice import reslice
from dipy.core.sphere import disperse_charges, Sphere, HemiSphere

import pymatreader as pmt

from scipy.optimize import minimize
from scipy.special import i0

## Functions

In [None]:
def vals_to_mat(dt):
    DTI = np.zeros((3,3))
    DTI[0,0] = dt[0]
    DTI[0,1],DTI[1,0] =  dt[1],dt[1]
    DTI[1,1] =  dt[2]
    DTI[0,2],DTI[2,0] =  dt[3],dt[3]
    DTI[1,2],DTI[2,1] =  dt[4],dt[4]
    DTI[2,2] =  dt[5]
    return DTI

def mat_to_vals(DTI):
    dt = np.zeros(6)
    dt[0] = DTI[0,0]
    dt[1] = DTI[0,1]
    dt[2] = DTI[1,1]
    dt[3] = DTI[0,2]
    dt[4] = DTI[1,2]
    dt[5] = DTI[2,2]
    return dt

def fill_lower_diag(a):
    b = [a[0],a[3],a[1],a[4],a[5],a[2]]
    n = 3
    mask = np.tri(n,dtype=bool) 
    out = np.zeros((n,n),dtype=float)
    out[mask] = b
    return out

def ComputeDTI(params):
    L = fill_lower_diag(params)
    
    np.fill_diagonal(L, np.abs(np.diagonal(L)))

    A = L @ L.T
    return A
    
def ComputeDTI(params):
    L = fill_lower_diag(params)
    
    np.fill_diagonal(L, np.abs(np.diagonal(L)))

    A = L @ L.T
    return A

def ForceLowFA(dt):
    # Modify the matrix to ensure low FA (more isotropic)
    eigenvalues, eigenvectors = np.linalg.eigh(dt)
    
    # Make the eigenvalues more similar to enforce low FA
    mean_eigenvalue = np.mean(eigenvalues)

    adjusted_eigenvalues = np.clip(eigenvalues, mean_eigenvalue * np.random.rand(), mean_eigenvalue * 1.0)
    
    # Reconstruct the matrix with the adjusted eigenvalues
    dt_low_fa = eigenvectors @ np.diag(adjusted_eigenvalues) @ eigenvectors.T
    
    return dt_low_fa
    
def FracAni(evals,MD):
    numerator = np.sqrt(3 * np.sum((evals - MD) ** 2))
    denominator = np.sqrt(2) * np.sqrt(np.sum(evals ** 2))
    
    return numerator / denominator

def clip_negative_eigenvalues(matrix):
    # Perform eigenvalue decomposition
    eigenvalues, eigenvectors = np.linalg.eig(matrix)
    
    # Clip negative eigenvalues to 0
    clipped_eigenvalues = np.maximum(eigenvalues, 1e-5)
    
    # Reconstruct the matrix with the clipped eigenvalues
    clipped_matrix = eigenvectors @ np.diag(clipped_eigenvalues) @ np.linalg.inv(eigenvectors)
    
    return clipped_matrix


In [None]:
def FitDT(Dat,seed=1):

    np.random.seed(seed)
    # DT_abc
    data = Dat[:,0]
    shape,loc,scale = lognorm.fit(data)
    
    dti1_fitted = stats.lognorm(shape, loc=loc, scale=scale)

    #DT_rest
    data = Dat[:,1]
    loc,scale = norm.fit(data)
    
    # Compute the fitted PDF
    dti2_fitted = stats.norm(loc=loc, scale=scale)

    return dti1_fitted,dti2_fitted

def FitKT(Dat,seed=1):
    np.random.seed(seed)    
    # Fitting x4
    data = Dat[:,0]
    shape,loc,scale = lognorm.fit(data)
    x4_fitted = stats.lognorm(shape, loc=loc, scale=scale)
    
    # Fitting R1
    data = Dat[:,3]
    loc,scale = norm.fit(data)
    R1_fitted = norm(loc,scale)
    
    # Fitting x2
    data = Dat[:,9]
    shape,loc,scale = lognorm.fit(data)
    x2_fitted = stats.lognorm(shape, loc=loc, scale=scale)

    # Fitting R2
    data = Dat[:,12]
    loc,scale = norm.fit(data)
    R2_fitted = norm(loc,scale)


    return x4_fitted,R1_fitted,x2_fitted,R2_fitted

def GenDTKT(DT_Fits,KT_Fits,seed,size):

    np.random.seed(seed)
    DT = np.zeros([size,6])
    KT = np.zeros([size,15])

    DT[:,0] = DT_Fits[0].rvs(size)
    DT[:,2] = DT_Fits[0].rvs(size)
    DT[:,5] = DT_Fits[0].rvs(size)

    DT[:,1] = DT_Fits[1].rvs(size)
    DT[:,3] = DT_Fits[1].rvs(size)
    DT[:,4] = DT_Fits[1].rvs(size)

    for k in range(3):
        KT[:,k] = KT_Fits[0].rvs(size)
    for k in range(3,9):
        KT[:,k] = KT_Fits[1].rvs(size)
    for k in range(9,12):
        KT[:,k] = KT_Fits[2].rvs(size)
    for k in range(12,15):
        KT[:,k] = KT_Fits[3].rvs(size)

    return DT,KT
    
def DKIMetrics(dt,kt,analytical=True):
    if(dt.ndim == 1):
        dt = vals_to_mat(dt)
    evals,evecs = np.linalg.eigh(dt)
    idx = np.argsort(evals)[::-1]
    evals = evals[idx]
    evecs = evecs[:, idx]
    
    params = np.concatenate([evals,np.hstack(evecs),kt])
    params2 = np.concatenate([evals,np.hstack(evecs),-kt])

    mk = dki.mean_kurtosis(params,analytical=analytical,min_kurtosis=-3.0 / 7, max_kurtosis=np.inf)

    ak = dki.axial_kurtosis(params,analytical=analytical,min_kurtosis=-3.0 / 7, max_kurtosis=np.inf)

    rk = dki.radial_kurtosis(params,analytical=analytical,min_kurtosis=-3.0 / 7, max_kurtosis=np.inf)

    mkt = dki.mean_kurtosis_tensor(params, min_kurtosis=-3.0 / 7, max_kurtosis=np.inf)

    kfa = kurtosis_fractional_anisotropy_test(params)

    return mk,ak,rk,mkt,kfa
    
def kurtosis_fractional_anisotropy_test(dki_params):
    r"""Compute the anisotropy of the kurtosis tensor (KFA).

    See :footcite:p:`Glenn2015` and :footcite:p:`NetoHenriques2021` for further
    details about the method.

    Parameters
    ----------
    dki_params : ndarray (x, y, z, 27) or (n, 27)
        All parameters estimated from the diffusion kurtosis model.
        Parameters are ordered as follows:
            1) Three diffusion tensor's eigenvalues
            2) Three lines of the eigenvector matrix each containing the first,
                second and third coordinates of the eigenvector
            3) Fifteen elements of the kurtosis tensor

    Returns
    -------
    kfa : array
        Calculated mean kurtosis tensor.

    Notes
    -----
    The KFA is defined as :footcite:p:`Glenn2015`:

    .. math::

         KFA \equiv
         \frac{||\mathbf{W} - MKT \mathbf{I}^{(4)}||_F}{||\mathbf{W}||_F}

    where $W$ is the kurtosis tensor, MKT the kurtosis tensor mean, $I^{(4)}$ is
    the Miny symmetric rank 2 isotropic tensor and $||...||_F$ is the tensor's
    Frobenius norm :footcite:p:`Glenn2015`.

    References
    ----------
    .. footbibliography::

    """
    Wxxxx = dki_params[..., 12]
    Wyyyy = dki_params[..., 13]
    Wzzzz = dki_params[..., 14]
    Wxxxy = dki_params[..., 15]
    Wxxxz = dki_params[..., 16]
    Wxyyy = dki_params[..., 17]
    Wyyyz = dki_params[..., 18]
    Wxzzz = dki_params[..., 19]
    Wyzzz = dki_params[..., 20]
    Wxxyy = dki_params[..., 21]
    Wxxzz = dki_params[..., 22]
    Wyyzz = dki_params[..., 23]
    Wxxyz = dki_params[..., 24]
    Wxyyz = dki_params[..., 25]
    Wxyzz = dki_params[..., 26]


    W = 1.0 / 5.0 * (Wxxxx + Wyyyy + Wzzzz + 2 * Wxxyy + 2 * Wxxzz + 2 * Wyyzz)
    # Compute's equation numerator
    A = (
        (Wxxxx - W) ** 2
        + (Wyyyy - W) ** 2
        + (Wzzzz - W) ** 2
        + 4 * (Wxxxy**2 + Wxxxz**2 + Wxyyy**2 + Wyyyz**2 + Wxzzz**2 + Wyzzz**2)
        + 6 * ((Wxxyy - W / 3) ** 2 + (Wxxzz - W / 3) ** 2 + (Wyyzz - W / 3) ** 2)
        + 12 * (Wxxyz**2 + Wxyyz**2 + Wxyzz**2)
    )
    # Compute's equation denominator
    B = (
        Wxxxx**2
        + Wyyyy**2
        + Wzzzz**2
        + 4 * (Wxxxy**2 + Wxxxz**2 + Wxyyy**2 + Wyyyz**2 + Wxzzz**2 + Wyzzz**2)
        + 6 * (Wxxyy**2 + Wxxzz**2 + Wyyzz**2)
        + 12 * (Wxxyz**2 + Wxyyz**2 + Wxyzz**2)
    )

    # Compute KFA
    KFA = np.zeros(A.shape)
    KFA = np.sqrt(A / B)

    return KFA

In [None]:
def CustomSimulator(Mat,gtab,S0,snr=None):
    evals,evecs = np.linalg.eigh(Mat)
    signal = single_tensor(gtab, S0=S0, evals=evals, evecs=evecs)
    if(snr is None):
        return signal
    else:
        return AddNoise(signal,S0,snr)

def Simulator(gtab,S0,params,SNR):

    dt = ComputeDTI(params)
    signal_dti = CustomSimulator(dt,gtab,S0,SNR)
    
    return signal_dti


def GenRicciNoise(signal,S0,snr):

    size = signal.shape
    sigma = S0 / snr
    noise1 = np.random.normal(0, sigma, size=size)
    noise2 = np.random.normal(0, sigma, size=size)

    return np.sqrt((signal+noise1) ** 2 + noise2 ** 2)


def AddNoise(signal,S0,snr):
    
    return GenRicciNoise(signal,S0,snr)

def CustomDKISimulator(dt,kt,gtab,S0,snr=None):
    if(dt.ndim == 1):
        dt = vals_to_mat(dt)
    evals,evecs = np.linalg.eigh(dt)
    params = np.concatenate([evals,np.hstack(evecs),kt])
    signal = dki.dki_prediction(params,gtab,S0)
    if(snr is None):
        return signal
    else:
        return AddNoise(signal,S0,snr)

def CustomDKISimulator2(params,kt,gtab,S0,snr=None):
    dt = ComputeDTI(params)
    evals,evecs = np.linalg.eigh(dt)
    combined_set = np.concatenate([evals,np.hstack(evecs),kt])
    signal = dki.dki_prediction(combined_set,gtab,S0)
    if(np.isnan(signal).any() or np.isinf(signal).any() or (signal>1e15).any()):
        pass#import pdb;pdb.set_trace()
    if(snr is None):
        return signal
    else:
        return AddNoise(signal,S0,snr)

In [None]:
class DTIPrior:
    def __init__(self, lower_abs : Tensor, upper_abs : Tensor, 
                       lower_rest: Tensor, upper_rest: Tensor,
                        return_numpy: bool = False):

        self.dist_abs = BoxUniform(low= lower_abs* torch.ones(3), high=upper_abs * torch.ones(3))
        self.dist_rest = BoxUniform(low=lower_rest * torch.ones(3), high=upper_rest *torch.ones(3))
        self.return_numpy = return_numpy
        
    def sample(self, sample_shape=torch.Size([])):
        
        abc  = self.dist_abs.sample(sample_shape)
        rest = self.dist_rest.sample(sample_shape)
        
        if self.return_numpy:   
            params = np.hstack([abc,rest]) 
        else:
            params = torch.hstack([abc,rest])

        return params
        
    def log_prob(self, values):
        if self.return_numpy:
            values = torch.as_tensor(values)
        
        abc  = values[:,:3]
        rest = values[:,3:]

        log_prob_abc  = self.dist_abs.log_prob(abc)
        log_prob_rest = self.dist_rest.log_prob(rest)
        return log_prob_abc+log_prob_rest

class DTIPriorS0:
    def __init__(self, lower_abs : Tensor, upper_abs : Tensor, 
                       lower_rest: Tensor, upper_rest: Tensor,
                       lower_S0: Tensor, upper_S0: Tensor,
                        return_numpy: bool = False):

        self.dist_abs = BoxUniform(low= lower_abs* torch.ones(3), high=upper_abs * torch.ones(3))
        self.dist_rest = BoxUniform(low=lower_rest * torch.ones(3), high=upper_rest *torch.ones(3))
        self.dist_S0 = BoxUniform(low=torch.tensor([lower_S0]), high=torch.tensor([upper_S0]))
        self.return_numpy = return_numpy
        
    def sample(self, sample_shape=torch.Size([])):
        
        abc  = self.dist_abs.sample(sample_shape)
        rest = self.dist_rest.sample(sample_shape)
        S0   = self.dist_S0.sample(sample_shape)
        
        if self.return_numpy:   
            params = np.hstack([abc,rest,S0]) 
        else:
            params = torch.hstack([abc,rest,S0])

        return params
        
    def log_prob(self, values):
        if self.return_numpy:
            values = torch.as_tensor(values)
        
        abc  = values[:,:3]
        rest = values[:,3:-1]
        S0   = values[:,-1]

        log_prob_abc  = self.dist_abs.log_prob(abc)
        log_prob_rest = self.dist_rest.log_prob(rest)
        log_prob_S0 = self.dist_S0.log_prob(S0)
        return log_prob_abc+log_prob_rest+log_prob_S0

class DTIPriorS0Direc:
    def __init__(self, lower_abs : Tensor, upper_abs : Tensor, 
                       lower_rest: Tensor, upper_rest: Tensor,
                       lower_S0: Tensor, upper_S0: Tensor,
                        return_numpy: bool = False):

        self.dist_abs = BoxUniform(low= lower_abs* torch.ones(3), high=upper_abs * torch.ones(3))
        self.dist_rest = BoxUniform(low=lower_rest * torch.ones(3), high=upper_rest *torch.ones(3))
        self.dist_S0 = BoxUniform(low=torch.tensor([lower_S0]), high=torch.tensor([upper_S0]))
        self.direction_choice = Categorical(probs=torch.ones(1, 5))
        self.return_numpy = return_numpy
        
    def sample(self, sample_shape=torch.Size([])):
        
        abc  = self.dist_abs.sample(sample_shape)
        rest = self.dist_rest.sample(sample_shape)
        S0   = self.dist_S0.sample(sample_shape)
        direc = self.direction_choice.sample(sample_shape)       
        
        if self.return_numpy:   
            params = np.hstack([abc,rest,S0,direc]) 
        else:
            params = torch.hstack([abc,rest,S0,direc])

        return params
        
    def log_prob(self, values):
        if self.return_numpy:
            values = torch.as_tensor(values)
        
        abc  = values[:,:3]
        rest = values[:,3:-2]
        S0   = values[:,-2]
        direc   = values[:,-1]

        log_prob_abc   = self.dist_abs.log_prob(abc)
        log_prob_rest  = self.dist_rest.log_prob(rest)
        log_prob_S0    = self.dist_S0.log_prob(S0)
        log_prob_direc =  self.direction_choice.log_prob(direc)
        return log_prob_abc+log_prob_rest+log_prob_S0+log_prob_direc

class DTIPriorS0Noise:
    def __init__(self, lower_abs : Tensor, upper_abs : Tensor, 
                       lower_rest: Tensor, upper_rest: Tensor,
                       lower_S0: Tensor, upper_S0: Tensor,
                       lower_noise: Tensor, upper_noise: Tensor,
                        return_numpy: bool = False):

        self.dist_abs = BoxUniform(low= lower_abs* torch.ones(3), high=upper_abs * torch.ones(3))
        self.dist_rest = BoxUniform(low=lower_rest * torch.ones(3), high=upper_rest *torch.ones(3))
        self.dist_S0 = BoxUniform(low=torch.tensor([lower_S0]), high=torch.tensor([upper_S0]))
        self.dist_noise = BoxUniform(low=torch.tensor([lower_noise]), high=torch.tensor([upper_noise]))
        self.return_numpy = return_numpy
        
    def sample(self, sample_shape=torch.Size([])):
        
        abc     = self.dist_abs.sample(sample_shape)
        rest    = self.dist_rest.sample(sample_shape)
        S0      = self.dist_S0.sample(sample_shape)
        noise   = self.dist_noise.sample(sample_shape)
        
        if self.return_numpy:   
            params = np.hstack([abc,rest,S0,noise]) 
        else:
            params = torch.hstack([abc,rest,S0,noise])

        return params
        
    def log_prob(self, values):
        if self.return_numpy:
            values = torch.as_tensor(values)
        
        abc     = values[:,:3]
        rest    = values[:,3:-2]
        S0      = values[:,-2]
        noise   = values[:,-1]

        log_prob_abc  = self.dist_abs.log_prob(abc)
        log_prob_rest = self.dist_rest.log_prob(rest)
        log_prob_S0 = self.dist_S0.log_prob(S0)
        log_prob_noise = self.dist_noise.log_prob(noise)
        return log_prob_abc+log_prob_rest+log_prob_S0+log_prob_noise

def histogram_mode(data, bins=50):
    # Calculate the histogram
    counts, bin_edges = np.histogram(data, bins=bins)
    
    # Find the bin with the maximum count (highest frequency)
    max_bin_index = np.argmax(counts)
    
    # Calculate the mode as the midpoint of the bin with the highest count
    mode = (bin_edges[max_bin_index] + bin_edges[max_bin_index + 1]) / 2
    
    return mode

In [None]:
def Errors(Guess,Truth,gtab,signal_true,signal_provided,S0Guess=200):
    # Eigenvalue error
    evals_guess_raw,evecs_guess = np.linalg.eigh(Guess)
    evals_guess = np.sort(evals_guess_raw)
    evals_true_raw,evecs_true = np.linalg.eigh(Truth)
    evals_true = np.sort(evals_true_raw)
    
    EigError = np.linalg.norm(evals_guess-evals_true)

    # Mean diffusivitiy
    mean_true = np.mean(evals_true)
    mean_guess = np.mean(evals_guess)
    MD = abs(mean_true-mean_guess)

    # Fractional Anisotropy
    FA_true  = FracAni(evals_true,mean_true)
    FA_guess = FracAni(evals_guess,mean_guess)
    FA = abs(FA_true-FA_guess)                                        

    # Frobenius error
    Frob =  np.linalg.norm(Guess-Truth, 'fro')

    # Signal error
    signal_guess = single_tensor(gtab, S0=S0Guess, evals=evals_guess_raw, evecs=evecs_guess)
    Err  = np.linalg.norm(signal_true-signal_guess)/len(signal_true)
    Corr = np.corrcoef(signal_true,signal_guess)[0,1]
    
    Err2  = np.linalg.norm(signal_provided-signal_guess[:len(signal_provided)])/len(signal_provided)
    Corr2 = np.corrcoef(signal_provided,signal_guess[:len(signal_provided)])[0,1]
    
    return MD,FA,EigError,Frob,Err,Corr,Err2,Corr2

def ErrorsMDFA(Guess,Truth):
    # Eigenvalue error
    evals_guess_raw,evecs_guess = np.linalg.eigh(Guess)
    evals_guess = np.sort(evals_guess_raw)
    evals_true_raw,evecs_true = np.linalg.eigh(Truth)
    evals_true = np.sort(evals_true_raw)

    # Mean diffusivitiy
    mean_true = np.mean(evals_true)
    mean_guess = np.mean(evals_guess)
    if(not mean_true == 0):
        MD = abs(mean_true-mean_guess)
    else:
        MD = abs(mean_true-mean_guess)

    # Fractional Anisotropy
    FA_true  = FracAni(evals_true,mean_true)
    FA_guess = FracAni(evals_guess,mean_guess)
    if(not FA_true == 0):
        FA = abs(FA_true-FA_guess)
    else:
        FA = abs(FA_true-FA_guess)
                                    
    
    return MD,FA
    
def PercsMDFA(Guess,Truth):
    # Eigenvalue error
    evals_guess_raw,evecs_guess = np.linalg.eigh(Guess)
    evals_guess = np.sort(evals_guess_raw)
    evals_true_raw,evecs_true = np.linalg.eigh(Truth)
    evals_true = np.sort(evals_true_raw)

    # Mean diffusivitiy
    mean_true = np.mean(evals_true)
    mean_guess = np.mean(evals_guess)
    if(not mean_true == 0):
        MD = abs(mean_true-mean_guess)/mean_true
    else:
        MD = abs(mean_true-mean_guess)/mean_true

    # Fractional Anisotropy
    FA_true  = FracAni(evals_true,mean_true)
    FA_guess = FracAni(evals_guess,mean_guess)
    if(not FA_true == 0):
        FA = abs(FA_true-FA_guess)/FA_true
    else:
        FA = abs(FA_true-FA_guess)/FA_true
                                    
    
    return MD,FA


def DKIErrors(GuessDT,GuessKT,TruthDT,TruthKT):
    guess = DKIMetrics(GuessDT,GuessKT,False)
    truth = DKIMetrics(TruthDT,TruthKT,False)

    #mk diff
    mk = abs(guess[0]-truth[0])
    ak = abs(guess[1]-truth[1])
    rk = abs(guess[2]-truth[2])
    mkt = abs(guess[3]-truth[3])
    kfa = abs(guess[4]-truth[4])

    return mk,ak,rk,mkt,kfa

def Percs(GuessDT,GuessKT,TruthDT,TruthKT):
    guess = DKIMetrics(GuessDT,GuessKT,False)
    truth = DKIMetrics(TruthDT,TruthKT,False)
    
    #mk diff
    mk = abs(guess[0]-truth[0])/abs(truth[0])
    ak = abs(guess[1]-truth[1])/abs(truth[1])
    rk = abs(guess[2]-truth[2])/abs(truth[2])
    mkt = abs(guess[3]-truth[3])/abs(truth[3])
    kfa = abs(guess[4]-truth[4])/abs(truth[4])
    
    return mk,ak,rk,mkt,kfa


In [None]:
def viol_plot(A,col,hatch=False,**kwargs):
    A_T = np.transpose(A)
    filtered_A = []
    for column in A_T:
        # Remove NaNs
        column = column[~np.isnan(column)]
        # Filter data within 3 standard deviations

        Q1 = np.percentile(column, 25)
        Q3 = np.percentile(column, 75)
        if(Q1==Q3):
            filtered_A.append(column)
        else:
            IQR = Q3 - Q1
            outlier_step = 15 * IQR
            filtered_entries = (column < Q3 + outlier_step)*(column > Q1 - outlier_step)
            filtered_column = column[filtered_entries]
            filtered_A.append(filtered_column)
    
    vp = plt.violinplot(filtered_A,showmeans=True,**kwargs)  
    for v in vp['bodies']:
        v.set_facecolor(col)
    vp['cbars'].set_color(col)
    vp['cmins'].set_color(col)
    vp['cmaxes'].set_color(col)
    vp['cmeans'].set_color('black')
    if(hatch):
        vp['bodies'][0].set_hatch('//')

def box_plot(data, edge_color, fill_color, hatch=None, linewidth=1.5, **kwargs):
    # Clean data to remove NaNs column-wise
    if(np.ndim(data) == 1):
        cleaned_data = data[~np.isnan(data)]
    else:
        cleaned_data = [d[~np.isnan(d)] for d in data]
    # Create the box plot with cleaned data
    bp = plt.boxplot(cleaned_data, patch_artist=True, **kwargs)
    
    for element in ['boxes', 'whiskers', 'means', 'medians', 'caps']:
        plt.setp(bp[element], color=edge_color, linewidth=linewidth)
    for patch in bp['boxes']:
        patch.set(facecolor=fill_color, linewidth=linewidth)      
        if hatch is not None:
            patch.set(hatch=hatch)

    return bp

def viol_plot2(A,col,hatch=False,**kwargs):
    A_T = np.transpose(A)
    filtered_A = []
    for column in A_T:
        # Remove NaNs
        column = column[~np.isnan(column)]
        # Identify outliers using Z-score
        z_scores = stats.zscore(column)
        abs_z_scores = np.abs(z_scores)
        # Filter data within 3 standard deviations
        filtered_entries = (abs_z_scores < 1000)
        filtered_column = column[filtered_entries]
        filtered_A.append(filtered_column)
    
    vp = plt.violinplot(filtered_A,showmeans=True,**kwargs)  
    for v in vp['bodies']:
        v.set_facecolor(col)
    vp['cbars'].set_color(col)
    vp['cmins'].set_color(col)
    vp['cmaxes'].set_color(col)
    vp['cmeans'].set_color('black')
    if(hatch):
        vp['bodies'][0].set_hatch('//')

In [None]:
network_path = '../Networks/'
image_path   = './Images/'
if not os.path.exists(image_path):
    os.makedirs(image_path)
NoiseLevels = [None,20,10,5,2]

TrainingSamples = 50000
InferSamples    = 500

lower_abs,upper_abs = -0.07,0.07
lower_rest,upper_rest = -0.015,0.015
lower_S0 = 25
upper_S0 = 2000
Save = True

TrueCol  = 'k'
NoisyCol = 'k'
NLLSFit   = np.array([225,190,106])/255
SBIFit   = np.array([64,176,166])/255
WLSFit = np.array([140, 100, 200]) / 255  # muted violet
MLEFit = np.array([210, 80, 140]) / 255

Errors_name = ['MD comparison','FA comparison','eig. comparison','Frobenius','Signal comparison','Correlation','Signal comparison','Correlation2']
custom_prior = DTIPriorS0(lower_abs,upper_abs,lower_rest,upper_rest,lower_S0,upper_S0)
priorS0, *_ = process_prior(custom_prior) 

## DKIFits

In [None]:

i = 1
fdwi = '../HCP_data/Pat'+str(i)+'/diff_1k.nii.gz'
bvalloc = '../HCP_data/Pat'+str(i)+'/bvals_1k.txt'
bvecloc = '../HCP_data/Pat'+str(i)+'/bvecs_1k.txt'

fdwi3 = '../HCP_data/Pat'+str(i)+'/diff_3k.nii.gz'
bvalloc3 = '../HCP_data/Pat'+str(i)+'/bvals_3k.txt'
bvecloc3 = '../HCP_data/Pat'+str(i)+'/bvecs_3k.txt'

bvalsHCP = np.loadtxt(bvalloc)
bvecsHCP = np.loadtxt(bvecloc)
gtabHCP = gradient_table(bvalsHCP, bvecsHCP)

bvalsHCP3 = np.loadtxt(bvalloc3)
bvecsHCP3 = np.loadtxt(bvecloc3)
gtabHCP3 = gradient_table(bvalsHCP3, bvecsHCP3)

gtabExt  = gradient_table(np.hstack((bvalsHCP,bvalsHCP3)), np.vstack((bvecsHCP,bvecsHCP3)))

data, affine, img = load_nifti(fdwi, return_img=True)
data, affine = reslice(data, affine, (1.5,1.5,1.5), (2.5,2.5,2.5))
axial_middle = data.shape[2] // 2
maskdata, mask = median_otsu(data, vol_idx=range(10, 50), median_radius=3,
                             numpass=1, autocrop=False, dilate=2)

data3, affine, img = load_nifti(fdwi3, return_img=True)
data3, affine = reslice(data3, affine, (1.5,1.5,1.5), (2.5,2.5,2.5))
# Get the indices of True values
true_indices = np.argwhere(mask)

# Determine the minimum and maximum indices along each dimension
min_coords = true_indices.min(axis=0)
max_coords = true_indices.max(axis=0)

maskdata  = maskdata[min_coords[0]:max_coords[0]+1,min_coords[1]:max_coords[1]+1,min_coords[2]:max_coords[2]+1]
maskdata3 = data3[min_coords[0]:max_coords[0]+1,min_coords[1]:max_coords[1]+1,min_coords[2]:max_coords[2]+1]

TestData = np.concatenate([maskdata[:, :, axial_middle, :],maskdata3[:, :, axial_middle, :]],axis=-1)
FlatTD = TestData.reshape(maskdata.shape[0]*maskdata.shape[1],138)
FlatTD = FlatTD[FlatTD[:,:69].sum(axis=-1)>0]
FlatTD = FlatTD[~np.array(FlatTD<0).any(axis=-1)]

dkimodel = dki.DiffusionKurtosisModel(gtabExt)
tenfit = dkimodel.fit(FlatTD)
DKIHCP = tenfit.kt
DTIHCP = tenfit.lower_triangular()
DKIMin = np.array(DKIHCP)
DTIMin = np.array(DTIHCP)


DTIFilt1 = DTIMin[(abs(DKIMin)<10).all(axis=1)]
DKIFilt1 = DKIMin[(abs(DKIMin)<10).all(axis=1)]
DTIFilt = DTIFilt1[(DKIFilt1>-3/7).all(axis=1)]
DKIFilt = DKIFilt1[(DKIFilt1>-3/7).all(axis=1)]

TrueMets = []
FA       = []
for (dt,kt) in tqdm(zip(DTIFilt,DKIFilt)):
    TrueMets.append(DKIMetrics(dt,kt))
    FA.append(FracAni(np.linalg.eigh(vals_to_mat(dt))[0],np.mean(np.linalg.eigh(vals_to_mat(dt))[0])))
TrueMets = np.array(TrueMets)
TrueFA = np.array(FA)

In [None]:
# Min fit
DT1_Min,DT2_Min = FitDT(DTIFilt,1)
x4_Min,R1_Min,x2_Min,R2_Min = FitKT(DKIFilt,1)

# LowFA Fit
DT1_lfa,DT2_lfa = FitDT(DTIFilt[TrueMets[:,-1]<0.3,:],1)
x4_lfa,R1_lfa,x2_lfa,R2_lfa = FitKT(DKIFilt[TrueMets[:,-1]<0.3,:],1)

# HighFA Fit
DT1_hfa,DT2_hfa = FitDT(DTIFilt[TrueMets[:,-1]>0.7,:],1)
x4_hfa,R1_hfa,x2_hfa,R2_hfa = FitKT(DKIFilt[TrueMets[:,-1]>0.7,:],1)

# UltraLowFA Fit
DT1_ulfa,DT2_ulfa = FitDT(DTIFilt[TrueMets[:,-1]<0.1,:],1)
x4_ulfa,R1_ulfa,x2_ulfa,R2_ulfa = FitKT(DKIFilt[TrueMets[:,-1]<0.1,:],1)

# HigherAK Fit
DT1_hak,DT2_hak = FitDT(DTIFilt[TrueMets[:,1]>0.9,:],1)
x4_hak,R1_hak,x2_hak,R2_hak = FitKT(DKIFilt[TrueMets[:,1]>0.9,:],1)

In [None]:
def invertComputeDTI(A, scale_factor=1e-3):
    # 1) ensure perfect symmetry
    A = (A + A.T) / 2.0

    # 2) Cholesky → L lower‑triangular with positive diagonals
    L = np.linalg.cholesky(A)

    # 3) recover params:
    #    - diagonal params = log(L[ii,ii] / scale_factor)
    #    - off‑diagonals = the corresponding L entries
    return np.array([
        np.log(L[0,0] / scale_factor),  # p0
        np.log(L[1,1] / scale_factor),  # p2
        np.log(L[2,2] / scale_factor),   # p5
        L[1,0],                         # p1
        L[2,0],                         # p3
        L[2,1]                         # p4

    ])
def ComputeDTI(params, scale_factor=1e-3):
    L = fill_lower_diag(params)
    diag_indices = np.diag_indices_from(L)
    L[diag_indices] = np.exp(L[diag_indices]) * scale_factor
    A = L @ L.T
    return A
    
def rician_nll_DKI(params,S_obs,gtab,S0):
    dt        = params[:6]
    kt        = params[6:-1]
    log_sigma = params[-1]

    S_model  = CustomDKISimulator2(dt,kt,gtab,S0)

    sigma2 = np.exp(2 * log_sigma)  # ensure positivity
    # Avoid log(0) with clipping
    S_obs_clipped = np.clip(S_obs, 1e-10, None)
    S_model_clipped = np.clip(S_model, 1e-10, 1e10)

    bessel_arg = (S_obs_clipped * S_model_clipped) / sigma2

    # use the scaled Bessel to avoid overflow:
    log_bessel = np.log(i0e(bessel_arg)) + np.abs(bessel_arg)

    nll = (
        np.log(sigma2)
        - np.log(S_obs_clipped)
        + (S_obs_clipped**2 + S_model_clipped**2) / (2 * sigma2)
        - log_bessel
    )
    if(np.isnan(nll).any() or np.isinf(nll).any()):
        import pdb;pdb.set_trace()
    return np.sum(nll)

In [None]:
def rician_nll_DKI_real(params,S_obs,gtab):
    dt        = params[:6]
    kt        = params[6:-2]
    S0 = np.exp(params[-2])
    log_sigma = params[-1]

    S_model  = CustomDKISimulator2(dt,kt,gtab,S0)

    sigma2 = np.exp(2 * log_sigma)  # ensure positivity
    # Avoid log(0) with clipping
    S_obs_clipped = np.clip(S_obs, 1e-10, None)
    S_model_clipped = np.clip(S_model, 1e-10, 1e10)

    bessel_arg = (S_obs_clipped * S_model_clipped) / sigma2

    # use the scaled Bessel to avoid overflow:
    log_bessel = np.log(i0e(bessel_arg)) + np.abs(bessel_arg)

    nll = (
        np.log(sigma2)
        - np.log(S_obs_clipped)
        + (S_obs_clipped**2 + S_model_clipped**2) / (2 * sigma2)
        - log_bessel
    )

    return np.sum(nll)
def robust_cholesky(A, tol=1e-12):
    """
    Drop-in replacement for np.linalg.cholesky that tolerates tiny negative eigenvalues.
    
    Parameters:
    - A: symmetric matrix (should be positive semi-definite)
    - tol: eigenvalue floor to avoid sqrt of negative numbers
    
    Returns:
    - L: lower-triangular matrix such that A ≈ L @ L.T
    """
    # Symmetry check
    if not np.allclose(A, A.T, atol=1e-10):
        raise ValueError("Matrix must be symmetric.")

    # Eigen-decomposition
    eigvals, eigvecs = np.linalg.eigh(A)

    # Clip small/negative eigenvalues to ensure stability
    eigvals_clipped = np.clip(eigvals, tol, None)
    
    # Construct the square root of A
    A_half = eigvecs @ np.diag(np.sqrt(eigvals_clipped))

    # Use QR decomposition on A_half.T to get a lower-triangular L
    Q, R = np.linalg.qr(A_half.T)
    L = R.T

    # Enforce positive diagonals
    signs = np.sign(np.diag(L))
    signs[signs == 0] = 1  # avoid zero sign
    L = L * signs[:, np.newaxis]

    # Ensure strictly lower-triangular by zeroing out upper part (optional)
    L = np.tril(L)

    return L
def invertComputeDTI(A, scale_factor=1e-3):
    # 1) ensure perfect symmetry
    A = (A + A.T) / 2.0

    # 2) Cholesky → L lower‑triangular with positive diagonals
    L = robust_cholesky(A)

    # 3) recover params:
    #    - diagonal params = log(L[ii,ii] / scale_factor)
    #    - off‑diagonals = the corresponding L entries
    return np.array([
        np.log(L[0,0] / scale_factor),  # p0
        np.log(L[1,1] / scale_factor),  # p2
        np.log(L[2,2] / scale_factor),   # p5
        L[1,0],                         # p1
        L[2,0],                         # p3
        L[2,1]                         # p4

    ])
DatFolder = './SavedDat/'

In [None]:
NLLSFit   = np.array([225,190,106])/255
SBIFit   = np.array([64,176,166])/255
WLSFit = np.array((192,108,132)) / 255  # muted violet
MLEFit = np.array([70,100,150]) / 255
BayFit = np.array([140,165,200])/255

In [None]:
def masked_local_ssim(img1, img2, mask, win_size=15,dat_range=None):
    half_win = win_size // 2
    padded_img1 = np.pad(img1, half_win, mode='reflect')
    padded_img2 = np.pad(img2, half_win, mode='reflect')
    padded_mask = np.pad(mask, half_win, mode='constant', constant_values=0)

    ssim_values = []

    rows, cols = img1.shape
    for i in range(rows):
        for j in range(cols):
            mask_patch = padded_mask[i:i+win_size, j:j+win_size]
            if mask_patch.all():  # Only if fully valid
                img1_patch = padded_img1[i:i+win_size, j:j+win_size]
                img2_patch = padded_img2[i:i+win_size, j:j+win_size]
                if(not (np.isnan(img1_patch).any() or np.isnan(img2_patch).any())):
                    if(dat_range is None):
                        val = ssim(img1_patch, img2_patch,
                                   data_range=img1.max() - img1.min(), full=False)
                    else:
                        val = ssim(img1_patch, img2_patch,
                                   data_range=dat_range, full=False)
                    ssim_values.append(val)

    return np.nanmean(ssim_values)

In [None]:
def viol_plot(A,col,hatch=False,**kwargs):
    A_T = np.transpose(A)
    filtered_A = []
    for column in A_T:
        # Remove NaNs
        column = column[~np.isnan(column)]
        # Identify outliers using Z-score
        z_scores = stats.zscore(column)
        abs_z_scores = np.abs(z_scores)
        # Filter data within 3 standard deviations
        filtered_entries = (abs_z_scores < 1000)
        filtered_column = column[filtered_entries]
        filtered_A.append(filtered_column)
    
    vp = plt.violinplot(filtered_A,showmeans=True,**kwargs)  
    for v in vp['bodies']:
        v.set_facecolor(col)
    vp['cbars'].set_color(col)
    vp['cmins'].set_color(col)
    vp['cmaxes'].set_color(col)
    vp['cmeans'].set_color('black')
    if(hatch):
        vp['bodies'][0].set_hatch('//')
def BoxPlots(y_data, positions, colors, colors2, ax,hatch = False,scatter=False,scatter_alpha=0.5, **kwargs):

    GREY_DARK = "#747473"
    jitter = 0.02
    # Clean data to remove NaNs column-wise
    if(np.ndim(y_data) == 1):
        cleaned_data = y_data[~np.isnan(y_data)]
    else:
        cleaned_data = [d[~np.isnan(d)] for d in y_data]
    
    # Define properties for the boxes (patch objects)
    boxprops = dict(
        linewidth=2, 
        facecolor='none',       # use facecolor for filling (set to 'none' if you want no fill)
        edgecolor='turquoise'   # edgecolor for the outline
    )

    # Define properties for the medians (Line2D objects)
    # Ensure GREY_DARK is defined (or replace it with a color string)
    medianprops = dict(
        linewidth=2, 
        color=GREY_DARK,
        solid_capstyle="butt"
    )

    # For whiskers, since they are Line2D objects, use 'color'
    whiskerprops = dict(
        linewidth=2, 
        color='turquoise'
    )

    bplot = ax.boxplot(
        cleaned_data,
        positions=positions, 
        showfliers=False,
        showcaps = False,
        medianprops=medianprops,
        whiskerprops=whiskerprops,
        boxprops=boxprops,
        patch_artist=True,
        **kwargs
    )

    # Update the color of each box (these are patch objects)
    for i, box in enumerate(bplot['boxes']):
        box.set_edgecolor(colors[i])
        if(hatch):
            box.set_hatch('/')
    
    
    # Update the color of the whiskers (each box has 2 whiskers)
    for i in range(len(positions)):
        bplot['whiskers'][2*i].set_color(colors[i])
        bplot['whiskers'][2*i+1].set_color(colors[i])
    
    # If caps are enabled, update their color (Line2D objects)
    if 'caps' in bplot:
        for i, cap in enumerate(bplot['caps']):
            cap.set_color(colors[i//2])  # two caps per box

    if(scatter):
        if(np.ndim(cleaned_data) == 1):
            x_data = np.array([positions] * len(cleaned_data))
            x_jittered = x_data + stats.t(df=6, scale=jitter).rvs(len(x_data))
            ax.scatter(x_data, cleaned_data, s=100, color=colors2, alpha=scatter_alpha)
        else:
            x_data = [np.array([positions[i]] * len(d)) for i, d in enumerate(cleaned_data)]
            x_jittered = [x + stats.t(df=6, scale=jitter).rvs(len(x)) for x in x_data]
            # Plot the scatter points with jitter (using colors2)
            for x, y, c in zip(x_jittered, cleaned_data, colors2):
                ax.scatter(x, y, s=100, color=c, alpha=scatter_alpha)
def BoxPlots2(y_data, positions, colors, colors2, ax,hatch = False):
    import numpy as np
    from scipy import stats

    jitter = 0.02
    x_data = [np.array([positions[i]] * len(d)) for i, d in enumerate(y_data)]
    x_jittered = [x + stats.t(df=6, scale=jitter).rvs(len(x)) for x in x_data]

    # Define properties for the boxes (patch objects)
    boxprops = dict(
        linewidth=2, 
        facecolor='none',       # use facecolor for filling (set to 'none' if you want no fill)
        edgecolor='turquoise'   # edgecolor for the outline
    )

    # Define properties for the medians (Line2D objects)
    # Ensure GREY_DARK is defined (or replace it with a color string)
    medianprops = dict(
        linewidth=2, 
        color='dimgray',  # Replace 'GREY_DARK' with an actual color if needed
        solid_capstyle="butt"
    )

    # For whiskers, since they are Line2D objects, use 'color'
    whiskerprops = dict(
        linewidth=2, 
        color='turquoise'
    )

    bplot = ax.boxplot(
        y_data,
        positions=positions, 
        showfliers=False,
        showcaps=False,
        showmeans=True,
        medianprops=medianprops,
        whiskerprops=whiskerprops,
        boxprops=boxprops,
        patch_artist=True
    )

    # Update the color of each box (these are patch objects)
    for i, box in enumerate(bplot['boxes']):
        box.set_edgecolor(colors[i])
        if(hatch):
            box.set_hatch('/')
    
    # Update the color of the medians (Line2D objects)
    for i, median in enumerate(bplot['medians']):
        median.set_color(colors[i])
    
    # Update the color of the whiskers (each box has 2 whiskers)
    for i in range(len(positions)):
        bplot['whiskers'][2*i].set_color(colors[i])
        bplot['whiskers'][2*i+1].set_color(colors[i])
    
    # If caps are enabled, update their color (Line2D objects)
    if 'caps' in bplot:
        for i, cap in enumerate(bplot['caps']):
            cap.set_color(colors[i//2])  # two caps per box

    # Plot the scatter points with jitter (using colors2)
    for x, y, c in zip(x_jittered, y_data, colors2):
        ax.scatter(x, y, s=100, color=c, alpha=0.5)

# In-silico (a-b)

In [None]:
fimg_init, fbvals, fbvecs = get_fnames('small_64D')
bvals, bvecs = read_bvals_bvecs(fbvals, fbvecs)
hsph_initial = HemiSphere(xyz=bvecs[1:])
hsph_updated,_ = disperse_charges(hsph_initial,5000)
bvecs = np.vstack([[0,0,0],hsph_updated.vertices])
bvalsExt = np.hstack([bvals, 3000*np.ones_like(bvals)])
bvecsExt = np.vstack([bvecs, bvecs])
bvalsExt[65] = 0
gtabSim = gradient_table(bvalsExt, bvecsExt)

In [None]:
if os.path.exists(f"{network_path}/DKISimFull.pickle"):
    with open(f"{network_path}/DKISimFull.pickle", "rb") as handle:
        posteriorFull = pickle.load(handle)
else:
    torch.manual_seed(1)
    np.random.seed(1)
    DT = []
    KT = []
    S0 = []
    DT1,KT1 = GenDTKT([DT1_full,DT2_full],[x4_full,R1_full,x2_full,R2_full],12,int(2.5*6000))
    DT2,KT2 = GenDTKT([DT1_lfa,DT2_lfa],[x4_lfa,R1_lfa,x2_lfa,R2_lfa],12,int(2.5*2000))
    DT3,KT3 = GenDTKT([DT1_hfa,DT2_hfa],[x4_hfa,R1_hfa,x2_hfa,R2_hfa],12,int(2.5*6000))
    DT4,KT4 = GenDTKT([DT1_ulfa,DT2_ulfa],[x4_ulfa,R1_ulfa,x2_ulfa,R2_ulfa],12,int(2.5*6000))
    DT5,KT5 = GenDTKT([DT1_hak,DT2_hak],[x4_hak,R1_hak,x2_hak,R2_hak],12,int(2.5*6000))
    
    
    DT = np.vstack([DT1,DT2,DT3,DT4,DT5])
    KT = np.vstack([KT1,KT2,KT3,KT4,KT5])
    
    S0 = np.array(S0).reshape(len(S0),1)
    
    indx = np.arange(len(KT))
    Obs = np.zeros([len(KT),len(gtabSim.bvecs)])
    kk = 0
    while len(indx)>0:
        for i in tqdm(indx): 
            Obs[i] = CustomDKISimulator(DT[i],KT[i],gtabSim,200,np.random.rand()*30)
        
        indxNew = []
        for i,O in enumerate(Obs):
            if (O>800).any() or (O<0).any():
                indxNew.append(i)
        KT[indxNew] = KT[indxNew]/2
        DT[indxNew] = GenDTKT([DT1_full,DT2_full],[x4_full,R1_full,x2_full,R2_full],kk,1)[0]
    
        indx = indxNew
        kk+=1
    Par = np.hstack([DT,KT])
    Obs = torch.tensor(Obs).float()
    Par = torch.tensor(Par).float()
    
    # Create inference object. Here, NPE is used.
    inference = SNPE()
    
    # generate simulations and pass to the inference object
    inference = inference.append_simulations(Par, Obs)
    
    # train the density estimator and build the posterior
    density_estimator = inference.train(stop_after_epochs= 100)
    posteriorFull = inference.build_posterior(density_estimator)
    
    os.system('say "Network done."')
    if not os.path.exists(f"{network_path}/DKISimFull.pickle"):
        with open(f"{network_path}/DKISimFull.pickle", "wb") as handle:
            pickle.dump(posteriorFull, handle)

In [None]:
torch.manual_seed(1)
np.random.seed(1)
DT1,KT1 = GenDTKT([DT1_full,DT2_full],[x4_full,R1_full,x2_full,R2_full],1,40)
DT2,KT2 = GenDTKT([DT1_lfa,DT2_lfa],[x4_lfa,R1_lfa,x2_lfa,R2_lfa],1,40)
DT3,KT3 = GenDTKT([DT1_hfa,DT2_hfa],[x4_hfa,R1_hfa,x2_hfa,R2_hfa],1,40)
DT4,KT4 = GenDTKT([DT1_ulfa,DT2_ulfa],[x4_ulfa,R1_ulfa,x2_ulfa,R2_ulfa],1,40)
DT5,KT5 = GenDTKT([DT1_hak,DT2_hak],[x4_hak,R1_hak,x2_hak,R2_hak],12,40)

SampsDT = np.vstack([DT1,DT2,DT3,DT4,DT5])
SampsKT = np.vstack([KT1,KT2,KT3,KT4,KT5])

Samples  = []

for Sd,Sk in zip(SampsDT,SampsKT):
    Samples.append([CustomDKISimulator(Sd,Sk,gtabSim, S0=200,snr=scale) for scale in NoiseLevels])

Samples = np.array(Samples)

In [None]:
se = 14
torch.manual_seed(se)
np.random.seed(se)
j = 1
vL = torch.tensor([0.2*j])
vS = torch.tensor([0.01*j])  

kk = np.random.randint(0,4)
DT_guess,KT_guess= GenDTKT([DT1_full,DT2_full],[x4_full,R1_full,x2_full,R2_full],se,1)



In [None]:
def residuals(params,gtab,y):

    dt        = params[:6]
    kt        = params[6:]
    
    Signal = CustomDKISimulator2(dt,kt,gtab,200)
    return y - Signal

In [None]:
def invertComputeDTI_exp(A, scale_factor=1e-3):
    # 1) ensure perfect symmetry
    A = (A + A.T) / 2.0

    # 2) Cholesky → L lower‑triangular with positive diagonals
    L = np.linalg.cholesky(A)

    # 3) recover params:
    #    - diagonal params = log(L[ii,ii] / scale_factor)
    #    - off‑diagonals = the corresponding L entries
    return np.array([
        np.log(L[0,0] / scale_factor),  # p0
        np.log(L[1,1] / scale_factor),  # p2
        np.log(L[2,2] / scale_factor),   # p5
        L[1,0],                         # p1
        L[2,0],                         # p3
        L[2,1]                         # p4

    ])
def ComputeDTI_exp(params, scale_factor=1e-3):
    L = fill_lower_diag(params)
    diag_indices = np.diag_indices_from(L)
    L[diag_indices] = np.exp(L[diag_indices]) * scale_factor
    A = L @ L.T
    return A
    
def rician_nll_DKI(params,S_obs,gtab,S0):
    dt        = params[:6]
    kt        = params[6:-1]
    log_sigma = params[-1]

    S_model  = CustomDKISimulator2(dt,kt,gtab,S0)

    sigma2 = np.exp(2 * log_sigma)  # ensure positivity
    # Avoid log(0) with clipping
    S_obs_clipped = np.clip(S_obs, 1e-10, None)
    S_model_clipped = np.clip(S_model, 1e-10, 1e10)

    bessel_arg = (S_obs_clipped * S_model_clipped) / sigma2

    # use the scaled Bessel to avoid overflow:
    log_bessel = np.log(i0e(bessel_arg)) + np.abs(bessel_arg)

    nll = (
        np.log(sigma2)
        - np.log(S_obs_clipped)
        + (S_obs_clipped**2 + S_model_clipped**2) / (2 * sigma2)
        - log_bessel
    )
    if(np.isnan(nll).any() or np.isinf(nll).any()):
        import pdb;pdb.set_trace()
    return np.sum(nll)

In [None]:
lower_bounds = [-0.0] * 3 + [-0.015]*3 + [0]*3 + [-1]*9 + [-0.25]*3 + [np.log(1e-2)]

upper_bounds = [5] * 3 + [0.015]*3 + [5]*3 + [1]*9 + [0.25]*3 +  [np.log(200)]
bounds = list(zip(lower_bounds, upper_bounds))

In [None]:
if(os.path.exists(DatFolder+'/Sim_Full_Error_MLE.npy')):
    Error_MLE = np.load(DatFolder+'/Sim_Full_Error_MLE.npy')
else:
    Error_MLE = []
    for k in tqdm([1,2,3,4]):
        ErrorN2 = []
        ENoise = []
        for i in tqdm(range(200)):
            tObs = Samples[i,k,:]#Simulator(bvals,bvecs,200,params,Noise)
            LS_x = least_squares(residuals, x0=np.hstack([invertComputeDTI_exp(vals_to_mat(DT_guess.squeeze())),KT_guess.squeeze()]), args=(gtabSim, tObs)).x
            res = minimize(
                    rician_nll_DKI,
                    np.hstack([LS_x[:6],LS_x[6:21],np.log(10)]),
                    args=(tObs, gtabSim, 200),
                    options={'disp':0,
                        'gtol': 1e-10,
                        'ftol': 1e-10,},
                        bounds=bounds
                )
            
            ErrorN2.append(DKIErrors(mat_to_vals(ComputeDTI_exp(res.x[:6])),res.x[6:-1],SampsDT[i],SampsKT[i]))
        Error_MLE.append(ErrorN2)

In [None]:
torch.manual_seed(10)
ErrorFull = []
for k in tqdm(range(5)):
    ErrorN2 = []
    ENoise = []
    for i in range(200):
        tObs = Samples[i,k,:]
        posterior_samples_1 = posteriorFull.sample((InferSamples,), x=tObs,show_progress_bars=False)
        GuessSBI = posterior_samples_1.mean(axis=0)
        
        ErrorN2.append(DKIErrors(GuessSBI[:6],GuessSBI[6:],SampsDT[i],SampsKT[i]))
    ErrorFull.append(ErrorN2)

Error_WLS = []
dkimodel = dki.DiffusionKurtosisModel(gtabSim,fit_method='WLS')

for k in tqdm(range(5)):
    ErrorN2 = []
    ENoise = []
    for i in range(200):
        tObs = Samples[i,k,:]#Simulator(bvals,bvecs,200,params,Noise)
        tenfit = dkimodel.fit(tObs)
        
        ErrorN2.append(DKIErrors(tenfit.lower_triangular(),tenfit.kt,SampsDT[i],SampsKT[i]))
    Error_WLS.append(ErrorN2)

Error_NLLS = []
dkimodel = dki.DiffusionKurtosisModel(gtabSim,fit_method='NLLS')

for k in tqdm(range(5)):
    ErrorN2 = []
    ENoise = []
    for i in range(200):
        tObs = Samples[i,k,:]#Simulator(bvals,bvecs,200,params,Noise)
        tenfit = dkimodel.fit(tObs)
        
        ErrorN2.append(DKIErrors(tenfit.lower_triangular(),tenfit.kt,SampsDT[i],SampsKT[i]))
    Error_NLLS.append(ErrorN2)


In [None]:
priors_SD =  [1/0.001633,    1/.5e-5,    1/0.001633,    1/.5e-5,   1/ 7.5e-5, 1/0.001633] + [1e2]*15 + [1e2]
mu0 = np.zeros(22)

V0  = np.diag(priors_SD)

In [None]:
# ------------------------------------------------------------
# 0.  Your forward model -------------------------------------
def f(theta, gtab):
    """DKI signal model for *one* voxel, given params theta."""
    dt        = theta[:6]
    kt        = theta[6:-1]
    S0        = np.exp(theta[-1])          # log-S0 in parameters
    return CustomDKISimulator2(dt, kt, gtab, S0)   # (Ngrad,)

# Finite-difference Jacobian (6 dt + K kt + 1 S0 parameters)
def jacobian(theta, gtab, eps=1e-5):
    J = np.zeros((len(gtab.bvals), len(theta)))
    f0 = f(theta, gtab)
    for i in range(len(theta)):
        t_eps        = theta.copy()
        t_eps[i]    += eps
        J[:, i]      = (f(t_eps, gtab) - f0) / eps
    return J, f0
# ------------------------------------------------------------

def vb_gauss_one_voxel(y, gtab, theta0,
                       mu0=None, V0=None,  # Gaussian prior
                       a0=1e-3, b0=1e-3,   # Gamma prior
                       max_iter=100, tol=1e-10):
    """
    Mean-field VB (Gaussian q(theta), Gamma q(1/σ²))
    for one voxel with Gaussian noise.
    """
    D = len(theta0)                       # #parameters
    if mu0 is None:
        mu0 = np.zeros(D)
    if V0 is None:
        V0 = np.eye(D) * 1e2              # very vague prior
    V0_inv = inv(V0)

    m  = theta0.copy()                    # variational mean
    S  = V0.copy()                        # variational cov
    a  = a0 + len(y)/2.
    b  = b0 + 1.0                         # will be updated
    for it in range(max_iter):
        # --- E[precision] ----------------
        lam = a / b                       # <1/σ²>

        # --- linearise model around current mean
        J, f_m = jacobian(m, gtab)
        r      = y - f_m                  # residual

        # --- Gauss-Newton-style VB updates
        S_new  = inv(V0_inv + lam * J.T @ J)
        m_new  = m + S_new @ (V0_inv @ (mu0 - m) + lam * J.T @ r)

        # --- update Gamma factors
        quad   = (r @ r
                  + np.trace(J @ S_new @ J.T))
        b_new  = b0 + 0.5 * quad

        # --- convergence check
        if np.linalg.norm(m_new - m) < tol:
            m, S, b = m_new, S_new, b_new
            break

        m, S, b = m_new, S_new, b_new

    return m, S, a/b 

In [None]:
Error_Bay = []
for k in tqdm.tqdm([1,2,3,4]):
    ErrorN2 = []
    ENoise = []
    for i in tqdm.tqdm(range(200)):
        tObs = Samples[i,k,:]#Simulator(bvals,bvecs,200,params,Noise)
        LS_x = least_squares(residuals, x0=np.hstack([invertComputeDTI_exp(vals_to_mat(DT_guess.squeeze())),KT_guess.squeeze()]), args=(gtabSim, tObs)).x
        try:
            m, S, prec = vb_gauss_one_voxel(tObs, gtabSim, np.append(LS_x,np.log(200)),mu0=mu0, V0=V0)
            ErrorN2.append(DKIErrors(mat_to_vals(ComputeDTI_exp(m[:6])),m[6:-1],SampsDT[i],SampsKT[i]))
        except:
            ErrorN2.append(math.nan*np.zeros(5))
        
    Error_Bay.append(ErrorN2)

In [None]:
ErrorFull = np.array(ErrorFull)
Error_NLLS = np.array(Error_NLLS)
Error_WLS = np.array(Error_WLS)
Error_MLE = np.array(Error_MLE)
Error_Bay = np.array(Error_Bay)
ErrorNames = ['MK Error', 'AK Error', 'RK Error', 'MKT Error', 'KFA Error']
fig,ax = plt.subplots(1,5,figsize=(25.5,3))
for i in range(5):
    plt.sca(ax[i])
    box_plot(Error_WLS[1:,:,i],WLSFit-0.2, np.clip(WLSFit+0.2,0,1),showfliers=False,showmeans=False,widths=0.3,positions=[1,2.9,4.8,6.7])
    box_plot(Error_NLLS[1:,:,i],NLLSFit-0.2, np.clip(NLLSFit+0.2,0,1),showfliers=False,widths=0.3,positions=[1.3,3.2,5.1,7])
    box_plot(Error_MLE[:,:,i],MLEFit-0.2, np.clip(MLEFit+0.2,0,1),showfliers=False,widths=0.3,positions=[1.6,3.5,5.4,7.3])
    box_plot(Error_Bay[:,:,i],BayFit-0.2, np.clip(BayFit+0.2,0,1),showfliers=False,widths=0.3,positions=[1.9,3.8,5.7,7.6])
    box_plot(ErrorFull[1:,:,i],SBIFit-0.2, np.clip(SBIFit+0.2,0,1),showfliers=False,widths=0.3,positions=[2.2,4.1,6,7.9])

    plt.xticks([1.6, 3.5, 5.4, 7.3,],[20,10,5,2],fontsize=32)
    plt.gca().ticklabel_format(axis='y',style='sci',scilimits=(-1,1))
    plt.grid(axis='y')
    plt.yticks(fontsize=32)

    if(i==4):
        handles = [
            Line2D([0], [0], color=SBIFit, lw=4, label='SBI'),  # Adjust color as per the actual plot color
        ]
        # Add the legenda
        plt.legend(handles=handles,loc=2, bbox_to_anchor=(0,1.15),
                   fontsize=36,columnspacing=0.3,handlelength=0.6,handletextpad=0.3)
    if(i==0):
        handles = [
            Line2D([0], [0], color=WLSFit, lw=4, label='WLS'),  # Adjust color as per the actual plot color
        ]
        # Add the legenda
        plt.legend(handles=handles,loc=2, bbox_to_anchor=(0,1.15),
                   fontsize=36,columnspacing=0.3,handlelength=0.6,handletextpad=0.3)
    if(i==1):
        handles = [
            Line2D([0], [0], color=NLLSFit, lw=4, label='NLLS'),  # Adjust color as per the actual plot color
        ]
        # Add the legenda
        plt.legend(handles=handles,loc=2, bbox_to_anchor=(0,1.15),
                   fontsize=36,columnspacing=0.3,handlelength=0.6,handletextpad=0.3)
    if(i==2):
        handles = [
            Line2D([0], [0], color=MLEFit, lw=4, label='MLE'),  # Adjust color as per the actual plot color
        ]
        # Add the legenda
        plt.legend(handles=handles,loc=2, bbox_to_anchor=(0,1.15),
                   fontsize=36,columnspacing=0.3,handlelength=0.6,handletextpad=0.3)
    if(i==3):
        handles = [
            Line2D([0], [0], color=BayFit, lw=4, label='Bayes'),  # Adjust color as per the actual plot color
        ]
        # Add the legenda
        plt.legend(handles=handles,loc=2, bbox_to_anchor=(0,1.15),
                   fontsize=36,columnspacing=0.3,handlelength=0.6,handletextpad=0.3)
if Save: plt.savefig(image_path+'ErrorsFull2.pdf',format='pdf',bbox_inches='tight',transparent=True)

In [None]:
fimg_init, fbvals, fbvecs = get_fnames('small_64D')
bvals, bvecs = read_bvals_bvecs(fbvals, fbvecs)
hsph_initial15 = HemiSphere(xyz=bvecs[1:16])
hsph_initial7 = HemiSphere(xyz=bvecs[1:7])
hsph_updated15,_ = disperse_charges(hsph_initial15,5000)
hsph_updated7,_ = disperse_charges(hsph_initial7,5000)
gtabSimSub = gradient_table(np.array([0]+[1000]*6+[3000]*15).squeeze(), np.vstack([[0,0,0],hsph_updated7.vertices,hsph_updated15.vertices]))

In [None]:
if os.path.exists(f"{network_path}/DKISimMin.pickle"):
    with open(f"{network_path}/DKISimMin.pickle", "rb") as handle:
        posteriorFull = pickle.load(handle)
else:
    torch.manual_seed(1)
    np.random.seed(1)
    DT = []
    KT = []
    S0 = []
    DT1,KT1 = GenDTKT([DT1_full,DT2_full],[x4_full,R1_full,x2_full,R2_full],12,int(2.5*6000))
    DT2,KT2 = GenDTKT([DT1_lfa,DT2_lfa],[x4_lfa,R1_lfa,x2_lfa,R2_lfa],12,int(2.5*2000))
    DT3,KT3 = GenDTKT([DT1_hfa,DT2_hfa],[x4_hfa,R1_hfa,x2_hfa,R2_hfa],12,int(2.5*6000))
    DT4,KT4 = GenDTKT([DT1_ulfa,DT2_ulfa],[x4_ulfa,R1_ulfa,x2_ulfa,R2_ulfa],12,int(2.5*6000))
    DT5,KT5 = GenDTKT([DT1_hak,DT2_hak],[x4_hak,R1_hak,x2_hak,R2_hak],12,int(2.5*6000))
    
    
    DT = np.vstack([DT1,DT2,DT3,DT4,DT5])
    KT = np.vstack([KT1,KT2,KT3,KT4,KT5])
    
    S0 = np.array(S0).reshape(len(S0),1)
    
    indx = np.arange(len(KT))
    Obs = np.zeros([len(KT),len(gtabSimSub.bvecs)])
    kk = 0
    while len(indx)>0:
        for i in tqdm.tqdm(indx): 
            Obs[i] = CustomDKISimulator(DT[i],KT[i],gtabSimSub,200,np.random.rand()*30)
        
        indxNew = []
        for i,O in enumerate(Obs):
            if (O>800).any() or (O<0).any():
                indxNew.append(i)
        KT[indxNew] = KT[indxNew]/2
        DT[indxNew] = GenDTKT([DT1_full,DT2_full],[x4_full,R1_full,x2_full,R2_full],kk,1)[0]
    
        indx = indxNew
        kk+=1
    Par = np.hstack([DT,KT])
    Obs = torch.tensor(Obs).float()
    Par = torch.tensor(Par).float()
    
    # Create inference object. Here, NPE is used.
    inference = SNPE()
    
    # generate simulations and pass to the inference object
    inference = inference.append_simulations(Par, Obs)
    
    # train the density estimator and build the posterior
    density_estimator = inference.train(stop_after_epochs= 100)
    posteriorFull = inference.build_posterior(density_estimator)
    
    os.system('say "Network done."')
    if not os.path.exists(f"{network_path}/DKISimMin.pickle"):
        with open(f"{network_path}/DKISimMin.pickle", "wb") as handle:
            pickle.dump(posteriorFull, handle)

In [None]:
torch.manual_seed(1)
np.random.seed(1)

Samples7  = []

for Sd,Sk in zip(SampsDT,SampsKT):
    Samples7.append([CustomDKISimulator(Sd,Sk,gtabSimSub, S0=200,snr=scale) for scale in NoiseLevels])

Samples7 = np.array(Samples7)

In [None]:
Error_NLLS = []
dkimodel = dki.DiffusionKurtosisModel(gtabSimSub,fit_method='NLLS')

for k in tqdm.tqdm(range(5)):
    ErrorN2 = []
    ENoise = []
    for i in range(200):
        tObs = Samples7[i,k,:]#Simulator(bvals,bvecs,200,params,Noise)
        tenfit = dkimodel.fit(tObs)
        
        ErrorN2.append(DKIErrors(tenfit.lower_triangular(),tenfit.kt,SampsDT[i],SampsKT[i]))
    Error_NLLS.append(ErrorN2)
Error_WLS = []
dkimodel = dki.DiffusionKurtosisModel(gtabSimSub,fit_method='WLS')

for k in tqdm.tqdm(range(5)):
    ErrorN2 = []
    ENoise = []
    for i in range(200):
        tObs = Samples7[i,k,:]#Simulator(bvals,bvecs,200,params,Noise)
        tenfit = dkimodel.fit(tObs)
        
        ErrorN2.append(DKIErrors(tenfit.lower_triangular(),tenfit.kt,SampsDT[i],SampsKT[i]))
    Error_WLS.append(ErrorN2)

torch.manual_seed(10)
ErrorFull = []
for k in tqdm.tqdm(range(5)):
    ErrorN2 = []
    ENoise = []
    for i in range(200):
        tObs = Samples7[i,k,:]
        posterior_samples_1 = posteriorFull.sample((InferSamples,), x=tObs,show_progress_bars=False)
        GuessSBI = posterior_samples_1.mean(axis=0)
        
        ErrorN2.append(DKIErrors(GuessSBI[:6],GuessSBI[6:],SampsDT[i],SampsKT[i]))
    ErrorFull.append(ErrorN2)


In [None]:
if(os.path.exists('./data/Sim_Min_Error_MLE.npy')):
    Error_MLE = np.load('./data/Sim_Min_Error_MLE.npy')
else:
    Error_MLE = []
    for k in tqdm.tqdm([1,2,3,4]):
        ErrorN2 = []
        ENoise = []
        for i in tqdm.tqdm(range(200)):
            tObs = Samples[i,k,:]#Simulator(bvals,bvecs,200,params,Noise)
            LS_x = least_squares(residuals, x0=np.hstack([invertComputeDTI_exp(vals_to_mat(DT_guess.squeeze())),KT_guess.squeeze()]), args=(gtabSim, tObs)).x
            res = minimize(
                    rician_nll_DKI,
                    np.hstack([LS_x[:6],LS_x[6:21],np.log(10)]),
                    args=(tObs, gtabSim, 200),
                    options={'disp':0,
                        'gtol': 1e-10,
                        'ftol': 1e-10,},
                        bounds=bounds
                )
            
            ErrorN2.append(DKIErrors(mat_to_vals(ComputeDTI_exp(res.x[:6])),res.x[6:-1],SampsDT[i],SampsKT[i]))
        Error_MLE.append(ErrorN2)

In [None]:


Error_Bay_Min = []

for k in tqdm.tqdm(range(5)):
    ErrorN2 = []
    ENoise = []
    for i in tqdm.tqdm(range(200)):
        tObs = Samples7[i,k,:]#Simulator(bvals,bvecs,200,params,Noise)
        LS_x = least_squares(residuals, x0=np.hstack([invertComputeDTI_exp(vals_to_mat(DT_guess.squeeze())),KT_guess.squeeze()]), args=(gtabSimSub, tObs)).x
        try:
            m, S, prec = vb_gauss_one_voxel(tObs, gtabSimSub, np.append(LS_x,np.log(200)),mu0=mu0, V0=V0)
            ErrorN2.append(DKIErrors(mat_to_vals(ComputeDTI_exp(m[:6])),m[6:-1],SampsDT[i],SampsKT[i]))
        except:
            ErrorN2.append(math.nan*np.zeros(5))
    Error_Bay_Min.append(ErrorN2)

In [None]:
Error_MLE_Min2 = np.copy(Error_MLE)
Error_MLE_Min2[Error_MLE_Min2>100] = math.nan

In [None]:
ErrorFull = np.array(ErrorFull)
Error_NLLS = np.array(Error_NLLS)
Error_WLS = np.array(Error_WLS)
Error_MLE_Min2 = np.array(Error_MLE_Min2)
Error_Bay_Min = np.array(Error_Bay_Min)
ErrorNames = ['MK Error', 'AK Error', 'RK Error', 'MKT Error', 'KFA Error']
fig,ax = plt.subplots(1,5,figsize=(25.5,3))
for i in range(5):
    plt.sca(ax[i])
    box_plot(Error_WLS[1:,:,i],WLSFit-0.2, np.clip(WLSFit+0.2,0,1),showfliers=False,showmeans=False,widths=0.3,positions=[1,2.9,4.8,6.7])
    box_plot(Error_NLLS[1:,:,i],NLLSFit-0.2, np.clip(NLLSFit+0.2,0,1),showfliers=False,widths=0.3,positions=[1.3,3.2,5.1,7])
    box_plot(Error_MLE_Min2[1:,:,i],MLEFit-0.2, np.clip(MLEFit+0.2,0,1),showfliers=False,widths=0.3,positions=[1.6,3.5,5.4,7.3])
    box_plot(Error_Bay_Min[1:,:,i],BayFit-0.2, np.clip(BayFit+0.2,0,1),showfliers=False,widths=0.3,positions=[1.9,3.8,5.7,7.6])
    box_plot(ErrorFull[1:,:,i],SBIFit-0.2, np.clip(SBIFit+0.2,0,1),showfliers=False,widths=0.3,positions=[2.2,4.1,6,7.9])

    plt.xticks([1.6, 3.5, 5.4, 7.3,],[20,10,5,2],fontsize=32)
    plt.gca().ticklabel_format(axis='y',style='sci',scilimits=(-1,1))
    plt.grid(axis='y')
    plt.yticks(fontsize=32)

    if(i==4):
        handles = [
            Line2D([0], [0], color=SBIFit, lw=4, label='SBI'),  # Adjust color as per the actual plot color
        ]
        # Add the legenda
        plt.legend(handles=handles,loc=2, bbox_to_anchor=(0,1.15),
                   fontsize=36,columnspacing=0.3,handlelength=0.6,handletextpad=0.3)
    if(i==0):
        handles = [
            Line2D([0], [0], color=WLSFit, lw=4, label='WLS'),  # Adjust color as per the actual plot color
        ]
        # Add the legenda
        plt.legend(handles=handles,loc=2, bbox_to_anchor=(0,1.15),
                   fontsize=36,columnspacing=0.3,handlelength=0.6,handletextpad=0.3)
    if(i==1):
        handles = [
            Line2D([0], [0], color=NLLSFit, lw=4, label='NLLS'),  # Adjust color as per the actual plot color
        ]
        # Add the legenda
        plt.legend(handles=handles,loc=2, bbox_to_anchor=(0,1.15),
                   fontsize=36,columnspacing=0.3,handlelength=0.6,handletextpad=0.3)
    if(i==2):
        handles = [
            Line2D([0], [0], color=MLEFit, lw=4, label='MLE'),  # Adjust color as per the actual plot color
        ]
        # Add the legenda
        plt.legend(handles=handles,loc=2, bbox_to_anchor=(0,1.15),
                   fontsize=36,columnspacing=0.3,handlelength=0.6,handletextpad=0.3)
    if(i==3):
        handles = [
            Line2D([0], [0], color=BayFit, lw=4, label='Bayes'),  # Adjust color as per the actual plot color
        ]
        # Add the legenda
        plt.legend(handles=handles,loc=2, bbox_to_anchor=(0,1.15),
                   fontsize=36,columnspacing=0.3,handlelength=0.6,handletextpad=0.3)
if Save: plt.savefig(image_path+'ErrorsMin2.pdf',format='pdf',bbox_inches='tight',transparent=True)

# In-vivo (c-f)

### MLE

In [None]:
i=3
fdwi = './HCP_data/Pat'+str(i)+'/diff_1k.nii.gz'
bvalloc = './HCP_data/Pat'+str(i)+'/bvals_1k.txt'
bvecloc = './HCP_data/Pat'+str(i)+'/bvecs_1k.txt'

fdwi3 = './HCP_data/Pat'+str(i)+'/diff_3k.nii.gz'
bvalloc3 = './HCP_data/Pat'+str(i)+'/bvals_3k.txt'
bvecloc3 = './HCP_data/Pat'+str(i)+'/bvecs_3k.txt'

bvalsHCP = np.loadtxt(bvalloc)
bvecsHCP = np.loadtxt(bvecloc)
gtabHCP = gradient_table(bvalsHCP, bvecsHCP)

bvalsHCP3 = np.loadtxt(bvalloc3)
bvecsHCP3 = np.loadtxt(bvecloc3)
gtabHCP3 = gradient_table(bvalsHCP3, bvecsHCP3)

gtabExt  = gradient_table(np.hstack((bvalsHCP,bvalsHCP3)), np.vstack((bvecsHCP,bvecsHCP3)))

data, affine, img = load_nifti(fdwi, return_img=True)
data, affine = reslice(data, affine, (1.5,1.5,1.5), (2.5,2.5,2.5))
maskdata, mask = median_otsu(data, vol_idx=range(10, 50), median_radius=4,
                             numpass=1, autocrop=False, dilate=2)
_, mask2 = median_otsu(data, vol_idx=range(10, 50), median_radius=4,
                             numpass=1, autocrop=True, dilate=2)


data3, affine, img = load_nifti(fdwi3, return_img=True)
data3, affine = reslice(data3, affine, (1.5,1.5,1.5), (2.5,2.5,2.5))
# Get the indices of True values
true_indices = np.argwhere(mask)

# Determine the minimum and maximum indices along each dimension
min_coords = true_indices.min(axis=0)
max_coords = true_indices.max(axis=0)

maskdata  = maskdata[min_coords[0]:max_coords[0]+1,min_coords[1]:max_coords[1]+1,min_coords[2]:max_coords[2]+1]
axial_middle = maskdata.shape[2] // 2
maskdata3 = data3[min_coords[0]:max_coords[0]+1,min_coords[1]:max_coords[1]+1,min_coords[2]:max_coords[2]+1]

TestData = np.concatenate([maskdata[:, :, axial_middle, :],maskdata3[:, :, axial_middle, :]],axis=-1)
TestData4D = np.concatenate([maskdata,maskdata3],axis=-1)

# Choose the first point (arbitrary starting point, e.g., the first gradient)
selected_indices = [1]
distance_matrix = squareform(pdist(bvecsHCP))
# Iteratively select the point furthest from the current selection
for _ in range(5):  # We need 7 points in total, and one is already selected
    remaining_indices = list(set(range(len(bvecsHCP))) - set(selected_indices))
    
    # Calculate the minimum distance to the selected points for each remaining point
    min_distances = np.min(distance_matrix[remaining_indices][:, selected_indices], axis=1)
    
    # Select the point with the maximum minimum distance
    next_index = remaining_indices[np.argmax(min_distances)]
    selected_indices.append(next_index)

selected_indices7 = [0]+selected_indices

bvalsHCP7_1 = bvalsHCP[selected_indices7]
bvecsHCP7_1 = bvecsHCP[selected_indices7]

i=3
bvalloc = './HCP_data/Pat'+str(i)+'/bvals_3k.txt'
bvecloc = './HCP_data/Pat'+str(i)+'/bvecs_3k.txt'

bvalsHCP3 = np.loadtxt(bvalloc)
bvecsHCP3 = np.loadtxt(bvecloc)
gtabHCP3 = gradient_table(bvalsHCP, bvecsHCP)

# Choose the first point (arbitrary starting point, e.g., the first gradient)
selected_indices = [0]

temp_bvecs = bvecsHCP3[bvalsHCP3>0]
temp_bvals = bvalsHCP3[bvalsHCP3>0]
distance_matrix = squareform(pdist(temp_bvecs))
# Iteratively select the point furthest from the current selection
for _ in range(14):  # We need 7 points in total, and one is already selected
    remaining_indices = list(set(range(len(temp_bvecs))) - set(selected_indices))
    
    # Calculate the minimum distance to the selected points for each remaining point
    min_distances = np.min(distance_matrix[remaining_indices][:, selected_indices], axis=1)
    
    # Select the point with the maximum minimum distance
    next_index = remaining_indices[np.argmax(min_distances)]
    selected_indices.append(next_index)

bvalsHCP7_3 = temp_bvals[selected_indices]
bvecsHCP7_3 = temp_bvecs[selected_indices]

gtabHCP7 = gradient_table(np.hstack((bvalsHCP7_1,bvalsHCP7_3)), np.vstack((bvecsHCP7_1,bvecsHCP7_3)))

true_indx = []
for b in bvecsHCP7_3:
    true_indx.append(np.linalg.norm(b-bvecsHCP3,axis=1).argmin())
true_indx = selected_indices7+[t+69 for t in true_indx]
gtabHCP7 = gradient_table(np.hstack((bvalsHCP7_1,bvalsHCP7_3)), np.vstack((bvecsHCP7_1,bvecsHCP7_3)))

cutout = np.sum(TestData4D[:,:,axial_middle,:69], axis=-1) != 0

In [None]:
def residuals(params,gtab,y):

    dt        = params[:6]
    kt        = params[6:-1]
    log_S0 = params[-1]
    
    Signal = CustomDKISimulator2(dt,kt,gtab,np.exp(log_S0))
    return y - Signal

In [None]:
se = 14
torch.manual_seed(se)
np.random.seed(se)
j = 1
vL = torch.tensor([0.2*j])
vS = torch.tensor([0.01*j])  

kk = np.random.randint(0,4)
if(kk==0):
    DT_guess,KT_guess= GenDTKT([DT1_full,DT2_full],[x4_full,R1_full,x2_full,R2_full],se,1)
elif(kk==1):
    DT_guess,KT_guess = GenDTKT([DT1_lfa,DT2_lfa],[x4_lfa,R1_lfa,x2_lfa,R2_lfa],se,1)
elif(kk==2):
    DT_guess,KT_guess = GenDTKT([DT1_hfa,DT2_hfa],[x4_hfa,R1_hfa,x2_hfa,R2_hfa],se,1)
elif(kk==3):
    DT_guess,KT_guess = GenDTKT([DT1_ulfa,DT2_ulfa],[x4_ulfa,R1_ulfa,x2_ulfa,R2_ulfa],se,1)

lower_bounds = [-0.0] * 3 + [-0.015]*3 + [0]*3 + [-1]*9 + [-0.25]*3 + [np.log(2)]

upper_bounds = [5] * 3 + [0.015]*3 + [5]*3 + [1]*9 + [0.25]*3 + [np.log(20)]
bounds = list(zip(lower_bounds, upper_bounds))


In [None]:
ArrShape = TestData[:,:,0].shape
NoiseEst = np.zeros(list(ArrShape) + [22])
torch.manual_seed(10)
for i in tqdm(range(ArrShape[0])):
    for j in range(ArrShape[1]):
        if(np.sum(maskdata[i,j,axial_middle,:69],axis=-1) == 0):
            pass
        else:
            tObs = TestData[i,j,:]
            LS_x = least_squares(residuals, x0=np.hstack([invertComputeDTI(vals_to_mat(DT_guess.squeeze())),KT_guess.squeeze(),np.log(tObs[gtabExt.bvals==0].mean())]), args=(gtabExt, tObs)).x
            res = minimize(
                    rician_nll_DKI,
                    np.hstack([LS_x[:6],LS_x[6:21],np.log(10)]),
                    args=(tObs, gtabExt, np.exp(LS_x[-1])),
                    options={'disp':0,'maxfun': 100000,   # ← raise the cap on f+g evaluations
                        'maxiter': 20000,   # ← (optional) raise the cap on iterations   # ← raise the cap on f+g evaluations   # ← raise the cap on f+g evaluations
                        'gtol': 1e-6,
                        'ftol': 1e-6,},
                        bounds=bounds
                )
            NoiseEst[i,j] = res.x
NoiseEstInv = np.zeros_like(NoiseEst)
for i in range(ArrShape[0]):
    for j in range(ArrShape[1]):    
        NoiseEstInv[i,j] = np.hstack([mat_to_vals(ComputeDTI(NoiseEst[i,j,:6])),NoiseEst[i,j,6:21],NoiseEst[i,j,-1]])
NoiseEst2 =  np.zeros_like(NoiseEst)
for i in range(ArrShape[0]):
    for j in range(ArrShape[1]):    
        NoiseEst2[i,j] = np.hstack([mat_to_vals(clip_negative_eigenvalues(vals_to_mat(NoiseEstInv[i,j]))),NoiseEstInv[i,j,6:]])
MK_MLEFull  = np.zeros(ArrShape)
AK_MLEFull  = np.zeros(ArrShape)
RK_MLEFull  = np.zeros(ArrShape)
for i in range(ArrShape[0]):
    for j in range(ArrShape[1]): 
        Metrics = DKIMetrics(NoiseEst2[i,j][:6],NoiseEst2[i,j][6:21])
        MK_MLEFull[i,j] = Metrics[0]
        AK_MLEFull[i,j] = Metrics[1]
        RK_MLEFull[i,j] = Metrics[2]

In [None]:
tnorm = TwoSlopeNorm(vmin=-0,vcenter = 0.5,vmax=1)
plt.imshow((MK_MLEFull*mask2[:,:,axial_middle]).T,norm=tnorm,cmap='gray')
plt.xticks([])
plt.yticks([])
if Save: plt.savefig(image_path+'MKNLFull_MLE.pdf',format='pdf',bbox_inches='tight',transparent=True)
plt.show()
tnorm = TwoSlopeNorm(vmin=-0,vcenter = 0.5,vmax=1)
plt.imshow((AK_MLEFull*mask2[:,:,axial_middle]).T,norm=tnorm,cmap='gray')
plt.xticks([])
plt.yticks([])
if Save: plt.savefig(image_path+'AKNFull_MLE.pdf',format='pdf',bbox_inches='tight',transparent=True)
plt.show()
tnorm = TwoSlopeNorm(vmin=-0,vcenter = 0.5,vmax=1)
plt.imshow((RK_MLEFull*mask2[:,:,axial_middle]).T,norm=tnorm,cmap='gray')
plt.colorbar()
plt.xticks([])
plt.yticks([])
if Save: plt.savefig(image_path+'RKNLFull_MLE.pdf',format='pdf',bbox_inches='tight',transparent=True)

In [None]:
ArrShape = TestData4D[:,:,axial_middle,0].shape
NoiseEst = np.zeros([62, 68 ,22])
torch.manual_seed(10)
for i in tqdm(range(ArrShape[0])):
    for j in range(ArrShape[1]):
        if(np.sum(TestData4D[i,j,axial_middle,:69],axis=-1) == 0):
            pass
        else:
            tObs = TestData4D[i,j,axial_middle,true_indx]
            LS_x = least_squares(residuals, x0=np.hstack([invertComputeDTI(vals_to_mat(DT_guess.squeeze())),KT_guess.squeeze(),np.log(100)]), args=(gtabHCP7, tObs)).x
            res = minimize(
                    rician_nll_DKI,
                    np.hstack([LS_x[:6],LS_x[6:21],np.log(10)]),
                    args=(tObs, gtabHCP7, np.exp(LS_x[-1])),
                    options={'disp':0},
                        bounds=bounds
                )
            NoiseEst[i,j] = res.x
NoiseEstInv = np.zeros_like(NoiseEst)
for i in range(62):
    for j in range(68):    
        NoiseEstInv[i,j] = np.hstack([mat_to_vals(ComputeDTI(NoiseEst[i,j,:6])),NoiseEst[i,j,6:21],NoiseEst[i,j,-1]])
NoiseEst2 =  np.zeros_like(NoiseEst)
for i in range(62):
    for j in range(68):    
        NoiseEst2[i,j] = np.hstack([mat_to_vals(ComputeDTI(NoiseEst[i,j,:6])),NoiseEst[i,j,6:]])
MK_SBIMin  = np.zeros([62, 68])
AK_SBIMin  = np.zeros([62, 68])
RK_SBIMin  = np.zeros([62, 68])
MKT_SBIMin = np.zeros([62, 68])
KFA_SBIMin = np.zeros([62, 68])
for i in tqdm(range(62)):
    for j in range(68):
        Metrics = DKIMetrics(NoiseEst2[i,j][:6],NoiseEst2[i,j][6:21])
        MK_SBIMin[i,j] = Metrics[0]
        AK_SBIMin[i,j] = Metrics[1]
        RK_SBIMin[i,j] = Metrics[2]
        MKT_SBIMin[i,j] = Metrics[3]
        KFA_SBIMin[i,j] = Metrics[4]

In [None]:
tnorm = TwoSlopeNorm(vmin=-0,vcenter = 0.5,vmax=1)
plt.imshow((MK_MLEMin*mask2[:,:,axial_middle]).T,norm=tnorm,cmap='gray')
plt.xticks([])
plt.yticks([])
if Save: plt.savefig(image_path+'MKNLMin_MLE.pdf',format='pdf',bbox_inches='tight',transparent=True)
plt.show()
tnorm = TwoSlopeNorm(vmin=-0,vcenter = 0.5,vmax=1)
plt.imshow((AK_MLEMin*mask2[:,:,axial_middle]).T,norm=tnorm,cmap='gray')
plt.xticks([])
plt.yticks([])
if Save: plt.savefig(image_path+'AKNMin_MLE.pdf',format='pdf',bbox_inches='tight',transparent=True)
plt.show()
tnorm = TwoSlopeNorm(vmin=-0,vcenter = 0.5,vmax=1)
plt.imshow((RK_MLEMin*mask2[:,:,axial_middle]).T,norm=tnorm,cmap='gray')
plt.colorbar()
plt.xticks([])
plt.yticks([])
if Save: plt.savefig(image_path+'RKNLMin_MLE.pdf',format='pdf',bbox_inches='tight',transparent=True)

### Bayes

In [None]:
import numpy as np
from numpy.linalg import inv

# ------------------------------------------------------------
# 0.  Your forward model -------------------------------------
def f(theta, gtab):
    """DKI signal model for *one* voxel, given params theta."""
    dt        = theta[:6]
    kt        = theta[6:-1]
    S0        = np.exp(theta[-1])          # log-S0 in parameters
    return CustomDKISimulator2(dt, kt, gtab, S0)   # (Ngrad,)

# Finite-difference Jacobian (6 dt + K kt + 1 S0 parameters)
def jacobian(theta, gtab, eps=1e-5):
    J = np.zeros((len(gtab.bvals), len(theta)))
    f0 = f(theta, gtab)
    for i in range(len(theta)):
        t_eps        = theta.copy()
        t_eps[i]    += eps
        J[:, i]      = (f(t_eps, gtab) - f0) / eps
    return J, f0
# ------------------------------------------------------------

def vb_gauss_one_voxel(y, gtab, theta0,
                       mu0=None, V0=None,  # Gaussian prior
                       a0=1e-3, b0=1e-3,   # Gamma prior
                       max_iter=100, tol=1e-10):
    """
    Mean-field VB (Gaussian q(theta), Gamma q(1/σ²))
    for one voxel with Gaussian noise.
    """
    D = len(theta0)                       # #parameters
    if mu0 is None:
        mu0 = np.zeros(D)
    if V0 is None:
        V0 = np.eye(D) * 1e2              # very vague prior
    V0_inv = inv(V0)

    m  = theta0.copy()                    # variational mean
    S  = V0.copy()                        # variational cov
    a  = a0 + len(y)/2.
    b  = b0 + 1.0                         # will be updated
    for it in range(max_iter):
        # --- E[precision] ----------------
        lam = a / b                       # <1/σ²>

        # --- linearise model around current mean
        J, f_m = jacobian(m, gtab)
        r      = y - f_m                  # residual

        # --- Gauss-Newton-style VB updates
        S_new  = inv(V0_inv + lam * J.T @ J)
        m_new  = m + S_new @ (V0_inv @ (mu0 - m) + lam * J.T @ r)

        # --- update Gamma factors
        quad   = (r @ r
                  + np.trace(J @ S_new @ J.T))
        b_new  = b0 + 0.5 * quad

        # --- convergence check
        if np.linalg.norm(m_new - m) < tol:
            m, S, b = m_new, S_new, b_new
            break

        m, S, b = m_new, S_new, b_new

    return m, S, a/b 

In [None]:
priors_SD =  [1/0.001633,    1/.5e-5,    1/0.001633,    1/.5e-5,   1/ 7.5e-5, 1/0.001633] + [1e2]*15 + [1e2]

In [None]:
mu0 = np.zeros(22)

V0  = np.diag(priors_SD)

In [None]:
if os.path.exists(DatFolder+'BayesG_Full_One.npy'):
    NoiseEst = np.load(DatFolder+'BayesG_Full_One.npy')
else:
    ArrShape = TestData4D[:,:,axial_middle,0].shape
    NoiseEst = np.zeros([62, 68 ,22])
    torch.manual_seed(10)
    for i in tqdm(range(ArrShape[0])):
        for j in range(ArrShape[1]):
            if(np.sum(TestData4D[i,j,axial_middle,:69],axis=-1) == 0):
                pass
            else:
                tObs = TestData4D[i,j,axial_middle,:]
                LS_x = least_squares(residuals, x0=np.hstack([invertComputeDTI(vals_to_mat(DT_guess.squeeze())),KT_guess.squeeze(),np.log(tObs[gtabExt.bvals==0].mean())]), args=(gtabExt, tObs)).x
                try:
                    m, S, prec = vb_gauss_one_voxel(tObs, gtabExt, LS_x,mu0=mu0, V0=V0)
                except:
                    m = 0
                NoiseEst[i,j] = m
NoiseEstInv = np.zeros_like(NoiseEst)
for i in range(62):
    for j in range(68):    
        NoiseEstInv[i,j] = np.hstack([mat_to_vals(ComputeDTI(NoiseEst[i,j,:6])),NoiseEst[i,j,6:21],NoiseEst[i,j,-1]])
NoiseEst2 =  np.zeros_like(NoiseEst)
for i in range(62):
    for j in range(68):    
        NoiseEst2[i,j] = np.hstack([mat_to_vals(clip_negative_eigenvalues(vals_to_mat(NoiseEstInv[i,j]))),NoiseEstInv[i,j,6:]])
MK_BayFull  = np.zeros([62, 68])
AK_BayFull  = np.zeros([62, 68])
RK_BayFull  = np.zeros([62, 68])
MKT_BayFull = np.zeros([62, 68])
KFA_BayFull = np.zeros([62, 68])
for i in tqdm(range(62)):
    for j in range(68):
        Metrics = DKIMetrics(NoiseEst2[i,j][:6],NoiseEst2[i,j][6:21])
        MK_BayFull[i,j] = Metrics[0]
        AK_BayFull[i,j] = Metrics[1]
        RK_BayFull[i,j] = Metrics[2]
        MKT_BayFull[i,j] = Metrics[3]
        KFA_BayFull[i,j] = Metrics[4]

In [None]:
tnorm = TwoSlopeNorm(vmin=-0,vcenter = 0.5,vmax=1)
plt.imshow((MK_BayFull*mask2[:,:,axial_middle]).T,norm=tnorm,cmap='gray')
plt.xticks([])
plt.yticks([])
if Save: plt.savefig(image_path+'MKNLFull_Bay.pdf',format='pdf',bbox_inches='tight',transparent=True)
plt.show()
tnorm = TwoSlopeNorm(vmin=-0,vcenter = 0.5,vmax=1)
plt.imshow((AK_BayFull*mask2[:,:,axial_middle]).T,norm=tnorm,cmap='gray')
plt.xticks([])
plt.yticks([])
if Save: plt.savefig(image_path+'AKNLFull_Bay.pdf',format='pdf',bbox_inches='tight',transparent=True)
plt.show()
tnorm = TwoSlopeNorm(vmin=-0,vcenter = 0.5,vmax=1)
plt.imshow((RK_BayFull*mask2[:,:,axial_middle]).T,norm=tnorm,cmap='gray')
plt.colorbar()
plt.xticks([])
plt.yticks([])
if Save: plt.savefig(image_path+'RKNLFull_Bay.pdf',format='pdf',bbox_inches='tight',transparent=True)

In [None]:
if os.path.exists(DatFolder+'BayesG_Min_One.npy'):
    NoiseEst = np.load(DatFolder+'BayesG_Min_One.npy')
else:
    ArrShape = TestData4D[:,:,axial_middle,0].shape
    NoiseEst = np.zeros([62, 68 ,22])
    torch.manual_seed(10)
    for i in tqdm(range(ArrShape[0])):
        for j in range(ArrShape[1]):
            if(np.sum(TestData4D[i,j,axial_middle,:69],axis=-1) == 0):
                pass
            else:
                tObs = TestData4D[i,j,axial_middle,true_indx]
                LS_x = least_squares(residuals, x0=np.hstack([invertComputeDTI(vals_to_mat(DT_guess.squeeze())),KT_guess.squeeze(),np.log(100)]), args=(gtabHCP7, tObs)).x
                try:
                    m, S, prec = vb_gauss_one_voxel(tObs, gtabHCP7, LS_x,mu0=mu0, V0=V0)
                except:
                    m = 0
                NoiseEst[i,j] = m
NoiseEstInv = np.zeros_like(NoiseEst)
for i in range(62):
    for j in range(68):    
        NoiseEstInv[i,j] = np.hstack([mat_to_vals(ComputeDTI(NoiseEst[i,j,:6])),NoiseEst[i,j,6:21],NoiseEst[i,j,-1]])
NoiseEst2 =  np.zeros_like(NoiseEst)
for i in range(62):
    for j in range(68):    
        NoiseEst2[i,j] = np.hstack([mat_to_vals(clip_negative_eigenvalues(vals_to_mat(NoiseEstInv[i,j]))),NoiseEstInv[i,j,6:]])
MK_BayMin  = np.zeros([62, 68])
AK_BayMin  = np.zeros([62, 68])
RK_BayMin  = np.zeros([62, 68])
MKT_BayMin = np.zeros([62, 68])
KFA_BayMin = np.zeros([62, 68])
for i in tqdm(range(62)):
    for j in range(68):
        Metrics = DKIMetrics(NoiseEst2[i,j][:6],NoiseEst2[i,j][6:21])
        MK_BayMin[i,j] = Metrics[0]
        AK_BayMin[i,j] = Metrics[1]
        RK_BayMin[i,j] = Metrics[2]
        MKT_BayMin[i,j] = Metrics[3]
        KFA_BayMin[i,j] = Metrics[4]

In [None]:
tnorm = TwoSlopeNorm(vmin=-0,vcenter = 0.5,vmax=1)
plt.imshow((MK_BayMin*mask2[:,:,axial_middle]).T,norm=tnorm,cmap='gray')
plt.xticks([])
plt.yticks([])
if Save: plt.savefig(image_path+'MKNLMin_Bay.pdf',format='pdf',bbox_inches='tight',transparent=True)
plt.show()
tnorm = TwoSlopeNorm(vmin=-0,vcenter = 0.5,vmax=1)
plt.imshow((AK_BayMin*mask2[:,:,axial_middle]).T,norm=tnorm,cmap='gray')
plt.xticks([])
plt.yticks([])
if Save: plt.savefig(image_path+'AKNLMin_Bay.pdf',format='pdf',bbox_inches='tight',transparent=True)
plt.show()
tnorm = TwoSlopeNorm(vmin=-0,vcenter = 0.5,vmax=1)
plt.imshow((RK_BayMin*mask2[:,:,axial_middle]).T,norm=tnorm,cmap='gray')
plt.colorbar()
plt.xticks([])
plt.yticks([])
if Save: plt.savefig(image_path+'RKNLMin_Bay.pdf',format='pdf',bbox_inches='tight',transparent=True)

### WLS

In [None]:
dkimodelNL = dki.DiffusionKurtosisModel(gtabExt,fit_method='WLS')
dkifitNL = dkimodelNL.fit(TestData[:,:,:])
MK_NLFull  = np.zeros([62, 68])
AK_NLFull  = np.zeros([62, 68])
RK_NLFull  = np.zeros([62, 68])
MKT_NLFull = np.zeros([62, 68])
KFA_NLFull = np.zeros([62, 68])
for i in range(62):
    for j in range(68):
        Metrics = DKIMetrics(dkifitNL[i,j].lower_triangular(),dkifitNL[i,j].kt)
        MK_NLFull[i,j] = Metrics[0]
        AK_NLFull[i,j] = Metrics[1]
        RK_NLFull[i,j] = Metrics[2]
        MKT_NLFull[i,j] = Metrics[3]
        KFA_NLFull[i,j] = Metrics[4]

In [None]:
tnorm = TwoSlopeNorm(vmin=-0,vcenter = 0.5,vmax=1)
plt.imshow((MK_NLFull*mask2[:,:,axial_middle]).T,norm=tnorm,cmap='gray')
plt.xticks([])
plt.yticks([])
if Save: plt.savefig(image_path+'MKNLFull_WLS.pdf',format='pdf',bbox_inches='tight',transparent=True)
plt.show()
tnorm = TwoSlopeNorm(vmin=-0,vcenter = 0.5,vmax=1)
plt.imshow((AK_NLFull*mask2[:,:,axial_middle]).T,norm=tnorm,cmap='gray')
plt.xticks([])
plt.yticks([])
if Save: plt.savefig(image_path+'AKNLFull_WLS.pdf',format='pdf',bbox_inches='tight',transparent=True)
plt.show()
tnorm = TwoSlopeNorm(vmin=-0,vcenter = 0.5,vmax=1)
plt.imshow((RK_NLFull*mask2[:,:,axial_middle]).T,norm=tnorm,cmap='gray')
plt.colorbar()
plt.xticks([])
plt.yticks([])
if Save: plt.savefig(image_path+'RKNLFull_WLS.pdf',format='pdf',bbox_inches='tight',transparent=True)

In [None]:
dkimodelNL = dki.DiffusionKurtosisModel(gtabHCP7,fit_method='WLS')
dkifitNL = dkimodelNL.fit(TestData[:,:,true_indx])
MK_NLMin  = np.zeros([62, 68])
AK_NLMin  = np.zeros([62, 68])
RK_NLMin  = np.zeros([62, 68])
MKT_NLMin = np.zeros([62, 68])
KFA_NLMin = np.zeros([62, 68])
for i in range(62):
    for j in range(68):
        Metrics = DKIMetrics(dkifitNL[i,j].lower_triangular(),dkifitNL[i,j].kt)
        MK_NLMin[i,j] = Metrics[0]
        AK_NLMin[i,j] = Metrics[1]
        RK_NLMin[i,j] = Metrics[2]
        MKT_NLMin[i,j] = Metrics[3]
        KFA_NLMin[i,j] = Metrics[4]

In [None]:
tnorm = TwoSlopeNorm(vmin=-0,vcenter = 0.5,vmax=1)
plt.imshow((MK_NLMin*mask2[:,:,axial_middle]).T,norm=tnorm,cmap='gray')
plt.xticks([])
plt.yticks([])
if Save: plt.savefig(image_path+'MKNLMin_WLS.pdf',format='pdf',bbox_inches='tight',transparent=True)
plt.show()
tnorm = TwoSlopeNorm(vmin=-0,vcenter = 0.5,vmax=1)
plt.imshow((AK_NLMin*mask2[:,:,axial_middle]).T,norm=tnorm,cmap='gray')
plt.xticks([])
plt.yticks([])
if Save: plt.savefig(image_path+'AKNLMin_WLS.pdf',format='pdf',bbox_inches='tight',transparent=True)
plt.show()
tnorm = TwoSlopeNorm(vmin=-0,vcenter = 0.5,vmax=1)
plt.imshow((RK_NLMin*mask2[:,:,axial_middle]).T,norm=tnorm,cmap='gray')
plt.colorbar()
plt.xticks([])
plt.yticks([])
if Save: plt.savefig(image_path+'RKNLMin_WLS.pdf',format='pdf',bbox_inches='tight',transparent=True)

## Combined

In [None]:
TD = []
axial_middles = []
masks = []
WMs = []
for kk in tqdm(range(32)):
    fdwi = './HCP_data/Pat'+str(kk+1)+'/diff_1k.nii.gz'
    bvalloc = './HCP_data/Pat'+str(kk+1)+'/bvals_1k.txt'
    bvecloc = './HCP_data/Pat'+str(kk+1)+'/bvecs_1k.txt'
    
    fdwi3 = './HCP_data/Pat'+str(kk+1)+'/diff_3k.nii.gz'
    bvalloc3 = './HCP_data/Pat'+str(kk+1)+'/bvals_3k.txt'
    bvecloc3 = './HCP_data/Pat'+str(kk+1)+'/bvecs_3k.txt'
    
    bvalsHCP = np.loadtxt(bvalloc)
    bvecsHCP = np.loadtxt(bvecloc)
    gtabHCP = gradient_table(bvalsHCP, bvecsHCP)
    
    bvalsHCP3 = np.loadtxt(bvalloc3)
    bvecsHCP3 = np.loadtxt(bvecloc3)
    gtabHCP3 = gradient_table(bvalsHCP3, bvecsHCP3)
    
    gtabExt  = gradient_table(np.hstack((bvalsHCP,bvalsHCP3)), np.vstack((bvecsHCP,bvecsHCP3)))
    
    data, affine, img = load_nifti(fdwi, return_img=True)
    data, affine = reslice(data, affine, (1.5,1.5,1.5), (2.5,2.5,2.5))
    maskdata, mask = median_otsu(data, vol_idx=range(10, 50), median_radius=3,
                                 numpass=1, autocrop=False, dilate=2)
    _, mask2 = median_otsu(data, vol_idx=range(10, 50), median_radius=3,
                                 numpass=1, autocrop=True, dilate=2)
    
    
    data3, affine, img = load_nifti(fdwi3, return_img=True)
    data3, affine = reslice(data3, affine, (1.5,1.5,1.5), (2.5,2.5,2.5))
    # Get the indices of True values
    true_indices = np.argwhere(mask)
    
    # Determine the minimum and maximum indices along each dimension
    min_coords = true_indices.min(axis=0)
    max_coords = true_indices.max(axis=0)
    
    maskdata  = maskdata[min_coords[0]:max_coords[0]+1,min_coords[1]:max_coords[1]+1,min_coords[2]:max_coords[2]+1]
    axial_middle = maskdata.shape[2] // 2
    maskdata3 = data3[min_coords[0]:max_coords[0]+1,min_coords[1]:max_coords[1]+1,min_coords[2]:max_coords[2]+1]
    axial_middles.append(axial_middle)
    TestData = np.concatenate([maskdata[:, :, axial_middle, :],maskdata3[:, :, axial_middle, :]],axis=-1)
    TestData4D = np.concatenate([maskdata,maskdata3],axis=-1)
    TD.append(TestData4D)
    masks.append(mask[min_coords[0]:max_coords[0]+1,min_coords[1]:max_coords[1]+1,axial_middle])
    WM, affine, img = load_nifti('./HCP_data/WM_Masks/c2Pat'+str(kk+1)+'_FP.nii', return_img=True)
    WMs.append(np.fliplr(WM[:,:,axial_middle]>0.8))

In [None]:
i = 1
bvalloc = './HCP_data/Pat'+str(i)+'/bvals_1k.txt'
bvecloc = './HCP_data/Pat'+str(i)+'/bvecs_1k.txt'
bvalsHCP = np.loadtxt(bvalloc)
bvecsHCP = np.loadtxt(bvecloc)
gtabHCP = gradient_table(bvalsHCP, bvecsHCP)

# Choose the first point (arbitrary starting point, e.g., the first gradient)
selected_indices = [1]
distance_matrix = squareform(pdist(bvecsHCP))
# Iteratively select the point furthest from the current selection
for _ in range(5):  # We need 7 points in total, and one is already selected
    remaining_indices = list(set(range(len(bvecsHCP))) - set(selected_indices))
    
    # Calculate the minimum distance to the selected points for each remaining point
    min_distances = np.min(distance_matrix[remaining_indices][:, selected_indices], axis=1)
    
    # Select the point with the maximum minimum distance
    next_index = remaining_indices[np.argmax(min_distances)]
    selected_indices.append(next_index)

selected_indices7 = [0]+selected_indices

bvalsHCP7_1 = bvalsHCP[selected_indices7]
bvecsHCP7_1 = bvecsHCP[selected_indices7]

bvalloc = './HCP_data/Pat'+str(i)+'/bvals_3k.txt'
bvecloc = './HCP_data/Pat'+str(i)+'/bvecs_3k.txt'

bvalsHCP3 = np.loadtxt(bvalloc)
bvecsHCP3 = np.loadtxt(bvecloc)
gtabHCP3 = gradient_table(bvalsHCP, bvecsHCP)

# Choose the first point (arbitrary starting point, e.g., the first gradient)
selected_indices = [0]

temp_bvecs = bvecsHCP3[bvalsHCP3>0]
temp_bvals = bvalsHCP3[bvalsHCP3>0]
distance_matrix = squareform(pdist(temp_bvecs))
# Iteratively select the point furthest from the current selection
for _ in range(14):  # We need 7 points in total, and one is already selected
    remaining_indices = list(set(range(len(temp_bvecs))) - set(selected_indices))
    
    # Calculate the minimum distance to the selected points for each remaining point
    min_distances = np.min(distance_matrix[remaining_indices][:, selected_indices], axis=1)
    
    # Select the point with the maximum minimum distance
    next_index = remaining_indices[np.argmax(min_distances)]
    selected_indices.append(next_index)

bvalsHCP7_3 = temp_bvals[selected_indices]
bvecsHCP7_3 = temp_bvecs[selected_indices]

gtabHCP7 = gradient_table(np.hstack((bvalsHCP7_1,bvalsHCP7_3)), np.vstack((bvecsHCP7_1,bvecsHCP7_3)))

true_indx = []
for b in bvecsHCP7_3:
    true_indx.append(np.linalg.norm(b-bvecsHCP3,axis=1).argmin())
selected_indices7 = selected_indices7+[t+69 for t in true_indx]



In [None]:
gTabsF = []
gTabs7 = []

MinDat   = []

for i in tqdm(range(1,33)):
    fdwi = './HCP_data/Pat'+str(i)+'/diff_1k.nii.gz'
    bvalloc = './HCP_data/Pat'+str(i)+'/bvals_1k.txt'
    bvecloc = './HCP_data/Pat'+str(i)+'/bvecs_1k.txt'
    
    fdwi3 = './HCP_data/Pat'+str(i)+'/diff_3k.nii.gz'
    bvalloc3 = './HCP_data/Pat'+str(i)+'/bvals_3k.txt'
    bvecloc3 = './HCP_data/Pat'+str(i)+'/bvecs_3k.txt'
    
    bvalsHCP = np.loadtxt(bvalloc)
    bvecsHCP = np.loadtxt(bvecloc)
    gtabHCP = gradient_table(bvalsHCP, bvecsHCP)
    
    bvalsHCP3 = np.loadtxt(bvalloc3)
    bvecsHCP3 = np.loadtxt(bvecloc3)
    gtabHCP3 = gradient_table(bvalsHCP3, bvecsHCP3)
    
    gtabExt  = gradient_table(np.hstack((bvalsHCP,bvalsHCP3)), np.vstack((bvecsHCP,bvecsHCP3)))
    gTabsF.append(gtabExt)
    
    bvalsHCP7 = gtabExt.bvals[selected_indices7]
    bvecsHCP7 = gtabExt.bvecs[selected_indices7]
    gtabHCP7 = gradient_table(bvalsHCP7, bvecsHCP7)
    gTabs7.append(gtabHCP7)

In [None]:
MKFullWLArr = []
RKFullWLArr = []
AKFullWLArr = []
MKTFullWLArr = []
KFAFullWLArr = []
for kk in tqdm(range(32)):
    dkimodelNL = dki.DiffusionKurtosisModel(gTabsF[kk],fit_method='WLS')
    dkifitNL = dkimodelNL.fit(TD[kk][:,:,axial_middles[kk],:])
    ArrShape = TD[kk][:,:,axial_middles[kk],0].shape
    NoiseEst_NL = np.zeros(list(ArrShape)+[21])
    MK_NL7  = np.zeros(ArrShape)
    AK_NL7  = np.zeros(ArrShape)
    RK_NL7 = np.zeros(ArrShape)
    MKT_NL7 = np.zeros(ArrShape)
    KFA_NL7 = np.zeros(ArrShape)
    for i in range(ArrShape[0]):
        for j in range(ArrShape[1]):
            NoiseEst_NL[i,j] = np.hstack([dkifitNL[i,j].lower_triangular(),dkifitNL[i,j].kt])
    NoiseEst_NL2 =  np.zeros_like(NoiseEst_NL)
    for i in range(ArrShape[0]):
        for j in range(ArrShape[1]):
            NoiseEst_NL2[i,j] = np.hstack([mat_to_vals(clip_negative_eigenvalues(vals_to_mat(NoiseEst_NL[i,j]))),NoiseEst_NL[i,j,6:]])
    for i in range(ArrShape[0]):
        for j in range(ArrShape[1]):
            Metrics = DKIMetrics(NoiseEst_NL2[i,j][:6],NoiseEst_NL2[i,j][6:21])
            MK_NL7[i,j] = Metrics[0]
            AK_NL7[i,j] = Metrics[1]
            RK_NL7[i,j] = Metrics[2]
            MKT_NL7[i,j] = Metrics[3]
            KFA_NL7[i,j] = Metrics[4]
    MKFullWLArr.append(MK_NL7)
    RKFullWLArr.append(RK_NL7)
    AKFullWLArr.append(AK_NL7)
    MKTFullWLArr.append(MKT_NL7)
    KFAFullWLArr.append(KFA_NL7)

In [None]:
MKMinWLArr = []
RKMinWLArr = []
AKMinWLArr = []
MKTMinWLArr = []
KFAMinWLArr = []
for kk in tqdm(range(32)):
    dkimodelNL = dki.DiffusionKurtosisModel(gTabs7[kk],fit_method='WLS')
    dkifitNL = dkimodelNL.fit(TD[kk][:,:,axial_middles[kk],selected_indices7])
    ArrShape = TD[kk][:,:,axial_middles[kk],0].shape
    NoiseEst_NL = np.zeros(list(ArrShape)+[21])
    MK_NL7  = np.zeros(ArrShape)
    AK_NL7  = np.zeros(ArrShape)
    RK_NL7 = np.zeros(ArrShape)
    MKT_NL7 = np.zeros(ArrShape)
    KFA_NL7 = np.zeros(ArrShape)
    for i in range(ArrShape[0]):
        for j in range(ArrShape[1]):
            NoiseEst_NL[i,j] = np.hstack([dkifitNL[i,j].lower_triangular(),dkifitNL[i,j].kt])
    NoiseEst_NL2 =  np.zeros_like(NoiseEst_NL)
    for i in range(ArrShape[0]):
        for j in range(ArrShape[1]):
            NoiseEst_NL2[i,j] = np.hstack([mat_to_vals(clip_negative_eigenvalues(vals_to_mat(NoiseEst_NL[i,j]))),NoiseEst_NL[i,j,6:]])
    for i in range(ArrShape[0]):
        for j in range(ArrShape[1]):
            Metrics = DKIMetrics(NoiseEst_NL2[i,j][:6],NoiseEst_NL2[i,j][6:21])
            MK_NL7[i,j] = Metrics[0]
            AK_NL7[i,j] = Metrics[1]
            RK_NL7[i,j] = Metrics[2]
            MKT_NL7[i,j] = Metrics[3]
            KFA_NL7[i,j] = Metrics[4]
    MKMinWLArr.append(MK_NL7)
    RKMinWLArr.append(RK_NL7)
    AKMinWLArr.append(AK_NL7)
    MKTMinWLArr.append(MKT_NL7)
    KFAMinWLArr.append(KFA_NL7)

In [None]:
MKFullNLArr = []
RKFullNLArr = []
AKFullNLArr = []
MKTFullNLArr = []
KFAFullNLArr = []
for kk in tqdm(range(32)):
    dkimodelNL = dki.DiffusionKurtosisModel(gTabsF[kk],fit_method='NLLS')
    dkifitNL = dkimodelNL.fit(TD[kk][:,:,axial_middles[kk],:])
    ArrShape = TD[kk][:,:,axial_middles[kk],0].shape
    NoiseEst_NL = np.zeros(list(ArrShape)+[21])
    MK_NL7  = np.zeros(ArrShape)
    AK_NL7  = np.zeros(ArrShape)
    RK_NL7 = np.zeros(ArrShape)
    MKT_NL7 = np.zeros(ArrShape)
    KFA_NL7 = np.zeros(ArrShape)
    for i in range(ArrShape[0]):
        for j in range(ArrShape[1]):
            NoiseEst_NL[i,j] = np.hstack([dkifitNL[i,j].lower_triangular(),dkifitNL[i,j].kt])
    NoiseEst_NL2 =  np.zeros_like(NoiseEst_NL)
    for i in range(ArrShape[0]):
        for j in range(ArrShape[1]):
            NoiseEst_NL2[i,j] = np.hstack([mat_to_vals(clip_negative_eigenvalues(vals_to_mat(NoiseEst_NL[i,j]))),NoiseEst_NL[i,j,6:]])
    for i in range(ArrShape[0]):
        for j in range(ArrShape[1]):
            Metrics = DKIMetrics(NoiseEst_NL2[i,j][:6],NoiseEst_NL2[i,j][6:21])
            MK_NL7[i,j] = Metrics[0]
            AK_NL7[i,j] = Metrics[1]
            RK_NL7[i,j] = Metrics[2]
            MKT_NL7[i,j] = Metrics[3]
            KFA_NL7[i,j] = Metrics[4]
    MKFullNLArr.append(MK_NL7)
    RKFullNLArr.append(RK_NL7)
    AKFullNLArr.append(AK_NL7)
    MKTFullNLArr.append(MKT_NL7)
    KFAFullNLArr.append(KFA_NL7)

In [None]:
MKMinNLArr = []
RKMinNLArr = []
AKMinNLArr = []
MKTMinNLArr = []
KFAMinNLArr = []
for kk in tqdm(range(32)):
    dkimodelNL = dki.DiffusionKurtosisModel(gTabs7[kk],fit_method='NLLS')
    dkifitNL = dkimodelNL.fit(TD[kk][:,:,axial_middles[kk],selected_indices7])
    ArrShape = TD[kk][:,:,axial_middles[kk],0].shape
    NoiseEst_NL = np.zeros(list(ArrShape)+[21])
    MK_NL7  = np.zeros(ArrShape)
    AK_NL7  = np.zeros(ArrShape)
    RK_NL7 = np.zeros(ArrShape)
    MKT_NL7 = np.zeros(ArrShape)
    KFA_NL7 = np.zeros(ArrShape)
    for i in range(ArrShape[0]):
        for j in range(ArrShape[1]):
            NoiseEst_NL[i,j] = np.hstack([dkifitNL[i,j].lower_triangular(),dkifitNL[i,j].kt])
    NoiseEst_NL2 =  np.zeros_like(NoiseEst_NL)
    for i in range(ArrShape[0]):
        for j in range(ArrShape[1]):
            NoiseEst_NL2[i,j] = np.hstack([mat_to_vals(clip_negative_eigenvalues(vals_to_mat(NoiseEst_NL[i,j]))),NoiseEst_NL[i,j,6:]])
    for i in range(ArrShape[0]):
        for j in range(ArrShape[1]):
            Metrics = DKIMetrics(NoiseEst_NL2[i,j][:6],NoiseEst_NL2[i,j][6:21])
            MK_NL7[i,j] = Metrics[0]
            AK_NL7[i,j] = Metrics[1]
            RK_NL7[i,j] = Metrics[2]
            MKT_NL7[i,j] = Metrics[3]
            KFA_NL7[i,j] = Metrics[4]
    MKMinNLArr.append(MK_NL7)
    RKMinNLArr.append(RK_NL7)
    AKMinNLArr.append(AK_NL7)
    MKTMinNLArr.append(MKT_NL7)
    KFAMinNLArr.append(KFA_NL7)

In [None]:
if os.path.exists(DatFolder+'Full_MK_HCP_MLE.npy'):
    MK_MLE_Full = np.load(DatFolder+'Full_MK_HCP_MLE.npy',allow_pickle=True)
    AK_MLE_Full = np.load(DatFolder+'Full_AK_HCP_MLE.npy',allow_pickle=True)
    RK_MLE_Full = np.load(DatFolder+'Full_RK_HCP_MLE.npy',allow_pickle=True)
else:
    MK_MLE_Full = []
    AK_MLE_Full = []
    RK_MLE_Full = []
    for kk in tqdm(range(1,33)):
        ArrShape = TD[kk][:,:,axial_middles[kk],0].shape
        NoiseEst = np.zeros(list(ArrShape) + [22])
        torch.manual_seed(10)
        for i in tqdm(range(ArrShape[0])):
            for j in range(ArrShape[1]):
                if(np.sum(TD[kk][:,:,axial_middles[kk],:],axis=-1) == 0):
                    pass
                else:
                    tObs = TD[kk][:,:,axial_middles[kk]]
                    LS_x = least_squares(residuals, x0=np.hstack([invertComputeDTI(vals_to_mat(DT_guess.squeeze())),KT_guess.squeeze(),np.log(tObs[gtabF[kk].bvals==0].mean())]), args=(gtabsF[kk], tObs)).x
                    res = Fullimize(
                            rician_nll_DKI,
                            np.hstack([LS_x[:6],LS_x[6:21],np.log(10)]),
                            args=(tObs, gtabExt, np.exp(LS_x[-1])),
                            options={'disp':0,'maxfun': 100000,   # ← raise the cap on f+g evaluations
                                'maxiter': 20000,   # ← (optional) raise the cap on iterations   # ← raise the cap on f+g evaluations   # ← raise the cap on f+g evaluations
                                'gtol': 1e-6,
                                'ftol': 1e-6,},
                                bounds=bounds
                        )
                    NoiseEst[i,j] = res.x
        NoiseEstInv = np.zeros_like(NoiseEst)
        for i in range(ArrShape[0]):
            for j in range(ArrShape[1]):    
                NoiseEstInv[i,j] = np.hstack([mat_to_vals(ComputeDTI(NoiseEst[i,j,:6])),NoiseEst[i,j,6:21],NoiseEst[i,j,-1]])
        NoiseEst2 =  np.zeros_like(NoiseEst)
        for i in range(ArrShape[0]):
            for j in range(ArrShape[1]):    
                NoiseEst2[i,j] = np.hstack([mat_to_vals(clip_negative_eigenvalues(vals_to_mat(NoiseEstInv[i,j]))),NoiseEstInv[i,j,6:]])
        MK_MLEFull  = np.zeros(ArrShape)
        AK_MLEFull  = np.zeros(ArrShape)
        RK_MLEFull  = np.zeros(ArrShape)
        for i in range(ArrShape[0]):
            for j in range(ArrShape[1]): 
                Metrics = DKIMetrics(NoiseEst2[i,j][:6],NoiseEst2[i,j][6:21])
                MK_MLEFull[i,j] = Metrics[0]
                AK_MLEFull[i,j] = Metrics[1]
                RK_MLEFull[i,j] = Metrics[2]
        MKFulMLEArr.append(MK_MLEFull)
        AK_MLE_Full.append(AK_MLEFull)
        RK_MLE_Full.append(RK_MLEFull)

In [None]:
if os.path.exists(DatFolder+'Min_MK_HCP_MLE.npy'):
    MK_MLE_Min = np.load(DatFolder+'Min_MK_HCP_MLE.npy',allow_pickle=True)
    AK_MLE_Min = np.load(DatFolder+'Min_AK_HCP_MLE.npy',allow_pickle=True)
    RK_MLE_Min = np.load(DatFolder+'Min_RK_HCP_MLE.npy',allow_pickle=True)
else:
    MK_MLE_Min = []
    AK_MLE_Min = []
    RK_MLE_Min = []
    for kk in tqdm(range(1,33)):
        ArrShape = TD[kk][:,:,axial_middles[kk],0].shape
        NoiseEst = np.zeros(list(ArrShape) + [22])
        torch.manual_seed(10)
        for i in tqdm(range(ArrShape[0])):
            for j in range(ArrShape[1]):
                if(np.sum(TD[kk][:,:,axial_middles[kk],:],axis=-1) == 0):
                    pass
                else:
                    tObs = TD[kk][:,:,axial_middles[kk],selected_indices7]
                    LS_x = least_squares(residuals, x0=np.hstack([invertComputeDTI(vals_to_mat(DT_guess.squeeze())),KT_guess.squeeze(),np.log(tObs[gtabF[kk].bvals==0].mean())]), args=(gtabs7[kk], tObs)).x
                    res = Minimize(
                            rician_nll_DKI,
                            np.hstack([LS_x[:6],LS_x[6:21],np.log(10)]),
                            args=(tObs, gtabExt, np.exp(LS_x[-1])),
                            options={'disp':0,'maxfun': 100000,   # ← raise the cap on f+g evaluations
                                'maxiter': 20000,   # ← (optional) raise the cap on iterations   # ← raise the cap on f+g evaluations   # ← raise the cap on f+g evaluations
                                'gtol': 1e-6,
                                'ftol': 1e-6,},
                                bounds=bounds
                        )
                    NoiseEst[i,j] = res.x
        NoiseEstInv = np.zeros_like(NoiseEst)
        for i in range(ArrShape[0]):
            for j in range(ArrShape[1]):    
                NoiseEstInv[i,j] = np.hstack([mat_to_vals(ComputeDTI(NoiseEst[i,j,:6])),NoiseEst[i,j,6:21],NoiseEst[i,j,-1]])
        NoiseEst2 =  np.zeros_like(NoiseEst)
        for i in range(ArrShape[0]):
            for j in range(ArrShape[1]):    
                NoiseEst2[i,j] = np.hstack([mat_to_vals(clip_negative_eigenvalues(vals_to_mat(NoiseEstInv[i,j]))),NoiseEstInv[i,j,6:]])
        MK_MLEMin  = np.zeros(ArrShape)
        AK_MLEMin  = np.zeros(ArrShape)
        RK_MLEMin  = np.zeros(ArrShape)
        for i in range(ArrShape[0]):
            for j in range(ArrShape[1]): 
                Metrics = DKIMetrics(NoiseEst2[i,j][:6],NoiseEst2[i,j][6:21])
                MK_MLEMin[i,j] = Metrics[0]
                AK_MLEMin[i,j] = Metrics[1]
                RK_MLEMin[i,j] = Metrics[2]
        MKFulMLEArr.append(MK_MLEMin)
        AK_MLE_Min.append(AK_MLEMin)
        RK_MLE_Min.append(RK_MLEMin)

In [None]:
if os.path.exists(DatFolder+'Full_MK_HCP_Bay.npy'):
    MK_Bay_Full = np.load(DatFolder+'Full_MK_HCP_Bay.npy',allow_pickle=True)
    AK_Bay_Full = np.load(DatFolder+'Full_AK_HCP_Bay.npy',allow_pickle=True)
    RK_Bay_Full = np.load(DatFolder+'Full_RK_HCP_Bay.npy',allow_pickle=True)
else:
    MK_Bay_Full = []
    AK_Bay_Full = []
    RK_Bay_Full = []
    for kk in tqdm(range(1,33)):
        ArrShape = TD[kk][:,:,axial_middles[kk],0].shape
        NoiseEst = np.zeros(list(ArrShape) + [22])
        torch.manual_seed(10)
        for i in tqdm(range(ArrShape[0])):
            for j in range(ArrShape[1]):
                if(np.sum(TD[kk][:,:,axial_middles[kk],:],axis=-1) == 0):
                    pass
                else:
                    tObs = TD[kk][:,:,axial_middles[kk]]
                    LS_x = least_squares(residuals, x0=np.hstack([invertComputeDTI(vals_to_mat(DT_guess.squeeze())),KT_guess.squeeze(),np.log(100)]), args=(gtabsF[kk], tObs)).x
                    try:
                        m, S, prec = vb_gauss_one_voxel(tObs, gtabHCP7, LS_x,mu0=mu0, V0=V0)
                    except:
                        m = 0
                    NoiseEst[i,j] = m
        NoiseEstInv = np.zeros_like(NoiseEst)
        for i in range(ArrShape[0]):
            for j in range(ArrShape[1]):    
                NoiseEstInv[i,j] = np.hstack([mat_to_vals(ComputeDTI(NoiseEst[i,j,:6])),NoiseEst[i,j,6:21],NoiseEst[i,j,-1]])
        NoiseEst2 =  np.zeros_like(NoiseEst)
        for i in range(ArrShape[0]):
            for j in range(ArrShape[1]):    
                NoiseEst2[i,j] = np.hstack([mat_to_vals(clip_negative_eigenvalues(vals_to_mat(NoiseEstInv[i,j]))),NoiseEstInv[i,j,6:]])
        MK_BayFull  = np.zeros(ArrShape)
        AK_BayFull  = np.zeros(ArrShape)
        RK_BayFull  = np.zeros(ArrShape)
        for i in range(ArrShape[0]):
            for j in range(ArrShape[1]): 
                Metrics = DKIMetrics(NoiseEst2[i,j][:6],NoiseEst2[i,j][6:21])
                MK_BayFull[i,j] = Metrics[0]
                AK_BayFull[i,j] = Metrics[1]
                RK_BayFull[i,j] = Metrics[2]
        MKFulBayArr.append(MK_BayFull)
        AK_Bay_Full.append(AK_BayFull)
        RK_Bay_Full.append(RK_BayFull)

In [None]:
if os.path.exists(DatFolder+'Min_MK_HCP_Bay.npy'):
    MK_Bay_Min = np.load(DatFolder+'Min_MK_HCP_Bay.npy',allow_pickle=True)
    AK_Bay_Min = np.load(DatFolder+'Min_AK_HCP_Bay.npy',allow_pickle=True)
    RK_Bay_Min = np.load(DatFolder+'Min_RK_HCP_Bay.npy',allow_pickle=True)
else:
    MK_Bay_Min = []
    AK_Bay_Min = []
    RK_Bay_Min = []
    for kk in tqdm(range(1,33)):
        ArrShape = TD[kk][:,:,axial_middles[kk],0].shape
        NoiseEst = np.zeros(list(ArrShape) + [22])
        torch.manual_seed(10)
        for i in tqdm(range(ArrShape[0])):
            for j in range(ArrShape[1]):
                if(np.sum(TD[kk][:,:,axial_middles[kk],:],axis=-1) == 0):
                    pass
                else:
                    tObs = TD[kk][:,:,axial_middles[kk],selected_indices7]
                    LS_x = least_squares(residuals, x0=np.hstack([invertComputeDTI(vals_to_mat(DT_guess.squeeze())),KT_guess.squeeze(),np.log(100)]), args=(gtabs7[kk], tObs)).x
                    try:
                        m, S, prec = vb_gauss_one_voxel(tObs, gtabHCP7, LS_x,mu0=mu0, V0=V0)
                    except:
                        m = 0
                    NoiseEst[i,j] = res.x
        NoiseEstInv = np.zeros_like(NoiseEst)
        for i in range(ArrShape[0]):
            for j in range(ArrShape[1]):    
                NoiseEstInv[i,j] = np.hstack([mat_to_vals(ComputeDTI(NoiseEst[i,j,:6])),NoiseEst[i,j,6:21],NoiseEst[i,j,-1]])
        NoiseEst2 =  np.zeros_like(NoiseEst)
        for i in range(ArrShape[0]):
            for j in range(ArrShape[1]):    
                NoiseEst2[i,j] = np.hstack([mat_to_vals(clip_negative_eigenvalues(vals_to_mat(NoiseEstInv[i,j]))),NoiseEstInv[i,j,6:]])
        MK_BayMin  = np.zeros(ArrShape)
        AK_BayMin  = np.zeros(ArrShape)
        RK_BayMin  = np.zeros(ArrShape)
        for i in range(ArrShape[0]):
            for j in range(ArrShape[1]): 
                Metrics = DKIMetrics(NoiseEst2[i,j][:6],NoiseEst2[i,j][6:21])
                MK_BayMin[i,j] = Metrics[0]
                AK_BayMin[i,j] = Metrics[1]
                RK_BayMin[i,j] = Metrics[2]
        MKFulBayArr.append(MK_BayMin)
        AK_Bay_Min.append(AK_BayMin)
        RK_Bay_Min.append(RK_BayMin)

In [None]:
SSIM_bay = []
for i in range(32):
    NS1 = MK_Bay_Min[i].astype(np.float32)
    NS2 = MK_Bay_Full[i].astype(np.float32)
    Ma = masks[i]
    result = masked_local_ssim(NS1, NS2, Ma, win_size=7,dat_range = 1)
    SSIM_bay.append(result)

SSIM_MLE = []
for i in range(32):
    NS1 = MK_MLE_Min[i].astype(np.float32)
    NS2 = MK_MLE_Full[i].astype(np.float32)
    Ma = masks[i]
    result = masked_local_ssim(NS1, NS2, Ma, win_size=7,dat_range = 1)
    SSIM_MLE.append(result)

SSIM_NL = []
for i in range(32):
    NS1 = MKMinNLArr[i]
    NS2 = MKFullNLArr[i]
    Ma = masks[i]
    result = masked_local_ssim(NS1, NS2, Ma, win_size=7,dat_range = 1)
    SSIM_NL.append(result)

SSIM_WL = []
for i in range(32):
    NS1 = MKMinWLArr[i]
    NS2 = MKFullWLArr[i]
    Ma = masks[i]
    result = masked_local_ssim(NS1, NS2, Ma, win_size=7,dat_range = 1)
    SSIM_WL.append(result)

SSIM_SBI= []
for i in range(32):
    NS1 = MKMinArr[i]
    NS2 = MKFullArr[i]
    Ma = masks[i]
    result = masked_local_ssim(NS1, NS2, Ma, win_size=7,dat_range = 1)
    SSIM_SBI.append(result)

In [None]:
Save = False

In [None]:
fig, ax = plt.subplots(figsize=(3.2,4.8))#, sharex=True)

y_data = np.array(SSIM_SBI)
g_pos = np.array([1])
colors = ['mediumturquoise']
colors2 = ['paleturquoise']

BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2,scatter=True)

y_data = np.array(SSIM_NL)
g_pos = np.array([2])
colors = ['sandybrown']
colors2 = ['peachpuff']


BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2,scatter=True)

y_data = np.array(SSIM_WL)
g_pos = np.array([3])
colors = [WLSFit]
colors2 = ['lightsalmon']

BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2,scatter=True)

plt.axhline(0.66, lw=3, ls='--', c='k')

y_data = np.array(SSIM_MLE)
g_pos = np.array([4])
colors = [MLEFit]
colors2 = ['lightsteelblue']

BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2,scatter=True)

plt.axhline(0.66, lw=3, ls='--', c='k')

y_data = np.array(SSIM_bay)
g_pos = np.array([5])
colors = [BayFit]
colors2 = ['lavender']

BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2,scatter=True)

plt.axhline(0.66, lw=3, ls='--', c='k')

plt.yticks([0,0.2,0.4,0.6,0.8])
plt.xticks([1,2,3,4,5],['SBI','NLLS','WLS','MLE','Bay.'],fontsize=32,rotation=90)

if Save: plt.savefig(FigLoc+'DKIHCP_SSIM_MK_Other.pdf',format='PDF',transparent=True,bbox_inches='tight')

In [None]:
SSIM_bay = []
for i in range(32):
    NS1 = AK_Bay_Min[i].astype(np.float32)
    NS2 = AK_Bay_Full[i].astype(np.float32)
    NS1 = gaussian_filter(NS1, sigma=0.5)
    NS2 = gaussian_filter(NS2, sigma=0.5)
    Ma = masks[i]
    result = masked_local_ssim(NS1, NS2, Ma, win_size=7,dat_range = 1)
    SSIM_bay.append(result)

SSIM_MLE = []
for i in range(32):
    NS1 = AK_MLE_Min[i].astype(np.float32)
    NS2 = AK_MLE_Full[i].astype(np.float32)
    NS1 = gaussian_filter(NS1, sigma=0.5)
    NS2 = gaussian_filter(NS2, sigma=0.5)
    Ma = masks[i]
    result = masked_local_ssim(NS1, NS2, Ma, win_size=7,dat_range = 1)
    SSIM_MLE.append(result)

SSIM_NL = []
for i in range(32):
    NS1 = AKMinNLArr[i]
    NS2 = AKFullNLArr[i]
    NS1 = gaussian_filter(NS1, sigma=0.5)
    NS2 = gaussian_filter(NS2, sigma=0.5)
    Ma = masks[i]
    result = masked_local_ssim(NS1, NS2, Ma, win_size=7,dat_range = 1)
    SSIM_NL.append(result)

SSIM_WL = []
for i in range(32):
    NS1 = AKMinWLArr[i]
    NS2 = AKFullWLArr[i]
    NS1 = gaussian_filter(NS1, sigma=0.5)
    NS2 = gaussian_filter(NS2, sigma=0.5)
    Ma = masks[i]
    result = masked_local_ssim(NS1, NS2, Ma, win_size=7,dat_range = 1)
    SSIM_WL.append(result)

SSIM_SBI= []
for i in range(32):
    NS1 =AKMinArr[i]
    NS1 = gaussian_filter(NS1, sigma=0.5)
    NS2 =AKFullArr[i]
    NS2 = gaussian_filter(NS2, sigma=0.5)
    Ma = masks[i]
    result = masked_local_ssim(NS1, NS2, Ma, win_size=7,dat_range=1)
    SSIM_SBI.append(result)

In [None]:
fig, ax = plt.subplots(figsize=(3.2,4.8))#, sharex=True)

y_data = np.array(SSIM_SBI)
g_pos = np.array([1])
colors = ['mediumturquoise']
colors2 = ['paleturquoise']

BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2,scatter=True)

y_data = np.array(SSIM_NL)
g_pos = np.array([2])
colors = ['sandybrown']
colors2 = ['peachpuff']


BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2,scatter=True)

y_data = np.array(SSIM_WL)
g_pos = np.array([3])
colors = [WLSFit]
colors2 = ['lightsalmon']

BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2,scatter=True)

plt.axhline(0.66, lw=3, ls='--', c='k')

y_data = np.array(SSIM_MLE)
g_pos = np.array([4])
colors = [MLEFit]
colors2 = ['lightsteelblue']

BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2,scatter=True)

plt.axhline(0.66, lw=3, ls='--', c='k')

y_data = np.array(SSIM_bay)
g_pos = np.array([5])
colors = [BayFit]
colors2 = ['lavender']

BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2,scatter=True)

plt.axhline(0.66, lw=3, ls='--', c='k')

plt.yticks([0,0.2,0.4,0.6,0.8])
plt.ylim([0,1])
plt.xticks([1,2,3,4,5],['SBI','NLLS','WLS','MLE','Bay.'],fontsize=32,rotation=90)

if Save: plt.savefig(FigLoc+'DKIHCP_SSIM_AK_Other.pdf',format='PDF',transparent=True,bbox_inches='tight')

In [None]:
SSIM_bay = []
for i in range(32):
    NS1 = RK_Bay_Min[i].astype(np.float32)
    NS2 = RK_Bay_Full[i].astype(np.float32)
    NS1 = gaussian_filter(NS1, sigma=0.5)
    Ma = masks[i]
    result = masked_local_ssim(NS1, NS2, Ma, win_size=7,dat_range = 1)
    SSIM_bay.append(result)

SSIM_MLE = []
for i in range(32):
    NS1 = RK_MLE_Min[i].astype(np.float32)
    NS2 = RK_MLE_Full[i].astype(np.float32)
    NS1 = gaussian_filter(NS1, sigma=0.5)
    Ma = masks[i]
    result = masked_local_ssim(NS1, NS2, Ma, win_size=7,dat_range = 1)
    SSIM_MLE.append(result)

SSIM_NL = []
for i in range(32):
    NS1 = RKMinNLArr[i]
    NS2 = RKFullNLArr[i]
    NS1 = gaussian_filter(NS1, sigma=0.5)
    Ma = masks[i]
    result = masked_local_ssim(NS1, NS2, Ma, win_size=7,dat_range = 1)
    SSIM_NL.append(result)

SSIM_WL = []
for i in range(32):
    NS1 = RKMinWLArr[i]
    NS2 = RKFullWLArr[i]
    NS1 = gaussian_filter(NS1, sigma=0.5)
    Ma = masks[i]
    result = masked_local_ssim(NS1, NS2, Ma, win_size=7,dat_range = 1)
    SSIM_WL.append(result)

SSIM_SBI= []
for i in range(32):
    NS1 =RKMinArr[i]
    NS1 = gaussian_filter(NS1, sigma=0.5)
    NS2 =RKFullArr[i]
    Ma = masks[i]
    result = masked_local_ssim(NS1, NS2, Ma, win_size=7,dat_range=1)
    SSIM_SBI.append(result)

In [None]:
fig, ax = plt.subplots(figsize=(3.2,4.8))#, sharex=True)

y_data = np.array(SSIM_SBI)
g_pos = np.array([1])
colors = ['mediumturquoise']
colors2 = ['paleturquoise']

BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2,scatter=True)

y_data = np.array(SSIM_NL)
g_pos = np.array([2])
colors = ['sandybrown']
colors2 = ['peachpuff']


BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2,scatter=True)

y_data = np.array(SSIM_WL)
g_pos = np.array([3])
colors = [WLSFit]
colors2 = ['lightsalmon']

BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2,scatter=True)

plt.axhline(0.66, lw=3, ls='--', c='k')

y_data = np.array(SSIM_MLE)
g_pos = np.array([4])
colors = [MLEFit]
colors2 = ['lightsteelblue']

BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2,scatter=True)

plt.axhline(0.66, lw=3, ls='--', c='k')

y_data = np.array(SSIM_bay)
g_pos = np.array([5])
colors = [BayFit]
colors2 = ['lavender']

BoxPlots(y_data,g_pos,colors,colors2,ax,widths=0.2,scatter=True)

plt.axhline(0.66, lw=3, ls='--', c='k')

plt.yticks([0,0.2,0.4,0.6,0.8])
plt.ylim([-0.1,1])
plt.xticks([1,2,3,4,5],['SBI','NLLS','WLS','MLE','Bay.'],fontsize=32,rotation=90)

if Save: plt.savefig(FigLoc+'DKIHCP_SSIM_RK_Other.pdf',format='PDF',transparent=True,bbox_inches='tight')