In [4]:
import torch
print("CUDA available:", torch.cuda.is_available())
print("Device count:", torch.cuda.device_count())
print("Device name:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else "N/A")

CUDA available: True
Device count: 1
Device name: NVIDIA GeForce RTX 5070 Ti


In [5]:
from transformers import (
    AutoTokenizer, 
    AutoModel, # If we only need the embeddings
    AutoModelForMaskedLM # If we want to work with the masked LM embeddings (additional logits output)
)
import torch
import pandas as pd
import random 


device = 'cuda' if torch.cuda.is_available() else 'cpu2'
print(f"Using device: {device}")
device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
print(f"Using device: {device}")
# Load ESM2 8M model and tokenizer
model_checkpoint = "facebook/esm2_t6_8M_UR50D"
model = AutoModelForMaskedLM.from_pretrained(model_checkpoint, force_download=True)
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, force_download=True)

# TODO: for later usage with LeoMed, save your model (and tokenizer) and scp this directory to the server
# Then set force_download=False if you specify a local path to load from 
# path_to_save = ...
# model.save_pretrained(path_to_save) 


# Load sample DMS data
df = pd.read_csv('data/gfp_ground_truth.csv')

# Extract sample from df and mask random 15%
seq = df['sequence'][0]
mask_prob = 0.15
masked_chars = []
for aa in seq:
    if random.random() < mask_prob:
        masked_chars.append(tokenizer.mask_token)
    else:
        masked_chars.append(aa)
masked_seq = "".join(masked_chars)

# Tokenize and input to ESM2
inputs = tokenizer(masked_seq, return_tensors="pt")
with torch.no_grad():
    outputs = model(**inputs)
    logits = outputs.logits  # (1, seq_len, vocab_size)

# Decode predictions by taking argmax
pred_ids = logits.argmax(dim=-1)  # (1, seq_len)
pred_seq = tokenizer.decode(
    pred_ids[0],
    skip_special_tokens=True
).replace(" ", "")

print("\nGround truth sequence:")
print(seq)
print("\nMasked input sequence:")
print(masked_seq)
print("\nPredicted sequence:")
print(pred_seq)


Using device: cuda
Using device: cuda


config.json:   0%|          | 0.00/775 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/775 [00:00<?, ?B/s]

Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


model.safetensors:   0%|          | 0.00/31.4M [00:00<?, ?B/s]

Some weights of the model checkpoint at facebook/esm2_t6_8M_UR50D were not used when initializing EsmForMaskedLM: ['esm.embeddings.position_embeddings.weight']
- This IS expected if you are initializing EsmForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing EsmForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


tokenizer_config.json:   0%|          | 0.00/95.0 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/95.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/93.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/125 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/95.0 [00:00<?, ?B/s]


Ground truth sequence:
SKGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKFICTTGKLPVPWPTLVTTLSYGVQCFSRYPDHMKQHDFFKSAMPEGYVQERTIFFKDDGNYKTRAEVKFEGDTLVNRIELKGIDFKEDGNILGHKLEYNYNSHNVYIMADKQKNGIKVNFKIRHNIEDGSVQLADHYQQNTPIGDGPVLLPDNHYLSTQSALSKDPNEKRDHMVLLEFVTAAGITHGMDELYK

Masked input sequence:
<mask>KGEEL<mask><mask>G<mask>VPILVELDGDVN<mask><mask>KFSVSGEG<mask>GDAT<mask><mask>KL<mask>LK<mask>I<mask>TTG<mask>LPVPWP<mask>LVTTLS<mask><mask>VQCF<mask><mask>YPDHMKQ<mask>DFFK<mask><mask><mask>PEGYVQ<mask>R<mask>IF<mask>KDDGNY<mask><mask>R<mask>E<mask>KFEGDTLVNRIELKGI<mask>FKED<mask><mask>ILGH<mask>LEYNY<mask>SHNVYIM<mask>D<mask>QK<mask><mask>IKVNFKIR<mask>N<mask>EDGSVQLADHYQQNTPIGDGPVLLPDNHYLSTQSALSKDPNEK<mask>D<mask>MVLLEF<mask>TAAGITHGMDELYK

Predicted sequence:
MKGEELKEGDVPILVELDGDVNSLKFSVSGEGKGDATGLKLELKGIDTTGELPVPWPELVTTLSGGVQCFSDYPDHMKQLDFFKKLLPEGYVQGRGIFKKDDGNYLLRLELKFEGDTLVNRIELKGIDFKEDGKILGHKLEYNYDSHNVYIMKDGQKKPIKVNFKIRLNKEDGSVQLADHYQQNTPIGDGPVLLPDNHYLSTQSALSKDPNEKLDGMVLLEFLTAAGITHGMDELYK


In [7]:
seq = df['sequence'][0]
seq

'SKGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKFICTTGKLPVPWPTLVTTLSYGVQCFSRYPDHMKQHDFFKSAMPEGYVQERTIFFKDDGNYKTRAEVKFEGDTLVNRIELKGIDFKEDGNILGHKLEYNYNSHNVYIMADKQKNGIKVNFKIRHNIEDGSVQLADHYQQNTPIGDGPVLLPDNHYLSTQSALSKDPNEKRDHMVLLEFVTAAGITHGMDELYK'

In [8]:
mask_prob = 0.15
masked_chars = []
for aa in seq:
    if random.random() < mask_prob:
        masked_chars.append(tokenizer.mask_token)
    else:
        masked_chars.append(aa)
masked_seq = "".join(masked_chars)
masked_seq

'<mask>KG<mask>ELFTGVVPILVE<mask>D<mask>DVNG<mask>KFSV<mask>GEGEG<mask>ATYGKL<mask>LKFICT<mask>GKLPV<mask>WPTLVTTLSYGV<mask>CFSRYPDHMKQHDFFK<mask>AMPEGYVQ<mask>RTI<mask>FKDDGNY<mask><mask>RA<mask>VKFEGDTLVNR<mask>E<mask>KGIDFKEDGNILGHKLEYNY<mask>SHN<mask>YIMADKQKNGIK<mask>NF<mask>IRHNIE<mask>GSVQ<mask>ADHYQQNTP<mask>GDGP<mask>L<mask>PD<mask>HYLS<mask>QSA<mask>SKDPNEKRDHMVL<mask>EFVT<mask><mask>G<mask>THGMDEL<mask>K'

In [11]:
inputs = tokenizer(masked_seq, return_tensors="pt")
with torch.no_grad():
    outputs = model(**inputs)
    logits = outputs.logits  # (1, seq_len, vocab_size)

outputs

MaskedLMOutput(loss=None, logits=tensor([[[ 13.9560,  -7.6173,  -5.9903,  ..., -15.5301, -15.7619,  -7.6145],
         [ -7.2889, -14.7171,  -7.4395,  ..., -15.7856, -15.9946, -14.7180],
         [-11.6578, -20.0635, -12.1473,  ..., -16.2404, -16.1713, -20.0570],
         ...,
         [-10.9424, -17.2669,  -9.6918,  ..., -16.1643, -16.1947, -17.2611],
         [-10.8844, -17.1586, -10.6407,  ..., -16.1206, -16.0598, -17.1437],
         [ -6.1914,  -6.8257,  16.5295,  ..., -16.7358, -16.6179,  -6.8626]]]), hidden_states=None, attentions=None)

In [12]:
pred_ids = logits.argmax(dim=-1)  # (1, seq_len)
pred_ids

tensor([[ 0, 20, 15,  6, 15,  9,  4, 18, 11,  6,  7,  7, 14, 12,  4,  7,  9,  6,
         13, 12, 13,  7, 17,  6,  4, 15, 18,  8,  7, 15,  6,  9,  6,  9,  6, 15,
          5, 11, 19,  6, 15,  4, 15,  4, 15, 18, 12, 23, 11, 13,  6, 15,  4, 14,
          7, 13, 22, 14, 11,  4,  7, 11, 11,  4,  8, 19,  6,  7,  6, 23, 18,  8,
         10, 19, 14, 13, 21, 20, 15, 16, 21, 13, 18, 18, 15, 15,  5, 20, 14,  9,
          6, 19,  7, 16,  9, 10, 11, 12,  9, 18, 15, 13, 13,  6, 17, 19,  9,  7,
         10,  5,  9,  7, 15, 18,  9,  6, 13, 11,  4,  7, 17, 10, 12,  9, 12, 15,
          6, 12, 13, 18, 15,  9, 13,  6, 17, 12,  4,  6, 21, 15,  4,  9, 19, 17,
         19, 13,  8, 21, 17,  4, 19, 12, 20,  5, 13, 15, 16, 15, 17,  6, 12, 15,
          4, 17, 18, 11, 12, 10, 21, 17, 12,  9, 13,  6,  8,  7, 16, 12,  5, 13,
         21, 19, 16, 16, 17, 11, 14,  4,  6, 13,  6, 14,  4,  4,  4, 14, 13,  6,
         21, 19,  4,  8, 12, 16,  8,  5,  4,  8, 15, 13, 14, 17,  9, 15, 10, 13,
         21, 20,  7,  4,  7,

In [14]:
pred_seq = tokenizer.decode(
    pred_ids[0],
    skip_special_tokens=True
).replace(" ", "")
pred_seq

'MKGKELFTGVVPILVEGDIDVNGLKFSVKGEGEGKATYGKLKLKFICTDGKLPVDWPTLVTTLSYGVGCFSRYPDHMKQHDFFKKAMPEGYVQERTIEFKDDGNYEVRAEVKFEGDTLVNRIEIKGIDFKEDGNILGHKLEYNYDSHNLYIMADKQKNGIKLNFTIRHNIEDGSVQIADHYQQNTPLGDGPLLLPDGHYLSIQSALSKDPNEKRDHMVLVEFVTKDGKTHGMDELKK'

In [None]:

# Decode predictions by taking argmax
pred_ids = logits.argmax(dim=-1)  # (1, seq_len)
pred_seq = tokenizer.decode(
    pred_ids[0],
    skip_special_tokens=True
).replace(" ", "")

print("\nGround truth sequence:")
print(seq)
print("\nMasked input sequence:")
print(masked_seq)
print("\nPredicted sequence:")
print(pred_seq)