In [1]:
import py3Dmol
import torch
import esm
import numpy as np
import json
import requests
import os

from Bio.Align import PairwiseAligner
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker

from classes import *

In [2]:
engine = create_engine("mysql+pymysql://root:@localhost/pdb?charset=utf8mb4")
Session = sessionmaker(bind=engine)
session = Session()

In [3]:
def get_pdb(id_):
    pdb = session.query(Protein).filter(Protein.id == id_).first()
    if pdb is None:
        raise Exception("There is not file with such id")
    return pdb

In [4]:
def get_pdb_data(id_or_pdb):
    if isinstance(id_or_pdb, Protein):
         pdb = id_or_pdb
    else:
        id_ = id_or_pdb.upper()
        pdb = get_pdb(id_)
    with open(pdb.structure) as f:
        data = f.read()
    return data

In [5]:
def view_pdb(id_or_pdb=None, chain_id=None, data=None):
    view = py3Dmol.view(width=600, height=400)
    if data is None:
        data = get_pdb_data(id_or_pdb)
    view.addModelsAsFrames(data)
    view.setStyle({'model': -1, 'chain': chain_id}, {"cartoon": {'color': 'spectrum'}})
    view.zoomTo()
    view.show()

In [6]:
def get_chains_to_compare(offset, limit, only_chain_id=None):
    query = session.query(Chain)
    if only_chain_id is not None:
        query = query.filter(Chain.chain_id==only_chain_id)
    chains = query.limit(limit).offset(offset).all()
    return chains

In [7]:
def get_most_similars_by_MSA(sequence, only_chain_id=None, top_count=5):
    aligner = PairwiseAligner()
    offset = 0
    limit = 20
    tops = []
    while True:
        chains = get_chains_to_compare(offset, limit, only_chain_id)
        if not len(chains):
            break
        offset += limit
        for chain in chains:
            alignments = aligner.align(chain.sequence, sequence)
            similarity = alignments[0].score
            i = len(tops)
            while i>0 and similarity>tops[i-1][0]:
                i-=1
            if i<top_count:
                tops.insert(i, (similarity, chain))
                tops = tops[:top_count]
    return [i for i in tops]

In [17]:
def evaluate(model, batch_converter, sequence):
    data = [("seq", sequence)]
    batch_labels, batch_strs, batch_tokens = batch_converter(data)
    with torch.no_grad():
        results = model(batch_tokens)
    avg_tokens = results["logits"][0].mean(axis=0)
    return avg_tokens.cpu().numpy()

In [18]:
def get_most_similars_by_esm(model, batch_converter, sequence, only_chain_id=None, top_count=5):
    res1 = evaluate(model, batch_converter, sequence)
    offset = 0
    limit = 20
    tops = []
    while True:
        chains = get_chains_to_compare(offset, limit, only_chain_id)
        if not len(chains):
            break
        offset += limit
        res = np.array([json.loads(chain.esm) for chain in chains])
        diffs = np.absolute(res-res1).mean(axis=1)
        mins = np.argsort(diffs)[:top_count]
        for min_index in mins:
            i = len(tops)
            while i>0 and diffs[min_index]<tops[i-1][0]:
                i-=1
            if i<top_count:
                tops.insert(i, (diffs[min_index], chains[min_index]))
                tops = tops[:top_count]
    return [i for i in tops]

In [36]:
def fold_sequence(sequence):
    allowed = {'B', 'V', 'X', 'F', 'H', 'K', 'I', 'C', 'S', 'L', 'Y', 'Z', 'T', 'A', 'D', 'P', 'R', 'E', 'N', 'Q', 'J', 'G', 'M', 'W'}
    if len(set(sequence).difference(allowed)):
        raise Exception(f"Sequence is only allowed to have {allowed} tokens.")
    res = requests.post("https://api.esmatlas.com/foldSequence/v1/pdb/", data=sequence)
    if res.status_code != 200:
        raise Exception(res.text)
    return res.text

In [11]:
def to_base_36(s):
    BS="0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ"
    b = 36
    res = ""
    while s:
        res+=BS[s%b]
        s//= b
    return res[::-1] or "0"

In [12]:
def save_generated_pdb(sequence, pdb_content):
    pdbs_folder = "/Users/arturmkrtchyan/Desktop/pdb2/generated"
    os.makedirs(pdbs_folder, exist_ok=True)
    generated_count = len(os.listdir(pdbs_folder))
    pdb_id = "+" + to_base_36(generated_count).zfill(3)
    pdb_file_name = pdb_id + ".pdb"
    pdb_path = os.path.join(pdbs_folder, pdb_file_name)
    with open(pdb_path, "w") as f:
        f.write(pdb_content)
    protein = Protein(pdb_id, "Generated Protein " + str(generated_count), pdb_path)
    tokens = evaluate(model, sequence)
    esm = json.dumps(tokens.tolist())
    chain = Chain(pdb_id, "A", sequence, esm)
    session.add(protein)
    session.add(chain)
    protein.chains.append(chain)
    session.commit()
    return pdb_id

In [19]:
def get_similars(sequence, only_chain_id=None, save_to_db=True, algorithm="esm", count=5):
    score_title = None
    if algorithm=="esm":
        score_title = "Difference"
        model, alphabet = esm.pretrained.esm2_t33_650M_UR50D()
        batch_converter = alphabet.get_batch_converter()
        model.eval();
        similars = get_most_similars_by_esm(model, batch_converter, sequence, only_chain_id, top_count=count)
    else:
        score_title = "Similarity score" 
        similars = get_most_similars_by_MSA(sequence, only_chain_id, top_count=count)
    generated_pdb = fold_sequence(sequence)
    if save_to_db:
        save_generated_pdb(sequence, generated_pdb)
    print(f"Visual representation of this sequence` {sequence}")
    view_pdb(data=generated_pdb, chain_id="A")
    print(f"Below you can see the most {count} similar proteins")
    for similarity, similar in similars:
        print(f"{score_title}: {round(float(similarity), 2)}\n Protein: {similar.protein.name}({similar.protein.id})\n Chain Id: {similar.chain_id}\n Sequence of the chain: {similar.sequence}")
        view_pdb(similar.protein, similar.chain_id)

In [55]:
sequence = "DTDADADCDCDC"
get_similars(sequence, algorithm="msa", save_to_db=False, count=5)

Visual representation of this sequence` DTDADADCDCDC


Below you can see the most 5 similar proteins
Similarity score: 12.0
 Protein: stable loop in the crystal structure of the intercalated four-stranded cytosine-rich metazoan telomere(200D)
 Chain Id: A
 Sequence of the chain: DTDADADCDCDC


Similarity score: 12.0
 Protein: stable loop in the crystal structure of the intercalated four-stranded cytosine-rich metazoan telomere(200D)
 Chain Id: B
 Sequence of the chain: DTDADADCDCDC


Similarity score: 10.0
 Protein: structure of a dna in low salt conditions d(gaccgcggtc)(401D)
 Chain Id: A
 Sequence of the chain: DGDADCDCDGDCDGDGDTDC


Similarity score: 10.0
 Protein: the unusual structure of the human centromere (gga)2 motif: unpaired guanosine residues stacked between sheared g(dot)a pairs(103D)
 Chain Id: A
 Sequence of the chain: DGDTDGDGDADADTDGDGDADADC


Similarity score: 10.0
 Protein: the unusual structure of the human centromere (gga)2 motif: unpaired guanosine residues stacked between sheared g(dot)a pairs(103D)
 Chain Id: B
 Sequence of the chain: DGDTDGDGDADADTDGDGDADADC


In [None]:
import unittest
from unittest.mock import patch

class TestGetSimilars(unittest.TestCase):
    def setUp(self):
        self.sequence = "MKLTIYPDELVQIVSDKIASNIVAAQCVQGFNADPTVEIKLVEEVNNGDPEILQVSPVAKKVGVIGLNFGRQKASIIICLPDDDQIETIYQA"
        self.only_chain_id = None
        self.save_to_db = False
        self.algorithm = "esm"
        self.count = 5

    @patch("fold_sequence")
    @patch("save_generated_pdb")
    @patch("view_pdb")
    @patch("get_most_similars_by_esm")
    def test_get_similars_with_esm_algorithm(self, mock_similars, mock_view_pdb, mock_save_pdb, mock_fold_sequence):
        mock_similars.return_value = [
            (0.8, {
                "protein": {
                    "name": "Protein A",
                    "id": "ABC123"
                },
                "chain_id": "A",
                "sequence": "MKLTIYPDELVQIVSDKIASNIVAAQCVQGFNADPTVEIKLVEEVNNGDPEILQVSPVAKKVGVIGLNFGRQKASIIICLPDDDQIETIYQA"
            })
        ]
        mock_fold_sequence.return_value = "Generated PDB"
        get_similars(self.sequence, self.only_chain_id, self.save_to_db, self.algorithm, self.count)
        mock_similars.assert_called_once_with(
            self.sequence, self.only_chain_id, top_count=self.count
        )
        mock_view_pdb.assert_called_once_with(data="Generated PDB", chain_id="A")
        mock_save_pdb.assert_not_called()

    @patch("get_most_similars_by_MSA")
    def test_get_similars_with_msa_algorithm(self, mock_similars):
        mock_similars.return_value = [
            (0.7, {
                "protein": {
                    "name": "Protein B",
                    "id": "XYZ789"
                },
                "chain_id": "B",
                "sequence": "MKLTIYPDELVQIVSDKIASNIVAAQCVQGFNADPTVEIKLVEEVNNGDPEILQVSPVAKKVGVIGLNFGRQKASIIICLPDDDQIETIYQA"
            })
        ]
        get_similars(self.sequence, self.only_chain_id, self.save_to_db, self.algorithm, self.count)
        mock_similars.assert_called_once_with(
            self.sequence, self.only_chain_id, top_count=self.count
        )

if __name__ == '__main__':
    unittest.main()