# Visualizing Context Vectors of ESM2 for Motif Identification

In this notebook, we give a few examples of a visualization tool which allows us to visualize the $1$-skeleton (graph) of the filtered simplicial complex given by computing persistent homology of the context vectors. This allows us to look for simplices (edges, triangles, tetrahedra, etc.) with low distance threshold birth times. Finding such simplices can indicate which subsequences of the protein sequence might be motifs. Similar to how we analyze the persistent homology of keyphrases and collocations in multiple contexts in text, we can analyze the persistent homology of molecular motifs across different proteins or functional groups across different molecules. When we find a simplex with low distance threshold birth time, we can examine how the model treats this motif or functional group in another context, that is, in a different protein sequence. If the persistent homology is stable across different contexts for models that are known to perform better at protein sequence related tasks like protein folding, we can deduce that invariance of persistent homology across contexts is likely an important property to have. Then, we might include a topological term in the loss while training or use this for knowledge distillation, where the student model mimics the persistent homology representations of the teacher model. 

In the example visualizations below, try setting the distance threshold to a low value and look for triangles or tetrahedra that form at low distances. This will indicate you have found a likely motif or functional group in the protein sequence. You can also modify the `layer` and `head`, and substitute in your own protein sequence. 

In [None]:
pip install umap-learn -q

## Using t-SNE

In [None]:
import torch
import gudhi as gd
import numpy as np
import matplotlib.pyplot as plt
import ipywidgets as widgets
from scipy.spatial import distance_matrix
from mpl_toolkits.mplot3d import Axes3D
from transformers import AutoTokenizer, EsmModel
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE


tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D")

def compute_output(sentence, layer, head):
    # Load pre-trained model
    global tokenizer
    model = EsmModel.from_pretrained("facebook/esm2_t6_8M_UR50D")

    # Set the output_attentions attribute of the configuration to True
    model.config.output_attentions = True

    # Tokenize input and convert to tensor
    inputs = tokenizer(sentence, return_tensors="pt")

    # Forward pass
    outputs = model(**inputs, output_hidden_states=True)

    # Obtain the attention weights
    attentions = outputs.attentions

    # Obtain the attention weights for the specific layer and head
    S = attentions[layer][0, head]

    # Obtain the value vectors
    model.eval()
    with torch.no_grad():
        hidden_states = outputs.hidden_states[layer]
        all_W_v = model.encoder.layer[layer].attention.self.value.weight
        num_heads = model.config.num_attention_heads
        head_dim = model.config.hidden_size // num_heads
        W_v_heads = all_W_v.view(num_heads, head_dim, model.config.hidden_size)
        W_v = W_v_heads[head]
        V = torch.matmul(hidden_states, W_v.t())

    # Compute the output O
    O = torch.matmul(S, V)

    return O, inputs

# Compute the output
output, inputs = compute_output("MGWGRKRR", 0, 0)

# Convert the output tensor to numpy array
output_np = output.detach().numpy()[0]

# Compute the pairwise Euclidean distance matrix
distances = distance_matrix(output_np, output_np)

# Compute the number of tokens
num_tokens = len(inputs["input_ids"][0])

# Decompose output to 3D for visualization
# pca = PCA(n_components=3)
# output_3d = pca.fit_transform(output_np)

# Decompose output to 3D for visualization using t-SNE
tsne = TSNE(n_components=3, perplexity=num_tokens-1)  # Set perplexity to the number of tokens
output_3d = tsne.fit_transform(output_np)

# Define the function to update the plot
def update_plot(eps):
    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')
    ax.scatter(output_3d[:, 0], output_3d[:, 1], output_3d[:, 2])
    
    # Add labels
    for i, token_id in enumerate(inputs['input_ids'][0]):
        ax.text(output_3d[i, 0], output_3d[i, 1], output_3d[i, 2], tokenizer.decode([token_id]))
    
    # Add edges
    for i in range(distances.shape[0]):
        for j in range(i+1, distances.shape[1]):
            if distances[i, j] <= eps:
                ax.plot([output_3d[i, 0], output_3d[j, 0]], 
                        [output_3d[i, 1], output_3d[j, 1]], 
                        [output_3d[i, 2], output_3d[j, 2]], 'b-')

    plt.show()

# Create the interactive widget for the threshold distance parameter
eps_slider = widgets.FloatSlider(min=0, max=np.max(distances), step=0.01, value=0)
widgets.interactive(update_plot, eps=eps_slider)


## Using UMAP

In [None]:
import torch
import gudhi as gd
import numpy as np
import matplotlib.pyplot as plt
import ipywidgets as widgets
from scipy.spatial import distance_matrix
from mpl_toolkits.mplot3d import Axes3D
from transformers import AutoTokenizer, EsmModel
from sklearn.decomposition import PCA
from umap.umap_ import UMAP


tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D")

def compute_output(sentence, layer, head):
    # Load pre-trained model
    global tokenizer
    model = EsmModel.from_pretrained("facebook/esm2_t6_8M_UR50D")

    # Set the output_attentions attribute of the configuration to True
    model.config.output_attentions = True

    # Tokenize input and convert to tensor
    inputs = tokenizer(sentence, return_tensors="pt")

    # Forward pass
    outputs = model(**inputs, output_hidden_states=True)

    # Obtain the attention weights
    attentions = outputs.attentions

    # Obtain the attention weights for the specific layer and head
    S = attentions[layer][0, head]

    # Obtain the value vectors
    model.eval()
    with torch.no_grad():
        hidden_states = outputs.hidden_states[layer]
        all_W_v = model.encoder.layer[layer].attention.self.value.weight
        num_heads = model.config.num_attention_heads
        head_dim = model.config.hidden_size // num_heads
        W_v_heads = all_W_v.view(num_heads, head_dim, model.config.hidden_size)
        W_v = W_v_heads[head]
        V = torch.matmul(hidden_states, W_v.t())

    # Compute the output O
    O = torch.matmul(S, V)

    return O, inputs

# Compute the output
output, inputs = compute_output("MGWGRKRR", 0, 0)

# Convert the output tensor to numpy array
output_np = output.detach().numpy()[0]

# Compute the pairwise Euclidean distance matrix
distances = distance_matrix(output_np, output_np)

# Compute the number of tokens
num_tokens = len(inputs["input_ids"][0])

# Decompose output to 3D for visualization using UMAP
umap = UMAP(n_components=3)  # Set n_components to 3
output_3d = umap.fit_transform(output_np)

# Define the function to update the plot
def update_plot(eps):
    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')
    ax.scatter(output_3d[:, 0], output_3d[:, 1], output_3d[:, 2])
    
    # Add labels
    for i, token_id in enumerate(inputs['input_ids'][0]):
        ax.text(output_3d[i, 0], output_3d[i, 1], output_3d[i, 2], tokenizer.decode([token_id]))
    
    # Add edges
    for i in range(distances.shape[0]):
        for j in range(i+1, distances.shape[1]):
            if distances[i, j] <= eps:
                ax.plot([output_3d[i, 0], output_3d[j, 0]], 
                        [output_3d[i, 1], output_3d[j, 1]], 
                        [output_3d[i, 2], output_3d[j, 2]], 'b-')

    plt.show()

# Create the interactive widget for the threshold distance parameter
eps_slider = widgets.FloatSlider(min=0, max=np.max(distances), step=0.01, value=0)
widgets.interactive(update_plot, eps=eps_slider)


## Interactive Plot with UMAP

In [25]:
import torch
import numpy as np
from scipy.spatial import distance_matrix
from transformers import AutoTokenizer, EsmModel
from umap.umap_ import UMAP
import plotly.graph_objects as go
from ipywidgets import interact, FloatSlider

tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D")

def compute_output(sentence, layer, head):
    # Load pre-trained model
    global tokenizer
    model = EsmModel.from_pretrained("facebook/esm2_t6_8M_UR50D")

    # Set the output_attentions attribute of the configuration to True
    model.config.output_attentions = True

    # Tokenize input and convert to tensor
    inputs = tokenizer(sentence, return_tensors="pt")

    # Forward pass
    outputs = model(**inputs, output_hidden_states=True)

    # Obtain the attention weights
    attentions = outputs.attentions

    # Obtain the attention weights for the specific layer and head
    S = attentions[layer][0, head]

    # Obtain the value vectors
    model.eval()
    with torch.no_grad():
        hidden_states = outputs.hidden_states[layer]
        all_W_v = model.encoder.layer[layer].attention.self.value.weight
        num_heads = model.config.num_attention_heads
        head_dim = model.config.hidden_size // num_heads
        W_v_heads = all_W_v.view(num_heads, head_dim, model.config.hidden_size)
        W_v = W_v_heads[head]
        V = torch.matmul(hidden_states, W_v.t())

    # Compute the output O
    O = torch.matmul(S, V)

    return O, inputs

# Compute the output
output, inputs = compute_output("MAVESRVTQEEIKKEPEKPIDREKTCPLLLRVFTTNNGRHHRMDEFSRGNVPSSELQIYTWMDATLKELTSLVKEVYPEARKKGTHFNFAIVFTDVKRPGYRVKEIGSTMSGRKGTDDSMTLQSQKFQIGDYLDIAITPPNRAPPPSGRMRPY", 3, 2)

# Convert the output tensor to numpy array
output_np = output.detach().numpy()[0]

# Compute the pairwise Euclidean distance matrix
distances = distance_matrix(output_np, output_np)

# Compute the number of tokens
num_tokens = len(inputs["input_ids"][0])

# Decompose output to 3D for visualization using UMAP
umap = UMAP(n_components=3)  # Set n_components to 3
output_3d = umap.fit_transform(output_np)

# Define the function to update the plot
def update_plot(eps):
    # Create a plotly graph object
    fig = go.Figure()

    # Get token labels
    token_labels = [tokenizer.decode([token_id]) for token_id in inputs['input_ids'][0]]
    
    # Add 3D scatter plot points
    fig.add_trace(go.Scatter3d(x=output_3d[:, 0], y=output_3d[:, 1], z=output_3d[:, 2],
                    mode='markers+text',   # Added 'text' to 'mode' to show the token labels
                    text=token_labels,     # Specify the text(labels) to be displayed
                    textposition='top center',   # Position the text to be top center wrt the marker
                    marker=dict(
                        size=6,
                        color='rgb(100, 150, 200)', 
                        colorscale='Viridis',   
                        opacity=0.8
                    ),
                    hovertemplate = "%{text}<br>x: %{x}<br>y: %{y}<br>z: %{z}<extra></extra>",
    ))

    # Add edges
    for i in range(distances.shape[0]):
        for j in range(i+1, distances.shape[1]):
            if distances[i, j] <= eps:
                fig.add_trace(go.Scatter3d(x=[output_3d[i, 0], output_3d[j, 0]],
                                        y=[output_3d[i, 1], output_3d[j, 1]], 
                                        z=[output_3d[i, 2], output_3d[j, 2]],
                                        mode='lines',
                                        line=dict(
                                            color='darkblue',
                                            width=2
                                        )))

    fig.update_layout(scene = dict(
                    xaxis_title='X AXIS',
                    yaxis_title='Y AXIS',
                    zaxis_title='Z AXIS'),
                    width=700,
                    margin=dict(r=20, l=10, b=10, t=10))
    fig.show()

# Create the interactive widget for the threshold distance parameter
eps_slider = FloatSlider(min=0, max=np.max(distances), step=0.01, value=0, description='eps:')
interact(update_plot, eps=eps_slider)


Some weights of the model checkpoint at facebook/esm2_t6_8M_UR50D were not used when initializing EsmModel: ['lm_head.bias', 'lm_head.dense.bias', 'lm_head.layer_norm.bias', 'lm_head.dense.weight', 'lm_head.layer_norm.weight']
- This IS expected if you are initializing EsmModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing EsmModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t6_8M_UR50D and are newly initialized: ['esm.pooler.dense.weight', 'esm.pooler.dense.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


interactive(children=(FloatSlider(value=0.0, description='eps:', max=1.6711198494853368, step=0.01), Output())…

<function __main__.update_plot(eps)>

## With Hidden States Instead of Context Vectors

In [2]:
import torch
import numpy as np
from scipy.spatial import distance_matrix
from transformers import AutoTokenizer, EsmModel
from umap.umap_ import UMAP
import plotly.graph_objects as go
from ipywidgets import interact, FloatSlider

tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D")

def compute_hidden_states(sentence, layer):
    # Load pre-trained model
    global tokenizer
    model = EsmModel.from_pretrained("facebook/esm2_t6_8M_UR50D")

    # Tokenize input and convert to tensor
    inputs = tokenizer(sentence, return_tensors="pt")

    # Forward pass
    outputs = model(**inputs, output_hidden_states=True)

    # Obtain the hidden states for the specified layer
    hidden_states = outputs.hidden_states[layer].detach().numpy()

    return hidden_states, inputs


# Compute the output
output, inputs = compute_hidden_states("MAVESRVTQEEIKKEPEKPIDREKTCPLLLRVFTTNNGRHHRMDEFSRGNVPSSELQIYTWMDATLKELTSLVKEVYPEARKKGTHFNFAIVFTDVKRPGYRVKEIGSTMSGRKGTDDSMTLQSQKFQIGDYLDIAITPPNRAPPPSGRMRPY", 3)

# Convert the output tensor to numpy array
output_np = output[0]

# Compute the pairwise Euclidean distance matrix
distances = distance_matrix(output_np, output_np)

# Compute the number of tokens
num_tokens = len(inputs["input_ids"][0])

# Decompose output to 3D for visualization using UMAP
umap = UMAP(n_components=3)  # Set n_components to 3
output_3d = umap.fit_transform(output_np)

# Define the function to update the plot
def update_plot(eps):
    # Create a plotly graph object
    fig = go.Figure()

    # Get token labels
    token_labels = [tokenizer.decode([token_id]) for token_id in inputs['input_ids'][0]]
    
    # Add 3D scatter plot points
    fig.add_trace(go.Scatter3d(x=output_3d[:, 0], y=output_3d[:, 1], z=output_3d[:, 2],
                    mode='markers+text',   # Added 'text' to 'mode' to show the token labels
                    text=token_labels,     # Specify the text(labels) to be displayed
                    textposition='top center',   # Position the text to be top center wrt the marker
                    marker=dict(
                        size=6,
                        color='rgb(100, 150, 200)', 
                        colorscale='Viridis',   
                        opacity=0.8
                    ),
                    hovertemplate = "%{text}<br>x: %{x}<br>y: %{y}<br>z: %{z}<extra></extra>",
    ))

    # Add edges
    for i in range(distances.shape[0]):
        for j in range(i+1, distances.shape[1]):
            if distances[i, j] <= eps:
                fig.add_trace(go.Scatter3d(x=[output_3d[i, 0], output_3d[j, 0]],
                                        y=[output_3d[i, 1], output_3d[j, 1]], 
                                        z=[output_3d[i, 2], output_3d[j, 2]],
                                        mode='lines',
                                        line=dict(
                                            color='darkblue',
                                            width=2
                                        )))

    fig.update_layout(scene = dict(
                    xaxis_title='X AXIS',
                    yaxis_title='Y AXIS',
                    zaxis_title='Z AXIS'),
                    width=700,
                    margin=dict(r=20, l=10, b=10, t=10))
    fig.show()

# Create the interactive widget for the threshold distance parameter
eps_slider = FloatSlider(min=0, max=np.max(distances), step=0.01, value=0, description='eps:')
interact(update_plot, eps=eps_slider)


Some weights of the model checkpoint at facebook/esm2_t6_8M_UR50D were not used when initializing EsmModel: ['lm_head.dense.weight', 'lm_head.layer_norm.bias', 'lm_head.bias', 'lm_head.layer_norm.weight', 'lm_head.dense.bias']
- This IS expected if you are initializing EsmModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing EsmModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t6_8M_UR50D and are newly initialized: ['esm.pooler.dense.weight', 'esm.pooler.dense.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


interactive(children=(FloatSlider(value=0.0, description='eps:', max=32.6344072276016, step=0.01), Output()), …

<function __main__.update_plot(eps)>