In [None]:

import os
os.environ["CUDA_VISIBLE_DEVICES"]="3"
import time
from datetime import timedelta

import torch
from torch.utils.checkpoint import checkpoint
import torch.nn.functional as F

import triton
import triton.language as tl

import pdb


In [None]:
device = torch.device("cuda:0")
torch.cuda.device_count()

In [None]:
def cosim(x,y):
    return ((x.reshape(-1).double() * y.reshape(-1).double()).sum() / x.reshape(-1).double().norm() / y.reshape(-1).double().norm()).float()

def baseline_torch(x, y, A):
    V = A.shape[0]
    return F.cross_entropy(F.linear(x, A).view(-1, V).float(), y.view(-1))

compiled_baseline = torch.compile(baseline_torch)
maxauto_baseline = torch.compile(baseline_torch, fullgraph=True, mode="max-autotune")

In [None]:
f = 1

In [None]:
B, S , H, V = 4, 4 // f, 16 // f, 512 // f
N = B * S 

y = torch.randint(0, V, (B * S,), device=device) # vocab ** B S 
A = torch.randn(V, H, requires_grad=True, device=device, dtype=torch.bfloat16)
At = A.clone().T.contiguous()

# x = torch.randn(B * S, H, requires_grad=True, device=device, dtype=torch.float32) # B S H
# x = A[y].clone().detach()
x = 0.05 * A[y].clone().detach() + torch.randn(B * S, H, device=device, dtype=torch.bfloat16)
x.requires_grad_()

loss = baseline_torch(x.double(), y, A.double())
loss.backward()

reference_A_grad = A.grad.float().clone()
reference_x_grad = x.grad.float().clone()
reference_loss = loss.detach().float().clone()


chunk_size = 16
V_chunk_size = 16
print(reference_loss)

In [None]:
@triton.autotune(
   configs=[
    triton.Config({'V_BLOCK_SIZE': 16, 'N_BLOCK_SIZE': 16, 'H_BLOCK_SIZE': 16}),
    # triton.Config({'V_BLOCK_SIZE': 64, 'N_BLOCK_SIZE': 64, 'H_BLOCK_SIZE': 64}),
    # triton.Config({'V_BLOCK_SIZE': 256, 'N_BLOCK_SIZE': 16, 'H_BLOCK_SIZE': 256}),
    # triton.Config({'V_BLOCK_SIZE': 512, 'N_BLOCK_SIZE': 16, 'H_BLOCK_SIZE': 512}),
    # triton.Config({'V_BLOCK_SIZE': 256, 'N_BLOCK_SIZE': 64, 'H_BLOCK_SIZE': 64}),

    # triton.Config({'V_BLOCK_SIZE': 256, 'N_BLOCK_SIZE': 256, 'H_BLOCK_SIZE': 64}),
    # triton.Config({'V_BLOCK_SIZE': 256, 'N_BLOCK_SIZE': 256, 'H_BLOCK_SIZE': 256}),
    # triton.Config({'V_BLOCK_SIZE': 256, 'N_BLOCK_SIZE': 16, 'H_BLOCK_SIZE': 16}),
   ],
   key=['V', 'N', 'H'],
   reset_to_zero=["loss_ptr", "m_global_ptr", "s_global_ptr"]
)
@triton.jit
def linear_xent_fwd_kernel_matmul(x_ptr,
                y_ptr,
                A_ptr,
                loss_ptr,  
                m_global_ptr,
                s_global_ptr,
                stride_x_N, stride_x_H,
                stride_A_V, stride_A_H,
                V: tl.constexpr, N: tl.constexpr, H: tl.constexpr,
                V_BLOCK_SIZE: tl.constexpr,
                N_BLOCK_SIZE: tl.constexpr,
                H_BLOCK_SIZE: tl.constexpr,
               ):
    idx = tl.program_id(axis=0)  
    x_block_ptr = tl.make_block_ptr(
        base=x_ptr,
        shape=(N, H),
        strides=(stride_x_N, stride_x_H),
        offsets=(idx * N_BLOCK_SIZE, 0),
        block_shape=(N_BLOCK_SIZE, H_BLOCK_SIZE),
        order=(1, 0),
    )
    A_block_ptr = tl.make_block_ptr(
        base=A_ptr,
        shape=(V, H),
        strides=(stride_A_V, stride_A_H),
        offsets=(0, 0),
        block_shape=(V_BLOCK_SIZE, H_BLOCK_SIZE),
        order=(1, 0),
    )
    offsets = idx * N_BLOCK_SIZE + tl.arange(0, N_BLOCK_SIZE)
    v_range = tl.arange(0, V_BLOCK_SIZE)
    y = tl.load(y_ptr + offsets)

    m = tl.load(m_global_ptr + offsets)
    s = tl.load(s_global_ptr + offsets)
    loss = 0.

    
    
    for _ in range(V // V_BLOCK_SIZE):
        
        z_j_to_k = tl.zeros((N_BLOCK_SIZE, V_BLOCK_SIZE), dtype=tl.float32)
        local_x_block_ptr = x_block_ptr
        for _ in range(H // H_BLOCK_SIZE):
            x_chunk = tl.load(local_x_block_ptr) # Nc x H 
            A_v = tl.load(A_block_ptr) # Vc x H 

            z_j_to_k = tl.dot(x_chunk, A_v.trans(), z_j_to_k) # (Nc x H) @ (H x Vc)

            local_x_block_ptr = tl.advance(local_x_block_ptr, [0, H_BLOCK_SIZE])
            A_block_ptr = tl.advance(A_block_ptr, [0, H_BLOCK_SIZE])
    
            
        m_new = tl.maximum(m, tl.max(z_j_to_k, 1))

        s_update = tl.sum(tl.exp(z_j_to_k - m_new[:, None]), axis=1)
        s = s * tl.exp(m - m_new) + s_update

        mask = y[:, None] == v_range[None, :] # Nc x Vc
        loss -= tl.sum(tl.where(mask, z_j_to_k, float(0.0))) / N

        m = m_new
        A_block_ptr = tl.advance(A_block_ptr, [V_BLOCK_SIZE, -H_BLOCK_SIZE * (H // H_BLOCK_SIZE)])
        v_range = v_range + V_BLOCK_SIZE
    
    # tl.device_print("m", m)
    # tl.device_print("s", s)
    loss += tl.sum(m + tl.log(s)) / N

    tl.atomic_add(loss_ptr, loss)
    tl.store(m_global_ptr + offsets, m)
    tl.store(s_global_ptr + offsets, s)

def linear_xent_matmul(x, y, A):
    N, H = x.shape
    V, H_A = A.shape
    assert H_A == H
    assert y.shape == (N,)
    x = x.contiguous()
    y = y.contiguous()
    A = A.contiguous()

    # assert V % 256 == 0, f"V is {V}"
    # assert N % 64 == 0, f"N is {N}"
    # assert H % 64 == 0, f"H is {H}"

    m_global = -10e5 * torch.ones(N, dtype=torch.float32, device=x.device)
    s_global = torch.zeros(N, dtype=torch.float32, device=x.device)
    loss = torch.zeros(1, dtype=torch.float32, device=x.device)

    # grid = (num_blocks,)
    grid = lambda meta: (triton.cdiv(N, meta['N_BLOCK_SIZE']), )

    with torch.cuda.device(x.device.index): # actually required
        linear_xent_fwd_kernel_matmul[grid](
                x,
                y,
                A,
                loss,  
                m_global,
                s_global,
                x.stride(0), x.stride(1),
                A.stride(0), A.stride(1),
                V=V, N=N, H=H)
    # print(linear_xent_fwd_kernel_matmul.best_config)
    return loss

loss = linear_xent_matmul(x, y, A) # autotune
torch.cuda.synchronize()
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
loss = linear_xent_matmul(x, y, A)
end_event.record()
torch.cuda.synchronize()
estimate_ms = start_event.elapsed_time(end_event)
print(f"Simple timing: {estimate_ms}ms")
print(loss), print(torch.dist(loss, reference_loss).item())

In [None]:
@triton.autotune(
   configs=[
    triton.Config({'V_BLOCK_SIZE': 16, 'N_BLOCK_SIZE': 16, 'H_BLOCK_SIZE': 16}),
    # triton.Config({'V_BLOCK_SIZE': 64, 'N_BLOCK_SIZE': 64, 'H_BLOCK_SIZE': 64}),
    # triton.Config({'V_BLOCK_SIZE': 256, 'N_BLOCK_SIZE': 16, 'H_BLOCK_SIZE': 256}),
    # triton.Config({'V_BLOCK_SIZE': 512, 'N_BLOCK_SIZE': 16, 'H_BLOCK_SIZE': 512}),
    # triton.Config({'V_BLOCK_SIZE': 256, 'N_BLOCK_SIZE': 64, 'H_BLOCK_SIZE': 64}),

    # triton.Config({'V_BLOCK_SIZE': 256, 'N_BLOCK_SIZE': 256, 'H_BLOCK_SIZE': 64}),
    # triton.Config({'V_BLOCK_SIZE': 256, 'N_BLOCK_SIZE': 256, 'H_BLOCK_SIZE': 256}),
    # triton.Config({'V_BLOCK_SIZE': 256, 'N_BLOCK_SIZE': 16, 'H_BLOCK_SIZE': 16}),
   ],
   key=['V', 'N', 'H'],
   reset_to_zero=["loss_ptr", "m_global_ptr", "s_global_ptr"]
)
@triton.jit
def linear_xent_fwd_kernel_matmul_t(x_ptr,
                y_ptr,
                A_t_ptr,
                loss_ptr,  
                m_global_ptr,
                s_global_ptr,
                stride_x_N, stride_x_H,
                stride_A_H, stride_A_V,
                V: tl.constexpr, N: tl.constexpr, H: tl.constexpr,
                V_BLOCK_SIZE: tl.constexpr,
                N_BLOCK_SIZE: tl.constexpr,
                H_BLOCK_SIZE: tl.constexpr,
               ):
    idx = tl.program_id(axis=0)  
    x_block_ptr = tl.make_block_ptr(
        base=x_ptr,
        shape=(N, H),
        strides=(stride_x_N, stride_x_H),
        offsets=(idx * N_BLOCK_SIZE, 0),
        block_shape=(N_BLOCK_SIZE, H_BLOCK_SIZE),
        order=(1, 0),
    )
    A_block_ptr = tl.make_block_ptr(
        base=A_t_ptr,
        shape=(H, V),
        strides=(stride_A_H, stride_A_V),
        offsets=(0, 0),
        block_shape=(H_BLOCK_SIZE, V_BLOCK_SIZE),
        order=(1, 0),
    )
    offsets = idx * N_BLOCK_SIZE + tl.arange(0, N_BLOCK_SIZE)
    v_range = tl.arange(0, V_BLOCK_SIZE)
    y = tl.load(y_ptr + offsets)

    m = tl.load(m_global_ptr + offsets)
    s = tl.load(s_global_ptr + offsets)
    loss = 0.

    for _ in range(V // V_BLOCK_SIZE):
        
        z_j_to_k = tl.zeros((N_BLOCK_SIZE, V_BLOCK_SIZE), dtype=tl.float32)
        local_x_block_ptr = x_block_ptr
        for _ in range(H // H_BLOCK_SIZE):
            x_chunk = tl.load(local_x_block_ptr) # Nc x H 
            A_v = tl.load(A_block_ptr) # Vc x H 

            z_j_to_k = tl.dot(x_chunk, A_v, z_j_to_k) # (Nc x H) @ (H x Vc)

            local_x_block_ptr = tl.advance(local_x_block_ptr, [0, H_BLOCK_SIZE])
            A_block_ptr = tl.advance(A_block_ptr, [H_BLOCK_SIZE, 0])
    
            
        m_new = tl.maximum(m, tl.max(z_j_to_k, 1))

        s_update = tl.sum(tl.exp(z_j_to_k - m_new[:, None]), axis=1)
        s = s * tl.exp(m - m_new) + s_update

        mask = y[:, None] == v_range[None, :] # Nc x Vc
        loss -= tl.sum(tl.where(mask, z_j_to_k, float(0.0))) / N

        m = m_new
        A_block_ptr = tl.advance(A_block_ptr, [-H_BLOCK_SIZE * (H // H_BLOCK_SIZE), V_BLOCK_SIZE])
        v_range = v_range + V_BLOCK_SIZE
    
    # tl.device_print("m", m)
    # tl.device_print("s", s)
    loss += tl.sum(m + tl.log(s)) / N

    tl.atomic_add(loss_ptr, loss)
    tl.store(m_global_ptr + offsets, m)
    tl.store(s_global_ptr + offsets, s)

@torch.no_grad
def linear_xent_matmul_At(x, y, At):
    N, H = x.shape
    H_A, V = At.shape # V, H_A = A.shape
    assert H_A == H
    assert y.shape == (N,)
    x = x.contiguous()
    y = y.contiguous()
    At = At.contiguous()

    # assert V % 256 == 0, f"V is {V}"
    # assert N % 64 == 0, f"N is {N}"
    # assert H % 64 == 0, f"H is {H}"

    m_global = -10e5 * torch.ones(N, dtype=torch.float32, device=x.device)
    s_global = torch.ones(N, dtype=torch.float32, device=x.device)
    loss = torch.zeros(1, dtype=torch.float32, device=x.device)

    # grid = (num_blocks,)
    grid = lambda meta: (triton.cdiv(N, meta['N_BLOCK_SIZE']), )

    with torch.cuda.device(x.device.index): # actually required
        linear_xent_fwd_kernel_matmul_t[grid](
                x,
                y,
                At,
                loss,  
                m_global,
                s_global,
                x.stride(0), x.stride(1),
                At.stride(0), At.stride(1),
                V=V, N=N, H=H)
    # print(linear_xent_fwd_kernel_matmul.best_config)
    return loss, m_global, s_global

loss, _, _ = linear_xent_matmul_At(x, y, At) # autotune
torch.cuda.synchronize()
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
loss, m_global, s_global = linear_xent_matmul_At(x, y, At)
end_event.record()
torch.cuda.synchronize()
estimate_ms = start_event.elapsed_time(end_event)
print(f"Simple timing: {estimate_ms}ms")
print(loss), print(torch.dist(loss, reference_loss).item())

In [None]:
@torch.no_grad
def manual_implementation_chunked_online_lse_single_loop_bf16_blockV_2(x, y, A):
    x = x.cpu()
    A = A.cpu()
    y = y.cpu()


    x_chunks = x.view(-1, H).split(chunk_size)
    y_chunks = y.view(-1).split(chunk_size)
    num_chunks = len(y_chunks)
    m_global = torch.zeros(N, dtype=torch.float32)
    s_global = torch.zeros(N, dtype=torch.float32)
    losses = torch.zeros(num_chunks, dtype=torch.float32)

    A_chunks = A.split(V_chunk_size)
    
    for idx, (x_chunk, y_chunk) in enumerate(zip(x_chunks, y_chunks)):
        m = -10e5 * torch.ones(chunk_size, dtype=torch.float32)
        s = torch.zeros(chunk_size, dtype=torch.float32)
        v_offsets = torch.arange(0, V_chunk_size)

        for v_idx, A_v in enumerate(A_chunks):
            m_prev = m.clone()
            z_j_to_k = (x_chunk @ A_v.T).to(dtype=torch.float32)
            m = torch.maximum(m_prev, z_j_to_k.max(dim=-1)[0])
            s = s * (m_prev - m).exp() + (z_j_to_k-  m[:, None]).exp().sum(dim=-1)

            mask = y_chunk[:, None] == v_offsets[None, :]
            losses[idx] -= torch.where(mask, z_j_to_k, 0.0).sum() / N

            v_offsets += V_chunk_size


        losses[idx] += (m + s.log()).sum() / N
        m_global[idx * chunk_size: (idx+1)*chunk_size] = m
        s_global[idx * chunk_size: (idx+1)*chunk_size] = s
    
    return losses.sum(), m_global, s_global

loss, m_ref2, s_ref2 = manual_implementation_chunked_online_lse_single_loop_bf16_blockV_2(x, y, A)
loss.item(), torch.dist(loss, reference_loss).item() # , f"{torch.dist(loss, reference_loss).item():2.4e}"

In [None]:
z = F.linear(x, A).view(-1, V).float()
torch.dist(m_global, z.max(dim=1)[0]), cosim(z.max(dim=1)[0], m_global)

In [None]:
torch.dist(s_global, (z - m_global[:, None]).exp().sum(dim=1)), cosim(s_global, (z - m_global[:, None]).exp().sum(dim=1))

In [None]:
(z.max(dim=1)[0] / m_global).mean(), 1 / (s_global / (z - m_global[:, None]).exp().sum(dim=1)).mean()

In [None]:
torch.dist(m_ref2, z.cpu().max(dim=1)[0]), torch.dist(s_ref2, (z.cpu() - m_ref2[:, None]).exp().sum(dim=1))

In [None]:
H_chunk_size = 16
V_chunk_size = 16
N_chunk_size = 16
print(N_chunk_size, V_chunk_size, H_chunk_size)
print(N, V, H)

@torch.no_grad
def manual_implementation_chunked_online_lse_bwd_single_loop_bf16_blockV_blockH_At(x, y, A, m_global, s_global):
    compute_dtype = torch.float32
    x = x.to(dtype=compute_dtype).cpu() # float16 is much more accurate for "sane" values of x and A
    At = A.T.to(dtype=compute_dtype).cpu()
    y = y.cpu()

    x_chunks = x.view(-1, H).split(N_chunk_size)
    y_chunks = y.view(-1).split(N_chunk_size)

    At_chunks = At.split(V_chunk_size, dim=1)
    v_offsets = torch.arange(V)

    Atgrad = torch.zeros_like(At, dtype=torch.float32)
    global_x_grad = torch.zeros_like(x.view(-1, H), dtype=torch.float32)

    for idx, (x_chunk, y_chunk) in enumerate(zip(x_chunks, y_chunks)):
        s = s_global[idx * N_chunk_size: (idx+1)*N_chunk_size]
        m = m_global[idx * N_chunk_size: (idx+1)*N_chunk_size]
        # xgrad = torch.zeros_like(x_chunk)
        Nc = x_chunk.shape[0]
        Vc = At_chunks[0].shape[1]

        for v_idx, A_v in enumerate(At_chunks): # can parallelize
            v_range = v_offsets[v_idx * V_chunk_size : (v_idx+1) * V_chunk_size]

            z_j_to_k = torch.zeros(Nc, Vc)
            for h_idx in range(H // H_chunk_size):
                x_chunk_h = x_chunk[:, h_idx*H_chunk_size : (h_idx+1) * H_chunk_size]
                A_chunk_h = A_v[h_idx*H_chunk_size : (h_idx+1) * H_chunk_size, :]

                z_j_to_k += (x_chunk_h @ A_chunk_h).to(dtype=torch.float32)
            
            softmax_z = ((z_j_to_k - m[:, None]).exp() / s[:, None]).to(dtype=compute_dtype)
            mask = (y_chunk[:, None] == v_range[None, :])[:,:,None] # needs to be N_BLOCK_SIZE x V_BLOCK_SIZE x 1 ?

            for h_idx in range(H // H_chunk_size):
                x_chunk_h = x_chunk[:, h_idx*H_chunk_size : (h_idx+1) * H_chunk_size]
                A_chunk_h = A_v[h_idx*H_chunk_size : (h_idx+1) * H_chunk_size, :]


                xgrad_temp = (softmax_z @  A_chunk_h.T).to(dtype=torch.float32)
                xgrad_temp -= torch.where(mask, A_chunk_h.T[None, :, :], 0.0).sum(dim=1)
                global_x_grad[idx * N_chunk_size: (idx+1)*N_chunk_size, h_idx*H_chunk_size : (h_idx+1) * H_chunk_size] += xgrad_temp / N
                Agrad_temp = (softmax_z.T @ x_chunk_h).to(dtype=torch.float32)
                Agrad_temp -= torch.where(mask, x_chunk_h[:, None, :], 0.0).sum(dim=0)
                Atgrad[h_idx*H_chunk_size : (h_idx+1) * H_chunk_size, v_range] += Agrad_temp.T
               
    
    return (Atgrad / N).T, global_x_grad.view_as(x)

Agrad, xgrad = manual_implementation_chunked_online_lse_bwd_single_loop_bf16_blockV_blockH_At(x, y, A, m_ref2, s_ref2)
torch.dist(reference_x_grad.cpu(), xgrad), cosim(reference_x_grad.cpu(), xgrad), torch.dist(reference_A_grad.cpu(), Agrad), cosim(reference_A_grad.cpu(), Agrad)

In [None]:
# @triton.autotune(
#    configs=[
#     triton.Config({'V_BLOCK_SIZE': 16, 'N_BLOCK_SIZE': 16, 'H_BLOCK_SIZE': 16}, num_warps=1, num_ctas=1, num_stages=1),
#     # triton.Config({'V_BLOCK_SIZE': 64, 'N_BLOCK_SIZE': 64, 'H_BLOCK_SIZE': 64}),
#     # triton.Config({'V_BLOCK_SIZE': 256, 'N_BLOCK_SIZE': 16, 'H_BLOCK_SIZE': 256}),
#     # triton.Config({'V_BLOCK_SIZE': 512, 'N_BLOCK_SIZE': 16, 'H_BLOCK_SIZE': 512}),
#     # triton.Config({'V_BLOCK_SIZE': 256, 'N_BLOCK_SIZE': 64, 'H_BLOCK_SIZE': 64}),

#     # triton.Config({'V_BLOCK_SIZE': 256, 'N_BLOCK_SIZE': 256, 'H_BLOCK_SIZE': 64}),
#     # triton.Config({'V_BLOCK_SIZE': 256, 'N_BLOCK_SIZE': 256, 'H_BLOCK_SIZE': 256}),
#     # triton.Config({'V_BLOCK_SIZE': 256, 'N_BLOCK_SIZE': 16, 'H_BLOCK_SIZE': 16}),
#    ],
#    key=['V', 'N', 'H'],
#    reset_to_zero=["A_grad_ptr", "x_grad_ptr"]
# )
@triton.jit()
def linear_xent_bwd_kernel_matmul_t(x_ptr,
                y_ptr,
                A_t_ptr,
                m_global_ptr,
                s_global_ptr,
                A_grad_ptr,
                x_grad_ptr,
                locks_N_ptr,
                locks_V_ptr,
                stride_x_N, stride_x_H,
                stride_A_H, stride_A_V,
                stride_x_grad_N, stride_x_grad_H,
                stride_A_grad_H, stride_A_grad_V,
                V: tl.constexpr, N: tl.constexpr, H: tl.constexpr,
                V_BLOCK_SIZE: tl.constexpr = 16,
                N_BLOCK_SIZE: tl.constexpr = 16,
                H_BLOCK_SIZE: tl.constexpr = 16,
               ):
    idx_N = tl.program_id(axis=0)
    idx_V = tl.program_id(axis=1)

    offsets = idx_N * N_BLOCK_SIZE + tl.arange(0, N_BLOCK_SIZE)
    v_range = idx_V * V_BLOCK_SIZE + tl.arange(0, V_BLOCK_SIZE)
    
    y = tl.load(y_ptr + offsets)
    m = tl.load(m_global_ptr + offsets)
    s = tl.load(s_global_ptr + offsets)

    x_block_ptr = tl.make_block_ptr(
        base=x_ptr,
        shape=(N, H),
        strides=(stride_x_N, stride_x_H),
        offsets=(idx_N * N_BLOCK_SIZE, 0),
        block_shape=(N_BLOCK_SIZE, H_BLOCK_SIZE),
        order=(1, 0),
    )

    x_grad_block_ptr = tl.make_block_ptr(
        base=x_grad_ptr,
        shape=(N, H),
        strides=(stride_x_grad_N, stride_x_grad_H),
        offsets=(idx_N * N_BLOCK_SIZE, 0),
        block_shape=(N_BLOCK_SIZE, H_BLOCK_SIZE),
        order=(1, 0),
    )

    A_block_ptr = tl.make_block_ptr(
        base=A_t_ptr,
        shape=(H, V),
        strides=(stride_A_H, stride_A_V),
        offsets=(0, idx_V * V_BLOCK_SIZE),
        block_shape=(H_BLOCK_SIZE, V_BLOCK_SIZE),
        order=(1, 0),
    )

    A_grad_block_ptr = tl.make_block_ptr(
        base=A_grad_ptr,
        shape=(H, V),
        strides=(stride_A_grad_H, stride_A_grad_V),
        offsets=(0, idx_V * V_BLOCK_SIZE),
        block_shape=(H_BLOCK_SIZE, V_BLOCK_SIZE),
        order=(1, 0),
    )

    z_j_to_k = tl.zeros((N_BLOCK_SIZE, V_BLOCK_SIZE), dtype=tl.float32)

    local_x_block_ptr = x_block_ptr
    local_A_block_ptr = A_block_ptr
    for _ in range(H // H_BLOCK_SIZE):
        x_chunk = tl.load(local_x_block_ptr) # Nc x Hc 
        A_v = tl.load(A_block_ptr) # Vc x Hc 

        z_j_to_k = tl.dot(x_chunk, A_v, z_j_to_k) # (Nc x Hc) @ (Hc x Vc)

        local_x_block_ptr = tl.advance(local_x_block_ptr, [0, H_BLOCK_SIZE])
        local_A_block_ptr = tl.advance(local_A_block_ptr, [H_BLOCK_SIZE, 0])

    mask = (y[:, None] == v_range[None, :])[:, :, None] # N_BLOCK_SIZE x V_BLOCK_SIZE x 1
    # the reason for the double loop
    softmax_z = ((z_j_to_k - m[:, None]).exp() / s[:, None])

    for _ in range(H // H_BLOCK_SIZE):
        x_chunk = tl.load(x_block_ptr).to(tl.float32) # Nc x Hc 
        A_v = tl.load(A_block_ptr).to(tl.float32) # Vc x Hc 
        
        # xgrad
        temp_xgrad = tl.dot(softmax_z, A_v.trans())
        temp_xgrad -= tl.sum(tl.where(mask, A_v.trans()[None, :, :], 0.0), axis=1)

        # Lock in V direction for x accumulation
        # tl.atomic_add(x_grad_block_ptr, temp_xgrad)
        while tl.atomic_cas(locks_V_ptr + idx_V, 0, 1) == 1:
            pass
        temp_xgrad = temp_xgrad / N + tl.load(x_grad_block_ptr)
        tl.store(x_grad_block_ptr, temp_xgrad)
        tl.atomic_xchg(locks_V_ptr + idx_V, 0)

        # Agrad
        temp_Agrad = tl.dot(softmax_z.trans(), x_chunk)
        temp_Agrad -= tl.sum(tl.where(mask, x_chunk[:, None, :], 0.0), axis=0)
        temp_Agrad = temp_Agrad.trans() # to T

        # Lock in N direction for A accumulation
        # tl.atomic_add(A_grad_block_ptr, temp_Agrad)
        while tl.atomic_cas(locks_N_ptr + idx_N, 0, 1) == 1:
            pass
        temp_Agrad = temp_Agrad / N + tl.load(A_grad_block_ptr)

        tl.store(A_grad_block_ptr, temp_Agrad)
        tl.atomic_xchg(locks_N_ptr + idx_N, 0)
 
        x_block_ptr = tl.advance(x_block_ptr, [0, H_BLOCK_SIZE])
        x_grad_block_ptr = tl.advance(x_grad_block_ptr, [0, H_BLOCK_SIZE])

        A_block_ptr = tl.advance(A_block_ptr, [H_BLOCK_SIZE, 0])
        A_grad_block_ptr = tl.advance(A_grad_block_ptr, [H_BLOCK_SIZE, 0])
    

@torch.no_grad
def linear_xent_matmul_bwd(x, y, At, m_global, s_global):
    N, H = x.shape
    H_A, V = At.shape # V, H_A = A.shape
    assert H_A == H
    assert y.shape == (N,)
    x = x.contiguous()
    y = y.contiguous()
    At = At.contiguous()

    # assert V % 256 == 0, f"V is {V}"
    # assert N % 64 == 0, f"N is {N}"
    # assert H % 64 == 0, f"H is {H}"

    xgrad = torch.zeros_like(x, dtype=torch.float32)
    Atgrad = torch.zeros_like(At, dtype=torch.float32)

    # grid = (num_blocks,)
    grid = lambda meta: (triton.cdiv(N, meta['N_BLOCK_SIZE']), triton.cdiv(V, meta['V_BLOCK_SIZE']))
    locks_N = torch.zeros(N // 16, dtype=torch.int32, device=x.device) # use minimal block sizes for now, how to make this dynamic?
    locks_V = torch.zeros(V // 16, dtype=torch.int32, device=x.device)

    with torch.cuda.device(x.device.index): # actually required
        linear_xent_bwd_kernel_matmul_t[grid](
                x,
                y,
                At,
                m_global.contiguous(),
                s_global.contiguous(),
                Atgrad,
                xgrad,
                locks_N, locks_V,
                x.stride(0), x.stride(1),
                At.stride(0), At.stride(1),
                xgrad.stride(0), xgrad.stride(1),
                Atgrad.stride(0), Atgrad.stride(1),
                V=V, N=N, H=H)
    # print(linear_xent_bwd_kernel_matmul_t.best_config)
    return xgrad, Atgrad 

xgrad, Atgrad = linear_xent_matmul_bwd(x, y, At, m_ref2.to(device), s_ref2.to(device))
torch.dist(reference_x_grad, xgrad), cosim(reference_x_grad, xgrad), torch.dist(reference_A_grad, Atgrad.T), cosim(reference_A_grad, Atgrad.T)


In [None]:
xgrad.device

In [None]:
xgrad.min(), reference_x_grad.min(), torch.dist(reference_x_grad, xgrad)

In [None]:
@triton.autotune(
   configs=[
    # triton.Config({'N_BLOCK_SIZE': 1}),
    triton.Config({'N_BLOCK_SIZE': 2}),
    triton.Config({'N_BLOCK_SIZE': 4}),
    triton.Config({'N_BLOCK_SIZE': 8}),
    triton.Config({'N_BLOCK_SIZE': 16}),
    # triton.Config({'N_BLOCK_SIZE': 32}),
    # triton.Config({'N_BLOCK_SIZE': 64}),
    # triton.Config({'N_BLOCK_SIZE': 128}),
    # triton.Config({'N_BLOCK_SIZE': 256}),
    # triton.Config({'N_BLOCK_SIZE': 512}),

   ],
   key=['N'],
   restore_value=["loss_ptr", "m_global_ptr", "s_global_ptr"]
)
@triton.jit
def linear_xent_fwd_kernel(x_ptr,
                y_ptr,
                A_ptr,
                loss_ptr,  
                m_global_ptr,
                s_global_ptr,
                stride_x_N, stride_x_H,
                stride_A_V, stride_A_H,
                V: tl.constexpr, N: tl.constexpr, H: tl.constexpr,
                N_BLOCK_SIZE: tl.constexpr,
               ):
    idx = tl.program_id(axis=0)  
    x_block_ptr = tl.make_block_ptr(
        base=x_ptr,
        shape=(N, H),
        strides=(stride_x_N, stride_x_H),
        offsets=(idx * N_BLOCK_SIZE, 0),
        block_shape=(N_BLOCK_SIZE, H),
        order=(1, 0),
    )
    A_block_ptr = tl.make_block_ptr(
        base=A_ptr,
        shape=(V, H),
        strides=(stride_A_V, stride_A_H),
        offsets=(0, 0),
        block_shape=(1, H),
        order=(1, 0),
    )
    offsets = idx * N_BLOCK_SIZE + tl.arange(0, N_BLOCK_SIZE)
    y = tl.load(y_ptr + offsets)
    m = tl.load(m_global_ptr + offsets)
    s = tl.load(s_global_ptr + offsets)
    # log2_const = 1.4426950408889634 not precise enough with exp2 
    loss = 0.

    x_chunk = tl.load(x_block_ptr) # Nc x H 
    
    for v in range(V):
        A_v = tl.load(A_block_ptr) # Vc x H
        z_j = tl.sum((x_chunk * A_v).to(tl.float32), axis=1) # (Nc x H) @ (H x 1)

        m_new = tl.maximum(m, z_j)
        s = s * tl.exp(m - m_new) + tl.exp(z_j - m_new)
        loss -= tl.sum(tl.where(y == v, z_j, 0.))

        m = m_new
        A_block_ptr = tl.advance(A_block_ptr, [1, 0])
    
    loss = (loss + tl.sum(m + tl.log(s))) / N

    tl.atomic_add(loss_ptr, loss)
    tl.store(m_global_ptr + offsets, m)
    tl.store(s_global_ptr + offsets, s)


# loss_triton = linear_cross_entropy(x, y, A) # type: ignore
# loss_triton, torch.dist(reference_loss, loss_triton).item()


In [None]:
@triton.jit
def fwd_kernel(
    x_ND_ptr,
    w_DV_ptr,
    c_N_ptr,
    output_N_ptr,
    l_N_ptr,
    N, D, V,
    stride_xn, stride_xd,
    stride_wd, stride_wv,
    # Meta-parameters
    BLOCK_N: tl.constexpr, BLOCK_D: tl.constexpr, BLOCK_V: tl.constexpr,
):
    # TODO: more parallelism, e.g. tiled softmax 
    #       w/ parallelization across tiles (intra-tile computation uses online softmax)
    # TODO: mask
    # only parallelize along the N dimension
    # i is the same as n
    i = tl.program_id(axis=0)
    offs_n_bN = i * BLOCK_N + tl.arange(0, BLOCK_N)
    c_i_bN = tl.load(c_N_ptr + offs_n_bN)
    output_i_bN = tl.zeros([BLOCK_N], dtype=tl.float32)

    # statistics for online softmax
    m_i_bN = tl.zeros([BLOCK_N], dtype=tl.float32) - float('inf')
    l_i_bN = tl.zeros([BLOCK_N], dtype=tl.float32) + 1.0
    for start_v in range(0, V, BLOCK_V):
        start_v = tl.multiple_of(start_v, BLOCK_V)
        offs_v_bN = start_v + tl.arange(0, BLOCK_V)
        # TODO: mask
        x_ND_block_ptr = tl.make_block_ptr(
            base=x_ND_ptr,
            shape=(N, D),
            strides=(stride_xn, stride_xd),
            offsets=(i * BLOCK_N, 0),
            block_shape=(BLOCK_N, BLOCK_D),
            order=(1, 0),
        )
        w_DV_block_ptr = tl.make_block_ptr(
            base=w_DV_ptr,
            shape=(D, V),
            strides=(stride_wd, stride_wv),
            offsets=(0, start_v),
            block_shape=(BLOCK_D, BLOCK_V),
            order=(1, 0),
        )
        xw_bNbV = tl.zeros([BLOCK_N, BLOCK_V], dtype=tl.float32)
        for start_d in range(0, D, BLOCK_D):
            start_d = tl.multiple_of(start_d, BLOCK_D)
            # TODO: mask
            # TODO: x load can be reduced?
            x_bNbD = tl.load(x_ND_block_ptr)
            w_bDbV = tl.load(w_DV_block_ptr)
            xw_bNbV = tl.dot(x_bNbD, w_bDbV, xw_bNbV)
            x_ND_block_ptr = tl.advance(x_ND_block_ptr, (0, BLOCK_D))
            w_DV_block_ptr = tl.advance(w_DV_block_ptr, (BLOCK_D, 0))
        
        # i for N
        # j for V
        m_ij_bN = tl.maximum(m_i_bN, tl.max(xw_bNbV, axis=1))
        p_ij_bNbV = tl.exp(xw_bNbV - m_ij_bN[:, None])
        l_ij_bN = tl.sum(p_ij_bNbV, axis=1)
        # update m_i and l_i
        alpha_bN = tl.exp(m_i_bN - m_ij_bN)
        l_i_bN = l_i_bN * alpha_bN + l_ij_bN
        m_i_bN = m_ij_bN
        # update output
        p_ic_bN = tl.sum(tl.where(c_i_bN[:, None] == offs_v_bN[None, :], p_ij_bNbV, 0.0), axis=1)
        output_i_bN = output_i_bN * alpha_bN + p_ic_bN

    output_i_bN = tl.log(output_i_bN) - tl.log(l_i_bN)
    tl.store(output_N_ptr + offs_n_bN, output_i_bN)
    tl.store(l_N_ptr + offs_n_bN, l_i_bN)


# output_N[n] = log_softmax(x_ND @ w_DV, dim=1)[n, c_N[n]]
class LMHeadThenLogSoftmaxThenGather(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x_ND: torch.Tensor, w_DV: torch.Tensor, c_N: torch.Tensor):
        # TODO
        N, D = x_ND.shape
        Dw, V = w_DV.shape
        Nc, = c_N.shape
        assert D == Dw and N == Nc
        output_N = x_ND.new_empty(N)
        l_N = x_ND.new_empty(N)
        BLOCK_N = 32
        BLOCK_D = 32
        BLOCK_V = 32
        grid = (triton.cdiv(N, BLOCK_N),)
        fwd_kernel[grid](
            x_ND, w_DV, c_N,
            output_N, l_N,
            N, D, V,
            x_ND.stride(0), x_ND.stride(1),
            w_DV.stride(0), w_DV.stride(1),
            BLOCK_N, BLOCK_D, BLOCK_V,
        )
        ctx.save_for_backward(w_DV, x_ND, c_N, l_N)
        return output_N
        

    @staticmethod
    def backward(ctx, g_NV):
        # TODO
        w_DV, x_ND, c_N, l_N = ctx.saved_tensors

def YouJiacheng_linear_xent(x, y, At):
    return LMHeadThenLogSoftmaxThenGather.apply(x, At, y).sum()


loss = YouJiacheng_linear_xent(x, y, At)
print(loss)

In [None]:
class LinearCrossEntropyLoss(torch.autograd.Function):

    @staticmethod
    def forward(
        ctx,
        x,
        y,
        A,
        ignore_index=-100, # ignores all negative integers ...
    ):
        N, H = x.shape
        V, H_A = A.shape
        assert H_A == H
        assert y.shape == (N,)
        x = x.contiguous()
        y = y.contiguous()
        A = A.contiguous()

        # V_BLOCK_SIZE = 16
        # N_BLOCK_SIZE = 16 # 64

        assert N % 16 == 0
        # assert V % V_BLOCK_SIZE == 0

        # num_blocks = N // N_BLOCK_SIZE
        m_global = -10e5 * torch.ones(N, dtype=torch.float32, device=x.device)
        s_global = torch.zeros(N, dtype=torch.float32, device=x.device)
        loss = torch.zeros(1, dtype=torch.float32, device=x.device)

        # grid = (num_blocks,)
        grid = lambda meta: (triton.cdiv(N, meta['N_BLOCK_SIZE']), )
        with torch.cuda.device(x.device.index): # actually required
            linear_xent_fwd_kernel[grid](
                    x,
                    y,
                    A,
                    loss,  
                    m_global,
                    s_global,
                    x.stride(0), x.stride(1),
                    A.stride(0), A.stride(1),
                    V=V, N=N, H=H,
                )
        # print(linear_xent_fwd_kernel.best_config)


        ctx.save_for_backward(m_global, s_global)
        return loss

    @staticmethod
    def backward(ctx, losses):
        pass
        return x, None, A
    
def linear_cross_entropy(x, y, A):
    return LinearCrossEntropyLoss.apply(x, y, A)

loss_triton = linear_cross_entropy(x, y, A) # type: ignore
loss_triton, torch.dist(reference_loss, loss_triton).item()

In [None]:
def benchmark_with_memory_reporting(func, quantiles, *args, **kwargs):
    torch.cuda.synchronize()
    torch.cuda.reset_peak_memory_stats(device=device)
    initial_memory = torch.cuda.memory_allocated(device=device)
    
    ms, min_ms, max_ms = triton.testing.do_bench(lambda: func(*args, **kwargs), quantiles=quantiles)
    
    torch.cuda.synchronize()
    peak_memory = torch.cuda.max_memory_allocated(device=device)
    memory_used = peak_memory - initial_memory
    
    return ms, min_ms, max_ms, memory_used

In [None]:
# make space for benchmarking
del A
del At
del x
del y
del reference_loss
torch.cuda.empty_cache()

In [None]:
f = 1 # number from 1 to 512 to get fast results

In [None]:
@triton.testing.perf_report(
    triton.testing.Benchmark(
        x_names=['H'],  # Argument names to use as an x-axis for the plot.
        x_vals=[2**i for i in range(9, 14, 1)],  # Different possible values for `x_name`.
        x_log=True,  # x axis is logarithmic.
        line_arg='provider',  # Argument name whose value corresponds to a different line in the plot.
        line_vals=['torch', 'torch-compile', 'triton2', 'triton2-t', 'triton-variant'],  # Possible values for `line_arg`.
        line_names=['torch', 'torch-compile', 'triton2', 'triton2-t', 'triton-variant'],  # Label name for the lines.
        # styles=[('blue', '-'), ('green', '-'), ('red', '-'), ('brown', '-')],  # Line styles.
        ylabel='TFLOP/s',  # Label name for the y-axis.
        plot_name='Linear+Loss Performance',  # Name for the plot. Used also as a file name for saving the plot.
        args={},  # Values for function arguments not in `x_names` and `y_name`.
    ))
def benchmark(H, provider):
    B, S, V = 4, 4096 // f, 131072 // f
    N = B * S 
    H = H // f

    x = torch.randn(N, H, requires_grad=True, device=device, dtype=torch.bfloat16) # B S H
    y = torch.randint(0, V, (N,), device=device) # vocab ** B S 
    A = torch.randn(V, H, requires_grad=True, device=device, dtype=torch.bfloat16)
    At = A.clone().T.contiguous()

    quantiles = [0.5, 0.2, 0.8]
    if provider == 'torch':
        ms, min_ms, max_ms = triton.testing.do_bench(lambda: baseline_torch(x, y, A), quantiles=quantiles)
    if provider == 'torch-compile':
        ms, min_ms, max_ms = triton.testing.do_bench(lambda: compiled_baseline(x, y, A), quantiles=quantiles)
    if provider == 'torch-maxauto':
        ms, min_ms, max_ms = triton.testing.do_bench(lambda: maxauto_baseline(x, y, A), quantiles=quantiles)
    if provider == "triton":
        ms, min_ms, max_ms = triton.testing.do_bench(lambda: linear_cross_entropy(x, y, A), quantiles=quantiles)
    if provider == "triton2":
        ms, min_ms, max_ms = triton.testing.do_bench(lambda: linear_xent_matmul(x, y, A), quantiles=quantiles)
    if provider == "triton2-t":
        ms, min_ms, max_ms = triton.testing.do_bench(lambda: linear_xent_matmul_At(x, y, At), quantiles=quantiles)
    if provider == "triton-variant":
        ms, min_ms, max_ms = triton.testing.do_bench(lambda: YouJiacheng_linear_xent(x, y, At), quantiles=quantiles)
        
    perf = lambda ms: 2 * N * H * V * 1e-12 / (ms * 1e-3)
    return perf(ms), perf(max_ms), perf(min_ms)

benchmark.run(print_data=True, show_plots=True)

In [None]:
@triton.testing.perf_report(
    triton.testing.Benchmark(
        x_names=['V'],  # Argument names to use as an x-axis for the plot.
        x_vals=[2**i for i in range(10, 18, 1)],  # Different possible values for `x_name`.
        x_log=True,  # x axis is logarithmic.
        line_arg='provider',  # Argument name whose value corresponds to a different line in the plot.
        line_vals=['torch', 'torch-compile', 'triton2', 'triton2-t'],  # Possible values for `line_arg`.
        line_names=['torch', 'torch-compile', 'triton2', 'triton2-t'],  # Label name for the lines.
        ylabel='TFLOP/s',  # Label name for the y-axis.
        plot_name='Linear+Loss Performance',  # Name for the plot. Used also as a file name for saving the plot.
        args={},  # Values for function arguments not in `x_names` and `y_name`.
    ))
def benchmark(V, provider):
    B, S , H = 4, 4096//f, 4096//f
    N = B * S 
    V = V // f

    x = torch.randn(N, H, requires_grad=True, device=device, dtype=torch.bfloat16) # B S H
    y = torch.randint(0, V, (N,), device=device) # vocab ** B S 
    A = torch.randn(V, H, requires_grad=True, device=device, dtype=torch.bfloat16)
    At = A.clone().T.contiguous()

    quantiles = [0.5, 0.2, 0.8]
    if provider == 'torch':
        ms, min_ms, max_ms = triton.testing.do_bench(lambda: baseline_torch(x, y, A), quantiles=quantiles)
    if provider == 'torch-compile':
        ms, min_ms, max_ms = triton.testing.do_bench(lambda: compiled_baseline(x, y, A), quantiles=quantiles)
    if provider == 'torch-maxauto':
        ms, min_ms, max_ms = triton.testing.do_bench(lambda: maxauto_baseline(x, y, A), quantiles=quantiles)
    if provider == "triton":
        ms, min_ms, max_ms = triton.testing.do_bench(lambda: linear_cross_entropy(x, y, A), quantiles=quantiles)
    if provider == "triton2":
        ms, min_ms, max_ms = triton.testing.do_bench(lambda: linear_xent_matmul(x, y, A), quantiles=quantiles)
    if provider == "triton2-t":
        ms, min_ms, max_ms = triton.testing.do_bench(lambda: linear_xent_matmul_At(x, y, At), quantiles=quantiles)
    perf = lambda ms: 2 * N * H * V * 1e-12 / (ms * 1e-3)
    return perf(ms), perf(max_ms), perf(min_ms)

benchmark.run(print_data=True, show_plots=True)


In [None]:
@triton.testing.perf_report(
    triton.testing.Benchmark(
        x_names=['N'],  # Argument names to use as an x-axis for the plot.
        x_vals=[2**i for i in range(9, 15, 1)],  # Different possible values for `x_name`.
        x_log=True,  # x axis is logarithmic.
        line_arg='provider',  # Argument name whose value corresponds to a different line in the plot.
        line_vals=['torch', 'torch-compile', 'triton2', 'triton2-t'],  # Possible values for `line_arg`.
        line_names=['torch', 'torch-compile', 'triton2', 'triton2-t'],  # Label name for the lines.
        ylabel='TFLOP/s',  # Label name for the y-axis.
        plot_name='Linear+Loss Performance',  # Name for the plot. Used also as a file name for saving the plot.
        args={},  # Values for function arguments not in `x_names` and `y_name`.
    ))
def benchmark(N, provider):
    H, V = 4096//f, 131072//f
    N = N // f

    x = torch.randn(N, H, requires_grad=True, device=device, dtype=torch.bfloat16) # B S H
    y = torch.randint(0, V, (N,), device=device) # vocab ** B S 
    A = torch.randn(V, H, requires_grad=True, device=device, dtype=torch.bfloat16)
    At = A.clone().T.contiguous()

    quantiles = [0.5, 0.2, 0.8]
    if provider == 'torch':
        ms, min_ms, max_ms = triton.testing.do_bench(lambda: baseline_torch(x, y, A), quantiles=quantiles)
    if provider == 'torch-compile':
        ms, min_ms, max_ms = triton.testing.do_bench(lambda: compiled_baseline(x, y, A), quantiles=quantiles)
    if provider == 'torch-maxauto':
        ms, min_ms, max_ms = triton.testing.do_bench(lambda: maxauto_baseline(x, y, A), quantiles=quantiles)
    if provider == "triton":
        ms, min_ms, max_ms = triton.testing.do_bench(lambda: linear_cross_entropy(x, y, A), quantiles=quantiles)
    if provider == "triton2":
        ms, min_ms, max_ms = triton.testing.do_bench(lambda: linear_xent_matmul(x, y, A), quantiles=quantiles)
    if provider == "triton2-t":
        ms, min_ms, max_ms = triton.testing.do_bench(lambda: linear_xent_matmul_At(x, y, At), quantiles=quantiles)

    perf = lambda ms: 2 * N * H * V * 1e-12 / (ms * 1e-3)
    return perf(ms), perf(max_ms), perf(min_ms)

benchmark.run(print_data=True, show_plots=True)

In [None]:
@triton.testing.perf_report(
    triton.testing.Benchmark(
        x_names=['N'],  # Argument names to use as an x-axis for the plot.
        x_vals=[2**i for i in range(9, 15, 1)],  # Different possible values for `x_name`.
        x_log=True,  # x axis is logarithmic.
        line_arg='provider',  # Argument name whose value corresponds to a different line in the plot.
        line_vals=['torch', 'torch-compile', 'triton2', 'triton2-t'],  # Possible values for `line_arg`.
        line_names=['torch', 'torch-compile', 'triton2', 'triton2-t'],  # Label name for the lines.
        ylabel='GBs of Memory',  # Label name for the y-axis.
        plot_name='Linear+Loss Performance',  # Name for the plot. Used also as a file name for saving the plot.
        args={},  # Values for function arguments not in `x_names` and `y_name`.
    ))
def benchmark(N, provider):
    H, V = 4096//f, 131072//f
    N = N // f

    x = torch.randn(N, H, requires_grad=True, device=device, dtype=torch.bfloat16) # B S H
    y = torch.randint(0, V, (N,), device=device) # vocab ** B S 
    A = torch.randn(V, H, requires_grad=True, device=device, dtype=torch.bfloat16)
    At = A.clone().T.contiguous()

    quantiles = [0.5, 0.2, 0.8]
    if provider == 'torch':
        ms, min_ms, max_ms, memory_used = benchmark_with_memory_reporting(lambda: baseline_torch(x, y, A), quantiles=quantiles)
    if provider == 'torch-compile':
        ms, min_ms, max_ms, memory_used = benchmark_with_memory_reporting(lambda: compiled_baseline(x, y, A), quantiles=quantiles)
    if provider == 'torch-maxauto':
        ms, min_ms, max_ms, memory_used = benchmark_with_memory_reporting(lambda: maxauto_baseline(x, y, A), quantiles=quantiles)
    if provider == "triton":
        ms, min_ms, max_ms, memory_used = benchmark_with_memory_reporting(lambda: linear_cross_entropy(x, y, A), quantiles=quantiles)
    if provider == "triton2":
        ms, min_ms, max_ms, memory_used = benchmark_with_memory_reporting(lambda: linear_xent_matmul(x, y, A), quantiles=quantiles)
    if provider == "triton2-t":
        ms, min_ms, max_ms, memory_used = benchmark_with_memory_reporting(lambda: linear_xent_matmul_At(x, y, At), quantiles=quantiles)
    return memory_used / 1024**3, 0, 0

benchmark.run(print_data=True, show_plots=True)

In [None]:
@triton.testing.perf_report(
    triton.testing.Benchmark(
        x_names=['V'],  # Argument names to use as an x-axis for the plot.
        x_vals=[2**i for i in range(10, 19, 1)],  # Different possible values for `x_name`.
        x_log=True,  # x axis is logarithmic.
        line_arg='provider',  # Argument name whose value corresponds to a different line in the plot.
        line_vals=['torch', 'torch-compile', 'triton2', 'triton2-t'],  # Possible values for `line_arg`.
        line_names=['torch', 'torch-compile', 'triton2', 'triton2-t'],  # Label name for the lines.
        ylabel='GBs of Memory',  # Label name for the y-axis.
        plot_name='Linear+Loss Performance',  # Name for the plot. Used also as a file name for saving the plot.
        args={},  # Values for function arguments not in `x_names` and `y_name`.
    ))

@torch.no_grad()
def benchmark(V, provider):
    B, S , H = 4, 4096//f, 4096//f
    N = B * S 
    V = V // f

    x = torch.randn(N, H, requires_grad=True, device=device, dtype=torch.bfloat16) # B S H
    y = torch.randint(0, V, (N,), device=device) # vocab ** B S 
    A = torch.randn(V, H, requires_grad=True, device=device, dtype=torch.bfloat16)
    At = A.clone().T.contiguous()

    quantiles = [0.5, 0.2, 0.8]
    if provider == 'torch':
        ms, min_ms, max_ms, memory_used = benchmark_with_memory_reporting(lambda: baseline_torch(x, y, A), quantiles=quantiles)
    if provider == 'torch-compile':
        ms, min_ms, max_ms, memory_used = benchmark_with_memory_reporting(lambda: compiled_baseline(x, y, A), quantiles=quantiles)
    if provider == 'torch-maxauto':
        ms, min_ms, max_ms, memory_used = benchmark_with_memory_reporting(lambda: maxauto_baseline(x, y, A), quantiles=quantiles)
    if provider == "triton":
        ms, min_ms, max_ms, memory_used = benchmark_with_memory_reporting(lambda: linear_cross_entropy(x, y, A), quantiles=quantiles)
    if provider == "triton2":
        ms, min_ms, max_ms, memory_used = benchmark_with_memory_reporting(lambda: linear_xent_matmul(x, y, A), quantiles=quantiles)
    if provider == "triton2-t":
        ms, min_ms, max_ms, memory_used = benchmark_with_memory_reporting(lambda: linear_xent_matmul_At(x, y, At), quantiles=quantiles)
    return memory_used / 1024**3, 0, 0

benchmark.run(print_data=True, show_plots=True)

In [None]:
torch.cuda.set_stream(torch.cuda.Stream())

In [None]:
@triton.testing.perf_report(
    triton.testing.Benchmark(
        x_names=['N'],  # Argument names to use as an x-axis for the plot.
        x_vals=[2**i for i in range(9, 15, 1)],  # Different possible values for `x_name`.
        x_log=True,  # x axis is logarithmic.
        line_arg='provider',  # Argument name whose value corresponds to a different line in the plot.
        line_vals=['torch', 'torch-compile', 'triton2', 'triton2-t'],  # Possible values for `line_arg`.
        line_names=['torch', 'torch-compile', 'triton2', 'triton2-t'],  # Label name for the lines.
        ylabel='TFLOP/s',  # Label name for the y-axis.
        plot_name='Linear+Loss Performance',  # Name for the plot. Used also as a file name for saving the plot.
        args={},  # Values for function arguments not in `x_names` and `y_name`.
    ))
def benchmark(N, provider):
    H, V = 4096//f, 131072//f
    N = N // f

    x = torch.randn(N, H, requires_grad=True, device=device, dtype=torch.bfloat16) # B S H
    y = torch.randint(0, V, (N,), device=device) # vocab ** B S 
    A = torch.randn(V, H, requires_grad=True, device=device, dtype=torch.bfloat16)
    At = A.clone().T.contiguous()

    quantiles = [0.5, 0.2, 0.8]
    if provider == 'torch':
        ms = triton.testing.do_bench_cudagraph(lambda: baseline_torch(x, y, A))
    if provider == 'torch-compile':
        ms = triton.testing.do_bench_cudagraph(lambda: compiled_baseline(x, y, A))
    if provider == 'torch-maxauto':
        ms = triton.testing.do_bench_cudagraph(lambda: maxauto_baseline(x, y, A))
    if provider == "triton":
        ms = triton.testing.do_bench_cudagraph(lambda: linear_cross_entropy(x, y, A))
    if provider == "triton2":
        ms = triton.testing.do_bench_cudagraph(lambda: linear_xent_matmul(x, y, A))
    if provider == "triton2":
        ms = triton.testing.do_bench_cudagraph(lambda: linear_xent_matmul(x, y, A))
    if provider == "triton2-t":
        ms = triton.testing.do_bench_cudagraph(lambda: linear_xent_matmul_At(x, y, At))
    perf = lambda ms: 2 * N * H * V * 1e-12 / (ms * 1e-3)
    return perf(ms)

benchmark.run(print_data=True, show_plots=True)