In [2]:
import jax
import awkward as ak
import numba
import numpy as np

ak.jax.register_and_check()

In [2]:
a = ak.Array([[1.0, 2, 3], [5, 6]], backend="jax")

def f(x):
    return ak.sum(ak.sum(x) * x)

f(a), jax.grad(f)(a)

(Array(289., dtype=float32),
 <Array [[34.0, 34.0, 34.0], [34.0, 34.0]] type='2 * var * float32'>)

In [None]:
import jax
import awkward as ak

ak.jax.register_and_check()

In [3]:
a = ak.Array([[1.0, 2, 3], [5, 6]], backend="jax")

def f(x):
    return ak.mean(ak.sum(x) * x)

f(a), jax.grad(f)(a)

(Array(57.8, dtype=float32),
 <Array [[6.8, 6.8, 6.8], [6.8, 6.8]] type='2 * var * float32'>)

In [7]:
behavior = {}

input_arr = ak.Array([1.0], backend="jax")

@numba.vectorize(
    [
        numba.float32(numba.float32, numba.float32),
        numba.float64(numba.float64, numba.float64),
    ]
)
def _some_kernel(x, y):
    return x * x + y * y

In [8]:
@ak.mixin_class(behavior)
class SomeClass:
    @property
    def some_kernel(self):
        return _some_kernel(self.x, self.y)

ak.behavior.update(behavior)

arr = ak.zip({"x": input_arr, "y": input_arr}, with_name="SomeClass")

arr.some_kernel

In [1]:
import awkward as ak
from coffea.nanoevents.methods import candidate
import numpy as np
import uproot

ak.jax.register_and_check()

Issue: coffea.nanoevents.methods.vector will be removed and replaced with scikit-hep vector. Nanoevents schemas internal to coffea will be migrated. Otherwise please consider using that package!.
  from coffea.nanoevents.methods import vector


In [4]:
ak.behavior.update(candidate.behavior)

ttbar_file = "https://github.com/scikit-hep/scikit-hep-testdata/"\
    "raw/main/src/skhep_testdata/data/nanoAOD_2015_CMS_Open_Data_ttbar.root"

with uproot.open(ttbar_file) as f:
    arr = f["Events"].arrays(["Electron_pt", "Electron_eta", "Electron_phi",
                              "Electron_mass", "Electron_charge"])

px = arr.Electron_pt * np.cos(arr.Electron_phi)
py = arr.Electron_pt * np.sin(arr.Electron_phi)
pz = arr.Electron_pt * np.sinh(arr.Electron_eta)
E = np.sqrt(arr.Electron_mass**2 + px**2 + py**2 + pz**2)

evtfilter = ak.num(arr["Electron_pt"]) >= 2

els = ak.zip({"pt": arr.Electron_pt, "eta": arr.Electron_eta, "phi": arr.Electron_phi,
              "energy": E, "charge": arr.Electron_charge}, with_name="PtEtaPhiECandidate")[evtfilter]
els = ak.to_backend(els, "jax")

print(els[:, 0].mass)

[0.03125, 0.0, nan, 0.0, 0.03125]


In [9]:
import jax
import awkward as ak
import numba
import numpy as np

ak.jax.register_and_check()

In [3]:
def f(x):
    return np.power(x[[2, 2, 0], ::-1], 3)

In [4]:
primals = ak.Array([[1.0, 2, 3], [], [5, 6]], backend="jax")
tangents = ak.Array([[0.0, 1, 0], [], [0, 0]], backend="jax")

In [6]:
val, grad = jax.jvp(f, (primals,), (tangents,))

In [7]:
val, grad

(<Array [[216.0, 125.0], [...], [27.0, 8.0, 1.0]] type='3 * var * float32'>,
 <Array [[0.0, 0.0], [0.0, ...], [0.0, 12.0, 0.0]] type='3 * var * float32'>)

In [12]:
print(jax.grad(np.sum)(primals))

[[1.0, 1.0, 1.0], [], [1.0, 1.0]]


In [11]:
ak.sum(primals)

Array(17., dtype=float32)

In [13]:
primals = np.array([[1.0, 2, 3], [5, 6, 7]])

In [14]:
print(jax.grad(np.sum)(primals))

[[1. 1. 1.]
 [1. 1. 1.]]
