In [1]:
pip install numpy torch transformers gudhi networkx scipy plotly ipywidgets -q

[0mNote: you may need to restart the kernel to use updated packages.


In [19]:
import numpy as np
import torch
from transformers import AutoTokenizer, AutoModel
import gudhi as gd
import networkx as nx
from scipy.spatial.distance import jensenshannon
import plotly.graph_objs as go
from ipywidgets import interact, interactive, FloatSlider, IntSlider, Text, Dropdown, VBox, Label, Button
import warnings
warnings.filterwarnings("ignore", category=UserWarning)


def get_attention_matrix(text, model, tokenizer, layer, head):
    inputs = tokenizer(text, return_tensors="pt")
    outputs = model(**inputs, output_attentions=True)
    attention = outputs.attentions[layer][0, head].detach().cpu().numpy()
    return attention

def compute_persistence(attention_matrix):
    softmax_attention = np.exp(attention_matrix) / np.sum(np.exp(attention_matrix), axis=-1)[:, np.newaxis]
    distance_matrix = np.array([[np.sqrt(jensenshannon(softmax_attention[i], softmax_attention[j])) for j in range(softmax_attention.shape[0])] for i in range(softmax_attention.shape[0])])
    
    rips_complex = gd.RipsComplex(distance_matrix=distance_matrix, max_edge_length=np.inf)
    simplex_tree = rips_complex.create_simplex_tree(max_dimension=2)
    persistence = simplex_tree.persistence(min_persistence=0.01)
    return persistence, simplex_tree, distance_matrix
def plot_simplicial_complex_3d(simplex_tree, distance_matrix, title, threshold, tokens, highlight_tokens):
    g = nx.Graph()
    for (simplex, _) in simplex_tree.get_filtration():
        if len(simplex) == 2:
            if distance_matrix[simplex[0]][simplex[1]] <= threshold:
                g.add_edge(simplex[0], simplex[1])

    labels = {node: tokens[node] for node in g.nodes()}
    
    pos = nx.spring_layout(g, dim=3, seed=42)
    
    Xn = [pos[k][0] for k in g.nodes()]
    Yn = [pos[k][1] for k in g.nodes()]
    Zn = [pos[k][2] for k in g.nodes()]
    
    Xe = []
    Ye = []
    Ze = []
    for e in g.edges():
        Xe += [pos[e[0]][0], pos[e[1]][0], None]
        Ye += [pos[e[0]][1], pos[e[1]][1], None]
        Ze += [pos[e[0]][2], pos[e[1]][2], None]
    
    trace_edges = go.Scatter3d(x=Xe, y=Ye, z=Ze, mode='lines', line=dict(color='gray', width=1))
    
    node_colors = ['red' if labels[node] in highlight_tokens else 'lightblue' for node in g.nodes()]
    trace_nodes = go.Scatter3d(x=Xn, y=Yn, z=Zn, mode='markers+text', text=list(labels.values()), marker=dict(symbol='circle', size=10, color=node_colors), textposition="top center")
    
    layout = go.Layout(title=title, scene=dict(xaxis_title='X', yaxis_title='Y', zaxis_title='Z'), showlegend=False)
    
    fig = go.Figure(data=[trace_edges, trace_nodes], layout=layout)
    fig.show()

def load_model(model_name):
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModel.from_pretrained(model_name)
    return tokenizer, model

model_dropdown = Dropdown(
    options=[
        ('GPT-2', 'gpt2'),
        ('DistilBERT', 'distilbert-base-uncased'),
        ('Bert', 'bert-base-uncased'),
        ('RoBERTa', 'roberta-base'),
        ('aleph-BeRT', 'onlplab/alephbert-base')
    ],
    value='gpt2',
    description='Model:'
)

tokenizer, model = load_model(model_dropdown.value)

text_input1 = Text(description='Text 1:', value='Quantum information theory is interesting')
text_input2 = Text(description='Text 2:', value='Quantum information theory allows us to study attention using entanglement')

layer_slider = IntSlider(value=1, min=0, max=model.config.num_hidden_layers - 1, description='Layer:', continuous_update=False)
head_slider = IntSlider(value=2, min=0, max=model.config.num_attention_heads - 1, description='Head:', continuous_update=False)

threshold_slider = FloatSlider(value=0.18, min=0.00, max=0.5, step=0.001, description='Threshold:', continuous_update=False)

subset_input = Text(description='Subset:', value='Quantum information theory')

def update_plot(threshold, text1, text2, layer, head, model_name, subset):
    tokenizer, model = load_model(model_name)
    layer_slider.max = model.config.num_hidden_layers - 1
    head_slider.max = model.config.num_attention_heads - 1

    attention_matrix1 = get_attention_matrix(text1, model, tokenizer, layer, head)
    attention_matrix2 = get_attention_matrix(text2, model, tokenizer, layer, head)

    persistence1, simplex_tree1, distance_matrix1 = compute_persistence(attention_matrix1)
    persistence2, simplex_tree2, distance_matrix2 = compute_persistence(attention_matrix2)

    tokens1 = tokenizer.tokenize(text1)
    tokens2 = tokenizer.tokenize(text2)

    tokens1 = [tokenizer.decode(token_id) for token_id in tokenizer.encode(text1)]
    tokens2 = [tokenizer.decode(token_id) for token_id in tokenizer.encode(text2)]

    if model_name == "gpt2":
        highlight_tokens = [tokenizer.convert_tokens_to_string(token) for token in tokenizer.tokenize(subset)]
    else:
        highlight_tokens = tokenizer.tokenize(subset)


    plot_simplicial_complex_3d(simplex_tree1, distance_matrix1, "Simplicial Complex for Text 1", threshold, tokens1, highlight_tokens)
    plot_simplicial_complex_3d(simplex_tree2, distance_matrix2, "Simplicial Complex for Text 2", threshold, tokens2, highlight_tokens)

# Create a VBox without the update_button
ui = VBox([
    model_dropdown, 
    text_input1, 
    text_input2, 
    subset_input, 
    layer_slider, 
    head_slider, 
    threshold_slider
])

# Pass the subset_input value to update_plot function
def wrapped_update_plot(*args, **kwargs):
    return update_plot(*args, subset=subset_input.value, **kwargs)

# Replace interact with interactive, include the model_name parameter
interactive_plot = interactive(wrapped_update_plot, threshold=threshold_slider, text1=text_input1, text2=text_input2, layer=layer_slider, head=head_slider, model_name=model_dropdown)

# Add a button to manually update the plot
update_button = Button(description="Update Plot")

def on_update_button_click(_):
    interactive_plot.update()
    
update_button.on_click(on_update_button_click)

# Add the update_button to the ui VBox
ui.children += (update_button,)

display(ui)

# Display the output of the interactive widget
interactive_plot.children[-1]


VBox(children=(Dropdown(description='Model:', options=(('GPT-2', 'gpt2'), ('DistilBERT', 'distilbert-base-unca…

Output()