From 28e95a20a67e9d3209518dd3bd643cfc69b53e56 Mon Sep 17 00:00:00 2001 From: Vik Paruchuri Date: Thu, 23 May 2024 16:09:23 -0700 Subject: [PATCH] Fix colors --- surya/input/processing.py | 5 +++-- surya/model/recognition/processor.py | 2 +- surya/ocr.py | 1 + surya/recognition.py | 1 + 4 files changed, 6 insertions(+), 3 deletions(-) diff --git a/surya/input/processing.py b/surya/input/processing.py index 17ce4ab..ead35cc 100644 --- a/surya/input/processing.py +++ b/surya/input/processing.py @@ -81,7 +81,7 @@ def slice_bboxes_from_image(image: Image.Image, bboxes): def slice_polys_from_image(image: Image.Image, polys): - image_array = np.array(image) + image_array = np.array(image, dtype=np.uint8) lines = [] for idx, poly in enumerate(polys): lines.append(slice_and_pad_poly(image_array, poly)) @@ -98,8 +98,9 @@ def slice_and_pad_poly(image_array: np.array, coordinates): coordinates = [(x - bbox[0], y - bbox[1]) for x, y in coordinates] # Pad the area outside the polygon with the pad value - mask = np.zeros_like(cropped_polygon, dtype=np.uint8) + mask = np.zeros(cropped_polygon.shape[:2], dtype=np.uint8) cv2.fillPoly(mask, [np.int32(coordinates)], 1) + mask = np.stack([mask] * 3, axis=-1) cropped_polygon[mask == 0] = settings.RECOGNITION_PAD_VALUE rectangle_image = Image.fromarray(cropped_polygon) diff --git a/surya/model/recognition/processor.py b/surya/model/recognition/processor.py index 645197a..f85f3b1 100644 --- a/surya/model/recognition/processor.py +++ b/surya/model/recognition/processor.py @@ -155,7 +155,7 @@ def align_long_axis( data_format: Optional[Union[str, ChannelDimension]] = None, input_data_format: Optional[Union[str, ChannelDimension]] = None, ) -> np.ndarray: - input_height, input_width = get_image_size(image, channel_dim=input_data_format) + input_height, input_width = image.shape[:2] output_height, output_width = size["height"], size["width"] if (output_width < output_height and input_width > input_height) or ( diff --git a/surya/ocr.py b/surya/ocr.py index 0847762..17d8fb6 100644 --- a/surya/ocr.py +++ b/surya/ocr.py @@ -62,6 +62,7 @@ def run_ocr(images: List[Image.Image], langs: List[List[str]], det_model, det_pr all_slices = [] slice_map = [] all_langs = [] + for idx, (det_pred, image, lang) in enumerate(zip(det_predictions, images, langs)): polygons = [p.polygon for p in det_pred.bboxes] slices = slice_polys_from_image(image, polygons) diff --git a/surya/recognition.py b/surya/recognition.py index 6122cb0..d28cc46 100644 --- a/surya/recognition.py +++ b/surya/recognition.py @@ -37,6 +37,7 @@ def batch_recognition(images: List, languages: List[List[str]], model, processor has_math = ["_math" in lang for lang in batch_langs] batch_images = images[i:i+batch_size] batch_images = [image.convert("RGB") for image in batch_images] + model_inputs = processor(text=[""] * len(batch_langs), images=batch_images, lang=batch_langs) batch_pixel_values = model_inputs["pixel_values"]