# ICoT Multiplication Research Assessment

This notebook contains comprehensive questions designed to evaluate understanding of the research documented in the ICoT (Implicit Chain-of-Thought) multiplication study.

**Important:** All questions should be answered based solely on the documentation provided. Reference the documentation sections as needed.

## Knowledge Key Points

The assessment covers the following key knowledge areas:

### 1. Research Problem & Motivation
- Why Transformers fail at multi-digit multiplication
- The significance of 4×4 digit multiplication as a test case
- The gap between model scale and reasoning capability

### 2. Data Format & Processing  
- Least-significant-digit-first representation
- Chain-of-thought token format and gradual removal
- Special delimiter tokens and their purpose

### 3. Model Architectures & Training
- ICoT vs SFT vs Auxiliary Loss models
- Architecture specifications (layers, heads, dimensions)
- Training hyperparameters and convergence

### 4. Long-Range Dependencies
- Definition and requirement for multiplication
- Logit attribution patterns
- Linear probe experiments and ĉₖ prediction

### 5. Attention Tree Mechanism
- Two-layer caching and retrieval structure
- Binary tree-like information flow
- Role of each layer in the computation

### 6. Geometric Representations
- Minkowski sums in attention heads
- Fourier basis with frequencies k ∈ {0,1,2,5}
- Pentagonal prism structure in 3D PCA

### 7. Training Dynamics & Failure Modes
- SFT learning pattern (c₀, c₁, c₇ first)
- Gradient plateau for middle digits
- Local optimum trap vs capacity limitation

### 8. ICoT Success Factors
- Implicit supervision through CoT removal
- Guided discovery of computational structures
- Internalization of intermediate computations

### 9. Alternative Approaches
- Auxiliary loss for ĉₖ prediction
- Inductive biases that enable learning
- Comparison with explicit supervision

### 10. Broader Implications
- Optimization challenges in transformer training
- Value of mechanistic interpretability
- Process supervision for complex reasoning

---


## Question 1

**Type:** Multiple Choice

In the ICoT multiplication task, the numbers are represented in a specific digit order. How would the multiplication problem 8331 × 5015 be formatted as input to the model?

A) 8331 * 5015
B) 1338 * 5105
C) 3318 * 1505
D) 5015 * 8331

**Your answer:**

## Question 2

**Type:** Multiple Choice

Which of the following statements correctly describes the performance difference between ICoT and standard fine-tuning (SFT) models on 4×4 digit multiplication?

A) Both achieve >95% accuracy, but ICoT trains faster
B) SFT achieves 81% accuracy while ICoT achieves 100%
C) ICoT achieves 100% accuracy while SFT achieves <1% accuracy
D) ICoT achieves 99% accuracy only when using auxiliary loss

**Your answer:**

## Question 3

**Type:** Free Generation

Describe the two-layer attention tree mechanism discovered in the ICoT model. What role does each layer play in computing the output digits?

**Your answer:**

## Question 4

**Type:** Multiple Choice

The ICoT model represents digits using Fourier bases. Which set of frequency components (k values) are used in the discovered Fourier basis representation?

A) k ∈ {0, 1, 2, 3, 4}
B) k ∈ {0, 1, 2, 5}
C) k ∈ {0, 2, 4, 5}
D) k ∈ {1, 2, 3, 4, 5}

**Your answer:**

## Question 5

**Type:** Free Generation

Explain why standard fine-tuning (SFT) fails to learn multi-digit multiplication, despite the model having sufficient capacity. What specific pattern is observed in gradient norms and loss during SFT training?

**Your answer:**

## Question 6

**Type:** Multiple Choice

What geometric structure emerges in the ICoT model's attention head outputs, and how is it mathematically characterized?

A) Minkowski sums: ATT¹(i,j) = αAᵢ + (1-α)Bⱼ + ε, creating nested cluster structures
B) Euclidean products: ATT¹(i,j) = Aᵢ · Bⱼ, creating orthogonal representations
C) Tensor products: ATT¹(i,j) = Aᵢ ⊗ Bⱼ, creating high-dimensional embeddings
D) Cartesian products: ATT¹(i,j) = (Aᵢ, Bⱼ), creating paired representations

**Your answer:**

## Question 7

**Type:** Multiple Choice

How many chain-of-thought (CoT) tokens are removed per epoch during ICoT training, and how many total epochs are needed for convergence?

A) 5 tokens per epoch, 10 epochs total
B) 10 tokens per epoch, 8 epochs total
C) 13 tokens per epoch, 8 epochs total
D) 8 tokens per epoch, 13 epochs total

**Your answer:**

## Question 8

**Type:** Free Generation

What do the linear probe experiments reveal about the ICoT model's internal representations? Specifically, what is being probed, and what do the results indicate about the difference between ICoT and SFT models?

**Your answer:**

## Question 9

**Type:** Free Generation

Based on the discovered mechanisms in the ICoT model for 4×4 digit multiplication, predict what architectural changes would be necessary to successfully learn 6×6 digit multiplication using the same ICoT training approach. Justify your prediction.

**Your answer:**

## Question 10

**Type:** Multiple Choice

Suppose you wanted to train a model to perform multi-digit division using insights from the ICoT multiplication study. Which training approach would most likely succeed?

A) Standard fine-tuning on division problems with a 12-layer model
B) Standard fine-tuning on division problems with more training data
C) ICoT-style training with gradual removal of long-division step tokens
D) Training only on the final quotient without any intermediate steps

**Your answer:**

## Question 11

**Type:** Free Generation

The ICoT model shows specific logit attribution patterns where input digit positions affect output digits. Based on the multiplication algorithm, explain why digit aᵢ should affect output digit cₖ most strongly when there exists a digit bⱼ such that i+j=k. What does this pattern reveal about the model's learned algorithm?

**Your answer:**

## Question 12

**Type:** Free Generation

The auxiliary loss model achieves 99% accuracy without explicit chain-of-thought tokens by predicting ĉₖ values. Why does this approach work? What inductive bias does it provide that standard fine-tuning lacks?

**Your answer:**

## Question 13

**Type:** Multiple Choice

A researcher hypothesizes that SFT fails at 4×4 multiplication due to insufficient model capacity, and proposes training a 24-layer, 16-head model. Based on the paper's findings, what would be the most likely outcome?

A) The larger model would achieve >90% accuracy due to increased capacity
B) The larger model would still achieve <1% accuracy, failing identically to smaller models
C) The larger model would achieve 50-60% accuracy, showing partial improvement
D) The larger model would require less training data to achieve 100% accuracy

**Your answer:**

## Question 14

**Type:** Free Generation

Describe the pentagonal prism geometry discovered in ICoT's 3D PCA analysis. What do the three principal components represent, and why does this geometry emerge?

**Your answer:**

## Question 15

**Type:** Multiple Choice

When training linear probes to decode the running sum ĉₖ, which location in the model provides the best decoding accuracy?

A) Layer 0 residual stream (after first layer MLP)
B) Layer 1 residual stream (after second layer MLP)
C) Layer 2 mid-point (after attention, before MLP)
D) Final hidden layer before unembedding

**Your answer:**

## Question 16

**Type:** Free Generation

The paper identifies that SFT fails due to an 'optimization problem' rather than a 'capacity problem'. Explain what this distinction means and provide evidence from the paper supporting this conclusion.

**Your answer:**

## Question 17

**Type:** Multiple Choice

In the Fourier basis analysis, what median R² value is achieved when fitting the ICoT model's final hidden layer representations using all Fourier frequencies k ∈ {0,1,2,3,4,5}?

A) 0.84
B) 0.95
C) 0.99
D) 1.0

**Your answer:**

## Question 18 [CQ1]

**Type:** Code-Based Question

Write code to verify the logit attribution pattern discovered in ICoT models. Specifically, test the hypothesis that input digit position aᵢ has the strongest effect on output digit cₖ when there exists a position j such that i+j=k.

Your code should:
1. Load the ICoT model checkpoint
2. Create a batch of 100 random 4×4 multiplication problems
3. For each problem, compute logit attribution by: (a) recording baseline logits for output positions c₂ through c₆, (b) for each input digit position, create a counterfactual by swapping the digit, (c) measure the change in logits
4. For each output position cₖ (k=2 to 6), compute the average absolute logit change for each input position
5. Identify which input positions have the highest attribution scores for each output position
6. Verify if the pattern matches i+j=k (print the top-3 most influential input positions for each output)

Expected output: For each output digit cₖ, the positions with indices that sum to k should show highest attribution.

In [None]:
# CQ1: Logit Attribution Pattern Verification
import torch
import numpy as np
import sys
sys.path.append('/home/smallyan/critic_model_mechinterp/icot')
from src.model_utils import load_hf_model
from src.data_utils import prompt_ci_raw_format_batch

# Set random seed for reproducibility
np.random.seed(42)
torch.manual_seed(42)

# TODO: Load the ICoT model
# model, tokenizer = load_hf_model(...)

# TODO: Generate 100 random 4×4 multiplication problems
# operands = [(a, b), ...] where a, b are 4-digit strings (reversed)

# TODO: For each problem, compute baseline logits for output positions c2-c6

# TODO: For each input digit position, create counterfactuals and measure logit changes

# TODO: Aggregate attribution scores and identify top-3 influential positions for each output

# TODO: Print results showing which input positions most affect each output position
# Expected: positions where i+j=k should show highest attribution

print("Top-3 influential input positions for each output digit:")
# Print your results here


## Question 19 [CQ2]

**Type:** Code-Based Question

Write code to replicate the linear probe experiment that demonstrates ICoT models encode the running sum ĉₖ in their hidden states.

Your code should:
1. Load the ICoT model and extract hidden states at layer 2 mid-point (after attention, before MLP)
2. For a validation set of 200 multiplication problems, extract hidden states at the timestep where cₖ is being computed (for k=2,3,4)
3. Compute the ground truth ĉₖ values (sum of all partial products with indices ≤ k plus any carry)
4. Train a simple linear regression probe (using sklearn or torch) to predict ĉₖ from the hidden states
5. Evaluate the Mean Absolute Error (MAE) on a held-out test set of 100 problems
6. Print the MAE for each k ∈ {2,3,4}

Expected outcome: MAE should be low (<5.0) for all three positions, demonstrating that ĉ values are linearly decodable from hidden states.

In [None]:
# CQ2: Running Sum Linear Probe Accuracy
import torch
import numpy as np
from sklearn.linear_model import Ridge
import sys
sys.path.append('/home/smallyan/critic_model_mechinterp/icot')
from src.model_utils import load_hf_model
from src.data_utils import read_operands, get_ci

# Set random seed
np.random.seed(42)
torch.manual_seed(42)

# TODO: Load ICoT model
# model, tokenizer = load_hf_model(...)

# TODO: Load validation data (300 problems: 200 train, 100 test)
# operands = read_operands('/home/smallyan/critic_model_mechinterp/icot/data/processed_valid.txt')

# TODO: For k in {2, 3, 4}:
#   1. Extract hidden states at layer 2 mid-point at timestep t_ck
#   2. Compute ground truth ĉₖ (sum of products with indices ≤ k plus carry)
#   3. Train linear regression probe on 200 samples
#   4. Evaluate MAE on 100 test samples

# TODO: Print MAE for each k
print("Mean Absolute Error for running sum prediction:")
# print(f"ĉ₂: {mae_2:.2f}")
# print(f"ĉ₃: {mae_3:.2f}")
# print(f"ĉ₄: {mae_4:.2f}")


## Question 20 [CQ3]

**Type:** Code-Based Question

Write code to compute the R² fit of the Fourier basis representation for digit embeddings in the ICoT model.

Your code should:
1. Load the ICoT model and extract the embedding matrix for digits 0-9
2. Construct the Fourier basis matrix Φ with frequencies k ∈ {0, 1, 2, 5}:
   - Column 0: constant (all 1s)
   - Columns 1-2: cos(2πn/10), sin(2πn/10) for n=0..9
   - Columns 3-4: cos(2πn/5), sin(2πn/5) for n=0..9  
   - Column 5: parity p(n) = (-1)ⁿ for n=0..9
3. For each dimension d in the embedding (768 dims), extract the vector x_d of length 10 (values for digits 0-9)
4. Fit coefficients: C_d = argmin_C ||x_d - ΦC||²
5. Compute R² = 1 - ||x_d - ΦC_d||² / ||x_d - mean(x_d)||²
6. Report the median R² across all embedding dimensions

Expected outcome: Median R² should be high (>0.80), confirming Fourier basis structure in embeddings.

In [None]:
# CQ3: Fourier Basis R² Computation
import torch
import numpy as np
import sys
sys.path.append('/home/smallyan/critic_model_mechinterp/icot')
from src.model_utils import load_hf_model

# Set random seed
np.random.seed(42)
torch.manual_seed(42)

# TODO: Load ICoT model and extract embedding matrix for digits 0-9
# model, tokenizer = load_hf_model(...)
# embeddings = model.transformer.wte.weight[0:10, :].detach().cpu().numpy()  # Shape: (10, 768)

# TODO: Construct Fourier basis matrix Φ (shape: 10 x 6)
# Columns: [constant, cos(2πn/10), sin(2πn/10), cos(2πn/5), sin(2πn/5), parity]
# n = np.arange(10)
# phi = np.column_stack([
#     np.ones(10),
#     np.cos(2 * np.pi * n / 10),
#     np.sin(2 * np.pi * n / 10),
#     np.cos(2 * np.pi * n / 5),
#     np.sin(2 * np.pi * n / 5),
#     (-1) ** n
# ])

# TODO: For each embedding dimension d:
#   1. Extract vector x_d of length 10 (values for digits 0-9)
#   2. Fit: C_d = argmin_C ||x_d - Φ @ C||² (use least squares)
#   3. Compute R² = 1 - ||x_d - Φ @ C_d||² / ||x_d - mean(x_d)||²

# TODO: Compute median R² across all dimensions
# median_r2 = np.median(r2_values)

print("Median R² for Fourier basis fit:")
# print(f"R² = {median_r2:.4f}")
