# Persistent Homology of Attention

![simplicial_complex_2.png](simplicial_complex_2.png)

This code is designed to visualize the attention mechanism of pre-trained transformer models using simplicial complexes. It clusters the probability distributions obtained by applying the softamx to the attention matrix of a text input. Each token's distribution is clustered based on the Jensen-Shannon distance (whcih could be substituted for another distance metric on distributions). It then computes the persistent homology, and using a slider plots the associated simplicial complex for that scale parameter value. At each scale, a new ($1$-skeleton of a) simplicial complex is plotted. This form of clustering provides us a graph that connects nodes that are nearby. The simplicial complex gives higher dimensional information about how the distributions are related. Note, this could also be applied to vector embeddings of tokens using the Euclidean distance metric (or something similar for vectors). It consists of several functions that handle attention matrix computation, persistence computation, and 3D visualization, as well as interactive widgets to adjust input parameters.

The code uses persistent homology to analyze the structure of attention mechanisms in transformer models by visualizing the relationships between tokens as simplicial complexes. Persistent homology is a mathematical method that studies the topological features of a shape across different scales. It provides a measure of the significance of topological features such as connected components, loops, and voids, which helps to understand the underlying structure in data.

In this code, persistent homology is used in the following way:

1. The `compute_persistence()` function calculates the persistence, simplex tree, and distance matrix based on the attention matrix. The attention matrix is first transformed into a probability distribution using the softmax function. Then, the Jensen-Shannon distance is calculated between each pair of probability distributions, creating a distance matrix.

2. The distance matrix is used to construct a Rips complex, which is a simplicial complex built by connecting points within a certain distance threshold. In this code, the Rips complex is created using the `gd.RipsComplex()` function from the Gudhi library.

3. The Rips complex is then converted into a simplex tree using the `create_simplex_tree()` method. A simplex tree is a data structure that represents a filtered simplicial complex, which is a simplicial complex where each simplex is associated with a value called its filtration value.

4. The persistence of the topological features is computed using the `persistence()` method. Persistence is a measure of the importance of a topological feature based on how long it persists across different scales (filtration values). 

5. The `plot_simplicial_complex_3d()` function generates a 3D visualization of the simplicial complex based on the simplex tree and distance matrix. The visualization shows the relationships between tokens as edges, with the persistence threshold determining which edges are displayed.

By using persistent homology, the code provides a way to study the attention mechanism's structure in transformer models and visualize how tokens are related to each other.

In [5]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [6]:
pip install nbconvert -q

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
Note: you may need to restart the kernel to use updated packages.


In [2]:
pip install numpy torch transformers gudhi matplotlib networkx scipy plotly ipywidgets -q

Note: you may need to restart the kernel to use updated packages.


In [7]:
!jupyter nbconvert --to html --template=nbextensions --ExecutePreprocessor.enabled=True --ExecutePreprocessor.timeout=120 notebook.ipynb

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
This application is used to convert notebook files (*.ipynb)
        to various other formats.


Options
The options below are convenience aliases to configurable class-options,
as listed in the "Equivalent to" description-line of the aliases.
To see all configurable class-options for some <cmd>, use:
    <cmd> --help-all

--debug
    set log level to logging.DEBUG (maximize logging output)
    Equivalent to: [--Application.log_level=10]
--show-config
    Show the application's configuration (human-readable format)
    Equivalent to: [--Application.show_config=True]
--show-config-json
    Show the application's configuration (json format)
    Equivalent to: [--Application.show_config_json=True]
--generate-co

In [3]:
import numpy as np
import torch
from transformers import AutoTokenizer, AutoModel
import gudhi as gd
import networkx as nx
from scipy.spatial.distance import jensenshannon
import plotly.graph_objs as go
from ipywidgets import interact, FloatSlider, IntSlider, Text, Dropdown, VBox, Label
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings("ignore", category=UserWarning)


def plot_persistence_diagram(persistence, title):
    gd.plot_persistence_diagram(persistence)
    plt.title(title)
    plt.show()
    
def compute_bottleneck_distance(persistence1, persistence2):
    persistence1_array = np.array([(birth, death) for dim, (birth, death) in persistence1 if dim == 0])
    persistence2_array = np.array([(birth, death) for dim, (birth, death) in persistence2 if dim == 0])
    return gd.bottleneck_distance(persistence1_array, persistence2_array)

def get_attention_matrix(text, model, tokenizer, layer, head):
    inputs = tokenizer(text, return_tensors="pt")
    outputs = model(**inputs, output_attentions=True)
    attention = outputs.attentions[layer][0, head].detach().cpu().numpy()
    return attention

def compute_persistence(attention_matrix):
    softmax_attention = np.exp(attention_matrix) / np.sum(np.exp(attention_matrix), axis=-1)[:, np.newaxis]
    distance_matrix = np.array([[np.sqrt(jensenshannon(softmax_attention[i], softmax_attention[j])) for j in range(softmax_attention.shape[0])] for i in range(softmax_attention.shape[0])])
    
    rips_complex = gd.RipsComplex(distance_matrix=distance_matrix, max_edge_length=np.inf)
    simplex_tree = rips_complex.create_simplex_tree(max_dimension=2)
    persistence = simplex_tree.persistence(min_persistence=0.01)
    return persistence, simplex_tree, distance_matrix

def plot_simplicial_complex_3d(simplex_tree, distance_matrix, title, threshold, tokens):
    g = nx.Graph()
    for (simplex, _) in simplex_tree.get_filtration():
        if len(simplex) == 2:
            if distance_matrix[simplex[0]][simplex[1]] <= threshold:
                g.add_edge(simplex[0], simplex[1])

    labels = {node: tokens[node] for node in g.nodes()}
    
    pos = nx.spring_layout(g, dim=3, seed=42)
    
    Xn = [pos[k][0] for k in g.nodes()]
    Yn = [pos[k][1] for k in g.nodes()]
    Zn = [pos[k][2] for k in g.nodes()]
    
    Xe = []
    Ye = []
    Ze = []
    for e in g.edges():
        Xe += [pos[e[0]][0], pos[e[1]][0], None]
        Ye += [pos[e[0]][1], pos[e[1]][1], None]
        Ze += [pos[e[0]][2], pos[e[1]][2], None]
    
    trace_edges = go.Scatter3d(x=Xe, y=Ye, z=Ze, mode='lines', line=dict(color='gray', width=1))
    
    trace_nodes = go.Scatter3d(x=Xn, y=Yn, z=Zn, mode='markers+text', text=list(labels.values()), marker=dict(symbol='circle', size=10, color='lightblue'), textposition="top center")
    
    layout = go.Layout(title=title, scene=dict(xaxis_title='X', yaxis_title='Y', zaxis_title='Z'), showlegend=False)
    
    fig = go.Figure(data=[trace_edges, trace_nodes], layout=layout)
    fig.show()

def load_model(model_name):
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModel.from_pretrained(model_name)
    return tokenizer, model

model_dropdown = Dropdown(
    options=[
        ('GPT-2', 'gpt2'),
        ('DistilBERT', 'distilbert-base-uncased'),
        ('Bert', 'bert-base-uncased'),
        ('RoBERTa', 'roberta-base'),
        ('aleph-BeRT', 'onlplab/alephbert-base')
    ],
    value='gpt2',
    description='Model:'
)

tokenizer, model = load_model(model_dropdown.value)

text_input1 = Text(description='Text 1:', value='Quantum information theory is fascinating')
text_input2 = Text(description='Text 2:', value='Quantum information theory allows us to study attention using entanglement')

layer_slider = IntSlider(value=1, min=0, max=model.config.num_hidden_layers - 1, description='Layer:', continuous_update=False)
head_slider = IntSlider(value=2, min=0, max=model.config.num_attention_heads - 1, description='Head:', continuous_update=False)

threshold_slider = FloatSlider(value=0.05, min=0.00, max=0.5, step=0.001, description='Threshold:', continuous_update=False)

# Update the update_plot function
def update_plot(threshold, text1, text2, layer, head, model_name):
    tokenizer, model = load_model(model_name)
    layer_slider.max = model.config.num_hidden_layers - 1
    head_slider.max = model.config.num_attention_heads - 1

    attention_matrix1 = get_attention_matrix(text1, model, tokenizer, layer, head)
    attention_matrix2 = get_attention_matrix(text2, model, tokenizer, layer, head)

    persistence1, simplex_tree1, distance_matrix1 = compute_persistence(attention_matrix1)
    persistence2, simplex_tree2, distance_matrix2 = compute_persistence(attention_matrix2)

    tokens1 = tokenizer.tokenize(text1)
    tokens2 = tokenizer.tokenize(text2)

    tokens1 = [tokenizer.decode(token_id) for token_id in tokenizer.encode(text1)]
    tokens2 = [tokenizer.decode(token_id) for token_id in tokenizer.encode(text2)]

    plot_simplicial_complex_3d(simplex_tree1, distance_matrix1, "Simplicial Complex for Text 1", threshold, tokens1)
    plot_simplicial_complex_3d(simplex_tree2, distance_matrix2, "Simplicial Complex for Text 2", threshold, tokens2)
    plot_persistence_diagram(persistence1, "Persistence Diagram for Text 1")
    plot_persistence_diagram(persistence2, "Persistence Diagram for Text 2")
    # Compute and display bottleneck distance between persistence diagrams
    bottleneck_distance = compute_bottleneck_distance(persistence1, persistence2)
    print("Bottleneck distance between Text 1 and Text 2:", bottleneck_distance)

interact(update_plot, threshold=threshold_slider, text1=text_input1, text2=text_input2, layer=layer_slider, head=head_slider, model_name=model_dropdown)


interactive(children=(FloatSlider(value=0.05, continuous_update=False, description='Threshold:', max=0.5, step…

<function __main__.update_plot(threshold, text1, text2, layer, head, model_name)>

The persistent diagrams above can be used to encode the topological information given by persistent homology at all scales simultaneously. Each dot corresponds to a topological feature that lasts some amount of "time", that is, it has a `birth` and `death` coordinate that corresponds to some scale parameter value given by the `threshold` slider. 

---
Now let's run one without the persistence diagrams plotted (and without the bottleneck distance between them).

In [4]:
import numpy as np
import torch
from transformers import AutoTokenizer, AutoModel
import gudhi as gd
import networkx as nx
from scipy.spatial.distance import jensenshannon
import plotly.graph_objs as go
from ipywidgets import interact, FloatSlider, IntSlider, Text, Dropdown, VBox, Label
import warnings
warnings.filterwarnings("ignore", category=UserWarning)


def get_attention_matrix(text, model, tokenizer, layer, head):
    inputs = tokenizer(text, return_tensors="pt")
    outputs = model(**inputs, output_attentions=True)
    attention = outputs.attentions[layer][0, head].detach().cpu().numpy()
    return attention

def compute_persistence(attention_matrix):
    softmax_attention = np.exp(attention_matrix) / np.sum(np.exp(attention_matrix), axis=-1)[:, np.newaxis]
    distance_matrix = np.array([[np.sqrt(jensenshannon(softmax_attention[i], softmax_attention[j])) for j in range(softmax_attention.shape[0])] for i in range(softmax_attention.shape[0])])
    
    rips_complex = gd.RipsComplex(distance_matrix=distance_matrix, max_edge_length=np.inf)
    simplex_tree = rips_complex.create_simplex_tree(max_dimension=2)
    persistence = simplex_tree.persistence(min_persistence=0.01)
    return persistence, simplex_tree, distance_matrix

def plot_simplicial_complex_3d(simplex_tree, distance_matrix, title, threshold, tokens):
    g = nx.Graph()
    for (simplex, _) in simplex_tree.get_filtration():
        if len(simplex) == 2:
            if distance_matrix[simplex[0]][simplex[1]] <= threshold:
                g.add_edge(simplex[0], simplex[1])

    labels = {node: tokens[node] for node in g.nodes()}
    
    pos = nx.spring_layout(g, dim=3, seed=42)
    
    Xn = [pos[k][0] for k in g.nodes()]
    Yn = [pos[k][1] for k in g.nodes()]
    Zn = [pos[k][2] for k in g.nodes()]
    
    Xe = []
    Ye = []
    Ze = []
    for e in g.edges():
        Xe += [pos[e[0]][0], pos[e[1]][0], None]
        Ye += [pos[e[0]][1], pos[e[1]][1], None]
        Ze += [pos[e[0]][2], pos[e[1]][2], None]
    
    trace_edges = go.Scatter3d(x=Xe, y=Ye, z=Ze, mode='lines', line=dict(color='gray', width=1))
    
    trace_nodes = go.Scatter3d(x=Xn, y=Yn, z=Zn, mode='markers+text', text=list(labels.values()), marker=dict(symbol='circle', size=10, color='lightblue'), textposition="top center")
    
    layout = go.Layout(title=title, scene=dict(xaxis_title='X', yaxis_title='Y', zaxis_title='Z'), showlegend=False)
    
    fig = go.Figure(data=[trace_edges, trace_nodes], layout=layout)
    fig.show()

def load_model(model_name):
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModel.from_pretrained(model_name)
    return tokenizer, model

model_dropdown = Dropdown(
    options=[
        ('GPT-2', 'gpt2'),
        ('DistilBERT', 'distilbert-base-uncased'),
        ('Bert', 'bert-base-uncased'),
        ('RoBERTa', 'roberta-base'),
        ('aleph-BeRT', 'onlplab/alephbert-base')
    ],
    value='gpt2',
    description='Model:'
)

tokenizer, model = load_model(model_dropdown.value)

text_input1 = Text(description='Text 1:', value='Quantum information theory is interesting')
text_input2 = Text(description='Text 2:', value='Quantum information theory allows us to study attention using entanglement')

layer_slider = IntSlider(value=1, min=0, max=model.config.num_hidden_layers - 1, description='Layer:', continuous_update=False)
head_slider = IntSlider(value=2, min=0, max=model.config.num_attention_heads - 1, description='Head:', continuous_update=False)

threshold_slider = FloatSlider(value=0.05, min=0.00, max=0.5, step=0.001, description='Threshold:', continuous_update=False)

def update_plot(threshold, text1, text2, layer, head, model_name):
    tokenizer, model = load_model(model_name)
    layer_slider.max = model.config.num_hidden_layers - 1
    head_slider.max = model.config.num_attention_heads - 1

    attention_matrix1 = get_attention_matrix(text1, model, tokenizer, layer, head)
    attention_matrix2 = get_attention_matrix(text2, model, tokenizer, layer, head)

    persistence1, simplex_tree1, distance_matrix1 = compute_persistence(attention_matrix1)
    persistence2, simplex_tree2, distance_matrix2 = compute_persistence(attention_matrix2)

    tokens1 = tokenizer.tokenize(text1)
    tokens2 = tokenizer.tokenize(text2)

    tokens1 = [tokenizer.decode(token_id) for token_id in tokenizer.encode(text1)]
    tokens2 = [tokenizer.decode(token_id) for token_id in tokenizer.encode(text2)]

    plot_simplicial_complex_3d(simplex_tree1, distance_matrix1, "Simplicial Complex for Text 1", threshold, tokens1)
    plot_simplicial_complex_3d(simplex_tree2, distance_matrix2, "Simplicial Complex for Text 2", threshold, tokens2)

interact(update_plot, threshold=threshold_slider, text1=text_input1, text2=text_input2, layer=layer_slider, head=head_slider, model_name=model_dropdown)



interactive(children=(FloatSlider(value=0.05, continuous_update=False, description='Threshold:', max=0.5, step…

<function __main__.update_plot(threshold, text1, text2, layer, head, model_name)>