In [1]:
!pip install triton tabulate pandas matplotlib

Defaulting to user installation because normal site-packages is not writeable
Collecting triton
  Downloading triton-2.0.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (63.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.2/63.2 MB[0m [31m38.6 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25hCollecting tabulate
  Downloading tabulate-0.9.0-py3-none-any.whl (35 kB)
Collecting cmake
  Downloading cmake-3.25.2-py2.py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (23.7 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m23.7/23.7 MB[0m [31m102.6 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
Collecting lit
  Downloading lit-15.0.7.tar.gz (132 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m132.3/132.3 kB[0m [31m37.3 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25ldone
Building wheels for collected packages: lit
  Building wheel for lit (setup.py) ... [?25ldone
[?25h  Creat

In [3]:
import tabulate
import torch

import triton
import triton.language as tl


@triton.jit
def _dropout(
        x_ptr,  # pointer to the input
        x_keep_ptr,  # pointer to a mask of 0s and 1s
        output_ptr,  # pointer to the output
        n_elements,  # number of elements in the `x` tensor
        p,  # probability that an element of `x` is changed to zero
        BLOCK_SIZE: tl.constexpr,
):
    pid = tl.program_id(axis=0)
    block_start = pid * BLOCK_SIZE
    offsets = block_start + tl.arange(0, BLOCK_SIZE)
    mask = offsets < n_elements
    # Load data
    x = tl.load(x_ptr + offsets, mask=mask)
    x_keep = tl.load(x_keep_ptr + offsets, mask=mask)
    # The line below is the crucial part, described in the paragraph above!
    output = tl.where(x_keep, x / (1 - p), 0.0)
    # Write-back output
    tl.store(output_ptr + offsets, output, mask=mask)


def dropout(x, x_keep, p):
    output = torch.empty_like(x)
    assert x.is_contiguous()
    n_elements = x.numel()
    grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
    _dropout[grid](x, x_keep, output, n_elements, p, BLOCK_SIZE=1024)
    return output


# Input tensor
x = torch.randn(size=(10,)).cuda()
# Dropout mask
p = 0.5
x_keep = (torch.rand(size=(10,)) > p).to(torch.int32).cuda()
#
output = dropout(x, x_keep=x_keep, p=p)
print(tabulate.tabulate([
    ["input"] + x.tolist(),
    ["keep mask"] + x_keep.tolist(),
    ["output"] + output.tolist()
]))

---------  --------  --------  ---------  --------  -------  ---------  --------  ---------  --------  -------
input      -1.75378  0.495332  -0.793169  -1.17336  2.27975  0.0690164  0.410273  -0.796761  0.769278  0.55273
keep mask   0        0          1          0        0        0          0          1         1         0
output      0        0         -1.58634    0        0        0          0         -1.59352   1.53856   0
---------  --------  --------  ---------  --------  -------  ---------  --------  ---------  --------  -------


In [3]:
@triton.jit
def _seeded_dropout(
        x_ptr,
        output_ptr,
        n_elements,
        p,
        seed,
        BLOCK_SIZE: tl.constexpr,
):
    # compute memory offsets of elements handled by this instance
    pid = tl.program_id(axis=0)
    block_start = pid * BLOCK_SIZE
    offsets = block_start + tl.arange(0, BLOCK_SIZE)
    # load data from x
    mask = offsets < n_elements
    x = tl.load(x_ptr + offsets, mask=mask)
    # randomly prune it
    random = tl.rand(seed, offsets)
    x_keep = random > p
    # write-back
    output = tl.where(x_keep, x / (1 - p), 0.0)
    tl.store(output_ptr + offsets, output, mask=mask)


def seeded_dropout(x, p, seed):
    output = torch.empty_like(x)
    assert x.is_contiguous()
    n_elements = x.numel()
    grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
    _seeded_dropout[grid](x, output, n_elements, p, seed, BLOCK_SIZE=1024)
    return output


x = torch.randn(size=(10,)).cuda()
# Compare this to the baseline - dropout mask is never instantiated!
output = seeded_dropout(x, p=0.5, seed=123)
output2 = seeded_dropout(x, p=0.5, seed=123)
output3 = seeded_dropout(x, p=0.5, seed=512)

print(tabulate.tabulate([
    ["input"] + x.tolist(),
    ["output (seed = 123)"] + output.tolist(),
    ["output (seed = 123)"] + output2.tolist(),
    ["output (seed = 512)"] + output3.tolist()
]))

-------------------  --------  -------  ---------  -------  ---------  ---------  --------  --------  --------  --------
input                0.503197  1.18018  -0.714294  1.78129  -0.106496  -0.174402  0.734475  0.350418  0.307363  -1.54676
output (seed = 123)  0         2.36036   0         0         0         -0.348805  0         0         0.614726  -3.09352
output (seed = 123)  0         2.36036   0         0         0         -0.348805  0         0         0.614726  -3.09352
output (seed = 512)  0         0        -1.42859   3.56258   0         -0.348805  1.46895   0         0          0
-------------------  --------  -------  ---------  -------  ---------  ---------  --------  --------  --------  --------


In [21]:
# Exercise 1: dropout for matrix with vector of seeds, 1 seed per row
@triton.jit
def _seeded_matrix_dropout(
        x_ptr,
        output_ptr,
        n_elements,
        p,
        seeds,
        n_elements_per_row: tl.constexpr,
        BLOCK_SIZE: tl.constexpr,
):
    # compute memory offsets of elements handled by this instance
    pid = tl.program_id(axis=0)
    # index of first row in block.
    # e.g. if BLOCK_SIZE is 1024 and this is block 3, then the first row is row 3 * 1024 in the original tensor
    start_row_index = pid * BLOCK_SIZE
    block_start = start_row_index * n_elements_per_row

    # offsets is now a BLOCK_SIZE x n_elements_per_row matrix
    row_offsets = (tl.arange(0, BLOCK_SIZE) * n_elements_per_row)[:, None] # left operand is how many rows down we go in current block, right is num elements for that row
    col_offsets = tl.arange(0, n_elements_per_row)[None, :]
    offsets = block_start + row_offsets + col_offsets

    # load data from x
    mask = offsets < n_elements
    x = tl.load(x_ptr + offsets, mask=mask)
    # randomly prune it
    # my_seed = 13
    random_values = tl.zeros((BLOCK_SIZE, n_elements_per_row), dtype=tl.float32) # one random value per row
    for row in range(0, BLOCK_SIZE, 1):
        # note that we get seed from start_row_index + row, NOT row by itself
        # we need start_row_index to get the right index into seeds vector.

        # random_mask one-hot for current row
        # offsets % block_start effectively makes offsets like a matrix [[0, 1, 2, ... n_elements_per_row - 1], [n_elements_per_row, etc]]
        # then just accept where offsets / n_elements_per_row == row
        random_mask = tl.where((offsets % block_start) / n_elements_per_row, 1.0, 0.0)
        # TODO: not sure how exactly to do this? idea is to select the seed from a specific index in seeds
        # but that fails with basically no error message besides
        # ValueError: Did you forget to add @triton.jit ? (`_builder` argument must be provided outside of JIT functions.)
        # which is wrong obviously.     
        my_seed = seeds
        random_values += tl.rand(my_seed, random_values) * random_mask
        # random_values += tl.rand(seeds[start_row_index + row], random_values) * random_mask
        # random_values[row] = tl.rand(seeds[start_row_index + row], col_offsets) # fill out one row at a time just isn't supported
    # random = tl.rand(my_seed, offsets)
    x_keep = random_values > p
    # write-back
    output = tl.where(x_keep, x / (1 - p), 0.0)
    tl.store(output_ptr + offsets, output, mask=mask)


def seeded_matrix_dropout(x, p, seeds):
    output = torch.empty_like(x)
    assert x.is_contiguous()
    assert seeds.shape == (x.shape[0],), f"seeds should be length of num rows but instead got seeds.shape {seeds.shape} and x.shape {x.shape}"
    # raise AssertionError("ok")
    n_elements = x.numel()
    # print(triton.cdiv(n_elements, 1024))
    # print(n_elements)
    grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) # for now, treat BLOCK_SIZE as the number of rows per block.
    n_elements_per_row = x.shape[1]
    _seeded_matrix_dropout[grid](x, output, n_elements, p, seeds, n_elements_per_row, BLOCK_SIZE=1024)
    return output

x = torch.randn(size=(3, 2)).cuda()
# Compare this to the baseline - dropout mask is never instantiated!
output = seeded_matrix_dropout(x, p=0.5, seeds=torch.tensor([123, 101, 1]).cuda())
output2 = seeded_matrix_dropout(x, p=0.5, seeds=torch.tensor([123, 101, 2]).cuda())
output3 = seeded_matrix_dropout(x, p=0.5, seeds=torch.tensor([512, 101, 3]).cuda())

print(tabulate.tabulate([
    ["input"] + x.tolist(),
    ["output (first set of seeds)"] + output.tolist(),
    ["output (2nd set)"] + output2.tolist(),
    ["output (3rd set)"] + output3.tolist()
]))

---------------------------  -------------------------------------------  ----------------------------------------  -----------------------------------------
input                        [-0.8444323539733887, 0.004563711117953062]  [-1.417240858078003, 0.5809898972511292]  [0.4518235921859741, 0.47092223167419434]
output (first set of seeds)  [-1.6888647079467773, 0.009127422235906124]  [-2.834481716156006, 1.1619797945022583]  [0.9036471843719482, 0.9418444633483887]
output (2nd set)             [-1.6888647079467773, 0.009127422235906124]  [-2.834481716156006, 1.1619797945022583]  [0.9036471843719482, 0.9418444633483887]
output (3rd set)             [-1.6888647079467773, 0.009127422235906124]  [-2.834481716156006, 1.1619797945022583]  [0.9036471843719482, 0.9418444633483887]
---------------------------  -------------------------------------------  ----------------------------------------  -----------------------------------------


In [21]:
block_size = 1024
n_elements_per_row = 10
n_elements_per_row * torch.arange(0, 1024)[:, None] + torch.arange(0, n_elements_per_row)[None, :]

tensor([[    0,     1,     2,  ...,     7,     8,     9],
        [   10,    11,    12,  ...,    17,    18,    19],
        [   20,    21,    22,  ...,    27,    28,    29],
        ...,
        [10210, 10211, 10212,  ..., 10217, 10218, 10219],
        [10220, 10221, 10222,  ..., 10227, 10228, 10229],
        [10230, 10231, 10232,  ..., 10237, 10238, 10239]])