In [None]:
##
#Import Packages
import numpy as np
import sys
import pickle
import torch
import matplotlib.pyplot as plt
from scipy.signal import savgol_filter

##
#SBI Specific Packages
from sbi import analysis as analysis
from sbi import utils as utils

In [None]:
##
#Define Path to Code Database
DirPath = '/PATH/TO/bin/'

#Define Path for Posterior Object
InputPath = '/PATH/TO/Posterior.pkl'

#Define Path to Example Data
DataPath = '/PATH/TO/ExampleData/'

#Define Output Path
OutputPath = '/PATH/TO/Output/'

In [None]:
##
#Load Posterior
with open(InputPath, "rb") as handle:
    posterior = pickle.load(handle)

In [None]:
##
#Import Custom Functions
sys.path.append(DirPath)
from ImportData import *
from FreedAnalytical import *

In [None]:
##
#Load Data
    # bvecs - bvectors (3xn)
    # FlipAngles - Flip Angles (degrees) (1xn)
    # tau - Diffusion Gradient Duration (seconds) (1xn), 
    # G - Diffusion Gradient Duration (G/cm - Equivalent to mT/m Divided by 10) (1xn)
    # TRs - Repetition Times (seconds) (1xn)
    # b0s - Array Defining b0 locations (b0 = 1, dwi = 0) (1xn)

bvecs, FlipAngles, tau, G, TRs, _ = ImportTextDataDWSSFP(DataPath)

In [None]:
##
#Obtain Arbitrary Tensor / Signal Pairs
from sbi.inference import SNPE, simulate_for_sbi

##
#Define Nuber of Simulations
nSim = 1000

#Define SNR Level (wrt b0)
SNR = [5,10,15,20,25,30,35,40,45,50]

##
#Define Prior Bounds
PriorLow = [0, 0, 0, -0.00025, -0.00025, -0.00025]
PriorHigh = [0.0005, 0.0005, 0.0005, 0.00025, 0.00025, 0.00025]

##
#Define Prior (Uniform)
prior = utils.BoxUniform(low=torch.tensor(PriorLow), high=torch.tensor(PriorHigh))

##
#Define Range of T1 (ms), T2 (ms) & B1 (Normalised) for Forward Simulator
T1Range = [300,1200]
T2Range = [20,80]
B1Range = [0.2,1.2]

##
#Define Simulator
simulator = lambda theta: FreedDWSSFPTensor_Conditional_SBIWrapper(theta,G,tau,TRs,FlipAngles,bvecs,B1Range,T1Range,T2Range,Cond=True)

##
#Estimate Tensor / Signal Pairs - Exagerate no Simulations to Account for NaNs (i.e. Invalid Tensors)
DArb, SArb = simulate_for_sbi(simulator, prior, nSim*5)

In [None]:
##
#Obtain nSim Valid Tensor / Signal Pairs for Evaluation
Finite_idx = np.squeeze(np.argwhere(np.isfinite(SArb[:,0])))
DArb = DArb[Finite_idx[0:nSim],:]
SArb = SArb[Finite_idx[0:nSim],:]

In [None]:
##
#Define Number of Posterior Samples per Evaluation
Samples = 1000

##
#Initialise Matrices
D_NPE = np.zeros([Samples,6,nSim,len(SNR)])

##
#Perform NPE Evalution
for l in range(len(SNR)):
    SNoise = SArb*1
    SNoise[:,:-3] = np.abs(SArb[:,:-3] + (torch.randn(*SArb[:,:-3].shape)*SArb[:,0][:,np.newaxis]/SNR[l]))
    for k in range(nSim):  
        D_NPE[:,:,k,l] = posterior.sample((Samples,), x=SNoise[k,:],show_progress_bars=False)

In [None]:
##
#Generate Mean & Standard Deviation as a Function of the Number of Samples

##
#Initialise Matrix
Mean_nSamples = np.zeros(D_NPE.shape)
STD_nSamples = np.zeros(D_NPE.shape)

##
#Calcuate Mean & Standard Deviation as Function of no. Samples
for k in range(1,D_NPE.shape[0]+1):
    Mean_nSamples[k-1,...] = np.mean(D_NPE[:k,...],axis=0) 
    STD_nSamples[k-1,...] = np.std(D_NPE[:k,...],axis=0)

In [None]:
##
#Define Maximum Index of no. Samples for Figure
nSamples = 50 

##
#Median over the no. Repeats & Tensor Coefficients
Mean_nSamples_av = np.median(np.abs(Mean_nSamples[:nSamples,...]),axis=(1,2))
STD_nSamples_av = np.median(np.abs(STD_nSamples[:nSamples,...]),axis=(1,2))

In [None]:
##
#Perform Smoothing to Improve Plot Visualisation
Mean_nSamples_av_Smooth = savgol_filter(savgol_filter(Mean_nSamples_av,5,2,axis=0),5,2,axis=1)
STD_nSamples_av_Smooth = savgol_filter(savgol_filter(STD_nSamples_av[1:],5,2,axis=0),5,2,axis=1)

In [None]:
##
#Plot Figure
fig, axs = plt.subplots(1, 2)
fig.set_size_inches(12,3)

##
#Mean
im1 = axs[0].imshow(np.transpose(Mean_nSamples_av_Smooth)*10**4,vmin=1.5,vmax=2,aspect='auto')
cbar = fig.colorbar(im1, extend='both', shrink=0.8, ax=axs[0])
cbar.set_label(r'D ($\cdot 10^{-4}$ mm$^{2}$/s)', rotation=90, size=10)
axs[0].set_yticks((range(len(SNR))),(SNR),size=12)
axs[0].set_xticks(range(int(np.round(Mean_nSamples_av.shape[0]/5))-1,Mean_nSamples_av.shape[0],int(np.round(Mean_nSamples_av.shape[0]/5))),range(int(np.round(Mean_nSamples_av.shape[0]/5)),Mean_nSamples_av.shape[0]+1,int(np.round(Mean_nSamples_av.shape[0]/5))), size=12)
axs[0].set_xlabel('No. Post. Samp', size=12)
axs[0].set_ylabel('SNR', size=12)
axs[0].text(-0.15, 1.05, '(a) Mean (Averaged)', transform=axs[0].transAxes, size=12)

##
#Standard Deviation
im2 = axs[1].imshow(np.transpose(np.concatenate((STD_nSamples_av_Smooth[0,:][np.newaxis,:]*np.nan,STD_nSamples_av_Smooth*10**4),axis=0)),vmin=0,vmax=0.25,aspect='auto')
cbar = fig.colorbar(im2, extend='both', shrink=0.8, ax=axs[1])
cbar.set_label(r'D ($\cdot 10^{-4}$ mm$^{2}$/s)', rotation=90, size=10)
axs[1].set_yticks((range(len(SNR))),(SNR),size=12)
axs[1].set_xticks(range(int(np.round(Mean_nSamples_av.shape[0]/5))-1,Mean_nSamples_av.shape[0],int(np.round(Mean_nSamples_av.shape[0]/5))),range(int(np.round(Mean_nSamples_av.shape[0]/5)),Mean_nSamples_av.shape[0]+1,int(np.round(Mean_nSamples_av.shape[0]/5))), size=12)
axs[1].set_xlabel('No. Post. Samp', size=12)
axs[1].set_ylabel('SNR', size=12)
axs[1].text(-0.15, 1.05, '(b) Standard Deviation (Averaged)', transform=axs[1].transAxes, size=12)

In [None]:
##
#Save Figure
fig.savefig(''.join([OutputPath,'Figure7ab.pdf']),dpi=300,format='pdf',bbox_inches='tight')