# üçì Last Straw Dataset - Comprehensive Overview

This notebook provides a complete guide to the **LAST-Straw** synthetic strawberry dataset. 
It demonstrates how to:
1.  **Download & Extract** the dataset automatically from GitHub Releases (handling split archives).
2.  **Visualize** RGB images, Depth maps, and Instance Masks.
3.  **Inspect Annotations** (Bounding Boxes, Categories, Attributes).
4.  **Visualize Matching** relationships (Strawberry connectivity to Peduncles).

**Repository**: [https://github.com/SergKurchev/strawberry_synthetic_dataset](https://github.com/SergKurchev/strawberry_synthetic_dataset)

## 1. Setup & Data Loading

This block handles the entire setup process. It checks if the dataset exists locally (or in Kaggle Inputs). If not, it automatically downloads the split parts from GitHub Releases, combines them, and extracts the data.

In [None]:
import os
import requests
import zipfile
import shutil
import numpy as np
import json
import cv2
import matplotlib.pyplot as plt
from pathlib import Path
from PIL import Image

# Configuration
GITHUB_TYPE = "releases" # or "raw" if small enough, but here we use releases for split files
VERSION_TAG = "v1.0"
BASE_URL = f"https://github.com/SergKurchev/strawberry_synthetic_dataset/releases/download/{VERSION_TAG}"
FILES_TO_DOWNLOAD = [
    "strawberry_dataset.zip.001",
    "strawberry_dataset.zip.002",
    "strawberry_dataset.zip.003"
]
OUTPUT_ZIP = "strawberry_dataset.zip"
DATASET_ROOT = Path("strawberry_dataset")

def setup_dataset():
    # 1. Check for existing dataset (Kaggle Input or Local)
    search_paths = [
        Path("strawberry_dataset"),
        Path("dataset/strawberry_dataset"),
        Path("/kaggle/input/last-straw-dataset/strawberry_dataset"),
        Path("/kaggle/input/strawberry_synthetic_dataset/strawberry_dataset")
    ]
    
    for p in search_paths:
        if p.exists() and (p / "annotations.json").exists():
            print(f"‚úÖ Dataset found at: {p}")
            return p

    print("‚¨áÔ∏è Dataset not found. Downloading from GitHub Releases...")
    
    # 2. Download Split Parts
    os.makedirs("temp_download", exist_ok=True)
    for filename in FILES_TO_DOWNLOAD:
        file_path = Path("temp_download") / filename
        if not file_path.exists():
            url = f"{BASE_URL}/{filename}"
            print(f"  Downloading {filename}...")
            try:
                r = requests.get(url, stream=True)
                r.raise_for_status()
                with open(file_path, 'wb') as f:
                    for chunk in r.iter_content(chunk_size=8192):
                        f.write(chunk)
            except Exception as e:
                print(f"‚ùå Failed to download {filename}: {e}")
                return None
        else:
            print(f"  {filename} already exists in temp.")

    # 3. Combine Parts
    print("üì¶ Combining split archive...")
    with open(OUTPUT_ZIP, 'wb') as outfile:
        for filename in FILES_TO_DOWNLOAD:
            part_path = Path("temp_download") / filename
            with open(part_path, 'rb') as infile:
                shutil.copyfileobj(infile, outfile)
    
    # 4. Extract
    print("üìÇ Extracting dataset...")
    try:
        with zipfile.ZipFile(OUTPUT_ZIP, 'r') as zip_ref:
            zip_ref.extractall(".")
        print("‚úÖ Extraction complete.")
    except zipfile.BadZipFile:
        print("‚ùå Error: The combined ZIP file is corrupted.")
        return None

    # Cleanup
    if os.path.exists("temp_download"): shutil.rmtree("temp_download")
    if os.path.exists(OUTPUT_ZIP): os.remove(OUTPUT_ZIP)
    
    return DATASET_ROOT

DATASET_PATH = setup_dataset()
if DATASET_PATH is None:
    raise RuntimeError("Failed to setup dataset!")

## 2. Load Metadata

Load `annotations.json` which contains COCO-style annotations for the entire dataset.

In [None]:
with open(DATASET_PATH / "annotations.json", 'r') as f:
    coco_data = json.load(f)

images = coco_data['images']
annotations = coco_data['annotations']
categories = {c['id']: c for c in coco_data['categories']}

print(f"üñºÔ∏è Total Images: {len(images)}")
print(f"üè∑Ô∏è Total Annotations: {len(annotations)}")
print("üìã Categories:")
for cid, cat in categories.items():
    print(f"  {cid}: {cat['name']}")

## 3. Visualization Helpers

Functions to draw bounding boxes, decode depth maps, and visualize relationship lines.

In [None]:
# Color palette for visualization
COLORS = {
    0: (0, 255, 0),      # Ripe (Green in BGR, will convert later)
    1: (0, 0, 255),      # Unripe (Red)
    2: (0, 165, 255),    # Half-ripe (Orange)
    3: (19, 69, 139)     # Peduncle (Brown)
}

def draw_annotations(img, anns, show_bboxes=True, show_labels=True):
    vis = img.copy()
    for ann in anns:
        cat_id = ann['category_id']
        color = COLORS.get(cat_id, (255, 255, 255))
        x, y, w, h = [int(v) for v in ann['bbox']]
        
        if show_bboxes:
            cv2.rectangle(vis, (x, y), (x+w, y+h), color, 2)
        
        if show_labels:
            label = categories[cat_id]['name']
            cv2.putText(vis, label, (x, y-5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 1)
    return vis

def visualize_matching(img, anns):
    vis = img.copy()
    # Map instance_id to annotation for quick lookup
    id_to_ann = {a['instance_id']: a for a in anns}
    
    for ann in anns:
        # If it's a strawberry and has a parent (peduncle)
        if ann['category_id'] in [0, 1, 2] and 'parent_id' in ann and ann['parent_id'] != 0:
            parent_id = ann['parent_id']
            if parent_id in id_to_ann:
                parent = id_to_ann[parent_id]
                
                # Calculate centers
                sx, sy, sw, sh = ann['bbox']
                center_s = (int(sx + sw/2), int(sy + sh/2))
                
                px, py, pw, ph = parent['bbox']
                center_p = (int(px + pw/2), int(py + ph/2))
                
                # Draw line
                cv2.line(vis, center_s, center_p, (255, 255, 255), 2)
                cv2.circle(vis, center_s, 5, COLORS[ann['category_id']], -1)
                cv2.circle(vis, center_p, 5, COLORS[3], -1)

    return vis

def decode_depth(depth_path):
    # Depth is stored as 16-bit PNG (R=High, G=Low byte)
    # Formula: depth_mm = (R * 256) + G
    if not os.path.exists(depth_path): return None
    
    img = Image.open(depth_path)
    depth_arr = np.array(img)
    
    if len(depth_arr.shape) == 3: # RGBA/RGB
        # Setup for 16-bit decoding from channels
        depth_mm = (depth_arr[:,:,0].astype(np.uint16) << 8) | depth_arr[:,:,1].astype(np.uint16)
    else:
        # Already gray 16-bit
        depth_mm = depth_arr
        
    return depth_mm.astype(np.float32) / 1000.0 # Convert to meters

## 4. Explore Samples

Let's visualize a few random samples from the dataset.

In [None]:
import random

SAMPLES_TO_SHOW = 3
indices = random.sample(range(len(images)), SAMPLES_TO_SHOW)

for idx in indices:
    img_info = images[idx]
    img_path = DATASET_PATH / "images" / img_info['file_name']
    depth_path = DATASET_PATH / "depth" / img_info['file_name']
    
    # Read Images
    rgb = cv2.imread(str(img_path))
    rgb = cv2.cvtColor(rgb, cv2.COLOR_BGR2RGB)
    depth_map = decode_depth(depth_path)
    
    # Get Annotations for this image
    img_anns = [a for a in annotations if a['image_id'] == img_info['id']]
    
    # Create Visualizations
    vis_bbox = draw_annotations(rgb, img_anns)
    vis_match = visualize_matching(rgb, img_anns)
    
    # Plot
    fig, ax = plt.subplots(1, 3, figsize=(18, 6))
    
    ax[0].imshow(vis_bbox)
    ax[0].set_title(f"Annotations: {img_info['file_name']}")
    ax[0].axis('off')
    
    if depth_map is not None:
        im1 = ax[1].imshow(depth_map, cmap='magma')
        ax[1].set_title("Depth Map (Meters)")
        plt.colorbar(im1, ax=ax[1])
    else:
        ax[1].text(0.5, 0.5, "Depth Missing", ha='center')
    ax[1].axis('off')
    
    ax[2].imshow(vis_match)
    ax[2].set_title("Matching (Strawberry -> Peduncle)")
    ax[2].axis('off')
    
    plt.tight_layout()
    plt.show()