In [1]:
import ast
import pandas as pd

CSV_FILE = "model_attributions_prompt.csv"

df_attr = pd.read_csv(CSV_FILE)

TOKEN_COL = "token"
ATTR_COL = "mean_attr"

def parse_list_cell(x):

    if pd.isna(x):
        return []
    x = str(x).strip()

    try:
        val = ast.literal_eval(x)
        if isinstance(val, list):
            return val
    except (SyntaxError, ValueError):
        pass

    return [t.strip() for t in x.split(",") if t.strip()]

df_attr["tokens_list"] = df_attr[TOKEN_COL].apply(parse_list_cell)
df_attr["atts_list"] = df_attr[ATTR_COL].apply(parse_list_cell)


In [2]:
import spacy
import numpy as np
from collections import defaultdict

nlp = spacy.load("en_core_web_sm")

def get_high_attr_indices(atts, top_frac=0.3):
    if not atts:
        return []
    arr = np.array(atts, dtype=float)
    q = np.quantile(arr, 1 - top_frac)
    return [i for i, a in enumerate(arr) if a >= q]

diversity_records = []

for _, row in df_attr.iterrows():

    tokens = row["tokens_list"]
    atts = row["atts_list"]
    if not tokens or not atts:
        continue

    high_idx = get_high_attr_indices(atts, top_frac=0.3)
    high_tokens = [tokens[i].lower() for i in high_idx]

    if not high_tokens:
        continue

    types = set(high_tokens)
    ttr_high = len(types) / len(high_tokens)

    all_tokens = [t.lower() for t in tokens if t.strip()]
    all_types = set(all_tokens)
    ttr_all = len(all_types) / len(all_tokens) if all_tokens else 0.0

    diversity_records.append(
        {

            "ttr_high_attr": ttr_high,
            "ttr_all_tokens": ttr_all,
            "num_high_tokens": len(high_tokens),
            "num_all_tokens": len(all_tokens),
        }
    )

df_div = pd.DataFrame(diversity_records)
print(df_div.describe())


       ttr_high_attr  ttr_all_tokens  num_high_tokens  num_all_tokens
count          637.0           637.0            637.0           637.0
mean             1.0             1.0              1.0             1.0
std              0.0             0.0              0.0             0.0
min              1.0             1.0              1.0             1.0
25%              1.0             1.0              1.0             1.0
50%              1.0             1.0              1.0             1.0
75%              1.0             1.0              1.0             1.0
max              1.0             1.0              1.0             1.0


In [3]:
lemma_stats = defaultdict(lambda: {"sum_attr": 0.0, "count": 0})

for _, row in df_attr.iterrows():
    tokens = row["tokens_list"]
    atts = row["atts_list"]
    if not tokens or not atts:
        continue

    text = " ".join(tokens)
    doc = nlp(text)

    for tok_spacy, att in zip(doc, atts[: len(doc)]):
        lemma = tok_spacy.lemma_.lower()
        lemma_stats[lemma]["sum_attr"] += float(att)
        lemma_stats[lemma]["count"] += 1

rows = []
for lemma, stats in lemma_stats.items():
    mean_attr = stats["sum_attr"] / max(stats["count"], 1)
    rows.append(
        {
            "lemma": lemma,
            "sum_attr": stats["sum_attr"],
            "count": stats["count"],
            "mean_attr": mean_attr,
        }
    )

df_lemma = pd.DataFrame(rows).sort_values("sum_attr", ascending=False).reset_index(drop=True)

print("Top 30 lemmas by total attribution:")
print(df_lemma.head(30))

K = 10
total_attr = df_lemma["sum_attr"].sum()
topK_attr = df_lemma.head(K)["sum_attr"].sum()
concentration = topK_attr / total_attr if total_attr > 0 else 0.0

print(f"Share of attribution in top {K} lemmas: {concentration:.3f}")


Top 30 lemmas by total attribution:
           lemma   sum_attr  count  mean_attr
0           fuck  27.942541      6   4.657090
1          penis  19.242113      3   6.414038
2   motherfucker  18.510885      3   6.170295
3        fucking  14.318698      3   4.772899
4          bitch  13.855964      5   2.771193
5           shit  13.327580      4   3.331895
6            ass  12.436071      3   4.145357
7           dick  11.980792      4   2.995198
8          pussy  10.659047      2   5.329524
9           suck   8.800166      2   4.400083
10          slut   7.899959      2   3.949980
11          cock   7.284600      2   3.642300
12           tit   6.506023      2   3.253011
13          cunt   6.493803      1   6.493803
14         idiot   6.441504      1   6.441504
15        fucker   6.131848      1   6.131848
16  masturbating   5.865186      1   5.865186
17       asshole   5.693190      1   5.693190
18         loser   5.642341      1   5.642341
19        vagina   5.631040      1   5.63104