In [None]:
# === Cell 1: Shared Cache Bootstrap ===
import os, pathlib, torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset, random_split
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import make_classification
from sklearn.preprocessing import StandardScaler
from typing import Dict, List, Optional, Tuple
import time

# Shared cache setup
AI_CACHE_ROOT = os.getenv("AI_CACHE_ROOT", "/mnt/ai/cache")
for k, v in {
    "HF_HOME": f"{AI_CACHE_ROOT}/hf",
    "TRANSFORMERS_CACHE": f"{AI_CACHE_ROOT}/hf/transformers",
    "HF_DATASETS_CACHE": f"{AI_CACHE_ROOT}/hf/datasets",
    "HUGGINGFACE_HUB_CACHE": f"{AI_CACHE_ROOT}/hf/hub",
    "TORCH_HOME": f"{AI_CACHE_ROOT}/torch",
}.items():
    os.environ[k] = v
    pathlib.Path(v).mkdir(parents=True, exist_ok=True)

print(f"[Cache] Root: {AI_CACHE_ROOT}")
print(f"[GPU] Available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"[GPU] Device: {torch.cuda.get_device_name(0)}")
    print(
        f"[GPU] Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB"
    )

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"[Device] Using: {device}")

In [None]:
# === Cell 2: Custom nn.Module - SimpleMLP Implementation ===
class SimpleMLP(nn.Module):
    """
    Simple Multi-Layer Perceptron for binary classification
    Features: configurable hidden layers, dropout, batch normalization
    """

    def __init__(
        self,
        input_dim: int,
        hidden_dims: List[int],
        output_dim: int = 1,
        dropout_rate: float = 0.2,
        use_batch_norm: bool = True,
    ):
        super(SimpleMLP, self).__init__()

        self.input_dim = input_dim
        self.hidden_dims = hidden_dims
        self.output_dim = output_dim
        self.dropout_rate = dropout_rate
        self.use_batch_norm = use_batch_norm

        # Build layers dynamically
        layers = []
        prev_dim = input_dim

        for hidden_dim in hidden_dims:
            # Linear layer
            layers.append(nn.Linear(prev_dim, hidden_dim))

            # Batch normalization (optional)
            if use_batch_norm:
                layers.append(nn.BatchNorm1d(hidden_dim))

            # Activation function
            layers.append(nn.ReLU())

            # Dropout (optional)
            if dropout_rate > 0:
                layers.append(nn.Dropout(dropout_rate))

            prev_dim = hidden_dim

        # Output layer
        layers.append(nn.Linear(prev_dim, output_dim))

        # Combine all layers
        self.layers = nn.Sequential(*layers)

        # Initialize weights
        self._init_weights()

    def _init_weights(self):
        """Xavier/He initialization for better convergence"""
        for module in self.modules():
            if isinstance(module, nn.Linear):
                nn.init.xavier_uniform_(module.weight)
                if module.bias is not None:
                    nn.init.constant_(module.bias, 0)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward pass through the network"""
        return self.layers(x)

    def get_num_parameters(self) -> int:
        """Count total trainable parameters"""
        return sum(p.numel() for p in self.parameters() if p.requires_grad)


# Test the model architecture
model = SimpleMLP(input_dim=20, hidden_dims=[128, 64, 32], output_dim=1)
print(f"[Model] Architecture: {model}")
print(f"[Model] Total parameters: {model.get_num_parameters():,}")

# Test forward pass
dummy_input = torch.randn(32, 20)  # batch_size=32, input_dim=20
with torch.no_grad():
    output = model(dummy_input)
    print(f"[Test] Input shape: {dummy_input.shape}")
    print(f"[Test] Output shape: {output.shape}")

In [None]:
# === Cell 3: Data Loading & Preprocessing ===
def create_synthetic_dataset(
    n_samples: int = 1000,
    n_features: int = 20,
    n_informative: int = 15,
    random_state: int = 42,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Create synthetic binary classification dataset
    Returns normalized features and labels as tensors
    """
    # Generate synthetic data
    X, y = make_classification(
        n_samples=n_samples,
        n_features=n_features,
        n_informative=n_informative,
        n_redundant=0,
        n_clusters_per_class=1,
        random_state=random_state,
    )

    # Normalize features
    scaler = StandardScaler()
    X_normalized = scaler.fit_transform(X)

    # Convert to tensors
    X_tensor = torch.FloatTensor(X_normalized)
    y_tensor = torch.FloatTensor(y).unsqueeze(1)  # Add dimension for BCEWithLogitsLoss

    return X_tensor, y_tensor


def create_data_loaders(
    X: torch.Tensor,
    y: torch.Tensor,
    batch_size: int = 32,
    train_ratio: float = 0.8,
    random_state: int = 42,
) -> Tuple[DataLoader, DataLoader]:
    """
    Create train and validation data loaders with proper splitting
    """
    # Set seed for reproducible splitting
    torch.manual_seed(random_state)

    # Create dataset
    dataset = TensorDataset(X, y)

    # Split into train and validation
    train_size = int(train_ratio * len(dataset))
    val_size = len(dataset) - train_size
    train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

    # Create data loaders
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

    print(f"[Data] Total samples: {len(dataset)}")
    print(f"[Data] Train samples: {len(train_dataset)}")
    print(f"[Data] Validation samples: {len(val_dataset)}")

    return train_loader, val_loader


# Generate dataset
X, y = create_synthetic_dataset(n_samples=1000, n_features=20)
train_loader, val_loader = create_data_loaders(X, y, batch_size=32)

print(f"[Dataset] Feature shape: {X.shape}")
print(f"[Dataset] Label shape: {y.shape}")
print(f"[Dataset] Class distribution: {torch.bincount(y.long().squeeze())}")

In [None]:
# === Cell 4: Manual Training Loop Implementation ===
def train_epoch_manual(
    model: nn.Module,
    train_loader: DataLoader,
    criterion: nn.Module,
    optimizer: optim.Optimizer,
    device: torch.device,
) -> Dict[str, float]:
    """
    Manual implementation of training loop for one epoch
    Returns training metrics
    """
    model.train()  # Set to training mode

    total_loss = 0.0
    correct_predictions = 0
    total_samples = 0

    for batch_idx, (data, target) in enumerate(train_loader):
        # Move data to device
        data, target = data.to(device), target.to(device)

        # Zero gradients
        optimizer.zero_grad()

        # Forward pass
        output = model(data)

        # Compute loss
        loss = criterion(output, target)

        # Backward pass
        loss.backward()

        # Update parameters
        optimizer.step()

        # Accumulate metrics
        total_loss += loss.item()

        # Calculate accuracy (for binary classification)
        predicted = (torch.sigmoid(output) > 0.5).float()
        correct_predictions += (predicted == target).sum().item()
        total_samples += target.size(0)

    # Calculate average metrics
    avg_loss = total_loss / len(train_loader)
    accuracy = correct_predictions / total_samples

    return {"loss": avg_loss, "accuracy": accuracy}


def validate_epoch_manual(
    model: nn.Module, val_loader: DataLoader, criterion: nn.Module, device: torch.device
) -> Dict[str, float]:
    """
    Manual implementation of validation loop for one epoch
    Returns validation metrics
    """
    model.eval()  # Set to evaluation mode

    total_loss = 0.0
    correct_predictions = 0
    total_samples = 0

    with torch.no_grad():  # Disable gradient computation
        for data, target in val_loader:
            # Move data to device
            data, target = data.to(device), target.to(device)

            # Forward pass
            output = model(data)

            # Compute loss
            loss = criterion(output, target)

            # Accumulate metrics
            total_loss += loss.item()

            # Calculate accuracy
            predicted = (torch.sigmoid(output) > 0.5).float()
            correct_predictions += (predicted == target).sum().item()
            total_samples += target.size(0)

    # Calculate average metrics
    avg_loss = total_loss / len(val_loader)
    accuracy = correct_predictions / total_samples

    return {"loss": avg_loss, "accuracy": accuracy}


# Test manual training loop
model = SimpleMLP(input_dim=20, hidden_dims=[64, 32], output_dim=1).to(device)
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

print("[Manual Training] Testing one epoch...")
train_metrics = train_epoch_manual(model, train_loader, criterion, optimizer, device)
val_metrics = validate_epoch_manual(model, val_loader, criterion, device)

print(
    f"Train - Loss: {train_metrics['loss']:.4f}, Accuracy: {train_metrics['accuracy']:.4f}"
)
print(f"Val - Loss: {val_metrics['loss']:.4f}, Accuracy: {val_metrics['accuracy']:.4f}")

In [None]:
# === Cell 5: Reusable Trainer Class Design ===
class Trainer:
    """
    Reusable trainer class for PyTorch models
    Supports: early stopping, learning rate scheduling, gradient clipping, checkpointing
    """

    def __init__(
        self,
        model: nn.Module,
        criterion: nn.Module,
        optimizer: optim.Optimizer,
        device: torch.device,
        scheduler: Optional[optim.lr_scheduler._LRScheduler] = None,
        gradient_clip_val: Optional[float] = None,
        early_stopping_patience: int = 10,
    ):

        self.model = model.to(device)
        self.criterion = criterion
        self.optimizer = optimizer
        self.device = device
        self.scheduler = scheduler
        self.gradient_clip_val = gradient_clip_val
        self.early_stopping_patience = early_stopping_patience

        # Training history
        self.history = {
            "train_loss": [],
            "train_accuracy": [],
            "val_loss": [],
            "val_accuracy": [],
            "learning_rates": [],
        }

        # Early stopping variables
        self.best_val_loss = float("inf")
        self.patience_counter = 0
        self.best_model_state = None

    def train_epoch(self, train_loader: DataLoader) -> Dict[str, float]:
        """Train for one epoch"""
        self.model.train()

        total_loss = 0.0
        correct_predictions = 0
        total_samples = 0

        for data, target in train_loader:
            data, target = data.to(self.device), target.to(self.device)

            # Zero gradients
            self.optimizer.zero_grad()

            # Forward pass
            output = self.model(data)
            loss = self.criterion(output, target)

            # Backward pass
            loss.backward()

            # Gradient clipping (optional)
            if self.gradient_clip_val is not None:
                torch.nn.utils.clip_grad_norm_(
                    self.model.parameters(), self.gradient_clip_val
                )

            # Update parameters
            self.optimizer.step()

            # Accumulate metrics
            total_loss += loss.item()
            predicted = (torch.sigmoid(output) > 0.5).float()
            correct_predictions += (predicted == target).sum().item()
            total_samples += target.size(0)

        return {
            "loss": total_loss / len(train_loader),
            "accuracy": correct_predictions / total_samples,
        }

    def validate_epoch(self, val_loader: DataLoader) -> Dict[str, float]:
        """Validate for one epoch"""
        self.model.eval()

        total_loss = 0.0
        correct_predictions = 0
        total_samples = 0

        with torch.no_grad():
            for data, target in val_loader:
                data, target = data.to(self.device), target.to(self.device)

                output = self.model(data)
                loss = self.criterion(output, target)

                total_loss += loss.item()
                predicted = (torch.sigmoid(output) > 0.5).float()
                correct_predictions += (predicted == target).sum().item()
                total_samples += target.size(0)

        return {
            "loss": total_loss / len(val_loader),
            "accuracy": correct_predictions / total_samples,
        }

    def fit(
        self,
        train_loader: DataLoader,
        val_loader: DataLoader,
        epochs: int,
        verbose: bool = True,
    ) -> Dict[str, List[float]]:
        """
        Train the model for multiple epochs
        Returns training history
        """
        print(f"[Trainer] Starting training for {epochs} epochs...")
        print(
            f"[Trainer] Model parameters: {sum(p.numel() for p in self.model.parameters()):,}"
        )

        for epoch in range(epochs):
            start_time = time.time()

            # Train and validate
            train_metrics = self.train_epoch(train_loader)
            val_metrics = self.validate_epoch(val_loader)

            # Update learning rate scheduler
            if self.scheduler is not None:
                if isinstance(self.scheduler, optim.lr_scheduler.ReduceLROnPlateau):
                    self.scheduler.step(val_metrics["loss"])
                else:
                    self.scheduler.step()

            # Record metrics
            self.history["train_loss"].append(train_metrics["loss"])
            self.history["train_accuracy"].append(train_metrics["accuracy"])
            self.history["val_loss"].append(val_metrics["loss"])
            self.history["val_accuracy"].append(val_metrics["accuracy"])
            self.history["learning_rates"].append(self.optimizer.param_groups[0]["lr"])

            # Early stopping check
            if val_metrics["loss"] < self.best_val_loss:
                self.best_val_loss = val_metrics["loss"]
                self.patience_counter = 0
                self.best_model_state = self.model.state_dict().copy()
            else:
                self.patience_counter += 1

            # Print progress
            if verbose:
                epoch_time = time.time() - start_time
                print(
                    f"Epoch {epoch+1:3d}/{epochs} | "
                    f"Train Loss: {train_metrics['loss']:.4f} | "
                    f"Train Acc: {train_metrics['accuracy']:.4f} | "
                    f"Val Loss: {val_metrics['loss']:.4f} | "
                    f"Val Acc: {val_metrics['accuracy']:.4f} | "
                    f"LR: {self.optimizer.param_groups[0]['lr']:.6f} | "
                    f"Time: {epoch_time:.2f}s"
                )

            # Early stopping
            if self.patience_counter >= self.early_stopping_patience:
                print(f"[Trainer] Early stopping triggered after {epoch+1} epochs")
                print(f"[Trainer] Best validation loss: {self.best_val_loss:.4f}")
                break

        # Load best model
        if self.best_model_state is not None:
            self.model.load_state_dict(self.best_model_state)
            print(f"[Trainer] Loaded best model (val_loss: {self.best_val_loss:.4f})")

        return self.history

    def plot_training_history(self):
        """Plot training history"""
        fig, axes = plt.subplots(2, 2, figsize=(12, 8))

        # Loss plot
        axes[0, 0].plot(self.history["train_loss"], label="Train Loss", color="blue")
        axes[0, 0].plot(self.history["val_loss"], label="Val Loss", color="red")
        axes[0, 0].set_title("Loss")
        axes[0, 0].set_xlabel("Epoch")
        axes[0, 0].set_ylabel("Loss")
        axes[0, 0].legend()
        axes[0, 0].grid(True)

        # Accuracy plot
        axes[0, 1].plot(self.history["train_accuracy"], label="Train Acc", color="blue")
        axes[0, 1].plot(self.history["val_accuracy"], label="Val Acc", color="red")
        axes[0, 1].set_title("Accuracy")
        axes[0, 1].set_xlabel("Epoch")
        axes[0, 1].set_ylabel("Accuracy")
        axes[0, 1].legend()
        axes[0, 1].grid(True)

        # Learning rate plot
        axes[1, 0].plot(self.history["learning_rates"], color="green")
        axes[1, 0].set_title("Learning Rate")
        axes[1, 0].set_xlabel("Epoch")
        axes[1, 0].set_ylabel("Learning Rate")
        axes[1, 0].grid(True)

        # Combined metrics
        axes[1, 1].plot(self.history["train_loss"], label="Train Loss", alpha=0.7)
        axes[1, 1].plot(self.history["val_loss"], label="Val Loss", alpha=0.7)
        ax2 = axes[1, 1].twinx()
        ax2.plot(
            self.history["train_accuracy"], label="Train Acc", color="orange", alpha=0.7
        )
        ax2.plot(
            self.history["val_accuracy"], label="Val Acc", color="purple", alpha=0.7
        )
        axes[1, 1].set_title("Combined Metrics")
        axes[1, 1].set_xlabel("Epoch")
        axes[1, 1].set_ylabel("Loss")
        ax2.set_ylabel("Accuracy")
        axes[1, 1].legend(loc="upper left")
        ax2.legend(loc="upper right")

        plt.tight_layout()
        plt.show()

In [None]:
# === Cell 6: Model Training & Validation Experiment ===
# Initialize fresh model and training components
model = SimpleMLP(input_dim=20, hidden_dims=[128, 64, 32], output_dim=1)
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-5)

# Optional: Learning rate scheduler
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode="min", factor=0.5, patience=5, verbose=True
)

# Create trainer
trainer = Trainer(
    model=model,
    criterion=criterion,
    optimizer=optimizer,
    device=device,
    scheduler=scheduler,
    gradient_clip_val=1.0,
    early_stopping_patience=15,
)

# Train the model
print("=" * 60)
print("TRAINING EXPERIMENT")
print("=" * 60)

history = trainer.fit(
    train_loader=train_loader, val_loader=val_loader, epochs=50, verbose=True
)

# Plot training history
trainer.plot_training_history()

# Final evaluation
print("\n" + "=" * 40)
print("FINAL EVALUATION")
print("=" * 40)

final_train_metrics = trainer.validate_epoch(train_loader)
final_val_metrics = trainer.validate_epoch(val_loader)

print(
    f"Final Train - Loss: {final_train_metrics['loss']:.4f}, Accuracy: {final_train_metrics['accuracy']:.4f}"
)
print(
    f"Final Val   - Loss: {final_val_metrics['loss']:.4f}, Accuracy: {final_val_metrics['accuracy']:.4f}"
)

In [None]:
# === Cell 7: Model Save/Load & Checkpoint Mechanism ===
import json
from datetime import datetime


class ModelCheckpoint:
    """
    Utility class for saving and loading model checkpoints
    Includes model state, optimizer state, and training metadata
    """

    @staticmethod
    def save_checkpoint(
        model: nn.Module,
        optimizer: optim.Optimizer,
        epoch: int,
        metrics: Dict[str, float],
        history: Dict[str, List[float]],
        filepath: str,
        metadata: Optional[Dict] = None,
    ):
        """
        Save complete training checkpoint
        """
        checkpoint = {
            "model_state_dict": model.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
            "epoch": epoch,
            "metrics": metrics,
            "history": history,
            "timestamp": datetime.now().isoformat(),
            "model_info": {
                "class_name": model.__class__.__name__,
                "num_parameters": sum(p.numel() for p in model.parameters()),
            },
        }

        if metadata:
            checkpoint["metadata"] = metadata

        torch.save(checkpoint, filepath)
        print(f"[Checkpoint] Saved to: {filepath}")

        # Also save a human-readable summary
        summary_path = filepath.replace(".pth", "_summary.json")
        summary = {
            "model_class": checkpoint["model_info"]["class_name"],
            "num_parameters": checkpoint["model_info"]["num_parameters"],
            "epoch": epoch,
            "metrics": metrics,
            "timestamp": checkpoint["timestamp"],
        }

        with open(summary_path, "w") as f:
            json.dump(summary, f, indent=2)

    @staticmethod
    def load_checkpoint(
        filepath: str,
        model: nn.Module,
        optimizer: Optional[optim.Optimizer] = None,
        device: torch.device = torch.device("cpu"),
    ):
        """
        Load training checkpoint
        Returns: epoch, metrics, history
        """
        checkpoint = torch.load(filepath, map_location=device)

        # Load model state
        model.load_state_dict(checkpoint["model_state_dict"])

        # Load optimizer state if provided
        if optimizer is not None and "optimizer_state_dict" in checkpoint:
            optimizer.load_state_dict(checkpoint["optimizer_state_dict"])

        print(f"[Checkpoint] Loaded from: {filepath}")
        print(f"[Checkpoint] Epoch: {checkpoint['epoch']}")
        print(f"[Checkpoint] Timestamp: {checkpoint['timestamp']}")

        return checkpoint["epoch"], checkpoint["metrics"], checkpoint["history"]


# Save current model checkpoint
checkpoint_dir = pathlib.Path(AI_CACHE_ROOT) / "checkpoints" / "nb02_experiments"
checkpoint_dir.mkdir(parents=True, exist_ok=True)

checkpoint_path = checkpoint_dir / f"simple_mlp_best_{int(time.time())}.pth"

ModelCheckpoint.save_checkpoint(
    model=trainer.model,
    optimizer=trainer.optimizer,
    epoch=len(trainer.history["train_loss"]),
    metrics=final_val_metrics,
    history=trainer.history,
    filepath=str(checkpoint_path),
    metadata={
        "dataset_info": "synthetic_binary_classification",
        "model_config": {
            "input_dim": 20,
            "hidden_dims": [128, 64, 32],
            "output_dim": 1,
        },
    },
)

# Test loading checkpoint
print("\n[Test] Loading checkpoint...")
new_model = SimpleMLP(input_dim=20, hidden_dims=[128, 64, 32], output_dim=1)
new_optimizer = optim.Adam(new_model.parameters(), lr=0.001)

epoch, metrics, loaded_history = ModelCheckpoint.load_checkpoint(
    checkpoint_path, new_model, new_optimizer, device
)

print(f"[Test] Loaded model metrics: {metrics}")
print(f"[Test] History length: {len(loaded_history['train_loss'])} epochs")

In [None]:
# === Cell 8: Advanced Techniques - Gradient Accumulation & Mixed Precision ===
from torch.cuda.amp import autocast, GradScaler


class AdvancedTrainer(Trainer):
    """
    Extended trainer with gradient accumulation and mixed precision support
    """

    def __init__(
        self,
        model: nn.Module,
        criterion: nn.Module,
        optimizer: optim.Optimizer,
        device: torch.device,
        scheduler: Optional[optim.lr_scheduler._LRScheduler] = None,
        gradient_clip_val: Optional[float] = None,
        early_stopping_patience: int = 10,
        gradient_accumulation_steps: int = 1,
        use_mixed_precision: bool = False,
    ):

        super().__init__(
            model,
            criterion,
            optimizer,
            device,
            scheduler,
            gradient_clip_val,
            early_stopping_patience,
        )

        self.gradient_accumulation_steps = gradient_accumulation_steps
        self.use_mixed_precision = use_mixed_precision

        # Mixed precision scaler
        self.scaler = (
            GradScaler() if use_mixed_precision and device.type == "cuda" else None
        )

        print(
            f"[AdvancedTrainer] Gradient accumulation steps: {gradient_accumulation_steps}"
        )
        print(f"[AdvancedTrainer] Mixed precision: {use_mixed_precision}")

    def train_epoch(self, train_loader: DataLoader) -> Dict[str, float]:
        """Enhanced training with gradient accumulation and mixed precision"""
        self.model.train()

        total_loss = 0.0
        correct_predictions = 0
        total_samples = 0

        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(self.device), target.to(self.device)

            # Mixed precision forward pass
            if self.use_mixed_precision and self.scaler is not None:
                with autocast():
                    output = self.model(data)
                    loss = self.criterion(output, target)
                    # Scale loss for gradient accumulation
                    loss = loss / self.gradient_accumulation_steps

                # Backward pass with gradient scaling
                self.scaler.scale(loss).backward()

                # Update weights every N steps
                if (batch_idx + 1) % self.gradient_accumulation_steps == 0:
                    if self.gradient_clip_val is not None:
                        self.scaler.unscale_(self.optimizer)
                        torch.nn.utils.clip_grad_norm_(
                            self.model.parameters(), self.gradient_clip_val
                        )

                    self.scaler.step(self.optimizer)
                    self.scaler.update()
                    self.optimizer.zero_grad()

            else:
                # Standard training
                output = self.model(data)
                loss = self.criterion(output, target)
                loss = loss / self.gradient_accumulation_steps

                loss.backward()

                if (batch_idx + 1) % self.gradient_accumulation_steps == 0:
                    if self.gradient_clip_val is not None:
                        torch.nn.utils.clip_grad_norm_(
                            self.model.parameters(), self.gradient_clip_val
                        )

                    self.optimizer.step()
                    self.optimizer.zero_grad()

            # Accumulate metrics (use original loss scale)
            total_loss += loss.item() * self.gradient_accumulation_steps
            predicted = (torch.sigmoid(output) > 0.5).float()
            correct_predictions += (predicted == target).sum().item()
            total_samples += target.size(0)

        return {
            "loss": total_loss / len(train_loader),
            "accuracy": correct_predictions / total_samples,
        }


# Demonstrate advanced training (optional for GPU users)
if torch.cuda.is_available():
    print("\n" + "=" * 50)
    print("ADVANCED TRAINING DEMO (GPU)")
    print("=" * 50)

    # Create advanced trainer with gradient accumulation and mixed precision
    advanced_model = SimpleMLP(input_dim=20, hidden_dims=[256, 128, 64], output_dim=1)
    advanced_optimizer = optim.AdamW(
        advanced_model.parameters(), lr=0.001, weight_decay=1e-4
    )

    advanced_trainer = AdvancedTrainer(
        model=advanced_model,
        criterion=nn.BCEWithLogitsLoss(),
        optimizer=advanced_optimizer,
        device=device,
        gradient_accumulation_steps=4,  # Simulate larger batch size
        use_mixed_precision=True,
        gradient_clip_val=1.0,
        early_stopping_patience=10,
    )

    # Quick training run
    print(
        "[Advanced] Running 10 epochs with gradient accumulation + mixed precision..."
    )
    advanced_history = advanced_trainer.fit(
        train_loader, val_loader, epochs=10, verbose=True
    )

else:
    print("\n[Advanced] Skipping advanced features (CPU mode)")
    print("Advanced features (mixed precision) require CUDA-enabled GPU")

In [None]:
# === Cell 9: Smoke Test - Verification ===
def run_smoke_tests():
    """
    Comprehensive smoke tests to verify all components work correctly
    """
    print("🧪 RUNNING SMOKE TESTS...")
    print("=" * 40)

    test_results = []

    # Test 1: Model creation and forward pass
    try:
        test_model = SimpleMLP(input_dim=10, hidden_dims=[32, 16], output_dim=1)
        test_input = torch.randn(8, 10)
        test_output = test_model(test_input)
        assert test_output.shape == (
            8,
            1,
        ), f"Expected shape (8, 1), got {test_output.shape}"
        test_results.append("✅ Model creation and forward pass")
    except Exception as e:
        test_results.append(f"❌ Model creation failed: {e}")

    # Test 2: Parameter counting
    try:
        param_count = test_model.get_num_parameters()
        assert param_count > 0, "Parameter count should be positive"
        test_results.append(f"✅ Parameter counting: {param_count:,} params")
    except Exception as e:
        test_results.append(f"❌ Parameter counting failed: {e}")

    # Test 3: Data loading
    try:
        X_test, y_test = create_synthetic_dataset(n_samples=100, n_features=10)
        train_loader_test, val_loader_test = create_data_loaders(
            X_test, y_test, batch_size=16
        )
        assert len(train_loader_test) > 0, "Train loader should not be empty"
        assert len(val_loader_test) > 0, "Val loader should not be empty"
        test_results.append("✅ Data loading and splitting")
    except Exception as e:
        test_results.append(f"❌ Data loading failed: {e}")

    # Test 4: Training loop
    try:
        smoke_trainer = Trainer(
            model=test_model,
            criterion=nn.BCEWithLogitsLoss(),
            optimizer=optim.Adam(test_model.parameters(), lr=0.01),
            device=torch.device("cpu"),  # Force CPU for reliability
            early_stopping_patience=5,
        )

        # Train for just 2 epochs
        smoke_history = smoke_trainer.fit(
            train_loader_test, val_loader_test, epochs=2, verbose=False
        )
        assert len(smoke_history["train_loss"]) == 2, "Should have 2 epochs of history"
        test_results.append("✅ Training loop execution")
    except Exception as e:
        test_results.append(f"❌ Training loop failed: {e}")

    # Test 5: Checkpoint save/load
    try:
        checkpoint_path = f"/tmp/smoke_test_checkpoint_{int(time.time())}.pth"

        ModelCheckpoint.save_checkpoint(
            model=test_model,
            optimizer=smoke_trainer.optimizer,
            epoch=2,
            metrics={"loss": 0.5, "accuracy": 0.8},
            history=smoke_history,
            filepath=checkpoint_path,
        )

        # Load checkpoint
        new_test_model = SimpleMLP(input_dim=10, hidden_dims=[32, 16], output_dim=1)
        epoch, metrics, history = ModelCheckpoint.load_checkpoint(
            checkpoint_path, new_test_model, device=torch.device("cpu")
        )

        assert epoch == 2, f"Expected epoch 2, got {epoch}"
        assert "loss" in metrics, "Metrics should contain loss"

        # Cleanup
        os.remove(checkpoint_path)
        if os.path.exists(checkpoint_path.replace(".pth", "_summary.json")):
            os.remove(checkpoint_path.replace(".pth", "_summary.json"))

        test_results.append("✅ Checkpoint save/load")
    except Exception as e:
        test_results.append(f"❌ Checkpoint save/load failed: {e}")

    # Print results
    print("\nSMOKE TEST RESULTS:")
    for result in test_results:
        print(f"  {result}")

    # Overall status
    passed_tests = sum(1 for r in test_results if r.startswith("✅"))
    total_tests = len(test_results)

    print(f"\n📊 SUMMARY: {passed_tests}/{total_tests} tests passed")

    if passed_tests == total_tests:
        print("🎉 ALL TESTS PASSED! Ready for production use.")
        return True
    else:
        print("⚠️  Some tests failed. Please review the issues above.")
        return False


# Run smoke tests
smoke_test_success = run_smoke_tests()

In [None]:
# === Cell 10: Chapter Summary & Next Steps ===
print("\n" + "🎯 CHAPTER SUMMARY" + "\n" + "=" * 50)

print("✅ COMPLETED ITEMS:")
print("  • 自訂 nn.Module 類別設計 (Custom nn.Module class design)")
print("  • 完整訓練迴圈實作 (Complete training loop implementation)")
print("  • 可重用 Trainer 類別 (Reusable Trainer class)")
print("  • Early stopping 與 learning rate scheduling")
print("  • 模型檢查點保存與載入 (Model checkpointing)")
print("  • 進階技巧：gradient accumulation & mixed precision")
print("  • 完整的驗收測試機制 (Comprehensive smoke testing)")

print("\n🧠 CORE CONCEPTS:")
print("  • nn.Module 繼承與 forward() 方法實作")
print(
    "  • 訓練迴圈核心組件：loss.backward() → optimizer.step() → optimizer.zero_grad()"
)
print("  • train()/eval() 模式切換的重要性")
print("  • 梯度累積用於模擬大 batch size")
print("  • Mixed precision 提升訓練效率")
print("  • 檢查點機制確保訓練可恢復性")

print("\n⚠️  COMMON PITFALLS:")
print("  • 忘記 optimizer.zero_grad() 導致梯度累積")
print("  • train()/eval() 模式不正確影響 BatchNorm/Dropout")
print("  • 忽略 gradient clipping 導致梯度爆炸")
print("  • 未設置隨機種子導致結果不可重現")
print("  • GPU memory leak (忘記 detach() 或清理變數)")

print("\n🚀 NEXT STEPS:")
print("  • 進入 nb03: HF Datasets 前處理管線")
print("  • 學習處理大規模文本/圖像/語音資料")
print("  • 準備進入 Transformer 架構學習")
print("  • 建立可擴展的資料處理工作流程")

print("\n💡 PRACTICAL RECOMMENDATIONS:")
print("  • 始終從小模型開始驗證，再擴展到大模型")
print("  • 使用 tensorboard 或 wandb 追蹤實驗（可選）")
print("  • 建立模型架構實驗的標準化模板")
print("  • 在移到 GPU 前先在 CPU 上驗證邏輯正確性")
print("  • 設置適當的 early stopping 避免過擬合")

if smoke_test_success:
    print("\n🎊 READY FOR NEXT CHAPTER!")
    print("    All components verified and working correctly.")
else:
    print("\n⚠️  PLEASE REVIEW FAILED TESTS BEFORE PROCEEDING")

print("\n" + "=" * 50)

In [None]:
# Quick validation that everything works
print("🔍 FINAL VALIDATION")
print("=" * 30)

# Test model creation
quick_model = SimpleMLP(input_dim=5, hidden_dims=[16, 8], output_dim=1)
test_data = torch.randn(4, 5)
output = quick_model(test_data)

print(f"✅ Model output shape: {output.shape}")
print(f"✅ Model parameters: {quick_model.get_num_parameters():,}")
print(f"✅ Forward pass successful")

# Test training components
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(quick_model.parameters(), lr=0.01)

print(f"✅ Loss function: {criterion.__class__.__name__}")
print(f"✅ Optimizer: {optimizer.__class__.__name__}")
print("🎉 All core components ready!")


## 6. 本章小結

### ✅ 完成項目
- **自訂 nn.Module 實作**：完整的 SimpleMLP 類別，支援動態層數配置
- **標準訓練迴圈**：手動實作與 Trainer 類別封裝兩種方法
- **進階訓練功能**：early stopping、learning rate scheduling、gradient clipping
- **檢查點機制**：完整的模型狀態保存與恢復功能
- **效能優化**：gradient accumulation 與 mixed precision 支援
- **品質保證**：comprehensive smoke tests 確保代碼可靠性

### 🧠 核心原理要點
- **nn.Module 設計模式**：透過 `__init__` 定義層結構，`forward` 實作前向傳播
- **訓練迴圈本質**：`loss.backward()` → `optimizer.step()` → `optimizer.zero_grad()` 循環
- **模式切換重要性**：`train()/eval()` 影響 BatchNorm、Dropout 行為
- **記憶體管理**：適當使用 `torch.no_grad()` 與變數清理避免記憶體洩漏
- **可重現性**：設定隨機種子確保實驗結果一致

### 🚀 下一步建議
1. **立即進行**：開始 `nb03_data_preprocessing.ipynb`，學習 HF Datasets 處理大規模資料
2. **中期目標**：完成 Part A 基礎部分，為 Transformer 學習做準備
3. **長期規劃**：將 Trainer 類別擴展支援分散式訓練與更多優化器

**準備就緒進入下一章！** 🎉