Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions com.unity.ml-agents/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ vector observations to be used simultaneously. (#3981) Thank you @shakenes !
### Bug Fixes
- Fixed an issue where SAC would perform too many model updates when resuming from a
checkpoint, and too few when using `buffer_init_steps`. (#4038)
- Fixed a bug in the onnx export that would cause constants needed for inference to not be visible to some versions of
the Barracuda importer. (#4073)
#### com.unity.ml-agents (C#)
#### ml-agents / ml-agents-envs / gym-unity (Python)

Expand Down
32 changes: 3 additions & 29 deletions ml-agents/mlagents/model_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from distutils.version import LooseVersion

try:
import onnx
from tf2onnx.tfonnx import process_tf_graph, tf_optimize
from tf2onnx import optimizer

Expand Down Expand Up @@ -126,16 +125,6 @@ def convert_frozen_to_onnx(
) -> Any:
# This is basically https://github.com/onnx/tensorflow-onnx/blob/master/tf2onnx/convert.py

# Some constants in the graph need to be read by the inference system.
# These aren't used by the model anywhere, so trying to make sure they propagate
# through conversion and import is a losing battle. Instead, save them now,
# so that we can add them back later.
constant_values = {}
for n in frozen_graph_def.node:
if n.name in MODEL_CONSTANTS:
val = n.attr["value"].tensor.int_val[0]
constant_values[n.name] = val

inputs = _get_input_node_names(frozen_graph_def)
outputs = _get_output_node_names(frozen_graph_def)
logger.info(f"onnx export - inputs:{inputs} outputs:{outputs}")
Expand All @@ -157,26 +146,9 @@ def convert_frozen_to_onnx(
onnx_graph = optimizer.optimize_graph(g)
model_proto = onnx_graph.make_model(settings.brain_name)

# Save the constant values back the graph initializer.
# This will ensure the importer gets them as global constants.
constant_nodes = []
for k, v in constant_values.items():
constant_node = _make_onnx_node_for_constant(k, v)
constant_nodes.append(constant_node)
model_proto.graph.initializer.extend(constant_nodes)
return model_proto


def _make_onnx_node_for_constant(name: str, value: int) -> Any:
tensor_value = onnx.TensorProto(
data_type=onnx.TensorProto.INT32,
name=name,
int32_data=[value],
dims=[1, 1, 1, 1],
)
return tensor_value


def _get_input_node_names(frozen_graph_def: Any) -> List[str]:
"""
Get the list of input node names from the graph.
Expand All @@ -201,10 +173,12 @@ def _get_input_node_names(frozen_graph_def: Any) -> List[str]:
def _get_output_node_names(frozen_graph_def: Any) -> List[str]:
"""
Get the list of output node names from the graph.
Also include constants, so that they will be readable by the
onnx importer.
Names are suffixed with ":0"
"""
node_names = _get_frozen_graph_node_names(frozen_graph_def)
output_names = node_names & POSSIBLE_OUTPUT_NODES
output_names = node_names & (POSSIBLE_OUTPUT_NODES | MODEL_CONSTANTS)
# Append the port
return [f"{n}:0" for n in output_names]

Expand Down