In [3]:
import os
import shutil
import re
import pickle
import json
from typing import Any, Dict, List
from dotenv import load_dotenv
from tqdm import tqdm
from openai import OpenAI

from openpmcvl.granular.pipeline.utils import load_dataset, save_jsonl

PMC_ROOT = "set this directory"

# Make sure .env file containt OPENAI_API_KEY
load_dotenv()
client = OpenAI()

# **Subcaption Extraction**

Extracts subfigure captions from figure captions using OpenAI's GPT-4o Batch API.

## Pipeline
1. Input: JSONL with metadata (captions + IDs)
2. Generate batch API requests (50k limit)
3. Submit to OpenAI batch processing
4. Get results as structured subcaptions
5. Save results to JSONL file

In [None]:
PROMPT = """
Subfigure labels are letters referring to individual subfigures within a larger figure.
This is a caption: "%s"
Check if the caption contains explicit subfigure label. 
If not, output "NO" and end the generation. 
If yes, output "YES", then generate the subcaption of the subfigures according to the caption. 
The output should use the template:
    YES
    Subfigure-A: ...
    Subfigure-B: ...
    ...
The label should be removed from subcaption.
""".strip()

caption = "Try sample caption!"


completion = client.chat.completions.create(
    model="gpt-4o-2024-08-06",
    messages=[
        {"role": "system", "content": "You are a helpful assistant."},
        {"role": "user", "content": PROMPT % caption},
    ],
    temperature=0,
    max_tokens=500,
)

print(completion.choices[0].message.content)

In [117]:
def generate_api_request(custom_id, system_content, user_content):
    """Generate a single API request in the required format."""
    return {
        "custom_id": custom_id,
        "method": "POST",
        "url": "/v1/chat/completions",
        "body": {
            "model": "gpt-4o-2024-08-06",
            "messages": [
                {"role": "system", "content": system_content},
                {"role": "user", "content": user_content},
            ],
            "temperature": 0,
            "max_tokens": 2000,
        },
    }


def create_prompt(caption):
    """Create the prompt template with the given caption."""
    return PROMPT.strip() % caption


def generate_jsonl(dataset, requests_file):
    """Generate JSONL file with API requests.
    
    Args:
        dataset: List of metadata containing captions and IDs
        requests_file: Path to output requests JSONL file
    """
    count = 0
    
    # Open output file and write requests line by line
    with open(requests_file, "w") as f:
        for data in dataset:
            count += 1
            
            # Skip first 50k entries (already processed)
            if count <= 50000:  # Batch API can handle at most 50k requests
                continue
                
            # Only process captions under 400 words
            if len(data["caption"].split()) <= 400:

                # Generate API request for this caption
                request = generate_api_request(
                    custom_id=f"{data['id']}",
                    system_content="You are a helpful assistant.",
                    user_content=create_prompt(data["caption"]),
                )
                
                # Write request as JSON line
                f.write(json.dumps(request) + "\n")

In [None]:
# Load the metadata dataset containing captions and IDs
dataset = load_dataset(os.path.join(PMC_ROOT, "meta.jsonl"))

# Define output path for API requests
requests_file = os.path.join(PMC_ROOT, "requests.jsonl")

# Generate JSONL file with API requests for each caption
generate_jsonl(dataset, requests_file)

In [None]:
# Upload the requests file to OpenAI for batch processing
batch_input_file = client.files.create(file=open(requests_file, "rb"), purpose="batch")
batch_input_file_id = batch_input_file.id

# Create a batch job to process the requests
# This will run for up to 24 hours and process 50k subcaptions
client.batches.create(
    input_file_id=batch_input_file_id,
    endpoint="/v1/chat/completions",
    completion_window="24h", 
    metadata={"description": "50k subcaptions"},
)

In [None]:
# Note you have to run this separately for each submitted batch
# Check status of first batch job
print(client.batches.retrieve("batch_xxxxx"))

# Get the output file content from the completed batch
file_response = client.files.content("file-xxxxxx")

# Write the batch results to a JSONL file
with open(f"{PMC_ROOT}/subcaptions.jsonl", "w") as f:
    f.write(file_response.text)