In [1]:
!pip install triton tabulate pandas matplotlib

Defaulting to user installation because normal site-packages is not writeable
Collecting triton
  Downloading triton-2.0.0.post1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (63.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.2/63.2 MB[0m [31m17.2 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25hCollecting tabulate
  Downloading tabulate-0.9.0-py3-none-any.whl (35 kB)
Collecting lit
  Downloading lit-16.0.0.tar.gz (144 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m145.0/145.0 kB[0m [31m37.8 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25ldone
Collecting cmake
  Downloading cmake-3.26.0-py2.py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (24.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.0/24.0 MB[0m [31m106.1 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
Building wheels for collected packages: lit
  Building wheel for lit (setup.py) ... [?25ldone
[?25h 

In [2]:
import tabulate
import torch

import triton
import triton.language as tl


@triton.jit
def _dropout(
        x_ptr,  # pointer to the input
        x_keep_ptr,  # pointer to a mask of 0s and 1s
        output_ptr,  # pointer to the output
        n_elements,  # number of elements in the `x` tensor
        p,  # probability that an element of `x` is changed to zero
        BLOCK_SIZE: tl.constexpr,
):
    pid = tl.program_id(axis=0)
    block_start = pid * BLOCK_SIZE
    offsets = block_start + tl.arange(0, BLOCK_SIZE)
    mask = offsets < n_elements
    # Load data
    x = tl.load(x_ptr + offsets, mask=mask)
    x_keep = tl.load(x_keep_ptr + offsets, mask=mask)
    # The line below is the crucial part, described in the paragraph above!
    output = tl.where(x_keep, x / (1 - p), 0.0)
    # Write-back output
    tl.store(output_ptr + offsets, output, mask=mask)


def dropout(x, x_keep, p):
    output = torch.empty_like(x)
    assert x.is_contiguous()
    n_elements = x.numel()
    grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
    _dropout[grid](x, x_keep, output, n_elements, p, BLOCK_SIZE=1024)
    return output


# Input tensor
x = torch.randn(size=(10,)).cuda()
# Dropout mask
p = 0.5
x_keep = (torch.rand(size=(10,)) > p).to(torch.int32).cuda()
#
output = dropout(x, x_keep=x_keep, p=p)
print(tabulate.tabulate([
    ["input"] + x.tolist(),
    ["keep mask"] + x_keep.tolist(),
    ["output"] + output.tolist()
]))

---------  --------  --------  -------  --------  --------  ---------  --------  -------  -------  ---------
input      0.398164  -1.14106  1.46103  -1.21333  -1.61976  -0.918388  -1.26812  0.84649  2.82668  -0.473958
keep mask  1          1        0         0         1         1          1        0        1         0
output     0.796329  -2.28213  0         0        -3.23951  -1.83678   -2.53624  0        5.65337   0
---------  --------  --------  -------  --------  --------  ---------  --------  -------  -------  ---------


In [3]:
@triton.jit
def _seeded_dropout(
        x_ptr,
        output_ptr,
        n_elements,
        p,
        seed,
        BLOCK_SIZE: tl.constexpr,
):
    # compute memory offsets of elements handled by this instance
    pid = tl.program_id(axis=0)
    block_start = pid * BLOCK_SIZE
    offsets = block_start + tl.arange(0, BLOCK_SIZE)
    # load data from x
    mask = offsets < n_elements
    x = tl.load(x_ptr + offsets, mask=mask)
    # randomly prune it
    random = tl.rand(seed, offsets)
    x_keep = random > p
    # write-back
    output = tl.where(x_keep, x / (1 - p), 0.0)
    tl.store(output_ptr + offsets, output, mask=mask)


def seeded_dropout(x, p, seed):
    output = torch.empty_like(x)
    assert x.is_contiguous()
    n_elements = x.numel()
    grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
    _seeded_dropout[grid](x, output, n_elements, p, seed, BLOCK_SIZE=1024)
    return output


x = torch.randn(size=(10,)).cuda()
# Compare this to the baseline - dropout mask is never instantiated!
output = seeded_dropout(x, p=0.5, seed=123)
output2 = seeded_dropout(x, p=0.5, seed=123)
output3 = seeded_dropout(x, p=0.5, seed=512)

print(tabulate.tabulate([
    ["input"] + x.tolist(),
    ["output (seed = 123)"] + output.tolist(),
    ["output (seed = 123)"] + output2.tolist(),
    ["output (seed = 512)"] + output3.tolist()
]))

-------------------  ---------  -------  --------  --------  ---------  ------  --------  -------  ---------  --------
input                -0.151539  1.10336  -1.73476  0.939393  -0.961672  1.0487  -1.82303  2.13982  -0.300116  -1.22286
output (seed = 123)   0         2.20672   0        0          0         2.0974   0        0        -0.600231  -2.44572
output (seed = 123)   0         2.20672   0        0          0         2.0974   0        0        -0.600231  -2.44572
output (seed = 512)   0         0        -3.46951  1.87879    0         2.0974  -3.64605  0         0          0
-------------------  ---------  -------  --------  --------  ---------  ------  --------  -------  ---------  --------


In [282]:
import numpy as np
# Exercise 1: dropout for matrix with vector of seeds, 1 seed per row
@triton.jit
def _seeded_matrix_dropout(
        x_ptr,
        output_ptr,
        debug_random_mask_ptr,
        debug_seeds_ptr,
        n_elements,
        p,
        seeds,
        n_elements_per_row: tl.constexpr,
        BLOCK_SIZE: tl.constexpr,
):
    # compute memory offsets of elements handled by this instance
    pid = tl.program_id(axis=0)
    # index of first row in block.
    # e.g. if BLOCK_SIZE is 1024 and this is block 3, then the first row is row 3 * 1024 in the original tensor
    start_row_index = pid * BLOCK_SIZE
    block_start = start_row_index * n_elements_per_row

    # offsets is now a BLOCK_SIZE x n_elements_per_row matrix
    row_offsets = (tl.arange(0, BLOCK_SIZE) * n_elements_per_row)[:, None] # left operand is how many rows down we go in current block, right is num elements for that row
    col_offsets = tl.arange(0, n_elements_per_row)[None, :]
    offsets = block_start + row_offsets + col_offsets

    # load data from x
    mask = offsets < n_elements
    x = tl.load(x_ptr + offsets, mask=mask)
    # randomly prune it
    random_values = tl.zeros((BLOCK_SIZE, n_elements_per_row), dtype=tl.float32) # one random value per row
    for row in range(BLOCK_SIZE):
        # note that we get seed from start_row_index + row, NOT row by itself
        # we need start_row_index to get the right index into seeds vector.

        # random_mask one-hot for current row
        # offsets % block_start effectively makes offsets like a matrix [[0, 1, 2, ... n_elements_per_row - 1], [n_elements_per_row, etc]]
        # then just accept where offsets / n_elements_per_row == row
        random_mask = tl.where(((row_offsets + col_offsets) // n_elements_per_row) == row, 1.0, 0.0)
        # Note: the following doesn't work becaues block_start could be 0 for the first block and mod 0 is invalid
        # Triton won't error out, it'll just silently give you an invalid result.
        # random_mask = tl.where(((offsets % block_start) // n_elements_per_row) == row, 1.0, 0.0)
        if row == 1:
            tl.store(debug_random_mask_ptr + row_offsets + col_offsets, random_mask, mask=mask)

        random_values += (tl.rand(seeds + start_row_index + row, random_values) * random_mask)
        # I'm guessing that, because these are all pointers, you probably can't just index into arrays/tensors the way you would in C or python
        # That's why we have the pointer arithmetic for random_values instead
        # random_values += tl.rand(seeds[start_row_index + row], random_values) * random_mask
        seed_value = tl.load(seeds + start_row_index + row)
        # seed_value = start_row_index + row
        tl.store(debug_seeds_ptr + start_row_index + row, seed_value)
    x_keep = random_values > p
    # write-back
    output = tl.where(x_keep, x / (1 - p), 0.0)
    tl.store(output_ptr + offsets, output, mask=mask)


def seeded_matrix_dropout(x, p, seeds):
    output = torch.empty_like(x)
    assert x.is_contiguous()
    assert seeds.shape == (x.shape[0],), f"seeds should be length of num rows but instead got seeds.shape {seeds.shape} and x.shape {x.shape}"
    # raise AssertionError("ok")
    n_elements = x.numel()
    # print(triton.cdiv(n_elements, 1024))
    # print(n_elements)
    grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) # for now, treat BLOCK_SIZE as the number of rows per block.
    n_elements_per_row = x.shape[1]
    debug_random_mask_ptr = torch.zeros((1024, n_elements_per_row)).cuda()
    debug_seeds_ptr = torch.zeros(seeds.shape).cuda()
    assert seeds.shape == debug_seeds_ptr.shape
    _seeded_matrix_dropout[grid](x, output, debug_random_mask_ptr, debug_seeds_ptr, n_elements, p, seeds, n_elements_per_row, BLOCK_SIZE=1024)
    return output, debug_random_mask_ptr, debug_seeds_ptr

x = torch.randn(size=(1000, 4)).cuda() # n elements per row must be a power of 2
# Compare this to the baseline - dropout mask is never instantiated!
output1, debug_random_mask_ptr1, debug_seeds_ptr1 = seeded_matrix_dropout(x, p=0.75, seeds=torch.rand(x.shape[0]).cuda())
# seeds should match for rows 0 and 1, but not for remaining ones
seed2 = torch.cat((torch.tensor([3]), torch.tensor([5]), torch.tensor(np.random.randint(0, 100, size=(x.shape[0]-2))))).cuda()
seed3 = torch.cat((torch.tensor([3]), torch.tensor([5]), torch.tensor(np.random.randint(0, 100, size=(x.shape[0]-2))))).cuda()
assert not torch.all(torch.eq(seed2, seed3))

# Synchronize calls-- wait for all previous cuda code to finish, so pid=0 is ready for each of these seeded_matrix_dropout calls
torch.cuda.synchronize()
output2, debug_random_mask_ptr2, debug_seeds_ptr2 = seeded_matrix_dropout(x, p=0.5, seeds=seed2)
torch.cuda.synchronize()
output3, debug_random_mask_ptr3, debug_seeds_ptr3 = seeded_matrix_dropout(x, p=0.5, seeds=seed3)
torch.cuda.synchronize()
print(debug_seeds_ptr2[0:7])
print(debug_seeds_ptr3[0:7])

assert torch.all(torch.eq(debug_random_mask_ptr1, debug_random_mask_ptr2))
assert torch.all(torch.eq(debug_random_mask_ptr2, debug_random_mask_ptr3))
random_mask_should_be = torch.zeros((1024, x.shape[1])).cuda()
random_mask_should_be[1, :] = torch.ones(x.shape[1])
assert torch.all(torch.eq(random_mask_should_be, debug_random_mask_ptr1)) # check that the random mask is set correctly each time


# TODO: sometimes these two assertions fail and it's unclear why
assert torch.all(torch.eq(output2[0:2, :], output3[0:2, :])), f"got output2 {output2[0:2, :]} and otuput3 {output3[0:2, :]} from x {x[0:2, :]}"
print(output2)
print(output3)
assert not torch.all(torch.eq(output2, output3))
# output2
# print("ok")

# GPT-4 SUGGESTIONS
suggest = """
It seems like the issue is related to the seed values being loaded incorrectly sometimes. One possible reason could be a race condition or synchronization issue in the Triton kernel. To debug this further, you can try the following steps:

1. Add more print statements in the Triton kernel to check the values of `start_row_index`, `row`, and `seeds` at different points in the kernel execution. This can help you identify if there's any issue with the indexing or pointer arithmetic.

2. Check if there's any issue with the input `seeds` tensor. You can print the `seeds` tensor before passing it to the Triton kernel and compare it with the `debug_seeds_ptr` tensor after the kernel execution.

3. Try running the Triton kernel with a smaller block size and see if the issue persists. This can help you identify if the issue is related to the block size or the way the blocks are being scheduled.

4. You can also try using Triton's built-in synchronization primitives like `tl.sync` and `tl.syncwarp` to ensure that all threads in a block are synchronized before loading the seed values. This can help you identify if there's any race condition or synchronization issue in the kernel.
"""

tensor([ 3.,  5., 64., 54., 31.,  6., 64.], device='cuda:0')
tensor([-2.4132e+18, -2.4139e+18,  6.8079e+18,  6.8090e+18,  6.8095e+18,
        -2.4133e+18, -2.4135e+18], device='cuda:0')


AssertionError: got output2 tensor([[ 0.0000,  0.0000,  0.0000,  0.0000],
        [-3.3235,  3.0144,  1.6038, -3.5172]], device='cuda:0') and otuput3 tensor([[-0.5733, -1.2331,  2.1167, -0.6417],
        [ 0.0000,  0.0000,  0.0000,  0.0000]], device='cuda:0') from x tensor([[-0.2867, -0.6166,  1.0584, -0.3209],
        [-1.6617,  1.5072,  0.8019, -1.7586]], device='cuda:0')

In [263]:
torch.all(torch.eq(debug_random_mask_ptr1.cpu(), torch.zeros((1024, x.shape[1]))))

tensor(False)

In [85]:
torch.all(torch.eq(debug_seeds_ptr2.cpu(), torch.arange(1024, 2024, 1)))

tensor(True)

In [42]:
1025 / 2

512.5

In [43]:
debug_row_indices

tensor([[0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        ...,
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.]], device='cuda:0')