In [3]:
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
import os
from Bio.PDB import PDBParser
from tqdm import tqdm
import torch
import esm
import json
from classes import *

In [4]:
import warnings
warnings.filterwarnings("ignore",message=".*?Chain .*? is discontinuous.*?")

In [5]:
def parse_pdb(id_, path):
    structure = PDBParser().get_structure(id_, path)
    pdb = {
        "id": id_,
        "name": structure.header["name"],
        "chains": {chain.id:''.join(residue.resname for residue in chain if residue.get_id()[0]==" ") for chain in structure.get_chains()}
    }
    return pdb

In [6]:
import ssl
ssl._create_default_https_context = ssl._create_unverified_context

model, alphabet = esm.pretrained.esm2_t33_650M_UR50D()
batch_converter = alphabet.get_batch_converter()
model.eval();  # disables dropout for deterministic results

In [6]:
pdbs_folder = "/Users/arturmkrtchyan/Desktop/pdb2"

In [7]:
folders = sorted(filter(lambda x: not x.startswith("."), os.listdir(pdbs_folder)))
for folder in folders:
    for file in os.listdir(f"{pdbs_folder}/{folder}"):
        if file.endswith(".ent"):
            id_ = file[3:7].upper()
            os.rename(f"{pdbs_folder}/{folder}/{file}",f"{pdbs_folder}/{folder}/{id_}.pdb")

In [8]:
engine = create_engine("mysql+pymysql://root:@localhost/pdb?charset=utf8mb4")
Base.metadata.create_all(engine)
Session = sessionmaker(bind=engine)
session = Session()

In [10]:
for folder in tqdm(folders):
    files_list = filter(lambda x: not x.startswith("."), os.listdir(f"{pdbs_folder}/{folder}"))
    chains_to_evaluate = []
    for file in files_list:
        path = f"{pdbs_folder}/{folder}/{file}"
        id_ = file[:4]
        data = parse_pdb(id_, path)
        protein = session.query(Protein).filter(Protein.id==id_).first()
        if protein is None:
            protein = Protein(id_, data["name"], path)
            session.add(protein)
        chain_ids_in_db = {i.chain_id for i in protein.chains}
        for chain_id, sequence in data["chains"].items():
            if chain_id not in chain_ids_in_db:
                chain = Chain(protein.id, chain_id, sequence)
                session.add(chain)
                protein.chains.append(chain)
                chains_to_evaluate.append(chain)
    if len(chains_to_evaluate):
        data = [(i.id, i.sequence) for i in chains_to_evaluate]
        batch_labels, batch_strs, batch_tokens = batch_converter(data)
        with torch.no_grad():
            results = model(batch_tokens)
        for tokens, chain in zip(results["logits"], chains_to_evaluate):
            avg_tokens = tokens.mean(axis=0)
            chain.esm = json.dumps(avg_tokens.cpu().numpy().tolist())
    session.commit()

100%|██████████| 4/4 [00:00<00:00,  4.22it/s]


In [8]:
parse_pdb("+000","/Users/arturmkrtchyan/Desktop/pdb2/generated/+000.pdb")

{'id': '+000',
 'name': 'esmfold v1 prediction for input',
 'chains': {'A': 'METLYSTHRVALARGGLNGLUARGLEULYSSERILEVALARGILELEUGLUARGSERLYSGLUPROVALSERGLYALAGLNLEUALAGLUGLULEUSERVALSERARGGLNVALILEVALGLNASPILEALATYRLEUARGSERLEUGLYTYRASNILEVALALATHRPROARGGLYTYRVALLEUALAGLYGLY'}}