# 03 — Self-Model Orthogonality Study v2: Factorial Framing Study

A systematic investigation of self-referential "I" token representations across
many prompt framing conditions using a **factorial design**.

## Framing Types

| Type | Description |
|------|-------------|
| **Identity** | "You are {persona}. Introduce yourself..." |
| **Behavior** | "Respond as {persona} would. Introduce yourself..." |
| **Bare** | No persona prefix — just the question |
| **Negation** | "You are NOT {persona}. Introduce yourself..." |
| **De-roling** | Explicit instruction to drop any persona |
| **Authenticity** | "Be genuine / honest / authentic" prefixes |
| **Quoted speech** | Ask for a character (human, AI, or generic) to say "I..." |

## Design

- Each condition uses 12 prompt variants to reduce prompt-specific noise.
- Multiple personas per framing type where applicable.
- **Multi-I extraction**: we extract representations for ALL "I" tokens in
  each generated response (not just the first), enabling within-response
  consistency analysis.

In [None]:
import matplotlib
matplotlib.use('Agg')

import torch
import numpy as np
import matplotlib.pyplot as plt
from transformers import AutoTokenizer, AutoModelForCausalLM
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import StratifiedKFold, cross_val_score
from collections import defaultdict
import json
import datetime
import warnings
warnings.filterwarnings("ignore")

# ── Load model ──────────────────────────────────────────────────────────────────
MODEL_NAME = "mistralai/Mistral-7B-Instruct-v0.3"

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=torch.float16,
    device_map="auto",
    output_hidden_states=True,
)
model.eval()

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
    model.config.pad_token_id = model.config.eos_token_id

print(f"Loaded {MODEL_NAME}")
print(f"Layers: {model.config.num_hidden_layers}, Hidden dim: {model.config.hidden_size}")

def make_instruct_prompt(user_msg: str) -> str:
    """Wrap a user message in the Mistral instruct template."""
    return f"[INST] {user_msg} [/INST]"

N_LAYERS = model.config.num_hidden_layers  # 32 for Mistral-7B

## Condition Definitions

We build the CONDITIONS dictionary programmatically from:
- **Personas** (assistant, pirate, oracle, yourself, wizard, robot)
- **12 question templates** for self-introduction
- **7 framing types**: identity, behavior, bare, negation, de-roling, authenticity, quoted speech

Quoted-speech conditions use 12 named speakers per sub-type (human, AI, generic).

In [None]:
# ── Personas ────────────────────────────────────────────────────────────────────
PERSONAS = {
    "assistant": "a helpful AI assistant",
    "pirate": "a pirate captain",
    "oracle": "a mystical oracle",
    "yourself": "yourself",
    "wizard": "a wise old wizard",
    "robot": "a robot",
}

# ── 12 question templates ────────────────────────────────────────────────────────
QUESTION_TEMPLATES = [
    "Introduce yourself in a few sentences.",
    "Tell me about yourself in a few sentences.",
    "Describe who you are in a few sentences.",
    "Share a bit about yourself in a few sentences.",
    "Give me a brief self-introduction in a few sentences.",
    "Explain what you are in a few sentences.",
    "Present yourself briefly in a few sentences.",
    "Who are you? Answer in a few sentences.",
    "What are you? Answer in a few sentences.",
    "Describe your nature in a few sentences.",
    "Tell me what defines you in a few sentences.",
    "Introduce yourself to someone new in a few sentences.",
]

CONDITIONS = {}

# ── Identity framing: "You are {desc}. {question}" ─────────────────────────
for key, desc in PERSONAS.items():
    cond_name = f"identity_{key}"
    prompts = []
    for q in QUESTION_TEMPLATES:
        prompts.append(f"You are {desc}. {q}")
    CONDITIONS[cond_name] = prompts

# ── Behavior framing: "Respond as {desc} would. {question}" ────────────────
for key, desc in PERSONAS.items():
    cond_name = f"behavior_{key}"
    prompts = []
    for q in QUESTION_TEMPLATES:
        if key == "yourself":
            prompts.append(f"Respond as {desc}. {q}")
        else:
            prompts.append(f"Respond as {desc} would. {q}")
    CONDITIONS[cond_name] = prompts

# ── Bare: no persona prefix ─────────────────────────────────────────────────────
CONDITIONS["bare"] = list(QUESTION_TEMPLATES)

# ── Negation: "You are not {desc}. {question}" ─────────────────────────────
NEGATION_PERSONAS = {
    "assistant": "a helpful AI assistant",
    "pirate": "a pirate captain",
    "robot": "a robot",
}
for key, desc in NEGATION_PERSONAS.items():
    cond_name = f"negation_{key}"
    prompts = []
    for q in QUESTION_TEMPLATES:
        prompts.append(f"You are not {desc}. {q}")
    CONDITIONS[cond_name] = prompts

# ── De-roling: explicit drop-persona prefixes ──────────────────────────────
DEROL_PREFIXES = [
    "Set aside any persona or role.",
    "Without playing any character,",
    "Don't pretend to be anything. Just",
]
for i, prefix in enumerate(DEROL_PREFIXES):
    cond_name = f"derol_{i}"
    prompts = []
    for q in QUESTION_TEMPLATES:
        prompts.append(f"{prefix} {q}")
    CONDITIONS[cond_name] = prompts

# ── Authenticity: "Be genuine" etc. ─────────────────────────────────────────
AUTH_PREFIXES = [
    "Be genuine.",
    "Respond honestly, as yourself.",
    "Respond authentically.",
]
for i, prefix in enumerate(AUTH_PREFIXES):
    cond_name = f"authentic_{i}"
    prompts = []
    for q in QUESTION_TEMPLATES:
        prompts.append(f"{prefix} {q}")
    CONDITIONS[cond_name] = prompts

# ── Quoted speech ───────────────────────────────────────────────────────────────
HUMAN_SPEAKERS = [
    "Mary", "an old sailor", "John", "Sarah", "a detective", "a child",
    "Professor Williams", "a warrior", "Tom", "Maria", "a farmer", "Elena",
]
AI_SPEAKERS = [
    "an AI assistant", "a chatbot", "a virtual assistant",
    "a language model", "a digital helper", "an AI system",
    "an ML model", "a neural network", "an artificial intelligence",
    "a bot", "an automated assistant", "an AI companion",
]
GENERIC_SPEAKERS = [
    "a stranger", "someone in a crowd", "a voice from the shadows",
    "a figure at the door", "an unknown caller", "a mysterious visitor",
    "a passerby", "an anonymous sender", "someone behind them",
    "a newcomer", "a distant voice", "a traveler",
]

def make_quoted_prompts(speakers):
    prompts = []
    for speaker in speakers:
        prompts.append(
            f'Write a short paragraph where {speaker} says "I" while '
            f'introducing themselves. Use first person for the character.'
        )
    return prompts

CONDITIONS["quoted_human"] = make_quoted_prompts(HUMAN_SPEAKERS)
CONDITIONS["quoted_ai"] = make_quoted_prompts(AI_SPEAKERS)
CONDITIONS["quoted_generic"] = make_quoted_prompts(GENERIC_SPEAKERS)

# ── Summary ─────────────────────────────────────────────────────────────────────
print(f"Total conditions: {len(CONDITIONS)}")
print(f"{'Condition':<25} {'Prompts':>7}")
print("-" * 34)
for k, v in CONDITIONS.items():
    print(f"{k:<25} {len(v):>7}")
print(f"{'TOTAL':<25} {sum(len(v) for v in CONDITIONS.values()):>7}")

## Multi-I Extraction Function

Unlike the basic extractor that returns a single vector for the **first** "I" token,
this version finds **all** occurrences of the "I" token in the generated output and
returns their hidden-state representations along with positional metadata.

In [None]:
I_TOKEN_ID = tokenizer.encode("I", add_special_tokens=False)[0]

@torch.no_grad()
def get_all_i_token_representations(prompt_text: str, max_new_tokens: int = 150):
    """Generate a response and extract hidden states for ALL 'I' tokens.

    Returns
    -------
    list[dict]
        Each dict has keys:
        - layer_reps : dict mapping layer_idx -> numpy array (hidden_dim,)
        - position_absolute : int (position in full sequence)
        - position_relative : float (position / total_len)
        - position_index : int (0-based index among all I tokens found)
        - token_id : int
    str
        The decoded generated text.
    """
    instruct_prompt = make_instruct_prompt(prompt_text)
    inputs = tokenizer(instruct_prompt, return_tensors="pt").to(model.device)
    prompt_len = inputs["input_ids"].shape[1]

    outputs = model.generate(
        **inputs,
        max_new_tokens=max_new_tokens,
        do_sample=True,
        temperature=0.7,
        top_p=0.9,
        return_dict_in_generate=True,
    )
    generated_ids = outputs.sequences[0]
    gen_text = tokenizer.decode(generated_ids[prompt_len:], skip_special_tokens=True)

    # Forward pass through full sequence to get hidden states
    with torch.no_grad():
        full_outputs = model(generated_ids.unsqueeze(0), output_hidden_states=True)
    hidden_states = full_outputs.hidden_states  # tuple of (n_layers+1,) tensors

    # Find all I tokens in the generated portion
    gen_token_ids = generated_ids[prompt_len:]
    i_positions = []
    for idx_in_gen, tid in enumerate(gen_token_ids):
        if tid.item() == I_TOKEN_ID:
            abs_pos = prompt_len + idx_in_gen
            i_positions.append(abs_pos)

    total_len = generated_ids.shape[0]
    results = []
    for rank, abs_pos in enumerate(i_positions):
        layer_reps = {}
        for layer_idx in range(N_LAYERS + 1):
            vec = hidden_states[layer_idx][0, abs_pos, :].cpu().float().numpy()
            layer_reps[layer_idx] = vec
        results.append({
            "layer_reps": layer_reps,
            "position_absolute": abs_pos,
            "position_relative": abs_pos / total_len,
            "position_index": rank,
            "token_id": I_TOKEN_ID,
        })

    return results, gen_text

## Run Extractions

We sample `SAMPLES_PER_CONDITION` prompts from each condition and extract
all I-token representations. This is the most time-consuming cell.

In [None]:
SAMPLES_PER_CONDITION = 10
np.random.seed(42)

all_results = {}   # cond_key -> list of (i_token_dicts_list)
all_texts = {}     # cond_key -> list of gen_text

for cond_key, prompts in CONDITIONS.items():
    idxs = np.random.choice(len(prompts), size=SAMPLES_PER_CONDITION,
                            replace=len(prompts) < SAMPLES_PER_CONDITION)
    cond_results = []
    cond_texts = []
    total_i_tokens = 0
    for idx in idxs:
        prompt = prompts[idx]
        i_reps, gen_text = get_all_i_token_representations(prompt)
        cond_results.append(i_reps)
        cond_texts.append(gen_text)
        total_i_tokens += len(i_reps)
    all_results[cond_key] = cond_results
    all_texts[cond_key] = cond_texts
    print(f"{cond_key:<25}  samples={SAMPLES_PER_CONDITION}  total_I_tokens={total_i_tokens}")

print(f"\nDone. {len(all_results)} conditions processed.")

## Compute Reference Directions

We define two reference axes:
1. **Self-model direction (SM)**: separates *bare* self-reference from *quoted human* speech.
2. **Assistant axis (AA)**: separates *identity_assistant* from *identity_pirate*.
3. **Orthogonal component**: the part of SM that is perpendicular to AA — this is
   the "pure self-model" signal that cannot be explained by persona identity.

In [None]:
analysis_layers = [0, 8, 16, 24, 32]

def get_all_activations_for_condition(cond_key, layer_idx):
    """Collect all I-token activations for a condition at a given layer."""
    vecs = []
    for sample_i_list in all_results[cond_key]:
        for entry in sample_i_list:
            if layer_idx in entry["layer_reps"]:
                vecs.append(entry["layer_reps"][layer_idx])
    return np.array(vecs) if vecs else np.zeros((0, model.config.hidden_size))

# Compute directions at each analysis layer
directions = {}  # layer_idx -> dict with sm_dir, aa_dir, ortho_dir

for layer_idx in analysis_layers:
    bare_vecs = get_all_activations_for_condition("bare", layer_idx)
    quoted_human_vecs = get_all_activations_for_condition("quoted_human", layer_idx)
    asst_vecs = get_all_activations_for_condition("identity_assistant", layer_idx)
    pirate_vecs = get_all_activations_for_condition("identity_pirate", layer_idx)

    # Self-model direction: bare mean - quoted_human mean
    sm_dir = bare_vecs.mean(axis=0) - quoted_human_vecs.mean(axis=0)
    sm_dir = sm_dir / (np.linalg.norm(sm_dir) + 1e-10)

    # Assistant axis: identity_assistant mean - identity_pirate mean
    aa_dir = asst_vecs.mean(axis=0) - pirate_vecs.mean(axis=0)
    aa_dir = aa_dir / (np.linalg.norm(aa_dir) + 1e-10)

    # Orthogonal component of SM w.r.t. AA
    proj_sm_on_aa = np.dot(sm_dir, aa_dir) * aa_dir
    ortho_dir = sm_dir - proj_sm_on_aa
    ortho_dir = ortho_dir / (np.linalg.norm(ortho_dir) + 1e-10)

    directions[layer_idx] = {
        "sm_dir": sm_dir,
        "aa_dir": aa_dir,
        "ortho_dir": ortho_dir,
    }

    cos_sm_aa = np.dot(sm_dir, aa_dir)
    print(f"Layer {layer_idx:>2}: cos(SM, AA) = {cos_sm_aa:+.4f}  "
          f"|bare|={len(bare_vecs)} |quoted_human|={len(quoted_human_vecs)} "
          f"|asst|={len(asst_vecs)} |pirate|={len(pirate_vecs)}")

## Project All Conditions

For every condition we project onto the **assistant axis** and the
**orthogonal (pure self-model)** axis, then print a summary table.

In [None]:
projections = defaultdict(dict)  # projections[layer_idx][cond_key] = (aa_proj, ortho_proj, n)

for layer_idx in analysis_layers:
    d = directions[layer_idx]
    print(f"\n{'='*70}")
    print(f"Layer {layer_idx}")
    print(f"{'Condition':<25} {'AA proj':>10} {'Ortho proj':>12} {'N':>5}")
    print("-" * 55)
    for cond_key in CONDITIONS:
        vecs = get_all_activations_for_condition(cond_key, layer_idx)
        if len(vecs) == 0:
            projections[layer_idx][cond_key] = (0.0, 0.0, 0)
            continue
        aa_projs = vecs @ d["aa_dir"]
        ortho_projs = vecs @ d["ortho_dir"]
        mean_aa = aa_projs.mean()
        mean_ortho = ortho_projs.mean()
        projections[layer_idx][cond_key] = (float(mean_aa), float(mean_ortho), len(vecs))
        print(f"{cond_key:<25} {mean_aa:>+10.4f} {mean_ortho:>+12.4f} {len(vecs):>5}")

## Framing-Type Analysis

Group conditions by their framing type and compare the mean orthogonal
projection at the focus layer. This tells us which framing types produce
"I" tokens most similar to genuine self-reference vs. quoted speech.

In [None]:
focus_layer = 8

FRAMING_GROUPS = {
    "identity": [k for k in CONDITIONS if k.startswith("identity_")],
    "behavior": [k for k in CONDITIONS if k.startswith("behavior_")],
    "bare": ["bare"],
    "negation": [k for k in CONDITIONS if k.startswith("negation_")],
    "derol": [k for k in CONDITIONS if k.startswith("derol_")],
    "authentic": [k for k in CONDITIONS if k.startswith("authentic_")],
    "quoted": [k for k in CONDITIONS if k.startswith("quoted_")],
}

print(f"Framing analysis at layer {focus_layer}")
print(f"{'Framing type':<15} {'Mean ortho':>12} {'Conditions':>12} {'Total N':>8}")
print("-" * 50)
for framing, cond_keys in FRAMING_GROUPS.items():
    ortho_vals = []
    total_n = 0
    for ck in cond_keys:
        aa_p, ort_p, n = projections[focus_layer][ck]
        # Weight by n
        ortho_vals.extend([ort_p] * n)
        total_n += n
    mean_ort = np.mean(ortho_vals) if ortho_vals else 0.0
    print(f"{framing:<15} {mean_ort:>+12.4f} {len(cond_keys):>12} {total_n:>8}")

## Multi-I Consistency

For responses containing 2 or more "I" tokens, we compute the **within-response
variance** of orthogonal projections. Low variance indicates that the self-model
signal is stable across the response; high variance suggests it shifts.

In [None]:
print(f"Within-response variance of orthogonal projection (layer {focus_layer})")
print(f"{'Condition':<25} {'Mean var':>10} {'Responses>=2I':>15} {'Total I':>8}")
print("-" * 62)

d = directions[focus_layer]
for cond_key in CONDITIONS:
    variances = []
    total_i = 0
    multi_count = 0
    for sample_i_list in all_results[cond_key]:
        if len(sample_i_list) >= 2:
            multi_count += 1
            projs = []
            for entry in sample_i_list:
                if focus_layer in entry["layer_reps"]:
                    vec = entry["layer_reps"][focus_layer]
                    projs.append(np.dot(vec, d["ortho_dir"]))
            if len(projs) >= 2:
                variances.append(np.var(projs))
        total_i += len(sample_i_list)
    mean_var = np.mean(variances) if variances else float("nan")
    print(f"{cond_key:<25} {mean_var:>10.4f} {multi_count:>15} {total_i:>8}")

## Visualizations

1. **Scatter plot** — every I-token projected onto (AA, Orthogonal) axes, colored by framing type.
2. **Bar plot** — mean orthogonal projection per condition, sorted and colored by framing type.

In [None]:
FRAMING_COLORS = {
    "identity": "red",
    "behavior": "orange",
    "bare": "green",
    "negation": "purple",
    "derol": "blue",
    "authentic": "cyan",
    "quoted": "gray",
}

def framing_type_of(cond_key):
    for ft in ["identity", "behavior", "negation", "derol", "authentic", "quoted"]:
        if cond_key.startswith(ft):
            return ft
    if cond_key == "bare":
        return "bare"
    return "other"

d = directions[focus_layer]

fig, ax = plt.subplots(figsize=(14, 10))

# Plot individual points and collect means
framing_plotted = set()
for cond_key in CONDITIONS:
    ft = framing_type_of(cond_key)
    color = FRAMING_COLORS.get(ft, "black")
    vecs = get_all_activations_for_condition(cond_key, focus_layer)
    if len(vecs) == 0:
        continue
    aa_projs = vecs @ d["aa_dir"]
    ortho_projs = vecs @ d["ortho_dir"]
    label = ft if ft not in framing_plotted else None
    ax.scatter(aa_projs, ortho_projs, c=color, alpha=0.3, s=15, label=label)
    framing_plotted.add(ft)
    # Mean marker
    ax.scatter(aa_projs.mean(), ortho_projs.mean(), c=color, marker="x", s=100,
               linewidths=2, zorder=5)

ax.set_xlabel("Assistant Axis projection", fontsize=12)
ax.set_ylabel("Orthogonal (pure self-model) projection", fontsize=12)
ax.set_title(f"All conditions — Layer {focus_layer}", fontsize=14)
ax.legend(loc="best", fontsize=10)
ax.grid(True, alpha=0.3)
fig.tight_layout()
fig.savefig("v2_all_conditions_scatter_layer8.png", dpi=150)
print("Saved v2_all_conditions_scatter_layer8.png")
plt.close(fig)

In [None]:
# Sorted bar plot of orthogonal projections by condition
cond_keys_sorted = sorted(
    CONDITIONS.keys(),
    key=lambda ck: projections[focus_layer][ck][1],  # sort by ortho projection
)

fig, ax = plt.subplots(figsize=(16, 8))
x_pos = np.arange(len(cond_keys_sorted))
colors = [FRAMING_COLORS.get(framing_type_of(ck), "black") for ck in cond_keys_sorted]
ortho_vals = [projections[focus_layer][ck][1] for ck in cond_keys_sorted]

bars = ax.bar(x_pos, ortho_vals, color=colors, edgecolor="white", linewidth=0.5)
ax.set_xticks(x_pos)
ax.set_xticklabels(cond_keys_sorted, rotation=90, fontsize=7)
ax.set_ylabel("Mean orthogonal projection", fontsize=12)
ax.set_title(f"Orthogonal (pure self-model) projection by condition — Layer {focus_layer}", fontsize=13)
ax.axhline(0, color="black", linewidth=0.8)
ax.grid(True, axis="y", alpha=0.3)

# Legend
from matplotlib.patches import Patch
legend_elements = [Patch(facecolor=c, label=ft) for ft, c in FRAMING_COLORS.items()]
ax.legend(handles=legend_elements, loc="upper left", fontsize=9)

fig.tight_layout()
fig.savefig("v2_orthogonal_barplot_layer8.png", dpi=150)
print("Saved v2_orthogonal_barplot_layer8.png")
plt.close(fig)