In [None]:
from dotenv import load_dotenv
from openai import OpenAI
import base64
import os

load_dotenv()

openai_client = OpenAI()

def encode_image(image_path):
    with open(image_path, "rb") as image_file:
        return base64.b64encode(image_file.read()).decode("utf-8")

def call_4o(prompt, image_path=None):
    content = [{"type": "text", "text": prompt}]
    if image_path:
        base64_image = encode_image(image_path)
        content.append({
            "type": "image_url",
            "image_url": {
                "url": f"data:image/jpeg;base64,{base64_image}",
            },
        })
    completion = openai_client.chat.completions.create(
        model="gpt-4o-mini",
        messages=[{
            "role": "user",
            "content": content
        }]
    )
    return completion.choices[0].message.content

def call_4o_text_only(prompt):
    completion = openai_client.chat.completions.create(
        model="gpt-4o-mini",
        messages=[{
            "role": "user",
            "content": prompt
        }]
    )
    return completion.choices[0].message.content


In [None]:
import json

eval_json_path = "valid_contents.json"
train_json_path = "train_contents.json"

with open(eval_json_path, "r") as f:
    eval_contents = json.load(f)

with open(train_json_path, "r") as f:
    train_contents = json.load(f)

valid_dataset = []
train_dataset = []

for custom_id, content in eval_contents.items():
    valid_dataset.append({
        "custom_id": custom_id,
        "image_path": f"data/crowdai/val/images/{custom_id}.jpg",
        "caption": content
    })

for custom_id, content in train_contents.items():
    train_dataset.append({
        "custom_id": custom_id,
        "image_path": f"data/crowdai/train/images/{custom_id}.jpg",
        "caption": content
    })

In [None]:
import re
from tqdm import tqdm
from typing import List, Dict, Any

def parse_caption(dataset: List[Dict[str, Any]]):
    dirty_dataset = []
    good_dataset = []
    for item in tqdm(dataset):
        caption = item["caption"]
        pattern = r"(\d+)[:\s]+([^\n]+)"
        matches = re.findall(pattern, caption)
        descriptions = {int(idx): desc.strip() for idx, desc in matches}
        if not descriptions:
            dirty_dataset.append(item)
            continue
        item["descriptions"] = descriptions
        good_dataset.append(item)
    
    return good_dataset, dirty_dataset

good_dataset, dirty_dataset = parse_caption(train_dataset)

In [None]:
dirty_dataset[0], len(dirty_dataset)

In [None]:
import jinja2
from tqdm import tqdm
import re

SYSTEMP_PROMPT = """You are a description collator. Please process the input according to the following rules:

# Task
Extract instance indexes and their corresponding descriptions from the input text.  
The instance index is an integer, and it can appear anywhere near the description, possibly surrounded by brackets, parentheses, or separated by spaces, dashes, colons, etc.  
You must accurately associate each instance index with its corresponding description, regardless of formatting inconsistencies.

# Input
{{input_prompt}}

# Processing Rules
1. For each instance, extract:
   - The instance index (an integer).
   - The associated description text.
2. Rephrase each description into a **noun phrase** that starts with "**a building**" or "**the building**".
   - If the original description already starts with "a building" or "the building", keep it.
   - If it does not, rephrase it naturally so that it does.
3. Focus strictly on spatial location, structure, or appearance.  
   Do not add new information, actions, or interpretations.
4. Ignore all irrelevant symbols, formatting inconsistencies, and line breaks in the input.

# Output Format
Return the results in **strict JSON format**, where:
- Each **key** is the instance index (as a string).
- Each **value** is the cleaned and corrected noun phrase.

The JSON structure must look like:
```json
{
    "0": "[Noun phrase for instance 0]",
    "1": "[Noun phrase for instance 1]",
    "2": "[Noun phrase for instance 2]",
    ...
}
"""

template = jinja2.Template(SYSTEMP_PROMPT)

new_eval_dataset = []

for item in tqdm(dirty_dataset):
    # image = Image.open(item["image_path"])
    caption = item["caption"]
    prompt = template.render(input_prompt=caption)

    response = call_4o_text_only(prompt)
    # Extract the JSON string from the response text
    json_match = re.search(r'\{.*\}', response, re.DOTALL)
    if json_match:
        final_prompt = json.loads(json_match.group())
    else:
        final_prompt = {}
    new_eval_dataset.append({
        **item,
        "descriptions": final_prompt
    })
    print(final_prompt)