In [1]:
from tkinter import *
from tkinter import filedialog

# A function to allow the user to select the folder contianing the data.
# Function inputs args: None. 
# Function output 1: The path of that the folder selected by the user. 
def folder_selection_dialog():
    root = Tk()
    root.title('Please select the directory containing the images')
    root.filename = filedialog.askdirectory(initialdir="/", title="Select A Folder")
    directory = root.filename
    root.destroy()

    return directory

In [None]:
from tkinter import *
from tkinter import filedialog

# A function to allow the user to select the model they wish to use or retrain. 
# Function inputs args: None. 
# Function output 1: The file path of that which was selected by the user. 
def file_selection_dialog():
    root = Tk()
    root.title('Please select the machine learning model in question')
    root.filename = filedialog.askopenfilename(initialdir="/", title="Select A File", filetypes=[("All files", "*.*")])
    file_path = root.filename
    root.destroy()

    return file_path

In [None]:
import matplotlib.pyplot as plt
import os

# Function inputs arg 1: num_epochs --> The number of iterations over which the model is refined. 
# Function inputs arg 2: training_loss --> Array of size 1 x num_epochs. This array contains the calculated values of loss for training. 
# Function inputs arg 3: validation_loss --> Array of size 1 x num_epochs. This array contains the calculated values of loss for validation. 
# Function inputs arg 4: save_plot --> True or Flase. When true, saves plot to data directory.  
# Function inputs arg 5: display_plot --> True or Flase. When true, displays the plot. 
# Function output: Graph with the loss per epoch.
def loss_graph(num_epochs, 
               training_loss, 
               validation_loss, 
               save_plot, 
               display_plot):
    
    # Plot the loss per epoch. 
    y = list(range(0,num_epochs))
    plt.plot(y, training_loss, label = "Training loss")
    plt.plot(y, validation_loss, label = "Validation loss")
    plt.rcParams.update({'font.size': 15})
    plt.ylabel('Loss', labelpad=10) # The labelpad argument alters the distance of the axis label from the axis itself. 
    plt.xlabel('Epoch', labelpad=10)
    plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left', borderaxespad=0.)

    # Save the plot if the user desires it.
    if save_plot:
        current_directory = os.getcwd()
        file_path, _ = os.path.split(current_directory)
        file_path = os.path.join(file_path, 'img', 'training_and_validation_loss.png')
        plt.savefig(file_path, dpi=200, bbox_inches='tight')
    
    # Display the plot if the user desires it. 
    if (display_plot == False):
        plt.close()
    else:
        plt.show()   

In [None]:
import matplotlib.pyplot as plt
import os

# Function inputs arg 1: num_epochs --> The number of iterations over which the model is refined. 
# Function inputs arg 2: training_accuracy --> Array of size 1 x num_epochs. This array contains the calculated values of training accuracy. 
# Function inputs arg 3: validation_accuracy --> Array of size 1 x num_epochs. This array contains the calculated values of validation accuracy. 
# Function inputs arg 4: save_plot --> True or Flase. When true, saves plot to data directory.  
# Function inputs arg 5: display_plot --> True or Flase. When true, displays the plot. 
# Function output: Graph with the training and validation accuracy per epoch.
def accuracy_graph(num_epochs, 
                   training_accuracy, 
                   validation_accuracy, 
                   save_plot, 
                   display_plot):
    
    # Plot the BCE calculated loss per epoch. 
    y = list(range(0,num_epochs))
    plt.plot(y, training_accuracy, label="Training accuracy")
    plt.plot(y, validation_accuracy, label="Validation accuracy")
    plt.rcParams.update({'font.size': 15})
    plt.ylabel('Accuracy', labelpad=10) # The leftpad argument alters the distance of the axis label from the axis itself. 
    plt.xlabel('Epoch', labelpad=10)
    plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left', borderaxespad=0.)

    # Save the plot if the user desires it.
    if save_plot:
        current_directory = os.getcwd()
        file_path, _ = os.path.split(current_directory)
        file_path = os.path.join(file_path, 'img', 'training_and_validation_accuracy.png')
        plt.savefig(file_path, dpi=200, bbox_inches='tight')
    
    # Display the plot if the user desires it. 
    if (display_plot == False):
        plt.close()
    else:
        plt.show()   

In [None]:
import matplotlib.pyplot as plt 
from skimage import exposure

# A function to display a montage of ground truth data and the predicted data. 
# Function input arg 1: x_test --> The raw image to display.
# Funtion input arg 2: y_test --> The corresponding ground truth labelled image. 
# Function input arg 3: img_height --> The height of an individual image in pixels. 
# Function input arg 4: img_width --> the width of an individual image in pixels. 
# Function output 1: A montage of images, including the raw image, the corresponding ground truth image, and the predicted image. 
def display_montage(x_test,
                    y_test,
                    img_height,
                    img_width,
                    display_plot, 
                    save_plot):
    
    # Create the figure.
    fig = plt.figure(figsize=(8, 7))

    # Set the number of rows and columns. 
    rows = 2
    columns = 3
    
    # Add a subplot at the 1st position.
    fig.add_subplot(rows, columns, 1)
    raw_image = exposure.equalize_adapthist(np.squeeze(x_test), clip_limit=0.03)
    plt.imshow(raw_image)
    plt.axis('off')
    plt.title("Raw Image.")

    # Adds a subplot at the 2nd position
    fig.add_subplot(rows, columns, 2)
    second_image = np.argmax(np.squeeze(y_test), axis=1)
    second_image = second_image.reshape(img_height, img_width)
    plt.imshow(second_image, cmap='tab10')
    plt.axis('off')
    plt.title("Ground Truth Image.")
    
    # Adds a subplot at the 3nd position
    y_pred = model.predict(x_test)
    y_pred_argmax = np.argmax(y_pred, axis=2)
    y_pred_argmax = np.reshape(y_pred_argmax, (img_height, img_width))
    
    fig.add_subplot(rows, columns, 3)
    plt.imshow(y_pred_argmax, cmap='tab10')
    plt.axis('off')
    plt.title("Image Predicted by CNN.")
    
    # Save the plot if the user desires it.
    if save_plot:
        current_directory = os.getcwd()
        file_path, _ = os.path.split(current_directory)
        file_path = os.path.join(file_path, 'img', 'Montage.png')
        plt.savefig(file_path, dpi=200, bbox_inches='tight')
    
    # Display the plot if the user desires it. 
    if (display_plot == False):
        plt.close()
    else:
        plt.show()   

In [None]:
import cv2 
import numpy as np

# A function which will append images within a directory into a numpy array. These imags will also be standardized. 
# Function input 1: image_list [list of strings] --> Each item in the list is the name of an image which needs to be appended into one stack e.g. image1.tif.
# Function input 2: directory [string] --> The directory containing the images.
# Function input 3: num_classes_in [numpy array] --> The unique values of the g_truth images. Used for determining the number of classes.
# Function input 4: raw [bool] --> When true, will standardize the image mean to 0, and set standard deviation to 1. 
# Function output 1: image stack [numpy array] --> The 3D stack of appended images. 
# Function output 2: num_classes_out [numpy array] --> The unique values of the images. Used for determining the number of classes.
def append_images(image_list,
                  directory, 
                  num_classes_in,
                  raw=True):

    # Create an empty list. 
    image_stack = []
    num_classes_out = 0
    
    # Iterate through the images of our list and append them to our stack. 
    for i in range(len(image_list)):
        file_path = os.path.join(directory, image_list[i])
        img = cv2.imread(file_path, -1)
        
        # For raw images. 
        if raw: 
            img = (img - img.min()) / (img.max() - img.min()) # Scale the image between 0 and 1. 
            #img = np.stack((img,)*num_classes_in, axis=-1)
            img = np.stack((img,)*1, axis=-1)
            
        # For gtruth images. 
        else: 
            num_classes_out = len(np.unique(img))
            
        image_stack.append(img)

    # Convert the stack to a numpy array. 
    image_stack = np.array(image_stack)

    return image_stack, num_classes_out 

In [None]:
from keras.models import Model 
from keras.layers import Input, Conv2D, MaxPooling2D, UpSampling2D, concatenate, Conv2DTranspose, BatchNormalization, Dropout, Lambda
import tensorflow as tf 

# A function to create our Unet model. 
# Function input 1: n_classes [int] --> Number of classes which need to be classified. 
# Function input 2: img_height [int] --> Image height in pixels. 
# Function input 3: img_width [int] --> Image width in pixels. 
# Function input 4: img_channels [int] --> Number of channels. For a grayscale image, this would be 1. for an RGB image, this would be 3.
def multiclass_Unet(n_classes = number_classes,
                   img_height = img_height,
                   img_width = img_width,
                   img_channels = img_channels):

    inputs = Input((img_height, img_width, img_channels))
    #print("inputs:", inputs.shape)
    
    # Contraction path. 
    c1 = Conv2D(16, (3,3), activation='relu', kernel_initializer='he_normal', padding='same')(inputs)
    c1 = Dropout(0.1)(c1)
    c1 = Conv2D(16, (3,3), activation='relu', kernel_initializer='he_normal', padding='same')(c1)
    p1 = MaxPooling2D(pool_size=(2,2), strides=(2,2))(c1)
    #print("p1:", p1.shape)
    
    c2 = Conv2D(32, (3,3), activation='relu', kernel_initializer='he_normal', padding='same')(p1)
    c2 = Dropout(0.1)(c2)
    c2 = Conv2D(32, (3,3), activation='relu', kernel_initializer='he_normal', padding='same')(c2)
    p2 = MaxPooling2D(2,2)(c2)
    #print("p2:", p2.shape)

    c3 = Conv2D(64, (3,3), activation='relu', kernel_initializer='he_normal', padding='same')(p2)
    c3 = Dropout(0.2)(c3)
    c3 = Conv2D(64, (3,3), activation='relu', kernel_initializer='he_normal', padding='same')(c3)
    p3 = MaxPooling2D(pool_size=(2,2), strides=(2,2))(c3)
    #print("p3:", p3.shape)

    c4 = Conv2D(128, (3,3), activation='relu', kernel_initializer='he_normal', padding='same')(p3)
    c4 = Dropout(0.2)(c4)
    c4 = Conv2D(128, (3,3), activation='relu', kernel_initializer='he_normal', padding='same')(c4)
    p4 = MaxPooling2D(pool_size=(2,2), strides=(2,2))(c4)
    #print("p4:", p4.shape)

    c5 = Conv2D(256, (3,3), activation='relu', kernel_initializer='he_normal', padding='same')(p4)
    c5 = Dropout(0.3)(c5)
    c5 = Conv2D(256, (2,2), activation='relu', kernel_initializer='he_normal', padding='same')(c5)
    #print("c5:", c5.shape)

    # Expansion path. 
    u6 = Conv2DTranspose(128, (2,2), strides=(2,2), padding='same')(c5)
    #print("u6:", u6.shape)
    u6 = concatenate([u6, c4])
    c6 = Conv2D(128, (3,3), activation='relu', kernel_initializer='he_normal', padding='same')(u6)
    c6 = Dropout(0.2)(c6)
    c6 = Conv2D(128, (3,3), activation='relu', kernel_initializer='he_normal', padding='same')(c6)
    #print("c6:", c6.shape)

    u7 = Conv2DTranspose(64, (2,2), strides=(2,2), padding='same')(c6)
    #print("u7:", u7.shape)
    u7 = concatenate([u7, c3])
    c7 = Conv2D(64, (3,3), activation='relu', kernel_initializer='he_normal', padding='same')(u7)
    c7 = Dropout(0.2)(c7)
    c7 = Conv2D(64, (3,3), activation='relu', kernel_initializer='he_normal', padding='same')(c7)
    #print("c7:", c7.shape)
    
    u8 = Conv2DTranspose(32, (2,2), strides=(2,2), padding='same')(c7)
    #print("u8:", u8.shape)
    u8 = concatenate([u8, c2])
    c8 = Conv2D(32, (3,3), activation='relu', kernel_initializer='he_normal', padding='same')(u8)
    c8 = Dropout(0.1)(c8)
    c8 = Conv2D(32, (3,3), activation='relu', kernel_initializer='he_normal', padding='same')(c8)
    #print("c8:", c8.shape)

    u9 = Conv2DTranspose(16, (2,2), strides=(2,2), padding='same')(c8)
    #print("u9:", u9.shape)
    u9 = concatenate([u9, c1])
    c9 = Conv2D(16, (3,3), activation='relu', kernel_initializer='he_normal', padding='same')(u9)
    c9 = Dropout(0.1)(c9)
    c9 = Conv2D(16, (3,3), activation='relu', kernel_initializer='he_normal', padding='same')(c9)
    #print("c9:", c9.shape)

    outputs = Conv2D(n_classes, (1,1), activation='softmax')(c9)
    outputs = tf.reshape(outputs, [-1, img_height*img_width, n_classes]) # This rehsape is necessary to use sample_weights. 
    #print("outputs:", outputs.shape)
    
    model = Model(inputs=[inputs], outputs=[outputs])

    return model

In [10]:
import tensorflow as tf
import segmentation_models as sm
import cv2
import numpy as np
import os 
import matplotlib.pyplot as plt 
from sklearn.model_selection import train_test_split
from sklearn import preprocessing
from sklearn.preprocessing import LabelEncoder
from sklearn.utils import class_weight
import keras 
from keras.utils.np_utils import to_categorical

# A function capable of training a CNN to classifying pixels within .tif microscopy images of cell nuclei. 
# Function input 1: directory [str] --> The directory containing the original and gtruth data. 
# Function input 2: save_plot [bool] --> When True, graphical data will be saved. 
# Function input 3: display_plot [bool] --> When True, graphical data will be displayed in the console. 
# Function input 4: save_model [bool] --> When True, saves the model to the directory containing the training data. 
# Function output 1: The trained CNN. 
def train_CNN(directory,
              save_plot=False,
              display_plot=True,
              save_model=True, 
              num_epochs=30):
     
    #### (1) Create our training and testing dataset. 
    
    # Get the names of our raw and labelled (gtruth) images. 
    raw_images = [image for image in os.listdir(directory) if all([image.endswith('.tif'), 'gtruth' not in image])]    
    gtruth_images = [image for image in os.listdir(directory) if all([image.endswith('.tif'), 'gtruth' in image])]
    
    # Get the images (X) and their ground truth equivalents (Y).
    Y, number_classes = append_images(gtruth_images, directory, num_classes_in=0, raw = False)
    X, _ = append_images(raw_images, directory, num_classes_in=number_classes, raw = True)

    # Encode our labels, to ensure that the that the first label value starts from 0 (not 1) as the model expects.
    label_encoder = LabelEncoder()
    slices, height, width = Y.shape
    Y_reshaped = Y.ravel() # Reshape each image into a single column. 
    Y_reshaped_encoded = label_encoder.fit_transform(Y_reshaped)
    Y_reshaped_encoded2 = Y_reshaped_encoded.reshape(slices, height, width)
    
    # Add an additional dimension to our ground truth data, as the model expects it. 
    Y = np.expand_dims(Y_reshaped_encoded2, axis = 3)
    
    # Convert our ground truth pixel values to a one-hot-encoded format. For instance, a pixel value of 2 would be converted to [0,0,1,0]. This is needed for loss functions such as categorical cross entropy loss functions.
    Y_categorical = to_categorical(Y, number_classes)
    Y_categorical = Y_categorical.reshape((Y.shape[0], Y.shape[1], Y.shape[2], number_classes))
    Y_categorical = np.reshape(Y_categorical, (Y_categorical.shape[0], Y_categorical.shape[1]*Y_categorical.shape[2], Y_categorical.shape[3]))
    
    # Split our data into test and train datasets. 
    x_train, x_test, y_train, y_test = train_test_split(X,Y_categorical, test_size=0.5)

    #### (2) Define and the loss algorithm and methods of model assessment. 
    
    # Establish the parameters for the model and the optimizer. 
    #if number_classes == 2:
    #    activation = 'sigmoid' 
    #elif number_classes > 2:
    #    activation = 'softmax'
    # Check to see where this will be useful. 
    
    #### (3) Define our model. 
    
    img_height = x_train.shape[1]    
    img_width = x_train.shape[2]
    img_channels = x_train.shape[3]
    model = multiclass_Unet(n_classes = number_classes,
                            img_height = img_height,
                            img_width = img_width,
                            img_channels = img_channels)
    
    focal_loss = sm.losses.CategoricalFocalLoss()
    model.compile(optimizer='adam', loss=focal_loss, metrics=['accuracy'])
    model.summary()
    
    #### (4) Train our model.
    
    history = model.fit(x_train,
                        y_train,
                        batch_size=1,
                        epochs=num_epochs,
                        verbose=1,
                        validation_data=(x_test,y_test))
                        #class_weight=class_weights)
    
    # If the user desires it, save the entire model as a SavedModel. 
    if save_model == True:
        file_path = os.path.join(directory, 'multiclass_CNN.hdf5')
        model.save(file_path) 
        
    #### (5) Assess our model performance. 

    _, acc = model.evaluate(x_test, y_test)
    print(acc)
    
    # Creat a montage to view data. 
    display_montage(x_test,
                    y_test,
                    img_height,
                    img_width)
    
    # Create the loss graph. 
    loss_graph(num_epochs, 
               history.history['loss'], 
               history.history['loss'], 
               save_plot, 
               display_plot)
    
    # Create the accuracy graph. 
    accuracy_graph(num_epochs, 
                   history.history['accuracy'], 
                   history.history['accuracy'], 
                   save_plot, 
                   display_plot)

In [None]:
# Increase training dataset size. 

# Add validation vs training loss. 
