# Persistent Homology of Attention in Vision

![simplicial_complex_image.png](simplicial_complex_image.png)

This code is for analyzing and comparing the attention mechanisms of two images processed by Vision Transformer (ViT) models. The analysis is done using topological data analysis (TDA), a mathematical technique that studies the shape of data. In this case, TDA is used to examine the structure of attention maps generated by the vision models.

Persistent homology is a technique in topological data analysis (TDA) that quantifies the topological features of data across different scales. In this code, persistent homology is used to analyze the attention mechanisms of Vision Transformer (ViT) models to gain insights into how the models process and attend to different parts of the input images.

The code constructs simplicial complexes using the Rips complex method, which is built upon the pairwise distances between the attention matrix's rows (in this case, the rows represent the attention values of each patch in an image). The Rips complex is a way to represent the structure of the data, as it captures the relationships between patches in the image based on their attention values. Each simplex in the complex represents a group of patches that are considered "close" to each other in terms of their attention values.

Persistence diagrams are then generated from the simplicial complexes, representing the topological features (e.g., connected components) of the data across different scales. These diagrams consist of points, where each point's coordinates (birth, death) represent the scale at which a topological feature appears and disappears, respectively. In this context, the birth and death of a feature correspond to the threshold distance at which patches become connected or disconnected based on their attention values.

Persistent homology is useful for analyzing vision transformers because it provides a way to study the structure and organization of the attention mechanisms in these models. By examining the topological features of the attention maps, we can gain insights into how the model attends to different parts of the input images and how the attention is organized across the image patches.

Comparing the persistence diagrams of two different images processed by the same ViT model can reveal similarities and differences in the attention patterns generated by the model for the respective images. The bottleneck distance, which is computed in this code, is a metric that quantifies the similarity between two persistence diagrams. A smaller bottleneck distance indicates that the topological features of the attention maps for the two images are more similar, suggesting that the model processes the images in a similar manner.

In summary, using persistent homology to analyze the attention mechanisms of vision transformers can help us:

1. Understand the structure and organization of the attention mechanisms within the model and how they change across different layers and heads.
2. Identify similarities and differences in the attention patterns generated by the model for different input images.
3. Evaluate the robustness and interpretability of the model's attention mechanisms by studying the topological features and their stability across different scales.

Overall, the application of persistent homology to study vision transformers can contribute to a better understanding of these models, potentially leading to improvements in their performance, robustness, and interpretability.

---

Please run the cells below, and upload two photos to compare. Once the two photos are uploaded, please select a layer and head to study. Then select a scale. This will give a simplicial complex (technically its $1$-skeleton) that represents the topology of the attention scores at that scale. The persistence diagrams are also computed, along with their bottleneck distance. 

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
pip install numpy torch transformers gudhi matplotlib networkx scipy plotly ipywidgets -q

Note: you may need to restart the kernel to use updated packages.


In [3]:
import numpy as np
import torch
from transformers import AutoTokenizer, AutoModel, ViTFeatureExtractor
from PIL import Image
import requests
from io import BytesIO
import gudhi as gd
import networkx as nx
from scipy.spatial.distance import jensenshannon
import plotly.graph_objs as go
from ipywidgets import interact, FloatSlider, IntSlider, Dropdown, VBox, Label, FileUpload
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings("ignore", category=UserWarning)

def plot_persistence_diagram(persistence, title):
    gd.plot_persistence_diagram(persistence)
    plt.title(title)
    plt.show()
    
def compute_bottleneck_distance(persistence1, persistence2):
    persistence1_array = np.array([(birth, death) for dim, (birth, death) in persistence1 if dim == 0])
    persistence2_array = np.array([(birth, death) for dim, (birth, death) in persistence2 if dim == 0])
    return gd.bottleneck_distance(persistence1_array, persistence2_array)

def load_model(model_name):
    feature_extractor = ViTFeatureExtractor.from_pretrained(model_name)
    model = AutoModel.from_pretrained(model_name)
    return feature_extractor, model

# Replace language models with vision transformers
model_dropdown = Dropdown(
    options=[
        ('ViT Base Patch16 224', 'google/vit-base-patch16-224'),
        ('ViT Large Patch16 224', 'google/vit-large-patch16-224')
    ],
   
        value='google/vit-base-patch16-224',
        description='Model:'
    )

feature_extractor, model = load_model(model_dropdown.value)

# Replace text inputs with file upload widgets for images
image_input1 = FileUpload(description='Image 1:')
image_input2 = FileUpload(description='Image 2:')

layer_slider = IntSlider(value=1, min=0, max=model.config.num_hidden_layers - 1, description='Layer:', continuous_update=False)
head_slider = IntSlider(value=2, min=0, max=model.config.num_attention_heads - 1, description='Head:', continuous_update=False)

threshold_slider = FloatSlider(value=0.05, min=0.00, max=0.5, step=0.001, description='Threshold:', continuous_update=False)


def compute_persistence(attention_matrix):
    softmax_attention = np.exp(attention_matrix) / np.sum(np.exp(attention_matrix), axis=-1)[:, np.newaxis]
    distance_matrix = np.array([[np.sqrt(jensenshannon(softmax_attention[i], softmax_attention[j])) for j in range(softmax_attention.shape[0])] for i in range(softmax_attention.shape[0])])
    
    rips_complex = gd.RipsComplex(distance_matrix=distance_matrix, max_edge_length=np.inf)
    simplex_tree = rips_complex.create_simplex_tree(max_dimension=2)
    persistence = simplex_tree.persistence(min_persistence=0.01)
    return persistence, simplex_tree, distance_matrix

def plot_simplicial_complex_3d(simplex_tree, distance_matrix, title, threshold, tokens):
    g = nx.Graph()
    for (simplex, _) in simplex_tree.get_filtration():
        if len(simplex) == 2:
            if distance_matrix[simplex[0]][simplex[1]] <= threshold:
                g.add_edge(simplex[0], simplex[1])

    labels = {node: tokens[node] for node in g.nodes()}
    
    pos = nx.spring_layout(g, dim=3, seed=42)
    
    Xn = [pos[k][0] for k in g.nodes()]
    Yn = [pos[k][1] for k in g.nodes()]
    Zn = [pos[k][2] for k in g.nodes()]
    
    Xe = []
    Ye = []
    Ze = []
    for e in g.edges():
        Xe += [pos[e[0]][0], pos[e[1]][0], None]
        Ye += [pos[e[0]][1], pos[e[1]][1], None]
        Ze += [pos[e[0]][2], pos[e[1]][2], None]
    
    trace_edges = go.Scatter3d(x=Xe, y=Ye, z=Ze, mode='lines', line=dict(color='gray', width=1))
    
    trace_nodes = go.Scatter3d(x=Xn, y=Yn, z=Zn, mode='markers+text', text=list(labels.values()), marker=dict(symbol='circle', size=10, color='lightblue'), textposition="top center")
    
    layout = go.Layout(title=title, scene=dict(xaxis_title='X', yaxis_title='Y', zaxis_title='Z'), showlegend=False)
    
    fig = go.Figure(data=[trace_edges, trace_nodes], layout=layout)
    fig.show()
    

def get_attention_matrix(image, model, feature_extractor, layer, head):
    # Convert the image to RGB format if it is not already
    if image.mode != "RGB":
        image = image.convert("RGB")

    inputs = feature_extractor(images=image, return_tensors="pt")
    outputs = model(**inputs, output_attentions=True)
    attention = outputs.attentions[layer][0, head].detach().cpu().numpy()

    return attention


def update_plot(threshold, layer, head, model_name):
    if len(image_input1.data) == 0 or len(image_input2.data) == 0:
        print("Please upload both images before proceeding.")
        return

    tokenizer, model = load_model(model_name)
    layer_slider.max = model.config.num_hidden_layers - 1
    head_slider.max = model.config.num_attention_heads - 1

    # Load images from the uploaded files
    image1 = Image.open(BytesIO(image_input1.data[-1]))
    image2 = Image.open(BytesIO(image_input2.data[-1]))

    attention_matrix1 = get_attention_matrix(image1, model, tokenizer, layer, head)
    attention_matrix2 = get_attention_matrix(image2, model, tokenizer, layer, head)

    persistence1, simplex_tree1, distance_matrix1 = compute_persistence(attention_matrix1)
    persistence2, simplex_tree2, distance_matrix2 = compute_persistence(attention_matrix2)

    plot_simplicial_complex_3d(simplex_tree1, distance_matrix1, "Simplicial Complex for Image 1", threshold, range(attention_matrix1.shape[0]))
    plot_simplicial_complex_3d(simplex_tree2, distance_matrix2, "Simplicial Complex for Image 2", threshold, range(attention_matrix2.shape[0]))
    
    plot_persistence_diagram(persistence1, "Persistence Diagram for Image 1")
    plot_persistence_diagram(persistence2, "Persistence Diagram for Image 2")

    # Compute and display bottleneck distance between persistence diagrams
    bottleneck_distance = compute_bottleneck_distance(persistence1, persistence2)
    print("Bottleneck distance between Image 1 and Image 2:", bottleneck_distance)

interact(update_plot, threshold=threshold_slider, layer=layer_slider, head=head_slider, model_name=model_dropdown)

# Display the image upload widgets
VBox([image_input1, image_input2])


Some weights of the model checkpoint at google/vit-base-patch16-224 were not used when initializing ViTModel: ['classifier.weight', 'classifier.bias']
- This IS expected if you are initializing ViTModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ViTModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized: ['vit.pooler.dense.bias', 'vit.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


interactive(children=(FloatSlider(value=0.05, continuous_update=False, description='Threshold:', max=0.5, step…

VBox(children=(FileUpload(value={}, description='Image 1:'), FileUpload(value={}, description='Image 2:')))