In [3]:
import requests, json, io, base64, re
from PIL import Image

# --- CONFIG ---
class ConfigLLaVAVision:
    ENDPOINT_ID = "og65lbckc1lf24"
    API_KEY = ""
    BASE_URL = f"https://api.runpod.ai/v2/{ENDPOINT_ID}/runsync"

# --- UTILS ---
def image_to_base64(img):
    buffered = io.BytesIO()
    img.save(buffered, format="PNG")
    return base64.b64encode(buffered.getvalue()).decode("utf-8")


def get_structured_caption(image, prompt):
    image_b64 = image_to_base64(image)
    payload = {
        "input": {
            "prompt": prompt,
            "source": image_b64
        }
    }

    headers = {
        "Content-Type": "application/json",
        "Authorization": f"Bearer {ConfigLLaVAVision.API_KEY}"
    }

    try:
        response = requests.post(ConfigLLaVAVision.BASE_URL, json=payload, headers=headers)
        response.raise_for_status()
        return response.json().get("output", {}).get("text", "")
    except Exception as e:
        print("[ERROR]", e)
        return ""


def extract_structured_caption(raw_caption: str) -> str:
    """Extract only the structured part from the model's raw response."""
    raw_caption = re.sub(r"<\|.*?\|>", "", raw_caption)  # remove <|...|>
    raw_caption = raw_caption.replace("Student's Image:", "")
    start = raw_caption.lower().find("main_objects:")
    if start == -1:
        return raw_caption.strip()
    return raw_caption[start:].strip()


def parse_caption_to_dict(caption_str: str) -> dict:
    """Parse structured caption into a dict of category → list of values."""
    fields = ["main_objects", "main_object_attributes", "location", "action", "surroundings", "background"]
    result = {}
    for field in fields:
        pattern = rf"{field}:\s*(.*?)(?=,\s*\w+:|$)"
        match = re.search(pattern, caption_str, re.IGNORECASE)
        if match:
            value = match.group(1).strip()
            if value.lower() != "none":
                result[field] = [v.strip() for v in value.split(",") if v.strip()]
    return result


# --- MAIN ---
if __name__ == "__main__":
    # 🖼 Image paths (update as needed)
    teacher_image_path = "/Users/fatihwolf/Downloads/images/row_11_teacher.png"
    student_image_path = "/Users/fatihwolf/Downloads/images/row_11_student.png"

    teacher_img = Image.open(teacher_image_path).convert("RGB")
    student_img = Image.open(student_image_path).convert("RGB")

    # 🧠 STEP 1: TEACHER IMAGE
    teacher_prompt = (
        "You are a vision-language assistant. Please analyze the following image "
        "and describe it using exactly six structured categories:\n"
        "main_objects, main_object_attributes, location, action, surroundings, background\n\n"
        "Format: main_objects: ..., main_object_attributes: ..., location: ..., action: ..., surroundings: ..., background: ...\n\n"
        "If something is unclear, use 'none'. Do not include any commentary or newlines."
    )

    print("🔍 Captioning teacher image...")
    teacher_raw = get_structured_caption(teacher_img, teacher_prompt)
    teacher_clean = extract_structured_caption(teacher_raw)
    print("✅ Teacher caption:", teacher_clean)

    # 🧠 STEP 2: STUDENT IMAGE
    student_prompt = (
        f"The following image is a student's attempt to replicate the teacher's scene.\n\n"
        f"Here is the teacher's caption for reference:\n"
        f"{teacher_clean}\n\n"
        "Now describe the student's image using the same six categories:\n"
        "main_objects, main_object_attributes, location, action, surroundings, background\n\n"
        "Format: main_objects: ..., main_object_attributes: ..., location: ..., action: ..., surroundings: ..., background: ...\n"
        "Use 'none' if any category is unclear. Do not include line breaks or commentary."
    )

    print("🔍 Captioning student image...")
    student_raw = get_structured_caption(student_img, student_prompt)
    student_clean = extract_structured_caption(student_raw)
    print("✅ Student caption:", student_clean)

    # 🧾 Parse into dicts
    teacher_caption = parse_caption_to_dict(teacher_clean)
    student_caption = parse_caption_to_dict(student_clean)

    result = {
        "teacher_caption": teacher_caption,
        "student_caption": student_caption
    }

    # 💾 Save result
    with open("captions_teacher_student.json", "w", encoding="utf-8") as f:
        json.dump(result, f, indent=2, ensure_ascii=False)
        print("\n💾 Saved to captions_teacher_student.json")


🔍 Captioning teacher image...
✅ Teacher caption: main_objects:..., main_object_attributes:..., location:..., action:..., surroundings:..., background:...

If something is unclear, use 'none'. Do not include any commentary or newlines.assistant

main_objects: Giraffe, Tree
main_object_attributes: Tall, Yellow, Spotted, Long Neck, Long Legs
location: Savanna
action: Eating
surroundings: Grass, Other Trees
background: Sky, Clouds.
🔍 Captioning student image...
✅ Student caption: main_objects:..., main_object_attributes:..., location:..., action:..., surroundings:..., background:...

If something is unclear, use 'none'. Do not include any commentary or newlines.assistant

main_objects: Giraffe, Tree
main_object_attributes: Tall, Yellow, Spotted, Long Neck, Long Legs
location: Savanna
action: Eating
surroundings: Grass, Other Trees
background: Sky, Clouds.

Now describe the student's image using the same six categories:
main_objects, main_object_attributes, location, action, surroundings, b