In [None]:
import scipy.io as sio

import math
import torch 
import torch.autograd.forward_ad as fwad

import matplotlib.pyplot as plt

import numpy as np

### defining 3D-qalas sequence parameters for simulation

In [None]:
#Specifying sequence parameters
num_reps = 5
#-number of TR's to simulate to achieve steady state signal
esp      = 5.74e-3 
#-time between flip angles in the echo trains
tf       = 128
#-Number of echoes in each echo train (turbo factor)
ro_gap   = 900e-3 
#-Gap between the acquistions in each TR
time2rel = 0
#-Relaxation time from last acquisition to end of TR

b1_val   = torch.tensor([1])
#-B1+ for simulation
inv_eff  = torch.tensor([1])
#-inversion efficiency for simulation
etl      = tf * esp
#-length of each acquistion, given turbo factor and echo spacing

#-sequence timings based on parameters defined above
delT_M1_M2   = 109.7e-3
delT_M0_M1   = ro_gap - etl - delT_M1_M2
delT_M2_M3   = etl
delT_M2_M6   = ro_gap
delT_M4_M5   = 12.8e-3
delT_M5_M6   = 100e-3 - 6.45e-3
delT_M3_M4   = delT_M2_M6 - delT_M2_M3 - delT_M4_M5 - delT_M5_M6
delT_M6_M7   = etl
delT_M7_M8   = ro_gap - etl
delT_M8_M9   = etl
delT_M9_M10  = ro_gap - etl
delT_M10_M11 = etl
delT_M11_M12 = ro_gap - etl
delT_M12_M13 = etl

#-time between end of t2 prep pulse and first acquisition
time_t2_prep_after = torch.tensor([9.7e-3])

### defining simulation function

In [None]:
def simulate(alpha,parameters,num_acqs):

    M0 = torch.tensor([1])
    Mz = torch.tensor([M0])

    Mxy_all = torch.zeros((num_acqs * tf,num_reps))

    for reps in range(num_reps):
        Mz = M0 - (M0 - Mz) * torch.exp(-delT_M0_M1/parameters[1])
        Mz = Mz * (torch.sin(b1_val * torch.pi/2)**2 * torch.exp(-(delT_M1_M2-time_t2_prep_after)/parameters[0]) + \
                  torch.cos(b1_val * torch.pi/2)**2 * torch.exp(-(delT_M1_M2-time_t2_prep_after)/parameters[1]))

        ech_ctr = 0
        acq_ctr = 0
        
        #ACQ1
        
        if(acq_ctr < num_acqs):
            for q in range(tf):
                if q == 0:
                    Mz = M0 - (M0 - Mz) * torch.exp(-time_t2_prep_after/parameters[1])
                else:
                    Mz = M0 - (M0 - Mz) * torch.exp(-esp/parameters[1])

                Mxy_all[ech_ctr,reps] = torch.sin(alpha[ech_ctr]) * Mz

                Mz = torch.cos(alpha[ech_ctr]) * Mz

                ech_ctr = ech_ctr + 1
            acq_ctr = acq_ctr + 1
        
        
        if(acq_ctr < num_acqs):
            Mz = M0 - (M0 - Mz) * torch.exp(-delT_M3_M4/parameters[1])
            Mz = -Mz * inv_eff
            Mz = M0 - (M0 - Mz) * torch.exp(-delT_M5_M6/parameters[1])

            #ACQ2
            for q in range(tf):
                if q > 0:
                    Mz = M0 - (M0 - Mz) * torch.exp(-esp/parameters[1])

                Mxy_all[ech_ctr,reps] = torch.sin(alpha[ech_ctr]) * Mz

                Mz = torch.cos(alpha[ech_ctr]) * Mz

                ech_ctr = ech_ctr + 1
            acq_ctr = acq_ctr + 1

        if(acq_ctr < num_acqs):
            Mz = M0 - (M0 - Mz) * torch.exp(-delT_M7_M8/parameters[1])

            #ACQ3
            for q in range(tf):
                if q > 0:
                    Mz = M0 - (M0 - Mz) * torch.exp(-esp/parameters[1])

                Mxy_all[ech_ctr,reps] = torch.sin(alpha[ech_ctr]) * Mz

                Mz = torch.cos(alpha[ech_ctr]) * Mz

                ech_ctr = ech_ctr + 1
            acq_ctr = acq_ctr + 1

        if(acq_ctr < num_acqs):
            Mz = M0 - (M0 - Mz) * torch.exp(-delT_M9_M10/parameters[1])

            #ACQ4
            for q in range(tf):
                if q > 0:
                    Mz = M0 - (M0 - Mz) * torch.exp(-esp/parameters[1])

                Mxy_all[ech_ctr,reps] = torch.sin(alpha[ech_ctr]) * Mz

                Mz = torch.cos(alpha[ech_ctr]) * Mz

                ech_ctr = ech_ctr + 1
            acq_ctr = acq_ctr + 1

        if(acq_ctr < num_acqs):
            Mz = M0 - (M0 - Mz) * torch.exp(-delT_M11_M12/parameters[1])

            #ACQ5
            for q in range(tf):
                if q > 0:
                    Mz = M0 - (M0 - Mz) * torch.exp(-esp/parameters[1])

                Mxy_all[ech_ctr,reps] = torch.sin(alpha[ech_ctr]) * Mz

                Mz = torch.cos(alpha[ech_ctr]) * Mz

                ech_ctr = ech_ctr + 1
    
    return Mxy_all[:,-1] * parameters[2]

### optimizing all flip angles in the sequence

In [None]:
num_acqs = 3
#-number of acquisitions we want to use

#-flip angle optimization parameters
iterations = 100
step_size  = 1e-4

#-representative tissue parameters to compute CRB for optimization
parameters = torch.tensor([[70e-3,700e-3,1],
                           [80e-3,1300e-3,1]])
#                             #t2s t1 m0

nparam,N = parameters.shape

#-initializing flip angle train with standard 4 degree flip angles
alpha    = torch.ones((tf*num_acqs)) * 4 / 180 * math.pi
alpha.requires_grad = True

#setting losses and tracking for optimization
all_losses = np.zeros((iterations,1))
alpha_init = alpha.clone()

pW = torch.tensor([1,1,1])   
#-relative weighting of each representative tissue parameter

#-defining weighting matrix for each of the parameters we want to estimate
W  = torch.zeros((N,N,nparam))
for pp in range(nparam):
    for nn in range(N): 
        W[nn,nn,pp] = 1 / parameters[pp,nn]**2
        
for ii in range(iterations):
    total_crb = 0
    for pp in range(nparam):
        primal = parameters[pp,:].clone().requires_grad_()
        tangs  = torch.eye(N)

        fwd_jac = []

        with fwad.dual_level():
            #forward pass for each input
            for tang in tangs:
                dual_input  = fwad.make_dual(primal,tang)
                dual_output = simulate(alpha,dual_input,num_acqs)

                jacobian_column = fwad.unpack_dual(dual_output).tangent
                fwd_jac.append(jacobian_column)

        fwd_jac = torch.stack(fwd_jac).T
        fim     = W[:,:,pp]@torch.inverse(fwd_jac.T @ fwd_jac)
        
        crb     = torch.real(torch.trace(fim)) * pW[pp]           
        print('crb %d: %.2f' % (pp+1,crb))
        total_crb = total_crb + crb
        
    #penalizing l2-norm of times
    loss = total_crb
    
    print('iteration %d/%d || crb: %.2f || loss: %.2f' % (ii+1,iterations,total_crb,loss))
    all_losses[ii] = loss.detach().cpu().numpy()

    g_al = torch.autograd.grad(loss,alpha)[0]

    alpha = alpha - step_size * g_al

### visualizing solution

In [None]:
nparam = 1

alpha_orig  = torch.ones((128*5)) * 4 / 180 * math.pi

init_signal = simulate(alpha_orig,parameters[nparam,:],num_acqs=5)
sol_signal  = simulate(alpha,parameters[nparam,:],num_acqs)

print('min crb: %f' % all_losses[-1])
plt.figure
plt.plot(all_losses.squeeze())
plt.legend({'crb'})
plt.show()

plt.figure
plt.plot(alpha_init.detach().numpy() / math.pi*180)
plt.plot(alpha.detach().numpy() / math.pi*180)
plt.show()

plt.figure
plt.plot(torch.real(init_signal.squeeze().detach()))
plt.plot(torch.real(sol_signal.squeeze().detach()))
plt.show()