Skip to content

Commit

Permalink
Dynamic or static
Browse files Browse the repository at this point in the history
  • Loading branch information
VikParuchuri committed May 26, 2024
1 parent 860ed17 commit acc9b37
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 120 deletions.
110 changes: 29 additions & 81 deletions surya/model/recognition/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,19 +90,6 @@ def forward(self, hidden_states: torch.Tensor, langs: torch.LongTensor) -> torch
return final_hidden_states


def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
From llama
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)


class MBartGQAttention(nn.Module):
def __init__(
self,
Expand Down Expand Up @@ -161,43 +148,43 @@ def forward(
# `past_key_value[0].shape[2] == key_value_states.shape[1]`
# is checking that the `sequence_length` of the `past_key_value` is the same as
# the provided `key_value_states` to support prefix tuning
if (
is_cross_attention
and not is_prefill
):
# reuse k,v, cross_attentions
key_states = past_key_value[0]
value_states = past_key_value[1]
past_key_value = (None, None)
elif is_cross_attention:
# cross_attentions
key_states = self._shape_key_value(self.k_proj(key_value_states), -1, bsz)
value_states = self._shape_key_value(self.v_proj(key_value_states), -1, bsz)
past_key_value = (key_states, value_states)
elif not is_prefill:
# reuse k, v, self_attention
key_states = self._shape_key_value(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape_key_value(self.v_proj(hidden_states), -1, bsz)
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
past_key_value = (key_states[:, :, -tgt_len:], value_states[:, :, -tgt_len:])
if is_cross_attention:
if is_prefill:
# cross_attentions
key_states = self._shape_key_value(self.k_proj(key_value_states), -1, bsz)
value_states = self._shape_key_value(self.v_proj(key_value_states), -1, bsz)
past_key_value = torch.cat([key_states.unsqueeze(0), value_states.unsqueeze(0)], dim=0)
else:
# reuse k,v, cross_attentions
key_states = past_key_value[0]
value_states = past_key_value[1]
past_key_value = None
# Self-attention
else:
# self_attention
key_states = self._shape_key_value(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape_key_value(self.v_proj(hidden_states), -1, bsz)
past_key_value = (key_states[:, :, -tgt_len:], value_states[:, :, -tgt_len:])
if is_prefill:
# initial prompt
key_states = self._shape_key_value(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape_key_value(self.v_proj(hidden_states), -1, bsz)
past_key_value = torch.cat([key_states[:, :, -tgt_len:].unsqueeze(0), value_states[:, :, -tgt_len:].unsqueeze(0)], dim=0)
else:
# reuse k, v, self_attention
key_states = self._shape_key_value(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape_key_value(self.v_proj(hidden_states), -1, bsz)
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
past_key_value = torch.cat([key_states[:, :, -tgt_len:].unsqueeze(0), value_states[:, :, -tgt_len:].unsqueeze(0)], dim=0)

proj_shape = (bsz * self.num_heads, -1, self.head_dim)
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)

# Expand kv heads, then match query shape
key_states = repeat_kv(key_states, self.num_kv_groups).reshape(*proj_shape)
value_states = repeat_kv(value_states, self.num_kv_groups).reshape(*proj_shape)
key_states = key_states.repeat_interleave(self.num_kv_groups, dim=1).reshape(*proj_shape)
value_states = value_states.repeat_interleave(self.num_kv_groups, dim=1).reshape(*proj_shape)

src_len = key_states.size(1)
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))

if attention_mask is not None:
if not is_cross_attention:
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)

Expand Down Expand Up @@ -295,7 +282,7 @@ def forward(
hidden_states = residual + hidden_states

# add cross-attn to positions 3,4 of present_key_value tuple
present_key_value = present_key_value + cross_attn_present_key_value
present_key_value = (present_key_value, cross_attn_present_key_value)

# Fully Connected
residual = hidden_states
Expand All @@ -317,6 +304,7 @@ def forward(

return outputs


class MBartMoEDecoder(MBartDecoder):
def __init__(self, config: MBartConfig, embed_tokens: Optional[nn.Embedding] = None):
MBartPreTrainedModel.__init__(self, config)
Expand All @@ -337,7 +325,6 @@ def __init__(self, config: MBartConfig, embed_tokens: Optional[nn.Embedding] = N
)
# Language-specific MoE goes at second and second-to-last layer
self.layers = nn.ModuleList([MBartMoEDecoderLayer(config, has_moe=(i in config.moe_layers) and config.use_moe) for i in range(config.decoder_layers)])
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
self.layernorm_embedding = nn.LayerNorm(config.d_model)
self.layer_norm = nn.LayerNorm(config.d_model)

Expand Down Expand Up @@ -369,18 +356,6 @@ def forward(
past_key_values_length = past_token_count if kv_caches is not None else 0
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale

# 4d mask is passed through the layers
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask, input_shape, inputs_embeds, past_key_values_length
)

# expand encoder attention mask
if encoder_hidden_states is not None and encoder_attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
encoder_attention_mask = _prepare_4d_attention_mask(
encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
)

# embed positions
positions = self.embed_positions(input, past_key_values_length)

Expand Down Expand Up @@ -511,33 +486,6 @@ def forward(
cross_attentions=outputs.cross_attentions,
)

def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, attention_mask=None, langs=None, use_cache=None, **kwargs
):
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
if attention_mask is None:
attention_mask = input_ids.new_ones(input_ids.shape)

if past_key_values:
past_length = past_key_values[0][0].shape[2]

# Some generation methods already pass only the last input ID
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = input_ids.shape[1] - 1

input_ids = input_ids[:, remove_prefix_length:]
# first step, decoder_cached_states are empty
return {
"input_ids": input_ids, # encoder_outputs is defined. input_ids not needed
"attention_mask": attention_mask,
"past_key_values": past_key_values,
"use_cache": use_cache,
"langs": langs
}

def prune_moe_experts(self, keep_keys: List[int]):
# Remove experts not specified in keep_keys
str_keep_keys = [str(key) for key in keep_keys]
Expand Down
115 changes: 76 additions & 39 deletions surya/recognition.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,76 +48,113 @@ def batch_recognition(images: List, languages: List[List[str]], model, processor
batch_pixel_values = torch.tensor(np.array(batch_pixel_values), dtype=model.dtype).to(model.device)
batch_decoder_input = torch.from_numpy(np.array(batch_decoder_input, dtype=np.int64)).to(model.device)

inference_token_count = batch_decoder_input.shape[-1]
token_count = 0
encoder_outputs = None
batch_predictions = [[] for _ in range(len(batch_images))]
sequence_scores = None

attention_mask = torch.ones_like(batch_decoder_input, device=model.device)
all_done = torch.zeros(len(batch_images), dtype=torch.bool, device=model.device)

# Decoder kv cache
# 7 (layers) x 2 (kv) x bs x 4 (heads) x max tokens x 64 (head dim)
dec_config = model.config.decoder
layer_count = dec_config.decoder_layers
kv_heads = dec_config.kv_heads
head_dim = int(dec_config.d_model / dec_config.decoder_attention_heads)
decoder_cache = torch.zeros((layer_count, 2, len(batch_images), kv_heads, settings.RECOGNITION_MAX_TOKENS, head_dim), dtype=model.dtype, device=model.device)
kv_mask = torch.zeros((len(batch_images), settings.RECOGNITION_MAX_TOKENS), device=model.device)
if settings.RECOGNITION_STATIC_CACHE:
decoder_cache = [torch.zeros((2, len(batch_images), kv_heads, settings.RECOGNITION_MAX_TOKENS, head_dim), dtype=model.dtype, device=model.device) for _ in range(layer_count)]

min_val = torch.finfo(model.dtype).min
kv_mask = torch.full((len(batch_images), 1, 1, settings.RECOGNITION_MAX_TOKENS + 1), min_val, dtype=model.dtype, device=model.device)
kv_mask[:, :, :, -1] = 0
kv_mask[:, :, :, :inference_token_count] = 0
else:
kv_mask = torch.zeros((len(batch_images), 1, 1, inference_token_count + 1), dtype=model.dtype, device=model.device)
decoder_cache = [1 for _ in range(layer_count)]

# Encoder kv cache
# 7 (layers) x 2 (kv) x bs x 4 (heads) x 196 (max tokens) x 64 (head dim)
encoder_cache = torch.zeros((layer_count, 2, len(batch_images), kv_heads, 196, head_dim), dtype=model.dtype, device=model.device)
encoder_cache = [torch.zeros((2, len(batch_images), kv_heads, 196, head_dim), dtype=model.dtype, device=model.device) for _ in range(layer_count)]

attention_mask = torch.zeros((len(batch_images), 1, inference_token_count, inference_token_count), dtype=model.dtype, device=model.device)

with torch.inference_mode():
while token_count < settings.RECOGNITION_MAX_TOKENS:
is_prefill = token_count == 0
inference_token_count = batch_decoder_input.shape[-1]
# Run prefill tokens
return_dict = model(
decoder_input_ids=batch_decoder_input,
decoder_attention_mask=attention_mask,
decoder_kv_caches=None,
decoder_past_token_count=token_count,
decoder_langs=batch_langs,
pixel_values=batch_pixel_values,
encoder_outputs=None,
return_dict=True,
)

logits = return_dict["logits"]
preds = torch.argmax(logits[:, -1], dim=-1)
all_done = preds == processor.tokenizer.eos_id
sequence_scores = torch.max(F.softmax(logits, dim=-1), dim=-1).values
batch_decoder_input = preds.unsqueeze(1)

encoder_outputs = (return_dict["encoder_last_hidden_state"],)
past_key_values = return_dict["past_key_values"]
token_range = torch.arange(token_count, token_count + inference_token_count, device=model.device)
for layer_idx, layer in enumerate(past_key_values):
if settings.RECOGNITION_STATIC_CACHE:
decoder_cache[layer_idx][:, :, :, token_range, :] = layer[0]
else:
decoder_cache[layer_idx] = layer[0]
encoder_cache[layer_idx] = layer[1]

token_count = inference_token_count
attention_mask = kv_mask

# Run post-prefill tokens
while token_count < settings.RECOGNITION_MAX_TOKENS:
inference_token_count = batch_decoder_input.shape[-1]
with torch.inference_mode():
return_dict = model(
decoder_input_ids=batch_decoder_input,
decoder_attention_mask=attention_mask,
decoder_kv_caches=None if token_count == 0 else [decoder_cache, encoder_cache],
decoder_kv_caches=[decoder_cache, encoder_cache],
decoder_past_token_count=token_count,
decoder_langs=batch_langs,
pixel_values=batch_pixel_values,
encoder_outputs=encoder_outputs,
return_dict=True,
)

logits = return_dict["logits"]
preds = torch.argmax(logits[:, -1], dim=-1)
scores = torch.max(F.softmax(logits, dim=-1), dim=-1).values
done = preds == processor.tokenizer.eos_id
all_done = all_done | done
logits = return_dict["logits"]
preds = torch.argmax(logits[:, -1], dim=-1)
scores = torch.max(F.softmax(logits, dim=-1), dim=-1).values
done = preds == processor.tokenizer.eos_id
all_done = all_done | done

if sequence_scores is None:
sequence_scores = scores
else:
scores[all_done == 1] = 0
sequence_scores = torch.cat([sequence_scores, scores], dim=1)
scores[all_done == 1] = 0
sequence_scores = torch.cat([sequence_scores, scores], dim=1)

encoder_outputs = (return_dict["encoder_last_hidden_state"],)
past_key_values = return_dict["past_key_values"]
for layer_idx, layer in enumerate(past_key_values):
decoder_cache[layer_idx, 0, :, :, token_count:(token_count + inference_token_count), :] = layer[0]
decoder_cache[layer_idx, 1, :, :, token_count:(token_count + inference_token_count), :] = layer[1]
past_key_values = return_dict["past_key_values"]
token_range = torch.arange(token_count, token_count + inference_token_count, device=model.device)
for layer_idx, layer in enumerate(past_key_values):
if settings.RECOGNITION_STATIC_CACHE:
decoder_cache[layer_idx][:, :, :, token_range, :] = layer[0]
else:
decoder_cache[layer_idx] = torch.cat([decoder_cache[layer_idx], layer[0]], dim=3)

if is_prefill:
encoder_cache[layer_idx, 0, :, :, :, :] = layer[2]
encoder_cache[layer_idx, 1, :, :, :, :] = layer[3]
if all_done.all():
break

if all_done.all():
break
if settings.RECOGNITION_STATIC_CACHE:
kv_mask[:, :, :, token_count:(token_count + inference_token_count)] = 0
else:
kv_mask = torch.cat([kv_mask, torch.zeros((len(batch_images), 1, 1, inference_token_count), dtype=model.dtype, device=model.device)], dim=-1)

kv_mask[:, token_count:(token_count + inference_token_count)] = 1
attention_mask = torch.cat([kv_mask, ~all_done.unsqueeze(1)], dim=1)
attention_mask = kv_mask

for j, (pred, status) in enumerate(zip(preds, all_done)):
if not status:
batch_predictions[j].append(int(pred))
for j, (pred, status) in enumerate(zip(preds, all_done)):
if not status:
batch_predictions[j].append(int(pred))

batch_decoder_input = preds.unsqueeze(1)
token_count += inference_token_count
batch_decoder_input = preds.unsqueeze(1)
token_count += inference_token_count

sequence_scores = torch.sum(sequence_scores, dim=-1) / torch.sum(sequence_scores != 0, dim=-1)
detected_text = processor.tokenizer.batch_decode(batch_predictions)
Expand Down
1 change: 1 addition & 0 deletions surya/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def TORCH_DEVICE_DETECTION(self) -> str:
RECOGNITION_FONT_DL_BASE: str = "https://github.com/satbyy/go-noto-universal/releases/download/v7.0"
RECOGNITION_BENCH_DATASET_NAME: str = "vikp/rec_bench"
RECOGNITION_PAD_VALUE: int = 255 # Should be 0 or 255
RECOGNITION_STATIC_CACHE: bool = False # Static cache for torch compile

# Layout
LAYOUT_MODEL_CHECKPOINT: str = "vikp/surya_layout2"
Expand Down

0 comments on commit acc9b37

Please sign in to comment.