We first start off by preparing structured input data for a transformer-based summarization model by extracting and cleaning the categories, annotations, and questions data from the AI2D visual question answering dataset. We combine these into a refined textual format for each image and save the results as a JSON file (transformer_inputs.json)

In [None]:
import os
import json
import re

# Defining the dataset root directory
dataset_root = "/app/ai2d"

# Function to prepare refined input data for transformer summarization task
def prepare_transformer_input():
    print("Preparing refined input data for transformer summarization task\n")

    # Loading the categories.json
    categories_path = os.path.join(dataset_root, "categories.json")
    with open(categories_path, 'r', encoding='utf-8') as f:
        categories = json.load(f)

    # Getting the lists of files for images, annotations, and questions
    images_dir = os.path.join(dataset_root, "images")
    annotations_dir = os.path.join(dataset_root, "annotations")
    questions_dir = os.path.join(dataset_root, "questions")

    # Creating dictionaries with filenames mapped to relative paths (without extensions)
    image_files = {os.path.splitext(f)[0]: os.path.join("images", f) for f in os.listdir(images_dir) if f.endswith(('.png', '.jpg', '.jpeg'))}
    annot_files = {os.path.splitext(os.path.splitext(f)[0])[0]: f for f in os.listdir(annotations_dir) if f.endswith('.json')}
    question_files = {os.path.splitext(os.path.splitext(f)[0])[0]: f for f in os.listdir(questions_dir) if f.endswith('.json')}

    print(f"Total images: {len(image_files)}")
    print(f"Total annotation files: {len(annot_files)}")
    print(f"Total question files: {len(question_files)}\n{'-'*50}")

    # Compiling a regex to filter noise from labels and questions
    noise_pattern = re.compile(r'copyright|©|\d{4,}|\w+\.\w+|\(previewing purposes only\)', re.IGNORECASE)

    # List to store the final structured input data
    inputs = []

    # Iterating over each image by its base name (without extension)
    for img_base in image_files.keys():
        input_text = []
        img_path = image_files[img_base]  # Relative path to the image
        full_img_name = os.path.basename(img_path)  # Just the filename

        # Adding the category from categories.json if it exists
        if full_img_name in categories:
            cat = categories[full_img_name]
            input_text.append(f"Category: {cat}")
            print(f"Debug - Category for {img_path}: {cat}")
        else:
            print(f"Debug - No category for {img_path} with key {full_img_name}")

        # Adding annotation labels from corresponding annotation file
        if img_base in annot_files:
            with open(os.path.join(annotations_dir, annot_files[img_base]), 'r', encoding='utf-8') as f:
                annot_data = json.load(f)
                if "text" in annot_data and annot_data["text"]:
                    # Extracting the "value" field from each annotation text object
                    labels = [text_info["value"] for text_id, text_info in annot_data["text"].items() if "value" in text_info]
                    if labels:
                        # Filtering out noise using regex
                        labels = [label for label in labels if not noise_pattern.search(label)]
                        if labels:
                            input_text.append(f"Labels: {', '.join(labels)}")
                            print(f"Debug - Labels for {img_path}: {', '.join(labels)}")

        # Adding question text from the corresponding question file
        if img_base in question_files:
            with open(os.path.join(questions_dir, question_files[img_base]), 'r', encoding='utf-8') as f:
                q_data = json.load(f)
                print(f"Debug - Raw question data for {img_path}: {q_data}")
                if isinstance(q_data, dict) and "questions" in q_data and q_data["questions"]:
                    # Taking the first question key directly as the question text
                    first_question_key = list(q_data["questions"].keys())[0]
                    question = first_question_key
                    if not noise_pattern.search(question):
                        input_text.append(f"Question: {question}")
                        print(f"Debug - Parsed question for {img_path}: {question}")
                else:
                    print(f"Debug - No valid questions found for {img_path}")

        # If we gathered any valid data, we add it to the inputs list
        if input_text:
            combined_input = "; ".join(input_text)
            inputs.append({"image": img_path, "input_text": combined_input})
            print(f"Prepared input for {img_path}: {combined_input}")
        else:
            print(f"No input prepared for {img_path}")

    # Final count of prepared inputs
    print(f"Total inputs prepared: {len(inputs)}")

    # Writing the processed input data to a JSON file for use in downstream tasks
    with open("transformer_inputs.json", 'w', encoding='utf-8') as f:
        json.dump(inputs, f, indent=2)
    print("Inputs saved to 'transformer_inputs.json'.")

if __name__ == "__main__":
    prepare_transformer_input()

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
Prepared input for images/4496.png: Category: typesOf; Labels: PLATE 4. LEAF MARGINS, CILIATE, LACERATE, LACINIATE, LOBED, PINNATIFID, DOUBLY SERRATE, SERRATE, REVOLUTE, SINUATE, SERRULATE, CLEFT, CRENATE, CRENULATE, INCISED, ENTIRE, DENTICULATE, DENTATE, INVOLUTE; Question: What represents cleft in the diagram?
Debug - Category for images/4497.png: typesOf
Debug - Labels for images/4497.png: Littorella. 3 flower, natural size,, Littorella. 3 flower cut vertically (mag.)., Littorella. Young ovary cut vertically to show the conducting tissue (mag.)., Littorella. Young ovule with the conducting tissue (mag.)., Littorella. Pistil (mag.)., Littorella. flower after expansion with long pendent filaments deprived of their anthers (mag.)., Littorella. Diagram 3. , Littorella lacustris. Monoecious inflorescece: flower solitary, pedicelled; numerous, sessile, at the base of the peduncle., Littorella. Germinating seed (mag.)., Litto

In [None]:
!pip install tqdm --quiet

[0m

We now use a language model (Zephyr-7B) to generate one-sentence educational captions for images in the AI2D dataset. We will filter out low-quality inputs, preprocess them by removing noisy parts (like questions, which could force the language model to generate answers instead of captions), and feed them into the model with a base prompt. Valid captions will be generated in batches and saved to `ai2d_educational_captions_complete.json`.

In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM
from tqdm import tqdm
import torch
import json
import re
from concurrent.futures import ThreadPoolExecutor

# Choosing the model
model_id = "HuggingFaceH4/zephyr-7b-alpha"

# Loading the tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
)

# Moving the model to GPU if it is available
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)

# Using torch.compile for PyTorch 2.x on CUDA devices to optimize performance
if hasattr(torch, 'compile') and device == "cuda":
    model = torch.compile(model)

print(f"Model loaded on {device}")

# Loading input data, which is basically the refined textual input data being generated from the above code section
with open("transformer_inputs.json", 'r', encoding='utf-8') as f:
    all_data = json.load(f)

samples = all_data
print(f"Loaded {len(samples)} samples for processing.")

# Precompiling regex patterns for efficiency
QUESTION_PATTERN = re.compile(r'Question:.*?(?=\n|$|\n\s*\n)', flags=re.DOTALL)
CAPTION_PATTERN = re.compile(r'Caption:.*?(?=\n|$)', flags=re.DOTALL)
CATEGORY_PATTERN = re.compile(r'Category:\s*([^\n;]+)', flags=re.IGNORECASE)
LABELS_PATTERN = re.compile(r'Labels:\s*([^\n;]+)', flags=re.IGNORECASE)
ANNOTATION_PATTERN = re.compile(r'(Category|Labels|Question):.*?(?=\n|$)', flags=re.IGNORECASE)

def is_valid_input(input_text):
    """
    Checks if input_text contains semantically meaningful labels, categories, or annotations.
    """
    if not input_text or input_text.isspace():
        return False
    cleaned_text = input_text.strip()
    category_match = CATEGORY_PATTERN.search(cleaned_text)
    labels_match = LABELS_PATTERN.search(cleaned_text)
    annotations = ANNOTATION_PATTERN.sub('', cleaned_text).strip()

    # Checking category validity
    category_valid = False
    if category_match:
        category = category_match.group(1).strip()
        category_valid = bool(category and len(category) > 1 and not re.match(r'^(unknown|misc|figure|\d+)$', category, re.IGNORECASE))

    # Checking labels validity
    labels_valid = False
    if labels_match:
        labels = labels_match.group(1).strip()
        label_list = [label.strip() for label in re.split(r'[,;]', labels) if label.strip()]
        descriptive_labels = [label for label in label_list if len(label) > 2 and re.search(r'[a-zA-Z]', label)]
        labels_valid = bool(label_list and len(descriptive_labels) > 0)

    # Checking annotation validity
    annotations_valid = False
    if annotations:
        words = [w for w in annotations.split() if len(w) > 3 and re.search(r'[a-zA-Z]', w)]
        annotations_valid = len(words) > 1

    # Special case: Allowing annotations if labels are not descriptive but category is 'partsOf'
    if labels_match:
        labels = labels_match.group(1).strip()
        label_list = [label.strip() for label in re.split(r'[,;]', labels) if label.strip()]
        descriptive_labels = [label for label in label_list if len(label) > 2 and re.search(r'[a-zA-Z]', label)]
        if label_list and not descriptive_labels and category_match and 'partsof' in category_match.group(1).lower():
            return annotations_valid

    return labels_valid or (category_valid and annotations_valid) or annotations_valid

def preprocess_input(input_text):
    """
    Removes question and caption parts from the input text.
    """
    cleaned_text = QUESTION_PATTERN.sub('', input_text).strip()
    cleaned_text = CAPTION_PATTERN.sub('', cleaned_text).strip()
    return cleaned_text

# Base prompt that remains constant for all the inputs
BASE_PROMPT = ("Write one concise educational caption sentence that describes the content of a diagram based on this description, focusing only on the diagram's labels and category: ")

def generate_one_sentence_caption(batch_inputs, tokenizer, model, max_new_tokens=80):
    """
    Generates a one-sentence caption for a batch of preprocessed inputs.
    """
    # Constructing prompts by appending the specific input description to the base prompt.
    prompts = [BASE_PROMPT + input_text for input_text in batch_inputs]

    # Tokenizing the batch of prompts
    inputs = tokenizer(prompts, return_tensors="pt", padding=True, truncation=True, max_length=512)
    input_ids = inputs.input_ids.to(device)
    attention_mask = inputs.attention_mask.to(device)

    # Generating captions with provided generation parameters
    with torch.no_grad():
        output_ids = model.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            max_new_tokens=max_new_tokens,
            temperature=0.1,
            do_sample=True,
            top_p=0.9,
            pad_token_id=tokenizer.eos_token_id,
            early_stopping=True,
            num_beams=3  # Reduced for faster generation
        )

    captions = []
    # Postprocess generated text to extract the first valid sentence
    for i in range(len(batch_inputs)):
        generated = tokenizer.decode(output_ids[i], skip_special_tokens=True).strip()
        if generated.startswith(prompts[i]):
            generated = generated[len(prompts[i]):].strip()
        generated = re.sub(r'^(Answer|Caption):?\s*', '', generated, flags=re.IGNORECASE).strip()
        sentences = re.split(r'(?<=[.!?])\s+', generated) if generated else []
        valid_sentences = [s.strip() for s in sentences if re.search(r'[a-zA-Z]', s) and not re.match(r'^[\d\s,.]+$', s)]
        caption = valid_sentences[0] if valid_sentences else "The diagram describes its labeled components."
        caption = caption.split('.')[0].strip() + '.' if '.' in caption else caption
        if caption and caption[-1] not in ".!?":
            caption += "."
        captions.append(caption)

    return captions

# Using the ThreadPoolExecutor to speed up filtering and preprocessing of the input samples.
def process_sample(item):
    preprocessed_text = preprocess_input(item["input_text"])
    valid = is_valid_input(item["input_text"])
    return item["image"], preprocessed_text, valid

# Parallel preprocessing
with ThreadPoolExecutor() as executor:
    processed_samples = list(executor.map(process_sample, samples))

# Filtering out invalid samples
valid_samples = [(img, text) for img, text, valid in processed_samples if valid]
skipped_count = len(samples) - len(valid_samples)

# Processing valid samples in batches
batch_size = 16
results = []
for i in tqdm(range(0, len(valid_samples), batch_size), desc="Generating Captions for Valid Samples"):
    batch = valid_samples[i:i + batch_size]
    batch_inputs = [text for _, text in batch]
    batch_images = [img for img, _ in batch]

    batch_captions = generate_one_sentence_caption(batch_inputs, tokenizer, model)

    for image, caption in zip(batch_images, batch_captions):
        results.append({
            "image": image,
            "caption": caption
        })

print(f"Processed {len(results)} samples, skipped {skipped_count} due to lack of semantic meaning.")

# Saving the output captions to a JSON file
output_file = "ai2d_educational_captions_complete.json"
with open(output_file, "w", encoding="utf-8") as f:
    json.dump(results, f, indent=2)

print(f"Saved {len(results)} captions to {output_file}")

Loading checkpoint shards: 100%|██████████| 8/8 [00:06<00:00,  1.27it/s]


Model loaded on cuda
Loaded 4903 samples for processing.


Generating Captions for Valid Samples: 100%|██████████| 286/286 [32:44<00:00,  6.87s/it]

Processed 4567 samples, skipped 336 due to lack of semantic meaning.
Saved 4567 captions to ai2d_educational_captions_complete.json





We will now filter out low-quality or non-informative captions from the previously generated set of image-caption pairs (`ai2d_educational_captions_complete.json`). We will remove captions that are too short, repetitive, list-like, or contain unwanted patterns like URLs or poor sentence structure. The cleaned and valid captions will then be saved to a new file called `ai2d_educational_captions_filtered.json`.

In [None]:
import json
import re
from tqdm import tqdm

# Loading the generated captions from the above code section
input_file = "ai2d_educational_captions_complete.json"
with open(input_file, 'r', encoding='utf-8') as f:
    image_caption_pairs = json.load(f)

print(f"Loaded {len(image_caption_pairs)} image-caption pairs for filtering.")

# Defining the criteria for filtering captions
def has_unwanted_patterns(caption):
    """
    Checks if a caption is too short, a list-like structure, contains links, or has repetitive patterns.
    Returns True if the caption should be removed, False otherwise.
    """
    if not caption or not isinstance(caption, str) or caption.isspace():
        return True

    # Removing leading/trailing whitespace
    caption = caption.strip()
    if not caption:
        return True

    # Checking for links (e.g., contains "http", "www", or ".com")
    if re.search(r'http[s]?://|www\.|\.com|\.org|\.net', caption, re.IGNORECASE):
        return True

    # Checking for repetitive patterns like a single character repeated 5+ times in a row (e.g., "N, N, N, N, N")
    if re.search(r'(.)\1{4,}', caption):
        return True

    # Word repeated 3+ times in a row (e.g., "hello hello hello")
    words = caption.split()
    for i in range(len(words) - 2):
        if words[i] == words[i + 1] == words[i + 2]:
            return True

    # Checking if the caption is too short (fewer than 4 words)
    if len(words) < 4:
        return True

    # Checking for list-like captions (e.g., ", ribosome, chloroplast, cell wall.") which basically has the pattern of comma-separated items with little other text
    if re.match(r'^(?:,?\s*[A-Za-z0-9\'\u03A9\s]+,)+.*\.?$', caption.strip()):
        # Ensuring it's not a proper sentence by checking for minimal connecting words
        non_list_words = re.sub(r'^(?:,?\s*[A-Za-z0-9\'\u03A9\s]+,)+', '', caption).strip('., ')
        if not non_list_words or len(non_list_words.split()) < 2:
            return True

    # Check ingfor captions starting with a comma or lacking sentence structure
    if caption.startswith(','):
        return True

    # Checking for fragments lacking proper sentence structure (e.g., "'s 5th Grade Science Class."). The captions must have a subject and a verb to be a proper sentence
    if not re.search(r'\b(is|are|shows|illustrates|depicts|contains|has|includes)\b', caption.lower()):
        # Allowing some flexibility for descriptive captions without a verb if they are long enough
        if len(words) < 6:
            return True

    return False

# Filtering the captions
filtered_pairs = []
removed_count = 0

for pair in tqdm(image_caption_pairs, desc="Filtering Captions"):
    caption = pair.get("caption", "")
    if not has_unwanted_patterns(caption):
        filtered_pairs.append(pair)
    else:
        removed_count += 1

print(f"Kept {len(filtered_pairs)} pairs, removed {removed_count} pairs due to unwanted patterns.")

# Save the filtered results to a new JSON file
output_file = "ai2d_educational_captions_filtered.json"
with open(output_file, "w", encoding="utf-8") as f:
    json.dump(filtered_pairs, f, indent=2)

print(f"Saved {len(filtered_pairs)} filtered captions to {output_file}")

Loaded 4567 image-caption pairs for filtering.


Filtering Captions: 100%|██████████| 4567/4567 [00:00<00:00, 64855.69it/s]

Kept 3963 pairs, removed 604 pairs due to unwanted patterns.
Saved 3963 filtered captions to ai2d_educational_captions_filtered.json



