Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce a JAX Linker class #21

Merged
merged 8 commits into from
Sep 27, 2020
Merged

Conversation

brandonwillard
Copy link
Member

@brandonwillard brandonwillard commented Jul 31, 2020

This PR introduces a JAX Linker class that compiles graphs to XLA using jax.

Local PyMC3 Testing Example

For anyone interested in trying out this functionality, first install the latest PyMC3 from GitHub (e.g. pip install --no-deps git+https://github.com/pymc-devs/pymc3). We need the latest PyMC3 because it contains a Theano import fix that allows it to run using Theano-PyMC (i.e. this project). The option --no-deps is used to prevent pip from installing the old Theano. If you leave out that option, you'll have to pip uninstall theano.

Next, check out this PR branch (e.g. git clone git@github.com:brandonwillard/Theano.git -b jax-linker) and install (e.g. pip install -r requirements.txt) Theano-PyMC.

Finally, try a small example in PyMC3:

import theano

import pymc3 as pm

# Disable C compilation by default
theano.config.cxx = ""

# This will make the JAX Linker the default
theano.config.mode = "JAX"

# Create a simple model
with pm.Model() as model:
    a = pm.Exponential("a", 1.0)
    b = pm.InverseGamma("b", 0.5, 0.5)
    x = pm.Normal("x", a, b)

# Evaluate the log-likelihood in JAX!
model.logp(model.test_point)

The output should look similar to the following:

/home/bwillard/apps/anaconda3/envs/theano-36/lib/python3.7/site-packages/jax/lax/lax.py:5905: UserWarning: Explicitly requested dtype float64 requested in array is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
  warnings.warn(msg.format(dtype, fun_name , truncated_dtype))
/home/bwillard/apps/anaconda3/envs/theano-36/lib/python3.7/site-packages/jax/lax/lax.py:5905: UserWarning: Explicitly requested dtype float64 requested in astype is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
  warnings.warn(msg.format(dtype, fun_name , truncated_dtype))
Out[37]: DeviceArray(-2.749619, dtype=float32)

@brandonwillard brandonwillard added the enhancement New feature or request label Jul 31, 2020
@brandonwillard brandonwillard self-assigned this Jul 31, 2020
@brandonwillard brandonwillard linked an issue Jul 31, 2020 that may be closed by this pull request
@brandonwillard brandonwillard force-pushed the jax-linker branch 10 times, most recently from 3289194 to c95ae76 Compare August 6, 2020 02:40
@brandonwillard brandonwillard force-pushed the jax-linker branch 10 times, most recently from d6b5d52 to 047295a Compare August 18, 2020 18:17
This is a cheap work-around for binary operators that actually behave as
variadic operators.  See aesara-devs/aesara#26.
@junpenglao
Copy link
Contributor

Interestingly I found a small differences between the gradient from theano-jax and native jax in the code snippet example:

dlogpt = theano.grad(cost=model.logpt, wrt=model.free_RVs)
dlogp_logp_fn = theano.function(inputs=model.free_RVs, outputs=[model.logpt] + dlogpt)
dlogp_logp_fn(**model.test_point)

# [DeviceArray(-2.74961873, dtype=float64),
#  DeviceArray(0.30685282, dtype=float64),
#  DeviceArray(4.4408921e-16, dtype=float64),
#  DeviceArray(0., dtype=float64)]

fgraph = theano.gof.FunctionGraph(model.free_RVs, [model.logpt])
logp_fn_jax = jax_funcify(fgraph)[0]
dlogp_logp_fn = jax.value_and_grad(logp_fn_jax, argnums=range(len(model.free_RVs)))
dlogp_logp_fn(
    model.test_point['a_log__'],
    model.test_point['b_log__'],
    model.test_point['x'])

# (DeviceArray(-2.74961873, dtype=float64),
#  (DeviceArray(0.30685282, dtype=float64),
#   DeviceArray(2.22044605e-16, dtype=float64),
#   DeviceArray(-0., dtype=float64)))

https://gist.github.com/junpenglao/fe5e1b451c076cc7b4ca16acdd7d6472

@brandonwillard
Copy link
Member Author

brandonwillard commented Sep 26, 2020

I'm glad they're so similar, though!

Those 1e-16 values? The JAX value appears to be half the Theano value. Perhaps there's a step in there somewhere that's reducing precision. Regardless, those are both within common effectively-zero ranges, so it's probably not an error.

@junpenglao
Copy link
Contributor

that makes sense - Jax is by default float32, and to enable float64 there are a bit more steps: https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#Double-(64bit)-precision

@twiecki
Copy link
Contributor

twiecki commented Sep 26, 2020

As this is basically functional (even if basic) I propose we merge this and start doing individual PRs against this.

twiecki
twiecki previously approved these changes Sep 26, 2020
@twiecki twiecki marked this pull request as ready for review September 26, 2020 12:30
Closes Theano#10.

Well, at least the JAX part, but Cython and Numba implementations can follow
very similar approaches and accomplish the same thing.
@brandonwillard
Copy link
Member Author

I just pushed some major fixes that should add functionality and coverage for all the *Subtensor* Ops (i.e. all varieties of array indexing).

Doing this enables general use of `jax` >= 0.2.0; however, we still need to find
a better fix for the shape-related problems introduced by omnistaging.

Closes Theano#43.
@twiecki
Copy link
Contributor

twiecki commented Sep 27, 2020

I can now sample with jax 0.2.0.

@junpenglao
Copy link
Contributor

I am getting an error of Identity lacking nfunc_spec attr, could be fixed with adding the below to jaxify.py

@jax_funcify.register(Identity)
def jax_funcify_Identity(op):
    def identity(x):
        return x

    return identity

@twiecki
Copy link
Contributor

twiecki commented Sep 27, 2020

I think I'll merge, it's a bit of a nuisance we can't PR into this.

@twiecki twiecki merged commit 98875d1 into aesara-devs:master Sep 27, 2020
@twiecki
Copy link
Contributor

twiecki commented Sep 27, 2020

OK damn, I was hoping it was only the syntax error (which I fixed on master) but other tests are failing too :-/.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Use Numba, Cython, and JAX for Python-implemented Ops
3 participants