# üçì Task 5: UniDepthV2 Evaluation with Intrinsics (Batched)

This notebook evaluates the **UniDepthV2** model on the Strawberry Synthetic Dataset.

**Key Features:**
1.  **Data Load**: Downloads dataset using GitHub Releases (Robust).
2.  **Precision**: Uses **`.npy`** files for high-precision Ground Truth depth.
3.  **Metric Depth**: Utilizes **Camera Intrinsics** during inference to ensure scale-accurate metric depth estimation.
4.  **Batch Inference**: Uses `DataLoader` with batch size 32 for efficient GPU utilization.
5.  **Evaluation**: Computes 5 key depth estimation metrics.

**Metrics:**
- **Abs Rel**: Absolute Relative Error (smaller is better)
- **RMSE**: Root Mean Squared Error (smaller is better)
- **RMSE log**: Log-space RMSE (smaller is better)
- **Œ¥1**: Accuracy < 1.25 (larger is better, max 1.0)
- **Sq Rel**: Squared Relative Error (smaller is better)

In [None]:
# 1. Install Dependencies
!git clone https://github.com/lpiccinelli-eth/UniDepth.git
!cd UniDepth && pip install -e .
!pip install timm huggingface_hub

In [None]:
import os
import sys
import json
import requests
import zipfile
import shutil
import glob
import inspect
from pathlib import Path
from io import BytesIO

import torch
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from tqdm.auto import tqdm
from torch.utils.data import Dataset, DataLoader

# Add UniDepth to path if not installed globally
if os.path.exists('/kaggle/working/UniDepth'):
    sys.path.append('/kaggle/working/UniDepth')

try:
    from unidepth.models import UniDepthV2
except ImportError:
    print("UniDepth not found in path, trying to import from local clone...")
    sys.path.append('UniDepth')
    from unidepth.models import UniDepthV2

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

## 2. Dataset Setup

In [None]:
# --- Dataset Configuration ---
VERSION_TAG = "Dataset"  # Correct tag for the dataset release
BASE_URL = f"https://github.com/SergKurchev/strawberry_synthetic_dataset/releases/download/{VERSION_TAG}"
FILES_TO_DOWNLOAD = [
    "strawberry_dataset.zip.001",
    "strawberry_dataset.zip.002",
    "strawberry_dataset.zip.003"
]
OUTPUT_ZIP = "strawberry_dataset.zip"

def reconstruct_metadata(dataset_root):
    """Reconstructs depth_metadata.json from individual files in metadata_temp/"""
    print("‚ö†Ô∏è 'depth_metadata.json' not found. Attempting reconstruction from 'metadata_temp/'...")
    temp_dir = dataset_root / "metadata_temp"
    if not temp_dir.exists():
        print(f"‚ùå metadata_temp directory not found at {temp_dir}")
        return False

    combined_metadata = {}
    json_files = list(temp_dir.glob("*_meta.json"))
    print(f"  Found {len(json_files)} metadata chunks.")
    
    for json_file in tqdm(json_files, desc="Reconstructing Metadata"):
        try:
            # Filename format: 00001_meta.json -> corresponds to 00001.png
            # We assume the content of the json is the metadata dict for that image
            img_id = json_file.name.replace("_meta.json", "")
            img_name = f"{img_id}.png"
            
            with open(json_file, 'r') as f:
                data = json.load(f)
                combined_metadata[img_name] = data
        except Exception as e:
            print(f"  Warning: Failed to read {json_file}: {e}")

    if not combined_metadata:
        print("‚ùå Failed to reconstruct any metadata.")
        return False

    target_path = dataset_root / "depth_metadata.json"
    print(f"üíæ Saving reconstructed metadata to {target_path}...")
    with open(target_path, 'w') as f:
        json.dump(combined_metadata, f, indent=2)
        
    return True

def setup_dataset():
    # 1. Search for existing dataset
    print("üîç Searching for existing dataset...")
    
    # Helper to validate a root candidate
    def validate_root(p):
        if (p / "depth_metadata.json").exists():
            return True
        if (p / "metadata_temp").exists():
            # Try to fix it
            return reconstruct_metadata(p)
        return False

    # Recursive search in current dir
    for root, dirs, files in os.walk(".", topdown=True):
        p = Path(root)
        # Start optimization: don't go too deep or into hidden dirs
        if ".git" in p.parts or "temp_download" in p.parts:
            continue
            
        if "images" in dirs and ("depth_metadata.json" in files or "metadata_temp" in dirs):
            if validate_root(p):
                print(f"‚úÖ Dataset found/Fixed at: {p}")
                return p

    # Check standard paths
    search_paths = [
        Path("strawberry_dataset"),
        Path("dataset/strawberry_dataset"),
        Path("/kaggle/input/last-straw-dataset/strawberry_dataset"),
        Path("/kaggle/input/strawberry_synthetic_dataset/strawberry_dataset")
    ]
    for p in search_paths:
        if p.exists():
            if validate_root(p):
                print(f"‚úÖ Dataset found/Fixed at: {p}")
                return p

    print("‚¨áÔ∏è Dataset not found. Downloading from GitHub Releases...")
    
    # 2. Prepare Download Directory
    if os.path.exists("temp_download"):
        shutil.rmtree("temp_download")
    os.makedirs("temp_download", exist_ok=True)
    
    if os.path.exists(OUTPUT_ZIP):
        os.remove(OUTPUT_ZIP)

    # 3. Download and Combine
    with open(OUTPUT_ZIP, 'wb') as outfile:
        for filename in FILES_TO_DOWNLOAD:
            file_path = Path("temp_download") / filename
            url = f"{BASE_URL}/{filename}"
            
            print(f"  Downloading {filename} from {url}...")
            r = requests.get(url, stream=True)
            if r.status_code != 200:
                raise RuntimeError(f"Download failed for {filename}: HTTP {r.status_code}")
            
            with open(file_path, 'wb') as f:
                for chunk in r.iter_content(chunk_size=8192):
                    f.write(chunk)
            
            file_size_mb = file_path.stat().st_size / 1024 / 1024
            print(f"  Downloaded {filename} ({file_size_mb:.2f} MB). Appending to zip...")
            
            with open(file_path, 'rb') as infile:
                shutil.copyfileobj(infile, outfile)

    # 4. Extract
    total_size_mb = os.path.getsize(OUTPUT_ZIP)/1024/1024
    print(f"üìÇ Extracting {OUTPUT_ZIP} ({total_size_mb:.2f} MB)...")
    
    try:
        with zipfile.ZipFile(OUTPUT_ZIP, 'r') as zip_ref:
            zip_ref.extractall(".")
            print("  Extraction complete.")
    except zipfile.BadZipFile as e:
        print(f"‚ùå BadZipFile Error: {e}")
        raise e
    
    shutil.rmtree("temp_download", ignore_errors=True)
    if os.path.exists(OUTPUT_ZIP):
        os.remove(OUTPUT_ZIP)

    # --- FIX: Handle potential backslash filenames on Linux ---
    print("üßπ Checking for backslash issues in filenames...")
    count = 0
    # Iterate over files in current directory to check for backslashes in names
    for filename in os.listdir("."):
        if "\\" in filename:
            # It's a file with backslashes in name, implying flattened structure
            new_path = filename.replace("\\", "/") # standardize to forward slash
            
            # Create parent dirs
            parent = os.path.dirname(new_path)
            if parent:
                os.makedirs(parent, exist_ok=True)
            
            # Move file
            try:
                shutil.move(filename, new_path)
                count += 1
            except Exception as e:
                print(f"  Failed to move {filename} -> {new_path}: {e}")
            
    if count > 0:
        print(f"‚úÖ Fixed {count} filenames with backslashes. Directory structure restored.")
        
    # 5. Locate and Fix
    print("üîé Locating dataset root...")
    for root, dirs, files in os.walk(".", topdown=True):
        p = Path(root)
        if "images" in dirs and ("depth_metadata.json" in files or "metadata_temp" in dirs):
            if validate_root(p):
                 print(f"‚úÖ Dataset extracted and verified at: {p}")
                 return p
            
    return None

DATASET_PATH = setup_dataset()
if not DATASET_PATH: raise RuntimeError("Dataset setup failed: Could not locate or reconstruct metadata")

## 3. Load Model & Prepare Dataset

In [None]:
# Load UniDepthV2 Model
VERSION = "lpiccinelli/unidepth-v2-vitl14"
print(f"Loading model {VERSION} to {device}...")
model = UniDepthV2.from_pretrained(VERSION).to(device)
model.eval();

In [None]:
# Load Metadata
with open(DATASET_PATH / "depth_metadata.json", 'r') as f:
    metadata = json.load(f)

test_images = sorted(list(metadata.keys()))
print(f"Found {len(test_images)} images in metadata.")

# Custom Dataset for Batching
class StrawberryEvalDataset(Dataset):
    def __init__(self, dataset_path, image_list, metadata):
        self.dataset_path = dataset_path
        self.image_list = image_list
        self.metadata = metadata
        
    def __len__(self):
        return len(self.image_list)
    
    def __getitem__(self, idx):
        img_name = self.image_list[idx]
        
        # Load RGB
        img_path = self.dataset_path / "images" / img_name
        rgb = Image.open(img_path).convert("RGB")
        rgb_np = np.array(rgb)
        # [3, H, W]
        rgb_tensor = torch.from_numpy(rgb_np).permute(2, 0, 1).float()
        
        # Get Intrinsics
        intrinsics_data = self.metadata[img_name].get("camera_intrinsics", {})
        if intrinsics_data:
            fx = intrinsics_data["fx"]
            fy = intrinsics_data["fy"]
            cx = intrinsics_data["cx"]
            cy = intrinsics_data["cy"]
        else:
            # Fallback if missing (should not happen in this dataset)
            fx, fy, cx, cy = 1000, 1000, 512, 512
            
        # [3, 3]
        K = torch.tensor([
            [fx, 0, cx],
            [0, fy, cy],
            [0, 0, 1]
        ]).float()
        
        return {
            'rgb': rgb_tensor,
            'K': K,
            'img_name': img_name
        }

# Create DataLoader
BATCH_SIZE = 32
dataset = StrawberryEvalDataset(DATASET_PATH, test_images, metadata)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)

print(f"DataLoader created with batch size {BATCH_SIZE}.")

# Metric Calculation Function
def compute_metrics(pred, gt, mask):
    """
    Computes standard depth estimation metrics.
    pred, gt: numpy arrays (meters)
    mask: boolean numpy array (valid pixels)
    """
    pred = pred[mask]
    gt = gt[mask]
    
    if len(gt) == 0:
        return None

    # Threshold: per-pixel max ratio
    thresh = np.maximum((gt / pred), (pred / gt))
    a1 = (thresh < 1.25).mean()      # delta1
    a2 = (thresh < 1.25 ** 2).mean() # delta2
    a3 = (thresh < 1.25 ** 3).mean() # delta3

    # Error metrics
    rms = (gt - pred) ** 2
    rmse = np.sqrt(rms.mean())

    rms_log = (np.log(gt) - np.log(pred)) ** 2
    rmse_log = np.sqrt(rms_log.mean())

    abs_rel = np.mean(np.abs(gt - pred) / gt)
    sq_rel = np.mean(((gt - pred) ** 2) / gt)

    return {
        'a1': a1,
        'a2': a2,
        'a3': a3,
        'rmse': rmse,
        'rmse_log': rmse_log,
        'abs_rel': abs_rel,
        'sq_rel': sq_rel
    }

## 4. Batch Evaluation Loop

In [None]:
metrics_list = []

print("Starting Batch Inference...")

# Debug signature once
try:
    print("Model infer signature:", inspect.signature(model.infer))
except:
    pass

for batch in tqdm(dataloader):
    rgb_batch = batch['rgb'].to(device)  # [B, 3, H, W]
    K_batch = batch['K'].to(device)      # [B, 3, 3]
    img_names = batch['img_name']
    
    # Inference
    with torch.no_grad():
        # Try calling with intrinsics (assuming batch support for K)
        try:
            predictions = model.infer(rgb_batch, intrinsics=K_batch)
        except TypeError:
            try:
                predictions = model.infer(rgb_batch, K=K_batch)
            except:
                # Fallback to no intrinsics (less accurate scale)
                predictions = model.infer(rgb_batch)
    
    pred_depth_batch = predictions["depth"] # [B, 1, H, W]
    
    # Process each image in batch
    for i, img_name in enumerate(img_names):
        pred_depth = pred_depth_batch[i].squeeze().cpu().numpy()
        
        # Load Ground Truth (.npy)
        npy_name = img_name.replace(".png", ".npy")
        gt_path = DATASET_PATH / "depth_npy" / npy_name
        if not gt_path.exists():
             gt_path = DATASET_PATH / "depth" / npy_name
        
        if gt_path.exists():
            gt_depth = np.load(gt_path).squeeze()
        else:
            continue
            
        # Resize if needed
        if pred_depth.shape != gt_depth.shape:
            import cv2
            pred_depth = cv2.resize(pred_depth, (gt_depth.shape[1], gt_depth.shape[0]), interpolation=cv2.INTER_LINEAR)

        # Mask & Compute Metrics
        mask = (gt_depth > 0) & (gt_depth < 10)
        if mask.sum() > 0:
            m = compute_metrics(pred_depth, gt_depth, mask)
            if m:
                metrics_list.append(m)

# Average Metrics
avg_metrics = {}
if metrics_list:
    for k in metrics_list[0].keys():
        avg_metrics[k] = np.mean([x[k] for x in metrics_list])

    print("\n=== Evaluation Results ===")
    print(f"Abs Rel:  {avg_metrics['abs_rel']:.4f}")
    print(f"RMSE:     {avg_metrics['rmse']:.4f}")
    print(f"RMSE log: {avg_metrics['rmse_log']:.4f}")
    print(f"Œ¥1 (Acc): {avg_metrics['a1']:.4f}")
    print(f"Sq Rel:   {avg_metrics['sq_rel']:.4f}")
else:
    print("‚ö†Ô∏è No metrics calculated.")

## 5. Visualization

In [None]:
# Visualize last processed in batch
if 'rgb_batch' in locals() and 'gt_depth' in locals():
    # Get last item from batch (idx i)
    last_rgb_tensor = rgb_batch[i].cpu().permute(1, 2, 0).numpy() / 255.0

    plt.figure(figsize=(20, 6))

    # 1. RGB
    plt.subplot(1, 3, 1)
    plt.imshow(last_rgb_tensor.astype(np.uint8) if last_rgb_tensor.max() > 1 else last_rgb_tensor)
    plt.title("RGB Input")
    plt.axis("off")

    # 2. GT Depth
    plt.subplot(1, 3, 2)
    plt.imshow(gt_depth, cmap='magma', vmin=0, vmax=3)
    plt.title("Ground Truth (.npy)")
    plt.axis("off")
    plt.colorbar(label='Meters')

    # 3. Predicted Depth
    plt.subplot(1, 3, 3)
    plt.imshow(pred_depth, cmap='magma', vmin=0, vmax=3)
    plt.title("UniDepthV2 Prediction")
    plt.axis("off")
    plt.colorbar(label='Meters')

    plt.show()