In [6]:
import os
import numpy as np
import pandas as pd
from tqdm import tqdm
from collections import defaultdict
from matplotalt import *
from matplotalt_helpers import pillow_image_to_base64_string
from PIL import Image
from api_helpers import get_openai_vision_response

alt_figs_path = "./alt_figs"

In [7]:
fig_ids = []
for alt_cap_path in os.listdir(alt_figs_path):
    if alt_cap_path.endswith(".txt"):
        fig_ids.append(alt_cap_path.split(".")[0][3:])

print(len(fig_ids))

204


In [14]:
fig_id_to_captions = []
OPENAI_API_KEY = os.environ.get("AZURE_OPENAI_API_KEY")
model = "TURBO"
n_errors = 0
n_passed = 0
desc_level_prompt = get_desc_level_prompt(desc_level=3)
for fig_id in tqdm(fig_ids):
    try:
        fig_caption_dict = {"figure_id": fig_id}
        starter_alt = None
        # Get text caption
        with open(f"{alt_figs_path}/nb_{fig_id}.txt") as heuristic_cap_file:
            starter_alt = heuristic_cap_file.read()
            fig_caption_dict["heuristic"] = starter_alt
        # Get figure image in base64
        pil_img = Image.open(f"{alt_figs_path}/nb_{fig_id}.jpg")
        base64_img = pillow_image_to_base64_string(pil_img)

        print(desc_level_prompt)
        print(starter_alt)
        # gpt4-turbo
        gpt4_caption = get_openai_vision_response(OPENAI_API_KEY, desc_level_prompt, base64_img, model=model, use_azure=True, max_tokens=225, return_full_response=False)
        fig_caption_dict["gpt-4-turbo-L4"] = gpt4_caption
        # gpt4-turbo + starter alt
        if starter_alt is not None:
            starter_alt_prompt = get_desc_level_prompt(desc_level=3, starter_desc=starter_alt)
            gpt4_alt_caption = get_openai_vision_response(OPENAI_API_KEY, starter_alt_prompt, base64_img, model=model, use_azure=True, max_tokens=225, return_full_response=False)
            fig_caption_dict["gpt-4-turbo-alt-L4"] = gpt4_alt_caption
        plt.clf()
        fig_id_to_captions.append(fig_caption_dict)
        np.save("./fig_id_captions_arr", fig_id_to_captions)
        n_passed += 1
    except Exception as e:
        #raise e
        print(fig_id)
        n_errors += 1
        plt.clf()
        print(e)

print(f"Num passed: {n_passed}")
print(f"Num errors: {n_errors}")
combined_captions_df = pd.DataFrame.from_dict(fig_id_to_captions)
combined_captions_df.to_json("./mpl_gallery_combined_captions.jsonl", orient='records', lines=True)

  0%|          | 0/204 [00:00<?, ?it/s]

230
Data table:

You are a helpful assistant that describes figures. Here are two example descriptions:
1. 'This is a vertical bar chart entitled 'COVID-19 mortality rate by age' that plots Mortality rate by Age. Mortality rate is plotted on the vertical y-axis from 0 to 15%. Age is plotted on the horizontal x-axis in bins: 10-19, 20-29, 30-39, 40-49, 50-59, 60-69, 70-79, 80+. The highest COVID-19 mortality rate is in the 80+ age range, while the lowest mortality rate is in 10-19, 20-29, 30-39, sharing the same rate. COVID-19 mortality rate does not linearly correspond to the demographic age. The mortality rate increases with age, especially around 40-49 years and upwards. The mortality rate increases exponentially with older people.'
2. 'This is a line chart titled 'Big Tech Stock Prices' that plots price by date. The corporations include AAPL (Apple), AMZN (Amazon), GOOG (Google), IBM (IBM), and MSFT (Microsoft). The years are plotted on the horizontal x-axis from 2000 to 2010 with a

  0%|          | 1/204 [00:19<1:07:34, 19.97s/it]

This is a heat map with a 7x9 grid, possibly representing data values through color variations. The horizontal x-axis ranges from 0 to 8 and the vertical y-axis ranges from 0 to 7, both divided into integer increments. The colors range from dark blue to dark purple, representing lower values, and yellow to bright turquoise for higher values. 

Overlaid are six arrows, three in a teal color labeled "X" on the horizontal axis pointing to the right, "Y" on the vertical axis pointing upwards, and a pair crossed, one labeled "111" and the other "112", suggesting a three-dimensional orientation or additional data layer. Three other arrows, in red and labeled "A" and "B", point to specific colored squares, indicating a transition or comparison between those points. 

Text labeling two squares shows "20" on a yellow square and "30" on a turquoise, indicating higher values in those positions. There's no explicit title or detailed description of what the axes or colors quantitatively represent. 

  0%|          | 1/204 [00:26<1:27:59, 26.01s/it]


KeyboardInterrupt: 

<Figure size 640x480 with 0 Axes>