In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
import cv2
import tifffile
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, BatchNormalization, Conv2DTranspose
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, LearningRateScheduler
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from sklearn.model_selection import train_test_split

# ---- CONFIG ---- #
tile_size = 256
size = 256
image_dir = "datasets2/images"
mask_dir = "datasets2/masks"

# ---- Function to split large image into tiles ---- #
def split_image_into_tiles(image_path, mask_path, tile_size):
    img = tifffile.imread(image_path)
    mask = tifffile.imread(mask_path)
    mask = mask[:, :, 0] if len(mask.shape) == 3 else mask

    tiles_img = []
    tiles_mask = []

    for x in range(0, img.shape[1], tile_size):
        for y in range(0, img.shape[0], tile_size):
            tile_img = img[y:y+tile_size, x:x+tile_size, :] if len(img.shape) == 3 else img[y:y+tile_size, x:x+tile_size]
            tile_mask = mask[y:y+tile_size, x:x+tile_size]

            tile_img = cv2.resize(tile_img, (size, size))
            tile_mask = cv2.resize(tile_mask, (size, size))
            tile_mask = (tile_mask > 0).astype(np.uint8)

            tiles_img.append(tile_img)
            tiles_mask.append(tile_mask)

    # Padding to form square grid
    num_tiles = len(tiles_img)
    perfect_square_size = int(np.ceil(np.sqrt(num_tiles)))
    total_tiles_needed = perfect_square_size**2
    num_tiles_to_add = total_tiles_needed - num_tiles

    for _ in range(num_tiles_to_add):
        tiles_img.append(np.zeros((size, size, img.shape[2]), dtype=np.uint8))
        tiles_mask.append(np.zeros((size, size), dtype=np.uint8))

    return np.array(tiles_img), np.array(tiles_mask)

# ---- Load all image-mask tiles ---- #
def load_data(image_dir, mask_dir, tile_size):
    images, masks = [], []
    image_filenames = sorted(os.listdir(image_dir))
    mask_filenames = sorted(os.listdir(mask_dir))

    for image_filename in image_filenames:
        if image_filename.endswith(".TIF"):
            mask_filename = image_filename.replace(".TIF", "_mask.TIF")
            if mask_filename in mask_filenames:
                img_path = os.path.join(image_dir, image_filename)
                mask_path = os.path.join(mask_dir, mask_filename)
                img_tiles, mask_tiles = split_image_into_tiles(img_path, mask_path, tile_size)
                images.extend(img_tiles)
                masks.extend(mask_tiles)

    return np.array(images), np.array(masks)

# ---- Load and prepare dataset ---- #
tiles_img, tiles_mask = load_data(image_dir, mask_dir, tile_size)
X_train, X_test, y_train, y_test = train_test_split(tiles_img, tiles_mask, test_size=0.2, random_state=42)
X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=0.1, random_state=42)

# Reshape masks to (H, W, 1)
y_train = y_train[..., np.newaxis]
y_val = y_val[..., np.newaxis]
y_test = y_test[..., np.newaxis]

# ---- Define AlexNet-style Segmentation Model ---- #
def alexnet_segmentation_model(input_size=(256, 256, 3)):
    inputs = Input(shape=input_size)

    # Encoder (AlexNet-style)
    x = Conv2D(96, (11, 11), strides=4, padding='same', activation='relu')(inputs)
    x = BatchNormalization()(x)
    x = MaxPooling2D((3, 3), strides=2, padding='same')(x)

    x = Conv2D(256, (5, 5), padding='same', activation='relu')(x)
    x = BatchNormalization()(x)
    x = MaxPooling2D((3, 3), strides=2, padding='same')(x)

    x = Conv2D(384, (3, 3), padding='same', activation='relu')(x)
    x = BatchNormalization()(x)

    x = Conv2D(384, (3, 3), padding='same', activation='relu')(x)
    x = BatchNormalization()(x)

    x = Conv2D(256, (3, 3), padding='same', activation='relu')(x)
    x = BatchNormalization()(x)
    x = MaxPooling2D((3, 3), strides=2, padding='same')(x)

    # Decoder (Upsampling only, no skip connections)
    x = Conv2DTranspose(256, (3, 3), strides=2, padding='same', activation='relu')(x)
    x = BatchNormalization()(x)

    x = Conv2DTranspose(128, (3, 3), strides=2, padding='same', activation='relu')(x)
    x = BatchNormalization()(x)

    x = Conv2DTranspose(64, (3, 3), strides=2, padding='same', activation='relu')(x)
    x = BatchNormalization()(x)

    x = Conv2DTranspose(32, (3, 3), strides=2, padding='same', activation='relu')(x)
    x = BatchNormalization()(x)

    output = Conv2D(1, (1, 1), activation='sigmoid')(x)

    # Resize to match ground truth
    output = tf.image.resize(output, (input_size[0], input_size[1]), method='bilinear')

    model = Model(inputs, output)
    model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
    
    model.summary()
    return model

# ---- Build Model ---- #
alexnet_model = alexnet_segmentation_model(input_size=(size, size, 3))

# ---- Learning Rate Schedule ---- #
def lr_schedule(epoch):
    initial_lr = 0.0001
    decay = 0.9
    return initial_lr * (decay ** (epoch // 10))

lr_scheduler = LearningRateScheduler(lr_schedule)

# ---- Callbacks ---- #
checkpointer = ModelCheckpoint("alexnet_segmentation_best.h5", monitor="val_loss", save_best_only=True, verbose=1)
earlyStopping = EarlyStopping(monitor='val_loss', patience=5, verbose=1)

# ---- Data Augmentation ---- #
datagen = ImageDataGenerator(
    rescale=1./255,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True,
    rotation_range=20,
    width_shift_range=0.2,
    height_shift_range=0.2,
    brightness_range=[0.8, 1.2]
)

# ---- Train Model ---- #
history = alexnet_model.fit(
    datagen.flow(X_train, y_train, batch_size=32),
    epochs=100,
    validation_data=(X_val / 255.0, y_val),
    callbacks=[lr_scheduler, checkpointer, earlyStopping]
)

# ---- Evaluate Model ---- #
loss, accuracy = alexnet_model.evaluate(X_test / 255.0, y_test)
print(f"Test Loss: {loss:.4f}, Test Accuracy: {accuracy:.4f}")

# ---- Save Final Model ---- #
alexnet_model.save("alexnet_segmentation_model_final.h5")

# ---- Plot Training History ---- #
plt.plot(history.history['loss'], label='Train Loss')
plt.plot(history.history['val_loss'], label='Val Loss')
plt.plot(history.history['accuracy'], label='Train Acc')
plt.plot(history.history['val_accuracy'], label='Val Acc')
plt.xlabel("Epochs")
plt.ylabel("Loss/Accuracy")
plt.title("AlexNet Segmentation Training History")
plt.legend()
plt.show()


In [None]:
import numpy as np
from sklearn.metrics import (
    accuracy_score,
    f1_score,
    precision_score,
    recall_score,
    confusion_matrix,
    roc_auc_score
)
from tensorflow.keras.models import load_model

# Load the saved model
loaded_model = load_model('alexnet_segmentation_best.h5')

# Predict probabilities on the test set
y_pred = loaded_model.predict(X_test)

# Threshold the predictions to get binary values (0 or 1)
y_pred_binary = (y_pred > 0.5).astype(int)

# Flatten arrays for metric calculations
y_test_flat = y_test.flatten()
y_pred_flat = y_pred_binary.flatten()
y_pred_prob_flat = y_pred.flatten()  # Keep raw probabilities for ROC-AUC

# Accuracy
accuracy = accuracy_score(y_test_flat, y_pred_flat)
print(f'Accuracy: {accuracy:.4f}')

# F1 Score
f1 = f1_score(y_test_flat, y_pred_flat)
print(f'F1 Score: {f1:.4f}')

# Precision
precision = precision_score(y_test_flat, y_pred_flat)
print(f'Precision: {precision:.4f}')

# Recall
recall = recall_score(y_test_flat, y_pred_flat)
print(f'Recall: {recall:.4f}')

# Confusion Matrix
conf_matrix = confusion_matrix(y_test_flat, y_pred_flat)
print('Confusion Matrix:')
print(conf_matrix)

# Mean Intersection over Union (mIoU)
intersection = np.sum(np.logical_and(y_test_flat, y_pred_flat))
union = np.sum(np.logical_or(y_test_flat, y_pred_flat))
miou = intersection / union
print(f'Mean Intersection over Union (mIoU): {miou:.4f}')

# ROC-AUC Score
roc_auc = roc_auc_score(y_test_flat, y_pred_prob_flat)
print(f'ROC-AUC Score: {roc_auc:.4f}')