首先训练CNN模型

In [10]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
from torch.nn.utils import prune

# 定义 CNN 模型
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 8, kernel_size=3, padding=1)  # -> 8 x 28 x 28
        self.conv2 = nn.Conv2d(8, 16, kernel_size=3, padding=1) # -> 16 x 28 x 28
        self.pool = nn.MaxPool2d(2, 2)                           # 每次池化尺寸减半
        self.dropout = nn.Dropout(0.25)
        self.fc1 = nn.Linear(16 * 7 * 7, 32)
        self.fc2 = nn.Linear(32, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))              # -> (B, 8, 28, 28)
        x = self.pool(F.relu(self.conv2(x)))   # -> (B, 16, 14, 14)
        x = self.pool(x)                       # -> (B, 16, 7, 7)
        x = self.dropout(x)
        x = x.view(x.size(0), -1)              # -> (B, 784)
        x = F.relu(self.fc1(x))                # -> (B, 32)
        x = self.fc2(x)                        # -> (B, 10)                 
        return x

# 权重剪枝（移除20%的权重）
def prune_model(model, amount=0.2):
    for name, module in model.named_modules():
        if isinstance(module, nn.Conv2d):
            prune.l1_unstructured(module, name='weight', amount=amount)
            prune.remove(module, 'weight')
    return model


# 数据预处理
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))  # 标准化
])

# 加载 MNIST 数据集
train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('./data', train=False, transform=transform)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1000, shuffle=False)

# 设置设备
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device('cpu')
model = CNN().to(device)
model = prune_model(model, amount=0.2)  # 剪枝20%的权重

# 设置优化器和损失函数
optimizer = optim.AdamW(model.parameters(), lr=0.0005,weight_decay=1e-4)
criterion = nn.CrossEntropyLoss()

# 训练模型
for epoch in range(1, 11):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
    print(f"Epoch {epoch} 完成")

# 测试模型
model.eval()
correct = 0
with torch.no_grad():
    for data, target in test_loader:
        data, target = data.to(device), target.to(device)
        output = model(data)
        pred = output.argmax(dim=1)
        correct += pred.eq(target).sum().item()


print(f"\n测试准确率：{correct / len(test_loader.dataset):.4f}")


Epoch 1 完成
Epoch 2 完成
Epoch 3 完成
Epoch 4 完成
Epoch 5 完成
Epoch 6 完成
Epoch 7 完成
Epoch 8 完成
Epoch 9 完成
Epoch 10 完成

测试准确率：0.9890


保存参数，转成numpy数组，并以二进制文件的形式存储

In [11]:
import numpy as np

state_dict = model.state_dict()
w_conv1 = state_dict['conv1.weight']
b_conv1 = state_dict['conv1.bias']
w_conv2 = state_dict['conv2.weight']
b_conv2 = state_dict['conv2.bias']
w_fc1 = state_dict['fc1.weight']
b_fc1 = state_dict['fc1.bias']
w_fc2 = state_dict['fc2.weight']
b_fc2 = state_dict['fc2.bias']
with open("parameters_cnn.bin", "wb") as out_bin:
    w_conv1.cpu().numpy().astype(np.float32).tofile(out_bin)
    b_conv1.cpu().numpy().astype(np.float32).tofile(out_bin)
    w_conv2.cpu().numpy().astype(np.float32).tofile(out_bin)
    b_conv2.cpu().numpy().astype(np.float32).tofile(out_bin)
    w_fc1.cpu().numpy().T.astype(np.float32).tofile(out_bin)
    b_fc1.cpu().numpy().astype(np.float32).tofile(out_bin)
    w_fc2.cpu().numpy().T.astype(np.float32).tofile(out_bin)
    b_fc2.cpu().numpy().astype(np.float32).tofile(out_bin)



打印各个参数的维度

In [12]:
for key, value in model.state_dict().items():
    print(f"{key}: {value.shape}")

conv1.bias: torch.Size([8])
conv1.weight: torch.Size([8, 1, 3, 3])
conv2.bias: torch.Size([16])
conv2.weight: torch.Size([16, 8, 3, 3])
fc1.weight: torch.Size([32, 784])
fc1.bias: torch.Size([32])
fc2.weight: torch.Size([10, 32])
fc2.bias: torch.Size([10])
