In [1]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

Pallas is an extension to jax that enables writing custom kernels for GPUs and TPUs. It aims to provide fine-grained control over generated code, combining with the high-level ergonomics of JAX tracing and jax.numpy API.

Pallas requires users to think of memory access and how to divide computation across multiple compute units.

On GPUs, pallas lowers to triton, and on TPUs, pallas lowers to Mosaic.

In [3]:
from functools import partial
import jax
import jax.numpy as jnp
import jax.experimental.pallas as pl
import numpy as np

In [8]:
# Kernel to add two vectors
def add_vectors_kernel(x_ref, y_ref, o_ref):
    x, y = x_ref[...], y_ref[...]
    o_ref[...] = x + y

In [5]:
# Inputs are Ref objects, which are basically pointers.
# Function does not return anything

In [6]:
# Reading the Ref object with the ellipsis [...] means reading the whole array. Alternatively, you could also use [:]

In [7]:
# Use the pallas_call function to launch the kernel

In [9]:
@jax.jit
def add_vectors(x: jax.Array, y: jax.Array) -> jax.Array:
    return pl.pallas_call(
        add_vectors_kernel,
        out_shape = jax.ShapeDtypeStruct(x.shape, x.dtype)
    )(x, y)

add_vectors(jnp.arange(8), jnp.arange(8))

# Notice that unlike triton, you do not need to allocate output memory and pass it as ptr. Instead, pallas_call just needs to know the shape and dtype. 
# It will create a buffer, pass it as the last arg, and fetch it back and return it for you.

XlaRuntimeError: FAILED_PRECONDITION: Triton support is only enabled for Ampere GPUs (compute capability 8.0) and up, but got compute capability 7.5.

In [10]:
def iota_kernel(o_ref):
  i = pl.program_id(0)
  o_ref[i] = i

In [None]:
def iota(size: int):
  return pl.pallas_call(iota_kernel,
                        out_shape=jax.ShapeDtypeStruct((size,), jnp.int32),
                        grid=(size,))()
iota(8)

In [None]:
def matmul_kernel(x_ref, y_ref, z_ref):
  z_ref[...] = x_ref[...] @ y_ref[...]

def matmul(x: jax.Array, y: jax.Array):
  return pl.pallas_call(
    matmul_kernel,
    out_shape=jax.ShapeDtypeStruct((x.shape[0], y.shape[1]), x.dtype),
    grid=(2, 2),
    in_specs=[
        pl.BlockSpec((x.shape[0] // 2, x.shape[1]), lambda i, j: (i, 0)),
        pl.BlockSpec((y.shape[0], y.shape[1] // 2), lambda i, j: (0, j))
    ],
    out_specs=pl.BlockSpec(
        (x.shape[0] // 2, y.shape[1] // 2), lambda i, j: (i, j),
    )
  )(x, y)
k1, k2 = jax.random.split(jax.random.key(0))
x = jax.random.normal(k1, (1024, 1024))
y = jax.random.normal(k2, (1024, 1024))
z = matmul(x, y)
np.testing.assert_allclose(z, x @ y)

In [12]:
# Grid is just a tuple. And is same as CUDA. Each grid has the specified number of threadblocks in it.
# Block is defined as pl.BlockSpec, again, just a tuple. However, here there is no restriction on the number of elements per block (since it'll be split
# into 1024 threads at max).
# Pallas also operates on vectors like Triton. Do not think in threads and warps. Only think blocks.

You can transform a pallas function with vmap.

To get the blockIdx, you can use jax.experiemental.pallas.program_id(axis=0|1|2)

pallas.num_programs() to get the grid size of a given axis.