In [48]:
from transformers import T5Tokenizer, T5EncoderModel
import torch
import re

In [49]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [35]:
device

device(type='cpu')

In [36]:
tokenizer = T5Tokenizer.from_pretrained('Rostlab/prot_t5_xl_half_uniref50-enc', do_lower_case=False)

In [37]:
model = T5EncoderModel.from_pretrained("Rostlab/prot_t5_xl_half_uniref50-enc").to(device)

In [38]:
model.full() if device=='cpu' else model.half()

T5EncoderModel(
  (shared): Embedding(128, 1024)
  (encoder): T5Stack(
    (embed_tokens): Embedding(128, 1024)
    (block): ModuleList(
      (0): T5Block(
        (layer): ModuleList(
          (0): T5LayerSelfAttention(
            (SelfAttention): T5Attention(
              (q): Linear(in_features=1024, out_features=4096, bias=False)
              (k): Linear(in_features=1024, out_features=4096, bias=False)
              (v): Linear(in_features=1024, out_features=4096, bias=False)
              (o): Linear(in_features=4096, out_features=1024, bias=False)
              (relative_attention_bias): Embedding(32, 32)
            )
            (layer_norm): T5LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (1): T5LayerFF(
            (DenseReluDense): T5DenseActDense(
              (wi): Linear(in_features=1024, out_features=16384, bias=False)
              (wo): Linear(in_features=16384, out_features=1024, bias=False)
              (dropout): Dropo

In [39]:
seqs = open('../xylo.aa').readlines()

In [40]:
seqs[:3]

['>KAK1830313.1 xylose reductase [Schizothecium conicum]\n',
 'MVAVPNIKLNSGHDMPQVGFGLWKVGNDVASDVVYNAIKAGYRLFDGACDYGNEVECGQGVARAIKDGLV\n',
 'KREELFIVSKLWNTFHDGERVVPIVQKQLADWGLEYFDLYLIHFPVALEYVDPSVRYPPGWHYQGDEIRR\n']

In [41]:
from Bio import SeqIO

In [50]:
records = list(SeqIO.parse("../xylo.aa", "fasta"))

In [51]:
seqs = [str(prot.seq) for prot in records]

In [44]:
fmt = [" ".join(list(re.sub(r"[UZOB]", "X", sequence))) for sequence in seqs]

In [45]:
ids = tokenizer(fmt, add_special_tokens=True, padding="longest")

In [46]:
ip = torch.tensor(ids['input_ids']).to(device)
mask = torch.tensor(ids['attention_mask']).to(device)

In [47]:
with torch.no_grad():
    embedding = model(input_ids=ip, attention_mask=mask)

RuntimeError: "addmm_impl_cpu_" not implemented for 'Half'

In [28]:
ip

tensor([[19,  6,  3,  ...,  0,  0,  0],
        [19,  6, 13,  ...,  0,  0,  0],
        [19,  6, 13,  ...,  0,  0,  0],
        ...,
        [19, 11, 18,  ...,  0,  0,  0],
        [19, 13,  7,  ...,  0,  0,  0],
        [19, 13,  7,  ...,  0,  0,  0]])

In [34]:
mask[0].type()

'torch.LongTensor'

In [52]:
len(seqs)

1244

In [53]:
records

[SeqRecord(seq=Seq('MVAVPNIKLNSGHDMPQVGFGLWKVGNDVASDVVYNAIKAGYRLFDGACDYGNE...IFG'), id='KAK1830313.1', name='KAK1830313.1', description='KAK1830313.1 xylose reductase [Schizothecium conicum]', dbxrefs=[]),
 SeqRecord(seq=Seq('MVPAIKLNSGFDMPQVGFGLWKVDGEIASDVVYNAIKAGYRLFDGACDYGNEVE...IFG'), id='KAK1781853.1', name='KAK1781853.1', description='KAK1781853.1 xylose reductase [Copromyces sp. CBS 386.78]', dbxrefs=[]),
 SeqRecord(seq=Seq('MVPNIKLSTGQDMPQVGFGLWKVDNAICADTVYNAIKVGYRLFDGACDYGNEVE...IFG'), id='KAK1764851.1', name='KAK1764851.1', description='KAK1764851.1 xylose reductase [Phialemonium atrogriseum]', dbxrefs=[]),
 SeqRecord(seq=Seq('MSASIPDIKLSSGHLMPSIGFGCWKLANATAGEQVYQAIKAGYRLFDGAEDYGN...IFV'), id='pdb|1SM9|D', name='pdb|1SM9|D', description='pdb|1SM9|D Chain D, xylose reductase', dbxrefs=[]),
 SeqRecord(seq=Seq('MSASIPDIKLSSGHLMPSIGFGCWKLANATAGEQVYQAIKAGYRLFDGAEDYGN...IFV'), id='pdb|1SM9|C', name='pdb|1SM9|C', description='pdb|1SM9|C Chain C, xylose reductase', dbxrefs=[]),
 SeqR

In [None]:
[[f'>{seq.id}\n{seq.seq}\n' 
    for seq in 
    records[:10]]
    for k in range(0, len(records), len(records) / 4)]

In [63]:
[[]k for k in range(0, len(records), len(records) // 4)]

[0, 311, 622, 933]

In [71]:
[(k, k + (k + 100) % len(records)) for k in range(0, len(records), 100)]

[(0, 100),
 (100, 300),
 (200, 500),
 (300, 700),
 (400, 900),
 (500, 1100),
 (600, 1300),
 (700, 1500),
 (800, 1700),
 (900, 1900),
 (1000, 2100),
 (1100, 2300),
 (1200, 1256)]

In [72]:
len(records)

1244

In [73]:
seqs

['MVAVPNIKLNSGHDMPQVGFGLWKVGNDVASDVVYNAIKAGYRLFDGACDYGNEVECGQGVARAIKDGLVKREELFIVSKLWNTFHDGERVVPIVQKQLADWGLEYFDLYLIHFPVALEYVDPSVRYPPGWHYQGDEIRRSKATIQETWTAMESLVEKKLSKSIGISNFQSQLIYDLLRHAKIPPATLQIEHHPFLVQQELLNLAKNEGIAVTAYSSFGPASFLEFNMDHAVQLKPLIEDETIKSIAAKHGRDPSQVLLRWATQRGLAIIPKSTREALMVSNLASLEFDLTEDEIKTISGFNRGIRFNQPSNYFPTQDLWIFG',
 'MVPAIKLNSGFDMPQVGFGLWKVDGEIASDVVYNAIKAGYRLFDGACDYGNEVECGQGVARAIKEGIVKREELFIVSKLWNTFHDGDRVEPIVRKQLADWGVDYFDLYLIHFPVALEYVDPSVRYPPGWHFDGQSEIRPSKATIQETWTAMESLVEKGLAKSIGVSNFQAQLLYDLLRYAKIRPATLQIEHHPYLVQQNLLNLAKAEGIAVTAYSSFGPASFREFNMEHAQKLQPLLEDATIKSIADKYNKDPAQVLLRWATQRGLAIIPKSSREATMKSNLNCLDFDLSEEDIKTISAFDRGIRFNQPTNYFSAENLWIFG',
 'MVPNIKLSTGQDMPQVGFGLWKVDNAICADTVYNAIKVGYRLFDGACDYGNEVEAGQGVARAIKEGIVKREELFIVSKLWNTFHDGDRVEPIVRKQLADWGVDYFDLYLIHFPVALEYVDPSVRYPPGWHYDGSSEIRPSKASIQETWTAMESLVGSGLAKNIGVSNFQAQLLYDLLRYAKIKPATLQIEHHPYLVQPELLRLAKTEGIAVTAYSSFGPASFAEFNMAHAAAITPLLEEATVTAIAAKHGKEPSQVLLRWATQRGLAVIPKSVREKYMQSNLASIEFDLEQSEIDQISNLDKGLRFNQPANYFPTEALWIFG',
 'MSASIPDIKLSSGHLM