# Load an MLP model and convert it to ONNX file
Please take the following steps before running this notebook
1. clone the repo by running `git clone https://github.com/abidihaider/RealTimeAlignment.git`
2. check to the develop branch of the repo
3. run `python setup.py develop`

In [1]:
import sys
import yaml
import torch
from mlp import MLP

In [2]:
with open('config.yaml', 'r', encoding='UTF-8') as handle:
    config = yaml.safe_load(handle)

model = MLP(**config['model'])

model.eval()

MLP(
  (embed): Sequential(
    (0): Linear(in_features=6, out_features=256, bias=True)
    (1): SineActivation()
    (2): Linear(in_features=256, out_features=256, bias=True)
    (3): SineActivation()
  )
  (solvers): ModuleList(
    (0-2): 3 x SubsetSolver(
      (model): Sequential(
        (0): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (1): Linear(in_features=768, out_features=256, bias=True)
        (2): SineActivation()
        (3): ResLinear(
          (norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
          (linear): Linear(in_features=256, out_features=256, bias=True)
          (activ): SineActivation()
        )
      )
    )
  )
  (output): Linear(in_features=256, out_features=27, bias=True)
)

In [3]:
# Create dummy input with the correct shape
in_features = config['model']['in_features']
num_particles = config['data']['num_particles']
dummy_input = torch.randn(1, num_particles, in_features)

# Export to ONNX
torch.onnx.export(
    model,                      # model being run
    dummy_input,                # model input (or a tuple for multiple inputs)
    "mlp.onnx",                 # where to save the model (filename)
    export_params=True,         # store the trained weights inside the model
    opset_version=11,           # the ONNX version to export to (11 is widely supported)
    do_constant_folding=True,   # optimize constants
    input_names=['input'],      # input name (can be arbitrary)
    output_names=['output'],    # output name
    dynamic_axes={              # support dynamic batch size
        'input': {0: 'batch_size'},
        'output': {0: 'batch_size'},
    }
)

print("Model has been exported to ONNX format.")

Model has been exported to ONNX format.
