From a3284b6f2753747dd6fbaf4f5ac2a20fdad70154 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Wed, 26 Nov 2025 23:40:47 +0100 Subject: [PATCH 1/3] predict pipeline --- chebai_graph/preprocessing/datasets/chebi.py | 36 +++++++++++++++++++- 1 file changed, 35 insertions(+), 1 deletion(-) diff --git a/chebai_graph/preprocessing/datasets/chebi.py b/chebai_graph/preprocessing/datasets/chebi.py index 4ae441a..722df7c 100644 --- a/chebai_graph/preprocessing/datasets/chebi.py +++ b/chebai_graph/preprocessing/datasets/chebi.py @@ -77,7 +77,7 @@ def __init__( properties = self._sort_properties(properties) else: properties = [] - self.properties = properties + self.properties: list[MolecularProperty] = properties assert isinstance(self.properties, list) and all( isinstance(p, MolecularProperty) for p in self.properties ) @@ -361,6 +361,40 @@ def load_processed_data( return base_df[base_data[0].keys()].to_dict("records") + def _process_input_for_prediction(self, smiles_list: list[str]) -> list: + data = [ + self.reader.to_data( + {"id": f"smiles_{idx}", "features": smiles, "labels": None} + ) + for idx, smiles in enumerate(smiles_list) + ] + # element of data is a dict with 'id' and 'features' (GeomData) + # GeomData has only edge_index filled but node and edges features are empty. + + assert len(data) == len(smiles_list), "Data length mismatch." + data_df = pd.DataFrame(data) + + for idx, data_row in data_df.itertuples(index=True): + property_data = data_row + for property in self.properties: + property.encoder.eval = True + property_value = self.reader.read_property(smiles_list[idx], property) + if property_value is None or len(property_value) == 0: + encoded_value = None + else: + encoded_value = torch.stack( + [property.encoder.encode(v) for v in property_value] + ) + if len(encoded_value.shape) == 3: + encoded_value = encoded_value.squeeze(0) + property_data[property.name] = encoded_value + + property_data["features"] = property_data.apply( + lambda row: self._merge_props_into_base(row), axis=1 + ) + + return data_df.to_dict("records") + class GraphPropAsPerNodeType(DataPropertiesSetter, ABC): def __init__(self, properties=None, transform=None, **kwargs): From 173788c5aa44abcd90c09aae2b0000813aa2124b Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Thu, 27 Nov 2025 20:25:44 +0100 Subject: [PATCH 2/3] fix pred pipe func --- chebai_graph/preprocessing/datasets/chebi.py | 120 +++++++++++++------ 1 file changed, 86 insertions(+), 34 deletions(-) diff --git a/chebai_graph/preprocessing/datasets/chebi.py b/chebai_graph/preprocessing/datasets/chebi.py index 722df7c..055cca9 100644 --- a/chebai_graph/preprocessing/datasets/chebi.py +++ b/chebai_graph/preprocessing/datasets/chebi.py @@ -184,6 +184,62 @@ def _after_setup(self, **kwargs) -> None: self._setup_properties() super()._after_setup(**kwargs) + def _process_input_for_prediction( + self, + smiles_list: list[str], + model_hparams: Optional[dict] = None, + ) -> list: + data_df = self._process_smiles_and_props(smiles_list) + data_df["features"] = data_df.apply( + lambda row: self._merge_props_into_base(row), axis=1 + ) + + # apply transformation, e.g. masking for pretraining task + if self.transform is not None: + data_df["features"] = data_df["features"].apply(self.transform) + + return data_df.to_dict("records") + + def _process_smiles_and_props(self, smiles_list: list[str]) -> pd.DataFrame: + """ + Process SMILES strings and compute molecular properties. + """ + data = [ + self.reader.to_data( + {"id": f"smiles_{idx}", "features": smiles, "labels": None} + ) + for idx, smiles in enumerate(smiles_list) + ] + # element of data is a dict with 'id' and 'features' (GeomData) + # GeomData has only edge_index filled but node and edges features are empty. + + assert len(data) == len(smiles_list), "Data length mismatch." + data_df = pd.DataFrame(data) + + props: list[dict] = [] + for data_row in data_df.itertuples(index=True): + row_prop_dict: dict = {} + for property in self.properties: + property.encoder.eval = True + property_value = self.reader.read_property( + smiles_list[data_row.Index], property + ) + if property_value is None or len(property_value) == 0: + encoded_value = None + else: + encoded_value = torch.stack( + [property.encoder.encode(v) for v in property_value] + ) + if len(encoded_value.shape) == 3: + encoded_value = encoded_value.squeeze(0) + row_prop_dict[property.name] = encoded_value + row_prop_dict["ident"] = data_row.ident + props.append(row_prop_dict) + + property_df = pd.DataFrame(props) + data_df = data_df.merge(property_df, on="ident", how="left") + return data_df + class GraphPropertiesMixIn(DataPropertiesSetter, ABC): def __init__( @@ -361,40 +417,6 @@ def load_processed_data( return base_df[base_data[0].keys()].to_dict("records") - def _process_input_for_prediction(self, smiles_list: list[str]) -> list: - data = [ - self.reader.to_data( - {"id": f"smiles_{idx}", "features": smiles, "labels": None} - ) - for idx, smiles in enumerate(smiles_list) - ] - # element of data is a dict with 'id' and 'features' (GeomData) - # GeomData has only edge_index filled but node and edges features are empty. - - assert len(data) == len(smiles_list), "Data length mismatch." - data_df = pd.DataFrame(data) - - for idx, data_row in data_df.itertuples(index=True): - property_data = data_row - for property in self.properties: - property.encoder.eval = True - property_value = self.reader.read_property(smiles_list[idx], property) - if property_value is None or len(property_value) == 0: - encoded_value = None - else: - encoded_value = torch.stack( - [property.encoder.encode(v) for v in property_value] - ) - if len(encoded_value.shape) == 3: - encoded_value = encoded_value.squeeze(0) - property_data[property.name] = encoded_value - - property_data["features"] = property_data.apply( - lambda row: self._merge_props_into_base(row), axis=1 - ) - - return data_df.to_dict("records") - class GraphPropAsPerNodeType(DataPropertiesSetter, ABC): def __init__(self, properties=None, transform=None, **kwargs): @@ -605,6 +627,36 @@ def _merge_props_into_base( is_graph_node=is_graph_node, ) + def _process_input_for_prediction( + self, + smiles_list: list[str], + model_hparams: Optional[dict] = None, + ) -> list: + if ( + model_hparams is None + or "in_channels" not in model_hparams["config"] + or model_hparams["config"]["in_channels"] is None + ): + raise ValueError( + f"model_hparams must be provided for data class: {self.__class__.__name__}" + f" which should contain 'in_channels' key with valid value in 'config' dictionary." + ) + + max_len_node_properties = int(model_hparams["config"]["in_channels"]) + # Determine max_len_node_properties based on in_channels + + data_df = self._process_smiles_and_props(smiles_list) + data_df["features"] = data_df.apply( + lambda row: self._merge_props_into_base(row, max_len_node_properties), + axis=1, + ) + + # apply transformation, e.g. masking for pretraining task + if self.transform is not None: + data_df["features"] = data_df["features"].apply(self.transform) + + return data_df.to_dict("records") + class ChEBI50_StaticGNI(DataPropertiesSetter, ChEBIOver50): READER = RandomFeatureInitializationReader From a11ba5067b6aec9f94e48f74db42f19f0482522f Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Thu, 27 Nov 2025 23:29:09 +0100 Subject: [PATCH 3/3] fix irrevalant ident from reader error --- chebai_graph/preprocessing/datasets/chebi.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/chebai_graph/preprocessing/datasets/chebi.py b/chebai_graph/preprocessing/datasets/chebi.py index 055cca9..8578160 100644 --- a/chebai_graph/preprocessing/datasets/chebi.py +++ b/chebai_graph/preprocessing/datasets/chebi.py @@ -206,7 +206,7 @@ def _process_smiles_and_props(self, smiles_list: list[str]) -> pd.DataFrame: """ data = [ self.reader.to_data( - {"id": f"smiles_{idx}", "features": smiles, "labels": None} + {"ident": f"smiles_{idx}", "features": smiles, "labels": None} ) for idx, smiles in enumerate(smiles_list) ]