In [1]:
from pathlib import Path
import torch
import re
from transformers import BertModel, BertTokenizer

In [2]:
# Most likely this path is different for you
# set this path to where your protbert_weights are located
protbert_cache = Path('../../protbert_weights').resolve()

In [3]:
protbert_version = 'Rostlab/prot_bert_bfd'

# Load ProtBertTokenizer
tokenizer = BertTokenizer.from_pretrained(protbert_version, do_lower_case=False, cache_dir=protbert_cache)
# Load ProtBert
model = BertModel.from_pretrained(protbert_version, cache_dir=protbert_cache)
model.eval()

Some weights of the model checkpoint at Rostlab/prot_bert_bfd were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.predictions.decoder.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(30, 1024, padding_idx=0)
    (position_embeddings): Embedding(40000, 1024)
    (token_type_embeddings): Embedding(2, 1024)
    (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.0, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0): BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=1024, out_features=1024, bias=True)
            (key): Linear(in_features=1024, out_features=1024, bias=True)
            (value): Linear(in_features=1024, out_features=1024, bias=True)
            (dropout): Dropout(p=0.0, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=1024, out_features=1024, bias=True)
            (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.0, inplace=False

In [4]:
# Use some example sequence and replace rare amino acids
seq = [re.sub(r'[UZOB]', 'X', 'M E T C Z A O')]

In [5]:
# Tokenize the example sequence
tokens = tokenizer(seq, add_special_tokens=True, return_tensors='pt')
input_ids, attention_mask = tokens['input_ids'], tokens['attention_mask']

# Compare the dimension 1 with the sequence length
input_ids.shape, attention_mask.shape

(torch.Size([1, 9]), torch.Size([1, 9]))

In [6]:
# Generate embedding
with torch.no_grad():
    embedding = model(input_ids=input_ids, attention_mask=attention_mask)[0]

embedding.shape

torch.Size([1, 9, 1024])

In [7]:
# Usually embedding[0] and embedding[-1] are [CLS] and [SEP] (special tokens)
# For simplicity, just ignore that
seq_len = attention_mask.sum()
actual_embedding = embedding[:, 1 : seq_len - 1]
actual_embedding.shape

# Be careful if you use multiple sequences

torch.Size([1, 7, 1024])

In [8]:
# If you want to do your own project, use ProtT5 with version = "Rostlab/prot_t5_xl_uniref50"
# - way larger, takes more time to compute, ProtBert sufficient for this exercise

# Transformer: https://jalammar.github.io/illustrated-transformer/