In [1]:
# ========== CONFIG ==========
import os
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["MKL_NUM_THREADS"] = "1"
os.environ["NUMEXPR_NUM_THREADS"] = "1"

# ----- 参数可调 -----
HIDDEN_SIZE = 128       # 隐藏层维度
BASE_EPOCHS = 2         # 基线训练 epoch（建议小一些，跑快）
BASE_LR     = 0.01
BASE_OPTIM  = "SGD"     # "SGD" or "Adam"

PRUNE_AMOUNTS = [0.3, 0.5, 0.7, 0.9]  # 剪枝比例

FT_EPOCHS  = 2          # 微调 epoch
FT_OPTIM   = "Adam"     # "SGD" or "Adam"
FT_LR      = 0.001

BATCH_SIZE = 128
USE_PLOTS  = True

import torch
torch.set_num_threads(1)
DEVICE = torch.device("cpu")
print("Using device:", DEVICE)

Using device: cpu


In [2]:
# ========== DATA (MNIST) ==========
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

trainset = torchvision.datasets.MNIST(root="./data", train=True, download=True, transform=transform)
testset  = torchvision.datasets.MNIST(root="./data", train=False, download=True, transform=transform)

trainloader = DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True)
testloader  = DataLoader(testset,  batch_size=BATCH_SIZE, shuffle=False)

In [3]:
# ========== MODEL / UTILS ==========
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.nn.utils.prune as prune
import pandas as pd
import matplotlib.pyplot as plt
import copy

class MLP(nn.Module):
    def __init__(self, hidden=HIDDEN_SIZE):
        super().__init__()
        self.fc1 = nn.Linear(28*28, hidden)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden, 10)
    def forward(self, x):
        x = x.view(-1, 28*28)
        x = self.relu(self.fc1(x))
        return self.fc2(x)

def count_sparsity(model):
    total_params, total_zeros = 0, 0
    for _, p in model.named_parameters():
        numel = p.numel()
        zeros = int((p == 0).sum().item())
        total_params += numel
        total_zeros += zeros
    return total_zeros/total_params, total_params, total_zeros

def train(model, loader, opt):
    model.train()
    for X, y in loader:
        X, y = X.to(DEVICE), y.to(DEVICE)
        opt.zero_grad()
        loss = F.cross_entropy(model(X), y)
        loss.backward()
        opt.step()

def test(model, loader):
    model.eval(); correct=0; total=0
    with torch.no_grad():
        for X,y in loader:
            X, y = X.to(DEVICE), y.to(DEVICE)
            pred = model(X).argmax(1)
            correct += (pred==y).sum().item(); total += y.size(0)
    return correct/total

def make_optimizer(name, params, lr):
    return optim.SGD(params, lr=lr, momentum=0.9) if name=="SGD" else optim.Adam(params, lr=lr)

def prune_model(model, amount):
    for m in model.modules():
        if isinstance(m, nn.Linear):
            prune.l1_unstructured(m, "weight", amount)
            prune.remove(m, "weight")

In [4]:
# ========== RUN EXPERIMENT (NO PLOTS, SAFE) ==========
import gc
torch.manual_seed(2024)

# 控制每个 epoch 的最大 steps，避免长时间占用内存/线程（可调小/大）
MAX_TRAIN_STEPS_PER_EPOCH = 300   # 对 MNIST 128 batch 大约 ~468 steps/epoch

def train_one_epoch_capped(model, loader, opt, max_steps=MAX_TRAIN_STEPS_PER_EPOCH):
    model.train()
    steps = 0
    for X, y in loader:
        X, y = X.to(DEVICE), y.to(DEVICE)
        opt.zero_grad(set_to_none=True)
        loss = F.cross_entropy(model(X), y)
        loss.backward()
        opt.step()
        steps += 1
        if steps >= max_steps:
            break
    # 显式清理
    del X, y, loss
    gc.collect()

@torch.inference_mode()
def test_safe(model, loader):
    model.eval()
    correct = 0
    total = 0
    for X, y in loader:
        X, y = X.to(DEVICE), y.to(DEVICE)
        pred = model(X).argmax(1)
        correct += (pred == y).sum().item()
        total += y.size(0)
    # 显式清理
    del X, y, pred
    gc.collect()
    return correct / total

# 1) 基线
baseline = MLP()
opt_base = make_optimizer(BASE_OPTIM, baseline.parameters(), BASE_LR)
for _ in range(BASE_EPOCHS):
    train_one_epoch_capped(baseline, trainloader, opt_base)

acc_base = test_safe(baseline, testloader)
s_base, tot_base, zeros_base = count_sparsity(baseline)
print(f"[Baseline] acc={acc_base:.4f}, sparsity={s_base:.3f} ({zeros_base}/{tot_base})")

# 2) 多剪枝比例实验 + 微调
records = []
for amt in PRUNE_AMOUNTS:
    pruned = MLP()
    pruned.load_state_dict(copy.deepcopy(baseline.state_dict()))

    # L1 剪枝并永久化
    prune_model(pruned, amt)

    acc_pruned = test_safe(pruned, testloader)
    s_after, tot_p, zeros_p = count_sparsity(pruned)

    opt_ft = make_optimizer(FT_OPTIM, pruned.parameters(), FT_LR)
    for _ in range(FT_EPOCHS):
        train_one_epoch_capped(pruned, trainloader, opt_ft)

    acc_ft = test_safe(pruned, testloader)

    records.append({
        "prune_amount": amt,
        "sparsity": round(s_after, 3),
        "acc_pruned": round(acc_pruned, 4),
        "acc_ft": round(acc_ft, 4),
        "zeros_after": zeros_p,
        "total_params": tot_p,
        "ft_optim": FT_OPTIM, "ft_lr": FT_LR
    })

    # 清理当前模型占用
    del pruned, opt_ft
    gc.collect()

import pandas as pd
df = pd.DataFrame(records)
print(df)


[Baseline] acc=0.9260, sparsity=0.000 (0/101770)
   prune_amount  sparsity  acc_pruned  acc_ft  zeros_after  total_params  \
0           0.3     0.300      0.9212  0.9447        30490        101770   
1           0.5     0.499      0.9148  0.9551        50816        101770   
2           0.7     0.699      0.8687  0.9496        71142        101770   
3           0.9     0.899      0.7569  0.9508        91469        101770   

  ft_optim  ft_lr  
0     Adam  0.001  
1     Adam  0.001  
2     Adam  0.001  
3     Adam  0.001  


acc_pruned：剪枝后立刻测试的准确率。
👉 随着剪枝比例上升，准确率明显下降（90% 剪掉后直接掉到 75.7%）。

acc_ft：剪枝后再用训练集微调 2 个 epoch 的结果。
👉 几乎完全恢复到 ~95% 左右！即便剪掉 90% 权重，也能恢复到和轻剪枝差不多的水平。

剪掉 90% 权重后，仍然能达到 95% 精度 → 说明网络里大部分连接是“多余的”。