diff --git a/jetstream_pt/third_party/gemma/model.py b/jetstream_pt/third_party/gemma/model.py index 10592fa8..8cd87b13 100644 --- a/jetstream_pt/third_party/gemma/model.py +++ b/jetstream_pt/third_party/gemma/model.py @@ -341,7 +341,7 @@ def forward( hidden_states = self.norm(hidden_states) embedder_weight = self.embedder.weight - if self.env.enable_weight_quantization: + if self.env.quant_config.enable_weight_quantization: embedder_weight = embedder_weight * self.embedder.weight_scaler logits = torch.matmul(hidden_states, embedder_weight.t()) return logits