In [40]:
SYSTEM = """
You are an image classification agent. Your role is to evaluate whether a given instruction has been correctly applied to an image.
You are given the original image, the modified image and an instruction.
 Response Format:
1. Provide a step-by-step analysis of the image in relation to the instruction.  
2. Conclude your response with either `<YES>` or `<NO>` on a new line, depending on whether the instruction was applied.  
3. Ensure that `<YES>` or `<NO>` is enclosed within less than (`<`) and greater than (`>`) signs and appears on a separate line at the end of the response.  
4. Ensure the less than (`<`) and greater than (`>`) signs are only used at the end of the response and nowhere else .
"""
PROMPT = """
Was the instruction "{instruction}" applied to the image?
"""

In [45]:
import base64
from io import BytesIO
from PIL import Image
from openai import OpenAI
import os

client = OpenAI(
    # base_url="https://api.groq.com/openai/v1", api_key=os.environ.get("GROQ_API_KEY")
)

import re


def assess_response(response: str) -> bool:
    matches = re.search(r"<(.{3})>", response)
    if not matches:
        return "<YES>" in response
    return matches.group(1) == "YES"


def check_modification(
    image_solution: Image.Image, instruction: str, image_input: Image.Image
) -> bool:
    buffered_input = BytesIO()
    image_input.save(buffered_input, format="JPEG")
    img_str_input = base64.b64encode(buffered_input.getvalue()).decode("utf-8")

    buffered_solution = BytesIO()
    image_solution.save(buffered_solution, format="JPEG")
    img_str_solution = base64.b64encode(buffered_solution.getvalue()).decode("utf-8")

    completion = client.chat.completions.create(
        # model="llama-3.2-90b-vision-preview",
        model="gpt-4o-mini",
        messages=[
            {"role": "system", "content": SYSTEM},
            {
                "role": "user",
                "content": [
                    {
                        "type": "text",
                        "text": PROMPT.format(instruction=instruction)
                    },
                    {
                        "type": "image_url",
                        "image_url": {
                            "url": f"data:image/jpeg;base64,{img_str_input}",
                            "detail": "low",
                        },
                    },
                    {
                        "type": "image_url",
                        "image_url": {
                            "url": f"data:image/jpeg;base64,{img_str_solution}",
                            "detail": "low",
                        },
                    },
                ],
            },
        ],
        temperature=1,
        max_completion_tokens=4096,
        top_p=1,
        stream=False,
    )

    response = completion.choices[0].message.content
    return assess_response(response), response

In [48]:

def classification(row):
    row["instruction_applied"],row["response"] = check_modification(row["image_solution"][0],row["instruction"],row["image_input"])
    return row

from datasets import load_dataset

ds = load_dataset("CharlyR/varbench", "tikz", split="benchmark")

ds = ds.select_columns(["id","instruction","image_solution","image_input"])

ds = ds.map(classification)


Map: 100%|██████████| 100/100 [05:06<00:00,  3.06s/ examples]


In [49]:
ds.push_to_hub("CharlyR/vTikz-vlm_oracl_benchmark","input_provided_gpt4o-mini", split="test")

Map: 100%|██████████| 100/100 [00:00<00:00, 6847.62 examples/s]t/s]
Creating parquet from Arrow format: 100%|██████████| 1/1 [00:00<00:00, 166.57ba/s]
Uploading the dataset shards: 100%|██████████| 1/1 [00:01<00:00,  1.21s/it]


CommitInfo(commit_url='https://huggingface.co/datasets/CharlyR/vTikz-vlm_oracl_benchmark/commit/c18c623f552654b99226f18344850e71f48f23d9', commit_message='Upload dataset', commit_description='', oid='c18c623f552654b99226f18344850e71f48f23d9', pr_url=None, repo_url=RepoUrl('https://huggingface.co/datasets/CharlyR/vTikz-vlm_oracl_benchmark', endpoint='https://huggingface.co', repo_type='dataset', repo_id='CharlyR/vTikz-vlm_oracl_benchmark'), pr_revision=None, pr_num=None)

In [50]:
df  = ds.to_pandas()

In [51]:
df["instruction_applied"] = df["instruction_applied"].astype(int)


In [52]:
df["instruction_applied"].describe()

count    100.000000
mean       0.720000
std        0.451261
min        0.000000
25%        0.000000
50%        1.000000
75%        1.000000
max        1.000000
Name: instruction_applied, dtype: float64

In [54]:
len(df[df["instruction_applied"]==0])

28