Paper: "Optimizing Large Language Models for Efficient Edge Inference via Hybrid SSM-Transformer Architectures and Sub-4-Bit Quantization"
Authors: Jordan M. Ellis, Priya R. Shenoy, Marcus T. Okafor, Lin Wei
Venue: NeurIPS 2024 | arXiv: 2409.XXXXX
Code: github.com/edge-llm-lab/ssm-sparse (original)
EdgeFusion enables deployment of 7B-class language models on commodity edge hardware (phones, embedded boards) within a 4 GB memory envelope. It combines two techniques:
-
Hybrid SSM-Transformer stack with differentiable sparse routing — each layer independently learns to route through either a Mamba-style SSM block (O(n), efficient) or a standard multi-head attention block (O(n²), expressive). At convergence, ~68% of layers route to SSM, reducing activated FLOPs by 38.6%.
-
Adaptive sub-4-bit block-wise weight quantization — weights are quantized to an average of 3.1 bits per parameter (block size B=128), with the top 0.5% of weight outliers retained at 6-bit precision in a sparse correction matrix. Activations remain in bfloat16.
Together, these achieve 2.7× throughput improvement and 61% peak memory reduction versus a full-precision dense Transformer baseline, with only 0.38 PPL degradation on Wikitext-103.
| Metric | EdgeFusion-7B | Dense FP16 | GPTQ-4bit | Mamba-pure |
|---|---|---|---|---|
| PPL (Wikitext-103) ↓ | 6.52 | 6.14 | 6.51 | 6.89 |
| MMLU 5-shot ↑ | 62.7% | 63.4% | 62.1% | 59.8% |
| HellaSwag ↑ | 80.3% | 81.2% | 79.9% | 76.4% |
| ARC-Challenge ↑ | 52.9% | 53.7% | 52.4% | 50.1% |
| GSM8K 8-shot ↑ | 46.8% | 47.9% | 45.3% | 41.2% |
| Peak Memory (GB) ↓ | 5.4 | 14.0 | 5.8 | 14.0 |
| Tokens/sec ↑ | 11.4 | 4.2 | 9.1 | 11.8 |
| GFLOPs/token ↓ | 8.8 | 14.4 | 14.4 | 8.8 |
Hardware: Snapdragon 8 Gen 3. Paper reports 7 edge targets; only one shown here.
# 1. Clone and install
git clone https://github.com/your-org/edgefusion
cd edgefusion
pip install -e .
pip install mamba-ssm causal-conv1d # CUDA 11.8+ required
# 2. Validate setup (no training)
python train.py --config configs/edgefusion_7b.yaml --output-dir ./checkpoints --dry-run
# 3. Train (router + quantization, 25K steps on The Pile)
python train.py --config configs/edgefusion_7b.yaml --output-dir ./checkpoints
# 4. Evaluate
python evaluate.py --checkpoint ./checkpoints/checkpoint_quantized
# 5. Interactive inference
python inference.py --checkpoint ./checkpoints/checkpoint_quantized \
--prompt "Edge computing enables" --profilepip install -r requirements.txt
pip install mamba-ssm causal-conv1d # separate — requires CUDA compilationconda env create -f environment.yaml
conda activate edgefusion
pip install mamba-ssm causal-conv1dcd docker && docker-compose build
docker-compose run trainSet ssm_backend: reference in configs/edgefusion_7b.yaml.
The reference backend uses a pure-PyTorch SSM implementation (slower, no CUDA dependency).
Full training (router + quantization, ~25K steps):
python train.py \
--config configs/edgefusion_7b.yaml \
--base-model meta-llama/Meta-Llama-3.1-7B \
--output-dir ./checkpointsRouter training only (skip quantization):
python train.py --config configs/edgefusion_7b.yaml \
--output-dir ./checkpoints --skip-quantizationFrozen base (only train router and quantization params):
python train.py --config configs/edgefusion_7b.yaml \
--output-dir ./checkpoints --freeze-baseStandalone quantization (after router training):
python quantize.py \
--checkpoint ./checkpoints/checkpoint_final \
--output-dir ./checkpoints/checkpoint_quantizedDebug mode (quick local test, 100 steps):
python train.py --config configs/edgefusion_7b.yaml --output-dir ./checkpoints --debugThe key convergence signal is SSM routing fraction (target: 68% per paper Section 5.4):
python analyze_router.py --checkpoint ./checkpoints/checkpoint_final --plotA model is converged when:
- SSM fraction ≈ 68%
- Mean gate entropy H < 0.15 bits (per Section 5.4)
Training logs print ssm_frac every 100 steps.
This repository was generated by ArXivist from the paper's SIR (Scientific Intermediate Representation). Several implementation details were not specified in the paper:
| Component | Assumption | Confidence | Impact |
|---|---|---|---|
| Optimizer | AdamW | 0.55 | High — wrong LR can prevent convergence |
| Learning rate | 1e-4 | 0.55 | High — see LR sweep advice below |
| LR schedule | Cosine with 500-step warmup | 0.50 | Medium |
| SSM hyperparameters | Mamba defaults (d_state=16, d_conv=4) | 0.65 | Medium |
| Router sparsity loss | L1 regularization on gates | 0.52 | High — mechanism may differ |
| Outlier correction format | COO sparse matrix | 0.58 | Low |
| Full vs. frozen fine-tuning | Full fine-tuning | 0.60 | Medium |
Recommended validation sequence:
- Run
--dry-runto validate setup - Run
--debug(100 steps) to check router gates move from 0.5 - Do a small LR sweep: {5e-5, 1e-4, 2e-4} for 2K steps each; pick the one where
ssm_fracmoves toward 0.68 - Train full 25K steps with best LR
- Check
ssm_frac ≈ 0.68andgate_entropy < 0.15at convergence
If ssm_frac is stuck near 0.5, try increasing router_sparsity_lambda in config.
If SSM fraction is too high (>0.90), decrease router_sparsity_lambda or set to 0.
@inproceedings{ellis2024edgefusion,
title={Optimizing Large Language Models for Efficient Edge Inference via
Hybrid SSM-Transformer Architectures and Sub-4-Bit Quantization},
author={Ellis, Jordan M. and Shenoy, Priya R. and Okafor, Marcus T. and Wei, Lin},
booktitle={Advances in Neural Information Processing Systems (NeurIPS)},
year={2024}
}Generated by ArXivist | SIR confidence: 0.77 | ArXivist v1.0