# Multi-Model Segmentation Inference

This notebook implements video segmentation using two different architectures:

1. **DinoV2-Based Model**:
   - Self-supervised pre-training
   - Vision Transformer backbone
   - Multi-class segmentation

2. **U-Net Model**:
   - CNN-based architecture
   - ResNet34 encoder
   - Standard segmentation approach

The notebook supports switching between models and provides consistent visualization for comparison.

## 1. Imports

In [11]:
import os
from pathlib import Path
import torch
import torch.nn as nn
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from transformers import Dinov2Model, AutoImageProcessor
import cv2
import segmentation_models_pytorch as smp
import albumentations as A
from albumentations.pytorch import ToTensorV2

## 2. Model Architecture

### DinoV2 Architecture
The `DinoV2ForSemanticSegmentation` class implements:
- Frozen DinoV2-base backbone
- Custom decoder head design
- Multi-scale feature processing
- Progressive upsampling

Key components:
1. **Feature Extraction**:
   - Pre-trained ViT-B backbone
   - Patch token processing
   - Hidden state utilization

2. **Decoder Design**:
   - Channel reduction (768→256→128→64)
   - Progressive upsampling
   - Bilinear interpolation
   - Final classification head

In [2]:
DINO_IMAGE_SIZE = 224

class DinoV2ForSemanticSegmentation(nn.Module):
    def __init__(self, num_classes=5):
        super(DinoV2ForSemanticSegmentation, self).__init__()
        self.dinov2 = Dinov2Model.from_pretrained("facebook/dinov2-base")
        for param in self.dinov2.parameters():
            param.requires_grad = False
        self.head = nn.Sequential(
            nn.Conv2d(768, 256, kernel_size=3, padding=1), nn.ReLU(),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
            nn.Conv2d(256, 128, kernel_size=3, padding=1), nn.ReLU(),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
            nn.Conv2d(128, 64, kernel_size=3, padding=1), nn.ReLU(),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
            nn.Conv2d(64, num_classes, kernel_size=1)
        )

    def forward(self, pixel_values):
        outputs = self.dinov2(pixel_values, output_hidden_states=True)
        last_hidden_state = outputs.last_hidden_state
        patch_tokens = last_hidden_state[:, 1:, :]
        batch_size, seq_len, num_channels = patch_tokens.shape
        height = width = int(seq_len**0.5)
        feature_map = patch_tokens.permute(0, 2, 1).contiguous().reshape(batch_size, num_channels, height, width)
        logits = self.head(feature_map)
        final_logits = nn.functional.interpolate(logits, size=(DINO_IMAGE_SIZE, DINO_IMAGE_SIZE), mode='bilinear', align_corners=False)
        return final_logits

## 3. Configuration

### Model Selection and Setup
- Choice between DinoV2 and U-Net architectures
- Run-specific model loading
- Automatic output organization

### Path Configuration
1. **Model Paths**:
   - Architecture-specific model files
   - Run-based organization
   - Automatic path resolution

2. **Video Settings**:
   - Input/output paths
   - Automated directory creation
   - Consistent naming conventions

3. **Hardware Setup**:
   - CUDA/MPS/CPU detection
   - Device optimization
   - Resource management

In [15]:
# --- CHOOSE THE MODEL TYPE ---
MODEL_TYPE = "unet"  # Options: "dino" or "unet"

# --- CHOOSE THE RUN AND VIDEO ---
# If MODEL_TYPE is "dino", use a DinoV2 run name
if MODEL_TYPE == "dino":
    RUN_NAME = "dinov2_2025-08-03_19-45-00" # <--- DinoV2 run folder name
    MODEL_FILENAME = "best_model.pth"
# If MODEL_TYPE is "unet", use a U-Net run name
else:
    RUN_NAME = "2025-08-03_15-56-59" # <--- U-Net run folder name
    MODEL_FILENAME = "peatland_segmentation_model.pth"

VIDEO_FILENAME = "Clip_1_42s.mp4" # <--- Video you want to process

# --- Paths and settings derived automatically ---
METRICS_DIR = Path("metrics") / RUN_NAME
MODEL_PATH = METRICS_DIR / MODEL_FILENAME
VIDEO_PATH = Path("../data/video/splits") / VIDEO_FILENAME
OUTPUT_DIR = METRICS_DIR / "video_predictions"
OUTPUT_DIR.mkdir(exist_ok=True)
OUTPUT_VIDEO_PATH = OUTPUT_DIR / f"predicted_{MODEL_TYPE}_{VIDEO_FILENAME}"

if torch.cuda.is_available(): DEVICE = "cuda"
elif torch.backends.mps.is_available(): DEVICE = "mps"
else: DEVICE = "cpu"

print(f"MODEL_TYPE: {MODEL_TYPE}")
print(f"Loading model from run: {RUN_NAME}")
print(f"Processing video: {VIDEO_PATH}")
print(f"Using device: {DEVICE}")

MODEL_TYPE: unet
Loading model from run: 2025-08-03_15-56-59
Processing video: ../data/video/splits/Clip_1_42s.mp4
Using device: mps


## 4. Load Model


In [16]:
if MODEL_TYPE == "dino":
    DINO_IMAGE_SIZE = 224
    processor = AutoImageProcessor.from_pretrained("facebook/dinov2-base")
    model = DinoV2ForSemanticSegmentation(num_classes=5).to(DEVICE)
else: # unet
    IMG_HEIGHT = 480
    IMG_WIDTH = 640
    IMAGENET_MEAN = [0.485, 0.456, 0.406]
    IMAGENET_STD = [0.229, 0.224, 0.225]
    processor = A.Compose([
        A.Resize(height=IMG_HEIGHT, width=IMG_WIDTH),
        A.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD, max_pixel_value=255.0),
        ToTensorV2(),
    ])
    model = smp.Unet("resnet34", encoder_weights=None, in_channels=3, classes=5).to(DEVICE)

model.load_state_dict(torch.load(MODEL_PATH, map_location=torch.device(DEVICE)))
model.eval()
print("Model loaded successfully.")

Model loaded successfully.


## 5. Visualization and Processing Functions

### Color Scheme
- **PATH**: Purple (152, 16, 60)
- **NATURAL_GROUND**: Pink (246, 41, 132)
- **TREE**: Light Blue (228, 193, 110)
- **VEGETATION**: Yellow (58, 221, 254)

### Processing Pipeline
1. **Frame Processing**:
   - RGB conversion
   - Model-specific preprocessing
   - Batch dimension handling

2. **Mask Generation**:
   - Model inference
   - Argmax prediction
   - Resolution matching

3. **Visualization**:
   - Class-based coloring
   - Contour drawing
   - Transparent overlays

In [17]:
CLASS_COLORS = {
    "PATH": (152, 16, 60), "NATURAL_GROUND": (246, 41, 132),
    "TREE": (228, 193, 110), "VEGETATION": (58, 221, 254),
}
CLASS_IDS = {"PATH": 0, "NATURAL_GROUND": 1, "TREE": 2, "VEGETATION": 3, "IGNORE": 4}

def get_prediction_mask(frame: np.ndarray, model, processor, model_type, device):
    """Takes a video frame, runs inference, and returns the raw prediction mask."""
    image_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
    
    if model_type == "dino":
        image_pil = Image.fromarray(image_rgb)
        pixel_values = processor(image_pil, return_tensors="pt").pixel_values.to(device)
    else: # unet
        augmented = processor(image=image_rgb)
        pixel_values = augmented['image'].unsqueeze(0).to(device)

    with torch.no_grad():
        outputs = model(pixel_values)
        
    preds = torch.argmax(outputs, dim=1).squeeze(0).cpu().numpy().astype(np.uint8)
    pred_resized = cv2.resize(preds, (frame.shape[1], frame.shape[0]), interpolation=cv2.INTER_NEAREST)
    
    return pred_resized

## 6. Video Processing

In [None]:
cap = cv2.VideoCapture(str(VIDEO_PATH))
if not cap.isOpened():
    print(f"Error: Could not open video file {VIDEO_PATH}")
else:
    frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    fps = int(cap.get(cv2.CAP_PROP_FPS))
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    out = cv2.VideoWriter(str(OUTPUT_VIDEO_PATH), fourcc, fps, (frame_width, frame_height))

    print(f"Processing {total_frames} frames...")
    
    for _ in tqdm(range(total_frames)):
        ret, frame = cap.read()
        if not ret:
            break
            
        # 1. Get the raw prediction mask from the model
        pred_mask = get_prediction_mask(frame, model, processor, MODEL_TYPE, DEVICE)
        
        # 2. Create a transparent overlay
        overlay = np.zeros_like(frame, dtype=np.uint8)
        
        # 3. Fill the PATH area with a semi-transparent color
        path_mask = (pred_mask == CLASS_IDS["PATH"])
        overlay[path_mask] = CLASS_COLORS["PATH"]
        
        # 4. Find and draw contours for VEGETATION
        veg_mask = (pred_mask == CLASS_IDS["VEGETATION"]).astype(np.uint8)
        contours, _ = cv2.findContours(veg_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        cv2.drawContours(overlay, contours, -1, CLASS_COLORS["VEGETATION"], 2) # 2px thick line
        
        # 5. Find and draw contours for TREES
        tree_mask = (pred_mask == CLASS_IDS["TREE"]).astype(np.uint8)
        contours, _ = cv2.findContours(tree_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        cv2.drawContours(overlay, contours, -1, CLASS_COLORS["TREE"], 2) # 2px thick line

        # 6. Blend the overlay with the original frame
        final_frame = cv2.addWeighted(frame, 1.0, overlay, 0.6, 0)
        
        out.write(final_frame)
        
    cap.release()
    out.release()
    print("\nProcessing complete.")
    print(f"Output video saved to: {OUTPUT_VIDEO_PATH}")

Processing 1287 frames...


100%|██████████| 1287/1287 [02:00<00:00, 10.69it/s]


Processing complete.
Output video saved to: metrics/2025-08-03_15-56-59/video_predictions/predicted_unet_Clip_1_42s.mp4



