In [1]:
import json
import re
import os
import time
from typing import Dict, Any, List
from tqdm import tqdm
import litellm
from openai import OpenAI, AsyncOpenAI
from datasets import load_dataset
import asyncio
import random

In [2]:
def process_graph_entry(ex, idx: int):
    """
    Process Graph: Uses the pre-written prompt field which describes the graph.
    index: a 1-based integer provided by the caller
    """
    return {
        "index": idx,  # integer index (1..n_examples for this dataset)
        "task": "graph",
        "prompt": ex.get("prompt"),
        "gold_answer": ex.get("solution"),
        "meta": {
            "graph_params": ex.get("graph_params"),
            "edges": ex.get("edges"),
            "original_id": ex.get("id"),  # keep original unique id here
        },
    }

In [3]:
def process_infobench_entry(ex, idx: int):
    """
    Process Infobench: Uses 'instruction'.
    If 'input' exists (context), it appends it.
    """
    raw_instruction = ex.get("instruction", "") or ""
    context = ex.get("input", "") or ""

    if context:
        full_prompt = f"{raw_instruction}\n\n{context}"
    else:
        full_prompt = raw_instruction

    return {
        "index": idx,
        "task": "infobench",
        "prompt": full_prompt,
        "gold_answer": None,
        "meta": {
            "decomposed_questions": ex.get("decomposed_questions"),
            "category": ex.get("category"),
            "subset": ex.get("subset"),
            "original_id": ex.get("id"),
        },
    }

In [4]:
def process_mmlu_med_entry(ex, idx: int):
    """
    Process MMLU: Formats Question + Choices (A, B, C, D).
    Converts integer answer (0) to Letter (A).
    """
    q = ex.get("question", "") or ""
    choices = ex.get("choices", []) or []

    formatted_choices = "\n".join([f"{chr(65 + i)}. {c}" for i, c in enumerate(choices)])

    prompt_text = f"Question:\n{q}\n\nChoices:\n{formatted_choices}"

    answer_idx = ex.get("answer")
    gold_letter = chr(65 + answer_idx) if answer_idx is not None else None

    return {
        "index": idx,
        "task": "mmlu_med",
        "prompt": prompt_text,
        "gold_answer": gold_letter,
        "meta": {
            "subject": ex.get("subject"),
            "answer_index": answer_idx,
            "raw_choices": choices,
            "original_id": ex.get("id"),
        },
    }


In [5]:
def get_combined_dataset(n_each=5):
    combined_data = []

    # Initialize a global index counter starting at 1
    # FIXME: the requests code seems to use global_idx starting from 0
    # Although this file claims it should start from 1, we change it to 0
    # so we can actually evaluate the request results.
    global_idx = 0

    print("Loading datasets...")
    graph_ds = load_dataset("vashistht/11763_datasets", "graph_dev", split="dev_test")
    infobench_ds = load_dataset("vashistht/11763_datasets", "infobench", split="dev_test")
    mmlu_ds = load_dataset("vashistht/11763_datasets", "mmlu_med", split="dev_test")

    # --- Graph ---
    print(f"Processing {n_each} examples from Graph...")
    for ex in graph_ds.select(range(n_each)):
        # Pass the current global_idx
        combined_data.append(process_graph_entry(ex, global_idx))
        global_idx += 1  # Increment for the next item

    # --- Infobench ---
    print(f"Processing {n_each} examples from Infobench...")
    for ex in infobench_ds.select(range(n_each)):
        # global_idx continues from where Graph left off
        combined_data.append(process_infobench_entry(ex, global_idx))
        global_idx += 1

    # --- MMLU Med ---
    print(f"Processing {n_each} examples from MMLU Med...")
    for ex in mmlu_ds.select(range(n_each)):
        combined_data.append(process_mmlu_med_entry(ex, global_idx))
        global_idx += 1

    # Shuffle the combined data
    # Note: The indices will now be unique (1..N), but scrambled in order.
    random.shuffle(combined_data)

    return combined_data

In [6]:
dataset = get_combined_dataset(n_each=100)


def print_sample(task_name, data):
    item = next((x for x in data if x["task"] == task_name), None)
    if item:
        print(f"\n--- Sample {task_name} ---")
        print(json.dumps(item, indent=2))


if dataset:
    # print_sample("graph", dataset)
    # print_sample("infobench", dataset)
    # print_sample("mmlu_med", dataset)

    # save dataset
    out_path = "combined_dataset_full.jsonl"
    with open(out_path, "w", encoding="utf-8") as f:
        for item in dataset:
            f.write(json.dumps(item, ensure_ascii=False) + "\n")
    print(f"\nWrote {len(dataset)} items to {out_path}")

Loading datasets...
Processing 100 examples from Graph...
Processing 100 examples from Infobench...
Processing 100 examples from MMLU Med...

Wrote 300 items to combined_dataset_full.jsonl
