# CoT Phase Transition Experiment v3.1 - Curvature (κ) Measurement

**Version**: 3.1 (2024-12-24)

**Changes from v3.0**:
- Changed model from Llama 3 8B to **Mistral 7B** (no authentication required)
- Works with T4 GPU (A100 not required)

**Purpose**: Measure hidden state trajectory curvature (κ) to geometrically validate CoT collapse.

**Design**:
- Model: Mistral 7B v0.1 (open-weight, NO authentication needed)
- Conditions: Direct / Standard CoT (A=1.0) / Corrupted CoT (A=0.2)
- Layers: L/2, 3L/4, L (3 layers)
- Measurement: Online curvature computation along CoT token range

**Hypothesis**:
- Successful CoT → Low κ (smooth trajectory)
- Corrupted CoT → High κ (jagged trajectory)

## 0. Google Drive Connection

In [None]:
from google.colab import drive
drive.mount('/content/drive')

import os
from datetime import datetime

EXPERIMENT_VERSION = 'v3.1'
EXPERIMENT_DATE = datetime.now().strftime('%Y%m%d')

# 実験1のディレクトリを探す
SAVE_DIR = '/content/drive/MyDrive/CoT_Experiment'

# 実験1のclean tracesがあるディレクトリを特定
exp1_dirs = [d for d in os.listdir(SAVE_DIR) if d.startswith('full_experiment_v3')]
if exp1_dirs:
    EXP1_DIR = f'{SAVE_DIR}/{sorted(exp1_dirs)[-1]}'
else:
    EXP1_DIR = f'{SAVE_DIR}/pilot_v2'

print(f'Experiment 1 directory: {EXP1_DIR}')

# 実験2の保存先
SAVE_DIR_KAPPA = f'{SAVE_DIR}/kappa_experiment_{EXPERIMENT_VERSION}_{EXPERIMENT_DATE}'
os.makedirs(SAVE_DIR_KAPPA, exist_ok=True)
os.makedirs(f'{SAVE_DIR_KAPPA}/results', exist_ok=True)
os.makedirs(f'{SAVE_DIR_KAPPA}/checkpoints', exist_ok=True)
os.makedirs(f'{SAVE_DIR_KAPPA}/figures', exist_ok=True)

print(f'Kappa experiment save directory: {SAVE_DIR_KAPPA}')

## 1. Check GPU and Install Dependencies

In [None]:
!nvidia-smi

import torch
print(f'\nPyTorch version: {torch.__version__}')
print(f'CUDA available: {torch.cuda.is_available()}')
if torch.cuda.is_available():
    print(f'GPU: {torch.cuda.get_device_name(0)}')
    print(f'GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB')

In [None]:
!pip install transformers accelerate bitsandbytes datasets matplotlib seaborn pandas tqdm scipy -q
print('Dependencies installed.')

## 2. Configuration

In [None]:
import json
import re
import math
import hashlib
import random
from typing import List, Dict, Tuple, Optional, Any
from dataclasses import dataclass, asdict
from datetime import datetime
from tqdm import tqdm
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

# =============================================================================
# Configuration
# =============================================================================
GLOBAL_SEED = 20251224

# Mistral 7B - NO AUTHENTICATION REQUIRED
MODEL_NAME = "mistralai/Mistral-7B-v0.1"

# κ計算の定数（仕様パック準拠）
EPSILON = 1e-8
DELTA = 1e-6

# 条件
CONDITIONS = ['direct', 'clean', 'corrupted']
CORRUPTION_LAMBDA = 0.8

# チェックポイント
CHECKPOINT_EVERY = 20

print('='*60)
print('KAPPA EXPERIMENT CONFIGURATION')
print('='*60)
print(f'  Model: {MODEL_NAME}')
print(f'  Conditions: {CONDITIONS}')
print(f'  Corruption λ: {CORRUPTION_LAMBDA}')
print('  NOTE: No authentication required for Mistral!')
print('='*60)

## 3. Load Model (Mistral 7B - No Auth Needed)

In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

print('Loading tokenizer...')
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

print('Loading model (this may take a few minutes)...')
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=torch.float16,
    device_map="auto",
    output_hidden_states=True
)
model.eval()

# モデル情報
NUM_LAYERS = model.config.num_hidden_layers
HIDDEN_DIM = model.config.hidden_size

print(f'\nModel loaded successfully!')
print(f'  Model: {MODEL_NAME}')
print(f'  Layers: {NUM_LAYERS}')
print(f'  Hidden dim: {HIDDEN_DIM}')

# 測定する層を決定（仕様パック準拠：L/2, 3L/4, L）
MEASURE_LAYERS = [
    NUM_LAYERS // 2,
    (3 * NUM_LAYERS) // 4,
    NUM_LAYERS
]
print(f'  Measurement layers: {MEASURE_LAYERS}')

## 4. Load Clean Traces from Experiment 1

In [None]:
def load_json(filepath: str) -> Any:
    with open(filepath, 'r', encoding='utf-8') as f:
        return json.load(f)

def save_json(data: Any, filepath: str):
    with open(filepath, 'w', encoding='utf-8') as f:
        json.dump(data, f, ensure_ascii=False, indent=2)

# Clean traces (I=10) を読み込む
clean_traces_path = f'{EXP1_DIR}/clean_traces/clean_traces_I10_v3.json'
if not os.path.exists(clean_traces_path):
    clean_traces_path = f'{EXP1_DIR}/clean_traces_I10.json'
if not os.path.exists(clean_traces_path):
    # ルートディレクトリも探す
    for root, dirs, files in os.walk(SAVE_DIR):
        for f in files:
            if 'clean_traces_I10' in f and f.endswith('.json'):
                clean_traces_path = os.path.join(root, f)
                break

print(f'Loading clean traces from: {clean_traces_path}')
clean_traces_data = load_json(clean_traces_path)
print(f'Loaded {len(clean_traces_data)} clean traces')

# 問題データも読み込む
problems_path = f'{EXP1_DIR}/problems_v3.json'
if not os.path.exists(problems_path):
    problems_path = f'{EXP1_DIR}/pilot_problems.json'
if not os.path.exists(problems_path):
    for root, dirs, files in os.walk(SAVE_DIR):
        for f in files:
            if 'problems' in f and f.endswith('.json'):
                problems_path = os.path.join(root, f)
                break

print(f'Loading problems from: {problems_path}')
problems_data = load_json(problems_path)
print(f'Loaded {len(problems_data)} problems')

# 辞書に変換
problems_dict = {p['index']: p for p in problems_data}
traces_dict = {t['problem_index']: t for t in clean_traces_data}

## 5. Data Structures

In [None]:
@dataclass
class KappaResult:
    problem_index: int
    condition: str
    kappa_by_layer: Dict[int, float]
    kappa_mean: float
    n_tokens: int
    timestamp: str

## 6. Corruption Logic

In [None]:
def derive_seed(global_seed: int, problem_id: int, I: int, lam: float, replicate_id: int = 0) -> int:
    key = f"{global_seed}|{problem_id}|I={I}|lam={lam}|rep={replicate_id}"
    h = hashlib.sha256(key.encode("utf-8")).hexdigest()
    return int(h[:8], 16)

def pick_corrupted_steps(I: int, lam: float, seed: int) -> List[int]:
    K = int(round(lam * I))
    if K == 0:
        return []
    steps = list(range(1, I + 1))
    rng = random.Random(seed)
    rng.shuffle(steps)
    return sorted(steps[:K])

def assign_corruption_types(corrupted_steps: List[int], seed: int) -> Dict[int, str]:
    K = len(corrupted_steps)
    if K == 0:
        return {}
    n_irr = (K * 1) // 5
    n_loc = (K * 2) // 5
    n_wrong = K - n_irr - n_loc
    if n_wrong == 0 and K > 0:
        n_wrong = 1
        if n_loc > 0:
            n_loc -= 1
        elif n_irr > 0:
            n_irr -= 1
    rng = random.Random(seed + 1)
    perm = corrupted_steps[:]
    rng.shuffle(perm)
    type_map = {}
    for s in perm[:n_irr]:
        type_map[s] = "IRR"
    for s in perm[n_irr:n_irr + n_loc]:
        type_map[s] = "LOC"
    for s in perm[n_irr + n_loc:]:
        type_map[s] = "WRONG"
    return type_map

IRRELEVANT_TEMPLATES = [
    "Compute an auxiliary value: aux = {a} + {b} = {result}, but it will not be used later.",
]

WRONG_CONSTRAINT_TEMPLATES = [
    "Fix an intermediate condition: set {var} = {wrong_value} as a given constraint for the rest of the steps.",
]

def generate_irrelevant_step(step_num: int, seed: int) -> str:
    rng = random.Random(seed)
    a, b = rng.randint(2, 20), rng.randint(2, 20)
    return IRRELEVANT_TEMPLATES[0].format(a=a, b=b, result=a+b)

def generate_local_error_step(original_step: str, seed: int) -> str:
    rng = random.Random(seed)
    numbers = re.findall(r'\d+', original_step)
    if not numbers:
        return f"Compute t = 10 * 3 = {rng.randint(28, 32)} (using the previous values)."
    original_result = int(numbers[-1])
    offset = rng.choice([-3, -2, -1, 1, 2, 3])
    wrong_result = max(0, original_result + offset)
    modified = re.sub(r'= (\d+)\.$', f'= {wrong_result}.', original_step)
    if modified == original_step:
        modified = re.sub(r'(\d+)\.$', f'{wrong_result}.', original_step)
    return modified

def generate_wrong_constraint_step(step_num: int, seed: int) -> str:
    rng = random.Random(seed)
    var = rng.choice(['x', 'total', 'result', 'n'])
    wrong_value = rng.randint(10, 100)
    return WRONG_CONSTRAINT_TEMPLATES[0].format(var=var, wrong_value=wrong_value)

def apply_corruption(trace_data: dict, lam: float, seed: int) -> str:
    steps = trace_data['steps'][:]
    I = len(steps)
    corrupted_steps = pick_corrupted_steps(I, lam, seed)
    corruption_types = assign_corruption_types(corrupted_steps, seed)
    new_steps = []
    for i, step_content in enumerate(steps):
        step_num = i + 1
        if step_num in corruption_types:
            ctype = corruption_types[step_num]
            step_seed = seed + step_num * 1000
            if ctype == 'IRR':
                new_content = generate_irrelevant_step(step_num, step_seed)
            elif ctype == 'LOC':
                new_content = generate_local_error_step(step_content, step_seed)
            else:
                new_content = generate_wrong_constraint_step(step_num, step_seed)
            new_steps.append(new_content)
        else:
            new_steps.append(step_content)
    lines = ['[[COT_START]]']
    for i, content in enumerate(new_steps):
        lines.append(f'Step {i+1}: {content}')
    lines.append('[[COT_END]]')
    return '\n'.join(lines)

print('Corruption logic defined.')

## 7. Prompt Construction

In [None]:
def create_prompt_direct(problem: dict) -> str:
    return f"""Problem: {problem['question']}

Compute the final numerical answer.
Answer:"""

def create_prompt_with_cot(problem: dict, cot_text: str) -> str:
    return f"""Problem: {problem['question']}

Reasoning trace:
{cot_text}

Based on the trace above, compute the final numerical answer.
Answer:"""

print('Prompt functions defined.')

## 8. Curvature (κ) Computation

In [None]:
def find_cot_token_range(input_ids: torch.Tensor, tokenizer) -> Tuple[int, int]:
    """
    [[COT_START]] と [[COT_END]] の間のトークン範囲を特定する。
    """
    text = tokenizer.decode(input_ids[0])
    start_marker = '[[COT_START]]'
    end_marker = '[[COT_END]]'
    
    start_pos = text.find(start_marker)
    end_pos = text.find(end_marker)
    
    if start_pos == -1 or end_pos == -1:
        return 0, input_ids.shape[1]
    
    # 文字位置からトークン位置への近似変換
    tokens_before_start = tokenizer.encode(text[:start_pos + len(start_marker)], add_special_tokens=False)
    start_token_idx = len(tokens_before_start)
    
    tokens_before_end = tokenizer.encode(text[:end_pos], add_special_tokens=False)
    end_token_idx = len(tokens_before_end)
    
    # 安全のためクリップ
    start_token_idx = max(0, min(start_token_idx, input_ids.shape[1] - 3))
    end_token_idx = max(start_token_idx + 3, min(end_token_idx, input_ids.shape[1]))
    
    return start_token_idx, end_token_idx


def compute_curvature_online(hidden_states: torch.Tensor, eps: float = EPSILON, delta: float = DELTA) -> float:
    """
    オンラインでκ（曲率）を計算する（仕様パック準拠）。
    """
    T = hidden_states.shape[0]
    if T < 3:
        return float('nan')
    
    total_angle = 0.0
    count = 0
    
    prev = hidden_states[0]
    curr = hidden_states[1]
    v_prev = curr - prev
    
    for t in range(2, T):
        nxt = hidden_states[t]
        v_curr = nxt - curr
        
        num = torch.dot(v_prev, v_curr).item()
        den = (torch.norm(v_prev).item() * torch.norm(v_curr).item()) + eps
        
        cos_val = num / den
        cos_val = max(-1.0 + delta, min(1.0 - delta, cos_val))
        
        angle = math.acos(cos_val)
        
        total_angle += angle
        count += 1
        
        prev, curr, v_prev = curr, nxt, v_curr
    
    return total_angle / max(count, 1)

print('Curvature computation functions defined.')

## 9. Main Measurement Function

In [None]:
@torch.no_grad()
def measure_kappa(prompt: str, measure_layers: List[int], has_cot: bool = True) -> Tuple[Dict[int, float], int]:
    """
    プロンプトに対してκを測定する。
    """
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    input_ids = inputs['input_ids']
    
    outputs = model(**inputs, output_hidden_states=True)
    hidden_states = outputs.hidden_states
    
    if has_cot:
        start_idx, end_idx = find_cot_token_range(input_ids, tokenizer)
    else:
        seq_len = input_ids.shape[1]
        start_idx = seq_len // 2
        end_idx = seq_len
    
    n_tokens = end_idx - start_idx
    
    kappa_by_layer = {}
    for layer_idx in measure_layers:
        # hidden_states[0]はembedding、hidden_states[i]はi層目の出力
        actual_idx = min(layer_idx, len(hidden_states) - 1)
        h = hidden_states[actual_idx][0, start_idx:end_idx, :]
        h = h.float()
        
        kappa = compute_curvature_online(h)
        kappa_by_layer[layer_idx] = kappa
    
    return kappa_by_layer, n_tokens


# テスト
if clean_traces_data and problems_data:
    test_prob = problems_data[0]
    test_trace = clean_traces_data[0]
    test_prompt = create_prompt_with_cot(test_prob, test_trace['full_text'])
    
    print('Testing κ measurement...')
    kappa_result, n_tok = measure_kappa(test_prompt, MEASURE_LAYERS, has_cot=True)
    print(f'  n_tokens: {n_tok}')
    print(f'  κ by layer: {kappa_result}')
    print(f'  κ mean: {np.nanmean(list(kappa_result.values())):.4f}')
    print('\nTest successful!')

## 10. Run Pilot (5 problems)

In [None]:
def run_kappa_experiment(
    problems: List[dict],
    traces: Dict[int, dict],
    max_problems: int = None
) -> List[KappaResult]:
    """
    全条件でκを測定する。
    """
    results = []
    
    if max_problems:
        problems = problems[:max_problems]
    
    for prob in tqdm(problems, desc='Problems'):
        prob_idx = prob['index']
        
        if prob_idx not in traces:
            continue
        
        trace_data = traces[prob_idx]
        
        for condition in CONDITIONS:
            try:
                if condition == 'direct':
                    prompt = create_prompt_direct(prob)
                    has_cot = False
                elif condition == 'clean':
                    prompt = create_prompt_with_cot(prob, trace_data['full_text'])
                    has_cot = True
                else:
                    seed = derive_seed(GLOBAL_SEED, prob_idx, I=10, lam=CORRUPTION_LAMBDA)
                    corrupted_cot = apply_corruption(trace_data, CORRUPTION_LAMBDA, seed)
                    prompt = create_prompt_with_cot(prob, corrupted_cot)
                    has_cot = True
                
                kappa_by_layer, n_tokens = measure_kappa(prompt, MEASURE_LAYERS, has_cot)
                kappa_mean = np.nanmean(list(kappa_by_layer.values()))
                
                result = KappaResult(
                    problem_index=prob_idx,
                    condition=condition,
                    kappa_by_layer=kappa_by_layer,
                    kappa_mean=float(kappa_mean),
                    n_tokens=n_tokens,
                    timestamp=datetime.now().isoformat()
                )
                results.append(result)
                
            except Exception as e:
                print(f'\nError for problem {prob_idx}, condition {condition}: {e}')
        
        if len(results) % (CHECKPOINT_EVERY * 3) == 0 and len(results) > 0:
            save_json([asdict(r) for r in results], f'{SAVE_DIR_KAPPA}/checkpoints/latest_kappa_v3.1.json')
    
    return results


# パイロット実行（5問）
print('='*60)
print('PILOT: Testing with 5 problems')
print('='*60)

pilot_results = run_kappa_experiment(problems_data, traces_dict, max_problems=5)

# パイロット結果
pilot_df = pd.DataFrame([asdict(r) for r in pilot_results])
print('\nPilot Results:')
print(pilot_df.groupby('condition')['kappa_mean'].agg(['mean', 'std', 'count']))

## 11. Run Full Experiment

In [None]:
# パイロット結果を確認してから実行
print('='*60)
print('FULL EXPERIMENT')
print('='*60)

all_kappa_results = run_kappa_experiment(problems_data, traces_dict, max_problems=None)

# 保存
save_json([asdict(r) for r in all_kappa_results], f'{SAVE_DIR_KAPPA}/results/kappa_results_v3.1.json')
print(f'\nResults saved to: {SAVE_DIR_KAPPA}/results/kappa_results_v3.1.json')

## 12. Analysis and Visualization

In [None]:
df_kappa = pd.DataFrame([asdict(r) for r in all_kappa_results])

print('Kappa Results Summary:')
print(df_kappa.groupby('condition')['kappa_mean'].agg(['mean', 'std', 'count']))

In [None]:
# 箱ひげ図
fig, ax = plt.subplots(figsize=(10, 6))

condition_order = ['direct', 'clean', 'corrupted']
colors = ['#808080', '#2E8B57', '#DC143C']

sns.boxplot(
    data=df_kappa,
    x='condition',
    y='kappa_mean',
    order=condition_order,
    palette=colors,
    ax=ax
)

ax.set_xlabel('Condition', fontsize=12)
ax.set_ylabel('Mean Curvature (κ)', fontsize=12)
ax.set_title('Hidden State Trajectory Curvature by Condition (Mistral 7B)', fontsize=14)

means = df_kappa.groupby('condition')['kappa_mean'].mean()
for i, cond in enumerate(condition_order):
    if cond in means:
        ax.annotate(f'μ={means[cond]:.3f}', xy=(i, means[cond]), xytext=(i+0.2, means[cond]),
                    fontsize=10, ha='left')

plt.tight_layout()
plt.savefig(f'{SAVE_DIR_KAPPA}/figures/kappa_boxplot_v3.1.png', dpi=150)
plt.show()
print(f'Saved: {SAVE_DIR_KAPPA}/figures/kappa_boxplot_v3.1.png')

In [None]:
# 統計的検定
from scipy import stats

print('\n=== Statistical Tests ===')

clean_kappa = df_kappa[df_kappa['condition'] == 'clean']['kappa_mean'].dropna()
corrupted_kappa = df_kappa[df_kappa['condition'] == 'corrupted']['kappa_mean'].dropna()

if len(clean_kappa) > 0 and len(corrupted_kappa) > 0:
    t_stat, p_val = stats.ttest_ind(clean_kappa, corrupted_kappa)
    print(f'\nClean vs Corrupted:')
    print(f'  Clean mean: {clean_kappa.mean():.4f}')
    print(f'  Corrupted mean: {corrupted_kappa.mean():.4f}')
    print(f'  t-statistic: {t_stat:.4f}')
    print(f'  p-value: {p_val:.4e}')
    
    if p_val < 0.05 and corrupted_kappa.mean() > clean_kappa.mean():
        print('  ✓ Significant: Corrupted CoT has HIGHER curvature (as hypothesized)')
    elif p_val < 0.05:
        print('  ! Significant difference found, but direction unexpected')
    else:
        print('  - No significant difference detected')
    
    # Effect size
    pooled_std = np.sqrt((clean_kappa.var() + corrupted_kappa.var()) / 2)
    if pooled_std > 0:
        cohens_d = (corrupted_kappa.mean() - clean_kappa.mean()) / pooled_std
        print(f"  Cohen's d: {cohens_d:.4f}")

## 13. Summary

In [None]:
print('='*60)
print('KAPPA EXPERIMENT SUMMARY')
print('='*60)
print(f'Version: {EXPERIMENT_VERSION}')
print(f'Model: {MODEL_NAME}')
print(f'Total measurements: {len(all_kappa_results)}')
print(f'\nMean κ by condition:')
for cond in condition_order:
    cond_data = df_kappa[df_kappa['condition']==cond]['kappa_mean']
    if len(cond_data) > 0:
        print(f'  {cond:12s}: {cond_data.mean():.4f} ± {cond_data.std():.4f}')
print(f'\nFiles saved to: {SAVE_DIR_KAPPA}')
print('='*60)