#### Pytorch Lightning

In [3]:
import torch
from pytorch_lightning import LightningModule


class SimpleModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.l1 = torch.nn.Linear(in_features=64, out_features=4)

    def forward(self, x):
        return torch.relu(self.l1(x.view(x.size(0), -1)))


model = SimpleModel()
filepath = "exports/test_pytorch_lightning.onnx"
input_sample = torch.randn((1, 64))
model.to_onnx(filepath, input_sample, export_params=True)

#### Pytorch + Pytorch Lightning

In [4]:
import torch
import torch.nn as nn
import pytorch_lightning as pl


class SimpleModel(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(SimpleModel, self).__init__()
        self.layer1 = nn.Linear(input_size, hidden_size)
        self.relu = nn.ReLU()
        self.layer2 = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        x = self.layer1(x)
        x = self.relu(x)
        x = self.layer2(x)
        if x.max() > 0.5:
            return x ** 2
        return x

In [5]:
class LightningModel(pl.LightningModule):
    def __init__(self, input_size, hidden_size, output_size, learning_rate=1e-3):
        super(LightningModel, self).__init__()
        # Use the PyTorch model as self.model
        self.model = SimpleModel(input_size, hidden_size, output_size)
        self.loss_fn = nn.MSELoss()
        self.learning_rate = learning_rate

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = self.loss_fn(y_hat, y)
        self.log("train_loss", loss)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = self.loss_fn(y_hat, y)
        self.log("val_loss", loss)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.learning_rate)

In [None]:
model = LightningModel(input_size=10, hidden_size=32, output_size=1)
filepath = "exports/test_pytorch_lightning_2.onnx"
input_sample = torch.randn((1, 10))
model.to_onnx(filepath, input_sample, export_params=True)

#### Two submodels

In [None]:
import torch
import torch.nn as nn
import torch.onnx


class SubModelA(nn.Module):
    def __init__(self):
        super(SubModelA, self).__init__()
        self.fc = nn.Linear(10, 5)

    def score(self, x):
        return torch.relu(self.fc(x))


class SubModelB(nn.Module):
    def __init__(self):
        super(SubModelB, self).__init__()
        self.fc = nn.Linear(5, 2)

    def forward(self, x):
        return torch.softmax(self.fc(x), dim=1)


class MainModel(nn.Module):
    def __init__(self, model_a, model_b):
        super(MainModel, self).__init__()
        self.models = nn.ModuleDict({
            "model_a": model_a,
            "model_b": model_b
        })

    def forward(self, x):
        # Pass input through SubModelA, then SubModelB
        x = self.models["model_a"].score(x)
        x = self.models["model_b"](x)
        return x


# Instantiate and test the model
model_a = SubModelA()
model_b = SubModelB()
model = MainModel(model_a, model_b)
sample_input = torch.randn(1, 10)
output = model(sample_input)
print(output)

In [10]:
filepath = "exports/test_pytorch_2_sub_model.onnx"
model.eval()  # Set model to evaluation mode
dummy_input = torch.randn(1, 10)  # Input shape must match the model's input
torch.onnx.export(
    model,
    dummy_input,
    filepath,
    export_params=True,
    input_names=["input"],
    output_names=["output"],
    dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}}
)

In [None]:
import onnxruntime as ort
import numpy as np

filepath = "exports/test_pytorch_2_sub_model.onnx"
# Load the ONNX model
session = ort.InferenceSession(filepath)

# Create a sample input matching the ONNX model input shape
dummy_input = np.random.randn(1, 10).astype(np.float32)

# Run inference
outputs = session.run(None, {"input": dummy_input})
print("ONNX Model Output:", outputs[0])

#### Dynamo export

In [None]:
import torch
import torch.nn as nn


class MLPModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc0 = nn.Linear(8, 8, bias=True)
        self.fc1 = nn.Linear(8, 4, bias=True)
        self.fc2 = nn.Linear(4, 2, bias=True)
        self.fc3 = nn.Linear(2, 2, bias=True)

    def forward(self, tensor_x: torch.Tensor):
        tensor_x = self.fc0(tensor_x)
        tensor_x = torch.sigmoid(tensor_x)
        tensor_x = self.fc1(tensor_x)
        tensor_x = torch.sigmoid(tensor_x)
        tensor_x = self.fc2(tensor_x)
        tensor_x = torch.sigmoid(tensor_x)
        output = self.fc3(tensor_x)
        return output


model = MLPModel()
tensor_x = torch.rand((97, 8), dtype=torch.float32)
export_options = torch.onnx.ExportOptions(dynamic_shapes=True)
onnx_program = torch.onnx.dynamo_export(model, tensor_x, export_options=export_options)
onnx_program.save("exports/test_pytorch_dynamo.onnx")