In [1]:
pip install ipywidgets -q

Note: you may need to restart the kernel to use updated packages.


In [None]:
import torch
from transformers import GPT2Model, GPT2TokenizerFast
from scipy.spatial import distance
import ipywidgets as widgets
from IPython.display import display

# Load pre-trained model and tokenizer
model = GPT2Model.from_pretrained('gpt2', output_hidden_states=True)
tokenizer = GPT2TokenizerFast.from_pretrained('gpt2')

# Create input widget for the sentence
sentence_widget = widgets.Textarea(
    value='',
    placeholder='Enter sentence',
    description='Sentence:',
    layout=widgets.Layout(width='auto')
)

# Create widgets for layer and head selection
layer_widget = widgets.IntText(value=0, description='Layer:')
head_widget = widgets.IntText(value=0, description='Head:')

# Create button to tokenize sentence
tokenize_button = widgets.Button(description='Tokenize')

# Placeholder for tokens widget
tokens_widget = None

# Placeholder for compute distances button
compute_button = None

def tokenize_sentence(b):
    global tokens_widget, compute_button
    
    # Tokenize the sentence
    sentence = sentence_widget.value
    inputs = tokenizer(sentence, return_tensors="pt", return_offsets_mapping=True)
    tokens = [sentence[start:end] for start, end in inputs['offset_mapping'][0]]
    
    # Create checkboxes for each token
    tokens_widget = widgets.SelectMultiple(
        options=tokens,
        rows=10,
        description='Tokens',
        disabled=False,
        layout=widgets.Layout(width='auto')
    )
    
    # Create button to compute distances
    compute_button = widgets.Button(description='Compute Distances')
    compute_button.on_click(compute_distances)
    
    # Display widgets
    display(tokens_widget, compute_button)

def compute_distances(b):
    # Get selected tokens
    selected_tokens = list(tokens_widget.value)
    
    # Tokenize the sentence
    sentence = sentence_widget.value
    inputs = tokenizer(sentence, return_tensors="pt")
    
    # Forward pass through model
    outputs = model(**inputs)
    
    # Extract the hidden states
    hidden_states = outputs.hidden_states
    
    # Get layer and head values
    layer = layer_widget.value
    head = head_widget.value
    
    # Extract context vectors for specified layer and head
    context_vectors = hidden_states[layer][0]
    
    # Select the context vectors for the selected tokens
    selected_vectors = torch.stack([context_vectors[i] for i in range(len(context_vectors)) if tokenizer.decode(inputs['input_ids'][0][i]) in selected_tokens])
    
    # Compute pairwise distances
    pairwise_distances = distance.pdist(selected_vectors.detach().numpy(), 'euclidean')
    
    # Convert the condensed distance matrix to a square form
    pairwise_distances = distance.squareform(pairwise_distances)
    
    # Print the pairwise distance matrix
    print(pairwise_distances)

# Link button click to function
tokenize_button.on_click(tokenize_sentence)

# Display widgets
display(sentence_widget, layer_widget, head_widget, tokenize_button)

Textarea(value='', description='Sentence:', layout=Layout(width='auto'), placeholder='Enter sentence')

IntText(value=0, description='Layer:')

IntText(value=0, description='Head:')

Button(description='Tokenize', style=ButtonStyle())

SelectMultiple(description='Tokens', layout=Layout(width='auto'), options=('Quant', 'um', ' information', ' th…

Button(description='Compute Distances', style=ButtonStyle())

[[  0.         139.95894228 140.80757815 141.79311334]
 [139.95894228   0.          33.9144922   37.4516537 ]
 [140.80757815  33.9144922    0.          36.55526241]
 [141.79311334  37.4516537   36.55526241   0.        ]]


The pairwise distances between the context vectors of a collection of words reflect how the model understands the relationships between these words within a given context (sentence or paragraph). 

1. **If the pairwise distances do not change in different contexts:**

    This could reflect a few things about the words and the model:
    
    * **Linguistic property:** It may suggest that the semantic relationships between the words are relatively stable across different contexts. This could be the case for words that have a fixed or limited set of meanings, or words that are strongly related to each other in some way.
    
    * **About the model:** It might indicate that the model is not sensitive to changes in context for these words, or that the specific layer and head selected does not encode context-sensitive information. GPT-2, like other transformers, is designed to capture context in its representations, but not all layers and heads capture this equally.

2. **If the pairwise distances change significantly in different contexts:**

    This could reflect a few things about the words and the model:
    
    * **Linguistic property:** It may suggest that the words have multiple meanings (polysemy) or that their semantic relationships are highly context-dependent. This could be the case for words that change their meaning based on the context they are used in.
    
    * **About the model:** It suggests that the model is sensitive to changes in context for these words, and that the specific layer and head selected are capturing this context-sensitive information. This would be expected behavior for a language model like GPT-2, which is designed to encode context in its representations.

Remember that the interpretation of these distances also depends on the layer and head chosen for the analysis. Different layers and heads of the model capture different types of information, so the same words may have different distances in different layers/heads. Now, because we want to understand the concept encoded by the subset of words as a whole, we don't want to simply compare distance matrices as matrices by using something like the Frobenius norm. We lose geometric information this way. What we really want to do is compute the persistent homology associated to the point cloud using something like Gudhi, then compare the persistence barcode diagrams using something like bottleneck or Wasserstein distance. This allows us to consider topology of the concept encoded by the subset of words as a whole. If the persistence barcode diagrams are very near to each other then the concept is stable across different concepts. Note, depending on hwo often the words co-occur in data empirically, this may or may not be a good thing. It could indicate that the model is properly encoding a strong relationship among the words, but it could also indicate that model is not sensitive enough to context and therefore is not expressive enough. 