In [None]:
%matplotlib inline

# Testing and Benchmarking Triton

**IMPORTANT**: since triton compiler is involved, we must restart the kernel after each modification of the optimized.rmsnorm module

In [None]:
import sys
sys.path.append('..')
import torch
from baselines.rmsnorm import RMSNormL3, RMSNormGT1, MyRMSNorm
# Note, since triton compiler is involved, we must restart the kernel after each modification of the 
# optimized.rmsnorm module
from optimized.rmsnorm import RMSNormTriton
from matplotlib import pyplot as plt

## Results for Ground Truth

In [None]:
def printRMSNorm(rmsnorm_class, x : torch.Tensor = None):
    print("--------------------")
    print("RMSNormTriton: x.shape: ", end="")
    print(x.shape)
    
    # Compute the expected output
    rmsnorm_test = rmsnorm_class(dim=3)
    expected_output = rmsnorm_test.forward(x)
    print(x)
    print(expected_output)

printRMSNorm(RMSNormGT1, torch.tensor([[1, 2, 3], [3, 3, 3], [7, 8, 9], [100, 150, 150]], dtype=torch.float32))
printRMSNorm(RMSNormGT1, torch.tensor([[1, 2, 3], ], dtype=torch.float16))
printRMSNorm(RMSNormGT1, torch.tensor([1, 2, 3], dtype=torch.float16))
printRMSNorm(RMSNormGT1, torch.tensor([[[1, 2, 3], [7, 8, 9]], [[100, 150, 150], [100, 150, 150]]], dtype=torch.float16))

Results for Triton

In [None]:
def printRMSNorm(rmsnorm_class, x : torch.Tensor = None):
    print("--------------------")
    print("RMSNormTriton: x.shape: ", end="")
    print(x.shape)
    # Compute the expected output
    rmsnorm_test = rmsnorm_class(dim=3).cuda()
    expected_output = rmsnorm_test.forward(x.cuda())
    print(x)
    print(expected_output)

# One sequence
printRMSNorm(RMSNormTriton, torch.tensor([[1, 2, 3], ], dtype=torch.float16))

# Multiple sequences
printRMSNorm(RMSNormTriton, torch.tensor([[1, 2, 3], [3, 3, 3], [7, 8, 9], [100, 150, 150]], dtype=torch.float32))

# Multiple batches, with multiple sequences
printRMSNorm(RMSNormTriton, torch.tensor([[[1, 2, 3], [7, 8, 9]], [[100, 150, 150], [100, 150, 150]]], dtype=torch.float16))

In [None]:
print("RMSNormTriton: x.shape: ", end="")

# Actual Performance Measurment

## Verify that outputs match

In [None]:
import triton

# Make sure the numbers are the same (or close)
torch.manual_seed(0)
seq_len = 100
model_dim = 4096
x = torch.randn(seq_len, model_dim, device='cuda')

# Triton implementation
rmsnormtriton = RMSNormTriton(dim=model_dim).cuda()
triton_output = rmsnormtriton.forward(x)

# Ground Truth PyTorch implementation
rmsnormgt1 = RMSNormGT1(dim=model_dim).cuda()
torch_output = rmsnormgt1.forward(x)
if torch.allclose(triton_output, torch_output):
# assert torch.allclose(y_triton, y_torch), (y_triton, y_torch)
    print("✅ Triton and Torch match")
else:
    print("❌ Triton and Torch differ")

## Measure Runtime

In [None]:
# Values of sequence leng to Sweep
x_vals = [1, 128, 256] + [512 * i for i in range(1, 25)]

# Value of model_dim
model_dim = model_dim # from cell above since we pre-allocate the "model"

In [None]:
plot_name = 'runtime'

@triton.testing.perf_report(
    triton.testing.Benchmark(
        x_names=['SeqLen'],  # argument names to use as an x-axis for the plot
        # I want to benchmark sequence lengths from 1 to 12,800 
        # llama2 context size is 8k tokens, but longer context len extensions are possible
        # also I observed that it is the only way to saturate GPU memory BW of V100
        # I want to include 1,128,and 256 to see the effect of small sequence lengths
        x_vals = x_vals,

        line_arg='provider',  # argument name whose value corresponds to a different line in the plot
        line_vals=[
            'triton',
            'torch-native',
            'torch-compile'
        ],  # possible values for `line_arg``
        line_names=[
            "Triton",
            "Torch (native)",
            "Torch (compiled)"
        ],  # label name for the lines

        styles=[('blue', '-'), ('green', '-.'), ('green', '--')],  # line styles
        # ylabel="GB/s",  # label name for the y-axis
        ylabel="ms",  # label name for the y-axis
        plot_name=plot_name,  # name for the plot. Used also as a file name for saving the plot.
        args={'N': 4096},  # values for function arguments not in `x_names` and `y_name`
    ))

def benchmark(SeqLen, N, provider):
    x = torch.randn(SeqLen, N, device='cuda', dtype=torch.float16)
    quantiles = [0.5, 0.2, 0.8] # report median, 20th and 80th percentiles
    if provider == 'triton':
        ms, min_ms, max_ms = triton.testing.do_bench(lambda: rmsnormtriton(x), quantiles=quantiles)
    if provider == 'torch-native':
        ms, min_ms, max_ms = triton.testing.do_bench(lambda: rmsnormgt1(x), quantiles=quantiles)
    if provider == 'torch-compile':
        ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.compile(rmsnormgt1)(x), quantiles=quantiles)

    # For GB/s reports
    # gbps = lambda ms: 2 * x.nelement() * x.element_size() * 1e-9 / (ms * 1e-3)
    # return gbps(ms), gbps(max_ms), gbps(min_ms)

    # For ms reports
    return ms, max_ms, min_ms


benchmark.run(show_plots=False, print_data=True, save_path='./prof/')
plt.show()

## Memory Bandwidth

In [None]:
plot_name = 'bandwidth'

@triton.testing.perf_report(
    triton.testing.Benchmark(
        x_names=['SeqLen'],  # argument names to use as an x-axis for the plot
        # I want to benchmark sequence lengths from 1 to 12,800 
        # llama2 context size is 8k tokens, but longer context len extensions are possible
        # also I observed that it is the only way to saturate GPU memory BW of V100
        # I want to include 1,128,and 256 to see the effect of small sequence lengths
        x_vals = x_vals,

        line_arg='provider',  # argument name whose value corresponds to a different line in the plot
        line_vals=[
            'triton',
            'torch-native',
            'torch-compile'
        ],  # possible values for `line_arg``
        line_names=[
            "Triton",
            "Torch (native)",
            "Torch (compiled)"
        ],  # label name for the lines

        styles=[('blue', '-'), ('green', '-.'), ('green', '--')],  # line styles
        ylabel="GB/s",  # label name for the y-axis
        plot_name="bandwidth",  # name for the plot. Used also as a file name for saving the plot.
        # ylabel="ms",  # label name for the y-axis
        # plot_name="runtime",  # name for the plot. Used also as a file name for saving the plot.
        args={'N': 4096},  # values for function arguments not in `x_names` and `y_name`
    ))

def benchmark(SeqLen, N, provider):
    x = torch.randn(SeqLen, N, device='cuda', dtype=torch.float16)
    quantiles = [0.5, 0.2, 0.8] # report median, 20th and 80th percentiles
    if provider == 'triton':
        ms, min_ms, max_ms = triton.testing.do_bench(lambda: rmsnormtriton(x), quantiles=quantiles)
    if provider == 'torch-native':
        ms, min_ms, max_ms = triton.testing.do_bench(lambda: rmsnormgt1(x), quantiles=quantiles)
    if provider == 'torch-compile':
        ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.compile(rmsnormgt1)(x), quantiles=quantiles)

    # For GB/s reports
    gbps = lambda ms: 2 * x.nelement() * x.element_size() * 1e-9 / (ms * 1e-3)
    return gbps(ms), gbps(max_ms), gbps(min_ms)

    # For ms reports
    # return ms, max_ms, min_ms


benchmark.run(show_plots=False, print_data=True, save_path='./prof/')
plt.show()

## Torch Profiler

In [None]:

seq_len = 4096
model_dim = 4096

def runRMSNorm(rmsnorm_class, x : torch.Tensor = None):
    rmsnorm_test = rmsnorm_class(dim=model_dim).cuda()
    rmsnorm_test.forward(x.cuda())

# In this example with wait=1, warmup=1, active=2, repeat=1,
# profiler will skip the first step/iteration,
# start warming up on the second, record
# the third and the forth iterations,
# after which the trace will become available
# and on_trace_ready (when set) is called;
# the cycle repeats starting with the next step
wait_val = 1
warmup_val = 1
active_val = 2
repeat_val = 1

def trace_handler(prof):
    """trace_handler is called every time a new trace becomes available"""

    print(prof.key_averages().table(
        sort_by="self_cuda_time_total", row_limit=-1))
    prof.export_chrome_trace("prof/rmsnorm_trace_{}x{}_{}.json".format(
        seq_len, model_dim, prof.step_num))

with torch.profiler.profile(
    activities=[
        # torch.profiler.ProfilerActivity.CPU,
        torch.profiler.ProfilerActivity.CUDA,
    ],

    schedule=torch.profiler.schedule(wait=wait_val, 
                                     warmup=warmup_val, 
                                     active=active_val, 
                                     repeat=repeat_val),
    on_trace_ready=trace_handler
    # Outputting for tensorboard? Use this instead:
    # on_trace_ready=torch.profiler.tensorboard_trace_handler('./prof')
    ) as p:
        for iter in range(10):
            x = torch.randn(seq_len, model_dim, device='cuda', dtype=torch.float16)
            runRMSNorm(RMSNormTriton, x)
            p.step()


In [None]:
import re
def get_cuda_time_in_ms(the_table_str: str):
    for line in the_table_str.splitlines():
        if "Self CUDA time total:" in line:
            match = re.search(r"(\d+\.\d+)(\w+)", line)
            if match:
                time_value = float(match.group(1))
                unit = match.group(2)
                if unit == 'us':  # if unit is microseconds, convert to milliseconds
                    time_value /= 1000
                return time_value

# Use the function
cuda_time_in_ms = get_cuda_time_in_ms(p.key_averages().table(sort_by="self_cuda_time_total", row_limit=1))
print(cuda_time_in_ms)

## Log Artifacts with MLFlow

In [None]:
import mlflow

experiment_name = "rmsnorm"
implementation_name = "RMSNormTriton"
implementation_type = "triton"

mlflow.set_experiment(experiment_name)
experiment = mlflow.get_experiment_by_name(experiment_name)

if(True):
    # List of files to log
    files_to_log = ['../optimized/rmsnorm.py', 
                    'prof/bandwidth.csv', 
                    'prof/bandwidth.png', 
                    'prof/runtime.csv', 
                    'prof/runtime.png',
                    'prof/rmsnorm_trace_4096x4096_4.json'
                    ]

    # Start an MLflow run
    with mlflow.start_run(experiment_id=experiment.experiment_id) as run:
        # Iterate over the list of files
        for file in files_to_log:
            # Log each file as an artifact
            mlflow.log_artifact(file)
        
        # Log experiment with seq_len and model_dim, and result cuda_time_in_ms
        mlflow.log_param("seq_len", seq_len)
        mlflow.log_param("model_dim", model_dim)
        mlflow.log_param("implementation", implementation_name)
        mlflow.log_param("implementation_type", implementation_type)
        mlflow.log_metric("cuda_time_in_ms", cuda_time_in_ms)