# Thunder bindings for Liger operators



In [1]:
from collections.abc import Sequence
import math

import torch
from torch.testing import assert_close
import litgpt
import thunder
from thunder.core.proxies import TensorProxy, AnyProxy
from thunder.core.transforms import get_grad, put_grads
from thunder.torch import TensorLike
from thunder.executors.utils import Context, set_saved_tensors
from thunder.core.compile_data import get_compile_option
import thunder.extend

import liger_kernel.ops.rms_norm
import liger_kernel.ops.rope
import liger_kernel.ops.swiglu
import liger_kernel.ops.geglu
import liger_kernel.ops.cross_entropy

device = torch.device("cuda")

liger_ex = thunder.extend.OperatorExecutor("liger", version="0.1")
thunder.extend.register_executor(liger_ex)

thunder.extend.OperatorExecutor('liger')

## RMS Norm

The first thing to fuse is RMS Norm.

After that, Liger's implementation is a drop-in replacement. We define operators for forward and backward and then a gradient and execution rule.

We register these as an implementation for the rms_norm operand that we divert the PyTorch function to.

In [2]:
# A tiny detail here is that PyTorch gained a `rms_norm` function somewhat
# recently and we need to tell LitGPT to use it.

def RMSNorm_forward(self, x):
    return torch.nn.functional.rms_norm(x, self.weight.shape, self.weight, self.eps)
litgpt.model.RMSNorm.forward = RMSNorm_forward

In [3]:
import functools
prod = lambda *args: functools.reduce(lambda x, y: x * y, args)


In [4]:
# ******************************* RMS NORM *******************************   
import functools
def liger_rms_norm_forward_meta(X, W, eps, offset, casting_mode):
    *n_rows, n_cols = X.shape
    n_rows = prod(*n_rows)
    # RSTD is always computed/stored in fp32 if we are using Llama or Gemma casting mode
    rstd_dtype = (
        thunder.dtypes.float32
        if casting_mode in (liger_kernel.ops.rms_norm._CASTING_MODE_LLAMA.value,
                            liger_kernel.ops.rms_norm._CASTING_MODE_GEMMA.value)
        else X.dtype
    )
    Y = TensorProxy(like=X)
    RSTD = TensorProxy(like=X, shape=(n_rows,), dtype=rstd_dtype)
    BLOCK_SIZE, num_warps = liger_kernel.ops.rms_norm.calculate_settings(n_cols)
    return Y, TensorProxy(like=X, shape=(n_rows, n_cols)), RSTD, BLOCK_SIZE, num_warps, casting_mode 

liger_rms_norm_forward = liger_ex.register_operator(
    "liger_rms_norm_forward",
    meta=liger_rms_norm_forward_meta,
    fn=liger_kernel.ops.rms_norm.rms_norm_forward
)

def liger_rms_norm_backward_meta(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warps):
    return TensorProxy(like=X), TensorProxy(like=W)

liger_rms_norm_backward = liger_ex.register_operator(
    "liger_rms_norm_backward",
    meta=liger_rms_norm_backward_meta,
    fn=liger_kernel.ops.rms_norm.rms_norm_backward
)

def rms_norm_meta(x, shape, w, eps):
    return thunder.TensorProxy(like=x)

rms_norm = liger_ex.register_operator('rms_norm', meta=rms_norm_meta, fn=torch.nn.functional.rms_norm, replaces=torch.nn.functional.rms_norm)

def rms_norm_grad_transform(x, shape, weight, eps):
    Y, X, RSTD, BLOCK_SIZE, num_warps, casting_mode = liger_rms_norm_forward(x, weight, eps, offset=0.0, casting_mode="llama")
    dY = get_grad(Y)
    dX, dW = liger_rms_norm_backward(dY, X, weight, RSTD, offset=0.0, casting_mode="llama", BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps)
    dX = dX.view(*x.shape)
    put_grads((x, weight), (dX, dW))
    return Y

def rms_norm_execution_transform(x, weight, eps):
    Y, *_ = liger_rms_norm_forward(x, weight, eps, offset=0.0, casting_mode="llama")
    return Y

liger_ex.register_implementation(
    rms_norm,
    execution_transform=rms_norm_execution_transform,
    grad_transform=rms_norm_grad_transform
)

### Testing RMS Norm

Let's test.

In [5]:
hidden_size = 64

example_input = torch.randn(32, 10, hidden_size, device=device, requires_grad=True)

with device:
    model = litgpt.model.RMSNorm(hidden_size)
thunder_model = thunder.jit(model, executors=[liger_ex])
ref = model(example_input.clone())
res = thunder_model(example_input.clone())
go = torch.randn_like(ref)
grad_ref, grad_ref_weight = torch.autograd.grad(ref, (example_input, model.weight), go)
grad_res, grad_res_weight = torch.autograd.grad(res, (example_input, model.weight), go)


assert liger_rms_norm_forward in {bsym.sym for bsym in thunder.last_traces(thunder_model)[-1].bound_symbols}
assert liger_rms_norm_backward in {bsym.sym for bsym in thunder.last_backward_traces(thunder_model)[-1].bound_symbols}

assert_close(ref, res)
assert_close(grad_ref, grad_res)
assert_close(grad_ref_weight, grad_res_weight)

# RoPE

Next is the RoPE implementation. Liger does both rope applications to query and key in one kernel whereas
LitGPT uses two. So we define not only forward and backward and a symbol to capture the litgpt version,
but also a small transform fusing the two `apply_rope` calls to one `liger_rope`.

In [6]:
def liger_rope_forward_meta(q, k, cos, sin):
    return TensorProxy(like=q), TensorProxy(like=k), cos, sin

liger_rope_forward = liger_ex.register_operator(
    "liger_rope_forward",
    meta=liger_rope_forward_meta,
    fn=liger_kernel.ops.rope.rope_forward,
)

def liger_rope_backward_meta(dq, dk, cos, sin):
    return TensorLike(like=dq), TensorLike(like=dk)

liger_rope_backward = liger_ex.register_operator(
    "liger_rope_backward",
    meta=liger_rope_backward_meta,
    fn=liger_kernel.ops.rope.rope_backward,
)

def liger_rope_grad_transform(q, k, cos, sin):
    q_out, k_out, _, _ = liger_rope_forward(q, k, cos, sin)
    q_out_grad = get_grad(q_out) 
    k_out_grad = get_grad(k_out)
    dq, dk = liger_rope_backward(q_out_grad, k_out_grad, cos, sin)
    put_grads((q, k), (dq, dk))
    return q_out, k_out

def liger_rope_execution_transform(q, k, cos, sin):
    q_out, k_out, _, _ = liger_rope_forward(q, k, cos, sin)
    return q_out, k_out

def liger_rope_impl(q, k, cos, sin):
    qr, kr, _, _ = liger_rope_forward(q, k, cos, sin)
    return qr, kr

liger_rope = liger_ex.register_operator('liger_rope', fn=liger_rope_impl, like=liger_rope_impl)

liger_ex.register_implementation(
    liger_rope,
    execution_transform=liger_rope_execution_transform,
    grad_transform=liger_rope_grad_transform,
)

def litgpt_apply_rope_meta(x, cos, sin):
    return TensorProxy(like=x)

litgpt_apply_rope = liger_ex.register_operator(
    'litgpt_apply_rope', fn=litgpt.model.apply_rope, meta=litgpt_apply_rope_meta, replaces=litgpt.model.apply_rope
)

class MergeRopeTransform(thunder.core.transform_common.Transform):
    def transform_traces_pre_prologue(self, prologue_trace, compute_trace, epilogue_trace, **kwargs):
        new_compute_trace = thunder.core.trace.from_trace(compute_trace)
        bound_symbols = compute_trace.bound_symbols[:]
        while bound_symbols:
            bsym = bound_symbols.pop(0)
            if bsym.sym == litgpt_apply_rope:
                for i, bsym2 in enumerate(bound_symbols):
                    assert not any(o is bsym.output for o in bsym2.flat_outs)
                    if bsym2.sym == litgpt_apply_rope:
                        break
                bsym2 = bound_symbols.pop(i)
                assert bsym2.sym == litgpt_apply_rope

                output = (bsym.output, bsym2.output)
                args = (bsym.args[0], bsym2.args[0], *bsym.args[1:])

                new_compute_trace.bound_symbols.append(bsym.from_bsym(args=args, output=output, sym=liger_rope))
            else:
                new_compute_trace.bound_symbols.append(bsym.from_bsym())
        new_compute_trace.set_provenance(thunder.core.trace.TraceProvenance(self.__class__))
        return prologue_trace, new_compute_trace, epilogue_trace

# Test

We test with a scaled-down Llama.

In [7]:
cfg = litgpt.Config.from_name('Llama-3.1-8B', n_layer=1)

with device:
    m = litgpt.GPT(cfg)
    m.max_seq_length = 1024
    m.set_kv_cache(1)
    inp = torch.arange(1, 6, dtype=torch.int64)[None]
    inp_pos = torch.arange(5)

jm = thunder.jit(m, executors=(liger_ex,), transforms=(MergeRopeTransform(),))
res = jm(inp, inp_pos)
ref = m(inp, inp_pos)

go = torch.randn_like(res)
grad_res, = torch.autograd.grad(res, jm.get_parameter('transformer.wte.weight'), go)
grad_ref, = torch.autograd.grad(ref, m.get_parameter('transformer.wte.weight'), go)

assert_close(res, ref)
assert_close(grad_res, grad_ref)