In [None]:
from trees import TreeMaker
from wrappers import GPTneoX_DenseWrapper, ActivationWrapper

import torch
import numpy as np
import umap
import plotly.graph_objects as go

In [None]:
# Install 'desc2025.xml' from MeSH website
tm = TreeMaker('desc2025.xml')

In [None]:
# View MeSH online browser to select a base tree number
key = 'G'

G, names = tm.tree_from_key(key, max_depth=0)
print(names)

In [None]:
model = "EleutherAI/pythia-2.8b-deduped"
wrapper = ActivationWrapper(model)

layer_num = 3
layer = wrapper.make_layer_wrapper(layer_num, 'mlp')

In [None]:
wrapper.batch_logits(names, tokens = 'all').shape

In [None]:
vocab_dict = wrapper.get_vocab()

In [None]:
all_acts, all_toks = layer.batch_activations(names)
last_acts, last_toks = layer.batch_activations(names, tokens='last')

print(last_acts.shape)

distances = torch.cdist(last_acts, last_acts, p=2)

In [None]:
umapper = umap.UMAP(n_components=3)
emb = umapper.fit_transform(last_acts.detach().numpy())
print(emb.shape)

x, y, z = emb[:, 0], emb[:, 1], emb[:, 2]
fig = go.Figure(data=[go.Scatter3d(
    x=x,
    y=y,
    z=z,
    mode='markers',
    marker=dict(
        size=4,
        opacity=0.8
    ),
    text=names,
    hoverinfo='text'        
)])

fig.update_layout(
    scene=dict(
        xaxis_title='X',
        yaxis_title='Y',
        zaxis_title='Z'
    ),
    margin=dict(l=0, r=0, b=0, t=0)
)
