# ICoT Code Questions - Solutions

This notebook contains complete solutions for the three code-based assessment questions.

**Important:** This notebook is for evaluation purposes. Students should only see the stubs in question_documentation.ipynb.

Each code question includes:
1. Student-facing stub (identical to question_documentation.ipynb)
2. Complete solution with implementation
3. Auto-check cell to validate results


---
## CQ1: Logit Attribution Pattern Verification

**Objective:** Verify that input digit position aᵢ has strongest effect on output digit cₖ when i+j=k.


In [None]:
# CQ1: Logit Attribution Pattern Verification (STUB)
import torch
import numpy as np
import sys
sys.path.append('/home/smallyan/critic_model_mechinterp/icot')
from src.model_utils import load_hf_model
from src.data_utils import prompt_ci_raw_format_batch

# Set random seed for reproducibility
np.random.seed(42)
torch.manual_seed(42)

# TODO: Load the ICoT model
# TODO: Generate 100 random 4×4 multiplication problems
# TODO: Compute logit attribution for each input-output pair
# TODO: Print top-3 influential positions for each output


In [None]:
# CQ1: SOLUTION
import torch
import numpy as np
import sys
sys.path.append('/home/smallyan/critic_model_mechinterp/icot')
from src.model_utils import load_hf_model

np.random.seed(42)
torch.manual_seed(42)

# Load ICoT model
config_path = "/home/smallyan/critic_model_mechinterp/icot/ckpts/2L4H/config.json"
state_dict_path = "/home/smallyan/critic_model_mechinterp/icot/ckpts/1_to_4_revops_2L_H4.pt"

print("Loading ICoT model...")
model, tokenizer = load_hf_model(config_path, state_dict_path)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = model.to(device)
model.eval()

# Generate 100 random 4×4 multiplication problems
n_samples = 100
operands = []
for _ in range(n_samples):
    a = ''.join([str(np.random.randint(0, 10)) for _ in range(4)])
    b = ''.join([str(np.random.randint(0, 10)) for _ in range(4)])
    operands.append((a, b))

print(f"Generated {n_samples} random multiplication problems")

# Function to format input and get logits
def get_output_logits(model, tokenizer, operand_pair, device):
    a, b = operand_pair
    # Format: "a0a1a2a3 * b0b1b2b3%%%####"
    prompt = f"{a} * {b}%%%####"
    tokens = tokenizer.encode(prompt, return_tensors='pt').to(device)
    
    with torch.no_grad():
        outputs = model(tokens)
        logits = outputs.logits[0]  # Shape: [seq_len, vocab_size]
    
    # Return logits at positions where c0-c7 would be generated
    # The output digits start after #### (4 # symbols)
    sep_pos = len(tokenizer.encode(f"{a} * {b}%%%####")) - 1
    return logits[sep_pos:sep_pos+8], tokens

# Compute attribution for each input position on each output position
print("Computing logit attribution...")
attribution_matrix = np.zeros((8, 8, n_samples))  # [input_pos, output_pos, sample]

for sample_idx, (a, b) in enumerate(operands):
    # Get baseline logits
    baseline_logits, tokens = get_output_logits(model, tokenizer, (a, b), device)
    
    # For each input position (8 positions: 4 for a, 4 for b)
    for input_pos in range(8):
        # Create counterfactual by swapping digit
        if input_pos < 4:  # Modify a
            a_list = list(a)
            original_digit = a_list[input_pos]
            new_digit = str((int(original_digit) + 5) % 10)
            a_list[input_pos] = new_digit
            a_modified = ''.join(a_list)
            counterfactual = (a_modified, b)
        else:  # Modify b
            b_list = list(b)
            b_pos = input_pos - 4
            original_digit = b_list[b_pos]
            new_digit = str((int(original_digit) + 5) % 10)
            b_list[b_pos] = new_digit
            b_modified = ''.join(b_list)
            counterfactual = (a, b_modified)
        
        # Get counterfactual logits
        cf_logits, _ = get_output_logits(model, tokenizer, counterfactual, device)
        
        # Compute attribution as absolute difference in logits
        # Focus on output positions c2-c6
        for output_pos in range(2, 7):
            if output_pos < len(baseline_logits) and output_pos < len(cf_logits):
                # Get the logit for the correct digit
                diff = torch.abs(baseline_logits[output_pos] - cf_logits[output_pos]).max().item()
                attribution_matrix[input_pos, output_pos, sample_idx] = diff

# Average across samples
mean_attribution = attribution_matrix.mean(axis=2)

# Identify top-3 positions for each output
print("\nTop-3 influential input positions for each output digit:")
print("=" * 60)

results = {}
for output_pos in range(2, 7):
    scores = mean_attribution[:, output_pos]
    top_3_idx = np.argsort(scores)[-3:][::-1]
    
    print(f"\nOutput c{output_pos}:")
    print(f"  Expected pattern: positions where i+j={output_pos}")
    print(f"  Top-3 positions: {top_3_idx.tolist()}")
    
    # Check if top positions match i+j=k pattern
    expected_positions = []
    for i in range(4):
        for j in range(4):
            if i + j == output_pos:
                # Position in input: a_i is at index i, b_j is at index j+4
                expected_positions.extend([i, j+4])
    expected_positions = list(set(expected_positions))[:6]  # Take first few
    
    print(f"  Expected positions (i+j={output_pos}): {expected_positions}")
    
    # Check overlap
    overlap = len(set(top_3_idx.tolist()) & set(expected_positions))
    print(f"  Match quality: {overlap}/3 top positions match expected")
    
    results[f"c{output_pos}"] = {
        "top_3": top_3_idx.tolist(),
        "expected": expected_positions,
        "overlap": overlap
    }

print("\n" + "=" * 60)
print("VERIFICATION COMPLETE")


In [None]:
# CQ1: AUTO-CHECK
# Verify that the attribution pattern shows expected structure

print("\nAuto-check for CQ1:")
print("-" * 40)

# Check that results were computed
assert 'results' in locals(), "Results not computed"

# Check that we have results for all output positions
for k in range(2, 7):
    assert f"c{k}" in results, f"Missing results for c{k}"

# Check that most outputs show good matches (at least 2/3 positions match on average)
total_overlap = sum(results[f"c{k}"]["overlap"] for k in range(2, 7))
avg_overlap = total_overlap / 5

print(f"Average overlap with expected pattern: {avg_overlap:.1f}/3")

if avg_overlap >= 1.5:
    print("✓ PASS: Attribution pattern shows i+j=k structure")
else:
    print("✗ FAIL: Attribution pattern does not match expected structure")

print(f"\nExpected: Average overlap ≥ 1.5/3")
print(f"Achieved: {avg_overlap:.1f}/3")


---
## CQ2: Running Sum Linear Probe Accuracy

**Objective:** Demonstrate that ĉₖ values are linearly decodable from ICoT hidden states.


In [None]:
# CQ2: Running Sum Linear Probe Accuracy (STUB)
import torch
import numpy as np
from sklearn.linear_model import Ridge
import sys
sys.path.append('/home/smallyan/critic_model_mechinterp/icot')

# Set random seed
np.random.seed(42)
torch.manual_seed(42)

# TODO: Load ICoT model
# TODO: Extract hidden states at layer 2 mid-point
# TODO: Compute ground truth ĉₖ values
# TODO: Train linear probes and evaluate MAE


In [None]:
# CQ2: SOLUTION
import torch
import numpy as np
from sklearn.linear_model import Ridge
from sklearn.metrics import mean_absolute_error
import sys
sys.path.append('/home/smallyan/critic_model_mechinterp/icot')
from src.model_utils import load_hf_model
from src.data_utils import read_operands

np.random.seed(42)
torch.manual_seed(42)

# Load ICoT model
config_path = "/home/smallyan/critic_model_mechinterp/icot/ckpts/2L4H/config.json"
state_dict_path = "/home/smallyan/critic_model_mechinterp/icot/ckpts/1_to_4_revops_2L_H4.pt"

print("Loading ICoT model...")
model, tokenizer = load_hf_model(config_path, state_dict_path)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = model.to(device)
model.eval()

# Load validation data
data_path = "/home/smallyan/critic_model_mechinterp/icot/data/processed_valid.txt"
all_operands = read_operands(data_path)[:300]  # Use 300 samples

print(f"Loaded {len(all_operands)} multiplication problems")

# Split into train and test
train_operands = all_operands[:200]
test_operands = all_operands[200:300]

# Function to compute ground truth ĉₖ
def compute_c_hat(a_str, b_str, k):
    '''Compute the accumulated sum at position k'''
    a_digits = [int(d) for d in a_str]
    b_digits = [int(d) for d in b_str]
    
    # Compute sum of all products where i+j <= k
    total = 0
    for i in range(len(a_digits)):
        for j in range(len(b_digits)):
            if i + j <= k:
                total += a_digits[i] * b_digits[j] * (10 ** (i + j))
    
    return total

# Function to extract hidden states at layer 2 mid-point
def extract_hidden_states(model, tokenizer, operand_pair, device):
    a, b = operand_pair
    prompt = f"{a} * {b}%%%####"
    tokens = tokenizer.encode(prompt, return_tensors='pt').to(device)
    
    # Hook to capture layer 2 mid-point (after attention, before MLP)
    hidden_states = {}
    
    def hook_fn(module, input, output):
        # output is the hidden state after attention block
        hidden_states['layer2_mid'] = output[0].detach()
    
    # Register hook on layer 1 (second layer, 0-indexed)
    hook = model.transformer.h[1].register_forward_hook(hook_fn)
    
    with torch.no_grad():
        _ = model(tokens)
    
    hook.remove()
    
    # Get states at output positions
    # Output starts after #### 
    sep_pos = len(tokenizer.encode(f"{a} * {b}%%%####")) - 1
    return hidden_states['layer2_mid'], sep_pos

# Extract features and targets for k in {2, 3, 4}
probe_results = {}

for k in [2, 3, 4]:
    print(f"\nProcessing k={k}...")
    
    # Collect training data
    X_train = []
    y_train = []
    
    for a, b in train_operands:
        hidden, sep_pos = extract_hidden_states(model, tokenizer, (a, b), device)
        # Get hidden state at position where c_k is computed
        if sep_pos + k < hidden.shape[1]:
            h_k = hidden[0, sep_pos + k, :].cpu().numpy()
            c_hat_k = compute_c_hat(a, b, k)
            
            X_train.append(h_k)
            y_train.append(c_hat_k)
    
    X_train = np.array(X_train)
    y_train = np.array(y_train)
    
    print(f"  Training samples: {len(X_train)}")
    
    # Train linear probe with Ridge regression
    probe = Ridge(alpha=0.01)
    probe.fit(X_train, y_train)
    
    # Collect test data
    X_test = []
    y_test = []
    
    for a, b in test_operands:
        hidden, sep_pos = extract_hidden_states(model, tokenizer, (a, b), device)
        if sep_pos + k < hidden.shape[1]:
            h_k = hidden[0, sep_pos + k, :].cpu().numpy()
            c_hat_k = compute_c_hat(a, b, k)
            
            X_test.append(h_k)
            y_test.append(c_hat_k)
    
    X_test = np.array(X_test)
    y_test = np.array(y_test)
    
    # Evaluate
    y_pred = probe.predict(X_test)
    mae = mean_absolute_error(y_test, y_pred)
    
    probe_results[k] = {
        'mae': mae,
        'n_train': len(X_train),
        'n_test': len(X_test)
    }
    
    print(f"  Test samples: {len(X_test)}")
    print(f"  MAE: {mae:.2f}")

# Print final results
print("\n" + "=" * 60)
print("Mean Absolute Error for running sum prediction:")
print("=" * 60)
for k in [2, 3, 4]:
    print(f"ĉ_{k}: {probe_results[k]['mae']:.2f}")
print("=" * 60)


In [None]:
# CQ2: AUTO-CHECK
print("\nAuto-check for CQ2:")
print("-" * 40)

# Check that all MAE values are computed
assert 'probe_results' in locals(), "Probe results not computed"

for k in [2, 3, 4]:
    assert k in probe_results, f"Missing results for k={k}"
    mae = probe_results[k]['mae']
    print(f"ĉ_{k} MAE: {mae:.2f} (threshold: <5.0)")
    
    if mae < 5.0:
        print(f"  ✓ PASS for k={k}")
    else:
        print(f"  ✗ FAIL for k={k}")

# Overall pass/fail
all_pass = all(probe_results[k]['mae'] < 5.0 for k in [2, 3, 4])

print("-" * 40)
if all_pass:
    print("✓ ALL TESTS PASSED: Running sums are linearly decodable")
else:
    print("✗ SOME TESTS FAILED: Check implementation")


---
## CQ3: Fourier Basis R² Computation for Digit Embeddings

**Objective:** Verify that digit embeddings follow Fourier basis structure with k ∈ {0,1,2,5}.


In [None]:
# CQ3: Fourier Basis R² Computation (STUB)
import torch
import numpy as np
import sys
sys.path.append('/home/smallyan/critic_model_mechinterp/icot')

# Set random seed
np.random.seed(42)
torch.manual_seed(42)

# TODO: Load ICoT model and extract embeddings
# TODO: Construct Fourier basis matrix
# TODO: Fit coefficients and compute R² for each dimension
# TODO: Report median R²


In [None]:
# CQ3: SOLUTION
import torch
import numpy as np
import sys
sys.path.append('/home/smallyan/critic_model_mechinterp/icot')
from src.model_utils import load_hf_model

np.random.seed(42)
torch.manual_seed(42)

# Load ICoT model
config_path = "/home/smallyan/critic_model_mechinterp/icot/ckpts/2L4H/config.json"
state_dict_path = "/home/smallyan/critic_model_mechinterp/icot/ckpts/1_to_4_revops_2L_H4.pt"

print("Loading ICoT model...")
model, tokenizer = load_hf_model(config_path, state_dict_path)

# Extract embedding matrix for digits 0-9
# Token IDs for digits: tokenizer.encode('0')[0], etc.
digit_token_ids = [tokenizer.encode(str(d))[0] for d in range(10)]
print(f"Digit token IDs: {digit_token_ids}")

embeddings = model.transformer.wte.weight[digit_token_ids, :].detach().cpu().numpy()
print(f"Embedding shape: {embeddings.shape}")  # Should be (10, 768)

# Construct Fourier basis matrix Φ (shape: 10 x 6)
# Frequencies k ∈ {0, 1, 2, 5}
# Basis: [constant, cos(2πn/10), sin(2πn/10), cos(2πn/5), sin(2πn/5), parity]

n = np.arange(10)
phi = np.column_stack([
    np.ones(10),                      # k=0: constant
    np.cos(2 * np.pi * n / 10),      # k=1: cos component
    np.sin(2 * np.pi * n / 10),      # k=1: sin component
    np.cos(2 * np.pi * n / 5),       # k=2: cos component (also k=2 mod 5)
    np.sin(2 * np.pi * n / 5),       # k=2: sin component
    (-1) ** n                         # k=5: parity (period 2)
])

print(f"Fourier basis shape: {phi.shape}")  # (10, 6)

# For each embedding dimension, compute R²
r2_values = []

n_dims = embeddings.shape[1]
for d in range(n_dims):
    # Extract vector for this dimension across all digits
    x_d = embeddings[:, d]  # Shape: (10,)
    
    # Fit coefficients using least squares: C_d = argmin ||x_d - Φ @ C||²
    C_d, residuals, rank, s = np.linalg.lstsq(phi, x_d, rcond=None)
    
    # Compute predictions
    x_pred = phi @ C_d
    
    # Compute R² = 1 - SS_res / SS_tot
    ss_res = np.sum((x_d - x_pred) ** 2)
    ss_tot = np.sum((x_d - np.mean(x_d)) ** 2)
    
    if ss_tot > 1e-10:  # Avoid division by zero
        r2 = 1 - (ss_res / ss_tot)
    else:
        r2 = 0.0
    
    r2_values.append(r2)

# Compute median R²
r2_values = np.array(r2_values)
median_r2 = np.median(r2_values)

print("\n" + "=" * 60)
print("Fourier Basis R² Analysis")
print("=" * 60)
print(f"Number of embedding dimensions: {n_dims}")
print(f"Median R²: {median_r2:.4f}")
print(f"Mean R²: {np.mean(r2_values):.4f}")
print(f"Std R²: {np.std(r2_values):.4f}")
print(f"Min R²: {np.min(r2_values):.4f}")
print(f"Max R²: {np.max(r2_values):.4f}")
print("=" * 60)

# Show distribution
percentiles = [25, 50, 75, 90, 95]
print("\nR² Distribution:")
for p in percentiles:
    val = np.percentile(r2_values, p)
    print(f"  {p}th percentile: {val:.4f}")


In [None]:
# CQ3: AUTO-CHECK
print("\nAuto-check for CQ3:")
print("-" * 40)

# Check that median R² was computed
assert 'median_r2' in locals(), "Median R² not computed"

print(f"Median R²: {median_r2:.4f} (threshold: >0.80)")

if median_r2 > 0.80:
    print("✓ PASS: Fourier basis explains >80% of variance")
    if median_r2 > 0.84:
        print("✓ EXCELLENT: Exceeds documented performance (0.84)")
else:
    print("✗ FAIL: Fourier basis does not explain sufficient variance")

print("-" * 40)
print(f"Expected: R² > 0.80")
print(f"Achieved: R² = {median_r2:.4f}")
