# Lecture 09: PyTorch and JAX
A primer on compilation, when it's useful and when it's not. 

In [1]:
import os
os.environ["TORCH_LOGS"] = "output_code, guards, recompiles"
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"]="false"

import jax
import jax.numpy as jnp

import torch
import torch.nn as nn
import torch.nn.functional as F

## Where does compilation really matter? (JAX vs PyTorch)

In [2]:
@jax.jit
def relu(X):
    return jnp.maximum(0, X)
    
@jax.jit
def two_matmuls(X, A, B):
    Y = X @ A
    Y = relu(Y)
    Y = Y @ B

    return Y
    
rng_key = jax.random.PRNGKey(seed=20)
X_jax = jax.random.normal(rng_key, (256, 1024))
A_jax = jax.random.normal(rng_key, (1024, 2048))
B_jax = jax.random.normal(rng_key, (2048, 1024))

### JAX Timing

In [17]:
%timeit two_matmuls(X_jax, A_jax, B_jax)

76 μs ± 45.7 μs per loop (mean ± std. dev. of 7 runs, 1 loop each)


This is suspiciously fast. Why?

JAX uses an asynchronous model to avoid Python overheads - this lets Python's control flow go almost uninterrupted [Source](https://docs.jax.dev/en/latest/async_dispatch.html)

In [18]:
%timeit two_matmuls(X_jax, A_jax, B_jax).block_until_ready()

205 μs ± 2.85 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [3]:
def two_matmuls_pytorch(X, A, B):
    Y = X @ A
    Y = F.relu(Y)
    Y = Y @ B

    return Y

X = torch.randn(256, 1024).to('cuda')
A = torch.randn(1024, 2048).to('cuda')
B = torch.randn(2048, 1024).to('cuda')

compiled_pt_fn = torch.compile(two_matmuls_pytorch)

# warmup
_ = compiled_pt_fn(X, A, B)

def pt_fn(compiled=False):
    if compiled:
        Y = compiled_pt_fn(X, A, B)
    else:
        Y = two_matmuls_pytorch(X, A, B)
    torch.cuda.synchronize()
    return Y


V1023 23:09:25.842000 163 site-packages/torch/_inductor/graph.py:2345] [0/0] [__output_code] Output code: 
V1023 23:09:25.842000 163 site-packages/torch/_inductor/graph.py:2345] [0/0] [__output_code] # AOT ID: ['0_inference']
V1023 23:09:25.842000 163 site-packages/torch/_inductor/graph.py:2345] [0/0] [__output_code] from ctypes import c_void_p, c_long, c_int
V1023 23:09:25.842000 163 site-packages/torch/_inductor/graph.py:2345] [0/0] [__output_code] import torch
V1023 23:09:25.842000 163 site-packages/torch/_inductor/graph.py:2345] [0/0] [__output_code] import math
V1023 23:09:25.842000 163 site-packages/torch/_inductor/graph.py:2345] [0/0] [__output_code] import random
V1023 23:09:25.842000 163 site-packages/torch/_inductor/graph.py:2345] [0/0] [__output_code] import os
V1023 23:09:25.842000 163 site-packages/torch/_inductor/graph.py:2345] [0/0] [__output_code] import tempfile
V1023 23:09:25.842000 163 site-packages/torch/_inductor/graph.py:2345] [0/0] [__output_code] from math impor

### PyTorch Timing

In [6]:
%timeit -n100 -r5 pt_fn()

183 μs ± 2.05 μs per loop (mean ± std. dev. of 5 runs, 100 loops each)


In [7]:
pt_fn(compiled=True)

tensor([[ 1553.5459,  -450.3364,  1118.1101,  ...,  -346.2913,   189.7560,
          -763.6236],
        [  725.9422,  -810.8474,  -106.8808,  ...,  -561.0588,  -770.7185,
          -948.0106],
        [  777.3891, -1309.3801, -1330.3395,  ...,   272.8092,   662.3004,
          1540.0090],
        ...,
        [  141.1997,   165.5197,  1213.4856,  ...,  -385.2852,   670.5660,
         -1739.9691],
        [ -974.1588,  -289.8643,  1005.3350,  ...,   196.1749,   569.9285,
         -1241.6023],
        [ -274.6496, -1811.7827,  1221.9260,  ...,  1461.4171,  -200.6417,
         -1501.7313]], device='cuda:0')

In [8]:
%timeit -n100 -r5 pt_fn(compiled=True)

214 μs ± 4.64 μs per loop (mean ± std. dev. of 5 runs, 100 loops each)


## Speedups with a lot of reads and writes to HBM

In [9]:
@jax.jit
def chain_jit(x):
    return (jnp.tanh(0.1*x + 1.7) * jax.nn.sigmoid(x) + jnp.exp(-x*x)).sum()

x = jnp.ones((10_000_000,), dtype=jnp.float32)

chain_jit(x).block_until_ready()

%timeit chain_jit(x).block_until_ready()

124 μs ± 1.2 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [10]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
x = torch.ones(10_000_000, device=device, dtype=torch.float32)

def chain_pt(x):
    return (torch.tanh(0.1*x + 1.7) * torch.sigmoid(x) + torch.exp(-(x*x))).sum()

compiled = torch.compile(chain_pt)  # requires PyTorch 2.x

# Warm up
_ = compiled(x)
if device == 'cuda': torch.cuda.synchronize()

# Fair timing (sync on GPU)
if device == 'cuda':
    %timeit (chain_pt(x)); torch.cuda.synchronize()
    %timeit (compiled(x)); torch.cuda.synchronize()
else:
    %timeit chain_pt(x)
    %timeit compiled(x)

V1023 23:10:52.463000 163 site-packages/torch/_inductor/graph.py:2345] [1/0] [__output_code] Output code: 
V1023 23:10:52.463000 163 site-packages/torch/_inductor/graph.py:2345] [1/0] [__output_code] # AOT ID: ['1_inference']
V1023 23:10:52.463000 163 site-packages/torch/_inductor/graph.py:2345] [1/0] [__output_code] from ctypes import c_void_p, c_long, c_int
V1023 23:10:52.463000 163 site-packages/torch/_inductor/graph.py:2345] [1/0] [__output_code] import torch
V1023 23:10:52.463000 163 site-packages/torch/_inductor/graph.py:2345] [1/0] [__output_code] import math
V1023 23:10:52.463000 163 site-packages/torch/_inductor/graph.py:2345] [1/0] [__output_code] import random
V1023 23:10:52.463000 163 site-packages/torch/_inductor/graph.py:2345] [1/0] [__output_code] import os
V1023 23:10:52.463000 163 site-packages/torch/_inductor/graph.py:2345] [1/0] [__output_code] import tempfile
V1023 23:10:52.463000 163 site-packages/torch/_inductor/graph.py:2345] [1/0] [__output_code] from math impor

549 μs ± 135 ns per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
81.9 μs ± 127 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [17]:
!cat /tmp/torchinductor_root/tq/ctqd7mrkcjr7yx6darqxxwqev5klc3pf62xqcsb3rb2csqyfaqaa.py


import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.reduction(
    size_hints={'x': 512, 'r0_': 32768},
    reduction_hint=ReductionHint.INNER,
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=108, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_red

## Speedups in MLP with Post Norm

In [12]:
B, S, D, H = 256, 1024, 512, 2048  # batch, model dim, hidden dim

@jax.jit
def mlp_block_jit(x, w1, b1, w2, b2, gamma, beta, eps=1e-5):
    h = x @ w1 + b1
    h = jax.nn.gelu(h, approximate=True)
    h = h @ w2 + b2
    y = x + h
    m = y.mean(-1, keepdims=True)
    v = ((y - m) ** 2).mean(-1, keepdims=True)
    y = (y - m) / jnp.sqrt(v + eps) * gamma + beta
    return y

key = jax.random.PRNGKey(0)
x = jax.random.normal(key, (B, S, D), dtype=jnp.float32)
w1 = jax.random.normal(key, (D, H), dtype=jnp.float32); b1 = jnp.zeros((H,), jnp.float32)
w2 = jax.random.normal(key, (H, D), dtype=jnp.float32); b2 = jnp.zeros((D,), jnp.float32)
gamma = jnp.ones((D,), jnp.float32); beta = jnp.zeros((D,), jnp.float32)

# warmup compile
_ = mlp_block_jit(x, w1, b1, w2, b2, gamma, beta).block_until_ready()


In [13]:
%timeit mlp_block_jit(x, w1, b1, w2, b2, gamma, beta).block_until_ready()

14.8 ms ± 16 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [14]:
os.environ['XLA_FLAGS'] = "--xla_dump_to=./xla_dump --xla_dump_hlo_as_text"
mlp_block_jit(x, w1, b1, w2, b2, gamma, beta).block_until_ready()

Array([[[-1.4833924 ,  0.3511368 ,  0.21911843, ...,  1.4281534 ,
          0.06670557,  1.1523116 ],
        [-0.05532628,  0.30686823, -2.0907853 , ..., -0.4386818 ,
         -0.09272943, -0.37294406],
        [-0.21030568, -0.65031123, -0.8474163 , ...,  0.9286902 ,
          0.6720757 ,  1.1566288 ],
        ...,
        [ 0.7291884 , -0.17409474, -1.855061  , ...,  0.95941633,
          1.1517587 ,  1.0067558 ],
        [-0.40763727, -1.0683631 , -0.8175719 , ...,  0.6376863 ,
          1.420852  , -0.0816533 ],
        [-0.89247984, -0.32276464, -0.95661914, ...,  0.40030655,
         -0.5419413 ,  1.0331538 ]],

       [[ 0.08665023,  1.0070175 ,  0.34418038, ...,  0.9715808 ,
         -1.7578357 ,  1.0011321 ],
        [-0.3371217 , -0.7290098 , -0.8793314 , ...,  1.4327533 ,
         -0.84283054,  0.594506  ],
        [-1.1071863 ,  0.83588153, -1.2886665 , ...,  1.7351596 ,
         -0.67959344,  2.0760362 ],
        ...,
        [-1.4165086 , -0.01250071, -1.3415816 , ...,  

In [18]:
import gc
gc.collect()
torch.cuda.empty_cache()

In [19]:
torch.set_float32_matmul_precision('high')
device = 'cuda' if torch.cuda.is_available() else 'cpu'

B, S, D, H = 256, 1024, 512, 2048
x  = torch.randn(B, S, D, device=device)
w1 = torch.randn(D, H, device=device); b1 = torch.zeros(H, device=device)
w2 = torch.randn(H, D, device=device); b2 = torch.zeros(D, device=device)
gamma = torch.ones(D, device=device);  beta = torch.zeros(D, device=device)

def mlp_block_pt(x, w1, b1, w2, b2, gamma, beta, eps=1e-5):
    h = x @ w1 + b1
    h = F.gelu(h, approximate='tanh')
    h = h @ w2 + b2
    y = x + h
    m = y.mean(dim=-1, keepdim=True)
    v = (y - m).pow(2).mean(dim=-1, keepdim=True)
    y = (y - m) / torch.sqrt(v + eps) * gamma + beta
    return y

compiled = torch.compile(mlp_block_pt)  # PyTorch 2.x

# warmup
_ = compiled(x, w1, b1, w2, b2, gamma, beta)
if device == 'cuda': torch.cuda.synchronize()

V1023 23:14:45.837000 163 site-packages/torch/_inductor/graph.py:2345] [2/0] [__output_code] Output code: 
V1023 23:14:45.837000 163 site-packages/torch/_inductor/graph.py:2345] [2/0] [__output_code] # AOT ID: ['2_inference']
V1023 23:14:45.837000 163 site-packages/torch/_inductor/graph.py:2345] [2/0] [__output_code] from ctypes import c_void_p, c_long, c_int
V1023 23:14:45.837000 163 site-packages/torch/_inductor/graph.py:2345] [2/0] [__output_code] import torch
V1023 23:14:45.837000 163 site-packages/torch/_inductor/graph.py:2345] [2/0] [__output_code] import math
V1023 23:14:45.837000 163 site-packages/torch/_inductor/graph.py:2345] [2/0] [__output_code] import random
V1023 23:14:45.837000 163 site-packages/torch/_inductor/graph.py:2345] [2/0] [__output_code] import os
V1023 23:14:45.837000 163 site-packages/torch/_inductor/graph.py:2345] [2/0] [__output_code] import tempfile
V1023 23:14:45.837000 163 site-packages/torch/_inductor/graph.py:2345] [2/0] [__output_code] from math impor

In [22]:
!cat /tmp/torchinductor_root/wj/cwjdxk6jkioa2ctaiariho5gosqwybyvhihnmhnoc43we3psgh6a.py


import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'x': 536870912}, 
    filename=__file__,
    triton_meta={'signature': {'in_out_ptr0': '*fp32', 'in_ptr0': '*fp32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=108, cc=80, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_gelu_0', 'mutated_arg_names': ['in_out_ptr0']

In [23]:
!cat /tmp/torchinductor_root/ge/cgeocwfjblpsir7kadiun3xv7zck5fj3mzh334sfrhk4kd5ohd7a.py

# AOT ID: ['2_inference']
from ctypes import c_void_p, c_long, c_int
import torch
import math
import random
import os
import tempfile
from math import inf, nan
from cmath import nanj
from torch._inductor.hooks import run_intermediate_hooks
from torch._inductor.utils import maybe_profile
from torch._inductor.codegen.memory_planning import _align as align
from torch import device, empty_strided
from torch._inductor.async_compile import AsyncCompile
from torch._inductor.select_algorithm import extern_kernels
import triton
import triton.language as tl
from torch._inductor.runtime.triton_heuristics import start_graph, end_graph
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
from torch._C import _cuda_getCurrentRawStream as get_raw_stream

aten = torch.ops.aten
inductor_ops = torch.ops.inductor
_quantized = torch.ops._quantized
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
assert_alignment = torch._C._dynamo.guards.assert_alignment

In [20]:
%timeit (mlp_block_pt(x, w1, b1, w2, b2, gamma, beta)); torch.cuda.synchronize()
%timeit (compiled(x, w1, b1, w2, b2, gamma, beta)); torch.cuda.synchronize()

22.5 ms ± 62.8 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)
14.4 ms ± 21.2 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


## JAX Internals

In [36]:
print(jax.make_jaxpr(two_matmuls)(X_jax, A_jax, B_jax))

{ [34;1mlambda [39;22m; a[35m:f32[256,1024][39m b[35m:f32[1024,2048][39m c[35m:f32[2048,1024][39m. [34;1mlet
    [39;22md[35m:f32[256,1024][39m = jit[
      name=two_matmuls
      jaxpr={ [34;1mlambda [39;22m; a[35m:f32[256,1024][39m b[35m:f32[1024,2048][39m c[35m:f32[2048,1024][39m. [34;1mlet
          [39;22me[35m:f32[256,2048][39m = dot_general[
            dimension_numbers=(([1], [0]), ([], []))
            preferred_element_type=float32
          ] a b
          f[35m:f32[256,2048][39m = jit[
            name=relu
            jaxpr={ [34;1mlambda [39;22m; e[35m:f32[256,2048][39m. [34;1mlet
                [39;22mf[35m:f32[256,2048][39m = max 0.0:f32[] e
              [34;1min [39;22m(f,) }
          ] e
          d[35m:f32[256,1024][39m = dot_general[
            dimension_numbers=(([1], [0]), ([], []))
            preferred_element_type=float32
          ] f c
        [34;1min [39;22m(d,) }
    ] a b c
  [34;1min [39;22m(d,) }


In [37]:
traced_fn = two_matmuls.trace(X_jax, A_jax, B_jax)
print(traced_fn.jaxpr)
hlo = traced_fn.lower()
print("HLO:\n", hlo.as_text())
hlo_optimized = hlo.compile()
print(hlo_optimized.as_text())

{ [34;1mlambda [39;22m; a[35m:f32[256,1024][39m b[35m:f32[1024,2048][39m c[35m:f32[2048,1024][39m. [34;1mlet
    [39;22md[35m:f32[256,2048][39m = dot_general[
      dimension_numbers=(([1], [0]), ([], []))
      preferred_element_type=float32
    ] a b
    e[35m:f32[256,2048][39m = jit[
      name=relu
      jaxpr={ [34;1mlambda [39;22m; d[35m:f32[256,2048][39m. [34;1mlet
          [39;22me[35m:f32[256,2048][39m = max 0.0:f32[] d
        [34;1min [39;22m(e,) }
    ] d
    f[35m:f32[256,1024][39m = dot_general[
      dimension_numbers=(([1], [0]), ([], []))
      preferred_element_type=float32
    ] e c
  [34;1min [39;22m(f,) }
HLO:
 module @jit_two_matmuls attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<256x1024xf32>, %arg1: tensor<1024x2048xf32>, %arg2: tensor<2048x1024xf32>) -> (tensor<256x1024xf32> {jax.result_info = "result"}) {
    %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims 