In [1]:
"""
Based upon the following article: Predicting Protein-Protein
Interactions Using a Protein Language Model and Linear Sum Assignment
(https://huggingface.co/blog/AmelieSchreiber/protein-binding-partners
-with-esm2)
"""

# Importing required libraries
import time

import numpy as np
from scipy.optimize import linear_sum_assignment
from transformers import AutoTokenizer, EsmForMaskedLM
import torch
import pandas as pd
from biotite.database import uniprot
import biotite.sequence.io.fasta as fasta

In [2]:
# Initialising the model and tokeniser
tokeniser = AutoTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D")
model = EsmForMaskedLM.from_pretrained("facebook/esm2_t6_8M_UR50D")

# Ensure that the model is in evaluation mode
model.eval()

EsmForMaskedLM(
  (esm): EsmModel(
    (embeddings): EsmEmbeddings(
      (word_embeddings): Embedding(33, 320, padding_idx=1)
      (dropout): Dropout(p=0.0, inplace=False)
      (position_embeddings): Embedding(1026, 320, padding_idx=1)
    )
    (encoder): EsmEncoder(
      (layer): ModuleList(
        (0-5): 6 x EsmLayer(
          (attention): EsmAttention(
            (self): EsmSelfAttention(
              (query): Linear(in_features=320, out_features=320, bias=True)
              (key): Linear(in_features=320, out_features=320, bias=True)
              (value): Linear(in_features=320, out_features=320, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
              (rotary_embeddings): RotaryEmbedding()
            )
            (output): EsmSelfOutput(
              (dense): Linear(in_features=320, out_features=320, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (LayerNorm): LayerNorm((320,), eps=1e-05, elementwise_a

In [3]:
# Check if a GPU is available and if so, set the model to run on it
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

In [4]:
# HVIDB has 456 recorded human-VACV PPIs; for each of them, the Masked
# Language Model (MLM) loss is computed in order to determine a suitable
# threshold value
# As a first step, import the respective CSV file and extract the
# interaction pairs, which are stored in the column "Human-virus PPI"
HVIDB_VACV_interactions_df = pd.read_csv(
    "all_HVIDB_VACV_interactions.csv"
)
interaction_pair_IDs = HVIDB_VACV_interactions_df[
    "Human-virus PPI"
].to_list()

# Querying the UniProt database is not possible for the following IDs:
# A0A0A0MR88, H0Y4G9 and A0A087X117 (they are all human protein IDs)
# Therefore, the interactions they are involved in are removed from the
# list
elements_to_remove = []

for interaction in interaction_pair_IDs:
    if (
        ("A0A0A0MR88" in interaction)
        or
        ("H0Y4G9" in interaction)
        or
        ("A0A087X117" in interaction)
    ):
        elements_to_remove.append(interaction)

for element in elements_to_remove:
    interaction_pair_IDs.remove(element)

# In the individual strings representing an interaction pair, the human
# and VACV UniProt ID are separated from each other by a hyphen
split_IDs_list = [
    pair_str.split("-") for pair_str in interaction_pair_IDs
]

In [5]:
# Determine the largest combined sequence length among the HVIDB
# interaction pairs
combined_lengths_list = []

for interaction_pair in split_IDs_list:
    human_ID, virus_ID = interaction_pair

    # Take the colon into account
    combined_length = 1

    # Suspend code execution for 2 seconds in order to obviate
    # server-side errors
    #time.sleep(2)
    io_object_human = uniprot.fetch(human_ID, "fasta")
    # Read the StringIO object into a FASTA file
    human_fasta = fasta.FastaFile.read(io_object_human)
    # The FASTA file contains only one entry; hence, the first and only
    # element is retrieved from the iterator returned by the `items()`
    # method
    _, seq_str = list(human_fasta.items())[0]
    combined_length += len(seq_str)

    #time.sleep(2)
    io_object_virus = uniprot.fetch(virus_ID, "fasta")
    virus_fasta = fasta.FastaFile.read(io_object_virus)
    _, seq_str = list(virus_fasta.items())[0]
    combined_length += len(seq_str)

    combined_lengths_list.append(combined_length)

print(
    f"The largest combined sequence length is {max(combined_lengths_list)}."
)

The largest combined sequence length is 36196.


In [18]:
print(combined_lengths_list)

[616, 1149, 388, 280, 528, 474, 809, 1047, 585, 1122, 930, 685, 689, 568, 673, 798, 873, 635, 654, 718, 459, 1825, 812, 438, 1204, 553, 544, 280, 396, 635, 532, 922, 906, 309, 705, 472, 546, 657, 1217, 953, 482, 852, 615, 1055, 420, 617, 478, 639, 1278, 977, 758, 7090, 594, 964, 543, 1513, 856, 725, 237, 824, 944, 713, 447, 286, 2514, 1012, 1421, 497, 472, 919, 1861, 1449, 1305, 387, 2602, 640, 34555, 347, 420, 1383, 449, 2530, 1357, 1016, 347, 1060, 1156, 1111, 474, 715, 281, 525, 1598, 405, 703, 644, 353, 321, 678, 557, 915, 462, 382, 474, 2496, 913, 577, 407, 1217, 812, 599, 1046, 652, 636, 1412, 1836, 874, 795, 1185, 507, 2551, 1781, 1752, 2048, 1085, 1063, 298, 795, 398, 553, 1950, 1055, 3005, 1603, 280, 472, 604, 263, 278, 420, 267, 489, 847, 1398, 572, 472, 245, 2447, 405, 563, 1916, 2692, 1348, 355, 327, 460, 554, 344, 412, 912, 828, 632, 713, 2002, 347, 631, 472, 843, 808, 789, 776, 647, 1683, 653, 1013, 382, 582, 808, 465, 654, 313, 1122, 569, 824, 936, 280, 498, 1259, 476, 4

In [24]:
print(np.count_nonzero(np.array(combined_lengths_list) > 4000))

4


In [6]:
def compute_mlm_loss(protein_1, protein_2, max_length, iterations=3):
    """
    Computes the Masked Language Model (MLM) loss between a pair of
    proteins using the ESM-2 model in order to assess the probability of
    interaction between the respective proteins.

    Usage of this function assumes ESM-2 to have already been
    initialised and to be stored in a variable named "model".

    Parameters
    ----------
    protein_1: str
        A string representing the amino acid sequence of the first
        protein in one-letter code.
    protein_2: str
        A string representing the amino acid sequence of the second
        protein in one-letter code.
    max_length: int
        The maximum combined sequence length including the colon
        character.
    iterations: int, optional
        The amount of times the procedure of randomly masking tokens and
        computing the MLM loss is repeated.

    Returns
    -------
    avg_mlm_loss: float
        The average MLM loss for the given pair of proteins.
    """
    total_loss = 0.0

    for _ in range(iterations):
        # Concatenate the two protein sequences with a colon as
        # separator between them
        concatenated_sequence = protein_1 + ":" + protein_2

        # Mask a subset of amino acids in the concatenated sequence
        # (excluding the separator)
        tokens = list(concatenated_sequence)
        mask_rate = 0.15 # Masking 15% of the sequence; optimal values
        # should be determined empirically/heuristically by testing
        # different values
        num_masks = int(len(tokens) * mask_rate)

        # Exclude the separator from potential mask indices
        available_indices = [
            i for i, token in enumerate(tokens) if token != ":"
        ]
        probs = torch.ones(len(available_indices))
        mask_indices = torch.multinomial(probs, num_masks, replacement=False)

        # Note that an intermediate step is taken by first indexing the
        # list `available_indices` instead of indexing the list `tokens`
        # directly
        # This is due to the fact that there is no one-to-one
        # correspondence between the indices randomly chosen by
        # `torch.multinomial` and the indices of the tokens as the colon
        # is not taken into account
        for idx in mask_indices:
            tokens[available_indices[idx]] = tokeniser.mask_token
        
        masked_sequence = "".join(tokens)
        inputs = tokeniser(
            masked_sequence,
            return_tensors="pt",
            truncation=True,
            max_length=max_length,
            padding="max_length"
        )

        # Comppute the MLM loss
        with torch.no_grad():
            outputs = model(**inputs, labels=inputs["input_ids"])
            loss = outputs.loss
        
        total_loss += loss.item()
    
    # Return the average loss
    return total_loss / iterations

In [16]:
# Compute the MLM loss for the first ten interaction pairs
first_ten_interaction_pairs = split_IDs_list[:10]
mlm_loss_list = []

for interaction_pair in first_ten_interaction_pairs:
    # Retrieve for each UniProt ID the corresponding sequence by
    # querying the UniProt database
    seqs = []
    for uniprot_ID in interaction_pair:
        io_object = uniprot.fetch(uniprot_ID, "fasta")
        fasta_file = fasta.FastaFile.read(io_object)
        _, seq_str = list(fasta_file.items())[0]
        seqs.append(seq_str)
    
    human_seq, virus_seq = seqs
    
    avg_loss = compute_mlm_loss(human_seq, virus_seq, max_length=4000)

    mlm_loss_list.append(avg_loss)

In [9]:
print(mlm_loss_list)

[16.05976454416911, 12.061362584431967, 18.526925404866535, 19.77448018391927, 17.920093536376953, 18.357943852742512, 15.585258801778158, 13.492405891418457, 17.451945622762043, 12.384026209513346]


In [13]:
print("Values for max_length=1000")
print(max(mlm_loss_list))
print(np.mean(mlm_loss_list))
print(np.std(mlm_loss_list))

Values for max_length=1000
17.667570114135742
10.302842132250467
4.833453504918497


In [11]:
print("Values for max_length=2000")
print(max(mlm_loss_list))
print(np.mean(mlm_loss_list))
print(np.std(mlm_loss_list))

Values for max_length=2000
19.77448018391927
16.161420663197834
2.5841445831938548


In [15]:
print("Values for max_length=3000")
print(max(mlm_loss_list))
print(np.mean(mlm_loss_list))
print(np.std(mlm_loss_list))

Values for max_length=3000
20.974369049072266
18.607375717163087
1.5803331755458345


In [17]:
print("Values for max_length=4000")
print(max(mlm_loss_list))
print(np.mean(mlm_loss_list))
print(np.std(mlm_loss_list))

Values for max_length=4000
21.16938845316569
19.61531899770101
1.049258936351101
