In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import time

print("库导入成功！")

库导入成功！


In [2]:
from RepVit import repvit_m0_6

print("RepVit 模型导入成功！")
print("使用 repvit_m0_6（最小版本，CPU 运行最快）")

RepVit 模型导入成功！
使用 repvit_m0_6（最小版本，CPU 运行最快）


  from .autonotebook import tqdm as notebook_tqdm
  def repvit_m0_9(pretrained=False, num_classes = 1000, distillation=False):
  def repvit_m1_0(pretrained=False, num_classes = 1000, distillation=False):
  def repvit_m1_1(pretrained=False, num_classes = 1000, distillation=False):
  def repvit_m1_5(pretrained=False, num_classes = 1000, distillation=False):
  def repvit_m2_3(pretrained=False, num_classes = 1000, distillation=False):


In [3]:
# 数据预处理：定义 transform
# 关键点：
# - Resize(64): 防止 RepViT 下采样导致特征图消失
# - Grayscale(3): 把单通道变成 3 通道，适配 RepViT 的输入层
transform = transforms.Compose(
    [
        transforms.Resize((64, 64)),
        transforms.Grayscale(num_output_channels=3),
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,)),
    ]
)
print("Transform 定义完成")

Transform 定义完成


In [4]:
# 下载/加载 MNIST 数据集
print("正在下载/加载 MNIST 数据...")
train_dataset = datasets.MNIST(
    "./data", train=True, download=True, transform=transform
)
test_dataset = datasets.MNIST("./data", train=False, transform=transform)
print(f"训练集大小: {len(train_dataset)}")
print(f"测试集大小: {len(test_dataset)}")

正在下载/加载 MNIST 数据...
训练集大小: 60000
测试集大小: 10000


In [5]:
# 创建 DataLoader
batch_size = 32

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
print(f"训练 DataLoader: {len(train_loader)} batches")
print(f"测试 DataLoader: {len(test_loader)} batches")


训练 DataLoader: 1875 batches
测试 DataLoader: 313 batches


In [6]:
# 初始化模型、优化器和损失函数
model = repvit_m0_6(num_classes=10)

optimizer = optim.AdamW(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

print("模型初始化完成")
print(f"模型参数数量: {sum(p.numel() for p in model.parameters()):,}")


模型初始化完成
模型参数数量: 2,169,502


In [7]:
# 训练循环（为了演示，只跑 1 个 Epoch，且只训练 200 个 batch）
epochs = 1
model.train()

print(f"\n=== 开始训练 (共 {epochs} 轮) ===")

start_time = time.time()

for epoch in range(epochs):
    for batch_idx, (data, target) in enumerate(train_loader):


        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()

        if batch_idx % 10 == 0:
            print(
                f"Epoch: {epoch + 1} | Batch: {batch_idx}/{len(train_loader)} | "
                f"Loss: {loss.item():.4f}"
            )

    break  # 只训练 1 个 epoch

total_time = time.time() - start_time
print(f"\n训练完成！")
print(f"总耗时: {total_time:.2f} 秒")



=== 开始训练 (共 1 轮) ===
Epoch: 1 | Batch: 0/1875 | Loss: 2.4205
Epoch: 1 | Batch: 10/1875 | Loss: 1.1626
Epoch: 1 | Batch: 20/1875 | Loss: 0.8315
Epoch: 1 | Batch: 30/1875 | Loss: 0.4282
Epoch: 1 | Batch: 40/1875 | Loss: 0.2008
Epoch: 1 | Batch: 50/1875 | Loss: 0.5908
Epoch: 1 | Batch: 60/1875 | Loss: 0.7364
Epoch: 1 | Batch: 70/1875 | Loss: 0.2092
Epoch: 1 | Batch: 80/1875 | Loss: 0.3546
Epoch: 1 | Batch: 90/1875 | Loss: 0.3742
Epoch: 1 | Batch: 100/1875 | Loss: 0.2603
Epoch: 1 | Batch: 110/1875 | Loss: 0.2863
Epoch: 1 | Batch: 120/1875 | Loss: 0.1915
Epoch: 1 | Batch: 130/1875 | Loss: 0.3074
Epoch: 1 | Batch: 140/1875 | Loss: 0.3266
Epoch: 1 | Batch: 150/1875 | Loss: 0.0371
Epoch: 1 | Batch: 160/1875 | Loss: 0.2933
Epoch: 1 | Batch: 170/1875 | Loss: 0.1254
Epoch: 1 | Batch: 180/1875 | Loss: 0.2497
Epoch: 1 | Batch: 190/1875 | Loss: 0.4148
Epoch: 1 | Batch: 200/1875 | Loss: 0.1210
Epoch: 1 | Batch: 210/1875 | Loss: 0.1602
Epoch: 1 | Batch: 220/1875 | Loss: 0.2440
Epoch: 1 | Batch: 230/1

In [8]:
# 见证奇迹的时刻：结构重参数化 (Fuse)
print("\n=== 正在进行结构重参数化 (Fuse) ===")
print("融合前结构 (打印第一个 Block):")
print(model.features[1].token_mixer[0])  # 打印一个未融合的 RepVGGDW


=== 正在进行结构重参数化 (Fuse) ===
融合前结构 (打印第一个 Block):
RepVGGDW(
  (conv): Conv2d_BN(
    (c): Conv2d(40, 40, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=40, bias=False)
    (bn): BatchNorm2d(40, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (conv1): Conv2d(40, 40, kernel_size=(1, 1), stride=(1, 1), groups=40)
  (bn): BatchNorm2d(40, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)


In [9]:
# 执行融合：把多分支结构全部合并成单路卷积
# 这一步必须在 eval 模式下做，或者做完之后就不再训练了
model.eval()
for m in model.modules():
    if hasattr(m, "fuse"):
        m.fuse()

print("\n融合后结构 (注意看变成了单纯的 Conv2d):")
print(model.features[1].token_mixer[0])


融合后结构 (注意看变成了单纯的 Conv2d):
RepVGGDW(
  (conv): Conv2d_BN(
    (c): Conv2d(40, 40, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=40, bias=False)
    (bn): BatchNorm2d(40, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (conv1): Conv2d(40, 40, kernel_size=(1, 1), stride=(1, 1), groups=40)
  (bn): BatchNorm2d(40, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)


In [10]:
# 推理评估：在测试集上评估融合后的模型
print("\n=== 开始在测试集上评估 (使用融合后的模型) ===")


correct = 0
total = 0
eval_start = time.time()

with torch.no_grad():
    for data, target in test_loader:

        output = model(data)
        _, predicted = torch.max(output.data, 1)
        total += target.size(0)
        correct += (predicted == target).sum().item()

        # 为了省时间，测试 20 个 batch 就停（约 640 张图）
        if total > 600:
            break

eval_time = time.time() - eval_start
accuracy = 100 * correct / total

print(f"测试集准确率: {accuracy:.2f}%")
print(f"评估耗时: {eval_time:.2f} 秒")



=== 开始在测试集上评估 (使用融合后的模型) ===
测试集准确率: 97.53%
评估耗时: 0.89 秒
