# 🚀 CIFAR-10 Image Classification with PyTorch and ONNX

## 📌 Project Description
In this project, I build a complete *end-to-end pipeline* for image classification:
1. Load and preprocess CIFAR-10 dataset
2. Build a simple CNN in PyTorch
3. Train and evaluate the model
4. Export the trained model to *ONNX*
5. Run inference using *ONNX Runtime*
6. (Optional later) Deploy the ONNX model with FastAPI

This notebook shows *step by step* how to go from data → training → deployment.

In [1]:
# ============================================
# Part 1: Install Dependencies
# ============================================
# I need PyTorch, TorchVision, ONNX, and ONNX Runtime
!pip install torch torchvision onnx onnxruntime

Collecting onnx
  Downloading onnx-1.19.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (7.0 kB)
Collecting onnxruntime
  Downloading onnxruntime-1.22.1-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (4.9 kB)
Collecting coloredlogs (from onnxruntime)
  Downloading coloredlogs-15.0.1-py2.py3-none-any.whl.metadata (12 kB)
Collecting humanfriendly>=9.1 (from coloredlogs->onnxruntime)
  Downloading humanfriendly-10.0-py2.py3-none-any.whl.metadata (9.2 kB)
Downloading onnx-1.19.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (18.2 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m18.2/18.2 MB[0m [31m72.6 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading onnxruntime-1.22.1-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (16.5 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m16.5/16.5 MB[0m [31m50.0 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading coloredlogs-15.0.1-py2.py3-none-any.whl (46 

# Part 2: Import Libraries
Here I import the libraries required for:
- *PyTorch* (training and model building)
- *TorchVision* (dataset and transforms)
- *ONNX & ONNX Runtime* (model export and inference)
- *PIL / NumPy* (image handling)

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import onnx
import onnxruntime as ort
import numpy as np
from PIL import Image

# Part 3: Data Preparation (CIFAR-10)
- CIFAR-10 dataset has *60,000 images (32x32, RGB, 10 classes)*.
- I will apply:
  - *ToTensor* (convert images to PyTorch tensors)
  - *Normalize* (scale pixel values between -1 and 1)
- I use DataLoader to batch and shuffle data.

In [3]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# Train set
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64,
                                          shuffle=True)

# Test set
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=64,
                                         shuffle=False)

# CIFAR-10 classes
classes = ('plane', 'car', 'bird', 'cat', 'deer',
           'dog', 'frog', 'horse', 'ship', 'truck')

100%|██████████| 170M/170M [00:04<00:00, 41.9MB/s]


# Part 4: Define the CNN Model
I create a *simple CNN*:
1. Conv2d(3→32) → ReLU → MaxPool
2. Conv2d(32→64) → ReLU → MaxPool
3. Flatten → Fully connected (128 units)
4. Output layer (10 classes)

In [4]:
class CNNClassifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
        self.fc1 = nn.Linear(64*8*8, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.pool(torch.relu(self.conv1(x)))
        x = self.pool(torch.relu(self.conv2(x)))
        x = x.view(-1, 64*8*8)
        x = torch.relu(self.fc1(x))
        return self.fc2(x)

# Instantiate
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = CNNClassifier().to(device)

# Check parameters
params = list(model.parameters())
print(model)
print(f"Number of parameter tensors: {len(params)}")
print(f"Shape of first weight tensor: {params[0].shape}")


CNNClassifier(
  (conv1): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (fc1): Linear(in_features=4096, out_features=128, bias=True)
  (fc2): Linear(in_features=128, out_features=10, bias=True)
)
Number of parameter tensors: 8
Shape of first weight tensor: torch.Size([32, 3, 3, 3])


# Part 5: Training Loop
- Loss function: *CrossEntropyLoss*
- Optimizer: *Adam*
- Train for a few epochs

In [5]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

epochs = 50
for epoch in range(epochs):
    running_loss = 0.0
    for images, labels in trainloader:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    print(f"[Epoch {epoch+1}] Loss: {running_loss/len(trainloader):.3f}")

print("✅ Training finished")

[Epoch 1] Loss: 1.332
[Epoch 2] Loss: 0.959
[Epoch 3] Loss: 0.797
[Epoch 4] Loss: 0.682
[Epoch 5] Loss: 0.580
[Epoch 6] Loss: 0.483
[Epoch 7] Loss: 0.399
[Epoch 8] Loss: 0.310
[Epoch 9] Loss: 0.242
[Epoch 10] Loss: 0.185
[Epoch 11] Loss: 0.144
[Epoch 12] Loss: 0.118
[Epoch 13] Loss: 0.098
[Epoch 14] Loss: 0.086
[Epoch 15] Loss: 0.079
[Epoch 16] Loss: 0.066
[Epoch 17] Loss: 0.075
[Epoch 18] Loss: 0.063
[Epoch 19] Loss: 0.057
[Epoch 20] Loss: 0.058
[Epoch 21] Loss: 0.051
[Epoch 22] Loss: 0.055
[Epoch 23] Loss: 0.046
[Epoch 24] Loss: 0.052
[Epoch 25] Loss: 0.050
[Epoch 26] Loss: 0.044
[Epoch 27] Loss: 0.040
[Epoch 28] Loss: 0.043
[Epoch 29] Loss: 0.047
[Epoch 30] Loss: 0.038
[Epoch 31] Loss: 0.046
[Epoch 32] Loss: 0.047
[Epoch 33] Loss: 0.033
[Epoch 34] Loss: 0.045
[Epoch 35] Loss: 0.046
[Epoch 36] Loss: 0.033
[Epoch 37] Loss: 0.034
[Epoch 38] Loss: 0.042
[Epoch 39] Loss: 0.039
[Epoch 40] Loss: 0.027
[Epoch 41] Loss: 0.033
[Epoch 42] Loss: 0.039
[Epoch 43] Loss: 0.038
[Epoch 44] Loss: 0.0

# Part 6: Model Evaluation
I check the *accuracy on the test set*

In [6]:
correct, total = 0, 0
model.eval()
with torch.no_grad():
    for images, labels in testloader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f"Test Accuracy: {100 * correct / total:.2f}%")

Test Accuracy: 69.83%


# Part 7: Export to ONNX
I export the trained PyTorch model to ONNX:
- Input shape: (1, 3, 32, 32)
- Save as cnn_cifar10.onnx

In [24]:
#!pip install --upgrade onnx onnxscript
#!pip install netron
#!pip install netron --quiet
#import netron
#from google.colab import output

dummy_input = torch.randn(1, 3, 32, 32, device=device)
onnx_path = "cnn_cifar10.onnx"
torch.onnx.export(
    model,
    dummy_input,
    onnx_path,
    input_names=["input"],
    output_names=["output"],
    dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}},
    opset_version=11
)
print(f"✅ Model exported to ONNX: {onnx_path}")

# Validate ONNX model
onnx_model = onnx.load(onnx_path)
onnx.checker.check_model(onnx_model)
print("✅ ONNX model checked and valid")
onnx.save(onnx_model, "ms.onnx")
#exported = torch.onnx.dynamo_export(model, dummy_input)
#exported.save("ms.onnx")


# Launch Netron via the command line on port 8081
#get_ipython().system_raw("netron ms.onnx --port 8081 &")

# Display Netron in an iframe directly in Colab
#output.serve_kernel_port_as_iframe(8081)

✅ Model exported to ONNX: cnn_cifar10.onnx
✅ ONNX model checked and valid


  torch.onnx.export(


# Part 8: Inference with ONNX Runtime
We load the ONNX model and run inference on a test image.

In [8]:
# Create ONNX Runtime session
session = ort.InferenceSession(onnx_path, providers=["CPUExecutionProvider"])

# Take one image from test set
image, label = testset[0]
img_numpy = image.unsqueeze(0).numpy().astype(np.float32)

# Run inference
inputs = {session.get_inputs()[0].name: img_numpy}
outputs = session.run(None, inputs)
pred = np.argmax(outputs[0], axis=1)

print(f"True label: {label} ({classes[label]}), Predicted: {pred[0]} ({classes[pred[0]]})")

True label: 3 (cat), Predicted: 3 (cat)


# Part 9: Compare PyTorch vs ONNX Prediction
To ensure export worked correctly, we compare predictions of PyTorch vs ONNX.

In [9]:
model.eval()
with torch.no_grad():
    torch_out = model(image.unsqueeze(0).to(device))
    torch_pred = torch.argmax(torch_out, dim=1).cpu().numpy()

print(f"PyTorch prediction: {torch_pred[0]}, ONNX prediction: {pred[0]}")

PyTorch prediction: 3, ONNX prediction: 3


# FastAPI Deployment of ONNX Model and Local Testing

## Description
This part assumes you already have an ONNX model (cnn_cifar10.onnx) exported from PyTorch.  
I will:
1. Load the ONNX model using onnxruntime.
2. Define a FastAPI application in Colab.
3. Expose a /predict endpoint to receive image uploads and return class predictions.
4. Test the API from Python in the same notebook.

No external tunneling (ngrok) is required; everything runs locally in Colab.

In [10]:
# ===================================================
# FastAPI Deployment of ONNX Model + Local Testing
# ===================================================

# Part 1: Install required libraries
!pip install fastapi uvicorn nest-asyncio onnxruntime Pillow torchvision requests -q

# Part 2: Import libraries
import io
import numpy as np
from PIL import Image
import onnxruntime as ort
import torchvision.transforms as transforms
from fastapi import FastAPI, UploadFile, File
from fastapi.responses import JSONResponse
import nest_asyncio
import uvicorn
import threading
import time
import requests
import urllib.request

# Part 3: Load ONNX Model
# Load the ONNX model using ONNX Runtime for inference
onnx_path = "cnn_cifar10.onnx"  # path ONNX model
session = ort.InferenceSession(onnx_path, providers=["CPUExecutionProvider"])

# Part 4: Define Preprocessing Function
# CIFAR-10 images need to be resized to 32x32, converted to tensor, and normalized
transform = transforms.Compose([
    transforms.Resize((32, 32)),  # resize image to 32x32
    transforms.ToTensor(),        # convert to tensor [0,1]
    transforms.Normalize((0.5,), (0.5,))  # normalize
])

# CIFAR-10 class labels
classes = ('plane', 'car', 'bird', 'cat', 'deer',
           'dog', 'frog', 'horse', 'ship', 'truck')

# Part 5: Create FastAPI App with /predict Endpoint
# Enable nested event loops to run FastAPI in Colab
nest_asyncio.apply()

app = FastAPI(title="CIFAR-10 ONNX Classifier")

@app.post("/predict")
async def predict(file: UploadFile = File(...)):
    """
    Accept an uploaded image and return the predicted CIFAR-10 class.
    """
    # Read image bytes
    image_bytes = await file.read()
    # Convert to PIL Image and RGB
    image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
    # Apply preprocessing
    img_tensor = transform(image).unsqueeze(0).numpy().astype(np.float32)
    # Run ONNX inference
    inputs = {session.get_inputs()[0].name: img_tensor}
    outputs = session.run(None, inputs)
    # Get predicted class index
    pred = int(np.argmax(outputs[0], axis=1)[0])
    # Return JSON response
    return JSONResponse({"class_id": pred, "class_name": classes[pred]})

# Part 6: Run FastAPI in Background Thread
def run_api():
    uvicorn.run(app, host="0.0.0.0", port=8000)

# Start FastAPI in a background thread
thread = threading.Thread(target=run_api, daemon=True)
thread.start()
time.sleep(2)  # wait for server to start
print("FastAPI server is running locally on http://127.0.0.1:8000")

# Part 7: Test API from Python in the Same Notebook
# Upload or download a sample image
from google.colab import files
uploaded = files.upload()  # select an image from your computer
image_path = list(uploaded.keys())[0]

# Test API
with open(image_path, "rb") as f:
    files = {"file": f}
    response = requests.post("http://127.0.0.1:8000/predict", files=files)

print("Prediction result:", response.json())

INFO:     Started server process [609]
INFO:     Waiting for application startup.
INFO:     Application startup complete.
INFO:     Uvicorn running on http://0.0.0.0:8000 (Press CTRL+C to quit)


FastAPI server is running locally on http://127.0.0.1:8000


Saving plane_image.png to plane_image.png
INFO:     127.0.0.1:56548 - "POST /predict HTTP/1.1" 200 OK
Prediction result: {'class_id': 0, 'class_name': 'plane'}
