# Counterfactuals and Positional Variables

This notebook demonstrates two key concepts for entity binding experiments:

1. **Counterfactual datasets** - How to generate counterfactuals by swapping entity groups
2. **Positional causal models** - Two different hypotheses about how retrieval works

## The Core Question

When a language model sees "Pete loves jam, and Ann loves pie. What does Ann love?", how does it retrieve the answer?

**Hypothesis 1 (Direct)**: The model learns that "group 1" (Ann's group) is at a specific position in its representation, and directly retrieves from that position.

**Hypothesis 2 (Positional)**: The model searches for "Ann" in its stored bindings, finds where Ann appears, then retrieves the associated value from that location.

We can test these hypotheses using **counterfactual experiments** where we swap entity groups.

In [16]:
# Setup
%load_ext autoreload
%autoreload 2

import sys

sys.path.append("..")

from causalab.tasks.entity_binding.config import create_sample_love_config
from causalab.tasks.entity_binding.causal_models import (
    create_direct_causal_model,
    create_positional_causal_model,
)
from causalab.tasks.entity_binding.counterfactual import swap_query_group, random_counterfactual

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## Part 1: Understanding the Two Causal Models

Let's start by seeing how the two models work on the same input.

In [17]:
# Create both models
config = create_sample_love_config()
direct_model = create_direct_causal_model(config)
positional_model = create_positional_causal_model(config)

print("Created two causal models:")
print(f"  Direct model:     {direct_model.id}")
print(f"  Positional model: {positional_model.id}")

Created two causal models:
  Direct model:     entity_binding_direct_3g_2e
  Positional model: entity_binding_positional_3g_2e


In [18]:
# Create a specific example
input_sample = {
    "entity_g0_e0": "Pete",
    "entity_g0_e1": "jam",
    "entity_g1_e0": "Ann",
    "entity_g1_e1": "pie",
    "entity_g2_e0": "Bob",
    "entity_g2_e1": "cake",
    "query_group": 1,
    "query_indices": (0,),
    "answer_index": 1,
    "active_groups": 3,
    "entities_per_group": 2,
}

print("Input configuration:")
print("  Entities: g0=(Pete, jam), g1=(Ann, pie), g2=(Bob, cake)")
print(
    f"  Query: group {input_sample['query_group']}, entity {input_sample['query_indices'][0]}"
)
print(
    f"  Answer: entity at group {input_sample['query_group']}, position {input_sample['answer_index']}"
)

Input configuration:
  Entities: g0=(Pete, jam), g1=(Ann, pie), g2=(Bob, cake)
  Query: group 1, entity 0
  Answer: entity at group 1, position 1


In [19]:
# Run both models
direct_output = direct_model.run_forward(input_sample)
positional_output = positional_model.run_forward(input_sample)

print("\nPrompt generated:")
print(f"  {direct_output['raw_input']}")
print()

print("Direct model (index-based):")
print(f"  Mechanism: Uses query_group={input_sample['query_group']} directly")
print(f"  Retrieves: entity_g1_e1 = {input_sample['entity_g1_e1']}")
print(f"  Answer: {direct_output['raw_output']}")
print()

print("Positional model (search-based):")
print(
    f"  Step 1: Extract query entity from g1_e0 = {positional_output['query_entity']}"
)
print(
    f"  Step 2: Search for {positional_output['query_entity']} at position e0 across groups"
)
print(f"  Step 3: Found at group {positional_output['positional_query_group']}")
print(
    f"  Step 4: Retrieve from g{positional_output['positional_query_group']}_e{input_sample['answer_index']}"
)
print(f"  Answer: {positional_output['raw_output']}")
print()

print(
    "On this input, both models agree:",
    direct_output["raw_output"] == positional_output["raw_output"],
)


Prompt generated:
  Pete loves jam, Ann loves pie, and Bob loves cake. What does Ann love?

Direct model (index-based):
  Mechanism: Uses query_group=1 directly
  Retrieves: entity_g1_e1 = pie
  Answer: pie

Positional model (search-based):
  Step 1: Extract query entity from g1_e0 = ('Ann',)
  Step 2: Search for ('Ann',) at position e0 across groups
  Step 3: Found at group 1
  Step 4: Retrieve from g1_e1
  Answer: pie

On this input, both models agree: True


## Part 2: What Happens When We Swap Groups?

Now let's swap entity groups and see how the models behave.

In [20]:
# Swap groups 1 and 2
swapped_sample = input_sample.copy()
swapped_sample["entity_g1_e0"] = "Bob"  # g1 gets g2's entities
swapped_sample["entity_g1_e1"] = "cake"
swapped_sample["entity_g2_e0"] = "Ann"  # g2 gets g1's entities
swapped_sample["entity_g2_e1"] = "pie"
# query_group STAYS 1

print("After swapping groups 1 and 2:")
print("  Entities: g0=(Pete, jam), g1=(Bob, cake), g2=(Ann, pie)")
print(
    f"  Query: still group {swapped_sample['query_group']}, entity {swapped_sample['query_indices'][0]}"
)
print("  Now querying Bob (who moved from g2 to g1)")

After swapping groups 1 and 2:
  Entities: g0=(Pete, jam), g1=(Bob, cake), g2=(Ann, pie)
  Query: still group 1, entity 0
  Now querying Bob (who moved from g2 to g1)


In [21]:
# Run both models on swapped input
direct_swapped = direct_model.run_forward(swapped_sample)
positional_swapped = positional_model.run_forward(swapped_sample)

print("Prompt generated:")
print(f"  {direct_swapped['raw_input']}")
print()

print("Direct model:")
print("  Still uses query_group=1 directly")
print(f"  Retrieves: entity_g1_e1 = {swapped_sample['entity_g1_e1']}")
print(f"  Answer: {direct_swapped['raw_output']}")
print()

print("Positional model:")
print(f"  Extracts query entity from g1_e0: {positional_swapped['query_entity']}")
print(f"  Searches for {positional_swapped['query_entity']} at position e0")
print(f"  Found at group: {positional_swapped['positional_query_group']}")
print(f"  Retrieves from g{positional_swapped['positional_query_group']}_e1")
print(f"  Answer: {positional_swapped['raw_output']}")
print()

print(
    "After swap, both models still agree:",
    direct_swapped["raw_output"] == positional_swapped["raw_output"],
)
print("(Because Bob is at g1 in the swapped configuration)")

Prompt generated:
  Pete loves jam, Bob loves cake, and Ann loves pie. What does Bob love?

Direct model:
  Still uses query_group=1 directly
  Retrieves: entity_g1_e1 = cake
  Answer: cake

Positional model:
  Extracts query entity from g1_e0: ('Bob',)
  Searches for ('Bob',) at position e0
  Found at group: 1
  Retrieves from g1_e1
  Answer: cake

After swap, both models still agree: True
(Because Bob is at g1 in the swapped configuration)


### Why Do They Still Agree?

On **behavioral** (symbolic) models, both approaches give the same answer because we're just computing what the answer SHOULD be.

The difference appears when testing **neural networks** with interventions:

- If we intervene on the neural representation of "group 1" and change it to represent Bob instead of Ann:
  - **Direct model prediction**: Should answer based on what's now in group 1 (cake)
  - **Positional model prediction**: Should search for the entity mentioned in the question, might not find it correctly

The models represent different hypotheses we're testing about the neural network's mechanism.

## Part 3: Generating Counterfactual Datasets

For intervention experiments, we need pairs of (input, counterfactual) to test the models.

In [22]:
# Generate a counterfactual using swap_query_group
example = swap_query_group(config)

input_ex = example["input"]
counter_ex = example["counterfactual_inputs"][0]

print("Counterfactual pair generated by swap_query_group:")
print()
print("INPUT:")
print(f"  Prompt: {input_ex['raw_input']}")
print(f"  Query group: {input_ex['query_group']}")
print()
print("COUNTERFACTUAL:")
print(f"  Prompt: {counter_ex['raw_input']}")
print(f"  Query group: {counter_ex['query_group']} (same position!)")

Counterfactual pair generated by swap_query_group:

INPUT:
  Prompt: Bob loves jam, and Tim loves bread. Who loves jam?
  Query group: 0

COUNTERFACTUAL:
  Prompt: Tim loves bread, and Bob loves jam. Who loves jam?
  Query group: 1 (same position!)


In [23]:
# Show what changed
print("\nWhat changed between input and counterfactual:\n")

query_group = input_ex["query_group"]
print(f"Entities in query group {query_group}:")
for e in range(config.max_entities_per_group):
    key = f"entity_g{query_group}_e{e}"
    input_val = input_ex.get(key)
    counter_val = counter_ex.get(key)
    changed = "<-- CHANGED" if input_val != counter_val else ""
    print(f"  Position e{e}:")
    print(f"    Input:          {input_val}")
    print(f"    Counterfactual: {counter_val} {changed}")


What changed between input and counterfactual:

Entities in query group 0:
  Position e0:
    Input:          Bob
    Counterfactual: Tim <-- CHANGED
  Position e1:
    Input:          jam
    Counterfactual: bread <-- CHANGED


## Part 4: Interchange Interventions

Like the MCQA example, we can perform interchange interventions to test causal variable localization.

In [24]:
# Perform interchange intervention on query_entity
# Use the same input_sample and create a counterfactual

counterfactual_sample = {
    "entity_g0_e0": "Tim",
    "entity_g0_e1": "soup",
    "entity_g1_e0": "Ann",  # Different entity at same position
    "entity_g1_e1": "bread",
    "entity_g2_e0": "Pete",
    "entity_g2_e1": "tea",
    "query_group": 2,
    "query_indices": (0,),
    "answer_index": 1,
    "active_groups": 3,
    "entities_per_group": 2,
}

print("Original input:")
print(f"  Entity at g1_e0: {input_sample['entity_g1_e0']}")
original_output = positional_model.run_forward(input_sample)
print(f"  query_entity: {original_output['query_entity']}")
print(f"  positional_query_group: {original_output['positional_query_group']}")
print(f"  raw_output: {original_output['raw_output']}")
print()

print("Counterfactual input:")
print(f"  Entity at g1_e0: {counterfactual_sample['entity_g1_e0']}")
counter_output = positional_model.run_forward(counterfactual_sample)
print(f"  query_entity: {counter_output['query_entity']}")
print(f"  positional_query_group: {counter_output['positional_query_group']}")
print(f"  raw_output: {counter_output['raw_output']}")
print()

# Interchange intervention on query_entity
intervened = positional_model.run_interchange(
    input_sample, {"query_entity": counterfactual_sample}
)

print("Intervened output (using query_entity from counterfactual):")
print(f"  query_entity: {intervened['query_entity']}")
print(f"  positional_query_group: {intervened['positional_query_group']}")
print(f"  raw_output: {intervened['raw_output']}")
intervened = positional_model.run_interchange(
    input_sample, {"positional_query_group": counterfactual_sample}
)

Original input:
  Entity at g1_e0: Ann
  query_entity: ('Ann',)
  positional_query_group: 1
  raw_output: pie

Counterfactual input:
  Entity at g1_e0: Ann
  query_entity: ('Pete',)
  positional_query_group: 2
  raw_output: tea

Intervened output (using query_entity from counterfactual):
  query_entity: ('Pete',)
  positional_query_group: 0
  raw_output: jam


## Part 5: Generate Counterfactual Dataset

Now let's generate a full dataset of counterfactual pairs for testing.

In [25]:
from causalab.causal.counterfactual_dataset import CounterfactualDataset

config.num_groups = 3

# Generate 64 counterfactual pairs using swap_query_group
swap_dataset = CounterfactualDataset.from_sampler(64, lambda: swap_query_group(config))

print(f"Generated {len(swap_dataset)} counterfactual pairs")
print(f"  Input:  {swap_dataset[0]['input']['raw_input']}")
print(f"  Counter: {swap_dataset[0]['counterfactual_inputs'][0]['raw_input']}")

Generated 64 counterfactual pairs
  Input:  Kate loves tea, Ann loves pie, and Bob loves cake. What does Kate love?
  Counter: Ann loves pie, Kate loves tea, and Bob loves cake. What does Kate love?


## Part 6: Testing Distinguishability

We can test if this counterfactual dataset can distinguish between different variables.

In [26]:
# Test if we can distinguish query_entity from raw_output
print("Can we distinguish query_entity from raw_output?")
positional_model.can_distinguish_with_dataset(
    swap_dataset, ["query_entity"], ["positional_query_group"]
)

Can we distinguish query_entity from raw_output?
Can distinguish between ['query_entity'] and ['positional_query_group']: 64 out of 64 examples
Proportion of distinguishable examples: 1.00


{'proportion': 1.0, 'count': 64}

In [27]:
# Test if we can distinguish positional_query_group from raw_output
print("Can we distinguish positional_query_group from raw_output?")
positional_model.can_distinguish_with_dataset(
    swap_dataset, ["positional_query_group"], ["raw_output"]
)

Can we distinguish positional_query_group from raw_output?
Can distinguish between ['positional_query_group'] and ['raw_output']: 64 out of 64 examples
Proportion of distinguishable examples: 1.00


{'proportion': 1.0, 'count': 64}

In [28]:
# Test if we can distinguish query_entity from no intervention
print("Can we distinguish query_entity from no intervention?")
positional_model.can_distinguish_with_dataset(
    swap_dataset, ["positional_query_group"], None
)

Can we distinguish query_entity from no intervention?
Can distinguish between ['positional_query_group'] and None: 64 out of 64 examples
Proportion of distinguishable examples: 1.00


{'proportion': 1.0, 'count': 64}

## Part 7: Random Counterfactuals (Baseline)

Let's compare with random counterfactuals to see the difference in discrimination power.

In [29]:
# Generate random counterfactual dataset
random_dataset = CounterfactualDataset.from_sampler(
    64, lambda: random_counterfactual(config)
)

print(f"Generated {len(random_dataset)} random counterfactual pairs")
print("Example pair (completely independent):")
print(f"  Input:  {random_dataset[0]['input']['raw_input']}")
print(f"  Counter: {random_dataset[0]['counterfactual_inputs'][0]['raw_input']}")

Generated 64 random counterfactual pairs
Example pair (completely independent):
  Input:  Tim loves cake, and Sue loves bread. What does Sue love?
  Counter: Sue loves soup, and Bob loves bread. What does Bob love?


In [30]:
# Test distinguishability with random counterfactuals
print("With random counterfactuals:")
print("query_entity vs raw_output:")
positional_model.can_distinguish_with_dataset(
    random_dataset, ["query_entity"], ["positional_query_group"]
)
print("positional_query_group vs raw_output:")
positional_model.can_distinguish_with_dataset(
    random_dataset, ["positional_query_group"], ["raw_output"]
)
print("query_entity vs no intervention:")
positional_model.can_distinguish_with_dataset(random_dataset, ["query_entity"], None)

With random counterfactuals:
query_entity vs raw_output:
Can distinguish between ['query_entity'] and ['positional_query_group']: 56 out of 64 examples
Proportion of distinguishable examples: 0.88
positional_query_group vs raw_output:
Can distinguish between ['positional_query_group'] and ['raw_output']: 58 out of 64 examples
Proportion of distinguishable examples: 0.91
query_entity vs no intervention:
Can distinguish between ['query_entity'] and None: 60 out of 64 examples
Proportion of distinguishable examples: 0.94


{'proportion': 0.9375, 'count': 60}