In [None]:
from typing import List, Optional

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import scipy.io as sio
import numpy as np
import time 
import matplotlib.pyplot as plt

import bart 

from python_utils import models
from python_utils import signalprocessing as sig

### helper functions 

In [None]:
def roll_one_dim(x: torch.Tensor, shift: int, dim: int) -> torch.Tensor:
    """
    Similar to roll but for only one dim.
    Args:
        x: A PyTorch tensor.
        shift: Amount to roll.
        dim: Which dimension to roll.
    Returns:
        Rolled version of x.
    """
    shift = shift % x.size(dim)
    if shift == 0:
        return x

    left = x.narrow(dim, 0, x.size(dim) - shift)
    right = x.narrow(dim, x.size(dim) - shift, shift)

    return torch.cat((right, left), dim=dim)

def roll(
    x: torch.Tensor,
    shift: List[int],
    dim: List[int],
) -> torch.Tensor:
    """
    Similar to np.roll but applies to PyTorch Tensors.
    Args:
        x: A PyTorch tensor.
        shift: Amount to roll.
        dim: Which dimension to roll.
    Returns:
        Rolled version of x.
    """
    if len(shift) != len(dim):
        raise ValueError("len(shift) must match len(dim)")

    for (s, d) in zip(shift, dim):
        x = roll_one_dim(x, s, d)

    return x


def fftshift(x: torch.Tensor, dim: Optional[List[int]] = None) -> torch.Tensor:
    """
    Similar to np.fft.fftshift but applies to PyTorch Tensors
    Args:
        x: A PyTorch tensor.
        dim: Which dimension to fftshift.
    Returns:
        fftshifted version of x.
    """
    if dim is None:
        # this weird code is necessary for toch.jit.script typing
        dim = [0] * (x.dim())
        for i in range(1, x.dim()):
            dim[i] = i

    # also necessary for torch.jit.script
    shift = [0] * len(dim)
    for i, dim_num in enumerate(dim):
        shift[i] = x.shape[dim_num] // 2

    return roll(x, shift, dim)


def ifftshift(x: torch.Tensor, dim: Optional[List[int]] = None) -> torch.Tensor:
    """
    Similar to np.fft.ifftshift but applies to PyTorch Tensors
    Args:
        x: A PyTorch tensor.
        dim: Which dimension to ifftshift.
    Returns:
        ifftshifted version of x.
    """
    if dim is None:
        # this weird code is necessary for toch.jit.script typing
        dim = [0] * (x.dim())
        for i in range(1, x.dim()):
            dim[i] = i

    # also necessary for torch.jit.script
    shift = [0] * len(dim)
    for i, dim_num in enumerate(dim):
        shift[i] = (x.shape[dim_num] + 1) // 2

    return roll(x, shift, dim)

tonpy = lambda x: torch.view_as_complex(x).cpu().numpy()

### loading data

In [None]:
dictionary = sio.loadmat('data/dictionary.mat')['dictionary']
fort2shfl  = sio.loadmat('data/for_t2shfl.mat')

mask      = fort2shfl['for_t2shfl'][0][0][0]
kspace    = fort2shfl['for_t2shfl'][0][0][1]
coils     = fort2shfl['for_t2shfl'][0][0][2]
truth     = fort2shfl['for_t2shfl'][0][0][3]

all_num_latent_linear = [2,3,4];
latent_num = 1

### generating linear subspace

In [None]:
all_basis = []

for num_latent_linear in all_num_latent_linear:
    [U,S,V] = np.linalg.svd(dictionary,full_matrices=False)
    basis   = U[:,:num_latent_linear]
    
    all_basis.append(basis)

### building forward model and generating k-space

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

[M,N,C,E] = kspace.shape

# basis_torch   = torch.stack((torch.tensor(np.real(basis),dtype=torch.float),\
#                 torch.tensor(np.imag(basis),dtype=torch.float)),dim=-1).to(device)
# basis_torch.requires_grad = False

coils_torch   = torch.stack((torch.tensor(np.real(coils),dtype=torch.float),\
                torch.tensor(np.imag(coils),dtype=torch.float)),dim=-1).to(device).reshape(M,N,C,1,2)
coils_torch.requires_grad = False

mask_torch    = torch.stack((torch.tensor(mask,dtype=torch.float),\
                             torch.tensor(mask,dtype=torch.float)),dim=-1).to(device)
mask_torch.requires_grad = False

In [None]:
def sfor(x,CC):
    '''
    Performs forward coil sensitivity multiplication
    inputs:
        x  (M x N x 1 x T x 2)    - time series of images
        CC (M x N x C x 1 x 2)    - coil sensitivity functions
    output
        out (M x N x C x T x 2)   - output
    '''
    return torch.stack((x[..., 0] * CC[..., 0] - x[..., 1] * CC[..., 1],\
                       x[..., 0] * CC[..., 1] + x[..., 1] * CC[..., 0]),-1)

def ffor(data):
    '''
    Performs a forward fourier transform
    inputs:
          x (M x N x C x T x 2)
    outputs:
        out (M x N x X x T x 2)
             
    '''
    
    return ifftshift(torch.fft(fftshift(data,dim = [0,1]).permute(2,3,0,1,4),signal_ndim = 2,normalized = True).permute(2,3,0,1,4),dim = [0,1])

def R(data,mask):
    '''
    Apply undersampling mask to some input
    inputs:
        data  (M x N x C x T x 2)     - input to be masked
        mask  (M x N x 1 x T x 1)     - mask to be applied
    '''
    
    return data * mask

def Dfor(model):
    '''
    Takes an input data through the decoder trained for compressing signal evolution
    input
        data (M*N x L x 2)   - input data
        model              - neural network model
    output
        out (256 x 256 x 1 x T x 2)
    '''
    
    out = model.decode(x)
    return torch.stack((out * xr,out * xi),-1).reshape(M,N,1,E,2)

In [None]:
kspace_torch   = torch.stack((torch.tensor(np.real(kspace),dtype=torch.float),\
                torch.tensor(np.imag(kspace),dtype=torch.float)),dim=-1).to(device)
kspace_torch.requires_grad = False

kspace_torch = R(kspace_torch,mask_torch)

### loading auto-encoder model 

In [None]:
model = torch.load('data/model_type_2_latent_1_epochs_200000.pt').to(device)

for param in model.parameters():
    param.requires_grad = False

### performing reconstruction 

In [None]:
x  = torch.zeros(M*N,latent_num).to(device).detach()

xr = torch.ones(M*N,1).to(device).detach() 
xi = torch.zeros(M*N,1).to(device).detach() 


x.requires_grad  = True
xr.requires_grad = True
xi.requires_grad = True

criterion   = nn.MSELoss()
optimizer   = optim.Adam([x,xr,xi],lr = 1e0)
iterations  = 2000

for iter in range(iterations):
    optimizer.zero_grad()
    
    loss   = criterion(kspace_torch,R(ffor(sfor(Dfor(model),coils_torch)),mask_torch))
    loss.backward()
    optimizer.step()
    
    running_loss = loss.item()
    
    if iter % 1 == 0:
        print('iteration %d / %d, current loss: %.12f' % (iter,iterations,running_loss))

In [None]:
T2    = np.squeeze(x.reshape(M,N,latent_num).detach()).cpu().numpy()
real  = np.squeeze(xr.reshape(M,N,1).detach()).cpu().numpy()
imag  = np.squeeze(xi.reshape(M,N,1).detach()).cpu().numpy()

timeseries_decoder = tonpy((Dfor(model).detach().squeeze().contiguous()))

### computing bart reconstructions

In [None]:
import cfl

all_coeffs_bart     = []
all_timeseries_bart = []
ctr = 1
for basis in all_basis:
    print('reconstruction %d' % ctr)
    ctr+=1
    
    B = basis.shape[1]
    
    kspace_bart = (kspace * mask).reshape(M,N,1,C,1,E)
    coils_bart  = coils.reshape(M,N,1,C,1)
    basis_bart  = basis.reshape(1,1,1,1,1,E,B)

    cfl.writecfl('bart_basis/basis_bart',basis_bart)
    coeffs_bart = np.squeeze(bart.bart(1,'pics -B bart_basis/basis_bart -i 50',kspace_bart,coils_bart)).transpose(2,0,1)
    print(' done')
    
    timeseries_bart = (basis @ coeffs_bart.reshape(B,M*N)).reshape(E,M,N).transpose(1,2,0)
    
    all_coeffs_bart.append(coeffs_bart)
    all_timeseries_bart.append(timeseries_bart)

### comparisons at a partcular echo 

In [None]:
lin_idx = 1

#Masking out results outside of the brain
brain_mask = np.zeros((M,N))
brain_mask[np.where(np.abs(truth[:,:,0,0])>0)] =1
brain_mask = np.expand_dims(brain_mask,axis=2)
ech = 0

timeseries_truth = np.abs(np.squeeze(truth))

truth_show   = sig.nor(timeseries_truth[:,:,ech:ech+1]*brain_mask)
decoder_show = sig.nor(np.abs(timeseries_decoder[:,:,ech:ech+1])*brain_mask)
bart_show    = sig.nor(np.abs(all_timeseries_bart[lin_idx][:,:,ech:ech+1]*brain_mask))

display = np.concatenate((truth_show,bart_show,decoder_show),axis = -1).transpose(2,0,1)

sig.mosaic(sig.nor(display),1,3)

print('rmse at echo %d' % ech)
print('  bart:      %.2f' % (sig.rmse(truth_show,bart_show)*100))
print('  proposed:  %.2f' % (sig.rmse(truth_show,decoder_show)*100))

### rmse comparison plot across all echoes

In [None]:
rmse_decoder = np.zeros(E)
rmse_bart    = np.zeros((E,len(all_num_latent_linear)))

decoder    = timeseries_decoder * brain_mask
truth_comp = timeseries_truth * brain_mask
for ee in range(E):
    rmse_decoder[ee] = sig.rmse(sig.nor(truth_comp[:,:,ee]),sig.nor(np.abs(decoder[:,:,ee])))
    
for ll in range(len(all_num_latent_linear)):
    bart_comp    = all_timeseries_bart[ll] * brain_mask
    
    for ee in range(E):
        rmse_bart[ee,ll] = sig.rmse(sig.nor(truth_comp[:,:,ee]),sig.nor(np.abs(bart_comp[:,:,ee])))
        
plt.plot(rmse_bart)
plt.plot(rmse_decoder)

legend = []
for ll in all_num_latent_linear:
    legend.append('linear %d' % ll)

legend.append('proposed %d' % latent_num)

plt.legend(legend)
plt.show()