In [1]:
import os
import numpy as np
import pandas as pd
from tqdm import tqdm
from pprint import pprint
from collections import defaultdict
from matplotalt import *
from matplotalt_helpers import pillow_image_to_base64_string
from PIL import Image
from api_helpers import get_openai_vision_response

alt_figs_path = "./alt_figs"

In [3]:
fig_ids = []
for alt_cap_path in os.listdir(alt_figs_path):
    if alt_cap_path.endswith(".txt"):
        fig_ids.append(alt_cap_path.split(".")[0][3:])

print(len(fig_ids))

203


In [4]:
fig_id_to_captions = []
OPENAI_API_KEY = os.environ.get("AZURE_OPENAI_API_KEY")
model = "TURBO"
n_errors = 0
n_passed = 0
desc_level_prompt = get_desc_level_prompt(desc_level=3)
for fig_id in tqdm(fig_ids):
    try:
        fig_caption_dict = {"figure_id": fig_id}
        starter_alt = None
        # Get text caption
        with open(f"{alt_figs_path}/nb_{fig_id}.txt") as heuristic_cap_file:
            starter_alt = heuristic_cap_file.read()
            fig_caption_dict["heuristic"] = starter_alt
        # Get figure image in base64
        pil_img = Image.open(f"{alt_figs_path}/nb_{fig_id}.jpg")
        base64_img = pillow_image_to_base64_string(pil_img)

        #print(desc_level_prompt)
        #print(starter_alt)
        # gpt4-turbo
        gpt4_caption = get_openai_vision_response(OPENAI_API_KEY, desc_level_prompt, base64_img, model=model, use_azure=True, max_tokens=300, return_full_response=False)
        fig_caption_dict["gpt-4-turbo-L4-300"] = gpt4_caption
        # gpt4-turbo + starter alt
        if starter_alt is not None:
            starter_alt_prompt = get_desc_level_prompt(desc_level=3, starter_desc=starter_alt)
            gpt4_alt_caption = get_openai_vision_response(OPENAI_API_KEY, starter_alt_prompt, base64_img, model=model, use_azure=True, max_tokens=300, return_full_response=False)
            fig_caption_dict["gpt-4-turbo-alt-L4-300"] = gpt4_alt_caption
        plt.clf()
        fig_id_to_captions.append(fig_caption_dict)
        #pprint(fig_id_to_captions)
        np.save("./fig_id_captions_arr2", fig_id_to_captions)
        n_passed += 1
    except Exception as e:
        #raise e
        print(fig_id)
        n_errors += 1
        plt.clf()
        print(e)

print(f"Num passed: {n_passed}")
print(f"Num errors: {n_errors}")
combined_captions_df = pd.DataFrame.from_dict(fig_id_to_captions)
combined_captions_df.to_json("./mpl_gallery_combined_captions3.jsonl", orient='records', lines=True)

  0%|          | 0/203 [00:00<?, ?it/s]

 34%|███▍      | 70/203 [19:11<35:51, 16.17s/it]  

301
Error code: 400 - {'error': {'inner_error': {'code': 'ResponsibleAIPolicyViolation', 'content_filter_results': {'jailbreak': {'filtered': True, 'detected': True}}}, 'code': 'content_filter', 'message': "The response was filtered due to the prompt triggering Azure OpenAI's content management policy. Please modify your prompt and retry. To learn more about our content filtering policies please read our documentation: \r\nhttps://go.microsoft.com/fwlink/?linkid=2198766.", 'param': 'prompt', 'type': None}}


 57%|█████▋    | 116/203 [31:56<21:41, 14.96s/it]

377
Error code: 400 - {'error': {'inner_error': {'code': 'ResponsibleAIPolicyViolation', 'content_filter_results': {'jailbreak': {'filtered': True, 'detected': True}}}, 'code': 'content_filter', 'message': "The response was filtered due to the prompt triggering Azure OpenAI's content management policy. Please modify your prompt and retry. To learn more about our content filtering policies please read our documentation: \r\nhttps://go.microsoft.com/fwlink/?linkid=2198766.", 'param': 'prompt', 'type': None}}


100%|██████████| 203/203 [54:56<00:00, 16.24s/it]

Num passed: 201
Num errors: 2





<Figure size 640x480 with 0 Axes>

In [5]:
combined_captions_df = pd.read_json("./mpl_gallery_combined_captions.jsonl", orient="records", lines=True)
combined_captions_df = combined_captions_df.sample(frac=1.0)
combined_captions_df.to_json("./mpl_gallery_combined_captions_shuffled.jsonl", orient="records", lines=True)