In [None]:
# This code is for prospective high b-value DWI denoising using DnCNN
# Code dependency Python 3.7, Tesnsorflow and Keras
# Reference article: https://ieeexplore.ieee.org/abstract/document/7839189
##
!pip install bm3d
!pip install pydicom
!pip install medpy
##
import os
import cv2
import logging
import time 
import glob
import os.path
import models
import bm3d
import pydicom
import nibabel as nib
import numpy as np
import pandas as pd
import scipy
from scipy import stats 
from scipy.stats import rice
from scipy import ndimage
#from keras import backend as K
import tensorflow as tf
import matplotlib.pyplot as plt
import sklearn.metrics as metrics
from sklearn.model_selection import train_test_split
#import imquality.brisque as brisque
import skimage.measure    
from keras.callbacks import CSVLogger, ModelCheckpoint, LearningRateScheduler
from keras.callbacks import EarlyStopping
from keras.models import load_model
from keras.optimizers import Adam
from skimage.measure import compare_psnr, compare_ssim
#from skimage.metrics import structural_similarity as ssim
from medpy.filter.smoothing import anisotropic_diffusion   
from skimage import data
from math import log10, sqrt
from skimage.filters import unsharp_mask
from skimage import filters
from skimage import restoration
from skimage.filters import threshold_local,threshold_yen

 


In [None]:
# Parameters
class Args:
      model='DnCNN'
      batch_size=128
      test_dir='../prospective_study/testing_data/'
      pretrain=None
      only_test=False
args=Args()

# PSNR function definition
def PSNR(original,denoised,mask):
    original_mask=mask*original
    denoised_mask=mask*denoised
    mse = np.mean((original_mask - denoised_mask) ** 2)
    if(mse == 0):  # MSE is zero means no noise is present in the signal .
                  # Therefore PSNR have no importance.
        return 0
    max_pixel = 1.0
    psnr = 20 * log10(max_pixel / sqrt(mse))
    return psnr


In [None]:
# This section is to save information log  in snapshot directory
if not args.only_test:
    save_dir = './snapshot/save_'+ args.model + '_' + 'nex' + str(args.sigma) + '_' + time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime()) + '/'
    if not os.path.exists(save_dir):
        os.mkdir(save_dir)
    # log
    logging.basicConfig(level=logging.INFO,format='%(asctime)s %(filename)s[line:%(lineno)d] %(levelname)s %(message)s',
                    datefmt='%Y %H:%M:%S',
                    filename=save_dir+'info.log',
                    filemode='w')
    console = logging.StreamHandler()
    console.setLevel(logging.INFO)
    formatter = logging.Formatter('%(name)-6s: %(levelname)-6s %(message)s')
    console.setFormatter(formatter)
    logging.getLogger('').addHandler(console)
    
    logging.info(args)
    
else:
    save_dir = '/'.join(args.pretrain.split('/')[:-1]) + '/'

root  : INFO   <__main__.Args object at 0x7fb6e2a425d0>


In [None]:
# Read NEX 4 gold standard data
x_sub1_4=[]

for dirpath, dirnames, filenames in os.walk('../prospective_study/data/'):
    for filename in [f for f in filenames if f.endswith("b2000_avg4.nii")]:
        img_path=os.path.join(dirpath, filename)
        img = nib.load(os.path.join(dirpath, filename))
        
        #read data
        x_test_nifti = np.array(img.get_data())     
        slice_total=x_test_nifti.shape[2]
        for num_slice in range(slice_total):
            x_test=x_test_nifti[:,:,num_slice]
            x_test=np.rot90(x_test,3)
            x_test=np.fliplr(x_test)
            x_test=x_test/np.max(x_test)
            x_test=x_test.astype('float32')
            x_sub1_4.append(x_test)                    
            x_sub4=np.asarray(x_sub1_4)
            
        x_sub4=np.transpose(x_sub4)
        print(x_sub4.shape)
        %matplotlib inline
        plt.imshow(x_sub4[:,:,69],cmap='gray')

    
         

In [None]:

# model selection
if args.pretrain:   model = load_model(args.pretrain, compile=False) # If it is pretrained model
else:   
    if args.model == 'DnCNN': model = models.DnCNN()

# compile the model
model.compile(optimizer='Adam',loss='mse')
#print(model.summary())
lr = LearningRateScheduler(step_decay)


In [None]:
# Run this cell to load saved model using retrospective data

loaded_model = tf.keras.models.load_model('../prospective_study/model_avg2_final_23/')

model = loaded_model



In [None]:
# Testing for all the slices of one subject at a time
# Run this cell if denoising prospective DWI data


print('Start to test on {}'.format(args.test_dir))
out_dir = save_dir + args.test_dir.split('/')[-1] + '/'
if not os.path.exists(out_dir):
        os.mkdir(out_dir)
        
im_data=[] 
name = []
psnr_denoised = []
psnr_noisy=[]
ssim_denoised = []
ssim_noisy=[]

# Read test images

for dirpath, dirnames, filenames in os.walk('../prospective_study/data/'):
    #for filename in [f for f in filenames if f.endswith("b2000_avg1.nii")]:
    for filename in [f for f in filenames if f.endswith("b2000_avg2.nii")]:
    
        #print(os.path.join(dirpath, filename)) 
        img_path=os.path.join(dirpath, filename)
        img = nib.load(os.path.join(dirpath, filename))
        img_affine=img.affine
        x_test_nifti = np.array(img.get_data())  
        pix_max=np.max(x_test_nifti)        
        print(pix_max)
        im_data.append(x_test_nifti)        
        slice_total=x_test_nifti.shape[2]

        # read image
        for num_slice in range(slice_total):
            x_test_clean=x_test_nifti[:,:,num_slice]
            x_test_clean =x_test_clean.astype('float32')                      
            x_test_clean=x_test_clean/np.max(x_test_clean) # Normalize the test                                  
            x_test_noisy=x_test_clean
           
            # predict
            x_test = x_test_noisy.reshape(1, x_test_noisy.shape[0], x_test_noisy.shape[1], 1) 
            y_predict = model.predict(x_test)
            y_predict = y_predict.reshape(x_test_clean.shape)
            y_predict= np.clip(y_predict,np.min(x_test_clean),np.max(x_test_clean))
            
            # Sharpening  
         
            un_msk=x_test_clean-y_predict #unsharp masking for medium level contrast details
            y_predict=y_predict+0.9*un_msk   # set weightage of 0.8
            denoised_nifti.append(y_predict)

            # Save the denosied images in ouput directory
            denoised_nifti = np.asarray(denoised_nifti)
            denoised_nifti=np.transpose(denoised_nifti)
            sub2_avg1 = nib.Nifti1Image(denoised_nifti,affine=np.eye(4))
            nib.save(sub2_avg1,'avg2_denoised.nii')  
          
                      
            # Image quality metrics
            
            psnr_noise, psnr_denoised = PSNR(x_sub4[:,:,num_slice], x_test_noisy,mask[:,:,num_slice]),PSNR(x_sub4[:,:,num_slice], y_predict,mask[:,:,num_slice])
            ssim_noise, ssim_denoised= compare_ssim(x_sub4[:,:,num_slice], x_test_noisy), compare_ssim(x_sub4[:,:,num_slice], y_predict)
            
            # Save PSNR,SSIM  values to a list          
            psnr_denoised.append(psnr_denoised)
            ssim_denoised.append(ssim_denoised)
            psnr_noisy.append(psnr_noise)
            ssim_noisy.append(ssim_noise)
            

        #save the quality metrics in a directory
        pd.DataFrame({'psnr_sub_denoised':np.array(psnr_denoised), 'ssim_sub_original':np.array(ssim_denoised),'psnr_sub_original':np.array(psnr_noisy), 
                     'ssim_sub_original':np.array(ssim_noisy)
                         }).to_csv(out_dir+'/metrics_sub.csv', index=True)
           
              

In [None]:
# Testing for the specified slice and for one subject at a time
# Run this cell if denoising prospective DWI data
#print('Start to test on {}'.format(args.test_dir))
#out_dir = save_dir + args.test_dir.split('/')[-1] + '/'
#if not os.path.exists(out_dir):
       # os.mkdir(out_dir)
im_data=[] 
name = []
psnr = []
ssim = []
entropy_dn=[]
psnr_bm=[]
psnr_bl=[]
psnr_ad=[]
psnr_tv=[]
ssim_bm=[]
ssim_bl=[]
ssim_ad=[]
ssim_tv=[]
entropy_bl=[]
entropy_bm=[]
entropy_ad=[]
entropy_tv=[]
count=0
denoised_nifti=[]

# Read test images

for dirpath, dirnames, filenames in os.walk('../prospective_study/data/'):
    for filename in [f for f in filenames if f.endswith("b3000_avg1_full_128_sub1.nii")]:
    #for filename in [f for f in filenames if f.endswith("b3000_avg1_full_128_sub2.nii")]:
    #for filename in [f for f in filenames if f.endswith("b3000_avg1_full_128_sub3.nii")]:  
    #for filename in [f for f in filenames if f.endswith("b3000_avg2_full_128_sub1.nii")]:
    #for filename in [f for f in filenames if f.endswith("b3000_avg2_full_128_sub2.nii")]:
    #for filename in [f for f in filenames if f.endswith("b3000_avg2_full_128_sub3.nii")]:
    #for filename in [f for f in filenames if f.endswith("b3000_avg4_full_128_sub1.nii")]:
    #for filename in [f for f in filenames if f.endswith("b3000_avg4_full_128_sub2.nii")]:
    #for filename in [f for f in filenames if f.endswith("b3000_avg4_full_128_sub3.nii")]:  
        #print(os.path.join(dirpath, filename)) 
        img_path=os.path.join(dirpath, filename)
        img = nib.load(os.path.join(dirpath, filename))
        img_affine=img.affine
        x_test_nifti = np.array(img.get_data())  
        im_data.append(x_test_nifti)        
         
        # read image
        x_test_clean=x_test_nifti[:,:,69]
        x_test_clean =x_test_clean.astype('float32')           
        x_test_clean=x_test_clean/np.max(x_test_clean) # Normalize the test data                         
        x_test_noisy=x_test_clean
        
        # predict
        x_test = x_test_noisy.reshape(1, x_test_noisy.shape[0], x_test_noisy.shape[1], 1) 
        y_predict = model.predict(x_test)
        y_predict = y_predict.reshape(x_test_clean.shape)
        y_predict= np.clip(y_predict,np.min(x_test_clean),np.max(x_test_clean))
        y_predict=y_predict

        # Sharpening  
      # y_predict = cv2.medianBlur(y_predict,3)          
        un_msk=x_test_clean-y_predict #unsharp masking for medium level contrast details
        y_predict=y_predict+0.9*un_msk   # set weightage of 0.9
        denoised_nifti.append(y_predict)
        
        

      

In [None]:
# Display images
%matplotlib inline
fig = plt.figure()  
ax = plt.gca()
ax.axes.xaxis.set_ticks([])
ax.axes.yaxis.set_ticks([])       
imgplot = plt.imshow(x_test_noisy,cmap='gray')                                         
fig.suptitle('X_test_DnCNN') 
fig.colorbar(imgplot, ax=ax)

fig = plt.figure() 
ax = plt.gca()
ax.axes.xaxis.set_ticks([])
ax.axes.yaxis.set_ticks([])                
imgplot = plt.imshow(y_predict,cmap='gray',vmax=1,vmin=0)        
fig.suptitle('Y_predict_DnCNN') 
fig.colorbar(imgplot, ax=ax)

fig = plt.figure() 
ax = plt.gca()
ax.axes.xaxis.set_ticks([])
ax.axes.yaxis.set_ticks([])                
imgplot = plt.imshow(y_predict-x_test_clean,cmap='gray',vmax=0, vmin=-0.1)        
fig.suptitle('diffmap_denoised-original') 
fig.colorbar(imgplot, ax=ax)

fig = plt.figure() 
ax = plt.gca()
ax.axes.xaxis.set_ticks([])
ax.axes.yaxis.set_ticks([]) 
plt.imshow(x_sub10[:,:,69]-x_test_noisy,cmap='gray')     
fig.suptitle('GT-original') 
fig.colorbar(imgplot, ax=ax)


fig = plt.figure() 
ax = plt.gca()
ax.axes.xaxis.set_ticks([])
ax.axes.yaxis.set_ticks([]) 
plt.imshow(x_sub10[:,:,69]-y_predict,cmap='gray')     
fig.suptitle('GT-denoised') 
fig.colorbar(imgplot, ax=ax)


