In [10]:
import torch
import triton
import triton.language as tl
from triton.runtime import driver
import tabulate

In [8]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [11]:
@triton.jit
def _seeded_dropout(
    x_ptr,
    output_ptr,
    n_elements,
    p,
    seed,
    BLOCK_SIZE: tl.constexpr,
):
    # compute memory offsets of elements handled by this instance
    pid = tl.program_id(axis=0)
    block_start = pid * BLOCK_SIZE
    offsets = block_start + tl.arange(0, BLOCK_SIZE)
    # load data from x
    mask = offsets < n_elements
    x = tl.load(x_ptr + offsets, mask=mask)
    # randomly prune it
    random = tl.rand(seed, offsets)
    x_keep = random > p
    # write-back
    output = tl.where(x_keep, x / (1 - p), 0.0)
    tl.store(output_ptr + offsets, output, mask=mask)


def seeded_dropout(x, p, seed):
    output = torch.empty_like(x)
    assert x.is_contiguous()
    n_elements = x.numel()
    grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), )
    _seeded_dropout[grid](x, output, n_elements, p, seed, BLOCK_SIZE=1024)
    return output


x = torch.randn(size=(10, ), device=device)
# Compare this to the baseline - dropout mask is never instantiated!
output = seeded_dropout(x, p=0.5, seed=123)
output2 = seeded_dropout(x, p=0.5, seed=123)
output3 = seeded_dropout(x, p=0.5, seed=512)

print(
    tabulate.tabulate([
        ["input"] + x.tolist(),
        ["output (seed = 123)"] + output.tolist(),
        ["output (seed = 123)"] + output2.tolist(),
        ["output (seed = 512)"] + output3.tolist(),
    ]))


-------------------  ---------  ---------  --------  ---------  --------  --------  ---------  ---------  ---------  -------
input                -0.809575  -0.445639  0.856024  -0.184584  -1.19222  -1.284    -0.397066  0.0317655  -0.265393  0.81817
output (seed = 123)   0         -0.891278  0          0          0        -2.56801   0         0          -0.530786  1.63634
output (seed = 123)   0         -0.891278  0          0          0        -2.56801   0         0          -0.530786  1.63634
output (seed = 512)   0          0         1.71205   -0.369169   0        -2.56801  -0.794132  0           0         0
-------------------  ---------  ---------  --------  ---------  --------  --------  ---------  ---------  ---------  -------


# Extend the kernel to operate over a matrix and use a vector of seeds - one per row.

In [27]:
@triton.jit
def _seeded_dropout(
    input_ptr,
    output_ptr,
    input_row_stride,
    output_row_stride,
    n_cols,
    p,
    seeds,
    BLOCK_SIZE: tl.constexpr,
):

    # the rows of the dropout are independent, so we parallelize across those
    row_idx = tl.program_id(0)
    # The stride represents how much we need to increase the pointer to advance 1 row
    # = num columns for *contiguous* 2d matrices
    row_start_ptr = input_ptr + row_idx * input_row_stride

    # The block size is the next power of two greater than n_cols, so we can fit each
    # row in a single block
    col_offsets = tl.arange(0, BLOCK_SIZE)
    input_ptrs = row_start_ptr + col_offsets

    # Load the row into SRAM, using a mask since BLOCK_SIZE may be > than n_cols
    row = tl.load(input_ptrs, mask=col_offsets < n_cols)

    # get the seed of the row
    seed = tl.load(seeds + row_idx)

    # filter and scale
    random = tl.rand(seed, col_offsets)
    row_keep = random > p
    dropout_output = tl.where(row_keep, row / (1 - p), 0.0)

    # Write back output to DRAM
    output_row_start_ptr = output_ptr + row_idx * output_row_stride
    output_ptrs = output_row_start_ptr + col_offsets
    tl.store(output_ptrs, dropout_output, mask=col_offsets < n_cols)

def seeded_dropout(x, p, seeds):

    n_rows, n_cols = x.shape

    BLOCK_SIZE = triton.next_power_of_2(n_cols)

    output = torch.empty_like(x)

    n_elements = x.numel()

    _seeded_dropout[(n_rows, )](
        x,
        output,
        x.stride(0),
        output.stride(0),
        n_cols,
        p,
        seeds,
        BLOCK_SIZE=BLOCK_SIZE,
    )
    return output

x = torch.randn(size=(2, 5), device=device)
# Compare this to the baseline - dropout mask is never instantiated!
seeds = torch.tensor([123, 132], device=device)
seeds2 = torch.tensor([123, 321], device=device)

output = seeded_dropout(x, p=0.5, seeds=seeds)
output2 = seeded_dropout(x, p=0.5, seeds=seeds2)

print(
    tabulate.tabulate([
        ["input"] + x.tolist(),
        ["output (seed = 123, 132)"] + output.tolist(),
        ["output (seed = 123, 321)"] + output2.tolist(),
    ]))


------------------------  --------------------------------------------------------------------------------------------------------  -------------------------------------------------------------------------------------------------------
input                     [-0.4968452453613281, -1.6862846612930298, 0.6759570837020874, -0.10347044467926025, 2.0680792331695557]  [0.21053476631641388, -1.7139315605163574, 0.7079384326934814, -1.3756039142608643, 1.4101020097732544]
output (seed = 123, 132)  [0.0, -3.3725693225860596, 0.0, 0.0, 0.0]                                                                 [0.0, 0.0, 1.415876865386963, 0.0, 0.0]
output (seed = 123, 321)  [0.0, -3.3725693225860596, 0.0, 0.0, 0.0]                                                                 [0.0, -3.427863121032715, 0.0, 0.0, 2.820204019546509]
------------------------  --------------------------------------------------------------------------------------------------------  -------------------------------------