In [None]:
import os
import numpy as np
import cv2
import tensorflow as tf
from sklearn.model_selection import KFold
from sklearn.metrics import confusion_matrix, accuracy_score, recall_score, precision_score, f1_score
import matplotlib.pyplot as plt
import plotly.figure_factory as ff

# Directories for Montgomery dataset images and masks
image_dir = r'C:\Users\Jaber\OneDrive - University of Florida\Educational\GitHub\Datasets\ImageSegmentation\Montgomery_Dataset\CXR_png'
left_mask_dir = r'C:\Users\Jaber\OneDrive - University of Florida\Educational\GitHub\Datasets\ImageSegmentation\Montgomery_Dataset\Masks_png\leftMask'
right_mask_dir = r'C:\Users\Jaber\OneDrive - University of Florida\Educational\GitHub\Datasets\ImageSegmentation\Montgomery_Dataset\Masks_png\rightMask'

img_size = (256, 256)

def load_images_and_masks(image_dir, left_mask_dir, right_mask_dir, img_size):
    images = []
    masks = []
    
    for img_name in os.listdir(image_dir):
        # No need to append .png since the filenames already have it
        img_path = os.path.join(image_dir, img_name)
        left_mask_path = os.path.join(left_mask_dir, img_name)
        right_mask_path = os.path.join(right_mask_dir, img_name)
        
        # Load image
        img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
        
        # Ensure the image is loaded correctly
        if img is None:
            print(f"Image not found: {img_path}")
            continue
        
        img = cv2.resize(img, img_size)
        img = img / 255.0  # Normalize image to range 0-1
        
        # Load left and right lung masks
        left_mask = cv2.imread(left_mask_path, cv2.IMREAD_GRAYSCALE)
        right_mask = cv2.imread(right_mask_path, cv2.IMREAD_GRAYSCALE)
        
        # Ensure both masks are loaded correctly
        if left_mask is None or right_mask is None:
            print(f"Mask not found: {left_mask_path} or {right_mask_path}")
            continue
        
        left_mask = cv2.resize(left_mask, img_size)
        right_mask = cv2.resize(right_mask, img_size)
        
        # Combine left and right masks into a single mask
        combined_mask = np.maximum(left_mask, right_mask)
        combined_mask = combined_mask / 255.0  # Normalize mask to range 0-1
        
        # Append the image and the mask
        images.append(np.expand_dims(img, axis=-1))  # Add channel dimension to the image
        masks.append(np.expand_dims(combined_mask, axis=-1))  # Add channel dimension to the mask
    
    return np.array(images), np.array(masks)

# Load the images and combined masks
images, masks = load_images_and_masks(image_dir, left_mask_dir, right_mask_dir, img_size)

# Ensure data was loaded correctly
if len(images) == 0 or len(masks) == 0:
    raise ValueError("No images or masks were loaded. Check the dataset paths and file names.")

# Capsule Layer Helper Functions
from tensorflow.keras import layers

def squash(vectors, axis=-1):
    """Squashing function to ensure output vectors' lengths are between 0 and 1"""
    s_squared_norm = tf.reduce_sum(tf.square(vectors), axis, keepdims=True)
    scale = s_squared_norm / (1 + s_squared_norm) / tf.sqrt(s_squared_norm + tf.keras.backend.epsilon())
    return scale * vectors

class CapsuleLayer(layers.Layer):
    def __init__(self, num_capsules, dim_capsule, num_routing=3, **kwargs):
        super(CapsuleLayer, self).__init__(**kwargs)
        self.num_capsules = num_capsules
        self.dim_capsule = dim_capsule
        self.num_routing = num_routing

    def build(self, input_shape):
        self.W = self.add_weight(shape=[input_shape[-1], self.num_capsules, self.dim_capsule], initializer='glorot_uniform', trainable=True)

    def call(self, inputs):
        # inputs.shape = [None, input_dim]
        u_hat = tf.einsum('...i,...ij->...j', inputs, self.W)
        u_hat_stopped = tf.stop_gradient(u_hat)

        b = tf.zeros(shape=[tf.shape(inputs)[0], self.num_capsules])
        for i in range(self.num_routing):
            c = tf.nn.softmax(b, axis=1)
            outputs = squash(tf.einsum('...i,...i->...i', c, u_hat_stopped if i < self.num_routing - 1 else u_hat))
            if i < self.num_routing - 1:
                b += tf.einsum('...i,...i->...i', outputs, u_hat_stopped)

        return outputs

# U-Net with Capsule Network Layers
def unet_capsule_model(input_size=(256, 256, 1)):
    inputs = tf.keras.layers.Input(input_size)
    
    # Contracting Path with Capsules
    c1 = tf.keras.layers.Conv2D(64, 3, activation='relu', padding='same')(inputs)
    c1 = CapsuleLayer(num_capsules=8, dim_capsule=16)(c1)  # Capsule layer instead of Conv2D
    p1 = tf.keras.layers.MaxPooling2D(pool_size=(2, 2))(c1)
    
    c2 = tf.keras.layers.Conv2D(128, 3, activation='relu', padding='same')(p1)
    c2 = CapsuleLayer(num_capsules=16, dim_capsule=32)(c2)  # Capsule layer instead of Conv2D
    p2 = tf.keras.layers.MaxPooling2D(pool_size=(2, 2))(c2)
    
    c3 = tf.keras.layers.Conv2D(256, 3, activation='relu', padding='same')(p2)
    c3 = CapsuleLayer(num_capsules=32, dim_capsule=64)(c3)  # Capsule layer instead of Conv2D
    p3 = tf.keras.layers.MaxPooling2D(pool_size=(2, 2))(c3)
    
    # Bottleneck
    b = tf.keras.layers.Conv2D(512, 3, activation='relu', padding='same')(p3)
    b = tf.keras.layers.Conv2D(512, 3, activation='relu', padding='same')(b)
    
    # Expansive Path (kept similar to original U-Net)
    u1 = tf.keras.layers.Conv2DTranspose(256, 2, strides=(2, 2), padding='same')(b)
    u1 = tf.keras.layers.concatenate([u1, c3])
    c4 = tf.keras.layers.Conv2D(256, 3, activation='relu', padding='same')(u1)
    c4 = tf.keras.layers.Conv2D(256, 3, activation='relu', padding='same')(c4)
    
    u2 = tf.keras.layers.Conv2DTranspose(128, 2, strides=(2, 2), padding='same')(c4)
    u2 = tf.keras.layers.concatenate([u2, c2])
    c5 = tf.keras.layers.Conv2D(128, 3, activation='relu', padding='same')(u2)
    c5 = tf.keras.layers.Conv2D(128, 3, activation='relu', padding='same')(c5)
    
    u3 = tf.keras.layers.Conv2DTranspose(64, 2, strides=(2, 2), padding='same')(c5)
    u3 = tf.keras.layers.concatenate([u3, c1])
    c6 = tf.keras.layers.Conv2D(64, 3, activation='relu', padding='same')(u3)
    c6 = tf.keras.layers.Conv2D(64, 3, activation='relu', padding='same')(c6)
    
    outputs = tf.keras.layers.Conv2D(1, 1, activation='sigmoid')(c6)
    
    model = tf.keras.Model(inputs=[inputs], outputs=[outputs])
    return model

# Early stopping callback to avoid overfitting
early_stopping = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=3, restore_best_weights=True)

# Cross-validation setup with 3 folds
kf = KFold(n_splits=3, shuffle=True, random_state=42)

# Arrays to store performance metrics for each fold
accuracy_scores = []
recall_scores = []
precision_scores = []
f1_scores = []
specificity_scores = []

# Loop through each fold in the 3-fold cross-validation
for train_index, test_index in kf.split(images):
    X_train, X_test = images[train_index], images[test_index]
    y_train, y_test = masks[train_index], masks[test_index]
    
    # Create the model for each fold
    model = unet_capsule_model()
    model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
    
    # Train the model with early stopping
    history = model.fit(X_train, y_train, validation_data=(X_test, y_test), epochs=50, batch_size=8, callbacks=[early_stopping], verbose=1)
    
    # Predict on the test set
    y_pred = model.predict(X_test)
    y_pred = (y_pred > 0.5).astype(np.uint8)  # Ensure binary predictions
    
    y_test_binary = (y_test > 0.5).astype(np.uint8)  # Convert y_test to binary for comparison

    # Calculate performance metrics for this fold
    def calculate_metrics(y_true, y_pred):
        y_true_flat = y_true.flatten()
        y_pred_flat = y_pred.flatten()
        
        accuracy = accuracy_score(y_true_flat, y_pred_flat)
        recall = recall_score(y_true_flat, y_pred_flat)
        precision = precision_score(y_true_flat, y_pred_flat)
        f1 = f1_score(y_true_flat, y_pred_flat)
        
        tn, fp, fn, tp = confusion_matrix(y_true_flat, y_pred_flat).ravel()
        specificity = tn / (tn + fp)
        
        return accuracy, recall, precision, f1, specificity

    # Calculate metrics for this fold
    accuracy, recall, precision, f1, specificity = calculate_metrics(y_test_binary, y_pred)
    
    # Append metrics to the respective arrays
    accuracy_scores.append(accuracy)
    recall_scores.append(recall)
    precision_scores.append(precision)
    f1_scores.append(f1)
    specificity_scores.append(specificity)

    # Plot confusion matrix for this fold
    cm = confusion_matrix(y_test_binary.flatten(), y_pred.flatten())
    if cm.shape != (2, 2):
        cm_padded = np.zeros((2, 2), dtype=int)
        cm_padded[:cm.shape[0], :cm.shape[1]] = cm
    else:
        cm_padded = cm
    x_labels = ['Normal', 'Abnormal']
    y_labels = ['Abnormal', 'Normal']
    cm_reversed = cm_padded[::-1]
    fig = ff.create_annotated_heatmap(z=cm_reversed, x=x_labels, y=y_labels, colorscale='Blues')
    fig.update_layout(
        title=f'Confusion Matrix, Fold',
        xaxis=dict(title='Predicted labels', tickfont=dict(size=10)),
        yaxis=dict(title='True labels', tickfont=dict(size=10)),
        width=400,
        height=300,
        margin=dict(l=50, r=50, t=130, b=50)
    )
    fig.show()

# Calculate the average performance across all 3 folds
avg_accuracy = np.mean(accuracy_scores)
avg_recall = np.mean(recall_scores)
avg_precision = np.mean(precision_scores)
avg_f1 = np.mean(f1_scores)
avg_specificity = np.mean(specificity_scores)

# Print the average performance metrics
print(f'Average Accuracy: {avg_accuracy:.4f}')
print(f'Average Recall (Sensitivity): {avg_recall:.4f}')
print(f'Average Precision: {avg_precision:.4f}')
print(f'Average F1 Score: {avg_f1:.4f}')
print(f'Average Specificity: {avg_specificity:.4f}')
