# Use gpt-3.5-turbo to combine the pix2text & easy_ocr outputs

In [None]:
# === kernel system ===
import os
import gc
import time
import json
import shutil
import traceback
import csv
from collections import defaultdict
from pathlib import Path
# === packages ===
import pandas as pd
from tqdm import tqdm
from PIL import Image
from openai import OpenAI
import re


# Load api keys

In [None]:
def load_keys(path):
    keys = {}
    with open(path, "r") as f:
        for line in f:
            if "=" in line:
                k, v = line.strip().split("=", 1)
                keys[k] = v
    return keys

keys = load_keys("/home/ne6131039/Desktop/Confidential_Key.txt")
openai_key = keys['OPENAI_KEY']
huggingface_write_key = keys['HUGGINGFACE_WRITE_KEY']

# Merge two type of text and label caption_type

### System prompt

In [None]:
# Prompt construction logic omitted for proprietary reasons
# Contact authors for detailed evaluation methodology
system_prompt = """
Creates prompt for GPT to process and classify scientific captions.
    
The prompt instructs GPT to:
- Combine OCR results from Pix2Text (LaTeX) and EasyOCR (plain text)
- Reconstruct coherent, readable scientific captions
- Classify content into predefined categories
- Apply proper scientific formatting (Unicode superscripts, preserve formulas)
- Output in structured format: caption_content + caption_type

Args:
    pix2text_output (str): LaTeX-formatted OCR result
    easyocr_output (str): Plain text OCR result
    
Returns:
    str: Formatted caption processing prompt
"""

In [None]:
# Initialize OpenAI client
client = OpenAI(api_key=openai_key)  

# Constants
INPUT_CSV = "../caption_analysis/ocr_output.csv"
OUTPUT_CSV = "caption_analysis_output.csv"
BATCH_SIZE = 100

#Initialize index 
START_INDEX = 0  

# Load input
df = pd.read_csv(INPUT_CSV).fillna("")
total_rows = len(df)
# total_rows = 40000

# Check how many already processed
if os.path.exists(OUTPUT_CSV):
    output_df = pd.read_csv(OUTPUT_CSV)
    processed_images = set(output_df["image_name"])
else:
    output_df = pd.DataFrame(columns=["image_name", "caption_content", "caption_type"])
    processed_images = set()

# Process in batches
for idx in range(START_INDEX, total_rows):
    row = df.iloc[idx]
    image_name = row["image_name"]

    # Skip if already processed
    if image_name in processed_images:
        continue

    print(f" Processing {idx+1}/{total_rows}: {image_name}")

    combined_text = f"""
[Pix2Text OCR]:
{row['pix2text_ocr'].strip()}

[EasyOCR OCR]:
{row['easyocr_ocr'].strip()}
"""

    try:
        response = client.chat.completions.create(
            model="gpt-3.5-turbo",
            messages=[
                {"role": "system", "content": system_prompt},
                {"role": "user", "content": combined_text}
            ],
            temperature=0.2,
            max_tokens=4000
        )

        reply = response.choices[0].message.content.strip()
        lines = reply.splitlines()
        
        caption_content = ""
        caption_type = ""
        
        for line in lines:
            if line.startswith("caption_content:"):
                caption_content = line.replace("caption_content:", "").strip()
            elif line.startswith("caption_type:"):
                caption_type = line.replace("caption_type:", "").strip()

    except Exception as e:
        caption_content = f"[ERROR] {e}"
        caption_type = "unknown"
        time.sleep(1)

    # Save result
    new_row = pd.DataFrame([{
        "image_name": image_name,
        "caption_content": caption_content,
        "caption_type": caption_type
    }])

    new_row.to_csv(OUTPUT_CSV, mode="a", header=not os.path.exists(OUTPUT_CSV), index=False)

    # Optional delay to avoid rate limits
    time.sleep(0.5)

print ("Processing Done")


# Image data distill using GPT-4o for QA pairs dataset

In [None]:
# ====== PROMPT 定義 ======
def build_system_prompt():
    return (
# Prompt construction logic omitted for proprietary reasons
# Contact authors for detailed evaluation methodology
        """
            Constructs prompt for GPT to generate TEM image VQA datasets.
    
            The prompt instructs GPT to:
            - Act as a scientific assistant for TEM image analysis
            - Generate question-answer pairs by task type for each sub-image:
              * Classification tasks: Imaging modality, image type identification
              * Recognition tasks: Observable features and measurements
              * Analysis tasks: Visual reasoning and interpretations
              * Description tasks: Comprehensive image summaries
            - Handle empty/blank images appropriately
            - Use Chain-of-Thought process for question generation
            - Return structured JSON output with specified format
            - Follow scientific accuracy guidelines (visual evidence only)
            
            Key constraints:
            - Don't use parent captions or labels in answers
            - Avoid material-specific names unless visible
            - Treat each sub-image independently
            - Generate multiple QA pairs per task type for diversity
            
            Args:
                caption (str): Parent figure caption (context only)
                sub_images (list): List of sub-image filenames
                predicted_labels (list): Predicted labels (reference only)
                
            Returns:
                str: Formatted VQA generation prompt
        """
    )

In [None]:
# ====== Initialize OpenAI client ======
client = OpenAI(api_key=openai_key)

# ====== Base URL for image access on Hugging Face ======
IMAGE_BASE_URL = "https://huggingface.co/datasets/LabSmart/TEM_Dataset/resolve/main"

# ====== File paths ======
OUTPUT_CSV = "gpt_reply_log.csv"        
CAPTION_CSV = "caption_analysis_output.csv"
TEM_CSV = "balanced_tem_data.csv"
OUTPUT_JSONL = "distilled_llava_vqa.jsonl"
FAILED_LIST = "failed_images.txt"
PDF_BATCH_LOG = "pdf_batch_log.csv"
TEM_BATCH_LOG = "tem_batch_log.csv"

# ====== Load metadata and batch logs ======
df_caption = pd.read_csv(CAPTION_CSV)
df_tem = pd.read_csv(TEM_CSV)
df_pdf_batch = pd.read_csv(PDF_BATCH_LOG)
df_tem_batch = pd.read_csv(TEM_BATCH_LOG)

# ====== Map image filename to batch folder ======
parent_to_batch = dict(zip(df_pdf_batch["filename"], df_pdf_batch["batch_folder"]))
sub_to_batch = dict(zip(df_tem_batch["filename"], df_tem_batch["batch_folder"]))

# ====== Filter relevant captions ======
df_caption = df_caption[df_caption["caption_type"].isin(["nanomaterials", "crystallography"])]
df = df_tem.merge(
    df_caption[["image_name", "caption_content"]],
    left_on="parent_image",
    right_on="image_name",
    how="left"
)
df.drop(columns=["image_name"], inplace=True)
assert "caption_content" in df.columns, "fail to pair caption"

# ====== Group sub-images by their parent image ======
grouped = df.groupby("parent_image")
MAX_GROUPS = 22000
total_calls = 0
error_images = []

In [None]:
# ====== Process each parent image and its sub-images ======
with open(OUTPUT_CSV, "a", encoding="utf-8", newline="") as f_log, open(FAILED_LIST, "a") as f_fail:
    writer = csv.DictWriter(f_log, fieldnames=["parent_image", "sub_image_ids", "reply"])
    if f_log.tell() == 0:
        writer.writeheader()

    for i, (parent_image, group) in enumerate(grouped):
        if i < 20000:
            continue
        if i >= MAX_GROUPS:
            break

        print(f"Dealing with {i+1} th parent image: {parent_image}\n")
        try:
            caption = group.iloc[0]["caption_content"]
            parent_batch = parent_to_batch.get(parent_image)
            if not parent_batch:
                raise ValueError(f"Missing parent batch for: {parent_image}")
            parent_url = f"{IMAGE_BASE_URL}/{parent_batch}/{parent_image}"
            parent_image_input = {"type": "image_url", "image_url": {"url": parent_url}}

            sub_image_inputs = []
            user_prompts = []
            sub_image_ids = []
            # List including explicit text reference to image name
            sub_image_with_text = []

            for _, row in group.iterrows():
                sub_image = row["sub_image"]
                tem_type = row["TEM_type"]
                try:
                    sub_batch = sub_to_batch.get(sub_image)
                    if not sub_batch:
                        raise ValueError(f"Missing sub batch for: {sub_image}")
                    sub_url = f"{IMAGE_BASE_URL}/{sub_batch}/{sub_image}"
                    sub_image_inputs.append({"type": "image_url", "image_url": {"url": sub_url}})
                    user_prompts.append(f"Sub-image: {sub_image} (predicted TEM type: {tem_type})")
                    sub_image_ids.append(sub_image)
                    
                    # Include filename explicitly before each image
                    sub_image_with_text.append({"type": "text", "text": f"Filename: {sub_image}"})
                    sub_image_with_text.append({"type": "image_url", "image_url": {"url": sub_url}})
                except Exception as e:
                    error_msg = f"{sub_image} | URL error: {e}"
                    print(error_msg)
                    f_fail.write(error_msg + "\n")
                    error_images.append(sub_image)

            if not sub_image_inputs:
                print(f" Skip：{parent_image} , no sub-images")
                f_fail.write(f"{parent_image} | No sub-images\n")
                continue

            if pd.isna(caption):
                caption = ""

            # Construct user prompt clearly identifying all image filenames
            user_prompt = (
                f"The following images will be shown to you:\n"
                f"Parent image filename: {parent_image}\n"
                f"Caption: {caption.strip()}\n\n"
                + "\n".join(user_prompts) +
                "\n\nPlease generate multiple high-quality VQA (Visual Question Answering) pairs for each sub-image. "
                "Each VQA pair should include a meaningful question and answer that is visually grounded in the image content. "
                "Return your answer as a JSON list with the fields: sub_image, question, answer, level."
            )

            messages = [
                {"role": "system", "content": build_system_prompt()},
                {
                    "role": "user",
                    "content": [
                        {"type": "text", "text": user_prompt},
                        {"type": "text", "text": f"Filename: {parent_image}"},
                        parent_image_input,
                        *sub_image_with_text
                    ]
                }
            ]
            
            # Call GPT model
            try:
                response = client.chat.completions.create(
                    model="gpt-4o",
                    messages=messages,
                    temperature=0.7,
                )
                reply = response.choices[0].message.content.strip()
            except:
                try:
                    time.sleep(30)
                    response = client.chat.completions.create(
                        model="gpt-4o",
                        messages=messages,
                        temperature=0.7,
                    )
                    reply = response.choices[0].message.content.strip()
                except:
                    reply = ""
                
            total_calls += 1

            # Log result
            writer.writerow({
                "parent_image": parent_image,
                "sub_image_ids": ";".join(sub_image_ids),
                "reply": reply
            })

            # Explicitly delete used vars to reduce memory pressure
            del sub_image_inputs, sub_image_ids, sub_image_with_text, user_prompts, messages, reply, response

        except Exception as e:
            error_msg = f"{parent_image} | GPT API error: {e}"
            print("❌", error_msg)
            traceback.print_exc()
            for sub, _ in group[["sub_image"]].itertuples(index=False):
                f_fail.write(f"{sub} | GPT error: {e}\n")
                error_images.append(sub)
            time.sleep(1)

        gc.collect()

# ====== Final report ======
print(f"\n✅ Finish first phase，sent {total_calls}  times GPT request。")
print(f"❌ Failed images total {len(error_images)} ，record to {FAILED_LIST}。")