In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.quantization
import numpy as np
import torch.utils.data
import pdb

class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, 3, stride=2, padding=1)
        self.bn1 = nn.BatchNorm2d(16)
        self.relu1 = nn.ReLU()
        self.conv2 = nn.Conv2d(16, 32, 3, stride=2, padding=1)
        self.bn2 = nn.BatchNorm2d(32)
        self.relu2 = nn.ReLU()
        self.fc = nn.Linear(32 * 7 * 7, 10)  # 假设输入为28x28图像

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu1(x)
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu2(x)
        x = x.reshape(x.size(0), -1)  # 展平
        x = self.fc(x)
        return x

# 计算模型参数大小
def get_model_size(model):
    return sum(p.numel() * p.element_size() for p in model.parameters())

# 训练函数
def train(model, train_loader, criterion, optimizer, num_epochs=5):
    model.train()
    for epoch in range(num_epochs):
        for images, labels in train_loader:
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}")

# 融合模型的函数
def fuse_model(model):
    model.eval()  # 设置为评估模式
    torch.quantization.fuse_modules(model, [['conv1', 'bn1', 'relu1'], ['conv2', 'bn2', 'relu2']], inplace=True)


# 量化模型的函数
def quantize_model(model):
    # 添加量化和反量化模块
    model.quant = torch.quantization.QuantStub()
    model.dequant = torch.quantization.DeQuantStub()
    
    # 修改前向传播
    original_forward = model.forward

    def new_forward(x):
        x = model.quant(x)
        x = original_forward(x)
        return model.dequant(x)

    model.forward = new_forward

    # 设置量化配置
    # model.qconfig = torch.quantization.get_default_qconfig('fbgemm')
    loaded_model.qconfig = torch.quantization.QConfig(
        activation=torch.quantization.MinMaxObserver.with_args(
            quant_min=0,
            quant_max=255,
            dtype=torch.quint8
        ),
        weight=torch.quantization.MinMaxObserver.with_args(
            quant_min=-128,
            quant_max=127,
            dtype=torch.qint8
        )
    )

    # 准备模型进行量化
    torch.quantization.prepare(model, inplace=True)

    # 使用校准数据计算统计量
    calibration_data = [torch.randn(1, 1, 28, 28) for _ in range(10)]
    model.eval()
    with torch.no_grad():
        for data in calibration_data:
            model(data)

    # 量化模型
    torch.quantization.convert(model, inplace=True)

    # 检查量化后的模型参数类型
    print(model.conv1.weight().dtype)
    print(model.conv1.bias().dtype)

    return model


if __name__ == "__main__":
    # 创建训练数据集
    train_dataset = np.random.rand(10, 1, 28, 28).astype(np.float32)
    train_labels = np.random.randint(0, 10, size=(10,)).astype(np.int64)
    train_dataset = torch.utils.data.TensorDataset(torch.from_numpy(train_dataset), torch.from_numpy(train_labels))
    train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=2, shuffle=True)

    # 实例化模型、损失函数和优化器
    model = SimpleCNN()
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    # 训练模型
    train(model, train_loader, criterion, optimizer, num_epochs=5)

    # 保存模型
    torch.save(model.state_dict(), './simple_cnn.pth')
    print("模型已保存。")

    # 加载模型
    loaded_model = SimpleCNN()
    loaded_model.load_state_dict(torch.load('./simple_cnn.pth'))
    print("模型已加载。")

    # 融合模型
    fuse_model(loaded_model)

    # 量化前模型参数大小
    size_before = get_model_size(loaded_model)
    print(f"量化前模型参数大小: {size_before / 1024:.2f} KB")

    # 测试数据
    test_data = torch.randn(1, 1, 28, 28)

    # 量化前的推理
    with torch.no_grad():
        output_before = loaded_model(test_data)
    print(f"量化前输出: {output_before}")

    # 量化模型
    quantized_model = quantize_model(loaded_model)

    # 量化后模型参数大小
    size_after = get_model_size(quantized_model)
    print(f"量化后模型参数大小: {size_after / 1024:.2f} KB")

    # 测试量化后的模型
    with torch.no_grad():
        output_after = loaded_model(test_data)
    print(f"量化后输出: {output_after}")

    # https://pytorch.org/docs/stable/quantization.html
    # https://pytorch.org/tutorials/advanced/static_quantization_tutorial.html