# Load an MLP model and convert it to ONNX file
Please take the following steps before running this notebook
1. clone the repo by running `git clone https://github.com/abidihaider/RealTimeAlignment.git`
2. check to the develop branch of the repo
3. run `python setup.py develop`

In [1]:
import sys
from pathlib import Path
import yaml
import numpy as np

import torch
import onnx
import onnxruntime

from mlp import MLP

In [2]:
device = 'cpu'

## Load model configuration and use it to initialize a model

In [3]:
with open('checkpoints/config.yaml', 'r', encoding='UTF-8') as handle:
    config = yaml.safe_load(handle)
print(config)
model = MLP(**config['model'])

{'checkpointing': {'checkpoint_path': '/home/yhuang2/PROJs/RealTimeAlignment/train/mlp_v2/checkpoints', 'resume': True, 'save_frequency': 20}, 'data': {'mode': 'raw', 'num_particles': 50, 'rounded': False}, 'model': {'embedding_features': [128, 128], 'in_features': 6, 'out_features': 27, 'rezero': True, 'norm': None, 'activ': {'name': 'leakyrelu', 'negative_slope': 0.1}, 'subset_config': [[6, 128, 128, 128, 128], [6, 128, 128, 128, 128], [6, 128, 128, 128, 128]]}, 'train': {'batch_size': 64, 'learning_rate': 0.0001, 'num_epochs': 200, 'num_warmup_epochs': 50, 'sched_gamma': 0.95, 'sched_steps': 20}}


In [4]:
ckpt_path = 'checkpoints/ckpt_last.path'
ckpt = torch.load(ckpt_path, map_location='cpu')
model_state_dict = ckpt['model']
model.load_state_dict(model_state_dict)
model.eval()

MLP(
  (embed): Sequential(
    (0): Identity()
    (1): Linear(in_features=6, out_features=128, bias=True)
    (2): LeakyReLU(negative_slope=0.1)
    (3): Identity()
    (4): Linear(in_features=128, out_features=128, bias=True)
    (5): LeakyReLU(negative_slope=0.1)
  )
  (solvers): ModuleList(
    (0-2): 3 x SubsetSolver(
      (model): Sequential(
        (0): Identity()
        (1): Linear(in_features=768, out_features=128, bias=True)
        (2): LeakyReLU(negative_slope=0.1)
        (3): ResLinear(
          (norm_layer): Identity()
          (linear): Linear(in_features=128, out_features=128, bias=True)
          (activ): LeakyReLU(negative_slope=0.1)
        )
        (4): ResLinear(
          (norm_layer): Identity()
          (linear): Linear(in_features=128, out_features=128, bias=True)
          (activ): LeakyReLU(negative_slope=0.1)
        )
        (5): ResLinear(
          (norm_layer): Identity()
          (linear): Linear(in_features=128, out_features=128, bias=True)


## Convert it to an ONNX file

In [5]:
# Create dummy input with the correct shape
in_features = config['model']['in_features']
num_particles = config['data']['num_particles']
dummy_input = torch.randn(1, num_particles, in_features)

# Export to ONNX
torch.onnx.export(
    model,                      # model being run
    dummy_input,                # model input (or a tuple for multiple inputs)
    "mlp.onnx",                 # where to save the model (filename)
    export_params=True,         # store the trained weights inside the model
    opset_version=11,           # the ONNX version to export to (11 is widely supported)
    do_constant_folding=True,   # optimize constants
    input_names=['input'],      # input name (can be arbitrary)
    output_names=['output'],    # output name
    dynamic_axes={              # support dynamic batch size
        'input': {0: 'batch_size'},
        'output': {0: 'batch_size'},
    }
)

print("Model has been exported to ONNX format.")

Model has been exported to ONNX format.


## Load the ONNX model and check its validity

In [6]:
onnx_path = 'mlp.onnx' 
onnx_model = onnx.load(onnx_path)
onnx.checker.check_model(onnx_model)

print("ONNX model is valid!")

ONNX model is valid!


## Test the ONNX model with ONNX Runtime

In [7]:
ort_session = onnxruntime.InferenceSession(onnx_path)

# Numpy dummy input
input_data = dummy_input.numpy()

# Run the ONNX model
ort_inputs = {ort_session.get_inputs()[0].name: input_data}
ort_outputs = ort_session.run(None, ort_inputs)
print("ONNX model tested successfully!")

ONNX model tested successfully!


## Load toy detector data and compare the results of PyTorch model and ONNX model

In [8]:
from rtal.datasets.dataset import ROMDataset
from torch.utils.data import DataLoader

load some example data

In [9]:
data_root = 'data/rom_det-3_part-200_cont-and-rounded_excerpt/'
dataset  = ROMDataset(data_root, split='train', num_particles=50)
dataloader = DataLoader(dataset, batch_size=4, shuffle=False)

get one batch (contains 4 events since we set the `batch_size=4` in the data loader) and run PyTorch model

In [10]:
event = next(iter(dataloader))
# readout generated by the misaligned detectors
# "_cont" here means the coordinates are continuous without any rounding.
readout = event[f'readout_curr_cont'].to(device)
# readout: (batch_size, num_detectors, num_particles, 2)
#        ->(batch_size, num_particles, num_detectors, 2)
#        ->(batch_size, num_particles, num_detectors x 2)
readout = torch.transpose(readout, 1, 2).flatten(-2, -1)
torch_outputs = model(readout)

run the ONNX model

In [11]:
ort_inputs = {ort_session.get_inputs()[0].name: readout.numpy()}
ort_outputs = ort_session.run(None, ort_inputs)[0]

Compare PyTorch and ONNX outputs

In [12]:
difference = np.abs(torch_outputs.detach().numpy() - ort_outputs)
max_diff = np.max(difference)
print(f"Max difference between PyTorch and ONNX outputs: {max_diff:.6f}")

Max difference between PyTorch and ONNX outputs: 0.000000
