# Front matter and Imports

## Imports

In [None]:
import numpy as np
import dill as pickle
import os
import tqdm

import scipy.stats as stats
from scipy.spatial.distance import pdist, squareform
from scipy.stats import lognorm,norm

from skimage.metrics import structural_similarity as ssim

import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
from matplotlib.gridspec import GridSpec
from matplotlib.colors import TwoSlopeNorm
import matplotlib.patches as mpatches
from matplotlib.ticker import FormatStrFormatter

plt.rc('text', usetex=True)
font = {'family' : 'serif',
        'size'   : 14}
plt.rc('font', **font)
plt.rc('ytick', labelsize=24)
plt.rc('xtick', labelsize=24)
plt.rc('text.latex', preamble=r'\usepackage{color}')
plt.rcParams['axes.spines.top'] = False
plt.rcParams['axes.spines.right'] = False
plt.rcParams['legend.frameon'] = False

from sbi import analysis as analysis
from sbi import utils as utils
from sbi.inference import SNPE, simulate_for_sbi
from sbi.inference.potentials.posterior_based_potential import posterior_estimator_based_potential
from sbi.utils.user_input_checks import (
    check_sbi_inputs,
    process_prior,
    process_simulator,
)
from sbi.utils import process_prior,BoxUniform

import torch
from torch.distributions import Categorical,Normal
from torch import Tensor


from dipy.sims.voxel import single_tensor
from dipy.data import get_fnames
from dipy.io.gradients import read_bvals_bvecs
from dipy.core.gradients import gradient_table
from dipy.reconst.dti import (decompose_tensor, from_lower_triangular)
from dipy.io.image import load_nifti
from dipy.segment.mask import median_otsu
import dipy.reconst.dti as dti
import dipy.reconst.dki as dki
from dipy.align.reslice import reslice
from dipy.core.sphere import disperse_charges, Sphere, HemiSphere

## Helper Functions

### DTI Functions

In [None]:
def vals_to_mat(dt):
    DTI = np.zeros((3,3))
    DTI[0,0] = dt[0]
    DTI[0,1],DTI[1,0] =  dt[1],dt[1]
    DTI[1,1] =  dt[2]
    DTI[0,2],DTI[2,0] =  dt[3],dt[3]
    DTI[1,2],DTI[2,1] =  dt[4],dt[4]
    DTI[2,2] =  dt[5]
    return DTI

def mat_to_vals(DTI):
    dt = np.zeros(6)
    dt[0] = DTI[0,0]
    dt[1] = DTI[0,1]
    dt[2] = DTI[1,1]
    dt[3] = DTI[0,2]
    dt[4] = DTI[1,2]
    dt[5] = DTI[2,2]
    return dt

def fill_lower_diag(a):
    b = [a[0],a[3],a[1],a[4],a[5],a[2]]
    n = 3
    mask = np.tri(n,dtype=bool) 
    out = np.zeros((n,n),dtype=float)
    out[mask] = b
    return out

def ComputeDTI(params):
    L = fill_lower_diag(params)
    
    np.fill_diagonal(L, np.abs(np.diagonal(L)))

    A = L @ L.T
    return A

def ForceLowFA(dt):
    # Modify the matrix to ensure low FA (more isotropic)
    eigenvalues, eigenvectors = np.linalg.eigh(dt)
    
    # Make the eigenvalues more similar to enforce low FA
    mean_eigenvalue = np.mean(eigenvalues)

    adjusted_eigenvalues = np.clip(eigenvalues, mean_eigenvalue * np.random.rand(), mean_eigenvalue * 1.0)
    
    # Reconstruct the matrix with the adjusted eigenvalues
    dt_low_fa = eigenvectors @ np.diag(adjusted_eigenvalues) @ eigenvectors.T
    
    return dt_low_fa
    
def FracAni(evals,MD):
    numerator = np.sqrt(3 * np.sum((evals - MD) ** 2))
    denominator = np.sqrt(2) * np.sqrt(np.sum(evals ** 2))
    
    return numerator / denominator

def clip_negative_eigenvalues(matrix):
    # Perform eigenvalue decomposition
    eigenvalues, eigenvectors = np.linalg.eig(matrix)
    
    # Clip negative eigenvalues to 0
    clipped_eigenvalues = np.maximum(eigenvalues, 1e-5)
    
    # Reconstruct the matrix with the clipped eigenvalues
    clipped_matrix = eigenvectors @ np.diag(clipped_eigenvalues) @ np.linalg.inv(eigenvectors)
    
    return clipped_matrix


### DKI Functions

In [None]:
def FitDT(Dat,seed=1):

    np.random.seed(seed)
    # DT_abc
    data = Dat[:,0]
    shape,loc,scale = lognorm.fit(data)
    
    dti1_fitted = stats.lognorm(shape, loc=loc, scale=scale)

    #DT_rest
    data = Dat[:,1]
    loc,scale = norm.fit(data)
    
    # Compute the fitted PDF
    dti2_fitted = stats.norm(loc=loc, scale=scale)

    return dti1_fitted,dti2_fitted

def FitKT(Dat,seed=1):
    np.random.seed(seed)    
    # Fitting x4
    data = Dat[:,0]
    shape,loc,scale = lognorm.fit(data)
    x4_fitted = stats.lognorm(shape, loc=loc, scale=scale)
    
    # Fitting R1
    data = Dat[:,3]
    loc,scale = norm.fit(data)
    R1_fitted = norm(loc,scale)
    
    # Fitting x2
    data = Dat[:,9]
    shape,loc,scale = lognorm.fit(data)
    x2_fitted = stats.lognorm(shape, loc=loc, scale=scale)

    # Fitting R2
    data = Dat[:,12]
    loc,scale = norm.fit(data)
    R2_fitted = norm(loc,scale)


    return x4_fitted,R1_fitted,x2_fitted,R2_fitted

def GenDTKT(DT_Fits,KT_Fits,seed,size):

    np.random.seed(seed)
    DT = np.zeros([size,6])
    KT = np.zeros([size,15])

    DT[:,0] = DT_Fits[0].rvs(size)
    DT[:,2] = DT_Fits[0].rvs(size)
    DT[:,5] = DT_Fits[0].rvs(size)

    DT[:,1] = DT_Fits[1].rvs(size)
    DT[:,3] = DT_Fits[1].rvs(size)
    DT[:,4] = DT_Fits[1].rvs(size)

    for k in range(3):
        KT[:,k] = KT_Fits[0].rvs(size)
    for k in range(3,9):
        KT[:,k] = KT_Fits[1].rvs(size)
    for k in range(9,12):
        KT[:,k] = KT_Fits[2].rvs(size)
    for k in range(12,15):
        KT[:,k] = KT_Fits[3].rvs(size)

    return DT,KT
    
def DKIMetrics(dt,kt,analytical=True):
    if(dt.ndim == 1):
        dt = vals_to_mat(dt)
    evals,evecs = np.linalg.eigh(dt)
    idx = np.argsort(evals)[::-1]
    evals = evals[idx]
    evecs = evecs[:, idx]
    
    params = np.concatenate([evals,np.hstack(evecs),kt])
    params2 = np.concatenate([evals,np.hstack(evecs),-kt])

    mk = dki.mean_kurtosis(params,analytical=analytical,min_kurtosis=-3.0 / 7, max_kurtosis=np.inf)

    ak = dki.axial_kurtosis(params,analytical=analytical,min_kurtosis=-3.0 / 7, max_kurtosis=np.inf)

    rk = dki.radial_kurtosis(params,analytical=analytical,min_kurtosis=-3.0 / 7, max_kurtosis=np.inf)

    mkt = dki.mean_kurtosis_tensor(params, min_kurtosis=-3.0 / 7, max_kurtosis=np.inf)

    kfa = kurtosis_fractional_anisotropy_test(params)

    return mk,ak,rk,mkt,kfa
    
def kurtosis_fractional_anisotropy_test(dki_params):
    r"""Compute the anisotropy of the kurtosis tensor (KFA).

    See :footcite:p:`Glenn2015` and :footcite:p:`NetoHenriques2021` for further
    details about the method.

    Parameters
    ----------
    dki_params : ndarray (x, y, z, 27) or (n, 27)
        All parameters estimated from the diffusion kurtosis model.
        Parameters are ordered as follows:
            1) Three diffusion tensor's eigenvalues
            2) Three lines of the eigenvector matrix each containing the first,
                second and third coordinates of the eigenvector
            3) Fifteen elements of the kurtosis tensor

    Returns
    -------
    kfa : array
        Calculated mean kurtosis tensor.

    Notes
    -----
    The KFA is defined as :footcite:p:`Glenn2015`:

    .. math::

         KFA \equiv
         \frac{||\mathbf{W} - MKT \mathbf{I}^{(4)}||_F}{||\mathbf{W}||_F}

    where $W$ is the kurtosis tensor, MKT the kurtosis tensor mean, $I^{(4)}$ is
    the fully symmetric rank 2 isotropic tensor and $||...||_F$ is the tensor's
    Frobenius norm :footcite:p:`Glenn2015`.

    References
    ----------
    .. footbibliography::

    """
    Wxxxx = dki_params[..., 12]
    Wyyyy = dki_params[..., 13]
    Wzzzz = dki_params[..., 14]
    Wxxxy = dki_params[..., 15]
    Wxxxz = dki_params[..., 16]
    Wxyyy = dki_params[..., 17]
    Wyyyz = dki_params[..., 18]
    Wxzzz = dki_params[..., 19]
    Wyzzz = dki_params[..., 20]
    Wxxyy = dki_params[..., 21]
    Wxxzz = dki_params[..., 22]
    Wyyzz = dki_params[..., 23]
    Wxxyz = dki_params[..., 24]
    Wxyyz = dki_params[..., 25]
    Wxyzz = dki_params[..., 26]


    W = 1.0 / 5.0 * (Wxxxx + Wyyyy + Wzzzz + 2 * Wxxyy + 2 * Wxxzz + 2 * Wyyzz)
    # Compute's equation numerator
    A = (
        (Wxxxx - W) ** 2
        + (Wyyyy - W) ** 2
        + (Wzzzz - W) ** 2
        + 4 * (Wxxxy**2 + Wxxxz**2 + Wxyyy**2 + Wyyyz**2 + Wxzzz**2 + Wyzzz**2)
        + 6 * ((Wxxyy - W / 3) ** 2 + (Wxxzz - W / 3) ** 2 + (Wyyzz - W / 3) ** 2)
        + 12 * (Wxxyz**2 + Wxyyz**2 + Wxyzz**2)
    )
    # Compute's equation denominator
    B = (
        Wxxxx**2
        + Wyyyy**2
        + Wzzzz**2
        + 4 * (Wxxxy**2 + Wxxxz**2 + Wxyyy**2 + Wyyyz**2 + Wxzzz**2 + Wyzzz**2)
        + 6 * (Wxxyy**2 + Wxxzz**2 + Wyyzz**2)
        + 12 * (Wxxyz**2 + Wxyyz**2 + Wxyzz**2)
    )

    # Compute KFA
    KFA = np.zeros(A.shape)
    KFA = np.sqrt(A / B)

    return KFA

### Simulators

In [None]:
def CustomSimulator(Mat,gtab,S0,snr=None):
    evals,evecs = np.linalg.eigh(Mat)
    signal = single_tensor(gtab, S0=S0, evals=evals, evecs=evecs)
    if(snr is None):
        return signal
    else:
        return AddNoise(signal,S0,snr)

def Simulator(bvals,bvecs,S0,params,SNR):

    dt = ComputeDTI(params)
    signal_dti = CustomSimulator(dt,gradient_table(bvals, bvecs),S0,SNR)
    
    return signal_dti


def GenRicciNoise(signal,S0,snr):

    size = signal.shape
    sigma = S0 / snr
    noise1 = np.random.normal(0, sigma, size=size)
    noise2 = np.random.normal(0, sigma, size=size)

    return np.sqrt((signal+noise1) ** 2 + noise2 ** 2)


def AddNoise(signal,S0,snr):
    
    return GenRicciNoise(signal,S0,snr)

def CustomDKISimulator(dt,kt,gtab,S0,snr=None):
    if(dt.ndim == 1):
        dt = vals_to_mat(dt)
    evals,evecs = np.linalg.eigh(dt)
    params = np.concatenate([evals,np.hstack(evecs),kt])
    signal = dki.dki_prediction(params,gtab,S0)
    if(snr is None):
        return signal
    else:
        return AddNoise(signal,S0,snr)

### SBI Priors

In [None]:
class DTIPrior:
    def __init__(self, lower_abs : Tensor, upper_abs : Tensor, 
                       lower_rest: Tensor, upper_rest: Tensor,
                        return_numpy: bool = False):

        self.dist_abs = BoxUniform(low= lower_abs* torch.ones(3), high=upper_abs * torch.ones(3))
        self.dist_rest = BoxUniform(low=lower_rest * torch.ones(3), high=upper_rest *torch.ones(3))
        self.return_numpy = return_numpy
        
    def sample(self, sample_shape=torch.Size([])):
        
        abc  = self.dist_abs.sample(sample_shape)
        rest = self.dist_rest.sample(sample_shape)
        
        if self.return_numpy:   
            params = np.hstack([abc,rest]) 
        else:
            params = torch.hstack([abc,rest])

        return params
        
    def log_prob(self, values):
        if self.return_numpy:
            values = torch.as_tensor(values)
        
        abc  = values[:,:3]
        rest = values[:,3:]

        log_prob_abc  = self.dist_abs.log_prob(abc)
        log_prob_rest = self.dist_rest.log_prob(rest)
        return log_prob_abc+log_prob_rest

class DTIPriorS0:
    def __init__(self, lower_abs : Tensor, upper_abs : Tensor, 
                       lower_rest: Tensor, upper_rest: Tensor,
                       lower_S0: Tensor, upper_S0: Tensor,
                        return_numpy: bool = False):

        self.dist_abs = BoxUniform(low= lower_abs* torch.ones(3), high=upper_abs * torch.ones(3))
        self.dist_rest = BoxUniform(low=lower_rest * torch.ones(3), high=upper_rest *torch.ones(3))
        self.dist_S0 = BoxUniform(low=torch.tensor([lower_S0]), high=torch.tensor([upper_S0]))
        self.return_numpy = return_numpy
        
    def sample(self, sample_shape=torch.Size([])):
        
        abc  = self.dist_abs.sample(sample_shape)
        rest = self.dist_rest.sample(sample_shape)
        S0   = self.dist_S0.sample(sample_shape)
        
        if self.return_numpy:   
            params = np.hstack([abc,rest,S0]) 
        else:
            params = torch.hstack([abc,rest,S0])

        return params
        
    def log_prob(self, values):
        if self.return_numpy:
            values = torch.as_tensor(values)
        
        abc  = values[:,:3]
        rest = values[:,3:-1]
        S0   = values[:,-1]

        log_prob_abc  = self.dist_abs.log_prob(abc)
        log_prob_rest = self.dist_rest.log_prob(rest)
        log_prob_S0 = self.dist_S0.log_prob(S0)
        return log_prob_abc+log_prob_rest+log_prob_S0

class DTIPriorS0Direc:
    def __init__(self, lower_abs : Tensor, upper_abs : Tensor, 
                       lower_rest: Tensor, upper_rest: Tensor,
                       lower_S0: Tensor, upper_S0: Tensor,
                        return_numpy: bool = False):

        self.dist_abs = BoxUniform(low= lower_abs* torch.ones(3), high=upper_abs * torch.ones(3))
        self.dist_rest = BoxUniform(low=lower_rest * torch.ones(3), high=upper_rest *torch.ones(3))
        self.dist_S0 = BoxUniform(low=torch.tensor([lower_S0]), high=torch.tensor([upper_S0]))
        self.direction_choice = Categorical(probs=torch.ones(1, 5))
        self.return_numpy = return_numpy
        
    def sample(self, sample_shape=torch.Size([])):
        
        abc  = self.dist_abs.sample(sample_shape)
        rest = self.dist_rest.sample(sample_shape)
        S0   = self.dist_S0.sample(sample_shape)
        direc = self.direction_choice.sample(sample_shape)       
        
        if self.return_numpy:   
            params = np.hstack([abc,rest,S0,direc]) 
        else:
            params = torch.hstack([abc,rest,S0,direc])

        return params
        
    def log_prob(self, values):
        if self.return_numpy:
            values = torch.as_tensor(values)
        
        abc  = values[:,:3]
        rest = values[:,3:-2]
        S0   = values[:,-2]
        direc   = values[:,-1]

        log_prob_abc   = self.dist_abs.log_prob(abc)
        log_prob_rest  = self.dist_rest.log_prob(rest)
        log_prob_S0    = self.dist_S0.log_prob(S0)
        log_prob_direc =  self.direction_choice.log_prob(direc)
        return log_prob_abc+log_prob_rest+log_prob_S0+log_prob_direc

class DTIPriorS0Noise:
    def __init__(self, lower_abs : Tensor, upper_abs : Tensor, 
                       lower_rest: Tensor, upper_rest: Tensor,
                       lower_S0: Tensor, upper_S0: Tensor,
                       lower_noise: Tensor, upper_noise: Tensor,
                        return_numpy: bool = False):

        self.dist_abs = BoxUniform(low= lower_abs* torch.ones(3), high=upper_abs * torch.ones(3))
        self.dist_rest = BoxUniform(low=lower_rest * torch.ones(3), high=upper_rest *torch.ones(3))
        self.dist_S0 = BoxUniform(low=torch.tensor([lower_S0]), high=torch.tensor([upper_S0]))
        self.dist_noise = BoxUniform(low=torch.tensor([lower_noise]), high=torch.tensor([upper_noise]))
        self.return_numpy = return_numpy
        
    def sample(self, sample_shape=torch.Size([])):
        
        abc     = self.dist_abs.sample(sample_shape)
        rest    = self.dist_rest.sample(sample_shape)
        S0      = self.dist_S0.sample(sample_shape)
        noise   = self.dist_noise.sample(sample_shape)
        
        if self.return_numpy:   
            params = np.hstack([abc,rest,S0,noise]) 
        else:
            params = torch.hstack([abc,rest,S0,noise])

        return params
        
    def log_prob(self, values):
        if self.return_numpy:
            values = torch.as_tensor(values)
        
        abc     = values[:,:3]
        rest    = values[:,3:-2]
        S0      = values[:,-2]
        noise   = values[:,-1]

        log_prob_abc  = self.dist_abs.log_prob(abc)
        log_prob_rest = self.dist_rest.log_prob(rest)
        log_prob_S0 = self.dist_S0.log_prob(S0)
        log_prob_noise = self.dist_noise.log_prob(noise)
        return log_prob_abc+log_prob_rest+log_prob_S0+log_prob_noise

def histogram_mode(data, bins=50):
    # Calculate the histogram
    counts, bin_edges = np.histogram(data, bins=bins)
    
    # Find the bin with the maximum count (highest frequency)
    max_bin_index = np.argmax(counts)
    
    # Calculate the mode as the midpoint of the bin with the highest count
    mode = (bin_edges[max_bin_index] + bin_edges[max_bin_index + 1]) / 2
    
    return mode

### Errors

In [None]:
def Errors(Guess,Truth,gtab,signal_true,signal_provided,S0Guess=200):
    # Eigenvalue error
    evals_guess_raw,evecs_guess = np.linalg.eigh(Guess)
    evals_guess = np.sort(evals_guess_raw)
    evals_true_raw,evecs_true = np.linalg.eigh(Truth)
    evals_true = np.sort(evals_true_raw)
    
    EigError = np.linalg.norm(evals_guess-evals_true)

    # Mean diffusivitiy
    mean_true = np.mean(evals_true)
    mean_guess = np.mean(evals_guess)
    MD = abs(mean_true-mean_guess)

    # Fractional Anisotropy
    FA_true  = FracAni(evals_true,mean_true)
    FA_guess = FracAni(evals_guess,mean_guess)
    FA = abs(FA_true-FA_guess)                                        

    # Frobenius error
    Frob =  np.linalg.norm(Guess-Truth, 'fro')

    # Signal error
    signal_guess = single_tensor(gtab, S0=S0Guess, evals=evals_guess_raw, evecs=evecs_guess)
    Err  = np.linalg.norm(signal_true-signal_guess)/len(signal_true)
    Corr = np.corrcoef(signal_true,signal_guess)[0,1]
    
    Err2  = np.linalg.norm(signal_provided-signal_guess[:len(signal_provided)])/len(signal_provided)
    Corr2 = np.corrcoef(signal_provided,signal_guess[:len(signal_provided)])[0,1]
    
    return MD,FA,EigError,Frob,Err,Corr,Err2,Corr2

def ErrorsMDFA(Guess,Truth):
    # Eigenvalue error
    evals_guess_raw,evecs_guess = np.linalg.eigh(Guess)
    evals_guess = np.sort(evals_guess_raw)
    evals_true_raw,evecs_true = np.linalg.eigh(Truth)
    evals_true = np.sort(evals_true_raw)

    # Mean diffusivitiy
    mean_true = np.mean(evals_true)
    mean_guess = np.mean(evals_guess)
    if(not mean_true == 0):
        MD = abs(mean_true-mean_guess)
    else:
        MD = abs(mean_true-mean_guess)

    # Fractional Anisotropy
    FA_true  = FracAni(evals_true,mean_true)
    FA_guess = FracAni(evals_guess,mean_guess)
    if(not FA_true == 0):
        FA = abs(FA_true-FA_guess)
    else:
        FA = abs(FA_true-FA_guess)
                                    
    
    return MD,FA
    
def PercsMDFA(Guess,Truth):
    # Eigenvalue error
    evals_guess_raw,evecs_guess = np.linalg.eigh(Guess)
    evals_guess = np.sort(evals_guess_raw)
    evals_true_raw,evecs_true = np.linalg.eigh(Truth)
    evals_true = np.sort(evals_true_raw)

    # Mean diffusivitiy
    mean_true = np.mean(evals_true)
    mean_guess = np.mean(evals_guess)
    if(not mean_true == 0):
        MD = abs(mean_true-mean_guess)/mean_true
    else:
        MD = abs(mean_true-mean_guess)/mean_true

    # Fractional Anisotropy
    FA_true  = FracAni(evals_true,mean_true)
    FA_guess = FracAni(evals_guess,mean_guess)
    if(not FA_true == 0):
        FA = abs(FA_true-FA_guess)/FA_true
    else:
        FA = abs(FA_true-FA_guess)/FA_true
                                    
    
    return MD,FA


def DKIErrors(GuessDT,GuessKT,TruthDT,TruthKT):
    guess = DKIMetrics(GuessDT,GuessKT,False)
    truth = DKIMetrics(TruthDT,TruthKT,False)

    #mk diff
    mk = abs(guess[0]-truth[0])
    ak = abs(guess[1]-truth[1])
    rk = abs(guess[2]-truth[2])
    mkt = abs(guess[3]-truth[3])
    kfa = abs(guess[4]-truth[4])

    return mk,ak,rk,mkt,kfa

def Percs(GuessDT,GuessKT,TruthDT,TruthKT):
    guess = DKIMetrics(GuessDT,GuessKT,False)
    truth = DKIMetrics(TruthDT,TruthKT,False)
    
    #mk diff
    mk = abs(guess[0]-truth[0])/abs(truth[0])
    ak = abs(guess[1]-truth[1])/abs(truth[1])
    rk = abs(guess[2]-truth[2])/abs(truth[2])
    mkt = abs(guess[3]-truth[3])/abs(truth[3])
    kfa = abs(guess[4]-truth[4])/abs(truth[4])
    
    return mk,ak,rk,mkt,kfa


### Plotting

In [None]:
def viol_plot(A,col,hatch=False,**kwargs):
    A_T = np.transpose(A)
    filtered_A = []
    for column in A_T:
        # Remove NaNs
        column = column[~np.isnan(column)]
        # Identify outliers using Z-score
        z_scores = stats.zscore(column)
        abs_z_scores = np.abs(z_scores)
        # Filter data within 3 standard deviations
        filtered_entries = (abs_z_scores < 1000)
        filtered_column = column[filtered_entries]
        filtered_A.append(filtered_column)
    
    vp = plt.violinplot(filtered_A,showmeans=True,**kwargs)  
    for v in vp['bodies']:
        v.set_facecolor(col)
    vp['cbars'].set_color(col)
    vp['cmins'].set_color(col)
    vp['cmaxes'].set_color(col)
    vp['cmeans'].set_color('black')
    if(hatch):
        vp['bodies'][0].set_hatch('//')

def box_plot(data, edge_color, fill_color, hatch=None, linewidth=1.5, **kwargs):
    # Clean data to remove NaNs column-wise
    if(np.ndim(data) == 1):
        cleaned_data = data[~np.isnan(data)]
    else:
        cleaned_data = [d[~np.isnan(d)] for d in data]
    # Create the box plot with cleaned data
    bp = plt.boxplot(cleaned_data, patch_artist=True, **kwargs)
    
    for element in ['boxes', 'whiskers', 'means', 'medians', 'caps']:
        plt.setp(bp[element], color=edge_color, linewidth=linewidth)
    for patch in bp['boxes']:
        patch.set(facecolor=fill_color, linewidth=linewidth)      
        if hatch is not None:
            patch.set(hatch=hatch)

    return bp

## Variables 

In [None]:
network_path = './Networks/'
image_path   = './Images/'
if not os.path.exists(image_path):
    os.makedirs(image_path)
NoiseLevels = [None,20,10,5,2]

TrainingSamples = 50000
InferSamples    = 100

lower_abs,upper_abs = -0.07,0.07
lower_rest,upper_rest = -0.015,0.015
lower_S0 = 25
upper_S0 = 2000
Save = True

TrueCol  = 'k'
NoisyCol = 'k'
WLSFit   = np.array([225,190,106])/255
SBIFit   = np.array([64,176,166])/255

Errors_name = ['MD comparison','FA comparison','eig. comparison','Frobenius','Signal comparison','Correlation','Signal comparison','Correlation2']


## DKI Fits

In [None]:

i = 1
fdwi = './HCP_data/Pat'+str(i)+'/diff_1k.nii.gz'
bvalloc = './HCP_data/Pat'+str(i)+'/bvals_1k.txt'
bvecloc = './HCP_data/Pat'+str(i)+'/bvecs_1k.txt'

fdwi3 = './HCP_data/Pat'+str(i)+'/diff_3k.nii.gz'
bvalloc3 = './HCP_data/Pat'+str(i)+'/bvals_3k.txt'
bvecloc3 = './HCP_data/Pat'+str(i)+'/bvecs_3k.txt'

bvalsHCP = np.loadtxt(bvalloc)
bvecsHCP = np.loadtxt(bvecloc)
gtabHCP = gradient_table(bvalsHCP, bvecsHCP)

bvalsHCP3 = np.loadtxt(bvalloc3)
bvecsHCP3 = np.loadtxt(bvecloc3)
gtabHCP3 = gradient_table(bvalsHCP3, bvecsHCP3)

gtabExt  = gradient_table(np.hstack((bvalsHCP,bvalsHCP3)), np.vstack((bvecsHCP,bvecsHCP3)))

data, affine, img = load_nifti(fdwi, return_img=True)
data, affine = reslice(data, affine, (1.5,1.5,1.5), (2.5,2.5,2.5))
axial_middle = data.shape[2] // 2
maskdata, mask = median_otsu(data, vol_idx=range(10, 50), median_radius=3,
                             numpass=1, autocrop=False, dilate=2)

data3, affine, img = load_nifti(fdwi3, return_img=True)
data3, affine = reslice(data3, affine, (1.5,1.5,1.5), (2.5,2.5,2.5))
# Get the indices of True values
true_indices = np.argwhere(mask)

# Determine the minimum and maximum indices along each dimension
min_coords = true_indices.min(axis=0)
max_coords = true_indices.max(axis=0)

maskdata  = maskdata[min_coords[0]:max_coords[0]+1,min_coords[1]:max_coords[1]+1,min_coords[2]:max_coords[2]+1]
maskdata3 = data3[min_coords[0]:max_coords[0]+1,min_coords[1]:max_coords[1]+1,min_coords[2]:max_coords[2]+1]

TestData = np.concatenate([maskdata[:, :, axial_middle, :],maskdata3[:, :, axial_middle, :]],axis=-1)
FlatTD = TestData.reshape(maskdata.shape[0]*maskdata.shape[1],138)
FlatTD = FlatTD[FlatTD[:,:69].sum(axis=-1)>0]
FlatTD = FlatTD[~np.array(FlatTD<0).any(axis=-1)]

dkimodel = dki.DiffusionKurtosisModel(gtabExt)
tenfit = dkimodel.fit(FlatTD)
DKIHCP = tenfit.kt
DTIHCP = tenfit.lower_triangular()
DKIFull = np.array(DKIHCP)
DTIFull = np.array(DTIHCP)


DTIFilt1 = DTIFull[(abs(DKIFull)<10).all(axis=1)]
DKIFilt1 = DKIFull[(abs(DKIFull)<10).all(axis=1)]
DTIFilt = DTIFilt1[(DKIFilt1>-3/7).all(axis=1)]
DKIFilt = DKIFilt1[(DKIFilt1>-3/7).all(axis=1)]

TrueMets = []
FA       = []
for (dt,kt) in tqdm.tqdm(zip(DTIFilt,DKIFilt)):
    TrueMets.append(DKIMetrics(dt,kt))
    FA.append(FracAni(np.linalg.eigh(vals_to_mat(dt))[0],np.mean(np.linalg.eigh(vals_to_mat(dt))[0])))
TrueMets = np.array(TrueMets)
TrueFA = np.array(FA)

In [None]:
# Full fit
DT1_full,DT2_full = FitDT(DTIFilt,1)
x4_full,R1_full,x2_full,R2_full = FitKT(DKIFilt,1)

# LowFA Fit
DT1_lfa,DT2_lfa = FitDT(DTIFilt[TrueMets[:,-1]<0.3,:],1)
x4_lfa,R1_lfa,x2_lfa,R2_lfa = FitKT(DKIFilt[TrueMets[:,-1]<0.3,:],1)

# HighFA Fit
DT1_hfa,DT2_hfa = FitDT(DTIFilt[TrueMets[:,-1]>0.7,:],1)
x4_hfa,R1_hfa,x2_hfa,R2_hfa = FitKT(DKIFilt[TrueMets[:,-1]>0.7,:],1)

# UltraLowFA Fit
DT1_ulfa,DT2_ulfa = FitDT(DTIFilt[TrueMets[:,-1]<0.1,:],1)
x4_ulfa,R1_ulfa,x2_ulfa,R2_ulfa = FitKT(DKIFilt[TrueMets[:,-1]<0.1,:],1)

# HigherAK Fit
DT1_hak,DT2_hak = FitDT(DTIFilt[TrueMets[:,1]>0.7,:],1)
x4_hak,R1_hak,x2_hak,R2_hak = FitKT(DKIFilt[TrueMets[:,1]>0.9,:],1)

# Fig 1

In [None]:
FigLoc = image_path + 'Fig_1/'
if not os.path.exists(FigLoc):
    os.makedirs(FigLoc)

## b

In [None]:
i = 1
fdwi = './HCP_data/Pat'+str(i)+'/diff_1k.nii.gz'
bvalloc = './HCP_data/Pat'+str(i)+'/bvals_1k.txt'
bvecloc = './HCP_data/Pat'+str(i)+'/bvecs_1k.txt'

fdwi3 = './HCP_data/Pat'+str(i)+'/diff_3k.nii.gz'
bvalloc3 = './HCP_data/Pat'+str(i)+'/bvals_3k.txt'
bvecloc3 = './HCP_data/Pat'+str(i)+'/bvecs_3k.txt'

bvalsHCP = np.loadtxt(bvalloc)
bvecsHCP = np.loadtxt(bvecloc)

bvalsHCP3 = np.loadtxt(bvalloc3)
bvecsHCP3 = np.loadtxt(bvecloc3)
gtabHCP = gradient_table(bvalsHCP, bvecsHCP)

In [None]:
# Choose the first point (arbitrary starting point, e.g., the first gradient)
selected_indices = [0]
distance_matrix = squareform(pdist(bvecsHCP))
# Iteratively select the point furthest from the current selection
for _ in range(6):  # We need 7 points in total, and one is already selected
    remaining_indices = list(set(range(len(bvecsHCP))) - set(selected_indices))
    
    # Calculate the minimum distance to the selected points for each remaining point
    min_distances = np.min(distance_matrix[remaining_indices][:, selected_indices], axis=1)
    
    # Select the point with the maximum minimum distance
    next_index = remaining_indices[np.argmax(min_distances)]
    selected_indices.append(next_index)

selected_indices = selected_indices

# Choose the first point (arbitrary starting point, e.g., the first gradient)
selected_indices_20 = [0]
distance_matrix = squareform(pdist(bvecsHCP))
# Iteratively select the point furthest from the current selection
for _ in range(19):  # We need 7 points in total, and one is already selected
    remaining_indices = list(set(range(len(bvecsHCP))) - set(selected_indices_20))
    
    # Calculate the minimum distance to the selected points for each remaining point
    min_distances = np.min(distance_matrix[remaining_indices][:, selected_indices_20], axis=1)
    
    # Select the point with the maximum minimum distance
    next_index = remaining_indices[np.argmax(min_distances)]
    selected_indices_20.append(next_index)

selected_indices_20 = np.array(selected_indices_20)

# Choose the first point (arbitrary starting point, e.g., the first gradient)
selected_indices3 = [1]
distance_matrix = squareform(pdist(bvecsHCP3))
# Iteratively select the point furthest from the current selection
for _ in range(15):  # We need 7 points in total, and one is already selected
    remaining_indices = list(set(range(len(bvecsHCP3))) - set(selected_indices3))
    
    # Calculate the minimum distance to the selected points for each remaining point
    min_distances = np.min(distance_matrix[remaining_indices][:, selected_indices3], axis=1)
    
    # Select the point with the maximum minimum distance
    next_index = remaining_indices[np.argmax(min_distances)]
    selected_indices3.append(next_index)

selected_indices3 = np.array(selected_indices3)
selected_indices3 = selected_indices3[selected_indices3>0]

# Choose the first point (arbitrary starting point, e.g., the first gradient)
selected_indices3_48 = [1]
distance_matrix = squareform(pdist(bvecsHCP3))
# Iteratively select the point furthest from the current selection
for _ in range(28):  # We need 7 points in total, and one is already selected
    remaining_indices = list(set(range(len(bvecsHCP3))) - set(selected_indices3_48))
    
    # Calculate the minimum distance to the selected points for each remaining point
    min_distances = np.min(distance_matrix[remaining_indices][:, selected_indices3_48], axis=1)
    
    # Select the point with the maximum minimum distance
    next_index = remaining_indices[np.argmax(min_distances)]
    selected_indices3_48.append(next_index)

selected_indices3_48= np.array(selected_indices3_48)
selected_indices3_48= selected_indices3_48[selected_indices3_48>0]

In [None]:
u = np.linspace(0, 2 * np.pi, 100)
v = np.linspace(0, np.pi, 100)

x = 4 * np.outer(np.cos(u), np.sin(v))
y = 4 * np.outer(np.sin(u), np.sin(v))
z = 4 * np.outer(np.ones(np.size(u)), np.cos(v))

x1 = 2 * np.outer(np.cos(u), np.sin(v))
y1 = 2 * np.outer(np.sin(u), np.sin(v))
z1 = 2 * np.outer(np.ones(np.size(u)), np.cos(v))
#for i in range(2):
#    ax.plot_surface(x+random.randint(-5,5), y+random.randint(-5,5), z+random.randint(-5,5),  rstride=4, cstride=4, color='b', linewidth=0, alpha=0.5)

fig = plt.figure()
ax = fig.add_subplot(projection='3d')

ax.set_aspect('equal')
ax.scatter(0,0,0,s=50,color='k',label=r'$B_0$')

ax.plot_surface(x1, y1, z1,  rstride=4, cstride=4, color=WLSFit, linewidth=0, alpha=0.25)
ax.scatter(2*bvecsHCP[np.sum(bvecsHCP,axis=1)!=0][:,0],2*bvecsHCP[np.sum(bvecsHCP,axis=1)!=0][:,1],2*bvecsHCP[np.sum(bvecsHCP,axis=1)!=0][:,2],s=50,
           color=WLSFit-0.2,label=r'$B = 1000$')

ax.plot_surface(x, y, z,  rstride=4, cstride=4, color=SBIFit, linewidth=0, alpha=0.25)
ax.scatter(4*bvecsHCP3[np.sum(bvecsHCP3,axis=1)!=0][:,0],4*bvecsHCP3[np.sum(bvecsHCP3,axis=1)!=0][:,1],4*bvecsHCP3[np.sum(bvecsHCP3,axis=1)!=0][:,2],s=50,
           color=SBIFit-0.2,label=r'$B = 3000$')


plt.axis('off')
plt.legend(ncols=3,fontsize=18,columnspacing=0.3,handlelength=0.6,handletextpad=0.3,loc=2,bbox_to_anchor=(0.05,0.3))
plt.title('Full Set \n 69 measurements',y=0.85,fontsize=24)
if Save: plt.savefig(FigLoc+'MultiShell.pdf',format='pdf',bbox_inches='tight', transparent=True)

In [None]:
u = np.linspace(0, 2 * np.pi, 100)
v = np.linspace(0, np.pi, 100)

x = 4 * np.outer(np.cos(u), np.sin(v))
y = 4 * np.outer(np.sin(u), np.sin(v))
z = 4 * np.outer(np.ones(np.size(u)), np.cos(v))

x1 = 2 * np.outer(np.cos(u), np.sin(v))
y1 = 2 * np.outer(np.sin(u), np.sin(v))
z1 = 2 * np.outer(np.ones(np.size(u)), np.cos(v))
#for i in range(2):
#    ax.plot_surface(x+random.randint(-5,5), y+random.randint(-5,5), z+random.randint(-5,5),  rstride=4, cstride=4, color='b', linewidth=0, alpha=0.5)

fig = plt.figure()
ax = fig.add_subplot(projection='3d')

ax.set_aspect('equal')
ax.scatter(0,0,0,s=50,color='k',label=r'$B_0$')

ax.plot_surface(x1, y1, z1,  rstride=4, cstride=4, color=WLSFit, linewidth=0, alpha=0.25)
ax.scatter(2*bvecsHCP[selected_indices,0],2*bvecsHCP[selected_indices,1],2*bvecsHCP[selected_indices,2],s=50,
           color=WLSFit-0.2,label=r'$B = 1000$')

ax.plot_surface(x, y, z,  rstride=4, cstride=4, color=SBIFit, linewidth=0, alpha=0.25)
ax.scatter(4*bvecsHCP3[selected_indices3,0],4*bvecsHCP3[selected_indices3,1],4*bvecsHCP3[selected_indices3,2],s=50,
           color=SBIFit-0.2,label=r'$B = 3000$')


plt.axis('off')
plt.title('Minimum Set \n 22 measurements',y=0.85,fontsize=32)
if Save: plt.savefig(FigLoc+'MiniSet.pdf',format='pdf',bbox_inches='tight', transparent=True)

In [None]:
u = np.linspace(0, 2 * np.pi, 100)
v = np.linspace(0, np.pi, 100)

x = 4 * np.outer(np.cos(u), np.sin(v))
y = 4 * np.outer(np.sin(u), np.sin(v))
z = 4 * np.outer(np.ones(np.size(u)), np.cos(v))

x1 = 2 * np.outer(np.cos(u), np.sin(v))
y1 = 2 * np.outer(np.sin(u), np.sin(v))
z1 = 2 * np.outer(np.ones(np.size(u)), np.cos(v))
#for i in range(2):
#    ax.plot_surface(x+random.randint(-5,5), y+random.randint(-5,5), z+random.randint(-5,5),  rstride=4, cstride=4, color='b', linewidth=0, alpha=0.5)

fig = plt.figure()
ax = fig.add_subplot(projection='3d')

ax.set_aspect('equal')
ax.scatter(0,0,0,s=50,color='k',label=r'$B_0$')

ax.plot_surface(x1, y1, z1,  rstride=4, cstride=4, color=WLSFit, linewidth=0, alpha=0.25)
ax.scatter(2*bvecsHCP[selected_indices_20,0],2*bvecsHCP[selected_indices_20,1],2*bvecsHCP[selected_indices_20,2],s=50,
           color=WLSFit-0.2,label=r'$B = 1000$')

ax.plot_surface(x, y, z,  rstride=4, cstride=4, color=SBIFit, linewidth=0, alpha=0.25)
ax.scatter(4*bvecsHCP3[selected_indices3_48,0],4*bvecsHCP3[selected_indices3_48,1],4*bvecsHCP3[selected_indices3_48,2],s=50,
           color=SBIFit-0.2,label=r'$B = 3000$')


plt.axis('off')
plt.title('Medium Set \n 48 measurements',y=0.85,fontsize=32)
if Save: plt.savefig(FigLoc+'MedSet.pdf',format='pdf',bbox_inches='tight',transparent=True)

## d

In [None]:
fimg_init, fbvals, fbvecs = get_fnames('small_64D')
bvals, bvecs = read_bvals_bvecs(fbvals, fbvecs)
hsph_initial = HemiSphere(xyz=bvecs[1:])
hsph_initial20 = HemiSphere(xyz=bvecs[1:20])
hsph_initial7 = HemiSphere(xyz=bvecs[1:7])
hsph_updated,potentials = disperse_charges(hsph_initial,5000)
hsph_updated20,potentials = disperse_charges(hsph_initial20,5000)
hsph_updated7,potentials = disperse_charges(hsph_initial7,5000)

gtabSimF = gradient_table(np.array([0]+[1000]*64).squeeze(), np.vstack([[0,0,0],hsph_updated.vertices]))
gtabSim20 = gradient_table(np.array([0]+[1000]*19).squeeze(), np.vstack([[0,0,0],hsph_updated20.vertices]))
gtabSim7 = gradient_table(np.array([0]+[1000]*6).squeeze(), np.vstack([[0,0,0],hsph_updated7.vertices]))

In [None]:
class DTIPriorDirec:
    def __init__(self, lower_abs : Tensor, upper_abs : Tensor, 
                       lower_rest: Tensor, upper_rest: Tensor,
                        return_numpy: bool = False):

        self.dist_abs = BoxUniform(low= lower_abs* torch.ones(3), high=upper_abs * torch.ones(3))
        self.dist_rest = BoxUniform(low=lower_rest * torch.ones(3), high=upper_rest *torch.ones(3))
        self.direction_choice = Categorical(probs=torch.ones(1, 5))
        self.return_numpy = return_numpy
        
    def sample(self, sample_shape=torch.Size([])):
        
        abc  = self.dist_abs.sample(sample_shape)
        rest = self.dist_rest.sample(sample_shape)
        direc = self.direction_choice.sample(sample_shape)
        
        if self.return_numpy:   
            params = np.hstack([abc,rest,direc]) 
        else:
            params = torch.hstack([abc,rest,direc])

        return params
        
    def log_prob(self, values):
        if self.return_numpy:
            values = torch.as_tensor(values)
        
        abc   = values[:,:3]
        rest  = values[:,3:-1]
        direc = values[:,-1]

        log_prob_abc  = self.dist_abs.log_prob(abc)
        log_prob_rest = self.dist_rest.log_prob(rest)
        log_prob_direc =  self.direction_choice.log_prob(direc)
        return log_prob_abc+log_prob_rest+log_prob_direc

In [None]:
custom_prior = DTIPriorS0(lower_abs,upper_abs,lower_rest,upper_rest,lower_S0,upper_S0)
priorS0, *_ = process_prior(custom_prior) 

In [None]:
np.random.seed(1)

gTabs = [gtabSimF]
for _ in range(4):
    x = np.random.permutation(np.arange(65))
    bvecs_shuffle = gtabSimF.bvecs[x]
    bvals_shuffle = gtabSimF.bvals[x]
    
    gTabs.append(gradient_table(bvals_shuffle, bvecs_shuffle))

In [None]:
if os.path.exists(f"{network_path}/EgPosterior.pickle"):
    with open(f"{network_path}/EgPosterior.pickle", "rb") as handle:
        posterior = pickle.load(handle)
else:
    np.random.seed(1)
    torch.manual_seed(1)
    Obs = []
    Par = []
    for i in tqdm.tqdm(range(TrainingSamples)):
        params = priorDirec.sample()
        dt = ComputeDTI(params[:-1])
        if(np.random.rand()<0.8):
            dt = ForceLowFA(dt)
        cG = gTabs[int(params[-1])]
        Obs.append(CustomSimulator(dt,cG,200,None))
        Par.append(np.append(mat_to_vals(dt),params[-1]))
    
    Obs = np.array(Obs)
    Par = np.array(Par)
    
    Obs = torch.tensor(Obs).float()
    Par = torch.tensor(Par).float()
    # Create inference object. Here, NPE is used.
    inference = SNPE(prior=priorS0)
    
    # generate simulations and pass to the inference object
    inference = inference.append_simulations(Par, Obs)
    
    # train the density estimator and build the posterior
    density_estimator = inference.train()
    posterior = inference.build_posterior(density_estimator)
    if not os.path.exists(f"{network_path}/EgPosterior.pickle"):
        with open(f"{network_path}/EgPosterior.pickle", "wb") as handle:
            pickle.dump(posterior, handle)

In [None]:
np.random.seed(0)
torch.manual_seed(0)
Samples  = []
DTISim = []
S0Sim    = []

params = priorS0.sample([5000])
for i in tqdm.tqdm(range(5000)):
    dt = ComputeDTI(params[i])
    dt = ForceLowFA(dt)
    DTISim.append(dt)
    S0Sim.append(params[i,-1])
    Samples.append([CustomSimulator(dt,gtabSimF, S0=200,snr=scale) for scale in NoiseLevels])
    
Samples = np.array(Samples).squeeze()
Samples = np.moveaxis(Samples, 0, -1)

Samples20  = []
DTISim20 = []
S0Sim20    = []

params = priorS0.sample([5000])
for i in tqdm.tqdm(range(5000)):
    dt = ComputeDTI(params[i])
    dt = ForceLowFA(dt)
    DTISim20.append(dt)
    S0Sim20.append(params[i,-1])
    Samples20.append([CustomSimulator(dt,gtabSim20, S0=200,snr=scale) for scale in NoiseLevels])
    
Samples20 = np.array(Samples20).squeeze()
Samples20 = np.moveaxis(Samples20, 0, -1)

Samples7  = []
DTISim7 = []
S0Sim7    = []

params = priorS0.sample([5000])
for i in tqdm.tqdm(range(5000)):
    dt = ComputeDTI(params[i])
    dt = ForceLowFA(dt)
    DTISim7.append(dt)
    S0Sim7.append(params[i,-1])
    Samples7.append([CustomSimulator(dt,gtabSim7, S0=200,snr=scale) for scale in NoiseLevels])
    
Samples7 = np.array(Samples7).squeeze()
Samples7 = np.moveaxis(Samples7, 0, -1)

In [None]:
i = np.random.choice(5)
j = np.random.choice(64)
gT = gTabs[i]
dT = DTISim[j]

In [None]:
np.random.seed(1)
torch.manual_seed(1)
tObs = CustomSimulator(dT,gT,200,None)
posterior_samples_1 = posterior.sample((InferSamples,), x=tObs)

In [None]:
for i in range(5):
    plt.subplots(figsize=(6,1))
    plt.plot(Samples[0][:,i],c=SBIFit,lw=3)
    plt.axis('off')
    if Save: plt.savefig(FigLoc+'EgSamples'+str(i)+'.pdf',format='pdf',bbox_inches='tight',transparent=True)

In [None]:
i = 1
j = 20
gT = gtabSimF
dT = DTISim[j]
tObs = CustomSimulator(dT,gT,200,None)
posterior_samples_1 = posterior.sample((InferSamples,), x=tObs)

signal_dti = CustomSimulator(vals_to_mat([histogram_mode(p) for p in posterior_samples_1.T][:-1]),gT,200,None)
plt.subplots(figsize=(6,1))
plt.plot(tObs,lw=3,c='k')
plt.axis('off')
if Save: plt.savefig(FigLoc+'EgInfPre.pdf',format='pdf',bbox_inches='tight',transparent=True)
plt.subplots(figsize=(6,1))
plt.plot(tObs,lw=3,c='k',label='true signal')
plt.plot(signal_dti,lw=2,c=SBIFit,ls='--',label='Recon. signal')
plt.axis('off')
plt.legend(ncols=2,loc=1,bbox_to_anchor =  (1.09,1.95),fontsize=32,columnspacing=0.3,handlelength=0.4,handletextpad=0.1)
if Save: plt.savefig(FigLoc+'EgInfPost.pdf',format='pdf',bbox_inches='tight',transparent=True)

# Fig 2

In [None]:
FigLoc = image_path + 'Fig_2/'
if not os.path.exists(FigLoc):
    os.makedirs(FigLoc)

## a

In [None]:
torch.manual_seed(0)
np.random.seed(0)

params = priorS0.sample()
dtTruth = ComputeDTI(params)
dtTruth = ForceLowFA(dtTruth)
Truth = CustomSimulator(dtTruth,gtabSimF,S0=200,snr=None)

    
dt_evals,dt_evecs = np.linalg.eigh(dtTruth)

SNR = [CustomSimulator(dtTruth,gtabSimF, S0=200,snr=scale) for scale in NoiseLevels[1:]]
    
SNR = np.array(SNR)

In [None]:
plt.subplots(figsize=(6,1))
plt.plot(Truth,'k',lw=2,label='True signal')
plt.plot(SNR[0],'gray',lw=2,ls='--',label='Noisy signal')
plt.axis('off')
plt.legend(ncols=2,loc=1,bbox_to_anchor =  (1.09,1.95),fontsize=32,columnspacing=0.3,handlelength=0.4,handletextpad=0.1)
if Save: plt.savefig(FigLoc+'EgSig20.pdf',format='pdf',bbox_inches='tight',transparent=True)
plt.show()
plt.subplots(figsize=(6,1))
plt.plot(Truth,'k',lw=2)
plt.plot(SNR[1],'gray',lw=2,ls='--')
plt.axis('off')
if Save: plt.savefig(FigLoc+'EgSig10.pdf',format='pdf',bbox_inches='tight',transparent=True)
plt.show()
plt.subplots(figsize=(6,1))
plt.plot(Truth,'k',lw=2)
plt.plot(SNR[2],'gray',lw=2,ls='--')
plt.axis('off')
if Save: plt.savefig(FigLoc+'EgSig5.pdf',format='pdf',bbox_inches='tight',transparent=True)
plt.show()
plt.subplots(figsize=(6,1))
plt.plot(Truth,'k',lw=2)
plt.plot(SNR[3],'gray',lw=2,ls='--')
plt.axis('off')
if Save: plt.savefig(FigLoc+'EgSig2.pdf',format='pdf',bbox_inches='tight',transparent=True)
plt.show()

## b

In [None]:
SNR20 = np.vstack([CustomSimulator(dtTruth,gtabSimF, S0=200,snr=20) for k in range(100)])
SNR10 = np.vstack([CustomSimulator(dtTruth,gtabSimF, S0=200,snr=10) for k in range(100)])
SNR5 = np.vstack([CustomSimulator(dtTruth,gtabSimF, S0=200,snr=5) for k in range(100)])
SNR2 = np.vstack([CustomSimulator(dtTruth,gtabSimF, S0=200,snr=2) for k in range(100)])


In [None]:
tenmodel = dti.TensorModel(gtabSimF,return_S0_hat = True,fit_method='NLLS')
tenfit = tenmodel.fit(SNR20)
FA20 = dti.fractional_anisotropy(tenfit.evals)
MD20 = dti.mean_diffusivity(tenfit.evals)
tenfit = tenmodel.fit(SNR10)
FA10 = dti.fractional_anisotropy(tenfit.evals)
MD10 = dti.mean_diffusivity(tenfit.evals)
tenfit = tenmodel.fit(SNR5)
FA5 = dti.fractional_anisotropy(tenfit.evals)
MD5 = dti.mean_diffusivity(tenfit.evals)
tenfit = tenmodel.fit(SNR2)
FA2 = dti.fractional_anisotropy(tenfit.evals)
MD2 = dti.mean_diffusivity(tenfit.evals)

In [None]:
plt.subplots(figsize=(6.4,2.4))
viol_plot(np.array([FA20,FA10,FA5,FA2]).T,WLSFit)

l = plt.axhline(FracAni(dt_evals,np.mean(dt_evals)),c='k',lw=3,ls='--',label='True FA')
plt.xticks([1,2,3,4],[20,10,5,2],fontsize=28)
plt.xticks(fontsize=28)
plt.xlabel('SNR',fontsize=32)
plt.ylabel('FA',fontsize=32)
legend_elements = [
    mpatches.Patch(facecolor=WLSFit, edgecolor='k', label='Fit FA'),
    Line2D([0], [0], color='k', lw=3, ls='--', label='True FA')
]
plt.legend(handles=legend_elements,loc = 'lower left',bbox_to_anchor=(0.05,0.5),
           fontsize=32,columnspacing=0.3,handlelength=0.6,handletextpad=0.3,ncols=2)
plt.yticks([0,0.5,1])
if Save: plt.savefig(FigLoc+'EgNoiseFA.pdf',format='pdf',bbox_inches='tight',transparent=True)

plt.subplots(figsize=(6.4,2.4))
viol_plot(np.array([MD20,MD10,MD5,MD2]).T,WLSFit)

l = plt.axhline(np.mean(dt_evals),c='k',lw=3,ls='--',label='True MD')
plt.xticks([1,2,3,4],[20,10,5,2],fontsize=28)
plt.xticks(fontsize=28)
plt.xlabel('SNR',fontsize=32)
plt.ylabel('MD',fontsize=32)
legend_elements = [
    mpatches.Patch(facecolor=WLSFit, edgecolor='k', label='Fit MD'),
    Line2D([0], [0], color='k', lw=3, ls='--', label='True MD')
]
plt.legend(handles=legend_elements,loc = 'lower left',bbox_to_anchor=(0.05,0.5),
           fontsize=32,columnspacing=0.3,handlelength=0.6,handletextpad=0.3,ncols=2)
plt.yticks([0,0.001,0.002])
plt.ylim((-7.687787458229293e-05, 0.0025))
if Save: plt.savefig(FigLoc+'EgNoiseMD.pdf',format='pdf',bbox_inches='tight',transparent=True)

## c

In [None]:
custom_prior = DTIPriorS0(lower_abs,upper_abs,lower_rest,upper_rest,0,30)
priorNoise, *_ = process_prior(custom_prior) 

In [None]:
np.random.seed(1)
torch.manual_seed(1)
if os.path.exists(f"{network_path}/DTISimFull.pickle"):
    with open(f"{network_path}/DTISimFull.pickle", "rb") as handle:
        posteriorFull = pickle.load(handle)
else:
    Obs = []
    Par = []
    for i in tqdm.tqdm(range(TrainingSamples)):
        params = priorNoise.sample()
        dt = ComputeDTI(params)
        dt = ForceLowFA(dt)
        a = params[-1]
        Obs.append(CustomSimulator(dt,gtabSimF,200,a))
        Par.append(np.hstack([mat_to_vals(dt),a]))
    
    Obs = np.array(Obs)
    Par = np.array(Par)
    Obs = torch.tensor(Obs).float()
    Par = torch.tensor(Par).float()
    
    # Create inference object. Here, NPE is used.
    inference = SNPE(prior=priorNoise)
    
    # generate simulations and pass to the inference object
    inference = inference.append_simulations(Par, Obs)
    
    # train the density estimator and build the posterior
    density_estimator = inference.train()
    posteriorFull = inference.build_posterior(density_estimator)
    if not os.path.exists(f"{network_path}/DTISimFull.pickle"):
        with open(f"{network_path}/DTISimFull.pickle", "wb") as handle:
            pickle.dump(posteriorFull, handle)

In [None]:
torch.manual_seed(2)
np.random.seed(2)
MD20 = []
FA20 = []
for S in tqdm.tqdm(SNR20):
    posterior_samples_1 = posteriorFull.sample((InferSamples,), x=S,show_progress_bars=False)
    CustomSimulator(vals_to_mat([histogram_mode(p) for p in posterior_samples_1.T][:-1]),gT,200,None)
    Guess = vals_to_mat(posterior_samples_1.mean(axis=0))
    Guess_clean = clip_negative_eigenvalues(Guess)
    evals_guess_raw,evecs_guess = np.linalg.eigh(Guess_clean)
    MD20.append(np.mean(evals_guess_raw))
    FA20.append(FracAni(evals_guess_raw,MD20[-1]))

torch.manual_seed(2)
np.random.seed(2)
MD10 = []
FA10 = []
for S in tqdm.tqdm(SNR10):
    posterior_samples_1 = posteriorFull.sample((InferSamples,), x=S,show_progress_bars=False)
    Guess = vals_to_mat(posterior_samples_1.mean(axis=0))
    Guess_clean = clip_negative_eigenvalues(Guess)
    evals_guess_raw,evecs_guess = np.linalg.eigh(Guess_clean)
    if((evals_guess_raw<0).any()): print(True)
    MD10.append(np.mean(evals_guess_raw))
    FA10.append(FracAni(evals_guess_raw,MD10[-1]))

torch.manual_seed(2)
np.random.seed(2)
MD5 = []


FA5 = []
for S in tqdm.tqdm(SNR5):
    posterior_samples_1 = posteriorFull.sample((InferSamples,), x=S,show_progress_bars=False)
    Guess = vals_to_mat(posterior_samples_1.mean(axis=0))
    Guess_clean = clip_negative_eigenvalues(Guess)
    evals_guess_raw,evecs_guess = np.linalg.eigh(Guess_clean)
    if((evals_guess_raw<0).any()): print(True)
    MD5.append(np.mean(evals_guess_raw))
    FA5.append(FracAni(evals_guess_raw,MD5[-1]))

torch.manual_seed(2)
np.random.seed(2)
MD2 = []
FA2 = []
for S in tqdm.tqdm(SNR2):
    posterior_samples_1 = posteriorFull.sample((InferSamples,), x=S,show_progress_bars=False)
    Guess = vals_to_mat(posterior_samples_1.mean(axis=0))
    Guess_clean = clip_negative_eigenvalues(Guess)
    evals_guess_raw,evecs_guess = np.linalg.eigh(Guess_clean)
    if((evals_guess_raw<0).any()): print(True)
    MD2.append(np.mean(evals_guess_raw))
    FA2.append(FracAni(evals_guess_raw,MD2[-1]))

In [None]:
plt.subplots(figsize=(6.4,2.4))
viol_plot(np.array([FA20,FA10,FA5,FA2]).T,SBIFit)

l = plt.axhline(FracAni(dt_evals,np.mean(dt_evals)),c='k',lw=3,ls='--',label='True FA')
plt.xticks([1,2,3,4],[20,10,5,2],fontsize=28)
plt.xticks(fontsize=28)
plt.xlabel('SNR',fontsize=32)
plt.ylabel('FA',fontsize=32)
legend_elements = [
    mpatches.Patch(facecolor=SBIFit, edgecolor='k', label='Fit FA'),
    Line2D([0], [0], color='k', lw=3, ls='--', label='True FA')
]
plt.legend(handles=legend_elements,loc = 'lower left',bbox_to_anchor=(0.05,0.5),
           fontsize=32,columnspacing=0.3,handlelength=0.6,handletextpad=0.3,ncols=2)
plt.yticks([0,0.5,1])
if Save: plt.savefig(FigLoc+'EgNoiseFA_SBI.pdf',format='pdf',bbox_inches='tight',transparent=True)
    
plt.subplots(figsize=(6.4,2.4))
viol_plot(np.array([MD20,MD10,MD5,MD2]).T,SBIFit)

l = plt.axhline(np.mean(dt_evals),c='k',lw=3,ls='--',label='True MD')
plt.xticks([1,2,3,4],[20,10,5,2],fontsize=28)
plt.xticks(fontsize=28)
plt.xlabel('SNR',fontsize=32)
plt.ylabel('MD',fontsize=32)
legend_elements = [
    mpatches.Patch(facecolor=SBIFit, edgecolor='k', label='Fit MD'),
    Line2D([0], [0], color='k', lw=3, ls='--', label='True MD')
]
plt.legend(handles=legend_elements,loc = 'lower left',bbox_to_anchor=(-0.05,-0.05),
           fontsize=32,columnspacing=0.3,handlelength=0.6,handletextpad=0.3,ncols=2)
plt.yticks([0,0.001,0.002])
plt.ylim((-7.687787458229293e-05, 0.0025))
if Save: plt.savefig(FigLoc+'EgNoiseMD_SBI.pdf',format='pdf',bbox_inches='tight',transparent=True)

## d

In [None]:
torch.manual_seed(10)
SNR = NoiseLevels
ErrorFull = []
NoiseApproxFull = []
for k in tqdm.tqdm(range(5)):
    ErrorN2 = []
    ENoise = []
    for i in range(500):
        tparams = mat_to_vals(DTISim[i])
        tObs = Samples[k,:,i]#Simulator(bvals,bvecs,200,params,Noise)
        mat_true = vals_to_mat(tparams)
        evals_true,evecs_true = np.linalg.eigh(mat_true)
        true_signal_dti = single_tensor(gtabSimF, S0=200, evals=evals_true, evecs=evecs_true,
                           snr=None)
        posterior_samples_1 = posteriorFull.sample((InferSamples,), x=tObs,show_progress_bars=False)
        mat_guess = vals_to_mat([histogram_mode(p) for p in posterior_samples_1.T])
        mat_guess = clip_negative_eigenvalues(mat_guess)
        ErrorN2.append(Errors(mat_guess,mat_true,gtabSimF,true_signal_dti,tObs))
        ENoise.append(posterior_samples_1[:,-1].mean())
    NoiseApproxFull.append(ENoise)
    ErrorFull.append(ErrorN2)

NoiseApproxFull = np.array(NoiseApproxFull)    

Error_s = []
for k,gtab,Samps,DTIS in zip([65,20,7],[gtabSimF,gtabSim20,gtabSim7],[Samples,Samples20,Samples7],[DTISim,DTISim20,DTISim7]):
    tenmodel = dti.TensorModel(gtab,fit_method='NLLS')
    Error_n = []
    for S,Noise in zip(Samps,NoiseLevels):
        Error = []
        for i in range(500):
            tenfit = tenmodel.fit(S[:,i])
            tensor_vals = dti.lower_triangular(tenfit.quadratic_form)
            DT_test = vals_to_mat(tensor_vals)
            Error.append(Errors(DT_test,DTIS[i],gtab,Samps[0][:,i],S[:,i]))
        Error_n.append(Error)
    Error_s.append(Error_n)
Error_s = np.array(Error_s)
Error_s = np.swapaxes(Error_s,0,1)

In [None]:
fig,axs = plt.subplots(1,2,figsize=(9,3))
ax = axs.ravel()
for ll,(a,E,t) in enumerate(zip(ax,np.array(ErrorFull).T,Errors_name)):
    plt.sca(a) 
    bp = box_plot(E[:,1:].T,SBIFit-0.2, np.clip(SBIFit+0.2,0,1),positions=[1.3,2.3,3.3,4.3],showfliers=False,widths=0.3)
    bp2 = box_plot(Error_s[1:,0,:,ll],WLSFit-0.2, np.clip(WLSFit+0.2,0,1),showfliers=False,widths=0.3)
    plt.sca(a)
    if(ll>3):
        plt.xlabel('SNR', fontsize=24)
    plt.xticks([1.15, 2.15, 3.15, 4.15,], NoiseLevels[1:])
    ymax = max(bp2['whiskers'][1].get_ydata()[1],bp['whiskers'][1].get_ydata()[1])*1.2
    # Create custom legend handles
    plt.gca().ticklabel_format(axis='y',style='sci',scilimits=(-1,1))
    plt.yticks(fontsize=32)
    plt.xticks([1.15, 2.15, 3.15, 4.15,], NoiseLevels[1:],fontsize=32)

    if(ll==1):
        handles = [
            Line2D([0], [0], color=SBIFit, lw=4, label='SBI'),  # Adjust color as per the actual plot color
        ]
        # Add the legenda
        plt.legend(handles=handles,loc=2, bbox_to_anchor=(0,1.15),
                   fontsize=36,columnspacing=0.3,handlelength=0.6,handletextpad=0.3)
    if(ll==0):
        handles = [
            Line2D([0], [0], color=WLSFit, lw=4, label='NLLS'),  # Adjust color as per the actual plot color
        ]
        # Add the legenda
        plt.legend(handles=handles,loc=2, bbox_to_anchor=(0,1.15),
                   fontsize=36,columnspacing=0.3,handlelength=0.6,handletextpad=0.3)

    plt.grid()
plt.tight_layout()

if Save: plt.savefig(FigLoc+'SimDatDTIErrors1.pdf',format='pdf',bbox_inches='tight',transparent=True)

## e

In [None]:
np.random.seed(1)
torch.manual_seed(1)
if os.path.exists(f"{network_path}/DTISimMin.pickle"):
    with open(f"{network_path}/DTISimMin.pickle", "rb") as handle:
        posterior7 = pickle.load(handle)
else:
    Obs = []
    Par = []
    for i in tqdm.tqdm(range(TrainingSamples)):
        params = priorNoise.sample()
        dt = ComputeDTI(params)
        dt = ForceLowFA(dt)
        a = params[-1]
        Obs.append(CustomSimulator(dt,gtabSim7,200,a))
        Par.append(np.hstack([mat_to_vals(dt),a]))
    
    Obs = np.array(Obs)
    Par = np.array(Par)
    Obs = torch.tensor(Obs).float()
    Par = torch.tensor(Par).float()
    
    # Create inference object. Here, NPE is used.
    inference = SNPE(prior=priorNoise)
    
    # generate simulations and pass to the inference object
    inference = inference.append_simulations(Par, Obs)
    
    # train the density estimator and build the posterior
    density_estimator = inference.train()
    posterior7 = inference.build_posterior(density_estimator)
    if not os.path.exists(f"{network_path}/DTISimMin.pickle"):
        with open(f"{network_path}/DTISimMin.pickle", "wb") as handle:
            pickle.dump(posterior7, handle)

In [None]:
torch.manual_seed(10)
SNR = NoiseLevels
Error7 = []
NoiseApprox7 = []
for k in tqdm.tqdm(range(5)):
    ErrorN2 = []
    ENoise = []
    for i in range(500):
        tparams = mat_to_vals(DTISim7[i])
        tObs = Samples7[k,:,i]
        mat_true = vals_to_mat(tparams)
        evals_true,evecs_true = np.linalg.eigh(mat_true)
        true_signal_dti = single_tensor(gtabSim7, S0=200, evals=evals_true, evecs=evecs_true,
                           snr=None)
        posterior_samples_1 = posterior7.sample((InferSamples,), x=tObs,show_progress_bars=False)
        mat_guess = vals_to_mat(np.array(posterior_samples_1.mean(axis=0)))
        ErrorN2.append(Errors(mat_guess,mat_true,gtabSim7,true_signal_dti,tObs))
        ENoise.append(posterior_samples_1[:,-1].mean())
    NoiseApprox7.append(ENoise)
    Error7.append(ErrorN2)

NoiseApprox7 = np.array(NoiseApprox7)    


In [None]:
fig,axs = plt.subplots(1,2,figsize=(9,3))
ax = axs.ravel()
for ll,(a,E,t) in enumerate(zip(ax,np.array(Error7).T,Errors_name)):
    plt.sca(a) 
    bp = box_plot(E[:,1:].T,SBIFit-0.2, np.clip(SBIFit+0.2,0,1),positions=[1.3,2.3,3.3,4.3],showfliers=False,widths=0.3)
    bp2 = box_plot(Error_s[1:,-1,:,ll],WLSFit-0.2, np.clip(WLSFit+0.2,0,1),showfliers=False,widths=0.3)
    plt.sca(a)
    plt.xticks([1.15, 2.15, 3.15, 4.15,], NoiseLevels[1:],fontsize=32)
    ymax = max(bp2['whiskers'][1].get_ydata()[1],bp['whiskers'][1].get_ydata()[1])*1.2
    # Create custom legend handles
    plt.gca().ticklabel_format(axis='y',style='sci',scilimits=(-1,1))
    plt.yticks(fontsize=32)
    if(ll==1):
        handles = [
            Line2D([0], [0], color=SBIFit, lw=4, label='SBI'),  # Adjust color as per the actual plot color
        ]
        # Add the legenda
        #plt.legend(handles=handles,loc=2, bbox_to_anchor=(0,1.05),
        #           fontsize=36,columnspacing=0.3,handlelength=0.6,handletextpad=0.3)
    if(ll==2):
        handles = [
            Line2D([0], [0], color=WLSFit, lw=4, label='NLLS'),  # Adjust color as per the actual plot color
        ]
        # Add the legenda
        #plt.legend(handles=handles,loc=2, bbox_to_anchor=(0,1.15),
        #           fontsize=36,columnspacing=0.3,handlelength=0.6,handletextpad=0.3)
    #PlotSig(1,1.3,ymax,ydiff2=ymax*0.01,ydiff1=ymax*0.1)
    plt.grid()
plt.tight_layout()

if Save: plt.savefig(FigLoc+'SimDatDTIErrors1_7.pdf',format='pdf',bbox_inches='tight',transparent=True)

# Fig 3

In [None]:
FigLoc = image_path + 'Fig_3/'
if not os.path.exists(FigLoc):
    os.makedirs(FigLoc)

In [None]:
fimg_init, fbvals, fbvecs = get_fnames('small_64D')
bvals, bvecs = read_bvals_bvecs(fbvals, fbvecs)
hsph_initial = HemiSphere(xyz=bvecs[1:])
hsph_updated,_ = disperse_charges(hsph_initial,5000)
bvecs = np.vstack([[0,0,0],hsph_updated.vertices])
bvalsExt = np.hstack([bvals, 3000*np.ones_like(bvals)])
bvecsExt = np.vstack([bvecs, bvecs])
bvalsExt[65] = 0
gtabSim = gradient_table(bvalsExt, bvecsExt)

## a

In [None]:
if os.path.exists(f"{network_path}/DKISimFull.pickle"):
    with open(f"{network_path}/DKISimFull.pickle", "rb") as handle:
        posteriorFull = pickle.load(handle)
else:
    torch.manual_seed(1)
    np.random.seed(1)
    DT = []
    KT = []
    S0 = []
    DT1,KT1 = GenDTKT([DT1_full,DT2_full],[x4_full,R1_full,x2_full,R2_full],12,int(2.5*6000))
    DT2,KT2 = GenDTKT([DT1_lfa,DT2_lfa],[x4_lfa,R1_lfa,x2_lfa,R2_lfa],12,int(2.5*2000))
    DT3,KT3 = GenDTKT([DT1_hfa,DT2_hfa],[x4_hfa,R1_hfa,x2_hfa,R2_hfa],12,int(2.5*6000))
    DT4,KT4 = GenDTKT([DT1_ulfa,DT2_ulfa],[x4_ulfa,R1_ulfa,x2_ulfa,R2_ulfa],12,int(2.5*6000))
    DT5,KT5 = GenDTKT([DT1_hak,DT2_hak],[x4_hak,R1_hak,x2_hak,R2_hak],12,int(2.5*6000))
    
    
    DT = np.vstack([DT1,DT2,DT3,DT4,DT5])
    KT = np.vstack([KT1,KT2,KT3,KT4,KT5])
    
    S0 = np.array(S0).reshape(len(S0),1)
    
    indx = np.arange(len(KT))
    Obs = np.zeros([len(KT),len(gtabSim.bvecs)])
    kk = 0
    while len(indx)>0:
        for i in tqdm.tqdm(indx): 
            Obs[i] = CustomDKISimulator(DT[i],KT[i],gtabSim,200,np.random.rand()*30)
        
        indxNew = []
        for i,O in enumerate(Obs):
            if (O>800).any() or (O<0).any():
                indxNew.append(i)
        KT[indxNew] = KT[indxNew]/2
        DT[indxNew] = GenDTKT([DT1_full,DT2_full],[x4_full,R1_full,x2_full,R2_full],kk,1)[0]
    
        indx = indxNew
        kk+=1
    Par = np.hstack([DT,KT])
    Obs = torch.tensor(Obs).float()
    Par = torch.tensor(Par).float()
    
    # Create inference object. Here, NPE is used.
    inference = SNPE()
    
    # generate simulations and pass to the inference object
    inference = inference.append_simulations(Par, Obs)
    
    # train the density estimator and build the posterior
    density_estimator = inference.train(stop_after_epochs= 100)
    posteriorFull = inference.build_posterior(density_estimator)
    
    os.system('say "Network done."')
    if not os.path.exists(f"{network_path}/DKISimFull.pickle"):
        with open(f"{network_path}/DKISimFull.pickle", "wb") as handle:
            pickle.dump(posteriorFull, handle)

In [None]:
torch.manual_seed(2)
np.random.seed(2)
j = 1
vL = torch.tensor([0.2*j])
vS = torch.tensor([0.01*j])  

kk = np.random.randint(0,4)
if(kk==0):
    DT,KT = GenDTKT([DT1_full,DT2_full],[x4_full,R1_full,x2_full,R2_full],2,1)
elif(kk==1):
    DT,KT = GenDTKT([DT1_lfa,DT2_lfa],[x4_lfa,R1_lfa,x2_lfa,R2_lfa],2,1)
elif(kk==2):
    DT,KT = GenDTKT([DT1_hfa,DT2_hfa],[x4_hfa,R1_hfa,x2_hfa,R2_hfa],2,1)
elif(kk==3):
    DT,KT = GenDTKT([DT1_ulfa,DT2_ulfa],[x4_ulfa,R1_ulfa,x2_ulfa,R2_ulfa],2,1)

tObs = CustomDKISimulator(DT.squeeze(),KT.squeeze(),gtabSim,200,20)
tTrue = CustomDKISimulator(DT.squeeze(),KT.squeeze(),gtabSim,200,None)

In [None]:
torch.manual_seed(1)
np.random.seed(1)
posterior_samples_1 = posteriorFull.sample((InferSamples,), x=tObs,show_progress_bars=True)
GuessDKI = posterior_samples_1.mean(axis=0)
GuessSig = CustomDKISimulator(GuessDKI[:6],GuessDKI[6:],gtabSim,200)
plt.subplots(figsize=(6,1))
plt.plot(tTrue,lw=2,c='k',label='true signal')
plt.plot(GuessSig,lw=2,c=SBIFit,ls='--',label='SBI signal')
plt.axis('off')
plt.legend(ncols=2,loc=1,bbox_to_anchor =  (1.09,1.95),fontsize=32,columnspacing=0.3,handlelength=0.4,handletextpad=0.1)
if Save: plt.savefig(FigLoc+'FullReconSBI.pdf',format='pdf',bbox_inches='tight',transparent=True)

## b

In [None]:
dkimodel = dki.DiffusionKurtosisModel(gtabSim,fit_method='NLLS')
tenfit = dkimodel.fit(tObs)
plt.subplots(figsize=(6,1))
plt.plot(tTrue,lw=2,c='k',label='true signal')
plt.plot(tenfit.predict(gtabSim,200),lw=2,c=WLSFit,ls='--',label='NLLS signal')
plt.axis('off')
plt.legend(ncols=2,loc=1,bbox_to_anchor =  (1.09,1.95),fontsize=32,columnspacing=0.3,handlelength=0.4,handletextpad=0.1)
if Save: plt.savefig(FigLoc+'FullReconWLS.pdf',format='pdf',bbox_inches='tight',transparent=True)

## c

In [None]:
torch.manual_seed(1)
np.random.seed(1)
Mets = []
MetsSBI = []
for i in tqdm.tqdm([20,10,5,2]):
    m = []
    m2 = []
    for k in range(50):
        tObs = CustomDKISimulator(np.squeeze(DT), np.squeeze(KT),gtabSim, S0=200, snr=i)#
        dkimodel = dki.DiffusionKurtosisModel(gtabSim,fit_method='NLLS')
        tenfit = dkimodel.fit(tObs)
        m.append(DKIMetrics(tenfit.lower_triangular(),tenfit.kt,False))
        posterior_samples_1 = posteriorFull.sample((InferSamples,), x=tObs,show_progress_bars=False)
        GuessDKI = posterior_samples_1.mean(axis=0)
        m2.append(DKIMetrics(GuessDKI[:6],GuessDKI[6:],False))
    Mets.append(m)
    MetsSBI.append(m2)
Mets = np.array(Mets)
MetsSBI = np.array(MetsSBI)

In [None]:
fig,ax = plt.subplots(1,5,figsize=(22.5,3))
for i in range(5):
    plt.sca(ax[i])
    viol_plot(Mets[:,:,i].T,WLSFit,)
    viol_plot(MetsSBI[:,:,i].T,SBIFit,widths=0.3,positions=[1.3,2.3,3.3,4.3],)
    plt.xticks([1.15, 2.15, 3.15, 4.15,],[20,10,5,2],fontsize=32)
    plt.axhline(DKIMetrics(np.squeeze(DT),np.squeeze(KT),False)[i],lw=3,ls='--',c='k')
    plt.yticks(fontsize=32)

legend_elements = [
    mpatches.Patch(facecolor=WLSFit, edgecolor='k', label='NLLS')
]
ax[0].legend(handles=legend_elements, ncols=1,loc=1,bbox_to_anchor=(0.8,1),fontsize=32,columnspacing=0.5,handlelength=0.8,handletextpad=0.3)
legend_elements = [
    mpatches.Patch(facecolor=SBIFit, edgecolor='k', label='SBI')
]
ax[1].legend(handles=legend_elements, ncols=1,loc=1,bbox_to_anchor=(0.8,1),fontsize=32,columnspacing=0.5,handlelength=0.8,handletextpad=0.3)
legend_elements = [
    Line2D([0], [0], color='k', lw=3, ls='--', label='True value')
]
ax[3].legend(handles=legend_elements, ncols=1,loc=1,bbox_to_anchor=(0.95,1.1),fontsize=32,columnspacing=0.5,handlelength=0.8,handletextpad=0.3)

if Save: plt.savefig(FigLoc+'EgSigMetricsFull.pdf',format='pdf',bbox_inches='tight',transparent=True)

## d

In [None]:
torch.manual_seed(1)
np.random.seed(1)
DT1,KT1 = GenDTKT([DT1_full,DT2_full],[x4_full,R1_full,x2_full,R2_full],1,40)
DT2,KT2 = GenDTKT([DT1_lfa,DT2_lfa],[x4_lfa,R1_lfa,x2_lfa,R2_lfa],1,40)
DT3,KT3 = GenDTKT([DT1_hfa,DT2_hfa],[x4_hfa,R1_hfa,x2_hfa,R2_hfa],1,40)
DT4,KT4 = GenDTKT([DT1_ulfa,DT2_ulfa],[x4_ulfa,R1_ulfa,x2_ulfa,R2_ulfa],1,40)
DT5,KT5 = GenDTKT([DT1_hak,DT2_hak],[x4_hak,R1_hak,x2_hak,R2_hak],12,40)

SampsDT = np.vstack([DT1,DT2,DT3,DT4,DT5])
SampsKT = np.vstack([KT1,KT2,KT3,KT4,KT5])

Samples  = []

for Sd,Sk in zip(SampsDT,SampsKT):
    Samples.append([CustomDKISimulator(Sd,Sk,gtabSim, S0=200,snr=scale) for scale in NoiseLevels])

Samples = np.array(Samples)

In [None]:
torch.manual_seed(10)
ErrorFull = []
for k in tqdm.tqdm(range(5)):
    ErrorN2 = []
    ENoise = []
    for i in range(200):
        tObs = Samples[i,k,:]
        posterior_samples_1 = posteriorFull.sample((InferSamples,), x=tObs,show_progress_bars=False)
        GuessSBI = posterior_samples_1.mean(axis=0)
        
        ErrorN2.append(DKIErrors(GuessSBI[:6],GuessSBI[6:],SampsDT[i],SampsKT[i]))
    ErrorFull.append(ErrorN2)

Error_s = []
dkimodel = dki.DiffusionKurtosisModel(gtabSim,fit_method='NLLS')

for k in tqdm.tqdm(range(5)):
    ErrorN2 = []
    ENoise = []
    for i in range(200):
        tObs = Samples[i,k,:]#Simulator(bvals,bvecs,200,params,Noise)
        tenfit = dkimodel.fit(tObs)
        
        ErrorN2.append(DKIErrors(tenfit.lower_triangular(),tenfit.kt,SampsDT[i],SampsKT[i]))
    Error_s.append(ErrorN2)



In [None]:
ErrorFull = np.array(ErrorFull)
Error_s = np.array(Error_s)
ErrorNames = ['MK Error', 'AK Error', 'RK Error', 'MKT Error', 'KFA Error']
fig,ax = plt.subplots(1,5,figsize=(22.5,3))
for i in range(5):
    plt.sca(ax[i])
    box_plot(Error_s[1:,:,i],WLSFit-0.2, np.clip(WLSFit+0.2,0,1),showfliers=False,showmeans=False,widths=0.3)
    box_plot(ErrorFull[1:,:,i],SBIFit-0.2, np.clip(SBIFit+0.2,0,1),showfliers=False,widths=0.3,positions=[1.3,2.3,3.3,4.3])
    plt.xticks([1.15, 2.15, 3.15, 4.15,],[20,10,5,2],fontsize=32)
    plt.gca().ticklabel_format(axis='y',style='sci',scilimits=(-1,1))
    plt.grid(axis='y')
    plt.yticks(fontsize=32)

if Save: plt.savefig(FigLoc+'ErrorsFull.pdf',format='pdf',bbox_inches='tight',transparent=True)

## e

In [None]:
fimg_init, fbvals, fbvecs = get_fnames('small_64D')
bvals, bvecs = read_bvals_bvecs(fbvals, fbvecs)
hsph_initial15 = HemiSphere(xyz=bvecs[1:16])
hsph_initial7 = HemiSphere(xyz=bvecs[1:7])
hsph_updated15,_ = disperse_charges(hsph_initial15,5000)
hsph_updated7,_ = disperse_charges(hsph_initial7,5000)
gtabSimSub = gradient_table(np.array([0]+[1000]*6+[3000]*15).squeeze(), np.vstack([[0,0,0],hsph_updated7.vertices,hsph_updated15.vertices]))

In [None]:
if os.path.exists(f"{network_path}/7SampSim.pickle"):
    with open(f"{network_path}/7SampSim.pickle", "rb") as handle:
        posteriorFull = pickle.load(handle)
else:
    torch.manual_seed(1)
    np.random.seed(1)
    DT = []
    KT = []
    S0 = []
    DT1,KT1 = GenDTKT([DT1_full,DT2_full],[x4_full,R1_full,x2_full,R2_full],12,int(2.5*6000))
    DT2,KT2 = GenDTKT([DT1_lfa,DT2_lfa],[x4_lfa,R1_lfa,x2_lfa,R2_lfa],12,int(2.5*2000))
    DT3,KT3 = GenDTKT([DT1_hfa,DT2_hfa],[x4_hfa,R1_hfa,x2_hfa,R2_hfa],12,int(2.5*6000))
    DT4,KT4 = GenDTKT([DT1_ulfa,DT2_ulfa],[x4_ulfa,R1_ulfa,x2_ulfa,R2_ulfa],12,int(2.5*6000))
    DT5,KT5 = GenDTKT([DT1_hak,DT2_hak],[x4_hak,R1_hak,x2_hak,R2_hak],12,int(2.5*6000))
    
    
    DT = np.vstack([DT1,DT2,DT3,DT4,DT5])
    KT = np.vstack([KT1,KT2,KT3,KT4,KT5])
    
    S0 = np.array(S0).reshape(len(S0),1)
    
    indx = np.arange(len(KT))
    Obs = np.zeros([len(KT),len(gtabSimSub.bvecs)])
    kk = 0
    while len(indx)>0:
        for i in tqdm.tqdm(indx): 
            Obs[i] = CustomDKISimulator(DT[i],KT[i],gtabSimSub,200,np.random.rand()*30)
        
        indxNew = []
        for i,O in enumerate(Obs):
            if (O>800).any() or (O<0).any():
                indxNew.append(i)
        KT[indxNew] = KT[indxNew]/2
        DT[indxNew] = GenDTKT([DT1_full,DT2_full],[x4_full,R1_full,x2_full,R2_full],kk,1)[0]
    
        indx = indxNew
        kk+=1
    Par = np.hstack([DT,KT])
    Obs = torch.tensor(Obs).float()
    Par = torch.tensor(Par).float()
    
    # Create inference object. Here, NPE is used.
    inference = SNPE()
    
    # generate simulations and pass to the inference object
    inference = inference.append_simulations(Par, Obs)
    
    # train the density estimator and build the posterior
    density_estimator = inference.train(stop_after_epochs= 100)
    posteriorFull = inference.build_posterior(density_estimator)
    
    os.system('say "Network done."')
    if not os.path.exists(f"{network_path}/DKISimMin.pickle"):
        with open(f"{network_path}/DKISimMin.pickle", "wb") as handle:
            pickle.dump(posteriorFull, handle)

In [None]:
torch.manual_seed(2)
np.random.seed(2)
j = 1
vL = torch.tensor([0.2*j])
vS = torch.tensor([0.01*j])  

kk = np.random.randint(0,4)
if(kk==0):
    DT,KT = GenDTKT([DT1_full,DT2_full],[x4_full,R1_full,x2_full,R2_full],2,1)
elif(kk==1):
    DT,KT = GenDTKT([DT1_lfa,DT2_lfa],[x4_lfa,R1_lfa,x2_lfa,R2_lfa],2,1)
elif(kk==2):
    DT,KT = GenDTKT([DT1_hfa,DT2_hfa],[x4_hfa,R1_hfa,x2_hfa,R2_hfa],2,1)
elif(kk==3):
    DT,KT = GenDTKT([DT1_ulfa,DT2_ulfa],[x4_ulfa,R1_ulfa,x2_ulfa,R2_ulfa],2,1)

tObs = CustomDKISimulator(np.squeeze(DT),np.squeeze(KT),gtabSimSub,200,20)
tTrue = CustomDKISimulator(np.squeeze(DT),np.squeeze(KT),gtabSim,200,None)

In [None]:
torch.manual_seed(1)
np.random.seed(1)
posterior_samples_1 = posteriorFull.sample((InferSamples,), x=tObs,show_progress_bars=True)

In [None]:
GuessDKI = posterior_samples_1.mean(axis=0)
GuessSig = CustomDKISimulator(GuessDKI[:6],GuessDKI[6:],gtabSim,200)
plt.subplots(figsize=(6,1))
plt.plot(tTrue,lw=2,c='k',label='true signal')
plt.plot(GuessSig,lw=2,c=SBIFit,ls='--',label='SBI signal')
plt.axis('off')
plt.legend(ncols=2,loc=1,bbox_to_anchor =  (1.09,1.95),fontsize=32,columnspacing=0.3,handlelength=0.4,handletextpad=0.1)
plt.fill_betweenx(np.arange(0,500,50),0*np.ones(10),7*np.ones(10),color='gray',alpha=0.5)
plt.fill_betweenx(np.arange(0,500,50),64*np.ones(10),79*np.ones(10),color='gray',alpha=0.5)
plt.ylim(-9.996985449425491, 209.99985644997255)

if Save: plt.savefig(FigLoc+'7ReconSBI.pdf',format='pdf',bbox_inches='tight',transparent=True)

## f

In [None]:
dkimodel = dki.DiffusionKurtosisModel(gtabSimSub,fit_method='NLLS')
tenfit = dkimodel.fit(tObs)
plt.subplots(figsize=(6,1))
plt.plot(tTrue,lw=2,c='k',label='true signal')
plt.plot(tenfit.predict(gtabSim,200),lw=2,c=WLSFit,ls='--',label='NLLS signal')
plt.axis('off')
plt.legend(ncols=2,loc=1,bbox_to_anchor =  (1.09,1.95),fontsize=32,columnspacing=0.3,handlelength=0.4,handletextpad=0.1)
plt.fill_betweenx(np.arange(0,500,50),0*np.ones(10),7*np.ones(10),color='gray',alpha=0.5)
plt.fill_betweenx(np.arange(0,500,50),64*np.ones(10),79*np.ones(10),color='gray',alpha=0.5)
plt.ylim(-9.996985449425491, 209.99985644997255)
if Save: plt.savefig(FigLoc+'7ReconWLS.pdf',format='pdf',bbox_inches='tight',transparent=True)

## g

In [None]:
torch.manual_seed(1)
np.random.seed(1)
Mets = []
MetsSBI = []
for i in tqdm.tqdm([20,10,5,2]):
    m = []
    m2 = []
    for k in range(50):
        tObs = CustomDKISimulator(np.squeeze(DT), np.squeeze(KT),gtabSimSub, S0=200, snr=i)#
        dkimodel = dki.DiffusionKurtosisModel(gtabSimSub,fit_method='NLLS')
        tenfit = dkimodel.fit(tObs)
        m.append(DKIMetrics(tenfit.lower_triangular(),tenfit.kt,False))
        posterior_samples_1 = posteriorFull.sample((InferSamples,), x=tObs,show_progress_bars=False)
        GuessDKI = posterior_samples_1.mean(axis=0)
        m2.append(DKIMetrics(GuessDKI[:6],GuessDKI[6:],False))
    Mets.append(m)
    MetsSBI.append(m2)

In [None]:
Mets = np.array(Mets)
MetsSBI = np.array(MetsSBI)
fig,ax = plt.subplots(1,5,figsize=(22.5,3))
for i in range(5):
    plt.sca(ax[i])
    viol_plot(Mets[:,:,i].T,WLSFit,)
    viol_plot(MetsSBI[:,:,i].T,SBIFit,widths=0.3,positions=[1.3,2.3,3.3,4.3],)
    plt.xticks([1.15, 2.15, 3.15, 4.15,],[20,10,5,2],fontsize=32)
    plt.axhline(DKIMetrics(np.squeeze(DT),np.squeeze(KT),False)[i],lw=3,ls='--',c='k')
    plt.yticks(fontsize=32)
if Save: plt.savefig(FigLoc+'EgSigMetrics7.pdf',format='pdf',bbox_inches='tight',transparent=True)

## h

In [None]:
torch.manual_seed(1)
np.random.seed(1)

Samples7  = []

for Sd,Sk in zip(SampsDT,SampsKT):
    Samples7.append([CustomDKISimulator(Sd,Sk,gtabSimSub, S0=200,snr=scale) for scale in NoiseLevels])

Samples7 = np.array(Samples7)

In [None]:
torch.manual_seed(10)
ErrorFull = []
for k in tqdm.tqdm(range(5)):
    ErrorN2 = []
    ENoise = []
    for i in range(200):
        tObs = Samples7[i,k,:]
        posterior_samples_1 = posteriorFull.sample((InferSamples,), x=tObs,show_progress_bars=False)
        GuessSBI = posterior_samples_1.mean(axis=0)
        
        ErrorN2.append(DKIErrors(GuessSBI[:6],GuessSBI[6:],SampsDT[i],SampsKT[i]))
    ErrorFull.append(ErrorN2)

Error_s = []
dkimodel = dki.DiffusionKurtosisModel(gtabSimSub,fit_method='NLLS')

for k in tqdm.tqdm(range(5)):
    ErrorN2 = []
    ENoise = []
    for i in range(200):
        tObs = Samples7[i,k,:]#Simulator(bvals,bvecs,200,params,Noise)
        tenfit = dkimodel.fit(tObs)
        
        ErrorN2.append(DKIErrors(tenfit.lower_triangular(),tenfit.kt,SampsDT[i],SampsKT[i]))
    Error_s.append(ErrorN2)



In [None]:
ErrorFull = np.array(ErrorFull)
Error_s = np.array(Error_s)
ErrorNames = ['MK Error', 'AK Error', 'RK Error', 'MKT Error', 'KFA Error']
fig,ax = plt.subplots(1,5,figsize=(22.5,3))
for i in range(5):
    plt.sca(ax[i])
    box_plot(Error_s[1:,:,i],WLSFit-0.2, np.clip(WLSFit+0.2,0,1),showfliers=False,widths=0.3)
    box_plot(ErrorFull[1:,:,i],SBIFit-0.2, np.clip(SBIFit+0.2,0,1),showfliers=False,widths=0.3,positions=[1.3,2.3,3.3,4.3])
    plt.xticks([1.15, 2.15, 3.15, 4.15,],[20,10,5,2],fontsize=32)
    plt.gca().ticklabel_format(axis='y',style='sci',scilimits=(-1,1))
    plt.grid(axis='y')
    plt.yticks(fontsize=32)
if Save: plt.savefig(FigLoc+'Errors7.pdf',format='pdf',bbox_inches='tight',transparent=True)

# Fig 4

In [None]:
FigLoc = image_path + 'Fig_4/'
if not os.path.exists(FigLoc):
    os.makedirs(FigLoc)

## a

In [None]:
fdwi = './HCP_data/Pat'+str(1)+'/diff_1k.nii.gz'
bvalloc = './HCP_data/Pat'+str(1)+'/bvals_1k.txt'
bvecloc = './HCP_data/Pat'+str(1)+'/bvecs_1k.txt'

bvalsHCP = np.loadtxt(bvalloc)
bvecsHCP = np.loadtxt(bvecloc)
gtabHCP = gradient_table(bvalsHCP, bvecsHCP)

data, affine, img = load_nifti(fdwi, return_img=True)
data, affine = reslice(data, affine, (1.5,1.5,1.5), (2.5,2.5,2.5))
axial_middle = data.shape[2] // 2
maskdata, mask = median_otsu(data, vol_idx=range(10, 50), median_radius=3,
                             numpass=1, autocrop=True, dilate=2)

# Choose the first point (arbitrary starting point, e.g., the first gradient)
selected_indices = [1]
distance_matrix = squareform(pdist(bvecsHCP))
# Iteratively select the point furthest from the current selection
for _ in range(5):  # We need 7 points in total, and one is already selected
    remaining_indices = list(set(range(len(bvecsHCP))) - set(selected_indices))
    
    # Calculate the minimum distance to the selected points for each remaining point
    min_distances = np.min(distance_matrix[remaining_indices][:, selected_indices], axis=1)
    
    # Select the point with the maximum minimum distance
    next_index = remaining_indices[np.argmax(min_distances)]
    selected_indices.append(next_index)

selected_indices = [0]+selected_indices

bvalsHCP7 = bvalsHCP[selected_indices]
bvecsHCP7 = bvecsHCP[selected_indices]
gtabHCP7 = gradient_table(bvalsHCP7, bvecsHCP7)

In [None]:
custom_prior = DTIPriorS0Noise(lower_abs,upper_abs,lower_rest,upper_rest,lower_S0,upper_S0,0,30)
priorS0Noise, *_ = process_prior(custom_prior) 

In [None]:
if os.path.exists(f"{network_path}/DTIHCPFull.pickle"):
    with open(f"{network_path}/DTIHCPFull.pickle", "rb") as handle:
        posterior2 = pickle.load(handle)
else:
    np.random.seed(1)
    torch.manual_seed(1)
    bvals = gtabHCP.bvals
    bvecs = gtabHCP.bvecs
    Obs = []
    Par = []
    for i in tqdm.tqdm(range(TrainingSamples)):
        params = priorS0.sample()
        dt = ComputeDTI(params[:-1])
        if(np.random.rand()<0.2):
            dt = ForceLowFA(dt)
        Obs.append(CustomSimulator(dt,gtabHCP,params[-1],np.random.rand()*30))
        Par.append(np.hstack([mat_to_vals(ComputeDTI(params)),params[-1]]))
    
    Obs = np.array(Obs)
    Par = np.array(Par)
    Obs = torch.tensor(Obs).float()
    Par= torch.tensor(Par).float()
    
    # Create inference object. Here, NPE is used.
    inference = SNPE(prior=priorS0)
    
    # generate simulations and pass to the inference object
    inference = inference.append_simulations(Par, Obs)
    
    # train the density estimator and build the posterior
    density_estimator = inference.train(stop_after_epochs= 100)
    posterior2 = inference.build_posterior(density_estimator)
    if not os.path.exists(f"{network_path}/DTIHCPFull.pickle"):
        with open(f"{network_path}/DTIHCPFull.pickle", "wb") as handle:
            pickle.dump(posterior2, handle)

In [None]:
ArrShape = maskdata[:,:,axial_middle,0].shape
NoiseEst = np.zeros([55,64,7])
VarEst   = np.zeros([55,64])
for i in tqdm.tqdm(range(ArrShape[0])):
    for j in range(ArrShape[1]):
        torch.manual_seed(10)
        if(np.sum(maskdata[i,j,axial_middle,:]) == 0):
            pass
        else:
            np.random.seed(1)
            torch.manual_seed(1)
            posterior_samples_1 = posterior2.sample((InferSamples,), x=maskdata[i,j,axial_middle,:],show_progress_bars=False)
            NoiseEst[i,j] = np.array([histogram_mode(p) for p in posterior_samples_1.T])


In [None]:
NoiseEst2 =  np.zeros_like(NoiseEst)
for i in range(55):
    for j in range(64):    
        NoiseEst2[i,j] = np.hstack([mat_to_vals(clip_negative_eigenvalues(vals_to_mat(NoiseEst[i,j]))),NoiseEst[i,j,-1]])
MD_SBIFull = np.zeros([55,64])
FA_SBIFull = np.zeros([55,64])
for i in range(55):
    for j in range(64):
        Eigs = np.linalg.eigh(vals_to_mat(NoiseEst2[i,j,:6]))[0]
        MD_SBIFull[i,j] = np.mean(Eigs)
        FA_SBIFull[i,j] = FracAni(Eigs,np.mean(Eigs))
FA_SBIFull[np.isnan(FA_SBIFull)] = 0

In [None]:
tenmodel = dti.TensorModel(gtabHCP,return_S0_hat = True,fit_method='NLLS')
tenfit = tenmodel.fit(maskdata[:,:,axial_middle])
FAFull = dti.fractional_anisotropy(tenfit.evals)
MDFull = dti.mean_diffusivity(tenfit.evals)

In [None]:
for i in range(55):
    for j in range(64):
        if(np.sum(maskdata[i,j,axial_middle,:]) == 0):
            FAFull[i,j] = 0

In [None]:
img = plt.imshow(MD_SBIFull.T,cmap='gray')
plt.axis('off')
vmin, vmax = img.get_clim()

if Save: plt.savefig(FigLoc+'HCP_SBI_MD.pdf',format='pdf',bbox_inches='tight',transparent=True)

In [None]:
plt.imshow(MDFull.T,cmap='gray',vmin=vmin, vmax=vmax)
plt.axis('off')
cbar = plt.colorbar()
cbar.formatter.set_powerlimits((0, 0))

if Save: plt.savefig(FigLoc+'HCP_WLS_MD.pdf',format='pdf',bbox_inches='tight',transparent=True)

In [None]:
data = MDFull.T-MD_SBIFull.T
norm = TwoSlopeNorm(vmin=np.nanmin(data), vcenter=0, vmax=np.nanmax(data))
plt.imshow(data,cmap='seismic',norm=norm)
plt.axis('off')
cbar = plt.colorbar()
ticks = [np.nanmin(data), 0, np.nanmax(data)]  # Adjust the number of ticks as needed
cbar.set_ticks(ticks)
cbar.formatter.set_powerlimits((0, 0))

if Save: plt.savefig(FigLoc+'HCP_MD_Diff.pdf',format='pdf',bbox_inches='tight',transparent=True)

In [None]:
print(np.mean(np.abs(MD_SBIFull*mask[:,:,axial_middle]-MDFull*mask[:,:,axial_middle])))
ssim_noise = ssim(MD_SBIFull, MDFull, data_range=np.max([MD_SBIFull.max(),MDFull.max()])-np.min([MD_SBIFull.min(),MDFull.min()]))
print(ssim_noise)
Num = 2*np.abs(MD_SBIFull*mask[:,:,axial_middle]-MDFull*mask[:,:,axial_middle])
Den = np.abs(MDFull*mask[:,:,axial_middle])+np.abs(MD_SBIFull*mask[:,:,axial_middle])
print(np.nanmedian(Num/Den))

## b

In [None]:
if os.path.exists(f"{network_path}/DTIHCPMin.pickle"):
    with open(f"{network_path}/DTIHCPMin.pickle", "rb") as handle:
        posterior7_2 = pickle.load(handle)
else:
    np.random.seed(1)
    torch.manual_seed(1)
    Obs = []
    Par = []
    for i in tqdm.tqdm(range(TrainingSamples)):
        params = priorS0.sample()
        dt = ComputeDTI(params[:-1])
        if(np.random.rand()<0.2):
            dt = ForceLowFA(dt)
        Obs.append(CustomSimulator(dt,gtabHCP7,params[-1],np.random.rand()*30)[:7])
        Par.append(np.hstack([mat_to_vals(ComputeDTI(params)),params[-1]]))
    
    Obs = np.array(Obs)
    Par = np.array(Par)
    Obs = torch.tensor(Obs).float()
    Par= torch.tensor(Par).float()
    
    # Create inference object. Here, NPE is used.
    inference = SNPE(prior=priorS0)
    
    # generate simulations and pass to the inference object
    inference = inference.append_simulations(Par, Obs)
    
    # train the density estimator and build the posterior
    density_estimator = inference.train()
    posterior7_2= inference.build_posterior(density_estimator)
    if not os.path.exists(f"{network_path}/DTIHCPMin.pickle"):
        with open(f"{network_path}/DTIHCPMin.pickle", "wb") as handle:
            pickle.dump(posterior7_2, handle)

In [None]:
ArrShape = maskdata[:,:,axial_middle,0].shape
NoiseEst = np.zeros([55,64,7])
for i in tqdm.tqdm(range(ArrShape[0])):
    for j in range(ArrShape[1]):
        torch.manual_seed(10)
        if(np.sum(maskdata[i,j,axial_middle,:]) == 0):
            pass
        else:
            posterior_samples_1 = posterior7_2.sample((InferSamples,), x=maskdata[i,j,axial_middle,selected_indices],show_progress_bars=False)
            NoiseEst[i,j] = np.array([histogram_mode(p) for p in posterior_samples_1.T])
            


In [None]:
NoiseEst2 =  np.zeros_like(NoiseEst)
for i in range(55):
    for j in range(64):    
        NoiseEst2[i,j] = np.hstack([mat_to_vals(clip_negative_eigenvalues(vals_to_mat(NoiseEst[i,j]))),NoiseEst[i,j,-1]])
MD_SBI7 = np.zeros([55,64])
FA_SBI7 = np.zeros([55,64])
for i in range(55):
    for j in range(64):
        Eigs = np.linalg.eigh(vals_to_mat(NoiseEst2[i,j,:6]))[0]
        MD_SBI7[i,j] = np.mean(Eigs)
        FA_SBI7[i,j] = FracAni(Eigs,np.mean(Eigs))
FA_SBI7[np.isnan(FA_SBI7)] = 0

In [None]:
tenmodel = dti.TensorModel(gtabHCP7,return_S0_hat = True,fit_method='NLLS')
tenfit = tenmodel.fit(maskdata[:,:,axial_middle,selected_indices])
FA7 = dti.fractional_anisotropy(tenfit.evals)
MD7 = dti.mean_diffusivity(tenfit.evals)

In [None]:
for i in range(55):
    for j in range(64):
        if(np.sum(maskdata[i,j,axial_middle,:]) == 0):
            FA7[i,j] = 0

In [None]:
img = plt.imshow(MD_SBI7.T,cmap='gray')
plt.axis('off')
#cbar = plt.colorbar()
#cbar.formatter.set_powerlimits((0, 0))
vmin, vmax = img.get_clim()

if Save: plt.savefig(FigLoc+'HCP_SBI_MD_7.pdf',format='pdf',bbox_inches='tight',transparent=True)

In [None]:
plt.imshow(MD7.T,cmap='gray',vmin=vmin, vmax=vmax)
plt.axis('off')
cbar = plt.colorbar()
cbar.formatter.set_powerlimits((0, 0))
if Save: plt.savefig(FigLoc+'HCP_WLS_MD_7.pdf',format='pdf',bbox_inches='tight',transparent=True)

In [None]:
data = np.abs(MD_SBIFullNan.T-MD_SBI7Nan.T)
norm = TwoSlopeNorm(vmin=0, vcenter=np.max(data)/2, vmax=np.max(data))
plt.imshow(data,cmap='Reds',norm=norm)
plt.axis('off')
if Save: plt.savefig(FigLoc+'DTI_MDSBIErr.pdf',format='pdf',bbox_inches='tight',transparent=True)
plt.show()

norm = TwoSlopeNorm(vmin=0, vcenter=np.max(data)/2, vmax=np.max(data))
ticks = [0, np.round(np.max(data),3)]  # Adjust the number of ticks as needed
data = np.abs(MDFull.T-MD7.T)
plt.imshow(data,cmap='Reds',norm=norm)

plt.axis('off')
cbar = plt.colorbar()

cbar.set_ticks(ticks)
cbar.formatter.set_powerlimits((0, 0))
if Save: plt.savefig(FigLoc+'DTI_MDWLSErr.pdf',format='pdf',bbox_inches='tight',transparent=True)

In [None]:
print(np.mean(np.abs(MD_SBIFull*mask[:,:,axial_middle]-MD_SBI7*mask[:,:,axial_middle])))
ssim_noise = ssim(MD_SBIFull*mask[:,:,axial_middle], MD_SBI7*mask[:,:,axial_middle], data_range=MD_SBIFull.max()-MD_SBIFull.min())
print(ssim_noise)
Num = 2*np.abs(MD_SBIFull*mask[:,:,axial_middle]-MD_SBI7*mask[:,:,axial_middle])
Den = np.abs(MD_SBI7*mask[:,:,axial_middle])+np.abs(MD_SBIFull*mask[:,:,axial_middle])
print(np.nanmedian(Num/Den))

In [None]:
print(np.mean(np.abs(MDFull*mask[:,:,axial_middle]-MD7*mask[:,:,axial_middle])))
ssim_noise = ssim(MDFull*mask[:,:,axial_middle], MD7*mask[:,:,axial_middle], data_range=MDFull.max()-MDFull.min())
print(ssim_noise)
Num = 2*np.abs(MDFull*mask[:,:,axial_middle]-MD7*mask[:,:,axial_middle])
Den = np.abs(MD7*mask[:,:,axial_middle])+np.abs(MDFull*mask[:,:,axial_middle])
print(np.nanmedian(Num/Den))

## c

In [None]:
img = plt.imshow(FA_SBIFull.T,cmap='gray')
plt.axis('off')
vmin, vmax = img.get_clim()
if Save: plt.savefig(FigLoc+'HCP_SBI_FA.pdf',format='pdf',bbox_inches='tight',transparent=True)

In [None]:
plt.imshow(FAFull.T,cmap='gray')
plt.axis('off')
cbar = plt.colorbar()
tick_labels = ['{:.2e}'.format(t) for t in cbar.get_ticks()]
cbar.set_ticks(cbar.get_ticks())
cbar.set_ticklabels(tick_labels)
cbar.update_ticks()
if Save: plt.savefig(FigLoc+'HCP_WLS_FA.pdf',format='pdf',bbox_inches='tight',transparent=True)

In [None]:
data = FAFull.T-FA_SBIFull.T
plt.imshow(data,cmap='seismic',vmin=-1, vmax=1)
plt.axis('off')
cbar = plt.colorbar()
cbar.formatter.set_powerlimits((0, 0))
ticks = [-1, 0, 1]  # Adjust the number of ticks as needed
cbar.set_ticks(ticks)
if Save: plt.savefig(FigLoc+'HCP_FA_Diff.pdf',format='pdf',bbox_inches='tight',transparent=True)

In [None]:
print(np.mean(np.abs(FA_SBIFull*mask[:,:,axial_middle]-FAFull*mask[:,:,axial_middle])))
ssim_noise = ssim(FA_SBIFull, FAFull, data_range=np.max([FA_SBIFull.max(),FAFull.max()])-np.min([FA_SBIFull.min(),FAFull.min()]))
print(ssim_noise)
Num = 2*np.abs(FA_SBIFull*mask[:,:,axial_middle]-FAFull*mask[:,:,axial_middle])
Den = np.abs(FAFull*mask[:,:,axial_middle])+np.abs(FA_SBIFull*mask[:,:,axial_middle])
print(np.nanmedian(Num/Den))

## d

In [None]:
img = plt.imshow(FA_SBI7.T,cmap='gray')
plt.axis('off')
#cbar = plt.colorbar()
#cbar.formatter.set_powerlimits((0, 0))
vmin, vmax = img.get_clim()
if Save: plt.savefig(FigLoc+'HCP_SBI_FA_7.pdf',format='pdf',bbox_inches='tight',transparent=True)

In [None]:
img = plt.imshow(FA_SBI7.T,cmap='gray')
plt.axis('off')
#cbar = plt.colorbar()
#cbar.formatter.set_powerlimits((0, 0))
vmin, vmax = img.get_clim()
#if Save: plt.savefig(FigLoc+'HCP_SBI_FA_7.pdf',format='pdf',bbox_inches='tight',transparent=True)

In [None]:
plt.imshow(FA7.T,cmap='gray',vmin=vmin, vmax=vmax)
plt.axis('off')
cbar = plt.colorbar()
cbar.formatter.set_powerlimits((0, 0))
if Save: plt.savefig(FigLoc+'HCP_WLS_FA_7.pdf',format='pdf',bbox_inches='tight',transparent=True)

In [None]:
from matplotlib.colors import TwoSlopeNorm
data = np.abs(FA_SBIFull.T-FA_SBI7.T)
norm = TwoSlopeNorm(vmin=0, vcenter=np.max(data)/2, vmax=np.max(data))
plt.imshow(data,cmap='Reds',norm=norm)
plt.axis('off')
if Save: plt.savefig(FigLoc+'DTI_FASBIErr.pdf',format='pdf',bbox_inches='tight',transparent=True)
plt.show()

norm = TwoSlopeNorm(vmin=0, vcenter=np.max(data)/2, vmax=np.max(data))
ticks = [0, np.max(data)]  # Adjust the number of ticks as needed
data = np.abs(FAFull.T-FA7.T)
plt.imshow(data,cmap='Reds',norm=norm)

plt.axis('off')
cbar = plt.colorbar()

cbar.set_ticks(ticks)
if Save: plt.savefig(FigLoc+'DTI_FAWLSErr.pdf',format='pdf',bbox_inches='tight',transparent=True)

In [None]:
print(np.mean(np.abs(FA_SBIFull*mask[:,:,axial_middle]-FA_SBI7*mask[:,:,axial_middle])))
ssim_noise = ssim(FA_SBIFull*mask[:,:,axial_middle], FA_SBI7*mask[:,:,axial_middle], data_range=FA_SBIFull.max()-FA_SBIFull.min())
print(ssim_noise)
Num = 2*np.abs(FA_SBIFull*mask[:,:,axial_middle]-FA_SBI7*mask[:,:,axial_middle])
Den = np.abs(FA_SBI7*mask[:,:,axial_middle])+np.abs(FA_SBIFull*mask[:,:,axial_middle])
print(np.nanmedian(Num/Den))

In [None]:
print(np.mean(np.abs(FAFull*mask[:,:,axial_middle]-FA7*mask[:,:,axial_middle])))
ssim_noise = ssim(FAFull*mask[:,:,axial_middle], FA7*mask[:,:,axial_middle], data_range=FAFull.max()-FAFull.min())
print(ssim_noise)
Num = 2*np.abs(FAFull*mask[:,:,axial_middle]-FA7*mask[:,:,axial_middle])
Den = np.abs(FA7*mask[:,:,axial_middle])+np.abs(FAFull*mask[:,:,axial_middle])
print(np.nanmedian(Num/Den))

## e

In [None]:
custom_prior = DTIPriorS0Direc(lower_abs,upper_abs,lower_rest,upper_rest,lower_S0,upper_S0)
priorDirec, *_ = process_prior(custom_prior) 

In [None]:
gTabsF = []
gTabs20 = []
gTabs7 = []

Indices20 = []
Indices7  = []
FullDat   = []
for i in tqdm.tqdm(range(1,6)):
    fdwi = './HCP_data/Pat'+str(i)+'/diff_1k.nii.gz'
    bvalloc = './HCP_data/Pat'+str(i)+'/bvals_1k.txt'
    bvecloc = './HCP_data/Pat'+str(i)+'/bvecs_1k.txt'
    
    bvalsHCP = np.loadtxt(bvalloc)
    bvecsHCP = np.loadtxt(bvecloc)
    gtabHCP = gradient_table(bvalsHCP, bvecsHCP)
    gTabsF.append(gtabHCP)
    
    data, affine, img = load_nifti(fdwi, return_img=True)
    data, affine = reslice(data, affine, (1.5,1.5,1.5), (2.5,2.5,2.5))
    axial_middle = data.shape[2] // 2
    maskdata, mask = median_otsu(data, vol_idx=range(10, 50), median_radius=3,
                                 numpass=1, autocrop=True, dilate=2)
    maskdata, mask = median_otsu(data, vol_idx=range(10, 50), median_radius=3,
                                 numpass=1, autocrop=True, dilate=2)
    axial_middle = maskdata.shape[2] // 2
    TestData = maskdata[:, :, axial_middle, :]
    FlatTD = TestData.reshape(maskdata.shape[0]*maskdata.shape[1],69)
    FlatTD = FlatTD[FlatTD.sum(axis=-1)>0]
    FlatTD = FlatTD[~np.array(FlatTD<0).any(axis=-1)]
    FullDat.append(FlatTD)
    
    # Choose the first point (arbitrary starting point, e.g., the first gradient)
    selected_indices = [0]
    distance_matrix = squareform(pdist(bvecsHCP))
    # Iteratively select the point furthest from the current selection
    for _ in range(6):  # We need 7 points in total, and one is already selected
        remaining_indices = list(set(range(len(bvecsHCP))) - set(selected_indices))
        
        # Calculate the minimum distance to the selected points for each remaining point
        min_distances = np.min(distance_matrix[remaining_indices][:, selected_indices], axis=1)
        
        # Select the point with the maximum minimum distance
        next_index = remaining_indices[np.argmax(min_distances)]
        selected_indices.append(next_index)
    
    bvalsHCP7 = bvalsHCP[selected_indices]
    bvecsHCP7 = bvecsHCP[selected_indices]
    gtabHCP7 = gradient_table(bvalsHCP7, bvecsHCP7)

    gTabs7.append(gtabHCP7)
    Indices7.append(selected_indices)
    # Choose the first point (arbitrary starting point, e.g., the first gradient)
    selected_indices = [0]
    distance_matrix = squareform(pdist(bvecsHCP))
    # Iteratively select the point furthest from the current selection
    for _ in range(19):  # We need 7 points in total, and one is already selected
        remaining_indices = list(set(range(len(bvecsHCP))) - set(selected_indices))
        
        # Calculate the minimum distance to the selected points for each remaining point
        min_distances = np.min(distance_matrix[remaining_indices][:, selected_indices], axis=1)
        
        # Select the point with the maximum minimum distance
        next_index = remaining_indices[np.argmax(min_distances)]
        selected_indices.append(next_index)
    
    
    bvalsHCP20 = bvalsHCP[selected_indices]
    bvecsHCP20 = bvecsHCP[selected_indices]
    gtabHCP20 = gradient_table(bvalsHCP20, bvecsHCP20)
    gTabs20.append(gtabHCP20)
    Indices20.append(selected_indices)

In [None]:
gTabsPermF = []
gTabsPerm20 = []
gTabsPerm7 = []
Perms     = []
Perms20   = []
Perms7    = []
x = np.random.seed(116)
for i in tqdm.tqdm(range(1,6)):
    x = np.random.permutation(np.arange(69))
    x20 = np.random.permutation(np.arange(20))
    x7 = np.random.permutation(np.arange(7)) # Got to fix
    Perms.append(x)
    Perms20.append(x20)
    Perms7.append(x7)
    gTabsPermF.append(gradient_table(gTabsF[i-1].bvals[x], gTabsF[i-1].bvecs[x]))
    gTabsPerm20.append(gradient_table(gTabs20[i-1].bvals[x20], gTabs20[i-1].bvecs[x20]))
    gTabsPerm7.append(gradient_table(gTabs7[i-1].bvals[x7], gTabs7[i-1].bvecs[x7]))

In [None]:
if os.path.exists(f"{network_path}/DTIMultiHCPFull.pickle"):
    with open(f"{network_path}/DTIMultiHCPFull.pickle", "rb") as handle:
        posteriorFull = pickle.load(handle)
else:
    np.random.seed(1)
    torch.manual_seed(1)
    Obs = []
    Par = []
    for i in tqdm.tqdm(range(TrainingSamples)):
        params = priorDirec.sample()
        dt = ComputeDTI(params[:-2])
        if(np.random.rand()<0.2):
            dt = ForceLowFA(dt)
        cG = gTabsPermF[int(params[-1])]
        Obs.append(CustomSimulator(dt,cG,params[-2],np.random.rand()*30))
        Par.append(np.hstack([mat_to_vals(ComputeDTI(params)),params[-2],params[-1]]))
    
    Obs = np.array(Obs)
    Par = np.array(Par)
    
    Obs = torch.tensor(Obs).float()
    Par = torch.tensor(Par).float()
    
    
    # Create inference object. Here, NPE is used.
    inference = SNPE(prior=priorDirec)
    
    # generate simulations and pass to the inference object
    inference = inference.append_simulations(Par, Obs)
    
    # train the density estimator and build the posterior
    density_estimator = inference.train(stop_after_epochs= 100)
    posteriorFull = inference.build_posterior(density_estimator)
    if not os.path.exists(f"{network_path}/DTIMultiHCPFull.pickle"):
        with open(f"{network_path}/DTIMultiHCPFull.pickle", "wb") as handle:
            pickle.dump(posteriorFull, handle)

In [None]:
if os.path.exists(f"{network_path}/DTIMultiHCPMid.pickle"):
    with open(f"{network_path}/DTIMultiHCPMid.pickle", "rb") as handle:
        posterior20 = pickle.load(handle)
else:
    np.random.seed(1)
    torch.manual_seed(1)
    Obs = []
    Par = []
    for i in tqdm.tqdm(range(TrainingSamples)):
        params = priorDirec.sample()
        dt = ComputeDTI(params[:-2])
        if(np.random.rand()<0.2):
            dt = ForceLowFA(dt)
        cG = gTabsPerm20[int(params[-1])]
        Obs.append(CustomSimulator(dt,cG,params[-2],np.random.rand()*30)[:20])
        Par.append(np.hstack([mat_to_vals(ComputeDTI(params)),params[-2],params[-1]]))
    
    Obs = np.array(Obs)
    Par = np.array(Par)
    
    Obs = torch.tensor(Obs).float()
    Par = torch.tensor(Par).float()
    
    
    # Create inference object. Here, NPE is used.
    inference = SNPE(prior=priorDirec)
    
    # generate simulations and pass to the inference object
    inference = inference.append_simulations(Par, Obs)
    
    # train the density estimator and build the posterior
    density_estimator = inference.train(stop_after_epochs= 100)
    posterior20 = inference.build_posterior(density_estimator)
    if not os.path.exists(f"{network_path}/DTIMultiHCPMid.pickle"):
        with open(f"{network_path}/DTIMultiHCPMid.pickle", "wb") as handle:
            pickle.dump(posterior20, handle)

In [None]:
if os.path.exists(f"{network_path}/DTIMultiHCPMin.pickle"):
    with open(f"{network_path}/DTIMultiHCPMin.pickle", "rb") as handle:
        posterior7 = pickle.load(handle)
else:
    np.random.seed(1)
    torch.manual_seed(1)
    Obs = []
    Par = []
    for i in tqdm.tqdm(range(TrainingSamples)):
        params = priorDirec.sample()
        dt = ComputeDTI(params[:-2])
        if(np.random.rand()<0.2):
            dt = ForceLowFA(dt)
        cG = gTabsPerm7[int(params[-1])]
        Obs.append(CustomSimulator(dt,cG,params[-2],np.random.rand()*30)[:7])
        Par.append(np.hstack([mat_to_vals(ComputeDTI(params)),params[-2],params[-1]]))
    
    Obs = np.array(Obs)
    Par = np.array(Par)
    
    Obs = torch.tensor(Obs).float()
    Par = torch.tensor(Par).float()
    
    
    # Create inference object. Here, NPE is used.
    inference = SNPE(prior=priorDirec)
    
    # generate simulations and pass to the inference object
    inference = inference.append_simulations(Par, Obs)
    
    # train the density estimator and build the posterior
    density_estimator = inference.train(stop_after_epochs= 100)
    posterior7 = inference.build_posterior(density_estimator)
    if not os.path.exists(f"{network_path}/DTIMultiHCPMin.pickle"):
        with open(f"{network_path}/DTIMultiHCPMin.pickle", "wb") as handle:
            pickle.dump(posterior7, handle)

In [None]:
np.random.seed(1)
RandChoices = []
for F in FullDat:
    RandChoices.append(np.random.choice(len(F),200,replace=False))

SubDatUnPerm = np.vstack([F[R] for F,R in zip(FullDat,RandChoices)])
SubDatFull = np.vstack([F[R][:,P] for F,P,R in zip(FullDat,Perms,RandChoices)])
SubDat20 = np.vstack([F[R][:,np.array(I)[P]] for F,I,P,R in zip(FullDat,Indices20,Perms20,RandChoices)])
SubDat7 = np.vstack([F[R][:,np.array(I)[P]] for F,I,P,R in zip(FullDat,Indices7,Perms7,RandChoices)])

In [None]:
SBIGuessFull = []
for S in tqdm.tqdm(SubDatFull):
    SBIGuessFull.append(posteriorFull.sample((InferSamples,), x=S,show_progress_bars=False).mean(axis=0))
SBIGuessFull = np.vstack(SBIGuessFull)
SBIGuessFull_2 = np.vstack([mat_to_vals(clip_negative_eigenvalues(vals_to_mat(S))) for S in SBIGuessFull])

SBIGuess20 = []
for S in tqdm.tqdm(SubDat20):
    SBIGuess20.append(posterior20.sample((InferSamples,), x=S,show_progress_bars=False).mean(axis=0))
SBIGuess20 = np.vstack(SBIGuess20)
SBIGuess20_2 = np.vstack([mat_to_vals(clip_negative_eigenvalues(vals_to_mat(S))) for S in SBIGuess20])

SBIGuess7 = []
for S in tqdm.tqdm(SubDat7):
    SBIGuess7.append(posterior7.sample((InferSamples,), x=S,show_progress_bars=False).mean(axis=0))
SBIGuess7 = np.vstack(SBIGuess7)
SBIGuess7_2 = np.vstack([mat_to_vals(clip_negative_eigenvalues(vals_to_mat(S))) for S in SBIGuess7])

In [None]:
WLSGuess = []
for i,gtab in enumerate(gTabsF):
    tenmodel = dti.TensorModel(gtab,return_S0_hat = True,fit_method='NLLS')
    tenfit = tenmodel.fit(SubDatUnPerm[i*200:(i+1)*200])
    WLSGuess.append([mat_to_vals(t) for t in tenfit.quadratic_form])
WLSGuess = np.vstack(WLSGuess)

WLSGuess20 = []
for i,gtab in enumerate(gTabsPerm20):
    tenmodel = dti.TensorModel(gtab,return_S0_hat = True,fit_method='NLLS')
    tenfit = tenmodel.fit(SubDat20[i*200:(i+1)*200])
    WLSGuess20.append([mat_to_vals(t) for t in tenfit.quadratic_form])
WLSGuess20 = np.vstack(WLSGuess20)

WLSGuess7 = []
for i,gtab in enumerate(gTabsPerm7):
    tenmodel = dti.TensorModel(gtab,return_S0_hat = True,fit_method='NLLS')
    tenfit = tenmodel.fit(SubDat7[i*200:(i+1)*200])
    WLSGuess7.append([mat_to_vals(t) for t in tenfit.quadratic_form])
WLSGuess7 = np.vstack(WLSGuess7)

In [None]:
Errors20 = np.vstack([ErrorsMDFA(vals_to_mat(G),vals_to_mat(T)) for G,T in zip(SBIGuess20,SBIGuessFull)]) # SBI Truth
Errors7 = np.vstack([ErrorsMDFA(vals_to_mat(G),vals_to_mat(T)) for G,T in zip(SBIGuess7,SBIGuessFull)])
ErrorsFullW = np.vstack([ErrorsMDFA(vals_to_mat(G),vals_to_mat(T)) for G,T in zip(WLSGuess,SBIGuessFull)])

In [None]:
Percs20 = np.vstack([PercsMDFA(vals_to_mat(G),vals_to_mat(T)) for G,T in zip(SBIGuess20,SBIGuessFull)]) # SBI Truth
Percs7 = np.vstack([PercsMDFA(vals_to_mat(G),vals_to_mat(T)) for G,T in zip(SBIGuess7,SBIGuessFull)])
PercsFullW = np.vstack([PercsMDFA(vals_to_mat(G),vals_to_mat(T)) for G,T in zip(WLSGuess,SBIGuessFull)])

In [None]:
Percs20W_WLS = np.vstack([PercsMDFA(vals_to_mat(G),vals_to_mat(T)) for G,T in zip(WLSGuess20,WLSGuess)])
Percs7W_WLS = np.vstack([PercsMDFA(vals_to_mat(G),vals_to_mat(T)) for G,T in zip(WLSGuess7,WLSGuess)])

In [None]:
Errors20W_WLS = np.vstack([ErrorsMDFA(vals_to_mat(G),vals_to_mat(T)) for G,T in zip(WLSGuess20,WLSGuess)])
Errors7W_WLS = np.vstack([ErrorsMDFA(vals_to_mat(G),vals_to_mat(T)) for G,T in zip(WLSGuess7,WLSGuess)])

In [None]:
i=0
fig,ax = plt.subplots()
box_plot(ErrorsFullW[:,i],'black', 'gray',positions=[0],showfliers=False,widths=0.3,)
box_plot(np.array([Errors20[:,i],Errors7[:,i]]),SBIFit-0.2, np.clip(SBIFit+0.2,0,1),positions=[0.7,1],
         showfliers=False,widths=0.3)    
box_plot(np.array([Errors20W_WLS[:,i],Errors7W_WLS[:,i]]),WLSFit-0.2, np.clip(WLSFit+0.2,0,1),'/',positions=[1.8,2.1],
         showfliers=False,widths=0.3)

plt.xticks([0,0.7,1,1.8,2.1],['Full','Mid','Min','Mid','Min'],fontsize=32,rotation=90)

if(i == 0):
    handles = [
        mpatches.Patch(facecolor=np.clip(SBIFit+0.2,0,1),edgecolor=SBIFit-0.2, label='Truth: SBI-full'),  # Adjust color as per the actual plot color
    ]
    plt.legend(handles=handles,loc=2, fontsize=32,bbox_to_anchor=(0.15,1.2),columnspacing=0.3,handlelength=0.6,handletextpad=0.3)
if(i == 1):
    handles = [
            mpatches.Patch(facecolor=np.clip(WLSFit+0.2,0,1),edgecolor=WLSFit-0.2,hatch='/', label='Truth: NLLS-full') # Adjust color as per the actual plot color
            ]
    plt.legend(handles=handles,loc=2, fontsize=32,bbox_to_anchor=(0.15,1.2),columnspacing=0.3,handlelength=0.6,handletextpad=0.3)
plt.gca().ticklabel_format(axis='y',style='sci',scilimits=(-1,1))
if Save: plt.savefig(FigLoc+'ErrorsHCP_'+str(i)+'.pdf',format='pdf',bbox_inches='tight',transparent=True)
plt.show()

fig,ax = plt.subplots()
box_plot(PercsFullW[:,i],'black', 'gray',positions=[0],showfliers=False,widths=0.3,)
box_plot(np.array([Percs20[:,i],Percs7[:,i]]),SBIFit-0.2, np.clip(SBIFit+0.2,0,1),positions=[0.7,1],
         showfliers=False,widths=0.3)    
box_plot(np.array([Percs20W_WLS[:,i],Percs7W_WLS[:,i]]),WLSFit-0.2, np.clip(WLSFit+0.2,0,1),'/',positions=[1.8,2.1],
         showfliers=False,widths=0.3)

plt.xticks([0,0.7,1,1.8,2.1],['Full','Mid','Min','Mid','Min'],fontsize=32,rotation=90)

if(i == 0):
    handles = [
        mpatches.Patch(facecolor=np.clip(SBIFit+0.2,0,1),edgecolor=SBIFit-0.2, label='Truth: SBI-full'),  # Adjust color as per the actual plot color
    ]
    plt.legend(handles=handles,loc=2, fontsize=32,bbox_to_anchor=(0.15,1.2),columnspacing=0.3,handlelength=0.6,handletextpad=0.3)

handles = [
        mpatches.Patch(facecolor=np.clip(WLSFit+0.2,0,1),edgecolor=WLSFit-0.2,hatch='/', label='Truth: NLLS-full') # Adjust color as per the actual plot color
        ]
plt.legend(handles=handles,loc=2, fontsize=32,bbox_to_anchor=(0.15,1.2),columnspacing=0.3,handlelength=0.6,handletextpad=0.3)
plt.gca().ticklabel_format(axis='y',style='sci',scilimits=(-1,1))
plt.ylim([0,1])
plt.axhspan(0,0.33,color='gray',alpha=0.25)
plt.axhline(0.33,ls='--',color='k')
if Save: plt.savefig(FigLoc+'PercsHCP_'+str(i)+'.pdf',format='pdf',bbox_inches='tight',transparent=True)
plt.show()

## f


In [None]:
i=1
fig,ax = plt.subplots()
box_plot(ErrorsFullW[:,i],'black', 'gray',positions=[0],showfliers=False,widths=0.3,)
box_plot(np.array([Errors20[:,i],Errors7[:,i]]),SBIFit-0.2, np.clip(SBIFit+0.2,0,1),positions=[0.7,1],
         showfliers=False,widths=0.3)    
box_plot(np.array([Errors20W_WLS[:,i],Errors7W_WLS[:,i]]),WLSFit-0.2, np.clip(WLSFit+0.2,0,1),'/',positions=[1.8,2.1],
         showfliers=False,widths=0.3)

plt.xticks([0,0.7,1,1.8,2.1],['Full','Mid','Min','Mid','Min'],fontsize=32,rotation=90)

handles = [
    mpatches.Patch(facecolor=np.clip(SBIFit+0.2,0,1),edgecolor=SBIFit-0.2, label='Truth: SBI-full'),  # Adjust color as per the actual plot color
]
plt.legend(handles=handles,loc=2, fontsize=32,bbox_to_anchor=(0.15,1.2),columnspacing=0.3,handlelength=0.6,handletextpad=0.3)


plt.gca().ticklabel_format(axis='y',style='sci',scilimits=(-1,1))
if Save: plt.savefig(FigLoc+'ErrorsHCP_'+str(i)+'.pdf',format='pdf',bbox_inches='tight',transparent=True)
plt.show()

fig,ax = plt.subplots()
box_plot(PercsFullW[:,i],'black', 'gray',positions=[0],showfliers=False,widths=0.3,)
box_plot(np.array([Percs20[:,i],Percs7[:,i]]),SBIFit-0.2, np.clip(SBIFit+0.2,0,1),positions=[0.7,1],
         showfliers=False,widths=0.3)    
box_plot(np.array([Percs20W_WLS[:,i],Percs7W_WLS[:,i]]),WLSFit-0.2, np.clip(WLSFit+0.2,0,1),'/',positions=[1.8,2.1],
         showfliers=False,widths=0.3)

plt.xticks([0,0.7,1,1.8,2.1],['Full','Mid','Min','Mid','Min'],fontsize=32,rotation=90)

if(i == 0):
    handles = [
        mpatches.Patch(facecolor=np.clip(SBIFit+0.2,0,1),edgecolor=SBIFit-0.2, label='Truth: SBI-full'),  # Adjust color as per the actual plot color
    ]
    plt.legend(handles=handles,loc=2, fontsize=32,bbox_to_anchor=(0.15,1.2),columnspacing=0.3,handlelength=0.6,handletextpad=0.3)

handles = [
        mpatches.Patch(facecolor=np.clip(WLSFit+0.2,0,1),edgecolor=WLSFit-0.2,hatch='/', label='Truth: NLLS-full') # Adjust color as per the actual plot color
        ]
plt.legend(handles=handles,loc=2, fontsize=32,bbox_to_anchor=(0.15,1.2),columnspacing=0.3,handlelength=0.6,handletextpad=0.3)
plt.gca().ticklabel_format(axis='y',style='sci',scilimits=(-1,1))
plt.ylim([0,1])
plt.axhspan(0,0.33,color='gray',alpha=0.25)
plt.axhline(0.33,ls='--',color='k')
if Save: plt.savefig(FigLoc+'PercsHCP_'+str(i)+'.pdf',format='pdf',bbox_inches='tight',transparent=True)
plt.show()

# Fig 5

In [None]:
InferSamples = 100

In [None]:
FigLoc = image_path + 'Fig_5/'
if not os.path.exists(FigLoc):
    os.makedirs(FigLoc)

## a

In [None]:
i=3
fdwi = './HCP_data/Pat'+str(i)+'/diff_1k.nii.gz'
bvalloc = './HCP_data/Pat'+str(i)+'/bvals_1k.txt'
bvecloc = './HCP_data/Pat'+str(i)+'/bvecs_1k.txt'

fdwi3 = './HCP_data/Pat'+str(i)+'/diff_3k.nii.gz'
bvalloc3 = './HCP_data/Pat'+str(i)+'/bvals_3k.txt'
bvecloc3 = './HCP_data/Pat'+str(i)+'/bvecs_3k.txt'

bvalsHCP = np.loadtxt(bvalloc)
bvecsHCP = np.loadtxt(bvecloc)
gtabHCP = gradient_table(bvalsHCP, bvecsHCP)

bvalsHCP3 = np.loadtxt(bvalloc3)
bvecsHCP3 = np.loadtxt(bvecloc3)
gtabHCP3 = gradient_table(bvalsHCP3, bvecsHCP3)

gtabExt  = gradient_table(np.hstack((bvalsHCP,bvalsHCP3)), np.vstack((bvecsHCP,bvecsHCP3)))

data, affine, img = load_nifti(fdwi, return_img=True)
data, affine = reslice(data, affine, (1.5,1.5,1.5), (2.5,2.5,2.5))
maskdata, mask = median_otsu(data, vol_idx=range(10, 50), median_radius=4,
                             numpass=1, autocrop=False, dilate=2)
_, mask2 = median_otsu(data, vol_idx=range(10, 50), median_radius=4,
                             numpass=1, autocrop=True, dilate=2)


data3, affine, img = load_nifti(fdwi3, return_img=True)
data3, affine = reslice(data3, affine, (1.5,1.5,1.5), (2.5,2.5,2.5))
# Get the indices of True values
true_indices = np.argwhere(mask)

# Determine the minimum and maximum indices along each dimension
min_coords = true_indices.min(axis=0)
max_coords = true_indices.max(axis=0)

maskdata  = maskdata[min_coords[0]:max_coords[0]+1,min_coords[1]:max_coords[1]+1,min_coords[2]:max_coords[2]+1]
axial_middle = maskdata.shape[2] // 2
maskdata3 = data3[min_coords[0]:max_coords[0]+1,min_coords[1]:max_coords[1]+1,min_coords[2]:max_coords[2]+1]

TestData = np.concatenate([maskdata[:, :, axial_middle, :],maskdata3[:, :, axial_middle, :]],axis=-1)
TestData4D = np.concatenate([maskdata,maskdata3],axis=-1)

In [None]:
np.random.seed(1)
torch.manual_seed(1)
if os.path.exists(f"{network_path}/DKIHCPFull.pickle"):
    with open(f"{network_path}/DKIHCPFull.pickle", "rb") as handle:
        posteriorFull = pickle.load(handle)
else:
    torch.manual_seed(1)
    np.random.seed(1)
    DT = []
    KT = []
    S0 = []
    DT1,KT1 = GenDTKT([DT1_full,DT2_full],[x4_full,R1_full,x2_full,R2_full],12,int(4*3000))
    DT2,KT2 = GenDTKT([DT1_lfa,DT2_lfa],[x4_lfa,R1_lfa,x2_lfa,R2_lfa],12,int(4*1000))
    DT3,KT3 = GenDTKT([DT1_hfa,DT2_hfa],[x4_hfa,R1_hfa,x2_hfa,R2_hfa],12,int(4*3000))
    DT4,KT4 = GenDTKT([DT1_ulfa,DT2_ulfa],[x4_ulfa,R1_ulfa,x2_ulfa,R2_ulfa],12,int(4*3000))
    DT5,KT5 = GenDTKT([DT1_hak,DT2_hak],[x4_hak,R1_hak,x2_hak,R2_hak],12,int(4*3000))
    
    
    DT = np.vstack([DT1,DT2,DT3,DT4,DT5])
    KT = np.vstack([KT1,KT2,KT3,KT4,KT5])
    
    S0Dist = BoxUniform(low=torch.tensor([lower_S0]), high=torch.tensor([upper_S0]))
    
    S0 = S0Dist.sample([4*13000])
    
    S0 = np.array(S0).reshape(len(S0),1)
    
    indx = np.arange(len(KT))
    Obs = np.zeros([len(KT),len(gtabExt.bvecs)])
    kk = 0
    while len(indx)>0:
        for i in tqdm.tqdm(indx): 
            Obs[i] = CustomDKISimulator(DT[i],KT[i],gtabExt,S0[i],np.random.rand()*20 + 30)
        
        indxNew = []
        for i,O in enumerate(Obs):
            if (O>4*np.array(S0[i])).any() or (O<0).any():
                indxNew.append(i)
        KT[indxNew] = KT[indxNew]/2
        DT[indxNew] = GenDTKT([DT1_full,DT2_full],[x4_full,R1_full,x2_full,R2_full],kk,1)[0]
    
        indx = indxNew
        kk+=1
    Par = np.hstack([DT,KT,S0])
    Obs = torch.tensor(Obs).float()
    Par = torch.tensor(Par).float()
    
    # Create inference object. Here, NPE is used.
    inference = SNPE()
    
    # generate simulations and pass to the inference object
    inference = inference.append_simulations(Par, Obs)
    
    # train the density estimator and build the posterior
    density_estimator = inference.train(stop_after_epochs= 100)
    posteriorFull = inference.build_posterior(density_estimator)
    
    if not os.path.exists(f"{network_path}/DKIHCPFull.pickle"):
        with open(f"{network_path}/DKIHCPFull.pickle", "wb") as handle:
            pickle.dump(posteriorFull, handle)

ArrShape = TestData4D[:,:,axial_middle,0].shape
NoiseEst = np.zeros([62, 68 ,22])
torch.manual_seed(10)
for i in tqdm.tqdm(range(ArrShape[0])):
    for j in range(ArrShape[1]):
        if(np.sum(TestData4D[i,j,axial_middle,:69],axis=-1) == 0):
            pass
        else:
            posterior_samples_1 = posteriorFull.sample((InferSamples,), x=TestData4D[i,j,axial_middle,:],show_progress_bars=False)
            NoiseEst[i,j] = np.array([histogram_mode(p) for p in posterior_samples_1.T])

In [None]:
NoiseEst2 =  np.zeros_like(NoiseEst)
for i in range(62):
    for j in range(68):    
        NoiseEst2[i,j] = np.hstack([mat_to_vals(clip_negative_eigenvalues(vals_to_mat(NoiseEst[i,j]))),NoiseEst[i,j,6:]])

In [None]:
MK_SBIFull  = np.zeros([62, 68])
AK_SBIFull  = np.zeros([62, 68])
RK_SBIFull  = np.zeros([62, 68])
MKT_SBIFull = np.zeros([62, 68])
KFA_SBIFull = np.zeros([62, 68])
for i in tqdm.tqdm(range(62)):
    for j in range(68):
        Metrics = DKIMetrics(NoiseEst2[i,j][:6],NoiseEst2[i,j][6:21])
        MK_SBIFull[i,j] = Metrics[0]
        AK_SBIFull[i,j] = Metrics[1]
        RK_SBIFull[i,j] = Metrics[2]
        MKT_SBIFull[i,j] = Metrics[3]
        KFA_SBIFull[i,j] = Metrics[4]

In [None]:
dkimodelNL = dki.DiffusionKurtosisModel(gtabExt,fit_method='NLLS')
dkifitNL = dkimodelNL.fit(TestData[:,:,:])
MK_NLFull  = np.zeros([62, 68])
AK_NLFull  = np.zeros([62, 68])
RK_NLFull  = np.zeros([62, 68])
MKT_NLFull = np.zeros([62, 68])
KFA_NLFull = np.zeros([62, 68])
for i in range(62):
    for j in range(68):
        Metrics = DKIMetrics(dkifitNL[i,j].lower_triangular(),dkifitNL[i,j].kt)
        MK_NLFull[i,j] = Metrics[0]
        AK_NLFull[i,j] = Metrics[1]
        RK_NLFull[i,j] = Metrics[2]
        MKT_NLFull[i,j] = Metrics[3]
        KFA_NLFull[i,j] = Metrics[4]

In [None]:
KFA_SBIFull[np.isnan(KFA_SBIFull)] = 1

In [None]:
tnorm = TwoSlopeNorm(vmin=-0,vcenter = 0.5,vmax=1)
plt.imshow((MK_SBIFull*mask2[:,:,axial_middle]).T,norm=tnorm,cmap='gray')
plt.xticks([])
plt.yticks([])
if Save: plt.savefig(FigLoc+'MKSBIFull.pdf',format='pdf',bbox_inches='tight',transparent=True)
plt.show()
tnorm = TwoSlopeNorm(vmin=-0,vcenter = 0.5,vmax=1)
plt.imshow((AK_SBIFull*mask2[:,:,axial_middle]).T,norm=tnorm,cmap='gray')
plt.xticks([])
plt.yticks([])
if Save: plt.savefig(FigLoc+'AKSBIFull.pdf',format='pdf',bbox_inches='tight',transparent=True)
plt.show()
tnorm = TwoSlopeNorm(vmin=-0,vcenter = 0.5,vmax=1)
plt.imshow((RK_SBIFull*mask2[:,:,axial_middle]).T,norm=tnorm,cmap='gray')
plt.xticks([])
plt.yticks([])
if Save: plt.savefig(FigLoc+'RKSBIFull.pdf',format='pdf',bbox_inches='tight',transparent=True)
plt.show()
tnorm = TwoSlopeNorm(vmin=-0,vcenter = 0.5,vmax=1)
plt.imshow((MKT_SBIFull*mask2[:,:,axial_middle]).T,norm=tnorm,cmap='gray')
plt.xticks([])
plt.yticks([])
if Save: plt.savefig(FigLoc+'MKTSBIFull.pdf',format='pdf',bbox_inches='tight',transparent=True)
plt.show()
tnorm = TwoSlopeNorm(vmin=-0,vcenter = 0.5,vmax=1)
plt.imshow((KFA_SBIFull*mask2[:,:,axial_middle]).T,norm=tnorm,cmap='gray')
plt.colorbar()
plt.xticks([])
plt.yticks([])
if Save: plt.savefig(FigLoc+'KFASBIFull.pdf',format='pdf',bbox_inches='tight',transparent=True)
plt.show()

## b

In [None]:
tnorm = TwoSlopeNorm(vmin=-0,vcenter = 0.5,vmax=1)
plt.imshow((MK_NLFull*mask2[:,:,axial_middle]).T,norm=tnorm,cmap='gray')
plt.xticks([])
plt.yticks([])
if Save: plt.savefig(FigLoc+'MKNLFull.pdf',format='pdf',bbox_inches='tight',transparent=True)
plt.show()
tnorm = TwoSlopeNorm(vmin=-0,vcenter = 0.5,vmax=1)
plt.imshow((AK_NLFull*mask2[:,:,axial_middle]).T,norm=tnorm,cmap='gray')
plt.xticks([])
plt.yticks([])
if Save: plt.savefig(FigLoc+'AKNLFull.pdf',format='pdf',bbox_inches='tight',transparent=True)
plt.show()
tnorm = TwoSlopeNorm(vmin=-0,vcenter = 0.5,vmax=1)
plt.imshow((RK_NLFull*mask2[:,:,axial_middle]).T,norm=tnorm,cmap='gray')
plt.xticks([])
plt.yticks([])
if Save: plt.savefig(FigLoc+'RKNLFull.pdf',format='pdf',bbox_inches='tight',transparent=True)
plt.show()
tnorm = TwoSlopeNorm(vmin=-0,vcenter = 0.5,vmax=1)
plt.imshow((MKT_NLFull*mask2[:,:,axial_middle]).T,norm=tnorm,cmap='gray')
plt.xticks([])
plt.yticks([])
if Save: plt.savefig(FigLoc+'MKTNLFull.pdf',format='pdf',bbox_inches='tight',transparent=True)
plt.show()
tnorm = TwoSlopeNorm(vmin=-0,vcenter = 0.5,vmax=1)
plt.imshow((KFA_NLFull*mask2[:,:,axial_middle]).T,norm=tnorm,cmap='gray')
plt.colorbar()
plt.xticks([])
plt.yticks([])
if Save: plt.savefig(FigLoc+'KFANLFull.pdf',format='pdf',bbox_inches='tight',transparent=True)
plt.show()

In [None]:
for S,N in zip([MK_NLFull,AK_NLFull,RK_NLFull,MKT_NLFull,KFA_NLFull],[MK_SBIFull,AK_SBIFull,RK_SBIFull,MKT_SBIFull,KFA_SBIFull]):
    print('==')
    print(np.mean(np.abs(S*mask2[:,:,axial_middle]-N*mask2[:,:,axial_middle])))
    ssim_noise = ssim(S, N, data_range=np.max([S.max(),N.max()])-np.min([S.min(),N.min()]))
    print(ssim_noise)
    Num = 2*np.abs(S*mask2[:,:,axial_middle]-N*mask2[:,:,axial_middle])
    Den = np.abs(N*mask2[:,:,axial_middle])+np.abs(S*mask2[:,:,axial_middle])
    print(np.nanmean(Num/Den))
    print('==')

## c

In [None]:
i=3
bvalloc = './HCP_data/Pat'+str(i)+'/bvals_1k.txt'
bvecloc = './HCP_data/Pat'+str(i)+'/bvecs_1k.txt'
bvalsHCP = np.loadtxt(bvalloc)
bvecsHCP = np.loadtxt(bvecloc)
gtabHCP = gradient_table(bvalsHCP, bvecsHCP)

# Choose the first point (arbitrary starting point, e.g., the first gradient)
selected_indices = [1]
distance_matrix = squareform(pdist(bvecsHCP))
# Iteratively select the point furthest from the current selection
for _ in range(5):  # We need 7 points in total, and one is already selected
    remaining_indices = list(set(range(len(bvecsHCP))) - set(selected_indices))
    
    # Calculate the minimum distance to the selected points for each remaining point
    min_distances = np.min(distance_matrix[remaining_indices][:, selected_indices], axis=1)
    
    # Select the point with the maximum minimum distance
    next_index = remaining_indices[np.argmax(min_distances)]
    selected_indices.append(next_index)

selected_indices7 = [0]+selected_indices

bvalsHCP7_1 = bvalsHCP[selected_indices7]
bvecsHCP7_1 = bvecsHCP[selected_indices7]

i=3
bvalloc = './HCP_data/Pat'+str(i)+'/bvals_3k.txt'
bvecloc = './HCP_data/Pat'+str(i)+'/bvecs_3k.txt'

bvalsHCP3 = np.loadtxt(bvalloc)
bvecsHCP3 = np.loadtxt(bvecloc)
gtabHCP3 = gradient_table(bvalsHCP, bvecsHCP)

# Choose the first point (arbitrary starting point, e.g., the first gradient)
selected_indices = [0]

temp_bvecs = bvecsHCP3[bvalsHCP3>0]
temp_bvals = bvalsHCP3[bvalsHCP3>0]
distance_matrix = squareform(pdist(temp_bvecs))
# Iteratively select the point furthest from the current selection
for _ in range(14):  # We need 7 points in total, and one is already selected
    remaining_indices = list(set(range(len(temp_bvecs))) - set(selected_indices))
    
    # Calculate the minimum distance to the selected points for each remaining point
    min_distances = np.min(distance_matrix[remaining_indices][:, selected_indices], axis=1)
    
    # Select the point with the maximum minimum distance
    next_index = remaining_indices[np.argmax(min_distances)]
    selected_indices.append(next_index)

bvalsHCP7_3 = temp_bvals[selected_indices]
bvecsHCP7_3 = temp_bvecs[selected_indices]

gtabHCP7 = gradient_table(np.hstack((bvalsHCP7_1,bvalsHCP7_3)), np.vstack((bvecsHCP7_1,bvecsHCP7_3)))

true_indx = []
for b in bvecsHCP7_3:
    true_indx.append(np.linalg.norm(b-bvecsHCP3,axis=1).argmin())
true_indx = selected_indices7+[t+69 for t in true_indx]
gtabHCP7 = gradient_table(np.hstack((bvalsHCP7_1,bvalsHCP7_3)), np.vstack((bvecsHCP7_1,bvecsHCP7_3)))

In [None]:
np.random.seed(1)
torch.manual_seed(1)
if os.path.exists(f"{network_path}/DKIHCPMin.pickle"):
    with open(f"{network_path}/DKIHCPMin.pickle", "rb") as handle:
        posteriorFull = pickle.load(handle)
else:
    DT = []
    KT = []
    S0 = []
    DT1,KT1 = GenDTKT([DT1_full,DT2_full],[x4_full,R1_full,x2_full,R2_full],12,int(4*30000))
    DT2,KT2 = GenDTKT([DT1_lfa,DT2_lfa],[x4_lfa,R1_lfa,x2_lfa,R2_lfa],12,int(4*10000))
    DT3,KT3 = GenDTKT([DT1_hfa,DT2_hfa],[x4_hfa,R1_hfa,x2_hfa,R2_hfa],12,int(4*30000))
    DT4,KT4 = GenDTKT([DT1_ulfa,DT2_ulfa],[x4_ulfa,R1_ulfa,x2_ulfa,R2_ulfa],12,int(4*30000))
    DT5,KT5 = GenDTKT([DT1_hak,DT2_hak],[x4_hak,R1_hak,x2_hak,R2_hak],12,int(4*30000))
    
    
    DT = np.vstack([DT1,DT2,DT3,DT4,DT5])
    KT = np.vstack([KT1,KT2,KT3,KT4,KT5])
    
    S0Dist = BoxUniform(low=torch.tensor([lower_S0]), high=torch.tensor([upper_S0]))
    
    S0 = S0Dist.sample([520000])
    
    S0 = np.array(S0).reshape(len(S0),1)
    
    indx = np.arange(len(KT))
    Obs = np.zeros([len(KT),len(gtabHCP7.bvecs)])
    kk = 0
    while len(indx)>0:
        for i in tqdm.tqdm(indx): 
            Obs[i] = CustomDKISimulator(DT[i],KT[i],gtabHCP7,S0[i],np.random.rand()*20 + 30)
        
        indxNew = []
        for i,O in enumerate(Obs):
            if (O>4*np.array(S0[i])).any() or (O<0).any():
                indxNew.append(i)
        KT[indxNew] = KT[indxNew]/2
        DT[indxNew] = GenDTKT([DT1_full,DT2_full],[x4_full,R1_full,x2_full,R2_full],kk,1)[0]
    
        indx = indxNew
        kk+=1
    Par = np.hstack([DT,KT,S0])
    Obs = torch.tensor(Obs).float()
    Par = torch.tensor(Par).float()
    
    # Create inference object. Here, NPE is used.
    inference = SNPE()
    
    # generate simulations and pass to the inference object
    inference = inference.append_simulations(Par, Obs)
    
    # train the density estimator and build the posterior
    density_estimator = inference.train(stop_after_epochs= 100)
    posteriorFull = inference.build_posterior(density_estimator)
    
    if not os.path.exists(f"{network_path}/DKIHCPMin.pickle"):
        with open(f"{network_path}/DKIHCPMin.pickle", "wb") as handle:
            pickle.dump(posteriorFull, handle)
ArrShape = TestData4D[:,:,axial_middle,0].shape
NoiseEst7 = np.zeros([62, 68 ,22])
torch.manual_seed(10)
for i in tqdm.tqdm(range(ArrShape[0])):
    for j in range(ArrShape[1]):
        if(np.sum(TestData4D[i,j,axial_middle,:69],axis=-1) == 0):
            pass
        else:
            posterior_samples_1 = posteriorFull.sample((InferSamples,), x=TestData4D[i,j,axial_middle,true_indx],show_progress_bars=False)
            NoiseEst7[i,j] = np.array([histogram_mode(p) for p in posterior_samples_1.T])


In [None]:
NoiseEst2 =  np.zeros_like(NoiseEst)
for i in range(62):
    for j in range(68):    
        NoiseEst2[i,j] = np.hstack([mat_to_vals(clip_negative_eigenvalues(vals_to_mat(NoiseEst7[i,j]))),NoiseEst7[i,j,6:]])

In [None]:
MK_SBI7  = np.zeros([62, 68])
AK_SBI7  = np.zeros([62, 68])
RK_SBI7  = np.zeros([62, 68])
MKT_SBI7 = np.zeros([62, 68])
KFA_SBI7 = np.zeros([62, 68])
for i in tqdm.tqdm(range(62)):
    for j in range(68):
        Metrics = DKIMetrics(NoiseEst2[i,j][:6],NoiseEst2[i,j][6:21])
        MK_SBI7[i,j] = Metrics[0]
        AK_SBI7[i,j] = Metrics[1]
        RK_SBI7[i,j] = Metrics[2]
        MKT_SBI7[i,j] = Metrics[3]
        KFA_SBI7[i,j] = Metrics[4]

In [None]:
dkimodelNL = dki.DiffusionKurtosisModel(gtabHCP7,fit_method='NLLS')
dkifitNL = dkimodelNL.fit(TestData[:,:,true_indx])
MK_NL7  = np.zeros([62, 68])
AK_NL7  = np.zeros([62, 68])
RK_NL7 = np.zeros([62, 68])
MKT_NL7 = np.zeros([62, 68])
KFA_NL7 = np.zeros([62, 68])
for i in range(62):
    for j in range(68):
        Metrics = DKIMetrics(dkifitNL[i,j].lower_triangular(),dkifitNL[i,j].kt)
        MK_NL7[i,j] = Metrics[0]
        AK_NL7[i,j] = Metrics[1]
        RK_NL7[i,j] = Metrics[2]
        MKT_NL7[i,j] = Metrics[3]
        KFA_NL7[i,j] = Metrics[4]

In [None]:
KFA_SBIFull[np.isnan(KFA_SBIFull)] = 1

In [None]:
KFA_SBI7[np.isnan(KFA_SBI7)] = 1

In [None]:
tnorm = TwoSlopeNorm(vmin=-0,vcenter = 0.5,vmax=1)
plt.imshow((MK_SBI7*mask2[:,:,axial_middle]).T,norm=tnorm,cmap='gray')
plt.xticks([])
plt.yticks([])
if Save: plt.savefig(FigLoc+'MKSBI7.pdf',format='pdf',bbox_inches='tight',transparent=True)
plt.show()
tnorm = TwoSlopeNorm(vmin=-0,vcenter = 0.5,vmax=1)
plt.imshow((AK_SBI7*mask2[:,:,axial_middle]).T,norm=tnorm,cmap='gray')
plt.xticks([])
plt.yticks([])
if Save: plt.savefig(FigLoc+'AKSBI7.pdf',format='pdf',bbox_inches='tight',transparent=True)
plt.show()
tnorm = TwoSlopeNorm(vmin=-0,vcenter = 0.5,vmax=1)
plt.imshow((RK_SBI7*mask2[:,:,axial_middle]).T,norm=tnorm,cmap='gray')
plt.xticks([])
plt.yticks([])
if Save: plt.savefig(FigLoc+'RKSBI7.pdf',format='pdf',bbox_inches='tight',transparent=True)
plt.show()
tnorm = TwoSlopeNorm(vmin=-0,vcenter = 0.5,vmax=1)
plt.imshow((MKT_SBI7*mask2[:,:,axial_middle]).T,norm=tnorm,cmap='gray')
plt.xticks([])
plt.yticks([])
if Save: plt.savefig(FigLoc+'MKTSBI7.pdf',format='pdf',bbox_inches='tight',transparent=True)
plt.show()
tnorm = TwoSlopeNorm(vmin=-0,vcenter = 0.5,vmax=1)
plt.imshow((KFA_SBI7*mask2[:,:,axial_middle]).T,norm=tnorm,cmap='gray')
plt.colorbar()
plt.xticks([])
plt.yticks([])
if Save: plt.savefig(FigLoc+'KFASBI7.pdf',format='pdf',bbox_inches='tight',transparent=True)
plt.show()

## d

In [None]:
tnorm = TwoSlopeNorm(vmin=-0,vcenter = 0.5,vmax=1)
plt.imshow((MK_NL7*mask2[:,:,axial_middle]).T,norm=tnorm,cmap='gray')
plt.xticks([])
plt.yticks([])
if Save: plt.savefig(FigLoc+'MKNL7.pdf',format='pdf',bbox_inches='tight',transparent=True)
plt.show()
tnorm = TwoSlopeNorm(vmin=-0,vcenter = 0.5,vmax=1)
plt.imshow((AK_NL7*mask2[:,:,axial_middle]).T,norm=tnorm,cmap='gray')
plt.xticks([])
plt.yticks([])
if Save: plt.savefig(FigLoc+'AKNL7.pdf',format='pdf',bbox_inches='tight',transparent=True)
plt.show()
tnorm = TwoSlopeNorm(vmin=-0,vcenter = 0.5,vmax=1)
plt.imshow((RK_NL7*mask2[:,:,axial_middle]).T,norm=tnorm,cmap='gray')
plt.xticks([])
plt.yticks([])
if Save: plt.savefig(FigLoc+'RKNL7.pdf',format='pdf',bbox_inches='tight',transparent=True)
plt.show()
tnorm = TwoSlopeNorm(vmin=-0,vcenter = 0.5,vmax=1)
plt.imshow((MKT_NL7*mask2[:,:,axial_middle]).T,norm=tnorm,cmap='gray')
plt.xticks([])
plt.yticks([])
if Save: plt.savefig(FigLoc+'MKTNL7.pdf',format='pdf',bbox_inches='tight',transparent=True)
plt.show()
tnorm = TwoSlopeNorm(vmin=-0,vcenter = 0.5,vmax=1)
plt.imshow((KFA_NL7*mask2[:,:,axial_middle]).T,norm=tnorm,cmap='gray')
plt.colorbar()
plt.xticks([])
plt.yticks([])
if Save: plt.savefig(FigLoc+'KFANL7.pdf',format='pdf',bbox_inches='tight',transparent=True)
plt.show()

## e

In [None]:
ticks = [0,1,2]
data = np.abs((MK_SBIFull-MK_SBI7)*mask2[:,:,axial_middle]).T
norm = TwoSlopeNorm(vmin=0, vcenter=1, vmax=2)
plt.imshow(data,cmap='Reds',norm=norm)
plt.axis('off')
cbar = plt.colorbar(ticks=ticks)
cbar.formatter.set_powerlimits((0, 0))
if Save: plt.savefig(FigLoc+'MKDiffSBI.pdf',format='pdf',bbox_inches='tight',transparent=True)
plt.show()

data = np.abs((AK_SBIFull-AK_SBI7)*mask2[:,:,axial_middle]).T
norm = TwoSlopeNorm(vmin=0, vcenter=np.max(data)/2, vmax=np.max(data))
plt.imshow(data,cmap='Reds',norm=norm)
plt.axis('off')
cbar = plt.colorbar(ticks=ticks)
cbar.formatter.set_powerlimits((0, 0))
if Save: plt.savefig(FigLoc+'AKDiffSBI.pdf',format='pdf',bbox_inches='tight',transparent=True)
plt.show()

data = np.abs((RK_SBIFull-RK_SBI7)*mask2[:,:,axial_middle]).T
norm = TwoSlopeNorm(vmin=0, vcenter=1,vmax=2)
plt.imshow(data,cmap='Reds',norm=norm)
plt.axis('off')
cbar = plt.colorbar()
cbar.formatter.set_powerlimits((0, 0))
if Save: plt.savefig(FigLoc+'RKDiffSBI.pdf',format='pdf',bbox_inches='tight',transparent=True)
plt.show()

ticks = [0,1,2]
data = np.abs((MKT_SBIFull-MKT_SBI7)*mask2[:,:,axial_middle]).T
norm = TwoSlopeNorm(vmin=0, vcenter=1, vmax=2)
plt.imshow(data,cmap='Reds',norm=norm)
plt.axis('off')
cbar = plt.colorbar()#ticks=ticks)
cbar.formatter.set_powerlimits((0, 0))
if Save: plt.savefig(FigLoc+'MKTDiffSBI.pdf',format='pdf',bbox_inches='tight',transparent=True)
plt.show()

data = np.abs((KFA_SBIFull-KFA_SBI7)*mask2[:,:,axial_middle]).T
norm = TwoSlopeNorm(vmin=0, vcenter=np.nanmax(data)/2, vmax=np.nanmax(data))
plt.imshow(data,cmap='Reds',norm=norm)
plt.axis('off')
cbar = plt.colorbar()
cbar.formatter.set_powerlimits((0, 0))
if Save: plt.savefig(FigLoc+'KFADiffSBI.pdf',format='pdf',bbox_inches='tight',transparent=True)
plt.show()

## f

In [None]:
data = np.abs((MK_NLFull-MK_NL7)*mask2[:,:,axial_middle]).T
plt.imshow(data,cmap='Reds',vmin=0,vmax=2)
plt.axis('off')
cbar = plt.colorbar(ticks=[0,1,2])
cbar.formatter.set_powerlimits((0, 0))
if Save: plt.savefig(FigLoc+'MKDiffWLS.pdf',format='pdf',bbox_inches='tight',transparent=True)
plt.show()

dat = np.abs((AK_SBIFull-AK_SBI7)*mask2[:,:,axial_middle]).T
norm = TwoSlopeNorm(vmin=0, vcenter=np.max(dat)/2, vmax=np.max(dat))
data = np.abs((AK_NLFull-AK_NL7)*mask2[:,:,axial_middle]).T
plt.imshow(data,cmap='Reds',norm=norm)
plt.axis('off')
cbar = plt.colorbar()#ticks=ticks)
cbar.formatter.set_powerlimits((0, 0))
if Save: plt.savefig(FigLoc+'AKDiffWLS.pdf',format='pdf',bbox_inches='tight',transparent=True)
plt.show()

norm = TwoSlopeNorm(vmin=0, vcenter=1,vmax=2)
data = np.abs((RK_NLFull-RK_NL7)*mask2[:,:,axial_middle]).T
#ticks = [0, np.round(np.max(data),10)]  #Adjust the number of ticks as needed
plt.imshow(data,cmap='Reds',norm=norm)
plt.axis('off')
cbar = plt.colorbar()
cbar.formatter.set_powerlimits((0, 0))
if Save: plt.savefig(FigLoc+'RKDiffWLS.pdf',format='pdf',bbox_inches='tight',transparent=True)
plt.show()

data = np.abs((MKT_NLFull-MKT_NL7)*mask2[:,:,axial_middle]).T
plt.imshow(data,cmap='Reds',norm=norm)
plt.axis('off')
cbar = plt.colorbar()
cbar.formatter.set_powerlimits((0, 0))
if Save: plt.savefig(FigLoc+'MKTDiffWLS.pdf',format='pdf',bbox_inches='tight',transparent=True)
plt.show()

dat = np.abs((KFA_SBIFull-KFA_SBI7)*mask2[:,:,axial_middle]).T
norm = TwoSlopeNorm(vmin=0, vcenter=np.nanmax(dat)/2, vmax=np.nanmax(dat))
data = np.abs((KFA_NLFull-KFA_NL7)*mask2[:,:,axial_middle]).T
ticks = [0, np.round(np.max(data),10)]  #Adjust the number of ticks as needed
plt.imshow(data,cmap='Reds',norm=norm)
plt.axis('off')
cbar = plt.colorbar()
cbar.formatter.set_powerlimits((0, 0))
if Save: plt.savefig(FigLoc+'KFADiffWLS.pdf',format='pdf',bbox_inches='tight',transparent=True)

In [None]:
for S,N in zip([MK_SBI7,AK_SBI7,RK_SBI7,MKT_SBI7,KFA_SBI7],[MK_SBIFull,AK_SBIFull,RK_SBIFull,MKT_SBIFull,KFA_SBIFull]):
    print('==')
    print(np.mean(np.abs(S*mask2[:,:,axial_middle]-N*mask2[:,:,axial_middle])))
    ssim_noise = ssim(S, N, data_range=N.max()-N.min())
    print(ssim_noise)
    Num = 2*np.abs(S*mask2[:,:,axial_middle]-N*mask2[:,:,axial_middle])
    Den = np.abs(N*mask2[:,:,axial_middle])+np.abs(S*mask2[:,:,axial_middle])
    print(np.nanmean(Num/Den))
    print('==')

In [None]:
for S,N in zip([MK_NL7,AK_NL7,RK_NL7,MKT_NL7,KFA_NL7],[MK_NLFull,AK_NLFull,RK_NLFull,MKT_NLFull,KFA_NLFull]):
    print('==')
    print(np.mean(np.abs(S*mask2[:,:,axial_middle]-N*mask2[:,:,axial_middle])))
    ssim_noise = ssim(S, N, data_range=N.max()-N.min())
    print(ssim_noise)
    Num = 2*np.abs(S*mask2[:,:,axial_middle]-N*mask2[:,:,axial_middle])
    Den = np.abs(N*mask2[:,:,axial_middle])+np.abs(S*mask2[:,:,axial_middle])
    print(np.nanmean(Num/Den))
    print('==')

## g

In [None]:
TrueIndxs = []
gTabs7 = []
for i in range(1,6):
    bvalloc = './HCP_data/Pat'+str(i)+'/bvals_1k.txt'
    bvecloc = './HCP_data/Pat'+str(i)+'/bvecs_1k.txt'
    bvalsHCP = np.loadtxt(bvalloc)
    bvecsHCP = np.loadtxt(bvecloc)
    gtabHCP = gradient_table(bvalsHCP, bvecsHCP)
    
    # Choose the first point (arbitrary starting point, e.g., the first gradient)
    selected_indices = [1]
    distance_matrix = squareform(pdist(bvecsHCP))
    # Iteratively select the point furthest from the current selection
    for _ in range(5):  # We need 7 points in total, and one is already selected
        remaining_indices = list(set(range(len(bvecsHCP))) - set(selected_indices))
        
        # Calculate the minimum distance to the selected points for each remaining point
        min_distances = np.min(distance_matrix[remaining_indices][:, selected_indices], axis=1)
        
        # Select the point with the maximum minimum distance
        next_index = remaining_indices[np.argmax(min_distances)]
        selected_indices.append(next_index)
    
    selected_indices7 = [0]+selected_indices
    
    bvalsHCP7_1 = bvalsHCP[selected_indices7]
    bvecsHCP7_1 = bvecsHCP[selected_indices7]
    
    i=3
    bvalloc = './HCP_data/Pat'+str(i)+'/bvals_3k.txt'
    bvecloc = './HCP_data/Pat'+str(i)+'/bvecs_3k.txt'
    
    bvalsHCP3 = np.loadtxt(bvalloc)
    bvecsHCP3 = np.loadtxt(bvecloc)
    gtabHCP3 = gradient_table(bvalsHCP, bvecsHCP)
    
    # Choose the first point (arbitrary starting point, e.g., the first gradient)
    selected_indices = [0]
    
    temp_bvecs = bvecsHCP3[bvalsHCP3>0]
    temp_bvals = bvalsHCP3[bvalsHCP3>0]
    distance_matrix = squareform(pdist(temp_bvecs))
    # Iteratively select the point furthest from the current selection
    for _ in range(14):  # We need 7 points in total, and one is already selected
        remaining_indices = list(set(range(len(temp_bvecs))) - set(selected_indices))
        
        # Calculate the minimum distance to the selected points for each remaining point
        min_distances = np.min(distance_matrix[remaining_indices][:, selected_indices], axis=1)
        
        # Select the point with the maximum minimum distance
        next_index = remaining_indices[np.argmax(min_distances)]
        selected_indices.append(next_index)
    
    bvalsHCP7_3 = temp_bvals[selected_indices]
    bvecsHCP7_3 = temp_bvecs[selected_indices]
    
    gtabHCP7 = gradient_table(np.hstack((bvalsHCP7_1,bvalsHCP7_3)), np.vstack((bvecsHCP7_1,bvecsHCP7_3)))
    
    true_indx = []
    for b in bvecsHCP7_3:
        true_indx.append(np.linalg.norm(b-bvecsHCP3,axis=1).argmin())
    true_indx = selected_indices7+[t+69 for t in true_indx]
    TrueIndxs.append(true_indx)
    gTabs7.append(gtabHCP7)

In [None]:
TrueIndxs20 = []
gTabs20 = []
for i in range(1,6):
    bvalloc = './HCP_data/Pat'+str(i)+'/bvals_1k.txt'
    bvecloc = './HCP_data/Pat'+str(i)+'/bvecs_1k.txt'
    bvalsHCP = np.loadtxt(bvalloc)
    bvecsHCP = np.loadtxt(bvecloc)
    gtabHCP = gradient_table(bvalsHCP, bvecsHCP)
    
    # Choose the first point (arbitrary starting point, e.g., the first gradient)
    selected_indices = [1]
    distance_matrix = squareform(pdist(bvecsHCP))

    temp_bvecs = bvecsHCP[bvalsHCP>0]
    temp_bvals = bvalsHCP[bvalsHCP>0]
    distance_matrix = squareform(pdist(temp_bvecs))
    # Iteratively select the point furthest from the current selection
    for _ in range(18):  # We need 7 points in total, and one is already selected
        remaining_indices = list(set(range(len(temp_bvecs))) - set(selected_indices))
        
        # Calculate the minimum distance to the selected points for each remaining point
        min_distances = np.min(distance_matrix[remaining_indices][:, selected_indices], axis=1)
        
        # Select the point with the maximum minimum distance
        next_index = remaining_indices[np.argmax(min_distances)]
        selected_indices.append(next_index)
    
    selected_indices7 = selected_indices
    
    bvalsHCP7_1 = np.insert(temp_bvals[selected_indices7],0,0)
    bvecsHCP7_1 = np.insert(temp_bvecs[selected_indices7],0,[0,0,0],axis=0)
    
    i=3
    bvalloc = './HCP_data/Pat'+str(i)+'/bvals_3k.txt'
    bvecloc = './HCP_data/Pat'+str(i)+'/bvecs_3k.txt'
    
    bvalsHCP3 = np.loadtxt(bvalloc)
    bvecsHCP3 = np.loadtxt(bvecloc)
    gtabHCP3 = gradient_table(bvalsHCP, bvecsHCP)
    
    # Choose the first point (arbitrary starting point, e.g., the first gradient)
    selected_indices = [0]
    
    temp_bvecs = bvecsHCP3[bvalsHCP3>0]
    temp_bvals = bvalsHCP3[bvalsHCP3>0]
    distance_matrix = squareform(pdist(temp_bvecs))
    # Iteratively select the point furthest from the current selection
    for _ in range(27):  # We need 7 points in total, and one is already selected
        remaining_indices = list(set(range(len(temp_bvecs))) - set(selected_indices))
        
        # Calculate the minimum distance to the selected points for each remaining point
        min_distances = np.min(distance_matrix[remaining_indices][:, selected_indices], axis=1)
        
        # Select the point with the maximum minimum distance
        next_index = remaining_indices[np.argmax(min_distances)]
        selected_indices.append(next_index)
    
    bvalsHCP7_3 = temp_bvals[selected_indices]
    bvecsHCP7_3 = temp_bvecs[selected_indices]
    
    gtabHCP20 = gradient_table(np.hstack((bvalsHCP7_1,bvalsHCP7_3)), np.vstack((bvecsHCP7_1,bvecsHCP7_3)))
    
    true_indx_one = []
    for b in bvecsHCP7_1:
        true_indx_one.append(np.linalg.norm(b-bvecsHCP,axis=1).argmin())
    true_indx = []        
    for b in bvecsHCP7_3:
        true_indx.append(np.linalg.norm(b-bvecsHCP3,axis=1).argmin())
    true_indx = true_indx_one+[t+69 for t in true_indx]
    TrueIndxs20.append(true_indx)
    gTabs20.append(gtabHCP20)

In [None]:
gTabsE = []

for i in tqdm.tqdm(range(1,6)):
    bvalloc = './HCP_data/Pat'+str(i)+'/bvals_1k.txt'
    bvecloc = './HCP_data/Pat'+str(i)+'/bvecs_1k.txt'
    
    bvalloc3 = './HCP_data/Pat'+str(i)+'/bvals_3k.txt'
    bvecloc3 = './HCP_data/Pat'+str(i)+'/bvecs_3k.txt'
    
    bvalsHCP = np.loadtxt(bvalloc)
    bvecsHCP = np.loadtxt(bvecloc)
    gtabHCP = gradient_table(bvalsHCP, bvecsHCP)
    
    bvalsHCP3 = np.loadtxt(bvalloc3)
    bvecsHCP3 = np.loadtxt(bvecloc3)
    gtabHCP3 = gradient_table(bvalsHCP3, bvecsHCP3)
    
    gtabExt  = gradient_table(np.hstack((bvalsHCP,bvalsHCP3)), np.vstack((bvecsHCP,bvecsHCP3)))
    gTabsE.append(gtabExt)

In [None]:
gTabsPermE = []
gTabsPerm7 = []
gTabsPerm20 = []
Perms     = []
Perms7    = []
Perms20    = []
x = np.random.seed(116)
for i in tqdm.tqdm(range(1,6)):
    x = np.random.permutation(np.arange(138))
    x7 = np.random.permutation(np.arange(22)) # Got to fix
    x20 = np.random.permutation(np.arange(48)) # Got to fix
    Perms.append(x)
    Perms7.append(x7)
    Perms20.append(x20)
    gTabsPermE.append(gradient_table(gTabsE[i-1].bvals[x], gTabsE[i-1].bvecs[x]))
    gTabsPerm7.append(gradient_table(gTabs7[i-1].bvals[x7], gTabs7[i-1].bvecs[x7]))
    gTabsPerm20.append(gradient_table(gTabs20[i-1].bvals[x20], gTabs20[i-1].bvecs[x20]))

In [None]:
np.random.seed(1)
torch.manual_seed(1)
if os.path.exists(f"{network_path}/DKIMultiHCPFull.pickle"):
    with open(f"{network_path}/DKIMultiHCPFull.pickle", "rb") as handle:
        posteriorFull = pickle.load(handle)
else:
    torch.manual_seed(1)
    np.random.seed(1)
    DT = []
    KT = []
    S0 = []
    DT1,KT1 = GenDTKT([DT1_full,DT2_full],[x4_full,R1_full,x2_full,R2_full],12,int(6000))
    DT2,KT2 = GenDTKT([DT1_lfa,DT2_lfa],[x4_lfa,R1_lfa,x2_lfa,R2_lfa],12,int(2000))
    DT3,KT3 = GenDTKT([DT1_hfa,DT2_hfa],[x4_hfa,R1_hfa,x2_hfa,R2_hfa],12,int(6000))
    DT4,KT4 = GenDTKT([DT1_ulfa,DT2_ulfa],[x4_ulfa,R1_ulfa,x2_ulfa,R2_ulfa],12,int(6000))
    DT5,KT5 = GenDTKT([DT1_hak,DT2_hak],[x4_hak,R1_hak,x2_hak,R2_hak],12,int(6000))
    
    
    DT = np.vstack([DT1,DT2,DT3,DT4,DT5])
    KT = np.vstack([KT1,KT2,KT3,KT4,KT5])
    
    S0Dist = BoxUniform(low=torch.tensor([lower_S0]), high=torch.tensor([upper_S0]))
    
    S0 = S0Dist.sample([26000])
    
    S0 = np.array(S0).reshape(len(S0),1)
    
    A  = np.random.choice(5,26000)
    
    indx = np.arange(len(KT))
    Obs = np.zeros([len(KT),len(gTabsPermE[0].bvecs)])
    kk = 0
    while len(indx)>0:
        for i in tqdm.tqdm(indx): 
            Obs[i] = CustomDKISimulator(DT[i],KT[i],gTabsPermE[A[i]],S0[i],np.random.rand()*20 + 30)
        
        indxNew = []
        for i,O in enumerate(Obs):
            if (O>4*np.array(S0[i])).any() or (O<0).any():
                indxNew.append(i)
        KT[indxNew] = KT[indxNew]/2
        DT[indxNew] = GenDTKT([DT1_full,DT2_full],[x4_full,R1_full,x2_full,R2_full],kk,1)[0]
    
        indx = indxNew
        kk+=1
    A = np.array(A).reshape(len(A),1)
    Par = np.hstack([DT,KT,S0])#,A])
    Obs = torch.tensor(Obs).float()
    Par = torch.tensor(Par).float()
    
    # Create inference object. Here, NPE is used.
    inference = SNPE()
    
    # generate simulations and pass to the inference object
    inference = inference.append_simulations(Par, Obs)
    
    # train the density estimator and build the posterior
    density_estimator = inference.train(stop_after_epochs= 100)
    posteriorFull = inference.build_posterior(density_estimator)
    
    os.system('say "Network done."')
    if not os.path.exists(f"{network_path}/DKIMultiHCPFull.pickle"):
        with open(f"{network_path}/DKIMultiHCPFull.pickle", "wb") as handle:
            pickle.dump(posteriorFull, handle)

In [None]:
np.random.seed(1)
torch.manual_seed(1)
if os.path.exists(f"{network_path}/DKIMultiHCPMin.pickle"):
    with open(f"{network_path}/DKIMultiHCPMin.pickle", "rb") as handle:
        posterior7 = pickle.load(handle)
else:
    torch.manual_seed(1)
    np.random.seed(1)
    DT = []
    KT = []
    S0 = []
    DT1,KT1 = GenDTKT([DT1_full,DT2_full],[x4_full,R1_full,x2_full,R2_full],12,int(2.5*6000))
    DT2,KT2 = GenDTKT([DT1_lfa,DT2_lfa],[x4_lfa,R1_lfa,x2_lfa,R2_lfa],12,int(2.5*2000))
    DT3,KT3 = GenDTKT([DT1_hfa,DT2_hfa],[x4_hfa,R1_hfa,x2_hfa,R2_hfa],12,int(2.5*6000))
    DT4,KT4 = GenDTKT([DT1_ulfa,DT2_ulfa],[x4_ulfa,R1_ulfa,x2_ulfa,R2_ulfa],12,int(2.5*6000))
    DT5,KT5 = GenDTKT([DT1_hak,DT2_hak],[x4_hak,R1_hak,x2_hak,R2_hak],12,int(2.5*6000))
    
    
    DT = np.vstack([DT1,DT2,DT3,DT4,DT5])
    KT = np.vstack([KT1,KT2,KT3,KT4,KT5])
    
    S0Dist = BoxUniform(low=torch.tensor([lower_S0]), high=torch.tensor([upper_S0]))
    
    S0 = S0Dist.sample([65000])
    
    S0 = np.array(S0).reshape(len(S0),1)
    
    A  = np.random.choice(5,65000)
    
    indx = np.arange(len(KT))
    Obs = np.zeros([len(KT),len(gTabsPerm7[0].bvecs)])
    kk = 0
    while len(indx)>0:
        for i in tqdm.tqdm(indx): 
            Obs[i] = CustomDKISimulator(DT[i],KT[i],gTabsPerm7[A[i]],S0[i],np.random.rand()*20 + 30)
        
        indxNew = []
        for i,O in enumerate(Obs):
            if (O>4*np.array(S0[i])).any() or (O<0).any():
                indxNew.append(i)
        KT[indxNew] = KT[indxNew]/2
        DT[indxNew] = GenDTKT([DT1_full,DT2_full],[x4_full,R1_full,x2_full,R2_full],kk,1)[0]
    
        indx = indxNew
        kk+=1
    A = np.array(A).reshape(len(A),1)
    Par = np.hstack([DT,KT,S0])#,A])
    Obs = torch.tensor(Obs).float()
    Par = torch.tensor(Par).float()
    
    # Create inference object. Here, NPE is used.
    inference = SNPE()
    
    # generate simulations and pass to the inference object
    inference = inference.append_simulations(Par, Obs)
    
    # train the density estimator and build the posterior
    density_estimator = inference.train(stop_after_epochs= 100)
    posterior7 = inference.build_posterior(density_estimator)
    
    os.system('say "Network done."')
    if not os.path.exists(f"{network_path}/DKIMultiHCPMin.pickle"):
        with open(f"{network_path}/DKIMultiHCPMin.pickle", "wb") as handle:
            pickle.dump(posterior7, handle)

In [None]:
np.random.seed(1)
torch.manual_seed(1)
if os.path.exists(f"{network_path}/DKIMultiHCPMid.pickle"):
    with open(f"{network_path}/DKIMultiHCPMid.pickle", "rb") as handle:
        posterior20 = pickle.load(handle)
else:
    torch.manual_seed(1)
    np.random.seed(1)
    DT = []
    KT = []
    S0 = []
    DT1,KT1 = GenDTKT([DT1_full,DT2_full],[x4_full,R1_full,x2_full,R2_full],12,int(2.5*6000))
    DT2,KT2 = GenDTKT([DT1_lfa,DT2_lfa],[x4_lfa,R1_lfa,x2_lfa,R2_lfa],12,int(2.5*2000))
    DT3,KT3 = GenDTKT([DT1_hfa,DT2_hfa],[x4_hfa,R1_hfa,x2_hfa,R2_hfa],12,int(2.5*6000))
    DT4,KT4 = GenDTKT([DT1_ulfa,DT2_ulfa],[x4_ulfa,R1_ulfa,x2_ulfa,R2_ulfa],12,int(2.5*6000))
    DT5,KT5 = GenDTKT([DT1_hak,DT2_hak],[x4_hak,R1_hak,x2_hak,R2_hak],12,int(2.5*6000))
        
    
    
    DT = np.vstack([DT1,DT2,DT3,DT4,DT5])
    KT = np.vstack([KT1,KT2,KT3,KT4,KT5])
    
    S0Dist = BoxUniform(low=torch.tensor([lower_S0]), high=torch.tensor([upper_S0]))
    
    S0 = S0Dist.sample([65000])
    
    S0 = np.array(S0).reshape(len(S0),1)
    
    A  = np.random.choice(5,650000)
    
    indx = np.arange(len(KT))
    Obs = np.zeros([len(KT),len(gTabsPerm20[0].bvecs)])
    kk = 0
    while len(indx)>0:
        for i in tqdm.tqdm(indx): 
            Obs[i] = CustomDKISimulator(DT[i],KT[i],gTabsPerm20[A[i]],S0[i],np.random.rand()*20 + 30)
        
        indxNew = []
        for i,O in enumerate(Obs):
            if (O>4*np.array(S0[i])).any() or (O<0).any():
                indxNew.append(i)
        KT[indxNew] = KT[indxNew]/2
        DT[indxNew] = GenDTKT([DT1_full,DT2_full],[x4_full,R1_full,x2_full,R2_full],kk,1)[0]
    
        indx = indxNew
        kk+=1
    A = np.array(A).reshape(len(A),1)
    Par = np.hstack([DT,KT,S0])#,A])
    Obs = torch.tensor(Obs).float()
    Par = torch.tensor(Par).float()
    
    # Create inference object. Here, NPE is used.
    inference = SNPE()
    
    # generate simulations and pass to the inference object
    inference = inference.append_simulations(Par, Obs)
    
    # train the density estimator and build the posterior
    density_estimator = inference.train(stop_after_epochs= 100)
    posterior20 = inference.build_posterior(density_estimator)
    
    os.system('say "Network done."')
    if not os.path.exists(f"{network_path}/DKIMultiHCPMid.pickle"):
        with open(f"{network_path}/DKIMultiHCPMid.pickle", "wb") as handle:
            pickle.dump(posterior20, handle)

In [None]:
FullDat = []
for i in range(1,6):
    fdwi = './HCP_data/Pat'+str(i)+'/diff_1k.nii.gz'
    fdwi3 = './HCP_data/Pat'+str(i)+'/diff_3k.nii.gz'
    
    data, affine, img = load_nifti(fdwi, return_img=True)
    data, affine = reslice(data, affine, (1.5,1.5,1.5), (2.5,2.5,2.5))
    axial_middle = data.shape[2] // 2
    maskdata, mask = median_otsu(data, vol_idx=range(10, 50), median_radius=3,
                                 numpass=1, autocrop=False, dilate=2)
    
    data3, affine, img = load_nifti(fdwi3, return_img=True)
    data3, affine = reslice(data3, affine, (1.5,1.5,1.5), (2.5,2.5,2.5))
    # Get the indices of True values
    true_indices = np.argwhere(mask)
    
    # Determine the minimum and maximum indices along each dimension
    min_coords = true_indices.min(axis=0)
    max_coords = true_indices.max(axis=0)
    
    maskdata  = maskdata[min_coords[0]:max_coords[0]+1,min_coords[1]:max_coords[1]+1,min_coords[2]:max_coords[2]+1]
    maskdata3 = data3[min_coords[0]:max_coords[0]+1,min_coords[1]:max_coords[1]+1,min_coords[2]:max_coords[2]+1]
    
    TestData = np.concatenate([maskdata[:, :, axial_middle, :],maskdata3[:, :, axial_middle, :]],axis=-1)
    FlatTD = TestData.reshape(maskdata.shape[0]*maskdata.shape[1],138)
    FlatTD = FlatTD[FlatTD[:,:69].sum(axis=-1)>0]
    FlatTD = FlatTD[~np.array(FlatTD<0).any(axis=-1)]
    FullDat.append(FlatTD)

np.random.seed(1)
RandChoices = []
for F in FullDat:
    RandChoices.append(np.random.choice(len(F),200,replace=False))
    
SubDatUnPerm = np.vstack([F[R] for F,R in zip(FullDat,RandChoices)])
SubDatFull = np.vstack([F[R][:,P] for F,P,R in zip(FullDat,Perms,RandChoices)])
SubDat20 = np.vstack([F[R][:,np.array(T)[P]] for F,T,P,R in zip(FullDat,TrueIndxs20,Perms20,RandChoices)])
SubDat7 = np.vstack([F[R][:,np.array(T)[P]] for F,T,P,R in zip(FullDat,TrueIndxs,Perms7,RandChoices)])

In [None]:
SBIGuessFull = []
for S in tqdm.tqdm(SubDatFull):
    SBIGuessFull.append(posteriorFull.sample((InferSamples,), x=S,show_progress_bars=False).mean(axis=0))
SBIGuessFull = np.vstack(SBIGuessFull)

SBIGuess20 = []
for S in tqdm.tqdm(SubDat20):
    SBIGuess20.append(posterior20.sample((InferSamples,), x=S,show_progress_bars=False).mean(axis=0))
SBIGuess20 = np.vstack(SBIGuess20)

In [None]:
SBIGuess7 = []
for S in tqdm.tqdm(SubDat7):
    SBIGuess7.append(posterior7.sample((InferSamples,), x=S,show_progress_bars=False).mean(axis=0))
SBIGuess7 = np.vstack(SBIGuess7)

SBIGuessFull_2 = np.array([np.hstack([mat_to_vals(clip_negative_eigenvalues(vals_to_mat(S))),S[6:]]) for S in SBIGuessFull])
SBIGuess7_2 = np.array([np.hstack([mat_to_vals(clip_negative_eigenvalues(vals_to_mat(S))),S[6:]]) for S in SBIGuess7])
SBIGuess20_2 = np.array([np.hstack([mat_to_vals(clip_negative_eigenvalues(vals_to_mat(S))),S[6:]]) for S in SBIGuess20])

In [None]:

WLSGuess = []
for i,gtab in enumerate(gTabsE):
    print(i)
    dkimodelNL = dki.DiffusionKurtosisModel(gtab,fit_method='NLLS')
    tenfit = dkimodelNL.fit(SubDatUnPerm[i*200:(i+1)*200])
    WLSGuess.append([np.hstack((d.lower_triangular(),d.kt)) for d in tenfit])
WLSGuess = np.vstack(WLSGuess)


WLSGuess20 = []
for i,gtab in enumerate(gTabsPerm20):
    dkimodelNL = dki.DiffusionKurtosisModel(gtab,fit_method='NLLS')
    tenfit = dkimodelNL.fit(SubDat20[i*200:(i+1)*200])
    WLSGuess20.append([np.hstack((d.lower_triangular(),d.kt)) for d in tenfit])
WLSGuess20 = np.vstack(WLSGuess20)

WLSGuess7 = []
for i,gtab in enumerate(gTabsPerm7):
    dkimodelNL = dki.DiffusionKurtosisModel(gtab,fit_method='NLLS')
    tenfit = dkimodelNL.fit(SubDat7[i*200:(i+1)*200])
    WLSGuess7.append([np.hstack((d.lower_triangular(),d.kt)) for d in tenfit])
WLSGuess7 = np.vstack(WLSGuess7)

In [None]:
Errors20 = np.vstack([DKIErrors(G[:6],G[6:],T[:6],T[6:]) for G,T in zip(SBIGuess20_2,SBIGuessFull_2)]) # SBI Truth
Errors7 = np.vstack([DKIErrors(G[:6],G[6:],T[:6],T[6:]) for G,T in zip(SBIGuess7_2,SBIGuessFull_2)])
ErrorsFullW = np.vstack([DKIErrors(G[:6],G[6:],T[:6],T[6:]) for G,T in zip(WLSGuess,SBIGuessFull_2)])

Errors20W_WLS = np.vstack([DKIErrors(G[:6],G[6:],T[:6],T[6:]) for G,T in zip(WLSGuess20,WLSGuess)])
Errors7W_WLS = np.vstack([DKIErrors(G[:6],G[6:],T[:6],T[6:]) for G,T in zip(WLSGuess7,WLSGuess)])

In [None]:
Percs20 = np.vstack([Percs(G[:6],G[6:],T[:6],T[6:]) for G,T in zip(SBIGuess20_2,SBIGuessFull_2)]) # SBI Truth
Percs7 = np.vstack([Percs(G[:6],G[6:],T[:6],T[6:]) for G,T in zip(SBIGuess7_2,SBIGuessFull_2)])
PercsFullW = np.vstack([Percs(G[:6],G[6:],T[:6],T[6:]) for G,T in zip(WLSGuess,SBIGuessFull_2)])

Percs20W_WLS = np.vstack([Percs(G[:6],G[6:],T[:6],T[6:]) for G,T in zip(WLSGuess20,WLSGuess)])
Percs7W_WLS = np.vstack([Percs(G[:6],G[6:],T[:6],T[6:]) for G,T in zip(WLSGuess7,WLSGuess)])

In [None]:
for i in range(5):
    fig,ax = plt.subplots()
    box_plot(PercsFullW[:,i],'black', 'gray',positions=[0],showfliers=False,widths=0.3,)
    box_plot(np.array([Percs20[:,i],Percs7[:,i]]),SBIFit-0.2, np.clip(SBIFit+0.2,0,1),positions=[0.7,1],
             showfliers=False,widths=0.3)    
    box_plot(np.array([Percs20W_WLS[:,i],Percs7W_WLS[:,i]]),WLSFit-0.2, np.clip(WLSFit+0.2,0,1),'/',positions=[1.8,2.1],
             showfliers=False,widths=0.3)
    
    plt.xticks([0,0.7,1,1.8,2.1],['Full','Mid','Min','Mid','Min'],fontsize=32,rotation=90)

    if(i == 0):
        handles = [
            mpatches.Patch(facecolor=np.clip(SBIFit+0.2,0,1),edgecolor=SBIFit-0.2, label='Truth: SBI-full'),  # Adjust color as per the actual plot color
        ]
        plt.legend(handles=handles,loc=2, fontsize=32,bbox_to_anchor=(0.0,0.8),columnspacing=0.3,handlelength=0.6,handletextpad=0.3)
    if(i == 1):
        handles = [
                mpatches.Patch(facecolor=np.clip(WLSFit+0.2,0,1),edgecolor=WLSFit-0.2,hatch='/', label='Truth: NLLS-full') # Adjust color as per the actual plot color
                ]
        plt.legend(handles=handles,loc=2, fontsize=32,bbox_to_anchor=(0.0,0.8),columnspacing=0.3,handlelength=0.6,handletextpad=0.3)
    plt.gca().ticklabel_format(axis='y',style='sci',scilimits=(-1,1))
    plt.ylim([0,1])
    plt.axhspan(0,0.33,color='gray',alpha=0.25)
    plt.axhline(0.33,ls='--',color='k')
    if Save: plt.savefig(FigLoc+'PercsHCP_'+str(i)+'.pdf',format='pdf',bbox_inches='tight',transparent=True)
    plt.show()

# Supplemental Fig 1

In [None]:
InferSamples = 1000

In [None]:
FigLoc = image_path + 'Fig_S1/'
if not os.path.exists(FigLoc):
    os.makedirs(FigLoc)

## a

In [None]:
fimg_init, fbvals, fbvecs = get_fnames('small_64D')
bvals, bvecs = read_bvals_bvecs(fbvals, fbvecs)
hsph_initial = HemiSphere(xyz=bvecs[1:])
hsph_initial20 = HemiSphere(xyz=bvecs[1:20])
hsph_initial7 = HemiSphere(xyz=bvecs[1:7])
hsph_updated,potentials = disperse_charges(hsph_initial,5000)
hsph_updated20,potentials = disperse_charges(hsph_initial20,5000)
hsph_updated7,potentials = disperse_charges(hsph_initial7,5000)

gtabSimF = gradient_table(np.array([0]+[1000]*64).squeeze(), np.vstack([[0,0,0],hsph_updated.vertices]))
gtabSim20 = gradient_table(np.array([0]+[1000]*19).squeeze(), np.vstack([[0,0,0],hsph_updated20.vertices]))
gtabSim7 = gradient_table(np.array([0]+[1000]*6).squeeze(), np.vstack([[0,0,0],hsph_updated7.vertices]))

In [None]:
FullDat = []
S0Full  = []
DTIFull = []
for i in tqdm.tqdm(range(1,6)):
    fdwi = './HCP_data/Pat'+str(i)+'/diff_1k.nii.gz'
    bvalloc = './HCP_data/Pat'+str(i)+'/bvals_1k.txt'
    bvecloc = './HCP_data/Pat'+str(i)+'/bvecs_1k.txt'
    
    bvalsHCP = np.loadtxt(bvalloc)
    bvecsHCP = np.loadtxt(bvecloc)
    gtabHCP = gradient_table(bvalsHCP, bvecsHCP)
    
    data, affine, img = load_nifti(fdwi, return_img=True)
    axial_middle = data.shape[2] // 2
    maskdata, mask = median_otsu(data, vol_idx=range(10, 50), median_radius=3,
                                 numpass=1, autocrop=True, dilate=2)
    print('maskdata.shape (%d, %d, %d, %d)' % maskdata.shape)
    
    TestData = maskdata[:, :, axial_middle, :]
    FlatTD = TestData.reshape(maskdata.shape[0]*maskdata.shape[1],69)
    FlatTD = FlatTD[FlatTD.sum(axis=-1)>0]
    FlatTD = FlatTD[~np.array(FlatTD<0).any(axis=-1)]
    FullDat.append(FlatTD)
    # Fit the tensor model to the DWI data with return_S0_hat=True
    tenmodel = dti.TensorModel(gtabHCP, return_S0_hat=True,fit_method='NLLS')
    tenfit = tenmodel.fit(FlatTD)
    DTIHCP = tenfit.quadratic_form
    DTIFull.append(DTIHCP)
    # Get the estimated S0_hat values
    S0HCP = tenfit.S0_hat
    S0Full.append(S0HCP)
DTIFull = np.concatenate(DTIFull)
FullDat = np.concatenate(FullDat)
S0Full = np.hstack(S0Full)

In [None]:
np.random.seed(0)
torch.manual_seed(0)
Samples  = []
DTISim = []
S0Sim    = []
# Define the lower and upper bounds

lower_abs,upper_abs = -0.07,0.07
lower_rest,upper_rest = -0.015,0.015
lower_S0 = 25
upper_S0 = 2000

custom_prior = DTIPriorS0(lower_abs,upper_abs,lower_rest,upper_rest,lower_S0,upper_S0)
prior, *_ = process_prior(custom_prior) 

params = prior.sample([5000])
for i in tqdm.tqdm(range(5000)):
    dt = ComputeDTI(params[i])
    dt = ForceLowFA(dt)
    DTISim.append(dt)
    S0Sim.append(params[i,-1])
    Samples.append([CustomSimulator(dt,gtabSimF, S0=200,snr=scale) for scale in NoiseLevels])
    
Samples = np.array(Samples).squeeze()
Samples = np.moveaxis(Samples, 0, -1)

DTISim = np.array(DTISim)

MDSim = [np.mean(np.linalg.eigh(B)[0]) for B in DTISim]
MDHCP = [np.mean(np.linalg.eigh(B)[0]) for B in DTIFull]

FASim = [FracAni(np.linalg.eigh(B)[0],m) for B,m in zip(DTISim,MDSim)]
FAHCP = [FracAni(np.linalg.eigh(B)[0],m) for B,m in zip(DTIFull,MDHCP)]

In [None]:
plt.hist(S0Sim,density=True,stacked=True,alpha=0.75,label='Simulated',color=SBIFit)
plt.hist(S0Full,density=True,stacked=True,alpha=0.75,label='HCP',color='gray')
plt.legend(fontsize=32,loc=1,bbox_to_anchor=(1,1))
plt.gca().ticklabel_format(axis='y',style='sci',scilimits=(-1,1))
plt.yticks(fontsize=32)
plt.xticks(fontsize=32)
if Save: plt.savefig(FigLoc+'S0Dist.pdf',format='pdf',bbox_inches='tight',transparent=True)

## b

In [None]:
plt.hist(MDSim,density=True,stacked=True,label='Simulated samples',color=SBIFit)
plt.hist(MDHCP,density=True,stacked=True,alpha=0.75,label='HPC subset',color='gray')
plt.gca().ticklabel_format(axis='y',style='sci',scilimits=(-1,1))
plt.yticks(fontsize=32)
plt.xticks(fontsize=32)
if Save: plt.savefig(FigLoc+'MDDist.pdf',format='pdf',bbox_inches='tight',transparent=True)

## c

In [None]:
plt.hist(FASim,density=True,label='Simulated samples',color=SBIFit)
plt.hist(FAHCP,density=True,alpha=0.75,label='HPC subset',color='gray')
plt.yticks(fontsize=32)
plt.xticks(fontsize=32)
if Save: plt.savefig(FigLoc+'FADist.pdf',format='pdf',bbox_inches='tight',transparent=True)

## d

In [None]:
fig,axs = plt.subplots(3,3,figsize=(12,12))
ax = axs.ravel()
ax[0].hist(DTISim[:,0,0],density=True,color=SBIFit)
ax[1].hist(DTISim[:,0,1],density=True,color=SBIFit)
ax[2].hist(DTISim[:,0,2],density=True,color=SBIFit)
ax[4].hist(DTISim[:,1,1],density=True,color=SBIFit)
ax[5].hist(DTISim[:,1,2],density=True,color=SBIFit)
ax[-1].hist(DTISim[:,2,2],density=True,color=SBIFit)


ax[0].hist(DTIFull[:,0,0],density=True,alpha=0.75,color='gray')
ax[1].hist(DTIFull[:,0,1],density=True,alpha=0.75,color='gray')
ax[2].hist(DTIFull[:,0,2],density=True,alpha=0.75,color='gray')
ax[4].hist(DTIFull[:,1,1],density=True,alpha=0.75,color='gray')
ax[5].hist(DTIFull[:,1,2],density=True,alpha=0.75,color='gray')
ax[-1].hist(DTIFull[:,2,2],density=True,alpha=0.75,color='gray')
ax[3].axis('off')
ax[-2].axis('off')
ax[-3].axis('off')

for a in ax:
    a.tick_params(axis='x', labelsize=32)
    a.tick_params(axis='y', labelsize=32)
    a.ticklabel_format(axis='y',style='sci',scilimits=(-1,1))
plt.tight_layout()
if Save: plt.savefig(FigLoc+'DTDist.pdf',format='pdf',bbox_inches='tight',transparent=True)

# Supplemental Fig 2

In [None]:
FigLoc = image_path + 'Fig_S2/'
if not os.path.exists(FigLoc):
    os.makedirs(FigLoc)

## a-d

In [None]:
data = DTIFilt[:,0]
shape,loc,scale = lognorm.fit(data)
# Generate x-values for plotting
x = np.linspace(min(data), max(data), 1000)

# Compute the fitted PDF
dti1_fitted = stats.lognorm(shape, loc=loc, scale=scale)
plt.hist(DTIFilt[:,5],bins=30,density=True,color=WLSFit)
plt.hist(DTISim[:,0,0],bins=30,density=True,alpha=0.5,color='gray')
plt.plot(x,dti1_fitted.pdf(x),lw=3,c=SBIFit)
plt.text(0.0014,600,"Lognormal, \n shape = {:.2f}, \n location = {:.2e} \n scale = {:.2e}".format(shape,loc,scale),
        fontsize=32)
plt.yticks([])
plt.xticks(fontsize=32)
if Save: plt.savefig(FigLoc+'Normal1.pdf',format='pdf',bbox_inches='tight',transparent=True)
plt.show()
#DT_rest
data = DTIFilt[:,1]
loc,scale = stats.norm.fit(data)

# Compute the fitted PDF
dti2_fitted = stats.norm(loc=loc, scale=scale)

x = np.linspace(min(data), max(data), 1000)

# Compute the fitted PDF
dti1_fitted = stats.norm(loc=loc, scale=scale)
plt.hist(DTIFilt[:,1],bins=30,density=True,color=WLSFit,label='HCP data')
plt.hist(DTISim[:,1,0],bins=30,density=True,alpha=0.5,color='gray',label='DTI prior')
plt.plot(x,dti1_fitted.pdf(x),lw=3,c=SBIFit,label='stat. fit')
plt.text(0.00011,2000,"Normal, \n mean = {:.2f},\n S.D. = {:.2e} \n".format(shape,loc,scale),
        fontsize=32)
plt.yticks([])
plt.xticks(fontsize=32)
plt.legend(fontsize=32,columnspacing=0.3,handlelength=0.4,handletextpad=0.1,loc=1,bbox_to_anchor=(0.5,1))
if Save: plt.savefig(FigLoc+'Normal2.pdf',format='pdf',bbox_inches='tight',transparent=True)
plt.show()

data = DKIFilt[:,0]
shape,loc,scale = lognorm.fit(data)
x4_fitted = stats.lognorm(shape, loc=loc, scale=scale)
    
# Generate x-values for plotting
x = np.linspace(min(data), max(data), 1000)

# Compute the fitted PDF
dti1_fitted = stats.lognorm(shape, loc=loc, scale=scale)
plt.hist(DKIFilt[:,0],bins=30,density=True,color=WLSFit)
plt.plot(x,x4_fitted.pdf(x),lw=3,c=SBIFit)
plt.text(1,0.8,"Lognormal, \n shape = {:.2f}, \n location = {:.2e} \n scale = {:.2e}".format(shape,loc,scale),
        fontsize=32)
plt.yticks([])
plt.xticks(fontsize=32)
if Save: plt.savefig(FigLoc+'Normal3.pdf',format='pdf',bbox_inches='tight',transparent=True)
plt.show()


# Fitting R1
data = DKIFilt[:,3]
loc,scale = stats.norm.fit(data)
R1_fitted = stats.norm(loc,scale)
    
# Generate x-values for plotting
x = np.linspace(min(data), max(data), 1000)

plt.hist(DKIFilt[:,3],bins=30,density=True,color=WLSFit)
plt.plot(x,R1_fitted.pdf(x),lw=3,c=SBIFit)
plt.text(0.1,3,"Normal, \n mean = {:.2f},\n S.D. = {:.2e} \n".format(shape,loc),
        fontsize=32)
plt.yticks([])
plt.xticks(fontsize=32)
if Save: plt.savefig(FigLoc+'Normal4.pdf',format='pdf',bbox_inches='tight',transparent=True)
plt.show()



## e-h


In [None]:

mask = TrueMets[:,-1]<0.3
data = DTIFilt[mask,0]
shape,loc,scale = lognorm.fit(data)
# Generate x-values for plotting
x = np.linspace(min(data), max(data), 1000)

# Compute the fitted PDF
dti1_fitted = stats.lognorm(shape, loc=loc, scale=scale)
plt.hist(DTIFilt[mask,5],bins=30,density=True,color=WLSFit)
plt.plot(x,dti1_fitted.pdf(x),lw=3,c=SBIFit)
plt.text(0.0014,600,"Lognormal, \n shape = {:.2f}, \n location = {:.2e} \n scale = {:.2e}".format(shape,loc,scale),
        fontsize=32)
plt.yticks([])
plt.xticks(fontsize=32)
if Save: plt.savefig(FigLoc+'lowFA1.pdf',format='pdf',bbox_inches='tight',transparent=True)
plt.show()
#DT_rest
data = DTIFilt[mask,1]
loc,scale = stats.norm.fit(data)

# Compute the fitted PDF
dti2_fitted = stats.norm(loc=loc, scale=scale)

x = np.linspace(min(data), max(data), 1000)

# Compute the fitted PDF
dti1_fitted = stats.norm(loc=loc, scale=scale)
plt.hist(DTIFilt[mask,1],bins=30,density=True,color=WLSFit)
plt.plot(x,dti1_fitted.pdf(x),lw=3,c=SBIFit)
plt.text(0.00011,2000,"Normal, \n mean = {:.2f},\n S.D. = {:.2e} \n".format(shape,loc,scale),
        fontsize=32)
plt.yticks([])
plt.xticks(fontsize=32)
if Save: plt.savefig(FigLoc+'lowFA2.pdf',format='pdf',bbox_inches='tight',transparent=True)
plt.show()

data = DKIFilt[mask,0]
shape,loc,scale = lognorm.fit(data)
x4_fitted = stats.lognorm(shape, loc=loc, scale=scale)
    
# Generate x-values for plotting
x = np.linspace(min(data), max(data), 1000)

# Compute the fitted PDF
dti1_fitted = stats.lognorm(shape, loc=loc, scale=scale)
plt.hist(DKIFilt[mask,0],bins=30,density=True,color=WLSFit)
plt.plot(x,x4_fitted.pdf(x),lw=3,c=SBIFit)
plt.text(1,0.8,"Lognormal, \n shape = {:.2f}, \n location = {:.2e} \n scale = {:.2e}".format(shape,loc,scale),
        fontsize=32)
plt.yticks([])
plt.xticks(fontsize=32)
if Save: plt.savefig(FigLoc+'lowFA3.pdf',format='pdf',bbox_inches='tight',transparent=True)
plt.show()


# Fitting R1
data = DKIFilt[mask,3]
loc,scale = stats.norm.fit(data)
R1_fitted = stats.norm(loc,scale)
    
# Generate x-values for plotting
x = np.linspace(min(data), max(data), 1000)

plt.hist(DKIFilt[mask,3],bins=30,density=True,color=WLSFit)
plt.plot(x,R1_fitted.pdf(x),lw=3,c=SBIFit)
plt.text(0.05,3,"Normal, \n mean = {:.2f},\n S.D. = {:.2e} \n".format(shape,loc),
        fontsize=32)
plt.yticks([])
plt.xticks(fontsize=32)
if Save: plt.savefig(FigLoc+'lowFA4.pdf',format='pdf',bbox_inches='tight',transparent=True)
plt.show()



## i-l

In [None]:
mask = TrueMets[:,-1]>0.7
data = DTIFilt[mask,0]
shape,loc,scale = lognorm.fit(data)
# Generate x-values for plotting
x = np.linspace(min(data), max(data), 1000)

# Compute the fitted PDF
dti1_fitted = stats.lognorm(shape, loc=loc, scale=scale)
plt.hist(DTIFilt[mask,0],bins=30,density=True,color=WLSFit)
plt.plot(x,dti1_fitted.pdf(x),lw=3,c=SBIFit)
plt.text(0.0008,1400,"Lognormal, \n shape = {:.2f}, \n location = {:.2e} \n scale = {:.2e}".format(shape,loc,scale),
        fontsize=32)
plt.yticks([])
plt.xticks(fontsize=32)
if Save: plt.savefig(FigLoc+'HighFA1.pdf',format='pdf',bbox_inches='tight',transparent=True)
plt.show()
#DT_rest
data = DTIFilt[mask,1]
loc,scale = stats.norm.fit(data)

# Compute the fitted PDF
dti2_fitted = stats.norm(loc=loc, scale=scale)

x = np.linspace(min(data), max(data), 1000)

# Compute the fitted PDF
dti1_fitted = stats.norm(loc=loc, scale=scale)
plt.hist(DTIFilt[mask,1],bins=30,density=True,color=WLSFit)
plt.plot(x,dti1_fitted.pdf(x),lw=3,c=SBIFit)
plt.text(0.00013,1600,"Normal, \n mean = {:.2f},\n S.D. = {:.2e} \n".format(shape,loc,scale),
        fontsize=32)
plt.yticks([])
plt.xticks(fontsize=32)
if Save: plt.savefig(FigLoc+'HighFA2.pdf',format='pdf',bbox_inches='tight',transparent=True)
plt.show()

data = DKIFilt[mask,0]
shape,loc,scale = lognorm.fit(data)
x4_fitted = stats.lognorm(shape, loc=loc, scale=scale)
    
# Generate x-values for plotting
x = np.linspace(min(data), max(data), 1000)

# Compute the fitted PDF
dti1_fitted = stats.lognorm(shape, loc=loc, scale=scale)
plt.hist(DKIFilt[mask,0],bins=30,density=True,color=WLSFit)
plt.plot(x,x4_fitted.pdf(x),lw=3,c=SBIFit)
plt.text(1.3,0.5,"Lognormal, \n shape = {:.2f}, \n location = {:.2e} \n scale = {:.2e}".format(shape,loc,scale),
        fontsize=32)
plt.yticks([])
plt.xticks(fontsize=32)
if Save: plt.savefig(FigLoc+'HighFA3.pdf',format='pdf',bbox_inches='tight',transparent=True)
plt.show()


# Fitting R1
data = DKIFilt[mask,3]
loc,scale = stats.norm.fit(data)
R1_fitted = stats.norm(loc,scale)
    
# Generate x-values for plotting
x = np.linspace(min(data), max(data), 1000)

plt.hist(DKIFilt[mask,3],bins=30,density=True,color=WLSFit)
plt.plot(x,R1_fitted.pdf(x),lw=3,c=SBIFit)
plt.text(0.3,1,"Normal, \n mean = {:.2f},\n S.D. = {:.2e} \n".format(shape,loc),
        fontsize=32)
plt.yticks([])
plt.xticks(fontsize=32)
if Save: plt.savefig(FigLoc+'HighFA4.pdf',format='pdf',bbox_inches='tight',transparent=True)
plt.show()



## m

In [None]:
i,j=0,0
mask = (TrueMets[:,-1]<0.7)*(TrueMets[:,-1]>0.3)
plt.scatter(DTIFilt[mask,i],DKIFilt[mask,j],color=np.clip(WLSFit-0.5,0,1),label='HCP data')
mask = TrueMets[:,-1]>0.7
plt.scatter(DTIFilt[mask,i],DKIFilt[mask,j],color=np.clip(WLSFit,0,1),marker='v'
            ,label='HCP data (KFA$>$0.7)')
mask = TrueMets[:,-1]<0.3
plt.scatter(DTIFilt[mask,i],DKIFilt[mask,j],color=np.clip(WLSFit-0.3,0,1),marker='^'
            ,label='HCP data (KFA$<$0.3)')
plt.yticks([])
plt.xticks([])
plt.legend(fontsize=24,loc=1,bbox_to_anchor=(1.1,1.1),handlelength=0.4,handletextpad=0.4,markerscale=2)
if Save: plt.savefig(FigLoc+'Scatter1Dat.pdf',format='pdf',bbox_inches='tight',transparent=True)
plt.show()

## n

In [None]:
DT1,KT1 = GenDTKT([DT1_full,DT2_full],[x4_full,R1_full,x2_full,R2_full],12,200)
DT2,KT2 = GenDTKT([DT1_lfa,DT2_lfa],[x4_lfa,R1_lfa,x2_lfa,R2_lfa],12,200)
DT3,KT3 = GenDTKT([DT1_hfa,DT2_hfa],[x4_hfa,R1_hfa,x2_hfa,R2_hfa],12,200)
DT4,KT4 = GenDTKT([DT1_ulfa,DT2_ulfa],[x4_ulfa,R1_ulfa,x2_ulfa,R2_ulfa],12,200)
DT5,KT5 = GenDTKT([DT1_hak,DT2_hak],[x4_hak,R1_hak,x2_hak,R2_hak],12,800)


DT = np.vstack([DT1,DT2,DT3,DT4,DT5])
KT = np.vstack([KT1,KT2,KT3,KT4,KT5])

ParMets = []
for d,k in tqdm.tqdm(zip(DT1,KT1)):
    ParMets.append(DKIMetrics(d,k))
ParTest1 = np.array(ParMets)

ParMets = []
for d,k in tqdm.tqdm(zip(DT2,KT2)):
    ParMets.append(DKIMetrics(d,k))
ParTest2 = np.array(ParMets)

ParMets = []
for d,k in tqdm.tqdm(zip(DT3,KT3)):
    ParMets.append(DKIMetrics(d,k))
ParTest3 = np.array(ParMets)

ParMets = []
for d,k in tqdm.tqdm(zip(DT4,KT4)):
    ParMets.append(DKIMetrics(d,k))
ParTest4 = np.array(ParMets)

ParMets = []
for d,k in tqdm.tqdm(zip(DT5,KT5)):
    ParMets.append(DKIMetrics(d,k))
ParTest5 = np.array(ParMets)
ParMets = []
for d,k in tqdm.tqdm(zip(DT,KT)):
    ParMets.append(DKIMetrics(d,k))
ParTest = np.array(ParMets)

In [None]:
i,j=0,0
mask = (ParTest[:,-1]<0.7)*(ParTest[:,-1]>0.3)
plt.scatter(DT[mask,i],KT[mask,j],color=np.clip(SBIFit-0.5,0,1),label='Sim. data')
mask = ParTest[:,-1]>0.7
plt.scatter(DT[mask,i],KT[mask,j],color=np.clip(SBIFit+0.2,0,1),marker='v',label='Sim. data (KFA$>$0.7)')
mask = ParTest[:,-1]<0.3
plt.scatter(DT[mask,i],KT[mask,j],color=np.clip(SBIFit-0.3,0,1),marker='^',label='Sim. data (KFA$>$0.7)')
plt.xlim((2.6217407288099334e-05, 0.004342572522996818))
plt.yticks([])
plt.xticks([])
plt.legend(fontsize=24,loc=1,bbox_to_anchor=(1.1,1.1),handlelength=0.4,handletextpad=0.4,markerscale=2)
if Save: plt.savefig(FigLoc+'Scatter1Sim.pdf',format='pdf',bbox_inches='tight',transparent=True)
plt.show()

## o

In [None]:
i,j=9,0
mask = (TrueMets[:,-1]<0.7)*(TrueMets[:,-1]>0.3)
plt.scatter(DKIFilt[mask,i],DKIFilt[mask,j],color=np.clip(WLSFit-0.5,0,1))
mask = TrueMets[:,-1]>0.7
plt.scatter(DKIFilt[mask,i],DKIFilt[mask,j],color=np.clip(WLSFit+0.2,0,1),marker='v')
mask = TrueMets[:,-1]<0.3
plt.scatter(DKIFilt[mask,i],DKIFilt[mask,j],color=np.clip(WLSFit-0.3,0,1),marker='^')
plt.yticks([])
plt.xticks([])
if Save: plt.savefig(FigLoc+'Scatter2Dat.pdf',format='pdf',bbox_inches='tight',transparent=True)
plt.show()

## p

In [None]:
i,j=9,0
mask = (ParTest[:,-1]<0.7)*(ParTest[:,-1]>0.3)
plt.scatter(KT[mask,i],KT[mask,j],color=np.clip(SBIFit-0.5,0,1))
mask = ParTest[:,-1]>0.7
plt.scatter(KT[mask,i],KT[mask,j],color=np.clip(SBIFit+0.2,0,1),marker='v')
mask = ParTest[:,-1]<0.3
plt.scatter(KT[mask,i],KT[mask,j],color=np.clip(SBIFit-0.3,0,1),marker='^')
plt.xlim((-0.07662077262433849, 1.0282445059644276))
plt.ylim((-0.5425279356172327, 4.178757181686705))
plt.yticks([])
plt.xticks([])
if Save: plt.savefig(FigLoc+'Scatter2Sim.pdf',format='pdf',bbox_inches='tight',transparent=True)
plt.show()

# Supplemental Fig 3

In [None]:
FigLoc = image_path + 'Fig_S3/'
if not os.path.exists(FigLoc):
    os.makedirs(FigLoc)

In [None]:
np.random.seed(1)
torch.manual_seed(1)
if os.path.exists(f"{network_path}/DTISimMid.pickle"):
    with open(f"{network_path}/DTISimMid.pickle", "rb") as handle:
        posterior20 = pickle.load(handle)
else:
    Obs = []
    Par = []
    for i in tqdm.tqdm(range(TrainingSamples)):
        params = priorNoise.sample()
        dt = ComputeDTI(params)
        dt = ForceLowFA(dt)
        a = params[-1]
        Obs.append(CustomSimulator(dt,gtabSim20,200,a))
        Par.append(np.hstack([mat_to_vals(dt),a]))
    
    
    Obs = np.array(Obs)
    Par = np.array(Par)
    Obs = torch.tensor(Obs).float()
    Par = torch.tensor(Par).float()
    
    # Create inference object. Here, NPE is used.
    inference = SNPE(prior=priorNoise)
    
    # generate simulations and pass to the inference object
    inference = inference.append_simulations(Par, Obs)
    
    # train the density estimator and build the posterior
    density_estimator = inference.train()
    posterior20 = inference.build_posterior(density_estimator)
    if not os.path.exists(f"{network_path}/DTISimMid.pickle"):
        with open(f"{network_path}/DTISimMid.pickle", "wb") as handle:
            pickle.dump(posterior20, handle)

In [None]:
torch.manual_seed(10)
SNR = NoiseLevels
Error20 = []
NoiseApprox20 = []
for k in tqdm.tqdm(range(5)):
    ErrorN2 = []
    ENoise = []
    for i in range(500):
        tparams = mat_to_vals(DTISim20[i])
        tObs = Samples20[k,:,i]
        mat_true = vals_to_mat(tparams)
        evals_true,evecs_true = np.linalg.eigh(mat_true)
        true_signal_dti = single_tensor(gtabSim20, S0=200, evals=evals_true, evecs=evecs_true,
                           snr=None)
        posterior_samples_1 = posterior20.sample((InferSamples,), x=tObs,show_progress_bars=False)
        mat_guess = vals_to_mat(np.array(posterior_samples_1.mean(axis=0)))
        ErrorN2.append(Errors(mat_guess,mat_true,gtabSim20,true_signal_dti,tObs))
        ENoise.append(posterior_samples_1[:,-1].mean())
    NoiseApprox20.append(ENoise)
    Error20.append(ErrorN2)

NoiseApprox20 = np.array(NoiseApprox20)    


In [None]:
k,gtab,Samps,DTIS = 20,gtabSim20,Samples20,DTISim20
tenmodel = dti.TensorModel(gtab,fit_method='NLLS')
Error_n = []
for S,Noise in zip(Samps,NoiseLevels):
    Error = []
    for i in range(500):
        tenfit = tenmodel.fit(S[:,i])
        tensor_vals = dti.lower_triangular(tenfit.quadratic_form)
        DT_test = vals_to_mat(tensor_vals)
        Error.append(Errors(DT_test,DTIS[i],gtab,Samps[0][:,i],S[:,i]))
    Error_n.append(Error)
Error_n = np.array(Error_n)

In [None]:
fig,axs = plt.subplots(1,4,figsize=(18,3))
ax = axs.ravel()
for ll,(a,E,t) in enumerate(zip(ax,np.array(Error20).T,Errors_name)):
    plt.sca(a) 
    bp = box_plot(E[:,1:].T,SBIFit-0.2, np.clip(SBIFit+0.2,0,1),positions=[1.3,2.3,3.3,4.3],showfliers=False,widths=0.3)
    bp2 = box_plot(Error_n[1:,:,ll],WLSFit-0.2, np.clip(WLSFit+0.2,0,1),showfliers=False,widths=0.3)
    plt.sca(a)
    if(ll>3):
        plt.xlabel('SNR', fontsize=24)
    plt.xticks([1.15, 2.15, 3.15, 4.15,], NoiseLevels[1:],fontsize=32)
    ymax = max(bp2['whiskers'][1].get_ydata()[1],bp['whiskers'][1].get_ydata()[1])*1.2
    plt.yticks(fontsize=32)
    # Create custom legend handles
    plt.gca().ticklabel_format(axis='y',style='sci',scilimits=(-1,1))
    plt.grid(axis='y')
plt.tight_layout()
if Save: plt.savefig(FigLoc+'SimDatDTIErrors1_20.pdf',format='pdf',bbox_inches='tight',transparent=True)

fig,axs = plt.subplots(1,4,figsize=(18,3))
ax = axs.ravel()
for ll,(a,E,t) in enumerate(zip(ax,np.array(Error20).T[4:],Errors_name[4:])):
    plt.sca(a) 
    bp = box_plot(E[:,1:].T,SBIFit-0.2, np.clip(SBIFit+0.2,0,1),positions=[1.3,2.3,3.3,4.3],showfliers=False,widths=0.3)
    bp2 = box_plot(Error_n[1:,:,ll+4],WLSFit-0.2, np.clip(WLSFit+0.2,0,1),showfliers=False,widths=0.3)
    plt.sca(a)
    if(ll>3):
        plt.xlabel('SNR', fontsize=24)
    plt.xticks([1.15, 2.15, 3.15, 4.15,], NoiseLevels[1:],fontsize=32)
    ymax = max(bp2['whiskers'][1].get_ydata()[1],bp['whiskers'][1].get_ydata()[1])*1.2
    plt.yticks(fontsize=32)
    # Create custom legend handles
    plt.gca().ticklabel_format(axis='y',style='sci',scilimits=(-1,1))
    plt.grid(axis='y')
plt.tight_layout()
if Save: plt.savefig(FigLoc+'SimDatDTIErrors2_20.pdf',format='pdf',bbox_inches='tight',transparent=True)

# Supplemental Fig 4

In [None]:
FigLoc = image_path + 'Fig_S4/'
if not os.path.exists(FigLoc):
    os.makedirs(FigLoc)

## a

In [None]:
np.random.seed(0)
torch.manual_seed(0)
Samples  = []
DTISim = []
S0Sim    = []

params = priorS0.sample([5000])
for i in tqdm.tqdm(range(5000)):
    dt = ComputeDTI(params[i])
    dt = ForceLowFA(dt)
    DTISim.append(dt)
    S0Sim.append(params[i,-1])
    Samples.append([CustomSimulator(dt,gtabSimF, S0=200,snr=scale) for scale in NoiseLevels])
    
Samples = np.array(Samples).squeeze()
Samples = np.moveaxis(Samples, 0, -1)

Samples20  = []
DTISim20 = []
S0Sim20    = []

params = priorS0.sample([5000])
for i in tqdm.tqdm(range(5000)):
    dt = ComputeDTI(params[i])
    dt = ForceLowFA(dt)
    DTISim20.append(dt)
    S0Sim20.append(params[i,-1])
    Samples20.append([CustomSimulator(dt,gtabSim20, S0=200,snr=scale) for scale in NoiseLevels])
    
Samples20 = np.array(Samples20).squeeze()
Samples20 = np.moveaxis(Samples20, 0, -1)

Samples7  = []
DTISim7 = []
S0Sim7    = []

params = priorS0.sample([5000])
for i in tqdm.tqdm(range(5000)):
    dt = ComputeDTI(params[i])
    dt = ForceLowFA(dt)
    DTISim7.append(dt)
    S0Sim7.append(params[i,-1])
    Samples7.append([CustomSimulator(dt,gtabSim7, S0=200,snr=scale) for scale in NoiseLevels])
    
Samples7 = np.array(Samples7).squeeze()
Samples7 = np.moveaxis(Samples7, 0, -1)

In [None]:
np.random.seed(1)
torch.manual_seed(1)
if os.path.exists(f"{network_path}/DTISimFull.pickle"):
    with open(f"{network_path}/DTISimFull.pickle", "rb") as handle:
        posteriorFull = pickle.load(handle)
else:
    Obs = []
    Par = []
    for i in tqdm.tqdm(range(TrainingSamples)):
        params = priorNoise.sample()
        dt = ComputeDTI(params)
        dt = ForceLowFA(dt)
        a = params[-1]
        Obs.append(CustomSimulator(dt,gtabSimF,200,a))
        Par.append(np.hstack([mat_to_vals(dt),a]))
    
    Obs = np.array(Obs)
    Par = np.array(Par)
    Obs = torch.tensor(Obs).float()
    Par = torch.tensor(Par).float()
    
    # Create inference object. Here, NPE is used.
    inference = SNPE(prior=priorNoise)
    
    # generate simulations and pass to the inference object
    inference = inference.append_simulations(Par, Obs)
    
    # train the density estimator and build the posterior
    density_estimator = inference.train()
    posteriorFull = inference.build_posterior(density_estimator)
    if not os.path.exists(f"{save_path}/DTISimFull.pickle"):
        with open(f"{save_path}/DTISimFull.pickle", "wb") as handle:
            pickle.dump(posteriorFull, handle)


In [None]:
torch.manual_seed(10)
SNR = NoiseLevels
ErrorFull = []
NoiseApproxFull = []
for k in tqdm.tqdm(range(5)):
    ErrorN2 = []
    ENoise = []
    for i in range(500):
        tparams = mat_to_vals(DTISim[i])
        tObs = Samples[k,:,i]#Simulator(bvals,bvecs,200,params,Noise)
        mat_true = vals_to_mat(tparams)
        evals_true,evecs_true = np.linalg.eigh(mat_true)
        true_signal_dti = single_tensor(gtabSimF, S0=200, evals=evals_true, evecs=evecs_true,
                           snr=None)
        posterior_samples_1 = posteriorFull.sample((InferSamples,), x=tObs,show_progress_bars=False)
        mat_guess = vals_to_mat(np.array(posterior_samples_1.mean(axis=0)))
        mat_guess = clip_negative_eigenvalues(mat_guess)
        ErrorN2.append(Errors(mat_guess,mat_true,gtabSimF,true_signal_dti,tObs))
        ENoise.append(posterior_samples_1[:,-1].mean())
    NoiseApproxFull.append(ENoise)
    ErrorFull.append(ErrorN2)

NoiseApproxFull = np.array(NoiseApproxFull)    


In [None]:
k,gtab,Samps,DTIS = 65,gtabSimF,Samples,DTISim
tenmodel = dti.TensorModel(gtab,fit_method='NLLS')
Error_n = []
for S,Noise in zip(Samps,NoiseLevels):
    Error = []
    for i in range(500):
        tenfit = tenmodel.fit(S[:,i])
        tensor_vals = dti.lower_triangular(tenfit.quadratic_form)
        DT_test = vals_to_mat(tensor_vals)
        Error.append(Errors(DT_test,DTIS[i],gtab,Samps[0][:,i],S[:,i]))
    Error_n.append(Error)
Error_n = np.array(Error_n)

In [None]:
fig,axs = plt.subplots(1,6,figsize=(27,3))
ax = axs.ravel()
for ll,(a,E,t) in enumerate(zip(ax,np.array(ErrorFull).T[2:],Errors_name[2:])):
    plt.sca(a) 
    bp = box_plot(E[:,1:].T,SBIFit-0.2, np.clip(SBIFit+0.2,0,1),positions=[1.3,2.3,3.3,4.3],showfliers=False,widths=0.3)
    bp2 = box_plot(Error_n[1:,:,ll+2],WLSFit-0.2, np.clip(WLSFit+0.2,0,1),showfliers=False,widths=0.3)
    plt.sca(a)
    if(ll>3):
        plt.xlabel('SNR', fontsize=24)
    plt.xticks([1.15, 2.15, 3.15, 4.15,], NoiseLevels[1:],fontsize=32)
    ymax = max(bp2['whiskers'][1].get_ydata()[1],bp['whiskers'][1].get_ydata()[1])*1.2
    plt.yticks(fontsize=32)
    # Create custom legend handles
    plt.gca().ticklabel_format(axis='y',style='sci',scilimits=(-1,1))
    plt.grid(axis='y')
    if(ll==1):
        handles = [
            Line2D([0], [0], color=SBIFit, lw=4, label='SBI'),  # Adjust color as per the actual plot color
        ]
        # Add the legenda
        plt.legend(handles=handles,loc=2, bbox_to_anchor=(0,1.05),
                   fontsize=36,columnspacing=0.3,handlelength=0.6,handletextpad=0.3)
    if(ll==0):
        handles = [
            Line2D([0], [0], color=WLSFit, lw=4, label='NLLS'),  # Adjust color as per the actual plot color
        ]
        # Add the legenda
        plt.legend(handles=handles,loc=2, bbox_to_anchor=(0,1.15),
                   fontsize=36,columnspacing=0.3,handlelength=0.6,handletextpad=0.3)
plt.tight_layout()
if Save: plt.savefig(FigLoc+'SimDatDTIErrors2.pdf',format='pdf',bbox_inches='tight',transparent=True)

## b

In [None]:
np.random.seed(1)
torch.manual_seed(1)
if os.path.exists(f"{network_path}/DTISimMin.pickle"):
    with open(f"{network_path}/DTISimMin.pickle", "rb") as handle:
        posterior7 = pickle.load(handle)
else:
    Obs = []
    Par = []
    for i in tqdm.tqdm(range(TrainingSamples)):
        params = priorNoise.sample()
        dt = ComputeDTI(params)
        dt = ForceLowFA(dt)
        a = params[-1]
        Obs.append(CustomSimulator(dt,gtabSim7,200,a))
        Par.append(np.hstack([mat_to_vals(dt),a]))
    
    Obs = np.array(Obs)
    Par = np.array(Par)
    Obs = torch.tensor(Obs).float()
    Par = torch.tensor(Par).float()
    
    # Create inference object. Here, NPE is used.
    inference = SNPE(prior=priorNoise)
    
    # generate simulations and pass to the inference object
    inference = inference.append_simulations(Par, Obs)
    
    # train the density estimator and build the posterior
    density_estimator = inference.train()
    posterior7 = inference.build_posterior(density_estimator)
    if not os.path.exists(f"{network_path}/DTISimMin.pickle"):
        with open(f"{network_path}/DTISimMin.pickle", "wb") as handle:
            pickle.dump(posterior7, handle)

In [None]:
torch.manual_seed(10)
SNR = NoiseLevels
Error7 = []
NoiseApprox7 = []
for k in tqdm.tqdm(range(5)):
    ErrorN2 = []
    ENoise = []
    for i in range(500):
        tparams = mat_to_vals(DTISim7[i])
        tObs = Samples7[k,:,i]
        mat_true = vals_to_mat(tparams)
        evals_true,evecs_true = np.linalg.eigh(mat_true)
        true_signal_dti = single_tensor(gtabSim7, S0=200, evals=evals_true, evecs=evecs_true,
                           snr=None)
        posterior_samples_1 = posterior7.sample((InferSamples,), x=tObs,show_progress_bars=False)
        mat_guess = vals_to_mat(np.array(posterior_samples_1.mean(axis=0)))
        ErrorN2.append(Errors(mat_guess,mat_true,gtabSim7,true_signal_dti,tObs))
        ENoise.append(posterior_samples_1[:,-1].mean())
    NoiseApprox7.append(ENoise)
    Error7.append(ErrorN2)

NoiseApprox7 = np.array(NoiseApprox7)    


In [None]:
k,gtab,Samps,DTIS = 7,gtabSim7,Samples7,DTISim7
tenmodel = dti.TensorModel(gtab,fit_method='NLLS')
Error_n = []
for S,Noise in zip(Samps,NoiseLevels):
    Error = []
    for i in range(500):
        tenfit = tenmodel.fit(S[:,i])
        tensor_vals = dti.lower_triangular(tenfit.quadratic_form)
        DT_test = vals_to_mat(tensor_vals)
        Error.append(Errors(DT_test,DTIS[i],gtab,Samps[0][:,i],S[:,i]))
    Error_n.append(Error)
Error_n = np.array(Error_n)

In [None]:
fig,axs = plt.subplots(1,6,figsize=(27,3))
ax = axs.ravel()
for ll,(a,E,t) in enumerate(zip(ax,np.array(Error7).T[2:],Errors_name[2:])):
    plt.sca(a) 
    bp = box_plot(E[:,1:].T,SBIFit-0.2, np.clip(SBIFit+0.2,0,1),positions=[1.3,2.3,3.3,4.3],showfliers=False,widths=0.3)
    bp2 = box_plot(Error_n[1:,:,ll+2],WLSFit-0.2, np.clip(WLSFit+0.2,0,1),showfliers=False,widths=0.3)
    plt.sca(a)
    if(ll>3):
        plt.xlabel('SNR', fontsize=24)
    plt.xticks([1.15, 2.15, 3.15, 4.15,], NoiseLevels[1:],fontsize=32)
    ymax = max(bp2['whiskers'][1].get_ydata()[1],bp['whiskers'][1].get_ydata()[1])*1.2
    plt.yticks(fontsize=32)
    # Create custom legend handles
    plt.gca().ticklabel_format(axis='y',style='sci',scilimits=(-1,1))
    plt.grid(axis='y')
    if(ll==1):
        handles = [
            Line2D([0], [0], color=SBIFit, lw=4, label='SBI'),  # Adjust color as per the actual plot color
        ]
        # Add the legenda
        plt.legend(handles=handles,loc=2, bbox_to_anchor=(0,1.05),
                   fontsize=36,columnspacing=0.3,handlelength=0.6,handletextpad=0.3)
    if(ll==0):
        handles = [
            Line2D([0], [0], color=WLSFit, lw=4, label='NLLS'),  # Adjust color as per the actual plot color
        ]
        # Add the legenda
        plt.legend(handles=handles,loc=2, bbox_to_anchor=(0,1.15),
                   fontsize=36,columnspacing=0.3,handlelength=0.6,handletextpad=0.3)
plt.tight_layout()
if Save: plt.savefig(FigLoc+'SimDatDTIErrors2_7.pdf',format='pdf',bbox_inches='tight',transparent=True)

# Supplemental Fig 5

In [None]:
FigLoc = image_path + 'Fig_S5/'
if not os.path.exists(FigLoc):
    os.makedirs(FigLoc)

In [None]:
fimg_init, fbvals, fbvecs = get_fnames('small_64D')
bvals, bvecs = read_bvals_bvecs(fbvals, fbvecs)
hsph_initial28 = HemiSphere(xyz=bvecs[1:29])
hsph_initial20 = HemiSphere(xyz=bvecs[1:20])
hsph_updated28,_ = disperse_charges(hsph_initial28,5000)
hsph_updated20,_ = disperse_charges(hsph_initial20,5000)
gtabSimSub = gradient_table(np.array([0]+[1000]*19+[3000]*28).squeeze(), np.vstack([[0,0,0],hsph_updated20.vertices,hsph_updated28.vertices]))

In [None]:
if os.path.exists(f"{network_path}/DKISimMid.pickle"):
    with open(f"{network_path}/DKISimMid.pickle", "rb") as handle:
        posteriorFull = pickle.load(handle)
else:
    torch.manual_seed(1)
    np.random.seed(1)
    DT = []
    KT = []
    S0 = []
    DT1,KT1 = GenDTKT([DT1_full,DT2_full],[x4_full,R1_full,x2_full,R2_full],12,int(2.5*60000))
    DT2,KT2 = GenDTKT([DT1_lfa,DT2_lfa],[x4_lfa,R1_lfa,x2_lfa,R2_lfa],12,int(2.5*20000))
    DT3,KT3 = GenDTKT([DT1_hfa,DT2_hfa],[x4_hfa,R1_hfa,x2_hfa,R2_hfa],12,int(2.5*60000))
    DT4,KT4 = GenDTKT([DT1_ulfa,DT2_ulfa],[x4_ulfa,R1_ulfa,x2_ulfa,R2_ulfa],12,int(2.5*60000))
    DT5,KT5 = GenDTKT([DT1_hak,DT2_hak],[x4_hak,R1_hak,x2_hak,R2_hak],12,int(2.5*60000))
    
    
    DT = np.vstack([DT1,DT2,DT3,DT4,DT5])
    KT = np.vstack([KT1,KT2,KT3,KT4,KT5])
    
    S0 = np.array(S0).reshape(len(S0),1)
    
    indx = np.arange(len(KT))
    Obs = np.zeros([len(KT),len(gtabSimSub.bvecs)])
    kk = 0
    while len(indx)>0:
        for i in tqdm.tqdm(indx): 
            Obs[i] = CustomDKISimulator(DT[i],KT[i],gtabSimSub,200,np.random.rand()*30)
        
        indxNew = []
        for i,O in enumerate(Obs):
            if (O>800).any() or (O<0).any():
                indxNew.append(i)
        KT[indxNew] = KT[indxNew]/2
        DT[indxNew] = GenDTKT([DT1_full,DT2_full],[x4_full,R1_full,x2_full,R2_full],kk,1)[0]
    
        indx = indxNew
        kk+=1
    Par = np.hstack([DT,KT])
    Obs = torch.tensor(Obs).float()
    Par = torch.tensor(Par).float()
    
    # Create inference object. Here, NPE is used.
    inference = SNPE()
    
    # generate simulations and pass to the inference object
    inference = inference.append_simulations(Par, Obs)
    
    # train the density estimator and build the posterior
    density_estimator = inference.train(stop_after_epochs= 100)
    posteriorFull = inference.build_posterior(density_estimator)
    
    os.system('say "Network done."')
    if not os.path.exists(f"{network_path}/DKISimMid.pickle"):
        with open(f"{network_path}/DKISimMid.pickle", "wb") as handle:
            pickle.dump(posteriorFull, handle)

In [None]:
torch.manual_seed(2)
np.random.seed(2)
j = 1
vL = torch.tensor([0.2*j])
vS = torch.tensor([0.01*j])  

kk = np.random.randint(0,4)
if(kk==0):
    DT,KT = GenDTKT([DT1_full,DT2_full],[x4_full,R1_full,x2_full,R2_full],2,1)
elif(kk==1):
    DT,KT = GenDTKT([DT1_lfa,DT2_lfa],[x4_lfa,R1_lfa,x2_lfa,R2_lfa],2,1)
elif(kk==2):
    DT,KT = GenDTKT([DT1_hfa,DT2_hfa],[x4_hfa,R1_hfa,x2_hfa,R2_hfa],2,1)
elif(kk==3):
    DT,KT = GenDTKT([DT1_ulfa,DT2_ulfa],[x4_ulfa,R1_ulfa,x2_ulfa,R2_ulfa],2,1)

tObs = CustomDKISimulator(np.squeeze(DT),np.squeeze(KT),gtabSimSub,200,20)
tTrue = CustomDKISimulator(np.squeeze(DT),np.squeeze(KT),gtabSim,200,None)

In [None]:
torch.manual_seed(1)
np.random.seed(1)
posterior_samples_1 = posteriorFull.sample((InferSamples,), x=tObs,show_progress_bars=True)

## a

In [None]:
GuessDKI = posterior_samples_1.mean(axis=0)
GuessSig = CustomDKISimulator(GuessDKI[:6],GuessDKI[6:],gtabSim,200)
plt.subplots(figsize=(6,1))
plt.plot(tTrue,lw=2,c='k',label='true signal')
plt.plot(GuessSig,lw=2,c=SBIFit,ls='--',label='SBI signal')
plt.axis('off')
plt.legend(ncols=2,loc=1,bbox_to_anchor =  (1.09,1.95),fontsize=32,columnspacing=0.3,handlelength=0.4,handletextpad=0.1)
plt.fill_betweenx(np.arange(0,500,50),0*np.ones(10),20*np.ones(10),color='gray',alpha=0.5)
plt.fill_betweenx(np.arange(0,500,50),64*np.ones(10),92*np.ones(10),color='gray',alpha=0.5)
plt.ylim(-9.996985449425491, 209.99985644997255)

plt.savefig(FigLoc+'20ReconSBI.pdf',format='pdf',bbox_inches='tight',transparent=True)

## b

In [None]:
dkimodel = dki.DiffusionKurtosisModel(gtabSimSub,fit_method='NLLS')
tenfit = dkimodel.fit(tObs)
plt.subplots(figsize=(6,1))
plt.plot(tTrue,lw=2,c='k',label='true signal')
plt.plot(tenfit.predict(gtabSim,200),lw=2,c=WLSFit,ls='--',label='NLLS signal')
plt.axis('off')
plt.legend(ncols=2,loc=1,bbox_to_anchor =  (1.09,1.95),fontsize=32,columnspacing=0.3,handlelength=0.4,handletextpad=0.1)
plt.fill_betweenx(np.arange(0,500,50),0*np.ones(10),20*np.ones(10),color='gray',alpha=0.5)
plt.fill_betweenx(np.arange(0,500,50),64*np.ones(10),92*np.ones(10),color='gray',alpha=0.5)
plt.ylim(-9.996985449425491, 209.99985644997255)
plt.savefig(FigLoc+'20ReconWLS.pdf',format='pdf',bbox_inches='tight',transparent=True)

In [None]:
torch.manual_seed(1)
np.random.seed(1)
Mets = []
MetsSBI = []
for i in tqdm.tqdm([20,10,5,2]):
    m = []
    m2 = []
    for k in range(50):
        tObs = CustomDKISimulator(np.squeeze(DT), np.squeeze(KT),gtabSimSub, S0=200, snr=i)#
        dkimodel = dki.DiffusionKurtosisModel(gtabSimSub,fit_method='NLLS')
        tenfit = dkimodel.fit(tObs)
        m.append(DKIMetrics(tenfit.lower_triangular(),tenfit.kt,False))
        posterior_samples_1 = posteriorFull.sample((InferSamples,), x=tObs,show_progress_bars=False)
        GuessDKI = posterior_samples_1.mean(axis=0)
        m2.append(DKIMetrics(GuessDKI[:6],GuessDKI[6:],False))
    Mets.append(m)
    MetsSBI.append(m2)

## c

In [None]:
Mets = np.array(Mets)
MetsSBI = np.array(MetsSBI)
fig,ax = plt.subplots(1,5,figsize=(22.5,3))
for i in range(5):
    plt.sca(ax[i])
    viol_plot(Mets[:,:,i].T,WLSFit,)
    viol_plot(MetsSBI[:,:,i].T,SBIFit,widths=0.3,positions=[1.3,2.3,3.3,4.3])
    plt.xticks([1.15, 2.15, 3.15, 4.15,],[20,10,5,2],fontsize=32)
    plt.axhline(DKIMetrics(np.squeeze(DT),np.squeeze(KT),False)[i],lw=3,ls='--',c='k')
    plt.yticks(fontsize=32)
plt.savefig(FigLoc+'EgSigMetrics20.pdf',format='pdf',bbox_inches='tight',transparent=True)

## d

In [None]:
torch.manual_seed(1)
np.random.seed(1)
DT1,KT1 = GenDTKT([DT1_full,DT2_full],[x4_full,R1_full,x2_full,R2_full],1,40)
DT2,KT2 = GenDTKT([DT1_lfa,DT2_lfa],[x4_lfa,R1_lfa,x2_lfa,R2_lfa],1,40)
DT3,KT3 = GenDTKT([DT1_hfa,DT2_hfa],[x4_hfa,R1_hfa,x2_hfa,R2_hfa],1,40)
DT4,KT4 = GenDTKT([DT1_ulfa,DT2_ulfa],[x4_ulfa,R1_ulfa,x2_ulfa,R2_ulfa],1,40)
DT5,KT5 = GenDTKT([DT1_hak,DT2_hak],[x4_hak,R1_hak,x2_hak,R2_hak],12,40)

SampsDT = np.vstack([DT1,DT2,DT3,DT4,DT5])
SampsKT = np.vstack([KT1,KT2,KT3,KT4,KT5])

In [None]:
torch.manual_seed(1)
np.random.seed(1)

Samples20  = []

for Sd,Sk in zip(SampsDT,SampsKT):
    Samples20.append([CustomDKISimulator(Sd,Sk,gtabSimSub, S0=200,snr=scale) for scale in NoiseLevels])

Samples20 = np.array(Samples20)

In [None]:
torch.manual_seed(10)
ErrorFull = []
for k in tqdm.tqdm(range(5)):
    ErrorN2 = []
    ENoise = []
    for i in range(200):
        tObs = Samples20[i,k,:]
        posterior_samples_1 = posteriorFull.sample((InferSamples,), x=tObs,show_progress_bars=False)
        GuessSBI = posterior_samples_1.mean(axis=0)
        
        ErrorN2.append(DKIErrors(GuessSBI[:6],GuessSBI[6:],SampsDT[i],SampsKT[i]))
    ErrorFull.append(ErrorN2)

Error_s = []
dkimodel = dki.DiffusionKurtosisModel(gtabSimSub,fit_method='NLLS')

for k in tqdm.tqdm(range(5)):
    ErrorN2 = []
    ENoise = []
    for i in range(200):
        tObs = Samples20[i,k,:]#Simulator(bvals,bvecs,200,params,Noise)
        tenfit = dkimodel.fit(tObs)
        
        ErrorN2.append(DKIErrors(tenfit.lower_triangular(),tenfit.kt,SampsDT[i],SampsKT[i]))
    Error_s.append(ErrorN2)



In [None]:
ErrorFull = np.array(ErrorFull)
Error_s = np.array(Error_s)
ErrorNames = ['MK Error', 'AK Error', 'RK Error', 'MKT Error', 'KFA Error']
fig,ax = plt.subplots(1,5,figsize=(22.5,3))
for i in range(5):
    plt.sca(ax[i])
    box_plot(Error_s[1:,:,i],WLSFit-0.2, np.clip(WLSFit+0.2,0,1),showfliers=False,widths=0.3)
    box_plot(ErrorFull[1:,:,i],SBIFit-0.2, np.clip(SBIFit+0.2,0,1),showfliers=False,widths=0.3,positions=[1.3,2.3,3.3,4.3])
    plt.xticks([1.15, 2.15, 3.15, 4.15,],[20,10,5,2],fontsize=32)
    plt.gca().ticklabel_format(axis='y',style='sci',scilimits=(-1,1))
    plt.grid(axis='y')
    plt.yticks(fontsize=32)
plt.savefig(FigLoc+'Errors20.pdf',format='pdf',bbox_inches='tight',transparent=True)

## e-g

In [None]:
i=3
bvalloc = './HCP_data/Pat'+str(i)+'/bvals_1k.txt'
bvecloc = './HCP_data/Pat'+str(i)+'/bvecs_1k.txt'
bvalsHCP = np.loadtxt(bvalloc)
bvecsHCP = np.loadtxt(bvecloc)
gtabHCP = gradient_table(bvalsHCP, bvecsHCP)
gtabExt  = gradient_table(np.hstack((bvalsHCP,bvalsHCP3)), np.vstack((bvecsHCP,bvecsHCP3)))

# Choose the first point (arbitrary starting point, e.g., the first gradient)
selected_indices = [1]
distance_matrix = squareform(pdist(bvecsHCP))

temp_bvecs = bvecsHCP[bvalsHCP>0]
temp_bvals = bvalsHCP[bvalsHCP>0]
distance_matrix = squareform(pdist(temp_bvecs))
# Iteratively select the point furthest from the current selection
for _ in range(18):  # We need 7 points in total, and one is already selected
    remaining_indices = list(set(range(len(temp_bvecs))) - set(selected_indices))
    
    # Calculate the minimum distance to the selected points for each remaining point
    min_distances = np.min(distance_matrix[remaining_indices][:, selected_indices], axis=1)
    
    # Select the point with the maximum minimum distance
    next_index = remaining_indices[np.argmax(min_distances)]
    selected_indices.append(next_index)

selected_indices7 = selected_indices

bvalsHCP7_1 = np.insert(temp_bvals[selected_indices7],0,0)
bvecsHCP7_1 = np.insert(temp_bvecs[selected_indices7],0,[0,0,0],axis=0)

i=3
bvalloc = './HCP_data/Pat'+str(i)+'/bvals_3k.txt'
bvecloc = './HCP_data/Pat'+str(i)+'/bvecs_3k.txt'

bvalsHCP3 = np.loadtxt(bvalloc)
bvecsHCP3 = np.loadtxt(bvecloc)
gtabHCP3 = gradient_table(bvalsHCP, bvecsHCP)

# Choose the first point (arbitrary starting point, e.g., the first gradient)
selected_indices = [0]

temp_bvecs = bvecsHCP3[bvalsHCP3>0]
temp_bvals = bvalsHCP3[bvalsHCP3>0]
distance_matrix = squareform(pdist(temp_bvecs))
# Iteratively select the point furthest from the current selection
for _ in range(27):  # We need 7 points in total, and one is already selected
    remaining_indices = list(set(range(len(temp_bvecs))) - set(selected_indices))
    
    # Calculate the minimum distance to the selected points for each remaining point
    min_distances = np.min(distance_matrix[remaining_indices][:, selected_indices], axis=1)
    
    # Select the point with the maximum minimum distance
    next_index = remaining_indices[np.argmax(min_distances)]
    selected_indices.append(next_index)

bvalsHCP7_3 = temp_bvals[selected_indices]
bvecsHCP7_3 = temp_bvecs[selected_indices]

gtabHCP20 = gradient_table(np.hstack((bvalsHCP7_1,bvalsHCP7_3)), np.vstack((bvecsHCP7_1,bvecsHCP7_3)))

true_indx_one = []
for b in bvecsHCP7_1:
    true_indx_one.append(np.linalg.norm(b-bvecsHCP,axis=1).argmin())
true_indx20 = []        
for b in bvecsHCP7_3:
    true_indx20.append(np.linalg.norm(b-bvecsHCP3,axis=1).argmin())
true_indx20 = true_indx_one+[t+69 for t in true_indx20]

In [None]:
torch.manual_seed(1)
np.random.seed(1)
np.random.seed(1)
torch.manual_seed(1)
if os.path.exists(f"{network_path}/DKIHCPMid.pickle"):
    with open(f"{network_path}/DKIHCPMid.pickle", "rb") as handle:
        posteriorFull = pickle.load(handle)
else:
    DT = []
    KT = []
    S0 = []
    DT1,KT1 = GenDTKT([DT1_full,DT2_full],[x4_full,R1_full,x2_full,R2_full],12,int(2.5*60000))
    DT2,KT2 = GenDTKT([DT1_lfa,DT2_lfa],[x4_lfa,R1_lfa,x2_lfa,R2_lfa],12,int(2.5*20000))
    DT3,KT3 = GenDTKT([DT1_hfa,DT2_hfa],[x4_hfa,R1_hfa,x2_hfa,R2_hfa],12,int(2.5*60000))
    DT4,KT4 = GenDTKT([DT1_ulfa,DT2_ulfa],[x4_ulfa,R1_ulfa,x2_ulfa,R2_ulfa],12,int(2.5*60000))
    DT5,KT5 = GenDTKT([DT1_hak,DT2_hak],[x4_hak,R1_hak,x2_hak,R2_hak],12,int(2.5*60000))
        
    DT = np.vstack([DT1,DT2,DT3,DT4,DT5])
    KT = np.vstack([KT1,KT2,KT3,KT4,KT5])
    
    S0Dist = BoxUniform(low=torch.tensor([lower_S0]), high=torch.tensor([upper_S0]))
    
    S0 = S0Dist.sample([650000])
    
    S0 = np.array(S0).reshape(len(S0),1)
    
    indx = np.arange(len(KT))
    Obs = np.zeros([len(KT),len(gtabHCP20.bvecs)])
    kk = 0
    while len(indx)>0:
        for i in tqdm.tqdm(indx): 
            Obs[i] = CustomDKISimulator(DT[i],KT[i],gtabHCP20,S0[i],np.random.rand()*20 + 30)
        
        indxNew = []
        for i,O in enumerate(Obs):
            if (O>4*np.array(S0[i])).any() or (O<0).any():
                indxNew.append(i)
        KT[indxNew] = KT[indxNew]/2
        DT[indxNew] = GenDTKT([DT1_full,DT2_full],[x4_full,R1_full,x2_full,R2_full],kk,1)[0]
    
        indx = indxNew
        kk+=1
    Par = np.hstack([DT,KT,S0])
    Obs = torch.tensor(Obs).float()
    Par = torch.tensor(Par).float()
    
    # Create inference object. Here, NPE is used.
    inference = SNPE()
    
    # generate simulations and pass to the inference object
    inference = inference.append_simulations(Par, Obs)
    
    # train the density estimator and build the posterior
    density_estimator = inference.train(stop_after_epochs= 100)
    posteriorFull = inference.build_posterior(density_estimator)
    
    os.system('say "Network done."')
    if not os.path.exists(f"{network_path}/DKIHCPMid.pickle"):
        with open(f"{network_path}/DKIHCPMid.pickle", "wb") as handle:
            pickle.dump(posteriorFull, handle)

In [None]:
i=3
fdwi = './HCP_data/Pat'+str(i)+'/diff_1k.nii.gz'
bvalloc = './HCP_data/Pat'+str(i)+'/bvals_1k.txt'
bvecloc = './HCP_data/Pat'+str(i)+'/bvecs_1k.txt'

fdwi3 = './HCP_data/Pat'+str(i)+'/diff_3k.nii.gz'
bvalloc3 = './HCP_data/Pat'+str(i)+'/bvals_3k.txt'
bvecloc3 = './HCP_data/Pat'+str(i)+'/bvecs_3k.txt'

bvalsHCP = np.loadtxt(bvalloc)
bvecsHCP = np.loadtxt(bvecloc)
gtabHCP = gradient_table(bvalsHCP, bvecsHCP)

bvalsHCP3 = np.loadtxt(bvalloc3)
bvecsHCP3 = np.loadtxt(bvecloc3)
gtabHCP3 = gradient_table(bvalsHCP3, bvecsHCP3)

gtabExt  = gradient_table(np.hstack((bvalsHCP,bvalsHCP3)), np.vstack((bvecsHCP,bvecsHCP3)))

data, affine, img = load_nifti(fdwi, return_img=True)
data, affine = reslice(data, affine, (1.5,1.5,1.5), (2.5,2.5,2.5))
maskdata, mask = median_otsu(data, vol_idx=range(10, 50), median_radius=4,
                             numpass=1, autocrop=False, dilate=2)
_, mask2 = median_otsu(data, vol_idx=range(10, 50), median_radius=4,
                             numpass=1, autocrop=True, dilate=2)


data3, affine, img = load_nifti(fdwi3, return_img=True)
data3, affine = reslice(data3, affine, (1.5,1.5,1.5), (2.5,2.5,2.5))
# Get the indices of True values
true_indices = np.argwhere(mask)

# Determine the minimum and maximum indices along each dimension
min_coords = true_indices.min(axis=0)
max_coords = true_indices.max(axis=0)

maskdata  = maskdata[min_coords[0]:max_coords[0]+1,min_coords[1]:max_coords[1]+1,min_coords[2]:max_coords[2]+1]
axial_middle = maskdata.shape[2] // 2
maskdata3 = data3[min_coords[0]:max_coords[0]+1,min_coords[1]:max_coords[1]+1,min_coords[2]:max_coords[2]+1]

TestData = np.concatenate([maskdata[:, :, axial_middle, :],maskdata3[:, :, axial_middle, :]],axis=-1)
TestData4D = np.concatenate([maskdata,maskdata3],axis=-1)

In [None]:
ArrShape = TestData4D[:,:,axial_middle,0].shape
NoiseEst = np.zeros([62, 68 ,22])
torch.manual_seed(10)
for i in tqdm.tqdm(range(ArrShape[0])):
    for j in range(ArrShape[1]):
        if(np.sum(TestData4D[i,j,axial_middle,:69],axis=-1) == 0):
            pass
        else:
            posterior_samples_1 = posteriorFull.sample((InferSamples,), x=TestData4D[i,j,axial_middle,true_indx20],show_progress_bars=False)
            NoiseEst[i,j] = posterior_samples_1.mean(axis=0)
os.system('say "Finished sampling."')


In [None]:
NoiseEst2 =  np.zeros_like(NoiseEst)
for i in range(62):
    for j in range(68):    
        NoiseEst2[i,j] = np.hstack([mat_to_vals(clip_negative_eigenvalues(vals_to_mat(NoiseEst[i,j]))),NoiseEst[i,j,6:]])

In [None]:

MK_SBIFull  = np.zeros([62, 68])
AK_SBIFull  = np.zeros([62, 68])
RK_SBIFull  = np.zeros([62, 68])
MKT_SBIFull = np.zeros([62, 68])
KFA_SBIFull = np.zeros([62, 68])
for i in tqdm.tqdm(range(62)):
    for j in range(68):
        if(np.sum(TestData4D[i,j,axial_middle,:69],axis=-1) == 0):
            pass
        else:
            Metrics = DKIMetrics(NoiseEst2[i,j][:6],NoiseEst2[i,j][6:21])
            MK_SBIFull[i,j] = Metrics[0]
            AK_SBIFull[i,j] = Metrics[1]
            RK_SBIFull[i,j] = Metrics[2]
            MKT_SBIFull[i,j] = Metrics[3]
            KFA_SBIFull[i,j] = Metrics[4]

In [None]:
dkimodelNL = dki.DiffusionKurtosisModel(gtabHCP20,fit_method='NLLS')
dkifitNL = dkimodelNL.fit(TestData[:,:,true_indx20])
MK_NLFull  = np.zeros([62, 68])
AK_NLFull  = np.zeros([62, 68])
RK_NLFull  = np.zeros([62, 68])
MKT_NLFull = np.zeros([62, 68])
KFA_NLFull = np.zeros([62, 68])
for i in range(62):
    for j in range(68):
        if(np.sum(TestData4D[i,j,axial_middle,:69],axis=-1) == 0):
            pass
        else:
            Metrics = DKIMetrics(dkifitNL[i,j].lower_triangular(),dkifitNL[i,j].kt)
            MK_NLFull[i,j] = Metrics[0]
            AK_NLFull[i,j] = Metrics[1]
            RK_NLFull[i,j] = Metrics[2]
            MKT_NLFull[i,j] = Metrics[3]
            KFA_NLFull[i,j] = Metrics[4]

In [None]:
KFA_SBIFull[np.isnan(KFA_SBIFull)] = 1


In [None]:
tnorm = TwoSlopeNorm(vmin=-0,vcenter = 0.5,vmax=1)
plt.imshow((MK_SBIFull*mask2[:,:,axial_middle]).T,norm=tnorm,cmap='gray')
plt.xticks([])
plt.yticks([])
if Save: plt.savefig(FigLoc+'MKSBI20.pdf',format='pdf',bbox_inches='tight',transparent=True)
plt.show()
tnorm = TwoSlopeNorm(vmin=-0,vcenter = 0.5,vmax=1)
plt.imshow((AK_SBIFull*mask2[:,:,axial_middle]).T,norm=tnorm,cmap='gray')
plt.xticks([])
plt.yticks([])
if Save: plt.savefig(FigLoc+'AKSBI20.pdf',format='pdf',bbox_inches='tight',transparent=True)
plt.show()
tnorm = TwoSlopeNorm(vmin=-0,vcenter = 0.5,vmax=1)
plt.imshow((RK_SBIFull*mask2[:,:,axial_middle]).T,norm=tnorm,cmap='gray')
plt.xticks([])
plt.yticks([])
if Save: plt.savefig(FigLoc+'RKSBI20.pdf',format='pdf',bbox_inches='tight',transparent=True)
plt.show()
tnorm = TwoSlopeNorm(vmin=-0,vcenter = 0.5,vmax=1)
plt.imshow((MKT_SBIFull*mask2[:,:,axial_middle]).T,norm=tnorm,cmap='gray')
plt.xticks([])
plt.yticks([])
if Save: plt.savefig(FigLoc+'MKTSBI20.pdf',format='pdf',bbox_inches='tight',transparent=True)
plt.show()
tnorm = TwoSlopeNorm(vmin=-0,vcenter = 0.5,vmax=1)
plt.imshow((KFA_SBIFull*mask2[:,:,axial_middle]).T,norm=tnorm,cmap='gray')
plt.colorbar()
plt.xticks([])
plt.yticks([])
if Save: plt.savefig(FigLoc+'KFASBI20.pdf',format='pdf',bbox_inches='tight',transparent=True)
plt.show()

In [None]:
tnorm = TwoSlopeNorm(vmin=-0,vcenter = 0.5,vmax=1)
plt.imshow((MK_NLFull*mask2[:,:,axial_middle]).T,norm=tnorm,cmap='gray')
plt.xticks([])
plt.yticks([])
if Save: plt.savefig(FigLoc+'MKNL20.pdf',format='pdf',bbox_inches='tight',transparent=True)
plt.show()
tnorm = TwoSlopeNorm(vmin=-0,vcenter = 0.5,vmax=1)
plt.imshow((AK_NLFull*mask2[:,:,axial_middle]).T,norm=tnorm,cmap='gray')
plt.xticks([])
plt.yticks([])
if Save: plt.savefig(FigLoc+'AKNL20.pdf',format='pdf',bbox_inches='tight',transparent=True)
plt.show()
tnorm = TwoSlopeNorm(vmin=-0,vcenter = 0.5,vmax=1)
plt.imshow((RK_NLFull*mask2[:,:,axial_middle]).T,norm=tnorm,cmap='gray')
plt.xticks([])
plt.yticks([])
if Save: plt.savefig(FigLoc+'RKNL20.pdf',format='pdf',bbox_inches='tight',transparent=True)
plt.show()
tnorm = TwoSlopeNorm(vmin=-0,vcenter = 0.5,vmax=1)
plt.imshow((MKT_NLFull*mask2[:,:,axial_middle]).T,norm=tnorm,cmap='gray')
plt.xticks([])
plt.yticks([])
if Save: plt.savefig(FigLoc+'MKTNL20.pdf',format='pdf',bbox_inches='tight',transparent=True)
plt.show()
tnorm = TwoSlopeNorm(vmin=-0,vcenter = 0.5,vmax=1)
plt.imshow((KFA_NLFull*mask2[:,:,axial_middle]).T,norm=tnorm,cmap='gray')
plt.colorbar()
plt.xticks([])
plt.yticks([])
if Save: plt.savefig(FigLoc+'KFANL20.pdf',format='pdf',bbox_inches='tight',transparent=True)
plt.show()

# Supplemental Fig 6

In [None]:
FigLoc = image_path + 'Fig_S6/'
if not os.path.exists(FigLoc):
    os.makedirs(FigLoc)

In [None]:
for i in range(5):
    plt.hist(ParTest[:,i],density=True,range=[0,1],color=SBIFit,label='Simulated')
    plt.hist(TrueMets[:,i],alpha=0.8,density=True,range=[0,1],color='gray',label='HCP')
    if(i==0):
        plt.legend(fontsize=32,columnspacing=0.3,handlelength=0.4,handletextpad=0.1)
    if Save: plt.savefig(FigLoc+'EgMetricDKI_'+str(i)+'.pdf',format='pdf',bbox_inches='tight',transparent=True)
    plt.show()

# Supplemental Fig 7

In [None]:
save_path = './SBI_Weights/'

In [None]:
FigLoc = image_path + 'Fig_S7/'
if not os.path.exists(FigLoc):
    os.makedirs(FigLoc)

## a

In [None]:
fdwi = './HCP_data/Pat'+str(1)+'/diff_1k.nii.gz'
bvalloc = './HCP_data/Pat'+str(1)+'/bvals_1k.txt'
bvecloc = './HCP_data/Pat'+str(1)+'/bvecs_1k.txt'

bvalsHCP = np.loadtxt(bvalloc)
bvecsHCP = np.loadtxt(bvecloc)
gtabHCP = gradient_table(bvalsHCP, bvecsHCP)

data, affine, img = load_nifti(fdwi, return_img=True)
data, affine = reslice(data, affine, (1.5,1.5,1.5), (2.5,2.5,2.5))
axial_middle = data.shape[2] // 2
maskdata, mask = median_otsu(data, vol_idx=range(10, 50), median_radius=3,
                             numpass=1, autocrop=True, dilate=2)

# Choose the first point (arbitrary starting point, e.g., the first gradient)
selected_indices = [0]
distance_matrix = squareform(pdist(bvecsHCP))
# Iteratively select the point furthest from the current selection
for _ in range(19):  # We need 7 points in total, and one is already selected
    remaining_indices = list(set(range(len(bvecsHCP))) - set(selected_indices))
    
    # Calculate the minimum distance to the selected points for each remaining point
    min_distances = np.min(distance_matrix[remaining_indices][:, selected_indices], axis=1)
    
    # Select the point with the maximum minimum distance
    next_index = remaining_indices[np.argmax(min_distances)]
    selected_indices.append(next_index)

selected_indices = selected_indices

bvalsHCP20 = bvalsHCP[selected_indices]
bvecsHCP20 = bvecsHCP[selected_indices]
gtabHCP20 = gradient_table(bvalsHCP20, bvecsHCP20)

In [None]:
if os.path.exists(f"{network_path}/DTIHCPMid.pickle"):
    with open(f"{network_path}/DTIHCPMid.pickle", "rb") as handle:
        posterior20 = pickle.load(handle)
else:
    np.random.seed(1)
    torch.manual_seed(1)
    bvals = gtabHCP.bvals
    bvecs = gtabHCP.bvecs
    Obs = []
    Par = []
    for i in tqdm.tqdm(range(TrainingSamples)):
        params = priorS0.sample()
        dt = ComputeDTI(params[:-1])
        if(np.random.rand()<0.2):
            dt = ForceLowFA(dt)
        Obs.append(CustomSimulator(dt,gtabHCP20,params[-1],np.random.rand()*30))
        Par.append(np.hstack([mat_to_vals(ComputeDTI(params)),params[-1]]))
    
    Obs = np.array(Obs)
    Par = np.array(Par)
    Obs = torch.tensor(Obs).float()
    Par= torch.tensor(Par).float()
    
    # Create inference object. Here, NPE is used.
    inference = SNPE(prior=priorS0)
    
    # generate simulations and pass to the inference object
    inference = inference.append_simulations(Par, Obs)
    
    # train the density estimator and build the posterior
    density_estimator = inference.train()
    posterior20= inference.build_posterior(density_estimator)
    if not os.path.exists(f"{network_path}/DTIHCPMid.pickle"):
        with open(f"{network_path}/DTIHCPMid.pickle", "wb") as handle:
            pickle.dump(posterior20, handle)

In [None]:
ArrShape = maskdata[:,:,axial_middle,0].shape
NoiseEst = np.zeros([55,64,7])
for i in tqdm.tqdm(range(ArrShape[0])):
    for j in range(ArrShape[1]):
        torch.manual_seed(10)
        if(np.sum(maskdata[i,j,axial_middle,:]) == 0):
            pass
        else:
            posterior_samples_1 = posterior20.sample((InferSamples,), x=maskdata[i,j,axial_middle,selected_indices],show_progress_bars=False)
            NoiseEst[i,j] = np.array([histogram_mode(p) for p in posterior_samples_1.T])
            


In [None]:
NoiseEst2 =  np.zeros_like(NoiseEst)
for i in range(55):
    for j in range(64):    
        NoiseEst2[i,j] = np.hstack([mat_to_vals(clip_negative_eigenvalues(vals_to_mat(NoiseEst[i,j]))),NoiseEst[i,j,-1]])
MD_SBI20 = np.zeros([55,64])
FA_SBI20 = np.zeros([55,64])
for i in range(55):
    for j in range(64):
        Eigs = np.linalg.eigh(vals_to_mat(NoiseEst2[i,j,:6]))[0]
        MD_SBI20[i,j] = np.mean(Eigs)
        FA_SBI20[i,j] = FracAni(Eigs,np.mean(Eigs))
FA_SBI20[np.isnan(FA_SBI20)] = 0

In [None]:
gtab2 = gtabHCP20
tenmodel = dti.TensorModel(gtab2,return_S0_hat = True,fit_method='NLLS')
tenfit = tenmodel.fit(maskdata[:,:,axial_middle,selected_indices])
FA20 = dti.fractional_anisotropy(tenfit.evals)
MD20 = dti.mean_diffusivity(tenfit.evals)

In [None]:
img = plt.imshow(MD_SBI20.T,cmap='gray')
plt.axis('off')
#cbar = plt.colorbar()
#cbar.formatter.set_powerlimits((0, 0))
vmin, vmax = img.get_clim()
if Save: plt.savefig(FigLoc+'HCP_SBI_MD_20.pdf',format='pdf',bbox_inches='tight',transparent=True)

In [None]:
plt.imshow(MD20.T,cmap='gray',vmin=vmin, vmax=vmax)
plt.axis('off')
cbar = plt.colorbar()
cbar.formatter.set_powerlimits((0, 0))
if Save: plt.savefig(FigLoc+'HCP_WLS_MD_20.pdf',format='pdf',bbox_inches='tight',transparent=True)

In [None]:
data = MD20.T-MD_SBI20.T
norm = TwoSlopeNorm(vmin=np.min(data), vcenter=0, vmax=np.max(data))
plt.imshow(data,cmap='seismic',norm=norm)
plt.axis('off')
cbar = plt.colorbar()
ticks = [np.min(data), 0, np.max(data)]  # Adjust the number of ticks as needed
cbar.set_ticks(ticks)
cbar.formatter.set_powerlimits((0, 0))
if Save: plt.savefig(FigLoc+'HCP_MD_Diff_20.pdf',format='pdf',bbox_inches='tight',transparent=True)

## b

In [None]:
for i in range(55):
    for j in range(64):
        if(np.sum(maskdata[i,j,axial_middle,:]) == 0):
            FA20[i,j] = 0

In [None]:
img = plt.imshow(FA_SBI20.T,cmap='gray')
plt.axis('off')
#cbar = plt.colorbar()
#cbar.formatter.set_powerlimits((0, 0))
vmin, vmax = img.get_clim()
if Save: plt.savefig(FigLoc+'HCP_SBI_FA_20.pdf',format='pdf',bbox_inches='tight',transparent=True)
plt.show()
plt.imshow(FA20.T,cmap='gray',vmin=vmin, vmax=vmax)
plt.axis('off')
cbar = plt.colorbar()
cbar.formatter.set_powerlimits((0, 0))
if Save: plt.savefig(FigLoc+'HCP_WLS_FA_20.pdf',format='pdf',bbox_inches='tight',transparent=True)

In [None]:
from matplotlib.colors import TwoSlopeNorm
data = FAFull.T-FA_SBIFull.T
norm = TwoSlopeNorm(vmin=np.min(data),vcenter=0, vmax=np.max(data))
plt.imshow(data,cmap='seismic',norm=norm)
plt.axis('off')
cbar = plt.colorbar()
ticks = [np.min(data), 0, np.max(data)]  # Adjust the number of ticks as needed
cbar.set_ticks(ticks)
cbar.formatter.set_powerlimits((0, 0))
if Save: plt.savefig(FigLoc+'HCP_FA_Diff_20.pdf',format='pdf',bbox_inches='tight',transparent=True)