In [None]:
##
#Import Packages
import numpy as np
import sys
import pickle
import torch
import matplotlib.pyplot as plt
from scipy.stats import iqr

##
#SBI Specific Packages
from sbi import analysis as analysis
from sbi import utils as utils

In [None]:
##
#Define Path to Code Database
DirPath = '/PATH/TO/bin/'

#Define Path for Posterior Object
InputPath = '/PATH/TO/Posterior.pkl'

#Define Path to Example Data
DataPath = '/PATH/TO/ExampleData/'

#Define Output Path
OutputPath = '/PATH/TO/Output/'

In [None]:
##
#Load Posterior
with open(InputPath, "rb") as handle:
    posterior = pickle.load(handle)

In [None]:
##
#Import Custom Functions
sys.path.append(DirPath)
from ImportData import *
from FreedAnalytical import *

In [None]:
##
#Load Data
    # bvecs - bvectors (3xn)
    # FlipAngles - Flip Angles (degrees) (1xn)
    # tau - Diffusion Gradient Duration (seconds) (1xn), 
    # G - Diffusion Gradient Duration (G/cm - Equivalent to mT/m Divided by 10) (1xn)
    # TRs - Repetition Times (seconds) (1xn)
    # b0s - Array Defining b0 locations (b0 = 1, dwi = 0) (1xn)

bvecs, FlipAngles, tau, G, TRs, _ = ImportTextDataDWSSFP(DataPath)

In [None]:
##
#Evaluate T1, T2 & B1 Conditioning versus Ground Truth

##
#Define Relaxation & Flip Angle Arrays (10 Elements per Array)
T1 = np.arange(300,1300,100)
T2 = np.arange(20,81,60/9)
B1 = np.arange(0.2, 1.3, 1/9)

##
#Define Number of Posterior Samples per Evaluation
Samples = 100

##
#Define Diffusion Tensor (mm/s)
Dxx, Dyy, Dzz, Dxy, Dxz, Dyz = np.array([4, 3, 2, -1, 1, -2])*10**-4

#Convert into Array
theta=np.array([Dxx, Dyy, Dzz, Dxy, Dxz, Dyz])

##
#Initialise Array
ThetaOutput = np.zeros((B1.shape[0],T1.shape[0],T2.shape[0],theta.shape[0]))

##
#Perform Parameter Estimation
for k in range(B1.shape[0]):
    for l in range(T1.shape[0]):
        for m in range(T2.shape[0]):
            #Estimate Signal
            S = FreedDWSSFPTensor_Conditional(theta,G,tau,TRs,FlipAngles,bvecs,B1[k],T1[l],T2[m])
            #Estimate Posterior Samples
            Post = posterior.sample((Samples,), x=torch.from_numpy(S[np.newaxis,:]))
            #Take Mean for Evaluation
            ThetaOutput[k,l,m,:] = torch.mean(Post,dim=0)


In [None]:
##
#Define Elements to Create Output Plots
ThetaB1 = np.squeeze(ThetaOutput[:,3,2])
ThetaT1 = np.squeeze(ThetaOutput[7,:,2])
ThetaT2 = np.squeeze(ThetaOutput[7,3,:])

In [None]:
##
#Create Figure
fig, axs = plt.subplots(2, 2)
fig.set_size_inches(12,6)

#T1
im1 = axs[0,0].imshow(np.concatenate((theta[:,np.newaxis],np.transpose(ThetaT1)),axis=1)*10**4,vmin=-2,vmax=4)
rect = plt.Rectangle((-0.4, -0.4), 1,5.85, fill=False, color="limegreen", linewidth=3)
axs[0,0].add_patch(rect)
cbar = fig.colorbar(im1, extend='both', shrink=1, ax=axs[0,0])
cbar.set_label('D (x10$^{-4}$ mm$^{2}$/s)', rotation=90, size=10)
axs[0,0].set_xticks((1,T1.shape[0]),(T1[0],np.round(T1[-1],decimals=1)), size=12)
axs[0,0].set_yticks((0,1,2,3,4,5),('$D_{xx}$','$D_{yy}$','$D_{zz}$','$D_{xy}$','$D_{xz}$','$D_{yz}$'), size=12)
axs[0,0].set_xlabel('$T_{1}$ (ms)', size=12)
axs[0,0].text(-0.12, 1.05, '(a)', transform=axs[0,0].transAxes, size=15)

#T2
im2 = axs[0,1].imshow(np.concatenate((theta[:,np.newaxis],np.transpose(ThetaT2)),axis=1)*10**4,vmin=-2,vmax=4)
rect = plt.Rectangle((-0.4, -0.4), 1,5.85, fill=False, color="limegreen", linewidth=3)
axs[0,1].add_patch(rect)
cbar = fig.colorbar(im2, extend='both', shrink=1, ax=axs[0,1])
cbar.set_label('D (x10$^{-4}$ mm$^{2}$/s)', rotation=90, size=10)
axs[0,1].set_xticks((1,T2.shape[0]),(T2[0],np.round(T2[-1],decimals=1)), size=12)
axs[0,1].set_yticks((0,1,2,3,4,5),('$D_{xx}$','$D_{yy}$','$D_{zz}$','$D_{xy}$','$D_{xz}$','$D_{yz}$'), size=12)
axs[0,1].set_xlabel('$T_{2}$ (ms)', size=12)
axs[0,1].text(-0.12, 1.05, '(b)', transform=axs[0,1].transAxes, size=15)


#B1
im3 = axs[1,0].imshow(np.concatenate((theta[:,np.newaxis],np.transpose(ThetaB1)),axis=1)*10**4,vmin=-2,vmax=4)
rect = plt.Rectangle((-0.4, -0.4), 1,5.85, fill=False, color="limegreen", linewidth=3)
axs[1,0].add_patch(rect)
cbar = fig.colorbar(im3, extend='both', shrink=1, ax=axs[1,0])
cbar.set_label('D (x10$^{-4}$ mm$^{2}$/s)', rotation=90, size=10)
axs[1,0].set_xticks((1,B1.shape[0]),(B1[0],np.round(B1[-1],decimals=1)), size=12)
axs[1,0].set_yticks((0,1,2,3,4,5),('$D_{xx}$','$D_{yy}$','$D_{zz}$','$D_{xy}$','$D_{xz}$','$D_{yz}$'), size=12)
axs[1,0].set_xlabel('$B_{1}$', size=12)
axs[1,0].text(-0.12, 1.05, '(c)', transform=axs[1,0].transAxes, size=15)

In [None]:
##
#Calculate Median Difference
B1Diff = np.median(np.abs((np.transpose(ThetaB1)-theta[:,np.newaxis]))/np.abs(theta[:,np.newaxis])*100)
T1Diff = np.median(np.abs((np.transpose(ThetaT1)-theta[:,np.newaxis]))/np.abs(theta[:,np.newaxis])*100)
T2Diff = np.median(np.abs((np.transpose(ThetaT2)-theta[:,np.newaxis]))/np.abs(theta[:,np.newaxis])*100)
AvDiff = (B1Diff + T1Diff + T2Diff)/3
print(''.join(['Median Difference versus Ground Truth = ','{0:.2f}'.format(AvDiff), '%']))

##
#Calcualte IQR
B1DiffIQR = iqr(np.abs((np.transpose(ThetaB1)-theta[:,np.newaxis]))/np.abs(theta[:,np.newaxis])*100)
T1DiffIQR = iqr(np.abs((np.transpose(ThetaT1)-theta[:,np.newaxis]))/np.abs(theta[:,np.newaxis])*100)
T2DiffIQR = iqr(np.abs((np.transpose(ThetaT2)-theta[:,np.newaxis]))/np.abs(theta[:,np.newaxis])*100)
AvDiffIQR = (B1DiffIQR + T1DiffIQR + T2DiffIQR)/3
print(''.join(['IQR = ','{0:.2f}'.format(AvDiffIQR), '%']))

In [None]:
##
#Save Figure
fig.savefig(''.join([OutputPath,'Figure4ac.pdf']),dpi=300,format='pdf',bbox_inches='tight')