# Zero to Thunder

Here we take a very short tour of what is possible with Thunder.

To get started we import it (and a bunch of things for this notebook).

In [1]:
import sys
sys.path.insert(0, '..')
import inspect


import torch, thunder


## Compiling a first module with Thunder

So let's get started! As a "Hello World", let us apply it to it to a small model, say, the MLP part found in Llama 2. We take it from LitGPT.

In [2]:
class LLaMAMLP(torch.nn.Module):
    def __init__(self, n_embd, intermediate_size) -> None:
        super().__init__()
        self.fc_1 = torch.nn.Linear(n_embd, intermediate_size, bias=False)
        self.fc_2 = torch.nn.Linear(n_embd, intermediate_size, bias=False)
        self.proj = torch.nn.Linear(intermediate_size, n_embd, bias=False)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x_fc_1 = self.fc_1(x)
        x_fc_2 = self.fc_2(x)
        x = torch.nn.functional.silu(x_fc_1) * x_fc_2
        return self.proj(x)


with torch.device("cuda"):
    m = LLaMAMLP(4096, 11008)
for p in m.parameters():
    p.requires_grad_(False)

print(m)

LLaMAMLP(
  (fc_1): Linear(in_features=4096, out_features=11008, bias=False)
  (fc_2): Linear(in_features=4096, out_features=11008, bias=False)
  (proj): Linear(in_features=11008, out_features=4096, bias=False)
)


Now we can apply Thunder. This uses the most important function of Thunder, `thunder.jit`, which can be used to compile a `torch.nn.Module` or a function. It will wrap our MLP in a `ThunderModule`

In [3]:
thunder_model = thunder.jit(m)

In [4]:
thunder_model

ThunderModule(
  (_model): LLaMAMLP(
    (fc_1): Linear(in_features=4096, out_features=11008, bias=False)
    (fc_2): Linear(in_features=4096, out_features=11008, bias=False)
    (proj): Linear(in_features=11008, out_features=4096, bias=False)
  )
)

Our Thunder module computes (up to numerical accuracy) the same thing as our original model and for a small model like this, it also has approximately the same performance.

In [5]:
x = torch.randn(2, 2048, 4096, device="cuda")
print('deviation:', (thunder_model(x) - m(x)).abs().max().item())

%timeit thunder_model(x); torch.cuda.synchronize()
%timeit m(x); torch.cuda.synchronize()

deviation: 1.4901161193847656e-07
58.2 ms ± 306 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
58.7 ms ± 50.9 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


So what has changed?
Quite a bit!

When we call the Thunder module, it do the computation in a single function without control flow. And what's more, it applies optimizations, such as creating fusions for NVFuser to execute. We can see all this by showing the last computation trace:

In [6]:
thunder.last_traces(thunder_model)[-1]

# Constructed by Delete Last Used (took 0 milliseconds)
import torch
import torch.nn.functional
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast()
def computation(x, t_fc_1_weight, t_fc_2_weight, t_proj_weight):
  # x: "cuda:0 f32[2, 2048, 4096]" 
  # t_fc_1_weight: "cuda:0 f32[11008, 4096]" 
  # t_fc_2_weight: "cuda:0 f32[11008, 4096]" 
  # t_proj_weight: "cuda:0 f32[4096, 11008]" 
  x_fc_1 = torch.nn.functional.linear(x, t_fc_1_weight, None)  # x_fc_1: "cuda:0 f32[2, 2048, 11008]"
    # x_fc_1 = ltorch.linear(x, t_fc_1_weight, None)  # x_fc_1: "cuda:0 f32[2, 2048, 11008]"
      # x_fc_1 = prims.linear(x, t_fc_1_weight, None)  # x_fc_1: "cuda:0 f32[2, 2048, 11008]"
  del t_fc_1_weight
  x_fc_2 = torch.nn.functional.linear(x, t_fc_2_weight, None)  # x_fc_2: "cuda:0 f32[2, 2048, 11008]"
    # x_fc_2 = ltorch.linear(x, t_fc_2_weight, None)  # x_fc_2: "cuda:0 f32[2, 2048, 11008]"
      # x_fc_2 = prims.linear(x, t_fc_2_weight, None)  # x_fc_2: "cuda:0 f32[2

For more detail of what is going on in this trace:
- Thunder has transformed the computation (more precisely, `m.__call__`) into a single function which has all the MLP parameters as arguments.
- It has recorded the tensor metadata.
- Operations have been mapped from the PyTorch functions to `thunder.torch`(aka `ltorch`) equivalents and decomposed into _primitive operations_.
- The multiplication and activation (`x = torch.nn.functional.silu(x_fc_1) * x_fc_2`have been put into one NVFuser fusion. (NVFuser here is (a particularly important) one of many optimizations, and we make it easy to add your own.) 
- You can see how the parameters are obtained and the metadata is checked in the prologue - get it through `thunder.last_prologue_traces(thunder_model)[-1]`.

You can actually see the series of traces, `last_traces` gives you a list of transformed traces in chronological order - for example the initial trace `thunder.last_traces(thunder_model)[0]` does not have the fusion yet.


## Compiling a more complex model

Obviously, we aim for larger models, so we can do the same with the entire LLama 2 (well, we have a smaller momdel here to be mild to our CI, but if you have a large GPU, just drop reducing the number of layers):

In [7]:
from lit_gpt import GPT
from thunder.tests.lit_gpt_model import Config
cfg = Config.from_name('Llama-2-7b-hf')
cfg.n_layer = 4 # fewer layers
with torch.device('cuda'):
    m = GPT(cfg)
m


GPT(
  (lm_head): Linear(in_features=4096, out_features=32000, bias=False)
  (transformer): ModuleDict(
    (wte): Embedding(32000, 4096)
    (h): ModuleList(
      (0-3): 4 x Block(
        (norm_1): RMSNorm()
        (attn): CausalSelfAttention(
          (attn): Linear(in_features=4096, out_features=12288, bias=False)
          (proj): Linear(in_features=4096, out_features=4096, bias=False)
        )
        (norm_2): RMSNorm()
        (mlp): LLaMAMLP(
          (fc_1): Linear(in_features=4096, out_features=11008, bias=False)
          (fc_2): Linear(in_features=4096, out_features=11008, bias=False)
          (proj): Linear(in_features=11008, out_features=4096, bias=False)
        )
      )
    )
    (ln_f): RMSNorm()
  )
)

Again we jit our model and compare the output...

In [8]:
thunder_model = thunder.jit(m)

inp = torch.randint(1, m.config.vocab_size, (1, 512), device="cuda")

actual = thunder_model(inp)
expected = m(inp)

print("deviation:", (actual - expected).abs().max().item())


deviation: 1.8477439880371094e-06


Just like before, we can see the program it ran:

In [9]:
thunder.last_traces(thunder_model)[-1]

# Constructed by Delete Last Used (took 1 milliseconds)
import torch
from torch import Tensor
import torch.nn.functional
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast()
def augmented_forward_fn(*args):
  # args: "Collection" 
  t0, \
  t1, \
  t2, \
  t3, \
  t4, \
  t5, \
  t6, \
  t7, \
  t8, \
  t9, \
  t10, \
  t11, \
  t12, \
  t13, \
  t14, \
  t15, \
  t16, \
  t17, \
  t18, \
  t19, \
  t20, \
  t21, \
  t22, \
  t23, \
  t24, \
  t25, \
  t26, \
  t27, \
  t28, \
  t29, \
  t30, \
  t31, \
  t32, \
  t33, \
  = args
  del args
  t38 = torch.nn.functional.embedding(t0, t33, None, None, 2.0, False, False)  # t38: "cuda:0 f32[1, 512, 4096]"
    # t38 = ltorch.embedding(t0, t33, None, None, 2.0, False, False)  # t38: "cuda:0 f32[1, 512, 4096]"
      # t334 = ltorch.reshape(t0, [512])  # t334: "cuda:0 i64[512]"
        # t334 = prims.reshape(t0, (512,))  # t334: "cuda:0 i64[512]"
      # t335 = prims.take(t33, t334, 0)  # t335: "cuda:0 f32[512, 40

Well, that is quite a bit to look through.
But here is a key thing: The function now returns a buch of things. This is because Thunder applies the same treatment to the backward and to this end saves information from the forward. You can see a hint of this because the output has a `ThunderFunctionBackward` on as its `grad_fn`. (You can see the backward trace with 
`thunder.last_backward_traces(thunder_model)[-1]`).

In [10]:
actual

tensor([[[-0.9922,  0.5946, -0.2173,  ..., -0.0981, -0.5058,  0.2747],
         [-1.1552,  0.5770, -0.7432,  ...,  0.0688,  0.1238,  0.6786],
         [-0.7813,  0.6960,  0.1235,  ..., -0.4840,  0.1373,  0.6490],
         ...,
         [ 0.3711,  0.1656,  0.3350,  ..., -0.0294,  0.3670,  0.5099],
         [-0.2544, -0.8470,  0.2063,  ..., -0.1341,  0.1877,  0.2612],
         [ 0.3420, -1.1421,  0.9222,  ...,  0.5636,  0.1666,  0.6947]]],
       device='cuda:0', grad_fn=<ThunderFunctionBackward>)

One thing to keep in mind here is that for bf16, the numerical accuracy impact of rearranging operations can be quite pronounced.

In [11]:
actual_grads = torch.autograd.grad(actual.sum(), m.parameters())
expected_grads = torch.autograd.grad(expected.sum(), m.parameters())
print("maximum deviation grads:", max((a-e).abs().max().item() for a, e in zip(actual_grads, expected_grads)))

maximum deviation grads: 0.00042724609375


But is it faster? Yes!

In [12]:
import gc
gc.collect()
%timeit r = m(inp); torch.autograd.grad(r.sum(), m.parameters()); torch.cuda.synchronize()
%timeit r = thunder_model(inp); torch.autograd.grad(r.sum(), m.parameters()); torch.cuda.synchronize()

154 ms ± 281 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
150 ms ± 342 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [13]:
del m, thunder_model
import gc
gc.collect()
torch.cuda.empty_cache()

So far, so good! Thunder should work with LitGPT today and we busy are adding the support required to run other models as well!

## Distributed with Thunder

Those Large Language Models are called Large for a reason, and memory in a single GPU is invariably small. So we need multiple.

Happily Thunder sports an FSDP interface to use multiple cards in our box.

You still need to setup the process group, but as far as the model is concerned,

```python
model = thunder.jit(thunder.distributed.fsdp(model))
```

is all you need. Because it is tricky to run multiprocessing from Notebooks, we write a small example into a file and run it though `torch-run`.

Check out our LitGPT Thunder examples for complete distributed training and finetuning!

In [14]:
%%writefile zero_to_thunder_fsdp_simple_example.py
import sys
sys.path.insert(0, '..')
from thunder.tests.lit_gpt_model import GPT, Config

import torch
import torch.distributed
import thunder
import thunder.distributed
import os

# Create Model
# NOTE: We create the model on CPU.
device='cpu'
torch.set_default_dtype(torch.bfloat16)
model = GPT.from_name('llama2-like')
# Setup for distributed
torch.distributed.init_process_group(backend='nccl')
rank = int(os.environ["LOCAL_RANK"])

device = f"cuda:{rank}"
x = torch.randint(1, model.config.vocab_size, (1, 1024), device=device)

# thunder.distributed.fsdp takes care of moving the parameter
# shard to the correct GPU for the current process.
model = thunder.jit(thunder.distributed.fsdp(model)) #  <---------------------------------------

# Run the forward pass.
res = model(x)
res.sum().backward()

res = model(x)
res.sum().backward()


Overwriting zero_to_thunder_fsdp_simple_example.py


In [15]:
!torchrun --nproc_per_node=2 zero_to_thunder_fsdp_simple_example.py

W0316 11:53:02.156000 140513675427904 torch/distributed/run.py:757] 
W0316 11:53:02.156000 140513675427904 torch/distributed/run.py:757] *****************************************
W0316 11:53:02.156000 140513675427904 torch/distributed/run.py:757] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. 
W0316 11:53:02.156000 140513675427904 torch/distributed/run.py:757] *****************************************


So there. FSDP with just wrapping the model in `fsdp`.

## Extending Thunder

But we promised that thunder is extensible. Let's find out what's up with that.

Specifically, we will incorporate the RMSNorm kernel from the great [Unsloth project](https://github.com/unslothai/unsloth/) into our model (note that NVFuser also creates a fused kernel for this).

In Thunder, extensions (as well as most builtin optimizations which use the exact same mechanism) work with _executors_ handling operations. Let us define one.

In [16]:
my_ex = thunder.extend.OperatorExecutor('my_ex', version='0.0.1')
thunder.extend.register_executor(my_ex)

my_ex

For our base implementation, we take the ccode from [LitGPT's RMSNorm implementation](https://github.com/Lightning-AI/litgpt/blob/7c1574925f973e64c0a53e056b77229bedee1619/lit_gpt/rmsnorm.py)

In thunder, we define a *meta* function that only defines the metadata (like shapes) of outputs and the actual implementation for each operator and then register the pair with our executor using the `register_operator` function.


In [17]:
from thunder import TensorProxy

# Taken from LitGPT, who in turn credit:
# Derived from https://github.com/bzhangGo/rmsnorm/blob/master/rmsnorm_torch.py. BSD 3-Clause License:
#    https://github.com/bzhangGo/rmsnorm/blob/master/LICENSE.

def rms_norm_impl(x: torch.Tensor, weight, dim: int, eps: float, add_unit_offset: bool) -> torch.Tensor:
    dtype = x.dtype
    x = x.float()
    # NOTE: the original RMSNorm paper implementation is not equivalent
    norm_x = torch.mean(x * x, dim=dim, keepdim=True)
    x_normed = x * torch.rsqrt(norm_x + eps)
    x_normed = x_normed.to(dtype=dtype)
    if add_unit_offset:
        # Gemma model requires a unit offset
        # https://github.com/google/gemma_pytorch/blob/main/gemma/model.py#L176
        return x_normed * (1 + weight)
    return x_normed * weight

def rms_norm_meta(x: TensorProxy, weight, dim: int, eps: float, add_unit_offset: bool) -> TensorProxy:
    return TensorProxy(like=x)

rms_norm = my_ex.register_operator('rms_norm', meta=rms_norm_meta, fn=rms_norm_impl)


Because evil monkey-patching is a thing for short demos is a thing, let's replace LitGPT's own implementation. For your own model, you might start out with a that in your code directly.

In [18]:
import lit_gpt.rmsnorm
if not hasattr(lit_gpt.rmsnorm, 'ThunderOrigRMSNorm'):
    lit_gpt.rmsnorm.ThunderOrigRMSNorm = lit_gpt.rmsnorm.RMSNorm

class ThunderizedRMSNorm(lit_gpt.rmsnorm.ThunderOrigRMSNorm):
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # This isn't the best paradigm. :/
        if thunder.core.interpreter.is_jitting():
            return rms_norm(x, self.weight, self.dim, self.eps, self.add_unit_offset)
        else:
            return super().forward(x)

lit_gpt.rmsnorm.RMSNorm = ThunderizedRMSNorm

We can try our new RMSNorm: 

In [19]:
with torch.device('cuda'):
    norm_module = ThunderizedRMSNorm(4096)
    x = torch.randn(256, 4096)

# we're not quite there to handle forward and backward yet, we'll re-enable them below
for p in norm_module.parameters():  
    p.requires_grad_(False)

thunder_norm_module = thunder.jit(norm_module, executors=(my_ex,) + thunder.get_default_executors())    

expected = norm_module(x)
actual = thunder_norm_module(x)

print("deviation:", (expected - actual).abs().max().item())

thunder.last_traces(thunder_norm_module)[-1]

deviation: 0.0


# Constructed by Delete Last Used (took 0 milliseconds)
import torch
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast()
def computation(x, t_weight):
  # x: "cuda:0 f32[256, 4096]" 
  # t_weight: "cuda:0 f32[4096]" 
  t7 = rms_norm(x, t_weight, -1, 1e-06, False)  # t7: "cuda:0 f32[256, 4096]"
  del x, t_weight
  return t7

But why did we do this? Well, we can now layer a faster implementation on top.
For this we take the [unsloth RMSNorm](https://github.com/unslothai/unsloth/blob/42076f6580e71522ed1c122043edfba595be64e4/unsloth/kernels/rms_layernorm.py) kernels. We the bits that were in the forward and backward of the `autograd.Function` into our implementation functions and define the corresponding metas.

In [20]:
# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import triton
import triton.language as tl
import torch

MAX_FUSED_SIZE = 65536
next_power_of_2 = triton.next_power_of_2

def calculate_settings(n):
    BLOCK_SIZE = next_power_of_2(n)
    if BLOCK_SIZE > MAX_FUSED_SIZE:
        raise RuntimeError(f"Cannot launch Triton kernel since n = {n} exceeds "\
                           f"the maximum CUDA blocksize = {MAX_FUSED_SIZE}.")
    num_warps = 4
    if   BLOCK_SIZE >= 32768: num_warps = 32
    elif BLOCK_SIZE >=  8192: num_warps = 16
    elif BLOCK_SIZE >=  2048: num_warps = 8
    return BLOCK_SIZE, num_warps

@triton.jit
def _rms_layernorm_forward(
    Y, Y_row_stride,
    X, X_row_stride,
    W, W_row_stride,
    r, r_row_stride,
    n_cols, eps,
    BLOCK_SIZE : tl.constexpr
):
    """
        Fast RMS Layernorm kernel
        Inspiration from a Triton tutorial:
        https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
    """
    row_idx = tl.program_id(0)
    col_offsets = tl.arange(0, BLOCK_SIZE)
    mask = col_offsets < n_cols

    Y += row_idx * Y_row_stride
    X += row_idx * X_row_stride
    r += row_idx * r_row_stride

    X_row = tl.load(X + col_offsets, mask = mask, other = 0).to(tl.float32)
    W_row = tl.load(W + col_offsets, mask = mask, other = 0)#.to(tl.float32)

    row_var = tl.sum(X_row * X_row, axis = 0) / n_cols
    inv_var = tl.math.rsqrt(row_var + eps)
    tl.store(r, inv_var)
    normed = X_row * inv_var
    normed = normed.to(W_row.dtype) # Exact copy from HF
    output = normed * W_row
    tl.store(Y + col_offsets, output, mask = mask)


@triton.jit
def _rms_layernorm_backward(
    dY, dY_row_stride,
    X,   X_row_stride,
    W,   W_row_stride,
    r,   r_row_stride,
    dW, dW_row_stride,
    n_cols, eps,
    BLOCK_SIZE : tl.constexpr,
):
    """
        Fast RMS Layernorm kernel for the backward pass
        Inspiration from a Triton tutorial:
        https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
    """
    row_idx = tl.program_id(0)
    col_offsets = tl.arange(0, BLOCK_SIZE)
    mask = col_offsets < n_cols

    dY += row_idx * dY_row_stride
    X  += row_idx *  X_row_stride
    r  += row_idx *  r_row_stride

    dY_row = tl.load(dY + col_offsets, mask = mask, other = 0).to(tl.float32)
    X_row  = tl.load(X  + col_offsets, mask = mask, other = 0).to(tl.float32)
    W_row  = tl.load(W  + col_offsets, mask = mask, other = 0).to(tl.float32)

    # Get saved row variance
    inv_var = tl.load(r).to(tl.float32)
    normed = X_row * inv_var

    dY_W = dY_row * W_row

    rowsum_dY_normed = tl.sum(dY_W * normed, axis = 0)
    output = inv_var/n_cols * (n_cols*dY_W - normed*rowsum_dY_normed)
    tl.store(dY + col_offsets, output, mask = mask)
    
def rms_layernorm_forward_impl(X, W, eps):
    shape = X.shape
    dim = shape[-1]
    X = X.view(-1, dim)
    n_rows, n_cols = X.shape
    BLOCK_SIZE, num_warps = calculate_settings(n_cols)

    Y = torch.empty((n_rows, n_cols), dtype = X.dtype, device = "cuda")
    r = torch.empty(n_rows, dtype = torch.float32, device = "cuda")

    _rms_layernorm_forward[(n_rows,)](
        Y, Y.stride(0),
        X, X.stride(0),
        W, W.stride(0),
        r, r.stride(0),
        n_cols, eps,
        BLOCK_SIZE = BLOCK_SIZE,
        num_warps  = num_warps,
    )
    return Y.view(*shape), (r, BLOCK_SIZE, num_warps)

def rms_layernorm_forward_meta(X, W, eps):
    n_cols = X.shape[-1]
    n_rows = 1
    for i in X.shape[:-1]:
        n_rows *= i
    BLOCK_SIZE, num_warps = calculate_settings(n_cols)
    Y = TensorProxy(like=X, requires_grad=True)
    return (Y,
            (TensorProxy(shape=(n_rows,), device=X.device, dtype=thunder.dtypes.float32, requires_grad=False),
             BLOCK_SIZE, 
             num_warps,
            )
           )

def rms_layernorm_backward_impl(X, W, r, eps, BLOCK_SIZE, num_warps, dY):
    shape = dY.shape
    dim = shape[-1]
    dY = dY.view(-1, dim)
    n_rows, n_cols = dY.shape
    dW = X
    dX = dY.clone()
    _rms_layernorm_backward[(n_rows,)](
        dX, dX.stride(0),
        X,  X .stride(0),
        W,  W .stride(0),
        r,  r .stride(0),
        dW, dW.stride(0),
        n_cols, eps,
        BLOCK_SIZE = BLOCK_SIZE,
        num_warps  = num_warps,
    )
    dX = dX.view(*shape)
    return dX

def rms_layernorm_backward_meta(X, W, r, eps, BLOCK_SIZE, num_warps, dY):
    return TensorProxy(like=dY)

With this, we can just register the additional operators:

In [21]:
unsloth_rms_norm_forward = my_ex.register_operator('unsloth_rms_norm_forward', meta=rms_layernorm_forward_meta, fn=rms_layernorm_forward_impl)
unsloth_rms_norm_backward = my_ex.register_operator('unsloth_rms_norm_backward', meta=rms_layernorm_backward_meta, fn=rms_layernorm_backward_impl)

But instead of monkey-patching more, we can now register the kernel as an _implementation_ of the base `rms_norm` primitive defined above. For this we need an _execution transform_ - which is a fancy word for a function that implements the original operator (`rms_norm`) in terms of our new operator - so it has the call signature of the `rms_norm`. Because - like many fast implementations - the unsloth RMS norm does not implement the operator in full generality (to do them justice, they have a variant adding the unit offset, we just didn't copy it over), we implement a checker function, too: It takes the arguments of the operator we want specialize and returns a bool whether our implementation handles the given inputs.

In [28]:
def rms_norm_to_unsloth(x: TensorProxy, weight: TensorProxy, dim: int, eps: float, add_unit_offset: bool):
    assert dim == -1 and not add_unit_offset
    res, _ = unsloth_rms_norm_forward(x, weight, eps)
    return res

def rms_norm_to_unsloth_checker(x: TensorProxy, weight: TensorProxy, dim: int, eps: float, add_unit_offset: bool):
    if dim != -1 or add_unit_offset:
        return False
    if weight.requires_grad:
        return False  # the unsloth rms norm backwward only gives the grad w.r.t. x
    return x.device.devicetype == thunder.devices.DeviceType.CUDA and weight.device.devicetype == thunder.devices.DeviceType.CUDA

my_ex.register_implementation(rms_norm, checker=rms_norm_to_unsloth_checker, execution_transform=rms_norm_to_unsloth)


So let us give that a try! Works great...

In [23]:
with torch.device('cuda'):
    norm_module = ThunderizedRMSNorm(4096)

# unfortunately, we meet dragons if we don't do this at this stage
for p in norm_module.parameters():  
    p.requires_grad_(False)

thunder_norm_module = thunder.jit(norm_module, executors=[my_ex,])    
x = torch.randn(2048, 4096, device="cuda")

expected = norm_module(x)
actual = thunder_norm_module(x)

print("deviation:", (expected - actual).abs().max().item())

thunder.last_traces(thunder_norm_module)[-1]

deviation: 9.5367431640625e-07


# Constructed by Delete Last Used (took 0 milliseconds)
import torch
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast()
def computation(x, t_weight):
  # x: "cuda:0 f32[2048, 4096]" 
  # t_weight: "cuda:0 f32[4096]" 
  (t7, (_, _, _)) = unsloth_rms_norm_forward(x, t_weight, 1e-06)
  del x, t_weight
  return t7

And this is also automatic when we instantiate a larger llama2-like model:

In [24]:
torch.set_default_dtype(torch.float32)
with torch.device('cuda'):
    m = GPT(Config.from_name('llama2-like'))

for p in m.parameters():
    p.requires_grad_(False)

thunder_model = thunder.jit(m, executors=(my_ex,) + thunder.get_default_executors())

inp = torch.randint(1, m.config.vocab_size, (1, 128), device="cuda")
actual = thunder_model(inp)
expected = m(inp)

print("deviation:", (actual - expected).abs().max().item())

deviation: 4.76837158203125e-07


By peeking into the trace, we can see that it actually used the unsloth RMS kernels:

In [25]:
[s for s in str(thunder.last_traces(thunder_model)[-1]).split('\n') if 'rms' in s]

['  (n_1, (_, _, _)) = unsloth_rms_norm_forward(x, t_transformer_h_0_norm_1_weight, 1e-05)',
 '  (t110, (_, _, _)) = unsloth_rms_norm_forward(t102, t_transformer_h_0_norm_2_weight, 1e-05)',
 '  (t139, (_, _, _)) = unsloth_rms_norm_forward(t130, t_transformer_h_1_norm_1_weight, 1e-05)',
 '  (t215, (_, _, _)) = unsloth_rms_norm_forward(t207, t_transformer_h_1_norm_2_weight, 1e-05)',
 '  (t243, (_, _, _)) = unsloth_rms_norm_forward(t235, t_transformer_ln_f_weight, 1e-05)']

But what about the backward?

Well, we have to connect forward and backward with a grad transformation. With our specialized ops, this is very simple, we compute the forward, call `get_grad` for the output, compute the backward, and put it on the input with `put_grads`.

In [26]:
from thunder.core.transforms import get_grad, put_grads

def unsloth_rms_norm_grad(x: TensorProxy, weight, dim: int, eps: float, add_unit_offset: bool):
    res, (r, BLOCK_SIZE, num_warps) = unsloth_rms_norm_forward(x, weight, eps)
    grad_res = get_grad(res)
    grad_x = unsloth_rms_norm_backward(x, weight, r, eps, BLOCK_SIZE, num_warps, grad_res)
    put_grads((x,), (grad_x,))
    return res

my_ex.register_implementation(rms_norm, checker=rms_norm_to_unsloth_checker,
                              execution_transform=rms_norm_to_unsloth,
                              grad_transform=unsloth_rms_norm_grad 
                              )



In [27]:
with torch.device('cuda'):
    norm_module = ThunderizedRMSNorm(4096)
    norm_module.weight.requires_grad_(False)
    x = torch.randn(256, 4096, requires_grad=True)

thunder_norm_module = thunder.jit(norm_module, executors=(my_ex,) + thunder.get_default_executors())    

actual = thunder_norm_module(x)
expected = norm_module(x)
actual_grads = torch.autograd.grad(actual.sum(), x)
expected_grads = torch.autograd.grad(expected.sum(),  x)

print("maximum deviation grads:", max((a-e).abs().max().item() for a, e in zip(actual_grads, expected_grads)))

torch.Size([256, 4096]) torch.Size([256, 4096]) torch.Size([4096]) torch.Size([256]) torch.Size([256, 4096])
(4096, 1) (4096, 1) (1,) (1,) (4096, 1)
maximum deviation grads: 3.5762786865234375e-07


And here is our module having the unsloth backward:

In [29]:
thunder.last_backward_traces(thunder_norm_module)[-1]

# Constructed by Delete Last Used (took 0 milliseconds)
import torch
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast()
def backward_fn(saved_for_backward, cotangents):
  # saved_for_backward: "Collection" 
  # cotangents: "Collection" 
  C0, \
  C1, \
  = saved_for_backward
  clear_collection(saved_for_backward)
  del saved_for_backward
  t4, \
  = cotangents
  clear_collection(cotangents)
  del cotangents
  t0, \
  t1, \
  t3, \
  = C0
  clear_collection(C0)
  del C0
  f0, \
  = C1
  clear_collection(C1)
  del C1
  t2 = unsloth_rms_norm_backward(t0, t1, t3, f0, 4096, 8, t4)  # t2: "cuda:0 f32[256, 4096]"
  del t0, t1, t3, f0, t4
  return (t2, None)

That's it! Do check out our LitGPT studios and the other tutorial notebooks.
