In [None]:
!pip install pulseqzero &> /dev/null
!pip install MRzeroCore &> /dev/null
!pip install ismrmrd

!wget https://github.com/MRsources/MRzero-Core/raw/main/documentation/playground_mr0/numerical_brain_cropped.mat &> /dev/null

In [None]:
# @title On Google Colab, you need to restart the runtime after executing this cell
!pip install numpy==1.24

(FLASH_FAopt_PSF)=
# Pulseq-zero Demo
Pulseq-zero combines Pusleq and MR-zero and allows you to optimize a Pulseq sequence dirrectly.

For example, herein we want to optimize the variable flip angles of a single shot FLASH sequence to improve the PSF to achieve the sharpness of a multi-shot FLASH


###First,###
we now need to define a FLASH sequence as function with parameters we wish to optimize as arguments:

# 2D FLASH - flipangle optimization


In [None]:
import numpy as np
import torch
import matplotlib.pyplot as plt
from matplotlib import cm

import MRzeroCore as mr0
import pulseqzero
pp = pulseqzero.pp_impl

In [None]:
# @title FLASH sequence definition

def flash(fov=200e-3,
          slice_thickness=8e-3,
          Nread=64,   # frequency encoding steps/samples
          Nphase=64,  # phase encoding steps/samples
          dwell_time=10e-5,
          shots = 1,
          fa_readout=7*torch.pi/180, # readout flip angle
          Trec=15, # delay time after the execution of each shot
          system=None
          ):

  # =====
  # default system settings
  # =====

  if system is None:
    system = pp.Opts(max_grad=80,
                     grad_unit='mT/m',
                     max_slew=200,
                     slew_unit='T/m/s')

  # =====
  # Handle parameter fa_readout: for optimization purposes needs to be torch.tensor
  # =====

  # neccesary size of fa_readout tensor
  fa_readout_size = int(Nphase/shots)

  # catch input of a single readout flip angle and transofrm to tensor
  if type(fa_readout) == float:
    fa_readout = torch.full((fa_readout_size,), fa_readout)

  # check if fa_readout is tensor object
  if isinstance(fa_readout, torch.Tensor):
    # chek if length is correct, fa_readout is a one dimensional tensor
    if fa_readout.size(0) != fa_readout_size:
      raise ValueError("Parameter fa_readout is of impropper size")

  # =====
  # Define rf events
  # =====

  rf1, _, _ = pp.make_sinc_pulse(
      flip_angle=10 * np.pi / 180, duration=1e-3,
      slice_thickness=slice_thickness, apodization=0.5, time_bw_product=4,
      system=system, return_gz=True
  )
  # rf1 = pp.make_block_pulse(flip_angle=90 * np.pi / 180, duration=1e-3, system=system)

  # =====
  # Define other gradients and ADC events
  # =====

  gx = pp.make_trapezoid(channel='x', flat_area=Nread / fov, flat_time=Nread*dwell_time, system=system)
  adc = pp.make_adc(num_samples=Nread, duration=Nread*dwell_time, phase_offset=0 * np.pi/180, delay=gx.rise_time, system=system)
  gx_pre = pp.make_trapezoid(channel='x', area=-gx.area / 2, duration=1e-3, system=system)
  gx_spoil = pp.make_trapezoid(channel='x', area=1.5 * gx.area, duration=2e-3, system=system)

  rf_phase = 0
  rf_inc = 0
  rf_spoiling_inc = 84

  # ======
  # CONSTRUCT SEQUENCE
  # ======

  seq = pp.Sequence()

  ## centric reordering
  phenc = np.arange(-Nphase // 2, Nphase // 2, 1) / fov
  permvec = sorted(np.arange(len(phenc)), key=lambda x: abs(len(phenc) // 2 - x))

  phenc_centr = phenc[permvec]
  encoding = []

  for shot in range(shots):

    if Trec > 0:
        seq.add_block(pp.make_delay(Trec))

    for ii in range(0, int(Nphase/shots)):  # e.g. -64:63

        rf1, _, _  = pp.make_sinc_pulse(flip_angle=fa_readout[ii],
                                 duration=1e-3,
                                 slice_thickness=slice_thickness,
                                 apodization=0.5,
                                 time_bw_product=4,
                                 system=system,
                                 return_gz=True
                                )

        rf1.phase_offset = rf_phase / 180 * np.pi   # set current rf phase
        adc.phase_offset = rf_phase / 180 * np.pi  # follow with ADC
        rf_inc = divmod(rf_inc + rf_spoiling_inc, 360.0)[1]   # increase increment
        # increment additional phase
        rf_phase = divmod(rf_phase + rf_inc, 360.0)[1]
        seq.add_block(rf1)
        gp = pp.make_trapezoid(channel='y', area=phenc_centr[ii*shots+shot], duration=1e-3, system=system)

        encoding.append(phenc_centr[ii*shots+shot]*fov)

        seq.add_block(gx_pre, gp)
        seq.add_block(adc, gx)
        gp = pp.make_trapezoid(channel='y', area=-phenc_centr[ii*shots+shot], duration=1e-3, system=system)
        seq.add_block(gx_spoil, gp)

  return seq, encoding

In [None]:
# @title image reconstruction via FFT

def reconstruction(signal, encoding, Nread, Nphase):
  # reconstruct image
  kspace = torch.reshape((signal), (Nread, Nphase)).clone().t()
  encoding = np.stack(encoding)
  ipermvec = np.argsort(encoding)
  kspace=kspace[:,ipermvec]

  # fftshift FFT fftshift
  spectrum = torch.fft.fftshift(kspace)
  space = torch.fft.fft2(spectrum)
  space = torch.fft.ifftshift(space)

  return space

In [None]:
# @title ploting functions for optimization steps and results

def plot_results_images(target, init, result=None, finished=False, colorbars=False):

  # show target, initial and optimized image on common colorscale
  if result is None:
    vmin = min(target.min(), init.min())
    vmax = max(target.max(), init.max())
  else:
    vmin = min(target.min(), init.min(), result.min())
    vmax = max(target.max(), init.max(), result.max())

  plt.subplot(131)
  plt.title("optimizer target")
  plt.axis('off')
  mr0.util.imshow(target, vmin=vmin, vmax=vmax, cmap=cm.gray)
  if colorbars: plt.colorbar(cmap='gray')

  plt.subplot(132)
  plt.title("initial image")
  plt.axis('off')
  mr0.util.imshow(init, vmin=vmin, vmax=vmax, cmap=cm.gray)
  if colorbars: plt.colorbar(cmap='gray')

  if result is not None:
    plt.subplot(133)
    if finished: plt.title("optimizer result")
    else: plt.title("optimizer step")
    plt.axis('off')
    mr0.util.imshow(result, vmin=vmin, vmax=vmax, cmap=cm.gray)
    if colorbars: plt.colorbar(cmap='gray')

  plt.show()

def plot_optimizer_history(loss_hist, param_hist, finished=False):
    plt.subplot(121)
    plt.title("Loss")
    plt.xlabel("itertation")
    plt.plot([l / loss_hist[0] for l in loss_hist], label="loss")
    if finished:
      plt.plot([np.argmin(loss_hist)], [[l / loss_hist[0] for l in loss_hist][np.argmin(loss_hist)]], "rx", label="optimum")
      plt.legend()
    plt.grid()

    plt.subplot(122)
    plt.xlabel("repetition")
    plt.ylabel("FA")
    plt.title("Optim. param")
    if finished: plt.plot(np.array(param_hist).T)
    else: plt.plot(np.array(param_hist[-2:]).T) # only plot current and last flip angle configuration
    plt.gca().yaxis.tick_right()
    plt.grid()

    plt.show()

def plot_optimized_flipangles(fa_optimized):
  plt.plot(fa_optimized, "o--")
  plt.xlabel("repetition")
  plt.ylabel("FA [deg]")
  plt.title("Optimized readout flip angle train")
  plt.gca().yaxis.tick_right()
  plt.grid()
  plt.show()

In [None]:
# @title setup spin system

obj_p = mr0.util.load_phantom([96,96])


In [None]:
# @title Generate optimization target

shots = 64
fa_readout = 8.0*torch.pi/180
Nread = 64
Nphase = 64
with pulseqzero.mr0_mode():
  seq, encoding = flash(fa_readout=fa_readout, shots=shots, Nread=Nread, Nphase=Nphase)

  seq0 = seq.to_mr0()
  signal,_ = mr0.util.simulate(seq0,obj_p,accuracy=1e-4)

# reconstruct image
space = reconstruction(signal, encoding, Nread, Nphase)

# plot result
plt.subplot(121)
plt.title('FFT-magnitude')
mr0.util.imshow(np.abs(space.numpy()), cmap=cm.gray)
plt.colorbar()

# store target for optimization
target = torch.abs(space)

In [None]:
# @title Simulate the inital image before optimization

shots = 1
fa_readout = torch.full((Nphase,), 8.0*torch.pi/180, requires_grad=True)

# simulate inital image
with pulseqzero.mr0_mode():
    seq, encoding = flash(fa_readout=fa_readout, shots=shots, Nread=Nread, Nphase=Nphase)

    seq0 = seq.to_mr0()
    signal,_ = mr0.util.simulate(seq0,obj_p,accuracy=1e-4)

# reconstruct image
space = reconstruction(signal, encoding, Nread, Nphase)
init = torch.abs(space) # current optimizer step image

plot_results_images(target, init)

In [None]:
# @title Perform optimization

# initalize optimizer
iterations = 100
params = [{"params": fa_readout, "lr": 0.01}]  # adjust learning rate as needed
optimizer = torch.optim.Adam(params)

loss_hist = []
FA_readout_hist = []

# optimization loop
for i in range(iterations):

    optimizer.zero_grad()

    # ====
    # simulate
    # ====

    with pulseqzero.mr0_mode():
      seq, encoding = flash(fa_readout=fa_readout, shots=shots, Nread=Nread, Nphase=Nphase)

      seq0 = seq.to_mr0()

      if i%5 == 0:
        graph = mr0.compute_graph(seq0, obj_p.build(), 100000, 1e-4)

      signal = mr0.execute_graph(graph, seq0, obj_p.build(), 1e-4, 1e-4)

    # reconstruct image
    space = reconstruction(signal, encoding, Nread, Nphase)
    image = torch.abs(space) # current optimizer step image


    # ====
    # loss computation
    # ====

    loss = ((image - target)**2).mean()
    print(f"{i+1} / {iterations}: loss={loss.item()}, fa_readout={fa_readout.detach().numpy() * 180/torch.pi}")

    loss_hist.append(loss.item())
    FA_readout_hist.append(fa_readout.detach().numpy().copy()*180/torch.pi)

    # ====
    # perform optimizer step
    # ====

    loss.backward()
    optimizer.step()

    # plot images
    plot_results_images(target, init, image)

    # optimization timeline
    plot_optimizer_history(loss_hist, FA_readout_hist)

In [None]:
# @title Evaluate optimization result

# simute optimizer result: optimal flip angle configuration
with pulseqzero.mr0_mode():
  seq, encoding = flash(fa_readout=FA_readout_hist[np.argmin(loss_hist)]*torch.pi/180, shots=1, Nread=Nphase, Nphase=Nphase)

  seq0 = seq.to_mr0()
  graph = mr0.compute_graph(seq0, obj_p.build(), 100000, 1e-8)
  signal = mr0.execute_graph(graph, seq0, obj_p.build(), 1e-8, 1e-8)  # high accuracy to check if more states are neccesary

  # reconstruct image
  space = reconstruction(signal, encoding, Nread, Nphase)
  result = torch.abs(space) # current optiumizer step image

# ====
# plot results
# ====

# images
plot_results_images(target, init, result, finished=True)

# optimization timeline
plot_optimizer_history(loss_hist, FA_readout_hist, finished=True)

# optimized flip angle configuration
plot_optimized_flipangles(FA_readout_hist[np.argmin(loss_hist)])