<a href="https://colab.research.google.com/github/Carba6/deeplearning/blob/main/test2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [26]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import torch.quantization as quantization
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
from torch.quantization import QuantStub, DeQuantStub, fuse_modules, quantize_dynamic, prepare_qat
from torch.quantization import FakeQuantize, QConfig

from torch.quantization import HistogramObserver

class CustomHistogramObserverActivation(HistogramObserver):
    def __init__(self, num_bits, **kwargs):
        self.num_bits = num_bits
        super(CustomHistogramObserverActivation, self).__init__(**kwargs)

    def calculate_qparams(self):
        min_val, max_val = self.min_val, self.max_val

        scale = (max_val - min_val) / (2 ** self.num_bits - 1)
        zero_point = 0
        return torch.tensor([scale]), torch.tensor([zero_point], dtype=torch.int64)

class CustomHistogramObserverWeight(HistogramObserver):
    def __init__(self, num_bits, **kwargs):
        self.num_bits = num_bits
        super(CustomHistogramObserverWeight, self).__init__(**kwargs)

    def calculate_qparams(self):
        min_val, max_val = self.min_val, self.max_val

        scale = (max_val - min_val) / (2 ** self.num_bits - 1)
        zero_point = int(-min_val / scale)
        return torch.tensor([scale]), torch.tensor([zero_point], dtype=torch.int64)

def custom_qconfig(num_bits):
    return QConfig(
        activation=FakeQuantize.with_args(observer=CustomHistogramObserverActivation, num_bits=num_bits, dtype=torch.quint8),
        weight=FakeQuantize.with_args(observer=CustomHistogramObserverWeight, num_bits=num_bits, dtype=torch.qint8),
    )
# 定义一个简单的Wide ResNet网络
class WideResNet(nn.Module):
    def __init__(self, num_classes=10):
        super(WideResNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(16)
        self.relu = nn.ReLU(inplace=True)
        self.layers = nn.Sequential(
            nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
        )
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(64, num_classes)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.layers(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x

# 加载CIFAR10数据集
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transforms.ToTensor())
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transforms.ToTensor())

# 定义数据加载器
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)
testloader = torch.utils.data.DataLoader(testset, batch_size=128, shuffle=False, num_workers=2)

# 定义训练参数
criterion = nn.CrossEntropyLoss()
net = WideResNet()
optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)

# 使用PyTorch的量化感知训练工具为网络准备量化
net.qconfig = custom_qconfig(5)
print(net.qconfig)
quantization.convert(net, inplace=True)
quantization.prepare(net, inplace=True)
for name, module in net.named_modules():
    print('Module: {}, QConfig: {}'.format(name, module.qconfig))
quantization.prepare_qat(net, inplace=True)

# 训练模型
for epoch in range(1):
    net.train()
    for i, (inputs, labels) in enumerate(trainloader):
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        if i % 100 == 0:
            print('[%d, %5d] loss: %.3f' % (epoch + 1, i + 1, loss.item()))

    scheduler.step()

# 量化模型
# net.qconfig = custom_qconfig(5)
q_net = quantization.convert(net, inplace=False)
print('=== Quantization Configuration ===')
for name, module in q_net.named_modules():
    if isinstance(module, quantization.QuantWrapper):
        print('{}: {}'.format(name, module.qconfig))
print(q_net)
correct = 0
total = 0
with torch.no_grad():
    q_net.eval()
    for data in testloader:
        images, labels = data
        outputs = net(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print('Accuracy of the quantized model on test images: %d %%' % (100 * correct / total))


Files already downloaded and verified
Files already downloaded and verified
QConfig(activation=functools.partial(<class 'torch.ao.quantization.fake_quantize.FakeQuantize'>, observer=<class '__main__.CustomHistogramObserverActivation'>, num_bits=5, dtype=torch.quint8){}, weight=functools.partial(<class 'torch.ao.quantization.fake_quantize.FakeQuantize'>, observer=<class '__main__.CustomHistogramObserverWeight'>, num_bits=5, dtype=torch.qint8){})
Module: , QConfig: None
Module: conv1, QConfig: None
Module: bn1, QConfig: None
Module: relu, QConfig: None
Module: layers, QConfig: None
Module: layers.0, QConfig: None
Module: layers.1, QConfig: None
Module: layers.2, QConfig: None
Module: layers.3, QConfig: None
Module: layers.4, QConfig: None
Module: layers.5, QConfig: None
Module: avgpool, QConfig: None
Module: fc, QConfig: None




[1,     1] loss: 2.336


KeyboardInterrupt: ignored