In [2]:
!pip install ipywidgets -q

# Interactive Simplicial Complex GUI

Below, we provide an interactive simplicial complex GUI that allows us to look at the simplicial complex for the filtration at specific values of the filtration parameter `threshold` using the slider to specify the threshold value. 

In [5]:
import numpy as np
import torch
from transformers import GPT2Tokenizer, GPT2Model
import gudhi as gd
import matplotlib.pyplot as plt
import networkx as nx
from scipy.spatial.distance import jensenshannon
import plotly.graph_objs as go


def get_attention_matrix(text, model, tokenizer, layer, head):
    inputs = tokenizer(text, return_tensors="pt")
    outputs = model(**inputs, output_attentions=True)
    attention = outputs.attentions[layer][0, head].detach().cpu().numpy()
    return attention

def compute_persistence(attention_matrix):
    softmax_attention = np.exp(attention_matrix) / np.sum(np.exp(attention_matrix), axis=-1)[:, np.newaxis]
    distance_matrix = np.array([[jensenshannon(softmax_attention[i], softmax_attention[j]) for j in range(softmax_attention.shape[0])] for i in range(softmax_attention.shape[0])])
    
    rips_complex = gd.RipsComplex(distance_matrix=distance_matrix, max_edge_length=np.inf)
    simplex_tree = rips_complex.create_simplex_tree(max_dimension=2)
    persistence = simplex_tree.persistence(min_persistence=0.01)
    return persistence, simplex_tree, distance_matrix

def plot_simplicial_complex_3d(simplex_tree, distance_matrix, title, threshold, tokens):
    g = nx.Graph()
    for (simplex, _) in simplex_tree.get_filtration():
        if len(simplex) == 2:
            if distance_matrix[simplex[0]][simplex[1]] <= threshold:
                g.add_edge(simplex[0], simplex[1])

    # Create a token dictionary only for nodes in the graph
    labels = {node: tokens[node] for node in g.nodes()}
    
    # 3D layout
    pos = nx.spring_layout(g, dim=3, seed=42)
    
    # Extract node coordinates
    Xn = [pos[k][0] for k in g.nodes()]
    Yn = [pos[k][1] for k in g.nodes()]
    Zn = [pos[k][2] for k in g.nodes()]
    
    # Extract edge coordinates
    Xe = []
    Ye = []
    Ze = []
    for e in g.edges():
        Xe += [pos[e[0]][0], pos[e[1]][0], None]
        Ye += [pos[e[0]][1], pos[e[1]][1], None]
        Ze += [pos[e[0]][2], pos[e[1]][2], None]
    
    # Create a trace for edges
    trace_edges = go.Scatter3d(x=Xe, y=Ye, z=Ze, mode='lines', line=dict(color='gray', width=1))
    
    # Create a trace for nodes
    trace_nodes = go.Scatter3d(x=Xn, y=Yn, z=Zn, mode='markers+text', text=list(labels.values()), marker=dict(symbol='circle', size=10, color='lightblue'), textposition="top center")
    
    # Create a layout
    layout = go.Layout(title=title, scene=dict(xaxis_title='X', yaxis_title='Y', zaxis_title='Z'), showlegend=False)
    
    # Create a plot
    fig = go.Figure(data=[trace_edges, trace_nodes], layout=layout)
    fig.show()


# Load pre-trained transformer model and tokenizer
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
model = GPT2Model.from_pretrained('gpt2')

# Text inputs
text1 = "Quantum information theory is interesting"
text2 = "Quantum information theory allows us to study attention using entanglement"

# Choose a layer and head
layer = 1
head = 2

# Get attention matrices
attention_matrix1 = get_attention_matrix(text1, model, tokenizer, layer, head)
attention_matrix2 = get_attention_matrix(text2, model, tokenizer, layer, head)

# Compute persistence and simplex trees
persistence1, simplex_tree1, distance_matrix1 = compute_persistence(attention_matrix1)
persistence2, simplex_tree2, distance_matrix2 = compute_persistence(attention_matrix2)

# Threshold value
threshold = 0.1

# Get tokens
tokens1 = tokenizer.tokenize(text1)
tokens2 = tokenizer.tokenize(text2)

# Add special tokens ([CLS], [SEP], etc.) if necessary
tokens1 = [tokenizer.decode(token_id) for token_id in tokenizer.encode(text1)]
tokens2 = [tokenizer.decode(token_id) for token_id in tokenizer.encode(text2)]

plot_simplicial_complex_3d(simplex_tree1, distance_matrix1, "Simplicial Complex for Text 1", threshold, tokens1)
plot_simplicial_complex_3d(simplex_tree2, distance_matrix2, "Simplicial Complex for Text 2", threshold, tokens2)

In [4]:
import numpy as np
import torch
from transformers import GPT2Tokenizer, GPT2Model
import gudhi as gd
import matplotlib.pyplot as plt
import networkx as nx
from scipy.spatial.distance import jensenshannon
import plotly.graph_objs as go
import ipywidgets as widgets
from IPython.display import display

# Create a slider for the threshold value
threshold_slider = widgets.FloatSlider(
    value=0.05,
    min=0.01,
    max=0.2,
    step=0.01,
    description='Threshold:',
    continuous_update=False
)

# Define a function to update the plot based on the threshold value from the slider
def update_plot(threshold):
    plot_simplicial_complex_3d(simplex_tree1, distance_matrix1, "Simplicial Complex for Text 1", threshold, tokens1)
    plot_simplicial_complex_3d(simplex_tree2, distance_matrix2, "Simplicial Complex for Text 2", threshold, tokens2)

# Connect the slider to the update_plot function
widgets.interact(update_plot, threshold=threshold_slider)

interactive(children=(FloatSlider(value=0.05, continuous_update=False, description='Threshold:', max=0.2, min=…

<function __main__.update_plot(threshold)>