# Zero Shot Parsing
This notebook goes through zero shot parsing to generate examples for evaluation and few-shot prompting. We manually correct the examples, but use zero-shot queries to speed up the process.

In [0]:
%pip install -U --quiet mlflow openai
%restart_python

In [0]:
from pathlib import Path
from openai import OpenAI
import base64
from PIL import Image
import IPython.display as display
import pandas as pd
import re
import json
import time

In [0]:
from mlflow.models import ModelConfig
config = ModelConfig(development_config="config.yaml").to_dict()

In [0]:
tile_df = spark.sql(f"""
    SELECT * 
    FROM {config['catalog']}.{config['schema']}.tile_info
    WHERE page_number in (12,24,27,31,32)
    """).toPandas()
tile_df

#Zero Shot - Metadata
This section runs the metadata prompt using the entire image from each example and the last tile (which is always the lower right). The last tile should contain most title blocks due to the dimensions of the tiles and resolution.

In [0]:
# This query pulls the last tile from each example page
example_metadata = spark.sql(f"""
SELECT *
FROM (
  SELECT 
    *,
    ROW_NUMBER() OVER (PARTITION BY page_number ORDER BY tile_number DESC) as rn
  FROM {config['catalog']}.{config['schema']}.tile_info
)
WHERE rn = 1
AND page_number in (12,24,27,31,32)
""")

example_metadata.display()

We are going to use a naive loop to query the examples, but will move to Ray or Spark for parallelization for the larger set of queries. The code below sends the excerpt and drawing into our model for a zero shot extraction.

## Metadata Extraction
The code below extracts the metadata from each drawing + lower right tile.

In [0]:
def zero_shot_metadata(row: pd.Series, client: OpenAI):
    page_image_data = base64.b64encode(Path(row['page_path']).read_bytes()).decode("utf-8")
    tile_image_data = base64.b64encode(Path(row['tile_path']).read_bytes()).decode("utf-8")

    chat_completion = client.chat.completions.create(
      messages=[
        {
          "role": "system",
          "content": config['metadata_prompt']
        },
        {
          "role": "user",
          "content": [
                {
                    "type": "image_url", 
                    "image_url": {"url": f"data:image/jpeg;base64,{page_image_data}"},
                },
                {
                    "type": "image_url", 
                    "image_url": {"url": f"data:image/jpeg;base64,{tile_image_data}"},
                },
            ]
        }
      ],
      model=config['fm_endpoint'],
      temperature=config['temperature'],
      top_p=config['top_p']
    )

    parsed_text = chat_completion.choices[0].message.content

    #json parsing
    fixed_json_str = re.sub(
        r'"(\d+)"-([A-Z\-0-9]+)"',
        r'"\1-\2"',
        parsed_text
    )

    parsed_dict = json.loads(fixed_json_str)

    label_filename = Path(row['page_path']).name.replace('.jpg', '.json')
    with open(f'./examples/{label_filename}', 'w') as f:
        json.dump(parsed_dict, f, indent=4)
    
    return parsed_dict

In [0]:
example_metadata_df = example_metadata.toPandas()
results = []
for idx, row in example_metadata_df.iterrows():
    print(row['page_number'])
    result = zero_shot_metadata(row, client)
    results.append(result)

page_cols = ['filename', 'file_path_hash', 'file_width', 'file_height', 'file_dpi', 'page_number', 'page_path', 'tile_number', 'tile_path']
page_results = pd.concat([example_metadata_df[page_cols], pd.DataFrame(results)], axis=1)
page_results['json_result'] = results

(
  spark.createDataFrame(page_results)
  .write.mode('overwrite')
  .options('mergeSchema','true')
  .saveAsTable(f'{config["catalog"]}.{config["schema"]}.example_pages_parsed')
)

In [0]:
page_cols = ['filename', 'file_path_hash', 'file_width', 'file_height', 'file_dpi', 'page_number', 'page_path', 'tile_number', 'tile_path']
page_results = pd.concat([example_metadata_df[page_cols], pd.DataFrame(results)], axis=1)
page_results['json_result'] = results

(
  spark.createDataFrame(page_results)
  .write.mode('overwrite')
  .option('mergeSchema', 'true')
  .saveAsTable(f'{config["catalog"]}.{config["schema"]}.example_pages_parsed')
)

In [0]:
spark.sql(f'SELECT * FROM {config["catalog"]}.{config["schema"]}.example_pages_parsed').display()

## Tag Extraction
Now we move on to tag extraction from every tile. We take the entire table and run a zero shot extraction, which we will manually correct as a ground truth and evaluation set.

In [0]:
def zero_shot_tag_with_retry(row: pd.Series, client: OpenAI, max_retries=2, retry_delay=1):
    tile_image_data = base64.b64encode(Path(row['tile_path']).read_bytes()).decode("utf-8")

    for attempt in range(max_retries + 1):
        try:
            chat_completion = client.chat.completions.create(
              messages=[
                {
                  "role": "system",
                  "content": config['tag_prompt']
                },
                {
                  "role": "user",
                  "content": [
                        {
                            "type": "image_url", 
                            "image_url": {"url": f"data:image/jpeg;base64,{tile_image_data}"},
                        },
                    ]
                }
              ],
              model=config['fm_endpoint'],
              temperature=config['temperature'],
              top_p=config['top_p']
            )

            parsed_text = chat_completion.choices[0].message.content

            # JSON parsing
            fixed_json_str = re.sub(
                  r'"(\d+)"-([A-Z\-0-9]+)"',
                  r'"\1-\2"',
                  parsed_text
              )
            
            label_filename = Path(row['tile_path']).name.replace('.jpg', '.json')
            
            # Try to parse JSON
            parsed_dict = json.loads(fixed_json_str)

            with open(f'./examples/{label_filename}', 'w') as f:
                json.dump(parsed_dict, f, indent=4)

            row['tag_info'] = parsed_dict
            break  # Success, exit retry loop
            
        except json.JSONDecodeError as e:
            if attempt < max_retries:
                print(f"JSON parsing failed for {label_filename} (attempt {attempt + 1}/{max_retries + 1}): {e}")
                time.sleep(retry_delay)  # Wait before retrying
            else:
                print(f"Failed to parse {label_filename} after {max_retries + 1} attempts: {e}")
                label_filename = Path(row['tile_path']).name.replace('.jpg', '.json')
                with open(f'./examples/{label_filename}', 'w') as f:
                    json.dump(fixed_json_str, f)
                row['tag_info'] = fixed_json_str
                
        except Exception as e:
            # For non-JSON errors (API errors, etc.), fail immediately
            print(f"Non-parsing error for {label_filename}: {e}")
            label_filename = Path(row['tile_path']).name.replace('.jpg', '.json')
            with open(f'./examples/{label_filename}', 'w') as f:
                json.dump(str(e), f)
            row['tag_info'] = str(e)
            break
    
    return row

In [0]:
# 7 minutes for 35 calls ~ 12s/call
results = []
for idx, row in tile_df.iterrows():
    print(row['page_number'], row['tile_number'])
    row_out = zero_shot_tag_with_retry(row, client)
    results.append(row_out)

In [0]:
spark.createDataFrame(pd.DataFrame(results)).write.mode('overwrite').saveAsTable(f'{config["catalog"]}.{config["schema"]}.example_tiles_parsed')

In [0]:
spark.sql(f'SELECT * FROM {config["catalog"]}.{config["schema"]}.example_tiles_parsed').display()