In [13]:
import scoring_functions as sfuncs
import predict_top_sequences as pts

import torch
from transformers import AutoTokenizer, EsmForMaskedLM

In [14]:
# Test the function
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D")
model = EsmForMaskedLM.from_pretrained("facebook/esm2_t6_8M_UR50D")
sequence = "MAPLRKTYVLKLYVAGNTPNSVRALKTLNNILEKEFKGVYALKVIDVLKNPQLAEEDKILATPTLAKVLPPPVRRIIGDLSNREKVLIGLDLLYEEIGDQAEDDLGLE"
mutations = [(68, 'L', 'R'), (83, 'E', 'D'), (84, 'K', 'A')]

In [15]:
print(sfuncs.masked_marginal_scoring(tokenizer, model, sequence, mutations))

-0.9924547672271729


In [16]:
print(sfuncs.mutant_marginal_score(tokenizer, model, sequence, mutations))

-1.6451168060302734


In [17]:
print(sfuncs.wild_type_marginal_score(tokenizer, model, sequence, mutations))

9.463499069213867


In [18]:
print(sfuncs.pseudolikelihood_score(tokenizer, model, sequence, mutations))

9.509151458740234


In [19]:
scores, top_sequences = pts.predict_top_full_sequences(sequence, [67, 82, 83], 10)
top_sequences

['MAPLRKTYVLKLYVAGNTPNSVRALKTLNNILEKEFKGVYALKVIDVLKNPQLAEEDKILATPTLAKKLPPPVRRIIGDLSNKEKVLIGLDLLYEEIGDQAEDDLGLE',
 'MAPLRKTYVLKLYVAGNTPNSVRALKTLNNILEKEFKGVYALKVIDVLKNPQLAEEDKILATPTLAKKLPPPVRRIIGDLSNLEKVLIGLDLLYEEIGDQAEDDLGLE',
 'MAPLRKTYVLKLYVAGNTPNSVRALKTLNNILEKEFKGVYALKVIDVLKNPQLAEEDKILATPTLAKLLPPPVRRIIGDLSNPKKVLIGLDLLYEEIGDQAEDDLGLE',
 'MAPLRKTYVLKLYVAGNTPNSVRALKTLNNILEKEFKGVYALKVIDVLKNPQLAEEDKILATPTLAKGLPPPVRRIIGDLSNKKKVLIGLDLLYEEIGDQAEDDLGLE',
 'MAPLRKTYVLKLYVAGNTPNSVRALKTLNNILEKEFKGVYALKVIDVLKNPQLAEEDKILATPTLAKKLPPPVRRIIGDLSNPKKVLIGLDLLYEEIGDQAEDDLGLE',
 'MAPLRKTYVLKLYVAGNTPNSVRALKTLNNILEKEFKGVYALKVIDVLKNPQLAEEDKILATPTLAKLLPPPVRRIIGDLSNKKKVLIGLDLLYEEIGDQAEDDLGLE',
 'MAPLRKTYVLKLYVAGNTPNSVRALKTLNNILEKEFKGVYALKVIDVLKNPQLAEEDKILATPTLAKKLPPPVRRIIGDLSNKKKVLIGLDLLYEEIGDQAEDDLGLE',
 'MAPLRKTYVLKLYVAGNTPNSVRALKTLNNILEKEFKGVYALKVIDVLKNPQLAEEDKILATPTLAKKLPPPVRRIIGDLSNLKKVLIGLDLLYEEIGDQAEDDLGLE',
 'MAPLRKTYVLKLYVAGNTPNSVRALKTLNNILEKEFKGVYALKVIDVLKNPQLAEEDKILATPTLAKGLPPPVRRIIGDLSNLKKVLIGLDLLY

The `predict_top_full_sequences` function does the following:

1. It first loads a pre-trained model and tokenizer from the Hugging Face model hub. Here, the `facebook/esm2_t6_8M_UR50D` model is used, which is a transformer-based model trained on a large corpus of protein sequences.

2. The function then takes a protein sequence and a list of positions where the amino acids are to be mutated. The positions are replaced with a special mask token in the sequence.

3. The masked sequence is passed through the model, and the output is the logits for each position in the sequence.

4. The logits are then converted to probabilities using the softmax function, specifically for the positions that were masked.

5. The function then computes all possible combinations of amino acids for the masked positions and computes the sum of the log probabilities of these combinations. 

6. It maintains a heap of size `m` (a parameter passed to the function) to keep the top `m` sequences based on their scores. 

The mathematical description can be presented as follows:

Let $p$ be the softmax probabilities of the amino acids at the masked positions, $C$ be the set of all possible combinations of amino acids, and $S$ be the set of top $m$ sequences and their scores. We can define $S$ as follows:

$$
S = \text{Top}_m \left\{ \left( \sum_{i=1}^{k} \log p(c_i), c \right) \, | \, c = (c_1, c_2, \ldots, c_k) \in C \right\}
$$

where $\text{Top}_m$ is an operation that selects the top $m$ elements based on their scores, $c$ is a combination of amino acids, and $k$ is the number of masked positions.

The function then generates the top sequences by replacing the masked positions in the original sequence with the amino acids from the top combinations.

The `predict_bottom_full_sequences_nst` and `predict_bottom_full_sequences_st` functions work in a similar manner, but keep the bottom `m` sequences based on their scores. The `nst` in `predict_bottom_full_sequences_nst` stands for "No Special Tokens", meaning this function ignores combinations that include special tokens. The `st` in `predict_bottom_full_sequences_st` stands for "Special Tokens", indicating this function allows combinations that include special tokens.

In [20]:
# The initial sequence
# sequence = "MAPLRKTYVLKLYVAGNTPNSVRALKTLNNILEKEFKGVYALKVIDVLKNPQLAEEDKILATPTLAKVLPPPVRRIIGDLSNREKVLIGLDLLYEEIGDQAEDDLGLE"
# The mutated amino acids
mutated_aas = ['R', 'D', 'A']

# The positions of the mutations are provided
mutation_positions = [67, 82, 83]

# Get the wild type amino acids from the original sequence
wt_aas = [sequence[i] for i in mutation_positions]

# Let's create a list of mutations for each sequence
mutations_for_sequences = []

for seq in top_sequences:
    mutations = []
    for pos, aa in zip(mutation_positions, mutated_aas):
        wt_aa = seq[pos]  # The wild-type amino acid at the mutation position in the current sequence
        mutations.append((pos, wt_aa, aa))  # Keep 0-indexing for the mutations' positions
    mutations_for_sequences.append(mutations)

mutations_for_sequences


[[(67, 'K', 'R'), (82, 'K', 'D'), (83, 'E', 'A')],
 [(67, 'K', 'R'), (82, 'L', 'D'), (83, 'E', 'A')],
 [(67, 'L', 'R'), (82, 'P', 'D'), (83, 'K', 'A')],
 [(67, 'G', 'R'), (82, 'K', 'D'), (83, 'K', 'A')],
 [(67, 'K', 'R'), (82, 'P', 'D'), (83, 'K', 'A')],
 [(67, 'L', 'R'), (82, 'K', 'D'), (83, 'K', 'A')],
 [(67, 'K', 'R'), (82, 'K', 'D'), (83, 'K', 'A')],
 [(67, 'K', 'R'), (82, 'L', 'D'), (83, 'K', 'A')],
 [(67, 'G', 'R'), (82, 'L', 'D'), (83, 'K', 'A')],
 [(67, 'L', 'R'), (82, 'L', 'D'), (83, 'K', 'A')]]

In [21]:
# Compute the scores
masked_marginal_scores = []

# Iterate over each sequence and its corresponding mutations
for seq, mutations in zip(top_sequences, mutations_for_sequences):
    # Compute the masked marginal score for the current sequence and mutations
    score = sfuncs.masked_marginal_scoring(tokenizer, model, seq, mutations)
    masked_marginal_scores.append(score)

masked_marginal_scores

[-1.6732699871063232,
 -1.6997504234313965,
 -1.687041997909546,
 -1.8217713832855225,
 -1.7720277309417725,
 -1.841071367263794,
 -1.9260571002960205,
 -1.9525375366210938,
 -1.8482518196105957,
 -1.8675518035888672]

In [22]:
scores_2, bottom_sequences = pts.predict_bottom_full_sequences_nst(sequence, [67, 82, 83], 10)
bottom_sequences

['MAPLRKTYVLKLYVAGNTPNSVRALKTLNNILEKEFKGVYALKVIDVLKNPQLAEEDKILATPTLAKCLPPPVRRIIGDLSNHCKVLIGLDLLYEEIGDQAEDDLGLE',
 'MAPLRKTYVLKLYVAGNTPNSVRALKTLNNILEKEFKGVYALKVIDVLKNPQLAEEDKILATPTLAKWLPPPVRRIIGDLSNWCKVLIGLDLLYEEIGDQAEDDLGLE',
 'MAPLRKTYVLKLYVAGNTPNSVRALKTLNNILEKEFKGVYALKVIDVLKNPQLAEEDKILATPTLAKMLPPPVRRIIGDLSNCCKVLIGLDLLYEEIGDQAEDDLGLE',
 'MAPLRKTYVLKLYVAGNTPNSVRALKTLNNILEKEFKGVYALKVIDVLKNPQLAEEDKILATPTLAKCLPPPVRRIIGDLSNCMKVLIGLDLLYEEIGDQAEDDLGLE',
 'MAPLRKTYVLKLYVAGNTPNSVRALKTLNNILEKEFKGVYALKVIDVLKNPQLAEEDKILATPTLAKCLPPPVRRIIGDLSNWWKVLIGLDLLYEEIGDQAEDDLGLE',
 'MAPLRKTYVLKLYVAGNTPNSVRALKTLNNILEKEFKGVYALKVIDVLKNPQLAEEDKILATPTLAKWLPPPVRRIIGDLSNCWKVLIGLDLLYEEIGDQAEDDLGLE',
 'MAPLRKTYVLKLYVAGNTPNSVRALKTLNNILEKEFKGVYALKVIDVLKNPQLAEEDKILATPTLAKCLPPPVRRIIGDLSNCCKVLIGLDLLYEEIGDQAEDDLGLE',
 'MAPLRKTYVLKLYVAGNTPNSVRALKTLNNILEKEFKGVYALKVIDVLKNPQLAEEDKILATPTLAKCLPPPVRRIIGDLSNWCKVLIGLDLLYEEIGDQAEDDLGLE',
 'MAPLRKTYVLKLYVAGNTPNSVRALKTLNNILEKEFKGVYALKVIDVLKNPQLAEEDKILATPTLAKWLPPPVRRIIGDLSNCCKVLIGLDLLY

In [23]:
# The initial sequence and the mutated amino acids are provided
# sequence = "MAPLRKTYVLKLYVAGNTPNSVRALKTLNNILEKEFKGVYALKVIDVLKNPQLAEEDKILATPTLAKVLPPPVRRIIGDLSNREKVLIGLDLLYEEIGDQAEDDLGLE"
mutated_aas = ['R', 'D', 'A']

# The positions of the mutations are provided
mutation_positions = [67, 82, 83]

# Create a list of mutations for each sequence
mutations_for_bottom_sequences = []

for seq in bottom_sequences:
    mutations = []
    for pos, aa in zip(mutation_positions, mutated_aas):
        wt_aa = seq[pos]  # The wild-type amino acid at the mutation position in the current sequence
        mutations.append((pos, wt_aa, aa))  # Keep 0-indexing for the mutations' positions
    mutations_for_bottom_sequences.append(mutations)

mutations_for_bottom_sequences


[[(67, 'C', 'R'), (82, 'H', 'D'), (83, 'C', 'A')],
 [(67, 'W', 'R'), (82, 'W', 'D'), (83, 'C', 'A')],
 [(67, 'M', 'R'), (82, 'C', 'D'), (83, 'C', 'A')],
 [(67, 'C', 'R'), (82, 'C', 'D'), (83, 'M', 'A')],
 [(67, 'C', 'R'), (82, 'W', 'D'), (83, 'W', 'A')],
 [(67, 'W', 'R'), (82, 'C', 'D'), (83, 'W', 'A')],
 [(67, 'C', 'R'), (82, 'C', 'D'), (83, 'C', 'A')],
 [(67, 'C', 'R'), (82, 'W', 'D'), (83, 'C', 'A')],
 [(67, 'W', 'R'), (82, 'C', 'D'), (83, 'C', 'A')],
 [(67, 'C', 'R'), (82, 'C', 'D'), (83, 'W', 'A')]]

In [24]:
# Compute the scores
bottom_masked_marginal_scores = []

# Iterate over each sequence and its corresponding mutations
for seq, mutations in zip(bottom_sequences, mutations_for_bottom_sequences):
    # Compute the masked marginal score for the current sequence and mutations
    score = sfuncs.masked_marginal_scoring(tokenizer, model, seq, mutations)
    bottom_masked_marginal_scores.append(score)  # Corrected here

bottom_masked_marginal_scores


[5.583103895187378,
 5.61872935295105,
 5.6325719356536865,
 5.639241933822632,
 5.971154451370239,
 5.729832887649536,
 6.672896146774292,
 6.2664735317230225,
 6.025151968002319,
 6.377577066421509]

The `masked_marginal_scores` and `bottom_masked_marginal_scores` represent the sum of the log probabilities of the predicted amino acids at the masked positions for the top and bottom sequences, respectively. The `masked_marginal_scores` correspond to the top sequences, while the `bottom_masked_marginal_scores` correspond to the bottom sequences.

However, the scores being higher or lower does not necessarily correlate with being top or bottom sequences. Here's why:

The sequences are ranked based on their marginal probabilities. When we say "top" sequences, we mean the sequences that have the highest marginal probabilities. Similarly, "bottom" sequences are the ones with the lowest marginal probabilities. 

However, when we compute the `masked_marginal_score` for a sequence, we are not just looking at the probabilities of the masked positions, but also the probabilities of the specific amino acids that we are considering to mutate to. These specific mutations may or may not be the ones with the highest probabilities. 

So, even if a sequence is a "top" sequence, if the specific mutations that we are considering have low probabilities, the `masked_marginal_score` for that sequence could be low. Similarly, even if a sequence is a "bottom" sequence, if the specific mutations that we are considering have high probabilities, the `masked_marginal_score` for that sequence could be high.

Therefore, the `masked_marginal_scores` being smaller than `bottom_masked_marginal_scores` simply means that, for the specific mutations that we are considering, the probabilities are lower in the top sequences than in the bottom sequences. This does not contradict the fact that the top sequences have higher marginal probabilities overall. In fact, the bottom `m` sequences are actually more likely to have higher scores, as any mutations applied to them will likely improve their fitness since they are the least likely sequences preidcted by the model. We should see similar trends for the other scoring methods. 

## Mutant Marginal Scores

In [26]:
# Compute the scores using mutant_marginal_score function
mutant_marginal_scores = []

# Iterate over each sequence and its corresponding mutations
for seq, mutations in zip(top_sequences, mutations_for_sequences):
    # Compute the mutant marginal score for the current sequence and mutations
    score = sfuncs.mutant_marginal_score(tokenizer, model, seq, mutations)
    mutant_marginal_scores.append(score)

mutant_marginal_scores

[-3.606092691421509,
 -4.381468772888184,
 -0.9693958759307861,
 -1.646575927734375,
 -4.2036285400390625,
 -0.936593770980835,
 -4.1708269119262695,
 -4.946203231811523,
 -2.421952247619629,
 -1.7119700908660889]

In [27]:
# Compute the scores using mutant_marginal_score function
bottom_mutant_marginal_scores = []

# Iterate over each sequence and its corresponding mutations
for seq, mutations in zip(bottom_sequences, mutations_for_bottom_sequences):
    # Compute the mutant marginal score for the current sequence and mutations
    score = sfuncs.mutant_marginal_score(tokenizer, model, seq, mutations)
    bottom_mutant_marginal_scores.append(score)

bottom_mutant_marginal_scores

[4.972687721252441,
 4.8160200119018555,
 5.436990261077881,
 4.930351734161377,
 5.081740379333496,
 5.983952045440674,
 6.070474147796631,
 4.992141246795654,
 5.894352912902832,
 6.160073280334473]

## Wild-Type Marginal Scores

In [28]:
# Compute the scores using mutant_marginal_score function
wild_type_marginal_scores = []

# Iterate over each sequence and its corresponding mutations
for seq, mutations in zip(top_sequences, mutations_for_sequences):
    # Compute the mutant marginal score for the current sequence and mutations
    score = sfuncs.wild_type_marginal_score(tokenizer, model, seq, mutations)
    wild_type_marginal_scores.append(score)

wild_type_marginal_scores

[6.440245628356934,
 6.361305236816406,
 11.316875457763672,
 6.238447189331055,
 7.522137641906738,
 6.694493770599365,
 2.8822383880615234,
 6.755340576171875,
 10.553121566772461,
 10.057172775268555]

In [29]:
# Compute the scores using mutant_marginal_score function
bottom_wild_type_marginal_scores = []

# Iterate over each sequence and its corresponding mutations
for seq, mutations in zip(bottom_sequences, mutations_for_bottom_sequences):
    # Compute the mutant marginal score for the current sequence and mutations
    score = sfuncs.wild_type_marginal_score(tokenizer, model, seq, mutations)
    bottom_wild_type_marginal_scores.append(score)

bottom_wild_type_marginal_scores

[13.42713737487793,
 14.374776840209961,
 11.102903366088867,
 13.391246795654297,
 11.295248031616211,
 15.266965866088867,
 10.980279922485352,
 14.061704635620117,
 11.688694953918457,
 14.730594635009766]

## Pseudolikelihood Scores

In [30]:
# Compute the scores using mutant_marginal_score function
pseudolikelihood_scores = []

# Iterate over each sequence and its corresponding mutations
for seq, mutations in zip(top_sequences, mutations_for_sequences):
    # Compute the mutant marginal score for the current sequence and mutations
    score = sfuncs.pseudolikelihood_score(tokenizer, model, seq, mutations)
    pseudolikelihood_scores.append(score)

pseudolikelihood_scores

[6.3451385498046875,
 6.325115203857422,
 11.335274696350098,
 6.095399379730225,
 7.592133522033691,
 6.561470031738281,
 2.7696115970611572,
 6.699493885040283,
 10.483247756958008,
 9.899065017700195]

In [31]:
# Compute the scores using mutant_marginal_score function
bottom_pseudolikelihood_scores = []

# Iterate over each sequence and its corresponding mutations
for seq, mutations in zip(bottom_sequences, mutations_for_bottom_sequences):
    # Compute the mutant marginal score for the current sequence and mutations
    score = sfuncs.pseudolikelihood_score(tokenizer, model, seq, mutations)
    bottom_pseudolikelihood_scores.append(score)

bottom_pseudolikelihood_scores

[13.314289093017578,
 14.322607040405273,
 11.03341293334961,
 13.302623748779297,
 11.18637466430664,
 15.177007675170898,
 10.890905380249023,
 13.978338241577148,
 11.652375221252441,
 14.624868392944336]