# Importing necessary libraries

In [33]:
import os, time, math, copy
from pathlib import Path
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

from sklearn.cluster import KMeans
from torch.quantization.quantize_fx import prepare_qat_fx, convert_fx
from torch.ao.quantization import get_default_qconfig_mapping

# Gloab Configuration

In [34]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
CPU = torch.device("cpu")
BATCH_SIZE = 64
EPOCHS = 5
LEARNING_RATE = 0.001

torch.backends.quantized.engine = "fbgemm"
torch.manual_seed(7); np.random.seed(7)

# Data Loading and Handeling

In [35]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

train_dataset = torchvision.datasets.MNIST(
    root='./data', train=True, download=True, transform=transform
)
train_loader = DataLoader(
    train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True
)

test_dataset = torchvision.datasets.MNIST(
    root='./data', train=False, download=True, transform=transform
)
test_loader = DataLoader(
    test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True
)


# Fully connected and convolution NN architecture

In [36]:
class FCNN(nn.Module):
    def __init__(self):
        super(FCNN, self).__init__()
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(28 * 28, 512)
        self.relu1 = nn.ReLU()
        self.drop1 = nn.Dropout(0.2)
        self.fc2 = nn.Linear(512, 256)
        self.relu2 = nn.ReLU()
        self.drop2 = nn.Dropout(0.2)
        self.fc3 = nn.Linear(256, 10)

    def forward(self, x):
        x = self.flatten(x)
        x = self.drop1(self.relu1(self.fc1(x)))
        x = self.drop2(self.relu2(self.fc2(x)))
        lg = self.fc3(x)
        return lg

class FCNN_QAT(nn.Module):
    def __init__(self):
        super(FCNN_QAT, self).__init__()
        self.quant = torch.quantization.QuantStub()
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(28 * 28, 512)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Linear(512, 256)
        self.relu2 = nn.ReLU()
        self.fc3 = nn.Linear(256, 10)
        self.dequant = torch.quantization.DeQuantStub()

    def forward(self, x):
        x = self.quant(x)
        x = self.flatten(x)
        x = self.relu1(self.fc1(x))
        x = self.relu2(self.fc2(x))
        logits = self.fc3(x)
        x = self.dequant(logits)
        return x

class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
        self.relu1 = nn.ReLU()
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.relu2 = nn.ReLU()
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(64 * 7 * 7, 128)
        self.relu3 = nn.ReLU()
        self.drop3 = nn.Dropout(0.2)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.pool1(self.relu1(self.conv1(x)))
        x = self.pool2(self.relu2(self.conv2(x)))
        x = self.flatten(x)
        x = self.drop3(self.relu3(self.fc1(x)))
        lg = self.fc2(x)
        return lg

class CNN_QAT(nn.Module):
    def __init__(self):
        super(CNN_QAT, self).__init__()
        self.quant = torch.quantization.QuantStub()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
        self.relu1 = nn.ReLU()
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.relu2 = nn.ReLU()
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(64 * 7 * 7, 128)
        self.relu3 = nn.ReLU()
        self.fc2 = nn.Linear(128, 10)
        self.dequant = torch.quantization.DeQuantStub()

    def forward(self, x):
        x = self.quant(x)
        x = self.pool1(self.relu1(self.conv1(x)))
        x = self.pool2(self.relu2(self.conv2(x)))
        x = self.flatten(x)
        x = self.relu3(self.fc1(x))
        logits = self.fc2(x)
        x = self.dequant(logits)
        return x

# Training and evaluation functions

In [37]:
def train_model(model, train_loader, epochs=EPOCHS, fine_tune=False):
    if fine_tune:
        epochs = 3
        print(f"Fine-tuning {model.__class__.__name__}")
    else:
        print(f"Training {model.__class__.__name__}")
    model.to(DEVICE)
    loss_func = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
    model.train()
    for epoch in range(epochs):
        for i, (images, labels) in enumerate(train_loader):
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            outputs = model(images)
            loss = loss_func(outputs, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        print(f'Epoch [{epoch+1}/{epochs}]\t | Loss: {loss.item():.4f}')

def evaluate_model(model, test_loader):
    model.to(DEVICE)
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    accuracy = 100 * correct / total
    return accuracy

# Helper functions

In [38]:
def _kmeans_quantize_tensor(t: torch.Tensor, n_clusters: int) -> torch.Tensor:
    """K-Means on a tensor with safe k selection."""
    arr = t.detach().cpu().view(-1, 1).numpy()
    N = arr.shape[0]
    if N < 2:
        return t
    uniq = np.unique(arr).size
    k_eff = int(max(1, min(n_clusters, N, uniq)))
    if k_eff == 1:
        mean_val = float(arr.mean())
        return torch.full_like(t, mean_val)
    km = KMeans(n_clusters=k_eff, random_state=0, n_init=10)
    labels = km.fit_predict(arr)
    cents  = km.cluster_centers_.astype(np.float32)
    q = torch.from_numpy(cents[labels].reshape(t.shape))
    return q.to(t.dtype)

def apply_kmeans_quantization(model, n_clusters=16):
    for module in model.modules():
        if isinstance(module, (nn.Linear, nn.Conv2d)):
            W = module.weight.data
            module.weight.data = _kmeans_quantize_tensor(W, n_clusters).to(W.device)
    return model

def apply_linear_quantization(model, bits=4):
    for module in model.modules():
        if isinstance(module, (nn.Linear, nn.Conv2d)):
            W = module.weight.data
            v_max, v_min = W.max(), W.min()
            if (v_max - v_min).abs() < 1e-12:
                module.weight.data = torch.full_like(W, W.mean())
                continue
            amax = W.abs().max()
            if amax < 1e-12:
                module.weight.data = torch.zeros_like(W)
                continue
            scale = amax / (2**(bits-1) - 1)
            q = torch.round(W / scale).clamp(-(2**(bits-1)), 2**(bits-1)-1)
            Wdq = q * scale
            module.weight.data = Wdq.to(W.device)
    return model

# Helper functions for comparision

In [39]:
def model_size_bytes_from_state_dict(sd, tmp_path):
    torch.save(sd, tmp_path)
    n = Path(tmp_path).stat().st_size
    return n

def param_count(model):
    return sum(p.numel() for p in model.parameters())

def latency_ms(model, device=torch.device("cpu"), runs=50, warmup=10):
    model = copy.deepcopy(model)
    model.to(device).eval()
    x = torch.randn(1,1,28,28, device=device)
    with torch.no_grad():
        for _ in range(warmup):
            _ = model(x)
    if device.type == "cuda":
        torch.cuda.synchronize()
    t0 = time.time()
    with torch.no_grad():
        for _ in range(runs):
            _ = model(x)
    if device.type == "cuda":
        torch.cuda.synchronize()
    return (time.time()-t0)/runs * 1000.0

def theoretical_kmeans_size_bits(model, n_clusters=16):
    total_bits = 0
    for module in model.modules():
        if isinstance(module, (nn.Linear, nn.Conv2d)):
            t = module.weight.detach().cpu().view(-1).numpy()
            N = t.size
            if N == 0:
                continue
            uniq = np.unique(t).size
            k_eff = int(max(1, min(n_clusters, N, uniq)))
            codebook_bits = k_eff * 32
            index_bits = 0 if k_eff == 1 else math.ceil(math.log2(k_eff)) * N
            total_bits += codebook_bits + index_bits
    return total_bits

def theoretical_linear_size_bits(model, bits=4):
    total_bits = 0
    for module in model.modules():
        if isinstance(module, (nn.Linear, nn.Conv2d)):
            total_bits += module.weight.numel() * bits
    return total_bits

In [40]:
def run_qat(model_class, original_model_path, model_name):
    """
    FX-QAT flow, also return quantized model + file size.
    """
    qat_device = torch.device("cpu")
    model = model_class().to(qat_device)
    model.load_state_dict(torch.load(original_model_path, map_location=qat_device))
    model.eval()

    qconfig_mapping = get_default_qconfig_mapping("fbgemm")
    example_inputs = (next(iter(train_loader))[0],)  # a batch of images

    prepared_model = prepare_qat_fx(model, qconfig_mapping, example_inputs)

    print(f"[QAT] Fine-tuning {model_name} on CPU...")
    global DEVICE
    old_device = DEVICE
    DEVICE = torch.device("cpu")  # ensure train_model uses CPU
    prepared_model.train()
    train_model(prepared_model, train_loader, fine_tune=True)
    DEVICE = old_device

    prepared_model.eval()
    quantized_model = convert_fx(prepared_model)

    # Evaluate and measure size
    old_device = DEVICE
    DEVICE = torch.device("cpu")
    accuracy = evaluate_model(quantized_model, test_loader)
    DEVICE = old_device

    size_bytes = model_size_bytes_from_state_dict(quantized_model.state_dict(), f"{model_name}_qat_int8_sd.pth")
    return quantized_model, accuracy, size_bytes

In [41]:
def train_original_models():
    print("Training Original Models")
    fcnn_original = FCNN()
    train_model(fcnn_original, train_loader)
    torch.save(fcnn_original.state_dict(), "FCNN_original.pth")
    original_fcnn_accuracy = evaluate_model(fcnn_original, test_loader)
    print(f"Original FCNN Test Accuracy: {original_fcnn_accuracy:.2f}%\n")

    cnn_original = CNN()
    train_model(cnn_original, train_loader)
    torch.save(cnn_original.state_dict(), "CNN_original.pth")
    original_cnn_accuracy = evaluate_model(cnn_original, test_loader)
    print(f"Original CNN Test Accuracy: {original_cnn_accuracy:.2f}%\n")
    return original_fcnn_accuracy, original_cnn_accuracy

# Main Function

In [42]:
if __name__ == '__main__':
    # 1) Train baselines
    original_fcnn_accuracy, original_cnn_accuracy = train_original_models()

    # Reload FP32 models on CPU for standardized measurements
    fcnn_fp32 = FCNN().to(CPU)
    fcnn_fp32.load_state_dict(torch.load("FCNN_original.pth", map_location=CPU))
    cnn_fp32 = CNN().to(CPU)
    cnn_fp32.load_state_dict(torch.load("CNN_original.pth", map_location=CPU))

    # Params & measured state-dict sizes
    fcnn_params = param_count(fcnn_fp32); cnn_params = param_count(cnn_fp32)
    fcnn_fp32_size = model_size_bytes_from_state_dict(fcnn_fp32.state_dict(), "FCNN_fp32_sd.pth")
    cnn_fp32_size  = model_size_bytes_from_state_dict(cnn_fp32.state_dict(),  "CNN_fp32_sd.pth")

    fcnn_lat_cpu = latency_ms(fcnn_fp32, device=CPU)
    cnn_lat_cpu  = latency_ms(cnn_fp32,  device=CPU)

    # 2) PTQ: K-Means
    print("\nRunning K-Means Weight Quantization (Post-Training)")
    fcnn_kmeans = copy.deepcopy(fcnn_fp32)
    fcnn_kmeans = apply_kmeans_quantization(fcnn_kmeans, n_clusters=16)
    kmeans_fcnn_accuracy = evaluate_model(fcnn_kmeans, test_loader)

    cnn_kmeans = copy.deepcopy(cnn_fp32)
    cnn_kmeans = apply_kmeans_quantization(cnn_kmeans, n_clusters=16)
    kmeans_cnn_accuracy = evaluate_model(cnn_kmeans, test_loader)

    fcnn_km_size = model_size_bytes_from_state_dict(fcnn_kmeans.state_dict(), "FCNN_kmeans_sd.pth")
    cnn_km_size  = model_size_bytes_from_state_dict(cnn_kmeans.state_dict(),  "CNN_kmeans_sd.pth")

    fcnn_km_bits = theoretical_kmeans_size_bits(fcnn_fp32, n_clusters=16)
    cnn_km_bits  = theoretical_kmeans_size_bits(cnn_fp32,  n_clusters=16)

    # 3) PTQ: Linear (4-bit symmetric)
    print("\nRunning Linear Quantization (Post-Training)")
    fcnn_linear = copy.deepcopy(fcnn_fp32)
    fcnn_linear = apply_linear_quantization(fcnn_linear, bits=4)
    linear_fcnn_accuracy = evaluate_model(fcnn_linear, test_loader)

    cnn_linear = copy.deepcopy(cnn_fp32)
    cnn_linear = apply_linear_quantization(cnn_linear, bits=4)
    linear_cnn_accuracy = evaluate_model(cnn_linear, test_loader)

    fcnn_lq_size = model_size_bytes_from_state_dict(fcnn_linear.state_dict(), "FCNN_linear4_sd.pth")
    cnn_lq_size  = model_size_bytes_from_state_dict(cnn_linear.state_dict(),  "CNN_linear4_sd.pth")

    fcnn_lq_bits = theoretical_linear_size_bits(fcnn_fp32, bits=4)
    cnn_lq_bits  = theoretical_linear_size_bits(cnn_fp32,  bits=4)

    # 4) QAT (FX)
    print("\nRunning Quantization Aware Training (QAT)")
    fcnn_q_model, qat_fcnn_accuracy, fcnn_q_size = run_qat(FCNN_QAT, "FCNN_original.pth", "FCNN")
    cnn_q_model,  qat_cnn_accuracy,  cnn_q_size  = run_qat(CNN_QAT,  "CNN_original.pth",  "CNN")

    # 5) Results table
    rows = []
    def add_row(name, method, model_obj, acc, size_bytes, theo_bits=None, lat_cpu=None, params=None):
        rows.append(dict(
            model=name,
            method=method,
            acc_pct=round(float(acc), 2),
            params=int(params if params is not None else param_count(model_obj)),
            file_mb_measured=round(size_bytes/(1024**2), 3) if size_bytes is not None else None,
            theoretical_mb=round((theo_bits/8)/(1024**2), 3) if theo_bits is not None else None,
            latency_ms_cpu=round(lat_cpu, 3) if lat_cpu is not None else None
        ))

    add_row("FCNN", "FP32 (baseline)", fcnn_fp32, original_fcnn_accuracy, fcnn_fp32_size,
            theo_bits=None, lat_cpu=fcnn_lat_cpu, params=fcnn_params)
    add_row("CNN",  "FP32 (baseline)", cnn_fp32,  original_cnn_accuracy,  cnn_fp32_size,
            theo_bits=None, lat_cpu=cnn_lat_cpu,  params=cnn_params)

    add_row("FCNN", "PTQ KMeans(K=16)", fcnn_kmeans, kmeans_fcnn_accuracy, fcnn_km_size,
            theo_bits=fcnn_km_bits, lat_cpu=latency_ms(fcnn_kmeans, CPU), params=fcnn_params)
    add_row("CNN",  "PTQ KMeans(K=16)", cnn_kmeans,  kmeans_cnn_accuracy,  cnn_km_size,
            theo_bits=cnn_km_bits,  lat_cpu=latency_ms(cnn_kmeans,  CPU), params=cnn_params)

    add_row("FCNN", "PTQ Linear(4-bit sym)", fcnn_linear, linear_fcnn_accuracy, fcnn_lq_size,
            theo_bits=fcnn_lq_bits, lat_cpu=latency_ms(fcnn_linear, CPU), params=fcnn_params)
    add_row("CNN",  "PTQ Linear(4-bit sym)", cnn_linear,  linear_cnn_accuracy,  cnn_lq_size,
            theo_bits=cnn_lq_bits,  lat_cpu=latency_ms(cnn_linear,  CPU), params=cnn_params)

    add_row("FCNN", "QAT INT8 (FX)", fcnn_q_model, qat_fcnn_accuracy, fcnn_q_size,
            theo_bits=None, lat_cpu=latency_ms(fcnn_q_model, CPU), params=fcnn_params)
    add_row("CNN",  "QAT INT8 (FX)", cnn_q_model,  qat_cnn_accuracy,  cnn_q_size,
            theo_bits=None, lat_cpu=latency_ms(cnn_q_model,  CPU), params=cnn_params)

    df = pd.DataFrame(rows, columns=[
        "model","method","acc_pct","params","file_mb_measured","theoretical_mb","latency_ms_cpu"
    ])
    print("\n=== RESULTS (before/after) ===")
    print(df.to_string(index=False))
    df.to_csv("quant_results.csv", index=False)
    print("\nSaved: quant_results.csv")


Training Original Models
Training FCNN
Epoch [1/5]	 | Loss: 0.0143
Epoch [2/5]	 | Loss: 0.0099
Epoch [3/5]	 | Loss: 0.0056
Epoch [4/5]	 | Loss: 0.2504
Epoch [5/5]	 | Loss: 0.0085
Original FCNN Test Accuracy: 98.10%

Training CNN
Epoch [1/5]	 | Loss: 0.0089
Epoch [2/5]	 | Loss: 0.0040
Epoch [3/5]	 | Loss: 0.0880
Epoch [4/5]	 | Loss: 0.0043
Epoch [5/5]	 | Loss: 0.0008
Original CNN Test Accuracy: 99.10%


Running K-Means Weight Quantization (Post-Training)

Running Linear Quantization (Post-Training)

Running Quantization Aware Training (QAT)
[QAT] Fine-tuning FCNN on CPU...
Fine-tuning GraphModule


For migrations of users: 
1. Eager mode quantization (torch.ao.quantization.quantize, torch.ao.quantization.quantize_dynamic), please migrate to use torchao eager mode quantize_ API instead 
2. FX graph mode quantization (torch.ao.quantization.quantize_fx.prepare_fx,torch.ao.quantization.quantize_fx.convert_fx, please migrate to use torchao pt2e quantization API instead (prepare_pt2e, convert_pt2e) 
3. pt2e quantization has been migrated to torchao (https://github.com/pytorch/ao/tree/main/torchao/quantization/pt2e) 
see https://github.com/pytorch/ao/issues/2259 for more details
  prepared_model = prepare_qat_fx(model, qconfig_mapping, example_inputs)


Epoch [1/3]	 | Loss: 0.0383
Epoch [2/3]	 | Loss: 0.0244
Epoch [3/3]	 | Loss: 0.0014


For migrations of users: 
1. Eager mode quantization (torch.ao.quantization.quantize, torch.ao.quantization.quantize_dynamic), please migrate to use torchao eager mode quantize_ API instead 
2. FX graph mode quantization (torch.ao.quantization.quantize_fx.prepare_fx,torch.ao.quantization.quantize_fx.convert_fx, please migrate to use torchao pt2e quantization API instead (prepare_pt2e, convert_pt2e) 
3. pt2e quantization has been migrated to torchao (https://github.com/pytorch/ao/tree/main/torchao/quantization/pt2e) 
see https://github.com/pytorch/ao/issues/2259 for more details
  quantized_model = convert_fx(prepared_model)


[QAT] Fine-tuning CNN on CPU...
Fine-tuning GraphModule
Epoch [1/3]	 | Loss: 0.0129
Epoch [2/3]	 | Loss: 0.0003
Epoch [3/3]	 | Loss: 0.0443

=== RESULTS (before/after) ===
model                method  acc_pct  params  file_mb_measured  theoretical_mb  latency_ms_cpu
 FCNN       FP32 (baseline)    98.10  535818             2.047             NaN           0.181
  CNN       FP32 (baseline)    99.10  421642             1.612             NaN           0.678
 FCNN      PTQ KMeans(K=16)    98.06  535818             2.047           0.255           0.194
  CNN      PTQ KMeans(K=16)    99.09  421642             1.612           0.201           0.713
 FCNN PTQ Linear(4-bit sym)    98.13  535818             2.047           0.255           0.186
  CNN PTQ Linear(4-bit sym)    98.99  421642             1.612           0.201           0.694
 FCNN         QAT INT8 (FX)    97.75  535818             0.532             NaN           0.234
  CNN         QAT INT8 (FX)    99.07  421642             0.415      