In [9]:
from transformers import AutoTokenizer
import torch
import torch.nn as nn
from transformers import BertTokenizer, BertModel

In [10]:
class ProteinLM(nn.Module):
    def __init__(self, model_name="Rostlab/prot_bert_bfd"):
        super(ProteinLM, self).__init__()
        self.tokenizer = BertTokenizer.from_pretrained(model_name, do_lower_case=False)
        self.bert = BertModel.from_pretrained(model_name)

    def forward(self, input_ids, attention_mask):
        with torch.no_grad():
            output = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        return output.last_hidden_state

In [11]:
def esm_tokenizer(sequences, max_length=512):
    tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D")
    tokenized = tokenizer(
        sequences,
        return_tensors="pt",
        padding='max_length',
        truncation=True,
        max_length=max_length
    )
    return tokenized['input_ids'], tokenized['attention_mask']

In [13]:
if __name__ == "__main__":
    with open("uniref50_sequences.txt") as f:
        sequences = [line.strip() for line in f]

    input_ids, attention_mask = esm_tokenizer(sequences)

    tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D")
    print(f"First token ID: {tokenizer.decode(input_ids[0].tolist())}")

    model = ProteinLM()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    input_ids, attention_mask = input_ids.to(device), attention_mask.to(device)

    embeddings = model(input_ids=input_ids, attention_mask=attention_mask)

    print("Embeddings shape:", embeddings.shape)

First token ID: <cls> M G R I R V W V G T S I P N P V N A H Q L V Y L K G M A K T K K L I L L L F V A A Q P N F K E W S L D V D A S T L V L T F E A N S V L S V K P D C S K V T I H S T A N G V K N V T L T N S G N G T L D A A N D Q A S C T I D A K D L D N I K L E T T L G T N T T N T F L E V K A G F G T K N G T T E F T Q G S P Y T A A A L V T P D V T A P E I S A T V G F S E F D L N S G R V T I A F T E A V D V S T L K F T K L A F R D A K L T G K T S T T G Y C N V T K D G K C D A A F C K N G A T V V L E V D N V D L N C I K S K R G L C T K D S D C I I T L E E D D F I Q D M A G N K L G K Y E S G T T A N A A E T L L H K F V P D I T S P T L D N F D L D L N A N T L T L E F S E T V D A K T L K A D G L T I Q G N G N T A D V S L Q V K L T S E S T T E S S D S A T I I V D I A P A D G A K L K M S T N I A T K T G D S Y I A V A T S A M N D M S G N A V K P I S S T A A K Q V R R F T N D T S A A V L S K F S L D L N T N Q L T L T F D E P V K V D S L N F T L F T L Q S T A A G G T E V K L T G S T T M T T G T 

RuntimeError: [enforce fail at alloc_cpu.cpp:114] data. DefaultCPUAllocator: not enough memory: you tried to allocate 8388608000 bytes.