In [None]:
import os

from mistralai import Mistral

api_keys = os.environ["MISTRAL_API_KEYS"]
model = "mistral-large-2411"

clients = [
    Mistral(
        api_key=api_key,
    )
    for api_key in api_keys.split(",")
]

for client in clients:
    # Check API is alive
    chat_response = client.chat.complete(
        model=model,
        messages=[
            {
                "role": "user",
                "content": "What is the best French cheese?",
            },
        ],
        max_tokens=10,
    )
    print(chat_response.model, chat_response.choices[0].message.content)

In [None]:
from time import sleep

from mistralai import SDKError
from openai import RateLimitError

SLEEP_DURATION = 1.2
if len(clients) == 2:
    SLEEP_DURATION = 0.5
if len(clients) >= 3:
    SLEEP_DURATION = 0.2

print("Sleep duration:", SLEEP_DURATION)


def wait(duration=SLEEP_DURATION):
    sleep(duration)


api_limit_hits_by_client_ids = {}


def init_api_limits() -> None:
    global api_limit_hits_by_client_ids

    api_limit_hits_by_client_ids = {}
    for i in range(len(clients)):
        api_limit_hits_by_client_ids[i] = 0


request_id = 0


def repeat_if_hit_api_limit(f):  # (1)
    def wrapper(*args, **kw):  # (2)
        global api_limit_hits_by_client_ids

        while True:
            try:
                return f(*args, **kw)
            except RateLimitError:
                client_id = request_id % len(clients)
                api_limit_hits_by_client_ids[client_id] += 1

                total_hits = 0
                for value in api_limit_hits_by_client_ids.values():
                    total_hits += value

                if (total_hits % 10) == 0:
                    print(f"API limit hit {total_hits} times. Details: {api_limit_hits_by_client_ids}")
                wait(2)
            except SDKError as e:
                if e.status_code == 429:
                    client_id = request_id % len(clients)
                    api_limit_hits_by_client_ids[client_id] += 1

                    total_hits = 0
                    for value in api_limit_hits_by_client_ids.values():
                        total_hits += value

                    if (total_hits % 10) == 0:
                        print(f"API limit hit {total_hits} times. Details: {api_limit_hits_by_client_ids}")
                    wait(2)
                else:
                    raise e
            except Exception as e:
                print("repeat_if_hit_api_limit -> unknown error", e)
                wait(60)

    return wrapper


@repeat_if_hit_api_limit
def query_model(messages):
    global request_id
    # print(request_id % len(clients))
    client = clients[request_id % len(clients)]
    request_id += 1
    response = client.chat.complete(model=model, messages=messages)
    return response

In [None]:
import os
import sys

dir2 = os.path.abspath("")
dir1 = os.path.dirname(dir2)
if dir1 not in sys.path:
    sys.path.append(dir1)

import importlib

import utils.prompt as prompt

# Required to purge the module cache and use the latest version after an update
importlib.reload(prompt)

In [None]:
import csv
import os.path
import re

import pandas as pd
from tqdm import tqdm

import utils.prompt as prompt

# difficulty = ["middle_school", "high_school", "undergraduate", "postgraduate", "phd"]
ratings = list(range(1, 11, 1))

invalid_complexities = 0
invalid_ratings = 0

FIELD_REQUIRES_KNOWLEDGE = "masj_requires_knowledge"
FIELD_REQUIRES_REASONING = "masj_requires_reasoning"
FIELD_NUM_REASONING_STEPS = "masj_num_reasoning_steps"


def estimate_reasoning_complexity_with_model(df, index, system_prompt, user_prompt):
    global invalid_complexities

    chat_response = query_model(
        [
            {
                "role": "system",
                "content": system_prompt,
            },
            {
                "role": "user",
                "content": f"""
                [Question Start]
                {user_prompt}
                [Question End]
                """,
            },
        ]
    )
    response = chat_response.choices[0].message.content
    try:
        requires_knowledge = re.search("\\[\\[Requires knowledge:(.+?)\\]\\]", response).group(1)
        requires_knowledge: str = requires_knowledge.strip().lower()
        assert requires_knowledge in prompt.valid_requires_answers
        df.at[index, FIELD_REQUIRES_KNOWLEDGE] = requires_knowledge

        requires_reasoning = re.search("\\[\\[Requires reasoning:(.+?)\\]\\]", response).group(1)
        requires_reasoning: str = requires_reasoning.strip().lower()
        assert requires_reasoning in prompt.valid_requires_answers
        df.at[index, FIELD_REQUIRES_REASONING] = requires_reasoning

        num_reasoning_steps = re.search("\\[\\[Number of reasoning steps:(.+?)\\]\\]", response).group(1)
        num_reasoning_steps: str = num_reasoning_steps.strip().lower()
        assert num_reasoning_steps in prompt.valid_num_reasoning_steps_answers
        df.at[index, FIELD_NUM_REASONING_STEPS] = num_reasoning_steps

        # print(f"""
        # [Question]
        # {user_prompt}
        # [Response]
        # {response}
        # ==========

        # """)
    except:
        print(f"Could not extract from response:\n{response}\n")
        invalid_complexities += 1


DUMP_EVERY = 50


def estimate_dataset(in_filename, out_filename, get_question_from_row, get_options_from_row, original_separators=True):
    if os.path.isfile(out_filename):
        df = pd.read_csv(
            out_filename,
            sep="\t",
            header=0,
            quoting=csv.QUOTE_NONE,
            quotechar="",
            escapechar="\\",
        )
    else:
        if original_separators:
            df = pd.read_csv(
                in_filename,
                sep="\t",
                header=0,
            )
        else:
            df = pd.read_csv(
                in_filename,
                sep="\t",
                header=0,
                quoting=csv.QUOTE_NONE,
                quotechar="",
                escapechar="\\",
            )

    global invalid_complexities
    global invalid_ratings
    invalid_complexities = 0
    invalid_ratings = 0
    init_api_limits()

    if FIELD_REQUIRES_KNOWLEDGE not in df.columns:
        df[FIELD_REQUIRES_KNOWLEDGE] = ""
    if FIELD_REQUIRES_REASONING not in df.columns:
        df[FIELD_REQUIRES_REASONING] = ""
    if FIELD_NUM_REASONING_STEPS not in df.columns:
        df[FIELD_NUM_REASONING_STEPS] = ""

    meaningful_iteration = 0
    for index, row in tqdm(df.iterrows(), total=df.shape[0]):
        if (
            isinstance(df.at[index, FIELD_REQUIRES_KNOWLEDGE], float)
            and df.at[index, FIELD_REQUIRES_KNOWLEDGE] in prompt.valid_requires_answers
            and df.at[index, FIELD_REQUIRES_REASONING] in prompt.valid_requires_answers
            and df.at[index, FIELD_NUM_REASONING_STEPS] in prompt.valid_num_reasoning_steps_answers
        ):
            continue

        meaningful_iteration += 1

        complexity_user_prompt = prompt.get_user_prompt(get_question_from_row(row), get_options_from_row(row))
        # print(complexity_user_prompt)

        estimate_reasoning_complexity_with_model(
            df, index, prompt.estimate_reasoning_complexity_system_prompt, complexity_user_prompt
        )

        if meaningful_iteration % DUMP_EVERY == 0:
            df.to_csv(out_filename, sep="\t", quoting=csv.QUOTE_NONE, quotechar="", escapechar="\\", index=False)
            total_hits = 0
            for value in api_limit_hits_by_client_ids.values():
                total_hits += value
            print(f"Over {meaningful_iteration} iterations we hit {total_hits} API limits")

    df.to_csv(out_filename, sep="\t", quoting=csv.QUOTE_NONE, quotechar="", escapechar="\\", index=False)
    print(
        f"Processed dataset {out_filename}. Total entries: {df.shape[0]}. Invalid complexities: {invalid_complexities}. Invalid ratings: {invalid_ratings}"
    )
    return df

In [None]:
# MMLU
import ast

estimate_dataset(
    in_filename="../data/mmlu_final.tsv",
    out_filename="../data/mmlu_pro_stem_reasoning_score.tsv",
    get_question_from_row=lambda row: row["question"],
    get_options_from_row=lambda row: ast.literal_eval(row["options"]),
    original_separators=False,
)