<a href="https://colab.research.google.com/github/MRsources/MRzero-Core/blob/main/documentation/playground_mr0/pulseq_zero_DESC_demo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# @title Install dependencies

!pip install git+https://gitlab.cs.fau.de/mrzero/pypulseq_rfshim &> /dev/null
!pip install pulseqzero
!pip install ismrmrd
!pip install MRzeroCore --no-deps &> /dev/null
!pip install pydisseqt

(pulseq_DESC)=
# DESC with pulseq-zero](pulseq_DESC)

In [None]:
# @title imports and definitions

import torch
import numpy as np
import matplotlib.pyplot as plt
import pickle
from copy import deepcopy

import MRzeroCore as mr0
import pulseqzero
pp = pulseqzero.pp_impl

res = 42

In [None]:
# @title Load a pTx brain phantom
!wget https://github.com/MRsources/MRzero-Core/raw/d499ec3d545455f39425740c26e76f3da1a03b15/documentation/playground_mr0/phantom.pickle  &> /dev/null

with open("phantom.pickle", "rb") as f:
    phantom = pickle.load(f)

phantom = phantom.interpolate(res, res, 1)

obj = phantom.build()
obj.coil_sens = torch.ones_like(obj.coil_sens[:1, :])

obj_nob1 = deepcopy(obj)
obj_nob1.B1 = torch.ones_like(obj.B1[:1, :])

In [None]:
# @title Define the sequence with pulseq

import pulseqzero
pp = pulseqzero.pp_impl

def tse_sequence(fov=200e-3,
                 slice_thickness=8e-3,
                 base_resolution=42,
                 TE_ms=5,
                 TI_s=None,
                 Ex_FA=90,
                 Ref_FA=180,
                 shim=None,
                 r_spoil=2,
                 PE_grad_on=True,
                 RO_grad_on=True):
    """
    Generates a TSE sequence using PyPulseq.

    Args:
        system: PyPulseq system object.
        fov (float): Field of view in meters (default: 200e-3).
        slice_thickness (float): Slice thickness in meters (default: 8e-3).
        base_resolution (int): Base resolution for frequency and phase encoding (default: 42).
        TE_ms (float): Echo time in milliseconds (default: 5).
        TI_s (float): Inversion time in seconds (default: None).
        Ex_FA (float): Excitation flip angle in degrees (default: 90).
        Ref_FA (float or array): Refocusing flip angle in degrees (default: 180).
        r_spoil (float): Spoil gradient factor (default: 2).
        PE_grad_on (bool): Enable/disable phase encoding gradients (default: True).
        RO_grad_on (bool): Enable/disable readout gradients (default: True).

    Returns:
        PyPulseq sequence object.
    """
    # Define resolution
    Nread = base_resolution  # frequency encoding steps/samples
    Nphase = base_resolution  # phase encoding steps/samples
    TE = TE_ms * 1e-3

    # %% S1. SETUP sys
    system = pp.Opts(
        max_grad=28, grad_unit='mT/m', max_slew=150, slew_unit='T/m/s',
        rf_ringdown_time=20e-6, rf_dead_time=100e-6,
        adc_dead_time=20e-6, grad_raster_time=10e-6)

    seq = pp.Sequence(system)

    # Define rf events
    shim_array = None
    if shim is not None:
        shim_array = shim[0, :, :]
    rf1, gz1, gzr1 = pp.make_sinc_pulse(
        flip_angle=Ex_FA * np.pi / 180, shim_array=shim_array, phase_offset=90 * np.pi / 180, duration=1e-3,
        slice_thickness=slice_thickness, apodization=0.5, time_bw_product=4,
        system=system, return_gz=True)

    rf2, gz2, _ = pp.make_sinc_pulse(
        flip_angle=180*np.pi / 180, duration=1e-3,
        slice_thickness=slice_thickness, apodization=0.5, time_bw_product=4,
        system=system, return_gz=True)

    dwell=50e-6*2

    G_flag=(int(RO_grad_on),int(PE_grad_on))  # gradient flag (read,PE), if (0,0) all gradients are 0, for (1,0) PE is off

    # Define other gradients and ADC events
    gx = pp.make_trapezoid(channel='x', rise_time = 0.5*dwell, flat_area=Nread / fov*G_flag[0], flat_time=Nread*dwell, system=system)
    adc = pp.make_adc(num_samples=Nread, duration=Nread*dwell, phase_offset=90 * np.pi / 180, delay=0*gx.rise_time, system=system)
    gx_pre0 = pp.make_trapezoid(channel='x', area=+((1.0 + r_spoil) * gx.area / 2) , duration=1.5e-3, system=system)
    gx_prewinder = pp.make_trapezoid(channel='x', area=+(r_spoil * gx.area / 2), duration=1e-3, system=system)
    gp = pp.make_trapezoid(channel='y', area=0 / fov, duration=1e-3, system=system)
    rf_prep = pp.make_block_pulse(flip_angle=180 * np.pi / 180, duration=1e-3, system=system)


    # FLAIR
    if TI_s is not None:
      seq.add_block(rf_prep)
      seq.add_block(pp.make_delay(TI_s))
      seq.add_block(gx_pre0)

    seq.add_block(rf1,gz1)
    seq.add_block(gx_pre0,gzr1)

    # the minimal TE is given by one full period form ref pulse to ref pulse, thus gz2+gx+2*gp
    minTE2=(pp.calc_duration(gz2) +pp.calc_duration(gx) + 2*pp.calc_duration(gp))/2

    minTE2=np.round(minTE2/10e-5)*10e-5


    # to realize longer TE,  we introduce a TEdelay that is added before and afetr the encoding period
    TEd=np.round(max(0, (TE/2-minTE2))/10e-5)*10e-5  # round to raster time

    if TEd==0:
      print('echo time set to minTE [ms]', 2*(minTE2 +TEd)*1000)
    else:
      print(' TE [ms]', 2*(minTE2 +TEd)*1000)


    # last timing step is to add TE/2 also between excitation and first ref pulse
    # from pulse top to pulse top we have already played out one full rf and gx_pre0, thus we substract these from TE/2
    seq.add_block(pp.make_delay((minTE2 +TEd ) - pp.calc_duration(gz1)-pp.calc_duration(gx_pre0)))

    encoding = []

    # for ii in range(-Nphase // 2, Nphase // 2):  # e.g. -64:63
    for i in range(Nphase):
        ii = -(i+1)//2 if i % 2 == 1 else i // 2
        gp  = pp.make_trapezoid(channel='y', area=+ii / fov*G_flag[1], duration=1e-3, system=system)
        gp_ = pp.make_trapezoid(channel='y', area=-ii / fov*G_flag[1], duration=1e-3, system=system)
        encoding.append(ii)

        # Try to index into a variable FA array, if it fails treat it as number
        try:
            flip_angle = Ref_FA[ii] * torch.pi / 180
        except:
            flip_angle = Ref_FA * torch.pi / 180
        shim_array = None
        if shim is not None:
            shim_array = shim[i + 1, :, :]
        rf2, gz2, _ = pp.make_sinc_pulse(flip_angle=flip_angle, shim_array=shim_array,
                                         duration=1e-3,slice_thickness=slice_thickness, apodization=0.5, time_bw_product=4,system=system, return_gz=True)

        seq.add_block(rf2,gz2)
        seq.add_block(pp.make_delay(TEd)) # TE delay
        seq.add_block(gx_prewinder, gp)
        seq.add_block(adc, gx)
        seq.add_block(gx_prewinder, gp_)
        seq.add_block(pp.make_delay(TEd)) # TE delay

    # %% S2. CHECK, PLOT and WRITE the sequence  as .seq
    # Check whether the timing of the sequence is correct
    ok, error_report = seq.check_timing()
    if ok:
        print('Timing check passed successfully')
    else:
        print('Timing check failed. Error listing follows:')
        [print(e) for e in error_report]

    return seq, encoding

In [None]:
# @title Define helper functions

def CP_8ch_shim():
    shim = torch.tensor([[0.35,    0 * np.pi / 180],
                         [0.35,  -45 * np.pi / 180],
                         [0.35,  -90 * np.pi / 180],
                         [0.35, -135 * np.pi / 180],
                         [0.35, -180 * np.pi / 180],
                         [0.35,  135 * np.pi / 180],
                         [0.35,   90 * np.pi / 180],
                         [0.35,   45 * np.pi / 180]])

    seq_shim = torch.zeros((res + 1, 8, 2))
    for i in range(res + 1):
        seq_shim[i, :, :] = shim
    return seq_shim


def reconstruction(signal, encoding, Nread, Nphase):
    # reconstruct image
    kspace = torch.reshape((signal), (Nread, Nphase)).clone().t()
    encoding = np.stack(encoding)
    ipermvec = np.argsort(encoding)
    kspace=kspace[:,ipermvec]

    # fftshift FFT fftshift
    spectrum = torch.fft.fftshift(kspace)
    space = torch.fft.fft2(spectrum)
    space = torch.fft.ifftshift(space)

    return space


def sim(shim):
    with pulseqzero.mr0_mode():
        seq, encoding = tse_sequence(base_resolution=res, Ref_FA=120, shim=shim)
    seq = seq.to_mr0()

    graph = mr0.compute_graph(seq, obj, 1000, 1e-4)
    signal = mr0.execute_graph(graph, seq, obj, 1e-2, 1e-3)
    reco = reconstruction(signal, encoding, res, res)

    return reco

In [None]:
# @title Generate target by simulating without B1 inhomogeneities or pTx
with pulseqzero.mr0_mode():
    seq_noshim, encoding = tse_sequence(base_resolution=res, Ref_FA=120)
seq_noshim = seq_noshim.to_mr0()
graph_nob1 = mr0.compute_graph(seq_noshim, obj_nob1, 1000, 1e-4)
signal_target = mr0.execute_graph(graph_nob1, seq_noshim, obj_nob1, 1e-2, 1e-3)
target = reconstruction(signal_target, encoding, res, res)


reco = sim(CP_8ch_shim())
vmax = max(reco.abs().max(), target.abs().max())

plt.figure()
plt.subplot(221)
plt.title("CP mode")
plt.imshow(reco.abs().T, origin="lower", vmin=0, vmax=vmax)
plt.axis("off")
plt.subplot(223)
plt.imshow(reco.angle().T, origin="lower", vmin=-np.pi, vmax=np.pi, cmap="twilight")
plt.axis("off")
plt.subplot(222)
plt.title("perfect B1")
plt.imshow(target.abs().T, origin="lower", vmin=0, vmax=vmax)
plt.axis("off")
plt.subplot(224)
plt.imshow(target.angle().T, origin="lower", vmin=-np.pi, vmax=np.pi, cmap="twilight")
plt.axis("off")
plt.show()

In [None]:
# @title DESC: pTx with full sequence shimming and reco-based loss

def plot_optim(loss_hist, shim_hist, reco):
    cmap = plt.get_cmap("viridis")

    plt.figure(figsize=(18, 3), dpi=80)
    plt.subplot(151)
    plt.title("loss history")
    plt.plot(loss_hist)
    plt.grid()
    plt.subplot(152)
    plt.title("shim history")
    for i in range(len(shim_hist)):
        plt.plot(shim_hist[i][:, :, 0].mean(1) * 8**0.5, c=cmap(i / len(shim_hist)))
    plt.grid()
    plt.subplot(153)
    plt.title("shim")
    for i in range(res + 1):
        rot = shim_hist[-1][:, :, 0] * torch.exp(1j * shim_hist[-1][:, :, 1])
        plt.scatter(rot[i, :].real, rot[i, :].imag, color=cmap(i / res))
    plt.xlim(-0.7, 0.7)
    plt.ylim(-0.7, 0.7)
    plt.grid()
    plt.subplot(154)
    plt.title("reco")
    plt.imshow(reco.abs().T, origin="lower", vmin=0, vmax=target.abs().max())
    plt.axis("off")
    plt.subplot(155)
    plt.title("target")
    plt.imshow(target.abs().T, origin="lower", vmin=0)
    plt.axis("off")


shim = CP_8ch_shim()
shim.requires_grad = True
optimizer = torch.optim.Adam(params=[shim], lr=0.05)
iter_count = 100
shim_hist = []
loss_hist = []

for i in range(iter_count):
    optimizer.zero_grad()
    reco = sim(shim)
    diff = (reco - target).abs()
    loss = (diff**2).mean().sqrt()

    loss_hist.append(float(loss))
    shim_hist.append(shim.detach().clone())

    plot_optim(loss_hist, shim_hist, reco.detach())
    plt.savefig(f"iter-{i}.png")
    if i % 10 == 4:
        plt.show()
    else:
        plt.close()

    print(f"{i+1:03}/{iter_count:03}: loss={float(loss)}")
    loss.backward()
    optimizer.step()

loss_hist.append(float(loss))
shim_hist.append(shim.detach().clone())
plot_optim(loss_hist, shim_hist, reco.detach())
plt.savefig(f"iter-{iter_count}.png")
plt.show()

In [None]:
# @title Generate a .gif of the optimization process

import os
import imageio

images = [imageio.imread(f"iter-{i}.png") for i in range(101)]
# for file_name in sorted(os.listdir(".")):
#     if file_name.endswith('.png'):
#         images.append(imageio.imread(file_name))

imageio.mimsave('optim.gif', images)