# Exporting PyTorch Models

Some short intro code.

- [ExecuTorch Runtime Overview](https://pytorch.org/executorch/stable/runtime-overview.html)
- [Deploying Torch-TensorRT Programs](https://pytorch.org/TensorRT/tutorials/runtime.html)
- [`torch.onnx`](https://pytorch.org/docs/stable/onnx.html)
- [TorchDynamo-based ONNX Exporter](https://pytorch.org/docs/stable/onnx_dynamo.html#torchdynamo-based-onnx-exporter)
  - this one might be more difficult to make work
- [TorchScript-based ONNX Exporter](https://pytorch.org/docs/stable/onnx_torchscript.html)
  - this is what i used to export the model here

In [None]:
from pathlib import Path

import torch
import deeplabcut as dlc

from deeplabcut.pose_estimation_pytorch.config import read_config_as_dict, pretty_print
from deeplabcut.pose_estimation_pytorch.models import PoseModel

In [None]:
root = Path("/Users/annastuckert/Documents/DLC_AI_Residency/DLC_AI2024/DeepLabCut-live/Ventral_gait_model/train")
model_cfg = read_config_as_dict(root / "pytorch_config.yaml")
weights_path = root / "snapshot-200.pt"
#dest_dict = Path("/media1/data/anna/DLC_AI2024/DeepLabCut-live/ONNX_files")

#print(weights_path)
#print()

pretty_print(model_cfg["model"])

In [None]:
model = PoseModel.build(model_cfg["model"])
weights = torch.load(weights_path, map_location="cpu")
model.load_state_dict(weights['model'])


In [None]:
#!pip install --upgrade onnx onnxscript


In [None]:
outputs = model(torch.ones((1, 3, 128, 128)))
predictions = model.get_predictions(outputs)

print(predictions)

In [6]:
dummy_input = torch.zeros((3, 640, 480))

torch.onnx.export(
    model,
    dummy_input,
    root / "resnet.onnx",
    verbose=True,
)

In [None]:
import onnx

# Load the ONNX model
onnx_model = onnx.load(root / "resnet.onnx")

# Check that the model is well formed
onnx.checker.check_model(onnx_model)

# Print a human readable representation of the graph
print(onnx.helper.printable_graph(onnx_model.graph))

In [None]:
#pip install onnxruntime

In [None]:
import onnxruntime as ort
import numpy as np

ort_session = ort.InferenceSession(root / "resnet.onnx")

outputs = ort_session.run(
    None,
    {"actual_input_1": np.random.randn(10, 3, 640, 480).astype(np.float32)},
)
print(outputs[0])