In [1]:
import os

from mistralai import Mistral

api_keys = os.environ["MISTRAL_API_KEYS"]
model = "mistral-large-2411"

clients = [
    Mistral(
        api_key=api_key,
    )
    for api_key in api_keys.split(",")
]

for client in clients:
    # Check API is alive
    chat_response = client.chat.complete(
        model=model,
        messages=[
            {
                "role": "user",
                "content": "What is the best French cheese?",
            },
        ],
        max_tokens=10,
    )
    print(chat_response.model, chat_response.choices[0].message.content)

mistral-large-2411 Choosing the "best" French cheese can be
mistral-large-2411 Determining the "best" French cheese can
mistral-large-2411 Determining the "best" French cheese can


In [2]:
from time import sleep

from mistralai import SDKError
from openai import RateLimitError

SLEEP_DURATION = 1.2
if len(clients) == 2:
    SLEEP_DURATION = 0.5
if len(clients) >= 3:
    SLEEP_DURATION = 0

print("Sleep duration:", SLEEP_DURATION)


def wait(duration=SLEEP_DURATION):
    sleep(duration)


api_limit_hits = 0


def repeat_if_hit_api_limit(f):  # (1)
    def wrapper(*args, **kw):  # (2)
        global api_limit_hits

        while True:
            try:
                return f(*args, **kw)
            except RateLimitError:
                api_limit_hits += 1
                if (api_limit_hits % 10) == 0:
                    print(f"API limit hit {api_limit_hits} times")
                wait(1)
            except SDKError as e:
                if e.status_code == 429:
                    api_limit_hits += 1
                    if (api_limit_hits % 10) == 0:
                        print(f"API limit hit {api_limit_hits} times")
                    wait(1)
                else:
                    raise e
            except Exception as e:
                print("repeat_if_hit_api_limit -> unknown error", e)
                wait(60)

    return wrapper

Sleep duration: 0


In [3]:
request_id = 0


@repeat_if_hit_api_limit
def query_model(messages):
    global request_id
    client = clients[request_id % len(clients)]
    request_id += 1
    response = client.chat.complete(model=model, messages=messages)
    return response

In [4]:
import os
import sys

dir2 = os.path.abspath("")
dir1 = os.path.dirname(dir2)
if dir1 not in sys.path:
    sys.path.append(dir1)

import importlib

import utils.prompt as prompt

# Required to purge the module cache and use the latest version after an update
importlib.reload(prompt)

<module 'utils.prompt' from '/Users/aigoncharov/dev/sktech/phi-4/utils/prompt.py'>

In [28]:
import ast
import csv
import re

from tqdm import tqdm

import utils.prompt as prompt

difficulty = ["middle_school", "high_school", "undergraduate", "postgraduate", "phd"]
ratings = list(range(1, 11, 1))

invalid_complexities = 0
invalid_ratings = 0


def model_as_judge(df, index, system_prompt, user_prompt, answer):
    global invalid_ratings

    chat_response = query_model(
        [
            {
                "role": "system",
                "content": 'Please act as an impartial judge and evaluate the quality of the response provided by an AI assistant to the user request displayed below. Your evaluation should consider factors such as the following all the settings in the system prompt, correspondences to the context of the user, the helpfulness, relevance and accuracy. Begin your evaluation by providing a short explanation. Be as objective as possible. After providing your explanation, you must rate the response on a scale of 1 to 10 by strictly following this format: "[[rating]]", for example:"Rating: [[6]]".',
            },
            {
                "role": "user",
                "content": f"""
                [Instructions for Assistant]
                {system_prompt}
                [End of Instructions for Assistant]

                [Question]
                {user_prompt}
                [End of Question]

                [The Start of Assistant’s Answer]
                {answer}
                [The End of Assistant’s Answer]
                """,
            },
        ]
    )
    response = chat_response.choices[0].message.content
    # print(response)

    try:
        rating = re.search("\\[\\[(\\d+?)\\]\\]", response).group(1)
        # print(rating)
        rating_int = int(rating)
        if rating_int in ratings:
            df.at[index, "masj_rating"] = rating_int
        else:
            invalid_ratings += 1
    except:
        print(f"Could not extract rating from response:\n{response}\n")
        invalid_ratings += 1


@repeat_if_hit_api_limit
def estimate_complextiy_with_model(df, index, system_prompt, user_prompt):
    global invalid_complexities

    chat_response = query_model(
        [
            {
                "role": "system",
                "content": system_prompt,
            },
            {
                "role": "user",
                "content": f"""
                [Question Start]
                {user_prompt}
                [Question End]
                """,
            },
        ]
    )
    response = chat_response.choices[0].message.content
    # print(response)

    try:
        complexity = re.search("\\[\\[(.+?)\\]\\]", response).group(1)
        # print(complexity)

        if complexity in difficulty:
            df.at[index, "masj_complexity"] = complexity
        else:
            invalid_complexities += 1
    except:
        print(f"Could not extract complexity from response:\n{response}\n")
        invalid_complexities += 1

    return response


DUMP_EVERY = 100


def estimate_dataset(df, get_question_from_row, get_options_from_row, out_filename):
    global invalid_complexities
    global invalid_ratings
    invalid_complexities = 0
    invalid_ratings = 0

    if "masj_complexity" not in df.columns:
        df["masj_complexity"] = ""
    if "masj_rating" not in df.columns:
        df["masj_rating"] = 0

    for index, row in tqdm(df.iterrows(), total=df.shape[0]):
        if df.at[index, "masj_complexity"] in difficulty and df.at[index, "masj_rating"] in ratings:
            continue

        complexity_system_prompt = f'You are an expert in the topic of the question. Please act as an impartial judge and evaluate the complexity of the multiple-choice question with options below. Begin your evaluation by providing a short explanation. Be as objective as possible. After providing your explanation, you must not answer the question. You must rate the question complexity by strictly following the scale: {", ".join(difficulty)}. You must return the complexity by strictly following this format: "[[complexity]]", for example: "Complexity: [[middle_school]]".'
        complexity_user_prompt = prompt.get_user_prompt(get_question_from_row(row), get_options_from_row(row))
        # print(complexity_user_prompt)

        response_complexity = estimate_complextiy_with_model(
            df, index, complexity_system_prompt, complexity_user_prompt
        )
        wait()

        model_as_judge(df, index, complexity_system_prompt, complexity_user_prompt, response_complexity)
        wait()

        if index % DUMP_EVERY == 0:
            df.to_csv(out_filename, sep="\t", quoting=csv.QUOTE_NONE, quotechar="", escapechar="\\", index=False)
            print(f"Over {index} iterations we hit {api_limit_hits} API limits")

    df.to_csv(out_filename, sep="\t", quoting=csv.QUOTE_NONE, quotechar="", escapechar="\\", index=False)
    print(
        f"Processed dataset {out_filename}. Total entries: {df.shape[0]}. Invalid complexities: {invalid_complexities}. Invalid ratings: {invalid_ratings}"
    )
    return df

In [34]:
# MMLU
import os.path

import pandas as pd

ORIGINAL_DATASET = "../data/mmlu_pro_stem"
original_filename = f"{ORIGINAL_DATASET}.tsv"
out_filename = f"{ORIGINAL_DATASET}_w_maj_complexity.tsv"

if os.path.isfile(out_filename):
    df = pd.read_csv(
        out_filename,
        sep="\t",
        header=0,
        quoting=csv.QUOTE_NONE,
        quotechar="",
        escapechar="\\",
    )
else:
    df = pd.read_csv(
        original_filename,
        sep="\t",
        header=0,
    )
# df = df.head(10)


estimate_dataset(
    df=df,
    get_question_from_row=lambda row: row["question"],
    get_options_from_row=lambda row: ast.literal_eval(row["options"]),
    out_filename=out_filename,
)

100%|██████████| 12032/12032 [02:43<00:00, 73.54it/s]  

Processed dataset ../data/mmlu_pro_stem_w_maj_complexity.tsv. Total entries: 12032. Invalid complexities: 0. Invalid ratings: 0





Unnamed: 0,src,answer,options,category,question,cot_content,question_id,answer_index,total_tokens,meta_cluster,base_cluster,masj_complexity,masj_rating
0,ori_mmlu-jurisprudence,C,['There is no distinction between the two form...,law,Which of the following criticisms of Llewellyn...,,1286,2,81,Legal Interpretation,Legal Theory Interpretations,postgraduate,9
1,ori_mmlu-international_law,E,"['Article 19', 'Article 11', 'Article 12', 'Ar...",law,Which of the following articles are not qualif...,,1293,4,38,Legal Interpretation,Constitutional Law,undergraduate,9
2,ori_mmlu-management,D,"['Work delegation', 'Workload balancing', 'Wor...",business,As what is ensuring that one individual does n...,,83,3,49,Economics & Finance MCQs,Business & Marketing Queries,high_school,9
3,stemez-Business,J,"['$308.25', '$142.75', '$199.99', '$225.85', '...",business,Margaret Denault recently rented a truck to dr...,,94,9,118,Economics & Finance MCQs,Business Finance Questions,high_school,9
4,stemez-Business,I,"['$60,000', '$43,200', '$1,794', '$25,000', '$...",business,The tax rate in the town of Centerville is 11(...,,104,8,102,Economics & Finance MCQs,Business Finance Questions,middle_school,9
...,...,...,...,...,...,...,...,...,...,...,...,...,...
12027,ori_mmlu-high_school_macroeconomics,F,['Higher interest rates that result from borro...,economics,"The ""crowding-out"" effect refers to which of t...",,7681,5,150,Economics & Finance MCQs,Economic Concepts & Policies,undergraduate,9
12028,ori_mmlu-high_school_macroeconomics,A,['Lower reserve requirements; lower the discou...,economics,Which of the following lists contains only Fed...,,7683,0,124,Economics & Finance MCQs,Economic Concepts & Policies,undergraduate,10
12029,ori_mmlu-high_school_macroeconomics,I,['The productivity of labor in country X is 75...,economics,Output in country X is 30000 units and there a...,,7684,8,206,Economics & Finance MCQs,Economic Concepts & Policies,undergraduate,9
12030,ori_mmlu-high_school_macroeconomics,B,"['an increase in net exports', 'a decrease in ...",economics,A use of easy money (expansionary) policy by t...,,7685,1,58,Economics & Finance MCQs,Economic Concepts & Policies,undergraduate,9


In [35]:
# ARC CH
import os.path

import pandas as pd

ORIGINAL_DATASET = "../data/arc_ch_validation"
original_filename = f"{ORIGINAL_DATASET}.tsv"
out_filename = f"{ORIGINAL_DATASET}_w_maj_complexity.tsv"

if os.path.isfile(out_filename):
    df = pd.read_csv(
        out_filename,
        sep="\t",
        header=0,
        quoting=csv.QUOTE_NONE,
        quotechar="",
        escapechar="\\",
    )
else:
    df = pd.read_csv(
        original_filename,
        sep="\t",
        header=0,
    )
# df = df.head(10)


def get_options_arc(row):
    try:
        options_len = int(row["leng"])
        options_str = row["text"]
        options_str_without_newline = options_str.replace("\n", "")
        options_str_without_brackets = options_str_without_newline[1:-1]
        options_split = options_str_without_brackets.split("' '")
        # Remove leading and trailing quotes from first and last options
        options_split[0] = options_split[0][1:]
        options_split[-1] = options_split[-1][:-1]
        # print(options_split, options_len)
        assert len(options_split) == options_len
        for option in options_split:
            assert len(option) > 0
        return options_split
    except AssertionError as e:
        print(f"get_options_arc: {row['id']} -> AssertionError: {e}")
        raise e


estimate_dataset(
    df=df,
    get_question_from_row=lambda row: row["question"],
    get_options_from_row=get_options_arc,
    out_filename=out_filename,
)

100%|██████████| 295/295 [00:23<00:00, 12.64it/s]

Processed dataset ../data/arc_ch_validation_w_maj_complexity.tsv. Total entries: 295. Invalid complexities: 0. Invalid ratings: 0





Unnamed: 0,id,question,answerKey,text,leng,answer,masj_complexity,masj_rating
0,Mercury_SC_407695,Juan and LaKeisha roll a few objects down a ra...,D,['Put the objects in groups.' 'Change the heig...,4,3,middle_school,9
1,Mercury_7103565,High-pressure systems stop air from rising int...,C,['fog' 'rain' 'drought' 'tornado'],4,2,high_school,9
2,MDSA_2009_5_16,Students visited the Morris W. Offit telescope...,D,['The sun revolves around Earth.' 'Earth rotat...,4,3,middle_school,10
3,Mercury_7027230,Which topic area would be the best to research...,A,['converting sunlight into electricity' 'looki...,4,0,middle_school,10
4,Mercury_SC_405487,"One year, the oak trees in a park began produc...",B,['Shady areas increased.' 'Food sources increa...,4,1,middle_school,9
...,...,...,...,...,...,...,...,...
290,Mercury_7090598,Which of these processes involves the transfer...,C,['erosion' 'sedimentation' 'subduction' 'cemen...,4,2,high_school,9
291,OHAT_2007_5_24,"In a forest, how do decomposers help other org...",B,['They release oxygen into the air that animal...,4,1,middle_school,10
292,Mercury_SC_402239,What is the best way to conserve natural resou...,C,['Throw all glass in the trash.' 'Use paper to...,4,2,middle_school,10
293,Mercury_7245088,Which describes the composition of carbohydrates?,D,['lipids bonding to form phospholipids'\n 'mon...,4,3,high_school,9
