# Multilingual Semantics Probe

## Step 1: Corpus Generation

In [131]:
from __future__ import annotations

import itertools
import json
from dataclasses import dataclass
from typing import Dict, List

import pandas as pd
import os

In [132]:
STIMULI_DIR = "./stimuli"

if not os.path.exists(STIMULI_DIR):
    os.mkdir(STIMULI_DIR)

In [133]:
# --- English lexicon ---
EN_SUBJECTS = [
    "shark",
    "robot",
    "chef",
    "dog",
]

EN_OBJECTS = [
    "pirate",
    "student",
    "doctor",
    "tourist",
]

# Use correct simple past forms
EN_VERBS_PAST = [
    "ate",
    "helped",
    "pushed",
    "chased",
]

# --- Mandarin lexicon ---
# Bare nouns only (no quantifiers inside)
ZH_SUBJECTS = [
    "鲨鱼",
    "机器人",
    "厨师",
    "狗",
]

ZH_OBJECTS = [
    "海盗",
    "学生",
    "医生",
    "游客",
]

# Verb stems compatible with 了
ZH_VERBS = [
    "吃",
    "帮助",
    "推",
    "追",
]

# Optional classifier map (defaults to 个)
ZH_CLASSIFIER: Dict[str, str] = {
    "鲨鱼": "只",
    "狗": "只",
    "机器人": "个",
    "厨师": "个",
    "海盗": "个",
    "学生": "个",
    "医生": "个",
    "游客": "个",
}

In [134]:
EN_TEMPLATES = [
    # Classic ambiguous English form
    "A {subj} {verb_past} every {obj}.",
]

ZH_TEMPLATES = [
    # Canonical Mandarin surface-scope reading
    "有一{cl}{subj}{verb}了每个{obj}。",
]

In [135]:
@dataclass(frozen=True)
class Stimulus:
    language: str
    template_id: str
    subj: str
    obj: str
    verb: str
    sentence: str


def get_classifier(noun: str, cl_map: Dict[str, str]) -> str:
    return cl_map.get(noun, "个")


def generate_english(
    subjects: List[str],
    objects: List[str],
    verbs_past: List[str],
) -> List[Stimulus]:
    out: List[Stimulus] = []
    for tid, tmpl in enumerate(EN_TEMPLATES):
        for subj, obj, verb in itertools.product(subjects, objects, verbs_past):
            out.append(
                Stimulus(
                    language="en",
                    template_id=f"en_{tid}",
                    subj=subj,
                    obj=obj,
                    verb=verb,
                    sentence=tmpl.format(
                        subj=subj,
                        obj=obj,
                        verb_past=verb,
                    ),
                )
            )
    return out


def generate_mandarin(
    subjects: List[str],
    objects: List[str],
    verbs: List[str],
    cl_map: Dict[str, str],
) -> List[Stimulus]:
    out: List[Stimulus] = []
    for tid, tmpl in enumerate(ZH_TEMPLATES):
        for subj, obj, verb in itertools.product(subjects, objects, verbs):
            cl = get_classifier(subj, cl_map)
            out.append(
                Stimulus(
                    language="zh",
                    template_id=f"zh_{tid}",
                    subj=subj,
                    obj=obj,
                    verb=verb,
                    sentence=tmpl.format(
                        cl=cl,
                        subj=subj,
                        obj=obj,
                        verb=verb,
                    ),
                )
            )
    return out

In [136]:
stimuli = []
stimuli += generate_english(EN_SUBJECTS, EN_OBJECTS, EN_VERBS_PAST)
stimuli += generate_mandarin(ZH_SUBJECTS, ZH_OBJECTS, ZH_VERBS, ZH_CLASSIFIER)

df = pd.DataFrame([s.__dict__ for s in stimuli])

# Stable IDs for downstream scoring
df.insert(
    0,
    "stimulus_id",
    [
        f"{row.language}-{row.template_id}-{row.Index:06d}"
        for row in df.itertuples()
    ],
)

In [137]:
print("Total stimuli:", len(df))
print(df["language"].value_counts())

display(
    df[df["language"] == "en"][["stimulus_id", "sentence"]].sample(
        min(5, (df["language"] == "en").sum()),
        random_state=0,
    )
)

display(
    df[df["language"] == "zh"][["stimulus_id", "sentence"]].sample(
        min(5, (df["language"] == "zh").sum()),
        random_state=0,
    )
)

Total stimuli: 128
language
en    64
zh    64
Name: count, dtype: int64


Unnamed: 0,stimulus_id,sentence
45,en-en_0-000045,A chef helped every tourist.
29,en-en_0-000029,A robot helped every tourist.
43,en-en_0-000043,A chef chased every doctor.
61,en-en_0-000061,A dog helped every tourist.
34,en-en_0-000034,A chef pushed every pirate.


Unnamed: 0,stimulus_id,sentence
109,zh-zh_0-000109,有一个厨师帮助了每个游客。
93,zh-zh_0-000093,有一个机器人帮助了每个游客。
107,zh-zh_0-000107,有一个厨师追了每个医生。
125,zh-zh_0-000125,有一只狗帮助了每个游客。
98,zh-zh_0-000098,有一个厨师推了每个海盗。


In [138]:
# Serialize
df.to_csv(os.path.join(STIMULI_DIR,"stimuli.csv"), index=False)

with open(os.path.join(STIMULI_DIR, "stimuli.jsonl"), "w", encoding="utf-8") as f:
    for row in df.to_dict(orient="records"):
        f.write(json.dumps(row, ensure_ascii=False) + "\n")

print("Wrote stimuli.csv and stimuli.jsonl")

Wrote stimuli.csv and stimuli.jsonl


### Add Natural Language Continuations

In [139]:
EN_CONTINUATIONS = {
    "surface": " There was only one {subj}.",
    "inverse": " There were many {subj}.",
}

# Mandarin: keep equally short.
# Note: plural is usually implicit; "很多" is a decent lexical cue.
ZH_CONTINUATIONS = {
    "surface": " 只有一{cl}{subj}。",
    "inverse": " 有很多{cl}{subj}。",
}

In [140]:
def add_continuations(df: pd.DataFrame) -> pd.DataFrame:
    rows = []
    for r in df.itertuples(index=False):
        base = r._asdict()

        if base["language"] == "en":
            # naive pluralization: add "s"
            # If you care about irregular plurals later, add a map.
            subj_plural = base["subj"] + "s"
            cont_map = {
                "surface": EN_CONTINUATIONS["surface"].format(subj=base["subj"]),
                "inverse": EN_CONTINUATIONS["inverse"].format(subj=subj_plural),
            }

        elif base["language"] == "zh":
            cl = ZH_CLASSIFIER.get(base["subj"], "个")
            cont_map = {
                "surface": ZH_CONTINUATIONS["surface"].format(cl=cl, subj=base["subj"]),
                "inverse": ZH_CONTINUATIONS["inverse"].format(cl=cl, subj=base["subj"]),
            }
        else:
            raise ValueError(f"Unknown language: {base['language']}")

        for cont_type, cont_text in cont_map.items():
            ex = dict(base)
            ex["continuation_type"] = cont_type            # "surface" or "inverse"
            ex["continuation_text"] = cont_text            # the thing you'll score
            ex["full_text"] = base["sentence"] + cont_text # convenient for debugging
            rows.append(ex)

    return pd.DataFrame(rows)



In [141]:
df_cont = add_continuations(df)

if "concept_id" not in df_cont.columns:
    concept_series = df_cont["subj"] + "|" + df_cont["obj"] + "|" + df_cont["verb"]
    df_cont.insert(1, "concept_id", concept_series)

In [142]:
df_cont.to_csv(os.path.join(STIMULI_DIR, "stimuli_with_continuations.csv"), index=False)

with open(os.path.join(STIMULI_DIR, "stimuli_with_continuations.jsonl"), "w", encoding="utf-8") as f:
    for row in df_cont.to_dict(orient="records"):
        f.write(json.dumps(row, ensure_ascii=False) + "\n")

print("Wrote stimuli_with_continuations.csv and stimuli_with_continuations.jsonl")

Wrote stimuli_with_continuations.csv and stimuli_with_continuations.jsonl


## Step 2: Evaluate Log Probs of Continuations

In [6]:
from collections import defaultdict

import json
import torch
import pandas as pd
from transformers import AutoTokenizer, AutoModelForCausalLM

  from .autonotebook import tqdm as notebook_tqdm


### Load in Continuations

In [7]:
JSONL_PATH = "stimuli/stimuli_with_continuations.jsonl"

def read_jsonl(path):
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if line:
                yield json.loads(line)

In [8]:
rows = list(read_jsonl(JSONL_PATH))
df = pd.DataFrame(rows)

### Load Model and Tokenizer

In [9]:
MODEL_NAME = "gpt2"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE = torch.float16 if (DEVICE == "cuda") else torch.float32

In [10]:
BATCH_SIZE = 16

In [11]:
OUT_CSV = "logprob_surface_vs_inverse.csv"

In [12]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

In [13]:
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=DTYPE
).to(DEVICE)

model.eval()

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-11): 12 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D(nf=2304, nx=768)
          (c_proj): Conv1D(nf=768, nx=768)
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D(nf=3072, nx=768)
          (c_proj): Conv1D(nf=768, nx=3072)
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=768, out_features=50257, bias=False)
)

### Calculate Prefix Log Prob
1) Tokenize Prompt and Continuation

2) Calculate Log Probs for prompt + continuation

3) Calculate continuation log probs: [a] get log prob of each next token from [0,T-1]; [b] sum log probs
- Given a prompt of length T, log_prob[:, :T-1, :] gives the probability of the first [0,T-1] tokens

```

a shark ate every pirate
                    ^ stop here; we already know the probability of this word
```

```python
for b in range(B):
    for t in range(T):
        target_log_probs[b,t] = shifted_log_probs[b, t, target_tokens[b,t]]
```


In [28]:
def get_log_probs(prompts: list[str], continuations: list[str]):
    # 1) Tokenize Prompt and Continuation
    enc_base_prompts = tokenizer(
        prompts,
        return_tensors="pt",

        # Handle different length prompts
        padding=True,
        truncation=True,

        # Avoid [EOS]/[BOS] from being inserted
        add_special_tokens=False
    )

    full_prompts = [p + c for p, c in zip(prompts, continuations)]
    enc_full_prompts = tokenizer(
        full_prompts,
        return_tensors="pt",

        # Handle different length prompts
        padding=True,
        truncation=True,

        # Avoid [EOS]/[BOS] from being inserted
        add_special_tokens=False
    )

    input_ids = enc_full_prompts["input_ids"].to(DEVICE)
    attention_mask = enc_full_prompts["attention_mask"].to(DEVICE)

    # 2) Calculate logProbs for prompt + continuation
    out = model(input_ids=input_ids, attention_mask=attention_mask)
    logits = out.logits  # [B, T, V]
    log_probs = torch.log_softmax(logits, dim=-1)  # [B, T, V]

    # 3) Calculate continuation log probs: [a] get log prob of each next token from [0,T-1]; [b] sum log probs
    target_tokens = input_ids[:, 1:]  # [B, T-1]
    shifted_log_probs = log_probs[:, :-1, :]  # [B, T-1, V]

    # Select the logProb for the selected token
    target_log_probs = torch.gather(
        input=shifted_log_probs,
        dim=-1,  # Select the log_prob for the selected prompt token in vocab
        index=target_tokens.unsqueeze(-1)  # [B, T-1, 1]
    ).squeeze(-1)  # [B, T-1]

    base_prompt_lens = enc_base_prompts["attention_mask"].sum(dim=1).tolist()
    full_prompt_lens = enc_full_prompts["attention_mask"].sum(dim=1).tolist()

    # The logProbs for the continuation live at (inclusive)
    # logits[m-1:n-2] -> logits[m-1:L-2] -> target_log_probs[m-1:L-1] since we already removed one token in the shift
    B, _ = target_log_probs.shape

    cont_log_probs_list = []

    for b in range(B):
        base_prompt_length = base_prompt_lens[b]
        full_prompt_length = full_prompt_lens[b]

        cont_log_probs = target_log_probs[b,
                                          base_prompt_length-1:full_prompt_length-1]
        cont_log_probs_sum = cont_log_probs.sum().item()
        cont_log_probs_mean = cont_log_probs.mean().item()
        n_cont_tokens = full_prompt_length - base_prompt_length  # Sanity Check

        cont_log_probs_list.append(
            {"cont_log_probs_sum": cont_log_probs_sum,
                "cont_log_probs_mean": cont_log_probs_mean, 
                "n_cont_tokens": n_cont_tokens
            }
        )
        # debugging
        # print(base_prompt_length, full_prompt_length)
        # print(attention_mask[b].shape, attention_mask[b])
        # print(target_log_probs[b].shape, target_log_probs[b])
        # print(cont_log_probs.shape,cont_log_probs)
        # print("="*10)
    return cont_log_probs_list

In [32]:
get_log_probs(
    prompts=["a shark ate every pirate", "a shark ate every pirate","a shark ate every pirate"],
    continuations=["there is exactly one shark", "there are many sharks","."]
)

[{'cont_log_probs_sum': -40.21692657470703,
  'cont_log_probs_mean': -5.745275020599365,
  'n_cont_tokens': 7},
 {'cont_log_probs_sum': -38.0067253112793,
  'cont_log_probs_mean': -6.33445405960083,
  'n_cont_tokens': 6},
 {'cont_log_probs_sum': -3.460556745529175,
  'cont_log_probs_mean': -3.460556745529175,
  'n_cont_tokens': 1}]