# TorchMetrics and Advanced Logging

**File Location:** `notebooks/02_datamodules_and_metrics/05_torchmetrics_logging.ipynb`

## Introduction

This notebook covers TorchMetrics integration with PyTorch Lightning for comprehensive model evaluation. Learn to use built-in metrics, create custom metrics, and implement advanced logging strategies for different ML tasks.

## TorchMetrics Fundamentals

### Basic Metric Usage

```python
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
from torchmetrics import Accuracy, Precision, Recall, F1Score, AUROC, MeanSquaredError
from torchmetrics import MetricCollection, ConfusionMatrix
from torchmetrics.functional import accuracy, precision, recall
import numpy as np

# Simple model for demonstration
class MetricsDemo(pl.LightningModule):
    def __init__(self, num_classes=5):
        super().__init__()
        self.save_hyperparameters()
        
        # Simple model
        self.model = nn.Sequential(
            nn.Linear(20, 64),
            nn.ReLU(),
            nn.Linear(64, num_classes)
        )
        
        # Individual metrics
        self.train_acc = Accuracy(task="multiclass", num_classes=num_classes)
        self.val_acc = Accuracy(task="multiclass", num_classes=num_classes)
        
        # Metrics with different averaging
        self.val_acc_macro = Accuracy(task="multiclass", num_classes=num_classes, average="macro")
        self.val_acc_weighted = Accuracy(task="multiclass", num_classes=num_classes, average="weighted")
        
    def forward(self, x):
        return self.model(x)
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits, y)
        
        # Update and log metrics
        preds = torch.argmax(logits, dim=1)
        self.train_acc(preds, y)
        
        # Log metrics
        self.log('train_loss', loss, on_step=True, on_epoch=True)
        self.log('train_acc', self.train_acc, on_step=False, on_epoch=True, prog_bar=True)
        
        return loss
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits, y)
        
        # Update metrics
        preds = torch.argmax(logits, dim=1)
        self.val_acc(preds, y)
        self.val_acc_macro(preds, y) 
        self.val_acc_weighted(preds, y)
        
        # Log metrics
        self.log('val_loss', loss, on_epoch=True, prog_bar=True)
        self.log('val_acc', self.val_acc, on_epoch=True, prog_bar=True)
        self.log('val_acc_macro', self.val_acc_macro, on_epoch=True)
        self.log('val_acc_weighted', self.val_acc_weighted, on_epoch=True)
        
        return loss
    
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-3)

# Create synthetic data
def create_multiclass_data(num_samples=1000, num_features=20, num_classes=5):
    torch.manual_seed(42)
    X = torch.randn(num_samples, num_features)
    # Create targets with some relationship to features
    weights = torch.randn(num_features)
    logits = X @ weights
    y = torch.div(logits - logits.min(), (logits.max() - logits.min()) / (num_classes - 1), rounding_mode='floor').long()
    y = torch.clamp(y, 0, num_classes - 1)
    return X, y

X, y = create_multiclass_data()
from torch.utils.data import TensorDataset, DataLoader

dataset = TensorDataset(X, y)
train_loader = DataLoader(dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(dataset, batch_size=32, shuffle=False)

print("✓ Basic metrics demo setup completed")
```

### Classification Metrics Suite

```python
class ClassificationMetrics(pl.LightningModule):
    def __init__(self, num_classes=5):
        super().__init__()
        self.save_hyperparameters()
        
        self.model = nn.Sequential(
            nn.Linear(20, 64),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(64, num_classes)
        )
        
        # Comprehensive classification metrics
        metrics = MetricCollection({
            'accuracy': Accuracy(task="multiclass", num_classes=num_classes),
            'precision': Precision(task="multiclass", num_classes=num_classes, average="macro"),
            'recall': Recall(task="multiclass", num_classes=num_classes, average="macro"),
            'f1': F1Score(task="multiclass", num_classes=num_classes, average="macro"),
            'auroc': AUROC(task="multiclass", num_classes=num_classes)
        })
        
        # Create separate collections for train/val/test
        self.train_metrics = metrics.clone(prefix='train_')
        self.val_metrics = metrics.clone(prefix='val_')
        self.test_metrics = metrics.clone(prefix='test_')
        
        # Per-class metrics
        self.val_precision_per_class = Precision(
            task="multiclass", num_classes=num_classes, average=None
        )
        self.val_recall_per_class = Recall(
            task="multiclass", num_classes=num_classes, average=None
        )
        
        # Confusion matrix
        self.val_confmat = ConfusionMatrix(task="multiclass", num_classes=num_classes)
        
    def forward(self, x):
        return self.model(x)
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits, y)
        
        # Get probabilities and predictions
        probs = F.softmax(logits, dim=1)
        preds = torch.argmax(logits, dim=1)
        
        # Update metrics (metrics handle probabilities automatically)
        self.train_metrics(preds, y)
        
        # Log loss
        self.log('train_loss', loss, on_step=True, on_epoch=True)
        
        # Log metrics (will be computed at end of epoch)
        self.log_dict(self.train_metrics, on_step=False, on_epoch=True)
        
        return loss
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits, y)
        
        probs = F.softmax(logits, dim=1) 
        preds = torch.argmax(logits, dim=1)
        
        # Update all metrics
        self.val_metrics(preds, y)
        self.val_precision_per_class(preds, y)
        self.val_recall_per_class(preds, y)
        self.val_confmat(preds, y)
        
        # Log metrics
        self.log('val_loss', loss, on_epoch=True, prog_bar=True)
        self.log_dict(self.val_metrics, on_step=False, on_epoch=True, prog_bar=True)
        
        return loss
    
    def test_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits, y)
        
        preds = torch.argmax(logits, dim=1)
        self.test_metrics(preds, y)
        
        self.log('test_loss', loss, on_epoch=True)
        self.log_dict(self.test_metrics, on_step=False, on_epoch=True)
        
        return loss
    
    def on_validation_epoch_end(self):
        # Log per-class metrics
        precision_per_class = self.val_precision_per_class.compute()
        recall_per_class = self.val_recall_per_class.compute()
        
        for i in range(self.hparams.num_classes):
            self.log(f'val_precision_class_{i}', precision_per_class[i])
            self.log(f'val_recall_class_{i}', recall_per_class[i])
        
        # Log confusion matrix as a figure (if you have matplotlib)
        try:
            import matplotlib.pyplot as plt
            import seaborn as sns
            
            confmat = self.val_confmat.compute()
            
            fig, ax = plt.subplots(figsize=(8, 6))
            sns.heatmap(confmat.cpu().numpy(), annot=True, fmt='d', ax=ax)
            ax.set_xlabel('Predicted')
            ax.set_ylabel('True')
            ax.set_title('Confusion Matrix')
            
            # Log figure to tensorboard
            if self.logger:
                self.logger.experiment.add_figure('val_confusion_matrix', fig, self.current_epoch)
            
            plt.close(fig)
            
        except ImportError:
            print("Matplotlib not available for confusion matrix plot")
    
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-3)

# Test comprehensive classification metrics
model = ClassificationMetrics(num_classes=5)

# Quick training to see metrics in action
trainer = pl.Trainer(
    max_epochs=3,
    enable_checkpointing=False,
    logger=False,
    limit_train_batches=20,
    limit_val_batches=10
)

print("Training with comprehensive classification metrics...")
trainer.fit(model, train_loader, val_loader)
print("✓ Classification metrics demo completed")
```

### Regression Metrics

```python
class RegressionMetrics(pl.LightningModule):
    def __init__(self, output_dim=1):
        super().__init__()
        self.save_hyperparameters()
        
        self.model = nn.Sequential(
            nn.Linear(20, 64),
            nn.ReLU(),
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.Linear(32, output_dim)
        )
        
        # Regression metrics
        from torchmetrics import MeanSquaredError, MeanAbsoluteError, R2Score
        from torchmetrics import MeanAbsolutePercentageError, ExplainedVariance
        
        regression_metrics = MetricCollection({
            'mse': MeanSquaredError(),
            'rmse': MeanSquaredError(squared=False),  # RMSE
            'mae': MeanAbsoluteError(),
            'r2': R2Score(),
            'explained_var': ExplainedVariance()
        })
        
        self.train_metrics = regression_metrics.clone(prefix='train_')
        self.val_metrics = regression_metrics.clone(prefix='val_')
        self.test_metrics = regression_metrics.clone(prefix='test_')
        
        # Additional custom metrics
        self.val_mape = MeanAbsolutePercentageError()
        
    def forward(self, x):
        return self.model(x)
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        
        # Ensure shapes match
        if y.dim() == 1:
            y = y.unsqueeze(1)
        
        loss = F.mse_loss(y_hat, y)
        
        # Update metrics
        self.train_metrics(y_hat, y)
        
        # Log metrics
        self.log('train_loss', loss, on_step=True, on_epoch=True)
        self.log_dict(self.train_metrics, on_step=False, on_epoch=True)
        
        return loss
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        
        if y.dim() == 1:
            y = y.unsqueeze(1)
        
        loss = F.mse_loss(y_hat, y)
        
        # Update metrics
        self.val_metrics(y_hat, y)
        
        # MAPE requires positive targets, so we add offset
        y_positive = torch.abs(y) + 1e-6
        y_hat_positive = torch.abs(y_hat) + 1e-6
        self.val_mape(y_hat_positive, y_positive)
        
        # Log metrics
        self.log('val_loss', loss, on_epoch=True, prog_bar=True)
        self.log_dict(self.val_metrics, on_step=False, on_epoch=True)
        self.log('val_mape', self.val_mape, on_epoch=True)
        
        return loss
    
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-3)

# Create regression data
def create_regression_data(num_samples=1000, num_features=20, noise=0.1):
    torch.manual_seed(42)
    X = torch.randn(num_samples, num_features)
    # Create targets with linear relationship + noise
    weights = torch.randn(num_features) * 0.5
    y = X @ weights + torch.randn(num_samples) * noise
    return X, y

X_reg, y_reg = create_regression_data()
reg_dataset = TensorDataset(X_reg, y_reg)
reg_train_loader = DataLoader(reg_dataset, batch_size=32, shuffle=True)
reg_val_loader = DataLoader(reg_dataset, batch_size=32, shuffle=False)

# Test regression metrics
reg_model = RegressionMetrics()

trainer = pl.Trainer(
    max_epochs=3,
    enable_checkpointing=False,
    logger=False,
    limit_train_batches=20,
    limit_val_batches=10
)

print("Training with regression metrics...")
trainer.fit(reg_model, reg_train_loader, reg_val_loader)
print("✓ Regression metrics demo completed")
```

## Custom Metrics

### Building Custom Metrics

```python
from torchmetrics import Metric
from typing import Any

class TopKCategoricalAccuracy(Metric):
    """Custom metric: Top-K categorical accuracy"""
    
    def __init__(self, k: int = 3, **kwargs):
        super().__init__(**kwargs)
        self.k = k
        
        # Define metric state
        self.add_state("correct", default=torch.tensor(0), dist_reduce_fx="sum")
        self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")
    
    def update(self, preds: torch.Tensor, target: torch.Tensor):
        """Update metric state"""
        # preds should be logits [batch_size, num_classes]
        # target should be class indices [batch_size]
        
        # Get top-k predictions
        _, top_k_preds = torch.topk(preds, self.k, dim=1)
        
        # Check if true class is in top-k
        target_expanded = target.unsqueeze(1).expand_as(top_k_preds)
        correct = torch.any(top_k_preds == target_expanded, dim=1)
        
        self.correct += correct.sum()
        self.total += target.size(0)
    
    def compute(self):
        """Compute final metric value"""
        return self.correct.float() / self.total

class MeanPredictionError(Metric):
    """Custom metric: Mean prediction error (signed)"""
    
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        
        self.add_state("sum_error", default=torch.tensor(0.0), dist_reduce_fx="sum")
        self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")
    
    def update(self, preds: torch.Tensor, target: torch.Tensor):
        """Update metric state"""
        error = preds - target
        self.sum_error += error.sum()
        self.total += target.numel()
    
    def compute(self):
        """Compute final metric value"""  
        return self.sum_error / self.total

class PercentileError(Metric):
    """Custom metric: Error at specific percentiles"""
    
    def __init__(self, percentiles=[50, 90, 95], **kwargs):
        super().__init__(**kwargs)
        self.percentiles = percentiles
        
        self.add_state("errors", default=[], dist_reduce_fx="cat")
    
    def update(self, preds: torch.Tensor, target: torch.Tensor):
        """Update metric state"""
        errors = torch.abs(preds - target).flatten()
        self.errors.append(errors)
    
    def compute(self):
        """Compute percentile errors"""
        if len(self.errors) == 0:
            return {f"error_p{p}": torch.tensor(0.0) for p in self.percentiles}
        
        all_errors = torch.cat(self.errors)
        results = {}
        
        for p in self.percentiles:
            percentile_val = torch.quantile(all_errors, p / 100.0)
            results[f"error_p{p}"] = percentile_val
            
        return results

# Model using custom metrics
class CustomMetricsModel(pl.LightningModule):
    def __init__(self, num_classes=10):
        super().__init__()
        self.save_hyperparameters()
        
        self.model = nn.Sequential(
            nn.Linear(20, 128),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, num_classes)
        )
        
        # Standard metrics
        self.val_acc = Accuracy(task="multiclass", num_classes=num_classes)
        
        # Custom metrics
        self.val_top3_acc = TopKCategoricalAccuracy(k=3)
        self.val_top5_acc = TopKCategoricalAccuracy(k=5)
        
    def forward(self, x):
        return self.model(x)
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits, y)
        
        self.log('train_loss', loss, on_step=True, on_epoch=True)
        return loss
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits, y)
        
        preds = torch.argmax(logits, dim=1)
        
        # Update standard metrics
        self.val_acc(preds, y)
        
        # Update custom metrics with logits (not predictions)
        self.val_top3_acc(logits, y)
        self.val_top5_acc(logits, y)
        
        # Log metrics
        self.log('val_loss', loss, on_epoch=True, prog_bar=True)
        self.log('val_acc', self.val_acc, on_epoch=True, prog_bar=True)
        self.log('val_top3_acc', self.val_top3_acc, on_epoch=True)
        self.log('val_top5_acc', self.val_top5_acc, on_epoch=True)
        
        return loss
    
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-3)

# Create data with more classes to test top-k accuracy
X_multi, y_multi = create_multiclass_data(num_classes=10)
multi_dataset = TensorDataset(X_multi, y_multi)
multi_train_loader = DataLoader(multi_dataset, batch_size=32, shuffle=True)
multi_val_loader = DataLoader(multi_dataset, batch_size=32, shuffle=False)

# Test custom metrics
custom_model = CustomMetricsModel(num_classes=10)

trainer = pl.Trainer(
    max_epochs=3,
    enable_checkpointing=False,
    logger=False,
    limit_train_batches=20,
    limit_val_batches=10
)

print("Training with custom metrics...")
trainer.fit(custom_model, multi_train_loader, multi_val_loader)
print("✓ Custom metrics demo completed")
```

## Advanced Logging Strategies

### Structured Logging with Metric Collections

```python
class AdvancedLoggingModel(pl.LightningModule):
    def __init__(self, num_classes=5):
        super().__init__()
        self.save_hyperparameters()
        
        self.model = nn.Sequential(
            nn.Linear(20, 64),
            nn.ReLU(),
            nn.Linear(64, num_classes)
        )
        
        # Organize metrics by category
        self.loss_metrics = MetricCollection({
            'ce_loss': torch.nn.CrossEntropyLoss(),
        })
        
        self.performance_metrics = MetricCollection({
            'accuracy': Accuracy(task="multiclass", num_classes=num_classes),
            'f1_macro': F1Score(task="multiclass", num_classes=num_classes, average="macro"),
            'f1_micro': F1Score(task="multiclass", num_classes=num_classes, average="micro"),
        })
        
        self.robustness_metrics = MetricCollection({
            'auroc': AUROC(task="multiclass", num_classes=num_classes),
            'precision': Precision(task="multiclass", num_classes=num_classes, average="macro"),
            'recall': Recall(task="multiclass", num_classes=num_classes, average="macro"),
        })
        
        # Create train/val versions
        self.train_performance = self.performance_metrics.clone(prefix='train_')
        self.val_performance = self.performance_metrics.clone(prefix='val_')
        self.val_robustness = self.robustness_metrics.clone(prefix='val_')
        
        # Track additional statistics
        self.prediction_confidence_sum = 0.0
        self.prediction_entropy_sum = 0.0
        self.num_predictions = 0
        
    def forward(self, x):
        return self.model(x)
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits, y)
        
        preds = torch.argmax(logits, dim=1)
        probs = F.softmax(logits, dim=1)
        
        # Update metrics
        self.train_performance(preds, y)
        
        # Calculate additional statistics
        max_probs, _ = torch.max(probs, dim=1)
        entropy = -torch.sum(probs * torch.log(probs + 1e-8), dim=1)
        
        # Log various aspects
        self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True)
        self.log('train_confidence', max_probs.mean(), on_step=True, on_epoch=True)
        self.log('train_entropy', entropy.mean(), on_step=True, on_epoch=True)
        
        # Log metrics at epoch end
        self.log_dict(self.train_performance, on_step=False, on_epoch=True)
        
        return loss
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits, y)
        
        preds = torch.argmax(logits, dim=1)
        probs = F.softmax(logits, dim=1)
        
        # Update metrics
        self.val_performance(preds, y)
        self.val_robustness(preds, y)
        
        # Track prediction statistics
        max_probs, _ = torch.max(probs, dim=1)
        entropy = -torch.sum(probs * torch.log(probs + 1e-8), dim=1)
        
        self.prediction_confidence_sum += max_probs.sum().item()
        self.prediction_entropy_sum += entropy.sum().item()
        self.num_predictions += x.size(0)
        
        # Log basic metrics
        self.log('val_loss', loss, on_epoch=True, prog_bar=True)
        self.log_dict(self.val_performance, on_step=False, on_epoch=True, prog_bar=True)
        self.log_dict(self.val_robustness, on_step=False, on_epoch=True)
        
        return loss
    
    def on_validation_epoch_end(self):
        # Log aggregated prediction statistics
        if self.num_predictions > 0:
            avg_confidence = self.prediction_confidence_sum / self.num_predictions
            avg_entropy = self.prediction_entropy_sum / self.num_predictions
            
            self.log('val_avg_confidence', avg_confidence)
            self.log('val_avg_entropy', avg_entropy)
            
            # Reset for next epoch
            self.prediction_confidence_sum = 0.0
            self.prediction_entropy_sum = 0.0
            self.num_predictions = 0
        
        # Custom epoch-level computations
        current_f1 = self.val_performance['val_f1_macro'].compute()
        current_acc = self.val_performance['val_accuracy'].compute()
        
        # Log derived metrics
        self.log('val_f1_acc_ratio', current_f1 / (current_acc + 1e-8))
        
        # Log to multiple loggers with different formats
        if self.logger:
            # Log scalar metrics
            metrics_dict = {
                'epoch': self.current_epoch,
                'val_f1_macro': current_f1.item(),
                'val_accuracy': current_acc.item(),
            }
            
            # If using TensorBoard logger
            if hasattr(self.logger, 'experiment'):
                for key, value in metrics_dict.items():
                    if key != 'epoch':
                        self.logger.experiment.add_scalar(f'custom/{key}', value, self.current_epoch)
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, mode='max', factor=0.5, patience=2, verbose=True
        )
        return {
            'optimizer': optimizer,
            'lr_scheduler': {
                'scheduler': scheduler,
                'monitor': 'val_accuracy',  # Metric to monitor
                'frequency': 1,
                'interval': 'epoch'
            }
        }

# Test advanced logging
from pytorch_lightning.loggers import TensorBoardLogger

logger = TensorBoardLogger("logs", name="advanced_logging")

advanced_model = AdvancedLoggingModel(num_classes=5)

trainer = pl.Trainer(
    max_epochs=5,
    logger=logger,
    enable_checkpointing=False,
    limit_train_batches=20,
    limit_val_batches=10
)

print("Training with advanced logging...")
trainer.fit(advanced_model, train_loader, val_loader)
print("✓ Advanced logging demo completed")
print(f"Logs saved to: {logger.log_dir}")
```

### Multi-Modal and Hierarchical Metrics

```python
class HierarchicalMetrics(pl.LightningModule):
    """Model with hierarchical/grouped metric organization"""
    
    def __init__(self, num_classes=5):
        super().__init__()
        self.save_hyperparameters()
        
        self.model = nn.Sequential(
            nn.Linear(20, 64),
            nn.ReLU(),
            nn.Linear(64, num_classes)
        )
        
        # Organize metrics hierarchically
        self.metrics = nn.ModuleDict({
            'accuracy': nn.ModuleDict({
                'train': Accuracy(task="multiclass", num_classes=num_classes),
                'val': Accuracy(task="multiclass", num_classes=num_classes),
                'test': Accuracy(task="multiclass", num_classes=num_classes)
            }),
            'f1': nn.ModuleDict({
                'train_macro': F1Score(task="multiclass", num_classes=num_classes, average="macro"),
                'train_micro': F1Score(task="multiclass", num_classes=num_classes, average="micro"),
                'val_macro': F1Score(task="multiclass", num_classes=num_classes, average="macro"),
                'val_micro': F1Score(task="multiclass", num_classes=num_classes, average="micro"),
            })
        })
        
        # Per-class metrics for detailed analysis
        self.per_class_metrics = nn.ModuleDict({
            f'class_{i}': nn.ModuleDict({
                'precision': Precision(task="multiclass", num_classes=num_classes, average=None),
                'recall': Recall(task="multiclass", num_classes=num_classes, average=None),
            }) for i in range(num_classes)
        })
    
    def forward(self, x):
        return self.model(x)
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits, y)
        
        preds = torch.argmax(logits, dim=1)
        
        # Update hierarchical metrics
        self.metrics['accuracy']['train'](preds, y)
        self.metrics['f1']['train_macro'](preds, y)
        self.metrics['f1']['train_micro'](preds, y)
        
        # Log with hierarchical names
        self.log('train_loss', loss, on_step=True, on_epoch=True)
        self.log('metrics/accuracy/train', self.metrics['accuracy']['train'], on_epoch=True)
        self.log('metrics/f1/train_macro', self.metrics['f1']['train_macro'], on_epoch=True)
        self.log('metrics/f1/train_micro', self.metrics['f1']['train_micro'], on_epoch=True)
        
        return loss
    
    def validation_step(self, batch, batch_idx):
        x, y = batch  
        logits = self(x)
        loss = F.cross_entropy(logits, y)
        
        preds = torch.argmax(logits, dim=1)
        
        # Update all validation metrics
        self.metrics['accuracy']['val'](preds, y)
        self.metrics['f1']['val_macro'](preds, y)
        self.metrics['f1']['val_micro'](preds, y)
        
        # Update per-class metrics
        for i in range(self.hparams.num_classes):
            self.per_class_metrics[f'class_{i}']['precision'](preds, y)
            self.per_class_metrics[f'class_{i}']['recall'](preds, y)
        
        # Log hierarchical metrics
        self.log('val_loss', loss, on_epoch=True, prog_bar=True)
        self.log('metrics/accuracy/val', self.metrics['accuracy']['val'], on_epoch=True, prog_bar=True)
        self.log('metrics/f1/val_macro', self.metrics['f1']['val_macro'], on_epoch=True)
        self.log('metrics/f1/val_micro', self.metrics['f1']['val_micro'], on_epoch=True)
        
        return loss
    
    def on_validation_epoch_end(self):
        # Log per-class metrics
        for i in range(self.hparams.num_classes):
            precision_all = self.per_class_metrics[f'class_{i}']['precision'].compute()
            recall_all = self.per_class_metrics[f'class_{i}']['recall'].compute()
            
            # Extract value for class i
            if precision_all.numel() > i:
                precision_i = precision_all[i]
                recall_i = recall_all[i]
                
                self.log(f'class_metrics/precision/class_{i}', precision_i)
                self.log(f'class_metrics/recall/class_{i}', recall_i)
                
                # F1 for this class
                f1_i = 2 * (precision_i * recall_i) / (precision_i + recall_i + 1e-8)
                self.log(f'class_metrics/f1/class_{i}', f1_i)
    
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-3)

# Test hierarchical metrics
hierarchical_model = HierarchicalMetrics(num_classes=5)

trainer = pl.Trainer(
    max_epochs=3,
    enable_checkpointing=False,
    logger=TensorBoardLogger("logs", name="hierarchical_metrics"),
    limit_train_batches=20,
    limit_val_batches=10
)

print("Training with hierarchical metrics...")
trainer.fit(hierarchical_model, train_loader, val_loader)
print("✓ Hierarchical metrics demo completed")
```

## Summary

This notebook covered comprehensive metric tracking and logging in PyTorch Lightning:

1. **TorchMetrics Basics**: Using built-in metrics with proper state management and averaging
2. **Classification Metrics**: Accuracy, precision, recall, F1, AUROC, confusion matrix
3. **Regression Metrics**: MSE, RMSE, MAE, R², explained variance, MAPE
4. **Custom Metrics**: Building domain-specific metrics by extending the Metric class
5. **Advanced Logging**: Structured, hierarchical metric organization and multi-logger support
6. **Metric Collections**: Grouping related metrics for cleaner code and logging

Key best practices:
- Use MetricCollection to group related metrics
- Implement proper metric state management with `add_state()`
- Log metrics at appropriate intervals (step vs epoch)
- Organize metrics hierarchically for better visualization
- Use custom metrics for domain-specific evaluation
- Leverage multiple loggers for different visualization needs
- Track prediction confidence and uncertainty alongside accuracy

Advanced features covered:
- Per-class metric tracking for imbalanced datasets
- Prediction confidence and entropy monitoring
- Learning rate scheduling based on metric values
- Custom epoch-end computations and derived metrics
- Integration with TensorBoard for rich visualizations

Next notebook: We'll explore checkpointing and early stopping strategies.