# 모델 양자화 - 간단한 입출력 변환

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.quantization
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import os

torch.backends.quantized.engine = 'qnnpack'

# 1. 데이터셋 불러오기
transform = transforms.Compose([transforms.ToTensor()])
train_set = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_set = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

train_loader = DataLoader(train_set, batch_size=64, shuffle=True)
test_loader = DataLoader(test_set, batch_size=1000, shuffle=False)

# 2. LeNet + QuantStub, DeQuantStub 모델 정의
class QuantLeNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.quant = torch.quantization.QuantStub()
        self.conv1 = nn.Conv2d(1, 6, 5)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 4 * 4, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)
        self.dequant = torch.quantization.DeQuantStub()

    def forward(self, x):
        x = self.quant(x)
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2)
        # 양자화할 경우 메모리가 연속적이지 않을 수 있기 때문에 view 대신 reshape 사용
        # x = x.view(-1, 16 * 4 * 4)
        x = x.reshape(-1, 16 * 4 * 4)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        x = self.dequant(x)
        return x

# 3. 훈련 함수
def train(model, loader, optimizer, criterion, epoch):
    model.train()
    for batch_idx, (x, y) in enumerate(loader):
        optimizer.zero_grad()
        pred = model(x)
        loss = criterion(pred, y)
        loss.backward()
        optimizer.step()

# 4. 평가 함수
def evaluate(model, loader):
    model.eval()
    correct = 0
    with torch.no_grad():
        for x, y in loader:
            pred = model(x)
            correct += (pred.argmax(1) == y).sum().item()
    return correct / len(loader.dataset)

# 5. float32 모델 학습
model_fp32 = QuantLeNet()
optimizer = torch.optim.Adam(model_fp32.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

print("Training float32 model...")
for epoch in range(3):  # 간단히 3 epoch
    train(model_fp32, train_loader, optimizer, criterion, epoch)
    acc = evaluate(model_fp32, test_loader)
    print(f"Epoch {epoch+1}: Accuracy = {acc:.4f}")

# 6. 양자화 설정
model_fp32.eval()
model_fp32.qconfig = torch.quantization.get_default_qconfig('qnnpack')  # CPU 용
torch.quantization.prepare(model_fp32, inplace=True)

# 7. Calibration (통계 수집)
print("Calibrating...")
with torch.no_grad():
    for x, y in train_loader:
        model_fp32(x)
        break  # 한두 배치로 충분

# 8. 양자화 적용
model_int8 = torch.quantization.convert(model_fp32)

# 9. 정확도 비교
acc_fp32 = evaluate(model_fp32, test_loader)
acc_int8 = evaluate(model_int8, test_loader)

print("\n🎯 Accuracy Comparison")
print(f"Float32 model accuracy: {acc_fp32:.4f}")
print(f"INT8 quantized model accuracy: {acc_int8:.4f}")

# 10. 모델 크기 비교
def get_size_of_model(model, path="temp.pth"):
    torch.save(model.state_dict(), path)
    size = os.path.getsize(path) / 1024  # KB
    os.remove(path)
    return size

print("\n📦 Model Size Comparison")
print(f"Float32 model size: {get_size_of_model(model_fp32):.2f} KB")
print(f"INT8 quantized model size: {get_size_of_model(model_int8):.2f} KB")

Training float32 model...
Epoch 1: Accuracy = 0.9730
Epoch 2: Accuracy = 0.9764
Epoch 3: Accuracy = 0.9839
Calibrating...

🎯 Accuracy Comparison
Float32 model accuracy: 0.9839
INT8 quantized model accuracy: 0.9839

📦 Model Size Comparison
Float32 model size: 231.19 KB
INT8 quantized model size: 50.76 KB
