# A Study of Attention in Transformers

This notebook studies attention in transformers from several perspectives. We first focus on viewing attention through the lens of "(complete) graph attention networks" and give two ways of visualizing multihead masked self-attention using both heatmaps of matrices (as is typical), as well as weighted graphs. Then we turn to a study of attention from an information theory perspective, computing the KL-divergence and the Jensen-Shannon distance between the probability distributions associated to tokens by the attention mechanism. Finally, we turn out *attention* to studying the persistent homology of these distributions based on distance matrices using the Jensen-Shannon distance. We show how to visualize the persistent homology at different scales for different input texts, and we compare specific attention heads in different transformer models using the bottleneck distance between persistence diagrams (which might be substituted for the Wasserstein distance or some other distance metric between persistence diagrams). Finally, we discuss the notion of contextual mappings described in [Are Transformers universal approximators of sequence-to-sequence functions?
](https://arxiv.org/abs/1912.10077). Also note a some relation to the idea presented in [Centroid Transformers: Learning to Abstract with Attention](https://arxiv.org/abs/2102.08606). Namely, we present a way of clustering the probability distributions here that could allow for information reduction in principle by chosing some representative of a cluster given by the persistent homology at some fixed scale. We would, in theory pick a value of the scale parameter for persistent homolog. The clusters are akin to what one would obtain using a density based scan at that fixed radius. Persistent homology also provides *"higher dimensional clustering"* information as can be seen by the plots of the $1$-skeleton of the simplicial complex obtained from persistent homology towards the end of the notebook. 

In [1]:
pip install torch transformers gudhi -q

Note: you may need to restart the kernel to use updated packages.


In [2]:
pip install torchvision timm -q

Note: you may need to restart the kernel to use updated packages.


---

## Attention Matrices and Attention Graphs

---

In [3]:
import numpy as np
import torch
import matplotlib.pyplot as plt
from transformers import GPT2Tokenizer, GPT2Model, BertTokenizer, BertModel
import ipywidgets as widgets
from IPython.display import display

# Function to initialize the tokenizer and the model
def initialize_model(model_name):
    if model_name == "gpt2":
        tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
        model = GPT2Model.from_pretrained("gpt2")
    elif model_name == "bert":
        tokenizer = BertTokenizer.from_pretrained("bert-base-multilingual-cased")
        model = BertModel.from_pretrained("bert-base-multilingual-cased")
    else:
        raise ValueError("Invalid model name")
    return tokenizer, model

# Initialize the tokenizer and the model
model_name = "gpt2"
tokenizer, model = initialize_model(model_name)

def visualize_attention(sentence1, sentence2, layer, head, model_name):
    tokenizer, model = initialize_model(model_name)
    
    sentences = [sentence1, sentence2]
    
    # Tokenize and create tensor inputs
    input_ids = [tokenizer.encode(sentence, return_tensors="pt") for sentence in sentences]

    # Process the input sentences with the selected model
    with torch.no_grad():
        outputs = [model(ids, output_attentions=True) for ids in input_ids]

    # Get the attention weights from the specified layer and head
    attention_weights = [output.attentions[layer].squeeze(0)[head] for output in outputs]

    # Plot the attention weights for both sentences
    for i, sentence in enumerate(sentences):
        tokens = tokenizer.tokenize(sentence)
        plt.figure(figsize=(6, 6))
        plt.imshow(attention_weights[i], cmap="viridis")
        plt.xticks(range(len(tokens)), tokens, rotation=90)
        plt.yticks(range(len(tokens)), tokens)
        plt.xlabel("Keys")
        plt.ylabel("Queries")
        plt.title(f"Sentence {i + 1}: Attention weights in layer {layer + 1}, head {head + 1}")
        plt.show()

# Create input widgets
sentence1_input = widgets.Text(value="Quantum information theory is interesting", description="Sentence 1:")
sentence2_input = widgets.Text(value="קוואַנטום אינפֿאָרמאַציע טעאָריע איז אינטערעסאַנט", description="Sentence 2:")
layer_slider = widgets.IntSlider(min=0, max=model.config.num_hidden_layers - 1, description="Layer:")
head_slider = widgets.IntSlider(min=0, max=model.config.num_attention_heads - 1, description="Head:")
model_dropdown = widgets.Dropdown(options=["gpt2", "bert"], value="gpt2", description="Model:")

# Create interactive output
interactive_output = widgets.interactive_output(
    visualize_attention,
    {
        "sentence1": sentence1_input,
        "sentence2": sentence2_input,
        "layer": layer_slider,
        "head": head_slider,
        "model_name": model_dropdown,
    },
)

# Display input widgets and interactive output
display(sentence1_input, sentence2_input, layer_slider, head_slider, model_dropdown, interactive_output)


Text(value='Quantum information theory is interesting', description='Sentence 1:')

Text(value='קוואַנטום אינפֿאָרמאַציע טעאָריע איז אינטערעסאַנט', description='Sentence 2:')

IntSlider(value=0, description='Layer:', max=11)

IntSlider(value=0, description='Head:', max=11)

Dropdown(description='Model:', options=('gpt2', 'bert'), value='gpt2')

Output()

Viewing the attention matrices in masked self-attention of GPT-2, we see that the entries above the diagonal are all zero. If we decide to make the matrices symmetric by replacing the zeros above the diagonal in entry $(j,i)$ of the attention matrix with its nonzero counterpart in entry $(i,j)$ below the diagonal, we can view the attention matrices as weighted adjacency matrices for a weighted "attention graph". Let's looks at an example of how this can be don in code using tokenized text. 

Note, that in the bidirectional Bert model, the attention matrices cannot be visualized as graphs as effectively unless we consider a directed edge to and from every node, that is, a (bi)directed complete graph, where there is a directed node from each vertex to every other vertex. This complicates matters when trying to visualize attention. Some other approaches to visualizing the attention graphs can be found in [BertViz](https://github.com/jessevig/bertviz). This provides visulaizations at the neuron l for GPT-2, and attention graphs for Bert models which have two compies of the nodes as apposed to only one copy. The nodes are aligned vertically in two columns and connected by weighted edges direct from left to right. To see the interactive tutorial check out [this notebook on colab](https://colab.research.google.com/drive/1hXIQ77A4TYS4y3UthWF-Ci7V7vVUoxmQ).

### An Interactive Example to Compare Heads in Different Layers

In [4]:
import torch
import networkx as nx
import matplotlib.pyplot as plt
from transformers import GPT2Tokenizer, GPT2Model
from ipywidgets import interact, IntSlider, Text

# Prepare model and tokenizer
model = GPT2Model.from_pretrained("gpt2", output_attentions=True)
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

def plot_attention_graph(layer, head, input_text):
    # Tokenize the input text and convert it to a tensor
    input_ids = torch.tensor(tokenizer.encode(input_text)).unsqueeze(0)

    # Get the model output with attention weights
    outputs = model(input_ids)

    # Extract the attention weights for the given layer and head
    attention_weights = outputs.attentions[layer - 1][0, head - 1].detach().numpy()

    # Convert the attention weights to a weighted adjacency matrix and create a directed graph
    graph = nx.DiGraph()
    sequence_length = attention_weights.shape[0]
    for i in range(sequence_length):
        for j in range(sequence_length):
            graph.add_edge(i, j, weight=attention_weights[i][j])

    # Plot the graph with edge and node labels
    plt.figure(figsize=(10, 10))
    pos = nx.circular_layout(graph)
    nx.draw_networkx_edges(graph, pos, width=2, edge_color='grey')
    edge_labels = {(u, v): f"{d['weight']:.2f}" for u, v, d in graph.edges(data=True)}
    nx.draw_networkx_edge_labels(graph, pos, edge_labels=edge_labels, font_size=12)
    nx.draw_networkx_nodes(graph, pos, node_color='lightblue', node_size=2000)
    node_labels = {i: tokenizer.decode([input_ids[0][i].item()]) for i in range(sequence_length)}
    nx.draw_networkx_labels(graph, pos, labels=node_labels, font_size=12)
    plt.title(f"Layer {layer} Head {head}")
    plt.axis('off')
    plt.show()

# Interactive widgets
layer_slider = IntSlider(min=1, max=12, step=1, value=1, description='Layer:')
head_slider = IntSlider(min=1, max=12, step=1, value=1, description='Head:')
input_text = Text(value="Quantum information theory is interesting", description='Text:')

interact(plot_attention_graph, layer=layer_slider, head=head_slider, input_text=input_text);



interactive(children=(IntSlider(value=1, description='Layer:', max=12, min=1), IntSlider(value=1, description=…

### An Example with a Vision Transformer

In this example we plot the attention graph for a visiont transformer applied to an image. This illustrated just how complicated these graphs can get. The interactive 3D plot allows us to see how each node attends to other nodes. With attention graphs this complicated it becomes difficult to visualize how each token attentds every other token. 

In [5]:
import torch
import torchvision.transforms as T
from PIL import Image
from transformers import ViTImageProcessor, ViTModel
import networkx as nx
import plotly.graph_objects as go
import numpy as np

# Load and preprocess the input image
image_path = "heart.jpg"
image = Image.open(image_path)
transform = T.Compose([T.Resize((224, 224)), T.ToTensor()])
image_tensor = transform(image)  # Do not add batch dimension

# Load the pre-trained ViT model and image processor
model_name = "google/vit-base-patch16-224"
image_processor = ViTImageProcessor.from_pretrained(model_name)
model = ViTModel.from_pretrained(model_name)

# Tokenize the image using the image processor
inputs = image_processor(images=[image_tensor], return_tensors="pt")

# Obtain the embeddings and attention weights for specific `layer` and `head` using the pre-trained ViT model
layer = 2
head = 3
outputs = model(**inputs, output_attentions=True)
attention_scores = outputs.attentions[layer][0][head].detach().cpu().numpy()

# Function to generate Fibonacci sphere points
def fibonacci_sphere_points(samples):
    points = []
    offset = 2 / samples
    increment = np.pi * (3 - np.sqrt(5))

    for i in range(samples):
        y = ((i * offset) - 1) + (offset / 2)
        r = np.sqrt(1 - pow(y, 2))
        phi = ((i + 1) % samples) * increment

        x = np.cos(phi) * r
        z = np.sin(phi) * r
        points.append((x, y, z))

    return points

# Function to plot the attention graph
def plot_attention_graph(attention_scores):
    # Create graph object
    G = nx.from_numpy_array(attention_scores)

    # Set node positions
    pos = {i: coord for i, coord in enumerate(fibonacci_sphere_points(len(G.nodes())))}

    # Set edge trace
    edge_x = []
    edge_y = []
    edge_z = []
    for edge in G.edges():
        x0, y0, z0 = pos[edge[0]]
        x1, y1, z1 = pos[edge[1]]
        edge_x.extend([x0, x1, None])
        edge_y.extend([y0, y1, None])
        edge_z.extend([z0, z1, None])

    edge_trace = go.Scatter3d(x=edge_x, y=edge_y, z=edge_z,
                          line=dict(width=0.5, color='rgba(255, 0, 5, 0.075)'))

    # Set node trace
    node_x = []
    node_y = []
    node_z = []
    node_text = []
    for node in G.nodes():
        x, y, z = pos[node]
        node_x.append(x)
        node_y.append(y)
        node_z.append(z)
        node_text.append(str(node))

    node_trace = go.Scatter3d(x=node_x, y=node_y, z=node_z,
                              mode='markers',
                              marker=dict(symbol='circle',
                                          size=7,
                                          color='#1f77b4',
                                          line=dict(width=0.5)),
                              text=node_text)

    # Set layout
    layout = go.Layout(scene=dict(xaxis=dict(title='X'),
                                   yaxis=dict(title='Y'),
                                   zaxis=dict(title='Z')), margin=dict(l=0, r=0, b=0, t=0))
    # Create figure and plot
    fig = go.Figure(data=[edge_trace, node_trace], layout=layout)
    fig.show()

#Plot the attention graph
plot_attention_graph(attention_scores)

  Referenced from: /Users/amelieschreiber/anaconda3/lib/python3.10/site-packages/torchvision/image.so
  Expected in: /Users/amelieschreiber/anaconda3/lib/python3.10/site-packages/torch/lib/libc10.dylib
  warn(
Some weights of the model checkpoint at google/vit-base-patch16-224 were not used when initializing ViTModel: ['classifier.weight', 'classifier.bias']
- This IS expected if you are initializing ViTModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ViTModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized: ['vit.pooler.dense.weight', 'vit.pooler.dense.bias']
You s

---

## Permutation Equivariance of Attention

---

We see a description of when permutation equivariance is achieved for various implementations of attention with positional encodings in $\S 4.2$ of [Group Equivariant Stand-Alone Self-Attention For Vision](https://openreview.net/forum?id=JkfYjnOEo6M). For now, let us discuss "pure self-attention" or "global self-attention" and why it is permutation equivariant. To discuss equivariance in more detail, we will need to understand "local self-attention" and "positional encodings", so we will discuss this more when we discuss positional encodings and local self-attention.

We shall prove the permutation equivariance of the attention mechanism by direct calculation. Consider a simple self-attention mechanism, which is a key component of modern transformer architectures. The attention mechanism computes the output by weighting the input vectors based on their relevance to each other.

Let $X \in \mathbb{R}^{n \times d}$ be the input matrix, where $n$ denotes the number of input vectors and $d$ represents the dimensionality of each vector. The attention mechanism calculates the Query ($Q$), Key ($K$), and Value ($V$) matrices as follows:

$$Q = XW_Q, \quad K = XW_K, \quad V = XW_V$$

where $W_Q, W_K, W_V \in \mathbb{R}^{d \times d'}$ are learnable weight matrices. The attention weights are computed using a scaled dot-product attention:

$$A = \operatorname{softmax} \left(\frac{QK^T}{\sqrt{d'}}\right)$$

The output matrix $Y \in \mathbb{R}^{n \times d}$ is then obtained by:

$$Y = AV$$

Now, let $\pi$ be a permutation of the set ${1, 2, \ldots, n}$. We will show that the attention mechanism is equivariant with respect to this permutation. Define the permutation matrix $P_\pi \in \mathbb{R}^{n \times n}$ such that $(P_\pi){ij} = 1$ if $j = \pi(i)$ and $(P\pi){ij} = 0$ otherwise. Applying the permutation to the input matrix $X$ results in $X\pi = P_\pi X$.

Next, calculate the Query, Key, and Value matrices for the permuted input:

$$Q_\pi = X_\pi W_Q = P_\pi XW_Q = P_\pi Q$$
$$K_\pi = X_\pi W_K = P_\pi XW_K = P_\pi K$$
$$V_\pi = X_\pi W_V = P_\pi XW_V = P_\pi V$$

Now, compute the attention weights for the permuted input:

$$A_\pi = \operatorname{softmax} \left(\frac{Q_\pi K_\pi^T}{\sqrt{d'}}\right) = \operatorname{softmax} \left(\frac{P_\pi Q (P_\pi K)^T}{\sqrt{d'}}\right) = \operatorname{softmax} \left(\frac{P_\pi Q K^T P_\pi^T}{\sqrt{d'}}\right) = P_\pi A P_\pi^T$$

Finally, calculate the output for the permuted input:

$$Y_\pi = A_\pi V_\pi = P_\pi A P_\pi^T (P_\pi V) = P_\pi (AV) = P_\pi Y$$

Thus, we have shown that the attention mechanism is permutation equivariant, as the output for the permuted input is the same as the permuted output:

$$Y_\pi = P_\pi Y$$

In conclusion, we have proven the permutation equivariance of the attention mechanism by direct calculation. This property demonstrates the robustness of the attention mechanism to the order of the input vectors, which is beneficial in various applications, such as natural language processing and graph neural networks.

### More on Getting Graphs from the Attention Matrix 

Can we really view attention as a graph? Well, kind of, and it has its benefits, but we should be precise about what we mean by attention, and how we are defining this graph. Attention is a complicated mechanism after all and when we add in masking, it is a little different than without. One way feels more meaningful that the other. First, let's look at some nuance in the definition of attention. If $\text{score}(x_i, x_j) = \text{score}(x_j, x_i)$, then the attention matrix is symmetric and naturally has the form of a weighted adjacency matrix. For example, this happens when attention is given by a simple dot product $\text{score}(x_i, x_j) = \langle x_i, x_j \rangle = x_i^Tx_j$. This also happens if we have a learnable matrix $A$ and $\text{score}(x_i, x_j) = x_i^TAx_j$ if $A$ is symmetric. We can then view the attention matrix $\alpha_{i,j} = \text{Attn}_{i,j}(X)$ as a weighted adjacency matrix where the nodes represent input tokens, and edge weights correspond to similarity scores (as defined by the inner product, scaled inner product, or the symmetric matrix $A$). Now, for the following definition of attention, this makes a little less sense, as the edge connecting token nodes in graph corresponding to tokens $x_i$ and $x_j$ is not the same as the weight on the edge connecting $x_j$ to $x_i$, in general. For masked self attention, this is more feasable. We'll discuss this below. Suppose $X \in \mathbb{R}^{d \times n}$ has as columns $X_i$ the $d$-dimensional embeddings of the $n$-tokens $x_1, x_2, ..., x_n$ from your input. Now, let 

$$W_QX = Q$$
$$W_KX = K$$
$$W_VX = V$$

be the learned weight matrices giving the "query" $q_i = W_QX_i$, "key" $k_i = W_KX_i$, and "value" $v_i = W_VX_i$ vectors. Then we can define attention as 

\begin{align}
\text{Attn}(X) &= \text{softmax}\left( \frac{Q^TK}{\sqrt{d}} \right)V \\
               &= \text{softmax}_j\left(\text{score}(x_i, x_j)\right)
\end{align}

From this we can derive, 

$$ \text{Attn}_{i,j}(X) = \frac{\exp \left( \frac{\langle q_i, k_j \rangle}{\sqrt{d}}\right)}{\sum_k \exp\left( \frac{\langle q_i, k_k \rangle}{\sqrt{d}} \right)}.$$

Now, note, we have turned each column into a probability distribution by applying the softmax so we have 

$$
P(X_i) = \text{softmax}\begin{pmatrix}
\frac{\langle q_i, k_1 \rangle}{\sqrt{d}}\\
\frac{\langle q_i, k_2 \rangle}{\sqrt{d}}\\
\vdots \\
\frac{\langle q_i, k_n \rangle}{\sqrt{d}}
\end{pmatrix}.
$$

Now, if we adjust for masked self attention, we can represent the attention mechanism as a direct graph (with weighted self-loops) as the entries in the attention matrix above the diagonal are zero, and so all edges are directed at token nodes that come "after" them. 

There is another way to understand attention using graphs when we view attention through the lens of (complete) graph attention networks [as explained here](https://docs.dgl.ai/en/0.8.x/tutorials/models/4_old_wines/7_transformer.html).

### Action of Permutation Group on Weighted Graphs with Weighted Self Loops

Now, let's suppose we want to kind of, force issue and treat thte atention matrix for masked self attention as a weighted adjaceny matrix (with self loops). We do this by using the entries above the diagonal (which are zero if the attention mechanism is masked) and we set each $Attn_{j,i}(X) = Attn_{i,j}(X)$, that is, replace the zero in the $(j,i)$-th entry  by the corresponding nonzero value in the $(i,j)$-th entry. 

Now, when acting on a weighted adjacency matrix by conjugating it with permutation matrices, the structure of the associated graph is preserved, although the order of vertices is altered. Let us examine this effect in detail.

Consider a graph $G = (V, E)$, where $V$ is the set of vertices and $E$ is the set of edges. Let $A \in \mathbb{R}^{n \times n}$ be the weighted adjacency matrix of $G$, where $n = |V|$. The $(i, j)$-th entry of $A$, denoted by $A_{ij}$, represents the weight of the edge connecting vertex $i$ and vertex $j$. If there is no edge between vertices $i$ and $j$, $A_{ij} = 0$.

Now, let $\pi$ be a permutation of the set ${1, 2, \ldots, n}$. Define the permutation matrix $P_\pi \in \mathbb{R}^{n \times n}$ such that $(P_\pi){ij} = 1$ if $j = \pi(i)$ and $(P\pi){ij} = 0$ otherwise. Conjugating the weighted adjacency matrix $A$ by the permutation matrix $P\pi$ yields a new weighted adjacency matrix $A_\pi$:

$$A_\pi = P_\pi A P_\pi^T$$

The effect of this conjugation on the weighted adjacency matrix and the associated graph can be analyzed as follows:

Vertex reordering: The conjugation by the permutation matrix effectively reorders the vertices of the graph according to the permutation $\pi$. The new adjacency matrix $A_\pi$ represents the same graph as the original adjacency matrix $A$, but with the vertices renumbered according to $\pi$.

Edge weights preservation: The edge weights in the graph are preserved under conjugation. Specifically, if there is an edge between vertices $i$ and $j$ with weight $A_{ij}$ in the original graph, then there is an edge between the permuted vertices $\pi(i)$ and $\pi(j)$ with the same weight in the permuted graph, i.e., $A_\pi(\pi(i), \pi(j)) = A_{ij}$.

### Permutation Equivariance of Multihead Attention 

Suppose $PX$ was given as input, where $P$ is a permutation matrix. 

First note that for each attention head, 

$$(PXW^i_K)(PXW^i_Q)^T = P(XW_K^i)(XW^i_Q)^TP^T.$$ 

After the softmax operation, we again get 

$$\sigma[P(XW_K^i)(XW^i_Q)^TP^T] = P\sigma[(XW_K^i)(XW^i_Q)^T]P^T.$$ 

Next, we have, 

$$\text{Attn}(PX) = PX + \sum_{i=1}^h P\sigma[(W^i_KX)^T(W^i_QX)]P^T \cdot (PXW^i_V)W^i_O  = P\cdot\text{Attn}(X),$$ 

where we used $P^TP = I$. Permutation equivariance of the token-wise feed-forward layer can be shown similarly: 

\begin{align}
\text{FF}(PX) &= P\cdot \text{Attn}(X) + \text{ReLU}(P\cdot\text{Attn}(X)W_1 +Pb_{1,n}1^T)W_2 + Pb_{2,n}1^T \\  
              &= P\cdot \text{Attn}(X) +P\cdot\text{ReLU}(\text{Attn}(X)W_1 +b_{1,n}1^T)W_2 + Pb_{2,n}1^T \\
              &= P\cdot \text{FF}(X)
\end{align}                

where $\text{ReLU}(XP) = \text{ReLU}(X)P$ was used. This analysis shows that the function class $T_{h,m,r}(\cdot)$ is restricted to permutation equivariant functions. This "restriction" is removed when we consider positional encodings. $\blacksquare$

### A second proof

In [Group Equivariant Stand-Alone Self-Attention For Vision
](https://arxiv.org/abs/2010.00977) in **Appendix G** we see an alternate functional proof of permutation equivariance of what they call *"global self attention without positional encodings"*. Here we will present their proof again. The notation is somewhat different, but illuminating. 

\begin{align}
m[\mathcal{L}_{\pi}[f]](i) &= \phi_{out}\left( \bigcup_{h \in [H]} \sum_{j \in S} \sigma_j\left(\langle \phi_{qry}^{(h)}(\mathcal{L}_{\pi}[f](i)) , \phi_{key}^{(h)}(\mathcal{L}_{\pi}[f](j)) \rangle \right) \phi_{val}^{(h)}(\mathcal{L}_{\pi}[f](j)) \right) \\
                           &= \phi_{out}\left( \bigcup_{h \in [H]} \sum_{j \in S} \sigma_j\left(\langle \phi_{qry}^{(h)}(f(\pi^{-1}(i))) , \phi_{key}^{(h)}(f(\pi^{-1}(j))) \rangle \right) \phi_{val}^{(h)}(f(\pi^{-1}(j))) \right)\\
                           &= \phi_{out}\left( \bigcup_{h \in [H]} \sum_{\pi(\overline{j}) \in S} \sigma_{\pi(\overline{j})}\left(\langle \phi_{qry}^{(h)}(f(\overline{i}) , \phi_{key}^{(h)}(f(\overline{j})) \rangle \right) \phi_{val}^{(h)}(f(\overline{j})) \right) \\
                           &= \phi_{out}\left( \bigcup_{h \in [H]} \sum_{\overline{j} \in S} \sigma_{\overline{j}}\left(\langle \phi_{qry}^{(h)}(f(\overline{i}) , \phi_{key}^{(h)}(f(\overline{j})) \rangle \right) \phi_{val}^{(h)}(f(\overline{j})) \right)\\
                           &= m[f](\overline{i}) \\
                           &= m[f](\pi^{-1}(i))\\
                           &= \mathcal{L}_{\pi}[m[f]](i)
\end{align}

Here we have used the substitution $\overline{i} = \pi(i)$ and $\overline{j} = \pi(j)$. Since summation is defined over the entire set we have that $\sum_{\pi(\overline{j}) \in S}[ \bullet ] = \sum_{\overline{j} \in S}[ \bullet ]$. Conclusively, we see that $m[\mathcal{L}_{\pi}[f]](i) = \mathcal{L}_{\pi}[m[f]](i)$. Hence, permutation equivariance holds. 

### Some Questions and Thoughts on Group Equivariance in General

1. Reference on permutation equivariance of attention: [Are Transformers universal approximators of sequence-to-sequence functions?](https://arxiv.org/abs/1912.10077)
2. [The general theory of permutation equivarant neural networks and higher order graph variational encoders
](https://arxiv.org/abs/2004.03990) does not provide enough details on the equivariance of the feed-forward linear layers of transformers, especially at the level of the representation theory and combinatorics of the symmetric group $S_n$ acting on tokens (or the nodes of the attention graph). Can we write up more of the details and examples so that this is more accessible to a wider audience? 
3. How can Group Equivariant Neural Networks be used for NLP? See for example [LieTransformer: Equivariant self-attention for Lie Groups
](https://arxiv.org/abs/2012.10885). 
4. Are there symmetries in language that we are missing? Invariant theory applications?
5. Is this related to graph grammars and/or the topology of language? See for example [Graph Grammars - http:/ /www.its.caltech.edu](http://www.its.caltech.edu/~matilde/GraphGrammarsLing.pdf)
6. How can we understand this through viewing transformers as ["graph attention networks"](https://docs.dgl.ai/en/0.8.x/tutorials/models/4_old_wines/7_transformer.html)? In particular, how can we apply transformers to graphs? Should we approximate arbitrary graphs with complete graphs with zero (or very small $\epsilon$) edge weights for non-existant edges? How can we use the "Laplacian positional encoding" to give node tokens a positional encoding according to the graph Laplacian? Should we tokenize edges too as in [Pure Transformers are Powerful Graph Learners | Jinwoo Kim](https://www.youtube.com/watch?v=TAKyjYoimd0&ab_channel=datamolio). 
7. How can this be made compatible with transformers? Can we use something like [Group Equivariant Stand-Alone Self-Attention For Vision
](https://arxiv.org/abs/2010.00977).
8. When do we need equivariance? What problems benefit from it? See for example [Geometric Deep Learning: Grids, Groups, Graphs, Geodesics, and Gauges
](https://arxiv.org/abs/2104.13478)
9. What other group equivariance can we find? See for example [On the Generalization of Equivariance and Convolution in Neural Networks to the Action of Compact Groups
](https://arxiv.org/abs/1802.03690) and [A General Theory of Equivariant CNNs on Homogeneous Spaces
](https://arxiv.org/abs/1811.02017). What does it say about a language or text if a neural network with a given equivariance is highly effective or inneffective at modelling that language?
10. Can we apply [The Quantum Graph Recurrent Neural Network](https://pennylane.ai/qml/demos/tutorial_qgrnn.html) to understand weighted attention graphs, and is this connected to Question 6?


### Notes on Lie Group Equivariant Self-Attention

Say $V_1, V_2$ are vector spaces (usually of the form of feature maps $V_i = \{f|f:S \to \mathbb{R}\}$) and let $\Phi:V_1 \to V_2$ be some map. Suppose there is an action of the group $G$, $\rho_1$ on $V_1$ and $\rho_2$ on $V_2$. Then we call $\Phi$ a **$G$-equivariant map** if

\begin{align}
\Phi[\rho_1(g)f] = \rho_2(g)\Phi[f].
\end{align}

Now, if we let $V_1 = V_2 = \{f|f:G \to \mathbb{R}\}$, then we have an action of $G$ on the space scalar valued function on $G$, called the **regular representation**, and $\rho_1 = \rho_2 = \pi$ so that 

\begin{align}
[\pi(g)(f)](h) = f(g^{-1}h).
\end{align}

and to be equivariant we need

\begin{align}
\Phi[\pi(g)(f(h))] &= \Phi[f(g^{-1}h)] \\
                   &= \pi(g)\Phi[f(h)].
\end{align}



---

## Information Theory and Attention 

---

### Probability Distributions from Attention

$$ 
P(X_i) = \left( \text{softmax}_j\left(\frac{\langle q_i, k_j \rangle}{\sqrt{d}}\right) \right)_{j=1}^n = \left( \frac{e^{\frac{\langle q_i, k_j \rangle }{\sqrt{d}}}}{\sum_{l=1}^n e^{\frac{\langle q_i, k_l \rangle}{\sqrt{d}}}} \right)_{j=1}^n 
$$

### Shannon Entropy
Given the probability distribution $P(X_i)$ for token $X_i$, the Shannon entropy $H(X_i)$ measures the average uncertainty or randomness associated with the attending behavior of token $X_i$. The formula for the Shannon entropy is as follows:

\begin{align}
H(X_i) &= -\sum_{j=1}^n P(X_i)_j \log_2 P(X_i)_j \\
       &= -\sum_{j=1}^n \frac{e^{\frac{\langle q_i, k_j \rangle}{\sqrt{d}}}}{\sum_{l=1}^n e^{\frac{\langle q_i, k_l \rangle}{\sqrt{d}}}} \log_2 \left( \frac{e^{\frac{\langle q_i, k_j \rangle}{\sqrt{d}}}}{\sum_{l=1}^n e^{\frac{\langle q_i, k_l \rangle}{\sqrt{d}}}} \right)
\end{align}

### KL-Divergence
To measure the dissimilarity between the attending behaviors of tokens $X_i$ and $X_j$, we can compute the Kullback-Leibler (KL) divergence, denoted as $D_{KL}(P(X_i)|| P(X_j))$. The formula for KL divergence is:

\begin{align}
D_{KL}(P(X_i) || P(X_j)) &= \sum_{k=1}^n P(X_i)_k \log_2 \frac{P(X_i)_k}{P(X_j)_k} \\
                   &= \sum_{k=1}^n \frac{e^{\frac{\langle q_i, k_k \rangle}{\sqrt{d}}}}{\sum_{l=1}^n e^{\frac{\langle q_i, k_l \rangle}{\sqrt{d}}}} \log_2 \left( \frac{\frac{e^{\frac{\langle q_i, k_k \rangle}{\sqrt{d}}}}{\sum{l=1}^n e^{\frac{\langle q_i, k_l \rangle}{\sqrt{d}}}}}{\frac{e^{\frac{\langle q_j, k_k \rangle}{\sqrt{d}}}}{\sum_{l=1}^n e^{\frac{\langle q_j, k_l \rangle}{\sqrt{d}}}}} \right)\\
                   &= \sum_{k=1}^n \frac{e^{\frac{\langle q_i, k_k \rangle}{\sqrt{d}}}}{\sum_{l=1}^n e^{\frac{\langle q_i, k_l \rangle}{\sqrt{d}}}} \left( \log_2 \left( \frac{e^{\frac{\langle q_i, k_k \rangle}{\sqrt{d}}}}{\sum_{l=1}^n e^{\frac{\langle q_i, k_l \rangle}{\sqrt{d}}}} \right) - \log_2 \left( \frac{e^{\frac{\langle q_j, k_k \rangle}{\sqrt{d}}}}{\sum_{l=1}^n e^{\frac{\langle q_j, k_l \rangle}{\sqrt{d}}}} \right) \right)
\end{align}

In [6]:
import torch
import numpy as np
from transformers import GPT2Tokenizer, GPT2Model

def pairwise_kl_divergence_gpt2(layer: int, head: int):
    # Set up the GPT-2 model and tokenizer
    model_name = 'gpt2'
    tokenizer = GPT2Tokenizer.from_pretrained(model_name)
    model = GPT2Model.from_pretrained(model_name, output_attentions=True)

    # Define the input text in English
    input_text = "Quantum Information Theory provides ways to study attention using entanglement."

    # Tokenize the input text and convert it to a tensor
    input_ids = torch.tensor(tokenizer.encode(input_text)).unsqueeze(0)

    # Get the model output with attention weights
    outputs = model(input_ids)

    # Extract the attention weights for the given layer and head
    attention_weights = outputs.attentions[layer - 1][0, head - 1].detach().numpy()

    # Compute the probability distributions P(X_i) using softmax
    softmax_weights = np.apply_along_axis(lambda x: np.exp(x) / np.sum(np.exp(x)), -1, attention_weights)

    # Compute the pairwise KL-divergence matrix D_KL(P(X_i)||P(X_j))
    n = softmax_weights.shape[0]
    kl_matrix = np.zeros((n, n))

    for i in range(n):
        for j in range(n):
            p = softmax_weights[i, :]
            q = softmax_weights[j, :]
            kl_matrix[i, j] = np.sum(p * (np.log2(p + 1e-9) - np.log2(q + 1e-9)))

    return kl_matrix

# Example usage:
layer = 1
head = 1
kl_matrix = pairwise_kl_divergence_gpt2(layer, head)
print("Matrix dimensions:", kl_matrix.shape)
print(kl_matrix)

Matrix dimensions: (14, 14)
[[0.         0.00039808 0.01222909 0.01960545 0.02434652 0.03134007
  0.04848492 0.03427438 0.07781303 0.06014758 0.04200178 0.06383259
  0.06268326 0.06864341]
 [0.0003974  0.         0.00844174 0.01453189 0.01913523 0.02534562
  0.04101073 0.02803404 0.06806736 0.05179694 0.03503851 0.05495372
  0.05383174 0.05957552]
 [0.01162895 0.00809702 0.         0.00171488 0.00360817 0.00637985
  0.01556731 0.0086439  0.03213211 0.0231753  0.0123948  0.02351705
  0.02272867 0.02714372]
 [0.01819905 0.01356786 0.00171804 0.         0.00147194 0.00347581
  0.01044507 0.00565423 0.02419614 0.01495834 0.00708822 0.01516831
  0.0152272  0.01975537]
 [0.02208225 0.01758282 0.00360888 0.00145549 0.         0.00118469
  0.00658171 0.00348158 0.01847804 0.01042882 0.00387719 0.01032058
  0.01097969 0.01501198]
 [0.02767208 0.0226668  0.00618455 0.00343837 0.00119114 0.
  0.00230053 0.00129398 0.0126713  0.00727872 0.00198932 0.0068933
  0.00721544 0.01021702]
 [0.04140078 0.

In [7]:
import torch
import numpy as np
from transformers import BertTokenizer, BertModel

def pairwise_kl_divergence(layer: int, head: int):
    # Set up the BERT model and tokenizer
    model_name = 'bert-base-multilingual-cased'
    tokenizer = BertTokenizer.from_pretrained(model_name)
    model = BertModel.from_pretrained(model_name, output_attentions=True)

    # Define the input text in Yiddish (Hebrew script)
    input_text = "קוואַנטן אינפֿאָרמאַציע טעאָריע גיט וועגן צו פֿאָרשן אויפֿמערקזאַם מיט פֿאַרשלינגן"

    # Tokenize the input text and convert it to a tensor
    input_ids = torch.tensor(tokenizer.encode(input_text)).unsqueeze(0)

    # Get the model output with attention weights
    outputs = model(input_ids)

    # Extract the attention weights for the given layer and head
    attention_weights = outputs.attentions[layer - 1][0, head - 1].detach().numpy()

    # Compute the probability distributions P(X_i) using softmax
    softmax_weights = np.apply_along_axis(lambda x: np.exp(x) / np.sum(np.exp(x)), -1, attention_weights)

    # Compute the pairwise KL-divergence matrix D_KL(P(X_i)||P(X_j))
    n = softmax_weights.shape[0]
    kl_matrix = np.zeros((n, n))

    for i in range(n):
        for j in range(n):
            p = softmax_weights[i, :]
            q = softmax_weights[j, :]
            kl_matrix[i, j] = np.sum(p * (np.log2(p + 1e-9) - np.log2(q + 1e-9)))

    return kl_matrix

# Example usage:
layer = 1  # Choose the layer number
head = 1  # Choose the head number
kl_matrix = pairwise_kl_divergence(layer, head)
print("Matrix dimensions:", kl_matrix.shape)
print(kl_matrix)



Some weights of the model checkpoint at bert-base-multilingual-cased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Matrix dimensions: (26, 26)
[[0.00000000e+00 1.93814710e-02 1.59873366e-02 1.90017484e-02
  1.74694359e-02 1.76622681e-02 1.67679116e-02 1.86270550e-02
  1.82237253e-02 1.91566683e-02 1.77614242e-02 1.68848373e-02
  1.78747624e-02 1.62867270e-02 1.63180642e-02 1.54490285e-02
  2.05162987e-02 1.76733769e-02 1.57894567e-02 1.56177320e-02
  1.81814171e-02 1.81295313e-02 1.82669945e-02 1.33567676e-02
  1.81078464e-02 4.55804542e-03]
 [1.59231164e-02 0.00000000e+00 4.18949174e-04 4.14589467e-03
  5.07617136e-04 5.78546315e-04 2.07086909e-03 3.80135025e-04
  7.45864236e-04 6.17076352e-04 4.72849188e-03 4.55387868e-04
  7.70849991e-04 1.06229773e-03 5.06504439e-04 4.49597370e-04
  4.79124719e-04 1.43430335e-03 1.00360205e-03 6.24092296e-04
  2.13528331e-03 2.10474525e-03 4.35120426e-04 1.50628109e-03
  2.08705547e-03 3.50543745e-02]
 [1.33888088e-02 4.18771990e-04 0.00000000e+00 3.14339600e-03
  3.56509583e-04 2.63554859e-04 1.19044154e-03 6.52827322e-04
  3.41096893e-04 3.80025711e-04 3.4699

If we want to do persistent homology, may want a pairwise distance matrix between tokens, but the Kullback-Leibler (KL) divergence is not a genuine distance metric because it is not symmetric and does not satisfy the triangle inequality. To turn the KL divergence into a genuine distance metric, one can use the Jensen-Shannon (JS) divergence, which is symmetric and satisfies the triangle inequality.

The Jensen-Shannon divergence is defined as the average of two KL divergences, using the average distribution of the two original distributions as a reference:

$$
JS(P, Q) = \frac{1}{2} D_{KL}(P || M) + \frac{1}{2} D_{KL}(Q || M)
$$

where $P$ and $Q$ are the original probability distributions, $D_{KL}$ is the KL divergence, and $M$ is the average distribution defined as:

$$
M = \frac{1}{2}(P + Q)
$$

However, the JS divergence is still not a proper distance metric because it does not satisfy the triangle inequality. To obtain a genuine distance metric, one can take the square root of the JS divergence. This is called the Jensen-Shannon distance:

$$
JS_{distance}(P, Q) = \sqrt{JS(P, Q)}
$$

The Jensen-Shannon distance is symmetric, non-negative, and satisfies the triangle inequality, making it a proper distance metric.

### An Example with JS-Distance

In [8]:
import torch
from transformers import BertTokenizer, BertModel
import numpy as np
from scipy.spatial.distance import jensenshannon
from ipywidgets import interact, widgets

def js_distance_matrix(layer: int, head: int, input_text1: str, input_text2: str):
    def get_distance_matrix(input_text):
        model_name = 'bert-base-uncased'
        tokenizer = BertTokenizer.from_pretrained(model_name)
        model = BertModel.from_pretrained(model_name, output_attentions=True)

        input_ids = torch.tensor(tokenizer.encode(input_text)).unsqueeze(0)
        outputs = model(input_ids)

        attention_weights = outputs.attentions[layer - 1][0, head - 1].detach().numpy()

        p_x = np.apply_along_axis(lambda x: np.exp(x) / np.sum(np.exp(x)), 1, attention_weights)

        num_tokens = p_x.shape[0]
        js_distance_mat = np.zeros((num_tokens, num_tokens))

        for i in range(num_tokens):
            for j in range(num_tokens):
                js_distance = jensenshannon(p_x[i], p_x[j])
                js_distance_mat[i, j] = js_distance

        return js_distance_mat

    js_distance_matrix_result1 = get_distance_matrix(input_text1)
    js_distance_matrix_result2 = get_distance_matrix(input_text2)

    return js_distance_matrix_result1, js_distance_matrix_result2


def interactive_js_distance_matrix(layer: int, head: int, input_text1: str, input_text2: str):
    js_distance_matrix_result1, js_distance_matrix_result2 = js_distance_matrix(layer, head, input_text1, input_text2)
    
    print("Matrix 1 dimensions:", js_distance_matrix_result1.shape)
    print(js_distance_matrix_result1)
    print("Matrix 2 dimensions:", js_distance_matrix_result2.shape)
    print(js_distance_matrix_result2)

interact(
    interactive_js_distance_matrix,
    layer=widgets.IntSlider(min=1, max=12, step=1, value=1),
    head=widgets.IntSlider(min=1, max=12, step=1, value=1),
    input_text1=widgets.Text(value='Example text 1', description='Text 1:'),
    input_text2=widgets.Text(value='Example text 2 is longer', description='Text 2:')
)

interactive(children=(IntSlider(value=1, description='layer', max=12, min=1), IntSlider(value=1, description='…

<function __main__.interactive_js_distance_matrix(layer: int, head: int, input_text1: str, input_text2: str)>

Now, because these distance matrices are in fact genuine distance martices (using Jensen-Shannon-distance), we can compute the persistent homology associate to them. We treat each token like a node in a weighted graph, with edge weights corresponding to the JS-distance between the probability distributions associated to the tokens by the attention mechanism. This gives a filtration of the weighted graph and an associated persistence diagram describing the persistent homology of the distributions. If we do this for two text inputs, and allows ourselfves to compare the attention in different layers and heads of the transformer, we can gain insight into the behavior of the attention mechanism. In particular, we see how the distributions associated to tokens cluster in an information theoretic way. 

## Persistent Homology of Attention

![simplicial_complex_1.png](simplicial_complex_1.png)

This code is designed to visualize the attention mechanism of pre-trained transformer models using simplicial complexes. It clusters the probability distributions obtained by applying the softamx to the attention matrix of a text input. Each token's distribution is clustered based on the Jensen-Shannon distance (whcih could be substituted for another distance metric on distributions). It then computes the persistent homology, and using a slider plots the associated simplicial complex for that scale parameter value. At each scale, a new ($1$-skeleton of a) simplicial complex is plotted. This form of clustering provides us a graph that connects nodes that are nearby. The simplicial complex gives higher dimensional information about how the distributions are related. Note, this could also be applied to vector embeddings of tokens using the Euclidean distance metric (or something similar for vectors). It consists of several functions that handle attention matrix computation, persistence computation, and 3D visualization, as well as interactive widgets to adjust input parameters.

The code uses persistent homology to analyze the structure of attention mechanisms in transformer models by visualizing the relationships between tokens as simplicial complexes. Persistent homology is a mathematical method that studies the topological features of a shape across different scales. It provides a measure of the significance of topological features such as connected components, loops, and voids, which helps to understand the underlying structure in data.

In this code, persistent homology is used in the following way:

1. The `compute_persistence()` function calculates the persistence, simplex tree, and distance matrix based on the attention matrix. The attention matrix is first transformed into a probability distribution using the softmax function. Then, the Jensen-Shannon distance is calculated between each pair of probability distributions, creating a distance matrix.

2. The distance matrix is used to construct a Rips complex, which is a simplicial complex built by connecting points within a certain distance threshold. In this code, the Rips complex is created using the `gd.RipsComplex()` function from the Gudhi library.

3. The Rips complex is then converted into a simplex tree using the `create_simplex_tree()` method. A simplex tree is a data structure that represents a filtered simplicial complex, which is a simplicial complex where each simplex is associated with a value called its filtration value.

4. The persistence of the topological features is computed using the `persistence()` method. Persistence is a measure of the importance of a topological feature based on how long it persists across different scales (filtration values). 

5. The `plot_simplicial_complex_3d()` function generates a 3D visualization of the simplicial complex based on the simplex tree and distance matrix. The visualization shows the relationships between tokens as edges, with the persistence threshold determining which edges are displayed.

By using persistent homology, the code provides a way to study the attention mechanism's structure in transformer models and visualize how tokens are related to each other.

In [9]:
%load_ext autoreload
%autoreload 2

In [10]:
pip install numpy torch transformers gudhi matplotlib networkx scipy plotly ipywidgets -q

Note: you may need to restart the kernel to use updated packages.


In [11]:
import numpy as np
import torch
from transformers import AutoTokenizer, AutoModel
import gudhi as gd
import networkx as nx
from scipy.spatial.distance import jensenshannon
import plotly.graph_objs as go
from ipywidgets import interact, FloatSlider, IntSlider, Text, Dropdown, VBox, Label
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings("ignore", category=UserWarning)


def plot_persistence_diagram(persistence, title):
    gd.plot_persistence_diagram(persistence)
    plt.title(title)
    plt.show()
    
def compute_bottleneck_distance(persistence1, persistence2):
    persistence1_array = np.array([(birth, death) for dim, (birth, death) in persistence1 if dim == 0])
    persistence2_array = np.array([(birth, death) for dim, (birth, death) in persistence2 if dim == 0])
    return gd.bottleneck_distance(persistence1_array, persistence2_array)

def get_attention_matrix(text, model, tokenizer, layer, head):
    inputs = tokenizer(text, return_tensors="pt")
    outputs = model(**inputs, output_attentions=True)
    attention = outputs.attentions[layer][0, head].detach().cpu().numpy()
    return attention

def compute_persistence(attention_matrix):
    softmax_attention = np.exp(attention_matrix) / np.sum(np.exp(attention_matrix), axis=-1)[:, np.newaxis]
    distance_matrix = np.array([[jensenshannon(softmax_attention[i], softmax_attention[j]) for j in range(softmax_attention.shape[0])] for i in range(softmax_attention.shape[0])])
    
    rips_complex = gd.RipsComplex(distance_matrix=distance_matrix, max_edge_length=np.inf)
    simplex_tree = rips_complex.create_simplex_tree(max_dimension=2)
    persistence = simplex_tree.persistence(min_persistence=0.01)
    return persistence, simplex_tree, distance_matrix

def plot_simplicial_complex_3d(simplex_tree, distance_matrix, title, threshold, tokens):
    g = nx.Graph()
    for (simplex, _) in simplex_tree.get_filtration():
        if len(simplex) == 2:
            if distance_matrix[simplex[0]][simplex[1]] <= threshold:
                g.add_edge(simplex[0], simplex[1])

    labels = {node: tokens[node] for node in g.nodes()}
    
    pos = nx.spring_layout(g, dim=3, seed=42)
    
    Xn = [pos[k][0] for k in g.nodes()]
    Yn = [pos[k][1] for k in g.nodes()]
    Zn = [pos[k][2] for k in g.nodes()]
    
    Xe = []
    Ye = []
    Ze = []
    for e in g.edges():
        Xe += [pos[e[0]][0], pos[e[1]][0], None]
        Ye += [pos[e[0]][1], pos[e[1]][1], None]
        Ze += [pos[e[0]][2], pos[e[1]][2], None]
    
    trace_edges = go.Scatter3d(x=Xe, y=Ye, z=Ze, mode='lines', line=dict(color='gray', width=1))
    
    trace_nodes = go.Scatter3d(x=Xn, y=Yn, z=Zn, mode='markers+text', text=list(labels.values()), marker=dict(symbol='circle', size=10, color='lightblue'), textposition="top center")
    
    layout = go.Layout(title=title, scene=dict(xaxis_title='X', yaxis_title='Y', zaxis_title='Z'), showlegend=False)
    
    fig = go.Figure(data=[trace_edges, trace_nodes], layout=layout)
    fig.show()

def load_model(model_name):
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModel.from_pretrained(model_name)
    return tokenizer, model

model_dropdown = Dropdown(
    options=[
        ('GPT-2', 'gpt2'),
        ('DistilBERT', 'distilbert-base-uncased'),
        ('Bert', 'bert-base-uncased'),
        ('RoBERTa', 'roberta-base'),
        ('aleph-BeRT', 'onlplab/alephbert-base')
    ],
    value='gpt2',
    description='Model:'
)

tokenizer, model = load_model(model_dropdown.value)

text_input1 = Text(description='Text 1:', value='Quantum information theory is fascinating')
text_input2 = Text(description='Text 2:', value='Quantum information theory allows us to study attention using entanglement')

layer_slider = IntSlider(value=1, min=0, max=model.config.num_hidden_layers - 1, description='Layer:', continuous_update=False)
head_slider = IntSlider(value=2, min=0, max=model.config.num_attention_heads - 1, description='Head:', continuous_update=False)

threshold_slider = FloatSlider(value=0.05, min=0.00, max=0.5, step=0.001, description='Threshold:', continuous_update=False)

# Update the update_plot function
def update_plot(threshold, text1, text2, layer, head, model_name):
    tokenizer, model = load_model(model_name)
    layer_slider.max = model.config.num_hidden_layers - 1
    head_slider.max = model.config.num_attention_heads - 1

    attention_matrix1 = get_attention_matrix(text1, model, tokenizer, layer, head)
    attention_matrix2 = get_attention_matrix(text2, model, tokenizer, layer, head)

    persistence1, simplex_tree1, distance_matrix1 = compute_persistence(attention_matrix1)
    persistence2, simplex_tree2, distance_matrix2 = compute_persistence(attention_matrix2)

    tokens1 = tokenizer.tokenize(text1)
    tokens2 = tokenizer.tokenize(text2)

    tokens1 = [tokenizer.decode(token_id) for token_id in tokenizer.encode(text1)]
    tokens2 = [tokenizer.decode(token_id) for token_id in tokenizer.encode(text2)]

    plot_simplicial_complex_3d(simplex_tree1, distance_matrix1, "Simplicial Complex for Text 1", threshold, tokens1)
    plot_simplicial_complex_3d(simplex_tree2, distance_matrix2, "Simplicial Complex for Text 2", threshold, tokens2)
    plot_persistence_diagram(persistence1, "Persistence Diagram for Text 1")
    plot_persistence_diagram(persistence2, "Persistence Diagram for Text 2")
    # Compute and display bottleneck distance between persistence diagrams
    bottleneck_distance = compute_bottleneck_distance(persistence1, persistence2)
    print("Bottleneck distance between Text 1 and Text 2:", bottleneck_distance)

interact(update_plot, threshold=threshold_slider, text1=text_input1, text2=text_input2, layer=layer_slider, head=head_slider, model_name=model_dropdown)


interactive(children=(FloatSlider(value=0.05, continuous_update=False, description='Threshold:', max=0.5, step…

<function __main__.update_plot(threshold, text1, text2, layer, head, model_name)>

The persistent diagrams above can be used to encode the topological information given by persistent homology at all scales simultaneously. Each dot corresponds to a topological feature that lasts some amount of "time", that is, it has a `birth` and `death` coordinate that corresponds to some scale parameter value given by the `threshold` slider. 

---
Now let's run one without the persistence diagrams plotted (and without the bottleneck distance between them).

In [12]:
import numpy as np
import torch
from transformers import AutoTokenizer, AutoModel
import gudhi as gd
import networkx as nx
from scipy.spatial.distance import jensenshannon
import plotly.graph_objs as go
from ipywidgets import interact, FloatSlider, IntSlider, Text, Dropdown, VBox, Label
import warnings
warnings.filterwarnings("ignore", category=UserWarning)


def get_attention_matrix(text, model, tokenizer, layer, head):
    inputs = tokenizer(text, return_tensors="pt")
    outputs = model(**inputs, output_attentions=True)
    attention = outputs.attentions[layer][0, head].detach().cpu().numpy()
    return attention

def compute_persistence(attention_matrix):
    softmax_attention = np.exp(attention_matrix) / np.sum(np.exp(attention_matrix), axis=-1)[:, np.newaxis]
    distance_matrix = np.array([[np.sqrt(jensenshannon(softmax_attention[i], softmax_attention[j])) for j in range(softmax_attention.shape[0])] for i in range(softmax_attention.shape[0])])
    
    rips_complex = gd.RipsComplex(distance_matrix=distance_matrix, max_edge_length=np.inf)
    simplex_tree = rips_complex.create_simplex_tree(max_dimension=2)
    persistence = simplex_tree.persistence(min_persistence=0.01)
    return persistence, simplex_tree, distance_matrix

def plot_simplicial_complex_3d(simplex_tree, distance_matrix, title, threshold, tokens):
    g = nx.Graph()
    for (simplex, _) in simplex_tree.get_filtration():
        if len(simplex) == 2:
            if distance_matrix[simplex[0]][simplex[1]] <= threshold:
                g.add_edge(simplex[0], simplex[1])

    labels = {node: tokens[node] for node in g.nodes()}
    
    pos = nx.spring_layout(g, dim=3, seed=42)
    
    Xn = [pos[k][0] for k in g.nodes()]
    Yn = [pos[k][1] for k in g.nodes()]
    Zn = [pos[k][2] for k in g.nodes()]
    
    Xe = []
    Ye = []
    Ze = []
    for e in g.edges():
        Xe += [pos[e[0]][0], pos[e[1]][0], None]
        Ye += [pos[e[0]][1], pos[e[1]][1], None]
        Ze += [pos[e[0]][2], pos[e[1]][2], None]
    
    trace_edges = go.Scatter3d(x=Xe, y=Ye, z=Ze, mode='lines', line=dict(color='gray', width=1))
    
    trace_nodes = go.Scatter3d(x=Xn, y=Yn, z=Zn, mode='markers+text', text=list(labels.values()), marker=dict(symbol='circle', size=10, color='lightblue'), textposition="top center")
    
    layout = go.Layout(title=title, scene=dict(xaxis_title='X', yaxis_title='Y', zaxis_title='Z'), showlegend=False)
    
    fig = go.Figure(data=[trace_edges, trace_nodes], layout=layout)
    fig.show()

def load_model(model_name):
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModel.from_pretrained(model_name)
    return tokenizer, model

model_dropdown = Dropdown(
    options=[
        ('GPT-2', 'gpt2'),
        ('DistilBERT', 'distilbert-base-uncased'),
        ('Bert', 'bert-base-uncased'),
        ('RoBERTa', 'roberta-base'),
        ('aleph-BeRT', 'onlplab/alephbert-base')
    ],
    value='gpt2',
    description='Model:'
)

tokenizer, model = load_model(model_dropdown.value)

text_input1 = Text(description='Text 1:', value='Quantum information theory is interesting')
text_input2 = Text(description='Text 2:', value='Quantum information theory allows us to study attention using entanglement')

layer_slider = IntSlider(value=1, min=0, max=model.config.num_hidden_layers - 1, description='Layer:', continuous_update=False)
head_slider = IntSlider(value=2, min=0, max=model.config.num_attention_heads - 1, description='Head:', continuous_update=False)

threshold_slider = FloatSlider(value=0.15, min=0.00, max=0.5, step=0.001, description='Threshold:', continuous_update=False)

def update_plot(threshold, text1, text2, layer, head, model_name):
    tokenizer, model = load_model(model_name)
    layer_slider.max = model.config.num_hidden_layers - 1
    head_slider.max = model.config.num_attention_heads - 1

    attention_matrix1 = get_attention_matrix(text1, model, tokenizer, layer, head)
    attention_matrix2 = get_attention_matrix(text2, model, tokenizer, layer, head)

    persistence1, simplex_tree1, distance_matrix1 = compute_persistence(attention_matrix1)
    persistence2, simplex_tree2, distance_matrix2 = compute_persistence(attention_matrix2)

    tokens1 = tokenizer.tokenize(text1)
    tokens2 = tokenizer.tokenize(text2)

    tokens1 = [tokenizer.decode(token_id) for token_id in tokenizer.encode(text1)]
    tokens2 = [tokenizer.decode(token_id) for token_id in tokenizer.encode(text2)]

    plot_simplicial_complex_3d(simplex_tree1, distance_matrix1, "Simplicial Complex for Text 1", threshold, tokens1)
    plot_simplicial_complex_3d(simplex_tree2, distance_matrix2, "Simplicial Complex for Text 2", threshold, tokens2)

interact(update_plot, threshold=threshold_slider, text1=text_input1, text2=text_input2, layer=layer_slider, head=head_slider, model_name=model_dropdown)



interactive(children=(FloatSlider(value=0.15, continuous_update=False, description='Threshold:', max=0.5, step…

<function __main__.update_plot(threshold, text1, text2, layer, head, model_name)>

## Contextual Mappings

We will provide a code example using GPT-2 that illustrates the concept of "contextual mapping" as defined in the following text reference (see [Are Transformers universal approximators of sequence-to-sequence functions?
](https://arxiv.org/abs/1912.10077)): 

Let us consider a setting where we are interested in embedding two sentences: 1) I am happy; and 2) I am Bob. These sentences are fed to a sequence-to-sequence model as $X=[X:,1,X:,2,X:,3] = [v_I,v_{am},v_{happy}]$ and $X'=[ X':,1, X':,2, X':,3] = [v_I,v_{am},v_{Bob}]$, where $v_I,v_{am},v_{happy}$, and $v_{Bob}$ denote d-dimensional embedding for the tokens ‘I’, ‘am’, ‘happy’, and ‘Bob’, respectively. Since the word ‘I’ occurs in different contexts in these sentences, in order to implement arbitrary sequence-to-sequence functions, the sequence-to-sequence model should map the two occurrences of ‘I’ to **different values**. We formally define this requirement below. 

**Definition 3.1 (Contextual mapping)**. Consider a finite set $\mathbb{L} \subset \mathbb{R}^{d \times n}$. A map $q : \mathbb{L} \to \mathbb{R}^{1 \times n}$ defines a contextual mapping if the map satisfies the following: 

1. For any $L \in \mathbb{L}$, the $n$ entries in $q(L)$ are all distinct. 
2. For any $L, L' \in \mathbb{L}$,with $L \neq L'$, all entries of $q(L)$ and $q(L')$ are distinct. 

In other words, a contextual mapping maps each token (column) of $L \in \mathbb{L}$ to a unique value which depends on the entire $L$; as a result, capturing the precise context of $L$. This allows the subsequent token-wise function (e.g., defined by the feed-forward layers in case of Transformer networks) to realize the outputs of any arbitrary sequence-to-sequence functions.

While the self-attention layer does consider pair-wise interactions among different input tokens, it is not clear if this weak form of pair-wise interaction with shared projection weights is sufficient to extract the underlying context. The following result, which we sketch here, shows that self-attention layers can implement a **permutation equivariant contextual mapping** over almost all elements of a grid in $[0,1]^{d \times n}$. We defer the full statement to Section 4.2. 

**Lemma 6 (informal)**. Consider the grid $G_{\delta} := \{0,\delta ,...,1 − \delta\}^{d \times n}$. Then, there exist a function $g_c : \mathbb{R}^{d \times n} \to \mathbb{R}^{d \times n}$ composed of $\delta^{−d}+1$ self-attention layers ($h = 2, m = 1$) and a vector $u \in \mathbb{R}^d$ such that $q(L) := u^Tg_c(L)$ satisfies the following properties, for a subset $\tilde{G_{\delta}} \subset G_{\delta}$ that contains almost all elements of $G_{\delta}$: 

1. For any $L \in \tilde{G_{\delta}}$, the entries of $q(L)$ are all distinct. 
2. For any $L, L' \in \tilde{G_{\delta}}$ such that $L$ is not a permutation of $L'$, all entries of $q(L)$, $q(L')$ are distinct. 

Lemma 6 shows that a series of self-attention layers can implement contextual mappings, despite the apparent restriction that each of them can only capture pair-wise interaction. However, the restriction of permutation equivarance still exists because attention layers are inherently permutation equivariant. Coupled with the ability of token-wise feed-forward layers to map different values in $q(L)$ to arbitrary output values, we can prove universal approximation capability of Transformers.

This text is discussing the process of embedding sentences in the context of a sequence-to-sequence (seq2seq) model, which is a type of neural network used for various natural language processing tasks. The main idea is to represent the sentences using embeddings, which are numerical representations that capture the meaning of words or tokens. In this case, the sentences "I am happy" and "I am Bob" are being considered.

The sentences are represented as sequences of embeddings: X for the first sentence and ˜X for the second sentence. The individual tokens ('I', 'am', 'happy', and 'Bob') have corresponding d-dimensional embeddings ($v_I, v_{am}, v_{happy}$, and $v_{Bob}$).

The text then introduces the concept of contextual mapping. This is important because the word 'I' appears in both sentences, but its meaning or context is different in each case. To ensure the seq2seq model can handle these different contexts, it should map the two occurrences of 'I' to different values.

A contextual mapping (defined as q) is a function that maps a sequence of embeddings (L) to a unique value (R1×n), ensuring that the same token in different contexts gets different values. The definition 3.1 presents two conditions for a proper contextual mapping:

All entries in q(L) are distinct for any L, meaning that each token within a single sequence gets a unique value.
If $L$ and $L'$ are different sequences $(L \neq L')$, all entries of $q(L)$ and $q(L')$ are distinct, ensuring that the same token in different sequences gets different values.

By meeting these conditions, the contextual mapping function $q$ can capture the precise context of each token in a sequence. This is important because it enables the seq2seq model to understand and differentiate between the contextual meanings of words in different sentences.

The text also mentions that this contextual mapping can help subsequent token-wise functions, such as the feed-forward layers in Transformer networks, to realize the outputs of any arbitrary sequence-to-sequence functions. In other words, by ensuring that tokens in different contexts are properly distinguished, the seq2seq model can be more effective in tasks like machine translation, sentiment analysis, or text summarization, among others.

In summary, the text explains the concept of contextual mapping in the context of sequence-to-sequence models. This mapping function ensures that each token is assigned a unique value depending on its context within a sentence, which ultimately helps the model understand and process different meanings of the same word in various contexts. This is a crucial aspect of natural language processing, as it allows the model to capture and represent the nuances of human language more effectively.

In [13]:
pip install transformers -q

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
Note: you may need to restart the kernel to use updated packages.


In [14]:
import torch
from transformers import GPT2Tokenizer, GPT2Model

# Initialize the GPT-2 tokenizer and model
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
model = GPT2Model.from_pretrained("gpt2")

# Define two sentences
sentence1 = "The cat sat on the mat."
sentence2 = "The cat chased the mouse."

# Tokenize the sentences
input_ids1 = tokenizer.encode(sentence1, return_tensors="pt")
input_ids2 = tokenizer.encode(sentence2, return_tensors="pt")

# Extract the context vectors (hidden states) for both sentences
with torch.no_grad():
    outputs1 = model(input_ids1)
    outputs2 = model(input_ids2)
    hidden_states1 = outputs1.last_hidden_state
    hidden_states2 = outputs2.last_hidden_state

# Choose the index of the token you want to compare
token_index = 1  # Index 1 corresponds to "cat" in both sentences

# Extract the context vectors for the chosen token in both sentences
context_vector1 = hidden_states1[0, token_index, :]
context_vector2 = hidden_states2[0, token_index, :]

# Compute the cosine similarity between the context vectors
cosine_similarity = torch.nn.functional.cosine_similarity(context_vector1.unsqueeze(0), context_vector2.unsqueeze(0))

print("Cosine similarity:", cosine_similarity.item())


Cosine similarity: 0.9999998807907104


While the cosine similarity shows these two vectors are very similar, they are mapped to different context vectors. Let's look at another example allowing us to compare multiple models to see if they are effectively providing contextual mappings that are different for the same token in different contexts. As we will see, the Bert model seems to capture the context of the word "*cat*" significantly better than GPT-2. 

In this next example you will be able to compare context vectors for two different tokens, each in a different sentence of your choosing. Note, when comparing the context vector of a token that appears in both sentences, the cosine similary should not be equal to $1$, as this implies the two context vectors are the same. We want a token appearing in two different contexts to have two different **contextual mappings**, that is, two different context vectors that depend on the sentence providing the context. 

In [15]:
import torch
from transformers import GPT2Tokenizer, GPT2Model, BertTokenizer, BertModel
import ipywidgets as widgets

# Define a function to compute context vectors based on the selected model
def compute_context_vectors(model_name, sentence1, sentence2, token_index1, token_index2):
    if model_name == "gpt2":
        tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
        model = GPT2Model.from_pretrained("gpt2")
    elif model_name == "bert":
        tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
        model = BertModel.from_pretrained("bert-base-uncased")

    input_ids1 = tokenizer.encode(sentence1, return_tensors="pt")
    input_ids2 = tokenizer.encode(sentence2, return_tensors="pt")

    with torch.no_grad():
        outputs1 = model(input_ids1)
        outputs2 = model(input_ids2)
        hidden_states1 = outputs1.last_hidden_state
        hidden_states2 = outputs2.last_hidden_state

    if token_index1 != token_index2:
        return "Tokens must be the same"
    else:
        context_vector1 = hidden_states1[0, token_index1, :]
        context_vector2 = hidden_states2[0, token_index2, :]

        cosine_similarity = torch.nn.functional.cosine_similarity(context_vector1.unsqueeze(0), context_vector2.unsqueeze(0))

        return cosine_similarity.item()

#Create two text boxes for entering sentences
sentence1_box = widgets.Text(
    value='The cat sat on the mat.',
    placeholder='Enter sentence 1',
    description='Sentence 1:',
    disabled=False
)

sentence2_box = widgets.Text(
    value='The cat chased the mouse.',
    placeholder='Enter sentence 2',
    description='Sentence 2:',
    disabled=False
)

# Choose the index of the token you want to compare
token_index_box1 = widgets.IntText(
    value=1,
    description='Token index 1:',
    disabled=False
)

token_index_box2 = widgets.IntText(
    value=1,
    description='Token index 2:',
    disabled=False
)

#Create a dropdown menu for model selection
model_dropdown = widgets.Dropdown(
    options=["gpt2", "bert"],
    value="gpt2",
    description="Model:",
    disabled=False,
)

#Define a function to handle model selection changes
def on_model_change(change):
    if change["type"] == "change" and change["name"] == "value":
        model_name = change["new"]
        cosine_similarity = compute_context_vectors(model_name, sentence1_box.value, sentence2_box.value, token_index_box1.value, token_index_box2.value)
        print(f"Cosine similarity for {model_name}: {cosine_similarity}")
    
#Add an event listener for the dropdown menu
model_dropdown.observe(on_model_change)

#Create a button to trigger the computation of cosine similarity
button = widgets.Button(description="Compute cosine similarity")

def on_button_click(b):
    cosine_similarity = compute_context_vectors(model_dropdown.value, sentence1_box.value, sentence2_box.value, token_index_box1.value, token_index_box2.value)
    print(f"Cosine similarity for {model_dropdown.value}: {cosine_similarity}")
    
button.on_click(on_button_click)

#Display the UI
display(sentence1_box)
display(sentence2_box)
display(token_index_box1)
display(token_index_box2)
display(model_dropdown)
display(button)

Text(value='The cat sat on the mat.', description='Sentence 1:', placeholder='Enter sentence 1')

Text(value='The cat chased the mouse.', description='Sentence 2:', placeholder='Enter sentence 2')

IntText(value=1, description='Token index 1:')

IntText(value=1, description='Token index 2:')

Dropdown(description='Model:', options=('gpt2', 'bert'), value='gpt2')

Button(description='Compute cosine similarity', style=ButtonStyle())

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Cosine similarity for bert: 0.8720442056655884


Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Cosine similarity for bert: 0.8720442056655884
Cosine similarity for gpt2: 0.9999998807907104
Cosine similarity for gpt2: 0.9999998807907104
