# üçì Strawberry & Peduncle: Pipeline Evaluation with Depth Anything V3

This notebook evaluates the full pipeline using **Metric Depth Estimation**:
1.  **Metric Depth**: `Depth Anything V3` estimates depth in meters and camera intrinsics.
2.  **Segmentation**: `YOLOv11` detects Strawberries and Peduncles.
3.  **Association**: `AffinityNet` matches peduncles to strawberries.
4.  **3D Localization**: Combines metric depth with 2D detections to get real-world coordinates.

## Models
- **Depth**: `Depth-Anything-V3-Large`
- **Segmentation**: `yolo11l-seg-strawberry-stem-2`
- **Matching**: `affinity-net-strawberry-peduncle-maching-v1`

In [None]:
# üì¶ Install Dependencies
# Depth Anything V3
!pip install git+https://github.com/ByteDance-Seed/Depth-Anything-3.git
!pip install xformers torch>=2 torchvision matplotlib opencv-python-headless scikit-learn tqdm ultralytics

In [None]:
import os
import json
import glob
import numpy as np
import cv2
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from PIL import Image
from ultralytics import YOLO
from scipy.optimize import linear_sum_assignment
from tqdm.notebook import tqdm

# Import Depth Anything V3
try:
    from depth_anything_3.api import DepthAnything3
    print("‚úÖ Depth Anything V3 imported successfully")
except ImportError:
    print("‚ùå ERROR: depth_anything_3 not installed. Run the pip install cell!")

# === CONFIGURATION ===
# Dataset Path check
POSSIBLE_DATASET_PATHS = [
    "/kaggle/input/strawberry-peduncle-segmentation/strawberry_peduncle_segmentation/dataset",
    "/kaggle/input/strawberry-peduncle-segmentation/dataset",
    "dataset" # Local fallback
]
DATASET_PATH = None
for p in POSSIBLE_DATASET_PATHS:
    if os.path.exists(p):
        DATASET_PATH = p
        break

if DATASET_PATH is None:
    print("‚ö†Ô∏è DATASET NOT FOUND! Please check input paths.")
    DATASET_PATH = "dataset"

print(f"üìÇ using Dataset Path: {DATASET_PATH}")

# Model Weights Check
WEIGHTS_YOLO = "/kaggle/input/yolo11l-seg-strawberry-stem-2/pytorch/default/1/yolo11l-seg-strawberry-stem-2.pt"
WEIGHTS_AFFINITY = "/kaggle/input/affinity-net-strawberry-peduncle-maching-v1/pytorch/default/1/best_affinity_net.pth"

# Check if weights exist, simple search if not
if not os.path.exists(WEIGHTS_YOLO):
    found = glob.glob("/kaggle/input/yolo11l-seg-strawberry-stem-2/**/*.pt", recursive=True)
    if found:
        WEIGHTS_YOLO = found[0]

if not os.path.exists(WEIGHTS_AFFINITY):
    found = glob.glob("/kaggle/input/affinity-net-strawberry-peduncle-maching-v1/**/*.pth", recursive=True)
    if found:
        WEIGHTS_AFFINITY = found[0]

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"‚úÖ Device: {DEVICE}")

## 1. Define AffinityNet & Utils

In [None]:
class AffinityNet(nn.Module):
    def __init__(self, spatial_dim=5, hidden_dims=[32, 16]):
        super().__init__()
        layers = []
        prev_dim = spatial_dim
        for hidden_dim in hidden_dims:
            layers.extend([
                nn.Linear(prev_dim, hidden_dim),
                nn.ReLU(inplace=True),
                nn.Dropout(0.3)
            ])
            prev_dim = hidden_dim
        layers.extend([
            nn.Linear(prev_dim, 1),
            nn.Sigmoid()
        ])
        self.network = nn.Sequential(*layers)

    def forward(self, spatial_features):
        return self.network(spatial_features)

    def predict_matrix(self, spatial_matrix):
        N_para, N_cube, _ = spatial_matrix.shape
        spatial_flat = spatial_matrix.reshape(-1, 5)
        affinity_flat = self.forward(spatial_flat)
        return affinity_flat.reshape(N_para, N_cube)

def compute_spatial_features_batch(para_bboxes, cube_bboxes, para_masks, cube_masks, image_size):
    H, W = image_size
    N_para = len(para_bboxes)
    N_cube = len(cube_bboxes)
    spatial_matrix = np.zeros((N_para, N_cube, 5), dtype=np.float32)
    
    for i in range(N_para):
        for j in range(N_cube):
            # Extract features (simplified for brevity, ensuring complete notebook flow)
            px1, py1, px2, py2 = para_bboxes[i]
            cx1, cy1, cx2, cy2 = cube_bboxes[j]
            
            vertical_dist = abs(py2 - cy1) / H
            vertical_score = max(0, 1.0 - vertical_dist * 5.0)
            
            overlap_left = max(px1, cx1)
            overlap_right = min(px2, cx2)
            overlap_width = max(0, overlap_right - overlap_left)
            para_width = px2 - px1
            horizontal_overlap = overlap_width / (para_width + 1e-6)
            
            para_center_x = (px1 + px2) / 2
            cube_center_x = (cx1 + cx2) / 2
            cube_width = cx2 - cx1
            offset = abs(para_center_x - cube_center_x)
            centeredness = max(0, 1.0 - offset / (cube_width / 2 + 1e-6))
            
            para_area = (px2 - px1) * (py2 - py1)
            cube_area = (cx2 - cx1) * (cy2 - cy1)
            size_ratio = min(para_area / (cube_area + 1e-6), 1.0)
            
            mask_iou = 0.0
            if para_masks is not None and cube_masks is not None:
                pm = para_masks[i]
                cm = cube_masks[j]
                intersection = np.logical_and(pm, cm).sum()
                union = np.logical_or(pm, cm).sum()
                mask_iou = intersection / (union + 1e-6)
                
            spatial_matrix[i, j] = [vertical_score, horizontal_overlap, centeredness, size_ratio, mask_iou]
            
    return spatial_matrix

## 2. Load Models

In [None]:
# 1. YOLO
print("üöÄ Loading YOLO...")
try:
    if os.path.exists(WEIGHTS_YOLO):
        yolo_model = YOLO(WEIGHTS_YOLO)
        print("‚úÖ YOLO Loaded")
    else:
        print(f"‚ùå ERROR: YOLO Weights not found!")
        yolo_model = None
except Exception as e:
    print(f"‚ö†Ô∏è Error loading YOLO: {e}")
    yolo_model = None

# 2. AffinityNet
print("üöÄ Loading AffinityNet...")
affinity_model = AffinityNet().to(DEVICE)
if os.path.exists(WEIGHTS_AFFINITY):
    affinity_model.load_state_dict(torch.load(WEIGHTS_AFFINITY, map_location=DEVICE))
    affinity_model.eval()
    print("‚úÖ AffinityNet Loaded")
else:
    print(f"‚ùå ERROR: AffinityNet Weights not found!")
    
# 3. Depth Anything V3
print("üöÄ Loading Depth Anything V3 (Metric)...")
try:
    DEPTH_MODEL_NAME = "Depth-Anything-V3-Large" # Or "Depth-Anything-V3-Small"
    depth_model = DepthAnything3.from_pretrained(DEPTH_MODEL_NAME).to(DEVICE)
    print("‚úÖ Depth Anything V3 Loaded (with Intrinsics estimation)")
except Exception as e:
    print(f"‚ùå Error loading Depth Anything V3: {e}")
    depth_model = None

## 3. Processing Pipeline

We use the depth map from DA-V3 to get the Z-coordinate (in meters) for each object.
We use the **inferred intrinsics** from DA-V3 to calculate X and Y.

In [None]:
def pixel_to_3d_metric(bbox, depth_map, intrinsics):
    """
    Calculates 3D position using the metric depth map and intrinsics.
    """
    fx = intrinsics[0, 0]
    fy = intrinsics[1, 1]
    cx = intrinsics[0, 2]
    cy = intrinsics[1, 2]
    
    x1, y1, x2, y2 = map(int, bbox)
    # Ensure coords are within image bounds
    H, W = depth_map.shape
    x1, x2 = max(0, x1), min(W, x2)
    y1, y2 = max(0, y1), min(H, y2)
    
    if x2 <= x1 or y2 <= y1:
        return 0.0, 0.0, 0.0
        
    # Get median depth in the bounding box region
    # (Can be refined to use mask if available)
    depth_crop = depth_map[y1:y2, x1:x2]
    Z = np.median(depth_crop)
    
    # Centroid
    u = (x1 + x2) / 2
    v = (y1 + y2) / 2
    
    X = (u - cx) * Z / fx
    Y = -(v - cy) * Z / fy # Invert Y for standard coordinate systems if needed
    
    return float(X), float(Y), float(Z)

def process_image(image_path, visualize=False):
    if yolo_model is None or depth_model is None:
        print("‚ùå Models not loaded correctly.")
        return None
    
    img_filename = os.path.basename(image_path)
    
    # 1. Load Image
    image_cv = cv2.imread(image_path)
    image_rgb = cv2.cvtColor(image_cv, cv2.COLOR_BGR2RGB)
    image_pil = Image.fromarray(image_rgb)
    H, W, _ = image_rgb.shape

    # 2. Metric Depth Inference
    with torch.no_grad():
        depth_pred = depth_model.inference([image_pil])
        
    depth_map = depth_pred.depth[0] # (H, W) in meters
    intrinsics = depth_pred.intrinsics[0].cpu().numpy() # (3, 3)
    
    # 3. YOLO Inference
    results = yolo_model(image_path, conf=0.5, verbose=False)[0]
    boxes = results.boxes.xyxy.cpu().numpy()
    classes = results.boxes.cls.cpu().numpy()
    masks = results.masks.data.cpu().numpy() if results.masks else None
    
    cube_indices = np.where(classes == 0)[0]
    para_indices = np.where(classes == 1)[0]
    
    cubes = []
    parallelepipeds = []
    
    # Process Objects
    for idx in cube_indices:
        box = boxes[idx]
        x, y, z = pixel_to_3d_metric(box, depth_map, intrinsics)
        cubes.append({
            'id': int(idx),
            'bbox': box.tolist(),
            'mask': masks[idx] if masks is not None else None,
            'pos_3d': [x, y, z],
            'class': 'strawberry'
        })
        
    for idx in para_indices:
        box = boxes[idx]
        x, y, z = pixel_to_3d_metric(box, depth_map, intrinsics)
        parallelepipeds.append({
            'id': int(idx),
            'bbox': box.tolist(),
            'mask': masks[idx] if masks is not None else None,
            'pos_3d': [x, y, z],
            'class': 'peduncle',
            'matched_cube_id': None
        })
        
    # 4. Association
    if len(cubes) > 0 and len(parallelepipeds) > 0:
        cube_boxes = np.array([c['bbox'] for c in cubes])
        para_boxes = np.array([p['bbox'] for p in parallelepipeds])
        cm = np.array([c['mask'] for c in cubes]) if masks is not None else None
        pm = np.array([p['mask'] for p in parallelepipeds]) if masks is not None else None
        
        spatial_matrix = compute_spatial_features_batch(para_boxes, cube_boxes, pm, cm, (H, W))
        spatial_t = torch.from_numpy(spatial_matrix).to(DEVICE)
        
        with torch.no_grad():
            affinity = affinity_model.predict_matrix(spatial_t).cpu().numpy()
            
        row_ind, col_ind = linear_sum_assignment(-affinity)
        for r, c in zip(row_ind, col_ind):
            if affinity[r, c] > 0.5:
                parallelepipeds[r]['matched_cube_id'] = cubes[c]['id']

    # 5. Visualization
    if visualize:
        plt.figure(figsize=(15, 6))
        
        # RGB + Bounding Boxes
        plt.subplot(1, 2, 1)
        vis_img = image_rgb.copy()
        for c in cubes:
            x1, y1, x2, y2 = map(int, c['bbox'])
            cv2.rectangle(vis_img, (x1, y1), (x2, y2), (255, 0, 0), 2)
            cv2.putText(vis_img, f"{c['pos_3d'][2]:.2f}m", (x1, y1-5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 0), 2)
            
        for p in parallelepipeds:
            x1, y1, x2, y2 = map(int, p['bbox'])
            cv2.rectangle(vis_img, (x1, y1), (x2, y2), (0, 255, 0), 2)
            cv2.putText(vis_img, f"{p['pos_3d'][2]:.2f}m", (x1, y1-5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
        
        plt.imshow(vis_img)
        plt.title(f"Detections & Metric Depth Labels")
        plt.axis('off')
        
        # Depth Map
        plt.subplot(1, 2, 2)
        plt.imshow(depth_map, cmap='turbo')
        plt.colorbar(label='Depth (m)')
        plt.title(f"Depth Anything V3 Output")
        plt.axis('off')
        plt.show()
        
    # Output Format
    output = {'image': img_filename, 'strawberries': [], 'peduncles': []}
    for c in cubes:
        d = {k:v for k,v in c.items() if k!='mask'}
        output['strawberries'].append(d)
    for p in parallelepipeds:
        d = {k:v for k,v in p.items() if k!='mask'}
        output['peduncles'].append(d)
        
    return output

In [None]:
# === RUN ON TEST IMAGES ===
if DATASET_PATH is not None:
    search_paths = [
        os.path.join(DATASET_PATH, "images", "*.png"),
        os.path.join(DATASET_PATH, "*.png")
    ]
    test_images = []
    for sp in search_paths:
        found = glob.glob(sp)
        if len(found) > 0:
            test_images = found
            break
    test_images = test_images[:5]
else:
    test_images = []

print(f"üß™ Testing on {len(test_images)} images...")

results_json = []
for img_path in test_images:
    try:
        res = process_image(img_path, visualize=True)
        if res:
            results_json.append(res)
    except Exception as e:
        print(f"‚ö†Ô∏è Error processing {img_path}: {e}")

# Save
with open("detailed_results.json", "w") as f:
    json.dump(results_json, f, indent=2)
print("‚úÖ Saved detailed_results.json")

In [None]:
# === 3D PLOT ===
def plot_3d_topdown(results):
    plt.figure(figsize=(10, 10))
    for entry in results:
        for s in entry['strawberries']:
            x, y, z = s['pos_3d']
            plt.scatter(x, z, c='red', marker='s', s=100, label='Strawberry' if 'Strawberry' not in plt.gca().get_legend_handles_labels()[1] else "")
        for p in entry['peduncles']:
            x, y, z = p['pos_3d']
            plt.scatter(x, z, c='green', marker='^', s=50, label='Peduncle' if 'Peduncle' not in plt.gca().get_legend_handles_labels()[1] else "")
            if p['matched_cube_id']:
                match = next((s for s in entry['strawberries'] if s['id'] == p['matched_cube_id']), None)
                if match:
                    mx, my, mz = match['pos_3d']
                    plt.plot([x, mx], [z, mz], 'k--', alpha=0.3)
    plt.xlabel("X (meters)")
    plt.ylabel("Z (Depth, meters)")
    plt.title("Top-Down View of Detected Objects (Metric Depth)")
    plt.grid(True)
    plt.legend()
    plt.axis('equal')
    plt.show()

if results_json:
    plot_3d_topdown(results_json)