In [1]:
!pip install triton tabulate pandas matplotlib

Defaulting to user installation because normal site-packages is not writeable

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip available: [0m[31;49m22.3[0m[39;49m -> [0m[32;49m23.0.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython3 -m pip install --upgrade pip[0m


In [2]:
import tabulate
import torch

import triton
import triton.language as tl


@triton.jit
def _dropout(
        x_ptr,  # pointer to the input
        x_keep_ptr,  # pointer to a mask of 0s and 1s
        output_ptr,  # pointer to the output
        n_elements,  # number of elements in the `x` tensor
        p,  # probability that an element of `x` is changed to zero
        BLOCK_SIZE: tl.constexpr,
):
    pid = tl.program_id(axis=0)
    block_start = pid * BLOCK_SIZE
    offsets = block_start + tl.arange(0, BLOCK_SIZE)
    mask = offsets < n_elements
    # Load data
    x = tl.load(x_ptr + offsets, mask=mask)
    x_keep = tl.load(x_keep_ptr + offsets, mask=mask)
    # The line below is the crucial part, described in the paragraph above!
    output = tl.where(x_keep, x / (1 - p), 0.0)
    # Write-back output
    tl.store(output_ptr + offsets, output, mask=mask)


def dropout(x, x_keep, p):
    output = torch.empty_like(x)
    assert x.is_contiguous()
    n_elements = x.numel()
    grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
    _dropout[grid](x, x_keep, output, n_elements, p, BLOCK_SIZE=1024)
    return output


# Input tensor
x = torch.randn(size=(10,)).cuda()
# Dropout mask
p = 0.5
x_keep = (torch.rand(size=(10,)) > p).to(torch.int32).cuda()
#
output = dropout(x, x_keep=x_keep, p=p)
print(tabulate.tabulate([
    ["input"] + x.tolist(),
    ["keep mask"] + x_keep.tolist(),
    ["output"] + output.tolist()
]))

---------  ---------  --------  ---------  -------  --------  --------  -------  ---------  ---------  -------
input      0.0185069  -1.49556  0.0994412  1.10682  0.193011  0.279261  1.63456  -0.169297  -0.602044  1.17376
keep mask  0           1        1          1        1         1         1         0          1         1
output     0          -2.99111  0.198882   2.21363  0.386023  0.558523  3.26912   0         -1.20409   2.34752
---------  ---------  --------  ---------  -------  --------  --------  -------  ---------  ---------  -------


In [3]:
@triton.jit
def _seeded_dropout(
        x_ptr,
        output_ptr,
        n_elements,
        p,
        seed,
        BLOCK_SIZE: tl.constexpr,
):
    # compute memory offsets of elements handled by this instance
    pid = tl.program_id(axis=0)
    block_start = pid * BLOCK_SIZE
    offsets = block_start + tl.arange(0, BLOCK_SIZE)
    # load data from x
    mask = offsets < n_elements
    x = tl.load(x_ptr + offsets, mask=mask)
    # randomly prune it
    random = tl.rand(seed, offsets)
    x_keep = random > p
    # write-back
    output = tl.where(x_keep, x / (1 - p), 0.0)
    tl.store(output_ptr + offsets, output, mask=mask)


def seeded_dropout(x, p, seed):
    output = torch.empty_like(x)
    assert x.is_contiguous()
    n_elements = x.numel()
    grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
    _seeded_dropout[grid](x, output, n_elements, p, seed, BLOCK_SIZE=1024)
    return output


x = torch.randn(size=(10,)).cuda()
# Compare this to the baseline - dropout mask is never instantiated!
output = seeded_dropout(x, p=0.5, seed=123)
output2 = seeded_dropout(x, p=0.5, seed=123)
output3 = seeded_dropout(x, p=0.5, seed=512)

print(tabulate.tabulate([
    ["input"] + x.tolist(),
    ["output (seed = 123)"] + output.tolist(),
    ["output (seed = 123)"] + output2.tolist(),
    ["output (seed = 512)"] + output3.tolist()
]))

-------------------  ---------  --------  -------  ---------  ----------  --------  --------  --------  ---------  ----------
input                -0.957525  -1.15484  1.08971  -0.288357  -0.0398638  -1.21333  0.159923  -1.41682  -0.16531   -0.0420467
output (seed = 123)   0         -2.30969  0         0          0          -2.42667  0          0        -0.330619  -0.0840935
output (seed = 123)   0         -2.30969  0         0          0          -2.42667  0          0        -0.330619  -0.0840935
output (seed = 512)   0          0        2.17943  -0.576715   0          -2.42667  0.319847   0         0          0
-------------------  ---------  --------  -------  ---------  ----------  --------  --------  --------  ---------  ----------


In [11]:
# Exercise 1: dropout for matrix with vector of seeds, 1 seed per row
@triton.jit
def _seeded_matrix_dropout(
        x_ptr,
        output_ptr,
        debug_random_mask_ptr,
        n_elements,
        p,
        seeds,
        n_elements_per_row: tl.constexpr,
        BLOCK_SIZE: tl.constexpr,
):
    # compute memory offsets of elements handled by this instance
    pid = tl.program_id(axis=0)
    # index of first row in block.
    # e.g. if BLOCK_SIZE is 1024 and this is block 3, then the first row is row 3 * 1024 in the original tensor
    start_row_index = pid * BLOCK_SIZE
    block_start = start_row_index * n_elements_per_row

    # offsets is now a BLOCK_SIZE x n_elements_per_row matrix
    row_offsets = (tl.arange(0, BLOCK_SIZE) * n_elements_per_row)[:, None] # left operand is how many rows down we go in current block, right is num elements for that row
    col_offsets = tl.arange(0, n_elements_per_row)[None, :]
    offsets = block_start + row_offsets + col_offsets

    # load data from x
    mask = offsets < n_elements
    x = tl.load(x_ptr + offsets, mask=mask)
    # randomly prune it
    random_values = tl.zeros((BLOCK_SIZE, n_elements_per_row), dtype=tl.float32) # one random value per row
    for row in range(BLOCK_SIZE):
        # note that we get seed from start_row_index + row, NOT row by itself
        # we need start_row_index to get the right index into seeds vector.

        # random_mask one-hot for current row
        # offsets % block_start effectively makes offsets like a matrix [[0, 1, 2, ... n_elements_per_row - 1], [n_elements_per_row, etc]]
        # then just accept where offsets / n_elements_per_row == row
        random_mask = tl.where(((row_offsets + col_offsets) // n_elements_per_row) == row, 1.0, 0.0)
        if row == 1:
            tl.store(debug_random_mask_ptr + row_offsets + col_offsets, random_mask, mask=mask)
        # TODO: ?? why does offsets % block_start not work????
        # random_mask = tl.where(((offsets % block_start) // n_elements_per_row) == row, 1.0, 0.0)
        # random_mask = tl.where(((offsets % block_start) // n_elements_per_row) == 1, 1.0, 0.0)
        # TODO: not sure how exactly to do this? idea is to select the seed from a specific index in seeds
        # but that fails with basically no error message besides
        # ValueError: Did you forget to add @triton.jit ? (`_builder` argument must be provided outside of JIT functions.)
        # which is wrong obviously.     
        random_values += (tl.rand(seeds + start_row_index + row, random_values) * random_mask)
        # random_values += tl.rand(seeds[start_row_index + row], random_values) * random_mask
        # random_values[row] = tl.rand(seeds[start_row_index + row], col_offsets) # fill out one row at a time just isn't supported
    # random = tl.rand(my_seed, offsets)
    x_keep = random_values > p
    # write-back
    output = tl.where(x_keep, x / (1 - p), 0.0)
    tl.store(output_ptr + offsets, output, mask=mask)


def seeded_matrix_dropout(x, p, seeds):
    output = torch.empty_like(x)
    assert x.is_contiguous()
    assert seeds.shape == (x.shape[0],), f"seeds should be length of num rows but instead got seeds.shape {seeds.shape} and x.shape {x.shape}"
    # raise AssertionError("ok")
    n_elements = x.numel()
    # print(triton.cdiv(n_elements, 1024))
    # print(n_elements)
    grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) # for now, treat BLOCK_SIZE as the number of rows per block.
    n_elements_per_row = x.shape[1]
    debug_random_mask_ptr = torch.zeros((1024, n_elements_per_row)).cuda()
    _seeded_matrix_dropout[grid](x, output, debug_random_mask_ptr, n_elements, p, seeds, n_elements_per_row, BLOCK_SIZE=1024)
    return output, debug_random_mask_ptr

x = torch.randn(size=(3, 4)).cuda() # n elements per row must be a power of 2
# Compare this to the baseline - dropout mask is never instantiated!
output1, debug_random_mask_ptr1 = seeded_matrix_dropout(x, p=0.75, seeds=torch.rand(x.shape[0]).cuda())
output2, debug_random_mask_ptr2 = seeded_matrix_dropout(x, p=0.5, seeds=torch.tensor([3, 5, 2]).cuda())
output3, debug_random_mask_ptr3 = seeded_matrix_dropout(x, p=0.5, seeds=torch.tensor([0, 1, 3]).cuda())

print(tabulate.tabulate([
    ["input"] + x.tolist(),
    ["output (first set of seeds)"] + output.tolist(),
    ["output (2nd set)"] + output2.tolist(),
    ["output (3rd set)"] + output3.tolist()
]))

assert torch.all(torch.eq(debug_random_mask_ptr1, debug_random_mask_ptr2))
assert torch.all(torch.eq(debug_random_mask_ptr2, debug_random_mask_ptr3))
random_mask_should_be = torch.zeros((1024, x.shape[1])).cuda()
random_mask_should_be[1, :] = torch.ones(x.shape[1])
assert torch.all(torch.eq(random_mask_should_be, debug_random_mask_ptr3))


---------------------------  ------------------------------------------------------------------------------------  -----------------------------------------------------------------------------------  ---------------------------------------------------------------------------------  -  -  --------  -  -  ---------  ----------
input                        [-1.3508756160736084, 0.11341709643602371, -0.24785758554935455, 1.0647841691970825]  [-2.004702091217041, -0.03761491924524307, 1.1699249744415283, -2.3878366947174072]  [1.4394214153289795, 1.8197963237762451, 0.48485055565834045, 0.5667889714241028]
output (first set of seeds)  0.0                                                                                   -2.3096888065338135                                                                  0.0                                                                                0  0  -2.42667  0  0  -0.330619  -0.0840935
output (2nd set)             [0.0, 0.0, 0.0, 0.0]               

AssertionError: 

In [8]:
debug_random_mask_ptr1.shape
debug_random_mask_ptr1


tensor([[0., 0., 0., 0.],
        [1., 1., 1., 1.],
        [0., 0., 0., 0.],
        ...,
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.]], device='cuda:0')

In [41]:
1025 // 2

512

In [42]:
1025 / 2

512.5

In [66]:
block_size = 3
per_row = 4
data = torch.rand(block_size * 2, per_row)
# data
offsets = torch.arange(block_size * per_row).reshape(block_size, per_row) + block_size * per_row
block_start = block_size * per_row
row = 2
one_hot = ((offsets % block_start) // per_row) == row

print(data)
random_mask = data[:block_size, :per_row] * one_hot
p = 0.5
random_mask > p

tensor([[0.6744, 0.3973, 0.4403, 0.8433],
        [0.7333, 0.2533, 0.9708, 0.7336],
        [0.0207, 0.7678, 0.6073, 0.8185],
        [0.7854, 0.7236, 0.6716, 0.3804],
        [0.6580, 0.0869, 0.6180, 0.8752],
        [0.4964, 0.0204, 0.6187, 0.2358]])


  one_hot = ((offsets % block_start) // per_row) == row


tensor([[False, False, False, False],
        [False, False, False, False],
        [False,  True,  True,  True]])