In [3]:
import asyncio
import aiohttp
from datasets import load_dataset

# Load MMLU Dataset
#dataset_name = "LFrancis/MMLU-NoOp-Plus"
dataset_name = "cais/mmlu"
dataset = load_dataset(dataset_name, "all")["test"]

# VLLM API Configuration
BASE_URL = "http://134.76.18.30:8080/v1/chat/completions"  # VLLM chat completions endpoint
HEADERS = {"Content-Type": "application/json", "Authorization": "Bearer 9c89c616-649e-4d77-a6ad-1b1e525f94b5"}
MODEL_NAME = "meta-llama/Llama-3.1-8B-Instruct"  # Replace with the actual model name on your server
options = ["A", "B", "C", "D"]

In [13]:

# Helper Functions
def create_chat_messages(question, choices, subject, sys_msg):
    """
    Create a formatted list of chat messages for the chat model.
    """
    user_prompt = f"{question}\n" + "\n".join(
        [f"{opt}. {choice}" for opt, choice in zip(options, choices)]
    ) + "\nAnswer:"
    return [
        {"role": "system", "content": sys_msg.format(subject)},
        {"role": "user", "content": user_prompt}
    ]

async def query_vllm_api(session, payload):
    """
    Send a query to the VLLM API and return the response.
    """
    async with session.post(BASE_URL, json=payload, headers=HEADERS) as response:
        result = await response.json()
        return result

from datasets import Dataset
import aiohttp
from tqdm.asyncio import tqdm_asyncio, tqdm


async def evaluate_question(session, entry):
    """
    Process a single dataset entry: Generate CoT and calculate logprobs.
    Adds reasoning and logprobs as new fields in the entry.
    """
    sys_msg = "The following are multiple choice questions (with answers) about {}."
    question, choices, subject = entry["question"], entry["choices"], entry["subject"]

    # Step 1: Generate reasoning (CoT) response
    messages = create_chat_messages(question, choices, subject, sys_msg)
    cot_payload = {
        "model": MODEL_NAME,  # Specify model
        "messages": messages,
        "max_tokens": 200,
        "temperature": 0.0,
    }

    cot_response = await query_vllm_api(session, cot_payload)
    if "object" in cot_response.keys() and cot_response["object"] == "error":
        raise Exception(cot_response["message"])
    cot_text = cot_response["choices"][0]["message"]["content"].strip()  # Extract CoT reasoning

    # Step 2: Calculate logprobs for each choice
    final_prompt = f"{cot_text}\nFinal Answer: "
    logprobs = {}
    for idx, option in enumerate(choices):
        choice_messages = [
            *messages,
            {"role": "system", "content": sys_msg.format(subject)},
            {"role": "user", "content": final_prompt + f" {option}"}
        ]
        choice_payload = {
            "model": MODEL_NAME,
            "messages": choice_messages,
            "max_tokens": 1,
            "temperature": 0.0,
            "prompt_logprobs": 0
        }
        choice_response = await query_vllm_api(session, choice_payload)
        if "prompt_logprobs" not in choice_response:
            raise Exception(f"No prompt logprobs found for {option}")
        logprobs[option] = list(choice_response["prompt_logprobs"][-6].values())[0]["logprob"]
    entry["logprobs"] = logprobs
    return entry

from itertools import islice

import asyncio
from itertools import islice
from tqdm.asyncio import tqdm_asyncio
import aiohttp

# Configuration
BATCH_SIZE = 50
CONCURRENT_REQUESTS = 10  # Adjust based on your system's capabilities
MAX_RETRIES = 0
RETRY_DELAY = 5
timeout = aiohttp.ClientTimeout(total=120)

# Semaphore for limiting concurrency
semaphore = asyncio.Semaphore(CONCURRENT_REQUESTS)

# Chunking the dataset
def chunk_dataset(dataset, batch_size):
    it = iter(dataset)
    for i in range(0, len(dataset), batch_size):
        yield list(islice(it, batch_size))

# Retry logic
async def evaluate_question_with_retries(session, entry, max_retries=MAX_RETRIES):
    for attempt in range(max_retries):
        try:
            async with semaphore:  # Control concurrency
                return await evaluate_question(session, entry)
        except Exception as e:
            print(f"Attempt {attempt + 1} failed for entry {entry}: {e}")
            if attempt < max_retries - 1:
                await asyncio.sleep(RETRY_DELAY)
            else:
                # Log or handle the failure after retries
                return {"entry": entry, "error": str(e)}

async def process_batch(session, batch):
    tasks = [evaluate_question_with_retries(session, entry) for entry in batch]
    results = []

    # Use asyncio.gather with exception handling
    for coro in tqdm(asyncio.as_completed(tasks), total=len(tasks), desc="Processing batch"):
        try:
            result = await coro  # Await individual tasks
            results.append(result)
        except Exception as e:
            print(f"Task failed with exception: {e}")
            results.append({"error": str(e)})

    return results


# Process the entire dataset in batches
async def process_dataset_in_batches(dataset):
    async with aiohttp.ClientSession(timeout=timeout) as session:
        all_results = []
        for batch in chunk_dataset(dataset, BATCH_SIZE):
            batch_results = await process_batch(session, batch)
            all_results.extend(batch_results)
        return all_results

def is_correct(entry):
    """
    Determines if the choice with the lowest log probability corresponds to the correct answer.

    Args:
        entry (dict): A dictionary containing the question, choices, answer index, and logprobs.

    Returns:
        bool: True if the option with the lowest logprob matches the correct answer index, False otherwise.
    """
    # Extract logprobs and the correct answer index
    logprobs = entry['logprobs']
    correct_answer_index = entry['answer']

    # Find the key (A, B, C, D) with the lowest logprob
    highest_logprob_option = max(logprobs, key=logprobs.get)

    # Map the key to its corresponding index (0 for 'A', 1 for 'B', etc.)
    options = ["A", "B", "C", "D"]
    highest_logprob_index = options.index(highest_logprob_option)

    # Check if the lowest logprob index matches the correct answer index
    return highest_logprob_index == correct_answer_index

In [15]:

async def main():
    from tqdm.asyncio import tqdm_asyncio
    """
    Main function to evaluate the dataset asynchronously.
    """
    # Process the dataset asynchronously
    processed_dataset = await process_dataset_in_batches(dataset)
    print(processed_dataset)
    # Convert back to Hugging Face Dataset if needed
    processed_dataset = Dataset.from_list(processed_dataset)

    # Save the updated dataset
    processed_dataset.save_to_disk(dataset_name+"_evaluated_"+MODEL_NAME)

    # Save or process results as needed
    score = [is_correct(result) for result in processed_dataset]
    score = sum(score) / len(score)
    print("Accuracy", score)

# Run the script
await main()

Processing batch: 100%|██████████| 50/50 [00:00<00:00, 17524.46it/s]
Processing batch: 100%|██████████| 50/50 [00:00<00:00, 58012.50it/s]
Processing batch: 100%|██████████| 50/50 [00:00<00:00, 38255.24it/s]
Processing batch: 100%|██████████| 50/50 [00:00<00:00, 63704.50it/s]
Processing batch: 100%|██████████| 50/50 [00:00<00:00, 88975.48it/s]
Processing batch: 100%|██████████| 50/50 [00:00<00:00, 98457.84it/s]
Processing batch: 100%|██████████| 50/50 [00:00<00:00, 104335.92it/s]
Processing batch: 100%|██████████| 50/50 [00:00<00:00, 64847.00it/s]
Processing batch: 100%|██████████| 50/50 [00:00<00:00, 90903.86it/s]
Processing batch: 100%|██████████| 50/50 [00:00<00:00, 89967.91it/s]
Processing batch: 100%|██████████| 50/50 [00:00<00:00, 73765.46it/s]
Processing batch: 100%|██████████| 50/50 [00:00<00:00, 106184.91it/s]
Processing batch: 100%|██████████| 50/50 [00:00<00:00, 94551.49it/s]
Processing batch: 100%|██████████| 50/50 [00:00<00:00, 93289.68it/s]
Processing batch: 100%|█████████

[None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, Non

TypeError: 'NoneType' object is not iterable