In [1]:
import os

from mistralai import Mistral

api_keys = os.environ["MISTRAL_API_KEYS"]
model = "mistral-large-2411"

clients = [
    Mistral(
        api_key=api_key,
    )
    for api_key in api_keys.split(",")
]

for client in clients:
    # Check API is alive
    chat_response = client.chat.complete(
        model=model,
        messages=[
            {
                "role": "user",
                "content": "What is the best French cheese?",
            },
        ],
        max_tokens=10,
    )
    print(chat_response.model, chat_response.choices[0].message.content)

mistral-large-2411 Choosing the "best" French cheese can be
mistral-large-2411 Determining the "best" French cheese can


In [2]:
from time import sleep

from mistralai import SDKError
from openai import RateLimitError

SLEEP_DURATION = 1.2
if len(clients) == 2:
    SLEEP_DURATION = 0.5
if len(clients) >= 3:
    SLEEP_DURATION = 0.2

print("Sleep duration:", SLEEP_DURATION)


def wait(duration=SLEEP_DURATION):
    sleep(duration)


api_limit_hits_by_client_ids = {}


def init_api_limits() -> None:
    global api_limit_hits_by_client_ids

    api_limit_hits_by_client_ids = {}
    for i in range(len(clients)):
        api_limit_hits_by_client_ids[i] = 0


request_id = 0


def repeat_if_hit_api_limit(f):  # (1)
    def wrapper(*args, **kw):  # (2)
        global api_limit_hits_by_client_ids

        while True:
            try:
                return f(*args, **kw)
            except RateLimitError:
                client_id = request_id % len(clients)
                api_limit_hits_by_client_ids[client_id] += 1

                total_hits = 0
                for value in api_limit_hits_by_client_ids.values():
                    total_hits += value

                if (total_hits % 10) == 0:
                    print(f"API limit hit {total_hits} times. Details: {api_limit_hits_by_client_ids}")
                wait(2)
            except SDKError as e:
                if e.status_code == 429:
                    client_id = request_id % len(clients)
                    api_limit_hits_by_client_ids[client_id] += 1

                    total_hits = 0
                    for value in api_limit_hits_by_client_ids.values():
                        total_hits += value

                    if (total_hits % 10) == 0:
                        print(f"API limit hit {total_hits} times. Details: {api_limit_hits_by_client_ids}")
                    wait(2)
                else:
                    raise e
            except Exception as e:
                print("repeat_if_hit_api_limit -> unknown error", e)
                wait(60)

    return wrapper


@repeat_if_hit_api_limit
def query_model(messages):
    global request_id
    # print(request_id % len(clients))
    client = clients[request_id % len(clients)]
    request_id += 1
    response = client.chat.complete(model=model, messages=messages)
    return response

Sleep duration: 0.5


In [3]:
import os
import sys

dir2 = os.path.abspath("")
dir1 = os.path.dirname(dir2)
if dir1 not in sys.path:
    sys.path.append(dir1)

import importlib

import utils.prompt as prompt

# Required to purge the module cache and use the latest version after an update
importlib.reload(prompt)

<module 'utils.prompt' from '/Users/aigoncharov/dev/sktech/phi-4/utils/prompt.py'>

In [9]:
import csv
import os.path
import re

import pandas as pd
from tqdm import tqdm

import utils.prompt as prompt

# difficulty = ["middle_school", "high_school", "undergraduate", "postgraduate", "phd"]
ratings = list(range(1, 11, 1))

invalid_complexities = 0
invalid_ratings = 0


def model_as_judge(df, index, system_prompt, user_prompt, answer):
    global invalid_ratings

    chat_response = query_model(
        [
            {
                "role": "system",
                "content": 'Please act as an impartial judge and evaluate the quality of the response provided by an AI assistant to the user request displayed below. Your evaluation should consider factors such as the following all the settings in the system prompt, correspondences to the context of the user, the helpfulness, relevance and accuracy. Begin your evaluation by providing a short explanation. Be as objective as possible. After providing your explanation, you must rate the response on a scale of 1 to 10 by strictly following this format: "[[rating]]", for example:"Rating: [[6]]".',
            },
            {
                "role": "user",
                "content": f"""
                [Instructions for Assistant]
                {system_prompt}
                [End of Instructions for Assistant]

                [Question]
                {user_prompt}
                [End of Question]

                [The Start of Assistant’s Answer]
                {answer}
                [The End of Assistant’s Answer]
                """,
            },
        ]
    )
    response = chat_response.choices[0].message.content
    # print(response)

    try:
        rating = re.search("\\[\\[(\\d+?)\\]\\]", response).group(1)
        # print(rating)
        rating_int = int(rating)
        if rating_int in ratings:
            df.at[index, "masj_num_rating"] = rating_int
        else:
            invalid_ratings += 1
    except:
        print(f"Could not extract rating from response:\n{response}\n")
        invalid_ratings += 1


def estimate_complextiy_with_model(df, index, system_prompt, user_prompt):
    global invalid_complexities

    chat_response = query_model(
        [
            {
                "role": "system",
                "content": system_prompt,
            },
            {
                "role": "user",
                "content": f"""
                [Question Start]
                {user_prompt}
                [Question End]
                """,
            },
        ]
    )
    response = chat_response.choices[0].message.content
    # print(response)

    try:
        complexity_str = re.search("\\[\\[(.+?)\\]\\]", response).group(1)
        # print(complexity_str)
        complexity = float(complexity_str)

        if complexity >= 0.0 and complexity <= 1.0:
            df.at[index, "masj_num_complexity"] = complexity
        else:
            invalid_complexities += 1
    except:
        print(f"Could not extract complexity from response:\n{response}\n")
        invalid_complexities += 1

    return response


DUMP_EVERY = 50


def estimate_dataset(in_filename, out_filename, get_question_from_row, get_options_from_row):
    if os.path.isfile(out_filename):
        df = pd.read_csv(
            out_filename,
            sep="\t",
            header=0,
            quoting=csv.QUOTE_NONE,
            quotechar="",
            escapechar="\\",
        )
    else:
        try:
            df = pd.read_csv(
                in_filename,
                sep="\t",
                header=0,
                quoting=csv.QUOTE_NONE,
                quotechar="",
                escapechar="\\",
            )
        except:
            df = pd.read_csv(
                in_filename,
                sep="\t",
                header=0,
            )

    global invalid_complexities
    global invalid_ratings
    invalid_complexities = 0
    invalid_ratings = 0
    init_api_limits()

    if "masj_num_complexity" not in df.columns:
        df["masj_num_complexity"] = -1.0
    if "masj_num_rating" not in df.columns:
        df["masj_num_rating"] = 0

    meaningful_iteration = 0
    for index, row in tqdm(df.iterrows(), total=df.shape[0]):
        if (
            isinstance(df.at[index, "masj_num_complexity"], float)
            and df.at[index, "masj_num_complexity"] >= 0.0
            and df.at[index, "masj_num_complexity"] <= 1.0
            and df.at[index, "masj_num_rating"] in ratings
        ):
            continue

        meaningful_iteration += 1

        complexity_system_prompt = 'You are an expert in the topic of the question. Please act as an impartial judge and evaluate the complexity of the multiple-choice question with options below. Begin your evaluation by providing a short explanation. Be as objective as possible. After providing your explanation, you must not answer the question. You must rate the question complexity as a number from 0 to 1 following the following scale as a reference: middle_school - 0.0-0.2, high_school - 0.2-0.4, undergraduate - 0.4-0.6, postgraduate - 0.6-0.8, phd - 0.8-1.0. You must return the complexity by strictly following this format: "[[complexity]]", for example: "Complexity: [[0.55]]", which corresponds to the undergraduate level.'
        complexity_user_prompt = prompt.get_user_prompt(get_question_from_row(row), get_options_from_row(row))
        # print(complexity_user_prompt)

        response_complexity = estimate_complextiy_with_model(
            df, index, complexity_system_prompt, complexity_user_prompt
        )
        wait()

        model_as_judge(df, index, complexity_system_prompt, complexity_user_prompt, response_complexity)
        wait()

        if meaningful_iteration % DUMP_EVERY == 0:
            df.to_csv(out_filename, sep="\t", quoting=csv.QUOTE_NONE, quotechar="", escapechar="\\", index=False)
            total_hits = 0
            for value in api_limit_hits_by_client_ids.values():
                total_hits += value
            print(f"Over {meaningful_iteration} iterations we hit {total_hits} API limits")

    df.to_csv(out_filename, sep="\t", quoting=csv.QUOTE_NONE, quotechar="", escapechar="\\", index=False)
    print(
        f"Processed dataset {out_filename}. Total entries: {df.shape[0]}. Invalid complexities: {invalid_complexities}. Invalid ratings: {invalid_ratings}"
    )
    return df

In [None]:
# # MMLU
# import ast

# estimate_dataset(
#     in_filename="../data/mmlu_pro_stem_w_maj_w_entropyphi4.tsv",
#     out_filename="../data/mmlu_pro_stem_w_numerical_maj_w_entropyphi4.tsv",
#     get_question_from_row=lambda row: row["question"],
#     get_options_from_row=lambda row: ast.literal_eval(row["options"]),
# )

In [None]:
# ARC CH


def get_options_arc(row):
    try:
        options_len = int(row["leng"])
        options_str = row["text"]
        options_str_without_newline = options_str.replace("\n", "")
        options_str_without_brackets = options_str_without_newline[1:-1]
        options_split = options_str_without_brackets.split("' '")
        # Remove leading and trailing quotes from first and last options
        options_split[0] = options_split[0][1:]
        options_split[-1] = options_split[-1][:-1]
        # print(options_split, options_len)
        assert len(options_split) == options_len
        for option in options_split:
            assert len(option) > 0
        return options_split
    except AssertionError as e:
        print(f"get_options_arc: {row['id']} -> AssertionError: {e}")
        raise e


estimate_dataset(
    in_filename="../data/arc_ch_train_w_maj_complexity.tsv",
    out_filename="../data/arc_ch_train_w_numerical_maj_complexity.tsv",
    get_question_from_row=lambda row: row["question"],
    get_options_from_row=get_options_arc,
)
estimate_dataset(
    in_filename="../data/arc_ch_test_w_maj_complexity.tsv",
    out_filename="../data/arc_ch_test_w_numerical_maj_complexity.tsv",
    get_question_from_row=lambda row: row["question"],
    get_options_from_row=get_options_arc,
)
estimate_dataset(
    in_filename="../data/arc_ch_validation_w_maj_complexity.tsv",
    out_filename="../data/arc_ch_validation_w_numerical_maj_complexity.tsv",
    get_question_from_row=lambda row: row["question"],
    get_options_from_row=get_options_arc,
)

  0%|          | 0/1117 [00:00<?, ?it/s]

  0%|          | 1/1117 [00:10<3:20:47, 10.80s/it]


KeyboardInterrupt: 

In [None]:
# # SCIQ
# from random import shuffle


# def get_options_sciq(row):
#     try:
#         options_len = 4
#         correct_option = row["correct"]
#         other_options = [row["incorrect1"], row["incorrect2"], row["incorrect3"]]
#         all_options = [correct_option] + other_options
#         shuffle(all_options)

#         assert len(all_options) == options_len
#         for option in all_options:
#             assert len(option) > 0
#         return all_options
#     except AssertionError as e:
#         print(f"get_options_sciq: {row['id']} -> AssertionError: {e}")
#         raise e


# estimate_dataset(
#     in_filename="../data/sciq_train.tsv",
#     out_filename="../data/sciq_train_w_maj_complexity.tsv",
#     get_question_from_row=lambda row: row["question"],
#     get_options_from_row=get_options_sciq,
# )
# estimate_dataset(
#     in_filename="../data/sciq_test.tsv",
#     out_filename="../data/sciq_test_w_maj_complexity.tsv",
#     get_question_from_row=lambda row: row["question"],
#     get_options_from_row=get_options_sciq,
# )
# estimate_dataset(
#     in_filename="../data/sciq_validation.tsv",
#     out_filename="../data/sciq_validation_w_maj_complexity.tsv",
#     get_question_from_row=lambda row: row["question"],
#     get_options_from_row=get_options_sciq,
# )
