# 🍓 Task 1: Segmentation (YOLO11)

This notebook trains a **YOLO11 Segmentation Model** to detect and segment:
- **Strawberries** (Class 0, 1, 2: Ripe, Unripe, Half-Ripe)
- **Peduncles** (Class 3: Stems)

**Key Features:**
1.  **Automatic Dataset Download**: Fetches split dataset from GitHub Releases.
2.  **Visual Verification**: Checks integrity of data before training.
3.  **YOLO Training**: Uses Ultralytics YOLO11n-seg (Nano) for speed/demo.
4.  **Inference**: Visualizes predictions on validation set.

## 1. Environment & Data Setup

Installs dependencies and downloads the dataset if not present.

In [None]:
import os
import sys
import json
import requests
import zipfile
import shutil
import glob
import inspect
import torch
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from tqdm.auto import tqdm
from pathlib import Path
import yaml
import cv2
from ultralytics import YOLO
    import ultralytics


# --- Robust Dataset Configuration ---
VERSION_TAG = "Dataset"
BASE_URL = f"https://github.com/SergKurchev/strawberry_synthetic_dataset/releases/download/{VERSION_TAG}"
FILES_TO_DOWNLOAD = [
    "strawberry_dataset.zip.001",
    "strawberry_dataset.zip.002",
    "strawberry_dataset.zip.003"
]
OUTPUT_ZIP = "strawberry_dataset.zip"

def reconstruct_metadata(dataset_root):
    """Reconstructs depth_metadata.json from individual files in metadata_temp/"""
    print("⚠️ 'depth_metadata.json' not found. Attempting reconstruction from 'metadata_temp/'...")
    temp_dir = dataset_root / "metadata_temp"
    if not temp_dir.exists():
        print(f"❌ metadata_temp directory not found at {temp_dir}")
        return False

    combined_metadata = {}
    json_files = list(temp_dir.glob("*_meta.json"))
    print(f"  Found {len(json_files)} metadata chunks.")
    
    for json_file in tqdm(json_files, desc="Reconstructing Metadata"):
        try:
            # Filename format: 00001_meta.json -> corresponds to 00001.png
            # We assume the content of the json is the metadata dict for that image
            img_id = json_file.name.replace("_meta.json", "")
            img_name = f"{img_id}.png"
            
            with open(json_file, 'r') as f:
                data = json.load(f)
                combined_metadata[img_name] = data
        except Exception as e:
            print(f"  Warning: Failed to read {json_file}: {e}")

    if not combined_metadata:
        print("❌ Failed to reconstruct any metadata.")
        return False

    target_path = dataset_root / "depth_metadata.json"
    print(f"💾 Saving reconstructed metadata to {target_path}...")
    with open(target_path, 'w') as f:
        json.dump(combined_metadata, f, indent=2)
        
    return True

def setup_dataset():
    # 1. Search for existing dataset
    print("🔍 Searching for existing dataset...")
    
    # Helper to validate a root candidate
    def validate_root(p):
        if (p / "depth_metadata.json").exists():
            return True
        if (p / "metadata_temp").exists():
            # Try to fix it
            return reconstruct_metadata(p)
        return False

    # Recursive search in current dir
    for root, dirs, files in os.walk(".", topdown=True):
        p = Path(root)
        # Start optimization: don't go too deep or into hidden dirs
        if ".git" in p.parts or "temp_download" in p.parts:
            continue
            
        if "images" in dirs and ("depth_metadata.json" in files or "metadata_temp" in dirs):
            if validate_root(p):
                print(f"✅ Dataset found/Fixed at: {p}")
                return p

    # Check standard paths
    search_paths = [
        Path("strawberry_dataset"),
        Path("dataset/strawberry_dataset"),
        Path("/kaggle/input/last-straw-dataset/strawberry_dataset"),
        Path("/kaggle/input/strawberry_synthetic_dataset/strawberry_dataset")
    ]
    for p in search_paths:
        if p.exists():
            if validate_root(p):
                print(f"✅ Dataset found/Fixed at: {p}")
                return p

    print("⬇️ Dataset not found. Downloading from GitHub Releases...")
    
    # 2. Prepare Download Directory
    if os.path.exists("temp_download"):
        shutil.rmtree("temp_download")
    os.makedirs("temp_download", exist_ok=True)
    
    if os.path.exists(OUTPUT_ZIP):
        os.remove(OUTPUT_ZIP)

    # 3. Download and Combine
    with open(OUTPUT_ZIP, 'wb') as outfile:
        for filename in FILES_TO_DOWNLOAD:
            file_path = Path("temp_download") / filename
            url = f"{BASE_URL}/{filename}"
            
            print(f"  Downloading {filename} from {url}...")
            r = requests.get(url, stream=True)
            if r.status_code != 200:
                raise RuntimeError(f"Download failed for {filename}: HTTP {r.status_code}")
            
            with open(file_path, 'wb') as f:
                for chunk in r.iter_content(chunk_size=8192):
                    f.write(chunk)
            
            file_size_mb = file_path.stat().st_size / 1024 / 1024
            print(f"  Downloaded {filename} ({file_size_mb:.2f} MB). Appending to zip...")
            
            with open(file_path, 'rb') as infile:
                shutil.copyfileobj(infile, outfile)

    # 4. Extract
    total_size_mb = os.path.getsize(OUTPUT_ZIP)/1024/1024
    print(f"📂 Extracting {OUTPUT_ZIP} ({total_size_mb:.2f} MB)...")
    
    try:
        with zipfile.ZipFile(OUTPUT_ZIP, 'r') as zip_ref:
            zip_ref.extractall(".")
            print("  Extraction complete.")
    except zipfile.BadZipFile as e:
        print(f"❌ BadZipFile Error: {e}")
        raise e
    
    shutil.rmtree("temp_download", ignore_errors=True)
    if os.path.exists(OUTPUT_ZIP):
        os.remove(OUTPUT_ZIP)

    # --- FIX: Handle potential backslash filenames on Linux ---
    print("🧹 Checking for backslash issues in filenames...")
    count = 0
    # Iterate over files in current directory to check for backslashes in names
    for filename in os.listdir("."):
        if "\\" in filename:
            # It's a file with backslashes in name, implying flattened structure
            new_path = filename.replace("\\", "/") # standardize to forward slash
            
            # Create parent dirs
            parent = os.path.dirname(new_path)
            if parent:
                os.makedirs(parent, exist_ok=True)
            
            # Move file
            try:
                shutil.move(filename, new_path)
                count += 1
            except Exception as e:
                print(f"  Failed to move {filename} -> {new_path}: {e}")
            
    if count > 0:
        print(f"✅ Fixed {count} filenames with backslashes. Directory structure restored.")
        
    # 5. Locate and Fix
    print("🔎 Locating dataset root...")
    for root, dirs, files in os.walk(".", topdown=True):
        p = Path(root)
        if "images" in dirs and ("depth_metadata.json" in files or "metadata_temp" in dirs):
            if validate_root(p):
                 print(f"✅ Dataset extracted and verified at: {p}")
                 return p
            
    return None

DATASET_PATH = setup_dataset()
if not DATASET_PATH: raise RuntimeError("Dataset setup failed: Could not locate or reconstruct metadata")
DATASET_ROOT = DATASET_PATH


## 2. Prepare YOLO Configuration

We need to create a `data.yaml` file that points to the correct paths so YOLO can look for images and labels.

In [None]:
# The dataset comes with 'labels/' already in YOLO format.
# Structure:
#   strawberry_dataset/
#     images/
#     labels/
# We need to tell YOLO where these are. 
# Note: The dataset provided doesn't have a pre-split train/val folder structure in 'images/'. 
# It has a flat 'images/' folder. YOLO usually expects images/train and images/val.
# However, we can use a text file list or we can split the folders now.

# Let's reorganize slightly for YOLO best practices: create train/val splits.
import random
import glob

def prepare_yolo_splits(dataset_path, split_ratio=0.8):
    images_dir = dataset_path / "images"
    labels_dir = dataset_path / "labels"
    
    # Check if already split (look for 'train' folder)
    if (images_dir / "train").exists():
        print("✅ Dataset already split.")
        return
    
    print("🔄 Splitting dataset into train/val...")
    all_images = list(images_dir.glob("*.png"))
    random.shuffle(all_images)
    
    split_idx = int(len(all_images) * split_ratio)
    train_imgs = all_images[:split_idx]
    val_imgs = all_images[split_idx:]
    
    # Create folders
    for split in ['train', 'val']:
        (images_dir / split).mkdir(exist_ok=True, parents=True)
        (labels_dir / split).mkdir(exist_ok=True, parents=True)
    
    # Move files
    for img_path in train_imgs:
        shutil.move(str(img_path), str(images_dir / "train" / img_path.name))
        lbl_name = img_path.stem + ".txt"
        if (labels_dir / lbl_name).exists():
            shutil.move(str(labels_dir / lbl_name), str(labels_dir / "train" / lbl_name))
            
    for img_path in val_imgs:
        shutil.move(str(img_path), str(images_dir / "val" / img_path.name))
        lbl_name = img_path.stem + ".txt"
        if (labels_dir / lbl_name).exists():
            shutil.move(str(labels_dir / lbl_name), str(labels_dir / "val" / lbl_name))
            
    print("✅ Split complete.")

prepare_yolo_splits(DATASET_PATH)

# Create data.yaml
yaml_content = f"""
path: {DATASET_PATH.absolute()} # absolute path to dataset root
train: images/train
val: images/val

names:
  0: strawberry_ripe
  1: strawberry_unripe
  2: strawberry_half_ripe
  3: peduncle
"""

with open(DATASET_PATH / "data.yaml", "w") as f:
    f.write(yaml_content)

print("📄 data.yaml created.")

## 3. Train YOLO11 Model

We use `yolo11n-seg.pt` (Nano) for quick training. For higher accuracy in production, use `yolo11m-seg.pt` or `yolo11l-seg.pt`.

In [None]:
# Load Model
model = YOLO('yolo11n-seg.pt')

# Train
results = model.train(
    data=str(DATASET_PATH / "data.yaml"),
    epochs=20,          # Adjust epochs as needed
    imgsz=640,          # Image size
    batch=16,           # Batch size
    project="strawberry_yolo",
    name="yolo11n_run",
    exist_ok=True
)

## 4. Evaluation & Inference

Visualize predictions on the validation set.

In [None]:
# Validate
metrics = model.val()

# Predict on a sample image from val set
val_images = list((DATASET_PATH / "images" / "val").glob("*.png"))
if val_images:
    test_img = val_images[0]
    results = model.predict(test_img)
    
    # Show result
    res_plotted = results[0].plot()
    plt.figure(figsize=(10, 10))
    plt.imshow(cv2.cvtColor(res_plotted, cv2.COLOR_BGR2RGB))
    plt.axis('off')
    plt.title(f"Prediction on {test_img.name}")
    plt.show()