Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix incompatibility with jax_dataclasses > 1.3.0 #36

Closed
flferretti opened this issue Jun 15, 2023 · 0 comments · Fixed by #38
Closed

Fix incompatibility with jax_dataclasses > 1.3.0 #36

flferretti opened this issue Jun 15, 2023 · 0 comments · Fixed by #38
Assignees

Comments

@flferretti
Copy link
Collaborator

As briefly discussed in #35, the error occurs when using jax_dataclasses > 1.3.0 and after some investigation I found out that the problem is that when get_type_hints, that it was not present in the working version (jax_dataclasses==1.3.0) gets called from _register_pytree_dataclass raises an error in line 80:

https://github.com/brentyi/jax_dataclasses/blob/08ae5bb5d0ebd271d9516e7d0c6d46b3fc48b246/jax_dataclasses/_get_type_hints.py#L66-L82 (this snippet does not load as usual 😕 )

specifically when it tries to evaluate "jaxsim.high_level.model.Model" and that might result in a circular import.

Error
Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/home/flferretti/mambaforge/envs/test103/lib/python3.10/site-packages/jaxsim/__init__.py", line 63, in <module>
    from . import high_level, logging, math, sixd
  File "/home/flferretti/mambaforge/envs/test103/lib/python3.10/site-packages/jaxsim/high_level/__init__.py", line 1, in <module>
    from . import common, joint, link, model
  File "/home/flferretti/mambaforge/envs/test103/lib/python3.10/site-packages/jaxsim/high_level/joint.py", line 12, in <module>
    class Joint(JaxsimDataclass):
  File "/home/flferretti/mambaforge/envs/test103/lib/python3.10/site-packages/jax_dataclasses/_dataclasses.py", line 47, in pytree_dataclass
    return wrap(cls)
  File "/home/flferretti/mambaforge/envs/test103/lib/python3.10/site-packages/jax_dataclasses/_dataclasses.py", line 38, in wrap
    return _register_pytree_dataclass(dataclasses.dataclass(cls, **kwargs))
  File "/home/flferretti/mambaforge/envs/test103/lib/python3.10/site-packages/jax_dataclasses/_dataclasses.py", line 76, in _register_pytree_dataclass
    type_from_name = get_type_hints_partial(cls, include_extras=True)  # type: ignore
  File "/home/flferretti/mambaforge/envs/test103/lib/python3.10/site-packages/jax_dataclasses/_get_type_hints.py", line 80, in get_type_hints_partial
    value = eval(value, base_globals)
  File "<string>", line 1, in <module>
AttributeError: partially initialized module 'jaxsim' has no attribute 'high_level' (most likely due to a circular import)

I'm still working on it. It may be useful to look at jdc_update branch of my fork.

C.C. @diegoferigo @traversaro

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants