From 2fe8a8429f4c0ef0c8509c043610b6b9b494aea4 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sat, 6 Dec 2025 12:22:31 +0100 Subject: [PATCH 1/2] use the generalized prediction pipeline --- chebifier/model_registry.py | 10 +- chebifier/prediction_models/__init__.py | 4 +- .../prediction_models/electra_predictor.py | 30 +---- chebifier/prediction_models/gnn_predictor.py | 112 +----------------- chebifier/prediction_models/nn_predictor.py | 110 ++--------------- 5 files changed, 26 insertions(+), 240 deletions(-) diff --git a/chebifier/model_registry.py b/chebifier/model_registry.py index 3632662..669247e 100644 --- a/chebifier/model_registry.py +++ b/chebifier/model_registry.py @@ -7,7 +7,7 @@ ChEBILookupPredictor, ChemlogPeptidesPredictor, ElectraPredictor, - ResGatedPredictor, + GNNPredictor, ) from chebifier.prediction_models.c3p_predictor import C3PPredictor from chebifier.prediction_models.chemlog_predictor import ( @@ -26,7 +26,7 @@ MODEL_TYPES = { "electra": ElectraPredictor, - "resgated": ResGatedPredictor, + "resgated": GNNPredictor, "gat": GATPredictor, "chemlog": ChemlogAllPredictor, "chemlog_peptides": ChemlogPeptidesPredictor, @@ -38,6 +38,6 @@ common_keys = MODEL_TYPES.keys() & ENSEMBLES.keys() -assert ( - not common_keys -), f"Overlapping keys between MODEL_TYPES and ENSEMBLES: {common_keys}" +assert not common_keys, ( + f"Overlapping keys between MODEL_TYPES and ENSEMBLES: {common_keys}" +) diff --git a/chebifier/prediction_models/__init__.py b/chebifier/prediction_models/__init__.py index bc29580..3eb2de1 100644 --- a/chebifier/prediction_models/__init__.py +++ b/chebifier/prediction_models/__init__.py @@ -3,13 +3,13 @@ from .chebi_lookup import ChEBILookupPredictor from .chemlog_predictor import ChemlogExtraPredictor, ChemlogPeptidesPredictor from .electra_predictor import ElectraPredictor -from .gnn_predictor import ResGatedPredictor +from .gnn_predictor import GNNPredictor __all__ = [ "BasePredictor", "ChemlogPeptidesPredictor", "ElectraPredictor", - "ResGatedPredictor", + "GNNPredictor", "ChEBILookupPredictor", "ChemlogExtraPredictor", "C3PPredictor", diff --git a/chebifier/prediction_models/electra_predictor.py b/chebifier/prediction_models/electra_predictor.py index 4843cad..62cf635 100644 --- a/chebifier/prediction_models/electra_predictor.py +++ b/chebifier/prediction_models/electra_predictor.py @@ -1,12 +1,7 @@ -from typing import TYPE_CHECKING - import numpy as np from .nn_predictor import NNPredictor -if TYPE_CHECKING: - from chebai.models.electra import Electra - def build_graph_from_attention(att, node_labels, token_labels, threshold=0.0): n_nodes = len(node_labels) @@ -40,37 +35,22 @@ def build_graph_from_attention(att, node_labels, token_labels, threshold=0.0): class ElectraPredictor(NNPredictor): def __init__(self, model_name: str, ckpt_path: str, **kwargs): - from chebai.preprocessing.reader import ChemDataReader - - super().__init__(model_name, ckpt_path, reader_cls=ChemDataReader, **kwargs) + super().__init__(model_name, ckpt_path, **kwargs) print(f"Initialised Electra model {self.model_name} (device: {self.device})") - def init_model(self, ckpt_path: str, **kwargs) -> "Electra": - from chebai.models.electra import Electra - - model = Electra.load_from_checkpoint( - ckpt_path, - map_location=self.device, - criterion=None, - strict=False, - metrics=dict(train=dict(), test=dict(), validation=dict()), - pretrained_checkpoint=None, - ) - model.eval() - return model - def explain_smiles(self, smiles) -> dict: from chebai.preprocessing.reader import EMBEDDING_OFFSET - reader = self.reader_cls() - token_dict = reader.to_data(dict(features=smiles, labels=None)) + token_dict = self._predictor._dm.reader.to_data( + dict(features=smiles, labels=None) + ) tokens = np.array(token_dict["features"]).astype(int).tolist() result = self.calculate_results([token_dict]) token_labels = ( ["[CLR]"] + [None for _ in range(EMBEDDING_OFFSET - 1)] - + list(reader.cache.keys()) + + list(self._predictor._dm.reader.cache.keys()) ) graphs = [ diff --git a/chebifier/prediction_models/gnn_predictor.py b/chebifier/prediction_models/gnn_predictor.py index 214d906..3df39e5 100644 --- a/chebifier/prediction_models/gnn_predictor.py +++ b/chebifier/prediction_models/gnn_predictor.py @@ -1,120 +1,12 @@ -from typing import TYPE_CHECKING, Optional - -import torch - from .nn_predictor import NNPredictor -if TYPE_CHECKING: - from chebai_graph.models.gat import GATGraphPred - from chebai_graph.models.resgated import ResGatedGraphPred - -class ResGatedPredictor(NNPredictor): +class GNNPredictor(NNPredictor): def __init__( self, model_name: str, ckpt_path: str, - molecular_properties, - dataset_cls: Optional[str] = None, **kwargs, ): - from chebai_graph.preprocessing.datasets.chebi import ( - ChEBI50GraphProperties, - GraphPropertiesMixIn, - ) - from chebai_graph.preprocessing.properties import MolecularProperty - - # molecular_properties is a list of class paths - if molecular_properties is not None: - properties = [self.load_class(prop)() for prop in molecular_properties] - properties = sorted( - properties, key=lambda prop: f"{prop.name}_{prop.encoder.name}" - ) - else: - properties = [] - for property in properties: - property.encoder.eval = True - self.molecular_properties = properties - assert isinstance(self.molecular_properties, list) and all( - isinstance(prop, MolecularProperty) for prop in self.molecular_properties - ) - # TODO it should not be necessary to refer to the whole dataset class, disentangle dataset and molecule reading - self.dataset_cls = ( - self.load_class(dataset_cls) - if dataset_cls is not None - else ChEBI50GraphProperties - ) - self.dataset: Optional[GraphPropertiesMixIn] = self.dataset_cls( - properties=molecular_properties - ) - - super().__init__( - model_name, ckpt_path, reader_cls=self.dataset.READER, **kwargs - ) - + super().__init__(model_name, ckpt_path, **kwargs) print(f"Initialised GNN model {self.model_name} (device: {self.device})") - - def load_class(self, class_path: str): - module_path, class_name = class_path.rsplit(".", 1) - module = __import__(module_path, fromlist=[class_name]) - return getattr(module, class_name) - - def init_model(self, ckpt_path: str, **kwargs) -> "ResGatedGraphPred": - import torch - from chebai_graph.models.resgated import ResGatedGraphPred - - model = ResGatedGraphPred.load_from_checkpoint( - ckpt_path, - map_location=torch.device(self.device), - criterion=None, - strict=False, - metrics=dict(train=dict(), test=dict(), validation=dict()), - pretrained_checkpoint=None, - ) - model.eval() - return model - - def read_smiles(self, smiles): - from chebai_graph.preprocessing.datasets.chebi import GraphPropAsPerNodeType - - d = self.dataset.READER().to_data(dict(features=smiles, labels=None)) - property_data = d - # TODO merge props into base should not be a method of a dataset (or at least static) - for property in self.dataset.properties: - property.encoder.eval = True - property_value = self.reader.read_property(smiles, property) - if property_value is None or len(property_value) == 0: - encoded_value = None - else: - encoded_value = torch.stack( - [property.encoder.encode(v) for v in property_value] - ) - if len(encoded_value.shape) == 3: - encoded_value = encoded_value.squeeze(0) - property_data[property.name] = encoded_value - # Augmented graphs need an additional argument - if isinstance(self.dataset, GraphPropAsPerNodeType): - d["features"] = self.dataset._merge_props_into_base( - property_data, max_len_node_properties=self.model.gnn.in_channels - ) - else: - d["features"] = self.dataset._merge_props_into_base(property_data) - return d - - -class GATPredictor(ResGatedPredictor): - - def init_model(self, ckpt_path: str, **kwargs) -> "GATGraphPred": - import torch - from chebai_graph.models.gat import GATGraphPred - - model = GATGraphPred.load_from_checkpoint( - ckpt_path, - map_location=torch.device(self.device), - criterion=None, - strict=False, - metrics=dict(train=dict(), test=dict(), validation=dict()), - pretrained_checkpoint=None, - ) - model.eval() - return model diff --git a/chebifier/prediction_models/nn_predictor.py b/chebifier/prediction_models/nn_predictor.py index eb7e997..e63f5aa 100644 --- a/chebifier/prediction_models/nn_predictor.py +++ b/chebifier/prediction_models/nn_predictor.py @@ -1,115 +1,29 @@ -import numpy as np -import tqdm -from rdkit import Chem +from abc import ABC + +from chebai.result.prediction import Predictor from chebifier import modelwise_smiles_lru_cache from .base_predictor import BasePredictor -class NNPredictor(BasePredictor): +class NNPredictor(BasePredictor, ABC): def __init__( self, model_name: str, ckpt_path: str, - reader_cls, - target_labels_path: str, **kwargs, ): - import torch + self.batch_size = kwargs.get("batch_size", None) + # If batch_size is not provided, it will be set to default batch size used during training in Predictor + self._predictor: Predictor = Predictor(ckpt_path, self.batch_size) super().__init__(model_name, **kwargs) - self.reader_cls = reader_cls - self.reader = reader_cls() - - self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - self.model = self.init_model(ckpt_path=ckpt_path) - self.target_labels = [ - line.strip() for line in open(target_labels_path, encoding="utf-8") - ] - self.batch_size = kwargs.get("batch_size", 1) - - def init_model(self, ckpt_path: str, **kwargs): - raise NotImplementedError( - "Model initialization must be implemented in subclasses." - ) - - def calculate_results(self, batch): - collator = self.reader_cls.COLLATOR() - dat = self.model._process_batch(collator(batch).to(self.device), 0) - return self.model(dat, **dat["model_kwargs"]) - - def batchify(self, batch): - cache = [] - for r in batch: - cache.append(r) - if len(cache) >= self.batch_size: - yield cache - cache = [] - if cache: - yield cache - - def read_smiles(self, smiles): - d = self.reader.to_data(dict(features=smiles, labels=None)) - return d @modelwise_smiles_lru_cache.batch_decorator def predict_smiles_list(self, smiles_list: list[str]) -> list: - """Returns a list with the length of smiles_list, each element is either None (=failure) or a dictionary - Of classes and predicted values.""" - import torch - - token_dicts = [] - could_not_parse = [] - index_map = dict() - for i, smiles in enumerate(smiles_list): - if not smiles: - print( - f"Model {self.model_name} received a missing SMILES string at position {i}." - ) - could_not_parse.append(i) - continue - try: - d = self.read_smiles(smiles) - # This is just for sanity checks - rdmol = Chem.MolFromSmiles(smiles, sanitize=False) - if rdmol is None: - print( - f"Model {self.model_name} received a SMILES string RDKit can't read at position {i}: {smiles}" - ) - could_not_parse.append(i) - continue - except Exception: - could_not_parse.append(i) - print( - f"Model {self.model_name} failed to parse a SMILES string at position {i}: {smiles}" - ) - continue - index_map[i] = len(token_dicts) - token_dicts.append(d) - results = [] - if len(token_dicts) > 0: - for batch in tqdm.tqdm( - self.batchify(token_dicts), - desc=f"{self.model_name}", - total=len(token_dicts) // self.batch_size, - ): - result = self.calculate_results(batch) - if isinstance(result, dict) and "logits" in result: - result = result["logits"] - results += torch.sigmoid(result).cpu().detach().tolist() - results = np.stack(results, axis=0) - preds = [ - ( - { - self.target_labels[j]: p - for j, p in enumerate(results[index_map[i]]) - } - if i not in could_not_parse - else None - ) - for i in range(len(smiles_list)) - ] - return preds - else: - return [None for _ in smiles_list] + """ + Returns a list with the length of smiles_list, each element is + either None (=failure) or a dictionary of classes and predicted values. + """ + return self._predictor.predict_smiles(smiles_list) From 2b2978b54cd018fed042c47aba1ec1f147acfd35 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sat, 6 Dec 2025 13:27:43 +0100 Subject: [PATCH 2/2] pre-commit --- chebifier/model_registry.py | 6 +++--- data/disjoint_additional.csv | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/chebifier/model_registry.py b/chebifier/model_registry.py index 669247e..b4b911c 100644 --- a/chebifier/model_registry.py +++ b/chebifier/model_registry.py @@ -38,6 +38,6 @@ common_keys = MODEL_TYPES.keys() & ENSEMBLES.keys() -assert not common_keys, ( - f"Overlapping keys between MODEL_TYPES and ENSEMBLES: {common_keys}" -) +assert ( + not common_keys +), f"Overlapping keys between MODEL_TYPES and ENSEMBLES: {common_keys}" diff --git a/data/disjoint_additional.csv b/data/disjoint_additional.csv index e087ffb..e42b547 100644 --- a/data/disjoint_additional.csv +++ b/data/disjoint_additional.csv @@ -11,4 +11,4 @@ 47923,48030 47923,48545 48030,48545 -90799,155837 \ No newline at end of file +90799,155837