In [None]:
# Import general python libraries
import numpy as np
import os
import pandas as pd
import cv2
import matplotlib.pyplot as plt
import seaborn as sns
from skimage import exposure
from scipy.ndimage import convolve
from time import time
import random
import random as python_random
import tensorflow as tf

# Import the GDAL module from the osgeo package
from osgeo import gdal

# Import necessary functions from scikit-learn
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, classification_report

# Import necessary functions and classes from Keras
from keras.utils import to_categorical
from keras.optimizers import Adam
from tensorflow.keras.losses import BinaryCrossentropy
from tensorflow.keras.metrics import BinaryAccuracy, Precision, Recall, IoU, MeanIoU, FalseNegatives, FalsePositives
from keras.models import Model
from keras.layers import Input, Conv2D, MaxPooling2D, Conv2DTranspose, concatenate, Dropout, BatchNormalization
from keras.callbacks import ModelCheckpoint

In [None]:
# Define the function for normalisation of vegetation indices
def post_idx_calc(index, normalise):
    # Replace nan with zero and inf with finite numbers
    idx = np.nan_to_num(index)
    if normalise:
        return cv2.normalize(
            idx, None, alpha=0.0, beta=1.0, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F)
    else:
        return idx

In [None]:
# Define function to calculate vegetation indices
def calculate_veg_indices(input_img):
# Extract the all channels from the input image
    RedEdge = input_img[:, :, 3]
    nir = input_img[:, :, 4]
    red = input_img[:, :, 2]
    green = input_img[:, :, 1]
    blue = input_img[:, :, 0]

# Define all the vegetation indices
    # Calculate vegetation indices
    ndvi = post_idx_calc((nir - red) / (nir + red),normalise=False)
    dvi = post_idx_calc((nir - red),normalise=False)
    tvi = post_idx_calc((60*(nir - green)) - (100 * (red - green)),normalise=False)
    gdvi = post_idx_calc((nir - green),normalise=False)
    endvi = post_idx_calc(((nir + green) - (2 * blue)) / ((nir + green) + (2 * blue)),normalise=False)\
        
    veg_indices = np.stack((ndvi, dvi, tvi, gdvi, endvi), axis=2)

    return veg_indices

In [None]:
# Define a function to get the width and height of an image using GDAL (If required)
def get_image_dimensions(file_path):
    ds = gdal.Open(file_path)
    if ds is not None:
        width = ds.RasterXSize
        height = ds.RasterYSize
        return width, height
    return None, None

# Minimum and Maximum width and height for filtering
min_width = 256
min_height = 256
max_width = 2000
max_height = 2000

In [None]:
# Define the tile size and overlap percentage
tile_size = 256
overlap = int(tile_size * 0.3)

num_classes=2

In [None]:
# Define the root directory with input images and respective masks
root_data_folder = r'/home/n10837647/pandanus_classification'
root_image_folder = r'/home/n10837647/pandanus_classification/msi_image_mask_rois/training'
root_model_folder =os.path.join(root_data_folder, 'model&outcomes')
# Check if the "model&outcomes" folder exists, and create it if it doesn't
if not os.path.exists(root_model_folder):
    os.makedirs(root_model_folder)

In [None]:
# Store the tiled images and masks
image_patches = []
mask_patches = []

# # Define a function to get the width and height of an image using GDAL
# def get_image_dimensions(file_path):
#     ds = gdal.Open(file_path)
#     if ds is not None:
#         width = ds.RasterXSize
#         height = ds.RasterYSize
#         return width, height
#     return None, None

# Specify the folder paths for images and masks
image_folder_path = os.path.join(root_image_folder, 'msi_rois')
mask_folder_path = os.path.join(root_image_folder, 'msi_mask_rois')

# Filter image and mask files based on dimensions
filtered_image_files = []
filtered_mask_files = []

input_img_folder = os.path.join(root_image_folder, 'msi_rois')
input_mask_folder = os.path.join(root_image_folder, 'msi_mask_rois')

img_files = [file for file in os.listdir(input_img_folder) if file.endswith(".tif")]
mask_files = [file for file in os.listdir(input_mask_folder) if file.endswith(".tif")]

# Iterate through the image files
for img_file in img_files:
    img_path = os.path.join(image_folder_path, img_file)
    img_width, img_height = get_image_dimensions(img_path)
    
    if img_width is not None and img_height is not None:
        if min_width <= img_width <= max_width and min_height <= img_height <= max_height:
            filtered_image_files.append(img_path)

# Iterate through the mask files
for mask_file in mask_files:
    mask_path = os.path.join(mask_folder_path, mask_file)
    mask_width, mask_height = get_image_dimensions(mask_path)
    
    if mask_width is not None and mask_height is not None:
        if min_width <= mask_width <= max_width and min_height <= mask_height <= max_height:
            filtered_mask_files.append(mask_path)

# Print the number of filtered image and mask files
print(f"Number of filtered image files: {len(filtered_image_files)}")
print(f"Number of filtered mask files: {len(filtered_mask_files)}")

In [None]:
# Sort the filtered files to ensure consistent ordering
filtered_image_files.sort()
filtered_mask_files.sort()

for i in range(len(filtered_image_files)):
    img_file = os.path.basename(filtered_image_files[i])  # Get the file name without the path
    mask_file = os.path.basename(filtered_mask_files[i])  # Get the file name without the path
    
    ds_img = gdal.Open(filtered_image_files[i])
    ds_mask = gdal.Open(filtered_mask_files[i])
    width = ds_img.RasterXSize
    height = ds_img.RasterYSize

    # Calculate the number of tiles in the image
    num_tiles_x = (width - tile_size) // (tile_size - overlap) + 1
    num_tiles_y = (height - tile_size) // (tile_size - overlap) + 1

    for y in range(num_tiles_y):
        for x in range(num_tiles_x):
            # Calculate the tile coordinates
            x_start = x * (tile_size - overlap)
            y_start = y * (tile_size - overlap)
            x_end = x_start + tile_size
            y_end = y_start + tile_size

            # Extract the image tile
            input_bands = 5  # Number of input bands
            input_img = np.array([ds_img.GetRasterBand(j + 1).ReadAsArray(x_start, y_start, tile_size, tile_size) for j in range(input_bands)])
            input_img = np.transpose(input_img, (1, 2, 0))
            input_img = exposure.equalize_hist(input_img)
            
            veg_indices = calculate_veg_indices(input_img)
            input_img = np.concatenate((input_img, veg_indices), axis=2)

            input_mask = ds_mask.GetRasterBand(1).ReadAsArray(x_start, y_start, tile_size, tile_size).astype(int)
           
            image_patches.append(input_img)
            mask_patches.append(input_mask)

    print(f"Processed image: {img_file} --> Processed mask: {mask_file}")

# Convert the lists to NumPy arrays
image_patches = np.array(image_patches)
mask_patches = np.array(mask_patches)

# Print the shape of the arrays
print("image_patches.shape: {}".format(image_patches.shape))
print("mask_patches.shape: {}".format(mask_patches.shape))

In [None]:
# This function takes the mask_patches data and converts it into a categorical representation. 
mask_patches_to_categorical = to_categorical(mask_patches, num_classes=2)

In [None]:
# Train-test split
X_train, X_test, y_train, y_test = train_test_split(image_patches, mask_patches_to_categorical, test_size=0.25, random_state=22)

In [None]:
#save, print and confirm the model data
output_file = os.path.join(root_model_folder, 'trainng and validation samples.txt')
# Save the print results to a text file
with open(output_file, "w") as file:
    file.write("image_patches.shape: {}\n".format(image_patches.shape))
    file.write("mask_patches.shape: {}\n".format(mask_patches.shape))

# Save the model data to the text file
with open(output_file, "a") as file:
    file.write("\nX_train shape: {}\n".format(X_train.shape))
    file.write("X_test shape: {}\n".format(X_test.shape))
    file.write("y_train shape: {}\n".format(y_train.shape))
    file.write("y_test shape: {}\n".format(y_test.shape))
    file.write("image height: {}\n".format(X_train.shape[1]))
    file.write("Image width: {}\n".format(X_train.shape[2]))
    file.write("Image channels: {}\n".format(X_train.shape[3]))

print(X_train.shape)
print(X_test.shape)
print(y_train.shape)
print(y_test.shape)

In [None]:
#-----------------"""**Build the model**"""----------------------------#
# Import module 
# Define the number of classess (ID 1,2)
n_classes = 2

def UNet(n_classes, image_height, image_width, image_channels):
  inputs = Input((image_height, image_width, image_channels))

  seed_value = 22
  random.seed(seed_value)
  np.random.seed(seed_value)
  tf.random.set_seed(seed_value)
  python_random.seed(seed_value)
  
  c1 = Conv2D(64, (3,3), activation="relu", kernel_initializer="he_normal", padding="same")(inputs)
  c1 = BatchNormalization()(c1)
  c1 = Dropout(0.2)(c1)
  c1 = Conv2D(64, (3,3), activation="relu", kernel_initializer="he_normal", padding="same")(c1)
  c1 = BatchNormalization()(c1)
  p1 = MaxPooling2D((2,2))(c1)

  c2 = Conv2D(128, (3,3), activation="relu", kernel_initializer="he_normal", padding="same")(p1)
  c2 = BatchNormalization()(c2)
  c2 = Dropout(0.2)(c2)
  c2 = Conv2D(128, (3,3), activation="relu", kernel_initializer="he_normal", padding="same")(c2)
  c2 = BatchNormalization()(c2)
  p2 = MaxPooling2D((2,2))(c2)

  c3 = Conv2D(256, (3,3), activation="relu", kernel_initializer="he_normal", padding="same")(p2)
  c3 = BatchNormalization()(c3)
  c3 = Dropout(0.2)(c3)
  c3 = Conv2D(256, (3,3), activation="relu", kernel_initializer="he_normal", padding="same")(c3)
  c3= BatchNormalization()(c3)
  p3 = MaxPooling2D((2,2))(c3)

  # c4 = Conv2D(512, (3,3), activation="relu", kernel_initializer="he_normal", padding="same")(p3)
  # c4 = BatchNormalization()(c4)
  # c4 = Dropout(0.2)(c4)
  # c4 = Conv2D(512, (3,3), activation="relu", kernel_initializer="he_normal", padding="same")(c4)
  # c4 = BatchNormalization()(c4)
  # p4 = MaxPooling2D((2,2))(c4)

  # c5 = Conv2D(1024, (3,3), activation="relu", kernel_initializer="he_normal", padding="same")(p4)
  # c5 = BatchNormalization()(c5)
  # c5 = Dropout(0.2)(c5)
  # c5 = Conv2D(1024, (3,3), activation="relu", kernel_initializer="he_normal", padding="same")(c5)
  # c5 = BatchNormalization()(c5)

  c5 = Conv2D(512, (3,3), activation="relu", kernel_initializer="he_normal", padding="same")(p4)
  c5 = BatchNormalization()(c5)
  c5 = Dropout(0.2)(c5)
  c5 = Conv2D(512, (3,3), activation="relu", kernel_initializer="he_normal", padding="same")(c5)
  c5 = BatchNormalization()(c5)

  # u6 = Conv2DTranspose(512, (2,2), strides=(2,2), padding="same")(c5)
  # u6 = concatenate([u6, c4])
  # c6 = Conv2D(512, (3,3), activation="relu", kernel_initializer="he_normal", padding="same")(u6)
  # c6 = BatchNormalization()(c6)
  # c6 = Dropout(0.2)(c6)
  # c6 = Conv2D(512, (3,3), activation="relu", kernel_initializer="he_normal", padding="same")(c6)
  # c6 = BatchNormalization()(c6)

  u7 = Conv2DTranspose(256, (2,2), strides=(2,2), padding="same")(c6)
  u7 = concatenate([u7, c3])
  c7 = Conv2D(256, (3,3), activation="relu", kernel_initializer="he_normal", padding="same")(u7)
  c7 = BatchNormalization()(c7)
  c7 = Dropout(0.2)(c7)
  c7 = Conv2D(256, (3,3), activation="relu", kernel_initializer="he_normal", padding="same")(c7)
  c7 = BatchNormalization()(c7)

  u8 = Conv2DTranspose(128, (2,2), strides=(2,2), padding="same")(c7)
  u8 = concatenate([u8, c2])
  c8 = Conv2D(128, (3,3), activation="relu", kernel_initializer="he_normal", padding="same")(u8)
  c8 = BatchNormalization()(c8)
  c8 = Dropout(0.2)(c8)
  c8 = Conv2D(128, (3,3), activation="relu", kernel_initializer="he_normal", padding="same")(c8)
  c8 = BatchNormalization()(c8)

  u9 = Conv2DTranspose(64, (2,2), strides=(2,2), padding="same")(c8)
  u9 = concatenate([u9, c1], axis=3)
  c9 = Conv2D(64, (3,3), activation="relu", kernel_initializer="he_normal", padding="same")(u9)
  c9 = BatchNormalization()(c9)
  c9 = Dropout(0.2)(c9)
  c9 = Conv2D(64, (3,3), activation="relu", kernel_initializer="he_normal", padding="same")(c9)
  c9 = BatchNormalization()(c9)

  outputs = Conv2D(n_classes, (1,1), activation="softmax")(c9)

  model = Model(inputs=inputs, outputs=outputs)
  return model
#----------------------------------------------------------------------#
# Create the model
image_height = X_train.shape[1]
image_width = X_train.shape[2]
image_channels = X_train.shape[3]
model=UNet(n_classes=n_classes, 
                          image_height=image_height, 
                          image_width=image_width, 
                          image_channels=image_channels)
#----------------------------------------------------------------------#
#Complie the model
model.compile(
    optimizer=Adam(learning_rate=0.001),
    loss=BinaryCrossentropy(),  # Use BinaryCrossentropy for binary classification
    metrics=[BinaryAccuracy(),Precision(class_id=1), Recall(class_id=1), IoU(num_classes=2,target_class_ids=[1]), MeanIoU(num_classes=2), FalseNegatives(), FalsePositives()])
#----------------------------------------------------------------------#

In [None]:
# Model training
#----------------------------------------------------------------------#
# Define a log directory for checkpoins
log_dir = os.path.join(root_model_folder, 'log')  # Create the log directory
if not os.path.exists(log_dir):
    os.makedirs(log_dir)
#----------------------------------------------------------------------#
# specify the filepath for where to save the weights
weight_path = os.path.join(log_dir, "weights.{epoch:02d}-{val_loss:.2f}.hdf5")
best_model_path = os.path.join(root_model_folder, 'save_best_model.hdf5')
#----------------------------------------------------------------------#
# create a ModelCheckpoint for best model
checkpoint_best_model = ModelCheckpoint(best_model_path, save_best_only=True, monitor='val_loss', mode='min')
#----------------------------------------------------------------------#
# create a ModelCheckpoint for save weights
checkpoint_weight = ModelCheckpoint(weight_path, ave_weights_only=True, verbose=1, period=50)
#----------------------------------------------------------------------#
# Start recording time
start_time = time()

# Train the model with class weights
history = model.fit(X_train, y_train, 
                    batch_size=20, 
                    verbose=1,
                    epochs=100,
                    validation_data=(X_test, y_test), 
                    callbacks=[checkpoint_best_model, checkpoint_weight],
                    shuffle=True)

# Calculate and print the training time
end_time = time()
training_time = end_time - start_time
print(f"Training time: {training_time} seconds")

In [None]:
#Confusion_matrix and Classification_report
#----------------#
# Confusion_matrix
#----------------#

# Predict on the test data
y_pred = model.predict(X_test)

# Convert the predicted and true masks to class labels
y_pred_classes = np.argmax(y_pred, axis=-1)
y_test_classes = np.argmax(y_test, axis=-1)

# Compute the confusion matrix
cm = confusion_matrix(y_test_classes.ravel(), y_pred_classes.ravel())

# Print the confusion matrix
print(cm)
# #----------------------------------------------------------------------#
# Plot the confusion matrix 
labels = ['Background', 'Pandanus']
# Plot the confusion matrix using heatmap()
plt.figure()
sns.heatmap(cm, annot=True, cmap='Blues', fmt='d', xticklabels=labels, yticklabels=labels)
plt.title('confusion matrix_heatmap')
plt.xlabel('Predicted labels')
plt.ylabel('True labels')
plt.savefig(os.path.join(root_model_folder, 'CM_heatmap.png'), bbox_inches='tight')
plt.show()
print('Saved confusion matrix_heatmap')
#----------------------------------------------------------------------#
#---------------------#
# classification report
#---------------------#
cr = classification_report(y_test_classes.ravel(), y_pred_classes.ravel(), target_names=['Background', 'Pandanus'])
# Print the classification report
print(cr)
#----------------------------------------------------------------------#
# Export confusion matrix and classification report as .txt
file_path = os.path.join(root_model_folder, 'model training performance report.txt')
with open(file_path, 'w') as file:
    file.write(f"Training Time: {training_time} seconds\n")
    file.write("Confusion Matrix:\n")
    file.write(str(cm))
    file.write("\n\n")
    file.write("Classification Report:\n")
    file.write(cr)
print('Saved classification_and_confusion_report')

In [None]:
# Export training_history (If necessary)
# Create a DataFrame from the history
history_df = pd.DataFrame(history.history)
# Save the DataFrame to a CSV file
history_df.to_csv(os.path.join(root_model_folder,'training_history.csv'), index=False)
print('Saved training_history')

In [None]:
# plot binary_accuracy graphs using history
num_epochs = len(history.history['binary_accuracy'])
# Plot the accuracy
plt.figure(figsize=(10, 8))
plt.plot(range(1, num_epochs + 1), history.history['binary_accuracy'])
plt.plot(range(1, num_epochs + 1), history.history['val_binary_accuracy'])
plt.title('Model accuracy')
plt.ylabel('Accuracy')
plt.xlabel('Epoch')
plt.legend(['Train', 'Test'], loc='upper left')
plt.savefig(os.path.join(root_model_folder, 'Accuracy.png'), bbox_inches='tight')
plt.tight_layout()
plt.show()
print('Saved accuracy graph')
#----------------------------------------------------------------------#
# plot loss graphs using history
num_epochs = len(history.history['loss'])
# Plot the accuracy
plt.figure(figsize=(10, 8))
plt.plot(range(1, num_epochs + 1), history.history['loss'])
plt.plot(range(1, num_epochs + 1), history.history['val_loss'])
plt.title('Model loss')
plt.ylabel('LOss')
plt.xlabel('Epoch')
plt.legend(['Train', 'Test'], loc='upper left')
plt.savefig(os.path.join(root_model_folder, 'loss.png'), bbox_inches='tight')
plt.tight_layout()
plt.show()
print('Saved loss graph')
#----------------------------------------------------------------------#
# plot precision graphs using history
num_epochs = len(history.history['precision'])
# Plot the accuracy
plt.figure(figsize=(10, 8))
plt.plot(range(1, num_epochs + 1), history.history['precision'])
plt.plot(range(1, num_epochs + 1), history.history['val_precision'])
plt.title('Model precision')
plt.ylabel('precision(class id=1)')
plt.xlabel('Epoch')
plt.legend(['Train', 'Test'], loc='upper left')
plt.savefig(os.path.join(root_model_folder, 'precision.png'), bbox_inches='tight')
plt.tight_layout()
plt.show()
print('Saved precision graph')
#----------------------------------------------------------------------#
# plot recall graphs using history
num_epochs = len(history.history['recall'])
# Plot the accuracy
plt.figure(figsize=(10, 8))
plt.plot(range(1, num_epochs + 1), history.history['recall'])
plt.plot(range(1, num_epochs + 1), history.history['val_recall'])
plt.title('Model recall')
plt.ylabel('recall(class id=1)')
plt.xlabel('Epoch')
plt.legend(['Train', 'Test'], loc='upper left')
plt.savefig(os.path.join(root_model_folder, 'recall.png'), bbox_inches='tight')
plt.tight_layout()
plt.show()
print('Saved recall graph')
#----------------------------------------------------------------------#
# plot iou graphs using history
num_epochs = len(history.history['io_u'])
# Plot the accuracy
plt.figure(figsize=(10, 8))
plt.plot(range(1, num_epochs + 1), history.history['io_u'])
plt.plot(range(1, num_epochs + 1), history.history['val_io_u'])
plt.title('Model IoU')
plt.ylabel('IoU(class id=1)')
plt.xlabel('Epoch')
plt.legend(['Train', 'Test'], loc='upper left')
plt.savefig(os.path.join(root_model_folder, 'IoU.png'), bbox_inches='tight')
plt.tight_layout()
plt.show()
print('Saved IoU graph')
#----------------------------------------------------------------------#
# plot mean_iou graphs using history
num_epochs = len(history.history['mean_io_u'])
# Plot the accuracy
plt.figure(figsize=(10, 8))
plt.plot(range(1, num_epochs + 1), history.history['mean_io_u'])
plt.plot(range(1, num_epochs + 1), history.history['val_mean_io_u'])
plt.title('Model MeanIoU')
plt.ylabel('MeanIoU')
plt.xlabel('Epoch')
plt.legend(['Train', 'Test'], loc='upper left')
plt.savefig(os.path.join(root_model_folder, 'MeanIoU.png'), bbox_inches='tight')
plt.tight_layout()
plt.show()
print('Saved MeanIoU graph')
#----------------------------------------------------------------------#
# plot false_negatives graphs using history
num_epochs = len(history.history['false_negatives'])
# Plot the accuracy
plt.figure(figsize=(10, 8))
plt.plot(range(1, num_epochs + 1), history.history['false_negatives'])
plt.plot(range(1, num_epochs + 1), history.history['val_false_negatives'])
plt.title('Model FalseNegatives')
plt.ylabel('FalseNegatives')
plt.xlabel('Epoch')
plt.legend(['Train', 'Test'], loc='upper left')
plt.savefig(os.path.join(root_model_folder, 'FalseNegatives.png'), bbox_inches='tight')
plt.tight_layout()
plt.show()
print('Saved FalseNegatives graph')
#----------------------------------------------------------------------#
# plot false_positives graphs using history
num_epochs = len(history.history['false_positives'])
# Plot the accuracy
plt.figure(figsize=(10, 8))
plt.plot(range(1, num_epochs + 1), history.history['false_positives'])
plt.plot(range(1, num_epochs + 1), history.history['val_false_positives'])
plt.title('Model FalsePositives')
plt.ylabel('FalsePositives')
plt.xlabel('Epoch')
plt.legend(['Train', 'Test'], loc='upper left')
plt.savefig(os.path.join(root_model_folder, 'FalsePositives.png'), bbox_inches='tight')
plt.tight_layout()
plt.show()
print('Saved FalsePositives graph')

In [None]:
#IOU
# Calculate and save IoU for each class
class_iou = []
with open(file_path, 'a') as file:
    file.write("\n\nIoU Results:\n")
    for i in range(2):
        true_class = (y_test_classes == i)
        pred_class = (y_pred_classes == i)
        intersection = np.sum(true_class * pred_class)
        union = np.sum(true_class) + np.sum(pred_class) - intersection
        iou = intersection / union
        class_iou.append(iou)
        file.write("IoU for class {}: {:.2f}\n".format(i+1, iou))
        print("IoU for class {}: {:.2f}".format(i+1, iou))
# Calculate and save average IoU
average_iou = np.mean(class_iou)
with open(file_path, 'a') as file:
    file.write("Average IoU: {:.2f}".format(average_iou))
    print("Average IoU: {:.2f}".format(average_iou))
print('Saved IoU results')
#-------------------------xxxxxx---------------------------------------#