In [1]:
# %load My_unstructured_pruning.py
import argparse
import os
import time
import shutil

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import torch.nn.utils.prune as prune

import torchvision
import torchvision.transforms as transforms

from vgg_quant import *

global best_prec
use_gpu = torch.cuda.is_available()
print('=> Building model...')
    
    
batch_size = 128

model_1 = VGG16_quant()
model_2 = VGG16_quant()
model_3 = VGG16_quant()

        

normalize = transforms.Normalize(mean=[0.491, 0.482, 0.447], std=[0.247, 0.243, 0.262])


train_dataset = torchvision.datasets.CIFAR10(
    root='./data',
    train=True,
    download=True,
    transform=transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normalize,
    ]))
trainloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)


test_dataset = torchvision.datasets.CIFAR10(
    root='./data',
    train=False,
    download=True,
    transform=transforms.Compose([
        transforms.ToTensor(),
        normalize,
    ]))

testloader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)


print_freq = 100 # every 100 batches, accuracy printed. Here, each batch includes "batch_size" data points
def validate(val_loader, model, criterion ):
    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()

    # switch to evaluate mode
    model.eval()

    end = time.time()
    with torch.no_grad():
        for i, (input, target) in enumerate(val_loader):
         
            input, target = input.cuda(), target.cuda()

            # compute output
            output = model(input)
            loss = criterion(output, target)

            # measure accuracy and record loss
            prec = accuracy(output, target)[0]
            losses.update(loss.item(), input.size(0))
            top1.update(prec.item(), input.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % print_freq == 0:  # This line shows how frequently print out the status. e.g., i%5 => every 5 batch, prints out
                print('Test: [{0}/{1}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'Prec {top1.val:.3f}% ({top1.avg:.3f}%)'.format(
                   i, len(val_loader), batch_time=batch_time, loss=losses,
                   top1=top1))

    print(' * Prec {top1.avg:.3f}% '.format(top1=top1))
    return top1.avg


def accuracy(output, target, topk=(1,)):
    """Computes the precision@k for the specified values of k"""
    maxk = max(topk) # 5
    batch_size = target.size(0) # 128

    _, pred = output.topk(maxk, 1, True, True) # topk(k, dim=None, largest=True, sorted=True)
                                    # will output (max value, its index)
    pred = pred.t()               # transpose
    correct = pred.eq(target.view(1, -1).expand_as(pred))   # "-1": calculate automatically

    res = []
    for k in topk: # 1, 5
        correct_k = correct[:k].view(-1).float().sum(0)  # view(-1): make a flattened 1D tensor
        res.append(correct_k.mul_(100.0 / batch_size))   # correct: size of [maxk, batch_size]
    return res


class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n    ## n is impact factor
        self.count += n
        self.avg = self.sum / self.count

        
def save_checkpoint(state, is_best, fdir):
    filepath = os.path.join(fdir, 'checkpoint.pth')
    torch.save(state, filepath)
    if is_best:
        shutil.copyfile(filepath, os.path.join(fdir, 'model_best.pth.tar'))


def adjust_learning_rate(optimizer, epoch):
    """For resnet, the lr starts from 0.1, and is divided by 10 at 80 and 120 epochs"""
    adjust_list = [150, 225]
    if epoch in adjust_list:
        for param_group in optimizer.param_groups:
            param_group['lr'] = param_group['lr'] * 0.1        

#model = nn.DataParallel(model).cuda()
#all_params = checkpoint['state_dict']
#model.load_state_dict(all_params, strict=False)
#criterion = nn.CrossEntropyLoss().cuda()
#validate(testloader, model, criterion)



# lr = 4e-3
# weight_decay = 1e-4
# epochs = 100
# best_prec = 0
# model = model.cuda()
# criterion = nn.CrossEntropyLoss().cuda()
# optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9)
# fdir = 'result/'+str(model_name)+'/model_best_valid.pth'
# epoch_parameters = 'epoch_parameters.pth'
# best_valid_parameters = 'best_valid_acc.pth'


def count_zeros_quant_only(model):
    """只统计 QuantConv2d 层的稀疏度"""
    total = 0
    zeros = 0
    for name, module in model.named_modules():
        if isinstance(module, QuantConv2d):
            weight = module.weight.data.cpu().numpy()
            total += weight.size
            zeros += (weight == 0).sum()
    return total, zeros, zeros / total if total > 0 else 0


def load_pruned_state_dict(model, state_dict):
    """
    加载剪枝后的模型权重
    将 weight_orig * weight_mask 合并为 weight
    """
    new_state_dict = {}
    
    for key, value in state_dict.items():
        if key.endswith('_orig'):
            # 找到对应的 mask
            base_key = key[:-5]  # 移除 '_orig'
            mask_key = base_key + '_mask'
            if mask_key in state_dict:
                # 合并: weight = weight_orig * weight_mask
                new_state_dict[base_key] = value * state_dict[mask_key]
            else:
                new_state_dict[base_key] = value
        elif key.endswith('_mask'):
            # mask 已在上面处理，跳过
            continue
        else:
            new_state_dict[key] = value
    
    model.load_state_dict(new_state_dict)
    return model


criterion = nn.CrossEntropyLoss().cuda()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

#1. 测试原始模型（无剪枝
print("=" * 60)
print("1. 原始模型（无剪枝）测试")
print("=" * 60)

best_valid_parameters = '4bit_VGG_best_valid_acc.pth'
checkpoint = torch.load(best_valid_parameters)
model_1.load_state_dict(checkpoint['state_dict'])
model_1.eval()
model_1.cuda()

prec_origin = validate(testloader, model_1, criterion)
total, zeros, sparsity = count_zeros_quant_only(model_1)
print(f"Total params: {total}")
print(f"Zero params: {zeros}")
print(f"Sparsity: {sparsity*100:.2f}%")


# 2. 测试分2次渐进式剪枝模型
print("\n" + "=" * 60)
print("2. 分2次渐进式剪枝模型测试 (目标稀疏度: 80%)")
print("=" * 60)

PATH_2step = "gradual_2step_0.8_ep10_vgg_pruned.pth"
try:
    checkpoint_2step = torch.load(PATH_2step)
    model_2 = load_pruned_state_dict(model_2, checkpoint_2step['model_state_dict'])
    model_2.eval()
    model_2.cuda()
    
    prec_2step = validate(testloader, model_2, criterion)
    total, zeros, sparsity = count_zeros_quant_only(model_2)
    print(f"Total params: {total}")
    print(f"Zero params: {zeros}")
    print(f"Sparsity: {sparsity*100:.2f}%")
    print(f"精度下降: {prec_origin - prec_2step:.2f}%")
except FileNotFoundError:
    print(f"文件 {PATH_2step} 不存在，跳过测试")
    prec_2step = None


# #3. 测试分4次渐进式剪枝模型
print("\n" + "=" * 60)
print("3. 分4次渐进式剪枝模型测试 (目标稀疏度: 80%)")
print("=" * 60)

PATH_4step = "gradual_4step_0.8_ep10_vgg_pruned.pth"
try:
    checkpoint_4step = torch.load(PATH_4step)
    model_3 = load_pruned_state_dict(model_3, checkpoint_4step['model_state_dict'])
    model_3.eval()
    model_3.cuda()
    
    prec_4step = validate(testloader, model_3, criterion)
    total, zeros, sparsity = count_zeros_quant_only(model_3)
    print(f"Total params: {total}")
    print(f"Zero params: {zeros}")
    print(f"Sparsity: {sparsity*100:.2f}%")
    print(f"精度下降: {prec_origin - prec_4step:.2f}%")
except FileNotFoundError:
    print(f"文件 {PATH_4step} 不存在，跳过测试")
    prec_4step = None


=> Building model...
Files already downloaded and verified
Files already downloaded and verified
1. 原始模型（无剪枝）测试
Test: [0/79]	Time 0.654 (0.654)	Loss 0.3867 (0.3867)	Prec 88.281% (88.281%)
 * Prec 88.670% 
Total params: 14710464
Zero params: 0
Sparsity: 0.00%

2. 分2次渐进式剪枝模型测试 (目标稀疏度: 80%)
Test: [0/79]	Time 0.344 (0.344)	Loss 0.4182 (0.4182)	Prec 85.156% (85.156%)
 * Prec 82.420% 
Total params: 14710464
Zero params: 11768371
Sparsity: 80.00%
精度下降: 6.25%

3. 分4次渐进式剪枝模型测试 (目标稀疏度: 80%)
Test: [0/79]	Time 0.241 (0.241)	Loss 0.4884 (0.4884)	Prec 84.375% (84.375%)
 * Prec 82.750% 
Total params: 14710464
Zero params: 11768371
Sparsity: 80.00%
精度下降: 5.92%
