# 🍓 Task 6: Depth Anything V2 Evaluation

This notebook evaluates the **Depth Anything V2** (Metric Depth) model on the Strawberry Synthetic Dataset.

**Key Features:**
1.  **Data Load**: Downloads dataset using GitHub Releases (Robust).
2.  **Precision**: Uses **`.npy`** files for high-precision Ground Truth depth.
3.  **Metric Depth**: Uses the official `metric_depth` model for absolute depth estimation.
4.  **Console Inference**: Runs inference via command line script as requested.
5.  **Evaluation**: Computes 5 key depth estimation metrics.

**Metrics:**
- **Abs Rel**: Absolute Relative Error (smaller is better)
- **RMSE**: Root Mean Squared Error (smaller is better)
- **RMSE log**: Log-space RMSE (smaller is better)
- **δ1**: Accuracy < 1.25 (larger is better, max 1.0)
- **Sq Rel**: Squared Relative Error (smaller is better)

In [None]:
# 1. Install Dependencies
!git clone https://github.com/DepthAnything/Depth-Anything-V2.git
!pip install -r Depth-Anything-V2/requirements.txt
!pip install huggingface_hub

In [None]:
import os
import sys
import json
import requests
import zipfile
import shutil
from pathlib import Path
from io import BytesIO

import torch
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import cv2
from tqdm.auto import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

## 2. Dataset Setup

In [None]:
import os
import sys
import json
import requests
import zipfile
import shutil
import glob
import inspect
import torch
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from tqdm.auto import tqdm
from pathlib import Path


# --- Robust Dataset Configuration ---
VERSION_TAG = "Dataset"
BASE_URL = f"https://github.com/SergKurchev/strawberry_synthetic_dataset/releases/download/{VERSION_TAG}"
FILES_TO_DOWNLOAD = [
    "strawberry_dataset.zip.001",
    "strawberry_dataset.zip.002",
    "strawberry_dataset.zip.003"
]
OUTPUT_ZIP = "strawberry_dataset.zip"

def reconstruct_metadata(dataset_root):
    """Reconstructs depth_metadata.json from individual files in metadata_temp/"""
    print("⚠️ 'depth_metadata.json' not found. Attempting reconstruction from 'metadata_temp/'...")
    temp_dir = dataset_root / "metadata_temp"
    if not temp_dir.exists():
        print(f"❌ metadata_temp directory not found at {temp_dir}")
        return False

    combined_metadata = {}
    json_files = list(temp_dir.glob("*_meta.json"))
    print(f"  Found {len(json_files)} metadata chunks.")
    
    for json_file in tqdm(json_files, desc="Reconstructing Metadata"):
        try:
            # Filename format: 00001_meta.json -> corresponds to 00001.png
            # We assume the content of the json is the metadata dict for that image
            img_id = json_file.name.replace("_meta.json", "")
            img_name = f"{img_id}.png"
            
            with open(json_file, 'r') as f:
                data = json.load(f)
                combined_metadata[img_name] = data
        except Exception as e:
            print(f"  Warning: Failed to read {json_file}: {e}")

    if not combined_metadata:
        print("❌ Failed to reconstruct any metadata.")
        return False

    target_path = dataset_root / "depth_metadata.json"
    print(f"💾 Saving reconstructed metadata to {target_path}...")
    with open(target_path, 'w') as f:
        json.dump(combined_metadata, f, indent=2)
        
    return True

def setup_dataset():
    # 1. Search for existing dataset
    print("🔍 Searching for existing dataset...")
    
    # Helper to validate a root candidate
    def validate_root(p):
        if (p / "depth_metadata.json").exists():
            return True
        if (p / "metadata_temp").exists():
            # Try to fix it
            return reconstruct_metadata(p)
        return False

    # Recursive search in current dir
    for root, dirs, files in os.walk(".", topdown=True):
        p = Path(root)
        # Start optimization: don't go too deep or into hidden dirs
        if ".git" in p.parts or "temp_download" in p.parts:
            continue
            
        if "images" in dirs and ("depth_metadata.json" in files or "metadata_temp" in dirs):
            if validate_root(p):
                print(f"✅ Dataset found/Fixed at: {p}")
                return p

    # Check standard paths
    search_paths = [
        Path("strawberry_dataset"),
        Path("dataset/strawberry_dataset"),
        Path("/kaggle/input/last-straw-dataset/strawberry_dataset"),
        Path("/kaggle/input/strawberry_synthetic_dataset/strawberry_dataset")
    ]
    for p in search_paths:
        if p.exists():
            if validate_root(p):
                print(f"✅ Dataset found/Fixed at: {p}")
                return p

    print("⬇️ Dataset not found. Downloading from GitHub Releases...")
    
    # 2. Prepare Download Directory
    if os.path.exists("temp_download"):
        shutil.rmtree("temp_download")
    os.makedirs("temp_download", exist_ok=True)
    
    if os.path.exists(OUTPUT_ZIP):
        os.remove(OUTPUT_ZIP)

    # 3. Download and Combine
    with open(OUTPUT_ZIP, 'wb') as outfile:
        for filename in FILES_TO_DOWNLOAD:
            file_path = Path("temp_download") / filename
            url = f"{BASE_URL}/{filename}"
            
            print(f"  Downloading {filename} from {url}...")
            r = requests.get(url, stream=True)
            if r.status_code != 200:
                raise RuntimeError(f"Download failed for {filename}: HTTP {r.status_code}")
            
            with open(file_path, 'wb') as f:
                for chunk in r.iter_content(chunk_size=8192):
                    f.write(chunk)
            
            file_size_mb = file_path.stat().st_size / 1024 / 1024
            print(f"  Downloaded {filename} ({file_size_mb:.2f} MB). Appending to zip...")
            
            with open(file_path, 'rb') as infile:
                shutil.copyfileobj(infile, outfile)

    # 4. Extract
    total_size_mb = os.path.getsize(OUTPUT_ZIP)/1024/1024
    print(f"📂 Extracting {OUTPUT_ZIP} ({total_size_mb:.2f} MB)...")
    
    try:
        with zipfile.ZipFile(OUTPUT_ZIP, 'r') as zip_ref:
            zip_ref.extractall(".")
            print("  Extraction complete.")
    except zipfile.BadZipFile as e:
        print(f"❌ BadZipFile Error: {e}")
        raise e
    
    shutil.rmtree("temp_download", ignore_errors=True)
    if os.path.exists(OUTPUT_ZIP):
        os.remove(OUTPUT_ZIP)

    # --- FIX: Handle potential backslash filenames on Linux ---
    print("🧹 Checking for backslash issues in filenames...")
    count = 0
    # Iterate over files in current directory to check for backslashes in names
    for filename in os.listdir("."):
        if "\\" in filename:
            # It's a file with backslashes in name, implying flattened structure
            new_path = filename.replace("\\", "/") # standardize to forward slash
            
            # Create parent dirs
            parent = os.path.dirname(new_path)
            if parent:
                os.makedirs(parent, exist_ok=True)
            
            # Move file
            try:
                shutil.move(filename, new_path)
                count += 1
            except Exception as e:
                print(f"  Failed to move {filename} -> {new_path}: {e}")
            
    if count > 0:
        print(f"✅ Fixed {count} filenames with backslashes. Directory structure restored.")
        
    # 5. Locate and Fix
    print("🔎 Locating dataset root...")
    for root, dirs, files in os.walk(".", topdown=True):
        p = Path(root)
        if "images" in dirs and ("depth_metadata.json" in files or "metadata_temp" in dirs):
            if validate_root(p):
                 print(f"✅ Dataset extracted and verified at: {p}")
                 return p
            
    return None

DATASET_PATH = setup_dataset()
if not DATASET_PATH: raise RuntimeError("Dataset setup failed: Could not locate or reconstruct metadata")
DATASET_ROOT = DATASET_PATH


## 3. Prepare Checkpoints

In [None]:
# Download the specific checkpoint (Hypersim, ViT-Base)
CHECKPOINT_DIR = Path("checkpoints")
CHECKPOINT_DIR.mkdir(exist_ok=True)
ckpt_path = CHECKPOINT_DIR / "depth_anything_v2_metric_hypersim_vitb.pth"

if not ckpt_path.exists():
    print("⬇️ Downloading Checkpoint...")
    # Using typical HuggingFace link pattern for Depth Anything V2
    url = "https://huggingface.co/depth-anything/Depth-Anything-V2-Metric-Hypersim-Base/resolve/main/depth_anything_v2_metric_hypersim_vitb.pth"
    r = requests.get(url, allow_redirects=True)
    with open(ckpt_path, 'wb') as f:
        f.write(r.content)
    print("✅ Checkpoint downloaded.")
else:
    print("✅ Checkpoint already exists.")

## 4. Run Inference (Console Command)

In [None]:
# Define paths for inference
IMG_PATH = str(DATASET_PATH / "images")
OUT_DIR = "metric_depth_vis"

# Ensure output directory is clean
if os.path.exists(OUT_DIR):
    shutil.rmtree(OUT_DIR)
os.makedirs(OUT_DIR, exist_ok=True)

# Construct the command
dataset_images_abs = os.path.abspath(IMG_PATH)
out_dir_abs = os.path.abspath(OUT_DIR)
ckpt_abs = os.path.abspath(ckpt_path)

print("🚀 Running Inference...")
# Using --save-numpy to ensure we get raw metric depth if supported, otherwise falling back to png parsing

cmd = (
    f"python Depth-Anything-V2/metric_depth/run.py "
    f"--encoder vitb "
    f"--load-from {ckpt_abs} "
    f"--max-depth 20 " # As per user example
    f"--img-path {dataset_images_abs} "
    f"--outdir {out_dir_abs} "
    f"--save-numpy " # Adding this to ensure high precision for metrics
)

!{cmd}

## 5. Evaluation & Metrics

In [None]:
# Metric Calculation Function (Reused from UniDepth NB)
def compute_metrics(pred, gt, mask):
    """
    Computes standard depth estimation metrics.
    pred, gt: numpy arrays (meters)
    mask: boolean numpy array (valid pixels)
    """
    pred = pred[mask]
    gt = gt[mask]
    
    # Threshold: per-pixel max ratio
    thresh = np.maximum((gt / pred), (pred / gt))
    a1 = (thresh < 1.25).mean()      # delta1
    a2 = (thresh < 1.25 ** 2).mean() # delta2
    a3 = (thresh < 1.25 ** 3).mean() # delta3

    # Error metrics
    rms = (gt - pred) ** 2
    rmse = np.sqrt(rms.mean())

    rms_log = (np.log(gt) - np.log(pred)) ** 2
    rmse_log = np.sqrt(rms_log.mean())

    abs_rel = np.mean(np.abs(gt - pred) / gt)
    sq_rel = np.mean(((gt - pred) ** 2) / gt)

    return {
        'a1': a1,
        'a2': a2,
        'a3': a3,
        'rmse': rmse,
        'rmse_log': rmse_log,
        'abs_rel': abs_rel,
        'sq_rel': sq_rel
    }

def plot_results(result_dir):
    """
    Reads results from output dir, computes metrics using GT, and visualizes.
    This mimics the user's requested style.
    """
    metrics_list = []
    result_path = Path(result_dir)
    
    if not result_path.exists():
        print(f"❌ Result directory {result_dir} does not exist!")
        return

    npy_files = sorted(list(result_path.glob("*.npy")))
    
    if len(npy_files) > 0:
        file_list = npy_files
        is_npy = True
    else:
        print("⚠️ No .npy files found. Falling back to .png (less accurate)...")
        file_list = sorted(list(result_path.glob("*.png")))
        is_npy = False
        
    print(f"Processing {len(file_list)} output files...")
    
    last_pred = None
    last_gt = None
    last_metrics = None
    last_stem = ""

    for pred_path in tqdm(file_list):
        if "_raw" in pred_path.stem:
             stem = pred_path.stem.replace("_raw", "")
        else:
             stem = pred_path.stem
             
        gt_path = DATASET_PATH / "depth_npy" / f"{stem}.npy"
        if not gt_path.exists():
             gt_path = DATASET_PATH / "depth" / f"{stem}.npy"
        
        if not gt_path.exists():
            continue
            
        gt_depth = np.load(gt_path).squeeze()
        
        if is_npy:
            pred_depth = np.load(pred_path)
        else:
            # Fallback for PNG (typically 16-bit uint mm)
            img = cv2.imread(str(pred_path), cv2.IMREAD_UNCHANGED)
            if img is None: continue
            pred_depth = img.astype(float) / 1000.0
            
        # Resize if mismatch
        if pred_depth.shape != gt_depth.shape:
            pred_depth = cv2.resize(pred_depth, (gt_depth.shape[1], gt_depth.shape[0]), interpolation=cv2.INTER_LINEAR)
            
        mask = (gt_depth > 0) & (gt_depth < 10)
        if mask.sum() == 0: continue
        
        metrics = compute_metrics(pred_depth, gt_depth, mask)
        metrics_list.append(metrics)
        
        last_pred = pred_depth
        last_gt = gt_depth
        last_metrics = metrics
        last_stem = stem

    # Aggregate
    if metrics_list:
        avg = {}
        for k in metrics_list[0].keys():
            avg[k] = np.mean([m[k] for m in metrics_list])
            
        print("\n=== Evaluation Results ===")
        print(f"Abs Rel:  {avg['abs_rel']:.4f}")
        print(f"RMSE:     {avg['rmse']:.4f}")
        print(f"RMSE log: {avg['rmse_log']:.4f}")
        print(f"δ1 (Acc): {avg['a1']:.4f}")
        print(f"Sq Rel:   {avg['sq_rel']:.4f}")
        
        # Visualize
        plt.figure(figsize=(20, 6))
        
        # RGB
        rgb_path = DATASET_PATH / "images" / f"{last_stem}.png"
        if rgb_path.exists():
            plt.subplot(1,3,1)
            plt.imshow(Image.open(rgb_path))
            plt.title("RGB Input")
            plt.axis("off")
            
        # GT
        plt.subplot(1,3,2)
        plt.imshow(last_gt, cmap='magma', vmin=0, vmax=5)
        plt.title("Ground Truth (.npy)")
        plt.axis("off")
        plt.colorbar(label='Meters')
        
        # Pred
        plt.subplot(1,3,3)
        plt.imshow(last_pred, cmap='magma', vmin=0, vmax=5)
        plt.title(f"Prediction (DAV2)\nAbsRel: {last_metrics['abs_rel']:.3f}")
        plt.axis("off")
        plt.colorbar(label='Meters')
        
        plt.show()
    else:
        print("No valid metrics computed.")

# Run the evaluation function
plot_results(OUT_DIR)