Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 19 additions & 9 deletions chebifier/prediction_models/c3p_predictor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from pathlib import Path
from typing import List, Optional

import tqdm

from chebifier import modelwise_smiles_lru_cache
from chebifier.prediction_models import BasePredictor

Expand All @@ -26,14 +28,22 @@ def __init__(
def predict_smiles_list(self, smiles_list: list[str]) -> list:
from c3p import classifier as c3p_classifier

result_list = c3p_classifier.classify(
list(smiles_list),
self.program_directory,
self.chemical_classes,
strict=False,
)
result_list = []
for batch_start in tqdm.tqdm(
range(0, len(smiles_list), 32), desc="Classifying with C3P"
):
batch_end = min(batch_start + 32, len(smiles_list))
result_list.extend(
c3p_classifier.classify(
smiles_list[batch_start:batch_end],
self.program_directory,
self.chemical_classes,
strict=False,
)
)

result_reformatted = [dict() for _ in range(len(smiles_list))]
for result in result_list:
for result in tqdm.tqdm(result_list, desc="Reformatting C3P results"):
chebi_id = result.class_id.split(":")[1]
result_reformatted[smiles_list.index(result.input_smiles)][
chebi_id
Expand Down Expand Up @@ -61,13 +71,13 @@ def explain_smiles(self, smiles):
highlights.append(
(
"text",
f"For class {result.class_name} ({result.class_id}), C3P gave the following explanation: {result.reason}",
f"For {result.class_name} ({result.class_id}), C3P gave the following explanation: {result.reason}",
)
)
highlights = [
(
"text",
f"C3P made positive predictions for {len(highlights)} classes. The explanations are as follows:",
f"C3P made positive predictions for {len(highlights)} classes. {'The explanations are as follows:' if len(highlights) > 0 else ''}",
)
] + highlights

Expand Down
8 changes: 4 additions & 4 deletions chebifier/prediction_models/chebi_lookup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from chebifier import modelwise_smiles_lru_cache
from chebifier.prediction_models import BasePredictor
from chebifier.utils import load_chebi_graph
from chebifier.utils import _smiles_to_mol, load_chebi_graph


class ChEBILookupPredictor(BasePredictor):
Expand Down Expand Up @@ -50,7 +50,7 @@ def build_smiles_lookup(self):
).items():
if smiles is not None:
try:
mol = Chem.MolFromSmiles(smiles)
mol = _smiles_to_mol(smiles)
if mol is None:
print(
f"Failed to parse SMILES {smiles} for ChEBI ID {chebi_id}"
Expand All @@ -72,7 +72,7 @@ def build_smiles_lookup(self):
def predict_smiles(self, smiles: str) -> Optional[dict]:
if not smiles:
return None
mol = Chem.MolFromSmiles(smiles)
mol = _smiles_to_mol(smiles)
if mol is None:
return None
canonical_smiles = Chem.MolToSmiles(mol)
Expand Down Expand Up @@ -110,7 +110,7 @@ def info_text(self):
return self._description

def explain_smiles(self, smiles: str) -> dict:
mol = Chem.MolFromSmiles(smiles)
mol = _smiles_to_mol(smiles)
if mol is None:
return {
"highlights": [
Expand Down
14 changes: 14 additions & 0 deletions chebifier/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import functools
import importlib.resources
import os
import pickle
Expand All @@ -6,6 +7,7 @@
import networkx as nx
import requests
import yaml
from rdkit import Chem

from chebifier.hugging_face import download_model_files

Expand Down Expand Up @@ -156,6 +158,18 @@ def process_config(config, model_registry):
return new_config


@functools.lru_cache(maxsize=128)
def _smiles_to_mol(smiles: str):
mol = Chem.MolFromSmiles(smiles, sanitize=False)
if mol is not None:
# turn aromatic bond types into single/double
try:
Chem.Kekulize(mol)
except Chem.KekulizeException as e:
print(f"Failed to Kekulize {smiles}: {e}")
return mol


if __name__ == "__main__":
chebi_graph = build_chebi_graph(chebi_version=244)
os.makedirs(os.path.join("data", "chebi_v244"), exist_ok=True)
Expand Down