<a href="https://colab.research.google.com/github/Carba6/deeplearning/blob/main/evaluate.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
from torchvision import datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.quantization import FakeQuantize, QConfig
from torch.quantization import HistogramObserver
from WRN_quant import Wide_ResNet as WRN_quant
from torch.quantization import get_default_qconfig

class CustomHistogramObserverActivation(HistogramObserver):
    def __init__(self, num_bits, **kwargs):
        self.num_bits = num_bits
        super(CustomHistogramObserverActivation, self).__init__(**kwargs)

    def calculate_qparams(self):
        min_val, max_val = self.min_val, self.max_val
        scale = (max_val - min_val) / (2 ** self.num_bits - 1)
        zero_point = 0
        return torch.tensor([scale]), torch.tensor([zero_point], dtype=torch.int64)


class CustomHistogramObserverWeight(HistogramObserver):
    def __init__(self, num_bits, **kwargs):
        self.num_bits = num_bits
        super(CustomHistogramObserverWeight, self).__init__(**kwargs)

    def calculate_qparams(self):
        min_val, max_val = self.min_val, self.max_val
        scale = (max_val - min_val) / (2 ** self.num_bits - 1)
        zero_point = int(-min_val / scale)
        return torch.tensor([scale]), torch.tensor([zero_point], dtype=torch.int64)


def custom_qconfig(num_bits):
    return QConfig(
        activation=FakeQuantize.with_args(observer=CustomHistogramObserverActivation, num_bits=num_bits,dtype=torch.quint8),
        weight=FakeQuantize.with_args(observer=CustomHistogramObserverWeight, num_bits=num_bits, dtype=torch.qint8),
    )

def evaluate(model, dataloader, device):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, targets in dataloader:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs, 1)
            total += targets.size(0)
            correct += (predicted == targets).sum().item()
    return correct / total


def evaluate_num_bits_model(num_bits, depth, widen_factor, num_classes, dropout_rate, device, test_loader):
    model = WRN_quant(depth=depth, widen_factor=widen_factor, num_classes=num_classes, dropout_rate=dropout_rate).to(device)
    model_path = f"WRN_{depth}_{widen_factor}_Relu6_{num_bits}bit.pth"
    !cp /content/drive/MyDrive/model/{model_path} /content
    model.eval()
    if num_bits == 8:
      model.qconfig = get_default_qconfig('fbgemm')
    else:
      model.qconfig = custom_qconfig(num_bits)
    torch.quantization.prepare(model, inplace=True)
    torch.quantization.convert(model, inplace=True)
    model.load_state_dict(torch.load(model_path))
    quantized_accuracy = evaluate(model, test_loader, 'cpu')
    print(f"{num_bits} bits model test accuracy: {quantized_accuracy * 100:.2f}%")
    return


def main():
    # 模型参数

    depth = 28
    widen_factor = 10
    num_classes = 10
    dropout_rate = 0.3
    batch_size = 128
    # 量化范围

    min_bits = 1  # 最小为1
    max_bits = 8  # 最大为8

    device = torch.device("cpu")

    # 加载CIFAR-10数据集
    transform = transforms.Compose([
        transforms.Resize((32, 32)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    testset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
    test_loader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2)

    for num_bits in range(min_bits, max_bits + 1):
        evaluate_num_bits_model(num_bits, depth, widen_factor, num_classes, dropout_rate, device, test_loader)


if __name__ == '__main__':
    main()

Files already downloaded and verified
Creat WRN_quant successfully!


In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
!cp /content/WRN_28_10_Relu6_full_precision.pth /content/drive/MyDrive/model