In [1]:
from tkinter import *
from tkinter import filedialog

# A function to allow the user to select the folder contianing the data.
# Function inputs args: None. 
# Function output 1: The path of that the folder selected by the user. 
def folder_selection_dialog():
    root = Tk()
    root.title('Please select the directory containing the images')
    root.filename = filedialog.askdirectory(initialdir="/", title="Select A Folder")
    directory = root.filename
    root.destroy()

    return directory

In [None]:
import matplotlib.pyplot as plt
import os

# Function inputs arg 1: num_epochs --> The number of iterations over which the model is refined. 
# Function inputs arg 2: training_loss --> Array of size 1 x num_epochs. This array contains the calculated values of loss for training. 
# Function inputs arg 3: validation_loss --> Array of size 1 x num_epochs. This array contains the calculated values of loss for validation. 
# Function inputs arg 4: save_plot --> True or Flase. When true, saves plot to data directory.  
# Function inputs arg 5: display_plot --> True or Flase. When true, displays the plot. 
# Function output: Graph with the loss per epoch.
def loss_graph(num_epochs, 
               training_loss, 
               validation_loss, 
               save_plot, 
               display_plot):
    
    # Plot the loss per epoch. 
    y = list(range(0,num_epochs))
    plt.plot(y, training_loss, label = "Training loss")
    plt.plot(y, validation_loss, label = "Validation loss")
    plt.rcParams.update({'font.size': 15})
    plt.ylabel('Loss', labelpad=10) # The labelpad argument alters the distance of the axis label from the axis itself. 
    plt.xlabel('Epoch', labelpad=10)
    plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left', borderaxespad=0.)

    # Save the plot if the user desires it.
    if save_plot:
        current_directory = os.getcwd()
        file_path, _ = os.path.split(current_directory)
        file_path = os.path.join(file_path, 'img', 'training_and_validation_loss.png')
        plt.savefig(file_path, dpi=200, bbox_inches='tight')
    
    # Display the plot if the user desires it. 
    if (display_plot == False):
        plt.close()
    else:
        plt.show()   

In [None]:
import matplotlib.pyplot as plt
import os

# Function inputs arg 1: num_epochs --> The number of iterations over which the model is refined. 
# Function inputs arg 2: training_accuracy --> Array of size 1 x num_epochs. This array contains the calculated values of training accuracy. 
# Function inputs arg 3: validation_accuracy --> Array of size 1 x num_epochs. This array contains the calculated values of validation accuracy. 
# Function inputs arg 4: save_plot --> True or Flase. When true, saves plot to data directory.  
# Function inputs arg 5: display_plot --> True or Flase. When true, displays the plot. 
# Function output: Graph with the training and validation accuracy per epoch.
def accuracy_graph(num_epochs, 
               training_accuracy, 
               validation_accuracy, 
               save_plot, 
               display_plot):
    
    # Plot the BCE calculated loss per epoch. 
    y = list(range(0,num_epochs))
    plt.plot(y, training_accuracy, label="Training accuracy")
    plt.plot(y, validation_accuracy, label="Validation accuracy")
    plt.rcParams.update({'font.size': 15})
    plt.ylabel('Accuracy', labelpad=10) # The leftpad argument alters the distance of the axis label from the axis itself. 
    plt.xlabel('Epoch', labelpad=10)
    plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left', borderaxespad=0.)

    # Save the plot if the user desires it.
    if save_plot:
        current_directory = os.getcwd()
        file_path, _ = os.path.split(current_directory)
        file_path = os.path.join(file_path, 'img', 'training_and_validation_accuracy.png')
        plt.savefig(file_path, dpi=200, bbox_inches='tight')
    
    # Display the plot if the user desires it. 
    if (display_plot == False):
        plt.close()
    else:
        plt.show()   

In [None]:
import cv2 
import numpy as np

# A function which will append images within a directory into a numpy array. These imags will also be standardized. 
# Function input 1: image_list [list of strings] --> Each item in the list is the name of an image which needs to be appended into one stack e.g. image1.tif.
# Function input 2: directory [string] --> The directory containing the images.
# Function input 3: num_classes_in [numpy array] --> The unique values of the g_truth images. Used for determining the number of classes.
# Function input 4: raw [bool] --> When true, will standardize the image mean to 0, and set standard deviation to 1. 
# Function output 1: image stack [numpy array] --> The 3D stack of appended images. 
# Function output 2: num_classes_out [numpy array] --> The unique values of the images. Used for determining the number of classes.
def append_images(image_list,
                  directory, 
                  num_classes_in,
                  raw=True):

    # Create an empty list. 
    image_stack = []
    num_classes_out = 0
    
    # Iterate through the images of our list and append them to our stack. 
    for i in range(len(image_list)):
        file_path = os.path.join(directory, image_list[i])
        img = cv2.imread(file_path, -1)
        
        # For raw images. 
        if raw: 
            img = (img - img.min()) / (img.max() - img.min()) # Scale the image between 0 and 1. 
            img = np.stack((img,)*num_classes_in, axis=-1)
            #img = np.stack((img,)*3, axis=-1)
            
        # For gtruth images. 
        else: 
            num_classes_out = len(np.unique(img))
            
        image_stack.append(img)

    # Convert the stack to a numpy array. 
    image_stack = np.array(image_stack)

    return image_stack, num_classes_out 

In [10]:
import tensorflow as tf
import segmentation_models as sm
import cv2
import numpy as np
import os 
import matplotlib.pyplot as plt 
from sklearn.model_selection import train_test_split
from sklearn import preprocessing
from sklearn.preprocessing import LabelEncoder
import keras 
from keras.utils.np_utils import to_categorical
from keras.models import Model 
from keras.layers import Input, Conv2D, MaxPooling2D, UpSampling2D, concatenate, Conv2DTranspose, BatchNormalization, Dropout, Lambda

# A function capable of training a CNN to classifying pixels within .tif microscopy images of cell nuclei. 
# Function input 1: directory [str] --> The directory containing the original and gtruth data. 
# Function input 2: save_plot [bool] --> When True, graphical data will be saved. 
# Function input 3: display_plot [bool] --> When True, graphical data will be displayed in the console. 
# Function output 1: The trained CNN. 
def train_CNN(directory,
              save_plot=False,
              display_plot=True):
     
    #### (1) Create our training and testing dataset. 
    raw_images = [image for image in os.listdir(directory) if all([image.endswith('.tif'), 'gtruth' not in image])]    
    gtruth_images = [image for image in os.listdir(directory) if all([image.endswith('.tif'), 'gtruth' in image])]
    
    # Get the images (X) and their ground truth equivalents (Y).
    Y, number_classes = append_images(gtruth_images, directory, num_classes_in=0, raw = False)
    X, _ = append_images(raw_images, directory, num_classes_in=number_classes, raw = True)

    # Encode our labels, to ensure that the that the first label value starts from 0 (not 1) as the model expects.
    label_encoder = LabelEncoder()
    slices, height, width = Y.shape
    Y_reshaped = Y.ravel() # Reshape each image into a single column. 
    Y_reshaped_encoded = label_encoder.fit_transform(Y_reshaped)
    Y_reshaped_encoded = Y_reshaped_encoded.reshape(slices, height, width)
    
    # Add an additional dimension to our ground truth data, as the model expects it. 
    Y = np.expand_dims(Y_reshaped_encoded, axis = 3)
    
    # Convert our ground truth pixel values to a one-hot-encoded format. For instance, a pixel value of 2 would be converted to [0,0,1,0]. This is needed for loss functions such as categorical cross entropy loss functions.
    Y_categorical = to_categorical(Y, number_classes)
    Y_categorical = Y_categorical.reshape((Y.shape[0], Y.shape[1], Y.shape[2], number_classes))

    # Split our data into test and train datasets. 
    x_train, x_test, y_train, y_test = train_test_split(X,Y_categorical, test_size=0.5)
    
    #### (2) Define and the loss algorithm and methods of model assessment. 
    
    # Establish the parameters for the model and the optimizer. 
    if number_classes == 2:
        activation = 'sigmoid' 
    elif number_classes > 2:
        activation = 'softmax'
    
    LR = 0.0001
    optim = tf.keras.optimizers.Adam(LR)
    
    # For semantic segmentation, we could use a cross entropy method, but it is suggested to use a combination of loss calculated by dice and focal funcitions. 
    dice_loss = sm.losses.DiceLoss(class_weights = np.array([0.25,0.25,0.25,0.25]))
    focal_loss = sm.losses.CategoricalFocalLoss()
    loss_weight = 1
    total_loss = dice_loss + (loss_weight * focal_loss) # or simply ... total loss = sm.losses.binary_focal_dice_loss
    
    # Track the Intersection Over Union (IOU) score and the F-score (a measure of precision and recall).
    assessment_scores = [sm.metrics.IOUScore(threshold=0.5), sm.metrics.FScore(threshold=0.5)]
    
    #### (3) Define our model. 
    
    def multiclass_Unet(n_classes = number_classes,
                       img_height = height,
                       img_width = width,
                       img_channels = 1):
        
    
    # Create our encoder. 
    backbone = 'vgg16'
    
    # Preprocess our inputs. 
    preprocessing_method = sm.get_preprocessing(backbone)
    x_train = preprocessing_method(x_train)
    x_test = preprocessing_method(x_test)
    
    # Define the model: Unet backbone with resnet34. 
    model = sm.Unet(backbone, 
                    encoder_weights='imagenet', 
                    classes=number_classes, 
                    activation=activation)

    # Compile our keras model together with our optimizer, loss and assessment metrics. 
    model.compile(optim, total_loss, metrics=assessment_scores)

    #### (4) Train our model.
    
    history = model.fit(x_train,
                        y_train,
                        batch_size=1,
                        epochs=3,
                        verbose=1,
                        validation_data=(x_test,y_test))