# üçì Task 2: Strawberry Ripeness Classification

This notebook trains an **EfficientNet-B0** model to classify strawberries into 3 categories:
0.  **Ripe** (Green label in COCO)
1.  **Unripe** (Red label)
2.  **Half-Ripe** (Orange label)

**Workflow:**
1.  **Download Dataset** (from GitHub Releases).
2.  **Crop Extraction**: Use ground-truth bounding boxes to extract individual strawberry images.
3.  **Data Augmentation**: Apply transforms (Flip, Rotate, Color Jitter).
4.  **Training**: Fine-tune a pre-trained EfficientNet model.

## 1. Setup & Download

Standard setup block to get the dataset ready.

In [None]:
import os
import requests
import zipfile
import shutil
import json
import cv2
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from tqdm.auto import tqdm
from PIL import Image
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix

# Ensure timm is installed for EfficientNet
try:
    import timm
except ImportError:
    !pip install -q timm
    import timm

# --- Configuration ---
GITHUB_TYPE = "releases"
VERSION_TAG = "v1.0"
BASE_URL = f"https://github.com/SergKurchev/strawberry_synthetic_dataset/releases/download/{VERSION_TAG}"
FILES_TO_DOWNLOAD = [
    "strawberry_dataset.zip.001",
    "strawberry_dataset.zip.002",
    "strawberry_dataset.zip.003"
]
OUTPUT_ZIP = "strawberry_dataset.zip"
DATASET_ROOT = Path("strawberry_dataset")

def setup_dataset():
    search_paths = [
        Path("strawberry_dataset"),
        Path("dataset/strawberry_dataset"),
        Path("/kaggle/input/last-straw-dataset/strawberry_dataset"),
        Path("/kaggle/input/strawberry_synthetic_dataset/strawberry_dataset")
    ]
    for p in search_paths:
        if p.exists() and (p / "annotations.json").exists():
            print(f"‚úÖ Dataset found at: {p}")
            return p

    print("‚¨áÔ∏è Dataset not found. Downloading...")
    os.makedirs("temp_download", exist_ok=True)
    for filename in FILES_TO_DOWNLOAD:
        file_path = Path("temp_download") / filename
        if not file_path.exists():
            r = requests.get(f"{BASE_URL}/{filename}", stream=True)
            with open(file_path, 'wb') as f:
                for chunk in r.iter_content(chunk_size=8192): f.write(chunk)

    print("üì¶ Combining...")
    with open(OUTPUT_ZIP, 'wb') as outfile:
        for filename in FILES_TO_DOWNLOAD:
            with open(Path("temp_download") / filename, 'rb') as infile: shutil.copyfileobj(infile, outfile)

    print("üìÇ Extracting...")
    with zipfile.ZipFile(OUTPUT_ZIP, 'r') as zip_ref: zip_ref.extractall(".")
    
    shutil.rmtree("temp_download")
    os.remove(OUTPUT_ZIP)
    return DATASET_ROOT

DATASET_PATH = setup_dataset()
if not DATASET_PATH: raise RuntimeError("Setup failed")

## 2. Crop Extraction

We need to cut out each strawberry from the full images based on the annotations to create a classification dataset.

In [None]:
# Load Annotations
with open(DATASET_PATH / "annotations.json", 'r') as f:
    coco = json.load(f)

CROPS_DIR = Path("strawberry_crops")
shutil.rmtree(CROPS_DIR, ignore_errors=True)
CROPS_DIR.mkdir(exist_ok=True)

# Categories we care about for classification
CLASS_MAP = {
    0: "ripe",
    1: "unripe",
    2: "half_ripe"
}

for name in CLASS_MAP.values():
    (CROPS_DIR / name).mkdir(exist_ok=True)

print("‚úÇÔ∏è Extracting crops...")
count = 0

# Optimize: Group anns by image_id
from collections import defaultdict
img_to_anns = defaultdict(list)
for ann in coco['annotations']:
    img_to_anns[ann['image_id']].append(ann)

for img_info in tqdm(coco['images']):
    img_id = img_info['id']
    if img_id not in img_to_anns: continue
    
    # Load Image
    img_path = DATASET_PATH / "images" / img_info['file_name']
    if not img_path.exists(): continue
    
    # We read with OpenCV for faster cropping
    img = cv2.imread(str(img_path))
    if img is None: continue
    
    for ann in img_to_anns[img_id]:
        cat_id = ann['category_id']
        if cat_id not in CLASS_MAP: continue
        
        x, y, w, h = [int(v) for v in ann['bbox']]
        
        # Padding usually helps classification models
        pad = 10
        x = max(0, x - pad)
        y = max(0, y - pad)
        w = min(img.shape[1] - x, w + 2*pad)
        h = min(img.shape[0] - y, h + 2*pad)
        
        crop = img[y:y+h, x:x+w]
        if crop.size == 0: continue
        
        save_name = f"{img_info['file_name']}_{ann['id']}.jpg"
        save_path = CROPS_DIR / CLASS_MAP[cat_id] / save_name
        cv2.imwrite(str(save_path), crop)
        count += 1

print(f"‚úÖ Extracted {count} crops.")

## 3. Training Pipeline

We use PyTorch + timm to fine-tune EfficientNet-B0.

In [None]:
# Dataset Class
class StrawberryDataset(Dataset):
    def __init__(self, file_paths, labels, transform=None):
        self.file_paths = file_paths
        self.labels = labels
        self.transform = transform
    
    def __len__(self): return len(self.file_paths)
    
    def __getitem__(self, idx):
        img = Image.open(self.file_paths[idx]).convert("RGB")
        if self.transform: img = self.transform(img)
        return img, torch.tensor(self.labels[idx], dtype=torch.long)

# Gather files
all_files = []
all_labels = []
label_map = {"ripe": 0, "unripe": 1, "half_ripe": 2}

for class_name, label_idx in label_map.items():
    files = list((CROPS_DIR / class_name).glob("*.jpg"))
    all_files.extend(files)
    all_labels.extend([label_idx] * len(files))

# Split
train_files, val_files, train_labels, val_labels = train_test_split(
    all_files, all_labels, test_size=0.2, stratify=all_labels, random_state=42
)

# Transforms
train_tf = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ColorJitter(brightness=0.2, contrast=0.2),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

val_tf = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# Loaders
train_ds = StrawberryDataset(train_files, train_labels, train_tf)
val_ds = StrawberryDataset(val_files, val_labels, val_tf)
train_loader = DataLoader(train_ds, batch_size=32, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=32, shuffle=False)

# Model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = timm.create_model('efficientnet_b0', pretrained=True, num_classes=3)
model = model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

print(f"üöÄ Training on {device} with {len(train_ds)} samples...")

for epoch in range(5):  # Short training for demo
    model.train()
    running_loss = 0.0
    for imgs, lbls in tqdm(train_loader, desc=f"Epoch {epoch+1}"):
        imgs, lbls = imgs.to(device), lbls.to(device)
        optimizer.zero_grad()
        outputs = model(imgs)
        loss = criterion(outputs, lbls)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    print(f"  Loss: {running_loss/len(train_loader):.4f}")

print("‚úÖ Training Complete.")

## 4. Evaluation

Check accuracy and confusion matrix.

In [None]:
model.eval()
all_preds = []
all_true = []

with torch.no_grad():
    for imgs, lbls in tqdm(val_loader, desc="Validating"):
        imgs, lbls = imgs.to(device), lbls.to(device)
        outputs = model(imgs)
        _, preds = torch.max(outputs, 1)
        all_preds.extend(preds.cpu().numpy())
        all_true.extend(lbls.cpu().numpy())

print("\nClassification Report:")
print(classification_report(all_true, all_preds, target_names=label_map.keys()))