# create conda env
conda env create -f environment.yaml
conda activate a3
# install requirements.txt
pip install -r requirements.txtFunctionalities:
- Collect rxx
- Approximate QK, VO, FFN using rxx
- Evaluate approximated model's perplexity or downstream tasks (lm-eval-harness).
Command line interface:
cd experiments/llm
python run.py collect -h
python run.py approx -h
python run.py eval ppl -h
python run.py eval harness -hFor full example, check this tutorial.
- The attention type (
attn_type) is parsed from the model config, and implemented in theA3ModelHelpers.get_model_arch_metafunction atsrc/a3/models/__init__.py - The SVD-based A3-QK solution cannot be applied to multi-head attention with RoPE (
mha-rope) and grouped query attention (gqa-*). We use CR approximation instead. - The SVD-based A3-VO solution cannot be applied to grouped query attention (
gqa-*), we use joint SVD instead. - 🟢 denotes A3 method and its variants; 🟡 denotes baselines for ablation study/debug
| Attn Type | Available approx mode | Class | Description |
|---|---|---|---|
mha |
qk |
🟢 | SVD using the Rxx of both Q and K |
mha |
q-only |
🟡 | SVD using the Rxx of Q only |
mha |
k-only |
🟡 | SVD using the Rxx of K only |
mha-rope |
rxx-w-rxx |
🟢 | CR, each pair of (Qi,Ki) heads has its own index to drop cols |
mha-rope |
rxx-w-rxx-uniform |
🟡 | CR, all (Qi,Ki) head pairs in one attn layer share the same index to drop cols |
gqa |
Not implemented yet | - | Most GQA are combined with RoPE. Limited by time, I only implemented gqa-rope |
gqa-rope |
rxx-w-rxx |
🟢 | CR, each group of (Q1,Q2,...,K) heads has its own index to drop cols |
gqa-rope |
rxx-w-rxx-uniform |
🟡 | CR, all (Q1,Q2,...,K) heads in one decoder layer share the same index to drop cols |
| Attn Type | Available approx mode | Class | Description |
|---|---|---|---|
mhamha-rope |
axkv |
🟢 | SVD using the Rxx of A Xkv |
mhamha-rope |
xkv |
🟡 | SVD using the Rxx of Xkv |
mhamha-rope |
identity |
🟡 | SVD on fused weights without using activation information |
gqa |
Not implemented yet | - | Most GQA are combined with RoPE. Limited by time, I only implemented gqa-rope |
gqa-rope |
xkv |
🟢 | SVD using the Rxx of Xkv |
gqa-rope |
identity |
🟡 | SVD on fused weights without using activation information |
For the 3-layer case, rxx may achieve better performance than rxx-w.
| Attn Type | Available approx mode | Class | Description |
|---|---|---|---|
2-layer3-layer |
rxx |
🟡 | CR using only the information of Rxx. |
2-layer3-layer |
rxx-w |
🟢 | CR using the information of both Rxx and the weights |