In [None]:
from typing import List, Optional

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import scipy.io as sio
import numpy as np
import time 
import matplotlib.pyplot as plt

import bart 
import cfl

from python_utils import models
from python_utils import signalprocessing as sig

from pytorch_wavelets import DTCWTForward, DTCWTInverse, DWTForward,DWTInverse

### helper functions

In [None]:
def roll_one_dim(x: torch.Tensor, shift: int, dim: int) -> torch.Tensor:
    """
    Similar to roll but for only one dim.
    Args:
        x: A PyTorch tensor.
        shift: Amount to roll.
        dim: Which dimension to roll.
    Returns:
        Rolled version of x.
    """
    shift = shift % x.size(dim)
    if shift == 0:
        return x

    left = x.narrow(dim, 0, x.size(dim) - shift)
    right = x.narrow(dim, x.size(dim) - shift, shift)

    return torch.cat((right, left), dim=dim)

def roll(
    x: torch.Tensor,
    shift: List[int],
    dim: List[int],
) -> torch.Tensor:
    """
    Similar to np.roll but applies to PyTorch Tensors.
    Args:
        x: A PyTorch tensor.
        shift: Amount to roll.
        dim: Which dimension to roll.
    Returns:
        Rolled version of x.
    """
    if len(shift) != len(dim):
        raise ValueError("len(shift) must match len(dim)")

    for (s, d) in zip(shift, dim):
        x = roll_one_dim(x, s, d)

    return x


def fftshift(x: torch.Tensor, dim: Optional[List[int]] = None) -> torch.Tensor:
    """
    Similar to np.fft.fftshift but applies to PyTorch Tensors
    Args:
        x: A PyTorch tensor.
        dim: Which dimension to fftshift.
    Returns:
        fftshifted version of x.
    """
    if dim is None:
        # this weird code is necessary for toch.jit.script typing
        dim = [0] * (x.dim())
        for i in range(1, x.dim()):
            dim[i] = i

    # also necessary for torch.jit.script
    shift = [0] * len(dim)
    for i, dim_num in enumerate(dim):
        shift[i] = x.shape[dim_num] // 2

    return roll(x, shift, dim)


def ifftshift(x: torch.Tensor, dim: Optional[List[int]] = None) -> torch.Tensor:
    """
    Similar to np.fft.ifftshift but applies to PyTorch Tensors
    Args:
        x: A PyTorch tensor.
        dim: Which dimension to ifftshift.
    Returns:
        ifftshifted version of x.
    """
    if dim is None:
        # this weird code is necessary for toch.jit.script typing
        dim = [0] * (x.dim())
        for i in range(1, x.dim()):
            dim[i] = i

    # also necessary for torch.jit.script
    shift = [0] * len(dim)
    for i, dim_num in enumerate(dim):
        shift[i] = (x.shape[dim_num] + 1) // 2

    return roll(x, shift, dim)

tonpy = lambda x: torch.view_as_complex(x).cpu().numpy()

### dataset and parameters

In [None]:
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

model_layers   = 2
nonlinearity   = 'tanh'
latent_num     = 1
epoch          = 80000

In [None]:
fort2shfl  = sio.loadmat('data/for_t2shfl.mat')

In [None]:
mask      = fort2shfl['for_t2shfl'][0][0][0]
coils     = fort2shfl['for_t2shfl'][0][0][1]
phase_es  = cfl.readcfl('data/phase_recon')
kspace    = cfl.readcfl('data/kspace')

### defining operators and k-space

In [None]:
[M,N,C,E] = kspace.shape

coils_torch   = torch.stack((torch.tensor(np.real(coils),dtype=torch.float),\
                torch.tensor(np.imag(coils),dtype=torch.float)),dim=-1).to(device).reshape(M,N,C,1,2)
coils_torch.requires_grad = False

mask_torch    = torch.stack((torch.tensor(mask,dtype=torch.float),\
                             torch.tensor(mask,dtype=torch.float)),dim=-1).to(device)
mask_torch.requires_grad = False

In [None]:
def phase_for(x,P):
    '''
    applies pre-computed phase to a time-series of images
    inputs:
        x (M x N x 1 x T x 2)      - time-series of images
    outputs:
        out (M x N x 1 x T x 2)    - time-series of images multiplied with phase
    '''
    
    return torch.stack((x[...,0]*P[...,0] - x[...,1]*P[...,1],\
                       x[...,1]*P[...,0] + x[...,0]*P[...,1]),-1)

def sfor(x,CC):
    '''
    Performs forward coil sensitivity multiplication
    inputs:
        x  (M x N x 1 x T x 2)    - time series of images
        CC (M x N x C x 1 x 2)    - coil sensitivity functions
    output
        out (M x N x C x T x 2)   - output
    '''
    return torch.stack((x[..., 0] * CC[..., 0] - x[..., 1] * CC[..., 1],\
                       x[..., 0] * CC[..., 1] + x[..., 1] * CC[..., 0]),-1)

def ffor(data):
    '''
    Performs a forward fourier transform
    inputs:
          x (M x N x C x T x 2)
    outputs:
        out (M x N x X x T x 2)
             
    '''
    
    return ifftshift(torch.fft(fftshift(data,dim = [0,1]).permute(2,3,0,1,4),signal_ndim = 2,normalized = True).permute(2,3,0,1,4),dim = [0,1])

def R(data,mask):
    '''
    Apply undersampling mask to some input
    inputs:
        data  (M x N x C x T x 2)     - input to be masked
        mask  (M x N x 1 x T x 1)     - mask to be applied
    '''
    
    return data * mask

def Dfor(model):
    '''
    Takes an input data through the decoder trained for compressing signal evolution
    input
        data (M*N x L x 2)   - input data
        model              - neural network model
    output
        out (256 x 256 x 1 x T x 2)
    '''
    
    return torch.stack((model.decode(x) * xr,model.decode(x) * xi),-1).reshape(M,N,1,E,2)

def Wfor_dec(x,W):
    '''
    inputs:
        x (M*N x L)  - set of coefficients to go through the forward operator
        W            - pytorch wavelet operator
    '''
    L = x.shape[1]
    
    wlr, whr = W(x.reshape(M,N,L,1).permute(3,2,0,1))
    l1wavelet_loss = torch.sum(torch.abs(wlr))
    
    for a_whr in whr:
        l1wavelet_loss += torch.sum(torch.abs(a_whr))
                     
    return l1wavelet_loss

In [None]:
kspace_torch   = torch.stack((torch.tensor(np.real(kspace),dtype=torch.float),\
                torch.tensor(np.imag(kspace),dtype=torch.float)),dim=-1).to(device)
kspace_torch.requires_grad = False

kspace_torch = R(kspace_torch,mask_torch)

phase_es_torch   = torch.stack((torch.tensor(np.real(phase_es),dtype=torch.float),\
                torch.tensor(np.imag(phase_es),dtype=torch.float)),dim=-1).to(device).reshape(M,N,1,E,2)
phase_es_torch.requires_grad = False

xfm = DWTForward(J=3, mode='zero', wave='db3').to(device)

### decoder reconstruction

In [None]:
model_path = 'data/model_layers_%d_nonlin_%s_latent_%d_epochs_%d.pt' % (model_layers,nonlinearity,latent_num,epoch)
model = torch.load(model_path).to(device)

for param in model.parameters():
    param.requires_grad = False
    

In [None]:
lambda_l1_x   = 1e-11
lambda_l1_xr  = 1e-11
lambda_l1_xi  = 1e-11

sc = 10; #scaling k-space to reconstruct

x  = torch.zeros(M*N,latent_num).to(device).detach()
xr = torch.zeros(M*N,1).to(device).detach()
xi = torch.zeros(M*N,1).to(device).detach()


x.requires_grad  = True
xr.requires_grad = True
xi.requires_grad = True

criterion   = nn.MSELoss()
optimizer   = optim.Adam([x,xr,xi],lr = 1e0)
iterations  = 10000

for iter in range(iterations):
    optimizer.zero_grad()
    
    loss   = criterion(kspace_torch/sc,R(ffor(sfor(phase_for(Dfor(model),phase_es_torch),coils_torch)),mask_torch)) + \
        lambda_l1_x * Wfor_dec(x,xfm) + lambda_l1_xr * Wfor_dec(xr,xfm) + lambda_l1_xi * Wfor_dec(xi,xfm)
        
        
    loss.backward()
    optimizer.step()
    
    running_loss = loss.item()
    
    if iter % 100 == 0:
        print('iteration %d / %d, current loss: %.12f' % (iter,iterations,running_loss))

In [None]:
T2    = np.squeeze(x.reshape(M,N,latent_num).detach()).cpu().numpy()
real  = np.squeeze(xr.reshape(M,N,1).detach()).cpu().numpy()
imag  = np.squeeze(xi.reshape(M,N,1).detach()).cpu().numpy()

timeseries_decoder = tonpy((Dfor(model).detach().squeeze().contiguous()))

In [None]:
if(latent_num == 1):
    T2_show = np.expand_dims(T2,axis=-1)
else:
    T2_show = np.copy(T2)
    
display = np.concatenate((T2_show,\
                          np.expand_dims(real ,axis=-1),\
                          np.expand_dims(imag ,axis=-1)),axis=-1).transpose(2,0,1)

sig.mosaic(display,1,2 + latent_num)

### comparisons

In [None]:
truth    = cfl.readcfl('data/truth')
lin_rec  = sio.loadmat('data/linear_llr_ablation.mat')['llr'][0][0][0]

if(len(np.shape(lin_rec))==3):
    lin_rec = np.expand_dims(lin_rec,axis=-1)
    
if(len(np.shape(lin_rec)) == 4):
    lin_rec = np.expand_dims(lin_rec,axis=-1)

brain_mask = np.zeros((M,N))
brain_mask[np.where(np.abs(truth[:,:,0,0])>.8)] =1
brain_mask = np.expand_dims(brain_mask,axis=2)

In [None]:
lin_reg = 0
lin_idx = 0

lin_rmse = np.zeros(E)
pro_rmse = np.zeros(E)

for ee in range(E):
    truth_comp   = np.abs(truth[:,:,0,ee:ee+1]) * brain_mask
    pro_rmse[ee] = sig.rmse(truth_comp,np.abs(timeseries_decoder[:,:,ee:ee+1])*brain_mask*sc)
    lin_rmse[ee] = sig.rmse(truth_comp,np.abs(lin_rec[:,:,ee:ee+1,lin_idx,lin_reg])*brain_mask)

legend = []

legend.append('lin')
legend.append('pro %d' % (latent_num+2))
legend.append('ana')

plt.plot(lin_rmse)
plt.plot(pro_rmse)
plt.legend(legend)
plt.show()

print('avg rmse:')
print('  pro %d: %.2f' % (latent_num+2,np.mean(pro_rmse)*100))
print('  lin:    %.2f' % (100*np.mean(lin_rmse)))


In [None]:
lin_reg = 0
lin_idx = 1
ech     = 39

decoder_show = np.abs(timeseries_decoder[:,:,ech:ech+1]) * brain_mask * sc
lin_show     = np.abs(lin_rec[:,:,ech:ech+1,lin_idx,lin_reg]) * brain_mask
truth_show   = np.abs(truth[:,:,0,ech:ech+1]) * brain_mask

display = np.concatenate((truth_show,decoder_show,lin_show),axis = -1).transpose(2,0,1)

sig.mosaic(sig.nor(display),1,3,clim=[0,1])
print('rmse ech %d' % (ech+1))
print('   prop: %.2f' % (sig.rmse(truth_show,decoder_show) * 100))
print('   lin: %.2f'  % (sig.rmse(truth_show,lin_show) * 100))