Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add Flax Layers #195

Closed
wants to merge 1 commit into from
Closed

feat: Add Flax Layers #195

wants to merge 1 commit into from

Conversation

SauravMaheshkar
Copy link

Adds Flax Module's for einops operations. More details can be found in #153

einops supports JAX arrays, so I'm unsure if I should add tests or not ?

CC: @arogozhnikov

@arogozhnikov
Copy link
Owner

I should add tests or not

Yes, tests are needed, because we need to ensure that following things are working:

  1. serialization/deserialization of weights
  2. serialization/deserialization/pickling/deepcopy-ing of models
  3. layers normally interact with other layers

For other frameworks, I just have some lenet-style model with layers, and check that 1 and 2 work for the whole model

@arogozhnikov
Copy link
Owner

@SauravMaheshkar here is an example for torch:
https://github.com/arogozhnikov/einops/blob/master/tests/test_layers.py#L200-L231

Goal of adding tests is to make sure that layers can be serialized / deserialized using native flax serialization mechanisms

@SauravMaheshkar
Copy link
Author

Request for Clarification:

has_flax = any(backend.framework_name == 'jax' for backend in collect_test_backends(symbolic=False, layers=True))

would be the correct way ?

@arogozhnikov
Copy link
Owner

hm. Yes, that's tricky, flax doesn't have own backend.

try:
    import flax
    has_flax = True
except:
    has_flax = False

@arogozhnikov
Copy link
Owner

Hi @SauravMaheshkar I think something was changed in the flax, because the code does not work.

I have another implementation #214 that propbably accounts for all strangeness of flax, so closing this. Testing is very welcome!

@SauravMaheshkar SauravMaheshkar deleted the add-flax branch October 3, 2022 09:10
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants