In [3]:
# %load 渐进式剪枝测试_0.9.py
# 渐进式剪枝测试 - 目标稀疏度 90%
import argparse
import os
import time
import shutil

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import torch.nn.utils.prune as prune

import torchvision
import torchvision.transforms as transforms

from vgg_quant import *

global best_prec
use_gpu = torch.cuda.is_available()
print('=> Building model...')

batch_size = 128

model_1 = VGG16_quant()
model_2 = VGG16_quant()
model_3 = VGG16_quant()

normalize = transforms.Normalize(mean=[0.491, 0.482, 0.447], std=[0.247, 0.243, 0.262])

train_dataset = torchvision.datasets.CIFAR10(
    root='./data',
    train=True,
    download=True,
    transform=transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normalize,
    ]))
trainloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)

test_dataset = torchvision.datasets.CIFAR10(
    root='./data',
    train=False,
    download=True,
    transform=transforms.Compose([
        transforms.ToTensor(),
        normalize,
    ]))

testloader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)

print_freq = 100


def validate(val_loader, model, criterion):
    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()

    model.eval()

    end = time.time()
    with torch.no_grad():
        for i, (input, target) in enumerate(val_loader):
            input, target = input.cuda(), target.cuda()

            output = model(input)
            loss = criterion(output, target)

            prec = accuracy(output, target)[0]
            losses.update(loss.item(), input.size(0))
            top1.update(prec.item(), input.size(0))

            batch_time.update(time.time() - end)
            end = time.time()

            if i % print_freq == 0:
                print('Test: [{0}/{1}]\t'
                      'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                      'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                      'Prec {top1.val:.3f}% ({top1.avg:.3f}%)'.format(
                       i, len(val_loader), batch_time=batch_time, loss=losses,
                       top1=top1))

    print(' * Prec {top1.avg:.3f}% '.format(top1=top1))
    return top1.avg


def accuracy(output, target, topk=(1,)):
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].view(-1).float().sum(0)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res


class AverageMeter(object):
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count



def count_zeros_quant_only(model):
    """只统计 QuantConv2d 层的稀疏度"""
    total = 0
    zeros = 0
    for name, module in model.named_modules():
        if isinstance(module, QuantConv2d):
            weight = module.weight.data.cpu().numpy()
            total += weight.size
            zeros += (weight == 0).sum()
    return total, zeros, zeros / total if total > 0 else 0


def count_zeros_quant_only(model):
    """只统计 QuantConv2d 层的稀疏度"""
    total = 0
    zeros = 0
    for name, module in model.named_modules():
        if isinstance(module, QuantConv2d):
            weight = module.weight.data.cpu().numpy()
            total += weight.size
            zeros += (weight == 0).sum()
    return total, zeros, zeros / total if total > 0 else 0


def load_pruned_state_dict(model, state_dict):

    new_state_dict = {}
    
    for key, value in state_dict.items():
        if key.endswith('_orig'):
            base_key = key[:-5]
            mask_key = base_key + '_mask'
            if mask_key in state_dict:
                new_state_dict[base_key] = value * state_dict[mask_key]
            else:
                new_state_dict[base_key] = value
        elif key.endswith('_mask'):
            continue
        else:
            new_state_dict[key] = value
    
    model.load_state_dict(new_state_dict)
    return model

criterion = nn.CrossEntropyLoss().cuda()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 1.无剪枝
print("=" * 60)
print("1. 原始模型（无剪枝）测试")
print("=" * 60)

best_valid_parameters = '4bit_VGG_best_valid_acc.pth'
checkpoint = torch.load(best_valid_parameters)
model_1.load_state_dict(checkpoint['state_dict'])
model_1.eval()
model_1.cuda()

prec_origin = validate(testloader, model_1, criterion)
total, zeros, sparsity = count_zeros_quant_only(model_1)
print(f"全模型 - Total params: {total}, Zero params: {zeros}, Sparsity: {sparsity*100:.2f}%")
total_q, zeros_q, sparsity_q = count_zeros_quant_only(model_1)
print(f"QuantConv2d层 - Total params: {total_q}, Zero params: {zeros_q}, Sparsity: {sparsity_q*100:.2f}%")

# 2 3次渐进剪枝
print("\n" + "=" * 60)
print("2. 分3次渐进式剪枝模型测试 (目标稀疏度: 90%)")
print("=" * 60)

PATH_3step = "gradual_3step_0.9_ep10_vgg_pruned.pth"
try:
    checkpoint_3step = torch.load(PATH_3step)
    model_2 = load_pruned_state_dict(model_2, checkpoint_3step['model_state_dict'])
    model_2.eval()
    model_2.cuda()
    
    prec_3step = validate(testloader, model_2, criterion)
    total, zeros, sparsity = count_zeros_quant_only(model_2)
    print(f"全模型 - Total params: {total}, Zero params: {zeros}, Sparsity: {sparsity*100:.2f}%")
    total_q, zeros_q, sparsity_q = count_zeros_quant_only(model_2)
    print(f"QuantConv2d层 - Total params: {total_q}, Zero params: {zeros_q}, Sparsity: {sparsity_q*100:.2f}%")
    print(f"精度下降: {prec_origin - prec_3step:.2f}%")
except FileNotFoundError:
    print(f"文件 {PATH_3step} 不存在，跳过测试")
    prec_3step = None


#3 6次渐进剪枝
print("\n" + "=" * 60)
print("3. 分6次渐进式剪枝模型测试 (目标稀疏度: 90%)")
print("=" * 60)

PATH_6step = "gradual_6step_0.9_ep10_vgg_pruned.pth"
try:
    checkpoint_6step = torch.load(PATH_6step)
    model_3 = load_pruned_state_dict(model_3, checkpoint_6step['model_state_dict'])
    model_3.eval()
    model_3.cuda()
    
    prec_6step = validate(testloader, model_3, criterion)
    total, zeros, sparsity = count_zeros_quant_only(model_3)
    print(f"全模型 - Total params: {total}, Zero params: {zeros}, Sparsity: {sparsity*100:.2f}%")
    total_q, zeros_q, sparsity_q = count_zeros_quant_only(model_3)
    print(f"QuantConv2d层 - Total params: {total_q}, Zero params: {zeros_q}, Sparsity: {sparsity_q*100:.2f}%")
    print(f"精度下降: {prec_origin - prec_6step:.2f}%")
except FileNotFoundError:
    print(f"文件 {PATH_6step} 不存在，跳过测试")
    prec_6step = None

=> Building model...
Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [01:59<00:00, 1431493.98it/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified
1. 原始模型（无剪枝）测试
Test: [0/79]	Time 0.693 (0.693)	Loss 0.3867 (0.3867)	Prec 88.281% (88.281%)
 * Prec 88.670% 
全模型 - Total params: 14710464, Zero params: 0, Sparsity: 0.00%
QuantConv2d层 - Total params: 14710464, Zero params: 0, Sparsity: 0.00%

2. 分3次渐进式剪枝模型测试 (目标稀疏度: 90%)
Test: [0/79]	Time 0.376 (0.376)	Loss 1.0111 (1.0111)	Prec 67.969% (67.969%)
 * Prec 61.430% 
全模型 - Total params: 14710464, Zero params: 13239416, Sparsity: 90.00%
QuantConv2d层 - Total params: 14710464, Zero params: 13239416, Sparsity: 90.00%
精度下降: 27.24%

3. 分6次渐进式剪枝模型测试 (目标稀疏度: 90%)
Test: [0/79]	Time 0.267 (0.267)	Loss 0.7874 (0.7874)	Prec 73.438% (73.438%)
 * Prec 70.330% 
全模型 - Total params: 14710464, Zero params: 13239416, Sparsity: 90.00%
QuantConv2d层 - Total params: 14710464, Zero params: 13239416, Sparsity: 90.00%
精度下降: 18.34%
