# Multilingual Semantics Probe

## Step 2: Evaluate Log Probs of Continuations

In [None]:
import numpy as np
import json
import torch
import pandas as pd
from transformers import AutoTokenizer, AutoModelForCausalLM

### Load in Continuations

In [25]:
JSONL_PATH = "stimuli/stimuli_with_continuations.jsonl"

def read_jsonl(path):
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if line:
                yield json.loads(line)

In [26]:
rows = list(read_jsonl(JSONL_PATH))
continuation_df = pd.DataFrame(rows)

### Load Model and Tokenizer

In [None]:
# MODEL_NAME = "gpt2"
# MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct"
# MODEL_NAME = "meta-llama/Llama-3.2-1B"
# MODEL_NAME = "Qwen/Qwen3-0.6B"
MODEL_NAME = "google/gemma-3-1b-it"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# DTYPE = torch.float16 if (DEVICE == "cuda") else torch.float32
DTYPE = torch.float32 # Avoid Precision overflows

In [28]:
BATCH_SIZE = 16

In [29]:
OUT_CSV = "logprob_surface_vs_inverse.csv"

In [None]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

: 

In [None]:
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=DTYPE
).to(DEVICE)

model.eval()

Fetching 2 files: 100%|██████████| 2/2 [08:14<00:00, 247.34s/it]
Loading checkpoint shards:  50%|█████     | 1/2 [00:31<00:31, 31.48s/it]

### Calculate Continuation Log Prob
1) Tokenize Prompt and Continuation

2) Calculate Log Probs for prompt + continuation

3) Calculate continuation log probs: [a] get log prob of each next token from [0,T-1]; [b] sum log probs
- Given a prompt of length T, log_prob[:, :T-1, :] gives the probability of the first [0,T-1] tokens

```

a shark ate every pirate
                    ^ stop here; we already know the probability of this word
```

```python
for b in range(B):
    for t in range(T):
        target_log_probs[b,t] = shifted_log_probs[b, t, target_tokens[b,t]]
```


In [None]:
def get_log_probs(prompts: list[str], continuations: list[str]) -> list[dict]:
    # 1) Tokenize Prompt and Continuation
    enc_base_prompts = tokenizer(
        prompts,
        return_tensors="pt",

        # Handle different length prompts
        padding=True,
        truncation=True,

        # Avoid [EOS]/[BOS] from being inserted
        add_special_tokens=False
    )

    full_prompts = [p + c for p, c in zip(prompts, continuations)]
    enc_full_prompts = tokenizer(
        full_prompts,
        return_tensors="pt",

        # Handle different length prompts
        padding=True,
        truncation=True,

        # Avoid [EOS]/[BOS] from being inserted
        add_special_tokens=False
    )

    input_ids = enc_full_prompts["input_ids"].to(DEVICE)
    attention_mask = enc_full_prompts["attention_mask"].to(DEVICE)

    # 2) Calculate logProbs for prompt + continuation
    out = model(input_ids=input_ids, attention_mask=attention_mask)
    logits = out.logits  # [B, T, V]
    log_probs = torch.log_softmax(logits, dim=-1)  # [B, T, V]

    # 3) Calculate continuation log probs: [a] get log prob of each next token from [0,T-1]; [b] sum log probs
    target_tokens = input_ids[:, 1:]  # [B, T-1]
    shifted_log_probs = log_probs[:, :-1, :]  # [B, T-1, V]

    # Select the logProb for the selected token
    target_log_probs = torch.gather(
        input=shifted_log_probs,
        dim=-1,  # Select the log_prob for the selected prompt token in vocab
        index=target_tokens.unsqueeze(-1)  # [B, T-1, 1]
    ).squeeze(-1)  # [B, T-1]

    base_prompt_lens = enc_base_prompts["attention_mask"].sum(dim=1).tolist()
    full_prompt_lens = enc_full_prompts["attention_mask"].sum(dim=1).tolist()

    # The logProbs for the continuation live at (inclusive)
    # logits[m-1:n-2] -> logits[m-1:L-2] -> target_log_probs[m-1:L-1] since we already removed one token in the shift
    B, _ = target_log_probs.shape

    cont_log_probs_list = []

    for b in range(B):
        base_prompt_length = base_prompt_lens[b]
        full_prompt_length = full_prompt_lens[b]

        cont_log_probs = target_log_probs[b,
                                          base_prompt_length-1:full_prompt_length-1]
        cont_log_probs_sum = cont_log_probs.sum().item()
        cont_log_probs_mean = cont_log_probs.mean().item()
        n_cont_tokens = full_prompt_length - base_prompt_length  # Sanity Check

        cont_log_probs_list.append(
            {"cont_log_probs_sum": cont_log_probs_sum,
                "cont_log_probs_mean": cont_log_probs_mean,
                "n_cont_tokens": n_cont_tokens
             }
        )
        
    return cont_log_probs_list

In [None]:
def process_log_probs(df: pd.DataFrame, b_sz: int = BATCH_SIZE):
    df = df.reset_index(drop=True).copy() # Pandas safety reasons
    all_stats = []

    for start in range(0, len(df), b_sz):
        end = min(start + b_sz, len(df))
        batch = continuation_df.iloc[start:end]

        stats = get_log_probs(
            prompts=batch["sentence"].tolist(),
            continuations=batch["continuation_text"].tolist()
        )

        all_stats += stats

    stats_df = pd.DataFrame(all_stats)    

    return pd.concat([df, stats_df], axis=1)

In [None]:
# AI Generated pivot table magic to get the metrics
def collapse_surface_inverse(df: pd.DataFrame) -> pd.DataFrame:
    df = df[df["continuation_type"].isin(["surface", "inverse"])].copy()

    num_cols  = ["cont_log_probs_sum", "cont_log_probs_mean", "n_cont_tokens"]
    text_cols = ["continuation_text", "full_text"]

    # metadata that is shared across types
    meta_cols = [
        c for c in df.columns
        if c not in (["continuation_type"] + num_cols + text_cols)
    ]
    meta = df.groupby("stimulus_id", as_index=False)[meta_cols].first()

    # --- 1) pivot numeric columns (keeps float/int dtypes) ---
    wide_num = (
        df[["stimulus_id", "continuation_type"] + num_cols]
        .pivot(index="stimulus_id", columns="continuation_type", values=num_cols)
    )
    wide_num.columns = [f"{col}_{ctype}" for col, ctype in wide_num.columns]
    wide_num = wide_num.reset_index()

    # --- 2) pivot text columns (object dtype is fine here) ---
    wide_text = (
        df[["stimulus_id", "continuation_type"] + text_cols]
        .pivot(index="stimulus_id", columns="continuation_type", values=text_cols)
    )
    wide_text.columns = [f"{col}_{ctype}" for col, ctype in wide_text.columns]
    wide_text = wide_text.reset_index()

    # merge everything
    out = meta.merge(wide_num, on="stimulus_id", how="left") \
              .merge(wide_text, on="stimulus_id", how="left")

    # deltas + ratios (now safe: numeric dtypes)
    out["delta_sum"]  = out["cont_log_probs_sum_inverse"]  - out["cont_log_probs_sum_surface"]
    out["ratio_sum"]  = np.exp(out["delta_sum"])
    
    out["delta_mean"] = out["cont_log_probs_mean_inverse"] - out["cont_log_probs_mean_surface"]
    out["ratio_mean"] = np.exp(out["delta_mean"])

    out["preference_mean"] = np.select(
        [out["delta_mean"] > 0, out["delta_mean"] < 0],
        ["inverse", "surface"],
        default="tie"
    )
    out["preference_sum"] = np.select(
        [out["delta_sum"] > 0, out["delta_sum"] < 0],
        ["inverse", "surface"],
        default="tie"
    )

    return out

### Save Outputs
- Add model name as a column

In [None]:
from pathlib import Path
RESULTS_DIR = Path("./results")
RESULTS_DIR.mkdir(parents=True, exist_ok=True)

In [None]:
import re
safe_model = re.sub(r"[^\w\-\.]", "_", MODEL_NAME)

In [None]:
scored_long = process_log_probs(continuation_df, b_sz=BATCH_SIZE)
scored_long.insert(0,"model", safe_model)

scored_wide = collapse_surface_inverse(scored_long)

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


In [None]:
scored_wide

Unnamed: 0,model,stimulus_id,concept_id,language,template_id,subj,obj,verb,sentence,cont_log_probs_sum_inverse,...,continuation_text_inverse,continuation_text_surface,full_text_inverse,full_text_surface,delta_sum,delta_mean,ratio_sum,ratio_mean,preference_mean,preference_sum
0,google_gemma-3-1b-it,en-en_0-000000,shark|pirate|ate,en,en_0,shark,pirate,ate,A shark ate every pirate.,-27.478607,...,There were many sharks.,There was only one shark.,A shark ate every pirate. There were many sharks.,A shark ate every pirate. There was only one s...,2.323645,-0.528679,10.212829,0.589383,surface,inverse
1,google_gemma-3-1b-it,en-en_0-000001,shark|pirate|helped,en,en_0,shark,pirate,helped,A shark helped every pirate.,-25.750093,...,There were many sharks.,There was only one shark.,A shark helped every pirate. There were many s...,A shark helped every pirate. There was only on...,2.552305,-0.432952,12.836661,0.648591,surface,inverse
2,google_gemma-3-1b-it,en-en_0-000002,shark|pirate|pushed,en,en_0,shark,pirate,pushed,A shark pushed every pirate.,-22.141607,...,There were many sharks.,There was only one shark.,A shark pushed every pirate. There were many s...,A shark pushed every pirate. There was only on...,6.287125,0.309801,537.605292,1.363153,inverse,inverse
3,google_gemma-3-1b-it,en-en_0-000003,shark|pirate|chased,en,en_0,shark,pirate,chased,A shark chased every pirate.,-28.556351,...,There were many sharks.,There was only one shark.,A shark chased every pirate. There were many s...,A shark chased every pirate. There was only on...,1.695395,-0.669313,5.448795,0.512060,surface,inverse
4,google_gemma-3-1b-it,en-en_0-000004,shark|student|ate,en,en_0,shark,student,ate,A shark ate every student.,-23.287512,...,There were many sharks.,There was only one shark.,A shark ate every student. There were many sha...,A shark ate every student. There was only one ...,6.519796,0.310382,678.440221,1.363947,inverse,inverse
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
123,google_gemma-3-1b-it,zh-zh_0-000123,狗|医生|追,zh,zh_0,狗,医生,追,有一只狗追了每个医生。,-27.165693,...,有很多只狗。,只有一只狗。,有一只狗追了每个医生。 有很多只狗。,有一只狗追了每个医生。 只有一只狗。,-3.175335,-0.635067,0.041780,0.529900,surface,surface
124,google_gemma-3-1b-it,zh-zh_0-000124,狗|游客|吃,zh,zh_0,狗,游客,吃,有一只狗吃了每个游客。,-29.891842,...,有很多只狗。,只有一只狗。,有一只狗吃了每个游客。 有很多只狗。,有一只狗吃了每个游客。 只有一只狗。,2.408596,0.481719,11.118340,1.618856,inverse,inverse
125,google_gemma-3-1b-it,zh-zh_0-000125,狗|游客|帮助,zh,zh_0,狗,游客,帮助,有一只狗帮助了每个游客。,-28.112892,...,有很多只狗。,只有一只狗。,有一只狗帮助了每个游客。 有很多只狗。,有一只狗帮助了每个游客。 只有一只狗。,2.223234,0.444647,9.237157,1.559939,inverse,inverse
126,google_gemma-3-1b-it,zh-zh_0-000126,狗|游客|推,zh,zh_0,狗,游客,推,有一只狗推了每个游客。,-30.290436,...,有很多只狗。,只有一只狗。,有一只狗推了每个游客。 有很多只狗。,有一只狗推了每个游客。 只有一只狗。,0.486448,0.097290,1.626529,1.102179,inverse,inverse


In [None]:
scored_long.to_json(
    RESULTS_DIR / f"scored_long_{safe_model}.jsonl",
    orient="records",
    lines=True
)

scored_wide.to_json(
    RESULTS_DIR / f"scored_wide_{safe_model}.jsonl",
    orient="records",
    lines=True
)