In [1]:
from trees import Tree
from wrappers import GPTneoX_DenseWrapper, ActivationWrapper

import torch
import numpy as np
import umap
import plotly.graph_objects as go

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Install 'desc2025.xml' from MeSH website
tm = Tree('desc2025.xml')

In [29]:
# View MeSH online browser to select a base tree number
key = 'A08.186'

#G, names = tm.tree_from_key(key, max_depth=4)
#print(names)

names = ['The heart is']
batch_size = len(names)

In [30]:
model = "EleutherAI/pythia-410m-deduped"
wrapper = ActivationWrapper(model)

layer_num = -1
layer = wrapper.make_layer_wrapper(layer_num, 'mlp')

In [31]:
wrapper.batch_logits(names, tokens = 'all').shape

torch.Size([1, 3, 50278])

In [32]:
vocab_dict = wrapper.get_vocab()
reversed_dict = {v: k for k, v in vocab_dict.items()}

In [33]:
#all_acts, all_toks = layer.batch_activations(names)
last_acts, last_toks = layer.batch_activations(names, tokens='all')

print(last_acts.shape)

distances = torch.cdist(last_acts, last_acts, p=2)

torch.Size([1, 3, 1024])


In [34]:
num_tok_samples = 10
temp = 1.0


all_to_embed = []
colors = []

next_tokens = wrapper.generate_next_token(names, num_tok_samples, temp)


thing, new_sens = wrapper.generate_and_prepare(names, num_tok_samples, temp)
res, to = layer.batch_activations(thing, tokens='last', tokenized_prior=True)
#vectorized_map = np.vectorize(reversed_dict.get)
#tokens_next = vectorized_map(next_tokens)
all_to_embed.extend(new_sens)
blue = ['red'] * len(new_sens)
colors.extend(blue)
#print(next_tokens)

new_thing, new_new_sens = wrapper.generate_and_prepare(new_sens, num_tok_samples, temp)
new_res, new_to = layer.batch_activations(new_thing, tokens='last', tokenized_prior=True)

all_to_embed.extend(new_new_sens)
green = ['green'] * len(new_new_sens)
colors.extend(green)


print(res.shape)

print(new_res.shape)

acts = torch.cat([res, new_res], dim=0)
print(acts.shape)

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


torch.Size([10, 1024])
torch.Size([100, 1024])
torch.Size([110, 1024])


In [35]:
print(colors)

print(len(colors))
print(len(all_to_embed))

['red', 'red', 'red', 'red', 'red', 'red', 'red', 'red', 'red', 'red', 'green', 'green', 'green', 'green', 'green', 'green', 'green', 'green', 'green', 'green', 'green', 'green', 'green', 'green', 'green', 'green', 'green', 'green', 'green', 'green', 'green', 'green', 'green', 'green', 'green', 'green', 'green', 'green', 'green', 'green', 'green', 'green', 'green', 'green', 'green', 'green', 'green', 'green', 'green', 'green', 'green', 'green', 'green', 'green', 'green', 'green', 'green', 'green', 'green', 'green', 'green', 'green', 'green', 'green', 'green', 'green', 'green', 'green', 'green', 'green', 'green', 'green', 'green', 'green', 'green', 'green', 'green', 'green', 'green', 'green', 'green', 'green', 'green', 'green', 'green', 'green', 'green', 'green', 'green', 'green', 'green', 'green', 'green', 'green', 'green', 'green', 'green', 'green', 'green', 'green', 'green', 'green', 'green', 'green', 'green', 'green', 'green', 'green', 'green', 'green']
110
110


In [36]:
umapper = umap.UMAP(n_components=3)
emb = umapper.fit_transform(acts.detach().numpy())
print(emb.shape)

x, y, z = emb[:, 0], emb[:, 1], emb[:, 2]
fig = go.Figure(data=[go.Scatter3d(
    x=x,
    y=y,
    z=z,
    mode='markers',
    marker=dict(
        size=4,
        color=colors,
        opacity=0.8
    ),
    text=all_to_embed,
    hoverinfo='text'        
)])

fig.update_layout(
    scene=dict(
        xaxis_title='X',
        yaxis_title='Y',
        zaxis_title='Z'
    ),
    margin=dict(l=0, r=0, b=0, t=0)
)

(110, 3)



'force_all_finite' was renamed to 'ensure_all_finite' in 1.6 and will be removed in 1.8.



In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(model)
tokenizer.add_special_tokens({'pad_token': '[PAD]'})

tokenized_og_batch, toks = wrapper.tokenize_inputs(names)

sen_ten = []


for i in range(batch_size):
    sen = toks[i]
    if '[PAD]' in sen:
        before_pad = sen[:sen.index('[PAD]')]
    else:
        before_pad = sen
    for j in range(num_tok_samples):
        merge = before_pad + [str(next_tokens[i,j])]
        sen_ten.append(tokenizer.convert_tokens_to_ids(merge))

max_len = max(len(seq) for seq in sen_ten)
print(max_len)

padded_inputs = [
    seq + [tokenizer.pad_token_id] * (max_len - len(seq)) for seq in sen_ten
]

input_ids = torch.tensor(padded_inputs)
attention_mask = (input_ids != tokenizer.pad_token_id).long()

inputs = input_ids, attention_mask

res, to = layer.batch_activations(inputs, tokens='last', tokenized_prior=True)

print(res.shape)