### Abstract
 Activation sparsity provides a dynamic, input-dependent alternative to weight pruning for accelerating inference in large language models (LLMs), effectively reducing unnecessary computations and memory accesses during the forward pass. Despite its promise, existing activation sparsification methods suffer from two major limitations: (1) solely relying on activation magnitude for sparsification, ignoring the coupling influence with the corresponding weights, (2) applying uniform sparsity rates across all blocks without considering block-wise sparsity sensitivity. To address these issues, this paper proposes a novel training-free weight-aware activation sparsity framework, called WAS. Firstly, with analyzing the coupling relationshape between weight and activation, we introduce a weight-aware scoring method to measure the activation importance in sparsification. Then, a novel constrained Bayesian optimization algorithm is further devised to set a suitable sparsity ratio for all blocks based on the sparsity sensitivity. Finally, we implement a custom GPU sparsity kernel to support the resulting sparsity patterns for wall-clock decoding speed-ups. Our WAS achieves competitive performance at 60\% model-level sparsity and significantly outperforms prior methods at higher sparsity levels, achieving up to 1.68× inference speed-up—at no retraining or weight update.

In [1]:
import sys,os
import torch
from tqdm import tqdm
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "4"
import argparse
import typing
from utils.utils import get_tokenizer, get_sparse_model
from eval_test.evaluate import eval_tasks

from was.model import LlamaSparseForCausalLM, LlamaSparseConfig
from was.model import MistralSparseForCausalLM, MistralSparseConfig

from transformers import AutoConfig, AutoModelForCausalLM

AutoConfig.register("llama_sparse", LlamaSparseConfig)
AutoConfig.register("mistral_sparse", MistralSparseConfig)

AutoModelForCausalLM.register(LlamaSparseConfig, LlamaSparseForCausalLM)
AutoModelForCausalLM.register(MistralSparseConfig, MistralSparseForCausalLM)

  from .autonotebook import tqdm as notebook_tqdm


### Load the optimal sparsity rates obtained from Bayesian optimization.

In [2]:
weight_dict = {
    "Mistral-7B": {
        'q': 1, 'k': 1/8, 'v': 1/8, 'o': 1,
        'gate': 3.5, 'up': 3.5, 'down': 3.5
    },

        "Llama-2-7B": {
        'q': 1, 'k': 1/8, 'v': 1/8, 'o': 1,
        'gate': 2.6875, 'up': 2.6875, 'down': 2.6875
        },


}
sps={
"Llama-2-7B" : [[0.3720121707624662, 0.37498553980085, 0.376696045782326, 0.37769695534272746, 0.37846142740563493, 0.37882172794166535, 0.3798171617619194, 0.38074865530775154, 0.38109715204992267, 0.3820352978287211, 0.383533766815422, 0.38775520433175414, 0.3904868305517332, 0.3916333827297726, 0.3919130484524919, 0.3920800042447037, 0.3924149854253026, 0.39270114103356535, 0.3933065710110301, 0.3950363430939929, 0.3967854329668602, 0.4016094384892526, 0.40208250102195237, 0.40210378276978204, 0.40304778157101795, 0.40343212943818424, 0.4084919309351512, 0.41839918068408477, 0.4271652402472407, 0.42861908583054953, 0.42966470560633424, 0.42986860433344914],
        [0.5713442230921849, 0.5729181462011875, 0.5751639345726678, 0.5765424467032879, 0.577812537008091, 0.5788743852102369, 0.580099415373599, 0.5811054410117681, 0.5817902779633589, 0.5824603532022575, 0.5831172735445226, 0.5847044104133572, 0.5865813829885432, 0.5907346920796416, 0.5932687962435536, 0.5943618828546953, 0.59522312014206, 0.5967951211449557, 0.5976843344014812, 0.5986977179977958, 0.5996316135565588, 0.6002875202634645, 0.6008302710337199, 0.6008843171192886, 0.6022971107009933, 0.6054288504626175, 0.6124941033486824, 0.6206212181934292, 0.6271763150339046, 0.6284894662250705, 0.6294855924705518, 0.6296899443219208],
        [0.7012713644882108, 0.7014350676570817, 0.7015280287890517, 0.70180846964121, 0.7027857076118573, 0.7086458806420335, 0.7094407211949303, 0.7097986219585672, 0.7121104654860442, 0.7128657978637374, 0.7132423515600318, 0.7163301931877049, 0.7227644916204667, 0.7310197223644707, 0.7458138392890786, 0.7545503694559393, 0.7594123738446797, 0.7657121497330586, 0.7690828972998875, 0.7717372693123751, 0.7741929299961626, 0.7768396395880095, 0.7775849192360305, 0.7780837574048741, 0.7784494308440698, 0.7784499331817754, 0.7784501452774784, 0.7784524379257576, 0.7784596848464203, 0.7785002267222303, 0.7785018032213286, 0.7785059298425341]
        ],
"Mistral-7B" : [[0.3775609028955764, 0.3789939381523285, 0.3804538615950444, 0.3815614306996619, 0.38224594047171234, 0.3837599738379501, 0.3850329422840084, 0.386834105212761, 0.3870595402345686, 0.3876852650395319, 0.3881267457750328, 0.3889023527537393, 0.3900865502739806, 0.39175575848080446, 0.39224317793138114, 0.39225849220071995, 0.39254832551000446, 0.3926732644425294, 0.393007677847419, 0.39349594360508006, 0.3940331265477053, 0.394223388650777, 0.3955972728965453, 0.4000183496171237, 0.4030123777110685, 0.403901405721781, 0.4137566951141697, 0.4233744438906073, 0.4255043498401377, 0.42635197414865533, 0.4292518869062295, 0.4298838892980136],
                    [0.5748498485047117, 0.5757952376670732, 0.5769846532769749, 0.5782101895963475, 0.5787955489200871, 0.579472568153386, 0.5799823774012058, 0.580419755292789, 0.580978258557723, 0.5819787511593875, 0.5829053709598148, 0.5835538956395441, 0.5847833999203588, 0.5848275320468758, 0.5853440810683013, 0.5854080832503255, 0.5868607378443262, 0.5869686115514824, 0.589263039481553, 0.5894255165579633, 0.589620505249984, 0.5914595898953209, 0.5987830078798729, 0.6084131095001821, 0.6179649528103596, 0.6266688657041428, 0.6290965076257066, 0.6297254074497288, 0.6298517403927066, 0.6298892550028785, 0.6299066985056049, 0.629929676390124],
                    [0.7013955019836529, 0.701678849542781, 0.7034597867432563, 0.7037652257529714, 0.7095079310696334, 0.7097794738662017, 0.7097960409992665, 0.7104157390405182, 0.710444691488421, 0.711179273810033, 0.7128782501544767, 0.7159002291726712, 0.7240184027608946, 0.7338599653129236, 0.7483119353687419, 0.756557240370774, 0.7606095022810708, 0.7720511405223425, 0.7745868566643107, 0.77648281522203, 0.7774244157562832, 0.7782131110083039, 0.7786373733056609, 0.7789397259021511, 0.7791482410663874, 0.7791492910957347, 0.7791498764021699, 0.7791502718643817, 0.7792017164493609, 0.7793270648687193, 0.7793277404717623, 0.7793284377876534]
]

}

### Evaluate PPL and 0-shot tasks

In [None]:

model_path = "/path/to/llama2-7b"
net = model_path.split("/")[-1]
was_path = "./models/Llama-2-7B"
model_type = "Llama-2-7B"
greedy_flag = True
sparsity = 0.6
eval_ppl = True
tasks = "piqa,arc_easy,arc_challenge,hellaswag,winogrande"
batch_size = "auto:4.0"
fewshot = 0
seed = 2



tokenizer = get_tokenizer(model_path)
model = get_sparse_model(model_path, device="cpu", histogram_path=os.path.join(was_path, "histograms"))

print("Evaluating sparse PPL at sparsity level: ", sparsity)
print("="*40)
if greedy_flag:
    print("Evaluating greedy PPL")
    greedy_path = os.path.join(was_path, "lookup")
    
    tmp = sps[model_type]
    if sparsity == 0.4:
        sp = tmp[0]
    elif sparsity == 0.6:
        sp = tmp[1]
    else:
        sp = tmp[2]
    model.load_greedy_sparsities(greedy_path, sparsity, sparsities=sp)
else:
    print("Evaluating uniform PPL")
    model.set_uniform_sparsity(sparsity)

eval_tasks(model, tokenizer, eval_ppl, model_path, tasks, fewshot, batch_size, seed)

You are using a model of type llama to instantiate a model of type llama_sparse. This is not supported for all configurations of models and can yield errors.


LlamaForCausalLM


Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  6.52it/s]
  histogram = torch.load(f"{self.file_path}/histograms.pt")


Evaluating sparse PPL at sparsity level:  0.6
Evaluating greedy PPL
get_wikitext2
Evaluating ...
wikitext2 Perplexity: 6.560258
get_c4_new
Evaluating ...
c4_new Perplexity: 8.782909


2025-05-12:16:07:16,758 INFO     [huggingface.py:481] Using model type 'default'
2025-05-12:16:07:16,778 INFO     [evaluator.py:164] Setting random seed to 0 | Setting numpy seed to 1234 | Setting torch manual seed to 1234 | Setting fewshot manual seed to 1234
2025-05-12:16:07:16,779 INFO     [evaluator.py:217] Using pre-initialized model
Using the latest cached version of the module from /root/.cache/huggingface/modules/datasets_modules/datasets/piqa/6c611c1a9bf220943c4174e117d3b660859665baf1d43156230116185312d011 (last modified on Fri Jun 14 16:08:58 2024) since it couldn't be found locally at piqa, or remotely on the Hugging Face Hub.
Using the latest cached version of the dataset since allenai/ai2_arc couldn't be found on the Hugging Face Hub
Found the latest cached dataset configuration 'ARC-Easy' at /root/.cache/huggingface/datasets/allenai___ai2_arc/ARC-Easy/0.0.0/210d026faf9955653af8916fad021475a3f00453 (last modified on Wed Oct 23 12:30:10 2024).
2025-05-12:16:07:57,453 INFO  

Passed argument batch_size = auto:4.0. Detecting largest batch size
Determined largest batch size: 64


Running loglikelihood requests:  25%|██▍       | 14871/60566 [04:53<12:30, 60.89it/s]

Passed argument batch_size = auto:4.0. Detecting largest batch size
Determined largest batch size: 64


Running loglikelihood requests: 100%|██████████| 60566/60566 [11:34<00:00, 87.23it/s] 
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 4096)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaFlashAttention2(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
          (sparse_fns): ModuleDict(
            (q): SparsifyFn()
            (k): SparsifyFn()
            (v): SparsifyFn()
            (o): SparsifyFn()
          )
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (up_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (down_proj): Linea

|    Tasks    |Version|Filter|n-shot| Metric |   |Value |   |Stderr|
|-------------|------:|------|-----:|--------|---|-----:|---|-----:|
|arc_challenge|      1|none  |     0|acc     |↑  |0.3959|±  |0.0143|
|             |       |none  |     0|acc_norm|↑  |0.4096|±  |0.0144|
|arc_easy     |      1|none  |     0|acc     |↑  |0.7311|±  |0.0091|
|             |       |none  |     0|acc_norm|↑  |0.7062|±  |0.0093|
|hellaswag    |      1|none  |     0|acc     |↑  |0.5442|±  |0.0050|
|             |       |none  |     0|acc_norm|↑  |0.7304|±  |0.0044|
|piqa         |      1|none  |     0|acc     |↑  |0.7720|±  |0.0098|
|             |       |none  |     0|acc_norm|↑  |0.7688|±  |0.0098|
|winogrande   |      1|none  |     0|acc     |↑  |0.6654|±  |0.0133|



### Evaluate PPL and 5-shot tasks

In [None]:
model_path = "/path/to/Mistral-7B-v0.1"
net = model_path.split("/")[-1]
was_path = "./models/Mistral-7B"
model_type = "Mistral-7B"
greedy_flag = True
sparsity = 0.75
eval_ppl = True
tasks = "gsm8k,mmlu"
batch_size = "auto:4.0"
fewshot = 5
seed = 2



tokenizer = get_tokenizer(model_path)
model = get_sparse_model(model_path, device="cpu", histogram_path=os.path.join(was_path, "histograms"))

print("Evaluating sparse PPL at sparsity level: ", sparsity)
print("="*40)
if greedy_flag:
    print("Evaluating greedy PPL")
    greedy_path = os.path.join(was_path, "lookup")
    
    tmp = sps[model_type]
    if sparsity == 0.4:
        sp = tmp[0]
    elif sparsity == 0.6:
        sp = tmp[1]
    else:
        sp = tmp[2]
    model.load_greedy_sparsities(greedy_path, sparsity, sparsities=sp)
else:
    print("Evaluating uniform PPL")
    model.set_uniform_sparsity(sparsity)

eval_tasks(model, tokenizer, eval_ppl, model_path, tasks, fewshot, batch_size, seed)

You are using a model of type mistral to instantiate a model of type mistral_sparse. This is not supported for all configurations of models and can yield errors.


MistralForCausalLM


Loading checkpoint shards: 100%|██████████| 2/2 [00:01<00:00,  1.33it/s]
  histogram = torch.load(f"{self.file_path}/histograms.pt")


Evaluating sparse PPL at sparsity level:  0.75
Evaluating greedy PPL
get_wikitext2
Evaluating ...
wikitext2 Perplexity: 10.336608
get_c4_new
Evaluating ...
c4_new Perplexity: 14.033845


2025-05-13:19:40:42,745 INFO     [huggingface.py:481] Using model type 'default'
2025-05-13:19:40:42,759 INFO     [evaluator.py:164] Setting random seed to 0 | Setting numpy seed to 1234 | Setting torch manual seed to 1234 | Setting fewshot manual seed to 1234
2025-05-13:19:40:42,760 INFO     [evaluator.py:217] Using pre-initialized model
Using the latest cached version of the dataset since gsm8k couldn't be found on the Hugging Face Hub
Found the latest cached dataset configuration 'main' at /root/.cache/huggingface/datasets/gsm8k/main/0.0.0/e53f048856ff4f594e959d75785d2c2d37b678ee (last modified on Wed Jan 22 12:54:13 2025).
2025-05-13:19:42:34,449 INFO     [task.py:415] Building contexts for mmlu_abstract_algebra on rank 0...
100%|██████████| 100/100 [00:00<00:00, 160.26it/s]
2025-05-13:19:42:35,093 INFO     [task.py:415] Building contexts for mmlu_computer_security on rank 0...
100%|██████████| 100/100 [00:00<00:00, 168.13it/s]
2025-05-13:19:42:35,696 INFO     [task.py:415] Buildin

Passed argument batch_size = auto:4.0. Detecting largest batch size
Determined largest batch size: 8


Running loglikelihood requests:  25%|██▍       | 13953/56168 [15:29<21:17, 33.05it/s] 

Passed argument batch_size = auto:4.0. Detecting largest batch size


Running loglikelihood requests:  25%|██▍       | 13984/56168 [15:41<21:16, 33.05it/s]

Determined largest batch size: 32


Running loglikelihood requests:  50%|████▉     | 27821/56168 [22:02<10:58, 43.06it/s]  

Passed argument batch_size = auto:4.0. Detecting largest batch size


Running loglikelihood requests:  50%|████▉     | 27948/56168 [22:14<10:55, 43.06it/s]

Determined largest batch size: 64


Running loglikelihood requests: 100%|██████████| 56168/56168 [30:46<00:00, 30.42it/s]
2025-05-13:20:17:11,437 INFO     [evaluator.py:489] Running generate_until requests
Running generate_until requests:   0%|          | 0/1319 [00:00<?, ?it/s]

Passed argument batch_size = auto. Detecting largest batch size
Determined Largest batch size: 1


Running generate_until requests: 100%|██████████| 1319/1319 [3:11:25<00:00,  8.71s/it]  
  (model): MistralModel(
    (embed_tokens): Embedding(32000, 4096)
    (layers): ModuleList(
      (0-31): 32 x MistralDecoderLayer(
        (self_attn): MistralFlashAttention2(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): MistralRotaryEmbedding()
          (sparse_fns): ModuleDict(
            (q): SparsifyFn()
            (k): SparsifyFn()
            (v): SparsifyFn()
            (o): SparsifyFn()
          )
        )
        (mlp): MistralMLP(
          (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (up_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (down_

|                 Tasks                 |Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|---------------------------------------|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k                                  |      3|flexible-extract|     5|exact_match|↑  |0.0417|±  |0.0055|
|                                       |       |strict-match    |     5|exact_match|↑  |0.0402|±  |0.0054|
|mmlu                                   |      2|none            |      |acc        |↑  |0.3934|±  |0.0040|
| - humanities                          |      2|none            |      |acc        |↑  |0.3509|±  |0.0068|
|  - formal_logic                       |      1|none            |     5|acc        |↑  |0.1905|±  |0.0351|
|  - high_school_european_history       |      1|none            |     5|acc        |↑  |0.4000|±  |0.0383|
|  - high_school_us_history             |      1|none            |     5|acc        |↑  |0.3873|±  |0.0342|
|  - high_school_world_histo