# **1. Install packages**

In [None]:
%%capture
!pip install segmentation-models-pytorch

# **2. Import libraries**

In [None]:
# Data handling
import pandas as pd
import numpy as np

# Data visualization
import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image
import cv2

# Torch
import torch
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import segmentation_models_pytorch as smp
from torchinfo import summary

# os
import os

# Path
from pathlib import Path

# tqdm
from tqdm.auto import tqdm

# warnings
import warnings
warnings.filterwarnings("ignore")

# **3. Load data**

In [None]:
# We define a function to create a list of the paths of the images and masks.
def image_mask_path(image_path:str, mask_path:str):
    IMAGE_PATH = Path(image_path)
    IMAGE_PATH_LIST = sorted(list(IMAGE_PATH.glob("*.png")))

    MASK_PATH = Path(mask_path)
    MASK_PATH_LIST = sorted(list(MASK_PATH.glob("*.png")))
    
    return IMAGE_PATH_LIST, MASK_PATH_LIST

In [None]:
image_path_train = "/kaggle/input/breast-cancer-semantic-segmentation-bcss/BCSS_512/train_512"
mask_path_train = "/kaggle/input/breast-cancer-semantic-segmentation-bcss/BCSS_512/train_mask_512"

IMAGE_PATH_LIST_TRAIN, MASK_PATH_LIST_TRAIN = image_mask_path(image_path_train, 
                                                              mask_path_train)

print(f'Total Images Train: {len(IMAGE_PATH_LIST_TRAIN)}')
print(f'Total Masks Train: {len(MASK_PATH_LIST_TRAIN)}')

In [None]:
image_path_val = "/kaggle/input/breast-cancer-semantic-segmentation-bcss/BCSS_512/val_512"
mask_path_val = "/kaggle/input/breast-cancer-semantic-segmentation-bcss/BCSS_512/val_mask_512"

IMAGE_PATH_LIST_VAL, MASK_PATH_LIST_VAL = image_mask_path(image_path_val, 
                                                          mask_path_val)

print(f'Total Images Val: {len(IMAGE_PATH_LIST_VAL)}')
print(f'Total Masks Val: {len(MASK_PATH_LIST_VAL)}')

In [None]:
VALUES_UNIQUE_TRAIN = []

for i in MASK_PATH_LIST_TRAIN:
    sample = cv2.imread(str(i), cv2.IMREAD_GRAYSCALE)
    uniques = np.unique(sample)
    VALUES_UNIQUE_TRAIN.append(uniques)
    
FINAL_VALUES_UNIQUE_TRAIN = np.concatenate(VALUES_UNIQUE_TRAIN)
print("Unique values Train:\n")
print(np.unique(FINAL_VALUES_UNIQUE_TRAIN))

In [None]:
import matplotlib.patches as mpatches
import matplotlib.pyplot as plt

# Colors for labels (for visualization, choose colors you like)
colors = plt.cm.tab20(range(22))  # Generate a list of 22 colors

# Labels for the legend
labels = [
    "outside_roi", "tumor", "stroma", "lymphocytic_infiltrate", 
    "necrosis_or_debris", "glandular_secretions", "blood", "exclude", 
    "metaplasia_NOS", "fat", "plasma_cells", "other_immune_infiltrate", 
    "mucoid_material", "normal_acinus_or_duct", "lymphatics", "undetermined", 
    "nerve", "skin_adnexa", "blood_vessel", "angioinvasion", "dcis", "other"
]

# Generate label patches for legend
legend_patches = [mpatches.Patch(color=colors[i], label=f"{i}: {labels[i]}") for i in range(22)]

# Plot legend in horizontal layout
fig, ax = plt.subplots(figsize=(12, 2))
ax.axis('off')  # Turn off axes

# Add legend horizontally
plt.legend(
    handles=legend_patches,
    loc='center',
    bbox_to_anchor=(0.5, 0.5),
    ncol=5,  # Number of columns in the legend
    title="Label Mapping"
)
plt.title("Label Color Mapping", fontsize=12, fontweight="bold")
plt.show()


In [None]:
# Display 5 images and their respective masks horizontally
fig, ax = plt.subplots(nrows=2, ncols=5, figsize=(25, 10))

for i, (img_path, mask_path) in enumerate(zip(IMAGE_PATH_LIST_TRAIN, MASK_PATH_LIST_TRAIN)):
    
    if i > 4:  # Limit to 5 images
        break
    
    # Load and display the image
    img_bgr = cv2.imread(str(img_path))
    img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
    ax[0, i].imshow(img_rgb)
    ax[0, i].axis('off')
    ax[0, i].set_title(f"Image\nShape: {img_rgb.shape}", fontsize=10, fontweight="bold", color="black")

    # Load and display the mask
    mask = cv2.imread(str(mask_path), cv2.IMREAD_GRAYSCALE)
    ax[1, i].imshow(mask)
    ax[1, i].axis('off')
    ax[1, i].set_title(f"Mask\nShape: {mask.shape}", fontsize=10, fontweight="bold", color="black")

fig.tight_layout()
plt.show()


In [None]:
# We visualize some images but with the mask superimposed.
fig, ax = plt.subplots(nrows = 10, ncols = 2, figsize = (12,30))
ax = ax.flat

for i,(img_path, mask_path) in enumerate(zip(IMAGE_PATH_LIST_TRAIN, MASK_PATH_LIST_TRAIN)):
    
    if i>19:
        break
        
    img_bgr = cv2.imread(str(img_path))
    img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
    ax[i].imshow(img_rgb)
    ax[i].axis('off')
    

    mask = cv2.imread(str(mask_path), cv2.IMREAD_GRAYSCALE)
    ax[i].imshow(mask, alpha = 0.30)
    ax[i].axis('off')
    

fig.tight_layout()
fig.show()

# **4. Preprocessing**

**We will create dataframes for both data sets.**

In [None]:
data_train = pd.DataFrame({'Image':IMAGE_PATH_LIST_TRAIN, 
                           'Mask':MASK_PATH_LIST_TRAIN})

data_val = pd.DataFrame({'Image':IMAGE_PATH_LIST_VAL, 
                         'Mask':MASK_PATH_LIST_VAL})

**Now we are going to find out what transformations were applied to the images when the model was pre-trained in order to replicate it in our images.**

In [None]:
preprocess_input = smp.encoders.get_preprocessing_fn(encoder_name = "resnet34", 
                                        pretrained = "imagenet")
preprocess_input

**We are going to replicate this same thing.**

In [None]:
RESIZE = (224, 224)
MEAN = [0.485, 0.456, 0.406]
STD = [0.229, 0.224, 0.225]

image_transforms = transforms.Compose([transforms.Resize(RESIZE),
                                       transforms.ToTensor(),
                                       transforms.Normalize(mean = MEAN, std = STD)])

mask_transforms = transforms.Compose([transforms.Resize(RESIZE), 
                                      transforms.PILToTensor()])

**We define our Dataset with all the transformations to perform.**

- **Dataset**

In [None]:
class CustomImageMaskDataset(Dataset):
    def __init__(self, data:pd.DataFrame, image_transforms, mask_transforms):
        self.data = data
        self.image_transforms = image_transforms
        self.mask_transforms = mask_transforms
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        image_path = self.data.iloc[idx, 0]
        image = Image.open(image_path).convert("RGB")
        image = self.image_transforms(image)
        
        mask_path = self.data.iloc[idx, 1]
        mask = Image.open(mask_path)
        mask = self.mask_transforms(mask)
        
        return image, mask

In [None]:
train_dataset = CustomImageMaskDataset(data_train, image_transforms, 
                                       mask_transforms)
    
val_dataset = CustomImageMaskDataset(data_val, image_transforms, 
                                     mask_transforms)

- **DataLoader**

In [None]:
BATCH_SIZE = 64
NUM_WORKERS = os.cpu_count()

train_dataloader = DataLoader(dataset = train_dataset, batch_size = BATCH_SIZE, 
                              shuffle = True, num_workers = NUM_WORKERS)

val_dataloader = DataLoader(dataset = val_dataset, batch_size = BATCH_SIZE, 
                            shuffle = True, num_workers = NUM_WORKERS)

In [None]:
# We visualize the dimensions of a batch.
batch_images, batch_masks = next(iter(train_dataloader))

batch_images.shape, batch_masks.shape

# 6.UNet++ model

In [None]:
# Define UNet++ Model
model = smp.UnetPlusPlus(
    encoder_name="resnet34",
    encoder_weights="imagenet",
    in_channels=3,
    classes=22,
    activation=None
)

# Visualize Model Summary
summary(
    model=model,
    input_size=[BATCH_SIZE, 3, 224, 224],
    col_width=15,
    col_names=['input_size', 'output_size', 'num_params', 'trainable'],
    row_settings=['var_names']
)

# Freeze Encoder Parameters
for param in model.encoder.parameters():
    param.requires_grad = False

# Verify Freezing
summary(
    model=model,
    input_size=[BATCH_SIZE, 3, 224, 224],
    col_width=15,
    col_names=['input_size', 'output_size', 'num_params', 'trainable'],
    row_settings=['var_names']
)

# Early Stopping Class
class EarlyStopping:
    def __init__(self, patience: int = 5, delta: float = 0.0001, path: str = "best_model.pth"):
        self.patience = patience
        self.delta = delta
        self.path = path
        self.best_score = None
        self.counter = 0
        self.early_stop = False
        
    def __call__(self, val_loss, model):
        if self.best_score is None:
            self.best_score = val_loss
            self.save_checkpoint(model)
            
        elif val_loss >= self.best_score + self.delta:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
                
        else:
            self.best_score = val_loss
            self.save_checkpoint(model)
            self.counter = 0
            
    def save_checkpoint(self, model):
        torch.save(model.state_dict(), self.path)

# Initialize Early Stopping
early_stopping = EarlyStopping(patience=20, delta=0.0)


In [None]:
# Define Training and Validation Steps
def train_step(model: torch.nn.Module, dataloader: torch.utils.data.DataLoader, 
               loss_fn: torch.nn.Module, optimizer: torch.optim.Optimizer):
    model.train()
    train_loss = 0.0
    train_accuracy = 0.0
    
    for batch, (X, y) in enumerate(dataloader):
        X = X.to(device=DEVICE, dtype=torch.float32)
        y = y.to(device=DEVICE, dtype=torch.long)
        optimizer.zero_grad()
        logit_mask = model(X)
        loss = loss_fn(logit_mask, y.squeeze())
        train_loss += loss.item()
        
        loss.backward()
        optimizer.step()
        
        prob_mask = logit_mask.softmax(dim=1)
        pred_mask = prob_mask.argmax(dim=1)
        
        tp, fp, fn, tn = smp.metrics.get_stats(
            output=pred_mask.detach().cpu().long(),
            target=y.squeeze().cpu().long(),
            mode="multiclass",
            num_classes=22
        )
        
        train_accuracy += smp.metrics.accuracy(tp, fp, fn, tn, reduction="micro").numpy()
    
    train_loss /= len(dataloader)
    train_accuracy /= len(dataloader)
    
    return train_loss, train_accuracy

def val_step(model: torch.nn.Module, dataloader: torch.utils.data.DataLoader, 
             loss_fn: torch.nn.Module):
    model.eval()
    val_loss = 0.0
    val_accuracy = 0.0
    
    with torch.inference_mode():
        for batch, (X, y) in enumerate(dataloader):
            X = X.to(device=DEVICE, dtype=torch.float32)
            y = y.to(device=DEVICE, dtype=torch.long)
            logit_mask = model(X)
            loss = loss_fn(logit_mask, y.squeeze())
            val_loss += loss.item()
            
            prob_mask = logit_mask.softmax(dim=1)
            pred_mask = prob_mask.argmax(dim=1)
            
            tp, fp, fn, tn = smp.metrics.get_stats(
                output=pred_mask.detach().cpu().long(),
                target=y.squeeze().cpu().long(),
                mode="multiclass",
                num_classes=22
            )
            
            val_accuracy += smp.metrics.accuracy(tp, fp, fn, tn, reduction="micro").numpy()
    
    val_loss /= len(dataloader)
    val_accuracy /= len(dataloader)
    return val_loss, val_accuracy

def train_model(model: torch.nn.Module, train_dataloader: torch.utils.data.DataLoader, 
               val_dataloader: torch.utils.data.DataLoader, loss_fn: torch.nn.Module, 
               optimizer: torch.optim.Optimizer, early_stopping, epochs: int = 10):
    results = {'train_loss': [], 'train_accuracy': [], 'val_loss': [], 'val_accuracy': []}
    
    for epoch in tqdm(range(epochs)):
        train_loss, train_accuracy = train_step(model, train_dataloader, loss_fn, optimizer)
        val_loss, val_accuracy = val_step(model, val_dataloader, loss_fn)
        
        print(f'Epoch: {epoch + 1} | '
              f'Train Loss: {train_loss:.4f} | '
              f'Train Accuracy: {train_accuracy:.4f} | '
              f'Val Loss: {val_loss:.4f} | '
              f'Val Accuracy: {val_accuracy:.4f}')
        
        early_stopping(val_loss, model)
        
        if early_stopping.early_stop:
            print("Early Stopping triggered!")
            break
            
        results['train_loss'].append(train_loss)
        results['train_accuracy'].append(train_accuracy)
        results['val_loss'].append(val_loss)
        results['val_accuracy'].append(val_accuracy)
        
    return results

In [None]:
# CUDA
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DEVICE

In [None]:

# Set Seeds for Reproducibility
SEED = 42
EPOCHS = 50  # Adjust as needed
torch.cuda.manual_seed(SEED)
torch.manual_seed(SEED)

# Define Loss Function and Optimizer
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-4)

# Train the Model
RESULTS = train_model(
    model=model.to(DEVICE),
    train_dataloader=train_dataloader,
    val_dataloader=val_dataloader,
    loss_fn=loss_fn,
    optimizer=optimizer,
    early_stopping=early_stopping,
    epochs=EPOCHS
)

# 6. Evaluation

In [None]:
# Plot Loss and Accuracy
def loss_and_metric_plot(results: dict):
    training_loss = results['train_loss']
    training_metric = results['train_accuracy']
    validation_loss = results['val_loss']
    validation_metric = results['val_accuracy']
    
    fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(12, 5))
    
    # Plot Loss
    ax[0].plot(training_loss, label="Train Loss")
    ax[0].plot(validation_loss, label="Val Loss")
    ax[0].set_title("CrossEntropyLoss", fontsize=12, fontweight="bold")
    ax[0].set_xlabel("Epoch", fontsize=10, fontweight="bold")
    ax[0].set_ylabel("Loss", fontsize=10, fontweight="bold")
    ax[0].legend()
    
    # Plot Accuracy
    ax[1].plot(training_metric, label="Train Accuracy")
    ax[1].plot(validation_metric, label="Val Accuracy")
    ax[1].set_title("Accuracy", fontsize=12, fontweight="bold")
    ax[1].set_xlabel("Epoch", fontsize=10, fontweight="bold")
    ax[1].set_ylabel("Accuracy", fontsize=10, fontweight="bold")
    ax[1].legend()
    
    plt.tight_layout()
    plt.show()

loss_and_metric_plot(RESULTS)

In [None]:
image_path_test = "/kaggle/input/breast-cancer-semantic-segmentation-bcss/BCSS/test"

IMAGE_PATH_LIST_TEST = list(Path(image_path_test).glob("*.png"))

print(f'Total Images Train: {len(IMAGE_PATH_LIST_TEST)}')

In [None]:
class CustomTestDataset(Dataset):
    def __init__(self, data:pd.DataFrame, image_transforms):
        self.data = data
        self.image_transforms = image_transforms
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        image_path = self.data.iloc[idx, 0]
        image = Image.open(image_path).convert("RGB")
        image = self.image_transforms(image)
        
        return image

In [None]:
def predictions_mask(test_dataloader: torch.utils.data.DataLoader):
    # Load the best model
    checkpoint = torch.load("/kaggle/working/best_model.pth")
    
    # Initialize UNet++ model with the correct number of classes
    loaded_model = smp.UnetPlusPlus(
        encoder_name="resnet34",
        encoder_weights=None,  # weights are loaded from checkpoint
        in_channels=3,
        classes=22,  # Updated to match checkpoint
        activation=None
    )
    
    # Load state dict
    loaded_model.load_state_dict(checkpoint)
    
    # Move to device
    loaded_model.to(device=DEVICE)
    
    # Set to evaluation mode
    loaded_model.eval()
    
    y_pred_mask = []
    
    with torch.inference_mode():
        for batch, X in tqdm(enumerate(test_dataloader), total=len(test_dataloader)):
            X = X.to(device=DEVICE, dtype=torch.float32)
            mask_logit = loaded_model(X)
            mask_prob = mask_logit.softmax(dim=1)
            mask_pred = mask_prob.argmax(dim=1)
            y_pred_mask.append(mask_pred.detach().cpu())
    
    y_pred_mask = torch.cat(y_pred_mask)
    image_path_test = "/kaggle/input/breast-cancer-semantic-segmentation-bcss/BCSS/test"
    IMAGE_PATH_LIST_TEST = list(Path(image_path_test).glob("*.png"))
    print(f'Total Images Test: {len(IMAGE_PATH_LIST_TEST)}')
    return y_pred_mask

# Prepare Test Data
data_test = pd.DataFrame({'Image': IMAGE_PATH_LIST_TEST})
data_test.head()

# Create Test Dataset and DataLoader
test_dataset = CustomTestDataset(data_test, image_transforms)
test_dataloader = DataLoader(dataset=test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)

# Execute Predictions
y_pred_mask = predictions_mask(test_dataloader)

In [None]:
# Visualize Original Image, Ground Truth Mask (Unprocessed), and Predicted Mask
fig, ax = plt.subplots(nrows=10, ncols=3, figsize=(18, 35))

for index, row in data_test.iterrows():
    if index > 9:  # Limit to 10 samples
        break
    
    # Original Image
    img_bgr = cv2.imread(str(row[0]))
    img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
    ax[index, 0].imshow(img_rgb)
    ax[index, 0].axis('off')
    ax[index, 0].set_title("Original Image", fontsize=12, fontweight="bold", color="black")
    
    # Ground Truth Mask (Unprocessed)
    if index < len(MASK_PATH_LIST_TRAIN):  # Ensure we are accessing the correct index
        original_mask_path = MASK_PATH_LIST_TRAIN[index]
        ground_truth_mask = cv2.imread(str(original_mask_path), cv2.IMREAD_GRAYSCALE)  # Load original mask
        ax[index, 1].imshow(ground_truth_mask)
        ax[index, 1].axis('off')
        ax[index, 1].set_title("Ground Truth Mask", fontsize=12, fontweight="bold", color="black")
    else:
        ax[index, 1].axis('off')  # Leave blank if ground truth is not available

    # Predicted Mask
    ax[index, 2].imshow(y_pred_mask[index].squeeze().numpy(), cmap='jet')
    ax[index, 2].axis('off')
    ax[index, 2].set_title("Predicted Mask", fontsize=12, fontweight="bold", color="black")

fig.tight_layout()
plt.show()

In [None]:
# Visualize 5 Original Images, Ground Truth Masks, and Predicted Masks
fig, ax = plt.subplots(nrows=4, ncols=3, figsize=(18, 25))

for index, row in data_test.iterrows():
    if index > 3:  # Limit to 5 samples
        break
    
    # Original Image
    img_bgr = cv2.imread(str(row[0]))
    img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
    ax[index, 0].imshow(img_rgb)
    ax[index, 0].axis('off')
    ax[index, 0].set_title("Original Image", fontsize=12, fontweight="bold", color="black")
    
    # Ground Truth Mask (Unprocessed)
    if index < len(MASK_PATH_LIST_TRAIN):  # Ensure we are accessing the correct index
        original_mask_path = MASK_PATH_LIST_TRAIN[index]
        ground_truth_mask = cv2.imread(str(original_mask_path), cv2.IMREAD_GRAYSCALE)  # Load original mask
        ax[index, 1].imshow(ground_truth_mask)
        ax[index, 1].axis('off')
        ax[index, 1].set_title("Ground Truth Mask", fontsize=12, fontweight="bold", color="black")
    else:
        ax[index, 1].axis('off')  # Leave blank if ground truth is not available

    # Predicted Mask
    ax[index, 2].imshow(y_pred_mask[index].squeeze().numpy(), cmap='jet')
    ax[index, 2].axis('off')
    ax[index, 2].set_title("Predicted Mask", fontsize=12, fontweight="bold", color="black")

fig.tight_layout()
plt.show()

In [None]:
import matplotlib.pyplot as plt
import cv2
from pathlib import Path
import pandas as pd
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
import segmentation_models_pytorch as smp
import torch.nn as nn

# Assuming the following variables are already defined:
# DEVICE, CustomTestDataset, image_transforms, BATCH_SIZE, NUM_WORKERS,
# IMAGE_PATH_LIST_TEST, MASK_PATH_LIST_TRAIN, y_pred_mask

def visualize_predictions(data_test, y_pred_mask, mask_path_list_train, num_samples=5):
    """
    Visualize Original Images, Ground Truth Masks, and Predicted Masks in rows.

    Parameters:
    - data_test (pd.DataFrame): DataFrame containing test image paths.
    - y_pred_mask (torch.Tensor): Tensor containing predicted masks.
    - mask_path_list_train (list): List of ground truth mask paths.
    - num_samples (int): Number of samples to visualize.
    """
    # Ensure num_samples does not exceed available samples
    num_samples = min(num_samples, len(data_test), len(y_pred_mask), len(mask_path_list_train))
    
    # Select the first 'num_samples' indices
    selected_indices = list(range(num_samples))
    
    # Create a subplot grid with 3 rows and 'num_samples' columns
    fig, axes = plt.subplots(nrows=3, ncols=num_samples, figsize=(5 * num_samples, 15))
    
    # If there's only one sample, axes might not be a 2D array
    if num_samples == 1:
        axes = axes.reshape(3, 1)
    
    for col, idx in enumerate(selected_indices):
        # --- Row 1: Original Image ---
        img_path = data_test.iloc[idx]['Image']  # Ensure 'Image' column exists
        img_bgr = cv2.imread(str(img_path))
        if img_bgr is None:
            print(f"Warning: Image at {img_path} could not be loaded.")
            axes[0, col].axis('off')
            axes[0, col].set_title(f"Sample {idx+1}\nOriginal Image\n(Not Available)", fontsize=12, fontweight="bold")
        else:
            img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
            axes[0, col].imshow(img_rgb)
            axes[0, col].axis('off')
            axes[0, col].set_title(f"Sample {idx+1}\nOriginal Image", fontsize=12, fontweight="bold")
        
        # --- Row 2: Ground Truth Mask ---
        if idx < len(mask_path_list_train):
            mask_path = mask_path_list_train[idx]
            ground_truth_mask = cv2.imread(str(mask_path), cv2.IMREAD_GRAYSCALE)
            if ground_truth_mask is not None:
                axes[1, col].imshow(ground_truth_mask)
                axes[1, col].axis('off')
                axes[1, col].set_title(f"Sample {idx+1}\nGround Truth Mask", fontsize=12, fontweight="bold")
            else:
                axes[1, col].axis('off')
                axes[1, col].set_title(f"Sample {idx+1}\nGround Truth Mask\n(Not Available)", fontsize=12, fontweight="bold")
        else:
            axes[1, col].axis('off')
            axes[1, col].set_title(f"Sample {idx+1}\nGround Truth Mask\n(Not Available)", fontsize=12, fontweight="bold")
        
        # --- Row 3: Predicted Mask ---
        if idx < len(y_pred_mask):
            predicted_mask = y_pred_mask[idx].squeeze().numpy()
            # If the predicted mask has been processed differently, adjust accordingly
            axes[2, col].imshow(predicted_mask, cmap='jet')
            axes[2, col].axis('off')
            axes[2, col].set_title(f"Sample {idx+1}\nPredicted Mask", fontsize=12, fontweight="bold")
        else:
            axes[2, col].axis('off')
            axes[2, col].set_title(f"Sample {idx+1}\nPredicted Mask\n(Not Available)", fontsize=12, fontweight="bold")
    
    # Adjust layout for better spacing
    plt.tight_layout()
    plt.show()

# Example Usage:

# Prepare Test Data
data_test = pd.DataFrame({'Image': IMAGE_PATH_LIST_TEST})
data_test.head()

# Create Test Dataset and DataLoader
test_dataset = CustomTestDataset(data_test, image_transforms)
test_dataloader = DataLoader(dataset=test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)

# Execute Predictions
y_pred_mask = predictions_mask(test_dataloader)

# Visualize Predictions
visualize_predictions(
    data_test=data_test,
    y_pred_mask=y_pred_mask,
    mask_path_list_train=MASK_PATH_LIST_TRAIN,
    num_samples=5  # Adjust the number of samples as needed
)