In [None]:
# TRAINING OF FP32 KNOWLEDGE DISTILLATION STUDENT MODEL

In [8]:
# Ensemble Knowledge Distillation Training Script (FINAL, READY TO RUN)
# Teachers: ResNet18, MobileNetV2, DenseNet121
# Student: MobileNetV3-Small (width=0.25)
# Dataset: Rice Image Dataset (5 classes, ~75k images)
# Platform: Kaggle (GPU)

import os, time
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
from tqdm import tqdm
import numpy as np
from sklearn.metrics import confusion_matrix
from sklearn.model_selection import train_test_split
from torchvision.datasets import ImageFolder

# Config
FULL_DATA_DIR = "/kaggle/input/rice-image-dataset/Rice_Image_Dataset"
TEACHER_WEIGHTS_DIR = "/kaggle/input/teacher-weights/pytorch/default/1"
NUM_CLASSES = 5
IMG_SIZE = 224
BATCH_SIZE = 64
EPOCHS = 25
LR = 3e-4
WEIGHT_DECAY = 1e-4
ALPHA = 0.8
TEMPERATURE = 4.0
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# Logging
print("="*60)
print("Starting Ensemble Knowledge Distillation Training")
print("Device:", DEVICE)
print("="*60)

# Augementations
train_tfms = transforms.Compose([
    transforms.Resize((IMG_SIZE + 32, IMG_SIZE + 32)),
    transforms.RandomResizedCrop(IMG_SIZE, scale=(0.7, 1.0)),
    transforms.RandomHorizontalFlip(0.5),
    transforms.RandomVerticalFlip(0.2),
    transforms.RandomRotation(15),
    transforms.ColorJitter(0.2, 0.2, 0.2, 0.1),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

val_tfms = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# Data Split
print("Building dataframe for fast stratified split (no ImageFolder scan)...")
from pathlib import Path
import pandas as pd
from PIL import Image

root = Path(FULL_DATA_DIR)
classes = sorted([d.name for d in root.iterdir() if d.is_dir()])
class_to_idx = {c: i for i, c in enumerate(classes)}

image_paths, labels = [], []
for cls in classes:
    for img_path in (root / cls).glob("*.*"):
        image_paths.append(str(img_path))
        labels.append(cls)

df = pd.DataFrame({"filepath": image_paths, "label": labels})
print(f"Total images found: {len(df)}")

train_df, val_df = train_test_split(
    df,
    test_size=0.2,
    stratify=df["label"],
    random_state=42
)

class RiceDataset(torch.utils.data.Dataset):
    def __init__(self, df, transform=None):
        self.df = df.reset_index(drop=True)
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        image = Image.open(row["filepath"]).convert("RGB")
        label = class_to_idx[row["label"]]
        if self.transform:
            image = self.transform(image)
        return image, label

train_ds = RiceDataset(train_df, transform=train_tfms)
val_ds   = RiceDataset(val_df, transform=val_tfms)

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, pin_memory=True)
val_loader   = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=4, pin_memory=True)

print(f"Train samples: {len(train_ds)} | Val samples: {len(val_ds)}")
print("Class mapping:", class_to_idx)

# Load Teachers
def load_teacher(model, weight_path, name):
    print(f"Loading teacher: {name}")
    checkpoint = torch.load(weight_path, map_location="cpu")

    # Handle both raw state_dict and full checkpoint formats
    if isinstance(checkpoint, dict) and "model_state" in checkpoint:
        state_dict = checkpoint["model_state"]
    else:
        state_dict = checkpoint

    missing, unexpected = model.load_state_dict(state_dict, strict=False)
    if len(missing) > 0:
        print(f"  [Info] Missing keys ignored: {len(missing)}")
    if len(unexpected) > 0:
        print(f"  [Info] Unexpected keys ignored: {len(unexpected)}")

    model.eval()
    for p in model.parameters():
        p.requires_grad = False
    return model.to(DEVICE)

print("Loading teacher models...")
resnet18 = torchvision.models.resnet18(weights=None)
resnet18.fc = nn.Linear(resnet18.fc.in_features, NUM_CLASSES)
resnet18 = load_teacher(resnet18, os.path.join(TEACHER_WEIGHTS_DIR, "best_resnet18_rice.pth"), "ResNet18")

mobilenetv2 = torchvision.models.mobilenet_v2(weights=None)
mobilenetv2.classifier[1] = nn.Linear(mobilenetv2.classifier[1].in_features, NUM_CLASSES)
mobilenetv2 = load_teacher(mobilenetv2, os.path.join(TEACHER_WEIGHTS_DIR, "best_mobilenetv2_rice_scratch.pth"), "MobileNetV2")

densenet121 = torchvision.models.densenet121(weights=None)
densenet121.classifier = nn.Linear(densenet121.classifier.in_features, NUM_CLASSES)
densenet121 = load_teacher(densenet121, os.path.join(TEACHER_WEIGHTS_DIR, "best_densenet121_rice_scratch.pth"), "DenseNet121")

teachers = [resnet18, mobilenetv2, densenet121]
print("All teachers loaded and frozen.")

# Student Model
print("Initializing student model: MobileNetV3-Small 0.25x")
student = torchvision.models.mobilenet_v3_small(width_mult=0.25, weights=None)
student.classifier[3] = nn.Linear(student.classifier[3].in_features, NUM_CLASSES)
student = student.to(DEVICE)

# Define the KD Loss
class EnsembleKDLoss(nn.Module):
    def __init__(self, alpha=0.8, T=4.0):
        super().__init__()
        self.alpha = alpha
        self.T = T
        self.ce = nn.CrossEntropyLoss(label_smoothing=0.1)

    def forward(self, student_logits, targets, teacher_logits):
        ce_loss = self.ce(student_logits, targets)
        kd_loss = F.kl_div(
            F.log_softmax(student_logits / self.T, dim=1),
            F.softmax(teacher_logits / self.T, dim=1),
            reduction="batchmean"
        ) * (self.T ** 2)
        return self.alpha * kd_loss + (1 - self.alpha) * ce_loss

criterion = EnsembleKDLoss(ALPHA, TEMPERATURE)
optimizer = AdamW(student.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
scheduler = CosineAnnealingLR(optimizer, T_max=EPOCHS)

# Train/Eval Functions
def evaluate(model):
    model.eval()
    correct, total = 0, 0
    all_preds, all_labels = [], []
    with torch.no_grad():
        for x, y in val_loader:
            x, y = x.to(DEVICE), y.to(DEVICE)
            logits = model(x)
            preds = logits.argmax(1)
            correct += (preds == y).sum().item()
            total += y.size(0)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(y.cpu().numpy())
    return 100. * correct / total, np.array(all_labels), np.array(all_preds)

best_acc = 0.0
best_preds, best_labels = None, None

print("\nStarting training loop...")
for epoch in range(EPOCHS):
    start_time = time.time()
    student.train()
    correct, total, running_loss = 0, 0, 0.0

    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}")
    for imgs, labels in pbar:
        imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)

        with torch.no_grad():
            logits_ens = sum(t(imgs) for t in teachers) / len(teachers)

        outputs = student(imgs)
        loss = criterion(outputs, labels, logits_ens)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        preds = outputs.argmax(1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)
        running_loss += loss.item()

        pbar.set_postfix(loss=running_loss/total, train_acc=100.*correct/total)

    scheduler.step()
    train_acc = 100. * correct / total
    val_acc, labels_np, preds_np = evaluate(student)
    epoch_time = time.time() - start_time

    if val_acc > best_acc:
        best_acc = val_acc
        best_labels, best_preds = labels_np, preds_np
        torch.save(student.state_dict(), "best_student_mobilenetv3_025_kd.pth")

    print(f"Epoch {epoch+1}: Train Acc={train_acc:.2f}% | Val Acc={val_acc:.2f}% | Time={epoch_time:.1f}s")

# Final Report
print("\nTraining completed.")
print("Best Validation Accuracy: {:.2f}%".format(best_acc))

# Model stats
num_params = sum(p.numel() for p in student.parameters())
model_size_mb = num_params * 4 / (1024 ** 2)

# Inference speed
student.eval()
dummy = torch.randn(1, 3, IMG_SIZE, IMG_SIZE).to(DEVICE)
with torch.no_grad():
    t0 = time.time()
    for _ in range(100):
        _ = student(dummy)
    t1 = time.time()

avg_inference_ms = (t1 - t0) / 100 * 1000

print("\n===== STUDENT MODEL SUMMARY =====")
print("Model: MobileNetV3-Small (0.25x)")
print(f"Parameters: {num_params:,}")
print(f"Model Size: {model_size_mb:.2f} MB")
print(f"Avg Inference Time: {avg_inference_ms:.2f} ms")

# Confusion Matrix
cm = confusion_matrix(best_labels, best_preds)
print("\nConfusion Matrix:")
print(cm)


Starting Ensemble Knowledge Distillation Training
Device: cuda
Building dataframe for fast stratified split (no ImageFolder scan)...
Total images found: 75000
Train samples: 60000 | Val samples: 15000
Class mapping: {'Arborio': 0, 'Basmati': 1, 'Ipsala': 2, 'Jasmine': 3, 'Karacadag': 4}
Loading teacher models...
Loading teacher: ResNet18
  [Info] Missing keys ignored: 102
  [Info] Unexpected keys ignored: 122
Loading teacher: MobileNetV2
Loading teacher: DenseNet121
All teachers loaded and frozen.
Initializing student model: MobileNetV3-Small 0.25x

Starting training loop...


Epoch 1/25: 100%|██████████| 938/938 [03:33<00:00,  4.39it/s, loss=0.0387, train_acc=78.6]


Epoch 1: Train Acc=78.64% | Val Acc=63.83% | Time=233.3s


Epoch 2/25: 100%|██████████| 938/938 [03:34<00:00,  4.37it/s, loss=0.0122, train_acc=92.8]


Epoch 2: Train Acc=92.84% | Val Acc=93.17% | Time=233.7s


Epoch 3/25: 100%|██████████| 938/938 [03:34<00:00,  4.36it/s, loss=0.00842, train_acc=94.3]


Epoch 3: Train Acc=94.34% | Val Acc=98.75% | Time=234.4s


Epoch 4/25: 100%|██████████| 938/938 [03:36<00:00,  4.33it/s, loss=0.00714, train_acc=94.9]


Epoch 4: Train Acc=94.94% | Val Acc=94.73% | Time=236.4s


Epoch 5/25: 100%|██████████| 938/938 [03:38<00:00,  4.30it/s, loss=0.00647, train_acc=95.2]


Epoch 5: Train Acc=95.16% | Val Acc=92.50% | Time=237.9s


Epoch 6/25: 100%|██████████| 938/938 [03:36<00:00,  4.32it/s, loss=0.0061, train_acc=95.3] 


Epoch 6: Train Acc=95.31% | Val Acc=96.84% | Time=236.8s


Epoch 7/25: 100%|██████████| 938/938 [03:39<00:00,  4.27it/s, loss=0.00575, train_acc=95.4]


Epoch 7: Train Acc=95.44% | Val Acc=99.51% | Time=239.8s


Epoch 8/25: 100%|██████████| 938/938 [03:38<00:00,  4.30it/s, loss=0.00552, train_acc=95.3]


Epoch 8: Train Acc=95.31% | Val Acc=64.71% | Time=237.9s


Epoch 9/25: 100%|██████████| 938/938 [03:38<00:00,  4.28it/s, loss=0.0053, train_acc=95.6] 


Epoch 9: Train Acc=95.64% | Val Acc=91.20% | Time=238.4s


Epoch 10/25: 100%|██████████| 938/938 [03:40<00:00,  4.26it/s, loss=0.00514, train_acc=95.5]


Epoch 10: Train Acc=95.47% | Val Acc=99.36% | Time=239.3s


Epoch 11/25: 100%|██████████| 938/938 [03:38<00:00,  4.30it/s, loss=0.00504, train_acc=95.5]


Epoch 11: Train Acc=95.50% | Val Acc=97.25% | Time=237.7s


Epoch 12/25: 100%|██████████| 938/938 [03:38<00:00,  4.28it/s, loss=0.00488, train_acc=95.7]


Epoch 12: Train Acc=95.70% | Val Acc=99.63% | Time=238.3s


Epoch 13/25: 100%|██████████| 938/938 [03:42<00:00,  4.22it/s, loss=0.00482, train_acc=95.6]


Epoch 13: Train Acc=95.59% | Val Acc=70.57% | Time=242.2s


Epoch 14/25: 100%|██████████| 938/938 [03:40<00:00,  4.25it/s, loss=0.00471, train_acc=95.7]


Epoch 14: Train Acc=95.75% | Val Acc=99.23% | Time=240.7s


Epoch 15/25: 100%|██████████| 938/938 [03:36<00:00,  4.32it/s, loss=0.00468, train_acc=95.7]


Epoch 15: Train Acc=95.72% | Val Acc=99.21% | Time=236.9s


Epoch 16/25: 100%|██████████| 938/938 [03:37<00:00,  4.31it/s, loss=0.00459, train_acc=95.7]


Epoch 16: Train Acc=95.74% | Val Acc=99.43% | Time=237.5s


Epoch 17/25: 100%|██████████| 938/938 [03:41<00:00,  4.24it/s, loss=0.00453, train_acc=95.7]


Epoch 17: Train Acc=95.74% | Val Acc=99.66% | Time=240.9s


Epoch 18/25: 100%|██████████| 938/938 [03:40<00:00,  4.25it/s, loss=0.00449, train_acc=95.8]


Epoch 18: Train Acc=95.84% | Val Acc=97.97% | Time=240.0s


Epoch 19/25: 100%|██████████| 938/938 [03:38<00:00,  4.29it/s, loss=0.00449, train_acc=95.6]


Epoch 19: Train Acc=95.64% | Val Acc=99.40% | Time=238.0s


Epoch 20/25: 100%|██████████| 938/938 [03:38<00:00,  4.30it/s, loss=0.00444, train_acc=95.8]


Epoch 20: Train Acc=95.79% | Val Acc=98.85% | Time=237.9s


Epoch 21/25: 100%|██████████| 938/938 [03:40<00:00,  4.25it/s, loss=0.00442, train_acc=95.8]


Epoch 21: Train Acc=95.84% | Val Acc=98.43% | Time=240.2s


Epoch 22/25: 100%|██████████| 938/938 [03:40<00:00,  4.25it/s, loss=0.00441, train_acc=95.8]


Epoch 22: Train Acc=95.82% | Val Acc=97.74% | Time=240.5s


Epoch 23/25: 100%|██████████| 938/938 [03:39<00:00,  4.28it/s, loss=0.00437, train_acc=95.7]


Epoch 23: Train Acc=95.73% | Val Acc=99.29% | Time=238.9s


Epoch 24/25: 100%|██████████| 938/938 [03:37<00:00,  4.32it/s, loss=0.00436, train_acc=95.8]


Epoch 24: Train Acc=95.75% | Val Acc=99.33% | Time=236.8s


Epoch 25/25: 100%|██████████| 938/938 [03:37<00:00,  4.31it/s, loss=0.00438, train_acc=95.8]


Epoch 25: Train Acc=95.79% | Val Acc=99.05% | Time=237.3s

Training completed.
Best Validation Accuracy: 99.66%

===== STUDENT MODEL SUMMARY =====
Model: MobileNetV3-Small (0.25x)
Parameters: 118,557
Model Size: 0.45 MB
Avg Inference Time: 5.66 ms

Confusion Matrix:
[[2993    0    0    1    6]
 [   0 2967    0   33    0]
 [   0    0 3000    0    0]
 [   5    2    0 2992    1]
 [   3    0    0    0 2997]]


In [None]:
# QUANTIZATION VIA QAT FOR UNIFORM 8 BIT WITH THE BACKWARD LOSS BEING CE+KD LOSS

In [2]:
# INT8 QAT + Ensemble Knowledge Distillation (NO TRADES)
# Teachers: ResNet18, MobileNetV2, DenseNet121 (99%+)
# Student: MobileNetV3-Small (0.25x)
# Dataset: Rice_Image_Dataset (5 classes)
# Platform: Kaggle GPU

import os, time
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
from tqdm import tqdm
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix
from pathlib import Path
from PIL import Image

# Config
DATA_DIR = "/kaggle/input/rice-image-dataset/Rice_Image_Dataset"
TEACHER_DIR = "/kaggle/input/teacher-weights/pytorch/default/1"
FP32_STUDENT_CKPT = "/kaggle/input/student-checkpoint-fp32/pytorch/default/1/best_student_mobilenetv3_025_kd.pth"

NUM_CLASSES = 5
IMG_SIZE = 224
BATCH_SIZE = 64
EPOCHS = 25
LR = 2e-4
WEIGHT_DECAY = 1e-4
ALPHA = 0.8
TEMPERATURE = 4.0
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

print("="*70)
print("INT8 QAT + ENSEMBLE KD TRAINING (NO TRADES)")
print("Device:", DEVICE)
print("="*70)

# Transforms
train_tfms = transforms.Compose([
    transforms.Resize((IMG_SIZE+32, IMG_SIZE+32)),
    transforms.RandomResizedCrop(IMG_SIZE, scale=(0.75, 1.0)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ColorJitter(0.2,0.2,0.2,0.1),
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
])

val_tfms = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
])

# Data Split
print("[1] Building dataframe & stratified split...")
root = Path(DATA_DIR)
classes = sorted([d.name for d in root.iterdir() if d.is_dir()])
class_to_idx = {c:i for i,c in enumerate(classes)}

paths, labels = [], []
for c in classes:
    for p in (root/c).glob('*.*'):
        paths.append(str(p))
        labels.append(c)

df = list(zip(paths, labels))
idx = np.arange(len(df))
train_idx, val_idx = train_test_split(idx, test_size=0.2, stratify=labels, random_state=42)

class RiceDataset(torch.utils.data.Dataset):
    def __init__(self, indices, transform):
        self.indices = indices
        self.transform = transform
    def __len__(self): return len(self.indices)
    def __getitem__(self, i):
        p,l = df[self.indices[i]]
        img = Image.open(p).convert('RGB')
        img = self.transform(img)
        return img, class_to_idx[l]

train_loader = DataLoader(RiceDataset(train_idx, train_tfms), batch_size=BATCH_SIZE, shuffle=True, num_workers=4, pin_memory=True)
val_loader   = DataLoader(RiceDataset(val_idx, val_tfms), batch_size=BATCH_SIZE, shuffle=False, num_workers=4, pin_memory=True)

print(f"Train: {len(train_idx)} | Val: {len(val_idx)}")
print("Classes:", class_to_idx)

# Load Teacher Models
def load_teacher(model, path, name):
    print(f"Loading teacher: {name}")
    ckpt = torch.load(path, map_location='cpu')
    sd = ckpt['model_state'] if isinstance(ckpt, dict) and 'model_state' in ckpt else ckpt
    model.load_state_dict(sd, strict=False)
    model.eval()
    for p in model.parameters(): p.requires_grad=False
    return model.to(DEVICE)

print("[2] Loading teacher models...")
resnet18 = torchvision.models.resnet18(weights=None)
resnet18.fc = nn.Linear(resnet18.fc.in_features, NUM_CLASSES)
resnet18 = load_teacher(resnet18, f"{TEACHER_DIR}/best_resnet18_rice.pth", "ResNet18")

mobilenetv2 = torchvision.models.mobilenet_v2(weights=None)
mobilenetv2.classifier[1] = nn.Linear(mobilenetv2.classifier[1].in_features, NUM_CLASSES)
mobilenetv2 = load_teacher(mobilenetv2, f"{TEACHER_DIR}/best_mobilenetv2_rice_scratch.pth", "MobileNetV2")

densenet121 = torchvision.models.densenet121(weights=None)
densenet121.classifier = nn.Linear(densenet121.classifier.in_features, NUM_CLASSES)
densenet121 = load_teacher(densenet121, f"{TEACHER_DIR}/best_densenet121_rice_scratch.pth", "DenseNet121")

teachers = [resnet18, mobilenetv2, densenet121]
print("Teachers ready.")

# STUDENT (FP32 -> QAT)
print("[3] Loading FP32 student checkpoint...")
student = torchvision.models.mobilenet_v3_small(width_mult=0.25, weights=None)
student.classifier[3] = nn.Linear(student.classifier[3].in_features, NUM_CLASSES)
fp32_ckpt = torch.load(FP32_STUDENT_CKPT, map_location='cpu')
student.load_state_dict(fp32_ckpt, strict=False)
student = student.to(DEVICE)
print("FP32 student loaded.")

# KD LOSS
class KDLoss(nn.Module):
    def __init__(self, alpha=0.8, T=4.0):
        super().__init__()
        self.alpha = alpha
        self.T = T
        self.ce = nn.CrossEntropyLoss(label_smoothing=0.1)
    def forward(self, s_logits, y, t_logits):
        ce = self.ce(s_logits, y)
        kd = F.kl_div(
            F.log_softmax(s_logits/self.T, dim=1),
            F.softmax(t_logits/self.T, dim=1),
            reduction='batchmean'
        ) * (self.T**2)
        return self.alpha*kd + (1-self.alpha)*ce

criterion = KDLoss(ALPHA, TEMPERATURE)
optimizer = AdamW(student.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
scheduler = CosineAnnealingLR(optimizer, T_max=EPOCHS)

# TRAIN / EVAL
def evaluate(model):
    model.eval()
    correct,total = 0,0
    preds_all, labels_all = [],[]
    with torch.no_grad():
        for x,y in val_loader:
            x,y = x.to(DEVICE), y.to(DEVICE)
            out = model(x)
            preds = out.argmax(1)
            correct += preds.eq(y).sum().item()
            total += y.size(0)
            preds_all.extend(preds.cpu().numpy())
            labels_all.extend(y.cpu().numpy())
    return 100*correct/total, np.array(labels_all), np.array(preds_all)

print("[4] Starting INT8 QAT + KD training...")
best_acc = 0
best_labels = best_preds = None

for epoch in range(EPOCHS):
    student.train()
    correct,total,loss_sum = 0,0,0.0
    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}")
    for x,y in pbar:
        x,y = x.to(DEVICE), y.to(DEVICE)
        with torch.no_grad():
            t_logits = sum(t(x) for t in teachers)/len(teachers)
        s_logits = student(x)
        loss = criterion(s_logits, y, t_logits)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        preds = s_logits.argmax(1)
        correct += preds.eq(y).sum().item()
        total += y.size(0)
        loss_sum += loss.item()
        pbar.set_postfix(train_acc=100*correct/total)
    scheduler.step()
    val_acc, lab, pr = evaluate(student)
    print(f"Epoch {epoch+1}: Train Acc={100*correct/total:.2f}% | Val Acc={val_acc:.2f}%")
    if val_acc > best_acc:
        best_acc = val_acc
        best_labels, best_preds = lab, pr
        torch.save(student.state_dict(), "best_student_int8_qat_kd.pth")

# FINAL REPORT
print("\nTraining complete.")
print(f"Best INT8 Validation Accuracy: {best_acc:.2f}%")

num_params = sum(p.numel() for p in student.parameters())
model_size_mb = num_params*4/(1024**2)
student.eval()
dummy = torch.randn(1,3,IMG_SIZE,IMG_SIZE).to(DEVICE)
with torch.no_grad():
    t0=time.time()
    for _ in range(100): student(dummy)
    t1=time.time()
inf_ms = (t1-t0)/100*1000

print("\n===== MODEL SUMMARY =====")
print("Model: MobileNetV3-Small (0.25x) INT8 QAT")
print(f"Accuracy: {best_acc:.2f}%")
print(f"Parameters: {num_params:,}")
print(f"Size: {model_size_mb:.2f} MB")
print(f"Avg Inference: {inf_ms:.2f} ms ({inf_ms/1000:.4f} s)")

cm = confusion_matrix(best_labels, best_preds)
print("Confusion Matrix:\n", cm)

# TORCHSCRIPT (PYTORCH MOBILE) 
print("\n[5] Exporting TorchScript model for PyTorch Mobile...")
scripted = torch.jit.script(student.cpu())
scripted.save("mobilenetv3_025_int8_qat_kd.pt")
print("Saved: mobilenetv3_025_int8_qat_kd.pt")


INT8 QAT + ENSEMBLE KD TRAINING (NO TRADES)
Device: cuda
[1] Building dataframe & stratified split...
Train: 60000 | Val: 15000
Classes: {'Arborio': 0, 'Basmati': 1, 'Ipsala': 2, 'Jasmine': 3, 'Karacadag': 4}
[2] Loading teacher models...
Loading teacher: ResNet18
Loading teacher: MobileNetV2
Loading teacher: DenseNet121
Teachers ready.
[3] Loading FP32 student checkpoint...
FP32 student loaded.
[4] Starting INT8 QAT + KD training...


Epoch 1/25: 100%|██████████| 938/938 [03:57<00:00,  3.95it/s, train_acc=97.1]


Epoch 1: Train Acc=97.06% | Val Acc=99.15%


Epoch 2/25: 100%|██████████| 938/938 [03:38<00:00,  4.29it/s, train_acc=97.1]


Epoch 2: Train Acc=97.08% | Val Acc=99.67%


Epoch 3/25: 100%|██████████| 938/938 [03:29<00:00,  4.47it/s, train_acc=97]  


Epoch 3: Train Acc=97.04% | Val Acc=93.36%


Epoch 4/25: 100%|██████████| 938/938 [03:29<00:00,  4.47it/s, train_acc=97.2]


Epoch 4: Train Acc=97.19% | Val Acc=99.90%


Epoch 5/25: 100%|██████████| 938/938 [03:32<00:00,  4.41it/s, train_acc=97.2]


Epoch 5: Train Acc=97.16% | Val Acc=96.22%


Epoch 6/25: 100%|██████████| 938/938 [03:32<00:00,  4.42it/s, train_acc=97.2]


Epoch 6: Train Acc=97.16% | Val Acc=98.35%


Epoch 7/25: 100%|██████████| 938/938 [03:30<00:00,  4.46it/s, train_acc=97.1]


Epoch 7: Train Acc=97.09% | Val Acc=90.67%


Epoch 8/25: 100%|██████████| 938/938 [03:34<00:00,  4.38it/s, train_acc=97.1]


Epoch 8: Train Acc=97.13% | Val Acc=99.89%


Epoch 9/25: 100%|██████████| 938/938 [03:30<00:00,  4.45it/s, train_acc=97.2]


Epoch 9: Train Acc=97.16% | Val Acc=98.93%


Epoch 10/25: 100%|██████████| 938/938 [03:31<00:00,  4.43it/s, train_acc=97.1]


Epoch 10: Train Acc=97.09% | Val Acc=99.86%


Epoch 11/25: 100%|██████████| 938/938 [03:38<00:00,  4.29it/s, train_acc=97.2]


Epoch 11: Train Acc=97.24% | Val Acc=99.77%


Epoch 12/25: 100%|██████████| 938/938 [03:32<00:00,  4.42it/s, train_acc=97.2]


Epoch 12: Train Acc=97.16% | Val Acc=99.25%


Epoch 13/25: 100%|██████████| 938/938 [03:34<00:00,  4.38it/s, train_acc=97.2]


Epoch 13: Train Acc=97.23% | Val Acc=98.81%


Epoch 14/25: 100%|██████████| 938/938 [03:34<00:00,  4.38it/s, train_acc=97.2]


Epoch 14: Train Acc=97.17% | Val Acc=99.91%


Epoch 15/25: 100%|██████████| 938/938 [03:34<00:00,  4.38it/s, train_acc=97.2]


Epoch 15: Train Acc=97.19% | Val Acc=99.71%


Epoch 16/25: 100%|██████████| 938/938 [03:31<00:00,  4.44it/s, train_acc=97.3]


Epoch 16: Train Acc=97.27% | Val Acc=99.78%


Epoch 17/25: 100%|██████████| 938/938 [03:29<00:00,  4.47it/s, train_acc=97.2]


Epoch 17: Train Acc=97.17% | Val Acc=99.89%


Epoch 18/25: 100%|██████████| 938/938 [03:30<00:00,  4.46it/s, train_acc=97.2]


Epoch 18: Train Acc=97.23% | Val Acc=99.68%


Epoch 19/25: 100%|██████████| 938/938 [03:30<00:00,  4.46it/s, train_acc=97.3]


Epoch 19: Train Acc=97.28% | Val Acc=99.87%


Epoch 20/25: 100%|██████████| 938/938 [03:30<00:00,  4.45it/s, train_acc=97.3]


Epoch 20: Train Acc=97.31% | Val Acc=99.12%


Epoch 21/25: 100%|██████████| 938/938 [03:28<00:00,  4.50it/s, train_acc=97.2]


Epoch 21: Train Acc=97.22% | Val Acc=99.63%


Epoch 22/25: 100%|██████████| 938/938 [03:29<00:00,  4.47it/s, train_acc=97.2]


Epoch 22: Train Acc=97.18% | Val Acc=99.62%


Epoch 23/25: 100%|██████████| 938/938 [03:30<00:00,  4.46it/s, train_acc=97.1]


Epoch 23: Train Acc=97.14% | Val Acc=99.59%


Epoch 24/25: 100%|██████████| 938/938 [03:31<00:00,  4.43it/s, train_acc=97.2]


Epoch 24: Train Acc=97.19% | Val Acc=99.65%


Epoch 25/25: 100%|██████████| 938/938 [03:33<00:00,  4.39it/s, train_acc=97.2]


Epoch 25: Train Acc=97.24% | Val Acc=99.59%

Training complete.
Best INT8 Validation Accuracy: 99.91%

===== MODEL SUMMARY =====
Model: MobileNetV3-Small (0.25x) INT8 QAT
Accuracy: 99.91%
Parameters: 118,557
Size: 0.45 MB
Avg Inference: 5.52 ms (0.0055 s)
Confusion Matrix:
 [[2997    0    0    1    2]
 [   0 2998    0    2    0]
 [   0    0 3000    0    0]
 [   0    4    0 2996    0]
 [   4    0    0    0 2996]]

[5] Exporting TorchScript model for PyTorch Mobile...
Saved: mobilenetv3_025_int8_qat_kd.pt


In [3]:
import torch
import torchvision

NUM_CLASSES = 5
IMG_SIZE = 224

model = torchvision.models.mobilenet_v3_small(width_mult=0.25)
model.classifier[3] = torch.nn.Linear(
    model.classifier[3].in_features, NUM_CLASSES
)
model.load_state_dict(torch.load("/kaggle/working/best_student_int8_qat_kd.pth", map_location="cpu"))
model.eval()

dummy = torch.randn(1, 3, IMG_SIZE, IMG_SIZE)

torch.onnx.export(
    model,
    dummy,
    "student_fp32_qat.onnx",
    input_names=["input"],
    output_names=["logits"],
    opset_version=13,
    do_constant_folding=True,
    dynamic_axes={"input": {0: "batch"}}
)

print("Exported student_fp32_qat.onnx")


Exported student_fp32_qat.onnx


In [2]:
!pip install onnxruntime

Collecting onnxruntime
  Downloading onnxruntime-1.23.2-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (5.1 kB)
Collecting coloredlogs (from onnxruntime)
  Downloading coloredlogs-15.0.1-py2.py3-none-any.whl.metadata (12 kB)
Collecting humanfriendly>=9.1 (from coloredlogs->onnxruntime)
  Downloading humanfriendly-10.0-py2.py3-none-any.whl.metadata (9.2 kB)
Downloading onnxruntime-1.23.2-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (17.4 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m17.4/17.4 MB[0m [31m89.7 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[?25hDownloading coloredlogs-15.0.1-py2.py3-none-any.whl (46 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m46.0/46.0 kB[0m [31m2.0 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading humanfriendly-10.0-py2.py3-none-any.whl (86 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m86.8/86.8 kB[0m [31m4.8 MB/s[0m eta [36m0:00:00[0m
[?25hInstal

In [7]:
from onnxruntime.quantization import quantize_dynamic, QuantType

quantize_dynamic(
    model_input="/kaggle/working/student_fp32_qat.onnx",
    model_output="/kaggle/working/student_int8.onnx",
    weight_type=QuantType.QInt8
)

print("INT8 ONNX model saved as student_int8.onnx")




INT8 ONNX model saved as student_int8.onnx


In [14]:
# STATIC INT8 ONNX QUANTIZATION + EVALUATION

# IMPORTS
import os
import time
import shutil
import random
import numpy as np
import onnx
import onnxruntime as ort

from onnxruntime.quantization import (
    quantize_static,
    QuantType,
    CalibrationDataReader
)

from pathlib import Path
from PIL import Image
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix

import torch
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader

# CONFIG
FP32_ONNX_INPUT = "/kaggle/input/student-onnx/onnx/default/1/student_fp32_qat.onnx"
WORK_FP32_ONNX  = "/kaggle/working/student_fp32_qat.onnx"
INT8_ONNX_PATH  = "/kaggle/working/student_int8_static.onnx"

DATA_DIR = "/kaggle/input/rice-image-dataset/Rice_Image_Dataset"

IMG_SIZE = 224
BATCH_SIZE = 32
NUM_CLASSES = 5
CALIB_SAMPLES = 200 

CLASSES = ["Arborio", "Basmati", "Ipsala", "Jasmine", "Karacadag"]

print("="*70)
print("STATIC INT8 ONNX CONVERSION + EVALUATION")
print("="*70)

# COPY FP32 ONNX TO WRITABLE DIRECTORY
print("[0] Copying FP32 ONNX to /kaggle/working...")
shutil.copy(FP32_ONNX_INPUT, WORK_FP32_ONNX)
print("FP32 ONNX ready:", WORK_FP32_ONNX)

# CALIBRATION DATA READER
print("[1] Preparing calibration data reader...")

class RiceCalibrationReader(CalibrationDataReader):
    def __init__(self, num_samples):
        self.transform = transforms.Compose([
            transforms.Resize((IMG_SIZE, IMG_SIZE)),
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225]
            )
        ])

        paths = []
        for c in CLASSES:
            paths.extend(list((Path(DATA_DIR) / c).glob("*.*")))

        random.shuffle(paths)
        random.seed(42)
        paths = paths[:num_samples]

        self.data = []
        for p in paths:
            img = Image.open(p).convert("RGB")
            x = self.transform(img).unsqueeze(0).numpy()
            self.data.append({"input": x})

        self.iterator = iter(self.data)

    def get_next(self):
        return next(self.iterator, None)


# STATIC INT8 QUANTIZATION
print("[2] Converting FP32 QAT ONNX → Static INT8 ONNX...")

quantize_static(
    model_input=WORK_FP32_ONNX,
    model_output=INT8_ONNX_PATH,
    calibration_data_reader=RiceCalibrationReader(CALIB_SAMPLES),
    weight_type=QuantType.QInt8,
    activation_type=QuantType.QInt8
)

print("INT8 model saved:", INT8_ONNX_PATH)

# MODEL SIZE & QUANT INFO
size_mb = os.path.getsize(INT8_ONNX_PATH) / (1024 ** 2)
print(f"Model size: {size_mb:.2f} MB")

onnx_model = onnx.load(INT8_ONNX_PATH)
int8_tensors = sum(
    1 for t in onnx_model.graph.initializer
    if t.data_type == onnx.TensorProto.INT8
)
fp32_tensors = sum(
    1 for t in onnx_model.graph.initializer
    if t.data_type == onnx.TensorProto.FLOAT
)

print(f"INT8 tensors: {int8_tensors}")
print(f"FP32 tensors: {fp32_tensors}")
print("Quantization scheme: Static INT8")

# VALIDATION DATASET
print("[3] Preparing validation dataset...")

paths, labels = [], []
for c in CLASSES:
    for p in (Path(DATA_DIR) / c).glob("*.*"):
        paths.append(str(p))
        labels.append(CLASSES.index(c))

idx = np.arange(len(paths))
_, val_idx = train_test_split(
    idx, test_size=0.2, stratify=labels, random_state=42
)

class RiceValDataset(Dataset):
    def __init__(self, indices):
        self.indices = indices
        self.transform = transforms.Compose([
            transforms.Resize((IMG_SIZE, IMG_SIZE)),
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225]
            )
        ])

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, i):
        idx = self.indices[i]
        img = Image.open(paths[idx]).convert("RGB")
        return self.transform(img).numpy(), labels[idx]

val_loader = DataLoader(
    RiceValDataset(val_idx),
    batch_size=BATCH_SIZE,
    shuffle=False
)

print(f"Validation samples: {len(val_idx)}")

# ONNX RUNTIME SESSION
print("[4] Creating ONNX Runtime session...")
sess = ort.InferenceSession(
    INT8_ONNX_PATH,
    providers=["CPUExecutionProvider"]
)
input_name = sess.get_inputs()[0].name


# ACCURACY EVALUATION
print("[5] Running accuracy evaluation...")

correct = 0
total = 0
all_preds = []
all_labels = []

for x, y in val_loader:
    logits = sess.run(None, {input_name: x.numpy()})[0]
    preds = logits.argmax(axis=1)

    correct += (preds == y.numpy()).sum()
    total += y.size(0)

    all_preds.extend(preds.tolist())
    all_labels.extend(y.numpy().tolist())

acc = 100.0 * correct / total

# LATENCY TEST
dummy = np.random.randn(1, 3, IMG_SIZE, IMG_SIZE).astype(np.float32)

# Warm-up
for _ in range(10):
    sess.run(None, {input_name: dummy})

t0 = time.time()
for _ in range(100):
    sess.run(None, {input_name: dummy})
t1 = time.time()

avg_ms = (t1 - t0) / 100 * 1000

# FINAL REPORT
print("\n" + "="*30 + " FINAL REPORT " + "="*30)
print("Model: MobileNetV3-Small – Static INT8")
print(f"Calibration samples: {CALIB_SAMPLES}")
print(f"Accuracy: {acc:.2f}%")
print(f"Model size: {size_mb:.2f} MB")
print(f"Avg inference latency: {avg_ms:.2f} ms")

cm = confusion_matrix(all_labels, all_preds)
print("\nConfusion Matrix:")
print(cm)
print("="*80)

STATIC INT8 ONNX CONVERSION + EVALUATION
[0] Copying FP32 ONNX to /kaggle/working...
FP32 ONNX ready: /kaggle/working/student_fp32_qat.onnx
[1] Preparing calibration data reader...
[2] Converting FP32 QAT ONNX → Static INT8 ONNX...




INT8 model saved: /kaggle/working/student_int8_static.onnx
Model size: 0.31 MB
INT8 tensors: 237
FP32 tensors: 237
Quantization scheme: Static INT8
[3] Preparing validation dataset...
Validation samples: 15000
[4] Creating ONNX Runtime session...
[5] Running accuracy evaluation...

Model: MobileNetV3-Small – Static INT8
Calibration samples: 200
Accuracy: 98.07%
Model size: 0.31 MB
Avg inference latency: 3.95 ms

Confusion Matrix:
[[2990    0    1    2    7]
 [   0 2998    0    2    0]
 [   0    0 2999    1    0]
 [  10  209    0 2730   51]
 [   5    2    0    0 2993]]


In [15]:
import onnxruntime as ort
import numpy as np
import time
from PIL import Image
from torchvision import transforms

# CONFIG
MODEL_PATH = "/kaggle/working/student_int8_static.onnx"
IMAGE_PATH = "/kaggle/input/rice-image-dataset/Rice_Image_Dataset/Arborio/Arborio (1000).jpg"

CLASSES = ["Arborio", "Basmati", "Ipsala", "Jasmine", "Karacadag"]
IMG_SIZE = 224
NUM_RUNS = 100    

def softmax(x):
    x = x - np.max(x, axis=1, keepdims=True)
    exp_x = np.exp(x)
    return exp_x / np.sum(exp_x, axis=1, keepdims=True)

# PREPROCESS
transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])

img = Image.open(IMAGE_PATH).convert("RGB")
x = transform(img).unsqueeze(0).numpy() 

# ONNX RUNTIME
print("Loading ONNX Runtime session...")
sess = ort.InferenceSession(
    MODEL_PATH,
    providers=["CPUExecutionProvider"]
)

input_name = sess.get_inputs()[0].name
output_name = sess.get_outputs()[0].name

# WARM-UP
for _ in range(10):
    sess.run([output_name], {input_name: x})

# TIMING
start = time.time()
for _ in range(NUM_RUNS):
    logits = sess.run([output_name], {input_name: x})[0]
end = time.time()

avg_time_sec = (end - start) / NUM_RUNS
avg_time_ms = avg_time_sec * 1000

# PREDICTION
probs = softmax(logits)[0]
pred_id = probs.argmax()
pred_class = CLASSES[pred_id]
confidence = probs[pred_id] * 100

# OUTPUT
print("Prediction:", pred_class)
print(f"Confidence: {confidence:.2f}%")
print(f"Average inference time: {avg_time_ms:.2f} ms ({avg_time_sec:.4f} s)")

Loading ONNX Runtime session...
Prediction: Arborio
Confidence: 98.25%
Average inference time: 2.99 ms (0.0030 s)
