In [None]:
##
#Import Packages
import numpy as np
import sys
import pickle
import torch
import matplotlib.pyplot as plt
from scipy.stats import iqr

##
#SBI Specific Packages
from sbi import analysis as analysis
from sbi import utils as utils

In [None]:
##
#Define Path to Code Database
DirPath = '/PATH/TO/bin/'

#Define Path for Posterior Object
InputPath = '/PATH/TO/Posterior.pkl'

#Define Path to Example Data
DataPath = '/PATH/TO/ExampleData/'

#Define Output Path
OutputPath = '/PATH/TO/Output/'

In [None]:
##
#Load Posterior
with open(InputPath, "rb") as handle:
    posterior = pickle.load(handle)

In [None]:
##
#Import Custom Functions
sys.path.append(DirPath)
from ImportData import *
from FreedAnalytical import *

In [None]:
##
#Load Data
    # bvecs - bvectors (3xn)
    # FlipAngles - Flip Angles (degrees) (1xn)
    # tau - Diffusion Gradient Duration (seconds) (1xn), 
    # G - Diffusion Gradient Duration (G/cm - Equivalent to mT/m Divided by 10) (1xn)
    # TRs - Repetition Times (seconds) (1xn)
    # b0s - Array Defining b0 locations (b0 = 1, dwi = 0) (1xn)

bvecs, FlipAngles, tau, G, TRs, _ = ImportTextDataDWSSFP(DataPath)

In [None]:
##
#Obtain Arbitrary Tensor / Signal Pairs
from sbi.inference import SNPE, simulate_for_sbi

##
#Define Nuber of Simulations
nSim = 1000

##
#Define Prior Bounds
PriorLow = [0, 0, 0, -0.00025, -0.00025, -0.00025]
PriorHigh = [0.0005, 0.0005, 0.0005, 0.00025, 0.00025, 0.00025]

##
#Define Prior (Uniform)
prior = utils.BoxUniform(low=torch.tensor(PriorLow), high=torch.tensor(PriorHigh))

##
#Define Range of T1 (ms), T2 (ms) & B1 (normalised) for Forward Simulator
T1Range = [300,1200]
T2Range = [20,80]
B1Range = [0.2,1.2]

##
#Define Simulator
simulator = lambda theta: FreedDWSSFPTensor_Conditional_SBIWrapper(theta,G,tau,TRs,FlipAngles,bvecs,B1Range,T1Range,T2Range,Cond=True)

##
#Estimate Tensor / Signal Pairs - Exagerate no. Simulations to Account for NaNs (i.e. Invalid Tensors)
DArb, SArb = simulate_for_sbi(simulator, prior, nSim*5)

In [None]:
##
#Obtain nSim Valid Tensor / Signal Pairs for Evaluation
Finite_idx = np.squeeze(np.argwhere(np.isfinite(SArb[:,0])))
DArb = DArb[Finite_idx[0:nSim],:]
SArb = SArb[Finite_idx[0:nSim],:]

In [None]:
##
#Define Number of Posterior Samples per Evaluation
Samples = 100

##
#Initialise Matrices
D_NPE = np.zeros([Samples,6,nSim])

##
#Perform NPE Evalution
for k in range(nSim):  
    D_NPE[:,:,k] = posterior.sample((Samples,), x=SArb[k,:])

In [None]:
##
#Obtain Mean of Posterior
D_NPE_Mean = np.mean(D_NPE,axis=0)

In [None]:
##
#Plot Figure of Difference with Ground Truth
fig, axs = plt.subplots(1, 6)
fig.set_size_inches(6,3)
for k in range(6):
    axs[k].hist(D_NPE_Mean[k,:]-DArb[:,k].numpy(),bins=100);axs[k].set_xlim([-1.5E-5,1.5E-5]);axs[k].set_yticks([])


In [None]:
##
#Calculate Average Difference
DDiff = np.median(np.abs(np.transpose(D_NPE_Mean)-DArb.numpy()))
print(''.join(['Median Difference versus Ground Truth = ','{0:.2f}'.format(DDiff*10**6), r' x10^-6 mm2/s']))

##
#Clculate Interquartile Range
DIQR = iqr(np.abs(np.transpose(D_NPE_Mean)-DArb.numpy()))
print(''.join(['IQR = ','{0:.2f}'.format(DIQR*10**6), r' x10^-6 mm2/s']))


In [None]:
##
#Save Figure
fig.savefig(''.join([OutputPath,'Figure4d.pdf']),dpi=300,format='pdf',bbox_inches='tight')