# SAM 3 with Linear Probing on DVARF Dataset

This notebook demonstrates the complete pipeline for training and applying a linear probe on top of SAM 3's zero-shot predictions to improve object detection performance on the DVARF aerial accident dataset.

## 1. Environment Setup

Clone the repository and install required dependencies.

In [None]:
%cd /content

# Clone the DVARF repository
!git clone https://github.com/AntoFratta/DVARF.git

%cd /content/DVARF

# Create a Colab-compatible requirements file by removing Windows-specific packages
!grep -v "triton-windows" requirements.txt > requirements_colab.txt

# Install all project dependencies
!pip install -r requirements_colab.txt

## 2. Extract Dataset

Extract the data archive containing images and annotations.

In [None]:
%cd /content/DVARF

# Install unrar utility
!apt-get update -y > /dev/null
!apt-get install -y unrar > /dev/null

# Extract data.rar to create the data/ directory structure
!unrar x data.rar ./ > /dev/null

# Verify extraction
print("After extracting data.rar:")
!ls
print("\nInside data/:")
!ls data
print("\nInside data/raw/:")
!ls data/raw
print("\nImage splits:")
!ls data/raw/images

## 3. Install SAM 3

Clone and install the official SAM 3 repository.

In [None]:
%cd /content

# Clone the official SAM 3 repository from Meta
!git clone https://github.com/facebookresearch/sam3.git

In [None]:
%cd /content/sam3

# Install SAM 3 in editable mode
!pip install -e .

## 4. Hugging Face Authentication

Authenticate with Hugging Face to access SAM 3 model weights.

In [None]:
from huggingface_hub import login

# Login to Hugging Face (widget will prompt for your token)
login()

## 5. Run SAM 3 on Training Split

Execute zero-shot inference on the training set to generate predictions for linear probe training.

In [None]:
%cd /content/DVARF

import sys
if "/content/DVARF" not in sys.path:
    sys.path.insert(0, "/content/DVARF")

from scripts.run_sam3_on_split import run_sam3_on_split

# Run SAM 3 on TRAIN split (segmentations not needed for linear probing)
run_sam3_on_split(
    split="train",
    score_threshold=0.26,
    max_images=None,
    save_segmentations=False,
    max_masks_per_image_per_class=None,
)

## 6. Build Linear Probe Dataset

Create training dataset for the linear probe by matching predictions with ground truth annotations.

In [None]:
%cd /content/DVARF

# Build the linear probe training dataset by matching SAM 3 predictions to ground truth
!python scripts/build_linear_probe_dataset.py

## 7. Train Linear Probe

Train a simple logistic regression classifier for each object class to refine SAM 3's predictions.

In [None]:
%cd /content/DVARF

# Train the linear probe classifier
!python scripts/train_linear_probe.py

## 8. Run SAM 3 on Test Split

Execute zero-shot inference on the test set.

In [None]:
%cd /content/DVARF

import sys
if "/content/DVARF" not in sys.path:
    sys.path.insert(0, "/content/DVARF")

from scripts.run_sam3_on_split import run_sam3_on_split

# Run SAM 3 on TEST split with segmentation masks saved
run_sam3_on_split(
    split="test",
    score_threshold=0.26,
    max_images=None,
    save_segmentations=True,
    max_masks_per_image_per_class=None,
)

## 9. Apply Linear Probe to Test Predictions

Refine the test set predictions using the trained linear probe.

In [None]:
%cd /content/DVARF

# Apply the trained linear probe to test predictions
!python scripts/apply_linear_probe_to_split.py

## 10. Evaluate Linear Probe Results

Compute metrics on the refined predictions.

In [None]:
%cd /content/DVARF

# Evaluate the linear probe enhanced predictions and save metrics
!python scripts/eval_sam3_linear_probe_on_split.py | tee results/sam3_linear_probe_test_metrics.txt

## 11. Archive Results

Create compressed archives of prediction outputs and training artifacts for download.

In [None]:
%cd /content/DVARF

# Archive zero-shot SAM 3 predictions on test set
!zip -r results/sam3_test_predictions.zip data/processed/predictions/sam3_yolo/test

# Archive linear probe enhanced predictions on test set
!zip -r results/sam3_linear_probe_test_predictions.zip data/processed/predictions/sam3_linear_probe_yolo/test

In [None]:
%cd /content/DVARF

# Archive linear probe training artifacts (dataset and weights)
!zip -r results/sam3_linear_probe_training_artifacts.zip data/processed/linear_probe

## 12. Calcolo metriche aggiuntive per confronto con la tesi

Calculate additional metrics for the linear probe enhanced predictions:
- **Class-specific average IoU**: Mean IoU for True Positives (predictions correctly matched to ground truth with IoU ≥ 0.5)
- **Inference speed**: Average inference time per frame (SAM 3 inference + linear probe application + post-processing, excluding I/O)

In [None]:
%cd /content/DVARF

import sys
import numpy as np
from pathlib import Path
from time import time
from PIL import Image

if "/content/DVARF" not in sys.path:
    sys.path.insert(0, "/content/DVARF")

from src.config import get_labels_dir
from src.prompts import CLASS_PROMPTS
from src.eval_yolo import _load_yolo_dataset, _compute_iou, _yolo_to_xyxy

# ============================================================================
# PART 1: Calculate class-specific average IoU for Linear Probe predictions
# ============================================================================

split = "test"
confidence_threshold = 0.26
iou_threshold = 0.5
num_classes = len(CLASS_PROMPTS)

labels_dir = get_labels_dir(split)
# Linear probe predictions are stored in sam3_linear_probe_yolo directory
preds_dir = Path("data/processed/predictions/sam3_linear_probe_yolo") / split

print("="*70)
print("CLASS-SPECIFIC AVERAGE IoU (Linear Probe - True Positives with IoU ≥ 0.5)")
print("="*70)
print(f"Split: {split}")
print(f"Labels directory: {labels_dir}")
print(f"Predictions directory: {preds_dir}")
print(f"Confidence threshold: {confidence_threshold}")
print(f"IoU threshold: {iou_threshold}\n")

# Load ground truth and predictions
gt_by_class, preds_by_class = _load_yolo_dataset(
    labels_dir=labels_dir,
    preds_dir=preds_dir,
    num_classes=num_classes,
    confidence_threshold=confidence_threshold,
)

# Calculate IoU for True Positives per class
class_names = {cid: name for cid, name in CLASS_PROMPTS.items()}
iou_per_class = {}

for c in range(num_classes):
    gt_dict = gt_by_class.get(c, {})
    preds = preds_by_class.get(c, [])
    
    if len(preds) == 0:
        iou_per_class[c] = []
        continue
    
    # Sort predictions by score (highest first)
    preds_sorted = sorted(preds, key=lambda x: x[2], reverse=True)
    
    # Track which GT boxes have been matched
    matched = {}
    for img_id, boxes in gt_dict.items():
        matched[img_id] = np.zeros(len(boxes), dtype=bool)
    
    # Collect IoU values for True Positives
    tp_ious = []
    
    for img_id, box_pred, _score in preds_sorted:
        gt_boxes = gt_dict.get(img_id)
        if gt_boxes is None or gt_boxes.size == 0:
            continue  # False Positive (no GT in this image)
        
        # Compute IoU with all GT boxes in this image
        ious = _compute_iou(box_pred, gt_boxes)
        best_idx = int(np.argmax(ious))
        best_iou = float(ious[best_idx])
        
        # True Positive if IoU ≥ threshold and GT not already matched
        if best_iou >= iou_threshold and not matched[img_id][best_idx]:
            tp_ious.append(best_iou)
            matched[img_id][best_idx] = True
    
    iou_per_class[c] = tp_ious

# Calculate and display mean IoU per class
mean_ious = {}
for c in range(num_classes):
    class_name = class_names.get(c, f"class_{c}")
    ious = iou_per_class[c]
    
    if len(ious) > 0:
        mean_iou = np.mean(ious)
        mean_ious[c] = mean_iou
        print(f"  {class_name:15s}: IoU_medio = {mean_iou:.4f}  (n_TP = {len(ious)})")
    else:
        mean_ious[c] = 0.0
        print(f"  {class_name:15s}: IoU_medio = 0.0000  (n_TP = 0)")

# Calculate overall mean IoU
if len(mean_ious) > 0:
    iou_medio_totale = np.mean(list(mean_ious.values()))
    print(f"\n  {'MEAN (all classes)':15s}: IoU_medio_totale = {iou_medio_totale:.4f}")
else:
    iou_medio_totale = 0.0
    print(f"\n  {'MEAN (all classes)':15s}: IoU_medio_totale = 0.0000")


# Save metrics to file
metrics_file = Path("results/sam3_linear_probe_test_metrics.txt")
with open(metrics_file, "a", encoding="utf-8") as f:
    f.write("\n\n" + "="*70 + "\n")
    f.write("Linear Probe - CLASS-SPECIFIC AVERAGE IoU (True Positives with IoU ≥ 0.5)\n")
    f.write("="*70 + "\n")
    for c in range(num_classes):
        class_name = class_names.get(c, f"class_{c}")
        ious = iou_per_class[c]
        if len(ious) > 0:
            mean_iou = mean_ious[c]
            f.write(f"  {class_name:15s}: IoU_medio = {mean_iou:.4f}  (n_TP = {len(ious)})\n")
        else:
            f.write(f"  {class_name:15s}: IoU_medio = 0.0000  (n_TP = 0)\n")
    f.write(f"\n  {'MEAN (all classes)':15s}: IoU_medio_totale = {iou_medio_totale:.4f}\n")
    f.write("="*70 + "\n")

print(f"\nMetrics saved to: {metrics_file}")
print("="*70)

In [None]:
%cd /content/DVARF

import sys
import torch
from time import time
from pathlib import Path

if "/content/DVARF" not in sys.path:
    sys.path.insert(0, "/content/DVARF")

from src.config import get_images_dir
from src.prompts import CLASS_PROMPTS
from src.sam3_wrapper import Sam3ImageModel
from src.yolo_export import sam3_boxes_to_yolo, nms_yolo_boxes

# ======================================================================
# PART 2: Measure inference speed per frame (Linear Probe - SAM 3)
# ======================================================================
# NOTE: The timing includes image loading from disk because SAM 3's
# predict_with_text() method loads images internally from file paths.
# The measured time represents the complete pipeline: I/O + inference + post-processing.
# ======================================================================

split = "test"
score_threshold = 0.26
nms_iou = 0.7
nms_max_det = 300

images_dir = get_images_dir(split)
image_files = sorted(
    list(images_dir.glob("*.jpg")) + list(images_dir.glob("*.png")),
    key=lambda p: int(p.stem),
)

print("\n" + "="*70)
print("Linear Probe - INFERENCE SPEED MEASUREMENT")
print("="*70)
print(f"Split: {split}")
print(f"Number of images: {len(image_files)}")
print(f"Images directory: {images_dir}")
print(f"Score threshold: {score_threshold}")
print(f"NMS: iou={nms_iou}, max_det={nms_max_det}")
print(f"NOTE: Timing includes I/O + model inference + post-processing\n")

# Initialize model
model = Sam3ImageModel()

# Synchronize GPU if available
if torch.cuda.is_available():
    torch.cuda.synchronize()

# Start timing (I/O + model inference + post-processing)
print("Starting inference timing...")
t_start = time()

for img_path in image_files:
    # Get image dimensions (quick read just for metadata)
    from PIL import Image
    with Image.open(img_path) as img:
        width, height = img.size
    
    all_boxes = []
    
    # Query SAM 3 for each class (this includes image loading)
    for class_id, prompt in CLASS_PROMPTS.items():
        prediction = model.predict_with_text(img_path, prompt)
        
        # Convert to YOLO format
        yolo_boxes = sam3_boxes_to_yolo(
            prediction=prediction,
            class_id=class_id,
            image_width=width,
            image_height=height,
            score_threshold=score_threshold,
        )
        all_boxes.extend(yolo_boxes)
    
    # Apply NMS
    all_boxes = nms_yolo_boxes(all_boxes, iou_threshold=nms_iou, max_det=nms_max_det)

# Synchronize GPU if available
if torch.cuda.is_available():
    torch.cuda.synchronize()

t_end = time()
total_time = t_end - t_start

# Calculate speed per frame
num_images = len(image_files)
speed_per_frame_s = total_time / num_images
speed_per_frame_ms = speed_per_frame_s * 1000

print(f"\nInference completed.")
print(f"Total time: {total_time:.2f} s")
print(f"Number of images: {num_images}")
print(f"Speed per frame: {speed_per_frame_ms:.2f} ms/frame ({speed_per_frame_s:.4f} s/frame)")

# Save speed metrics to file
metrics_file = Path("results/sam3_linear_probe_test_metrics.txt")
with open(metrics_file, "a", encoding="utf-8") as f:
    f.write("\n" + "="*70 + "\n")
    f.write("Linear Probe - INFERENCE SPEED MEASUREMENT\n")
    f.write("="*70 + "\n")
    f.write(f"Measured on: test set ({num_images} images)\n")
    f.write(f"Components: I/O + model inference + post-processing (NMS)\n")
    f.write(f"Total time: {total_time:.2f} s\n")
    f.write(f"Speed per frame: {speed_per_frame_ms:.2f} ms/frame ({speed_per_frame_s:.4f} s/frame)\n")
    f.write("="*70 + "\n")

print(f"\nSpeed metrics saved to: {metrics_file}")
print("="*70)