In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import scipy.io as sio
import numpy as np
import os
import glob
from time import time
import math
from torch.nn import init
import copy
import cv2
from skimage.metrics import structural_similarity as ssim
#Time module
from datetime import datetime
import os 

## Input parameters

In [24]:
cs_ratio =50
epoch_num = 13
learning_rate =1e-4
##############
layer_num = 9
group_num =1
###define mask
mask_dir='mask_dir'
mask_type='q1'      #q1 for DLMRi and od for outer dense

In [25]:
gpu_list = '0'
img_size =256
########
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = gpu_list
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# Load CS Sampling Matrix: phi
Phi_data_Name = './%s/rand_%s_%d.mat'  %(mask_dir,mask_type,cs_ratio)
Phi_data = sio.loadmat(Phi_data_Name)
mask_matrix = Phi_data['mask_matrix']
mask_matrix = torch.from_numpy(mask_matrix).type(torch.FloatTensor)
mask = torch.unsqueeze(mask_matrix, 2)
mask = torch.cat([mask, mask], 2)
mask = mask.to(device)

In [26]:
class FFT_Mask_ForBack(torch.nn.Module):
    def __init__(self):
        super(FFT_Mask_ForBack, self).__init__()
    def forward(self, x, mask):
        x_dim_0 = x.shape[0]
        x_dim_1 = x.shape[1]
        x_dim_2 = x.shape[2]
        x_dim_3 = x.shape[3]
        x = x.view(-1, x_dim_2, x_dim_3, 1)
        y = torch.zeros_like(x)
        z = torch.cat([x, y], 3)
        fftz = torch.fft(z, 2)#take 2 D FFT
        z_hat = torch.ifft(fftz * mask, 2)#multiply with mask and then take INVERS FFT
        x = z_hat[:, :, :, 0:1]
        x = x.view(x_dim_0, x_dim_1, x_dim_2, x_dim_3)
        return x
# Define ISTA-Net-plus Block
class BasicBlock(torch.nn.Module):
    def __init__(self):
        super(BasicBlock, self).__init__()
        self.lambda_step = nn.Parameter(torch.Tensor([0.5]))
        self.soft_thr = nn.Parameter(torch.Tensor([0.01]))
        ################
        self.W1 = nn.Parameter(init.xavier_normal_(torch.Tensor(32, 1, 3, 3)))
        self.conv1_forward = nn.Parameter(init.xavier_normal_(torch.Tensor(32, 32, 3, 3)))
        self.conv2_forward = nn.Parameter(init.xavier_normal_(torch.Tensor(32, 32, 3, 3)))
        self.conv1_backward = nn.Parameter(init.xavier_normal_(torch.Tensor(32, 32, 3, 3)))
        self.conv2_backward = nn.Parameter(init.xavier_normal_(torch.Tensor(32, 32, 3, 3)))
        self.conv_G = nn.Parameter(init.xavier_normal_(torch.Tensor(1, 32, 3, 3)))

        self.c1 = nn.Parameter(torch.ones(1,1,1,1), requires_grad=True)
        self.c2 = nn.Parameter(torch.ones(1,1,1,1), requires_grad=True)
        self.c3 = nn.Parameter(torch.ones(1,1,1,1), requires_grad=True)        
        self.b1 = nn.Parameter(torch.zeros(1,32,1,1), requires_grad=True)
        self.b2 = nn.Parameter(torch.zeros(1,32,1,1), requires_grad=True)
        self.b3 = nn.Parameter(torch.zeros(1,32,1,1), requires_grad=True)
    def forward(self, x, fft_forback, PhiTb, mask):
        x = x - self.lambda_step * fft_forback(x, mask)
        x = x + self.lambda_step * PhiTb
        x_input = x
        gamma1 = F.conv2d(x_input, self.W1, padding=1)
        gamma2 = self.c1*F.conv2d(gamma1, self.conv1_forward, padding=1)+self.b1
        gamma2 = F.relu(gamma2)
        gamma3 = F.conv2d(gamma2, self.conv2_forward, padding=1)
        for _ in  range(1):            
            # backward computation
            gamma2 = F.conv_transpose2d(gamma3,self.conv2_forward,padding = 1)
            gamma1 = F.conv_transpose2d(gamma2,self.conv1_forward,padding = 1)            
            # forward computation
            gamma1 = F.relu( (gamma1 - self.c1 * F.conv2d( F.conv_transpose2d(gamma1,self.W1,padding=1) - x ,self.W1,padding=1)) + self.b1)
            gamma2 = F.relu( (gamma2 - self.c2 * F.conv2d( F.conv_transpose2d(gamma2,self.conv1_forward,padding=1) - gamma1, self.conv1_forward,padding=1)) + self.b2) 
            gamma3 = F.relu( (gamma3 - self.c3 * F.conv2d( F.conv_transpose2d(gamma3,self.conv2_forward,padding=1) - gamma2, self.conv2_forward,padding=1)) + self.b3) 
        gammaE = torch.mul(torch.sign(gamma3), F.relu(torch.abs(gamma3) - self.soft_thr))
        gamma4 = F.conv2d(gammaE, self.conv1_backward, padding=1)
        gamma4 = F.relu(gamma4)
        gamma5 = F.conv2d(gamma4, self.conv2_backward, padding=1)
        gamma6 = F.conv2d(gamma5, self.conv_G, padding=1)
        x_pred = x_input + gamma6
        x = F.conv2d(gamma3, self.conv1_backward, padding=1)
        x = F.relu(x)
        x_D_est = F.conv2d(x, self.conv2_backward, padding=1)
        symloss = x_D_est - gamma1
        return [x_pred, symloss]
        #############
# Define ISTA-Net-plus
class ISTANetplus(torch.nn.Module):
    def __init__(self, LayerNo):
        super(ISTANetplus, self).__init__()
        onelayer = []
        self.LayerNo = LayerNo
        self.fft_forback = FFT_Mask_ForBack()
        for i in range(LayerNo):
            onelayer.append(BasicBlock())
        self.fcs = nn.ModuleList(onelayer)
    def forward(self, PhiTb, mask):
        x = PhiTb
        layers_sym = []   # for computing symmetric loss
        for i in range(self.LayerNo):
            [x, layer_sym] = self.fcs[i](x, self.fft_forback, PhiTb, mask)
            layers_sym.append(layer_sym)
        x_final = x
        return [x_final, layers_sym]
model = ISTANetplus(layer_num)
model = nn.DataParallel(model)
model = model.to(device)

In [27]:
pwd

'd:\\python_dir\\knee_csmri'

Define file locations

In [28]:
model_dir='model_dir_knee'
log_dir='log_dir_knee'
data_dir = './test_data/knee_test_256'
result_dir = './result_dir_knee'
test_name1 = 'brainMR_test'
test_name2 = 'kneeMR_test'
#test_dir = os.path.join(data_dir, test_name)
###Image types
filepaths = glob.glob(data_dir + '/*.jpg')
#result_dir = os.path.join(result_dir, test_name2)
model_dir = "./%s/MRI_CS_ISTA_Net_plus_layer_%d_group_%d_ratio_%d" % (model_dir, layer_num, group_num, cs_ratio)
# Load pre-trained model with epoch number
model.load_state_dict(torch.load('%s/net_params_%d.pkl' % (model_dir, epoch_num)))
def psnr(img1, img2):
    img1.astype(np.float32)
    img2.astype(np.float32)
    mse = np.mean((img1 - img2) ** 2)
    if mse == 0:
        return 100
    PIXEL_MAX = 255.0
    return 20 * math.log10(PIXEL_MAX / math.sqrt(mse))
#if not os.path.exists(result_dir):
 #   os.makedirs(result_dir)
ImgNum = len(filepaths)
PSNR_All = np.zeros([1, ImgNum], dtype=np.float32)
SSIM_All = np.zeros([1, ImgNum], dtype=np.float32)
Init_PSNR_All = np.zeros([1, ImgNum], dtype=np.float32)
Init_SSIM_All = np.zeros([1, ImgNum], dtype=np.float32)
print('Images loaded=:',ImgNum)

Images loaded=: 22


In [29]:
print('\n')
print("MRI CS Reconstruction Start")
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
start_time = datetime.now()
with torch.no_grad():
    for img_no in range(ImgNum):
        imgName = filepaths[img_no]
        Iorg = cv2.imread(imgName, 0)
        Icol = Iorg.reshape(1, 1, 256, 256) / 255.0
        Img_output = Icol
        ### start timer
        start = time()
        batch_x = torch.from_numpy(Img_output)
        batch_x = batch_x.type(torch.FloatTensor)
        batch_x = batch_x.to(device)
        PhiTb = FFT_Mask_ForBack()(batch_x, mask)
        [x_output, loss_layers_sym] = model(PhiTb, mask)
        end = time()
        #end timer
        initial_result = PhiTb.cpu().data.numpy().reshape(256, 256)
        Prediction_value = x_output.cpu().data.numpy().reshape(256, 256)
        X_init = np.clip(initial_result, 0, 1).astype(np.float64)
        X_rec = np.clip(Prediction_value, 0, 1).astype(np.float64)
        init_PSNR = psnr(X_init * 255, Iorg.astype(np.float64))
        init_SSIM = ssim(X_init * 255, Iorg.astype(np.float64), data_range=255)
        rec_PSNR = psnr(X_rec*255., Iorg.astype(np.float64))
        rec_SSIM = ssim(X_rec*255., Iorg.astype(np.float64), data_range=255)
        #print("[%02d/%02d]time for %s is %.4f, Initial  PSNR is %.2f, Initial  SSIM is %.4f" % (img_no, ImgNum, imgName, (end - start), init_PSNR, init_SSIM))
        print("[%02d/%02d]time for %s is %.4f, Proposed PSNR is %.2f, Proposed SSIM is %.4f" % (img_no, ImgNum, imgName, (end - start), rec_PSNR, rec_SSIM))
        im_rec_rgb = np.clip(X_rec*255, 0, 255).astype(np.uint8)
        resultName = imgName.replace(data_dir, result_dir)
        cv2.imwrite("%s_ISTA_Net_plus_ratio_%d_epoch_%d_PSNR_%.2f_SSIM_%.4f.bmp" % (resultName, cs_ratio, epoch_num, rec_PSNR, rec_SSIM), im_rec_rgb)
        del x_output

        PSNR_All[0, img_no] = rec_PSNR
        SSIM_All[0, img_no] = rec_SSIM

        Init_PSNR_All[0, img_no] = init_PSNR
        Init_SSIM_All[0, img_no] = init_SSIM
end_time = datetime.now()
print('\n')
init_data =   "CS ratio is %d, Avg Initial  PSNR/SSIM for %s is %.2f/%.4f" % (cs_ratio, test_name1, np.mean(Init_PSNR_All), np.mean(Init_SSIM_All))
output_data = "CS ratio is %d, Avg Proposed PSNR/SSIM for %s is %.2f/%.4f, Epoch number of model is %d \n" % (cs_ratio,test_name1, np.mean(PSNR_All), np.mean(SSIM_All), epoch_num)
print(init_data)
print(output_data)
output_file_name = "./%s/PSNR_SSIM_Results_MRI_CS_ISTA_Net_plus_layer_%d_group_%d_ratio_%d.txt" % (log_dir, layer_num, group_num, cs_ratio)
output_file = open(output_file_name, 'a')
output_file.write(output_data)
output_file.close()
end_time = datetime.now()
print('Duration: {}'.format(end_time - start_time))
print("MRI CS Reconstruction End")



MRI CS Reconstruction Start
[00/22]time for ./test_data/knee_test_256\c122.jpg is 0.0209, Proposed PSNR is 44.19, Proposed SSIM is 0.9844
[01/22]time for ./test_data/knee_test_256\c136.jpg is 0.0209, Proposed PSNR is 43.89, Proposed SSIM is 0.9837
[02/22]time for ./test_data/knee_test_256\c168.jpg is 0.0209, Proposed PSNR is 42.79, Proposed SSIM is 0.9817
[03/22]time for ./test_data/knee_test_256\c176.jpg is 0.0209, Proposed PSNR is 40.82, Proposed SSIM is 0.9756
[04/22]time for ./test_data/knee_test_256\c178.jpg is 0.0219, Proposed PSNR is 43.63, Proposed SSIM is 0.9820
[05/22]time for ./test_data/knee_test_256\c180.jpg is 0.0219, Proposed PSNR is 43.63, Proposed SSIM is 0.9820
[06/22]time for ./test_data/knee_test_256\c183.jpg is 0.0219, Proposed PSNR is 43.46, Proposed SSIM is 0.9828
[07/22]time for ./test_data/knee_test_256\c196.jpg is 0.0229, Proposed PSNR is 43.56, Proposed SSIM is 0.9814
[08/22]time for ./test_data/knee_test_256\n110.jpg is 0.0229, Proposed PSNR is 43.41, Prop

## text save

In [None]:
print('\n')
result_dir='result_dir'
print("MRI CS Reconstruction Start")
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
start_time = datetime.now()
with torch.no_grad():
    for img_no in range(ImgNum):
        imgName = filepaths[img_no]
        Iorg = cv2.imread(imgName, 0)
        Icol = Iorg.reshape(1, 1, 256, 256) / 255.0
        Img_output = Icol
        ### start timer
        start = time()
        batch_x = torch.from_numpy(Img_output)
        batch_x = batch_x.type(torch.FloatTensor)
        batch_x = batch_x.to(device)
        PhiTb = FFT_Mask_ForBack()(batch_x, mask)
        [x_output, loss_layers_sym] = model(PhiTb, mask)
        end = time()
        #end timer
        initial_result = PhiTb.cpu().data.numpy().reshape(256, 256)
        Prediction_value = x_output.cpu().data.numpy().reshape(256, 256)
        X_init = np.clip(initial_result, 0, 1).astype(np.float64)
        X_rec = np.clip(Prediction_value, 0, 1).astype(np.float64)
        init_PSNR = psnr(X_init * 255, Iorg.astype(np.float64))
        init_SSIM = ssim(X_init * 255, Iorg.astype(np.float64), data_range=255)
        rec_PSNR = psnr(X_rec*255., Iorg.astype(np.float64))
        rec_SSIM = ssim(X_rec*255., Iorg.astype(np.float64), data_range=255)
        print("[%02d/%02d] recons. time for %s is, %.4f, Proposed PSNR is, %.2f,Proposed SSIM is, %.4f" % (img_no, ImgNum, imgName, (end - start), rec_PSNR, rec_SSIM))
        #save log
        output_data = "%.2f, \n" % (rec_PSNR)
        output_data2 = "%.4f, \n" % (rec_SSIM)
        output_file_name = "./%s/psnr_knee_%d.txt" % (log_dir, cs_ratio)
        output_file_name2 = "./%s/ssim_knee_%d.txt" % (log_dir, cs_ratio)
        output_file = open(output_file_name, 'a')
        output_file2 = open(output_file_name2, 'a')
        output_file.write(output_data)
        output_file2.write(output_data2)
        output_file.close() 
        output_file2.close()
        ##       
        im_rec_rgb = np.clip(X_rec*255, 0, 255).astype(np.uint8)
        resultName = imgName.replace(data_dir, result_dir)
        cv2.imwrite("%s_ISTA_Net_plus_ratio_%d_epoch_%d_PSNR_%.2f_SSIM_%.4f.bmp" % (resultName, cs_ratio, epoch_num, rec_PSNR, rec_SSIM), im_rec_rgb)
        del x_output
        PSNR_All[0, img_no] = rec_PSNR
        SSIM_All[0, img_no] = rec_SSIM

        Init_PSNR_All[0, img_no] = init_PSNR
        Init_SSIM_All[0, img_no] = init_SSIM
print('\n')
print("MRI CS Reconstruction End")



MRI CS Reconstruction Start
[00/22] recons. time for knee_test_256/c168.jpg is, 0.0232, Proposed PSNR is, 36.39,Proposed SSIM is, 0.9262
[01/22] recons. time for knee_test_256/c122.jpg is, 0.0220, Proposed PSNR is, 37.62,Proposed SSIM is, 0.9365
[02/22] recons. time for knee_test_256/c136.jpg is, 0.0223, Proposed PSNR is, 38.08,Proposed SSIM is, 0.9358
[03/22] recons. time for knee_test_256/c183.jpg is, 0.0265, Proposed PSNR is, 36.80,Proposed SSIM is, 0.9273
[04/22] recons. time for knee_test_256/c180.jpg is, 0.0248, Proposed PSNR is, 37.15,Proposed SSIM is, 0.9255
[05/22] recons. time for knee_test_256/c176.jpg is, 0.0243, Proposed PSNR is, 34.37,Proposed SSIM is, 0.8945
[06/22] recons. time for knee_test_256/c178.jpg is, 0.0259, Proposed PSNR is, 37.15,Proposed SSIM is, 0.9255
[07/22] recons. time for knee_test_256/c196.jpg is, 0.0215, Proposed PSNR is, 37.32,Proposed SSIM is, 0.9246
[08/22] recons. time for knee_test_256/n142.jpg is, 0.0311, Proposed PSNR is, 36.12,Proposed SSIM 