In [None]:
##
#Import Packages
import numpy as np
import nibabel as nib
import sys
import pickle
import torch

##
#SBI Specific Packages
from sbi import utils as utils
from sbi.inference import SNPE, simulate_for_sbi
from sbi.utils import RestrictionEstimator
from sbi.neural_nets import posterior_nn

In [None]:
##
#Define Path to Code database
DirPath = '/PATH/TO/bin/'

#Define Output Path for Posterior Object
OutputPath = '/PATH/TO/OUTPUT.pkl'

#Define Path to Input Data
DataPath = '/PATH/TO/Data'

In [None]:
##
#Is Network Conditioned on T1/T2/B1? (True/False)
Cond = True

In [None]:
##
#Define Number of Simulations (Only Accepts Inputs in Integer Format)
 
##
#Restricted Prior Estimator (Classifer Network - Identify Prior Regions Corresponding to Invalid Simulations - Where Eigenvalues < 0 for Tensor)
noSim_RPE = 500000

##
#NPE Network 
noSim = 1000000

In [None]:
##
#Perform Simulations at different SNR levels (Total Simulations = noSim x (number of SNR levels+1))

##
#Define SNR Range
SNR = [2,50]

##
#Define number of SNR levels (not including noise-free simulations)
SNRlevels = 5

In [None]:
##
#Define Prior ranges for Tensor [D11,D22,D33,D12,D13,D23] (mm2/s) - Suitable Range for Post-Mortem Investigations
PriorLow = [0, 0, 0, -0.00025, -0.00025, -0.00025]
PriorHigh = [0.0005, 0.0005, 0.0005, 0.00025, 0.00025, 0.00025]

In [None]:
##
#Define Priors
prior = utils.BoxUniform(low=torch.tensor(PriorLow), high=torch.tensor(PriorHigh))

In [None]:
##
#Define Range of T1 (ms), T2 (ms) & B1 (normalised) for Forward Simulator
T1Range = [300,1200]
T2Range = [20,80]
B1Range = [0.2,1.2]

In [None]:
##
#Import Custom Functions (Including Simulator)
sys.path.append(DirPath)
from FreedAnalytical import *
from ImportData import *

In [None]:
##
#Load Data
    # bvecs - bvectors (3xn)
    # FlipAngles - Flip Angles (degrees) (1xn)
    # tau - Diffusion Gradient Duration (seconds) (1xn), 
    # G - Diffusion Gradient Duration (G/cm - Equivalent to mT/m Divided by 10) (1xn)
    # TRs - Repetition Times (seconds) (1xn)
    # b0s - Array Defining b0 locations (b0 = 1, dwi = 0) (1xn)

bvecs, FlipAngles, tau, G, TRs, b0s = ImportTextDataDWSSFP(DataPath)

In [None]:
##
#Define DW-SSFP Tensor Forward Simulator (Uses lambda for Compatibility with SBI Toolbox)
simulator = lambda theta: FreedDWSSFPTensor_Conditional_SBIWrapper(theta,G,tau,TRs,FlipAngles,bvecs,B1Range,T1Range,T2Range,Cond=Cond)

In [None]:
##
#Generate Training data - Restricted Prior Estimator
theta_RPE, x_RPE = simulate_for_sbi(simulator, prior, noSim_RPE)

##
#Train Restricted Prior Estimator
restriction_estimator = RestrictionEstimator(prior=prior)
restriction_estimator.append_simulations(theta_RPE, x_RPE)
classifier = restriction_estimator.train()

##
#Define Restricted Prior
restricted_prior = restriction_estimator.restrict_prior()

In [None]:
##
#Generate Training Data with Restricted Prior
theta, x = simulate_for_sbi(simulator, restricted_prior, noSim)

In [None]:
##
#Create Different SNR Levels For Training Data (based on b0)

##
#Estimate Mean b0 
b0 = torch.nanmean(x[:,:b0s.shape[0]][:,b0s==1])

##
#Replicate the variable and signal arrays by the number of SNR levels + 1
thetaSNR = theta.repeat(SNRlevels+1,1)
xSNR = x.repeat(SNRlevels+1,1)

##
#Scale the Signals by Different SNR Levels
xSNR[theta.shape[0]:,:b0s.shape[0]]=np.abs(xSNR[theta.shape[0]:,:b0s.shape[0]]+torch.randn((x[:,:b0s.shape[0]].shape[0]*SNRlevels,x[:,:b0s.shape[0]].shape[1]))*b0/torch.distributions.uniform.Uniform(SNR[0], SNR[1]).sample([x.shape[0]*SNRlevels,1]))

In [None]:
##
#Define Neural Density Estimator
neural_posterior = posterior_nn(model="nsf")

#Perform Inference with NPE
inference = SNPE(prior=prior,density_estimator=neural_posterior)
density_estimator = inference.append_simulations(thetaSNR, xSNR).train()
posterior = inference.build_posterior()

In [None]:
##
#Save Posterior Object
with open(OutputPath, "wb") as handle:
    pickle.dump(posterior, handle)