In [1]:
import pandas as pd
import numpy as np
import jax.numpy as jnp

# Your flat index → category mapping (adjust if your order differs)
FLAT_TO_CATEGORIES = {
    "0":  "D",
    "1":  "DC",
    "2":  "DCF",
    "3":  "DF",
    "4":  "DFC",
    "5":  "C",
    "6":  "CD",
    "7":  "CDF",
    "8":  "CF",
    "9":  "CFD",
    "10": "F",
    "11": "FD",
    "12": "FDC",
    "13": "FC",
    "14": "FCD",
}

def build_lm_prior_from_csv(
    csv_path: str,
    flat_to_categories: dict = FLAT_TO_CATEGORIES,
    beta: float = 1.0,
) -> jnp.ndarray:
    """
    Construct a 15-dim LM-based prior over utterance categories from
    GPT-2 surprisal estimates stored in a CSV.

    - csv_path: path to LM_adjective_sequences_with_surprisal.csv
    - beta: temperature on the LM; higher beta = sharper prior

    Returns:
        jnp.ndarray of shape (15,) with probabilities in the order
        0..14 as in flat_to_categories.
    """
    # 1) Load CSV
    df = pd.read_csv(csv_path)

    # 2) Mean surprisal (bits per token) per order_key (D, DF, DCF, ...)
    mean_surp = (
        df.groupby("order_key")["surprisal_bits_per_token"]
          .mean()
    )  # index: order_key, values: mean surprisal

    # 3) Convert surprisal in bits to unnormalized "LM prior" weights
    #    If surprisal_bits ≈ -log2 P_LM(u), then P_LM(u) ∝ 2^{-surprisal_bits}
    #    We additionally raise to beta: 2^{-beta * surprisal_bits}
    weights = {
        order_key: 2.0 ** (-beta * s)
        for order_key, s in mean_surp.items()
    }

    # 4) Collect weights in the flat 0..14 order and normalize
    categories_in_order = [flat_to_categories[str(i)] for i in range(len(flat_to_categories))]

    lm_prior = np.array(
        [weights[cat] for cat in categories_in_order],
        dtype=float,
    )

    lm_prior = lm_prior / lm_prior.sum()

    # 5) Return as jax array
    return jnp.asarray(lm_prior, dtype=jnp.float32)




In [3]:
# Example usage:
csv_path = "./LM_adjective_sequences_with_surprisal.csv"
lm_prior = build_lm_prior_from_csv(csv_path, beta=1.0)
print(lm_prior, lm_prior.shape)  # (15,)

Platform 'METAL' is experimental and not all JAX functionality may be correctly supported!


Metal device set to: Apple M2 Pro

systemMemory: 16.00 GB
maxCacheSize: 5.92 GB

[0.1669514  0.16733019 0.12160929 0.11005973 0.09253279 0.07532827
 0.02494562 0.03780574 0.05690099 0.02470998 0.02651604 0.01232579
 0.03122547 0.0363892  0.01536951] (15,)
