In [None]:
from typing import List, Optional

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import scipy.io as sio
import numpy as np
import time 
import matplotlib.pyplot as plt

import bart
import cfl

from python_utils import models
from python_utils import signalprocessing as sig

from pytorch_wavelets import DTCWTForward, DTCWTInverse, DWTForward,DWTInverse

### defining some helper functions

In [None]:
def roll_one_dim(x: torch.Tensor, shift: int, dim: int) -> torch.Tensor:
    """
    Similar to roll but for only one dim.
    Args:
        x: A PyTorch tensor.
        shift: Amount to roll.
        dim: Which dimension to roll.
    Returns:
        Rolled version of x.
    """
    shift = shift % x.size(dim)
    if shift == 0:
        return x

    left = x.narrow(dim, 0, x.size(dim) - shift)
    right = x.narrow(dim, x.size(dim) - shift, shift)

    return torch.cat((right, left), dim=dim)

def roll(
    x: torch.Tensor,
    shift: List[int],
    dim: List[int],
) -> torch.Tensor:
    """
    Similar to np.roll but applies to PyTorch Tensors.
    Args:
        x: A PyTorch tensor.
        shift: Amount to roll.
        dim: Which dimension to roll.
    Returns:
        Rolled version of x.
    """
    if len(shift) != len(dim):
        raise ValueError("len(shift) must match len(dim)")

    for (s, d) in zip(shift, dim):
        x = roll_one_dim(x, s, d)

    return x


def fftshift(x: torch.Tensor, dim: Optional[List[int]] = None) -> torch.Tensor:
    """
    Similar to np.fft.fftshift but applies to PyTorch Tensors
    Args:
        x: A PyTorch tensor.
        dim: Which dimension to fftshift.
    Returns:
        fftshifted version of x.
    """
    if dim is None:
        # this weird code is necessary for toch.jit.script typing
        dim = [0] * (x.dim())
        for i in range(1, x.dim()):
            dim[i] = i

    # also necessary for torch.jit.script
    shift = [0] * len(dim)
    for i, dim_num in enumerate(dim):
        shift[i] = x.shape[dim_num] // 2

    return roll(x, shift, dim)


def ifftshift(x: torch.Tensor, dim: Optional[List[int]] = None) -> torch.Tensor:
    """
    Similar to np.fft.ifftshift but applies to PyTorch Tensors
    Args:
        x: A PyTorch tensor.
        dim: Which dimension to ifftshift.
    Returns:
        ifftshifted version of x.
    """
    if dim is None:
        # this weird code is necessary for toch.jit.script typing
        dim = [0] * (x.dim())
        for i in range(1, x.dim()):
            dim[i] = i

    # also necessary for torch.jit.script
    shift = [0] * len(dim)
    for i, dim_num in enumerate(dim):
        shift[i] = (x.shape[dim_num] + 1) // 2

    return roll(x, shift, dim)

tonpy = lambda x: torch.view_as_complex(x).cpu().numpy()

### setting parameters and loading

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model_type     = 3
latent_num     = 1
epoch          = 200000

nmonte = 2
noivar = .003

In [None]:
mask    = cfl.readcfl('data/mask')
kspace  = cfl.readcfl('data/kspace')
coils   = cfl.readcfl('data/coils')

all_num_latent_linear = [2,3];

dictionary = sio.loadmat('data/dictionary.mat')['dictionary']

In [None]:
all_basis = []

for num_latent_linear in all_num_latent_linear:
    [U,S,V] = np.linalg.svd(dictionary,full_matrices=False)
    basis   = U[:,:num_latent_linear]
    
    all_basis.append(basis)

### defining operators and preparing k-space

In [None]:
[M,N,C,E] = kspace.shape

xfm = DWTForward(J=3, mode='zero', wave='db3').to(device)

coils_torch   = torch.stack((torch.tensor(np.real(coils),dtype=torch.float),\
                torch.tensor(np.imag(coils),dtype=torch.float)),dim=-1).to(device).reshape(M,N,C,1,2)
coils_torch.requires_grad = False

mask_torch    = torch.stack((torch.tensor(mask,dtype=torch.float),\
                             torch.tensor(mask,dtype=torch.float)),dim=-1).to(device)
mask_torch.requires_grad = False

In [None]:
def sfor(x,CC):
    '''
    Performs forward coil sensitivity multiplication
    inputs:
        x  (M x N x 1 x T x 2)    - time series of images
        CC (M x N x C x 1 x 2)    - coil sensitivity functions
    output
        out (M x N x C x T x 2)   - output
    '''
    return torch.stack((x[..., 0] * CC[..., 0] - x[..., 1] * CC[..., 1],\
                       x[..., 0] * CC[..., 1] + x[..., 1] * CC[..., 0]),-1)

def ffor(data):
    '''
    Performs a forward fourier transform
    inputs:
          x (M x N x C x T x 2)
    outputs:
        out (M x N x X x T x 2)
             
    '''
    
    return ifftshift(torch.fft(fftshift(data,dim = [0,1]).permute(2,3,0,1,4),signal_ndim = 2,normalized = True).permute(2,3,0,1,4),dim = [0,1])

def R(data,mask):
    '''
    Apply undersampling mask to some input
    inputs:
        data  (M x N x C x T x 2)     - input to be masked
        mask  (M x N x 1 x T x 1)     - mask to be applied
    '''
    
    return data * mask

def Dfor(model):
    '''
    Takes an input data through the decoder trained for compressing signal evolution
    input
        data (M*N x L x 2)   - input data
        model              - neural network model
    output
        out (256 x 256 x 1 x T x 2)
    '''
    
    out = model.decode(x)
    return torch.stack((out * xr,out * xi),-1).reshape(M,N,1,E,2)

### noise instance reconstruction loop

In [None]:
model_path = 'data/model_type_%d_latent_%d_epochs_%d.pt' % (model_type,latent_num,epoch)
model = torch.load(model_path).to(device)

for param in model.parameters():
    param.requires_grad = False

In [None]:
import cfl

iter_deco   = 4000
iter_bart   = 100

#-preparations for bart reconstruction
all_timeseries_bart    = np.zeros((M,N,E,nmonte,len(all_basis)),dtype = complex)
all_timeseries_decoder = np.zeros((M,N,E,nmonte),dtype = complex)

coils_bart  = coils.reshape(M,N,1,C,1)

for nm in range(nmonte):
    print('monte %d/%d' % (nm+1,nmonte))
    cur_noise = np.random.normal(0,noivar,size = kspace.shape) + 1j * np.random.normal(0,noivar,size = kspace.shape)
    kspace_noise = kspace + cur_noise
    
    #DECODER RECONSTRUCTIONS
    print('  decoder recon')
    kspace_torch   = torch.stack((torch.tensor(np.real(kspace_noise),dtype=torch.float),\
                torch.tensor(np.imag(kspace_noise),dtype=torch.float)),dim=-1).to(device)
    kspace_torch.requires_grad = False

    kspace_torch = R(kspace_torch,mask_torch)
    
    x  = torch.zeros(M*N,latent_num).to(device).detach()

    xr = torch.zeros(M*N,1).to(device).detach() 
    xi = torch.zeros(M*N,1).to(device).detach() 

    x.requires_grad  = True
    xr.requires_grad = True
    xi.requires_grad = True

    criterion   = nn.MSELoss()
    optimizer   = optim.Adam([x,xr,xi],lr = 1e-1)
    
    for iter in range(iter_deco):
        optimizer.zero_grad()

        loss   = criterion(kspace_torch,R(ffor(sfor(Dfor(model),coils_torch)),mask_torch))
        
        loss.backward()
        optimizer.step()

        running_loss = loss.item()

        if iter % 250 == 0:
            print('    iteration %d / %d, current loss: %.12f' % (iter,iter_deco,running_loss))

    timeseries_decoder = tonpy((Dfor(model).detach().squeeze().contiguous()))
    all_timeseries_decoder[:,:,:,nm] = timeseries_decoder
    
    #BART RECONSTRUCTIONS
    kspace_bart = (kspace_noise * mask).reshape(M,N,1,C,1,E)
    
    ctr_b = 0
    for basis in all_basis:
        B = basis.shape[1]
        print('  basis %d/%d' % (ctr_b+1,len(all_basis)))
        
        basis_bart  = basis.reshape(1,1,1,1,1,E,B)
        cfl.writecfl('data/basis_bart',basis_bart)
        
        #-no regularization bart string
        bartstr = 'pics -B data/basis_bart -i %d' % (iter_bart)
        
        #-reconstruction
        coeffs_bart     = np.squeeze(bart.bart(1,bartstr,kspace_bart,coils_bart))
        timeseries_bart = (basis @ coeffs_bart.transpose(2,0,1).reshape(B,M*N)).reshape(E,M,N).transpose(1,2,0)
        
        all_timeseries_bart[:,:,:,nm,ctr_b] = timeseries_bart
        
        ctr_b += 1
        

print('DONE')

### comparisons to fully-sampled

#### -multiplying by brainmask

In [None]:
timeseries_truth = np.abs(cfl.readcfl('data/timeseries_truth'))

#Masking out results outside of the brain
brain_mask = np.zeros((M,N))
brain_mask[np.where(np.abs(timeseries_truth[:,:,0])>.1)] =1
brain_mask = np.expand_dims(brain_mask,axis=2)

all_timeseries_bart    = all_timeseries_bart * brain_mask.reshape(M,N,1,1,1)
all_timeseries_decoder = all_timeseries_decoder * brain_mask.reshape(M,N,1,1)
truth                  = timeseries_truth * brain_mask

#### -computing rmse at each echo 

In [None]:
rmse_decoder = np.zeros((E,nmonte))
rmse_bart    = np.zeros((E,nmonte,len(all_basis)))

for ee in range(E):
    truth_comp = sig.nor(np.abs(truth[:,:,ee]))
    for nm in range(nmonte):
        rmse_decoder[ee,nm] = sig.rmse(truth_comp,sig.nor(np.abs(all_timeseries_decoder[:,:,ee,nm])))
    
        for bb in range(len(all_basis)):
            rmse_bart[ee,nm,bb] = sig.rmse(truth_comp,sig.nor(np.abs(all_timeseries_bart[:,:,ee,nm,bb])))
            
#-plotting average rmse for each echo just to get an idea of what's going on
leg = []
plt.plot(np.mean(rmse_decoder,1))
leg.append('dec')
for bb in range(len(all_basis)):
    leg.append('lin %d' % (bb+2))
    plt.plot(np.mean(rmse_bart[:,:,bb],1))

plt.legend(leg)
plt.show()

#### -computing average error maps

In [None]:
error_decoder = np.zeros((M,N,E))
error_bart    = np.zeros((M,N,E,len(all_basis)))

variance_decoder = np.zeros((M,N,E))
variance_bart    = np.zeros((M,N,E,len(all_basis)))

for ee in range(E):
    truth_comp = sig.nor(np.abs(truth[:,:,ee]))
    
    differences_decoder = np.zeros((M,N,nmonte))
    for nm in range(nmonte):
        differences_decoder[:,:,nm] = (truth_comp - sig.nor(np.abs(all_timeseries_decoder[:,:,ee,nm])))**2
    
    error_decoder[:,:,ee]    = np.sqrt(np.sum(differences_decoder,-1)) / np.sqrt(nmonte*(truth_comp**2))
    variance_decoder[:,:,ee] = np.var(np.sqrt(differences_decoder),-1)
    
    for bb in range(len(all_basis)):
        differences_bart = np.zeros((M,N,nmonte))
        for nm in range(nmonte):
            differences_bart[:,:,nm] = (truth_comp - sig.nor(np.abs(all_timeseries_bart[:,:,ee,nm,bb])))**2
        
        error_bart[:,:,ee,bb]    = np.sqrt(np.sum(differences_bart,-1)) / np.sqrt(nmonte*(truth_comp**2))
        variance_bart[:,:,ee,bb] = np.var(np.sqrt(differences_bart),-1)
        
    error_bart    = np.nan_to_num(error_bart)
    error_decoder = np.nan_to_num(error_decoder)

### quick comparison of error maps

In [None]:
ech = 10
cl  = [0,1]

display = np.concatenate((np.expand_dims(error_decoder[:,:,ech],axis=0),\
                         error_bart[:,:,ech,:].transpose(2,0,1)),axis=0)

sig.mosaic(display,1,3,clim=cl)

### show particular reconstruction for instance and echo

In [None]:
ech = 10;
nm  = 0;
cl  = [0,1]

display = np.concatenate((sig.nor(np.expand_dims(all_timeseries_decoder[:,:,ech,nm],axis=0)),\
                         sig.nor(all_timeseries_bart[:,:,ech,nm,:]).transpose(2,0,1)),axis=0)

sig.mosaic(display,1,3,clim=cl)