In [9]:
# Import necessary libraries
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import torch.onnx

# MNIST Dataset and DataLoader for both training and testing
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('./data', train=False, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1000, shuffle=False)

# Define the Neural Network Model
class SimpleMNIST(nn.Module):
    def __init__(self):
        super(SimpleMNIST, self).__init__()
        self.fc1 = nn.Linear(28*28, 128)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(128, 10)
        self.softmax = nn.LogSoftmax(dim=1)

    def forward(self, x):
        x = x.view(-1, 28*28)
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        x = self.softmax(x)
        return x
    
class SimplestMNIST(nn.Module):
    def __init__(self):
        super(SimplestMNIST, self).__init__()
        self.fc = nn.Linear(28*28, 10)  # One fully connected layer

    def forward(self, x):
        x = x.view(-1, 28*28)
        return self.fc(x)


model = SimplestMNIST()

# Training Parameters
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.003, momentum=0.9)

# Training Loop
for epoch in range(5):  # 5 epochs for demonstration, increase as needed
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()

        if batch_idx % 100 == 0:
            print(f"Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} ({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}")

# Testing Loop
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
    for data, target in test_loader:
        output = model(data)
        test_loss += criterion(output, target).item()  # Sum up batch loss
        pred = output.argmax(dim=1, keepdim=True)  # Get the index of the max log-probability
        correct += pred.eq(target.view_as(pred)).sum().item()

test_loss /= len(test_loader.dataset)
print(f'\nTest set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)} ({100. * correct / len(test_loader.dataset):.0f}%)\n')



Test set: Average loss: 0.0003, Accuracy: 9206/10000 (92%)



In [10]:

# Export the model to ONNX format
dummy_input = torch.randn(1, 1, 28, 28)  # Dummy input for the model
torch.onnx.export(model, dummy_input, "simple_mnist.onnx", verbose=True)

print("Model has been exported to ONNX format.")


Exported graph: graph(%onnx::Reshape_0 : Float(1, 1, 28, 28, strides=[784, 784, 28, 1], requires_grad=0, device=cpu),
      %fc.weight : Float(10, 784, strides=[784, 1], requires_grad=1, device=cpu),
      %fc.bias : Float(10, strides=[1], requires_grad=1, device=cpu)):
  %/Constant_output_0 : Long(2, strides=[1], device=cpu) = onnx::Constant[value=  -1  784 [ CPULongType{2} ], onnx_name="/Constant"](), scope: __main__.SimplestMNIST:: # /var/folders/_r/rcjcpfc15fxf66bn1kd3l1r00000gn/T/ipykernel_18344/2081141353.py:44:0
  %/Reshape_output_0 : Float(1, 784, strides=[784, 1], requires_grad=0, device=cpu) = onnx::Reshape[allowzero=0, onnx_name="/Reshape"](%onnx::Reshape_0, %/Constant_output_0), scope: __main__.SimplestMNIST:: # /var/folders/_r/rcjcpfc15fxf66bn1kd3l1r00000gn/T/ipykernel_18344/2081141353.py:44:0
  %5 : Float(1, 10, strides=[10, 1], requires_grad=1, device=cpu) = onnx::Gemm[alpha=1., beta=1., transB=1, onnx_name="/fc/Gemm"](%/Reshape_output_0, %fc.weight, %fc.bias), scope: __