-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Description
❓ Question
I would like to save PPO model in types of 'model.pt' by torch.jit.script for general usage. But when I try this
scripted_model = th.jit.script(model.policy)
there comes the error:
`ValueError Traceback (most recent call last)
Cell In[49], line 1
----> 1 th.jit.script(model.policy)
File ~/PycharmProjects/PythonProject/.venv/lib/python3.8/site-packages/torch/jit/_script.py:1432, in script(obj, optimize, _frames_up, _rcb, example_inputs)
1429 _TOPLEVEL = False
1431 try:
-> 1432 return _script_impl(
1433 obj=obj,
1434 optimize=optimize,
1435 _frames_up=_frames_up + 1,
1436 _rcb=_rcb,
1437 example_inputs=example_inputs,
1438 )
1439 finally:
1440 _TOPLEVEL = prev
File ~/PycharmProjects/PythonProject/.venv/lib/python3.8/site-packages/torch/jit/_script.py:1146, in _script_impl(obj, optimize, _frames_up, _rcb, example_inputs)
1144 if isinstance(obj, torch.nn.Module):
1145 obj = call_prepare_scriptable_func(obj)
-> 1146 return torch.jit._recursive.create_script_module(
1147 obj, torch.jit._recursive.infer_methods_to_compile
1148 )
1149 else:
1150 obj = obj.prepare_scriptable() if hasattr(obj, "prepare_scriptable") else obj # type: ignore[operator]
File ~/PycharmProjects/PythonProject/.venv/lib/python3.8/site-packages/torch/jit/_recursive.py:556, in create_script_module(nn_module, stubs_fn, share_types, is_tracing)
554 assert not isinstance(nn_module, torch.jit.RecursiveScriptModule)
555 check_module_initialized(nn_module)
--> 556 concrete_type = get_module_concrete_type(nn_module, share_types)
557 if not is_tracing:
558 AttributeTypeIsSupportedChecker().check(nn_module)
File ~/PycharmProjects/PythonProject/.venv/lib/python3.8/site-packages/torch/jit/_recursive.py:505, in get_module_concrete_type(nn_module, share_types)
501 return nn_module._concrete_type
503 if share_types:
504 # Look into the store of cached JIT types
--> 505 concrete_type = concrete_type_store.get_or_create_concrete_type(nn_module)
506 else:
507 # Get a concrete type directly, without trying to re-use an existing JIT
508 # type from the type store.
509 concrete_type_builder = infer_concrete_type_builder(nn_module, share_types)
File ~/PycharmProjects/PythonProject/.venv/lib/python3.8/site-packages/torch/jit/_recursive.py:437, in ConcreteTypeStore.get_or_create_concrete_type(self, nn_module)
435 def get_or_create_concrete_type(self, nn_module):
436 """Infer a ConcreteType from this nn.Module instance. Underlying JIT types are re-used if possible."""
--> 437 concrete_type_builder = infer_concrete_type_builder(nn_module)
439 nn_module_type = type(nn_module)
440 if nn_module_type not in self.type_store:
File ~/PycharmProjects/PythonProject/.venv/lib/python3.8/site-packages/torch/jit/_recursive.py:272, in infer_concrete_type_builder(nn_module, share_types)
269 if name in user_annotated_ignored_attributes:
270 continue
--> 272 attr_type, _ = infer_type(name, item)
273 if item is None:
274 # Modules can be None. We don't have direct support for optional
275 # Modules, so the register it as an NoneType attribute instead.
276 concrete_type_builder.add_attribute(name, attr_type.type(), False, False)
File ~/PycharmProjects/PythonProject/.venv/lib/python3.8/site-packages/torch/jit/_recursive.py:228, in infer_concrete_type_builder..infer_type(name, item)
222 try:
223 if (
224 name in class_annotations
225 and class_annotations[name]
226 != torch.nn.Module.annotations["forward"]
227 ):
--> 228 ann_to_type = torch.jit.annotations.ann_to_type(
229 class_annotations[name], fake_range()
230 )
231 attr_type = torch._C.InferredType(ann_to_type)
232 elif isinstance(item, torch.jit.Attribute):
File ~/PycharmProjects/PythonProject/.venv/lib/python3.8/site-packages/torch/jit/annotations.py:514, in ann_to_type(ann, loc, rcb)
512 if the_type is not None:
513 return the_type
--> 514 raise ValueError(f"Unknown type annotation: '{ann}' at {loc.highlight()}")
ValueError: Unknown type annotation: '<class 'stable_baselines3.common.torch_layers.BaseFeaturesExtractor'>' at`
I want to ask if there is some method to save the PPO model by torch.jit.script or torch.jit.trace, thx
Checklist
- I have checked that there is no similar issue in the repo
- I have read the documentation
- If code there is, it is minimal and working
- If code there is, it is formatted using the markdown code blocks for both code and stack traces.