-
Notifications
You must be signed in to change notification settings - Fork 47
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
JAX qfunctions #781
Comments
My concern with 1 is that it doesn't (on its own) allow comparing JAX backends to those we've already written and that the JIT might be very slow (based on the times you've reported in CRIKit experiments). Regarding the device array handoff, there are few places in the source that reference https://numba.pydata.org/numba-doc/latest/cuda/cuda_array_interface.html ( |
So the reason the CRIKit compiles are slow is that the compilation has to unroll a large Python loop. No such issue should be present in libCEED if we write a pure JAX backend, especially since most of the operations can be expressed as JAX does support |
Ah, so we need to better use JAX looping primitives in CRIKit. CeedElemRestriction transpose does accumulation so we'd need a way to avoid conflicting overlapping writes (atomics are faster, though non-deterministic; libCEED docs classify backends by determinism https://libceed.readthedocs.io/en/latest/gettingstarted/#backends). Has this been addressed in some later PR?
|
As to why we can't use the JAX loop primitives in CRIKit right now: we need to be indexing into a Python list, but the loop counter variable has to be a JAX tracer object, which doesn't implement the Anyway, I just looked through the JAX PRs and it doesn't look like the reverse direction has happened yet. I'll look into doing it myself though, it doesn't look like there's all that much code that would have to change, and I can probably take some hints from the numba codebase. Another possibility (?) would be to use DLPack, which JAX supports. Not sure if that will work on GPU, but if it does, that would be easier than using |
Sounds great. PETSc has DLPack support and I think it's at a level that makes sense for libCEED too. |
So where would be the natural place to put the code for this? Would we add a new pair of functions to EDIT: after taking a closer look at ceed.h, this makes more sense to me:
Open to suggestions though. Development is going in https://github.com/CEED/libCEED/tree/emily/dlpack for now |
I'm not big on hiding includes behind ifdefs at compile time.
I'm perhaps not understanding, but does this need to be in the C interface? I thought this was for passing data back and forth in Python?
If it has to be in the C interface, perhaps a separate header and file, like ceed/cuda.h and ceed/hip.h
…On Mon, Jun 7, 2021, at 5:39 PM, Emily Jakobs wrote:
So where would be the natural place to put the code for this? Would we add a new pair of functions to `ceed.h`, say `int CeedVectorToDLPack(CeedVector, DLManagedTensor **)` and `int CeedVectorFromDLPack(CeedVector *, DLManagedTensor *)` and implement them for every backend, or something else?
—
You are receiving this because you are subscribed to this thread.
Reply to this email directly, view it on GitHub <#781 (comment)>, or unsubscribe <https://github.com/notifications/unsubscribe-auth/AF62K5MYUAG4WQ6ZTCAGGZLTRVKCFANCNFSM46IN3P2Q>.
|
Well it's about getting data to and from C and Python (so that you can use JAX qfunctions with any CEED backend). I suppose it doesn't have to be in the C API, but there has to be a function we can call from the cffi in the Python implementation, and it should be the same function regardless of what backend is in use. I think this would be a useful function to have in the main C interface though because it would more easily enable users to use, say, C++ TensorFlow code or any other C++ ML code as a qfunction instead of just JAX. |
That makes sense. I'd be in favor of adding ceed/dlpack.h and interface/ceed-dlpack.c if we add this to the C interface in anticipation of future flexibility.
…On Mon, Jun 7, 2021, at 6:07 PM, Emily Jakobs wrote:
Well it's about getting data to and from C and Python (so that you can use JAX qfunctions with any CEED backend). I suppose it doesn't have to be in the C API, but there has to be a function we can call from the cffi in the Python implementation, and it should be the same function regardless of what backend is in use. I think this would be a useful function to have in the main C interface though because it would more easily enable users to use, say, C++ TensorFlow code or any other C++ ML code as a qfunction instead of just JAX.
—
You are receiving this because you commented.
Reply to this email directly, view it on GitHub <#781 (comment)>, or unsubscribe <https://github.com/notifications/unsubscribe-auth/AF62K5NINDUHK2CKD57QGF3TRVNKPANCNFSM46IN3P2Q>.
|
Is there currently a good way to determine which backend a |
You can give a host array to a GPU backend just fine. CPU backends getting a device array will error. See the PETSc example for an example of querying the backend preferred memory type and setting that as the handoff memory type. We make the assumption that users will hand off the correct type of device array for the backend. |
So should I assume that the user has downloaded By the way, in case you want to see an example of using |
Oh, after thinking about it a bit more, we'll also need functions to transform the input/output context data for a QFunction to/from DLPack, right? (So that we can pass the input fields as |
Hmm, the PETSc interface to DLPack is pure Cython (without mentioning The contexts are plain data (no pointers) of specified size so they can just be copied. It's technically unnecessary in languages with closures or other dynamic way to construct functions -- enabling them to be parametrized without depending on global variables. |
I don't think I understand your envisioned outcome. If we go the Python route, I would actually change If we go the C route, we'd make a brand new backend that delegates back to the different backends. We'd have something like I feel like the Python route might be easier (I've wanted to add native Python QFunctions for a while but haven't had the time) |
so my envisioned outcome is essentially what you're describing with |
I think writing QFunctions in Python/JAX is what we want. I'd like to preserve the ability to use QFunctions that were written in C from Python. I'm not sure the value of a straight C interface. Looking at Writing a |
@jedbrown just to be entirely unambiguous about what you're thinking about, are you thinking of a
(perhaps with functions related a QFunction instead of a Vector), or were you thinking of copy-pasting the contents of |
Yes, the above functions look about right. I think we should do what others do and "vendor" the header from upstream -- copy it into the libceed repository. That'll allow us to implement those public interfaces without configuration options. One choice would be to keep the header private (don't install it) and only include it in |
I'm testing that right now in
FWIW,
and the same happens with other DLPack types. Any idea what I might be doing wrong? (I should have everything updated in the remote so if you want to see the errors for yourself, cloning my branch and pip installing should work) EDIT: ah, figured it out, had to include |
So it turns out that tensorflow uses pybind11, so the returned EDIT: figured this one out, use |
But I thought you were inside of Python? |
yes, but jax calls out to XLA, which is a C++ library (and a part of tensorflow). See https://github.com/google/jax/blob/master/jax/_src/dlpack.py#L43 and https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/python/dlpack.cc#L250 |
But why can't the conversion to dlpack tensors occur in Python? When a QFunction is created, a user provides a pointer to a function (or the interface takes the user's function and converts it into an appropriate pointer) (note, this function can be written in any language) which takes in raw pointers to the context data, input array, and output arrays. Why not convert the raw array pointers to dlpack tensors inside of a Python function call? |
Oh, I see what you mean; no reason it couldn't happen inside of Python, I just don't see what advantage that would confer over also having the code in C, especially now because I figured out how to get around that problem I was just describing ( So I could re-write the C portion of this code in Python using CFFI and/or ctypes if you have a good reason for it to be there instead of in C, but as far as I can see literally the only difference would be that the code would then only be accessible through the Python API as opposed to potentially being available in every libCEED API. Either way, the C code I have works to get a read-only array from JAX. Still working on a read-write compatible one, for some reason weird things happen with that (if you pass |
Unrelated to the comments immediately above this one, I do have one question of semantics: if the incoming DLPack |
Erroring is the right behavior in case of precision mismatch. See also #778 As for where the conversion code lives, we just want to keep build-time configuration as simple as possible. Python is fully dynamic, but the cffi code is basically equivalent to vendoring |
Just posting an update here in case anyone has any ideas about how to get around the issue I'm working through; the main problem I'm dealing with is that
but it segfaults for some reason I can't yet figure out. An alternative that should "just work" would be to write a C or C++ function that creates the |
@caidao22 Do you have experience/suggestions for this issue connecting vectors exposed via DLPack with JAX? |
Shouldn’t PyCapsule_New take three input arguments with the third one being a destructor? |
The destructor is optional (see https://docs.python.org/3/c-api/capsule.html), and although omitting it might cause a memory leak, it shouldn't cause
|
What if you comment out these two lines: |
I would also try to pass NULL as the destructor before the destructor is implemented. |
If you comment out those two lines, then |
In PETSc I used PyCapsule_New through cython header Python.h and did not need to set the types. Perhaps it is easier to follow https://gitlab.com/petsc/petsc/-/blob/main/src/binding/petsc4py/src/PETSc/petscvec.pxi#L580 |
@jedbrown how do you feel about having a Cython module with a single function that handles the |
Agreed that it's a heavy dependency, but let's try that and once it's working, maybe we'll understand the problem well enough we can drop the Cython dependency. And if not, no big deal. |
This probably won't matter too much for this cython code, but once I have this working, am I correct in assuming the next step would be to write a pair of functions like |
Just leaving an update on my progress here -- getting data from JAX into a
and sometimes instead gives
which seems to indicate, as I mentioned above, that TensorFlow is reading in the wrong fields, in part because I know that every field in the struct I'm passing is initialized, since I'm passing this (output of the above-mentioned printing function):
|
@jedbrown and I have been discussing the possibility of using JAX to write qfunctions, since it supports JIT compilation and automatic differentiation. I see several ways to go about this, and several potential roadblocks, so I'm opening this issue for discussion. First, we need to decide what sort of architecture we want -- here are a few options:
DeviceArray
instances. The major advantage of this approach would be that it's at least to some extent backend-independent (perhaps not the getting data intoDeviceArray
s part) and would require writing less Python code (i.e. not having to implement most libCEED functions in Python), but the biggest disadvantage would probably be that it's not necessarily easy to get the data into a JAX array on the device with no copying. The goal would be to avoid having to write C++ code that depends on XLA itself, since such code can really only be compiled in any reasonable manner by BazelIf any of the libCEED devs have thoughts on this or are interested in working with me on implementing it, please let me know
The text was updated successfully, but these errors were encountered: