Skip to content

Commit

Permalink
Improve performance
Browse files Browse the repository at this point in the history
  • Loading branch information
VikParuchuri committed May 27, 2024
1 parent df534cf commit e5c9b44
Show file tree
Hide file tree
Showing 3 changed files with 108 additions and 107 deletions.
42 changes: 28 additions & 14 deletions surya/model/recognition/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,35 @@
from transformers import MBartForCausalLM, MBartConfig
from torch import nn
from transformers.activations import ACT2FN
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_attention_mask
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions, BaseModelOutputWithPastAndCrossAttentions
from transformers.models.mbart.modeling_mbart import MBartPreTrainedModel, MBartDecoder, \
MBartLearnedPositionalEmbedding
from transformers.models.mbart.modeling_mbart import MBartPreTrainedModel, MBartDecoder
from surya.model.recognition.config import MBartMoEConfig
import torch
import math


class MBartLearnedPositionalEmbedding(nn.Embedding):
"""
This module learns positional embeddings up to a fixed maximum size.
"""

def __init__(self, num_embeddings: int, embedding_dim: int):
# MBart is set up so that if padding_idx is specified then offset the embedding ids by 2
# and adjust num_embeddings appropriately. Other models don't have this hack
self.offset = 2
super().__init__(num_embeddings + self.offset, embedding_dim)

def forward(self, input_ids: torch.Tensor, past_key_values_length: int = 0):
"""`input_ids' shape is expected to be [bsz x seqlen]."""

bsz, seq_len = input_ids.shape[:2]
positions = torch.arange(
past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device
).expand(bsz, -1)

return super().forward(positions + self.offset)


class MBartExpertMLP(nn.Module):
def __init__(self, config: MBartConfig, is_lg=False, is_xl=False):
super().__init__()
Expand All @@ -32,7 +52,6 @@ def __init__(self, config: MBartConfig, is_lg=False, is_xl=False):

def forward(self, hidden_states):
current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states)
current_hidden_states = self.dropout(current_hidden_states)
current_hidden_states = self.w2(current_hidden_states)
return current_hidden_states

Expand Down Expand Up @@ -65,8 +84,11 @@ def forward(self, hidden_states: torch.Tensor, langs: torch.LongTensor) -> torch
# Set weights to 1 if zero experts activated
routing_weights[torch.isinf(routing_weights)] = 1

unique_langs = langs.view(-1).unique(sorted=True).tolist()
unique_langs = [l for l in unique_langs if l in self.lang_codes]

# Loop over all available experts in the model and perform the computation on each expert
for expert_idx, expert_lang in enumerate(self.lang_codes):
for expert_lang in unique_langs:
# Check which samples match with this expert
lang_match = (langs == expert_lang).any(dim=-1)
idx = torch.nonzero(lang_match, as_tuple=True)[0]
Expand All @@ -78,7 +100,6 @@ def forward(self, hidden_states: torch.Tensor, langs: torch.LongTensor) -> torch

current_state = hidden_states[idx]
current_hidden_states = expert_layer(current_state.view(-1, hidden_dim))
current_hidden_states = self.dropout(current_hidden_states)
current_hidden_states = current_hidden_states.view(-1, sequence_length, hidden_dim)

# Weight by number of languages in the input
Expand Down Expand Up @@ -190,8 +211,7 @@ def forward(

attn_weights = nn.functional.softmax(attn_weights, dim=-1)

attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
attn_output = torch.bmm(attn_probs, value_states).view(bsz, self.num_heads, tgt_len, self.head_dim).transpose(1,2)
attn_output = torch.bmm(attn_weights, value_states).view(bsz, self.num_heads, tgt_len, self.head_dim).transpose(1,2)

# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
# partitioned across GPUs when using tensor-parallelism.
Expand Down Expand Up @@ -261,7 +281,6 @@ def forward(
is_prefill=is_prefill,
attention_mask=attention_mask,
)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states

# Cross-Attention Block
Expand All @@ -277,7 +296,6 @@ def forward(
attention_mask=encoder_attention_mask,
past_key_value=cross_kv_cache,
)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states

# add cross-attn to positions 3,4 of present_key_value tuple
Expand All @@ -290,9 +308,7 @@ def forward(
hidden_states = self.moe(hidden_states, langs)
else:
hidden_states = self.activation_fn(self.fc1(hidden_states))
hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
hidden_states = self.fc2(hidden_states)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)

hidden_states = residual + hidden_states

Expand Down Expand Up @@ -358,8 +374,6 @@ def forward(
hidden_states = inputs_embeds + positions.to(inputs_embeds.device)
hidden_states = self.layernorm_embedding(hidden_states)

hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)

# decoder layers
all_hidden_states = None
all_self_attns = None
Expand Down
36 changes: 11 additions & 25 deletions surya/model/recognition/processor.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
from typing import Dict, Union, Optional, List, Tuple
from typing import Dict, Union, Optional, List

import cv2
from torch import TensorType
from transformers import DonutImageProcessor, DonutProcessor, AutoImageProcessor, DonutSwinConfig
from transformers.image_processing_utils import BaseImageProcessor, get_size_dict, BatchFeature
from transformers.image_transforms import to_channel_dimension_format, pad, _rescale_for_pil_conversion, to_pil_image
from transformers.image_utils import PILImageResampling, ImageInput, ChannelDimension, make_list_of_images, \
valid_images, to_numpy_array, is_scaled_image, infer_channel_dimension_format, get_image_size
from transformers import DonutImageProcessor, DonutProcessor
from transformers.image_processing_utils import BatchFeature
from transformers.image_transforms import pad
from transformers.image_utils import PILImageResampling, ImageInput, ChannelDimension, make_list_of_images, get_image_size
import numpy as np
from PIL import Image
import PIL
Expand Down Expand Up @@ -47,7 +46,7 @@ def numpy_resize(self, image: np.ndarray, size, interpolation=cv2.INTER_LANCZOS4

return resized_image

def process_inner(self, images: List[np.ndarray], train=False):
def process_inner(self, images: List[np.ndarray]):
assert images[0].shape[2] == 3 # RGB input images, channel dim last

# Rotate if the bbox is wider than it is tall
Expand All @@ -71,25 +70,21 @@ def process_inner(self, images: List[np.ndarray], train=False):
self.pad_image(
image=image,
size=max_size,
random_padding=train, # Change amount of padding randomly during training
input_data_format=ChannelDimension.FIRST,
pad_value=settings.RECOGNITION_PAD_VALUE
)
for image in images
]
# Rescale and normalize
images = [
self.rescale(img, scale=self.rescale_factor, input_data_format=ChannelDimension.FIRST)
for img in images
]
for idx in range(len(images)):
images[idx] = images[idx] * self.rescale_factor
images = [
self.normalize(img, mean=self.image_mean, std=self.image_std, input_data_format=ChannelDimension.FIRST)
for img in images
]

return images


def preprocess(
self,
images: ImageInput,
Expand All @@ -114,15 +109,14 @@ def preprocess(

# Convert to numpy for later processing steps
images = [np.array(img) for img in images]
images = self.process_inner(images, train=self.train)
images = self.process_inner(images)
data = {"pixel_values": images}
return BatchFeature(data=data, tensor_type=return_tensors)

def pad_image(
self,
image: np.ndarray,
size: Dict[str, int],
random_padding: bool = False,
data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
pad_value: float = 0.0,
Expand All @@ -135,12 +129,8 @@ def pad_image(

assert delta_width >= 0 and delta_height >= 0

if random_padding:
pad_top = np.random.randint(low=0, high=delta_height + 1)
pad_left = np.random.randint(low=0, high=delta_width + 1)
else:
pad_top = delta_height // 2
pad_left = delta_width // 2
pad_top = delta_height // 2
pad_left = delta_width // 2

pad_bottom = delta_height - pad_top
pad_right = delta_width - pad_left
Expand Down Expand Up @@ -180,10 +170,6 @@ def __init__(self, image_processor=None, tokenizer=None, train=False, **kwargs):
self._in_target_context_manager = False

def __call__(self, *args, **kwargs):
# For backward compatibility
if self._in_target_context_manager:
return self.current_processor(*args, **kwargs)

images = kwargs.pop("images", None)
text = kwargs.pop("text", None)
lang = kwargs.pop("lang", None)
Expand Down
Loading

0 comments on commit e5c9b44

Please sign in to comment.