# Task 3: Strawberry-Peduncle Matching

This notebook trains a model to match strawberries with their corresponding peduncles (stems).

**Approach**: Siamese Network with contrastive learning
**Dataset**: Uses parent_id relationships from annotations
**Environment**: Designed for Kaggle with GPU support

## 1. Environment Setup

In [None]:
# Check GPU
import torch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

In [None]:
# Install dependencies
!pip install -q torch torchvision timm opencv-python-headless matplotlib seaborn scikit-learn
!pip install -q pillow numpy pandas tqdm albumentations

In [None]:
import os
import json
import random
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
from PIL import Image
from tqdm.auto import tqdm
import cv2
from collections import defaultdict

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import timm

from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, roc_auc_score
import albumentations as A
from albumentations.pytorch import ToTensorV2

sns.set_style('whitegrid')
plt.rcParams['figure.figsize'] = (12, 8)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Set seeds
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)

## 2. Download Dataset

In [None]:
# Clone dataset repository
REPO_URL = "https://github.com/SergKurchev/strawberry_synthetic_dataset.git"
DATASET_DIR = "/kaggle/working/strawberry_dataset"

if not os.path.exists(DATASET_DIR):
    print("Cloning dataset repository...")
    !git clone {REPO_URL} {DATASET_DIR}
else:
    print("Dataset already exists")

# Load annotations
with open(os.path.join(DATASET_DIR, "annotations.json"), 'r') as f:
    coco_data = json.load(f)

print(f"\nTotal images: {len(coco_data['images'])}")
print(f"Total annotations: {len(coco_data['annotations'])}")

## 3. Extract Crops and Build Matching Pairs

In [None]:
# Extract all object crops
CROPS_DIR = "/kaggle/working/matching_crops"
os.makedirs(f"{CROPS_DIR}/strawberries", exist_ok=True)
os.makedirs(f"{CROPS_DIR}/peduncles", exist_ok=True)

print("Extracting crops...")

strawberry_crops = {}  # instance_id -> crop_path
peduncle_crops = {}    # instance_id -> crop_path
matching_info = {}     # strawberry_instance_id -> peduncle_instance_id

for img_info in tqdm(coco_data['images']):
    img_path = os.path.join(DATASET_DIR, "images", img_info['file_name'])
    img = cv2.imread(img_path)
    
    if img is None:
        continue
    
    # Get all annotations for this image
    img_anns = [ann for ann in coco_data['annotations'] if ann['image_id'] == img_info['id']]
    
    for ann in img_anns:
        # Get bounding box
        x, y, w, h = [int(v) for v in ann['bbox']]
        
        # Add padding
        padding = 10
        x = max(0, x - padding)
        y = max(0, y - padding)
        w = min(img.shape[1] - x, w + 2*padding)
        h = min(img.shape[0] - y, h + 2*padding)
        
        # Crop
        crop = img[y:y+h, x:x+w]
        
        if crop.size == 0:
            continue
        
        instance_id = ann['instance_id']
        category_id = ann['category_id']
        
        # Save crop
        if category_id in [0, 1, 2]:  # Strawberry
            crop_filename = f"strawberry_{instance_id}.png"
            crop_path = os.path.join(CROPS_DIR, "strawberries", crop_filename)
            cv2.imwrite(crop_path, crop)
            strawberry_crops[instance_id] = crop_path
            
            # Record matching
            if 'parent_id' in ann and ann['parent_id'] != 0:
                matching_info[instance_id] = ann['parent_id']
        
        elif category_id == 3:  # Peduncle
            crop_filename = f"peduncle_{instance_id}.png"
            crop_path = os.path.join(CROPS_DIR, "peduncles", crop_filename)
            cv2.imwrite(crop_path, crop)
            peduncle_crops[instance_id] = crop_path

print(f"\nExtracted {len(strawberry_crops)} strawberry crops")
print(f"Extracted {len(peduncle_crops)} peduncle crops")
print(f"Found {len(matching_info)} strawberry-peduncle matches")

In [None]:
# Create positive and negative pairs
positive_pairs = []
negative_pairs = []

print("Creating training pairs...")

# Positive pairs: strawberry with its parent peduncle
for strawberry_id, peduncle_id in matching_info.items():
    if strawberry_id in strawberry_crops and peduncle_id in peduncle_crops:
        positive_pairs.append((
            strawberry_crops[strawberry_id],
            peduncle_crops[peduncle_id],
            1  # Label: match
        ))

# Negative pairs: strawberry with random different peduncle
peduncle_ids = list(peduncle_crops.keys())
for strawberry_id, correct_peduncle_id in matching_info.items():
    if strawberry_id not in strawberry_crops:
        continue
    
    # Sample 2 negative peduncles per strawberry
    available_peduncles = [p for p in peduncle_ids if p != correct_peduncle_id]
    if len(available_peduncles) < 2:
        continue
    
    neg_peduncles = random.sample(available_peduncles, min(2, len(available_peduncles)))
    for neg_ped_id in neg_peduncles:
        negative_pairs.append((
            strawberry_crops[strawberry_id],
            peduncle_crops[neg_ped_id],
            0  # Label: no match
        ))

print(f"Positive pairs: {len(positive_pairs)}")
print(f"Negative pairs: {len(negative_pairs)}")

# Combine and split
all_pairs = positive_pairs + negative_pairs
random.shuffle(all_pairs)

train_pairs, val_pairs = train_test_split(all_pairs, test_size=0.2, random_state=42)
print(f"\nTrain pairs: {len(train_pairs)}")
print(f"Val pairs: {len(val_pairs)}")

In [None]:
# Visualize sample pairs
fig, axes = plt.subplots(4, 4, figsize=(16, 16))

for idx in range(8):
    # Positive pair
    pos_pair = positive_pairs[idx]
    strawberry_img = cv2.imread(pos_pair[0])
    strawberry_img = cv2.cvtColor(strawberry_img, cv2.COLOR_BGR2RGB)
    peduncle_img = cv2.imread(pos_pair[1])
    peduncle_img = cv2.cvtColor(peduncle_img, cv2.COLOR_BGR2RGB)
    
    row = idx // 2
    col = (idx % 2) * 2
    
    axes[row, col].imshow(strawberry_img)
    axes[row, col].set_title('Strawberry', fontsize=10, color='green')
    axes[row, col].axis('off')
    
    axes[row, col+1].imshow(peduncle_img)
    axes[row, col+1].set_title('Peduncle (MATCH)', fontsize=10, color='green')
    axes[row, col+1].axis('off')

plt.suptitle('Sample Positive Pairs (Matching)', fontsize=16, fontweight='bold')
plt.tight_layout()
plt.savefig('/kaggle/working/matching_pairs.png', dpi=150, bbox_inches='tight')
plt.show()

## 4. Create Dataset and DataLoaders

In [None]:
# Pair Dataset
class MatchingDataset(Dataset):
    def __init__(self, pairs, transform=None):
        self.pairs = pairs
        self.transform = transform
    
    def __len__(self):
        return len(self.pairs)
    
    def __getitem__(self, idx):
        strawberry_path, peduncle_path, label = self.pairs[idx]
        
        strawberry = cv2.imread(strawberry_path)
        strawberry = cv2.cvtColor(strawberry, cv2.COLOR_BGR2RGB)
        
        peduncle = cv2.imread(peduncle_path)
        peduncle = cv2.cvtColor(peduncle, cv2.COLOR_BGR2RGB)
        
        if self.transform:
            strawberry = self.transform(image=strawberry)['image']
            peduncle = self.transform(image=peduncle)['image']
        
        return strawberry, peduncle, torch.tensor(label, dtype=torch.float32)

# Transforms
train_transform = A.Compose([
    A.Resize(128, 128),
    A.HorizontalFlip(p=0.5),
    A.RandomBrightnessContrast(p=0.3),
    A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ToTensorV2()
])

val_transform = A.Compose([
    A.Resize(128, 128),
    A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ToTensorV2()
])

# Create datasets
train_dataset = MatchingDataset(train_pairs, transform=train_transform)
val_dataset = MatchingDataset(val_pairs, transform=val_transform)

# Create dataloaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=2)

print("DataLoaders created")

## 5. Build Siamese Network

In [None]:
# Siamese Network with shared encoder
class SiameseNetwork(nn.Module):
    def __init__(self, embedding_dim=128):
        super(SiameseNetwork, self).__init__()
        
        # Shared encoder
        self.encoder = timm.create_model('efficientnet_b0', pretrained=True, num_classes=0)
        
        # Get encoder output dimension
        with torch.no_grad():
            dummy_input = torch.randn(1, 3, 128, 128)
            encoder_dim = self.encoder(dummy_input).shape[1]
        
        # Embedding head
        self.embedding_head = nn.Sequential(
            nn.Linear(encoder_dim, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, embedding_dim)
        )
        
        # Matching head
        self.matching_head = nn.Sequential(
            nn.Linear(embedding_dim * 2, 64),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(64, 1),
            nn.Sigmoid()
        )
    
    def forward_one(self, x):
        x = self.encoder(x)
        x = self.embedding_head(x)
        return F.normalize(x, p=2, dim=1)  # L2 normalization
    
    def forward(self, x1, x2):
        emb1 = self.forward_one(x1)
        emb2 = self.forward_one(x2)
        
        # Concatenate embeddings
        combined = torch.cat([emb1, emb2], dim=1)
        
        # Predict match probability
        output = self.matching_head(combined)
        
        return output.squeeze()

model = SiameseNetwork(embedding_dim=128).to(device)
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

## 6. Training

In [None]:
# Training setup
criterion = nn.BCELoss()
optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=30)

def train_epoch(model, loader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    all_preds = []
    all_labels = []
    
    for strawberry, peduncle, labels in tqdm(loader, desc="Training"):
        strawberry = strawberry.to(device)
        peduncle = peduncle.to(device)
        labels = labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(strawberry, peduncle)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        all_preds.extend((outputs > 0.5).cpu().numpy())
        all_labels.extend(labels.cpu().numpy())
    
    acc = accuracy_score(all_labels, all_preds)
    return running_loss / len(loader), acc * 100

def validate(model, loader, criterion, device):
    model.eval()
    running_loss = 0.0
    all_preds = []
    all_probs = []
    all_labels = []
    
    with torch.no_grad():
        for strawberry, peduncle, labels in tqdm(loader, desc="Validation"):
            strawberry = strawberry.to(device)
            peduncle = peduncle.to(device)
            labels = labels.to(device)
            
            outputs = model(strawberry, peduncle)
            loss = criterion(outputs, labels)
            
            running_loss += loss.item()
            all_probs.extend(outputs.cpu().numpy())
            all_preds.extend((outputs > 0.5).cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    
    acc = accuracy_score(all_labels, all_preds)
    return running_loss / len(loader), acc * 100, all_preds, all_probs, all_labels

In [None]:
# Train model
num_epochs = 30
best_acc = 0.0
history = {'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': []}

for epoch in range(num_epochs):
    print(f"\nEpoch {epoch+1}/{num_epochs}")
    
    train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
    val_loss, val_acc, val_preds, val_probs, val_labels = validate(model, val_loader, criterion, device)
    
    scheduler.step()
    
    history['train_loss'].append(train_loss)
    history['train_acc'].append(train_acc)
    history['val_loss'].append(val_loss)
    history['val_acc'].append(val_acc)
    
    print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%")
    print(f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%")
    
    if val_acc > best_acc:
        best_acc = val_acc
        torch.save(model.state_dict(), '/kaggle/working/best_matching_model.pth')
        print(f"âœ“ Best model saved (acc: {best_acc:.2f}%)")

print(f"\nTraining complete! Best validation accuracy: {best_acc:.2f}%")

## 7. Evaluate and Visualize

In [None]:
# Plot training curves
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))

epochs_range = range(1, num_epochs + 1)

ax1.plot(epochs_range, history['train_loss'], label='Train Loss', linewidth=2)
ax1.plot(epochs_range, history['val_loss'], label='Val Loss', linewidth=2)
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.set_title('Training and Validation Loss', fontweight='bold')
ax1.legend()
ax1.grid(True)

ax2.plot(epochs_range, history['train_acc'], label='Train Acc', linewidth=2)
ax2.plot(epochs_range, history['val_acc'], label='Val Acc', linewidth=2)
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Accuracy (%)')
ax2.set_title('Training and Validation Accuracy', fontweight='bold')
ax2.legend()
ax2.grid(True)

plt.tight_layout()
plt.savefig('/kaggle/working/matching_training_curves.png', dpi=150, bbox_inches='tight')
plt.show()

In [None]:
# Load best model and evaluate
model.load_state_dict(torch.load('/kaggle/working/best_matching_model.pth'))
_, _, val_preds, val_probs, val_labels = validate(model, val_loader, criterion, device)

# Metrics
precision, recall, f1, _ = precision_recall_fscore_support(val_labels, val_preds, average='binary')
auc = roc_auc_score(val_labels, val_probs)

print("\n=== Matching Metrics ===")
print(f"Accuracy: {accuracy_score(val_labels, val_preds)*100:.2f}%")
print(f"Precision: {precision:.4f}")
print(f"Recall: {recall:.4f}")
print(f"F1-Score: {f1:.4f}")
print(f"AUC-ROC: {auc:.4f}")

# Confusion matrix
from sklearn.metrics import confusion_matrix
cm = confusion_matrix(val_labels, val_preds)

plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=['No Match', 'Match'], yticklabels=['No Match', 'Match'])
plt.title('Confusion Matrix', fontsize=14, fontweight='bold')
plt.ylabel('True Label')
plt.xlabel('Predicted Label')
plt.tight_layout()
plt.savefig('/kaggle/working/matching_confusion_matrix.png', dpi=150, bbox_inches='tight')
plt.show()

## 8. Save Results

In [None]:
# Save summary
summary = {
    "model": "Siamese Network (EfficientNet-B0)",
    "dataset": {
        "total_pairs": len(all_pairs),
        "train_pairs": len(train_pairs),
        "val_pairs": len(val_pairs),
        "positive_pairs": len(positive_pairs),
        "negative_pairs": len(negative_pairs)
    },
    "training": {
        "epochs": num_epochs,
        "best_val_acc": float(best_acc)
    },
    "metrics": {
        "accuracy": float(accuracy_score(val_labels, val_preds)),
        "precision": float(precision),
        "recall": float(recall),
        "f1_score": float(f1),
        "auc_roc": float(auc)
    }
}

with open('/kaggle/working/matching_summary.json', 'w') as f:
    json.dump(summary, f, indent=2)

print("Summary saved!")
print(json.dumps(summary, indent=2))