In [None]:
##
#Import Packages
import numpy as np
import nibabel as nib
import sys
import pickle


##
#SBI Specific Packages
import torch
from sbi import analysis as analysis
from sbi import utils as utils
from sbi.inference import SNPE, NPE, MCMCPosterior, posterior_estimator_based_potential, simulate_for_sbi
from sbi.utils import RestrictionEstimator
from sbi.utils.user_input_checks import check_sbi_inputs, process_prior, process_simulator
from sbi.analysis import conditional_corrcoeff, conditional_pairplot, conditional_potential, pairplot, pairplot
from sbi.neural_nets.embedding_nets import FCEmbedding
from sbi.neural_nets import posterior_nn

In [None]:
##
#Import custom functions (Define your path)
sys.path.append('YourPath')
from FreedAnalytical import *
from ImportData import *

In [None]:
##
#Define Data Path and load data (Text files corresponding to acquisition parameters)
DataPath = 'YourDataPath'
_, _, _, _, _, _, bvecs, FlipAngles, tau, G, TRs = ImportDataDWSSFP(DataPath)

In [None]:
##
#Define Priors
num_dim = 6
prior = utils.BoxUniform(low=torch.tensor([0,0,0,-0.00025,-0.00025,-0.00025]) * torch.ones(num_dim), high=torch.tensor([0.0005,0.0005,0.0005,0.00025,0.00025,0.00025]) * torch.ones(num_dim))

In [None]:
##
#Define Simulator (encompass in lambda as it only accepts single input) - Note - simulator is conditioned on arbitrary T1, T2 & B1 values.
simulator = lambda theta: FreedDWSSFPTensor_Conditional_SBIWrapper(theta,G,tau,TRs,FlipAngles,bvecs)

In [None]:
##
#Establish network to identify prior regions corresponding to invalid simulations (for diffusion tensor, where eigenvalues < 0) via a restriction estimator

##
#First Generate training data for the restriction estimator
number_of_simulations = 500000
theta_RestrictionEstimator, x_RestrictionEstimator = simulate_for_sbi(simulator, prior, number_of_simulations)

##
#Train Restricted Prior Estimator
restriction_estimator = RestrictionEstimator(prior=prior)
restriction_estimator.append_simulations(theta_RestrictionEstimator, x_RestrictionEstimator)
classifier = restriction_estimator.train()

##
#Define Restricted Prior
restricted_prior = restriction_estimator.restrict_prior()

In [None]:
##
#Perform forward simulations with restricted prior for SBI inference network
number_of_simulations = 1000000
theta, x = simulate_for_sbi(simulator, restricted_prior, number_of_simulations)

In [None]:
##
#Create 5 different SNR levels for the training data (not including noise free simulations)
Rounds = 5

##
#Estimate average Signal Amplitude of b0 data
b0 = torch.nanmean(x[:,:-3][:,tau==0])
#Define maximum and minimum SNR
SNR = [2,50]

##
#Replicate the variable and signal arrays by the number of SNR levels
thetaSNR = theta.repeat(Rounds+1,1)
xSNR = x.repeat(Rounds+1,1)

##
#Scale the signals by the different SNR levels
xSNR[theta.shape[0]:,:-3]=np.abs(xSNR[theta.shape[0]:,:-3]+torch.randn((x[:,:-3].shape[0]*Rounds,x[:,:-3].shape[1]))*b0/torch.distributions.uniform.Uniform(SNR[0], SNR[1]).sample([x.shape[0]*Rounds,1]))

In [None]:
##
#Define Neural Density Estimator
neural_posterior = posterior_nn(model="nsf")

# setup the inference procedure with NPE and perform training
inference = SNPE(prior=prior,density_estimator=neural_posterior)
density_estimator = inference.append_simulations(thetaSNR, xSNR).train()
posterior = inference.build_posterior()

In [None]:
##
#Save Posterior module
with open("YourPosterior.pkl", "wb") as handle:
    pickle.dump(posterior, handle)