In [1]:
import os
import re
import sys
import pandas as pd
import plotly.express as px
from types import SimpleNamespace

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset

sys.path.append('sparsify')

from sparsify import Sae

Triton not installed, using eager implementation of SAE decoder.


In [2]:
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x27150825e20>

In [3]:
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

In [4]:
torch.set_default_device(DEVICE)
torch.get_default_device()

device(type='cuda', index=0)

### Usage Recommendations
https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B#usage-recommendations


1. Set the temperature within the range of 0.5-0.7 (0.6 is recommended) to prevent endless repetitions or incoherent outputs.

2. Avoid adding a system prompt; all instructions should be contained within the user prompt.

3. For mathematical problems, it is advisable to include a directive in your prompt such as: "Please reason step by step, and put your final answer within \boxed{}."

4. When evaluating model performance, it is recommended to conduct multiple tests and average the results.

Additionally, we have observed that the DeepSeek-R1 series models tend to bypass thinking pattern (i.e., outputting "<think>\n\n</think>") when responding to certain queries, which can adversely affect the model's performance. To ensure that the model engages in thorough reasoning, we recommend enforcing the model to initiate its response with "<think>\n" at the beginning of every output.

In [5]:
tokenizer = AutoTokenizer.from_pretrained('deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B')
model = AutoModelForCausalLM.from_pretrained('deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B')

In [6]:
def inference(chat_prompt, max_new_tokens=400, temperature=0.6, add_generation_prompt=True, output_hidden_states=True):
    tokenized_inputs = tokenizer.apply_chat_template(chat_prompt, tokenize=True, add_generation_prompt=add_generation_prompt, return_tensors='pt', return_dict=True)

    model_out = model.generate(
        tokenized_inputs.input_ids,
        attention_mask=tokenized_inputs.attention_mask,
        max_new_tokens=max_new_tokens,
        temperature=temperature,
        output_hidden_states=output_hidden_states,
        return_dict_in_generate=True
    )
    
    return model_out

In [7]:
def decode_inference(model_out_sequences, skip_special_tokens=False):
    decoded_out = tokenizer.batch_decode(model_out_sequences, skip_special_tokens=skip_special_tokens)
    return decoded_out

In [34]:
EXPERIMENT_NUM_STR = '01'

EXPERIMENT_FOLDER_NAME = os.path.join('experiments_files', f'experiment_{EXPERIMENT_NUM_STR}')

EXPERIMENT_REASONING_PATH = os.path.join(EXPERIMENT_FOLDER_NAME, 'reasoning')

MODEL_OUT_CACHE_FOLDER_PATH = os.path.join(EXPERIMENT_REASONING_PATH, 'model_outputs_cache')

PLOTS_PATH = os.path.join(EXPERIMENT_REASONING_PATH, 'plots')
ALL_TOKENS_PLOTS_PATH = os.path.join(PLOTS_PATH, 'all_tokens')
GENERATED_ONLY_TOKENS_PLOTS_PATH = os.path.join(PLOTS_PATH, 'generated_tokens_only')

paths_to_create = [
    MODEL_OUT_CACHE_FOLDER_PATH,
    ALL_TOKENS_PLOTS_PATH,
    GENERATED_ONLY_TOKENS_PLOTS_PATH
]

for path in paths_to_create:
    os.makedirs(path, exist_ok=True)

In [9]:
def save_model_output(model_out, save_path):
    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    data_to_save = {
        'sequences': model_out.sequences.cpu(),
        'hidden_states': tuple(
            tuple(state.cpu() for state in layer) for layer in model_out.hidden_states
        )
    }
    torch.save(data_to_save, save_path)

def load_model_output(load_path, device=DEVICE):
    loaded_data = torch.load(load_path, map_location=device)
    sequences = loaded_data['sequences']
    hidden_states = tuple(
        tuple(state.to(device) for state in layer) for layer in loaded_data['hidden_states']
    )
    return SimpleNamespace(
        sequences=sequences,
        hidden_states=hidden_states
    )

In [10]:
def create_chat_prompts(prompts, suffix=''):
    chat_prompts = [
        [
            {
                'role': 'user',
                'content': prompt_content + suffix
            }
        ] for prompt_content in prompts
    ]
    return chat_prompts

In [11]:
REASONING_SUFFIX = ' Please reason step by step in a few words, and put your final answer within \\boxed{}'

## Dataset Processing

In [12]:
# def get_processed_prompts(questions=None, answers=None, filename='processed_prompts.csv', suffix=''):
#     if os.path.exists(filename):
#         print(f'Loading from existing file: {filename}')
#         return pd.read_csv(filename)
    
#     if questions is None or answers is None:
#         print(f'No questions nor answers were provided and file {filename} wasn\'t found')

#     chat_prompts = create_chat_prompts(questions, suffix)
#     prompts_df = pd.DataFrame({
#         'chat_prompt': [str(prompt) for prompt in chat_prompts],
#         'correct_answer': answers
#     })
#     prompts_df.to_csv(filename, index=False)
#     print(f'Prompts file created: {filename}')
#     return prompts_df

In [13]:
# RESONING_PROMPTS_FILENAME = 'reasoning_prompts.csv'

In [14]:
# if not os.path.exists(RESONING_PROMPTS_FILENAME):
#     gsm8k_ds = load_dataset('openai/gsm8k', 'main')

#     reasoning_prompts = get_processed_prompts(
#         questions=gsm8k_ds['train']['question'],
#         answers=gsm8k_ds['train']['answer'],
#         suffix=REASONING_SUFFIX,
#         filename=RESONING_PROMPTS_FILENAME
#     )

#     del gsm8k_ds

# reasoning_prompts = get_processed_prompts(filename=RESONING_PROMPTS_FILENAME)

In [12]:
gsm8k_ds = load_dataset('openai/gsm8k', 'main')

reasoning_chat_prompts = create_chat_prompts(gsm8k_ds['train']['question'], suffix=REASONING_SUFFIX)

del gsm8k_ds

## Interpretability Experiments

### Utils functions

#### Utils for individual SAE encoder out

In [13]:
def get_top_k_encoder_out_by_acts(encoder_out, top_k=None, no_feature_repetition=False):
    top_acts_flat = encoder_out.top_acts.flatten()
    top_indices_flat = encoder_out.top_indices.flatten()

    top_acts_num = top_acts_flat.size(0)
    if top_k is None or top_k > top_acts_num:
        top_k = top_acts_num

    if no_feature_repetition:
        feature_max_act = {}
        for feature_idx_tn, act_val in zip(top_indices_flat, top_acts_flat):
            feature_idx = feature_idx_tn.item()
            if feature_idx not in feature_max_act or act_val > feature_max_act[feature_idx]:
                feature_max_act[feature_idx] = act_val
        
        sorted_features = sorted(feature_max_act.items(), key=lambda x: x[1], reverse=True)
        
        top_k = min(top_k, len(sorted_features))
        top_features = sorted_features[:top_k]
        
        unique_indices = torch.tensor([feature[0] for feature in top_features], device=top_indices_flat.device)
        unique_acts = torch.tensor([feature[1] for feature in top_features], device=top_acts_flat.device)
        
        return SimpleNamespace(
            top_acts=unique_acts,
            top_indices=unique_indices
        )

    top_acts_values, top_flat_acts_pos = torch.topk(top_acts_flat, top_k)
    corresponding_feature_indices = top_indices_flat[top_flat_acts_pos]

    return SimpleNamespace(
        top_acts=top_acts_values,
        top_indices=corresponding_feature_indices
    )

In [14]:
def get_top_k_feature_freq_encoder_out(all_steps_feature_indices, top_k=None):
    flat_freq = torch.bincount(all_steps_feature_indices.flatten())
    feature_indices_with_counts = torch.nonzero(flat_freq, as_tuple=False).squeeze()

    feature_freqs = flat_freq[feature_indices_with_counts]
    
    features_num = feature_indices_with_counts.size(0)
    if top_k is None or top_k > features_num:
        top_k = features_num

    top_k_freqs, top_k_indices = torch.topk(feature_freqs, top_k)
    top_k_features = feature_indices_with_counts[top_k_indices]

    return SimpleNamespace(
        feature_indices=top_k_features,
        freqs=top_k_freqs
    )

In [15]:
def relate_feature_acts_to_tokens(encoder_out, feature_idx, sample_tokens, top_k=None, no_token_repetition=False):
    feature_acts = torch.where(
        encoder_out.top_indices == feature_idx,
        encoder_out.top_acts,
        torch.tensor(0.0, device=encoder_out.top_acts.device)
    )
    max_acts_per_token = feature_acts.max(dim=1).values
    
    # Not that necessary but just in case
    nonzero_max_acts_mask = max_acts_per_token > 0
    if nonzero_max_acts_mask.sum() == 0:
        print(f'Non-dead feature activations weren\'t found for feature with index {feature_idx}')
        return torch.tensor([]), torch.tensor([])
    
    nonzero_max_acts = max_acts_per_token[nonzero_max_acts_mask]
    nonzero_max_acts_indices = nonzero_max_acts_mask.nonzero(as_tuple=True)[0]

    if no_token_repetition:
        nonzero_tokens = torch.tensor([sample_tokens[i].item() for i in nonzero_max_acts_indices], device=encoder_out.top_acts.device)
        unique_tokens, inverse_indices = torch.unique(nonzero_tokens, return_inverse=True)
        
        unique_tokens_len = len(unique_tokens)
        max_acts_per_unique_token = torch.zeros(unique_tokens_len, device=encoder_out.top_acts.device)
        for i, token_idx in enumerate(inverse_indices):
            max_acts_per_unique_token[token_idx] = max(max_acts_per_unique_token[token_idx], nonzero_max_acts[i])

        if top_k is None or top_k > unique_tokens_len:
            top_k = unique_tokens_len

        top_k_acts, top_k_indices = torch.topk(max_acts_per_unique_token, top_k)
        top_k_tokens = unique_tokens[top_k_indices]
    else:
        non_zero_acts_num = len(nonzero_max_acts)
        if top_k is None or top_k > non_zero_acts_num:
            top_k = non_zero_acts_num
        
        top_k_acts, top_k_indices = torch.topk(nonzero_max_acts, top_k)
        top_k_tokens = torch.tensor([sample_tokens[i].item() for i in nonzero_max_acts_indices[top_k_indices]], device=encoder_out.top_acts.device)
    
    return SimpleNamespace(
        token_ids=top_k_tokens,
        top_acts=top_k_acts
    )

#### Utils for plotting individual encoder out analysis

In [25]:
def plot_simple_bar(df, x, y, title, labels, save_path=None):
    bar_fig = px.bar(
        df,
        x=x,
        y=y,
        title=title,
        labels=labels
    )
    bar_fig.update_layout(
        xaxis={
            'tickangle': -90,
        },
        yaxis={
            'type': 'log'
        }
    )
    if save_path:
        # bar_fig.write_image(save_path)
        bar_fig.write_html(save_path)
        return
    bar_fig.show()


def plot_feature_acts(encoder_out, title, save_path=None):
    all_outs_encoder_out_df = pd.DataFrame({
        'index': encoder_out.top_indices.cpu().numpy().flatten(),
        'act': encoder_out.top_acts.cpu().numpy().flatten()
    })
    all_outs_encoder_out_df['index'] = all_outs_encoder_out_df['index'].astype(str)
    plot_simple_bar(
        all_outs_encoder_out_df,
        x='index',
        y='act',
        title=title,
        labels={'index': 'Feature Index', 'act': 'Activation'},
        save_path=save_path
    )


def plot_feature_freqs(feature_freqs, title, save_path=None):
    all_outs_encoder_out_df = pd.DataFrame({
        'index': feature_freqs.feature_indices.cpu().numpy().flatten(),
        'freq': feature_freqs.freqs.cpu().numpy().flatten()
    })
    all_outs_encoder_out_df['index'] = all_outs_encoder_out_df['index'].astype(str)
    plot_simple_bar(
        all_outs_encoder_out_df,
        x='index',
        y='freq',
        title=title,
        labels={'index': 'Feature Index', 'freq': 'Frequency'},
        save_path=save_path
    )


def plot_acts_histogram(encoder_out, title, nbins=50, save_path=None):
    acts_hist_fig = px.histogram(
        x=encoder_out.top_acts.cpu().numpy().flatten(),
        nbins=nbins,
        title=title,
        labels={'x': 'Activation', 'y': ' Frequency'},
    )
    acts_hist_fig.update_yaxes(type='log')

    if save_path:
        # acts_hist_fig.write_image(save_path)
        acts_hist_fig.write_html(save_path)
        return
    
    acts_hist_fig.show()


def plot_acts_and_tokens_relation_scatter(token_acts_relation, title, save_path=None):
    token_act_df = pd.DataFrame({
        'token_id': token_acts_relation.token_ids.cpu().numpy(),
        'top_acts': token_acts_relation.top_acts.cpu().numpy()
    })
    scatter_fig = px.scatter(
        token_act_df,
        x='token_id',
        y='top_acts',
        title=title,
        labels={'top_acts': 'Top Activations', 'token_id': 'Token ID'}
    )
    scatter_fig.update_traces(marker=dict(size=12, opacity=0.7))

    if save_path:
        # scatter_fig.write_image(save_path)
        scatter_fig.write_html(save_path)
        return
    
    scatter_fig.show()

def plot_acts_and_tokens_relation_bar(token_acts_relation, title, save_path=None):
    decoded_tokens = decode_inference(token_acts_relation.token_ids)
    decoded_tokens = [
        f'<space>:{token_id}' if token == ' ' else
        f'<colon>:{token_id}' if token == ':' else
        f'{token}:{token_id}'
        for token, token_id in zip(decoded_tokens, token_acts_relation.token_ids.tolist())
    ]
    token_act_df = pd.DataFrame({
        'token': decoded_tokens,
        'top_acts': token_acts_relation.top_acts.cpu().numpy()
    })

    plot_height = max(400, len(token_act_df) * 15)
    acts_tokens_bar_fig = px.bar(
        token_act_df,
        x='top_acts',
        y='token',
        orientation='h',
        title=title,
        labels={'top_acts': 'Top Activation', 'token': 'Token'},
    )
    acts_tokens_bar_fig.update_layout(
        yaxis={
            'categoryorder': 'total ascending',
            'tickmode': 'array',
            'tickvals': token_act_df['token'],
            'ticktext': token_act_df['token']
        },
        xaxis={
            'type': 'log'
        },
        margin=dict(l=200),
        font=dict(size=10),
        height=plot_height
    )

    if save_path:
        # acts_tokens_bar_fig.write_image(save_path)
        acts_tokens_bar_fig.write_html(save_path)
        return
    
    acts_tokens_bar_fig.show()

#### Functions for multiple layers

In [17]:
def get_hidden_state_by_layer_acts(model_out, model_sae_layers):
    hidden_state_by_layer = {}
    generated_tokens_only_hidden_state_by_layer = {}

    for layer in model_sae_layers:
        model_sae_layer_acts = []
        for step_idx in range(len(model_out.hidden_states)):
            step_layer_acts = model_out.hidden_states[step_idx][layer]
            # Remove batch dim
            flattened_acts = step_layer_acts.reshape(-1, step_layer_acts.size(-1))
            model_sae_layer_acts.append(flattened_acts)

        hidden_state_by_layer[layer] = torch.cat(model_sae_layer_acts, dim=0)
        generated_tokens_only_hidden_state_by_layer[layer] = torch.cat(model_sae_layer_acts[1:], dim=0)

    return hidden_state_by_layer, generated_tokens_only_hidden_state_by_layer

In [18]:
def encode_hidden_states_by_layer(hidden_state_by_layer, saes):
    encoder_out_by_layer = {}
    for layer_num in hidden_state_by_layer:
        encoder_out_by_layer[layer_num] = saes[layer_num].encode(
            hidden_state_by_layer[layer_num]
        )
    return encoder_out_by_layer

In [29]:
def analyze_enocder_out_by_layer(
    encoder_out_by_layer,
    top_k_encoder_out_by_acts,
    top_k_encoder_out_feature_freq,
    top_k_relevant_features,
    token_acts_relation_top_k, # Not used in scatter plot
    sample_token_ids,
    plot_save_path,
    plot_title_prefix_init=''
):
    for layer_num, encoder_out in encoder_out_by_layer.items():
        top_k_encoder_out = get_top_k_encoder_out_by_acts(encoder_out, top_k=top_k_encoder_out_by_acts)
        top_k_feature_freq = get_top_k_feature_freq_encoder_out(encoder_out.top_indices, top_k=top_k_encoder_out_feature_freq)
        relevant_features = get_top_k_encoder_out_by_acts(encoder_out, top_k=top_k_relevant_features, no_feature_repetition=True).top_indices

        # Create Plots
        plot_title_prefix = f'Layer {layer_num} |'
        if len(plot_title_prefix_init) > 0:
            plot_title_prefix = plot_title_prefix_init + plot_title_prefix
        
        filename_prefix = f'layer_{layer_num}'

        plot_feature_acts(
            top_k_encoder_out,
            title=f'{plot_title_prefix} Top {top_k_encoder_out_by_acts} Feature Activations',
            save_path=os.path.join(plot_save_path, f'{filename_prefix}_feature_activations.html')
        )
        plot_feature_freqs(
            top_k_feature_freq,
            title=f'{plot_title_prefix} Top {top_k_encoder_out_feature_freq} Feature Frequencies',
            save_path=os.path.join(plot_save_path, f'{filename_prefix}_feature_frequencies.html')
        )
        plot_acts_histogram(
            encoder_out,
            title=f'{plot_title_prefix} Activations Frequency',
            save_path=os.path.join(plot_save_path, f'{filename_prefix}_activations_frequencies_hist.html')
        )

        for relevant_feature in relevant_features:
            relevant_feature_idx = relevant_feature.item()
            token_acts_relation = relate_feature_acts_to_tokens(
                encoder_out,
                feature_idx=relevant_feature_idx,
                sample_tokens=sample_token_ids,
                no_token_repetition=True
            )
            top_k_token_acts_relation = relate_feature_acts_to_tokens(
                encoder_out,
                feature_idx=relevant_feature_idx,
                sample_tokens=sample_token_ids,
                top_k=token_acts_relation_top_k,
                no_token_repetition=True
            )

            plot_acts_and_tokens_relation_scatter(
                token_acts_relation,
                title=f'{plot_title_prefix} Top {token_acts_relation_top_k} Activating Tokens for Feature with index {relevant_feature_idx}',
                save_path=os.path.join(plot_save_path, f'{filename_prefix}_feature_{relevant_feature_idx}_activations_tokens_scatter.html')
            )
            plot_acts_and_tokens_relation_bar(
                top_k_token_acts_relation,
                title=f'{plot_title_prefix} Top {token_acts_relation_top_k} Activating Tokens for Feature with index {relevant_feature_idx}',
                save_path=os.path.join(plot_save_path, f'{filename_prefix}_feature_{relevant_feature_idx}_activations_tokens_bar.html')
            )

### Functions to run the experiments

In [24]:
def create_and_save_model_inferences(prompts, model_out_folder_path):
    for idx, prompt in enumerate(prompts):
        model_out = inference(prompt)
        save_path = os.path.join(model_out_folder_path, f'{idx}_model_inference.pt')
        save_model_output(model_out, save_path=save_path)

In [32]:
def load_and_analyze_model_inferences(saes, prompts, model_out_folder_path):
    model_out_files = os.listdir(model_out_folder_path)
    model_out_files.sort(key=lambda x: int(re.search(r'(\d+)_model_inference.pt', x).group(1)))

    # Top k Definitions Here
    top_k_encoder_out_by_acts=400
    top_k_encoder_out_feature_freq=70
    top_k_relevant_features=10
    token_acts_relation_top_k=70

    for out_file_idx, out_file in enumerate(model_out_files):
        model_out = load_model_output(
            os.path.join(model_out_folder_path, out_file),
            device=DEVICE
        )
        hidden_state_by_layer, generated_tokens_only_hidden_state_by_layer = get_hidden_state_by_layer_acts(model_out, saes.keys())

        # Analyze all tokens
        encoder_out_by_layer = encode_hidden_states_by_layer(hidden_state_by_layer, saes)

        all_tokens_plots_by_model_out_path = os.path.join(ALL_TOKENS_PLOTS_PATH, f'{out_file_idx}_model_out')
        os.makedirs(all_tokens_plots_by_model_out_path, exist_ok=True)

        analyze_enocder_out_by_layer(
            encoder_out_by_layer=encoder_out_by_layer,
            top_k_encoder_out_by_acts=top_k_encoder_out_by_acts,
            top_k_encoder_out_feature_freq=top_k_encoder_out_feature_freq,
            top_k_relevant_features=top_k_relevant_features,
            token_acts_relation_top_k=token_acts_relation_top_k,
            sample_token_ids=model_out.sequences.squeeze(),
            plot_save_path=all_tokens_plots_by_model_out_path,
            plot_title_prefix_init=f'Model Out #{out_file_idx} | '
        )

        # Analyze generated tokens only
        # The code here can be improved later
        model_input_prompt_tokens = tokenizer.apply_chat_template(
            prompts[out_file_idx],
            tokenize=True,
            add_generation_prompt=True,
            return_tensors='pt',
            return_dict=True
        ).input_ids
        input_prompt_tokens_len = model_input_prompt_tokens.size()[1]
        model_out_generated_only_token_ids = model_out.sequences.squeeze()[input_prompt_tokens_len:]

        encoder_out_by_layer = encode_hidden_states_by_layer(generated_tokens_only_hidden_state_by_layer, saes)

        generated_only_tokens_plots_by_model_out_path = os.path.join(GENERATED_ONLY_TOKENS_PLOTS_PATH, f'{out_file_idx}_model_out')
        os.makedirs(generated_only_tokens_plots_by_model_out_path, exist_ok=True)

        analyze_enocder_out_by_layer(
            encoder_out_by_layer=encoder_out_by_layer,
            top_k_encoder_out_by_acts=top_k_encoder_out_by_acts,
            top_k_encoder_out_feature_freq=top_k_encoder_out_feature_freq,
            top_k_relevant_features=top_k_relevant_features,
            token_acts_relation_top_k=token_acts_relation_top_k,
            sample_token_ids=model_out_generated_only_token_ids,
            plot_save_path=generated_only_tokens_plots_by_model_out_path,
            plot_title_prefix_init=f'Model Out #{out_file_idx} (Gen. Tokens Only) | '
        )

### Running the experiments

In [21]:
model_sae_layers = [5, 10, 20]

saes = {}
for layer_num in model_sae_layers:
    saes[layer_num] = Sae.load_from_hub('EleutherAI/sae-DeepSeek-R1-Distill-Qwen-1.5B-65k', hookpoint=f'layers.{layer_num}.mlp', device=DEVICE)
    

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

In [None]:
# create_and_save_model_inferences(
#     prompts=reasoning_chat_prompts,
#     model_out_folder_path=MODEL_OUT_CACHE_FOLDER_PATH
# )

# -> Finished 31 inferences

In [35]:
load_and_analyze_model_inferences(
    saes=saes,
    prompts=reasoning_chat_prompts,
    model_out_folder_path=MODEL_OUT_CACHE_FOLDER_PATH
)