In [None]:
##
#Import Packages
import numpy as np
import sys
import pickle
import torch
import matplotlib.pyplot as plt
from scipy.optimize import curve_fit
from timeit import default_timer as timer
from scipy.stats import iqr

##
#SBI Specific Packages
from sbi import analysis as analysis
from sbi import utils as utils

In [None]:
##
#Define Path to Code Database
DirPath = '/PATH/TO/bin/'

#Define Path for Posterior Object
InputPath = '/PATH/TO/Posterior.pkl'

#Define Path to Example Data
DataPath = '/PATH/TO/ExampleData/'

#Define Output Path
OutputPath = '/PATH/TO/Output/'

In [None]:
##
#Load Posterior
with open(InputPath, "rb") as handle:
    posterior = pickle.load(handle)

In [None]:
##
#Import Custom Functions
sys.path.append(DirPath)
from ImportData import *
from FreedAnalytical import *

In [None]:
##
#Load Data
    # bvecs - bvectors (3xn)
    # FlipAngles - Flip Angles (degrees) (1xn)
    # tau - Diffusion Gradient Duration (seconds) (1xn), 
    # G - Diffusion Gradient Duration (G/cm - Equivalent to mT/m Divided by 10) (1xn)
    # TRs - Repetition Times (seconds) (1xn)
    # b0s - Array Defining b0 locations (b0 = 1, dwi = 0) (1xn)

bvecs, FlipAngles, tau, G, TRs, _ = ImportTextDataDWSSFP(DataPath)

In [None]:
##
#Obtain Arbitrary Tensor / Signal Pairs
from sbi.inference import SNPE, simulate_for_sbi

##
#Define Nuber of Simulations
nSim = 1000

#Define SNR Level (wrt b0)
SNR = 20

##
#Define Prior Bounds
PriorLow = [0, 0, 0, -0.00025, -0.00025, -0.00025]
PriorHigh = [0.0005, 0.0005, 0.0005, 0.00025, 0.00025, 0.00025]

##
#Define Prior (Uniform)
prior = utils.BoxUniform(low=torch.tensor(PriorLow), high=torch.tensor(PriorHigh))

##
#Define Range of T1 (ms), T2 (ms) & B1 (Normalised) for Forward Simulator
T1Range = [300,1200]
T2Range = [20,80]
B1Range = [0.2,1.2]

##
#Define Simulator
simulator = lambda theta: FreedDWSSFPTensor_Conditional_SBIWrapper(theta,G,tau,TRs,FlipAngles,bvecs,B1Range,T1Range,T2Range,Cond=True)

##
#Estimate Tensor / Signal Pairs - Exagerate no Simulations to Account for NaNs (i.e. Invalid Tensors)
DArb, SArb = simulate_for_sbi(simulator, prior, nSim*5)

##
#Add Noise
SArb[:,:-3] = np.abs(SArb[:,:-3] + (torch.randn(*SArb[:,:-3].shape)*SArb[:,0][:,np.newaxis]/SNR))

In [None]:
##
#Obtain nSim Valid Tensor / Signal Pairs for Evaluation
Finite_idx = np.squeeze(np.argwhere(np.isfinite(SArb[:,0])))
DArb = DArb[Finite_idx[0:nSim],:]
SArb = SArb[Finite_idx[0:nSim],:]

In [None]:
##
#Perform Time Evaluation - NLLS

##
#Initialise Matrices
TimeNLLS = np.zeros([nSim])

##
#Define Initial Fitting Point & Bounds (NLLS)
Init = np.array([4, 3, 2, -1, 1, -2])*10**-4
low = np.array([0, 0, 0, -1, -1, -1])*10**-3
high = np.array([1, 1, 1, 1, 1, 1])*10**-3

##
#Reduce Computational Overhead for NLLS by Converting Inputs Prior to Evaluation
B1 = SArb[:,-3].numpy()*1E2
T1 = SArb[:,-2].numpy()*1E5
T2 = SArb[:,-1].numpy()*1E4
SArb_NLLS = SArb[:,:-3].numpy()

##
#Perform Time Estimation (NLLS)
for k in range(nSim):  
    start = timer()
    tmp, _ = curve_fit(lambda x, *theta: FreedDWSSFPTensor_curve_fit(x, theta, G, tau, TRs, FlipAngles, bvecs, B1[k], T1[k], T2[k]), 1, SArb_NLLS[k,:], p0 = Init, bounds=(low,high), method='trf',maxfev=10**6)
    end = timer()
    TimeNLLS[k] = (end - start)*1000

In [None]:
##
#Perform Time Evaluation - NPE

##
#Define Number of Posterior Samples per Evaluation
Samples = np.round(10**np.linspace(0, 4, num=21))

##
#Initialise Matrices
TimeNPE = np.zeros([len(Samples),nSim])

##
#Perform Time Estimation (NPE)
for l in range(len(Samples)):
    for k in range(nSim):  
        start = timer()
        tmp = posterior.sample((int(Samples[l]),), x=SArb[k,:],show_progress_bars=False)
        end = timer()
        TimeNPE[l,k] = (end - start)*1000

In [None]:
##
#Obtain Data for Plotting
MedianTimes = np.array([np.median(TimeNLLS),*np.median(TimeNPE,axis=1)])
IQRTimes = np.array([iqr(TimeNLLS),*iqr(TimeNPE,axis=1)])

In [None]:
##
#Perform Polynomial Fit

##
#Create Arrays
Off = np.ones(len(Samples))
X = np.array(Samples * 1)

##
#Define Weight Matrix
Weights = 1/np.array(IQRTimes[1:])

##
#Create Matrix
A = np.transpose(np.array([[*Off],[*X]]))*Weights[:,np.newaxis]

##
#Estimate Coefficients
Coeff = np.dot(np.linalg.pinv(A),Weights*MedianTimes[1:])

##
#Reconstruct Line
EvRecon = Coeff[0] + X*Coeff[1] 

In [None]:
##
#Plot Figure
fig, axs = plt.subplots(1, 2)
fig.set_size_inches(12,5)
xLabels = ['NLLS', 'NPE\n(1)', 'NPE\n(10)', 'NPE\n(100)', 'NPE\n(1,000)', 'NPE\n(10,000)']
axs[0].bar([0,1,2,3,4,5],[MedianTimes[0],MedianTimes[1],MedianTimes[6],MedianTimes[11],MedianTimes[16],MedianTimes[21]],yerr=[IQRTimes[0],IQRTimes[1],IQRTimes[6],IQRTimes[11],IQRTimes[16],IQRTimes[21]],tick_label=xLabels);axs[0].set_ylim([0,120])
axs[0].set_ylabel('Evaluation time (ms)',fontsize=12)
axs[1].scatter(np.log10(Samples),MedianTimes[1:],s=60,marker='x',label='Estimated Evaulation Time (Median)');axs[1].set_xlim([0,4]);axs[1].set_ylim([0,120])
axs[1].plot(np.log10(Samples),EvRecon,linestyle='--',linewidth=3,label=''.join(['Linear Fit (','{0:.2f}'.format(Coeff[1]*1000),r' $\mathrm{\mu} s \cdot \mathrm{N}_{\mathrm{Samples}}$ +', '{0:.2f}'.format(Coeff[0]),' ms)']))
axs[1].axhline(MedianTimes[0],linestyle='--',linewidth=3,color='tab:orange',label = 'NLLS Evaluation Time',alpha=0.5)
axs[1].set_xticks([0,1,2,3,4])
axs[1].set_xticklabels(['1','10','100','1,000','10,000'])
axs[1].set_ylabel('Evaluation Time (ms)',fontsize=12)
axs[1].set_xlabel(r'Number of Samples (log$_{10}$)',fontsize=12)
axs[1].legend(loc='upper left')
axs[0].text(-0.2, 1.05, '(a)', transform=axs[0].transAxes, size=22)
axs[1].text(-0.2, 1.05, '(b)', transform=axs[1].transAxes, size=22)

In [None]:
##
#Identify no. Samples for NPE versus NLLS Assuming Equivalent Analysis Time
noSamplesEqTime = np.round((MedianTimes[0]-Coeff[0])/Coeff[1])

print(''.join(['Initialisation Time = ', '{0:.2f}'.format(Coeff[0]),' ms']))
print(''.join(['Time Per Posterior Estimate = ', '{0:.2f}'.format(Coeff[1]*1000),' us']))
print(''.join(['No. Posterior samples in NLLS Equivalent Time = ', '{0:.0f}'.format(noSamplesEqTime)]))

In [None]:
##
#Save Figure
fig.savefig(''.join([OutputPath,'Figure6.pdf']),dpi=300,format='pdf',bbox_inches='tight')