# From prototyping in PyTorch to running with custom CUDA kernels
In this demo we will show how one can go about prototyping models in PyTorch, and then iteratively replacing some parts with custom code using Thunder. Here, we will be dealing with custom CUDA kernels, although Thunder is quite flexible and can be adapted to work with other compute environments. We highly recommend checking the [Zero to Thunder tutorial](./zero_to_thunder.ipynb) for a general overview of the Thunder's capabilities.

As an example, we will use a GPT-2 implementation from [llm.c](https://github.com/karpathy/llm.c). We will start with the PyTorch reference implementation and then replace many of its parts with the native [llm.c](https://github.com/karpathy/llm.c) CUDA kernels using Thunder. This serves to demonstrate the following iterative strategy for model runtime performance optimization:
* One starts with a prototype model implemented in PyTorch.
* Once critical sections of PyTorch program are identified, they are to be replaced with custom implementations for better performance and/or control. It is quite convenient to do so in Python, and Thunder is especially well-suited for this task.
* Once these critical sections are performing as expected, one can go further and re-implement the whole model for the native environment to reduce any additional overhead coming from Python, PyTorch, among others. For example, [llm.c](https://github.com/karpathy/llm.c) implements a very lean C/CUDA GPT-2 model that does not depend on PyTorch nor cPython.

## [llm.c](https://github.com/karpathy/llm.c)
[llm.c](https://github.com/karpathy/llm.c) showcases that training LLMs is quite simple, and provides a very lean C/CUDA implementation of GPT-2 without additional overhead that comes from using PyTorch and cPython. It has a PyTorch reference implementation that we will be actively working with.

Before we dive deeper, let us do some preparatory work with [llm.c](https://github.com/karpathy/llm.c)

In [1]:
%%bash
git clone https://github.com/karpathy/llm.c.git
cd llm.c
git checkout 954077fb887d2770e4d537bafea056473d4bb4ce
pip install -r requirements.txt
python prepro_tinyshakespeare.py
python train_gpt2.py

fatal: destination path 'llm.c' already exists and is not an empty directory.
Previous HEAD position was 50acc12 Merge branch 'ngc92-split-file' Separates out common error-checking wrapper utils, that are broadly useful across all file
HEAD is now at 954077f TRAINING WORKSgit add train_gpt2.cu! ITS SLOW BUT IT WORKS WOOT


data/tiny_shakespeare.txt already exists, skipping download...
Saved 32768 tokens to data/tiny_shakespeare_val.bin
Saved 305260 tokens to data/tiny_shakespeare_train.bin
using device: cuda
wrote gpt2_tokenizer.bin
loading weights from pretrained gpt: gpt2
loading cached tokens in data/tiny_shakespeare_val.bin
wrote gpt2_124M.bin
wrote gpt2_124M_debug_state.bin
iteration 0, loss: 5.270008563995361, time: 3038.967ms
iteration 1, loss: 4.059720993041992, time: 33.077ms
iteration 2, loss: 3.3751838207244873, time: 35.995ms
iteration 3, loss: 2.800813913345337, time: 36.473ms
iteration 4, loss: 2.315413475036621, time: 36.510ms
iteration 5, loss: 1.8490413427352905, time: 36.740ms
iteration 6, loss: 1.3946460485458374, time: 42.617ms
iteration 7, loss: 0.9992104768753052, time: 42.732ms
iteration 8, loss: 0.6240706443786621, time: 42.645ms
iteration 9, loss: 0.3764864206314087, time: 42.529ms
final 20 iters avg: 338.828ms
<|endoftext|>One year ago today:
This is the first week since we last

The lines above clone the repository, check it out at `954077fb887d2770e4d537bafea056473d4bb4ce` (yes, we need this as it is being developed very rapidly) and then install all the necessary dependencies. `python prepro_tinyshakespeare.py` downloads the [tinyshakespeare](https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt) dataset, tokenezes it with the GPT-2 Tokenizer and saves the GPT-2 (124M) weights. `python train_gpt2.py` loads the weights and runs the reference PyTorch model for a dozen of iterations.

## Reference PyTorch implementation
This is where it all begins for us! The code below is a modified version of the original [train_gpt2.py](./llm.c/train_gpt2.py)

In [2]:
# Based off train_gpt2.py from https://github.com/karpathy/llm.c/tree/954077fb887d2770e4d537bafea056473d4bb4ce

"""
Reference code for GPT-2 training and inference.
Will save the model weights into files, to be read from C as initialization.

References:
1) the official GPT-2 TensorFlow implementation released by OpenAI:
https://github.com/openai/gpt-2/blob/master/src/model.py
2) huggingface/transformers PyTorch implementation:
https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py
"""

import os
import math
import struct
from dataclasses import dataclass

import numpy as np
import torch
import torch.nn as nn
from torch.nn import functional as F

# Attention-related {
def _permute(qkv, n_embd, n_head, B, T, C):
    q, k, v = qkv.split(n_embd, dim=2)
    q = q.view(B, T, n_head, C // n_head).transpose(1, 2) # (B, nh, T, hs)
    k = k.view(B, T, n_head, C // n_head).transpose(1, 2) # (B, nh, T, hs)
    v = v.view(B, T, n_head, C // n_head).transpose(1, 2) # (B, nh, T, hs)
    return q, k, v


def _unpermute(y, B, T, C):
    return y.transpose(1, 2).contiguous().view(B, T, C)


def _manual_attention(qkv, bias, n_embd, n_head, B, T, C):
    q, k, v = _permute(qkv, n_embd, n_head, B, T, C)

    # manual implementation of attention
    att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
    att = att.masked_fill(bias[:,:,:T,:T] == 0, float('-inf'))
    att = F.softmax(att, dim=-1)
    y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)

    y = _unpermute(y, B, T, C) # re-assemble all head outputs side by side

    return y
# }

class NewGELU(nn.Module):
    """Careful there are a few versions of GeLU, this one is the exact one used by OpenAI"""
    def forward(self, input):
        return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0))))

class CausalSelfAttention(nn.Module):

    def __init__(self, config):
        super().__init__()
        assert config.n_embd % config.n_head == 0
        # key, query, value projections for all heads, but in a batch
        self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
        # output projection
        self.c_proj = nn.Linear(config.n_embd, config.n_embd)
        # regularization
        self.n_head = config.n_head
        self.n_embd = config.n_embd
        # not really a 'bias', more of a mask, but following the OpenAI/HF naming though
        self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
                                     .view(1, 1, config.block_size, config.block_size))

    def forward(self, x):
        B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        qkv = self.c_attn(x)

        y = _manual_attention(qkv, self.bias, self.n_embd, self.n_head, B, T, C)

        # output projection
        y = self.c_proj(y)
        return y

class MLP(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.c_fc    = nn.Linear(config.n_embd, 4 * config.n_embd)
        self.gelu    = NewGELU()
        self.c_proj  = nn.Linear(4 * config.n_embd, config.n_embd)

    def forward(self, x):
        x = self.c_fc(x)
        x = self.gelu(x)
        x = self.c_proj(x)
        return x

class Block(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.ln_1 = nn.LayerNorm(config.n_embd)
        self.attn = CausalSelfAttention(config)
        self.ln_2 = nn.LayerNorm(config.n_embd)
        self.mlp = MLP(config)

    def forward(self, x):
        x = x + self.attn(self.ln_1(x))
        x = x + self.mlp(self.ln_2(x))
        return x

@dataclass
class GPTConfig:
    block_size: int = 1024
    vocab_size: int = 50257
    n_layer: int = 12
    n_head: int = 12
    n_embd: int = 768

class GPT(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.config = config

        self.transformer = nn.ModuleDict(dict(
            wte = nn.Embedding(config.vocab_size, config.n_embd),
            wpe = nn.Embedding(config.block_size, config.n_embd),
            h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
            ln_f = nn.LayerNorm(config.n_embd),
        ))
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
        self.transformer.wte.weight = self.lm_head.weight # https://paperswithcode.com/method/weight-tying

    def forward(self, idx, targets=None):
        device = idx.device
        b, t = idx.size()
        assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
        pos = torch.arange(0, t, dtype=torch.long, device=device) # shape (t)

        # forward the GPT model itself
        tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
        pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd)
        x = tok_emb + pos_emb

        for block in self.transformer.h:
            x = block(x)
        x = self.transformer.ln_f(x)

        if targets is not None:
            # if we are given some desired targets also calculate the loss
            logits = self.lm_head(x)
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
        else:
            # inference-time mini-optimization: only forward the lm_head on the very last position
            logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim
            loss = None

        return logits, loss

    @classmethod
    def from_pretrained(cls, model_type):
        """Loads pretrained GPT-2 model weights from huggingface"""
        assert model_type in {'gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl'}
        from transformers import GPT2LMHeadModel
        print("loading weights from pretrained gpt: %s" % model_type)

        # n_layer, n_head and n_embd are determined from model_type
        config_args = {
            'gpt2':         dict(n_layer=12, n_head=12, n_embd=768),  # 124M params
            'gpt2-medium':  dict(n_layer=24, n_head=16, n_embd=1024), # 350M params
            'gpt2-large':   dict(n_layer=36, n_head=20, n_embd=1280), # 774M params
            'gpt2-xl':      dict(n_layer=48, n_head=25, n_embd=1600), # 1558M params
        }[model_type]
        config_args['vocab_size'] = 50257 # always 50257 for GPT model checkpoints
        config_args['block_size'] = 1024 # always 1024 for GPT model checkpoints
        # create a from-scratch initialized minGPT model
        config = GPTConfig(**config_args)
        model = GPT(config)
        sd = model.state_dict()
        sd_keys = sd.keys()
        sd_keys = [k for k in sd_keys if not k.endswith('.attn.bias')] # discard this mask / buffer, not a param

        # init a huggingface/transformers model
        model_hf = GPT2LMHeadModel.from_pretrained(model_type)
        sd_hf = model_hf.state_dict()

        # copy while ensuring all of the parameters are aligned and match in names and shapes
        sd_keys_hf = sd_hf.keys()
        sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.masked_bias')] # ignore these, just a buffer
        sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.bias')] # same, just the mask (buffer)
        transposed = ['attn.c_attn.weight', 'attn.c_proj.weight', 'mlp.c_fc.weight', 'mlp.c_proj.weight']
        # basically the openai checkpoints use a "Conv1D" module, but we only want to use a vanilla Linear
        # this means that we have to transpose these weights when we import them
        assert len(sd_keys_hf) == len(sd_keys), f"mismatched keys: {len(sd_keys_hf)} != {len(sd_keys)}"
        for k in sd_keys_hf:
            if any(k.endswith(w) for w in transposed):
                # special treatment for the Conv1D weights we need to transpose
                assert sd_hf[k].shape[::-1] == sd[k].shape
                with torch.no_grad():
                    sd[k].copy_(sd_hf[k].t())
            else:
                # vanilla copy over the other parameters
                assert sd_hf[k].shape == sd[k].shape
                with torch.no_grad():
                    sd[k].copy_(sd_hf[k])

        return model

    @torch.no_grad()
    def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
        """
        Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete
        the sequence max_new_tokens times, feeding the predictions back into the model each time.
        Most likely you'll want to make sure to be in model.eval() mode of operation for this.
        """
        for _ in range(max_new_tokens):
            # if the sequence context is growing too long we must crop it at block_size
            idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:]
            # forward the model to get the logits for the index in the sequence
            logits, _ = self(idx_cond)
            # pluck the logits at the final step and scale by desired temperature
            logits = logits[:, -1, :] / temperature
            # optionally crop the logits to only the top k options
            if top_k is not None:
                v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
                logits[logits < v[:, [-1]]] = -float('Inf')
            # apply softmax to convert logits to (normalized) probabilities
            probs = F.softmax(logits, dim=-1)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1)
            # append sampled index to the running sequence and continue
            idx = torch.cat((idx, idx_next), dim=1)

        return idx

# a few utilities for saving params/grads/activations to files for loading in C
def write_fp32(tensor, file):
    file.write(tensor.detach().numpy().astype("float32").tobytes())

def write_tensors(model_tensors, L, file):
    write_fp32(model_tensors["transformer.wte.weight"], file) # (V, C)
    write_fp32(model_tensors["transformer.wpe.weight"], file) # (T, C)
    for i in range(L): # (L, C)
        write_fp32(model_tensors[f"transformer.h.{i}.ln_1.weight"], file)
    for i in range(L): # (L, C)
        write_fp32(model_tensors[f"transformer.h.{i}.ln_1.bias"], file)
    for i in range(L): # (L, 3C, C)
        write_fp32(model_tensors[f"transformer.h.{i}.attn.c_attn.weight"], file)
    for i in range(L): # (L, 3C)
        write_fp32(model_tensors[f"transformer.h.{i}.attn.c_attn.bias"], file)
    for i in range(L): # (L, C, C)
        write_fp32(model_tensors[f"transformer.h.{i}.attn.c_proj.weight"], file)
    for i in range(L): # (L, C)
        write_fp32(model_tensors[f"transformer.h.{i}.attn.c_proj.bias"], file)
    for i in range(L): # (L, C)
        write_fp32(model_tensors[f"transformer.h.{i}.ln_2.weight"], file)
    for i in range(L): # (L, C)
        write_fp32(model_tensors[f"transformer.h.{i}.ln_2.bias"], file)
    for i in range(L): # (L, 4C, C)
        write_fp32(model_tensors[f"transformer.h.{i}.mlp.c_fc.weight"], file)
    for i in range(L): # (L, 4C)
        write_fp32(model_tensors[f"transformer.h.{i}.mlp.c_fc.bias"], file)
    for i in range(L): # (L, C, 4C)
        write_fp32(model_tensors[f"transformer.h.{i}.mlp.c_proj.weight"], file)
    for i in range(L): # (L, C)
        write_fp32(model_tensors[f"transformer.h.{i}.mlp.c_proj.bias"], file)
    write_fp32(model_tensors["transformer.ln_f.weight"], file) # (C, )
    write_fp32(model_tensors["transformer.ln_f.bias"], file) # (C, )

def write_model(model, filename):
    # everything we need to instantiate the model
    # 1) header is: version int, GPTConfig ints, padding to 1024 bytes
    header = torch.zeros(256, dtype=torch.int32)
    header[0] = 20240326 # magic
    header[1] = 1 # checkpoint version = 1
    header[2] = model.config.block_size
    header[3] = model.config.vocab_size
    header[4] = model.config.n_layer
    header[5] = model.config.n_head
    header[6] = model.config.n_embd
    # 2) the parameters on CPU are next
    params = {name: param.cpu() for name, param in model.named_parameters()}
    # now write
    with open(filename, "wb") as file:
        # header
        file.write(header.numpy().tobytes())
        # model parameters
        write_tensors(params, model.config.n_layer, file)
    print(f"wrote {filename}")

def write_state(model, x, y, logits, loss, filename):
    # the state is used for debugging.
    # it contains information about the input, logits, loss, and the parameter gradients
    # this can be used for checking the computation correctness in C
    header = torch.zeros(256, dtype=torch.int32)
    header[0] = 20240327 # magic
    header[1] = 1 # run state version = 1
    header[2] = x.size(0) # batch size of the batch, B
    header[3] = x.size(1) # temporal extent of the batch, T
    grads = {name: param.grad.cpu() for name, param in model.named_parameters()}
    with open(filename, "wb") as file:
        # header
        file.write(header.numpy().tobytes())
        # input x
        file.write(x.cpu().numpy().astype("int32").tobytes()) # (B, T)
        # targets y
        file.write(y.cpu().numpy().astype("int32").tobytes()) # (B, T)
        # logits (result of the model forward pass)
        write_fp32(logits.cpu(), file)
        # loss (single float, result of the cross entropy loss)
        write_fp32(loss.cpu(), file)
        # gradients
        write_tensors(grads, model.config.n_layer, file)
    print(f"wrote {filename}")

def write_tokenizer(enc, filename):
    n = enc.max_token_value + 1
    header = torch.zeros(256, dtype=torch.int32)
    header[0] = 20240328 # magic
    header[1] = 1 # tokenizer version = 1
    header[2] = n # number of tokens
    with open(filename, "wb") as file:
        file.write(header.numpy().tobytes())
        for i in range(n):
            b = enc.decode_bytes([i])
            length = len(b)
            assert length < 256, f"Token length exceeds 255: {length}"
            file.write(struct.pack("<B", length))  # Write the length as a 1-byte unsigned integer
            file.write(b)  # Write the actual bytes
    print(f"wrote {filename}")

Methods such as `_permute`, `_unpermute` and `_manual_attention` right at the very top group some parts of the original code. These we will map to CUDA kernels that [llm.c](https://github.com/karpathy/llm.c) provides using Thunder later on!

The code below is a simplification of the `__main__` method from [llm.c/train_gpt2.py](./llm.c/train_gpt2.py) that we wrap into a callable for convenience.

In [3]:
# based of train_gpt2.py from https://github.com/karpathy/llm.c/tree/954077fb887d2770e4d537bafea056473d4bb4ce
import time
import argparse
import tiktoken
import thunder


def get_data_iter(B, T, device):
    # load the tokens
    # prefer to use tiny_shakespeare if it's available, otherwise use tiny_stories
    # we're using val instead of train split just because it is smaller/faster
    shake_tokens_bin = "./llm.c/data/tiny_shakespeare_val.bin"
    story_tokens_bin = "./llm.c/data/TinyStories_val.bin"
    assert os.path.isfile(shake_tokens_bin) or os.path.isfile(story_tokens_bin), "you must run prepro on some dataset"
    tokens_bin = shake_tokens_bin if os.path.isfile(shake_tokens_bin) else story_tokens_bin
    assert os.path.isfile(tokens_bin)
    #print(f"loading cached tokens in {tokens_bin}")
    with open(tokens_bin, "rb") as f:
        tokens = np.frombuffer(f.read(), dtype=np.int32)

    # np -> tensor, long, on device
    tokens = torch.tensor(tokens)
    tokens = tokens.to(torch.long)
    tokens = tokens.to(device)

    # lightweight dataloader
    def get_batch():
        assert B*T+1 <= len(tokens), "not enough tokens"
        # for 338,025 tokens. E.g. with B=8 T=1024, this will yield 41 batches before looping
        i = 0
        while True:
            x = tokens[i:i+B*T].view(B, T)
            y = tokens[i+1:i+B*T+1].view(B, T)
            yield x, y
            i += B*T
            if i + B*T + 1 >= len(tokens):
                i = 0 # in prod we'd want to randomize the start point a bit

    data_iter = iter(get_batch())
    return data_iter


def demo_model(
    *,
    write_tensors: int = 0,  # write tensors to disk
    inference_only: int = 0,  # only run inference
    compile: int = 0,  # torch.compile the model
    tensorcores: int = 0,  # use tensorcores
    num_iterations: int = 10,  # number of iterations to run
    batch_size: int = 4,  # batch size
    sequence_length: int = 64,  # sequence length
    thunder_jit: int = 1,  # thunder.jit the model
    thunder_executors = None,
):
    # default settings will overfit a tiny batch of data
    # and save model weights and debug state to disk on the first iteration
    # if you'd like to e.g. time the forward pass only, call this function with:
    # inference_only=1, write_tensors=0, sequence_length=1024

    B, T = batch_size, sequence_length
    assert 1 <= T <= 1024

    # we use CUDA globally
    device = "cuda"
    
    # seed the random number generators
    torch.cuda.manual_seed(42)

    # init the tokenizer
    enc = tiktoken.get_encoding("gpt2")
    encode = lambda s: enc.encode(s, allowed_special={"<|endoftext|>"})
    decode = lambda l: enc.decode(l)
    write_tokenizer(enc, "./llm.c/gpt2_tokenizer.bin")

    if tensorcores:
        torch.set_float32_matmul_precision('high')

    # load the GPT-2 model weights
    model = GPT.from_pretrained("gpt2")
    model.train()
    model.to(device)
    if compile:
        print("compiling the model with torch.compile...")
        model = torch.compile(model)
    if thunder_jit:
        print("compiling the model with thunder.jit...")
        if thunder_executors == None:
            thunder_executors = ()
        model = thunder.jit(model, executors=tuple(thunder_executors) + thunder.get_default_executors())

    # forward backward for a few iterations
    data_iter = get_data_iter(B, T, device)
    x, y = next(data_iter) # we'll overfit this batch below
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
    timings = []
    # Warm-up runs!
    for _ in range(5):
        model(x, y)
        
    for i in range(num_iterations):
        # Now measure the runtime!
        t0 = time.time()
        logits, loss = model(x, y)
        if not inference_only:
            optimizer.zero_grad()
            loss.backward()
            if i == 0:
                #print(thunder.last_backward_traces(model)[-1])
                pass
            print(f"{loss=}")
            # TODO: investigate missing keys
            # on the first iteration only, save the state dict to file for later reference
            #if i == 0 and args.write_tensors:
            #    write_model(model, "gpt2_124M.bin")
            #    write_state(model, x, y, logits, loss, "gpt2_124M_debug_state.bin")
            optimizer.step()
            torch.cuda.synchronize()
        t1 = time.time()
        if i > num_iterations - 20:
            timings.append(t1-t0)
        print(f"iteration {i}, loss: {loss.item()}, time: {(t1-t0)*1000:.3f}ms")
    if len(timings) > 0:
        print(f"final 20 iters avg: {np.mean(timings)*1000:.3f}ms")

    # before we end, let's also do one round of inference
    # we'll kick off the generation with "<|endoftext|>", which designates the start of a new sequence
    start = "<|endoftext|>"
    start_ids = encode(start)
    x = (torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...])

    # run generation for 16 time steps (tokens)
    max_new_tokens = 16
    temperature = 1.0
    top_k = 40
    model.eval()
    y = model.generate(x, max_new_tokens, temperature=temperature, top_k=top_k)
    print(decode(y[0].tolist()))
    print('---------------')
    return model

Let's see whether our modifications actually work.
We first try to reproduce `python train_gpt2.py`

In [4]:
default_model = demo_model(thunder_jit=0)
del default_model

wrote ./llm.c/gpt2_tokenizer.bin


  from .autonotebook import tqdm as notebook_tqdm


loading weights from pretrained gpt: gpt2
loss=tensor(5.2700, device='cuda:0', grad_fn=<NllLossBackward0>)
iteration 0, loss: 5.270008563995361, time: 204.454ms
loss=tensor(4.0597, device='cuda:0', grad_fn=<NllLossBackward0>)
iteration 1, loss: 4.059720993041992, time: 35.370ms
loss=tensor(3.3752, device='cuda:0', grad_fn=<NllLossBackward0>)
iteration 2, loss: 3.3751838207244873, time: 36.551ms
loss=tensor(2.8008, device='cuda:0', grad_fn=<NllLossBackward0>)
iteration 3, loss: 2.800813913345337, time: 37.395ms
loss=tensor(2.3154, device='cuda:0', grad_fn=<NllLossBackward0>)
iteration 4, loss: 2.315413475036621, time: 38.213ms
loss=tensor(1.8490, device='cuda:0', grad_fn=<NllLossBackward0>)
iteration 5, loss: 1.8490413427352905, time: 38.244ms
loss=tensor(1.3946, device='cuda:0', grad_fn=<NllLossBackward0>)
iteration 6, loss: 1.3946460485458374, time: 44.990ms
loss=tensor(0.9992, device='cuda:0', grad_fn=<NllLossBackward0>)
iteration 7, loss: 0.9992104768753052, time: 50.108ms
loss=tens

Great! The numbers match. Let's see whether we can directly run the model through Thunder

In [5]:
thunder_model = demo_model(thunder_jit=1)

wrote ./llm.c/gpt2_tokenizer.bin
loading weights from pretrained gpt: gpt2
compiling the model with thunder.jit...
loss=tensor(5.2700, device='cuda:0', grad_fn=<ThunderFunctionBackward>)
iteration 0, loss: 5.270007610321045, time: 8706.774ms
loss=tensor(4.0597, device='cuda:0', grad_fn=<ThunderFunctionBackward>)
iteration 1, loss: 4.059719562530518, time: 54.336ms
loss=tensor(3.3752, device='cuda:0', grad_fn=<ThunderFunctionBackward>)
iteration 2, loss: 3.375183582305908, time: 58.817ms
loss=tensor(2.8008, device='cuda:0', grad_fn=<ThunderFunctionBackward>)
iteration 3, loss: 2.8008131980895996, time: 58.689ms
loss=tensor(2.3154, device='cuda:0', grad_fn=<ThunderFunctionBackward>)
iteration 4, loss: 2.315413475036621, time: 66.542ms
loss=tensor(1.8490, device='cuda:0', grad_fn=<ThunderFunctionBackward>)
iteration 5, loss: 1.8490394353866577, time: 65.207ms
loss=tensor(1.3946, device='cuda:0', grad_fn=<ThunderFunctionBackward>)
iteration 6, loss: 1.3946452140808105, time: 65.406ms
loss=

We can inspect the forward trace

In [6]:
forward_trace = thunder.last_traces(thunder_model)[-1]
print(forward_trace)

# Constructed by Delete Last Used (took 8 milliseconds)
import torch
from torch import Tensor
import torch.nn.functional
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast()
def augmented_forward_fn(idx, targets, t_transformer_h_0_ln_1_bias, t_transformer_h_0_attn_c_attn_bias, t_transformer_h_0_attn_c_proj_bias, t_transformer_h_0_ln_2_bias, t_transformer_h_0_mlp_c_fc_bias, t_transformer_h_0_mlp_c_proj_bias, t_transformer_h_1_ln_1_bias, t_transformer_h_1_attn_c_attn_bias, t_transformer_h_1_attn_c_proj_bias, t_transformer_h_1_ln_2_bias, t_transformer_h_1_mlp_c_fc_bias, t_transformer_h_1_mlp_c_proj_bias, t_transformer_h_2_ln_1_bias, t_transformer_h_2_attn_c_attn_bias, t_transformer_h_2_attn_c_proj_bias, t_transformer_h_2_ln_2_bias, t_transformer_h_2_mlp_c_fc_bias, t_transformer_h_2_mlp_c_proj_bias, t_transformer_h_3_ln_1_bias, t_transformer_h_3_attn_c_attn_bias, t_transformer_h_3_attn_c_proj_bias, t_transformer_h_3_ln_2_bias, t_transformer_h_3_mlp_c_fc_bias, 

And the backward trace

In [7]:
backward_trace = thunder.last_backward_traces(thunder_model)[-1]
print(backward_trace)
del thunder_model

# Constructed by Delete Last Used (took 15 milliseconds)
import torch
from torch import Tensor
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast()
def backward_fn(saved_for_backward, cotangents):
  # saved_for_backward: "Collection"
  # cotangents: "Collection"
  C0, C1, = saved_for_backward
  clear_collection(saved_for_backward)
  del saved_for_backward
  t945, t946, = cotangents
  clear_collection(cotangents)
  del cotangents
  idx, t0, t10, t103, t107, t108, t113, t123, t127, t129, t133, t137, t139, t14, \
  t142, t145, t146, t154, t156, t16, t160, t164, t166, t169, t172, t178, t182, \
  t183, t188, t19, t198, t202, t204, t208, t212, t214, t217, t22, t220, t221, \
  t229, t231, t235, t239, t241, t244, t247, t253, t257, t258, t263, t273, t277, \
  t279, t28, t283, t287, t289, t292, t295, t296, t304, t306, t310, t314, t316, \
  t319, t32, t322, t328, t33, t332, t333, t338, t348, t352, t354, t358, t362, \
  t364, t367, t370, t371, t379, t38, t381, t385, t

## Compiling CUDA kernels with [CUDA-Python](https://github.com/NVIDIA/cuda-python)
We are using [CUDA-Python](https://github.com/NVIDIA/cuda-python) to compile CUDA kernels and provide bindings for Python.
We already have an excellent resource, [Extend Thunder with CUDA-Python](./extend_thunder_with_cuda_python.ipynb), that explains the topic in greater detail. For now, we will reuse some of its helper functions that allow us to compile and run CUDA kernels.

In [8]:
from cuda import cuda, nvrtc


def check_error(results):
    err, *results = results
    if isinstance(err, cuda.CUresult):
        if err != cuda.CUresult.CUDA_SUCCESS:
            raise RuntimeError(f"CUDA error: {cuda.cuGetErrorString(err)}")
    elif isinstance(err, nvrtc.nvrtcResult):
        if err != nvrtc.nvrtcResult.NVRTC_SUCCESS:
            raise RuntimeError(f"NVRTC error: {nvrtc.nvrtcGetErrorString(err)}")
    else:
        raise TypeError("Unknown error type: {err}")
    if len(results) == 0:
        return
    if len(results) == 1:
        return results[0]
    return results


def compile_program_and_get_module(cuda_src, program_name):
    """
    Compiles a kernel from the CUDA source code provided in the string `cuda_src` with the name `program_name`
    and returns a PTX represented as a module data.

    The module is then could be used to retrieve CUDA kernels which in turn could be run with `launch_kernel`
    """

    torch.cuda.current_stream()  # this initializes the device context for us. we don't need the stream specifically.
    
    # Create program
    prog = check_error(nvrtc.nvrtcCreateProgram(str.encode(cuda_src), (program_name + '.cu').encode(), 0, [], []))    
    
    # Compile program
    min, maj = torch.cuda.get_device_capability()
    opts = [
        f"--gpu-architecture=compute_{min}{maj}".encode(),
        b"--include-path=/usr/local/cuda/include/",
        b"--include-path=/usr/include/",
        b"--use_fast_math",
        b"--dopt=on",
    ] #, b"--expt-relaxed-constexpr"]
    check_error(nvrtc.nvrtcCompileProgram(prog, len(opts), opts))
    
    ## Get PTX from compilation
    ptxSize = check_error(nvrtc.nvrtcGetPTXSize(prog))
    ptx = b" " * ptxSize
    check_error(nvrtc.nvrtcGetPTX(prog, ptx))
    
    logSize = check_error(nvrtc.nvrtcGetProgramLogSize(prog))
    log = b" " * logSize
    check_error(nvrtc.nvrtcGetProgramLog(prog, log))
    print(log.decode())
    
    
    # Load PTX as module data and retrieve function
    module = check_error(cuda.cuModuleLoadData(ptx))
    return module


def launch_kernel(kernel, grid, block, /, *args, shmem=0):
    """utility function to launch kernels.
    Args can be tensors (corresponding to float* etc kernel params or numpy scalars (which have precision info))
    """

    # collect values (data_ptr as uint64 array for tensors, the values as an array for values)
    addresses = []
    wrapped_args = []
    for a in args:
        if isinstance(a, torch.Tensor):
            # for tensor pass in data_ptr
            wrapped_args.append(numpy.array(a.data_ptr(), dtype=numpy.uint64))
        elif isinstance(a, numpy.number):
            wrapped_args.append(numpy.array([a]))
        else:
            raise TypeError("please only pass tensors and numpy numbers to launch_kernel")

    # assemble an array of pointers to the args
    args = numpy.array([a.ctypes.data for a in wrapped_args], dtype=numpy.uint64)

    # set up grid / block layout to be 3d
    grid = tuple(grid)
    block = tuple(block)
    assert 1 <= len(block) <= 3 and 1 <= len(grid) <= 3
    grid = grid + (3 - len(grid)) * (1,)
    block = block + (3 - len(block)) * (1,)

    # Launch!
    err, = cuda.cuLaunchKernel(
       kernel,
       *grid, *block, # xyz each
       shmem,  # dynamic shared memory
       torch.cuda.current_stream().stream_id,  # stream
       args.ctypes.data,  # kernel arguments
       0,  # extra (ignore)
    )
    if err != cuda.CUresult.CUDA_SUCCESS:
        raise RuntimeError(f"CUDA error: {err}")

## [llm.c](https://github.com/karpathy/llm.c): CUDA kernels
Below we list all the CUDA kernels from [llm.c/train_gpt2.py](./llm.c/train_gpt2.py). Note that some of these kernels are defined with `extern "C"`. This is needed to avoid name mangling when accessing the kernels with [CUDA-Python](https://github.com/NVIDIA/cuda-python).

In [9]:
all_kernels = r"""
#include <cooperative_groups.h>
#include <cooperative_groups/reduce.h>

#define NEG_INFINITY __int_as_float(0xff800000)
#define FLT_MAX 3.402823466e+38F
#define M_PI 3.14159265358979323846


// convenience macro for calculating grid/block dimensions for kernels
#define CEIL_DIV(M, N) (((M) + (N)-1) / (N))

// ----------------------------------------------------------------------------
// all the kernels

// warp-level reduction for finding the maximum value
__device__ float warpReduceMax(float val) {
    for (int offset = 16; offset > 0; offset /= 2) {
        val = fmaxf(val, __shfl_down_sync(0xFFFFFFFF, val, offset));
    }
    return val;
}

// warp-level reduction for summing values
__device__ float warpReduceSum(float val) {
    for (int offset = 16; offset > 0; offset /= 2) {
        val += __shfl_down_sync(0xFFFFFFFF, val, offset);
    }
    return val;
}

__global__ void encoder_forward_kernel2(float* out,
                               int* inp, float* wte, float* wpe,
                               int B, int T, int C) {
    int idx = blockIdx.x * blockDim.x + threadIdx.x;
    int N = B * T * C;

    if (idx < N) {
        int bt = idx / C;
        int b = bt / T;
        int t = bt % T;
        int c = idx % C;

        int ix = inp[b * T + t];

        float* out_btc = out + b * T * C + t * C + c;
        float* wte_ix = wte + ix * C + c;
        float* wpe_tc = wpe + t * C + c;
        *out_btc = *wte_ix + *wpe_tc;
    }
}


extern "C"
__global__ void layernorm_forward_kernel3(float* __restrict__ out, float* __restrict__ mean, float* __restrict__ rstd,
                                    const float*  __restrict__ inp, const float*  __restrict__ weight,
                                    const float* __restrict__ bias, int N, int C) {
    namespace cg = cooperative_groups;
    cg::thread_block block = cg::this_thread_block();
    cg::thread_block_tile<32> warp = cg::tiled_partition<32>(block);
    int idx = blockIdx.x * warp.meta_group_size() + warp.meta_group_rank();
    if(idx >= N) {
        return;
    }

    // the row of input that this group of threads is responsible for
    const float* x = inp + idx * C;

    // mean
    float sum = 0.0f;
    for (int i = warp.thread_rank(); i < C; i += warp.size()) {
        sum += x[i];
    }
    sum = cg::reduce(warp, sum, cg::plus<float>{});
    float m = sum / C;
    if(warp.thread_rank() == 0 && mean != nullptr) {
        __stcs(mean + idx, m);
    }

    // rstd
    sum = 0.0f;
    for (int i = warp.thread_rank(); i < C; i += warp.size()) {
        float diff = x[i] - m;
        sum += diff * diff;
    }
    sum = cg::reduce(warp, sum, cg::plus<float>{});
    float s = rsqrtf(sum / C + 1e-5f);
    if(warp.thread_rank() == 0 && rstd != nullptr) {
        __stcs(rstd + idx, s);
    }

    // final normalization and scaling by weight/bias
    float* o = out + idx * C;
    for (int c = warp.thread_rank(); c < C; c += warp.size()) {
        // load and store using the .cs "streaming" hint to the compiler,
        // indicating that this data will not be reused soon, and can be streamed through the caches
        // this allows the threads to get more cache-hits for the (shared) weight and bias parameters
        float n = s * (__ldcs(x+c) - m);
        __stcs(o+c, n * weight[c] + bias[c]);
    }
}

__global__ void add_bias(float* out, float* bias, int B, int T, int OC) {
    int idx = blockIdx.x * blockDim.x + threadIdx.x;
    int stride = blockDim.x * gridDim.x;
    for (int i = idx; i < B*T*OC; i += stride) {
        int col = i % OC;
        out[i] += bias[col];
    }
}


extern "C"
__global__ void permute_kernel(float* q, float* k, float* v,
                               const float* inp,
                               int B, int N, int NH, int d) {
    // okay so now, this kernel wants Q,K,V to all be of shape (B, NH, N, d)
    // but instead, we have a single tensor QKV (inp) of shape (B, N, 3, NH, d)
    int idx = blockIdx.x * blockDim.x + threadIdx.x;

    // Q[b][nh_][n][d_] = inp[b][n][0][nh_][d_]

    if (idx < B * NH * N * d) {
        int b = idx / (NH * N * d);
        int rest = idx % (NH * N * d);
        int nh_ = rest / (N * d);
        rest = rest % (N * d);
        int n = rest / d;
        int d_ = rest % d;

        int inp_idx = \
            (b * N * 3 * NH * d)
            +   (n * 3 * NH * d)
            +       (0 * NH * d)
            +          (nh_ * d)
            +                d_;

        q[idx] = __ldcs(&inp[inp_idx]);
        k[idx] = __ldcs(&inp[inp_idx + NH * d]);
        v[idx] = __ldcs(&inp[inp_idx + 2 * (NH * d)]);
    }
}

extern "C"
__global__ void permute_kernel_backward(float* dinp,
                                        const float* dq, const float* dk, const float* dv,
                                        int B, int N, int NH, int d) {
    int idx = blockIdx.x * blockDim.x + threadIdx.x;
    if (idx < B * NH * N * d) {
        int b = idx / (NH * N * d);
        int rest = idx % (NH * N * d);
        int nh_ = rest / (N * d);
        rest = rest % (N * d);
        int n = rest / d;
        int d_ = rest % d;

        int inp_idx = (b * N * 3 * NH * d) + (n * 3 * NH * d) + (0 * NH * d) + (nh_ * d) + d_;
        dinp[inp_idx] += dq[idx];
        dinp[inp_idx + NH * d] += dk[idx];
        dinp[inp_idx + 2 * (NH * d)] += dv[idx];
    }
}

extern "C"
__global__ void unpermute_kernel(float* inp, float *out, int B, int N, int NH, int d) {
   // out has shape (B, nh, N, d) but we need to unpermute it to (B, N, nh, d)
    int idx = blockIdx.x * blockDim.x + threadIdx.x;

    // out[b][n][nh_][d_] <- inp[b][nh_][n][d_]
    if (idx < B * NH * N * d) {
        int b = idx / (NH * N * d);
        int rest = idx % (NH * N * d);
        int nh_ = rest / (N * d);
        rest = rest % (N * d);
        int n = rest / d;
        int d_ = rest % d;

        int other_idx = (b * NH * N * d) + (n * NH * d) + (nh_ * d) + d_;
        out[other_idx] = __ldcs(&inp[idx]);
    }
}

extern "C"
__global__ void unpermute_kernel_backward(float* dinp, const float *dout, int B, int N, int NH, int d) {
    int idx = blockIdx.x * blockDim.x + threadIdx.x;
    if (idx < B * NH * N * d) {
        int b = idx / (NH * N * d);
        int rest = idx % (NH * N * d);
        int nh_ = rest / (N * d);
        rest = rest % (N * d);
        int n = rest / d;
        int d_ = rest % d;

        int other_idx = (b * NH * N * d) + (n * NH * d) + (nh_ * d) + d_;
        dinp[idx] += dout[other_idx];
    }
}

__device__ float& vec_at(float4& vec, int index) {
    return reinterpret_cast<float*>(&vec)[index];
}

__device__ float vec_at(const float4& vec, int index) {
    return reinterpret_cast<const float*>(&vec)[index];
}

__global__ void softmax_forward_kernel5(float* out, float inv_temperature, const float* inp, int N, int T) {
    // inp, out shape: (N, T, T), where N = B * NH
    // fuses the multiplication by scale inside attention
    // directly autoregressive, so we only compute the lower triangular part
    // uses the online softmax algorithm
    assert(T % 4  == 0);
    namespace cg = cooperative_groups;
    cg::thread_block block = cg::this_thread_block();
    cg::thread_block_tile<32> warp = cg::tiled_partition<32>(block);
    int idx = blockIdx.x * warp.meta_group_size() + warp.meta_group_rank();
    if(idx >= N * T) {
        return;
    }
    int own_pos = idx % T;
    int pos_by_4 = own_pos / 4;

    // one row of inp, i.e. inp[idx, :] of shape (T,)
    const float* x = inp + idx * T;

    // not INF, so we don't get NaNs accidentally when subtracting two values.
    float maxval = -FLT_MAX;
    float sumval = 0.0f;

    const float4* x_vec = reinterpret_cast<const float4*>(x);
    for (int i = warp.thread_rank(); i < pos_by_4; i += warp.size()) {
        float4 v = x_vec[i];
        float old_maxval = maxval;
        for(int k = 0; k < 4; ++k) {
            maxval = fmaxf(maxval, vec_at(v, k));
        }
        sumval *= expf(inv_temperature * (old_maxval - maxval));
        for(int k = 0; k < 4; ++k) {
            sumval += expf(inv_temperature * (vec_at(v, k) - maxval));
        }
    }

    if(4*pos_by_4 + warp.thread_rank() <= own_pos) {
        float old_maxval = maxval;
        maxval = fmaxf(maxval, x[4*pos_by_4 + warp.thread_rank()]);
        sumval *= expf(inv_temperature * (old_maxval - maxval));
        sumval += expf(inv_temperature * (x[4*pos_by_4 + warp.thread_rank()] - maxval));
    }

    float global_maxval = cg::reduce(warp, maxval, cg::greater<float>{});
    sumval *= expf(inv_temperature * (maxval - global_maxval));

    float sum = cg::reduce(warp, sumval, cg::plus<float>{});
    float norm = 1.f / sum;

    // divide the whole row by the sum
    for (int i = warp.thread_rank(); i <= own_pos; i += warp.size()) {
        // recalculation is faster than doing the round-trip through memory.
        float ev = expf(inv_temperature * (__ldcs(x + i) - global_maxval));
        __stcs(out + idx * T + i, ev * norm);
    }
}

__global__ void residual_forward_kernel(float* out, float* inp1, float* inp2, int N) {
    int idx = blockIdx.x * blockDim.x + threadIdx.x;
    if (idx < N) {
        out[idx] = __ldcs(&inp1[idx]) + __ldcs(&inp2[idx]);
    }
}

#define GELU_SCALING_FACTOR sqrtf(2.0f / M_PI)

extern "C"
__global__ void gelu_forward_kernel(float* out, const float* inp, int N) {
    int i = blockIdx.x * blockDim.x + threadIdx.x;
    if (i < N) {
        float xi = inp[i];
        float cube = 0.044715f * xi * xi * xi;
        out[i] = 0.5f * xi * (1.0f + tanhf(GELU_SCALING_FACTOR * (xi + cube)));
    }
}

extern "C"
__global__ void gelu_backward_kernel(float* dinp, const float* inp, const float* dout, const int N) {
    int i = blockIdx.x * blockDim.x + threadIdx.x;
    if (i < N) {
        float x = inp[i];
        float cube = 0.044715f * x * x * x;
        float tanh_arg = GELU_SCALING_FACTOR * (x + cube);
        float tanh_out = tanhf(tanh_arg);
        float coshf_out = coshf(tanh_arg);
        float sech_out = 1.0f / (coshf_out * coshf_out);
        float local_grad = 0.5f * (1.0f + tanh_out) + x * 0.5f * sech_out * GELU_SCALING_FACTOR * (1.0f + 3.0f * 0.044715f * x * x);
        dinp[i] += local_grad * dout[i];
    }
}

extern "C"
__global__ void crossentropy_forward_kernel1(float* losses,
                            float* probs, int* targets,
                            int B, int T, int V) {
    int i = blockIdx.x * blockDim.x + threadIdx.x;
    if (i < B * T) {
        int b = i / T;
        int t = i % T;
        float* probs_bt = probs + b * T * V + t * V;
        int ix = targets[b * T + t];
        losses[b * T + t] = -logf(probs_bt[ix]);
    }
}

extern "C"
__global__ void softmax_forward_kernel7(float* out, const float* inp, int N, int C) {
    // out is (N, C) just like inp. Each row of inp will get softmaxed.
    // same as kernel4, but optimised for very large Cs with advanced unrolling

    // The trick is to read into a register array (all indices known at compile time)
    // and always read UNROLL_FACTOR values to maximise memory level parallelism
    // even if we would be out of bounds, we set the index to min(C-1, idx)
    // so we just do some unnecessary reads (obviously bad for small C)
    // the writes are in a separate loop with a conditional check for out of bounds
    // making it separate is necessary to convince the compiler to do the right thing
    const int UNROLL_FACTOR = 8;
    const int warpsPerBlock = blockDim.x / 32;

    extern __shared__ float shared[];
    int idx = blockIdx.x;
    int tid = threadIdx.x;
    int warpId = threadIdx.x / 32; // warp index within a block
    int laneId = threadIdx.x % 32; // thread index within a warp

    // shared[] must be allocated to have 2 * warpsPerBlock elements
    // first half for max values, the second half for sum values
    float* maxvals = shared;
    float* sumvals = &shared[warpsPerBlock];

    if (tid >= C) {
        maxvals[warpId] = NEG_INFINITY;
        sumvals[warpId] = 0.0f;
        return;
    }

    const float* x = inp + idx * C; // input
    float* y = out + idx * C; // output

    // first, thread coarsening by directly accessing global memory in series
    float maxval = NEG_INFINITY;
    for (int i = tid; i < C; i += blockDim.x * UNROLL_FACTOR) {
        #pragma unroll
        for (int u = 0; u < UNROLL_FACTOR; u++) {
            maxval = fmaxf(maxval, x[min(C - 1, i + u*blockDim.x)]);
        }
    }

    // now within-warp reductions for maxval
    maxval = warpReduceMax(maxval);
    // the 0th thread of each warp writes the maxval of that warp to shared memory
    if (laneId == 0) maxvals[warpId] = maxval;
    __syncthreads();
    // now the 0th thread reduces the maxvals in shared memory, i.e. across warps
    if (tid == 0) {
        float val = maxvals[tid];
        #pragma unroll
        for (int i = 1; i < warpsPerBlock; i++) {
            val = fmaxf(val, maxvals[i]);
        }
        // store the final max in the first position
        maxvals[0] = val;
    }
    __syncthreads();
    // broadcast the max to all threads
    float offset = maxvals[0];

    // compute expf and write the result to global memory
    // + thread coarsening for sum
    float sumval = 0.0f;
    for (int i = tid; i < C; i += blockDim.x * UNROLL_FACTOR) {
        float reg_array[UNROLL_FACTOR];
        #pragma unroll
        for (int u = 0; u < UNROLL_FACTOR; u++) {
            reg_array[u] = __ldcs(&x[min(C - 1, i + u*blockDim.x)]);
        }
        #pragma unroll
        for (int u = 0; u < UNROLL_FACTOR; u++) {
            if (i + u*blockDim.x < C) {
                float output = expf(reg_array[u] - offset);
                y[min(C - 1, i + u*blockDim.x)] = output; // compiler likes redundant min()?!
                sumval += output; // combined into the same loop unlike kernel3
            }
        }
    }

    // okay now we calculated exp(x - max(x))
    // step 2: sum all the values and divide by the sum

    // within-warp reduction for sumval
    sumval = warpReduceSum(sumval);
    // write sumval to shared memory
    if (laneId == 0) sumvals[warpId] = sumval;
    __syncthreads();
    // inter-thread reduction of sum
    if (tid == 0) {
        float val = sumvals[tid];
        #pragma unroll
        for (int i = 1; i < warpsPerBlock; ++i) {
            val += sumvals[i];
        }
        sumvals[0] = val;
    }
    __syncthreads();
    // broadcast the sum to all threads
    float sum = sumvals[0];

    // divide the whole row by the sum
    for (int i = tid; i < C; i += blockDim.x * UNROLL_FACTOR) {
        float reg_array[UNROLL_FACTOR];
        #pragma unroll
        for (int u = 0; u < UNROLL_FACTOR; u++) {
            reg_array[u] = y[min(C - 1, i + u*blockDim.x)];
        }
        #pragma unroll
        for (int u = 0; u < UNROLL_FACTOR; u++) {
            if (i + u*blockDim.x < C) {
                y[i + u*blockDim.x] = reg_array[u] / sum;
            }
        }
    }
}

extern "C"
__global__ void crossentropy_softmax_backward_kernel1(float* dlogits,
                           const float* dlosses, const float* probs, const int* targets,
                           int B, int T, int V) {
    int i = blockIdx.x * blockDim.x + threadIdx.x;
    if (i < B * T * V) {
        int b = i / (T * V);
        int t = (i / V) % T;
        int v = i % V;
        float* dlogits_bt = dlogits + b * T * V + t * V;
        const float* probs_bt = probs + b * T * V + t * V;
        float dloss = dlosses[b * T + t];
        int ix = targets[b * T + t];
        float p = probs_bt[v];
        float indicator = v == ix ? 1.0f : 0.0f;
        dlogits_bt[v] += (p - indicator) * dloss;
    }
}

__global__ void matmul_backward_bias_kernel_faster(float* dbias, const float* dout, int B, int T, int OC) {
    extern __shared__ float shared[];
    int o = blockIdx.x; // range [0, OC)
    int tid = threadIdx.x; // range [0, block_size)
    int block_size = blockDim.x;
    const float* x = dout + o;
    // thread coarsening
    double sum = 0.0f;
    for (int i = tid; i < B * T; i += block_size) {
        sum += x[i * OC];
    }
    shared[tid] = (float) sum;
    __syncthreads();
    // reductions
    for (int stride = block_size / 2; stride >= 1; stride /= 2) {
        __syncthreads();
        if (tid < stride) {
            shared[tid] += shared[tid + stride];
        }
    }
    // write the final result (at thread 0) to global memory
    if (tid == 0) {
        dbias[o] = shared[0];
    }
}

// super naive layernorm backward kernel that just parallelizes over B,T and loops over C
extern "C"
__global__ void layernorm_backward_kernel1(float* dinp, float* dweight, float* dbias,
                        float* dout, float* inp, float* weight, float* mean, float* rstd,
                        int B, int T, int C) {
    int idx = blockIdx.x * blockDim.x + threadIdx.x;
    if (idx >= B*T) return;
    int b = idx / T;
    int t = idx % T;

    float* dout_bt = dout + b * T * C + t * C;
    float* inp_bt = inp + b * T * C + t * C;
    float* dinp_bt = dinp + b * T * C + t * C;
    float mean_bt = mean[b * T + t];
    float rstd_bt = rstd[b * T + t];

    // first: two reduce operations
    float dnorm_mean = 0.0f;
    float dnorm_norm_mean = 0.0f;
    for (int i = 0; i < C; i++) {
        float norm_bti = (inp_bt[i] - mean_bt) * rstd_bt;
        float dnorm_i = weight[i] * dout_bt[i];
        dnorm_mean += dnorm_i;
        dnorm_norm_mean += dnorm_i * norm_bti;
    }
    dnorm_mean = dnorm_mean / C;
    dnorm_norm_mean = dnorm_norm_mean / C;

    // now iterate again and accumulate all the gradients
    for (int i = 0; i < C; i++) {
        float norm_bti = (inp_bt[i] - mean_bt) * rstd_bt;
        float dnorm_i = weight[i] * dout_bt[i];
        // gradient contribution to bias
        atomicAdd(&dbias[i], dout_bt[i]);
        // gradient contribution to weight
        atomicAdd(&dweight[i], norm_bti * dout_bt[i]);
        // gradient contribution to input
        float dval = 0.0f;
        dval += dnorm_i; // term 1
        dval -= dnorm_mean; // term 2
        dval -= norm_bti * dnorm_norm_mean; // term 3
        dval *= rstd_bt; // final scale
        dinp_bt[i] += dval;
    }
}

__global__ void setConstant(float* vec, float constant, int N) {
    int idx = blockIdx.x * blockDim.x + threadIdx.x;
    if (idx < N) {
        vec[idx] = constant;
    }
}
"""

Let us try and compile the source code of `all_kernels` using the helper functions defined above

In [10]:
llmc_cuda_module = compile_program_and_get_module(all_kernels, "llmc_cuda_module")

      extern __shared__ float shared[];
                              ^


 


Having compiled the code, we need to be able to access specific kernels by their names

In [11]:
import functools

@functools.cache
def extract_cuda_kernel(function_name):
    return check_error(cuda.cuModuleGetFunction(llmc_cuda_module, function_name.encode()))

Note that we decorated `extract_cuda_kernel` with `functools.cache` for potentially faster kernel look-ups.

## High-level Python wrapper for launching CUDA kernels

Here we present a high-level Python wrappers that accept PyTorch tensors and execute the compiled CUDA kernels.

In [12]:
import numpy


def map_arg_tensors(map_fn):
    def arg_mapper(f):
        def wrapper(*args):
            def map_tensor_fn(inp):
                if isinstance(inp, torch.Tensor):
                    return map_fn(inp)
                else:
                    return inp

            new_args = map(map_tensor_fn, args)
            return f(*new_args)
        return wrapper
    return arg_mapper


def force_contiguous_inputs(f):
    return map_arg_tensors(lambda t: t.contiguous())(f)


def index_tensors_to_int32(f):
    def map_fn(t):
        if t.dtype.is_floating_point:
            return t
        else:
            return t.to(torch.int32)

    return map_arg_tensors(lambda t: map_fn(t))(f)


@force_contiguous_inputs
def permute_split(input):
    assert input.ndim == 5
    assert input.shape[2] == 3

    B, N, _, NH, d = input.shape

    permute_kernel = extract_cuda_kernel("permute_kernel")

    q, k, v = torch.empty(3, B, NH, N, d, device=input.device, dtype=input.dtype).unbind(dim=0)

    block_size = 256
    grid_size = (input.numel() // 3 + block_size - 1) // block_size
    launch_kernel(
        permute_kernel,
        (grid_size,),
        (block_size,),
        q, k, v,
        input,
        numpy.int32(B),
        numpy.int32(N),
        numpy.int32(NH),
        numpy.int32(d),
    )

    return q, k, v


@force_contiguous_inputs
def permute_split_backward(q_grad, k_grad, v_grad):
    assert q_grad.shape == k_grad.shape == v_grad.shape

    permute_kernel_backward = extract_cuda_kernel("permute_kernel_backward")

    B, NH, N, d = q_grad.shape

    grad = torch.zeros(B, N, 3, NH, d, dtype=q_grad.dtype, device=q_grad.device)

    block_size = 256
    grid_size = (q_grad.numel() + block_size - 1) // block_size
    launch_kernel(
        permute_kernel_backward,
        (grid_size,),
        (block_size,),
        grad,
        q_grad, k_grad, v_grad,
        numpy.int32(B),
        numpy.int32(N),
        numpy.int32(NH),
        numpy.int32(d),
    )

    return grad


@force_contiguous_inputs
def unpermute_forward(input):
    assert input.ndim == 4

    unpermute_kernel = extract_cuda_kernel("unpermute_kernel")

    B, NH, N, d = input.shape

    output = torch.empty(B, N, NH, d, dtype=input.dtype, device=input.device)

    block_size = 256
    grid_size = (input.numel() + block_size - 1) // block_size
    launch_kernel(
        unpermute_kernel,
        (grid_size,),
        (block_size,),
        input,
        output,
        numpy.int32(B),
        numpy.int32(N),
        numpy.int32(NH),
        numpy.int32(d),
    )

    return output


@force_contiguous_inputs
def unpermute_backward(grad):
    assert grad.ndim == 4

    unpermute_backward_kernel = extract_cuda_kernel("unpermute_kernel_backward")

    B, N, NH, d = grad.shape

    res = torch.zeros(B, NH, N, d, dtype=grad.dtype, device=grad.device)

    block_size = 256
    grid_size = (grad.numel() + block_size - 1) // block_size
    launch_kernel(
        unpermute_backward_kernel,
        (grid_size,),
        (block_size,),
        res,
        grad,
        numpy.int32(B),
        numpy.int32(N),
        numpy.int32(NH),
        numpy.int32(d),
    )

    return res


@force_contiguous_inputs
def softmax_forward(input):
    softmax_kernel = extract_cuda_kernel("softmax_forward_kernel7")

    output = input.new_empty(*input.shape)

    if input.ndim == 0:
        C = 1
    else:
        C = input.shape[-1]
    N = input.numel() // C

    block_size = 512
    grid_size = N
    shared_mem_size = 2 * block_size / 32 * input.dtype.itemsize
    launch_kernel(
        softmax_kernel,
        (grid_size,),
        (block_size,),
        output,
        input,
        numpy.int32(N),
        numpy.int32(C),
        shmem=shared_mem_size,
    )

    return output


@force_contiguous_inputs
@index_tensors_to_int32
def crossentropy_forward(probs, targets):
    assert probs.ndim - targets.ndim == 1

    if probs.ndim <= 1:
        probs = probs.unsqueeze(0)
        targets = targets.unsqueeze(0)

    *BT, V = probs.shape
    assert tuple(targets.shape) == tuple(BT)

    if len(BT) == 2:
        B, T = BT
    else:
        B, T = 1, BT[0]

    crossentropy_kernel = extract_cuda_kernel("crossentropy_forward_kernel1")

    losses = probs.new_empty(*targets.shape)

    block_size = 128
    grid_size = (probs.numel() // V + block_size - 1) // block_size

    launch_kernel(
        crossentropy_kernel,
        (grid_size,),
        (block_size,),
        losses,
        probs,
        targets,
        numpy.int32(B),
        numpy.int32(T),
        numpy.int32(V),
    )

    return losses


@force_contiguous_inputs
def crossentropy_softmax_forward(scores, targets):
    probs = softmax_forward(scores)
    return crossentropy_forward(probs, targets)


@force_contiguous_inputs
@index_tensors_to_int32
def crossentropy_softmax_backward(glogits, glosses, probs, targets):
    crossentropy_softmax_backward_kernel = extract_cuda_kernel("crossentropy_softmax_backward_kernel1")

    *BT, V = probs.shape
    if len(BT) == 2:
        B, T = BT
    else:
        B, T = 1, BT[0]

    block_size = 256
    grid_size = (probs.numel() + block_size - 1) // block_size

    launch_kernel(
        crossentropy_softmax_backward_kernel,
        (grid_size,),
        (block_size,),
        glogits,
        glosses,
        probs,
        targets,
        numpy.int32(B),
        numpy.int32(T),
        numpy.int32(V),
    )

    return glogits


@force_contiguous_inputs
def gelu_forward(input):
    gelu_kernel = extract_cuda_kernel("gelu_forward_kernel")

    output = input.new_empty(*input.shape)

    block_size = 128
    grid_size = (input.numel() + block_size - 1) // block_size
    launch_kernel(gelu_kernel, (grid_size,), (block_size,), output, input, numpy.int32(input.numel()))

    return output


@force_contiguous_inputs
def gelu_backward(input, grad):
    gelu_backward_kernel = extract_cuda_kernel("gelu_backward_kernel")

    input_grad = torch.zeros(*input.shape, dtype=input.dtype, device=input.device)

    block_size = 128
    grid_size = (input.numel() + block_size - 1) // block_size
    launch_kernel(gelu_backward_kernel, (grid_size,), (block_size,), input_grad, input, grad, numpy.int32(input.numel()))

    return input_grad


@force_contiguous_inputs
def layernorm_forward(input, weight, bias):
    layernorm_forward_kernel = extract_cuda_kernel("layernorm_forward_kernel3")

    B, T, C = input.shape
    N = B * T

    out = input.new_empty(*input.shape)
    mean = input.new_empty(B, T)
    rstd = input.new_empty(B, T)

    block_size = 512
    grid_size = (N * 32 + block_size - 1) // block_size
    launch_kernel(
        layernorm_forward_kernel,
        (grid_size,),
        (block_size,),
        out, mean, rstd,
        input, weight, bias,
        numpy.int32(N), numpy.int32(C),
    )

    return out, mean, rstd


@force_contiguous_inputs
def layernorm_backward(grad, input, weight, bias, mean, rstd):
    layernorm_backward_kernel = extract_cuda_kernel("layernorm_backward_kernel1")

    B, T, C = input.shape
    N = B * T

    input_grad = torch.zeros(*input.shape, dtype=input.dtype, device=input.device)
    weight_grad = torch.zeros(*weight.shape, dtype=weight.dtype, device=weight.device)
    bias_grad = torch.zeros(*bias.shape, dtype=bias.dtype, device=bias.device)

    block_size = 64
    grid_size = (N + block_size - 1) // block_size
    launch_kernel(
        layernorm_backward_kernel,
        (grid_size,),
        (block_size,),
        input_grad, weight_grad, bias_grad,
        grad, input, weight, mean, rstd,
        numpy.int32(B), numpy.int32(T), numpy.int32(C),
    )

    return input_grad, weight_grad, bias_grad

A couple of notes. All the kernels expect contiguous inputs, hence all the functions are decorated with `force_contiguous_inputs`. Some of them accept index tensors of type `int32`, but PyTorch uses `int64` by default, so we additionally pre-process index inputs with `index_tensors_to_int32`. We do not convert floating point inputs to `float` as we use this `dtype` by default. Aside from these nuances, the implementation is straightforward - acquire the kernel, prepare inputs and launch configurations, run the kernel with `launch_kernel`.

Let's test one of the custom CUDA kernels! 

In [13]:
x = torch.rand(256, 256, device='cuda')
torch_gelu = F.gelu(x, approximate='tanh')
custom_gelu = gelu_forward(x)
print((torch_gelu - custom_gelu).abs().max())

tensor(1.1921e-07, device='cuda:0')


## Swap PyTorch implementations with custom ones

In this section we will finally achieve our goal - we will replace PyTorch operations with our own defined above that utilize native CUDA kernels from [llm.c](https://github.com/karpathy/llm.c).

Before reading on, however, we highly recommend checking the following resources first:
* [Zero to Thunder](./zero_to_thunder.ipynb) for a very short introduction to Thunder and its capabilities.
* [Defining new Thunder operators](./adding_custom_operator.ipynb) for understanding how to introduce new operators and executors to Thunder.
* [Defining custom forward and backward for existing operators](./adding_custom_operator_backward.ipynb) to learn how to make these new operators also differentiable using custom backward implementations.

In a nutshell, what we are doing below is as follows:
* A new executor is created.
* A new forward/backward symbol is registered for this executor with the corresponding implementation defined in the previous section.
* A new symbol is registered with the executor which in `execution_transform` calls the "forward" symbol, and in `grad_transforms` calls both the "forward" and the "backward" symbols, all from the previous step. This new symbol is what specific PyTorch implementations are going to be mapped to.

In [14]:
from thunder.core.transforms import get_grad, put_grad, put_grads
from thunder.core.langctxs import langctx, Languages
from thunder import TensorProxy


# Register a new executor
llmc = thunder.extend.OperatorExecutor("llm.c")
thunder.extend.register_executor(llmc)

# Permute/unpermute + backward {
def permute_meta(input: TensorProxy):
    B, N, _, NH, d = input.shape
    shape = (B, NH, N, d)
    return (
        TensorProxy(like=input, shape=shape),
        TensorProxy(like=input, shape=shape),
        TensorProxy(like=input, shape=shape),
    )


# permute forward symbol
llmc_permute = llmc.register_operator(
    "llmc_permute",
    like=permute_meta,
    fn=permute_split,
)


def unpermute_meta(input: TensorProxy):
    B, NH, N, d = input.shape
    return TensorProxy(like=input, shape=(B, N, NH, d))


# unpermute forward symbol
llmc_unpermute = llmc.register_operator(
    "llmc_unpermute",
    like=unpermute_meta,
    fn=unpermute_forward,
)


def permute_backward_meta(q_grad: TensorProxy, k_grad: TensorProxy, v_grad: TensorProxy):
    B, NH, N, d = q_grad.shape
    shape = (B, N, 3, NH, d)
    return TensorProxy(like=q_grad, shape=shape)


# permute backward symbol
llmc_permute_backward = llmc.register_operator(
    "llmc_permute_backward",
    like=permute_backward_meta,
    fn=permute_split_backward,
)


def unpermute_backward_meta(grad: TensorProxy):
    B, N, NH, d = grad.shape
    return TensorProxy(like=grad, shape=(B, NH, N, d))


# unpermute backward symbol
llmc_unpermute_backward = llmc.register_operator(
    "llmc_unpermute_backward",
    like=unpermute_backward_meta,
    fn=unpermute_backward,
)

def llmc_permute_util_meta(qkv, n_embd, n_head, B, T, C):
    shape = (B, n_head, T, C // n_head)
    return (
        TensorProxy(like=qkv, shape=shape),
        TensorProxy(like=qkv, shape=shape),
        TensorProxy(like=qkv, shape=shape),
    )


# composite permute symbol
llmc_permute_util = llmc.register_operator(
    "llmc_permute_util",
    like=llmc_permute_util_meta,
    fn=_permute,
    replaces=_permute,
)


def llmc_permute_util_execution_transform(qkv, n_embd, n_head, B, T, C):
    qkv = thunder.torch.reshape(qkv, (B, T, 3, n_head, C // n_head))
    return llmc_permute(qkv)


@langctx(Languages.TORCH)
def llmc_permute_util_grad_transform(qkv, n_embd, n_head, B, T, C):
    qkv_reshaped = thunder.torch.reshape(qkv, (B, T, 3, n_head, C // n_head))

    q, k, v = llmc_permute(qkv_reshaped)

    q_grad = get_grad(q)
    k_grad = get_grad(k)
    v_grad = get_grad(v)

    grad = llmc_permute_backward(q_grad, k_grad, v_grad)
    grad = thunder.torch.reshape(grad, qkv.shape)

    put_grad(qkv, grad)

    return q, k, v


llmc.register_implementation(
    llmc_permute_util,
    checker=lambda *args, **kwargs: True,
    execution_transform=llmc_permute_util_execution_transform,
    grad_transform=llmc_permute_util_grad_transform,
)


def llmc_unpermute_util_meta(input: TensorProxy, B: int, T: int, C: int):
    return TensorProxy(like=input, shape=(B, T, C))


# composite unpermute symbol
llmc_unpermute_util = llmc.register_operator(
    "llmc_unpermute_util",
    like=llmc_unpermute_util_meta,
    fn=_unpermute,
    replaces=_unpermute,
)


def llmc_unpermute_util_execution_transform(input: TensorProxy, B: int, T: int, C: int):
    res = llmc_unpermute(input)
    return thunder.torch.view(res, B, T, C)


def llmc_unpermute_util_grad_transform(input: TensorProxy, B: int, T: int, C: int):
    x = llmc_unpermute(input)
    fwd = thunder.torch.view(x, B, T, C)

    # NOTE: get_grad(x) breaks things. Why?
    fwd_grad = get_grad(fwd)
    fwd_grad = thunder.torch.view(fwd_grad, *x.shape)

    input_grad = llmc_unpermute_backward(fwd_grad)
    put_grad(input, input_grad)

    return fwd


llmc.register_implementation(
    llmc_unpermute_util,
    checker=lambda *args, **kwargs: True,
    execution_transform=llmc_unpermute_util_execution_transform,
    grad_transform=llmc_unpermute_util_grad_transform,
)
# }

In [15]:
# Softmax {
llmc_softmax = llmc.register_operator(
    "llmc_softmax",
    like=lambda input: TensorProxy(like=input),
    fn=softmax_forward,
)

def llmc_softmax_checker(x: TensorProxy, /, dim: int, *, dtype=None) -> bool:
    if not (dim == -1 or dim == x.ndim - 1):
        return False
    return True


def llmc_softmax_execution_transform(x: TensorProxy, /, dim: int, *, dtype=None) -> TensorProxy:
    return llmc_softmax(x)


llmc.register_implementation(
    thunder.torch.softmax,
    checker=llmc_softmax_checker,
    execution_transform=llmc_softmax_execution_transform,
)
# }

In [16]:
from typing import Any


# Crossentropy + Softmax {
def llmc_crossentropy_meta(
    a: TensorProxy,
    target: TensorProxy,
) -> TensorProxy:
    return TensorProxy(like=target, dtype=a.dtype)


llmc_crossentropy = llmc.register_operator(
    "llmc_crossentropy",
    like=llmc_crossentropy_meta,
    fn=crossentropy_forward,
)


llmc_crossentropy_softmax_backward = llmc.register_operator(
    "llmc_crossentropy_softmax_backward",
    like=lambda x, *args, **kwargs: TensorProxy(like=x),
    fn=crossentropy_softmax_backward,
)


def llmc_crossentropy_softmax_checker(
    a: TensorProxy,
    /,
    target: TensorProxy,
    weight: None | TensorProxy = None,
    size_average: None | Any = None,
    ignore_index: int = -100,
    reduce: None | Any = None,
    reduction: str = "mean",
    label_smoothing: float = 0.0,
):
    return (
        thunder.dtypes.is_integer_dtype(target.dtype)
        and weight is None
        and size_average is None
        and ignore_index < 0
        and reduce is None
        and reduction == "mean"
        and label_smoothing == 0.0
    )


# Setting torch langctx to be able to use binary ops with torch.Tensor
@langctx(Languages.TORCH)
def llmc_crossentropy_softmax_execution_transform(
    a: TensorProxy,
    /,
    target: TensorProxy,
    weight: None | TensorProxy = None,
    size_average: None | Any = None,
    ignore_index: int = -100,
    reduce: None | Any = None,
    reduction: str = "mean",
    label_smoothing: float = 0.0,
):
    probs = llmc_softmax(a)
    loss = llmc_crossentropy(probs, target)
    return loss.sum() / loss.numel


# Setting torch langctx to be able to use binary ops with torch.Tensor
@langctx(Languages.TORCH)
def llmc_crossentropy_softmax_grad_transform(
    logits: TensorProxy,
    /,
    targets: TensorProxy,
    weight: None | TensorProxy = None,
    size_average: None | Any = None,
    ignore_index: int = -100,
    reduce: None | Any = None,
    reduction: str = "mean",
    label_smoothing: float = 0.0,
):
    probs = llmc_softmax(logits)
    losses = llmc_crossentropy(probs, targets)
    loss = losses.sum() / losses.numel

    loss_grad = get_grad(loss)

    losses_grad = thunder.torch.ones_like(losses) / losses.numel
    logits_grad = thunder.torch.zeros_like(logits)

    logits_grad = llmc_crossentropy_softmax_backward(logits_grad, losses_grad, probs, targets)
    put_grad(logits, logits_grad)

    return loss


llmc.register_implementation(
    thunder.torch.cross_entropy,
    checker=llmc_crossentropy_softmax_checker,
    execution_transform=llmc_crossentropy_softmax_execution_transform,
    grad_transform=llmc_crossentropy_softmax_grad_transform,
)
# }

In [17]:
# GELU {
llmc_gelu = llmc.register_operator(
    "llmc_gelu",
    like=lambda input: TensorProxy(like=input),
    fn=gelu_forward,
)


llmc_gelu_backward = llmc.register_operator(
    "llmc_gelu_backward",
    like=lambda input, grad: TensorProxy(like=input),
    fn=gelu_backward,
)


llmc_module_gelu = llmc.register_operator(
    "llmc_module_gelu",
    like=lambda self, input: TensorProxy(like=input),
    fn=NewGELU.forward,
    replaces=NewGELU.forward,
)


def llmc_module_gelu_grad_transform(self, input: TensorProxy):
    fwd = llmc_gelu(input)

    fwd_grad = get_grad(fwd)

    input_grad = llmc_gelu_backward(input, fwd_grad)

    put_grad(input, input_grad)

    return fwd


llmc.register_implementation(
    llmc_module_gelu,
    checker=lambda *args, **kwargs: True,
    execution_transform=lambda self, input: llmc_gelu(input),
    grad_transform=llmc_module_gelu_grad_transform,
)
# }

In [18]:
# Replace LayerNorm {
def layernorm_forward(input, weight, bias):
    layernorm_forward_kernel = extract_cuda_kernel("layernorm_forward_kernel3")

    B, T, C = input.shape
    N = B * T

    out = input.new_empty(*input.shape)
    mean = input.new_empty(B, T)
    rstd = input.new_empty(B, T)

    block_size = 512
    grid_size = (N * 32 + block_size - 1) // block_size
    launch_kernel(
        layernorm_forward_kernel,
        (grid_size,),
        (block_size,),
        out, mean, rstd,
        input, weight, bias,
        numpy.int32(N), numpy.int32(C),
    )

    return out, mean, rstd


def layernorm_meta(input, weight, bias):
    B, T, C = input.shape
    return TensorProxy(like=input), TensorProxy(like=input, shape=(B, T)), TensorProxy(like=input, shape=(B, T))


llmc_layernorm = llmc.register_operator(
    "llmc_layernorm",
    like=layernorm_meta,
    fn=layernorm_forward,
)


def layernorm_backward(grad, input, weight, bias, mean, rstd):
    layernorm_backward_kernel = extract_cuda_kernel("layernorm_backward_kernel1")

    B, T, C = input.shape
    N = B * T

    input_grad = torch.zeros(*input.shape, dtype=input.dtype, device=input.device)
    weight_grad = torch.zeros(*weight.shape, dtype=weight.dtype, device=weight.device)
    bias_grad = torch.zeros(*bias.shape, dtype=bias.dtype, device=bias.device)

    block_size = 64
    grid_size = (N + block_size - 1) // block_size
    launch_kernel(
        layernorm_backward_kernel,
        (grid_size,),
        (block_size,),
        input_grad, weight_grad, bias_grad,
        grad, input, weight, mean, rstd,
        numpy.int32(B), numpy.int32(T), numpy.int32(C),
    )

    return input_grad, weight_grad, bias_grad


def layernorm_backward_meta(grad, input, weight, bias, mean, rstd):
    return TensorProxy(like=input), TensorProxy(like=weight), TensorProxy(like=bias)


llmc_layernorm_backward = llmc.register_operator(
    "llmc_layernorm_backward",
    like=layernorm_backward_meta,
    fn=layernorm_backward,
)


llmc_layer_norm = llmc.register_operator(
    "llmc_layer_norm",
    like=lambda input, normalzied_shape, weight, bias, eps=1e-5: TensorProxy(like=input),
    fn=F.layer_norm,
    replaces=F.layer_norm,
)


def llmc_layer_norm_execution_transform(input, normalized_shape, weight, bias, eps=1e-5):
    return llmc_layernorm(input, weight, bias)[0]


def llmc_layer_norm_grad_transform(input, normalized_shape, weight, bias, eps=1e-5):
    out, mean, rstd = llmc_layernorm(input, weight, bias)

    out_grad = get_grad(out)
    mean_grad = get_grad(mean)
    rstd_grad = get_grad(rstd)

    input_grad, weight_grad, bias_grad = llmc_layernorm_backward(out_grad, input, weight, bias, mean, rstd)

    put_grads((input, weight, bias), (input_grad, weight_grad, bias_grad))

    return out


llmc.register_implementation(
    llmc_layer_norm,
    checker=lambda *args, **kwargs: True,
    execution_transform=llmc_layer_norm_execution_transform,
    grad_transform=llmc_layer_norm_grad_transform,
)
# }

## Checking whether all the pieces fit
Now that re-mapping is complete, let's check its correctness.

In [19]:
custom_model = demo_model(thunder_jit=1, thunder_executors=(llmc,))

wrote ./llm.c/gpt2_tokenizer.bin
loading weights from pretrained gpt: gpt2
compiling the model with thunder.jit...
loss=tensor(5.2700, device='cuda:0', grad_fn=<ThunderFunctionBackward>)
iteration 0, loss: 5.27000617980957, time: 1815.470ms
loss=tensor(4.0597, device='cuda:0', grad_fn=<ThunderFunctionBackward>)
iteration 1, loss: 4.059720993041992, time: 53.967ms
loss=tensor(3.3752, device='cuda:0', grad_fn=<ThunderFunctionBackward>)
iteration 2, loss: 3.37518310546875, time: 54.288ms
loss=tensor(2.8008, device='cuda:0', grad_fn=<ThunderFunctionBackward>)
iteration 3, loss: 2.800813913345337, time: 55.244ms
loss=tensor(2.3154, device='cuda:0', grad_fn=<ThunderFunctionBackward>)
iteration 4, loss: 2.315415143966675, time: 53.213ms
loss=tensor(1.8490, device='cuda:0', grad_fn=<ThunderFunctionBackward>)
iteration 5, loss: 1.8490420579910278, time: 57.130ms
loss=tensor(1.3946, device='cuda:0', grad_fn=<ThunderFunctionBackward>)
iteration 6, loss: 1.3946458101272583, time: 57.649ms
loss=ten

The outcome indicates that at least correctness is preserved! We will check the forward/backward trace to see whether all our defined symbols appear over there! Note that symbols for which execution/grad transform is defined, are not going to appear in traces. This is because execution/grad transforms define decompositions replacing these symbols.

In [20]:
import itertools


# NOTE: "meta"-symbols with defined execution/grad transform are not included.
# This is because they are being decomposed into symbols with the names from the set defined below.
new_symbol_names = {
    "llmc_permute",
    "llmc_unpermute",
    "llmc_permute_backward",
    "llmc_unpermute_backward",
    "llmc_softmax",
    "llmc_crossentropy",
    "llmc_gelu",
    "llmc_gelu_backward",
    "llmc_layernorm",
    "llmc_layernorm_backward",
}

model_fwd_trace = thunder.last_traces(custom_model)[-1]
model_bkw_trace = thunder.last_backward_traces(custom_model)[-1]

seen_new_symbols = set()
for bsym in itertools.chain(model_fwd_trace.bound_symbols, model_bkw_trace.bound_symbols):
    seen_new_symbols.update({bsym.sym.name} & new_symbol_names)
    
assert seen_new_symbols == new_symbol_names

All the forward/backward symbols that call into the custom CUDA kernels are also in the traces. So there we have it - a PyTorch implementation with parts replaced with the user's code!

## Sanity check benchmarking
Let's do some basic benchmarking and see how the custom CUDA kernels perform. We test a combination of a forward and a backward step done with the Adam optimizer.

In [21]:
def fwd_bkw(model):
    B, T = 4, 64
    device = "cuda"
    
    data_iter = get_data_iter(B, T, device)
    x, y = next(data_iter)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

    def do_fwd_bkw():
        logits, loss = model(x, y)
        del logits
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
    # Warm-up runs!
    for _ in range(2):
        do_fwd_bkw()

    return do_fwd_bkw


custom_model_bench = fwd_bkw(custom_model)

thunder_model = demo_model(thunder_jit=1)
thunder_model_bench = fwd_bkw(thunder_model)

ref_model = demo_model(thunder_jit=0);
ref_model_bench = fwd_bkw(ref_model)

def benchmark(bench_f, num_runs=30):
    import time
    st = time.time()
    for i in range(num_runs):
        bench_f()
    torch.cuda.synchronize()
    et = time.time()
    avg_time_in_seconds = (et - st) / num_runs
    print(f"Elapsed average time (n={num_runs}): {avg_time_in_seconds * 1000:.4f}ms")

print()

benchmark(custom_model_bench)
benchmark(thunder_model_bench)
benchmark(ref_model_bench)

%reset -f

wrote ./llm.c/gpt2_tokenizer.bin
loading weights from pretrained gpt: gpt2
compiling the model with thunder.jit...
loss=tensor(5.2700, device='cuda:0', grad_fn=<ThunderFunctionBackward>)
iteration 0, loss: 5.270007610321045, time: 142.715ms
loss=tensor(4.0597, device='cuda:0', grad_fn=<ThunderFunctionBackward>)
iteration 1, loss: 4.059719562530518, time: 52.394ms
loss=tensor(3.3752, device='cuda:0', grad_fn=<ThunderFunctionBackward>)
iteration 2, loss: 3.375183582305908, time: 50.392ms
loss=tensor(2.8008, device='cuda:0', grad_fn=<ThunderFunctionBackward>)
iteration 3, loss: 2.8008131980895996, time: 57.834ms
loss=tensor(2.3154, device='cuda:0', grad_fn=<ThunderFunctionBackward>)
iteration 4, loss: 2.315413475036621, time: 57.975ms
loss=tensor(1.8490, device='cuda:0', grad_fn=<ThunderFunctionBackward>)
iteration 5, loss: 1.8490394353866577, time: 57.686ms
loss=tensor(1.3946, device='cuda:0', grad_fn=<ThunderFunctionBackward>)
iteration 6, loss: 1.3946452140808105, time: 57.700ms
loss=t

One could see that the model with custom CUDA kernels is generally on par (or even faster!) with the reference PyTorch implementation. This is quite amazing given that not all custom CUDA kernels are necessarily optimal, and that PyTorch kernels should be quite performant and efficient in this setting. Taking into account memory layouts and fusing computation definitely helps!

## Gap with the C implementation
Let us see how Python implementations compared to the C implementation. For that we will use the testing code from [llm.c](https://github.com/karpathy/llm.c).

In [22]:
%%bash
cd llm.c
make test_gpt2cu
./test_gpt2cu

NICE Compiling with OpenMP support
nvcc -O3 --use_fast_math test_gpt2.cu -lcublas -lcublasLt -o test_gpt2cu
[System]
Device 0: NVIDIA A100-SXM4-40GB
enable_tf32: 0
[GPT-2]
max_seq_len: 1024
vocab_size: 50257
num_layers: 12
num_heads: 12
channels: 768
num_parameters: 124439808
[State]
batch_size: 4
seq_len: 64
num_activations: 82760960
-43.431671 -43.431725
-39.836388 -39.836433
-43.065968 -43.066017
OK (LOGITS)
LOSS OK: 5.270009 5.270009
grads
OK -0.002320 -0.002320
OK 0.002072 0.002072
OK 0.003717 0.003717
OK 0.001307 0.001307
OK 0.000632 0.000632
TENSOR OK
step 0: loss 5.270009 (took 34.621211 ms)
step 1: loss 4.059718 (took 10.644219 ms)
step 2: loss 3.375185 (took 246.578083 ms)
step 3: loss 2.800815 (took 246.652834 ms)
step 4: loss 2.315419 (took 247.894914 ms)
step 5: loss 1.849051 (took 246.635729 ms)
step 6: loss 1.394659 (took 246.620490 ms)
step 7: loss 0.999220 (took 246.692001 ms)
step 8: loss 0.624077 (took 246.682988 ms)
step 9: loss 0.376495 (took 246.609135 ms)
loss ok

This is what we get when using [llm.c](https://github.com/karpathy/llm.c) checked out at `954077fb887d2770e4d537bafea056473d4bb4ce`. However, a more recent version `50acc125f39694ee43f285e5cf8fc123cbb911fa` is more performant.

In [23]:
%%bash
cd llm.c
git checkout 50acc125f39694ee43f285e5cf8fc123cbb911fa
python train_gpt2.py
make test_gpt2fp32cu
./test_gpt2fp32cu

Previous HEAD position was 954077f TRAINING WORKSgit add train_gpt2.cu! ITS SLOW BUT IT WORKS WOOT
HEAD is now at 50acc12 Merge branch 'ngc92-split-file' Separates out common error-checking wrapper utils, that are broadly useful across all file


Running pytorch 2.4.0.dev20240430+cu121
using device: cuda
wrote gpt2_tokenizer.bin
loading weights from pretrained gpt: gpt2
loading cached tokens in data/tiny_shakespeare_val.bin
padded vocab size from 50257 to 50304
wrote gpt2_124M.bin
padded vocab size from 50257 to 50304
wrote gpt2_124M_bf16.bin
padded vocab size in reference grads from 50257 to 50304
wrote gpt2_124M_debug_state.bin
iteration 1, loss: 4.175999641418457, time: 42.633ms
iteration 2, loss: 3.8168132305145264, time: 36.040ms
iteration 3, loss: 3.706939935684204, time: 45.228ms
iteration 4, loss: 3.9031288623809814, time: 45.948ms
iteration 5, loss: 3.2775278091430664, time: 45.743ms
iteration 6, loss: 2.9421181678771973, time: 45.576ms
iteration 7, loss: 2.8467600345611572, time: 45.889ms
iteration 8, loss: 2.6670680046081543, time: 49.916ms
iteration 9, loss: 2.408634901046753, time: 49.969ms
final 9 iters avg: 45.216ms
peak memory consumption: 2341 MiB
<|endoftext|>One of the most important issues in our political s

Makefile:84: OpenMPI is not found, disabling multi-GPU support
Makefile:85: On Linux you can try install OpenMPI with `sudo apt install openmpi-bin openmpi-doc libopenmpi-dev`


nvcc found, including CUDA builds
/usr/local/cuda-12/bin/nvcc -O3 -t=0 --use_fast_math test_gpt2_fp32.cu -lcublas -lcublasLt   -lcublas -lcublasLt -o test_gpt2fp32cu


      int V = config.vocab_size;
          ^




[System]
Device 0: NVIDIA A100-SXM4-40GB
enable_tf32: 0
[State]
batch_size: 4
seq_len: 64
allocated 221 MiB for activations
-43.431671, -43.431725
-39.836388, -39.836433
-43.065968, -43.066017
-42.828091, -42.828136
-43.529598, -43.529644
-44.318451, -44.318501
-41.227470, -41.227524
-41.270821, -41.270866
-42.541451, -42.541515
-42.395061, -42.395103
OK (LOGITS)
allocated 474 MiB for parameter gradients
allocated 4 MiB for activation gradients
LOSS OK: 5.270009 5.270009
grads
OK -0.002320 -0.002320
OK 0.002072 0.002072
OK 0.003717 0.003717
OK 0.001307 0.001307
OK 0.000632 0.000632
TENSOR OK
allocated 474 MiB for AdamW optimizer state m
allocated 474 MiB for AdamW optimizer state v
step 0: loss 5.270009 (took 13.917726 ms)
step 1: loss 4.059717 (took 12.114419 ms)
step 2: loss 3.375185 (took 24.071226 ms)
step 3: loss 2.800816 (took 24.149647 ms)
step 4: loss 2.315418 (took 24.197707 ms)
step 5: loss 1.849047 (took 24.187108 ms)
step 6: loss 1.394655 (took 24.166976 ms)
step 7: loss 0.

This later version also outperforms the Thunder version with custom kernels!
To bridge this performance gap, as future work, we plan to
* Replace older kernels with newer ones.
* Improve memory management. The Thunder version delegates memory management to PyTorch and does so in a model-agnostic fashion which is not guaranteed to be optimal. In contrast, the C implementation is not model-agnostic and assumes full control over efficient memory allocation and management.

Stay tuned and may your kernels run fast!