In [1]:
# Multi Coil D-GEC demo for fastMRI brain
# Compares D-GEC and PnP-PDS
# Algorithm and Denoiser is desgined for haar wavelet and level 4 wavelet decomposition

import os, sys
sys.path.append(os.path.dirname(sys.path[0]))
import numpy as np
import matplotlib.pyplot as plt
import torch
import random

from utils import general as gutil

from algorithms import D_GEC_multi_coil
from algorithms import PnP_PDS_multi_coil
from algorithms import general_multi_coil
from fastMRI_utils import transforms_new
from fastMRI_utils.utils_fastMRI import tensor_to_complex_np

from scipy.io import savemat
from scipy.io import loadmat

device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
device_cpu = torch.device("cpu")

random.seed(10)

In [2]:
# DNCNN and DnCNN_cpc Denoisers locations and names
modelnames_cpc = ['checkpoint_last_DnCNN_cpc_0_10.pt','checkpoint_last_DnCNN_cpc_10_20.pt', 'checkpoint_last_DnCNN_cpc_20_50.pt', 'checkpoint_last_DnCNN_cpc_50_120.pt', 'checkpoint_last_DnCNN_cpc_120_500.pt']    
modeldir_cpc = '/storage/D_GEC_Demo/'
model_PnP_PDS = '/storage/D_GEC_Demo/checkpoint_last_DnCNN_0_50.pt'

# Data
mdic = loadmat("/storage/D_GEC_Demo/R_4_VD_SNR_35_data_brain.mat")


In [4]:
# For R = 4, 35 dB SNR
image_number = 7

num_of_D_GEC_iterations = 10
num_of_PnP_PDS_iterations = 50
theta_damp = 0.3
zeta_damp = 0.3
sens_var = torch.tensor(2.5283e-11,device = device) # For R = 4 # noise variance introduced by imperfect sens-map estimation by ESPIRiT # computed this number by taking average over training data
gamma_tune = 12


y_mat = mdic['y_mat']
GT_target_complex_mat = mdic['GT_target_complex_mat']
sens_maps_mat = mdic['sens_maps_mat']
mask_mat = mdic['mask_mat']
prob_map_mat = mdic['prob_map_mat']
sigma_w_square_mat = mdic['sigma_w_square_mat']
M_mat = mdic['M_mat']
N_mat = mdic['N_mat']
metric_mask_mat = mdic['metric_mask_mat']
GT_target_abs_mat = mdic['GT_target_abs_mat']


y_foo = transforms_new.to_tensor(y_mat[image_number]).permute(2,0,1,3)
y = (torch.cat((y_foo[:,:,:,0], y_foo[:,:,:,1]), dim = 0).unsqueeze(0)).to(device)
GT_target_complex = transforms_new.to_tensor(GT_target_complex_mat[image_number]).permute(2,0,1).unsqueeze(0).to(device)
sens_maps_new = transforms_new.to_tensor(sens_maps_mat[image_number]).permute(2,0,1,3).to(device)
mask = mask_mat[image_number,:,:]
prob_map = prob_map_mat[image_number,:,:]
wvar = torch.tensor(sigma_w_square_mat[image_number,0],device=device)
M = M_mat[image_number,0]
N = N_mat[image_number,0]
metric_mask = transforms_new.to_tensor(metric_mask_mat[image_number,:,:]).to(device)
GT_target_abs = transforms_new.to_tensor(GT_target_abs_mat[image_number,:,:]).to(device)

y = y.type('torch.FloatTensor').to(device)
GT_target_complex = GT_target_complex.type('torch.FloatTensor').to(device)
sens_maps_new = sens_maps_new.type('torch.FloatTensor').to(device)

## DGEC

x_D_GEC_denoiser, x_D_GEC_LMMSE, PSNR_list_GEC = D_GEC_multi_coil.D_GEC(y, sens_maps_new, mask, wvar, sens_var, num_of_D_GEC_iterations, modelnames_cpc, modeldir_cpc ,theta_damp,zeta_damp, GT_target_abs, metric_mask)
recovered_image_DGEC_1 = transforms_new.complex_abs(x_D_GEC_denoiser.squeeze(0).permute(1,2,0))

# PnP-PDS

x_PnP_PDS, PSNR_list_PnP_PDS = PnP_PDS_multi_coil.PnP_PDS(y, sens_maps_new, mask, wvar, num_of_PnP_PDS_iterations, model_PnP_PDS, gamma_tune, GT_target_abs, metric_mask)
recovered_image_PNP = transforms_new.complex_abs(x_PnP_PDS.squeeze(0).permute(1,2,0))

# Metric

PSNR_D_GEC_Den = gutil.calc_psnr((recovered_image_DGEC_1*metric_mask).cpu(), (GT_target_abs*metric_mask).cpu(), max = (GT_target_abs*metric_mask).max().cpu())
PSNR_PnP_PDS = gutil.calc_psnr((recovered_image_PNP*metric_mask).cpu(), (GT_target_abs*metric_mask).cpu(), max = (GT_target_abs*metric_mask).max().cpu())

rSNR_D_GEC_Den = gutil.calc_rSNR_non_DB_scale((recovered_image_DGEC_1*metric_mask).cpu(), (GT_target_abs*metric_mask).cpu())
rSNR_PnP_PDS = gutil.calc_rSNR_non_DB_scale((recovered_image_PNP*metric_mask).cpu(), (GT_target_abs*metric_mask).cpu())

SSIM_D_GEC_Den = gutil.calc_SSIM((recovered_image_DGEC_1*metric_mask).cpu(), (GT_target_abs*metric_mask).cpu())
SSIM_PnP_PDS = gutil.calc_SSIM((recovered_image_PNP*metric_mask).cpu(), (GT_target_abs*metric_mask).cpu())

print("PSNR D-GEC: ", PSNR_D_GEC_Den)
print("PSNR PnP-PDS: ", PSNR_PnP_PDS)

print('Done!')


PSNR D-GEC:  40.79025507720501
PSNR PnP-PDS:  40.01435642031464
Done!
