In [3]:
import numpy as np
import torch
from transformers import ViTModel, ViTFeatureExtractor
import gudhi as gd
import networkx as nx
from scipy.spatial.distance import jensenshannon
import plotly.graph_objs as go
from PIL import Image
import ipywidgets as widgets


def get_attention_matrix(image, model, tokenizer, layer, head):
    inputs = tokenizer(images=[image], return_tensors="pt")
    outputs = model(**inputs, output_attentions=True)
    attention = outputs.attentions[layer][0, head].detach().cpu().numpy()
    return attention

def compute_persistence(attention_matrix):
    softmax_attention = np.exp(attention_matrix) / np.sum(np.exp(attention_matrix), axis=-1)[:, np.newaxis]
    distance_matrix = np.array([[np.sqrt(jensenshannon(softmax_attention[i], softmax_attention[j])) for j in range(softmax_attention.shape[0])] for i in range(softmax_attention.shape[0])])
    
    rips_complex = gd.RipsComplex(distance_matrix=distance_matrix, max_edge_length=np.inf)
    simplex_tree = rips_complex.create_simplex_tree(max_dimension=2)
    persistence = simplex_tree.persistence(min_persistence=0.01)
    return persistence, simplex_tree, distance_matrix

def process_image(image_path, model, tokenizer, layer, head):
    image = Image.open(image_path)
    attention_matrix = get_attention_matrix(image, model, tokenizer, layer, head)
    persistence, simplex_tree, distance_matrix = compute_persistence(attention_matrix)
    tokens = list(range(attention_matrix.shape[0]))
    return persistence, simplex_tree, distance_matrix, tokens

def plot_simplicial_complex_3d(simplex_tree, distance_matrix, title, threshold, tokens):
    g = nx.Graph()
    for (simplex, _) in simplex_tree.get_filtration():
        if len(simplex) == 2:
            if distance_matrix[simplex[0]][simplex[1]] <= threshold:
                g.add_edge(simplex[0], simplex[1])

    labels = {node: tokens[node] for node in g.nodes()}
    
    pos = nx.spring_layout(g, dim=3, seed=42)
    
    Xn = [pos[k][0] for k in g.nodes()]
    Yn = [pos[k][1] for k in g.nodes()]
    Zn = [pos[k][2] for k in g.nodes()]
    
    Xe = []
    Ye = []
    Ze = []
    for e in g.edges():
        Xe += [pos[e[0]][0], pos[e[1]][0], None]
        Ye += [pos[e[0]][1], pos[e[1]][1], None]
        Ze += [pos[e[0]][2], pos[e[1]][2], None]
    
    trace_edges = go.Scatter3d(x=Xe, y=Ye, z=Ze, mode='lines', line=dict(color='gray', width=1))
    trace_nodes = go.Scatter3d(x=Xn, y=Yn, z=Zn, mode='markers+text', text=list(labels.values()), marker=dict(symbol='circle', size=10, color='lightblue'), textposition="top center")
    layout = go.Layout(title=title, scene=dict(xaxis_title='X', yaxis_title='Y', zaxis_title='Z'), showlegend=False)
    fig = go.Figure(data=[trace_edges, trace_nodes], layout=layout)
    fig.show()


In [6]:
# Load pre-trained visual transformer model and tokenizer
tokenizer = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224')
model = ViTModel.from_pretrained('google/vit-base-patch16-224')

# Image paths
image_paths = ["heart.jpg", "platonic_metatron.jpg"]

# Choose a layer and head
layer = 1
head = 2

# Process images
results = [process_image(image_path, model, tokenizer, layer, head) for image_path in image_paths]

# Create a slider for the threshold value
threshold_slider = widgets.FloatSlider(
    value=0.001,
    min=0.00001,
    max=0.2,
    step=0.00001,
    description='Threshold:',
    continuous_update=False
)

# Define a function to update the plot based on the threshold value from the slider
def update_plot(threshold):
    for i, (persistence, simplex_tree, distance_matrix, tokens) in enumerate(results):
        plot_simplicial_complex_3d(simplex_tree, distance_matrix, f"Simplicial Complex for Image {i+1}", threshold, tokens)

# Connect the slider to the update_plot function
widgets.interact(update_plot, threshold=threshold_slider)

Some weights of the model checkpoint at google/vit-base-patch16-224 were not used when initializing ViTModel: ['classifier.weight', 'classifier.bias']
- This IS expected if you are initializing ViTModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ViTModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized: ['vit.pooler.dense.weight', 'vit.pooler.dense.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


interactive(children=(FloatSlider(value=0.001, continuous_update=False, description='Threshold:', max=0.2, min…

<function __main__.update_plot(threshold)>