# Training Strategies and Optimization

This notebook covers training strategies, optimization techniques, and best practices for the Hyena-GLT framework.

## Topics Covered:
1. Training Configuration
2. Optimization Strategies
3. Learning Rate Scheduling
4. Distributed Training
5. Memory Optimization
6. Monitoring and Logging

In [None]:
import sys
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np

# Add project root to path
project_root = Path().absolute().parent
sys.path.append(str(project_root))

from hyena_glt.config import TrainingConfig
from hyena_glt.optimization import OptimizationConfig

print("✅ Imports successful!")

## 1. Training Configuration

Let's start by setting up a comprehensive training configuration:

In [None]:
# Basic training configuration
training_config = TrainingConfig(
    batch_size=32,
    learning_rate=1e-4,
    num_epochs=10,
    warmup_steps=1000,
    weight_decay=0.01,
    gradient_clip_norm=1.0,
    save_every=1000,
    eval_every=500,
    mixed_precision=True,
    gradient_accumulation_steps=4
)

print("Training Configuration:")
print(f"Batch Size: {training_config.batch_size}")
print(f"Learning Rate: {training_config.learning_rate}")
print(f"Epochs: {training_config.num_epochs}")
print(f"Mixed Precision: {training_config.mixed_precision}")
print(f"Gradient Accumulation Steps: {training_config.gradient_accumulation_steps}")

## 2. Optimization Strategies

Different optimization strategies for various training scenarios:

In [None]:
# Optimization configuration for different scenarios
optimization_configs = {
    "conservative": OptimizationConfig(
        optimizer_type="adamw",
        learning_rate=1e-5,
        weight_decay=0.01,
        beta1=0.9,
        beta2=0.999,
        scheduler_type="cosine",
        warmup_ratio=0.1
    ),
    "aggressive": OptimizationConfig(
        optimizer_type="adamw",
        learning_rate=5e-4,
        weight_decay=0.1,
        beta1=0.9,
        beta2=0.95,
        scheduler_type="linear",
        warmup_ratio=0.05
    ),
    "fine_tuning": OptimizationConfig(
        optimizer_type="adamw",
        learning_rate=2e-5,
        weight_decay=0.01,
        beta1=0.9,
        beta2=0.999,
        scheduler_type="polynomial",
        warmup_ratio=0.06
    )
}

for name, config in optimization_configs.items():
    print(f"\n{name.upper()} Strategy:")
    print(f"  Learning Rate: {config.learning_rate}")
    print(f"  Weight Decay: {config.weight_decay}")
    print(f"  Scheduler: {config.scheduler_type}")
    print(f"  Warmup Ratio: {config.warmup_ratio}")

## 3. Learning Rate Scheduling

Visualizing different learning rate schedules:

In [None]:
def simulate_lr_schedule(schedule_type, base_lr=1e-4, total_steps=10000, warmup_steps=1000):
    """Simulate learning rate schedule."""
    steps = np.arange(total_steps)
    lrs = []

    for step in steps:
        if step < warmup_steps:
            # Linear warmup
            lr = base_lr * (step / warmup_steps)
        else:
            progress = (step - warmup_steps) / (total_steps - warmup_steps)

            if schedule_type == "cosine":
                lr = base_lr * 0.5 * (1 + np.cos(np.pi * progress))
            elif schedule_type == "linear":
                lr = base_lr * (1 - progress)
            elif schedule_type == "polynomial":
                lr = base_lr * ((1 - progress) ** 2)
            else:  # constant
                lr = base_lr

        lrs.append(lr)

    return steps, lrs

# Plot different schedules
fig, axes = plt.subplots(2, 2, figsize=(12, 8))
axes = axes.flatten()

schedules = ["cosine", "linear", "polynomial", "constant"]
colors = ["blue", "red", "green", "orange"]

for i, (schedule, color) in enumerate(zip(schedules, colors, strict=False)):
    steps, lrs = simulate_lr_schedule(schedule)
    axes[i].plot(steps, lrs, color=color, linewidth=2)
    axes[i].set_title(f"{schedule.title()} Schedule")
    axes[i].set_xlabel("Training Steps")
    axes[i].set_ylabel("Learning Rate")
    axes[i].grid(True, alpha=0.3)
    axes[i].set_ylim(0, 1.1e-4)

plt.tight_layout()
plt.show()

print("📈 Learning rate schedules visualized!")

## 4. Training Loop Example

A comprehensive training loop with monitoring:

In [None]:
class TrainingMonitor:
    """Monitor training progress and metrics."""

    def __init__(self):
        self.metrics = {
            "train_loss": [],
            "val_loss": [],
            "learning_rate": [],
            "gradient_norm": [],
            "step_time": []
        }
        self.step = 0

    def log_metrics(self, **kwargs):
        """Log training metrics."""
        self.step += 1
        for key, value in kwargs.items():
            if key in self.metrics:
                self.metrics[key].append(value)

    def plot_metrics(self):
        """Plot training metrics."""
        fig, axes = plt.subplots(2, 2, figsize=(12, 8))

        # Loss curves
        if self.metrics["train_loss"]:
            axes[0, 0].plot(self.metrics["train_loss"], label="Train Loss", color="blue")
        if self.metrics["val_loss"]:
            axes[0, 0].plot(self.metrics["val_loss"], label="Val Loss", color="red")
        axes[0, 0].set_title("Training Loss")
        axes[0, 0].set_xlabel("Steps")
        axes[0, 0].set_ylabel("Loss")
        axes[0, 0].legend()
        axes[0, 0].grid(True, alpha=0.3)

        # Learning rate
        if self.metrics["learning_rate"]:
            axes[0, 1].plot(self.metrics["learning_rate"], color="green")
        axes[0, 1].set_title("Learning Rate")
        axes[0, 1].set_xlabel("Steps")
        axes[0, 1].set_ylabel("LR")
        axes[0, 1].grid(True, alpha=0.3)

        # Gradient norm
        if self.metrics["gradient_norm"]:
            axes[1, 0].plot(self.metrics["gradient_norm"], color="purple")
        axes[1, 0].set_title("Gradient Norm")
        axes[1, 0].set_xlabel("Steps")
        axes[1, 0].set_ylabel("Norm")
        axes[1, 0].grid(True, alpha=0.3)

        # Step time
        if self.metrics["step_time"]:
            axes[1, 1].plot(self.metrics["step_time"], color="orange")
        axes[1, 1].set_title("Step Time")
        axes[1, 1].set_xlabel("Steps")
        axes[1, 1].set_ylabel("Time (s)")
        axes[1, 1].grid(True, alpha=0.3)

        plt.tight_layout()
        plt.show()

# Example usage
monitor = TrainingMonitor()

# Simulate training metrics
for i in range(100):
    # Simulate decreasing loss
    train_loss = 2.0 * np.exp(-i/50) + 0.1 + np.random.normal(0, 0.05)
    val_loss = 2.2 * np.exp(-i/60) + 0.15 + np.random.normal(0, 0.03)

    # Simulate learning rate schedule
    if i < 20:
        lr = 1e-4 * (i / 20)
    else:
        lr = 1e-4 * 0.5 * (1 + np.cos(np.pi * (i-20) / 80))

    # Simulate gradient norm
    grad_norm = np.random.lognormal(0, 0.5)

    # Simulate step time
    step_time = 0.5 + np.random.normal(0, 0.1)

    monitor.log_metrics(
        train_loss=train_loss,
        val_loss=val_loss,
        learning_rate=lr,
        gradient_norm=grad_norm,
        step_time=step_time
    )

monitor.plot_metrics()
print("📊 Training metrics visualized!")

## 5. Memory Optimization Techniques

Strategies for efficient memory usage during training:

In [None]:
class MemoryOptimizer:
    """Memory optimization utilities."""

    @staticmethod
    def estimate_memory_usage(model, batch_size, sequence_length):
        """Estimate GPU memory usage."""
        # Count parameters
        param_count = sum(p.numel() for p in model.parameters())

        # Estimate memory (rough approximation)
        param_memory = param_count * 4  # 4 bytes per float32
        activation_memory = batch_size * sequence_length * model.config.hidden_size * 4
        gradient_memory = param_memory  # Same as parameters
        optimizer_memory = param_memory * 2  # AdamW stores momentum and variance

        total_memory = param_memory + activation_memory + gradient_memory + optimizer_memory

        return {
            "parameters": param_memory / 1e9,  # GB
            "activations": activation_memory / 1e9,
            "gradients": gradient_memory / 1e9,
            "optimizer": optimizer_memory / 1e9,
            "total": total_memory / 1e9
        }

    @staticmethod
    def get_memory_tips(memory_gb):
        """Get memory optimization tips based on usage."""
        tips = []

        if memory_gb > 16:
            tips.extend([
                "Consider gradient checkpointing",
                "Use gradient accumulation",
                "Enable mixed precision training",
                "Reduce batch size"
            ])

        if memory_gb > 24:
            tips.extend([
                "Use model parallelism",
                "Consider ZeRO optimizer",
                "Use activation recomputation"
            ])

        return tips

# Example memory estimation
print("Memory Optimization Analysis")
print("=" * 40)

# Mock model config for estimation
class MockConfig:
    def __init__(self):
        self.hidden_size = 1024
        self.num_layers = 12
        self.vocab_size = 8000

class MockModel:
    def __init__(self, config):
        self.config = config
        # Estimate parameter count
        self._param_count = (
            config.vocab_size * config.hidden_size +  # Embedding
            config.num_layers * config.hidden_size * config.hidden_size * 4 +  # Layers
            config.hidden_size * config.vocab_size  # Output
        )

    def parameters(self):
        class MockParam:
            def __init__(self, numel):
                self._numel = numel
            def numel(self):
                return self._numel
        return [MockParam(self._param_count)]

config = MockConfig()
model = MockModel(config)
optimizer = MemoryOptimizer()

batch_sizes = [8, 16, 32, 64]
sequence_length = 2048

for batch_size in batch_sizes:
    memory = optimizer.estimate_memory_usage(model, batch_size, sequence_length)
    tips = optimizer.get_memory_tips(memory["total"])

    print(f"\nBatch Size: {batch_size}")
    print(f"  Parameters: {memory['parameters']:.2f} GB")
    print(f"  Activations: {memory['activations']:.2f} GB")
    print(f"  Gradients: {memory['gradients']:.2f} GB")
    print(f"  Optimizer: {memory['optimizer']:.2f} GB")
    print(f"  Total: {memory['total']:.2f} GB")

    if tips:
        print(f"  💡 Tips: {', '.join(tips)}")

## 6. Distributed Training Setup

Configuration for multi-GPU and multi-node training:

In [None]:
class DistributedTrainingConfig:
    """Configuration for distributed training."""

    def __init__(self):
        self.strategies = {
            "data_parallel": {
                "description": "Replicate model on each GPU, split batch",
                "pros": ["Simple to implement", "Good for large batches"],
                "cons": ["Memory overhead", "Communication overhead"],
                "best_for": "Models that fit on single GPU"
            },
            "model_parallel": {
                "description": "Split model layers across GPUs",
                "pros": ["Handles large models", "Memory efficient"],
                "cons": ["Pipeline bubbles", "Complex implementation"],
                "best_for": "Very large models"
            },
            "pipeline_parallel": {
                "description": "Pipeline model layers with micro-batches",
                "pros": ["Good throughput", "Handles large models"],
                "cons": ["Latency overhead", "Memory for pipeline"],
                "best_for": "Large models with many layers"
            },
            "zero": {
                "description": "ZeRO optimizer state partitioning",
                "pros": ["Memory efficient", "Scales well"],
                "cons": ["Communication overhead", "Implementation complexity"],
                "best_for": "Large scale training"
            }
        }

    def recommend_strategy(self, model_size_gb, num_gpus, gpu_memory_gb):
        """Recommend distributed training strategy."""
        recommendations = []

        if model_size_gb < gpu_memory_gb * 0.5:
            recommendations.append(("data_parallel", "Model fits comfortably on single GPU"))

        if model_size_gb > gpu_memory_gb:
            recommendations.append(("model_parallel", "Model too large for single GPU"))
            recommendations.append(("pipeline_parallel", "Alternative for large models"))

        if num_gpus >= 8:
            recommendations.append(("zero", "Large scale training with many GPUs"))

        return recommendations

    def print_strategies(self):
        """Print all available strategies."""
        print("Distributed Training Strategies")
        print("=" * 50)

        for name, info in self.strategies.items():
            print(f"\n{name.upper().replace('_', ' ')}:")
            print(f"  Description: {info['description']}")
            print(f"  Pros: {', '.join(info['pros'])}")
            print(f"  Cons: {', '.join(info['cons'])}")
            print(f"  Best for: {info['best_for']}")

# Example usage
dist_config = DistributedTrainingConfig()
dist_config.print_strategies()

# Example recommendations
scenarios = [
    ("Small model", 2, 4, 24),  # 2GB model, 4 GPUs, 24GB each
    ("Medium model", 8, 8, 24),  # 8GB model, 8 GPUs, 24GB each
    ("Large model", 30, 8, 24)   # 30GB model, 8 GPUs, 24GB each
]

print("\n\nRecommendations for Different Scenarios:")
print("=" * 50)

for name, model_size, num_gpus, gpu_memory in scenarios:
    print(f"\n{name}: {model_size}GB model, {num_gpus} GPUs ({gpu_memory}GB each)")
    recommendations = dist_config.recommend_strategy(model_size, num_gpus, gpu_memory)

    for strategy, reason in recommendations:
        print(f"  ✅ {strategy}: {reason}")

## 7. Best Practices Summary

Key takeaways for effective training:

In [None]:
training_best_practices = {
    "Learning Rate": [
        "Start with learning rate finder",
        "Use warmup for large batch sizes",
        "Cosine annealing often works well",
        "Monitor gradient norm for instability"
    ],
    "Batch Size": [
        "Larger batches → more stable gradients",
        "Use gradient accumulation for memory constraints",
        "Scale learning rate with batch size",
        "Monitor training dynamics"
    ],
    "Regularization": [
        "Weight decay for parameter regularization",
        "Dropout for overfitting prevention",
        "Gradient clipping for stability",
        "Early stopping based on validation"
    ],
    "Memory Optimization": [
        "Use mixed precision (FP16/BF16)",
        "Gradient checkpointing for large models",
        "Optimize data loading pipeline",
        "Monitor GPU memory usage"
    ],
    "Monitoring": [
        "Track loss curves",
        "Monitor learning rate schedule",
        "Watch gradient norms",
        "Log training metrics regularly"
    ]
}

print("🎯 Training Best Practices")
print("=" * 50)

for category, practices in training_best_practices.items():
    print(f"\n{category}:")
    for practice in practices:
        print(f"  • {practice}")

print("\n🚀 Ready to train your Hyena-GLT model!")