In [1]:
!pip install triton tabulate pandas matplotlib

Defaulting to user installation because normal site-packages is not writeable

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip available: [0m[31;49m22.3[0m[39;49m -> [0m[32;49m23.0.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython3 -m pip install --upgrade pip[0m


In [2]:
# Given in the tutorial. Hardcoded dropout mask.

import tabulate
import torch

import triton
import triton.language as tl


@triton.jit
def _dropout(
        x_ptr,  # pointer to the input
        x_keep_ptr,  # pointer to a mask of 0s and 1s
        output_ptr,  # pointer to the output
        n_elements,  # number of elements in the `x` tensor
        p,  # probability that an element of `x` is changed to zero
        BLOCK_SIZE: tl.constexpr,
):
    pid = tl.program_id(axis=0)
    block_start = pid * BLOCK_SIZE
    offsets = block_start + tl.arange(0, BLOCK_SIZE)
    mask = offsets < n_elements
    # Load data
    x = tl.load(x_ptr + offsets, mask=mask)
    x_keep = tl.load(x_keep_ptr + offsets, mask=mask)
    # The line below is the crucial part, described in the paragraph above!
    output = tl.where(x_keep, x / (1 - p), 0.0)
    # Write-back output
    tl.store(output_ptr + offsets, output, mask=mask)


def dropout(x, x_keep, p):
    output = torch.empty_like(x)
    assert x.is_contiguous()
    n_elements = x.numel()
    grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
    _dropout[grid](x, x_keep, output, n_elements, p, BLOCK_SIZE=1024)
    return output


# Input tensor
x = torch.randn(size=(10,)).cuda()
# Dropout mask
p = 0.5
x_keep = (torch.rand(size=(10,)) > p).to(torch.int32).cuda()
#
output = dropout(x, x_keep=x_keep, p=p)
print(tabulate.tabulate([
    ["input"] + x.tolist(),
    ["keep mask"] + x_keep.tolist(),
    ["output"] + output.tolist()
]))

---------  ---------  --------  ---------  ----------  --------  -------  ---------  --------  ---------  --------
input      -0.396598  0.914868  -0.463365  -0.0863367  0.249196  -1.7128  -0.748133  -1.20401  -0.680449  -1.24109
keep mask   1         0          1          0          0          1        1          0         0          1
output     -0.793196  0         -0.92673    0          0         -3.4256  -1.49627    0         0         -2.48217
---------  ---------  --------  ---------  ----------  --------  -------  ---------  --------  ---------  --------


In [3]:
# Given in the tutorial. random seed controls dropout for a vector
@triton.jit
def _seeded_dropout(
        x_ptr,
        output_ptr,
        n_elements,
        p,
        seed,
        BLOCK_SIZE: tl.constexpr,
):
    # compute memory offsets of elements handled by this instance
    pid = tl.program_id(axis=0)
    block_start = pid * BLOCK_SIZE
    offsets = block_start + tl.arange(0, BLOCK_SIZE)
    # load data from x
    mask = offsets < n_elements
    x = tl.load(x_ptr + offsets, mask=mask)
    # randomly prune it
    random = tl.rand(seed, offsets)
    x_keep = random > p
    # write-back
    output = tl.where(x_keep, x / (1 - p), 0.0)
    tl.store(output_ptr + offsets, output, mask=mask)


def seeded_dropout(x, p, seed):
    output = torch.empty_like(x)
    assert x.is_contiguous()
    n_elements = x.numel()
    grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
    _seeded_dropout[grid](x, output, n_elements, p, seed, BLOCK_SIZE=1024)
    return output


x = torch.randn(size=(10,)).cuda()
# Compare this to the baseline - dropout mask is never instantiated!
output = seeded_dropout(x, p=0.5, seed=123)
output2 = seeded_dropout(x, p=0.5, seed=123)
output3 = seeded_dropout(x, p=0.5, seed=512)

print(tabulate.tabulate([
    ["input"] + x.tolist(),
    ["output (seed = 123)"] + output.tolist(),
    ["output (seed = 123)"] + output2.tolist(),
    ["output (seed = 512)"] + output3.tolist()
]))

-------------------  ---------  --------  --------  ---------  ---------  -------  --------  ---------  -------  ----------
input                -0.868897  0.403624  -1.51665  -0.749709  -0.951039  1.94522  -2.04819  0.0271463  2.01441  -0.0647303
output (seed = 123)   0         0.807247   0         0          0         3.89043   0        0          4.02882  -0.129461
output (seed = 123)   0         0.807247   0         0          0         3.89043   0        0          4.02882  -0.129461
output (seed = 512)   0         0         -3.0333   -1.49942    0         3.89043  -4.09638  0          0         0
-------------------  ---------  --------  --------  ---------  ---------  -------  --------  ---------  -------  ----------


In [5]:
# Exercise 1: dropout for matrix with vector of seeds, 1 seed per row
import numpy as np
import torch

import triton
import triton.language as tl

import os
# ensure CUDA immediately returns errors so we know it's wrong immediately if something is wrong
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

# Thank you to https://twitter.com/cis_female/ for helping debug this so that it works
# now! My original approach was really complex & buggy but you don't actually need
# anything more complicated than looping through each row in the block, generating a
# new mask for that row.
@triton.jit
def _seeded_matrix_dropout(
        x_ptr,
        output_ptr,
        debug_random_mask_ptr,
        debug_seeds_ptr,
        p,
        seeds,
        n_elements_per_row: tl.constexpr,
        BLOCK_SIZE: tl.constexpr,
):
    # Note: the tl.program_id(0) * BLOCK_SIZE is essential so when writing to a location
    # we don't have overwriting (because each block will have its own set of rows
    # ranging from 0 to BLOCK_SIZE - 1).
    # In general if writing data is inconsistent across runs, it's most likely a
    # problem with contention with all the block programs writing to the same spot
    # in parallel.
    block_offset = tl.program_id(0) * BLOCK_SIZE
    # Note that we have block_offset * n_elements_per_row if we're indexing into the
    # matrix the shape of x, while just block_offset when indexing into a vector
    # the size of x.shape[0] e.g. seeds
    current_block_ptr = x_ptr + block_offset * n_elements_per_row
    current_block_seeds_ptr = seeds + block_offset
    current_block_debug_random_mask_ptr = debug_random_mask_ptr + block_offset * n_elements_per_row
    current_block_debug_seeds_ptr = debug_seeds_ptr + block_offset
    current_block_output_ptr = output_ptr + block_offset * n_elements_per_row

    for row in range(BLOCK_SIZE):
        current_row = tl.load(current_block_ptr
                              + row * n_elements_per_row
                              + tl.arange(0, n_elements_per_row))
        rand_seed = tl.load(current_block_seeds_ptr + row)
        tl.store(current_block_debug_seeds_ptr + row, rand_seed)
        random_mask = tl.rand(rand_seed, tl.arange(0, n_elements_per_row)) < p
        tl.store(current_block_debug_random_mask_ptr
                 + row * n_elements_per_row
                 + tl.arange(0, n_elements_per_row), random_mask)
        masked_xs = tl.where(random_mask, current_row, 0.0)
        tl.store(current_block_output_ptr + row * n_elements_per_row + tl.arange(0, n_elements_per_row), masked_xs)

BLOCK_SIZE = 4 # should be much smaller than your input. e.g. here BLOCK_SIZE is num
# rows in a block. Also a previous mistake was to do cdiv based on num elements in
# total, not num rows for calculating grid, which came from me copying the grid code
# from the matmul tutorial-- oops!
def seeded_matrix_dropout(x, p, seeds):
    output = torch.empty_like(x)
    assert x.is_contiguous()
    assert seeds.shape == (x.shape[0],), f"seeds should be length of num rows but instead got seeds.shape {seeds.shape} and x.shape {x.shape}"
    # BLOCK_SIZE is num rows per block
    grid = lambda meta: (triton.cdiv(x.shape[0], meta['BLOCK_SIZE']),)
    n_elements_per_row = x.shape[1]
    # these debug values should get overwritten so these initial values are just to indicate if it's not being written for some reason
    debug_random_mask_ptr = 22 * torch.ones(x.shape, dtype=torch.int).cuda() # random masks used for each row
    debug_seeds_ptr = 21 * torch.ones(seeds.shape, dtype=torch.int).cuda()
    _seeded_matrix_dropout[grid](x, output, debug_random_mask_ptr, debug_seeds_ptr, p, seeds, n_elements_per_row, BLOCK_SIZE=BLOCK_SIZE)
    return output, debug_random_mask_ptr, debug_seeds_ptr

x = torch.randn(size=(1000, 64)).cuda() # n elements per row must be a power of 2
# Compare this to the baseline - dropout mask is never instantiated!
output1, debug_random_mask_ptr1, debug_seeds_ptr1 = seeded_matrix_dropout(x, p=0.75, seeds=torch.rand(x.shape[0]).cuda())
# seeds should match for rows 0 and 1, but not for remaining ones
seed2 = torch.cat((torch.tensor([30, 50, 100, 4, 41]), torch.tensor(np.random.randint(0, 100, size=(x.shape[0]-5))))).cuda().to(torch.int)
seed3 = torch.cat((torch.tensor([30, 50, 100, 4, 41]), torch.tensor(np.random.randint(0, 100, size=(x.shape[0]-5))))).cuda().to(torch.int)
assert torch.all(torch.eq(seed2[0:5], seed3[0:5]))
assert not torch.all(torch.eq(seed2, seed3))

output2, debug_random_mask_ptr2, debug_seeds_ptr2 = seeded_matrix_dropout(x, p=0.5, seeds=seed2)
output3, debug_random_mask_ptr3, debug_seeds_ptr3 = seeded_matrix_dropout(x, p=0.5, seeds=seed3)

# Check seeds being used correctly
# print(seed2[0:10])
# print(debug_seeds_ptr2[0:10])
# print(seed3[0:10])
# print(debug_seeds_ptr3[0:10])
assert torch.all(torch.eq(seed2, debug_seeds_ptr2)), f"seed2 {seed2} debug {debug_seeds_ptr2}"
assert torch.all(torch.eq(seed3, debug_seeds_ptr3)), f"seed3 {seed2} debug {debug_seeds_ptr3}"

# only the first 5 seeds are hardcoded to be consistently the same
# print(f"debug_random_mask_ptr2{debug_random_mask_ptr2[0:5, :]}")
# print(f"debug_random_mask_ptr3{debug_random_mask_ptr3[0:5, :]}")
assert torch.all(torch.eq(debug_random_mask_ptr2[0:5, :], debug_random_mask_ptr3[0:5, :])), "expect mask for first 5 rows to match, because those seeds are hardcoded to be the same"

# first 5 rows should match, but the rest shouldn't!
assert torch.all(torch.isclose(output2[0:5, :], output3[0:5, :])), f"got output2\n{output2[0:5, :]} and otuput3\n{output3[0:5, :]} from x\n{x[0:5, :]}"
assert not torch.all(torch.isclose(output2, output3))
print("ok, all tests pass")


ok
