# Few Shot Parsing
Now that we have some manually constructed ground truths, we move on to few shot parsing with examples.

This notebook has been tested on serverless.

In [0]:
%pip install -U --quiet mlflow openai
%restart_python

In [0]:
from pathlib import Path
from openai import OpenAI
import base64
from PIL import Image
import IPython.display as display
import pandas as pd
import re
import json
import time

In [0]:
from mlflow.models import ModelConfig
config = ModelConfig(development_config="config.yaml").to_dict()

In [0]:
tile_df = spark.sql(f"""
    SELECT * 
    FROM {config['catalog']}.{config['schema']}.tile_info
    WHERE page_number in (12,24,27,31,32)
    """).toPandas()
tile_df

#Few Shot - Metadata
This section runs the metadata prompt using the entire image from each example and the last tile (which is always the lower right). The last tile should contain most title blocks due to the dimensions of the tiles and resolution.

In [0]:
# This query pulls the last tile from each example page
example_metadata = spark.sql(f"""
SELECT *
FROM (
  SELECT 
    *,
    ROW_NUMBER() OVER (PARTITION BY page_number ORDER BY tile_number DESC) as rn
  FROM {config['catalog']}.{config['schema']}.tile_info
)
WHERE rn = 1
-- AND page_number in (12,24,27,31,32)
""")

example_metadata.display()

We are going to use a naive loop to query the examples, but will move to Ray or Spark for parallelization for the larger set of queries. The code below sends the excerpt and drawing into our model for a zero shot extraction.

In [0]:
DATABRICKS_TOKEN = dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiToken().get()

client = OpenAI(
  api_key=DATABRICKS_TOKEN,
  base_url="https://adb-984752964297111.11.azuredatabricks.net/serving-endpoints"
)

## Metadata Extraction
The code below extracts the metadata from each drawing + lower right tile. Our examples are saved in `example_pages_parsed`. We select the first two and use it for our few shot example.

In [0]:
examples = spark.sql(f'SELECT * FROM {config["catalog"]}.{config["schema"]}.example_pages_parsed').toPandas().iloc[0:2]
tests = spark.sql(f'SELECT * FROM {config["catalog"]}.{config["schema"]}.example_pages_parsed').toPandas().iloc[2:]

In [0]:
ex1_tile_img_data = base64.b64encode(Path(examples.iloc[0]['tile_path']).read_bytes()).decode("utf-8")
ex1_text = str(examples.iloc[0]['json_result'])
ex2_tile_img_data = base64.b64encode(Path(examples.iloc[1]['tile_path']).read_bytes()).decode("utf-8")
ex2_text = str(examples.iloc[1]['json_result'])

In [0]:
def extract_json_from_markdown(s):
    s = re.sub(
          r'"(\d+)"-([A-Z\-0-9]+)"',
          r'"\1-\2"',
          s
      )
    
    # Remove markdown json block markers
    if s.startswith('``````'):
        s = s[7:-3]  # Remove '``````' from end
    
    # Strip whitespace and newlines
    s = s.strip()
    
    # Parse and return JSON
    return json.loads(s)

In [0]:
def few_shot_metadata(row: pd.Series, client: OpenAI):
    page_image_data = base64.b64encode(Path(row['page_path']).read_bytes()).decode("utf-8")
    tile_image_data = base64.b64encode(Path(row['tile_path']).read_bytes()).decode("utf-8")

    chat_completion = client.chat.completions.create(
      messages=[
        {
          "role": "system",
          "content": config['metadata_prompt']
        },
        {
          "role": "user",
          "content": [
                {
                    "type": "text",
                    "text": "Here are two examples of tiles with their example extractions"
                },
                {
                    "type": "image_url", 
                    "image_url": {"url": f"data:image/jpeg;base64,{ex1_tile_img_data}"},
                },
                {
                    "type": "text", 
                    "text": ex1_text,
                },
                {
                    "type": "image_url", 
                    "image_url": {"url": f"data:image/jpeg;base64,{ex2_tile_img_data}"},
                },
                {
                    "type": "text", 
                    "text": ex2_text,
                },
                {
                    "type": "text",
                    "text": "Now, extract this title block"
                },
                {
                    "type": "image_url", 
                    "image_url": {"url": f"data:image/jpeg;base64,{tile_image_data}"},
                },
            ]
        }
      ],
      model=config['fm_endpoint'],
      temperature=config['temperature'],
      top_p=config['top_p']
    )

    parsed_text = chat_completion.choices[0].message.content

    try:
        parsed_dict = extract_json_from_markdown(parsed_text)
        label_filename = Path(row['page_path']).name.replace('.jpg', '.json')
        with open(f'./examples/{label_filename}', 'w') as f:
            json.dump(parsed_dict, f, indent=4)
        return parsed_dict
    except Exception as e:
        print(e)    
        return parsed_text

In [0]:
few_shot_metadata(tests.iloc[1], client)

In [0]:
example_metadata_df = example_metadata.toPandas()
results = []
# for idx, row in example_metadata_df.iterrows():
#     print(row['page_number'])
#     result = zero_shot_metadata(row, client)
#     results.append(result)

In [0]:
page_cols = ['filename', 'file_path_hash', 'file_width', 'file_height', 'file_dpi', 'page_number', 'page_path']
page_results = pd.concat([example_metadata_df[page_cols], pd.DataFrame(results)], axis=1)

In [0]:
spark.createDataFrame(page_results).write.mode('overwrite').saveAsTable(f'{config["catalog"]}.{config["schema"]}.example_pages_parsed')

## Tag Extraction
Now we move on to tag extraction from every tile. We take the entire table and run a zero shot extraction, which we will manually correct as a ground truth and evaluation set.

In [0]:
def zero_shot_tag_with_retry(row: pd.Series, client: OpenAI, max_retries=2, retry_delay=1):
    tile_image_data = base64.b64encode(Path(row['tile_path']).read_bytes()).decode("utf-8")

    for attempt in range(max_retries + 1):
        try:
            chat_completion = client.chat.completions.create(
              messages=[
                {
                  "role": "system",
                  "content": config['tag_prompt']
                },
                {
                  "role": "user",
                  "content": [
                        {
                            "type": "image_url", 
                            "image_url": {"url": f"data:image/jpeg;base64,{tile_image_data}"},
                        },
                    ]
                }
              ],
              model=config['fm_endpoint'],
              temperature=config['temperature'],
              top_p=config['top_p']
            )

            parsed_text = chat_completion.choices[0].message.content

            # JSON parsing
            fixed_json_str = re.sub(
                  r'"(\d+)"-([A-Z\-0-9]+)"',
                  r'"\1-\2"',
                  parsed_text
              )
            
            label_filename = Path(row['tile_path']).name.replace('.jpg', '.json')
            
            # Try to parse JSON
            parsed_dict = json.loads(fixed_json_str)

            with open(f'./examples/{label_filename}', 'w') as f:
                json.dump(parsed_dict, f, indent=4)

            row['tag_info'] = parsed_dict
            break  # Success, exit retry loop
            
        except json.JSONDecodeError as e:
            if attempt < max_retries:
                print(f"JSON parsing failed for {label_filename} (attempt {attempt + 1}/{max_retries + 1}): {e}")
                time.sleep(retry_delay)  # Wait before retrying
            else:
                print(f"Failed to parse {label_filename} after {max_retries + 1} attempts: {e}")
                label_filename = Path(row['tile_path']).name.replace('.jpg', '.json')
                with open(f'./examples/{label_filename}', 'w') as f:
                    json.dump(fixed_json_str, f)
                row['tag_info'] = fixed_json_str
                
        except Exception as e:
            # For non-JSON errors (API errors, etc.), fail immediately
            print(f"Non-parsing error for {label_filename}: {e}")
            label_filename = Path(row['tile_path']).name.replace('.jpg', '.json')
            with open(f'./examples/{label_filename}', 'w') as f:
                json.dump(str(e), f)
            row['tag_info'] = str(e)
            break
    
    return row

In [0]:
# 7 minutes for 35 calls ~ 12s/call
results = []
for idx, row in tile_df.iterrows():
    print(row['page_number'], row['tile_number'])
    row_out = zero_shot_tag_with_retry(row, client)
    results.append(row_out)