In [None]:
##
#Import Packages
import numpy as np
import sys
import pickle
import torch
import matplotlib.pyplot as plt
from scipy.stats import iqr
from scipy.optimize import curve_fit

##
#SBI Specific Packages
from sbi import analysis as analysis
from sbi import utils as utils

In [None]:
##
#Define Path to Code Database
DirPath = '/PATH/TO/bin/'

#Define Path for Posterior Object
InputPath = '/PATH/TO/Posterior.pkl'

#Define Path to Example Data
DataPath = '/PATH/TO/ExampleData/'

#Define Output Path
OutputPath = '/PATH/TO/Output/'

In [None]:
##
#Load Posterior
with open(InputPath, "rb") as handle:
    posterior = pickle.load(handle)

In [None]:
##
#Import Custom Functions
sys.path.append(DirPath)
from ImportData import *
from FreedAnalytical import *

In [None]:
##
#Load Data
    # bvecs - bvectors (3xn)
    # FlipAngles - Flip Angles (degrees) (1xn)
    # tau - Diffusion Gradient Duration (seconds) (1xn), 
    # G - Diffusion Gradient Duration (G/cm - Equivalent to mT/m Divided by 10) (1xn)
    # TRs - Repetition Times (seconds) (1xn)
    # b0s - Array Defining b0 locations (b0 = 1, dwi = 0) (1xn)

bvecs, FlipAngles, tau, G, TRs, _ = ImportTextDataDWSSFP(DataPath)

In [None]:
##
#Obtain Arbitrary Tensor / Signal Pairs
from sbi.inference import SNPE, simulate_for_sbi

##
#Define Nuber of Simulations
nSim = 1000

#Define SNR Level (wrt b0)
SNR = 5

##
#Define Prior Bounds
PriorLow = [0, 0, 0, -0.00025, -0.00025, -0.00025]
PriorHigh = [0.0005, 0.0005, 0.0005, 0.00025, 0.00025, 0.00025]

##
#Define Prior (Uniform)
prior = utils.BoxUniform(low=torch.tensor(PriorLow), high=torch.tensor(PriorHigh))

##
#Define Range of T1 (ms), T2 (ms) & B1 (normalised) for Forward Simulator
T1Range = [300,1200]
T2Range = [20,80]
B1Range = [0.2,1.2]

##
#Define Simulator
simulator = lambda theta: FreedDWSSFPTensor_Conditional_SBIWrapper(theta,G,tau,TRs,FlipAngles,bvecs,B1Range,T1Range,T2Range,Cond=True)

##
#Estimate Tensor / Signal Pairs - Exagerate no. Simulations to Account for NaNs (i.e. Invalid Tensors)
DArb, SArb = simulate_for_sbi(simulator, prior, nSim*5)

##
#Add Noise
SArb[:,:-3] = np.abs(SArb[:,:-3] + (torch.randn(*SArb[:,:-3].shape)*SArb[:,0][:,np.newaxis]/SNR))

In [None]:
##
#Obtain nSim Valid Tensor / Signal Pairs for Evaluation
Finite_idx = np.squeeze(np.argwhere(np.isfinite(SArb[:,0])))
DArb = DArb[Finite_idx[0:nSim],:]
SArb = SArb[Finite_idx[0:nSim],:]

In [None]:
##
#Define Number of Posterior Samples per Evaluation
Samples = 100

##
#Define Initial Fitting Point & Bounds (NLLS)
Init = np.array([4, 3, 2, -1, 1, -2])*10**-4
low = np.array([0, 0, 0, -1, -1, -1])*10**-3
high = np.array([1, 1, 1, 1, 1, 1])*10**-3

##
#Initialise Matrices
D_NPE = np.zeros([Samples,6,nSim])
D_NLLS = np.zeros([6,nSim])

##
#Perform NPE & NLLS Evalution
for k in range(nSim):  
    D_NPE[:,:,k] = posterior.sample((Samples,), x=SArb[k,:])
    D_NLLS[:,k], _ = curve_fit(lambda x, *theta: FreedDWSSFPTensor_curve_fit(x, theta, G, tau, TRs, FlipAngles, bvecs, SArb[k,-3].numpy()*1E2, SArb[k,-2].numpy()*1E5, SArb[k,-1].numpy()*1E4), 1, SArb[k,:-3].numpy(), p0 = Init, bounds = (low,high), method='trf',maxfev=10**6)

In [None]:
##
#Obtain Mean Posterior for NPE
D_NPE_Mean = np.mean(D_NPE,axis=0)

In [None]:
##
#Plot Figure
fig, axs = plt.subplots(1, 6)
fig.set_size_inches(12,3)
labels = [r'$D_{xx}\cdot 10^{-4}$ $mm^2/s$',r'$D_{yy}\cdot 10^{-4}$ $mm^2/s$',r'$D_{zz}\cdot 10^{-4}$ $mm^2/s$']
for k in range(3):
    axs[k].hist((D_NPE_Mean[k,:]-DArb[:,k].numpy())*10**4,bins=25);axs[k].set_xlim([-3,3]);axs[k].set_yticks([]);axs[k].set_xlabel(labels[k])
    axs[k+3].hist((D_NLLS[k,:]-DArb[:,k].numpy())*10**4,bins=25);axs[k+3].set_xlim([-3,3]);axs[k+3].set_yticks([]);axs[k+3].set_xlabel(labels[k])
axs[0].text(-0.15, 1.05, '(c) Difference (NPE; SNR = 5)', transform=axs[0].transAxes, size=15)
axs[3].text(-0.15, 1.05, '(d) Difference (NLLS; SNR = 5)', transform=axs[3].transAxes, size=15)

In [None]:
##
#Calculate Average Difference (NPE)
D_NPE_Diff = np.median(np.abs(np.transpose(D_NPE_Mean)-DArb.numpy()))
print(''.join(['Median Difference versus Ground Truth = ','{0:.2f}'.format(D_NPE_Diff*10**4), r' x10^-4 mm2/s']))

##
#Calculate Interquartile Range (NPE)
D_NPE_IQR = iqr(np.abs(np.transpose(D_NPE_Mean)-DArb.numpy()))
print(''.join(['IQR = ','{0:.2f}'.format(D_NPE_IQR*10**4), r' x10^-4 mm2/s']))

In [None]:
##
#Save Figure
fig.savefig(''.join([OutputPath,'Figure5cd.pdf']),dpi=300,format='pdf',bbox_inches='tight')