
### U-Caps
* No CV
* OneEpoch
* Results Showing for both Train and Test
* **All Performance metrics showing for each Batch during Run!**
  * This feature is not provided by Keras! I recalled the log using *MetricsCallback*!

In [1]:
import os
import numpy as np
import pydicom
import cv2
import tensorflow as tf
import pandas as pd
from sklearn.metrics import confusion_matrix, accuracy_score, recall_score, precision_score, f1_score
import matplotlib.pyplot as plt
import seaborn as sns
from datetime import datetime

# Directories for SIIM-ACR Pneumothorax Segmentation dataset
train_image_dir = r'C:\Users\Jaber\OneDrive - University of Florida\Educational\GitHub\Datasets\ImageSegmentation\SIIM-ACR Pneumothorax Segmentation\archive\pneumothorax\dicom-images-train'
test_image_dir = r'C:\Users\Jaber\OneDrive - University of Florida\Educational\GitHub\Datasets\ImageSegmentation\SIIM-ACR Pneumothorax Segmentation\archive\pneumothorax\dicom-images-test'
train_csv_path = r'C:\Users\Jaber\OneDrive - University of Florida\Educational\GitHub\Datasets\ImageSegmentation\SIIM-ACR Pneumothorax Segmentation\archive\pneumothorax\train-rle.csv'

img_size = (256, 256)

# Load the combined CSV that contains both train and test mask information
combined_df = pd.read_csv(train_csv_path)

# Split into training and test sets based on the availability of the image files
train_df = combined_df[combined_df['ImageId'].apply(lambda x: os.path.exists(os.path.join(train_image_dir, x + '.dcm')))]
test_df = combined_df[combined_df['ImageId'].apply(lambda x: os.path.exists(os.path.join(test_image_dir, x + '.dcm')))]

# Define the rle2mask function here instead of importing from mask_functions
def rle2mask(rle, width, height):
    mask = np.zeros(width * height, dtype=np.uint8)
    array = np.asarray([int(x) for x in rle.split()])
    starts = array[0::2]
    lengths = array[1::2]
    current_position = 0
    for start, length in zip(starts, lengths):
        current_position += start
        mask[current_position:current_position + length] = 255
        current_position += length
    return mask.reshape((height, width))

# Data generator to load data in batches
def data_generator(image_dir, df, img_size, batch_size=16):
    while True:
        df_shuffled = df.sample(frac=1).reset_index(drop=True)
        for start in range(0, len(df_shuffled), batch_size):
            end = min(start + batch_size, len(df_shuffled))
            batch_df = df_shuffled.iloc[start:end]

            images = []
            masks = []

            for index, row in batch_df.iterrows():
                img_id = row['ImageId']
                img_path = os.path.join(image_dir, img_id + '.dcm')

                dicom_data = pydicom.dcmread(img_path)
                img = dicom_data.pixel_array

                img = cv2.resize(img, img_size)
                img = img / 255.0  # Normalize image to range 0-1

                # Check if there is a mask
                if pd.isna(row['EncodedPixels']):
                    mask = np.zeros(img_size)  # No pneumothorax, empty mask
                else:
                    mask = rle2mask(row['EncodedPixels'], dicom_data.Rows, dicom_data.Columns)
                    mask = cv2.resize(mask, img_size)
                    mask = mask / 255.0  # Normalize mask to range 0-1

                images.append(np.expand_dims(img, axis=-1))  # Add channel dimension to the image
                masks.append(np.expand_dims(mask, axis=-1))  # Add channel dimension to the mask

            yield np.array(images), np.array(masks)

# Capsule Layer with Dynamic Routing
from tensorflow.keras import layers

def squash(vectors, axis=-1):
    """Squashing function to ensure output vectors' lengths are between 0 and 1"""
    s_squared_norm = tf.reduce_sum(tf.square(vectors), axis, keepdims=True)
    scale = s_squared_norm / (1 + s_squared_norm) / tf.sqrt(s_squared_norm + tf.keras.backend.epsilon())
    return scale * vectors

class CapsuleLayer(layers.Layer):
    def __init__(self, num_capsules, dim_capsule, num_routing=3, **kwargs):
        super(CapsuleLayer, self).__init__(**kwargs)
        self.num_capsules = num_capsules
        self.dim_capsule = dim_capsule
        self.num_routing = num_routing

    def build(self, input_shape):
        self.W = self.add_weight(shape=[input_shape[-1], self.num_capsules * self.dim_capsule],
                                 initializer='glorot_uniform', trainable=True)

    def call(self, inputs):
        inputs = tf.reshape(inputs, [-1, inputs.shape[1] * inputs.shape[2], inputs.shape[3]])
        u_hat = tf.einsum('...ij,jk->...ik', inputs, self.W)
        u_hat = tf.reshape(u_hat, [-1, inputs.shape[1], self.num_capsules, self.dim_capsule])
        
        b = tf.zeros(shape=[tf.shape(inputs)[0], inputs.shape[1], self.num_capsules])
        for i in range(self.num_routing):
            c = tf.nn.softmax(b, axis=-1)
            s = tf.reduce_sum(c[..., tf.newaxis] * u_hat, axis=1)
            v = squash(s)
            if i < self.num_routing - 1:
                b += tf.reduce_sum(u_hat * v[:, tf.newaxis, :, :], axis=-1)
        return v

# U-Net with Capsule Network Layers and Dynamic Routing
def unet_capsule_model(input_size=(256, 256, 1)):
    inputs = tf.keras.layers.Input(input_size)
    
    # Contracting Path with Capsules
    c1 = tf.keras.layers.Conv2D(64, 3, activation='relu', padding='same')(inputs)
    c1 = CapsuleLayer(num_capsules=8, dim_capsule=16)(c1)
    c1_flattened = tf.keras.layers.Flatten()(c1)  # Flatten the capsule output
    c1_reshaped = tf.keras.layers.Dense(256*256, activation='relu')(c1_flattened)  # Fully connected layer to reshape
    c1_reshaped = tf.keras.layers.Reshape((256, 256, 1))(c1_reshaped)  # Reshape to 4D
    p1 = tf.keras.layers.MaxPooling2D(pool_size=(2, 2))(c1_reshaped)
    
    c2 = tf.keras.layers.Conv2D(128, 3, activation='relu', padding='same')(p1)
    c2 = CapsuleLayer(num_capsules=16, dim_capsule=32)(c2)
    c2_flattened = tf.keras.layers.Flatten()(c2)
    c2_reshaped = tf.keras.layers.Dense(128*128, activation='relu')(c2_flattened)
    c2_reshaped = tf.keras.layers.Reshape((128, 128, 1))(c2_reshaped)
    p2 = tf.keras.layers.MaxPooling2D(pool_size=(2, 2))(c2_reshaped)
    
    c3 = tf.keras.layers.Conv2D(256, 3, activation='relu', padding='same')(p2)
    c3 = CapsuleLayer(num_capsules=32, dim_capsule=64)(c3)
    c3_flattened = tf.keras.layers.Flatten()(c3)
    c3_reshaped = tf.keras.layers.Dense(64*64, activation='relu')(c3_flattened)
    c3_reshaped = tf.keras.layers.Reshape((64, 64, 1))(c3_reshaped)
    p3 = tf.keras.layers.MaxPooling2D(pool_size=(2, 2))(c3_reshaped)
    
    # Bottleneck
    b = tf.keras.layers.Conv2D(512, 3, activation='relu', padding='same')(p3)
    b = tf.keras.layers.Conv2D(512, 3, activation='relu', padding='same')(b)
    
    # Expansive Path
    u1 = tf.keras.layers.Conv2DTranspose(256, 2, strides=(2, 2), padding='same')(b)
    u1 = tf.keras.layers.concatenate([u1, c3_reshaped])
    c4 = tf.keras.layers.Conv2D(256, 3, activation='relu', padding='same')(u1)
    c4 = tf.keras.layers.Conv2D(256, 3, activation='relu', padding='same')(c4)
    
    u2 = tf.keras.layers.Conv2DTranspose(128, 2, strides=(2, 2), padding='same')(c4)
    u2 = tf.keras.layers.concatenate([u2, c2_reshaped])
    c5 = tf.keras.layers.Conv2D(128, 3, activation='relu', padding='same')(u2)
    c5 = tf.keras.layers.Conv2D(128, 3, activation='relu', padding='same')(c5)
    
    u3 = tf.keras.layers.Conv2DTranspose(64, 2, strides=(2, 2), padding='same')(c5)
    u3 = tf.keras.layers.concatenate([u3, c1_reshaped])
    c6 = tf.keras.layers.Conv2D(64, 3, activation='relu', padding='same')(u3)
    c6 = tf.keras.layers.Conv2D(64, 3, activation='relu', padding='same')(c6)
    
    outputs = tf.keras.layers.Conv2D(1, 1, activation='sigmoid')(c6)
    
    model = tf.keras.Model(inputs=[inputs], outputs=[outputs])
    return model

# Dice and Binary Crossentropy combined loss function
def combined_dice_bce_loss(y_true, y_pred):
    y_true_f = tf.cast(tf.keras.backend.flatten(y_true), dtype='float32')
    y_pred_f = tf.cast(tf.keras.backend.flatten(y_pred), dtype='float32')
    
    intersection = tf.keras.backend.sum(y_true_f * y_pred_f)
    dice_loss = 1 - (2. * intersection + 1) / (tf.keras.backend.sum(y_true_f) + tf.keras.backend.sum(y_pred_f) + 1)
    
    bce_loss = tf.keras.losses.binary_crossentropy(y_true_f, y_pred_f)
    
    return dice_loss + bce_loss

# Custom metrics for Keras (with data type casting to fix type mismatch)
def custom_precision(y_true, y_pred):
    y_true = tf.cast(y_true, dtype='float32')
    y_pred = tf.round(tf.cast(y_pred, dtype='float32'))
    true_positives = tf.reduce_sum(y_true * y_pred)
    predicted_positives = tf.reduce_sum(y_pred)
    precision = true_positives / (predicted_positives + tf.keras.backend.epsilon())
    return precision

def custom_recall(y_true, y_pred):
    y_true = tf.cast(y_true, dtype='float32')
    y_pred = tf.round(tf.cast(y_pred, dtype='float32'))
    true_positives = tf.reduce_sum(y_true * y_pred)
    possible_positives = tf.reduce_sum(y_true)
    recall = true_positives / (possible_positives + tf.keras.backend.epsilon())
    return recall

def custom_f1(y_true, y_pred):
    precision = custom_precision(y_true, y_pred)
    recall = custom_recall(y_true, y_pred)
    return 2 * ((precision * recall) / (precision + recall + tf.keras.backend.epsilon()))

def custom_specificity(y_true, y_pred):
    y_true = tf.cast(y_true, dtype='float32')
    y_pred = tf.round(tf.cast(y_pred, dtype='float32'))
    true_negatives = tf.reduce_sum((1 - y_true) * (1 - y_pred))
    possible_negatives = tf.reduce_sum(1 - y_true)
    specificity = true_negatives / (possible_negatives + tf.keras.backend.epsilon())
    return specificity

# Custom callback to print more metrics at each batch in the exact format you requested
class MetricsCallback(tf.keras.callbacks.Callback):
    def __init__(self):
        super().__init__()
        self.batch_counter = 1  # Initialize the batch counter
    
    def on_batch_end(self, batch, logs=None):
        logs = logs or {}
        accuracy = logs.get('accuracy', 0)
        loss = logs.get('loss', 0)
        precision = logs.get('custom_precision', 0)
        recall = logs.get('custom_recall', 0)
        f1 = logs.get('custom_f1', 0)
        specificity = logs.get('custom_specificity', 0)
        
        # Time formatting for current step
        current_time = datetime.now().strftime("%H:%M:%S")
        
        # Print the metrics with proper formatting
        print(f"{self.batch_counter}/723 ━━━━━━━━━━━━━━━━━━━━ {current_time} - 60s/step")
        print(f"Accuracy: {accuracy:.4f} - Precision: {precision:.4f} - Recall: {recall:.4f} - Specificity: {specificity:.4f} - F1: {f1:.4f} - Loss: {loss:.4f}\n")
        
        # Increment batch counter
        self.batch_counter += 1

# Batch size for data generator
batch_size = 16

# Train generator and test generator
train_generator = data_generator(train_image_dir, train_df, img_size, batch_size=batch_size)
test_generator = data_generator(test_image_dir, test_df, img_size, batch_size=batch_size)

# Model training and testing
model = unet_capsule_model()
model.compile(optimizer='adam', loss=combined_dice_bce_loss, metrics=['accuracy', custom_precision, custom_recall, custom_f1, custom_specificity])

# Train the model with the custom callback (verbose=0 to avoid duplicate output)
history = model.fit(train_generator, steps_per_epoch=len(train_df) // batch_size, epochs=1,  # Train for 1 epoch
                    validation_data=test_generator, validation_steps=len(test_df) // batch_size,
                    callbacks=[MetricsCallback()], verbose=0)

# Evaluate on the train set
train_generator_eval = data_generator(train_image_dir, train_df, img_size, batch_size=batch_size)
y_train_pred = model.predict(train_generator_eval, steps=len(train_df) // batch_size)
y_train_true = np.array([mask for _, mask in train_generator_eval]).astype(np.uint8)
y_train_pred = (y_train_pred > 0.5).astype(np.uint8)

# Confusion Matrix for training
conf_matrix_train = confusion_matrix(y_train_true.flatten(), y_train_pred.flatten())
sns.heatmap(conf_matrix_train, annot=True, fmt="d", cmap="Blues")
plt.title(f"Confusion Matrix for Train")
plt.show()

# Evaluate on the test set
y_test_pred = model.predict(test_generator, steps=len(test_df) // batch_size)
y_test_true = np.array([mask for _, mask in test_generator]).astype(np.uint8)
y_test_pred = (y_test_pred > 0.5).astype(np.uint8)

# Confusion Matrix for testing
conf_matrix_test = confusion_matrix(y_test_true.flatten(), y_test_pred.flatten())
sns.heatmap(conf_matrix_test, annot=True, fmt="d", cmap="Blues")
plt.title(f"Confusion Matrix for Test")
plt.show()

# Visualization: Show input image, true mask, and predicted mask for a few samples for both train and test
def visualize_predictions(generator, true_masks, pred_masks, title):
    for i in range(3):  # Visualize first 3 predictions
        fig, ax = plt.subplots(1, 3, figsize=(15, 5))

        X = next(generator)[0]
        
        ax[0].imshow(X[i].squeeze(), cmap='gray')
        ax[0].set_title('Input Image')

        ax[1].imshow(true_masks[i].squeeze(), cmap='gray')
        ax[1].set_title('True Mask')

        ax[2].imshow(pred_masks[i].squeeze(), cmap='gray')
        ax[2].set_title('Predicted Mask')

        plt.suptitle(title)
        plt.show()

# Visualize predictions for training set
visualize_predictions(train_generator_eval, y_train_true, y_train_pred, "Train Set Predictions")

# Visualize predictions for testing set
visualize_predictions(test_generator, y_test_true, y_test_pred, "Test Set Predictions")

# Performance report for training set
train_accuracy = accuracy_score(y_train_true.flatten(), y_train_pred.flatten())
train_recall = recall_score(y_train_true.flatten(), y_train_pred.flatten())
train_precision = precision_score(y_train_true.flatten(), y_train_pred.flatten())
train_f1 = f1_score(y_train_true.flatten(), y_train_pred.flatten())
train_tn, train_fp, train_fn, train_tp = confusion_matrix(y_train_true.flatten(), y_train_pred.flatten()).ravel()
train_specificity = train_tn / (train_tn + train_fp)

print(f'Training Set Results:')
print(f'Accuracy: {train_accuracy:.4f}')
print(f'Recall (Sensitivity): {train_recall:.4f}')
print(f'Precision: {train_precision:.4f}')
print(f'F1 Score: {train_f1:.4f}')
print(f'Specificity: {train_specificity:.4f}')

# Performance report for testing set
test_accuracy = accuracy_score(y_test_true.flatten(), y_test_pred.flatten())
test_recall = recall_score(y_test_true.flatten(), y_test_pred.flatten())
test_precision = precision_score(y_test_true.flatten(), y_test_pred.flatten())
test_f1 = f1_score(y_test_true.flatten(), y_test_pred.flatten())
test_tn, test_fp, test_fn, test_tp = confusion_matrix(y_test_true.flatten(), y_test_pred.flatten()).ravel()
test_specificity = test_tn / (test_tn + test_fp)

print(f'Testing Set Results:')
print(f'Accuracy: {test_accuracy:.4f}')
print(f'Recall (Sensitivity): {test_recall:.4f}')
print(f'Precision: {test_precision:.4f}')
print(f'F1 Score: {test_f1:.4f}')
print(f'Specificity: {test_specificity:.4f}')



1/723 ━━━━━━━━━━━━━━━━━━━━ 17:18:51 - 60s/step
Accuracy: 0.7807 - Precision: 0.0081 - Recall: 0.2152 - Specificity: 0.7854 - F1: 0.0156 - Loss: 1.6770

2/723 ━━━━━━━━━━━━━━━━━━━━ 17:19:32 - 60s/step
Accuracy: 0.8876 - Precision: 0.0040 - Recall: 0.1076 - Specificity: 0.8927 - F1: 0.0078 - Loss: 1.6754

3/723 ━━━━━━━━━━━━━━━━━━━━ 17:20:25 - 60s/step
Accuracy: 0.9248 - Precision: 0.0027 - Recall: 0.0717 - Specificity: 0.9285 - F1: 0.0052 - Loss: 1.6700

4/723 ━━━━━━━━━━━━━━━━━━━━ 17:21:07 - 60s/step
Accuracy: 0.9426 - Precision: 0.0020 - Recall: 0.0538 - Specificity: 0.9464 - F1: 0.0039 - Loss: 1.6321



KeyboardInterrupt: 

Results show that we have Good Accuracy but High loss! I checked the diagrams and found that this high value for Loss could be due to the Poor performance of the model on Edge (Boundy) detection of segmentations.

The model is improved by the followings:
* Use Boundary-Aware Loss Functions: Introduce loss functions that emphasize boundary accuracy. Examples include:
 - Boundary Loss: Directly focus on minimizing the error at object boundaries by computing the distance between predicted and true boundary pixels.

* Attention Mechanisms: Incorporate attention layers or attention gates to allow the model to focus on boundary regions more effectively. Attention mechanisms help the model prioritize important features and pixels, like object boundaries, during training.


* Edge Detection as Preprocessing: Incorporate an edge detection module (e.g., using Sobel filters or Canny edge detectors) to provide the model with explicit boundary information as input, making it easier for the model to focus on boundaries.


In [None]:
import os
import numpy as np
import pydicom
import cv2
import tensorflow as tf
import pandas as pd
from sklearn.metrics import confusion_matrix, accuracy_score, recall_score, precision_score, f1_score
import matplotlib.pyplot as plt
import seaborn as sns
from datetime import datetime

# Directories for SIIM-ACR Pneumothorax Segmentation dataset
train_image_dir = r'C:\Users\Jaber\OneDrive - University of Florida\Educational\GitHub\Datasets\ImageSegmentation\SIIM-ACR Pneumothorax Segmentation\archive\pneumothorax\dicom-images-train'
test_image_dir = r'C:\Users\Jaber\OneDrive - University of Florida\Educational\GitHub\Datasets\ImageSegmentation\SIIM-ACR Pneumothorax Segmentation\archive\pneumothorax\dicom-images-test'
train_csv_path = r'C:\Users\Jaber\OneDrive - University of Florida\Educational\GitHub\Datasets\ImageSegmentation\SIIM-ACR Pneumothorax Segmentation\archive\pneumothorax\train-rle.csv'

img_size = (256, 256)

# Load the combined CSV that contains both train and test mask information
combined_df = pd.read_csv(train_csv_path)

# Split into training and test sets based on the availability of the image files
train_df = combined_df[combined_df['ImageId'].apply(lambda x: os.path.exists(os.path.join(train_image_dir, x + '.dcm')))]
test_df = combined_df[combined_df['ImageId'].apply(lambda x: os.path.exists(os.path.join(test_image_dir, x + '.dcm')))]

# Define the rle2mask function
def rle2mask(rle, width, height):
    mask = np.zeros(width * height, dtype=np.uint8)
    array = np.asarray([int(x) for x in rle.split()])
    starts = array[0::2] - 1  # Adjust start positions
    lengths = array[1::2]
    for start, length in zip(starts, lengths):
        mask[int(start):int(start + length)] = 255
    return mask.reshape((height, width)).T  # Transpose to match the image orientation

# Sobel edge detection for preprocessing
def sobel_edge_detection(image):
    image_rank = len(image.shape)
    if image_rank == 2:
        # Shape is [height, width]
        image = tf.expand_dims(image, axis=-1)  # Add channel dimension
        image = tf.expand_dims(image, axis=0)   # Add batch dimension
        edges = tf.image.sobel_edges(image)     # Output shape: [1, height, width, 1, 2]
        edges = tf.sqrt(tf.reduce_sum(tf.square(edges), axis=-1))  # Shape: [1, height, width, 1]
        edges = tf.squeeze(edges, axis=[0, -1])  # Remove batch and channel dimensions
        edges = edges / (tf.reduce_max(edges) + tf.keras.backend.epsilon())  # Normalize to 0-1
    elif image_rank == 3:
        # Shape is [height, width, channels]
        image = tf.expand_dims(image, axis=0)   # Add batch dimension
        edges = tf.image.sobel_edges(image)     # Output shape: [1, height, width, channels, 2]
        edges = tf.sqrt(tf.reduce_sum(tf.square(edges), axis=-1))  # Shape: [1, height, width, channels]
        edges = tf.squeeze(edges, axis=0)       # Remove batch dimension
        edges = edges / (tf.reduce_max(edges) + tf.keras.backend.epsilon())  # Normalize to 0-1
    elif image_rank == 4:
        # Shape is [batch_size, height, width, channels]
        edges = tf.image.sobel_edges(image)     # Output shape: [batch_size, height, width, channels, 2]
        edges = tf.sqrt(tf.reduce_sum(tf.square(edges), axis=-1))  # Shape: [batch_size, height, width, channels]
        edges = edges / (tf.reduce_max(edges, axis=[1,2,3], keepdims=True) + tf.keras.backend.epsilon())  # Normalize to 0-1
    else:
        raise ValueError("Unsupported image rank: {}".format(image_rank))
    return edges

# Data generator with edge detection preprocessing
def data_generator(image_dir, df, img_size, batch_size=16):
    while True:
        df_shuffled = df.sample(frac=1).reset_index(drop=True)
        for start in range(0, len(df_shuffled), batch_size):
            end = min(start + batch_size, len(df_shuffled))
            batch_df = df_shuffled.iloc[start:end]

            images = []
            masks = []

            for index, row in batch_df.iterrows():
                img_id = row['ImageId']
                img_path = os.path.join(image_dir, img_id + '.dcm')

                dicom_data = pydicom.dcmread(img_path)
                img = dicom_data.pixel_array

                # Resize the original image
                img_resized = cv2.resize(img, img_size)
                img_resized = img_resized / 255.0  # Normalize image to range 0-1

                # Apply Sobel edge detection
                edge_img = sobel_edge_detection(tf.convert_to_tensor(img_resized, dtype=tf.float32))
                edge_img = edge_img.numpy()  # Convert to numpy array

                # Ensure edge_img has same shape as img_resized
                if img_resized.shape != edge_img.shape:
                    edge_img = cv2.resize(edge_img, (img_resized.shape[1], img_resized.shape[0]))

                # Stack the resized original and edge-detected image
                img_combined = np.stack([img_resized, edge_img], axis=-1)

                # Check if there is a mask
                if pd.isna(row['EncodedPixels']):
                    mask = np.zeros(img_size, dtype=np.uint8)  # No pneumothorax, empty mask
                else:
                    mask = rle2mask(row['EncodedPixels'], dicom_data.Columns, dicom_data.Rows)
                    mask = cv2.resize(mask, img_size, interpolation=cv2.INTER_NEAREST)
                    mask = (mask > 127).astype(np.uint8)  # Binarize mask

                images.append(img_combined)
                masks.append(np.expand_dims(mask, axis=-1))  # Add channel dimension to the mask

            yield np.array(images), np.array(masks)

# Custom Metrics for Precision, Recall, F1, and Specificity
def custom_precision(y_true, y_pred):
    y_true = tf.cast(tf.round(y_true), 'float32')  # Ensure y_true is binary
    y_pred = tf.round(y_pred)
    true_positives = tf.reduce_sum(y_true * y_pred)
    predicted_positives = tf.reduce_sum(y_pred)
    precision = true_positives / (predicted_positives + tf.keras.backend.epsilon())
    return precision

def custom_recall(y_true, y_pred):
    y_true = tf.cast(tf.round(y_true), 'float32')  # Ensure y_true is binary
    y_pred = tf.round(y_pred)
    true_positives = tf.reduce_sum(y_true * y_pred)
    possible_positives = tf.reduce_sum(y_true)
    recall = true_positives / (possible_positives + tf.keras.backend.epsilon())
    return recall

def custom_f1(y_true, y_pred):
    precision = custom_precision(y_true, y_pred)
    recall = custom_recall(y_true, y_pred)
    f1 = 2 * ((precision * recall) / (precision + recall + tf.keras.backend.epsilon()))
    return f1

def custom_specificity(y_true, y_pred):
    y_true = tf.cast(tf.round(y_true), 'float32')  # Ensure y_true is binary
    y_pred = tf.round(y_pred)
    true_negatives = tf.reduce_sum((1 - y_true) * (1 - y_pred))
    possible_negatives = tf.reduce_sum(1 - y_true)
    specificity = true_negatives / (possible_negatives + tf.keras.backend.epsilon())
    return specificity

# Capsule Layer with Dynamic Routing
from tensorflow.keras import layers

def squash(vectors, axis=-1):
    """Squashing function to ensure output vectors' lengths are between 0 and 1"""
    s_squared_norm = tf.reduce_sum(tf.square(vectors), axis, keepdims=True)
    scale = s_squared_norm / (1 + s_squared_norm + tf.keras.backend.epsilon())
    return scale * vectors / tf.sqrt(s_squared_norm + tf.keras.backend.epsilon())

class CapsuleLayer(layers.Layer):
    def __init__(self, num_capsules, dim_capsule, num_routing=3, **kwargs):
        super(CapsuleLayer, self).__init__(**kwargs)
        self.num_capsules = num_capsules
        self.dim_capsule = dim_capsule
        self.num_routing = num_routing

    def build(self, input_shape):
        self.W = self.add_weight(shape=[input_shape[-1], self.num_capsules * self.dim_capsule],
                                 initializer='glorot_uniform', trainable=True)

    def call(self, inputs):
        inputs = tf.reshape(inputs, [-1, inputs.shape[1] * inputs.shape[2], inputs.shape[3]])
        u_hat = tf.einsum('...ij,jk->...ik', inputs, self.W)
        u_hat = tf.reshape(u_hat, [-1, inputs.shape[1], self.num_capsules, self.dim_capsule])
        
        b = tf.zeros(shape=[tf.shape(inputs)[0], inputs.shape[1], self.num_capsules])
        for i in range(self.num_routing):
            c = tf.nn.softmax(b, axis=-1)
            s = tf.reduce_sum(c[..., tf.newaxis] * u_hat, axis=1)
            v = squash(s)
            if i < self.num_routing - 1:
                b += tf.reduce_sum(u_hat * v[:, tf.newaxis, :, :], axis=-1)
        return v

# Attention Gate (fixed shape mismatch)
def attention_gate(x, g, inter_shape, upsample=False):
    theta_x = tf.keras.layers.Conv2D(inter_shape, kernel_size=1, strides=1, padding='same')(x)
    phi_g = tf.keras.layers.Conv2D(inter_shape, kernel_size=1, padding='same')(g)
    
    if upsample:
        # Use Lambda layer to resize phi_g to match theta_x
        phi_g = tf.keras.layers.Lambda(
            lambda inputs: tf.image.resize(inputs[0], tf.shape(inputs[1])[1:3], method='bilinear'))([phi_g, theta_x])
    
    add_xg = tf.keras.layers.add([theta_x, phi_g])
    relu_xg = tf.keras.layers.Activation('relu')(add_xg)
    psi = tf.keras.layers.Conv2D(1, kernel_size=1, padding='same')(relu_xg)
    sigmoid_xg = tf.keras.layers.Activation('sigmoid')(psi)
    return tf.keras.layers.Multiply()([x, sigmoid_xg])

# U-Net with Capsule Network Layers, Attention Mechanism, and Dynamic Routing
def unet_capsule_model(input_size=(256, 256, 2)):  # 2-channel input (original + edge)
    inputs = tf.keras.layers.Input(input_size)
    
    # Contracting Path with Capsules
    c1 = tf.keras.layers.Conv2D(64, 3, activation='relu', padding='same')(inputs)
    c1 = CapsuleLayer(num_capsules=8, dim_capsule=16)(c1)
    c1_flattened = tf.keras.layers.Flatten()(c1)
    c1_reshaped = tf.keras.layers.Dense(256 * 256, activation='relu')(c1_flattened)
    c1_reshaped = tf.keras.layers.Reshape((256, 256, 1))(c1_reshaped)
    p1 = tf.keras.layers.MaxPooling2D(pool_size=(2, 2))(c1_reshaped)
    
    c2 = tf.keras.layers.Conv2D(128, 3, activation='relu', padding='same')(p1)
    c2 = CapsuleLayer(num_capsules=16, dim_capsule=32)(c2)
    c2_flattened = tf.keras.layers.Flatten()(c2)
    c2_reshaped = tf.keras.layers.Dense(128 * 128, activation='relu')(c2_flattened)
    c2_reshaped = tf.keras.layers.Reshape((128, 128, 1))(c2_reshaped)
    p2 = tf.keras.layers.MaxPooling2D(pool_size=(2, 2))(c2_reshaped)
    
    c3 = tf.keras.layers.Conv2D(256, 3, activation='relu', padding='same')(p2)
    c3 = CapsuleLayer(num_capsules=32, dim_capsule=64)(c3)
    c3_flattened = tf.keras.layers.Flatten()(c3)
    c3_reshaped = tf.keras.layers.Dense(64 * 64, activation='relu')(c3_flattened)
    c3_reshaped = tf.keras.layers.Reshape((64, 64, 1))(c3_reshaped)
    p3 = tf.keras.layers.MaxPooling2D(pool_size=(2, 2))(c3_reshaped)
    
    # Bottleneck
    b = tf.keras.layers.Conv2D(512, 3, activation='relu', padding='same')(p3)
    b = tf.keras.layers.Conv2D(512, 3, activation='relu', padding='same')(b)
    
    # Attention mechanism in expansive path
    g1 = tf.keras.layers.Conv2DTranspose(256, 2, strides=(2, 2), padding='same')(b)
    a1 = attention_gate(c3_reshaped, g1, 128, upsample=True)
    u1 = tf.keras.layers.concatenate([g1, a1])
    c4 = tf.keras.layers.Conv2D(256, 3, activation='relu', padding='same')(u1)
    c4 = tf.keras.layers.Conv2D(256, 3, activation='relu', padding='same')(c4)
    
    g2 = tf.keras.layers.Conv2DTranspose(128, 2, strides=(2, 2), padding='same')(c4)
    a2 = attention_gate(c2_reshaped, g2, 64, upsample=True)
    u2 = tf.keras.layers.concatenate([g2, a2])
    c5 = tf.keras.layers.Conv2D(128, 3, activation='relu', padding='same')(u2)
    c5 = tf.keras.layers.Conv2D(128, 3, activation='relu', padding='same')(c5)
    
    g3 = tf.keras.layers.Conv2DTranspose(64, 2, strides=(2, 2), padding='same')(c5)
    a3 = attention_gate(c1_reshaped, g3, 32, upsample=True)
    u3 = tf.keras.layers.concatenate([g3, a3])
    c6 = tf.keras.layers.Conv2D(64, 3, activation='relu', padding='same')(u3)
    c6 = tf.keras.layers.Conv2D(64, 3, activation='relu', padding='same')(c6)
    
    outputs = tf.keras.layers.Conv2D(1, 1, activation='sigmoid')(c6)
    
    model = tf.keras.Model(inputs=[inputs], outputs=[outputs])
    return model

# Boundary Loss Function
def boundary_loss(y_true, y_pred):
    y_true = tf.cast(tf.round(y_true), dtype='float32')  # Ensure y_true is binary
    y_pred = tf.cast(y_pred, dtype='float32')

    y_true_f = tf.keras.backend.flatten(y_true)
    y_pred_f = tf.keras.backend.flatten(y_pred)
    
    # Dice Loss
    intersection = tf.keras.backend.sum(y_true_f * y_pred_f)
    dice_loss = 1 - (2. * intersection + 1) / (tf.keras.backend.sum(y_true_f) + tf.keras.backend.sum(y_pred_f) + 1)
    
    # Boundary Loss (difference between boundary of prediction and ground truth)
    sobel_true = sobel_edge_detection(y_true)
    sobel_pred = sobel_edge_detection(y_pred)
    boundary_diff = tf.reduce_mean(tf.abs(sobel_true - sobel_pred))
    
    return dice_loss + boundary_diff

# Custom callback to print more metrics at each batch in the exact format you requested
class MetricsCallback(tf.keras.callbacks.Callback):
    def __init__(self):
        super().__init__()
        self.batch_counter = 1  # Initialize the batch counter
    
    def on_batch_end(self, batch, logs=None):
        logs = logs or {}
        accuracy = logs.get('accuracy', 0)
        loss = logs.get('loss', 0)
        precision = logs.get('custom_precision', 0)
        recall = logs.get('custom_recall', 0)
        f1 = logs.get('custom_f1', 0)
        specificity = logs.get('custom_specificity', 0)
        
        # Time formatting for current step
        current_time = datetime.now().strftime("%H:%M:%S")
        
        # Print the metrics with proper formatting
        print(f"{self.batch_counter}/723 ━━━━━━━━━━━━━━━━━━━━ {current_time} - 60s/step")
        print(f"Accuracy: {accuracy:.4f} - Precision: {precision:.4f} - Recall: {recall:.4f} - Specificity: {specificity:.4f} - F1: {f1:.4f} - Loss: {loss:.4f}\n")
        
        # Increment batch counter
        self.batch_counter += 1

# Batch size for data generator
batch_size = 16

# Train generator and test generator
train_generator = data_generator(train_image_dir, train_df, img_size, batch_size=batch_size)
test_generator = data_generator(test_image_dir, test_df, img_size, batch_size=batch_size)

# Model training and testing
model = unet_capsule_model()
model.compile(optimizer='adam', loss=boundary_loss, metrics=['accuracy', custom_precision, custom_recall, custom_f1, custom_specificity])

# Train the model with the custom callback (verbose=0 to avoid duplicate output)
history = model.fit(train_generator, steps_per_epoch=len(train_df) // batch_size, epochs=1,  # Train for 1 epoch
                    validation_data=test_generator, validation_steps=len(test_df) // batch_size,
                    callbacks=[MetricsCallback()], verbose=0)

# Evaluate on the train set
train_generator_eval = data_generator(train_image_dir, train_df, img_size, batch_size=batch_size)
y_train_pred = model.predict(train_generator_eval, steps=len(train_df) // batch_size)
y_train_true_list = []
for _ in range(len(train_df) // batch_size):
    _, masks = next(train_generator_eval)
    y_train_true_list.append(masks)
y_train_true = np.concatenate(y_train_true_list)
y_train_pred = (y_train_pred > 0.5).astype(np.uint8)

# Confusion Matrix for training
conf_matrix_train = confusion_matrix(y_train_true.flatten(), y_train_pred.flatten())
sns.heatmap(conf_matrix_train, annot=True, fmt="d", cmap="Blues")
plt.title(f"Confusion Matrix for Train")
plt.show()

# Evaluate on the test set
test_generator_eval = data_generator(test_image_dir, test_df, img_size, batch_size=batch_size)
y_test_pred = model.predict(test_generator_eval, steps=len(test_df) // batch_size)
y_test_true_list = []
for _ in range(len(test_df) // batch_size):
    _, masks = next(test_generator_eval)
    y_test_true_list.append(masks)
y_test_true = np.concatenate(y_test_true_list)
y_test_pred = (y_test_pred > 0.5).astype(np.uint8)

# Confusion Matrix for testing
conf_matrix_test = confusion_matrix(y_test_true.flatten(), y_test_pred.flatten())
sns.heatmap(conf_matrix_test, annot=True, fmt="d", cmap="Blues")
plt.title(f"Confusion Matrix for Test")
plt.show()

# Visualization: Show input image, true mask, and predicted mask for a few samples for both train and test
def visualize_predictions(generator, true_masks, pred_masks, title):
    for i in range(3):  # Visualize first 3 predictions
        fig, ax = plt.subplots(1, 3, figsize=(15, 5))

        X_batch, _ = next(generator)
        X = X_batch[i]
        
        ax[0].imshow(X[:, :, 0].squeeze(), cmap='gray')
        ax[0].set_title('Input Image')

        ax[1].imshow(true_masks[i].squeeze(), cmap='gray')
        ax[1].set_title('True Mask')

        ax[2].imshow(pred_masks[i].squeeze(), cmap='gray')
        ax[2].set_title('Predicted Mask')

        plt.suptitle(title)
        plt.show()

# Visualize predictions for training set
train_generator_eval = data_generator(train_image_dir, train_df, img_size, batch_size=batch_size)
visualize_predictions(train_generator_eval, y_train_true, y_train_pred, "Train Set Predictions")

# Visualize predictions for testing set
test_generator_eval = data_generator(test_image_dir, test_df, img_size, batch_size=batch_size)
visualize_predictions(test_generator_eval, y_test_true, y_test_pred, "Test Set Predictions")

# Performance report for training set
train_accuracy = accuracy_score(y_train_true.flatten(), y_train_pred.flatten())
train_recall = recall_score(y_train_true.flatten(), y_train_pred.flatten())
train_precision = precision_score(y_train_true.flatten(), y_train_pred.flatten())
train_f1 = f1_score(y_train_true.flatten(), y_train_pred.flatten())
train_tn, train_fp, train_fn, train_tp = confusion_matrix(y_train_true.flatten(), y_train_pred.flatten()).ravel()
train_specificity = train_tn / (train_tn + train_fp)

print(f'Training Set Results:')
print(f'Accuracy: {train_accuracy:.4f}')
print(f'Recall (Sensitivity): {train_recall:.4f}')
print(f'Precision: {train_precision:.4f}')
print(f'F1 Score: {train_f1:.4f}')
print(f'Specificity: {train_specificity:.4f}')

# Performance report for testing set
test_accuracy = accuracy_score(y_test_true.flatten(), y_test_pred.flatten())
test_recall = recall_score(y_test_true.flatten(), y_test_pred.flatten())
test_precision = precision_score(y_test_true.flatten(), y_test_pred.flatten())
test_f1 = f1_score(y_test_true.flatten(), y_test_pred.flatten())
test_tn, test_fp, test_fn, test_tp = confusion_matrix(y_test_true.flatten(), y_test_pred.flatten()).ravel()
test_specificity = test_tn / (test_tn + test_fp)

print(f'Testing Set Results:')
print(f'Accuracy: {test_accuracy:.4f}')
print(f'Recall (Sensitivity): {test_recall:.4f}')
print(f'Precision: {test_precision:.4f}')
print(f'F1 Score: {test_f1:.4f}')
print(f'Specificity: {test_specificity:.4f}')



1/723 ━━━━━━━━━━━━━━━━━━━━ 19:53:09 - 60s/step
Accuracy: 0.7317 - Precision: 0.0001 - Recall: 0.1695 - Specificity: 0.7317 - F1: 0.0001 - Loss: 1.1988

2/723 ━━━━━━━━━━━━━━━━━━━━ 19:54:19 - 60s/step
Accuracy: 0.4536 - Precision: 0.0001 - Recall: 0.2324 - Specificity: 0.4536 - F1: 0.0001 - Loss: 1.1431

3/723 ━━━━━━━━━━━━━━━━━━━━ 19:55:11 - 60s/step
Accuracy: 0.4267 - Precision: 0.0000 - Recall: 0.1685 - Specificity: 0.4267 - F1: 0.0001 - Loss: 1.1049

4/723 ━━━━━━━━━━━━━━━━━━━━ 19:56:04 - 60s/step
Accuracy: 0.4318 - Precision: 0.0000 - Recall: 0.1317 - Specificity: 0.4319 - F1: 0.0001 - Loss: 1.0809

5/723 ━━━━━━━━━━━━━━━━━━━━ 19:56:47 - 60s/step
Accuracy: 0.4431 - Precision: 0.0000 - Recall: 0.1053 - Specificity: 0.4432 - F1: 0.0000 - Loss: 1.0661

6/723 ━━━━━━━━━━━━━━━━━━━━ 19:57:39 - 60s/step
Accuracy: 0.4561 - Precision: 0.0000 - Recall: 0.0878 - Specificity: 0.4561 - F1: 0.0000 - Loss: 1.0562

7/723 ━━━━━━━━━━━━━━━━━━━━ 19:58:44 - 60s/step
Accuracy: 0.4695 - Precision: 0.0000 - R

* Precsion is not promissing!
* Trying again with the above items!!!