# FEXI notebook

### Imports

In [1]:
import numpy as np
import matplotlib.pyplot as plt
import nibabel as nib
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as utils
from tqdm import tqdm
from scipy.special import erf
import scipy.stats

from dmipy.core.acquisition_scheme import acquisition_scheme_from_bvalues
from dmipy.core.modeling_framework import MultiCompartmentSphericalMeanModel
from dmipy.signal_models import sphere_models, cylinder_models, gaussian_models

from scipy.io import savemat

  from .autonotebook import tqdm as notebook_tqdm


### FEXI simulations
#### Simulated $s$ signal axr_sim()

In [3]:
def axr_sim(adc, sigma, axr, bf, be, tm):
   """Generate an output signal s based on known inputs for a given voxel
Inputs   - adc:      apparent diffusion coefficient [m2/s] single value
         - sigma:    filter efficiency single value
         - axr:      exchange rate [1/s] single value
         - bf:       filter block b-value [m2/s] 20x1
         - be:       encoding block b-value [m2/s] 20x1
         - tm:       mixing time [s] 20x1

Output: 	- s:        signal (sum of the magnetisations) single value
   Based off Elizabeth's code. 
   """

   tm[(bf == 0) & (tm == min(tm))] = np.inf
   #calculate ADC as fnc of mixing time
   adc_tm = adc * (1 - sigma* np.exp(-axr*tm))

   # compute signal
   s = np.exp(-adc_tm*be)
   return s 


### Estimate $ADC$, $\sigma$ & $AXR$
#### axr_fit()

In [None]:
#==========================================================================
# for a given signal estimate the variables that produced it. 
# Estimate adc, sigma and axr
#  (Lasic 2011, MRM)
#
# Use: adc, sigma, axr = axr_fit(bf, be, tm, smeas, init, lb, ub)
#
# Inputs    - bf:       filter block b-value [m^2/s] 20x1
#           - be:       encoding block b-value [m^2/s] 20x1 
#           - tm:       mixing time [s] 20x1 
#           - smeas:    measured signal (normalised) 20x1?
#           - init:     initial values  [adc, sigma, axr] [m2/s a.u. 1/s] 3x1
#           - lb:       lower bounds   [adc, sigma, axr] [m2/s a.u. 1/s] 3x1
#           - ub:       upper bounds   [adc, sigma, axr] [m2/s a.u. 1/s] 3x1
#
# Output: 	- adc:      fitted ADC [m2/s] single value
#           - sigma:	fitted filter efficiency single value 
#           - axr:      fitted AXR [1/s] single value
#
# Author: E Powell, 23/08/23
#
#==========================================================================

### Input variables from Elizabeth

In [4]:
# example code for Gabe
# E Powell, 24/11/2023

bf = np.array([0, 0, 250, 250, 250, 250, 250, 250]) * 1e6    
# filter b-values [s/m2]
#bf = bf.reshape(-1, 1)  # probs not needed because 1d - Reshaping to a column vector
be = np.array([0, 250, 0, 250, 0, 250, 0, 250]) * 1e6         
# encoding b-values [s/m2]
tm = np.array([20, 20, 20, 20, 200, 200, 400, 400]) * 1e-3    
# mixing time [s]

sim_adc = 1e-9                             # ADC, simulated [m2/s]
sim_sig = 0.2                              # sigma, simulated [a.u.]
sim_axr = 3                                # AXR, simulated [s-1]

# simulate signals
s = axr_sim(sim_adc, sim_sig, sim_axr, bf, be, tm) 

# fit model to simulated signals and estimate parameters
init = np.array([1.1e-9, .15, 3.5])
lb = np.array([.1e-9, 0, .1])
ub = np.array([3.5e-9, 1, 20])


fit_adc, fit_sig, fit_axr = axr_fit(bf, be, tm, s, init, lb, ub)

In [None]:
# print and compare simulated vs fitted
[sim_adc, sim_sig, sim_axr].*[1e9 1 1]
[fit_adc, fit_sig, fit_axr].*[1e9 1 1]

### My old method

Say we wanted to calculate s for a random distribution of adc values.
I think axr may also vary, but both are probs linked to each other. 

In [3]:

#adc=1
sigma=2
axr=3

be = np.array([np.repeat(0,10), np.repeat(250,10)]).flatten()
bf = np.array([1e-6, 0.090, 1e-6, 0, 1e-6, 1.5, 1e-6, 2, 1e-6, 3,1e-6, 0, 1e-6, 0.500, 1e-6, 1.5, 1e-6, 2, 1e-6, 3])
tm = np.array([1e-6, 0.090, 1e-6, -1, 1e-6, 1.5, 1e-6, 2, 1e-6, 3,1e-6, 0.090, 1e-6, 0.500, 1e-6, 1.5, 1e-6, 2, 1e-6, 3])
nvox = 1000

#axrs = np.random.uniform(1,100,nvox) 
adcs = np.random.uniform(200,300,nvox) 

#axr_sim(adc,sigma,axr,bf,be,tm)


s = np.array([axr_sim(adc,sigma,axr,bf,be,tm) for adc in adcs]) 


