# Task 2: Strawberry Ripeness Classification

This notebook trains a classification model to categorize strawberries into:
- Ripe
- Unripe
- Half-ripe

**Input**: Segmented strawberry crops from Task 1
**Environment**: Designed for Kaggle with GPU support

## 1. Environment Setup

In [None]:
# Check GPU
import torch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

In [None]:
# Install dependencies
!pip install -q torch torchvision timm opencv-python-headless matplotlib seaborn scikit-learn
!pip install -q pillow numpy pandas tqdm albumentations

In [None]:
import os
import json
import shutil
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
from PIL import Image
from tqdm.auto import tqdm
import cv2

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from torchvision import models
import timm

from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score
import albumentations as A
from albumentations.pytorch import ToTensorV2

sns.set_style('whitegrid')
plt.rcParams['figure.figsize'] = (12, 8)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## 2. Download Dataset

In [None]:
# Clone dataset repository
REPO_URL = "https://github.com/SergKurchev/strawberry_synthetic_dataset.git"
DATASET_DIR = "/kaggle/working/strawberry_dataset"

if not os.path.exists(DATASET_DIR):
    print("Cloning dataset repository...")
    !git clone {REPO_URL} {DATASET_DIR}
else:
    print("Dataset already exists")

# Load annotations
with open(os.path.join(DATASET_DIR, "annotations.json"), 'r') as f:
    coco_data = json.load(f)

print(f"\nTotal images: {len(coco_data['images'])}")
print(f"Total annotations: {len(coco_data['annotations'])}")

## 3. Extract Strawberry Crops

In [None]:
# Extract strawberry crops using bounding boxes
CROPS_DIR = "/kaggle/working/strawberry_crops"
os.makedirs(CROPS_DIR, exist_ok=True)

# Create class directories
class_map = {
    0: 'ripe',
    1: 'unripe',
    2: 'half_ripe'
}

for class_name in class_map.values():
    os.makedirs(os.path.join(CROPS_DIR, class_name), exist_ok=True)

print("Extracting strawberry crops...")

crop_counts = {name: 0 for name in class_map.values()}

for img_info in tqdm(coco_data['images']):
    img_path = os.path.join(DATASET_DIR, "images", img_info['file_name'])
    img = cv2.imread(img_path)
    
    if img is None:
        continue
    
    # Get strawberry annotations for this image (exclude peduncles)
    img_anns = [ann for ann in coco_data['annotations'] 
                if ann['image_id'] == img_info['id'] and ann['category_id'] in [0, 1, 2]]
    
    for ann in img_anns:
        # Get bounding box
        x, y, w, h = [int(v) for v in ann['bbox']]
        
        # Add padding
        padding = 10
        x = max(0, x - padding)
        y = max(0, y - padding)
        w = min(img.shape[1] - x, w + 2*padding)
        h = min(img.shape[0] - y, h + 2*padding)
        
        # Crop
        crop = img[y:y+h, x:x+w]
        
        if crop.size == 0:
            continue
        
        # Save crop
        class_name = class_map[ann['category_id']]
        crop_filename = f"{img_info['file_name'][:-4]}_{ann['id']}.png"
        crop_path = os.path.join(CROPS_DIR, class_name, crop_filename)
        cv2.imwrite(crop_path, crop)
        crop_counts[class_name] += 1

print("\nCrop extraction complete!")
for class_name, count in crop_counts.items():
    print(f"  {class_name}: {count} crops")

In [None]:
# Visualize class distribution
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))

classes = list(crop_counts.keys())
counts = list(crop_counts.values())
colors = ['#FF6B6B', '#4ECDC4', '#FFE66D']

ax1.bar(classes, counts, color=colors)
ax1.set_title('Strawberry Class Distribution', fontsize=14, fontweight='bold')
ax1.set_ylabel('Count')
for i, v in enumerate(counts):
    ax1.text(i, v + 5, str(v), ha='center', fontweight='bold')

ax2.pie(counts, labels=classes, autopct='%1.1f%%', colors=colors, startangle=90)
ax2.set_title('Class Proportion', fontsize=14, fontweight='bold')

plt.tight_layout()
plt.savefig('/kaggle/working/classification_distribution.png', dpi=150, bbox_inches='tight')
plt.show()

In [None]:
# Visualize sample crops
fig, axes = plt.subplots(3, 6, figsize=(18, 9))

for row, class_name in enumerate(class_map.values()):
    class_dir = os.path.join(CROPS_DIR, class_name)
    samples = sorted(os.listdir(class_dir))[:6]
    
    for col, sample in enumerate(samples):
        img = cv2.imread(os.path.join(class_dir, sample))
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        
        axes[row, col].imshow(img)
        axes[row, col].set_title(class_name if col == 0 else '', fontsize=10)
        axes[row, col].axis('off')

plt.suptitle('Sample Crops per Class', fontsize=16, fontweight='bold')
plt.tight_layout()
plt.savefig('/kaggle/working/sample_crops.png', dpi=150, bbox_inches='tight')
plt.show()

## 4. Create Dataset and DataLoaders

In [None]:
# Custom Dataset
class StrawberryDataset(Dataset):
    def __init__(self, image_paths, labels, transform=None):
        self.image_paths = image_paths
        self.labels = labels
        self.transform = transform
    
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        img = cv2.imread(self.image_paths[idx])
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        label = self.labels[idx]
        
        if self.transform:
            augmented = self.transform(image=img)
            img = augmented['image']
        
        return img, label

# Prepare data
all_images = []
all_labels = []
label_to_idx = {'ripe': 0, 'unripe': 1, 'half_ripe': 2}

for class_name in class_map.values():
    class_dir = os.path.join(CROPS_DIR, class_name)
    for img_name in os.listdir(class_dir):
        all_images.append(os.path.join(class_dir, img_name))
        all_labels.append(label_to_idx[class_name])

# Split dataset
X_train, X_val, y_train, y_val = train_test_split(
    all_images, all_labels, test_size=0.2, random_state=42, stratify=all_labels
)

print(f"Train samples: {len(X_train)}")
print(f"Val samples: {len(X_val)}")

In [None]:
# Data augmentation
train_transform = A.Compose([
    A.Resize(224, 224),
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.3),
    A.RandomRotate90(p=0.5),
    A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.2, rotate_limit=30, p=0.5),
    A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.5),
    A.HueSaturationValue(hue_shift_limit=20, sat_shift_limit=30, val_shift_limit=20, p=0.5),
    A.GaussNoise(p=0.3),
    A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ToTensorV2()
])

val_transform = A.Compose([
    A.Resize(224, 224),
    A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ToTensorV2()
])

# Create datasets
train_dataset = StrawberryDataset(X_train, y_train, transform=train_transform)
val_dataset = StrawberryDataset(X_val, y_val, transform=val_transform)

# Create dataloaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=2)

print("DataLoaders created")

## 5. Build Model

In [None]:
# Use EfficientNet-B0 pretrained model
model = timm.create_model('efficientnet_b0', pretrained=True, num_classes=3)
model = model.to(device)

print(f"Model: EfficientNet-B0")
print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}")

## 6. Training

In [None]:
# Training setup
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=50)

# Training loop
def train_epoch(model, loader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    for images, labels in tqdm(loader, desc="Training"):
        images, labels = images.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()
    
    return running_loss / len(loader), 100. * correct / total

def validate(model, loader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for images, labels in tqdm(loader, desc="Validation"):
            images, labels = images.to(device), labels.to(device)
            
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
            
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    
    return running_loss / len(loader), 100. * correct / total, all_preds, all_labels

In [None]:
# Train model
num_epochs = 50
best_acc = 0.0
history = {'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': []}

for epoch in range(num_epochs):
    print(f"\nEpoch {epoch+1}/{num_epochs}")
    
    train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
    val_loss, val_acc, val_preds, val_labels = validate(model, val_loader, criterion, device)
    
    scheduler.step()
    
    history['train_loss'].append(train_loss)
    history['train_acc'].append(train_acc)
    history['val_loss'].append(val_loss)
    history['val_acc'].append(val_acc)
    
    print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%")
    print(f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%")
    
    # Save best model
    if val_acc > best_acc:
        best_acc = val_acc
        torch.save(model.state_dict(), '/kaggle/working/best_classification_model.pth')
        print(f"âœ“ Best model saved (acc: {best_acc:.2f}%)")

print(f"\nTraining complete! Best validation accuracy: {best_acc:.2f}%")

## 7. Evaluate and Visualize Results

In [None]:
# Plot training curves
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))

epochs_range = range(1, num_epochs + 1)

ax1.plot(epochs_range, history['train_loss'], label='Train Loss', linewidth=2)
ax1.plot(epochs_range, history['val_loss'], label='Val Loss', linewidth=2)
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.set_title('Training and Validation Loss', fontweight='bold')
ax1.legend()
ax1.grid(True)

ax2.plot(epochs_range, history['train_acc'], label='Train Acc', linewidth=2)
ax2.plot(epochs_range, history['val_acc'], label='Val Acc', linewidth=2)
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Accuracy (%)')
ax2.set_title('Training and Validation Accuracy', fontweight='bold')
ax2.legend()
ax2.grid(True)

plt.tight_layout()
plt.savefig('/kaggle/working/training_curves.png', dpi=150, bbox_inches='tight')
plt.show()

In [None]:
# Load best model and evaluate
model.load_state_dict(torch.load('/kaggle/working/best_classification_model.pth'))
_, _, val_preds, val_labels = validate(model, val_loader, criterion, device)

# Classification report
class_names = ['ripe', 'unripe', 'half_ripe']
print("\n=== Classification Report ===")
print(classification_report(val_labels, val_preds, target_names=class_names))

# Confusion matrix
cm = confusion_matrix(val_labels, val_preds)
plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=class_names, yticklabels=class_names)
plt.title('Confusion Matrix', fontsize=14, fontweight='bold')
plt.ylabel('True Label')
plt.xlabel('Predicted Label')
plt.tight_layout()
plt.savefig('/kaggle/working/confusion_matrix.png', dpi=150, bbox_inches='tight')
plt.show()

## 8. Save Results

In [None]:
# Save summary
summary = {
    "model": "EfficientNet-B0",
    "dataset": {
        "total_samples": len(all_images),
        "train_samples": len(X_train),
        "val_samples": len(X_val),
        "classes": class_names
    },
    "training": {
        "epochs": num_epochs,
        "best_val_acc": float(best_acc),
        "final_train_acc": float(history['train_acc'][-1]),
        "final_val_acc": float(history['val_acc'][-1])
    },
    "metrics": {
        "accuracy": float(accuracy_score(val_labels, val_preds)),
        "per_class": classification_report(val_labels, val_preds, target_names=class_names, output_dict=True)
    }
}

with open('/kaggle/working/classification_summary.json', 'w') as f:
    json.dump(summary, f, indent=2)

print("Summary saved!")
print(json.dumps(summary, indent=2))