In [1]:
import tensorflow as tf
import cProfile
import os
from os import walk
import matplotlib.pyplot as plt
import pydicom as dicom
import numpy as np
from skimage import data, img_as_float
from skimage import measure
import math
from tf_unet import unet, util, image_util

# Data Preparation and Undersampling

In [2]:
class Image:
    
    def __init__(self, name):
        self.data_set = dicom.dcmread(name)
        
    def get_Image(self):
        #returns pixel array of the image
        return self.data_set.pixel_array
    
    def get_Kspace(self,flip):
        #returns flipped k-space data
        data = self.get_Image()
        arr = np.fft.fft2(data)
        return arr if not flip else np.fft.fftshift(arr)
    
    def skipEveryNLine(self,N,img):
        data = img
        step = N
        i = 0
        count = 0
        while i < int(len(data)):
            for j in range (0,int(len(data[i]))):
                data[i][j]=0
            count+=1
            if(count>=N):
                count = 0
                i= (i+2) if i < int(len(data))else i
            else:
                i= (i+1) if i < int(len(data))else i
        return data  

    def grappaUndersample(self,accelaration_factor,centerline_factor):
        #undersamples top and bottom imagex 
        flipped_Img = self.get_Kspace(True)
        center_line = int(len(flipped_Img)/2)
        centerline_factor = len(flipped_Img)/centerline_factor
        
        upper_border = int(center_line - ( len(flipped_Img)/centerline_factor)/2)
        lower_border = int(center_line + ( len(flipped_Img)/centerline_factor)/2)
        
        upSubsample = None
        downSubsample = None
        
        if(accelaration_factor == 0):
            upSubsample = np.zeros((upper_border,320), dtype=int)
            downSubsample = np.zeros((upper_border,320), dtype=int)
        else:
            upSubsample = self.skipEveryNLine(accelaration_factor,flipped_Img[0:upper_border])
            downSubsample = self.skipEveryNLine(accelaration_factor,flipped_Img[lower_border+1:])
         
        return np.concatenate((upSubsample,flipped_Img[upper_border:lower_border+1] ,downSubsample))
    
    def showImage(self,img,inFourierDomain,title):
        if not inFourierDomain:
            plt.imshow(np.absolute(Img), cmap=plt.cm.bone), plt.title(title,plt.axis('off'))
            
            plt.show()
        else:
            plt.axis('off')
            plt.imshow(np.log(1+np.absolute(Img)), cmap=plt.cm.bone), plt.title("Original Image")
            plt.show()
    
    

In [3]:
def fourierTransform(img):
    data = img
    arr = np.fft.fft2(data)  
    return arr

def inverseFourierTransform(img):
    data = img
    fourierTransformImg = np.fft.ifft2(data)
    return fourierTransformImg

# Import dataset

In [4]:
from os import walk
from os import listdir
from os.path import isfile, join
def importDataset(path):
    dcm_files = [path + '/' + f for f in listdir(path) if isfile(join(path, f))]
    Images =  [Image(f) for f in dcm_files]
    return Images
def showImg(Img,inFourierDomain):
        if not inFourierDomain:
            plt.imshow(np.absolute(Img), cmap=plt.cm.bone), plt.title("Original Image",plt.axis('off'))
            
            plt.show()
        else:
            plt.axis('off')
            plt.imshow(np.log(1+np.absolute(Img)), cmap=plt.cm.bone), plt.title("Original Image")
            plt.show()

import dataset using the function below. Define path to the training dataset folder on "mypath" parameter and path to testing folder on "test Path" parameter. The function returns a set of Image objects.

In [7]:
mypath = './brain'
# testPath = 'F:/fastMRI_brain_DICOM/183268758055'

datasets = importDataset(mypath)
# test_datasets = importDataset(testPath)
dataset_shape = len(datasets[0].get_Image())
dataset_column = len(datasets[0].get_Image()[0])
print(dataset_shape)

320


In [6]:
def loadImages(dataset,AF):
    #edge cases 
    if not dataset:
        raise Exception("Input a valid dataset")
    
    if not AF: 
        raise Exception("Enter an accelaration factor")
        
    if any ((x>4 or x<0) for x in AF):
        raise Exception("The model only trains accelaration factors in the range of x where 0 <= x <= 4")
        
    x_train = None
    y_train = None
       
    for i in range(len(dataset)):
        for j in range(len(AF)):
            if not np.any(x_train) and not np.any(y_train) :
                x_train = np.absolute(dataset[0].get_Image())
                y_train = np.absolute(inverseFourierTransform(dataset[0].grappaUndersample(AF[0],20)))
            
            else:
                x_train = np.vstack([x_train,np.absolute(dataset[i].get_Image())])
                y_train = np.vstack([y_train,np.absolute(inverseFourierTransform(dataset[i].grappaUndersample(AF[j],20)))])
    
    x_train = x_train.reshape(len(dataset)*len(AF),320,320,1)
    y_train = y_train.reshape(len(dataset)*len(AF),320,320,1)
    return (x_train,y_train)
    

# UNet Implementation

In [7]:
from tensorflow.keras.models import Model
from keras.layers import Input, Dense, Activation, concatenate, UpSampling2D
from keras.layers.convolutional import Conv2D, Conv2DTranspose
from keras.layers.pooling import MaxPooling2D, AveragePooling2D
from keras.losses import mean_squared_error, mean_absolute_error
from keras.optimizers import RMSprop
from keras.initializers import RandomNormal
from keras.callbacks import ModelCheckpoint


In [8]:
class uNet:
    def __init__(self,activation):
        self.hasArchitecture = False
        self.model = None
        self.activation = activation 
        self.history = None
        self.kernels = (3,3)
        self.strides = (1,1)
    
    def set_model(self,model):
        self.model = model
        self.hasArchitecture = True
        self.history = model.history
    
    def train(self,subsampled,original,batch_size,num_ephocs,checkpoint_output):
        if not self.hasArchitecture:
            self.create_architecture()
        
        cp_path = checkpoint_output
        model_path = os.path.join(cp_path,self.model.name)
        #save model to another directory if directory already exists 
        i=1
        while (os.path.exists(model_path)):
            model_path=os.path.join(model_path,"v_{}".format(i))
        os.mkdir(model_path)
        cp_format = os.path.join(model_path, "unet-{epoch:02d}.hdf5")
        
        cp_callback = ModelCheckpoint(
        #save checkpoints at every epoch
        cp_format, save_freq='epoch')

        print("Network checkpoints will be saved to: '{}'".format(cp_path))

        self.history = self.model.fit(
        subsampled,
        original,
        batch_size=batch_size,
        epochs=num_ephocs,
        shuffle=True,
        validation_split=.2,
        callbacks=[cp_callback]
        )
        
        
        self.model.save(model_path) 
         
    
    
    def create_architecture(self):
        inputs = Input(shape=(dataset_shape,dataset_column,1))
        
        weights_init = RandomNormal(mean=0.0,stddev=0.1)
        
        #using padding = same is zero padding
        zeroPadding = "same"
        
        layer1 = Conv2D(
            filters=64,
            kernel_size=self.kernels,
            strides=self.strides,
            padding=zeroPadding,
            activation=self.activation,
            kernel_initializer=weights_init)(inputs)
        
        layer2 = Conv2D(
            filters=64,
            kernel_size=self.kernels,
            strides=self.strides,
            padding=zeroPadding,
            activation=self.activation,
            kernel_initializer=weights_init)(layer1)
        
        maxpool1 = MaxPooling2D(pool_size=(2, 2), strides=(2, 2), padding='same')(layer2)
        
        layer3 = Conv2D(
            filters=128,
            kernel_size=self.kernels,
            strides=self.strides,
            padding=zeroPadding,
            activation=self.activation,
            kernel_initializer=weights_init)(maxpool1)
        
        layer4 = Conv2D(
            filters=128,
            kernel_size=self.kernels,
            strides=self.strides,
            padding=zeroPadding,
            activation=self.activation,
            kernel_initializer=weights_init)(layer3)
        
        maxpool2 = MaxPooling2D(pool_size=(2, 2), strides=(2, 2), padding='same')(layer4)
        
        layer5 = Conv2D(
            filters=256,
            kernel_size=self.kernels,
            strides=self.strides,
            padding=zeroPadding,
            activation=self.activation,
            kernel_initializer=weights_init)(maxpool2)
        
        layer6 = Conv2D(
            filters=128,
            kernel_size=self.kernels,
            strides=self.strides,
            padding=zeroPadding,
            activation=self.activation,
            kernel_initializer=weights_init)(layer5)
        
        averagepool1 =  concatenate([UpSampling2D(size=(2, 2))(layer6), layer4], axis=3)        
        
        layer7 = Conv2D(
            filters=128,
            kernel_size=self.kernels,
            strides=self.strides,
            padding=zeroPadding,
            activation=self.activation,
            kernel_initializer=weights_init)(averagepool1)
        
        layer8 = Conv2D(
            filters=64,
            kernel_size=self.kernels,
            strides=self.strides,
            padding=zeroPadding,
            activation=self.activation,
            kernel_initializer=weights_init)(layer7)
        
        averagepool2 = concatenate([UpSampling2D(size=(2, 2))(layer8), layer2], axis=3)
        
        layer9 = Conv2D(
            filters=64,
            kernel_size=self.kernels,
            strides=self.strides,
            padding=zeroPadding,
            activation=self.activation,
            kernel_initializer=weights_init)(averagepool2)
        
        layer10 = Conv2D(
            filters=64,
            kernel_size=self.kernels,
            strides=self.strides,
            padding=zeroPadding,
            activation=self.activation,
            kernel_initializer=weights_init)(layer9)
        
        finalOutput = Conv2D(
            filters=1,
            kernel_size=self.kernels,
            strides=self.strides,
            padding=zeroPadding,
            activation=None,
            kernel_initializer=weights_init)(layer10)
        
        self.model = Model(inputs=[inputs], outputs=[finalOutput])
        
        self.model.summary()
        
        self.model.compile(
            optimizer=RMSprop( lr=.001, rho=0.9, epsilon=1e-08, decay=0),
            loss=mean_squared_error,
            metrics=[mean_squared_error])

        self.hasArchitecture = True

    def reconstruct(self,undersampled_img):
        return self.model.predict(undersampled_img)
    
    def plot_model(self,metric):
        plt.plot(self.history.history[metric])
        plt.title(metric)
        plt.ylabel(metric)
        plt.xlabel('Epoch')
        plt.yscale("log")
        plt.xscale( "log")
        return self.history
      

In [None]:
#Specify prefered AFs
AF=[1,2,3,4]
original,subsampled = loadImages(datasets,AF)

#Intialise model with prefered Activation Function
net = uNet('relu')

#Intialise batch size, number of epochs and output directory
batch_size = 32
epochs = 2000
modelOutputDirectory = 'C:/MRI_dataset/output'

#Train Model
net.train(subsampled,original,batch_size,epochs,modelOutputDirectory)

Model: "functional_3"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_3 (InputLayer)            [(None, 320, 320, 1) 0                                            
__________________________________________________________________________________________________
conv2d_11 (Conv2D)              (None, 320, 320, 64) 640         input_3[0][0]                    
__________________________________________________________________________________________________
conv2d_12 (Conv2D)              (None, 320, 320, 64) 36928       conv2d_11[0][0]                  
__________________________________________________________________________________________________
max_pooling2d_2 (MaxPooling2D)  (None, 160, 160, 64) 0           conv2d_12[0][0]                  
_______________________________________________________________________________________

# Evaluation class

The following contains the code for reconstruction, error fixing and image/mse plotting

In [37]:
class Eval():
    def __init__(self,model):
        m=uNet('relu')
        m.set_model(model)
        self.model = m
        
    def undersample_dataset(self,dataset,AF):
        #edge cases 
        if not dataset:
            raise Exception("Input a valid dataset")

        if not AF: 
            raise Exception("Enter an accelaration factor")

        if any ((x>4 or x<0) for x in AF):
            raise Exception("The model only trains accelaration factors in the range of x where 0 <= x <= 4")

        undersampled_imgs ={}

        for i in range(len(dataset)):
            for j in range(len(AF)):
                x = {"{img}{AF}".format(img=i , AF = AF[j]):np.absolute(inverseFourierTransform(dataset[i].grappaUndersample(AF[j],20)))
                }
                undersampled_imgs.update(x)

        return undersampled_imgs
    
    def reconstruct(self,undersampled_img):
        undersampled_input = undersampled_img.reshape(1,dataset_shape,dataset_shape,1)
        recon_output = self.model.reconstruct(undersampled_input)
        return np.squeeze(recon_output)
    
    def k_space_Correction(self,recon_img,original_img):
        recon_fft = fourierTransform(recon_img)
        original_fft = fourierTransform(original_img)

        center_line = int(len(recon_fft)/2)

        centerline_factor = 20

        upper_border = int(center_line - ( centerline_factor)/2)
        lower_border = int(center_line + ( centerline_factor)/2)

        #Fix k-space of recon Image
        recon_upper = recon_fft[0:upper_border]
        recon_lower = recon_fft[lower_border+1:]
        recon_centre = original_fft[upper_border:lower_border+1]
        recon_fixed = np.concatenate((recon_upper,recon_centre ,recon_lower))

        return np.absolute(inverseFourierTransform(recon_fixed))
    
    
    #Plot MSE between recon Img and original Img
    def plot_MSE(self,dataset,AF):
        undersampled_imgs = self.undersample_dataset(dataset,AF)


        k=0
        for i,j in undersampled_imgs.items():

            x = self.reconstruct(j)
            y=dataset[k].get_Image()
            
            #Writes MSE to the console
            mse = tf.keras.losses.MeanSquaredError()
            print("Image {a} AF {b}".format(a=k,b=i[len(i)-1]))
            print(mse(y, x).numpy())


            if(int(i[len(i)-1])==4):
                k+=1
        
    #plot mse between recon img and k-space fixed recon img 
    def plot_kspace_fix_MSE(self,dataset,AF):
        undersampled_imgs = self.undersample_dataset(dataset,AF)


        k=0
        for i,j in undersampled_imgs.items():

            x = self.reconstruct(j)
            y=dataset[k].get_Image()
            x= self.k_space_Correction(x,y)

            mse = tf.keras.losses.MeanSquaredError()
            print("Image {a} AF {b}".format(a=k,b=i[len(i)-1]))
            print(mse(y, x).numpy())


            if(int(i[len(i)-1])==4):
                k+=1
    
    #plots undersampled images with reconstructed figures
    def plot_imgs(self,dataset,AF,output_dir):
        undersampled_imgs = self.undersample_dataset(dataset,AF)
        directory = net.model.name

        # Path 
        path = os.path.join(output_dir, directory)
        if not os.path.exists(path):  
            os.mkdir(path)
        else:
            z=1
            while os.path.exists(path):
                path = os.path.join(path,"v{}".format(z))

        os.mkdir(path)

        print("Images saved at '% s' " % path)

        k=0
        for i,j in undersampled_imgs.items():

            reconstructed_img = self.reconstruct(j)
            print()

            plt.figure(figsize=(15, 15))
            plt.subplot(121), plt.imshow(j, cmap=plt.cm.bone) 
            plt.title("Undersampled Image {img} at AF={af}".format(img=k,af=i[len(i)-1]),fontsize=25)
            plt.axis('off')

            plt.subplot(122), plt.imshow(reconstructed_img, cmap=plt.cm.bone)
            plt.title('Reconstructed Image',fontsize=25)
            plt.axis('off')

            save_img = os.path.join(path, "{}.png".format("{}".format(i)))
            plt.savefig(save_img, bbox_inches='tight')
            plt.close()

            print("Saved {index} to {path}".format(
                index=i, path=save_img))
            if(int(i[len(i)-1])==4):
                k+=1
                
    #plot reconstructed images with k-space fixed recon img
    def plot_k_spaceFix_imgs(dataset,AF,output_dir):
        undersampled_imgs = self.undersample_dataset(dataset,AF)
        directory = net.model.name

        # Path 
        path = os.path.join(output_dir, directory)
        if not os.path.exists(path):
            os.mkdir(path)
            z=1
            while os.path.exists(path):
                path = os.path.join(path,"v{}".format(z))

        os.mkdir(path)

        print("Images saved at '% s' " % path)

        k=0
        for i,j in undersampled_imgs.items():

            reconstructed_img = self.reconstruct(j)
            print()

            plt.figure(figsize=(15, 15))
            plt.subplot(121), plt.imshow(reconstructed_img, cmap=plt.cm.bone) 
            plt.title("Reconstructed Validation Image {img} at AF={af}".format(img=k+14,af=i[len(i)-1]),fontsize=25)
            plt.axis('off')

            y=dataset[k].get_Image()
            k_space_fix = self.k_space_Correction(reconstructed_img,y)

            plt.subplot(122), plt.imshow(k_space_fix, cmap=plt.cm.bone)
            plt.title('k-space fixed Image',fontsize=25)
            plt.axis('off')

            save_img = os.path.join(path, "{}.png".format("{}".format(i)))
            plt.savefig(save_img, bbox_inches='tight')
            plt.close()

            print("Saved {index} to {path}".format(
                index=i, path=save_img))
            if(int(i[len(i)-1])==4):
                k+=1


    

In [31]:
model = tf.keras.models.load_model('C:/MRI_dataset/output/functional_1')


In [38]:
eval_imgs = Eval(net.model)

In [39]:
eval_imgs.plot_imgs(datasets,[1,2,3],"C:/test")

Images saved at 'C:/test\functional_1\v1' 

Saved 01 to C:/test\functional_1\v1\01.png

Saved 02 to C:/test\functional_1\v1\02.png

Saved 03 to C:/test\functional_1\v1\03.png

Saved 11 to C:/test\functional_1\v1\11.png

Saved 12 to C:/test\functional_1\v1\12.png

Saved 13 to C:/test\functional_1\v1\13.png

Saved 21 to C:/test\functional_1\v1\21.png


KeyboardInterrupt: 