In [96]:
import torch
import utils

In [None]:
from transformers import EsmForMaskedLM, EsmTokenizer

# Load pretrained model and tokenizer
model_name = "facebook/esm2_t6_8M_UR50D"  # Example ESM model
model = EsmForMaskedLM.from_pretrained(model_name)
tokenizer = EsmTokenizer.from_pretrained(model_name)


In [None]:
# Protein sequence with a masked token
sequence = "MKTLLILAVVFCALMAIVFV<mask>"

# Tokenize the input sequence
inputs = tokenizer(sequence, return_tensors="pt", add_special_tokens=True)

# Extract input IDs and attention mask
input_ids = inputs["input_ids"]
attention_mask = inputs["attention_mask"]


In [None]:
# Forward pass
outputs = model(input_ids=input_ids, attention_mask=attention_mask)

# Logits for masked language modeling
logits = outputs.logits
print(logits.shape)  # Shape: (batch_size, seq_len, vocab_size)


In [None]:


# Find the index of the masked token
masked_index = (input_ids == tokenizer.mask_token_id).nonzero(as_tuple=True)

# Get logits for the masked token
masked_logits = logits[masked_index]

# Predicted token ID
predicted_token_id = torch.argmax(masked_logits, dim=-1)

# Convert token ID back to the residue
predicted_token = tokenizer.convert_ids_to_tokens(predicted_token_id)
print(f"Predicted token: {predicted_token}")


In [None]:
# Define loss function
loss_fn = torch.nn.CrossEntropyLoss()

# Mask some tokens in the sequence
labels = input_ids.clone()
labels[labels != tokenizer.mask_token_id] = -100  # Ignore non-masked tokens for loss

# Forward pass with labels
outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
loss = outputs.loss
print(f"Loss: {loss.item()}")


In [2]:
from transformers import EsmForMaskedLM, EsmTokenizer
import torch

# Load pretrained ESM model and tokenizer
model_name = "facebook/esm2_t6_8M_UR50D"
model = EsmForMaskedLM.from_pretrained(model_name)
tokenizer = EsmTokenizer.from_pretrained(model_name)

# Protein sequence
sequence = "MKTLLILAVVFCALMAIVFV<mask><mask><mask><mask><mask>"

# Tokenize the sequence
inputs = tokenizer(sequence, return_tensors="pt", add_special_tokens=True)
input_ids = inputs["input_ids"]
attention_mask = inputs["attention_mask"]

# Forward pass
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
logits = outputs.logits

# Predict masked token
masked_index = (input_ids == tokenizer.mask_token_id).nonzero(as_tuple=True)
masked_logits = logits[masked_index]
predicted_token_id = torch.argmax(masked_logits, dim=-1)
predicted_token = tokenizer.convert_ids_to_tokens(predicted_token_id)

print(f"Predicted token: {predicted_token}")

Predicted token: ['V', 'L', 'L', 'L', 'R']


In [16]:
from esm import pretrained

# Load the ESM-MSA model
model, alphabet = pretrained.esm_msa1b_t12_100M_UR50S()
batch_converter = alphabet.get_batch_converter()

# Example MSA
msa = [
    ("seq1", "MKTLLILAVVFCALMAIVFV"),
    ("seq2", "MKTLIILVVFCALMAVVVF."),
    ("seq3", "MKTLLILVVFCALMAIVF.."),
]

# Convert MSA into model-compatible format
batch_labels, batch_strs, batch_tokens = batch_converter([msa])

# Print batch tokens shape
print(batch_tokens.shape)  # (num_sequences, msa_depth, seq_length)

# Forward pass through the model
with torch.no_grad():
    results = model(batch_tokens, repr_layers=[12])

# Extract token embeddings from the last layer (layer 12)
msa_embeddings = results["representations"][12]
print(msa_embeddings.shape)  # Shape: (msa_depth, seq_length, embedding_dim)



A module that was compiled using NumPy 1.x cannot be run in
NumPy 2.1.3 as it may crash. To support both 1.x and 2.x
versions of NumPy, modules must be compiled with NumPy 2.0.
Some module may need to rebuild instead e.g. with 'pybind11>=2.12'.

If you are a user of the module, the easiest solution will be to
downgrade to 'numpy<2' or try to upgrade the affected module.
We expect that some modules will need time to support NumPy 2.

Traceback (most recent call last):  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/Users/guillaumebelissent/opt/anaconda3/envs/ml/lib/python3.11/site-packages/ipykernel_launcher.py", line 18, in <module>
    app.launch_new_instance()
  File "/Users/guillaumebelissent/opt/anaconda3/envs/ml/lib/python3.11/site-packages/traitlets/config/application.py", line 1075, in launch_instance
    app.start()
  File "/Users/guillaumebelissent/opt/anaconda3/envs/ml/lib/python3.11/site-packages/ipykernel/ker

torch.Size([1, 3, 21])


NameError: name 'torch' is not defined

In [28]:
import torch
import esm

# Load the MSA model and alphabet
model, alphabet = esm.pretrained.esm_msa1b_t12_100M_UR50S()
batch_converter = alphabet.get_batch_converter()

# Move the model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

# Example MSA (aligned sequences)
msa = [
    ("seq1", "MKTLLILAVVFCALMAIVFV"),
    ("seq2", "MKTLIILVVFCALMAVVVF."),
    ("seq3", "MKTLLILVVFCALMAIVF.."),
]

# Tokenize the MSA
batch_labels, batch_strs, batch_tokens = batch_converter([msa])

# Move tokens to the appropriate device
batch_tokens = batch_tokens.to(device)

# Forward pass
with torch.no_grad():
    results = model(batch_tokens)

# Extract per-residue embeddings
# Layer 12 corresponds to the final layer of the model
representations = results["representations"][12]  # Shape: (msa_depth, seq_len, hidden_size)
print("Per-residue embedding shape:", representations.shape)

query_sequence_embedding = representations[0]  # Shape: (seq_len, hidden_size)

# Example MSA with a masked token in the first sequence
msa_with_mask = [
    ("seq1", "MKTLLILAV<mask>CALMAIVFV"),  # Masked token at position 10
    ("seq2", "MKTLIILVVFCALMAVVVF."),
    ("seq3", "MKTLLILVVFCALMAIVF.."),
]

# Tokenize the MSA with masking
batch_labels, batch_strs, batch_tokens = batch_converter([msa_with_mask])
batch_tokens = batch_tokens.to(device)

# Forward pass
with torch.no_grad():
    results = model(batch_tokens)

# Extract logits for the masked position
masked_index = 10 + 1  # Adjust for the [CLS] token
masked_logits = results["logits"][0, 0, masked_index]  # First sequence, masked position
predicted_token_id = masked_logits.argmax().item()
predicted_token = alphabet.get_tok(predicted_token_id)
print(f"Predicted token at position 10: {predicted_token}")


Downloading: "https://dl.fbaipublicfiles.com/fair-esm/models/esm2_t6_8M_UR50D.pt" to /Users/guillaumebelissent/.cache/torch/hub/checkpoints/esm2_t6_8M_UR50D.pt
Downloading: "https://dl.fbaipublicfiles.com/fair-esm/regression/esm2_t6_8M_UR50D-contact-regression.pt" to /Users/guillaumebelissent/.cache/torch/hub/checkpoints/esm2_t6_8M_UR50D-contact-regression.pt


Predicted token at position 10: C


In [69]:
# Load the model and alphabet
model, alphabet = esm.pretrained.esm_msa1b_t12_100M_UR50S()
batch_converter = alphabet.get_batch_converter()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

# Example MSA with a masked token
msa_with_mask = [
    ("P69905","MVLSPADKTNVKAAWGKVGAHAGEYGAEALERMFLSFPTTKTYFPHFDLSHGS<mask>QVKGHGKKVADALTNAVAHVDDMPNALSALSDLHAHKLRVDPVNFKLLSHCLLVTLAAHLPAEFTPAVHASLDKFLASVSTVLTSKYR"),
    ("P01942","MVLSGEDKSNIKAAWGKIGGHGAEYGAEALERMFASFPTTKTYFPHFDVSHGS<mask>QVKGHGKKVADALASAAGHLDDLPGALSALSDLHAHKLRVDPVNFKLLSHCLLVTLASHHPADFTPAVHASLDKFLASVSTVLTSKYR"),
    ("P13786","MSLTRTERTIILSLWSKISTQADVIGTETLERLFSCYPQAKTYFPHFDLHSGS<mask>QLRAHGSKVVAAVGDAVKSIDNVTSALSKLSELHAYVLRVDPVNFKFLSHCLLVTLASHFPADFTADAHAAWDKFLSIVSGVLTEKYR"),
]

# Tokenize the MSA
batch_labels, batch_strs, batch_tokens = batch_converter(msa_with_mask)
batch_tokens = batch_tokens.to(device)

# Forward pass
with torch.no_grad():
    results = model(batch_tokens)

# Extract logits for the masked position
masked_index = 1+msa_with_mask[0][1].find('<')  # Adjust for the [CLS] token
masked_logits = results["logits"][0, 0, masked_index]  # First sequence, masked position
predicted_token_id = masked_logits.argmax().item()
predicted_token = alphabet.get_tok(predicted_token_id)
print(f"Predicted token at position {masked_index}: {predicted_token}")


Predicted token at position 54: A


In [26]:
batch_tokens

tensor([[[ 0, 20, 15, 11,  4,  4, 12,  4,  5,  7,  7, 18, 23,  5,  4, 20,  5,
          12,  7, 18,  7],
         [ 0, 20, 15, 11,  4, 12, 12,  4,  7,  7, 18, 23,  5,  4, 20,  5,  7,
           7,  7, 18, 29],
         [ 0, 20, 15, 11,  4,  4, 12,  4,  7,  7, 18, 23,  5,  4, 20,  5, 12,
           7, 18, 29, 29]]])

In [30]:
msa = [
    "MVLSPADKTNVKAAWGKVGAHAGEYGAEALERMFLSFPTTKTYFPHFDLSHGSAQVKGHGKKVADALTNAVAHVDDMPNALSALSDLHAHKLRVDPVNFKLLSHCLLVTLAAHLPAEFTPAVHASLDKFLASVSTVLTSKYR",
    "MVLSGEDKSNIKAAWGKIGGHGAEYGAEALERMFASFPTTKTYFPHFDVSHGSAQVKGHGKKVADALASAAGHLDDLPGALSALSDLHAHKLRVDPVNFKLLSHCLLVTLASHHPADFTPAVHASLDKFLASVSTVLTSKYR",
    "MSLTRTERTIILSLWSKISTQADVIGTETLERLFSCYPQAKTYFPHFDLHSGSAQLRAHGSKVVAAVGDAVKSIDNVTSALSKLSELHAYVLRVDPVNFKFLSHCLLVTLASHFPADFTADAHAAWDKFLSIVSGVLTEKYR",
]


In [22]:
esm2_t30_150M_UR50D=pretrained.esm2_t30_150M_UR50D()

In [9]:
import numpy as np

In [30]:
def mask(sequence, position):
    return sequence[: position] + '<mask>' + sequence[position + 1:]

In [2]:
blosum62 = {
        ('A', 'A'):  4, ('A', 'R'): -1, ('A', 'N'): -2, ('A', 'D'): -2, ('A', 'C'):  0,
        ('A', 'Q'): -1, ('A', 'E'): -1, ('A', 'G'):  0, ('A', 'H'): -2, ('A', 'I'): -1,
        ('A', 'L'): -1, ('A', 'K'): -1, ('A', 'M'): -1, ('A', 'F'): -2, ('A', 'P'): -1,
        ('A', 'S'):  1, ('A', 'T'):  0, ('A', 'W'): -3, ('A', 'Y'): -2, ('A', 'V'):  0,
        ('R', 'R'):  5, ('R', 'N'):  0, ('R', 'D'): -2, ('R', 'C'): -3, ('R', 'Q'):  1,
        ('R', 'E'):  0, ('R', 'G'): -2, ('R', 'H'):  0, ('R', 'I'): -3, ('R', 'L'): -2,
        ('R', 'K'):  2, ('R', 'M'): -1, ('R', 'F'): -3, ('R', 'P'): -2, ('R', 'S'): -1,
        ('R', 'T'): -1, ('R', 'W'): -3, ('R', 'Y'): -2, ('R', 'V'): -3,
    }

In [4]:
len(blosum62.keys())

39

In [7]:
def blosum62(AA1, AA2):
    blosum62 = {
        ('A', 'A'):  4, ('A', 'R'): -1, ('A', 'N'): -2, ('A', 'D'): -2, ('A', 'C'):  0,
        ('A', 'Q'): -1, ('A', 'E'): -1, ('A', 'G'):  0, ('A', 'H'): -2, ('A', 'I'): -1,
        ('A', 'L'): -1, ('A', 'K'): -1, ('A', 'M'): -1, ('A', 'F'): -2, ('A', 'P'): -1,
        ('A', 'S'):  1, ('A', 'T'):  0, ('A', 'W'): -3, ('A', 'Y'): -2, ('A', 'V'):  0,
        ('R', 'R'):  5, ('R', 'N'):  0, ('R', 'D'): -2, ('R', 'C'): -3, ('R', 'Q'):  1,
        ('R', 'E'):  0, ('R', 'G'): -2, ('R', 'H'):  0, ('R', 'I'): -3, ('R', 'L'): -2,
        ('R', 'K'):  2, ('R', 'M'): -1, ('R', 'F'): -3, ('R', 'P'): -2, ('R', 'S'): -1,
        ('R', 'T'): -1, ('R', 'W'): -3, ('R', 'Y'): -2, ('R', 'V'): -3,
        ('N', 'N'):  6, ('N', 'D'):  1, ('N', 'C'): -3, ('N', 'Q'):  0, ('N', 'E'):  0,
        ('N', 'G'):  0, ('N', 'H'):  1, ('N', 'I'): -3, ('N', 'L'): -3, ('N', 'K'):  0,
        ('N', 'M'): -2, ('N', 'F'): -3, ('N', 'P'): -2, ('N', 'S'):  1, ('N', 'T'):  0,
        ('N', 'W'): -4, ('N', 'Y'): -2, ('N', 'V'): -3,
        ('D', 'D'):  6, ('D', 'C'): -3, ('D', 'Q'):  0, ('D', 'E'):  2, ('D', 'G'): -1,
        ('D', 'H'): -1, ('D', 'I'): -3, ('D', 'L'): -4, ('D', 'K'): -1, ('D', 'M'): -3,
        ('D', 'F'): -3, ('D', 'P'): -1, ('D', 'S'):  0, ('D', 'T'): -1, ('D', 'W'): -4,
        ('D', 'Y'): -3, ('D', 'V'): -3,
        ('C', 'C'):  9, ('C', 'Q'): -3, ('C', 'E'): -4, ('C', 'G'): -3, ('C', 'H'): -3,
        ('C', 'I'): -1, ('C', 'L'): -1, ('C', 'K'): -3, ('C', 'M'): -1, ('C', 'F'): -2,
        ('C', 'P'): -3, ('C', 'S'): -1, ('C', 'T'): -1, ('C', 'W'): -2, ('C', 'Y'): -2,
        ('C', 'V'): -1,
        ('Q', 'Q'):  5, ('Q', 'E'):  2, ('Q', 'G'): -2, ('Q', 'H'):  0, ('Q', 'I'): -3,
        ('Q', 'L'): -2, ('Q', 'K'):  1, ('Q', 'M'):  0, ('Q', 'F'): -3, ('Q', 'P'): -1,
        ('Q', 'S'):  0, ('Q', 'T'): -1, ('Q', 'W'): -2, ('Q', 'Y'): -1, ('Q', 'V'): -2,
        ('E', 'E'):  5, ('E', 'G'): -2, ('E', 'H'):  0, ('E', 'I'): -3, ('E', 'L'): -3,
        ('E', 'K'):  1, ('E', 'M'): -2, ('E', 'F'): -3, ('E', 'P'): -1, ('E', 'S'):  0,
        ('E', 'T'): -1, ('E', 'W'): -3, ('E', 'Y'): -2, ('E', 'V'): -2,
        ('G', 'G'):  6, ('G', 'H'): -2, ('G', 'I'): -4, ('G', 'L'): -4, ('G', 'K'): -2,
        ('G', 'M'): -3, ('G', 'F'): -3, ('G', 'P'): -2, ('G', 'S'):  0, ('G', 'T'): -2,
        ('G', 'W'): -2, ('G', 'Y'): -3, ('G', 'V'): -3,
        ('H', 'H'):  8, ('H', 'I'): -3, ('H', 'L'): -3, ('H', 'K'): -1, ('H', 'M'): -2,
        ('H', 'F'): -1, ('H', 'P'): -2, ('H', 'S'): -1, ('H', 'T'): -2, ('H', 'W'): -2,
        ('H', 'Y'):  2, ('H', 'V'): -3,
        ('I', 'I'):  4, ('I', 'L'):  2, ('I', 'K'): -3, ('I', 'M'):  1, ('I', 'F'):  0,
        ('I', 'P'): -3, ('I', 'S'): -2, ('I', 'T'): -1, ('I', 'W'): -3, ('I', 'Y'): -1,
        ('I', 'V'):  3,
        ('L', 'L'):  4, ('L', 'K'): -2, ('L', 'M'):  2, ('L', 'F'):  0, ('L', 'P'): -3,
        ('L', 'S'): -2, ('L', 'T'): -1, ('L', 'W'): -2, ('L', 'Y'): -1, ('L', 'V'):  1,
        ('K', 'K'):  5, ('K', 'M'): -1, ('K', 'F'): -3, ('K', 'P'): -1, ('K', 'S'):  0,
        ('K', 'T'): -1, ('K', 'W'): -3, ('K', 'Y'): -2, ('K', 'V'): -2,
        ('M', 'M'):  5, ('M', 'F'):  0, ('M', 'P'): -2, ('M', 'S'): -1, ('M', 'T'): -1,
        ('M', 'W'): -1, ('M', 'Y'): -1, ('M', 'V'):  1,
        ('F', 'F'):  6, ('F', 'P'): -4, ('F', 'S'): -2, ('F', 'T'): -2, ('F', 'W'):  1,
        ('F', 'Y'):  3, ('F', 'V'): -1,
        ('P', 'P'):  7, ('P', 'S'): -1, ('P', 'T'): -1, ('P', 'W'): -4, ('P', 'Y'): -3,
        ('P', 'V'): -2,
        ('S', 'S'):  4, ('S', 'T'):  1, ('S', 'W'): -3, ('S', 'Y'): -2, ('S', 'V'): -2,
        ('T', 'T'):  5, ('T', 'W'): -2, ('T', 'Y'): -2, ('T', 'V'):  0,
        ('W', 'W'): 11, ('W', 'Y'):  2, ('W', 'V'): -3,
        ('Y', 'Y'):  7, ('Y', 'V'): -1,
        ('V', 'V'):  4
    }
    return blosum62[(str(AA1).upper(), str(AA2).upper())]

In [8]:
blosum62('A','Y')

-2

In [52]:
import torch
import numpy as np
from esm import pretrained  # Import ESM pretrained models

# Define the `mask` function
def mask(sequence, masked_index):
    """
    Replaces a character at the specified index with a mask token ('X').
    
    Parameters:
        sequence (str): The original sequence.
        masked_index (int): The index of the character to mask.

    Returns:
        list of tuples: Each tuple is a sequence identifier and sequence (required by batch_converter).
    """
    masked_sequence = sequence[:masked_index] + 'X' + sequence[masked_index + 1:]
    return [("sequence_id", masked_sequence)]

# Main test function
def test(sequence, model_loader):
    """
    Tests a protein sequence using a pretrained ESM model and predicts the masked amino acid.
    
    Parameters:
        sequence (str): Input protein sequence.
        model_loader (function): Function to load the pretrained model (e.g., `pretrained.esm2_t30_150M_UR50D`).

    Returns:
        Tuple: Masked index, predicted amino acid, and model results.
    """
    # Randomly mask an index in the sequence
    masked_index = np.random.randint(0, len(sequence))+1
    masked_sequence = mask(sequence, masked_index)

    # Load the model and alphabet
    model, alphabet = model_loader()
    batch_converter = alphabet.get_batch_converter()
    
    # Set device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    
    # Tokenize the masked sequence
    print("Masked Sequence:", masked_sequence)
    batch_labels, batch_strs, batch_tokens = batch_converter(masked_sequence)
    batch_tokens = batch_tokens.to(device)
    
    # Perform inference
    with torch.no_grad():
        results = model(batch_tokens)
    
    # Extract logits for the masked position
    logits = results["logits"]  # Shape: (batch_size, seq_len, vocab_size)
    masked_logits = logits[0, masked_index, :]  # Select the masked token's logits
    
    # Find the predicted amino acid
    predicted_token_id = masked_logits.argmax().item()
    predicted_token = alphabet.get_tok(predicted_token_id)

    # Return the masked index, predicted amino acid, and full results
    return masked_index-1, predicted_token, results

# Example usage with pretrained ESM model
seq="MVLSPADKTNVKAAWGKVGAHAGEYGAEALERMFLSFPT"
masked_index, predicted_aa, results = test(seq, pretrained.esm2_t30_150M_UR50D)
print(f"Masked Index: {masked_index}")
print(f"Predicted Amino Acid: {predicted_aa}")
print(f"Predicted Amino Acid: {seq[masked_index]}")


Masked Sequence: [('sequence_id', 'MVLSPADKTNVKAAWGKVGAHAGEYGAEALERXFLSFPT')]
Masked Index: 31
Predicted Amino Acid: R
Predicted Amino Acid: R


# ESM single sequence (esm2_t30_150M_UR50D)

## Example Use

In [6]:
import torch
from esm import pretrained
import utils
%load_ext autoreload

### Mask size = 1

In [58]:
sequence="MVLSPADKTNVKAAWGKVGAHAGEYGAEALERMFLSFPT"
model, alphabet = pretrained.esm2_t30_150M_UR50D()
batch_converter = alphabet.get_batch_converter()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
masked_sequence=utils.mask(sequence, [4,5,6,7,9])
print(f'Masking residues:\n{sequence}\n{masked_sequence}\n{"".join(["-" if sequence[i]==masked_sequence[i] else "^" for i in range(len(sequence))])}')

Masking residues:
MVLSPADKTNVKAAWGKVGAHAGEYGAEALERMFLSFPT
MVLSXXXXTXVKAAWGKVGAHAGEYGAEALERMFLSFPT
----^^^^-^-----------------------------


In [59]:
batch_labels, batch_strs, batch_tokens = batch_converter([('ID',masked_sequence)])
batch_tokens = batch_tokens.to(device)

# Perform inference
with torch.no_grad():
    results = model(batch_tokens)
print(sequence)
print(''.join([alphabet.get_tok(item) for item in results["logits"][0].argmax(dim=-1)[1:-1]]))

MVLSXXXXTXVKAAWGKVGAHAGEYGAEALERMFLSFPT


In [54]:
results["logits"][0].argmax(dim=-1)

tensor([ 0, 20,  7,  4,  8,  5,  5, 13,  4, 11,  5,  7, 15,  5,  5, 22,  6, 15,
         7,  6,  5, 21,  5,  6,  9, 19,  6,  5,  9,  5,  4,  9, 10, 20, 18,  4,
         8, 18, 14, 11,  2])

In [12]:
%autoreload 2

In [78]:
import torch
import numpy as np
from esm import pretrained  # Import ESM pretrained models

# Initialize model and batch_converter globally for reuse
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model, alphabet = pretrained.esm2_t30_150M_UR50D()
batch_converter = alphabet.get_batch_converter()
model = model.to(device)
model.eval()  # Switch to evaluation mode


def mask(sequence, masked_index):
    """
    Replaces a character at the specified index with a mask token ('X').

    Parameters:
        sequence (str): The original sequence.
        masked_index (int): The index of the character to mask.

    Returns:
        str: Masked sequence.
    """
    return sequence[:masked_index] + 'X' + sequence[masked_index + 1:]


def test(sequence,model,alphabet):
    """
    Tests a protein sequence using a pretrained ESM model and predicts the masked amino acid.

    Parameters:
        sequence (str): Input protein sequence.

    Returns:
        Tuple: Masked index, predicted amino acid, and original amino acid.
    """
    # Randomly mask an index in the sequence
    masked_index = np.random.randint(1, len(sequence))+1  # Masking starts from index 0

    # Tokenize the masked sequence
    batch_labels, batch_strs, batch_tokens = batch_converter([("sequence_id", mask(sequence, masked_index))])
    batch_tokens = batch_tokens.to(device)

    # Perform inference
    with torch.no_grad():
        results = model(batch_tokens)

    # Return masked index, predicted amino acid, and original amino acid
    return masked_index-1, alphabet.get_tok(results["logits"][0, masked_index, :].argmax().item()), sequence[masked_index-1]


# Example usage
seq = "MVLSPADKTNVKAAWGKVGAHAGEYGAEALERMFLSFPT"
masked_index, predicted_aa, label = test(seq,model,alphabet)
print(f"Masked Index: {masked_index}")
print(f"Predicted Amino Acid: {predicted_aa}")
print(f"Original Amino Acid: {label}")


Masked Index: 13
Predicted Amino Acid: A
Original Amino Acid: A


In [79]:
from tqdm import tqdm

In [89]:
def read_fasta(file_path):
    sequences = []
    with open(file_path, 'r') as file:
        for line in file:
            if not line.startswith(">"):
                sequences.append(line.strip())
    return sequences

results=[]
for seq in tqdm(read_fasta('Project2/BindingDBTargetSequences.fasta')[::10]):
    results.append(test(seq,model,alphabet))

  0%|          | 0/938 [00:00<?, ?it/s]

100%|██████████| 938/938 [28:40<00:00,  1.83s/it]  


In [95]:
results

[(329, 'D', 'D'),
 (27, 'E', 'E'),
 (67, 'G', 'G'),
 (312, 'P', 'P'),
 (41, 'I', 'I'),
 (52, 'E', 'E'),
 (44, 'N', 'N'),
 (461, 'P', 'P'),
 (177, 'A', 'A'),
 (12, 'A', 'A'),
 (91, 'Q', 'Q'),
 (478, 'A', 'A'),
 (491, 'E', 'E'),
 (118, 'A', 'A'),
 (98, 'T', 'T'),
 (53, 'I', 'I'),
 (67, 'G', 'G'),
 (323, 'T', 'T'),
 (77, 'Q', 'Q'),
 (344, 'Q', 'Q'),
 (147, 'G', 'G'),
 (190, 'K', 'K'),
 (160, 'R', 'R'),
 (68, 'V', 'V'),
 (482, 'Y', 'Y'),
 (278, 'L', 'L'),
 (576, 'L', 'L'),
 (1063, 'E', 'E'),
 (56, 'L', 'L'),
 (24, 'D', 'D'),
 (41, 'G', 'G'),
 (333, 'S', 'S'),
 (157, 'Q', 'Q'),
 (5, 'L', 'L'),
 (410, 'N', 'N'),
 (51, 'S', 'S'),
 (494, 'D', 'D'),
 (86, 'A', 'A'),
 (196, 'Q', 'Q'),
 (318, 'L', 'L'),
 (113, 'Y', 'Y'),
 (554, 'L', 'L'),
 (61, 'Y', 'Y'),
 (11, 'S', 'S'),
 (696, 'E', 'E'),
 (89, 'V', 'V'),
 (467, 'R', 'R'),
 (499, 'L', 'L'),
 (534, 'N', 'N'),
 (113, 'E', 'E'),
 (964, 'I', 'I'),
 (73, 'V', 'V'),
 (327, 'L', 'L'),
 (505, 'P', 'P'),
 (159, 'S', 'S'),
 (305, 'E', 'E'),
 (282, 'E', 'E

In [92]:
s=0
for res in results:
    s+=res[1]==res[2]
s/len(results)

0.9829424307036247