In [1]:
import torch
from tqdm import tqdm
import numpy as np
import os

from transformers import T5EncoderModel, T5Tokenizer
from transformers import BertModel, BertTokenizer
from transformers import XLNetModel, XLNetTokenizer
from transformers import AlbertModel, AlbertTokenizer


def build_pretrained_model(model_name):
    if "t5" in model_name:
        tokenizer = T5Tokenizer.from_pretrained(model_name, do_lower_case=False)
        model = T5EncoderModel.from_pretrained(model_name)
    elif "albert" in model_name:
        tokenizer = AlbertTokenizer.from_pretrained(model_name, do_lower_case=False)
        model = AlbertModel.from_pretrained(model_name)
    elif "bert" in model_name:
        tokenizer = BertTokenizer.from_pretrained(model_name, do_lower_case=False)
        model = BertModel.from_pretrained(model_name)
    elif "xlnet" in model_name:
        tokenizer = XLNetTokenizer.from_pretrained(model_name, do_lower_case=False )
        model = XLNetModel.from_pretrained(model_name)
    else:
        raise ValueError(f"Unkown model name: {model_name}")
    return tokenizer, model

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
ppm = "Rostlab/prot_t5_xl_uniref50"
device = "cpu"

tokenizer, embeder = build_pretrained_model(ppm)
embeder = embeder.eval().to(device)

Some weights of the model checkpoint at Rostlab/prot_t5_xl_uniref50 were not used when initializing T5EncoderModel: ['decoder.block.11.layer.2.layer_norm.weight', 'decoder.block.23.layer.0.layer_norm.weight', 'decoder.block.8.layer.0.SelfAttention.v.weight', 'decoder.block.18.layer.1.layer_norm.weight', 'decoder.block.12.layer.2.layer_norm.weight', 'decoder.block.23.layer.1.EncDecAttention.k.weight', 'decoder.block.22.layer.1.EncDecAttention.k.weight', 'decoder.block.9.layer.1.EncDecAttention.q.weight', 'decoder.block.3.layer.2.layer_norm.weight', 'decoder.block.0.layer.2.DenseReluDense.wi.weight', 'decoder.block.7.layer.0.SelfAttention.k.weight', 'decoder.block.14.layer.1.EncDecAttention.o.weight', 'decoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight', 'decoder.block.13.layer.1.EncDecAttention.q.weight', 'decoder.block.0.layer.2.layer_norm.weight', 'decoder.block.23.layer.0.SelfAttention.o.weight', 'decoder.block.10.layer.0.SelfAttention.v.weight', 'decoder.block.15.l

In [3]:
train_lines = open("train.seq", "r").readlines()
test_lines = open("test.seq", "r").readlines()

In [4]:
def emb(seq):
    seqs = [" ".join(seq.strip())]
    inputs = tokenizer.batch_encode_plus(seqs, add_special_tokens=True, padding=True)
    inputs = {k: torch.tensor(v).to(device) for k, v in inputs.items()}

    seq_len = (inputs['attention_mask'][0] == 1).sum()
    with torch.no_grad():
        embedding = embeder(**inputs)
    embedding = embedding.last_hidden_state.cpu().numpy()
    assert embedding.shape[0] == 1
    embedding = embedding[0, :seq_len-1]
    return embedding

In [5]:
for idx, line in tqdm(enumerate(train_lines+test_lines)):
    fpro, fseq, spro, sseq = line.strip().split()
    fembedding = emb(fseq)
    sembedding = emb(sseq)
    
    assert fembedding.shape[0] == len(fseq)
    assert sembedding.shape[0] == len(sseq)

    np.save(f"embs/{fpro}.npy", fembedding)
    np.save(f"embs/{spro}.npy", sembedding)

50it [02:39,  3.19s/it]
