In [1]:
import torch
import json
import SimpleITK as sitk
import numpy as np
import sigpy as sp
from scipy import io
from MoCo_INR import MoCoINR
from utils import fftnc, ifftnc, cal_metrics

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
## load config file 
config_path = 'Config/VISTA_recon.json'

# load config
with open(config_path, 'r') as f:
    config = json.load(f)

data_path = config['data']['data_path']
AF = config['data']['AF']   # you can change AF to 12 or 20
save_path = config['evaluation']['save_path'] + f'AF_{AF}/'
config['evaluation']['save_path'] = save_path

gpu = config['train']['gpu']
device = torch.device(f"cuda:{gpu}" if torch.cuda.is_available() else "cpu")

#### Load Demo Data
The demo code shows the reconstruction from OCMR dataset on LAX-4ch view.

In [3]:
gt_img = sitk.GetArrayFromImage(sitk.ReadImage(data_path + 'gt_img.nii.gz'))
smap = sitk.GetArrayFromImage(sitk.ReadImage(data_path + 'csm.nii.gz'))

n_coil = smap.shape[0]
n_frame = gt_img.shape[0]
grid_size = gt_img.shape[-1]

samp_mask = io.loadmat(data_path + f'samp_VISTA_{grid_size}x{n_frame}_R{AF}.mat')['samp']
samp_mask = samp_mask.T[:,:,None].repeat(grid_size, axis = 2)

In [4]:
## prepare data
gt_tensor = torch.from_numpy(gt_img).to(device)   # (n_frame, H, W)
smap_tensor = torch.from_numpy(smap).to(device)    # (n_coil, H, W)
sam_mask_tensor = torch.from_numpy(samp_mask).to(device)  # (n_frame, H, W)
kdata = fftnc(gt_tensor.unsqueeze(1) * smap_tensor.unsqueeze(0), dim=(-2, -1)) * sam_mask_tensor.unsqueeze(1)  # (n_frame, n_coil, H, W)
kdata = kdata * config['train']['kscale']

## Zero-filling Reconstruction
zf_recon = ifftnc(kdata, dim=(-2, -1))  # Zero-filled reconstruction
zf_recon = torch.sum(zf_recon * smap_tensor.unsqueeze(0).conj(), dim=1)  # Sum over coils
zf_recon = zf_recon / torch.abs(zf_recon).max()  # Normalize

psnr, ssim = cal_metrics(torch.abs(zf_recon), torch.abs(gt_tensor))
print(f'PSNR/SSIM for ZF recon: {psnr:.2f}/{ssim:.3f}')

PSNR/SSIM for ZF recon: 14.34/0.207


#### Init MoCo-INR model and recon the CMR images
+ The intermediate results are saved in the `out_VISTA` folder
  + If you don't want to save the intermediate results, you can change the flag of `config['evaluation']['flag']` to false
+ The recon process would take around 40 secs

In [5]:
MoCo_INR_model = MoCoINR(config, 
                         gt_img=gt_tensor,
                         smap=smap_tensor,
                         kdata=kdata,
                         mask_tensor=sam_mask_tensor)

est_recon = MoCo_INR_model.train()  # (n_frame, H, W)

 13%|████████▋                                                        | 201/1500 [00:05<01:02, 20.88it/s, loss=0.009916]

PSNR/SSIM: 34.64/0.923


 26%|█████████████████▏                                               | 397/1500 [00:11<00:29, 37.73it/s, loss=0.008076]

PSNR/SSIM: 36.53/0.936


 40%|██████████████████████████                                       | 601/1500 [00:17<00:38, 23.38it/s, loss=0.003440]

PSNR/SSIM: 37.90/0.947


 53%|██████████████████████████████████▌                              | 797/1500 [00:22<00:18, 38.34it/s, loss=0.004359]

PSNR/SSIM: 37.88/0.948


 66%|███████████████████████████████████████████▏                     | 997/1500 [00:27<00:13, 37.18it/s, loss=0.002728]

PSNR/SSIM: 38.17/0.949


 80%|███████████████████████████████████████████████████▏            | 1201/1500 [00:33<00:12, 23.57it/s, loss=0.002683]

PSNR/SSIM: 38.43/0.950


 93%|███████████████████████████████████████████████████████████▌    | 1397/1500 [00:38<00:02, 37.70it/s, loss=0.003033]

PSNR/SSIM: 38.63/0.951


100%|████████████████████████████████████████████████████████████████| 1500/1500 [00:41<00:00, 35.81it/s, loss=0.002699]


PSNR/SSIM: 38.32/0.951


In [None]:

from utils import save_as_gif
from IPython.display import HTML, display
import os

os.makedirs(save_path, exist_ok=True)

gt_gif_path = f'{save_path}/gt.gif'
recon_gif_path = f'{save_path}/recon.gif'

save_as_gif(torch.abs(gt_tensor).detach().cpu().numpy(), gt_gif_path, 0, 0.6)
save_as_gif(torch.abs(est_recon).detach().cpu().numpy(), recon_gif_path, 0, 0.6)

display(HTML(
    f"""
    <div style="display:flex; gap:24px; justify-content:center;">
      <figure style="text-align:center;">
        <img src="{gt_gif_path}" style="max-width:640px;">
        <figcaption>GT</figcaption>
      </figure>
      <figure style="text-align:center;">
        <img src="{recon_gif_path}" style="max-width:640px;">
        <figcaption>MoCo-INR</figcaption>
      </figure>
    </div>
    """
))

