# Entity Binding Task Implementation Walkthrough

This notebook demonstrates the step-by-step implementation of a flexible entity binding task system for causal abstraction experiments. Entity binding tasks test how language models understand relationships between entities in structured contexts.

## What are Entity Binding Tasks?

Entity binding tasks involve:
1. **Entity Groups**: Sets of entities that appear together (e.g., "Pete, jam" or "Ann, pie")
2. **Templates**: Structured text that describes relationships (e.g., "X loves Y")
3. **Queries**: Questions about specific entities or relationships

Example: "Pete loves jam, Ann loves pie. What does Pete love?"

The goal is to understand how neural networks represent and manipulate these entity-relationship bindings.

## Step 1: Core Data Structures

First, we need flexible data structures that can handle arbitrary numbers of entity groups and entities per group.

In [1]:
# autoreload imports
%load_ext autoreload
%autoreload 2

# Setup imports
import sys

sys.path.append("..")

from causalab.tasks.entity_binding.config import (
    EntityGroup,
    BindingMatrix,
    create_sample_love_config,
    create_sample_action_config,
)

### EntityBindingTaskConfig: The Blueprint

This configuration class defines everything about our task:

In [2]:
# Create a sample "love" configuration
love_config = create_sample_love_config()

print("Love Task Configuration:")
print(f"Max groups: {love_config.max_groups}")
print(f"Max entities per group: {love_config.max_entities_per_group}")
print(f"Entity roles: {love_config.entity_roles}")
print(f"People: {love_config.entity_pools[0]}")
print(f"Foods: {love_config.entity_pools[1]}")
print(f"Statement template: {love_config.statement_template}")
print(f"Question templates: {love_config.question_templates}")

Love Task Configuration:
Max groups: 3
Max entities per group: 2
Entity roles: {0: 'person', 1: 'food'}
People: ['Pete', 'Ann', 'Tim', 'Bob', 'Sue', 'Kate']
Foods: ['jam', 'pie', 'cake', 'bread', 'soup', 'tea']
Statement template: {entity_e0} loves {entity_e1}
Question templates: {((0,), 1): 'What does {query_entity} love?', ((1,), 0): 'Who loves {query_entity}?'}


### EntityGroup and BindingMatrix: Representing Specific Instances

While the config is the blueprint, EntityGroup and BindingMatrix represent specific instances:

In [3]:
# Create specific entity groups
group1 = EntityGroup(["Pete", "jam"], group_index=0)
group2 = EntityGroup(["Ann", "pie"], group_index=1)

print("Entity Groups:")
print(f"Group 1: {group1}")
print(f"Group 2: {group2}")

# Create binding matrix
matrix = BindingMatrix(
    [group1, group2],
    max_groups=love_config.max_groups,
    max_entities_per_group=love_config.max_entities_per_group,
)

print(f"\nBinding Matrix: {matrix}")
print(f"Active groups: {matrix.active_groups}")

# Show entity access
print("\nEntity Access:")
print(f"Group 0, Entity 0 (person): {matrix.get_entity(0, 0)}")
print(f"Group 0, Entity 1 (food): {matrix.get_entity(0, 1)}")
print(f"Group 1, Entity 0 (person): {matrix.get_entity(1, 0)}")
print(f"Group 2, Entity 0 (inactive): {matrix.get_entity(2, 0)}")

Entity Groups:
Group 1: EntityGroup(['Pete', 'jam'], group_0)
Group 2: EntityGroup(['Ann', 'pie'], group_1)

Binding Matrix: BindingMatrix([EntityGroup(['Pete', 'jam'], group_0), EntityGroup(['Ann', 'pie'], group_1)], active=2)
Active groups: 2

Entity Access:
Group 0, Entity 0 (person): Pete
Group 0, Entity 1 (food): jam
Group 1, Entity 0 (person): Ann
Group 2, Entity 0 (inactive): None


### Flat Dictionary Representation

For the causal model, we need a flat dictionary with standardized keys:

## Step 2: Template Processing System

The template system converts abstract entity bindings into concrete text.

In [4]:
from causalab.tasks.entity_binding.templates import TemplateProcessor

processor = TemplateProcessor(love_config)

### Statement Template Filling

First, let's fill the statement template with our entities:

In [5]:
# Fill statement template
statement = processor.fill_statement_template(matrix)

print("Template Filling Process:")
print(f"Template: {love_config.statement_template}")
print(f"Result: '{statement}'")

Template Filling Process:
Template: {entity_e0} loves {entity_e1}
Result: 'Pete loves jam, and Ann loves pie.'


### Question Template Selection and Filling

Different query patterns require different templates:

In [6]:
print("Available Question Templates:")
for (query_indices, answer_idx), template in love_config.question_templates.items():
    print(f"  Query {query_indices} → Answer {answer_idx}: '{template}'")

# Example 1: Query person (index 0), expect food answer (index 1)
template1 = processor.select_question_template((0,), 1)
question1 = processor.fill_question_template(template1, 0, (0,), matrix)

print("\nExample 1 - Query Person:")
print(f"Template: '{template1}'")
print(f"Filled: '{question1}'")

# Example 2: Query food (index 1), expect person answer (index 0)
template2 = processor.select_question_template((1,), 0)
question2 = processor.fill_question_template(template2, 0, (1,), matrix)

print("\nExample 2 - Query Food:")
print(f"Template: '{template2}'")
print(f"Filled: '{question2}'")

Available Question Templates:
  Query (0,) → Answer 1: 'What does {query_entity} love?'
  Query (1,) → Answer 0: 'Who loves {query_entity}?'

Example 1 - Query Person:
Template: 'What does {query_entity} love?'
Filled: 'What does Pete love?'

Example 2 - Query Food:
Template: 'Who loves {query_entity}?'
Filled: 'Who loves jam?'


### Complete Prompt Generation

Combining statements and questions into full prompts:

In [7]:
# Generate complete prompts
prompt1, answer1 = processor.generate_full_prompt(matrix, 0, (0,), 1)
prompt2, answer2 = processor.generate_full_prompt(matrix, 0, (1,), 0)

print("Complete Prompt Examples:")
print("\n1. Query about person:")
print(f"   Prompt: '{prompt1}'")
print(f"   Answer: '{answer1}'")

print("\n2. Query about food:")
print(f"   Prompt: '{prompt2}'")
print(f"   Answer: '{answer2}'")

print("\nKey insight: Same entities, different questions → different answers")
print("This tests whether models understand entity-relationship structure")

Complete Prompt Examples:

1. Query about person:
   Prompt: 'Pete loves jam, and Ann loves pie. What does Pete love?'
   Answer: 'jam'

2. Query about food:
   Prompt: 'Pete loves jam, and Ann loves pie. Who loves jam?'
   Answer: 'Pete'

Key insight: Same entities, different questions → different answers
This tests whether models understand entity-relationship structure


### More Complex Example: Action Tasks

Let's see how this works with 3-entity groups (person, object, location):

In [8]:
# Create action configuration
action_config = create_sample_action_config()
action_processor = TemplateProcessor(action_config)

print("Action Task Configuration:")
print(f"Entity roles: {action_config.entity_roles}")
print(f"Statement template: {action_config.statement_template}")
print("Question templates:")
for (query_indices, answer_idx), template in action_config.question_templates.items():
    print(f"  Query {query_indices} → Answer {answer_idx}: '{template}'")

# Create action binding
action_group = EntityGroup(["Pete", "jam", "cup"], 0)
action_matrix = BindingMatrix(
    [action_group], action_config.max_groups, action_config.max_entities_per_group
)

# Generate action statement
action_statement = action_processor.fill_statement_template(action_matrix)
print(f"\nAction statement: '{action_statement}'")

print("\nThis shows the system's flexibility:")
print("- Same core logic works for 2-entity and 3-entity groups")
print("- Templates can be arbitrarily complex")
print("- Easy to add new entity types and relationships")

Action Task Configuration:
Entity roles: {0: 'person', 1: 'object', 2: 'location'}
Statement template: {entity_e0} put {entity_e1} in the {entity_e2}
Question templates:
  Query (0,) → Answer 1: 'What did {person} put somewhere?'
  Query (0,) → Answer 2: 'Where did {person} put something?'
  Query (1,) → Answer 0: 'Who put {object} somewhere?'
  Query (1,) → Answer 2: 'Where was {object} put?'
  Query (2,) → Answer 0: 'Who put something in the {location}?'
  Query (2,) → Answer 1: 'What was put in the {location}?'
  Query (0, 1) → Answer 2: 'Where did {person} put {object}?'
  Query (0, 2) → Answer 1: 'What did {person} put in the {location}?'
  Query (1, 2) → Answer 0: 'Who put {object} in the {location}?'

Action statement: 'Pete put jam in the cup.'

This shows the system's flexibility:
- Same core logic works for 2-entity and 3-entity groups
- Templates can be arbitrarily complex
- Easy to add new entity types and relationships


## Step 3: Causal Model Integration

Now we integrate everything into a causal model that can generate prompts and answers.

In [12]:
from causalab.tasks.entity_binding.causal_models import (
    create_direct_causal_model,
    sample_valid_entity_binding_input,
)

### Causal Model Structure

The causal model represents entity bindings as a set of causal variables:

In [13]:
# Create causal model
causal_model = create_direct_causal_model(love_config)

print(f"Causal Model: {causal_model.id}")
print(f"Total variables: {len(causal_model.variables)}")

# Show different types of variables
entity_vars = [v for v in causal_model.variables if v.startswith("entity_")]
control_vars = [
    v
    for v in causal_model.variables
    if not v.startswith("entity_") and v not in ["raw_input", "raw_output"]
]

print(f"\nEntity variables ({len(entity_vars)}): {entity_vars[:5]}...")
print(f"Control variables ({len(control_vars)}): {control_vars}")
print("Required variables: raw_input, raw_output")

print("\nKey insight: The causal model has NO intermediate variables")
print("It directly maps from entity inputs to text outputs")
print("This tests whether neural networks learn similar direct mappings")

Causal Model: entity_binding_direct_3g_2e
Total variables: 15

Entity variables (6): ['entity_g0_e0', 'entity_g0_e1', 'entity_g1_e0', 'entity_g1_e1', 'entity_g2_e0']...
Control variables (7): ['query_group', 'query_indices', 'answer_index', 'active_groups', 'entities_per_group', 'statement_template', 'question_template']
Required variables: raw_input, raw_output

Key insight: The causal model has NO intermediate variables
It directly maps from entity inputs to text outputs
This tests whether neural networks learn similar direct mappings


### Sampling Valid Inputs

The model includes logic to sample valid configurations:

In [14]:
# Sample a valid input
sample_input = sample_valid_entity_binding_input(love_config)

print("Sample Input:")
print(f"Active groups: {sample_input['active_groups']}")
print(f"Query group: {sample_input['query_group']}")
print(
    f"Query pattern: {sample_input['query_indices']} → {sample_input['answer_index']}"
)

print("\nActive entities:")
for g in range(sample_input["active_groups"]):
    entities = []
    for e in range(sample_input["entities_per_group"]):
        entity = sample_input[f"entity_g{g}_e{e}"]
        if entity is not None:
            entities.append(f"{love_config.entity_roles[e]}: {entity}")
    print(f"  Group {g}: {entities}")

print("\nValidation ensures:")
print("- Query group is within active groups")
print("- Query and answer entities exist")
print("- Question template exists for this query pattern")

Sample Input:
Active groups: 3
Query group: 2
Query pattern: (0,) → 1

Active entities:
  Group 0: ['person: Pete', 'food: soup']
  Group 1: ['person: Kate', 'food: pie']
  Group 2: ['person: Bob', 'food: tea']

Validation ensures:
- Query group is within active groups
- Query and answer entities exist
- Question template exists for this query pattern


### Running the Causal Model

Let's run the model to generate prompts and answers:

In [15]:
# Run the causal model forward with the sampled input as an intervention
# This ensures the model uses our specific entity values rather than random sampling
output = causal_model.run_forward(sample_input)

print("Causal Model Output:")
print(f"Generated prompt: '{output['raw_input']}'")
print(f"Expected answer: '{output['raw_output']}'")

# Verify consistency
query_entity_key = (
    f"entity_g{sample_input['query_group']}_e{sample_input['query_indices'][0]}"
)
answer_entity_key = (
    f"entity_g{sample_input['query_group']}_e{sample_input['answer_index']}"
)

print("\nConsistency check:")
print(f"Query entity: {sample_input[query_entity_key]}")
print(f"Answer entity: {sample_input[answer_entity_key]}")
print(
    f"Output matches answer entity: {output['raw_output'] == sample_input[answer_entity_key]}"
)

print("\nNote: We pass sample_input as an intervention to ensure")
print("the model uses our specific entity values rather than random sampling")

Causal Model Output:
Generated prompt: 'Pete loves soup, Kate loves pie, and Bob loves tea. What does Bob love?'
Expected answer: 'tea'

Consistency check:
Query entity: Bob
Answer entity: tea
Output matches answer entity: True

Note: We pass sample_input as an intervention to ensure
the model uses our specific entity values rather than random sampling


### Multiple Examples

Let's generate several examples to see the variety:

In [16]:
print("Multiple Example Generations:")
print("=" * 50)

for i in range(5):
    # Sample new input configuration
    sample = sample_valid_entity_binding_input(love_config)
    # Run model with sample as intervention to use specific entities
    output = causal_model.run_forward(sample)

    # Extract key information
    query_group = sample["query_group"]
    entities = []
    for e in range(love_config.max_entities_per_group):
        entity = sample[f"entity_g{query_group}_e{e}"]
        if entity is not None:
            entities.append(entity)

    print(f"\nExample {i + 1}:")
    print(f"  Entities in query group: {entities}")
    print(f"  Query pattern: {sample['query_indices']} → {sample['answer_index']}")
    print(f"  Prompt: '{output['raw_input']}'")
    print(f"  Answer: '{output['raw_output']}'")

    # Verify answer is correct
    expected_answer = sample[f"entity_g{query_group}_e{sample['answer_index']}"]
    print(
        f"  Expected: '{expected_answer}' ✓"
        if output["raw_output"] == expected_answer
        else f"  Expected: '{expected_answer}' ✗"
    )

print("\nKey observations:")
print("- Same model generates different question types")
print("- Consistent entity-relationship mappings")
print("- Flexible handling of different entity combinations")
print("- Answers are now correct because we pass sample as intervention")

Multiple Example Generations:

Example 1:
  Entities in query group: ['Sue', 'tea']
  Query pattern: (1,) → 0
  Prompt: 'Sue loves tea, and Tim loves pie. Who loves tea?'
  Answer: 'Sue'
  Expected: 'Sue' ✓

Example 2:
  Entities in query group: ['Ann', 'bread']
  Query pattern: (1,) → 0
  Prompt: 'Ann loves bread, Pete loves cake, and Sue loves tea. Who loves bread?'
  Answer: 'Ann'
  Expected: 'Ann' ✓

Example 3:
  Entities in query group: ['Ann', 'tea']
  Query pattern: (0,) → 1
  Prompt: 'Ann loves tea, Tim loves jam, and Sue loves pie. What does Ann love?'
  Answer: 'tea'
  Expected: 'tea' ✓

Example 4:
  Entities in query group: ['Sue', 'jam']
  Query pattern: (0,) → 1
  Prompt: 'Sue loves jam, Tim loves cake, and Bob loves soup. What does Sue love?'
  Answer: 'jam'
  Expected: 'jam' ✓

Example 5:
  Entities in query group: ['Pete', 'soup']
  Query pattern: (1,) → 0
  Prompt: 'Pete loves soup, and Bob loves tea. Who loves soup?'
  Answer: 'Pete'
  Expected: 'Pete' ✓

Key observat

### Action System with Causal Model Integration

Now let's see how the 3-entity system works with the full causal model:

### Detailed 3-Entity Action System Demonstration

Let's explore the action system more thoroughly to understand how it handles multiple query patterns with 3 entities:

In [17]:
# Create multiple action scenarios with different entities
print("=== 3-Entity Action System Deep Dive ===")
print(f"Entity roles: {action_config.entity_roles}")
print("Entity pools:")
for role_idx, role_name in action_config.entity_roles.items():
    print(f"  {role_name} ({role_idx}): {action_config.entity_pools[role_idx]}")

# Create several different action scenarios
action_scenarios = [
    ["Pete", "jam", "cup"],
    ["Ann", "water", "box"],
    ["Bob", "book", "table"],
]

print("\n=== Different Action Scenarios ===")
for i, entities in enumerate(action_scenarios):
    group = EntityGroup(entities, i)
    matrix = BindingMatrix(
        [group], action_config.max_groups, action_config.max_entities_per_group
    )

    statement = action_processor.fill_statement_template(matrix)
    print(
        f"\nScenario {i + 1}: {dict(zip(['person', 'object', 'location'], entities))}"
    )
    print(f"Statement: '{statement}'")

    # Generate different types of questions for this scenario
    print("Questions:")
    for (
        query_indices,
        answer_idx,
    ), template in action_config.question_templates.items():
        try:
            filled_question = action_processor.fill_question_template(
                template, 0, query_indices, matrix
            )
            answer_entity = matrix.get_entity(0, answer_idx)
            print(
                f"  Query {query_indices} → Answer {answer_idx}: '{filled_question}' → '{answer_entity}'"
            )
        except Exception as e:
            print(
                f"  Query {query_indices} → Answer {answer_idx}: Template issue - {e}"
            )

print("\n=== Key Insights about 3-Entity System ===")
print("1. **Richer Query Patterns**: With 3 entities, we can ask about:")
print("   - What did [person] put in [location]? (query person + location → object)")
print("   - Who put [object] in [location]? (query object + location → person)")
print("   - Where did [person] put [object]? (query person + object → location)")
print(
    "2. **Multiple Context Elements**: Questions can reference 2 entities to ask about the 3rd"
)
print(
    "3. **Template Complexity**: Each query pattern needs its own natural language template"
)
print(
    "4. **Scalability**: Same framework works for N entities (person, object, location, time, etc.)"
)

=== 3-Entity Action System Deep Dive ===
Entity roles: {0: 'person', 1: 'object', 2: 'location'}
Entity pools:
  person (0): ['Pete', 'Ann', 'Bob', 'Sue', 'Tim', 'Kate', 'Dan', 'Lily']
  object (1): ['jam', 'water', 'book', 'coin', 'pen', 'key', 'phone', 'watch']
  location (2): ['cup', 'box', 'table', 'shelf', 'drawer', 'bag', 'pocket', 'basket']

=== Different Action Scenarios ===

Scenario 1: {'person': 'Pete', 'object': 'jam', 'location': 'cup'}
Statement: 'Pete put jam in the cup.'
Questions:
  Query (0,) → Answer 1: 'What did Pete put somewhere?' → 'jam'
  Query (0,) → Answer 2: 'Where did Pete put something?' → 'cup'
  Query (1,) → Answer 0: 'Who put jam somewhere?' → 'Pete'
  Query (1,) → Answer 2: 'Where was jam put?' → 'cup'
  Query (2,) → Answer 0: 'Who put something in the cup?' → 'Pete'
  Query (2,) → Answer 1: 'What was put in the cup?' → 'jam'
  Query (0, 1) → Answer 2: 'Where did Pete put jam?' → 'cup'
  Query (0, 2) → Answer 1: 'What did Pete put in the cup?' → 'jam'
 

In [19]:
# Create causal model for action configuration
action_config.max_groups = 4
action_causal_model = create_direct_causal_model(action_config)

print("=== 3-Entity Causal Model ===")
print(f"Model ID: {action_causal_model.id}")
print(f"Total variables: {len(action_causal_model.variables)}")

# Show entity variables for action model
action_entity_vars = [
    v for v in action_causal_model.variables if v.startswith("entity_")
]
print(f"\nEntity variables pattern ({len(action_entity_vars)} total):")
for g in range(action_config.max_groups):
    group_vars = [
        f"entity_g{g}_e{e}" for e in range(action_config.max_entities_per_group)
    ]
    roles = [
        action_config.entity_roles.get(e, f"role_{e}")
        for e in range(action_config.max_entities_per_group)
    ]
    print(f"  Group {g}: {list(zip(group_vars, roles))}")

print("\n=== Sample Action Generations ===")
# Generate several action examples
for i in range(3):
    try:
        sample = sample_valid_entity_binding_input(action_config)
        # IMPORTANT: Pass sample as intervention to use specific entity values
        output = action_causal_model.run_forward(sample)

        # Extract the entities from the sample
        entities = []
        for e in range(action_config.max_entities_per_group):
            for g in range(action_config.max_groups):
                entity = sample[f"entity_g{g}_e{e}"]  # Focus on first group
                if entity is not None:
                    role = action_config.entity_roles[e]
                    entities.append(f"{role}: {entity}")

        # Get expected answer
        expected_answer = sample[
            f"entity_g{sample['query_group']}_e{sample['answer_index']}"
        ]

        print(f"\nAction Example {i + 1}:")
        print(f"  Entities: {entities}")
        print(f"  Query pattern: {sample['query_indices']} → {sample['answer_index']}")
        print(f"  Full prompt: '{output['raw_input']}'")
        print(f"  Expected answer: '{expected_answer}'")
        print(f"  Actual answer: '{output['raw_output']}'")
        print(f"  Correct: {'✓' if output['raw_output'] == expected_answer else '✗'}")

    except Exception as e:
        print(f"\nAction Example {i + 1}: Generation failed - {e}")
        print("  (This may be due to template issues)")

print("\n=== Comparison: 2-Entity vs 3-Entity Complexity ===")
print("Love system (2 entities):")
print(f"  - {len(love_config.question_templates)} question templates")
print(f"  - {love_config.max_entities_per_group} entities per group")
print(f"  - Query patterns: {list(love_config.question_templates.keys())}")

print("\nAction system (3 entities):")
print(f"  - {len(action_config.question_templates)} question templates")
print(f"  - {action_config.max_entities_per_group} entities per group")
print(f"  - Query patterns: {list(action_config.question_templates.keys())}")

print("\nKey scaling insight: With N entities, you can potentially have:")
print('- N different single-entity queries ("What did Pete...?")')
print('- N*(N-1) different two-entity queries ("What did Pete put in...?")')
print("- And so on for higher-order combinations")
print("The template system handles this complexity systematically!")

print("\nIMPORTANT: When using the causal model with specific entities,")
print("always pass the entity configuration as an intervention to run_forward().")

=== 3-Entity Causal Model ===
Model ID: entity_binding_direct_4g_3e
Total variables: 21

Entity variables pattern (12 total):
  Group 0: [('entity_g0_e0', 'person'), ('entity_g0_e1', 'object'), ('entity_g0_e2', 'location')]
  Group 1: [('entity_g1_e0', 'person'), ('entity_g1_e1', 'object'), ('entity_g1_e2', 'location')]
  Group 2: [('entity_g2_e0', 'person'), ('entity_g2_e1', 'object'), ('entity_g2_e2', 'location')]
  Group 3: [('entity_g3_e0', 'person'), ('entity_g3_e1', 'object'), ('entity_g3_e2', 'location')]

=== Sample Action Generations ===

Action Example 1:
  Entities: ['person: Bob', 'person: Ann', 'object: watch', 'object: water', 'location: table', 'location: box']
  Query pattern: (0,) → 1
  Full prompt: 'Bob put watch in the table, and Ann put water in the box. What did Bob put somewhere?'
  Expected answer: 'watch'
  Actual answer: 'watch'
  Correct: ✓

Action Example 2:
  Entities: ['person: Tim', 'person: Bob', 'person: Ann', 'object: watch', 'object: water', 'object: p

## What We've Built

This system provides a flexible foundation for entity binding experiments:

### Core Components:
1. **Data Structures**: Handle arbitrary entity groups and relationships
2. **Template System**: Convert abstract bindings to concrete text  
3. **Causal Model**: Represent the task as a causal graph

### Key Design Decisions:
- **Scalable**: One model handles variable numbers of groups/entities
- **Flexible**: Easy to add new templates and entity types
- **Testable**: Clear separation between structure and content

### Next Steps:
1. **Counterfactual Generation**: Create strategic input pairs for testing
2. **Token Position Functions**: Locate entities in neural network inputs
3. **Integration**: Connect to intervention experiment framework