# **Toward Mechanistic Explanation of Deductive Reasoning in Language Models**

This notebook is designed to reproduce the use of the visualization tool developed in [D. Maltoni and M. Ferrara, *"Toward Mechanistic Explanation of Deductive Reasoning in Language Models"*, arXiv:2510.09340, 2025](https://arxiv.org/abs/2510.09340). It relies on the following Python scripts:
- **logic_data.py** - provides functions to create the dataset used in the experimentation.
- **model_with_hooks.py** - contains a modified version of the nanoGPT model originally implemented by [Andrej Karpathy](https://github.com/karpathy/nanoGPT) combined with the [TransformerLens](https://github.com/TransformerLensOrg/TransformerLens), a library for mechanistic interpretability of generative language models.
- **utility_functions.py** - provides several utility functions.
- **EnvUtilities.py** - contains the *setup_environment* function, which properly configures the environment.

The following code imports all necessary modules and functions required to run this notebook.

In [None]:
import os
import torch
import numpy as np
import ipywidgets as widgets
from ipywidgets import interact

from logic_data import create_dics, create_dataset
from utility_functions import load_model, generate, to_tensor_prompts, to_str_tokens, prepare_data_flows_from_attention_matrices, from_logits_to_chars, from_logits_to_top_k_chars, truncated_pseudoinverse, draw_sankey, compute_avg_attention_layer_patterns
from EnvUtilities import setup_environment

In the following code cell, you will set the model file name and the prompt to analyze.

In [None]:
# Model checkpoint
model_file_name=f"VarLenPos_NewLitNeg_CoT_NoUlC_DB4096_mb128_Ep150_AtL2_He1_Emb128_NoMLP_NoFlAt_se127_r8.mdl"

model_checkpoint_file_path=os.path.join(os.path.abspath(''),model_file_name)
#---

# Prompt to analyze
prompt_str = 'C>D,A>B,B>C,E>F,D>E|A>F@'

The code cell below creates a set of index lists that will be used to visualize only specific links between attention layers.

In [None]:
all_token_indices=list(range(44))
all_ouptut_token_indices=list(range(23,44))
output_premise_token_indices=list(range(24,41,4))
output_greater_token_indices=list(range(25,42,4))
output_consequent_token_indices=list(range(26,43,4))
output_comma_token_indices=list(range(27,43,4))
output_minus_token_index=list(range(43,44))

Executing the following code cell will:
- initialize the environment,
- create two dictionaries for character-to-index and index-to-character mappings,
- load the model, and
- extract the model parameters from its configuration.

In [None]:
# Environment initialization
env = setup_environment(use_gpu = True)
device = env['device']
ctx = env['ctx']
# ---

# Mapping creation between characters and indices
_, stoi, itos = create_dics(alphabet=20)

# Model loading
model, gptconf=load_model(model_checkpoint_file_path, device)
model.eval()
torch.set_grad_enabled(False)

# Model parameters
n_layer = gptconf.n_layer
n_head = gptconf.n_head
n_embd = gptconf.n_embd

The code cell below performs the following steps:
- uses the model to generate an output based on the input prompt (*prompt_str*),
- creates a new string (*str_without_last_token*) that combines the prompt and the generated output, excluding the last token,
- converts *str_without_last_token* and the generated output into two lists of characters for visualization purposes,
- passes *str_without_last_token* to the model to obtain all cached hooks using TransformerLens functionalities.

In [None]:
# Generate output from the model using the prompt example
output = generate(model, 6,True, stoi, itos, prompt_str, device, ctx)
print(f"Model output: {output}")

# Prepare a string containing: prompt + generated output without the last token
str_without_last_token=output[:-1]

# Convert prompt + generated output without the last token to list of characters
prompt_last_token_charlist=to_str_tokens(str_without_last_token,"")

# Convert generated output to list of characters
output_charlist=to_str_tokens(output[24:],"")

# Run the model to get cache using TransformerLens library
model.return_all_logits = True
_,cache=model.run_with_cache(to_tensor_prompts(str_without_last_token, stoi, device),remove_batch_dim=True)

The following code processes the attention patterns stored in the cached hooks and prepares them for visualization.

In [None]:
# Extract attention layer patterns from cache
attention_layer_patterns={}
for i in range(n_layer):
    key = f"transformer.h.{i}.attn.hook_pattern"
    attention_layer_patterns[i]=cache[key]

# Process attention matrices for visualization
attention_matrices= [attention_layer_patterns[i].cpu().numpy() for i in range(n_layer)]
attention_layer_sources, attention_layer_targets, attention_layer_values = prepare_data_flows_from_attention_matrices(n_layer, n_head, attention_matrices)

The cell code below selects the characters corresponding to the top-ranked and second-ranked tokens decoded from the residual stream after Layer 1. This decoding is based on the [LogitLens](https://www.alignmentforum.org/posts/AcKRB8wDpdaN6v6ru/interpreting-gpt-the-logit-lens) approach, which applies the final LM head to the residual stream to identify tokens with the highest logit values.

In [None]:
# Extract residual stream after first attention layer from cache
post_l0_emb=cache["hook_post_li_emb.0"]

# Get logits from residual stream after first attention layer
post_l0_logits=model.lm_head(post_l0_emb)

# Get top-k characters from logits after first attention layer
l1_top_k_subs,_=from_logits_to_top_k_chars(post_l0_logits,itos,top_k=2)

The following code cell creates a dataset of 4096 examples (2048 positives and 2048 negatives) to compute average attention patterns, which will be used to visualize token-idependent links. 

In [None]:
# Create dataset of 4096 examples
src_data, tgt_data,_,_=create_dataset(literals=6, alphabet=20, dataset_size=4096, positive_sample_with_fixed_length=False, negative_sample_with_new_literal=True, encfunc=stoi)
data = np.concatenate((src_data, tgt_data), axis=1)

# Split dataset into negative and positive examples
negative_data = data[data[:, -1] == 6]  # Assuming 6 is the label for positive examples
positive_data = data[data[:, -1] == 7]  # Assuming 7 is the label for positive examples

In the code cell below the average attention patterns are computed and prepared for visualization.

In [None]:
# Compute average attention patterns for the entire dataset
avg_attention_layer_patterns=compute_avg_attention_layer_patterns(n_head,n_layer,data,model,stoi,itos,device)
avg_attention_layer_sources, avg_attention_layer_targets, avg_attention_layer_values = prepare_data_flows_from_attention_matrices(n_layer, n_head, avg_attention_layer_patterns)

# Compute average attention patterns for positive examples
pos_avg_attention_layer_patterns=compute_avg_attention_layer_patterns(n_head,n_layer,positive_data,model,stoi,itos,device)
pos_avg_attention_layer_sources, pos_avg_attention_layer_targets, pos_avg_attention_layer_values = prepare_data_flows_from_attention_matrices(n_layer, n_head, pos_avg_attention_layer_patterns)

# Compute average attention patterns for negative examples
neg_avg_attention_layer_patterns=compute_avg_attention_layer_patterns(n_head,n_layer,negative_data,model,stoi,itos,device)
neg_avg_attention_layer_sources, neg_avg_attention_layer_targets, neg_avg_attention_layer_values = prepare_data_flows_from_attention_matrices(n_layer, n_head, neg_avg_attention_layer_patterns)

the following code cell applies the truncated pseudoinverse approach proposed in the paper to decode the information carried by queries, keys and values into token space.

In [None]:
thr_Q=0.7  
thr_K=0.99
thr_V=0.7

# Extract Q, K, and V projection matrices from the model
W_Q = model.transformer.h[1].attn.c_attn.weight.split(n_embd, dim=0)[0].T
W_K = model.transformer.h[1].attn.c_attn.weight.split(n_embd, dim=0)[1].T
W_V = model.transformer.h[1].attn.c_attn.weight.split(n_embd, dim=0)[2].T

res_stream = post_l0_emb
res_stream = model.transformer.h[1].ln_1(res_stream)  # Layer normalization applied to residual stream

# Compute truncated pseudoinverse
W_Q_inv, n_Q = truncated_pseudoinverse(W_Q.cpu().numpy(), thr_sk=thr_Q)
W_K_inv, n_K = truncated_pseudoinverse(W_K.cpu().numpy(), thr_sk=thr_K)
W_V_inv, n_V = truncated_pseudoinverse(W_V.cpu().numpy(), thr_sk=thr_V)

# Project residual stream onto Q, K, V subspaces
res_stream_Q = res_stream @ W_Q @ torch.tensor(W_Q_inv, device=device)
res_stream_K = res_stream @ W_K @ torch.tensor(W_K_inv, device=device)  
res_stream_V = res_stream @ W_V @ torch.tensor(W_V_inv, device=device)  

# Convert projected residual streams back to character space
res_stream_Q_letters,_ = from_logits_to_chars(model.lm_head(res_stream_Q),itos)
res_stream_K_letters,_ = from_logits_to_chars(model.lm_head(res_stream_K),itos)
res_stream_V_letters,_ = from_logits_to_chars(model.lm_head(res_stream_V),itos)

Executing the cell below will reproduce Figure 2.

In [None]:
#Figure 2

@interact(weight_thr=widgets.FloatSlider(min=0.0, max=1.0, step=0.05, value=0.0,continuous_update=False,description="Threshold:"),
          target_flows=widgets.Dropdown(options={'All tokens': all_token_indices,
                                                 'All output tokens': all_ouptut_token_indices,
                                                 'Output premise tokens': output_premise_token_indices,
                                                 'Output \'>\' tokens': output_greater_token_indices,
                                                 'Output consequent tokens': output_consequent_token_indices,
                                                 'Output \',\' tokens': output_comma_token_indices,
                                                 'Output \'-\' token': output_minus_token_index
                                                 },
                                        value=all_token_indices,
                                        description='L1->L2 Links:'),
          attention_data=widgets.Dropdown(options={'Current prompt': [attention_layer_sources, attention_layer_targets, attention_layer_values],
                                                  'Average on all examples': [avg_attention_layer_sources, avg_attention_layer_targets, avg_attention_layer_values],
                                                  'Average on positive examples': [pos_avg_attention_layer_sources, pos_avg_attention_layer_targets, pos_avg_attention_layer_values],
                                                  'Average on negative examples': [neg_avg_attention_layer_sources, neg_avg_attention_layer_targets, neg_avg_attention_layer_values],
                                                  },
                                                   value=[attention_layer_sources, attention_layer_targets, attention_layer_values],
                                                   description='Attention:'),
          show_kqv=widgets.Checkbox(value=False, description='Show K, Q, V', disabled=False))
def draw(weight_thr,target_flows, attention_data,show_kqv):
   if show_kqv:
      res_stream_data=[res_stream_Q_letters,res_stream_K_letters,res_stream_V_letters]
   else:
      res_stream_data=None

   draw_sankey(prompt_last_token_charlist, output_charlist, l1_top_k_subs, attention_data,res_stream_data, all_token_indices, weight_thr, target_flows)

Executing the cell below will reproduce Figure 3.

In [None]:
#Figure 3

@interact(weight_thr=widgets.FloatSlider(min=0.0, max=1.0, step=0.05, value=0.4,continuous_update=False,description="Threshold:"),
          target_flows=widgets.Dropdown(options={'All tokens': all_token_indices,
                                                 'All output tokens': all_ouptut_token_indices,
                                                 'Output premise tokens': output_premise_token_indices,
                                                 'Output \'>\' tokens': output_greater_token_indices,
                                                 'Output consequent tokens': output_consequent_token_indices,
                                                 'Output \',\' tokens': output_comma_token_indices,
                                                 'Output \'-\' token': output_minus_token_index
                                                 },
                                        value=output_greater_token_indices,
                                        description='L1->L2 Links:'),
          attention_data=widgets.Dropdown(options={'Current prompt': [attention_layer_sources, attention_layer_targets, attention_layer_values],
                                                  'Average on all examples': [avg_attention_layer_sources, avg_attention_layer_targets, avg_attention_layer_values],
                                                  'Average on positive examples': [pos_avg_attention_layer_sources, pos_avg_attention_layer_targets, pos_avg_attention_layer_values],
                                                  'Average on negative examples': [neg_avg_attention_layer_sources, neg_avg_attention_layer_targets, neg_avg_attention_layer_values],
                                                  },
                                                   value=[attention_layer_sources, attention_layer_targets, attention_layer_values],
                                                   description='Attention:'),
          show_kqv=widgets.Checkbox(value=True, description='Show K, Q, V', disabled=False))
def draw(weight_thr,target_flows, attention_data,show_kqv):
   if show_kqv:
      res_stream_data=[res_stream_Q_letters,res_stream_K_letters,res_stream_V_letters]
   else:
      res_stream_data=None

   draw_sankey(prompt_last_token_charlist, output_charlist, l1_top_k_subs, attention_data,res_stream_data, all_token_indices, weight_thr, target_flows)

Executing the cell below will reproduce Figure 4.

In [None]:
#Figure 4

@interact(weight_thr=widgets.FloatSlider(min=0.0, max=1.0, step=0.05, value=0.4,continuous_update=False,description="Threshold:"),
          target_flows=widgets.Dropdown(options={'All tokens': all_token_indices,
                                                 'All output tokens': all_ouptut_token_indices,
                                                 'Output premise tokens': output_premise_token_indices,
                                                 'Output \'>\' tokens': output_greater_token_indices,
                                                 'Output consequent tokens': output_consequent_token_indices,
                                                 'Output \',\' tokens': output_comma_token_indices,
                                                 'Output -\' token': output_minus_token_index
                                                 },
                                        value=output_comma_token_indices,
                                        description='L1->L2 Links:'),
          attention_data=widgets.Dropdown(options={'Current prompt': [attention_layer_sources, attention_layer_targets, attention_layer_values],
                                                  'Average on all examples': [avg_attention_layer_sources, avg_attention_layer_targets, avg_attention_layer_values],
                                                  'Average on positive examples': [pos_avg_attention_layer_sources, pos_avg_attention_layer_targets, pos_avg_attention_layer_values],
                                                  'Average on negative examples': [neg_avg_attention_layer_sources, neg_avg_attention_layer_targets, neg_avg_attention_layer_values],
                                                  },
                                                   value=[attention_layer_sources, attention_layer_targets, attention_layer_values],
                                                   description='Attention:'),
          show_kqv=widgets.Checkbox(value=True, description='Show K, Q, V', disabled=False))
def draw(weight_thr,target_flows, attention_data,show_kqv):
   if show_kqv:
      res_stream_data=[res_stream_Q_letters,res_stream_K_letters,res_stream_V_letters]
   else:
      res_stream_data=None

   draw_sankey(prompt_last_token_charlist, output_charlist, l1_top_k_subs, attention_data,res_stream_data, all_token_indices, weight_thr, target_flows)

Executing the cell below will reproduce Figure 5.

In [None]:
#Figure 5

@interact(weight_thr=widgets.FloatSlider(min=0.0, max=1.0, step=0.05, value=0.1,continuous_update=False,description="Threshold:"),
          target_flows=widgets.Dropdown(options={'All tokens': all_token_indices,
                                                 'All output tokens': all_ouptut_token_indices,
                                                 'Output premise tokens': output_premise_token_indices,
                                                 'Output \'>\' tokens': output_greater_token_indices,
                                                 'Output consequent tokens': output_consequent_token_indices,
                                                 'Output \',\' tokens': output_comma_token_indices,
                                                 'Output \'-\' token': output_minus_token_index
                                                 },
                                        value=output_minus_token_index,
                                        description='L1->L2 Links:'),
          attention_data=widgets.Dropdown(options={'Current prompt': [attention_layer_sources, attention_layer_targets, attention_layer_values],
                                                  'Average on all examples': [avg_attention_layer_sources, avg_attention_layer_targets, avg_attention_layer_values],
                                                  'Average on positive examples': [pos_avg_attention_layer_sources, pos_avg_attention_layer_targets, pos_avg_attention_layer_values],
                                                  'Average on negative examples': [neg_avg_attention_layer_sources, neg_avg_attention_layer_targets, neg_avg_attention_layer_values],
                                                  },
                                                   value=[pos_avg_attention_layer_sources, pos_avg_attention_layer_targets, pos_avg_attention_layer_values],
                                                   description='Attention:'),
          show_kqv=widgets.Checkbox(value=False, description='Show K, Q, V', disabled=False))
def draw(weight_thr,target_flows, attention_data,show_kqv):
   if show_kqv:
      res_stream_data=[res_stream_Q_letters,res_stream_K_letters,res_stream_V_letters]
   else:
      res_stream_data=None

   draw_sankey(prompt_last_token_charlist, output_charlist, l1_top_k_subs, attention_data,res_stream_data, all_token_indices, weight_thr, target_flows)