In [14]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.cuda.amp import GradScaler, autocast
from torch.utils.data import DataLoader, TensorDataset
import time

In [15]:
# 定义简单的CNN模型
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.fc1 = nn.Linear(64*7*7, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = torch.max_pool2d(x, 2)
        x = torch.relu(self.conv2(x))
        x = torch.max_pool2d(x, 2)
        x = x.view(x.size(0), -1)  # Flatten the tensor
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [31]:
# 生成随机数据：假设我们需要1000张28x28的单通道图像，和对应的标签
num_samples = 15000
image_size = (28, 28)

# 随机生成图像数据（形状为[1000, 1, 28, 28]）
images = torch.randn(num_samples, 1, image_size[0], image_size[1])

# 随机生成标签（假设是10类分类任务，标签范围是0-9）
labels = torch.randint(0, 10, (num_samples,))

In [32]:
# 将数据放入TensorDataset中，并使用DataLoader加载
dataset = TensorDataset(images, labels)
train_loader = DataLoader(dataset, batch_size=256, shuffle=True)

In [None]:
# 初始化模型、损失函数和优化器
model = SimpleCNN().cuda()  # 将模型放置在GPU上
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# 初始化 GradScaler 用于自动混合精度
scaler = GradScaler()

In [None]:
# 记录训练开始时间
start_time = time.time()
# 训练循环
num_epochs = 10
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for inputs, labels in train_loader:
        inputs, labels = inputs.cuda(), labels.cuda()

        # 自动混合精度训练
        optimizer.zero_grad()

        with autocast():  # 开启FP16自动混合精度
            outputs = model(inputs)
            loss = criterion(outputs, labels)

        # 缩放损失，进行反向传播
        scaler.scale(loss).backward()

        # 使用Scaler更新优化器
        scaler.step(optimizer)
        scaler.update()

        running_loss += loss.item()

    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss / len(train_loader)}")
# 计算并打印训练总时间
end_time = time.time()
training_time = end_time - start_time
print(f"Training time: {training_time:.2f} seconds")