In [0]:
%pip install kagglehub openai
%restart_python

In [0]:
import os
import shutil
import json
from pathlib import Path
from PIL import Image
import io
import base64

from pyspark.sql import functions as F
from pyspark.sql.types import StructType, StructField, StringType, FloatType, IntegerType
from openai import OpenAI
import kagglehub

TOKEN = dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiToken().get()
URL = dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiUrl().get()

## Download Check Data
We are going to use the Kaggle check dataset and Unity Catalog volumes for our first dataset. We download the files to temporary directory and then copy them over to a Volume.

In [0]:
%sql
-- Setup our volume
CREATE VOLUME IF NOT EXISTS shm.default.bank_checks

In [0]:
# Download the dataset and save it to the volume
path = Path(kagglehub.dataset_download("saifkhichi96/bank-checks-signatures-segmentation-dataset"))

subdir = "TestSet/X/"
file_dir = path / subdir
for file_name in os.listdir(file_dir):
    file_path = path / subdir / file_name
    shutil.copy(file_path, f"/Volumes/shm/default/bank_checks/{file_name}")

# Use Spark's read binary file to load all the images into a table
images_df = (
  spark.read.format("binaryFile")
  .load("/Volumes/shm/default/bank_checks/")
)

## Parse Check Data
Now that we have a table of images (this could also be PDF pages etc.!), we simply pass the images in batch to a multimodal model.

In [0]:
image_data= base64.b64encode(
  images_df.collect()[0]['content']
  ).decode("utf-8")

display(Image.open(io.BytesIO(base64.b64decode(image_data))))

In [0]:
client = OpenAI(api_key=TOKEN, base_url=f"{URL}/serving-endpoints/")

structured_prompt = """Extract all information from this check and return ONLY valid JSON with:
{
  'date': 'yyyy-mm-dd',
  'payee_name': 'string',
  'dollar_amount': 'float',
  'reason_for_payment': 'string',
  'check_number': 'integer'
  'sender_name': 'string',
  'sender_address': 'string'
}
"""

chat_completion = client.chat.completions.create(
  messages=[
    {
      "role": "system",
      "content": structured_prompt
    },
    {
      "role": "user",
      "content": [
        {"type": "text", "text": "Extract all infromation from this image"},
        {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{image_data}"}}
      ]
    }
  ],
  model="databricks-claude-3-7-sonnet",
  max_tokens=512,
)

parsed_text = chat_completion.choices[0].message.content
print(parsed_text)

## Run it at Scale
We now move from a single call to operating at scale. The beauty here is that we can scale horizontally to massively reduce our compute time. 

What if you could parse every document in your legal department in a couple of days?

In [0]:
# Define It!
output_schema = StructType([
    StructField("date", StringType(), nullable=True),
    StructField("payee_name", StringType(), nullable=True),
    StructField("dollar_amount", FloatType(), nullable=True),
    StructField("reason_for_payment", StringType(), nullable=True),
    StructField("check_number", IntegerType(), nullable=True),
    StructField("sender_name", StringType(), nullable=True),
    StructField("sender_address", StringType(), nullable=True)
])

def process_image(content):
    client = OpenAI(api_key=TOKEN, base_url=f"{URL}/serving-endpoints/")

    image_data = base64.b64encode(content).decode("utf-8")
    
    response = client.chat.completions.create(
        messages=[{
            "role": "user",
            "content": [
                {"type": "text", "text": structured_prompt},
                {"type": "image_url", "image_url": {
                    "url": f"data:image/jpeg;base64,{image_data}"
                }}
            ]
        }],
        model="databricks-claude-3-7-sonnet",
        max_tokens=512
    )
    
    raw_text = response.choices[0].message.content
    clean_text = raw_text.replace('```','').replace('json','')
    return json.loads(clean_text)

# Scale It!
process_image_udf = F.udf(process_image, output_schema)

# Run It!
(
  images_df
  .withColumn('extracted_info', process_image_udf(F.col('content')))
  .select('path','extracted_info')
  .display()
)