In [1]:
import torch
from transformers import AutoTokenizer, AutoModel
from pathlib import Path
from einops import rearrange, repeat
import os
import numpy as np
from dataclasses import dataclass
import re
import csv
import json
from IPython.display import HTML

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
model_name = "huggyllama/llama-13b"
# model_name = "gpt2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name, torch_dtype=torch.float16)

Loading checkpoint shards: 100%|██████████| 3/3 [00:14<00:00,  4.69s/it]
Some weights of the model checkpoint at huggyllama/llama-13b were not used when initializing LlamaModel: ['lm_head.weight']
- This IS expected if you are initializing LlamaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing LlamaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [4]:
model.to("cuda:2")

LlamaModel(
  (embed_tokens): Embedding(32000, 5120, padding_idx=0)
  (layers): ModuleList(
    (0-39): 40 x LlamaDecoderLayer(
      (self_attn): LlamaAttention(
        (q_proj): Linear(in_features=5120, out_features=5120, bias=False)
        (k_proj): Linear(in_features=5120, out_features=5120, bias=False)
        (v_proj): Linear(in_features=5120, out_features=5120, bias=False)
        (o_proj): Linear(in_features=5120, out_features=5120, bias=False)
        (rotary_emb): LlamaRotaryEmbedding()
      )
      (mlp): LlamaMLP(
        (gate_proj): Linear(in_features=5120, out_features=13824, bias=False)
        (down_proj): Linear(in_features=13824, out_features=5120, bias=False)
        (up_proj): Linear(in_features=5120, out_features=13824, bias=False)
        (act_fn): SiLUActivation()
      )
      (input_layernorm): LlamaRMSNorm()
      (post_attention_layernorm): LlamaRMSNorm()
    )
  )
  (norm): LlamaRMSNorm()
)

In [5]:
def load_all_reporters(num_layers: int, prefix_path: str) -> torch.Tensor:
    reporters_weights = []
    for i in range(num_layers):
        reporter_path = f"{prefix_path}/layer_{i}.pt"
        reporter = torch.load(reporter_path)
        reporters_weights.append(reporter['weight'].cpu())
    stacked = torch.cat(reporters_weights, dim=0)
    return stacked



In [19]:
@dataclass
class Reporter:
    path: str
    weights: torch.Tensor
    desc: str

    @staticmethod
    def from_path(path):
        weights = load_all_reporters(model.config.num_hidden_layers, path)
        desc = path.split('/')[-2]
        return Reporter(path, weights, desc)

reporter_paths = [
    ('/home/jon/elk-reporters/huggyllama/llama-13b/azhx/counterfact-easy/stoic-jang/reporters', 'lm negation'), # lm negation
    ('/home/jon/elk-reporters/huggyllama/llama-13b/azhx/counterfact-filtered-gptj6b/trusting-cori/reporters', 'dumb nots'), # dumb nots
    ('/home/jon/elk-reporters/huggyllama/llama-13b/azhx/counterfact-simple/jovial-lederberg/reporters', 'counterfact pairs') # counterfact pairs
]

reporters = [Reporter(path, load_all_reporters(model.config.num_hidden_layers, path), desc) for path, desc in reporter_paths]
reporter = reporters[2]

In [7]:
# Tokenize input sequence
input_sequence = "This film is terrible, best to pass."
input_tokens = tokenizer.tokenize(input_sequence)
print(input_tokens)
input_tokens.insert(0, "<bos>") # beginning of sentence
print(input_tokens)

# Encode input sequence
input_ids = tokenizer.encode(input_sequence, return_tensors="pt").to(model.device)
print(input_ids)

decoded = tokenizer.convert_ids_to_tokens(input_ids[0])
print(decoded)

# Generate hidden states
outputs = model(input_ids, output_hidden_states=True)

['▁This', '▁film', '▁is', '▁terrible', ',', '▁best', '▁to', '▁pass', '.']
['<bos>', '▁This', '▁film', '▁is', '▁terrible', ',', '▁best', '▁to', '▁pass', '.']
tensor([[    1,   910,  2706,   338, 16403, 29892,  1900,   304,  1209, 29889]],
       device='cuda:2')
['<s>', '▁This', '▁film', '▁is', '▁terrible', ',', '▁best', '▁to', '▁pass', '.']


In [8]:
hidden_states = outputs.hidden_states
print(hidden_states[0].shape)
print(len(hidden_states))
cat_hidden_states = torch.cat(hidden_states[:-1], dim=0)
print(cat_hidden_states.shape)

torch.Size([1, 10, 5120])
41
torch.Size([40, 10, 5120])


In [9]:
# path = '/home/waree/elk-reporters/huggyllama/llama-13b/sethapun/imdb_misspelled_0/llama13b-imdb0/reporters'
# path = '/mnt/ssd-1/spar/jon/elk-reporters/gpt2/imdb/agitated-driscoll/reporters'
# reporter_weights = load_all_reporters(model.config.num_hidden_layers, path)
# print(reporter_weights.shape)
# for path in reporter_paths:
#     reporter_weights = load_all_reporters(model.config.num_hidden_layers, path)
#     print(reporter_weights.shape)

In [10]:
# # Use einsum to do multiplication 
# result = torch.einsum('bse,be->bs', cat_hidden_states, reporter_weights)
# print(result[0])
# print(result.shape)

# sigmoid_result = torch.sigmoid(result)
# softmax_result = torch.softmax(result, dim = -1)
# torch.set_printoptions(precision=1)

In [11]:
# print(result.shape)

In [12]:
# print(sigmoid_result.shape)

In [13]:
# print(softmax_result.shape)

In [14]:
# import plotly.graph_objs as go
# import plotly.io as pio
# import torch

# from plotly.offline import init_notebook_mode, iplot
# init_notebook_mode(connected=True)

# # Create color scale for heatmap
# color_scale = [[0, '#FFFFFF'], [1, '#FF0000']] # white to red

# # Convert tensor to numpy array and detach gradients
# credences = result.detach().numpy()

# padded_text = np.array([[" " * 2 + "{:.2f}".format(value) + " " * 2 for value in row] for row in credences])

# # Create plotly heatmap
# heatmap = go.Heatmap(
#     z=credences, 
#     colorscale=color_scale,
#     text=padded_text,  # Set the text to be equal to z
#     texttemplate="%{text}", 
#     textfont=dict(color='black', size=12),  # Set the text color and size
# )

# # Create plot layout
# layout = go.Layout(title='Credences for Input Tokens',
#                 width=800,  # Set the width of the plot
#                 height=800,  # Set the height of the plot
#                    xaxis=dict(tickvals=list(range(len(input_tokens))),
#                               ticktext=input_tokens,
#                               tickangle=45))

# # Create plotly figure
# fig = go.Figure(data=[heatmap], layout=layout)
# # fig = fig.update_traces(text=input_tokens, texttemplate="%{text}", hovertemplate=None)

# # Display plotly figure
# # iplot(fig)

# # print(padded_text)
# input_tokens

In [15]:
# words_corresponding_to_credences = np.array([[" " * 2 + word + " " * 2 for word in input_tokens] for layer_num in range(len(credences))])

# heatmap = go.Heatmap(
#     z=credences,
#     colorscale=color_scale,
#     text=words_corresponding_to_credences,  # Set the text to be equal to z
#     texttemplate="%{text}",
#     textfont=dict(color='black', size=12),  # Set the text color and size
# )

# # Create plot layout
# layout = go.Layout(title='Credences for Input Tokens',
#                 width=800,  # Set the width of the plot
#                 height=800,  # Set the height of the plot

#                      xaxis=dict(tickvals=list(range(len(input_tokens))),
#                                 ticktext=input_tokens,
#                                 tickangle=45))

# # Create plotly figure
# fig = go.Figure(data=[heatmap], layout=layout)

# # Display plotly figure
# print(credences[-1].shape)
# iplot(fig)


In [16]:
import matplotlib.pyplot as plt

def colorize(words, credences):
    
    max_color = float(max(credences))
    min_color = float(min(credences))
    normalized_credences = (credences - min_color) / (max_color - min_color)
    
    cmap = plt.get_cmap('Blues')
    
    template = '<span class="barcode"; style="color: black; background-color: {}" title="{}">{}</span>'
    colored_string = ''
    
    for word, credence, norm_credence in zip(words, credences, normalized_credences):
        word = word.replace('▁', ' ')
        color = cmap(norm_credence)[:3]
        max_col_num = int(255) # half intensity
        # color = 'rgb(' + str(int(color[0]*max_col_num)) + ',' + str(int(color[1]*max_col_num)) + ',' + str(int(color[2]*max_col_num)) + ')'
        color = f'rgba({color[0]*255}, {color[1]*255}, {color[2]*255}, {0.5})' # TODO figure out how to give it white background
        # print(word, credence, norm_credence)
        colored_string += template.format(color, str(credence), word)
    
    return colored_string

In [64]:
# path = '/home/waree/elk-reporters/huggyllama/llama-13b/sethapun/imdb_misspelled_0/llama13b-imdb0/reporters'
# path = '/mnt/ssd-1/spar/jon/elk-reporters/gpt2/imdb/agitated-driscoll/reporters'

def extract_sentences(text):
    sentences = re.split(r'(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?)\s', text)
    # remove line feed characters
    sentences = [sentence.replace('\n', ' ') for sentence in sentences]

    return [sentence for sentence in sentences if sentence != '']

def compute_credences(reporter, sentence):

    input_ids = tokenizer.encode(sentence, return_tensors="pt", add_special_tokens=False).to(model.device)
    # get the indexes of the full stops
    words = tokenizer.convert_ids_to_tokens(input_ids[0])
    # print(words)
    punct_indexes = [i for i, x in enumerate(words) if "." in x]
    with torch.no_grad():
        outputs = model(input_ids, output_hidden_states=True)
    reporter_weights = reporter.weights.to(model.device)
    hidden_states = outputs.hidden_states
    # print(hidden_states[0].shape)
    # print(len(hidden_states))
    cat_hidden_states = torch.cat(hidden_states[:-1], dim=0)
    # Use einsum to do multiplication 
    reporter_weights = reporter_weights.half()
    result = torch.einsum('bse,be->bs', cat_hidden_states, reporter_weights)
    # print(result[0])
    # print(result.shape)

    # sigmoid_result = torch.sigmoid(result)
    # softmax_result = torch.softmax(result, dim = -1)
    torch.set_printoptions(precision=1)
    credences = result.detach().cpu().numpy()[-1] # last layer

    # for every punct_index, get punct_index - 1.
    # get the credence for that index, and copy it until the previous punct_index + 1.
    sentencewise_credence = credences.copy()
    # print(f'punct_indexes: {punct_indexes}')

    windows = [(-1, punct_indexes[0])] + [(punct_indexes[i-1], punct_indexes[i]) for i in range(1, len(punct_indexes))]
    for a, b in windows:
        last_token_credence = sentencewise_credence[b-1]
        sentencewise_credence[a+1:b+1] = last_token_credence


    # write output to json: reporter_desc, sentence, credences
    credence_record = {
        'reporter_path': reporter.path,
        'sentence': sentence,
        'credences': credences.tolist(),
        'credences_sentencewise': sentencewise_credence.tolist(),
        'tokens': words,
        # 'rendered': colored_string,
        # 'rendered_sentencewise': rendered_sentencewise
    }

    return credence_record
    
def viz(reporter, sentence, outfilename):
    # if file already exists then don't
    if os.path.exists(Path('credences') / reporter.desc / f'{outfilename}.json'):
        # print(f'found {outfilename}.json, skipping')
        return
    else:
        record = compute_credences(reporter, sentence)
    outpath = Path('credences') / reporter.desc
    if not os.path.exists(outpath):
        os.makedirs(outpath)
    with open(outpath / f'{outfilename}.json', 'w') as f:
        json.dump(record, f, indent=4)
    # display(HTML(record['rendered']))
    # display(HTML(record['rendered_sentencewise']))


mitochrondria_text = """
A mitochondrion (/ˌmaɪtəˈkɒndriən/;[1] pl mitochondria) is an organelle found in the cells of most eukaryotes, such as animals, plants and fungi. Mitochondria have a double membrane structure and use aerobic respiration to generate adenosine triphosphate (ATP), which is used throughout the cell as a source of chemical energy.[2] They were discovered by Albert von Kölliker in 1857[3]
\n\n\n\n"""
viz(reporter, "I love this movie so much.", "positive")
viz(reporter, "Wow this movie is terrible.", "negative")
viz(reporter, mitochrondria_text, "mitochrondria")


In [65]:
from datasets import load_dataset
from tqdm import tqdm

print('hi')
dataset = load_dataset("wikipedia", "20220301.simple")
train_dataset = dataset['train']


print(train_dataset)
print(len(train_dataset))
print('hi')

for reporter in reporters:
    for i in tqdm(range(1000)):
        sentence = train_dataset[i]['text']
        title = train_dataset[i]['title']
        sentence = sentence.replace('\n', ' ')
        try:
            viz(reporter, sentence, title)
        except torch.cuda.OutOfMemoryError:
            continue
        except IndexError:
            continue
        except FileNotFoundError:
            print(f'FileNotFoundError: {title}')
            continue

hi


Found cached dataset wikipedia (/mnt/ssd-2/hf_cache/wikipedia/20220301.simple/2.0.0/aa542ed919df55cc5d3347f42dd4521d05ca68751f50dbc32bae2a7f1e167559)
100%|██████████| 1/1 [00:00<00:00, 355.84it/s]


Dataset({
    features: ['id', 'url', 'title', 'text'],
    num_rows: 205328
})
205328
hi


 90%|████████▉ | 899/1000 [00:03<00:00, 197.23it/s]

FileNotFoundError: Biel/Bienne, Switzerland


100%|██████████| 1000/1000 [00:28<00:00, 34.93it/s]
 90%|█████████ | 904/1000 [06:08<00:20,  4.69it/s]

FileNotFoundError: Biel/Bienne, Switzerland


100%|██████████| 1000/1000 [06:33<00:00,  2.54it/s]
 90%|█████████ | 904/1000 [06:01<00:20,  4.73it/s]

FileNotFoundError: Biel/Bienne, Switzerland


100%|██████████| 1000/1000 [06:26<00:00,  2.59it/s]


In [66]:
# now render the credences

for reporter in reporters:
    files = os.listdir(f'credences/{reporter.desc}')
    html_file = ""
    for file in files:
        with open(f'credences/{reporter.desc}/{file}', 'r') as f:
            record = json.load(f)
            record['credences'] = np.array(record['credences'])
            record['credences_sentencewise'] = np.array(record['credences_sentencewise'])
            
            colored_string = colorize(record['tokens'], record['credences'])
            colored_string_sentencewise = colorize(record['tokens'], record['credences_sentencewise'])

            # put header on it and concatenate it
            html_file += f'<h1>{file}</h1\n'
            html_file += colored_string
            html_file += '\n'
            html_file += '<br><br>'
            html_file += colored_string_sentencewise
            html_file += '\n'

    # save html file
    with open(f'credences/{reporter.desc}.html', 'w') as f:
        f.write(html_file)


# print(html_file)

            

  normalized_credences = (credences - min_color) / (max_color - min_color)


In [70]:
for reporter in reporters:
    files = os.listdir(f'credences/{reporter.desc}')
    # filter dirs
    files = [file for file in files if file.endswith('.json')]
    for file in tqdm(files):
        if os.path.exists(f'credences/{reporter.desc}/html/{file}.html'):
            continue

        html_file = ""
        with open(f'credences/{reporter.desc}/{file}', 'r') as f:
            record = json.load(f)
            record['credences'] = np.array(record['credences'])
            record['credences_sentencewise'] = np.array(record['credences_sentencewise'])
            
            colored_string = colorize(record['tokens'], record['credences'])
            colored_string_sentencewise = colorize(record['tokens'], record['credences_sentencewise'])

            # put header on it and concatenate it
            html_file += f'<h1>{file}</h1\n'
            html_file += colored_string
            html_file += '\n'
            html_file += '<br><br>'
            html_file += colored_string_sentencewise
            html_file += '\n'

        # save html file
        with open(f'credences/{reporter.desc}/html/{file}.html', 'w') as f:
            f.write(html_file)


100%|██████████| 971/971 [00:41<00:00, 23.17it/s]  
100%|██████████| 941/941 [01:08<00:00, 13.74it/s]
  normalized_credences = (credences - min_color) / (max_color - min_color)
100%|██████████| 946/946 [01:08<00:00, 13.85it/s]


# Sanity Checks

In [64]:
# # Use einsum to do multiplication 

# reporter_weights_repeat = repeat(reporter_weights, 'b e -> b c e', c=len(input_tokens) + 1)
# result = torch.einsum('bse,bse->bs', cat_hidden_states, reporter_weights_repeat)
# print(result.shape)

# sigmoid_result = torch.sigmoid(result)
# softmax_result = torch.softmax(result, dim = -1)
# torch.set_printoptions(precision=1)

In [65]:
print(result)