# U-Net for Image Segmentation

* The Montgomery datast
* With CV and embedded Early Stopping to avoid overfitting.

In [1]:
import os
import numpy as np
import cv2
import tensorflow as tf
from sklearn.model_selection import KFold
from sklearn.metrics import confusion_matrix, accuracy_score, recall_score, precision_score, f1_score
import matplotlib.pyplot as plt

# Directories for Montgomery dataset images and masks
image_dir = r'C:\Users\Jaber\OneDrive - University of Florida\Educational\GitHub\Datasets\ImageSegmentation\Montgomery_Dataset\CXR_png'
left_mask_dir = r'C:\Users\Jaber\OneDrive - University of Florida\Educational\GitHub\Datasets\ImageSegmentation\Montgomery_Dataset\Masks_png\leftMask'
right_mask_dir = r'C:\Users\Jaber\OneDrive - University of Florida\Educational\GitHub\Datasets\ImageSegmentation\Montgomery_Dataset\Masks_png\rightMask'

img_size = (256, 256)

def load_images_and_masks(image_dir, left_mask_dir, right_mask_dir, img_size):
    images = []
    masks = []
    
    for img_name in os.listdir(image_dir):
        # No need to append .png since the filenames already have it
        img_path = os.path.join(image_dir, img_name)
        left_mask_path = os.path.join(left_mask_dir, img_name)
        right_mask_path = os.path.join(right_mask_dir, img_name)
        
        # Load image
        img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
        
        # Ensure the image is loaded correctly
        if img is None:
            print(f"Image not found: {img_path}")
            continue
        
        img = cv2.resize(img, img_size)
        img = img / 255.0  # Normalize image to range 0-1
        
        # Load left and right lung masks
        left_mask = cv2.imread(left_mask_path, cv2.IMREAD_GRAYSCALE)
        right_mask = cv2.imread(right_mask_path, cv2.IMREAD_GRAYSCALE)
        
        # Ensure both masks are loaded correctly
        if left_mask is None or right_mask is None:
            print(f"Mask not found: {left_mask_path} or {right_mask_path}")
            continue
        
        left_mask = cv2.resize(left_mask, img_size)
        right_mask = cv2.resize(right_mask, img_size)
        
        # Combine left and right masks into a single mask
        combined_mask = np.maximum(left_mask, right_mask)
        combined_mask = combined_mask / 255.0  # Normalize mask to range 0-1
        
        # Append the image and the mask
        images.append(np.expand_dims(img, axis=-1))  # Add channel dimension to the image
        masks.append(np.expand_dims(combined_mask, axis=-1))  # Add channel dimension to the mask
    
    return np.array(images), np.array(masks)

# Load the images and combined masks
images, masks = load_images_and_masks(image_dir, left_mask_dir, right_mask_dir, img_size)

# Ensure data was loaded correctly
if len(images) == 0 or len(masks) == 0:
    raise ValueError("No images or masks were loaded. Check the dataset paths and file names.")

# U-Net Architecture
def unet_model(input_size=(256, 256, 1)):
    inputs = tf.keras.layers.Input(input_size)
    
    # Contracting Path
    c1 = tf.keras.layers.Conv2D(64, 3, activation='relu', padding='same')(inputs)
    c1 = tf.keras.layers.Conv2D(64, 3, activation='relu', padding='same')(c1)
    p1 = tf.keras.layers.MaxPooling2D(pool_size=(2, 2))(c1)
    
    c2 = tf.keras.layers.Conv2D(128, 3, activation='relu', padding='same')(p1)
    c2 = tf.keras.layers.Conv2D(128, 3, activation='relu', padding='same')(c2)
    p2 = tf.keras.layers.MaxPooling2D(pool_size=(2, 2))(c2)
    
    c3 = tf.keras.layers.Conv2D(256, 3, activation='relu', padding='same')(p2)
    c3 = tf.keras.layers.Conv2D(256, 3, activation='relu', padding='same')(c3)
    p3 = tf.keras.layers.MaxPooling2D(pool_size=(2, 2))(c3)
    
    # Bottleneck
    b = tf.keras.layers.Conv2D(512, 3, activation='relu', padding='same')(p3)
    b = tf.keras.layers.Conv2D(512, 3, activation='relu', padding='same')(b)
    
    # Expansive Path
    u1 = tf.keras.layers.Conv2DTranspose(256, 2, strides=(2, 2), padding='same')(b)
    u1 = tf.keras.layers.concatenate([u1, c3])
    c4 = tf.keras.layers.Conv2D(256, 3, activation='relu', padding='same')(u1)
    c4 = tf.keras.layers.Conv2D(256, 3, activation='relu', padding='same')(c4)
    
    u2 = tf.keras.layers.Conv2DTranspose(128, 2, strides=(2, 2), padding='same')(c4)
    u2 = tf.keras.layers.concatenate([u2, c2])
    c5 = tf.keras.layers.Conv2D(128, 3, activation='relu', padding='same')(u2)
    c5 = tf.keras.layers.Conv2D(128, 3, activation='relu', padding='same')(c5)
    
    u3 = tf.keras.layers.Conv2DTranspose(64, 2, strides=(2, 2), padding='same')(c5)
    u3 = tf.keras.layers.concatenate([u3, c1])
    c6 = tf.keras.layers.Conv2D(64, 3, activation='relu', padding='same')(u3)
    c6 = tf.keras.layers.Conv2D(64, 3, activation='relu', padding='same')(c6)
    
    outputs = tf.keras.layers.Conv2D(1, 1, activation='sigmoid')(c6)
    
    model = tf.keras.Model(inputs=[inputs], outputs=[outputs])
    return model

# Early stopping callback to avoid overfitting
early_stopping = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=3, restore_best_weights=True)

# Cross-validation setup
kf = KFold(n_splits=10, shuffle=True, random_state=42)

# Arrays to store performance metrics for each fold
accuracy_scores = []
recall_scores = []
precision_scores = []
f1_scores = []
specificity_scores = []

# Loop through each fold in the 10-fold cross-validation
for train_index, test_index in kf.split(images):
    X_train, X_test = images[train_index], images[test_index]
    y_train, y_test = masks[train_index], masks[test_index]
    
    # Create the model for each fold
    model = unet_model()
    model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
    
    # Train the model with early stopping
    history = model.fit(X_train, y_train, validation_data=(X_test, y_test), epochs=50, batch_size=8, callbacks=[early_stopping], verbose=1)
    
    # Predict on the test set
    y_pred = model.predict(X_test)
    y_pred = (y_pred > 0.5).astype(np.uint8)  # Ensure binary predictions
    
    y_test_binary = (y_test > 0.5).astype(np.uint8)  # Convert y_test to binary for comparison

    # Calculate performance metrics for this fold
    def calculate_metrics(y_true, y_pred):
        y_true_flat = y_true.flatten()
        y_pred_flat = y_pred.flatten()
        
        accuracy = accuracy_score(y_true_flat, y_pred_flat)
        recall = recall_score(y_true_flat, y_pred_flat)
        precision = precision_score(y_true_flat, y_pred_flat)
        f1 = f1_score(y_true_flat, y_pred_flat)
        
        tn, fp, fn, tp = confusion_matrix(y_true_flat, y_pred_flat).ravel()
        specificity = tn / (tn + fp)
        
        return accuracy, recall, precision, f1, specificity

    # Calculate metrics for this fold
    accuracy, recall, precision, f1, specificity = calculate_metrics(y_test_binary, y_pred)
    
    # Append metrics to the respective arrays
    accuracy_scores.append(accuracy)
    recall_scores.append(recall)
    precision_scores.append(precision)
    f1_scores.append(f1)
    specificity_scores.append(specificity)

# Calculate the average performance across all 10 folds
avg_accuracy = np.mean(accuracy_scores)
avg_recall = np.mean(recall_scores)
avg_precision = np.mean(precision_scores)
avg_f1 = np.mean(f1_scores)
avg_specificity = np.mean(specificity_scores)

# Print the average performance metrics
print(f'Average Accuracy: {avg_accuracy:.4f}')
print(f'Average Recall (Sensitivity): {avg_recall:.4f}')
print(f'Average Precision: {avg_precision:.4f}')
print(f'Average F1 Score: {avg_f1:.4f}')
print(f'Average Specificity: {avg_specificity:.4f}')


Epoch 1/50
[1m16/16[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m308s[0m 18s/step - accuracy: 0.6711 - loss: 0.6987 - val_accuracy: 0.7376 - val_loss: 0.5244
Epoch 2/50
[1m16/16[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m267s[0m 17s/step - accuracy: 0.7505 - loss: 0.4633 - val_accuracy: 0.7376 - val_loss: 0.3808
Epoch 3/50
[1m16/16[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m264s[0m 16s/step - accuracy: 0.7381 - loss: 0.3631 - val_accuracy: 0.7376 - val_loss: 0.3461
Epoch 4/50
[1m16/16[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m260s[0m 16s/step - accuracy: 0.7978 - loss: 0.3292 - val_accuracy: 0.9075 - val_loss: 0.3041
Epoch 5/50
[1m16/16[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m258s[0m 16s/step - accuracy: 0.9273 - loss: 0.2312 - val_accuracy: 0.8824 - val_loss: 0.3208
Epoch 6/50
[1m16/16[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m254s[0m 16s/step - accuracy: 0.9240 - loss: 0.2036 - val_accuracy: 0.9441 - val_loss: 0.1476
Epoch 7/50
[1m16/16[0m [3

  _warn_prf(average, modifier, msg_start, len(result))


Epoch 1/50
[1m16/16[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m233s[0m 14s/step - accuracy: 0.6579 - loss: 0.7182 - val_accuracy: 0.7323 - val_loss: 0.5278
Epoch 2/50
[1m16/16[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m228s[0m 14s/step - accuracy: 0.7458 - loss: 0.5015 - val_accuracy: 0.7323 - val_loss: 0.3937
Epoch 3/50
[1m16/16[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m230s[0m 14s/step - accuracy: 0.7465 - loss: 0.3685 - val_accuracy: 0.7323 - val_loss: 0.3394
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 3s/step


  _warn_prf(average, modifier, msg_start, len(result))


Epoch 1/50
[1m16/16[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m236s[0m 14s/step - accuracy: 0.6623 - loss: 0.6168 - val_accuracy: 0.7427 - val_loss: 0.4296
Epoch 2/50
[1m16/16[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m227s[0m 14s/step - accuracy: 0.7379 - loss: 0.4000 - val_accuracy: 0.9194 - val_loss: 0.3054
Epoch 3/50
[1m16/16[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m226s[0m 14s/step - accuracy: 0.9049 - loss: 0.2703 - val_accuracy: 0.9274 - val_loss: 0.1828
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 3s/step


  _warn_prf(average, modifier, msg_start, len(result))


Epoch 1/50
[1m16/16[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m235s[0m 14s/step - accuracy: 0.6685 - loss: 0.5905 - val_accuracy: 0.7446 - val_loss: 0.3948
Epoch 2/50
[1m16/16[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m224s[0m 14s/step - accuracy: 0.7428 - loss: 0.4079 - val_accuracy: 0.7537 - val_loss: 0.3534
Epoch 3/50
[1m16/16[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m225s[0m 14s/step - accuracy: 0.8394 - loss: 0.3210 - val_accuracy: 0.9257 - val_loss: 0.2554
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 3s/step


  _warn_prf(average, modifier, msg_start, len(result))


Epoch 1/50
[1m16/16[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m234s[0m 14s/step - accuracy: 0.6779 - loss: 0.6063 - val_accuracy: 0.7673 - val_loss: 0.3891
Epoch 2/50
[1m16/16[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m225s[0m 14s/step - accuracy: 0.7490 - loss: 0.3841 - val_accuracy: 0.8656 - val_loss: 0.3036
Epoch 3/50
[1m16/16[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m224s[0m 14s/step - accuracy: 0.8930 - loss: 0.2952 - val_accuracy: 0.9375 - val_loss: 0.2397
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 5s/step


  _warn_prf(average, modifier, msg_start, len(result))


Epoch 1/50
[1m16/16[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m230s[0m 14s/step - accuracy: 0.6516 - loss: 0.6975 - val_accuracy: 0.7481 - val_loss: 0.5048
Epoch 2/50
[1m16/16[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m225s[0m 14s/step - accuracy: 0.7527 - loss: 0.4617 - val_accuracy: 0.7481 - val_loss: 0.3710
Epoch 3/50
[1m16/16[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m224s[0m 14s/step - accuracy: 0.7462 - loss: 0.3709 - val_accuracy: 0.8621 - val_loss: 0.3216
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 3s/step


  _warn_prf(average, modifier, msg_start, len(result))


Epoch 1/50
[1m16/16[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m235s[0m 14s/step - accuracy: 0.7232 - loss: 0.7294 - val_accuracy: 0.7271 - val_loss: 0.5095
Epoch 2/50
[1m16/16[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m223s[0m 14s/step - accuracy: 0.7476 - loss: 0.4572 - val_accuracy: 0.7271 - val_loss: 0.3997
Epoch 3/50
[1m16/16[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m224s[0m 14s/step - accuracy: 0.7431 - loss: 0.3701 - val_accuracy: 0.7821 - val_loss: 0.3310
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 3s/step


  _warn_prf(average, modifier, msg_start, len(result))


Epoch 1/50
[1m16/16[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m245s[0m 15s/step - accuracy: 0.6625 - loss: 0.6589 - val_accuracy: 0.7696 - val_loss: 0.4760
Epoch 2/50
[1m16/16[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m290s[0m 18s/step - accuracy: 0.7354 - loss: 0.4787 - val_accuracy: 0.7696 - val_loss: 0.3208
Epoch 3/50
[1m16/16[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m243s[0m 15s/step - accuracy: 0.7423 - loss: 0.3909 - val_accuracy: 0.7696 - val_loss: 0.3115
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 4s/step


  _warn_prf(average, modifier, msg_start, len(result))


Epoch 1/50
[1m16/16[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m281s[0m 17s/step - accuracy: 0.7172 - loss: 0.6525 - val_accuracy: 0.7401 - val_loss: 0.5271
Epoch 2/50
[1m16/16[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m251s[0m 16s/step - accuracy: 0.7409 - loss: 0.4810 - val_accuracy: 0.7401 - val_loss: 0.3935
Epoch 3/50
[1m16/16[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m246s[0m 15s/step - accuracy: 0.7454 - loss: 0.3689 - val_accuracy: 0.8830 - val_loss: 0.3180
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 3s/step


  _warn_prf(average, modifier, msg_start, len(result))


Average Accuracy: 0.7683
Average Recall (Sensitivity): 0.0931
Average Precision: 0.0968
Average F1 Score: 0.0949
Average Specificity: 0.9989
