Neural Network Compression - Haiku
Implementation of several neural network compression techniques (knowledge distillation, pruning, quantization, factorization), in Haiku.
For an introduction to neural network compression, see 4-popular-model-compression-techniques-explained.
Or install locally:
git clone https://github.com/Brandhsu/nn-compress-haiku/
cd nn-compress-haiku
git lfs pull
pip install -r requirements.txt
First, train a model on CIFAR-10.
python scripts/01_train.py --save-dir models
Next, optionally train a model with knowledge distillation.
python scripts/02_train_kd.py --model-path models/params.pkl --save-dir models
Then compress it!
python scripts/03_compress.py --model-path models/params.pkl --compression-func svd --save-dir figs --log-dir logs
Note: Compression happens post-training in a layer-by-layer (local) fashion.
The following training checkpoints (non-exponentially averaged) were saved after 10,0001 training iterations.
In this case, simply training a smaller model with supervised labels performed best, followed by training a smaller model with a larger model, and training a larger model with supervised labels.
Name | Test Accuracy | Latency | Size |
---|---|---|---|
teacher | 62.36% | 650 ms ± 415 ms per loop (mean ± std. dev. of 7 runs, 10 loops each) | 121.9 MB |
student-alpha-0.0 | 62.97% | 423 ms ± 16.4 ms per loop (mean ± std. dev. of 7 runs, 10 loops each) | 295 kB |
student-alpha-0.5 | 63.24% | 421 ms ± 4.52 ms per loop (mean ± std. dev. of 7 runs, 10 loops each) | 295 kB |
student-alpha-1.0 | 63.59% | 418 ms ± 5.12 ms per loop (mean ± std. dev. of 7 runs, 10 loops each) | 295 kB |
Note: These checkpoints were created directly after training without post-training compression.
Several other experiments were conducted to measure the impact of post-training compression with respect to accuracy and latency using the following techniques:
- pruning: masked pruning via weight magnitude
- quantization: linear quantization via uniform sampling
- factorization: low-rank reconstruction via svd
Accuracy | Latency |
---|---|
Note: These results attained with the teacher model on the CIFAR-10 test set.
- Accuracy tends to decreases with compression, however, both linear quantization and weight pruning were surprisingly robust.
- This result is intriguing as it demonstrates the superiority of unstructured compression techniques over structured compression techniques.
- Some hypotheses for further investigation:
- Important learned weights may lack structure due to the random nature of how initial weights are set.
- Poor low-rank approximation may be due to noise/outliers (weights having poor SNR).
- Latency does not decrease with compression (in the plot above) since the number of matrix multiplication operations remain the same (NOTE: 0% compression is an outlier due to JAX compilation).
- Additional work is required to obtain the benefits of compression with respect to latency such as factorizing the weights (in the case of svd), sparse matrix formats (in the case of quantization and pruning), sparse computation hardware, etc.).