In [None]:
from typing import List, Optional

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import scipy.io as sio
import numpy as np
import time 
import matplotlib.pyplot as plt

import cfl
import bart 
import math

from python_utils import models
from python_utils import signalprocessing as sig
from python_utils import simulator

### helper functions

In [None]:
def roll_one_dim(x: torch.Tensor, shift: int, dim: int) -> torch.Tensor:
    """
    Similar to roll but for only one dim.
    Args:
        x: A PyTorch tensor.
        shift: Amount to roll.
        dim: Which dimension to roll.
    Returns:
        Rolled version of x.
    """
    shift = shift % x.size(dim)
    if shift == 0:
        return x

    left = x.narrow(dim, 0, x.size(dim) - shift)
    right = x.narrow(dim, x.size(dim) - shift, shift)

    return torch.cat((right, left), dim=dim)

def roll(
    x: torch.Tensor,
    shift: List[int],
    dim: List[int],
) -> torch.Tensor:
    """
    Similar to np.roll but applies to PyTorch Tensors.
    Args:
        x: A PyTorch tensor.
        shift: Amount to roll.
        dim: Which dimension to roll.
    Returns:
        Rolled version of x.
    """
    if len(shift) != len(dim):
        raise ValueError("len(shift) must match len(dim)")

    for (s, d) in zip(shift, dim):
        x = roll_one_dim(x, s, d)

    return x


def fftshift(x: torch.Tensor, dim: Optional[List[int]] = None) -> torch.Tensor:
    """
    Similar to np.fft.fftshift but applies to PyTorch Tensors
    Args:
        x: A PyTorch tensor.
        dim: Which dimension to fftshift.
    Returns:
        fftshifted version of x.
    """
    if dim is None:
        # this weird code is necessary for toch.jit.script typing
        dim = [0] * (x.dim())
        for i in range(1, x.dim()):
            dim[i] = i

    # also necessary for torch.jit.script
    shift = [0] * len(dim)
    for i, dim_num in enumerate(dim):
        shift[i] = x.shape[dim_num] // 2

    return roll(x, shift, dim)


def ifftshift(x: torch.Tensor, dim: Optional[List[int]] = None) -> torch.Tensor:
    """
    Similar to np.fft.ifftshift but applies to PyTorch Tensors
    Args:
        x: A PyTorch tensor.
        dim: Which dimension to ifftshift.
    Returns:
        ifftshifted version of x.
    """
    if dim is None:
        # this weird code is necessary for toch.jit.script typing
        dim = [0] * (x.dim())
        for i in range(1, x.dim()):
            dim[i] = i

    # also necessary for torch.jit.script
    shift = [0] * len(dim)
    for i, dim_num in enumerate(dim):
        shift[i] = (x.shape[dim_num] + 1) // 2

    return roll(x, shift, dim)

tonpy = lambda x: torch.view_as_complex(x).cpu().numpy()

### loading dataset and setting model parameters

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model_type     = 3
latent_num     = 1
epoch          = 200000

In [None]:
mask    = cfl.readcfl('data/mask')
kspace  = cfl.readcfl('data/kspace')
coils   = cfl.readcfl('data/coils')

### defining operators and preparing k-space

In [None]:
[M,N,C,E] = kspace.shape

coils_torch   = torch.stack((torch.tensor(np.real(coils),dtype=torch.float),\
                torch.tensor(np.imag(coils),dtype=torch.float)),dim=-1).to(device).reshape(M,N,C,1,2)
coils_torch.requires_grad = False

mask_torch    = torch.stack((torch.tensor(mask,dtype=torch.float),\
                             torch.tensor(mask,dtype=torch.float)),dim=-1).to(device)
mask_torch.requires_grad = False

In [None]:
def sfor(x,CC):
    '''
    Performs forward coil sensitivity multiplication
    inputs:
        x  (M x N x 1 x T x 2)    - time series of images
        CC (M x N x C x 1 x 2)    - coil sensitivity functions
    output
        out (M x N x C x T x 2)   - output
    '''
    return torch.stack((x[..., 0] * CC[..., 0] - x[..., 1] * CC[..., 1],\
                       x[..., 0] * CC[..., 1] + x[..., 1] * CC[..., 0]),-1)

def ffor(data):
    '''
    Performs a forward fourier transform
    inputs:
          x (M x N x C x T x 2)
    outputs:
        out (M x N x X x T x 2)
             
    '''
    
    return ifftshift(torch.fft(fftshift(data,dim = [0,1]).permute(2,3,0,1,4),signal_ndim = 2,normalized = True).permute(2,3,0,1,4),dim = [0,1])

def R(data,mask):
    '''
    Apply undersampling mask to some input
    inputs:
        data  (M x N x C x T x 2)     - input to be masked
        mask  (M x N x 1 x T x 1)     - mask to be applied
    '''
    
    return data * mask

def Dfor(model):
    '''
    Takes an input data through the decoder trained for compressing signal evolution
    input
        data (M*N x L x 2)   - input data
        model              - neural network model
    output
        out (256 x 256 x 1 x T x 2)
    '''
    
    out = model.decode(x)
    return torch.stack((out * xr,out * xi),-1).reshape(M,N,1,E,2)

kspace_torch   = torch.stack((torch.tensor(np.real(kspace),dtype=torch.float),\
                torch.tensor(np.imag(kspace),dtype=torch.float)),dim=-1).to(device)
kspace_torch.requires_grad = False

kspace_torch = R(kspace_torch,mask_torch)

### decoder reconstruction

In [None]:
model_path = 'data/model_type_%d_latent_%d_epochs_%d.pt' % (model_type,latent_num,epoch)
model = torch.load(model_path).to(device)

for param in model.parameters():
    param.requires_grad = False

In [None]:
import time

lam         = 0
iterations  = 2000
lr          = 1e7

saveiter    = 50
dispiter    = 10

x  = torch.zeros(M*N,latent_num).to(device).detach()

xr = torch.zeros(M*N,1).to(device).detach() 
xi = torch.zeros(M*N,1).to(device).detach() 


x.requires_grad  = True
xr.requires_grad = True
xi.requires_grad = True

criterion   = nn.MSELoss()

all_T2   = []
all_real = []
all_imag = []

tic = time.perf_counter()

for iter in range(iterations):

    loss   = criterion(kspace_torch,R(ffor(sfor(Dfor(model),coils_torch)),mask_torch))

    gx = torch.autograd.grad(loss, 
                x, 
                create_graph = not True,retain_graph=True)[0]
    gr = torch.autograd.grad(loss, 
                xr, 
                create_graph = not True,retain_graph=True)[0]
    gi = torch.autograd.grad(loss, 
                xi, 
                create_graph = not True,retain_graph=True)[0]

    x  = x  - gx * lr
    xr = xr - gr * lr
    xi = xi - gi * lr

    running_loss = loss.item()

    if iter % dispiter == 0:
        toc = time.perf_counter()
        print('iteration %d / %d, current loss: %.12f / elapsed time: %.2f (s)' % (iter,iterations,running_loss, toc-tic))
        tic = time.perf_counter()
        
    if iter % saveiter == 0:
        all_T2.append(np.squeeze(x.reshape(M,N,latent_num).detach()).cpu().numpy())
        all_real.append(np.squeeze(xr.reshape(M,N,1).detach()).cpu().numpy())
        all_imag.append(np.squeeze(xi.reshape(M,N,1).detach()).cpu().numpy())

T2    = np.squeeze(x.reshape(M,N,latent_num).detach()).cpu().numpy()
real  = np.squeeze(xr.reshape(M,N,1).detach()).cpu().numpy()
imag  = np.squeeze(xi.reshape(M,N,1).detach()).cpu().numpy()

timeseries_decoder = tonpy((Dfor(model).detach().squeeze().contiguous()))

del x, xr, xi, gx, gr, gi, loss

In [None]:
display = np.concatenate((np.expand_dims(T2,axis=-1),\
                          np.expand_dims(real,axis=-1),\
                          np.expand_dims(imag,axis=-1)),axis=-1).transpose(2,0,1)
sig.mosaic(sig.nor(display),1,2 + latent_num)

### estimating t2 maps with dictionary matching in latent space

In [None]:
device = torch.device("cuda:0")

#preparing objectos to EPG simulate dictionary
spacing     = 11.5 / 1000;
angles_rad  = (torch.ones(1,E) * 180 * math.pi / 180).to(device)
angle_exc   = (torch.ones(1,1) * 90 * math.pi/180).to(device)
t2_range    = torch.linspace(0,400,1000).to(device) / 1000
t1_range    = torch.ones(1000).to(device)*1000 / 1000

angles_rad.requires_grad = False
angle_exc.requires_grad  = False
t1_range.requires_grad   = False
t2_range.requires_grad   = False

In [None]:
model.to(device)
# simulating the dictionary and encoding it as latent variables
dictionary_epg        = simulator.FSE_signal2_ex(angle_exc,angles_rad,spacing,t1_range,t2_range)[0].squeeze()
dictionary_epg_latent = model.encode(dictionary_epg).cpu()

In [None]:
all_tmap_est = []

for ii in range(len(all_T2)):   
    # performing vectorized dictionary 
    min_indices = torch.argmin(torch.abs(dictionary_epg_latent.permute((1,0)) - T2.reshape((M*N,1))),1)

    # estimating t2 map
    t2map_est = t2_range[min_indices].reshape(M,N).cpu().numpy()
    
    all_tmap_est.append(t2map_est)

### computing EPG gradients

In [None]:
all_norms_x  = np.zeros(len(all_tmap_est))
all_norms_xr = np.zeros(len(all_tmap_est))

for ii in range(len(all_tmap_est)):
    print('gradient %d/%d' % (ii+1,len(all_tmap_est)))
    x_epg   = torch.tensor(all_tmap_est[ii],dtype = torch.float).reshape(M*N).to(device).detach()
    xr_epg  = torch.tensor(all_real[ii],dtype = torch.float).reshape(M*N,1,1).to(device).detach()

    xi_epg  = torch.tensor(all_imag[ii],dtype = torch.float).reshape(M*N,1,1).to(device).detach()
    x_t1    = torch.ones(M*N).to(device)

    x_epg.requires_grad  = True
    xr_epg.requires_grad = True
    xi_epg.requires_grad = False
    x_t1.requires_grad   = False

    out    = simulator.FSE_signal2_ex(angle_exc,angles_rad,spacing,x_t1,x_epg)[0]
    out    = torch.stack((out * xr_epg,out * xi_epg),-1).reshape(M,N,1,E,2)
    loss   = criterion(kspace_torch,R(ffor(sfor(out,coils_torch)),mask_torch))

    g = torch.autograd.grad(loss, 
                    x_epg, 
                    create_graph = not True,retain_graph=True)[0]
    gp = torch.autograd.grad(loss, 
                xr_epg, 
                create_graph = not True,retain_graph=False)[0]
    
    all_norms_x[ii]  = torch.norm(g)
    all_norms_xr[ii] = torch.norm(gp)
    
    del x_epg,xr_epg,xi_epg,x_t1,out,loss,g,gp

### plotting the norm of the gradients

In [None]:
nor = np.max(np.concatenate((all_norms_x,all_norms_xr)))
lw  = 3

plt.title('norm of gradients in EPG Forward Model',fontsize = 24)
plt.plot(all_norms_x / nor,linewidth=lw)
plt.plot(all_norms_xr /nor,linewidth=lw)
plt.legend(['T2 gradient','density gradient'],prop={'size':20})
plt.show()

### EPG reconstruction initialized with decoder reconstruction

In [None]:
iterations = 200
lr         = 1e7

x_epg   = torch.tensor(all_tmap_est[-1],dtype = torch.float).reshape(M*N).to(device).detach()
xr_epg  = torch.tensor(all_real[-1],dtype = torch.float).reshape(M*N,1,1).to(device).detach()
xi_epg  = torch.tensor(all_imag[-1],dtype = torch.float).reshape(M*N,1,1).to(device).detach()
x_t1    = torch.ones(M*N).to(device)

x_epg.requires_grad  = True
xr_epg.requires_grad = True
xi_epg.requires_grad = False
x_t1.requires_grad   = False

tic = time.perf_counter()
for iter in range(iterations):
    out    = simulator.FSE_signal2_ex(angle_exc,angles_rad,spacing,x_t1,x_epg)[0]
    out    = torch.stack((out * xr_epg,out * xi_epg),-1).reshape(M,N,1,E,2)
    loss   = criterion(kspace_torch,R(ffor(sfor(out,coils_torch)),mask_torch))

    g = torch.autograd.grad(loss, 
                    x_epg, 
                    create_graph = not True,retain_graph=True)[0]
    gp = torch.autograd.grad(loss, 
                xr_epg, 
                create_graph = not True,retain_graph=False)[0]

    x_epg  = x_epg  - lr * g
    xr_epg = xr_epg - lr * gp

    running_loss = loss.item()
    
    if iter % dispiter == 0:
        toc = time.perf_counter()
        print('iteration %d / %d, current loss: %.12f / elapsed: %.2f' % (iter,iterations,running_loss,toc-tic))
        tic = time.perf_counter()

    del out,loss,g,gp

In [None]:
T2_epg    = np.squeeze(x_epg.reshape(M,N,latent_num).detach()).cpu().numpy()
real_epg  = np.squeeze(xr_epg.reshape(M,N,1).detach()).cpu().numpy()
imag_epg  = np.squeeze(xi_epg.reshape(M,N,1).detach()).cpu().numpy()

out    = simulator.FSE_signal2_ex(angle_exc.detach().cpu(),angles_rad.detach().cpu(),spacing,x_t1.detach().cpu(),x_epg.detach().cpu())[0]
out    = torch.stack((out * xr_epg.detach().cpu(),out * xi_epg.detach().cpu()),-1).reshape(M,N,1,E,2)

timeseries_epg = tonpy(out.detach().squeeze().contiguous())

del out

### comparing decoder and epg initialized decoder reconstruction

In [None]:
timeseries_truth = np.abs(cfl.readcfl('data/timeseries_truth'))

#Masking out results outside of the brain
brain_mask = np.zeros((M,N))
brain_mask[np.where(np.abs(timeseries_truth[:,:,0])>0)] =1
brain_mask = np.expand_dims(brain_mask,axis=2)

In [None]:
lin_idx     = 0
regidx_bart = 0 

ech          = 0

truth_show   = sig.nor(timeseries_truth[:,:,ech:ech+1]*brain_mask)
decoder_show = sig.nor(np.abs(timeseries_decoder[:,:,ech:ech+1])*brain_mask)
epg_show     = sig.nor(np.abs(timeseries_epg[:,:,ech:ech+1])*brain_mask) 

display = np.concatenate((decoder_show,epg_show),axis = -1).transpose(2,0,1)

sig.mosaic(sig.nor(display),1,2,clim=[0,1])

print('  proposed:  %.2f' % (sig.rmse(truth_show,decoder_show)*100))
print('  epg:       %.2f' % (sig.rmse(truth_show,epg_show)*100))