<a href="https://colab.research.google.com/github/CARDIAL-nyu/cmr-playground/blob/main/golden_angle_sample/OCMR_Fully_Sampled_to_DynamicRadCineMRI.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Setup Google Drive for storage

In [None]:
from google.colab import drive
import sys
drive.mount('/content/drive')

In [None]:
!{sys.executable} -m pip install torch==1.6.0 torchvision torchtext torchkbnufft==0.3.4

In [None]:
!{sys.executable} -m pip freeze

# Dependencies

In [None]:
%matplotlib notebook

# Generic scientific python
import numpy as np
from matplotlib import pyplot as plt
import cupy as cp
import pandas as pd

# Aesthetics
import seaborn as sns
import matplotlib as mpl
plt.style.use('fivethirtyeight')
five_thirty_eight = [
"#30a2da",
"#fc4f30",
"#e5ae38",
"#6d904f",
"#8b8b8b",
]
sns.set_palette(five_thirty_eight, color_codes=True)
sns.set_color_codes()
mpl.rcParams['axes.grid'] = False

# ISMRM Tools
import sys
!{sys.executable} -m pip install git+https://github.com/ismrmrd/ismrmrd-python-tools.git git+https://github.com/ismrmrd/ismrmrd-python.git --no-binary ismrmrd
import ismrmrdtools
import ismrmrdtools.coils
import ismrmrd
import ismrmrd.xsd
from ismrmrdtools import show, transform

# AWS and other MRI tools
!{sys.executable} -m pip install boto3==1.19.12 watermark

# AWS
import boto3
from botocore import UNSIGNED
from botocore.client import Config

# Progress Bar
from tqdm import tqdm_notebook as tqdm
from tqdm import tnrange as trange

# Standard library
import glob
import itertools
import tempfile
import datetime
import sys
import os

# MRI libraries
!{sys.executable} -m pip install sigpy h5py git+https://github.com/mritools/mrrt.mri.git
import sigpy as sp
import sigpy.plot as pl
import sigpy.mri as mri
from mrrt.mri.coils import coil_pca, apply_pca_weights

# Visualization
from IPython.core.display import HTML, Video
import matplotlib.animation as animation
from skimage import exposure

# DynamicRadCineMRI requires these older library versions pinned here:
#!{sys.executable} -m pip install torch==1.6.0 torchkbnufft==0.3.4

In [None]:
!{sys.executable} -m pip freeze

# Utility Functions
## Major utilities
 - `read_ocmr` based off of sample code that accompanies OCMR data
 - `return_ekg_from_ismrmrd` extracts EKG from the ISMRMRD file if available
 - `view_cine` displays an HTML5 video player for visualizing cine

In [None]:
def read_ocmr(filename, extrareturns=False):
# Before running the code, install ismrmrd-python and ismrmrd-python-tools:
#  https://github.com/ismrmrd/ismrmrd-python
#  https://github.com/ismrmrd/ismrmrd-python-tools
# Last modified: 06-12-2020 by Chong Chen (Chong.Chen@osumc.edu)
#
# Input:  *.h5 file name
# Output: all_data    k-space data, orgnazide as {'kx'  'ky'  'kz'  'coil'  'phase'  'set'  'slice'  'rep'  'avg'}
#         param  some parameters of the scan
# 

# This is a function to read K-space from ISMRMD *.h5 data
# Modifid by Chong Chen (Chong.Chen@osumc.edu) based on the python script
# from https://github.com/ismrmrd/ismrmrd-python-tools/blob/master/recon_ismrmrd_dataset.py

    if not os.path.isfile(filename):
        print("%s is not a valid file" % filename)
        raise SystemExit
    dset = ismrmrd.Dataset(filename, 'dataset', create_if_needed=False)
    header = ismrmrd.xsd.CreateFromDocument(dset.read_xml_header())
    enc = header.encoding[0]

    # Matrix size
    eNx = enc.encodedSpace.matrixSize.x
    #eNy = enc.encodedSpace.matrixSize.y
    eNz = enc.encodedSpace.matrixSize.z
    eNy = (enc.encodingLimits.kspace_encoding_step_1.maximum + 1); #no zero padding along Ny direction

    # Field of View
    eFOVx = enc.encodedSpace.fieldOfView_mm.x
    eFOVy = enc.encodedSpace.fieldOfView_mm.y
    eFOVz = enc.encodedSpace.fieldOfView_mm.z
    
    # Save the parameters    
    param = dict();
    param['TRes'] =  str(header.sequenceParameters.TR)
    param['FOV'] = [eFOVx, eFOVy, eFOVz]
    param['TE'] = str(header.sequenceParameters.TE)
    param['TI'] = str(header.sequenceParameters.TI)
    param['echo_spacing'] = str(header.sequenceParameters.echo_spacing)
    param['flipAngle_deg'] = str(header.sequenceParameters.flipAngle_deg)
    param['sequence_type'] = header.sequenceParameters.sequence_type

    # Read number of Slices, Reps, Contrasts, etc.
    nCoils = header.acquisitionSystemInformation.receiverChannels
    try:
        nSlices = enc.encodingLimits.slice.maximum + 1
    except:
        nSlices = 1
        
    try:
        nReps = enc.encodingLimits.repetition.maximum + 1
    except:
        nReps = 1
               
    try:
        nPhases = enc.encodingLimits.phase.maximum + 1
    except:
        nPhases = 1;

    try:
        nSets = enc.encodingLimits.set.maximum + 1;
    except:
        nSets = 1;

    try:
        nAverage = enc.encodingLimits.average.maximum + 1;
    except:
        nAverage = 1;   
        
    firstacq=0
    noise_scan = list()
    for acqnum in trange(dset.number_of_acquisitions()):
        acq = dset.read_acquisition(acqnum)

        # TODO: Currently ignoring noise scans
        if acq.isFlagSet(ismrmrd.ACQ_IS_NOISE_MEASUREMENT):
            #print("Found noise scan at acq ", acqnum)
            noise_scan.append(dset.read_acquisition(acqnum).data)
            continue
        else:
            firstacq = acqnum
            print("Imaging acquisition starts acq ", acqnum)
            break
     
    noise_scan = np.asarray(noise_scan)

    # assymetry echo
    kx_prezp = 0;
    acq_first = dset.read_acquisition(firstacq)
    if  acq_first.center_sample*2 <  eNx:
        kx_prezp = eNx - acq_first.number_of_samples
         
    # Initialiaze a storage array
    param['kspace_dim'] = {'kx ky kz coil phase set slice rep avg'};
    all_data = np.zeros((eNx, eNy, eNz, nCoils, nPhases, nSets, nSlices, nReps, nAverage), dtype=np.complex64)

    # Loop through the rest of the acquisitions and stuff
    for acqnum in trange(firstacq,dset.number_of_acquisitions()):
        acq = dset.read_acquisition(acqnum)

        # Stuff into the buffer
        y = acq.idx.kspace_encode_step_1
        z = acq.idx.kspace_encode_step_2
        phase =  acq.idx.phase;
        set =  acq.idx.set;
        slice =  acq.idx.slice;
        rep =  acq.idx.repetition;
        avg = acq.idx.average;        
        all_data[kx_prezp:, y, z, :,phase, set, slice, rep, avg ] = np.transpose(acq.data)
        
    try:
        twix_ekg = return_ekg_from_ismrmrd(dset)
    except:
        twix_ekg = None
        print('EKG not found in data')

    if not extrareturns:
        return all_data, param
    else:
        return all_data, param, enc, acq_first, noise_scan, twix_ekg
    
def return_ekg_from_ismrmrd(f):
    """
    returns dataframe with dset passed in
    """
    nwaveforms = f.number_of_waveforms()
    read_waveforms = [f.read_waveform(i) for i in trange(nwaveforms, desc='Reading waveforms...') if
                      f.read_waveform(i).waveform_id == 0]

    # List comprehension above puts them in an odd format. This re-arranges to fit nicely into pandas DataFrame.
    x1 = np.hstack((read_waveforms[0].data, read_waveforms[1].data))
    for a_stack in tqdm(read_waveforms[2:], desc='Stacking waveforms...'):
        x1 = np.hstack((x1, a_stack.data))

    # Siemens twix files have EKG recorded at sampling rate of 400 Hz. Create the time array:
    t1 = np.arange(0, x1.shape[1]) / 400.0

    return pd.DataFrame(np.vstack((t1, x1)).T,
                        columns=['time_sec', 'ch1', 'ch2', 'ch3', 'ch4', 'isTrigger_boolean'])

def return_xyt_array(filename):
    recon_tv_1 = cfl.readcfl(filename.replace('.cfl','').replace('.hdr',''))
    recon_tv_sqz_1 = recon_tv_1[:,:,0,0,0,0,0,0,0,0,:].squeeze()
    return recon_tv_sqz_1[::-1,:,:].transpose((1,0,2))

def view_cine(filename, display_now=True, fig_width=7, fps=10, equalize=True):
    if isinstance(filename, str):
        xyt_array = return_xyt_array(filename)
    else:
        xyt_array = filename
        filename = id(filename)

    _cine_html_out_f, _ani_out_f = cine_html5(xyt_array, fig_width=fig_width, equalize=equalize);
    if display_now:
        display(_cine_html_out_f)
    else:
        with open(f'{filename}.html', 'w') as f:
            f.write(_ani_out_f.to_jshtml(fps=fps, default_mode='loop'))
        print(f'HTML of cine saved at: {f"{filename}.html"}')
    
def equalize_xyt(xyt_array):
    if xyt_array.shape[1] < xyt_array.shape[0]:
        xyt_array = xyt_array.transpose((1,0,2))

    mag_video_flatten = np.abs(xyt_array).reshape((xyt_array.shape[0], -1))
    _norm_factor = mag_video_flatten.max()
    mag_video_flatten = mag_video_flatten/_norm_factor
    mag_video_equal = exposure.equalize_adapthist(mag_video_flatten, clip_limit=0.01)
    mag_video_equal = mag_video_equal/mag_video_equal.max()
    mag_video_equal = mag_video_equal.reshape(xyt_array.shape)
    return mag_video_equal

ani = None
def cine_html5(xyt_array, xlims=None, ylims=None, fig_width=7, equalize=True):
    global ani
    if isinstance(xyt_array, list):
        proc_list = list()
        for an_xyt_array in xyt_array:
            proc_list.append(equalize_xyt(an_xyt_array) if equalize else an_xyt_array)
            
        mag_video_equal = np.concatenate(proc_list, axis=1)
    else:
        if xyt_array.shape[1] < xyt_array.shape[0]:
            xyt_array = xyt_array.transpose((1,0,2))

        mag_video_flatten = np.abs(xyt_array).reshape((xyt_array.shape[0], -1))
        _norm_factor = mag_video_flatten.max()
        mag_video_flatten = mag_video_flatten/_norm_factor
        mag_video_equal = exposure.equalize_adapthist(mag_video_flatten, clip_limit=0.01)
        mag_video_equal = mag_video_equal/mag_video_equal.max()
        mag_video_equal = mag_video_equal.reshape(xyt_array.shape)
    
    plt.ioff()
    fig, ax = plt.subplots(figsize=(fig_width,fig_width*mag_video_equal.shape[0]/mag_video_equal.shape[1]))
    fig.subplots_adjust(left=0, bottom=0, right=1, top=1, wspace=0, hspace=0)
    
    l = ax.imshow(np.abs(np.squeeze(mag_video_equal)[:,:,0]), cmap='gray', origin='lower', animated=True,
                 interpolation='bilinear')
    if xlims:
        ax.set_xlim(xlims)
    if ylims:
        ax.set_ylim(ylims)

    animate = lambda i: l.set_data(np.abs(np.squeeze(mag_video_equal)[:,:,i]))
    num_frame = np.squeeze(mag_video_equal).shape[-1]
    ax.set_axis_off()
    fig.subplots_adjust(0,0,1,1,0,0)
    #fig.tight_layout()
    ani = animation.FuncAnimation(fig, animate, frames=num_frame)
    plt.close(fig)
    plt.ion()
    return (HTML(ani.to_jshtml(fps=10, default_mode='loop')), ani)

# Available cines from OCMR dataset
  - filter the table by:
   1. cases with more than 1 slice
   2. fully sampled

In [None]:
ocmr_attributes_DF = pd.read_csv('https://ocmr.s3.amazonaws.com/ocmr_data_attributes.csv')

ocmr_attributes_DF.query('slices > 1 and smp == "fs"')

## Download one particular case

In [None]:
case_name = 'fs_0056_1_5T.h5'

In [None]:
temp_download_path = tempfile.TemporaryDirectory(dir='./')
download_path = temp_download_path.name
bucket_name = 'ocmr'

h5_file = f'{download_path}/{case_name}'

if not os.path.exists(download_path):
    os.makedirs(download_path)

# Hook for boto3 client to update tqdm progress bar
def hook(t):
    def inner(bytes_amount):
        t.update(bytes_amount)
    return inner

s3_client = boto3.client('s3', config=Config(signature_version=UNSIGNED))
# Grab filesize for progress bar updates
filesize = boto3.resource('s3', config=Config(signature_version=UNSIGNED)).Object('ocmr', f'data/{case_name}').content_length

with tqdm(total=filesize, unit='B', unit_scale=True, desc=case_name) as t:
    s3_client.download_file(bucket_name, f'data/{case_name}', h5_file, Callback=hook(t))

# Read the k-space data and EKG (if available)

In [None]:
all_data, param, enc, acq_first, noise_scan, twix_ekg = read_ocmr(f'{h5_file}', extrareturns=True)

# Reconstruct the cine
1. Whiten the data if a noise scan is available
2. Coil compress to a smaller number of virtual coils (e.g. 12 virtual coils)
 - keeping all 30+ coils may become too memory intensive
 - some network architectures may expect a fixed number of coils (DynamicRadCineMRI expects 12 coils by default)
3. Normalize k-space data
 - different MRI vendors and models, coils, protocols, etc. may produce raw k-space data with different ranges of values
 - we'll need to normalize them somehow to improve interoperability and agnosticism of our reconstructions
 - this also helps to make regularization behavior consistent across very different datasets
 - our strategy is similar to the BART strategy: normalize k-space by some scalar such that the 99th percentile of the naive root-sum-of-squares (RSS) reconstruction of the k-space data is equal to 1.0
4. Reconstruct the multi-coil cine (e.g. cine with `X` frames where each frames is comprised of `Y` coil images)
5. Estimate the coil sensitivity maps
6. Use coil sensitivity information to reduce the multi-coil cine down to a "normal" cine
  - this uses coil sensitivity information to adaptively combine each multi-coil frame into a single-coil frame

In [None]:
def proc_fs_ocmr(all_data, noise_scan=None, n_virtual_coils=None):
    # Whiten the data
    if noise_scan is not None and noise_scan.size > 0:
        noise_scan_nc_nx = np.mean(noise_scan, axis=0)
        noise_cov = sp.mri.util.get_cov(noise_scan_nc_nx)
    
        all_data_whitened = np.zeros_like(all_data, dtype=np.complex64)
        for a_slice in trange(all_data.shape[6]):
            ksp_nc_nxnynt = all_data.squeeze()[...,a_slice].transpose((2,0,1,3))
            ksp_nc_nxnynt = sp.mri.util.whiten(ksp_nc_nxnynt, noise_cov)
            all_data_whitened[:,:,0,:,:,0,a_slice,0,0] = ksp_nc_nxnynt.transpose((1,2,0,3))
        all_data = all_data_whitened
    
    ksp_nc_nx_ny_nz_avg = np.mean(all_data.squeeze().transpose((2, 0, 1, 3, 4)), axis=3)
    n_coils, n_x, n_y, n_z = ksp_nc_nx_ny_nz_avg.shape
    
    # Coil compress
    if n_virtual_coils is not None:
        pca_mtx, neig = coil_pca(ksp_nc_nx_ny_nz_avg.transpose((1,2,3,0)), ncal_x=n_x, ncal_y=n_y, ncal_z=n_z, neig=n_virtual_coils, percentile=99, pca_matrix_only=True)
        all_data = apply_pca_weights(all_data.squeeze().transpose((0,1,4,3,2)), pca_mtx, neig=n_virtual_coils)
    
    # Normalize k-space data
    all_data = normalize_kspace(all_data)

    # Image space
    im_coil_prew = transform.transform_kspace_to_image(all_data[:,:,n_z//2,:,:], [0,1])
    # ksp_scale = np.percentile(np.abs(np.sqrt(np.sum(np.power(im_coil_prew[:,:,n_z//2,0,:], 2), axis=2))), 99)
    # all_data = all_data/ksp_scale
    # im_coil_prew = transform.transform_kspace_to_image(all_data, [0,1])

    # Coil Sensitivity Estimation
    smaps_return = sp.mri.app.EspiritCalib(np.mean(all_data.transpose((4,0,1,2,3)), axis=4)[...,n_z//2], thresh=0.002, crop=0.98).run()
    
    # Adaptive coil combination
    smaps_na = smaps_return[...,np.newaxis]
    #maps_comb = np.sum(im_coil_prew.transpose((4,0,1,2,3))[...,n_z//2,:]*np.conj(smaps_na), axis=0)/np.sum(smaps_na*np.conj(smaps_na), axis=0)
    maps_comb = np.sum(im_coil_prew.transpose((3,0,1,2))*np.conj(smaps_na), axis=0)/np.sum(smaps_na*np.conj(smaps_na), axis=0)

    maps_comb = np.nan_to_num(maps_comb)
    
    return maps_comb, smaps_return, all_data, im_coil_prew

def normalize_kspace(ksp):
    return ksp/np.percentile(np.abs(np.sqrt(np.sum(np.power(transform.transform_kspace_to_image(ksp[:,:,ksp.shape[2]//2,0,:], [0,1]), 2), axis=2))), 99)

maps_comb, smaps_return, all_data_proc, im_coil = proc_fs_ocmr(all_data, noise_scan=noise_scan, n_virtual_coils=12)
view_cine(maps_comb)

# Demo of the [`DynamicRadCineMRI`](https://github.com/koflera/DynamicRadCineMRI/) Network

## Clone the repo to Google Drive
 - cloning to Google Drive persists the directory across Google Colab sessions


In [None]:
!git clone https://github.com/koflera/DynamicRadCineMRI.git /content/drive/MyDrive/DynamicRadCineMRI

In [None]:
#np.percentile(np.abs(np.sqrt(np.sum(np.power(im_coil_prew[:,:,n_z//2,0,:], 2), axis=2))), 99)

## Save the OCMR reconstruction from above for use with the `DynamicRadCineMRI` network
 - save to same Google Drive directory for persistent storage

In [None]:
complex_cine_file = os.path.join('/content/drive/MyDrive/DynamicRadCineMRI/', case_name.replace('.h5', 'cc.npz'))
np.savez(complex_cine_file,
         maps_comb_prew_cc=maps_comb, smaps_na_prew_cc=smaps_return, im_coil_prew_cc=im_coil)

## Dependencies for `DynamicRadCineMRI`

In [None]:
import torch
sys.path.append('/content/drive/MyDrive/DynamicRadCineMRI/')
sys.path.append('/content/drive/MyDrive/DynamicRadCineMRI/network/')
from network.nufft_operator import Dyn2DRadEncObj
from network.reconstruction_network import NUFFTCascade
from network.xtyt_fft_unet import XTYTFFTCNN

from helper_funcs.noise_funcs import add_gaussian_noise

import matplotlib
matplotlib.rcParams['animation.embed_limit'] = 2**128

## Utility function to generate tiny golden angle trajectory

In [None]:
def return_ga_traj(nX, ntviews, tiny_num=7):
    """Returns tiny golden angle trajectory and density compensation

    Parameters
    ----------
    nX : int
        number of readout points
    ntviews : int
        number of radial spokes/turns
    tiny_num : int
        type of golden angle, e.g. 7 for 7th tiny golden angle

    Returns
    -------
    traj : ndarray
        trajectory with shape (nX, ntviews)
    densitycomp : ndarray

    Reference
    ---------
    https://doi.org/10.1002/mrm.25831

    """
    golden_ratio = (np.sqrt(5.0) + 1.0) / 2.0
    golden_angle = np.pi / (golden_ratio + tiny_num - 1)

    radian = np.mod(np.arange(0, ntviews) * golden_angle, 2. * np.pi)
    rho = np.arange(-np.floor(nX / 2), np.floor(nX / 2)) + 0.5

    _sin = np.sin(radian)
    _cos = np.cos(radian)

    # Complex trajectory
    traj = np.stack(((rho[..., np.newaxis] * _sin[np.newaxis, ...]),
                     (rho[..., np.newaxis] * _cos[np.newaxis, ...])), axis=2)
    
    # Density Compensation
    densitycomp = np.sqrt(np.power(traj[...,0], 2) + np.power(traj[...,1], 2))
    densitycomp /= densitycomp.max()
    
    # Reshape into (n_spokes, n_readout, 2)
    traj = traj.transpose((1,0,2))
    
    # Reshape into (n_spokes, n_readout)
    densitycomp = densitycomp.transpose((1,0))
    
    return traj, densitycomp

## Load in the saved OCMR reconstruction and crop the frame dimensions to 320x320

In [None]:
#datatype
dtype = torch.float

#wheter to use the gpu or not;
use_GPU = 1

ocmr_fs_load = np.load(complex_cine_file)
cine_complex = ocmr_fs_load['maps_comb_prew_cc']
n_ro, n_pe, n_frames = cine_complex.shape

# Crop to 320x320
square_cine = np.zeros((320, 320, n_frames), dtype=np.complex128)
if n_ro > 320:
    square_cine[:,(320-n_pe)//2:-(320-n_pe)//2,:] = cine_complex[(n_ro-320)//2:-(n_ro-320)//2,...]
elif n_ro < 320:
    square_cine[(320-n_ro)//2:-(320-n_ro)//2,(320-n_pe)//2:-(320-n_pe)//2,:] = cine_complex
else:
    square_cine[:,(320-n_pe)//2:-(320-n_pe)//2,:] = cine_complex

## Load in the coil sensitivity maps and also crop these to 320x320

In [None]:
a_coil_sense_map = ocmr_fs_load['smaps_na_prew_cc'].squeeze()

n_coils_sense, n_ro_sense, n_pe_sense = a_coil_sense_map.shape
square_sense_map = np.zeros((n_coils_sense, 320, 320), dtype=np.complex128)

if n_ro > 320:
    square_sense_map[:, :,(320-n_pe_sense)//2:-(320-n_pe_sense)//2] = a_coil_sense_map[:,(n_ro_sense-320)//2:-(n_ro_sense-320)//2,:]
elif n_ro < 320:
    square_sense_map[:, (320-n_ro)//2:-(320-n_ro)//2,(320-n_pe_sense)//2:-(320-n_pe_sense)//2] = a_coil_sense_map
else:
    square_sense_map[:,:,(320-n_pe_sense)//2:-(320-n_pe_sense)//2] = a_coil_sense_map
    
print(n_coils_sense, n_ro_sense, n_pe_sense, square_sense_map.shape)

#convert to tensor of shape (1, 20, 2, 352, 352)
csm_tensor = torch.stack([torch.tensor(square_sense_map.real), torch.tensor(square_sense_map.imag)], dim=1).unsqueeze(0)

## Create Radial Trajectory

In [None]:
xf = square_cine
im_size = xf.shape

#create a torch tensor of shape, e.g. (1,1,2,320,320,n_frames)
# would mean there's 23 frames in the cine with 320x320 frames
xf_tensor = torch.stack(
    [torch.tensor(xf.real), torch.tensor(xf.imag)],
    dim=0).unsqueeze(0).unsqueeze(0)

#for fully-sampled in the radial case, we'd need 320*pi/2 ~ 500 spokes per frame
# anything less than that will be considered undersampled (e.g. 96 spokes per frame)
spokes_per_frame = 96
traj, dc = return_ga_traj(nX=320, ntviews=int(n_frames*spokes_per_frame))
ntviews = traj.shape[0]

#different NUFFT libraries have their own conventions for the ranges of values
# of the trajectory (i.e. coordinates).  In this case, the torchkbnufft expects
# pi to be the maximum value, so we scale accordingly
traj_scale = np.max((traj[...,0].max(), traj[...,1].max(), traj[...,0].min(), traj[...,1].min()))
traj *= np.pi/traj_scale

#after retrospective sorting of acquisitions, radial spokes will not be perfectly
# sequential and evenly spaced for each frame. We try to simulate this here
# by shuffling the radial spokes and then distributing them to each frame, so that
# each frame will have spokes that are a bit more randomly distributed and clumpy
np.random.seed(20211110)
spoke_indices = np.random.choice(traj.shape[0], (n_frames, spokes_per_frame), replace=False)
traj = traj[spoke_indices.flatten(),...]
# apply the same shuffling to the density compensation
dc = dc[spoke_indices.flatten(),...]

#convert trajectory to tensor of shape (1,2,Nrad,n_frames)
ktraj_tensor = torch.tensor(traj.T.reshape((2,-1,n_frames))).unsqueeze(0)

#convert density compensation to tensor of shape (1, 1, 1, Nrad, n_frames)
dcomp_tensor = torch.tensor(dc.T.reshape((-1,n_frames))).unsqueeze(0).unsqueeze(0).unsqueeze(0)

if use_GPU:
    xf_tensor = xf_tensor.float().to('cuda')
    ktraj_tensor = ktraj_tensor.float().to('cuda')
    csm_tensor = csm_tensor.float().to('cuda')
    dcomp_tensor = dcomp_tensor.float().to('cuda')

`fs_0045_3T.h5` with `E3C2K16` suppresses contraction extent

In [None]:
#create encoding operator object
EncObj = Dyn2DRadEncObj(im_size,
                        ktraj_tensor,
                        dcomp_tensor,
                        csm_tensor,
                        norm='ortho').cuda()

#define the CNN-block and thus the model;
#available: E3C2K4, E3C2K8, E3C2K16;
n_enc_stages = 3
n_convs_per_stage = 2
n_filters = 8
CNN = XTYTFFTCNN(n_ch=2,
                 n_enc_stages=n_enc_stages,
                 n_convs_per_stage=n_convs_per_stage,
                 n_filters=n_filters)

#initialize reconstruction network
reconstruction_network = NUFFTCascade(EncObj,
                                      CNN,
                                      learn_lambda=False,
                                      use_precon=True,
                                      mode='fine_tuning').cuda()

model_folder = '/content/drive/MyDrive/DynamicRadCineMRI/pre_trained_models/'
model_id = 'E{}C{}K{}'.format(n_enc_stages, n_convs_per_stage, n_filters)

reconstruction_network.load_state_dict(
    torch.load(model_folder + 'model_{}.pt'.format(model_id)))

#forward operator which transforms the OCMR cine into (undersampled) radial k-space 
k_tensor = EncObj.apply_A(xf_tensor)
# add some noise
k_tensor = add_gaussian_noise(k_tensor, sigma=0.06)

#now perform a non-cartesian reconstruction of the radial k-space
# this undersampled reconstruction will have all the expected undersampled artifacts
xu_tensor = EncObj.apply_Adag(k_tensor)

In [None]:
#list of different parameters for nu and npcg as used in the paper
nu_list = [10, 14]#[10,12,14] #[1, 1, 2, 4, 8, 12] # the number of alternations between CG- and CNN modules
npcg_list = [6, 10]#[6,8,10] #[0, 8, 4, 4, 4, 4] # number of CG iterations in the CG-module

#initialize dictionary which contains recos for different hyper-parameters
D_cnn_recos = {}
n_tests = 6

if use_GPU:
    xu_tensor = xu_tensor.cuda()
    
for nu, npcg in tqdm(list(itertools.product(nu_list, npcg_list))):

    reconstruction_network.nu = nu
    reconstruction_network.npcg = npcg

    #apply CNN-block + CG-block
    with torch.no_grad():
        xcnn_reg = reconstruction_network(xu_tensor.squeeze(0))

    if use_GPU:
        xcnn_reg = xcnn_reg.cpu()

    xcnn_reg = xcnn_reg.squeeze(0).squeeze(0).numpy()
    xcnn_reg = xcnn_reg[0, ...] + 1j * xcnn_reg[1, ...]
    D_cnn_recos[f'xcnn_nu{nu}_npcg{npcg}'] = xcnn_reg

if use_GPU:
    xu_tensor = xu_tensor.cpu()

xu = xu_tensor.squeeze(0).squeeze(0).cpu().numpy()
xu = xu[0, ...] + 1j * xu[1, ...]

## Visualization
### Left: fully-sampled cartesian reconstruction | Right: under-sampled non-cartesian reconstruction
 - notice the streaking artifacts on the right that are characteristic of reconstructing from undersampled radial k-space
 - we're also using indexing to zoom into the central 160x160 of the cine

In [None]:
view_cine([a_cine[320//4:-320//4,320//4:-320//4,...] for a_cine in [xf, xu]], fig_width=7)

### Set the `equalize` flag to `False` to disable adaptive histogram equalization we used above for improving contrast

In [None]:
view_cine([a_cine[320//4:-320//4,320//4:-320//4,...] for a_cine in [xf]], fig_width=3.5, equalize=False)

### Visualize the outpout of the `DynamicRadCineMRI` network

#### Observations for some cases
1. Case: `fs_0056_1_5T.h5`
 - when selecting the pre-trained `E3C2K8` model by setting the following:
```python
n_enc_stages = 3
n_convs_per_stage = 2
n_filters = 8
```
we can see that streaking artifacts are suppressed, but now the LV contraction is oversmoothed such that it no longer contracts to the same extent (the LV chamber doesn't get as small) as the fully-sampled case.  This suggests more tuning is necessary to properly trade-off LV motion fidelity and removal of undersampling streaking artifacts



In [None]:
D_cnn_recos.keys()

In [None]:
view_cine([a_cine[320//4:-320//4,320//4:-320//4,...] for a_cine in list(D_cnn_recos.values())], fig_width=14, equalize=True)