Skip to content

Commit

Permalink
Cleanup processors
Browse files Browse the repository at this point in the history
  • Loading branch information
VikParuchuri committed May 23, 2024
1 parent d549465 commit da6d43b
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 79 deletions.
4 changes: 2 additions & 2 deletions surya/detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from surya.model.detection.segformer import SegformerForRegressionMask
from surya.postprocessing.heatmap import get_and_clean_boxes
from surya.postprocessing.affinity import get_vertical_lines
from surya.input.processing import prepare_image, split_image, get_total_splits
from surya.input.processing import prepare_image_detection, split_image, get_total_splits
from surya.schema import TextDetectionResult
from surya.settings import settings
from tqdm import tqdm
Expand Down Expand Up @@ -62,7 +62,7 @@ def batch_detection(images: List, model: SegformerForRegressionMask, processor,
split_index.extend([image_idx] * len(image_parts))
split_heights.extend(split_height)

image_splits = [prepare_image(image, processor) for image in image_splits]
image_splits = [prepare_image_detection(image, processor) for image in image_splits]
# Batch images in dim 0
batch = torch.stack(image_splits, dim=0).to(model.dtype).to(model.device)

Expand Down
2 changes: 1 addition & 1 deletion surya/input/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def split_image(img, processor):
return [img.copy()], [img_height]


def prepare_image(img, processor):
def prepare_image_detection(img, processor):
new_size = (processor.size["width"], processor.size["height"])

img.thumbnail(new_size, Image.Resampling.LANCZOS) # Shrink largest dimension to fit new size
Expand Down
29 changes: 10 additions & 19 deletions surya/model/ordering/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,34 +31,25 @@ def __init__(self, *args, **kwargs):

self.patch_size = kwargs.get("patch_size", (4, 4))

def process_inner(self, images: List[List]):
# This will be in list of lists format, with height x width x channel
assert isinstance(images[0], (list, np.ndarray))
def process_inner(self, images: List[np.ndarray]):
images = [img.transpose(2, 0, 1) for img in images] # convert to CHW format

# convert list of lists format to array
if isinstance(images[0], list):
# numpy unit8 needed for augmentation
np_images = [np.array(img, dtype=np.uint8) for img in images]
else:
np_images = [img.astype(np.uint8) for img in images]
np_images = [img.transpose(2, 0, 1) for img in np_images] # convert to CHW format

assert np_images[0].shape[0] == 3 # RGB input images, channel dim last
assert images[0].shape[0] == 3 # RGB input images, channel dim last

# Convert to float32 for rescale/normalize
np_images = [img.astype(np.float32) for img in np_images]
images = [img.astype(np.float32) for img in images]

# Rescale and normalize
np_images = [
images = [
self.rescale(img, scale=self.rescale_factor, input_data_format=ChannelDimension.FIRST)
for img in np_images
for img in images
]
np_images = [
images = [
self.normalize(img, mean=self.image_mean, std=self.image_std, input_data_format=ChannelDimension.FIRST)
for img in np_images
for img in images
]

return np_images
return images

def process_boxes(self, boxes):
padded_boxes = []
Expand Down Expand Up @@ -152,7 +143,7 @@ def preprocess(
boxes = new_boxes

# Convert to numpy for later processing steps
images = [to_numpy_array(image) for image in images]
images = [np.array(image) for image in images]

images = self.process_inner(images)
boxes, box_mask, box_counts = self.process_boxes(boxes)
Expand Down
80 changes: 27 additions & 53 deletions surya/model/recognition/processor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Dict, Union, Optional, List, Tuple

import cv2
from torch import TensorType
from transformers import DonutImageProcessor, DonutProcessor, AutoImageProcessor, DonutSwinConfig
from transformers.image_processing_utils import BaseImageProcessor, get_size_dict, BatchFeature
Expand Down Expand Up @@ -29,84 +30,64 @@ def __init__(self, *args, max_size=None, train=False, **kwargs):
self.max_size = max_size
self.train = train

def numpy_resize(self, image: np.ndarray, size, resample):
image = PIL.Image.fromarray(image)
resized = self.pil_resize(image, size, resample)
resized = np.array(resized, dtype=np.uint8)
resized_image = resized.transpose(2, 0, 1)

return resized_image

def pil_resize(self, image: PIL.Image.Image, size, resample):
width, height = image.size
def numpy_resize(self, image: np.ndarray, size, interpolation=cv2.INTER_LANCZOS4):
height, width = image.shape[:2]
max_width, max_height = size["width"], size["height"]
if width != max_width or height != max_height:
# Shrink to fit within dimensions
width_scale = max_width / width
height_scale = max_height / height
scale = min(width_scale, height_scale)

new_width = min(int(width * scale), max_width)
new_height = min(int(height * scale), max_height)
image = image.resize((new_width, new_height), resample)
if (height == max_height and width <= max_width) or (width == max_width and height <= max_height):
return image

image.thumbnail((max_width, max_height), resample)
scale = min(max_width / width, max_height / height)

assert image.width <= max_width and image.height <= max_height

return image
new_width = int(width * scale)
new_height = int(height * scale)

def process_inner(self, images: List[List], train=False):
# This will be in list of lists format, with height x width x channel
assert isinstance(images[0], (list, np.ndarray))
resized_image = cv2.resize(image, (new_width, new_height), interpolation=interpolation)
resized_image = resized_image.transpose(2, 0, 1)

# convert list of lists format to array
if isinstance(images[0], list):
# numpy unit8 needed for augmentation
np_images = [np.array(img, dtype=np.uint8) for img in images]
else:
np_images = [img.astype(np.uint8) for img in images]
return resized_image

assert np_images[0].shape[2] == 3 # RGB input images, channel dim last
def process_inner(self, images: List[np.ndarray], train=False):
assert images[0].shape[2] == 3 # RGB input images, channel dim last

# Rotate if the bbox is wider than it is tall
np_images = [self.align_long_axis(image, size=self.max_size, input_data_format=ChannelDimension.LAST) for image in np_images]
images = [self.align_long_axis(image, size=self.max_size, input_data_format=ChannelDimension.LAST) for image in images]

# Verify that the image is wider than it is tall
for img in np_images:
for img in images:
assert img.shape[1] >= img.shape[0]

# This also applies the right channel dim format, to channel x height x width
np_images = [self.numpy_resize(img, self.max_size, self.resample) for img in np_images]
assert np_images[0].shape[0] == 3 # RGB input images, channel dim first
images = [self.numpy_resize(img, self.max_size, self.resample) for img in images]
assert images[0].shape[0] == 3 # RGB input images, channel dim first

# Convert to float32 for rescale/normalize
np_images = [img.astype(np.float32) for img in np_images]
images = [img.astype(np.float32) for img in images]

# Pads with 255 (whitespace)
# Pad to max size to improve performance
max_size = self.max_size
np_images = [
images = [
self.pad_image(
image=image,
size=max_size,
random_padding=train, # Change amount of padding randomly during training
input_data_format=ChannelDimension.FIRST,
pad_value=255.0
pad_value=settings.RECOGNITION_PAD_VALUE
)
for image in np_images
for image in images
]
# Rescale and normalize
np_images = [
images = [
self.rescale(img, scale=self.rescale_factor, input_data_format=ChannelDimension.FIRST)
for img in np_images
for img in images
]
np_images = [
images = [
self.normalize(img, mean=self.image_mean, std=self.image_std, input_data_format=ChannelDimension.FIRST)
for img in np_images
for img in images
]

return np_images
return images


def preprocess(
Expand All @@ -131,15 +112,8 @@ def preprocess(
) -> PIL.Image.Image:
images = make_list_of_images(images)

if not valid_images(images):
raise ValueError(
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
"torch.Tensor, tf.Tensor or jax.ndarray."
)

# Convert to numpy for later processing steps
images = [to_numpy_array(image) for image in images]

images = [np.array(img) for img in images]
images = self.process_inner(images, train=self.train)
data = {"pixel_values": images}
return BatchFeature(data=data, tensor_type=return_tensors)
Expand Down
17 changes: 13 additions & 4 deletions surya/recognition.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import List
import torch
from PIL import Image
from transformers import GenerationConfig

from surya.postprocessing.math.latex import fix_math, contains_math
from surya.postprocessing.text import truncate_repetitions
Expand All @@ -25,6 +26,15 @@ def batch_recognition(images: List, languages: List[List[str]], model, processor
assert all([isinstance(image, Image.Image) for image in images])
assert len(images) == len(languages)

generation_config = GenerationConfig(
max_new_tokens=settings.RECOGNITION_MAX_TOKENS,
eos_token_id=processor.tokenizer.eos_id,
output_scores=True,
return_dict_in_generate=True,
bos_token_id=processor.tokenizer.eos_id,
pad_token_id=processor.tokenizer.pad_id,
)

if batch_size is None:
batch_size = get_batch_size()

Expand All @@ -45,16 +55,15 @@ def batch_recognition(images: List, languages: List[List[str]], model, processor
batch_langs = torch.from_numpy(np.array(batch_langs, dtype=np.int64)).to(model.device)
batch_pixel_values = torch.tensor(np.array(batch_pixel_values), dtype=model.dtype).to(model.device)
batch_decoder_input = torch.from_numpy(np.array(batch_decoder_input, dtype=np.int64)).to(model.device)
decoder_attention_mask = torch.ones_like(batch_decoder_input, device=model.device)

with torch.inference_mode():
return_dict = model.generate(
pixel_values=batch_pixel_values,
decoder_input_ids=batch_decoder_input,
decoder_attention_mask=decoder_attention_mask,
decoder_langs=batch_langs,
eos_token_id=processor.tokenizer.eos_id,
max_new_tokens=settings.RECOGNITION_MAX_TOKENS,
output_scores=True,
return_dict_in_generate=True
generation_config=generation_config,
)
generated_ids = return_dict["sequences"]

Expand Down

0 comments on commit da6d43b

Please sign in to comment.