Skip to content

BlazeBot8/EdgeFusion

Repository files navigation

EdgeFusion: Hybrid SSM-Transformer with Sub-4-Bit Quantization for Edge LLM Inference

Paper: "Optimizing Large Language Models for Efficient Edge Inference via Hybrid SSM-Transformer Architectures and Sub-4-Bit Quantization"
Authors: Jordan M. Ellis, Priya R. Shenoy, Marcus T. Okafor, Lin Wei
Venue: NeurIPS 2024 | arXiv: 2409.XXXXX
Code: github.com/edge-llm-lab/ssm-sparse (original)


Overview

EdgeFusion enables deployment of 7B-class language models on commodity edge hardware (phones, embedded boards) within a 4 GB memory envelope. It combines two techniques:

  1. Hybrid SSM-Transformer stack with differentiable sparse routing — each layer independently learns to route through either a Mamba-style SSM block (O(n), efficient) or a standard multi-head attention block (O(n²), expressive). At convergence, ~68% of layers route to SSM, reducing activated FLOPs by 38.6%.

  2. Adaptive sub-4-bit block-wise weight quantization — weights are quantized to an average of 3.1 bits per parameter (block size B=128), with the top 0.5% of weight outliers retained at 6-bit precision in a sparse correction matrix. Activations remain in bfloat16.

Together, these achieve 2.7× throughput improvement and 61% peak memory reduction versus a full-precision dense Transformer baseline, with only 0.38 PPL degradation on Wikitext-103.


Expected Results (Paper Table 1 & 2, EdgeFusion-7B)

Metric EdgeFusion-7B Dense FP16 GPTQ-4bit Mamba-pure
PPL (Wikitext-103) ↓ 6.52 6.14 6.51 6.89
MMLU 5-shot ↑ 62.7% 63.4% 62.1% 59.8%
HellaSwag ↑ 80.3% 81.2% 79.9% 76.4%
ARC-Challenge ↑ 52.9% 53.7% 52.4% 50.1%
GSM8K 8-shot ↑ 46.8% 47.9% 45.3% 41.2%
Peak Memory (GB) ↓ 5.4 14.0 5.8 14.0
Tokens/sec ↑ 11.4 4.2 9.1 11.8
GFLOPs/token ↓ 8.8 14.4 14.4 8.8

Hardware: Snapdragon 8 Gen 3. Paper reports 7 edge targets; only one shown here.


Quick Start

# 1. Clone and install
git clone https://github.com/your-org/edgefusion
cd edgefusion
pip install -e .
pip install mamba-ssm causal-conv1d  # CUDA 11.8+ required

# 2. Validate setup (no training)
python train.py --config configs/edgefusion_7b.yaml --output-dir ./checkpoints --dry-run

# 3. Train (router + quantization, 25K steps on The Pile)
python train.py --config configs/edgefusion_7b.yaml --output-dir ./checkpoints

# 4. Evaluate
python evaluate.py --checkpoint ./checkpoints/checkpoint_quantized

# 5. Interactive inference
python inference.py --checkpoint ./checkpoints/checkpoint_quantized \
    --prompt "Edge computing enables" --profile

Installation

pip

pip install -r requirements.txt
pip install mamba-ssm causal-conv1d  # separate — requires CUDA compilation

conda

conda env create -f environment.yaml
conda activate edgefusion
pip install mamba-ssm causal-conv1d

Docker

cd docker && docker-compose build
docker-compose run train

CPU / No CUDA (Raspberry Pi, Apple Silicon)

Set ssm_backend: reference in configs/edgefusion_7b.yaml. The reference backend uses a pure-PyTorch SSM implementation (slower, no CUDA dependency).


Training

Full training (router + quantization, ~25K steps):

python train.py \
    --config configs/edgefusion_7b.yaml \
    --base-model meta-llama/Meta-Llama-3.1-7B \
    --output-dir ./checkpoints

Router training only (skip quantization):

python train.py --config configs/edgefusion_7b.yaml \
    --output-dir ./checkpoints --skip-quantization

Frozen base (only train router and quantization params):

python train.py --config configs/edgefusion_7b.yaml \
    --output-dir ./checkpoints --freeze-base

Standalone quantization (after router training):

python quantize.py \
    --checkpoint ./checkpoints/checkpoint_final \
    --output-dir ./checkpoints/checkpoint_quantized

Debug mode (quick local test, 100 steps):

python train.py --config configs/edgefusion_7b.yaml --output-dir ./checkpoints --debug

Monitoring Convergence

The key convergence signal is SSM routing fraction (target: 68% per paper Section 5.4):

python analyze_router.py --checkpoint ./checkpoints/checkpoint_final --plot

A model is converged when:

  • SSM fraction ≈ 68%
  • Mean gate entropy H < 0.15 bits (per Section 5.4)

Training logs print ssm_frac every 100 steps.


Reproducibility Notes

⚠️ Read before running experiments.

This repository was generated by ArXivist from the paper's SIR (Scientific Intermediate Representation). Several implementation details were not specified in the paper:

Component Assumption Confidence Impact
Optimizer AdamW 0.55 High — wrong LR can prevent convergence
Learning rate 1e-4 0.55 High — see LR sweep advice below
LR schedule Cosine with 500-step warmup 0.50 Medium
SSM hyperparameters Mamba defaults (d_state=16, d_conv=4) 0.65 Medium
Router sparsity loss L1 regularization on gates 0.52 High — mechanism may differ
Outlier correction format COO sparse matrix 0.58 Low
Full vs. frozen fine-tuning Full fine-tuning 0.60 Medium

Recommended validation sequence:

  1. Run --dry-run to validate setup
  2. Run --debug (100 steps) to check router gates move from 0.5
  3. Do a small LR sweep: {5e-5, 1e-4, 2e-4} for 2K steps each; pick the one where ssm_frac moves toward 0.68
  4. Train full 25K steps with best LR
  5. Check ssm_frac ≈ 0.68 and gate_entropy < 0.15 at convergence

If ssm_frac is stuck near 0.5, try increasing router_sparsity_lambda in config.
If SSM fraction is too high (>0.90), decrease router_sparsity_lambda or set to 0.


Citation

@inproceedings{ellis2024edgefusion,
  title={Optimizing Large Language Models for Efficient Edge Inference via
         Hybrid SSM-Transformer Architectures and Sub-4-Bit Quantization},
  author={Ellis, Jordan M. and Shenoy, Priya R. and Okafor, Marcus T. and Wei, Lin},
  booktitle={Advances in Neural Information Processing Systems (NeurIPS)},
  year={2024}
}

Generated by ArXivist | SIR confidence: 0.77 | ArXivist v1.0

About

EdgeFusion is a co-designed LLM architecture combining hybrid SSM-Transformer layers, sparse routing, and adaptive sub-4-bit quantization. It reduces activated FLOPs by 38.6%, achieves 5.2× memory compression with minimal accuracy loss, and enables 7B-class LLM inference in 4 GB memory, cutting peak memory 61% and boosting throughput 2.7×.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors