In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
import tensorflow as tf
from tensorflow.keras import layers, Model

from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, Callback
import time
from tensorflow.keras.utils import plot_model
import numpy as np
from sklearn.model_selection import KFold
import os
import glob

In [None]:
import os
import tensorflow as tf
import numpy as np
import glob

BASE_PATH = "/kaggle/input/dataset/new_data"

# Define folders for train, validation, and test
TRAIN_IMAGES_DIR = os.path.join(BASE_PATH, "train", "images")
TRAIN_MASKS_DIR = os.path.join(BASE_PATH, "train", "GT_TE")

VAL_IMAGES_DIR = os.path.join(BASE_PATH, "valid", "images")
VAL_MASKS_DIR = os.path.join(BASE_PATH, "valid", "GT_TE")

TEST_IMAGES_DIR = os.path.join(BASE_PATH, "test", "images")
TEST_MASKS_DIR = os.path.join(BASE_PATH, "test", "GT_TE")

# Check if files are loaded correctly
print("Train images:", len(os.listdir(TRAIN_IMAGES_DIR)))
print("Train masks:", len(os.listdir(TRAIN_MASKS_DIR)))
print("Validation images:", len(os.listdir(VAL_IMAGES_DIR)))
print("Validation masks:", len(os.listdir(VAL_MASKS_DIR)))
print("Test images:", len(os.listdir(TEST_IMAGES_DIR)))
print("Test masks:", len(os.listdir(TEST_MASKS_DIR)))

IMG_HEIGHT = 256
IMG_WIDTH = 256

def load_image(image_path, mask_path):
    """Load and preprocess a single image-mask pair"""
    # Load image
    image = tf.io.read_file(image_path)
    image = tf.image.decode_bmp(image, channels=3)  # decode as RGB
    image = tf.image.rgb_to_grayscale(image)        # convert to grayscale
    image = tf.image.resize(image, [IMG_HEIGHT, IMG_WIDTH])
    image = tf.cast(image, tf.float32) / 255.0

    # Load mask
    mask = tf.io.read_file(mask_path)
    mask = tf.image.decode_bmp(mask, channels=3)
    mask = tf.image.rgb_to_grayscale(mask)
    mask = tf.image.resize(mask, [IMG_HEIGHT, IMG_WIDTH])
    mask = tf.cast(mask, tf.float32) / 255.0
    mask = tf.where(mask > 0.5, 1.0, 0.0)  # binarize

    return image, mask



def create_dataset(image_dir, mask_dir, batch_size=8, shuffle=True):
    """Create tf.data dataset for training, validation, or testing"""
    image_paths = sorted(glob.glob(os.path.join(image_dir, "*.bmp")))
    mask_paths = sorted(glob.glob(os.path.join(mask_dir, "*.bmp")))

    print(f"Found {len(image_paths)} images and {len(mask_paths)} masks in {image_dir}")

    dataset = tf.data.Dataset.from_tensor_slices((image_paths, mask_paths))
    dataset = dataset.map(load_image, num_parallel_calls=tf.data.AUTOTUNE)

    if shuffle:
        dataset = dataset.shuffle(buffer_size=100)

    return dataset.batch(batch_size).prefetch(tf.data.AUTOTUNE)



BATCH_SIZE = 16
train_dataset = create_dataset(TRAIN_IMAGES_DIR, TRAIN_MASKS_DIR, batch_size=BATCH_SIZE)
val_dataset = create_dataset(VAL_IMAGES_DIR, VAL_MASKS_DIR, batch_size=BATCH_SIZE, shuffle=False)
test_dataset = create_dataset(TEST_IMAGES_DIR, TEST_MASKS_DIR, batch_size=1, shuffle=False)



In [None]:
import tensorflow as tf
from tensorflow.keras import layers, models

# ---------------- Custom GroupNorm ----------------
class GroupNormalization(layers.Layer):
    def __init__(self, groups=4, axis=-1, epsilon=1e-5, **kwargs):
        super(GroupNormalization, self).__init__(**kwargs)
        self.groups = groups
        self.axis = axis
        self.epsilon = epsilon

    def build(self, input_shape):
        dim = input_shape[self.axis]
        if dim is None:
            raise ValueError("Axis {} of input tensor should have a defined dimension".format(self.axis))
        if dim % self.groups != 0:
            raise ValueError("Number of channels ({}) must be divisible by groups ({})".format(dim, self.groups))
        self.gamma = self.add_weight(shape=(dim,),
                                     initializer="ones",
                                     trainable=True,
                                     name="gamma")
        self.beta = self.add_weight(shape=(dim,),
                                    initializer="zeros",
                                    trainable=True,
                                    name="beta")
        super(GroupNormalization, self).build(input_shape)

    def call(self, inputs):
        input_shape = tf.shape(inputs)
        N, H, W, C = input_shape[0], input_shape[1], input_shape[2], input_shape[3]

        G = self.groups
        x = tf.reshape(inputs, [N, H, W, G, C // G])

        mean, var = tf.nn.moments(x, [1, 2, 4], keepdims=True)
        x = (x - mean) / tf.sqrt(var + self.epsilon)

        x = tf.reshape(x, [N, H, W, C])
        return x * self.gamma + self.beta

# ---------------- DepthWiseConv2D ----------------
def DepthWiseConv2D(filters, kernel_size=3, dilation_rate=1, padding="same", activation=None):
    return models.Sequential([
        layers.DepthwiseConv2D(kernel_size, padding=padding, dilation_rate=dilation_rate),
        GroupNormalization(groups=4),
        layers.Conv2D(filters, 1, padding="same"),
        layers.Activation(activation) if activation else layers.Activation("linear")
    ])


In [None]:
# ---------------- Gated Attention Unit ----------------
class GatedAttentionUnit(layers.Layer):
    def __init__(self, in_c, out_c, kernel_size=3):
        super(GatedAttentionUnit, self).__init__()
        self.w1 = models.Sequential([
            DepthWiseConv2D(in_c, kernel_size, activation="linear"),
            layers.Activation("sigmoid")
        ])
        self.w2 = models.Sequential([
            DepthWiseConv2D(in_c, kernel_size + 2, activation="gelu"),
        ])
        self.wo = models.Sequential([
            DepthWiseConv2D(out_c, kernel_size, activation="gelu"),
        ])
        self.cw = layers.Conv2D(out_c, 1, padding="same")

    def call(self, x):
        x1, x2 = self.w1(x), self.w2(x)
        out = self.wo(layers.Multiply()([x1, x2]))
        out = layers.Add()([out, self.cw(x)])
        return out


In [None]:
class DilatedGatedAttention(layers.Layer):
    def __init__(self, in_c, out_c, k_size=3, dilated_ratio=[1, 2, 4], **kwargs):
        super(DilatedGatedAttention, self).__init__(**kwargs)  # ✅ allows name, trainable, etc.
        self.in_c = in_c
        self.out_c = out_c
        self.k_size = k_size
        self.dilated_ratio = dilated_ratio

        # Dilated convolutions
        self.convs = [
            layers.Conv2D(
                in_c, k_size, padding="same", dilation_rate=d,
                activation="relu"
            ) for d in dilated_ratio
        ]

        # Normalization
        self.norm = layers.LayerNormalization(axis=-1)  
        self.conv = layers.Conv2D(in_c, 1, padding="same")
        self.gau = GatedAttentionUnit(in_c, out_c, kernel_size=k_size)

    def call(self, x):
        dilated_features = [conv(x) for conv in self.convs]
        fused = layers.Concatenate()(dilated_features)
        fused = self.norm(fused)
        fused = self.conv(fused)
        out = self.gau(fused)
        return out


In [None]:
# ------------------ IEA Block ------------------
class EAblock(layers.Layer):
    def __init__(self, channels, name=None):
        super(EAblock, self).__init__(name=name)
        self.channels = channels
        self.expand_channels = channels * 4

        # 1x1 Conv before attention
        self.conv1 = layers.Conv2D(channels, kernel_size=1, padding='same', activation=None)
        # Memory unit 1: expand
        self.linear_0 = layers.Conv1D(self.expand_channels, kernel_size=1, use_bias=False)
        # Memory unit 2: compress
        self.linear_1 = layers.Conv1D(channels, kernel_size=1, use_bias=False)
        # 1x1 Conv after attention
        self.conv2 = layers.Conv2D(channels, kernel_size=1, padding='same', activation=None)
        # Normalization
        self.norm = layers.LayerNormalization(axis=-1)
        # Activation
        self.act = layers.Activation('gelu')

    def call(self, x):
        # Save residual
        identity = x

        # 1x1 conv
        x = self.conv1(x)

        # Flatten spatial dimensions: B, H*W, C
        b, h, w, c = tf.shape(x)[0], tf.shape(x)[1], tf.shape(x)[2], tf.shape(x)[3]
        x_flat = tf.reshape(x, [b, h*w, c])  # shape: (B, HW, C)

        # Memory Unit 1: expand channels
        attn = self.linear_0(x_flat)          # (B, HW, 4C)
        attn = tf.nn.softmax(attn, axis=-1)
        attn = attn / (1e-9 + tf.reduce_sum(attn, axis=-1, keepdims=True))

        # Memory Unit 2: compress back
        x_flat = self.linear_1(attn)          # (B, HW, C)

        # Reshape back to (B, H, W, C)
        x = tf.reshape(x_flat, [b, h, w, c])

        # 1x1 conv + normalization
        x = self.conv2(x)
        x = self.norm(x)

        # Add residual + activation
        x = x + identity
        x = self.act(x)

        return x


In [None]:

# ------------------ UNet with EA + DGA in Encoder, DGA + EA in Decoder ------------------
def unet_with_ea_dga(input_shape=(256,256,1)):
    inputs = layers.Input(input_shape, name="input_layer")

    # ----- Encoder -----
    c1 = layers.Conv2D(64, 3, activation='relu', padding='same', name="enc_c1a")(inputs)
    c1 = layers.Conv2D(64, 3, activation='relu', padding='same', name="enc_c1b")(c1)
    p1 = layers.MaxPooling2D((2, 2), name="pool1")(c1)

    c2 = layers.Conv2D(128, 3, activation='relu', padding='same', name="enc_c2a")(p1)
    c2 = layers.Conv2D(128, 3, activation='relu', padding='same', name="enc_c2b")(c2)
    p2 = layers.MaxPooling2D((2, 2), name="pool2")(c2)

    c3 = layers.Conv2D(256, 3, activation='relu', padding='same', name="enc_c3a")(p2)
    c3 = layers.Conv2D(256, 3, activation='relu', padding='same', name="enc_c3b")(c3)
    p3 = layers.MaxPooling2D((2, 2), name="pool3")(c3)

    # Encoder 4-6: EAblock -> DGA
    c4 = layers.Conv2D(512, 3, activation='relu', padding='same', name="enc_c4a")(p3)
    c4 = layers.Conv2D(512, 3, activation='relu', padding='same', name="enc_c4b")(c4)
    c4 = EAblock(512, name="e_ablock_enc4")(c4)
    c4 = DilatedGatedAttention(512, 512, name="dga_block_enc4")(c4)
    p4 = layers.MaxPooling2D((2, 2), name="pool4")(c4)

    c5 = layers.Conv2D(1024, 3, activation='relu', padding='same', name="enc_c5a")(p4)
    c5 = layers.Conv2D(1024, 3, activation='relu', padding='same', name="enc_c5b")(c5)
    c5 = EAblock(1024, name="e_ablock_enc5")(c5)
    c5 = DilatedGatedAttention(1024, 1024, name="dga_block_enc5")(c5)
    p5 = layers.MaxPooling2D((2, 2), name="pool5")(c5)

    c6 = layers.Conv2D(1024, 3, activation='relu', padding='same', name="enc_c6a")(p5)
    c6 = layers.Conv2D(1024, 3, activation='relu', padding='same', name="enc_c6b")(c6)
    c6 = EAblock(1024, name="e_ablock_enc6")(c6)
    c6 = DilatedGatedAttention(1024, 1024, name="dga_block_enc6")(c6)

    # ----- Decoder 1-3: DGA -> EA -----
    u7 = layers.Conv2DTranspose(1024, 2, strides=(2,2), padding='same', name="up7")(c6)
    u7 = layers.concatenate([u7, c5], axis=3, name="concat7")
    c7 = layers.Conv2D(1024, 3, activation='relu', padding='same', name="dec_c7a")(u7)
    c7 = layers.Conv2D(1024, 3, activation='relu', padding='same', name="dec_c7b")(c7)
    c7 = DilatedGatedAttention(1024, 1024, name="dga_block_dec7")(c7)
    c7 = EAblock(1024, name="e_ablock_dec7")(c7)

    u8 = layers.Conv2DTranspose(512, 2, strides=(2,2), padding='same', name="up8")(c7)
    u8 = layers.concatenate([u8, c4], axis=3, name="concat8")
    c8 = layers.Conv2D(512, 3, activation='relu', padding='same', name="dec_c8a")(u8)
    c8 = layers.Conv2D(512, 3, activation='relu', padding='same', name="dec_c8b")(c8)
    c8 = DilatedGatedAttention(512, 512, name="dga_block_dec8")(c8)
    c8 = EAblock(512, name="e_ablock_dec8")(c8)

    u9 = layers.Conv2DTranspose(256, 2, strides=(2,2), padding='same', name="up9")(c8)
    u9 = layers.concatenate([u9, c3], axis=3, name="concat9")
    c9 = layers.Conv2D(256, 3, activation='relu', padding='same', name="dec_c9a")(u9)
    c9 = layers.Conv2D(256, 3, activation='relu', padding='same', name="dec_c9b")(c9)
    c9 = DilatedGatedAttention(256, 256, name="dga_block_dec9")(c9)
    c9 = EAblock(256, name="e_ablock_dec9")(c9)

    # ----- Decoder 4-5: plain convs -----
    u10 = layers.Conv2DTranspose(128, 2, strides=(2,2), padding='same', name="up10")(c9)
    u10 = layers.concatenate([u10, c2], axis=3, name="concat10")
    c10 = layers.Conv2D(128, 3, activation='relu', padding='same', name="dec_c10a")(u10)
    c10 = layers.Conv2D(128, 3, activation='relu', padding='same', name="dec_c10b")(c10)

    u11 = layers.Conv2DTranspose(64, 2, strides=(2,2), padding='same', name="up11")(c10)
    u11 = layers.concatenate([u11, c1], axis=3, name="concat11")
    c11 = layers.Conv2D(64, 3, activation='relu', padding='same', name="dec_c11a")(u11)
    c11 = layers.Conv2D(64, 3, activation='relu', padding='same', name="dec_c11b")(c11)

    outputs = layers.Conv2D(1, 1, activation='sigmoid', name="output_layer")(c11)

    model = models.Model(inputs=inputs, outputs=outputs, name="UNet_with_EA_DGA")
    return model

# ------------------ Test ------------------
model = unet_with_ea_dga(input_shape=(256,256,1))
model.summary()


In [None]:
def dice_coefficient(y_true, y_pred, smooth=1e-6):
    y_true_f = tf.keras.backend.flatten(y_true)
    y_pred_f = tf.keras.backend.flatten(y_pred)
    intersection = tf.reduce_sum(y_true_f * y_pred_f)
    return (2. * intersection + smooth) / (tf.reduce_sum(y_true_f) + tf.reduce_sum(y_pred_f) + smooth)

def iou_metric(y_true, y_pred, smooth=1e-6):
    y_true_f = tf.keras.backend.flatten(y_true)
    y_pred_f = tf.keras.backend.flatten(y_pred)
    intersection = tf.reduce_sum(y_true_f * y_pred_f)
    union = tf.reduce_sum(y_true_f) + tf.reduce_sum(y_pred_f) - intersection
    return (intersection + smooth) / (union + smooth)

def bce_dice_loss(y_true, y_pred):
    bce = tf.keras.losses.BinaryCrossentropy()(y_true, y_pred)
    dice_loss = 1 - dice_coefficient(y_true, y_pred)
    return bce + dice_loss

In [None]:
# -------------------------
# Image loader
# -------------------------
def load_only_image(image_path):
    img = tf.io.read_file(image_path)
    img = tf.image.decode_bmp(img, channels=3)   # decode as RGB
    img = tf.image.resize(img, [256, 256])
    img = tf.image.rgb_to_grayscale(img)         # convert to 1 channel
    img = tf.cast(img, tf.float32) / 255.0
    return img

def load_image(image_path, mask_path):
    img = load_only_image(image_path)
    mask = load_only_image(mask_path)
    return img, mask


In [None]:
# -------------------------
# Attention map visualizer callback
# -------------------------
class AttentionMapLogger(Callback):
    def __init__(self, sample_image, every_n_epochs=5, save_dir="attention_maps"):
        super().__init__()
        self.sample_image = tf.expand_dims(sample_image, axis=0)  # add batch dim
        self.every_n_epochs = every_n_epochs
        self.save_dir = save_dir
        os.makedirs(save_dir, exist_ok=True)

    def save_attention_maps(self, feature_maps, stage_name, epoch):
        # feature_maps: tensor of shape (1, H, W, C)
        fmap = feature_maps[0]  # remove batch
        fmap_mean = tf.reduce_mean(fmap, axis=-1)  # collapse channels -> (H,W)
        fmap_min = tf.reduce_min(fmap_mean)
        fmap_max = tf.reduce_max(fmap_mean)
        fmap_norm = (fmap_mean - fmap_min) / (fmap_max - fmap_min + 1e-6)

        plt.figure(figsize=(5,5))
        plt.imshow(fmap_norm, cmap='jet')
        plt.axis('off')
        plt.title(f"{stage_name} Epoch {epoch}")
        plt.savefig(os.path.join(self.save_dir, f"{stage_name}_epoch{epoch}.png"))
        plt.close()
    def on_epoch_end(self, epoch, logs=None):
        if (epoch + 1) % self.every_n_epochs != 0:
            return

        # Run the model on the sample image and fetch attention maps
        # Assumes your model returns a list of attention feature maps as second output
        # Or that you have a model subclass which exposes encoder/decoder feature maps
        attention_features = self.model.get_attention_features(self.sample_image)  # custom function

        for i, fmap in enumerate(attention_features):
            stage_name = f"stage_{i+1}"
            self.save_attention_maps(fmap, stage_name, epoch + 1)





In [None]:
import matplotlib.pyplot as plt
import time
from sklearn.model_selection import KFold
import numpy as np
import glob
import tensorflow as tf
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint

In [None]:
# -------------------------
# Training loop with K-Fold
# -------------------------
EPOCHS = 300
BATCH_SIZE = 8

early_stopping = EarlyStopping(
    monitor='val_dice_coefficient',
    patience=10,
    mode='max',
    restore_best_weights=True
)

checkpoint = ModelCheckpoint(
    "best_model.h5",
    monitor='val_dice_coefficient',
    save_best_only=True,
    mode='max'
)

image_paths = np.array(sorted(glob.glob(os.path.join(TRAIN_IMAGES_DIR, "*.bmp"))))
mask_paths  = np.array(sorted(glob.glob(os.path.join(TRAIN_MASKS_DIR, "*.bmp"))))

kfold = KFold(n_splits=5, shuffle=True, random_state=42)
fold_num = 1

for train_idx, val_idx in kfold.split(image_paths):
    print(f"\n--- Fold {fold_num} ---")

    train_images, train_masks = image_paths[train_idx], mask_paths[train_idx]
    val_images, val_masks = image_paths[val_idx], mask_paths[val_idx]

    # -------------------------
    # Dataset creation
    # -------------------------
    def dataset_from_paths(images, masks, shuffle=True):
        dataset = tf.data.Dataset.from_tensor_slices((images, masks))
        dataset = dataset.map(load_image, num_parallel_calls=tf.data.AUTOTUNE)
        if shuffle:
            dataset = dataset.shuffle(100)
        return dataset.batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)

    train_dataset = dataset_from_paths(train_images, train_masks)
    val_dataset = dataset_from_paths(val_images, val_masks, shuffle=False)

    # -------------------------
    # Build model
    # -------------------------
    model = unet_with_ea_dga(input_shape=(256,256,1))
    model.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
        loss=bce_dice_loss,
        metrics=['accuracy', dice_coefficient, iou_metric]
    )

    # -------------------------
    # Feature extractor for attention visualization
    # -------------------------

    layer_names = [
        'enc_c1b',
        'enc_c2b',
        'enc_c3b',
        'enc_c4b',
        'e_ablock_enc4',
        'dga_block_enc4',
        'enc_c5b',
        'e_ablock_enc5',
        'dga_block_enc5',
        'enc_c6b',
        'e_ablock_enc6',
        'dga_block_enc6',
        'dec_c7b',
        'dga_block_dec7',
        'e_ablock_dec7',
        'dec_c8b',
        'dga_block_dec8',
        'e_ablock_dec8',
        'dec_c9b',
        'dga_block_dec9',
        'e_ablock_dec9',
        'dec_c10b',
        'dec_c11b'
    ]


    intermediate_outputs = [model.get_layer(name).output for name in layer_names]
    feature_model = tf.keras.Model(inputs=model.input, outputs=intermediate_outputs)

    # Sample image for attention visualization
    sample_image = load_only_image(val_images[0])
    sample_image = tf.expand_dims(sample_image, axis=0)  # Add batch dimension

    # -------------------------
    # Custom callback for attention maps
    # -------------------------
    class AttentionMapLogger(tf.keras.callbacks.Callback):
        def __init__(self, sample_image, every_n_epochs=5):
            super().__init__()
            self.sample_image = sample_image
            self.every_n_epochs = every_n_epochs

        def on_epoch_end(self, epoch, logs=None):
            if (epoch + 1) % self.every_n_epochs == 0:
                attention_features = feature_model(self.sample_image, training=False)
                for i, fmap in enumerate(attention_features):
                    fmap_mean = tf.reduce_mean(fmap, axis=-1)  # average over channels
                    fmap_img = tf.squeeze(fmap_mean).numpy()
                    plt.figure(figsize=(4,4))
                    plt.imshow(fmap_img, cmap='jet')
                    plt.title(f"Epoch {epoch+1} - Layer {layer_names[i]}")
                    plt.axis('off')
                    plt.show()

    attention_logger = AttentionMapLogger(sample_image, every_n_epochs=5)

    # -------------------------
    # Train
    # -------------------------
    start_time = time.time()
    history = model.fit(
        train_dataset,
        validation_data=val_dataset,
        epochs=EPOCHS,
        callbacks=[early_stopping, checkpoint, attention_logger]
    )
    end_time = time.time()

    # -------------------------
    # Model summary, GFLOPs, training time
    # -------------------------
    model.summary()
    try:
        from tensorflow.python.profiler import model_analyzer
        from tensorflow.python.profiler.option_builder import ProfileOptionBuilder
        profile = model_analyzer.profile(
            model, options=ProfileOptionBuilder.float_operation()
        )
        print(f"GFLOPs: {profile.total_float_ops / 1e9}")
    except:
        print("GFLOPs calculation not available, use tensorflow profiler for exact computation.")

    print(f"Training time for fold {fold_num}: {end_time - start_time:.2f}s")
    fold_num += 1


In [None]:
# -------------------------
# Load best model
# -------------------------
model = unet_with_ea_dga(input_shape=(256, 256, 1))
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
              loss=bce_dice_loss,
              metrics=['accuracy', dice_coefficient, iou_metric])

model.load_weights("best_model.h5")


# -------------------------
# Evaluate on test set
# -------------------------
results = model.evaluate(test_dataset)
print("\nTest Results:")
for name, value in zip(model.metrics_names, results):
    print(f"{name}: {value:.4f}")

In [None]:
def display_predictions(model, dataset, num_samples=5):
    for batch_idx, (images, masks) in enumerate(dataset.take(num_samples)):
        preds = model.predict(images)
        preds = (preds > 0.5).astype("float32")  # Binarize predictions

        for i in range(len(images)):
            plt.figure(figsize=(12, 4))

            plt.subplot(1, 3, 1)
            plt.title("Input Image")
            plt.imshow(images[i].numpy().squeeze(), cmap="gray")
            plt.axis("off")

            plt.subplot(1, 3, 2)
            plt.title("Ground Truth")
            plt.imshow(masks[i].numpy().squeeze(), cmap="gray")
            plt.axis("off")

            plt.subplot(1, 3, 3)
            plt.title("Prediction")
            plt.imshow(preds[i].squeeze(), cmap="gray")
            plt.axis("off")

            plt.show()


display_predictions(model, test_dataset, num_samples=5)