In [65]:
%load_ext autoreload
%autoreload 2

from epsilon_transformers.persistence import S3Persister, HackyPersister
from epsilon_transformers.training.configs.model_configs import RawModelConfig
from epsilon_transformers.process.processes import RRXOR
from epsilon_transformers.analysis.activation_analysis import get_beliefs_for_transformer_inputs
from epsilon_transformers.steering.steer import organize_activations,get_steering_vector,run_model_with_steering,get_inputs_ending_in_belief

import numpy as np
import torch
import plotly.express as px
import pathlib

from sklearn.linear_model import LinearRegression
from sklearn.decomposition import PCA
from plotly.subplots import make_subplots
from plotly import graph_objects as go



The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


# Get the model checkpoints (.pt files) and the config

In [66]:
persister = HackyPersister(dir_path=pathlib.Path('./models/vfs4q106-rrxor'))
checkpoints = persister.get_model_checkpoints()


# Initialize the RRXOR process
This library has a Process class that we use to define a data generating HMM.

The most important part of the process is the transition matrix, which defines the transitions between states. You can get it via `process.transition_matrix`. It is a numpy array of shape (num_emission, num_states, num_states), and defines the transition probabilities from state i to state j given emission k.

The process also has a `state_names_dict`, which is a dictionary mapping state names to state indices. You can get it via `process.state_names_dict`.


In [67]:
process = RRXOR()
print(process)
process_matrix= process.transition_matrix
mixed_state_tree = process.derive_mixed_state_presentation(depth=11)
MSP_transition_matrix = mixed_state_tree.build_msp_transition_matrix()
# make a process out of the transition matrix
from epsilon_transformers.process.processes import TransitionMatrixProcess
process = TransitionMatrixProcess(transition_matrix=MSP_transition_matrix)


rrxor Process
Number of states: 5
Vocabulary length: 2
Transition matrix shape: (2, 5, 5)


# Simplex Analysis

## Ground Truth Simplex

To get the ground truth simplex structure, we need to get all paths of any length on the mixed state tree, and their associated belief states. To do this, we use `mixed_state_tree.paths_and_belief_states`, which returns a list of tuples, where the first element is the path, and the second element is the belief state associated with it.



In [68]:
# in order to plot the belief states in the simplex, we need to get the paths and beliefs from the MSP
tree_paths, tree_beliefs = mixed_state_tree.paths_and_belief_states

msp_beliefs = [tuple(round(b, 5) for b in belief) for belief in tree_beliefs]
# the MSP states are the unique beliefs in the tree
print(f"Number of Unique beliefs: {len(set(msp_beliefs))} out of {len(msp_beliefs)}")
# now lets index each belief
msp_belief_index = {b: i for i, b in enumerate(set(msp_beliefs))}

for i in range(5):
    ith_belief = list(msp_belief_index.keys())[i]
    print(f"{ith_belief} is indexed as {msp_belief_index[ith_belief]}")


Number of Unique beliefs: 36 out of 1723
(0.5, 0.25, 0.0, 0.0, 0.25) is indexed as 0
(0.33333, 0.0, 0.33333, 0.16667, 0.16667) is indexed as 1
(0.0, 0.0, 1.0, 0.0, 0.0) is indexed as 2
(0.0, 0.66667, 0.0, 0.0, 0.33333) is indexed as 3
(0.0, 0.0, 0.5, 0.5, 0.0) is indexed as 4


The simplex is a 4D structure, representing probability distributions over 5 hidden states, so we project down to 3D using PCA for visualization.



In [69]:
def run_visualization_pca(beliefs):
    pca = PCA(n_components=3)
    pca.fit(beliefs)

    return pca

vis_pca = run_visualization_pca(list(msp_belief_index.keys()))
index = list(msp_belief_index.values())



# Find Simplex in Transformer Activations

First, we get all of the possible sequences generated by our HMM by getting all of the paths on the MSP, of length equal to the context length. We put them all in one batch so we can get the activations for all of them at once.


In [70]:
# now lets set up all the inputs as they arrive into the transformer
device = 'cpu'
transformer_inputs = [x for x in tree_paths if len(x) == 10]
transformer_inputs = torch.tensor(transformer_inputs, dtype=torch.int).to(device)

# print first few batches
print(transformer_inputs[:5])

tensor([[0, 1, 0, 1, 1, 1, 0, 1, 0, 1],
        [0, 1, 0, 1, 1, 1, 0, 0, 1, 1],
        [0, 1, 0, 1, 1, 1, 0, 0, 0, 0],
        [0, 1, 0, 1, 1, 0, 1, 1, 1, 0],
        [0, 1, 0, 1, 1, 0, 1, 1, 1, 1]], dtype=torch.int32)


Then, we get the belief states associated with each transformer input, using the `get_beliefs_for_transformer_inputs` function. The output of this function is a tuple of two tensors, where the first tensor is the belief states associated with each transformer input, and the second tensor is the indices of the beliefs associated with each transformer input. The shape of the belief states tensor is [batch, n_ctx, belief_dim], and the shape of the indices tensor is [batch, n_ctx].


In [71]:
transformer_input_beliefs, transformer_input_belief_indices = get_beliefs_for_transformer_inputs(transformer_inputs, msp_belief_index, tree_paths, tree_beliefs)
print(f"Transformer Input Beliefs: {transformer_input_beliefs.shape}, Transformer Input Belief Indices: {transformer_input_belief_indices.shape}")
transformer_input_belief_indices_flattened = transformer_input_belief_indices.view(-1).cpu().numpy()



Transformer Input Beliefs: torch.Size([436, 10, 5]), Transformer Input Belief Indices: torch.Size([436, 10])


Next we get the activations from the transformer for the input. We use the `names_filter` argument to get the activations from the resid_post layers, which is the residual stream of the transformer after the attention and MLP have been added in.

In [72]:
device = torch.device("cpu")#torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = persister.load_model(ckpt_num=checkpoints[-1], device=device)
logits, activations = model.run_with_cache(transformer_inputs, names_filter=lambda x: 'hook_resid_post' in x)
print(activations.keys())

dict_keys(['blocks.0.hook_resid_post', 'blocks.1.hook_resid_post', 'blocks.2.hook_resid_post', 'blocks.3.hook_resid_post'])


Next we define a function to run the linear regression between the activations and the belief states.
For the RRXOR process we concatenate the activations from the residual stream of all layers.
Using the `run_activation_to_beliefs_regression` function, we run the linear regression between the activations and the belief states. The outputs are the linear regression model, and the belief predictions, which are of shape [batch, n_ctx, belief_dim], that is, for each context position, we have a belief prediction over the 5 hidden states.

In [73]:
# we now have activations [batch, n_ctx, d_model]
# and we have transformer_input_beliefs [batch, n_ctx, belief_dim]
# and we have transformer_input_belief_indices [batch, n_ctx]

# in the end we want to do linear regression between the activations and the transformer_input_beliefs
def run_activation_to_beliefs_regression(activations, ground_truth_beliefs):

    # make sure the first two dimensions are the same
    assert activations.shape[0] == ground_truth_beliefs.shape[0]
    assert activations.shape[1] == ground_truth_beliefs.shape[1]

    # flatten the activations
    batch_size, n_ctx, d_model = activations.shape
    belief_dim = ground_truth_beliefs.shape[-1]
    activations_flattened = activations.view(-1, d_model) # [batch * n_ctx, d_model]
    ground_truth_beliefs_flattened = ground_truth_beliefs.view(-1, belief_dim) # [batch * n_ctx, belief_dim]
    
    # run the regression    
    regression = LinearRegression()
    regression.fit(activations_flattened, ground_truth_beliefs_flattened)

    # get the belief predictions
    belief_predictions = regression.predict(activations_flattened) # [batch * n_ctx, belief_dim]
    belief_predictions = belief_predictions.reshape(batch_size, n_ctx, belief_dim)

    return regression, belief_predictions



In [74]:
acts = torch.concatenate([v for k, v in activations.items()], dim=-1)
regression, belief_predictions = run_activation_to_beliefs_regression(acts, transformer_input_beliefs)
print(f"Shape of belief_predictions: {belief_predictions.shape}")
belief_predictions_pca = vis_pca.transform(belief_predictions.reshape(-1, 5))



Shape of belief_predictions: (436, 10, 5)


## Collect activations

In [75]:
per_layer_belief_activations = organize_activations(activations, transformer_input_belief_indices, all_positions=True)

## Compute steering vector

In [76]:
state_1 = 31
state_2 = 21
per_layer_steering_vector = get_steering_vector(per_layer_belief_activations,state_1,state_2)


## Add the hooking function

In [77]:
prompts_with_belief_state_1 = get_inputs_ending_in_belief(transformer_inputs,transformer_input_belief_indices,state_1)
prompts_with_belief_state_2 = get_inputs_ending_in_belief(transformer_inputs,transformer_input_belief_indices,state_2)

In [78]:
normal_1 = model(prompts_with_belief_state_1)
steered_to_2 = run_model_with_steering(model,prompts_with_belief_state_1,per_layer_steering_vector,1)

normal_2 = model(prompts_with_belief_state_2)
steered_to_1 = run_model_with_steering(model,prompts_with_belief_state_2,per_layer_steering_vector,-1)

In [79]:
output_state_1 = normal_1[:,-1,:].softmax(1).detach()
output_state_2 = normal_2[:,-1,:].softmax(1).detach()
corrupted_output_state_1 = steered_to_2[:,-1,:].softmax(1).detach()
corrupted_output_state_2 = steered_to_1[:,-1,:].softmax(1).detach()

In [80]:
outputs =[output_state_1,output_state_2,corrupted_output_state_1,corrupted_output_state_2]
zero_bars = {"state_1":0,"state_2":0,"corrupted_state_1":0,"corrupted_state_2":0}
one_bars = {"state_1":0,"state_2":0,"corrupted_state_1":0,"corrupted_state_2":0}
for i,output in enumerate(outputs):
    total = len(output)
    key = list(one_bars.keys())[i]
    one_bars[key] = sum(output[:,0].numpy())/total
    zero_bars[key] = sum(output[:,1].numpy())/total

In [81]:
# Create a subplot with two scatter plots
fig = make_subplots(rows=1, cols=2, specs=[[{'type': 'scatter'}, {'type': 'scatter'}]])

# Plot the ground truth beliefs on the left
fig.add_trace(go.Bar(x=["State T","State F", "State T->State F","State F-> State T"], y=list(zero_bars.values()),
                         name=f'Probability to output 0'),
              row=1, col=1)
fig.add_trace(go.Bar(x=["State T","State F", "State T->State F","State F-> State T"], y=list(one_bars.values()),
                         name=f'Probability to output 1'),
              row=1, col=2)
fig.update_layout(title='Output probabilities',
                  yaxis_title='Probabilities', xaxis_title='Model belief state',
                  width=800, height=400,
                  )
fig.update_yaxes(range=[0, 1], row=1, col=2)
fig.update_yaxes(range=[0, 1], row=1, col=1)



## Ploting steering vectors

In [82]:
state_1 = 31
state_2 = 10
per_layer_steering_vector = get_steering_vector(per_layer_belief_activations,state_1,state_2)
flattened_vector = torch.cat([v for k,v in per_layer_steering_vector.items()],dim=-1)
activations_state_1 = []
activations_state_2 = []
for layers in per_layer_belief_activations.values():
    activations_state_1.append(torch.stack(layers[state_1]))
    activations_state_2.append(torch.stack(layers[state_2]))
activations_state_1 = torch.cat(activations_state_1,dim=-1)
activations_state_2 = torch.cat(activations_state_2,dim=-1)

transformed_steering_vector_full = vis_pca.transform(regression.predict(flattened_vector.reshape(1,-1)))
transformed_activations_1 = regression.predict(activations_state_1)
transformed_activations_2 = regression.predict(activations_state_2)


In [83]:
def get_plot_2d_pca():
    colors = px.colors.qualitative.Light24 + px.colors.qualitative.Dark24 + px.colors.qualitative.Plotly

    # Create a subplot with two scatter plots
    fig = make_subplots(rows=1, cols=1, specs=[[{'type': 'scatter'}]])

    for belief in msp_belief_index.keys():
        b = msp_belief_index[belief]
        relevant_indices = np.where(transformer_input_belief_indices_flattened == b)[0]
        relevant_data = belief_predictions_pca[relevant_indices]
        if len(relevant_data) > 0:
            centers_of_mass = np.mean(relevant_data, axis=0)
            fig.add_trace(go.Scatter(x=[centers_of_mass[0]], y=[centers_of_mass[2]],
                                        mode='markers',
                                        marker=dict(size=4, color=colors[b], opacity=1),
                                        name=f'Belief {b}'),
                            row=1, col=1)

    # set x and y lime to -.75 to .75
    fig.update_xaxes(range=[-.85, .85], row=1, col=1)
    fig.update_yaxes(range=[-.85, .85], row=1, col=1)
    # Update layout
    fig.update_layout(title='2D PCA Projection of Beliefs',
                    xaxis_title='PCA Dimension 1', yaxis_title='PCA Dimension 2',
                    width=600, height=400,
                    annotations=[
                        dict(text="Residual Stream", x=0.8, y=1.05, showarrow=False, xref="paper", yref="paper")
                    ])
    return fig

In [84]:
def get_plot_2d_pca_vector(fig,vector):
    # Calculate and plot the centers of mass of the belief predictions on the right

    for belief in msp_belief_index.keys():
        b = msp_belief_index[belief]
        relevant_indices = np.where(transformer_input_belief_indices_flattened == b)[0]
        relevant_data = belief_predictions_pca[relevant_indices]
        if len(relevant_data) > 0:
            centers_of_mass = np.mean(relevant_data, axis=0)
            if b == state_1:
                fig.add_trace(go.Scatter(x=[centers_of_mass[0]], y=[centers_of_mass[2]],
                                        mode='markers',
                                        marker=dict(size=10, color=colors[b], opacity=1),
                                        name=f'Starting Belief'),
                            row=1, col=1)
                center_x_start=centers_of_mass[0]
                center_y_start=centers_of_mass[2]
            
            if b == state_2:
                fig.add_trace(go.Scatter(x=[centers_of_mass[0]], y=[centers_of_mass[2]],
                                        mode='markers',
                                        marker=dict(size=10, color=colors[b], opacity=1),
                                        name=f'Ending Belief'),
                            row=1, col=1)
                center_x_end=centers_of_mass[0]
                center_y_end=centers_of_mass[1]
                
            else:
                fig.add_trace(go.Scatter(x=[centers_of_mass[0]], y=[centers_of_mass[2]],
                                        mode='markers',
                                        marker=dict(size=2, color=colors[b], opacity=1),),
                            row=1, col=1)
    fig.add_trace(go.Scatter(x=[center_x_start,transformed_steering_vector_full[0][0]], y=[center_y_start,transformed_steering_vector_full[0][2]],name="Vector"))        
    return fig

In [85]:
# Transform the ground truth beliefs
colors = px.colors.qualitative.Light24 + px.colors.qualitative.Dark24 + px.colors.qualitative.Plotly

# Create a subplot with two scatter plots
fig = make_subplots(rows=1, cols=1, specs=[[{'type': 'scatter'}]])


state_1 = 31
state_2 = 10
per_layer_steering_vector = get_steering_vector(per_layer_belief_activations,state_1,state_2)
flattened_vector = torch.cat([v for k,v in per_layer_steering_vector.items()],dim=-1)
transformed_steering_vector_full = vis_pca.transform(regression.predict(flattened_vector.reshape(1,-1)))
fig = get_plot_2d_pca_vector(fig,transformed_steering_vector_full)

state_1 = 31
state_2 = 21
per_layer_steering_vector = get_steering_vector(per_layer_belief_activations,state_1,state_2)
flattened_vector = torch.cat([v for k,v in per_layer_steering_vector.items()],dim=-1)
transformed_steering_vector_full = vis_pca.transform(regression.predict(flattened_vector.reshape(1,-1)))
fig = get_plot_2d_pca_vector(fig,transformed_steering_vector_full)

state_1 = 10
state_2 = 2
per_layer_steering_vector = get_steering_vector(per_layer_belief_activations,state_1,state_2)
flattened_vector = torch.cat([v for k,v in per_layer_steering_vector.items()],dim=-1)
transformed_steering_vector_full = vis_pca.transform(regression.predict(flattened_vector.reshape(1,-1)))
fig = get_plot_2d_pca_vector(fig,transformed_steering_vector_full)




# set x and y lime to -.75 to .75
fig.update_xaxes(range=[-.85, .85], row=1, col=1)
fig.update_yaxes(range=[-.85, .85], row=1, col=1)
# Update layout
fig.update_layout(title='2D PCA Projection of Beliefs',
                  xaxis_title='PCA Dimension 1', yaxis_title='PCA Dimension 2',
                  width=600, height=400,
                  annotations=[
                      dict(text="Residual Stream", x=0.8, y=1.05, showarrow=False, xref="paper", yref="paper")
                  ])



## Plot trajectory

In [86]:
sentence_activation=[]
for activation in activations.values():
    sentence_activation.append(activation[1])
sentence_activation = torch.stack(sentence_activation)
sentence_activation = torch.permute(sentence_activation,(1,0,2))

In [87]:
plot = get_plot_2d_pca()
token_trajectory = []
belief_colors = []
for j,token in enumerate(sentence_activation):
    token = token.reshape(1,-1)
    transformed_token = regression.predict(token)
    transformed_token = vis_pca.transform(transformed_token)
    correct_belief = transformer_input_belief_indices[1][j].item()
    belief_colors.append(colors[correct_belief])
    token_trajectory.append(transformed_token[0])
    plot.add_trace(go.Scatter(x=[transformed_token[0][0]], y=[transformed_token[0][2]],
                                        mode='lines+markers',
                                        marker=dict(size=10, color=colors[correct_belief], opacity=1),
                                        name=f'Token {j}'),   
                            row=1, col=1)

token_trajectory = np.transpose(token_trajectory)
plot.add_trace(go.Scatter(x=token_trajectory[0], y=token_trajectory[2],
                                    mode='lines',
                                    name=f'Trajectory',line=dict(color='firebrick', width=1,
                              dash='dash')),   
                        row=1, col=1)
plot.show()

## Steer in intermediate generation

In [88]:

state_1 = 24
state_2 = 27
cut_at = 2
cut_belief_indices = transformer_input_belief_indices[:,:cut_at]

per_layer_steering_vector = get_steering_vector(per_layer_belief_activations,state_1,state_2)
prompts_with_belief_state_1 = get_inputs_ending_in_belief(transformer_inputs[:,:cut_at],cut_belief_indices,state_1)[0] # all are the same 
prompts_with_belief_state_2 = get_inputs_ending_in_belief(transformer_inputs[:,:cut_at],cut_belief_indices,state_2)[0] # all are the same 

normal_1 = model(prompts_with_belief_state_1)
steered_to_2 = run_model_with_steering(model,prompts_with_belief_state_1,per_layer_steering_vector,1)
normal_2 = model(prompts_with_belief_state_2)
steered_to_1 = run_model_with_steering(model,prompts_with_belief_state_2,per_layer_steering_vector,-1)

output_state_1 = normal_1[:,-1,:].softmax(1).detach()
output_state_2 = normal_2[:,-1,:].softmax(1).detach()
corrupted_output_state_1 = steered_to_2[:,-1,:].softmax(1).detach()
corrupted_output_state_2 = steered_to_1[:,-1,:].softmax(1).detach()

outputs =[output_state_1,output_state_2,corrupted_output_state_1,corrupted_output_state_2]
zero_bars = {"state_1":0,"state_2":0,"corrupted_state_1":0,"corrupted_state_2":0}
one_bars = {"state_1":0,"state_2":0,"corrupted_state_1":0,"corrupted_state_2":0}
for i,output in enumerate(outputs):
    total = len(output)
    key = list(one_bars.keys())[i]
    one_bars[key] = sum(output[:,0].numpy())/total
    zero_bars[key] = sum(output[:,1].numpy())/total

# Create a subplot with two scatter plots
fig = make_subplots(rows=1, cols=2, specs=[[{'type': 'scatter'}, {'type': 'scatter'}]])

# Plot the ground truth beliefs on the left
fig.add_trace(go.Bar(x=["State T","State F", "State T->State F","State F-> State T"], y=list(zero_bars.values()),
                         name=f'Probability to output 0'),
              row=1, col=1)
fig.add_trace(go.Bar(x=["State T","State F", "State T->State F","State F-> State T"], y=list(one_bars.values()),
                         name=f'Probability to output 1'),
              row=1, col=2)
fig.update_layout(title='Output probabilities',
                  yaxis_title='Probabilities', xaxis_title='Model belief state',
                  width=800, height=400,
                  )
fig.update_yaxes(range=[0, 1], row=1, col=2)
fig.update_yaxes(range=[0, 1], row=1, col=1)


