# üçì Task 1: Segmentation (YOLO11)

This notebook trains a **YOLO11 Segmentation Model** to detect and segment:
- **Strawberries** (Class 0, 1, 2: Ripe, Unripe, Half-Ripe)
- **Peduncles** (Class 3: Stems)

**Key Features:**
1.  **Automatic Dataset Download**: Fetches split dataset from GitHub Releases.
2.  **Visual Verification**: Checks integrity of data before training.
3.  **YOLO Training**: Uses Ultralytics YOLO11n-seg (Nano) for speed/demo.
4.  **Inference**: Visualizes predictions on validation set.

## 1. Environment & Data Setup

Installs dependencies and downloads the dataset if not present.

In [None]:
import os
import requests
import zipfile
import shutil
import yaml
import cv2
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from ultralytics import YOLO

# Install Ultralytics if needed
try:
    import ultralytics
except ImportError:
    !pip install -q ultralytics


# --- Configuration ---
GITHUB_TYPE = "releases"
VERSION_TAG = "v1.0"
BASE_URL = f"https://github.com/SergKurchev/strawberry_synthetic_dataset/releases/download/{VERSION_TAG}"
FILES_TO_DOWNLOAD = [
    "strawberry_dataset.zip.001",
    "strawberry_dataset.zip.002",
    "strawberry_dataset.zip.003"
]
OUTPUT_ZIP = "strawberry_dataset.zip"
DATASET_ROOT = Path("strawberry_dataset")

def setup_dataset():
    # 1. Check existing
    search_paths = [
        Path("strawberry_dataset"),
        Path("dataset/strawberry_dataset"),
        Path("/kaggle/input/last-straw-dataset/strawberry_dataset"),
        Path("/kaggle/input/strawberry_synthetic_dataset/strawberry_dataset")
    ]
    for p in search_paths:
        if p.exists() and (p / "data.yaml").exists(): # Check for YOLO data.yaml too if possible, or create it later
            print(f"‚úÖ Dataset found at: {p}")
            return p

    print("‚¨áÔ∏è Dataset not found. Downloading from GitHub Releases...")
    
    # 2. Download
    os.makedirs("temp_download", exist_ok=True)
    for filename in FILES_TO_DOWNLOAD:
        file_path = Path("temp_download") / filename
        if not file_path.exists():
            url = f"{BASE_URL}/{filename}"
            print(f"  Downloading {filename}...")
            r = requests.get(url, stream=True)
            with open(file_path, 'wb') as f:
                for chunk in r.iter_content(chunk_size=8192):
                    f.write(chunk)
    
    # 3. Combine
    print("üì¶ Combining...")
    with open(OUTPUT_ZIP, 'wb') as outfile:
        for filename in FILES_TO_DOWNLOAD:
            part_path = Path("temp_download") / filename
            with open(part_path, 'rb') as infile:
                shutil.copyfileobj(infile, outfile)
    
    # 4. Extract
    print("üìÇ Extracting...")
    with zipfile.ZipFile(OUTPUT_ZIP, 'r') as zip_ref:
        zip_ref.extractall(".")
    
    # Cleanup
    shutil.rmtree("temp_download")
    os.remove(OUTPUT_ZIP)
    print("‚úÖ Done.")
    return DATASET_ROOT

DATASET_PATH = setup_dataset()
if not DATASET_PATH: raise RuntimeError("Dataset setup failed")

## 2. Prepare YOLO Configuration

We need to create a `data.yaml` file that points to the correct paths so YOLO can look for images and labels.

In [None]:
# The dataset comes with 'labels/' already in YOLO format.
# Structure:
#   strawberry_dataset/
#     images/
#     labels/
# We need to tell YOLO where these are. 
# Note: The dataset provided doesn't have a pre-split train/val folder structure in 'images/'. 
# It has a flat 'images/' folder. YOLO usually expects images/train and images/val.
# However, we can use a text file list or we can split the folders now.

# Let's reorganize slightly for YOLO best practices: create train/val splits.
import random
import glob

def prepare_yolo_splits(dataset_path, split_ratio=0.8):
    images_dir = dataset_path / "images"
    labels_dir = dataset_path / "labels"
    
    # Check if already split (look for 'train' folder)
    if (images_dir / "train").exists():
        print("‚úÖ Dataset already split.")
        return
    
    print("üîÑ Splitting dataset into train/val...")
    all_images = list(images_dir.glob("*.png"))
    random.shuffle(all_images)
    
    split_idx = int(len(all_images) * split_ratio)
    train_imgs = all_images[:split_idx]
    val_imgs = all_images[split_idx:]
    
    # Create folders
    for split in ['train', 'val']:
        (images_dir / split).mkdir(exist_ok=True, parents=True)
        (labels_dir / split).mkdir(exist_ok=True, parents=True)
    
    # Move files
    for img_path in train_imgs:
        shutil.move(str(img_path), str(images_dir / "train" / img_path.name))
        lbl_name = img_path.stem + ".txt"
        if (labels_dir / lbl_name).exists():
            shutil.move(str(labels_dir / lbl_name), str(labels_dir / "train" / lbl_name))
            
    for img_path in val_imgs:
        shutil.move(str(img_path), str(images_dir / "val" / img_path.name))
        lbl_name = img_path.stem + ".txt"
        if (labels_dir / lbl_name).exists():
            shutil.move(str(labels_dir / lbl_name), str(labels_dir / "val" / lbl_name))
            
    print("‚úÖ Split complete.")

prepare_yolo_splits(DATASET_PATH)

# Create data.yaml
yaml_content = f"""
path: {DATASET_PATH.absolute()} # absolute path to dataset root
train: images/train
val: images/val

names:
  0: strawberry_ripe
  1: strawberry_unripe
  2: strawberry_half_ripe
  3: peduncle
"""

with open(DATASET_PATH / "data.yaml", "w") as f:
    f.write(yaml_content)

print("üìÑ data.yaml created.")

## 3. Train YOLO11 Model

We use `yolo11n-seg.pt` (Nano) for quick training. For higher accuracy in production, use `yolo11m-seg.pt` or `yolo11l-seg.pt`.

In [None]:
# Load Model
model = YOLO('yolo11n-seg.pt')

# Train
results = model.train(
    data=str(DATASET_PATH / "data.yaml"),
    epochs=20,          # Adjust epochs as needed
    imgsz=640,          # Image size
    batch=16,           # Batch size
    project="strawberry_yolo",
    name="yolo11n_run",
    exist_ok=True
)

## 4. Evaluation & Inference

Visualize predictions on the validation set.

In [None]:
# Validate
metrics = model.val()

# Predict on a sample image from val set
val_images = list((DATASET_PATH / "images" / "val").glob("*.png"))
if val_images:
    test_img = val_images[0]
    results = model.predict(test_img)
    
    # Show result
    res_plotted = results[0].plot()
    plt.figure(figsize=(10, 10))
    plt.imshow(cv2.cvtColor(res_plotted, cv2.COLOR_BGR2RGB))
    plt.axis('off')
    plt.title(f"Prediction on {test_img.name}")
    plt.show()