-
-
Notifications
You must be signed in to change notification settings - Fork 156
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Introduce a JAX Linker class #21
Conversation
8c40cd3
to
23c4705
Compare
3289194
to
c95ae76
Compare
d6b5d52
to
047295a
Compare
07ea053
to
202bcf2
Compare
This is a cheap work-around for binary operators that actually behave as variadic operators. See aesara-devs/aesara#26.
202bcf2
to
a9facbd
Compare
Interestingly I found a small differences between the gradient from theano-jax and native jax in the code snippet example: dlogpt = theano.grad(cost=model.logpt, wrt=model.free_RVs)
dlogp_logp_fn = theano.function(inputs=model.free_RVs, outputs=[model.logpt] + dlogpt)
dlogp_logp_fn(**model.test_point)
# [DeviceArray(-2.74961873, dtype=float64),
# DeviceArray(0.30685282, dtype=float64),
# DeviceArray(4.4408921e-16, dtype=float64),
# DeviceArray(0., dtype=float64)]
fgraph = theano.gof.FunctionGraph(model.free_RVs, [model.logpt])
logp_fn_jax = jax_funcify(fgraph)[0]
dlogp_logp_fn = jax.value_and_grad(logp_fn_jax, argnums=range(len(model.free_RVs)))
dlogp_logp_fn(
model.test_point['a_log__'],
model.test_point['b_log__'],
model.test_point['x'])
# (DeviceArray(-2.74961873, dtype=float64),
# (DeviceArray(0.30685282, dtype=float64),
# DeviceArray(2.22044605e-16, dtype=float64),
# DeviceArray(-0., dtype=float64))) https://gist.github.com/junpenglao/fe5e1b451c076cc7b4ca16acdd7d6472 |
I'm glad they're so similar, though! Those |
that makes sense - Jax is by default float32, and to enable float64 there are a bit more steps: https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#Double-(64bit)-precision |
As this is basically functional (even if basic) I propose we merge this and start doing individual PRs against this. |
Closes Theano#10. Well, at least the JAX part, but Cython and Numba implementations can follow very similar approaches and accomplish the same thing.
a9facbd
to
0e4bf69
Compare
I just pushed some major fixes that should add functionality and coverage for all the |
Doing this enables general use of `jax` >= 0.2.0; however, we still need to find a better fix for the shape-related problems introduced by omnistaging. Closes Theano#43.
1097f81
to
82edad2
Compare
82edad2
to
a62dec2
Compare
I can now sample with jax 0.2.0. |
I am getting an error of
|
I think I'll merge, it's a bit of a nuisance we can't PR into this. |
OK damn, I was hoping it was only the syntax error (which I fixed on master) but other tests are failing too :-/. |
Quick fix for import issues introduced by #21
This PR introduces a JAX
Linker
class that compiles graphs to XLA usingjax
.Local PyMC3 Testing Example
For anyone interested in trying out this functionality, first install the latest PyMC3 from GitHub (e.g.
pip install --no-deps git+https://github.com/pymc-devs/pymc3
). We need the latest PyMC3 because it contains a Theano import fix that allows it to run using Theano-PyMC (i.e. this project). The option--no-deps
is used to preventpip
from installing the old Theano. If you leave out that option, you'll have topip uninstall theano
.Next, check out this PR branch (e.g.
git clone git@github.com:brandonwillard/Theano.git -b jax-linker
) and install (e.g.pip install -r requirements.txt
) Theano-PyMC.Finally, try a small example in PyMC3:
The output should look similar to the following: