In [32]:
import torch
from transformers import AutoTokenizer, AutoModel
from pathlib import Path
from einops import rearrange, repeat
import os
import numpy as np

# todo: 
# - modularize
# - add slider for diff sentences


In [21]:
# Load pre-trained model and tokenizer
# model_name = "huggyllama/llama-13b"
model_name = "huggyllama/llama-7b"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name)

# Tokenize input sequence
input_sequence = "The horse raced past the barn fell."
input_tokens = tokenizer.tokenize(input_sequence)
print(input_tokens)
input_tokens.insert(0, "<bos>") # beginning of sentence
print(input_tokens)

# Encode input sequence
input_ids = tokenizer.encode(input_sequence, return_tensors="pt")
print(input_ids)

# Generate hidden states
outputs = model(input_ids, output_hidden_states=True)

Loading checkpoint shards: 100%|██████████| 2/2 [00:09<00:00,  4.81s/it]
Some weights of the model checkpoint at huggyllama/llama-7b were not used when initializing LlamaModel: ['lm_head.weight']
- This IS expected if you are initializing LlamaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing LlamaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


['▁The', '▁horse', '▁rac', 'ed', '▁past', '▁the', '▁bar', 'n', '▁fell', '.']
['<bos>', '▁The', '▁horse', '▁rac', 'ed', '▁past', '▁the', '▁bar', 'n', '▁fell', '.']
tensor([[    1,   450, 10435, 11021,   287,  4940,   278,  2594, 29876,  8379,
         29889]])


In [22]:
hidden_states = outputs.hidden_states
print(hidden_states[0].shape)
print(len(hidden_states))
cat_hidden_states = torch.cat(hidden_states[:-1], dim=0)
print(cat_hidden_states.shape)

torch.Size([1, 11, 4096])
33
torch.Size([32, 11, 4096])


In [23]:
def load_all_reporters(num_layers: int, prefix_path: str) -> torch.Tensor:
    reporters_weights = []
    for i in range(num_layers):
        reporter_path = f"{prefix_path}/layer_{i}.pt"
        reporter = torch.load(reporter_path).weight
        reporters_weights.append(reporter.cpu())
    stacked = torch.cat(reporters_weights, dim=0)
    return stacked

# path = '/home/waree/elk-reporters/huggyllama/llama-13b/sethapun/imdb_misspelled_0/llama13b-imdb0/reporters'
path = '/home/jon/elk-reporters/huggyllama/llama-7b/imdb/funny-robinson/reporters'

reporter_weights = load_all_reporters(model.config.num_hidden_layers, path)
print(reporter_weights.shape)


torch.Size([32, 4096])


In [24]:
# Use einsum to do multiplication 
result = torch.einsum('bse,be->bs', cat_hidden_states, reporter_weights)
print(result[0])
print(result.shape)

sigmoid_result = torch.sigmoid(result)
softmax_result = torch.softmax(result, dim = -1)
torch.set_printoptions(precision=1)

tensor([ 0.0,  0.0,  0.0,  0.0,  0.0, -0.0,  0.1,  0.0, -0.0, -0.0, -0.0],
       grad_fn=<SelectBackward0>)
torch.Size([32, 11])


In [25]:
print(result)

tensor([[ 8.8e-03,  4.6e-02,  4.8e-03,  3.4e-03,  9.1e-03, -4.9e-03,  1.3e-01,
          2.1e-02, -4.1e-03, -1.3e-02, -3.2e-02],
        [-1.5e-01, -2.1e-01, -1.6e-01, -7.6e-02, -5.6e-02,  2.4e-02, -1.5e-01,
         -3.8e-02, -1.5e-02, -1.5e-03, -6.8e-02],
        [ 1.9e-01,  1.0e-01,  4.3e-02, -2.1e-02,  7.5e-02, -2.6e-01,  5.1e-02,
         -9.0e-02, -1.3e-02, -3.4e-02, -5.0e-03],
        [ 2.3e+01, -2.5e-01, -1.3e-01,  1.5e-01,  6.0e-02, -2.3e-01, -3.6e-01,
         -2.0e-01, -1.8e-01, -2.3e-01, -2.0e-01],
        [ 5.9e+00, -2.1e-01,  1.3e-01,  2.2e-01,  1.8e-01, -9.9e-02, -1.8e-01,
         -5.3e-02, -2.1e-01, -1.3e-01, -3.5e-02],
        [-1.7e+01, -1.3e-01, -4.4e-01, -2.6e-01, -3.6e-01, -3.5e-02,  3.5e-03,
          8.4e-02, -8.5e-02,  3.0e-01, -1.1e-01],
        [ 1.2e+01,  2.1e-01,  5.3e-01,  3.5e-01,  4.5e-01,  6.3e-02, -1.5e-01,
         -9.4e-02, -1.5e-01, -3.5e-01,  3.1e-02],
        [ 2.1e+00,  2.5e-01,  4.1e-01,  4.2e-01,  6.1e-01, -2.3e-01, -1.5e-01,
         -1.6e-01,

In [26]:
print(sigmoid_result)

tensor([[5.0e-01, 5.1e-01, 5.0e-01, 5.0e-01, 5.0e-01, 5.0e-01, 5.3e-01, 5.1e-01,
         5.0e-01, 5.0e-01, 4.9e-01],
        [4.6e-01, 4.5e-01, 4.6e-01, 4.8e-01, 4.9e-01, 5.1e-01, 4.6e-01, 4.9e-01,
         5.0e-01, 5.0e-01, 4.8e-01],
        [5.5e-01, 5.3e-01, 5.1e-01, 4.9e-01, 5.2e-01, 4.4e-01, 5.1e-01, 4.8e-01,
         5.0e-01, 4.9e-01, 5.0e-01],
        [1.0e+00, 4.4e-01, 4.7e-01, 5.4e-01, 5.1e-01, 4.4e-01, 4.1e-01, 4.5e-01,
         4.6e-01, 4.4e-01, 4.5e-01],
        [1.0e+00, 4.5e-01, 5.3e-01, 5.6e-01, 5.5e-01, 4.8e-01, 4.6e-01, 4.9e-01,
         4.5e-01, 4.7e-01, 4.9e-01],
        [5.4e-08, 4.7e-01, 3.9e-01, 4.4e-01, 4.1e-01, 4.9e-01, 5.0e-01, 5.2e-01,
         4.8e-01, 5.7e-01, 4.7e-01],
        [1.0e+00, 5.5e-01, 6.3e-01, 5.9e-01, 6.1e-01, 5.2e-01, 4.6e-01, 4.8e-01,
         4.6e-01, 4.1e-01, 5.1e-01],
        [8.9e-01, 5.6e-01, 6.0e-01, 6.0e-01, 6.5e-01, 4.4e-01, 4.6e-01, 4.6e-01,
         4.3e-01, 5.2e-01, 6.7e-01],
        [1.0e-01, 5.0e-01, 4.4e-01, 4.6e-01, 4.2e-01, 5.

In [27]:
print(softmax_result)

tensor([[9.0e-02, 9.4e-02, 9.0e-02, 9.0e-02, 9.0e-02, 8.9e-02, 1.0e-01, 9.1e-02,
         8.9e-02, 8.8e-02, 8.7e-02],
        [8.4e-02, 7.9e-02, 8.4e-02, 9.1e-02, 9.3e-02, 1.0e-01, 8.5e-02, 9.5e-02,
         9.7e-02, 9.8e-02, 9.2e-02],
        [1.1e-01, 1.0e-01, 9.4e-02, 8.8e-02, 9.7e-02, 6.9e-02, 9.5e-02, 8.2e-02,
         8.9e-02, 8.7e-02, 9.0e-02],
        [1.0e+00, 5.0e-11, 5.6e-11, 7.4e-11, 6.8e-11, 5.1e-11, 4.5e-11, 5.2e-11,
         5.4e-11, 5.1e-11, 5.3e-11],
        [9.7e-01, 2.3e-03, 3.2e-03, 3.5e-03, 3.3e-03, 2.5e-03, 2.3e-03, 2.6e-03,
         2.2e-03, 2.5e-03, 2.7e-03],
        [5.9e-09, 9.5e-02, 7.0e-02, 8.4e-02, 7.6e-02, 1.0e-01, 1.1e-01, 1.2e-01,
         1.0e-01, 1.5e-01, 9.7e-02],
        [1.0e+00, 6.4e-06, 8.9e-06, 7.4e-06, 8.2e-06, 5.6e-06, 4.5e-06, 4.8e-06,
         4.5e-06, 3.7e-06, 5.4e-06],
        [3.9e-01, 6.3e-02, 7.4e-02, 7.5e-02, 9.0e-02, 3.9e-02, 4.2e-02, 4.2e-02,
         3.7e-02, 5.3e-02, 9.8e-02],
        [1.3e-02, 1.1e-01, 8.7e-02, 9.6e-02, 8.2e-02, 1.

In [28]:
result.shape

torch.Size([32, 11])

In [29]:
import plotly.graph_objs as go
import plotly.io as pio
import torch
from IPython.display import HTML
from plotly.offline import init_notebook_mode, iplot
init_notebook_mode(connected=True)

# Create color scale for heatmap
color_scale = [[0, '#FFFFFF'], [1, '#FF0000']] # white to red

# Convert tensor to numpy array and detach gradients
credences = result.detach().numpy()

padded_text = np.array([[" " * 2 + "{:.2f}".format(value) + " " * 2 for value in row] for row in credences])

# Create plotly heatmap
heatmap = go.Heatmap(
    z=credences, 
    colorscale=color_scale,
    text=padded_text,  # Set the text to be equal to z
    texttemplate="%{text}", 
    textfont=dict(color='black', size=12),  # Set the text color and size
)

# Create plot layout
layout = go.Layout(title='Credences for Input Tokens',
                width=800,  # Set the width of the plot
                height=800,  # Set the height of the plot
                   xaxis=dict(tickvals=list(range(len(input_tokens))),
                              ticktext=input_tokens,
                              tickangle=45))

# Create plotly figure
fig = go.Figure(data=[heatmap], layout=layout)
# fig = fig.update_traces(text=input_tokens, texttemplate="%{text}", hovertemplate=None)

# Display plotly figure
iplot(fig)

In [30]:
words_corresponding_to_credences = np.array([[" " * 2 + word + " " * 2 for word in input_tokens] for layer_num in range(len(credences))])

heatmap = go.Heatmap(
    z=credences,
    colorscale=color_scale,
    text=words_corresponding_to_credences,  # Set the text to be equal to z
    texttemplate="%{text}",
    textfont=dict(color='black', size=12),  # Set the text color and size
)

# Create plot layout
layout = go.Layout(title='Credences for Input Tokens',
                width=800,  # Set the width of the plot
                height=800,  # Set the height of the plot

                     xaxis=dict(tickvals=list(range(len(input_tokens))),
                                ticktext=input_tokens,
                                tickangle=45))

# Create plotly figure
fig = go.Figure(data=[heatmap], layout=layout)

# Display plotly figure
iplot(fig)


# Sanity Checks

In [31]:
# Use einsum to do multiplication 

reporter_weights_repeat = repeat(reporter_weights, 'b e -> b c e', c=len(input_tokens) + 1)
result = torch.einsum('bse,bse->bs', cat_hidden_states, reporter_weights_repeat)
print(result.shape)

sigmoid_result = torch.sigmoid(result)
softmax_result = torch.softmax(result, dim = -1)
torch.set_printoptions(precision=1)

RuntimeError: einsum(): subscript s has size 12 for operand 1 which does not broadcast with previously seen size 11

In [None]:
print(result)

tensor([[ 6.2e-03,  5.5e-03, -2.2e-02, -9.1e-03, -4.4e-02, -7.0e-03,  6.1e-02,
         -7.1e-04, -9.6e-03,  4.9e-03],
        [-1.2e+00, -5.9e-01, -1.8e-01, -6.5e-01, -3.5e-01, -9.4e-01, -2.4e-01,
         -4.6e-01, -1.8e-01, -9.0e-01],
        [-6.9e-01, -2.9e-01, -1.6e-01, -2.5e-01, -4.1e-01, -2.7e-01, -1.7e-01,
         -7.4e-02,  5.3e-02, -2.7e-01],
        [ 3.3e+01,  8.6e-02, -1.0e-01,  1.9e-02,  3.1e-01,  6.0e-03, -1.1e-01,
         -1.5e-01, -2.3e-01, -1.5e-01],
        [-3.5e+01, -8.6e-03,  7.5e-01,  2.8e-01, -4.6e-01, -1.6e-01,  2.4e-01,
          2.7e-01,  2.4e-01,  1.0e-01],
        [ 2.8e+01,  1.7e-01, -3.2e-01, -1.5e-01,  6.5e-01,  3.0e-01, -3.1e-01,
         -6.8e-02, -2.5e-01, -1.6e-02],
        [ 3.2e+01,  3.4e-01, -3.5e-01,  8.9e-02,  3.6e-01,  2.3e-01, -6.3e-01,
         -4.9e-01, -3.1e-01, -1.8e-01],
        [ 2.5e+01,  3.1e-01, -1.1e+00,  3.1e-02,  2.3e-01,  3.5e-01, -9.6e-01,
         -5.5e-01, -4.7e-01,  2.8e-01],
        [-2.7e+01,  4.0e-01,  1.0e+00,  3.1e-01,