Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

jitted predict #7

Closed
soldierofhell opened this issue Dec 21, 2021 · 2 comments
Closed

jitted predict #7

soldierofhell opened this issue Dec 21, 2021 · 2 comments

Comments

@soldierofhell
Copy link

Hi, I'm starting to explore your framework. I'm familiar with jax, but not with objax. I noticed that train ops are jitted with objax.Jit, but as my goal is to have fast prediction embedded in some larger jax code, I wonder if predit() can be also jitted? Thanks in advance,

Regards,

@wil-j-wil
Copy link
Collaborator

Yes you can jit compile that part just like any other, just include it in your larger chunk of jitted code.
Or if you want to jit that part on its own you can do

predict = objax.Jit(model.predict, model.vars() + opt_hypers.vars())

@wil-j-wil
Copy link
Collaborator

Assuming the above solves this issue. If not, feel free to re-open.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants