Skip to content

Commit

Permalink
Fix colors
Browse files Browse the repository at this point in the history
  • Loading branch information
VikParuchuri committed May 23, 2024
1 parent 06a9a8b commit 28e95a2
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 3 deletions.
5 changes: 3 additions & 2 deletions surya/input/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def slice_bboxes_from_image(image: Image.Image, bboxes):


def slice_polys_from_image(image: Image.Image, polys):
image_array = np.array(image)
image_array = np.array(image, dtype=np.uint8)
lines = []
for idx, poly in enumerate(polys):
lines.append(slice_and_pad_poly(image_array, poly))
Expand All @@ -98,8 +98,9 @@ def slice_and_pad_poly(image_array: np.array, coordinates):
coordinates = [(x - bbox[0], y - bbox[1]) for x, y in coordinates]

# Pad the area outside the polygon with the pad value
mask = np.zeros_like(cropped_polygon, dtype=np.uint8)
mask = np.zeros(cropped_polygon.shape[:2], dtype=np.uint8)
cv2.fillPoly(mask, [np.int32(coordinates)], 1)
mask = np.stack([mask] * 3, axis=-1)

cropped_polygon[mask == 0] = settings.RECOGNITION_PAD_VALUE
rectangle_image = Image.fromarray(cropped_polygon)
Expand Down
2 changes: 1 addition & 1 deletion surya/model/recognition/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def align_long_axis(
data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
) -> np.ndarray:
input_height, input_width = get_image_size(image, channel_dim=input_data_format)
input_height, input_width = image.shape[:2]
output_height, output_width = size["height"], size["width"]

if (output_width < output_height and input_width > input_height) or (
Expand Down
1 change: 1 addition & 0 deletions surya/ocr.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def run_ocr(images: List[Image.Image], langs: List[List[str]], det_model, det_pr
all_slices = []
slice_map = []
all_langs = []

for idx, (det_pred, image, lang) in enumerate(zip(det_predictions, images, langs)):
polygons = [p.polygon for p in det_pred.bboxes]
slices = slice_polys_from_image(image, polygons)
Expand Down
1 change: 1 addition & 0 deletions surya/recognition.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def batch_recognition(images: List, languages: List[List[str]], model, processor
has_math = ["_math" in lang for lang in batch_langs]
batch_images = images[i:i+batch_size]
batch_images = [image.convert("RGB") for image in batch_images]

model_inputs = processor(text=[""] * len(batch_langs), images=batch_images, lang=batch_langs)

batch_pixel_values = model_inputs["pixel_values"]
Expand Down

0 comments on commit 28e95a2

Please sign in to comment.