In [None]:
import os
import time
import cv2
import torch
import torch.nn as nn
import torch.nn.functional as F
import h5py
import numpy as np
import math
import glob

from src.fastmri.math import complex_abs, complex_mul, complex_conj
from src.fastmri.transforms import to_tensor
from src.fastmri.coil_combine import rss_complex, rss
from src.fastmri.subsample import RandomMaskFunc
from src.fastmri.fft import fft2c, ifft2c, fftshift, fft2

from src.tvcs.tv_op import *
from src.utils import *
from src.metrics import psnr, ssim
import matplotlib.pyplot as plt

from scipy.io import loadmat,savemat

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "3"

In [None]:
from src.models.condrefinenet_fourier import CondRefineNetDilated
chp_path = './checkpoints/fastMRI_scorenet_F64_c1/net.pth'
states = torch.load(chp_path)['weights']
ch = 64
scorenet = CondRefineNetDilated(2,2,ch).cuda()
scorenet.load_state_dict(states)
scorenet.eval();

In [None]:
from src.utils_IDPCNN import Denoisers
chp_dir = './checkpoints/fastMRI_dncnn_c1_mse_F128/'
cnn_denoisers = Denoisers(chp_dir,2,2,17,128,bias=True)

In [None]:
from src.fastmri.subsample import RandomMaskFunc
mask_func = RandomMaskFunc(
    center_fractions=[0.04],
    accelerations=[8]
)

file_ = './data/file1000243_17.mat' # file_brain_AXT2_209_6001069_2.mat
target_np = loadmat(file_)['target'].transpose(2,0,1)
sense_map_espirit = loadmat(file_)['sense_map'].transpose(2,0,1)
_,w,h = target_np.shape
target_tensor, target_rss, sense_map_tensor  = process_MRI_data(target_np, sense_map_espirit)
sense_map_tensor = sense_map_tensor.cuda().squeeze(0)

print(file_, target_rss[20:60].std(), target_tensor.shape)

In [None]:
crop_win = 320
data_kspace = fft2c(target_tensor).cuda().squeeze(0)
mask = mask_func(data_kspace.shape, 1234).cuda().byte()
mask = mask.repeat(1,data_kspace.shape[-3],1,1)
data_kspace = data_kspace * mask
c,h,w,_ = data_kspace.shape
plt.imshow(crop_image(target_rss, crop_win))
PSNR, _, rec_im = compare_psnr_ssim_tensor(ifft2c(data_kspace), target_rss, crop_win, coil_dim=0)
cv2.imwrite('target.png', 255*crop_image(target_rss, crop_win))
cv2.imwrite(f'./zf.png', 255*rec_im)

In [None]:
reg_cnn=True
reg_self=True
reg_diffusion=True
from src.utils_rSGM import RSGM
recon_module = RSGM(
    scorenet=scorenet,
    cnn_denoisers=cnn_denoisers,
    reg_diffusion=reg_diffusion,
    reg_cnn=reg_cnn,
    reg_similar=reg_self,
    reg_kernel_size=7,
    crop_win=320,
    self_learn_K_iter=5,
    K_step=1e-3,
    lam=3e-4,
    rho=3.3e-3,
    gamma=1.06,
    delta=1
)

In [None]:
torch.manual_seed(1234)
rec_im, psnrs = recon_module(
    data_kspace, 
    mask, 
    sense_map=sense_map_tensor,
    target_rss=target_rss,
    normlize_input=False,
    max_iter=110,
    K_gd_iter=5,
    verbose=True)

In [None]:
PSNR, _, rec_im_ = compare_psnr_ssim_tensor(rec_im, target_rss, crop_win=320, coil_dim=0)
print(PSNR)
plt.imshow(rec_im_)