## SECTION A: Imports, Seed Setup, and File Paths

In [None]:
# SECTION A - Imports and Setup

import os
import random
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt

import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import torch.optim as optim
import torchvision.transforms as T
import torch.nn as nn
import torch.nn.functional as F
import pytorch_ssim

seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

# Check device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# === Define Paths ===
base_path = r"D:\Datasets\ham10000"  

train_image_dir = os.path.join(base_path, 'images/train')
val_image_dir = os.path.join(base_path, 'images/val')
test_image_dir = os.path.join(base_path, 'images/test')

train_mask_dir = os.path.join(base_path, 'annotations/train')
val_mask_dir = os.path.join(base_path, 'annotations/val')
test_mask_dir = os.path.join(base_path, 'annotations/test')

# Read image IDs
def read_ids(filepath):
    with open(filepath, 'r') as f:
        return [line.strip() for line in f.readlines()]

train_ids = read_ids(os.path.join(base_path, 'train.txt'))
val_ids = read_ids(os.path.join(base_path, 'val.txt'))
test_ids = read_ids(os.path.join(base_path, 'test.txt'))


## SECTION B: PyTorch Dataset Class 

In [None]:
# SECTION B - Custom Dataset Class

class SkinLesionDataset(Dataset):
    def __init__(self, ids_list, image_dir, mask_dir, image_size=(384, 384), augment=False):
        self.ids_list = ids_list
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.image_size = image_size
        self.augment = augment
        
        # Transformations
        self.transform_image = transforms.Compose([
            transforms.Resize(image_size, interpolation=Image.BICUBIC),
            transforms.ToTensor()
        ])
        
        self.transform_mask = transforms.Compose([
            transforms.Resize(image_size, interpolation=Image.NEAREST),
            transforms.ToTensor()
        ])

    def __len__(self):
        return len(self.ids_list)

    def __getitem__(self, idx):
        image_id = self.ids_list[idx]

        image_path = os.path.join(self.image_dir, image_id + ".jpg")
        mask_path = os.path.join(self.mask_dir, image_id + ".png")

        image = Image.open(image_path).convert("RGB")
        mask = Image.open(mask_path).convert("L")

        image = self.transform_image(image)
        mask = self.transform_mask(mask)
        
        # Binarize mask
        mask = (mask >= 0.5).float()

        return image, mask

    def get_filenames(self, idx):
        batch_ids = self.ids_list[idx]
        return batch_ids

# Create Dataset Instances
train_dataset = SkinLesionDataset(train_ids, train_image_dir, train_mask_dir, image_size=(384, 384))
val_dataset = SkinLesionDataset(val_ids, val_image_dir, val_mask_dir, image_size=(384, 384))
test_dataset = SkinLesionDataset(test_ids, test_image_dir, test_mask_dir, image_size=(384, 384))

# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=4, shuffle=False, num_workers=2)

## SECTION C: Custom Loss Functions 

In [None]:
class DiceLoss(nn.Module):
    def __init__(self, smooth=1e-5):
        super(DiceLoss, self).__init__()
        self.smooth = smooth

    def forward(self, y_pred, y_true):
        y_pred = torch.sigmoid(y_pred)
        y_pred = y_pred.view(-1)
        y_true = y_true.view(-1)
        
        intersection = (y_pred * y_true).sum()
        dice = (2. * intersection + self.smooth) / (y_pred.sum() + y_true.sum() + self.smooth)
        return 1 - dice

class BinaryFocalLoss(nn.Module):
    def __init__(self, gamma=2, alpha=0.25):
        super(BinaryFocalLoss, self).__init__()
        self.gamma = gamma
        self.alpha = alpha

    def forward(self, y_pred, y_true):
        y_pred = torch.sigmoid(y_pred)
        y_true = y_true.view(-1)
        y_pred = y_pred.view(-1)

        bce = nn.BCELoss(reduction='none')(y_pred, y_true)
        pt = torch.exp(-bce)
        focal_loss = self.alpha * (1 - pt) ** self.gamma * bce
        return focal_loss.mean()



class SSIMLoss(nn.Module):
    def __init__(self):
        super(SSIMLoss, self).__init__()

    def forward(self, y_pred, y_true):
        y_pred = torch.sigmoid(y_pred)
        ssim_score = pytorch_ssim.ssim(y_pred, y_true)
        return 1 - ssim_score

class CombinedLoss(nn.Module):
    def __init__(self):
        super(CombinedLoss, self).__init__()
        self.dice = DiceLoss()
        self.focal = BinaryFocalLoss()
        self.ssim = SSIMLoss()

    def forward(self, y_pred, y_true):
        loss_dice = self.dice(y_pred, y_true)
        loss_focal = self.focal(y_pred, y_true)
        loss_ssim = self.ssim(y_pred, y_true)
        return loss_dice + loss_focal + loss_ssim

## SECTION D: Implementation of FEEINnet

In [None]:
from efficientnet_pytorch import EfficientNet

class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ConvBlock, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.2),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.2),
        )

    def forward(self, x):
        return self.conv(x)


class DecoderBlock(nn.Module):
    def __init__(self, in_channels, skip_channels, out_channels, dropout=0.5):
        super(DecoderBlock, self).__init__()
        self.upsample = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)
        self.dropout = nn.Dropout2d(dropout)
        self.conv = ConvBlock(out_channels + skip_channels, out_channels)

    def forward(self, x, skip):
        x = self.upsample(x)
        x = torch.cat([x, skip], dim=1)
        x = self.dropout(x)
        return self.conv(x)

class channel_attention(nn.Module):
    def __init__(self, in_channels, ratio=8):
        super(channel_attention, self).__init__()
        self.in_channels = in_channels
        self.shared_fc1 = nn.Linear(in_channels, in_channels // ratio)
        self.relu = nn.ReLU(inplace=True)
        self.shared_fc2 = nn.Linear(in_channels // ratio, in_channels)
        self.sigmoid = nn.Sigmoid()

    def squeeze_operation(self, x):
        # x: (B, C, H, W) → (B, C)
        B, C, H, W = x.size()
        return x.view(B, C, -1).sum(dim=2) / (H * W)

    def forward(self, x):
        B, C, H, W = x.size()

        # Global Average Pooling
        avg_pool = F.adaptive_avg_pool2d(x, 1)
        avg_pool = self.squeeze_operation(avg_pool)
        avg_out = self.shared_fc1(avg_pool)
        avg_out = self.relu(avg_out)
        avg_out = self.shared_fc2(avg_out)

        # Global Max Pooling
        max_pool = F.adaptive_max_pool2d(x, 1)
        max_pool = self.squeeze_operation(max_pool)
        max_out = self.shared_fc1(max_pool)
        max_out = self.relu(max_out)
        max_out = self.shared_fc2(max_out)

        # Combine + sigmoid
        combined = avg_out + max_out
        attn = self.sigmoid(combined).view(B, C, 1, 1)

        return x * attn.expand_as(x)

class EdgeAttentionModule(nn.Module):
    def __init__(self, num_filters):
        super(EdgeAttentionModule, self).__init__()
        # For s3 and s4
        self.conv_s3 = nn.Sequential(
            nn.Conv2d(num_filters, num_filters, kernel_size=3, padding=1),
            nn.BatchNorm2d(num_filters),
            nn.LeakyReLU(0.2, inplace=True)
        )
        self.conv_s4 = nn.Sequential(
            nn.Conv2d(num_filters, num_filters, kernel_size=3, padding=1),
            nn.BatchNorm2d(num_filters),
            nn.LeakyReLU(0.2, inplace=True)
        )
        # For combined
        self.conv_combined = nn.Sequential(
            nn.Conv2d(num_filters * 2, num_filters, kernel_size=3, padding=1),
            nn.BatchNorm2d(num_filters),
            nn.LeakyReLU(0.2, inplace=True)
        )
        # For final decoder output
        self.conv_final_map = nn.Sequential(
            nn.Conv2d(num_filters, num_filters, kernel_size=3, padding=1),
            nn.BatchNorm2d(num_filters),
            nn.LeakyReLU(0.2, inplace=True)
        )
        # After multiplication
        self.conv_enhanced = nn.Sequential(
            nn.Conv2d(num_filters, num_filters, kernel_size=3, padding=1),
            nn.BatchNorm2d(num_filters),
            nn.LeakyReLU(0.2, inplace=True)
        )

    def distance_transform(self, S):
        S = 1 - S
        kernel = torch.ones((S.shape[1], S.shape[1], 5, 5), device=S.device)
        S = F.conv2d(S, kernel, stride=1, padding=2, groups=S.shape[1])
        return S

    def calculate_di(self, Si, inv_Si):
        DT_Si = self.distance_transform(Si)
        DT_inv_Si = self.distance_transform(inv_Si)

        max_DT_Si = DT_Si.amax(dim=(2, 3), keepdim=True)
        max_DT_inv_Si = DT_inv_Si.amax(dim=(2, 3), keepdim=True)

        Di = (DT_Si / (max_DT_Si + 1e-8)) + (DT_inv_Si / (max_DT_inv_Si + 1e-8))
        return Di

    def forward(self, s3, s4, final_decoder_output):
        s3 = self.conv_s3(s3)
        s4 = F.interpolate(s4, size=s3.shape[2:], mode='bilinear', align_corners=False)
        s4 = self.conv_s4(s4)

        combined = torch.cat([s3, s4], dim=1)
        combined = self.conv_combined(combined)

        final_map = self.conv_final_map(final_decoder_output)

        boundary = torch.sigmoid(final_map)
        inv_boundary = 1 - boundary

        Di = self.calculate_di(boundary, inv_boundary)
        Mi_B = 1 - Di

        combined_up = F.interpolate(combined, size=Mi_B.shape[2:], mode='bilinear', align_corners=False)
        enhanced_boundary = combined_up * Mi_B
        enhanced_boundary = self.conv_enhanced(enhanced_boundary)

        return enhanced_boundary

class inverse_attention_module(nn.Module):
    def __init__(self, Fi_channels, Ob_channels, num_filters):
        super(inverse_attention_module, self).__init__()

        # 1x1 convolution for downsampling Ob to match Fi channels
        self.down_ob = nn.Sequential(
            nn.Conv2d(Ob_channels, num_filters, kernel_size=1, stride=1, padding=0),
            nn.BatchNorm2d(num_filters),
            nn.LeakyReLU(0.2, inplace=True)
        )

        # Convolutional block for Fi + Ob
        self.conv = nn.Sequential(
            nn.Conv2d(Fi_channels + num_filters, num_filters, kernel_size=3, padding=1),
            nn.BatchNorm2d(num_filters),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(num_filters, num_filters, kernel_size=3, padding=1),
            nn.BatchNorm2d(num_filters),
            nn.LeakyReLU(0.2, inplace=True)
        )

    def forward(self, Ui_plus_1, Fi, Ob):
        B, C, H, W = Fi.shape
        _, _, H_ob, W_ob = Ob.shape

        # Compute downsampling factors for Ob
        factor_h = max(H_ob // H, 1)
        factor_w = max(W_ob // W, 1)

        # Downsample Ob using stride to match Fi resolution
        Ob_down = F.conv2d(
            Ob,
            weight=self.down_ob[0].weight,
            bias=self.down_ob[0].bias,
            stride=(factor_h, factor_w),
            padding=0
        )
        Ob_down = self.down_ob[1](Ob_down)
        Ob_down = self.down_ob[2](Ob_down)

        # Concatenate Fi and downsampled Ob
        Fi_concat = torch.cat([Fi, Ob_down], dim=1)

        # Convolution block to reduce channels
        Fi_concat_conv = self.conv(Fi_concat)

        # Sigmoid on Ui_plus_1
        Si = torch.sigmoid(Ui_plus_1)
        inv_Si = 1 - Si

        # Upsample Fi_concat_conv to match Ui_plus_1
        Fi_concat_conv_up = F.interpolate(Fi_concat_conv, size=Ui_plus_1.shape[2:], mode='bilinear', align_corners=False)

        # Element-wise multiplication
        refined_feature = Fi_concat_conv_up * inv_Si

        return refined_feature

class EfficientNetB4Encoder(nn.Module):
    def __init__(self):
        super(EfficientNetB4Encoder, self).__init__()
        self.backbone = EfficientNet.from_pretrained('efficientnet-b4')

        self.outputs = {}
        self.feature_layers = {
            "s1": "_blocks.1",   # stem
            "s2": "_blocks.2",   # block1
            "s3": "_blocks.3",   # block2
            "s4": "_blocks.4",   # block3
            "s5": "_blocks.6",   # block4
            "b1": "_blocks.7"    # bottleneck
        }

        for name, layer_name in self.feature_layers.items():
            layer = dict([*self.backbone.named_modules()])[layer_name]
            layer.register_forward_hook(self.save_output(name))

    def save_output(self, name):
        def hook(module, input, output):
            self.outputs[name] = output
        return hook

    def forward(self, x):
        _ = self.backbone.extract_features(x)
        return [self.outputs["s1"], self.outputs["s2"], self.outputs["s3"],
                self.outputs["s4"], self.outputs["s5"], self.outputs["b1"]]


class FEEINnet(nn.Module):
    def __init__(self, dropout_rate=0.5):
        super(FEEINnet, self).__init__()
        self.encoder = EfficientNetB4Encoder()
        self.s1_conv = nn.Conv2d(in_channels=48, out_channels=48, kernel_size=1)
        self.s2_conv = nn.Conv2d(in_channels=64, out_channels=144, kernel_size=1)
        self.s3_conv = nn.Conv2d(in_channels=160, out_channels=192, kernel_size=1)
        self.s4_conv = nn.Conv2d(in_channels=224, out_channels=336, kernel_size=1)
        self.s5_conv = nn.Conv2d(in_channels=384, out_channels=1632, kernel_size=1)

        self.b1_conv = nn.Conv2d(in_channels=1792, out_channels=1632, kernel_size=1)

        self.d1 = DecoderBlock(in_channels=1632, skip_channels=336, out_channels=256, dropout=dropout_rate)
        self.cam1 = channel_attention(256)

        self.d2 = DecoderBlock(256, 192, 128, dropout_rate)
        self.d3 = DecoderBlock(128, 144, 64, dropout_rate)
        self.cam3 = channel_attention(64)

        self.d4 = DecoderBlock(64, 48, 32, dropout_rate)
        self.d5 = DecoderBlock(32, 3, 16, dropout_rate)  
        self.cam5 = channel_attention(16)

        # EAM
        self.eam = EdgeAttentionModule(16)

        # IAM modules
        self.iam1 = inverse_attention_module(Fi_channels=1632, Ob_channels=16, num_filters=16)
        self.iam2 = inverse_attention_module(Fi_channels=1632, Ob_channels=16, num_filters=16)

        self.final_conv = nn.Conv2d(16, 1, kernel_size=1)

    def forward(self, x):
        s1, s2, s3, s4, s5, b1 = self.encoder(x)

        s1 = self.s1_conv(s1)
        s2 = self.s2_conv(s2)
        s3 = self.s3_conv(s3)
        s4 = self.s4_conv(s4)
        s5 = self.s5_conv(s5)
        b1 = self.b1_conv(b1)

        # Decoder
        d1 = self.cam1(self.d1(b1, s5))
        d2 = self.d2(d1, s4)
        d3 = self.cam3(self.d3(d2, s3))
        d4 = self.d4(d3, s2)
        d5 = self.cam5(self.d5(d4, s1))

        # Edge Attention
        enhanced_boundary = self.eam(s3, s4, d5)

        # Inverse Attention Mechanism
        ra1_out = self.iam1(d5, b1, enhanced_boundary)
        ra1_sum = ra1_out + d5

        ra2_out = self.iam2(ra1_sum, s5, enhanced_boundary)
        ra2_sum = ra2_out + ra1_sum

        output = torch.sigmoid(self.final_conv(ra2_sum))
        return output

## SECTION E: Optimizer and Training

In [None]:
# Image dimensions
IMAGE_WIDTH, IMAGE_HEIGHT, CHANNELS = 384, 384, 3

# Initialize model
model = FEEINnet(dropout_rate=0.4).to(device)
print(model)

# Loss
loss_fn = CombinedLoss()

# Optimizer
learning_rate = 1e-3
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, betas=(0.9, 0.999))

# Scheduler (ReduceLROnPlateau)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', factor=0.1, patience=5, min_lr=1e-6, verbose=True
)

class EarlyStopping:
    def __init__(self, patience=10, min_delta=0):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.best_loss = float('inf')
        self.early_stop = False

    def __call__(self, val_loss):
        if val_loss < self.best_loss - self.min_delta:
            self.best_loss = val_loss
            self.counter = 0
        else:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True

num_epochs = 100
early_stopping = EarlyStopping(patience=10)

train_history = {
    "loss": [], "val_loss": [],
    "iou": [], "dice": [],
    "precision": [], "recall": [],
    "accuracy": [], "ssim": []
}

for epoch in range(num_epochs):
    print(f"\nEpoch {epoch+1}/{num_epochs}")
    # Training phase
    model.train()
    train_loss = 0.0
    for images, masks in train_loader:
        images, masks = images.to(device), masks.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = loss_fn(outputs, masks)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()

    train_loss /= len(train_loader)

    # Validation phase
    model.eval()
    val_loss = 0.0
    all_iou, all_dice, all_prec, all_rec, all_acc, all_ssim = [], [], [], [], [], []
    with torch.no_grad():
        for images, masks in val_loader:
            images, masks = images.to(device), masks.to(device)
            outputs = model(images)
            loss = loss_fn(outputs, masks)
            val_loss += loss.item()

            all_iou.append(iou_score(outputs, masks))
            all_dice.append(dice_score(outputs, masks))
            all_prec.append(precision_score(outputs, masks))
            all_rec.append(recall_score(outputs, masks))
            all_acc.append(accuracy_score(outputs, masks))
            all_ssim.append(ssim_score(outputs, masks))

    val_loss /= len(val_loader)

    metrics = {
        "iou": np.mean(all_iou),
        "dice": np.mean(all_dice),
        "precision": np.mean(all_prec),
        "recall": np.mean(all_rec),
        "accuracy": np.mean(all_acc),
        "ssim": np.mean(all_ssim)
    }

    print(f"Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | "
          f"IoU: {metrics['iou']:.4f} | Dice: {metrics['dice']:.4f} | "
          f"Precision: {metrics['precision']:.4f} | Recall: {metrics['recall']:.4f} | "
          f"Accuracy: {metrics['accuracy']:.4f} | SSIM: {metrics['ssim']:.4f}")

    # Step the scheduler
    scheduler.step(val_loss)

    # Save history
    train_history["loss"].append(train_loss)
    train_history["val_loss"].append(val_loss)
    for key in ["iou", "dice", "precision", "recall", "accuracy", "ssim"]:
        train_history[key].append(metrics[key])

    # Early stopping
    early_stopping(val_loss)
    if early_stopping.early_stop:
        print("Early stopping triggered.")
        break

    # Save best model
    if val_loss <= early_stopping.best_loss:
        torch.save(model.state_dict(), "best_feeinnet.pth")

## SECTION F: Evaluate the Model

In [None]:
# Reload best saved model
model.load_state_dict(torch.load("best_feeinnet.pth"))
model.eval()

# Evaluation Function
def evaluate(model, dataloader, loss_fn):
    model.eval()
    total_loss = 0
    all_iou, all_dice, all_prec, all_rec, all_acc, all_ssim = [], [], [], [], [], []

    with torch.no_grad():
        for images, masks in dataloader:
            images, masks = images.to(device), masks.to(device)
            outputs = model(images)
            loss = loss_fn(outputs, masks)
            total_loss += loss.item()

            all_iou.append(iou_score(outputs, masks))
            all_dice.append(dice_score(outputs, masks))
            all_prec.append(precision_score(outputs, masks))
            all_rec.append(recall_score(outputs, masks))
            all_acc.append(accuracy_score(outputs, masks))
            all_ssim.append(ssim_score(outputs, masks))

    avg_loss = total_loss / len(dataloader)
    metrics = {
        "loss": avg_loss,
        "iou": np.mean(all_iou),
        "fscore": np.mean(all_dice),  # F1 = Dice
        "accuracy": np.mean(all_acc),
        "precision": np.mean(all_prec),
        "recall": np.mean(all_rec),
        "ssim": np.mean(all_ssim),
    }
    return metrics

# Validation Evaluation
val_metrics = evaluate(model, val_loader, loss_fn)
print("Validation Results:")
print(f"Validation Loss: {val_metrics['loss']:.4f}")
print(f"Validation IOU: {val_metrics['iou']:.4f}")
print(f"Validation FScore: {val_metrics['fscore']:.4f}")
print(f"Validation Accuracy: {val_metrics['accuracy']:.4f}")
print(f"Validation Precision: {val_metrics['precision']:.4f}")
print(f"Validation Recall: {val_metrics['recall']:.4f}")
print(f"Validation SSIM: {val_metrics['ssim']:.4f}")

# Test Evaluation
test_metrics = evaluate(model, test_loader, loss_fn)
print("\nTest Results:")
print(f"Test Loss: {test_metrics['loss']:.4f}")
print(f"Test IOU: {test_metrics['iou']:.4f}")
print(f"Test FScore: {test_metrics['fscore']:.4f}")
print(f"Test Accuracy: {test_metrics['accuracy']:.4f}")
print(f"Test Precision: {test_metrics['precision']:.4f}")
print(f"Test Recall: {test_metrics['recall']:.4f}")
print(f"Test SSIM: {test_metrics['ssim']:.4f}")


## SECTION G: Visualize Training and Validation Loss

In [None]:
# Directory to save outputs
output_dir = r"D:\20BPS1134\Implementation\RA\3"
os.makedirs(output_dir, exist_ok=True)

# Plot Training & Validation Loss
plt.figure(figsize=(12, 6))
plt.plot(train_history["loss"], label="Train Loss")
plt.plot(train_history["val_loss"], label="Validation Loss")
plt.title("Model Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend(loc="upper right")
plt.grid(True)
plt.savefig(os.path.join(output_dir, "loss.png"))
plt.show()

# Plot Training & Validation IoU
plt.figure(figsize=(12, 6))
plt.plot(train_history["iou"], label="Train IoU")
plt.plot(train_history["val_loss"], label="Validation IoU")  # bug fix: should use "iou"
plt.title("Model IoU Score")
plt.xlabel("Epoch")
plt.ylabel("IoU Score")
plt.legend(loc="upper left")
plt.grid(True)
plt.savefig(os.path.join(output_dir, "iou_score.png"))
plt.show()


## SECTION H: Visualize Results for 3 Images

In [None]:
from PIL import Image

# Function to load and preprocess image
def load_image(image_path):
    transform = T.Compose([
        T.Resize((384, 384)),
        T.ToTensor()
    ])
    img = Image.open(image_path).convert("RGB")
    img = transform(img)
    return img

# Function to load mask
def load_mask(mask_path):
    transform = T.Compose([
        T.Resize((384, 384)),
        T.ToTensor()
    ])
    mask = Image.open(mask_path).convert("L")
    mask = transform(mask)
    mask = (mask >= 0.5).float()
    return mask

# Function to visualize predictions
def plot_predictions(model, image_tensor, mask_tensor, output_dir, img_index):
    model.eval()
    with torch.no_grad():
        image_input = image_tensor.unsqueeze(0).to(device)  # Add batch dimension
        pred_mask = model(image_input)
        pred_mask = torch.sigmoid(pred_mask).cpu().numpy()[0, 0]
        pred_mask = np.round(pred_mask)

    fig, ax = plt.subplots(1, 3, figsize=(20, 5))
    ax[0].imshow(np.transpose(image_tensor.numpy(), (1, 2, 0)))
    ax[0].set_title('Input Image')
    ax[1].imshow(mask_tensor.squeeze().numpy(), cmap='gray')
    ax[1].set_title('True Mask')
    ax[2].imshow(pred_mask, cmap='gray')
    ax[2].set_title('Predicted Mask')

    result_path = os.path.join(output_dir, f'result_{img_index}.png')
    plt.savefig(result_path, bbox_inches='tight')
    plt.show()
    plt.close()

# Output directory
output_dir = r"D:\20BPS1134\Implementation\RA\3\HAM10000_Results"
os.makedirs(output_dir, exist_ok=True)

# Visualize predictions for 3 random test images
for i in range(1, 4):
    sample_index = np.random.choice(len(test_ids), 1, replace=False)[0]
    sample_image_path = os.path.join(test_image_dir, test_ids[sample_index] + '.jpg')
    sample_mask_path = os.path.join(test_mask_dir, test_ids[sample_index] + '.png')

    sample_image = load_image(sample_image_path)
    sample_mask = load_mask(sample_mask_path)

    plot_predictions(model, sample_image, sample_mask, output_dir, img_index=i)
