In [None]:
from google.colab import drive
drive.mount('/content/drive')

!git clone https://github.com/facebookresearch/sam2.git


In [None]:
cd sam2

In [None]:
!pip install setuptools

In [None]:
!pip install -e /content/sam2

In [None]:
import sys
sys.path.append("/content/drive/MyDrive/sam2")

In [None]:
import random
import numpy as np
import os
import pandas as pd
import cv2
import torch
import torch.nn.utils
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from glob import glob
from PIL import Image
from sklearn.model_selection import train_test_split
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator

# Set random seeds for reproducibility
np.random.seed(3)
torch.manual_seed(3)
random.seed(3)

In [None]:
# Define consistent paths
ROOT_PATH = "/content/drive/MyDrive/preprocessed/train/"
IMAGES_DIR = os.path.join(ROOT_PATH, "images/rgb/")
MASKS_DIR = os.path.join(ROOT_PATH, "masks/")
CHECKPOINT_DIR = "./checkpoints/"
FINE_TUNED_MODEL_DIR = "./"

# Create DataFrame containing image and mask paths
image_files = sorted([f for f in os.listdir(IMAGES_DIR) if f.endswith(".png")])
mask_files = sorted([f for f in os.listdir(MASKS_DIR) if f.endswith(".png")])

data_df = pd.DataFrame({
    "ImageId": image_files,
    "MaskId": mask_files,
    "image_path": [os.path.join(IMAGES_DIR, img) for img in image_files],
    "mask_path": [os.path.join(MASKS_DIR, mask) for mask in mask_files]
})

# Split data into train and validation sets
train_df, val_df = train_test_split(data_df, test_size=0.2, random_state=42)

# Prepare data lists
train_data = [
    {"image": row['image_path'], "annotation": row['mask_path']}
    for _, row in train_df.iterrows()
]

val_data = [
    {"image": row['image_path'], "annotation": row['mask_path']}
    for _, row in val_df.iterrows()
]

print(f"Training samples: {len(train_data)}")
print(f"Validation samples: {len(val_data)}")

In [None]:
# Function to visualize segmentation masks
def show_anns(anns, borders=True):
    if len(anns) == 0:
        return
    sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
    ax = plt.gca()
    ax.set_autoscale_on(False)

    img = np.ones((sorted_anns[0]['segmentation'].shape[0], sorted_anns[0]['segmentation'].shape[1], 4))
    img[:, :, 3] = 0
    for ann in sorted_anns:
        m = ann['segmentation']
        color_mask = np.concatenate([np.random.random(3), [1]])
        img[m] = color_mask
        if borders:
            import cv2
            contours, _ = cv2.findContours(m.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
            # Try to smooth contours
            contours = [cv2.approxPolyDP(contour, epsilon=0.01, closed=True) for contour in contours]
            cv2.drawContours(img, contours, -1, (0, 0, 1, 0.4), thickness=1)

    ax.imshow(img)

In [None]:
# Improved read_batch function
def read_batch(data, visualize_data=False):
    # Select a random entry
    ent = data[np.random.randint(len(data))]

    # Get full paths
    img = cv2.imread(ent["image"])[..., ::-1]  # Convert BGR to RGB
    ann_map = cv2.imread(ent["annotation"], cv2.IMREAD_GRAYSCALE)  # Read annotation as grayscale

    if img is None or ann_map is None:
        print(f"Error: Could not read image or mask from path {ent['image']} or {ent['annotation']}")
        return None, None, None, 0

    # Resize image and mask
    r = np.min([1024 / img.shape[1], 1024 / img.shape[0]])  # Scaling factor
    img = cv2.resize(img, (int(img.shape[1] * r), int(img.shape[0] * r)))
    ann_map = cv2.resize(ann_map, (int(ann_map.shape[1] * r), int(ann_map.shape[0] * r)),
                         interpolation=cv2.INTER_NEAREST)

    # Initialize a single binary mask
    binary_mask = np.zeros_like(ann_map, dtype=np.uint8)
    points = []

    # Get binary masks and combine them into a single mask
    inds = np.unique(ann_map)[1:]  # Skip the background (index 0)
    if len(inds) == 0:  # Handle case with no segmentation
        return None, None, None, 0

    for ind in inds:
        mask = (ann_map == ind).astype(np.uint8)  # Create binary mask for each unique index
        binary_mask = np.maximum(binary_mask, mask)  # Combine with the existing binary mask

    # Erode the combined binary mask to avoid boundary points
    eroded_mask = cv2.erode(binary_mask, np.ones((5, 5), np.uint8), iterations=1)

    # Get all coordinates inside the eroded mask and choose random points
    coords = np.argwhere(eroded_mask > 0)
    if len(coords) > 0:
        for _ in inds:  # Select as many points as there are unique labels
            yx = np.array(coords[np.random.randint(len(coords))])
            points.append([yx[1], yx[0]])

    points = np.array(points)
    if len(points) == 0:  # Handle case with no valid points
        return None, None, None, 0

    if visualize_data:
        # Plotting the images and points
        plt.figure(figsize=(15, 5))

        # Original Image
        plt.subplot(1, 3, 1)
        plt.title('Original Image')
        plt.imshow(img)
        plt.axis('off')

        # Segmentation Mask (binary_mask)
        plt.subplot(1, 3, 2)
        plt.title('Binarized Mask')
        plt.imshow(binary_mask, cmap='gray')
        plt.axis('off')

        # Mask with Points in Different Colors
        plt.subplot(1, 3, 3)
        plt.title('Binarized Mask with Points')
        plt.imshow(binary_mask, cmap='gray')

        # Plot points in different colors
        colors = list(mcolors.TABLEAU_COLORS.values())
        for i, point in enumerate(points):
            plt.scatter(point[0], point[1], c=colors[i % len(colors)], s=100, label=f'Point {i+1}')

        plt.axis('off')
        plt.tight_layout()
        plt.show()

    binary_mask = np.expand_dims(binary_mask, axis=-1)  # Now shape is (H, W, 1)
    binary_mask = binary_mask.transpose((2, 0, 1))
    points = np.expand_dims(points, axis=1)

    # Return the image, binarized mask, points, and number of masks
    return img, binary_mask, points, len(inds)

In [None]:
# Evaluation function to compute validation metrics
def evaluate_model(predictor, data, num_samples=50):
    total_iou = 0
    total_loss = 0
    valid_samples = 0

    for i in range(min(num_samples, len(data))):
        with torch.no_grad():
            image, mask, input_point, num_masks = read_batch([data[i]], visualize_data=False)
            if image is None or mask is None or num_masks == 0:
                continue

            input_label = np.ones((num_masks, 1))
            if not isinstance(input_point, np.ndarray) or not isinstance(input_label, np.ndarray):
                continue

            if input_point.size == 0 or input_label.size == 0:
                continue

            # Set image and get predictions
            predictor.set_image(image)
            mask_input, unnorm_coords, labels, unnorm_box = predictor._prep_prompts(
                input_point, input_label, box=None, mask_logits=None, normalize_coords=True
            )

            if unnorm_coords is None or labels is None or unnorm_coords.shape[0] == 0 or labels.shape[0] == 0:
                continue

            sparse_embeddings, dense_embeddings = predictor.model.sam_prompt_encoder(
                points=(unnorm_coords, labels), boxes=None, masks=None,
            )

            batched_mode = unnorm_coords.shape[0] > 1
            high_res_features = [feat_level[-1].unsqueeze(0) for feat_level in predictor._features["high_res_feats"]]
            low_res_masks, prd_scores, _, _ = predictor.model.sam_mask_decoder(
                image_embeddings=predictor._features["image_embed"][-1].unsqueeze(0),
                image_pe=predictor.model.sam_prompt_encoder.get_dense_pe(),
                sparse_prompt_embeddings=sparse_embeddings,
                dense_prompt_embeddings=dense_embeddings,
                multimask_output=True,
                repeat_image=batched_mode,
                high_res_features=high_res_features,
            )
            prd_masks = predictor._transforms.postprocess_masks(low_res_masks, predictor._orig_hw[-1])

            # Convert to tensors
            gt_mask = torch.tensor(mask.astype(np.float32)).cuda()
            prd_mask = torch.sigmoid(prd_masks[:, 0])

            # Calculate loss
            seg_loss = (-gt_mask * torch.log(prd_mask + 1e-6) - (1 - gt_mask) * torch.log((1 - prd_mask) + 1e-6)).mean()

            # Calculate IoU
            inter = (gt_mask * (prd_mask > 0.5)).sum(1).sum(1)
            union = gt_mask.sum(1).sum(1) + (prd_mask > 0.5).sum(1).sum(1) - inter
            iou = inter / (union + 1e-6)

            score_loss = torch.abs(prd_scores[:, 0] - iou).mean()
            loss = seg_loss + score_loss * 0.05

            total_loss += loss.item()
            total_iou += iou.mean().item()
            valid_samples += 1

    if valid_samples == 0:
        return 0, 0

    return total_loss / valid_samples, total_iou / valid_samples

In [None]:
# Evaluation function to compute validation metrics
def evaluate_model(predictor, data, num_samples=50):
    total_iou = 0
    total_loss = 0
    valid_samples = 0

    for i in range(min(num_samples, len(data))):
        with torch.no_grad():
            image, mask, input_point, num_masks = read_batch([data[i]], visualize_data=False)
            if image is None or mask is None or num_masks == 0:
                continue

            input_label = np.ones((num_masks, 1))
            if not isinstance(input_point, np.ndarray) or not isinstance(input_label, np.ndarray):
                continue

            if input_point.size == 0 or input_label.size == 0:
                continue

            # Set image and get predictions
            predictor.set_image(image)
            mask_input, unnorm_coords, labels, unnorm_box = predictor._prep_prompts(
                input_point, input_label, box=None, mask_logits=None, normalize_coords=True
            )

            if unnorm_coords is None or labels is None or unnorm_coords.shape[0] == 0 or labels.shape[0] == 0:
                continue

            sparse_embeddings, dense_embeddings = predictor.model.sam_prompt_encoder(
                points=(unnorm_coords, labels), boxes=None, masks=None,
            )

            batched_mode = unnorm_coords.shape[0] > 1
            high_res_features = [feat_level[-1].unsqueeze(0) for feat_level in predictor._features["high_res_feats"]]
            low_res_masks, prd_scores, _, _ = predictor.model.sam_mask_decoder(
                image_embeddings=predictor._features["image_embed"][-1].unsqueeze(0),
                image_pe=predictor.model.sam_prompt_encoder.get_dense_pe(),
                sparse_prompt_embeddings=sparse_embeddings,
                dense_prompt_embeddings=dense_embeddings,
                multimask_output=True,
                repeat_image=batched_mode,
                high_res_features=high_res_features,
            )
            prd_masks = predictor._transforms.postprocess_masks(low_res_masks, predictor._orig_hw[-1])

            # Convert to tensors
            gt_mask = torch.tensor(mask.astype(np.float32)).cuda()
            prd_mask = torch.sigmoid(prd_masks[:, 0])

            # Calculate loss
            seg_loss = (-gt_mask * torch.log(prd_mask + 1e-6) - (1 - gt_mask) * torch.log((1 - prd_mask) + 1e-6)).mean()

            # Calculate IoU
            inter = (gt_mask * (prd_mask > 0.5)).sum(1).sum(1)
            union = gt_mask.sum(1).sum(1) + (prd_mask > 0.5).sum(1).sum(1) - inter
            iou = inter / (union + 1e-6)

            score_loss = torch.abs(prd_scores[:, 0] - iou).mean()
            loss = seg_loss + score_loss * 0.05

            total_loss += loss.item()
            total_iou += iou.mean().item()
            valid_samples += 1

    if valid_samples == 0:
        return 0, 0

    return total_loss / valid_samples, total_iou / valid_samples

In [None]:
# Add accuracy calculation to the evaluation function
def evaluate_model(predictor, data, num_samples=50):
    total_iou = 0
    total_loss = 0
    total_accuracy = 0
    valid_samples = 0

    for i in range(min(num_samples, len(data))):
        with torch.no_grad():
            image, mask, input_point, num_masks = read_batch([data[i]], visualize_data=False)
            if image is None or mask is None or num_masks == 0:
                continue

            input_label = np.ones((num_masks, 1))
            if not isinstance(input_point, np.ndarray) or not isinstance(input_label, np.ndarray):
                continue

            if input_point.size == 0 or input_label.size == 0:
                continue

            # Set image and get predictions
            predictor.set_image(image)
            mask_input, unnorm_coords, labels, unnorm_box = predictor._prep_prompts(
                input_point, input_label, box=None, mask_logits=None, normalize_coords=True
            )

            if unnorm_coords is None or labels is None or unnorm_coords.shape[0] == 0 or labels.shape[0] == 0:
                continue

            sparse_embeddings, dense_embeddings = predictor.model.sam_prompt_encoder(
                points=(unnorm_coords, labels), boxes=None, masks=None,
            )

            batched_mode = unnorm_coords.shape[0] > 1
            high_res_features = [feat_level[-1].unsqueeze(0) for feat_level in predictor._features["high_res_feats"]]
            low_res_masks, prd_scores, _, _ = predictor.model.sam_mask_decoder(
                image_embeddings=predictor._features["image_embed"][-1].unsqueeze(0),
                image_pe=predictor.model.sam_prompt_encoder.get_dense_pe(),
                sparse_prompt_embeddings=sparse_embeddings,
                dense_prompt_embeddings=dense_embeddings,
                multimask_output=True,
                repeat_image=batched_mode,
                high_res_features=high_res_features,
            )
            prd_masks = predictor._transforms.postprocess_masks(low_res_masks, predictor._orig_hw[-1])

            # Convert to tensors
            gt_mask = torch.tensor(mask.astype(np.float32)).cuda()
            prd_mask = torch.sigmoid(prd_masks[:, 0])

            # Calculate binary prediction (0 or 1)
            pred_binary = (prd_mask > 0.5).float()

            # Calculate loss
            seg_loss = (-gt_mask * torch.log(prd_mask + 1e-6) - (1 - gt_mask) * torch.log((1 - prd_mask) + 1e-6)).mean()

            # Calculate IoU
            inter = (gt_mask * pred_binary).sum(1).sum(1)
            union = gt_mask.sum(1).sum(1) + pred_binary.sum(1).sum(1) - inter
            iou = inter / (union + 1e-6)

            # Calculate pixel accuracy
            total_pixels = gt_mask.numel()
            correct_pixels = ((pred_binary == gt_mask).float()).sum()
            accuracy = correct_pixels / total_pixels

            score_loss = torch.abs(prd_scores[:, 0] - iou).mean()
            loss = seg_loss + score_loss * 0.05

            total_loss += loss.item()
            total_iou += iou.mean().item()
            total_accuracy += accuracy.item()
            valid_samples += 1

    if valid_samples == 0:
        return 0, 0, 0

    return total_loss / valid_samples, total_iou / valid_samples, total_accuracy / valid_samples

In [None]:
# Download and setup the SAM2 model
!wget -O sam2.1_hiera_large.pt "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_large.pt"
!mkdir -p checkpoints
!mv sam2.1_hiera_large.pt checkpoints/

# Model configuration
sam2_checkpoint = "./checkpoints/sam2.1_hiera_large.pt"
model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Build the model for finetuning
sam2_model = build_sam2(model_cfg, sam2_checkpoint, device)
predictor = SAM2ImagePredictor(sam2_model)

# Set model to training mode
predictor.model.sam_mask_decoder.train(True)
predictor.model.sam_prompt_encoder.train(True)

In [None]:
# Configure optimizer and other training parameters
optimizer = torch.optim.AdamW(params=predictor.model.parameters(), lr=0.0001, weight_decay=1e-4)
scaler = torch.cuda.amp.GradScaler()
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=500, gamma=0.2)
NO_OF_STEPS = 6000
accumulation_steps = 4
FINE_TUNED_MODEL_NAME = "fine_tuned_sam2"

# Initialize metrics tracking
train_losses = []
train_ious = []
train_accuracies = []
val_losses = []
val_ious = []
val_accuracies = []
best_val_iou = 0

In [None]:
# Training loop
for step in range(1, NO_OF_STEPS + 1):
    with torch.cuda.amp.autocast():
        image, mask, input_point, num_masks = read_batch(train_data, visualize_data=False)
        if image is None or mask is None or num_masks == 0:
            continue

        input_label = np.ones((num_masks, 1))
        if not isinstance(input_point, np.ndarray) or not isinstance(input_label, np.ndarray):
            continue

        if input_point.size == 0 or input_label.size == 0:
            continue

        predictor.set_image(image)
        mask_input, unnorm_coords, labels, unnorm_box = predictor._prep_prompts(input_point, input_label, box=None, mask_logits=None, normalize_coords=True)
        if unnorm_coords is None or labels is None or unnorm_coords.shape[0] == 0 or labels.shape[0] == 0:
            continue

        sparse_embeddings, dense_embeddings = predictor.model.sam_prompt_encoder(
            points=(unnorm_coords, labels), boxes=None, masks=None,
        )

        batched_mode = unnorm_coords.shape[0] > 1
        high_res_features = [feat_level[-1].unsqueeze(0) for feat_level in predictor._features["high_res_feats"]]
        low_res_masks, prd_scores, _, _ = predictor.model.sam_mask_decoder(
            image_embeddings=predictor._features["image_embed"][-1].unsqueeze(0),
            image_pe=predictor.model.sam_prompt_encoder.get_dense_pe(),
            sparse_prompt_embeddings=sparse_embeddings,
            dense_prompt_embeddings=dense_embeddings,
            multimask_output=True,
            repeat_image=batched_mode,
            high_res_features=high_res_features,
        )
        prd_masks = predictor._transforms.postprocess_masks(low_res_masks, predictor._orig_hw[-1])

        gt_mask = torch.tensor(mask.astype(np.float32)).cuda()
        prd_mask = torch.sigmoid(prd_masks[:, 0])

        # Calculate binary prediction (0 or 1)
        pred_binary = (prd_mask > 0.5).float()

        # Calculate pixel accuracy
        total_pixels = gt_mask.numel()
        correct_pixels = ((pred_binary == gt_mask).float()).sum()
        accuracy = correct_pixels / total_pixels

        seg_loss = (-gt_mask * torch.log(prd_mask + 1e-6) - (1 - gt_mask) * torch.log((1 - prd_mask) + 1e-6)).mean()

        inter = (gt_mask * pred_binary).sum(1).sum(1)
        union = gt_mask.sum(1).sum(1) + pred_binary.sum(1).sum(1) - inter
        iou = inter / (union + 1e-6)
        score_loss = torch.abs(prd_scores[:, 0] - iou).mean()
        loss = seg_loss + score_loss * 0.05

        # Apply gradient accumulation
        loss = loss / accumulation_steps
        scaler.scale(loss).backward()

        # Track training metrics
        if step == 1:
            mean_iou = iou.mean().item()
            mean_loss = loss.item() * accumulation_steps
            mean_accuracy = accuracy.item()
        else:
            mean_iou = mean_iou * 0.95 + 0.05 * iou.mean().item()
            mean_loss = mean_loss * 0.95 + 0.05 * (loss.item() * accumulation_steps)
            mean_accuracy = mean_accuracy * 0.95 + 0.05 * accuracy.item()

        # Update weights after accumulation steps
        if step % accumulation_steps == 0:
            # Clip gradients
            torch.nn.utils.clip_grad_norm_(predictor.model.parameters(), max_norm=1.0)
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()

        # Update scheduler
        scheduler.step()

        # Evaluate on validation set periodically
        if step % 500 == 0 or step == NO_OF_STEPS:
            # Save model checkpoint
            FINE_TUNED_MODEL = f"{FINE_TUNED_MODEL_NAME}_{step}.torch"
            torch.save(predictor.model.state_dict(), FINE_TUNED_MODEL)

            # Temporarily set model to eval mode
            predictor.model.eval()
            val_loss, val_iou, val_accuracy = evaluate_model(predictor, val_data)
            predictor.model.train()

            # Track metrics
            train_losses.append(mean_loss)
            train_ious.append(mean_iou)
            train_accuracies.append(mean_accuracy)
            val_losses.append(val_loss)
            val_ious.append(val_iou)
            val_accuracies.append(val_accuracy)

            # Save best model
            if val_iou > best_val_iou:
                best_val_iou = val_iou
                torch.save(predictor.model.state_dict(), f"{FINE_TUNED_MODEL_NAME}_best.torch")

            print(f"Step {step}:")
            print(f"  Train Loss: {mean_loss:.4f}, Train IoU: {mean_iou:.4f}, Train Accuracy: {mean_accuracy:.4f}")
            print(f"  Val Loss: {val_loss:.4f}, Val IoU: {val_iou:.4f}, Val Accuracy: {val_accuracy:.4f}")

        # Display progress more frequently
        elif step % 100 == 0:
            print(f"Step {step}: Train Loss: {mean_loss:.4f}, Train IoU: {mean_iou:.4f}, Train Accuracy: {mean_accuracy:.4f}")


In [None]:
plt.figure(figsize=(18, 6))

plt.subplot(1, 3, 1)
plt.plot([500 * (i+1) for i in range(len(train_losses))], train_losses, label='Train Loss')
plt.plot([500 * (i+1) for i in range(len(val_losses))], val_losses, label='Val Loss')
plt.xlabel('Steps')
plt.ylabel('Loss')
plt.title('Training and Validation Loss')
plt.legend()

plt.subplot(1, 3, 2)
plt.plot([500 * (i+1) for i in range(len(train_ious))], train_ious, label='Train IoU')
plt.plot([500 * (i+1) for i in range(len(val_ious))], val_ious, label='Val IoU')
plt.xlabel('Steps')
plt.ylabel('IoU')
plt.title('Training and Validation IoU')
plt.legend()

plt.subplot(1, 3, 3)
plt.plot([500 * (i+1) for i in range(len(train_accuracies))], train_accuracies, label='Train Accuracy')
plt.plot([500 * (i+1) for i in range(len(val_accuracies))], val_accuracies, label='Val Accuracy')
plt.xlabel('Steps')
plt.ylabel('Pixel Accuracy')
plt.title('Training and Validation Pixel Accuracy')
plt.legend()

plt.tight_layout()
plt.savefig('training_metrics_with_accuracy.png')
plt.show()

In [None]:
# Load the best model for inference
best_model_weights = f"{FINE_TUNED_MODEL_NAME}_best.torch"
sam2_model = build_sam2(model_cfg, sam2_checkpoint, device="cuda")
predictor = SAM2ImagePredictor(sam2_model)
predictor.model.load_state_dict(torch.load(best_model_weights))
predictor.model.eval()

# Test inference on a random validation sample
def read_image(image_path, mask_path):  # read and resize image and mask
    img = cv2.imread(image_path)[..., ::-1]  # Convert BGR to RGB
    mask = cv2.imread(mask_path, 0)

    # Check if image or mask is None
    if img is None or mask is None:
        print(f"Warning: Could not read image from {image_path} or mask from {mask_path}")
        return None, None

    r = np.min([1024 / img.shape[1], 1024 / img.shape[0]])
    img = cv2.resize(img, (int(img.shape[1] * r), int(img.shape[0] * r)))
    mask = cv2.resize(mask, (int(mask.shape[1] * r), int(mask.shape[0] * r)), interpolation=cv2.INTER_NEAREST)
    return img, mask

def get_points(mask, num_points):  # Sample points inside the input mask
    # Check if mask is empty
    if mask is None:
        return np.array([])

    # Make sure we have foreground pixels
    if np.max(mask) == 0:
        print("Warning: Mask is empty (all zeros)")
        return np.array([])

    # Get coordinates of foreground pixels
    coords = np.argwhere(mask > 0)
    if len(coords) == 0:
        return np.array([])

    # Sample random points
    points = []
    for i in range(min(num_points, len(coords))):
        yx = np.array(coords[np.random.randint(len(coords))])
        points.append([[yx[1], yx[0]]])
    return np.array(points)

# Try multiple validation images until we find one with valid points
max_attempts = 10
attempt = 0
found_valid_image = False

while attempt < max_attempts and not found_valid_image:
    # Randomly select a test image from the validation data
    selected_entry = random.choice(val_data)
    image_path = selected_entry['image']
    mask_path = selected_entry['annotation']

    print(f"Attempt {attempt+1}: Trying image {os.path.basename(image_path)}")

    # Load the selected image and mask
    image, mask = read_image(image_path, mask_path)

    # Check if image and mask were loaded successfully
    if image is None or mask is None:
        attempt += 1
        continue

    # Check if mask contains any foreground pixels
    if np.max(mask) == 0:
        print(f"Mask for {os.path.basename(image_path)} is empty. Trying another image.")
        attempt += 1
        continue

    # Generate random points for the input
    num_samples = 30  # Number of points per segment to sample
    input_points = get_points(mask, num_samples)

    if len(input_points) > 0:
        found_valid_image = True
        print(f"Found valid image: {os.path.basename(image_path)}")
    else:
        print(f"No valid points found in {os.path.basename(image_path)}. Trying another image.")
        attempt += 1

if not found_valid_image:
    print(f"Failed to find a valid image after {max_attempts} attempts.")
    print("Please check your validation dataset or increase the number of attempts.")
else:
    # Perform inference and predict masks
    with torch.no_grad():
        predictor.set_image(image)
        masks, scores, logits = predictor.predict(
            point_coords=input_points,
            point_labels=np.ones([input_points.shape[0], 1])
        )

    # Process the predicted masks and sort by scores
    np_masks = np.array(masks[:, 0])
    np_scores = scores[:, 0]
    sorted_indices = np.argsort(np_scores)[::-1]
    sorted_masks = np_masks[sorted_indices]
    sorted_scores = np_scores[sorted_indices]  # Also keep track of sorted scores

    # Print score information
    print(f"Number of predicted masks: {len(sorted_scores)}")
    print(f"Prediction scores (top 5): {sorted_scores[:5]}")

    # Initialize segmentation map and occupancy mask
    seg_map = np.zeros_like(sorted_masks[0], dtype=np.uint8)
    occupancy_mask = np.zeros_like(sorted_masks[0], dtype=bool)

    # Combine masks to create the final segmentation map
    for i in range(sorted_masks.shape[0]):
        mask_i = sorted_masks[i]

        # Skip masks with low scores
        if sorted_scores[i] < 0.7:  # You can adjust this threshold
            continue

        # Calculate overlap with existing mask
        overlap_ratio = 0
        if mask_i.sum() > 0:  # Avoid division by zero
            overlap_ratio = (mask_i * occupancy_mask).sum() / mask_i.sum()

        # Skip if there's too much overlap
        if overlap_ratio > 0.15:
            continue

        mask_bool = mask_i.astype(bool)
        mask_bool[occupancy_mask] = False  # Set overlapping areas to False in the mask
        seg_map[mask_bool] = i + 1  # Use boolean mask to index seg_map
        occupancy_mask[mask_bool] = True  # Update occupancy_mask

    # Visualization: Show the original image, mask, input points, and final segmentation
    plt.figure(figsize=(20, 5))

    plt.subplot(1, 4, 1)
    plt.title('Test Image')
    plt.imshow(image)
    # Plot points on the image
    for point in input_points:
        plt.scatter(point[0][0], point[0][1], c='r', s=40)
    plt.axis('off')

    plt.subplot(1, 4, 2)
    plt.title('Original Mask')
    plt.imshow(mask, cmap='gray')
    plt.axis('off')

    plt.subplot(1, 4, 3)
    plt.title('Highest Scoring Predicted Mask')
    # Show the highest-scoring mask
    if len(sorted_masks) > 0:
        plt.imshow(sorted_masks[0], cmap='gray')
    plt.axis('off')

    plt.subplot(1, 4, 4)
    plt.title('Final Segmentation')
    plt.imshow(seg_map, cmap='jet')
    plt.axis('off')

    plt.tight_layout()
    plt.savefig('inference_example.png')
    plt.show()

    # Calculate metrics for the inference result
    gt_binary = mask > 0  # Convert ground truth mask to binary
    pred_binary = seg_map > 0  # Convert prediction to binary

    # Calculate IoU
    intersection = np.logical_and(gt_binary, pred_binary).sum()
    union = np.logical_or(gt_binary, pred_binary).sum()
    iou = intersection / union if union > 0 else 0

    # Calculate pixel accuracy
    total_pixels = gt_binary.size
    correct_pixels = np.sum(gt_binary == pred_binary)
    accuracy = correct_pixels / total_pixels

    print(f"Inference metrics - IoU: {iou:.4f}, Pixel Accuracy: {accuracy:.4f}")