In [1]:
import torch
import processing.tokenization
import processing.parse_pdb
import pandas as pd
import matplotlib.pyplot as plt
import transformers

In [2]:
data = processing.parse_pdb.load_pdb(['protein'], data_loc='./data/pdb_seqres.txt').sort_values(by='Name')
data = processing.parse_pdb.context_size_filter(data, context_length=1024, min_length=100, max_length=4000, window=True)
data

Unnamed: 0,Name,Sequence,MolType,length
0,,GPLGSMALRACGLIIFRRCLIPKVDNNAIEFLLLQASDGIHHWTPP...,protein,153.0
1,,GPLGSMALRACGLIIFRRCLIPKVDNNAIEFLLLQASDGIHHWTPP...,protein,153.0
2,,MLVVPAIDLFRGKVARMIKGRKENTIFYEKDPVELVEKLIEEGFTL...,protein,241.0
3,,GPLGSMALRACGLIIFRRCLIPKVDNNAIEFLLLQASDGIHHWTPP...,protein,153.0
4,,MSFRFGQCLIKPSVVFLKTELSFALVNRKPVVPGHVLVCPLRPVER...,protein,163.0
...,...,...,...,...
963714,zinc peptidase active subunit,ADPAASTFETTLPNGLKVVVREDHRAPTLVHMVWYRVGSMDETTGT...,protein,437.0
963715,zinc peptidase inactive subunit,AIKIEHWTAPSGAQVYYVENRTLPMLDVQVDFDAGSAREPADQVGV...,protein,424.0
963716,zinc peptidase inactive subunit,AIKIEHWTAPSGAQVYYVENRTLPMLDVQVDFDAGSAREPADQVGV...,protein,424.0
963717,zinc-dependent metalloproteinase,QQRFPQRYVMLAIVADHGMVTKYSGNSSAITTRVHQMVSHVTEMYS...,protein,202.0


In [22]:
tokenizer = transformers.EsmTokenizer(vocab_file='./processing/basic_tokens.txt', return_tensors='pt')
esm_model = transformers.EsmForMaskedLM.from_pretrained('facebook/esm2_t33_650M_UR50D')

In [23]:
data_generator = processing.parse_pdb.dataset_generator(data, tokenizer=tokenizer, mask_token_id=tokenizer.mask_token_id, mask_probability=0.05)

In [24]:
with torch.no_grad():
    input_values, mask, true_values = next(data_generator)
    model_output = esm_model(**input_values)

In [25]:
mask.shape

torch.Size([155])

In [26]:
torch.sum(mask)

tensor(7)

In [27]:
model_output.logits.argmax(axis=-1)[0, mask]

tensor([ 5, 19,  4,  4,  5, 10,  9])

In [28]:
true_values

tensor([23, 18,  4,  4,  5, 10, 16])

In [29]:
torch.sum(model_output.logits.argmax(axis=-1)[0, mask] == true_values) / torch.sum(mask)

tensor(0.5714)

In [30]:
true_values

tensor([23, 18,  4,  4,  5, 10, 16])