In [None]:
from datasets import load_dataset
import re
import pandas as pd

# noun-phrase table (≈18 k rows)
from datasets import load_dataset
ds = load_dataset("fairnlp/holistic-bias", "sentences",
                       split="test").to_pandas()

In [None]:
axes_to_keep = [
    "gender_and_sex",
    "nationality",
    "religion",
    "sexual_orientation",
    "cultural"
]

mask = (
    ds["axis"].isin(axes_to_keep)                                # ① axis filter                                  # ② noun present
    & ds.apply(                                                  # ③ template ends with {noun_phrase}.
        lambda r: r["template"].rstrip().endswith("{noun_phrase}."),
        axis=1
    )
)

df_filtered = ds[mask].reset_index(drop=True)
print(df_filtered.shape)

In [None]:
MODELS = [
    # "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
    "Qwen/Qwen2.5-7B-Instruct",
    "meta-llama/Llama-3.1-8B-Instruct",
    "mistralai/Ministral-8B-Instruct-2410",
]

In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch, math

MODEL_NAME = MODELS[0]    # pick your model
tokenizer   = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(
            MODEL_NAME,
            torch_dtype="auto",
            device_map="auto",
            trust_remote_code=True).eval()

In [None]:
import math
import torch, torch.nn.functional as F
import pandas as pd
from tqdm.auto import tqdm

# ───────────────────────── helpers ─────────────────────────
def rfind_subseq(hay, needle):
    for i in range(len(hay) - len(needle), -1, -1):
        if hay[i : i + len(needle)] == needle:
            return i
    return -1


def continuation_stats(cut_sentence, continuation, full_sentence):

    messages = [
        {"role": "user",
         "content": f"Finish the following sentence:\n\n{cut_sentence}"},
        {"role": "assistant",
         "content": full_sentence}
    ]
    ids = tokenizer.apply_chat_template(
        messages,
        add_generation_prompt=False,
        return_tensors="pt"
    ).to(model.device)

    with torch.no_grad():
        logits = model(ids).logits
    logp = F.log_softmax(logits.float(), dim=-1)[0]

    assistant_ids = tokenizer(full_sentence, add_special_tokens=False).input_ids
    a_start = rfind_subseq(ids[0].tolist(), assistant_ids)
    if a_start < 0:
        raise ValueError("assistant span not found")

    prefix_ids = tokenizer(cut_sentence, add_special_tokens=False).input_ids
    cont_ids   = assistant_ids[len(prefix_ids):]

    c_start = a_start + len(prefix_ids)
    ctx_rows   = logp[c_start - 1 : c_start - 1 + len(cont_ids)]
    idx        = torch.tensor(cont_ids, device=logp.device).unsqueeze(-1)
    token_logp = ctx_rows.gather(-1, idx).squeeze(-1)

    token_logps = token_logp.tolist()          # 1-D list[float]
    total_logp  = sum(token_logps)
    n_tokens    = len(token_logps)
    avg_logp    = total_logp / n_tokens
    perplexity  = math.exp(-avg_logp)

    return {
        "token_logps": token_logps,
        "logP_continuation": total_logp,
        "n_cont_tokens": n_tokens,
        "avg_logP_per_token": avg_logp,
        "perplexity": perplexity,
    }


# ───────────────────────── main loop ───────────────────────
records = []
for _, row in tqdm(df_filtered.iterrows(),
                   total=len(df_filtered),
                   desc="scoring"):

    template    = row["template"]
    noun_phrase = row["noun_phrase"]

    pre, post = template.split("{noun_phrase}", 1)
    cut_sent      = pre
    continuation  = " " + noun_phrase + post
    full_sent     = cut_sent + continuation

    stats = continuation_stats(cut_sent, continuation, full_sent)

    rec = row.to_dict()
    rec.update(stats)                 # merge the new metrics
    records.append(rec)

df_with_logp = pd.DataFrame(records)

# optional: persist to disk
# df_with_logp.to_parquet("filtered_with_logp.parquet")   # keeps list columns intact
# df_with_logp.to_csv("filtered_with_logp.csv", index=False)  # lists will beb JSON-ish

df_with_logp.head()

In [None]:
df_with_logp.to_csv("test_qwen_long.csv")