# Import the necessary libraries

In [None]:
import torch
import torchvision.datasets as datasets 
import torchvision.transforms as transforms
import torch.nn as nn
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
from pathlib import Path
import os

# Load the MNIST dataset

In [None]:
# Make torch deterministic
_ = torch.manual_seed(0)

In [None]:
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])

# Load the MNIST dataset
mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
# Create a dataloader for the training
train_loader = torch.utils.data.DataLoader(mnist_trainset, batch_size=10, shuffle=True)

# Load the MNIST test set
mnist_testset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(mnist_testset, batch_size=10, shuffle=True)

# Define the device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Define the model

In [None]:
class VerySimpleNet(nn.Module):
    def __init__(self, hidden_size_1=100, hidden_size_2=100):
        super(VerySimpleNet,self).__init__()
        self.quant = torch.quantization.QuantStub()
        self.linear1 = nn.Linear(28*28, hidden_size_1) 
        self.linear2 = nn.Linear(hidden_size_1, hidden_size_2) 
        self.linear3 = nn.Linear(hidden_size_2, 10)
        self.relu = nn.ReLU()
        self.dequant = torch.quantization.DeQuantStub()

    def forward(self, img):
        x = img.view(-1, 28*28)
        x = self.quant(x)
        x = self.relu(self.linear1(x))
        x = self.relu(self.linear2(x))
        x = self.linear3(x)
        x = self.dequant(x)
        return x

net = VerySimpleNet()

# Insert min-max observers in the model

In [None]:
net.qconfig = torch.ao.quantization.default_qconfig
net.train()

# prepare_qat() 用于量化感知训练
net_quantized = torch.ao.quantization.prepare_qat(net) # Insert observers
net_quantized

# Train the model

In [None]:
def train(train_loader, model, epochs=5, total_iterations_limit=None):

    model.to(device)

    cross_el = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

    total_iterations = 0

    for epoch in range(epochs):
        model.train()

        loss_sum = 0
        num_iterations = 0

        data_iterator = tqdm(train_loader, desc=f'Epoch {epoch+1}')
        if total_iterations_limit is not None:
            data_iterator.total = total_iterations_limit
        for data in data_iterator:
            num_iterations += 1
            total_iterations += 1
            x, y = data
            x = x.to(device)
            y = y.to(device)
            optimizer.zero_grad()
            output = model(x.view(-1, 28*28))
            loss = cross_el(output, y)
            loss_sum += loss.item()
            avg_loss = loss_sum / num_iterations
            data_iterator.set_postfix(loss=avg_loss)
            loss.backward()
            optimizer.step()
            if total_iterations_limit is not None and total_iterations >= total_iterations_limit:
                return

def print_size_of_model(model):
    torch.save(model.state_dict(), "temp_delme.p")

    print('Size (KB):', os.path.getsize("temp_delme.p")/1024)
        # os.path.getsize(...) → 文件大小，单位 Bytes
        # 1 Byte = 8 bits
        # /1024 → 转换为 KB

    os.remove('temp_delme.p')

train(train_loader, net_quantized, epochs=5)

# Check the collected statistics during training

In [None]:
print(f'Check statistics of the various layers')
net_quantized

# Define the test loop

In [None]:
# 此测试方法仅针对分类任务
# output 的 shape = [batch_size, num_classes]（每类的 logits）
# y 的 shape = [batch_size]，每个元素是类别索引（整数
def test(test_loader, model: nn.Module, total_iterations: int = None):
    correct = 0
    total = 0

    iterations = 0

    model.to(device)
    model.eval()

    with torch.no_grad():
        for data in tqdm(test_loader, desc='Testing'):
            x, y = data
            x = x.to(device)
            y = y.to(device)

            # 使用其他模型时，可能需要调整这里的输入格式
            output = model(x.view(-1, 28*28))
            
            # enumerate 接受一个可迭代对象（列表、张量、生成器等）,返回 索引和元素 的对 (index, element)
            for idx, i in enumerate(output):
                # torch.argmax(i) → 返回张量 i 中最大值的索引
                if torch.argmax(i) == y[idx]:
                    correct +=1
                total +=1
            iterations += 1
            if total_iterations is not None and iterations >= total_iterations:
                break
    accuracy = correct / total
    print(f'Accuracy: {round(accuracy * 100, 2)}%')

test(test_loader, net_quantized)

# Quantize the model using the statistics collected

In [None]:
net_quantized.eval()
net_quantized.to("cpu")
net_quantized = torch.ao.quantization.convert(net_quantized)

In [None]:
print(f'Check statistics of the various layers')
net_quantized

# Print weights and size of the model after quantization

In [None]:
# Print the weights matrix of the model before quantization
print('Weights before quantization')
print(torch.int_repr(net_quantized.linear1.weight()))

# Define the Quantizedtest loop

In [None]:
def Quantizedtest(test_loader, model: nn.Module, total_iterations: int = None):
    correct = 0
    total = 0

    iterations = 0

    model.to("cpu")
    model.eval()

    with torch.no_grad():
        for data in tqdm(test_loader, desc='Testing'):
            x, y = data
            x = x.to("cpu")
            y = y.to("cpu")

            # 使用其他模型时，可能需要调整这里的输入格式
            output = model(x.view(x.size(0), -1))
            
            # enumerate 接受一个可迭代对象（列表、张量、生成器等）,返回 索引和元素 的对 (index, element)
            for idx, i in enumerate(output):
                # torch.argmax(i) → 返回张量 i 中最大值的索引
                if torch.argmax(i) == y[idx]:
                    correct +=1
                total +=1
            iterations += 1
            if total_iterations is not None and iterations >= total_iterations:
                break
    accuracy = correct / total
    print(f'Accuracy: {round(accuracy * 100, 2)}%')

In [None]:
print('Testing the model after quantization')
Quantizedtest(test_loader, net_quantized)