Skip to content

[Question] How to save the model.policy by torch.jit.script #2099

@yucthonni

Description

@yucthonni

❓ Question

I would like to save PPO model in types of 'model.pt' by torch.jit.script for general usage. But when I try this
scripted_model = th.jit.script(model.policy)
there comes the error:
`ValueError Traceback (most recent call last)
Cell In[49], line 1
----> 1 th.jit.script(model.policy)

File ~/PycharmProjects/PythonProject/.venv/lib/python3.8/site-packages/torch/jit/_script.py:1432, in script(obj, optimize, _frames_up, _rcb, example_inputs)
1429 _TOPLEVEL = False
1431 try:
-> 1432 return _script_impl(
1433 obj=obj,
1434 optimize=optimize,
1435 _frames_up=_frames_up + 1,
1436 _rcb=_rcb,
1437 example_inputs=example_inputs,
1438 )
1439 finally:
1440 _TOPLEVEL = prev

File ~/PycharmProjects/PythonProject/.venv/lib/python3.8/site-packages/torch/jit/_script.py:1146, in _script_impl(obj, optimize, _frames_up, _rcb, example_inputs)
1144 if isinstance(obj, torch.nn.Module):
1145 obj = call_prepare_scriptable_func(obj)
-> 1146 return torch.jit._recursive.create_script_module(
1147 obj, torch.jit._recursive.infer_methods_to_compile
1148 )
1149 else:
1150 obj = obj.prepare_scriptable() if hasattr(obj, "prepare_scriptable") else obj # type: ignore[operator]

File ~/PycharmProjects/PythonProject/.venv/lib/python3.8/site-packages/torch/jit/_recursive.py:556, in create_script_module(nn_module, stubs_fn, share_types, is_tracing)
554 assert not isinstance(nn_module, torch.jit.RecursiveScriptModule)
555 check_module_initialized(nn_module)
--> 556 concrete_type = get_module_concrete_type(nn_module, share_types)
557 if not is_tracing:
558 AttributeTypeIsSupportedChecker().check(nn_module)

File ~/PycharmProjects/PythonProject/.venv/lib/python3.8/site-packages/torch/jit/_recursive.py:505, in get_module_concrete_type(nn_module, share_types)
501 return nn_module._concrete_type
503 if share_types:
504 # Look into the store of cached JIT types
--> 505 concrete_type = concrete_type_store.get_or_create_concrete_type(nn_module)
506 else:
507 # Get a concrete type directly, without trying to re-use an existing JIT
508 # type from the type store.
509 concrete_type_builder = infer_concrete_type_builder(nn_module, share_types)

File ~/PycharmProjects/PythonProject/.venv/lib/python3.8/site-packages/torch/jit/_recursive.py:437, in ConcreteTypeStore.get_or_create_concrete_type(self, nn_module)
435 def get_or_create_concrete_type(self, nn_module):
436 """Infer a ConcreteType from this nn.Module instance. Underlying JIT types are re-used if possible."""
--> 437 concrete_type_builder = infer_concrete_type_builder(nn_module)
439 nn_module_type = type(nn_module)
440 if nn_module_type not in self.type_store:

File ~/PycharmProjects/PythonProject/.venv/lib/python3.8/site-packages/torch/jit/_recursive.py:272, in infer_concrete_type_builder(nn_module, share_types)
269 if name in user_annotated_ignored_attributes:
270 continue
--> 272 attr_type, _ = infer_type(name, item)
273 if item is None:
274 # Modules can be None. We don't have direct support for optional
275 # Modules, so the register it as an NoneType attribute instead.
276 concrete_type_builder.add_attribute(name, attr_type.type(), False, False)

File ~/PycharmProjects/PythonProject/.venv/lib/python3.8/site-packages/torch/jit/_recursive.py:228, in infer_concrete_type_builder..infer_type(name, item)
222 try:
223 if (
224 name in class_annotations
225 and class_annotations[name]
226 != torch.nn.Module.annotations["forward"]
227 ):
--> 228 ann_to_type = torch.jit.annotations.ann_to_type(
229 class_annotations[name], fake_range()
230 )
231 attr_type = torch._C.InferredType(ann_to_type)
232 elif isinstance(item, torch.jit.Attribute):

File ~/PycharmProjects/PythonProject/.venv/lib/python3.8/site-packages/torch/jit/annotations.py:514, in ann_to_type(ann, loc, rcb)
512 if the_type is not None:
513 return the_type
--> 514 raise ValueError(f"Unknown type annotation: '{ann}' at {loc.highlight()}")

ValueError: Unknown type annotation: '<class 'stable_baselines3.common.torch_layers.BaseFeaturesExtractor'>' at`
I want to ask if there is some method to save the PPO model by torch.jit.script or torch.jit.trace, thx

Checklist

Metadata

Metadata

Assignees

No one assigned

    Labels

    RTFMAnswer is the documentationcheck the checklistYou have checked the required items in the checklist but you didn't do what is written...questionFurther information is requested

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions