In [6]:
!uv pip install tabulate

[2mUsing Python 3.12.3 environment at: /home/allen/miniconda3[0m
[2K[2mResolved [1m1 package[0m [2min 458ms[0m[0m                                          [0m
[2K[37m⠙[0m [2mPreparing packages...[0m (0/1)                                                   
[2K[1A[37m⠙[0m [2mPreparing packages...[0m (0/1)----[0m[0m     0 B/34.43 KiB                     [1A
[2K[1A[37m⠙[0m [2mPreparing packages...[0m (0/1)----[0m[0m 16.00 KiB/34.43 KiB                   [1A
[2K[1A[37m⠙[0m [2mPreparing packages...[0m (0/1)2m--[0m[0m 32.00 KiB/34.43 KiB                   [1A
[2K[2mPrepared [1m1 package[0m [2min 150ms[0m[0m                                                  [1A
[2K[2mInstalled [1m1 package[0m [2min 14ms[0m[0m                                 [0m
 [32m+[39m [1mtabulate[0m[2m==0.9.0[0m


In [1]:
import torch
import triton
import triton.language as tl

DEVICE = torch.device(f'cuda:{torch.cuda.current_device()}')

In [2]:
@triton.jit
def _dropout(
    x_ptr,
    x_keep_ptr,
    output_ptr,
    n_elements,
    p,
    BLOCK_SIZE: tl.constexpr
):
    PID = tl.program_id(axis=0)
    block_start = PID * BLOCK_SIZE
    offsets = block_start + tl.arange(0, BLOCK_SIZE)
    mask = offsets < n_elements
    x = tl.load(x_ptr + offsets, mask=mask)
    x_keep = tl.load(x_keep_ptr + offsets, mask=mask)
    output = tl.where(x_keep, x / (1-p), 0.0)
    tl.store(output_ptr + offsets, output, mask=mask)

In [8]:
def dropout(x, x_keep, p):
    output = torch.empty_like(x)
    assert x.is_contiguous()
    n_elements = x.numel()
    grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
    _dropout[grid](x, x_keep, output, n_elements, p, BLOCK_SIZE=1024)
    return output



In [None]:
import tabulate

x = torch.randn(size=(10,), device=DEVICE)
p = 0.5
x_keep = (torch.randn(size=(10,), device=DEVICE) > p).to(torch.int32)
# print(x_keep)

output = dropout(x, x_keep=x_keep, p=p)
print(tabulate.tabulate([
    ["input"] + x.tolist(),
    ["keep mask"] + x_keep.tolist(),
    ["output"] + output.tolist(),
]))

tensor([0, 1, 0, 0, 1, 0, 1, 0, 0, 0], device='cuda:0', dtype=torch.int32)
---------  -------  --------  -------  --------  -------  --------  ----------  --------  -------  -------
input      -1.8765  -1.47489  -0.7433  -1.74919  2.126    -1.70117  -0.0978969  0.477327  1.08858  -1.2364
keep mask   0        1         0        0        1         0         1          0         0         0
output      0       -2.94979   0        0        4.25201   0        -0.195794   0         0         0
---------  -------  --------  -------  --------  -------  --------  ----------  --------  -------  -------


In [10]:
@triton.jit
def _seeded_dropout(
    x_ptr,
    output_ptr,
    n_elements,
    p,    # float32 [0,1]
    seed, # int32
    BLOCK_SIZE: tl.constexpr,
):
    PID = tl.program_id(axis=0)
    block_start = PID * BLOCK_SIZE
    offsets = block_start + tl.arange(0, BLOCK_SIZE)
    mask = offsets < n_elements
    x = tl.load(x_ptr + offsets, mask=mask)
    x_keep = tl.rand(seed, offsets) > p
    output = tl.where(x_keep, x / (1-p), 0.0)
    tl.store(output_ptr + offsets, output, mask=mask)  

In [11]:
def seeded_dropout(x, p, seed):
    output = torch.empty_like(x)
    assert x.is_contiguous()
    n_elements = x.numel()
    grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
    _seeded_dropout[grid](x, output, n_elements, p, seed, BLOCK_SIZE=1024)
    return output


In [12]:
x = torch.randn(size=(10, ), device=DEVICE)
# Compare this to the baseline - dropout mask is never instantiated!
output = seeded_dropout(x, p=0.5, seed=123)
output2 = seeded_dropout(x, p=0.5, seed=123)
output3 = seeded_dropout(x, p=0.5, seed=512)

print(
    tabulate.tabulate([
        ["input"] + x.tolist(),
        ["output (seed = 123)"] + output.tolist(),
        ["output (seed = 123)"] + output2.tolist(),
        ["output (seed = 512)"] + output3.tolist(),
    ]))

-------------------  ---------  -------  --------  ------------  -------  ---------  --------  --------  --------  ---------
input                -0.103223  1.48454  -1.93938  -0.000645571  1.17313  -0.079865  -1.35829  -0.50152  -1.40758  -0.724626
output (seed = 123)   0         2.96907   0         0            0        -0.15973    0         0        -2.81516  -1.44925
output (seed = 123)   0         2.96907   0         0            0        -0.15973    0         0        -2.81516  -1.44925
output (seed = 512)   0         0        -3.87876  -0.00129114   0        -0.15973   -2.71659   0         0         0
-------------------  ---------  -------  --------  ------------  -------  ---------  --------  --------  --------  ---------
