# Import the necessary libraries

In [None]:
import torch
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torch.nn as nn
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm

# pathlib：是 Python 标准库中的一个模块，提供了面向对象的文件系统路径操作方法
from pathlib import Path
import os

# Load the MNIST dataset

In [None]:
# 设置随机种子，确保实验的可重复性
_ = torch.manual_seed(0)

In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),
    # 归一化处理：使用均值 0.1307 和标准差 0.3081 对图像进行归一化
    # MNIST 全体像素均值：0.1307
    # MNIST 全体像素标准差：0.3081
    transforms.Normalize((0.1307,), (0.3081,))
    # MNIST 是灰度图像，因此均值和标准差都是单通道的元组
    # 如果是 RGB 彩色图像，则需要提供三个通道的均值和标准差，例如：((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    # transforms.Normalize((mean_R, mean_G, mean_B), (std_R, std_G, std_B))
])

# Load the MNIST dataset
mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
# Create a dataloader for the training
train_loader = torch.utils.data.DataLoader(mnist_trainset, batch_size=10, shuffle=True)

# Load the MNIST test set
mnist_testset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(mnist_testset, batch_size=10, shuffle=True)

# Define the device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Define the model

In [None]:
class VerySimpleNet(nn.Module):
    def __init__(self, hidden_size_1=100, hidden_size_2=100):
        super(VerySimpleNet,self).__init__()
        self.linear1 = nn.Linear(28*28, hidden_size_1) 
        self.linear2 = nn.Linear(hidden_size_1, hidden_size_2) 
        self.linear3 = nn.Linear(hidden_size_2, 10)
        self.relu = nn.ReLU()

    def forward(self, img):
        x = img.view(-1, 28*28)
        x = self.relu(self.linear1(x))
        x = self.relu(self.linear2(x))
        x = self.linear3(x)
        return x

In [None]:
net = VerySimpleNet()

# Train the model

In [None]:
# 此训练方法仅针对分类任务
# output 的 shape = [batch_size, num_classes]（每类的 logits）
# y 的 shape = [batch_size]，每个元素是类别索引（整数）
def train(train_loader, model, epochs=5, total_iterations_limit=None):
    # total_iterations_limit：如果指定了该参数，则训练将在达到该总迭代次数后提前终止
    model.to(device)

    cross_el = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

    # 统计整个训练过程中已经执行了多少次迭代（跨 epoch 累加）
    total_iterations = 0

    for epoch in range(epochs):
        model.train()

        # loss_sum 用于累计每个 batch 的 loss
        loss_sum = 0
        # num_iterations 用于统计当前 epoch 已经执行了多少次迭代
        num_iterations = 0

        data_iterator = tqdm(train_loader, desc=f'Epoch {epoch+1}')

        # 如果指定了 total_iterations_limit，则调整 data_iterator 的总迭代次数
        if total_iterations_limit is not None and total_iterations_limit - total_iterations < data_iterator.total:
            data_iterator.total = total_iterations_limit - total_iterations

        # tqdm 的进度条 每执行一次 for data in data_iterator: 的循环，进度条就增加 1
        for data in data_iterator:

            if total_iterations_limit is not None and total_iterations >= total_iterations_limit:
                return

            num_iterations += 1
            total_iterations += 1
            x, y = data
            x = x.to(device)
            y = y.to(device)
            optimizer.zero_grad()

            # 使用其他模型时，可能需要调整这里的输入格式
            output = model(x.view(-1, 28*28))

            loss = cross_el(output, y)
            loss_sum += loss.item()
            avg_loss = loss_sum / num_iterations

            # 在 tqdm 进度条后面显示额外信息
            data_iterator.set_postfix(loss=avg_loss)

            loss.backward()
            optimizer.step()
            
def print_size_of_model(model):
    torch.save(model.state_dict(), "temp_delme.p")

    print('Size (KB):', os.path.getsize("temp_delme.p")/1024)
        # os.path.getsize(...) → 文件大小，单位 Bytes
        # 1 Byte = 8 bits
        # /1024 → 转换为 KB

    os.remove('temp_delme.p')

MODEL_FILENAME = 'simplenet_ptq.pt'

if Path(MODEL_FILENAME).exists():
    net.load_state_dict(torch.load(MODEL_FILENAME))
    print('Loaded model from disk')
else:
    train(train_loader, net, epochs=5)
    # Save the model to disk
    torch.save(net.state_dict(), MODEL_FILENAME)

# Define the testing loop

In [None]:
# 此测试方法仅针对分类任务
# output 的 shape = [batch_size, num_classes]（每类的 logits）
# y 的 shape = [batch_size]，每个元素是类别索引（整数
def test(test_loader, model: nn.Module, total_iterations: int = None):
    correct = 0
    total = 0

    iterations = 0

    model.to(device)
    model.eval()

    with torch.no_grad():
        for data in tqdm(test_loader, desc='Testing'):
            x, y = data
            x = x.to(device)
            y = y.to(device)

            # 使用其他模型时，可能需要调整这里的输入格式
            output = model(x.view(-1, 28*28))
            
            # enumerate 接受一个可迭代对象（列表、张量、生成器等）,返回 索引和元素 的对 (index, element)
            for idx, i in enumerate(output):
                # torch.argmax(i) → 返回张量 i 中最大值的索引
                if torch.argmax(i) == y[idx]:
                    correct +=1
                total +=1
            iterations += 1
            if total_iterations is not None and iterations >= total_iterations:
                break
    accuracy = correct / total
    print(f'Accuracy: {round(accuracy * 100, 2)}%')

# Print weights and size of the model before quantization

In [None]:
# Print the weights matrix of the model before quantization
print('Weights before quantization')
print(net.linear1.weight)
print(net.linear1.weight.dtype)

In [None]:
print('Size of the model before quantization')
print_size_of_model(net)

In [None]:
print(f'Accuracy of the model before quantization: ')
test(test_loader, net)

# Insert min-max observers in the model

In [None]:
class QuantizedVerySimpleNet(nn.Module):
    def __init__(self, hidden_size_1=100, hidden_size_2=100):
        super(QuantizedVerySimpleNet,self).__init__()
        # torch.quantization.QuantStub() 本身是一个 nn.Module，后续会被替换为真正的量化操作，在 PyTorch 的 静态量化（Static Quantization） 中用来 占位（Stub）
        # 占位 Stub：一个临时占据位置的对象，在模型定义时放进去，但在实际推理前会被 PyTorch 自动 替换为具体操作（如量化、反量化等）
        # 在 prepare() 之后会插入一个 Observer，用于收集输入的 min/max 值
        # 在 convert() 之后会变成真正的量化操作（如 torch.quantize_per_tensor）
        self.quant = torch.quantization.QuantStub()
        self.linear1 = nn.Linear(28*28, hidden_size_1) 
        self.linear2 = nn.Linear(hidden_size_1, hidden_size_2) 
        self.linear3 = nn.Linear(hidden_size_2, 10)
        self.relu = nn.ReLU()
        # 同上
        self.dequant = torch.quantization.DeQuantStub()

    def forward(self, img):
        x = img.view(-1, 28*28)
        x = self.quant(x)
        x = self.relu(self.linear1(x))
        x = self.relu(self.linear2(x))
        x = self.linear3(x)
        x = self.dequant(x)
        return x

# class VerySimpleNet(nn.Module):
#     def __init__(self, hidden_size_1=100, hidden_size_2=100):
#         super(VerySimpleNet,self).__init__()
#         self.linear1 = nn.Linear(28*28, hidden_size_1) 
#         self.linear2 = nn.Linear(hidden_size_1, hidden_size_2) 
#         self.linear3 = nn.Linear(hidden_size_2, 10)
#         self.relu = nn.ReLU()

#     def forward(self, img):
#         x = img.view(-1, 28*28)
#         x = self.relu(self.linear1(x))
#         x = self.relu(self.linear2(x))
#         x = self.linear3(x)
#         return x

In [None]:
net_quantized = QuantizedVerySimpleNet()
# Copy weights from unquantized model
net_quantized.load_state_dict(net.state_dict())
net_quantized.eval()

# 给模型设置 默认量化配置（qconfig）
# 激活使用 MinMaxObserver
# 权重使用 PerChannelMinMaxObserver（逐通道量化）
# 数据类型为 int8
net_quantized.qconfig = torch.ao.quantization.default_qconfig

# 向模型中插入 “观察器（observer）” 模块，用于记录训练/推理过程中每个层的激活最小值和最大值
# prepare() 会执行以下操作：
# 用 observer 模块替换掉 QuantStub 和 DeQuantStub
# 给每层的输入输出插入观察器（记录 min/max）
# 模型仍为 float 形式，但能在 calibration 阶段收集数据分布
# 当自定义模块或其子模块显式设置了 .qconfig，torch.ao.quantization.prepare() 才会在其中插入观察器（observer）
net_quantized = torch.ao.quantization.prepare(net_quantized) # Insert observers
net_quantized

# Calibrate the model using the test set

In [None]:
test(test_loader, net_quantized)

In [None]:
print(f'Check statistics of the various layers')
net_quantized

# Quantize the model using the statistics collected

In [None]:
net_quantized.to("cpu")
# 替换观察器模块为实际的量化和反量化操作之前必须在 CPU 上
net_quantized = torch.ao.quantization.convert(net_quantized)

In [None]:
print(f'Check statistics of the various layers')
net_quantized

# Print weights of the model after quantization

In [None]:
# Print the weights matrix of the model after quantization
print('Weights after quantization')
# torch.int_repr() 返回量化后的整数表示
print(torch.int_repr(net_quantized.linear1.weight()))
print(net_quantized.linear1.weight())

# Compare the dequantized weights and the original weights

In [None]:
print('Original weights: ')
print(net.linear1.weight)
print('')
print(f'Dequantized weights: ')
# torch.dequantize(...) → 把量化权重转回 float32
print(torch.dequantize(net_quantized.linear1.weight()))
print('')

# Print size and accuracy of the quantized model

In [None]:
print('Size of the model after quantization')
print_size_of_model(net_quantized)

In [None]:
print(type(net_quantized))
print(net_quantized)
print(net_quantized.linear1)

In [None]:
def Quantizedtest(test_loader, model: nn.Module, total_iterations: int = None):
    correct = 0
    total = 0

    iterations = 0

    model.to("cpu")
    model.eval()

    with torch.no_grad():
        for data in tqdm(test_loader, desc='Testing'):
            x, y = data
            x = x.to("cpu")
            y = y.to("cpu")

            # 使用其他模型时，可能需要调整这里的输入格式
            output = model(x.view(x.size(0), -1))
            
            # enumerate 接受一个可迭代对象（列表、张量、生成器等）,返回 索引和元素 的对 (index, element)
            for idx, i in enumerate(output):
                # torch.argmax(i) → 返回张量 i 中最大值的索引
                if torch.argmax(i) == y[idx]:
                    correct +=1
                total +=1
            iterations += 1
            if total_iterations is not None and iterations >= total_iterations:
                break
    accuracy = correct / total
    print(f'Accuracy: {round(accuracy * 100, 2)}%')

In [None]:
print('Testing the model after quantization')
Quantizedtest(test_loader, net_quantized)