In [0]:
%pip install pytesseract pdf2image
dbutils.library.restartPython()

: 

### Cluster Config to use for reproducibility:

```json
{
    "cluster_name": "alex miller OCR cluster udf",
    "spark_version": "15.4.x-scala2.12",
    "spark_conf": {
        "spark.databricks.pyspark.dataFrameChunk.enabled": "true"
    },
    "azure_attributes": {
        "availability": "ON_DEMAND_AZURE"
    },
    "node_type_id": "Standard_D16ads_v5",
    "autotermination_minutes": 120,
    "init_scripts": [
        {
            "workspace": {
                "destination": "/Users/alex.miller@databricks.com/ray-ocr/init.sh"
            }
        }
    ],
    "single_user_name": "alex.miller@databricks.com",
    "data_security_mode": "DATA_SECURITY_MODE_AUTO",
    "runtime_engine": "PHOTON",
    "kind": "CLASSIC_PREVIEW",
    "use_ml_runtime": true,
    "is_single_node": false,
    "num_workers": 6,
    "apply_policy_default_values": false
}
```

### Setup Ray cluster:
- init.sh is included in Databricks cluster to download tesseract-ocr package to all nodes
- Spark Cluster has 6 `num_workers` but will pass 4 to Ray and leave 2 for Spark (let Spark handle to read and write process)
- Supplying Ray with 4 `min_worker_nodes` and `max_worker_nodes` (leaving 2 for Spark)

In [0]:
from ray.util.spark import setup_ray_cluster, MAX_NUM_WORKER_NODES, shutdown_ray_cluster
import ray

restart = True
if restart is True:
  try:
    shutdown_ray_cluster()
  except:
    pass
  try:
    ray.shutdown()
  except:
    pass

# Ray allows you to define custom cluster configurations using setup_ray_cluster function
# This allows you to allocate CPUs and GPUs on Ray cluster
ray_context = setup_ray_cluster(
  min_worker_nodes=4,       # minimum number of worker nodes to start
  max_worker_nodes=4,       # maximum number of worker nodes to start (autoscaling)
  num_gpus_worker_node=0,   # number of GPUs to allocate per worker node
  num_gpus_head_node=0,     # number of GPUs to allocate on head node (driver)
  num_cpus_worker_node=12,   # number of CPUs to allocate on worker nodes, only giving Ray 1 and Spark the rest
  num_cpus_head_node=8,    # number of CPUs to allocate on head node (driver)
  collect_log_tp_path="/Volumes/alex_m/gen_ai/pdfs/ray_collected_logs"
)

# Pass any custom configuration to ray.init
ray.init(ignore_reinit_error=True)
print(ray.cluster_resources())

In [None]:
### Description of the Code

The code below is designed to process PDF documents using a combination of Spark and Ray for distributed computing. The workflow involves the following steps:

1. **Cluster Setup**: The Ray cluster is set up with specific configurations for the number of worker nodes, CPUs, and GPUs allocated to both the head node and worker nodes. This setup ensures efficient resource utilization between Spark and Ray.

2. **PDF Processing**: A `PDFProcessor` class is defined to convert PDF documents into images. The class includes methods to convert PDF data to image bytes and handle batches of PDF documents.

3. **OCR Processing**: An `OCRProcessor` class is defined to perform Optical Character Recognition (OCR) on the images generated from the PDF documents. The class includes methods to convert image bytes to PIL images and handle batches of images for OCR processing.

4. **Main Function**: The `main` function orchestrates the entire workflow:
    - Reads PDF documents from a Spark table.
    - Converts the Spark DataFrame to a Ray Dataset.
    - Processes the PDFs to convert them into images using the `PDFProcessor`.
    - Performs OCR on the images using the `OCRProcessor`.
    - Converts the OCR results to a Pandas DataFrame and displays the results.
    - Saves the processed data back to a Spark table.

5. **Shutdown**: Finally, the Ray cluster is shut down to release the resources.

This approach leverages the parallel processing capabilities of Ray and the distributed data handling capabilities of Spark to efficiently process large volumes of PDF documents and extract text using OCR.

In [0]:
import io
from typing import List, Dict, Any
import ray
from pdf2image import convert_from_bytes
from PIL import Image
import pytesseract
import time
from pyspark.sql import functions as F


class PDFProcessor:
    """
    A class for processing PDF documents and converting them to images.
    """
    def __init__(self):
        pass

    @staticmethod
    def pdf_to_image_bytes(pdf_data: bytes) -> List[bytes]:
        """
        Convert PDF data to a list of image byte strings.

        Args:
            pdf_data (bytes): Raw PDF data.

        Returns:
            List[bytes]: List of image byte strings.
        """
        pages = convert_from_bytes(pdf_data)
        return [PDFProcessor._image_to_bytes(page) for page in pages]

    @staticmethod
    def _image_to_bytes(image: Image.Image) -> bytes:
        """
        Convert a PIL Image to bytes.

        Args:
            image (Image.Image): PIL Image object.

        Returns:
            bytes: Byte string representation of the image.
        """
        img_byte_arr = io.BytesIO()
        image.save(img_byte_arr, format='PNG')
        return img_byte_arr.getvalue()

    def __call__(self, batch: Dict[str, Any]) -> Dict[str, List[Any]]:
        """
        Process a batch of PDF documents.

        Args:
            batch (Dict[str, Any]): Batch of PDF documents.

        Returns:
            Dict[str, List[Any]]: Processed batch with pages, paths, and page numbers.
        """
        results = []
        for content, path in zip(batch["content"], batch["path"]):
            try:
                pages = PDFProcessor.pdf_to_image_bytes(content)
                results.extend([
                    {"page": page, "path": path, "page_number": i + 1}
                    for i, page in enumerate(pages)
                ])
            except Exception as e:
                results.append({
                    "page": b"",
                    "path": path,
                    "page_number": -1,
                    "error": str(e)
                })

        return {
            "page": [item["page"] for item in results],
            "path": [item["path"] for item in results],
            "page_number": [item["page_number"] for item in results],
            "error": [item.get("error", "") for item in results]
        }

class OCRProcessor:
    """
    A class for performing OCR on images.
    """

    def __init__(self):
        self.tesseract = pytesseract

    @staticmethod
    def bytes_to_pil(image_bytes: bytes) -> Image.Image:
        """
        Convert image bytes to PIL Image.

        Args:
            image_bytes (bytes): Byte string representation of an image.

        Returns:
            Image.Image: PIL Image object.
        """
        return Image.open(io.BytesIO(image_bytes))

    # def process_batch(self, batch: Dict[str, Any]) -> Dict[str, List[Any]]:
    def __call__(self, batch: Dict[str, Any]) -> Dict[str, List[Any]]:
        """
        Process a batch of images with OCR.

        Args:
            batch (Dict[str, Any]): Batch of images.

        Returns:
            Dict[str, List[Any]]: OCR results for the batch.
        """
        results = []
        for page, path, page_number in zip(batch["page"], batch["path"], batch["page_number"]):
            start_time = time.time()
            try:
                image = self.bytes_to_pil(page)
                text = self.tesseract.image_to_string(image)
                results.append({
                    "text": text or "",
                    "status": "success",
                    "error": "",
                    "path": path,
                    "page_number": page_number,
                    "duration": time.time() - start_time
                })
                del image
            except Exception as e:
                results.append({
                    "text": "",
                    "status": "error",
                    "error": str(e),
                    "path": path,
                    "page_number": page_number,
                    "duration": time.time() - start_time
                })

        return {key: [item[key] for item in results] for key in results[0]}

# def main():
#     # Define Spark DataFrame column names for PDFs
#     pdf_column = "pdf"
#     path_column = "__url__"

#     # Read PDFs from Spark table
#     sdf = (
#         spark.read.table("alex_m.gen_ai.pixparse_pdfs")
#         .select(F.col(pdf_column).alias("content"), F.col(path_column).alias("path"))
#         .limit(1000)
#     )

#     # Create Ray Dataset
#     ray_dataset = ray.data.from_spark(sdf)

#     # Set concurrency
#     min_concurrency = int(ray.cluster_resources().get("CPU", 1) * 0.4)
#     max_concurrency = int(ray.cluster_resources().get("CPU", 1))
#     max_concurrency = int(ray.cluster_resources().get("CPU", 1) * 0.9)

#     # Process PDFs
#     pages_dataset = ray_dataset.map_batches(
#         PDFProcessor,
#         # PDFProcessor.process_batch,
#         batch_size=100,
#         num_cpus=1,
#         # concurrency=max_concurrency
#         concurrency=(15, 28),
#     )

#     # Perform OCR
#     # ocr_processor = OCRProcessor()
#     # ocr_processor = OCRProcessor.remote()

#     ocr_dataset = pages_dataset.map_batches(
#         OCRProcessor,
#         # ocr_processor.process_batch,
#         batch_size=8,
#         num_cpus=1,
#         concurrency=(20, 28),   # recommendation from Ray docs?
#     )

#     # Convert to pandas and display results
#     ocr_dataset_pd = ocr_dataset.to_pandas()
#     ocr_dataset_pd.display()

#     processed_spark_df = spark.createDataFrame(ocr_dataset_pd)
#     processed_spark_df.write.mode("overwrite").saveAsTable("alex_m.gen_ai.ray_ocr")

# if __name__ == "__main__":
#     main()

def main(config: ProcessingConfig = ProcessingConfig()) -> None:
    """
    Main processing pipeline for PDF OCR.
    
    Args:
        config: ProcessingConfig object containing all processing parameters
    """
    # Read PDFs from Spark table
    sdf = (
        spark.read.table(config.data.input_table)
        .select(
            F.col(config.data.pdf_column).alias("content"), 
            F.col(config.data.path_column).alias("path")
        )
    )
    
    if config.data.limit_rows:
        sdf = sdf.limit(config.data.limit_rows)

    # Create Ray Dataset
    ray_dataset = ray.data.from_spark(sdf)

    # Set concurrency based on available CPUs
    cpu_count = ray.cluster_resources().get("CPU", 1)
    min_concurrency = int(cpu_count * config.ray.min_concurrency_factor)
    max_concurrency = int(cpu_count * config.ray.max_concurrency_factor)

    # Process PDFs
    pages_dataset = ray_dataset.map_batches(
        PDFProcessor,
        batch_size=config.ray.pdf_batch_size,
        num_cpus=config.ray.pdf_num_cpus,
        concurrency=config.ray.pdf_concurrency,
    )

    # Perform OCR
    ocr_dataset = pages_dataset.map_batches(
        OCRProcessor,
        batch_size=config.ray.ocr_batch_size,
        num_cpus=config.ray.ocr_num_cpus,
        concurrency=config.ray.ocr_concurrency,
    )

    # Convert to pandas and display results
    ocr_dataset_pd = ocr_dataset.to_pandas()
    ocr_dataset_pd.display()

    # Save results
    processed_spark_df = spark.createDataFrame(ocr_dataset_pd)
    processed_spark_df.write.mode("overwrite").saveAsTable(config.data.output_table)

if __name__ == "__main__":
    from config import ProcessingConfig, RayConfig, DataConfig
    # Example usage with custom configuration
    config = ProcessingConfig(
        ray=RayConfig(
            pdf_batch_size=100,
            pdf_num_cpus=1,
            pdf_concurrency=(20, 28),
            ocr_batch_size=8,
            ocr_num_cpus=1,
            ocr_concurrency=(20, 28)
        ),
        data=DataConfig(
            limit_rows=1000,
            input_table="alex_m.gen_ai.pixparse_pdfs",
            pdf_column="pdf",
            path_column="__url__",
            output_table="alex_m.gen_ai.ray_ocr"
        )
    )
    main(config)
    
    shutdown_ray_cluster()
    ray.shutdown()