In [1]:
from transformers import AutoTokenizer, EsmModel
import torch

# 1. Load the tokenizer and the base ESM-2 model (matching your notebook's checkpoint)
model_checkpoint = 'facebook/esm2_t6_8M_UR50D'
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
model = EsmModel.from_pretrained(model_checkpoint)

# 2. Define your sequence (using the GCK glucokinase snippet from your notebook)
seq = 'MLDDRARMEAAKKEKVEQILAEFQLQEEDLKKVMRRMQKEMDRGLRLETHEEASVKMLPTYVRSTPEGSEVGDFLSLDLGGTNFRVMLVKVGEGEEGQWSVKTKHQMYS'

# 3. Tokenize the sequence
# return_tensors="pt" tells it to return PyTorch tensors
inputs = tokenizer(seq, return_tensors="pt")

# 4. Pass the inputs through the model to get the hidden states
# We use torch.no_grad() because we are just extracting features, not training ESM-2
with torch.no_grad():
    outputs = model(**inputs)

# 5. Extract the last hidden state (the embeddings)
# Shape: (batch_size, sequence_length + 2, hidden_dimension)
# The +2 accounts for the special <cls> (start) and <eos> (end) tokens ESM adds.
last_hidden_state = outputs.last_hidden_state

print(f"Full embedding tensor shape: {last_hidden_state.shape}")

Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t6_8M_UR50D and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Full embedding tensor shape: torch.Size([1, 111, 320])
