<a href="https://colab.research.google.com/github/OneFineStarstuff/OneFineStarstuff/blob/main/Exporting_and_Serving_a_PyTorch_Model_with_ONNX.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
pip install onnx

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# Define a simple neural network
class SimpleNet(nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3)
        self.fc1 = nn.Linear(64 * 24 * 24, 128)  # Adjusted dimensions
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = torch.relu(self.conv1(x))
        print(f'Shape after conv1: {x.shape}')  # Debugging shape
        x = torch.relu(self.conv2(x))
        print(f'Shape after conv2: {x.shape}')  # Debugging shape
        x = torch.flatten(x, 1)
        print(f'Shape after flatten: {x.shape}')  # Debugging shape
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# Load data
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)

# Instantiate and train the model
model = SimpleNet()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training loop (simplified for brevity)
for epoch in range(1):
    for images, labels in train_loader:
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

# Dummy input for ONNX export
dummy_input = torch.randn(1, 1, 28, 28)

# Export the model to an ONNX file
torch.onnx.export(model, dummy_input, "model.onnx", opset_version=11,
                  input_names=['input'], output_names=['output'])

print("Model has been exported to ONNX format as 'model.onnx'")