# Use gpt-3.5-turbo to combine the pix2text & easy_ocr outputs

In [None]:
# === kernel system ===
import os
import gc
import time
import json
import shutil
import traceback
import csv
from collections import defaultdict
from pathlib import Path
# === packages ===
import pandas as pd
from tqdm import tqdm
from PIL import Image
from openai import OpenAI
import re


# Load api keys

In [None]:
def load_keys(path):
    keys = {}
    with open(path, "r") as f:
        for line in f:
            if "=" in line:
                k, v = line.strip().split("=", 1)
                keys[k] = v
    return keys

keys = load_keys("/home/ne6131039/Desktop/Confidential_Key.txt")
openai_key = keys['OPENAI_KEY']
huggingface_write_key = keys['HUGGINGFACE_WRITE_KEY']

# Merge two type of text and label caption_type

### System prompt

In [None]:
system_prompt = """
You are a scientific caption corrector and classifier.

You are given two OCR outputs from a scientific figure caption:
- One from Pix2Text (in LaTeX-style format)
- One from EasyOCR (standard plain text)

Your job is to:
1. Reconstruct the most accurate and readable caption in proper English.
2. Classify the content into one of the following categories:
   ["nanomaterials", "bioimaging", "crystallography", "device structures", "other"]

--- Output format ---
caption_content: <your refined caption>
caption_type: <category>

Note: 
-Always use Unicode superscript formatting for units (e.g., m⁻², cm⁻³).\
-Preserve chemical formulas as they appear (e.g., CsPbBr₃, LaNi₀.₅Cu₀.₁O₃) without expanding them into full names.
"""

In [None]:
# Initialize OpenAI client
client = OpenAI(api_key=openai_key)  

# Constants
INPUT_CSV = "../caption_analysis/ocr_output.csv"
OUTPUT_CSV = "caption_analysis_output.csv"
BATCH_SIZE = 100

#Initialize index 
START_INDEX = 0  

# Load input
df = pd.read_csv(INPUT_CSV).fillna("")
total_rows = len(df)
# total_rows = 40000

# Check how many already processed
if os.path.exists(OUTPUT_CSV):
    output_df = pd.read_csv(OUTPUT_CSV)
    processed_images = set(output_df["image_name"])
else:
    output_df = pd.DataFrame(columns=["image_name", "caption_content", "caption_type"])
    processed_images = set()

# Process in batches
for idx in range(START_INDEX, total_rows):
    row = df.iloc[idx]
    image_name = row["image_name"]

    # Skip if already processed
    if image_name in processed_images:
        continue

    print(f" Processing {idx+1}/{total_rows}: {image_name}")

    combined_text = f"""
[Pix2Text OCR]:
{row['pix2text_ocr'].strip()}

[EasyOCR OCR]:
{row['easyocr_ocr'].strip()}
"""

    try:
        response = client.chat.completions.create(
            model="gpt-3.5-turbo",
            messages=[
                {"role": "system", "content": system_prompt},
                {"role": "user", "content": combined_text}
            ],
            temperature=0.2,
            max_tokens=4000
        )

        reply = response.choices[0].message.content.strip()
        lines = reply.splitlines()
        
        caption_content = ""
        caption_type = ""
        
        for line in lines:
            if line.startswith("caption_content:"):
                caption_content = line.replace("caption_content:", "").strip()
            elif line.startswith("caption_type:"):
                caption_type = line.replace("caption_type:", "").strip()

    except Exception as e:
        caption_content = f"[ERROR] {e}"
        caption_type = "unknown"
        time.sleep(1)

    # Save result
    new_row = pd.DataFrame([{
        "image_name": image_name,
        "caption_content": caption_content,
        "caption_type": caption_type
    }])

    new_row.to_csv(OUTPUT_CSV, mode="a", header=not os.path.exists(OUTPUT_CSV), index=False)

    # Optional delay to avoid rate limits
    time.sleep(0.5)

print ("Processing Done")


# Choose [nanomaterials,crystallography] images
> remove bio data and balance dataset

In [None]:
# Read CSV files for analysis
caption_df = pd.read_csv("caption_analysis_output.csv")
tem_df = pd.read_csv("../tem_images_description.csv")

# Filter desired caption types
target_types = ["nanomaterials","crystallography"]
valid_parents = caption_df[caption_df["caption_type"].isin(target_types)]["image_name"].unique()

# Filter sub-images corresponding to selected parent images
matched_df = tem_df[tem_df["parent_image"].isin(valid_parents)]

# Sample 10,000 sub-images for each TEM type (if available)
balanced_subs = (
    matched_df
    .groupby("TEM_type", group_keys=False)
    .apply(lambda g: g.sample(n=10000, random_state=42) if len(g) >= 10000 else g)
)

print("number of every TEM_type ：")
print(balanced_subs["TEM_type"].value_counts())

selected_parents = balanced_subs["parent_image"].nunique()
print(f"corresponding parent_image numbers：{selected_parents}")

balanced_subs.to_csv("balanced_tem_data.csv", index=False)


In [None]:
# 1. Read the caption_analysis_output.csv file and select rows with type in ["nanomaterials", "crystallography"]
# 2. For the selected rows, copy the corresponding images:
#    - From /home/ne6131039/Desktop/TEM_DATAS/LLaVA Dataset/PDF_images to /home/ne6131039/Desktop/Test_Dataset/PDF_images
#    - From /home/ne6131039/Desktop/TEM_DATAS/LLaVA Dataset/TEM_images to /home/ne6131039/Desktop/Test_Dataset/TEM_images


# Source paths
pdf_src = "/home/ne6131039/Desktop/TEM_DATAS/LLaVA Dataset/PDF_images"
tem_src = "/home/ne6131039/Desktop/TEM_DATAS/LLaVA Dataset/TEM_images"

# Destination paths
pdf_dst = "/home/ne6131039/Desktop/Test_Dataset/PDF_images"
tem_dst = "/home/ne6131039/Desktop/Test_Dataset/TEM_images"

os.makedirs(pdf_dst, exist_ok=True)
os.makedirs(tem_dst, exist_ok=True)

# Extract selected parent and sub image names from balanced dataset
selected_parent_images = balanced_subs["parent_image"].unique()
selected_sub_images = balanced_subs["sub_image"].tolist()

# Copy parent PDF images based on selected_parent_images
for parent_image in tqdm(selected_parent_images, desc="Copying parent"):
    src_pdf = os.path.join(pdf_src, parent_image)
    dst_pdf = os.path.join(pdf_dst, parent_image)

    if os.path.exists(src_pdf):
        shutil.copy2(src_pdf, dst_pdf)
        # print(f"Copied PDF: {parent_image}")
    else:
        print(f"PDF not found: {src_pdf}")

# Copy sub TEM images from selected_sub_images
for sub_image in tqdm(selected_sub_images, desc="Copying subs"):
    src_tem = os.path.join(tem_src, sub_image)
    dst_tem = os.path.join(tem_dst, sub_image)

    if os.path.exists(src_tem):
        shutil.copy2(src_tem, dst_tem)
        # print(f"Copied TEM: {sub_image}")
    else:
        print(f"TEM not found: {src_tem}")

print("Balanced PDF & TEM images have been successfully copied.")



# upload the TEM_Dataset to hugging face 
> for next stage image input

In [None]:
# ==== Setting ====
SOURCE_DIR = "/home/ne6131039/Desktop/Test_Dataset/TEM_images"  
OUTPUT_DIR = "/home/ne6131039/Desktop/Batch_upload_images/TEM_images"  
CSV_PATH = "tem_batch_log.csv"     
BATCH_SIZE = 1000                      

# ==== OUTPUT_DIR ====
os.makedirs(OUTPUT_DIR, exist_ok=True)

# ==== collect all images ====
valid_exts = {'.jpg', '.jpeg', '.png', '.bmp', '.gif', '.tiff'}
image_files = [f for f in sorted(os.listdir(SOURCE_DIR)) if Path(f).suffix.lower() in valid_exts]

# ==== batch & copy ====
batch_log = []
for i, filename in enumerate(image_files):
    batch_idx = i // BATCH_SIZE + 1
    batch_folder = f"TEM_Batch_{batch_idx:03d}"
    batch_path = os.path.join(OUTPUT_DIR, batch_folder)
    os.makedirs(batch_path, exist_ok=True)

    src_path = os.path.join(SOURCE_DIR, filename)
    dst_path = os.path.join(batch_path, filename)
    shutil.copy2(src_path, dst_path)

    batch_log.append({
        "filename": filename,
        "batch_folder": f"TEM_Batch_{batch_idx:03d}"
    })

# ==== Write CSV ====
with open(CSV_PATH, "w", newline='', encoding="utf-8") as csvfile:
    writer = csv.DictWriter(csvfile, fieldnames=["filename", "batch_folder"])
    writer.writeheader()
    writer.writerows(batch_log)

print(f"Finish：dealed {len(image_files)} images，CSV record in：{CSV_PATH}")

# Image data distill using GPT-4o for QA pairs dataset

In [None]:
# ====== PROMPT 定義 ======
def build_system_prompt():
    return (
        """
            You are a scientific assistant constructing a visual question-answering (VQA) dataset based on Transmission Electron Microscopy (TEM) sub-images.
            
            You will be provided with:
            - A caption describing the parent figure (for background context only; do not use it in answers)
            - Several TEM sub-images cropped from the parent figure
            - Predicted image type labels for each sub-image (for reference only; do not rely on them when answering)
            
            ---
            
            Your task is to generate high-quality question–answer (QA) pairs for each individual TEM sub-image, following the rules below.
            
            ### Step 1: Handle Blank or Empty Images
            
            First, check whether the sub-image contains any meaningful visual structure:
            - Are there visible features such as particles, lattice fringes, contrast zones, porosity, or surface textures?
            - Or is the image blank, blurry, purely noisy, or lacks discernible structures?
            
            → If the image appears blank or uninformative, return only one Level 0 QA pair:
            
            {
              "sub_image": "<filename>",
              "question": "What can be seen in this image?",
              "answer": "The image appears empty or lacks visible structures.",
              "level": 0
            }
            
            Then skip further QA generation for this image.
            
            ---
            
            ### Step 2: QA Generation

            If the image contains observable structures, you must generate following outputs:
            question–answer (QA) pairs, divided into:
            
            - Level 1 (several questions): Basic classification
            - Level 2 (several questions): Direct visual identification
            - Level 3 (several questions): Visual reasoning
            - 1 total description (Level 4): A comprehensive 3-paragraph summary
            
            #### Level 1 – Basic Classification (several QA)
                Focus on identifying the image modality and type based on visible characteristics.
                Do not rely on metadata or provided labels. 
            Example topics:  
                - Modality: CTEM, HRTEM, STEM, diffraction
                - Type: nanoparticle, lattice image, porous film, diffraction pattern            
            Example:           
                Q: What is the imaging modality used in this image?
                A: The visible lattice fringes suggest this is an HRTEM image.
                
            #### Level 2 – Direct Visual Identification (several QA)
                Focus only on clearly observable features, no reasoning.
                Describe what you can count, measure, or visually confirm.
            Example topics:
                - Particle count
                - Shape (e.g., spherical, rod-like)
                - Contrast, porosity, texture
                - Characteristic size (~100 nm)
                - Spatial arrangement (aligned vs random)
            Example:
                Q: How many particles are visible in the image?
                A: Approximately 20 particles are visible across the field of view.
                
            #### Level 3 – Visual Reasoning (several QA)
                Use interpretation based on visual evidence, not speculation.
                Infer implications or behaviors from observed features.
            Example topics:
                - Thickness estimation based on contrast
                - Particle density or uniformity
                - Self-assembly indicators
                - Implications for material structure or properties
            Example:
                Q: Does the distribution of particles suggest uniformity?
                A: The particles appear unevenly spaced, indicating non-uniform distribution and potential aggregation.
                
            #### Level 4 – Full Image Description (1 QA)
                Summarize the image in 3 paragraphs, each reflecting insights from one level:
                Paragraph 1: Basic classification (Level 1)
                Paragraph 2: Visual identification (Level 2)
                Paragraph 3: Reasoned interpretation (Level 3)
                Use academic language and separate paragraphs with \\n\\n.
            Example:
                Question: Describe this image
                Answer:
                The image appears to be acquired using high-resolution transmission electron microscopy (HRTEM), as suggested by the presence of clear lattice fringes and high spatial resolution.
                
                Approximately 15 nearly spherical particles are distributed across the field. The contrast is sharp, and the particle sizes are relatively consistent, estimated to be under 100 nm. No obvious porosity is visible.
                
                The slight variation in contrast and particle overlap suggests differences in local thickness. The loosely clustered arrangement of particles may indicate weak self-assembly or aggregation effects during sample preparation.
            ---
            
            ### Step 3: Use a Chain-of-Thought (CoT) process
            
            - Before selecting final QA items, you must first brainstorm at least 5 distinct candidate questions per level. Avoid overly vague or generic phrasing (e.g., "What is the feature of this image?").
            
            Then, select the best few per level based on:
            - Topical diversity (covering different visual features)
            - Clarity and precision of question and answer
            - Scientific relevance and reasoning depth

            ### Step 4: Output Format
            
            Return a **JSON list** of all QA pairs. Each item must include:
            - `sub_image`: the filename of the sub-image
            - `question`: a question that is answerable solely from the image
            - `answer`: an accurate answer based on visual observation
            - `level`: integer (0 = empty image, 1 = classification, 2 = recognition, 3 = reasoning, 4 = summary)
            
            ### Important Rules:
            - Do NOT use parent captions or predicted labels in answers (they are for context only).
            - Do NOT mention material-specific names (e.g., Cu, Au) unless they are visibly labeled in the image.
            - Use general scientific terms such as "the material", "the nanoparticles", or "the structure".
            - Each sub-image must be treated independently. Do NOT compare it to other images.
            - All questions and answers must be **answerable based only on the visual content** of the sub-image.
            - Return only a valid JSON array of objects. No explanation, no formatting, no markdown.
            - Do not include any headings, explanations, markdown, or extra text.
            - Return a valid raw JSON array only. Do not wrap in triple quotes or code fences.
        """
    )

In [None]:
# ====== Initialize OpenAI client ======
client = OpenAI(api_key=openai_key)

# ====== Base URL for image access on Hugging Face ======
IMAGE_BASE_URL = "https://huggingface.co/datasets/LabSmart/TEM_Dataset/resolve/main"

# ====== File paths ======
OUTPUT_CSV = "gpt_reply_log.csv"        
CAPTION_CSV = "caption_analysis_output.csv"
TEM_CSV = "balanced_tem_data.csv"
OUTPUT_JSONL = "distilled_llava_vqa.jsonl"
FAILED_LIST = "failed_images.txt"
PDF_BATCH_LOG = "pdf_batch_log.csv"
TEM_BATCH_LOG = "tem_batch_log.csv"

# ====== Load metadata and batch logs ======
df_caption = pd.read_csv(CAPTION_CSV)
df_tem = pd.read_csv(TEM_CSV)
df_pdf_batch = pd.read_csv(PDF_BATCH_LOG)
df_tem_batch = pd.read_csv(TEM_BATCH_LOG)

# ====== Map image filename to batch folder ======
parent_to_batch = dict(zip(df_pdf_batch["filename"], df_pdf_batch["batch_folder"]))
sub_to_batch = dict(zip(df_tem_batch["filename"], df_tem_batch["batch_folder"]))

# ====== Filter relevant captions ======
df_caption = df_caption[df_caption["caption_type"].isin(["nanomaterials", "crystallography"])]
df = df_tem.merge(
    df_caption[["image_name", "caption_content"]],
    left_on="parent_image",
    right_on="image_name",
    how="left"
)
df.drop(columns=["image_name"], inplace=True)
assert "caption_content" in df.columns, "fail to pair caption"

# ====== Group sub-images by their parent image ======
grouped = df.groupby("parent_image")
MAX_GROUPS = 22000
total_calls = 0
error_images = []

In [None]:
# ====== Process each parent image and its sub-images ======
with open(OUTPUT_CSV, "a", encoding="utf-8", newline="") as f_log, open(FAILED_LIST, "a") as f_fail:
    writer = csv.DictWriter(f_log, fieldnames=["parent_image", "sub_image_ids", "reply"])
    if f_log.tell() == 0:
        writer.writeheader()

    for i, (parent_image, group) in enumerate(grouped):
        if i < 20000:
            continue
        if i >= MAX_GROUPS:
            break

        print(f"Dealing with {i+1} th parent image: {parent_image}\n")
        try:
            caption = group.iloc[0]["caption_content"]
            parent_batch = parent_to_batch.get(parent_image)
            if not parent_batch:
                raise ValueError(f"Missing parent batch for: {parent_image}")
            parent_url = f"{IMAGE_BASE_URL}/{parent_batch}/{parent_image}"
            parent_image_input = {"type": "image_url", "image_url": {"url": parent_url}}

            sub_image_inputs = []
            user_prompts = []
            sub_image_ids = []
            # List including explicit text reference to image name
            sub_image_with_text = []

            for _, row in group.iterrows():
                sub_image = row["sub_image"]
                tem_type = row["TEM_type"]
                try:
                    sub_batch = sub_to_batch.get(sub_image)
                    if not sub_batch:
                        raise ValueError(f"Missing sub batch for: {sub_image}")
                    sub_url = f"{IMAGE_BASE_URL}/{sub_batch}/{sub_image}"
                    sub_image_inputs.append({"type": "image_url", "image_url": {"url": sub_url}})
                    user_prompts.append(f"Sub-image: {sub_image} (predicted TEM type: {tem_type})")
                    sub_image_ids.append(sub_image)
                    
                    # Include filename explicitly before each image
                    sub_image_with_text.append({"type": "text", "text": f"Filename: {sub_image}"})
                    sub_image_with_text.append({"type": "image_url", "image_url": {"url": sub_url}})
                except Exception as e:
                    error_msg = f"{sub_image} | URL error: {e}"
                    print(error_msg)
                    f_fail.write(error_msg + "\n")
                    error_images.append(sub_image)

            if not sub_image_inputs:
                print(f" Skip：{parent_image} , no sub-images")
                f_fail.write(f"{parent_image} | No sub-images\n")
                continue

            if pd.isna(caption):
                caption = ""

            # Construct user prompt clearly identifying all image filenames
            user_prompt = (
                f"The following images will be shown to you:\n"
                f"Parent image filename: {parent_image}\n"
                f"Caption: {caption.strip()}\n\n"
                + "\n".join(user_prompts) +
                "\n\nPlease generate multiple high-quality VQA (Visual Question Answering) pairs for each sub-image. "
                "Each VQA pair should include a meaningful question and answer that is visually grounded in the image content. "
                "Return your answer as a JSON list with the fields: sub_image, question, answer, level."
            )

            messages = [
                {"role": "system", "content": build_system_prompt()},
                {
                    "role": "user",
                    "content": [
                        {"type": "text", "text": user_prompt},
                        {"type": "text", "text": f"Filename: {parent_image}"},
                        parent_image_input,
                        *sub_image_with_text
                    ]
                }
            ]
            
            # Call GPT model
            try:
                response = client.chat.completions.create(
                    model="gpt-4o",
                    messages=messages,
                    temperature=0.7,
                )
                reply = response.choices[0].message.content.strip()
            except:
                try:
                    time.sleep(30)
                    response = client.chat.completions.create(
                        model="gpt-4o",
                        messages=messages,
                        temperature=0.7,
                    )
                    reply = response.choices[0].message.content.strip()
                except:
                    reply = ""
                
            total_calls += 1

            # Log result
            writer.writerow({
                "parent_image": parent_image,
                "sub_image_ids": ";".join(sub_image_ids),
                "reply": reply
            })

            # Explicitly delete used vars to reduce memory pressure
            del sub_image_inputs, sub_image_ids, sub_image_with_text, user_prompts, messages, reply, response

        except Exception as e:
            error_msg = f"{parent_image} | GPT API error: {e}"
            print("❌", error_msg)
            traceback.print_exc()
            for sub, _ in group[["sub_image"]].itertuples(index=False):
                f_fail.write(f"{sub} | GPT error: {e}\n")
                error_images.append(sub)
            time.sleep(1)

        gc.collect()

# ====== Final report ======
print(f"\n✅ Finish first phase，sent {total_calls}  times GPT request。")
print(f"❌ Failed images total {len(error_images)} ，record to {FAILED_LIST}。")