
# Chain‑of‑Thought Faithfulness Analysis

*Generated 2025-04-27*

This Jupyter notebook implements, end‑to‑end, the pipeline for analysing whether
LLM chains‑of‑thought (CoT) are **faithful** to the model’s actual decision
process.  It uses only the files already present in your directory tree:

```
/root/CoTFaithChecker/data/mmlu/
    input_mcq_data.json
    hints_sycophancy.json
    hints_induced_urgency.json
    hints_unethical_information.json
/root/CoTFaithChecker/g_cot_cluster/outputs/mmlu/
    DeepSeek-R1-Distill-Llama-8B/
        [...]
```

You may freely tweak paths, sample sizes, etc.  Each section is heavily
commented so you can swap in a different model or dataset with minimal effort.



## Notebook Road‑map

| Section | Purpose |
|---------|---------|
| **0. Setup & utility functions** | imports, paths, JSON helpers |
| **1. Load data** | questions, hints, completions, *segmented* CoTs |
| **2. Tokenizer & char→token mapping** | build category masks |
| **3. Model forward pass** | grab *attentions* & *hidden states* |
| **4. Descriptive attention metrics** | answer→category shares, entropy, hint Δ |
| **5. Representational probing** | PCA / UMAP on category‑pooled hidden states |
| **6. Causal ablation tests** | zero‑out reasoning vs. hint tokens, measure flip‑rate |
| **7. Visualisation & saving results** | bar plots, scatter, CSV export |

> **Tip:** GPU highly recommended.  On an 8‑GB card, process ~100 examples at a
> time or switch the model to 8‑bit via `load_in_8bit=True`.


In [None]:

# SECTION 0 – Imports & configuration
from pathlib import Path
import json, math, re, itertools, collections, random, string, os, sys, gc, time, hashlib, logging
from dataclasses import dataclass, asdict
from typing import List, Dict, Tuple, Any

import torch
import torch.nn.functional as F
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE

from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

logging.basicConfig(level=logging.INFO, format='[%(asctime)s] %(levelname)s %(message)s')


In [None]:

# Edit these if your clone lives elsewhere
DATA_ROOT   = Path('/root/CoTFaithChecker/data/mmlu')
SEG_ROOT    = Path('/root/CoTFaithChecker/g_cot_cluster/outputs/mmlu/DeepSeek-R1-Distill-Llama-8B')
MODEL_NAME  = 'deepseek-ai/deepseek-llama-8b'
DEVICE      = 'cuda' if torch.cuda.is_available() else 'cpu'

print('Using device:', DEVICE)


In [None]:

def load_jsonl(path: Path):
    """Load conventional JSON array or JSON‑Lines file."""
    txt = path.read_text(encoding='utf-8')
    try:
        data = json.loads(txt)
    except json.JSONDecodeError:
        data = [json.loads(line) for line in txt.splitlines() if line.strip()]
    return data


### 1. Load questions, hints & completions

In [None]:

# Questions
questions = {d['question_id']: d for d in load_jsonl(DATA_ROOT / 'input_mcq_data.json')}
print(f'Loaded {len(questions):,} questions')

# Hints by type
HINT_TYPES = ['sycophancy', 'induced_urgency', 'unethical_information']
hints = {
    htype: {d['question_id']: d for d in load_jsonl(DATA_ROOT / f'hints_{htype}.json')}
    for htype in HINT_TYPES
}
print({k: len(v) for k,v in hints.items()})

# Verification & switch metadata (per hint type + none)
switch_meta = {}
for htype in ['none'] + HINT_TYPES:
    path = SEG_ROOT.parent / f'DeepSeek-R1-Distill-Llama-8B/{htype}/switch_analysis_with_500.json'
    if path.exists():
        switch_meta[htype] = {d['question_id']: d for d in load_jsonl(path)}
print({k: len(v) for k,v in switch_meta.items()})

# Segmented CoTs (correct + incorrect)
SEG_PATH = SEG_ROOT / 'correct_indices'   # adjust for correct/incorrect dirs
seg_files = sorted(SEG_PATH.glob('*.json'))
segmented = {}
for fp in seg_files:
    data = load_jsonl(fp)
    for d in data:
        segmented[d['question_id']] = d['segments']
print(f'Segmented CoTs loaded for {len(segmented):,} questions')


### 2. Tokenisation and category masks

In [None]:

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)

@dataclass
class Sample:
    qid: int
    prompt: str
    completion: str
    segments: List[dict]        # as stored
    hint_type: str              # 'none' | sycophancy | ...
    switched: bool
    is_correct: bool
    hint_option: str
    hint_span: Tuple[int,int]   # char indices within prompt if present

    # will be filled later
    input_ids: torch.Tensor     = None
    category_masks: Dict[str, torch.Tensor] = None
    answer_mask: torch.Tensor   = None
    hint_token_mask: torch.Tensor = None

def build_char_to_tok_map(text: str, enc, input_ids):
    """Return list len(text) → token index (inclusive end)."""
    # token offsets from fast tokenizer
    offsets = enc(text, return_offsets_mapping=True, add_special_tokens=False).offset_mapping
    char2tok = [-1]*len(text)
    for tok_i, (s,e) in enumerate(offsets):
        for c in range(s,e):
            char2tok[c] = tok_i
    return char2tok

def build_masks(sample: Sample):
    text = sample.prompt + sample.completion
    enc = tokenizer
    ids = enc(text, return_tensors='pt', add_special_tokens=False)
    sample.input_ids = ids.input_ids.to(DEVICE)
    char2tok = build_char_to_tok_map(text, enc, ids.input_ids[0])

    # category masks
    masks={}
    for seg in sample.segments:
        cat = seg['phrase_category']
        tok_indices = set()
        for c in range(seg['start'], seg['end']+1):
            ti = char2tok[c]
            if ti!=-1:
                tok_indices.add(ti)
        if not tok_indices:
            continue
        mask = torch.zeros(ids.input_ids.shape[-1], dtype=torch.bool)
        mask[list(tok_indices)] = True
        masks.setdefault(cat, torch.zeros_like(mask))
        masks[cat] |= mask
    sample.category_masks = masks

    # answer mask – detect option letter xx inside completion
    answer_pattern = re.compile(r"\[\s*([A-D])\s*\]")
    m = answer_pattern.search(sample.completion[::-1]) or answer_pattern.search(sample.completion)
    if m:
        ans_start = len(sample.prompt) + m.start(0)  # crude
        indices = set()
        for c in range(ans_start, ans_start + len(m.group(0))):
            if 0<=c<len(char2tok):
                ti = char2tok[c]
                if ti!=-1: indices.add(ti)
        ans_mask = torch.zeros(ids.input_ids.shape[-1], dtype=torch.bool)
        ans_mask[list(indices)] = True
        sample.answer_mask = ans_mask
    else:
        sample.answer_mask = torch.zeros(ids.input_ids.shape[-1], dtype=torch.bool)

    # hint mask (span is provided)
    if sample.hint_span!=(None,None):
        hset=set()
        for c in range(sample.hint_span[0], sample.hint_span[1]):
            ti=char2tok[c]
            if ti!=-1: hset.add(ti)
        hmask=torch.zeros(ids.input_ids.shape[-1],dtype=torch.bool)
        hmask[list(hset)] = True
        sample.hint_token_mask = hmask
    else:
        sample.hint_token_mask = torch.zeros_like(sample.answer_mask)


### 2.b – Build `Sample` objects

In [None]:

SAMPLES=[]
for qid, qdata in questions.items():
    for htype in ['none'] + HINT_TYPES:
        # build prompt+completion text
        if htype=='none':
            comp_path = SEG_ROOT.parent / 'none' / 'completions_with_500.json'
        else:
            comp_path = SEG_ROOT.parent / htype / 'completions_with_500.json'
        if not comp_path.exists(): continue
        comp_map = {d['question_id']: d['completion'] for d in load_jsonl(comp_path)}
        if qid not in comp_map: continue
        completion = comp_map[qid]

        # find segments
        segs = segmented.get(qid, [])
        # identify hint span in the prompt, if any
        hint_data = hints.get(htype, {}).get(qid)
        if hint_data:
            hint_text = hint_data['hint_text']
            prompt = completion.split('assistant')[0].split('user')[1]  # approximate; adjust if needed
            span_start = prompt.find(hint_text)
            span_end   = span_start + len(hint_text) if span_start!=-1 else (None, None)
            hint_span=(span_start, span_end)
            hint_option=hint_data['hint_option']
        else:
            prompt = completion.split('assistant')[0].split('user')[1]
            hint_span=(None,None)
            hint_option=None

        sw = switch_meta.get(htype,{}).get(qid, {})
        sample = Sample(
            qid=qid,
            prompt=prompt,
            completion=completion,
            segments=segs,
            hint_type=htype,
            switched=sw.get('switched', False),
            is_correct=not sw.get('switched', False) if htype=='none' else sw.get('is_correct_option', False),
            hint_option=hint_option,
            hint_span=hint_span
        )
        build_masks(sample)
        SAMPLES.append(sample)
print(f'Total Sample objects: {len(SAMPLES):,}')


### 3. Load the Llama‑8B model (8‑bit)

In [None]:

bnb = BitsAndBytesConfig(load_in_8bit=True, llm_int8_threshold=6.0)
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, device_map='auto',
                                             torch_dtype=torch.float16,
                                             quantization_config=bnb)
model.eval()


### 4. Forward pass helper

In [None]:

@torch.no_grad()
def forward_with_cache(sample: Sample, layers=None):
    """Run one sample, return (rolled_attention, hidden_states)."""
    outs = model(sample.input_ids,
                 output_attentions=True,
                 output_hidden_states=True,
                 use_cache=False)
    attns = torch.stack(outs.attentions)     # (L, H, T, T)
    if layers is not None:
        attns = attns[layers]
    # attention rollout across layers
    rolled = attns[0]
    for l in range(1, attns.size(0)):
        rolled = rolled @ attns[l]
    hidden = torch.stack(outs.hidden_states) # (L+1, T, d)
    return rolled, hidden


### 4.b – Attention metrics

In [None]:

def answer_category_shares(rolled_attn: torch.Tensor, sample: Sample):
    # rolled_attn: (H, T, T)
    shares={}
    ans_idx = sample.answer_mask.nonzero(as_tuple=False).squeeze(1)
    if ans_idx.nelement()==0: return shares
    total = rolled_attn[:, ans_idx].sum().item()
    for cat, mask in sample.category_masks.items():
        w = rolled_attn[:, ans_idx][:,:, mask].sum().item()
        shares[cat] = w/total if total>0 else 0.0
    if sample.hint_token_mask.any():
        w = rolled_attn[:, ans_idx][:,:, sample.hint_token_mask].sum().item()
        shares['HINT'] = w/total if total>0 else 0.0
    return shares

def attention_entropy(shares: Dict[str,float]):
    p = np.array(list(shares.values()))
    p = p[p>0]
    return -np.sum(p*np.log2(p)) if p.size else np.nan


### 5. Representational probing (PCA / UMAP)

In [None]:

def pooled_hidden(hidden: torch.Tensor, sample: Sample):
    # hidden: (L+1, T, d); take final layer
    h = hidden[-1]         # (T, d)
    pools={}
    for cat, mask in sample.category_masks.items():
        if mask.any():
            pools[cat] = h[mask].mean(0).cpu().numpy()
    if sample.hint_token_mask.any():
        pools['HINT'] = h[sample.hint_token_mask].mean(0).cpu().numpy()
    return pools


### 6. Causal ablation test

In [None]:

def regenerate_answer(sample: Sample, edited_hidden=None, answer_len=3):
    """Feed edited hidden state as past‑kv and regenerate answer tokens."""
    # For brevity, we regenerate full output; practical use would stop at answer.
    if edited_hidden is not None:
        raise NotImplementedError('Hidden‑state surgery demo left as TODO')
    out = model.generate(sample.input_ids.to(DEVICE),
                         max_new_tokens=answer_len,
                         do_sample=False)
    return tokenizer.decode(out[0][-answer_len:])


### 7. Compute metrics for a subset

In [None]:

records=[]
SAMPLE_LIMIT = 200   # change as needed
for i,sample in enumerate(random.sample(SAMPLES, min(SAMPLE_LIMIT, len(SAMPLES)))):
    rolled, hidden = forward_with_cache(sample, layers=None)
    shares = answer_category_shares(rolled, sample)
    ent   = attention_entropy(shares)
    pools = pooled_hidden(hidden, sample)
    record = dict(
        qid      = sample.qid,
        hint_type= sample.hint_type,
        switched = sample.switched,
        correct  = sample.is_correct,
        entropy  = ent,
        shares   = shares,
    )
    # flatten shares
    for k,v in shares.items():
        record[f'share_{k}'] = v
    records.append(record)
    if (i+1)%20==0:
        print(f'Processed {i+1} / {SAMPLE_LIMIT}')
metrics_df = pd.DataFrame(records)


### 8. Visualisation examples

In [None]:

import seaborn as sns
sns.set()

# Barplot of mean answer→category share by correctness
melt = metrics_df.melt(id_vars=['correct','hint_type'],
                       value_vars=[c for c in metrics_df.columns if c.startswith('share_')],
                       var_name='category', value_name='share')
melt['category']=melt['category'].str.replace('share_','')
plt.figure(figsize=(10,4))
sns.barplot(data=melt, x='category', y='share', hue='correct', ci='sd')
plt.title('Answer‑to‑Category Attention Share')
plt.xticks(rotation=45)
plt.tight_layout()
plt.show()


In [None]:

# Build pooled vectors
vecs=[]; labels=[]
for s in random.sample(SAMPLES, min(500,len(SAMPLES))):
    rolled, hidden = forward_with_cache(s, layers=None)
    pools = pooled_hidden(hidden, s)
    if 'logical_deduction' in pools:
        vecs.append(pools['logical_deduction'])
        labels.append('correct' if s.is_correct else 'wrong')
X = np.stack(vecs)
pca = PCA(n_components=2).fit_transform(X)
plt.figure(figsize=(6,6))
for lab in set(labels):
    idx=[i for i,l in enumerate(labels) if l==lab]
    plt.scatter(pca[idx,0], pca[idx,1], label=lab, alpha=0.7, s=20)
plt.legend(); plt.title('PCA of logical_deduction pooled hidden states')
plt.show()


### 9. Save metrics

In [None]:

OUT_CSV = Path('cot_attention_metrics.csv')
metrics_df.to_csv(OUT_CSV, index=False)
print('Saved', OUT_CSV.resolve())



## 10. Next steps & TODOs

* Increase `SAMPLE_LIMIT` or implement batch processing to cover the full set.
* Complete the *causal ablation* (`regenerate_answer`) using hidden‑state
  zero‑out or prompt masking.
* Add entropy/hint‑Δ plots and statistical tests (e.g. paired t‑test).
* Swap PCA for UMAP if non‑linear structure is suspected.
* Profile GPU memory – consider slicing attention layers or using
  `torch.utils.checkpoint` for bigger batches.
