In [1]:
import torch
from transformers import AutoTokenizer, AutoModel
from pathlib import Path
from einops import rearrange, repeat
import os
import numpy as np

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Load pre-trained model and tokenizer
model_name = "huggyllama/llama-13b"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name)

# Tokenize input sequence
input_sequence = "This film is terrible, best to pass."
input_tokens = tokenizer.tokenize(input_sequence)
print(input_tokens)
input_tokens.insert(0, "<bos>") # beginning of sentence
print(input_tokens)

# Encode input sequence
input_ids = tokenizer.encode(input_sequence, return_tensors="pt")
print(input_ids)

# Generate hidden states
outputs = model(input_ids, output_hidden_states=True)

Loading checkpoint shards: 100%|██████████| 3/3 [00:26<00:00,  8.89s/it]
Some weights of the model checkpoint at huggyllama/llama-13b were not used when initializing LlamaModel: ['lm_head.weight']
- This IS expected if you are initializing LlamaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing LlamaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


['▁This', '▁film', '▁is', '▁terrible', ',', '▁best', '▁to', '▁pass', '.']
['<bos>', '▁This', '▁film', '▁is', '▁terrible', ',', '▁best', '▁to', '▁pass', '.']
tensor([[    1,   910,  2706,   338, 16403, 29892,  1900,   304,  1209, 29889]])


In [3]:
hidden_states = outputs.hidden_states
print(hidden_states[0].shape)
print(len(hidden_states))
cat_hidden_states = torch.cat(hidden_states[:-1], dim=0)
print(cat_hidden_states.shape)

torch.Size([1, 10, 5120])
41
torch.Size([40, 10, 5120])


In [4]:
def load_all_reporters(num_layers: int, prefix_path: str) -> torch.Tensor:
    reporters_weights = []
    for i in range(num_layers):
        reporter_path = f"{prefix_path}/layer_{i}.pt"
        reporter = torch.load(reporter_path)
        reporters_weights.append(reporter.weight.cpu())
    stacked = torch.cat(reporters_weights, dim=0)
    return stacked

path = '/home/waree/elk-reporters/huggyllama/llama-13b/sethapun/imdb_misspelled_0/llama13b-imdb0/reporters'
reporter_weights = load_all_reporters(model.config.num_hidden_layers, path)
print(reporter_weights.shape)

torch.Size([40, 5120])


In [12]:
# Use einsum to do multiplication 
result = torch.einsum('bse,be->bs', cat_hidden_states, reporter_weights)
print(result[0])
print(result.shape)

sigmoid_result = torch.sigmoid(result)
softmax_result = torch.softmax(result, dim = -1)
torch.set_printoptions(precision=1)

tensor([ 0.0,  0.0, -0.0, -0.0, -0.0, -0.0,  0.1, -0.0, -0.0,  0.0],
       grad_fn=<SelectBackward0>)
torch.Size([40, 10])


In [6]:
print(result)

tensor([[ 6.2e-03,  5.5e-03, -2.2e-02, -9.1e-03, -4.4e-02, -7.0e-03,  6.1e-02,
         -7.1e-04, -9.6e-03,  4.9e-03],
        [-1.2e+00, -5.9e-01, -1.8e-01, -6.5e-01, -3.5e-01, -9.4e-01, -2.4e-01,
         -4.6e-01, -1.8e-01, -9.0e-01],
        [-6.9e-01, -2.9e-01, -1.6e-01, -2.5e-01, -4.1e-01, -2.7e-01, -1.7e-01,
         -7.4e-02,  5.3e-02, -2.7e-01],
        [ 3.3e+01,  8.6e-02, -1.0e-01,  1.9e-02,  3.1e-01,  6.0e-03, -1.1e-01,
         -1.5e-01, -2.3e-01, -1.5e-01],
        [-3.5e+01, -8.6e-03,  7.5e-01,  2.8e-01, -4.6e-01, -1.6e-01,  2.4e-01,
          2.7e-01,  2.4e-01,  1.0e-01],
        [ 2.8e+01,  1.7e-01, -3.2e-01, -1.5e-01,  6.5e-01,  3.0e-01, -3.1e-01,
         -6.8e-02, -2.5e-01, -1.6e-02],
        [ 3.2e+01,  3.4e-01, -3.5e-01,  8.9e-02,  3.6e-01,  2.3e-01, -6.3e-01,
         -4.9e-01, -3.1e-01, -1.8e-01],
        [ 2.5e+01,  3.1e-01, -1.1e+00,  3.1e-02,  2.3e-01,  3.5e-01, -9.6e-01,
         -5.5e-01, -4.7e-01,  2.8e-01],
        [-2.7e+01,  4.0e-01,  1.0e+00,  3.1e-01,

In [7]:
print(sigmoid_result)

tensor([[5.0e-01, 5.0e-01, 4.9e-01, 5.0e-01, 4.9e-01, 5.0e-01, 5.2e-01, 5.0e-01,
         5.0e-01, 5.0e-01],
        [2.2e-01, 3.6e-01, 4.6e-01, 3.4e-01, 4.1e-01, 2.8e-01, 4.4e-01, 3.9e-01,
         4.5e-01, 2.9e-01],
        [3.3e-01, 4.3e-01, 4.6e-01, 4.4e-01, 4.0e-01, 4.3e-01, 4.6e-01, 4.8e-01,
         5.1e-01, 4.3e-01],
        [1.0e+00, 5.2e-01, 4.7e-01, 5.0e-01, 5.8e-01, 5.0e-01, 4.7e-01, 4.6e-01,
         4.4e-01, 4.6e-01],
        [8.3e-16, 5.0e-01, 6.8e-01, 5.7e-01, 3.9e-01, 4.6e-01, 5.6e-01, 5.7e-01,
         5.6e-01, 5.3e-01],
        [1.0e+00, 5.4e-01, 4.2e-01, 4.6e-01, 6.6e-01, 5.8e-01, 4.2e-01, 4.8e-01,
         4.4e-01, 5.0e-01],
        [1.0e+00, 5.8e-01, 4.1e-01, 5.2e-01, 5.9e-01, 5.6e-01, 3.5e-01, 3.8e-01,
         4.2e-01, 4.6e-01],
        [1.0e+00, 5.8e-01, 2.5e-01, 5.1e-01, 5.6e-01, 5.9e-01, 2.8e-01, 3.7e-01,
         3.9e-01, 5.7e-01],
        [3.1e-12, 6.0e-01, 7.4e-01, 5.8e-01, 5.6e-01, 6.1e-01, 8.3e-01, 6.5e-01,
         5.6e-01, 3.7e-01],
        [4.4e-04, 2

In [8]:
print(softmax_result)

tensor([[1.0e-01, 1.0e-01, 9.8e-02, 9.9e-02, 9.6e-02, 9.9e-02, 1.1e-01, 1.0e-01,
         9.9e-02, 1.0e-01],
        [4.8e-02, 9.3e-02, 1.4e-01, 8.8e-02, 1.2e-01, 6.6e-02, 1.3e-01, 1.1e-01,
         1.4e-01, 6.9e-02],
        [6.4e-02, 9.5e-02, 1.1e-01, 9.9e-02, 8.4e-02, 9.7e-02, 1.1e-01, 1.2e-01,
         1.3e-01, 9.7e-02],
        [1.0e+00, 4.2e-15, 3.5e-15, 3.9e-15, 5.2e-15, 3.9e-15, 3.4e-15, 3.3e-15,
         3.0e-15, 3.3e-15],
        [7.6e-17, 9.1e-02, 1.9e-01, 1.2e-01, 5.8e-02, 7.8e-02, 1.2e-01, 1.2e-01,
         1.2e-01, 1.0e-01],
        [1.0e+00, 7.5e-13, 4.6e-13, 5.4e-13, 1.2e-12, 8.6e-13, 4.6e-13, 5.9e-13,
         4.9e-13, 6.2e-13],
        [1.0e+00, 2.3e-14, 1.2e-14, 1.8e-14, 2.4e-14, 2.1e-14, 8.9e-15, 1.0e-14,
         1.2e-14, 1.4e-14],
        [1.0e+00, 1.4e-11, 3.6e-12, 1.1e-11, 1.3e-11, 1.5e-11, 4.1e-12, 6.1e-12,
         6.7e-12, 1.4e-11],
        [1.8e-13, 8.7e-02, 1.7e-01, 7.9e-02, 7.3e-02, 9.3e-02, 2.9e-01, 1.1e-01,
         7.3e-02, 3.5e-02],
        [6.8e-05, 5

In [29]:
result.shape

torch.Size([40, 10])

In [52]:
import plotly.graph_objs as go
import plotly.io as pio
import torch
from IPython.display import HTML
from plotly.offline import init_notebook_mode, iplot
init_notebook_mode(connected=True)

# Create color scale for heatmap
color_scale = [[0, '#FFFFFF'], [1, '#FF0000']] # white to red

# Convert tensor to numpy array and detach gradients
credences = result.detach().numpy()

padded_text = np.array([[" " * 2 + "{:.2f}".format(value) + " " * 2 for value in row] for row in credences])

# Create plotly heatmap
heatmap = go.Heatmap(
    z=credences, 
    colorscale=color_scale,
    text=padded_text,  # Set the text to be equal to z
    texttemplate="%{text}", 
    textfont=dict(color='black', size=12),  # Set the text color and size
)

# Create plot layout
layout = go.Layout(title='Credences for Input Tokens',
                width=800,  # Set the width of the plot
                height=800,  # Set the height of the plot
                   xaxis=dict(tickvals=list(range(len(input_tokens))),
                              ticktext=input_tokens,
                              tickangle=45))

# Create plotly figure
fig = go.Figure(data=[heatmap], layout=layout)
# fig = fig.update_traces(text=input_tokens, texttemplate="%{text}", hovertemplate=None)

# Display plotly figure
iplot(fig)

# Sanity Checks

In [21]:
# Use einsum to do multiplication 

reporter_weights_repeat = repeat(reporter_weights, 'b e -> b c e', c=len(input_tokens) + 1)
result = torch.einsum('bse,bse->bs', cat_hidden_states, reporter_weights_repeat)
print(result.shape)

sigmoid_result = torch.sigmoid(result)
softmax_result = torch.softmax(result, dim = -1)
torch.set_printoptions(precision=1)

RuntimeError: einsum(): subscript s has size 11 for operand 1 which does not broadcast with previously seen size 10

In [None]:
print(result)

tensor([[ 6.2e-03,  5.5e-03, -2.2e-02, -9.1e-03, -4.4e-02, -7.0e-03,  6.1e-02,
         -7.1e-04, -9.6e-03,  4.9e-03],
        [-1.2e+00, -5.9e-01, -1.8e-01, -6.5e-01, -3.5e-01, -9.4e-01, -2.4e-01,
         -4.6e-01, -1.8e-01, -9.0e-01],
        [-6.9e-01, -2.9e-01, -1.6e-01, -2.5e-01, -4.1e-01, -2.7e-01, -1.7e-01,
         -7.4e-02,  5.3e-02, -2.7e-01],
        [ 3.3e+01,  8.6e-02, -1.0e-01,  1.9e-02,  3.1e-01,  6.0e-03, -1.1e-01,
         -1.5e-01, -2.3e-01, -1.5e-01],
        [-3.5e+01, -8.6e-03,  7.5e-01,  2.8e-01, -4.6e-01, -1.6e-01,  2.4e-01,
          2.7e-01,  2.4e-01,  1.0e-01],
        [ 2.8e+01,  1.7e-01, -3.2e-01, -1.5e-01,  6.5e-01,  3.0e-01, -3.1e-01,
         -6.8e-02, -2.5e-01, -1.6e-02],
        [ 3.2e+01,  3.4e-01, -3.5e-01,  8.9e-02,  3.6e-01,  2.3e-01, -6.3e-01,
         -4.9e-01, -3.1e-01, -1.8e-01],
        [ 2.5e+01,  3.1e-01, -1.1e+00,  3.1e-02,  2.3e-01,  3.5e-01, -9.6e-01,
         -5.5e-01, -4.7e-01,  2.8e-01],
        [-2.7e+01,  4.0e-01,  1.0e+00,  3.1e-01,