1. 参数清零 zero_module

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# 定义 zero_module，只打印一次
def zero_module(module):
    print(">> zero_module 被调用！初始化中...")
    for p in module.parameters():
        p.detach().zero_()
    return module

# 一个简化版的模型
class SimpleBlock(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = zero_module(nn.Conv2d(3, 3, 3, padding=1))  # 调用一次
        self.act = nn.ReLU()

    def forward(self, x):
        return self.act(self.conv(x))

# 创建模型
print("=== 构建模型阶段 ===")
model = SimpleBlock()

# 模拟输入
x = torch.randn(1, 3, 16, 16)

print("\n=== 第一次前向传播 ===")
out1 = model(x)

print("\n=== 第二次前向传播 ===")
out2 = model(x)

# 查看输出是否为 0（说明 conv 确实初始化为 0）
print("\n输出张量的平均值（应接近 0）:", out1.abs().mean().item())



2. 梯度检查点测试 torch.utils.checkpoint()

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.checkpoint import checkpoint
import time

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# === 1. 加载 MNIST 数据集 ===
transform = transforms.Compose([transforms.ToTensor()])
train_dataset = torchvision.datasets.MNIST(root='../public/mnist', train=True, transform=transform, download=False)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=1024, shuffle=True)

# === 2. 定义一个可以使用 checkpoint 的 CNN 模型 ===
class Block(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU()
        )

    def forward(self, x):
        return self.net(x)

class CNN(nn.Module):
    # ------------ 修改梯度检查点的位置 ------------
    def __init__(self, use_checkpoint=False):
        super().__init__()
        self.use_checkpoint = use_checkpoint
        self.block1 = Block(1, 32)
        self.block2 = Block(32, 64)
        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(64, 10)

    def forward(self, x):
        if self.use_checkpoint:
            x = checkpoint(self.block1, x)
            x = checkpoint(self.block2, x)
        else:
            x = self.block1(x)
            x = self.block2(x)
        x = self.pool(x).view(x.size(0), -1)
        return self.fc(x)

# === 3. 模型准备 ===
model = CNN(use_checkpoint=True).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()

# === 4. 控制训练时间，最多2分钟 ===
start_time = time.time()
max_duration = 45  # 秒

step = 0
while True:
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        if step % 10 == 0:
            elapsed = time.time() - start_time
            print(f"Step {step}, Loss = {loss.item():.4f}, Time = {elapsed:.1f}s")
        step += 1

        if time.time() - start_time > max_duration:
            print("⏰ 2分钟到了，训练结束！")
            break
    else:
        continue  # 只有 break 没触发才走这里
    break
