In [None]:
!pip install pypulseq==1.3.1.post1 &> /dev/null
!pip install nevergrad &> /dev/null
!pip install ismrmrd
!pip install MRzeroCore &> /dev/null
!wget https://github.com/MRsources/MRzero-Core/raw/main/documentation/playground_mr0/numerical_brain_cropped.mat &> /dev/null

In [None]:
# @title On Google Colab, you need to restart the runtime after executing this cell
!pip install numpy==1.24

(mr00_FLASH_2D_ernstAngle_opt)=
# Ernst angle optimization

In [None]:
#@title setup basic sequence params & phantom
# %% S0. SETUP env
import MRzeroCore as mr0
import numpy as np

import pypulseq as pp
import torch
import matplotlib.pyplot as plt

import nevergrad as ng

from IPython.display import clear_output

plt.rcParams['figure.figsize'] = [10, 5]
plt.rcParams['figure.dpi'] = 100 # 200 e.g. is really fine, but slower



# %% S1. SETUP sys

# choose the scanner limits
system = pp.Opts(max_grad=28,grad_unit='mT/m',max_slew=150,slew_unit='T/m/s',
                 rf_ringdown_time=20e-6,rf_dead_time=100e-6,adc_dead_time=20e-6,grad_raster_time=50*10e-6)

# Define FOV and resolution
fov = 200e-3
slice_thickness = 8e-3
sz = (32, 32)   # spin system size / resolution
Nread = 64    # frequency encoding steps/samples
Nphase = 64    # phase encoding steps/samples

# %% S4: SETUP SPIN SYSTEM/object on which we can run the MR sequence external.seq from above

sz = [64, 64]
# (i) load a phantom object from file
obj_p = mr0.VoxelGridPhantom.load_mat('numerical_brain_cropped.mat')
obj_p = obj_p.interpolate(sz[0], sz[1], 1)
# Manipulate loaded data
obj_p.T2dash[:] = 30e-3
obj_p.D *= 0
obj_p.B0 *= 1    # alter the B0 inhomogeneity
# Store PD and B0 for comparison
PD = obj_p.PD
B0 = obj_p.B0
obj_p.plot()
# Convert Phantom into simulation data
obj_p = obj_p.build()

In [None]:
#@title set up functions for generating, simulating and reconstructing FLASH sequence
def ifft2d(x):
    x = torch.fft.fftshift(x)
    x = torch.fft.ifft2(x)
    x = torch.fft.ifftshift(x)
    return x

def fft2d(x):
    x = torch.fft.ifftshift(x)
    x = torch.fft.fft2(x)
    x = torch.fft.fftshift(x)
    return x

def generate_flash_seq(FA=10, fname='gre.seq', verbose=0):

    # %% S2. DEFINE the sequence
    seq = pp.Sequence()

    # Define rf events
    rf1, _, _ = pp.make_sinc_pulse(
        flip_angle= FA * np.pi / 180, duration=1e-3,
        slice_thickness=slice_thickness, apodization=0.5, time_bw_product=4,
        system=system, return_gz=True
    )
    # rf1 = pp.make_block_pulse(flip_angle=90 * np.pi / 180, duration=1e-3, system=system)

    # Define other gradients and ADC events
    gx = pp.make_trapezoid(channel='x', flat_area=Nread / fov, flat_time=10e-3, system=system)
    adc = pp.make_adc(num_samples=Nread, duration=10e-3, phase_offset=0 * np.pi/180, delay=gx.rise_time, system=system)
    gx_pre = pp.make_trapezoid(channel='x', area=-gx.area / 2, duration=5e-3, system=system)
    gx_spoil = pp.make_trapezoid(channel='x', area=1.5 * gx.area, duration=2e-3, system=system)

    rf_phase = 0
    rf_inc = 0
    rf_spoiling_inc = 117

    # ======
    # CONSTRUCT SEQUENCE
    # ======
    ##linear reordering
    phenc = np.arange(-Nphase // 2, Nphase // 2, 1) / fov
    permvec =np.arange(0, Nphase, 1)
    ## centric reordering
    #permvec = sorted(np.arange(len(phenc)), key=lambda x: abs(len(phenc) // 2 - x))
    ## random reordering
    #perm =np.arange(0, Nphase, 1);  permvec = np.random.permutation(perm)

    phenc_centr = phenc[permvec]

    for ii in range(0, Nphase):  # e.g. -64:63

        rf1.phase_offset = rf_phase / 180 * np.pi   # set current rf phase

        adc.phase_offset = rf_phase / 180 * np.pi  # follow with ADC
        rf_inc = divmod(rf_inc + rf_spoiling_inc, 360.0)[1]   # increase increment
        # increment additional pahse
        rf_phase = divmod(rf_phase + rf_inc, 360.0)[1]

        seq.add_block(rf1)
        seq.add_block(pp.make_delay(0.005))
        gp = pp.make_trapezoid(channel='y', area=phenc_centr[ii], duration=5e-3, system=system)
        seq.add_block(gx_pre, gp)
        seq.add_block(adc, gx)
        gp = pp.make_trapezoid(channel='y', area=-phenc_centr[ii], duration=5e-3, system=system)
        seq.add_block(gx_spoil, gp)
        if ii < Nphase - 1:
            seq.add_block(pp.make_delay(0.01))


    # %% S3. CHECK, PLOT and WRITE the sequence  as .seq
    # Check whether the timing of the sequence is correct
    ok, error_report = seq.check_timing()
    if ok:
        if verbose > 0:
            print('Timing check passed successfully')
    else:
        print('Timing check failed. Error listing follows:')
        [print(e) for e in error_report]

    # PLOT sequence
    if verbose > 0:
        sp_adc, t_adc = mr0.util.pulseq_plot(seq, clear=False, figid=(11,12))

    # Prepare the sequence output for the scanner
    seq.set_definition('FOV', [fov, fov, slice_thickness])
    seq.set_definition('Name', 'gre')
    seq.write(fname)

    reco_params = {'permvec': permvec}

    return reco_params

def simu_seq(fname, obj_p, reco_params, noiselevel=1e-4, verbose=0):

    permvec = reco_params['permvec']

    # %% S5:. SIMULATE  the external.seq file and add acquired signal to ADC plot
    # Read in the sequence
    seq0 = mr0.Sequence.import_file(fname)
    if verbose > 0:
        seq0.plot_kspace_trajectory()
    # Simulate the sequence
    graph = mr0.compute_graph(seq0, obj_p, 200, 1e-3)
    signal = mr0.execute_graph(graph, seq0, obj_p, print_progress=False)

    # PLOT sequence with signal in the ADC subplot
    if verbose > 0:
        plt.close(11);plt.close(12)
        sp_adc, t_adc = mr0.util.pulseq_plot(seq, clear=False, signal=signal.numpy())

    # additional noise as simulation is perfect
    signal += noiselevel * np.random.randn(signal.shape[0], 2).view(np.complex128)


    # %% S6: MR IMAGE RECON of signal ::: #####################################
    if verbose > 0:
        fig = plt.figure()  # fig.clf()
        plt.subplot(411)
        plt.title('ADC signal')
        plt.plot(torch.real(signal), label='real')
        plt.plot(torch.imag(signal), label='imag')
        # this adds ticks at the correct position szread
        major_ticks = np.arange(0, Nphase * Nread, Nread)
        ax = plt.gca()
        ax.set_xticks(major_ticks)
        ax.grid()

    kspace = torch.reshape((signal), (Nphase, Nread)).clone().t()

    ipermvec = np.argsort(permvec)

    kspace=kspace[:,ipermvec]

    img = fft2d(kspace)

    return img


In [None]:
#@title nevergrad optimization
def calc_loss(FA): # loss function
    global iter
    reco_params = generate_flash_seq(FA=FA, fname='gre.seq')
    img = simu_seq('gre.seq', obj_p, reco_params, noiselevel=0*1e-3)
    mag = torch.sum(img.flatten().abs()**2)

    clear_output(wait=True)
    plt.figure(figsize=(13,5))
    plt.subplot(1,2,1)
    plt.imshow(img.abs()), plt.colorbar()
    plt.title(f'iter {iter}: FA={FA[0]:.2f}, MAG={mag.item():.2f}')
    plt.subplot(1,2,2)
    plt.plot(values,'.-')
    plt.xlabel('iteration'), plt.ylabel('loss')
    plt.show()

    iter += 1

    return -mag.item()

# def calc_loss(FA):
#   E1 = np.exp(-20e-3/1)
#   FArad = np.deg2rad(FA)
#   S = np.sin(FArad) * (1-E1)/(1-np.cos(FArad)*E1)
#   return -S

def rescale_vars(x, a,b,c,d):
    # original range (a,b)
    # new range (c,d)
    return ((x-a) / (b-a)) * (d-c) + c

def obj_fun_rescaled(x1): # rescaled loss functions, shuch that optimizer sees only normrange
    if type(x1) == list:
        x1 = np.array(x1)
    x1r = rescale_vars(x1, *normrange, *valrange)
    return calc_loss(x1r)

def print_candidate_and_value(optimizer, candidate, value): # callback, print and save intermediate steps
    global cands, values, xx
    # print('iter', xx, 'cand:', candidate, 'val:', value)
    cands.append(candidate)
    values.append(value)

iter = 0 # global iteration counter

# number of cost function evaluations ("budget")
# Limited for building docs - should be increased
niter = 10

# this is the range in which the optimizer operates, see https://cma-es.github.io/cmaes_sourcecode_page.html#practical
normrange = (-3,3)

# boundaries in physical units
valrange = (1e-1,180)

# initial value (not sure if it has any influence, probably depending on optimizer)
init = np.array([5]) # physical units

# defining optimizable variables ("instrumentation")
instrum = ng.p.Instrumentation(
    ng.p.Array(init=rescale_vars(init,*valrange,*normrange)).set_bounds(*normrange),
)

cands = [] # to save all candidates during opt
values = [] # to save loss values during opt

optimizer = ng.optimizers.NGOpt(parametrization=instrum, budget=niter) # documentation says NGOpt is a good first choice, this is a "meta-optimizer" that chooses algorithm based on instrumentation
# optimizer = ng.optimizers.registry["PSO"](parametrization=instrum, budget=niter) # particle swarm
# optimizer = ng.families.ParametrizedBO()(parametrization=instrum, budget=niter) # Bayesian optimization, this might be good for continuous in dim > 1, dim < ~100 (?)
# optimizer = ng.families.NonObjectOptimizer(method='Powell')(parametrization=instrum, budget=niter) # more traditional grad free things
# optimizer = ng.families.NonObjectOptimizer(method='NLOPT_GN_DIRECT')(parametrization=instrum, budget=niter) # more traditional grad free things
# optimizer = ng.families.ParametrizedCMA()(parametrization=instrum, budget=niter) # only in dim > 1
# optimizer = ng.families.RandomSearchMaker()(parametrization=instrum, budget=niter) # random search as baseline

optimizer.register_callback("tell", print_candidate_and_value) # set callback

recommendation = optimizer.minimize(obj_fun_rescaled) # run opt
FAopt = rescale_vars(recommendation[0][0].value, *normrange, *valrange)

print("final result:", FAopt)  # opt result
print("used optimizer", optimizer._optim)


In [None]:
#@title some details on optimization history
x_explored = [cands[ii][0].value[0] for ii in range(len(cands))] # extract optimization history
FA_explored = np.array([rescale_vars(x, *normrange, *valrange) for x in x_explored]) # rescale back

plt.figure(figsize=(7,12))
plt.subplot(4,1,1)
plt.plot(FA_explored,values,'.')
plt.xlabel('FA [deg]'), plt.ylabel('loss')

plt.subplot(4,1,2)
plt.plot(FA_explored,'.-')
plt.ylabel('FA [deg]')

plt.subplot(4,1,3)
plt.plot(values,'.-')
plt.ylabel('loss')

plt.subplot(4,1,4)
plt.hist(FA_explored, bins=30)

In [None]:
#@title manual line search
iter=0
FAs = np.linspace(0,90, 10)  # Reduced precision for building docs
losses = np.zeros(FAs.shape)
for ii,FA in enumerate(FAs):
  losses[ii] = calc_loss(np.array([FA]))

plt.figure()
plt.plot(FAs,losses,'.-')
plt.xlabel('FA'), plt.ylabel('loss')