In [None]:
import jax
import triton
import triton.language as tl
import jax_triton as jt
import jax.numpy as jnp

In [3]:
def leaky_relu_dropout(x, rate=0.5):
    """
    Applies Leaky ReLU activation followed by Dropout.

    Parameters
    ----------
    x : jax.numpy.ndarray
        The input array to which the dropout mask will be applied.
    rate : float, optional
        The dropout rate. Default is 0.5.

    Returns
    -------
    y: jax.numpy.ndarray
        The input array x after applying Leaky ReLU and dropout
    """

    # Apply Leaky ReLU
    x = jnp.where(x >= 0, x, 0.01 * x)

    # Apply Dropout
    keep_prob = 1.0 - rate
    rand_tensor = jax.random.uniform(jax.random.PRNGKey(0), x.shape)
    keep_mask = jnp.where(rand_tensor > rate, 1.0, 0.0)
    y = x * keep_mask / keep_prob
    
    return y

In [None]:
key = jax.random.PRNGKey(0)
x = jax.random.uniform(key, (1823, 781))
%timeit leaky_relu_dropout(x).block_until_ready()

In [5]:
jit_leaky_relu_dropout = jax.jit(leaky_relu_dropout)
%timeit jit_leaky_relu_dropout(x).block_until_ready()

In [32]:
@triton.jit
def leaky_dropout_kernel(
    x_ptr,
    output_ptr,
    rows: tl.constexpr,
    cols: tl.constexpr,
    p: tl.constexpr,
    seed: tl.constexpr,
    block_size: tl.constexpr,
):
    """
    Triton kernel that applies leaky ReLU and dropout to the input tensor x.

    Args:
    x_ptr: The pointer to the input tensor.
    output_ptr: The pointer to the output tensor.
    rows: The number of rows in the input tensor.
    cols: The number of columns in the input tensor.
    p: The probability of an element to be zeroed.
    seed: The seed for the random number generator.
    block_size: The block size for parallelization.
    """

    # compute memory offsets of elements handled by this instance
    pid = tl.program_id(axis=0)
    block_start = pid * block_size
    offsets = block_start + tl.arange(0, block_size)
    
    # compute row and column indices
    row_idx = offsets % rows
    col_idx = offsets // rows
    mask = (row_idx < rows) & (col_idx < cols)

    # load data from x
    x = tl.load(x_ptr + offsets, mask=mask)

    # apply leaky relu
    x = tl.where(x >= 0.0, x, 0.01 * x)
    
    # randomly prune it
    random = tl.rand(tl.full([], seed, tl.int32), offsets)
    x_keep = random > p
    
    # apply dropout
    output = tl.where(x_keep, x / (1 - p), 0.0)
    
    # store output
    tl.store(output_ptr + offsets, output, mask=mask)

In [34]:
def triton_leaky_dropout(x, p=0.5, seed=123) -> jnp.ndarray:
    """
    Helper function to call leaky_dropout_kernel.

    Args:
    x: The input tensor.
    p: The probability of an element to be zeroed. Defaults to 0.5.
    seed: The seed for the random number generator. Defaults to 123.

    Returns:
    A tensor with the same shape and dtype as x, but with leaky relu dropout applied.
    """
    
    out_shape = jax.ShapeDtypeStruct(shape=x.shape, dtype=x.dtype)
    
    rows, cols = x.shape
    n_elements = x.size

    grid = lambda meta: (triton.cdiv(n_elements, meta['block_size']), )
    
    return jt.triton_call(
        x,
        kernel=leaky_dropout_kernel,
        out_shape=out_shape,
        grid=grid,
        rows=rows,
        cols=cols,
        p=p,
        seed=seed,
        block_size=1024,
        )

In [None]:
y = triton_leaky_dropout(x, 0.5, 1)
jnp.sum(y==0.0)/(y.shape[0]*y.shape[1])

In [38]:
jit_leaky_dropout_kernel = jax.jit(triton_leaky_dropout)
%timeit jit_leaky_dropout_kernel(x).block_until_ready()

In [45]:
@triton.testing.perf_report(
    triton.testing.Benchmark(
        x_names=['N'],  # argument names to use as an x-axis for the plot
        x_vals=[128 * i for i in range(2, 100)],  # different possible values for `x_name`
        line_arg='provider',  # argument name whose value corresponds to a different line in the plot
        line_vals=[
            'triton',
            'xla',
            'naive',
        ],  # possible values for `line_arg``
        line_names=[
            "Triton",
            "JAX (XLA)",
            "JAX (Naive)",
        ],  # label name for the lines
        styles=[('blue', '-'), ('green', '-'), ('green', '--')],  # line styles
        ylabel="GB/s",  # label name for the y-axis
        plot_name="dropout-performance",  # name for the plot. Used also as a file name for saving the plot.
        args={'M': 4096},  # values for function arguments not in `x_names` and `y_name`
    ))
def benchmark(M, N, provider):
    x = jax.random.normal(jax.random.PRNGKey(0), (M,N))
    quantiles = [0.5, 0.2, 0.8]
    if provider == 'xla':
        ms, min_ms, max_ms = triton.testing.do_bench(lambda: jit_leaky_relu_dropout(x), quantiles=quantiles)
    if provider == 'triton':
        ms, min_ms, max_ms = triton.testing.do_bench(lambda: jit_leaky_dropout_kernel (x), quantiles=quantiles)
    if provider == 'naive':
        ms, min_ms, max_ms = triton.testing.do_bench(lambda: leaky_relu_dropout(x), quantiles=quantiles)
    num_elements = jnp.prod(jnp.array(x.shape))
    # Calculate the size of each element in bytes
    element_size = jnp.dtype(x.dtype).itemsize
    gbps = lambda ms: 2 * num_elements * element_size * 1e-9 / (ms * 1e-3)
    return gbps(ms), gbps(max_ms), gbps(min_ms)

In [None]:
benchmark.run(print_data=True, show_plots=True)