## 1. Setup Environment

Clone repository and install dependencies (Colab only).

In [None]:
# Run this cell only on Google Colab
import sys

IN_COLAB = 'google.colab' in sys.modules

if IN_COLAB:
    %cd /content
    
    # Clone the repository
    !git clone https://github.com/AntoFratta/DVARF.git
    %cd /content/DVARF
    
    # Install dependencies (remove Windows-specific packages)
    !grep -v "triton-windows" requirements.txt > requirements_colab.txt
    !pip install -q -r requirements_colab.txt
    
    # Extract data
    !apt-get update -y > /dev/null
    !apt-get install -y unrar > /dev/null
    !unrar x data.rar ./ > /dev/null
    
    print("✅ Colab setup complete!")
else:
    # Local development: just add project root to path
    from pathlib import Path
    project_root = Path.cwd().parent if Path.cwd().name == "notebooks" else Path.cwd()
    if str(project_root) not in sys.path:
        sys.path.insert(0, str(project_root))
    print(f"✅ Local setup complete! Project root: {project_root}")

## 2. Import Libraries and Setup SAM3

In [None]:
import os
import sys
from pathlib import Path
import numpy as np
import torch
from PIL import Image
import matplotlib.pyplot as plt

# Add project root to path if not already added
if IN_COLAB:
    project_root = Path("/content/DVARF")
else:
    project_root = Path.cwd().parent if Path.cwd().name == "notebooks" else Path.cwd()

if str(project_root) not in sys.path:
    sys.path.insert(0, str(project_root))

# Import project modules
from src.sam3_wrapper import Sam3ImageModel
from src.prompts import CLASS_PROMPTS
from src.config import get_images_dir

print(f"Project root: {project_root}")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"\nClass prompts:")
for class_id, prompt in CLASS_PROMPTS.items():
    print(f"  Class {class_id}: {prompt}")

## 3. Load SAM3 Model

Initialize the SAM3 wrapper. This will download the model weights if needed.

In [None]:
# Enable debug mode to see internal SAM3 output structure
os.environ["SAM3_DEBUG"] = "1"

print("Loading SAM3 model...")
model = Sam3ImageModel()
print("✅ SAM3 model loaded successfully!")

## 4. Test on Sample Images

Run SAM3 on 2-3 images from the test split to verify feature extraction works.

In [None]:
# Get test images directory
test_images_dir = project_root / "data" / "raw" / "images" / "test"

# Select first 3 images for testing
test_image_files = sorted(test_images_dir.glob("*.jpg"))[:3]

if not test_image_files:
    print("❌ No test images found! Check that data was extracted correctly.")
else:
    print(f"Found {len(test_image_files)} test images:")
    for img_path in test_image_files:
        print(f"  - {img_path.name}")

### Test Single Image with Single Prompt

In [None]:
# Test on first image with first class prompt
if test_image_files:
    test_img = test_image_files[0]
    test_prompt = CLASS_PROMPTS[0]
    
    print(f"\n{'='*60}")
    print(f"Testing: {test_img.name}")
    print(f"Prompt: '{test_prompt}'")
    print(f"{'='*60}\n")
    
    # Run SAM3 prediction
    prediction = model.predict_with_text(test_img, test_prompt)
    
    print("\n" + "="*60)
    print("PREDICTION RESULTS")
    print("="*60)
    
    # Check prediction outputs
    print(f"\n✅ Boxes shape: {prediction.boxes.shape}")
    print(f"✅ Scores shape: {prediction.scores.shape}")
    print(f"✅ Masks shape: {prediction.masks.shape}")
    print(f"✅ Features shape: {prediction.features.shape}")
    
    num_detections = prediction.boxes.shape[0]
    print(f"\nNumber of detections: {num_detections}")
    
    if num_detections > 0:
        print(f"\nScores: {prediction.scores.cpu().numpy()}")
        print(f"\nFeatures (first detection, first 10 dims): {prediction.features[0, :10].cpu().numpy()}")
        print(f"Feature dtype: {prediction.features.dtype}")
        print(f"Feature device: {prediction.features.device}")
        
        # Verify feature dimensions are correct (256-d)
        assert prediction.features.shape[1] == 256, f"Expected 256-d features, got {prediction.features.shape[1]}"
        print("\n✅ Feature dimensions are correct (256-d)")
    else:
        print("\n⚠️ No detections found for this image/prompt combination")
else:
    print("❌ No test images available")

### Test All Classes on One Image

In [None]:
# Test all class prompts on first image
if test_image_files:
    test_img = test_image_files[0]
    
    print(f"\n{'='*60}")
    print(f"Testing all classes on: {test_img.name}")
    print(f"{'='*60}\n")
    
    for class_id, prompt in CLASS_PROMPTS.items():
        print(f"\nClass {class_id}: '{prompt}'")
        print("-" * 50)
        
        prediction = model.predict_with_text(test_img, prompt)
        
        num_det = prediction.boxes.shape[0]
        print(f"  Detections: {num_det}")
        
        if num_det > 0:
            print(f"  Scores: {prediction.scores.cpu().numpy()}")
            print(f"  Features shape: {prediction.features.shape}")
            print(f"  Feature mean: {prediction.features.mean().item():.4f}")
            print(f"  Feature std: {prediction.features.std().item():.4f}")
        else:
            print(f"  Features shape: {prediction.features.shape} (empty)")
    
    print("\n" + "="*60)
    print("✅ All class predictions completed successfully!")
    print("="*60)

## 5. Test Complete Pipeline on Multiple Images

Simulate the full `run_sam3_on_split` workflow on a few images.

In [None]:
from src.yolo_export import sam3_boxes_to_yolo, YoloBox

# Test on up to 3 images
test_limit = min(3, len(test_image_files))

print(f"\n{'='*60}")
print(f"Testing complete pipeline on {test_limit} images")
print(f"{'='*60}\n")

for idx, img_path in enumerate(test_image_files[:test_limit], 1):
    print(f"\n[{idx}/{test_limit}] Processing {img_path.name}")
    print("-" * 60)
    
    image = Image.open(img_path).convert("RGB")
    width, height = image.size
    
    all_boxes = []
    all_features = []
    
    # Loop over all classes
    for class_id, prompt in CLASS_PROMPTS.items():
        prediction = model.predict_with_text(img_path, prompt)
        
        # Convert to YOLO format
        yolo_boxes = sam3_boxes_to_yolo(
            prediction=prediction,
            class_id=class_id,
            image_width=width,
            image_height=height,
            score_threshold=0.26,
        )
        
        all_boxes.extend(yolo_boxes)
        
        print(f"  Class {class_id}: {len(yolo_boxes)} boxes")
    
    print(f"\n  Total boxes: {len(all_boxes)}")
    
    # Check that all boxes have features
    boxes_with_features = sum(1 for box in all_boxes if box.features is not None)
    boxes_without_features = len(all_boxes) - boxes_with_features
    
    print(f"  Boxes with features: {boxes_with_features}")
    print(f"  Boxes without features: {boxes_without_features}")
    
    if boxes_without_features > 0:
        print("\n  ❌ ERROR: Some boxes are missing features!")
    else:
        print("\n  ✅ All boxes have features")
        
        # Build 257-d feature vectors (256-d features + score)
        for box in all_boxes:
            score_val = box.score if box.score is not None else 0.0
            feat_with_score = np.concatenate([box.features, [score_val]]).astype(np.float32)
            all_features.append(feat_with_score)
        
        if all_features:
            features_arr = np.array(all_features, dtype=np.float16)
            print(f"  Final features array shape: {features_arr.shape}")
            print(f"  Expected shape: ({len(all_boxes)}, 257)")
            
            assert features_arr.shape == (len(all_boxes), 257), "Feature shape mismatch!"
            print("  ✅ Feature array shape is correct!")

print("\n" + "="*60)
print("✅ Pipeline test completed successfully!")
print("="*60)

## 6. Visualize Sample Detection

Show detections with bounding boxes on an image.

In [None]:
import matplotlib.patches as patches

if test_image_files:
    # Use first image
    test_img = test_image_files[0]
    image = Image.open(test_img).convert("RGB")
    width, height = image.size
    
    # Get predictions for first class only (to keep visualization clean)
    prompt = CLASS_PROMPTS[0]
    prediction = model.predict_with_text(test_img, prompt)
    
    # Create figure
    fig, ax = plt.subplots(1, 1, figsize=(12, 8))
    ax.imshow(image)
    ax.set_title(f"SAM3 Detections: {test_img.name}\nPrompt: '{prompt}'")
    ax.axis('off')
    
    # Draw bounding boxes
    if prediction.boxes.shape[0] > 0:
        boxes_np = prediction.boxes.cpu().numpy()
        scores_np = prediction.scores.cpu().numpy()
        
        for i, (box, score) in enumerate(zip(boxes_np, scores_np)):
            x1, y1, x2, y2 = box
            w = x2 - x1
            h = y2 - y1
            
            rect = patches.Rectangle(
                (x1, y1), w, h,
                linewidth=2,
                edgecolor='red',
                facecolor='none'
            )
            ax.add_patch(rect)
            
            # Add score label
            ax.text(
                x1, y1 - 5,
                f'{score:.2f}',
                color='white',
                fontsize=10,
                bbox=dict(facecolor='red', alpha=0.7, edgecolor='none', pad=2)
            )
    
    plt.tight_layout()
    plt.show()
    
    print(f"\nDetections: {prediction.boxes.shape[0]}")
    print(f"Features extracted: {prediction.features.shape[0]} x {prediction.features.shape[1]}")

## 7. Summary

If all cells above ran successfully, the SAM3 feature extraction is working correctly!

**Expected results:**
- ✅ SAM3 model loads without errors
- ✅ Predictions return boxes, scores, masks, and **features**
- ✅ Features have shape (N, 256) where N = number of detections
- ✅ All boxes have associated features (no missing values)
- ✅ Final feature array has shape (N, 257) after concatenating score

**Next steps:**
1. Run `run_sam3_on_split.py` on the train split to generate features for all images
2. Run `build_linear_probe_dataset.py` to create the training dataset
3. Run `train_linear_probe.py` to train the classifier
4. Run `apply_linear_probe_to_split.py` to apply it on the test split
5. Evaluate results with `eval_sam3_linear_probe_on_split.py`

In [None]:
# Disable debug mode
os.environ["SAM3_DEBUG"] = "0"
print("✅ Test notebook completed!")