In [1]:
import sys
from pathlib import Path
import yaml
import numpy as np
import torch
import torch.nn as nn
import onnx
import onnxruntime
from onnxsim import simplify
from mlp import MLP

In [2]:
device = 'cpu'
with open('checkpoints/config.yaml', 'r', encoding='UTF-8') as handle:
    config = yaml.safe_load(handle)
print(config)
model = MLP(**config['model'])

{'checkpointing': {'checkpoint_path': '/home/yhuang2/PROJs/RealTimeAlignment/train/mlp_v2/checkpoints', 'resume': True, 'save_frequency': 20}, 'data': {'mode': 'raw', 'num_particles': 50, 'rounded': False}, 'model': {'embedding_features': [128, 128], 'in_features': 6, 'out_features': 27, 'rezero': True, 'norm': None, 'activ': {'name': 'leakyrelu', 'negative_slope': 0.1}, 'subset_config': [[6, 128, 128, 128, 128], [6, 128, 128, 128, 128], [6, 128, 128, 128, 128]]}, 'train': {'batch_size': 64, 'learning_rate': 0.0001, 'num_epochs': 200, 'num_warmup_epochs': 50, 'sched_gamma': 0.95, 'sched_steps': 20}}


In [3]:
import inspect
print(inspect.getsource(model.solvers[0]._assemble_shift))

    def _assemble_shift(self, data):
        """
        assemble the subset by shifting the input
        """
        return torch.cat([torch.roll(data, i, dims=1)
                          for i in range(self.subset_size)],
                         dim=-1)



In [4]:
ckpt_path = 'checkpoints/ckpt_last.path'
ckpt = torch.load(ckpt_path, map_location='cpu')
model_state_dict = ckpt['model']
model.load_state_dict(model_state_dict)
model.eval()

MLP(
  (embed): Sequential(
    (0): Identity()
    (1): Linear(in_features=6, out_features=128, bias=True)
    (2): LeakyReLU(negative_slope=0.1)
    (3): Identity()
    (4): Linear(in_features=128, out_features=128, bias=True)
    (5): LeakyReLU(negative_slope=0.1)
  )
  (solvers): ModuleList(
    (0-2): 3 x SubsetSolver(
      (model): Sequential(
        (0): Identity()
        (1): Linear(in_features=768, out_features=128, bias=True)
        (2): LeakyReLU(negative_slope=0.1)
        (3): ResLinear(
          (norm_layer): Identity()
          (linear): Linear(in_features=128, out_features=128, bias=True)
          (activ): LeakyReLU(negative_slope=0.1)
        )
        (4): ResLinear(
          (norm_layer): Identity()
          (linear): Linear(in_features=128, out_features=128, bias=True)
          (activ): LeakyReLU(negative_slope=0.1)
        )
        (5): ResLinear(
          (norm_layer): Identity()
          (linear): Linear(in_features=128, out_features=128, bias=True)


In [5]:
# Create dummy input with the correct shape
in_features = config['model']['in_features']
num_particles = config['data']['num_particles']
dummy_input = torch.randn(1, num_particles, in_features)

# ✅ Wrap the model to use inference() instead of forward()
class EmbedWrapper(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.embed = model.embed

    def forward(self, x):
        return self.embed(x)
    

embed_model = EmbedWrapper(model)
embed_model.eval()

torch.onnx.export(
    embed_model,
    dummy_input,
    'embed.onnx',
    export_params=True,
    opset_version=11,
    do_constant_folding=True,
    input_names=['input'],
    output_names=['output'],
    dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}}
)
print("Model has been exported to ONNX format.")

Model has been exported to ONNX format.


In [6]:
onnx_model = onnx.load("embed.onnx")
model_simp, check = simplify(onnx_model)
assert check, "Simplified ONNX model could not be validated"
onnx.save(model_simp, "mlp_simplified.onnx")
print("Simplified ONNX model saved.")

Simplified ONNX model saved.


In [7]:
onnx_model = onnx.load("mlp_simplified.onnx")
ops = set([node.op_type for node in onnx_model.graph.node])
print("Operators in simplified model:", ops)
for i, node in enumerate(onnx_model.graph.node):
    print(f"{i:02d}: {node.name} - {node.op_type}")

Operators in simplified model: {'MatMul', 'Add', 'LeakyRelu'}
00: /embed/embed.1/MatMul - MatMul
01: /embed/embed.1/Add - Add
02: /embed/embed.2/LeakyRelu - LeakyRelu
03: /embed/embed.4/MatMul - MatMul
04: /embed/embed.4/Add - Add
05: /embed/embed.5/LeakyRelu - LeakyRelu


In [None]:
import hls4ml
config_hls = hls4ml.utils.config_from_onnx_model('mlp_simplified.onnx')
config_hls['Model']['Strategy'] = 'Latency'
config_hls['Model']['Precision'] = 'ap_fixed<16,6>'
config_hls['IOType'] = 'io_stream'

In [None]:
output_dir = 'hls4ml_mlp_project'
hls_model = hls4ml.converters.convert_from_onnx_model(
    'mlp_simplified.onnx',
    hls_config=config_hls,
    output_dir=output_dir,
    backend='Vivado',
    part='xcu250-figd2104-2L-e'
)
print("hls4ml model conversion complete.")