Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions chebifier/model_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
ChEBILookupPredictor,
ChemlogPeptidesPredictor,
ElectraPredictor,
ResGatedPredictor,
GNNPredictor,
)
from chebifier.prediction_models.c3p_predictor import C3PPredictor
from chebifier.prediction_models.chemlog_predictor import (
Expand All @@ -26,7 +26,7 @@

MODEL_TYPES = {
"electra": ElectraPredictor,
"resgated": ResGatedPredictor,
"resgated": GNNPredictor,
"gat": GATPredictor,
"chemlog": ChemlogAllPredictor,
"chemlog_peptides": ChemlogPeptidesPredictor,
Expand Down
4 changes: 2 additions & 2 deletions chebifier/prediction_models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
from .chebi_lookup import ChEBILookupPredictor
from .chemlog_predictor import ChemlogExtraPredictor, ChemlogPeptidesPredictor
from .electra_predictor import ElectraPredictor
from .gnn_predictor import ResGatedPredictor
from .gnn_predictor import GNNPredictor

__all__ = [
"BasePredictor",
"ChemlogPeptidesPredictor",
"ElectraPredictor",
"ResGatedPredictor",
"GNNPredictor",
"ChEBILookupPredictor",
"ChemlogExtraPredictor",
"C3PPredictor",
Expand Down
30 changes: 5 additions & 25 deletions chebifier/prediction_models/electra_predictor.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,7 @@
from typing import TYPE_CHECKING

import numpy as np

from .nn_predictor import NNPredictor

if TYPE_CHECKING:
from chebai.models.electra import Electra


def build_graph_from_attention(att, node_labels, token_labels, threshold=0.0):
n_nodes = len(node_labels)
Expand Down Expand Up @@ -40,37 +35,22 @@ def build_graph_from_attention(att, node_labels, token_labels, threshold=0.0):

class ElectraPredictor(NNPredictor):
def __init__(self, model_name: str, ckpt_path: str, **kwargs):
from chebai.preprocessing.reader import ChemDataReader

super().__init__(model_name, ckpt_path, reader_cls=ChemDataReader, **kwargs)
super().__init__(model_name, ckpt_path, **kwargs)
print(f"Initialised Electra model {self.model_name} (device: {self.device})")

def init_model(self, ckpt_path: str, **kwargs) -> "Electra":
from chebai.models.electra import Electra

model = Electra.load_from_checkpoint(
ckpt_path,
map_location=self.device,
criterion=None,
strict=False,
metrics=dict(train=dict(), test=dict(), validation=dict()),
pretrained_checkpoint=None,
)
model.eval()
return model

def explain_smiles(self, smiles) -> dict:
from chebai.preprocessing.reader import EMBEDDING_OFFSET

reader = self.reader_cls()
token_dict = reader.to_data(dict(features=smiles, labels=None))
token_dict = self._predictor._dm.reader.to_data(
dict(features=smiles, labels=None)
)
tokens = np.array(token_dict["features"]).astype(int).tolist()
result = self.calculate_results([token_dict])

token_labels = (
["[CLR]"]
+ [None for _ in range(EMBEDDING_OFFSET - 1)]
+ list(reader.cache.keys())
+ list(self._predictor._dm.reader.cache.keys())
)

graphs = [
Expand Down
112 changes: 2 additions & 110 deletions chebifier/prediction_models/gnn_predictor.py
Original file line number Diff line number Diff line change
@@ -1,120 +1,12 @@
from typing import TYPE_CHECKING, Optional

import torch

from .nn_predictor import NNPredictor

if TYPE_CHECKING:
from chebai_graph.models.gat import GATGraphPred
from chebai_graph.models.resgated import ResGatedGraphPred


class ResGatedPredictor(NNPredictor):
class GNNPredictor(NNPredictor):
def __init__(
self,
model_name: str,
ckpt_path: str,
molecular_properties,
dataset_cls: Optional[str] = None,
**kwargs,
):
from chebai_graph.preprocessing.datasets.chebi import (
ChEBI50GraphProperties,
GraphPropertiesMixIn,
)
from chebai_graph.preprocessing.properties import MolecularProperty

# molecular_properties is a list of class paths
if molecular_properties is not None:
properties = [self.load_class(prop)() for prop in molecular_properties]
properties = sorted(
properties, key=lambda prop: f"{prop.name}_{prop.encoder.name}"
)
else:
properties = []
for property in properties:
property.encoder.eval = True
self.molecular_properties = properties
assert isinstance(self.molecular_properties, list) and all(
isinstance(prop, MolecularProperty) for prop in self.molecular_properties
)
# TODO it should not be necessary to refer to the whole dataset class, disentangle dataset and molecule reading
self.dataset_cls = (
self.load_class(dataset_cls)
if dataset_cls is not None
else ChEBI50GraphProperties
)
self.dataset: Optional[GraphPropertiesMixIn] = self.dataset_cls(
properties=molecular_properties
)

super().__init__(
model_name, ckpt_path, reader_cls=self.dataset.READER, **kwargs
)

super().__init__(model_name, ckpt_path, **kwargs)
print(f"Initialised GNN model {self.model_name} (device: {self.device})")

def load_class(self, class_path: str):
module_path, class_name = class_path.rsplit(".", 1)
module = __import__(module_path, fromlist=[class_name])
return getattr(module, class_name)

def init_model(self, ckpt_path: str, **kwargs) -> "ResGatedGraphPred":
import torch
from chebai_graph.models.resgated import ResGatedGraphPred

model = ResGatedGraphPred.load_from_checkpoint(
ckpt_path,
map_location=torch.device(self.device),
criterion=None,
strict=False,
metrics=dict(train=dict(), test=dict(), validation=dict()),
pretrained_checkpoint=None,
)
model.eval()
return model

def read_smiles(self, smiles):
from chebai_graph.preprocessing.datasets.chebi import GraphPropAsPerNodeType

d = self.dataset.READER().to_data(dict(features=smiles, labels=None))
property_data = d
# TODO merge props into base should not be a method of a dataset (or at least static)
for property in self.dataset.properties:
property.encoder.eval = True
property_value = self.reader.read_property(smiles, property)
if property_value is None or len(property_value) == 0:
encoded_value = None
else:
encoded_value = torch.stack(
[property.encoder.encode(v) for v in property_value]
)
if len(encoded_value.shape) == 3:
encoded_value = encoded_value.squeeze(0)
property_data[property.name] = encoded_value
# Augmented graphs need an additional argument
if isinstance(self.dataset, GraphPropAsPerNodeType):
d["features"] = self.dataset._merge_props_into_base(
property_data, max_len_node_properties=self.model.gnn.in_channels
)
else:
d["features"] = self.dataset._merge_props_into_base(property_data)
return d


class GATPredictor(ResGatedPredictor):

def init_model(self, ckpt_path: str, **kwargs) -> "GATGraphPred":
import torch
from chebai_graph.models.gat import GATGraphPred

model = GATGraphPred.load_from_checkpoint(
ckpt_path,
map_location=torch.device(self.device),
criterion=None,
strict=False,
metrics=dict(train=dict(), test=dict(), validation=dict()),
pretrained_checkpoint=None,
)
model.eval()
return model
110 changes: 12 additions & 98 deletions chebifier/prediction_models/nn_predictor.py
Original file line number Diff line number Diff line change
@@ -1,115 +1,29 @@
import numpy as np
import tqdm
from rdkit import Chem
from abc import ABC

from chebai.result.prediction import Predictor

from chebifier import modelwise_smiles_lru_cache

from .base_predictor import BasePredictor


class NNPredictor(BasePredictor):
class NNPredictor(BasePredictor, ABC):
def __init__(
self,
model_name: str,
ckpt_path: str,
reader_cls,
target_labels_path: str,
**kwargs,
):
import torch
self.batch_size = kwargs.get("batch_size", None)
# If batch_size is not provided, it will be set to default batch size used during training in Predictor
self._predictor: Predictor = Predictor(ckpt_path, self.batch_size)

super().__init__(model_name, **kwargs)
self.reader_cls = reader_cls
self.reader = reader_cls()

self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model = self.init_model(ckpt_path=ckpt_path)
self.target_labels = [
line.strip() for line in open(target_labels_path, encoding="utf-8")
]
self.batch_size = kwargs.get("batch_size", 1)

def init_model(self, ckpt_path: str, **kwargs):
raise NotImplementedError(
"Model initialization must be implemented in subclasses."
)

def calculate_results(self, batch):
collator = self.reader_cls.COLLATOR()
dat = self.model._process_batch(collator(batch).to(self.device), 0)
return self.model(dat, **dat["model_kwargs"])

def batchify(self, batch):
cache = []
for r in batch:
cache.append(r)
if len(cache) >= self.batch_size:
yield cache
cache = []
if cache:
yield cache

def read_smiles(self, smiles):
d = self.reader.to_data(dict(features=smiles, labels=None))
return d

@modelwise_smiles_lru_cache.batch_decorator
def predict_smiles_list(self, smiles_list: list[str]) -> list:
"""Returns a list with the length of smiles_list, each element is either None (=failure) or a dictionary
Of classes and predicted values."""
import torch

token_dicts = []
could_not_parse = []
index_map = dict()
for i, smiles in enumerate(smiles_list):
if not smiles:
print(
f"Model {self.model_name} received a missing SMILES string at position {i}."
)
could_not_parse.append(i)
continue
try:
d = self.read_smiles(smiles)
# This is just for sanity checks
rdmol = Chem.MolFromSmiles(smiles, sanitize=False)
if rdmol is None:
print(
f"Model {self.model_name} received a SMILES string RDKit can't read at position {i}: {smiles}"
)
could_not_parse.append(i)
continue
except Exception:
could_not_parse.append(i)
print(
f"Model {self.model_name} failed to parse a SMILES string at position {i}: {smiles}"
)
continue
index_map[i] = len(token_dicts)
token_dicts.append(d)
results = []
if len(token_dicts) > 0:
for batch in tqdm.tqdm(
self.batchify(token_dicts),
desc=f"{self.model_name}",
total=len(token_dicts) // self.batch_size,
):
result = self.calculate_results(batch)
if isinstance(result, dict) and "logits" in result:
result = result["logits"]
results += torch.sigmoid(result).cpu().detach().tolist()
results = np.stack(results, axis=0)
preds = [
(
{
self.target_labels[j]: p
for j, p in enumerate(results[index_map[i]])
}
if i not in could_not_parse
else None
)
for i in range(len(smiles_list))
]
return preds
else:
return [None for _ in smiles_list]
"""
Returns a list with the length of smiles_list, each element is
either None (=failure) or a dictionary of classes and predicted values.
"""
return self._predictor.predict_smiles(smiles_list)
2 changes: 1 addition & 1 deletion data/disjoint_additional.csv
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@
47923,48030
47923,48545
48030,48545
90799,155837
90799,155837