## Instructions
--> Run the cells step by step to achieve following results
1. Get Reasoning traces from the LLM via API
2. Structure the data in the LLM accepted format
3. Create HuggingFace version of the data

Note: Set the Gemini API keys in the Environment

In [None]:
import os, json
import base64
from openai import OpenAI
import datasets
IMG_DIR = "data/filtered_images_5k/"

### **Set the following paths ar per need**

In [None]:
OUTPUT_FILE = "data/train_cot.json"
OUTPUT_HF_DIR = "data/hf_dataset/"

In [None]:
with open("data/filtered_data_5k.json", "r") as f:
    data = json.load(f)

len(data)

5000

In [None]:
api_key_0 = os.get_env("GEMINI_API_KEY_0") 
api_key_1 = os.get_env("GEMINI_API_KEY_1")
api_key_2 = os.get_env("GEMINI_API_KEY_2")
api_key_3 = os.get_env("GEMINI_API_KEY_3")
api_key_4 = os.get_env("GEMINI_API_KEY_4")
api_key_5 = os.get_env("GEMINI_API_KEY_5")

client = OpenAI(
    api_key=api_key_0,
    base_url="https://generativelanguage.googleapis.com/v1beta/openai/"
)

client_1 = OpenAI(
    api_key=api_key_1,
    base_url="https://generativelanguage.googleapis.com/v1beta/openai/"
)

client_2 = OpenAI(
    api_key=api_key_2,
    base_url="https://generativelanguage.googleapis.com/v1beta/openai/"
)

client_3 = OpenAI(
    api_key=api_key_3,
    base_url="https://generativelanguage.googleapis.com/v1beta/openai/"
)

client_4 = OpenAI(
    api_key=api_key_4,
    base_url="https://generativelanguage.googleapis.com/v1beta/openai/"
)

client_5 = OpenAI(
    api_key=api_key_5,
    base_url="https://generativelanguage.googleapis.com/v1beta/openai/"
)
# clients = [client, client_1, client_2, client_3, client_4, client_5]
clients = [client_5, client_4, client_3, client_2, client_1, client]

In [None]:
model_id = "gemini-2.0-flash-thinking-exp-01-21"

In [None]:

SYSTEM_PROMPT = """You are a Vision Language Model specialized in interpreting visual data from chart images.

## MAIN GOAL
Your task is to analyze the provided chart image and respond to queries with concise answers, usually a single word, number, or short phrase.
The charts include a variety of types (e.g., line charts, bar charts) and contain colors, labels, and text.
Focus on delivering accurate, succinct answers based on the visual information. Avoid additional explanation unless absolutely necessary.


1. THINKING PROCESS
- Use short, simple sentences that mirror natural thought patterns addressing the question.

## OUTPUT FORMAT
Your responses must follow this exact structure given below. Make sure to always include the final answer.

```
<think>
[Your thinking process goes here]   <--- SHOULD NOT BE MORE THAN 2-3 LINES.
</think>

<answer>
[The final answer to the question goes here] <--- should be very concise, usually a single word, number, or short phrase.
</answer>

<Example Responses>
```
1. <think>\nThe user is asking to identify the largest bar in the chart.\nComparing the heights of the bars visually, the 'Total revenue' bar is the tallest, indicating the greatest value.\n</think>\n\n<answer>\nTotal revenue\n</answer>
2. <think>\nThe user wants the year with the minimum number of officers.\nI need to find the lowest point on the line graph and read the corresponding year from the x-axis.\nThe lowest value on the line is clearly above the '2013' label.\n</think>\n\n<answer>\n2013\n</answer>
3. <think>\nThe query asks for the average of the navy blue bars.\nFirst, identify the values of the four navy blue bars: 31, 47, 18, and 3.\nSum these values (31+47+18+3=99) and divide by the number of bars (4) to get the average, which is 99/4 = 24.75.\n</think>\n\n<answer>\n24.75\n</answer>
4. <think>\nThe user is asking for the total number of officers in 2015.\nI need to find the bar corresponding to 2015 and read the value from the y-axis.\nThe bar for 2015 is clearly labeled with a value of 50.\n</think>\n\n<answer>\n50\n</answer>
```

## STYLE GUIDELINES ON THINKING PROCESS

NOTE: Your response MUST follow the specified output format and the answers inside the <answer> tag should be concise, generally a single word, a numeric answer, an yes or no, or name/phrase.
NOTE: Try to keep the thinking process VERY MUCH CONCISE and to the point, avoid unnecessary details.
NOTE: For numeric answers don't answer in words, just give the number and for fractions simply to write in decimal
NOTE: When two words/phrases are the answer, separate them by a comma. For example: "yes, no", "China, India", "2015, 2016"

NOTE: The reasoning process inside <think> tags SHOULD STRICTLY BE UNDER 3 LINES. BASICALLY, it should be extremely to the point, concise and nothing unnecessary.
"""

In [None]:
# Function to encode the image
def encode_image(image_path):
  with open(image_path, "rb") as image_file:
    return base64.b64encode(image_file.read()).decode('utf-8')

In [None]:
def ask_lm(prompt, image_path=None, api_client = None):
    if api_client is None:
        api_client = client
    messages = [
        {
            "role": "system",
            "content": SYSTEM_PROMPT
        }
    ]
    prompt = {"type": "text", "text": prompt}
    user_content = [prompt]

    if image_path:
        try:
            encoded_image = encode_image(image_path)

            user_content.append({
                "type": "image_url",
                "image_url": {
                    "url": f"data:image/png;base64,{encoded_image}"  # Assuming JPEG, adjust if needed
                }
            })
        except FileNotFoundError:
            return "Error: Image file not found."
        except Exception as e:
            return f"Error reading image file: {e}"

    messages.append({
        "role": "user",
        "content": user_content
    })

    response = api_client.chat.completions.create(
        model=model_id,
        messages=messages,
    )
    # print(response)
    return response.choices[0].message.content


In [None]:

OUTPUT = []
FAILED = []
ERRORS = []
CLIENT_COUNT = [0] * len(clients)
cnt = 0
x = 0
client_id = 0
NUM_CLIENTS = len(clients)
for item in data[x:]:
    qs = item["query"]
    img = IMG_DIR + item["imgname"]

    response = None
    # clients = [client, client_1, client_2, client_3, client_4, client_5]
    client = clients[client_id]

    try_cnt = 0

    while True:
        if try_cnt >= len(clients):
            FAILED.append(item)
            break 
        else:
            try_cnt += 1
            try:
                client = clients[client_id]
                response = ask_lm(qs, img, client)
                CLIENT_COUNT[client_id] += 1
                break
            except Exception as e:
                print(f"Client {client_id} failed: {e}")
                client_id += 1
                client_id %= NUM_CLIENTS

    if response is not None:
        item["output"] = response
        OUTPUT.append(item)
    cnt += 1
    if cnt%25 == 0:
        print(f"{cnt} calls done")
        # write output to file
        with open("data/train_smolcot_left2k_" + str(x) + "_on.json", "w") as f:
            json.dump(OUTPUT, f, indent = 4)


In [None]:
with open(OUTPUT_FILE, "w") as f:
    json.dump(OUTPUT, f, indent = 4)

train_data = OUTPUT

In [None]:
print("Total calls:", cnt)
print("Total successful calls:", len(OUTPUT))
print("Total failed calls:", len(FAILED))
print("Total errors:", len(ERRORS))

for i in range(len(clients)):
    print(f"Client {i} calls:", CLIENT_COUNT[i])

In [None]:
from datasets import Dataset, Sequence, Image as HFImageFeature
from PIL import Image
import json, os
def transform_entry(data_entry, image_dir):

    img_filename = data_entry.get('imgname')
    query = data_entry.get('query')
    output = data_entry.get('output')

    if not all([img_filename, query, output]):
        print(f"Warning: Skipping entry due to missing fields: {data_entry}")
        return None, None

    img_path = os.path.join(image_dir, img_filename)

    try:
        image = Image.open(img_path).convert('RGB')
        # img_bytes = image["bytes"]
        # pil_image = PIL.Image.open(io.BytesIO(img_bytes))
    except FileNotFoundError:
        print(f"Warning: Image file not found at {img_path}. Skipping entry.")
        return None, None
    except Exception as e:
        print(f"Warning: Error loading image {img_path}: {e}. Skipping entry.")
        return None, None

    # Construct the 'messages' field according to the required structure
    messages = [
        {
            "role": "user",
            "content": [
                {"type": "text", "text": query},
                # 'index': 0 refers to the first image in the 'images' list below
                # 'text': None as seen in the example for image placeholders
                {"type": "image", "index": 0, "text": None}
            ]
        },
        {
            "role": "assistant",
            "content": [
                {"type": "text", "text": output, "index": None}
            ]
        }
    ]

    # a = {}
    # a["image"] = image
    # a["messages"] = messages
    # return a
    # return {"images": [image], "messages": messages}
    return image, messages

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
transformed_data_list = []
images = []
messages = []
print(f"Starting transformation for {len(train_data)} entries...")
for i, entry in enumerate(train_data):
    if i%300==0:
      print(f"Processing entry {i+1}/{len(train_data)}: {entry.get('imgname')}")
    img, msg = transform_entry(entry, IMG_DIR)
    if img is None or msg is None:
        continue
    images.append(img)
    messages.append(msg)

Starting transformation for 300 entries...
Processing entry 1/300: 12756.png


In [7]:
dataset = datasets.Dataset.from_dict(
    {'images': images, 'messages': messages},
)

In [11]:
from datasets import Image as HFImageFeature, Sequence

def wrap_image_in_list(batch):
  """
  Takes a batch dictionary and wraps each image in the 'images' list
  into its own single-element list.
  """
  batch['images'] = [[img] for img in batch['images']]
  return batch # Return the modified batch



new_image_feature = Sequence(feature=HFImageFeature(decode=True))

def transform(dataset):
    transformed_dataset = dataset.map(
        wrap_image_in_list,
        batched=True
    )
    transformed_dataset = transformed_dataset.cast_column("images", new_image_feature)
    transformed_dataset.features
    return transformed_dataset

In [12]:
final_data = transform(dataset)
final_data[0]

Map: 100%|██████████| 300/300 [00:27<00:00, 11.08 examples/s]
Casting the dataset: 100%|██████████| 300/300 [00:00<00:00, 70299.53 examples/s]


{'images': [<PIL.PngImagePlugin.PngImageFile image mode=RGB size=204x394>],
 'messages': [{'content': [{'index': None,
     'text': 'Is the value of "Don\'t know" segment 7%?',
     'type': 'text'},
    {'index': 0, 'text': None, 'type': 'image'}],
   'role': 'user'},
  {'content': [{'index': None,
     'text': '<think>\nThe user is asking to check if the value of "Don\'t know" segment is 7%.\nLooking at the pie chart, the "Don\'t know" segment is clearly labeled with "7%".\nTherefore, the statement is correct.\n</think>\n\n<answer>\nYes\n</answer>',
     'type': 'text'}],
   'role': 'assistant'}]}

In [None]:
final_data.save_to_disk(OUTPUT_HF_DIR)
print("Dataset saved to disk.")2