# Testing Interchange Interventions on Positional Entity Variables

This notebook demonstrates:
1. How interchange interventions work on the positional entity causal model
2. The difference between causal model interventions and neural network interventions
3. Why positional_entity variables are trivial but still useful for testing neural networks

## Key Insight

The `positional_entity_g{g}_e{e}` variables in the causal model are **trivial** - they always return the group index (0 or 1). However, this doesn't mean the neural network experiment is meaningless! The neural network might store rich positional information at these token positions, even though the causal model's abstraction is simple.

In [1]:
# Setup
import sys

sys.path.append("..")

from causalab.tasks.entity_binding.config import create_sample_love_config
from causalab.tasks.entity_binding.causal_models import create_positional_entity_causal_model
from causalab.tasks.entity_binding.counterfactuals import swap_query_group
from causalab.causal.counterfactual_dataset import CounterfactualDataset

## Step 1: Create Task Configuration and Causal Model

In [2]:
# Create config
config = create_sample_love_config()
config.max_groups = 2
config.prompt_prefix = "We will ask a question about the following sentences.\n\n"
config.statement_question_separator = "\n\n"
config.prompt_suffix = "\nAnswer:"
config.fixed_query_indices = (0,)

print("Task configuration:")
print(f"  Max groups: {config.max_groups}")
print(f"  Entities per group: {config.max_entities_per_group}")
print(f"  Query indices: FIXED to {config.fixed_query_indices}")

# Create positional entity causal model
causal_model = create_positional_entity_causal_model(config)
print(f"\nCausal model: {causal_model.id}")

Task configuration:
  Max groups: 2
  Entities per group: 2
  Query indices: FIXED to (0,)

Causal model: entity_binding_positional_entity_2g_2e


## Step 2: Generate a Counterfactual Pair

We'll use `swap_query_group()` which swaps entity groups between original and counterfactual.

In [None]:
# Generate a counterfactual pair
example = swap_query_group(config)

input_sample = example["input"]
counterfactual_sample = example["counterfactual_inputs"][0]

print("Original input:")
print(f"  Entity g0_e0: {input_sample['entity_g0_e0']}")
print(f"  Entity g0_e1: {input_sample['entity_g0_e1']}")
print(f"  Entity g1_e0: {input_sample['entity_g1_e0']}")
print(f"  Entity g1_e1: {input_sample['entity_g1_e1']}")
print(f"  query_group: {input_sample['query_group']}")
print(f"  query_e0: {input_sample['query_e0']}, query_e1: {input_sample['query_e1']}")
print(f"  Prompt: {input_sample['raw_input']}")

print("\nCounterfactual input:")
print(f"  Entity g0_e0: {counterfactual_sample['entity_g0_e0']}")
print(f"  Entity g0_e1: {counterfactual_sample['entity_g0_e1']}")
print(f"  Entity g1_e0: {counterfactual_sample['entity_g1_e0']}")
print(f"  Entity g1_e1: {counterfactual_sample['entity_g1_e1']}")
print(f"  query_group: {counterfactual_sample['query_group']}")
print(f"  query_e0: {counterfactual_sample['query_e0']}, query_e1: {counterfactual_sample['query_e1']}")
print(f"  Prompt: {counterfactual_sample['raw_input']}")

print("\nNote: The entity groups are swapped between original and counterfactual!")
print("The query_group and query_e{e} variables track which group/entities are being queried.")

## Step 3: Run Causal Model Without Intervention

First, let's see what the causal model computes for both inputs without any intervention.

In [4]:
# Run model on original
print("Original input (no intervention):")
original_output = causal_model.run_forward(input_sample)
print(f"  positional_entity_g0_e1: {original_output['positional_entity_g0_e1']}")
print(f"  positional_entity_g1_e1: {original_output['positional_entity_g1_e1']}")
print(f"  positional_answer: {original_output['positional_answer']}")
print(f"  raw_output: {original_output['raw_output']}")

# Run model on counterfactual
print("\nCounterfactual input (no intervention):")
counter_output = causal_model.run_forward(counterfactual_sample)
print(f"  positional_entity_g0_e1: {counter_output['positional_entity_g0_e1']}")
print(f"  positional_entity_g1_e1: {counter_output['positional_entity_g1_e1']}")
print(f"  positional_answer: {counter_output['positional_answer']}")
print(f"  raw_output: {counter_output['raw_output']}")

print("\nObservation: positional_entity variables are always 0 and 1!")
print(
    "They just return the group index, regardless of what entities are in the groups."
)

Original input (no intervention):
  positional_entity_g0_e1: 0
  positional_entity_g1_e1: 1
  positional_answer: 1
  raw_output: bread

Counterfactual input (no intervention):
  positional_entity_g0_e1: 0
  positional_entity_g1_e1: 1
  positional_answer: 0
  raw_output: bread

Observation: positional_entity variables are always 0 and 1!
They just return the group index, regardless of what entities are in the groups.


## Step 4: Perform Interchange Intervention

Now we'll perform an interchange intervention using arrow syntax:
- `positional_entity_g0_e1 <- positional_entity_g1_e1` (from counterfactual)
- `positional_entity_g1_e1 <- positional_entity_g0_e1` (from counterfactual)

This tests: "What happens if we swap the positional entity values?"

In [14]:
# Define target variables
target_variables = [
    "positional_entity_g0_e1<-positional_entity_g1_e1",
    "positional_entity_g1_e1<-positional_entity_g0_e1",
]

print(f"Target variables: {target_variables}")
print("\nThis means:")
print(
    "  - Take positional_entity_g1_e1 from counterfactual → patch into positional_entity_g0_e1 of original"
)
print(
    "  - Take positional_entity_g0_e1 from counterfactual → patch into positional_entity_g1_e1 of original"
)

# Perform interchange intervention
intervened_output = causal_model.run_interchange(
    input_sample,
    {
        "positional_entity_g0_e1<-positional_entity_g1_e1": counterfactual_sample,
        "positional_entity_g1_e1<-positional_entity_g0_e1": counterfactual_sample,
        "positional_entity_g0_e0<-positional_entity_g1_e0": counterfactual_sample,
        "positional_entity_g1_e0<-positional_entity_g0_e0": counterfactual_sample,
    },
)

for variable in intervened_output:
    print(f"{variable}: {intervened_output[variable]}")

Target variables: ['positional_entity_g0_e1<-positional_entity_g1_e1', 'positional_entity_g1_e1<-positional_entity_g0_e1']

This means:
  - Take positional_entity_g1_e1 from counterfactual → patch into positional_entity_g0_e1 of original
  - Take positional_entity_g0_e1 from counterfactual → patch into positional_entity_g1_e1 of original
entity_g0_e0: Pete
entity_g0_e1: pie
entity_g1_e0: Ann
entity_g1_e1: bread
query_group: 1
query_indices: (0,)
answer_index: 1
active_groups: 2
entities_per_group: 2
statement_template: {entity_e0} loves {entity_e1}
positional_entity_g0_e0: 1
positional_entity_g0_e1: 1
positional_entity_g1_e0: 0
positional_entity_g1_e1: 0
question_template: What does {query_entity} love?
positional_query_e0: (0,)
positional_query_e1: ()
positional_answer: 0
raw_input: We will ask a question about the following sentences.

Pete loves pie, and Ann loves bread.

What does Ann love?
Answer:
raw_output: pie


## Step 5: Understanding Why the Intervention Has No Effect

The positional_entity variables are **deterministic functions** that always return their group index.

In [None]:
print("Understanding positional_entity variables:")
print("\nIn the causal model mechanism:")
print(
    "  positional_entity_g0_e1 = lambda entity_val: 0 if entity_val is not None else None"
)
print(
    "  positional_entity_g1_e1 = lambda entity_val: 1 if entity_val is not None else None"
)
print("\nThey ALWAYS return their group index (0 or 1)!")

print("\nKey causal model variables:")
print("  - query_group: Which group contains the query entity (input variable)")
print("  - query_e0, query_e1: The actual query entities from that group (input variables)")
print("  - positional_entity_g{g}_e{e}: Group index for each entity position (trivial)")
print("  - positional_answer: Which group to retrieve from (computed from search)")

print("\nIntervention effect on values:")
print(
    f"  Original:   g0_e1={original_output['positional_entity_g0_e1']}, g1_e1={original_output['positional_entity_g1_e1']}"
)
print(
    f"  Counter:    g0_e1={counter_output['positional_entity_g0_e1']}, g1_e1={counter_output['positional_entity_g1_e1']}"
)
print(
    f"  Intervened: g0_e1={intervened_output['positional_entity_g0_e1']}, g1_e1={intervened_output['positional_entity_g1_e1']}"
)
print(
    "\nNo change! Because both original and counterfactual have the same positional values (0, 1)."
)

# Check downstream effects
print("\nDownstream effects:")
if intervened_output["positional_answer"] != original_output["positional_answer"]:
    print(
        f"  positional_answer changed: {original_output['positional_answer']} -> {intervened_output['positional_answer']}"
    )
else:
    print(f"  positional_answer unchanged: {original_output['positional_answer']}")

if intervened_output["raw_output"] != original_output["raw_output"]:
    print(
        f"  raw_output changed: '{original_output['raw_output']}' -> '{intervened_output['raw_output']}'"
    )
else:
    print(f"  raw_output unchanged: '{original_output['raw_output']}'")

## Step 6: Test Distinguishability on a Dataset

Let's generate a small dataset and test if the intervention is distinguishable from no intervention.

In [7]:
# Generate dataset
print("Generating dataset of 16 counterfactual pairs...")
dataset = CounterfactualDataset.from_sampler(
    16, lambda: swap_query_group(config), id="swap_query_group_test"
)
print(f"✓ Generated {len(dataset)} pairs")

# Test distinguishability from no intervention
print("\nTest 1: Can we distinguish the intervention from NO intervention?")
print(f"  Variables: {target_variables}")
print("  Comparing: intervention on these variables vs. None")

result = causal_model.can_distinguish_with_dataset(dataset, target_variables, None)

print("\nResult:")
print(f"  Distinguishable examples: {result['count']}/{len(dataset)}")
print(f"  Proportion: {result['proportion']:.2%}")

if result["proportion"] == 0:
    print("\n✓ As expected: 0% distinguishable!")
    print(
        "  The intervention has no effect because positional_entity values are always 0 and 1."
    )

Generating dataset of 16 counterfactual pairs...
✓ Generated 16 pairs

Test 1: Can we distinguish the intervention from NO intervention?
  Variables: ['positional_entity_g0_e1<-positional_entity_g1_e1', 'positional_entity_g1_e1<-positional_entity_g0_e1']
  Comparing: intervention on these variables vs. None
Can distinguish between ['positional_entity_g0_e1<-positional_entity_g1_e1', 'positional_entity_g1_e1<-positional_entity_g0_e1'] and None: 0 out of 16 examples
Proportion of distinguishable examples: 0.00

Result:
  Distinguishable examples: 0/16
  Proportion: 0.00%

✓ As expected: 0% distinguishable!
  The intervention has no effect because positional_entity values are always 0 and 1.


## Step 7: Test Distinguishability from Another Variable

Even though intervening on positional_entity has no effect, it should still be distinguishable from intervening on other variables like `positional_answer`.

In [8]:
print("Test 2: Can we distinguish from positional_answer intervention?")

result2 = causal_model.can_distinguish_with_dataset(
    dataset, target_variables, ["positional_answer"]
)

print("\nResult:")
print(f"  Distinguishable examples: {result2['count']}/{len(dataset)}")
print(f"  Proportion: {result2['proportion']:.2%}")

if result2["proportion"] == 1.0:
    print("\n✓ 100% distinguishable!")
    print(
        "  This confirms positional_entity and positional_answer are different variables."
    )

Test 2: Can we distinguish from positional_answer intervention?
Can distinguish between ['positional_entity_g0_e1<-positional_entity_g1_e1', 'positional_entity_g1_e1<-positional_entity_g0_e1'] and ['positional_answer']: 16 out of 16 examples
Proportion of distinguishable examples: 1.00

Result:
  Distinguishable examples: 16/16
  Proportion: 100.00%

✓ 100% distinguishable!
  This confirms positional_entity and positional_answer are different variables.


## Key Takeaways

### 1. Causal Model Behavior

- **positional_entity variables are trivial**: They always return their group index (0 or 1)
- **Interchange intervention has no effect**: Because both original and counterfactual have the same values (0, 1)
- **0% distinguishable from no intervention**: The causal model outputs don't change
- **100% distinguishable from positional_answer**: They are different variables in the graph

### 2. Why This Still Matters for Neural Networks

Even though the causal model's positional_entity variables are trivial, the neural network experiment is meaningful because:

1. **The causal model is an abstraction**: It captures the logical structure, not the actual computation
2. **Neural representations are NOT trivial**: The network might store rich positional information at the e1 token positions
3. **The intervention tests a hypothesis**: Does the network encode "which group this entity belongs to" at the entity token positions?
4. **Token positions matter**: We're intervening at specific locations (last token of e1 entities) in the neural network

### 3. Causal Abstraction Claim

The experiment tests whether:
- The neural network's activations at position `[g0_e1_last_token, g1_e1_last_token]`
- Are causally aligned with the causal model's `positional_entity_g0_e1` and `positional_entity_g1_e1` variables
- Even though those causal variables are trivial (always 0, 1), the neural representations might be non-trivial
- If swapping those neural representations affects retrieval, it suggests the network is encoding positional information there

### 4. Experimental Design Insight

This is an example of testing an **intermediate representation hypothesis**:
- We're not testing if the network computes 0 or 1 at those positions
- We're testing if those positions **causally influence** downstream retrieval
- The token positions act as "probes" into the network's internal computations
- Swapping representations tests if they carry information that affects the output