Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

python3Packages.jax: towards fixing dependencies #297146

Merged
merged 4 commits into from
Apr 2, 2024

Conversation

GaetanLepage
Copy link
Contributor

@GaetanLepage GaetanLepage commented Mar 19, 2024

Description of changes

Things done

  • Built on platform(s)
    • x86_64-linux
    • aarch64-linux
    • x86_64-darwin
    • aarch64-darwin
  • For non-Linux: Is sandboxing enabled in nix.conf? (See Nix manual)
    • sandbox = relaxed
    • sandbox = true
  • Tested, as applicable:
  • Tested compilation of all packages that depend on this change using nix-shell -p nixpkgs-review --run "nixpkgs-review rev HEAD". Note: all changes have to be committed, also see nixpkgs-review usage
  • Tested basic functionality of all binary files (usually in ./result/bin/)
  • 24.05 Release Notes (or backporting 23.05 and 23.11 Release notes)
    • (Package updates) Added a release notes entry if the change is major or breaking
    • (Module updates) Added a release notes entry if the change is significant
    • (Module addition) Added a release notes entry if adding a new NixOS module
  • Fits CONTRIBUTING.md.

Add a 👍 reaction to pull requests you find important.

@natsukium natsukium added the 8.has: upstream changes reviewed Reviewer checked the changelogs/commit logs associated with the release and did not find any issues. label Mar 19, 2024
@GaetanLepage
Copy link
Contributor Author

Result of nixpkgs-review pr 297146 run on x86_64-linux 1

4 packages marked as broken and skipped:
  • python311Packages.elegy
  • python311Packages.elegy.dist
  • python311Packages.treex
  • python311Packages.treex.dist
4 packages failed to build:
  • python311Packages.distrax
  • python311Packages.distrax.dist
  • python311Packages.rlax
  • python311Packages.rlax.dist
21 packages built:
  • python311Packages.bambi
  • python311Packages.bambi.dist
  • python311Packages.blackjax
  • python311Packages.blackjax.dist
  • python311Packages.chex
  • python311Packages.chex.dist
  • python311Packages.dalle-mini
  • python311Packages.dalle-mini.dist
  • python311Packages.dm-haiku
  • python311Packages.dm-haiku.dist
  • python311Packages.equinox
  • python311Packages.equinox.dist
  • python311Packages.flax
  • python311Packages.flax.dist
  • python311Packages.jaxopt
  • python311Packages.jaxopt.dist
  • python311Packages.optax
  • python311Packages.optax.dist
  • python311Packages.optax.testsout
  • python311Packages.vqgan-jax
  • python311Packages.vqgan-jax.dist

@natsukium
Copy link
Member

google-deepmind/chex#333 breaks distrax's tests.
e.g.

distrax> ___________ DeterministicTest.test_sample_dtype_float64__with_device ___________
distrax> [gw1] linux -- Python 3.11.8 /nix/store/3v2ch16fkl50i85n05h5ckss8pxx6836-python3-3.11.8/bin/python3.11
distrax> 
distrax> self = <distrax._src.distributions.deterministic_test.DeterministicTest testMethod=test_sample_dtype_float64__with_device>
distrax> dtype = <class 'jax.numpy.float64'>
distrax> 
distrax>     @chex.all_variants
distrax>     @parameterized.named_parameters(
distrax>         ('int32', jnp.int32),
distrax>         ('int64', jnp.int64),
distrax>         ('float32', jnp.float32),
distrax>         ('float64', jnp.float64))
distrax>     def test_sample_dtype(self, dtype):
distrax>       dist = self.distrax_cls(loc=jnp.zeros((), dtype=dtype))
distrax>       samples = self.variant(dist.sample)(seed=self.key)
distrax>       self.assertEqual(samples.dtype, dist.dtype)
distrax> >     chex.assert_type(samples, dtype)
distrax> 
distrax> distrax/_src/distributions/deterministic_test.py:113: 
distrax> _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
distrax> /nix/store/0h9rhxq0mj1y474g3zafm9arkrq41lq8-python3.11-chex-0.1.86/lib/python3.11/site-packages/chex/_src/asserts_internal.py:278: in _chex_assert_fn
distrax>     host_assertion_fn(
distrax> _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
distrax> 
distrax> custom_message = None, custom_message_format_vars = ()
distrax> include_default_message = True, exception_type = <class 'AssertionError'>
distrax> args = (Array(0., dtype=float32), <class 'jax.numpy.float64'>), kwargs = {}
distrax> assertion_exc = AssertionError('Error in type compatibility check: input 0 has type float32 but expected float64.')
distrax> value_exc = None
distrax> error_msg = '[Chex] Assertion assert_type failed: Error in type compatibility check: input 0 has type float32 but expected float64.'
distrax> default_msg = 'Assertion assert_type failed: '
distrax> 
distrax>     def _assert_on_host(*args,
distrax>                         custom_message: Optional[str] = None,
distrax>                         custom_message_format_vars: Sequence[Any] = (),
distrax>                         include_default_message: bool = True,
distrax>                         exception_type: Type[Exception] = AssertionError,
distrax>                         **kwargs) -> None:
distrax>       # Format error's stack trace to remove Chex' internal frames.
distrax>       assertion_exc = None
distrax>       value_exc = None
distrax>       try:
distrax>         assert_fn(*args, **kwargs)
distrax>       except AssertionError as e:
distrax>         assertion_exc = e
distrax>       except ValueError as e:
distrax>         value_exc = e
distrax>       finally:
distrax>         if value_exc is not None:
distrax>           raise ValueError(str(value_exc))
distrax>     
distrax>         if assertion_exc is not None:
distrax>           # Format the exception message.
distrax>           error_msg = str(assertion_exc)
distrax>     
distrax>           # Include only the name of the outermost chex assertion.
distrax>           if error_msg.startswith(ERR_PREFIX):
distrax>             error_msg = error_msg[error_msg.find("failed:") + len("failed:"):]
distrax>     
distrax>           # Whether to include the default error message.
distrax>           default_msg = (f"Assertion {name} failed: "
distrax>                          if include_default_message else "")
distrax>           error_msg = f"{ERR_PREFIX}{default_msg}{error_msg}"
distrax>     
distrax>           # Whether to include a custom error message.
distrax>           if custom_message:
distrax>             if custom_message_format_vars:
distrax>               custom_message = custom_message.format(*custom_message_format_vars)
distrax>             error_msg = f"{error_msg} [{custom_message}]"
distrax>     
distrax> >         raise exception_type(error_msg)
distrax> E         AssertionError: [Chex] Assertion assert_type failed: Error in type compatibility check: input 0 has type float32 but expected float64.
distrax> 
distrax> /nix/store/0h9rhxq0mj1y474g3zafm9arkrq41lq8-python3.11-chex-0.1.86/lib/python3.11/site-packages/chex/_src/asserts_internal.py:196: AssertionError

@stephen-huan
Copy link
Member

@natsukium there is an upstream PR (google-deepmind/distrax#270) fixing the issue. For now, I've opened #300648.

Note that chex fails to build on current nixos-unstable (d8fe5e6) but the version bump in this PR fixes the issue.

nix build .#python311Packages.chex
error: builder for '/nix/store/b4736piijw4fjn89b6dbr8y39i5gvrkn-python3.11-chex-0.1.85.drv' failed with exit code 1;
       last 10 log lines:
       >
       > chex/_src/variants_test.py::CountVariantsTest::test_counters
       >   /build/source/chex/_src/variants_test.py:524: DeprecationWarning: unittest.makeSuite() is deprecated and will be removed in Python 3.13. Please use unittest.TestLoader.loadTestsFromTestCase() instead.
       >     ts = unittest.makeSuite(self.InnerTest)  # pytype: disable=module-attr
       >
       > -- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
       > =========================== short test summary info ============================
       > FAILED chex/_src/asserts_chexify_test.py::AssertsLibraryTest::test_assert_trees_all_close - AssertionError: AssertionError not raised
       > =========== 1 failed, 522 passed, 85 skipped, 54 warnings in 53.19s ============
       > /nix/store/c8dj731bkcdzhgrpawhc8qvdgls4xfjv-stdenv-linux/setup: line 1578: pop_var_context: head of shell_variables not a function context
       For full logs, run 'nix-store -l /nix/store/b4736piijw4fjn89b6dbr8y39i5gvrkn-python3.11-chex-0.1.85.drv'.

@GaetanLepage
Copy link
Contributor Author

Result of nixpkgs-review pr 297146 run on x86_64-linux 1

4 packages marked as broken and skipped:
  • python311Packages.elegy
  • python311Packages.elegy.dist
  • python311Packages.treex
  • python311Packages.treex.dist
28 packages failed to build:
  • python311Packages.bambi
  • python311Packages.bambi.dist
  • python311Packages.blackjax
  • python311Packages.blackjax.dist
  • python311Packages.dalle-mini
  • python311Packages.dalle-mini.dist
  • python311Packages.distrax
  • python311Packages.distrax.dist
  • python311Packages.dm-haiku
  • python311Packages.dm-haiku.dist
  • python311Packages.equinox
  • python311Packages.equinox.dist
  • python311Packages.flax
  • python311Packages.flax.dist
  • python311Packages.jaxopt
  • python311Packages.jaxopt.dist
  • python311Packages.optax
  • python311Packages.optax.dist
  • python311Packages.optax.testsout
  • python311Packages.rlax
  • python311Packages.rlax.dist
  • python311Packages.vqgan-jax
  • python311Packages.vqgan-jax.dist
  • python312Packages.equinox
  • python312Packages.equinox.dist
  • python312Packages.optax
  • python312Packages.optax.dist
  • python312Packages.optax.testsout
4 packages built:
  • python311Packages.chex
  • python311Packages.chex.dist
  • python312Packages.chex
  • python312Packages.chex.dist

@GaetanLepage
Copy link
Contributor Author

Result of nixpkgs-review pr 297146 run on aarch64-linux 1

2 packages marked as broken and skipped:
  • python311Packages.bambi
  • python311Packages.bambi.dist
14 packages failed to build:
  • python311Packages.blackjax
  • python311Packages.blackjax.dist
  • python311Packages.equinox
  • python311Packages.equinox.dist
  • python311Packages.jaxopt
  • python311Packages.jaxopt.dist
  • python311Packages.optax
  • python311Packages.optax.dist
  • python311Packages.optax.testsout
  • python312Packages.equinox
  • python312Packages.equinox.dist
  • python312Packages.optax
  • python312Packages.optax.dist
  • python312Packages.optax.testsout
4 packages built:
  • python311Packages.chex
  • python311Packages.chex.dist
  • python312Packages.chex
  • python312Packages.chex.dist

@stephen-huan
Copy link
Member

@GaetanLepage btw, chex, optax, distrax, and flax are all independently broken.

For flax, see my comments on your PR. Sorry that everything is so scattered!

@GaetanLepage
Copy link
Contributor Author

This one now fixes both chex and optax.

@GaetanLepage
Copy link
Contributor Author

Result of nixpkgs-review pr 297146 run on aarch64-linux 1

2 packages marked as broken and skipped:
  • python311Packages.bambi
  • python311Packages.bambi.dist
6 packages failed to build:
  • python311Packages.blackjax
  • python311Packages.blackjax.dist
  • python311Packages.equinox
  • python311Packages.equinox.dist
  • python312Packages.equinox
  • python312Packages.equinox.dist
12 packages built:
  • python311Packages.chex
  • python311Packages.chex.dist
  • python311Packages.jaxopt
  • python311Packages.jaxopt.dist
  • python311Packages.optax
  • python311Packages.optax.dist
  • python311Packages.optax.testsout
  • python312Packages.chex
  • python312Packages.chex.dist
  • python312Packages.optax
  • python312Packages.optax.dist
  • python312Packages.optax.testsout

@GaetanLepage
Copy link
Contributor Author

Result of nixpkgs-review pr 297146 run on x86_64-linux 1

4 packages marked as broken and skipped:
  • python311Packages.elegy
  • python311Packages.elegy.dist
  • python311Packages.treex
  • python311Packages.treex.dist
16 packages failed to build:
  • python311Packages.dalle-mini
  • python311Packages.dalle-mini.dist
  • python311Packages.distrax
  • python311Packages.distrax.dist
  • python311Packages.dm-haiku
  • python311Packages.dm-haiku.dist
  • python311Packages.equinox
  • python311Packages.equinox.dist
  • python311Packages.flax
  • python311Packages.flax.dist
  • python311Packages.rlax
  • python311Packages.rlax.dist
  • python311Packages.vqgan-jax
  • python311Packages.vqgan-jax.dist
  • python312Packages.equinox
  • python312Packages.equinox.dist
16 packages built:
  • python311Packages.bambi
  • python311Packages.bambi.dist
  • python311Packages.blackjax
  • python311Packages.blackjax.dist
  • python311Packages.chex
  • python311Packages.chex.dist
  • python311Packages.jaxopt
  • python311Packages.jaxopt.dist
  • python311Packages.optax
  • python311Packages.optax.dist
  • python311Packages.optax.testsout
  • python312Packages.chex
  • python312Packages.chex.dist
  • python312Packages.optax
  • python312Packages.optax.dist
  • python312Packages.optax.testsout

@GaetanLepage
Copy link
Contributor Author

I propose to merge this PR as it ships the fixes for chex and optax.
We will then be allow to work on the fixes for flax & co on subsequent PRs.

@SomeoneSerge
Copy link
Contributor

@GaetanLepage would you like to mark the broken packages broken before I merge? So they don't generate noise in further nixpkgs-reviews...

@GaetanLepage GaetanLepage changed the title python311Packages.chex: 0.1.85 -> 0.1.86 Fix some jax-related python packages Apr 1, 2024
@GaetanLepage GaetanLepage changed the title Fix some jax-related python packages Fix/update some jax-related python packages Apr 1, 2024
@ofborg ofborg bot requested a review from onny April 1, 2024 23:21
@GaetanLepage
Copy link
Contributor Author

Result of nixpkgs-review pr 297146 run on x86_64-linux 1

8 packages marked as broken and skipped:
  • python311Packages.distrax
  • python311Packages.distrax.dist
  • python311Packages.elegy
  • python311Packages.elegy.dist
  • python311Packages.rlax
  • python311Packages.rlax.dist
  • python311Packages.treex
  • python311Packages.treex.dist
24 packages built:
  • python311Packages.bambi
  • python311Packages.bambi.dist
  • python311Packages.blackjax
  • python311Packages.blackjax.dist
  • python311Packages.dalle-mini
  • python311Packages.dalle-mini.dist
  • python311Packages.dm-haiku
  • python311Packages.dm-haiku.dist
  • python311Packages.equinox
  • python311Packages.equinox.dist
  • python311Packages.flax
  • python311Packages.flax.dist
  • python311Packages.jaxopt
  • python311Packages.jaxopt.dist
  • python311Packages.optax
  • python311Packages.optax.dist
  • python311Packages.optax.testsout
  • python311Packages.vqgan-jax
  • python311Packages.vqgan-jax.dist
  • python312Packages.equinox
  • python312Packages.equinox.dist
  • python312Packages.optax
  • python312Packages.optax.dist
  • python312Packages.optax.testsout

@SomeoneSerge SomeoneSerge changed the title Fix/update some jax-related python packages python3Packages.jax: towards fixing dependencies Apr 2, 2024
@SomeoneSerge SomeoneSerge merged commit 4c2f2f1 into NixOS:master Apr 2, 2024
29 of 31 checks passed
@stephen-huan
Copy link
Member

@GaetanLepage is there a reason distrax is marked as broken rather than the suppressing the failing tests as is done already and as in #300648? As far as I can tell chex is used only in tests and the package functions as expected.

@stephen-huan
Copy link
Member

also the bump to the version of chex seems to have been lost

@GaetanLepage GaetanLepage deleted the chex branch April 2, 2024 06:42
@GaetanLepage
Copy link
Contributor Author

also the bump to the version of chex seems to have been lost

This was done in the meantime in d55fb2a

@GaetanLepage
Copy link
Contributor Author

@GaetanLepage is there a reason distrax is marked as broken rather than the suppressing the failing tests as is done already and as in #300648? As far as I can tell chex is used only in tests and the package functions as expected.

Maybe we should report those failures upstream.
I was not sure about their gravity.

@stephen-huan
Copy link
Member

There's already an upstream PR (google-deepmind/distrax#270) that (mostly) fixes the issue.

@GaetanLepage
Copy link
Contributor Author

There's already an upstream PR (google-deepmind/distrax#270) that (mostly) fixes the issue.

Oh cool ! I hope it will be merged soon then.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
6.topic: python 8.has: upstream changes reviewed Reviewer checked the changelogs/commit logs associated with the release and did not find any issues. 10.rebuild-darwin: 0 10.rebuild-linux: 11-100
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants