In [7]:
import os
from mistralai import Mistral

api_key = os.environ["MISTRAL_API_KEY"]
model = "mistral-large-2411"

client = Mistral(api_key=api_key)

# Check API is alive
chat_response = client.chat.complete(
    model=model,
    messages=[
        {
            "role": "user",
            "content": "What is the best French cheese?",
        },
    ],
    max_tokens=10,
)
print(chat_response.model, chat_response.choices[0].message.content)

mistral-large-2411 Choosing the "best" French cheese can be


In [17]:
import os, sys

dir2 = os.path.abspath("")
dir1 = os.path.dirname(dir2)
if not dir1 in sys.path:
    sys.path.append(dir1)

import pandas as pd
import ast
import csv
from tqdm import tqdm
from time import sleep
from mistralai import SDKError
import re

import utils.prompt as prompt

import importlib

# Required to purge the module cache and use the latest version after an update
importlib.reload(prompt)

difficulty = ["middle_school", "high_school", "undergraduate", "postgraduate", "phd"]

extract_regex = re.compile("\\[\\[(.*?)\\]\\]")

invalid_complexities = 0


def estimate_dataset(df, client, get_question_from_row, get_options_from_row):
    df["masj_complexity"] = ""
    df["masj_mt_rating"] = 0

    for index, row in tqdm(df.iterrows(), total=df.shape[0]):
        while True:
            try:
                chat_response = client.chat.complete(
                    model=model,
                    messages=[
                        {
                            "role": "system",
                            "content": f"You are an expert in science. Please act as an impartial judge and evaluate the complexity of the multiple-choice question below. Be as objective as possible. You must rate the question complexity by strictly following the scale: {", ".join(difficulty)}. You must be concise and return only the complexity by strictly following this format: [[complexity]], for example: [[middle_school]].",
                        },
                        {
                            "role": "user",
                            "content": prompt.get_user_prompt(get_question_from_row(row), get_options_from_row(row)),
                        },
                    ],
                )
                response = chat_response.choices[0].message.content
                complexity = extract_regex.search(response).group(1)
                # print(complexity)

                if complexity in difficulty:
                    df.at[index, "masj_complexity"] = complexity
                else:
                    invalid_complexities += 1

                sleep(1.2)

                break
            except SDKError as e:
                if e.status_code == 429:
                    sleep(1)
                else:
                    raise e
    return df


DATASET = "../data/mmlu_pro_stem"

df = pd.read_csv(f"{DATASET}.tsv", sep="\t", header=0)
df = df.head(10)

processed_df = estimate_dataset(
    df=df,
    client=client,
    get_question_from_row=lambda row: row["question"],
    get_options_from_row=lambda row: ast.literal_eval(row["options"]),
)
# processed_df.to_csv(f"{DATASET}_w_maj_complexity.tsv", sep="\t", quoting=csv.QUOTE_NONE)

  0%|          | 0/10 [00:00<?, ?it/s]

undergraduate


 10%|█         | 1/10 [00:01<00:13,  1.55s/it]

high_school


 20%|██        | 2/10 [00:03<00:12,  1.52s/it]

middle_school


 30%|███       | 3/10 [00:04<00:11,  1.57s/it]

high_school


 40%|████      | 4/10 [00:06<00:09,  1.57s/it]

high_school


 50%|█████     | 5/10 [00:07<00:08,  1.62s/it]

high_school


 60%|██████    | 6/10 [00:09<00:06,  1.59s/it]

middle_school


 70%|███████   | 7/10 [00:10<00:04,  1.55s/it]

middle_school


 80%|████████  | 8/10 [00:13<00:03,  1.89s/it]

undergraduate


 90%|█████████ | 9/10 [00:15<00:01,  1.76s/it]

high_school


100%|██████████| 10/10 [00:16<00:00,  1.68s/it]
