# 04 - Training Setup

In this notebook, we'll set up the training pipeline for video similarity learning.

## Learning Objectives

By the end of this notebook, you will:
- Prepare data loaders for training
- Configure training parameters and optimizers
- Set up data augmentation strategies
- Implement training loops and validation
- **Complete 4 hands-on exercises** for training setup

## Key Concepts

**Data Loaders**: Efficiently load and batch data during training.

**Optimizers**: Algorithms that update model parameters to minimize loss.

**Data Augmentation**: Techniques to increase dataset diversity.

**Training Loop**: The iterative process of forward pass, loss calculation, and backpropagation.

In [None]:
# Import required libraries
import sys
import os
from pathlib import Path

# Add the project root to the path
project_root = Path.cwd().parent
sys.path.append(str(project_root))

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import yaml

# Import our utilities
from utils.data_utils import VideoDataset, create_sample_dataset
from utils.model_utils import VideoSiameseNetwork, VideoTripletNetwork
from utils.training_utils import ContrastiveLoss, TripletLoss

# Set up plotting
plt.style.use('default')
sns.set_palette("husl")

print("Libraries imported successfully!")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

## 1. Data Preparation

Let's prepare our data for training by creating data loaders.

In [None]:
# Load configuration
config_path = project_root / "configs" / "default_config.yaml"
with open(config_path, 'r') as f:
    config = yaml.safe_load(f)

print("Configuration loaded:")
for key, value in config.items():
    print(f"  {key}: {value}")

In [None]:
# Define data paths
data_dir = project_root / "data" / "videos"
metadata_file = data_dir / "sample_metadata.csv"
pairs_file = data_dir / "similarity_pairs.csv"

# Load metadata and pairs
metadata = pd.read_csv(metadata_file)
pairs = pd.read_csv(pairs_file)

print(f"Dataset loaded:")
print(f"  Videos: {len(metadata)}")
print(f"  Similarity pairs: {len(pairs)}")
print(f"  Unique labels: {metadata['label'].nunique()}")

## 2. Data Loader Setup

Now let's create data loaders for training and validation.

In [None]:
# Split data into train/validation
from sklearn.model_selection import train_test_split

# Split pairs into train/val
train_pairs, val_pairs = train_test_split(
    pairs, test_size=0.2, random_state=42, stratify=pairs['similarity']
)

print(f"Data split:")
print(f"  Training pairs: {len(train_pairs)}")
print(f"  Validation pairs: {len(val_pairs)}")

# Create datasets
train_dataset = VideoDataset(
    pairs=train_pairs,
    video_dir=data_dir,
    max_frames=config['data']['max_frames'],
    image_size=config['data']['image_size'],
    is_training=True
)

val_dataset = VideoDataset(
    pairs=val_pairs,
    video_dir=data_dir,
    max_frames=config['data']['max_frames'],
    image_size=config['data']['image_size'],
    is_training=False
)

print(f"Datasets created:")
print(f"  Training dataset: {len(train_dataset)} samples")
print(f"  Validation dataset: {len(val_dataset)} samples")

In [None]:
# Create data loaders
train_loader = DataLoader(
    train_dataset,
    batch_size=config['training']['batch_size'],
    shuffle=True,
    num_workers=config['training']['num_workers'],
    pin_memory=True
)

val_loader = DataLoader(
    val_dataset,
    batch_size=config['training']['batch_size'],
    shuffle=False,
    num_workers=config['training']['num_workers'],
    pin_memory=True
)

print(f"Data loaders created:")
print(f"  Training batches: {len(train_loader)}")
print(f"  Validation batches: {len(val_loader)}")
print(f"  Batch size: {config['training']['batch_size']}")

## 🎯 EXERCISE 1: Data Loader Analysis

**Task**: Analyze and optimize the data loading pipeline.

**Requirements**:
1. Calculate the memory usage of a single batch
2. Measure data loading time for different batch sizes
3. Implement data prefetching for faster loading
4. Create a data loading benchmark
5. Suggest optimizations for the data pipeline

**Your code here**:

In [None]:
# TODO: Write your data loader analysis code

# 1. Calculate memory usage
# Your code here...

# 2. Measure loading time
# Your code here...

# 3. Implement prefetching
# Your code here...

# 4. Create benchmark
# Your code here...

# 5. Suggest optimizations
# Your code here...

## 3. Model and Loss Function Setup

Let's set up our model and loss functions.

In [None]:
# Create model
if config['model']['architecture'] == 'siamese':
    model = VideoSiameseNetwork(
        feature_dim=config['model']['feature_dim'],
        embedding_dim=config['model']['embedding_dim']
    )
    loss_fn = ContrastiveLoss(margin=config['training']['margin'])
elif config['model']['architecture'] == 'triplet':
    model = VideoTripletNetwork(
        feature_dim=config['model']['feature_dim'],
        embedding_dim=config['model']['embedding_dim']
    )
    loss_fn = TripletLoss(margin=config['training']['margin'])
else:
    raise ValueError(f"Unknown architecture: {config['model']['architecture']}")

print(f"Model created: {config['model']['architecture']}")
print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"Loss function: {type(loss_fn).__name__}")

In [None]:
# Set up optimizer
optimizer = optim.Adam(
    model.parameters(),
    lr=config['training']['learning_rate'],
    weight_decay=config['training']['weight_decay']
)

# Set up learning rate scheduler
scheduler = optim.lr_scheduler.StepLR(
    optimizer,
    step_size=config['training']['lr_step_size'],
    gamma=config['training']['lr_gamma']
)

print(f"Optimizer: {type(optimizer).__name__}")
print(f"Learning rate: {config['training']['learning_rate']}")
print(f"Weight decay: {config['training']['weight_decay']}")
print(f"Scheduler: {type(scheduler).__name__}")

## 🎯 EXERCISE 2: Optimizer and Scheduler Analysis

**Task**: Analyze and compare different optimizers and schedulers.

**Requirements**:
1. Implement different optimizers (SGD, AdamW, RMSprop)
2. Create different learning rate schedulers (cosine, exponential, plateau)
3. Compare optimizer convergence on a simple loss function
4. Analyze the impact of different learning rates
5. Suggest optimal optimizer/scheduler combinations

**Your code here**:

In [None]:
# TODO: Write your optimizer and scheduler analysis code

# 1. Implement different optimizers
# Your code here...

# 2. Create different schedulers
# Your code here...

# 3. Compare convergence
# Your code here...

# 4. Analyze learning rate impact
# Your code here...

# 5. Suggest optimal combinations
# Your code here...

## 4. Training Loop Setup

Let's create the training loop with validation.

In [None]:
# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

print(f"Using device: {device}")

# Training function
def train_epoch(model, train_loader, optimizer, loss_fn, device):
    """Train for one epoch"""
    model.train()
    total_loss = 0
    num_batches = 0
    
    progress_bar = tqdm(train_loader, desc="Training")
    for batch in progress_bar:
        video1, video2, labels = batch
        video1, video2 = video1.to(device), video2.to(device)
        labels = labels.to(device)
        
        # Forward pass
        optimizer.zero_grad()
        
        if isinstance(model, VideoSiameseNetwork):
            predictions = model(video1, video2)
            loss = loss_fn(predictions, labels.float())
        else:
            # For triplet networks, we need to handle differently
            # This is a simplified version
            anchor, positive, negative = video1, video2, video1  # Simplified
            anchor_emb, pos_emb, neg_emb = model(anchor, positive, negative)
            loss = loss_fn(anchor_emb, pos_emb, neg_emb)
        
        # Backward pass
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        num_batches += 1
        
        progress_bar.set_postfix({'loss': f'{loss.item():.4f}'})
    
    return total_loss / num_batches

# Validation function
def validate_epoch(model, val_loader, loss_fn, device):
    """Validate for one epoch"""
    model.eval()
    total_loss = 0
    num_batches = 0
    
    with torch.no_grad():
        progress_bar = tqdm(val_loader, desc="Validation")
        for batch in progress_bar:
            video1, video2, labels = batch
            video1, video2 = video1.to(device), video2.to(device)
            labels = labels.to(device)
            
            # Forward pass
            if isinstance(model, VideoSiameseNetwork):
                predictions = model(video1, video2)
                loss = loss_fn(predictions, labels.float())
            else:
                anchor, positive, negative = video1, video2, video1  # Simplified
                anchor_emb, pos_emb, neg_emb = model(anchor, positive, negative)
                loss = loss_fn(anchor_emb, pos_emb, neg_emb)
            
            total_loss += loss.item()
            num_batches += 1
            
            progress_bar.set_postfix({'loss': f'{loss.item():.4f}'})
    
    return total_loss / num_batches

print("Training and validation functions created successfully!")

## 🎯 EXERCISE 3: Training Loop Optimization

**Task**: Optimize the training loop for better performance.

**Requirements**:
1. Implement gradient clipping to prevent exploding gradients
2. Add early stopping to prevent overfitting
3. Implement model checkpointing
4. Add training metrics tracking (accuracy, precision, recall)
5. Create a training visualization dashboard

**Your code here**:

In [None]:
# TODO: Write your training loop optimization code

# 1. Implement gradient clipping
# Your code here...

# 2. Add early stopping
# Your code here...

# 3. Implement checkpointing
# Your code here...

# 4. Add metrics tracking
# Your code here...

# 5. Create visualization dashboard
# Your code here...

## 5. Data Augmentation Setup

Let's set up data augmentation strategies for better generalization.

In [None]:
import torchvision.transforms as transforms
from PIL import Image

# Define augmentation transforms
train_transforms = transforms.Compose([
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(degrees=10),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.RandomResizedCrop(size=(224, 224), scale=(0.8, 1.0)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

val_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

print("Data augmentation transforms created:")
print("Training transforms:")
for transform in train_transforms.transforms:
    print(f"  - {type(transform).__name__}")
print("\nValidation transforms:")
for transform in val_transforms.transforms:
    print(f"  - {type(transform).__name__}")

## 🎯 EXERCISE 4: Data Augmentation Analysis

**Task**: Analyze and design effective data augmentation strategies.

**Requirements**:
1. Implement video-specific augmentations (temporal cropping, frame dropping)
2. Create a function to visualize augmented samples
3. Compare the effectiveness of different augmentation strategies
4. Implement adaptive augmentation based on training progress
5. Suggest domain-specific augmentations for video similarity

**Your code here**:

In [None]:
# TODO: Write your data augmentation analysis code

# 1. Implement video-specific augmentations
# Your code here...

# 2. Create visualization function
# Your code here...

# 3. Compare augmentation strategies
# Your code here...

# 4. Implement adaptive augmentation
# Your code here...

# 5. Suggest domain-specific augmentations
# Your code here...

## 6. Training Configuration Summary

Let's summarize our training setup.

In [None]:
print("=== TRAINING SETUP SUMMARY ===")
print(f"\nData Configuration:")
print(f"  Training samples: {len(train_dataset)}")
print(f"  Validation samples: {len(val_dataset)}")
print(f"  Batch size: {config['training']['batch_size']}")
print(f"  Max frames per video: {config['data']['max_frames']}")
print(f"  Image size: {config['data']['image_size']}")

print(f"\nModel Configuration:")
print(f"  Architecture: {config['model']['architecture']}")
print(f"  Feature dimension: {config['model']['feature_dim']}")
print(f"  Embedding dimension: {config['model']['embedding_dim']}")
print(f"  Total parameters: {sum(p.numel() for p in model.parameters()):,}")

print(f"\nTraining Configuration:")
print(f"  Optimizer: {type(optimizer).__name__}")
print(f"  Learning rate: {config['training']['learning_rate']}")
print(f"  Weight decay: {config['training']['weight_decay']}")
print(f"  Loss function: {type(loss_fn).__name__}")
print(f"  Margin: {config['training']['margin']}")

print(f"\nHardware Configuration:")
print(f"  Device: {device}")
print(f"  CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"  GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

## 🎯 FINAL EXERCISE: Training Setup Report

**Task**: Write a comprehensive report on training setup optimization.

**Requirements**:
1. Analyze the current training setup and identify bottlenecks
2. Suggest improvements for data loading efficiency
3. Recommend optimal hyperparameters for different scenarios
4. Propose a training monitoring and debugging strategy
5. Design a scalable training pipeline for large datasets

**Your report here** (write in markdown):

In [None]:
# TODO: Write your training setup report
report = """
## Training Setup Report

### Current Setup Analysis:
[Your analysis here]

### Data Loading Improvements:
[Your suggestions here]

### Optimal Hyperparameters:
[Your recommendations here]

### Monitoring Strategy:
[Your proposal here]

### Scalable Pipeline:
[Your design here]
"""

print(report)

## Summary

In this notebook, we've set up:

✅ **Data Preparation**: Created train/validation splits and data loaders
✅ **Model Setup**: Initialized models and loss functions
✅ **Optimizer Configuration**: Set up optimizers and learning rate schedulers
✅ **Training Loop**: Created training and validation functions
✅ **Data Augmentation**: Implemented augmentation strategies
✅ **4 Interactive Exercises**: Hands-on training setup optimization

### Key Takeaways:

1. **Data Loading Efficiency**: Proper data loading setup is crucial for training speed
2. **Optimizer Choice**: Different optimizers work better for different scenarios
3. **Learning Rate Scheduling**: Proper LR scheduling can significantly improve convergence
4. **Data Augmentation**: Augmentation helps prevent overfitting and improves generalization
5. **Monitoring**: Proper training monitoring helps identify issues early

### Next Steps:

In the next notebook, we'll learn about **Model Training** - how to actually train our models and monitor their performance.

---

**Questions to think about:**
- What would be the optimal batch size for your hardware?
- How would you handle class imbalance in the training data?
- What augmentation strategies would work best for your video domain?
- How would you implement distributed training for large datasets?
- What metrics would you track during training?