# T2 Mapping - Multi-echo SE

T2 mapping using a multi-echo spin echo sequence where a single readout line is acquired. 
A long repetition time (TR) is used to ensure full signal recovery before the next k-space line is acquired. 

### Imports

In [None]:
import tempfile
from pathlib import Path

import matplotlib.pyplot as plt
import MRzeroCore as mr0
import numpy as np
import torch
from cmap import Colormap
from einops import rearrange
from mrpro.algorithms.reconstruction import DirectReconstruction
from mrpro.data import KData
from mrpro.data import SpatialDimension
from mrpro.data.traj_calculators import KTrajectoryCartesian
from mrpro.operators import DictionaryMatchOp
from mrpro.operators.models import MonoExponentialDecay

from mrseq.scripts.t2_t2prep_flash import main as create_seq
from mrseq.utils import sys_defaults

### Settings
We are going to use a numerical phantom with a matrix size of 128 x 128. The repetition time is set to 20 seconds to ensure also tissue with long T1 such as CSF is fully relaxed. 

In [None]:
image_matrix_size = [64, 64]
t2_prep_echo_times = [0, 0.02, 0.08]

tmp = tempfile.TemporaryDirectory()
fname_mrd = Path(tmp.name) / 't2.mrd'

### Create the digital phantom

We use the standard Brainweb phantom from [MRzero](https://github.com/MRsources/MRzero-Core), but we set the B1-field to be constant everywhere.

In [None]:
def times_phantom(values, background_value, size=64, circle_radius=0.15, corner_radius=0.15, margin=0.15):
    """Generate 3 x 3 circles to simulate the T1MES phantom."""
    if len(values) != 9:
        raise ValueError('You must provide exactly 9 values for the circles.')

    device = 'cpu'
    values = torch.tensor(values, dtype=torch.float32, device=device)
    bg_val = torch.tensor(background_value, dtype=torch.float32, device=device)

    # Coordinate grid [0, 1]
    lin = torch.linspace(0, 1, size, device=device)
    Y, X = torch.meshgrid(lin, lin, indexing='ij')

    # Define the region where the rounded square sits
    inner_min = margin
    inner_max = 1 - margin
    square_size = inner_max - inner_min

    # Map coordinates to [0, 3] layout space
    Xs = (X - inner_min) / square_size * 3.0
    Ys = (Y - inner_min) / square_size * 3.0

    # --- Rounded square mask (hard edges) ---
    def rounded_square_mask(X, Y, size=3.0, r=0.3):
        """Binary mask: 1 inside square with rounded corners, 0 outside."""
        # Inside main square
        mask = (X >= 0) & (size >= X) & (Y >= 0) & (size >= Y)
        mask = mask.float()

        # Carve out quarter circles at corners
        corners = [
            (r, r),  # bottom-left
            (size - r, r),  # bottom-right
            (r, size - r),  # top-left
            (size - r, size - r),  # top-right
        ]
        for cx, cy in corners:
            dist = torch.sqrt((X - cx) ** 2 + (Y - cy) ** 2)
            if cx < size / 2 and cy < size / 2:
                region = (cx > X) & (cy > Y)
            elif cx > size / 2 and cy < size / 2:
                region = (cx < X) & (cy > Y)
            elif cx < size / 2 and cy > size / 2:
                region = (cx > X) & (cy < Y)
            else:
                region = (cx < X) & (cy < Y)
            mask[region] = (dist <= r).float()[region]
        return mask

    bg_mask = rounded_square_mask(Xs, Ys, size=3.0, r=corner_radius * 3.0)

    # Start with background
    img = bg_val * bg_mask

    # --- Draw 3x3 circles (hard edges, constant values) ---
    positions = [(x + 0.5, y + 0.5) for y in range(3) for x in range(3)]
    for (cx, cy), val in zip(positions, values, strict=True):
        dist = torch.sqrt((Xs - cx) ** 2 + (Ys - cy) ** 2)
        circle_mask = (dist <= circle_radius * 3.0).float()
        # Replace background with circle value inside mask
        img = img * (1 - circle_mask) + val * circle_mask

    # Zero out everything outside the rounded square
    img = img * bg_mask

    return img


values = [0.1, 0.3, 0.5, 0.2, 0.7, 0.4, 0.9, 0.6, 0.8]

tensor_img = times_phantom(values, background_value=1.0)
print(tensor_img.shape, tensor_img.min().item(), tensor_img.max().item())

plt.figure()
plt.imshow(tensor_img)
plt.colorbar()

In [None]:
t1_values_t1mes = [0.256, 1.49, 0.427, 0.818, 1.384, 1.107, 0.295, 0.557, 0.429]
t2_values_t1mes = [0.172, 0.282, 0.212, 0.054, 0.057, 0.056, 0.05, 0.051, 0.05]
t1mes = [1, 2, 3, 4, 5, 6, 7, 8, 9]
t1mes = t1_values_t1mes
plt.figure()
plt.imshow(times_phantom(values=t1mes, background_value=0.846), vmin=0, vmax=1.5)
plt.colorbar()

In [None]:
t1_values_t1mes = [0.256, 1.49, 0.427, 0.818, 1.384, 1.107, 0.295, 0.557, 0.429]
t2_values_t1mes = [0.172, 0.282, 0.212, 0.054, 0.057, 0.056, 0.05, 0.051, 0.05]
phantom = mr0.VoxelGridPhantom(
    PD=times_phantom(values=[1.0] * 9, background_value=0.2)[:, :, None],
    T1=times_phantom(values=t1_values_t1mes, background_value=0.846)[:, :, None],
    T2=times_phantom(values=t2_values_t1mes, background_value=0.141)[:, :, None],
    T2dash=torch.ones([*image_matrix_size, 1]) * 10,
    D=torch.zeros([*image_matrix_size, 1]),
    B0=torch.ones([*image_matrix_size, 1]),
    B1=torch.ones([1, *image_matrix_size, 1]),
    coil_sens=torch.ones([1, *image_matrix_size, 1]),
    size=(0.200, 0.200, 0.008),
)

### Create the multi-echo SE sequence

To create the multi-echo SE sequence, we use the previously imported [t2_multi_echo_se_single_line script](../src/mrseq/scripts/t2_multi_echo_se_single_line.py).


In [None]:
sequence, fname_seq = create_seq(
    system=sys_defaults,
    test_report=False,
    timing_check=False,
    t2_prep_echo_times=t2_prep_echo_times,
    fov_xy=float(phantom.size.numpy()[0]),
    n_readout=image_matrix_size[0],
    cardiac_trigger_delay=1.0,
    acceleration=1,
)

### Simulate the sequence
Now, we pass the sequence and the phantom to the MRzero simulation and save the simulated signal as an (ISMR)MRD file.

In [None]:
mr0_sequence = mr0.Sequence.import_file(str(fname_seq.with_suffix('.seq')))
signal, ktraj_adc = mr0.util.simulate(mr0_sequence, phantom, accuracy=1e-5)
mr0.sig_to_mrd(fname_mrd, signal, sequence)

### Reconstruct the images at different echo times

We use [MRpro](https://github.com/PTB-MR/MRpro) for the image reconstruction.

In [None]:
kdata = KData.from_file(fname_mrd, trajectory=KTrajectoryCartesian())
kdata.header.encoding_matrix = SpatialDimension(z=1, y=image_matrix_size[1], x=2 * image_matrix_size[0])
kdata.header.recon_matrix = SpatialDimension(z=1, y=image_matrix_size[1], x=image_matrix_size[0])
recon = DirectReconstruction(kdata, csm=None)
idata = recon(kdata)

We can now plot the images at different inversion times.

In [None]:
idat = idata.data.abs().numpy().squeeze()
fig, ax = plt.subplots(1, idat.shape[0], figsize=(4 * idata.shape[0], 4))
for i in range(idat.shape[0]):
    ax[i].imshow(idat[i, :, :], cmap='gray')
    ax[i].set_title(f'TE = {int(t2_prep_echo_times[i] * 1000)} ms')
    ax[i].set_xticks([])
    ax[i].set_yticks([])

### Estimate the T2 maps
We use a dictionary matching approach to estimate the T2 maps. Afterward, we compare them to the input and ensure they match.

In [None]:
dictionary = DictionaryMatchOp(MonoExponentialDecay(decay_time=t2_prep_echo_times), index_of_scaling_parameter=0)
dictionary.append(torch.tensor(1.0), torch.linspace(0.01, 0.8, 1000)[None, :])
m0_match, t2_match = dictionary(idata.data[:, 0, 0])

t2_input = np.roll(rearrange(phantom.T2.numpy().squeeze()[::-1, ::-1], 'x y -> y x'), shift=(1, 1), axis=(0, 1))
obj_mask = np.zeros_like(t2_input)
obj_mask[t2_input > 0] = 1
t2_measured = t2_match.numpy().squeeze() * obj_mask

fig, ax = plt.subplots(1, 3, figsize=(15, 3))
for cax in ax:
    cax.set_xticks([])
    cax.set_yticks([])

im = ax[0].imshow(t2_input, vmin=0, vmax=0.3, cmap=Colormap('navia').to_mpl())
fig.colorbar(im, ax=ax[0], label='Input T2 (s)')

im = ax[1].imshow(t2_measured, vmin=0, vmax=0.3, cmap=Colormap('navia').to_mpl())
fig.colorbar(im, ax=ax[1], label='Measured T2 (s)')

im = ax[2].imshow(t2_measured - t2_input, vmin=-0.3, vmax=0.3, cmap='bwr')
fig.colorbar(im, ax=ax[2], label='Difference T2 (s)')

relative_error = np.sum(np.abs(t2_input - t2_measured)) / np.sum(np.abs(t2_input))
print(f'Relative error {relative_error}')
assert relative_error < 0.02