# üçì Strawberry & Peduncle: Segmentation, Matching & 3D Localization

This notebook evaluates the full pipeline:
1.  **Segmentation**: YOLOv11 detects Strawberries (Cubes) and Peduncles (Parallelepipeds).
2.  **Association**: AffinityNet predicts which peduncle belongs to which strawberry.
3.  **3D Localization**: Calculates XYZ coordinates in meters relative to the camera.

## Models
- **Segmentation**: `yolo11l-seg-strawberry-stem-2`
- **Matching**: `affinity-net-strawberry-peduncle-maching-v1`

## ‚ö†Ô∏è Troubleshooting
If weights are not found, make sure you have added the correct Models to your Kaggle notebook usage.

In [None]:
# üì¶ Install Dependencies
!pip install ultralytics opencv-python-headless matplotlib scikit-learn tqdm

In [None]:
import os
import json
import glob
import numpy as np
import cv2
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from ultralytics import YOLO
from scipy.optimize import linear_sum_assignment
from tqdm.notebook import tqdm

# === CONFIGURATION ===
# Dataset Path check
POSSIBLE_DATASET_PATHS = [
    "/kaggle/input/strawberry-peduncle-segmentation/strawberry_peduncle_segmentation/dataset",
    "/kaggle/input/strawberry-peduncle-segmentation/dataset",
    "dataset" # Local fallback
]
DATASET_PATH = None
for p in POSSIBLE_DATASET_PATHS:
    if os.path.exists(p):
        DATASET_PATH = p
        break

if DATASET_PATH is None:
    print("‚ö†Ô∏è DATASET NOT FOUND! Please check input paths.")
    DATASET_PATH = "/kaggle/input/strawberry-peduncle-segmentation/strawberry_peduncle_segmentation/dataset" # Default to try

print(f"üìÇ using Dataset Path: {DATASET_PATH}")

# Model Weights Check
# Weights might be in slightly different folders depending on version
WEIGHTS_YOLO = "/kaggle/input/yolo11l-seg-strawberry-stem-2/pytorch/default/1/best.pt"
WEIGHTS_AFFINITY = "/kaggle/input/affinity-net-strawberry-peduncle-maching-v1/pytorch/default/1/best_affinity_net.pth"

# Check if weights exist, simple search if not
if not os.path.exists(WEIGHTS_YOLO):
    found = glob.glob("/kaggle/input/yolo11l-seg-strawberry-stem-2/**/*.pt", recursive=True)
    if found:
        WEIGHTS_YOLO = found[0]
        print(f"üîç Found YOLO weights at: {WEIGHTS_YOLO}")
    else:
        print(f"‚ùå YOLO Weights NOT FOUND at {WEIGHTS_YOLO}")

if not os.path.exists(WEIGHTS_AFFINITY):
    found = glob.glob("/kaggle/input/affinity-net-strawberry-peduncle-maching-v1/**/*.pth", recursive=True)
    if found:
        WEIGHTS_AFFINITY = found[0]
        print(f"üîç Found Affinity weights at: {WEIGHTS_AFFINITY}")
    else:
        print(f"‚ùå Affinity Weights NOT FOUND at {WEIGHTS_AFFINITY}")

# Camera Intrinsics (Unity Default)
IMG_WIDTH = 1024
IMG_HEIGHT = 1024
FOV_VERTICAL_DEG = 60.0

# Real object sizes (Meters)
CUBE_RealHeight = 0.03  # 3 cm
PARA_RealHeight = 0.02  # 2 cm (approx length for depth estimation)

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"‚úÖ Device: {DEVICE}")

## 1. Define AffinityNet & Utils

Defining the network architecture and feature extraction logic here to make the notebook self-contained.

In [None]:
class AffinityNet(nn.Module):
    """
    MLP for predicting association affinity between parallelepiped and cube.
    Uses only geometric/spatial features - no visual features needed.
    """
    def __init__(self, spatial_dim=5, hidden_dims=[32, 16]):
        super().__init__()
        layers = []
        prev_dim = spatial_dim
        for hidden_dim in hidden_dims:
            layers.extend([
                nn.Linear(prev_dim, hidden_dim),
                nn.ReLU(inplace=True),
                nn.Dropout(0.3)
            ])
            prev_dim = hidden_dim
        layers.extend([
            nn.Linear(prev_dim, 1),
            nn.Sigmoid()
        ])
        self.network = nn.Sequential(*layers)

    def forward(self, spatial_features):
        return self.network(spatial_features)

    def predict_matrix(self, spatial_matrix):
        N_para, N_cube, _ = spatial_matrix.shape
        spatial_flat = spatial_matrix.reshape(-1, 5)
        affinity_flat = self.forward(spatial_flat)
        return affinity_flat.reshape(N_para, N_cube)

def compute_spatial_features(para_bbox, cube_bbox, para_mask, cube_mask, image_size):
    H, W = image_size
    px1, py1, px2, py2 = para_bbox
    cx1, cy1, cx2, cy2 = cube_bbox

    # 1. Vertical distance info
    vertical_dist = abs(py2 - cy1) / H
    vertical_score = max(0, 1.0 - vertical_dist * 5.0)

    # 2. Horizontal overlap
    overlap_left = max(px1, cx1)
    overlap_right = min(px2, cx2)
    overlap_width = max(0, overlap_right - overlap_left)
    para_width = px2 - px1
    horizontal_overlap = overlap_width / (para_width + 1e-6)

    # 3. Centeredness
    para_center_x = (px1 + px2) / 2
    cube_center_x = (cx1 + cx2) / 2
    cube_width = cx2 - cx1
    offset = abs(para_center_x - cube_center_x)
    centeredness = max(0, 1.0 - offset / (cube_width / 2 + 1e-6))

    # 4. Size ratio
    para_area = (px2 - px1) * (py2 - py1)
    cube_area = (cx2 - cx1) * (cy2 - cy1)
    size_ratio = para_area / (cube_area + 1e-6)
    size_ratio = min(size_ratio, 1.0)

    # 5. Mask IoU
    if para_mask is not None and cube_mask is not None:
        intersection = np.logical_and(para_mask, cube_mask).sum()
        union = np.logical_or(para_mask, cube_mask).sum()
        mask_iou = intersection / (union + 1e-6)
    else:
        mask_iou = 0.0

    return np.array([vertical_score, horizontal_overlap, centeredness, size_ratio, mask_iou], dtype=np.float32)

def compute_spatial_features_batch(para_bboxes, cube_bboxes, para_masks, cube_masks, image_size):
    N_para = len(para_bboxes)
    N_cube = len(cube_bboxes)
    spatial_matrix = np.zeros((N_para, N_cube, 5), dtype=np.float32)
    for i in range(N_para):
        for j in range(N_cube):
            spatial_matrix[i, j] = compute_spatial_features(
                para_bboxes[i], cube_bboxes[j],
                para_masks[i] if para_masks is not None else None,
                cube_masks[j] if cube_masks is not None else None,
                image_size
            )
    return spatial_matrix

## 2. 3D Coordinate Calculation

We calculate the 3D position $(X, Y, Z)$ of each object using the Pinhole Camera Model.

### Intrinsics
Given $FOV_{vert} = 60^{\circ}$ and Image Height $H = 1024$:

$f_y = \frac{H / 2}{\tan(FOV / 2)}$

### Depth Estimation ($Z$)
We use the known physical height of the object and its projected pixel height:

$Z = \frac{f_y \cdot H_{real}}{H_{pixel}}$

### X, Y Estimation
$X = \frac{(u - c_x) \cdot Z}{f_x}$
$Y = -\frac{(v - c_y) \cdot Z}{f_y}$ (Unity Y is up, Image Y is down)

In [None]:
def get_intrinsics(img_width, img_height, fov_deg):
    fov_rad = np.deg2rad(fov_deg)
    f_y = (img_height / 2) / np.tan(fov_rad / 2)
    f_x = f_y  # Square pixels
    cx = img_width / 2
    cy = img_height / 2
    return f_x, f_y, cx, cy

def pixel_to_3d(bbox, real_height, intrinsics):
    """
    Convert bounding box to 3D centroid.
    bbox: [x1, y1, x2, y2]
    real_height: physical height in meters
    intrinsics: (fx, fy, cx, cy)
    """
    fx, fy, cx, cy = intrinsics
    x1, y1, x2, y2 = bbox
    
    # Pixel dimensions
    h_pixel = max(y2 - y1, 1e-5)
    w_pixel = max(x2 - x1, 1e-5)
    
    # Centroid in pixels
    u_center = (x1 + x2) / 2
    v_center = (y1 + y2) / 2
    
    # Estimate Depth Z
    # Z = (f * real_H) / pixel_H
    Z = (fy * real_height) / h_pixel
    
    # Back-project to X, Y
    # Unity Coordinate System: Y is Up, we might need to adjust signs depending on needs.
    # Here: Standard CV pinhole: X right, Y down, Z forward.
    # Unity: X right, Y up, Z forward.
    
    X = (u_center - cx) * Z / fx
    Y = -(v_center - cy) * Z / fy  # Invert Y to match Unity's 'Up' direction approx
    
    return float(X), float(Y), float(Z)

## 3. Pipeline

1.  **Detect Objects** (YOLO)
2.  **Separate Classes**
3.  **Compute Affinity**
4.  **Match**
5.  **Compute 3D**
6.  **Visualize**

In [None]:
# Load Models
print("üöÄ Loading Models...")

# 1. YOLO
try:
    if os.path.exists(WEIGHTS_YOLO):
        yolo_model = YOLO(WEIGHTS_YOLO)
        print("‚úÖ YOLO Loaded")
    else:
        print(f"‚ùå ERROR: YOLO Weights not found anywhere! Check inputs.")
        # Dummy model to prevent execution crash if just checking paths
        yolo_model = None 
except Exception as e:
    print(f"‚ö†Ô∏è Error loading YOLO: {e}")
    yolo_model = None

# 2. AffinityNet
affinity_model = AffinityNet().to(DEVICE)
if os.path.exists(WEIGHTS_AFFINITY):
    affinity_model.load_state_dict(torch.load(WEIGHTS_AFFINITY, map_location=DEVICE))
    affinity_model.eval()
    print("‚úÖ AffinityNet Loaded")
else:
    print(f"‚ùå ERROR: AffinityNet Weights not found at {WEIGHTS_AFFINITY}")
    
# Intrinsics
fx, fy, cx, cy = get_intrinsics(IMG_WIDTH, IMG_HEIGHT, FOV_VERTICAL_DEG)
intrinsics = (fx, fy, cx, cy)
print(f"üì∏ Intrinsics: fx={fx:.1f}, fy={fy:.1f}, cx={cx:.1f}, cy={cy:.1f}")

In [None]:
def process_image(image_path, visualize=False):
    if yolo_model is None:
        print("‚ùå YOLO model not loaded, skipping inference.")
        return None
        
    filename = os.path.basename(image_path)
    
    # 1. Inference
    results = yolo_model(image_path, conf=0.5, verbose=False)[0]
    
    # 2. Extract Data
    boxes = results.boxes.xyxy.cpu().numpy()
    classes = results.boxes.cls.cpu().numpy()
    if results.masks is not None:
        masks = results.masks.data.cpu().numpy()  # [N, H, W]
    else:
        masks = None
        
    # Separate by class (0: cube/strawberry, 1: para/peduncle)
    # Note: Check dataset.yaml for exact class IDs. Assuming 0=Cube, 1=Para based on standard order
    # BUT: COCO annotations had 1=Cube, 2=Para. YOLO classes are 0-indexed.
    # So YOLO 0 -> Cube, YOLO 1 -> Para
    
    cube_indices = np.where(classes == 0)[0]
    para_indices = np.where(classes == 1)[0]
    
    cubes = []
    parallelepipeds = []
    
    # Process Cubes
    for idx in cube_indices:
        box = boxes[idx]
        mask = masks[idx] if masks is not None else None
        x, y, z = pixel_to_3d(box, CUBE_RealHeight, intrinsics)
        cubes.append({
            'id': int(idx),
            'bbox': box.tolist(),
            'mask': mask,
            'pos_3d': [x, y, z],
            'class': 'strawberry'
        })
        
    # Process Paras
    for idx in para_indices:
        box = boxes[idx]
        mask = masks[idx] if masks is not None else None
        x, y, z = pixel_to_3d(box, PARA_RealHeight, intrinsics)
        parallelepipeds.append({
            'id': int(idx),
            'bbox': box.tolist(),
            'mask': mask,
            'pos_3d': [x, y, z],
            'class': 'peduncle',
            'matched_cube_id': None
        })
        
    # 3. Association
    if len(cubes) > 0 and len(parallelepipeds) > 0:
        cube_boxes = np.array([c['bbox'] for c in cubes])
        para_boxes = np.array([p['bbox'] for p in parallelepipeds])
        cube_masks_arr = np.array([c['mask'] for c in cubes]) if masks is not None else None
        para_masks_arr = np.array([p['mask'] for p in parallelepipeds]) if masks is not None else None
        
        # Compute features
        spatial_matrix = compute_spatial_features_batch(
            para_boxes, cube_boxes,
            para_masks_arr, cube_masks_arr,
            (IMG_HEIGHT, IMG_WIDTH)
        )
        
        # Predict affinity
        spatial_tensor = torch.from_numpy(spatial_matrix).to(DEVICE)
        with torch.no_grad():
            affinity_matrix = affinity_model.predict_matrix(spatial_tensor).cpu().numpy()
            
        # Match (Hungarian)
        row_ind, col_ind = linear_sum_assignment(-affinity_matrix)
        
        for r, c in zip(row_ind, col_ind):
            score = affinity_matrix[r, c]
            if score > 0.5:
                parallelepipeds[r]['matched_cube_id'] = cubes[c]['id']

    # 4. Visualization
    if visualize:
        img = cv2.imread(image_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        
        plt.figure(figsize=(12, 12))
        plt.imshow(img)
        ax = plt.gca()
        
        # Draw Cubes
        for c in cubes:
            x1, y1, x2, y2 = c['bbox']
            rect = plt.Rectangle((x1, y1), x2-x1, y2-y1, fill=False, edgecolor='red', linewidth=2)
            ax.add_patch(rect)
            ax.text(x1, y1-5, f"Straw {c['id']} {c['pos_3d'][2]:.2f}m", color='red', fontsize=10, backgroundcolor='white')
            
        # Draw Paras & Lines
        for p in parallelepipeds:
            x1, y1, x2, y2 = p['bbox']
            rect = plt.Rectangle((x1, y1), x2-x1, y2-y1, fill=False, edgecolor='green', linewidth=2)
            ax.add_patch(rect)
            ax.text(x1, y1-5, f"Stem {p['id']}", color='green', fontsize=8, backgroundcolor='white')
            
            if p['matched_cube_id'] is not None:
                cube = next((c for c in cubes if c['id'] == p['matched_cube_id']), None)
                if cube:
                    c_x = (cube['bbox'][0] + cube['bbox'][2]) / 2
                    c_y = (cube['bbox'][1] + cube['bbox'][3]) / 2
                    p_x = (x1 + x2) / 2
                    p_y = (y1 + y2) / 2
                    plt.plot([c_x, p_x], [c_y, p_y], color='yellow', linewidth=2)
        
        plt.axis('off')
        plt.title(f"Results: {filename}\nYellow lines = Predicted Associations")
        plt.show()
        
    # 5. Output Data Format
    output_data = {
        'image': filename,
        'strawberries': [
            {k: v for k, v in c.items() if k != 'mask' and k != 'bbox'} for c in cubes
        ],
        'peduncles': [
            {k: v for k, v in p.items() if k != 'mask' and k != 'bbox'} for p in parallelepipeds
        ]
    }
    # Add back bbox for JSON (converted to simple list)
    for s in output_data['strawberries']:
        cube_ref = next(c for c in cubes if c['id'] == s['id'])
        s['bbox'] = cube_ref['bbox']

    for p in output_data['peduncles']:
        para_ref = next(par for par in parallelepipeds if par['id'] == p['id'])
        p['bbox'] = para_ref['bbox']

    return output_data

In [None]:
# === RUN ON TEST IMAGES ===
if DATASET_PATH is not None:
    # Try multiple subfolder structures
    search_paths = [
        os.path.join(DATASET_PATH, "images", "*.png"),
        os.path.join(DATASET_PATH, "*.png")
    ]
    test_images = []
    for sp in search_paths:
        found = glob.glob(sp)
        if len(found) > 0:
            test_images = found
            break
    
    test_images = test_images[:5]
else:
    test_images = []

results_json = []

print(f"üß™ Testing on {len(test_images)} images...")

if len(test_images) == 0:
    print("‚ö†Ô∏è NO IMAGES FOUND. Check if the dataset is added to the notebook input correctly.")

for img_path in test_images:
    try:
        res = process_image(img_path, visualize=True)
        if res is not None:
            results_json.append(res)
    except Exception as e:
        print(f"‚ö†Ô∏è Error processing {img_path}: {e}")

# Save JSON
with open("detailed_results.json", "w") as f:
    json.dump(results_json, f, indent=2)
    
print("\n‚úÖ Saved detailed_results.json")

In [None]:
# === JSON PREVIEW ===
if len(results_json) > 0:
    print(json.dumps(results_json[0], indent=2))
else:
    print("‚ö†Ô∏è No results to preview. Check logs for errors.")

## Visualizing 3D Positions
Here we do a simple scatter plot of the detected objects in 3D (X-Z plane, Top-down view).

In [None]:
def plot_3d_topdown(results):
    plt.figure(figsize=(10, 10))
    
    for entry in results:
        for s in entry['strawberries']:
            x, y, z = s['pos_3d']
            plt.scatter(x, z, c='red', marker='s', s=100, label='Strawberry' if 'Strawberry' not in plt.gca().get_legend_handles_labels()[1] else "")
            
        for p in entry['peduncles']:
            x, y, z = p['pos_3d']
            plt.scatter(x, z, c='green', marker='^', s=50, label='Peduncle' if 'Peduncle' not in plt.gca().get_legend_handles_labels()[1] else "")
            
            # Draw connection
            if p['matched_cube_id'] is not None:
                match = next((s for s in entry['strawberries'] if s['id'] == p['matched_cube_id']), None)
                if match:
                    mx, my, mz = match['pos_3d']
                    plt.plot([x, mx], [z, mz], 'k--', alpha=0.3)

    plt.xlabel("X (meters)")
    plt.ylabel("Z (Depth, meters)")
    plt.title("Top-Down View of Detected Objects (X-Z Plane)")
    plt.grid(True)
    plt.legend()
    plt.axis('equal')
    plt.show()

if len(results_json) > 0:
    plot_3d_topdown(results_json)
else:
    print("‚ö†Ô∏è No data to plot.")