<a href="https://colab.research.google.com/github/AlexKalll/Supervised-ML-Models/blob/main/Model_optimization.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [4]:
# ============================================================
# SECTION 0: Imports & Global Configuration
# ============================================================

import os
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.quantization as tq

from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset
from tqdm import tqdm

import warnings
warnings.filterwarnings("ignore")

# Explicitly set the quantization engine to QNNPACK
if 'qnnpack' in torch.backends.quantized.supported_engines:
    torch.backends.quantized.engine = 'qnnpack'

if torch.cuda.is_available():
    DEVICE = torch.device("cuda")
else:
    DEVICE = torch.device("cpu")

BATCH_SIZE = 64
EPOCHS = 2
KD_EPOCHS = 2
LR = 1e-3


# ============================================================
# SECTION 1: Data Loading (MNIST)
# ============================================================

def get_dataloaders():
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])

    train_dataset = datasets.MNIST(root="./data", train=True, download=True, transform=transform)
    test_dataset = datasets.MNIST(root="./data", train=False, download=True, transform=transform)

    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

    return train_loader, test_loader


# ============================================================
# SECTION 2: Baseline Feed-Forward Network
# ============================================================

class BaselineFFN(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(28 * 28, 256)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Linear(256, 128)
        self.relu2 = nn.ReLU()
        self.fc3 = nn.Linear(128, 10)

    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = self.relu1(self.fc1(x))
        x = self.relu2(self.fc2(x))
        x = self.fc3(x)
        return x


# ============================================================
# SECTION 3: Mixture of Experts (MoE)
# ============================================================

class MoELayer(nn.Module):
    def __init__(self, input_dim, output_dim, num_experts=4):
        super().__init__()
        self.num_experts = num_experts

        # Helper stubs to create a "Float Island"
        self.dequant = tq.DeQuantStub()
        self.quant = tq.QuantStub()

        self.gate = nn.Linear(input_dim, num_experts)
        self.experts = nn.ModuleList([
            nn.Sequential(
                nn.Linear(input_dim, output_dim),
                nn.ReLU()
            ) for _ in range(num_experts)
        ])

    def forward(self, x):
        # 1. Dequantize input to float (Float Island begins)
        x_float = self.dequant(x)

        # 2. Gate (Float op)
        gate_logits = self.gate(x_float)
        gate_probs = F.softmax(gate_logits, dim=1)

        # 3. Experts (Float op)
        expert_outputs = torch.stack([expert(x_float) for expert in self.experts], dim=1)

        # 4. Weighted Sum (Float op)
        output_float = torch.sum(gate_probs.unsqueeze(-1) * expert_outputs, dim=1)

        # 5. Quantize output back to int8 (Float Island ends)
        output_quantized = self.quant(output_float)

        return output_quantized


class MoEFFN(nn.Module):
    def __init__(self):
        super().__init__()
        self.quant = tq.QuantStub()
        self.fc1 = nn.Linear(28 * 28, 256)
        self.relu1 = nn.ReLU()
        self.moe = MoELayer(256, 128, num_experts=4)
        self.fc3 = nn.Linear(128, 10)
        self.dequant = tq.DeQuantStub()

    def forward(self, x):
        x = self.quant(x)
        x = x.view(x.size(0), -1)
        x = self.relu1(self.fc1(x))
        x = self.moe(x) # Handles internal dequant/quant
        x = self.fc3(x)
        x = self.dequant(x)
        return x


# ============================================================
# SECTION 4: Training & Evaluation Utils
# ============================================================

def train(model, dataloader, optimizer, criterion):
    model.train()
    total_loss = 0
    for x, y in tqdm(dataloader, desc="Training", leave=False):
        x, y = x.to(DEVICE), y.to(DEVICE)
        optimizer.zero_grad()
        loss = criterion(model(x), y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(dataloader)

def evaluate(model, dataloader, device=None):
    if device is None: device = DEVICE
    model.to(device)
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for x, y in dataloader:
            x, y = x.to(device), y.to(device)
            preds = model(x).argmax(dim=1)
            correct += (preds == y).sum().item()
            total += y.size(0)
    return correct / total

def measure_latency(model, dataloader, num_batches=50, device=None):
    if device is None: device = DEVICE
    model.to(device)
    model.eval()
    # Warmup
    with torch.no_grad():
        for i, (x, _) in enumerate(dataloader):
            model(x.to(device))
            if i >= 5: break
    # Measure
    start = time.time()
    with torch.no_grad():
        for i, (x, _) in enumerate(dataloader):
            model(x.to(device))
            if i >= num_batches: break
    end = time.time()
    return (end - start) / num_batches

def count_parameters(model):
    total_params = 0
    for name, parameter in model.state_dict().items():
        if isinstance(parameter, torch.Tensor):
            is_weight_or_bias = ("weight" in name or "bias" in name)
            is_not_metadata = ("scale" not in name and "zero_point" not in name)
            if is_weight_or_bias and is_not_metadata:
                total_params += parameter.numel()
    return total_params

def get_model_size_mb(model):
    torch.save(model.state_dict(), "temp.p")
    size_mb = os.path.getsize("temp.p") / (1024 * 1024)
    os.remove("temp.p")
    return size_mb


# ============================================================
# SECTION 5: PTQ & QAT Helpers
# ============================================================

def apply_ptq_moe(model, calibration_loader):
    """Post-Training Quantization for MoE"""
    model.eval()
    model.qconfig = tq.get_default_qconfig("qnnpack")
    model_q = model.to("cpu")

    # Disable quantization for Float Island
    model_q.moe.gate.qconfig = None
    for expert in model_q.moe.experts:
        expert[0].qconfig = None
        expert[1].qconfig = None

    tq.fuse_modules(model_q, [["fc1", "relu1"]], inplace=True)
    tq.prepare(model_q, inplace=True)

    with torch.no_grad():
        for x, _ in calibration_loader:
            model_q(x.to("cpu"))

    tq.convert(model_q, inplace=True)
    return model_q

def apply_qat_moe(train_loader, test_loader):
    """Quantization Aware Training for MoE"""
    print("\nPreparing MoE for QAT...")

    # 1. Initialize a fresh model
    qat_model = MoEFFN().to(DEVICE)

    # 2. Set QAT Configuration
    qat_model.qconfig = tq.get_default_qat_qconfig("qnnpack")

    # 3. Disable quantization for the Float Island (Experts & Router)
    # We want these to remain high-precision during training and inference
    qat_model.moe.gate.qconfig = None
    for expert in qat_model.moe.experts:
        expert[0].qconfig = None
        expert[1].qconfig = None

    # 4. Fuse Modules (fc1 + relu1)
    tq.fuse_modules(qat_model, [["fc1", "relu1"]], inplace=True)

    # 5. Prepare for QAT (Inserts fake quantization nodes)
    tq.prepare_qat(qat_model, inplace=True)

    # 6. Train (Fine-tune with fake quantization noise)
    optimizer = optim.Adam(qat_model.parameters(), lr=1e-4) # Lower LR for QAT
    criterion = nn.CrossEntropyLoss()

    for epoch in range(EPOCHS):
        qat_model.train() # Important: QAT happens in train mode
        total_loss = 0
        for x, y in tqdm(train_loader, desc=f"QAT Epoch {epoch+1}", leave=False):
            x, y = x.to(DEVICE), y.to(DEVICE)
            optimizer.zero_grad()
            logits = qat_model(x)
            loss = criterion(logits, y)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        print(f"QAT Epoch {epoch+1}: Loss={total_loss/len(train_loader):.4f}")

    # 7. Convert to actual INT8 model (Move to CPU first)
    qat_model.eval()
    qat_model = qat_model.to("cpu")
    tq.convert(qat_model, inplace=True)

    return qat_model


# ============================================================
# SECTION 6: Knowledge Distillation
# ============================================================

def train_kd(student, teacher, dataloader, optimizer, alpha=0.7, temperature=4.0):
    student.train()
    teacher.eval()
    total_loss = 0
    for x, y in tqdm(dataloader, desc="KD Training", leave=False):
        x, y = x.to(DEVICE), y.to(DEVICE)
        with torch.no_grad():
            teacher_logits = teacher(x)
        student_logits = student(x)

        soft_loss = F.kl_div(
            F.log_softmax(student_logits/temperature, dim=1),
            F.softmax(teacher_logits/temperature, dim=1),
            reduction="batchmean"
        ) * (temperature**2)
        hard_loss = F.cross_entropy(student_logits, y)

        loss = alpha*soft_loss + (1-alpha)*hard_loss
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(dataloader)


# ============================================================
# SECTION 7: Main Experiment Pipeline
# ============================================================

def main():
    train_loader, test_loader = get_dataloaders()

    # Calibration loader (subset of train data)
    calibration_subset = Subset(train_loader.dataset, range(1024))
    calibration_loader = DataLoader(calibration_subset, batch_size=BATCH_SIZE)

    criterion = nn.CrossEntropyLoss()

    # ---------------------------------------------------------
    # 1. Baseline Training
    # ---------------------------------------------------------
    print("\n=== 1. Training Baseline ===")
    baseline = BaselineFFN().to(DEVICE)
    opt = optim.Adam(baseline.parameters(), lr=LR)

    for epoch in range(EPOCHS):
        loss = train(baseline, train_loader, opt, criterion)
        print(f"Epoch {epoch+1}: Loss={loss:.4f}")

    base_acc = evaluate(baseline, test_loader)
    base_lat = measure_latency(baseline, test_loader)


    # ---------------------------------------------------------
    # 2. MoE Training (Scratch)
    # ---------------------------------------------------------
    print("\n=== 2. Training MoE ===")
    moe = MoEFFN().to(DEVICE)
    opt_moe = optim.Adam(moe.parameters(), lr=LR)

    for epoch in range(EPOCHS):
        loss = train(moe, train_loader, opt_moe, criterion)
        print(f"Epoch {epoch+1}: Loss={loss:.4f}")

    moe_acc = evaluate(moe, test_loader)
    moe_lat = measure_latency(moe, test_loader)

    # ---------------------------------------------------------
    # 3. MoE + PTQ
    # ---------------------------------------------------------
    print("\n=== 3. Applying PTQ to MoE ===")
    moe_ptq_model = MoEFFN()
    moe_ptq_model.load_state_dict(moe.state_dict())
    moe_ptq_model = apply_ptq_moe(moe_ptq_model, calibration_loader)

    moe_ptq_acc = evaluate(moe_ptq_model, test_loader, device="cpu")
    moe_ptq_lat = measure_latency(moe_ptq_model, test_loader, device="cpu")

    # ---------------------------------------------------------
    # 4. MoE + QAT (New!)
    # ---------------------------------------------------------
    print("\n=== 4. Training MoE with QAT ===")
    # QAT trains a fresh model from scratch (or fine-tunes)
    moe_qat_model = apply_qat_moe(train_loader, test_loader)

    moe_qat_acc = evaluate(moe_qat_model, test_loader, device="cpu")
    moe_qat_lat = measure_latency(moe_qat_model, test_loader, device="cpu")

    # ---------------------------------------------------------
    # 5. MoE + KD Training
    # ---------------------------------------------------------
    print("\n=== 5. Training MoE with Knowledge Distillation ===")
    kd_student = MoEFFN().to(DEVICE)
    opt_kd = optim.Adam(kd_student.parameters(), lr=LR)

    for epoch in range(KD_EPOCHS):
        loss = train_kd(kd_student, baseline, train_loader, opt_kd)
        print(f"KD Epoch {epoch+1}: Loss={loss:.4f}")

    kd_acc = evaluate(kd_student, test_loader)
    kd_lat = measure_latency(kd_student, test_loader)

    # ---------------------------------------------------------
    # 6. MoE + KD + PTQ
    # ---------------------------------------------------------
    print("\n=== 6. Applying PTQ to MoE (KD) ===")
    kd_ptq_model = MoEFFN()
    kd_ptq_model.load_state_dict(kd_student.state_dict())
    kd_ptq_model = apply_ptq_moe(kd_ptq_model, calibration_loader)

    kd_ptq_acc = evaluate(kd_ptq_model, test_loader, device="cpu")
    kd_ptq_lat = measure_latency(kd_ptq_model, test_loader, device="cpu")


    # ---------------------------------------------------------
    # FINAL RESULTS
    # ---------------------------------------------------------
    print("\n" + "="*85)
    print(f"{'Model':<20} | {'Acc':<8} | {'Lat(ms)':<8} | {'Params':<12} | {'Size(MB)':<10}")
    print("-" * 85)

    def print_row(name, acc, lat, model):
        size = get_model_size_mb(model)
        params = count_parameters(model)
        print(f"{name:<20} | {acc:.4f}   | {lat*1000:.2f}     | {params:<12,} | {size:.2f}")

    print_row("Baseline", base_acc, base_lat, baseline)
    print_row("MoE", moe_acc, moe_lat, moe)
    print_row("MoE + PTQ", moe_ptq_acc, moe_ptq_lat, moe_ptq_model)
    print_row("MoE + QAT", moe_qat_acc, moe_qat_lat, moe_qat_model)
    print_row("MoE + KD", kd_acc, kd_lat, kd_student)
    print_row("MoE + KD + PTQ", kd_ptq_acc, kd_ptq_lat, kd_ptq_model)

    print("="*85)

if __name__ == "__main__":
    main()


=== 1. Training Baseline ===




Epoch 1: Loss=0.2278




Epoch 2: Loss=0.0912

=== 2. Training MoE ===




Epoch 1: Loss=0.2370




Epoch 2: Loss=0.0949

=== 3. Applying PTQ to MoE ===

=== 4. Training MoE with QAT ===

Preparing MoE for QAT...




QAT Epoch 1: Loss=0.6094




QAT Epoch 2: Loss=0.1937

=== 5. Training MoE with Knowledge Distillation ===




KD Epoch 1: Loss=0.6862




KD Epoch 2: Loss=0.1100

=== 6. Applying PTQ to MoE (KD) ===

Model                | Acc      | Lat(ms)  | Params       | Size(MB)  
-------------------------------------------------------------------------------------
Baseline             | 0.9712   | 14.79     | 235,146      | 0.90
MoE                  | 0.9703   | 17.99     | 334,862      | 1.28
MoE + PTQ            | 0.9698   | 17.71     | 132,612      | 0.71
MoE + QAT            | 0.9509   | 18.16     | 132,612      | 0.71
MoE + KD             | 0.9712   | 15.17     | 334,862      | 1.28
MoE + KD + PTQ       | 0.9704   | 18.90     | 132,612      | 0.71


In [3]:
# ============================================================
# SECTION 0: Imports & Global Configuration
# ============================================================

import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset
from tqdm import tqdm

import torch.quantization as tq

import warnings
warnings.filterwarnings("ignore")

if torch.cuda.is_available():
  DEVICE = torch.device("cuda")
else:
  DEVICE = torch.device("cpu")

BATCH_SIZE = 64
EPOCHS = 5
KD_EPOCHS = 5
LR = 1e-3


# ============================================================
# SECTION 1: Data Loading (MNIST)
# ============================================================

def get_dataloaders():
    transform = transforms.Compose([
        transforms.ToTensor()
    ])

    train_dataset = datasets.MNIST(
        root="./data",
        train=True,
        download=True,
        transform=transform
    )

    test_dataset = datasets.MNIST(
        root="./data",
        train=False,
        download=True,
        transform=transform
    )

    train_loader = DataLoader(
        train_dataset,
        batch_size=BATCH_SIZE,
        shuffle=True
    )

    test_loader = DataLoader(
        test_dataset,
        batch_size=BATCH_SIZE,
        shuffle=False
    )

    return train_loader, test_loader


# ============================================================
# SECTION 2: Baseline Feed-Forward Network
# ============================================================

class BaselineFFN(nn.Module):
    def __init__(self):
        super().__init__()
        self.quant = tq.QuantStub() # Added for PTQ
        self.fc1 = nn.Linear(28 * 28, 256)
        self.relu1 = nn.ReLU() # Changed from F.relu for fusibility
        self.fc2 = nn.Linear(256, 128)
        self.relu2 = nn.ReLU() # Changed from F.relu for fusibility
        self.fc3 = nn.Linear(128, 10)
        self.dequant = tq.DeQuantStub() # Added for PTQ

    def forward(self, x):
        x = self.quant(x) # Added for PTQ
        x = x.view(x.size(0), -1)  # flatten
        x = self.relu1(self.fc1(x)) # Use nn.ReLU instance
        x = self.relu2(self.fc2(x)) # Use nn.ReLU instance
        x = self.fc3(x)
        x = self.dequant(x) # Added for PTQ
        return x


# ============================================================
# SECTION 3: Mixture of Experts (MoE)
# ============================================================

class MoELayer(nn.Module):
    def __init__(self, input_dim, output_dim, num_experts=4):
        super().__init__()
        self.num_experts = num_experts

        # Quantization stubs for the MoE Layer itself to define its boundaries
        self.dequant = tq.DeQuantStub() # Dequantize input when entering MoELayer
        self.quant = tq.QuantStub() # Quantize output when exiting MoELayer

        # Router (gate) - these layers will operate on float inputs after dequantization
        self.gate = nn.Linear(input_dim, num_experts)

        # Experts - these layers will operate on float inputs after dequantization
        self.experts = nn.ModuleList([
            nn.Sequential(
                nn.Linear(input_dim, output_dim),
                nn.ReLU()
            )
            for _ in range(num_experts)
        ])

    def forward(self, x):
        # Assume input 'x' is a quantized tensor coming from a previous layer (e.g., fc1 in MoEFFN)
        # Dequantize the input for internal float operations within MoELayer
        x_float = self.dequant(x)

        # Routing probabilities (these operations will happen on float tensors)
        gate_logits = self.gate(x_float)
        gate_probs = F.softmax(gate_logits, dim=1)

        # Expert outputs (each expert takes float input and produces float output)
        expert_outputs_list = []
        for expert in self.experts:
            expert_outputs_list.append(expert(x_float))

        expert_outputs = torch.stack(
            expert_outputs_list,
            dim=1
        )

        # Weighted sum (these operations will happen on float tensors)
        output_float = torch.sum(
            gate_probs.unsqueeze(-1) * expert_outputs,
            dim=1
        )

        # Quantize the output of MoELayer before passing it to the next quantized layer (e.g., fc3 in MoEFFN)
        output_quantized = self.quant(output_float)

        return output_quantized


class MoEFFN(nn.Module):
    def __init__(self):
        super().__init__()
        self.quant = tq.QuantStub() # Quantization stub for the model's input
        self.fc1 = nn.Linear(28 * 28, 256)
        self.relu1 = nn.ReLU() # Changed from F.relu for fusibility
        self.moe = MoELayer(256, 128, num_experts=4) # MoELayer itself handles quant/dequant internally
        self.fc3 = nn.Linear(128, 10)
        self.dequant = tq.DeQuantStub() # Dequantization stub for the model's output

    def forward(self, x):
        x = self.quant(x) # Quantize input
        x = x.view(x.size(0), -1)
        x = self.relu1(self.fc1(x)) # fc1 output -> relu1 output (will be quantized after fusion)
        x = self.moe(x) # MoELayer takes quantized input, outputs quantized
        x = self.fc3(x) # fc3 takes quantized input, outputs quantized
        x = self.dequant(x) # Dequantize final output
        return x


# ============================================================
# SECTION 4: Standard Training & Evaluation
# ============================================================

def train(model, dataloader, optimizer, criterion):
    model.train()
    total_loss = 0

    for x, y in tqdm(dataloader):
        x, y = x.to(DEVICE), y.to(DEVICE)

        optimizer.zero_grad()
        logits = model(x)
        loss = criterion(logits, y)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    return total_loss / len(dataloader)


def evaluate(model, dataloader, device=None):
    if device is None:
        device = DEVICE
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for x, y in dataloader:
            x, y = x.to(device), y.to(device)
            logits = model(x)
            preds = logits.argmax(dim=1)
            correct += (preds == y).sum().item()
            total += y.size(0)

    return correct / total


# ============================================================
# SECTION 5: Latency Measurement
# ============================================================

def measure_latency(model, dataloader, num_batches=50, device=None):
    if device is None:
        device = DEVICE
    model.eval()

    # Warm-up
    with torch.no_grad():
        for i, (x, _) in enumerate(dataloader):
            model(x.to(device))
            if i >= 5:
                break

    start = time.time()

    with torch.no_grad():
        for i, (x, _) in enumerate(dataloader):
            model(x.to(device))
            if i >= num_batches:
                break

    end = time.time()
    return (end - start) / num_batches

# ============================================================
# SECTION 5.5: Model Size & Parameter Counting
# ============================================================

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

def estimate_model_size(model, quantized=False):
    """
    Rough model size estimation:
    FP32 -> 4 bytes per parameter
    INT8 -> 1 byte per parameter
    """
    num_params = count_parameters(model)
    bytes_per_param = 1 if quantized else 4
    size_mb = (num_params * bytes_per_param) / (1024 ** 2)
    return size_mb


# ============================================================
# SECTION 6: Post-Training Quantization (PTQ)
# ============================================================

def apply_ptq(model, calibration_loader):
    model.eval()
    model.qconfig = tq.get_default_qconfig("qnnpack") # Changed from fbgemm

    # Move model to CPU before preparing and converting for FBGEMM,
    # as FBGEMM typically uses CPU-optimized kernels for x86.
    model_for_quant = model.to(torch.device("cpu"))

    # Explicitly disable quantization for the internal modules of MoELayer
    # as it's designed to operate as a float island internally.
    model_for_quant.moe.gate.qconfig = None
    for expert_seq in model_for_quant.moe.experts:
        expert_seq[0].qconfig = None  # nn.Linear in expert
        expert_seq[1].qconfig = None  # nn.ReLU in expert

    # Fuse modules for better quantization performance.
    # For MoEFFN's fc1 and relu1:
    tq.fuse_modules(
        model_for_quant,
        [["fc1", "relu1"]],
        inplace=True
    )
    # MoELayer's internal layers (gate, experts) are now designed to operate in float
    # due to explicit dequant/quant stubs in MoELayer, so no fusion is needed for them.

    tq.prepare(model_for_quant, inplace=True);

    # Calibration on CPU with CPU tensors
    with torch.no_grad():
        for x, _ in calibration_loader:
            model_for_quant(x.to(torch.device("cpu")))

    tq.convert(model_for_quant, inplace=True)
    return model_for_quant # Return the CPU-quantized model


# ============================================================
# SECTION 7: Knowledge Distillation Training
# ============================================================

def train_kd(student, teacher, dataloader, optimizer, alpha=0.7, temperature=4.0):
    student.train()
    teacher.eval()

    total_loss = 0

    for x, y in tqdm(dataloader):
        x, y = x.to(DEVICE), y.to(DEVICE)

        with torch.no_grad():
            teacher_logits = teacher(x)

        student_logits = student(x)

        hard_loss = F.cross_entropy(student_logits, y)

        soft_loss = F.kl_div(
            F.log_softmax(student_logits / temperature, dim=1),
            F.softmax(teacher_logits / temperature, dim=1),
            reduction="batchmean"
        ) * (temperature ** 2)

        loss = alpha * soft_loss + (1 - alpha) * hard_loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    return total_loss / len(dataloader)


# ============================================================
# SECTION 8: Main Experiment Pipeline
# ============================================================

def main():
    # Explicitly set the quantization engine for CPU
    torch.backends.quantized.engine = 'qnnpack' # Changed from fbgemm

    train_loader, test_loader = get_dataloaders()

    # Define calibration_loader early, as it's used in multiple sections.
    calibration_subset = Subset(train_loader.dataset, range(1024))
    calibration_loader = DataLoader(calibration_subset, batch_size=BATCH_SIZE)

    # ---- Baseline ----
    print("\n=== Training Baseline ===")
    baseline = BaselineFFN().to(DEVICE)
    opt = optim.Adam(baseline.parameters(), lr=LR)
    criterion = nn.CrossEntropyLoss()

    for epoch in range(EPOCHS):
        loss = train(baseline, train_loader, opt, criterion)
        acc = evaluate(baseline, test_loader, device=DEVICE)
        print(f"Epoch {epoch+1}: Loss={loss:.4f}, Acc={acc:.4f}")

    base_acc = evaluate(baseline, test_loader, device=DEVICE)
    base_lat = measure_latency(baseline, test_loader, device=DEVICE)

    # ---- MoE ----
    print("\n=== Training MoE ===")
    moe = MoEFFN().to(DEVICE)
    opt = optim.Adam(moe.parameters(), lr=LR)

    for epoch in range(EPOCHS):
        loss = train(moe, train_loader, opt, criterion)
        acc = evaluate(moe, test_loader, device=DEVICE)
        print(f"Epoch {epoch+1}: Loss={loss:.4f}, Acc={acc:.4f}")

    moe_acc = evaluate(moe, test_loader, device=DEVICE)
    moe_lat = measure_latency(moe, test_loader, device=DEVICE)

    # ---- Baseline + PTQ (PTQ-only control) ----
    print("\n=== Applying PTQ to Baseline ===")

    baseline_for_quant = BaselineFFN()
    baseline_for_quant.load_state_dict(baseline.state_dict())

    # Baseline does NOT need MoE-specific handling
    baseline_for_quant.qconfig = tq.get_default_qconfig("qnnpack") # Changed from fbgemm
    baseline_for_quant = baseline_for_quant.to(torch.device("cpu"))

    tq.fuse_modules(
        baseline_for_quant,
        [["fc1", "relu1"], ["fc2", "relu2"]], # Restored both fusions
        inplace=True
    )

    tq.prepare(baseline_for_quant, inplace=True)

    with torch.no_grad():
        for x, _ in calibration_loader:
            baseline_for_quant(x.to(torch.device("cpu")))

    tq.convert(baseline_for_quant, inplace=True)

    baseline_ptq_acc = evaluate(
        baseline_for_quant,
        test_loader,
        device=torch.device("cpu")
    )

    baseline_ptq_lat = measure_latency(
        baseline_for_quant,
        test_loader,
        device=torch.device("cpu")
    )


    # ---- Knowledge Distillation ----
    print("\n=== Training MoE with Knowledge Distillation ===")
    kd_student = MoEFFN().to(DEVICE)
    opt = optim.Adam(kd_student.parameters(), lr=LR)

    for epoch in range(KD_EPOCHS):
        loss = train_kd(kd_student, baseline, train_loader, opt)
        acc = evaluate(kd_student, test_loader, device=DEVICE)
        print(f"KD Epoch {epoch+1}: Loss={loss:.4f}, Acc={acc:.4f}")

    kd_acc = evaluate(kd_student, test_loader, device=DEVICE)
    kd_lat = measure_latency(kd_student, test_loader, device=DEVICE)

    # ---- PTQ ----
    # calibration_subset = Subset(train_loader.dataset, range(1024)) # Moved up
    # calibration_loader = DataLoader(calibration_subset, batch_size=BATCH_SIZE) # Moved up

    # Create a new MoEFFN instance for quantization, as applying PTQ modifies the model in-place
    # and we want to quantize the KD student, not the original MoE model.
    # Also, move to CPU before applying PTQ and fusion.
    quant_kd_student = MoEFFN()
    quant_kd_student.load_state_dict(kd_student.state_dict())
    quant_moe = apply_ptq(quant_kd_student, calibration_loader)

    quant_acc = evaluate(quant_moe, test_loader, device=torch.device("cpu"))
    quant_lat = measure_latency(quant_moe, test_loader, device=torch.device("cpu"))


    print("\n===== FINAL RESULTS ===")

    print(f"Baseline      | Acc: {base_acc:.4f} | Lat: {base_lat:.6f} | "
          f"Params: {count_parameters(baseline):,} | "
          f"Size(MB): {estimate_model_size(baseline):.2f}")

    print(f"Baseline + PTQ| Acc: {baseline_ptq_acc:.4f} | Lat: {baseline_ptq_lat:.6f} | "
          f"Params: {count_parameters(baseline_for_quant):,} | "
          f"Size(MB): {estimate_model_size(baseline_for_quant, quantized=True):.2f}")

    print(f"MoE           | Acc: {moe_acc:.4f} | Lat: {moe_lat:.6f} | "
          f"Params: {count_parameters(moe):,} | "
          f"Size(MB): {estimate_model_size(moe):.2f}")

    print(f"MoE + KD      | Acc: {kd_acc:.4f} | Lat: {kd_lat:.6f} | "
          f"Params: {count_parameters(kd_student):,} | "
          f"Size(MB): {estimate_model_size(kd_student):.2f}")

    print(f"MoE + KD + PTQ| Acc: {quant_acc:.4f} | Lat: {quant_lat:.6f} | "
          f"Params: {count_parameters(quant_moe):,} | "
          f"Size(MB): {estimate_model_size(quant_moe, quantized=True):.2f}")


if __name__ == "__main__":
    main()


=== Training Baseline ===


100%|██████████| 938/938 [00:12<00:00, 72.86it/s]


Epoch 1: Loss=0.2881, Acc=0.9604


100%|██████████| 938/938 [00:13<00:00, 70.94it/s]


Epoch 2: Loss=0.1058, Acc=0.9704


100%|██████████| 938/938 [00:13<00:00, 68.30it/s]


Epoch 3: Loss=0.0706, Acc=0.9749


100%|██████████| 938/938 [00:14<00:00, 66.45it/s]


Epoch 4: Loss=0.0524, Acc=0.9791


100%|██████████| 938/938 [00:13<00:00, 67.53it/s]


Epoch 5: Loss=0.0411, Acc=0.9775

=== Training MoE ===


100%|██████████| 938/938 [00:15<00:00, 59.80it/s]


Epoch 1: Loss=0.2725, Acc=0.9611


100%|██████████| 938/938 [00:16<00:00, 57.78it/s]


Epoch 2: Loss=0.1028, Acc=0.9724


100%|██████████| 938/938 [00:15<00:00, 59.31it/s]


Epoch 3: Loss=0.0705, Acc=0.9729


100%|██████████| 938/938 [00:16<00:00, 58.08it/s]


Epoch 4: Loss=0.0512, Acc=0.9731


100%|██████████| 938/938 [00:16<00:00, 57.84it/s]


Epoch 5: Loss=0.0395, Acc=0.9749

=== Applying PTQ to Baseline ===

=== Training MoE with Knowledge Distillation ===


100%|██████████| 938/938 [00:17<00:00, 52.99it/s]


KD Epoch 1: Loss=1.5981, Acc=0.9597


100%|██████████| 938/938 [00:17<00:00, 54.60it/s]


KD Epoch 2: Loss=0.2881, Acc=0.9713


100%|██████████| 938/938 [00:17<00:00, 52.77it/s]


KD Epoch 3: Loss=0.1462, Acc=0.9739


100%|██████████| 938/938 [00:17<00:00, 54.70it/s]


KD Epoch 4: Loss=0.1000, Acc=0.9759


100%|██████████| 938/938 [00:18<00:00, 51.82it/s]


KD Epoch 5: Loss=0.0783, Acc=0.9778

===== FINAL RESULTS ===
Baseline      | Acc: 0.9775 | Lat: 0.008784 | Params: 235,146 | Size(MB): 0.90
Baseline + PTQ| Acc: 0.9779 | Lat: 0.011063 | Params: 0 | Size(MB): 0.00
MoE           | Acc: 0.9749 | Lat: 0.008887 | Params: 334,862 | Size(MB): 1.28
MoE + KD      | Acc: 0.9778 | Lat: 0.009381 | Params: 334,862 | Size(MB): 1.28
MoE + KD + PTQ| Acc: 0.9777 | Lat: 0.014814 | Params: 132,612 | Size(MB): 0.13


In [None]:
# ============================================================
# SECTION 0: Imports & Global Configuration
# ============================================================

import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset
from tqdm import tqdm

import torch.quantization as tq

import warnings
warnings.filterwarnings("ignore")

if torch.cuda.is_available():
  DEVICE = torch.device("cuda")
else:
  DEVICE = torch.device("cpu")

BATCH_SIZE = 64
EPOCHS = 5
KD_EPOCHS = 5
LR = 1e-3


# ============================================================
# SECTION 1: Data Loading (MNIST)
# ============================================================

def get_dataloaders():
    transform = transforms.Compose([
        transforms.ToTensor()
    ])

    train_dataset = datasets.MNIST(
        root="./data",
        train=True,
        download=True,
        transform=transform
    )

    test_dataset = datasets.MNIST(
        root="./data",
        train=False,
        download=True,
        transform=transform
    )

    train_loader = DataLoader(
        train_dataset,
        batch_size=BATCH_SIZE,
        shuffle=True
    )

    test_loader = DataLoader(
        test_dataset,
        batch_size=BATCH_SIZE,
        shuffle=False
    )

    return train_loader, test_loader


# ============================================================
# SECTION 2: Baseline Feed-Forward Network
# ============================================================

class BaselineFFN(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(28 * 28, 256)
        self.relu1 = nn.ReLU() # Changed from F.relu for fusibility
        self.fc2 = nn.Linear(256, 128)
        self.relu2 = nn.ReLU() # Changed from F.relu for fusibility
        self.fc3 = nn.Linear(128, 10)

    def forward(self, x):
        x = x.view(x.size(0), -1)  # flatten
        x = self.relu1(self.fc1(x)) # Use nn.ReLU instance
        x = self.relu2(self.fc2(x)) # Use nn.ReLU instance
        x = self.fc3(x)
        return x


# ============================================================
# SECTION 3: Mixture of Experts (MoE)
# ============================================================

class MoELayer(nn.Module):
    def __init__(self, input_dim, output_dim, num_experts=4):
        super().__init__()
        self.num_experts = num_experts

        # Quantization stubs for the MoE Layer itself to define its boundaries
        self.dequant = tq.DeQuantStub() # Dequantize input when entering MoELayer
        self.quant = tq.QuantStub() # Quantize output when exiting MoELayer

        # Router (gate) - these layers will operate on float inputs after dequantization
        self.gate = nn.Linear(input_dim, num_experts)

        # Experts - these layers will operate on float inputs after dequantization
        self.experts = nn.ModuleList([
            nn.Sequential(
                nn.Linear(input_dim, output_dim),
                nn.ReLU()
            )
            for _ in range(num_experts)
        ])

    def forward(self, x):
        # Assume input 'x' is a quantized tensor coming from a previous layer (e.g., fc1 in MoEFFN)
        # Dequantize the input for internal float operations within MoELayer
        x_float = self.dequant(x)

        # Routing probabilities (these operations will happen on float tensors)
        gate_logits = self.gate(x_float)
        gate_probs = F.softmax(gate_logits, dim=1)

        # Expert outputs (each expert takes float input and produces float output)
        expert_outputs_list = []
        for expert in self.experts:
            expert_outputs_list.append(expert(x_float))

        expert_outputs = torch.stack(
            expert_outputs_list,
            dim=1
        )

        # Weighted sum (these operations will happen on float tensors)
        output_float = torch.sum(
            gate_probs.unsqueeze(-1) * expert_outputs,
            dim=1
        )

        # Quantize the output of MoELayer before passing it to the next quantized layer (e.g., fc3 in MoEFFN)
        output_quantized = self.quant(output_float)

        return output_quantized


class MoEFFN(nn.Module):
    def __init__(self):
        super().__init__()
        self.quant = tq.QuantStub() # Quantization stub for the model's input
        self.fc1 = nn.Linear(28 * 28, 256)
        self.relu1 = nn.ReLU() # Changed from F.relu for fusibility
        self.moe = MoELayer(256, 128, num_experts=4) # MoELayer itself handles quant/dequant internally
        self.fc3 = nn.Linear(128, 10)
        self.dequant = tq.DeQuantStub() # Dequantization stub for the model's output

    def forward(self, x):
        x = self.quant(x) # Quantize input
        x = x.view(x.size(0), -1)
        x = self.relu1(self.fc1(x)) # fc1 output -> relu1 output (will be quantized after fusion)
        x = self.moe(x) # MoELayer takes quantized input, outputs quantized
        x = self.fc3(x) # fc3 takes quantized input, outputs quantized
        x = self.dequant(x) # Dequantize final output
        return x


# ============================================================
# SECTION 4: Standard Training & Evaluation
# ============================================================

def train(model, dataloader, optimizer, criterion):
    model.train()
    total_loss = 0

    for x, y in tqdm(dataloader):
        x, y = x.to(DEVICE), y.to(DEVICE)

        optimizer.zero_grad()
        logits = model(x)
        loss = criterion(logits, y)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    return total_loss / len(dataloader)


def evaluate(model, dataloader, device=None):
    if device is None:
        device = DEVICE
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for x, y in dataloader:
            x, y = x.to(device), y.to(device)
            logits = model(x)
            preds = logits.argmax(dim=1)
            correct += (preds == y).sum().item()
            total += y.size(0)

    return correct / total


# ============================================================
# SECTION 5: Latency Measurement
# ============================================================

def measure_latency(model, dataloader, num_batches=50, device=None):
    if device is None:
        device = DEVICE
    model.eval()

    # Warm-up
    with torch.no_grad():
        for i, (x, _) in enumerate(dataloader):
            model(x.to(device))
            if i >= 5:
                break

    start = time.time()

    with torch.no_grad():
        for i, (x, _) in enumerate(dataloader):
            model(x.to(device))
            if i >= num_batches:
                break

    end = time.time()
    return (end - start) / num_batches


# ============================================================
# SECTION 6: Post-Training Quantization (PTQ)
# ============================================================

def apply_ptq(model, calibration_loader):
    model.eval()
    model.qconfig = tq.get_default_qconfig("qnnpack")

    # Move model to CPU before preparing and converting for QNNPACK,
    # as QNNPACK typically uses CPU-optimized kernels.
    model_for_quant = model.to(torch.device("cpu"))

    # Explicitly disable quantization for the internal modules of MoELayer
    # as it's designed to operate as a float island internally.
    model_for_quant.moe.gate.qconfig = None
    for expert_seq in model_for_quant.moe.experts:
        expert_seq[0].qconfig = None  # nn.Linear in expert
        expert_seq[1].qconfig = None  # nn.ReLU in expert

    # Fuse modules for better quantization performance.
    # For MoEFFN's fc1 and relu1:
    tq.fuse_modules(model_for_quant, [["fc1", "relu1"]], inplace=True)
    # MoELayer's internal layers (gate, experts) are now designed to operate in float
    # due to explicit dequant/quant stubs in MoELayer, so no fusion is needed for them.

    tq.prepare(model_for_quant, inplace=True)

    # Calibration on CPU with CPU tensors
    with torch.no_grad():
        for x, _ in calibration_loader:
            model_for_quant(x.to(torch.device("cpu")))

    tq.convert(model_for_quant, inplace=True)
    return model_for_quant # Return the CPU-quantized model


# ============================================================
# SECTION 7: Knowledge Distillation Training
# ============================================================

def train_kd(student, teacher, dataloader, optimizer, alpha=0.7, temperature=4.0):
    student.train()
    teacher.eval()

    total_loss = 0

    for x, y in tqdm(dataloader):
        x, y = x.to(DEVICE), y.to(DEVICE)

        with torch.no_grad():
            teacher_logits = teacher(x)

        student_logits = student(x)

        hard_loss = F.cross_entropy(student_logits, y)

        soft_loss = F.kl_div(
            F.log_softmax(student_logits / temperature, dim=1),
            F.softmax(teacher_logits / temperature, dim=1),
            reduction="batchmean"
        ) * (temperature ** 2)

        loss = alpha * soft_loss + (1 - alpha) * hard_loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    return total_loss / len(dataloader)


# ============================================================
# SECTION 8: Main Experiment Pipeline
# ============================================================

def main():
    train_loader, test_loader = get_dataloaders()

    # ---- Baseline ----
    print("\n=== Training Baseline ===")
    baseline = BaselineFFN().to(DEVICE)
    opt = optim.Adam(baseline.parameters(), lr=LR)
    criterion = nn.CrossEntropyLoss()

    for epoch in range(EPOCHS):
        loss = train(baseline, train_loader, opt, criterion)
        acc = evaluate(baseline, test_loader, device=DEVICE)
        print(f"Epoch {epoch+1}: Loss={loss:.4f}, Acc={acc:.4f}")

    base_acc = evaluate(baseline, test_loader, device=DEVICE)
    base_lat = measure_latency(baseline, test_loader, device=DEVICE)

    # ---- MoE ----
    print("\n=== Training MoE ===")
    moe = MoEFFN().to(DEVICE)
    opt = optim.Adam(moe.parameters(), lr=LR)

    for epoch in range(EPOCHS):
        loss = train(moe, train_loader, opt, criterion)
        acc = evaluate(moe, test_loader, device=DEVICE)
        print(f"Epoch {epoch+1}: Loss={loss:.4f}, Acc={acc:.4f}")

    moe_acc = evaluate(moe, test_loader, device=DEVICE)
    moe_lat = measure_latency(moe, test_loader, device=DEVICE)

    # ---- Knowledge Distillation ----
    print("\n=== Training MoE with Knowledge Distillation ===")
    kd_student = MoEFFN().to(DEVICE)
    opt = optim.Adam(kd_student.parameters(), lr=LR)

    for epoch in range(KD_EPOCHS):
        loss = train_kd(kd_student, baseline, train_loader, opt)
        acc = evaluate(kd_student, test_loader, device=DEVICE)
        print(f"KD Epoch {epoch+1}: Loss={loss:.4f}, Acc={acc:.4f}")

    kd_acc = evaluate(kd_student, test_loader, device=DEVICE)
    kd_lat = measure_latency(kd_student, test_loader, device=DEVICE)

    # ---- PTQ ----
    calibration_subset = Subset(train_loader.dataset, range(1024))
    calibration_loader = DataLoader(calibration_subset, batch_size=BATCH_SIZE)

    # Create a new MoEFFN instance for quantization, as applying PTQ modifies the model in-place
    # and we want to quantize the KD student, not the original MoE model.
    # Also, move to CPU before applying PTQ and fusion.
    quant_kd_student = MoEFFN()
    quant_kd_student.load_state_dict(kd_student.state_dict())
    quant_moe = apply_ptq(quant_kd_student, calibration_loader)

    quant_acc = evaluate(quant_moe, test_loader, device=torch.device("cpu"))
    quant_lat = measure_latency(quant_moe, test_loader, device=torch.device("cpu"))

    # ---- Final Results ----
    print("\n===== FINAL RESULTS ===")
    print(f"Baseline      | Acc: {base_acc:.4f} | Latency: {base_lat:.6f}")
    print(f"MoE           | Acc: {moe_acc:.4f} | Latency: {moe_lat:.6f}")
    print(f"MoE + KD      | Acc: {kd_acc:.4f} | Latency: {kd_lat:.6f}")
    print(f"MoE + KD + PTQ| Acc: {quant_acc:.4f} | Latency: {quant_lat:.6f}")


if __name__ == "__main__":
    main()



=== Training Baseline ===


100%|██████████| 938/938 [00:12<00:00, 75.36it/s]


Epoch 1: Loss=0.2888, Acc=0.9573


100%|██████████| 938/938 [00:12<00:00, 76.67it/s]


Epoch 2: Loss=0.1100, Acc=0.9684


100%|██████████| 938/938 [00:12<00:00, 75.93it/s]


Epoch 3: Loss=0.0723, Acc=0.9737


100%|██████████| 938/938 [00:12<00:00, 76.50it/s]


Epoch 4: Loss=0.0543, Acc=0.9785


100%|██████████| 938/938 [00:13<00:00, 68.90it/s]


Epoch 5: Loss=0.0382, Acc=0.9777

=== Training MoE ===


100%|██████████| 938/938 [00:14<00:00, 66.44it/s]


Epoch 1: Loss=0.2808, Acc=0.9606


100%|██████████| 938/938 [00:14<00:00, 62.84it/s]


Epoch 2: Loss=0.1104, Acc=0.9699


100%|██████████| 938/938 [00:18<00:00, 51.69it/s]


Epoch 3: Loss=0.0752, Acc=0.9713


100%|██████████| 938/938 [00:15<00:00, 61.56it/s]


Epoch 4: Loss=0.0544, Acc=0.9702


100%|██████████| 938/938 [00:15<00:00, 61.72it/s]


Epoch 5: Loss=0.0427, Acc=0.9763

=== Training MoE with Knowledge Distillation ===


100%|██████████| 938/938 [00:15<00:00, 61.68it/s]


KD Epoch 1: Loss=1.5740, Acc=0.9624


100%|██████████| 938/938 [00:15<00:00, 59.22it/s]


KD Epoch 2: Loss=0.2850, Acc=0.9704


100%|██████████| 938/938 [00:16<00:00, 56.80it/s]


KD Epoch 3: Loss=0.1478, Acc=0.9738


100%|██████████| 938/938 [00:15<00:00, 59.48it/s]


KD Epoch 4: Loss=0.1017, Acc=0.9765


100%|██████████| 938/938 [00:15<00:00, 58.78it/s]


KD Epoch 5: Loss=0.0783, Acc=0.9785

===== FINAL RESULTS ===
Baseline      | Acc: 0.9777 | Latency: 0.007712
MoE           | Acc: 0.9763 | Latency: 0.007964
MoE + KD      | Acc: 0.9785 | Latency: 0.008079
MoE + KD + PTQ| Acc: 0.9774 | Latency: 0.008207


In [None]:
# ============================================================
# SECTION 0: Imports & Global Configuration
# ============================================================

import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.quantization as tq

from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset
from tqdm import tqdm

import warnings
warnings.filterwarnings("ignore")

# Explicitly set the quantization engine to QNNPACK (standard for x86/ARM CPUs)
if 'qnnpack' in torch.backends.quantized.supported_engines:
    torch.backends.quantized.engine = 'qnnpack'

if torch.cuda.is_available():
    DEVICE = torch.device("cuda")
else:
    DEVICE = torch.device("cpu")

BATCH_SIZE = 64
EPOCHS = 5
KD_EPOCHS = 5
LR = 1e-3


# ============================================================
# SECTION 1: Data Loading (MNIST)
# ============================================================

def get_dataloaders():
    transform = transforms.Compose([
        transforms.ToTensor(),
        # Normalization helps training stability
        transforms.Normalize((0.1307,), (0.3081,))
    ])

    train_dataset = datasets.MNIST(
        root="./data",
        train=True,
        download=True,
        transform=transform
    )

    test_dataset = datasets.MNIST(
        root="./data",
        train=False,
        download=True,
        transform=transform
    )

    train_loader = DataLoader(
        train_dataset,
        batch_size=BATCH_SIZE,
        shuffle=True
    )

    test_loader = DataLoader(
        test_dataset,
        batch_size=BATCH_SIZE,
        shuffle=False
    )

    return train_loader, test_loader


# ============================================================
# SECTION 2: Baseline Feed-Forward Network
# ============================================================

class BaselineFFN(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(28 * 28, 256)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Linear(256, 128)
        self.relu2 = nn.ReLU()
        self.fc3 = nn.Linear(128, 10)

    def forward(self, x):
        x = x.view(x.size(0), -1)  # flatten
        x = self.relu1(self.fc1(x))
        x = self.relu2(self.fc2(x))
        x = self.fc3(x)
        return x


# ============================================================
# SECTION 3: Mixture of Experts (MoE)
# ============================================================

class MoELayer(nn.Module):
    def __init__(self, input_dim, output_dim, num_experts=4):
        super().__init__()
        self.num_experts = num_experts

        # Quantization stubs: MoE Internal logic stays float (Float Island)
        self.dequant = tq.DeQuantStub() # Dequantize input (int8 -> fp32)
        self.quant = tq.QuantStub()     # Quantize output (fp32 -> int8)

        # Router (gate)
        self.gate = nn.Linear(input_dim, num_experts)

        # Experts
        self.experts = nn.ModuleList([
            nn.Sequential(
                nn.Linear(input_dim, output_dim),
                nn.ReLU()
            )
            for _ in range(num_experts)
        ])

    def forward(self, x):
        # 1. Dequantize incoming int8 tensor to float32
        x_float = self.dequant(x)

        # 2. Routing (Float32 operations)
        gate_logits = self.gate(x_float)
        gate_probs = F.softmax(gate_logits, dim=1)

        # 3. Expert computation (Float32 operations)
        expert_outputs_list = []
        for expert in self.experts:
            expert_outputs_list.append(expert(x_float))

        expert_outputs = torch.stack(expert_outputs_list, dim=1)

        # 4. Weighted sum (Float32 operations)
        output_float = torch.sum(
            gate_probs.unsqueeze(-1) * expert_outputs,
            dim=1
        )

        # 5. Quantize result back to int8 for the next layer
        output_quantized = self.quant(output_float)

        return output_quantized


class MoEFFN(nn.Module):
    def __init__(self):
        super().__init__()
        self.quant = tq.QuantStub() # Input quantization
        self.fc1 = nn.Linear(28 * 28, 256)
        self.relu1 = nn.ReLU()
        self.moe = MoELayer(256, 128, num_experts=4)
        self.fc3 = nn.Linear(128, 10)
        self.dequant = tq.DeQuantStub() # Output dequantization

    def forward(self, x):
        x = self.quant(x)
        x = x.view(x.size(0), -1)
        x = self.relu1(self.fc1(x))
        x = self.moe(x) # Enters float island, returns int8
        x = self.fc3(x)
        x = self.dequant(x)
        return x


# ============================================================
# SECTION 4: Standard Training & Evaluation
# ============================================================

def train(model, dataloader, optimizer, criterion):
    model.train()
    total_loss = 0

    for x, y in tqdm(dataloader, desc="Training", leave=False):
        x, y = x.to(DEVICE), y.to(DEVICE)

        optimizer.zero_grad()
        logits = model(x)
        loss = criterion(logits, y)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    return total_loss / len(dataloader)


def evaluate(model, dataloader, device=None):
    if device is None:
        device = DEVICE
    model.to(device)
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for x, y in dataloader:
            x, y = x.to(device), y.to(device)
            logits = model(x)
            preds = logits.argmax(dim=1)
            correct += (preds == y).sum().item()
            total += y.size(0)

    return correct / total


# ============================================================
# SECTION 5: Latency & Model Size Measurements
# ============================================================

def measure_latency(model, dataloader, num_batches=50, device=None):
    if device is None:
        device = DEVICE
    model.to(device)
    model.eval()

    # Warm-up
    with torch.no_grad():
        for i, (x, _) in enumerate(dataloader):
            model(x.to(device))
            if i >= 5: break

    start = time.time()
    with torch.no_grad():
        for i, (x, _) in enumerate(dataloader):
            model(x.to(device))
            if i >= num_batches: break
    end = time.time()

    return (end - start) / num_batches

def count_parameters(model):
    """Counts total trainable parameters."""
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

def estimate_model_size(model, quantized=False):
    """
    Estimates model size in MB.
    FP32 assumption: 4 bytes per param.
    INT8 assumption: 1 byte per param.
    """
    num_params = count_parameters(model)
    # Note: This is a heuristic. Quantized models pack weights, but overhead exists.
    bytes_per_param = 1 if quantized else 4
    size_mb = (num_params * bytes_per_param) / (1024 ** 2)
    return size_mb


# ============================================================
# SECTION 6: Post-Training Quantization (PTQ)
# ============================================================

def apply_ptq(model, calibration_loader):
    model.eval()
    # 1. Set qconfig for the model (QNNPACK)
    model.qconfig = tq.get_default_qconfig("qnnpack")

    # 2. Move to CPU (Quantization flow is typically CPU-based)
    model_for_quant = model.to(torch.device("cpu"))

    # 3. Disable quantization for the MoE Internals (Float Island)
    # We want MoE input/output to quantize, but the internal routing/experts to stay float.
    model_for_quant.moe.gate.qconfig = None
    for expert_seq in model_for_quant.moe.experts:
        expert_seq[0].qconfig = None # nn.Linear
        expert_seq[1].qconfig = None # nn.ReLU

    # 4. Fuse modules (only layers that are NOT in the float island)
    tq.fuse_modules(model_for_quant, [["fc1", "relu1"]], inplace=True)

    # 5. Prepare (Insert Observers)
    tq.prepare(model_for_quant, inplace=True)

    # 6. Calibrate
    with torch.no_grad():
        for x, _ in calibration_loader:
            model_for_quant(x.to(torch.device("cpu")))

    # 7. Convert (Float -> Int8)
    tq.convert(model_for_quant, inplace=True)

    return model_for_quant


# ============================================================
# SECTION 7: Knowledge Distillation Training
# ============================================================

def train_kd(student, teacher, dataloader, optimizer, alpha=0.7, temperature=4.0):
    student.train()
    teacher.eval()
    total_loss = 0

    for x, y in tqdm(dataloader, desc="KD Training", leave=False):
        x, y = x.to(DEVICE), y.to(DEVICE)

        with torch.no_grad():
            teacher_logits = teacher(x)

        student_logits = student(x)

        # Soft Target Loss (KL Divergence)
        soft_loss = F.kl_div(
            F.log_softmax(student_logits / temperature, dim=1),
            F.softmax(teacher_logits / temperature, dim=1),
            reduction="batchmean"
        ) * (temperature ** 2)

        # Hard Target Loss (Cross Entropy)
        hard_loss = F.cross_entropy(student_logits, y)

        loss = alpha * soft_loss + (1 - alpha) * hard_loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    return total_loss / len(dataloader)


# ============================================================
# SECTION 8: Main Experiment Pipeline
# ============================================================

def main():
    train_loader, test_loader = get_dataloaders()

    # Calibration loader for PTQ
    calibration_subset = Subset(train_loader.dataset, range(1024))
    calibration_loader = DataLoader(calibration_subset, batch_size=BATCH_SIZE)

    # ---------------------------------------------------------
    # 1. Baseline
    # ---------------------------------------------------------
    print("\n=== 1. Training Baseline ===")
    baseline = BaselineFFN().to(DEVICE)
    opt = optim.Adam(baseline.parameters(), lr=LR)
    criterion = nn.CrossEntropyLoss()

    for epoch in range(EPOCHS):
        loss = train(baseline, train_loader, opt, criterion)
        print(f"Epoch {epoch+1}: Loss={loss:.4f}")

    base_acc = evaluate(baseline, test_loader, device=DEVICE)
    base_lat = measure_latency(baseline, test_loader, device=DEVICE)


    # ---------------------------------------------------------
    # 2. MoE (Training from scratch)
    # ---------------------------------------------------------
    print("\n=== 2. Training MoE ===")
    moe = MoEFFN().to(DEVICE)
    opt_moe = optim.Adam(moe.parameters(), lr=LR)

    for epoch in range(EPOCHS):
        loss = train(moe, train_loader, opt_moe, criterion)
        print(f"Epoch {epoch+1}: Loss={loss:.4f}")

    moe_acc = evaluate(moe, test_loader, device=DEVICE)
    moe_lat = measure_latency(moe, test_loader, device=DEVICE)


    # ---------------------------------------------------------
    # 3. MoE + KD
    # ---------------------------------------------------------
    print("\n=== 3. Training MoE with Knowledge Distillation ===")
    kd_student = MoEFFN().to(DEVICE)
    opt_kd = optim.Adam(kd_student.parameters(), lr=LR)

    for epoch in range(KD_EPOCHS):
        loss = train_kd(kd_student, baseline, train_loader, opt_kd)
        print(f"KD Epoch {epoch+1}: Loss={loss:.4f}")

    kd_acc = evaluate(kd_student, test_loader, device=DEVICE)
    kd_lat = measure_latency(kd_student, test_loader, device=DEVICE)


    # ---------------------------------------------------------
    # 4. MoE + KD + PTQ
    # ---------------------------------------------------------
    print("\n=== 4. Applying PTQ to MoE  ===")
    quant_kd_student = MoEFFN()
    quant_kd_student.load_state_dict(kd_student.state_dict())

    # Apply PTQ (Helper function handles the float island)
    quant_moe = apply_ptq(quant_kd_student, calibration_loader)

    quant_acc = evaluate(quant_moe, test_loader, device=torch.device("cpu"))
    quant_lat = measure_latency(quant_moe, test_loader, device=torch.device("cpu"))


    # ---------------------------------------------------------
    # Final Results
    # ---------------------------------------------------------
    print("\n" + "="*80)
    print(f"{'Model':<20} | {'Acc':<8} | {'Lat(ms)':<8} | {'Params':<12} | {'Size(MB)':<10}")
    print("-" * 80)

    # 1. Baseline
    print(f"{'Baseline':<20} | {base_acc:.4f}   | {base_lat*1000:.2f}     | {count_parameters(baseline):<12,} | {estimate_model_size(baseline):.2f}")

    # 2. MoE
    print(f"{'MoE':<20} | {moe_acc:.4f}   | {moe_lat*1000:.2f}     | {count_parameters(moe):<12,} | {estimate_model_size(moe):.2f}")

    # 3. MoE + KD
    print(f"{'MoE + KD':<20} | {kd_acc:.4f}   | {kd_lat*1000:.2f}     | {count_parameters(kd_student):<12,} | {estimate_model_size(kd_student):.2f}")

    # 4. MoE + KD + PTQ
    print(f"{'MoE + KD + PTQ':<20} | {quant_acc:.4f}   | {quant_lat*1000:.2f}     | {count_parameters(quant_moe):<12,} | {estimate_model_size(quant_moe, True):.2f}")
    print("="*80)

if __name__ == "__main__":
    main()


=== 1. Training Baseline ===




Epoch 1: Loss=0.2277




Epoch 2: Loss=0.0935




Epoch 3: Loss=0.0656




Epoch 4: Loss=0.0485




Epoch 5: Loss=0.0407

=== 2. Training MoE ===




Epoch 1: Loss=0.2332




Epoch 2: Loss=0.1012




Epoch 3: Loss=0.0721




Epoch 4: Loss=0.0530




Epoch 5: Loss=0.0446

=== 3. Training MoE with Knowledge Distillation ===




KD Epoch 1: Loss=1.3546




KD Epoch 2: Loss=0.2368




KD Epoch 3: Loss=0.1475




KD Epoch 4: Loss=0.1147




KD Epoch 5: Loss=0.0985

=== 4. Applying PTQ to MoE  ===

Model                | Acc      | Lat(ms)  | Params       | Size(MB)  
--------------------------------------------------------------------------------
Baseline             | 0.9742   | 14.87     | 235,146      | 0.90
MoE                  | 0.9751   | 13.96     | 334,862      | 1.28
MoE + KD             | 0.9762   | 13.12     | 334,862      | 1.28
MoE + KD + PTQ       | 0.9760   | 17.03     | 132,612      | 0.13


In [35]:
!pip install torch -q

In [49]:
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.quantization as tq

from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset
from tqdm import tqdm

# -------------------------------
# Global Config (Colab Safe)
# -------------------------------
DEVICE = torch.device("cpu")
BATCH_SIZE = 64
EPOCHS = 1
LR = 1e-3
NUM_EXPERTS = 4

engines = torch.backends.quantized.supported_engines
print(f"Supported engines: {engines}")

# ============================================================
# Utils
# ============================================================

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

def estimate_model_size(model, quantized=False):
    bytes_per_param = 1 if quantized else 4
    size_mb = count_parameters(model) * bytes_per_param / (1024**2)
    return size_mb

def measure_latency(model, dataloader, num_batches=50):
    model.eval()

    start = time.time()
    with torch.no_grad():
        for i, (x, _) in enumerate(dataloader):
            _ = model(x)
            if i >= num_batches:
                break
    end = time.time()

    return (end - start) / num_batches

# ============================================================
# Data
# ============================================================

def get_dataloaders():
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))])

    train_ds = datasets.MNIST("./data", train=True, download=True, transform=transform)
    test_ds  = datasets.MNIST("./data", train=False, download=True, transform=transform)

    train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
    test_loader  = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False)

    calib_subset = Subset(train_ds, range(1024))
    calib_loader = DataLoader(calib_subset, batch_size=BATCH_SIZE, shuffle=False)

    return train_loader, test_loader, calib_loader

# ============================================================
# Baseline
# ============================================================

class BaselineFFN(nn.Module):
    def __init__(self):
        super().__init__()
        self.quant = tq.QuantStub()
        self.fc1 = nn.Linear(784, 256)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Linear(256, 128)
        self.relu2 = nn.ReLU()
        self.fc3 = nn.Linear(128, 10)
        self.dequant = tq.DeQuantStub()

    def forward(self, x):
        # x = self.quant(x)
        x = x.view(x.size(0), -1)
        x = self.relu1(self.fc1(x))
        x = self.relu2(self.fc2(x))
        x = self.fc3(x)
        # x = self.dequant(x)
        return x

# ============================================================
# MoE
# ============================================================

class MoELayer(nn.Module):
    def __init__(self, input_dim, output_dim, num_experts):
        super().__init__()
        self.num_experts = num_experts

        self.dequant = tq.DeQuantStub()
        self.quant = tq.QuantStub()

        self.gate = nn.Linear(input_dim, num_experts)

        self.experts = nn.ModuleList([
            nn.Sequential(
                nn.Linear(input_dim, output_dim),
                nn.ReLU()
            ) for _ in range(num_experts)])

    def forward(self, x):
        x_float = self.dequant(x)

        gate_logits = self.gate(x_float)
        gate_probs = F.softmax(gate_logits, dim=1)

        expert_outs = []
        for expert in self.experts:
            expert_outs.append(expert(x_float))

        expert_outs = torch.stack(expert_outs, dim=1)

        out_float = torch.sum(gate_probs.unsqueeze(-1) * expert_outs, dim=1)
        out_quant = self.quant(out_float)

        return out_quant, gate_probs


class MoEFFN(nn.Module):
    def __init__(self, num_experts):
        super().__init__()
        self.quant = tq.QuantStub()
        self.fc1 = nn.Linear(784, 256)
        self.relu1 = nn.ReLU()
        self.moe = MoELayer(256, 128, num_experts)
        self.fc3 = nn.Linear(128, 10)
        self.dequant = tq.DeQuantStub()

    def forward(self, x):
        x = self.quant(x)
        x = x.view(x.size(0), -1)
        x = self.relu1(self.fc1(x))
        x, gate_probs = self.moe(x)
        x = self.fc3(x)
        x = self.dequant(x)
        return x, gate_probs

# ============================================================
# Training
# ============================================================

def train_baseline(model, dataloader, optimizer, criterion):
    model.train()
    total_loss, correct, total = 0, 0, 0

    for x, y in tqdm(dataloader, desc="Training", leave=False):
        optimizer.zero_grad()
        logits = model(x)
        loss = criterion(logits, y)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        preds = logits.argmax(1)
        correct += (preds == y).sum().item()
        total += y.size(0)

    return total_loss / len(dataloader), correct / total


def train_moe(model, dataloader, optimizer, criterion):
    model.train()
    total_loss, correct, total = 0, 0, 0

    for x, y in tqdm(dataloader, desc="Training", leave=False):
        optimizer.zero_grad()
        logits, _ = model(x)
        loss = criterion(logits, y)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        preds = logits.argmax(1)
        correct += (preds == y).sum().item()
        total += y.size(0)

    return total_loss / len(dataloader), correct / total


def train_kd(student, teacher, dataloader, optimizer, alpha=0.7, temperature=4.0):
    student.train()
    teacher.eval()

    total_loss, correct, total = 0, 0, 0

    for x, y in tqdm(dataloader, desc="KD Training", leave=False):
        with torch.no_grad():
            teacher_logits = teacher(x)

        student_logits, _ = student(x)

        soft_loss = F.kl_div(
            F.log_softmax(student_logits / temperature, dim=1),
            F.softmax(teacher_logits / temperature, dim=1),
            reduction="batchmean"
        ) * (temperature ** 2)

        hard_loss = F.cross_entropy(student_logits, y)
        loss = alpha * soft_loss + (1 - alpha) * hard_loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        preds = student_logits.argmax(1)
        correct += (preds == y).sum().item()
        total += y.size(0)

    return total_loss / len(dataloader), correct / total

# ============================================================
# Eval
# ============================================================

def eval_baseline(model, loader):
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for x, y in loader:
            preds = model(x).argmax(1)
            correct += (preds == y).sum().item()
            total += y.size(0)
    return correct / total


def eval_moe(model, loader):
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for x, y in loader:
            logits, _ = model(x)
            preds = logits.argmax(1)
            correct += (preds == y).sum().item()
            total += y.size(0)
    return correct / total

# ============================================================
# Router Load
# ============================================================

def router_load_analysis(model, loader, num_batches=100):
    model.eval()
    usage = torch.zeros(NUM_EXPERTS)

    with torch.no_grad():
        for i, (x, _) in enumerate(loader):
            _, gate_probs = model(x)
            top = gate_probs.argmax(1)
            for e in top:
                usage[e] += 1
            if i >= num_batches:
                break

    return usage / usage.sum()

# ============================================================
# PTQ
# ============================================================

def apply_ptq(model, calib_loader):
    model.eval()
    model.qconfig = tq.get_default_qconfig("qnnpack")

    # Float island
    model.moe.gate.qconfig = None
    for ex in model.moe.experts:
        ex[0].qconfig = None
        ex[1].qconfig = None

    tq.fuse_modules(model, [["fc1","relu1"]], inplace=True)

    tq.prepare(model, inplace=True)

    with torch.no_grad():
        for x, _ in calib_loader:
            model(x)

    tq.convert(model, inplace=True)
    return model

# ============================================================
# Main
# ============================================================

def main():
    train_loader, test_loader, calib_loader = get_dataloaders()

    # 1. Baseline
    baseline = BaselineFFN()
    opt = optim.Adam(baseline.parameters(), lr=LR)
    crit = nn.CrossEntropyLoss()

    print("\n---1. Baseline training ...")
    for e in range(EPOCHS):
        loss, acc = train_baseline(baseline, train_loader, opt, crit)
        print(f"Epoch {e+1} | Loss: {loss:.4f} | Accuracy: {acc:.4f}")

    base_acc = eval_baseline(baseline, test_loader)
    base_lat = measure_latency(baseline, test_loader)

    # 2. MoE
    moe = MoEFFN(NUM_EXPERTS)
    opt_moe = optim.Adam(moe.parameters(), lr=LR)

    print("\n---2. MoE training ...")
    for e in range(EPOCHS):
        loss, acc = train_moe(moe, train_loader, opt_moe, crit)
        print(f"Epoch {e+1} | Loss: {loss:.4f} | Accuracy: {acc:.4f}")

    moe_acc = eval_moe(moe, test_loader)
    moe_lat = measure_latency(moe, test_loader)

    # 3. KD
    kd_student = MoEFFN(NUM_EXPERTS)
    opt_kd = optim.Adam(kd_student.parameters(), lr=LR)

    print("\n---3. KD training ...")
    for e in range(EPOCHS):
        loss, acc = train_kd(kd_student, baseline, train_loader, opt_kd)
        print(f"Epoch {e+1} | Loss: {loss:.4f} | Accuracy: {acc:.4f}")

    kd_acc = eval_moe(kd_student, test_loader)
    kd_lat = measure_latency(kd_student, test_loader)

    # 4. PTQ
    quant_moe = MoEFFN(NUM_EXPERTS)
    quant_moe.load_state_dict(moe.state_dict())
    quant_moe = apply_ptq(quant_moe, calib_loader)

    quant_acc = eval_moe(quant_moe, test_loader)
    quant_lat = measure_latency(quant_moe, test_loader)

    quant_kd_student = MoEFFN(NUM_EXPERTS)
    quant_kd_student.load_state_dict(kd_student.state_dict())

    quant_moe_kd = apply_ptq(quant_kd_student, calib_loader)
    quant_acc_kd = eval_moe(quant_moe_kd, test_loader)
    quant_lat_kd = measure_latency(quant_moe_kd, test_loader)



    # Final table
    print("....Final table...")
    print("\n" + "="*70)
    print(f"{'Model':<20} | {'Acc':<8} | {'Lat(ms)':<8} | {'Params':<12} | {'Size(MB)':<10}")
    print("-"*70)

    print(f"{'Baseline':<20} | {base_acc:.4f} | {base_lat*1000:.2f} | {count_parameters(baseline):<12,} | {estimate_model_size(baseline):.2f}")
    print(f"{'MoE':<20} | {moe_acc:.4f} | {moe_lat*1000:.2f} | {count_parameters(moe):<12,} | {estimate_model_size(moe):.2f}")
    print(f"{'MoE + PTQ':<20} | {quant_acc:.4f} | {quant_lat*1000:.2f} | {count_parameters(quant_moe):<12,} | {estimate_model_size(quant_moe, True):.2f}")
    print(f"{'MoE + KD':<20} | {kd_acc:.4f} | {kd_lat*1000:.2f} | {count_parameters(kd_student):<12,} | {estimate_model_size(kd_student):.2f}")
    print(f"{'MoE + KD + PTQ':<20} | {quant_acc_kd:.4f} | {quant_lat_kd*1000:.2f} | {count_parameters(quant_moe_kd):<12,} | {estimate_model_size(quant_moe_kd, True):.2f}")
    print("="*70)

    moe_load = router_load_analysis(moe, test_loader)
    print("\nRouter Load Distribution (MoE):")
    for i, v in enumerate(moe_load):
        print(f"Expert {i}: {v.item()*100:.2f}%")

if __name__ == "__main__":
    main()


Supported engines: ['qnnpack', 'onednn', 'x86', 'fbgemm']

---1. Baseline training ...




Epoch 1 | Loss: 0.2286 | Accuracy: 0.9306

---2. MoE training ...




Epoch 1 | Loss: 0.2317 | Accuracy: 0.9308

---3. KD training ...




Epoch 1 | Loss: 0.4378 | Accuracy: 0.9306
....Final table...

Model                | Acc      | Lat(ms)  | Params       | Size(MB)  
----------------------------------------------------------------------
Baseline             | 0.9601 | 13.60 | 235,146      | 0.90
MoE                  | 0.9547 | 13.76 | 334,862      | 1.28
MoE + PTQ            | 0.9546 | 16.65 | 132,612      | 0.13
MoE + KD             | 0.9579 | 14.07 | 334,862      | 1.28
MoE + KD + PTQ       | 0.9586 | 33.99 | 132,612      | 0.13

Router Load Distribution (MoE):
Expert 0: 0.14%
Expert 1: 0.00%
Expert 2: 97.66%
Expert 3: 2.20%
