In [None]:
##
#Import Packages
import numpy as np
import sys
import pickle
import torch
import matplotlib.pyplot as plt
from scipy.optimize import curve_fit
from scipy.stats import iqr

##
#SBI Specific Packages
from sbi import analysis as analysis
from sbi import utils as utils

In [None]:
##
#Define Path to Code Database
DirPath = '/PATH/TO/bin/'

#Define Path for Posterior Object
InputPath = '/PATH/TO/Posterior.pkl'

#Define Path to Example Data
DataPath = '/PATH/TO/ExampleData/'

#Define Output Path
OutputPath = '/PATH/TO/Output/'

In [None]:
##
#Load Posterior
with open(PostPath, "rb") as handle:
    posterior = pickle.load(handle)

In [None]:
##
#Import Custom Functions
sys.path.append(DirPath)
from ImportData import *
from FreedAnalytical import *

In [None]:
##
#Load Data
    # bvecs - bvectors (3xn)
    # FlipAngles - Flip Angles (degrees) (1xn)
    # tau - Diffusion Gradient Duration (seconds) (1xn), 
    # G - Diffusion Gradient Duration (G/cm - Equivalent to mT/m Divided by 10) (1xn)
    # TRs - Repetition Times (seconds) (1xn)
    # b0s - Array Defining b0 locations (b0 = 1, dwi = 0) (1xn)

bvecs, FlipAngles, tau, G, TRs, _ = ImportTextDataDWSSFP(DataPath)

In [None]:
##
#Define Defaults for Evaluation

##
#Define Diffusion Tensor (mm/s)
Dxx, Dyy, Dzz, Dxy, Dxz, Dyz = np.array([4, 3, 2, -1, 1, -2])*10**-4

##
#Define Relaxation Times 
T1 = 650
T2 = 35

##
#Relative B1
B1 = 1

##
#Number of Repeats
nRepeats = 1000

##
#Define Number of Posterior Samples per Evaluation
Samples = 100

#
#Convert into Array
theta=np.array([Dxx, Dyy, Dzz, Dxy, Dxz, Dyz])

In [None]:
##
#Perform SNR comparisons - NPE and NLLS

##
#Define SNR levels (wrt b0)
SNR = [5, 10, 15, 20, 25, 30, 35, 40, 45, 50]

##
#Define Initial Fitting Point & Bounds (NLLS)
Init = np.array([4, 3, 2, -1, 1, -2])*10**-4
low = np.array([0, 0, 0, -1, -1, -1])*10**-3
high = np.array([1, 1, 1, 1, 1, 1])*10**-3

##
#Initialise Matrices
ThetaSBISNR = np.zeros((6,len(SNR),nRepeats))
ThetaSBISNR_std = np.zeros((6,len(SNR),nRepeats))
ThetaLSSNR = np.zeros((6,len(SNR),nRepeats))

##
#Perform Evaluation
for l in range(nRepeats):
    for k in range(len(SNR)):

        ##
        #Generate Signal + Noise
        S = FreedDWSSFPTensor_Conditional(theta,G,tau,TRs,FlipAngles,bvecs,B1,T1,T2)
        S[:-3] = np.abs(S[:-3] + (np.random.randn(S[:-3].shape[0])*S[0]/SNR[k]))

        ##
        #Estimate Posterior Samples
        Post = posterior.sample((Samples,), x=torch.from_numpy(S[np.newaxis,:]))

        ##
        #Take Mean of Posterior Samples for Evaluation
        ThetaSBISNR[:,k,l] = torch.mean(Post,dim=0)

        ##
        #Perform NLLS
        ThetaLSSNR[:,k,l], _ = curve_fit(lambda x, *theta: FreedDWSSFPTensor_curve_fit(x, theta, G, tau, TRs, FlipAngles, bvecs, B1, T1, T2), 1, S[:-3], p0 = Init, bounds = (low,high),method='trf',maxfev=10**6)

In [None]:
##
#Average Over Number of Repeats (NPE & NLLS)
ThetaSBISNR_mean = np.mean(ThetaSBISNR,axis=2)
ThetaLSSNR_mean = np.mean(ThetaLSSNR,axis=2)

In [None]:
##
#Plot Figure
fig, axs = plt.subplots(1, 2)
fig.set_size_inches(12,6)

##
#NPE
im1 = axs[0].imshow(np.concatenate((theta[:,np.newaxis],(ThetaSBISNR_mean)),axis=1)*10**4,vmin=-2,vmax=4)
rect = plt.Rectangle((-0.4, -0.4), 1,5.85, fill=False, color="limegreen", linewidth=3)
axs[0].add_patch(rect)
cbar = fig.colorbar(im1, extend='both', shrink=0.4, ax=axs[0])
cbar.set_label('D (x10$^{-4}$ mm$^{2}$/s)', rotation=90, size=10)
axs[0].set_xticks((1,len(SNR)),(SNR[0],np.round(SNR[-1])),size=12)
axs[0].set_yticks((0,1,2,3,4,5),('$D_{xx}$','$D_{yy}$','$D_{zz}$','$D_{xy}$','$D_{xz}$','$D_{yz}$'), size=12)
axs[0].set_xlabel('SNR', size=12)
axs[0].text(-0.15, 1.05, '(a) NPE', transform=axs[0].transAxes, size=15)

##
#NLLS
im2 = axs[1].imshow(np.concatenate((theta[:,np.newaxis],(ThetaLSSNR_mean)),axis=1)*10**4,vmin=-2,vmax=4)
rect = plt.Rectangle((-0.4, -0.4), 1,5.85, fill=False, color="limegreen", linewidth=3)
axs[1].add_patch(rect)
cbar = fig.colorbar(im2, extend='both', shrink=0.4, ax=axs[1])
cbar.set_label('D (x10$^{-4}$ mm$^{2}$/s)', rotation=90, size=10)
axs[1].set_xticks((1,len(SNR)),(SNR[0],np.round(SNR[-1])),size=12)
axs[1].set_yticks((0,1,2,3,4,5),('$D_{xx}$','$D_{yy}$','$D_{zz}$','$D_{xy}$','$D_{xz}$','$D_{yz}$'), size=12)
axs[1].set_xlabel('SNR', size=12)
axs[1].text(-0.15, 1.05, '(b) NLLS', transform=axs[1].transAxes, size=15)

In [None]:
##
#Calculate Difference (NPE)
SNRDiff = np.median(np.abs((ThetaSBISNR_mean-theta[:,np.newaxis]))/np.abs(theta[:,np.newaxis])*100)
print(''.join(['Median Difference versus Ground Truth (NPE) = ','{0:.2f}'.format(SNRDiff), '%']))

##
#Calculate Interquartile Range (NPE)
SNRDiffIQR = iqr(np.abs((ThetaSBISNR_mean-theta[:,np.newaxis]))/np.abs(theta[:,np.newaxis])*100)
print(''.join(['IQR = ','{0:.2f}'.format(SNRDiffIQR), '%']))


In [None]:
#Save Figure
fig.savefig(''.join([OutputPath,'Figure5ab.pdf']),dpi=300,format='pdf',bbox_inches='tight')