In [1]:
import onnxruntime.training.onnxblock as onnxblock
from onnxruntime.training.api import CheckpointState, Module, Optimizer
from onnxruntime.training import artifacts
from onnxruntime import InferenceSession
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np
import torch
import onnx
import io
import netron
import evaluate

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class XSumNet(torch.nn.Module):
    def __init__(self, input_size, hidden_size, num_classes):
        super(XSumNet, self).__init__()

        self.fc1 = torch.nn.Linear(input_size, hidden_size)
        self.relu = torch.nn.ReLU()
        self.fc2 = torch.nn.Linear(hidden_size, num_classes)

    def forward(self, model_input):
        out = self.fc1(model_input)
        out = self.relu(out)
        out = self.fc2(out)
        return out

device = "cpu"
batch_size, input_size, hidden_size, output_size = 64, 784, 500, 10
pt_model = XSumNet(input_size, hidden_size, output_size).to(device)

In [3]:
model_inputs = (torch.randn(batch_size, input_size, device=device),)

input_names = ["input"]
output_names = ["output"]
dynamic_axes = {"input": {0: "batch_size"}, "output": {0: "batch_size"}}

f = io.BytesIO()
torch.onnx.export(
    pt_model,
    model_inputs,
    f,
    input_names=input_names,
    output_names=output_names,
    opset_version=14,
    do_constant_folding=False,
    training=torch.onnx.TrainingMode.TRAINING,
    dynamic_axes=dynamic_axes,
    export_params=True,
    keep_initializers_as_inputs=False,
)
onnx_model = onnx.load_model_from_string(f.getvalue())

In [4]:
requires_grad = [name for name, param in pt_model.named_parameters() if param.requires_grad]

frozen_params = [name for name, param in pt_model.named_parameters() if not param.requires_grad]

artifacts.generate_artifacts(
    onnx_model,
    optimizer=artifacts.OptimType.AdamW,
    loss=artifacts.LossType.CrossEntropyLoss,
    requires_grad=requires_grad,
    frozen_params=frozen_params,
    additional_output_names=output_names)

2024-08-01 17:54:08.466003 [I:onnxruntime:Default, graph_transformer.cc:15 Apply] GraphTransformer ConstantSharing modified: 0 with status: OK
2024-08-01 17:54:08.466280 [I:onnxruntime:Default, graph_transformer.cc:15 Apply] GraphTransformer ShapeInputMerge modified: 0 with status: OK
2024-08-01 17:54:08.466641 [I:onnxruntime:Default, graph_transformer.cc:15 Apply] GraphTransformer LayerNormFusion modified: 0 with status: OK
2024-08-01 17:54:08.466711 [I:onnxruntime:Default, graph_transformer.cc:15 Apply] GraphTransformer CommonSubexpressionElimination modified: 0 with status: OK
2024-08-01 17:54:08.466717 [I:onnxruntime:Default, graph_transformer.cc:15 Apply] GraphTransformer GeluFusion modified: 0 with status: OK
2024-08-01 17:54:08.466723 [I:onnxruntime:Default, graph_transformer.cc:15 Apply] GraphTransformer SimplifiedLayerNormFusion modified: 0 with status: OK
2024-08-01 17:54:08.466727 [I:onnxruntime:Default, graph_transformer.cc:15 Apply] GraphTransformer FastGeluFusion modified

In [5]:
netron.start("eval_model.onnx")

Serving 'eval_model.onnx' at http://localhost:8080


('localhost', 8080)