In [None]:
##
#Import Packages
import numpy as np
import nibabel as nib
import sys
import pickle
from interruptingcow import timeout

##
#SBI Specific Packages
import torch
from sbi import analysis as analysis
from sbi import utils as utils
from sbi.inference import SNPE, NPE, MCMCPosterior, posterior_estimator_based_potential, simulate_for_sbi
from sbi.utils import RestrictionEstimator
from sbi.utils.user_input_checks import check_sbi_inputs, process_prior, process_simulator
from sbi.analysis import conditional_corrcoeff, conditional_pairplot, conditional_potential, pairplot, pairplot
from sbi.neural_nets.embedding_nets import FCEmbedding
from sbi.neural_nets import posterior_nn

In [None]:
##
#Import custom functions
sys.path.append('YourPath')
from ImportData import *
from FreedAnalytical import *

In [None]:
##
#Load Posterior and inference module
with open("YourPosterior.pkl", "rb") as handle:
    posterior = pickle.load(handle)

In [None]:
##
#Define Data Path and load data
DataPath = 'YourDataPath'
Data, T1Map, T2Map, B1Map, Mask, noisefloor, bvecs, FlipAngles, tau, G, TRs = ImportDataDWSSFP(DataPath)

In [None]:
##
#Normalise Data by S0

##
#Initialise Arrays
Data.norm = Data.data*0

##
#Define unique locations
Values, Index = np.unique(FlipAngles, return_index=True)

##
#Estimate normalised data for each flip angle
for idx, k in enumerate(Index):
    #Calculate theoretical signal amplitude
    b0 = FreedDWSSFP(G[k], tau[k], TRs[k], FlipAngles[k]*B1Map.data, 0, T1Map.data, T2Map.data)
    #Identify S0 (incorporating noisefloor contribution)
    S0 = np.mean((Data.data[:,:,:,(tau == 0) & (FlipAngles == Values[idx])]**2-np.mean(noisefloor[(tau == 0) & (FlipAngles == Values[idx])])**2)**0.5,axis = 3)/b0*Mask.data
    #Divide by S0
    Data.norm[:,:,:,FlipAngles == Values[idx]] = Data.data[:,:,:,FlipAngles == Values[idx]]/S0[:,:,:,np.newaxis]

In [None]:
##
#Perform fitting

##
#Define Number of samples
Samples = 100

##
#Initialise Output Array
Tensor = np.zeros((*T1Map.data.shape,Samples,6))
TensorErr = np.zeros(T1Map.data.shape)


for k in range(Data.data.shape[0]):
    for l in range(Data.data.shape[1]):
        for m in range(Data.data.shape[2]):
            if Mask.data[k,l,m] == 0:
                pass
            else:
                #Generate Data Vector
                try:
                    with timeout(5, exception=RuntimeError):
                        DataVec = np.concatenate((Data.norm[k,l,m,:],[B1Map.data[k,l,m]/100],[T1Map.data[k,l,m]/100000],[T2Map.data[k,l,m]/10000]))
                        Tensor[k,l,m,:,:] = posterior.sample((Samples,), x=torch.from_numpy(DataVec[np.newaxis,:]), show_progress_bars = False)
                except:
                    TensorErr[k,l,m] = 1
    print(k)

In [None]:
##
#Output mean tensor estimates
OutputPath = 'YourPath'
nib.save(nib.Nifti1Image(np.mean(Tensor[:,:,:,:,0],axis=3),Data.aff),''.join([OutputPath, 'D11_Mean.nii.gz']))
nib.save(nib.Nifti1Image(np.mean(Tensor[:,:,:,:,1],axis=3),Data.aff),''.join([OutputPath, 'D22_Mean.nii.gz']))
nib.save(nib.Nifti1Image(np.mean(Tensor[:,:,:,:,2],axis=3),Data.aff),''.join([OutputPath, 'D33_Mean.nii.gz']))
nib.save(nib.Nifti1Image(np.mean(Tensor[:,:,:,:,3],axis=3),Data.aff),''.join([OutputPath, 'D12_Mean.nii.gz']))
nib.save(nib.Nifti1Image(np.mean(Tensor[:,:,:,:,4],axis=3),Data.aff),''.join([OutputPath, 'D13_Mean.nii.gz']))
nib.save(nib.Nifti1Image(np.mean(Tensor[:,:,:,:,5],axis=3),Data.aff),''.join([OutputPath, 'D23_Mean.nii.gz']))

In [None]:
##
#Output posterior tensor distributions
nib.save(nib.Nifti1Image(Tensor[:,:,:,:,0],Data.aff),''.join([OutputPath, 'D11_Samples.nii.gz']))
nib.save(nib.Nifti1Image(Tensor[:,:,:,:,1],Data.aff),''.join([OutputPath, 'D22_Samples.nii.gz']))
nib.save(nib.Nifti1Image(Tensor[:,:,:,:,2],Data.aff),''.join([OutputPath, 'D33_Samples.nii.gz']))
nib.save(nib.Nifti1Image(Tensor[:,:,:,:,3],Data.aff),''.join([OutputPath, 'D12_Samples.nii.gz']))
nib.save(nib.Nifti1Image(Tensor[:,:,:,:,4],Data.aff),''.join([OutputPath, 'D13_Samples.nii.gz']))
nib.save(nib.Nifti1Image(Tensor[:,:,:,:,5],Data.aff),''.join([OutputPath, 'D23_Samples.nii.gz']))