Generate embeddings from protein sequence data using lightweight ESM2 model

In [2]:
from transformers import EsmModel, EsmTokenizer
import torch

# Load large ESM-2 model
model_name = "facebook/esm2_t6_8M_UR50D"
tokenizer = EsmTokenizer.from_pretrained(model_name)
model = EsmModel.from_pretrained(model_name)

#Example protein sequence
sequence = "MKTLLVLLLGAAGG"
tokens = tokenizer(sequence, return_tensors="pt")

# Generate embeddings quickly
embeddings = model(**tokens).last_hidden_state
print(f"Embedding shape: {embeddings.shape}")

  from .autonotebook import tqdm as notebook_tqdm
Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t6_8M_UR50D and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Embedding shape: torch.Size([1, 16, 320])


Protein mutation effect analysis

In [None]:
# Load mid-weight ESM-2 model
model_name = "facebook/esm2_t12_35M_UR50D" 
tokenizer = EsmTokenizer.from_pretrained(model_name)
model = EsmModel.from_pretrained(model_name)

# Example wild-type sequence
sequence = "MKTLLVLLLGAAGG"

# Introduce mutation (G -> A at position 13)
mutated_sequence = "MKTLLVLLLGAAGG".replace("G", "A", 1)

# Tokenize sequences
tokens_wt = tokenizer(sequence, return_tensors="pt")
tokens_mut = tokenizer(mutated_sequence, return_tensors="pt")

# Generate embeddings
embeddings_wt = model(**tokens_wt).last_hidden_state
embeddings_mut = model(**tokens_mut).last_hidden_state

# Compute similarity (cosine similarity, lower values means proteins phenotype is affected, closer to 1 means minimal or no effect)
similarity = torch.nn.functional.cosine_similarity(embeddings_wt.mean(dim=1), embeddings_mut.mean(dim=1))
print(f"Mutation Impact Score: {similarity.item()}")


Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t12_35M_UR50D and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Mutation Impact Score: 0.9933467507362366
