From 2b55962483643c9e4fe341365aaf4bf4af8c041e Mon Sep 17 00:00:00 2001 From: Lance Wang Date: Wed, 22 May 2024 16:18:29 -0700 Subject: [PATCH] Fix gemma model, enable_weight_quantization is available through quant_config. --- jetstream_pt/third_party/gemma/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jetstream_pt/third_party/gemma/model.py b/jetstream_pt/third_party/gemma/model.py index 10592fa8..8cd87b13 100644 --- a/jetstream_pt/third_party/gemma/model.py +++ b/jetstream_pt/third_party/gemma/model.py @@ -341,7 +341,7 @@ def forward( hidden_states = self.norm(hidden_states) embedder_weight = self.embedder.weight - if self.env.enable_weight_quantization: + if self.env.quant_config.enable_weight_quantization: embedder_weight = embedder_weight * self.embedder.weight_scaler logits = torch.matmul(hidden_states, embedder_weight.t()) return logits