# üè• Gynecological Surgery Organ Segmentation with SegFormer

This notebook trains a SegFormer model to segment organs (uterus, ovary, fallopian tube) from laparoscopic surgery videos.

**Model**: SegFormer-B0 (Hugging Face)  
**Dataset**: BlackWalkersAnatomy from Kaggle  
**Classes**: Background, Uterus, Ovary, Fallopian Tube

## 1. Install Dependencies

In [None]:
# Install dependencies with compatible versions
print("Installing dependencies with compatible versions...")

!pip install -q --no-cache-dir \
    "numpy==1.26.4" \
    "transformers==4.44.0" \
    "datasets==2.14.0" \
    "albumentations" \
    "evaluate" \
    "Pillow"

print("\n‚úÖ Installation complete!")
print("="*60)
print("üî¥ CRITICAL: RESTART KERNEL NOW")
print("="*60)
print("\nColab: Runtime ‚Üí Restart runtime")
print("Kaggle: Click restart button (‚ü≥)")
print("\nAfter restart:")
print("  ‚Ä¢ Skip this cell")
print("  ‚Ä¢ Run from Cell 2 onwards")
print("="*60)

## 2. Import Libraries

In [None]:
import os
import numpy as np
import pandas as pd
from PIL import Image
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
import cv2
from pathlib import Path
import glob

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from transformers import (
    SegformerForSemanticSegmentation, 
    SegformerImageProcessor,
    TrainingArguments,
    Trainer
)

import albumentations as A
from sklearn.model_selection import train_test_split
import warnings
warnings.filterwarnings('ignore')

# Set seed for reproducibility
def set_seed(seed=42):
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    
set_seed(42)

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

## 3. Dataset Configuration

In [None]:
# Dataset paths
BASE_PATH = "/kaggle/input/blackwalkersanatomy/GynSurg_Anatomy_Dataset"
IMAGE_BASE = os.path.join(BASE_PATH, "ganseg")
MASK_BASE = os.path.join(BASE_PATH, "ganseg_mask")

# Class mapping for gynecological organs
id2label = {
    0: "background",
    1: "uterus",
    2: "fallopian_tube",
    3: "ovary",
}

label2id = {v: k for k, v in id2label.items()}
num_classes = len(id2label)

print(f"Number of classes: {num_classes}")
print(f"Class mapping:")
for class_id, label in id2label.items():
    print(f"  Class {class_id}: {label}")

# Global variable for intensity mapping (will be set after data exploration)
INTENSITY_TO_CLASS = {}

## 4. Collect Image-Mask Pairs

In [None]:
def find_corresponding_frame(mask_path, image_base_dir):
    """
    Find the corresponding frame image for a mask
    
    mask_path: /path/to/ganseg_mask/GANSEG_01/0.mp4_/0_010800_06-00-00_mask.png
    Returns: /path/to/ganseg/GANSEG_01/0.mp4_/0_010800_06-00-00.png (or .jpg)
    """
    mask_filename = os.path.basename(mask_path)
    
    # Get the directory structure
    parts = Path(mask_path).parts
    ganseg_idx = parts.index('ganseg_mask')
    ganseg_id = parts[ganseg_idx + 1]  # e.g., GANSEG_01
    video_folder = parts[ganseg_idx + 2]  # e.g., 0.mp4_
    
    # Construct image directory
    image_dir = os.path.join(image_base_dir, ganseg_id, video_folder)
    
    # Remove '_mask' from filename to get image filename
    image_filename = mask_filename.replace('_mask.png', '.png')
    image_path = os.path.join(image_dir, image_filename)
    
    # Try different extensions if needed
    if not os.path.exists(image_path):
        image_path = image_path.replace('.png', '.jpg')
    if not os.path.exists(image_path):
        image_path = image_path.replace('.jpg', '.jpeg')
    
    return image_path if os.path.exists(image_path) else None

In [None]:
# Collect all mask files
print("Collecting mask files from all GANSEG folders...")
all_mask_paths = []

for ganseg_dir in sorted(os.listdir(MASK_BASE)):
    ganseg_path = os.path.join(MASK_BASE, ganseg_dir)
    if not os.path.isdir(ganseg_path):
        continue
    
    print(f"  Processing {ganseg_dir}...", end=" ")
    folder_mask_count = 0
    
    for video_folder in sorted(os.listdir(ganseg_path)):
        video_path = os.path.join(ganseg_path, video_folder)
        if not os.path.isdir(video_path):
            continue
        
        # Get all mask files in this video folder
        mask_files = glob.glob(os.path.join(video_path, '*_mask.png'))
        all_mask_paths.extend(mask_files)
        folder_mask_count += len(mask_files)
    
    print(f"{folder_mask_count} masks found")

print(f"\nTotal mask files found: {len(all_mask_paths)}")

# Find corresponding images
print("\nMatching images to masks...")
image_paths = []
mask_paths = []
missing_count = 0

for mask_path in tqdm(all_mask_paths):
    image_path = find_corresponding_frame(mask_path, IMAGE_BASE)
    if image_path:
        image_paths.append(image_path)
        mask_paths.append(mask_path)
    else:
        missing_count += 1

print(f"\n‚úì Successfully matched {len(image_paths)} image-mask pairs")
if missing_count > 0:
    print(f"‚úó {missing_count} masks without corresponding images")

# Show sample paths
if len(image_paths) > 0:
    print(f"\nSample paths:")
    print(f"  Image: {image_paths[0]}")
    print(f"  Mask:  {mask_paths[0]}")

## 5. Step 2: Create Intensity Mapping

In [None]:
if len(image_paths) == 0:
    raise ValueError("No image-mask pairs found! Check dataset structure.")

print("="*60)
print("STEP 1: ANALYZING MASK INTENSITIES")
print("="*60)

# Analyze multiple masks to get consistent intensity values
print("\nAnalyzing first 10 masks to detect all intensity values...")
all_intensities = set()

for i, mask_path in enumerate(mask_paths[:10]):
    mask = np.array(Image.open(mask_path))
    if len(mask.shape) == 3:
        mask = mask[:,:,0]  # Take first channel if RGB
    all_intensities.update(np.unique(mask))

all_intensities = sorted(list(all_intensities))
print(f"Found {len(all_intensities)} unique intensity values across samples: {all_intensities}")

In [None]:
print("\n" + "="*60)
print("STEP 2: CREATE INTENSITY TO CLASS MAPPING")
print("="*60)

# Create intensity to class mapping
INTENSITY_TO_CLASS = {intensity: idx for idx, intensity in enumerate(all_intensities)}
print(f"\nIntensity ‚Üí Class ID mapping:")
for intensity, class_id in INTENSITY_TO_CLASS.items():
    organ_name = id2label.get(class_id, "unknown")
    print(f"  Intensity {intensity:3d} ‚Üí Class {class_id} ({organ_name})")

print("\n‚úì Intensity mapping created")

## 6. Step 3: Visualize Sample Data

In [None]:
print("\n" + "="*60)
print("STEP 3: VISUALIZING SAMPLE DATA")
print("="*60)

# Find a mask with multiple intensities for better visualization
best_mask_idx = 0
max_intensities = 0
for i in range(min(20, len(mask_paths))):
    mask = np.array(Image.open(mask_paths[i]))
    if len(mask.shape) == 3:
        mask = mask[:,:,0]
    n_intensities = len(np.unique(mask))
    if n_intensities > max_intensities:
        max_intensities = n_intensities
        best_mask_idx = i

sample_img_path = image_paths[best_mask_idx]
sample_mask_path = mask_paths[best_mask_idx]

print(f"Visualizing: {os.path.basename(sample_mask_path)}")
print(f"This sample contains {max_intensities} classes\n")

# Load image and mask
img = Image.open(sample_img_path).convert('RGB')
mask = Image.open(sample_mask_path)
mask_array = np.array(mask)

# Handle RGB masks
if len(mask_array.shape) == 3:
    mask_array = mask_array[:,:,0]

unique_values = np.unique(mask_array)
print(f"Intensities present: {unique_values}")

# Pixel distribution
print(f"\nPixel distribution:")
for val in unique_values:
    count = np.sum(mask_array == val)
    percentage = (count / mask_array.size) * 100
    class_id = INTENSITY_TO_CLASS[val]
    organ = id2label[class_id]
    print(f"  {organ:15s} (intensity {val:3d}): {count:7d} pixels ({percentage:5.2f}%)")

# Create visualization
fig, axes = plt.subplots(2, 3, figsize=(18, 12))

# Row 1: Original, Full Mask, Overlay
axes[0, 0].imshow(img)
axes[0, 0].set_title("Original Surgical Image", fontsize=14, fontweight='bold')
axes[0, 0].axis('off')

# Map intensities to class IDs for visualization
class_mask = np.zeros_like(mask_array)
for intensity, class_id in INTENSITY_TO_CLASS.items():
    class_mask[mask_array == intensity] = class_id

axes[0, 1].imshow(class_mask, cmap='tab10', vmin=0, vmax=num_classes-1)
axes[0, 1].set_title("Segmentation Mask", fontsize=14, fontweight='bold')
axes[0, 1].axis('off')

# Create overlay
img_array = np.array(img)
overlay = img_array.copy().astype(float)
colors = plt.cm.Set1(np.linspace(0, 1, num_classes))
for val in unique_values:
    class_id = INTENSITY_TO_CLASS[val]
    if class_id == 0:  # Skip background
        continue
    mask_bool = mask_array == val
    color = colors[class_id][:3] * 255
    overlay[mask_bool] = overlay[mask_bool] * 0.4 + color * 0.6

axes[0, 2].imshow(overlay.astype(np.uint8))
axes[0, 2].set_title("Overlay on Image", fontsize=14, fontweight='bold')
axes[0, 2].axis('off')

# Row 2: Individual classes
class_colors_viz = {
    0: [0, 0, 0],       # background - not shown
    1: [255, 100, 100], # uterus - red
    2: [100, 255, 100], # fallopian_tube - green
    3: [100, 100, 255], # ovary - blue
}

for idx, (class_id, organ_name) in enumerate(id2label.items()):
    if class_id == 0:  # Skip background
        continue
    
    col_idx = class_id - 1
    ax = axes[1, col_idx]
    
    # Find corresponding intensity
    intensity_val = [k for k, v in INTENSITY_TO_CLASS.items() if v == class_id][0]
    
    # Create colored overlay for this class only
    colored_img = img_array.copy()
    mask_bool = mask_array == intensity_val
    
    if np.any(mask_bool):
        color = np.array(class_colors_viz[class_id])
        colored_img[mask_bool] = (colored_img[mask_bool] * 0.3 + color * 0.7).astype(np.uint8)
    
    ax.imshow(colored_img)
    
    pixel_count = np.sum(mask_array == intensity_val)
    percentage = (pixel_count / mask_array.size) * 100
    
    ax.set_title(f"{organ_name.upper()}\n({pixel_count} pixels, {percentage:.1f}%)", 
                 fontsize=12, fontweight='bold')
    ax.axis('off')

plt.tight_layout()
plt.show()

print("\n‚úì Visualization complete")

## 7. Prepare Train-Validation Split

In [None]:
print("\n" + "="*60)
print("PREPARING TRAIN-VALIDATION SPLIT")
print("="*60)

# Group by GANSEG folder and video to avoid data leakage
video_groups = []
for mask_path in mask_paths:
    parts = Path(mask_path).parts
    ganseg_id = parts[parts.index('ganseg_mask') + 1]
    video_folder = parts[parts.index('ganseg_mask') + 2]
    video_groups.append(f"{ganseg_id}_{video_folder}")

# Convert to dataframe for easier splitting
df = pd.DataFrame({
    'image_path': image_paths,
    'mask_path': mask_paths,
    'video_group': video_groups
})

print(f"\nDataset statistics:")
print(f"  Total samples: {len(df)}")
print(f"  Unique video groups: {df['video_group'].nunique()}")

print(f"\nSamples per video group:")
group_counts = df['video_group'].value_counts()
for group, count in group_counts.items():
    print(f"  {group}: {count} frames")

# Split by video group to avoid leakage
unique_groups = df['video_group'].unique()
train_groups, val_groups = train_test_split(
    unique_groups, 
    test_size=0.2, 
    random_state=42
)

train_df = df[df['video_group'].isin(train_groups)]
val_df = df[df['video_group'].isin(val_groups)]

train_images = train_df['image_path'].tolist()
train_masks = train_df['mask_path'].tolist()
val_images = val_df['image_path'].tolist()
val_masks = val_df['mask_path'].tolist()

print(f"\n‚úì Split complete:")
print(f"  Training samples: {len(train_images)} ({len(train_images)/len(df)*100:.1f}%)")
print(f"  Validation samples: {len(val_images)} ({len(val_images)/len(df)*100:.1f}%)")
print(f"\nTraining video groups: {sorted(train_groups)}")
print(f"Validation video groups: {sorted(val_groups)}")

## 8. Dataset Class Definition

In [None]:
class GynSurgDataset(Dataset):
    """Dataset for gynecological surgery segmentation"""
    
    def __init__(self, image_paths, mask_paths, processor, 
                 target_size=512, augment=False):
        self.image_paths = image_paths
        self.mask_paths = mask_paths
        self.processor = processor
        self.target_size = target_size
        self.augment = augment
        
        # Augmentation pipeline for training
        if augment:
            self.transform = A.Compose([
                A.HorizontalFlip(p=0.5),
                A.VerticalFlip(p=0.3),
                A.RandomRotate90(p=0.5),
                A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.5),
                A.GaussNoise(var_limit=(10.0, 50.0), p=0.3),
                A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.1, 
                                  rotate_limit=15, p=0.5),
                A.HueSaturationValue(hue_shift_limit=10, sat_shift_limit=20, 
                                    val_shift_limit=10, p=0.3),
            ])
        else:
            self.transform = None
    
    def mask_to_class_ids(self, mask):
        """
        Convert grayscale mask intensities to class IDs (0, 1, 2, 3).
        Uses the global INTENSITY_TO_CLASS mapping.
        """
        mask = np.array(mask)
        
        # Handle RGB masks by taking first channel
        if len(mask.shape) == 3:
            mask = mask[:,:,0]
        
        # Create class ID mask
        h, w = mask.shape
        class_mask = np.zeros((h, w), dtype=np.int64)
        
        # Map each intensity to its class ID
        for intensity, class_id in INTENSITY_TO_CLASS.items():
            class_mask[mask == intensity] = class_id
        
        return class_mask
    
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        # Load image
        image = Image.open(self.image_paths[idx]).convert("RGB")
        image = np.array(image)
        
        # Load and convert mask
        mask = Image.open(self.mask_paths[idx])
        mask = self.mask_to_class_ids(mask)
        
        # Apply augmentation
        if self.transform:
            transformed = self.transform(image=image, mask=mask)
            image = transformed['image']
            mask = transformed['mask']
        
        # Process image with SegFormer processor
        encoded = self.processor(
            images=image,
            return_tensors="pt"
        )
        
        # Resize mask to target size (nearest neighbor to preserve class IDs)
        mask_tensor = torch.from_numpy(mask).unsqueeze(0).unsqueeze(0).float()
        mask_resized = F.interpolate(
            mask_tensor, 
            size=(self.target_size, self.target_size), 
            mode="nearest"
        )
        mask_resized = mask_resized.squeeze().long()
        
        # Prepare output
        encoded_inputs = {k: v.squeeze(0) for k, v in encoded.items()}
        encoded_inputs["labels"] = mask_resized
        
        return encoded_inputs

print("‚úì Dataset class defined")

## 9. Initialize Model and Processor

In [None]:
MODEL_NAME = "nvidia/segformer-b0-finetuned-ade-512-512"
# For better accuracy, try: "nvidia/segformer-b1-finetuned-ade-512-512"

print("Loading SegFormer model and processor...")
print(f"Model: {MODEL_NAME}")

# Initialize processor
processor = SegformerImageProcessor.from_pretrained(MODEL_NAME)

# Initialize model
model = SegformerForSemanticSegmentation.from_pretrained(
    MODEL_NAME,
    num_labels=num_classes,
    id2label=id2label,
    label2id=label2id,
    ignore_mismatched_sizes=True,
)

total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"\n‚úì Model loaded successfully")
print(f"  Total parameters: {total_params:,}")
print(f"  Trainable parameters: {trainable_params:,}")

## 10. Create Datasets and DataLoaders

In [None]:
TARGET_SIZE = 512

print("Creating datasets...")

train_dataset = GynSurgDataset(
    train_images, 
    train_masks, 
    processor,
    target_size=TARGET_SIZE,
    augment=True  # Enable augmentation for training
)

val_dataset = GynSurgDataset(
    val_images, 
    val_masks, 
    processor,
    target_size=TARGET_SIZE,
    augment=False  # No augmentation for validation
)

print(f"‚úì Datasets created:")
print(f"  Train dataset: {len(train_dataset)} samples")
print(f"  Val dataset: {len(val_dataset)} samples")

# Test dataset loading
print("\nTesting dataset loading...")
sample = train_dataset[0]
print(f"  pixel_values shape: {sample['pixel_values'].shape}")
print(f"  labels shape: {sample['labels'].shape}")
print(f"  labels unique values: {torch.unique(sample['labels']).tolist()}")
print("‚úì Dataset loading successful")

# Visualize a preprocessed sample
fig, axes = plt.subplots(1, 2, figsize=(12, 6))

# Denormalize image for visualization
img = sample['pixel_values'].numpy().transpose(1, 2, 0)
img = (img - img.min()) / (img.max() - img.min())

axes[0].imshow(img)
axes[0].set_title("Preprocessed Image (512x512)", fontsize=12, fontweight='bold')
axes[0].axis('off')

axes[1].imshow(sample['labels'].numpy(), cmap='tab10', vmin=0, vmax=num_classes-1)
axes[1].set_title("Preprocessed Mask (512x512)", fontsize=12, fontweight='bold')
axes[1].axis('off')

plt.tight_layout()
plt.show()

## 11. Define Metrics Functions

In [None]:
def compute_iou(preds, labels, num_classes):
    """Compute mean IoU across all classes"""
    ious = []
    preds = preds.cpu().numpy()
    labels = labels.cpu().numpy()
    
    for cls in range(num_classes):
        pred_mask = (preds == cls)
        label_mask = (labels == cls)
        
        intersection = np.logical_and(pred_mask, label_mask).sum()
        union = np.logical_or(pred_mask, label_mask).sum()
        
        if union == 0:
            iou = float('nan')
        else:
            iou = intersection / union
        ious.append(iou)
    
    return np.nanmean(ious)

def compute_dice(preds, labels, num_classes):
    """Compute mean Dice coefficient"""
    dices = []
    preds = preds.cpu().numpy()
    labels = labels.cpu().numpy()
    
    for cls in range(num_classes):
        pred_mask = (preds == cls)
        label_mask = (labels == cls)
        
        intersection = np.logical_and(pred_mask, label_mask).sum()
        dice = (2. * intersection) / (pred_mask.sum() + label_mask.sum() + 1e-8)
        dices.append(dice)
    
    return np.mean(dices)

def compute_per_class_metrics(preds, labels, num_classes):
    """Compute IoU and Dice for each class"""
    preds = preds.cpu().numpy()
    labels = labels.cpu().numpy()
    
    metrics = {}
    for cls in range(num_classes):
        pred_mask = (preds == cls)
        label_mask = (labels == cls)
        
        intersection = np.logical_and(pred_mask, label_mask).sum()
        union = np.logical_or(pred_mask, label_mask).sum()
        
        if union > 0:
            iou = intersection / union
            dice = (2. * intersection) / (pred_mask.sum() + label_mask.sum() + 1e-8)
        else:
            iou = float('nan')
            dice = float('nan')
        
        metrics[f"class_{cls}_iou"] = iou
        metrics[f"class_{cls}_dice"] = dice
    
    return metrics

print("‚úì Metric functions defined")

## 12. Training Configuration

In [None]:
OUTPUT_DIR = "./segformer-gynsurg-b0"
BATCH_SIZE = 4  # Reduce to 2 if OOM
NUM_EPOCHS = 50
LEARNING_RATE = 5e-5
WARMUP_STEPS = 100

training_args = TrainingArguments(
    output_dir=OUTPUT_DIR,
    num_train_epochs=NUM_EPOCHS,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    learning_rate=LEARNING_RATE,
    warmup_steps=WARMUP_STEPS,
    weight_decay=0.01,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    save_total_limit=3,
    logging_steps=10,
    logging_first_step=True,
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    greater_is_better=False,
    fp16=torch.cuda.is_available(),
    dataloader_num_workers=2,
    dataloader_pin_memory=True,
    remove_unused_columns=False,
    push_to_hub=False,
    report_to="none",
)

print("Training configuration:")
print(f"  Model: {MODEL_NAME}")
print(f"  Batch size: {BATCH_SIZE}")
print(f"  Epochs: {NUM_EPOCHS}")
print(f"  Learning rate: {LEARNING_RATE}")
print(f"  Image size: {TARGET_SIZE}x{TARGET_SIZE}")
print(f"  FP16 training: {torch.cuda.is_available()}")
print(f"  Output directory: {OUTPUT_DIR}")

## 13. Custom Trainer with Metrics

In [None]:
class SegmentationTrainer(Trainer):
    """Custom trainer with IoU/Dice metrics"""
    
    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        """Compute loss"""
        labels = inputs.pop("labels")
        outputs = model(**inputs)
        logits = outputs.logits
        
        # Upsample logits to match label size
        upsampled_logits = F.interpolate(
            logits,
            size=labels.shape[-2:],
            mode="bilinear",
            align_corners=False
        )
        
        # Compute cross-entropy loss
        loss = F.cross_entropy(upsampled_logits, labels, ignore_index=-100)
        
        return (loss, outputs) if return_outputs else loss
    
    def evaluation_loop(self, dataloader, description, prediction_loss_only=None, 
                       ignore_keys=None, metric_key_prefix="eval"):
        """Custom evaluation with IoU and Dice"""
        model = self.model
        model.eval()
        
        total_loss = 0
        total_iou = 0
        total_dice = 0
        num_batches = 0
        
        # Per-class metrics
        per_class_iou = np.zeros(num_classes)
        per_class_dice = np.zeros(num_classes)
        per_class_count = np.zeros(num_classes)
        
        with torch.no_grad():
            for batch in tqdm(dataloader, desc=description):
                pixel_values = batch["pixel_values"].to(model.device)
                labels = batch["labels"].to(model.device)
                
                outputs = model(pixel_values=pixel_values)
                logits = outputs.logits
                
                # Upsample and compute loss
                upsampled_logits = F.interpolate(
                    logits,
                    size=labels.shape[-2:],
                    mode="bilinear",
                    align_corners=False
                )
                loss = F.cross_entropy(upsampled_logits, labels, ignore_index=-100)
                
                # Get predictions
                preds = upsampled_logits.argmax(dim=1)
                
                # Compute overall metrics
                iou = compute_iou(preds, labels, num_classes)
                dice = compute_dice(preds, labels, num_classes)
                
                # Compute per-class metrics
                class_metrics = compute_per_class_metrics(preds, labels, num_classes)
                for cls in range(num_classes):
                    iou_val = class_metrics[f"class_{cls}_iou"]
                    dice_val = class_metrics[f"class_{cls}_dice"]
                    if not np.isnan(iou_val):
                        per_class_iou[cls] += iou_val
                        per_class_dice[cls] += dice_val
                        per_class_count[cls] += 1
                
                total_loss += loss.item()
                total_iou += iou
                total_dice += dice
                num_batches += 1
        
        # Aggregate metrics
        metrics = {
            f"{metric_key_prefix}_loss": total_loss / num_batches,
            f"{metric_key_prefix}_iou": total_iou / num_batches,
            f"{metric_key_prefix}_dice": total_dice / num_batches,
        }
        
        # Add per-class metrics
        for cls in range(num_classes):
            if per_class_count[cls] > 0:
                metrics[f"{metric_key_prefix}_iou_{id2label[cls]}"] = per_class_iou[cls] / per_class_count[cls]
                metrics[f"{metric_key_prefix}_dice_{id2label[cls]}"] = per_class_dice[cls] / per_class_count[cls]
        
        return metrics

print("‚úì Custom trainer defined")

## 14. Train the Model

In [None]:
trainer = SegmentationTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
)

print("\n" + "="*60)
print("STARTING TRAINING")
print("="*60)
print(f"Training on {len(train_dataset)} samples")
print(f"Validating on {len(val_dataset)} samples")
print("="*60 + "\n")

train_result = trainer.train()

print("\n" + "="*60)
print("‚úì TRAINING COMPLETED!")
print("="*60)

## 15. Save the Model

In [None]:
print("\nSaving model...")
trainer.save_model(OUTPUT_DIR)
processor.save_pretrained(OUTPUT_DIR)
print(f"‚úì Model saved to {OUTPUT_DIR}")

## 16. Inference and Visualization

In [None]:
def visualize_predictions(model, dataset, processor, num_samples=5, indices=None):
    """Visualize model predictions"""
    model.eval()
    device = next(model.parameters()).device
    
    if indices is None:
        indices = np.random.choice(len(dataset), min(num_samples, len(dataset)), replace=False)
    else:
        num_samples = len(indices)
    
    fig, axes = plt.subplots(num_samples, 4, figsize=(20, 5*num_samples))
    if num_samples == 1:
        axes = axes.reshape(1, -1)
    
    with torch.no_grad():
        for idx, sample_idx in enumerate(indices):
            sample = dataset[sample_idx]
            pixel_values = sample["pixel_values"].unsqueeze(0).to(device)
            true_mask = sample["labels"]
            
            # Get prediction
            outputs = model(pixel_values=pixel_values)
            logits = outputs.logits
            
            # Upsample to original size
            upsampled_logits = F.interpolate(
                logits,
                size=true_mask.shape,
                mode="bilinear",
                align_corners=False
            )
            pred_mask = upsampled_logits.argmax(dim=1).squeeze().cpu()
            
            # Denormalize image for visualization
            img = pixel_values.squeeze().cpu().numpy().transpose(1, 2, 0)
            img = (img - img.min()) / (img.max() - img.min())
            
            # Compute metrics for this sample
            sample_iou = compute_iou(pred_mask.unsqueeze(0), true_mask.unsqueeze(0), num_classes)
            sample_dice = compute_dice(pred_mask.unsqueeze(0), true_mask.unsqueeze(0), num_classes)
            
            # Plot input image
            axes[idx, 0].imshow(img)
            axes[idx, 0].set_title("Input Image", fontsize=12, fontweight='bold')
            axes[idx, 0].axis('off')
            
            # Plot ground truth
            axes[idx, 1].imshow(true_mask, cmap='tab10', vmin=0, vmax=num_classes-1)
            axes[idx, 1].set_title("Ground Truth", fontsize=12, fontweight='bold')
            axes[idx, 1].axis('off')
            
            # Plot prediction
            axes[idx, 2].imshow(pred_mask, cmap='tab10', vmin=0, vmax=num_classes-1)
            axes[idx, 2].set_title(f"Prediction\nIoU: {sample_iou:.3f}, Dice: {sample_dice:.3f}", 
                                  fontsize=12, fontweight='bold')
            axes[idx, 2].axis('off')
            
            # Plot overlay
            img_uint8 = (img * 255).astype(np.uint8)
            overlay = img_uint8.copy()
            colors = plt.cm.Set1(np.linspace(0, 1, num_classes))
            
            for cls in range(1, num_classes):  # Skip background
                mask_bool = pred_mask.numpy() == cls
                if np.any(mask_bool):
                    color = colors[cls][:3] * 255
                    overlay[mask_bool] = (overlay[mask_bool] * 0.4 + color * 0.6).astype(np.uint8)
            
            axes[idx, 3].imshow(overlay)
            axes[idx, 3].set_title("Overlay", fontsize=12, fontweight='bold')
            axes[idx, 3].axis('off')
    
    plt.tight_layout()
    plt.show()

print("\n" + "="*60)
print("VISUALIZING PREDICTIONS ON VALIDATION SET")
print("="*60 + "\n")

visualize_predictions(model, val_dataset, processor, num_samples=5)

## 17. Final Evaluation Metrics

In [None]:
print("\n" + "="*60)
print("FINAL EVALUATION ON VALIDATION SET")
print("="*60)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()

val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False, num_workers=2)

total_iou = 0
total_dice = 0
per_class_iou = np.zeros(num_classes)
per_class_dice = np.zeros(num_classes)
per_class_count = np.zeros(num_classes)
num_batches = 0

with torch.no_grad():
    for batch in tqdm(val_loader, desc="Evaluating"):
        pixel_values = batch["pixel_values"].to(device)
        labels = batch["labels"].to(device)
        
        outputs = model(pixel_values=pixel_values)
        logits = outputs.logits
        
        upsampled_logits = F.interpolate(
            logits,
            size=labels.shape[-2:],
            mode="bilinear",
            align_corners=False
        )
        preds = upsampled_logits.argmax(dim=1)
        
        # Overall metrics
        iou = compute_iou(preds, labels, num_classes)
        dice = compute_dice(preds, labels, num_classes)
        
        # Per-class metrics
        class_metrics = compute_per_class_metrics(preds, labels, num_classes)
        for cls in range(num_classes):
            iou_val = class_metrics[f"class_{cls}_iou"]
            dice_val = class_metrics[f"class_{cls}_dice"]
            if not np.isnan(iou_val):
                per_class_iou[cls] += iou_val
                per_class_dice[cls] += dice_val
                per_class_count[cls] += 1
        
        total_iou += iou
        total_dice += dice
        num_batches += 1

mean_iou = total_iou / num_batches
mean_dice = total_dice / num_batches

print(f"\n{'='*60}")
print(f"OVERALL METRICS")
print(f"{'='*60}")
print(f"Mean IoU:  {mean_iou:.4f}")
print(f"Mean Dice: {mean_dice:.4f}")

print(f"\n{'='*60}")
print(f"PER-CLASS METRICS")
print(f"{'='*60}")
print(f"{'Class':<20} {'IoU':>10} {'Dice':>10} {'Samples':>10}")
print(f"{'-'*60}")

for cls in range(num_classes):
    organ_name = id2label[cls]
    if per_class_count[cls] > 0:
        cls_iou = per_class_iou[cls] / per_class_count[cls]
        cls_dice = per_class_dice[cls] / per_class_count[cls]
        print(f"{organ_name:<20} {cls_iou:>10.4f} {cls_dice:>10.4f} {int(per_class_count[cls]):>10}")
    else:
        print(f"{organ_name:<20} {'N/A':>10} {'N/A':>10} {'0':>10}")

print(f"{'='*60}\n")

## 18. Prediction Utilities for Inference

In [None]:
def predict_single_image(image_path, model, processor, device='cuda'):
    """
    Predict segmentation mask for a single image
    
    Args:
        image_path: Path to input image
        model: Trained SegFormer model
        processor: SegFormer image processor
        device: Device to run inference on
    
    Returns:
        pred_mask: Predicted segmentation mask (numpy array)
        confidence: Confidence scores per pixel
    """
    model.eval()
    model.to(device)
    
    # Load and process image
    image = Image.open(image_path).convert('RGB')
    inputs = processor(images=image, return_tensors="pt")
    pixel_values = inputs['pixel_values'].to(device)
    
    # Get prediction
    with torch.no_grad():
        outputs = model(pixel_values=pixel_values)
        logits = outputs.logits
        
        # Get probabilities
        probs = F.softmax(logits, dim=1)
        confidence, pred_mask = torch.max(probs, dim=1)
        
        # Upsample to original size
        pred_mask = F.interpolate(
            pred_mask.unsqueeze(1).float(),
            size=image.size[::-1],
            mode="nearest"
        ).squeeze().cpu().numpy()
        
        confidence = F.interpolate(
            confidence.unsqueeze(1),
            size=image.size[::-1],
            mode="bilinear",
            align_corners=False
        ).squeeze().cpu().numpy()
    
    return pred_mask.astype(np.uint8), confidence

def visualize_single_prediction(image_path, pred_mask, confidence=None):
    """Visualize prediction for a single image"""
    image = np.array(Image.open(image_path).convert('RGB'))
    
    fig, axes = plt.subplots(1, 3 if confidence is None else 4, figsize=(20, 5))
    
    # Original image
    axes[0].imshow(image)
    axes[0].set_title("Original Image", fontsize=14, fontweight='bold')
    axes[0].axis('off')
    
    # Predicted mask
    axes[1].imshow(pred_mask, cmap='tab10', vmin=0, vmax=num_classes-1)
    axes[1].set_title("Predicted Mask", fontsize=14, fontweight='bold')
    axes[1].axis('off')
    
    # Overlay
    overlay = image.copy().astype(float)
    colors = plt.cm.Set1(np.linspace(0, 1, num_classes))
    
    for cls in range(1, num_classes):
        mask_bool = pred_mask == cls
        if np.any(mask_bool):
            color = colors[cls][:3] * 255
            overlay[mask_bool] = overlay[mask_bool] * 0.4 + color * 0.6
    
    axes[2].imshow(overlay.astype(np.uint8))
    axes[2].set_title("Overlay", fontsize=14, fontweight='bold')
    axes[2].axis('off')
    
    # Confidence map
    if confidence is not None:
        im = axes[3].imshow(confidence, cmap='viridis', vmin=0, vmax=1)
        axes[3].set_title("Confidence Map", fontsize=14, fontweight='bold')
        axes[3].axis('off')
        plt.colorbar(im, ax=axes[3], fraction=0.046, pad=0.04)
    
    plt.tight_layout()
    plt.show()

print("‚úì Inference functions created")
print("\nExample usage:")
print("  pred_mask, confidence = predict_single_image('/path/to/image.png', model, processor)")
print("  visualize_single_prediction('/path/to/image.png', pred_mask, confidence)")

## üéØ Model Training Complete!

### üìä Results Summary
- Model trained on gynecological surgery frames
- Segments 4 classes: background, uterus, ovary, fallopian_tube
- Metrics: IoU and Dice coefficient computed per class

### üöÄ Next Steps for Improvement

#### Quick Improvements:
1. **Try SegFormer-B1**: Better accuracy with slightly more parameters
   ```python
   MODEL_NAME = "nvidia/segformer-b1-finetuned-ade-512-512"
   ```

2. **Increase training epochs**: Try 75-100 epochs
3. **Adjust learning rate**: Try 3e-5 or 7e-5
4. **Larger batch size**: If GPU allows, increase to 8 or 16

#### Advanced Techniques:
5. **Class weighting**: Handle class imbalance
6. **Combined loss**: Dice + CE loss
7. **Test-time augmentation**: Average predictions from augmented versions
8. **Post-processing**: CRF or morphological operations
9. **Ensemble**: Combine B0 and B1 predictions

#### Demo/Presentation:
10. **Video inference**: Process full surgical video sequences
11. **Real-time demo**: Show live segmentation on video frames
12. **Metrics dashboard**: Interactive visualization of performance
13. **Comparison**: Show before/after or compare with baseline

### üìÅ Saved Files
- Model: `./segformer-gynsurg-b0/`
- Includes: model weights, config, processor

Good luck with your project! üöÄüè•