In [None]:
# GPU server
%env CUDA_VISIBLE_DEVICES=0

# Libraries
%matplotlib widget
import os
import torch
import numpy as np
import scipy.io as sio
import sigpy.mri as mr
import sigpy.plot as pl
import matplotlib.pyplot as plt

from codes import utils, parser_ops
from codes.model_3d import UnrolledNet

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
parser = parser_ops.get_parser()
args = parser.parse_args(args=[])

### Plot loss curves

In [None]:
model_name = "T1w_R2_5Unrolls_5ResNet"
print(f"Loading model {model_name}")

file_dir = os.path.join(os.getcwd(), 'saved_models')
saved_model_dir = os.path.join(file_dir, model_name)

trn_loss=sio.loadmat(os.path.join(saved_model_dir, 'TrainingLog.mat'))['trn_loss']
val_loss=sio.loadmat(os.path.join(saved_model_dir, 'TrainingLog.mat'))['val_loss']

plt.figure()
plt.plot(np.asarray(trn_loss).T)
plt.plot(np.asarray(val_loss).T)
plt.title('Loss Curves'), plt.xlabel('Epochs'), plt.ylabel('Loss')
plt.legend(['trn loss', 'val loss'])
plt.grid()
plt.show()

### Load the data and generate images

In [None]:
# Data directories
img_data = np.load(args.data_dir)[None]
args.ncontrast, args.nrow, args.ncol, args.ndepth = img_data.shape

# Get training data
kspace_test = utils.fftcn(img_data, axes=(-1,-2,-3)) 

# 1D ifft
kspace_test = utils.ifftc1(kspace_test, axis=(-3))
kspace_test = kspace_test / np.percentile(np.abs(kspace_test), 95)

# Generate compressed sensing masks with autocalibrated signal
mask = np.load(os.path.join(saved_model_dir, "mask.npy"))

test_mask = np.complex64(mask)
nw_input = utils.ifftc2(kspace_test * test_mask, axes=(-1,-2))
ref_image = utils.ifftc2(kspace_test, axes=(-1,-2))

### Load the model and perform the reconstruction

In [None]:
model = UnrolledNet(args, device=device).to(device)
model.load_state_dict(torch.load(os.path.join(saved_model_dir, 'best.pth'), map_location=torch.device('cpu'))["model_state"])

model.eval()
with torch.no_grad():
    input_to_nw = torch.from_numpy(utils.c2r(nw_input, axis=1).reshape(2 * args.ncoil * args.ncontrast, args.nrow, args.ncol, args.ndepth)[None]).to(device)
    trn_mask = torch.from_numpy(test_mask[None]).to(device)
    nw_img_output, lamdas, nw_kspace_output = model(input_to_nw, trn_mask, trn_mask)

zs_ssl_recon = utils.r2c(nw_img_output.squeeze().reshape(args.ncoil * args.ncontrast, 2, args.nrow, args.ncol, args.ndepth).to('cpu').numpy(), axis=1)

In [None]:
combined_img = np.stack((ref_image[0], nw_input[0], zs_ssl_recon[0]))

pl.ImagePlot(combined_img.transpose(0,3,1,2), z=0, title="left: R=1 | middle: R=2 retrospective | right: ZS-SSL output")

In [None]:
all_ssim = utils.ssim_batch(
    np.abs(ref_image[..., 32]),
    np.abs(zs_ssl_recon[..., 32])
)
all_psnr = utils.psnr_batch(
    np.abs(ref_image[..., 32]),
    np.abs(zs_ssl_recon[..., 32])
)

print(all_ssim)
print(all_psnr)