# Token Position Functions for Entity Binding

This notebook demonstrates the **template-aware token position system** for entity binding tasks. Token positions are crucial for causal interventions - they tell us **WHERE** in the transformer to intervene to test causal hypotheses.

## Why Template-Aware Token Positions Matter

The new system solves a critical problem: **ambiguity when entities appear multiple times**.

Consider this prompt: `"Pete loves jam, and Ann loves pie. What does Pete love?"`

- "Pete" appears **twice**: once in the statement, once in the question
- The old system always found the **first occurrence**
- The new system **understands the template structure** and can distinguish:
  - Pete in the **statement** (tokens 0-1)
  - Pete in the **question** (token 12)

This notebook shows how to use the new system effectively.

In [2]:
# Setup
%load_ext autoreload
%autoreload 2

import sys

sys.path.append("..")

from causalab.tasks.entity_binding.config import (
    create_sample_love_config,
    create_sample_action_config,
)
from causalab.tasks.entity_binding.causal_models import (
    create_direct_causal_model,
    sample_valid_entity_binding_input,
)
from causalab.tasks.entity_binding.token_positions import (
    PromptParser,
    PromptTokenizer,
    get_entity_token_indices_structured,
)
from causalab.neural.pipeline import LMPipeline

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
nnsight is not detected. Please install via 'pip install nnsight' for nnsight backend.


## Part 1: The Ambiguity Problem (Why We Need Structure)

Let's start by demonstrating the problem that the new system solves.

In [3]:
# Load language model
print("Loading GPT-2 model...")
pipeline = LMPipeline("gpt2")
config = create_sample_love_config()
print("Model loaded!")

Loading GPT-2 model...
Model loaded!


In [4]:
# Create example where "Pete" appears in both statement and question
input_sample = {
    "entity_g0_e0": "Pete",
    "entity_g0_e1": "jam",
    "entity_g1_e0": "Ann",
    "entity_g1_e1": "pie",
    "query_group": 0,
    "query_indices": (0,),
    "answer_index": 1,
    "active_groups": 2,
    "entities_per_group": 2,
    "raw_input": "Pete loves jam, and Ann loves pie. What does Pete love?",
}

print("Prompt:", input_sample["raw_input"])
print("\nNote: 'Pete' appears TWICE in this prompt:")
print("  1. In the statement: 'Pete loves jam'")
print("  2. In the question: 'What does Pete love?'")
print("\nWhich 'Pete' do we want to intervene on?")

Prompt: Pete loves jam, and Ann loves pie. What does Pete love?

Note: 'Pete' appears TWICE in this prompt:
  1. In the statement: 'Pete loves jam'
  2. In the question: 'What does Pete love?'

Which 'Pete' do we want to intervene on?


In [5]:
# Show full tokenization
tokens = pipeline.load(input_sample)["input_ids"][0].tolist()
print("Full tokenization:")
print(f"Total tokens: {len(tokens)}\n")

for i, token_id in enumerate(tokens):
    token_str = pipeline.tokenizer.decode([token_id])
    if i in [0, 1]:
        marker = " <-- FIRST Pete (statement)"
    elif i == 12:
        marker = " <-- SECOND Pete (question)"
    else:
        marker = ""
    print(f"Token {i:2d}: {repr(token_str):20s}{marker}")

Full tokenization:
Total tokens: 15

Token  0: 'P'                  <-- FIRST Pete (statement)
Token  1: 'ete'                <-- FIRST Pete (statement)
Token  2: ' loves'            
Token  3: ' jam'              
Token  4: ','                 
Token  5: ' and'              
Token  6: ' Ann'              
Token  7: ' loves'            
Token  8: ' pie'              
Token  9: '.'                 
Token 10: ' What'             
Token 11: ' does'             
Token 12: ' Pete'              <-- SECOND Pete (question)
Token 13: ' love'             
Token 14: '?'                 


## Part 2: Template-Aware Parsing

The new system **parses the prompt structure** using template knowledge.

In [6]:
# Parse the prompt into structured segments
parser = PromptParser(config)
parsed = parser.parse_prompt(input_sample)

print("Parsed prompt structure:")
print(f"  Total segments: {len(parsed.segments)}")
print(f"  Statement region: chars {parsed.statement_region}")
print(f"  Question region: chars {parsed.question_region}")
print("\nThe parser understands the template structure!")

Parsed prompt structure:
  Total segments: 11
  Statement region: chars (0, 34)
  Question region: chars (35, 55)

The parser understands the template structure!


In [7]:
# Show the parsed segments
print("\nPrompt broken down into segments:\n")
for i, seg in enumerate(parsed.segments):
    # Determine which region this segment belongs to
    in_statement = (
        parsed.statement_region
        and parsed.statement_region[0] <= seg.char_start < parsed.statement_region[1]
    )
    in_question = (
        parsed.question_region
        and parsed.question_region[0] <= seg.char_start < parsed.question_region[1]
    )
    region = "STMT" if in_statement else ("QUES" if in_question else "???")

    print(f"{i:2d}. [{region}] {seg.segment_type:15s} | '{seg.text}'")


Prompt broken down into segments:

 0. [STMT] entity          | 'Pete'
 1. [STMT] template_text   | ' loves '
 2. [STMT] entity          | 'jam'
 3. [STMT] delimiter       | ', and '
 4. [STMT] entity          | 'Ann'
 5. [STMT] template_text   | ' loves '
 6. [STMT] entity          | 'pie'
 7. [STMT] delimiter       | '.'
 8. [QUES] question_text   | 'What does '
 9. [QUES] entity          | 'Pete'
10. [QUES] question_text   | ' love?'


### Understanding the Segments

Each segment has a **type**:
- **entity**: An entity from the binding matrix
- **template_text**: Text from the statement template
- **delimiter**: Punctuation between statements (`,` `, and`, `.`)
- **question_text**: Text from the question template

Notice that **both Pete segments are tracked** - one in the statement, one in the question!

In [8]:
# Add token positions to the parsed structure
tokenizer = PromptTokenizer(pipeline)
parsed = tokenizer.tokenize_prompt(parsed)

print("After tokenization, each segment has token positions:\n")
for seg in parsed.segments:
    if seg.segment_type == "entity":
        print(
            f"Entity g{seg.group_idx}_e{seg.entity_idx} ('{seg.text}'): tokens {seg.token_positions}"
        )

After tokenization, each segment has token positions:

Entity g0_e0 ('Pete'): tokens [0, 1]
Entity g0_e1 ('jam'): tokens [3]
Entity g1_e0 ('Ann'): tokens [6]
Entity g1_e1 ('pie'): tokens [8]
Entity g0_e0 ('Pete'): tokens [12]


## Part 3: Region-Aware Token Position Queries

The key feature: **specify which region** you want!

In [9]:
# Get Pete from the STATEMENT region
pete_statement = get_entity_token_indices_structured(
    input_sample, pipeline, config, 0, 0, region="statement"
)
print(f"Pete in STATEMENT: tokens {pete_statement}")

# Get Pete from the QUESTION region
pete_question = get_entity_token_indices_structured(
    input_sample, pipeline, config, 0, 0, region="question"
)
print(f"Pete in QUESTION:  tokens {pete_question}")

# Get first occurrence (backward compatible)
pete_first = get_entity_token_indices_structured(
    input_sample, pipeline, config, 0, 0, region=None
)
print(f"Pete first occurrence (region=None): tokens {pete_first}")

print("\n" + "=" * 70)
print("KEY INSIGHT: We can now distinguish between duplicate entities!")
print("=" * 70)
assert pete_statement != pete_question, "Statement and question should differ!"
print("✓ Verified: Statement and question Petes have different token positions")

Pete in STATEMENT: tokens [0, 1]
Pete in QUESTION:  tokens [12]
Pete first occurrence (region=None): tokens [0, 1]

KEY INSIGHT: We can now distinguish between duplicate entities!
✓ Verified: Statement and question Petes have different token positions


## Part 4: Visual Verification

Let's visualize which tokens are selected for each region.

In [10]:
# Show which tokens correspond to which occurrences
pete_statement_tokens = get_entity_token_indices_structured(
    input_sample, pipeline, config, 0, 0, region="statement"
)
pete_question_tokens = get_entity_token_indices_structured(
    input_sample, pipeline, config, 0, 0, region="question"
)

print("Token-by-token breakdown:\n")
for i, token_id in enumerate(tokens):
    token_str = pipeline.tokenizer.decode([token_id])
    markers = []
    if i in pete_statement_tokens:
        markers.append("PETE(statement)")
    if i in pete_question_tokens:
        markers.append("PETE(question)")
    marker_str = " <-- " + ", ".join(markers) if markers else ""
    print(f"Token {i:2d}: {repr(token_str):20s}{marker_str}")

Token-by-token breakdown:

Token  0: 'P'                  <-- PETE(statement)
Token  1: 'ete'                <-- PETE(statement)
Token  2: ' loves'            
Token  3: ' jam'              
Token  4: ','                 
Token  5: ' and'              
Token  6: ' Ann'              
Token  7: ' loves'            
Token  8: ' pie'              
Token  9: '.'                 
Token 10: ' What'             
Token 11: ' does'             
Token 12: ' Pete'              <-- PETE(question)
Token 13: ' love'             
Token 14: '?'                 


## Part 5: All Entities in the Prompt

Let's find token positions for all entities.

In [11]:
# Find all entities using the parsed structure
print(input_sample)
print("All entities in the prompt:\n")

entities = [
    (0, 0, "Pete", "statement"),
    (0, 1, "jam", "statement"),
    (1, 0, "Ann", "statement"),
    (1, 1, "pie", "statement"),
    (0, 0, "Pete", "question"),  # Pete appears again in question
]

for group_idx, entity_idx, name, region in entities:
    try:
        tokens = get_entity_token_indices_structured(
            input_sample, pipeline, config, group_idx, entity_idx, region=region
        )
        print(f"{name:10} in {region:10} -> tokens {tokens}")
    except ValueError:
        print(f"{name:10} in {region:10} -> NOT FOUND")

{'entity_g0_e0': 'Pete', 'entity_g0_e1': 'jam', 'entity_g1_e0': 'Ann', 'entity_g1_e1': 'pie', 'query_group': 0, 'query_indices': (0,), 'answer_index': 1, 'active_groups': 2, 'entities_per_group': 2, 'raw_input': 'Pete loves jam, and Ann loves pie. What does Pete love?'}
All entities in the prompt:

Pete       in statement  -> tokens [0, 1]
jam        in statement  -> tokens [3]
Ann        in statement  -> tokens [6]
pie        in statement  -> tokens [8]
Pete       in question   -> tokens [12]


## Part 6: Multi-Token Entities with Region Specification

The system handles multi-token entities within specific regions.

In [12]:
# Create example with multi-token entity
multi_token_sample = {
    "entity_g0_e0": "Elizabeth",
    "entity_g0_e1": "strawberry jam",
    "entity_g1_e0": "Margaret",
    "entity_g1_e1": "apple pie",
    "query_group": 0,
    "query_indices": (0,),
    "answer_index": 1,
    "active_groups": 2,
    "entities_per_group": 2,
    "raw_input": "Elizabeth loves strawberry jam, and Margaret loves apple pie. What does Elizabeth love?",
}

print("Multi-token example:")
print(f"Prompt: {multi_token_sample['raw_input']}\n")

# Find multi-token entity in statement
jam_tokens = get_entity_token_indices_structured(
    multi_token_sample, pipeline, config, 0, 1, region="statement"
)
print(f"'strawberry jam' in statement: tokens {jam_tokens}")
print(f"  Token count: {len(jam_tokens)}")

# Get specific token within multi-token entity
if len(jam_tokens) > 1:
    first_token = get_entity_token_indices_structured(
        multi_token_sample, pipeline, config, 0, 1, region="statement", token_idx=0
    )
    print(f"  First token only (token_idx=0): {first_token}")

Multi-token example:
Prompt: Elizabeth loves strawberry jam, and Margaret loves apple pie. What does Elizabeth love?

'strawberry jam' in statement: tokens [2, 3]
  Token count: 2
  First token only (token_idx=0): [2]


## Part 7: Action Tasks (3-Entity Groups)

The system works with more complex templates too.

In [13]:
# Create action configuration
action_config = create_sample_action_config()
action_model = create_direct_causal_model(action_config)

# Generate action example
action_sample = sample_valid_entity_binding_input(action_config)
action_output = action_model.run_forward(action_sample)
action_sample["raw_input"] = action_output["raw_input"]

print("Action task example (person, object, location):")
print(f"Prompt: {action_sample['raw_input']}")
print(f"Expected answer: {action_output['raw_output']}\n")

Action task example (person, object, location):
Prompt: Kate put watch in the box, Dan put coin in the basket, and Pete put key in the pocket. Where was watch put?
Expected answer: box



In [14]:
# Parse action task structure
action_parser = PromptParser(action_config)
action_parsed = action_parser.parse_prompt(action_sample)
action_tokenizer = PromptTokenizer(pipeline)
action_parsed = action_tokenizer.tokenize_prompt(action_parsed)

print("Action task structure:")
print(f"  {len(action_parsed.segments)} segments")
print(f"  Statement region: chars {action_parsed.statement_region}")
print(f"  Question region: chars {action_parsed.question_region}\n")

# Show entities in statement
query_group = action_sample["query_group"]
print(f"Entities in query group {query_group} (statement):")

for e in range(action_config.max_entities_per_group):
    entity_key = f"entity_g{query_group}_e{e}"
    entity = action_sample.get(entity_key)

    if entity is not None:
        role = action_config.entity_roles[e]
        tokens = get_entity_token_indices_structured(
            action_sample, pipeline, action_config, query_group, e, region="statement"
        )
        print(f"  {role.capitalize():10} ('{entity}'): tokens {tokens}")

Action task structure:
  21 segments
  Statement region: chars (0, 86)
  Question region: chars (87, 107)

Entities in query group 0 (statement):
  Person     ('Kate'): tokens [0]
  Object     ('watch'): tokens [2]
  Location   ('box'): tokens [5]


## Part 8: Rich Structural Information

Beyond just finding entities, the parsed structure gives us rich metadata.

In [15]:
# Query the parsed structure
print("Querying the parsed structure:\n")

# Find all entity segments
entity_segments = [s for s in parsed.segments if s.segment_type == "entity"]
print(f"Total entity segments: {len(entity_segments)}")

# Find all delimiters
delimiter_segments = [s for s in parsed.segments if s.segment_type == "delimiter"]
print(f"Total delimiter segments: {len(delimiter_segments)}")
for seg in delimiter_segments:
    print(f"  Delimiter: '{seg.text}' at chars {seg.char_start}-{seg.char_end}")

# Get specific entity segments
print("\nPete segments:")
pete_segments = parsed.get_entity_segments(0, 0)
for seg in pete_segments:
    region = "statement" if seg.char_start < parsed.statement_region[1] else "question"
    print(f"  '{seg.text}' in {region}: tokens {seg.token_positions}")

Querying the parsed structure:

Total entity segments: 5
Total delimiter segments: 2
  Delimiter: ', and ' at chars 14-20
  Delimiter: '.' at chars 33-34

Pete segments:
  'Pete' in statement: tokens [0, 1]
  'Pete' in question: tokens [12]


## Part 9: Multiple Random Examples

Let's see the system work across different prompts.

In [16]:
print("Generating 5 random examples with region-aware token positions:\n")
print("=" * 80)

model = create_direct_causal_model(config)

for i in range(5):
    # Generate new sample
    sample = sample_valid_entity_binding_input(config)
    output = model.run_forward(sample)
    sample["raw_input"] = output["raw_input"]

    print(f"\nExample {i + 1}:")
    print(f"Prompt: {sample['raw_input']}")
    print(f"Expected answer: {output['raw_output']}")

    # Get query entity from statement
    query_group = sample["query_group"]
    query_idx = sample["query_indices"][0]

    try:
        statement_tokens = get_entity_token_indices_structured(
            sample, pipeline, config, query_group, query_idx, region="statement"
        )
        print(f"  Query entity in statement: tokens {statement_tokens}")
    except ValueError:
        print("  Query entity NOT in statement")

    try:
        question_tokens = get_entity_token_indices_structured(
            sample, pipeline, config, query_group, query_idx, region="question"
        )
        print(f"  Query entity in question: tokens {question_tokens}")
    except ValueError:
        print("  Query entity NOT in question")

    print("-" * 80)

Generating 5 random examples with region-aware token positions:


Example 1:
Prompt: Kate loves tea, Tim loves jam, and Ann loves pie. What does Ann love?
Expected answer: pie
  Query entity in statement: tokens [9]
  Query entity in question: tokens [15]
--------------------------------------------------------------------------------

Example 2:
Prompt: Kate loves soup, Tim loves tea, and Sue loves cake. What does Sue love?
Expected answer: cake
  Query entity in statement: tokens [9]
  Query entity in question: tokens [15]
--------------------------------------------------------------------------------

Example 3:
Prompt: Bob loves pie, and Sue loves bread. What does Bob love?
Expected answer: pie
  Query entity in statement: tokens [0]
  Query entity in question: tokens [11]
--------------------------------------------------------------------------------

Example 4:
Prompt: Sue loves jam, Pete loves pie, and Tim loves soup. What does Sue love?
Expected answer: jam
  Query entity in 

## Part 10: Error Handling

The system provides clear errors when entities aren't found in the specified region.

In [17]:
# Try to get an entity from the wrong region
simple_sample = {
    "entity_g0_e0": "Pete",
    "entity_g0_e1": "jam",
    "active_groups": 1,
    "entities_per_group": 2,
    "raw_input": "Pete loves jam.",  # No question!
}

print("Example without question:")
print(f"Prompt: {simple_sample['raw_input']}\n")

# This works - Pete is in the statement
statement_tokens = get_entity_token_indices_structured(
    simple_sample, pipeline, config, 0, 0, region="statement"
)
print(f"Pete in statement: {statement_tokens}")

# This raises an error - no question region
try:
    question_tokens = get_entity_token_indices_structured(
        simple_sample, pipeline, config, 0, 0, region="question"
    )
    print(f"Pete in question: {question_tokens}")
except ValueError as e:
    print(f"\n✓ Expected error: {e}")
    print("  The system correctly reports that there's no question region!")

Example without question:
Prompt: Pete loves jam.

Pete in statement: [0, 1]

✓ Expected error: Entity g0_e0 not found in question region
  The system correctly reports that there's no question region!


## Summary: Advantages of Template-Aware Token Positions

### Key Improvements:

1. **✓ Unambiguous**: Explicitly specify `region='statement'` or `region='question'`
2. **✓ Correct**: No more finding wrong occurrence of duplicate entities
3. **✓ Rich metadata**: Full structural information (segments, regions, types)
4. **✓ Template-aware**: Uses actual template structure, not blind substring matching
5. **✓ Extensible**: Easy to add new segment types or query capabilities

### API Functions:

- **`get_entity_token_indices_structured`**: Get token positions with region specification
- **`PromptParser`**: Parse prompts into structured segments
- **`PromptTokenizer`**: Add token position information to segments
- **`ParsedPrompt`**: Rich data structure with regions and segments

### For Intervention Experiments:

```python
# Get entity from statement (for testing representation)
statement_tokens = get_entity_token_indices_structured(
    sample, pipeline, config, group_idx, entity_idx, region='statement'
)

# Get entity from question (for testing retrieval)
question_tokens = get_entity_token_indices_structured(
    sample, pipeline, config, group_idx, entity_idx, region='question'
)
```

### Next Steps:

- Use region-specific token positions in intervention experiments
- Test how interventions at statement vs question positions differ
- Build precise causal abstraction claims about entity binding