# Token Position Functions for Entity Binding

This notebook demonstrates the **template-aware token position system** for entity binding tasks. Token positions are crucial for causal interventions - they tell us **WHERE** in the transformer to intervene to test causal hypotheses.

## How It Works: The Mega Template

The system uses a **mega template** with unique variable names for each entity position:
- Statement entities: `{g0_e0}`, `{g0_e1}`, `{g1_e0}`, `{g1_e1}`, ...
- Question entities: `{person}`, `{query_entity}`, `{food}`, ...

This means each entity has a **unique variable name** in the template, eliminating ambiguity even when the same entity value appears multiple times.

Example mega template:
```
"{g0_e0} loves {g0_e1} and {g1_e0} loves {g1_e1}. What does {person} love?"
```

Filled with values:
```
"Pete loves jam and Ann loves pie. What does Pete love?"
```

Even though "Pete" appears twice, we can distinguish them:
- `g0_e0` = Pete in the statement
- `person` = Pete in the question

In [13]:
# Setup
%load_ext autoreload
%autoreload 2

import sys
sys.path.append("..")

from tasks.entity_binding.config import (
    create_sample_love_config,
    create_sample_action_config,
)
from tasks.entity_binding.causal_models import (
    create_positional_entity_causal_model,
    sample_valid_entity_binding_input,
)
from tasks.entity_binding.token_positions import (
    get_entity_token_positions,
    get_question_entity_token_positions,
    get_statement_entity_token_positions,
)
from neural.pipeline import LMPipeline

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## Part 1: The Ambiguity Problem (Why We Need Structure)

Let's start by demonstrating the problem that the new system solves.

In [14]:
# Load language model
print("Loading GPT-2 model...")
pipeline = LMPipeline("gpt2")
config = create_sample_love_config()
print("Model loaded!")

Loading GPT-2 model...
Model loaded!


In [15]:
# Create example where "Pete" appears in both statement and question
input_sample = {
    "entity_g0_e0": "Pete",
    "entity_g0_e1": "jam",
    "entity_g1_e0": "Ann",
    "entity_g1_e1": "pie",
    "query_group": 0,
    "query_indices": (0,),
    "answer_index": 1,
    "active_groups": 2,
    "entities_per_group": 2,
    "raw_input": "Pete loves jam, and Ann loves pie. What does Pete love?",
}

print("Prompt:", input_sample["raw_input"])
print("\nNote: 'Pete' appears TWICE in this prompt:")
print("  1. In the statement: 'Pete loves jam'")
print("  2. In the question: 'What does Pete love?'")
print("\nWhich 'Pete' do we want to intervene on?")

Prompt: Pete loves jam, and Ann loves pie. What does Pete love?

Note: 'Pete' appears TWICE in this prompt:
  1. In the statement: 'Pete loves jam'
  2. In the question: 'What does Pete love?'

Which 'Pete' do we want to intervene on?


In [16]:
# Show full tokenization
tokens = pipeline.load(input_sample)["input_ids"][0].tolist()
print("Full tokenization:")
print(f"Total tokens: {len(tokens)}\n")

for i, token_id in enumerate(tokens):
    token_str = pipeline.tokenizer.decode([token_id])
    if i in [0, 1]:
        marker = " <-- FIRST Pete (statement)"
    elif i == 12:
        marker = " <-- SECOND Pete (question)"
    else:
        marker = ""
    print(f"Token {i:2d}: {repr(token_str):20s}{marker}")

Full tokenization:
Total tokens: 15

Token  0: 'P'                  <-- FIRST Pete (statement)
Token  1: 'ete'                <-- FIRST Pete (statement)
Token  2: ' loves'            
Token  3: ' jam'              
Token  4: ','                 
Token  5: ' and'              
Token  6: ' Ann'              
Token  7: ' loves'            
Token  8: ' pie'              
Token  9: '.'                 
Token 10: ' What'             
Token 11: ' does'             
Token 12: ' Pete'              <-- SECOND Pete (question)
Token 13: ' love'             
Token 14: '?'                 


## Part 2: The Mega Template System

Let's see how the mega template creates unique variable names for each entity position.

In [17]:
# Show the mega template for this sample
mega_template = config.build_mega_template(
    active_groups=input_sample["active_groups"],
    query_indices=input_sample["query_indices"],
    answer_index=input_sample["answer_index"],
)

print("Mega Template:")
print(f"  {mega_template}")
print("\nVariable names in template:")
print("  Statement: g0_e0 (Pete), g0_e1 (jam), g1_e0 (Ann), g1_e1 (pie)")
print("  Question: person (Pete), query_entity (Pete)")
print("\nEach position has a UNIQUE variable name!")

Mega Template:
  {g0_e0} loves {g0_e1} and {g1_e0} loves {g1_e1}. What does {person} love?

Variable names in template:
  Statement: g0_e0 (Pete), g0_e1 (jam), g1_e0 (Ann), g1_e1 (pie)
  Question: person (Pete), query_entity (Pete)

Each position has a UNIQUE variable name!


## Part 3: Getting Token Positions

Now let's use the token position functions to find where each entity is in the tokenized prompt.

In [18]:
# Get Pete from the STATEMENT (group 0, entity 0)
pete_statement = get_entity_token_positions(
    input_sample, pipeline, config, group_idx=0, entity_idx=0
)
print(f"Pete in STATEMENT (g0_e0): tokens {pete_statement}")

# Get Pete from the QUESTION (using role name)
pete_question = get_question_entity_token_positions(
    input_sample, pipeline, config, role_name="person"
)
print(f"Pete in QUESTION (person): tokens {pete_question}")

print("\n" + "=" * 70)
print("KEY INSIGHT: Same entity value, different token positions!")
print("=" * 70)
print(f"Statement Pete: tokens {pete_statement}")
print(f"Question Pete:  tokens {pete_question}")

Pete in STATEMENT (g0_e0): tokens [0, 1]
Pete in QUESTION (person): tokens []

KEY INSIGHT: Same entity value, different token positions!
Statement Pete: tokens [0, 1]
Question Pete:  tokens []


## Part 4: Visual Verification

Let's visualize which tokens are selected for each entity.

In [19]:
# Show which tokens correspond to which entities
print("Token-by-token breakdown:\n")

# Get all entity positions
entity_positions = {
    "g0_e0 (Pete)": get_entity_token_positions(input_sample, pipeline, config, 0, 0),
    "g0_e1 (jam)": get_entity_token_positions(input_sample, pipeline, config, 0, 1),
    "g1_e0 (Ann)": get_entity_token_positions(input_sample, pipeline, config, 1, 0),
    "g1_e1 (pie)": get_entity_token_positions(input_sample, pipeline, config, 1, 1),
    "question (Pete)": get_question_entity_token_positions(input_sample, pipeline, config, role_name="person"),
}

for i, token_id in enumerate(tokens):
    token_str = pipeline.tokenizer.decode([token_id])
    markers = []
    for name, positions in entity_positions.items():
        if i in positions:
            markers.append(name)
    marker_str = " <-- " + ", ".join(markers) if markers else ""
    print(f"Token {i:2d}: {repr(token_str):20s}{marker_str}")

Token-by-token breakdown:

Token  0: 'P'                  <-- g0_e0 (Pete)
Token  1: 'ete'                <-- g0_e0 (Pete)
Token  2: ' loves'            
Token  3: ' jam'               <-- g0_e1 (jam)
Token  4: ','                 
Token  5: ' and'               <-- g1_e0 (Ann)
Token  6: ' Ann'              
Token  7: ' loves'             <-- g1_e1 (pie)
Token  8: ' pie'              
Token  9: '.'                 
Token 10: ' What'             
Token 11: ' does'             
Token 12: ' Pete'             
Token 13: ' love'             
Token 14: '?'                 


## Part 5: All Statement Entities

Let's find token positions for all entities in the statements.

In [20]:
# Find all entities in the prompt
print("All statement entities:\n")

for g in range(input_sample["active_groups"]):
    for e in range(input_sample["entities_per_group"]):
        entity_value = input_sample[f"entity_g{g}_e{e}"]
        role = config.entity_roles[e]
        tokens = get_entity_token_positions(input_sample, pipeline, config, g, e)
        print(f"  g{g}_e{e} ({role:6}): '{entity_value:10}' -> tokens {tokens}")

print("\nQuestion entity:")
pete_q = get_question_entity_token_positions(input_sample, pipeline, config, role_name="person")
print(f"  person: '{input_sample['entity_g0_e0']}' -> tokens {pete_q}")

All statement entities:

  g0_e0 (person): 'Pete      ' -> tokens [0, 1]
  g0_e1 (food  ): 'jam       ' -> tokens [3]
  g1_e0 (person): 'Ann       ' -> tokens [5]
  g1_e1 (food  ): 'pie       ' -> tokens [7]

Question entity:
  person: 'Pete' -> tokens []


## Part 6: Multi-Token Entities

The system handles multi-token entities and lets you select specific tokens within them.

In [21]:
# Create a sample with multi-token entities
model = create_positional_entity_causal_model(config)

multi_token_sample = {
    "entity_g0_e0": "Elizabeth",
    "entity_g0_e1": "strawberry jam",
    "entity_g1_e0": "Margaret",
    "entity_g1_e1": "apple pie",
    "query_group": 0,
    "query_indices": (0,),
    "answer_index": 1,
    "active_groups": 2,
    "entities_per_group": 2,
}
output = model.run_forward(multi_token_sample)
multi_token_sample["raw_input"] = output["raw_input"]

print("Multi-token example:")
print(f"Prompt: {multi_token_sample['raw_input']}\n")

# Get all tokens for 'strawberry jam'
all_jam_tokens = get_entity_token_positions(multi_token_sample, pipeline, config, 0, 1)
print(f"'strawberry jam' all tokens: {all_jam_tokens}")

# Get just the first token (token_idx=0)
first_token = get_entity_token_positions(multi_token_sample, pipeline, config, 0, 1, token_idx=0)
print(f"'strawberry jam' first token (token_idx=0): {first_token}")

# Get just the last token (token_idx=-1)
last_token = get_entity_token_positions(multi_token_sample, pipeline, config, 0, 1, token_idx=-1)
print(f"'strawberry jam' last token (token_idx=-1): {last_token}")

Multi-token example:
Prompt: Elizabeth loves strawberry jam and Margaret loves apple pie. What does Pete love?

'strawberry jam' all tokens: [2, 3]
'strawberry jam' first token (token_idx=0): [2]
'strawberry jam' last token (token_idx=-1): [3]


## Part 7: Action Tasks (3-Entity Groups)

The system works with more complex templates too - here's a 3-entity example.

In [22]:
# Create action configuration (person, object, location)
action_config = create_sample_action_config()
action_model = create_positional_entity_causal_model(action_config)

# Generate action example
action_sample = sample_valid_entity_binding_input(action_config)
action_output = action_model.run_forward(action_sample)
action_sample["raw_input"] = action_output["raw_input"]

print("Action task example (person, object, location):")
print(f"Prompt: {action_sample['raw_input']}")
print(f"Expected answer: {action_output['raw_output']}\n")

# Show the mega template
action_mega = action_config.build_mega_template(
    action_sample["active_groups"],
    action_sample["query_indices"],
    action_sample["answer_index"],
)
print(f"Mega template:\n  {action_mega}")

Action task example (person, object, location):
Prompt: Ann put book in the pocket, Tim put watch in the bag, and Sue put jam in the box. Who put something in the bag?
Expected answer: Tim

Mega template:
  {g0_e0} put {g0_e1} in the {g0_e2}, {g1_e0} put {g1_e1} in the {g1_e2}, and {g2_e0} put {g2_e1} in the {g2_e2}. Who put something in the {location}?


## Part 8: Using Role Names

You can also use role names instead of entity indices for more readable code.

In [None]:
query_group = action_sample["query_group"]
# Use role_name instead of entity_idx
# For statement entities
person_tokens = get_statement_entity_token_positions(
    action_sample, pipeline, action_config, 
    group_idx=query_group, role_name="person"
)
print(f"Person in statement (using role_name): {person_tokens}")

# Same thing using entity_idx
person_tokens_idx = get_entity_token_positions(
    action_sample, pipeline, action_config,
    group_idx=query_group, entity_idx=0  # person is entity 0
)
print(f"Person in statement (using entity_idx=0): {person_tokens_idx}")

# For question entities - role_name is often more natural
query_indices = action_sample["query_indices"]
for q_idx in query_indices:
    role = action_config.entity_roles[q_idx]
    try:
        q_tokens = get_question_entity_token_positions(
            action_sample, pipeline, action_config, role_name=role
        )
        print(f"\n{role.capitalize()} in question: {q_tokens}")
    except ValueError as e:
        print(f"\n{role.capitalize()} not in question (expected for some query patterns)")

Person in statement (using role_name): [7]
Person in statement (using entity_idx=0): [7]

Location in question: [27]


## Part 9: Multiple Random Examples

Let's see the system work across different prompts.

In [None]:
print("Generating 5 random examples:\n")
print("=" * 80)

for i in range(5):
    # Generate new sample
    sample = sample_valid_entity_binding_input(config)
    output = model.run_forward(sample)
    sample["raw_input"] = output["raw_input"]

    print(f"\nExample {i + 1}:")
    print(f"Prompt: {sample['raw_input']}")
    print(f"Expected answer: {output['raw_output']}")

    # Get query entity from statement and question
    query_group = sample["query_group"]
    query_idx = sample["query_indices"][0]
    query_role = config.entity_roles[query_idx]

    statement_tokens = get_entity_token_positions(
        sample, pipeline, config, query_group, query_idx
    )
    print(f"  Query entity in statement (g{query_group}_e{query_idx}): tokens {statement_tokens}")

    try:
        question_tokens = get_question_entity_token_positions(
            sample, pipeline, config, role_name=query_role
        )
        print(f"  Query entity in question ({query_role}): tokens {question_tokens}")
    except ValueError:
        print(f"  Query entity not in question")

    print("-" * 80)

## Summary: Template-Aware Token Positions

### Key Concepts:

1. **Mega Template**: Combines statement and question templates with unique variable names
   - Statement entities: `g0_e0`, `g0_e1`, `g1_e0`, `g1_e1`, ...
   - Question entities: `person`, `food`, `query_entity`, ...

2. **No Ambiguity**: Each entity position has a unique variable name, even when the same value appears multiple times

3. **Flexible Indexing**: Use `token_idx` to select specific tokens within multi-token entities

### API Functions:

```python
# Statement entities (by group and entity index)
get_entity_token_positions(sample, pipeline, config, group_idx, entity_idx, token_idx=None)

# Statement entities (by role name)
get_statement_entity_token_positions(sample, pipeline, config, group_idx, role_name="person")

# Question entities (by role name or entity index)
get_question_entity_token_positions(sample, pipeline, config, role_name="person")
get_question_entity_token_positions(sample, pipeline, config, entity_idx=0)
```

### For Intervention Experiments:

```python
# Get entity from statement (for testing representation)
statement_tokens = get_entity_token_positions(sample, pipeline, config, group_idx=0, entity_idx=0)

# Get entity from question (for testing retrieval)
question_tokens = get_question_entity_token_positions(sample, pipeline, config, role_name="person")

# Get specific token within multi-token entity
last_token = get_entity_token_positions(sample, pipeline, config, 0, 1, token_idx=-1)
```