Skip to content

Commit

Permalink
Split up caches
Browse files Browse the repository at this point in the history
  • Loading branch information
VikParuchuri committed May 27, 2024
1 parent 8f3b370 commit df534cf
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 24 deletions.
40 changes: 18 additions & 22 deletions surya/model/recognition/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,8 @@ def forward(
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
langs: Optional[torch.LongTensor] = None,
kv_caches: Optional[List[torch.Tensor]] = None,
self_kv_cache: Optional[torch.Tensor] = None,
cross_kv_cache: Optional[torch.Tensor] = None,
is_prefill: Optional[bool] = False,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
Expand All @@ -253,11 +254,10 @@ def forward(

# Self Attention
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
self_attn_past_key_value = kv_caches[0] if kv_caches is not None else None
# add present self-attn cache to positions 1,2 of present_key_value tuple
hidden_states, present_key_value = self.self_attn(
hidden_states=hidden_states,
past_key_value=self_attn_past_key_value,
past_key_value=self_kv_cache,
is_prefill=is_prefill,
attention_mask=attention_mask,
)
Expand All @@ -270,13 +270,12 @@ def forward(
hidden_states = self.encoder_attn_layer_norm(hidden_states)

# cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple
cross_attn_past_key_value = kv_caches[1] if kv_caches is not None else None
hidden_states, cross_attn_present_key_value = self.encoder_attn(
hidden_states=hidden_states,
key_value_states=encoder_hidden_states,
is_prefill=is_prefill,
attention_mask=encoder_attention_mask,
past_key_value=cross_attn_past_key_value,
past_key_value=cross_kv_cache,
)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states
Expand Down Expand Up @@ -336,24 +335,21 @@ def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
kv_caches: Optional[List[torch.Tensor]] = None,
self_kv_cache: Optional[torch.Tensor] = None,
cross_kv_cache: Optional[torch.Tensor] = None,
past_token_count: Optional[int] = None,
langs: Optional[torch.LongTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
use_cache = True
return_dict = True

input = input_ids
input_shape = input.size()
input_ids = input_ids.view(-1, input_shape[-1])

# past_key_values_length
past_key_values_length = past_token_count if kv_caches is not None else 0
past_key_values_length = past_token_count if self_kv_cache is not None else 0
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale

# embed positions
Expand All @@ -372,15 +368,17 @@ def forward(

for idx, decoder_layer in enumerate(self.layers):
is_prefill = past_token_count == 0
kv_cache = [kv_caches[0][idx], kv_caches[1][idx]] if kv_caches is not None else None
layer_self_kv_cache = self_kv_cache[idx] if self_kv_cache is not None else None
layer_cross_kv_cache = cross_kv_cache[idx] if cross_kv_cache is not None else None
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
langs=langs,
kv_caches=kv_cache,
self_kv_cache=layer_self_kv_cache,
cross_kv_cache=layer_cross_kv_cache,
is_prefill=is_prefill,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
encoder_attention_mask=None,
use_cache=use_cache,
)
hidden_states = layer_outputs[0]
Expand Down Expand Up @@ -440,7 +438,8 @@ def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
kv_caches: Optional[List[torch.FloatTensor]] = None,
self_kv_cache: Optional[torch.FloatTensor] = None,
cross_kv_cache: Optional[torch.FloatTensor] = None,
past_token_count: Optional[int] = None,
langs: Optional[torch.LongTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
Expand All @@ -461,14 +460,11 @@ def forward(
outputs = self.model.decoder(
input_ids=input_ids,
attention_mask=attention_mask,
kv_caches=kv_caches,
self_kv_cache=self_kv_cache,
cross_kv_cache=cross_kv_cache,
past_token_count=past_token_count,
langs=langs,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
return_dict=return_dict,
)

logits = self.lm_head(outputs[0])
Expand Down
5 changes: 3 additions & 2 deletions surya/recognition.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,10 +102,11 @@ def batch_recognition(images: List, languages: List[List[str]], model, processor
return_dict = model(
decoder_input_ids=batch_decoder_input,
decoder_attention_mask=attention_mask,
decoder_kv_caches=None if is_prefill else [decoder_cache, encoder_cache],
decoder_self_kv_cache=None if is_prefill else decoder_cache,
decoder_cross_kv_cache=None if is_prefill else encoder_cache,
decoder_past_token_count=token_count,
pixel_values=batch_pixel_values,
decoder_langs=batch_langs,
pixel_values=batch_pixel_values,
encoder_outputs=encoder_outputs,
return_dict=True,
)
Expand Down

0 comments on commit df534cf

Please sign in to comment.