In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, Subset
import cv2
import os
from glob import glob
import numpy as np
from tqdm import tqdm
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import albumentations as A
import torch.nn.functional as F
import torchvision.transforms as T
from albumentations.pytorch import ToTensorV2
from PIL import Image



  from .autonotebook import tqdm as notebook_tqdm


In [None]:
class EdgeSegmentationDataset(Dataset):
    def __init__(self, image_dir, mask_dir, target_size=(512, 512), transform=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.target_size = target_size
        self.transform = transform

        
        image_extensions = ["*.jpg", "*.jpeg", "*.png"]
        image_files = []
        for ext in image_extensions:
            image_files.extend(glob(os.path.join(image_dir, ext)))
            
        mask_files = glob(os.path.join(mask_dir, "*_mask.png"))
        
        image_dict = {os.path.splitext(os.path.basename(f))[0]: f for f in image_files}
        mask_dict = {os.path.splitext(os.path.basename(f))[0].replace('_mask', ''): f for f in mask_files}
        
        common_keys = set(image_dict.keys()) & set(mask_dict.keys())
        self.pairs = [(image_dict[key], mask_dict[key]) for key in common_keys]
        print(f"Found {len(self.pairs)} valid image-mask pairs")
        
    def __len__(self):
        return len(self.pairs)
        
    def __getitem__(self, idx):
        img_path, mask_path = self.pairs[idx]

        # Load image and mask in grayscale as NumPy arrays
        image = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
        mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
        
        if image is None or mask is None:
            raise FileNotFoundError(f"Failed to load image or mask at {img_path} or {mask_path}")

        # Resize using cv2
        image = cv2.resize(image, (self.target_size[1], self.target_size[0]), interpolation=cv2.INTER_AREA)
        mask = cv2.resize(mask, (self.target_size[1], self.target_size[0]), interpolation=cv2.INTER_NEAREST)

        # Apply augmentations with albumentations 
        if self.transform is not None:
            augmentations = self.transform(image=image, mask=mask)
            image = augmentations["image"]
            mask = augmentations["mask"]

        # Ensure image and mask are NumPy arrays before converting to tensors
        if not isinstance(image, np.ndarray) or not isinstance(mask, np.ndarray):
            raise TypeError(f"Expected NumPy arrays after transform, got {type(image)} for image and {type(mask)} for mask")

        # Manually convert NumPy arrays to PyTorch tensors
        image = torch.from_numpy(image).float()  # Convert to float tensor
        mask = torch.from_numpy(mask).float()
        image = image.unsqueeze(0)  # Add channel dim: (H, W) -> (1, H, W)
        mask = mask.unsqueeze(0)

        # Normalize image to [0, 1]
        image = image / 255.0

        # Binarize mask, original range was 0, 255
        mask = (mask > 127.5).float()

        return image, mask, img_path, mask_path

In [None]:
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, 1, 2, dilation=2, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, 1, 2, dilation=2, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.conv(x)

class UNET(nn.Module):
    def __init__(
            self, in_channels=1, out_channels=1, features=[64, 128, 256, 512],
    ):
        super(UNET, self).__init__()
        self.ups = nn.ModuleList()
        self.downs = nn.ModuleList()
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

        # Down part of UNET
        for feature in features:
            self.downs.append(DoubleConv(in_channels, feature))
            in_channels = feature

        # Up part of UNET
        for feature in reversed(features):
            self.ups.append(
                nn.ConvTranspose2d(
                    feature*2, feature, kernel_size=2, stride=2,
                )
            )
            self.ups.append(DoubleConv(feature*2, feature))

        self.bottleneck = DoubleConv(features[-1], features[-1]*2)
        # self.bottleneck = nn.Sequential(
        #         nn.Conv2d(features[-1], features[-1]*2, kernel_size=3, padding=2, dilation=2, bias=False),
        #         nn.BatchNorm2d(features[-1]*2),
        #         nn.ReLU(inplace=True),
        #         nn.Conv2d(features[-1]*2, features[-1]*2, kernel_size=3, padding=2, dilation=2, bias=False),
        #         nn.BatchNorm2d(features[-1]*2),
        #         nn.ReLU(inplace=True)
        #     )
        self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)

    def forward(self, x):
        skip_connections = []

        for down in self.downs:
            x = down(x)
            skip_connections.append(x)
            x = self.pool(x)

        x = self.bottleneck(x)
        skip_connections = skip_connections[::-1]

        for idx in range(0, len(self.ups), 2):
            x = self.ups[idx](x)
            skip_connection = skip_connections[idx//2]

            if x.shape != skip_connection.shape:
                x = TF.resize(x, size=skip_connection.shape[2:])

            concat_skip = torch.cat((skip_connection, x), dim=1)
            x = self.ups[idx+1](concat_skip)

        return self.final_conv(x)

In [None]:
# %%
def compute_accuracy(outputs, masks):
    preds = torch.sigmoid(outputs) > 0.5  # Apply sigmoid for prob, then threshold
    correct = (preds == masks).float()
    accuracy = correct.sum() / correct.numel()
    return accuracy.item()

# %%
def evaluate_model(model, data_loader, device, criterion):
    model.eval()
    total_loss = 0.0
    total_accuracy = 0.0
    with torch.no_grad():
        for images, masks, _, _ in data_loader:
            images = images.to(device)
            masks = masks.to(device)
            outputs = model(images)
            loss = criterion(outputs, masks)
            acc = compute_accuracy(outputs, masks)
            total_loss += loss.item()
            total_accuracy += acc
    avg_loss = total_loss / len(data_loader)
    avg_accuracy = total_accuracy / len(data_loader)
    return avg_loss, avg_accuracy

In [None]:
class OutlineConnectivityLoss(nn.Module):
    def __init__(self, smooth=1, connectivity_weight=0.1):
        super(OutlineConnectivityLoss, self).__init__()
        self.smooth = smooth
        self.connectivity_weight = connectivity_weight

    def dice_loss(self, inputs, targets):
        inputs_flat = inputs.view(-1)
        targets_flat = targets.view(-1)
        intersection = (inputs_flat * targets_flat).sum()
        return 1 - (2. * intersection + self.smooth) / (inputs_flat.sum() + targets_flat.sum() + self.smooth)

    def connectivity_penalty(self, inputs, targets):
        # Penalty on overlapping regions
        grad_h_overlap = torch.abs((inputs * targets)[:, :, 1:, :] - (inputs * targets)[:, :, :-1, :])
        grad_w_overlap = torch.abs((inputs * targets)[:, :, :, 1:] - (inputs * targets)[:, :, :, :-1])

        # Penalty on non-overlapping predicted regions (false positives)
        non_overlap = inputs * (1 - targets)  # Predicted regions that don't overlap with ground truth
        grad_h_non_overlap = torch.abs(non_overlap[:, :, 1:, :] - non_overlap[:, :, :-1, :])
        grad_w_non_overlap = torch.abs(non_overlap[:, :, :, 1:] - non_overlap[:, :, :, :-1])

        # Combine both penalties
        penalty_overlap = (grad_h_overlap.sum() + grad_w_overlap.sum()) / inputs.numel()
        penalty_non_overlap = (grad_h_non_overlap.sum() + grad_w_non_overlap.sum()) / inputs.numel()
        return penalty_overlap + penalty_non_overlap

    def forward(self, inputs, targets):
        inputs = torch.sigmoid(inputs)
        dice = self.dice_loss(inputs, targets)
        penalty = self.connectivity_penalty(inputs, targets)
        return dice + self.connectivity_weight * penalty

In [7]:
def train_model(model, train_loader, val_loader, num_epochs, device):
    model = model.to(device)
    # total_positives = 0
    # total_negatives = 0
    # for _, mask, _, _ in train_loader:
    #     mask_np = mask.numpy().flatten()
    #     positives = np.sum(mask_np == 1)
    #     negatives = np.sum(mask_np == 0)
    #     total_positives += positives
    #     total_negatives += negatives
    # pos_weight_value = total_negatives / total_positives if total_positives > 0 else 1.0
    # print(f"Using pos_weight: {pos_weight_value:.2f}")
    pos_weight = torch.tensor([2.0]).to(device)
    
    criterion_bce = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
    criterion_conn = OutlineConnectivityLoss(smooth=1, connectivity_weight=0.1)
    optimizer = optim.Adam(model.parameters(), lr=0.0001)

    
    for epoch in range(num_epochs):
        model.train()
        train_loss = 0.0
        train_loop = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}", leave=True)
        for images, masks, _, _ in train_loop:
            images, masks = images.to(device), masks.to(device)

            optimizer.zero_grad()
            outputs = model(images)
            loss_bce = criterion_bce(outputs, masks)
            loss_conn = criterion_conn(outputs, masks)
            loss = 0.5 * loss_bce + 0.5 * loss_conn  # Combine BCE and Dice+Connectivity
            loss.backward()
            optimizer.step()

            train_loss += loss.item() * images.size(0)

        train_loss /= len(train_loader.dataset)

        # Validation
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for images, masks, _, _ in train_loader:
                images, masks = images.to(device), masks.to(device)
                outputs = model(images)
                loss_bce = criterion_bce(outputs, masks)
                loss_conn = criterion_conn(outputs, masks)
                loss = 0.5 * loss_bce + 0.5 * loss_conn
                val_loss += loss.item() * images.size(0)

        val_loss /= len(val_loader.dataset)

        print(f"Epoch {epoch+1}/{num_epochs} - Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")

    return model

In [None]:
IMAGE_HEIGHT = 512
IMAGE_WIDTH = 512

train_transform = A.Compose(
    [
        A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
        A.Rotate(limit=35, p=1.0),
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.1),
        A.Normalize(
            mean=[0.0, 0.0, 0.0],
            std=[1.0, 1.0, 1.0],
            max_pixel_value=255.0,
        ),
    ],
)

val_transforms = A.Compose(
    [
        A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
        A.Normalize(
            mean=[0.0, 0.0, 0.0],
            std=[1.0, 1.0, 1.0],
            max_pixel_value=255.0,
        ),
    ],
)

image_dir = "images/"
mask_dir = "outputMasks/"
batch_size = 4
num_epochs = 80
test_split = 0.2  # 20% for testing
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Data loading
train_dataset = EdgeSegmentationDataset(image_dir, mask_dir, transform=train_transform)
val_dataset = EdgeSegmentationDataset(image_dir, mask_dir, transform=val_transforms)

# Train-test split on indices
indices = list(range(len(train_dataset)))
train_indices, test_indices = train_test_split(
    indices, test_size=test_split, random_state=42
)

train_subset = Subset(train_dataset, train_indices)
val_subset = Subset(val_dataset, test_indices)

train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(val_subset, batch_size=batch_size, shuffle=False)

print(f"Training samples: {len(train_subset)}, Test samples: {len(val_subset)}")

# Model
model = UNET().to(device)

# Train and test
model = train_model(model, train_loader, test_loader, num_epochs=num_epochs, device=device)

# Save the model
torch.save(model.state_dict(), "edge_segmentation_model.pth")
print("Model saved as edge_segmentation_model.pth")

Found 244 valid image-mask pairs
Found 244 valid image-mask pairs
Training samples: 195, Test samples: 49


Epoch 1/50: 100%|██████████| 49/49 [00:26<00:00,  1.87it/s]


Epoch 1/50 - Train Loss: 0.6988, Val Loss: 2.7976


Epoch 2/50: 100%|██████████| 49/49 [00:25<00:00,  1.90it/s]


Epoch 2/50 - Train Loss: 0.6133, Val Loss: 2.4367


Epoch 3/50: 100%|██████████| 49/49 [00:25<00:00,  1.89it/s]


Epoch 3/50 - Train Loss: 0.5761, Val Loss: 2.2299


Epoch 4/50: 100%|██████████| 49/49 [00:25<00:00,  1.90it/s]


Epoch 4/50 - Train Loss: 0.5452, Val Loss: 2.0645


Epoch 5/50: 100%|██████████| 49/49 [00:25<00:00,  1.90it/s]


Epoch 5/50 - Train Loss: 0.5177, Val Loss: 1.9982


Epoch 6/50: 100%|██████████| 49/49 [00:25<00:00,  1.89it/s]


Epoch 6/50 - Train Loss: 0.4951, Val Loss: 1.8898


Epoch 7/50: 100%|██████████| 49/49 [00:25<00:00,  1.89it/s]


Epoch 7/50 - Train Loss: 0.4731, Val Loss: 1.8241


Epoch 8/50: 100%|██████████| 49/49 [00:26<00:00,  1.88it/s]


Epoch 8/50 - Train Loss: 0.4558, Val Loss: 1.7847


Epoch 9/50: 100%|██████████| 49/49 [00:28<00:00,  1.74it/s]


Epoch 9/50 - Train Loss: 0.4356, Val Loss: 1.6345


Epoch 10/50: 100%|██████████| 49/49 [00:25<00:00,  1.90it/s]


Epoch 10/50 - Train Loss: 0.4173, Val Loss: 1.6230


Epoch 11/50: 100%|██████████| 49/49 [00:26<00:00,  1.87it/s]


Epoch 11/50 - Train Loss: 0.3992, Val Loss: 1.5365


Epoch 12/50: 100%|██████████| 49/49 [00:25<00:00,  1.91it/s]


Epoch 12/50 - Train Loss: 0.3806, Val Loss: 1.4902


Epoch 13/50: 100%|██████████| 49/49 [00:26<00:00,  1.85it/s]


Epoch 13/50 - Train Loss: 0.3651, Val Loss: 1.4113


Epoch 14/50: 100%|██████████| 49/49 [00:26<00:00,  1.87it/s]


Epoch 14/50 - Train Loss: 0.3514, Val Loss: 1.3677


Epoch 15/50: 100%|██████████| 49/49 [00:25<00:00,  1.89it/s]


Epoch 15/50 - Train Loss: 0.3378, Val Loss: 1.2798


Epoch 16/50: 100%|██████████| 49/49 [00:26<00:00,  1.88it/s]


Epoch 16/50 - Train Loss: 0.3207, Val Loss: 1.2401


Epoch 17/50: 100%|██████████| 49/49 [00:25<00:00,  1.89it/s]


Epoch 17/50 - Train Loss: 0.3047, Val Loss: 1.1854


Epoch 18/50: 100%|██████████| 49/49 [00:26<00:00,  1.88it/s]


Epoch 18/50 - Train Loss: 0.2961, Val Loss: 1.1153


Epoch 19/50: 100%|██████████| 49/49 [00:26<00:00,  1.88it/s]


Epoch 19/50 - Train Loss: 0.2815, Val Loss: 1.1065


Epoch 20/50: 100%|██████████| 49/49 [00:26<00:00,  1.88it/s]


Epoch 20/50 - Train Loss: 0.2679, Val Loss: 0.9991


Epoch 21/50: 100%|██████████| 49/49 [00:26<00:00,  1.87it/s]


Epoch 21/50 - Train Loss: 0.2551, Val Loss: 1.0030


Epoch 22/50: 100%|██████████| 49/49 [00:25<00:00,  1.91it/s]


Epoch 22/50 - Train Loss: 0.2542, Val Loss: 0.9353


Epoch 23/50: 100%|██████████| 49/49 [00:25<00:00,  1.91it/s]


Epoch 23/50 - Train Loss: 0.2452, Val Loss: 0.9496


Epoch 24/50: 100%|██████████| 49/49 [00:25<00:00,  1.91it/s]


Epoch 24/50 - Train Loss: 0.2304, Val Loss: 0.9181


Epoch 25/50: 100%|██████████| 49/49 [00:25<00:00,  1.90it/s]


Epoch 25/50 - Train Loss: 0.2274, Val Loss: 0.9036


Epoch 26/50: 100%|██████████| 49/49 [00:25<00:00,  1.91it/s]


Epoch 26/50 - Train Loss: 0.2237, Val Loss: 0.8290


Epoch 27/50: 100%|██████████| 49/49 [00:25<00:00,  1.90it/s]


Epoch 27/50 - Train Loss: 0.2126, Val Loss: 0.8168


Epoch 28/50: 100%|██████████| 49/49 [00:25<00:00,  1.91it/s]


Epoch 28/50 - Train Loss: 0.2085, Val Loss: 0.8048


Epoch 29/50: 100%|██████████| 49/49 [00:25<00:00,  1.91it/s]


Epoch 29/50 - Train Loss: 0.2003, Val Loss: 0.7432


Epoch 30/50: 100%|██████████| 49/49 [00:25<00:00,  1.89it/s]


Epoch 30/50 - Train Loss: 0.1909, Val Loss: 0.7370


Epoch 31/50: 100%|██████████| 49/49 [00:25<00:00,  1.92it/s]


Epoch 31/50 - Train Loss: 0.1879, Val Loss: 0.7078


Epoch 32/50: 100%|██████████| 49/49 [00:25<00:00,  1.90it/s]


Epoch 32/50 - Train Loss: 0.1821, Val Loss: 0.7120


Epoch 33/50: 100%|██████████| 49/49 [00:25<00:00,  1.90it/s]


Epoch 33/50 - Train Loss: 0.1796, Val Loss: 0.6654


Epoch 34/50: 100%|██████████| 49/49 [00:25<00:00,  1.92it/s]


Epoch 34/50 - Train Loss: 0.1766, Val Loss: 0.7202


Epoch 35/50: 100%|██████████| 49/49 [00:26<00:00,  1.87it/s]


Epoch 35/50 - Train Loss: 0.1797, Val Loss: 0.6940


Epoch 36/50: 100%|██████████| 49/49 [00:25<00:00,  1.91it/s]


Epoch 36/50 - Train Loss: 0.1702, Val Loss: 0.6711


Epoch 37/50: 100%|██████████| 49/49 [00:25<00:00,  1.92it/s]


Epoch 37/50 - Train Loss: 0.1691, Val Loss: 0.6445


Epoch 38/50: 100%|██████████| 49/49 [00:25<00:00,  1.92it/s]


Epoch 38/50 - Train Loss: 0.1662, Val Loss: 0.6198


Epoch 39/50: 100%|██████████| 49/49 [00:25<00:00,  1.91it/s]


Epoch 39/50 - Train Loss: 0.1622, Val Loss: 0.6215


Epoch 40/50: 100%|██████████| 49/49 [00:26<00:00,  1.87it/s]


Epoch 40/50 - Train Loss: 0.1553, Val Loss: 0.5936


Epoch 41/50: 100%|██████████| 49/49 [00:25<00:00,  1.92it/s]


Epoch 41/50 - Train Loss: 0.1522, Val Loss: 0.5886


Epoch 42/50: 100%|██████████| 49/49 [00:25<00:00,  1.91it/s]


Epoch 42/50 - Train Loss: 0.1532, Val Loss: 0.5631


Epoch 43/50: 100%|██████████| 49/49 [00:25<00:00,  1.92it/s]


Epoch 43/50 - Train Loss: 0.1474, Val Loss: 0.5803


Epoch 44/50: 100%|██████████| 49/49 [00:25<00:00,  1.92it/s]


Epoch 44/50 - Train Loss: 0.1466, Val Loss: 0.5560


Epoch 45/50: 100%|██████████| 49/49 [00:25<00:00,  1.91it/s]


Epoch 45/50 - Train Loss: 0.1518, Val Loss: 0.5902


Epoch 46/50: 100%|██████████| 49/49 [00:25<00:00,  1.92it/s]


Epoch 46/50 - Train Loss: 0.1453, Val Loss: 0.5458


Epoch 47/50: 100%|██████████| 49/49 [00:26<00:00,  1.82it/s]


Epoch 47/50 - Train Loss: 0.1451, Val Loss: 0.5641


Epoch 48/50: 100%|██████████| 49/49 [00:28<00:00,  1.70it/s]


Epoch 48/50 - Train Loss: 0.1478, Val Loss: 0.5571


Epoch 49/50: 100%|██████████| 49/49 [00:26<00:00,  1.83it/s]


Epoch 49/50 - Train Loss: 0.1408, Val Loss: 0.5390


Epoch 50/50: 100%|██████████| 49/49 [00:26<00:00,  1.83it/s]


Epoch 50/50 - Train Loss: 0.1370, Val Loss: 0.5447
Model saved as edge_segmentation_model.pth


In [None]:
from skimage.segmentation import active_contour
model.eval()
test_samples = []
with torch.no_grad():
    for images, masks, img_paths, mask_paths in test_loader:
        images = images.to(device)
        masks = masks.to(device)
        outputs = model(images)
        outputs = torch.sigmoid(outputs)  # Ensure outputs are probabilities (0-1)
        test_samples.extend(zip(images.cpu(), masks.cpu(), outputs.cpu(), img_paths, mask_paths))
        if len(test_samples) >= 5:  # Limit to 5
            break

def create_outer_frame_contour(image_shape, border_offset=5, num_points=100):
    h, w = image_shape
    # Define the rectangle just inside the image borders
    top_left = (border_offset, border_offset)
    top_right = (w - border_offset, border_offset)
    bottom_right = (w - border_offset, h - border_offset)
    bottom_left = (border_offset, h - border_offset)

    # Create points along each side of the rectangle
    points_per_side = num_points // 4  # Distribute points across 4 sides
    top = np.linspace(top_left, top_right, points_per_side)
    right = np.linspace(top_right, bottom_right, points_per_side)
    bottom = np.linspace(bottom_right, bottom_left, points_per_side)
    left = np.linspace(bottom_left, top_left, points_per_side)

    # Concatenate the points to form a closed contour
    initial_contour = np.vstack([top, right, bottom, left])
    return initial_contour

for i, (image, true_mask, pred, img_path, mask_path) in enumerate(test_samples[:5]):
    image = image.permute(1, 2, 0).numpy()  # CHW to HWC
    
    # True mask processing
    true_mask_np = true_mask.squeeze().numpy()  # Remove channel dim
    print(f"Sample {i+1} - True Mask Stats: min={true_mask_np.min():.2f}, max={true_mask_np.max():.2f}, mean={true_mask_np.mean():.2f}")
    true_mask_display = (true_mask_np * 255).astype(np.uint8)  # Scale to 0-255
    
    # Predicted mask processing
    pred_np = pred.squeeze().numpy()  # Remove channel dim
    print(f"Sample {i+1} - Pred Mask Stats: min={pred_np.min():.2f}, max={pred_np.max():.2f}, mean={pred_np.mean():.2f}")
    pred_mask_display = (pred_np * 255).astype(np.uint8)  
    # pred_mask_display = (pred_np > 0.5 ).astype(np.uint8) * 255  

    contours_pred, _ = cv2.findContours(pred_mask_display, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    pred_contour = max(contours_pred, key=cv2.contourArea).squeeze() if contours_pred else None

    if pred_contour is not None:
        print(f"Sample {i+1} - pred_contour shape: {pred_contour.shape}")
        # Ensure pred_contour has the correct shape (N, 2)
        if len(pred_contour.shape) == 1:  # If 1D, reshape it
            pred_contour = pred_contour.reshape(-1, 2)
            print("No contour found")
        elif pred_contour.shape[1] != 2:  # If shape is not (N, 2), handle error
            print(f"Sample {i+1} - Invalid pred_contour shape: {pred_contour.shape}")
            pred_contour = None


        if pred_contour is not None:
            if not np.array_equal(pred_contour[0], pred_contour[-1]):
                print(f"Sample {i+1} - Initial contour not closed, forcing closure")
                pred_contour = np.vstack([pred_contour, pred_contour[0]])
    
    outer_contour = create_outer_frame_contour(pred_mask_display.shape, border_offset=5, num_points=200)

    if outer_contour is not None:

        # Parameters for active_contour:
        # - alpha: Controls the snake's elasticity (length penalty, positive to shrink)
        # - beta: Controls the snake's rigidity (smoothness penalty)
        # - gamma: Step size for iteration
        try:
            snake_contour = active_contour(
                pred_mask_display,  
                outer_contour,     
                alpha=0.05,       # Positive alpha to encourage shrinking
                beta=.01,          # Rigidity for smoothness
                gamma=0.001,      # Step size
            )
            # Ensure the snake contour is closed
            snake_contour = np.vstack([snake_contour, snake_contour[0]])
        except Exception as e:
            print(f"Sample {i+1} - Active contour failed: {str(e)}")
            snake_contour = None
    else:
        snake_contour = None


    plt.figure(figsize=(15, 5))
    
    plt.subplot(1, 4, 1)
    plt.title("Input Image")
    plt.imshow(image, cmap='gray')
    plt.axis('off')
    
    plt.subplot(1, 4, 2)
    plt.title("True Mask")
    plt.imshow(true_mask_display, cmap='gray', vmin=0, vmax=255)  
    plt.axis('off')
    
    plt.subplot(1, 4, 3)
    plt.title("Predicted Mask")
    plt.imshow(pred_mask_display, cmap='gray', vmin=0, vmax=255)  
    if pred_contour is not None:
        plt.plot(pred_contour[:, 0], pred_contour[:, 1], 'b-', linewidth=2, label='Initial Contour')  # Blue for initial contour
    if snake_contour is not None:
        plt.plot(snake_contour[:, 0], snake_contour[:, 1], 'r-', linewidth=2, label='Snake Contour')  # Blue for initial contour

    plt.axis('off')

    
    plt.suptitle(f"Sample {i+1}: {os.path.basename(img_path)}")
    plt.tight_layout()
    plt.show()