In [1]:
import torch
from transformers import AutoTokenizer, AutoModel
from pathlib import Path
from einops import rearrange, repeat
import os
import numpy as np

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Load pre-trained model and tokenizer
# model_name = "huggyllama/llama-13b"
model_name = "gpt2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name)

# Tokenize input sequence
input_sequence = "This film is terrible, best to pass."
input_tokens = tokenizer.tokenize(input_sequence)
print(input_tokens)
input_tokens.insert(0, "<bos>") # beginning of sentence
print(input_tokens)

# Encode input sequence
input_ids = tokenizer.encode(input_sequence, return_tensors="pt")
print(input_ids)

# Generate hidden states
outputs = model(input_ids, output_hidden_states=True)

['This', 'Ġfilm', 'Ġis', 'Ġterrible', ',', 'Ġbest', 'Ġto', 'Ġpass', '.']
['<bos>', 'This', 'Ġfilm', 'Ġis', 'Ġterrible', ',', 'Ġbest', 'Ġto', 'Ġpass', '.']
tensor([[1212, 2646,  318, 7818,   11, 1266,  284, 1208,   13]])


In [3]:
hidden_states = outputs.hidden_states
print(hidden_states[0].shape)
print(len(hidden_states))
cat_hidden_states = torch.cat(hidden_states[:-1], dim=0)
print(cat_hidden_states.shape)

torch.Size([1, 9, 768])
13
torch.Size([12, 9, 768])


In [4]:
def load_all_reporters(num_layers: int, prefix_path: str) -> torch.Tensor:
    reporters_weights = []
    for i in range(num_layers):
        reporter_path = f"{prefix_path}/layer_{i}.pt"
        reporter = torch.load(reporter_path)
        reporters_weights.append(reporter['weight'].cpu())
    stacked = torch.cat(reporters_weights, dim=0)
    return stacked

# path = '/home/waree/elk-reporters/huggyllama/llama-13b/sethapun/imdb_misspelled_0/llama13b-imdb0/reporters'
path = '/mnt/ssd-1/spar/jon/elk-reporters/gpt2/imdb/agitated-driscoll/reporters'
reporter_weights = load_all_reporters(model.config.num_hidden_layers, path)
print(reporter_weights.shape)

torch.Size([12, 768])


In [5]:
# Use einsum to do multiplication 
result = torch.einsum('bse,be->bs', cat_hidden_states, reporter_weights)
print(result[0])
print(result.shape)

sigmoid_result = torch.sigmoid(result)
softmax_result = torch.softmax(result, dim = -1)
torch.set_printoptions(precision=1)

tensor([ 0.5020, -0.2215, -0.1510, -0.2354, -0.3721, -0.5517, -0.5882, -0.6715,
        -0.7445], grad_fn=<SelectBackward0>)
torch.Size([12, 9])


In [6]:
print(result)

tensor([[ 5.0e-01, -2.2e-01, -1.5e-01, -2.4e-01, -3.7e-01, -5.5e-01, -5.9e-01,
         -6.7e-01, -7.4e-01],
        [ 5.5e+00,  7.0e-01,  3.4e+00,  3.6e+00,  2.7e+00,  4.7e+00,  2.8e+00,
          1.8e+00,  3.2e+00],
        [-3.7e+01, -5.8e+00, -8.0e+00, -9.5e+00, -8.5e+00, -7.3e+00, -6.6e+00,
         -5.1e+00, -9.6e+00],
        [-7.0e+02, -9.8e+00, -7.8e+00, -1.1e+01, -6.9e+00, -5.0e+00, -7.2e+00,
         -4.8e+00, -1.0e+01],
        [ 3.4e+02,  6.8e+00,  4.8e+00,  2.4e+00,  3.3e+00,  3.0e+00,  3.3e+00,
          3.8e+00,  4.8e+00],
        [-2.5e+02, -8.0e+00, -4.3e+00, -4.0e+00, -3.3e+00, -2.4e+00, -2.6e+00,
         -3.9e+00, -3.9e+00],
        [-9.3e+02, -2.4e+01, -1.9e+01, -2.7e+01, -1.8e+01, -1.0e+01, -1.0e+01,
         -9.2e+00, -1.6e+01],
        [-8.2e+02, -2.5e+01, -2.1e+01, -3.2e+01, -2.1e+01, -1.6e+01, -1.4e+01,
         -1.4e+01, -1.8e+01],
        [-8.0e+02, -3.3e+01, -2.9e+01, -4.2e+01, -3.0e+01, -2.4e+01, -2.1e+01,
         -2.1e+01, -2.9e+01],
        [ 4.3e+02, 

In [7]:
print(sigmoid_result)

tensor([[6.2e-01, 4.4e-01, 4.6e-01, 4.4e-01, 4.1e-01, 3.7e-01, 3.6e-01, 3.4e-01,
         3.2e-01],
        [1.0e+00, 6.7e-01, 9.7e-01, 9.7e-01, 9.4e-01, 9.9e-01, 9.4e-01, 8.6e-01,
         9.6e-01],
        [1.4e-16, 2.9e-03, 3.3e-04, 7.4e-05, 2.1e-04, 7.0e-04, 1.3e-03, 6.2e-03,
         6.5e-05],
        [0.0e+00, 5.6e-05, 4.1e-04, 1.3e-05, 1.0e-03, 6.5e-03, 7.1e-04, 8.2e-03,
         4.4e-05],
        [1.0e+00, 1.0e+00, 9.9e-01, 9.2e-01, 9.6e-01, 9.5e-01, 9.6e-01, 9.8e-01,
         9.9e-01],
        [0.0e+00, 3.5e-04, 1.3e-02, 1.7e-02, 3.5e-02, 8.6e-02, 7.0e-02, 2.0e-02,
         2.0e-02],
        [0.0e+00, 5.4e-11, 5.4e-09, 2.0e-12, 1.9e-08, 3.3e-05, 3.0e-05, 1.0e-04,
         1.3e-07],
        [0.0e+00, 9.3e-12, 1.1e-09, 1.7e-14, 5.2e-10, 1.5e-07, 8.4e-07, 1.2e-06,
         1.1e-08],
        [0.0e+00, 4.7e-15, 2.2e-13, 6.3e-19, 9.1e-14, 3.5e-11, 1.1e-09, 5.1e-10,
         2.5e-13],
        [1.0e+00, 1.0e+00, 1.0e+00, 1.0e+00, 1.0e+00, 1.0e+00, 1.0e+00, 1.0e+00,
         1.0e+00],


In [8]:
print(softmax_result)

tensor([[2.4e-01, 1.2e-01, 1.2e-01, 1.1e-01, 1.0e-01, 8.3e-02, 8.0e-02, 7.4e-02,
         6.9e-02],
        [5.0e-01, 4.1e-03, 6.5e-02, 7.6e-02, 3.2e-02, 2.2e-01, 3.4e-02, 1.3e-02,
         5.2e-02],
        [1.2e-14, 2.5e-01, 2.8e-02, 6.2e-03, 1.7e-02, 5.9e-02, 1.1e-01, 5.2e-01,
         5.4e-03],
        [0.0e+00, 3.3e-03, 2.4e-02, 7.4e-04, 6.0e-02, 3.8e-01, 4.2e-02, 4.8e-01,
         2.6e-03],
        [1.0e+00, 0.0e+00, 0.0e+00, 0.0e+00, 0.0e+00, 0.0e+00, 0.0e+00, 0.0e+00,
         0.0e+00],
        [0.0e+00, 1.2e-03, 4.9e-02, 6.4e-02, 1.3e-01, 3.4e-01, 2.7e-01, 7.5e-02,
         7.2e-02],
        [0.0e+00, 3.3e-07, 3.3e-05, 1.2e-08, 1.1e-04, 2.0e-01, 1.8e-01, 6.2e-01,
         7.8e-04],
        [0.0e+00, 4.2e-06, 4.8e-04, 8.0e-09, 2.4e-04, 6.7e-02, 3.9e-01, 5.4e-01,
         5.0e-03],
        [0.0e+00, 2.9e-06, 1.3e-04, 3.9e-10, 5.6e-05, 2.2e-02, 6.6e-01, 3.2e-01,
         1.5e-04],
        [1.0e+00, 0.0e+00, 0.0e+00, 0.0e+00, 0.0e+00, 0.0e+00, 0.0e+00, 0.0e+00,
         0.0e+00],


In [9]:
result.shape

torch.Size([12, 9])

In [36]:
import plotly.graph_objs as go
import plotly.io as pio
import torch
from IPython.display import HTML
from plotly.offline import init_notebook_mode, iplot
init_notebook_mode(connected=True)

# Create color scale for heatmap
color_scale = [[0, '#FFFFFF'], [1, '#FF0000']] # white to red

# Convert tensor to numpy array and detach gradients
credences = result.detach().numpy()

padded_text = np.array([[" " * 2 + "{:.2f}".format(value) + " " * 2 for value in row] for row in credences])

# Create plotly heatmap
heatmap = go.Heatmap(
    z=credences, 
    colorscale=color_scale,
    text=padded_text,  # Set the text to be equal to z
    texttemplate="%{text}", 
    textfont=dict(color='black', size=12),  # Set the text color and size
)

# Create plot layout
layout = go.Layout(title='Credences for Input Tokens',
                width=800,  # Set the width of the plot
                height=800,  # Set the height of the plot
                   xaxis=dict(tickvals=list(range(len(input_tokens))),
                              ticktext=input_tokens,
                              tickangle=45))

# Create plotly figure
fig = go.Figure(data=[heatmap], layout=layout)
# fig = fig.update_traces(text=input_tokens, texttemplate="%{text}", hovertemplate=None)

# Display plotly figure
iplot(fig)

In [39]:
words_corresponding_to_credences = np.array([[" " * 2 + word + " " * 2 for word in input_tokens] for layer_num in range(len(credences))])

heatmap = go.Heatmap(
    z=credences,
    colorscale=color_scale,
    text=words_corresponding_to_credences,  # Set the text to be equal to z
    texttemplate="%{text}",
    textfont=dict(color='black', size=12),  # Set the text color and size
)

# Create plot layout
layout = go.Layout(title='Credences for Input Tokens',
                width=800,  # Set the width of the plot
                height=800,  # Set the height of the plot

                     xaxis=dict(tickvals=list(range(len(input_tokens))),
                                ticktext=input_tokens,
                                tickangle=45))

# Create plotly figure
fig = go.Figure(data=[heatmap], layout=layout)

# Display plotly figure
iplot(fig)


# Sanity Checks

In [12]:
# Use einsum to do multiplication 

reporter_weights_repeat = repeat(reporter_weights, 'b e -> b c e', c=len(input_tokens) + 1)
result = torch.einsum('bse,bse->bs', cat_hidden_states, reporter_weights_repeat)
print(result.shape)

sigmoid_result = torch.sigmoid(result)
softmax_result = torch.softmax(result, dim = -1)
torch.set_printoptions(precision=1)

RuntimeError: einsum(): subscript s has size 11 for operand 1 which does not broadcast with previously seen size 9

In [None]:
print(result)

tensor([[ 6.2e-03,  5.5e-03, -2.2e-02, -9.1e-03, -4.4e-02, -7.0e-03,  6.1e-02,
         -7.1e-04, -9.6e-03,  4.9e-03],
        [-1.2e+00, -5.9e-01, -1.8e-01, -6.5e-01, -3.5e-01, -9.4e-01, -2.4e-01,
         -4.6e-01, -1.8e-01, -9.0e-01],
        [-6.9e-01, -2.9e-01, -1.6e-01, -2.5e-01, -4.1e-01, -2.7e-01, -1.7e-01,
         -7.4e-02,  5.3e-02, -2.7e-01],
        [ 3.3e+01,  8.6e-02, -1.0e-01,  1.9e-02,  3.1e-01,  6.0e-03, -1.1e-01,
         -1.5e-01, -2.3e-01, -1.5e-01],
        [-3.5e+01, -8.6e-03,  7.5e-01,  2.8e-01, -4.6e-01, -1.6e-01,  2.4e-01,
          2.7e-01,  2.4e-01,  1.0e-01],
        [ 2.8e+01,  1.7e-01, -3.2e-01, -1.5e-01,  6.5e-01,  3.0e-01, -3.1e-01,
         -6.8e-02, -2.5e-01, -1.6e-02],
        [ 3.2e+01,  3.4e-01, -3.5e-01,  8.9e-02,  3.6e-01,  2.3e-01, -6.3e-01,
         -4.9e-01, -3.1e-01, -1.8e-01],
        [ 2.5e+01,  3.1e-01, -1.1e+00,  3.1e-02,  2.3e-01,  3.5e-01, -9.6e-01,
         -5.5e-01, -4.7e-01,  2.8e-01],
        [-2.7e+01,  4.0e-01,  1.0e+00,  3.1e-01,