In [None]:
# The code for PLAPT binding affinity prediction

In [None]:
import torch
from transformers import BertTokenizer, BertModel, RobertaTokenizer, RobertaModel
import re
import onnxruntime
import numpy as np
from typing import List, Dict, Union
from diskcache import Cache
from tqdm import tqdm
from contextlib import contextmanager, nullcontext

In [None]:
class PredictionModule:
    def __init__(self, model_path: str = "models/affinity_predictor.onnx"):
        self.session = onnxruntime.InferenceSession(model_path)
        self.input_name = self.session.get_inputs()[0].name
        self.mean = 6.51286529169358
        self.scale = 1.5614094578916633

    def convert_to_affinity(self, normalized: float) -> Dict[str, float]:
        neg_log10_affinity_M = float((normalized * self.scale) + self.mean)
        affinity_uM = float((10**6) * (10**(-neg_log10_affinity_M)))
        return {
            "neg_log10_affinity_M": neg_log10_affinity_M,
            "affinity_uM": affinity_uM
        }

    def predict(self, batch_data: np.ndarray) -> List[Dict[str, float]]:
        affinities = []
        for feature in batch_data:
            affinity_normalized = self.session.run(None, {self.input_name: [feature], 'TrainingMode': np.array(False)})[0][0][0]
            affinities.append(self.convert_to_affinity(affinity_normalized))
        return affinities

class Plapt:
    def __init__(self, prediction_module_path: str = "models/affinity_predictor.onnx", device: str = 'cuda', cache_dir: str = './embedding_cache', use_tqdm: bool = False):
        self.device = torch.device(device if torch.cuda.is_available() else 'cpu')
        self.use_tqdm = use_tqdm

        self.prot_tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert", do_lower_case=False)
        self.prot_encoder = BertModel.from_pretrained("Rostlab/prot_bert").to(self.device)

        self.mol_tokenizer = RobertaTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
        self.mol_encoder = RobertaModel.from_pretrained("seyonec/ChemBERTa-zinc-base-v1").to(self.device)

        self.prediction_module = PredictionModule(prediction_module_path)
        self.cache = Cache(cache_dir)

    @contextmanager
    def progress_bar(self, total: int, desc: str):
        if self.use_tqdm:
            with tqdm(total=total, desc=desc) as pbar:
                yield pbar
        else:
            yield nullcontext()

    @staticmethod
    def preprocess_sequence(seq: str) -> str:
        return " ".join(re.sub(r"[UZOB]", "X", seq))

    def tokenize_molecule(self, mol_smiles: Union[str, List[str]]) -> torch.Tensor:
        return self.mol_tokenizer(mol_smiles, padding=True, max_length=278, truncation=True, return_tensors='pt')

    def tokenize_protein(self, prot_seq: Union[str, List[str]]) -> torch.Tensor:
        preprocessed = [self.preprocess_sequence(seq) if isinstance(seq, str) else self.preprocess_sequence(seq[0]) for seq in prot_seq]
        return self.prot_tokenizer(preprocessed, padding=True, max_length=3200, truncation=True, return_tensors='pt')

    def encode_molecules(self, mol_smiles: List[str], batch_size: int) -> torch.Tensor:
        embeddings = []
        with self.progress_bar(len(mol_smiles), "Encoding molecules") as pbar:
            for batch in self.make_batches(mol_smiles, batch_size):
                cached_embeddings = [self.cache.get(smiles) for smiles in batch]
                uncached_indices = [i for i, emb in enumerate(cached_embeddings) if emb is None]

                if uncached_indices:
                    uncached_smiles = [batch[i] for i in uncached_indices]
                    tokens = self.tokenize_molecule(uncached_smiles)
                    with torch.no_grad():
                        new_embeddings = self.mol_encoder(**tokens.to(self.device)).pooler_output.cpu()
                    for i, emb in zip(uncached_indices, new_embeddings):
                        cached_embeddings[i] = emb
                        self.cache[batch[i]] = emb

                embeddings.extend(cached_embeddings)
                if self.use_tqdm:
                    pbar.update(len(batch))

        return torch.stack(embeddings).to(self.device)

    def encode_proteins(self, prot_seqs: List[str], batch_size: int) -> torch.Tensor:
        embeddings = []
        with self.progress_bar(len(prot_seqs), "Encoding proteins") as pbar:
            for batch in self.make_batches(prot_seqs, batch_size):
                cached_embeddings = [self.cache.get(seq) for seq in batch]
                uncached_indices = [i for i, emb in enumerate(cached_embeddings) if emb is None]

                if uncached_indices:
                    uncached_seqs = [batch[i] for i in uncached_indices]
                    tokens = self.tokenize_protein(uncached_seqs)
                    with torch.no_grad():
                        new_embeddings = self.prot_encoder(**tokens.to(self.device)).pooler_output.cpu()
                    for i, emb in zip(uncached_indices, new_embeddings):
                        cached_embeddings[i] = emb
                        self.cache[batch[i]] = emb

                embeddings.extend(cached_embeddings)
                if self.use_tqdm:
                    pbar.update(len(batch))

        return torch.stack(embeddings).to(self.device)

    @staticmethod
    def make_batches(iterable: List, n: int = 1):
        length = len(iterable)
        for ndx in range(0, length, n):
            yield iterable[ndx:min(ndx + n, length)]

    def predict_affinity(self, prot_seqs: List[str], mol_smiles: List[str], prot_batch_size: int = 2, mol_batch_size: int = 16, affinity_batch_size: int = 128) -> List[Dict[str, float]]:
        if len(prot_seqs) != len(mol_smiles):
            raise ValueError("The number of proteins and molecules must be the same.")

        prot_encodings = self.encode_proteins(prot_seqs, prot_batch_size)
        mol_encodings = self.encode_molecules(mol_smiles, mol_batch_size)

        affinities = []
        with self.progress_bar(len(prot_seqs), "Predicting affinities") as pbar:
            for batch in self.make_batches(range(len(prot_seqs)), affinity_batch_size):
                prot_batch = prot_encodings[batch]
                mol_batch = mol_encodings[batch]
                features = torch.cat((prot_batch, mol_batch), dim=1).cpu().numpy()
                batch_affinities = self.prediction_module.predict(features)
                affinities.extend(batch_affinities)
                if self.use_tqdm:
                    pbar.update(len(batch))

        return affinities

    def score_candidates(self, target_protein: str, mol_smiles: List[str], mol_batch_size: int = 16, affinity_batch_size: int = 128) -> List[Dict[str, float]]:
        target_encoding = self.encode_proteins([target_protein], batch_size=1)
        mol_encodings = self.encode_molecules(mol_smiles, mol_batch_size)

        affinities = []
        with self.progress_bar(len(mol_smiles), "Scoring candidates") as pbar:
            for batch in self.make_batches(range(len(mol_smiles)), affinity_batch_size):
                mol_batch = mol_encodings[batch]
                repeated_target = target_encoding.repeat(len(batch), 1)
                features = torch.cat((repeated_target, mol_batch), dim=1).cpu().numpy()
                batch_affinities = self.prediction_module.predict(features)
                affinities.extend(batch_affinities)
                if self.use_tqdm:
                    pbar.update(len(batch))

        return affinities

if __name__ == "__main__":
    plapt = Plapt()
    target_protein = "MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLAGG"
    candidate_molecules = ['CC#C[C@]1(O)CC[C@H]2[C@@H]3CC=C4C(=O)CC[C@]4(C)[C@H]3CC[C@]21C','CN1C=C(C=N1)S(=O)(=O)N2CCC3=CC4=C(C[C@@]3(C2)C(=O)C5=NC=CC(=C5)C(F)','C1CC(CCC1C2=CC=CC=C2)C3=C(C(=O)NC(=O)N3)CC4=CC(=CC=C4)C(F)(F)F','CC(C)C1=CC=C(C=C1)C2=CC(=O)C3=C(C2=O)C4=CC=CC=C4C5=CC=CC=C35','CC#C[C@@]1(CC[C@@H]2[C@@]1(C[C@@H](C3=C4CCC(=O)C=C4CC[C@@H]23)C5=CC=C(C=C5)N(C)C(C)C)C)O','CC(C)[C@@H]([C@@H](C1=CC=CC=C1)OC2=CC3=C(C=C2)N(N=C3)C4=CN(C(=O)C=C4)C)NC(=O)C(C)(F)F','O=C(OCC(=O)[C@@]4(O)[C@H](C)C[C@H]5[C@@H]6/C=C(\C3=C\c1c(cnn1c2ccccc2)C[C@@]3([C@H]6[C@@H](O)C[C@]45C)C)C)C','CC#C[C@]1(O)CC[C@H]2[C@@H]3CC=C4C(=O)CC[C@]4(C)[C@H]3CC[C@]21C','CC(C)C1=CC=C(C=C1)C2=CC(=O)C3=C(C2=O)C4=CC=CC=C4C5=CC=CC=C35','CN1C=C(C=N1)S(=O)(=O)N2CCC3=CC4=C(C[C@@]3(C2)C(=O)C5=NC=CC(=C5)C(F)(F)F)C=NN4C6=CC=C(C=C6)F']
    scores = plapt.score_candidates(target_protein, candidate_molecules, mol_batch_size=16, affinity_batch_size=128)
    print("\nScore Candidates Results:", scores)