diff --git a/benchmarks/run_offline.py b/benchmarks/run_offline.py index 07128881..a41ca986 100644 --- a/benchmarks/run_offline.py +++ b/benchmarks/run_offline.py @@ -136,7 +136,7 @@ def run_prefill_time(engine, params, decode_state, seqlen): } -def main(): +def main(argv): """Main function to run engine offline.""" engine = create_engine() diff --git a/default_shardings/llama-2.yaml b/default_shardings/llama-2.yaml index 35859a21..d6bd2bc0 100644 --- a/default_shardings/llama-2.yaml +++ b/default_shardings/llama-2.yaml @@ -6,6 +6,7 @@ freqs_cis : -1 # torch.complex64 (2048, 64) tok_embeddings.weight : 1 # torch.float32 (32000, 4096) +tok_embeddings.weight_scaler : 0 # torch.bfloat16 (4096,) layers.*.attention.wo.weight : 1 # torch.int8 (4096, 4096) layers.*.attention.wo.weight_scaler : 0 # torch.bfloat16 (4096,) layers.*.attention.wq.weight : 0 # torch.int8 (4096, 4096) @@ -15,9 +16,13 @@ layers.*.attention.wk.weight_scaler : 0 # torch.bfloat16 (4096,) layers.*.attention.wv.weight : 0 # torch.int8 (4096, 4096) layers.*.attention.wv.weight_scaler : 0 # torch.bfloat16 (4096,) layers.*.feed_forward.w1.weight : 0 # torch.float32 (11008, 4096) +layers.*.feed_forward.w1.weight_scaler : 0 # torch.bfloat16 (4096,) layers.*.feed_forward.w2.weight : 1 # torch.float32 (4096, 11008) +layers.*.feed_forward.w2.weight_scaler : 0 # torch.bfloat16 (11008,) layers.*.feed_forward.w3.weight : 0 # torch.float32 (11008, 4096) +layers.*.feed_forward.w3.weight_scaler : 0 # torch.bfloat16 (4096,) layers.*.attention_norm.weight : -1 # torch.float32 (4096,) layers.*.ffn_norm.weight : -1 # torch.float32 (4096,) norm.weight : -1 # torch.float32 (4096,) output.weight : 0 # torch.float32 (32000, 4096) +output.weight_scaler : 0 # torch.float32 (4096,)