In [1]:
import os
import sys
home_dir = "../../"
module_path = os.path.abspath(os.path.join(home_dir))
if module_path not in sys.path:
    sys.path.append(module_path)
    
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
model = AutoModelForSeq2SeqLM.from_pretrained("Rostlab/prot_t5_xl_uniref50")

Downloading pytorch_model.bin: 100%|██████████| 11.3G/11.3G [02:47<00:00, 67.5MB/s]


In [4]:
model

T5ForConditionalGeneration(
  (shared): Embedding(128, 1024)
  (encoder): T5Stack(
    (embed_tokens): Embedding(128, 1024)
    (block): ModuleList(
      (0): T5Block(
        (layer): ModuleList(
          (0): T5LayerSelfAttention(
            (SelfAttention): T5Attention(
              (q): Linear(in_features=1024, out_features=4096, bias=False)
              (k): Linear(in_features=1024, out_features=4096, bias=False)
              (v): Linear(in_features=1024, out_features=4096, bias=False)
              (o): Linear(in_features=4096, out_features=1024, bias=False)
              (relative_attention_bias): Embedding(32, 32)
            )
            (layer_norm): T5LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (1): T5LayerFF(
            (DenseReluDense): T5DenseActDense(
              (wi): Linear(in_features=1024, out_features=16384, bias=False)
              (wo): Linear(in_features=16384, out_features=1024, bias=False)
              (dro

In [25]:
# tokenizer = AutoTokenizer.from_pretrained("Rostlab/prot_t5_xl_uniref50") # this did not work
# tokenizer

In [26]:
from transformers import T5Tokenizer
tokenizer = T5Tokenizer.from_pretrained("Rostlab/prot_t5_xl_uniref50")
tokenizer

T5Tokenizer(name_or_path='Rostlab/prot_t5_xl_uniref50', vocab_size=128, model_max_length=1000000000000000019884624838656, is_fast=False, padding_side='right', truncation_side='right', special_tokens={'eos_token': '</s>', 'unk_token': '<unk>', 'pad_token': '<pad>', 'additional_special_tokens': ['<extra_id_0>', '<extra_id_1>', '<extra_id_2>', '<extra_id_3>', '<extra_id_4>', '<extra_id_5>', '<extra_id_6>', '<extra_id_7>', '<extra_id_8>', '<extra_id_9>', '<extra_id_10>', '<extra_id_11>', '<extra_id_12>', '<extra_id_13>', '<extra_id_14>', '<extra_id_15>', '<extra_id_16>', '<extra_id_17>', '<extra_id_18>', '<extra_id_19>', '<extra_id_20>', '<extra_id_21>', '<extra_id_22>', '<extra_id_23>', '<extra_id_24>', '<extra_id_25>', '<extra_id_26>', '<extra_id_27>', '<extra_id_28>', '<extra_id_29>', '<extra_id_30>', '<extra_id_31>', '<extra_id_32>', '<extra_id_33>', '<extra_id_34>', '<extra_id_35>', '<extra_id_36>', '<extra_id_37>', '<extra_id_38>', '<extra_id_39>', '<extra_id_40>', '<extra_id_41>', '

In [32]:
import re
seq = "ASDX"

seq = re.sub(r"[UZOB]", "X", seq) # replacing unknown amino acid with unknown token
seq = list(seq)

mut_pos_zero_idxed = 2 # the outputs of mutant A_DX and AS_X are different at every positions.
seq[mut_pos_zero_idxed] = '<extra_id_0>'# tokenizer.mask_token #'<extra_id_0>' # mut_pos must be 0-indexed. replace AA by special mask token used by the model

seq = " ".join(list(seq)) # space separated amino acids
print(seq)

# <eos> token at the end
# starts from 0-index
input_ids = tokenizer.batch_encode_plus(
            [seq], add_special_tokens=True, padding="longest"
        )
print(input_ids)

print(tokenizer.convert_ids_to_tokens(3), tokenizer.convert_tokens_to_ids('▁A'))
print(tokenizer.convert_ids_to_tokens(127), tokenizer.convert_tokens_to_ids('<extra_id_0>'))

import torch
device = torch.device("cpu")
model = model.to(device)
tokenized_sequences = torch.tensor(input_ids["input_ids"]).to(device)
attention_mask = torch.tensor(input_ids["attention_mask"]).to(device)

A S <extra_id_0> X
{'input_ids': [[3, 7, 127, 23, 1]], 'attention_mask': [[1, 1, 1, 1, 1]]}
▁A 3
<extra_id_0> 127


In [23]:
with torch.no_grad():
    logits = model(input_ids=tokenized_sequences, attention_mask=attention_mask, decoder_input_ids=tokenized_sequences).logits

logits = logits.squeeze().cpu().numpy()
logits = logits[0:4]
print(logits)
print(logits.shape)

[[-17.390211   -9.293226  -50.622787   -7.236208  -10.624225   -8.695351
  -10.065985   -9.033567   -9.414229   -9.940994  -10.468842  -10.259579
  -12.276775   -9.225727  -11.641763  -12.024079  -11.323843  -12.420126
  -12.841396  -12.613511  -11.722964  -12.178924  -11.511997   -4.7033205
  -50.698643  -51.003498  -50.929955  -50.378647  -49.787445  -51.267986
  -50.573208  -51.00089   -50.63575   -51.379807  -50.905254  -51.337032
  -50.69402   -50.55966   -51.067028  -51.02701   -49.835644  -50.958694
  -51.259655  -51.352947  -50.01358   -50.387264  -50.794067  -49.55082
  -50.350464  -50.089256  -50.65592   -51.10532   -50.492172  -49.89929
  -51.1219    -50.76748   -49.833546  -49.934505  -50.972485  -50.68562
  -51.00729   -50.040417  -51.130943  -50.581726  -49.522335  -50.701366
  -49.25947   -49.574875  -50.061802  -51.41051   -51.184708  -51.082794
  -51.39099   -50.893303  -50.307026  -49.093086  -50.58528   -51.06971
  -50.63497   -51.10927   -50.554596  -50.864094  -50.