In [3]:
import re
import os
import cv2
from PIL import Image
from numpy import ndarray
from typing import List, Tuple, Any
from pathlib import Path
import matplotlib.pyplot as plt
from surya.recognition import OCRResult, RecognitionPredictor
from surya.detection import DetectionPredictor

class SuryaOCR:
    def __init__(self)-> None:
        os.environ["RECOGNITION_BATCH_SIZE"] = "512"
        
        self.detection_predictor = DetectionPredictor()
        self.recognition_predictor = RecognitionPredictor()
        self.langs = ['es', 'en']

    def get_predictions(self, image: Image.Image)-> List[OCRResult]:
        return self.recognition_predictor(images=[image], langs=[self.langs], det_predictor=self.detection_predictor)

    def get_text_from_predictions(self, predictions: List[OCRResult]) -> str:
        if not predictions:
            raise ValueError('Predictions are required')
        if not predictions[0].text_lines:
            return 'No text detected.'
        return predictions[0].text_lines[0].text

    def get_bouding_boxes_from_predictions(self, predictions: List[OCRResult])-> List[float]:
        if not predictions:
            raise ValueError('Predictions are required')
        if not predictions[0].text_lines:
            return [0, 0, 0, 0]
        return predictions[0].text_lines[0].bbox

    def get_confidence_from_predictions(self, predictions: List[OCRResult])-> float:
        if not predictions:
            raise ValueError('Predictions are required')
        if not predictions[0].text_lines:
            return 0
        confidence: float | None = predictions[0].text_lines[0].confidence
        return confidence if confidence is not None else 0

    def show_image_with_bounding_boxes(self, image_path: str, bounding_boxes: List[float], text: str = "Image with boundings")-> None:
        if not image_path:
            raise ValueError('Image is required')
        if not bounding_boxes:
            raise ValueError('Bounding boxes are required')

        image = cv2.imread(image_path)

        x1, y1, x2, y2 = map(int, bounding_boxes)

        cv2.rectangle(img=image, pt1=(x1, y1), pt2=(x2, y2), color=(0, 255, 0), thickness=2)

        plt.imshow(X=cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
        plt.axis('off')
        plt.title(label=text)
        plt.show()

    def cropped_image(self, image_path: str, bounding_boxes: List[float], showImage: bool):
        if not image_path:
            raise ValueError('Image is required')
        if not bounding_boxes:
            raise ValueError('Bounding boxes are required')

        image = cv2.imread(image_path)

        x1, y1, x2, y2 = map(int, bounding_boxes)
        cropped_image = image[y1:y2, x1:x2]

        if showImage:
            self.show_image_with_bounding_boxes(image_path=image_path, bounding_boxes=bounding_boxes, text='Cropped Image')

        return cropped_image

    def save_cropped_image(self, image: ndarray, output_dir: Path, name: str)-> None: 
        if image.size == 0:
            raise ValueError('Image is required')
        if not output_dir:
            raise ValueError('Output directory is required')

        os.chdir(path=output_dir)
        cv2.imwrite(name, image)

    def predict(self, image_path: str, show_result: bool = False)-> Tuple[str, List[float], float]:
        if not image_path:
            raise ValueError('Image is required')

        image: Image.Image = Image.open(fp=image_path)
        predictions: List[OCRResult] = self.get_predictions(image=image)
        extracted_text: str = self.get_text_from_predictions(predictions=predictions)
        bounding_boxes: List[float] = self.get_bouding_boxes_from_predictions(predictions=predictions)
        confidence: float = self.get_confidence_from_predictions(predictions=predictions)

        if show_result:
            self.show_image_with_bounding_boxes(image_path=image_path, bounding_boxes=bounding_boxes, text=extracted_text)

        return extracted_text, bounding_boxes, confidence

    def cropped_image_from_predictions(self, image_path: str, predictions: dict, show_image: bool = False, show_text: bool = False) -> ndarray:
        if not image_path:
            raise ValueError('Image is required')
        if not predictions:
            raise ValueError('Predictions are required')

        image: ndarray = cv2.imread(image_path)

        x_center: Any = predictions["x"]
        y_center: Any = predictions["y"]
        w: Any = predictions["width"]
        h: Any = predictions["height"]

        x1: int = int(x_center - w/2)
        y1: int = int(y_center - h/2)
        x2: int = int(x_center + w/2)
        y2: int = int(y_center + h/2)

        x1 = max(0, x1)
        y1 = max(0, y1)
        x2 = min(image.shape[1], x2)
        y2 = min(image.shape[0], y2)

        scale_factor = 2

        cropped_image: ndarray = image[y1:y2, x1:x2]
        # rescaled_image = cv2.resize(
        #     cropped_image,
        #     None,
        #     fx=scale_factor,
        #     fy=scale_factor,
        #     interpolation=cv2.INTER_LANCZOS4
        # )
        gray_image = cv2.cvtColor(cropped_image, cv2.COLOR_BGR2GRAY)
        channles_image = cv2.cvtColor(gray_image, cv2.COLOR_GRAY2BGR)
        
        if show_text:
            cv2.imwrite("tmp_cropped_image.jpg", channles_image)
            text, _, _ = self.predict(image_path="tmp_cropped_image.jpg")
            Path("tmp_cropped_image.jpg").unlink()

        if show_image:
            plt.imshow(X=channles_image)
            if show_text:
                plt.title(label=text)
            else:
                plt.title(label="Cropped Image")
            plt.axis("off")
            plt.show()

        return cropped_image
    
    def get_image_name_from_path(self, image_path: str)-> str:
        if not image_path:
            raise ValueError('Image is required')

        return Path(image_path).stem
    
    def predict_and_save_cropped_image(self, image_path: str, predictions: list[Any], output_dir: Path)-> list[Any]:
        if image_path is None:
            raise ValueError('Image is required')
        if not predictions:
            raise ValueError('Predictions are required')
        if not output_dir:
            raise ValueError('Output directory is required')

        image_name: str = self.get_image_name_from_path(image_path=image_path)

        for index, prediction in enumerate(predictions):
            try:
                cropped_image: ndarray = self.cropped_image_from_predictions(image_path=image_path, predictions=prediction, show_image=True, show_text=True)
                self.save_cropped_image(image=cropped_image, output_dir=output_dir, name=f"{image_name}_{index}.jpg")
                new_image_path = f"{output_dir}/{image_name}_{index}.jpg"
                text, _, _ = self.predict(image_path=new_image_path)
                prediction["text"] = text
            except Exception as e:
                logging.error(msg=f"There was an error adding this prediction: {e}")
                continue

        return predictions

    def natural_key(self, filename: str):
        return [int(text) if text.isdigit() else text.lower() for text in re.split('(\d+)', filename)]

    def extract_text_and_y(self, predictions: list[Any]) -> list[dict]:
        result = []
        for prediction in predictions:
            if "text" in prediction and "y" in prediction:
                text = prediction["text"].strip()
                if text:  # Se incluye solo si el texto no está vacío
                    result.append({"text": text, "y": prediction["y"]})
        return result

    def predict_and_save_from_dir(self, input_dir: Path, output_dir: Path, yolov12: Yolo) -> list[dict]:
        if not input_dir:
            raise ValueError('Input directory is required')
        if not output_dir:
            raise ValueError('Output directory is required')
    
        results = []
    
        for image_path in sorted(input_dir.glob("*.png"), key=lambda p: self.natural_key(p.name)):
            logging.info(msg=f"{image_path}")
            print(f"{image_path}")
            
            predictions: list[Any] = yolov12.predict_code(image=image_path)
            cropped_predictions = self.predict_and_save_cropped_image(
                image_path=str(image_path),
                predictions=predictions,
                output_dir=output_dir
            )

            filtered_predictions = self.extract_text_and_y(predictions=predictions)
            
            results.append({
                "file_name": str(image_path),
                "predictions": filtered_predictions
            })
    
        return results

In [4]:
suryaOCR: SuryaOCR = SuryaOCR()

Loaded detection model s3://text_detection/2025_02_28 on device cpu with dtype torch.float32
Loaded recognition model s3://text_recognition/2025_02_18 on device cpu with dtype torch.float32


In [None]:
image: str = "/kaggle/working/detects_valid_scenes/test_01-Scene-005-02.png"

In [None]:
image_path: Path = Path(image)