Skip to content

Commit

Permalink
Update dim, fixed #8: norm_term_broadcastable
Browse files Browse the repository at this point in the history
  • Loading branch information
Beomi committed Apr 17, 2024
1 parent d3659c3 commit a82dbe9
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 6 deletions.
19 changes: 13 additions & 6 deletions modeling_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -818,6 +818,8 @@ def forward(
return final_output, None, None

def _retrieve_from_memory(self, query_states):
# query_states: [batch_size, seq_len, num_heads, head_dim]

# Check if memory is initialized
if self.memory is None or self.norm_term is None:
debug_print("[Retrieve] No memory or norm term found")
Expand All @@ -833,9 +835,9 @@ def _retrieve_from_memory(self, query_states):
debug_print("[Retrieve] memory_output.shape", memory_output.shape)
debug_print("[Retrieve] self.norm_term.shape", self.norm_term.shape)

# Ensure norm_term is broadcastable to memory_output shape: [batch_size, num_heads, seq_len, head_dim]
norm_term_broadcastable = self.norm_term.unsqueeze(1).expand(
-1, query_states.size(1), query_states.size(2), -1
# Broadcast norm_term to the shape of query_states, then sum across head_dim for normalization
norm_term_broadcastable = self.norm_term.expand_as(query_states).sum(
dim=3, keepdim=True
)
debug_print(
"[Broadcast] norm_term_broadcastable.shape", norm_term_broadcastable.shape
Expand All @@ -846,8 +848,11 @@ def _retrieve_from_memory(self, query_states):
return memory_output

def _update_memory(self, key_states, value_states):
# Ensure that norm_term is initialized
# key_states: [batch_size, seq_len, num_heads, head_dim]
# value_states: [batch_size, seq_len, num_heads, value_dim]

key_states = F.elu(key_states) + 1 # Apply ELU activation

if self.memory is not None:
self.memory = self.memory + torch.matmul(
key_states.transpose(-2, -1), value_states
Expand All @@ -857,10 +862,12 @@ def _update_memory(self, key_states, value_states):

if self.norm_term is not None:
self.norm_term = self.norm_term + key_states.sum(
dim=2
dim=2, keepdim=True
) # Update normalization term
else:
self.norm_term = key_states.sum(dim=2) # Initialize normalization term
self.norm_term = key_states.sum(
dim=2, keepdim=True
) # Initialize normalization term

debug_print("[Update] self.memory.shape", self.memory.shape)
debug_print("[Update] self.norm_term.shape", self.norm_term.shape)
Expand Down
15 changes: 15 additions & 0 deletions test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,18 +44,33 @@
# Generate some dummy input data
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b")
text = """This work introduces an efficient method to scale Transformer-based"""
longtext = """The new memory states M s and z s are then passed to the next segment S + 1, building in a recurrence in each attention layer. The right side term σ (K ) T V in Eq. (4) is known as an associative binding operator (Smolensky, 1990; Hebb, 2005; Schlag et al., 2020).
Inspired by the success of delta rule (Munkhdalai et al., 2019; Schlag et al., 2020; 2021), we have also incorporated it into our Infini-attention. The delta rule attempts a slightly improved memory update by first retrieving existing value entries and subtracting them from the new values before applying the associative bindings as new update."""

encoded = tokenizer(
text,
return_tensors="pt",
)
# attention_mask = torch.ones_like(input_ids)
encoded["labels"] = encoded["input_ids"].clone()

long_encoded = tokenizer(
longtext,
return_tensors="pt",
)
# attention_mask = torch.ones_like(input_ids)
long_encoded["labels"] = long_encoded["input_ids"].clone()

print(encoded)
# Test the forward pass
outputs = model(**encoded.to(device)) # position_ids=position_ids)
print("Short Text Loss")
print(outputs.loss)
outputs.loss.backward() # Test the backward pass

outputs = model(**long_encoded.to(device)) # position_ids=position_ids)
print("Long Text Loss")
print(outputs.loss)
outputs.loss.backward() # Test the backward pass

print("backprop done")
Expand Down

0 comments on commit a82dbe9

Please sign in to comment.