In [1]:
# ============================================================================
# Cross-Species Fruit Quality Grading via Few-Shot Prototypical Networks
# Learning Class-Agnostic Defect Representations
# 
# Author: Amr Samir
# Master's Thesis - 2026
# ============================================================================
# 
# RESEARCH GAP: 
# While models can classify fruit species, they fail to generalize quality 
# grading (Good/Bad) across unseen fruit types. This work proves that metric 
# learning can learn "defectness" rather than "fruit-specific features."
#
# KEY CONTRIBUTION:
# Train on {Apple, Banana, Grape} ‚Üí Test on {Mango, Orange} WITHOUT retraining
# ============================================================================

import os
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from PIL import Image
import matplotlib.pyplot as plt
from collections import defaultdict
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

# Set seeds for reproducibility
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True

set_seed(42)

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
print(f"PyTorch version: {torch.__version__}")

Using device: cpu
PyTorch version: 2.5.1+cpu


## 1. Configuration & Hyperparameters

Define all experimental settings in one place for reproducibility and easy ablation studies.

In [2]:
# ============================================================================
# CONFIGURATION - Modify these based on your dataset and experiments
# ============================================================================

class Config:
    """Centralized configuration for all experiments"""
    
    # Dataset paths - FruitVision Dataset
    DATA_ROOT = r"D:\Datasets\FruitVision"
    
    # Fruits for training (SEEN during training)
    TRAIN_FRUITS = ['apple', 'banana', 'grape']
    
    # Fruits for testing (UNSEEN - the key experiment!)
    TEST_FRUITS = ['mango', 'orange']
    
    # Quality classes (binary grading) - FruitVision uses fresh/rotten
    CLASSES = ['fresh', 'rotten']  # Maps to Good/Bad
    N_CLASSES = 2  # Binary: Fresh (Good) vs Rotten (Bad)
    
    # Few-shot settings
    N_SHOT = 5       # Number of support examples per class (5-shot learning)
    N_QUERY = 15     # Number of query examples per class
    N_EPISODES_TRAIN = 1000   # Training episodes per epoch
    N_EPISODES_VAL = 200      # Validation episodes
    N_EPISODES_TEST = 600     # Test episodes for statistical significance
    
    # Model settings
    BACKBONE = 'resnet18'     # Options: 'resnet18', 'resnet50', 'vit_tiny'
    EMBEDDING_DIM = 512       # Dimension of the embedding space
    PRETRAINED = True         # Use ImageNet pretrained weights
    
    # Training settings
    EPOCHS = 50
    LEARNING_RATE = 1e-4
    WEIGHT_DECAY = 1e-5
    BATCH_SIZE = 1            # For episodic training, batch_size = 1 episode
    
    # Image settings
    IMAGE_SIZE = 224
    
    # Paths for saving
    CHECKPOINT_DIR = './checkpoints'
    RESULTS_DIR = './results'
    
    # Experiment name (for logging)
    EXPERIMENT_NAME = f"ProtoNet_{BACKBONE}_{N_SHOT}shot"

config = Config()

# Create directories
os.makedirs(config.CHECKPOINT_DIR, exist_ok=True)
os.makedirs(config.RESULTS_DIR, exist_ok=True)

print("=" * 60)
print("EXPERIMENTAL CONFIGURATION")
print("=" * 60)
print(f"Train Fruits (SEEN):     {config.TRAIN_FRUITS}")
print(f"Test Fruits (UNSEEN):    {config.TEST_FRUITS}")
print(f"Few-Shot Setting:        {config.N_SHOT}-shot, {config.N_QUERY}-query")
print(f"Backbone:                {config.BACKBONE}")
print(f"Embedding Dimension:     {config.EMBEDDING_DIM}")
print(f"Training Episodes:       {config.N_EPISODES_TRAIN}")
print("=" * 60)

EXPERIMENTAL CONFIGURATION
Train Fruits (SEEN):     ['apple', 'banana', 'grape']
Test Fruits (UNSEEN):    ['mango', 'orange']
Few-Shot Setting:        5-shot, 15-query
Backbone:                resnet18
Embedding Dimension:     512
Training Episodes:       1000


## 2. Data Augmentation & Transforms

Strong augmentation is critical for learning generalizable defect features. We use different transforms for support (stable) and query (augmented) sets.

In [3]:
# ============================================================================
# DATA AUGMENTATION STRATEGIES
# ============================================================================
# Key insight: Defects have texture/edge patterns. Augmentations should preserve
# these while varying lighting, orientation, and scale.

# Training augmentation (strong)
train_transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.RandomCrop(config.IMAGE_SIZE),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomVerticalFlip(p=0.3),
    transforms.RandomRotation(degrees=30),
    transforms.ColorJitter(
        brightness=0.3,
        contrast=0.3,
        saturation=0.3,
        hue=0.1
    ),
    transforms.RandomAffine(
        degrees=0,
        translate=(0.1, 0.1),
        scale=(0.9, 1.1)
    ),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],  # ImageNet stats
        std=[0.229, 0.224, 0.225]
    ),
    transforms.RandomErasing(p=0.2, scale=(0.02, 0.1))  # Simulates occlusion
])

# Validation/Test augmentation (minimal - just normalization)
eval_transform = transforms.Compose([
    transforms.Resize((config.IMAGE_SIZE, config.IMAGE_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])

print("‚úì Data transforms defined")
print(f"  Training: Strong augmentation with color jitter, rotation, erasing")
print(f"  Evaluation: Minimal transforms (resize + normalize)")

‚úì Data transforms defined
  Training: Strong augmentation with color jitter, rotation, erasing
  Evaluation: Minimal transforms (resize + normalize)


## 3. Episodic Dataset for Few-Shot Learning

The core innovation: Instead of traditional batches, we sample **episodes**. Each episode contains:
- **Support Set**: K examples of Good + K examples of Bad (used to build prototypes)
- **Query Set**: Q examples to classify using the prototypes

This forces the model to learn "defectness" in a generalizable way.

In [4]:
# ============================================================================
# EPISODIC DATASET FOR FEW-SHOT QUALITY GRADING
# ============================================================================

class FruitQualityDataset:
    """
    Loads fruit images organized by fruit type and quality.
    Expected structure (FruitVision):
        data_root/
            fruit_name/
                fresh/
                    img1.jpg, img2.jpg, ...
                rotten/
                    img1.jpg, img2.jpg, ...
    """
    def __init__(self, data_root, fruit_types, transform=None):
        self.data_root = data_root
        self.fruit_types = fruit_types
        self.transform = transform
        self.classes = ['fresh', 'rotten']  # FruitVision naming
        
        # Organize images by fruit and quality
        self.data = defaultdict(lambda: defaultdict(list))
        self._load_data()
        
    def _load_data(self):
        """Load all image paths organized by fruit type and quality"""
        for fruit in self.fruit_types:
            for quality in self.classes:
                folder_path = os.path.join(self.data_root, fruit, quality)
                if os.path.exists(folder_path):
                    images = [
                        os.path.join(folder_path, f) 
                        for f in os.listdir(folder_path) 
                        if f.lower().endswith(('.jpg', '.jpeg', '.png', '.bmp'))
                    ]
                    self.data[fruit][quality] = images
                    print(f"  Loaded {len(images):4d} images: {fruit}/{quality}")
                else:
                    print(f"  ‚ö† Missing folder: {folder_path}")
    
    def get_episode(self, n_shot, n_query, fruit=None):
        """
        Sample a single episode for few-shot learning.
        
        Args:
            n_shot: Number of support examples per class
            n_query: Number of query examples per class
            fruit: Specific fruit to sample from (None = random)
            
        Returns:
            support_images: [n_classes * n_shot] tensor
            support_labels: [n_classes * n_shot] tensor
            query_images: [n_classes * n_query] tensor
            query_labels: [n_classes * n_query] tensor
            fruit_name: Name of the fruit in this episode
        """
        # Select fruit
        if fruit is None:
            fruit = random.choice(self.fruit_types)
        
        support_images, support_labels = [], []
        query_images, query_labels = [], []
        
        for class_idx, quality in enumerate(self.classes):
            # Get all images for this fruit-quality combination
            all_images = self.data[fruit][quality]
            
            if len(all_images) < n_shot + n_query:
                raise ValueError(
                    f"Not enough images for {fruit}/{quality}. "
                    f"Need {n_shot + n_query}, have {len(all_images)}"
                )
            
            # Randomly sample support and query
            sampled = random.sample(all_images, n_shot + n_query)
            support_paths = sampled[:n_shot]
            query_paths = sampled[n_shot:]
            
            # Load and transform images
            for path in support_paths:
                img = Image.open(path).convert('RGB')
                if self.transform:
                    img = self.transform(img)
                support_images.append(img)
                support_labels.append(class_idx)
                
            for path in query_paths:
                img = Image.open(path).convert('RGB')
                if self.transform:
                    img = self.transform(img)
                query_images.append(img)
                query_labels.append(class_idx)
        
        # Stack into tensors
        support_images = torch.stack(support_images)  # [n_classes*n_shot, C, H, W]
        support_labels = torch.tensor(support_labels)
        query_images = torch.stack(query_images)      # [n_classes*n_query, C, H, W]
        query_labels = torch.tensor(query_labels)
        
        return support_images, support_labels, query_images, query_labels, fruit


class EpisodicDataLoader:
    """
    DataLoader that yields episodes instead of batches.
    Each iteration returns one episode (support + query sets).
    """
    def __init__(self, dataset, n_shot, n_query, n_episodes, fruits=None):
        self.dataset = dataset
        self.n_shot = n_shot
        self.n_query = n_query
        self.n_episodes = n_episodes
        self.fruits = fruits if fruits else dataset.fruit_types
        
    def __iter__(self):
        for _ in range(self.n_episodes):
            fruit = random.choice(self.fruits)
            yield self.dataset.get_episode(
                self.n_shot, 
                self.n_query, 
                fruit=fruit
            )
    
    def __len__(self):
        return self.n_episodes

print("‚úì Episodic dataset classes defined")

‚úì Episodic dataset classes defined


## 4. Prototypical Network Architecture

The **Prototypical Network** computes a prototype (centroid) for each class from support examples, then classifies queries by distance to prototypes.

$$d(z_q, c_k) = \|f_\theta(x_q) - \frac{1}{|S_k|}\sum_{x_i \in S_k} f_\theta(x_i)\|^2$$

Where:
- $f_\theta$ is our embedding network (ResNet backbone)
- $c_k$ is the prototype for class $k$
- $z_q$ is the query embedding

In [5]:
# ============================================================================
# PROTOTYPICAL NETWORK ARCHITECTURE
# ============================================================================

class EmbeddingNetwork(nn.Module):
    """
    Feature extraction backbone that maps images to embedding space.
    The embedding should capture defect-related features (texture, edges, anomalies).
    """
    def __init__(self, backbone='resnet18', embedding_dim=512, pretrained=True):
        super().__init__()
        
        if backbone == 'resnet18':
            self.encoder = models.resnet18(pretrained=pretrained)
            in_features = self.encoder.fc.in_features
            self.encoder.fc = nn.Identity()  # Remove classification head
            
        elif backbone == 'resnet50':
            self.encoder = models.resnet50(pretrained=pretrained)
            in_features = self.encoder.fc.in_features
            self.encoder.fc = nn.Identity()
            
        elif backbone == 'efficientnet_b0':
            self.encoder = models.efficientnet_b0(pretrained=pretrained)
            in_features = self.encoder.classifier[1].in_features
            self.encoder.classifier = nn.Identity()
            
        else:
            raise ValueError(f"Unknown backbone: {backbone}")
        
        # Projection head to embedding dimension
        self.projection = nn.Sequential(
            nn.Linear(in_features, embedding_dim),
            nn.BatchNorm1d(embedding_dim),
            nn.ReLU(),
            nn.Linear(embedding_dim, embedding_dim)
        )
        
        self.embedding_dim = embedding_dim
        
    def forward(self, x):
        """Extract embeddings from images"""
        features = self.encoder(x)
        embeddings = self.projection(features)
        # L2 normalize embeddings (important for metric learning)
        embeddings = F.normalize(embeddings, p=2, dim=1)
        return embeddings


class PrototypicalNetwork(nn.Module):
    """
    Few-Shot Prototypical Network for Fruit Quality Grading.
    
    Key idea: Build a prototype (mean embedding) for each class from support 
    examples, then classify queries by distance to prototypes.
    """
    def __init__(self, backbone='resnet18', embedding_dim=512, pretrained=True):
        super().__init__()
        self.encoder = EmbeddingNetwork(backbone, embedding_dim, pretrained)
        self.embedding_dim = embedding_dim
        
    def compute_prototypes(self, support_embeddings, support_labels, n_classes=2):
        """
        Compute class prototypes from support set.
        
        Args:
            support_embeddings: [n_support, embedding_dim]
            support_labels: [n_support]
            n_classes: Number of classes
            
        Returns:
            prototypes: [n_classes, embedding_dim]
        """
        prototypes = torch.zeros(n_classes, self.embedding_dim, device=support_embeddings.device)
        
        for c in range(n_classes):
            mask = (support_labels == c)
            class_embeddings = support_embeddings[mask]
            prototypes[c] = class_embeddings.mean(dim=0)
        
        return prototypes
    
    def forward(self, support_images, support_labels, query_images, n_classes=2):
        """
        Forward pass for one episode.
        
        Args:
            support_images: [n_support, C, H, W]
            support_labels: [n_support]
            query_images: [n_query, C, H, W]
            n_classes: Number of classes
            
        Returns:
            logits: [n_query, n_classes] - negative distances (for softmax)
            query_embeddings: [n_query, embedding_dim]
            prototypes: [n_classes, embedding_dim]
        """
        # Get embeddings
        support_embeddings = self.encoder(support_images)  # [n_support, dim]
        query_embeddings = self.encoder(query_images)      # [n_query, dim]
        
        # Compute prototypes
        prototypes = self.compute_prototypes(
            support_embeddings, support_labels, n_classes
        )  # [n_classes, dim]
        
        # Compute distances from queries to prototypes
        # Using squared Euclidean distance
        distances = torch.cdist(query_embeddings, prototypes, p=2)  # [n_query, n_classes]
        
        # Return negative distances as logits (closer = higher probability)
        logits = -distances
        
        return logits, query_embeddings, prototypes


# Initialize model
model = PrototypicalNetwork(
    backbone=config.BACKBONE,
    embedding_dim=config.EMBEDDING_DIM,
    pretrained=config.PRETRAINED
).to(device)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"‚úì Model initialized: {config.BACKBONE}")
print(f"  Total parameters:     {total_params:,}")
print(f"  Trainable parameters: {trainable_params:,}")
print(f"  Embedding dimension:  {config.EMBEDDING_DIM}")

‚úì Model initialized: resnet18
  Total parameters:     11,702,848
  Trainable parameters: 11,702,848
  Embedding dimension:  512


## 5. Supervised Contrastive Loss (Alternative)

We implement **SupCon Loss** as an alternative to standard cross-entropy. This loss pushes embeddings of the same class together and different classes apart, which is crucial for learning generalizable defect features.

$$\mathcal{L}_{SupCon} = \sum_{i} \frac{-1}{|P(i)|} \sum_{p \in P(i)} \log \frac{\exp(z_i \cdot z_p / \tau)}{\sum_{a \neq i} \exp(z_i \cdot z_a / \tau)}$$

In [6]:
# ============================================================================
# LOSS FUNCTIONS
# ============================================================================

class SupConLoss(nn.Module):
    """
    Supervised Contrastive Loss (Khosla et al., 2020)
    
    This loss is particularly good for learning embeddings that generalize
    across different visual domains (different fruit types).
    """
    def __init__(self, temperature=0.07):
        super().__init__()
        self.temperature = temperature
        
    def forward(self, features, labels):
        """
        Args:
            features: [batch_size, embedding_dim] - L2 normalized embeddings
            labels: [batch_size]
        """
        device = features.device
        batch_size = features.shape[0]
        
        labels = labels.contiguous().view(-1, 1)
        mask = torch.eq(labels, labels.T).float().to(device)
        
        # Compute similarity
        anchor_dot_contrast = torch.div(
            torch.matmul(features, features.T),
            self.temperature
        )
        
        # For numerical stability
        logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
        logits = anchor_dot_contrast - logits_max.detach()
        
        # Mask out self-contrast
        logits_mask = torch.scatter(
            torch.ones_like(mask),
            1,
            torch.arange(batch_size).view(-1, 1).to(device),
            0
        )
        mask = mask * logits_mask
        
        # Compute log_prob
        exp_logits = torch.exp(logits) * logits_mask
        log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True) + 1e-6)
        
        # Compute mean of log-likelihood over positive
        mean_log_prob_pos = (mask * log_prob).sum(1) / (mask.sum(1) + 1e-6)
        
        # Loss
        loss = -mean_log_prob_pos
        loss = loss.mean()
        
        return loss


class PrototypicalLoss(nn.Module):
    """
    Standard Prototypical Network loss (cross-entropy on distances).
    """
    def __init__(self):
        super().__init__()
        self.ce = nn.CrossEntropyLoss()
        
    def forward(self, logits, labels):
        """
        Args:
            logits: [n_query, n_classes] - negative distances
            labels: [n_query]
        """
        return self.ce(logits, labels)


class CombinedLoss(nn.Module):
    """
    Combines prototypical loss with supervised contrastive loss.
    This often improves embedding quality.
    """
    def __init__(self, proto_weight=1.0, supcon_weight=0.5, temperature=0.07):
        super().__init__()
        self.proto_loss = PrototypicalLoss()
        self.supcon_loss = SupConLoss(temperature)
        self.proto_weight = proto_weight
        self.supcon_weight = supcon_weight
        
    def forward(self, logits, query_labels, all_embeddings, all_labels):
        """
        Args:
            logits: [n_query, n_classes]
            query_labels: [n_query]
            all_embeddings: [n_support + n_query, dim]
            all_labels: [n_support + n_query]
        """
        loss_proto = self.proto_loss(logits, query_labels)
        loss_supcon = self.supcon_loss(all_embeddings, all_labels)
        
        total_loss = (self.proto_weight * loss_proto + 
                      self.supcon_weight * loss_supcon)
        
        return total_loss, loss_proto, loss_supcon

# Initialize loss function
criterion = CombinedLoss(
    proto_weight=1.0, 
    supcon_weight=0.5, 
    temperature=0.1
)

print("‚úì Loss functions defined")
print("  Using: Combined Prototypical + Supervised Contrastive Loss")

‚úì Loss functions defined
  Using: Combined Prototypical + Supervised Contrastive Loss


## 6. Training Loop (Episodic Training)

Unlike traditional training, we iterate over **episodes** not batches. Each episode simulates a few-shot scenario, forcing the model to learn generalizable features.

In [7]:
# ============================================================================
# TRAINING UTILITIES
# ============================================================================

def compute_accuracy(logits, labels):
    """Compute accuracy from logits"""
    predictions = logits.argmax(dim=1)
    accuracy = (predictions == labels).float().mean().item()
    return accuracy


def train_epoch(model, dataloader, criterion, optimizer, device):
    """Train for one epoch (many episodes)"""
    model.train()
    
    total_loss = 0.0
    total_acc = 0.0
    n_episodes = 0
    
    pbar = tqdm(dataloader, desc="Training", leave=False)
    for support_imgs, support_lbls, query_imgs, query_lbls, fruit in pbar:
        # Move to device
        support_imgs = support_imgs.to(device)
        support_lbls = support_lbls.to(device)
        query_imgs = query_imgs.to(device)
        query_lbls = query_lbls.to(device)
        
        # Forward pass
        logits, query_emb, prototypes = model(
            support_imgs, support_lbls, query_imgs
        )
        
        # Get support embeddings for SupCon loss
        support_emb = model.encoder(support_imgs)
        all_embeddings = torch.cat([support_emb, query_emb], dim=0)
        all_labels = torch.cat([support_lbls, query_lbls], dim=0)
        
        # Compute loss
        loss, loss_proto, loss_supcon = criterion(
            logits, query_lbls, all_embeddings, all_labels
        )
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # Metrics
        acc = compute_accuracy(logits, query_lbls)
        total_loss += loss.item()
        total_acc += acc
        n_episodes += 1
        
        pbar.set_postfix({
            'loss': f'{loss.item():.4f}',
            'acc': f'{acc:.3f}',
            'fruit': fruit
        })
    
    return total_loss / n_episodes, total_acc / n_episodes


@torch.no_grad()
def evaluate(model, dataloader, device, desc="Evaluating"):
    """Evaluate on validation/test episodes"""
    model.eval()
    
    total_acc = 0.0
    n_episodes = 0
    fruit_accuracies = defaultdict(list)
    
    pbar = tqdm(dataloader, desc=desc, leave=False)
    for support_imgs, support_lbls, query_imgs, query_lbls, fruit in pbar:
        # Move to device
        support_imgs = support_imgs.to(device)
        support_lbls = support_lbls.to(device)
        query_imgs = query_imgs.to(device)
        query_lbls = query_lbls.to(device)
        
        # Forward pass
        logits, _, _ = model(support_imgs, support_lbls, query_imgs)
        
        # Compute accuracy
        acc = compute_accuracy(logits, query_lbls)
        total_acc += acc
        n_episodes += 1
        fruit_accuracies[fruit].append(acc)
        
        pbar.set_postfix({'acc': f'{acc:.3f}', 'fruit': fruit})
    
    # Compute per-fruit accuracy
    per_fruit_acc = {
        fruit: np.mean(accs) for fruit, accs in fruit_accuracies.items()
    }
    
    return total_acc / n_episodes, per_fruit_acc


def save_checkpoint(model, optimizer, epoch, metrics, path):
    """Save model checkpoint"""
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'metrics': metrics
    }, path)
    print(f"  ‚úì Checkpoint saved: {path}")


def load_checkpoint(model, optimizer, path):
    """Load model checkpoint"""
    checkpoint = torch.load(path)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    return checkpoint['epoch'], checkpoint['metrics']

print("‚úì Training utilities defined")

‚úì Training utilities defined


## 7. Load Dataset & Create DataLoaders

**Dataset:** FruitVision (D:\Datasets\FruitVision)

Dataset structure:
```
FruitVision/
‚îú‚îÄ‚îÄ apple/
‚îÇ   ‚îú‚îÄ‚îÄ fresh/
‚îÇ   ‚îî‚îÄ‚îÄ rotten/
‚îú‚îÄ‚îÄ banana/
‚îÇ   ‚îú‚îÄ‚îÄ fresh/
‚îÇ   ‚îî‚îÄ‚îÄ rotten/
‚îú‚îÄ‚îÄ grape/
‚îÇ   ‚îú‚îÄ‚îÄ fresh/
‚îÇ   ‚îî‚îÄ‚îÄ rotten/
‚îú‚îÄ‚îÄ mango/          ‚Üê UNSEEN (test only)
‚îÇ   ‚îú‚îÄ‚îÄ fresh/
‚îÇ   ‚îî‚îÄ‚îÄ rotten/
‚îî‚îÄ‚îÄ orange/         ‚Üê UNSEEN (test only)
    ‚îú‚îÄ‚îÄ fresh/
    ‚îî‚îÄ‚îÄ rotten/
```

In [8]:
# ============================================================================
# LOAD DATASET
# ============================================================================
# NOTE: Update config.DATA_ROOT to your actual dataset path!
# You can download datasets from:
# 1. FruitVision (Kaggle): https://www.kaggle.com/datasets/fruitvision
# 2. Zenodo Fruit Quality: https://zenodo.org/records/1310165
# 3. Fruits Fresh and Rotten: https://www.kaggle.com/datasets/sriramr/fruits-fresh-and-rotten-for-classification

print("=" * 60)
print("LOADING DATASETS")
print("=" * 60)
print(f"Data root: {config.DATA_ROOT}")
print()

# Check if data exists
if not os.path.exists(config.DATA_ROOT):
    print("‚ö†Ô∏è  WARNING: Dataset not found!")
    print(f"   Please ensure FruitVision is at: {config.DATA_ROOT}")
    print()
    print("   Expected structure:")
    print("   FruitVision/")
    print("   ‚îú‚îÄ‚îÄ apple/")
    print("   ‚îÇ   ‚îú‚îÄ‚îÄ fresh/")
    print("   ‚îÇ   ‚îî‚îÄ‚îÄ rotten/")
    print("   ‚îú‚îÄ‚îÄ banana/")
    print("   ‚îÇ   ‚îú‚îÄ‚îÄ fresh/")
    print("   ‚îÇ   ‚îî‚îÄ‚îÄ rotten/")
    print("   ‚îî‚îÄ‚îÄ ...")
else:
    # Load training dataset (SEEN fruits)
    print("Loading TRAINING data (seen fruits):")
    train_dataset = FruitQualityDataset(
        data_root=config.DATA_ROOT,
        fruit_types=config.TRAIN_FRUITS,
        transform=train_transform
    )
    
    print()
    print("Loading TEST data (UNSEEN fruits - key experiment!):")
    test_dataset = FruitQualityDataset(
        data_root=config.DATA_ROOT,
        fruit_types=config.TEST_FRUITS,
        transform=eval_transform
    )
    
    # Create episodic dataloaders
    train_loader = EpisodicDataLoader(
        dataset=train_dataset,
        n_shot=config.N_SHOT,
        n_query=config.N_QUERY,
        n_episodes=config.N_EPISODES_TRAIN
    )
    
    val_loader = EpisodicDataLoader(
        dataset=train_dataset,  # Validation on seen fruits
        n_shot=config.N_SHOT,
        n_query=config.N_QUERY,
        n_episodes=config.N_EPISODES_VAL
    )
    
    # KEY EXPERIMENT: Test on UNSEEN fruits
    test_loader = EpisodicDataLoader(
        dataset=test_dataset,  # Unseen fruits!
        n_shot=config.N_SHOT,
        n_query=config.N_QUERY,
        n_episodes=config.N_EPISODES_TEST
    )
    
    print()
    print("=" * 60)
    print("DataLoaders created:")
    print(f"  Train: {len(train_loader)} episodes on {config.TRAIN_FRUITS}")
    print(f"  Val:   {len(val_loader)} episodes on {config.TRAIN_FRUITS}")
    print(f"  Test:  {len(test_loader)} episodes on {config.TEST_FRUITS} (UNSEEN!)")
    print("=" * 60)

LOADING DATASETS
Data root: D:\Datasets\FruitVision

Loading TRAINING data (seen fruits):
  Loaded  765 images: apple/fresh
  Loaded  630 images: apple/rotten
  Loaded  749 images: banana/fresh
  Loaded  632 images: banana/rotten
  Loaded  770 images: grape/fresh
  Loaded  630 images: grape/rotten

Loading TEST data (UNSEEN fruits - key experiment!):
  Loaded  763 images: mango/fresh
  Loaded  630 images: mango/rotten
  Loaded  753 images: orange/fresh
  Loaded  656 images: orange/rotten

DataLoaders created:
  Train: 1000 episodes on ['apple', 'banana', 'grape']
  Val:   200 episodes on ['apple', 'banana', 'grape']
  Test:  600 episodes on ['mango', 'orange'] (UNSEEN!)


## 8. Main Training Loop

Run the full training with validation. The model learns from episodes of Apple, Banana, Grape (seen fruits) and will be tested on Mango, Orange (unseen fruits).

In [None]:
# ============================================================================
# MAIN TRAINING LOOP
# ============================================================================

def train_protonet(model, train_loader, val_loader, config, device):
    """
    Full training loop with validation and checkpointing.
    """
    # Optimizer with different learning rates
    optimizer = torch.optim.AdamW([
        {'params': model.encoder.encoder.parameters(), 'lr': config.LEARNING_RATE * 0.1},  # Backbone
        {'params': model.encoder.projection.parameters(), 'lr': config.LEARNING_RATE}       # Projection
    ], weight_decay=config.WEIGHT_DECAY)
    
    # Learning rate scheduler
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, T_max=config.EPOCHS, eta_min=1e-6
    )
    
    # Loss function
    criterion = CombinedLoss(proto_weight=1.0, supcon_weight=0.5)
    
    # Training history
    history = {
        'train_loss': [], 'train_acc': [],
        'val_acc': [], 'val_per_fruit': []
    }
    
    best_val_acc = 0.0
    
    print("=" * 60)
    print("STARTING TRAINING")
    print("=" * 60)
    print(f"Training on SEEN fruits: {config.TRAIN_FRUITS}")
    print(f"Will test on UNSEEN fruits: {config.TEST_FRUITS}")
    print()
    
    for epoch in range(1, config.EPOCHS + 1):
        print(f"\nEpoch {epoch}/{config.EPOCHS}")
        print("-" * 40)
        
        # Train
        train_loss, train_acc = train_epoch(
            model, train_loader, criterion, optimizer, device
        )
        
        # Validate
        val_acc, val_per_fruit = evaluate(
            model, val_loader, device, desc="Validating"
        )
        
        # Update scheduler
        scheduler.step()
        
        # Save history
        history['train_loss'].append(train_loss)
        history['train_acc'].append(train_acc)
        history['val_acc'].append(val_acc)
        history['val_per_fruit'].append(val_per_fruit)
        
        # Print metrics
        print(f"  Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.3f}")
        print(f"  Val Acc:    {val_acc:.3f}")
        print(f"  Per-fruit:  {val_per_fruit}")
        
        # Save best model
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            save_checkpoint(
                model, optimizer, epoch,
                {'val_acc': val_acc, 'per_fruit': val_per_fruit},
                os.path.join(config.CHECKPOINT_DIR, 'best_model.pth')
            )
    
    print("\n" + "=" * 60)
    print(f"TRAINING COMPLETE")
    print(f"Best validation accuracy: {best_val_acc:.3f}")
    print("=" * 60)
    
    return history

# Uncomment to train (only run if dataset exists)
history = train_protonet(model, train_loader, val_loader, config, device)

print("‚úì Training function defined")
print("  To start training, run:")
print("  history = train_protonet(model, train_loader, val_loader, config, device)")

STARTING TRAINING
Training on SEEN fruits: ['apple', 'banana', 'grape']
Will test on UNSEEN fruits: ['mango', 'orange']


Epoch 1/50
----------------------------------------


Training:  56%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñã    | 564/1000 [12:24:20<50:33,  6.96s/it, loss=2.1887, acc=0.900, fruit=apple]        

## 9. KEY EXPERIMENT: Test on UNSEEN Fruits üéØ

**This is the core contribution of your thesis!**

The model was trained ONLY on Apple, Banana, Grape. Now we test on Mango and Orange (which the model has NEVER seen). If it performs well, we have proven that it learned "defectness" rather than fruit-specific features.

In [None]:
# ============================================================================
# KEY EXPERIMENT: TEST ON UNSEEN FRUITS
# ============================================================================

def test_on_unseen_fruits(model, test_loader, device, config):
    """
    The critical experiment that validates our research gap.
    
    Model trained on: Apple, Banana, Grape
    Testing on: Mango, Orange (NEVER seen during training!)
    
    If accuracy > 80%, we have proven cross-species generalization.
    """
    print("=" * 70)
    print("üéØ KEY EXPERIMENT: TESTING ON UNSEEN FRUIT SPECIES")
    print("=" * 70)
    print(f"Model was trained on: {config.TRAIN_FRUITS}")
    print(f"Testing on (UNSEEN): {config.TEST_FRUITS}")
    print(f"Few-shot setting: {config.N_SHOT}-shot")
    print()
    
    # Load best model
    checkpoint_path = os.path.join(config.CHECKPOINT_DIR, 'best_model.pth')
    if os.path.exists(checkpoint_path):
        checkpoint = torch.load(checkpoint_path)
        model.load_state_dict(checkpoint['model_state_dict'])
        print(f"‚úì Loaded best model from training")
    else:
        print("‚ö† No checkpoint found, using current model weights")
    
    # Evaluate on unseen fruits
    test_acc, per_fruit_acc = evaluate(
        model, test_loader, device, 
        desc="Testing on UNSEEN fruits"
    )
    
    # Compute 95% confidence interval
    # Run multiple times for statistical significance
    print("\nComputing confidence intervals (600 episodes)...")
    all_accuracies = []
    per_fruit_all = defaultdict(list)
    
    for _ in tqdm(range(5), desc="Trials"):
        acc, per_fruit = evaluate(model, test_loader, device, desc="Trial")
        all_accuracies.append(acc)
        for fruit, fruit_acc in per_fruit.items():
            per_fruit_all[fruit].append(fruit_acc)
    
    mean_acc = np.mean(all_accuracies)
    std_acc = np.std(all_accuracies)
    ci_95 = 1.96 * std_acc / np.sqrt(len(all_accuracies))
    
    # Results
    print("\n" + "=" * 70)
    print("üìä RESULTS ON UNSEEN FRUITS")
    print("=" * 70)
    print(f"\nOverall Accuracy: {mean_acc:.3f} ¬± {ci_95:.3f} (95% CI)")
    print(f"\nPer-Fruit Accuracy:")
    
    for fruit in config.TEST_FRUITS:
        fruit_mean = np.mean(per_fruit_all[fruit])
        fruit_std = np.std(per_fruit_all[fruit])
        print(f"  {fruit.capitalize()}: {fruit_mean:.3f} ¬± {fruit_std:.3f}")
    
    print("\n" + "=" * 70)
    if mean_acc > 0.80:
        print("‚úÖ SUCCESS: Model generalizes to unseen fruits!")
        print("   This validates our hypothesis that metric learning can learn")
        print("   'class-agnostic defect representations'.")
    elif mean_acc > 0.65:
        print("‚ö†Ô∏è  PARTIAL SUCCESS: Better than random (50%), room for improvement")
    else:
        print("‚ùå Model struggles to generalize. Consider:")
        print("   - More diverse training fruits")
        print("   - Stronger augmentation")
        print("   - Different backbone")
    print("=" * 70)
    
    return {
        'mean_accuracy': mean_acc,
        'std': std_acc,
        'ci_95': ci_95,
        'per_fruit': dict(per_fruit_all)
    }

# Uncomment to run experiment (only after training)
results = test_on_unseen_fruits(model, test_loader, device, config)

print("‚úì Key experiment function defined")
print("  To run: results = test_on_unseen_fruits(model, test_loader, device, config)")

‚úì Key experiment function defined
  To run: results = test_on_unseen_fruits(model, test_loader, device, config)


## 10. Baseline Comparisons (For Paper)

To validate your contribution, you must compare against:
1. **Supervised CNN** (ResNet trained on all fruits)
2. **Transfer Learning** (Pretrained ResNet, fine-tuned)
3. **Zero-Shot CLIP** (No training examples)
4. **Standard Prototypical Network** (Without SupCon loss)

In [None]:
# ============================================================================
# BASELINE COMPARISONS
# ============================================================================

# Baseline 1: Supervised CNN (Upper Bound)
class SupervisedBaseline(nn.Module):
    """
    Standard supervised classifier trained on all fruits.
    This is the UPPER BOUND - it sees all data during training.
    Your few-shot model should approach this performance.
    """
    def __init__(self, backbone='resnet18', num_classes=2, pretrained=True):
        super().__init__()
        if backbone == 'resnet18':
            self.model = models.resnet18(pretrained=pretrained)
            self.model.fc = nn.Linear(self.model.fc.in_features, num_classes)
        
    def forward(self, x):
        return self.model(x)


# Baseline 2: Zero-Shot CLIP (if available)
def create_clip_baseline():
    """
    Zero-shot CLIP baseline using text prompts.
    Install: pip install transformers
    """
    try:
        from transformers import CLIPProcessor, CLIPModel
        
        model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
        processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
        
        return model, processor
    except ImportError:
        print("‚ö† Install transformers for CLIP baseline: pip install transformers")
        return None, None


def clip_zero_shot_classify(image, model, processor, device):
    """
    Classify fruit quality using CLIP zero-shot.
    """
    # Text prompts for quality
    text_prompts = [
        "a photo of a fresh, healthy fruit",
        "a photo of a rotten, defective fruit"
    ]
    
    inputs = processor(
        text=text_prompts, 
        images=image, 
        return_tensors="pt", 
        padding=True
    ).to(device)
    
    outputs = model(**inputs)
    logits_per_image = outputs.logits_per_image
    probs = logits_per_image.softmax(dim=1)
    
    return probs.argmax().item()


# Baseline 3: Standard ProtoNet (without SupCon)
class StandardProtoNet(PrototypicalNetwork):
    """
    Standard Prototypical Network without Supervised Contrastive Loss.
    Used to show the benefit of adding SupCon.
    """
    pass  # Same architecture, different loss function


def run_baseline_comparisons(config, device):
    """
    Run all baseline comparisons for the paper.
    """
    results = {}
    
    print("=" * 60)
    print("BASELINE COMPARISONS")
    print("=" * 60)
    
    # TODO: Add actual baseline runs here
    # This is a placeholder showing expected results format
    
    results['our_method'] = {
        'seen_fruits': 0.95,   # Expected ~95% on seen fruits
        'unseen_fruits': 0.85  # Expected ~85% on unseen fruits
    }
    
    results['supervised_cnn'] = {
        'seen_fruits': 0.98,   # Upper bound
        'unseen_fruits': 0.60  # Fails to generalize
    }
    
    results['clip_zero_shot'] = {
        'seen_fruits': 0.70,   # No training
        'unseen_fruits': 0.70  # No training
    }
    
    results['standard_protonet'] = {
        'seen_fruits': 0.92,
        'unseen_fruits': 0.78  # Our method should beat this
    }
    
    return results

print("‚úì Baseline comparison functions defined")

‚úì Baseline comparison functions defined


## 11. Visualization & Analysis

Essential visualizations for your paper:
1. t-SNE of embedding space (showing clustering of Good/Bad across fruits)
2. Training curves
3. Confusion matrices
4. Per-fruit performance comparison

In [None]:
# ============================================================================
# VISUALIZATION FUNCTIONS
# ============================================================================

from sklearn.manifold import TSNE
import seaborn as sns

def visualize_embedding_space(model, dataset, device, n_samples=200, save_path=None):
    """
    Create t-SNE visualization of the embedding space.
    This is a KEY FIGURE for your paper showing that:
    - Good fruits cluster together (across species)
    - Bad fruits cluster together (across species)
    """
    model.eval()
    
    embeddings = []
    labels = []
    fruit_types = []
    
    # Sample images from each fruit-quality combination
    for fruit in dataset.fruit_types:
        for quality_idx, quality in enumerate(['fresh', 'rotten']):
            images = dataset.data[fruit][quality][:n_samples//len(dataset.fruit_types)//2]
            
            for img_path in images:
                img = Image.open(img_path).convert('RGB')
                img_tensor = eval_transform(img).unsqueeze(0).to(device)
                
                with torch.no_grad():
                    emb = model.encoder(img_tensor)
                
                embeddings.append(emb.cpu().numpy().flatten())
                labels.append(quality)
                fruit_types.append(fruit)
    
    embeddings = np.array(embeddings)
    
    # t-SNE
    tsne = TSNE(n_components=2, random_state=42, perplexity=30)
    embeddings_2d = tsne.fit_transform(embeddings)
    
    # Plot
    fig, axes = plt.subplots(1, 2, figsize=(14, 6))
    
    # Plot 1: Colored by quality
    ax1 = axes[0]
    colors = ['green' if l == 'fresh' else 'red' for l in labels]
    ax1.scatter(embeddings_2d[:, 0], embeddings_2d[:, 1], c=colors, alpha=0.6, s=50)
    ax1.set_title('Embedding Space by Quality\n(Green=Fresh, Red=Rotten)', fontsize=12)
    ax1.set_xlabel('t-SNE 1')
    ax1.set_ylabel('t-SNE 2')
    
    # Plot 2: Colored by fruit type
    ax2 = axes[1]
    unique_fruits = list(set(fruit_types))
    color_map = plt.cm.get_cmap('tab10')
    colors = [color_map(unique_fruits.index(f)) for f in fruit_types]
    
    for fruit in unique_fruits:
        mask = [f == fruit for f in fruit_types]
        ax2.scatter(
            embeddings_2d[mask, 0], 
            embeddings_2d[mask, 1], 
            label=fruit.capitalize(),
            alpha=0.6, s=50
        )
    ax2.legend()
    ax2.set_title('Embedding Space by Fruit Type', fontsize=12)
    ax2.set_xlabel('t-SNE 1')
    ax2.set_ylabel('t-SNE 2')
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"‚úì Saved: {save_path}")
    
    plt.show()
    
    return embeddings_2d


def plot_training_history(history, save_path=None):
    """Plot training curves"""
    fig, axes = plt.subplots(1, 2, figsize=(12, 4))
    
    # Loss
    axes[0].plot(history['train_loss'], label='Train Loss', color='blue')
    axes[0].set_xlabel('Epoch')
    axes[0].set_ylabel('Loss')
    axes[0].set_title('Training Loss')
    axes[0].legend()
    axes[0].grid(True, alpha=0.3)
    
    # Accuracy
    axes[1].plot(history['train_acc'], label='Train Acc', color='blue')
    axes[1].plot(history['val_acc'], label='Val Acc', color='orange')
    axes[1].set_xlabel('Epoch')
    axes[1].set_ylabel('Accuracy')
    axes[1].set_title('Training & Validation Accuracy')
    axes[1].legend()
    axes[1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
    
    plt.show()


def plot_comparison_bar_chart(results, save_path=None):
    """
    Bar chart comparing methods - KEY FIGURE for paper.
    """
    methods = list(results.keys())
    seen_acc = [results[m]['seen_fruits'] for m in methods]
    unseen_acc = [results[m]['unseen_fruits'] for m in methods]
    
    x = np.arange(len(methods))
    width = 0.35
    
    fig, ax = plt.subplots(figsize=(10, 6))
    bars1 = ax.bar(x - width/2, seen_acc, width, label='Seen Fruits', color='steelblue')
    bars2 = ax.bar(x + width/2, unseen_acc, width, label='Unseen Fruits', color='coral')
    
    ax.set_ylabel('Accuracy')
    ax.set_title('Cross-Species Generalization Comparison')
    ax.set_xticks(x)
    ax.set_xticklabels([m.replace('_', ' ').title() for m in methods])
    ax.legend()
    ax.set_ylim(0, 1.0)
    ax.axhline(y=0.5, color='gray', linestyle='--', alpha=0.5, label='Random')
    
    # Add value labels
    for bar in bars1 + bars2:
        height = bar.get_height()
        ax.annotate(f'{height:.2f}',
                    xy=(bar.get_x() + bar.get_width() / 2, height),
                    xytext=(0, 3),
                    textcoords="offset points",
                    ha='center', va='bottom', fontsize=9)
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
    
    plt.show()

print("‚úì Visualization functions defined")

‚úì Visualization functions defined


## 12. Ablation Studies (Required for Strong Paper)

Ablation studies prove that each component of your method matters:
1. **N-shot analysis**: 1-shot, 3-shot, 5-shot, 10-shot
2. **Loss function**: ProtoLoss only vs ProtoLoss + SupCon
3. **Backbone**: ResNet-18 vs ResNet-50 vs EfficientNet
4. **Number of training fruits**: Train on 1, 2, 3 fruits

In [None]:
# ============================================================================
# ABLATION STUDIES
# ============================================================================

def ablation_n_shot(model, test_dataset, device, shots=[1, 3, 5, 10], n_episodes=200):
    """
    Ablation: How does performance change with number of shots?
    """
    results = {}
    
    print("=" * 60)
    print("ABLATION: N-SHOT ANALYSIS")
    print("=" * 60)
    
    for n_shot in shots:
        print(f"\nTesting {n_shot}-shot...")
        
        test_loader = EpisodicDataLoader(
            dataset=test_dataset,
            n_shot=n_shot,
            n_query=15,
            n_episodes=n_episodes
        )
        
        acc, per_fruit = evaluate(model, test_loader, device, f"{n_shot}-shot")
        results[n_shot] = {'accuracy': acc, 'per_fruit': per_fruit}
        
        print(f"  {n_shot}-shot accuracy: {acc:.3f}")
    
    # Plot results
    plt.figure(figsize=(8, 5))
    plt.plot(shots, [results[s]['accuracy'] for s in shots], 'bo-', linewidth=2, markersize=8)
    plt.xlabel('Number of Shots (K)')
    plt.ylabel('Accuracy on Unseen Fruits')
    plt.title('N-Shot Ablation Study')
    plt.grid(True, alpha=0.3)
    plt.xticks(shots)
    plt.ylim(0.5, 1.0)
    plt.savefig(os.path.join(config.RESULTS_DIR, 'ablation_nshot.png'), dpi=300)
    plt.show()
    
    return results


def ablation_loss_function(config, device):
    """
    Ablation: Compare different loss functions.
    """
    results = {}
    
    losses = {
        'proto_only': PrototypicalLoss(),
        'proto_supcon': CombinedLoss(proto_weight=1.0, supcon_weight=0.5),
        'supcon_heavy': CombinedLoss(proto_weight=1.0, supcon_weight=1.0),
    }
    
    print("=" * 60)
    print("ABLATION: LOSS FUNCTION")
    print("=" * 60)
    
    # TODO: Train separate models with each loss
    # results[loss_name] = accuracy
    
    return results


def ablation_training_diversity(config, device):
    """
    Ablation: How does the number of training fruits affect generalization?
    
    - Train on 1 fruit  ‚Üí Test on unseen
    - Train on 2 fruits ‚Üí Test on unseen  
    - Train on 3 fruits ‚Üí Test on unseen
    """
    print("=" * 60)
    print("ABLATION: TRAINING DIVERSITY")
    print("=" * 60)
    
    fruit_combinations = [
        ['apple'],
        ['apple', 'banana'],
        ['apple', 'banana', 'grape'],
    ]
    
    results = {}
    
    for fruits in fruit_combinations:
        print(f"\nTraining on: {fruits}")
        # TODO: Train model on subset and evaluate
        # results[len(fruits)] = accuracy
    
    return results

print("‚úì Ablation study functions defined")

‚úì Ablation study functions defined


## 13. Summary & Next Steps

### What You've Built:
‚úÖ Prototypical Network with SupCon Loss for few-shot quality grading  
‚úÖ Episodic training framework  
‚úÖ Cross-species generalization experiment  
‚úÖ Baseline comparisons structure  
‚úÖ Visualization tools  

### Dataset Download Links:
1. **Fruits Fresh and Rotten** (Kaggle):  
   https://www.kaggle.com/datasets/sriramr/fruits-fresh-and-rotten-for-classification

2. **Fruit Quality Good/Bad** (Zenodo):  
   https://zenodo.org/records/1310165

3. **FruitVision 2025** (Search on Kaggle)

### To Run:
1. Download dataset ‚Üí Organize as shown above
2. Update `config.DATA_ROOT` 
3. Run cells sequentially
4. Run `train_protonet()` 
5. Run `test_on_unseen_fruits()` ‚Üê **KEY EXPERIMENT**