In [1]:
from phi35.utils import *
tokenizer = load_tokenizer()
hidden_states = get_hidden_states(PROMPT, tokenizer)
embeddings_table = get_token_embeddings_table()
umap_model = get_umap_model(embeddings_table)
umap_2d = get_2d_representation(embeddings_table, umap_model)
token_ids = get_prompt_token_ids(PROMPT, tokenizer)

Loading tokenizer...
Tokenizer loaded successfully.
Getting hidden states. Recalculate: False
Loading cached output


  return torch.load(io.BytesIO(b))


Retrieving token embeddings table...
Loading existing token embeddings table
Token embeddings table shape: (32064, 3072)
Getting UMAP model. Recalculate: False
Loading existing UMAP model
Loading existing 2D representation
2D representation shape: (32064, 2)
Tokenizing prompt: 'An android named Apple was los...'
Prompt tokenized. Number of tokens: 182


In [None]:
tokenizer.encode("pizza")

In [None]:
tokenizer.convert_ids_to_tokens([282, 24990])
tokenizer.convert_tokens_to_ids(['izza'])
# tokenizer.convert_str_to_tokens("pizza")
tokenizer.encode("izza"), tokenizer.convert_tokens_to_ids(['izza'])




In [None]:
tokenizer.batch_decode([282, 24990])

In [None]:
tokenizer.encode("izza")

In [None]:
tokenizer.encode("_izza")

In [None]:
tokenizer.batch_decode([5951, 1362])


In [None]:
hidden_states.hidden_states[0].squeeze(0)


In [None]:
token_ids = tokenizer.encode(PROMPT, add_special_tokens=False)
tokens = tokenizer.convert_ids_to_tokens(token_ids)
tokens = [token.replace("▁", " ") for token in tokens if token != "<0x0A>"]


In [None]:
[token.replace("▁", " ") for token in tokens]


In [None]:
for token in tokenizer.convert_ids_to_tokens(token_ids):
    print(token)


In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as colors
from io import BytesIO
import base64
from sklearn.metrics.pairwise import cosine_similarity
import ipywidgets as widgets
from IPython.display import display

def get_color(similarity):
    cmap = colors.LinearSegmentedColormap.from_list("", ["red", "yellow", "green"])
    return colors.rgb2hex(cmap(similarity))

def get_min_max_similarity_except_self(emb, embeddings):
    prompt_similarities = cosine_similarity(emb, embeddings)
    one_mask = np.isclose(prompt_similarities, 1, atol=1e-3)
    non_one_max = np.max(prompt_similarities[~one_mask])
    prompt_max_sim = non_one_max
    prompt_min_sim = np.min(prompt_similarities)
    return prompt_min_sim, prompt_max_sim

def create_colorbar(min_sim, max_sim):
    fig, ax = plt.subplots(figsize=(6, 1))
    cmap = colors.LinearSegmentedColormap.from_list("", ["red", "yellow", "green"])
    norm = colors.Normalize(vmin=min_sim, vmax=max_sim)
    plt.colorbar(plt.cm.ScalarMappable(norm=norm, cmap=cmap), cax=ax, orientation='horizontal', label='Similarity')
    plt.tight_layout()
    
    buf = BytesIO()
    plt.savefig(buf, format='png')
    plt.close(fig)
    
    img_str = base64.b64encode(buf.getvalue()).decode()
    return f'<img src="data:image/png;base64,{img_str}">'


def on_token_click(b):
    token_id = b.token_id
    emb = token_embeddings[token_id]
    similar_tokens = get_top_similar_tokens(emb, token_embeddings, n=5, decode=True, tokenizer=tokenizer)
    result.value = f"Top 5 similar tokens for '{b.description}': {', '.join(similar_tokens)}"
    
    # Reshape emb to a 2D array
    emb_reshaped = emb.reshape(1, -1)
    similarities = cosine_similarity(emb_reshaped, prompt_embeddings).flatten()
    prompt_min, prompt_max = get_min_max_similarity_except_self(emb_reshaped, prompt_embeddings)
    
    for btn, sim in zip(token_buttons, similarities):
        if btn.token_id == token_id:
            btn.style.button_color = 'lightblue'  # Highlight the clicked token
        else:
            normalized_sim = (sim - prompt_min) / (prompt_max - prompt_min)
            normalized_sim = np.clip(normalized_sim, 0, 1)
            btn.style.button_color = get_color(normalized_sim)
    
    # Update color bar
    colorbar.value = create_colorbar(prompt_min, prompt_max)


# Calculate the maximum token length
max_token_length = max(len(tokenizer.decode([token.item()])) for token in token_ids.flatten())

# Create buttons for each token with decoded text
token_buttons = []
for token in token_ids.flatten():
    decoded_token = tokenizer.decode([token.item()])
    token_length = len(decoded_token)
    
    # Calculate width based on token length, with min and max limits
    min_width = 50  # Minimum width in pixels
    max_width = 200  # Maximum width in pixels
    width = min(max(min_width, int(token_length / max_token_length * max_width)), max_width)
    
    btn = widgets.Button(
        description=decoded_token,
        layout=widgets.Layout(width=f'{width}px', height='30px')
    )
    btn.token_id = token.item()
    btn.on_click(on_token_click)
    token_buttons.append(btn)

# Create a FlexBox layout for the buttons
button_layout = widgets.Layout(flex_flow='row wrap', align_items='flex-start')
buttons_box = widgets.Box(children=token_buttons, layout=button_layout)

# Create a toggle for normalization
normalize_toggle = widgets.ToggleButton(
    value=True,
    description='Normalize',
    disabled=False,
    button_style='', # 'success', 'info', 'warning', 'danger' or ''
    tooltip='Toggle normalization',
    icon='check' # (FontAwesome names without the `fa-` prefix)
)

# Function to handle toggle change
def on_toggle_change(change):
    if change['type'] == 'change' and change['name'] == 'value':
        # Simulate a click on the last clicked token to update the view
        last_clicked = next((btn for btn in token_buttons if btn.style.button_color == 'lightblue'), None)
        if last_clicked:
            on_token_click(last_clicked)

normalize_toggle.observe(on_toggle_change, names='value')

# Create an output widget to display results
result = widgets.Label(value="Click on a token to see similar tokens")

# Create initial color bar (will be updated on click)
colorbar = widgets.HTML(create_colorbar(0, 1))

# Display the widgets and color bar
display(widgets.VBox([normalize_toggle, buttons_box, result, colorbar]))