In [48]:
pip install transformers torch numpy gudhi -q

Note: you may need to restart the kernel to use updated packages.


## Define the Protein Sequences with a Common Motif

In [49]:
sequences = [

"MMHPIMNTADPVYAVYFYVKWRHRYSNRAVEKHAGMTHADDGFRDHGTFSFGPKPLDGGQHACGTCFKLFNGPKACEAFKSYCSNLINQCITQGSTNMHS",

"NNAVNIVNGDIKNVNAGACGTCVPQFHGHLCSDNTCSLWQKLPPQHLDGQNLWGNGDLYPPLKEIFEGWNPSMALWPLEMLKPCVSLVPQTQAPCSGVKF",

"PTGSWYPDYCYTLVEHQKGNFMICDPNYTNFRLEIERELFCYTYYQVFCWSYCNCFKPFGRTKVFIWHAMHLRYMSLAYYLSKVTMGMCEHAACGTCTFL",

"TFVYCACGTCRRGSHEKWWIYTVDIYACILCLNGSSFMFRNTSVWCTYWFLRCWEVFLSTMWQPRCHRTCIPEIPKHYTQYFQHPSCWCNDVHYHDVQSI",

"WSDGSKSSGDTMWHANALTNHFLHMTFIKSTVMRMPWVNHRNCGFFACGTCCVKYPCGGECHWPTLPKVCLSLVMMKARVFKNSLHVPYTDEDIYNYKNY",

"RSSEVHRVQIWNVMDSNFNWTSYTSCRTKVDKPKVPWDNGACGTCIHNYCNCENGDKVKPNNIWANSHTDWIVPHKHAWMYFERTFRLWIHGEHTCTFGI",

"KYYRFGTQPAGHGPVGRGRACALSRQWSRCRFILVIWITYEGDHVRMLEHKFQATSPNMKLWHDNCRGACGTCEVDRAAHTPYIYCHWSGMSAVNMTQRS",

"VHWKMLAYARHTDHKSYMRSQQLMPQFGPFGYCYVMLCGLHKYAMNLVTGLEVLACGTCGPAATMKIKWVLGDPHLRFGAAGPNDHRSATDPNFYKQFHK",

"ENEKSEHNNWWMFVRNHTFVRIECCHQTWLRATMGAFKWTECVIKKFYYDYCCMNQWYDELAACGTCGQMMTSYQDWLLKNQNMEFLGYWWMQQIQSFVM",

"VMADINWSNGQYAQDWGHQDPGIWYNGNTLGSHEQEVSYRCLEDAVCISRWLNMNYGTQSIVDVVEPARWLAMKFAHACGTCGYAHNQQVITNKYQQRVG" 

]



In [50]:
import torch
from transformers import AutoModel 
from transformers import EsmModel, EsmConfig, AutoTokenizer
import transformers
from scipy.spatial import distance_matrix
import gudhi as gd
from gudhi.hera import wasserstein_distance
import numpy as np
import matplotlib.pyplot as plt


def compute_output_model(tokenizer, model, sentence, layer, head):
    # Load pre-trained model

    # Tokenize input and convert to tensor
    inputs = tokenizer(sentence, return_tensors="pt")

    # Forward pass
    # Specify `output_hidden_states=True` when calling the model
    outputs = model(**inputs, output_attentions=True, output_hidden_states=True)

    # Obtain the attention weights
    attentions = outputs.attentions

    # Obtain the attention weights for the specific layer and head
    S = attentions[layer][0, head]

    # Obtain the value vectors
    model.eval()
    with torch.no_grad():
        hidden_states = outputs.hidden_states[layer]
        all_W_v = model.encoder.layer[layer].attention.self.value.weight
        num_heads = model.config.num_attention_heads
        head_dim = model.config.hidden_size // num_heads
        W_v_heads = all_W_v.view(num_heads, head_dim, model.config.hidden_size)
        W_v = W_v_heads[head]
        V = torch.matmul(hidden_states, W_v.t())

    # Compute the output O
    O = torch.matmul(S, V)

    return O



def compute_phrase_distances_and_homology(tokenizer, context_vectors, sentence, phrase):
    # Initialize the tokenizer

    # Tokenize the sentence and the phrase
    sentence_tokens = tokenizer.encode(sentence, add_special_tokens=False)
    phrase_tokens = tokenizer.encode(phrase, add_special_tokens=False)

    # Find the indices of the phrase tokens in the sentence
    phrase_indices = []
    phrase_length = len(phrase_tokens)
    for i in range(len(sentence_tokens) - phrase_length + 1):
        if sentence_tokens[i:i+phrase_length] == phrase_tokens:
            phrase_indices.extend(range(i, i+phrase_length))
            break

    # Extract the context vectors for the phrase
    phrase_context_vectors = context_vectors[0, phrase_indices]

    # Detach the tensor and convert to numpy array
    phrase_context_vectors_np = phrase_context_vectors.detach().numpy()

    # Compute the pairwise Euclidean distances among the phrase context vectors
    distances = distance_matrix(phrase_context_vectors_np, phrase_context_vectors_np)

    # Compute the persistent homology of the distance matrix
    rips_complex = gd.RipsComplex(distance_matrix=distances, max_edge_length=np.max(distances))
    simplex_tree = rips_complex.create_simplex_tree(max_dimension=2)
    persistent_homology = simplex_tree.persistence(min_persistence=0.001)

    return persistent_homology



def transform_persistence_diagram(diagram, target_dimension=0):
    # Only include features of the target dimension and return the transformed diagram
    return [(birth, death) for dimension, (birth, death) in diagram if dimension == target_dimension]



def compute_wasserstein_distances(persistence_diagrams, p=2, target_dimension=0):
    n = len(persistence_diagrams)
    distances = np.zeros((n, n))
    for i in range(n):
        for j in range(i+1, n):
            diagram1 = transform_persistence_diagram(persistence_diagrams[i], target_dimension)
            diagram2 = transform_persistence_diagram(persistence_diagrams[j], target_dimension)
            distance = wasserstein_distance(diagram1, diagram2, order=1., internal_p=2.)
            distances[i, j] = distance
            distances[j, i] = distance
    return distances



def count_negative_entries_below_diagonal(matrix):
    count = 0
    total = 0
    n = len(matrix)
    for i in range(n):
        for j in range(i):
            if matrix[i][j] < 0:
                count += 1
            total += 1
    return count, total

### Compute the context vectors for all layers and heads for the model

In [51]:
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D")
model = EsmModel.from_pretrained("facebook/esm2_t6_8M_UR50D", output_attentions=True)

# Get the number of layers and heads in the model
num_layers = model.config.num_hidden_layers
num_heads = model.config.num_attention_heads

context = []
for i in range(len(sequences)):
    sequence_context = []
    for layer in range(num_layers):
        layer_context = []
        for head in range(num_heads):
            layer_context.append(compute_output_model(tokenizer, model, sequences[i], layer, head))
        sequence_context.append(layer_context)
    context.append(sequence_context)


Some weights of the model checkpoint at facebook/esm2_t6_8M_UR50D were not used when initializing EsmModel: ['lm_head.layer_norm.weight', 'lm_head.bias', 'lm_head.layer_norm.bias', 'lm_head.dense.bias', 'lm_head.dense.weight']
- This IS expected if you are initializing EsmModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing EsmModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t6_8M_UR50D and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


### Print the context vectors for the model for `sequence`, `layer`, and `head`

In [52]:
sequence = len(sequences)-1
layer = 5
head = 5

print(context[sequence][layer][head].shape)
print(context[sequence][layer][head])

torch.Size([1, 102, 16])
tensor([[[ 0.0716, -2.4374,  0.7318,  ..., -0.7640,  3.3909, -0.1210],
         [-0.2228, -2.6890,  0.9176,  ..., -0.9820,  2.6624, -0.5428],
         [-0.1338, -2.7014,  0.9474,  ..., -0.9658,  2.6073, -0.5419],
         ...,
         [ 0.7355,  0.7615,  0.9914,  ..., -2.3322,  1.9662,  1.1216],
         [ 0.7229,  0.9678,  0.9745,  ..., -2.2428,  1.8807,  1.2550],
         [ 0.6539, -0.4697,  0.6328,  ..., -1.8632,  2.6691,  0.4090]]],
       grad_fn=<CloneBackward0>)


In [53]:
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D")

# Get the number of layers and heads in the model
num_layers = model.config.num_hidden_layers
num_heads = model.config.num_attention_heads

persistent_homology = []

for i in range(len(sequences)):
    sentence_persistent_homology = []
    for layer in range(num_layers):
        layer_persistent_homology = []
        for head in range(num_heads):
            layer_persistent_homology.append(compute_phrase_distances_and_homology(tokenizer, context[i][layer][head], sequences[i], "ACGTC"))
        sentence_persistent_homology.append(layer_persistent_homology)
    persistent_homology.append(sentence_persistent_homology)


In [54]:
persistent_homology[sequence][layer][head]

[(0, (0.0, inf)),
 (0, (0.0, 2.187820107942586)),
 (0, (0.0, 1.499387111378892)),
 (0, (0.0, 1.4223606198879313)),
 (0, (0.0, 1.1088247399056559))]

In [55]:
# Get the number of layers and heads in the model
num_layers = model.config.num_hidden_layers
num_heads = model.config.num_attention_heads

persistence_diagrams = []

for i in range(len(sequences)):
    sentence_persistence_diagrams = []
    for layer in range(num_layers):
        layer_persistence_diagrams = []
        for head in range(num_heads):
            layer_persistence_diagrams.append(persistent_homology[i][layer][head])
        sentence_persistence_diagrams.append(layer_persistence_diagrams)
    persistence_diagrams.append(sentence_persistence_diagrams)

In [56]:
persistence_diagrams[len(sequences)-1][layer][head]

[(0, (0.0, inf)),
 (0, (0.0, 2.187820107942586)),
 (0, (0.0, 1.499387111378892)),
 (0, (0.0, 1.4223606198879313)),
 (0, (0.0, 1.1088247399056559))]

In [57]:
# Get the number of layers and heads in the model
num_layers = model.config.num_hidden_layers
num_heads = model.config.num_attention_heads

# Initialize w_distances
w_distances = [[None for _ in range(num_heads)] for _ in range(num_layers)]

for layer in range(num_layers):
    for head in range(num_heads):
        w_distances[layer][head] = compute_wasserstein_distances([persistent_homology[sequence][layer][head] for sequence in range(len(sequences))])


### Distance Matrices for `layer` and `head` of the first model

Here we print the Wasserstein distance matrix giving the pairwise Wasserstein distances between each pair of persistence diagrams for the sequence motif `ACGTC` in each of the contexts `sequences[i]`. This gives a summary of how well the attention head preserves the persistent homology of the motif subsequence. We can then compare each head in the model, of compare heads in different models. This will gives a way to analyze which heads preserve persistent homology of certain sequential motifs well. We can then use Low Rank Adaptations (LoRAs) to further modify attention heads to preserve persistent homology better, or use these heads in particular in a knowledge distillation process. 

In [58]:
print(w_distances[layer][head].shape)
print(w_distances[layer][head])

(10, 10)
[[ 0.          1.39238421  1.41594783  1.85822688  1.67857971  1.40880631
   1.61060503  8.42503068  5.70553847  1.22788006]
 [ 1.39238421  0.          2.04597967  2.07716571  3.07096392  2.62689958
   2.1560772   7.39770065  4.31315426  1.40475998]
 [ 1.41594783  2.04597967  0.          2.41268016  1.12412643  0.91650471
   0.31969838  9.34453814  6.25999175  1.57946912]
 [ 1.85822688  2.07716571  2.41268016  0.          3.53680658  3.09274224
   2.2286383   6.93185798  3.84731159  1.09080109]
 [ 1.67857971  3.07096392  1.12412643  3.53680658  0.          0.54468621
   1.30816828 10.04822742  7.38411817  2.70359555]
 [ 1.40880631  2.62689958  0.91650471  3.09274224  0.54468621  0.
   0.97550539  9.80831268  6.94005383  2.25953121]
 [ 1.61060503  2.1560772   0.31969838  2.2286383   1.30816828  0.97550539
   0.          9.16049628  6.07594989  1.39542727]
 [ 8.42503068  7.39770065  9.34453814  6.93185798 10.04822742  9.80831268
   9.16049628  0.          3.41105655  7.76506902]

In [59]:
layer_2 = 3
head_2 = 5

print(w_distances[layer_2][head_2].shape)
print(w_distances[layer_2][head_2])

(10, 10)
[[0.         0.42284069 0.26190206 0.33752455 0.35762748 0.43477798
  0.1590168  0.09297139 0.2608054  0.18499069]
 [0.42284069 0.         0.49504896 0.08963797 0.09590379 0.12751034
  0.27874397 0.32986931 0.16602771 0.24709696]
 [0.26190206 0.49504896 0.         0.40541099 0.42983575 0.49452138
  0.23122507 0.24911698 0.32902125 0.26854067]
 [0.33752455 0.08963797 0.40541099 0.         0.09260987 0.1413259
  0.22856443 0.27806722 0.07671915 0.19691742]
 [0.35762748 0.09590379 0.42983575 0.09260987 0.         0.0771505
  0.19861068 0.2646561  0.1008145  0.17263679]
 [0.43477798 0.12751034 0.49452138 0.1413259  0.0771505  0.
  0.27576118 0.34180659 0.177965   0.24978728]
 [0.1590168  0.27874397 0.23122507 0.22856443 0.19861068 0.27576118
  0.         0.07902644 0.15184528 0.05461064]
 [0.09297139 0.32986931 0.24911698 0.27806722 0.2646561  0.34180659
  0.07902644 0.         0.20134807 0.12783017]
 [0.2608054  0.16602771 0.32902125 0.07671915 0.1008145  0.177965
  0.15184528 0.

In [64]:
matrix = w_distances[layer_2][head_2] - w_distances[layer][head]
print(matrix)

[[ 0.         -0.96954352 -1.15404576 -1.52070233 -1.32095223 -0.97402833
  -1.45158823 -8.3320593  -5.44473307 -1.04288937]
 [-0.96954352  0.         -1.5509307  -1.98752774 -2.97506013 -2.49938923
  -1.87733324 -7.06783134 -4.14712655 -1.15766302]
 [-1.15404576 -1.5509307   0.         -2.00726916 -0.69429067 -0.42198333
  -0.0884733  -9.09542115 -5.93097049 -1.31092845]
 [-1.52070233 -1.98752774 -2.00726916  0.         -3.44419671 -2.95141634
  -2.00007387 -6.65379076 -3.77059244 -0.89388367]
 [-1.32095223 -2.97506013 -0.69429067 -3.44419671  0.         -0.46753571
  -1.1095576  -9.78357133 -7.28330367 -2.53095876]
 [-0.97402833 -2.49938923 -0.42198333 -2.95141634 -0.46753571  0.
  -0.69974421 -9.46650608 -6.76208884 -2.00974392]
 [-1.45158823 -1.87733324 -0.0884733  -2.00007387 -1.1095576  -0.69974421
   0.         -9.08146985 -5.92410462 -1.34081663]
 [-8.3320593  -7.06783134 -9.09542115 -6.65379076 -9.78357133 -9.46650608
  -9.08146985  0.         -3.20970848 -7.63723884]
 [-5.444

## Compute some Statistice

As we can see from the below analysis, the distance matrix `w_distances[layer][head]` has larger Wasserstein distances than the distance matrix `w_distances[layer_2][head_2]` as indicated by the $0\%$ non-negative entries percentage, meaning all of the entries in the lower triangular part of the matrix below the diagonal are larger in `w_distances[layer][head]`. This means that `w_distances[layer_2][head_2]` preserves persistent homology better than `w_distances[layer][head]` for the sequence motif `ACGTC`.


In [68]:
import numpy as np
from scipy import stats

def non_negative_percentage(matrix):
    n = len(matrix)
    lower_triangular_no_diag = matrix[np.tril_indices(n, -1)]
    non_negative_count = np.sum(lower_triangular_no_diag >= 0)
    total_count = len(lower_triangular_no_diag)
    percentage = (non_negative_count / total_count) * 100
    return percentage

def stats_analysis(matrix):
    n = len(matrix)
    lower_triangular_no_diag = matrix[np.tril_indices(n, -1)]
    max_value = np.max(lower_triangular_no_diag)
    min_value = np.min(lower_triangular_no_diag)
    median_value = np.median(lower_triangular_no_diag)
    mean_value = np.mean(lower_triangular_no_diag)
    std_dev = np.std(lower_triangular_no_diag)
    
    # Compute interquartile range to find outliers
    q1 = np.percentile(lower_triangular_no_diag, 25)
    q3 = np.percentile(lower_triangular_no_diag, 75)
    iqr = q3 - q1
    outlier_mask = (lower_triangular_no_diag < (q1 - 1.5 * iqr)) | (lower_triangular_no_diag > (q3 + 1.5 * iqr))
    outliers = lower_triangular_no_diag[outlier_mask]
    
    return {"Max": max_value, "Min": min_value, "Median": median_value, 
            "Mean": mean_value, "Standard Deviation": std_dev, "Outliers": outliers}


matrix = w_distances[layer_2][head_2] - w_distances[layer][head]
print(f"Non-negative entries percentage: {non_negative_percentage(matrix)}%")
print(f"Statistical Analysis: {stats_analysis(matrix)}")


Non-negative entries percentage: 0.0%
Statistical Analysis: {'Max': -0.08847330428996433, 'Min': -9.783571327699914, 'Median': -2.009743922992899, 'Mean': -3.480074829869965, 'Standard Deviation': 2.8731863437962626, 'Outliers': array([], dtype=float64)}


## Comparing Pairs of Attention Heads

Here we compare all pairs of the attention heads (except comparing a head to itself, which yields trivial statistics). 

In [75]:
# Assuming w_distances is a nested list where the first index is layers and the second index is heads
num_layers = len(w_distances)
num_heads = len(w_distances[0])

stats = {}
for i in range(num_layers):
    for j in range(num_heads):
        for k in range(i, num_layers):
            for l in range((j if i == k else 0), num_heads):  # if same layer, start from next head
                if i == k and j == l:  # Skip comparing the same head with itself
                    continue
                matrix = w_distances[k][l] - w_distances[i][j]
                stats[(i, j, k, l)] = {
                    "Non-negative entries percentage": non_negative_percentage(matrix),
                    "Statistical Analysis": stats_analysis(matrix)
                }

In [79]:
# Specify the pair of attention heads
layer1 = 3 
head1 = 5 
layer2 = 5  
head2 = 2  

# Print the statistics for the specified pair of attention heads
stats[(layer1, head1, layer2, head2)]

{'Non-negative entries percentage': 100.0,
 'Statistical Analysis': {'Max': 9.959264763771461,
  'Min': 1.088075241570382,
  'Median': 4.25026259849097,
  'Mean': 4.444191712729245,
  'Standard Deviation': 2.029278381635479,
  'Outliers': array([9.57412683, 9.95926476])}}

In [76]:
stats

{(0, 0, 0, 1): {'Non-negative entries percentage': 75.55555555555556,
  'Statistical Analysis': {'Max': 0.2737156397631336,
   'Min': -0.15416111928617268,
   'Median': 0.05830137358019381,
   'Mean': 0.05641814209541791,
   'Standard Deviation': 0.09657246064206684,
   'Outliers': array([ 0.27371564, -0.15416112])}},
 (0, 0, 0, 2): {'Non-negative entries percentage': 26.666666666666668,
  'Statistical Analysis': {'Max': 0.1816781003842158,
   'Min': -0.22302049773932014,
   'Median': -0.04974570992278435,
   'Mean': -0.03121254989953686,
   'Standard Deviation': 0.09986535761213923,
   'Outliers': array([-0.2230205 ,  0.16623232,  0.1816781 ,  0.17905587,  0.1523926 ])}},
 (0, 0, 0, 3): {'Non-negative entries percentage': 75.55555555555556,
  'Statistical Analysis': {'Max': 0.30193427450848037,
   'Min': -0.20784200392159122,
   'Median': 0.09866960624310361,
   'Mean': 0.10597497561231475,
   'Standard Deviation': 0.1348866622685403,
   'Outliers': array([], dtype=float64)}},
 (0, 0,