In [None]:
# Standard libraries
import os
import shutil
import random
from pathlib import Path
import warnings
import time

warnings.filterwarnings("ignore")

# Image processing
import cv2
from PIL import Image

# Data handling
import numpy as np
import pandas as pd

# Plotting
import matplotlib.pyplot as plt

# PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchinfo import summary
from torch.utils.data import Dataset, DataLoader


# Torchvision
from torchvision.models.efficientnet import efficientnet_b0, EfficientNet_B0_Weights

# Albumentations
import albumentations as album
from albumentations.pytorch import ToTensorV2

# Scikit-learn
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score

# Set device
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")

In [None]:
!rm -rf /kaggle/working/*

# Dataset Processing

In [None]:
src_base = "/kaggle/input/combined-dataset-pt1/S5_STARE_Dataset"
EPOCHS = 200
pat = 50 #Patience for Early stopping

In [None]:
import os
from PIL import Image

# --------------------------
# PARAMETERS
# --------------------------
working_base = "/kaggle/working"                                  # working directory
resized_base = os.path.join(working_base, "Resized")             # output resized folder
target_size = (256, 256)                                         # desired size

# --------------------------
# FUNCTIONS
# --------------------------
def is_image_file(filename):
    return filename.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp'))

def resize_image(src_path, dst_path, is_mask=False, size=(256, 256)):
    with Image.open(src_path) as img:
        if is_mask:
            img = img.resize(size, Image.NEAREST)  # preserve mask labels
        else:
            img = img.resize(size, Image.LANCZOS)  # high-quality resize for images
        img.save(dst_path)

def process_dataset(src_base, dst_base):
    """
    Automatically finds all subfolders in src_base, treats folders containing
    'mask' or 'groundtruth' as masks, others as images.
    """
    for root, dirs, files in os.walk(src_base):
        if not files:
            continue

        # Compute destination folder
        rel_path = os.path.relpath(root, src_base)
        dst_folder = os.path.join(dst_base, rel_path)
        os.makedirs(dst_folder, exist_ok=True)

        # Determine if folder is mask folder
        is_mask = any(x in root.lower() for x in ['mask', 'groundtruth'])

        # Process all image files
        for f in files:
            if not is_image_file(f):
                continue
            src_path = os.path.join(root, f)
            dst_path = os.path.join(dst_folder, f)
            resize_image(src_path, dst_path, is_mask=is_mask, size=target_size)

# --------------------------
# EXECUTION
# --------------------------
process_dataset(src_base, resized_base)
print("Dataset processed and resized successfully!")


# Model Define

In [None]:
# SeparableConv2d remains unchanged
class SeparableConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, dilation=1, bias=False):
        super().__init__()
        self.depthwise = nn.Conv2d(in_channels, in_channels, kernel_size, stride, padding, dilation, groups=in_channels, bias=bias)
        self.pointwise = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=bias)

    def forward(self, x):
        return self.pointwise(self.depthwise(x))

# ASPP remains unchanged
class ASPP(nn.Module):
    def __init__(self, in_channels, out_channels, atrous_rates):
        super().__init__()
        modules = [
            nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False),
                nn.BatchNorm2d(out_channels),
                nn.ReLU(inplace=True)
            )
        ]
        for rate in atrous_rates:
            modules.append(nn.Sequential(
                SeparableConv2d(in_channels, out_channels, 3, padding=rate, dilation=rate, bias=False),
                nn.BatchNorm2d(out_channels),
                nn.ReLU(inplace=True)
            ))
        modules.append(nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(in_channels, out_channels, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        ))
        self.convs = nn.ModuleList(modules)
        self.project = nn.Sequential(
            nn.Conv2d((len(atrous_rates) + 2) * out_channels, out_channels, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Dropout2d(0.5)
        )

    def forward(self, x):
        size = x.shape[2:]
        res = [F.interpolate(conv(x), size=size, mode='bilinear', align_corners=True) if i == len(self.convs)-1 else conv(x) for i, conv in enumerate(self.convs)]
        return self.project(torch.cat(res, dim=1))

# MFF Block
class MFFBlock(nn.Module):
    def __init__(self, in_channels_low, in_channels_high, out_channels):
        super().__init__()
        self.low_proj = nn.Conv2d(in_channels_low, out_channels, 1, bias=False)
        self.high_proj = nn.Conv2d(in_channels_high, out_channels, 1, bias=False)
        self.fusion = nn.Sequential(
            nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
        self.se = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(out_channels, out_channels // 8, 1),
            nn.ReLU(),
            nn.Conv2d(out_channels // 8, out_channels, 1),
            nn.Sigmoid()
        )

    def forward(self, low_feat, high_feat):
        high_feat = F.interpolate(high_feat, size=low_feat.shape[2:], mode='bilinear', align_corners=True)
        low_feat = self.low_proj(low_feat)
        high_feat = self.high_proj(high_feat)
        x = low_feat + high_feat
        x = self.fusion(x)
        return x * self.se(x)

# CAFSE Block
class CAFSEBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.coarse = nn.Sequential(
            nn.Conv2d(channels, channels, 3, padding=1),
            nn.BatchNorm2d(channels),
            nn.ReLU(inplace=True)
        )
        self.fine = nn.Sequential(
            nn.Conv2d(channels, channels, 1),
            nn.BatchNorm2d(channels),
            nn.Sigmoid()
        )

    def forward(self, decoder_feat, aspp_feat):
        aspp_feat = F.interpolate(aspp_feat, size=decoder_feat.shape[2:], mode='bilinear', align_corners=True)
        coarse = self.coarse(aspp_feat)
        fine = self.fine(decoder_feat)
        return decoder_feat + coarse * fine

# Decoder remains unchanged
class Decoder(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels, 48, 1, bias=False),
            nn.BatchNorm2d(48),
            nn.ReLU(inplace=True)
        )
        self.fuse = nn.Sequential(
            SeparableConv2d(96, out_channels, 3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            SeparableConv2d(out_channels, out_channels, 3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Dropout2d(0.3)
        )

    def forward(self, x, low_level_feat):
        x = F.interpolate(x, size=low_level_feat.shape[2:], mode='bilinear', align_corners=True)
        x = self.conv1(x)
        x = torch.cat([x, low_level_feat], dim=1)
        return self.fuse(x)

# Main model
class DeepFusionLab(nn.Module):
    def __init__(self, num_classes_seg=1, num_classes_cls=2, mode=1, output_stride=16, activation='sigmoid'):
        super().__init__()
        self.mode = mode
        self.output_stride = output_stride

        backbone = efficientnet_b0(weights=EfficientNet_B0_Weights.IMAGENET1K_V1)
        features = list(backbone.features.children())
        if output_stride == 16:
            self.low_level = nn.Sequential(*features[:3])
            self.high_level = nn.Sequential(*features[3:])
        else:
            self.low_level = nn.Sequential(*features[:2])
            self.high_level = nn.Sequential(*features[2:])

        low_level_channels = 24 if output_stride == 16 else 16
        self.low_proj = nn.Sequential(
            nn.Conv2d(low_level_channels, 48, 1, bias=False),
            nn.BatchNorm2d(48),
            nn.ReLU(inplace=True)
        )

        atrous_rates = [6, 12, 18] if output_stride == 16 else [12, 24, 36]
        self.aspp = ASPP(1280, 256, atrous_rates)
        self.mff = MFFBlock(48, 256, 256)
        self.decoder = Decoder(256, 256)
        self.cafse = CAFSEBlock(256)
        self.final_conv = nn.Conv2d(256, num_classes_seg, 1)

        self.classifier = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(1280, 512),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(512, num_classes_cls)
        )

        if activation == 'sigmoid':
            self.activation = nn.Sigmoid()
        elif activation == 'softmax2d':
            self.activation = nn.Softmax2d()
        else:
            self.activation = None

    def forward(self, x):
        input_size = x.size()[2:]
        low_feat = self.low_level(x)
        high_feat = self.high_level(low_feat)
        low_proj = self.low_proj(low_feat)

        if self.mode == 0:
            out = self.classifier(high_feat)
            return out
        elif self.mode == 1:
            aspp_out = self.aspp(high_feat)
            mff_out = self.mff(low_proj, aspp_out)
            decoder_out = self.decoder(mff_out, low_proj)
            cafse_out = self.cafse(decoder_out, aspp_out)
            out = self.final_conv(cafse_out)
            out = F.interpolate(out, size=input_size, mode='bilinear', align_corners=True)
            if self.activation is not None:
                out = self.activation(out)
            return out
        else:
            raise ValueError("Mode must be 0 (classification) or 1 (segmentation)")


# Model Compile

In [None]:
model = DeepFusionLab(num_classes_seg=2, num_classes_cls=2, mode=1)  # mode=1 for segmentation

# Input image
input_tensor = torch.randn(2, 3, 256, 256)

# Forward pass
output = model(input_tensor)
print(output.shape)

def print_model_parameters(model):
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    non_trainable_params = total_params - trainable_params

    print(f"Total Parameters: {total_params:,}")
    print(f"Trainable Parameters: {trainable_params:,}")
    print(f"Non-Trainable Parameters: {non_trainable_params:,}")

# Example usage
model = DeepFusionLab(num_classes_seg=2, num_classes_cls=2, mode=1)
print_model_parameters(model)

summary(model, input_size=(2, 3, 256, 256))

# Metrics Define & Others

In [None]:
#Metrics definitions (ensure these are defined earlier in your code)
class DiceCoefficient(torch.nn.Module):
    def __init__(self, threshold=0.5):
        super(DiceCoefficient, self).__init__()
        self.threshold = threshold
        self.__name__ = "DiceCoefficient"  # Add the __name__ attribute

    def forward(self, y_true, y_pred):
        y_pred = torch.sigmoid(y_pred)  # Apply sigmoid if predictions are logits
        y_pred = (y_pred > self.threshold).float()  # Apply threshold

        intersection = (y_true * y_pred).sum(dim=(2, 3))  # Sum over height and width
        union = (y_true + y_pred).sum(dim=(2, 3))

        dice = 2. * intersection / (union + 1e-6)  # Add epsilon to avoid division by zero
        return dice.mean()  # Mean over the batch



class IoU(torch.nn.Module):
    def __init__(self, threshold=0.5, eps=1e-7, activation=None):
        """
        Intersection over Union (IoU) metric similar to SMP's implementation.

        Args:
            threshold (float): Threshold for converting probabilities to binary predictions.
            eps (float): Small value to avoid division by zero.
            activation (callable, optional): Activation function to apply to predictions (e.g., torch.sigmoid).
                                             If None, assumes inputs are already probabilities.
        """
        super(IoU, self).__init__()
        self.threshold = threshold
        self.eps = eps
        self.activation = activation
        self.__name__ = "IoU"  # For compatibility with metric logging

    def forward(self, y_pred, y_true):

        # Apply activation if provided (e.g., sigmoid for logits)
        if self.activation is not None:
            y_pred = self.activation(y_pred)

        # Convert probabilities to binary predictions using threshold
        y_pred = (y_pred > self.threshold).float()

        # Ensure inputs are binary and have matching shapes
        y_true = y_true.float()
        assert y_pred.shape == y_true.shape, f"Shape mismatch: y_pred {y_pred.shape}, y_true {y_true.shape}"

        # Compute intersection and union
        intersection = (y_pred * y_true).sum(dim=(2, 3))  # Sum over H and W dimensions
        union = (y_pred + y_true - y_pred * y_true).sum(dim=(2, 3))  # Union = A + B - Aâˆ©B

        # Compute IoU with epsilon to avoid division by zero
        iou = (intersection + self.eps) / (union + self.eps)

        # Return mean IoU over the batch
        return iou.mean()

class AUC(torch.nn.Module):
    def __init__(self):
        super(AUC, self).__init__()
        self.__name__ = "AUC"

    def forward(self, y_pred, y_true):
        # For a proper AUC, you would accumulate predictions and labels for the whole dataset
        # Here we use a rough approximation per batch by binarizing with a threshold
        y_pred = torch.sigmoid(y_pred).view(-1)
        y_true = y_true.view(-1).float()
        y_pred = y_pred.detach().cpu().numpy()
        y_true = y_true.detach().cpu().numpy()

        from sklearn.metrics import roc_auc_score
        try:
            auc = roc_auc_score(y_true, y_pred)
        except ValueError:
            auc = 0.5  # Fallback if only one class is present
        return torch.tensor(auc)

class Accuracy(torch.nn.Module):
    def __init__(self, threshold=0.5):
        super(Accuracy, self).__init__()
        self.threshold = threshold
        self.__name__ = "Accuracy"  # âœ… Add this

    def forward(self, y_pred, y_true):
        # y_pred is already sigmoid activated
        y_pred = (y_pred > self.threshold).float()
        y_true = y_true.float()
        correct = (y_pred == y_true).float()
        return correct.mean()


class Precision(torch.nn.Module):
    def __init__(self, threshold=0.5, eps=1e-7):
        super(Precision, self).__init__()
        self.threshold = threshold
        self.eps = eps
        self.__name__ = "Precision"

    def forward(self, y_pred, y_true):
        y_pred = (y_pred > self.threshold).float()
        y_true = y_true.float()

        TP = (y_pred * y_true).sum(dim=(2, 3))
        FP = (y_pred * (1 - y_true)).sum(dim=(2, 3))

        precision = (TP + self.eps) / (TP + FP + self.eps)
        return precision.mean()


class Recall(torch.nn.Module):
    def __init__(self, threshold=0.5, eps=1e-7):
        super(Recall, self).__init__()
        self.threshold = threshold
        self.eps = eps
        self.__name__ = "Recall"

    def forward(self, y_pred, y_true):
        y_pred = (y_pred > self.threshold).float()
        y_true = y_true.float()

        TP = (y_pred * y_true).sum(dim=(2, 3))
        FN = ((1 - y_pred) * y_true).sum(dim=(2, 3))

        recall = (TP + self.eps) / (TP + FN + self.eps)
        return recall.mean()


class F1Score(torch.nn.Module):
    def __init__(self, threshold=0.5, eps=1e-7):
        super(F1Score, self).__init__()
        self.threshold = threshold
        self.eps = eps
        self.__name__ = "F1Score"

    def forward(self, y_pred, y_true):
        y_pred = (y_pred > self.threshold).float()
        y_true = y_true.float()

        TP = (y_pred * y_true).sum(dim=(2, 3))
        FP = (y_pred * (1 - y_true)).sum(dim=(2, 3))
        FN = ((1 - y_pred) * y_true).sum(dim=(2, 3))

        precision = (TP + self.eps) / (TP + FP + self.eps)
        recall = (TP + self.eps) / (TP + FN + self.eps)
        f1 = 2 * precision * recall / (precision + recall + self.eps)
        return f1.mean()

dice_metric = DiceCoefficient(threshold=0.5)
iou_metric = IoU(threshold=0.5)
auc_metric = AUC()

In [None]:
class ThresholdedDiceLoss(torch.nn.Module):
    def __init__(self, threshold=0.5):
        super(ThresholdedDiceLoss, self).__init__()
        self.threshold = threshold
        self.__name__ = 'dice_loss'  # Add the __name__ attribute

    def forward(self, y_true, y_pred):
        # Apply sigmoid if the predictions are logits (before thresholding)
        y_pred = torch.sigmoid(y_pred)

        # Apply thresholding to the predicted probabilities
        y_pred = (y_pred > self.threshold).float()

        # Calculate intersection and union
        intersection = (y_true * y_pred).sum(dim=(2, 3))  # Sum over height and width
        union = (y_true + y_pred).sum(dim=(2, 3))

        # Calculate Dice coefficient and return the loss (1 - Dice coefficient)
        dice = 2. * intersection / (union + 1e-6)  # Add epsilon to avoid division by zero
        return 1 - dice.mean()  # Loss is 1 - Dice coefficient, averaged over the batch


loss_fn = ThresholdedDiceLoss(threshold=0.5)

# Early Stopping class
class EarlyStopping:
    def __init__(self, patience=5, min_delta=1e-20, metric='val_iou'):
        self.patience = patience  # Number of epochs to wait for improvement
        self.min_delta = min_delta  # Minimum improvement required
        self.metric = metric  # Metric to monitor ('val_iou', 'val_dice', or 'val_loss')
        self.best_score = None
        self.wait = 0
        self.stop_training = False

    def __call__(self, metrics_dict):
        current_score = metrics_dict[self.metric]
        if self.metric == 'val_loss':
            current_score = -current_score  # Lower loss is better, so negate for consistency

        if self.best_score is None:
            self.best_score = current_score
        elif current_score < self.best_score + self.min_delta:
            self.wait += 1
            print(f"No improvement in {self.metric}. Wait: {self.wait}/{self.patience}")
            if self.wait >= self.patience:
                self.stop_training = True
                print(f"Early stopping triggered after {self.wait} epochs without improvement!")
        else:
            self.best_score = current_score
            self.wait = 0

# One-hot encoding and decoding
def one_hot_encode(label, label_values):
    semantic_map = []
    for colour in label_values:
        equality = np.equal(label, colour)
        class_map = np.all(equality, axis=-1)
        semantic_map.append(class_map)
    return np.stack(semantic_map, axis=-1)

def reverse_one_hot(image):
    return np.argmax(image, axis=-1)

def colour_code_segmentation(image, label_values):
    colour_codes = np.array(label_values)
    return colour_codes[image.astype(int)]

In [None]:
# Define the threshold
threshold = np.array([112, 127, 127])

# Define the input directory for masks
mask_dir = '/kaggle/working/Resized/Train/Mask'

# List all files in the mask directory
mask_files = os.listdir(mask_dir)

# Iterate through each mask file
for mask_filename in mask_files:
    mask_path = os.path.join(mask_dir, mask_filename)

    # Read the image
    mask_img = cv2.imread(mask_path)

    if mask_img is None:
        print(f"Could not read image: {mask_path}")
        continue

    # Check if the image is grayscale or color and convert to BGR if needed for comparison
    if len(mask_img.shape) == 2: # Grayscale
        # Convert grayscale to 3 channels to compare with a 3-channel threshold
        mask_img_color = cv2.cvtColor(mask_img, cv2.COLOR_GRAY2BGR)
    else: # Color image
        mask_img_color = mask_img.copy() # Work on a copy

    # Create a boolean mask where pixels are less than the threshold in all channels
    # Note: cv2.threshold and similar functions are typically for single-channel images.
    # We can use numpy broadcasting and boolean indexing for multi-channel thresholding.
    mask_below_threshold = np.all(mask_img_color < threshold, axis=-1)

    # Create the output image (initialized to black)
    thresholded_mask = np.zeros_like(mask_img_color, dtype=np.uint8)

    # Set pixels greater than or equal to the threshold to [255, 255, 255] (white)
    # The condition is the opposite of mask_below_threshold
    thresholded_mask[~mask_below_threshold] = [255, 255, 255]

    # Save the thresholded image back to the same location (overwriting the original mask)
    cv2.imwrite(mask_path, thresholded_mask)

print("Mask images on Train/Masks have been thresholded.")

In [None]:
# Define the input directory for masks
mask_dir = '/kaggle/working/Resized/Test/Mask'

# List all files in the mask directory
mask_files = os.listdir(mask_dir)

# Iterate through each mask file
for mask_filename in mask_files:
    mask_path = os.path.join(mask_dir, mask_filename)

    # Read the image
    mask_img = cv2.imread(mask_path)

    if mask_img is None:
        print(f"Could not read image: {mask_path}")
        continue

    # Check if the image is grayscale or color and convert to BGR if needed for comparison
    if len(mask_img.shape) == 2: # Grayscale
        # Convert grayscale to 3 channels to compare with a 3-channel threshold
        mask_img_color = cv2.cvtColor(mask_img, cv2.COLOR_GRAY2BGR)
    else: # Color image
        mask_img_color = mask_img.copy() # Work on a copy

    # Create a boolean mask where pixels are less than the threshold in all channels
    mask_below_threshold = np.all(mask_img_color < threshold, axis=-1)

    # Create the output image (initialized to black)
    thresholded_mask = np.zeros_like(mask_img_color, dtype=np.uint8)

    # Set pixels greater than or equal to the threshold to [255, 255, 255] (white)
    # The condition is the opposite of mask_below_threshold
    thresholded_mask[~mask_below_threshold] = [255, 255, 255]

    # Save the thresholded image back to the same location (overwriting the original mask)
    cv2.imwrite(mask_path, thresholded_mask)

print("Mask images on Test/Masks have been thresholded.")


In [None]:
# Define the base directory for the dataset
dataset_base_dir = '/kaggle/working/Resized'
train_dir = os.path.join(dataset_base_dir, 'Train')
test_dir = os.path.join(dataset_base_dir, 'Test')

# List image files in the Train directory
train_image_files = [f for f in os.listdir(os.path.join(train_dir, 'Image')) if f.endswith('.png')]

# Create a list of full paths for the training images and masks
train_image_paths = [os.path.join(train_dir, 'Image', f) for f in train_image_files]
train_mask_paths = [os.path.join(train_dir, 'Mask', f) for f in train_image_files] # Assuming mask filenames match image filenames

# Create a DataFrame for the training data
train_df_full = pd.DataFrame({'image_path': train_image_paths, 'mask_path': train_mask_paths})

# Split the full training data into training and validation sets
train_df, valid_df = train_test_split(train_df_full, test_size=0.1, random_state=42) # 90% train, 10% validation

# Reset indices after splitting
train_df = train_df.reset_index(drop=True)
valid_df = valid_df.reset_index(drop=True)

# List image files in the Test directory
test_image_files = [f for f in os.listdir(os.path.join(test_dir, 'Image')) if f.endswith('.png')]

# Create a list of full paths for the test images and masks
test_image_paths = [os.path.join(test_dir, 'Image', f) for f in test_image_files]
test_mask_paths = [os.path.join(test_dir, 'Mask', f) for f in test_image_files] # Assuming mask filenames match image filenames

# Create a DataFrame for the test data
test_df = pd.DataFrame({'image_path': test_image_paths, 'mask_path': test_mask_paths})

print(f"Number of training samples: {len(train_df)}")
print(f"Number of validation samples: {len(valid_df)}")
print(f"Number of test samples: {len(test_df)}")

# Display the first few rows of each DataFrame
print("\nTrain DataFrame:")
print(train_df.head())
print("\nValidation DataFrame:")
print(valid_df.head())
print("\nTest DataFrame:")
print(test_df.head())


In [None]:
class MyDataGenerator(Dataset):
    def __init__(self, df, class_rgb_values, augmentation=None, preprocessing=None):
        self.image_paths = df['image_path'].tolist()
        self.mask_paths = df['mask_path'].tolist()
        self.class_rgb_values = class_rgb_values
        self.augmentation = augmentation
        self.preprocessing = preprocessing

    def __getitem__(self, i):
        image = cv2.cvtColor(cv2.imread(self.image_paths[i]), cv2.COLOR_BGR2RGB)
        mask = cv2.cvtColor(cv2.imread(self.mask_paths[i]), cv2.COLOR_BGR2RGB)

        # One-hot encode to shape (H, W, C), values in {0, 1}
        mask = one_hot_encode(mask, self.class_rgb_values).astype('float')  # âœ… no division

        if self.augmentation:
            sample = self.augmentation(image=image, mask=mask)
            image, mask = sample['image'], sample['mask']

        if self.preprocessing:
            sample = self.preprocessing(image=image, mask=mask)
            image, mask = sample['image'], sample['mask']

        return image, mask

    def __len__(self):
        return len(self.image_paths)


# Augmentations
def get_training_augmentation():
    return album.Compose([
        album.HorizontalFlip(p=0.5),  # Existing: Randomly flip horizontally
        album.VerticalFlip(p=0.5),    # Randomly flip vertically (polyps can appear in any orientation)
    ])

def get_validation_augmentation():
    return album.Compose([
        album.PadIfNeeded(min_height=256, min_width=256, always_apply=True, border_mode=cv2.BORDER_CONSTANT, value=0)
    ])

def to_tensor(x, **kwargs):
    return x.transpose(2, 0, 1).astype('float32')

def get_preprocessing():
    return album.Compose([
        album.Resize(height=256, width=256, always_apply=True),  # Resize to 256x256
        album.Lambda(image=to_tensor, mask=to_tensor)  # Convert to tensor
    ])

select_class_rgb_values = np.array([[0, 0, 0],
                            [255, 255, 255]])

In [None]:
# Initialize data loaders
train_dataset = MyDataGenerator(train_df, select_class_rgb_values, get_training_augmentation(), get_preprocessing())
valid_dataset = MyDataGenerator(valid_df, select_class_rgb_values, get_validation_augmentation(), get_preprocessing())
test_dataset = MyDataGenerator(test_df, select_class_rgb_values, get_validation_augmentation(), get_preprocessing())

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
valid_loader = DataLoader(valid_dataset, batch_size=32, shuffle=False, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=4)

# Initialize model, loss, optimizer, and scheduler
model = DeepFusionLab(num_classes_seg=2, num_classes_cls=2, mode=1).to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=1, T_mult=2, eta_min=1e-9)

# Metric storage lists
train_loss_list, valid_loss_list = [], []
train_dice_list, valid_dice_list = [], []
train_iou_list, valid_iou_list = [], []

# Training loop
best_iou = 0.0
early_stopping = EarlyStopping(patience=pat, min_delta=1e-20, metric='val_iou')
# Start timing
start_time = time.time()

for epoch in range(EPOCHS):
    model.train()
    train_loss, train_dice, train_iou = 0.0, 0.0, 0.0

    for images, masks in train_loader:
        images, masks = images.to(DEVICE), masks.to(DEVICE)
        optimizer.zero_grad()
        preds = model(images)
        loss = loss_fn(preds, masks)
        loss.backward()
        optimizer.step()

        train_loss += loss.item() * images.size(0)
        train_dice += dice_metric(preds, masks).item() * images.size(0)
        train_iou += iou_metric(preds, masks).item() * images.size(0)

    train_loss /= len(train_loader.dataset)
    train_dice /= len(train_loader.dataset)
    train_iou /= len(train_loader.dataset)

    model.eval()
    valid_loss, valid_dice, valid_iou = 0.0, 0.0, 0.0

    with torch.no_grad():
        for images, masks in valid_loader:
            images, masks = images.to(DEVICE), masks.to(DEVICE)
            preds = model(images)
            loss = loss_fn(preds, masks)

            valid_loss += loss.item() * images.size(0)
            valid_dice += dice_metric(preds, masks).item() * images.size(0)
            valid_iou += iou_metric(preds, masks).item() * images.size(0)

    valid_loss /= len(valid_loader.dataset)
    valid_dice /= len(valid_loader.dataset)
    valid_iou /= len(valid_loader.dataset)

    # Logging
    print(f"Epoch {epoch+1}: "
          f"Train Loss={train_loss:.4f}, Train Dice={train_dice:.4f}, Train IoU={train_iou:.4f}, "
          f"Valid Loss={valid_loss:.4f}, Valid Dice={valid_dice:.4f}, Valid IoU={valid_iou:.4f}")

    # Store metrics
    train_loss_list.append(train_loss)
    valid_loss_list.append(valid_loss)

    train_dice_list.append(train_dice)
    valid_dice_list.append(valid_dice)

    train_iou_list.append(train_iou)
    valid_iou_list.append(valid_iou)

    # Save best model
    if valid_iou > best_iou:
        best_iou = valid_iou
        torch.save(model.state_dict(), 'Best_Weight.pth')
        print("Model saved!")

    # Early stopping
    metrics_dict = {'val_loss': valid_loss, 'val_dice': valid_dice, 'val_iou': valid_iou}
    early_stopping(metrics_dict)
    if early_stopping.stop_training:
        print(f"Training stopped at epoch {epoch+1}")
        break

    scheduler.step()

# End timing
end_time = time.time()
elapsed_time = end_time - start_time

print(f"\nTotal training time: {elapsed_time / 60:.2f} minutes")

In [None]:
# Plotting
epochs_range = range(1, len(train_loss_list) + 1)
plt.figure(figsize=(18, 5))

# Loss plot
plt.subplot(1, 3, 1)
plt.plot(epochs_range, train_loss_list, label='Train Loss')
plt.plot(epochs_range, valid_loss_list, label='Val Loss')
plt.title('Loss over Epochs')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()

# Dice score plot
plt.subplot(1, 3, 2)
plt.plot(epochs_range, train_dice_list, label='Train Dice')
plt.plot(epochs_range, valid_dice_list, label='Val Dice')
plt.title('Dice Score over Epochs')
plt.xlabel('Epoch')
plt.ylabel('Dice Score')
plt.legend()

# IoU plot
plt.subplot(1, 3, 3)
plt.plot(epochs_range, train_iou_list, label='Train IoU')
plt.plot(epochs_range, valid_iou_list, label='Val IoU')
plt.title('IoU over Epochs')
plt.xlabel('Epoch')
plt.ylabel('IoU')
plt.legend()

plt.tight_layout()
plt.show()


In [None]:
# Initialize model
model = DeepFusionLab(num_classes_seg=2, num_classes_cls=2, mode=1).to(DEVICE)

# Load best weights
model.load_state_dict(torch.load('Best_Weight.pth', map_location=DEVICE))
print("âœ… Loaded best model weights")

# Set to evaluation mode
model.eval()

# Create test loader
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False, num_workers=4)

# Metrics
metrics = [
    DiceCoefficient(threshold=0.5),
    IoU(threshold=0.5),
    AUC(),
    Accuracy(threshold=0.5),
    Precision(threshold=0.5),
    Recall(threshold=0.5),
    F1Score(threshold=0.5),
]

# Initialize accumulators
test_stats = {metric.__name__: 0.0 for metric in metrics}
test_loss = 0.0
total_samples = 0

# Evaluation loop
with torch.no_grad():
    for images, masks in test_loader:
        images, masks = images.to(DEVICE), masks.to(DEVICE)
        preds = model(images)

        # Loss
        loss = loss_fn(preds, masks)
        test_loss += loss.item() * images.size(0)

        # Metrics
        for metric in metrics:
            test_stats[metric.__name__] += metric(preds, masks).item() * images.size(0)

        total_samples += images.size(0)

# Average everything
test_loss /= total_samples
for key in test_stats:
    test_stats[key] /= total_samples

# Print results
print("\nðŸ“Š Test Evaluation Results:")
print(f"  Loss: {test_loss:.4f}")
for key, value in test_stats.items():
    print(f"  {key}: {value:.4f}")
