Skip to content

Computationally identical to standard multi-head attention, but significantly lower memory overhead

License

Notifications You must be signed in to change notification settings

Ali2500/EfficientMultiheadAttention

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

3 Commits
 
 
 
 
 
 

Repository files navigation

Efficient Multihead Attention

PyTorch implementation of "Self-attention Does Not Need O(n2) Memory"

This one-file repo provides a PyTorch implementation of the work by Rabe et al which provides code in JAX: https://arxiv.org/abs/2112.05682

The attention operation coded here is identical to the standard multi-head attention proposed by Vaswani et al., but uses some mathematical tricks and gradient checkpointing to process the input features in chunks, thereby significantly reducing memory overhead.

About

Computationally identical to standard multi-head attention, but significantly lower memory overhead

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages