Skip to content

Commit

Permalink
Fix typo at comments
Browse files Browse the repository at this point in the history
  • Loading branch information
Beomi committed Apr 17, 2024
1 parent c2f0746 commit cffcaa2
Showing 1 changed file with 4 additions and 5 deletions.
9 changes: 4 additions & 5 deletions modeling_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -818,7 +818,7 @@ def forward(
return final_output, None, None

def _retrieve_from_memory(self, query_states):
# query_states: [batch_size, seq_len, num_heads, head_dim]
# query_states: [batch_size, num_heads, seq_len, head_dim]

# Check if memory is initialized
if self.memory is None or self.norm_term is None:
Expand All @@ -838,8 +838,7 @@ def _retrieve_from_memory(self, query_states):
# Broadcast norm_term to the shape of query_states, then sum across head_dim for normalization
norm_term_broadcastable = torch.matmul(
query_states,
self.norm_term
.transpose(-2, -1),
self.norm_term.transpose(-2, -1),
)
debug_print(
"[Broadcast] norm_term_broadcastable.shape", norm_term_broadcastable.shape
Expand All @@ -850,8 +849,8 @@ def _retrieve_from_memory(self, query_states):
return memory_output

def _update_memory(self, key_states, value_states):
# key_states: [batch_size, seq_len, num_heads, head_dim]
# value_states: [batch_size, seq_len, num_heads, value_dim]
# key_states: [batch_size, num_heads, seq_len, head_dim]
# value_states: [batch_size, num_heads, seq_len, value_dim]

key_states = F.elu(key_states) + 1 # Apply ELU activation

Expand Down

0 comments on commit cffcaa2

Please sign in to comment.