# Utilities and Helper Functions

This notebook provides interactive exploration and testing of utility functions, metrics, losses, and training helpers.
It allows you to understand and experiment with the core components of the training system.

## Setup and Imports

Import all necessary libraries and utility modules.

In [None]:
import os
import sys
import warnings
from pathlib import Path
from typing import Dict, Any, List, Tuple, Optional

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from IPython.display import display, HTML, clear_output
import ipywidgets as widgets
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, roc_auc_score

# Suppress warnings
warnings.filterwarnings('ignore')

# Add src to path
if 'src' not in sys.path:
    sys.path.append('src')

print("‚úÖ All imports successful!")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

## Load Project Utilities

Import all utility modules from the project.

In [None]:
# Load configuration management
%run 01_Configuration_Management.ipynb

# Import utility modules
from src.utils import (
    set_seed, get_optimizer, get_scheduler, evaluate,
    prepare_thresholds, compute_loss
)
from src.utils.metrics import compute_metrics
from src.utils.training import (
    train_epoch, validate_epoch, compute_loss as training_compute_loss
)
from src.utils.ema import EMA
from src.losses import FocalLoss, LabelSmoothingCrossEntropy

print("‚úÖ Project utilities loaded!")

## Loss Functions Explorer

Interactive exploration of different loss functions.

In [None]:
def create_loss_function_explorer():
    """Create an interactive loss function explorer."""
    
    # Loss function selection
    loss_type = widgets.Dropdown(
        options=['Binary Cross Entropy', 'Focal Loss', 'Label Smoothing CE'],
        value='Binary Cross Entropy',
        description='Loss Type:'
    )
    
    # Focal Loss parameters
    focal_alpha = widgets.FloatSlider(
        value=1.0,
        min=0.1,
        max=2.0,
        step=0.1,
        description='Focal Œ±:',
        disabled=True
    )
    
    focal_gamma = widgets.FloatSlider(
        value=2.0,
        min=0.0,
        max=5.0,
        step=0.5,
        description='Focal Œ≥:',
        disabled=True
    )
    
    # Label Smoothing parameter
    smoothing = widgets.FloatSlider(
        value=0.1,
        min=0.0,
        max=0.5,
        step=0.05,
        description='Smoothing:',
        disabled=True
    )
    
    # Test parameters
    num_samples = widgets.IntSlider(
        value=1000,
        min=100,
        max=5000,
        step=100,
        description='Samples:'
    )
    
    num_classes = widgets.IntSlider(
        value=5,
        min=2,
        max=20,
        step=1,
        description='Classes:'
    )
    
    test_button = widgets.Button(
        description='üß™ Test Loss Function',
        button_style='primary'
    )
    
    output = widgets.Output()
    
    def on_loss_type_change(change):
        loss_name = change['new']
        
        # Enable/disable parameters based on loss type
        focal_alpha.disabled = loss_name != 'Focal Loss'
        focal_gamma.disabled = loss_name != 'Focal Loss'
        smoothing.disabled = loss_name != 'Label Smoothing CE'
    
    loss_type.observe(on_loss_type_change, names='value')
    
    def on_test_clicked(b):
        with output:
            output.clear_output()
            
            try:
                print(f"üß™ Testing {loss_type.value}")
                print("=" * 40)
                
                # Generate synthetic data
                torch.manual_seed(42)
                
                # Create logits and targets
                logits = torch.randn(num_samples.value, num_classes.value)
                targets = torch.randint(0, 2, (num_samples.value, num_classes.value)).float()
                
                print(f"   Data shape: {logits.shape}")
                print(f"   Positive rate: {targets.mean():.3f}")
                
                # Create loss function
                if loss_type.value == 'Binary Cross Entropy':
                    loss_fn = nn.BCEWithLogitsLoss()
                    loss_name = 'BCE'
                elif loss_type.value == 'Focal Loss':
                    loss_fn = FocalLoss(alpha=focal_alpha.value, gamma=focal_gamma.value)
                    loss_name = f'Focal(Œ±={focal_alpha.value}, Œ≥={focal_gamma.value})'
                else:  # Label Smoothing CE
                    loss_fn = LabelSmoothingCrossEntropy(smoothing=smoothing.value)
                    loss_name = f'LabelSmooth(Œµ={smoothing.value})'
                
                # Compute loss
                loss_value = loss_fn(logits, targets)
                
                print(f"\nüìä Loss Results:")
                print(f"   Loss function: {loss_name}")
                print(f"   Loss value: {loss_value.item():.4f}")
                
                # Analyze loss behavior
                probs = torch.sigmoid(logits)
                predictions = (probs > 0.5).float()
                
                accuracy = (predictions == targets).float().mean()
                print(f"   Accuracy: {accuracy.item():.4f}")
                
                # Visualize loss landscape
                print(f"\nüìà Loss Landscape Analysis:")
                
                # Test different confidence levels
                confidence_levels = torch.linspace(0.01, 0.99, 50)
                loss_values = []
                
                for conf in confidence_levels:
                    # Create confident predictions
                    confident_logits = torch.log(conf / (1 - conf)) * torch.ones_like(targets)
                    loss_val = loss_fn(confident_logits, targets)
                    loss_values.append(loss_val.item())
                
                # Plot loss landscape
                plt.figure(figsize=(12, 8))
                
                # Loss vs confidence
                plt.subplot(2, 2, 1)
                plt.plot(confidence_levels.numpy(), loss_values, 'b-', linewidth=2)
                plt.xlabel('Prediction Confidence')
                plt.ylabel('Loss Value')
                plt.title(f'{loss_name} vs Confidence')
                plt.grid(True, alpha=0.3)
                
                # Loss distribution
                plt.subplot(2, 2, 2)
                individual_losses = []
                for i in range(min(100, num_samples.value)):
                    single_loss = loss_fn(logits[i:i+1], targets[i:i+1])
                    individual_losses.append(single_loss.item())
                
                plt.hist(individual_losses, bins=20, alpha=0.7, color='skyblue')
                plt.xlabel('Loss Value')
                plt.ylabel('Frequency')
                plt.title('Loss Distribution (Sample)')
                plt.grid(True, alpha=0.3)
                
                # Prediction distribution
                plt.subplot(2, 2, 3)
                plt.hist(probs.flatten().numpy(), bins=30, alpha=0.7, color='lightcoral')
                plt.xlabel('Predicted Probability')
                plt.ylabel('Frequency')
                plt.title('Prediction Distribution')
                plt.grid(True, alpha=0.3)
                
                # Target distribution
                plt.subplot(2, 2, 4)
                class_counts = targets.sum(dim=0).numpy()
                plt.bar(range(len(class_counts)), class_counts, color='lightgreen')
                plt.xlabel('Class Index')
                plt.ylabel('Positive Count')
                plt.title('Target Distribution')
                plt.grid(True, alpha=0.3)
                
                plt.tight_layout()
                plt.show()
                
                # Compare with other loss functions
                print(f"\nüîÑ Comparison with Other Loss Functions:")
                
                comparison_losses = {
                    'BCE': nn.BCEWithLogitsLoss(),
                    'Focal(Œ±=1,Œ≥=2)': FocalLoss(alpha=1.0, gamma=2.0),
                    'LabelSmooth(Œµ=0.1)': LabelSmoothingCrossEntropy(smoothing=0.1)
                }
                
                for name, loss_func in comparison_losses.items():
                    try:
                        comp_loss = loss_func(logits, targets)
                        print(f"   {name}: {comp_loss.item():.4f}")
                    except Exception as e:
                        print(f"   {name}: Error - {e}")
                
            except Exception as e:
                print(f"‚ùå Error testing loss function: {e}")
                import traceback
                traceback.print_exc()
    
    test_button.on_click(on_test_clicked)
    
    # Layout
    controls = widgets.VBox([
        widgets.HTML("<h3>Loss Function Explorer</h3>"),
        loss_type,
        widgets.HBox([focal_alpha, focal_gamma]),
        smoothing,
        widgets.HBox([num_samples, num_classes]),
        test_button
    ])
    
    return widgets.VBox([controls, output])

# Display loss function explorer
loss_explorer = create_loss_function_explorer()
display(loss_explorer)

## Optimizer and Scheduler Explorer

Interactive exploration of optimizers and learning rate schedulers.

In [None]:
def create_optimizer_scheduler_explorer():
    """Create an interactive optimizer and scheduler explorer."""
    
    # Optimizer selection
    optimizer_type = widgets.Dropdown(
        options=['AdamW', 'Adam', 'SGD'],
        value='AdamW',
        description='Optimizer:'
    )
    
    # Optimizer parameters
    learning_rate = widgets.FloatLogSlider(
        value=1e-4,
        base=10,
        min=-6,
        max=-2,
        step=0.1,
        description='Learning Rate:'
    )
    
    weight_decay = widgets.FloatLogSlider(
        value=1e-2,
        base=10,
        min=-5,
        max=-1,
        step=0.1,
        description='Weight Decay:'
    )
    
    # Scheduler selection
    scheduler_type = widgets.Dropdown(
        options=['linear', 'cosine', 'polynomial', 'constant'],
        value='linear',
        description='Scheduler:'
    )
    
    # Scheduler parameters
    num_epochs = widgets.IntSlider(
        value=10,
        min=1,
        max=100,
        description='Epochs:'
    )
    
    warmup_ratio = widgets.FloatSlider(
        value=0.1,
        min=0.0,
        max=0.5,
        step=0.05,
        description='Warmup Ratio:'
    )
    
    test_button = widgets.Button(
        description='üìä Test Configuration',
        button_style='primary'
    )
    
    output = widgets.Output()
    
    def on_test_clicked(b):
        with output:
            output.clear_output()
            
            try:
                print(f"üìä Testing Optimizer and Scheduler Configuration")
                print("=" * 50)
                
                # Create a simple model for testing
                model = nn.Linear(10, 1)
                
                # Create optimizer configuration
                optimizer_config = OptimizerConfig(
                    type=optimizer_type.value.lower(),
                    learning_rate=learning_rate.value,
                    weight_decay=weight_decay.value
                )
                
                # Create scheduler configuration
                scheduler_config = SchedulerConfig(
                    type=scheduler_type.value,
                    warmup_ratio=warmup_ratio.value
                )
                
                print(f"   Optimizer: {optimizer_type.value}")
                print(f"   Learning Rate: {learning_rate.value:.2e}")
                print(f"   Weight Decay: {weight_decay.value:.2e}")
                print(f"   Scheduler: {scheduler_type.value}")
                print(f"   Warmup Ratio: {warmup_ratio.value}")
                
                # Create optimizer and scheduler
                optimizer = get_optimizer(model, optimizer_config)
                
                # Calculate total steps
                steps_per_epoch = 100  # Simulated
                total_steps = num_epochs.value * steps_per_epoch
                
                scheduler = get_scheduler(optimizer, scheduler_config, total_steps)
                
                print(f"\nüìà Learning Rate Schedule:")
                print(f"   Total steps: {total_steps}")
                print(f"   Warmup steps: {int(total_steps * warmup_ratio.value)}")
                
                # Simulate training and collect learning rates
                learning_rates = []
                steps = []
                
                for step in range(total_steps):
                    current_lr = optimizer.param_groups[0]['lr']
                    learning_rates.append(current_lr)
                    steps.append(step)
                    
                    # Simulate optimizer step
                    if scheduler is not None:
                        scheduler.step()
                
                # Plot learning rate schedule
                plt.figure(figsize=(12, 8))
                
                # Learning rate over steps
                plt.subplot(2, 2, 1)
                plt.plot(steps, learning_rates, 'b-', linewidth=2)
                plt.xlabel('Training Step')
                plt.ylabel('Learning Rate')
                plt.title(f'{scheduler_type.value.title()} LR Schedule')
                plt.grid(True, alpha=0.3)
                
                # Learning rate over epochs
                plt.subplot(2, 2, 2)
                epoch_lrs = [learning_rates[i * steps_per_epoch] for i in range(num_epochs.value)]
                plt.plot(range(num_epochs.value), epoch_lrs, 'r-o', linewidth=2, markersize=6)
                plt.xlabel('Epoch')
                plt.ylabel('Learning Rate')
                plt.title('LR at Epoch Start')
                plt.grid(True, alpha=0.3)
                
                # Log scale learning rate
                plt.subplot(2, 2, 3)
                plt.semilogy(steps, learning_rates, 'g-', linewidth=2)
                plt.xlabel('Training Step')
                plt.ylabel('Learning Rate (log scale)')
                plt.title('LR Schedule (Log Scale)')
                plt.grid(True, alpha=0.3)
                
                # Learning rate statistics
                plt.subplot(2, 2, 4)
                lr_stats = {
                    'Initial LR': learning_rates[0],
                    'Max LR': max(learning_rates),
                    'Final LR': learning_rates[-1],
                    'Min LR': min(learning_rates)
                }
                
                bars = plt.bar(lr_stats.keys(), lr_stats.values(), color=['blue', 'green', 'red', 'orange'])
                plt.ylabel('Learning Rate')
                plt.title('LR Statistics')
                plt.yscale('log')
                plt.xticks(rotation=45)
                
                # Add value labels on bars
                for bar, value in zip(bars, lr_stats.values()):
                    plt.text(bar.get_x() + bar.get_width()/2, value, f'{value:.2e}', 
                            ha='center', va='bottom', fontsize=8)
                
                plt.tight_layout()
                plt.show()
                
                # Print detailed statistics
                print(f"\nüìä Detailed Statistics:")
                print(f"   Initial LR: {learning_rates[0]:.2e}")
                print(f"   Maximum LR: {max(learning_rates):.2e}")
                print(f"   Final LR: {learning_rates[-1]:.2e}")
                print(f"   Minimum LR: {min(learning_rates):.2e}")
                print(f"   LR Decay Ratio: {learning_rates[-1] / learning_rates[0]:.4f}")
                
                # Warmup analysis
                warmup_steps = int(total_steps * warmup_ratio.value)
                if warmup_steps > 0:
                    warmup_lrs = learning_rates[:warmup_steps]
                    print(f"\nüî• Warmup Analysis:")
                    print(f"   Warmup steps: {warmup_steps}")
                    print(f"   LR at warmup end: {warmup_lrs[-1]:.2e}")
                    print(f"   Warmup slope: {(warmup_lrs[-1] - warmup_lrs[0]) / warmup_steps:.2e} per step")
                
            except Exception as e:
                print(f"‚ùå Error testing configuration: {e}")
                import traceback
                traceback.print_exc()
    
    test_button.on_click(on_test_clicked)
    
    # Layout
    controls = widgets.VBox([
        widgets.HTML("<h3>Optimizer & Scheduler Explorer</h3>"),
        widgets.HBox([optimizer_type, scheduler_type]),
        widgets.HBox([learning_rate, weight_decay]),
        widgets.HBox([num_epochs, warmup_ratio]),
        test_button
    ])
    
    return widgets.VBox([controls, output])

# Display optimizer and scheduler explorer
print("\nüìä Optimizer & Scheduler Explorer:")
optimizer_explorer = create_optimizer_scheduler_explorer()
display(optimizer_explorer)

## Metrics and Evaluation Explorer

Interactive exploration of evaluation metrics.

In [None]:
def create_metrics_explorer():
    """Create an interactive metrics explorer."""
    
    # Data generation parameters
    num_samples = widgets.IntSlider(
        value=1000,
        min=100,
        max=5000,
        step=100,
        description='Samples:'
    )
    
    num_classes = widgets.IntSlider(
        value=5,
        min=2,
        max=20,
        step=1,
        description='Classes:'
    )
    
    positive_rate = widgets.FloatSlider(
        value=0.3,
        min=0.1,
        max=0.9,
        step=0.1,
        description='Positive Rate:'
    )
    
    noise_level = widgets.FloatSlider(
        value=0.1,
        min=0.0,
        max=0.5,
        step=0.05,
        description='Noise Level:'
    )
    
    compute_button = widgets.Button(
        description='üìä Compute Metrics',
        button_style='primary'
    )
    
    output = widgets.Output()
    
    def on_compute_clicked(b):
        with output:
            output.clear_output()
            
            try:
                print(f"üìä Computing Evaluation Metrics")
                print("=" * 40)
                
                # Generate synthetic data
                torch.manual_seed(42)
                np.random.seed(42)
                
                # Generate targets
                targets = torch.bernoulli(torch.full((num_samples.value, num_classes.value), positive_rate.value))
                
                # Generate predictions with some correlation to targets
                base_probs = targets.float() * (0.8 - noise_level.value) + (1 - targets.float()) * noise_level.value
                noise = torch.randn_like(base_probs) * 0.1
                probs = torch.clamp(base_probs + noise, 0.01, 0.99)
                
                # Convert to predictions
                predictions = (probs > 0.5).float()
                
                print(f"   Data shape: {targets.shape}")
                print(f"   Actual positive rate: {targets.mean():.3f}")
                print(f"   Predicted positive rate: {predictions.mean():.3f}")
                
                # Compute metrics using project function
                metrics = compute_metrics(predictions.numpy(), targets.numpy())
                
                print(f"\nüìà Computed Metrics:")
                for metric_name, value in metrics.items():
                    if isinstance(value, (int, float)):
                        print(f"   {metric_name}: {value:.4f}")
                    else:
                        print(f"   {metric_name}: {value}")
                
                # Compute additional sklearn metrics for comparison
                print(f"\nüîç Additional Metrics:")
                
                # Per-class metrics
                for i in range(num_classes.value):
                    class_targets = targets[:, i].numpy()
                    class_preds = predictions[:, i].numpy()
                    class_probs = probs[:, i].numpy()
                    
                    if class_targets.sum() > 0:  # Only if there are positive examples
                        try:
                            auc = roc_auc_score(class_targets, class_probs)
                            precision, recall, f1, _ = precision_recall_fscore_support(
                                class_targets, class_preds, average='binary', zero_division=0
                            )
                            print(f"   Class {i}: AUC={auc:.3f}, P={precision:.3f}, R={recall:.3f}, F1={f1:.3f}")
                        except ValueError:
                            print(f"   Class {i}: Cannot compute AUC (single class)")
                
                # Visualize metrics
                fig, axes = plt.subplots(2, 2, figsize=(15, 12))
                fig.suptitle('Evaluation Metrics Analysis', fontsize=16)
                
                # 1. Confusion matrix heatmap (for first class)
                from sklearn.metrics import confusion_matrix
                cm = confusion_matrix(targets[:, 0].numpy(), predictions[:, 0].numpy())
                
                sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', ax=axes[0, 0])
                axes[0, 0].set_title('Confusion Matrix (Class 0)')
                axes[0, 0].set_xlabel('Predicted')
                axes[0, 0].set_ylabel('Actual')
                
                # 2. Prediction probability distribution
                axes[0, 1].hist(probs.flatten().numpy(), bins=30, alpha=0.7, color='skyblue', edgecolor='black')
                axes[0, 1].axvline(0.5, color='red', linestyle='--', label='Decision Threshold')
                axes[0, 1].set_xlabel('Predicted Probability')
                axes[0, 1].set_ylabel('Frequency')
                axes[0, 1].set_title('Prediction Probability Distribution')
                axes[0, 1].legend()
                axes[0, 1].grid(True, alpha=0.3)
                
                # 3. Per-class performance
                class_f1_scores = []
                for i in range(num_classes.value):
                    class_targets = targets[:, i].numpy()
                    class_preds = predictions[:, i].numpy()
                    
                    _, _, f1, _ = precision_recall_fscore_support(
                        class_targets, class_preds, average='binary', zero_division=0
                    )
                    class_f1_scores.append(f1)
                
                axes[1, 0].bar(range(num_classes.value), class_f1_scores, color='lightgreen')
                axes[1, 0].set_xlabel('Class Index')
                axes[1, 0].set_ylabel('F1 Score')
                axes[1, 0].set_title('Per-Class F1 Scores')
                axes[1, 0].grid(True, alpha=0.3)
                
                # 4. Precision-Recall curve (for first class)
                from sklearn.metrics import precision_recall_curve
                
                if targets[:, 0].sum() > 0:
                    precision_curve, recall_curve, _ = precision_recall_curve(
                        targets[:, 0].numpy(), probs[:, 0].numpy()
                    )
                    
                    axes[1, 1].plot(recall_curve, precision_curve, 'b-', linewidth=2)
                    axes[1, 1].set_xlabel('Recall')
                    axes[1, 1].set_ylabel('Precision')
                    axes[1, 1].set_title('Precision-Recall Curve (Class 0)')
                    axes[1, 1].grid(True, alpha=0.3)
                else:
                    axes[1, 1].text(0.5, 0.5, 'No positive examples\nfor Class 0', 
                                   ha='center', va='center', transform=axes[1, 1].transAxes)
                    axes[1, 1].set_title('Precision-Recall Curve (Class 0)')
                
                plt.tight_layout()
                plt.show()
                
                # Threshold analysis
                print(f"\nüéØ Threshold Analysis:")
                
                thresholds = [0.3, 0.4, 0.5, 0.6, 0.7]
                threshold_results = []
                
                for threshold in thresholds:
                    thresh_preds = (probs > threshold).float()
                    thresh_metrics = compute_metrics(thresh_preds.numpy(), targets.numpy())
                    
                    threshold_results.append({
                        'Threshold': threshold,
                        'Accuracy': thresh_metrics.get('accuracy', 0),
                        'F1': thresh_metrics.get('f1_macro', 0),
                        'Precision': thresh_metrics.get('precision_macro', 0),
                        'Recall': thresh_metrics.get('recall_macro', 0)
                    })
                
                threshold_df = pd.DataFrame(threshold_results)
                print("\n   Threshold Analysis Results:")
                display(threshold_df.round(4))
                
            except Exception as e:
                print(f"‚ùå Error computing metrics: {e}")
                import traceback
                traceback.print_exc()
    
    compute_button.on_click(on_compute_clicked)
    
    # Layout
    controls = widgets.VBox([
        widgets.HTML("<h3>Metrics Explorer</h3>"),
        widgets.HBox([num_samples, num_classes]),
        widgets.HBox([positive_rate, noise_level]),
        compute_button
    ])
    
    return widgets.VBox([controls, output])

# Display metrics explorer
print("\nüìä Metrics Explorer:")
metrics_explorer = create_metrics_explorer()
display(metrics_explorer)

## EMA (Exponential Moving Average) Explorer

Interactive exploration of EMA for model weights.

In [None]:
def create_ema_explorer():
    """Create an interactive EMA explorer."""
    
    # EMA parameters
    decay_rate = widgets.FloatSlider(
        value=0.999,
        min=0.9,
        max=0.9999,
        step=0.0001,
        description='Decay Rate:',
        readout_format='.4f'
    )
    
    num_updates = widgets.IntSlider(
        value=1000,
        min=100,
        max=10000,
        step=100,
        description='Updates:'
    )
    
    test_button = widgets.Button(
        description='üß™ Test EMA',
        button_style='primary'
    )
    
    output = widgets.Output()
    
    def on_test_clicked(b):
        with output:
            output.clear_output()
            
            try:
                print(f"üß™ Testing EMA with decay={decay_rate.value}")
                print("=" * 40)
                
                # Create a simple model
                model = nn.Linear(10, 1)
                
                # Initialize EMA
                ema = EMA(model, decay=decay_rate.value)
                
                print(f"   Model parameters: {sum(p.numel() for p in model.parameters())}")
                print(f"   EMA decay rate: {decay_rate.value}")
                print(f"   Effective window: ~{1/(1-decay_rate.value):.1f} updates")
                
                # Simulate training updates
                original_weights = []
                ema_weights = []
                weight_differences = []
                
                # Get initial weight
                initial_weight = model.weight.data[0, 0].item()
                
                for step in range(num_updates.value):
                    # Simulate weight update (add some noise)
                    with torch.no_grad():
                        model.weight.data += torch.randn_like(model.weight.data) * 0.01
                    
                    # Update EMA
                    ema.update()
                    
                    # Record weights (just first weight for visualization)
                    if step % 10 == 0:  # Sample every 10 steps
                        current_weight = model.weight.data[0, 0].item()
                        ema_weight = ema.ema_model.weight.data[0, 0].item()
                        
                        original_weights.append(current_weight)
                        ema_weights.append(ema_weight)
                        weight_differences.append(abs(current_weight - ema_weight))
                
                # Visualize EMA behavior
                steps_sampled = list(range(0, num_updates.value, 10))
                
                plt.figure(figsize=(15, 10))
                
                # Weight evolution
                plt.subplot(2, 2, 1)
                plt.plot(steps_sampled, original_weights, 'b-', alpha=0.7, label='Original Model', linewidth=1)
                plt.plot(steps_sampled, ema_weights, 'r-', label='EMA Model', linewidth=2)
                plt.xlabel('Training Step')
                plt.ylabel('Weight Value')
                plt.title('Weight Evolution: Original vs EMA')
                plt.legend()
                plt.grid(True, alpha=0.3)
                
                # Weight difference
                plt.subplot(2, 2, 2)
                plt.plot(steps_sampled, weight_differences, 'g-', linewidth=2)
                plt.xlabel('Training Step')
                plt.ylabel('|Original - EMA|')
                plt.title('Absolute Difference Between Models')
                plt.grid(True, alpha=0.3)
                
                # EMA decay analysis
                plt.subplot(2, 2, 3)
                decay_powers = np.arange(1, 101)
                effective_weights = [(1 - decay_rate.value) * (decay_rate.value ** (p-1)) for p in decay_powers]
                
                plt.plot(decay_powers, effective_weights, 'purple', linewidth=2)
                plt.xlabel('Steps Ago')
                plt.ylabel('Effective Weight')
                plt.title('EMA Weight Distribution')
                plt.grid(True, alpha=0.3)
                
                # Cumulative weight
                plt.subplot(2, 2, 4)
                cumulative_weights = np.cumsum(effective_weights)
                plt.plot(decay_powers, cumulative_weights, 'orange', linewidth=2)
                plt.axhline(y=0.5, color='red', linestyle='--', label='50% Weight')
                plt.axhline(y=0.9, color='blue', linestyle='--', label='90% Weight')
                plt.xlabel('Steps Ago')
                plt.ylabel('Cumulative Weight')
                plt.title('Cumulative EMA Weight')
                plt.legend()
                plt.grid(True, alpha=0.3)
                
                plt.tight_layout()
                plt.show()
                
                # EMA statistics
                print(f"\nüìä EMA Statistics:")
                print(f"   Final original weight: {original_weights[-1]:.6f}")
                print(f"   Final EMA weight: {ema_weights[-1]:.6f}")
                print(f"   Final difference: {weight_differences[-1]:.6f}")
                print(f"   Mean difference: {np.mean(weight_differences):.6f}")
                print(f"   Max difference: {np.max(weight_differences):.6f}")
                
                # Effective window analysis
                steps_for_50_percent = None
                steps_for_90_percent = None
                
                for i, cum_weight in enumerate(cumulative_weights):
                    if steps_for_50_percent is None and cum_weight >= 0.5:
                        steps_for_50_percent = i + 1
                    if steps_for_90_percent is None and cum_weight >= 0.9:
                        steps_for_90_percent = i + 1
                        break
                
                print(f"\nüéØ Effective Window Analysis:")
                print(f"   50% of weight from last {steps_for_50_percent} steps")
                print(f"   90% of weight from last {steps_for_90_percent} steps")
                print(f"   Theoretical window: {1/(1-decay_rate.value):.1f} steps")
                
                # Compare different decay rates
                print(f"\nüîÑ Decay Rate Comparison:")
                decay_rates = [0.99, 0.995, 0.999, 0.9995, 0.9999]
                
                for dr in decay_rates:
                    effective_window = 1 / (1 - dr)
                    print(f"   Decay {dr}: ~{effective_window:.1f} step window")
                
            except Exception as e:
                print(f"‚ùå Error testing EMA: {e}")
                import traceback
                traceback.print_exc()
    
    test_button.on_click(on_test_clicked)
    
    # Layout
    controls = widgets.VBox([
        widgets.HTML("<h3>EMA Explorer</h3>"),
        widgets.HBox([decay_rate, num_updates]),
        test_button
    ])
    
    return widgets.VBox([controls, output])

# Display EMA explorer
print("\nüß™ EMA Explorer:")
ema_explorer = create_ema_explorer()
display(ema_explorer)

print("\n‚úÖ Utilities and Helpers notebook complete!")
print("\nThis notebook provides:")
print("‚Ä¢ Interactive loss function exploration and testing")
print("‚Ä¢ Optimizer and scheduler configuration testing")
print("‚Ä¢ Comprehensive metrics computation and analysis")
print("‚Ä¢ EMA behavior visualization and analysis")
print("‚Ä¢ Utility function testing and validation")