Skip to content

Commit

Permalink
Fix #8 again
Browse files Browse the repository at this point in the history
  • Loading branch information
Beomi committed Apr 17, 2024
1 parent ca0b9ea commit c2f0746
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions modeling_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -836,8 +836,10 @@ def _retrieve_from_memory(self, query_states):
debug_print("[Retrieve] self.norm_term.shape", self.norm_term.shape)

# Broadcast norm_term to the shape of query_states, then sum across head_dim for normalization
norm_term_broadcastable = self.norm_term.expand_as(query_states).sum(
dim=3, keepdim=True
norm_term_broadcastable = torch.matmul(
query_states,
self.norm_term
.transpose(-2, -1),
)
debug_print(
"[Broadcast] norm_term_broadcastable.shape", norm_term_broadcastable.shape
Expand Down

0 comments on commit c2f0746

Please sign in to comment.