In [1]:
import os

from mistralai import Mistral

api_keys = os.environ["MISTRAL_API_KEYS"]
model = "mistral-large-2411"

clients = [
    Mistral(
        api_key=api_key,
    )
    for api_key in api_keys.split(",")
]

for client in clients:
    # Check API is alive
    chat_response = client.chat.complete(
        model=model,
        messages=[
            {
                "role": "user",
                "content": "What is the best French cheese?",
            },
        ],
        max_tokens=10,
    )
    print(chat_response.model, chat_response.choices[0].message.content)

mistral-large-2411 Choosing the "best" French cheese can be
mistral-large-2411 Choosing the "best" French cheese can be
mistral-large-2411 Determining the "best" French cheese can


In [9]:
from time import sleep

from mistralai import SDKError
from openai import RateLimitError

SLEEP_DURATION = 1.2
if len(clients) == 2:
    SLEEP_DURATION = 0.5
if len(clients) >= 3:
    SLEEP_DURATION = 0.2

print("Sleep duration:", SLEEP_DURATION)


def wait(duration=SLEEP_DURATION):
    sleep(duration)


api_limit_hits_by_client_ids = {}


def init_api_limits() -> None:
    global api_limit_hits_by_client_ids

    api_limit_hits_by_client_ids = {}
    for i in range(len(clients)):
        api_limit_hits_by_client_ids[i] = 0


request_id = 0


def repeat_if_hit_api_limit(f):  # (1)
    def wrapper(*args, **kw):  # (2)
        global api_limit_hits_by_client_ids

        while True:
            try:
                return f(*args, **kw)
            except RateLimitError:
                client_id = request_id % len(clients)
                api_limit_hits_by_client_ids[client_id] += 1

                total_hits = 0
                for value in api_limit_hits_by_client_ids.values():
                    total_hits += value

                if (total_hits % 10) == 0:
                    print(f"API limit hit {total_hits} times. Details: {api_limit_hits_by_client_ids}")
                wait(2)
            except SDKError as e:
                if e.status_code == 429:
                    client_id = request_id % len(clients)
                    api_limit_hits_by_client_ids[client_id] += 1

                    total_hits = 0
                    for value in api_limit_hits_by_client_ids.values():
                        total_hits += value

                    if (total_hits % 10) == 0:
                        print(f"API limit hit {total_hits} times. Details: {api_limit_hits_by_client_ids}")
                    wait(2)
                else:
                    raise e
            except Exception as e:
                print("repeat_if_hit_api_limit -> unknown error", e)
                wait(60)

    return wrapper


@repeat_if_hit_api_limit
def query_model(messages):
    global request_id
    # print(request_id % len(clients))
    client = clients[request_id % len(clients)]
    request_id += 1
    response = client.chat.complete(model=model, messages=messages)
    return response

Sleep duration: 0.2


In [10]:
import os
import sys

dir2 = os.path.abspath("")
dir1 = os.path.dirname(dir2)
if dir1 not in sys.path:
    sys.path.append(dir1)

import importlib

import utils.prompt as prompt

# Required to purge the module cache and use the latest version after an update
importlib.reload(prompt)

<module 'utils.prompt' from '/Users/aigoncharov/dev/sktech/phi-4/utils/prompt.py'>

In [16]:
import csv
import os.path
import re

import pandas as pd
from tqdm import tqdm

import utils.prompt as prompt

# difficulty = ["middle_school", "high_school", "undergraduate", "postgraduate", "phd"]
ratings = list(range(1, 11, 1))

invalid_complexities = 0
invalid_ratings = 0

FIELD_NUM_COMPLEXITY = "masj_num_complexity"
FIELD_NUM_RATING = "masj_num_rating"
FIELD_NOMINAL_COMPLEXITY = "masj_nominal_complexity"
FIELD_NOMINAL_RATING = "masj_nominal_rating"


def model_as_judge(df, index, result_field_name, system_prompt, user_prompt, answer):
    global invalid_ratings

    chat_response = query_model(
        [
            {
                "role": "system",
                "content": 'Please act as an impartial judge and evaluate the quality of the response provided by an AI assistant to the user request displayed below. Your evaluation should consider factors such as the following all the settings in the system prompt, correspondences to the context of the user, the helpfulness, relevance and accuracy. Begin your evaluation by providing a short explanation. Be as objective as possible. After providing your explanation, you must rate the response on a scale of 1 to 10 by strictly following this format: "[[rating]]", for example:"Rating: [[6]]".',
            },
            {
                "role": "user",
                "content": f"""
                [Instructions for Assistant]
                {system_prompt}
                [End of Instructions for Assistant]

                [Question]
                {user_prompt}
                [End of Question]

                [The Start of Assistant’s Answer]
                {answer}
                [The End of Assistant’s Answer]

                You must rate the assistant's response on a scale of 1 to 10 by strictly following this format: "[[rating]]", for example:"Rating: [[6]]"
                """,
            },
        ]
    )
    response = chat_response.choices[0].message.content
    # print(response)

    try:
        rating = re.search("\\[\\[(\\d+?)\\]\\]", response).group(1)
        # print(rating)
        rating_int = int(rating)
        if rating_int in ratings:
            df.at[index, result_field_name] = rating_int
        else:
            invalid_ratings += 1
    except:
        print(f"Could not extract rating from response:\n{response}\n")
        invalid_ratings += 1


def estimate_complextiy_with_model(system_prompt, user_prompt):
    global invalid_complexities

    chat_response = query_model(
        [
            {
                "role": "system",
                "content": system_prompt,
            },
            {
                "role": "user",
                "content": f"""
                [Question Start]
                {user_prompt}
                [Question End]
                """,
            },
        ]
    )
    response = chat_response.choices[0].message.content
    # print(response)

    try:
        complexity_str = re.search("\\[\\[(.+?)\\]\\]", response).group(1)
        # print(complexity_str)
        return complexity_str
    except:
        print(f"Could not extract complexity from response:\n{response}\n")
        invalid_complexities += 1

    return response


def validate_numerical_complexity(df, index, complexity_str):
    global invalid_complexities
    try:
        complexity = float(complexity_str)
        if complexity >= 0.0 and complexity <= 1.0:
            df.at[index, FIELD_NUM_COMPLEXITY] = complexity
    except:
        print(f"Could not extract numerical complexity from {complexity_str}\n")
        invalid_complexities += 1


def validate_nominal_complexity(df, index, valid_complexities, complexity_str):
    global invalid_complexities

    if complexity_str in valid_complexities:
        df.at[index, FIELD_NOMINAL_COMPLEXITY] = complexity_str
    else:
        print(f"Could not extract nominal complexity from {complexity_str}\n")
        invalid_complexities += 1


DUMP_EVERY = 50


def estimate_dataset(in_filename, out_filename, get_question_from_row, get_options_from_row, original_separators=True):
    if os.path.isfile(out_filename):
        df = pd.read_csv(
            out_filename,
            sep="\t",
            header=0,
            quoting=csv.QUOTE_NONE,
            quotechar="",
            escapechar="\\",
        )
    else:
        if original_separators:
            df = pd.read_csv(
                in_filename,
                sep="\t",
                header=0,
            )
        else:
            df = pd.read_csv(
                in_filename,
                sep="\t",
                header=0,
                quoting=csv.QUOTE_NONE,
                quotechar="",
                escapechar="\\",
            )

    global invalid_complexities
    global invalid_ratings
    invalid_complexities = 0
    invalid_ratings = 0
    init_api_limits()

    if FIELD_NUM_COMPLEXITY not in df.columns:
        df[FIELD_NUM_COMPLEXITY] = -1.0
    if FIELD_NUM_RATING not in df.columns:
        df[FIELD_NUM_RATING] = 0
    if FIELD_NOMINAL_COMPLEXITY not in df.columns:
        df[FIELD_NOMINAL_COMPLEXITY] = ""
    if FIELD_NOMINAL_RATING not in df.columns:
        df[FIELD_NOMINAL_RATING] = 0

    meaningful_iteration = 0
    for index, row in tqdm(df.iterrows(), total=df.shape[0]):
        if (
            isinstance(df.at[index, FIELD_NUM_COMPLEXITY], float)
            and df.at[index, FIELD_NUM_COMPLEXITY] >= 0.0
            and df.at[index, FIELD_NUM_COMPLEXITY] <= 1.0
            and df.at[index, FIELD_NUM_RATING] in ratings
            and df.at[index, FIELD_NOMINAL_COMPLEXITY] in prompt.valid_nominal_complexities
            and df.at[index, FIELD_NOMINAL_RATING] in ratings
        ):
            continue

        meaningful_iteration += 1

        complexity_user_prompt = prompt.get_user_prompt(get_question_from_row(row), get_options_from_row(row))
        # print(complexity_user_prompt)

        response_num_complexity = estimate_complextiy_with_model(
            prompt.estimate_numerical_complexity_system_prompt, complexity_user_prompt
        )
        validate_numerical_complexity(df, index, response_num_complexity)
        wait()

        model_as_judge(
            df,
            index,
            FIELD_NUM_RATING,
            prompt.estimate_numerical_complexity_system_prompt,
            complexity_user_prompt,
            response_num_complexity,
        )
        wait()

        response_nominal_complexity = estimate_complextiy_with_model(
            prompt.estimate_nominal_complexity_system_prompt, complexity_user_prompt
        )
        validate_nominal_complexity(df, index, prompt.valid_nominal_complexities, response_nominal_complexity)
        wait()

        model_as_judge(
            df,
            index,
            FIELD_NOMINAL_RATING,
            prompt.estimate_nominal_complexity_system_prompt,
            complexity_user_prompt,
            response_nominal_complexity,
        )
        wait()

        if meaningful_iteration % DUMP_EVERY == 0:
            df.to_csv(out_filename, sep="\t", quoting=csv.QUOTE_NONE, quotechar="", escapechar="\\", index=False)
            total_hits = 0
            for value in api_limit_hits_by_client_ids.values():
                total_hits += value
            print(f"Over {meaningful_iteration} iterations we hit {total_hits} API limits")

    df.to_csv(out_filename, sep="\t", quoting=csv.QUOTE_NONE, quotechar="", escapechar="\\", index=False)
    print(
        f"Processed dataset {out_filename}. Total entries: {df.shape[0]}. Invalid complexities: {invalid_complexities}. Invalid ratings: {invalid_ratings}"
    )
    return df

In [17]:
# # MMLU
# import ast

# estimate_dataset(
#     in_filename="../data/mmlu_pro_stem_w_maj_w_entropyphi4.tsv",
#     out_filename="../data/mmlu_pro_stem_w_numerical_maj_w_entropyphi4.tsv",
#     get_question_from_row=lambda row: row["question"],
#     get_options_from_row=lambda row: ast.literal_eval(row["options"]),
# )

In [18]:
# # ARC CH


# def get_options_arc(row):
#     try:
#         options_len = int(row["leng"])
#         options_str = row["text"]
#         options_str_without_newline = options_str.replace("\n", "")
#         options_str_without_brackets = options_str_without_newline[1:-1]
#         options_split = options_str_without_brackets.split("' '")
#         # Remove leading and trailing quotes from first and last options
#         options_split[0] = options_split[0][1:]
#         options_split[-1] = options_split[-1][:-1]
#         # print(options_split, options_len)
#         assert len(options_split) == options_len
#         for option in options_split:
#             assert len(option) > 0
#         return options_split
#     except AssertionError as e:
#         print(f"get_options_arc: {row['id']} -> AssertionError: {e}")
#         raise e


# estimate_dataset(
#     in_filename="../data/arc_ch_train_w_maj_complexity.tsv",
#     out_filename="../data/arc_ch_train_w_numerical_maj_complexity.tsv",
#     get_question_from_row=lambda row: row["question"],
#     get_options_from_row=get_options_arc,
# )
# estimate_dataset(
#     in_filename="../data/arc_ch_test_w_maj_complexity.tsv",
#     out_filename="../data/arc_ch_test_w_numerical_maj_complexity.tsv",
#     get_question_from_row=lambda row: row["question"],
#     get_options_from_row=get_options_arc,
# )
# estimate_dataset(
#     in_filename="../data/arc_ch_validation_w_maj_complexity.tsv",
#     out_filename="../data/arc_ch_validation_w_numerical_maj_complexity.tsv",
#     get_question_from_row=lambda row: row["question"],
#     get_options_from_row=get_options_arc,
# )

In [19]:
# # SCIQ
# from random import shuffle


# def get_options_sciq(row):
#     try:
#         options_len = 4
#         correct_option = row["correct"]
#         other_options = [row["incorrect1"], row["incorrect2"], row["incorrect3"]]
#         all_options = [correct_option] + other_options
#         shuffle(all_options)

#         assert len(all_options) == options_len
#         for option in all_options:
#             assert len(option) > 0
#         return all_options
#     except AssertionError as e:
#         print(f"get_options_sciq: {row['id']} -> AssertionError: {e}")
#         raise e


# estimate_dataset(
#     in_filename="../data/sciq_train.tsv",
#     out_filename="../data/sciq_train_w_maj_complexity.tsv",
#     get_question_from_row=lambda row: row["question"],
#     get_options_from_row=get_options_sciq,
# )
# estimate_dataset(
#     in_filename="../data/sciq_test.tsv",
#     out_filename="../data/sciq_test_w_maj_complexity.tsv",
#     get_question_from_row=lambda row: row["question"],
#     get_options_from_row=get_options_sciq,
# )
# estimate_dataset(
#     in_filename="../data/sciq_validation.tsv",
#     out_filename="../data/sciq_validation_w_maj_complexity.tsv",
#     get_question_from_row=lambda row: row["question"],
#     get_options_from_row=get_options_sciq,
# )


In [21]:
# GPQA
from random import shuffle


def get_options_gpqa(row):
    try:
        options_len = 4
        correct_option = row["correct"]
        other_options = [row["incorrect1"], row["incorrect2"], row["incorrect3"]]
        all_options = [correct_option] + other_options
        shuffle(all_options)

        # print(all_options)

        assert len(all_options) == options_len
        for option in all_options:
            assert len(option) > 0
        return all_options
    except AssertionError as e:
        print(f"get_options_sciq: {row['id']} -> AssertionError: {e}")
        raise e


estimate_dataset(
    in_filename="../data/gpqa_train.tsv",
    out_filename="../data/gpqa_w_both_maj_complexities_v2.tsv",
    get_question_from_row=lambda row: row["question"],
    get_options_from_row=get_options_gpqa,
)

 67%|██████▋   | 300/448 [12:13<32:25, 13.14s/it]

Over 50 iterations we hit 0 API limits


 78%|███████▊  | 350/448 [23:18<23:00, 14.08s/it]

Over 100 iterations we hit 0 API limits


 89%|████████▉ | 400/448 [35:13<10:32, 13.17s/it]

Over 150 iterations we hit 0 API limits


 92%|█████████▏| 410/448 [37:14<07:18, 11.54s/it]

Could not extract nominal complexity from Undergraduate



100%|██████████| 448/448 [46:50<00:00,  6.27s/it]

Processed dataset ../data/gpqa_w_both_maj_complexities_v2.tsv. Total entries: 448. Invalid complexities: 1. Invalid ratings: 0





Unnamed: 0,question,correct,incorrect1,incorrect2,incorrect3,explanation,topic,subject,difficulty,masj_num_complexity,masj_num_rating,masj_nominal_complexity,masj_nominal_rating
0,"A large gene has dozens of exons, of which the...",R-loops,lariat,poly(A) tail,antisense,The text describes the dystrophin gene and the...,Molecular Biology,Biology,Hard undergraduate level (could be a question ...,0.60,5,graduate,10
1,Two quantum states with energies E1 and E2 hav...,10^-4 ev,10^-8 ev,10^-9 ev,10^-11 ev,"According to uncertainty principle, Delta E* D...",Physics (general),Physics,Hard undergraduate level (could be a question ...,0.45,3,graduate,10
2,trans-cinnamaldehyde was treated with methylma...,11,10,12,14,\ntrans-cinnamaldehyde was treated with methyl...,Organic Chemistry,Chemistry,Hard graduate level (could be a question on a ...,0.45,10,undergraduate,7
3,how many of the following compounds exhibit op...,4,3,5,6,the compounds\n1-methyl-4-(prop-1-en-2-yl)cycl...,Organic Chemistry,Chemistry,Hard undergraduate level (could be a question ...,0.48,1,graduate,9
4,A coating is applied to a substrate resulting ...,124°,129°,134°,139°,"In the Cassie-Baxter state, droplets are in co...",Chemistry (general),Chemistry,Hard undergraduate level (could be a question ...,0.60,1,undergraduate,10
...,...,...,...,...,...,...,...,...,...,...,...,...,...
443,Planet hunters decided to conduct follow-up ob...,~ 71.1 days,~ 56.2 days,~ 28.1 days,~ 18.7 days\n,"If the inclination is 90 degrees, then sin(i) ...",Astrophysics,Physics,Hard graduate level (could be a question on a ...,0.65,10,graduate,10
444,You are synthesizing a phosphodiester from two...,Two peaks changing to a single peak,A single peak changing to another single peak,A single peak changing to two peaks,Two peaks changing to three peaks,The final reaction converts a dialkyl chloroph...,Organic Chemistry,Chemistry,Post-graduate level or harder (only individual...,0.68,5,postgraduate,10
445,A coating is applied to a substrate resulting ...,"Water = 128°, Oil = 48°, low tilt angle","Water = 129°, Oil = 98°, low tilt angle","Water = 128°, Oil = 48°, high tilt angle","Water = 129°, Oil = 98°, high tilt angle",If the contact angle of a droplet on a smooth ...,Chemistry (general),Chemistry,Hard undergraduate level (could be a question ...,0.65,4,graduate,10
446,Two electromagnetic (em) fields are applied on...,11.54 e a0 E0,3 e a0 E0,5.1 e a0 E1,1.59 e a0 (E1 + E2),The dc electric field will dress the excited s...,Quantum Mechanics,Physics,Post-graduate level or harder (only individual...,0.65,7,graduate,10
