# Margin-Aware Contrastive Learning with SVM for Plant Disease Detection

## Paper Implementation: Contrastive Learning + SVM Classification

This notebook implements a novel approach combining:
- **SimCLR-style contrastive pretraining** for learning robust representations
- **SVM classification head** leveraging maximum margin properties
- **Application to PlantWildV2** dataset for plant disease detection

### Key Innovation
We hypothesize that SVM's maximum margin principle naturally aligns with contrastive learning's objective of separating representations in embedding space, leading to:
- Better few-shot learning performance
- Improved robustness to domain shift
- Enhanced feature separability

### Author: [Your Name]
### Date: 2025-10-31

## 1. Setup and Installation

In [None]:
# Install required packages (for Colab)
!pip install -q torch torchvision torchaudio
!pip install -q scikit-learn matplotlib seaborn
!pip install -q tensorboard pillow tqdm
!pip install -q thundersvm  # GPU-accelerated SVM (optional, falls back to sklearn)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
import torchvision
import torchvision.transforms as transforms
from torchvision import models

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.svm import LinearSVC, SVC
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score
from sklearn.manifold import TSNE
from sklearn.preprocessing import StandardScaler
from PIL import Image
import os
from pathlib import Path
from tqdm.auto import tqdm
import warnings
warnings.filterwarnings('ignore')

# Set random seeds for reproducibility
SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed(SEED)

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

## 2. Configuration

In [None]:
class Config:
    """Configuration for the entire pipeline"""
    
    # Data paths
    DATA_ROOT = './plantwildV2'  # Change this to your dataset path
    CHECKPOINT_DIR = './checkpoints'
    RESULTS_DIR = './results'
    
    # Image parameters
    IMG_SIZE = 224
    BATCH_SIZE = 128
    NUM_WORKERS = 4
    
    # Contrastive learning parameters
    PROJECTION_DIM = 128  # Dimension of projection head output
    EMBEDDING_DIM = 512   # Dimension of encoder output
    TEMPERATURE = 0.5     # Temperature for NT-Xent loss
    
    # Training parameters - Contrastive
    CONTRASTIVE_EPOCHS = 200
    CONTRASTIVE_LR = 3e-4
    CONTRASTIVE_WEIGHT_DECAY = 1e-6
    
    # Training parameters - SVM
    SVM_C = 1.0  # SVM penalty parameter
    SVM_KERNEL = 'linear'  # 'linear' or 'rbf'
    SVM_MAX_ITER = 1000
    
    # Fine-tuning parameters
    FINETUNE_EPOCHS = 50
    FINETUNE_LR = 1e-4
    FINETUNE_WEIGHT_DECAY = 1e-5
    
    # Augmentation parameters
    COLOR_JITTER_STRENGTH = 0.5
    GAUSSIAN_BLUR_KERNEL = 23
    
    # Experiment settings
    ENABLE_FEW_SHOT = True
    FEW_SHOT_K = [1, 5, 10, 20]  # K-shot learning scenarios
    
config = Config()

# Create directories
os.makedirs(config.CHECKPOINT_DIR, exist_ok=True)
os.makedirs(config.RESULTS_DIR, exist_ok=True)

## 3. Data Loading and Augmentation

In [None]:
class ContrastiveAugmentation:
    """Strong augmentation for contrastive learning (SimCLR style)"""
    
    def __init__(self, img_size=224, strength=0.5):
        self.transform = transforms.Compose([
            transforms.RandomResizedCrop(img_size, scale=(0.2, 1.0)),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomApply([
                transforms.ColorJitter(0.8*strength, 0.8*strength, 0.8*strength, 0.2*strength)
            ], p=0.8),
            transforms.RandomGrayscale(p=0.2),
            transforms.RandomApply([transforms.GaussianBlur(23, sigma=(0.1, 2.0))], p=0.5),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
    
    def __call__(self, x):
        # Return two augmented views of the same image
        return self.transform(x), self.transform(x)


class StandardAugmentation:
    """Standard augmentation for supervised training"""
    
    def __init__(self, img_size=224, is_train=True):
        if is_train:
            self.transform = transforms.Compose([
                transforms.RandomResizedCrop(img_size),
                transforms.RandomHorizontalFlip(),
                transforms.RandomRotation(15),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
            ])
        else:
            self.transform = transforms.Compose([
                transforms.Resize(256),
                transforms.CenterCrop(img_size),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
            ])
    
    def __call__(self, x):
        return self.transform(x)


class PlantWildV2Dataset(Dataset):
    """Custom dataset loader for PlantWildV2
    
    Expected structure:
    plantwildV2/
        class1/
            img1.jpg
            img2.jpg
        class2/
            img1.jpg
            ...
    """
    
    def __init__(self, root_dir, transform=None, mode='contrastive'):
        """
        Args:
            root_dir: Root directory of dataset
            transform: Transform to apply
            mode: 'contrastive' or 'supervised'
        """
        self.root_dir = Path(root_dir)
        self.transform = transform
        self.mode = mode
        
        # Get all image paths and labels
        self.samples = []
        self.class_to_idx = {}
        self.classes = []
        
        # Scan directory structure
        for idx, class_dir in enumerate(sorted(self.root_dir.iterdir())):
            if class_dir.is_dir():
                class_name = class_dir.name
                self.classes.append(class_name)
                self.class_to_idx[class_name] = idx
                
                # Get all images in this class
                for img_path in class_dir.glob('*'):
                    if img_path.suffix.lower() in ['.jpg', '.jpeg', '.png', '.bmp']:
                        self.samples.append((str(img_path), idx))
        
        print(f"Found {len(self.samples)} images from {len(self.classes)} classes")
        print(f"Classes: {self.classes}")
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        img_path, label = self.samples[idx]
        
        # Load image
        image = Image.open(img_path).convert('RGB')
        
        if self.transform:
            if self.mode == 'contrastive':
                # Return two views for contrastive learning
                view1, view2 = self.transform(image)
                return view1, view2, label
            else:
                # Return single view for supervised learning
                image = self.transform(image)
                return image, label
        
        return image, label


def create_few_shot_dataset(dataset, k_shot, num_classes):
    """Create k-shot dataset by sampling k examples per class"""
    class_samples = {i: [] for i in range(num_classes)}
    
    # Group samples by class
    for idx, (_, label) in enumerate(dataset.samples):
        class_samples[label].append(idx)
    
    # Sample k examples per class
    few_shot_indices = []
    for class_idx in range(num_classes):
        indices = class_samples[class_idx]
        if len(indices) >= k_shot:
            sampled = np.random.choice(indices, k_shot, replace=False)
            few_shot_indices.extend(sampled)
    
    return torch.utils.data.Subset(dataset, few_shot_indices)

## 4. Model Architectures

In [None]:
class ProjectionHead(nn.Module):
    """Projection head for contrastive learning (SimCLR)"""
    
    def __init__(self, input_dim=512, hidden_dim=512, output_dim=128):
        super().__init__()
        self.projection = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim)
        )
    
    def forward(self, x):
        return self.projection(x)


class ContrastiveEncoder(nn.Module):
    """Encoder network for contrastive learning"""
    
    def __init__(self, base_model='resnet50', projection_dim=128, pretrained=True):
        super().__init__()
        
        # Load base encoder (ResNet50)
        if base_model == 'resnet50':
            resnet = models.resnet50(pretrained=pretrained)
            self.embedding_dim = resnet.fc.in_features
            self.encoder = nn.Sequential(*list(resnet.children())[:-1])  # Remove FC layer
        elif base_model == 'resnet18':
            resnet = models.resnet18(pretrained=pretrained)
            self.embedding_dim = resnet.fc.in_features
            self.encoder = nn.Sequential(*list(resnet.children())[:-1])
        else:
            raise ValueError(f"Unknown base model: {base_model}")
        
        # Projection head for contrastive loss
        self.projection_head = ProjectionHead(
            input_dim=self.embedding_dim,
            hidden_dim=self.embedding_dim,
            output_dim=projection_dim
        )
    
    def forward(self, x, return_embedding=False):
        # Get embedding from encoder
        h = self.encoder(x)
        h = torch.flatten(h, 1)
        
        if return_embedding:
            return h
        
        # Project to contrastive space
        z = self.projection_head(h)
        return z
    
    def get_embedding(self, x):
        """Extract features for downstream tasks"""
        with torch.no_grad():
            return self.forward(x, return_embedding=True)


class SVMClassifier:
    """SVM classifier wrapper using sklearn"""
    
    def __init__(self, C=1.0, kernel='linear', max_iter=1000):
        """
        Args:
            C: SVM penalty parameter
            kernel: 'linear' or 'rbf'
            max_iter: Maximum iterations
        """
        if kernel == 'linear':
            self.svm = LinearSVC(C=C, max_iter=max_iter, dual=True)
        else:
            self.svm = SVC(C=C, kernel=kernel, max_iter=max_iter)
        
        self.scaler = StandardScaler()
    
    def fit(self, X, y):
        """Train SVM on features"""
        X_scaled = self.scaler.fit_transform(X)
        self.svm.fit(X_scaled, y)
    
    def predict(self, X):
        """Predict classes"""
        X_scaled = self.scaler.transform(X)
        return self.svm.predict(X_scaled)
    
    def score(self, X, y):
        """Get accuracy score"""
        X_scaled = self.scaler.transform(X)
        return self.svm.score(X_scaled, y)


class LinearProbe(nn.Module):
    """Linear probe for comparison baseline"""
    
    def __init__(self, input_dim, num_classes):
        super().__init__()
        self.fc = nn.Linear(input_dim, num_classes)
    
    def forward(self, x):
        return self.fc(x)


class SoftmaxClassifier(nn.Module):
    """Softmax classifier head for fine-tuning"""
    
    def __init__(self, encoder, num_classes, freeze_encoder=False):
        super().__init__()
        self.encoder = encoder
        self.fc = nn.Linear(encoder.embedding_dim, num_classes)
        
        if freeze_encoder:
            for param in self.encoder.parameters():
                param.requires_grad = False
    
    def forward(self, x):
        features = self.encoder.get_embedding(x)
        return self.fc(features)

## 5. Contrastive Loss Functions

In [None]:
class NTXentLoss(nn.Module):
    """Normalized Temperature-scaled Cross Entropy Loss (NT-Xent)
    
    This is the contrastive loss used in SimCLR
    """
    
    def __init__(self, temperature=0.5):
        super().__init__()
        self.temperature = temperature
    
    def forward(self, z_i, z_j):
        """
        Args:
            z_i: Projections from view 1, shape [batch_size, projection_dim]
            z_j: Projections from view 2, shape [batch_size, projection_dim]
        
        Returns:
            loss: NT-Xent loss
        """
        batch_size = z_i.shape[0]
        
        # Normalize embeddings
        z_i = F.normalize(z_i, dim=1)
        z_j = F.normalize(z_j, dim=1)
        
        # Concatenate views
        representations = torch.cat([z_i, z_j], dim=0)  # [2*batch_size, projection_dim]
        
        # Compute similarity matrix
        similarity_matrix = F.cosine_similarity(
            representations.unsqueeze(1), 
            representations.unsqueeze(0), 
            dim=2
        )  # [2*batch_size, 2*batch_size]
        
        # Create mask to remove self-similarity
        mask = torch.eye(2 * batch_size, dtype=torch.bool, device=z_i.device)
        similarity_matrix = similarity_matrix.masked_fill(mask, -9e15)
        
        # Compute positive pairs
        positives = torch.cat([
            torch.diag(similarity_matrix, batch_size),
            torch.diag(similarity_matrix, -batch_size)
        ], dim=0).reshape(2 * batch_size, 1)
        
        # Compute negatives (all other pairs)
        negatives = similarity_matrix[~mask].reshape(2 * batch_size, -1)
        
        # Concatenate and scale by temperature
        logits = torch.cat([positives, negatives], dim=1) / self.temperature
        
        # Labels are always 0 (first column = positive pair)
        labels = torch.zeros(2 * batch_size, dtype=torch.long, device=z_i.device)
        
        # Cross entropy loss
        loss = F.cross_entropy(logits, labels)
        
        return loss

## 6. Training Functions

In [None]:
def train_contrastive_epoch(model, dataloader, optimizer, criterion, device):
    """Train one epoch of contrastive learning"""
    model.train()
    total_loss = 0
    
    pbar = tqdm(dataloader, desc='Training')
    for view1, view2, _ in pbar:
        view1, view2 = view1.to(device), view2.to(device)
        
        # Forward pass
        z1 = model(view1)
        z2 = model(view2)
        
        # Compute loss
        loss = criterion(z1, z2)
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        pbar.set_postfix({'loss': f'{loss.item():.4f}'})
    
    return total_loss / len(dataloader)


def train_contrastive(model, train_loader, config, device):
    """Full contrastive pretraining"""
    criterion = NTXentLoss(temperature=config.TEMPERATURE).to(device)
    optimizer = optim.Adam(
        model.parameters(),
        lr=config.CONTRASTIVE_LR,
        weight_decay=config.CONTRASTIVE_WEIGHT_DECAY
    )
    scheduler = optim.lr_scheduler.CosineAnnealingLR(
        optimizer, 
        T_max=config.CONTRASTIVE_EPOCHS
    )
    
    history = {'loss': []}
    best_loss = float('inf')
    
    for epoch in range(config.CONTRASTIVE_EPOCHS):
        print(f"\nEpoch {epoch+1}/{config.CONTRASTIVE_EPOCHS}")
        
        # Train
        loss = train_contrastive_epoch(model, train_loader, optimizer, criterion, device)
        history['loss'].append(loss)
        
        # Scheduler step
        scheduler.step()
        
        print(f"Average Loss: {loss:.4f}")
        
        # Save best model
        if loss < best_loss:
            best_loss = loss
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': loss,
            }, os.path.join(config.CHECKPOINT_DIR, 'best_contrastive.pth'))
            print("✓ Saved best model")
    
    return history


def extract_features(model, dataloader, device):
    """Extract features from pretrained encoder"""
    model.eval()
    features_list = []
    labels_list = []
    
    with torch.no_grad():
        for images, labels in tqdm(dataloader, desc='Extracting features'):
            images = images.to(device)
            features = model.get_embedding(images)
            features_list.append(features.cpu().numpy())
            labels_list.append(labels.numpy())
    
    features = np.concatenate(features_list, axis=0)
    labels = np.concatenate(labels_list, axis=0)
    
    return features, labels


def train_svm_classifier(encoder, train_loader, test_loader, config, device):
    """Train SVM on frozen features"""
    print("\n" + "="*50)
    print("Training SVM Classifier")
    print("="*50)
    
    # Extract features
    print("Extracting training features...")
    X_train, y_train = extract_features(encoder, train_loader, device)
    
    print("Extracting test features...")
    X_test, y_test = extract_features(encoder, test_loader, device)
    
    print(f"Train features shape: {X_train.shape}")
    print(f"Test features shape: {X_test.shape}")
    
    # Train SVM
    print(f"\nTraining SVM (C={config.SVM_C}, kernel={config.SVM_KERNEL})...")
    svm = SVMClassifier(C=config.SVM_C, kernel=config.SVM_KERNEL, max_iter=config.SVM_MAX_ITER)
    svm.fit(X_train, y_train)
    
    # Evaluate
    train_acc = svm.score(X_train, y_train)
    test_acc = svm.score(X_test, y_test)
    
    print(f"\nResults:")
    print(f"  Train Accuracy: {train_acc*100:.2f}%")
    print(f"  Test Accuracy: {test_acc*100:.2f}%")
    
    # Get predictions for detailed metrics
    y_pred = svm.predict(X_test)
    
    return svm, test_acc, y_pred, y_test


def train_linear_probe(encoder, train_loader, test_loader, num_classes, config, device):
    """Train linear probe baseline"""
    print("\n" + "="*50)
    print("Training Linear Probe")
    print("="*50)
    
    # Create linear probe
    probe = LinearProbe(encoder.embedding_dim, num_classes).to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(probe.parameters(), lr=config.FINETUNE_LR)
    
    best_acc = 0
    
    for epoch in range(config.FINETUNE_EPOCHS):
        probe.train()
        total_loss = 0
        
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            
            # Get frozen features
            with torch.no_grad():
                features = encoder.get_embedding(images)
            
            # Forward pass
            outputs = probe(features)
            loss = criterion(outputs, labels)
            
            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
        
        # Evaluate
        if (epoch + 1) % 10 == 0:
            probe.eval()
            correct = 0
            total = 0
            
            with torch.no_grad():
                for images, labels in test_loader:
                    images, labels = images.to(device), labels.to(device)
                    features = encoder.get_embedding(images)
                    outputs = probe(features)
                    _, predicted = torch.max(outputs.data, 1)
                    total += labels.size(0)
                    correct += (predicted == labels).sum().item()
            
            acc = 100 * correct / total
            print(f"Epoch {epoch+1}/{config.FINETUNE_EPOCHS}, Loss: {total_loss/len(train_loader):.4f}, Test Acc: {acc:.2f}%")
            
            if acc > best_acc:
                best_acc = acc
    
    return best_acc


def train_softmax_finetune(encoder, train_loader, test_loader, num_classes, config, device):
    """Fine-tune with softmax classifier"""
    print("\n" + "="*50)
    print("Fine-tuning with Softmax")
    print("="*50)
    
    model = SoftmaxClassifier(encoder, num_classes, freeze_encoder=False).to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=config.FINETUNE_LR, weight_decay=config.FINETUNE_WEIGHT_DECAY)
    
    best_acc = 0
    
    for epoch in range(config.FINETUNE_EPOCHS):
        model.train()
        total_loss = 0
        
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
        
        # Evaluate
        if (epoch + 1) % 10 == 0:
            model.eval()
            correct = 0
            total = 0
            
            with torch.no_grad():
                for images, labels in test_loader:
                    images, labels = images.to(device), labels.to(device)
                    outputs = model(images)
                    _, predicted = torch.max(outputs.data, 1)
                    total += labels.size(0)
                    correct += (predicted == labels).sum().item()
            
            acc = 100 * correct / total
            print(f"Epoch {epoch+1}/{config.FINETUNE_EPOCHS}, Loss: {total_loss/len(train_loader):.4f}, Test Acc: {acc:.2f}%")
            
            if acc > best_acc:
                best_acc = acc
    
    return best_acc

## 7. Evaluation and Visualization

In [None]:
def plot_confusion_matrix(y_true, y_pred, classes, title='Confusion Matrix'):
    """Plot confusion matrix"""
    cm = confusion_matrix(y_true, y_pred)
    
    plt.figure(figsize=(12, 10))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
                xticklabels=classes, yticklabels=classes)
    plt.title(title)
    plt.ylabel('True Label')
    plt.xlabel('Predicted Label')
    plt.xticks(rotation=45, ha='right')
    plt.yticks(rotation=0)
    plt.tight_layout()
    return plt.gcf()


def plot_tsne(features, labels, classes, title='t-SNE Visualization'):
    """Plot t-SNE visualization of features"""
    print("Computing t-SNE...")
    tsne = TSNE(n_components=2, random_state=42)
    features_2d = tsne.fit_transform(features)
    
    plt.figure(figsize=(12, 10))
    scatter = plt.scatter(features_2d[:, 0], features_2d[:, 1], 
                         c=labels, cmap='tab10', alpha=0.6, s=20)
    plt.colorbar(scatter, ticks=range(len(classes)))
    plt.title(title)
    plt.xlabel('t-SNE 1')
    plt.ylabel('t-SNE 2')
    
    # Add legend
    handles = [plt.Line2D([0], [0], marker='o', color='w', 
                         markerfacecolor=plt.cm.tab10(i/len(classes)), 
                         markersize=8, label=classes[i]) 
              for i in range(len(classes))]
    plt.legend(handles=handles, bbox_to_anchor=(1.05, 1), loc='upper left')
    plt.tight_layout()
    return plt.gcf()


def plot_training_history(history):
    """Plot training loss curve"""
    plt.figure(figsize=(10, 6))
    plt.plot(history['loss'])
    plt.title('Contrastive Learning - Training Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.grid(True)
    plt.tight_layout()
    return plt.gcf()


def evaluate_few_shot(encoder, full_dataset, k_shots, num_classes, config, device):
    """Evaluate few-shot learning performance"""
    results = {}
    
    # Create supervised augmentation dataset
    full_dataset_sup = PlantWildV2Dataset(
        root_dir=config.DATA_ROOT,
        transform=StandardAugmentation(config.IMG_SIZE, is_train=False),
        mode='supervised'
    )
    
    # Create test set (use 30% for testing)
    test_size = int(0.3 * len(full_dataset_sup))
    train_size = len(full_dataset_sup) - test_size
    _, test_dataset = random_split(full_dataset_sup, [train_size, test_size])
    test_loader = DataLoader(test_dataset, batch_size=config.BATCH_SIZE, shuffle=False)
    
    for k in k_shots:
        print(f"\n{k}-shot learning...")
        
        # Create k-shot dataset
        few_shot_dataset = create_few_shot_dataset(full_dataset_sup, k, num_classes)
        few_shot_loader = DataLoader(few_shot_dataset, batch_size=32, shuffle=True)
        
        # Train SVM on k-shot data
        svm, test_acc, _, _ = train_svm_classifier(
            encoder, few_shot_loader, test_loader, config, device
        )
        
        results[f'{k}-shot'] = test_acc * 100
    
    return results


def print_classification_report(y_true, y_pred, classes):
    """Print detailed classification report"""
    print("\n" + "="*50)
    print("Classification Report")
    print("="*50)
    print(classification_report(y_true, y_pred, target_names=classes))

## 8. Main Experiment Pipeline

In [None]:
def run_full_experiment(config):
    """Run complete experimental pipeline"""
    
    print("="*70)
    print("CONTRASTIVE LEARNING + SVM FOR PLANT DISEASE DETECTION")
    print("="*70)
    
    # 1. Load dataset
    print("\n[1/6] Loading PlantWildV2 Dataset...")
    
    # Contrastive dataset
    contrastive_dataset = PlantWildV2Dataset(
        root_dir=config.DATA_ROOT,
        transform=ContrastiveAugmentation(config.IMG_SIZE, config.COLOR_JITTER_STRENGTH),
        mode='contrastive'
    )
    
    # Supervised dataset
    supervised_dataset = PlantWildV2Dataset(
        root_dir=config.DATA_ROOT,
        transform=StandardAugmentation(config.IMG_SIZE, is_train=True),
        mode='supervised'
    )
    
    supervised_dataset_test = PlantWildV2Dataset(
        root_dir=config.DATA_ROOT,
        transform=StandardAugmentation(config.IMG_SIZE, is_train=False),
        mode='supervised'
    )
    
    num_classes = len(contrastive_dataset.classes)
    classes = contrastive_dataset.classes
    
    # Split datasets
    train_size = int(0.7 * len(supervised_dataset))
    val_size = len(supervised_dataset) - train_size
    train_dataset, _ = random_split(supervised_dataset, [train_size, val_size])
    _, test_dataset = random_split(supervised_dataset_test, [train_size, val_size])
    
    # Create dataloaders
    contrastive_loader = DataLoader(
        contrastive_dataset,
        batch_size=config.BATCH_SIZE,
        shuffle=True,
        num_workers=config.NUM_WORKERS,
        pin_memory=True,
        drop_last=True
    )
    
    train_loader = DataLoader(
        train_dataset,
        batch_size=config.BATCH_SIZE,
        shuffle=True,
        num_workers=config.NUM_WORKERS
    )
    
    test_loader = DataLoader(
        test_dataset,
        batch_size=config.BATCH_SIZE,
        shuffle=False,
        num_workers=config.NUM_WORKERS
    )
    
    print(f"  ✓ Found {num_classes} classes")
    print(f"  ✓ Train size: {len(train_dataset)}")
    print(f"  ✓ Test size: {len(test_dataset)}")
    
    # 2. Contrastive Pretraining
    print("\n[2/6] Contrastive Pretraining...")
    encoder = ContrastiveEncoder(
        base_model='resnet50',
        projection_dim=config.PROJECTION_DIM,
        pretrained=True
    ).to(device)
    
    history = train_contrastive(encoder, contrastive_loader, config, device)
    
    # Plot training history
    fig = plot_training_history(history)
    fig.savefig(os.path.join(config.RESULTS_DIR, 'training_loss.png'), dpi=300, bbox_inches='tight')
    plt.close()
    
    # 3. Extract and visualize features
    print("\n[3/6] Extracting Features...")
    X_train, y_train = extract_features(encoder, train_loader, device)
    X_test, y_test = extract_features(encoder, test_loader, device)
    
    # t-SNE visualization
    print("Creating t-SNE visualization...")
    sample_size = min(2000, len(X_test))
    sample_indices = np.random.choice(len(X_test), sample_size, replace=False)
    fig = plot_tsne(X_test[sample_indices], y_test[sample_indices], classes)
    fig.savefig(os.path.join(config.RESULTS_DIR, 'tsne_features.png'), dpi=300, bbox_inches='tight')
    plt.close()
    
    # 4. Train SVM Classifier
    print("\n[4/6] Training SVM Classifier...")
    svm, svm_acc, y_pred_svm, y_test_svm = train_svm_classifier(
        encoder, train_loader, test_loader, config, device
    )
    
    # Print detailed metrics
    print_classification_report(y_test_svm, y_pred_svm, classes)
    
    # Plot confusion matrix
    fig = plot_confusion_matrix(y_test_svm, y_pred_svm, classes, 'SVM Confusion Matrix')
    fig.savefig(os.path.join(config.RESULTS_DIR, 'confusion_matrix_svm.png'), dpi=300, bbox_inches='tight')
    plt.close()
    
    # 5. Baseline Comparisons
    print("\n[5/6] Training Baseline Methods...")
    
    # Linear probe
    linear_acc = train_linear_probe(
        encoder, train_loader, test_loader, num_classes, config, device
    )
    
    # Softmax fine-tuning
    softmax_acc = train_softmax_finetune(
        encoder, train_loader, test_loader, num_classes, config, device
    )
    
    # 6. Few-shot Learning
    few_shot_results = {}
    if config.ENABLE_FEW_SHOT:
        print("\n[6/6] Few-shot Learning Evaluation...")
        few_shot_results = evaluate_few_shot(
            encoder, supervised_dataset_test, config.FEW_SHOT_K, 
            num_classes, config, device
        )
    
    # Summary
    print("\n" + "="*70)
    print("FINAL RESULTS SUMMARY")
    print("="*70)
    print(f"\nFull Dataset Results:")
    print(f"  SVM Classifier:          {svm_acc*100:.2f}%")
    print(f"  Linear Probe:            {linear_acc:.2f}%")
    print(f"  Softmax Fine-tuning:     {softmax_acc:.2f}%")
    
    if few_shot_results:
        print(f"\nFew-shot Learning Results:")
        for k, acc in few_shot_results.items():
            print(f"  {k}: {acc:.2f}%")
    
    # Save results
    results_dict = {
        'svm_accuracy': float(svm_acc * 100),
        'linear_probe_accuracy': float(linear_acc),
        'softmax_accuracy': float(softmax_acc),
        'few_shot_results': few_shot_results,
        'num_classes': num_classes,
        'classes': classes
    }
    
    import json
    with open(os.path.join(config.RESULTS_DIR, 'results.json'), 'w') as f:
        json.dump(results_dict, f, indent=4)
    
    print(f"\n✓ Results saved to {config.RESULTS_DIR}")
    print("="*70)
    
    return results_dict, encoder, svm

## 9. Run Experiment

In [None]:
# Update config with your dataset path
config.DATA_ROOT = './plantwildV2'  # Change this to your actual path

# Run full experiment
results, trained_encoder, trained_svm = run_full_experiment(config)

## 10. Inference and Visualization

In [None]:
def predict_image(image_path, encoder, svm, config, classes):
    """Predict disease for a single image"""
    # Load and preprocess image
    transform = StandardAugmentation(config.IMG_SIZE, is_train=False)
    image = Image.open(image_path).convert('RGB')
    image_tensor = transform(image).unsqueeze(0).to(device)
    
    # Extract features
    with torch.no_grad():
        features = encoder.get_embedding(image_tensor)
        features = features.cpu().numpy()
    
    # Predict
    prediction = svm.predict(features)[0]
    predicted_class = classes[prediction]
    
    # Visualize
    plt.figure(figsize=(8, 6))
    plt.imshow(image)
    plt.axis('off')
    plt.title(f'Predicted: {predicted_class}', fontsize=16)
    plt.tight_layout()
    plt.show()
    
    return predicted_class

# Example usage
# image_path = 'path/to/test/image.jpg'
# prediction = predict_image(image_path, trained_encoder, trained_svm, config, results['classes'])

## 11. Additional Analysis

In [None]:
def compare_methods_plot(results):
    """Create comparison bar plot of different methods"""
    methods = ['SVM', 'Linear Probe', 'Softmax Fine-tune']
    accuracies = [
        results['svm_accuracy'],
        results['linear_probe_accuracy'],
        results['softmax_accuracy']
    ]
    
    plt.figure(figsize=(10, 6))
    bars = plt.bar(methods, accuracies, color=['#2ecc71', '#3498db', '#e74c3c'])
    plt.ylabel('Accuracy (%)', fontsize=12)
    plt.title('Comparison of Classification Methods', fontsize=14, fontweight='bold')
    plt.ylim([0, 100])
    
    # Add value labels on bars
    for bar, acc in zip(bars, accuracies):
        height = bar.get_height()
        plt.text(bar.get_x() + bar.get_width()/2., height,
                f'{acc:.2f}%', ha='center', va='bottom', fontsize=11, fontweight='bold')
    
    plt.grid(axis='y', alpha=0.3)
    plt.tight_layout()
    return plt.gcf()

# Plot comparison
if 'results' in locals():
    fig = compare_methods_plot(results)
    fig.savefig(os.path.join(config.RESULTS_DIR, 'methods_comparison.png'), dpi=300, bbox_inches='tight')
    plt.show()

## 12. Export Model for Deployment

In [None]:
import pickle

def save_model_for_deployment(encoder, svm, config, classes):
    """Save models for production deployment"""
    
    # Save encoder
    torch.save({
        'model_state_dict': encoder.state_dict(),
        'embedding_dim': encoder.embedding_dim,
        'classes': classes
    }, os.path.join(config.CHECKPOINT_DIR, 'encoder_final.pth'))
    
    # Save SVM
    with open(os.path.join(config.CHECKPOINT_DIR, 'svm_final.pkl'), 'wb') as f:
        pickle.dump(svm, f)
    
    # Save config
    config_dict = {
        'IMG_SIZE': config.IMG_SIZE,
        'EMBEDDING_DIM': config.EMBEDDING_DIM,
        'classes': classes
    }
    with open(os.path.join(config.CHECKPOINT_DIR, 'config.json'), 'w') as f:
        json.dump(config_dict, f, indent=4)
    
    print(f"✓ Models saved to {config.CHECKPOINT_DIR}")

# Save models
if 'trained_encoder' in locals() and 'trained_svm' in locals():
    save_model_for_deployment(trained_encoder, trained_svm, config, results['classes'])

## 13. Citation and References

If you use this code in your research, please cite:

```bibtex
@article{yourname2025contrastive,
  title={Margin-Aware Contrastive Learning with SVM for Plant Disease Detection},
  author={Your Name},
  journal={arXiv preprint arXiv:XXXX.XXXXX},
  year={2025}
}
```

### References
1. Chen, T., et al. (2020). A Simple Framework for Contrastive Learning of Visual Representations. ICML.
2. Tang, Y. (2013). Deep Learning using Linear Support Vector Machines. arXiv:1306.0239.
3. Agarap, A. F. (2017). An Architecture Combining Convolutional Neural Network (CNN) and Support Vector Machine (SVM) for Image Classification. arXiv:1712.03541.