# FEXI notebook

### Imports

In [1]:
import numpy as np
import matplotlib.pyplot as plt
import nibabel as nib
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as utils
from tqdm import tqdm
from scipy.special import erf
import scipy.stats

from dmipy.core.acquisition_scheme import acquisition_scheme_from_bvalues
from dmipy.core.modeling_framework import MultiCompartmentSphericalMeanModel
from dmipy.signal_models import sphere_models, cylinder_models, gaussian_models

from scipy.io import savemat

  from .autonotebook import tqdm as notebook_tqdm


### FEXI simulations
#### Simulated $s$ signal

In [2]:
def axr_sim(adc, sigma, axr, bf, be, tm):
   """Generate an output signal s based on known inputs for a given voxel
Inputs   - adc:      apparent diffusion coefficient [m2/s] single value
         - sigma:    filter efficiency single value
         - axr:      exchange rate [1/s] single value
         - bf:       filter block b-value [m2/s] 20x1
         - be:       encoding block b-value [m2/s] 20x1
         - tm:       mixing time [s] 20x1

Output: 	- s:        signal (sum of the magnetisations) single value
   Based off Elizabeth's code. 
   """

   tm[(bf == 0) & (tm == min(tm))] = np.inf
   #calculate ADC as fnc of mixing time
   adc_tm = adc * (1 - sigma* np.exp(-axr*tm))

   # compute signal
   s = np.exp(-adc_tm*be)
   return s 


Say we wanted to calculate s for a random distribution of adc values.
I think axr may also vary, but both are probs linked to each other. 

In [3]:
#adc=1
sigma=2
axr=3

be = np.array([np.repeat(0,10), np.repeat(250,10)]).flatten()
bf = np.array([1e-6, 0.090, 1e-6, 0, 1e-6, 1.5, 1e-6, 2, 1e-6, 3,1e-6, 0, 1e-6, 0.500, 1e-6, 1.5, 1e-6, 2, 1e-6, 3])
tm = np.array([1e-6, 0.090, 1e-6, -1, 1e-6, 1.5, 1e-6, 2, 1e-6, 3,1e-6, 0.090, 1e-6, 0.500, 1e-6, 1.5, 1e-6, 2, 1e-6, 3])
nvox = 1000

#axrs = np.random.uniform(1,100,nvox) 
adcs = np.random.uniform(200,300,nvox) 

#axr_sim(adc,sigma,axr,bf,be,tm)


s = np.array([axr_sim(adc,sigma,axr,bf,be,tm) for adc in adcs]) 




### axr_fit

In [None]:
def axr_fit(bf, be, tm, smeas, init, lb, ub):
    """
    Estimate AXR FEXI model parameters
    (Lasic 2011, MRM)

    Use: [adc, sigma, axr] = axr_fit(bf, be, tm, smeas, init, lb, ub)

    Inputs  - bf:       filter block b-value [m^2/s] 20x1
            - be:       encoding block b-value [m^2/s] 20x1
            - tm:       mixing time [s] 20x1
            - smeas:    measured signal (normalised) probs 20x1
            - init:     initial values  [adc, sigma, axr] [m2/s a.u. 1/s] 3x1
            - lb:       lower bounds   [adc, sigma, axr] [m2/s a.u. 1/s] 3x1
            - ub:       upper bounds   [adc, sigma, axr] [m2/s a.u. 1/s] 3x1

    Output: - adc:      fitted ADC [m2/s] Single value
            - sigma:	fitted filter efficiency Single value
            - axr:      fitted AXR [1/s] single value

    Author: E Powell, 23/08/23
    """
    #will this line work?
    # Matlab line was:
    # if size(init,1) >= 25

    #only use parpool if multiple inits
    if np.shape(init) >= 25:
        useparpool = True
    else:
        useparpool = False

        
    #scale sequence values (values to be fitted, i.e. diffusivities, need to be ~1)
    bf = bf * 1e-9
    be = be * 1e-9

    #scale initial parameter values and bounds
    # The matlab line was
    # all_params = init .* [1e9 ones(1, size(init,2)-1)];
    all_params = init * [] #incomplete
    lb[1] = lb[1] * 1e9
    ub[1] = ub[1] * 1e9

    #find if any parameters are fixed
    idx_free = np.nonzero(lb != ub)
    idx_fixed = np.nonzero(lb == ub)

    #select initial values and bounds only for free params
    # free_params = all_params(:,idx_free);
    free_params = all_params[:,idx_free]
    lb = lb[idx_free]
    ub = ub[idx_free]

    #opt = optimset('display', 'off'); 
    # Line to not show info as you interate.

    """===========
    Stopped here in matla code
    ============"""

    return adc, sigma, axr

### fit_axr_sse

In [None]:
def fit_axr_sse(free_params, all_params, idx_free, bf, tm, be, smeas):
    """
    Inputs - free_params:  parameter values being fitted
                           [adc sigma axr] 3x1
           - scheme:       acquisition parameters [nx3]
                           [bf, be, tm] 3x1
           - all_params:   array of all parameters
                           [adc sigma axr] 3x1
           - idx_free:     indices of free parameters within free_params
           - idx_adc:      indices into scheme of acquisitions used for ADC calc
           - idx_s0sf:     indices into scheme of acquisitions used for S0/Sf calc
                       - NORMALISED to b=0 for each bf, tm combination so that S0 = Sbf = 0
           
 Outputs   - sfit:     fitted signal [1, ntm x nbval]

    """