In [11]:
import jax
from jax.experimental import pallas as pl   
from jax.experimental.pallas import tpu as pltpu
import jax.numpy as jnp
import numpy as np
import functools

In [16]:
def hbm_vmem_kernel(x_hbm_ref, out_vmem_ref, scratch_vmem_ref):
  pltpu.sync_copy(x_hbm_ref.at[0:3], scratch_vmem_ref.at[0:3])
  out_vmem_ref[...] = scratch_vmem_ref[...] + 1

x = jax.random.uniform(jax.random.key(0), (8, 128), jnp.float32)
out = pl.pallas_call(hbm_vmem_kernel,
  in_specs=[pl.BlockSpec(memory_space=pltpu.MemorySpace.ANY)],
  out_shape=jax.ShapeDtypeStruct((3, 128), jnp.float32),
  scratch_shapes=(pltpu.MemorySpace.VMEM(shape=(3, 128), dtype=jnp.float32),)
)(x)

np.testing.assert_allclose(out, x[0:3] + 1)

In [24]:
def add_matrices_kernel(x_vmem_ref, y_vmem_ref, out_vmem_ref):
    x_vregs = x_vmem_ref[...]
    y_vregs = y_vmem_ref[...]
    out_vmem_ref[...] = x_vregs[...] + y_vregs[...]


In [32]:
def add_matrices(x:jax.Array,y:jax.Array)->jax.Array:
    block_spec = pl.BlockSpec((256,512),lambda i:(i,0))
    return pl.pallas_call(add_matrices_kernel,
    out_shape=x,
    in_specs=[block_spec,block_spec],
    out_specs=block_spec,
    grid=(2,),
    compiler_params=pltpu.CompilerParams(
        dimension_semantics=("parallel",)
    )
)(x,y)


In [33]:
x,y = jax.random.uniform(jax.random.key(0), (256,512), jnp.float32),jax.random.uniform(jax.random.key(42), (256,512), jnp.float32)


out = add_matrices(x,y)

In [3]:
jax.ShapeDtypeStruct((256,512),jnp.float32)

ShapeDtypeStruct(shape=(256, 512), dtype=float32)

In [4]:
def matmul_small(x: np.ndarray, y: np.ndarray) -> np.ndarray:
  m, k, n = x.shape[0], x.shape[1], y.shape[0]
  assert m <= 256
  assert k <= 256
  assert n <= 256
  return np.matmul(x, y)

def block_matmul(
    x: np.ndarray,
    y: np.ndarray,
    *,
    bm: int = 256,
    bk: int = 256,
    bn: int = 256,
) -> np.ndarray:
  m, k = x.shape
  _, n = y.shape

  z = np.zeros((m, n), dtype=x.dtype)
  for m_i in range(m // bm):
    for n_i in range(n // bn):
      for k_i in range(k // bk):
        m_slice = slice(m_i * bm, (m_i + 1) * bm)
        k_slice = slice(k_i * bk, (k_i + 1) * bk)
        n_slice = slice(n_i * bn, (n_i + 1) * bn)
        x_block = x[m_slice, k_slice]
        y_block = y[k_slice, n_slice]
        z[m_slice, n_slice] += matmul_small(x_block, y_block)
  return z

In [5]:
def matmul_flops(m: int, k: int, n: int):
  return 2 * m * k * n

def matmul_membw(m: int, k: int, n: int, dtype: jnp.dtype):
  return (m * k + k * n + m * n) * np.dtype(dtype).itemsize

print(matmul_flops(1024, 1024, 1024))
print(matmul_membw(1024, 1024, 1024, jnp.float32))

2147483648
12582912


In [6]:
v5e_flops = 197e12
v5e_membw = 819e9
v5e_op_intensity = v5e_flops / v5e_membw  # ~240.5

In [7]:
def matmul_flops_intensity(m: int, k: int, n: int, dtype: jnp.dtype):
  flops = matmul_flops(m, k, n)
  membw = matmul_membw(m, k, n, dtype)
  return flops / membw

In [8]:
print(f"{matmul_flops_intensity(1024, 1024, 1024, jnp.float32)} flops/byte")

170.66666666666666 flops/byte


In [9]:
print(f"{matmul_flops_intensity(1024, 1024, 1024, jnp.bfloat16)} flops/byte")

341.3333333333333 flops/byte


In [28]:
def matmul_kernel(x_ref,y_ref,z_ref,acc_ref,*,nsteps):
    @pl.when(pl.program_id(2)==0)
    def init():
        acc_ref[...] = jnp.zeros_like(acc_ref)

    acc_ref[...]+=jnp.dot(x_ref[...],y_ref[...],preferred_element_type=jnp.float32)
    @pl.when(pl.program_id(2)==nsteps-1)
    def final():
        z_ref[...] = acc_ref[...].astype(x_ref.dtype)

In [29]:
def matmul(x:jax.Array,y:jax.Array,*,bm=128,bk=128,bn=128)->jax.Array:
    m,k = x.shape
    _,n = y.shape
    assert m%bm==0 and k%bk==0 and n%bn==0
    m_blocks = m//bm
    k_blocks = k//bk
    n_blocks = n//bn
    return pl.pallas_call(functools.partial(matmul_kernel,nsteps=k_blocks),
    grid_spec=pltpu.PrefetchScalarGridSpec(num_scalar_prefetch=0,
    in_specs=[
        pl.BlockSpec((bm,bk),lambda i,j,k:(i,k)),
        pl.BlockSpec((bk,bn),lambda i,j,k:(k,j)),
    ],
    out_specs=pl.BlockSpec((bm,bn),lambda i,j,k:(i,j)),
    scratch_shapes=[pltpu.MemorySpace.VMEM(shape=(bm,bk),dtype=jnp.float32)],
    grid=(m_blocks,n_blocks,k_blocks),
    ),
    out_shape=jax.ShapeDtypeStruct((m,n),x.dtype),
    compiler_params=pltpu.CompilerParams(
        dimension_semantics=("parallel","parallel","arbitrary")
    ),
    )(x,y)

In [31]:
m,n,k = 4096,4096,4096
x = jax.random.normal(jax.random.key(0), (m,k), jnp.bfloat16)
y = jax.random.normal(jax.random.key(0), (k,n), jnp.bfloat16)
z = matmul(x,y)
np.testing.assert_array_equal(z,x@y)

In [39]:
@jax.jit
def f(x:jax.Array,y:jax.Array)->jax.Array:
    return g(x)+y


@functools.partial(jax.jit,donate_argnums=[0])
def g(x:jax.Array)->jax.Array:
    return x+1

x = jax.random.normal(jax.random.key(0), (1024,1024), jnp.float32)
y = jax.random.normal(jax.random.key(0), (1024,1024), jnp.float32)
z = f(x,y)

In [None]:
class myClass:
    def __init__(self,x:int):
        self.x = x

    def __call__(self,y:int)->int:
        x = jnp.array([1])
        out = g(x)+y
        print(x)
        return out


x = myClass(1)
y = x(2)
print(y)

RuntimeError: Array has been deleted with shape=int32[1].

: 

In [58]:
y = x(3)
print(y)

1
5


In [3]:
import numpy as np
import jax.numpy as jnp



In [4]:
import numpy as np

def precompute_freqs_cis_numpy(
    dim: int, end: int, theta: float = 10000.0, use_scaled: bool = False, dtype=np.float32
):
    """
    NumPy equivalent of the precompute_freqs_cis torch function.
    """
    # Note: If use_scaled is True, you will need to provide a NumPy
    # implementation of the `apply_scaling` function.
    if use_scaled:
        raise NotImplementedError(
            "The `apply_scaling` function is not implemented for NumPy in this snippet."
        )

    # Calculate frequencies
    freqs = 1.0 / (theta ** (np.arange(0, dim, 2).astype(dtype) / dim))

    # Create the time sequence
    t = np.arange(end, dtype=dtype)

    # Compute the outer product to get phase angles for all positions and dimensions
    freqs = np.outer(t, freqs)

    # Convert phase angles to complex numbers on the unit circle
    # e^(i*theta) = cos(theta) + i*sin(theta)
    freqs_cis = np.exp(1j * freqs)

    # Stack the real and imaginary parts to match the torch output shape
    freqs_cis_real = np.stack([np.real(freqs_cis), np.imag(freqs_cis)], axis=-1)

    return freqs_cis_real


In [None]:
from utils.ops import precompute_freqs_cis as precompute_freqs_cis_jax
from experiments.torch_llama import precompute_freqs_cis as precompute_freqs_cis_torch
import torch
import jax
import jax.numpy as jnp
import numpy as np

jax.config.update("jax_default_matmul_precision", "highest")


freqs_cis_jax = precompute_freqs_cis_jax(128, 1024, 500000.0, dtype=jnp.bfloat16)
freqs_cis_numpy = precompute_freqs_cis_numpy(128, 1024, 500000.0)
freqs_cis_torch = precompute_freqs_cis_torch(128, 1024, 500000.0).to(dtype=torch.bfloat16)

np.testing.assert_allclose(freqs_cis_jax, freqs_cis_numpy, rtol=1e-4, atol=1e-4)

AssertionError: 
Not equal to tolerance rtol=0.0001, atol=0.0001

Mismatched elements: 7639 / 131072 (5.83%)
Max absolute difference among violations: 0.00213042
Max relative difference among violations: 1.4310915
 ACTUAL: array([[[ 1.000000e+00,  0.000000e+00],
        [ 1.000000e+00,  0.000000e+00],
        [ 1.000000e+00,  0.000000e+00],...
 DESIRED: array([[[ 1.000000e+00,  0.000000e+00],
        [ 1.000000e+00,  0.000000e+00],
        [ 1.000000e+00,  0.000000e+00],...

In [11]:
np.testing.assert_allclose(freqs_cis_numpy, freqs_cis_torch, rtol=1e-4, atol=1e-4)