Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 87 additions & 1 deletion chebai_graph/preprocessing/datasets/chebi.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def __init__(
properties = self._sort_properties(properties)
else:
properties = []
self.properties = properties
self.properties: list[MolecularProperty] = properties
assert isinstance(self.properties, list) and all(
isinstance(p, MolecularProperty) for p in self.properties
)
Expand Down Expand Up @@ -184,6 +184,62 @@ def _after_setup(self, **kwargs) -> None:
self._setup_properties()
super()._after_setup(**kwargs)

def _process_input_for_prediction(
self,
smiles_list: list[str],
model_hparams: Optional[dict] = None,
) -> list:
data_df = self._process_smiles_and_props(smiles_list)
data_df["features"] = data_df.apply(
lambda row: self._merge_props_into_base(row), axis=1
)

# apply transformation, e.g. masking for pretraining task
if self.transform is not None:
data_df["features"] = data_df["features"].apply(self.transform)

return data_df.to_dict("records")

def _process_smiles_and_props(self, smiles_list: list[str]) -> pd.DataFrame:
"""
Process SMILES strings and compute molecular properties.
"""
data = [
self.reader.to_data(
{"ident": f"smiles_{idx}", "features": smiles, "labels": None}
)
for idx, smiles in enumerate(smiles_list)
]
# element of data is a dict with 'id' and 'features' (GeomData)
# GeomData has only edge_index filled but node and edges features are empty.

assert len(data) == len(smiles_list), "Data length mismatch."
data_df = pd.DataFrame(data)

props: list[dict] = []
for data_row in data_df.itertuples(index=True):
row_prop_dict: dict = {}
for property in self.properties:
property.encoder.eval = True
property_value = self.reader.read_property(
smiles_list[data_row.Index], property
)
if property_value is None or len(property_value) == 0:
encoded_value = None
else:
encoded_value = torch.stack(
[property.encoder.encode(v) for v in property_value]
)
if len(encoded_value.shape) == 3:
encoded_value = encoded_value.squeeze(0)
row_prop_dict[property.name] = encoded_value
row_prop_dict["ident"] = data_row.ident
props.append(row_prop_dict)

property_df = pd.DataFrame(props)
data_df = data_df.merge(property_df, on="ident", how="left")
return data_df


class GraphPropertiesMixIn(DataPropertiesSetter, ABC):
def __init__(
Expand Down Expand Up @@ -571,6 +627,36 @@ def _merge_props_into_base(
is_graph_node=is_graph_node,
)

def _process_input_for_prediction(
self,
smiles_list: list[str],
model_hparams: Optional[dict] = None,
) -> list:
if (
model_hparams is None
or "in_channels" not in model_hparams["config"]
or model_hparams["config"]["in_channels"] is None
):
raise ValueError(
f"model_hparams must be provided for data class: {self.__class__.__name__}"
f" which should contain 'in_channels' key with valid value in 'config' dictionary."
)

max_len_node_properties = int(model_hparams["config"]["in_channels"])
# Determine max_len_node_properties based on in_channels

data_df = self._process_smiles_and_props(smiles_list)
data_df["features"] = data_df.apply(
lambda row: self._merge_props_into_base(row, max_len_node_properties),
axis=1,
)

# apply transformation, e.g. masking for pretraining task
if self.transform is not None:
data_df["features"] = data_df["features"].apply(self.transform)

return data_df.to_dict("records")


class ChEBI50_StaticGNI(DataPropertiesSetter, ChEBIOver50):
READER = RandomFeatureInitializationReader
Expand Down