In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, TorchAoConfig

MODEL_NAME = "microsoft/phi-4"
DEVICE = torch.device("cpu")
if torch.cuda.is_available():
    DEVICE = torch.device("cuda")
if torch.mps.is_available():
    DEVICE = torch.device("mps")

print(f"Using device: {DEVICE}")

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

# quantization_config = TorchAoConfig("int4_weight_only", group_size=32)
# model = AutoModelForCausalLM.from_pretrained(
#     MODEL_NAME,
#     low_cpu_mem_usage=True,
#     torch_dtype=torch.bfloat16,
#     quantization_config=quantization_config,
#     device_map=DEVICE,
# )
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME).to(DEVICE)

Using device: mps
import error: No module named 'triton'


Loading checkpoint shards:   0%|          | 0/6 [00:00<?, ?it/s]

In [11]:
print(model.get_memory_footprint() / 10**9)
print(next(model.parameters()).device)

29.319014656
mps:0
phi3


In [14]:
import os, sys

dir2 = os.path.abspath("")
dir1 = os.path.dirname(dir2)
if not dir1 in sys.path:
    sys.path.append(dir1)

import pandas as pd
import ast
import csv
from tqdm import tqdm

import utils.prompt as prompt
import complexity_estimation.tokenwise_entropy as tokenwise_entropy

import importlib

# Required to purge the module cache and use the latest version after an update
importlib.reload(prompt)
importlib.reload(tokenwise_entropy)

DUMP_EVERY = 100
invalid_answers = 0


def estimate_dataset(
    df, model, tokenizer, get_subject_from_row, get_question_from_row, get_options_from_row, verify_answer, out_filename
):
    global invalid_answers

    model_name = model.config_class().model_type
    print(model_name)

    field_ans_correct = f"entropy_ans_correct_{model_name}"
    field_ans_output = f"entropy_ans_output_{model_name}"
    field_entropy_value = f"field_entropy_value_{model_name}"

    if field_ans_correct not in df.columns:
        df[field_ans_correct] = False
    if field_ans_output not in df.columns:
        df[field_ans_output] = ""
    if field_entropy_value not in df.columns:
        df[field_entropy_value] = 0.0

    entropy_estimator = tokenwise_entropy.TokenwiseEntropy(llm_model=model, device=DEVICE)

    for index, row in tqdm(df.iterrows(), total=df.shape[0]):
        if df.at[index, field_ans_output] != "" and df.at[index, field_ans_output] != "INVALID":
            continue

        print(f"loop {index} -> start: {model.get_memory_footprint(return_buffers=True) / 10**9} GB")

        sys_prompt = prompt.get_sys_prompt(get_subject_from_row(row))
        user_prompt = prompt.get_user_prompt(get_question_from_row(row), get_options_from_row(row))
        messages = [
            {"role": "system", "content": sys_prompt},
            {"role": "user", "content": user_prompt},
        ]
        formatted_prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)

        inputs = tokenizer(formatted_prompt, return_tensors="pt").to(DEVICE)

        outputs = model.generate(**inputs, max_new_tokens=1, pad_token_id=tokenizer.eos_token_id)
        print(f"loop {index} -> after generate: {model.get_memory_footprint(return_buffers=True) / 10**9} GB")

        input_length = inputs.input_ids.shape[1]
        answer_raw = outputs[0, input_length:]
        answer = tokenizer.decode(answer_raw, skip_special_tokens=True)
        if answer in prompt.option_ids:
            entropy = entropy_estimator.calculate(outputs)
            print(f"loop {index} -> after entropy: {model.get_memory_footprint(return_buffers=True) / 10**9} GB")
            df.at[index, field_entropy_value] = entropy
            df.at[index, field_ans_output] = answer
            df.at[index, field_ans_correct] = verify_answer(row, answer)
        else:
            df.at[index, "entropy_ans_output"] = "INVALID"
            invalid_answers += 1

        print(f"Answer: {answer}\nEntropy: {entropy}\nis_correct: {df.at[index, field_ans_correct]}\n\n")

        if index % DUMP_EVERY == 0:
            df.to_csv(out_filename, sep="\t", quoting=csv.QUOTE_NONE, quotechar="", escapechar="\\", index=False)

    df.to_csv(out_filename, sep="\t", quoting=csv.QUOTE_NONE, quotechar="", escapechar="\\", index=False)
    print(f"Processed dataset {out_filename}. Total entries: {df.shape[0]}. Invalid answers: {invalid_answers}")
    return df


ORIGINAL_DATASET = "../data/mmlu_pro_stem"
original_filename = f"{ORIGINAL_DATASET}.tsv"
out_filename = f"{ORIGINAL_DATASET}_w_phi4_entropy.tsv"

if os.path.isfile(out_filename):
    df = pd.read_csv(
        out_filename,
        sep="\t",
        header=0,
        quoting=csv.QUOTE_NONE,
        quotechar="",
        escapechar="\\",
    )
else:
    df = pd.read_csv(
        original_filename,
        sep="\t",
        header=0,
    )
# df = df.head(10)


def verify_model_answer(row, model_answer):
    try:
        return int(row["answer_index"]) + 1 == int(model_answer)
    except:
        return False


estimate_dataset(
    df=df,
    model=model,
    tokenizer=tokenizer,
    get_subject_from_row=lambda row: row["base_cluster"],
    get_question_from_row=lambda row: row["question"],
    get_options_from_row=lambda row: ast.literal_eval(row["options"]),
    verify_answer=verify_model_answer,
    out_filename=out_filename,
)

phi3


  0%|          | 0/12032 [00:00<?, ?it/s]

loop 0 -> start: 29.319014656 GB
loop 0 -> after generate: 29.319014656 GB


  0%|          | 1/12032 [00:15<50:30:09, 15.11s/it]

loop 0 -> after entropy: 29.319014656 GB
Answer: 3
Entropy: 0.06396484375
is_correct: True


loop 1 -> start: 29.319014656 GB
loop 1 -> after generate: 29.319014656 GB


  0%|          | 2/12032 [00:26<43:07:23, 12.90s/it]

loop 1 -> after entropy: 29.319014656 GB
Answer: 2
Entropy: 1.09375
is_correct: False


loop 2 -> start: 29.319014656 GB
loop 2 -> after generate: 29.319014656 GB


  0%|          | 3/12032 [00:40<45:31:48, 13.63s/it]

loop 2 -> after entropy: 29.319014656 GB
Answer: 1
Entropy: 0.00836181640625
is_correct: False


loop 3 -> start: 29.319014656 GB
loop 3 -> after generate: 29.319014656 GB


  0%|          | 4/12032 [01:03<56:42:48, 16.97s/it]

loop 3 -> after entropy: 29.319014656 GB
Answer: 7
Entropy: 0.09130859375
is_correct: False


loop 4 -> start: 29.319014656 GB
loop 4 -> after generate: 29.319014656 GB


  0%|          | 5/12032 [01:23<60:51:42, 18.22s/it]

loop 4 -> after entropy: 29.319014656 GB
Answer: 7
Entropy: 0.10302734375
is_correct: False


loop 5 -> start: 29.319014656 GB
loop 5 -> after generate: 29.319014656 GB


  0%|          | 6/12032 [01:41<60:36:59, 18.15s/it]

loop 5 -> after entropy: 29.319014656 GB
Answer: 3
Entropy: 0.25
is_correct: False


loop 6 -> start: 29.319014656 GB
loop 6 -> after generate: 29.319014656 GB


  0%|          | 7/12032 [01:55<55:43:07, 16.68s/it]

loop 6 -> after entropy: 29.319014656 GB
Answer: 3
Entropy: 0.083984375
is_correct: False


loop 7 -> start: 29.319014656 GB
loop 7 -> after generate: 29.319014656 GB


  0%|          | 8/12032 [02:13<57:32:00, 17.23s/it]

loop 7 -> after entropy: 29.319014656 GB
Answer: 1
Entropy: 0.058837890625
is_correct: False


loop 8 -> start: 29.319014656 GB
loop 8 -> after generate: 29.319014656 GB


  0%|          | 8/12032 [02:21<58:53:22, 17.63s/it]


KeyboardInterrupt: 