# ICoT Replication: Linear Regression Probing

This notebook replicates the linear regression probing experiment from the ICoT (Implicit Chain-of-Thought) multiplication research. The goal is to test whether intermediate values ĉk (running sums during multiplication) can be decoded from the model's hidden states using linear probes.

## Experiment Overview

**Goal**: Determine if the model internally represents running sums (ĉk) during multi-digit multiplication.

**Method**:
1. Load pre-trained ICoT model (2 layers, 4 heads)
2. Extract hidden states at specific timesteps for each output digit position
3. Train linear regression probes to predict ĉk from hidden states
4. Evaluate probe accuracy (MAE) on validation set
5. Compare against SFT (standard fine-tuning) baseline

**Expected Result**: ICoT model should have lower MAE than SFT, indicating it learns to represent intermediate computation steps.

In [None]:
# Setup and imports
import os
import sys

# Change to repository root
os.chdir('/home/smallyan/critic_model_mechinterp/icot')
sys.path.insert(0, '/home/smallyan/critic_model_mechinterp/icot')

import torch
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path

# Check GPU availability
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU device: {torch.cuda.get_device_name(0)}")
    device = "cuda"
else:
    print("Warning: Running on CPU - this will be slow")
    device = "cpu"

In [None]:
# Import model utilities
from src.model_utils import load_hf_model, load_c_hat_model
from src.HookedModel import convert_to_hooked_model
from src.ActivationCache import record_activations
from src.probes import RegressionProbe

print("Successfully imported custom modules")

## 1. Load Models

We load two models:
- **ICoT model**: Trained with implicit chain-of-thought (learns multiplication via intermediate steps)
- **SFT model**: Standard fine-tuning (learns direct input-output mapping)

In [None]:
# Load ICoT model
BASE_DIR = "/home/smallyan/critic_model_mechinterp/icot"
model_path = os.path.join(BASE_DIR, "ckpts/2L4H/")
config_path = os.path.join(model_path, "config.json")
state_dict_path = os.path.join(model_path, "state_dict.bin")

print("Loading ICoT model...")
icot_model, tokenizer = load_hf_model(config_path, state_dict_path, cpu=(device=="cpu"))
icot_model.to(device).eval()
convert_to_hooked_model(icot_model)
print("ICoT model loaded successfully")

# Load SFT model
print("\nLoading SFT model...")
sft_model_path = os.path.join(BASE_DIR, "ckpts/vanilla_ft/ckpt.pt")
sft_model, _ = load_c_hat_model(sft_model_path)
sft_model.to(device).eval()
print("SFT model loaded successfully")

## 2. Load and Prepare Data

The dataset contains 4x4 digit multiplication problems in reverse order (least-significant digit first).
Format example: `1338 * 5105` represents 8331 × 5015

In [None]:
# Load validation data
data_path = os.path.join(BASE_DIR, "data/processed_valid.txt")

with open(data_path, "r") as f:
    texts = f.readlines()

# Parse operands from file format
texts = [
    text.replace(" ", "").replace("\n", "").split("||")[0].split("*")
    for text in texts
    if text != "\n"
]

# Convert to correct decimal order (reverse the digit order)
operands = [(int(a[::-1]), int(b[::-1])) for a, b in texts]

print(f"Loaded {len(operands)} multiplication problems")
print(f"Example: {operands[0][0]} × {operands[0][1]} = {operands[0][0] * operands[0][1]}")

## 3. Create Input Prompts

We create prompts that include the full answer sequence to position the model at the final timestep where all intermediate values have been computed.

In [None]:
# Utility function to format prompts
def multiply(a: int, b: int, return_reverse=False) -> str:
    """Multiply a, b, optionally return the result in reverse order."""
    ans = str(a * b)
    if return_reverse:
        return ans[::-1]
    return ans

def prompt_ci_operands(operands, i, tokenizer, device="cpu"):
    """Generate prompts for c_i position."""
    answers = [multiply(a, b, return_reverse=True) for a, b in operands]
    suffixes = ["" for _ in answers]
    if i >= 1:
        suffixes = [" " + " ".join(ans[:i]) for ans in answers]
    
    prompt_txts = [
        " " + " ".join(str(a))[::-1] + " * " + " ".join(str(b))[::-1] + " "
        for a, b in operands
    ]
    eos = tokenizer.eos_token
    prompt_txts = [
        f"{txt}{eos}{eos} ####{suffix}" for txt, suffix in zip(prompt_txts, suffixes)
    ]
    
    prompt_token_ids = tokenizer(prompt_txts, return_tensors="pt", padding=True).input_ids
    prompt_token_ids = prompt_token_ids.to(device)
    return prompt_txts, prompt_token_ids

# Create prompts with full answer (i=8 for 8-digit output)
prompt_text, tokens = prompt_ci_operands(operands, 8, tokenizer, device=device)
print(f"Created {len(tokens)} prompts")
print(f"Example prompt: {prompt_text[0]}")

## 4. Compute Ground Truth Labels (ĉk values)

For each operand pair (a, b), we compute the running sum ĉk at each digit position k.
The formula: ĉk = Σ(ai × bj) + carry_{k-1}, where i+j = k

In [None]:
def get_c_hats(a, b):
    """Compute running sums (c_hat) for multiplication of a and b."""
    c_hats = []
    carrys = []
    pair_sums = []
    
    # Convert to digit arrays (least significant first)
    a_digits = [int(d) for d in str(a)[::-1]]
    b_digits = [int(d) for d in str(b)[::-1]]
    total_len = len(a_digits) + len(b_digits)
    
    for ii in range(total_len):
        aibi_sum = 0
        # Sum products along the diagonal ii
        for a_ii in range(ii, -1, -1):
            b_ii = ii - a_ii
            if 0 <= a_ii < len(a_digits) and 0 <= b_ii < len(b_digits):
                aibi_sum += a_digits[a_ii] * b_digits[b_ii]
        
        pair_sums.append(aibi_sum)
        
        # Add carry from previous running sum
        if len(c_hats) > 0:
            aibi_sum += c_hats[-1] // 10
        
        c_hats.append(aibi_sum)
        carrys.append(aibi_sum // 10)
    
    return c_hats, carrys, pair_sums

# Compute labels for all operands
labels = []
for a, b in operands:
    c_hats, carrys, pair_sums = get_c_hats(a, b)
    labels.append(c_hats)

labels = torch.tensor(labels, dtype=torch.float32)
print(f"Computed labels shape: {labels.shape}")
print(f"Example c_hats for {operands[0]}: {labels[0].tolist()}")

## 5. Split Data into Train/Val

We shuffle and split the data, using 1024 samples for validation.

In [None]:
# Shuffle and split
torch.manual_seed(123)  # For reproducibility
shuffle_idx = torch.randperm(len(tokens))
tokens = tokens[shuffle_idx]
labels = labels[shuffle_idx]

val_size = 1024
val_tokens = tokens[-val_size:].to(device)
val_labels = labels[-val_size:].to(device)

train_tokens = tokens[:-val_size].to(device)
train_labels = labels[:-val_size].to(device)

print(f"Training samples: {len(train_tokens)}")
print(f"Validation samples: {len(val_tokens)}")

## 6. Record Activations from Models

We extract hidden states from specific hook points:
- Layer 0: mid-residual, post-residual
- Layer 1: mid-residual, post-residual

In [None]:
# Define hook modules to record activations from
hook_modules = [
    "0.hook_resid_mid",
    "0.hook_resid_post",
    "1.hook_resid_mid",
    "1.hook_resid_post",
]

print("Recording validation activations from ICoT model...")
with torch.no_grad():
    with record_activations(icot_model, hook_modules) as cache:
        _ = icot_model(val_tokens)

# Stack activations: [num_modules, batch, seq_len, hidden_dim]
val_acts = torch.stack(
    [cache[m][:, -val_labels.shape[1]:] for m in hook_modules],
    dim=0,
)

print(f"Validation activations shape: {val_acts.shape}")
print(f"[num_modules={val_acts.shape[0]}, batch={val_acts.shape[1]}, seq_len={val_acts.shape[2]}, hidden_dim={val_acts.shape[3]}]")

## 7. Train or Load Linear Probes

We attempt to load pre-trained probes. If not available, we train new ones.

In [None]:
# Probe configuration
num_modules, val_batch_size, seq, d_model = val_acts.shape
probe_shape = (num_modules, seq, d_model, 1)

# Try to load pre-trained ICoT probe
probe_path = os.path.join(BASE_DIR, "ckpts/icot_c_hat_probe/probe.pth")
icot_probe = RegressionProbe(probe_shape, 1e-3)

if os.path.exists(probe_path):
    print(f"Loading pre-trained ICoT probe from {probe_path}")
    icot_probe.load_weights(probe_path)
else:
    print("Pre-trained ICoT probe not found. Training new probe...")
    # Record training activations
    print("Recording training activations...")
    with torch.no_grad():
        with record_activations(icot_model, hook_modules) as cache:
            _ = icot_model(train_tokens)
    
    train_acts = torch.stack(
        [cache[m][:, -train_labels.shape[1]:] for m in hook_modules],
        dim=0,
    )
    
    # Train probe
    print("Training probe (this may take a few minutes)...")
    for epoch in range(100):
        loss = icot_probe.train_step(train_acts, train_labels)
        if (epoch + 1) % 20 == 0:
            print(f"Epoch {epoch+1}/100, Loss: {loss:.4f}")

# Try to load pre-trained SFT probe
sft_probe_path = os.path.join(BASE_DIR, "ckpts/sft_c_hat_probe/probe.pth")
sft_probe = RegressionProbe(probe_shape, 1e-3)

if os.path.exists(sft_probe_path):
    print(f"\nLoading pre-trained SFT probe from {sft_probe_path}")
    sft_probe.load_weights(sft_probe_path)
else:
    print("\nWarning: Pre-trained SFT probe not found. Skipping SFT comparison.")
    sft_probe = None

## 8. Evaluate Probes

We compute Mean Absolute Error (MAE) for each digit position (c2 through c6).

In [None]:
# Evaluate ICoT probe
print("Evaluating ICoT probe...")
with torch.no_grad():
    icot_val_preds = icot_probe(val_acts)

icot_metrics = icot_probe.evaluate_probe(val_acts, val_labels)
icot_mae = icot_metrics[-1][2]  # Extract MAE for layer 1 post-residual

print(f"ICoT MAE by digit position:")
for i, mae in enumerate(icot_mae):
    print(f"  c{i}: {mae:.3f}")

# Evaluate SFT probe if available
if sft_probe is not None:
    print("\nEvaluating SFT probe...")
    with torch.no_grad():
        sft_val_preds = sft_probe(val_acts)
    
    sft_metrics = sft_probe.evaluate_probe(val_acts, val_labels)
    sft_mae = sft_metrics[-1][2]
    
    print(f"SFT MAE by digit position:")
    for i, mae in enumerate(sft_mae):
        print(f"  c{i}: {mae:.3f}")
else:
    sft_val_preds = None
    sft_mae = None

## 9. Visualize Results

Create scatter plots comparing predicted vs. true ĉk values for positions c2-c6.

In [None]:
# Prepare data for plotting
val_labels_np = val_labels.cpu().numpy()
icot_val_preds_np = icot_val_preds.cpu().numpy()

if sft_val_preds is not None:
    sft_val_preds_np = sft_val_preds.cpu().numpy()
    n_rows = 2
else:
    n_rows = 1

n_cols = 5
fig, axes = plt.subplots(
    n_rows, n_cols, figsize=(15, 3*n_rows), gridspec_kw={"hspace": 0.4}
)

if n_rows == 1:
    axes = axes.reshape(1, -1)

for row in range(n_rows):
    if row == 0 and sft_val_preds is not None:
        probe_preds = sft_val_preds_np
        metrics = sft_mae
        model_name = "SFT"
    else:
        probe_preds = icot_val_preds_np
        metrics = icot_mae
        model_name = "ICoT"
    
    for col_idx, c_i in enumerate(range(2, 7)):
        ax = axes[row, col_idx]
        _val_labels = val_labels_np[:, c_i]
        _val_preds = probe_preds[2, :, c_i]  # Layer 1 mid-residual
        
        min_val = min(_val_labels.min(), _val_preds.min())
        max_val = max(_val_labels.max(), _val_preds.max())
        diagonal_line = np.linspace(min_val, max_val, 100)
        
        sorted_indices = np.argsort(_val_labels)
        sorted_labels = _val_labels[sorted_indices]
        sorted_preds = _val_preds[sorted_indices]
        
        mae = metrics[c_i]
        
        ax.plot(
            diagonal_line,
            diagonal_line,
            "r--",
            alpha=0.7,
            linewidth=2,
            label="Perfect predictions",
        )
        
        ax.scatter(
            sorted_labels,
            sorted_preds,
            alpha=0.5,
            s=5,
            label="Predictions",
            color="blue",
        )
        ax.set_title(f"{model_name}: c_hat_{c_i} (MAE {mae:.2f})", fontsize=12)
        
        if row == n_rows - 1:
            ax.set_xlabel(f"True c_hat_{c_i}", fontsize=11)
        
        if col_idx == 0:
            ax.set_ylabel(f"Predicted c_hat", fontsize=11)
            if row == 0:
                ax.legend(fontsize=9)
        ax.set_aspect("equal", adjustable="box")

plt.tight_layout()
output_path = "evaluation/replications/probe_results.png"
plt.savefig(output_path, dpi=150, bbox_inches="tight")
print(f"\nSaved visualization to {output_path}")
plt.show()

## 10. Summary of Results

Compare the MAE values between ICoT and SFT models to determine which better represents intermediate computation.

In [None]:
print("\n" + "="*60)
print("REPLICATION RESULTS SUMMARY")
print("="*60)

print("\nICoT Model - Mean Absolute Error by digit:")
for i in range(2, 7):
    print(f"  c_hat_{i}: {icot_mae[i]:.3f}")
print(f"  Average (c2-c6): {icot_mae[2:7].mean():.3f}")

if sft_mae is not None:
    print("\nSFT Model - Mean Absolute Error by digit:")
    for i in range(2, 7):
        print(f"  c_hat_{i}: {sft_mae[i]:.3f}")
    print(f"  Average (c2-c6): {sft_mae[2:7].mean():.3f}")
    
    print("\nImprovement (SFT MAE - ICoT MAE):")
    for i in range(2, 7):
        improvement = sft_mae[i] - icot_mae[i]
        print(f"  c_hat_{i}: {improvement:.3f} ({improvement/sft_mae[i]*100:.1f}% better)")

print("\n" + "="*60)
print("INTERPRETATION:")
print("Lower MAE indicates the model better represents intermediate values.")
print("ICoT should show significantly lower MAE than SFT, confirming it")
print("learns to represent running sums during multiplication.")
print("="*60)