<a href="https://colab.research.google.com/github/Carba6/deeplearning/blob/main/test3.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [33]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
from torch.quantization import QuantStub, DeQuantStub, fuse_modules, quantize_dynamic
from torch.quantization import FakeQuantize, QConfig
from torch.quantization import HistogramObserver

class CustomHistogramObserverActivation(HistogramObserver):
    def __init__(self, num_bits, **kwargs):
        self.num_bits = num_bits
        super(CustomHistogramObserverActivation, self).__init__(**kwargs)

    def calculate_qparams(self):
        min_val, max_val = self.min_val, self.max_val

        scale = (max_val - min_val) / (2 ** self.num_bits - 1)
        zero_point = 0
        return torch.tensor([scale]), torch.tensor([zero_point], dtype=torch.int64)

class CustomHistogramObserverWeight(HistogramObserver):
    def __init__(self, num_bits, **kwargs):
        self.num_bits = num_bits
        super(CustomHistogramObserverWeight, self).__init__(**kwargs)

    def calculate_qparams(self):
        min_val, max_val = self.min_val, self.max_val

        scale = (max_val - min_val) / (2 ** self.num_bits - 1)
        zero_point = int(-min_val / scale)
        return torch.tensor([scale]), torch.tensor([zero_point], dtype=torch.int64)

def custom_qconfig(num_bits):
    return QConfig(
        activation=FakeQuantize.with_args(observer=CustomHistogramObserverActivation, num_bits=num_bits, dtype=torch.quint8),
        weight=FakeQuantize.with_args(observer=CustomHistogramObserverWeight, num_bits=num_bits, dtype=torch.qint8),
    )
class SimpleQuantizedModel(nn.Module):
    def __init__(self):
        super(SimpleQuantizedModel, self).__init__()
        self.quant = QuantStub()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, padding=1)
        self.relu1 = nn.ReLU()
        self.maxpool1 = nn.MaxPool2d(kernel_size=2)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
        self.relu2 = nn.ReLU()
        self.maxpool2 = nn.MaxPool2d(kernel_size=2)
        self.fc1 = nn.Linear(32 * 7 * 7, 64)
        self.relu3 = nn.ReLU()
        self.fc2 = nn.Linear(64, 10)
        self.dequant = DeQuantStub()

    def forward(self, x):
        x = self.quant(x)
        x = self.maxpool1(self.relu1(self.conv1(x)))
        x = self.maxpool2(self.relu2(self.conv2(x)))
        x = x.reshape(x.size(0), -1)
        x = self.relu3(self.fc1(x))
        x = self.fc2(x)
        x = self.dequant(x)
        return x
transform = transforms.Compose([
    transforms.Resize((28, 28)),
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)
model = SimpleQuantizedModel()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
def train(model, loader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    for i, (images, labels) in enumerate(loader):
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    return running_loss / len(loader)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

num_epochs = 1


for epoch in range(num_epochs):
    train_loss = train(model, train_loader, criterion, optimizer, device)
    print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {train_loss:.4f}")
def evaluate(model, loader, device):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return correct / total

accuracy = evaluate(model, test_loader, device)
print(f"Test accuracy: {accuracy * 100:.2f}%")
model.to('cpu')
model.eval()
model_fused = torch.quantization.fuse_modules(model, [['conv1', 'relu1'], ['conv2', 'relu2'], ['fc1', 'relu3']])
model_fused.qconfig = custom_qconfig(5)
torch.quantization.prepare(model_fused, inplace=True)
def calibrate(model, loader, device):
    model.eval()
    with torch.no_grad():
        for images, _ in loader:
            images = images.to(device)
            _ = model(images)

calibrate(model_fused, train_loader, 'cpu')
print(model_fused.qconfig)



def print_quantization_info(model):
    qconfig = model.qconfig
    print(f"Activation observer: {qconfig.activation()}")
    print(f"Weight observer: {qconfig.weight()}")
    print(f"min: {qconfig.activation().quant_min}")
    print(f"max: {qconfig.weight().quant_max}")

# 在量化准备之后，打印量化信息
print_quantization_info(model_fused)


torch.quantization.convert(model_fused, inplace=True)
quantized_accuracy = evaluate(model_fused, test_loader, 'cpu')
print(f"Quantized test accuracy: {quantized_accuracy * 100:.2f}%")

print(model_fused.conv1.scale, model_fused.conv1.zero_point)
print(model_fused.conv2.scale, model_fused.conv2.zero_point)
print(model_fused.fc1.scale, model_fused.fc1.zero_point)
print(model_fused.fc2.scale, model_fused.fc2.zero_point)
def check_quantization_after_convert(model):
    quantized_layers = 0
    for layer_name, layer in model.named_modules():
        if isinstance(layer, (nn.quantized.Conv2d, nn.quantized.Linear)):
            quantized_layers += 1
            print(f"Layer {layer_name} is quantized.")
        elif isinstance(layer, (nn.Conv2d, nn.Linear)):
            print(f"Layer {layer_name} is not quantized.")
    return quantized_layers

# 在量化后检查实际量化的层数
quantized_layers = check_quantization_after_convert(model_fused)
print(f"Number of quantized layers: {quantized_layers}")



Epoch [1/1], Loss: 0.1680
Test accuracy: 97.79%
QConfig(activation=functools.partial(<class 'torch.ao.quantization.fake_quantize.FakeQuantize'>, observer=<class '__main__.CustomHistogramObserverActivation'>, num_bits=5, dtype=torch.quint8){'factory_kwargs': <function _add_module_to_qconfig_obs_ctr.<locals>.get_factory_kwargs_based_on_module_device at 0x7f4c7d1a9af0>}, weight=functools.partial(<class 'torch.ao.quantization.fake_quantize.FakeQuantize'>, observer=<class '__main__.CustomHistogramObserverWeight'>, num_bits=5, dtype=torch.qint8){'factory_kwargs': <function _add_module_to_qconfig_obs_ctr.<locals>.get_factory_kwargs_based_on_module_device at 0x7f4c7d1a9af0>})
Activation observer: FakeQuantize(
  fake_quant_enabled=tensor([1], dtype=torch.uint8), observer_enabled=tensor([1], dtype=torch.uint8), quant_min=0, quant_max=255, dtype=torch.quint8, qscheme=torch.per_tensor_affine, ch_axis=-1, scale=tensor([1.]), zero_point=tensor([0], dtype=torch.int32)
  (activation_post_process): Cu