Skip to content

Commit

Permalink
Update model.py To Fix ntk length bug. (#375)
Browse files Browse the repository at this point in the history
  • Loading branch information
hiworldwzj committed Mar 26, 2024
1 parent f5dc783 commit aa98b35
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion lightllm/models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def _init_to_get_dynamic_ntk_rotary(self):
scaling_factor = 1.0
else:
scaling_factor = self.config.get("rope_scaling", {}).get("factor", 1.0)
max_seq_len = self.max_seq_length
max_seq_len = max(self.max_seq_length, max_position_embeddings)
self._cos_cached = torch.zeros((max_seq_len, partial_head_dim // 2), dtype=torch.float16, device="cuda")
self._sin_cached = torch.zeros((max_seq_len, partial_head_dim // 2), dtype=torch.float16, device="cuda")

Expand Down

0 comments on commit aa98b35

Please sign in to comment.