In [46]:
import torch
import processing.tokenization
import processing.parse_pdb
import pandas as pd
import matplotlib.pyplot as plt
import transformers
from tqdm import tqdm

In [47]:
data = processing.parse_pdb.load_pdb(['protein'], data_loc='./data/pdb_seqres.txt').sort_values(by='Name')
data = processing.parse_pdb.context_size_filter(data, context_length=1024, min_length=100, max_length=4000, window=True)
data

Unnamed: 0,Name,Sequence,MolType,length
0,,GPLGSMALRACGLIIFRRCLIPKVDNNAIEFLLLQASDGIHHWTPP...,protein,153.0
1,,GPLGSMALRACGLIIFRRCLIPKVDNNAIEFLLLQASDGIHHWTPP...,protein,153.0
2,,MLVVPAIDLFRGKVARMIKGRKENTIFYEKDPVELVEKLIEEGFTL...,protein,241.0
3,,GPLGSMALRACGLIIFRRCLIPKVDNNAIEFLLLQASDGIHHWTPP...,protein,153.0
4,,MSFRFGQCLIKPSVVFLKTELSFALVNRKPVVPGHVLVCPLRPVER...,protein,163.0
...,...,...,...,...
963714,zinc peptidase active subunit,ADPAASTFETTLPNGLKVVVREDHRAPTLVHMVWYRVGSMDETTGT...,protein,437.0
963715,zinc peptidase inactive subunit,AIKIEHWTAPSGAQVYYVENRTLPMLDVQVDFDAGSAREPADQVGV...,protein,424.0
963716,zinc peptidase inactive subunit,AIKIEHWTAPSGAQVYYVENRTLPMLDVQVDFDAGSAREPADQVGV...,protein,424.0
963717,zinc-dependent metalloproteinase,QQRFPQRYVMLAIVADHGMVTKYSGNSSAITTRVHQMVSHVTEMYS...,protein,202.0


In [48]:
tokenizer = transformers.EsmTokenizer(vocab_file='./processing/basic_tokens.txt', return_tensors='pt')
# 'facebook/esm2_t33_650M_UR50D'
esm_model = transformers.EsmForMaskedLM.from_pretrained('facebook/esm2_t48_15B_UR50D')

config.json:   0%|          | 0.00/725 [00:00<?, ?B/s]

pytorch_model.bin.index.json:   0%|          | 0.00/73.7k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/7 [00:00<?, ?it/s]

pytorch_model-00001-of-00007.bin:   0%|          | 0.00/9.67G [00:00<?, ?B/s]

pytorch_model-00002-of-00007.bin:   0%|          | 0.00/9.65G [00:00<?, ?B/s]

pytorch_model-00003-of-00007.bin:   0%|          | 0.00/9.96G [00:00<?, ?B/s]

In [49]:
data_generator = processing.parse_pdb.masked_dataset_generator(data, tokenizer=tokenizer, mask_token_id=tokenizer.mask_token_id, mask_probability=0.05)

In [50]:
import os
total_correct = 0
num_masks = 0
file_loc = './bigger_run.txt'

with torch.no_grad():
    caught_up = -1
    if os.path.exists(file_loc):
        with open('./current_run.txt', 'r') as file:
            last_protein = len(file.readlines())
        current_output = open(file_loc, 'a+')
        caught_up = 0
    else:
        current_output = open(file_loc, 'a+')
    
    try:
        for idx, (input_values, mask, true_values) in enumerate(tqdm(data_generator, total=len(data))):
            if caught_up != -1 and caught_up < last_protein:
                caught_up += 1
                continue
            model_output = esm_model(**input_values)
            total_correct += torch.sum(model_output.logits.argmax(axis=-1)[0, mask] == true_values).item()
            num_masks += torch.sum(mask).item()
            current_output.write(f'{idx},{total_correct},{num_masks}\n')
    except KeyboardInterrupt:
        current_output.flush()
        current_output.close()
        print("Flushed and closed file!")

  0%|          | 1070/963719 [08:42<130:37:08,  2.05it/s]

Flushed and closed file!



