# Comparing Motif Persistent Homology of ESM-2 Models

Here are some comments and explanations to make the notebook more readable:

### Overall Goal

This notebook generates a set of random protein sequences containing a common motif, computes contextual embeddings for the sequences using ESM models, and compares the embeddings using persistent homology and Wasserstein distance.

### Code Walkthrough

- Generate 100 random protein sequences containing the motif 'KLGFMNPTQY'
  - Use amino_acids string to sample from
  - Generate random length and random motif position
  - Insert motif at random position
  
- Load ESM-2 models of different sizes 
  - Smaller 6 layer model 
  - Larger 12 layer model
  
- Get hidden states for a given sequence
  - Helpful utility function to extract states from a given layer
  
- Compute persistent homology on phrase embeddings
  - Calculate distance matrix between motif vector embeddings
  - Generate persistent homology from distance matrix
  - Capture high-level topological structure
  
- Compute Wasserstein distance between two persistence diagrams
  - Quantify difference in topological structure
  - Use as distance metric between sequences
  
- Compare models by looking at difference in distances
  - Larger model gives larger distances
  - Majority are increased, some are decreased
  - Outliers likely due to instability for small model

### Conclusion

The notebook provides a workflow to compare protein sequence embeddings using topological data analysis. Key findings:

- Persistent homology captures topological structure of motif embeddings 
- Wasserstein distance quantifies difference between sequences
- Larger ESM model leads to greater distinction between sequences
- Outliers show instability for small model on some sequences

In [1]:
import random

amino_acids = 'ACDEFGHIKLMNPQRSTVWY'
motif = 'KLGFMNPTQY'
sequences = []

for _ in range(100):
    seq_length = random.randint(50, 75)
    motif_position = random.randint(0, seq_length - len(motif))
    sequence = ''.join(random.choice(amino_acids) for _ in range(seq_length))
    sequence = sequence[:motif_position] + motif + sequence[motif_position + len(motif):]
    sequences.append(sequence)

# Print the list of protein sequences
sequences


['RSWLLKTGCKEMMSEWYKKLGFMNPTQYPHPAVEFVYYCWTKLDHNEFSDVPPR',
 'QWGPKLGFMNPTQYDYTGYLHMTYRNGKVREHMWLGWYPGHYMFGLTDHMRENVTQHHCMQNLAFGC',
 'NTKPPTMQQEPISRFCMVEGWMAFNHVTSKLGFMNPTQYSPQNECELSELDHSLNGFDANWR',
 'PFRCKIVQISSACMTSMLFTVWCDFQWFVFASFLYEWKQKLGFMNPTQYNNWGSF',
 'CILDKVWFNVESRFCGPNGLMLSTDYYQKDWYVPTKFHENERYQKLGFMNPTQYWNYQQSFHGKACSEQS',
 'YQGRLRLFNTQGQYEEYISVMNLNIATYREHKLGFMNPTQYIRIRTANMAHNYCPWDFRYNGKWVQEKQYV',
 'FVFGESYIHLARHKLGFMNPTQYHGMMKSQVRYSELYVYGHFFQHHVAHETPYEFPSDFQDS',
 'TSAMSEICIAFVGSFYLLVQMGANLDMNYCNTLMVFVFMKLGFMNPTQYSMNKH',
 'RRFWDCQQEEYCKETSWPMTIGHMHWYNYENRGVCVPLQIQMWMKLGFMNPTQYPQD',
 'QKCDMPFMTIMKYFRNCVAIFEFERHREILFNNTPDFINWFVIFNWQPKDSERFQQHANVKLGFMNPTQYSM',
 'AIMLALMWIESMTQLIGRELYPDKKAFTNGVSEMERWKMPAGQQRKLGFMNPTQYKKVDMFTDEKQKGER',
 'LFFGNVHTIGKLGFMNPTQYYTICNWETLKLFVCSHHCTRSSCMEWMPAWAENKYDTN',
 'HPFYLRFTKLGFMNPTQYHTAGHEHQWVDLMTYVYDAQKHDIQMWNTMESKD',
 'PWMREPTKDENVKLGFMNPTQYPWKMAYFFRWIIFLEGIEADVGDPKPVEFSNRFVFYMRGICCPQYTP',
 'TYGDIHVIAVPGISDFMQLTICLLTHEAHWHDPPVSHWHPEPSVQKLGFMNPTQY

In [2]:
import torch

def get_hidden_states(tokenizer, model, sentence, layer):
    # Load pre-trained model
    tokenizer = tokenizer
    model = model

    # Tokenize input and convert to tensor
    inputs = tokenizer(sentence, return_tensors="pt")

    # Forward pass
    # Specify `output_hidden_states=True` when calling the model
    outputs = model(**inputs, output_attentions=True, output_hidden_states=True)

    # Obtain the hidden states for the specific layer
    hidden_states = outputs.hidden_states[layer]

    return hidden_states



In [3]:
from transformers import AutoTokenizer, EsmModel
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D")
model = EsmModel.from_pretrained("facebook/esm2_t6_8M_UR50D")

# Set the layer and head to use for computation
layer = 5

# Compute the context vectors for each text in the corpus
context = [get_hidden_states(tokenizer, model, t, layer) for t in sequences]

Some weights of the model checkpoint at facebook/esm2_t6_8M_UR50D were not used when initializing EsmModel: ['lm_head.dense.bias', 'lm_head.layer_norm.weight', 'lm_head.layer_norm.bias', 'lm_head.bias', 'lm_head.dense.weight']
- This IS expected if you are initializing EsmModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing EsmModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t6_8M_UR50D and are newly initialized: ['esm.pooler.dense.weight', 'esm.pooler.dense.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [4]:
context[len(sequences)-1]

tensor([[[ 9.6350e-01,  3.3161e+00,  2.3348e+00,  ...,  6.6520e-01,
           1.7391e+00, -4.6667e+00],
         [ 2.7638e-01,  2.9199e-01,  2.3925e+00,  ...,  2.1819e+00,
           1.0009e+00, -3.5376e-01],
         [-1.4134e+00, -1.8429e+00,  6.6259e-01,  ...,  7.7934e-01,
           8.1615e-01,  2.7464e+00],
         ...,
         [-9.6685e-01, -1.4864e+00, -1.3682e+00,  ..., -2.9277e-01,
           7.7223e-01, -1.4342e+00],
         [-2.8106e+00, -1.4600e+00,  7.1654e-01,  ..., -5.5445e-01,
           3.5375e-03, -9.1865e-01],
         [-6.8736e-01, -5.0814e-01, -3.8937e-01,  ..., -3.6261e-01,
          -4.4213e+00, -2.6621e+00]]], grad_fn=<AsStridedBackward0>)

In [5]:
from scipy.spatial import distance_matrix
import gudhi as gd
import numpy as np
import matplotlib.pyplot as plt
from transformers import AutoTokenizer, EsmModel


def compute_phrase_distances_and_homology(tokenizer, context_vectors, sentence, phrase):
    # Initialize the tokenizer
    tokenizer = tokenizer

    # Tokenize the sentence and the phrase
    sentence_tokens = tokenizer.encode(sentence, add_special_tokens=False)
    phrase_tokens = tokenizer.encode(phrase, add_special_tokens=False)

    # Find the indices of the phrase tokens in the sentence
    phrase_indices = []
    phrase_length = len(phrase_tokens)
    for i in range(len(sentence_tokens) - phrase_length + 1):
        if sentence_tokens[i:i+phrase_length] == phrase_tokens:
            phrase_indices.extend(range(i, i+phrase_length))
            break

    # Extract the context vectors for the phrase
    phrase_context_vectors = context_vectors[0, phrase_indices]

    # Detach the tensor and convert to numpy array
    phrase_context_vectors_np = phrase_context_vectors.detach().numpy()

    # Compute the pairwise Euclidean distances among the phrase context vectors
    distances = distance_matrix(phrase_context_vectors_np, phrase_context_vectors_np)

    # Compute the persistent homology of the distance matrix
    rips_complex = gd.RipsComplex(distance_matrix=distances, max_edge_length=np.max(distances))
    simplex_tree = rips_complex.create_simplex_tree(max_dimension=2)
    persistent_homology = simplex_tree.persistence(min_persistence=0.001)

    return persistent_homology

In [6]:
from gudhi.hera import wasserstein_distance
import numpy as np
from transformers import AutoTokenizer, EsmModel
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D")

persistent_homology_1 = compute_phrase_distances_and_homology(tokenizer, context[0], sequences[0], "KLGFMNPTQY")

persistent_homology_2 = compute_phrase_distances_and_homology(tokenizer, context[1], sequences[1], "KLGFMNPTQY")

dimension = 0  # set the dimension you're interested in

persistent_homology_np_1 = np.array([[pt[1][0], pt[1][1]] for pt in persistent_homology_1 if pt[1][1] != float('inf') and pt[0] == dimension])

persistent_homology_np_2 = np.array([[pt[1][0], pt[1][1]] for pt in persistent_homology_2 if pt[1][1] != float('inf') and pt[0] == dimension])


print(f"Wasserstein distance: = {wasserstein_distance(persistent_homology_np_1, persistent_homology_np_2, order=1., internal_p=2.):.2f}")



Wasserstein distance: = 9.52


In [7]:
from gudhi.hera import wasserstein_distance
import numpy as np
from transformers import AutoTokenizer, EsmModel

# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D")

# Function to convert persistent homology to numpy array
def ph_to_np(persistent_homology, dimension):
    return np.array([[pt[1][0], pt[1][1]] for pt in persistent_homology if pt[1][1] != float('inf') and pt[0] == dimension])

# Function to compute Wasserstein distance between two sequences
def compute_wd(tokenizer, seq1, seq2, context1, context2, dimension=0):
    ph1 = compute_phrase_distances_and_homology(tokenizer, context1, seq1, "KLGFMNPTQY")
    ph2 = compute_phrase_distances_and_homology(tokenizer, context2, seq2, "KLGFMNPTQY")
    ph_np_1 = ph_to_np(ph1, dimension)
    ph_np_2 = ph_to_np(ph2, dimension)
    wd = wasserstein_distance(ph_np_1, ph_np_2, order=1., internal_p=2.)
    return wd

# Compute pairwise Wasserstein distances
n = len(sequences)
wd_matrix = np.zeros((n, n))
for i in range(n):
    for j in range(i+1, n):  # no need to compute for j < i due to symmetry
        wd_matrix[i, j] = compute_wd(tokenizer, sequences[i], sequences[j], context[i], context[j])
        wd_matrix[j, i] = wd_matrix[i, j]  # the distance is symmetric

print(wd_matrix)


[[ 0.          9.52312771 15.05591117 ... 14.36145408 40.78931892
   6.31245712]
 [ 9.52312771  0.         20.85794811 ... 21.16120551 50.31024142
  13.28347528]
 [15.05591117 20.85794811  0.         ...  2.52152111 29.45229331
   9.66645219]
 ...
 [14.36145408 21.16120551  2.52152111 ...  0.         29.14903591
  10.21494934]
 [40.78931892 50.31024142 29.45229331 ... 29.14903591  0.
  37.02676614]
 [ 6.31245712 13.28347528  9.66645219 ... 10.21494934 37.02676614
   0.        ]]


In [8]:
from transformers import AutoTokenizer, EsmModel
tokenizer_2 = AutoTokenizer.from_pretrained("facebook/esm2_t12_35M_UR50D")
model_2 = EsmModel.from_pretrained("facebook/esm2_t12_35M_UR50D")

# Set the layer and head to use for computation
layer = 11

# Compute the context vectors for each text in the corpus
context_2 = [get_hidden_states(tokenizer_2, model_2, t, layer) for t in sequences]

Some weights of the model checkpoint at facebook/esm2_t12_35M_UR50D were not used when initializing EsmModel: ['lm_head.dense.bias', 'lm_head.layer_norm.weight', 'lm_head.layer_norm.bias', 'lm_head.bias', 'lm_head.dense.weight']
- This IS expected if you are initializing EsmModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing EsmModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t12_35M_UR50D and are newly initialized: ['esm.pooler.dense.weight', 'esm.pooler.dense.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [9]:
from gudhi.hera import wasserstein_distance
import numpy as np
from transformers import AutoTokenizer, EsmModel
tokenizer_2 = AutoTokenizer.from_pretrained("facebook/esm2_t12_35M_UR50D")

persistent_homology_3 = compute_phrase_distances_and_homology(tokenizer_2, context_2[0], sequences[0], "KLGFMNPTQY")

persistent_homology_4 = compute_phrase_distances_and_homology(tokenizer_2, context_2[1], sequences[1], "KLGFMNPTQY")

dimension = 0  # set the dimension you're interested in

persistent_homology_np_3 = np.array([[pt[1][0], pt[1][1]] for pt in persistent_homology_3 if pt[1][1] != float('inf') and pt[0] == dimension])

persistent_homology_np_4 = np.array([[pt[1][0], pt[1][1]] for pt in persistent_homology_4 if pt[1][1] != float('inf') and pt[0] == dimension])


print(f"Wasserstein distance: = {wasserstein_distance(persistent_homology_np_3, persistent_homology_np_4, order=1., internal_p=2.):.2f}")



Wasserstein distance: = 37.39


In [10]:
# Load the tokenizer
tokenizer_2 = AutoTokenizer.from_pretrained("facebook/esm2_t12_35M_UR50D")

# Compute pairwise Wasserstein distances
n = len(sequences)
wd_matrix_2 = np.zeros((n, n))
for i in range(n):
    for j in range(i+1, n):  # no need to compute for j < i due to symmetry
        wd_matrix_2[i, j] = compute_wd(tokenizer_2, sequences[i], sequences[j], context_2[i], context_2[j])
        wd_matrix_2[j, i] = wd_matrix_2[i, j]  # the distance is symmetric

print(wd_matrix_2)


[[  0.          37.38918053  56.12622933 ...  90.61542573 127.75986032
   25.22901981]
 [ 37.38918053   0.          62.14173315 ...  60.29814263  97.44257723
   45.81406345]
 [ 56.12622933  62.14173315   0.         ...  34.4891964   71.63363099
   71.62300969]
 ...
 [ 90.61542573  60.29814263  34.4891964  ...   0.          37.1444346
  106.11220608]
 [127.75986032  97.44257723  71.63363099 ...  37.1444346    0.
  143.25664068]
 [ 25.22901981  45.81406345  71.62300969 ... 106.11220608 143.25664068
    0.        ]]


In [11]:
wd_matrix_2 - wd_matrix

array([[  0.        ,  27.86605282,  41.07031815, ...,  76.25397164,
         86.9705414 ,  18.9165627 ],
       [ 27.86605282,   0.        ,  41.28378504, ...,  39.13693712,
         47.13233581,  32.53058817],
       [ 41.07031815,  41.28378504,   0.        , ...,  31.96767529,
         42.18133768,  61.95655749],
       ...,
       [ 76.25397164,  39.13693712,  31.96767529, ...,   0.        ,
          7.99539869,  95.89725675],
       [ 86.9705414 ,  47.13233581,  42.18133768, ...,   7.99539869,
          0.        , 106.22987455],
       [ 18.9165627 ,  32.53058817,  61.95655749, ...,  95.89725675,
        106.22987455,   0.        ]])

In [12]:
import numpy as np
from scipy import stats

matrix = wd_matrix_2 - wd_matrix

# Compute the lower triangular part of the matrix, excluding the diagonal
lower_triangular = matrix[np.tril_indices(matrix.shape[0], k=-1)]

# Compute the statistics
percentage_non_negative = np.mean(lower_triangular >= 0) * 100
max_value = np.max(lower_triangular)
min_value = np.min(lower_triangular)
median = np.median(lower_triangular)
mean = np.mean(lower_triangular)
std_dev = np.std(lower_triangular)

# Compute outliers using Z-score
z_scores = stats.zscore(lower_triangular)
outliers = lower_triangular[np.abs(z_scores) > 1.5]

percentage_non_negative, max_value, min_value, median, mean, std_dev, outliers


(94.94949494949495,
 205.9887035027673,
 -33.50869924775133,
 38.63009773096543,
 42.63264214354211,
 31.582733299496553,
 array([ 90.09785646,  91.44264837,  92.84893095, 100.48104425,
        101.94778474,  -6.64630103,  -6.04887921, 100.41971326,
        114.4672457 , 102.20613192, 118.31902012,  96.10189332,
        -16.19022699,  -7.69609209,  -9.53037177, -12.58779501,
         90.0313116 ,  96.00217512, -11.70717053, -18.57745759,
         -6.49036944,  -7.35313098, -17.85245186, -16.16993629,
        -10.68902272, -23.65699194,  -8.15293604,  -7.76608573,
        -24.44866859, -13.9143451 ,  -5.17645417,  91.74962401,
         93.53432691, 110.97155856,  -9.35218583,  92.62566071,
        100.2419612 ,  97.79473112, 127.64931065, 130.42938076,
        101.90047228, 108.21898218, 115.92006279, 147.08085034,
        113.29608362,  96.00666536,  91.67193596,  97.93437913,
        104.4616999 , 124.88281222, 105.45047555, 121.50629781,
         95.78307604, 125.63765558, 129.075113