In [None]:
%load_ext autoreload
%autoreload 2

### Convert image to FigureInfo object

In [None]:
import base64
import json
import os

from llm_synthesis.models.figure import FigureInfo


def convert_raw_image_(img_path: str) -> FigureInfo:
    """
    Load an image from the given path convert it to base64-encoded string,
    then wrap it in a FigureInfo object.

    Parameters:
        path (str): Path to the image file (.png, .jpeg, .jpg)
    Returns:
        FigureInfo: An object containing the base64-encoded image data and other metadata.
    """
    ext = os.path.splitext(img_path)[1].lower()
    if ext not in [".png", ".jpeg", ".jpg"]:
        raise ValueError(f"Unsupported image format: {ext}")

    with open(img_path, "rb") as f:
        encoded = base64.b64encode(f.read()).decode("utf-8")

    figure_info = FigureInfo(
        base64_data=encoded,
        alt_text="",
        position=0,
        context_before="",
        context_after="",
        figure_reference="",
        figure_class="",
    )

    return figure_info


img_path = "<path_of_your_image.png>"
figure_info = convert_raw_image_(img_path)

### Call Claude API for extraction


In [None]:
from llm_synthesis.transformers.plot_extraction.claude_extraction.plot_data_extraction import (
    ClaudeLinePlotDataExtractor,
)

# Initialize the extractor
extractor = ClaudeLinePlotDataExtractor(model_name="claude-sonnet-4-20250514")

# Perform the extraction
extracted_data = extractor.forward(figure_info)
print("Extracted data:")
print(extracted_data)

# Track the cost of this operation
cost_info = extractor.get_cost()
cost_info

### Visualize extracted data series with their labels and axis for the chart

In [None]:
# Orginal figure
from IPython.display import Image

Image(base64.b64decode(figure_info.base64_data))

In [None]:
# visualize ground truth and extracted data
from llm_synthesis.transformers.plot_extraction.claude_extraction.plot_data_extraction import (
    ExtractedLinePlotData,
)
from llm_synthesis.utils.visualization import visualize_line_chart

# load ground truth coordinates
ground_truth_path = "<path_to_ground_truth.json>"  # make sure the name of each series is consistent with the extracted data

with open(ground_truth_path) as f:
    gt_coordinates = json.load(f)
    gt_extracted_data = ExtractedLinePlotData.model_validate(gt_coordinates)

# sort the names for consistent visualization
gt_extracted_data.name_to_coordinates = dict(
    sorted(
        gt_extracted_data.name_to_coordinates.items(), key=lambda item: item[0]
    )
)
extracted_data.name_to_coordinates = dict(
    sorted(extracted_data.name_to_coordinates.items(), key=lambda item: item[0])
)

visualize_line_chart(gt_extracted_data)
visualize_line_chart(extracted_data)

### Compute the RMSE of the extracted data to ground truth

In [None]:
from llm_synthesis.metrics.extraction_metric.figure_extraction_metric import (
    FigureExtractionMetric,
)

figure_extraction_metric = FigureExtractionMetric()

rmse = figure_extraction_metric(
    extracted_data, gt_extracted_data, error_metric="rmse"
)
mae = figure_extraction_metric(
    extracted_data, gt_extracted_data, error_metric="mae"
)
rmse, mae
