In [1]:
!pip install -q transformers accelerate ninja

In [2]:
import torch, os

if not torch.cuda.is_available():
    raise SystemError("GPU NOT FOUND")

gpu_name = torch.cuda.get_device_name(0)
compute_capability = torch.cuda.get_device_capability(0)
print(f"GPU Name: {gpu_name}  (compute capability {compute_capability})")

GPU Name: Tesla T4  (compute capability (7, 5))


In [3]:
from transformers import AutoTokenizer, AutoModelForMaskedLM

MODEL_NAME = "bert-base-uncased"
DTYPE      = torch.float16

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = (
    AutoModelForMaskedLM
    .from_pretrained(MODEL_NAME, torch_dtype=DTYPE)
    .to("cuda")
    .eval()
)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [4]:
example_text = "hello world"
inputs = tokenizer(example_text, return_tensors="pt").to("cuda")

with torch.no_grad():
    outputs = model(**inputs)
    logits  = outputs.logits

print("Logits tensor shape:", logits.shape)

Logits tensor shape: torch.Size([1, 4, 30522])


# Profiling

In [5]:
import torch
from torch.profiler import profile, record_function, ProfilerActivity

def run_profiler(model, tokenizer, text_batch, training=False, warmup_iters=2, row_limit=20):

    device = next(model.parameters()).device
    model.train() if training else model.eval()

    batch = tokenizer(text_batch, return_tensors="pt", padding=True).to(device)

    # 1. Warm-up (kept out of the trace)
    for _ in range(warmup_iters):
        if training:
            out = model(**batch, labels=batch["input_ids"])
            out.loss.backward()
        else:
            with torch.no_grad():
                _ = model(**batch)
        torch.cuda.synchronize()

    # 2. One profiled step
    tag       = "training" if training else "inference"
    tracefile = f"trace_{tag}.json"

    with profile(activities=[
                    ProfilerActivity.CPU,
                    ProfilerActivity.CUDA
                 ],
                 record_shapes=True,
                 profile_memory=True,
                 with_stack=True) as prof:
        with record_function(tag):
            if training:
                out = model(**batch, labels=batch["input_ids"])
                out.loss.backward()
            else:
                with torch.no_grad():
                    _ = model(**batch)


    print(f"Top {row_limit} CUDA kernels ({tag}):")
    print(prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=row_limit))

    prof.export_chrome_trace(tracefile)
    print(f"Full timeline saved to ./{tracefile}")

In [6]:
sentences = ["the quick brown fox jumps over the lazy dog"] * 8

run_profiler(model, tokenizer, sentences, training=False)  # inference trace

Top 20 CUDA kernels (inference):
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg       CPU Mem  Self CPU Mem      CUDA Mem  Self CUDA Mem    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                              inference         0.00%       0.000us         0.00%       0.000us       0.000us      25.335ms       521.93%      25.335ms      25.335ms           0 b        

In [7]:
sentences = ["the quick brown fox jumps over the lazy dog"] * 8
run_profiler(model, tokenizer, sentences, training=True)   # training trace

Top 20 CUDA kernels (training):
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg       CPU Mem  Self CPU Mem      CUDA Mem  Self CUDA Mem    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                               training         0.00%       0.000us         0.00%       0.000us       0.000us      22.869ms       119.54%      22.869ms      22.869ms           0 b         

In [8]:
# Example payload: 8 sentences, padded to same length
batch_txt = [
    "the quick brown fox jumps over the lazy dog",
] * 8                                           # batch size = 8

# Tokenize → tensors → push to CUDA
batch = tokenizer(
    batch_txt,
    return_tensors="pt",
    padding=True,
).to("cuda")      # batch is a dict with input_ids, attention_mask, etc.

In [9]:
import torch, gc, statistics as stats, time

def _cuda_timer(step_fn, warm=50, iters=500):
    """
    step_fn() must run exactly ONE forward or train step.
    Returns median and stdev (in ms) across `iters` timed runs.
    """
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.benchmark = True
    gc.disable()

    for _ in range(warm):
        step_fn()
    torch.cuda.synchronize()

    start = torch.cuda.Event(enable_timing=True)
    end   = torch.cuda.Event(enable_timing=True)
    times = []

    for _ in range(iters):
        start.record()
        step_fn()
        end.record()
        torch.cuda.synchronize()          # wait for kernel to finish
        times.append(start.elapsed_time(end))  # ms

    return stats.median(times), stats.stdev(times)

@torch.inference_mode()
def _fwd_step():
    _ = model(**batch)

def _train_step():
    loss = model(**batch, labels=batch["input_ids"]).loss
    loss.backward()
    model.zero_grad(set_to_none=True)

med_inf,  std_inf  = _cuda_timer(_fwd_step,   warm=100, iters=1000)
med_train, std_train = _cuda_timer(_train_step, warm=50,  iters=300)

print(f"Inference  : {med_inf:.2f} ± {std_inf:.2f} ms")
print(f"Train step : {med_train:.2f} ± {std_train:.2f} ms")


Inference  : 17.55 ± 11.15 ms
Train step : 57.21 ± 24.53 ms


## Kernel Optimizations

In [10]:
import triton, triton.language as tl
import torch, types, math, time, gc
from contextlib import nullcontext

@triton.jit
def _tanh(x):
    return 2.0 * tl.sigmoid(2.0 * x) - 1.0

@triton.jit
def _fast_gelu(x):
    return 0.5 * x * (1.0 + _tanh(0.79788456 * (x + 0.044715 * x * x * x)))


@triton.jit
def fbgemm_kernel(X, W, B, Y,
                  M, N, K,
                  BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr,
                  BLOCK_K: tl.constexpr):
    pid_m = tl.program_id(0)          # Process ID along M dimension
    pid_n = tl.program_id(1)          # along N dimension

    offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
    offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
    offs_k = tl.arange(0, BLOCK_K)

    X_ptrs = X + (offs_m[:, None] * K + offs_k[None, :])
    W_ptrs = W + (offs_n[None, :] * K + offs_k[:, None])
    acc = tl.zeros((BLOCK_M, BLOCK_N), tl.float32)

    for k in range(0, K, BLOCK_K):
        x = tl.load(X_ptrs + k,
                    mask=(offs_m[:, None] < M) & (offs_k[None, :] < K),
                    other=0.0)
        w = tl.load(W_ptrs + k,
                    mask=(offs_n[None, :] < N) & (offs_k[:, None] < K),
                    other=0.0)
        acc += tl.dot(x, w)

    # add bias
    bias = tl.load(B + offs_n, mask=offs_n < N, other=0.0)
    acc = acc + bias[None, :]

    # fast-GELU in fp16 for storage
    acc_f16 = _fast_gelu(acc)
    acc_f16 = acc.to(tl.float16)

    tl.store(Y + (offs_m[:, None] * N + offs_n[None, :]),
             acc_f16,
             mask=(offs_m[:, None] < M) & (offs_n[None, :] < N))

def fused_bias_gelu_triton(x: torch.Tensor,
                            w: torch.Tensor,
                            b: torch.Tensor) -> torch.Tensor:
    """x:[M,K]  w:[N,K]  b:[N]  -> y:[M,N] (all fp16)"""
    M, K = x.shape
    N = w.shape[0]
    y = torch.empty((M, N), dtype=x.dtype, device=x.device)

    BLOCK_M, BLOCK_N, BLOCK_K = 128, 128, 32
    grid = (triton.cdiv(M, BLOCK_M),
            triton.cdiv(N, BLOCK_N))

    fbgemm_kernel[grid](
        x, w, b, y,
        M, N, K,
        BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K,
        num_warps=4, num_stages=2
    )
    return y

In [11]:
# Autograd wrapper (fused fwd, stock bwd)
class TritonFBGGEMM(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, w, b):
        y = fused_bias_gelu_triton(x, w, b)
        ctx.save_for_backward(x, w, b, y)
        return y
    @staticmethod
    def backward(ctx, dy):
        x, w, b, y = ctx.saved_tensors
        # GELU backward (approx.)
        pre = y
        tanh_out = torch.tanh(0.79788456 *
                              (pre + 0.044715 * pre * pre * pre))
        gelu_grad = 0.5 * (1.0 + tanh_out) + \
                    0.5 * pre * (1 - tanh_out * tanh_out) * \
                    (0.79788456 + 0.134145 * pre * pre)
        dy_pre = dy * gelu_grad

        dx = torch.matmul(dy_pre, w)                 # [M,K]
        dw = torch.matmul(dy_pre.t(), x)         # [N,K]
        db = dy_pre.sum(dim=0)
        return dx, dw, db

In [12]:
# Monkey-patch every feed-forward “intermediate dense” layer
def patch_ffn(block):
    lin1, lin2 = block.intermediate.dense, block.output.dense
    W1 = lin1.weight.contiguous().half()   # [N,K] for row-major load
    B1 = lin1.bias.contiguous().half()
    def fused_forward(self, hidden_states):
        B, S, K = hidden_states.shape
        x = hidden_states.reshape(-1, K)                # [M,K]
        y = TritonFBGGEMM.apply(x, W1, B1)              # fused op
        return y.view(B, S, -1)
    block.intermediate.forward = types.MethodType(fused_forward,
                                                  block.intermediate)

In [13]:
for blk in model.bert.encoder.layer:
    patch_ffn(blk)
print("Patched FFN (intermediate dense) with Triton fused kernel")

Patched FFN (intermediate dense) with Triton fused kernel


In [14]:
@torch.inference_mode()
def _fwd_step():
    _ = model(**batch)

def _train_step():
    loss = model(**batch, labels=batch["input_ids"]).loss
    loss.backward()
    model.zero_grad(set_to_none=True)

med_inf,  std_inf  = _cuda_timer(_fwd_step,   warm=100, iters=1000)
med_train, std_train = _cuda_timer(_train_step, warm=50,  iters=300)

print(f"Inference  : {med_inf:.2f} ± {std_inf:.2f} ms")
print(f"Train step : {med_train:.2f} ± {std_train:.2f} ms")

Inference  : 11.65 ± 2.25 ms
Train step : 35.36 ± 7.23 ms
