Skip to content

Conversation

@ecole41
Copy link
Collaborator

@ecole41 ecole41 commented Jun 30, 2025

This PR aims to fix the jax backend within n3fit. So far, the input to the model has been converted from a tracked dictionary into a standard Python dictionary and early stopping has been removed.

When setting the KERAS_BACKEND=jax and running n3fit, this model stops training before the first epoch.

@ecole41 ecole41 requested a review from scarlehoff June 30, 2025 13:34
@ecole41 ecole41 mentioned this pull request Jun 30, 2025
@ecole41
Copy link
Collaborator Author

ecole41 commented Jul 4, 2025

@scarlehoff Just to update on this:

No Sum Rules: I have tested this using a run card without sum rules and the same error appears as the xparams are tracked when they are initialised using keras i.e when you run self.x_in = x_in , self.x_in is tracked even if x_in is not.

Issue with Preprocessing: I have tried to investigate why when uncommenting the stopping, we get this error: #2318 (comment). I have found that the alphas and betas are accessible after being built and their values are successfully accessed several times after being built. However, for some reason when this error arises, a.__jax_array__() is not able to be run for any a value in alphas.
I'm not sure why this happens and can't see in the trace why this would just happen one time and not the previous times this function is called.

@scarlehoff
Copy link
Member

scarlehoff commented Jul 14, 2025

This last commit seems to fix several problems. The only one I am not sure yet why is it happening is the "Issue with Preprocessing". It seems that calling the validation model's compute_losses fails (while calling the training one doesn't).

But the interesting thing is that the layer is the same, so one can call the training, and once it is built, call the validation. This is obviously and morally wrong but seems to work... there must be something that can be done to the validation model in order to call it directly without the extra step.

@ecole41
Copy link
Collaborator Author

ecole41 commented Jul 30, 2025

@scarlehoff I'm finding that with this version, I'm getting an indexing error when running with n3fit/runcards/examples/Basic_runcard.yml:

  File "/Users/ellacole/codes/nnpdf/nnpdfgit/nnpdf/n3fit/src/n3fit/backends/keras_backend/MetaModel.py", line 175, in perform_fit
    history = super().fit(
              ^^^^^^^^^^^^
  File "/Users/ellacole/miniconda3/envs/nnpdf_dev/lib/python3.12/site-packages/keras/src/utils/traceback_utils.py", line 122, in error_handler
    raise e.with_traceback(filtered_tb) from None
  File "/Users/ellacole/miniconda3/envs/nnpdf_dev/lib/python3.12/site-packages/optree/ops.py", line 747, in tree_map
    return treespec.unflatten(map(func, *flat_args))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
IndexError: index 15 is out of bounds for axis 0 with size 1

The shapes of the inputs look correct so I'm not sure why this is happening. :

DEBUG: x_params shapes:
  pdf_input: (100, 152, 1)
  xgrid_integration: (100, 2000, 1)
DEBUG: y shapes:
  y[0]: (100, 1)
  y[1]: (100, 1)
  y[2]: (100, 1)
  y[3]: (100, 1)
  y[4]: (100, 1)
  y[5]: (100, 1)
DEBUG: steps_per_epoch: 100
DEBUG: epochs: 900

Do you maybe have a run card which works with this

@scarlehoff
Copy link
Member

scarlehoff commented Jul 30, 2025

oh, this is a funny one. It has nothing to do with the shape, it is a GPU optimization we did.

Just put, at the beginning of this function:

num_replicas = self.output_shape[0]

if K.backend() == "jax":
    return 1

(you'll have to put from keras import backend as K at the top)

I'll try to fix it up in a more general way (I should benchmark jax in gpu, haven't done that yet) before merging.

Ps: I cannot give you a working runcard because the computer I was using for this is in a storage room right now ^^U but the BasicRuncard should work

Edit: I've added the change above and added a run_jax test just like run_torch. Seems to work fine ^^ The error is not related to jax.

@scarlehoff scarlehoff changed the title [WIP] Fixing Jax Backend Fixing Jax Backend Aug 13, 2025
ecole41 and others added 4 commits August 13, 2025 10:54
…ss attribute so that it is not get tracked and saves the weights as numpy objects instead of variables

in addition, a spurious call to compute_loss is perform at every epoch
which needs to be further studied
@scarlehoff scarlehoff force-pushed the fix_jax_backend branch 2 times, most recently from af85595 to dd4f878 Compare August 13, 2025 08:59
@scarlehoff scarlehoff added the run-fit-bot Starts fit bot from a PR. label Aug 13, 2025
Copy link
Member

@scarlehoff scarlehoff left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be merged as soon as the tests pass. I've rebased on top of the latest master.

There might be a problem with the jax installation in the CI when eko is included because of the version of numpy, and with the regressions because STEPS_PER_EPOCHS is changed globally, so fingers crossed.

@github-actions
Copy link

Greetings from your nice fit 🤖 !
I have good news for you, I just finished my tasks:

Check the report carefully, and please buy me a ☕ , or better, a GPU 😉!

@scarlehoff scarlehoff merged commit a2c8c3f into master Aug 14, 2025
15 checks passed
@scarlehoff scarlehoff deleted the fix_jax_backend branch August 14, 2025 07:21
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

run-fit-bot Starts fit bot from a PR.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants