> **DEPRECATED:** This task is outdated and may not reflect current best practices.
> See `causalab/tasks/MCQA/` for an up-to-date example.

# Carry Variables and Cascade Testing

This notebook demonstrates three key concepts for testing carry propagation in 3-digit addition:

1. **Carry variables (C_1, C_2, C_3)** - Explicit intermediate variables representing carries
2. **Random counterfactuals** - Testing distinguishability with baseline counterfactuals
3. **Sum-to-nine counterfactuals** - Testing carry cascade with binary vector [0, 1, 0]

## The Core Question

When a language model computes "The sum of 123 and 456 is", does it use explicit carry variables that propagate from right to left?

**Hypothesis**: The model uses carry variables C_i at each digit position, and we can test this by:
- Intervening on carry variables
- Testing distinguishability with counterfactual datasets
- Creating cascades where C_1 → C_2 → C_3

## Setup

In [1]:
# Autoreload
%load_ext autoreload
%autoreload 2

import sys

sys.path.append("../..")

from causalab.tasks.general_addition.config import create_two_number_three_digit_config
from causalab.tasks.general_addition.causal_models import (
    create_intermediate_addition_model,
)
from causalab.tasks.general_addition.counterfactuals import (
    random_counterfactual,
    sum_to_nine_counterfactual,
)
from causalab.causal.counterfactual_dataset import CounterfactualDataset

## Part 1: Understanding Carry Variables

The intermediate model explicitly represents carry propagation with variables C_i (carry from position i) and O_i (output digit at position i).

### Position Indexing (Right-to-Left)
- **Position 1**: Ones place (rightmost)
- **Position 2**: Tens place
- **Position 3**: Hundreds place (leftmost for 3-digit)

### Variables
- **C_1**: Carry from ones position (1 if ones digits sum ≥ 10)
- **C_2**: Carry from tens position (1 if tens digits + C_1 ≥ 10)
- **C_3**: Carry from hundreds position (1 if hundreds digits + C_2 ≥ 10)
- **O_1, O_2, O_3, O_4**: Output digits at each position

In [2]:
# Create 3-digit configuration and intermediate model
config = create_two_number_three_digit_config()
model = create_intermediate_addition_model(config)

print(f"Model: {model.id}")
print(f"Total variables: {len(model.variables)}")
print()

# Show carry and output variables
carry_vars = [v for v in model.variables if v.startswith("C_")]
output_vars = [v for v in model.variables if v.startswith("O_")]

print(f"Carry variables: {carry_vars}")
print(f"Output variables: {output_vars}")

Model: addition_intermediate_2n_3d
Total variables: 17

Carry variables: ['C_1', 'C_2', 'C_3']
Output variables: ['O_1', 'O_2', 'O_3', 'O_4']


### Example: Carry Propagation in 368 + 475 = 843

In [3]:
# Example: 368 + 475 = 843
example = {
    "digit_0_0": 3,
    "digit_0_1": 6,
    "digit_0_2": 8,  # 368
    "digit_1_0": 4,
    "digit_1_1": 7,
    "digit_1_2": 5,  # 475
    "num_digits": 3,
    "template": config.templates[0],
}

output = model.new_trace(example)

print("Example: 368 + 475 = 843")
print("=" * 70)
print()
print(f"Prompt: {output['raw_input']}")
print(f"Answer: {output['raw_output'].strip()}")
print()
print("Carry propagation (right-to-left):")
print()
print("Position 1 (ones): 8 + 5 = 13")
print(f"  C_1 = {output['C_1']} (13 >= 10, generates carry)")
print(f"  O_1 = {output['O_1']} (13 % 10)")
print()
print(f"Position 2 (tens): 6 + 7 + C_1 = 6 + 7 + {output['C_1']} = 14")
print(f"  C_2 = {output['C_2']} (14 >= 10, generates carry)")
print(f"  O_2 = {output['O_2']} (14 % 10)")
print()
print(f"Position 3 (hundreds): 3 + 4 + C_2 = 3 + 4 + {output['C_2']} = 8")
print(f"  C_3 = {output['C_3']} (8 < 10, no carry)")
print(f"  O_3 = {output['O_3']}")
print()
print(f"Position 4 (thousands): O_4 = {output['O_4']} (equals C_3)")
print()
print("✓ Carries propagate: Position 1 → Position 2 (but stop at Position 3)")

Example: 368 + 475 = 843

Prompt: The sum of 368 and 475 is
Answer: 843

Carry propagation (right-to-left):

Position 1 (ones): 8 + 5 = 13
  C_1 = 1 (13 >= 10, generates carry)
  O_1 = 3 (13 % 10)

Position 2 (tens): 6 + 7 + C_1 = 6 + 7 + 1 = 14
  C_2 = 1 (14 >= 10, generates carry)
  O_2 = 4 (14 % 10)

Position 3 (hundreds): 3 + 4 + C_2 = 3 + 4 + 1 = 8
  C_3 = 0 (8 < 10, no carry)
  O_3 = 8

Position 4 (thousands): O_4 = 0 (equals C_3)

✓ Carries propagate: Position 1 → Position 2 (but stop at Position 3)


### Interventions on Carry Variables

We can intervene on carry variables to test if neural networks use similar structure.

In [4]:
# Example without carry: 123 + 234 = 357
no_carry_example = {
    "digit_0_0": 1,
    "digit_0_1": 2,
    "digit_0_2": 3,  # 123
    "digit_1_0": 2,
    "digit_1_1": 3,
    "digit_1_2": 4,  # 234
    "num_digits": 3,
    "template": config.templates[0],
}

output_original = model.new_trace(no_carry_example)

print("Original: 123 + 234 = 357 (no carries)")
print("=" * 70)
print(f"  C_1 = {output_original['C_1']} (3+4=7, no carry)")
print(f"  C_2 = {output_original['C_2']} (2+3+0=5, no carry)")
print(f"  C_3 = {output_original['C_3']} (1+2+0=3, no carry)")
print(f"  Answer: {output_original['raw_output'].strip()}")
print()

Original: 123 + 234 = 357 (no carries)
  C_1 = 0 (3+4=7, no carry)
  C_2 = 0 (2+3+0=5, no carry)
  C_3 = 0 (1+2+0=3, no carry)
  Answer: 357



In [5]:
# Intervene: Force C_1 = 1
output_intervened = model.new_trace({**no_carry_example, 'C_1': 1})

print("Intervention: Force C_1 = 1 (pretend ones generated a carry)")
print("=" * 70)
print(f"  C_1 = {output_intervened['C_1']} (intervened)")
print(f"  C_2 = {output_intervened['C_2']} (2+3+1=6, still no carry)")
print(f"  C_3 = {output_intervened['C_3']} (1+2+0=3, no carry)")
print(f"  O_2 = {output_intervened['O_2']} (changed from 5 to 6!)")
print(f"  Answer: {output_intervened['raw_output'].strip()} (was 357, now 367)")
print()
print("Key insight: Intervening on C_1 affects downstream O_2!")

Intervention: Force C_1 = 1 (pretend ones generated a carry)
  C_1 = 1 (intervened)
  C_2 = 0 (2+3+1=6, still no carry)
  C_3 = 0 (1+2+0=3, no carry)
  O_2 = 6 (changed from 5 to 6!)
  Answer: 367 (was 357, now 367)

Key insight: Intervening on C_1 affects downstream O_2!


In [6]:
# Intervene: Force C_2 = 1
output_intervened_c2 = model.new_trace({**no_carry_example, 'C_2': 1})

print("Intervention: Force C_2 = 1 (pretend tens generated a carry)")
print("=" * 70)
print(f"  C_2 = {output_intervened_c2['C_2']} (intervened)")
print(f"  C_3 = {output_intervened_c2['C_3']} (1+2+1=4, still no carry)")
print(f"  O_3 = {output_intervened_c2['O_3']} (changed from 3 to 4!)")
print(f"  Answer: {output_intervened_c2['raw_output'].strip()} (was 357, now 457)")
print()
print("Key insight: Intervening on C_2 affects downstream O_3!")

Intervention: Force C_2 = 1 (pretend tens generated a carry)
  C_2 = 1 (intervened)
  C_3 = 0 (1+2+1=4, still no carry)
  O_3 = 4 (changed from 3 to 4!)
  Answer: 457 (was 357, now 457)

Key insight: Intervening on C_2 affects downstream O_3!


## Part 2: Random Counterfactuals - Testing Distinguishability

We generate a dataset of random counterfactual pairs and test whether we can distinguish carry variables from other variables.

### What is Distinguishability?

Two variables V1 and V2 are **distinguishable** on a counterfactual pair if:
- Intervening on V1 gives a different output than intervening on V2

If we can distinguish C_1 from all other variables, it suggests C_1 represents a unique computational role.

In [7]:
# Generate random counterfactual dataset
print("Generating random counterfactual dataset...")
random_dataset = CounterfactualDataset.from_sampler(
    128, lambda: random_counterfactual(config, num_addends=2, num_digits=3)
)

print(f"✓ Generated {len(random_dataset)} counterfactual pairs")
print()
print("Example pair:")
input_ex = random_dataset[0]["input"]
counter_ex = random_dataset[0]["counterfactual_inputs"][0]
print(f"  Input:  {input_ex['raw_input']}")
print(f"  Counter: {counter_ex['raw_input']}")
print()
print("These are completely independent examples (baseline).")

Generating random counterfactual dataset...
✓ Generated 128 counterfactual pairs

Example pair:
  Input:  The sum of 981 and 234 is
  Counter: The sum of 955 and 550 is

These are completely independent examples (baseline).


### Testing C_1 Distinguishability from All Variables

We test if C_1 can be distinguished from every other variable in the model.

In [8]:
# Get all variables except C_1
all_vars = model.variables
test_vars = [
    v
    for v in all_vars
    if v
    not in ["C_1", "raw_input", "raw_output", "template", "num_addends", "num_digits"]
]

print("Testing C_1 distinguishability from all variables:")
print("=" * 70)
print()

distinguishable_count = 0
total_count = len(test_vars)

# Test each variable
for var in test_vars:
    result = model.can_distinguish_with_dataset(random_dataset, ["C_1"], [var])
    if result["proportion"] > 0:
        distinguishable_count += 1
        status = "✓"
    else:
        status = "✗"
    print(
        f"{status} C_1 vs {var:20s}: {result['proportion']:.2%} ({result['count']}/{len(random_dataset)})"
    )

print()
print("=" * 70)
print(
    f"Summary: C_1 is distinguishable from {distinguishable_count}/{total_count} variables"
)
print()
print("This measures how often C_1 has a unique effect compared to other variables.")

Testing C_1 distinguishability from all variables:

Can distinguish between ['C_1'] and ['digit_0_0']: 124 out of 128 examples
Proportion of distinguishable examples: 0.97
✓ C_1 vs digit_0_0           : 96.88% (124/128)
Can distinguish between ['C_1'] and ['digit_0_1']: 115 out of 128 examples
Proportion of distinguishable examples: 0.90
✓ C_1 vs digit_0_1           : 89.84% (115/128)
Can distinguish between ['C_1'] and ['digit_0_2']: 115 out of 128 examples
Proportion of distinguishable examples: 0.90
✓ C_1 vs digit_0_2           : 89.84% (115/128)
Can distinguish between ['C_1'] and ['digit_1_0']: 123 out of 128 examples
Proportion of distinguishable examples: 0.96
✓ C_1 vs digit_1_0           : 96.09% (123/128)
Can distinguish between ['C_1'] and ['digit_1_1']: 122 out of 128 examples
Proportion of distinguishable examples: 0.95
✓ C_1 vs digit_1_1           : 95.31% (122/128)
Can distinguish between ['C_1'] and ['digit_1_2']: 110 out of 128 examples
Proportion of distinguishable exa

### Testing C_1 vs Output Variables (Most Important)

The most critical tests are whether C_1 is distinguishable from output digits O_1, O_2, O_3.

In [9]:
print("Testing C_1 vs Output Variables:")
print("=" * 70)
print()

for output_var in ["O_1", "O_2", "O_3", "O_4"]:
    result = model.can_distinguish_with_dataset(random_dataset, ["C_1"], [output_var])
    print(f"C_1 vs {output_var}: {result['proportion']:.2%} distinguishable")
    print(f"  ({result['count']}/{len(random_dataset)} pairs)")
    print()

print("Key insight: C_1 should be highly distinguishable from O_2, O_3")
print("because C_1 affects those output positions through carry propagation.")

Testing C_1 vs Output Variables:

Can distinguish between ['C_1'] and ['O_1']: 116 out of 128 examples
Proportion of distinguishable examples: 0.91
C_1 vs O_1: 90.62% distinguishable
  (116/128 pairs)

Can distinguish between ['C_1'] and ['O_2']: 118 out of 128 examples
Proportion of distinguishable examples: 0.92
C_1 vs O_2: 92.19% distinguishable
  (118/128 pairs)

Can distinguish between ['C_1'] and ['O_3']: 119 out of 128 examples
Proportion of distinguishable examples: 0.93
C_1 vs O_3: 92.97% distinguishable
  (119/128 pairs)

Can distinguish between ['C_1'] and ['O_4']: 94 out of 128 examples
Proportion of distinguishable examples: 0.73
C_1 vs O_4: 73.44% distinguishable
  (94/128 pairs)

Key insight: C_1 should be highly distinguishable from O_2, O_3
because C_1 affects those output positions through carry propagation.


## Part 3: Sum-to-Nine Counterfactuals - Testing Carry Cascade

### The Binary Vector [0, 1, 0]

We want to test a specific carry cascade pattern:
- **Position 1 (ones)**: Digits sum to 9 → **C_1 = 0** (no carry)
- **Position 2 (tens)**: Digits sum to a value that with C_1=0 gives ≥10 → **C_2 = 1** (carry!)
- **Position 3 (hundreds)**: Digits sum to a value that with C_2=1 gives <10 → **C_3 = 0** (no carry)

But what if we intervene and force C_1 = 1? Then:
- C_1 = 1 (intervened)
- Position 2: Now sums to ≥10 (because we added 1)
- C_2 = 1 (still carries)
- Position 3: Now sums to ≥10 (because C_2 propagated)
- **C_3 = 1** (cascade! The intervention cascades through positions)

### Creating the Pattern

We use `sum_to_nine_counterfactual(config, digit_position=2)` which makes the **ones digits** (position 2 in left-to-right indexing) sum to 9.

In [10]:
# Generate sum-to-nine dataset (ones place)
print("Generating sum-to-nine counterfactual dataset (ones digits sum to 9)...")
sum_to_nine_dataset = CounterfactualDataset.from_sampler(
    128,
    lambda: sum_to_nine_counterfactual(
        config, digit_position=2
    ),  # digit_position=2 is ones (rightmost)
)

print(f"✓ Generated {len(sum_to_nine_dataset)} counterfactual pairs")
print()
print("Example pairs where input has ones digits summing to 9:")
print()

for i in range(3):
    input_ex = sum_to_nine_dataset[i]["input"]
    counter_ex = sum_to_nine_dataset[i]["counterfactual_inputs"][0]

    # Check ones digits
    ones_sum = input_ex["digit_0_2"] + input_ex["digit_1_2"]

    print(f"Pair {i + 1}:")
    print(f"  Input:  {input_ex['raw_input']}")
    print(f"    Ones: {input_ex['digit_0_2']} + {input_ex['digit_1_2']} = {ones_sum} ✓")
    print(f"  Counter: {counter_ex['raw_input']}")
    print()

Generating sum-to-nine counterfactual dataset (ones digits sum to 9)...
✓ Generated 128 counterfactual pairs

Example pairs where input has ones digits summing to 9:

Pair 1:
  Input:  The sum of 985 and 224 is
    Ones: 5 + 4 = 9 ✓
  Counter: The sum of 674 and 143 is

Pair 2:
  Input:  The sum of 282 and 137 is
    Ones: 2 + 7 = 9 ✓
  Counter: The sum of 457 and 202 is

Pair 3:
  Input:  The sum of 704 and 725 is
    Ones: 4 + 5 = 9 ✓
  Counter: The sum of 178 and 553 is



### Demonstrating the Cascade Pattern

Let's manually create an example that shows the [0, 1, 0] → [1, 1, 1] cascade.

In [11]:
# Create example: 369 + 574 where ones sum to 9+4=13, tens is 6+7=13, hundreds is 3+5=8
# But we want ones to sum to 9 (no carry), tens to sum high enough to carry with or without C_1

# Try: 369 + 540
# Ones: 9 + 0 = 9 (C_1 = 0)
# Tens: 6 + 4 = 10 (C_2 = 1)
# Hundreds: 3 + 5 = 8, + C_2 = 9 (C_3 = 0)
# This gives pattern [0, 1, 0]

cascade_example = {
    "digit_0_0": 3,
    "digit_0_1": 6,
    "digit_0_2": 9,  # 369
    "digit_1_0": 5,
    "digit_1_1": 4,
    "digit_1_2": 0,  # 540
    "num_digits": 3,
    "template": config.templates[0],
}

output_normal = model.new_trace(cascade_example)

print("Example: 369 + 540 = 909")
print("=" * 70)
print()
print("Normal computation:")
print("  Position 1 (ones): 9 + 0 = 9")
print(f"    C_1 = {output_normal['C_1']} (no carry, since 9 < 10)")
print(f"    O_1 = {output_normal['O_1']}")
print()
print(f"  Position 2 (tens): 6 + 4 + C_1 = 6 + 4 + {output_normal['C_1']} = 10")
print(f"    C_2 = {output_normal['C_2']} (carry!)")
print(f"    O_2 = {output_normal['O_2']} (10 % 10)")
print()
print(f"  Position 3 (hundreds): 3 + 5 + C_2 = 3 + 5 + {output_normal['C_2']} = 9")
print(f"    C_3 = {output_normal['C_3']} (no carry, since 9 < 10)")
print(f"    O_3 = {output_normal['O_3']}")
print()
print(
    f"  Carry pattern: [C_1={output_normal['C_1']}, C_2={output_normal['C_2']}, C_3={output_normal['C_3']}] = [0, 1, 0] ✓"
)
print(f"  Answer: {output_normal['raw_output'].strip()}")

Example: 369 + 540 = 909

Normal computation:
  Position 1 (ones): 9 + 0 = 9
    C_1 = 0 (no carry, since 9 < 10)
    O_1 = 9

  Position 2 (tens): 6 + 4 + C_1 = 6 + 4 + 0 = 10
    C_2 = 1 (carry!)
    O_2 = 0 (10 % 10)

  Position 3 (hundreds): 3 + 5 + C_2 = 3 + 5 + 1 = 9
    C_3 = 0 (no carry, since 9 < 10)
    O_3 = 9

  Carry pattern: [C_1=0, C_2=1, C_3=0] = [0, 1, 0] ✓
  Answer: 909


In [12]:
# Now intervene: Force C_1 = 1 and watch the cascade
output_cascade = model.new_trace({**cascade_example, 'C_1': 1})

print("Intervention: Force C_1 = 1 (cascade test)")
print("=" * 70)
print()
print("Intervened computation:")
print(f"  Position 1: C_1 = {output_cascade['C_1']} (forced to 1)")
print()
print(f"  Position 2 (tens): 6 + 4 + C_1 = 6 + 4 + {output_cascade['C_1']} = 11")
print(f"    C_2 = {output_cascade['C_2']} (still carries, now 11 >= 10)")
print(f"    O_2 = {output_cascade['O_2']} (11 % 10, changed from 0 to 1!)")
print()
print(f"  Position 3 (hundreds): 3 + 5 + C_2 = 3 + 5 + {output_cascade['C_2']} = 9")
print(f"    C_3 = {output_cascade['C_3']} (still no carry, 9 < 10)")
print(f"    O_3 = {output_cascade['O_3']} (unchanged)")
print()
print(
    f"  Carry pattern: [C_1={output_cascade['C_1']}, C_2={output_cascade['C_2']}, C_3={output_cascade['C_3']}] = [1, 1, 0]"
)
print(f"  Answer: {output_cascade['raw_output'].strip()} (changed from 909 to 919!)")
print()
print("✓ The C_1 intervention propagated through C_2 but stopped at C_3")

Intervention: Force C_1 = 1 (cascade test)

Intervened computation:
  Position 1: C_1 = 1 (forced to 1)

  Position 2 (tens): 6 + 4 + C_1 = 6 + 4 + 1 = 11
    C_2 = 1 (still carries, now 11 >= 10)
    O_2 = 1 (11 % 10, changed from 0 to 1!)

  Position 3 (hundreds): 3 + 5 + C_2 = 3 + 5 + 1 = 9
    C_3 = 0 (still no carry, 9 < 10)
    O_3 = 9 (unchanged)

  Carry pattern: [C_1=1, C_2=1, C_3=0] = [1, 1, 0]
  Answer: 919 (changed from 909 to 919!)

✓ The C_1 intervention propagated through C_2 but stopped at C_3


### Better Example: Full Cascade [0, 1, 0] → [1, 1, 1]

Let's create an example where forcing C_1=1 causes the cascade to go all the way to C_3.

In [13]:
# Example: 459 + 550
# Ones: 9 + 0 = 9 (C_1 = 0, but if forced to 1...)
# Tens: 5 + 5 = 10, +C_1 = 10 or 11 (C_2 = 1 either way)
# Hundreds: 4 + 5 = 9, +C_2 = 10 (C_3 = 1 if C_2=1)

full_cascade_example = {
    "digit_0_0": 4,
    "digit_0_1": 5,
    "digit_0_2": 9,  # 459
    "digit_1_0": 5,
    "digit_1_1": 5,
    "digit_1_2": 0,  # 550
    "num_digits": 3,
    "template": config.templates[0],
}

output_before = model.new_trace(full_cascade_example)

print("Full Cascade Example: 459 + 550")
print("=" * 70)
print()
print("Normal computation (no intervention):")
print(f"  Position 1: 9 + 0 = 9, C_1 = {output_before['C_1']}")
print(
    f"  Position 2: 5 + 5 + {output_before['C_1']} = 10, C_2 = {output_before['C_2']}"
)
print(
    f"  Position 3: 4 + 5 + {output_before['C_2']} = 10, C_3 = {output_before['C_3']}"
)
print(
    f"  Pattern: [C_1={output_before['C_1']}, C_2={output_before['C_2']}, C_3={output_before['C_3']}] = [0, 1, 1]"
)
print(f"  Answer: {output_before['raw_output'].strip()}")
print()

Full Cascade Example: 459 + 550

Normal computation (no intervention):
  Position 1: 9 + 0 = 9, C_1 = 0
  Position 2: 5 + 5 + 0 = 10, C_2 = 1
  Position 3: 4 + 5 + 1 = 10, C_3 = 1
  Pattern: [C_1=0, C_2=1, C_3=1] = [0, 1, 1]
  Answer: 1009



In [14]:
# Hmm, that naturally has C_3=1. Let me try a different example.
# Example: 369 + 630
# Ones: 9 + 0 = 9 (C_1 = 0, if forced to 1, becomes 10)
# Tens: 6 + 3 = 9, +C_1=0 → 9, +C_1=1 → 10 (C_2 changes from 0 to 1!)
# Hundreds: 3 + 6 = 9, +C_2=0 → 9, +C_2=1 → 10 (C_3 changes from 0 to 1!)
# This should give [0, 0, 0] → [1, 1, 1] cascade!

perfect_cascade = {
    "digit_0_0": 3,
    "digit_0_1": 6,
    "digit_0_2": 9,  # 369
    "digit_1_0": 6,
    "digit_1_1": 3,
    "digit_1_2": 0,  # 630
    "num_digits": 3,
    "template": config.templates[0],
}

output_pre_cascade = model.new_trace(perfect_cascade)

print("Perfect Cascade Example: 369 + 630 = 999")
print("=" * 70)
print()
print("Normal computation:")
print(f"  Position 1: 9 + 0 = 9, C_1 = {output_pre_cascade['C_1']}")
print(
    f"  Position 2: 6 + 3 + {output_pre_cascade['C_1']} = 9, C_2 = {output_pre_cascade['C_2']}"
)
print(
    f"  Position 3: 3 + 6 + {output_pre_cascade['C_2']} = 9, C_3 = {output_pre_cascade['C_3']}"
)
print(
    f"  Pattern: [C_1={output_pre_cascade['C_1']}, C_2={output_pre_cascade['C_2']}, C_3={output_pre_cascade['C_3']}] = [0, 0, 0]"
)
print(f"  Answer: {output_pre_cascade['raw_output'].strip()}")
print()

Perfect Cascade Example: 369 + 630 = 999

Normal computation:
  Position 1: 9 + 0 = 9, C_1 = 0
  Position 2: 6 + 3 + 0 = 9, C_2 = 0
  Position 3: 3 + 6 + 0 = 9, C_3 = 0
  Pattern: [C_1=0, C_2=0, C_3=0] = [0, 0, 0]
  Answer: 999



In [15]:
# Now the cascade intervention
output_full_cascade = model.new_trace({**perfect_cascade, 'C_1': 1})

print("CASCADE: Force C_1 = 1")
print("=" * 70)
print()
print(f"  Position 1: C_1 = {output_full_cascade['C_1']} (forced)")
print(f"    O_1 = {output_full_cascade['O_1']}")
print()
print(f"  Position 2: 6 + 3 + {output_full_cascade['C_1']} = 10")
print(f"    C_2 = {output_full_cascade['C_2']} ✓ (CASCADE! Changed from 0 to 1)")
print(f"    O_2 = {output_full_cascade['O_2']} (10 % 10, changed from 9 to 0)")
print()
print(f"  Position 3: 3 + 6 + {output_full_cascade['C_2']} = 10")
print(f"    C_3 = {output_full_cascade['C_3']} ✓ (CASCADE! Changed from 0 to 1)")
print(f"    O_3 = {output_full_cascade['O_3']} (10 % 10, changed from 9 to 0)")
print()
print(
    f"  Position 4: O_4 = {output_full_cascade['O_4']} (equals C_3, changed from 0 to 1)"
)
print()
print(
    f"  Pattern: [C_1={output_full_cascade['C_1']}, C_2={output_full_cascade['C_2']}, C_3={output_full_cascade['C_3']}] = [1, 1, 1]"
)
print(f"  Answer: {output_full_cascade['raw_output'].strip()}")
print()
print("=" * 70)
print("✓ FULL CASCADE: [0,0,0] → [1,1,1]")
print("✓ C_1 intervention propagated through C_2 to C_3, incrementing hundreds place!")
print("=" * 70)

CASCADE: Force C_1 = 1

  Position 1: C_1 = 1 (forced)
    O_1 = 9

  Position 2: 6 + 3 + 1 = 10
    C_2 = 1 ✓ (CASCADE! Changed from 0 to 1)
    O_2 = 0 (10 % 10, changed from 9 to 0)

  Position 3: 3 + 6 + 1 = 10
    C_3 = 1 ✓ (CASCADE! Changed from 0 to 1)
    O_3 = 0 (10 % 10, changed from 9 to 0)

  Position 4: O_4 = 1 (equals C_3, changed from 0 to 1)

  Pattern: [C_1=1, C_2=1, C_3=1] = [1, 1, 1]
  Answer: 1009

✓ FULL CASCADE: [0,0,0] → [1,1,1]
✓ C_1 intervention propagated through C_2 to C_3, incrementing hundreds place!


### Testing Cascade with Sum-to-Nine Dataset

Now we test if the sum-to-nine dataset allows us to distinguish cascade patterns.

In [16]:
print("Testing carry distinguishability with sum-to-nine dataset:")
print("=" * 70)
print()

# Test C_1 vs C_2
result_c1_c2 = model.can_distinguish_with_dataset(sum_to_nine_dataset, ["C_1"], ["C_2"])
print(f"C_1 vs C_2: {result_c1_c2['proportion']:.2%} distinguishable")
print(f"  ({result_c1_c2['count']}/{len(sum_to_nine_dataset)} pairs)")
print()

# Test C_1 vs C_3
result_c1_c3 = model.can_distinguish_with_dataset(sum_to_nine_dataset, ["C_1"], ["C_3"])
print(f"C_1 vs C_3: {result_c1_c3['proportion']:.2%} distinguishable")
print(f"  ({result_c1_c3['count']}/{len(sum_to_nine_dataset)} pairs)")
print()

# Test C_1 vs O_3 (hundreds output digit)
result_c1_o3 = model.can_distinguish_with_dataset(sum_to_nine_dataset, ["C_1"], ["O_3"])
print(f"C_1 vs O_3: {result_c1_o3['proportion']:.2%} distinguishable")
print(f"  ({result_c1_o3['count']}/{len(sum_to_nine_dataset)} pairs)")
print()

# Test C_1 vs O_4 (thousands output digit)
result_c1_o4 = model.can_distinguish_with_dataset(sum_to_nine_dataset, ["C_1"], ["O_4"])
print(f"C_1 vs O_4: {result_c1_o4['proportion']:.2%} distinguishable")
print(f"  ({result_c1_o4['count']}/{len(sum_to_nine_dataset)} pairs)")
print()

print("Key insight: When ones digits sum to 9, C_1 should be distinguishable")
print("from downstream carries and outputs because it can trigger cascades.")

Testing carry distinguishability with sum-to-nine dataset:

Can distinguish between ['C_1'] and ['C_2']: 94 out of 128 examples
Proportion of distinguishable examples: 0.73
C_1 vs C_2: 73.44% distinguishable
  (94/128 pairs)

Can distinguish between ['C_1'] and ['C_3']: 90 out of 128 examples
Proportion of distinguishable examples: 0.70
C_1 vs C_3: 70.31% distinguishable
  (90/128 pairs)

Can distinguish between ['C_1'] and ['O_3']: 121 out of 128 examples
Proportion of distinguishable examples: 0.95
C_1 vs O_3: 94.53% distinguishable
  (121/128 pairs)

Can distinguish between ['C_1'] and ['O_4']: 90 out of 128 examples
Proportion of distinguishable examples: 0.70
C_1 vs O_4: 70.31% distinguishable
  (90/128 pairs)

Key insight: When ones digits sum to 9, C_1 should be distinguishable
from downstream carries and outputs because it can trigger cascades.


## Summary: Testing Neural Networks

### What We've Demonstrated

1. **Carry Variables (C_1, C_2, C_3)**: Explicit intermediate representations of carries
   - Interventions on C_i affect downstream positions
   - C_1 → O_2, C_2 → O_3, etc.

2. **Random Counterfactuals**: Baseline distinguishability testing
   - C_1 is distinguishable from most other variables
   - Especially distinguishable from output variables O_2, O_3

3. **Sum-to-Nine Counterfactuals**: Cascade pattern testing
   - Binary vector [0, 0, 0] with ones digits summing to 9
   - Intervening C_1=1 creates cascade [1, 1, 1]
   - C_1 intervention propagates through C_2 to C_3
   - This increments the hundreds place (O_3) and thousands place (O_4)

### The Neural Network Question

**Do neural networks represent explicit carry variables?**

To test this:
1. Generate these counterfactual datasets (random and sum-to-nine)
2. Run interchange interventions on neural activations
3. Test if swapping activations produces cascade effects like swapping C_1
4. If yes → evidence for explicit carry representation
5. If no → model uses different computational structure

### Key Metrics

- **Distinguishability**: Can we distinguish C_1 from other variables?
- **Cascade detection**: Does intervening on C_1 create the expected cascade pattern?
- **Position specificity**: Is the effect localized to the right positions (O_2, O_3, O_4)?