# Interactive Simplicial Complex from Text

An interactive simplicial complex from persistent homology of an information theoretic analysis of attention. 

In [12]:
import numpy as np
import torch
from transformers import GPT2Tokenizer, GPT2Model
import gudhi as gd
import networkx as nx
from scipy.spatial.distance import jensenshannon
import plotly.graph_objs as go
import ipywidgets as widgets


def get_attention_matrix(text, model, tokenizer, layer, head):
    inputs = tokenizer(text, return_tensors="pt")
    outputs = model(**inputs, output_attentions=True)
    attention = outputs.attentions[layer][0, head].detach().cpu().numpy()
    return attention


def compute_persistence(attention_matrix):
    softmax_attention = np.exp(attention_matrix) / np.sum(np.exp(attention_matrix), axis=-1)[:, np.newaxis]
    distance_matrix = np.array([[np.sqrt(jensenshannon(softmax_attention[i], softmax_attention[j])) for j in range(softmax_attention.shape[0])] for i in range(softmax_attention.shape[0])])

    rips_complex = gd.RipsComplex(distance_matrix=distance_matrix, max_edge_length=np.inf)
    simplex_tree = rips_complex.create_simplex_tree(max_dimension=2)
    persistence = simplex_tree.persistence(min_persistence=0.01)
    return persistence, simplex_tree, distance_matrix


def plot_simplicial_complex_3d(simplex_tree, distance_matrix, title, threshold, tokens):
    g = nx.Graph()
    for (simplex, _) in simplex_tree.get_filtration():
        if len(simplex) == 2:
            if distance_matrix[simplex[0]][simplex[1]] <= threshold:
                g.add_edge(simplex[0], simplex[1])

    labels = {node: tokens[node] for node in g.nodes()}
    pos = nx.spring_layout(g, dim=3, seed=42)
    Xn = [pos[k][0] for k in g.nodes()]
    Yn = [pos[k][1] for k in g.nodes()]
    Zn = [pos[k][2] for k in g.nodes()]

    Xe = []
    Ye = []
    Ze = []
    for e in g.edges():
        Xe += [pos[e[0]][0], pos[e[1]][0], None]
        Ye += [pos[e[0]][1], pos[e[1]][1], None]
        Ze += [pos[e[0]][2], pos[e[1]][2], None]

    trace_edges = go.Scatter3d(x=Xe, y=Ye, z=Ze, mode='lines', line=dict(color='gray', width=1))
    trace_nodes = go.Scatter3d(x=Xn, y=Yn, z=Zn, mode='markers+text', text=list(labels.values()), marker=dict(symbol='circle', size=10, color='lightblue'), textposition="top center")

    layout = go.Layout(title=title, scene=dict(xaxis_title='X', yaxis_title='Y', zaxis_title='Z'), showlegend=False)
    fig = go.Figure(data=[trace_edges, trace_nodes], layout=layout)
    fig.show()


if __name__ == '__main__':
    tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
    model = GPT2Model.from_pretrained('gpt2')

    
text1 = "Quantum information theory is interesting"
text2 = "Quantum information theory allows us to study attention using entanglement"

layer = 1
head = 2

attention_matrix1 = get_attention_matrix(text1, model, tokenizer, layer, head)
attention_matrix2 = get_attention_matrix(text2, model, tokenizer, layer, head)

persistence1, simplex_tree1, distance_matrix1 = compute_persistence(attention_matrix1)
persistence2, simplex_tree2, distance_matrix2 = compute_persistence(attention_matrix2)

tokens1 = [tokenizer.decode(token_id) for token_id in tokenizer.encode(text1)]
tokens2 = [tokenizer.decode(token_id) for token_id in tokenizer.encode(text2)]

threshold_slider = widgets.FloatSlider(
    value=0.05,
    min=0.01,
    max=0.4,
    step=0.01,
    description='Threshold:',
    continuous_update=False
)

def update_plot(threshold):
    plot_simplicial_complex_3d(simplex_tree1, distance_matrix1, "Simplicial Complex for Text 1", threshold, tokens1)
    plot_simplicial_complex_3d(simplex_tree2, distance_matrix2, "Simplicial Complex for Text 2", threshold, tokens2)

widgets.interact(update_plot, threshold=threshold_slider)


interactive(children=(FloatSlider(value=0.05, continuous_update=False, description='Threshold:', max=0.4, min=…

<function __main__.update_plot(threshold)>

In [9]:
import numpy as np
import torch
from transformers import GPT2Tokenizer, GPT2Model
import gudhi as gd
import matplotlib.pyplot as plt
import networkx as nx
from scipy.spatial.distance import jensenshannon
import plotly.graph_objs as go
import ipywidgets as widgets
from IPython.display import display


def get_attention_matrix(text, model, tokenizer, layer, head):
    inputs = tokenizer(text, return_tensors="pt")
    outputs = model(**inputs, output_attentions=True)
    attention = outputs.attentions[layer][0, head].detach().cpu().numpy()
    return attention

def compute_persistence(attention_matrix):
    softmax_attention = np.exp(attention_matrix) / np.sum(np.exp(attention_matrix), axis=-1)[:, np.newaxis]
    distance_matrix = np.array([[np.sqrt(jensenshannon(softmax_attention[i], softmax_attention[j])) for j in range(softmax_attention.shape[0])] for i in range(softmax_attention.shape[0])])
    
    rips_complex = gd.RipsComplex(distance_matrix=distance_matrix, max_edge_length=np.inf)
    simplex_tree = rips_complex.create_simplex_tree(max_dimension=2)
    persistence = simplex_tree.persistence(min_persistence=0.01)
    return persistence, simplex_tree, distance_matrix

def plot_simplicial_complex_3d(simplex_tree, distance_matrix, title, threshold, tokens):
    g = nx.Graph()
    for (simplex, _) in simplex_tree.get_filtration():
        if len(simplex) == 2:
            if distance_matrix[simplex[0]][simplex[1]] <= threshold:
                g.add_edge(simplex[0], simplex[1])

    # Create a token dictionary only for nodes in the graph
    labels = {node: tokens[node] for node in g.nodes()}
    
    # 3D layout
    pos = nx.spring_layout(g, dim=3, seed=42)
    
    # Extract node coordinates
    Xn = [pos[k][0] for k in g.nodes()]
    Yn = [pos[k][1] for k in g.nodes()]
    Zn = [pos[k][2] for k in g.nodes()]
    
    # Extract edge coordinates
    Xe = []
    Ye = []
    Ze = []
    for e in g.edges():
        Xe += [pos[e[0]][0], pos[e[1]][0], None]
        Ye += [pos[e[0]][1], pos[e[1]][1], None]
        Ze += [pos[e[0]][2], pos[e[1]][2], None]
    
    # Create a trace for edges
    trace_edges = go.Scatter3d(x=Xe, y=Ye, z=Ze, mode='lines', line=dict(color='gray', width=1))
    
    # Create a trace for nodes
    trace_nodes = go.Scatter3d(x=Xn, y=Yn, z=Zn, mode='markers+text', text=list(labels.values()), marker=dict(symbol='circle', size=10, color='lightblue'), textposition="top center")
    
    # Create a layout
    layout = go.Layout(title=title, scene=dict(xaxis_title='X', yaxis_title='Y', zaxis_title='Z'), showlegend=False)
    
    # Create a plot
    fig = go.Figure(data=[trace_edges, trace_nodes], layout=layout)
    fig.show()


from transformers import BertTokenizer, BertModel

# Load pre-trained transformer model and tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-multilingual-cased')
model = BertModel.from_pretrained('bert-base-multilingual-cased')

# Text inputs
text1 = "Quantum information theory is interesting"
text2 = "תיאוריה קוונטית של מידע מעניינת"  # Hebrew translation: "Quantum information theory is interesting"


# Choose a layer and head
layer = 1
head = 2

# Get attention matrices
attention_matrix1 = get_attention_matrix(text1, model, tokenizer, layer, head)
attention_matrix2 = get_attention_matrix(text2, model, tokenizer, layer, head)

# Compute persistence and simplex trees
persistence1, simplex_tree1, distance_matrix1 = compute_persistence(attention_matrix1)
persistence2, simplex_tree2, distance_matrix2 = compute_persistence(attention_matrix2)


# Get tokens
tokens1 = tokenizer.tokenize(text1)
tokens2 = tokenizer.tokenize(text2)

# Add special tokens ([CLS], [SEP], etc.) if necessary
tokens1 = [tokenizer.decode(token_id) for token_id in tokenizer.encode(text1)]
tokens2 = [tokenizer.decode(token_id) for token_id in tokenizer.encode(text2)]

# Create a slider for the threshold value
threshold_slider = widgets.FloatSlider(
    value=0.37,
    min=0.37,
    max=0.6,
    step=0.001,
    description='Threshold:',
    continuous_update=False
)

# Define a function to update the plot based on the threshold value from the slider
def update_plot(threshold):
    plot_simplicial_complex_3d(simplex_tree1, distance_matrix1, "Simplicial Complex for Text 1", threshold, tokens1)
    plot_simplicial_complex_3d(simplex_tree2, distance_matrix2, "Simplicial Complex for Text 2", threshold, tokens2)

# Connect the slider to the update_plot function
widgets.interact(update_plot, threshold=threshold_slider)

Some weights of the model checkpoint at bert-base-multilingual-cased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


interactive(children=(FloatSlider(value=0.37, continuous_update=False, description='Threshold:', max=0.6, min=…

<function __main__.update_plot(threshold)>

Here a slightly different version that handles plots where there are no edges connecting nodes (that is, the $1$-skeleton of the complicial 

In [15]:
import numpy as np
import torch
from transformers import BertTokenizer, BertModel
import gudhi as gd
import networkx as nx
from scipy.spatial.distance import jensenshannon
import plotly.graph_objs as go
import ipywidgets as widgets


def get_attention_matrix(text, model, tokenizer, layer, head):
    inputs = tokenizer(text, return_tensors="pt")
    outputs = model(**inputs, output_attentions=True)
    attention = outputs.attentions[layer][0, head].detach().cpu().numpy()
    return attention


def compute_persistence(attention_matrix):
    softmax_attention = np.exp(attention_matrix) / np.sum(np.exp(attention_matrix), axis=-1, keepdims=True)
    distance_matrix = np.array([[np.sqrt(jensenshannon(softmax_attention[i], softmax_attention[j])) for j in range(softmax_attention.shape[0])] for i in range(softmax_attention.shape[0])])

    rips_complex = gd.RipsComplex(distance_matrix=distance_matrix, max_edge_length=np.inf)
    simplex_tree = rips_complex.create_simplex_tree(max_dimension=2)
    persistence = simplex_tree.persistence(min_persistence=0.01)
    return persistence, simplex_tree, distance_matrix


def plot_simplicial_complex_3d(simplex_tree, distance_matrix, title, threshold, tokens):
    g = nx.Graph()
    for (simplex, _) in simplex_tree.get_filtration():
        if len(simplex) == 2:
            if distance_matrix[simplex[0]][simplex[1]] <= threshold:
                g.add_edge(simplex[0], simplex[1])

    if not g.edges():
        print(f"No edges in the graph for '{title}'. Try adjusting the threshold value.")
        return

    labels = {node: tokens[node] for node in g.nodes()}
    pos = nx.spring_layout(g, dim=3, seed=42)
    Xn, Yn, Zn = zip(*pos.values())
    Xe, Ye, Ze = [], [], []
    for e in g.edges():
        Xe += [pos[e[0]][0], pos[e[1]][0], None]
        Ye += [pos[e[0]][1], pos[e[1]][1], None]
        Ze += [pos[e[0]][2], pos[e[1]][2], None]

    trace_edges = go.Scatter3d(x=Xe, y=Ye, z=Ze, mode='lines', line=dict(color='gray', width=1))
    trace_nodes = go.Scatter3d(x=Xn, y=Yn, z=Zn, mode='markers+text', text=list(labels.values()), marker=dict(symbol='circle', size=10, color='lightblue'), textposition="top center")
    layout = go.Layout(title=title, scene=dict(xaxis_title='X', yaxis_title='Y', zaxis_title='Z'), showlegend=False)
    fig = go.Figure(data=[trace_edges, trace_nodes], layout=layout)
    fig.show()


tokenizer = BertTokenizer.from_pretrained('bert-base-multilingual-cased')
model = BertModel.from_pretrained('bert-base-multilingual-cased')

text1 = "Quantum information theory is interesting"
text2 = "תיאוריה קוונטית של מידע מעניינת"

layer = 1
head = 2

attention_matrix1 = get_attention_matrix(text1, model, tokenizer, layer, head)
attention_matrix2 = get_attention_matrix(text2, model, tokenizer, layer, head)

persistence1, simplex_tree1, distance_matrix1 = compute_persistence(attention_matrix1)
persistence2, simplex_tree2, distance_matrix2 = compute_persistence(attention_matrix2)

tokens1 = [tokenizer.decode(token_id) for token_id in tokenizer.encode(text1)]
tokens2 = [tokenizer.decode(token_id) for token_id in tokenizer.encode(text2)]

threshold_slider = widgets.FloatSlider(
value=0.37,
min=0.37,
max=0.6,
step=0.001,
description='Threshold:',
continuous_update=False
)

def update_plot(threshold):
    plot_simplicial_complex_3d(simplex_tree1, distance_matrix1, "Simplicial Complex for Text 1", threshold, tokens1)
    plot_simplicial_complex_3d(simplex_tree2, distance_matrix2, "Simplicial Complex for Text 2", threshold, tokens2)

widgets.interact(update_plot, threshold=threshold_slider)


Some weights of the model checkpoint at bert-base-multilingual-cased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


interactive(children=(FloatSlider(value=0.37, continuous_update=False, description='Threshold:', max=0.6, min=…

<function __main__.update_plot(threshold)>