In [10]:
# Step 1 - Import
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

import optuna
from torch.ao.quantization import (
    prepare_qat, convert, get_default_qat_qconfig
)

import torch.onnx
print("Torch version", torch.__version__)

ModuleNotFoundError: No module named 'torch'

In [None]:
# Step 2 - Dataset & Transform
# IMPORTANT: For 3-class training, your folders should be:
#   myImages/train/cats
#   myImages/train/dogs
#   myImages/train/others
#   myImages/test/cats
#   myImages/test/dogs
#   myImages/test/others
# ImageFolder will discover classes in alphabetical order, so with
# ['cats', 'dogs', 'others'] we get:
#   index 0 -> cats, index 1 -> dogs, index 2 -> others
transform = transforms.Compose([
    transforms.Resize((28, 28)),
    transforms.Grayscale(),
    transforms.ToTensor()
])

train_dataset = datasets.ImageFolder("myImages/train", transform=transform)
test_dataset = datasets.ImageFolder("myImages/test", transform=transform)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

print("Classes:", train_dataset.classes)

Classes: ['cats', 'dogs']


In [None]:
# Step 3 - Model Definition
# NOTE:
# - We want 3 explicit classes: ['cats', 'dogs', 'others'].
# - So the final layer should output 3 logits (for each class).
# - ImageFolder will map folders in alphabetical order, so make sure your
#   train/test folders look like:
#     myImages/train/cats
#     myImages/train/dogs
#     myImages/train/others
#     myImages/test/cats
#     myImages/test/dogs
#     myImages/test/others
#   and we will interpret index 0 -> cats, 1 -> dogs, 2 -> others.
class SimpleClassifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(28*28, 128)
        self.relu = nn.ReLU()
        # 3 outputs: 0=cat, 1=dog, 2=other
        self.fc2 = nn.Linear(128, 3)

    def forward(self, x):
        x = self.flatten(x)
        x = self.relu(self.fc1(x))
        x = self.fc2(x)
        return x


In [None]:
# Step 4 - OPTUNA(Hyperparameter Tuning)
# Its good to choose best lr (learning rate) not like a fixed always, Optuna tries multiple learning rates automatically and picks best.
# Define Objective Function
def objective(trial):
    lr = trial.suggest_float("lr", 1e-4, 1e-2, log=True)

    model = SimpleClassifier()
    optimizer = optim.Adam(model.parameters(), lr=lr)
    loss_fn = nn.CrossEntropyLoss()

    for epoch in range(3):
        for images, labels in train_loader:
            optimizer.zero_grad()
            output = model(images)
            loss = loss_fn(output, labels)
            loss.backward()
            optimizer.step()

    # Evaluate
    correct = 0
    total = 0
    model.eval()
    with torch.no_grad():
        for x, y in test_loader:
            preds = model(x).argmax(dim=1)
            correct += (preds == y).sum().item()
            total += y.size(0)

    accuracy = correct / total
    return accuracy


In [None]:
# Run Optuna
study = optuna.create_study(direction="maximize")
study.optimize(objective, n_trials=10)

print("Best LR:", study.best_params)

[I 2026-01-01 19:20:54,714] A new study created in memory with name: no-name-52b93d62-17d7-4e3b-b9e7-0b9211d63af3
[I 2026-01-01 19:20:54,912] Trial 0 finished with value: 1.0 and parameters: {'lr': 0.00020893793011634825}. Best is trial 0 with value: 1.0.
[I 2026-01-01 19:20:55,099] Trial 1 finished with value: 0.5 and parameters: {'lr': 0.00012275960600518357}. Best is trial 0 with value: 1.0.
[I 2026-01-01 19:20:55,285] Trial 2 finished with value: 0.5 and parameters: {'lr': 0.00025534445320197726}. Best is trial 0 with value: 1.0.
[I 2026-01-01 19:20:55,469] Trial 3 finished with value: 0.5 and parameters: {'lr': 0.008656909350357936}. Best is trial 0 with value: 1.0.
[I 2026-01-01 19:20:55,663] Trial 4 finished with value: 0.5 and parameters: {'lr': 0.00020078310444734574}. Best is trial 0 with value: 1.0.
[I 2026-01-01 19:20:55,880] Trial 5 finished with value: 1.0 and parameters: {'lr': 0.00048144570991253213}. Best is trial 0 with value: 1.0.
[I 2026-01-01 19:20:56,067] Trial 6 

Best LR: {'lr': 0.00020893793011634825}


In [None]:
# STEP 5 - Train Final Model using best LR
best_lr = study.best_params["lr"]

model = SimpleClassifier()
optimizer = optim.Adam(model.parameters(), lr=best_lr)
loss_fn = nn.CrossEntropyLoss()

for epoch in range(5):
    total_loss = 0
    for images, labels in train_loader:
        optimizer.zero_grad()
        loss = loss_fn(model(images), labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    print(f"Epoch {epoch+1} Loss: {total_loss:.4f}")


Epoch 1 Loss: 2.3864
Epoch 2 Loss: 2.1350
Epoch 3 Loss: 1.9161
Epoch 4 Loss: 1.7318
Epoch 5 Loss: 1.5756


In [None]:
# Save float32 before QAT 
torch.save(model.state_dict(), "fp32_model.pth")

In [None]:
# STEP 6 - Quantization Aware Training (QAT)
# PREPARE Model for QAT
model.train()
model.qconfig = get_default_qat_qconfig("fbgemm")

prepare_qat(model, inplace=True)

# Fine-tune with QAT
for epoch in range(3):
    for images, labels in train_loader:
        optimizer.zero_grad()
        loss = loss_fn(model(images), labels)
        loss.backward()
        optimizer.step()

For migrations of users: 
1. Eager mode quantization (torch.ao.quantization.quantize, torch.ao.quantization.quantize_dynamic), please migrate to use torchao eager mode quantize_ API instead 
2. FX graph mode quantization (torch.ao.quantization.quantize_fx.prepare_fx,torch.ao.quantization.quantize_fx.convert_fx, please migrate to use torchao pt2e quantization API instead (prepare_pt2e, convert_pt2e) 
3. pt2e quantization has been migrated to torchao (https://github.com/pytorch/ao/tree/main/torchao/quantization/pt2e) 
see https://github.com/pytorch/ao/issues/2259 for more details
  prepare_qat(model, inplace=True)


In [None]:
# after QAT training reload float32 Model
fp32_model = SimpleClassifier()
fp32_model.load_state_dict(torch.load("fp32_model.pth"))
fp32_model.eval()


SimpleClassifier(
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (fc1): Linear(in_features=784, out_features=128, bias=True)
  (relu): ReLU()
  (fc2): Linear(in_features=128, out_features=10, bias=True)
)

In [None]:
# Step 7 - Evaluate after QAT
# STEP 7 - Evaluate QAT model
correct = 0
total = 0
model.eval()

with torch.no_grad():
    for images, labels in test_loader:
        outputs = model(images)
        predictions = torch.argmax(outputs, dim=1)
        correct += (predictions == labels).sum().item()
        total += labels.size(0)

print("QAT Model Accuracy:", correct / total)


QAT Model Accuracy: 0.5


In [None]:
# STEP 8 - Test single image
# Single image prediction
example_img, example_label = test_dataset[1]

model.eval()
with torch.no_grad():
    output = model(example_img.unsqueeze(0))
    pred = torch.argmax(output)

print("True Label:", example_label)
print("Predicted:", pred.item())


True Label: 1
Predicted: 1


In [None]:
# Step 9 - Convert to Quantized Model
#quantized_model = torch.ao.quantization.convert(model.eval(), inplace=False)

For migrations of users: 
1. Eager mode quantization (torch.ao.quantization.quantize, torch.ao.quantization.quantize_dynamic), please migrate to use torchao eager mode quantize_ API instead 
2. FX graph mode quantization (torch.ao.quantization.quantize_fx.prepare_fx,torch.ao.quantization.quantize_fx.convert_fx, please migrate to use torchao pt2e quantization API instead (prepare_pt2e, convert_pt2e) 
3. pt2e quantization has been migrated to torchao (https://github.com/pytorch/ao/tree/main/torchao/quantization/pt2e) 
see https://github.com/pytorch/ao/issues/2259 for more details
  quantized_model = torch.ao.quantization.convert(model.eval(), inplace=False)


In [None]:
# Finally export model (ONNX for deployment)
# Export as a single, self-contained ONNX file (no external .data file),
# so it can be loaded directly in the browser with onnxruntime-web.

dummy_input = torch.randn(1, 1, 28, 28)

torch.onnx.export(
    fp32_model,
    dummy_input,
    "cat-dog_classification.onnx",
    input_names=["input"],
    output_names=["output"],
    opset_version=18,
    use_external_data_format=False,
)

print("ONNX model exported successfully!")


[torch.onnx] Obtain model graph for `SimpleClassifier([...]` with `torch.export.export(..., strict=False)`...
[torch.onnx] Obtain model graph for `SimpleClassifier([...]` with `torch.export.export(..., strict=False)`... ✅
[torch.onnx] Run decomposition...
[torch.onnx] Run decomposition... ✅
[torch.onnx] Translate the graph into ONNX...
[torch.onnx] Translate the graph into ONNX... ✅
ONNX model exported successfully!
