In [None]:
from dotenv import load_dotenv

load_dotenv()

In [None]:
import torch

# disable gradient
torch.set_grad_enabled(False)

In [None]:
from transformer_lens import HookedTransformer

model = HookedTransformer.from_pretrained("Qwen/Qwen3-1.7B", device="cuda", trust_remote_code=True)

In [None]:
import re
from typing import Literal
from nltk.tokenize import sent_tokenize
from openai.types.chat import ChatCompletionAssistantMessageParam, ChatCompletionSystemMessageParam, ChatCompletionUserMessageParam

Message = ChatCompletionAssistantMessageParam | ChatCompletionSystemMessageParam | ChatCompletionUserMessageParam

def get_think_sentences(assistant_content: str) -> list[str]:
    """
    Given an assistant message content, return all sentences in its <think>...</think> part as a list.
    Uses nltk.sent_tokenize for robust sentence splitting.
    """
    match = re.search(r"<think>(.*?)</think>", assistant_content, re.DOTALL)
    if not match:
        return []
    think_text = match.group(1).strip()
    sentences = sent_tokenize(think_text)
    return [s for s in sentences if s.strip()]

def truncate_last_assistant(messages: list[Message], sentences: list[str], sentence_idx: int):
    """
    Truncate the last assistant message's <think>...</think> content up to the given sentence index.
    Splits using the actual sentence string, not tokenization.
    """
    # Find the last assistant message

    if messages[-1]["role"] != 'assistant':
        raise ValueError("Last Message is not assistant")

    # Get the content and the <think>...</think> part
    content = messages[-1]["content"]
    think_start = content.find('<think>')
    think_end = content.find('</think>')
    if think_start == -1 or think_end == -1:
        raise ValueError("No <think> tags found")

    sentence = sentences[sentence_idx]
    new_content = content.split(sentence)[0]
    new_messages = messages.copy()
    new_messages[-1] = ChatCompletionAssistantMessageParam(role='assistant', content=new_content)
    return new_messages

def add_probe_phrase(messages: list[Message], mode: Literal["MCQ"]):
    if messages[-1]["role"] != 'assistant':
        raise ValueError("Last Message is not assistant")
    
    content = messages[-1]["content"] + "</think>"
    
    if mode == "MCQ":
        content = content + "\n```json{\"answer\": \""

    new_messages = messages.copy()
    new_messages[-1] = ChatCompletionAssistantMessageParam(role='assistant', content=content)
    return new_messages

def get_new_messages(history: list[Message], message: Message, sentence_index: int, mode="MCQ"):
    think_sentences = get_think_sentences(message["content"])
    messages = truncate_last_assistant(
        history + [message], think_sentences, sentence_index)
    messages = add_probe_phrase(messages, mode=mode)
    return messages

In [None]:
import datasets

dataset = datasets.load_dataset("tau/commonsense_qa")["validation"]
# can also use allenai/ai2arc as well

In [None]:
def get_messages(question: str, labels: list[str], texts: list[str]) -> list[Message]:

    choices = "\n".join([f"{label}) {text}" for label,
                        text in zip(labels, texts)])
    system_message = ChatCompletionSystemMessageParam(role="system", content="You are a helpful assistant. For each multiple-choice question, analyze carefully and answer with the correct option only.")
    user_message = ChatCompletionUserMessageParam(role="user", content=f"Answer the following mcq question {question}\n{choices}\nEnsure your final answer is within a json document. Example:\n```json{{answer: \"A\"}}```")
    return [system_message, user_message]


In [None]:
def get_history_completion(row_index: int):
    history = get_messages(dataset[row_index]["question"], dataset[row_index]
                            ["choices"]["label"], dataset[row_index]["choices"]["text"])
    tokens = model.tokenizer.apply_chat_template(history, add_generation_prompt=True, return_tensors="pt").to("cuda")
    out = model.generate(tokens, max_new_tokens=4000, temperature=0.6, return_type="tokens")
    # truncate the input tokens from the out
    out = out[:, tokens.shape[1]:]
    return history, ChatCompletionAssistantMessageParam(role="assistant", content=model.tokenizer.decode(out[0], skip_special_tokens=True))


In [None]:
def get_all_think_prefixes(history: list[Message], message: Message, mode="MCQ") -> list[list[Message]]:
    think_sentences = get_think_sentences(message["content"])
    prefixes: list[list[Message]] = []
    for i, _ in enumerate(think_sentences):
        new_messages = truncate_last_assistant(
            history + [message], think_sentences, i)
        new_messages = add_probe_phrase(new_messages, mode=mode)
        prefixes.append(new_messages)
    # with all think sentences
    final_message = ChatCompletionAssistantMessageParam(role='assistant', content=message["content"].partition("</think>")[0])
    prefixes.append(add_probe_phrase(history + [final_message], mode=mode))
    return prefixes

In [None]:
import json_repair
def extract_answer(content: str):
    non_thinking = content.split("</think>")[-1].strip()
    code_block = non_thinking.partition("```")[-1].strip().partition("```")[0].strip()
    result = json_repair.loads(code_block.removeprefix("json").strip())["answer"]
    return result

In [None]:
from tqdm import tqdm

def get_reasoning_prefix_top_logprobs(reasoning_prefix: list[Message], choices: list[str]):
    choice_tokens = [model.to_single_token(choice) for choice in choices]
    tokens = model.tokenizer.apply_chat_template(reasoning_prefix, add_generation_prompt=False, continue_final_message=False, return_tensors="pt").to("cuda")
    logits = model(tokens).to("cpu")
    return [logits[0, -1, t].item() for t in choice_tokens]

def get_logit_evolution(history: list[Message], message: Message, choices: list[str], mode="MCQ"):
    prefixes = get_all_think_prefixes(history, message, mode=mode)
    return [get_reasoning_prefix_top_logprobs(prefix, choices) for prefix in tqdm(prefixes)]


In [None]:
import plotly.graph_objects as go
from ipywidgets import Output
from IPython.display import display, Markdown

import os
import json

def plot_logit_evolution(logits_list, message, choices_labels, choices, question, filename=None):
    # logits_list: list of list of 4 floats (timesteps x choices)
    # message: the assistant message dict (with 'content')
    # choices_labels: list of 4 labels (e.g., ['A', 'B', 'C', 'D'])
    # choices: list of 4 choice texts
    # question: the question text
    # filename: base filename (no extension) to save plot as HTML and PNG in 'plots/'
    
    # Get answer
    answer = extract_answer(message["content"])
    
    # Get think sentences for hover text and display
    think_sentences = get_think_sentences(message["content"])
    
    # Offset sentences: first point is '', then each sentence
    hover_sentences = [''] + think_sentences
    
    # Prepare data for plotly
    x = list(range(len(logits_list)))
    fig = go.Figure()
    
    color_map = ['blue', 'orange', 'green', 'red']
    answer_color = 'black'
    
    for i, (label, choice) in enumerate(zip(choices_labels, choices)):
        y = [logits[i] for logits in logits_list]
        hover = hover_sentences if len(hover_sentences) == len(logits_list) else ['']*len(logits_list)
        is_answer = (label == answer)
        fig.add_trace(go.Scatter(
            x=x, y=y,
            mode='lines+markers',
            name=f"Choice {label}: {choice}",
            line=dict(color=answer_color if is_answer else color_map[i % len(color_map)], width=4 if is_answer else 2),
            marker=dict(size=8),
            hovertext=hover,
            hoverinfo='y'
        ))
    
    html_safe_question = question.replace("'", "&#39;").replace('"', '&quot;').replace("\n", "<br>")
    fig.update_layout(
        title=f"Logit Evolution",
        xaxis_title="Step (Think Sentence)",
        yaxis_title="Logit Value",
        legend_title="Choices",
        hovermode="x unified",
        margin=dict(t=80, b=80),
        # Add the question as a description below the title (for notebook/plotly view)
        annotations=[{
            'text': f'<b>Question:</b> {html_safe_question}',
            'align': 'left',
            'showarrow': False,
            'xref': 'paper',
            'yref': 'paper',
            'x': 0,
            'y': 1.13,
            'bordercolor': 'rgba(0,0,0,0)',
            'font': {'size': 14},
            'bgcolor': 'rgba(255,255,255,0.95)',
        }]
    )

    # Save plot if filename is not None
    if filename is not None:
        os.makedirs('plots', exist_ok=True)
        html_path = os.path.join('plots', filename + '.html')
        png_path = os.path.join('plots', filename + '.png')
        js_sentences = json.dumps(hover_sentences)
        js_question = json.dumps(f'<div id="question-block" style="margin-bottom:18px;padding:10px 12px;background:#f5f5f5;border-left:4px solid #0074D9;font-size:15px;"><b>Question:</b> {html_safe_question}</div>')
        # JS: inject question div above plot, and sentence div below plot
        post_script = f'''
        var plot = document.getElementsByClassName('plotly-graph-div')[0];
        if (plot) {{
            var qdiv = document.createElement('div');
            qdiv.innerHTML = {js_question};
            plot.parentNode.insertBefore(qdiv, plot);
            var out = document.createElement('div');
            out.id = 'sentence-output';
            out.style.margin = '20px 0';
            out.style.padding = '10px';
            out.style.border = '1px solid #ccc';
            out.style.minHeight = '2em';
            out.style.fontFamily = 'monospace';
            out.style.background = '#f9f9f9';
            out.innerHTML = '(No think sentence for this step)';
            plot.parentNode.insertBefore(out, plot.nextSibling);
            var thinkSentences = {js_sentences};
            plot.on('plotly_hover', function(data) {{
                var idx = data.points[0].x;
                if(idx == 0) {{ out.innerHTML = '(No think sentence for this step)'; }}
                else if(idx < thinkSentences.length) {{ out.innerHTML = '<b>Step ' + idx + ':</b> ' + thinkSentences[idx]; }}
            }});
            plot.on('plotly_unhover', function(data) {{
                out.innerHTML = '(No think sentence for this step)';
            }});
        }}
        '''
        fig.write_html(html_path, include_plotlyjs='cdn', full_html=True, post_script=post_script)
        fig.write_image(png_path)

    # Output widget for displaying the sentence (for notebook interactivity)
    out = Output()
    display(fig)
    display(out)

    def update_sentence(trace, points, selector):
        if points.point_inds:
            idx = points.point_inds[0]
            with out:
                out.clear_output()
                if idx == 0:
                    display(Markdown("*(No think sentence for this step)*"))
                elif idx < len(hover_sentences):
                    display(Markdown(f"**Step {idx}:** {hover_sentences[idx]}"))

    # Attach the callback to all traces
    for trace in fig.data:
        trace.on_hover(update_sentence)

In [None]:
for i in range(40):
    for _ in range(4):
        try:
            history, message = get_history_completion(i)
        except Exception as e:
            print(f"Error at index {i}: {e}")
            continue
        else:
            break
    else:
        # exhausted retries
        continue
    logits = get_logit_evolution(history, message, choices=dataset[i]["choices"]["label"], mode="MCQ")
    choices_labels = dataset[i]["choices"]["label"]
    choices = dataset[i]["choices"]["text"]
    plot_logit_evolution(logits, message, choices_labels, choices, dataset[i]["question"], f"commonsenseqa_{i}")