# In-place Softmax Example

In [1]:
import torch

In [2]:
# Dummy dataset's output just before softmax
x = torch.rand((32, 50000000), device='cuda')

In [3]:
# This will most likely run out of memory depending on your GPU
out = torch.nn.Softmax(dim=1)(x)

RuntimeError: CUDA out of memory. Tried to allocate 5.96 GiB (GPU 0; 10.92 GiB total capacity; 5.96 GiB already allocated; 4.27 GiB free; 5.96 GiB reserved in total by PyTorch)

In [4]:
def in_place_softmax(x):
    # In-place operations in PyTorch are generally represented with an ending underscore
    # x = e^x
    torch.exp_(x)
    # Compute the denominator
    z = torch.sum(x, dim=1)
    # Using `out` destination to be x itself to prevent allocation of memory
    torch.div(x.transpose_(1, 0), z, out=x)
    x.transpose_(1, 0)
    return x

In [5]:
# This shouldn't run out of memory
in_place_softmax(x)

tensor([[2.4643e-08, 3.1151e-08, 2.4846e-08,  ..., 1.2965e-08, 1.2426e-08,
         1.3507e-08],
        [1.2732e-08, 1.8277e-08, 2.1367e-08,  ..., 1.8857e-08, 1.6789e-08,
         2.4582e-08],
        [2.2021e-08, 1.5001e-08, 2.6280e-08,  ..., 2.8042e-08, 1.1684e-08,
         1.4510e-08],
        ...,
        [1.3430e-08, 2.2775e-08, 2.7292e-08,  ..., 1.7282e-08, 2.1122e-08,
         1.4342e-08],
        [1.1880e-08, 1.1644e-08, 1.8411e-08,  ..., 1.2490e-08, 2.8656e-08,
         2.0938e-08],
        [3.0777e-08, 2.1498e-08, 1.1740e-08,  ..., 1.6667e-08, 1.5999e-08,
         2.4465e-08]], device='cuda:0')

## Testing

In [6]:
a = torch.randn((32, 10), device='cuda')

In [7]:
torch.isclose(torch.nn.Softmax(dim=1)(a), in_place_softmax(a)).all()

tensor(True, device='cuda:0')