## Convert model to ONNX

## Import libraries

In [3]:
import os
import boto3
import random
from pathlib import Path
from datetime import datetime

import torch
import torch.nn as nn
from torch.autograd import Variable

# Load trained model

In [4]:
# MNIST dataset parameters.
num_classes = 10  # total classes (0-9 digits).


class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(
                in_channels=1,
                out_channels=32,
                kernel_size=5,
                stride=1,
                padding=2,
            ),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(
                in_channels=32, out_channels=64, kernel_size=5, stride=1, padding=2
            ),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),
        )

        # fully connected layer, output 10 classes
        self.out = nn.Linear(64 * 7 * 7, num_classes)

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        # flatten the output of conv2 to (batch_size, 32 * 7 * 7)
        x = x.view(x.size(0), -1)
        output = self.out(x)
        return output, x  # return x for visualization

In [7]:
# Path to data
directory_path = Path.cwd().parents[0]
trained_model_path = directory_path.joinpath(
    str(os.environ.get("TRAINED_MODEL_PATH", "../models"))
)

path = trained_model_path.joinpath("torch-210921163030-5341ad0f6f389a55")
trained_model = CNN()
trained_model.load_state_dict(torch.load(f"{path}/pytorch_model.pt"))

<All keys matched successfully>

## Save the ONNX model to file

In [9]:
use_ceph = bool(int(os.getenv("USE_CEPH", 0)))
automation = bool(int(os.getenv("AUTOMATION", 0)))


time_version = f"torch-{datetime.now():%y%m%d%H%M%S}-{random.getrandbits(64):08x}"

# Path to data
directory_path = Path.cwd().parents[0]
trained_model_path = directory_path.joinpath(
    str(os.environ.get("TRAINED_MODEL_PATH", "../models"))
)

dummy_input = Variable(torch.randn(1, 1, 28, 28))
torch.onnx.export(
    trained_model, dummy_input, f"{trained_model_path}/{time_version}-model.onnx"
)

if automation or use_ceph:
    # Download files from S3
    s3_endpoint_url = os.environ["OBJECT_STORAGE_ENDPOINT_URL"]
    s3_access_key = os.environ["AWS_ACCESS_KEY_ID"]
    s3_secret_key = os.environ["AWS_SECRET_ACCESS_KEY"]
    s3_bucket = os.environ["OBJECT_STORAGE_BUCKET_NAME"]

    # Create an S3 client
    s3 = boto3.client(
        service_name="s3",
        aws_access_key_id=s3_access_key,
        aws_secret_access_key=s3_secret_key,
        endpoint_url=s3_endpoint_url,
    )

    p = Path(f"{trained_model_path}/{time_version}_model.onnx")
    key = f"{project_name}/models{p}"
    print(key)
    s3.upload_file(Bucket=s3_bucket, Key=key, Filename=str(p))