# Metadata Then Tags
We showed that tiles are the path forward for tag extraction. But that the model is hallucinating titles that don't exist. In order to improve our metadata capture, we are going to do two passes on the image. The first is a few-shot extraction of image metadata - titles, locations, revisions etc. The second is a tile by tile few shot extraction where we will curate samples ourselves to guide the model

In [0]:
%pip install mlflow openai
%restart_python

In [0]:
from pathlib import Path
from openai import OpenAI
import base64
from PIL import Image
import IPython.display as display

In [0]:
from mlflow.models import ModelConfig
config = ModelConfig(development_config="config_alb.yaml").to_dict()

In [0]:
DATABRICKS_TOKEN = dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiToken().get()

client = OpenAI(
  api_key=DATABRICKS_TOKEN,
  base_url="https://adb-984752964297111.11.azuredatabricks.net/serving-endpoints"
)

## Metadata Extraction


In [0]:
load_sheet_df = spark.sql("SELECT * FROM shm.pid.load_sheet_alb")

examples_df = (
  load_sheet_df
  .filter(load_sheet_df.for_examples == True)
  .toPandas()
)

# Get the image path
# TODO: Abstract a bit more
examples_df["image_path"] = examples_df["closest_filename"].apply(lambda x: Path(config["processed_path"] + x.replace(".pdf", "") + "/" + x.replace(".pdf", "") + "_page_1.jpeg"))

In [0]:
examples_df

## Few Shot Metadata Parsing
Testing metadata extraction only

In [0]:
test_page_path = examples_df.loc[7, "image_path"]
test_image = Image.open(test_page_path)
display.display(test_image)

In [0]:
import pandas as pd
def few_shot_parse_metadata(image_path: str, examples: pd.DataFrame):
  image_data = base64.b64encode(Path(image_path).read_bytes()).decode("utf-8")

  content = []
  for idx, example in examples.iterrows():
    example_text = example['json_output']
    image_path = Path(example['image_path'])
    example_data = base64.b64encode(image_path.read_bytes()).decode("utf-8")
    content.append({
        "type": "image_url", 
        "image_url": {"url": f"data:image/jpeg;base64,{example_data}"},
      })
    content.append({
      "type": "text",
      "text": example_text
    })
    
  content.append({
    "type": "image_url", "image_url": 
    {"url": f"data:image/jpeg;base64,{image_data}"}
    })

  chat_completion = client.chat.completions.create(
    messages=[
      {
        "role": "system",
        "content": config['metadata_prompt']
      },
      {
        "role": "user",
        "content": content
      }
    ],
    model=config['fm_endpoint'],
    temperature=config['temperature'],
    top_p=config['top_p']
  )

  parsed_text = chat_completion.choices[0].message.content
  return parsed_text

In [0]:
few_shot_parse_metadata(test_page_path, pd.DataFrame())

## Few Shot Tile Tag Extraction


In [0]:
import pandas as pd
def few_shot_parse_tags(image_path: str, examples: pd.DataFrame):
  image_data = base64.b64encode(Path(image_path).read_bytes()).decode("utf-8")

  content = []
  for idx, example in examples.iterrows():
    example_text = example['json_output']
    image_path = Path(example['image_path'])
    example_data = base64.b64encode(image_path.read_bytes()).decode("utf-8")
    content.append({
        "type": "image_url", 
        "image_url": {"url": f"data:image/jpeg;base64,{example_data}"},
      })
    content.append({
      "type": "text",
      "text": example_text
    })
    
  content.append({
    "type": "image_url", "image_url": 
    {"url": f"data:image/jpeg;base64,{image_data}"}
    })

  chat_completion = client.chat.completions.create(
    messages=[
      {
        "role": "system",
        "content": config['tag_prompt']
      },
      {
        "role": "user",
        "content": content
      }
    ],
    model=config['fm_endpoint'],
    temperature=config['temperature'],
    top_p=config['top_p']
  )

  parsed_text = chat_completion.choices[0].message.content
  return parsed_text

In [0]:
test_tile_path = "/Volumes/shm/pid/tiled_pdfs/with_load_sheet/MRP-520-PID-PR-000351_F267/MRP-520-PID-PR-000351_F267.1_page_1_tile_4.webp"
test_tile = Image.open(test_tile_path)
display.display(test_tile)

In [0]:
examples = examples_df.iloc[[7]]
inf_few_shot_tags = few_shot_parse_tags(test_tile_path, examples)

In [0]:
inf_few_shot_tags