# PyTorch深度学习完全教程

PyTorch是当前最流行的深度学习框架之一，以其动态计算图、直观的API设计和强大的GPU加速能力而著称。本教程将带你从基础到高级，全面掌握PyTorch在实际项目中的应用。

## 🎯 学习目标

通过本教程，你将掌握：
- **PyTorch基础**：张量操作、自动微分、计算图
- **模型构建**：自定义网络架构、层的组合、参数管理
- **数据处理**：自定义数据集、数据加载器、数据增强
- **训练流程**：完整的训练/验证/测试循环、损失函数、优化器
- **模型管理**：断点保存恢复、模型版本控制、最佳模型选择
- **实验跟踪**：TensorBoard可视化、日志记录、超参数跟踪
- **生产部署**：模型导出、推理优化、性能监控

## 📋 完整内容大纲

### 1. 环境配置与PyTorch基础
- CUDA环境检查
- 张量创建与操作
- 自动微分机制

### 2. 自定义神经网络模型
- 基础模型架构
- 卷积神经网络（CNN）
- 高级网络组件

### 3. 自定义数据集与数据加载
- MNIST数据集处理
- 自定义Dataset类
- 数据预处理与增强

### 4. 完整训练流程
- 训练、验证、测试划分
- 损失函数与优化器
- 学习率调度

### 5. 断点保存与恢复
- 模型状态保存
- 训练中断恢复
- 最佳模型管理

### 6. TensorBoard可视化
- 损失与指标跟踪
- 模型结构可视化
- 参数分布监控

### 7. 高级日志与实验管理
- 结构化日志记录
- 超参数实验跟踪
- 性能分析工具

### 8. 实际项目案例
- MNIST手写数字识别
- 端到端项目流程
- 最佳实践总结

**重点案例**：基于MNIST数据集的手写数字分类，从零构建完整的深度学习项目，包含所有生产级特性。

PyTorch以其Pythonic的设计哲学和研究友好的特性，成为学术界和工业界的首选深度学习框架。让我们开始这个精彩的学习之旅！

In [None]:
print("=== PyTorch深度学习环境配置 ===")

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, random_split
import torchvision
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter
from torch.optim.lr_scheduler import StepLR, ReduceLROnPlateau, CosineAnnealingLR

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.notebook import tqdm
import os
import json
import pickle
from datetime import datetime
import logging
import warnings
from pathlib import Path
from collections import defaultdict
import time
import copy

# 设置警告过滤
warnings.filterwarnings('ignore')

# 设置随机种子，确保结果可重现
def set_seed(seed=42):
    """设置随机种子以确保实验可重现性"""
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)

# 环境信息检查
print(f"PyTorch版本: {torch.__version__}")
print(f"Torchvision版本: {torchvision.__version__}")
print(f"NumPy版本: {np.__version__}")

# CUDA环境详细检查
print(f"\n=== GPU环境信息 ===")
print(f"CUDA可用: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA版本: {torch.version.cuda}")
    print(f"cuDNN版本: {torch.backends.cudnn.version()}")
    print(f"GPU设备数量: {torch.cuda.device_count()}")
    for i in range(torch.cuda.device_count()):
        print(f"GPU {i}: {torch.cuda.get_device_name(i)}")
        print(f"  显存总量: {torch.cuda.get_device_properties(i).total_memory / 1024**3:.1f} GB")
else:
    print("未检测到CUDA设备，将使用CPU训练")

# 设备配置
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"\n使用设备: {device}")

# 创建项目目录结构
project_dirs = {
    'models': 'saved_models',
    'logs': 'logs', 
    'tensorboard': 'runs',
    'data': 'data',
    'checkpoints': 'checkpoints',
    'outputs': 'outputs'
}

print(f"\n=== 创建项目目录结构 ===")
for name, path in project_dirs.items():
    Path(path).mkdir(exist_ok=True)
    print(f"✓ {name}: {path}/")

# 配置matplotlib
plt.rcParams['figure.figsize'] = (12, 8)
plt.rcParams['font.size'] = 12
sns.set_style("whitegrid")

# PyTorch设置优化
if torch.cuda.is_available():
    # 启用自动混合精度训练
    torch.backends.cudnn.benchmark = True
    print(f"✓ 启用cuDNN基准测试模式以优化性能")

print(f"\n=== 环境配置完成 ===")
print(f"项目初始化时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")

# 验证PyTorch基本功能
print(f"\n=== PyTorch功能验证 ===")
# 创建测试张量
test_tensor = torch.randn(2, 3, device=device)
print(f"测试张量创建成功: {test_tensor.shape} on {test_tensor.device}")

# 测试自动微分
x = torch.tensor([2.0], requires_grad=True, device=device)
y = x ** 2 + 3 * x + 1
y.backward()
print(f"自动微分测试: f(2) = {y.item():.2f}, f'(2) = {x.grad.item():.2f}")

print(f"✓ PyTorch环境验证完成，准备开始深度学习之旅！")

## 2. PyTorch张量基础

张量（Tensor）是PyTorch的核心数据结构，类似于NumPy的数组但支持GPU加速和自动微分。掌握张量操作是深度学习的基础。

In [None]:
# 2.1 张量创建与基本操作
print("=== PyTorch张量基础操作 ===")

# 2.1.1 张量创建的多种方式
print("1. 张量创建方式:")

# 从Python列表创建
tensor_from_list = torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=torch.float32, device=device)
print(f"从列表创建: \n{tensor_from_list}")

# 创建特殊张量
zeros_tensor = torch.zeros(3, 4, device=device)
ones_tensor = torch.ones(2, 3, device=device)
random_tensor = torch.randn(2, 3, device=device)
arange_tensor = torch.arange(0, 10, 2, device=device)

print(f"\n零张量 (3x4): \n{zeros_tensor}")
print(f"\n单位张量 (2x3): \n{ones_tensor}")
print(f"\n随机张量 (2x3): \n{random_tensor}")
print(f"\n等差数列张量: {arange_tensor}")

# 根据现有张量创建
like_tensor = torch.zeros_like(tensor_from_list)
rand_like_tensor = torch.randn_like(tensor_from_list)

print(f"\n类似形状的零张量: \n{like_tensor}")
print(f"\n类似形状的随机张量: \n{rand_like_tensor}")

# 2.1.2 张量属性
print(f"\n2. 张量属性:")
sample_tensor = torch.randn(2, 3, 4, device=device)
print(f"张量形状: {sample_tensor.shape}")
print(f"张量维度: {sample_tensor.dim()}")
print(f"张量数据类型: {sample_tensor.dtype}")
print(f"张量设备: {sample_tensor.device}")
print(f"张量大小: {sample_tensor.size()}")
print(f"元素总数: {sample_tensor.numel()}")

# 2.1.3 张量运算
print(f"\n3. 张量运算:")

# 基本数学运算
a = torch.tensor([[1, 2], [3, 4]], dtype=torch.float32, device=device)
b = torch.tensor([[5, 6], [7, 8]], dtype=torch.float32, device=device)

print(f"张量 a: \n{a}")
print(f"张量 b: \n{b}")

# 元素级运算
print(f"\n加法: \n{a + b}")
print(f"减法: \n{a - b}")
print(f"乘法: \n{a * b}")
print(f"除法: \n{a / b}")
print(f"幂运算: \n{a ** 2}")

# 矩阵运算
print(f"\n矩阵乘法: \n{torch.mm(a, b)}")
print(f"矩阵乘法 (操作符): \n{a @ b}")

# 统计运算
sample_data = torch.randn(100, device=device)
print(f"\n统计运算 (100个随机数):")
print(f"均值: {sample_data.mean().item():.4f}")
print(f"标准差: {sample_data.std().item():.4f}")
print(f"最大值: {sample_data.max().item():.4f}")
print(f"最小值: {sample_data.min().item():.4f}")
print(f"求和: {sample_data.sum().item():.4f}")

# 2.1.4 张量变形
print(f"\n4. 张量变形操作:")
original = torch.arange(12, device=device)
print(f"原始张量: {original}")

# 重塑
reshaped = original.view(3, 4)
print(f"重塑为 3x4: \n{reshaped}")

# 添加维度
unsqueezed = original.unsqueeze(0)  # 在第0维添加
print(f"添加维度后: {unsqueezed.shape}")

# 移除维度
squeezed = unsqueezed.squeeze(0)  # 移除第0维
print(f"移除维度后: {squeezed.shape}")

# 转置
matrix = torch.randn(3, 4, device=device)
transposed = matrix.t()
print(f"原矩阵形状: {matrix.shape}")
print(f"转置后形状: {transposed.shape}")

# 2.1.5 张量索引与切片
print(f"\n5. 张量索引与切片:")
data = torch.arange(24, device=device).view(4, 6)
print(f"原始数据 (4x6): \n{data}")

# 基本索引
print(f"第一行: {data[0]}")
print(f"第一列: {data[:, 0]}")
print(f"左上角2x2: \n{data[:2, :2]}")

# 布尔索引
mask = data > 10
print(f"大于10的元素: {data[mask]}")

# 高级索引
indices = torch.tensor([0, 2], device=device)
print(f"选择第0和第2行: \n{data[indices]}")

# 2.1.6 就地操作与内存管理
print(f"\n6. 就地操作与内存管理:")
x = torch.tensor([1, 2, 3], dtype=torch.float32, device=device)
print(f"原始张量: {x}")
print(f"内存地址: {x.data_ptr()}")

# 非就地操作
y = x + 1
print(f"非就地操作结果: {y}")
print(f"原始张量不变: {x}")

# 就地操作
x.add_(1)  # 等价于 x += 1
print(f"就地操作后: {x}")

# 内存连续性
non_contiguous = torch.randn(3, 4, device=device).t()
print(f"是否内存连续: {non_contiguous.is_contiguous()}")
contiguous = non_contiguous.contiguous()
print(f"连续化后: {contiguous.is_contiguous()}")

# 2.1.7 与NumPy的互操作
print(f"\n7. 与NumPy的互操作:")
if device.type == 'cpu':
    # 只有CPU张量可以与NumPy直接互转
    torch_tensor = torch.tensor([1, 2, 3, 4], dtype=torch.float32)
    numpy_array = torch_tensor.numpy()
    print(f"PyTorch张量: {torch_tensor}")
    print(f"转换为NumPy: {numpy_array}")
    
    # 从NumPy创建张量
    new_torch = torch.from_numpy(numpy_array)
    print(f"从NumPy创建: {new_torch}")
else:
    # GPU张量需要先移到CPU
    gpu_tensor = torch.tensor([1, 2, 3, 4], dtype=torch.float32, device=device)
    cpu_tensor = gpu_tensor.cpu()
    numpy_array = cpu_tensor.numpy()
    print(f"GPU张量转NumPy: {numpy_array}")

print(f"\n✓ PyTorch张量基础操作学习完成！")

# 可视化一些张量操作
fig, axes = plt.subplots(2, 2, figsize=(12, 10))

# 1. 随机张量可视化
random_2d = torch.randn(10, 10, device='cpu')
axes[0, 0].imshow(random_2d.numpy(), cmap='viridis')
axes[0, 0].set_title('随机张量可视化')
axes[0, 0].colorbar = plt.colorbar(axes[0, 0].imshow(random_2d.numpy(), cmap='viridis'), ax=axes[0, 0])

# 2. 张量运算结果
x = torch.linspace(-3, 3, 100)
y1 = torch.sin(x)
y2 = torch.cos(x)
y3 = torch.exp(-x**2)

axes[0, 1].plot(x.numpy(), y1.numpy(), label='sin(x)')
axes[0, 1].plot(x.numpy(), y2.numpy(), label='cos(x)')
axes[0, 1].plot(x.numpy(), y3.numpy(), label='exp(-x²)')
axes[0, 1].set_title('张量数学函数')
axes[0, 1].legend()
axes[0, 1].grid(True)

# 3. 矩阵乘法可视化
A = torch.randn(5, 3)
B = torch.randn(3, 4)
C = torch.mm(A, B)

axes[1, 0].imshow(A.numpy(), cmap='RdBu', aspect='auto')
axes[1, 0].set_title('矩阵 A (5x3)')

axes[1, 1].imshow(B.numpy(), cmap='RdBu', aspect='auto')
axes[1, 1].set_title('矩阵 B (3x4)')

plt.tight_layout()
plt.show()

# 性能基准测试
print(f"\n=== 性能基准测试 ===")
def benchmark_operation(operation_func, tensor_size=(1000, 1000), iterations=100):
    """基准测试函数"""
    tensor = torch.randn(tensor_size, device=device)
    
    # 预热
    for _ in range(10):
        _ = operation_func(tensor)
    
    if device.type == 'cuda':
        torch.cuda.synchronize()
    
    start_time = time.time()
    for _ in range(iterations):
        result = operation_func(tensor)
    
    if device.type == 'cuda':
        torch.cuda.synchronize()
    
    end_time = time.time()
    avg_time = (end_time - start_time) / iterations * 1000  # 毫秒
    return avg_time

# 测试不同操作的性能
operations = {
    '矩阵乘法': lambda x: torch.mm(x, x.t()),
    '元素级加法': lambda x: x + x,
    '三角函数': lambda x: torch.sin(x),
    '指数函数': lambda x: torch.exp(x),
    '求和': lambda x: torch.sum(x)
}

print(f"设备: {device}")
print(f"张量大小: 1000x1000")
print(f"迭代次数: 100")
print("-" * 40)

for name, op in operations.items():
    avg_time = benchmark_operation(op)
    print(f"{name:12}: {avg_time:.2f} ms")

print(f"\n✓ 张量基础操作和性能测试完成！")

## 3. 自动微分（Autograd）

自动微分是PyTorch的核心特性之一，它能够自动计算梯度，是深度学习训练的基础。理解autograd机制对于掌握PyTorch至关重要。

In [None]:
# 3.1 自动微分基础
print("=== 自动微分（Autograd）详解 ===")

# 3.1.1 requires_grad参数
print("1. requires_grad参数和计算图构建:")

# 创建需要梯度的张量
x = torch.tensor([2.0], requires_grad=True, device=device)
y = torch.tensor([3.0], requires_grad=True, device=device)

print(f"x = {x}, requires_grad = {x.requires_grad}")
print(f"y = {y}, requires_grad = {y.requires_grad}")

# 定义计算图
z = x * y  # z = 2 * 3 = 6
w = z + x  # w = 6 + 2 = 8
loss = w ** 2  # loss = 8^2 = 64

print(f"z = x * y = {z}")
print(f"w = z + x = {w}")
print(f"loss = w^2 = {loss}")

# 查看计算图信息
print(f"\nloss.grad_fn: {loss.grad_fn}")
print(f"w.grad_fn: {w.grad_fn}")
print(f"z.grad_fn: {z.grad_fn}")

# 3.1.2 反向传播
print(f"\n2. 反向传播计算梯度:")
loss.backward()

print(f"∂loss/∂x = {x.grad}")
print(f"∂loss/∂y = {y.grad}")

# 手动验证梯度计算
# loss = (x*y + x)^2 = (2*3 + 2)^2 = 8^2 = 64
# ∂loss/∂x = 2*(x*y + x) * (y + 1) = 2*8*(3+1) = 64
# ∂loss/∂y = 2*(x*y + x) * x = 2*8*2 = 32
print(f"手动计算 ∂loss/∂x = 2*8*4 = {2*8*4}")
print(f"手动计算 ∂loss/∂y = 2*8*2 = {2*8*2}")

# 3.1.3 梯度清零
print(f"\n3. 梯度累积和清零:")
print(f"第一次反向传播后 x.grad: {x.grad}")

# 再次反向传播（梯度会累积）
loss.backward()
print(f"第二次反向传播后 x.grad: {x.grad}")

# 梯度清零
x.grad.zero_()
y.grad.zero_()
print(f"清零后 x.grad: {x.grad}")

# 3.1.4 梯度上下文管理
print(f"\n4. 梯度上下文管理:")

# 禁用梯度计算
x = torch.randn(3, 3, requires_grad=True, device=device)
print(f"x.requires_grad: {x.requires_grad}")

with torch.no_grad():
    y = x * 2
    print(f"在 no_grad 上下文中，y.requires_grad: {y.requires_grad}")

# 临时启用梯度计算
x = torch.randn(3, 3, device=device)
with torch.enable_grad():
    x.requires_grad_(True)
    y = x * 2
    print(f"在 enable_grad 上下文中，y.requires_grad: {y.requires_grad}")

# 3.1.5 detach()方法
print(f"\n5. detach()方法使用:")
x = torch.randn(3, requires_grad=True, device=device)
y = x * 2

# 分离张量，停止梯度传播
y_detached = y.detach()
print(f"原始 y.requires_grad: {y.requires_grad}")
print(f"分离后 y_detached.requires_grad: {y_detached.requires_grad}")

# 3.1.6 函数的自动微分
print(f"\n6. 复杂函数的自动微分:")

def complex_function(x):
    """复杂函数示例"""
    return torch.sin(x) * torch.exp(-x**2) + torch.cos(x**2)

# 创建输入
x = torch.linspace(-2, 2, 100, requires_grad=True, device=device)
y = complex_function(x)

# 计算某点的梯度
loss = y.sum()
loss.backward()

print(f"在区间[-2, 2]上的梯度范围: [{x.grad.min().item():.4f}, {x.grad.max().item():.4f}]")

# 可视化函数和梯度
if device.type == 'cpu':
    x_np = x.detach().numpy()
    y_np = y.detach().numpy()
    grad_np = x.grad.numpy()
else:
    x_np = x.detach().cpu().numpy()
    y_np = y.detach().cpu().numpy()
    grad_np = x.grad.cpu().numpy()

fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 8))

ax1.plot(x_np, y_np, 'b-', linewidth=2, label='f(x)')
ax1.set_title('复杂函数 f(x) = sin(x)·exp(-x²) + cos(x²)')
ax1.set_ylabel('f(x)')
ax1.grid(True)
ax1.legend()

ax2.plot(x_np, grad_np, 'r-', linewidth=2, label="f'(x)")
ax2.set_title('函数的导数')
ax2.set_xlabel('x')
ax2.set_ylabel("f'(x)")
ax2.grid(True)
ax2.legend()

plt.tight_layout()
plt.show()

# 3.1.7 自定义autograd函数
print(f"\n7. 自定义autograd函数:")

class SquareFunction(torch.autograd.Function):
    """自定义平方函数，演示如何实现自定义的前向和反向传播"""
    
    @staticmethod
    def forward(ctx, input):
        """前向传播"""
        # 保存输入以供反向传播使用
        ctx.save_for_backward(input)
        return input ** 2
    
    @staticmethod
    def backward(ctx, grad_output):
        """反向传播"""
        # 获取保存的输入
        input, = ctx.saved_tensors
        # 计算梯度：d(x²)/dx = 2x
        grad_input = 2 * input * grad_output
        return grad_input

# 使用自定义函数
square = SquareFunction.apply

x = torch.tensor([3.0], requires_grad=True, device=device)
y = square(x)
y.backward()

print(f"自定义平方函数: f(3) = {y.item()}")
print(f"自定义函数梯度: f'(3) = {x.grad.item()}")
print(f"理论梯度: 2*3 = 6")

# 3.1.8 高阶梯度
print(f"\n8. 高阶梯度计算:")

x = torch.tensor([2.0], requires_grad=True, device=device)
y = x ** 3  # y = x³

# 一阶导数
grad_1 = torch.autograd.grad(y, x, create_graph=True)[0]
print(f"一阶导数 dy/dx = {grad_1.item()}")  # 3x² = 12

# 二阶导数
grad_2 = torch.autograd.grad(grad_1, x)[0]
print(f"二阶导数 d²y/dx² = {grad_2.item()}")  # 6x = 12

# 3.1.9 雅可比矩阵
print(f"\n9. 雅可比矩阵计算:")

def vector_function(x):
    """向量函数 f(x,y) = [x²+y, xy, y²]"""
    return torch.stack([
        x[0]**2 + x[1],
        x[0] * x[1], 
        x[1]**2
    ])

x = torch.tensor([2.0, 3.0], requires_grad=True, device=device)
y = vector_function(x)

# 计算雅可比矩阵
jacobian = torch.autograd.functional.jacobian(vector_function, x)
print(f"输入: {x}")
print(f"输出: {y}")
print(f"雅可比矩阵:\n{jacobian}")

# 理论雅可比矩阵:
# f₁ = x² + y  →  ∂f₁/∂x = 2x, ∂f₁/∂y = 1
# f₂ = xy      →  ∂f₂/∂x = y,  ∂f₂/∂y = x  
# f₃ = y²      →  ∂f₃/∂x = 0,  ∂f₃/∂y = 2y
theoretical_jacobian = torch.tensor([
    [2*x[0], 1],      # [4, 1]
    [x[1], x[0]],     # [3, 2]
    [0, 2*x[1]]       # [0, 6]
], device=device)
print(f"理论雅可比矩阵:\n{theoretical_jacobian}")

print(f"\n✓ 自动微分机制学习完成！")

# 3.1.10 梯度检查工具
print(f"\n10. 梯度检查（数值验证）:")

def numerical_gradient(f, x, h=1e-5):
    """数值方法计算梯度"""
    grad = torch.zeros_like(x)
    for i in range(x.numel()):
        x_pos = x.clone()
        x_neg = x.clone()
        x_pos.view(-1)[i] += h
        x_neg.view(-1)[i] -= h
        
        with torch.no_grad():
            grad.view(-1)[i] = (f(x_pos) - f(x_neg)) / (2 * h)
    return grad

# 测试函数
def test_function(x):
    return (x**2).sum()

x = torch.randn(3, requires_grad=True, device=device)

# 自动微分梯度
loss = test_function(x)
loss.backward()
auto_grad = x.grad.clone()

# 数值梯度
x.grad.zero_()
numerical_grad = numerical_gradient(test_function, x)

# 比较两种梯度
difference = torch.abs(auto_grad - numerical_grad)
print(f"自动微分梯度: {auto_grad}")
print(f"数值计算梯度: {numerical_grad}")
print(f"差异: {difference}")
print(f"最大差异: {difference.max().item():.2e}")

# 绘制梯度检查可视化
if len(x) <= 10:  # 只对小张量进行可视化
    fig, ax = plt.subplots(figsize=(10, 6))
    indices = range(len(auto_grad.flatten()))
    
    if device.type == 'cpu':
        auto_grad_np = auto_grad.numpy().flatten()
        numerical_grad_np = numerical_grad.numpy().flatten()
    else:
        auto_grad_np = auto_grad.cpu().numpy().flatten()
        numerical_grad_np = numerical_grad.cpu().numpy().flatten()
    
    ax.plot(indices, auto_grad_np, 'bo-', label='自动微分', markersize=8)
    ax.plot(indices, numerical_grad_np, 'r^-', label='数值计算', markersize=8)
    ax.set_xlabel('参数索引')
    ax.set_ylabel('梯度值')
    ax.set_title('梯度验证：自动微分 vs 数值计算')
    ax.legend()
    ax.grid(True)
    plt.show()

print(f"\n自动微分系统验证完成！梯度计算正确性: {'✓' if difference.max() < 1e-4 else '✗'}")

## 4. 自定义神经网络模型

在PyTorch中，构建神经网络模型主要通过继承`nn.Module`类来实现。我们将从简单的全连接网络开始，逐步构建复杂的卷积神经网络。

In [None]:
# 4.1 基础神经网络模型
print("=== 自定义神经网络模型构建 ===")

# 4.1.1 简单的全连接网络
class SimpleNN(nn.Module):
    """简单的全连接神经网络"""
    
    def __init__(self, input_size, hidden_size, output_size):
        super(SimpleNN, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.fc3 = nn.Linear(hidden_size, output_size)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.2)
        
    def forward(self, x):
        x = self.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.relu(self.fc2(x))
        x = self.dropout(x)
        x = self.fc3(x)
        return x

# 创建并测试简单网络
simple_model = SimpleNN(input_size=784, hidden_size=128, output_size=10).to(device)
print(f"1. 简单全连接网络:")
print(simple_model)

# 查看模型参数
total_params = sum(p.numel() for p in simple_model.parameters())
trainable_params = sum(p.numel() for p in simple_model.parameters() if p.requires_grad)
print(f"\n总参数数量: {total_params:,}")
print(f"可训练参数: {trainable_params:,}")

# 测试前向传播
test_input = torch.randn(32, 784, device=device)  # batch_size=32
output = simple_model(test_input)
print(f"输入形状: {test_input.shape}")
print(f"输出形状: {output.shape}")

# 4.1.2 使用Sequential构建网络
print(f"\n2. 使用Sequential构建网络:")

sequential_model = nn.Sequential(
    nn.Linear(784, 256),
    nn.ReLU(),
    nn.Dropout(0.3),
    nn.Linear(256, 128),
    nn.ReLU(),
    nn.Dropout(0.3),
    nn.Linear(128, 64),
    nn.ReLU(),
    nn.Linear(64, 10)
).to(device)

print(sequential_model)

# 4.1.3 卷积神经网络（CNN）
class ConvNet(nn.Module):
    """卷积神经网络用于图像分类"""
    
    def __init__(self, num_classes=10):
        super(ConvNet, self).__init__()
        
        # 卷积层
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        
        # 池化层
        self.pool = nn.MaxPool2d(2, 2)
        
        # 全连接层
        self.fc1 = nn.Linear(128 * 3 * 3, 512)  # 28->14->7->3 (after 3 pooling)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, num_classes)
        
        # 激活函数和正则化
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.5)
        self.batch_norm1 = nn.BatchNorm2d(32)
        self.batch_norm2 = nn.BatchNorm2d(64)
        self.batch_norm3 = nn.BatchNorm2d(128)
    
    def forward(self, x):
        # 第一个卷积块
        x = self.conv1(x)
        x = self.batch_norm1(x)
        x = self.relu(x)
        x = self.pool(x)  # 28x28 -> 14x14
        
        # 第二个卷积块
        x = self.conv2(x)
        x = self.batch_norm2(x)
        x = self.relu(x)
        x = self.pool(x)  # 14x14 -> 7x7
        
        # 第三个卷积块
        x = self.conv3(x)
        x = self.batch_norm3(x)
        x = self.relu(x)
        x = self.pool(x)  # 7x7 -> 3x3
        
        # 展平
        x = x.view(x.size(0), -1)
        
        # 全连接层
        x = self.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.relu(self.fc2(x))
        x = self.dropout(x)
        x = self.fc3(x)
        
        return x

# 创建CNN模型
cnn_model = ConvNet(num_classes=10).to(device)
print(f"\n3. 卷积神经网络:")
print(cnn_model)

# 测试CNN
test_image = torch.randn(32, 1, 28, 28, device=device)  # MNIST格式
cnn_output = cnn_model(test_image)
print(f"\nCNN测试:")
print(f"输入形状: {test_image.shape}")
print(f"输出形状: {cnn_output.shape}")

# 计算参数量
cnn_params = sum(p.numel() for p in cnn_model.parameters())
print(f"CNN参数数量: {cnn_params:,}")

# 4.1.4 残差块和ResNet风格网络
class ResidualBlock(nn.Module):
    """残差块"""
    
    def __init__(self, in_channels, out_channels, stride=1):
        super(ResidualBlock, self).__init__()
        
        self.conv1 = nn.Conv2d(in_channels, out_channels, 3, stride, 1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        
        # 如果输入输出通道数不同，需要调整维度
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, 1, stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )
    
    def forward(self, x):
        residual = x
        
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        
        # 添加残差连接
        out += self.shortcut(residual)
        out = F.relu(out)
        
        return out

class SimpleResNet(nn.Module):
    """简化的ResNet"""
    
    def __init__(self, num_classes=10):
        super(SimpleResNet, self).__init__()
        
        self.conv1 = nn.Conv2d(1, 16, 3, 1, 1, bias=False)
        self.bn1 = nn.BatchNorm2d(16)
        
        # 残差块
        self.layer1 = self._make_layer(16, 16, 2, stride=1)
        self.layer2 = self._make_layer(16, 32, 2, stride=2)
        self.layer3 = self._make_layer(32, 64, 2, stride=2)
        
        self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(64, num_classes)
    
    def _make_layer(self, in_channels, out_channels, num_blocks, stride):
        layers = []
        layers.append(ResidualBlock(in_channels, out_channels, stride))
        for _ in range(1, num_blocks):
            layers.append(ResidualBlock(out_channels, out_channels, 1))
        return nn.Sequential(*layers)
    
    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.avg_pool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

# 创建ResNet模型
resnet_model = SimpleResNet(num_classes=10).to(device)
print(f"\n4. 简化ResNet:")
print(resnet_model)

# 测试ResNet
resnet_output = resnet_model(test_image)
print(f"\nResNet测试:")
print(f"输入形状: {test_image.shape}")
print(f"输出形状: {resnet_output.shape}")

resnet_params = sum(p.numel() for p in resnet_model.parameters())
print(f"ResNet参数数量: {resnet_params:,}")

# 4.1.5 自定义激活函数
class Swish(nn.Module):
    """Swish激活函数: x * sigmoid(x)"""
    
    def forward(self, x):
        return x * torch.sigmoid(x)

class GELU(nn.Module):
    """GELU激活函数的近似实现"""
    
    def forward(self, x):
        return 0.5 * x * (1 + torch.tanh(np.sqrt(2 / np.pi) * (x + 0.044715 * torch.pow(x, 3))))

# 测试自定义激活函数
print(f"\n5. 自定义激活函数测试:")
x = torch.linspace(-3, 3, 100, device=device)

relu = nn.ReLU()
swish = Swish().to(device)
gelu = GELU().to(device)

with torch.no_grad():
    y_relu = relu(x)
    y_swish = swish(x)
    y_gelu = gelu(x)

# 可视化激活函数
if device.type == 'cpu':
    x_np = x.numpy()
    y_relu_np = y_relu.numpy()
    y_swish_np = y_swish.numpy()
    y_gelu_np = y_gelu.numpy()
else:
    x_np = x.cpu().numpy()
    y_relu_np = y_relu.cpu().numpy()
    y_swish_np = y_swish.cpu().numpy()
    y_gelu_np = y_gelu.cpu().numpy()

plt.figure(figsize=(12, 6))
plt.plot(x_np, y_relu_np, label='ReLU', linewidth=2)
plt.plot(x_np, y_swish_np, label='Swish', linewidth=2)
plt.plot(x_np, y_gelu_np, label='GELU', linewidth=2)
plt.plot(x_np, np.tanh(x_np), label='Tanh', linewidth=2, linestyle='--')
plt.xlabel('x')
plt.ylabel('f(x)')
plt.title('不同激活函数的比较')
plt.legend()
plt.grid(True)
plt.show()

# 4.1.6 模型信息和可视化工具
def model_summary(model, input_size):
    """模型摘要信息"""
    def register_hook(module):
        def hook(module, input, output):
            class_name = str(module.__class__).split(".")[-1].split("'")[0]
            module_idx = len(summary)
            
            m_key = f"{class_name}-{module_idx+1}"
            summary[m_key] = {}
            summary[m_key]["input_shape"] = list(input[0].size())
            summary[m_key]["output_shape"] = list(output.size())
            
            params = 0
            if hasattr(module, "weight") and hasattr(module.weight, "size"):
                params += torch.prod(torch.LongTensor(list(module.weight.size())))
            if hasattr(module, "bias") and hasattr(module.bias, "size"):
                params += torch.prod(torch.LongTensor(list(module.bias.size())))
            
            summary[m_key]["nb_params"] = params
        
        if not isinstance(module, nn.Sequential) and not isinstance(module, nn.ModuleList):
            hooks.append(module.register_forward_hook(hook))
    
    # 检查设备类型
    device_type = next(model.parameters()).device
    
    # 创建测试输入
    if isinstance(input_size, tuple):
        x = torch.rand(1, *input_size).to(device_type)
    else:
        x = torch.rand(input_size).to(device_type)
    
    summary = {}
    hooks = []
    
    model.apply(register_hook)
    model(x)
    
    # 移除hooks
    for h in hooks:
        h.remove()
    
    print("-" * 60)
    print(f"{'Layer (type)':>25} {'Output Shape':>15} {'Param #':>10}")
    print("=" * 60)
    
    total_params = 0
    total_output = 0
    trainable_params = 0
    
    for layer in summary:
        output_shape = str(summary[layer]["output_shape"])
        nb_params = summary[layer]["nb_params"]
        
        total_params += nb_params
        print(f"{layer:>25} {output_shape:>15} {nb_params:>10,}")
    
    # 计算可训练参数
    for param in model.parameters():
        if param.requires_grad:
            trainable_params += param.numel()
    
    print("=" * 60)
    print(f"Total params: {total_params:,}")
    print(f"Trainable params: {trainable_params:,}")
    print(f"Non-trainable params: {total_params - trainable_params:,}")
    print("-" * 60)

print(f"\n6. 模型摘要信息:")
print(f"\nCNN模型摘要:")
model_summary(cnn_model, (1, 28, 28))

print(f"\nResNet模型摘要:")
model_summary(resnet_model, (1, 28, 28))

# 4.1.7 模型初始化
def init_weights(m):
    """自定义权重初始化"""
    if isinstance(m, nn.Linear):
        torch.nn.init.xavier_uniform_(m.weight)
        if m.bias is not None:
            m.bias.data.fill_(0.01)
    elif isinstance(m, nn.Conv2d):
        torch.nn.init.kaiming_uniform_(m.weight, mode='fan_out', nonlinearity='relu')
        if m.bias is not None:
            m.bias.data.fill_(0.01)
    elif isinstance(m, nn.BatchNorm2d):
        m.weight.data.fill_(1)
        m.bias.data.zero_()

print(f"\n7. 模型权重初始化:")

# 创建新模型进行初始化演示
test_model = ConvNet().to(device)

# 查看初始化前的权重
first_conv_weight = test_model.conv1.weight.data.clone()
print(f"初始化前第一层卷积权重统计:")
print(f"  均值: {first_conv_weight.mean().item():.6f}")
print(f"  标准差: {first_conv_weight.std().item():.6f}")

# 应用自定义初始化
test_model.apply(init_weights)

# 查看初始化后的权重
after_conv_weight = test_model.conv1.weight.data
print(f"初始化后第一层卷积权重统计:")
print(f"  均值: {after_conv_weight.mean().item():.6f}")
print(f"  标准差: {after_conv_weight.std().item():.6f}")

print(f"\n✓ 神经网络模型构建学习完成！")

# 4.1.8 模型比较表
print(f"\n8. 模型复杂度对比:")
models_comparison = {
    'Simple NN': {'model': simple_model, 'params': sum(p.numel() for p in simple_model.parameters())},
    'Sequential': {'model': sequential_model, 'params': sum(p.numel() for p in sequential_model.parameters())},
    'CNN': {'model': cnn_model, 'params': sum(p.numel() for p in cnn_model.parameters())},
    'ResNet': {'model': resnet_model, 'params': sum(p.numel() for p in resnet_model.parameters())}
}

print(f"{'模型':>12} {'参数数量':>15} {'相对复杂度':>12}")
print("-" * 45)

min_params = min(info['params'] for info in models_comparison.values())

for name, info in models_comparison.items():
    params = info['params']
    relative_complexity = params / min_params
    print(f"{name:>12} {params:>15,} {relative_complexity:>12.1f}x")

# 可视化模型复杂度
model_names = list(models_comparison.keys())
param_counts = [models_comparison[name]['params'] for name in model_names]

plt.figure(figsize=(10, 6))
bars = plt.bar(model_names, param_counts, color=['skyblue', 'lightgreen', 'lightcoral', 'gold'])
plt.yscale('log')
plt.ylabel('参数数量 (对数刻度)')
plt.title('不同模型的参数复杂度对比')
plt.xticks(rotation=45)

# 添加数值标签
for i, (bar, count) in enumerate(zip(bars, param_counts)):
    plt.text(bar.get_x() + bar.get_width()/2, bar.get_height() * 1.1, 
             f'{count:,}', ha='center', va='bottom', fontsize=10)

plt.tight_layout()
plt.show()

print(f"\n模型架构设计要点:")
print(f"✓ 根据任务复杂度选择合适的模型规模")
print(f"✓ 使用批归一化稳定训练过程")
print(f"✓ 适当使用Dropout防止过拟合")
print(f"✓ 残差连接有助于训练深层网络")
print(f"✓ 合适的权重初始化很重要")

## 5. 自定义数据集与数据加载

在深度学习项目中，数据处理是关键环节。PyTorch提供了灵活的数据加载机制，我们将学习如何创建自定义数据集、进行数据增强，以及高效的数据加载。

In [None]:
# 5.1 MNIST数据集加载和预处理
print("=== 自定义数据集与数据加载 ===")

# 5.1.1 标准MNIST数据集加载
print("1. 标准MNIST数据集加载:")

# 定义数据变换
basic_transform = transforms.Compose([
    transforms.ToTensor(),  # 转换为张量，范围[0,1]
    transforms.Normalize((0.1307,), (0.3081,))  # MNIST的均值和标准差
])

# 加载MNIST数据集
train_dataset = torchvision.datasets.MNIST(
    root='./data', 
    train=True, 
    download=True, 
    transform=basic_transform
)

test_dataset = torchvision.datasets.MNIST(
    root='./data', 
    train=False, 
    download=True, 
    transform=basic_transform
)

print(f"训练集大小: {len(train_dataset)}")
print(f"测试集大小: {len(test_dataset)}")
print(f"图像形状: {train_dataset[0][0].shape}")
print(f"标签范围: {min([label for _, label in train_dataset])} - {max([label for _, label in train_dataset])}")

# 5.1.2 自定义数据集类
class CustomMNIST(Dataset):
    """自定义MNIST数据集类"""
    
    def __init__(self, data, targets, transform=None, target_transform=None):
        self.data = data
        self.targets = targets
        self.transform = transform
        self.target_transform = target_transform
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        image, target = self.data[idx], self.targets[idx]
        
        # 如果是PIL图像或numpy数组，转换为PIL
        if not isinstance(image, torch.Tensor):
            image = transforms.ToPILImage()(image)
        else:
            image = transforms.ToPILImage()(image.squeeze())
        
        if self.transform:
            image = self.transform(image)
        
        if self.target_transform:
            target = self.target_transform(target)
        
        return image, target

# 从原始数据集创建自定义数据集
custom_dataset = CustomMNIST(
    data=train_dataset.data[:1000],  # 使用前1000个样本
    targets=train_dataset.targets[:1000],
    transform=basic_transform
)

print(f"\n自定义数据集大小: {len(custom_dataset)}")

# 5.1.3 数据增强
print(f"\n2. 数据增强技术:")

# 定义各种数据增强变换
data_augmentation = transforms.Compose([
    transforms.RandomRotation(10),  # 随机旋转±10度
    transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),  # 随机平移
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

# 更激进的数据增强
aggressive_augmentation = transforms.Compose([
    transforms.RandomRotation(15),
    transforms.RandomAffine(degrees=5, translate=(0.15, 0.15), scale=(0.85, 1.15)),
    transforms.RandomPerspective(distortion_scale=0.2, p=0.5),
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,)),
    transforms.RandomErasing(p=0.1, scale=(0.02, 0.1))  # 随机擦除
])

# 创建增强数据集
augmented_dataset = torchvision.datasets.MNIST(
    root='./data', 
    train=True, 
    transform=data_augmentation
)

# 可视化数据增强效果
def visualize_augmentation(dataset, original_dataset, num_samples=8):
    """可视化数据增强效果"""
    fig, axes = plt.subplots(2, num_samples, figsize=(16, 6))
    
    for i in range(num_samples):
        # 原始图像
        orig_img, label = original_dataset[i]
        if orig_img.dim() == 3 and orig_img.shape[0] == 1:
            orig_img = orig_img.squeeze(0)
        
        axes[0, i].imshow(orig_img, cmap='gray')
        axes[0, i].set_title(f'原始 ({label})')
        axes[0, i].axis('off')
        
        # 增强后图像
        aug_img, _ = dataset[i]
        if aug_img.dim() == 3 and aug_img.shape[0] == 1:
            aug_img = aug_img.squeeze(0)
        
        axes[1, i].imshow(aug_img, cmap='gray')
        axes[1, i].set_title(f'增强 ({label})')
        axes[1, i].axis('off')
    
    plt.suptitle('数据增强效果对比', fontsize=16)
    plt.tight_layout()
    plt.show()

print("数据增强效果可视化:")
visualize_augmentation(augmented_dataset, train_dataset)

# 5.1.4 数据加载器配置
print(f"\n3. 数据加载器配置:")

# 不同的数据加载器配置
dataloaders = {}

# 基础配置
dataloaders['basic'] = DataLoader(
    train_dataset,
    batch_size=64,
    shuffle=True,
    num_workers=2
)

# 高性能配置
dataloaders['optimized'] = DataLoader(
    train_dataset,
    batch_size=128,
    shuffle=True,
    num_workers=4,
    pin_memory=True if torch.cuda.is_available() else False,
    persistent_workers=True
)

# 测试不同配置的加载速度
def benchmark_dataloader(dataloader, name, num_batches=50):
    """测试数据加载器性能"""
    start_time = time.time()
    
    for i, (data, target) in enumerate(dataloader):
        if i >= num_batches:
            break
        # 模拟数据传输到GPU
        data = data.to(device, non_blocking=True)
        target = target.to(device, non_blocking=True)
    
    end_time = time.time()
    total_time = end_time - start_time
    samples_per_second = (num_batches * dataloader.batch_size) / total_time
    
    print(f"{name:12}: {total_time:.2f}s, {samples_per_second:.0f} samples/s")

print("数据加载器性能测试:")
for name, dataloader in dataloaders.items():
    benchmark_dataloader(dataloader, name)

# 5.1.5 数据集分割
print(f"\n4. 数据集分割:")

# 将训练集分割为训练集和验证集
train_size = int(0.8 * len(train_dataset))
val_size = len(train_dataset) - train_size

train_subset, val_subset = random_split(
    train_dataset, 
    [train_size, val_size],
    generator=torch.Generator().manual_seed(42)
)

print(f"原始训练集: {len(train_dataset)}")
print(f"分割后训练集: {len(train_subset)}")
print(f"验证集: {len(val_subset)}")

# 创建对应的数据加载器
train_loader = DataLoader(train_subset, batch_size=64, shuffle=True, num_workers=2)
val_loader = DataLoader(val_subset, batch_size=64, shuffle=False, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=2)

print(f"训练批次数: {len(train_loader)}")
print(f"验证批次数: {len(val_loader)}")
print(f"测试批次数: {len(test_loader)}")

# 5.1.6 数据分析和可视化
print(f"\n5. 数据分析:")

# 分析类别分布
def analyze_dataset(dataset, name):
    """分析数据集的类别分布"""
    if hasattr(dataset, 'targets'):
        targets = dataset.targets
    else:
        # 对于subset，需要提取targets
        targets = [dataset.dataset.targets[i] for i in dataset.indices]
        targets = torch.tensor(targets)
    
    unique, counts = torch.unique(targets, return_counts=True)
    
    print(f"\n{name} 类别分布:")
    for digit, count in zip(unique.tolist(), counts.tolist()):
        percentage = count / len(targets) * 100
        print(f"  数字 {digit}: {count:5d} 样本 ({percentage:5.1f}%)")
    
    return targets

# 分析各个数据集
train_targets = analyze_dataset(train_subset, "训练集")
val_targets = analyze_dataset(val_subset, "验证集")
test_targets = analyze_dataset(test_dataset, "测试集")

# 可视化类别分布
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
datasets_info = [
    ("训练集", train_targets),
    ("验证集", val_targets), 
    ("测试集", test_targets)
]

for idx, (name, targets) in enumerate(datasets_info):
    unique, counts = torch.unique(targets, return_counts=True)
    
    axes[idx].bar(unique.numpy(), counts.numpy())
    axes[idx].set_title(f'{name}类别分布')
    axes[idx].set_xlabel('数字')
    axes[idx].set_ylabel('样本数量')
    axes[idx].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

# 5.1.7 批次数据可视化
def visualize_batch(dataloader, num_samples=16):
    """可视化一个批次的数据"""
    data_iter = iter(dataloader)
    images, labels = next(data_iter)
    
    # 选择要显示的样本数
    num_samples = min(num_samples, len(images))
    
    # 计算网格大小
    grid_size = int(np.ceil(np.sqrt(num_samples)))
    
    fig, axes = plt.subplots(grid_size, grid_size, figsize=(12, 12))
    axes = axes.flatten()
    
    for i in range(num_samples):
        img = images[i]
        if img.dim() == 3 and img.shape[0] == 1:
            img = img.squeeze(0)
        
        # 反归一化以便显示
        img = img * 0.3081 + 0.1307
        img = torch.clamp(img, 0, 1)
        
        axes[i].imshow(img, cmap='gray')
        axes[i].set_title(f'Label: {labels[i].item()}')
        axes[i].axis('off')
    
    # 隐藏多余的子图
    for i in range(num_samples, len(axes)):
        axes[i].axis('off')
    
    plt.suptitle('训练批次数据样本', fontsize=16)
    plt.tight_layout()
    plt.show()

print(f"\n6. 批次数据可视化:")
visualize_batch(train_loader, 16)

# 5.1.8 数据统计分析
print(f"\n7. 数据统计分析:")

def compute_dataset_stats(dataloader):
    """计算数据集的统计信息"""
    mean = 0.0
    std = 0.0
    total_samples = 0
    
    for data, _ in dataloader:
        batch_samples = data.size(0)
        data = data.view(batch_samples, data.size(1), -1)
        mean += data.mean(2).sum(0)
        std += data.std(2).sum(0)
        total_samples += batch_samples
    
    mean /= total_samples
    std /= total_samples
    
    return mean, std

# 计算训练集统计信息
train_mean, train_std = compute_dataset_stats(train_loader)
print(f"训练集统计信息:")
print(f"  均值: {train_mean.item():.4f}")
print(f"  标准差: {train_std.item():.4f}")

# 验证预定义的归一化参数
print(f"预定义MNIST归一化参数:")
print(f"  均值: 0.1307")
print(f"  标准差: 0.3081")

# 5.1.9 内存和性能优化
print(f"\n8. 内存和性能优化:")

# 内存映射数据集
class MemoryMappedDataset(Dataset):
    """内存映射数据集，适用于大数据集"""
    
    def __init__(self, data_path, transform=None):
        # 这里简化示例，实际应用中会使用内存映射文件
        self.data = train_dataset.data
        self.targets = train_dataset.targets
        self.transform = transform
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        # 懒加载，只在需要时加载数据
        image = self.data[idx]
        target = self.targets[idx]
        
        if self.transform:
            # 转换为PIL图像以应用transform
            image = transforms.ToPILImage()(image)
            image = self.transform(image)
        else:
            image = image.float() / 255.0
            image = image.unsqueeze(0)  # 添加通道维度
        
        return image, target

# 预取数据加载器
class PrefetchLoader:
    """数据预取加载器"""
    
    def __init__(self, loader):
        self.loader = loader
        self.stream = torch.cuda.Stream() if torch.cuda.is_available() else None
    
    def __iter__(self):
        loader_iter = iter(self.loader)
        self.preload(loader_iter)
        
        while self.next_input is not None:
            torch.cuda.current_stream().wait_stream(self.stream) if self.stream else None
            input = self.next_input
            target = self.next_target
            self.preload(loader_iter)
            yield input, target
    
    def preload(self, loader_iter):
        try:
            self.next_input, self.next_target = next(loader_iter)
        except StopIteration:
            self.next_input = None
            self.next_target = None
            return
        
        if self.stream:
            with torch.cuda.stream(self.stream):
                self.next_input = self.next_input.cuda(non_blocking=True)
                self.next_target = self.next_target.cuda(non_blocking=True)

# 测试优化后的数据加载
optimized_dataset = MemoryMappedDataset('./data', basic_transform)
optimized_loader = DataLoader(
    optimized_dataset, 
    batch_size=128, 
    shuffle=True, 
    num_workers=4,
    pin_memory=torch.cuda.is_available(),
    persistent_workers=True
)

print("数据加载优化技术:")
print("✓ 内存映射减少内存占用")
print("✓ pin_memory加速GPU传输")
print("✓ persistent_workers减少进程创建开销")
print("✓ non_blocking传输提高并行度")

print(f"\n✓ 数据集和数据加载学习完成！")

# 保存数据加载器配置信息
data_config = {
    'train_size': len(train_subset),
    'val_size': len(val_subset),
    'test_size': len(test_dataset),
    'batch_size': 64,
    'num_workers': 2,
    'pin_memory': torch.cuda.is_available(),
    'normalize_mean': 0.1307,
    'normalize_std': 0.3081
}

print(f"\n数据配置信息:")
for key, value in data_config.items():
    print(f"  {key}: {value}")

# 将数据加载器保存为全局变量供后续使用
globals()['train_loader'] = train_loader
globals()['val_loader'] = val_loader
globals()['test_loader'] = test_loader
globals()['data_config'] = data_config

## 6. 完整训练流程

实现一个完整的深度学习训练流程，包括训练、验证、测试的循环，损失函数选择，优化器配置，以及学习率调度。

In [None]:
# 6.1 训练配置和初始化
print("=== 完整训练流程 ===")

# 6.1.1 训练配置
class TrainingConfig:
    """训练配置类"""
    
    def __init__(self):
        # 模型配置
        self.model_name = "CNN_MNIST"
        self.num_classes = 10
        
        # 训练配置
        self.epochs = 20
        self.batch_size = 64
        self.learning_rate = 0.001
        self.weight_decay = 1e-4
        
        # 优化器配置
        self.optimizer_type = "Adam"  # Adam, SGD, AdamW
        self.momentum = 0.9  # for SGD
        
        # 学习率调度
        self.scheduler_type = "StepLR"  # StepLR, ReduceLROnPlateau, CosineAnnealingLR
        self.step_size = 7
        self.gamma = 0.1
        self.patience = 5  # for ReduceLROnPlateau
        
        # 早停配置
        self.early_stopping = True
        self.early_stopping_patience = 10
        self.min_delta = 0.001
        
        # 设备和路径
        self.device = device
        self.save_dir = Path("checkpoints")
        self.log_dir = Path("logs")
        
        # 混合精度训练
        self.use_amp = torch.cuda.is_available()
        
        # 日志和保存
        self.save_best_only = True
        self.save_frequency = 5  # 每5个epoch保存一次
        self.log_frequency = 100  # 每100个batch记录一次

config = TrainingConfig()

print(f"训练配置:")
print(f"  模型: {config.model_name}")
print(f"  训练轮数: {config.epochs}")
print(f"  批次大小: {config.batch_size}")
print(f"  学习率: {config.learning_rate}")
print(f"  优化器: {config.optimizer_type}")
print(f"  设备: {config.device}")
print(f"  混合精度: {config.use_amp}")

# 6.1.2 模型初始化
# 使用之前定义的CNN模型
model = ConvNet(num_classes=config.num_classes).to(config.device)

# 应用权重初始化
model.apply(init_weights)

print(f"\n模型参数数量: {sum(p.numel() for p in model.parameters()):,}")

# 6.1.3 损失函数
criterion = nn.CrossEntropyLoss()

# 6.1.4 优化器配置
def get_optimizer(model, config):
    """根据配置获取优化器"""
    if config.optimizer_type == "Adam":
        return optim.Adam(
            model.parameters(), 
            lr=config.learning_rate,
            weight_decay=config.weight_decay
        )
    elif config.optimizer_type == "SGD":
        return optim.SGD(
            model.parameters(),
            lr=config.learning_rate,
            momentum=config.momentum,
            weight_decay=config.weight_decay
        )
    elif config.optimizer_type == "AdamW":
        return optim.AdamW(
            model.parameters(),
            lr=config.learning_rate,
            weight_decay=config.weight_decay
        )
    else:
        raise ValueError(f"Unknown optimizer: {config.optimizer_type}")

optimizer = get_optimizer(model, config)

# 6.1.5 学习率调度器
def get_scheduler(optimizer, config):
    """根据配置获取学习率调度器"""
    if config.scheduler_type == "StepLR":
        return StepLR(optimizer, step_size=config.step_size, gamma=config.gamma)
    elif config.scheduler_type == "ReduceLROnPlateau":
        return ReduceLROnPlateau(
            optimizer, mode='min', patience=config.patience,
            factor=config.gamma, verbose=True
        )
    elif config.scheduler_type == "CosineAnnealingLR":
        return CosineAnnealingLR(optimizer, T_max=config.epochs)
    else:
        return None

scheduler = get_scheduler(optimizer, config)

# 6.1.6 早停机制
class EarlyStopping:
    """早停机制"""
    
    def __init__(self, patience=7, min_delta=0, restore_best_weights=True):
        self.patience = patience
        self.min_delta = min_delta
        self.restore_best_weights = restore_best_weights
        self.best_loss = None
        self.counter = 0
        self.best_weights = None
        
    def __call__(self, val_loss, model):
        if self.best_loss is None:
            self.best_loss = val_loss
            self.save_checkpoint(model)
        elif val_loss < self.best_loss - self.min_delta:
            self.best_loss = val_loss
            self.counter = 0
            self.save_checkpoint(model)
        else:
            self.counter += 1
            
        if self.counter >= self.patience:
            if self.restore_best_weights:
                model.load_state_dict(self.best_weights)
            return True
        return False
    
    def save_checkpoint(self, model):
        self.best_weights = copy.deepcopy(model.state_dict())

early_stopping = EarlyStopping(
    patience=config.early_stopping_patience,
    min_delta=config.min_delta
) if config.early_stopping else None

# 6.1.7 混合精度训练设置
scaler = torch.cuda.amp.GradScaler() if config.use_amp else None

print(f"\n训练组件初始化完成:")
print(f"✓ 模型: {type(model).__name__}")
print(f"✓ 损失函数: {type(criterion).__name__}")
print(f"✓ 优化器: {type(optimizer).__name__}")
print(f"✓ 调度器: {type(scheduler).__name__ if scheduler else None}")
print(f"✓ 早停: {'启用' if early_stopping else '禁用'}")
print(f"✓ 混合精度: {'启用' if scaler else '禁用'}")

# 6.2 训练和验证函数
def train_epoch(model, train_loader, criterion, optimizer, scaler, device, epoch, config):
    """训练一个epoch"""
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    # 创建进度条
    pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{config.epochs} [Train]')
    
    for batch_idx, (data, target) in enumerate(pbar):
        data, target = data.to(device, non_blocking=True), target.to(device, non_blocking=True)
        
        optimizer.zero_grad()
        
        if scaler is not None:
            # 混合精度训练
            with torch.cuda.amp.autocast():
                output = model(data)
                loss = criterion(output, target)
            
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            # 常规训练
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
        
        # 统计
        running_loss += loss.item()
        _, predicted = torch.max(output.data, 1)
        total += target.size(0)
        correct += (predicted == target).sum().item()
        
        # 更新进度条
        if batch_idx % config.log_frequency == 0:
            current_acc = 100. * correct / total
            current_loss = running_loss / (batch_idx + 1)
            pbar.set_postfix({
                'Loss': f'{current_loss:.4f}',
                'Acc': f'{current_acc:.2f}%'
            })
    
    epoch_loss = running_loss / len(train_loader)
    epoch_acc = 100. * correct / total
    
    return epoch_loss, epoch_acc

def validate_epoch(model, val_loader, criterion, device, epoch, config):
    """验证一个epoch"""
    model.eval()
    val_loss = 0.0
    correct = 0
    total = 0
    
    with torch.no_grad():
        pbar = tqdm(val_loader, desc=f'Epoch {epoch+1}/{config.epochs} [Val]')
        
        for data, target in pbar:
            data, target = data.to(device, non_blocking=True), target.to(device, non_blocking=True)
            
            if scaler is not None:
                with torch.cuda.amp.autocast():
                    output = model(data)
                    loss = criterion(output, target)
            else:
                output = model(data)
                loss = criterion(output, target)
            
            val_loss += loss.item()
            _, predicted = torch.max(output.data, 1)
            total += target.size(0)
            correct += (predicted == target).sum().item()
            
            # 更新进度条
            current_acc = 100. * correct / total
            current_loss = val_loss / (len(pbar.n) + 1) if hasattr(pbar, 'n') else 0
            pbar.set_postfix({
                'Loss': f'{current_loss:.4f}',
                'Acc': f'{current_acc:.2f}%'
            })
    
    epoch_loss = val_loss / len(val_loader)
    epoch_acc = 100. * correct / total
    
    return epoch_loss, epoch_acc

# 6.3 训练历史记录
class TrainingHistory:
    """训练历史记录"""
    
    def __init__(self):
        self.train_losses = []
        self.train_accuracies = []
        self.val_losses = []
        self.val_accuracies = []
        self.learning_rates = []
        self.epochs = []
        
    def update(self, epoch, train_loss, train_acc, val_loss, val_acc, lr):
        self.epochs.append(epoch)
        self.train_losses.append(train_loss)
        self.train_accuracies.append(train_acc)
        self.val_losses.append(val_loss)
        self.val_accuracies.append(val_acc)
        self.learning_rates.append(lr)
    
    def save(self, filepath):
        """保存训练历史"""
        history_dict = {
            'epochs': self.epochs,
            'train_losses': self.train_losses,
            'train_accuracies': self.train_accuracies,
            'val_losses': self.val_losses,
            'val_accuracies': self.val_accuracies,
            'learning_rates': self.learning_rates
        }
        
        with open(filepath, 'w') as f:
            json.dump(history_dict, f, indent=2)
    
    def load(self, filepath):
        """加载训练历史"""
        with open(filepath, 'r') as f:
            history_dict = json.load(f)
        
        self.epochs = history_dict['epochs']
        self.train_losses = history_dict['train_losses']
        self.train_accuracies = history_dict['train_accuracies']
        self.val_losses = history_dict['val_losses']
        self.val_accuracies = history_dict['val_accuracies']
        self.learning_rates = history_dict['learning_rates']
    
    def plot(self):
        """绘制训练历史"""
        fig, axes = plt.subplots(2, 2, figsize=(15, 10))
        
        # 损失曲线
        axes[0, 0].plot(self.epochs, self.train_losses, 'b-', label='Train Loss')
        axes[0, 0].plot(self.epochs, self.val_losses, 'r-', label='Val Loss')
        axes[0, 0].set_title('Training and Validation Loss')
        axes[0, 0].set_xlabel('Epoch')
        axes[0, 0].set_ylabel('Loss')
        axes[0, 0].legend()
        axes[0, 0].grid(True)
        
        # 准确率曲线
        axes[0, 1].plot(self.epochs, self.train_accuracies, 'b-', label='Train Acc')
        axes[0, 1].plot(self.epochs, self.val_accuracies, 'r-', label='Val Acc')
        axes[0, 1].set_title('Training and Validation Accuracy')
        axes[0, 1].set_xlabel('Epoch')
        axes[0, 1].set_ylabel('Accuracy (%)')
        axes[0, 1].legend()
        axes[0, 1].grid(True)
        
        # 学习率曲线
        axes[1, 0].plot(self.epochs, self.learning_rates, 'g-')
        axes[1, 0].set_title('Learning Rate Schedule')
        axes[1, 0].set_xlabel('Epoch')
        axes[1, 0].set_ylabel('Learning Rate')
        axes[1, 0].set_yscale('log')
        axes[1, 0].grid(True)
        
        # 验证损失vs准确率散点图
        axes[1, 1].scatter(self.val_losses, self.val_accuracies, c=self.epochs, cmap='viridis')
        axes[1, 1].set_title('Validation Loss vs Accuracy')
        axes[1, 1].set_xlabel('Validation Loss')
        axes[1, 1].set_ylabel('Validation Accuracy (%)')
        axes[1, 1].grid(True)
        
        plt.tight_layout()
        plt.show()

history = TrainingHistory()

# 6.4 主训练循环
print(f"\n开始训练...")
print("=" * 50)

best_val_acc = 0.0
start_time = time.time()

try:
    for epoch in range(config.epochs):
        epoch_start_time = time.time()
        
        # 训练阶段
        train_loss, train_acc = train_epoch(
            model, train_loader, criterion, optimizer, scaler, device, epoch, config
        )
        
        # 验证阶段
        val_loss, val_acc = validate_epoch(
            model, val_loader, criterion, device, epoch, config
        )
        
        # 学习率调度
        if scheduler is not None:
            if isinstance(scheduler, ReduceLROnPlateau):
                scheduler.step(val_loss)
            else:
                scheduler.step()
        
        # 记录当前学习率
        current_lr = optimizer.param_groups[0]['lr']
        
        # 更新训练历史
        history.update(epoch, train_loss, train_acc, val_loss, val_acc, current_lr)
        
        # 计算epoch时间
        epoch_time = time.time() - epoch_start_time
        
        # 打印epoch结果
        print(f"Epoch {epoch+1}/{config.epochs}:")
        print(f"  Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%")
        print(f"  Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%")
        print(f"  Learning Rate: {current_lr:.6f}")
        print(f"  Time: {epoch_time:.2f}s")
        
        # 保存最佳模型
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict() if scheduler else None,
                'val_acc': val_acc,
                'val_loss': val_loss,
                'config': config.__dict__
            }, config.save_dir / 'best_model.pth')
            print(f"  ✓ 新的最佳模型已保存 (Val Acc: {val_acc:.2f}%)")
        
        # 定期保存检查点
        if (epoch + 1) % config.save_frequency == 0:
            checkpoint_path = config.save_dir / f'checkpoint_epoch_{epoch+1}.pth'
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict() if scheduler else None,
                'val_acc': val_acc,
                'val_loss': val_loss,
                'history': history.__dict__,
                'config': config.__dict__
            }, checkpoint_path)
            print(f"  ✓ 检查点已保存: {checkpoint_path.name}")
        
        # 早停检查
        if early_stopping:
            if early_stopping(val_loss, model):
                print(f"  Early stopping triggered after {epoch+1} epochs")
                break
        
        print("-" * 50)

except KeyboardInterrupt:
    print("\n训练被用户中断")

# 训练完成统计
total_time = time.time() - start_time
print(f"\n训练完成!")
print(f"总训练时间: {total_time/60:.1f} 分钟")
print(f"最佳验证准确率: {best_val_acc:.2f}%")

# 保存训练历史
history.save(config.log_dir / 'training_history.json')
print(f"训练历史已保存到: {config.log_dir / 'training_history.json'}")

# 可视化训练历史
print(f"\n绘制训练历史曲线:")
history.plot()

print(f"\n✓ 完整训练流程演示完成！")

## 7. 断点保存与恢复 (Checkpoint & Resume)

在深度学习训练中，断点保存与恢复功能至关重要，特别是对于长时间训练的模型。这可以让我们：
- 在训练意外中断后恢复训练
- 在不同实验之间继续训练
- 保存和加载最佳模型
- 实现增量训练

### 7.1 断点保存策略

In [None]:
import json
import shutil
from pathlib import Path
from datetime import datetime

class CheckpointManager:
    """
    断点管理器 - 统一管理模型的保存与恢复
    """
    def __init__(self, save_dir, max_checkpoints=5):
        self.save_dir = Path(save_dir)
        self.save_dir.mkdir(parents=True, exist_ok=True)
        self.max_checkpoints = max_checkpoints
        
    def save_checkpoint(self, state, epoch, is_best=False, suffix=""):
        """保存断点"""
        # 基本检查点
        checkpoint_name = f"checkpoint_epoch_{epoch}{suffix}.pth"
        checkpoint_path = self.save_dir / checkpoint_name
        
        # 添加时间戳
        state['save_time'] = datetime.now().isoformat()
        
        torch.save(state, checkpoint_path)
        print(f"✓ 检查点已保存: {checkpoint_path}")
        
        # 如果是最佳模型，额外保存一份
        if is_best:
            best_path = self.save_dir / "best_model.pth"
            shutil.copy2(checkpoint_path, best_path)
            print(f"✓ 最佳模型已保存: {best_path}")
        
        # 清理旧的检查点
        self._cleanup_checkpoints()
        
        return checkpoint_path
    
    def load_checkpoint(self, checkpoint_path, model, optimizer=None, scheduler=None):
        """加载断点"""
        checkpoint_path = Path(checkpoint_path)
        
        if not checkpoint_path.exists():
            raise FileNotFoundError(f"检查点文件不存在: {checkpoint_path}")
        
        print(f"🔄 正在加载检查点: {checkpoint_path}")
        checkpoint = torch.load(checkpoint_path, map_location='cpu')
        
        # 加载模型状态
        model.load_state_dict(checkpoint['model_state_dict'])
        
        # 加载优化器状态
        if optimizer is not None and 'optimizer_state_dict' in checkpoint:
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        
        # 加载调度器状态
        if scheduler is not None and 'scheduler_state_dict' in checkpoint:
            if checkpoint['scheduler_state_dict'] is not None:
                scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        
        print(f"✓ 成功加载检查点 (Epoch {checkpoint.get('epoch', 'Unknown')})")
        return checkpoint
    
    def find_latest_checkpoint(self):
        """查找最新的检查点"""
        checkpoints = list(self.save_dir.glob("checkpoint_epoch_*.pth"))
        if not checkpoints:
            return None
        
        # 按修改时间排序
        latest = max(checkpoints, key=lambda x: x.stat().st_mtime)
        return latest
    
    def list_checkpoints(self):
        """列出所有检查点"""
        checkpoints = list(self.save_dir.glob("checkpoint_epoch_*.pth"))
        checkpoints.sort(key=lambda x: x.stat().st_mtime, reverse=True)
        
        print(f"\n📁 检查点目录: {self.save_dir}")
        print("=" * 60)
        
        if not checkpoints:
            print("暂无检查点文件")
            return []
        
        for i, ckpt in enumerate(checkpoints):
            # 尝试读取检查点信息
            try:
                state = torch.load(ckpt, map_location='cpu')
                epoch = state.get('epoch', 'Unknown')
                val_acc = state.get('val_acc', 'Unknown')
                save_time = state.get('save_time', 'Unknown')
                size = ckpt.stat().st_size / (1024 * 1024)  # MB
                
                print(f"{i+1:2d}. {ckpt.name}")
                print(f"     Epoch: {epoch}, Val Acc: {val_acc}, Size: {size:.1f}MB")
                print(f"     Time: {save_time}")
                print()
            except Exception as e:
                print(f"{i+1:2d}. {ckpt.name} (无法读取: {e})")
        
        return checkpoints
    
    def _cleanup_checkpoints(self):
        """清理旧的检查点，只保留最新的几个"""
        checkpoints = list(self.save_dir.glob("checkpoint_epoch_*.pth"))
        if len(checkpoints) <= self.max_checkpoints:
            return
        
        # 按修改时间排序，删除最旧的
        checkpoints.sort(key=lambda x: x.stat().st_mtime, reverse=True)
        for old_ckpt in checkpoints[self.max_checkpoints:]:
            old_ckpt.unlink()
            print(f"🗑️  已删除旧检查点: {old_ckpt.name}")

# 示例：创建检查点管理器
checkpoint_manager = CheckpointManager(save_dir="./checkpoints", max_checkpoints=3)

print("✓ 断点管理器创建完成")

### 7.2 恢复训练功能

In [None]:
def resume_training(checkpoint_path, model, optimizer, scheduler=None, config=None):
    """
    从断点恢复训练
    
    Args:
        checkpoint_path: 断点文件路径
        model: 模型实例
        optimizer: 优化器
        scheduler: 学习率调度器
        config: 配置对象
    
    Returns:
        start_epoch: 开始的epoch
        history: 训练历史
        best_val_acc: 最佳验证准确率
    """
    
    checkpoint = checkpoint_manager.load_checkpoint(
        checkpoint_path, model, optimizer, scheduler
    )
    
    # 获取恢复信息
    start_epoch = checkpoint.get('epoch', 0) + 1
    best_val_acc = checkpoint.get('val_acc', 0.0)
    
    # 恢复训练历史
    history = TrainingHistory()
    if 'history' in checkpoint:
        history.__dict__.update(checkpoint['history'])
    
    # 恢复配置
    if config is None and 'config' in checkpoint:
        config = SimpleNamespace(**checkpoint['config'])
    
    print(f"🚀 准备从 Epoch {start_epoch} 恢复训练")
    print(f"📊 当前最佳验证准确率: {best_val_acc:.2f}%")
    
    return start_epoch, history, best_val_acc, config

# 演示：模拟恢复训练流程
def demo_resume_training():
    """演示恢复训练的完整流程"""
    
    print("=== 断点恢复训练演示 ===\n")
    
    # 1. 查看可用的检查点
    print("1. 查看可用的检查点:")
    available_checkpoints = checkpoint_manager.list_checkpoints()
    
    if not available_checkpoints:
        print("❌ 没有找到检查点文件，请先运行一些训练")
        return
    
    # 2. 选择最新的检查点进行恢复
    latest_checkpoint = checkpoint_manager.find_latest_checkpoint()
    print(f"2. 选择最新检查点: {latest_checkpoint.name}")
    
    # 3. 重新创建模型和优化器（模拟新的训练会话）
    print("\n3. 重新创建模型、优化器和调度器...")
    model = SimpleCNN(num_classes=10)
    model = model.to(device)
    
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)
    
    # 4. 恢复训练状态
    print("\n4. 恢复训练状态...")
    try:
        start_epoch, history, best_val_acc, restored_config = resume_training(
            latest_checkpoint, model, optimizer, scheduler
        )
        
        print(f"✓ 成功恢复训练状态!")
        print(f"  - 下一个训练epoch: {start_epoch}")
        print(f"  - 历史最佳准确率: {best_val_acc:.2f}%")
        print(f"  - 已恢复训练历史: {len(history.train_losses)} 个epochs")
        
        # 5. 可以继续训练（这里只是演示，不实际运行）
        print(f"\n5. 现在可以从 epoch {start_epoch} 继续训练...")
        print("   (为了演示目的，这里不实际执行训练循环)")
        
    except Exception as e:
        print(f"❌ 恢复训练失败: {e}")

# 运行演示
demo_resume_training()

## 8. TensorBoard 可视化

TensorBoard 是一个强大的可视化工具，可以帮助我们：
- 监控训练过程中的损失和指标
- 可视化模型结构
- 观察权重和梯度的分布
- 比较不同实验的结果

### 8.1 TensorBoard 基础使用

In [None]:
from torch.utils.tensorboard import SummaryWriter
import torchvision.utils as vutils
from datetime import datetime
import numpy as np

class TensorBoardLogger:
    """
    TensorBoard 日志记录器
    """
    def __init__(self, log_dir, experiment_name=None):
        if experiment_name is None:
            experiment_name = f"exp_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
        
        self.log_dir = Path(log_dir) / experiment_name
        self.log_dir.mkdir(parents=True, exist_ok=True)
        
        self.writer = SummaryWriter(log_dir=str(self.log_dir))
        self.step = 0
        
        print(f"📊 TensorBoard 日志目录: {self.log_dir}")
        print(f"💡 启动TensorBoard: tensorboard --logdir {log_dir}")
    
    def log_scalar(self, tag, value, step=None):
        """记录标量值"""
        if step is None:
            step = self.step
        self.writer.add_scalar(tag, value, step)
    
    def log_scalars(self, tag_dict, step=None):
        """记录多个标量值"""
        if step is None:
            step = self.step
        for tag, value in tag_dict.items():
            self.writer.add_scalar(tag, value, step)
    
    def log_histogram(self, tag, values, step=None):
        """记录直方图"""
        if step is None:
            step = self.step
        self.writer.add_histogram(tag, values, step)
    
    def log_image(self, tag, image, step=None):
        """记录图像"""
        if step is None:
            step = self.step
        self.writer.add_image(tag, image, step)
    
    def log_images(self, tag, images, step=None):
        """记录图像网格"""
        if step is None:
            step = self.step
        grid = vutils.make_grid(images, normalize=True, scale_each=True)
        self.writer.add_image(tag, grid, step)
    
    def log_model_graph(self, model, input_tensor):
        """记录模型计算图"""
        self.writer.add_graph(model, input_tensor)
    
    def log_model_weights(self, model, step=None):
        """记录模型权重分布"""
        if step is None:
            step = self.step
        
        for name, param in model.named_parameters():
            if param.requires_grad:
                self.writer.add_histogram(f"weights/{name}", param.data, step)
                if param.grad is not None:
                    self.writer.add_histogram(f"gradients/{name}", param.grad.data, step)
    
    def log_learning_rate(self, optimizer, step=None):
        """记录学习率"""
        if step is None:
            step = self.step
        
        for i, param_group in enumerate(optimizer.param_groups):
            lr = param_group['lr']
            self.writer.add_scalar(f"learning_rate/group_{i}", lr, step)
    
    def increment_step(self):
        """增加步数"""
        self.step += 1
    
    def close(self):
        """关闭writer"""
        self.writer.close()

# 创建 TensorBoard logger
tb_logger = TensorBoardLogger(
    log_dir="./tensorboard_logs", 
    experiment_name="mnist_cnn_demo"
)

print("✓ TensorBoard 记录器创建完成")

### 8.2 TensorBoard 可视化演示

In [None]:
def demo_tensorboard_logging():
    """演示 TensorBoard 的各种功能"""
    
    print("=== TensorBoard 可视化演示 ===\n")
    
    # 1. 创建示例模型和数据
    model = SimpleCNN(num_classes=10)
    model = model.to(device)
    
    # 获取一个batch的数据用于演示
    sample_data, sample_labels = next(iter(train_loader))
    sample_data = sample_data.to(device)
    
    print("1. 记录模型计算图...")
    # 记录模型计算图
    tb_logger.log_model_graph(model, sample_data[:1])  # 只用一个样本
    
    print("2. 记录样本图像...")
    # 记录样本图像
    tb_logger.log_images("samples/train_images", sample_data[:8])  # 前8张图像
    
    print("3. 模拟训练过程记录...")
    # 模拟一个短训练过程来演示日志记录
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    criterion = nn.CrossEntropyLoss()
    
    model.train()
    for step in range(10):  # 只演示10步
        # 前向传播
        outputs = model(sample_data)
        loss = criterion(outputs, sample_labels.to(device))
        
        # 计算准确率
        _, predicted = torch.max(outputs.data, 1)
        accuracy = (predicted == sample_labels.to(device)).float().mean().item() * 100
        
        # 反向传播
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # 记录到 TensorBoard
        tb_logger.log_scalar("loss/train", loss.item(), step)
        tb_logger.log_scalar("accuracy/train", accuracy, step)
        tb_logger.log_learning_rate(optimizer, step)
        
        # 每5步记录一次权重分布
        if step % 5 == 0:
            tb_logger.log_model_weights(model, step)
        
        print(f"Step {step+1}/10: Loss={loss.item():.4f}, Acc={accuracy:.2f}%")
    
    print("\n4. 记录超参数和最终结果...")
    # 记录超参数
    hparams = {
        'lr': 0.001,
        'batch_size': 64,
        'model': 'SimpleCNN',
        'optimizer': 'Adam'
    }
    metrics = {
        'final_loss': loss.item(),
        'final_accuracy': accuracy
    }
    
    # TensorBoard 的超参数记录
    tb_logger.writer.add_hparams(hparams, metrics)
    
    print("5. 生成混淆矩阵可视化...")
    # 简单的预测结果分析
    model.eval()
    with torch.no_grad():
        all_predictions = []
        all_labels = []
        
        for data, labels in train_loader:
            data = data.to(device)
            outputs = model(data)
            _, predicted = torch.max(outputs, 1)
            
            all_predictions.extend(predicted.cpu().numpy())
            all_labels.extend(labels.numpy())
            
            if len(all_predictions) > 500:  # 只用前500个样本演示
                break
        
        # 创建混淆矩阵
        from sklearn.metrics import confusion_matrix
        import seaborn as sns
        
        cm = confusion_matrix(all_labels[:500], all_predictions[:500])
        
        # 绘制混淆矩阵
        plt.figure(figsize=(8, 6))
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
        plt.title('Confusion Matrix')
        plt.ylabel('True Label')
        plt.xlabel('Predicted Label')
        
        # 保存到 TensorBoard
        tb_logger.writer.add_figure("confusion_matrix", plt.gcf(), 0)
        plt.close()
    
    print("\n✓ TensorBoard 演示完成!")
    print(f"📊 日志已保存到: {tb_logger.log_dir}")
    print(f"🌐 启动 TensorBoard: tensorboard --logdir {tb_logger.log_dir.parent}")
    print("   然后在浏览器中打开: http://localhost:6006")

# 运行演示
demo_tensorboard_logging()

## 9. 实验管理与日志记录

在深度学习研究中，实验管理和详细的日志记录是非常重要的，它们可以帮助我们：
- 追踪不同实验的超参数和结果
- 复现实验结果
- 比较不同方法的性能
- 记录训练过程中的详细信息

### 9.1 实验管理系统

In [None]:
import logging
import sys
from typing import Dict, Any
import psutil
import platform

class ExperimentManager:
    """
    实验管理器 - 统一管理实验的配置、日志和结果
    """
    def __init__(self, experiment_name, base_dir="./experiments"):
        self.experiment_name = experiment_name
        self.base_dir = Path(base_dir)
        self.exp_dir = self.base_dir / experiment_name
        
        # 创建实验目录结构
        self.exp_dir.mkdir(parents=True, exist_ok=True)
        (self.exp_dir / "checkpoints").mkdir(exist_ok=True)
        (self.exp_dir / "logs").mkdir(exist_ok=True)
        (self.exp_dir / "results").mkdir(exist_ok=True)
        (self.exp_dir / "tensorboard").mkdir(exist_ok=True)
        
        # 设置日志
        self.setup_logging()
        
        # 记录实验开始时间和系统信息
        self.start_time = datetime.now()
        self.log_system_info()
        
        print(f"🔬 实验 '{experiment_name}' 已初始化")
        print(f"📁 实验目录: {self.exp_dir}")
    
    def setup_logging(self):
        """设置日志系统"""
        log_file = self.exp_dir / "logs" / "experiment.log"
        
        # 创建logger
        self.logger = logging.getLogger(self.experiment_name)
        self.logger.setLevel(logging.INFO)
        
        # 避免重复添加handler
        if not self.logger.handlers:
            # 文件handler
            file_handler = logging.FileHandler(log_file)
            file_handler.setLevel(logging.INFO)
            
            # 控制台handler
            console_handler = logging.StreamHandler(sys.stdout)
            console_handler.setLevel(logging.INFO)
            
            # 格式化
            formatter = logging.Formatter(
                '%(asctime)s - %(name)s - %(levelname)s - %(message)s',
                datefmt='%Y-%m-%d %H:%M:%S'
            )
            file_handler.setFormatter(formatter)
            console_handler.setFormatter(formatter)
            
            # 添加handlers
            self.logger.addHandler(file_handler)
            self.logger.addHandler(console_handler)
    
    def log_system_info(self):
        """记录系统信息"""
        info = {
            "experiment_name": self.experiment_name,
            "start_time": self.start_time.isoformat(),
            "python_version": platform.python_version(),
            "platform": platform.platform(),
            "cpu_count": psutil.cpu_count(),
            "memory_gb": round(psutil.virtual_memory().total / (1024**3), 2),
            "pytorch_version": torch.__version__,
            "cuda_available": torch.cuda.is_available(),
            "cuda_version": torch.version.cuda if torch.cuda.is_available() else None,
            "gpu_count": torch.cuda.device_count() if torch.cuda.is_available() else 0
        }
        
        if torch.cuda.is_available():
            info["gpu_names"] = [torch.cuda.get_device_name(i) 
                               for i in range(torch.cuda.device_count())]
        
        # 保存系统信息
        with open(self.exp_dir / "system_info.json", "w") as f:
            json.dump(info, f, indent=2)
        
        self.logger.info(f"System info: {json.dumps(info, indent=2)}")
    
    def log_config(self, config: Dict[str, Any]):
        """记录实验配置"""
        config_file = self.exp_dir / "config.json"
        
        with open(config_file, "w") as f:
            json.dump(config, f, indent=2)
        
        self.logger.info(f"Configuration: {json.dumps(config, indent=2)}")
    
    def log_info(self, message: str):
        """记录信息"""
        self.logger.info(message)
    
    def log_results(self, results: Dict[str, Any], epoch: int = None):
        """记录结果"""
        timestamp = datetime.now().isoformat()
        
        result_entry = {
            "timestamp": timestamp,
            "epoch": epoch,
            "results": results
        }
        
        # 追加到结果文件
        results_file = self.exp_dir / "results" / "results.jsonl"
        with open(results_file, "a") as f:
            f.write(json.dumps(result_entry) + "\n")
        
        self.logger.info(f"Results (epoch {epoch}): {json.dumps(results)}")
    
    def save_model(self, model, filename, **kwargs):
        """保存模型"""
        save_path = self.exp_dir / "checkpoints" / filename
        
        save_dict = {
            "model_state_dict": model.state_dict(),
            "timestamp": datetime.now().isoformat(),
            **kwargs
        }
        
        torch.save(save_dict, save_path)
        self.logger.info(f"Model saved: {filename}")
        return save_path
    
    def get_tensorboard_dir(self):
        """获取TensorBoard目录"""
        return str(self.exp_dir / "tensorboard")
    
    def finalize(self, final_results: Dict[str, Any] = None):
        """结束实验，记录最终结果"""
        end_time = datetime.now()
        duration = end_time - self.start_time
        
        summary = {
            "experiment_name": self.experiment_name,
            "start_time": self.start_time.isoformat(),
            "end_time": end_time.isoformat(),
            "duration_seconds": duration.total_seconds(),
            "duration_formatted": str(duration),
            "final_results": final_results or {}
        }
        
        # 保存实验总结
        with open(self.exp_dir / "experiment_summary.json", "w") as f:
            json.dump(summary, f, indent=2)
        
        self.logger.info(f"Experiment completed. Duration: {duration}")
        if final_results:
            self.logger.info(f"Final results: {json.dumps(final_results)}")

# 示例：创建实验管理器
exp_manager = ExperimentManager("mnist_cnn_v1")

# 示例配置
sample_config = {
    "model": "SimpleCNN",
    "dataset": "MNIST",
    "batch_size": 64,
    "learning_rate": 0.001,
    "epochs": 10,
    "optimizer": "Adam",
    "loss_function": "CrossEntropyLoss"
}

exp_manager.log_config(sample_config)
exp_manager.log_info("实验管理器演示完成")

print("✓ 实验管理器创建完成")

## 10. 完整的 MNIST 分类案例

现在我们将所有前面学到的技术整合起来，创建一个完整的、生产就绪的 MNIST 分类项目。这个案例将包括：
- 完整的项目结构
- 配置管理
- 数据加载和预处理
- 模型定义和训练
- 断点保存和恢复
- TensorBoard 可视化
- 实验管理
- 模型评估和测试

### 10.1 项目配置

In [None]:
class MNISTConfig:
    """MNIST项目的完整配置"""
    def __init__(self):
        # 数据相关
        self.data_dir = "./data"
        self.batch_size = 64
        self.num_workers = 4
        self.pin_memory = True
        
        # 模型相关
        self.model_name = "SimpleCNN"
        self.num_classes = 10
        self.dropout_rate = 0.5
        
        # 训练相关
        self.epochs = 10
        self.learning_rate = 0.001
        self.weight_decay = 1e-4
        self.optimizer = "Adam"
        
        # 调度器相关
        self.scheduler = "StepLR"
        self.step_size = 5
        self.gamma = 0.5
        
        # 早停相关
        self.early_stopping = True
        self.patience = 5
        self.min_delta = 0.001
        
        # 保存相关
        self.save_frequency = 2
        self.max_checkpoints = 3
        
        # 实验相关
        self.experiment_name = f"mnist_cnn_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        
        # 混合精度训练
        self.use_amp = torch.cuda.is_available()
        
        # TensorBoard
        self.log_every = 100  # 每多少步记录一次
        
        # 随机种子
        self.seed = 42

def complete_mnist_training():
    """完整的MNIST训练流程"""
    
    print("🚀 开始完整的 MNIST 分类训练")
    print("=" * 60)
    
    # 1. 初始化配置
    config = MNISTConfig()
    
    # 设置随机种子
    torch.manual_seed(config.seed)
    np.random.seed(config.seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(config.seed)
    
    # 2. 初始化实验管理器
    exp_manager = ExperimentManager(config.experiment_name)
    exp_manager.log_config(config.__dict__)
    
    # 3. 初始化 TensorBoard
    tb_logger = TensorBoardLogger(
        log_dir=exp_manager.get_tensorboard_dir(),
        experiment_name="training"
    )
    
    # 4. 创建数据加载器
    exp_manager.log_info("创建数据加载器...")
    train_loader = create_dataloader(
        dataset_type='train',
        batch_size=config.batch_size,
        shuffle=True,
        num_workers=config.num_workers,
        pin_memory=config.pin_memory
    )
    
    val_loader = create_dataloader(
        dataset_type='val',
        batch_size=config.batch_size,
        shuffle=False,
        num_workers=config.num_workers,
        pin_memory=config.pin_memory
    )
    
    test_loader = create_dataloader(
        dataset_type='test',
        batch_size=config.batch_size,
        shuffle=False,
        num_workers=config.num_workers,
        pin_memory=config.pin_memory
    )
    
    # 5. 创建模型
    exp_manager.log_info("创建模型...")
    model = SimpleCNN(num_classes=config.num_classes, dropout_rate=config.dropout_rate)
    model = model.to(config.device)
    
    # 记录模型信息
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    exp_manager.log_info(f"模型总参数数: {total_params:,}")
    exp_manager.log_info(f"可训练参数数: {trainable_params:,}")
    
    # 6. 创建优化器和调度器
    optimizer = torch.optim.Adam(
        model.parameters(),
        lr=config.learning_rate,
        weight_decay=config.weight_decay
    )
    
    scheduler = torch.optim.lr_scheduler.StepLR(
        optimizer,
        step_size=config.step_size,
        gamma=config.gamma
    )
    
    # 7. 创建损失函数
    criterion = nn.CrossEntropyLoss()
    
    # 8. 创建早停和检查点管理器
    early_stopping = EarlyStopping(
        patience=config.patience,
        min_delta=config.min_delta
    ) if config.early_stopping else None
    
    checkpoint_manager = CheckpointManager(
        save_dir=exp_manager.exp_dir / "checkpoints",
        max_checkpoints=config.max_checkpoints
    )
    
    # 9. 混合精度训练
    scaler = torch.cuda.amp.GradScaler() if config.use_amp else None
    
    # 10. 记录模型图
    sample_input = next(iter(train_loader))[0][:1].to(config.device)
    tb_logger.log_model_graph(model, sample_input)
    
    # 11. 训练循环
    exp_manager.log_info("开始训练...")
    history = TrainingHistory()
    best_val_acc = 0.0
    
    for epoch in range(config.epochs):
        epoch_start_time = time.time()
        
        # 训练
        train_loss, train_acc = train_epoch_with_amp(
            model, train_loader, criterion, optimizer, config.device,
            scaler if config.use_amp else None, epoch, config, tb_logger
        )
        
        # 验证
        val_loss, val_acc = validate_epoch(
            model, val_loader, criterion, config.device, epoch, config
        )
        
        # 更新调度器
        scheduler.step()
        current_lr = optimizer.param_groups[0]['lr']
        
        # 更新历史
        history.update(epoch, train_loss, train_acc, val_loss, val_acc, current_lr)
        
        # TensorBoard 记录
        tb_logger.log_scalars({
            'Loss/Train': train_loss,
            'Loss/Validation': val_loss,
            'Accuracy/Train': train_acc,
            'Accuracy/Validation': val_acc,
            'Learning_Rate': current_lr
        }, epoch)
        
        # 记录权重分布
        if epoch % 2 == 0:
            tb_logger.log_model_weights(model, epoch)
        
        # 计算时间
        epoch_time = time.time() - epoch_start_time
        
        # 记录结果
        epoch_results = {
            'train_loss': train_loss,
            'train_acc': train_acc,
            'val_loss': val_loss,
            'val_acc': val_acc,
            'learning_rate': current_lr,
            'epoch_time': epoch_time
        }
        exp_manager.log_results(epoch_results, epoch)
        
        # 输出结果
        exp_manager.log_info(
            f"Epoch {epoch+1}/{config.epochs} - "
            f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}% - "
            f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}% - "
            f"LR: {current_lr:.6f} - Time: {epoch_time:.2f}s"
        )
        
        # 保存最佳模型
        is_best = val_acc > best_val_acc
        if is_best:
            best_val_acc = val_acc
            exp_manager.log_info(f"新的最佳模型! 验证准确率: {val_acc:.2f}%")
        
        # 保存检查点
        if (epoch + 1) % config.save_frequency == 0 or is_best:
            checkpoint_state = {
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'scaler_state_dict': scaler.state_dict() if scaler else None,
                'val_acc': val_acc,
                'val_loss': val_loss,
                'history': history.__dict__,
                'config': config.__dict__,
                'best_val_acc': best_val_acc
            }
            
            checkpoint_manager.save_checkpoint(
                checkpoint_state, epoch, is_best=is_best
            )
        
        # 早停检查
        if early_stopping:
            if early_stopping(val_loss, model):
                exp_manager.log_info(f"早停触发，在第 {epoch+1} 轮停止训练")
                break
    
    # 12. 测试最佳模型
    exp_manager.log_info("在测试集上评估最佳模型...")
    
    # 加载最佳模型
    best_checkpoint = checkpoint_manager.save_dir / "best_model.pth"
    if best_checkpoint.exists():
        checkpoint = torch.load(best_checkpoint, map_location=config.device)
        model.load_state_dict(checkpoint['model_state_dict'])
    
    # 测试
    test_loss, test_acc = test_model(model, test_loader, criterion, config.device)
    
    # 13. 最终结果
    final_results = {
        'best_val_acc': best_val_acc,
        'test_acc': test_acc,
        'test_loss': test_loss,
        'total_epochs': epoch + 1,
        'total_params': total_params,
        'trainable_params': trainable_params
    }
    
    exp_manager.log_info(f"训练完成! 最佳验证准确率: {best_val_acc:.2f}%, 测试准确率: {test_acc:.2f}%")
    
    # 14. 保存最终结果和关闭资源
    history.save(exp_manager.exp_dir / "training_history.json")
    exp_manager.finalize(final_results)
    tb_logger.close()
    
    print("\n🎉 完整的 MNIST 训练流程结束!")
    print(f"📊 实验目录: {exp_manager.exp_dir}")
    print(f"🏆 最终测试准确率: {test_acc:.2f}%")
    
    return model, history, final_results

# 这里我们创建配置但暂不运行完整训练（训练时间较长）
config = MNISTConfig()
print("✓ MNIST 完整项目配置创建完成")
print(f"📋 实验名称: {config.experiment_name}")
print("💡 运行 complete_mnist_training() 开始完整训练")

In [None]:
def train_epoch_with_amp(model, dataloader, criterion, optimizer, device, 
                        scaler, epoch, config, tb_logger):
    """带有混合精度的训练epoch"""
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    for batch_idx, (data, target) in enumerate(dataloader):
        data, target = data.to(device), target.to(device)
        
        optimizer.zero_grad()
        
        if scaler is not None:  # 使用混合精度
            with torch.cuda.amp.autocast():
                output = model(data)
                loss = criterion(output, target)
            
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:  # 标准训练
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
        
        # 统计
        running_loss += loss.item()
        _, predicted = torch.max(output.data, 1)
        total += target.size(0)
        correct += (predicted == target).sum().item()
        
        # TensorBoard 记录
        if batch_idx % config.log_every == 0:
            step = epoch * len(dataloader) + batch_idx
            tb_logger.log_scalar('Batch/Loss', loss.item(), step)
            tb_logger.log_scalar('Batch/Accuracy', 
                               100. * correct / total, step)
    
    avg_loss = running_loss / len(dataloader)
    accuracy = 100. * correct / total
    
    return avg_loss, accuracy

def test_model(model, test_loader, criterion, device):
    """在测试集上评估模型"""
    model.eval()
    test_loss = 0
    correct = 0
    total = 0
    
    all_predictions = []
    all_targets = []
    
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += criterion(output, target).item()
            
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()
            total += target.size(0)
            
            all_predictions.extend(pred.cpu().numpy().flatten())
            all_targets.extend(target.cpu().numpy())
    
    test_loss /= len(test_loader)
    test_acc = 100. * correct / total
    
    return test_loss, test_acc

print("✓ 辅助函数定义完成")

## 11. PyTorch 最佳实践总结

通过本教程，我们学习了 PyTorch 的各个方面。以下是一些重要的最佳实践：

### 11.1 代码组织

1. **模块化设计**：将数据加载、模型定义、训练循环分离
2. **配置管理**：使用配置类统一管理超参数
3. **实验管理**：为每个实验创建独立的目录和日志
4. **版本控制**：使用 Git 管理代码，记录每次实验的代码版本

### 11.2 训练优化

1. **混合精度训练**：在支持的硬件上使用 AMP 加速训练
2. **数据加载优化**：合理设置 `num_workers` 和 `pin_memory`
3. **学习率调度**：使用学习率调度器优化训练过程
4. **早停机制**：防止过拟合，节省训练时间

### 11.3 调试和监控

1. **TensorBoard 可视化**：监控训练过程和模型性能
2. **详细日志记录**：记录所有重要信息便于复现
3. **断点保存**：支持训练中断后的恢复
4. **性能分析**：监控 GPU 利用率和内存使用

### 11.4 模型部署准备

1. **模型保存**：保存完整的模型状态
2. **推理优化**：使用 `torch.jit` 或 TorchScript
3. **模型压缩**：量化、剪枝等技术
4. **端到端测试**：确保模型在生产环境中正常工作

### 11.5 常用代码模板

以下是一些常用的 PyTorch 代码模板：

In [None]:
# 1. 基本训练循环模板
basic_training_template = """
model.train()
for epoch in range(num_epochs):
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        
        if batch_idx % log_interval == 0:
            print(f'Epoch: {epoch}, Batch: {batch_idx}, Loss: {loss.item():.6f}')
"""

# 2. 模型保存和加载模板
save_load_template = """
# 保存
torch.save({
    'epoch': epoch,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'loss': loss,
}, checkpoint_path)

# 加载
checkpoint = torch.load(checkpoint_path)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']
"""

# 3. 自定义数据集模板
custom_dataset_template = """
class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, data_dir, transform=None):
        self.data_dir = data_dir
        self.transform = transform
        self.samples = self._load_samples()
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        sample = self.samples[idx]
        data = self._load_data(sample)
        
        if self.transform:
            data = self.transform(data)
        
        return data, sample['label']
    
    def _load_samples(self):
        # 实现样本加载逻辑
        pass
    
    def _load_data(self, sample):
        # 实现数据加载逻辑
        pass
"""

print("📋 常用代码模板:")
print("1. 基本训练循环")
print("2. 模型保存和加载")
print("3. 自定义数据集")
print("\n💡 这些模板可以作为你项目的起点进行修改")

### 11.6 进阶学习建议

完成本教程后，你可以继续学习以下高级主题：

1. **分布式训练**：
   - 数据并行 (DataParallel, DistributedDataParallel)
   - 模型并行
   - 混合并行策略

2. **模型优化**：
   - 量化 (Quantization)
   - 剪枝 (Pruning)
   - 知识蒸馏 (Knowledge Distillation)
   - TorchScript 和 ONNX

3. **高级架构**：
   - Transformer 模型
   - 生成对抗网络 (GANs)
   - 变分自编码器 (VAEs)
   - 图神经网络 (GNNs)

4. **专门领域**：
   - 计算机视觉：目标检测、图像分割
   - 自然语言处理：BERT、GPT 等
   - 强化学习：DQN、PPO 等
   - 时间序列分析：LSTM、Transformer

### 11.7 推荐资源

- **官方文档**：https://pytorch.org/docs/
- **官方教程**：https://pytorch.org/tutorials/
- **PyTorch 示例**：https://github.com/pytorch/examples
- **Papers with Code**：https://paperswithcode.com/
- **Awesome PyTorch**：https://github.com/bharathgs/Awesome-pytorch-list

## 12. 总结

恭喜你完成了 PyTorch 深度学习教程！🎉

在本教程中，我们学习了：

✅ **PyTorch 基础**：张量操作、自动微分机制
✅ **神经网络构建**：从简单的全连接到复杂的CNN
✅ **数据处理**：自定义数据集、数据加载优化
✅ **训练优化**：优化器、调度器、混合精度训练
✅ **实验管理**：断点保存、TensorBoard 可视化、日志记录
✅ **完整项目**：端到端的 MNIST 分类案例

现在你已经具备了使用 PyTorch 进行深度学习研究和开发的基础技能。记住，深度学习是一个快速发展的领域，持续学习和实践是成功的关键。

**下一步建议**：
1. 尝试在自己的数据集上应用所学知识
2. 参与开源项目，贡献代码
3. 阅读最新的研究论文并尝试复现
4. 关注 PyTorch 的最新发展和最佳实践

祝你在深度学习的道路上取得成功！🚀