TL;DR: We present a significant performance optimization for PyTorch's multinomial sampling using the Gumbel-Max trick, achieving up to 2.8x speedup on A100 GPUs.
This project implements an efficient alternative to torch.multinomial using the Gumbel-Max trick for sampling from categorical distributions. Our approach shows substantial performance improvements across various batch sizes and vocabulary sizes, particularly beneficial for large-scale language model inference.
Benchmarks conducted on NVIDIA A100 80GB GPU:
- Torch Multinomial: 0.600 ms ± 0.058 ms
- Gumbel-Max: 0.214 ms ± 0.004 ms
- Speedup: 2.8x
- Torch Multinomial: 4.549 ms ± 2.609 ms
- Gumbel-Max: 1.294 ms ± 0.009 ms
- Speedup: 3.5x
- Torch Multinomial: 64.386 ms ± 2.748 ms
- Gumbel-Max: 30.544 ms ± 1.725 ms
- Speedup: 2.1x
The Gumbel-Max trick converts sampling from a categorical distribution into an argmax operation over Gumbel noise, which can be highly optimized on modern GPUs:
def gumbel_sample(logits):
gumbel_noise = -torch.log(-torch.log(torch.rand_like(logits)))
return torch.argmax(logits + gumbel_noise, dim=-1)