## Load libraries

In [None]:
import sys

if not sys.warnoptions:
    import warnings
    warnings.simplefilter("ignore")
    
from osgeo import gdal
import tensorflow as tf
import numpy as np
from sklearn.preprocessing import StandardScaler
from joblib import dump, load
import matplotlib.colors as mcolors
import random

# Load Test datasets

In [None]:
# write a fuction to load the various datasets 
def load_data(directory):
    """
    Load multi-band .tif files from a directory into a NumPy array.
    
    Args:
        directory (str): Path to the directory containing .tif files.
    
    Returns:
        np.ndarray: Stack of .tif files as a 4D NumPy array (num_files, height, width, bands).
        list: List of filenames in the order they were loaded.
    """
    tif_files = [f for f in os.listdir(directory) if f.endswith('.tif')]
    tif_files.sort()

    arrays = []
    filenames = []

    for tif_file in tif_files:
        file_path = os.path.join(directory, tif_file)
        #print(f"Loading {file_path}...")

        dataset = gdal.Open(file_path)
        if dataset is None:
            print(f"Failed to load {file_path}")
            continue

        # Get number of bands
        num_bands = dataset.RasterCount
        bands = []
        
        # Read all bands
        for i in range(1, num_bands + 1):
            band = dataset.GetRasterBand(i)
            bands.append(band.ReadAsArray())

        # Stack bands along the last axis
        array = np.stack(bands, axis=-1)
        arrays.append(array)
        filenames.append(tif_file)

        dataset = None

    if arrays:
        stacked_array = np.stack(arrays, axis=0)
        print(f"Loaded {len(arrays)} .tif files into array of shape {stacked_array.shape}")
        return stacked_array, filenames
    else:
        print("No .tif files loaded.")
        return None, []

In [None]:
#Resizing images, if needed
SIZE_X = 256
SIZE_Y = 256
n_classes=3 #Number of classes for segmentation

test_images = load_data("path to the test image/")
X_test = test_images[0]
#X_test = X_test[:36,:,:,:2]
test_masks = load_data("path to the test label/") #if required
y_test = test_masks[0]

del test_masks, test_images

# Standardisation 

In [None]:
num_samples, width, height, num_channels = X_test.shape
X_test_reshaped = X_test.reshape(-1, num_channels)

#load the saved standardise parameters and apply to prediction dataset
scaler = load("path to/pretrained_scaler_4model_predictions.joblib")

X_test_normalized = scaler.transform(X_test_reshaped)

In [None]:
#reshape to original size
X_test = X_test_normalized.reshape(num_samples, width, height, num_channels)

del X_test_reshaped, X_test_normalized

## One hot encoding

In [None]:
# encode the labels
y_test_cat = tf.keras.utils.to_categorical(y_test, num_classes=n_classes)

## Model Prediction

In [None]:
path = "path to pretrained saved model/UNeTASMMonitoring_30_epochs_FocalLossAndDiceLoss.hdf5"

#Load one model at a time for testing.
Attmodel = tf.keras.models.load_model(path, compile=False)


In [None]:
X_test = X_test.reshape(width, height, num_channels)

def predict_large_image(image, model, patch_size=256, num_classes=3, overlap=0.2):
    """
    Predict a large image using a UNet model trained on smaller patches.
    
    Parameters:
        image (np.array): (H, W, C) input image (standardized/normalized as required by model)
        model (tf.keras.Model): Trained UNet model
        patch_size (int): Size of patches used during training
        num_classes (int): Number of output classes (for softmax)
        overlap (float): Fraction of overlap between patches, e.g., 0.2 = 20%
    
    Returns:
        output (np.array): Predicted mask (H, W, num_classes)
    """
    H, W, C = image.shape
    stride = int(patch_size * (1 - overlap))
    output = np.zeros((H, W, num_classes))
    count = np.zeros((H, W, num_classes))  # For overlapping average

    for y in range(0, H - patch_size + 1, stride):
        for x in range(0, W - patch_size + 1, stride):
            patch = image[y:y+patch_size, x:x+patch_size, :]
            patch = np.expand_dims(patch, axis=0)  # (1, patch_size, patch_size, C)
            pred = model.predict(patch, verbose=0)[0]  # (patch_size, patch_size, num_classes)
            
            output[y:y+patch_size, x:x+patch_size, :] += pred
            count[y:y+patch_size, x:x+patch_size, :] += 1

    # Handle border patches (if needed)
    for y in [H - patch_size]:
        for x in range(0, W - patch_size + 1, stride):
            patch = image[y:y+patch_size, x:x+patch_size, :]
            patch = np.expand_dims(patch, axis=0)
            pred = model.predict(patch, verbose=0)[0]
            output[y:y+patch_size, x:x+patch_size, :] += pred
            count[y:y+patch_size, x:x+patch_size, :] += 1

    for y in range(0, H - patch_size + 1, stride):
        for x in [W - patch_size]:
            patch = image[y:y+patch_size, x:x+patch_size, :]
            patch = np.expand_dims(patch, axis=0)
            pred = model.predict(patch, verbose=0)[0]
            output[y:y+patch_size, x:x+patch_size, :] += pred
            count[y:y+patch_size, x:x+patch_size, :] += 1

    # Bottom-right corner
    patch = image[H-patch_size:H, W-patch_size:W, :]
    patch = np.expand_dims(patch, axis=0)
    pred = model.predict(patch, verbose=0)[0]
    output[H-patch_size:H, W-patch_size:W, :] += pred
    count[H-patch_size:H, W-patch_size:W, :] += 1

    # Normalize by count to get averaged prediction
    output = output / np.maximum(count, 1e-7)

    return output

In [None]:
'''Prediction over the test dataset'''
pred_test = predict_large_image(X_test, Attmodel, patch_size=256, num_classes=3, overlap=0.5)


pred_test = np.argmax(pred_test, axis=-1)
print(pred_test.shape)

# Visualise Predictions

In [None]:
import matplotlib.colors as mcolors

# Define hex color mapping for each class
class_info = {
    0: ("#023020", "Vegetation"),  # Dark green
    1: ("#ffbf00", "Mines"),  # Gold
    2: ("#999999", "Other"),  # Med Gray
}


# Separate colors and names
color_list = [class_info[i][0] for i in range(len(class_info))]
class_names = [class_info[i][1] for i in range(len(class_info))]

# Create colormap and norm
cmap = mcolors.ListedColormap(color_list)
bounds = np.arange(len(class_info) + 1)
norm = mcolors.BoundaryNorm(boundaries=bounds, ncolors=len(class_info))

In [None]:
image_number = random.randint(0, len(pred_test))
plt.figure(figsize=(12, 6))
plt.subplot(121)
plt.imshow(pred_test,cmap=cmap, interpolation='nearest')
plt.title('Predicted')
plt.axis('off')
#plt.savefig("West.png", dpi=600, bbox_inches='tight', pad_inches=0.0)
plt.subplot(122)
plt.imshow(y_test[0, :, :, 0],cmap=cmap, interpolation='nearest')
plt.title('True Label')
plt.axis('off')
#plt.colorbar(ticks=range(len(class_colors)))
plt.show()

## Model Performance Evaluation

In [None]:
CNN_pred = predict_large_image(X_test, Attmodel, patch_size=256, num_classes=3, overlap=0.5)
CNN_pred_argmax = np.argmax(CNN_pred, axis=-1)

In [None]:
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score, jaccard_score

def evaluate_segmentation(y_true, y_pred, class_names=None):
    """
    Compute per-class and overall Precision, Recall, IoU, F1-score, and Accuracy 
    for multi-class semantic segmentation, rounded to 4 decimal places.

    Parameters:
        y_true (numpy array): Ground truth segmentation masks (one-hot encoded) - shape (N, H, W, C)
        y_pred (numpy array): Predicted segmentation masks (probabilities) - shape (N, H, W, C)
        class_names (list): Optional list of class names. Length must equal number of classes (C).

    Returns:
        dict: Dictionary containing per-class and overall metrics.
    """
    # Convert to class indices
    y_true_labels = np.argmax(y_true, axis=-1).flatten()
    y_pred_labels = np.argmax(y_pred, axis=-1).flatten()
    
    num_classes = y_true.shape[-1]
    class_indices = list(range(num_classes))

    # Default class names if not provided
    if class_names is None:
        class_names = [f"Class_{i}" for i in class_indices]

    # --- Per-class metrics ---
    precision_per_class = precision_score(y_true_labels, y_pred_labels, average=None, labels=class_indices, zero_division=1)
    recall_per_class = recall_score(y_true_labels, y_pred_labels, average=None, labels=class_indices, zero_division=1)
    f1_per_class = f1_score(y_true_labels, y_pred_labels, average=None, labels=class_indices, zero_division=1)
    iou_per_class = jaccard_score(y_true_labels, y_pred_labels, average=None, labels=class_indices)

    # --- Overall metrics (macro) ---
    precision_macro = precision_score(y_true_labels, y_pred_labels, average='macro', zero_division=1)
    recall_macro = recall_score(y_true_labels, y_pred_labels, average='macro', zero_division=1)
    f1_macro = f1_score(y_true_labels, y_pred_labels, average='macro', zero_division=1)
    iou_macro = jaccard_score(y_true_labels, y_pred_labels, average='macro')
    accuracy = accuracy_score(y_true_labels, y_pred_labels)

    # --- Organize per-class metrics into readable dict ---
    per_class_metrics = {
        name: {
            "Precision": round(float(p), 2),
            "Recall": round(float(r), 2),
            "F1-score": round(float(f), 2),
            "IoU": round(float(i), 2)
        }
        for name, p, r, f, i in zip(class_names, precision_per_class, recall_per_class, f1_per_class, iou_per_class)
    }

    # --- Final dictionary ---
    results = {
        "Per-Class Metrics": per_class_metrics,
        "Precision": round(precision_macro, 2),
        "Recall": round(recall_macro, 2),
        "F1": round(f1_macro, 2),
        "Mean IoU": round(iou_macro, 2),
        "Accuracy": round(accuracy, 2)
    }

    return results


In [None]:
class_names = ["Vegetation", "Mines", "Other"]

metrics = evaluate_segmentation(y_test_cat, CNN_pred, class_names)
for cls, vals in metrics["Per-Class Metrics"].items():
    print(f"{cls}: {vals}")

print("\nOverall Metrics:")
for k, v in metrics.items():
    if k != "Per-Class Metrics":
        print(f"{k}: {v}")
