# nn.Module

nn.Module 是 PyTorch 神经网络的核心基类，所有神经网络模型都应该继承这个类。它提供了构建、训练和管理神经网络所需的基本功能。
nn.Module 是一个抽象基类，它定义了神经网络模块的标准接口和行为：

In [7]:
import torch.nn as nn
import torch

# 所有自定义网络都必须继承 nn.Module
class MyNetwork(nn.Module):
    def __init__(self):
        super().__init__()  # 必须调用父类初始化
        # 定义网络层
        self.layer1 = nn.Linear(10, 5)
        self.layer2 = nn.Linear(5, 1)
    
    def forward(self, x):
        # 定义前向传播
        x = torch.relu(self.layer1(x))
        x = self.layer2(x)
        return x

## 参数管理（最重要的特性）
nn.Module 自动跟踪所有注册的参数：

In [8]:
class SimpleNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(784, 256)  # 自动跟踪参数
        self.weight = nn.Parameter(torch.randn(256, 10))  # 自定义参数
    
    def forward(self, x):
        x = self.linear(x)
        x = x @ self.weight
        return x

# 使用示例
model = SimpleNet()
print("参数数量:", sum(p.numel() for p in model.parameters()))
print("可训练参数:", sum(p.numel() for p in model.parameters() if p.requires_grad))

# 访问特定参数
for name, param in model.named_parameters():
    print(f"{name}: {param.shape}")

参数数量: 203520
可训练参数: 203520
weight: torch.Size([256, 10])
linear.weight: torch.Size([256, 784])
linear.bias: torch.Size([256])


## 模块嵌套和层次结构
nn.Module 支持模块的嵌套，可以构建复杂的网络结构：

In [9]:
class ResidualBlock(nn.Module):
    """残差块 - 作为构建块"""
    def __init__(self, channels):
        super().__init__()
        self.conv1 = nn.Conv2d(channels, channels, 3, padding=1)
        self.conv2 = nn.Conv2d(channels, channels, 3, padding=1)
        self.relu = nn.ReLU()
    
    def forward(self, x):
        identity = x
        x = self.relu(self.conv1(x))
        x = self.conv2(x)
        return self.relu(x + identity)

class ComplexNetwork(nn.Module):
    """复杂网络 - 由多个块组成"""
    def __init__(self):
        super().__init__()
        self.initial_conv = nn.Conv2d(3, 64, 7, stride=2, padding=3)
        
        # 嵌套多个残差块
        self.res_blocks = nn.ModuleList([
            ResidualBlock(64) for _ in range(4)
        ])
        
        self.classifier = nn.Linear(64, 10)
    
    def forward(self, x):
        x = self.initial_conv(x)
        for block in self.res_blocks:
            x = block(x)
        x = x.mean([2, 3])  # 全局平均池化
        return self.classifier(x)

## 常用方法和功能
### 参数管理方法

In [10]:
model = ComplexNetwork()

# 获取所有参数
params = list(model.parameters())

# 获取参数名称和值
for name, param in model.named_parameters():
    print(f"{name}: {param.shape}")

# 获取所有子模块
for name, module in model.named_children():
    print(f"子模块 {name}: {module}")

# 递归获取所有模块
for name, module in model.named_modules():
    print(f"模块 {name}: {type(module).__name__}")

initial_conv.weight: torch.Size([64, 3, 7, 7])
initial_conv.bias: torch.Size([64])
res_blocks.0.conv1.weight: torch.Size([64, 64, 3, 3])
res_blocks.0.conv1.bias: torch.Size([64])
res_blocks.0.conv2.weight: torch.Size([64, 64, 3, 3])
res_blocks.0.conv2.bias: torch.Size([64])
res_blocks.1.conv1.weight: torch.Size([64, 64, 3, 3])
res_blocks.1.conv1.bias: torch.Size([64])
res_blocks.1.conv2.weight: torch.Size([64, 64, 3, 3])
res_blocks.1.conv2.bias: torch.Size([64])
res_blocks.2.conv1.weight: torch.Size([64, 64, 3, 3])
res_blocks.2.conv1.bias: torch.Size([64])
res_blocks.2.conv2.weight: torch.Size([64, 64, 3, 3])
res_blocks.2.conv2.bias: torch.Size([64])
res_blocks.3.conv1.weight: torch.Size([64, 64, 3, 3])
res_blocks.3.conv1.bias: torch.Size([64])
res_blocks.3.conv2.weight: torch.Size([64, 64, 3, 3])
res_blocks.3.conv2.bias: torch.Size([64])
classifier.weight: torch.Size([10, 64])
classifier.bias: torch.Size([10])
子模块 initial_conv: Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=

### 训练和评估模式

In [11]:
# 训练模式（默认）
model.train()
print("训练模式:", model.training)  # True

# 评估模式（影响Dropout、BatchNorm等层的行为）
model.eval()
print("训练模式:", model.training)  # False

# 在评估模式下进行推理
with torch.no_grad():  # 禁用梯度计算，节省内存
    model.eval()
    output = model(input_data)

训练模式: True
训练模式: False


NameError: name 'input_data' is not defined

### 设备管理

In [None]:
# 将模型移动到设备
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

# 检查模型参数所在的设备
print(next(model.parameters()).device)

## 实际应用示例
### 示例1：完整的图像分类网络

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class CNNClassifier(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        self.conv_layers = nn.Sequential(
            nn.Conv2d(3, 32, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
        )
        self.fc_layers = nn.Sequential(
            nn.Linear(64 * 8 * 8, 128),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(128, num_classes)
        )
    
    def forward(self, x):
        x = self.conv_layers(x)
        x = x.view(x.size(0), -1)  # 展平
        x = self.fc_layers(x)
        return x

# 使用网络
model = CNNClassifier(num_classes=10)
print(model)

# 模拟输入
x = torch.randn(4, 3, 32, 32)  # batch_size=4, 3通道, 32x32
output = model(x)
print("输出形状:", output.shape)

### 示例2：使用 nn.ModuleList 和 nn.ModuleDict

In [None]:
class DynamicNetwork(nn.Module):
    def __init__(self, layer_sizes):
        super().__init__()
        
        # 使用 ModuleList 管理动态数量的层
        self.layers = nn.ModuleList()
        for i in range(len(layer_sizes) - 1):
            self.layers.append(nn.Linear(layer_sizes[i], layer_sizes[i+1]))
        
        # 使用 ModuleDict 管理不同类型的激活函数
        self.activations = nn.ModuleDict({
            'relu': nn.ReLU(),
            'sigmoid': nn.Sigmoid(),
            'tanh': nn.Tanh()
        })
        
        self.current_activation = 'relu'
    
    def forward(self, x):
        for i, layer in enumerate(self.layers):
            x = layer(x)
            if i < len(self.layers) - 1:  # 除了最后一层都加激活函数
                x = self.activations[self.current_activation](x)
        return x
    
    def set_activation(self, activation_name):
        self.current_activation = activation_name

## 模型保存和加载

In [None]:
# 保存整个模型
torch.save(model, 'model.pth')
loaded_model = torch.load('model.pth')

# 只保存状态字典（推荐）
torch.save(model.state_dict(), 'model_state.pth')

# 加载状态字典
new_model = CNNClassifier()
new_model.load_state_dict(torch.load('model_state.pth'))

## 最佳实践
### 1. 始终调用 super().__init__()

In [None]:
class CorrectNet(nn.Module):
    def __init__(self):
        super().__init__()  # 必须调用！
        self.layer = nn.Linear(10, 5)

### 2. 在 __init__ 中定义所有层，在 forward 中定义计算

In [None]:
class GoodDesign(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer1 = nn.Linear(10, 20)  # 在 init 中定义
        self.layer2 = nn.Linear(20, 1)
        self.dropout = nn.Dropout(0.5)
    
    def forward(self, x):
        x = F.relu(self.layer1(x))  # 在 forward 中组织计算
        x = self.dropout(x)
        return self.layer2(x)

### 3. 使用 nn.Sequential 组织简单的层序列

In [None]:
class SequentialNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.features = nn.Sequential(
            nn.Linear(784, 256),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(128, 10)
        )
    
    def forward(self, x):
        return self.features(x)