## U-Net architecture

In [None]:
import os
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.preprocessing.image import load_img, img_to_array
from tensorflow.keras import backend as K
import cv2
import matplotlib.pyplot as plt
from datetime import datetime  # Import datetime for the custom callback
from sklearn.metrics import confusion_matrix, accuracy_score, recall_score, precision_score, f1_score
import seaborn as sns
from tensorflow.keras.callbacks import EarlyStopping  # Import EarlyStopping

# Directory paths
train_img_dir = r"C:\Users\Jaber\OneDrive - University of Florida\Educational\GitHub\Datasets\ImageSegmentation\Dental_XRay_Computacional_Vision_Segmentation\Dental X_Ray\train"
train_mask_dir = r"C:\Users\Jaber\OneDrive - University of Florida\Educational\GitHub\ImageSegmentation\Datasets\Dental_XRay_Computacional_Vision_Segmentation\Dental X_Ray\train\train_mask"
test_img_dir = r"C:\Users\Jaber\OneDrive - University of Florida\Educational\GitHub\Datasets\ImageSegmentation\Dental_XRay_Computacional_Vision_Segmentation\Dental X_Ray\test"
test_mask_dir = r"C:\Users\Jaber\OneDrive - University of Florida\Educational\GitHub\ImageSegmentation\Datasets\Dental_XRay_Computacional_Vision_Segmentation\Dental X_Ray\test\test_mask"
valid_img_dir = r"C:\Users\Jaber\OneDrive - University of Florida\Educational\GitHub\Datasets\ImageSegmentation\Dental_XRay_Computacional_Vision_Segmentation\Dental X_Ray\valid"
valid_mask_dir = r"C:\Users\Jaber\OneDrive - University of Florida\Educational\GitHub\ImageSegmentation\Datasets\Dental_XRay_Computacional_Vision_Segmentation\Dental X_Ray\valid\valid_mask"

# U-Net model
def unet_model(input_size=(256, 256, 1)):
    inputs = layers.Input(input_size)
    
    # Downsample
    c1 = layers.Conv2D(64, (3, 3), activation='relu', padding='same')(inputs)
    c1 = layers.Conv2D(64, (3, 3), activation='relu', padding='same')(c1)
    p1 = layers.MaxPooling2D((2, 2))(c1)

    c2 = layers.Conv2D(128, (3, 3), activation='relu', padding='same')(p1)
    c2 = layers.Conv2D(128, (3, 3), activation='relu', padding='same')(c2)
    p2 = layers.MaxPooling2D((2, 2))(c2)

    c3 = layers.Conv2D(256, (3, 3), activation='relu', padding='same')(p2)
    c3 = layers.Conv2D(256, (3, 3), activation='relu', padding='same')(c3)
    p3 = layers.MaxPooling2D((2, 2))(c3)

    c4 = layers.Conv2D(512, (3, 3), activation='relu', padding='same')(p3)
    c4 = layers.Conv2D(512, (3, 3), activation='relu', padding='same')(c4)
    p4 = layers.MaxPooling2D(pool_size=(2, 2))(c4)

    # Bottleneck
    c5 = layers.Conv2D(1024, (3, 3), activation='relu', padding='same')(p4)
    c5 = layers.Conv2D(1024, (3, 3), activation='relu', padding='same')(c5)

    # Upsample
    u6 = layers.Conv2DTranspose(512, (2, 2), strides=(2, 2), padding='same')(c5)
    u6 = layers.concatenate([u6, c4])
    c6 = layers.Conv2D(512, (3, 3), activation='relu', padding='same')(u6)
    c6 = layers.Conv2D(512, (3, 3), activation='relu', padding='same')(c6)

    u7 = layers.Conv2DTranspose(256, (2, 2), strides=(2, 2), padding='same')(c6)
    u7 = layers.concatenate([u7, c3])
    c7 = layers.Conv2D(256, (3, 3), activation='relu', padding='same')(u7)
    c7 = layers.Conv2D(256, (3, 3), activation='relu', padding='same')(c7)

    u8 = layers.Conv2DTranspose(128, (2, 2), strides=(2, 2), padding='same')(c7)
    u8 = layers.concatenate([u8, c2])
    c8 = layers.Conv2D(128, (3, 3), activation='relu', padding='same')(u8)
    c8 = layers.Conv2D(128, (3, 3), activation='relu', padding='same')(c8)

    u9 = layers.Conv2DTranspose(64, (2, 2), strides=(2, 2), padding='same')(c8)
    u9 = layers.concatenate([u9, c1], axis=3)
    c9 = layers.Conv2D(64, (3, 3), activation='relu', padding='same')(u9)
    c9 = layers.Conv2D(64, (3, 3), activation='relu', padding='same')(c9)

    outputs = layers.Conv2D(1, (1, 1), activation='sigmoid')(c9)

    model = models.Model(inputs=[inputs], outputs=[outputs])
    
    return model

# Load images and masks
def load_images_and_masks(img_dir, mask_dir, img_size=(256, 256)):
    images = []
    masks = []
    
    img_files = os.listdir(img_dir)
    for img_file in img_files:
        img_path = os.path.join(img_dir, img_file)
        # Adjust mask filename to include "_mask.png"
        mask_file = img_file + "_mask.png"
        mask_path = os.path.join(mask_dir, mask_file)

        if os.path.exists(mask_path):
            # Load image and mask
            img = load_img(img_path, color_mode='grayscale', target_size=img_size)
            img = img_to_array(img) / 255.0
            mask = load_img(mask_path, color_mode='grayscale', target_size=img_size)
            mask = img_to_array(mask) / 255.0

            images.append(img)
            masks.append(mask)
        else:
            print(f"Mask not found for {img_file}, skipping this image.")
    
    return np.array(images), np.array(masks)

# Load training and validation data
X_train, y_train = load_images_and_masks(train_img_dir, train_mask_dir)
X_valid, y_valid = load_images_and_masks(valid_img_dir, valid_mask_dir)

# Define custom metrics
def custom_precision(y_true, y_pred):
    y_pred_bin = K.round(y_pred)
    true_positives = K.sum(K.round(y_true * y_pred_bin))
    predicted_positives = K.sum(y_pred_bin)
    precision = true_positives / (predicted_positives + K.epsilon())
    return precision

def custom_recall(y_true, y_pred):
    y_pred_bin = K.round(y_pred)
    true_positives = K.sum(K.round(y_true * y_pred_bin))
    possible_positives = K.sum(y_true)
    recall = true_positives / (possible_positives + K.epsilon())
    return recall

def custom_specificity(y_true, y_pred):
    y_pred_bin = K.round(y_pred)
    true_negatives = K.sum(K.round((1 - y_true) * (1 - y_pred_bin)))
    possible_negatives = K.sum(1 - y_true)
    specificity = true_negatives / (possible_negatives + K.epsilon())
    return specificity

def custom_f1(y_true, y_pred):
    precision = custom_precision(y_true, y_pred)
    recall = custom_recall(y_true, y_pred)
    return 2 * (precision * recall) / (precision + recall + K.epsilon())

# Define Focal Loss
def focal_loss_fixed(y_true, y_pred):
    gamma = 2.0
    alpha = 0.25
    epsilon = K.epsilon()
    y_pred = K.clip(y_pred, epsilon, 1. - epsilon)
    cross_entropy = -y_true * K.log(y_pred) - (1 - y_true) * K.log(1 - y_pred)
    weight = alpha * y_true * K.pow((1 - y_pred), gamma) + (1 - alpha) * (1 - y_true) * K.pow(y_pred, gamma)
    loss = weight * cross_entropy
    return K.mean(loss)

# Compile the model
model = unet_model()
model.compile(optimizer='adam', loss=focal_loss_fixed, metrics=['accuracy', custom_precision, custom_recall, custom_specificity, custom_f1])

# Batch size for training
batch_size = 16

# Calculate total batches for training
total_batches = int(np.ceil(len(X_train) / batch_size))

# Custom callback to print more metrics at each batch
class MetricsCallback(tf.keras.callbacks.Callback):
    def __init__(self, total_batches):
        super().__init__()
        self.batch_counter = 1  # Initialize the batch counter
        self.total_batches = total_batches  # Total number of batches per epoch
        self.current_epoch = 1  # Initialize current epoch

    def on_epoch_begin(self, epoch, logs=None):
        self.current_epoch = epoch + 1  # Epochs are zero-indexed
        print(f"\nEpoch {self.current_epoch}/{self.params['epochs']}")

    def on_batch_end(self, batch, logs=None):
        logs = logs or {}
        accuracy = logs.get('accuracy', 0)
        loss = logs.get('loss', 0)
        precision = logs.get('custom_precision', 0)
        recall = logs.get('custom_recall', 0)
        f1 = logs.get('custom_f1', 0)
        specificity = logs.get('custom_specificity', 0)
        
        # Time formatting for current step
        current_time = datetime.now().strftime("%H:%M:%S")
        
        # Print the metrics with proper formatting
        print(f"Batch {self.batch_counter}/{self.total_batches} ━━━━━━━━━━━━━━━━━━━━ {current_time}")
        print(f"Accuracy: {accuracy:.4f} - Precision: {precision:.4f} - Recall: {recall:.4f} - Specificity: {specificity:.4f} - F1: {f1:.4f} - Loss: {loss:.4f}\n")
        
        # Increment batch counter
        self.batch_counter += 1

    def on_epoch_end(self, epoch, logs=None):
        # Reset batch counter at the end of each epoch
        self.batch_counter = 1

# Define early stopping
early_stopping = EarlyStopping(
    monitor='val_loss',        # Metric to monitor
    patience=5,                # Number of epochs with no improvement after which training will be stopped
    restore_best_weights=True  # Restore model weights from the epoch with the best value of the monitored metric
)

# Initialize the custom callback
metrics_callback = MetricsCallback(total_batches=total_batches)

# Train the model with early stopping
history = model.fit(
    X_train, y_train,
    epochs=20,               # Number of epochs
    batch_size=batch_size,   # Batch size of 16
    validation_data=(X_valid, y_valid),
    callbacks=[metrics_callback, early_stopping],  # Add early_stopping to the callbacks list
    verbose=0                # Suppress default Keras logging
)

# Save the model
model.save('dental_xray_unet_model.h5')

# Load test data
X_test, y_test = load_images_and_masks(test_img_dir, test_mask_dir)

# Evaluate on the training set
y_train_pred = model.predict(X_train, batch_size=batch_size)
y_train_pred_bin = (y_train_pred > 0.5).astype(np.uint8)

# Confusion Matrix for training
conf_matrix_train = confusion_matrix(y_train.flatten(), y_train_pred_bin.flatten())
plt.figure(figsize=(6, 4))
sns.heatmap(conf_matrix_train, annot=True, fmt="d", cmap="Blues")
plt.title("Confusion Matrix for Train")
plt.xlabel("Predicted")
plt.ylabel("True")
plt.show()

# Evaluate on the validation set
y_valid_pred = model.predict(X_valid, batch_size=batch_size)
y_valid_pred_bin = (y_valid_pred > 0.5).astype(np.uint8)

# Confusion Matrix for validation
conf_matrix_valid = confusion_matrix(y_valid.flatten(), y_valid_pred_bin.flatten())
plt.figure(figsize=(6, 4))
sns.heatmap(conf_matrix_valid, annot=True, fmt="d", cmap="Blues")
plt.title("Confusion Matrix for Validation")
plt.xlabel("Predicted")
plt.ylabel("True")
plt.show()

# Evaluate on the test set
y_test_pred = model.predict(X_test, batch_size=batch_size)
y_test_pred_bin = (y_test_pred > 0.5).astype(np.uint8)

# Confusion Matrix for testing
conf_matrix_test = confusion_matrix(y_test.flatten(), y_test_pred_bin.flatten())
plt.figure(figsize=(6, 4))
sns.heatmap(conf_matrix_test, annot=True, fmt="d", cmap="Blues")
plt.title("Confusion Matrix for Test")
plt.xlabel("Predicted")
plt.ylabel("True")
plt.show()

# Performance report for training set
train_accuracy = accuracy_score(y_train.flatten(), y_train_pred_bin.flatten())
train_recall = recall_score(y_train.flatten(), y_train_pred_bin.flatten())
train_precision = precision_score(y_train.flatten(), y_train_pred_bin.flatten())
train_f1 = f1_score(y_train.flatten(), y_train_pred_bin.flatten())
train_tn, train_fp, train_fn, train_tp = confusion_matrix(y_train.flatten(), y_train_pred_bin.flatten()).ravel()
train_specificity = train_tn / (train_tn + train_fp)

print(f'Training Set Results:')
print(f'Accuracy: {train_accuracy:.4f}')
print(f'Recall (Sensitivity): {train_recall:.4f}')
print(f'Precision: {train_precision:.4f}')
print(f'F1 Score: {train_f1:.4f}')
print(f'Specificity: {train_specificity:.4f}')

# Performance report for validation set
valid_accuracy = accuracy_score(y_valid.flatten(), y_valid_pred_bin.flatten())
valid_recall = recall_score(y_valid.flatten(), y_valid_pred_bin.flatten())
valid_precision = precision_score(y_valid.flatten(), y_valid_pred_bin.flatten())
valid_f1 = f1_score(y_valid.flatten(), y_valid_pred_bin.flatten())
valid_tn, valid_fp, valid_fn, valid_tp = confusion_matrix(y_valid.flatten(), y_valid_pred_bin.flatten()).ravel()
valid_specificity = valid_tn / (valid_tn + valid_fp)

print(f'Validation Set Results:')
print(f'Accuracy: {valid_accuracy:.4f}')
print(f'Recall (Sensitivity): {valid_recall:.4f}')
print(f'Precision: {valid_precision:.4f}')
print(f'F1 Score: {valid_f1:.4f}')
print(f'Specificity: {valid_specificity:.4f}')

# Performance report for testing set
test_accuracy = accuracy_score(y_test.flatten(), y_test_pred_bin.flatten())
test_recall = recall_score(y_test.flatten(), y_test_pred_bin.flatten())
test_precision = precision_score(y_test.flatten(), y_test_pred_bin.flatten())
test_f1 = f1_score(y_test.flatten(), y_test_pred_bin.flatten())
test_tn, test_fp, test_fn, test_tp = confusion_matrix(y_test.flatten(), y_test_pred_bin.flatten()).ravel()
test_specificity = test_tn / (test_tn + test_fp)

print(f'Testing Set Results:')
print(f'Accuracy: {test_accuracy:.4f}')
print(f'Recall (Sensitivity): {test_recall:.4f}')
print(f'Precision: {test_precision:.4f}')
print(f'F1 Score: {test_f1:.4f}')
print(f'Specificity: {test_specificity:.4f}')

# Visualization: Show input image, true mask, and predicted mask for a few samples
def visualize_predictions(images, true_masks, pred_masks, title):
    for i in range(3):  # Visualize first 3 predictions
        plt.figure(figsize=(12, 4))
        
        # Original image
        plt.subplot(1, 3, 1)
        plt.imshow(images[i].squeeze(), cmap='gray')
        plt.title('Original Image')
        
        # Ground truth mask
        plt.subplot(1, 3, 2)
        plt.imshow(true_masks[i].squeeze(), cmap='gray')
        plt.title('Ground Truth Mask')
        
        # Predicted mask
        plt.subplot(1, 3, 3)
        plt.imshow(pred_masks[i].squeeze(), cmap='gray')
        plt.title('Predicted Mask')
        
        plt.suptitle(title)
        plt.show()

# Visualize predictions for training set
visualize_predictions(X_train, y_train, y_train_pred_bin, "Train Set Predictions")

# Visualize predictions for validation set
visualize_predictions(X_valid, y_valid, y_valid_pred_bin, "Validation Set Predictions")

# Visualize predictions for testing set
visualize_predictions(X_test, y_test, y_test_pred_bin, "Test Set Predictions")


Mask not found for train_mask, skipping this image.
Mask not found for valid_mask, skipping this image.

Epoch 1/20
Batch 1/299 ━━━━━━━━━━━━━━━━━━━━ 12:41:20
Accuracy: 0.1233 - Precision: 0.0415 - Recall: 0.9663 - Specificity: 0.0888 - F1: 0.0796 - Loss: 0.1293

Batch 2/299 ━━━━━━━━━━━━━━━━━━━━ 12:42:09
Accuracy: 0.5499 - Precision: 0.0208 - Recall: 0.4831 - Specificity: 0.5444 - F1: 0.0398 - Loss: 0.1185

Batch 3/299 ━━━━━━━━━━━━━━━━━━━━ 12:42:47
Accuracy: 0.6912 - Precision: 0.0138 - Recall: 0.3221 - Specificity: 0.6963 - F1: 0.0265 - Loss: 0.1015

Batch 4/299 ━━━━━━━━━━━━━━━━━━━━ 12:43:25
Accuracy: 0.7590 - Precision: 0.0104 - Recall: 0.2416 - Specificity: 0.7722 - F1: 0.0199 - Loss: 0.0846

Batch 5/299 ━━━━━━━━━━━━━━━━━━━━ 12:44:08
Accuracy: 0.8026 - Precision: 0.0083 - Recall: 0.1933 - Specificity: 0.8178 - F1: 0.0159 - Loss: 0.0728

Batch 6/299 ━━━━━━━━━━━━━━━━━━━━ 12:45:01
Accuracy: 0.8313 - Precision: 0.0069 - Recall: 0.1610 - Specificity: 0.8481 - F1: 0.0133 - Loss: 0.0651

Ba

## VGG16 architecture

In [None]:
import os
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.preprocessing.image import load_img, img_to_array
from tensorflow.keras import backend as K
import cv2
import matplotlib.pyplot as plt
from datetime import datetime  # Import datetime for the custom callback
from sklearn.metrics import confusion_matrix, accuracy_score, recall_score, precision_score, f1_score
import seaborn as sns
from tensorflow.keras.callbacks import EarlyStopping  # Import EarlyStopping
from tensorflow.keras.applications import VGG16  # Import VGG16 for the new model

# Directory paths
train_img_dir = r"C:\Users\Jaber\OneDrive - University of Florida\Educational\GitHub\Datasets\ImageSegmentation\Dental_XRay_Computacional_Vision_Segmentation\Dental X_Ray\train"
train_mask_dir = r"C:\Users\Jaber\OneDrive - University of Florida\Educational\GitHub\ImageSegmentation\Datasets\Dental_XRay_Computacional_Vision_Segmentation\Dental X_Ray\train\train_mask"
test_img_dir = r"C:\Users\Jaber\OneDrive - University of Florida\Educational\GitHub\Datasets\ImageSegmentation\Dental_XRay_Computacional_Vision_Segmentation\Dental X_Ray\test"
test_mask_dir = r"C:\Users\Jaber\OneDrive - University of Florida\Educational\GitHub\ImageSegmentation\Datasets\Dental_XRay_Computacional_Vision_Segmentation\Dental X_Ray\test\test_mask"
valid_img_dir = r"C:\Users\Jaber\OneDrive - University of Florida\Educational\GitHub\Datasets\ImageSegmentation\Dental_XRay_Computacional_Vision_Segmentation\Dental X_Ray\valid"
valid_mask_dir = r"C:\Users\Jaber\OneDrive - University of Florida\Educational\GitHub\ImageSegmentation\Datasets\Dental_XRay_Computacional_Vision_Segmentation\Dental X_Ray\valid\valid_mask"

# U-Net model replaced with VGG16-based model
def vgg16_unet_model(input_size=(256, 256, 3)):
    # Load VGG16 as the encoder with pre-trained ImageNet weights
    vgg16 = VGG16(include_top=False, weights='imagenet', input_shape=input_size)
    
    # Freeze VGG16 layers to prevent them from being trained
    for layer in vgg16.layers:
        layer.trainable = False
    
    # Extract layers for skip connections
    block1 = vgg16.get_layer('block1_pool').output   # 128x128
    block2 = vgg16.get_layer('block2_pool').output   # 64x64
    block3 = vgg16.get_layer('block3_pool').output   # 32x32
    block4 = vgg16.get_layer('block4_pool').output   # 16x16
    block5 = vgg16.get_layer('block5_pool').output   # 8x8
    
    # Decoder
    u6 = layers.Conv2DTranspose(512, (2, 2), strides=(2, 2), padding='same')(block5)  # 16x16
    u6 = layers.concatenate([u6, block4])
    c6 = layers.Conv2D(512, (3, 3), activation='relu', padding='same')(u6)
    c6 = layers.Conv2D(512, (3, 3), activation='relu', padding='same')(c6)
    
    u7 = layers.Conv2DTranspose(256, (2, 2), strides=(2, 2), padding='same')(c6)  # 32x32
    u7 = layers.concatenate([u7, block3])
    c7 = layers.Conv2D(256, (3, 3), activation='relu', padding='same')(u7)
    c7 = layers.Conv2D(256, (3, 3), activation='relu', padding='same')(c7)
    
    u8 = layers.Conv2DTranspose(128, (2, 2), strides=(2, 2), padding='same')(c7)  # 64x64
    u8 = layers.concatenate([u8, block2])
    c8 = layers.Conv2D(128, (3, 3), activation='relu', padding='same')(u8)
    c8 = layers.Conv2D(128, (3, 3), activation='relu', padding='same')(c8)
    
    u9 = layers.Conv2DTranspose(64, (2, 2), strides=(2, 2), padding='same')(c8)    # 128x128
    u9 = layers.concatenate([u9, block1])
    c9 = layers.Conv2D(64, (3, 3), activation='relu', padding='same')(u9)
    c9 = layers.Conv2D(64, (3, 3), activation='relu', padding='same')(c9)
    
    # Final upsampling to reach original image size
    u10 = layers.Conv2DTranspose(32, (2, 2), strides=(2, 2), padding='same')(c9)  # 256x256
    c10 = layers.Conv2D(32, (3, 3), activation='relu', padding='same')(u10)
    c10 = layers.Conv2D(32, (3, 3), activation='relu', padding='same')(c10)
    
    outputs = layers.Conv2D(1, (1, 1), activation='sigmoid')(c10)
    
    model = models.Model(inputs=vgg16.input, outputs=outputs)
    
    return model

# Load images and masks
def load_images_and_masks(img_dir, mask_dir, img_size=(256, 256)):
    images = []
    masks = []
    
    img_files = os.listdir(img_dir)
    for img_file in img_files:
        img_path = os.path.join(img_dir, img_file)
        # Adjust mask filename to include "_mask.png"
        mask_file = img_file + "_mask.png"
        mask_path = os.path.join(mask_dir, mask_file)

        if os.path.exists(mask_path):
            # Load image as RGB
            img = load_img(img_path, color_mode='rgb', target_size=img_size)
            img = img_to_array(img) / 255.0
            # Load mask as grayscale
            mask = load_img(mask_path, color_mode='grayscale', target_size=img_size)
            mask = img_to_array(mask) / 255.0

            images.append(img)
            masks.append(mask)
        else:
            print(f"Mask not found for {img_file}, skipping this image.")
    
    return np.array(images), np.array(masks)

# Load training and validation data
X_train, y_train = load_images_and_masks(train_img_dir, train_mask_dir)
X_valid, y_valid = load_images_and_masks(valid_img_dir, valid_mask_dir)

# Define custom metrics
def custom_precision(y_true, y_pred):
    y_pred_bin = K.round(y_pred)
    true_positives = K.sum(K.round(y_true * y_pred_bin))
    predicted_positives = K.sum(y_pred_bin)
    precision = true_positives / (predicted_positives + K.epsilon())
    return precision

def custom_recall(y_true, y_pred):
    y_pred_bin = K.round(y_pred)
    true_positives = K.sum(K.round(y_true * y_pred_bin))
    possible_positives = K.sum(y_true)
    recall = true_positives / (possible_positives + K.epsilon())
    return recall

def custom_specificity(y_true, y_pred):
    y_pred_bin = K.round(y_pred)
    true_negatives = K.sum(K.round((1 - y_true) * (1 - y_pred_bin)))
    possible_negatives = K.sum(1 - y_true)
    specificity = true_negatives / (possible_negatives + K.epsilon())
    return specificity

def custom_f1(y_true, y_pred):
    precision = custom_precision(y_true, y_pred)
    recall = custom_recall(y_true, y_pred)
    return 2 * (precision * recall) / (precision + recall + K.epsilon())

# Define Focal Loss
def focal_loss_fixed(y_true, y_pred):
    gamma = 2.0
    alpha = 0.25
    epsilon = K.epsilon()
    y_pred = K.clip(y_pred, epsilon, 1. - epsilon)
    cross_entropy = -y_true * K.log(y_pred) - (1 - y_true) * K.log(1 - y_pred)
    weight = alpha * y_true * K.pow((1 - y_pred), gamma) + (1 - alpha) * (1 - y_true) * K.pow(y_pred, gamma)
    loss = weight * cross_entropy
    return K.mean(loss)

# Compile the model
model = vgg16_unet_model()
model.compile(optimizer='adam', loss=focal_loss_fixed, metrics=['accuracy', custom_precision, custom_recall, custom_specificity, custom_f1])

# Batch size for training
batch_size = 16

# Calculate total batches for training
total_batches = int(np.ceil(len(X_train) / batch_size))

# Custom callback to print more metrics at each batch with epoch tracking
class MetricsCallback(tf.keras.callbacks.Callback):
    def __init__(self, total_batches):
        super().__init__()
        self.batch_counter = 1  # Initialize the batch counter
        self.total_batches = total_batches  # Total number of batches per epoch
        self.current_epoch = 1  # Initialize current epoch

    def on_epoch_begin(self, epoch, logs=None):
        self.current_epoch = epoch + 1  # Epochs are zero-indexed
        print(f"\nEpoch {self.current_epoch}/{self.params['epochs']}")

    def on_batch_end(self, batch, logs=None):
        logs = logs or {}
        accuracy = logs.get('accuracy', 0)
        loss = logs.get('loss', 0)
        precision = logs.get('custom_precision', 0)
        recall = logs.get('custom_recall', 0)
        f1 = logs.get('custom_f1', 0)
        specificity = logs.get('custom_specificity', 0)
        
        # Time formatting for current step
        current_time = datetime.now().strftime("%H:%M:%S")
        
        # Print the metrics with proper formatting
        print(f"Batch {self.batch_counter}/{self.total_batches} ━━━━━━━━━━━━━━━━━━━━ {current_time}")
        print(f"Accuracy: {accuracy:.4f} - Precision: {precision:.4f} - Recall: {recall:.4f} - Specificity: {specificity:.4f} - F1: {f1:.4f} - Loss: {loss:.4f}\n")
        
        # Increment batch counter
        self.batch_counter += 1

    def on_epoch_end(self, epoch, logs=None):
        # Reset batch counter at the end of each epoch
        self.batch_counter = 1

# Define early stopping
early_stopping = EarlyStopping(
    monitor='val_loss',        # Metric to monitor
    patience=5,                # Number of epochs with no improvement after which training will be stopped
    restore_best_weights=True  # Restore model weights from the epoch with the best value of the monitored metric
)

# Initialize the custom callback
metrics_callback = MetricsCallback(total_batches=total_batches)

# Train the model with early stopping
history = model.fit(
    X_train, y_train,
    epochs=20,               # Number of epochs
    batch_size=batch_size,   # Batch size of 16
    validation_data=(X_valid, y_valid),
    callbacks=[metrics_callback, early_stopping],  # Add early_stopping to the callbacks list
    verbose=0                # Suppress default Keras logging
)

# Save the model
model.save('dental_xray_vgg16_unet_model.h5')

# Load test data
X_test, y_test = load_images_and_masks(test_img_dir, test_mask_dir)

# Evaluate on the training set
y_train_pred = model.predict(X_train, batch_size=batch_size)
y_train_pred_bin = (y_train_pred > 0.5).astype(np.uint8)

# Confusion Matrix for training
conf_matrix_train = confusion_matrix(y_train.flatten(), y_train_pred_bin.flatten())
plt.figure(figsize=(6, 4))
sns.heatmap(conf_matrix_train, annot=True, fmt="d", cmap="Blues")
plt.title("Confusion Matrix for Train")
plt.xlabel("Predicted")
plt.ylabel("True")
plt.show()

# Evaluate on the validation set
y_valid_pred = model.predict(X_valid, batch_size=batch_size)
y_valid_pred_bin = (y_valid_pred > 0.5).astype(np.uint8)

# Confusion Matrix for validation
conf_matrix_valid = confusion_matrix(y_valid.flatten(), y_valid_pred_bin.flatten())
plt.figure(figsize=(6, 4))
sns.heatmap(conf_matrix_valid, annot=True, fmt="d", cmap="Blues")
plt.title("Confusion Matrix for Validation")
plt.xlabel("Predicted")
plt.ylabel("True")
plt.show()

# Evaluate on the test set
y_test_pred = model.predict(X_test, batch_size=batch_size)
y_test_pred_bin = (y_test_pred > 0.5).astype(np.uint8)

# Confusion Matrix for testing
conf_matrix_test = confusion_matrix(y_test.flatten(), y_test_pred_bin.flatten())
plt.figure(figsize=(6, 4))
sns.heatmap(conf_matrix_test, annot=True, fmt="d", cmap="Blues")
plt.title("Confusion Matrix for Test")
plt.xlabel("Predicted")
plt.ylabel("True")
plt.show()

# Performance report for training set
train_accuracy = accuracy_score(y_train.flatten(), y_train_pred_bin.flatten())
train_recall = recall_score(y_train.flatten(), y_train_pred_bin.flatten())
train_precision = precision_score(y_train.flatten(), y_train_pred_bin.flatten())
train_f1 = f1_score(y_train.flatten(), y_train_pred_bin.flatten())
train_tn, train_fp, train_fn, train_tp = confusion_matrix(y_train.flatten(), y_train_pred_bin.flatten()).ravel()
train_specificity = train_tn / (train_tn + train_fp)

print(f'Training Set Results:')
print(f'Accuracy: {train_accuracy:.4f}')
print(f'Recall (Sensitivity): {train_recall:.4f}')
print(f'Precision: {train_precision:.4f}')
print(f'F1 Score: {train_f1:.4f}')
print(f'Specificity: {train_specificity:.4f}')

# Performance report for validation set
valid_accuracy = accuracy_score(y_valid.flatten(), y_valid_pred_bin.flatten())
valid_recall = recall_score(y_valid.flatten(), y_valid_pred_bin.flatten())
valid_precision = precision_score(y_valid.flatten(), y_valid_pred_bin.flatten())
valid_f1 = f1_score(y_valid.flatten(), y_valid_pred_bin.flatten())
valid_tn, valid_fp, valid_fn, valid_tp = confusion_matrix(y_valid.flatten(), y_valid_pred_bin.flatten()).ravel()
valid_specificity = valid_tn / (valid_tn + valid_fp)

print(f'Validation Set Results:')
print(f'Accuracy: {valid_accuracy:.4f}')
print(f'Recall (Sensitivity): {valid_recall:.4f}')
print(f'Precision: {valid_precision:.4f}')
print(f'F1 Score: {valid_f1:.4f}')
print(f'Specificity: {valid_specificity:.4f}')

# Performance report for testing set
test_accuracy = accuracy_score(y_test.flatten(), y_test_pred_bin.flatten())
test_recall = recall_score(y_test.flatten(), y_test_pred_bin.flatten())
test_precision = precision_score(y_test.flatten(), y_test_pred_bin.flatten())
test_f1 = f1_score(y_test.flatten(), y_test_pred_bin.flatten())
test_tn, test_fp, test_fn, test_tp = confusion_matrix(y_test.flatten(), y_test_pred_bin.flatten()).ravel()
test_specificity = test_tn / (test_tn + test_fp)

print(f'Testing Set Results:')
print(f'Accuracy: {test_accuracy:.4f}')
print(f'Recall (Sensitivity): {test_recall:.4f}')
print(f'Precision: {test_precision:.4f}')
print(f'F1 Score: {test_f1:.4f}')
print(f'Specificity: {test_specificity:.4f}')

# Visualization: Show input image, true mask, and predicted mask for a few samples
def visualize_predictions(images, true_masks, pred_masks, title):
    for i in range(3):  # Visualize first 3 predictions
        plt.figure(figsize=(12, 4))
        
        # Original image
        plt.subplot(1, 3, 1)
        plt.imshow(images[i].squeeze(), cmap='gray')
        plt.title('Original Image')
        
        # Ground truth mask
        plt.subplot(1, 3, 2)
        plt.imshow(true_masks[i].squeeze(), cmap='gray')
        plt.title('Ground Truth Mask')
        
        # Predicted mask
        plt.subplot(1, 3, 3)
        plt.imshow(pred_masks[i].squeeze(), cmap='gray')
        plt.title('Predicted Mask')
        
        plt.suptitle(title)
        plt.show()

# Visualize predictions for training set
visualize_predictions(X_train, y_train, y_train_pred_bin, "Train Set Predictions")

# Visualize predictions for validation set
visualize_predictions(X_valid, y_valid, y_valid_pred_bin, "Validation Set Predictions")

# Visualize predictions for testing set
visualize_predictions(X_test, y_test, y_test_pred_bin, "Test Set Predictions")


Mask not found for train_mask, skipping this image.
Mask not found for valid_mask, skipping this image.

Epoch 1/20
Batch 1/299 ━━━━━━━━━━━━━━━━━━━━ 13:13:53
Accuracy: 0.9514 - Precision: 0.0494 - Recall: 0.0375 - Specificity: 0.9785 - F1: 0.0426 - Loss: 0.0644

Batch 2/299 ━━━━━━━━━━━━━━━━━━━━ 13:14:03
Accuracy: 0.9619 - Precision: 0.0247 - Recall: 0.0187 - Specificity: 0.9893 - F1: 0.0213 - Loss: 0.0469

Batch 3/299 ━━━━━━━━━━━━━━━━━━━━ 13:14:12
Accuracy: 0.9667 - Precision: 0.0165 - Recall: 0.0125 - Specificity: 0.9928 - F1: 0.0142 - Loss: 0.0390

Batch 4/299 ━━━━━━━━━━━━━━━━━━━━ 13:14:21
Accuracy: 0.9690 - Precision: 0.0123 - Recall: 0.0094 - Specificity: 0.9946 - F1: 0.0107 - Loss: 0.0338

Batch 5/299 ━━━━━━━━━━━━━━━━━━━━ 13:14:29
Accuracy: 0.9693 - Precision: 0.0099 - Recall: 0.0075 - Specificity: 0.9957 - F1: 0.0085 - Loss: 0.0303

Batch 6/299 ━━━━━━━━━━━━━━━━━━━━ 13:14:38
Accuracy: 0.9714 - Precision: 0.0082 - Recall: 0.0062 - Specificity: 0.9964 - F1: 0.0071 - Loss: 0.0272

Ba