In [None]:
from typing import List, Optional

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import scipy.io as sio
import numpy as np
import time 
import matplotlib.pyplot as plt

import bart 

from python_utils import models
from python_utils import signalprocessing as sig

from pytorch_wavelets import DTCWTForward, DTCWTInverse, DWTForward,DWTInverse

### defining some helper functions

In [None]:
def roll_one_dim(x: torch.Tensor, shift: int, dim: int) -> torch.Tensor:
    """
    Similar to roll but for only one dim.
    Args:
        x: A PyTorch tensor.
        shift: Amount to roll.
        dim: Which dimension to roll.
    Returns:
        Rolled version of x.
    """
    shift = shift % x.size(dim)
    if shift == 0:
        return x

    left = x.narrow(dim, 0, x.size(dim) - shift)
    right = x.narrow(dim, x.size(dim) - shift, shift)

    return torch.cat((right, left), dim=dim)

def roll(
    x: torch.Tensor,
    shift: List[int],
    dim: List[int],
) -> torch.Tensor:
    """
    Similar to np.roll but applies to PyTorch Tensors.
    Args:
        x: A PyTorch tensor.
        shift: Amount to roll.
        dim: Which dimension to roll.
    Returns:
        Rolled version of x.
    """
    if len(shift) != len(dim):
        raise ValueError("len(shift) must match len(dim)")

    for (s, d) in zip(shift, dim):
        x = roll_one_dim(x, s, d)

    return x


def fftshift(x: torch.Tensor, dim: Optional[List[int]] = None) -> torch.Tensor:
    """
    Similar to np.fft.fftshift but applies to PyTorch Tensors
    Args:
        x: A PyTorch tensor.
        dim: Which dimension to fftshift.
    Returns:
        fftshifted version of x.
    """
    if dim is None:
        # this weird code is necessary for toch.jit.script typing
        dim = [0] * (x.dim())
        for i in range(1, x.dim()):
            dim[i] = i

    # also necessary for torch.jit.script
    shift = [0] * len(dim)
    for i, dim_num in enumerate(dim):
        shift[i] = x.shape[dim_num] // 2

    return roll(x, shift, dim)


def ifftshift(x: torch.Tensor, dim: Optional[List[int]] = None) -> torch.Tensor:
    """
    Similar to np.fft.ifftshift but applies to PyTorch Tensors
    Args:
        x: A PyTorch tensor.
        dim: Which dimension to ifftshift.
    Returns:
        ifftshifted version of x.
    """
    if dim is None:
        # this weird code is necessary for toch.jit.script typing
        dim = [0] * (x.dim())
        for i in range(1, x.dim()):
            dim[i] = i

    # also necessary for torch.jit.script
    shift = [0] * len(dim)
    for i, dim_num in enumerate(dim):
        shift[i] = (x.shape[dim_num] + 1) // 2

    return roll(x, shift, dim)

tonpy = lambda x: torch.view_as_complex(x).cpu().numpy()

### loading data and models

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model_layers   = 2
nonlinearity   = 'tanh'
latent_num     = 1
epoch          = 100000
rep            = 0

In [None]:
dictionary = sio.loadmat('data/dictionary.mat')['dictionary']
fort2shfl  = sio.loadmat('data/for_t2shfl.mat')

In [None]:
mask      = fort2shfl['for_t2shfl'][0][0][0]
kspace    = fort2shfl['for_t2shfl'][0][0][1]
coils     = fort2shfl['for_t2shfl'][0][0][2]
truth     = fort2shfl['for_t2shfl'][0][0][3]

all_num_latent_linear = [2,3];

In [None]:
all_basis = []

for num_latent_linear in all_num_latent_linear:
    [U,S,V] = np.linalg.svd(dictionary,full_matrices=False)
    basis   = U[:,:num_latent_linear]
    
    all_basis.append(basis)

### defining forward model operators and setting preparing k-space to reconstruct

In [None]:
[M,N,C,E] = kspace.shape

xfm = DWTForward(J=3, mode='zero', wave='db3').to(device)

coils_torch   = torch.stack((torch.tensor(np.real(coils),dtype=torch.float),\
                torch.tensor(np.imag(coils),dtype=torch.float)),dim=-1).to(device).reshape(M,N,C,1,2)
coils_torch.requires_grad = False

mask_torch    = torch.stack((torch.tensor(mask,dtype=torch.float),\
                             torch.tensor(mask,dtype=torch.float)),dim=-1).to(device)
mask_torch.requires_grad = False

In [None]:
def sfor(x,CC):
    '''
    Performs forward coil sensitivity multiplication
    inputs:
        x  (M x N x 1 x T x 2)    - time series of images
        CC (M x N x C x 1 x 2)    - coil sensitivity functions
    output
        out (M x N x C x T x 2)   - output
    '''
    return torch.stack((x[..., 0] * CC[..., 0] - x[..., 1] * CC[..., 1],\
                       x[..., 0] * CC[..., 1] + x[..., 1] * CC[..., 0]),-1)

def ffor(data):
    '''
    Performs a forward fourier transform
    inputs:
          x (M x N x C x T x 2)
    outputs:
        out (M x N x X x T x 2)
             
    '''
    
    return ifftshift(torch.fft(fftshift(data,dim = [0,1]).permute(2,3,0,1,4),signal_ndim = 2,normalized = True).permute(2,3,0,1,4),dim = [0,1])

def R(data,mask):
    '''
    Apply undersampling mask to some input
    inputs:
        data  (M x N x C x T x 2)     - input to be masked
        mask  (M x N x 1 x T x 1)     - mask to be applied
    '''
    
    return data * mask

def Dfor(model):
    '''
    Takes an input data through the decoder trained for compressing signal evolution
    input
        data (M*N x L x 2)   - input data
        model              - neural network model
    output
        out (256 x 256 x 1 x T x 2)
    '''
    
    out = model.decode(x)
    return torch.stack((out * xr,out * xi),-1).reshape(M,N,1,E,2)
    
def Wfor_dec(x,W):
    '''
    inputs:
        x (M*N x L)  - set of coefficients to go through the forward operator
        W     a       - pytorch wavelet operator
    '''
    L = x.shape[1]
    
    wlr, whr = W(x.reshape(M,N,L,1).permute(3,2,0,1))
    l1wavelet_loss = torch.sum(torch.abs(wlr))
    
    for a_whr in whr:
        l1wavelet_loss += torch.sum(torch.abs(a_whr))
                     
    return l1wavelet_loss

In [None]:
kspace_torch   = torch.stack((torch.tensor(np.real(kspace),dtype=torch.float),\
                torch.tensor(np.imag(kspace),dtype=torch.float)),dim=-1).to(device)
kspace_torch.requires_grad = False

kspace_torch = R(kspace_torch,mask_torch)

### decoder, wavelet regularized reconstruction

In [None]:
model_path = 'data/model_layers_%d_nonlin_%s_latent_%d_epochs_%d_rep%d.pt' % \
    (model_layers,nonlinearity,latent_num,epoch,rep)
model = torch.load(model_path).to(device)

for param in model.parameters():
    param.requires_grad = False

In [None]:
lambdas_decoder = [8e-13]

iterations  = 8000 
    
LD = len(lambdas_decoder)

all_T2   = np.zeros((M,N,latent_num,LD))
all_real = np.zeros((M,N,LD))
all_imag = np.zeros((M,N,LD))
all_timeseries_decoder = np.zeros((M,N,E,LD),dtype=complex)

ctr = 0
for lam in lambdas_decoder:

    x  = torch.zeros(M*N,latent_num).to(device).detach()

    xr = torch.ones(M*N,1).to(device).detach() 
    xi = torch.zeros(M*N,1).to(device).detach() 


    x.requires_grad  = True
    xr.requires_grad = True
    xi.requires_grad = True

    criterion   = nn.MSELoss()
    optimizer   = optim.Adam([x,xr,xi],lr = 1e0)
    
    for iter in range(iterations):
        optimizer.zero_grad()

        loss   = criterion(kspace_torch,R(ffor(sfor(Dfor(model),coils_torch)),mask_torch)) + \
            lam * Wfor_dec(x,xfm) + lam * Wfor_dec(xr,xfm)

        loss.backward()
        optimizer.step()

        running_loss = loss.item()

        if iter % 100 == 0:
            print('iteration %d / %d, current loss: %.12f' % (iter,iterations,running_loss))
    
    T2    = np.squeeze(x.reshape(M,N,latent_num).detach()).cpu().numpy()
    real  = np.squeeze(xr.reshape(M,N,1).detach()).cpu().numpy()
    imag  = np.squeeze(xi.reshape(M,N,1).detach()).cpu().numpy()

    timeseries_decoder = tonpy((Dfor(model).detach().squeeze().contiguous()))
    
    all_T2[:,:,:,ctr]  = T2.reshape(M,N,latent_num)
    all_real[:,:,ctr]  = real
    all_imag[:,:,ctr]  = imag
    
    all_timeseries_decoder[:,:,:,ctr] = timeseries_decoder
    
    ctr+=1

### quick display of decoder reconstruction for a particular regularization value

In [None]:
regidx = 0
display = np.concatenate((np.expand_dims(all_T2[:,:,0,regidx],axis=-1),\
                          np.expand_dims(all_real[:,:,regidx],axis=-1),\
                          np.expand_dims(all_imag[:,:,regidx],axis=-1)),axis=-1).transpose(2,0,1)
sig.mosaic(sig.nor(display),1,2 + latent_num)

### wavelet regularized linear reconstruction with BART

In [None]:
import cfl

ctr = 1
lambdas = [3e-5]


iter_bart = 4000

nB = len(all_basis)
nR = len(lambdas)

all_coeffs_bart     = []
all_timeseries_bart = []

for basis in all_basis:
    B = basis.shape[1]
    print('reconstruction %d || rank %d' % (ctr,B))
    ctr+=1
    
    #-preparing preliminaries for BART reconstruction
    kspace_bart = (kspace * mask).reshape(M,N,1,C,1,E)
    coils_bart  = coils.reshape(M,N,1,C,1)
    basis_bart  = basis.reshape(1,1,1,1,1,E,B)
    cfl.writecfl('data/basis_bart',basis_bart)

    cur_coeffs_bart = np.zeros((M,N,B,nR),dtype = complex)
    cur_timeseries  = np.zeros((M,N,E,nR),dtype = complex)
    
    ctr_r = 0
    for lam in lambdas:
        bartstr = 'pics -B data/basis_bart -i %d -l1 -r %f' % (iter_bart,lam)
            
        print('bart string: ' + bartstr)
        coeffs_bart = np.squeeze(bart.bart(1,bartstr,kspace_bart,coils_bart))
        
        cur_coeffs_bart[:,:,:,ctr_r] = coeffs_bart
        timeseries = (basis @ coeffs_bart.transpose(2,0,1).reshape(B,M*N)).reshape(E,M,N).transpose(1,2,0)
        cur_timeseries[:,:,:,ctr_r] = timeseries
        
        ctr_r +=1
    print(' done')
    
    
    all_coeffs_bart.append(cur_coeffs_bart)
    all_timeseries_bart.append(cur_timeseries)

### comparisons

In [None]:
timeseries_truth = np.abs(np.squeeze(truth))

#Masking out results outside of the brain
brain_mask = np.zeros((M,N))
brain_mask[np.where(np.abs(truth[:,:,0,0])>0)] =1
brain_mask = np.expand_dims(brain_mask,axis=2)

rmse_decoder_all = np.zeros((E,LD))
truth_comp       = timeseries_truth * brain_mask

for ll in range(len(lambdas_decoder)):
    decoder    = all_timeseries_decoder[:,:,:,ll] * brain_mask

    for ee in range(E):
        rmse_decoder_all[ee,ll] = sig.rmse(sig.nor(truth_comp[:,:,ee]),sig.nor(np.abs(decoder[:,:,ee])))

In [None]:
lin_idx      = 1
reg_idx_bart = 0
reg_idx_dec  = 0
ech          = 20

truth_show   = sig.nor(timeseries_truth[:,:,ech:ech+1]*brain_mask)
decoder_show = sig.nor(np.abs(all_timeseries_decoder[:,:,ech:ech+1,reg_idx_dec])*brain_mask)
bart_show    = sig.nor(np.abs(all_timeseries_bart[lin_idx][:,:,ech:ech+1,reg_idx_bart]*brain_mask))

display = np.concatenate((bart_show,decoder_show),axis = -1).transpose(2,0,1)

sig.mosaic(sig.nor(display),1,2,clim=[0,1])

print('  bart:      %.2f' % (sig.rmse(truth_show,bart_show)*100))
print('  proposed:  %.2f' % (sig.rmse(truth_show,decoder_show)*100))

In [None]:
reg_idx_bart = 0
reg_idx_dec  = 0

rmse_decoder = np.zeros(E)
rmse_bart    = np.zeros((E,len(all_num_latent_linear)))

decoder    = all_timeseries_decoder[:,:,:,reg_idx_dec] * brain_mask
truth_comp = timeseries_truth * brain_mask

for ee in range(E):
    rmse_decoder[ee] = sig.rmse(sig.nor(truth_comp[:,:,ee]),sig.nor(np.abs(decoder[:,:,ee])))
    
for ll in range(len(all_num_latent_linear)):
    bart_comp    = all_timeseries_bart[ll][:,:,:,reg_idx_bart] * brain_mask
    
    for ee in range(E):
        rmse_bart[ee,ll] = sig.rmse(sig.nor(truth_comp[:,:,ee]),sig.nor(np.abs(bart_comp[:,:,ee])))
        
plt.plot(rmse_bart)
plt.plot(rmse_decoder)

legend = []
for ll in all_num_latent_linear:
    legend.append('linear %d' % ll)

legend.append('propsed %d' % latent_num)

plt.legend(legend)
plt.show()

#printing out average rmse
print('average rmse:')
print('  pro %d: %.2f' % (latent_num,100*np.mean(rmse_decoder)))
for ll in range(len(all_num_latent_linear)):
    print('  lin %d: %.2f' % (all_num_latent_linear[ll],100*np.mean(rmse_bart[:,ll])))