# Task 4: Depth Estimation Model Evaluation

This notebook evaluates two state-of-the-art depth estimation models:
1. **Depth Anything 3** (ByteDance)
2. **UniDepthV2** (ETH Zurich)

**Goal**: Test absolute depth estimation (in meters) on strawberry dataset
**Environment**: Kaggle with GPU support (carefully managed dependencies)

> ⚠️ **IMPORTANT: GPU REQUIRED**
> These models require a GPU to run efficiently. On Kaggle, please ensure you have enabled **GPU T4 x2** or **P100** in the Accelerator settings.
> 
> **Local CPU Testing**: Not supported for these models due to heavy dependencies (xformers) and high VRAM usage. The notebook will gracefully skip models if they cannot be loaded.

## 1. Environment Setup and Dependency Management

In [None]:
# Check system info
import sys
import torch
print(f"Python version: {sys.version}")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA version: {torch.version.cuda}")
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

### Install Common Dependencies

In [None]:
# Install common dependencies first
!pip install -q opencv-python-headless matplotlib seaborn scikit-learn
!pip install -q pillow numpy pandas tqdm einops timm
!pip install -q huggingface_hub

## 2. Download Dataset

In [None]:
import os
import json
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
from PIL import Image
from tqdm.auto import tqdm
import cv2

sns.set_style('whitegrid')
plt.rcParams['figure.figsize'] = (15, 10)

# Clone dataset repository
REPO_URL = "https://github.com/SergKurchev/strawberry_synthetic_dataset.git"
DATASET_DIR = "/kaggle/working/strawberry_dataset"

if not os.path.exists(DATASET_DIR):
    print("Cloning dataset repository...")
    !git clone {REPO_URL} {DATASET_DIR}
else:
    print("Dataset already exists")

# Load metadata
with open(os.path.join(DATASET_DIR, "depth_metadata.json"), 'r') as f:
    depth_metadata = json.load(f)

print(f"\nDataset loaded: {len(depth_metadata)} images with depth ground truth")

In [None]:
# Select test images
test_images = sorted(list(depth_metadata.keys()))[:20]  # Use first 20 images
print(f"Selected {len(test_images)} test images")

def load_depth_png(depth_path):
    """Load depth from PNG file with proper 16-bit decoding."""
    # The depth is encoded as 16-bit value split across R and G channels:
    # - R channel: high byte (depth_mm >> 8)
    # - G channel: low byte (depth_mm & 0xFF)
    img = Image.open(depth_path)
    depth_arr = np.array(img)
    
    if len(depth_arr.shape) == 3 and depth_arr.shape[2] >= 2:
        # RGB encoded - 16-bit value in R (high) and G (low) channels
        high = depth_arr[:, :, 0].astype(np.uint16)
        low = depth_arr[:, :, 1].astype(np.uint16)
        depth_mm = (high << 8) | low  # Reconstruct 16-bit value
        depth_m = depth_mm.astype(np.float32) / 1000.0  # mm to m
    else:
        # Single channel - assume already in mm
        depth_m = depth_arr.astype(np.float32) / 1000.0
    
    return depth_m

# Visualize sample images with ground truth depth
fig, axes = plt.subplots(2, 4, figsize=(20, 10))
axes = axes.flatten()

for idx, img_name in enumerate(test_images[:4]):
    # Load RGB
    rgb_path = os.path.join(DATASET_DIR, "images", img_name)
    rgb = cv2.imread(rgb_path)
    rgb = cv2.cvtColor(rgb, cv2.COLOR_BGR2RGB)
    
    # Load depth
    depth_path = os.path.join(DATASET_DIR, "depth", img_name)
    depth_meters = load_depth_png(depth_path)
    
    # Display RGB
    axes[idx*2].imshow(rgb)
    axes[idx*2].set_title(f"RGB: {img_name}", fontsize=10)
    axes[idx*2].axis('off')
    
    # Display depth
    im = axes[idx*2+1].imshow(depth_meters, cmap='turbo', vmin=0, vmax=3)
    axes[idx*2+1].set_title(f"Ground Truth Depth (m)", fontsize=10)
    axes[idx*2+1].axis('off')
    plt.colorbar(im, ax=axes[idx*2+1], fraction=0.046)

plt.suptitle('Sample Images with Ground Truth Depth', fontsize=16, fontweight='bold')
plt.tight_layout()
plt.savefig('/kaggle/working/dataset_samples.png', dpi=150, bbox_inches='tight')
plt.show()

## 3. Model 1: Depth Anything 3

### Installation Strategy for Kaggle

In [None]:
# Install Depth Anything 3 dependencies
print("Installing Depth Anything 3...")

# Install xformers (required for DA3)
!pip install -q xformers

# Clone repository
DA3_DIR = "/kaggle/working/Depth-Anything-3"
if not os.path.exists(DA3_DIR):
    !git clone https://github.com/ByteDance-Seed/Depth-Anything-3.git {DA3_DIR}

# Install package
import sys
sys.path.insert(0, DA3_DIR)

# Install DA3 in editable mode
!cd {DA3_DIR} && pip install -q -e .

print("Depth Anything 3 installed successfully!")

In [None]:
# Load Depth Anything 3 model
try:
    from depth_anything_3.api import DepthAnything3
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # Load model (using smaller variant for Kaggle)
    print("Loading Depth Anything 3 model...")
    da3_model = DepthAnything3.from_pretrained("depth-anything/DA3NESTED-GIANT-LARGE")
    da3_model = da3_model.to(device=device)
    da3_model.eval()
    
    print(f"✓ Depth Anything 3 loaded on {device}")
    DA3_AVAILABLE = True
    
except Exception as e:
    print(f"⚠ Failed to load Depth Anything 3: {e}")
    print("Continuing with UniDepthV2 only...")
    DA3_AVAILABLE = False

### Test Depth Anything 3

In [None]:
if DA3_AVAILABLE:
    # Run inference on test images
    da3_results = {}
    
    print("Running Depth Anything 3 inference...")
    for img_name in tqdm(test_images):
        img_path = os.path.join(DATASET_DIR, "images", img_name)
        
        try:
            # Run inference
            prediction = da3_model.inference([img_path])
            
            # Extract depth map
            depth_pred = prediction.depth[0]  # [H, W]
            
            # Store result
            da3_results[img_name] = depth_pred
            
        except Exception as e:
            print(f"Error processing {img_name}: {e}")
            continue
    
    print(f"\n✓ Depth Anything 3 processed {len(da3_results)} images")
else:
    print("Skipping Depth Anything 3 (not available)")
    da3_results = {}

## 4. Model 2: UniDepthV2

### Installation Strategy for Kaggle

In [None]:
# Install UniDepthV2 dependencies
print("Installing UniDepthV2...")

# Clone repository
UNIDEPTH_DIR = "/kaggle/working/UniDepth"
if not os.path.exists(UNIDEPTH_DIR):
    !git clone https://github.com/lpiccinelli-eth/UniDepth.git {UNIDEPTH_DIR}

# Install package
sys.path.insert(0, UNIDEPTH_DIR)

# Install dependencies from requirements
!cd {UNIDEPTH_DIR} && pip install -q -e .

print("UniDepthV2 installed successfully!")

In [None]:
# Load UniDepthV2 model
try:
    from unidepth.models import UniDepthV2
    import torch.nn.functional as F
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # Load model
    print("Loading UniDepthV2 model...")
    unidepth_model = UniDepthV2.from_pretrained("lpiccinelli/unidepth-v2-vitl14")
    unidepth_model = unidepth_model.to(device)
    unidepth_model.eval()
    
    print(f"✓ UniDepthV2 loaded on {device}")
    UNIDEPTH_AVAILABLE = True
    
except Exception as e:
    print(f"⚠ Failed to load UniDepthV2: {e}")
    print("Trying alternative loading method...")
    
    try:
        # Alternative: use torch.hub
        unidepth_model = torch.hub.load("lpiccinelli-eth/UniDepth", "UniDepthV2", trust_repo=True)
        unidepth_model = unidepth_model.to(device)
        unidepth_model.eval()
        print(f"✓ UniDepthV2 loaded via torch.hub on {device}")
        UNIDEPTH_AVAILABLE = True
    except Exception as e2:
        print(f"⚠ Failed to load UniDepthV2: {e2}")
        UNIDEPTH_AVAILABLE = False

### Test UniDepthV2

In [None]:
if UNIDEPTH_AVAILABLE:
    # Run inference on test images
    unidepth_results = {}
    
    print("Running UniDepthV2 inference...")
    for img_name in tqdm(test_images):
        img_path = os.path.join(DATASET_DIR, "images", img_name)
        
        try:
            # Load and preprocess image
            rgb = Image.open(img_path).convert('RGB')
            
            # Run inference
            with torch.no_grad():
                predictions = unidepth_model.infer(rgb)
            
            # Extract depth map (in meters)
            depth_pred = predictions['depth'].cpu().numpy().squeeze()
            
            # Store result
            unidepth_results[img_name] = depth_pred
            
        except Exception as e:
            print(f"Error processing {img_name}: {e}")
            continue
    
    print(f"\n✓ UniDepthV2 processed {len(unidepth_results)} images")
else:
    print("Skipping UniDepthV2 (not available)")
    unidepth_results = {}

## 5. Evaluation Metrics

In [None]:
# Define evaluation metrics
def compute_depth_metrics(pred, gt, mask=None):
    """
    Compute depth estimation metrics
    pred: predicted depth (meters)
    gt: ground truth depth (meters)
    mask: valid pixels mask
    """
    if mask is None:
        mask = (gt > 0) & (gt < 10)  # Valid depth range
    
    pred = pred[mask]
    gt = gt[mask]
    
    # Align scale (for relative depth models)
    scale = np.median(gt) / np.median(pred)
    pred_scaled = pred * scale
    
    # Absolute metrics
    abs_rel = np.mean(np.abs(pred_scaled - gt) / gt)
    sq_rel = np.mean(((pred_scaled - gt) ** 2) / gt)
    rmse = np.sqrt(np.mean((pred_scaled - gt) ** 2))
    rmse_log = np.sqrt(np.mean((np.log(pred_scaled) - np.log(gt)) ** 2))
    
    # Threshold accuracy
    thresh = np.maximum((gt / pred_scaled), (pred_scaled / gt))
    a1 = (thresh < 1.25).mean()
    a2 = (thresh < 1.25 ** 2).mean()
    a3 = (thresh < 1.25 ** 3).mean()
    
    return {
        'abs_rel': abs_rel,
        'sq_rel': sq_rel,
        'rmse': rmse,
        'rmse_log': rmse_log,
        'a1': a1,
        'a2': a2,
        'a3': a3,
        'scale': scale
    }

print("Evaluation metrics defined")

## 6. Compare Models

In [None]:
# Evaluate both models
da3_metrics_list = []
unidepth_metrics_list = []

print("Evaluating models...\n")

for img_name in test_images:
    # Load ground truth
    depth_path = os.path.join(DATASET_DIR, "depth", img_name)
    gt_depth = load_depth_png(depth_path)
    
    # Evaluate DA3
    if img_name in da3_results:
        pred_depth = da3_results[img_name]
        # Resize if needed
        if pred_depth.shape != gt_depth.shape:
            pred_depth = cv2.resize(pred_depth, (gt_depth.shape[1], gt_depth.shape[0]))
        metrics = compute_depth_metrics(pred_depth, gt_depth)
        da3_metrics_list.append(metrics)
    
    # Evaluate UniDepth
    if img_name in unidepth_results:
        pred_depth = unidepth_results[img_name]
        # Resize if needed
        if pred_depth.shape != gt_depth.shape:
            pred_depth = cv2.resize(pred_depth, (gt_depth.shape[1], gt_depth.shape[0]))
        metrics = compute_depth_metrics(pred_depth, gt_depth)
        unidepth_metrics_list.append(metrics)

# Aggregate metrics
def aggregate_metrics(metrics_list):
    if not metrics_list:
        return None
    
    aggregated = {}
    for key in metrics_list[0].keys():
        if key != 'scale':
            aggregated[key] = np.mean([m[key] for m in metrics_list])
    return aggregated

da3_avg_metrics = aggregate_metrics(da3_metrics_list)
unidepth_avg_metrics = aggregate_metrics(unidepth_metrics_list)

print("\n=== Depth Anything 3 Results ===")
if da3_avg_metrics:
    for key, value in da3_avg_metrics.items():
        print(f"{key}: {value:.4f}")
else:
    print("Not available")

print("\n=== UniDepthV2 Results ===")
if unidepth_avg_metrics:
    for key, value in unidepth_avg_metrics.items():
        print(f"{key}: {value:.4f}")
else:
    print("Not available")

In [None]:
# Visualize comparison
if da3_avg_metrics and unidepth_avg_metrics:
    metrics_names = ['abs_rel', 'rmse', 'rmse_log', 'a1', 'a2', 'a3']
    da3_values = [da3_avg_metrics[m] for m in metrics_names]
    unidepth_values = [unidepth_avg_metrics[m] for m in metrics_names]
    
    fig, axes = plt.subplots(2, 3, figsize=(18, 10))
    axes = axes.flatten()
    
    for idx, metric in enumerate(metrics_names):
        axes[idx].bar(['DA3', 'UniDepthV2'], [da3_values[idx], unidepth_values[idx]], 
                     color=['#FF6B6B', '#4ECDC4'])
        axes[idx].set_title(metric.upper(), fontweight='bold')
        axes[idx].set_ylabel('Value')
        axes[idx].grid(True, alpha=0.3)
    
    plt.suptitle('Model Comparison on Strawberry Dataset', fontsize=16, fontweight='bold')
    plt.tight_layout()
    plt.savefig('/kaggle/working/model_comparison.png', dpi=150, bbox_inches='tight')
    plt.show()

## 7. Visual Comparison on Sample Images

In [None]:
# Visualize predictions on 4 sample images
sample_images = test_images[:4]

fig, axes = plt.subplots(4, 4, figsize=(20, 20))

for row, img_name in enumerate(sample_images):
    # Load RGB
    rgb_path = os.path.join(DATASET_DIR, "images", img_name)
    rgb = cv2.imread(rgb_path)
    rgb = cv2.cvtColor(rgb, cv2.COLOR_BGR2RGB)
    
    # Load GT depth
    depth_path = os.path.join(DATASET_DIR, "depth", img_name)
    gt_depth = load_depth_png(depth_path)
    
    # RGB
    axes[row, 0].imshow(rgb)
    axes[row, 0].set_title('RGB', fontweight='bold')
    axes[row, 0].axis('off')
    
    # Ground Truth
    im1 = axes[row, 1].imshow(gt_depth, cmap='turbo', vmin=0, vmax=3)
    axes[row, 1].set_title('Ground Truth', fontweight='bold')
    axes[row, 1].axis('off')
    plt.colorbar(im1, ax=axes[row, 1], fraction=0.046)
    
    # DA3 prediction
    if img_name in da3_results:
        pred = da3_results[img_name]
        if pred.shape != gt_depth.shape:
            pred = cv2.resize(pred, (gt_depth.shape[1], gt_depth.shape[0]))
        im2 = axes[row, 2].imshow(pred, cmap='turbo', vmin=0, vmax=3)
        axes[row, 2].set_title('Depth Anything 3', fontweight='bold')
        axes[row, 2].axis('off')
        plt.colorbar(im2, ax=axes[row, 2], fraction=0.046)
    else:
        axes[row, 2].text(0.5, 0.5, 'N/A', ha='center', va='center', fontsize=20)
        axes[row, 2].axis('off')
    
    # UniDepth prediction
    if img_name in unidepth_results:
        pred = unidepth_results[img_name]
        if pred.shape != gt_depth.shape:
            pred = cv2.resize(pred, (gt_depth.shape[1], gt_depth.shape[0]))
        im3 = axes[row, 3].imshow(pred, cmap='turbo', vmin=0, vmax=3)
        axes[row, 3].set_title('UniDepthV2', fontweight='bold')
        axes[row, 3].axis('off')
        plt.colorbar(im3, ax=axes[row, 3], fraction=0.046)
    else:
        axes[row, 3].text(0.5, 0.5, 'N/A', ha='center', va='center', fontsize=20)
        axes[row, 3].axis('off')

plt.suptitle('Depth Estimation Comparison', fontsize=18, fontweight='bold')
plt.tight_layout()
plt.savefig('/kaggle/working/visual_comparison.png', dpi=150, bbox_inches='tight')
plt.show()

## 8. Save Results

In [None]:
# Save comprehensive summary
summary = {
    "dataset": {
        "name": "Strawberry Synthetic Dataset",
        "test_images": len(test_images),
        "depth_range_meters": "0.3 - 2.5"
    },
    "models": {
        "depth_anything_3": {
            "available": DA3_AVAILABLE,
            "metrics": da3_avg_metrics if da3_avg_metrics else "N/A"
        },
        "unidepth_v2": {
            "available": UNIDEPTH_AVAILABLE,
            "metrics": unidepth_avg_metrics if unidepth_avg_metrics else "N/A"
        }
    },
    "notes": [
        "Metrics computed on absolute depth (meters)",
        "Scale alignment applied using median scaling",
        "Valid depth range: 0-10 meters"
    ]
}

with open('/kaggle/working/depth_evaluation_summary.json', 'w') as f:
    json.dump(summary, f, indent=2)

print("\n=== Evaluation Summary ===")
print(json.dumps(summary, indent=2))

print("\n✓ All results saved to /kaggle/working/")