In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"]="6"
# os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
import time
from datetime import timedelta

import torch
import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint
import math

import triton
import triton.language as tl

%reload_ext autoreload
%autoreload 2

In [None]:
device = torch.device("cuda:0")
torch.cuda.device_count()

In [None]:
f = 4 # make larger to let test go fast. f=1 is target size

In [None]:
def cosim(x,y):
    return ((x.reshape(-1).double() * y.reshape(-1).double()).sum() / x.reshape(-1).double().norm() / y.reshape(-1).double().norm()).float()

@torch._dynamo.disable
def baseline_torch(x, y, A):
    V = A.shape[0]
    return F.cross_entropy(F.linear(x, A).view(-1, V).float(), y.view(-1))

@torch.compile # need to define this twice, otherwise there's weird shadowing happening in the notebook
def compiled_baseline(x, y, A):
    V = A.shape[0]
    return F.cross_entropy(F.linear(x, A).view(-1, V).float(), y.view(-1))

maxauto_baseline = torch.compile(baseline_torch, fullgraph=True, mode="max-autotune")

In [None]:
N, H, V = (4096 * 4) // f, 4096 // f, 131072 // f
# N, H, V =256, 1024 * 4, 32768

compute_dtype = torch.float16

y = torch.randint(0, V, (N,), device=device) # vocab ** B S 
A = torch.randn(V, H, requires_grad=True, device=device, dtype=compute_dtype)
At = A.clone().detach().T.contiguous()
At.requires_grad_()

# x = torch.randn(B * S, H, requires_grad=True, device=device, dtype=torch.float32) # B S H
# x = A[y].clone().detach()
x = 0.1 * A[y].clone().detach() + torch.randn(N, H, device=device, dtype=compute_dtype)
x.requires_grad_()

loss = baseline_torch(x.float(), y, A.float())
loss.backward()

reference_A_grad = A.grad.float().clone()
reference_x_grad = x.grad.float().clone()
reference_loss = loss.detach().float().clone()

z_ref = F.linear(x, A).view(-1, V).float().detach()
m_ref = z_ref.max(dim=1)[0]
s_ref = (z_ref - m_ref[:, None]).exp().sum(dim=1)

print(reference_loss)

In [None]:
def simple_bench(fn, reference_loss, reference_x_grad, reference_A_grad):
    x.grad, A.grad, At.grad = None, None, None
    loss_triton = fn() # warmup
    torch.cuda.synchronize()
    start_event = torch.cuda.Event(enable_timing=True)
    end_event = torch.cuda.Event(enable_timing=True)
    start_event.record()
    loss_triton = fn()
    end_event.record()
    torch.cuda.synchronize()
    estimate_ms_fwd = start_event.elapsed_time(end_event)
    print(f"fwd : {estimate_ms_fwd}ms")
    print(f"fwd error: {torch.dist(loss, reference_loss).item()}")
    loss_triton = fn()
    loss_triton.backward() # warmup
    x.grad, A.grad, At.grad = None, None, None
    loss_triton = fn()
    torch.cuda.synchronize()
    start_event = torch.cuda.Event(enable_timing=True)
    end_event = torch.cuda.Event(enable_timing=True)
    start_event.record()
    loss_triton.backward()
    end_event.record()
    torch.cuda.synchronize()
    estimate_ms_bwd = start_event.elapsed_time(end_event)
    print(f"bwd : {estimate_ms_bwd}ms")
    if At.grad is not None:
        A_error = torch.dist(reference_A_grad.T, At.grad).item()
    else:
        A_error = torch.dist(reference_A_grad, A.grad).item()
    print(f"bwd error: {torch.dist(reference_x_grad, x.grad).item()}, {A_error}")

In [None]:

def _inner_function(x_block, y_block, A, num_blocks):
    return F.cross_entropy(F.linear(x_block, A), y_block) / num_blocks

# @torch.compile(dynamic=False)
def torch_compiled_checkpoint(x, y, A, default_chunk_size = 512):
    loss = 0.
    N = x.view(-1, H).shape[0]
    chunk_size = min(default_chunk_size, N)
    if chunk_size % N != 0:
        chunk_size = math.gcd(N, default_chunk_size)
    x_blocks = x.view(-1, H).split(chunk_size)
    y_blocks = y.view(-1).split(chunk_size)


    for x_block, y_block in zip(x_blocks, y_blocks):
        loss += checkpoint(_inner_function, x_block, y_block, A, num_blocks=len(y_blocks), use_reentrant=False)
    return loss
torch_compiled_checkpoint(x, y, A)

In [None]:
simple_bench(lambda: compiled_baseline(x,y,A), reference_loss, reference_x_grad, reference_A_grad)

In [None]:
simple_bench(lambda: torch_compiled_checkpoint(x,y,A), reference_loss, reference_x_grad, reference_A_grad)

In [None]:
from malek_xent import linear_cross_entropy as efficient_xent, FusedProjectionPlusCrossEntropyLoss


# op = FusedProjectionPlusCrossEntropyLoss(H, V, 16).to(device)
# op(x, y).backward()



simple_bench(lambda: efficient_xent(x,y,A), reference_loss, reference_x_grad, reference_A_grad)

In [None]:
from double_recomp2 import linear_cross_entropy as linear_cross_entropy_double_recomp


simple_bench(lambda: linear_cross_entropy_double_recomp(x,y,At), reference_loss, reference_x_grad, reference_A_grad)

In [None]:
del x, y, A, At

In [None]:
import sys,os
sys.path.append(os.getcwd())

In [None]:
# to try eventually:
# from litgpt.ops import linear_cross_entropy_checkerboard
# from litgpt.ops import linear_cross_entropy_nolock
# from litgpt.ops import linear_cross_entropy
# from litgpt.ops import linear_cross_entropy_double_recomp

from litgpt.ops import linear_cross_entropy_nolock
from highmem import linear_cross_entropy as linear_cross_entropy_highmem
from double_recomp2 import linear_cross_entropy as linear_cross_entropy_double_recomp
from double_recomp3 import linear_cross_entropy as linear_cross_entropy_parallel_recomp
from manyway_recomp import linear_cross_entropy as linear_cross_entropy_manyway
from malek_xent import linear_cross_entropy as efficient_xent

# Benchmarking FWD + BWD

In [None]:
range_dict = {'H':range(10, 14, 1),'V': range(10, 18, 1),'N': range(8, 15, 1)}

configs = []
for mode in ["fwd", "bwd", "fwd-bwd"]: # , 
    for variable in ['H', 'N', 'V']: 
        configs.append(
            triton.testing.Benchmark(
                x_names=[variable],  # Argument names to use as an x-axis for the plot.
                x_vals=[(2**i)//f for i in range_dict[variable]],  # Different possible values for `x_name`.
                x_log=True,  # x axis is logarithmic.
                line_arg='provider',  # Argument name whose value corresponds to a different line in the plot.
                line_vals=['torch', 'torch-compile', 'torch-compile-checkpoint', 'triton', 'triton-recomp', 'triton-par-recomp', 'triton-many-recomp', 'malek'],
                line_names=['torch', 'torch-compile', 'torch-compile-checkpoint', 'triton', 'triton-recomp', 'triton-par-recomp', 'triton-many-recomp', 'malek'],
                ylabel='TFLOP/s',  # Label name for the y-axis.
                plot_name=f'{mode}-Linear+Loss Performance. Defaults: N=B*S=16384, H=2048, V=131072',
                args={"mode": mode},  # Values for function arguments not in `x_names` and `y_name`.
            ))
        
@triton.testing.perf_report(configs)
def benchmark(H=2048//f, V=131072//f, N=(4096 * 4)//f, provider="torch", mode="fwd"):
    print(provider, N, H, V, mode)

    x = torch.randn(N, H, requires_grad=True, device=device, dtype=torch.float16) # B S H
    y = torch.randint(0, V, (N,), device=device) # vocab ** B S 
    A = torch.randn(V, H, requires_grad=True, device=device, dtype=torch.float16)
    At = A.detach().clone().T.contiguous()
    At.requires_grad_()

    if provider == 'torch':
        fn = lambda: baseline_torch(x, y, A)
    if provider == 'torch-compile':
        fn = lambda: compiled_baseline(x, y, A)
    if provider == "torch-compile-checkpoint":
        fn = lambda: torch_compiled_checkpoint(x, y, A)
    if provider == "triton":
        fn = lambda: linear_cross_entropy(x, y, At)
    if provider == "triton-nolock-nowrite":
        fn = lambda: linear_cross_entropy_nolock(x, y, At)
    if provider == "triton":
        fn = lambda: linear_cross_entropy_highmem(x, y, At)
    if provider == "triton-recomp":
        fn = lambda: linear_cross_entropy_double_recomp(x, y, At)
    if provider == "triton-par-recomp":
        fn = lambda: linear_cross_entropy_parallel_recomp(x, y, At)
    if provider == "malek":
        fn = lambda: efficient_xent(x, y, A)
    if provider == "triton-many-recomp":
        fn = lambda: linear_cross_entropy_manyway(x, y, At)

    try:
        if mode == "fwd":
            @torch.no_grad
            def test_fn():
                fn()
        elif mode == "bwd":
            loss = fn()
            test_fn = lambda: loss.backward(retain_graph=True)
        elif mode == "fwd-bwd":
            test_fn = lambda: fn().backward()
        else:
            test_fn = fn

        quantiles = [0.5, 0.2, 0.8]
        ms, min_ms, max_ms = triton.testing.do_bench(test_fn, quantiles=quantiles, warmup = 50, rep = 200)
    except: # in any failure case
        print(f"error when computing {provider} for N={N}, H={H}, V={V}")
        ms, min_ms, max_ms = 1e6, 1e6, 1e6

    flop = 2 * (N * H * V) + 3 * N * V
    if mode == "bwd":
        flop *= 2
    if mode == "fwd-bwd":
        flop *= 3
    
    perf = lambda ms: flop * 1e-12 / (ms * 1e-3)
    return perf(ms), perf(max_ms), perf(min_ms)

benchmark.run(print_data=False, show_plots=True)

# Bench memory

In [None]:
def benchmark_with_memory_reporting(func, quantiles, *args, **kwargs):
    torch.cuda.synchronize()
    torch.cuda.reset_peak_memory_stats(device=device)
    initial_memory = torch.cuda.memory_allocated(device=device)
    
    ms, min_ms, max_ms = triton.testing.do_bench(lambda: func(*args, **kwargs), quantiles=quantiles, warmup = 5, rep = 1)
    
    torch.cuda.synchronize()
    peak_memory = torch.cuda.max_memory_allocated(device=device)
    memory_used = peak_memory - initial_memory
    
    return ms, min_ms, max_ms, memory_used

range_dict = {'H':range(8, 14, 1),'V': range(10, 18, 1),'N': range(8, 15, 1)}

configs = []
for mode in ["fwd", "fwd-bwd"]:
    for variable in ['H', 'N', 'V']:
        configs.append(
            triton.testing.Benchmark(
                x_names=[variable],  # Argument names to use as an x-axis for the plot.
                x_vals=[2**i for i in range_dict[variable]],  # Different possible values for `x_name`.
                x_log=True,  # x axis is logarithmic.
                line_arg='provider',  # Argument name whose value corresponds to a different line in the plot.
                ylabel='Peak Memory in GB (excluding inputs)',  # Label name for the y-axis.
                line_vals=['torch', 'torch-compile', 'torch-compile-checkpoint', 'triton', 'triton-recomp', 'triton-par-recomp', 'triton-many-recomp', 'malek'],
                line_names=['torch', 'torch-compile', 'torch-compile-checkpoint', 'triton', 'triton-recomp', 'triton-par-recomp', 'triton-many-recomp', 'malek'],
                args={"mode": mode},  # Values for function arguments not in `x_names` and `y_name`.
            ))
        
@triton.testing.perf_report(configs)
def benchmark(H=4096, V=131072, N=4096 * 4, provider="torch", mode="fwd"):

    x = torch.randn(N, H, requires_grad=True, device=device, dtype=torch.bfloat16) # B S H
    y = torch.randint(0, V, (N,), device=device) # vocab ** B S 
    A = torch.randn(V, H, requires_grad=True, device=device, dtype=torch.bfloat16)
    At = A.clone().T.contiguous()

    if provider == 'torch':
        fn = lambda: baseline_torch(x, y, A)
    if provider == 'torch-compile':
        fn = lambda: compiled_baseline(x, y, A)
    if provider == "torch-compile-checkpoint":
        fn = lambda: torch_compiled_checkpoint(x, y, At)
    if provider == "triton":
        fn = lambda: linear_cross_entropy(x, y, At)

    if mode == "bwd":
        loss = fn()
        fn = lambda: loss.backward(retain_graph=True)

    quantiles = [0.5, 0.2, 0.8]
    ms, min_ms, max_ms, max_memory_allocated = benchmark_with_memory_reporting(fn, quantiles=quantiles)

    return max_memory_allocated / 1024**3, 0, 0

benchmark.run(print_data=True, show_plots=True)