In [2]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

MODEL_NAME = "microsoft/phi-4"
DEVICE = torch.device("cpu")
if torch.cuda.is_available():
    DEVICE = torch.device("cuda")
if torch.mps.is_available():
    DEVICE = torch.device("mps")

print(f"Using device: {DEVICE}")

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

# quantization_config = TorchAoConfig("int4_weight_only", group_size=32)
# model = AutoModelForCausalLM.from_pretrained(
#     MODEL_NAME,
#     low_cpu_mem_usage=True,
#     torch_dtype=torch.bfloat16,
#     quantization_config=quantization_config,
#     device_map=DEVICE,
# )
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME).to(DEVICE)

Using device: mps
import error: No module named 'triton'


Loading checkpoint shards:   0%|          | 0/6 [00:00<?, ?it/s]

In [3]:
print(model.get_memory_footprint() / 10**9)
print(next(model.parameters()).device)

29.319014656
mps:0


In [9]:
from dataclasses import dataclass
from typing import Any

import torch

LLMModel = Any


# TODO: Cite https://github.com/abazarova/tda4hallucinations/
@dataclass
class TokenwiseEntropy:
    llm_model: LLMModel
    device: str = "cuda"

    @torch.no_grad()
    def calculate(self, input_ids, n) -> float:
        token_distribution = self._get_token_distribution(input_ids, n)
        entropy = self._compute_entropy_from_logits(token_distribution) / n
        return entropy.detach().cpu().item()

    def _get_token_distribution(self, input_ids, n) -> torch.Tensor:
        # Yield the output of the model for the current example
        output = self.llm_model(
            input_ids,
            output_hidden_states=True,
            output_attentions=False,
        )

        return output.logits[0, -n:]

    def _compute_entropy_from_logits(
        self,
        logits: torch.Tensor,
    ) -> torch.Tensor:
        """
        Compute entropy from logits.

        Parameters:
        ----------
        logits : torch.Tensor
            Logits from the model.

        Returns:
        -------
        torch.Tensor
            Entropy values.
        """
        probabilities = torch.softmax(logits, dim=-1)
        log_probabilities = torch.log(probabilities + 1e-12)
        entropies = -torch.sum(probabilities * log_probabilities, dim=-1)
        # print(entropies)
        return entropies[0]

In [12]:
from typing import List


def get_sys_prompt(subject: str | None = None):
    if subject is not None:
        sys_msg = f"The following is a multiple choice question about {subject}."
    else:
        sys_msg = "The following is a multiple choice question."

    sys_msg += 'Please act as an expert and answer the question. Begin your answer by providing a short explanation. After providing your explanation, you must write down the NUMBER of the correct answer by strictly following the format: "[[answer]]".'
    return sys_msg


option_ids = [str(i + 1) for i in range(20)]


def get_user_prompt(question: str, options: List[str]):
    options_str = "\n".join([f"{option_id}. {answer}".strip() for option_id, answer in zip(option_ids, options)])
    user_prompt = f'Question: {question.strip()}\nOptions:\n{options_str}\nProvide a short explanation and choose one of the answers. After that, write down the NUMBER of the correct answer in the format: "[[answer]]"'
    return user_prompt

In [16]:
import ast
import csv
import gc
import os.path
import re

import pandas as pd
from tqdm import tqdm

DUMP_EVERY = 100
invalid_answers = 0


def estimate_dataset(
    df, model, tokenizer, get_subject_from_row, get_question_from_row, get_options_from_row, verify_answer, out_filename
):
    global invalid_answers

    model_name = model.config_class().model_type
    print(model_name)

    field_ans_correct = f"entropy_cot_ans_correct_{model_name}"
    field_entropy_value = f"entropy_cot_value_{model_name}"

    if field_ans_correct not in df.columns:
        df[field_ans_correct] = False
    if field_entropy_value not in df.columns:
        df[field_entropy_value] = 0.0

    entropy_estimator = TokenwiseEntropy(llm_model=model, device=DEVICE)

    for index, row in tqdm(df.iterrows(), total=df.shape[0]):
        if df.at[index, field_entropy_value] != 0.0:
            continue

        # print(f"loop {index} -> start: {model.get_memory_footprint(return_buffers=True) / 10**9} GB")

        sys_prompt = get_sys_prompt(get_subject_from_row(row))
        user_prompt = get_user_prompt(get_question_from_row(row), get_options_from_row(row))
        messages = [
            {"role": "system", "content": sys_prompt},
            {"role": "user", "content": user_prompt},
        ]
        formatted_prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)

        inputs = tokenizer(formatted_prompt, return_tensors="pt").to(DEVICE)

        outputs = model.generate(**inputs, max_new_tokens=500, pad_token_id=tokenizer.eos_token_id)
        # print(f"loop {index} -> after generate: {model.get_memory_footprint(return_buffers=True) / 10**9} GB")

        input_length = inputs.input_ids.shape[1]
        n = outputs.shape[1] - input_length
        answer_raw = outputs[0, input_length:]
        answer_str = tokenizer.decode(answer_raw, skip_special_tokens=True)
        # print(f"Answer: {answer_str}")
        try:
            answer = re.search("\\[\\[(\\d+?)\\]\\]", answer_str).group(1)
            if answer in option_ids:
                entropy = entropy_estimator.calculate(outputs, n)
                # print(f"loop {index} -> after entropy: {model.get_memory_footprint(return_buffers=True) / 10**9} GB")
                df.at[index, field_entropy_value] = entropy
                df.at[index, field_ans_correct] = verify_answer(row, answer)
            else:
                invalid_answers += 1
        except:
            invalid_answers += 1

        # print(
        #     f"Answer: {answer_str}\nEntropy: {df.at[index, field_entropy_value]}\nis_correct: {df.at[index, field_ans_correct]}\n\n"
        # )

        if index % DUMP_EVERY == 0:
            df.to_csv(out_filename, sep="\t", quoting=csv.QUOTE_NONE, quotechar="", escapechar="\\", index=False)

        gc.collect()
        if DEVICE == torch.device("cuda"):
            torch.cuda.empty_cache()

    df.to_csv(out_filename, sep="\t", quoting=csv.QUOTE_NONE, quotechar="", escapechar="\\", index=False)
    print(f"Processed dataset {out_filename}. Total entries: {df.shape[0]}. Invalid answers: {invalid_answers}")
    return df


ORIGINAL_DATASET = "../data/mmlu_pro_stem"
original_filename = f"{ORIGINAL_DATASET}.tsv"
out_filename = f"{ORIGINAL_DATASET}_w_phi4_entropy_cot.tsv"

if os.path.isfile(out_filename):
    df = pd.read_csv(
        out_filename,
        sep="\t",
        header=0,
        quoting=csv.QUOTE_NONE,
        quotechar="",
        escapechar="\\",
    )
else:
    df = pd.read_csv(
        original_filename,
        sep="\t",
        header=0,
    )
# df = df.head(10)


def verify_model_answer(row, model_answer):
    try:
        return int(row["answer_index"]) + 1 == int(model_answer)
    except:
        return False


estimate_dataset(
    df=df,
    model=model,
    tokenizer=tokenizer,
    get_subject_from_row=lambda row: row["base_cluster"],
    get_question_from_row=lambda row: row["question"],
    get_options_from_row=lambda row: ast.literal_eval(row["options"]),
    verify_answer=verify_model_answer,
    out_filename=out_filename,
)

phi3


  0%|          | 0/12032 [00:00<?, ?it/s]

Answer: Karl Llewellyn, a prominent legal realist, distinguished between the "grand style" and "formal style" of legal reasoning. The grand style involves a more creative and policy-oriented approach, where judges may consider broader social implications and policy goals. In contrast, the formal style is more rigid, focusing strictly on applying established legal rules and precedents.

Criticism of Llewellyn's distinction often centers on the practical application of these styles in judicial decision-making. One compelling criticism is that it is misleading to pigeon-hole judges into these distinct categories. In reality, judicial reasoning often involves a blend of both styles. Judges may start with a formal analysis but then shift to a grand style when the formal rules are indeterminate or lead to unjust outcomes. This criticism highlights the fluidity and complexity of judicial reasoning, which cannot be neatly categorized.

Therefore, the most compelling criticism is that it is mis

  0%|          | 1/12032 [01:16<255:10:16, 76.35s/it]

Answer: In the context of the Indian Constitution, "qualified rights" are those rights that are not absolute and can be restricted under certain conditions. Article 19, which deals with the freedom of speech and expression, assembly, association, movement, residence, and profession, is a qualified right because these freedoms can be restricted on grounds such as the sovereignty and integrity of India, security of the state, public order, decency, morality, etc.

On the other hand, "unqualified rights" or "absolute rights" are those that cannot be restricted or are not subject to any conditions. Article 21, which guarantees the right to life and personal liberty, is often considered an unqualified right, although it has been subject to reasonable restrictions through judicial interpretation.

Article 12 defines the term "State" for the purposes of Part III of the Constitution, which deals with Fundamental Rights. It is not a right itself but a definitional provision.

Article 11 deals w

  0%|          | 2/12032 [02:45<279:35:59, 83.67s/it]

Answer: Ensuring that one individual does not carry the burden of a whole work task involves assigning parts of the task to different individuals or teams. This process helps in optimizing efficiency, utilizing diverse skills, and preventing burnout. The term that best describes this concept is "work delegation." Work delegation involves assigning responsibility and authority to others to complete specific tasks or projects, thereby distributing the workload among team members.

[[1]]


  0%|          | 3/12032 [03:25<213:53:25, 64.01s/it]

tensor([1.1325e-05, 5.1880e-04, 2.1582e-01, 6.3782e-03, 1.6113e-02, 6.6161e-06,
        1.0840e-01, 2.3365e-04, 1.5918e-01, 1.3447e-04, 5.5469e-01, 4.3945e-03,
        3.0884e-02, 7.1526e-05, 1.1406e+00, 1.7812e+00, 1.7031e+00, 2.2656e-01,
        8.3984e-01, 2.8125e-01, 2.5586e-01, 8.3594e-01, 6.6016e-01, 1.1094e+00,
        5.5859e-01, 6.6797e-01, 2.2827e-02, 2.0156e+00, 1.4219e+00, 1.2734e+00,
        2.2812e+00, 1.7031e+00, 6.4844e-01, 2.5781e+00, 1.9297e+00, 5.6250e-01,
        2.8711e-01, 2.2754e-01, 1.1484e+00, 1.4219e+00, 3.9673e-03, 1.5938e+00,
        1.8750e+00, 2.2031e+00, 9.4531e-01, 1.0938e+00, 8.0078e-01, 2.2070e-01,
        1.8906e+00, 2.8906e-01, 1.8652e-01, 7.1094e-01, 6.3672e-01, 7.5391e-01,
        8.5547e-01, 6.6280e-05, 1.1953e+00, 1.7109e+00, 1.1797e+00, 5.9766e-01,
        7.0801e-02, 9.9219e-01, 1.6797e+00, 8.1250e-01, 1.0469e+00, 1.4688e+00,
        2.0703e-01, 9.1406e-01, 2.3750e+00, 9.8047e-01, 2.1250e+00, 1.3750e+00,
        2.1387e-01, 3.6328e-01, 1.8828e+

  0%|          | 3/12032 [03:46<252:32:56, 75.58s/it]


KeyboardInterrupt: 