# T1-T2 Saturation Recovery CPMG Notebook

#### Setup
Edit the cell below to set the `LAB_USER_NAME` variable to your name then click **Run->Run All Cells** in the top menu bar of jupyterlab. The data from your experiments will be saved into a folder using your user name.

In [None]:
LAB_USER_NAME = 'REPLACE_ME'

In [None]:
import panel as pn
import sys
import os
import yaml
import numpy as np
from bokeh.palettes import Viridis
import asyncio

from matipo import Sequence, SEQUENCE_DIR, GLOBALS_DIR, DATA_DIR, Unit
from matipo.util.decimation import decimate
from matipo.util.fft import fft_reconstruction
from matipo.util.etl import deinterlace
from matipo.util import ilt

from matipo.experiment.base_experiment import BaseExperiment, auto_inputs, PlotInterface
from matipo.experiment.models import PLOT_COLORS, SITickFormatter, SIHoverFormatter
from matipo.experiment.plots import SignalPlot, SpectrumPlot, ImagePlot, Image1DPlot, ComplexDataLinePlot

from loglogmap_plot import LogLogMapPlot
from ilt_SRCPMG import SRCPMG_T1T2_spectrum

pn.extension()

WORKSPACE = os.path.join(LAB_USER_NAME, 'SRCPMG_Map')
LAB_DIR = os.path.join(DATA_DIR, WORKSPACE)
os.makedirs(LAB_DIR, exist_ok=True)
print('Data will be saved to', LAB_DIR)

##### SRCPMG Pulse Sequence
(90 pulse | `t_sat_i`) for `t_sat_i` in `t_sat` | `t_rec` | 90 pulse | `t_echo/2` | (180 pulse | `t_echo`) * `n_echo` | `t_end`

Acquisitions are centred between 180 pulses. The saturation module (`t_sat`) is not adjustable via the GUI as it has been optimised for *ilumr*.

##### Parameter Selection
- **n_scans:** The number of times that the entire pulse sequence is repeated and the data averaged together. This should be an even number for optimal phase cycling. 
- **n_rec**:The number of recovery times used in the pulse sequence.
- **t_rec_min**, **t_rec_max**: The minimum and maximum recovery times used in the pulse sequence. The experiment will calculate `n_rec` recovery times between the selected `t_rec_min` and `t_rec_max` at log space intervals.
- **t_echo:** The time from the excitation 90 pulse to the first echo, and spacing between the subsequent echos.
- **n_echo:** The number of echos.
- **n_samples:** The number of samples acquired per echo. These samples will be averaged together during data processing to get a single data point per echo. Recommended to use between 16 and 64 for a good tradeoff between the DSP group delay and post-processing time.
- **t_dw:** The sample dwell time. Total sampling time (`n_samples * t_dw`) must be able to fit between the 180 pulses, i.e. `(n_samples * t_dw) + t_180 < t_echo`, otherwise the sampling will overlap with the 180 pulse and cause artifacts. The value of `t_180` can be found in `/home/globals/hardpulse_180.yaml`. 


##### Recommended Default Parameters
n_scans=2, n_rec=10, t_rec_min=0.0002, t_rec_max=10, t_echo=200, n_echo=20000, n_samples=16, t_dw=8. 

In [None]:
class T1_T2_SRCPMG(BaseExperiment):
    
    def setup(self):
        self.enable_partialplot = True
        self.title = "T1-T2 SRCPMG Experiment"
        self.workspace = WORKSPACE
        self.seq = Sequence('SRCPMG.py')
        self.plots = {
            'map': LogLogMapPlot(
                interpolate=10,
                figure_opts = dict(
                            title="T1-T2 Map",
                            x_axis_label="T2 (s)",
                            y_axis_label="T1 (s)",
                )),
            'recovery': ComplexDataLinePlot(
                figure_opts = dict(
                        title="T1 Saturation Recovery",
                        x_axis_label="Time (s)",
                        y_axis_label="Amplitude")),
            'decay': ComplexDataLinePlot(
                figure_opts = dict(
                            title="T2 decay",
                            x_axis_label="Time (s)",
                            y_axis_label="Amplitude"))
        }

        self.inputs = auto_inputs(self.seq, {
            'n_scans': 2,
            'n_rec': 10,
            't_rec_min': 2e-4*Unit('s'),
            't_rec_max': 2*Unit('s'),
            't_echo': 200*Unit('us'),
            'n_echo': 5000,
            'n_samples': 16,
            't_dw': 4*Unit('us')
        })        
        
    def update_par(self):     
        self.seq.setpar(
            t_end=0.1, # don't need to wait for T1 recovery in saturation recovery type experiments, should still wait >>T2* to avoid coherence pathways
            t_sat=5e-3*np.array([7, 5, 3, 2]), # pulse spacings should be larger than T2* and decreasing (recommended by literature)
            t_rec=np.logspace(np.log10(self.inputs['t_rec_max'].value), np.log10(self.inputs['t_rec_min'].value), self.inputs['n_rec'].value) # measure using recovery times in decreasing order (recommended by literature)
        )
        
    async def update_plots(self, final):
        await self.seq.fetch_data()
        
        # average the echos to get the integrated echo decay curve
        y = np.reshape(np.mean(np.reshape(self.seq.data,(-1, self.seq.par.n_samples)),axis=1),(-1, self.seq.par.n_echo))
        
        # correct phase assuming smallest inversion time has a phase of 180
        phase = np.angle(np.mean(y[0, :2]))
        y *= np.exp(1j*-phase)
        
        # make time axes
        t_T2 = np.linspace(0, self.seq.par.n_echo*self.seq.par.t_echo, self.seq.par.n_echo)
        t_T1 = self.seq.par.t_rec
        
        self.plots['recovery'].update(t_T1[:self.progress.value], np.mean(y[:self.progress.value, :2], axis=1))
        self.plots['decay'].update(t_T2, y[self.progress.value-1])
        self.log.debug(f'self.progress.value: {self.progress.value}')
        
        if final:
            self.log.info('Running Inverse Laplace Transform (~1 min)...')
            # Run Inverse Laplace Transform
            n_T1 = 50
            n_T2 = 50
            alpha = 1
            self.T1 = np.logspace(-3,1,n_T1)
            self.T2 = np.logspace(-3,1,n_T2)
            self.S = SRCPMG_T1T2_spectrum(self.T1, self.T2, t_T1, t_T2, y, alpha, tol=1e-5, progress=1000)

            self.plots['map'].update(self.S, self.T2[0], self.T2[-1], self.T1[0], self.T1[-1])

SRCPMGExp = T1_T2_SRCPMG(state_id='srcpmg')
SRCPMGExp()