In [None]:

import os
# os.environ["TRITON_PRINT_AUTOTUNING"] = "1"
os.environ["CUDA_VISIBLE_DEVICES"]="6"
import time
from datetime import timedelta

import torch
from torch.utils.checkpoint import checkpoint
import torch.nn.functional as F

import triton
import triton.language as tl

import pdb


In [None]:
device = torch.device("cuda:0")
torch.cuda.device_count()

In [None]:
def baseline_torch(x, y, A):
    V = A.shape[0]
    return F.cross_entropy(F.linear(x, A).view(-1, V).float(), y.view(-1))

compiled_baseline = torch.compile(baseline_torch)
maxauto_baseline = torch.compile(baseline_torch, fullgraph=True, mode="max-autotune")

In [None]:
@triton.autotune(
   configs=[
    triton.Config({'V_BLOCK_SIZE': 64, 'N_BLOCK_SIZE': 16, 'H_BLOCK_SIZE': 64}),
    triton.Config({'V_BLOCK_SIZE': 64, 'N_BLOCK_SIZE': 64, 'H_BLOCK_SIZE': 64}),
    triton.Config({'V_BLOCK_SIZE': 256, 'N_BLOCK_SIZE': 16, 'H_BLOCK_SIZE': 256}),
    triton.Config({'V_BLOCK_SIZE': 512, 'N_BLOCK_SIZE': 16, 'H_BLOCK_SIZE': 512}),
    triton.Config({'V_BLOCK_SIZE': 256, 'N_BLOCK_SIZE': 64, 'H_BLOCK_SIZE': 64}),

    triton.Config({'V_BLOCK_SIZE': 256, 'N_BLOCK_SIZE': 256, 'H_BLOCK_SIZE': 64}),
    triton.Config({'V_BLOCK_SIZE': 256, 'N_BLOCK_SIZE': 256, 'H_BLOCK_SIZE': 256}),
    triton.Config({'V_BLOCK_SIZE': 256, 'N_BLOCK_SIZE': 16, 'H_BLOCK_SIZE': 16}),
   ],
   key=['V', 'N', 'H'],
   reset_to_zero=["loss_ptr", "m_global_ptr", "s_global_ptr"]
)
@triton.jit
def linear_xent_fwd_kernel_matmul(x_ptr,
                y_ptr,
                A_ptr,
                loss_ptr,  
                m_global_ptr,
                s_global_ptr,
                stride_x_N, stride_x_H,
                stride_A_V, stride_A_H,
                V: tl.constexpr, N: tl.constexpr, H: tl.constexpr,
                V_BLOCK_SIZE: tl.constexpr,
                N_BLOCK_SIZE: tl.constexpr,
                H_BLOCK_SIZE: tl.constexpr,
               ):
    idx = tl.program_id(axis=0)  
    x_block_ptr = tl.make_block_ptr(
        base=x_ptr,
        shape=(N, H),
        strides=(stride_x_N, stride_x_H),
        offsets=(idx * N_BLOCK_SIZE, 0),
        block_shape=(N_BLOCK_SIZE, H_BLOCK_SIZE),
        order=(1, 0),
    )
    A_block_ptr = tl.make_block_ptr(
        base=A_ptr,
        shape=(V, H),
        strides=(stride_A_V, stride_A_H),
        offsets=(0, 0),
        block_shape=(V_BLOCK_SIZE, H_BLOCK_SIZE),
        order=(1, 0),
    )
    offsets = idx * N_BLOCK_SIZE + tl.arange(0, N_BLOCK_SIZE)
    v_range = tl.arange(0, V_BLOCK_SIZE)
    y = tl.load(y_ptr + offsets)

    m = tl.load(m_global_ptr + offsets)
    s = tl.load(s_global_ptr + offsets)
    loss = 0.

    
    
    for _ in range(V // V_BLOCK_SIZE):
        
        z_j_to_k = tl.zeros((N_BLOCK_SIZE, V_BLOCK_SIZE), dtype=tl.float32)
        local_x_block_ptr = x_block_ptr
        for _ in range(H // H_BLOCK_SIZE):
            x_chunk = tl.load(local_x_block_ptr) # Nc x H 
            A_v = tl.load(A_block_ptr) # Vc x H 

            z_j_to_k = tl.dot(x_chunk, A_v.trans(), z_j_to_k) # (Nc x H) @ (H x Vc)

            local_x_block_ptr = tl.advance(local_x_block_ptr, [0, H_BLOCK_SIZE])
            A_block_ptr = tl.advance(A_block_ptr, [0, H_BLOCK_SIZE])
    
            
        m_new = tl.maximum(m, tl.max(z_j_to_k, 1))

        s_update = tl.sum(tl.exp(z_j_to_k - m_new[:, None]), axis=1)
        s = s * tl.exp(m - m_new) + s_update

        mask = y[:, None] == v_range[None, :] # Nc x Vc
        loss -= tl.sum(tl.where(mask, z_j_to_k, float(0.0))) / N

        m = m_new
        A_block_ptr = tl.advance(A_block_ptr, [V_BLOCK_SIZE, -H_BLOCK_SIZE * (H // H_BLOCK_SIZE)])
        v_range = v_range + V_BLOCK_SIZE
    
    # tl.device_print("m", m)
    # tl.device_print("s", s)
    loss += tl.sum(m + tl.log(s)) / N

    tl.atomic_add(loss_ptr, loss)
    tl.store(m_global_ptr + offsets, m)
    tl.store(s_global_ptr + offsets, s)

def linear_xent_matmul(x, y, A):
    N, H = x.shape
    V, H_A = A.shape
    assert H_A == H
    assert y.shape == (N,)
    x = x.contiguous()
    y = y.contiguous()
    A = A.contiguous()

    assert V % 256 == 0, f"V is {V}"
    assert N % 64 == 0, f"N is {N}"
    assert H % 64 == 0, f"H is {H}"

    m_global = -10e5 * torch.ones(N, dtype=torch.float32, device=x.device)
    s_global = torch.zeros(N, dtype=torch.float32, device=x.device)
    loss = torch.zeros(1, dtype=torch.float32, device=x.device)

    # grid = (num_blocks,)
    grid = lambda meta: (triton.cdiv(N, meta['N_BLOCK_SIZE']), )

    with torch.cuda.device(x.device.index): # actually required
        linear_xent_fwd_kernel_matmul[grid](
                x,
                y,
                A,
                loss,  
                m_global,
                s_global,
                x.stride(0), x.stride(1),
                A.stride(0), A.stride(1),
                V=V, N=N, H=H)
    # print(linear_xent_fwd_kernel_matmul.best_config)
    return loss

In [None]:
@triton.autotune(
   configs=[
    triton.Config({'V_BLOCK_SIZE': 64, 'N_BLOCK_SIZE': 16, 'H_BLOCK_SIZE': 64}),
    triton.Config({'V_BLOCK_SIZE': 64, 'N_BLOCK_SIZE': 64, 'H_BLOCK_SIZE': 64}),
    triton.Config({'V_BLOCK_SIZE': 256, 'N_BLOCK_SIZE': 16, 'H_BLOCK_SIZE': 256}),
    triton.Config({'V_BLOCK_SIZE': 512, 'N_BLOCK_SIZE': 16, 'H_BLOCK_SIZE': 512}),
    triton.Config({'V_BLOCK_SIZE': 256, 'N_BLOCK_SIZE': 64, 'H_BLOCK_SIZE': 64}),

    triton.Config({'V_BLOCK_SIZE': 256, 'N_BLOCK_SIZE': 256, 'H_BLOCK_SIZE': 64}),
    triton.Config({'V_BLOCK_SIZE': 256, 'N_BLOCK_SIZE': 256, 'H_BLOCK_SIZE': 256}),
    triton.Config({'V_BLOCK_SIZE': 256, 'N_BLOCK_SIZE': 16, 'H_BLOCK_SIZE': 16}),
   ],
   key=['V', 'N', 'H'],
   reset_to_zero=["loss_ptr", "m_global_ptr", "s_global_ptr"]
)
@triton.jit
def linear_xent_fwd_kernel_matmul_t(x_ptr,
                y_ptr,
                A_t_ptr,
                loss_ptr,  
                m_global_ptr,
                s_global_ptr,
                stride_x_N, stride_x_H,
                stride_A_H, stride_A_V,
                V: tl.constexpr, N: tl.constexpr, H: tl.constexpr,
                V_BLOCK_SIZE: tl.constexpr,
                N_BLOCK_SIZE: tl.constexpr,
                H_BLOCK_SIZE: tl.constexpr,
               ):
    idx = tl.program_id(axis=0)  
    x_block_ptr = tl.make_block_ptr(
        base=x_ptr,
        shape=(N, H),
        strides=(stride_x_N, stride_x_H),
        offsets=(idx * N_BLOCK_SIZE, 0),
        block_shape=(N_BLOCK_SIZE, H_BLOCK_SIZE),
        order=(1, 0),
    )
    A_block_ptr = tl.make_block_ptr(
        base=A_t_ptr,
        shape=(H, V),
        strides=(stride_A_H, stride_A_V),
        offsets=(0, 0),
        block_shape=(H_BLOCK_SIZE, V_BLOCK_SIZE),
        order=(1, 0),
    )
    offsets = idx * N_BLOCK_SIZE + tl.arange(0, N_BLOCK_SIZE)
    v_range = tl.arange(0, V_BLOCK_SIZE)
    y = tl.load(y_ptr + offsets)

    m = tl.load(m_global_ptr + offsets)
    s = tl.load(s_global_ptr + offsets)
    loss = 0.

    
    
    for _ in range(V // V_BLOCK_SIZE):
        
        z_j_to_k = tl.zeros((N_BLOCK_SIZE, V_BLOCK_SIZE), dtype=tl.float32)
        local_x_block_ptr = x_block_ptr
        for _ in range(H // H_BLOCK_SIZE):
            x_chunk = tl.load(local_x_block_ptr) # Nc x H 
            A_v = tl.load(A_block_ptr) # Vc x H 

            z_j_to_k = tl.dot(x_chunk, A_v, z_j_to_k) # (Nc x H) @ (H x Vc)

            local_x_block_ptr = tl.advance(local_x_block_ptr, [0, H_BLOCK_SIZE])
            A_block_ptr = tl.advance(A_block_ptr, [H_BLOCK_SIZE, 0])
    
            
        m_new = tl.maximum(m, tl.max(z_j_to_k, 1))

        s_update = tl.sum(tl.exp(z_j_to_k - m_new[:, None]), axis=1)
        s = s * tl.exp(m - m_new) + s_update

        mask = y[:, None] == v_range[None, :] # Nc x Vc
        loss -= tl.sum(tl.where(mask, z_j_to_k, float(0.0))) / N

        m = m_new
        A_block_ptr = tl.advance(A_block_ptr, [-H_BLOCK_SIZE * (H // H_BLOCK_SIZE), V_BLOCK_SIZE])
        v_range = v_range + V_BLOCK_SIZE
    
    # tl.device_print("m", m)
    # tl.device_print("s", s)
    loss += tl.sum(m + tl.log(s)) / N

    tl.atomic_add(loss_ptr, loss)
    tl.store(m_global_ptr + offsets, m)
    tl.store(s_global_ptr + offsets, s)

@torch.no_grad
def linear_xent_matmul_At(x, y, At):
    N, H = x.shape
    H_A, V = At.shape # V, H_A = A.shape
    assert H_A == H
    assert y.shape == (N,)
    x = x.contiguous()
    y = y.contiguous()
    At = At.contiguous()

    assert V % 256 == 0, f"V is {V}"
    assert N % 64 == 0, f"N is {N}"
    assert H % 64 == 0, f"H is {H}"

    m_global = -10e5 * torch.ones(N, dtype=torch.float32, device=x.device)
    s_global = torch.zeros(N, dtype=torch.float32, device=x.device)
    loss = torch.zeros(1, dtype=torch.float32, device=x.device)

    # grid = (num_blocks,)
    grid = lambda meta: (triton.cdiv(N, meta['N_BLOCK_SIZE']), )

    with torch.cuda.device(x.device.index): # actually required
        linear_xent_fwd_kernel_matmul_t[grid](
                x,
                y,
                At,
                loss,  
                m_global,
                s_global,
                x.stride(0), x.stride(1),
                At.stride(0), At.stride(1),
                V=V, N=N, H=H)
    # print(linear_xent_fwd_kernel_matmul.best_config)
    return loss, m_global, s_global

In [None]:
@triton.autotune(
   configs=[
    # triton.Config({'N_BLOCK_SIZE': 1}),
    triton.Config({'N_BLOCK_SIZE': 2}),
    triton.Config({'N_BLOCK_SIZE': 4}),
    triton.Config({'N_BLOCK_SIZE': 8}),
    triton.Config({'N_BLOCK_SIZE': 16}),
    # triton.Config({'N_BLOCK_SIZE': 32}),
    # triton.Config({'N_BLOCK_SIZE': 64}),
    # triton.Config({'N_BLOCK_SIZE': 128}),
    # triton.Config({'N_BLOCK_SIZE': 256}),
    # triton.Config({'N_BLOCK_SIZE': 512}),

   ],
   key=['N'],
   restore_value=["loss_ptr", "m_global_ptr", "s_global_ptr"]
)
@triton.jit
def linear_xent_fwd_kernel(x_ptr,
                y_ptr,
                A_ptr,
                loss_ptr,  
                m_global_ptr,
                s_global_ptr,
                stride_x_N, stride_x_H,
                stride_A_V, stride_A_H,
                V: tl.constexpr, N: tl.constexpr, H: tl.constexpr,
                N_BLOCK_SIZE: tl.constexpr,
               ):
    idx = tl.program_id(axis=0)  
    x_block_ptr = tl.make_block_ptr(
        base=x_ptr,
        shape=(N, H),
        strides=(stride_x_N, stride_x_H),
        offsets=(idx * N_BLOCK_SIZE, 0),
        block_shape=(N_BLOCK_SIZE, H),
        order=(1, 0),
    )
    A_block_ptr = tl.make_block_ptr(
        base=A_ptr,
        shape=(V, H),
        strides=(stride_A_V, stride_A_H),
        offsets=(0, 0),
        block_shape=(1, H),
        order=(1, 0),
    )
    offsets = idx * N_BLOCK_SIZE + tl.arange(0, N_BLOCK_SIZE)
    y = tl.load(y_ptr + offsets)
    m = tl.load(m_global_ptr + offsets)
    s = tl.load(s_global_ptr + offsets)
    # log2_const = 1.4426950408889634 not precise enough with exp2 
    loss = 0.

    x_chunk = tl.load(x_block_ptr) # Nc x H 
    
    for v in range(V):
        A_v = tl.load(A_block_ptr) # Vc x H
        z_j = tl.sum((x_chunk * A_v).to(tl.float32), axis=1) # (Nc x H) @ (H x 1)

        m_new = tl.maximum(m, z_j)
        s = s * tl.exp(m - m_new) + tl.exp(z_j - m_new)
        loss -= tl.sum(tl.where(y == v, z_j, 0.))

        m = m_new
        A_block_ptr = tl.advance(A_block_ptr, [1, 0])
    
    loss = (loss + tl.sum(m + tl.log(s))) / N

    tl.atomic_add(loss_ptr, loss)
    tl.store(m_global_ptr + offsets, m)
    tl.store(s_global_ptr + offsets, s)


# loss_triton = linear_cross_entropy(x, y, A) # type: ignore
# loss_triton, torch.dist(reference_loss, loss_triton).item()


In [None]:
@triton.jit
def fwd_kernel(
    x_ND_ptr,
    w_DV_ptr,
    c_N_ptr,
    output_N_ptr,
    l_N_ptr,
    N, D, V,
    stride_xn, stride_xd,
    stride_wd, stride_wv,
    # Meta-parameters
    BLOCK_N: tl.constexpr, BLOCK_D: tl.constexpr, BLOCK_V: tl.constexpr,
):
    # TODO: more parallelism, e.g. tiled softmax 
    #       w/ parallelization across tiles (intra-tile computation uses online softmax)
    # TODO: mask
    # only parallelize along the N dimension
    # i is the same as n
    i = tl.program_id(axis=0)
    offs_n_bN = i * BLOCK_N + tl.arange(0, BLOCK_N)
    c_i_bN = tl.load(c_N_ptr + offs_n_bN)
    output_i_bN = tl.zeros([BLOCK_N], dtype=tl.float32)

    # statistics for online softmax
    m_i_bN = tl.zeros([BLOCK_N], dtype=tl.float32) - float('inf')
    l_i_bN = tl.zeros([BLOCK_N], dtype=tl.float32) + 1.0
    for start_v in range(0, V, BLOCK_V):
        start_v = tl.multiple_of(start_v, BLOCK_V)
        offs_v_bN = start_v + tl.arange(0, BLOCK_V)
        # TODO: mask
        x_ND_block_ptr = tl.make_block_ptr(
            base=x_ND_ptr,
            shape=(N, D),
            strides=(stride_xn, stride_xd),
            offsets=(i * BLOCK_N, 0),
            block_shape=(BLOCK_N, BLOCK_D),
            order=(1, 0),
        )
        w_DV_block_ptr = tl.make_block_ptr(
            base=w_DV_ptr,
            shape=(D, V),
            strides=(stride_wd, stride_wv),
            offsets=(0, start_v),
            block_shape=(BLOCK_D, BLOCK_V),
            order=(1, 0),
        )
        xw_bNbV = tl.zeros([BLOCK_N, BLOCK_V], dtype=tl.float32)
        for start_d in range(0, D, BLOCK_D):
            start_d = tl.multiple_of(start_d, BLOCK_D)
            # TODO: mask
            # TODO: x load can be reduced?
            x_bNbD = tl.load(x_ND_block_ptr)
            w_bDbV = tl.load(w_DV_block_ptr)
            xw_bNbV = tl.dot(x_bNbD, w_bDbV, xw_bNbV)
            x_ND_block_ptr = tl.advance(x_ND_block_ptr, (0, BLOCK_D))
            w_DV_block_ptr = tl.advance(w_DV_block_ptr, (BLOCK_D, 0))
        
        # i for N
        # j for V
        m_ij_bN = tl.maximum(m_i_bN, tl.max(xw_bNbV, axis=1))
        p_ij_bNbV = tl.exp(xw_bNbV - m_ij_bN[:, None])
        l_ij_bN = tl.sum(p_ij_bNbV, axis=1)
        # update m_i and l_i
        alpha_bN = tl.exp(m_i_bN - m_ij_bN)
        l_i_bN = l_i_bN * alpha_bN + l_ij_bN
        m_i_bN = m_ij_bN
        # update output
        p_ic_bN = tl.sum(tl.where(c_i_bN[:, None] == offs_v_bN[None, :], p_ij_bNbV, 0.0), axis=1)
        output_i_bN = output_i_bN * alpha_bN + p_ic_bN

    output_i_bN = tl.log(output_i_bN) - tl.log(l_i_bN)
    tl.store(output_N_ptr + offs_n_bN, output_i_bN)
    tl.store(l_N_ptr + offs_n_bN, l_i_bN)


# output_N[n] = log_softmax(x_ND @ w_DV, dim=1)[n, c_N[n]]
class LMHeadThenLogSoftmaxThenGather(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x_ND: torch.Tensor, w_DV: torch.Tensor, c_N: torch.Tensor):
        # TODO
        N, D = x_ND.shape
        Dw, V = w_DV.shape
        Nc, = c_N.shape
        assert D == Dw and N == Nc
        output_N = x_ND.new_empty(N)
        l_N = x_ND.new_empty(N)
        BLOCK_N = 32
        BLOCK_D = 32
        BLOCK_V = 32
        grid = (triton.cdiv(N, BLOCK_N),)
        fwd_kernel[grid](
            x_ND, w_DV, c_N,
            output_N, l_N,
            N, D, V,
            x_ND.stride(0), x_ND.stride(1),
            w_DV.stride(0), w_DV.stride(1),
            BLOCK_N, BLOCK_D, BLOCK_V,
        )
        ctx.save_for_backward(w_DV, x_ND, c_N, l_N)
        return output_N
        

    @staticmethod
    def backward(ctx, g_NV):
        # TODO
        w_DV, x_ND, c_N, l_N = ctx.saved_tensors

def YouJiacheng_linear_xent(x, y, At):
    return LMHeadThenLogSoftmaxThenGather.apply(x, At, y).sum()


In [None]:
class LinearCrossEntropyLoss(torch.autograd.Function):

    @staticmethod
    def forward(
        ctx,
        x,
        y,
        A,
        ignore_index=-100, # ignores all negative integers ...
    ):
        N, H = x.shape
        V, H_A = A.shape
        assert H_A == H
        assert y.shape == (N,)
        x = x.contiguous()
        y = y.contiguous()
        A = A.contiguous()

        # V_BLOCK_SIZE = 16
        # N_BLOCK_SIZE = 16 # 64

        assert N % 16 == 0
        # assert V % V_BLOCK_SIZE == 0

        # num_blocks = N // N_BLOCK_SIZE
        m_global = -10e5 * torch.ones(N, dtype=torch.float32, device=x.device)
        s_global = torch.zeros(N, dtype=torch.float32, device=x.device)
        loss = torch.zeros(1, dtype=torch.float32, device=x.device)

        # grid = (num_blocks,)
        grid = lambda meta: (triton.cdiv(N, meta['N_BLOCK_SIZE']), )
        with torch.cuda.device(x.device.index): # actually required
            linear_xent_fwd_kernel[grid](
                    x,
                    y,
                    A,
                    loss,  
                    m_global,
                    s_global,
                    x.stride(0), x.stride(1),
                    A.stride(0), A.stride(1),
                    V=V, N=N, H=H,
                )
        # print(linear_xent_fwd_kernel.best_config)


        ctx.save_for_backward(m_global, s_global)
        return loss

    @staticmethod
    def backward(ctx, losses):
        pass
        return x, None, A
    
def linear_cross_entropy(x, y, A):
    return LinearCrossEntropyLoss.apply(x, y, A)

# loss_triton = linear_cross_entropy(x, y, A) # type: ignore
# loss_triton, torch.dist(reference_loss, loss_triton).item()

In [None]:
def benchmark_with_memory_reporting(func, quantiles, *args, **kwargs):
    torch.cuda.synchronize()
    torch.cuda.reset_peak_memory_stats(device=device)
    initial_memory = torch.cuda.memory_allocated(device=device)
    
    ms, min_ms, max_ms = triton.testing.do_bench(lambda: func(*args, **kwargs), quantiles=quantiles)
    
    torch.cuda.synchronize()
    peak_memory = torch.cuda.max_memory_allocated(device=device)
    memory_used = peak_memory - initial_memory
    
    return ms, min_ms, max_ms, memory_used

In [None]:
f = 1 # number from 1 to 512 to get fast results
default_H = 2048

In [None]:
@triton.testing.perf_report(
    triton.testing.Benchmark(
        x_names=['H'],  # Argument names to use as an x-axis for the plot.
        x_vals=[2**i for i in range(9, 14, 1)],  # Different possible values for `x_name`.
        x_log=True,  # x axis is logarithmic.
        line_arg='provider',  # Argument name whose value corresponds to a different line in the plot.
        line_vals=['torch', 'torch-compile', 'triton2', 'triton2-t', 'triton-variant'],  # Possible values for `line_arg`.
        line_names=['torch', 'torch-compile', 'triton2', 'triton2-t', 'triton-variant'],  # Label name for the lines.
        # styles=[('blue', '-'), ('green', '-'), ('red', '-'), ('brown', '-')],  # Line styles.
        ylabel='TFLOP/s',  # Label name for the y-axis.
        plot_name='Linear+Loss Performance',  # Name for the plot. Used also as a file name for saving the plot.
        args={},  # Values for function arguments not in `x_names` and `y_name`.
    ))
def benchmark(H, provider):
    B, S, V = 4, 4096 // f, 131072 // f
    N = B * S 
    H = H // f

    x = torch.randn(N, H, requires_grad=True, device=device, dtype=torch.bfloat16) # B S H
    y = torch.randint(0, V, (N,), device=device) # vocab ** B S 
    A = torch.randn(V, H, requires_grad=True, device=device, dtype=torch.bfloat16)
    At = A.clone().T.contiguous()

    quantiles = [0.5, 0.2, 0.8]
    if provider == 'torch':
        ms, min_ms, max_ms = triton.testing.do_bench(lambda: baseline_torch(x, y, A), quantiles=quantiles)
    if provider == 'torch-compile':
        ms, min_ms, max_ms = triton.testing.do_bench(lambda: compiled_baseline(x, y, A), quantiles=quantiles)
    if provider == 'torch-maxauto':
        ms, min_ms, max_ms = triton.testing.do_bench(lambda: maxauto_baseline(x, y, A), quantiles=quantiles)
    if provider == "triton":
        ms, min_ms, max_ms = triton.testing.do_bench(lambda: linear_cross_entropy(x, y, A), quantiles=quantiles)
    if provider == "triton2":
        ms, min_ms, max_ms = triton.testing.do_bench(lambda: linear_xent_matmul(x, y, A), quantiles=quantiles)
    if provider == "triton2-t":
        ms, min_ms, max_ms = triton.testing.do_bench(lambda: linear_xent_matmul_At(x, y, At), quantiles=quantiles)
    if provider == "triton-variant":
        ms, min_ms, max_ms = triton.testing.do_bench(lambda: YouJiacheng_linear_xent(x, y, At), quantiles=quantiles)
        
    perf = lambda ms: 2 * (N * H * V) * 1e-12 / (ms * 1e-3)
    return perf(ms), perf(max_ms), perf(min_ms)

benchmark.run(print_data=True, show_plots=True)

In [None]:
@triton.testing.perf_report(
    triton.testing.Benchmark(
        x_names=['V'],  # Argument names to use as an x-axis for the plot.
        x_vals=[2**i for i in range(10, 18, 1)],  # Different possible values for `x_name`.
        x_log=True,  # x axis is logarithmic.
        line_arg='provider',  # Argument name whose value corresponds to a different line in the plot.
        line_vals=['torch', 'torch-compile', 'triton2', 'triton2-t',"triton-variant"],  # Possible values for `line_arg`.
        line_names=['torch', 'torch-compile', 'triton2', 'triton2-t',"triton-variant"],  # Label name for the lines.
        ylabel='TFLOP/s',  # Label name for the y-axis.
        plot_name='Linear+Loss Performance',  # Name for the plot. Used also as a file name for saving the plot.
        args={},  # Values for function arguments not in `x_names` and `y_name`.
    ))
def benchmark(V, provider):
    B, S , H = 4, 4096//f, default_H//f
    N = B * S 
    V = V // f

    x = torch.randn(N, H, requires_grad=True, device=device, dtype=torch.bfloat16) # B S H
    y = torch.randint(0, V, (N,), device=device) # vocab ** B S 
    A = torch.randn(V, H, requires_grad=True, device=device, dtype=torch.bfloat16)
    At = A.clone().T.contiguous()

    quantiles = [0.5, 0.2, 0.8]
    if provider == 'torch':
        ms, min_ms, max_ms = triton.testing.do_bench(lambda: baseline_torch(x, y, A), quantiles=quantiles)
    if provider == 'torch-compile':
        ms, min_ms, max_ms = triton.testing.do_bench(lambda: compiled_baseline(x, y, A), quantiles=quantiles)
    if provider == 'torch-maxauto':
        ms, min_ms, max_ms = triton.testing.do_bench(lambda: maxauto_baseline(x, y, A), quantiles=quantiles)
    if provider == "triton":
        ms, min_ms, max_ms = triton.testing.do_bench(lambda: linear_cross_entropy(x, y, A), quantiles=quantiles)
    if provider == "triton2":
        ms, min_ms, max_ms = triton.testing.do_bench(lambda: linear_xent_matmul(x, y, A), quantiles=quantiles)
    if provider == "triton2-t":
        ms, min_ms, max_ms = triton.testing.do_bench(lambda: linear_xent_matmul_At(x, y, At), quantiles=quantiles)
    if provider == "triton-variant":
        ms, min_ms, max_ms = triton.testing.do_bench(lambda: YouJiacheng_linear_xent(x, y, At), quantiles=quantiles)
    perf = lambda ms: 2 * N * H * V * 1e-12 / (ms * 1e-3)
    return perf(ms), perf(max_ms), perf(min_ms)

benchmark.run(print_data=True, show_plots=True)


In [None]:
@triton.testing.perf_report(
    triton.testing.Benchmark(
        x_names=['N'],  # Argument names to use as an x-axis for the plot.
        x_vals=[2**i for i in range(9, 15, 1)],  # Different possible values for `x_name`.
        x_log=True,  # x axis is logarithmic.
        line_arg='provider',  # Argument name whose value corresponds to a different line in the plot.
        line_vals=['torch', 'torch-compile', 'triton2', 'triton2-t',"triton-variant"],  # Possible values for `line_arg`.
        line_names=['torch', 'torch-compile', 'triton2', 'triton2-t',"triton-variant"],  # Label name for the lines.
        ylabel='TFLOP/s',  # Label name for the y-axis.
        plot_name='Linear+Loss Performance',  # Name for the plot. Used also as a file name for saving the plot.
        args={},  # Values for function arguments not in `x_names` and `y_name`.
    ))
def benchmark(N, provider):
    H, V = default_H//f, 131072//f
    N = N // f

    x = torch.randn(N, H, requires_grad=True, device=device, dtype=torch.bfloat16) # B S H
    y = torch.randint(0, V, (N,), device=device) # vocab ** B S 
    A = torch.randn(V, H, requires_grad=True, device=device, dtype=torch.bfloat16)
    At = A.clone().T.contiguous()

    quantiles = [0.5, 0.2, 0.8]
    if provider == 'torch':
        ms, min_ms, max_ms = triton.testing.do_bench(lambda: baseline_torch(x, y, A), quantiles=quantiles)
    if provider == 'torch-compile':
        ms, min_ms, max_ms = triton.testing.do_bench(lambda: compiled_baseline(x, y, A), quantiles=quantiles)
    if provider == 'torch-maxauto':
        ms, min_ms, max_ms = triton.testing.do_bench(lambda: maxauto_baseline(x, y, A), quantiles=quantiles)
    if provider == "triton":
        ms, min_ms, max_ms = triton.testing.do_bench(lambda: linear_cross_entropy(x, y, A), quantiles=quantiles)
    if provider == "triton2":
        ms, min_ms, max_ms = triton.testing.do_bench(lambda: linear_xent_matmul(x, y, A), quantiles=quantiles)
    if provider == "triton2-t":
        ms, min_ms, max_ms = triton.testing.do_bench(lambda: linear_xent_matmul_At(x, y, At), quantiles=quantiles)
    if provider == "triton-variant":
        ms, min_ms, max_ms = triton.testing.do_bench(lambda: YouJiacheng_linear_xent(x, y, At), quantiles=quantiles)
    perf = lambda ms: 2 * N * H * V * 1e-12 / (ms * 1e-3)
    return perf(ms), perf(max_ms), perf(min_ms)

benchmark.run(print_data=True, show_plots=True)

In [None]:
@triton.testing.perf_report(
    triton.testing.Benchmark(
        x_names=['N'],  # Argument names to use as an x-axis for the plot.
        x_vals=[2**i for i in range(9, 15, 1)],  # Different possible values for `x_name`.
        x_log=True,  # x axis is logarithmic.
        line_arg='provider',  # Argument name whose value corresponds to a different line in the plot.
        line_vals=['torch', 'torch-compile', 'triton2', 'triton2-t'],  # Possible values for `line_arg`.
        line_names=['torch', 'torch-compile', 'triton2', 'triton2-t'],  # Label name for the lines.
        ylabel='GBs of Memory',  # Label name for the y-axis.
        plot_name='Linear+Loss Performance',  # Name for the plot. Used also as a file name for saving the plot.
        args={},  # Values for function arguments not in `x_names` and `y_name`.
    ))
def benchmark(N, provider):
    H, V = default_H//f, 131072//f
    N = N // f

    x = torch.randn(N, H, requires_grad=True, device=device, dtype=torch.bfloat16) # B S H
    y = torch.randint(0, V, (N,), device=device) # vocab ** B S 
    A = torch.randn(V, H, requires_grad=True, device=device, dtype=torch.bfloat16)
    At = A.clone().T.contiguous()

    quantiles = [0.5, 0.2, 0.8]
    if provider == 'torch':
        ms, min_ms, max_ms, memory_used = benchmark_with_memory_reporting(lambda: baseline_torch(x, y, A), quantiles=quantiles)
    if provider == 'torch-compile':
        ms, min_ms, max_ms, memory_used = benchmark_with_memory_reporting(lambda: compiled_baseline(x, y, A), quantiles=quantiles)
    if provider == 'torch-maxauto':
        ms, min_ms, max_ms, memory_used = benchmark_with_memory_reporting(lambda: maxauto_baseline(x, y, A), quantiles=quantiles)
    if provider == "triton":
        ms, min_ms, max_ms, memory_used = benchmark_with_memory_reporting(lambda: linear_cross_entropy(x, y, A), quantiles=quantiles)
    if provider == "triton2":
        ms, min_ms, max_ms, memory_used = benchmark_with_memory_reporting(lambda: linear_xent_matmul(x, y, A), quantiles=quantiles)
    if provider == "triton2-t":
        ms, min_ms, max_ms, memory_used = benchmark_with_memory_reporting(lambda: linear_xent_matmul_At(x, y, At), quantiles=quantiles)
    return memory_used / 1024**3, 0, 0

benchmark.run(print_data=True, show_plots=True)

In [None]:
@triton.testing.perf_report(
    triton.testing.Benchmark(
        x_names=['V'],  # Argument names to use as an x-axis for the plot.
        x_vals=[2**i for i in range(10, 19, 1)],  # Different possible values for `x_name`.
        x_log=True,  # x axis is logarithmic.
        line_arg='provider',  # Argument name whose value corresponds to a different line in the plot.
        line_vals=['torch', 'torch-compile', 'triton2', 'triton2-t'],  # Possible values for `line_arg`.
        line_names=['torch', 'torch-compile', 'triton2', 'triton2-t'],  # Label name for the lines.
        ylabel='GBs of Memory',  # Label name for the y-axis.
        plot_name='Linear+Loss Performance',  # Name for the plot. Used also as a file name for saving the plot.
        args={},  # Values for function arguments not in `x_names` and `y_name`.
    ))

@torch.no_grad()
def benchmark(V, provider):
    B, S , H = 4, 4096//f, default_H//f
    N = B * S 
    V = V // f

    x = torch.randn(N, H, requires_grad=True, device=device, dtype=torch.bfloat16) # B S H
    y = torch.randint(0, V, (N,), device=device) # vocab ** B S 
    A = torch.randn(V, H, requires_grad=True, device=device, dtype=torch.bfloat16)
    At = A.clone().T.contiguous()

    quantiles = [0.5, 0.2, 0.8]
    if provider == 'torch':
        ms, min_ms, max_ms, memory_used = benchmark_with_memory_reporting(lambda: baseline_torch(x, y, A), quantiles=quantiles)
    if provider == 'torch-compile':
        ms, min_ms, max_ms, memory_used = benchmark_with_memory_reporting(lambda: compiled_baseline(x, y, A), quantiles=quantiles)
    if provider == 'torch-maxauto':
        ms, min_ms, max_ms, memory_used = benchmark_with_memory_reporting(lambda: maxauto_baseline(x, y, A), quantiles=quantiles)
    if provider == "triton":
        ms, min_ms, max_ms, memory_used = benchmark_with_memory_reporting(lambda: linear_cross_entropy(x, y, A), quantiles=quantiles)
    if provider == "triton2":
        ms, min_ms, max_ms, memory_used = benchmark_with_memory_reporting(lambda: linear_xent_matmul(x, y, A), quantiles=quantiles)
    if provider == "triton2-t":
        ms, min_ms, max_ms, memory_used = benchmark_with_memory_reporting(lambda: linear_xent_matmul_At(x, y, At), quantiles=quantiles)
    return memory_used / 1024**3, 0, 0

benchmark.run(print_data=True, show_plots=True)

In [None]:
torch.cuda.set_stream(torch.cuda.Stream())

In [None]:
@triton.testing.perf_report(
    triton.testing.Benchmark(
        x_names=['N'],  # Argument names to use as an x-axis for the plot.
        x_vals=[2**i for i in range(9, 15, 1)],  # Different possible values for `x_name`.
        x_log=True,  # x axis is logarithmic.
        line_arg='provider',  # Argument name whose value corresponds to a different line in the plot.
        line_vals=['torch', 'torch-compile', 'triton2', 'triton2-t'],  # Possible values for `line_arg`.
        line_names=['torch', 'torch-compile', 'triton2', 'triton2-t'],  # Label name for the lines.
        ylabel='TFLOP/s',  # Label name for the y-axis.
        plot_name='Linear+Loss Performance',  # Name for the plot. Used also as a file name for saving the plot.
        args={},  # Values for function arguments not in `x_names` and `y_name`.
    ))
def benchmark(N, provider):
    H, V = default_H//f, 131072//f
    N = N // f

    x = torch.randn(N, H, requires_grad=True, device=device, dtype=torch.bfloat16) # B S H
    y = torch.randint(0, V, (N,), device=device) # vocab ** B S 
    A = torch.randn(V, H, requires_grad=True, device=device, dtype=torch.bfloat16)
    At = A.clone().T.contiguous()

    quantiles = [0.5, 0.2, 0.8]
    if provider == 'torch':
        ms = triton.testing.do_bench_cudagraph(lambda: baseline_torch(x, y, A))
    if provider == 'torch-compile':
        ms = triton.testing.do_bench_cudagraph(lambda: compiled_baseline(x, y, A))
    if provider == 'torch-maxauto':
        ms = triton.testing.do_bench_cudagraph(lambda: maxauto_baseline(x, y, A))
    if provider == "triton":
        ms = triton.testing.do_bench_cudagraph(lambda: linear_cross_entropy(x, y, A))
    if provider == "triton2":
        ms = triton.testing.do_bench_cudagraph(lambda: linear_xent_matmul(x, y, A))
    if provider == "triton2":
        ms = triton.testing.do_bench_cudagraph(lambda: linear_xent_matmul(x, y, A))
    if provider == "triton2-t":
        ms = triton.testing.do_bench_cudagraph(lambda: linear_xent_matmul_At(x, y, At))
    perf = lambda ms: 2 * N * H * V * 1e-12 / (ms * 1e-3)
    return perf(ms)

benchmark.run(print_data=True, show_plots=True)