# Entity Binding Task Implementation Walkthrough

This notebook demonstrates the step-by-step implementation of a flexible entity binding task system for causal abstraction experiments. Entity binding tasks test how language models understand relationships between entities in structured contexts.

## What are Entity Binding Tasks?

Entity binding tasks involve:
1. **Entity Groups**: Sets of entities that appear together (e.g., "Pete, jam" or "Ann, pie")
2. **Templates**: Structured text that describes relationships (e.g., "X loves Y")
3. **Queries**: Questions about specific entities or relationships

Example: "Pete loves jam, Ann loves pie. What does Pete love?"

## Step 1: Core Data Structures

First, we need flexible data structures that can handle arbitrary numbers of entity groups and entities per group.

In [None]:
# autoreload imports
%load_ext autoreload
%autoreload 2

from causalab.tasks.entity_binding.config import (
    EntityGroup,
    BindingMatrix,
    EntityBindingTaskConfig,
    create_sample_love_config,
    create_sample_action_config,
)
from causalab.tasks.entity_binding.causal_models import (
    create_positional_entity_causal_model,
    sample_valid_entity_binding_input,
)
from causalab.neural.token_position_builder import Template

### EntityBindingTaskConfig: The Blueprint

This configuration class defines everything about our task:

In [None]:
# Create a sample "love" configuration
love_config = create_sample_love_config()

print("Love Task Configuration:")
print(f"Max groups: {love_config.max_groups}")
print(f"Max entities per group: {love_config.max_entities_per_group}")
print(f"Entity roles: {love_config.entity_roles}")
print(f"People: {love_config.entity_pools[0]}")
print(f"Foods: {love_config.entity_pools[1]}")
print(f"Statement template: {love_config.statement_template}")
print(f"Question templates: {love_config.question_templates}")

Love Task Configuration:
Max groups: 3
Max entities per group: 2
Entity roles: {0: 'person', 1: 'food'}
People: ['Pete', 'Ann', 'Tim', 'Bob', 'Sue', 'Kate']
Foods: ['jam', 'pie', 'cake', 'bread', 'soup', 'tea']
Statement template: {e0} loves {e1}
Question templates: {((0,), 1): 'What does {person} love?', ((1,), 0): 'Who loves {food}?'}


### EntityGroup and BindingMatrix: Representing Specific Instances

While the config is the blueprint, EntityGroup and BindingMatrix represent specific instances:

In [None]:
# Create specific entity groups
group1 = EntityGroup(["Pete", "jam"], group_index=0)
group2 = EntityGroup(["Ann", "pie"], group_index=1)

print("Entity Groups:")
print(f"Group 1: {group1}")
print(f"Group 2: {group2}")

# Create binding matrix
matrix = BindingMatrix(
    [group1, group2],
    max_groups=love_config.max_groups,
    max_entities_per_group=love_config.max_entities_per_group,
)

print(f"\nBinding Matrix: {matrix}")
print(f"Active groups: {matrix.active_groups}")

# Show entity access
print("\nEntity Access:")
print(f"Group 0, Entity 0 (person): {matrix.get_entity(0, 0)}")
print(f"Group 0, Entity 1 (food): {matrix.get_entity(0, 1)}")
print(f"Group 1, Entity 0 (person): {matrix.get_entity(1, 0)}")
print(f"Group 2, Entity 0 (inactive): {matrix.get_entity(2, 0)}")

Entity Groups:
Group 1: EntityGroup(['Pete', 'jam'], group_0)
Group 2: EntityGroup(['Ann', 'pie'], group_1)

Binding Matrix: BindingMatrix([EntityGroup(['Pete', 'jam'], group_0), EntityGroup(['Ann', 'pie'], group_1)], active=2)
Active groups: 2

Entity Access:
Group 0, Entity 0 (person): Pete
Group 0, Entity 1 (food): jam
Group 1, Entity 0 (person): Ann
Group 2, Entity 0 (inactive): None


### Flat Dictionary Representation

For the causal model, we need a flat dictionary with standardized keys:

## Step 2: Template Processing System

The template system converts abstract entity bindings into concrete text.

In [None]:
# Build a mega template using the config
# This combines the statement template (conjoined for multiple groups) with a question template
mega_template_str = love_config.build_mega_template(
    active_groups=2,
    query_indices=(0,),  # Query by person
    answer_index=1,       # Answer is food
)

print("Mega template for 2 groups, query person -> answer food:")
print(f"'{mega_template_str}'")

Mega template for 2 groups, query person -> answer food:
'We will ask a question about the following sentences.

{g0_e0} loves {g0_e1} and {g1_e0} loves {g1_e1}. What does {person} love?
Answer:'


### Filling Templates with Entity Values

Now let's fill the mega template with actual entity values using the Template class:

In [None]:
# Create a Template object from the mega template string
template = Template(mega_template_str)

# Fill the template with entity values
# Statement entities use: g0_e0, g0_e1, g1_e0, g1_e1, ...
# Question entities use: query_entity, person, food, ...
values = {
    "g0_e0": "Pete",    # Group 0, entity 0 (person)
    "g0_e1": "jam",     # Group 0, entity 1 (food)
    "g1_e0": "Ann",     # Group 1, entity 0 (person)
    "g1_e1": "pie",     # Group 1, entity 1 (food)
    "person": "Pete",   # Query entity (person from group 0)
}

filled_prompt = template.fill(values)

print("Template Filling Process:")
print(f"Template: '{mega_template_str}'")
print(f"Filled: '{filled_prompt}'")

Template Filling Process:
Template: 'We will ask a question about the following sentences.

{g0_e0} loves {g0_e1} and {g1_e0} loves {g1_e1}. What does {person} love?
Answer:'
Filled: 'We will ask a question about the following sentences.

Pete loves jam and Ann loves pie. What does Pete love?
Answer:'


### Different Question Templates

The config contains templates for different query patterns:

In [None]:
print("Available Question Templates:")
for (query_indices, answer_idx), template_str in love_config.question_templates.items():
    print(f"  Query {query_indices} → Answer {answer_idx}: '{template_str}'")

# Example 1: Query person (index 0), expect food answer (index 1)
mega1 = love_config.build_mega_template(active_groups=2, query_indices=(0,), answer_index=1)
filled1 = Template(mega1).fill({
    "g0_e0": "Pete", "g0_e1": "jam",
    "g1_e0": "Ann", "g1_e1": "pie",
    "person": "Pete",
})

print("\nExample 1 - Query Person:")
print(f"Filled prompt: '{filled1}'")
print("Expected answer: 'jam'")

# Example 2: Query food (index 1), expect person answer (index 0)
mega2 = love_config.build_mega_template(active_groups=2, query_indices=(1,), answer_index=0)
filled2 = Template(mega2).fill({
    "g0_e0": "Pete", "g0_e1": "jam",
    "g1_e0": "Ann", "g1_e1": "pie",
    "food": "jam",
})

print("\nExample 2 - Query Food:")
print(f"Filled prompt: '{filled2}'")
print("Expected answer: 'Pete'")

Available Question Templates:
  Query (0,) → Answer 1: 'What does {person} love?'
  Query (1,) → Answer 0: 'Who loves {food}?'

Example 1 - Query Person:
Filled prompt: 'We will ask a question about the following sentences.

Pete loves jam and Ann loves pie. What does Pete love?
Answer:'
Expected answer: 'jam'

Example 2 - Query Food:
Filled prompt: 'We will ask a question about the following sentences.

Pete loves jam and Ann loves pie. Who loves jam?
Answer:'
Expected answer: 'Pete'


### Key Insight

Same entities, different questions → different answers. This tests whether models understand the entity-relationship structure.

In [None]:
# Demonstrate the flexibility - same entities, different questions
print("Same entities, different query patterns:")
print(f"  Query person 'Pete' -> Answer food: 'jam'")
print(f"  Query food 'jam' -> Answer person: 'Pete'")
print("\nThis tests whether models track which entities are bound together.")

Same entities, different query patterns:
  Query person 'Pete' -> Answer food: 'jam'
  Query food 'jam' -> Answer person: 'Pete'

This tests whether models track which entities are bound together.


### More Complex Example: Action Tasks

Let's see how this works with 3-entity groups (person, object, location):

In [None]:
# Create action configuration
action_config = create_sample_action_config()

print("Action Task Configuration:")
print(f"Entity roles: {action_config.entity_roles}")
print(f"Statement template: {action_config.statement_template}")
print("Question templates:")
for (query_indices, answer_idx), template_str in action_config.question_templates.items():
    print(f"  Query {query_indices} → Answer {answer_idx}: '{template_str}'")

# Build and fill a 3-entity action template
action_mega = action_config.build_mega_template(active_groups=1, query_indices=(0,), answer_index=1)
action_filled = Template(action_mega).fill({
    "g0_e0": "Pete",
    "g0_e1": "jam", 
    "g0_e2": "cup",
    "person": "Pete",
})

print(f"\nAction prompt: '{action_filled}'")

print("\nThis shows the system's flexibility:")
print("- Same core logic works for 2-entity and 3-entity groups")
print("- Templates can be arbitrarily complex")
print("- Easy to add new entity types and relationships")

Action Task Configuration:
Entity roles: {0: 'person', 1: 'object', 2: 'location'}
Statement template: {e0} put {e1} in the {e2}
Question templates:
  Query (0,) → Answer 1: 'What did {person} put somewhere?'
  Query (0,) → Answer 2: 'Where did {person} put something?'
  Query (1,) → Answer 0: 'Who put {object} somewhere?'
  Query (1,) → Answer 2: 'Where was {object} put?'
  Query (2,) → Answer 0: 'Who put something in the {location}?'
  Query (2,) → Answer 1: 'What was put in the {location}?'
  Query (0, 1) → Answer 2: 'Where did {person} put {object}?'
  Query (0, 2) → Answer 1: 'What did {person} put in the {location}?'
  Query (1, 2) → Answer 0: 'Who put {object} in the {location}?'

Action prompt: 'Pete put jam in the cup. What did Pete put somewhere?'

This shows the system's flexibility:
- Same core logic works for 2-entity and 3-entity groups
- Templates can be arbitrarily complex
- Easy to add new entity types and relationships


## Step 3: Causal Model Integration

Now we integrate everything into a causal model that can generate prompts and answers.

### Causal Model Structure

The causal model represents entity bindings as a set of causal variables:

In [None]:
# Create causal model
causal_model = create_positional_entity_causal_model(love_config)

print(f"Causal Model: {causal_model.id}")
print(f"Total variables: {len(causal_model.variables)}")

# Show different types of variables
entity_vars = [v for v in causal_model.variables if v.startswith("entity_")]
control_vars = [
    v
    for v in causal_model.variables
    if not v.startswith("entity_") and v not in ["raw_input", "raw_output"]
    and not v.startswith("positional_") and not v.startswith("query_")
]

print(f"\nEntity variables ({len(entity_vars)}): {entity_vars[:5]}...")
print(f"Control variables ({len(control_vars)}): {control_vars}")
print("Output variables: raw_input, raw_output")

print("\nKey insight: The positional model computes position through intermediate variables:")
print("- positional_entity_g{g}_e{e}: Position (group index) of each entity")
print("- positional_query_e{e}: Groups where query entity appears at position e")
print("- positional_answer: Final group position to retrieve from")

Causal Model: entity_binding_positional_entity_3g_2e
Total variables: 26

Entity variables (6): ['entity_g0_e0', 'entity_g0_e1', 'entity_g1_e0', 'entity_g1_e1', 'entity_g2_e0']...
Control variables (5): ['answer_index', 'active_groups', 'entities_per_group', 'statement_template', 'question_template']
Output variables: raw_input, raw_output

Key insight: The positional model computes position through intermediate variables:
- positional_entity_g{g}_e{e}: Position (group index) of each entity
- positional_query_e{e}: Groups where query entity appears at position e
- positional_answer: Final group position to retrieve from


### Sampling Valid Inputs

The model includes logic to sample valid configurations:

In [None]:
# Sample a valid input - returns a CausalTrace with computed variables
sample_trace = sample_valid_entity_binding_input(love_config, causal_model)

print("Sample Input (from CausalTrace):")
print(f"Active groups: {sample_trace['active_groups']}")
print(f"Query group: {sample_trace['query_group']}")
print(f"Query pattern: {sample_trace['query_indices']} → {sample_trace['answer_index']}")

print("\nActive entities:")
for g in range(sample_trace["active_groups"]):
    entities = []
    for e in range(sample_trace["entities_per_group"]):
        entity = sample_trace[f"entity_g{g}_e{e}"]
        if entity is not None:
            entities.append(f"{love_config.entity_roles[e]}: {entity}")
    print(f"  Group {g}: {entities}")

print("\nValidation ensures:")
print("- Query group is within active groups")
print("- Query and answer entities exist")
print("- Question template exists for this query pattern")
print("- Entities at each position are distinct across groups (for uniqueness)")

Sample Input (from CausalTrace):
Active groups: 2
Query group: 1
Query pattern: (0,) → 1

Active entities:
  Group 0: ['person: Kate', 'food: jam']
  Group 1: ['person: Bob', 'food: cake']

Validation ensures:
- Query group is within active groups
- Query and answer entities exist
- Question template exists for this query pattern
- Entities at each position are distinct across groups (for uniqueness)


### Running the Causal Model

Let's run the model to generate prompts and answers:

In [None]:
# The CausalTrace already contains computed values (raw_input, raw_output, etc.)
# since sample_valid_entity_binding_input calls model.new_trace() internally

print("Causal Model Output (from trace):")
print(f"Generated prompt: '{sample_trace['raw_input']}'")
print(f"Expected answer: '{sample_trace['raw_output']}'")

# Verify consistency
query_entity_key = f"entity_g{sample_trace['query_group']}_e{sample_trace['query_indices'][0]}"
answer_entity_key = f"entity_g{sample_trace['query_group']}_e{sample_trace['answer_index']}"

print("\nConsistency check:")
print(f"Query entity: {sample_trace[query_entity_key]}")
print(f"Answer entity: {sample_trace[answer_entity_key]}")
print(f"Output matches answer entity: {sample_trace['raw_output'] == sample_trace[answer_entity_key]}")

print("\nNote: sample_valid_entity_binding_input returns a CausalTrace")
print("which includes both input values and computed outputs.")

Causal Model Output (from trace):
Generated prompt: 'We will ask a question about the following sentences.

Kate loves jam and Bob loves cake. What does Bob love?
Answer:'
Expected answer: 'cake'

Consistency check:
Query entity: Bob
Answer entity: cake
Output matches answer entity: True

Note: sample_valid_entity_binding_input returns a CausalTrace
which includes both input values and computed outputs.


### Multiple Examples

Let's generate several examples to see the variety:

In [None]:
print("Multiple Example Generations:")
print("=" * 50)

for i in range(5):
    # Sample new trace
    trace = sample_valid_entity_binding_input(love_config, causal_model)
    
    # Extract key information
    query_group = trace["query_group"]
    entities = []
    for e in range(love_config.max_entities_per_group):
        entity = trace[f"entity_g{query_group}_e{e}"]
        if entity is not None:
            entities.append(entity)

    print(f"\nExample {i + 1}:")
    print(f"  Entities in query group: {entities}")
    print(f"  Query pattern: {trace['query_indices']} → {trace['answer_index']}")
    print(f"  Prompt: '{trace['raw_input']}'")
    print(f"  Answer: '{trace['raw_output']}'")

    # Verify answer is correct
    expected_answer = trace[f"entity_g{query_group}_e{trace['answer_index']}"]
    print(
        f"  Expected: '{expected_answer}' ✓"
        if trace["raw_output"] == expected_answer
        else f"  Expected: '{expected_answer}' ✗"
    )

print("\nKey observations:")
print("- Same model generates different question types")
print("- Consistent entity-relationship mappings")
print("- Flexible handling of different entity combinations")

Multiple Example Generations:

Example 1:
  Entities in query group: ['Bob', 'cake']
  Query pattern: (0,) → 1
  Prompt: 'We will ask a question about the following sentences.

Bob loves cake and Kate loves bread. What does Bob love?
Answer:'
  Answer: 'cake'
  Expected: 'cake' ✓

Example 2:
  Entities in query group: ['Ann', 'pie']
  Query pattern: (0,) → 1
  Prompt: 'We will ask a question about the following sentences.

Ann loves pie and Sue loves bread. What does Ann love?
Answer:'
  Answer: 'pie'
  Expected: 'pie' ✓

Example 3:
  Entities in query group: ['Bob', 'jam']
  Query pattern: (1,) → 0
  Prompt: 'We will ask a question about the following sentences.

Kate loves cake, Bob loves jam, and Ann loves pie. Who loves jam?
Answer:'
  Answer: 'Bob'
  Expected: 'Bob' ✓

Example 4:
  Entities in query group: ['Sue', 'bread']
  Query pattern: (0,) → 1
  Prompt: 'We will ask a question about the following sentences.

Bob loves cake, Ann loves tea, and Sue loves bread. What does Sue lo

### Action System with Causal Model Integration

Now let's see how the 3-entity system works with the full causal model:

### Detailed 3-Entity Action System Demonstration

Let's explore the action system more thoroughly to understand how it handles multiple query patterns with 3 entities:

In [None]:
# Create multiple action scenarios with different entities
print("=== 3-Entity Action System Deep Dive ===")
print(f"Entity roles: {action_config.entity_roles}")
print("Entity pools:")
for role_idx, role_name in action_config.entity_roles.items():
    print(f"  {role_name} ({role_idx}): {action_config.entity_pools[role_idx][:4]}...")

# Demonstrate different question patterns with the action config
print("\n=== Question Patterns for 3-Entity Groups ===")
for (query_indices, answer_idx), template_str in action_config.question_templates.items():
    print(f"Query {query_indices} → Answer {answer_idx}: '{template_str}'")

=== 3-Entity Action System Deep Dive ===
Entity roles: {0: 'person', 1: 'object', 2: 'location'}
Entity pools:
  person (0): ['Pete', 'Ann', 'Bob', 'Sue']...
  object (1): ['jam', 'water', 'book', 'coin']...
  location (2): ['cup', 'box', 'table', 'shelf']...

=== Question Patterns for 3-Entity Groups ===
Query (0,) → Answer 1: 'What did {person} put somewhere?'
Query (0,) → Answer 2: 'Where did {person} put something?'
Query (1,) → Answer 0: 'Who put {object} somewhere?'
Query (1,) → Answer 2: 'Where was {object} put?'
Query (2,) → Answer 0: 'Who put something in the {location}?'
Query (2,) → Answer 1: 'What was put in the {location}?'
Query (0, 1) → Answer 2: 'Where did {person} put {object}?'
Query (0, 2) → Answer 1: 'What did {person} put in the {location}?'
Query (1, 2) → Answer 0: 'Who put {object} in the {location}?'


In [None]:
# Create causal model for action configuration
action_causal_model = create_positional_entity_causal_model(action_config)

print("=== 3-Entity Causal Model ===")
print(f"Model ID: {action_causal_model.id}")
print(f"Total variables: {len(action_causal_model.variables)}")

# Show entity variables for action model
action_entity_vars = [v for v in action_causal_model.variables if v.startswith("entity_")]
print(f"\nEntity variables ({len(action_entity_vars)} total):")
for g in range(action_config.max_groups):
    group_vars = [f"entity_g{g}_e{e}" for e in range(action_config.max_entities_per_group)]
    roles = [action_config.entity_roles.get(e, f"role_{e}") for e in range(action_config.max_entities_per_group)]
    print(f"  Group {g}: {list(zip(group_vars, roles))}")

print("\n=== Sample Action Generations ===")
for i in range(3):
    try:
        trace = sample_valid_entity_binding_input(action_config, action_causal_model)
        
        expected_answer = trace[f"entity_g{trace['query_group']}_e{trace['answer_index']}"]
        
        print(f"\nAction Example {i + 1}:")
        print(f"  Query pattern: {trace['query_indices']} → {trace['answer_index']}")
        print(f"  Full prompt: '{trace['raw_input']}'")
        print(f"  Expected answer: '{expected_answer}'")
        print(f"  Actual answer: '{trace['raw_output']}'")
        print(f"  Correct: {'✓' if trace['raw_output'] == expected_answer else '✗'}")
    except Exception as e:
        print(f"\nAction Example {i + 1}: Generation failed - {e}")

print("\n=== Comparison: 2-Entity vs 3-Entity Complexity ===")
print("Love system (2 entities):")
print(f"  - {len(love_config.question_templates)} question templates")
print(f"  - Query patterns: {list(love_config.question_templates.keys())}")

print("\nAction system (3 entities):")
print(f"  - {len(action_config.question_templates)} question templates")
print(f"  - Query patterns: {list(action_config.question_templates.keys())}")

print("\nKey scaling insight: With N entities, you can have:")
print('- N different single-entity queries ("What did Pete...?")')
print('- N*(N-1)/2 different two-entity queries ("What did Pete put in...?")')
print("The template system handles this complexity systematically!")

=== 3-Entity Causal Model ===
Model ID: entity_binding_positional_entity_3g_3e
Total variables: 34

Entity variables (9 total):
  Group 0: [('entity_g0_e0', 'person'), ('entity_g0_e1', 'object'), ('entity_g0_e2', 'location')]
  Group 1: [('entity_g1_e0', 'person'), ('entity_g1_e1', 'object'), ('entity_g1_e2', 'location')]
  Group 2: [('entity_g2_e0', 'person'), ('entity_g2_e1', 'object'), ('entity_g2_e2', 'location')]

=== Sample Action Generations ===

Action Example 1:
  Query pattern: (0,) → 2
  Full prompt: 'Pete put phone in the box and Bob put jam in the bag. Where did Bob put something?'
  Expected answer: 'bag'
  Actual answer: 'bag'
  Correct: ✓

Action Example 2:
  Query pattern: (0,) → 2
  Full prompt: 'Dan put phone in the drawer and Ann put pen in the cup. Where did Dan put something?'
  Expected answer: 'drawer'
  Actual answer: 'drawer'
  Correct: ✓

Action Example 3:
  Query pattern: (0,) → 1
  Full prompt: 'Bob put coin in the drawer, Ann put phone in the basket, and S

## What We've Built

This system provides a flexible foundation for entity binding experiments:

### Core Components:
1. **Data Structures**: Handle arbitrary entity groups and relationships
2. **Template System**: Convert abstract bindings to concrete text  
3. **Causal Model**: Represent the task as a causal graph

### Next Steps:
1. **Counterfactual Generation**: Create strategic input pairs for testing
2. **Token Position Functions**: Locate entities in neural network inputs
3. **Integration**: Connect to intervention experiment framework