In [None]:
import os
import cv2
import imageio
import scipy.ndimage as ndi
import numpy as np
import matplotlib.pyplot as plt

In [None]:
def format_and_render_plot():
    '''Custom function to simplify common formatting operations for exercises. Operations include: 
    1. Turning off axis grids.
    2. Calling `plt.tight_layout` to improve subplot spacing.
    3. Calling `plt.show()` to render plot.'''
    fig = plt.gcf()
    for ax in fig.axes:
        ax.axis('off')    
    plt.tight_layout()
    plt.show()

In [None]:
def normalize_image(img_array):
    '''
    Min-max image normalization
    '''
    # apply min-max normalization
    im = (img_array - img_array.min()) / (img_array.max() - img_array.min())

    return im


def filter_image(img_array):
    '''
    Filter and mask image
    '''
    # create mask based on threshold
    threshold = 0.4
    mask_bone = img_array >= threshold

    # binary_dilation: add pixels along edges
    mask_dilate = ndi.binary_dilation(mask_bone, iterations=2)

    return mask_dilate


# define the function to compute MSE between two images
def mse(img1, img2):
   img1 = img1.astype(np.float32)
   img2 = img2.astype(np.float32)

   h, w = img1.shape
   diff = cv2.subtract(img1, img2)
   err = np.sum(diff**2)
   mse = err/(float(h*w))
   return mse, diff


def dice_coef(y_true, y_pred, flatten=False):
    '''
    Calculate DICE coefficient
    '''
    if y_true.shape != y_pred.shape:
            raise ValueError("Shape mismatch: img and img2 must have to be of the same shape.")
    else:
        if flatten:
            # flatten
            y_true_f = y_true.flatten()
            y_pred_f = y_pred.flatten()

            # calculate intersection
            intersection_f = np.sum(y_true_f * y_pred_f)

            # calculate DICE
            dice_f = (2. * intersection_f) / (np.sum(y_true_f) + np.sum(y_pred_f))

            return dice_f
        
        else:
            # intersection
            intersection = np.sum(y_true * y_pred)
            
            # DICE
            dice = (2. * intersection) / (np.sum(y_true) + np.sum(y_pred))

            return dice


def process_images(real_image_path, fake_image_path, plot_images=False):
    '''
    Process pipeline
    '''
    # load images
    real_image = imageio.imread(real_image_path)
    fake_image = imageio.imread(fake_image_path)

    # normalize images
    normalized_real_image = normalize_image(real_image)
    normalized_fake_image = normalize_image(fake_image)

    # filter images
    filtered_real_image = filter_image(normalized_real_image)
    filtered_fake_image = filter_image(normalized_fake_image)

    # calculate DICE
    dice_coefficient = dice_coef(filtered_real_image, filtered_fake_image)

    # calculate MSE
    error, diff = mse(filtered_real_image, filtered_fake_image)

    if plot_images:
        # Plot the grayscale images
        fig, axes = plt.subplots(1, 3, sharex=True)
        axes[0].imshow(filtered_real_image, cmap='gray')
        axes[0].set_title('sub', fontweight ="bold")
        axes[1].imshow(filtered_fake_image, cmap='gray')
        axes[1].set_title('GAN', fontweight ="bold")
        axes[2].imshow(diff, cmap='gray')
        axes[2].set_title('Diff', fontweight ="bold")
        format_and_render_plot()

    

    return dice_coefficient, error

In [None]:
# define path
real_path = 'C:/Users/einspaen/AppData/Local/xnat-dataset/classifierAppData/fake/163014180-2/generated'
fake_path = 'C:/Users/einspaen/AppData/Local/xnat-dataset/classifierAppData/real/163014180-2/sub'

In [None]:
# list of DICOM files
real_files = os.listdir(real_path)
fake_files = os.listdir(fake_path)

# storage for DICE and MSE results
dice_results = []
mse_results = []

# iterate over all files
if len(real_files) != len(fake_files):
    print("Anzahl der realen und gefälschten Bilder stimmt nicht überein.")
else:
    for real_file, fake_file in zip(real_files, fake_files):
        # define image path
        real_image_path = os.path.join(real_path, real_file)
        fake_image_path = os.path.join(fake_path, fake_file)

        # start pipeline
        dice_coefficient, mse_error = process_images(real_image_path, fake_image_path)

        # save results and print
        dice_results.append(dice_coefficient)
        mse_results.append(mse_error)
        print(f"DICE-Coeff for {real_file} and {fake_file}: {dice_coefficient}")
        print(f"Image matching Error between {real_file} and {fake_file}: {mse_error}")

In [None]:
# number of files (x-axis)
number_of_files = np.arange(len(real_files))

# plot
fig, axes = plt.subplots(2, 1, sharex=True)
axes[0].plot(number_of_files, dice_results, marker='o', linestyle='-')
axes[0].set_title('DICE results for each slice')
axes[1].plot(number_of_files, mse_results, marker='o', linestyle='-')
axes[1].set_title('MSE results for each slice')

# show plot
plt.show()