Code to accompany the paper Learning Fast Algorithms for Linear Transforms Using Butterfly Factorizations.
- The module
butterfly/butterfly.pycan be used as a drop-in replacement for a
nn.Linearlayer. The files in
butterflydirectory are all that are needed for this use.
The butterfly multiplication is written in C++ and CUDA as PyTorch extension. To install it:
cd butterfly/factor_multiply python setup.py install
Without the C++/CUDA version, butterfly multiplication is still usable, but is
quite slow. The variable
controls whether to use the C++/CUDA version or the pure PyTorch version.
For training, we've had better results with the Adam optimizer than SGD.
- The directory
learning_transformscontains code to learn the transforms as presented in the paper. This directory is presently being developed and refactored.