In [26]:
import torch
from transformers import AutoTokenizer, AutoModel
from pathlib import Path
from einops import rearrange, repeat
import os
import numpy as np
from dataclasses import dataclass
import re
import csv
import json


In [3]:
model_name = "huggyllama/llama-13b"
# model_name = "gpt2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name)

Loading checkpoint shards: 100%|██████████| 3/3 [00:19<00:00,  6.49s/it]
Some weights of the model checkpoint at huggyllama/llama-13b were not used when initializing LlamaModel: ['lm_head.weight']
- This IS expected if you are initializing LlamaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing LlamaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [4]:
def load_all_reporters(num_layers: int, prefix_path: str) -> torch.Tensor:
    reporters_weights = []
    for i in range(num_layers):
        reporter_path = f"{prefix_path}/layer_{i}.pt"
        reporter = torch.load(reporter_path)
        reporters_weights.append(reporter['weight'].cpu())
    stacked = torch.cat(reporters_weights, dim=0)
    return stacked



In [41]:
@dataclass
class Reporter:
    path: str
    weights: torch.Tensor
    desc: str

    @staticmethod
    def from_path(path):
        weights = load_all_reporters(model.config.num_hidden_layers, path)
        desc = path.split('/')[-2]
        return Reporter(path, weights, desc)

reporter_paths = [
    ('/home/jon/elk-reporters/huggyllama/llama-13b/azhx/counterfact-easy/stoic-jang/reporters', 'lm negation'), # lm negation
    ('/home/jon/elk-reporters/huggyllama/llama-13b/azhx/counterfact-filtered-gptj6b/trusting-cori/reporters', 'dumb nots'), # dumb nots
    ('/home/jon/elk-reporters/huggyllama/llama-13b/azhx/counterfact-simple/jovial-lederberg/reporters', 'counterfact pairs') # counterfact pairs
]

In [6]:
# Tokenize input sequence
input_sequence = "This film is terrible, best to pass."
input_tokens = tokenizer.tokenize(input_sequence)
print(input_tokens)
input_tokens.insert(0, "<bos>") # beginning of sentence
print(input_tokens)

# Encode input sequence
input_ids = tokenizer.encode(input_sequence, return_tensors="pt")
print(input_ids)

decoded = tokenizer.convert_ids_to_tokens(input_ids[0])
print(decoded)

# Generate hidden states
outputs = model(input_ids, output_hidden_states=True)

['▁This', '▁film', '▁is', '▁terrible', ',', '▁best', '▁to', '▁pass', '.']
['<bos>', '▁This', '▁film', '▁is', '▁terrible', ',', '▁best', '▁to', '▁pass', '.']
tensor([[    1,   910,  2706,   338, 16403, 29892,  1900,   304,  1209, 29889]])
['<s>', '▁This', '▁film', '▁is', '▁terrible', ',', '▁best', '▁to', '▁pass', '.']


In [7]:
hidden_states = outputs.hidden_states
print(hidden_states[0].shape)
print(len(hidden_states))
cat_hidden_states = torch.cat(hidden_states[:-1], dim=0)
print(cat_hidden_states.shape)

torch.Size([1, 10, 5120])
41
torch.Size([40, 10, 5120])


In [8]:
# path = '/home/waree/elk-reporters/huggyllama/llama-13b/sethapun/imdb_misspelled_0/llama13b-imdb0/reporters'
# path = '/mnt/ssd-1/spar/jon/elk-reporters/gpt2/imdb/agitated-driscoll/reporters'
# reporter_weights = load_all_reporters(model.config.num_hidden_layers, path)
# print(reporter_weights.shape)
for path in reporter_paths:
    reporter_weights = load_all_reporters(model.config.num_hidden_layers, path)
    print(reporter_weights.shape)

torch.Size([40, 5120])


In [9]:
# Use einsum to do multiplication 
result = torch.einsum('bse,be->bs', cat_hidden_states, reporter_weights)
print(result[0])
print(result.shape)

sigmoid_result = torch.sigmoid(result)
softmax_result = torch.softmax(result, dim = -1)
torch.set_printoptions(precision=1)

tensor([-0.0051,  0.0150,  0.0182,  0.0131,  0.0048, -0.0068, -0.0019, -0.0270,
         0.0224, -0.0266], grad_fn=<SelectBackward0>)
torch.Size([40, 10])


In [10]:
print(result.shape)

torch.Size([40, 10])


In [11]:
print(sigmoid_result.shape)

torch.Size([40, 10])


In [12]:
print(softmax_result.shape)

torch.Size([40, 10])


In [13]:
import plotly.graph_objs as go
import plotly.io as pio
import torch
from IPython.display import HTML
from plotly.offline import init_notebook_mode, iplot
init_notebook_mode(connected=True)

# Create color scale for heatmap
color_scale = [[0, '#FFFFFF'], [1, '#FF0000']] # white to red

# Convert tensor to numpy array and detach gradients
credences = result.detach().numpy()

padded_text = np.array([[" " * 2 + "{:.2f}".format(value) + " " * 2 for value in row] for row in credences])

# Create plotly heatmap
heatmap = go.Heatmap(
    z=credences, 
    colorscale=color_scale,
    text=padded_text,  # Set the text to be equal to z
    texttemplate="%{text}", 
    textfont=dict(color='black', size=12),  # Set the text color and size
)

# Create plot layout
layout = go.Layout(title='Credences for Input Tokens',
                width=800,  # Set the width of the plot
                height=800,  # Set the height of the plot
                   xaxis=dict(tickvals=list(range(len(input_tokens))),
                              ticktext=input_tokens,
                              tickangle=45))

# Create plotly figure
fig = go.Figure(data=[heatmap], layout=layout)
# fig = fig.update_traces(text=input_tokens, texttemplate="%{text}", hovertemplate=None)

# Display plotly figure
# iplot(fig)

# print(padded_text)
input_tokens

['<bos>',
 '▁This',
 '▁film',
 '▁is',
 '▁terrible',
 ',',
 '▁best',
 '▁to',
 '▁pass',
 '.']

In [14]:
words_corresponding_to_credences = np.array([[" " * 2 + word + " " * 2 for word in input_tokens] for layer_num in range(len(credences))])

heatmap = go.Heatmap(
    z=credences,
    colorscale=color_scale,
    text=words_corresponding_to_credences,  # Set the text to be equal to z
    texttemplate="%{text}",
    textfont=dict(color='black', size=12),  # Set the text color and size
)

# Create plot layout
layout = go.Layout(title='Credences for Input Tokens',
                width=800,  # Set the width of the plot
                height=800,  # Set the height of the plot

                     xaxis=dict(tickvals=list(range(len(input_tokens))),
                                ticktext=input_tokens,
                                tickangle=45))

# Create plotly figure
fig = go.Figure(data=[heatmap], layout=layout)

# Display plotly figure
print(credences[-1].shape)
iplot(fig)


(10,)


In [15]:
import matplotlib.pyplot as plt

def colorize(words, credences):
    
    max_color = float(max(credences))
    min_color = float(min(credences))
    normalized_credences = (credences - min_color) / (max_color - min_color)
    
    cmap = plt.get_cmap('Blues')
    
    template = '<div style="background-color: white; display: flex; white-space: pre"><span class="barcode"; style="color: black; background-color: {}" title="{}">{}</span><div>'
    colored_string = ''
    
    for word, credence, norm_credence in zip(words, credences, normalized_credences):
        word = word.replace('▁', ' ')
        color = cmap(norm_credence)[:3]
        max_col_num = int(255) # half intensity
        # color = 'rgb(' + str(int(color[0]*max_col_num)) + ',' + str(int(color[1]*max_col_num)) + ',' + str(int(color[2]*max_col_num)) + ')'
        color = f'rgba({color[0]*255}, {color[1]*255}, {color[2]*255}, {0.5})' # TODO figure out how to give it white background
        # print(word, credence, norm_credence)
        colored_string += template.format(color, str(credence), word)
    
    return colored_string

# example usage:
words = input_tokens
colors = credences[-1] # associated values

# Colorizing the text
print(words, colors)
colored_string = colorize(words, colors)

# To display the colored string in Jupyter Notebook
display(HTML(colored_string))

['<bos>', '▁This', '▁film', '▁is', '▁terrible', ',', '▁best', '▁to', '▁pass', '.'] [14.801093 21.4402   25.842049 25.527487  6.562271 45.627815 26.71216
 35.150642 -9.107204 41.0222  ]


In [67]:
# path = '/home/waree/elk-reporters/huggyllama/llama-13b/sethapun/imdb_misspelled_0/llama13b-imdb0/reporters'
path = '/mnt/ssd-1/spar/jon/elk-reporters/gpt2/imdb/agitated-driscoll/reporters'

def extract_sentences(text):
    sentences = re.split(r'(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?)\s', text)
    return [sentence for sentence in sentences if sentence != '']

def render_sentence(reporter, sentence):
    json_file = 'credences.json'
    try:
        credence_records = json.load(open(json_file, 'r'))
    except:
        credence_records = []

    for record in credence_records:
        if record['sentence'] == sentence and record['reporter_path'] == reporter.path:
            print('found')
            return record
    print('not found')

    input_ids = tokenizer.encode(sentence, return_tensors="pt", add_special_tokens=False)
    # get the indexes of the full stops
    words = tokenizer.convert_ids_to_tokens(input_ids[0])
    punct_indexes = [i for i, x in enumerate(words) if x == "."]
    outputs = model(input_ids, output_hidden_states=True)
    reporter_weights = reporter.weights
    hidden_states = outputs.hidden_states
    # print(hidden_states[0].shape)
    # print(len(hidden_states))
    cat_hidden_states = torch.cat(hidden_states[:-1], dim=0)
    # Use einsum to do multiplication 
    result = torch.einsum('bse,be->bs', cat_hidden_states, reporter_weights)
    # print(result[0])
    # print(result.shape)

    # sigmoid_result = torch.sigmoid(result)
    # softmax_result = torch.softmax(result, dim = -1)
    torch.set_printoptions(precision=1)
    credences = result.detach().numpy()[-1] # last layer

    # for every punct_index, get punct_index - 1.
    # get the credence for that index, and copy it until the previous punct_index + 1.
    sentencewise_credence = credences.copy()
    print(punct_indexes)
    windows = [(-1, punct_indexes[0])] + [(punct_indexes[i-1], punct_indexes[i]) for i in range(1, len(punct_indexes))]
    for a, b in windows:
        last_token_credence = sentencewise_credence[b-1]
        sentencewise_credence[a+1:b+1] = last_token_credence

    # print(words, credences[-1])
    colored_string = colorize(words, credences)
    rendered_sentencewise = colorize(words, sentencewise_credence)

    # write output to json: reporter_desc, sentence, credences
    credence_record = {
        'reporter_path': reporter.path,
        'sentence': sentence,
        'credences': credences.tolist(),
        'rendered': colored_string,
        'rendered_sentencewise': rendered_sentencewise
    }
    # write to json
    credence_records.append(credence_record)
    
    with open(json_file, 'w') as f:
        json.dump(credence_records, f, indent=4)

    return credence_record
    
def viz(reporter, sentence, sentence_level=False):
    record = render_sentence(reporter, sentence)
    display(HTML(record['rendered']))
    display(HTML(record['rendered_sentencewise']))

reporter = Reporter.from_path(reporter_paths[0][0])

viz(reporter, "I love this movie so much.")
viz(reporter, "Wow this movie is terrible.")



not found
[6]



invalid value encountered in divide



not found
[6]


In [52]:
reporters = [Reporter.from_path(path) for path, desc in reporter_paths]

sentences = extract_sentences(mitochrondria_text)

for reporter in reporters[2:3]:
    for sentence in sentences:
        # print(sentence)
        # print(tokenizer.encode(sentence))
        # input_ids=tokenizer.encode(sentence)
        # print(tokenizer.convert_ids_to_tokens(input_ids[0]))
        # viz(reporter, sentence)
        


tensor([[  319,  1380,  2878,   898, 29878,   291, 20374, 30713,   655, 30312,
         29873, 30184, 30176, 29895, 31036,   299,   374, 30184, 29876, 29914,
         29936,   715,  1380,  2878,   898,  2849, 29897,   338,   385,  2894,
          1808,  1476,   297,   278,  9101,   310,  1556,   321,  2679,   653,
          4769, 29892,  1316,   408, 15006, 29892, 18577,   322, 26933, 29875,
         29889]])
['▁A', '▁mit', 'och', 'ond', 'r', 'ion', '▁(/', 'ˌ', 'ma', 'ɪ', 't', 'ə', 'ˈ', 'k', 'ɒ', 'nd', 'ri', 'ə', 'n', '/', ';', '▁pl', '▁mit', 'och', 'ond', 'ria', ')', '▁is', '▁an', '▁organ', 'elle', '▁found', '▁in', '▁the', '▁cells', '▁of', '▁most', '▁e', 'uk', 'ary', 'otes', ',', '▁such', '▁as', '▁animals', ',', '▁plants', '▁and', '▁fung', 'i', '.']
[50]
tensor([[  341,  2049,   305,   898,  2849,   505,   263,  3765,  3813, 10800,
          3829,   322,   671, 14911,   711,   293,  4613, 12232,   304,  5706,
           594,   264,   359,   457,  3367,   561, 25715,   403,   313,  129

In [70]:
for reporter in reporters[2:3]:
    # viz(reporter, sentence, sentence_level=True)
    sentence = ' '.join(sentences)
    print(sentence)
    
    viz(reporter, sentence)

A mitochondrion (/ˌmaɪtəˈkɒndriən/; pl mitochondria) is an organelle found in the cells of most eukaryotes, such as animals, plants and fungi. Mitochondria have a double membrane structure and use aerobic respiration to generate adenosine triphosphate (ATP), which is used throughout the cell as a source of chemical energy. They were discovered by Albert von Kölliker in 1857 in the voluntary muscles of insects. The term mitochondrion was coined by Carl Benda in 1898. The mitochondrion is popularly nicknamed the "powerhouse of the cell", a phrase coined by Philip Siekevitz in a 1957 article of the same name.
found


In [79]:
from datasets import load_dataset

dataset = load_dataset("wikipedia", "20220301.simple")
train_dataset = dataset['train']

print(train_dataset)
print(len(train_dataset))
print('hi')

Found cached dataset wikipedia (/mnt/ssd-2/hf_cache/wikipedia/20220301.simple/2.0.0/aa542ed919df55cc5d3347f42dd4521d05ca68751f50dbc32bae2a7f1e167559)
100%|██████████| 1/1 [00:00<00:00, 435.86it/s]

Dataset({
    features: ['id', 'url', 'title', 'text'],
    num_rows: 205328
})





# Sanity Checks

In [20]:
# # Use einsum to do multiplication 

# reporter_weights_repeat = repeat(reporter_weights, 'b e -> b c e', c=len(input_tokens) + 1)
# result = torch.einsum('bse,bse->bs', cat_hidden_states, reporter_weights_repeat)
# print(result.shape)

# sigmoid_result = torch.sigmoid(result)
# softmax_result = torch.softmax(result, dim = -1)
# torch.set_printoptions(precision=1)

In [21]:
print(result)

tensor([[-5.1e-03,  1.5e-02,  1.8e-02,  1.3e-02,  4.8e-03, -6.8e-03, -1.9e-03,
         -2.7e-02,  2.2e-02, -2.7e-02],
        [-6.5e-01,  7.4e-02, -1.0e-01, -1.2e-01,  2.0e-01, -7.1e-02, -3.7e-02,
         -3.2e-02,  6.1e-02, -8.2e-02],
        [-2.3e+00, -1.1e+00, -2.6e-01, -1.1e+00, -5.4e-01, -1.6e+00, -4.8e-01,
         -5.9e-01, -8.8e-01, -1.6e+00],
        [-5.6e+01,  9.4e-01,  1.0e+00,  8.2e-01,  9.9e-01,  8.0e-01,  4.3e-01,
          3.7e-01,  6.7e-01,  6.2e-01],
        [-1.4e+01,  2.3e+00,  2.3e+00,  1.6e+00,  1.4e+00,  8.3e-01,  1.1e+00,
          2.5e-01,  4.8e-01,  5.1e-01],
        [-3.8e+01,  2.8e+00,  2.6e+00,  1.3e+00,  1.4e+00,  9.1e-01,  6.2e-01,
          2.1e-01,  1.2e-01,  4.3e-01],
        [ 4.5e+01, -2.6e+00, -3.2e+00, -1.5e+00, -1.9e+00, -1.1e+00, -4.0e-01,
         -4.9e-01, -1.7e-01, -4.8e-01],
        [ 3.3e+01, -2.7e+00, -3.8e+00, -1.4e+00, -1.9e+00, -1.3e+00, -9.0e-01,
         -8.8e-01, -3.5e-01, -4.4e-01],
        [ 1.7e+00,  3.7e+00,  5.2e+00,  2.7e+00,