-
Notifications
You must be signed in to change notification settings - Fork 13
Fixing Jax Backend #2342
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fixing Jax Backend #2342
Conversation
|
@scarlehoff Just to update on this: No Sum Rules: I have tested this using a run card without sum rules and the same error appears as the xparams are tracked when they are initialised using keras i.e when you run self.x_in = x_in , self.x_in is tracked even if x_in is not. Issue with Preprocessing: I have tried to investigate why when uncommenting the stopping, we get this error: #2318 (comment). I have found that the alphas and betas are accessible after being built and their values are successfully accessed several times after being built. However, for some reason when this error arises, |
|
This last commit seems to fix several problems. The only one I am not sure yet why is it happening is the "Issue with Preprocessing". It seems that calling the validation model's But the interesting thing is that the layer is the same, so one can call the training, and once it is built, call the validation. This is obviously and morally wrong but seems to work... there must be something that can be done to the validation model in order to call it directly without the extra step. |
e478e25 to
1ad74c1
Compare
|
@scarlehoff I'm finding that with this version, I'm getting an indexing error when running with File "/Users/ellacole/codes/nnpdf/nnpdfgit/nnpdf/n3fit/src/n3fit/backends/keras_backend/MetaModel.py", line 175, in perform_fit
history = super().fit(
^^^^^^^^^^^^
File "/Users/ellacole/miniconda3/envs/nnpdf_dev/lib/python3.12/site-packages/keras/src/utils/traceback_utils.py", line 122, in error_handler
raise e.with_traceback(filtered_tb) from None
File "/Users/ellacole/miniconda3/envs/nnpdf_dev/lib/python3.12/site-packages/optree/ops.py", line 747, in tree_map
return treespec.unflatten(map(func, *flat_args))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
IndexError: index 15 is out of bounds for axis 0 with size 1The shapes of the inputs look correct so I'm not sure why this is happening. : DEBUG: x_params shapes:
pdf_input: (100, 152, 1)
xgrid_integration: (100, 2000, 1)
DEBUG: y shapes:
y[0]: (100, 1)
y[1]: (100, 1)
y[2]: (100, 1)
y[3]: (100, 1)
y[4]: (100, 1)
y[5]: (100, 1)
DEBUG: steps_per_epoch: 100
DEBUG: epochs: 900Do you maybe have a run card which works with this |
|
oh, this is a funny one. It has nothing to do with the shape, it is a GPU optimization we did. Just put, at the beginning of this function:
if K.backend() == "jax":
return 1(you'll have to put I'll try to fix it up in a more general way (I should benchmark jax in gpu, haven't done that yet) before merging. Ps: I cannot give you a working runcard because the computer I was using for this is in a storage room right now ^^U but the BasicRuncard should work Edit: I've added the change above and added a |
…ss attribute so that it is not get tracked and saves the weights as numpy objects instead of variables in addition, a spurious call to compute_loss is perform at every epoch which needs to be further studied
af85595 to
dd4f878
Compare
dd4f878 to
d2aa556
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This can be merged as soon as the tests pass. I've rebased on top of the latest master.
There might be a problem with the jax installation in the CI when eko is included because of the version of numpy, and with the regressions because STEPS_PER_EPOCHS is changed globally, so fingers crossed.
|
Greetings from your nice fit 🤖 !
Check the report carefully, and please buy me a ☕ , or better, a GPU 😉! |
This PR aims to fix the jax backend within
n3fit. So far, the input to the model has been converted from a tracked dictionary into a standard Python dictionary and early stopping has been removed.When setting the
KERAS_BACKEND=jaxand running n3fit, this model stops training before the first epoch.