In [None]:
import sys
from pathlib import Path
import yaml
import numpy as np
import torch
import torch.nn as nn
import onnx
import onnxruntime
from onnxsim import simplify
from mlp import MLP

In [None]:
device = 'cpu'
with open('checkpoints/config.yaml', 'r', encoding='UTF-8') as handle:
    config = yaml.safe_load(handle)
print(config)
model = MLP(**config['model'])

In [None]:
import inspect
print(inspect.getsource(model.solvers[0]._assemble_shift))

In [None]:
ckpt_path = 'checkpoints/ckpt_last.path'
ckpt = torch.load(ckpt_path, map_location='cpu')
model_state_dict = ckpt['model']
model.load_state_dict(model_state_dict)
model.eval()

In [None]:
# Create dummy input with the correct shape
in_features = config['model']['in_features']
num_particles = config['data']['num_particles']
dummy_input = torch.randn(1, num_particles, in_features)

# ✅ Wrap the model to use inference() instead of forward()
class EmbedWrapper(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.embed = model.embed

    def forward(self, x):
        return self.embed(x)
    

embed_model = EmbedWrapper(model)
embed_model.eval()

torch.onnx.export(
    embed_model,
    dummy_input,
    'embed.onnx',
    export_params=True,
    opset_version=11,
    do_constant_folding=True,
    input_names=['input'],
    output_names=['output'],
    dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}}
)
print("Model has been exported to ONNX format.")

In [None]:
onnx_model = onnx.load("mlp.onnx")
model_simp, check = simplify(onnx_model)
assert check, "Simplified ONNX model could not be validated"
onnx.save(model_simp, "mlp_simplified.onnx")
print("Simplified ONNX model saved.")

In [None]:
onnx_model = onnx.load("mlp_simplified.onnx")
ops = set([node.op_type for node in onnx_model.graph.node])
print("Operators in simplified model:", ops)

In [None]:
import hls4ml
config_hls = hls4ml.utils.config_from_onnx_model('mlp_simplified.onnx')
config_hls['Model']['Strategy'] = 'Latency'
config_hls['Model']['Precision'] = 'ap_fixed<16,6>'
config_hls['IOType'] = 'io_stream'

In [None]:
output_dir = 'hls4ml_mlp_project'
hls_model = hls4ml.converters.convert_from_onnx_model(
    'mlp_simplified.onnx',
    hls_config=config_hls,
    output_dir=output_dir,
    backend='Vivado',
    part='xcu250-figd2104-2L-e'
)
print("hls4ml model conversion complete.")