# Low Memory Dropout

在 `dropout` 的一般实现中，我们在 `Forward` 时需要生成一个和输入 Tensor 相同形状的 Mask Tensor，我们一般称为 `x_keep`，其中值为 0 的位置表示要丢弃，值为 1 的位置的输入原样输出。同时我们要将 `x_keep`保存起来，用于 Backward 时计算梯度。

<div class="wy-nav-content-img">
    <img src="assets/low-memory-dropout_mask.drawio.svg" width="600px" alt="Dropout 的执行示意图">
    <p>Dropout 的执行示意图</p>
</div>


这就相当于多占用了一份内存，我们来考虑，能否在 Backward 的时候，将 `x_keep` 重新计算出来呢？实际上就是和 Forward 的时候生成相同的随机数就可以了，那么我们只需要保存好随机种子就行了。

`triton.language.rand(seed, offsets)` 接口用于根据 `offsets` 生成一组随机数，这样我们就可以在运行时，计算出 `x_keep`：

```python
x_keep = tl.rand(seed, offsets) > p
```

`tl.rand` 的接口除了 seed 外还要求传入 offsets，而不是一个 shape，这是因为随机数的生成需要考虑到所有的线程块，避免不同的线程块生成相同的随机数。

In [1]:
import tabulate
import torch

import triton
import triton.language as tl


@triton.jit
def dropout_kernel(
    x_ptr, x_keep_ptr, output_ptr, n_elements, p, BLOCK_SIZE: tl.constexpr
):
    pid = tl.program_id(axis=0)
    offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
    mask = offsets < n_elements

    # load data
    x = tl.load(x_ptr + offsets, mask)
    x_keep = tl.load(x_keep_ptr + offsets, mask)
    output = tl.where(x_keep, x / (1 - p), 0)
    # write back to output
    tl.store(output_ptr + offsets, output, mask)


def dropout(x: torch.Tensor, x_keep: torch.Tensor, p):
    output = torch.empty_like(x)
    assert x.is_contiguous()
    n_elements = x.numel()
    grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
    dropout_kernel[grid](x, x_keep, output, n_elements, p, BLOCK_SIZE=1024)
    return output


@triton.jit
def seeded_dropout_kernel(
    x_ptr, output_ptr, n_elements, p, seed, BLOCK_SIZE: tl.constexpr
):
    pid = tl.program_id(axis=0)
    offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
    mask = offsets < n_elements
    x_keep = tl.rand(seed, offsets) > p
    # load data
    x = tl.load(x_ptr + offsets, mask)
    output = tl.where(x_keep, x / (1 - p), 0)
    # write back to output
    tl.store(output_ptr + offsets, output, mask)


def seeded_dropout(x: torch.Tensor, p, seed=123):
    output = torch.empty_like(x)
    assert x.is_contiguous()
    n_elements = x.numel()
    grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
    seeded_dropout_kernel[grid](x, output, n_elements, p, seed, BLOCK_SIZE=1024)
    return output


def dropout_test():
    x = torch.randn((10,), device="cuda")
    p = 0.5
    x_keep = torch.rand(x.shape) > p
    x_keep = x_keep.to(dtype=torch.int32, device="cuda")
    output = dropout(x, x_keep, p)
    print(
        tabulate.tabulate(
            [
                ["input"] + x.tolist(),
                ["keep mask"] + x_keep.tolist(),
                ["output"] + output.tolist(),
            ]
        )
    )


def seeded_dropout_test():
    x = torch.randn((10,), device="cuda")
    p = 0.5
    output1 = seeded_dropout(x, p, 423)
    output2 = seeded_dropout(x, p, 123)
    output3 = seeded_dropout(x, p, 423)
    print(
        tabulate.tabulate(
            [
                ["input"] + x.tolist(),
                ["output (seed 42)"] + output1.tolist(),
                ["output (seed 123)"] + output2.tolist(),
                ["output (seed 42)"] + output3.tolist(),
            ]
        )
    )


dropout_test()
seeded_dropout_test()

---------  --------  -------  ----------  --------  -------  --------  ---------  --------  ---------  -------
input      0.799002  1.09926  -0.0294367  0.408484  1.90967  -2.65934  -0.446925  0.318453  -0.166543  1.86979
keep mask  1         0         1          0         1         0         1         0          1         1
output     1.598     0        -0.0588734  0         3.81934   0        -0.89385   0         -0.333087  3.73957
---------  --------  -------  ----------  --------  -------  --------  ---------  --------  ---------  -------
-----------------  --------  --------  ----------  -------  -------  -------  --------  ---------  -------  ---------
input              0.932829  0.556717  -0.0195228  1.48022  -0.9475  -0.9432  -1.13308  -0.733073  1.15927  -0.355865
output (seed 42)   0         0          0          0        -1.895    0        0        -1.46615   0        -0.71173
output (seed 123)  0         1.11343    0          0         0       -1.8864   0         0        