# Counterfactuals and Positional Variables

This notebook demonstrates two key concepts for entity binding experiments:

1. **Counterfactual datasets** - How to generate counterfactuals by swapping entity groups
2. **Positional causal models** - How the positional model searches for query entities and retrieves answers

## The Core Question

When a language model sees "Pete loves jam, and Ann loves pie. What does Ann love?", how does it retrieve the answer?

The **positional model** hypothesis: The model searches for "Ann" in its stored bindings, finds where Ann appears, then retrieves the associated value from that location.

We can test this hypothesis using **counterfactual experiments** where we swap entity groups and perform interchange interventions.

In [None]:
# Setup
%load_ext autoreload
%autoreload 2

import sys

sys.path.append("..")

from causalab.tasks.entity_binding.config import create_sample_love_config
from causalab.tasks.entity_binding.causal_models import (
    create_positional_entity_causal_model,
    sample_valid_entity_binding_input,
)
from causalab.tasks.entity_binding.counterfactuals import swap_query_group, random_counterfactual
from causalab.causal.causal_utils import generate_counterfactual_samples

## Part 1: Understanding the Positional Causal Model

Let's start by seeing how the positional model works on an input sample.

In [None]:
# Create the positional entity causal model
config = create_sample_love_config()
causal_model = create_positional_entity_causal_model(config)

print(f"Created causal model: {causal_model.id}")
print(f"\nModel variables:")
print(f"  Inputs: {causal_model.inputs[:8]}...")  # First 8 inputs
print(f"  Total inputs: {len(causal_model.inputs)}")

In [None]:
# Sample a valid input and run the model
input_sample = sample_valid_entity_binding_input(config, model=causal_model)

print("Input configuration:")
print(f"  Active groups: {input_sample['active_groups']}")
print(f"  Query group: {input_sample['query_group']}")
print(f"  Query indices: {input_sample['query_indices']}")
print(f"  Answer index: {input_sample['answer_index']}")
print()

# Show entities
for g in range(input_sample['active_groups']):
    entities = [input_sample[f'entity_g{g}_e{e}'] for e in range(input_sample['entities_per_group'])]
    print(f"  Group {g}: {entities}")
print()
print(f"  Query entities: query_e0={input_sample['query_e0']}, query_e1={input_sample['query_e1']}")

In [None]:
# Show the model's computed variables
print("Positional model computation:")
print(f"\n1. Prompt generated:")
print(f"   {input_sample['raw_input']}")
print()

print("2. Positional search:")
query_indices = input_sample['query_indices']
for e in query_indices:
    query_entity = input_sample[f'query_e{e}']
    positional_query = input_sample[f'positional_query_e{e}']
    print(f"   Search for query_e{e}={query_entity} at position e{e} → found at groups: {positional_query}")
print()

print("3. Determine answer position:")
print(f"   positional_answer (intersection): {input_sample['positional_answer']}")
print()

print("4. Retrieve answer:")
answer_pos = input_sample['positional_answer']
answer_idx = input_sample['answer_index']
print(f"   entity_g{answer_pos}_e{answer_idx} = {input_sample[f'entity_g{answer_pos}_e{answer_idx}']}")
print(f"   raw_output: {input_sample['raw_output']}")

## Part 2: Generating Counterfactuals with swap_query_group

The `swap_query_group` function generates counterfactual pairs by swapping entity groups.
This creates pairs where the same entities exist but at different positions.

In [None]:
# Generate a counterfactual using swap_query_group
example = swap_query_group(config)

input_ex = example["input"]
counter_ex = example["counterfactual_inputs"][0]

print("Counterfactual pair generated by swap_query_group:")
print()
print("INPUT:")
print(f"  Prompt: {input_ex['raw_input']}")
print(f"  Query group: {input_ex['query_group']}")
print(f"  Answer: {input_ex['raw_output']}")
print()
print("COUNTERFACTUAL:")
print(f"  Prompt: {counter_ex['raw_input']}")
print(f"  Query group: {counter_ex['query_group']}")
print(f"  Answer: {counter_ex['raw_output']}")

In [None]:
# Show what changed between input and counterfactual
print("What changed between input and counterfactual:\n")

query_group = input_ex["query_group"]
active_groups = input_ex["active_groups"]

print(f"Query group in input: {query_group}")
print(f"Query group in counterfactual: {counter_ex['query_group']}")
print()

# Find which groups were swapped
print("Entity positions:")
for g in range(active_groups):
    print(f"  Group {g}:")
    for e in range(input_ex['entities_per_group']):
        key = f"entity_g{g}_e{e}"
        input_val = input_ex.get(key)
        counter_val = counter_ex.get(key)
        changed = " ← SWAPPED" if input_val != counter_val else ""
        print(f"    e{e}: {input_val} → {counter_val}{changed}")

### Key insight: swap_query_group tracks where entities move

When we swap groups, the counterfactual's `query_group` variable points to where the 
original queried entities are NOW located. The counterfactual query is about the SAME 
entities but they're at a different position in the prompt.

This is crucial for testing positional vs direct retrieval hypotheses:
- **If the model uses positional lookup**: swapping affects WHERE it searches
- **The causal model tracks this**: `positional_answer` follows the query entities

## Part 3: Interchange Interventions

Interchange interventions allow us to test causal variable localization by patching 
values from one trace into another.

In [None]:
# Generate a fresh counterfactual pair for intervention demonstration
example = swap_query_group(config)
input_sample = example["input"]
counterfactual_sample = example["counterfactual_inputs"][0]

print("Original input:")
print(f"  Prompt: {input_sample['raw_input']}")
print(f"  positional_answer: {input_sample['positional_answer']}")
print(f"  raw_output: {input_sample['raw_output']}")
print()

print("Counterfactual input:")
print(f"  Prompt: {counterfactual_sample['raw_input']}")
print(f"  positional_answer: {counterfactual_sample['positional_answer']}")
print(f"  raw_output: {counterfactual_sample['raw_output']}")

In [None]:
# Show the key causal variables that differ
print("Key positional variables:")
print()
print("In original:")
for e in range(input_sample['entities_per_group']):
    pq = input_sample[f'positional_query_e{e}']
    if pq:
        print(f"  positional_query_e{e}: {pq}")
print(f"  positional_answer: {input_sample['positional_answer']}")
print()
print("In counterfactual:")
for e in range(counterfactual_sample['entities_per_group']):
    pq = counterfactual_sample[f'positional_query_e{e}']
    if pq:
        print(f"  positional_query_e{e}: {pq}")
print(f"  positional_answer: {counterfactual_sample['positional_answer']}")

### Interchange intervention: patch positional_answer from counterfactual

We can patch the `positional_answer` value from the counterfactual into the original trace.
This tests what would happen if the model looked up from a different position.

In [None]:
# Perform interchange intervention on positional_answer
intervened = causal_model.run_interchange(
    input_sample, {"positional_answer": counterfactual_sample}
)

print("Interchange intervention: positional_answer <- counterfactual")
print()
print(f"Original positional_answer: {input_sample['positional_answer']}")
print(f"Counterfactual positional_answer: {counterfactual_sample['positional_answer']}")
print(f"Intervened positional_answer: {intervened['positional_answer']}")
print()
print(f"Original raw_output: {input_sample['raw_output']}")
print(f"Intervened raw_output: {intervened['raw_output']}")
print()
if intervened['raw_output'] != input_sample['raw_output']:
    print("The intervention changed the output!")
    print(f"  Now retrieving from group {intervened['positional_answer']} instead of {input_sample['positional_answer']}")

## Part 4: Generating Counterfactual Datasets

For experiments, we need datasets of counterfactual pairs.

In [None]:
# Generate counterfactual dataset using swap_query_group
swap_dataset = generate_counterfactual_samples(64, lambda: swap_query_group(config))

print(f"Generated {len(swap_dataset)} counterfactual pairs using swap_query_group")
print()
print("Example pair:")
print(f"  Input:  {swap_dataset[0]['input']['raw_input']}")
print(f"  Counter: {swap_dataset[0]['counterfactual_inputs'][0]['raw_input']}")

## Part 5: Testing Distinguishability

A key property: can we distinguish interventions on different variables?
This tells us if our counterfactual dataset is useful for separating causal hypotheses.

In [None]:
# Test: Can we distinguish positional_answer intervention from no intervention?
print("Test 1: positional_answer vs no intervention")
result1 = causal_model.can_distinguish_with_dataset(
    swap_dataset, ["positional_answer"], None
)
print(f"  Distinguishable: {result1['count']}/{len(swap_dataset)} = {result1['proportion']:.0%}")
print()

In [None]:
# Test: Can we distinguish positional_query_e0 from positional_answer?
print("Test 2: positional_query_e0 vs positional_answer")
result2 = causal_model.can_distinguish_with_dataset(
    swap_dataset, ["positional_query_e0"], ["positional_answer"]
)
print(f"  Distinguishable: {result2['count']}/{len(swap_dataset)} = {result2['proportion']:.0%}")
print()

In [None]:
# Test: Can we distinguish positional_answer from raw_output?
print("Test 3: positional_answer vs raw_output")
result3 = causal_model.can_distinguish_with_dataset(
    swap_dataset, ["positional_answer"], ["raw_output"]
)
print(f"  Distinguishable: {result3['count']}/{len(swap_dataset)} = {result3['proportion']:.0%}")

## Part 6: Random Counterfactuals (Baseline)

Let's compare with completely random counterfactuals to see how useful `swap_query_group` is.

In [None]:
# Generate random counterfactual dataset
random_dataset = generate_counterfactual_samples(64, lambda: random_counterfactual(config))

print(f"Generated {len(random_dataset)} random counterfactual pairs")
print()
print("Example pair (completely independent):")
print(f"  Input:  {random_dataset[0]['input']['raw_input']}")
print(f"  Counter: {random_dataset[0]['counterfactual_inputs'][0]['raw_input']}")

In [None]:
# Compare distinguishability with random counterfactuals
print("With RANDOM counterfactuals:")
print()

print("positional_answer vs no intervention:")
r1 = causal_model.can_distinguish_with_dataset(random_dataset, ["positional_answer"], None)
print(f"  {r1['proportion']:.0%}")

print("positional_query_e0 vs positional_answer:")
r2 = causal_model.can_distinguish_with_dataset(random_dataset, ["positional_query_e0"], ["positional_answer"])
print(f"  {r2['proportion']:.0%}")

print("positional_answer vs raw_output:")
r3 = causal_model.can_distinguish_with_dataset(random_dataset, ["positional_answer"], ["raw_output"])
print(f"  {r3['proportion']:.0%}")

print()
print("With SWAP counterfactuals (from above):")
print(f"  positional_answer vs None: {result1['proportion']:.0%}")
print(f"  positional_query_e0 vs positional_answer: {result2['proportion']:.0%}")
print(f"  positional_answer vs raw_output: {result3['proportion']:.0%}")

## Summary

This notebook demonstrated:

1. **Positional causal model**: How the model searches for query entities and retrieves answers
   - `positional_query_e{e}`: Finds which groups contain the query entity at position e
   - `positional_answer`: Intersection of positional queries (the group to retrieve from)

2. **Counterfactual generation**: Using `swap_query_group()` to create structured counterfactuals
   - Swaps entity groups while tracking where queried entities move
   - More useful than random counterfactuals for testing causal hypotheses

3. **Interchange interventions**: Testing causal variable localization
   - Patching values from counterfactual into original trace
   - Observing downstream effects on retrieval

4. **Distinguishability testing**: Verifying that our counterfactual dataset can separate different causal variables

For more details on interchange interventions with neural networks, see notebook 04.