# PyTorch Lightning Architecture Fundamentals

**File Location:** `notebooks/01_lightning_fundamentals/01_pl_architecture.ipynb`

## Introduction

This notebook introduces the core components of PyTorch Lightning: LightningModule, LightningDataModule, and Trainer. You'll learn how these components work together and understand the basic logging mechanism with `self.log()`.

## Core Lightning Components

### LightningModule - The Model Wrapper

```python
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
from torch.optim import Adam
from torchmetrics import Accuracy

class SimpleMLP(pl.LightningModule):
    def __init__(self, input_size=784, hidden_size=128, num_classes=10, lr=1e-3):
        super().__init__()
        # Save hyperparameters automatically
        self.save_hyperparameters()
        
        # Define the network
        self.model = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(hidden_size, num_classes)
        )
        
        # Define metrics
        self.train_acc = Accuracy(task="multiclass", num_classes=num_classes)
        self.val_acc = Accuracy(task="multiclass", num_classes=num_classes)
        
    def forward(self, x):
        return self.model(x)
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        
        # Log metrics - this is the key Lightning feature!
        self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True)
        self.log('train_acc', self.train_acc(y_hat, y), on_step=False, on_epoch=True, prog_bar=True)
        
        return loss
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        
        # Log validation metrics
        self.log('val_loss', loss, on_epoch=True, prog_bar=True)
        self.log('val_acc', self.val_acc(y_hat, y), on_epoch=True, prog_bar=True)
        
        return loss
    
    def configure_optimizers(self):
        return Adam(self.parameters(), lr=self.hparams.lr)

# Test the model creation
model = SimpleMLP()
print(f"Model created with hparams: {model.hparams}")
print(f"Model architecture:\n{model}")
```

### LightningDataModule - Data Pipeline Management

```python
import torch
from torch.utils.data import DataLoader, TensorDataset
import pytorch_lightning as pl

class SyntheticDataModule(pl.LightningDataModule):
    def __init__(self, num_samples=1000, input_size=784, num_classes=10, batch_size=32):
        super().__init__()
        self.save_hyperparameters()
        
    def setup(self, stage=None):
        # Create synthetic data
        torch.manual_seed(42)
        
        if stage == "fit" or stage is None:
            # Training data
            self.train_x = torch.randn(self.hparams.num_samples, self.hparams.input_size)
            self.train_y = torch.randint(0, self.hparams.num_classes, (self.hparams.num_samples,))
            
            # Validation data (20% of training size)
            val_size = self.hparams.num_samples // 5
            self.val_x = torch.randn(val_size, self.hparams.input_size)
            self.val_y = torch.randint(0, self.hparams.num_classes, (val_size,))
        
        if stage == "test" or stage is None:
            # Test data
            test_size = self.hparams.num_samples // 10
            self.test_x = torch.randn(test_size, self.hparams.input_size)
            self.test_y = torch.randint(0, self.hparams.num_classes, (test_size,))
    
    def train_dataloader(self):
        dataset = TensorDataset(self.train_x, self.train_y)
        return DataLoader(dataset, batch_size=self.hparams.batch_size, shuffle=True)
    
    def val_dataloader(self):
        dataset = TensorDataset(self.val_x, self.val_y)
        return DataLoader(dataset, batch_size=self.hparams.batch_size)
    
    def test_dataloader(self):
        dataset = TensorDataset(self.test_x, self.test_y)
        return DataLoader(dataset, batch_size=self.hparams.batch_size)

# Test the data module
dm = SyntheticDataModule(num_samples=1000, batch_size=64)
dm.setup("fit")

print(f"Training set size: {len(dm.train_x)}")
print(f"Validation set size: {len(dm.val_x)}")
print(f"Batch size: {dm.hparams.batch_size}")

# Test a batch
train_loader = dm.train_dataloader()
batch = next(iter(train_loader))
print(f"Batch shape: {batch[0].shape}, Labels shape: {batch[1].shape}")
```

### Trainer - The Training Engine

```python
import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger

# Create logger
logger = TensorBoardLogger("logs", name="pl_architecture_demo")

# Create trainer with basic configuration
trainer = pl.Trainer(
    max_epochs=3,
    logger=logger,
    enable_checkpointing=True,
    log_every_n_steps=10,
    enable_progress_bar=True,
    enable_model_summary=True
)

# Initialize model and data
model = SimpleMLP(lr=1e-3)
dm = SyntheticDataModule(num_samples=1000, batch_size=64)

# Train the model
print("Starting training...")
trainer.fit(model, dm)

print(f"Training completed! Logs saved to: {logger.log_dir}")
```

### Understanding self.log() Parameters

```python
# Demonstration of different logging options
class LoggingDemo(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.model = nn.Linear(10, 1)
        
    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.model(x)
        loss = F.mse_loss(y_hat, y.view(-1, 1))
        
        # Different logging configurations
        self.log('loss_step_only', loss, on_step=True, on_epoch=False)  # Only log per step
        self.log('loss_epoch_only', loss, on_step=False, on_epoch=True)  # Only log per epoch
        self.log('loss_both', loss, on_step=True, on_epoch=True)  # Log both
        self.log('loss_progbar', loss, prog_bar=True)  # Show in progress bar
        self.log('loss_logger_only', loss, logger=True, prog_bar=False)  # Only to logger
        
        return loss
    
    def configure_optimizers(self):
        return Adam(self.parameters(), lr=1e-3)

# Create demo data for regression
demo_dm = SyntheticDataModule(num_samples=500, input_size=10, num_classes=1, batch_size=32)

# Quick demo (just 1 epoch)
demo_trainer = pl.Trainer(max_epochs=1, logger=False, enable_checkpointing=False)
demo_model = LoggingDemo()

print("Demonstrating different logging options:")
print("- loss_step_only: logged every step")
print("- loss_epoch_only: logged at end of epoch") 
print("- loss_both: logged both step and epoch")
print("- loss_progbar: shown in progress bar")
print("- loss_logger_only: only sent to logger (not progress bar)")
```

### Model Inspection and Hyperparameters

```python
# Inspect the trained model
print("=== Model Inspection ===")
print(f"Model type: {type(model).__name__}")
print(f"Hyperparameters: {model.hparams}")
print(f"Current epoch: {model.current_epoch}")
print(f"Global step: {model.global_step}")

# Access logged metrics
print("\n=== Training Metrics ===")
if hasattr(trainer, 'logged_metrics'):
    for key, value in trainer.logged_metrics.items():
        print(f"{key}: {value:.4f}")

# Model summary
print("\n=== Model Summary ===")
from pytorch_lightning.utilities.model_summary import ModelSummary
summary = ModelSummary(model, max_depth=2)
print(summary)
```

### Architecture Benefits Demo

```python
# Compare Lightning vs Pure PyTorch approach
print("=== Lightning vs Pure PyTorch ===")

# Lightning approach (what we just did)
print("Lightning approach:")
print("✓ Automatic GPU/CPU handling")
print("✓ Automatic logging and metrics")
print("✓ Built-in progress bars")
print("✓ Automatic optimization steps")
print("✓ Easy experiment tracking")
print("✓ Configurable training loops")

# Pure PyTorch equivalent would require:
print("\nPure PyTorch equivalent would need:")
print("- Manual device handling (cuda/cpu)")
print("- Manual loss tracking and averaging")
print("- Manual progress bar implementation") 
print("- Manual optimization loop")
print("- Manual validation loop")
print("- Manual metric computation")
print("- Manual logging setup")

# Show the simplicity
print(f"\nLightning training call: trainer.fit(model, datamodule)")
print(f"Lines of Lightning code for training: ~50 lines")
print(f"Equivalent PyTorch code: ~200+ lines")
```

## Summary

In this notebook, you learned:

1. **LightningModule**: The core model wrapper that handles training/validation steps and optimizer configuration
2. **LightningDataModule**: Organized data pipeline management with setup() and dataloader methods
3. **Trainer**: The main training engine that orchestrates the entire training process
4. **self.log()**: Flexible logging system with options for step/epoch logging and progress bar display
5. **Architecture Benefits**: How Lightning reduces boilerplate code while maintaining flexibility

Key takeaways:
- Lightning separates concerns: model logic, data handling, and training configuration
- `self.log()` automatically handles metric averaging and logging
- The Trainer handles all the training loop complexity
- Hyperparameters are automatically saved and accessible via `self.hparams`
- This architecture makes experiments reproducible and code more maintainable

Next notebook: We'll explore Trainer's debugging and development features.