From 98e3f9cf5b40624c2b5cb2f645608ed350b9ccd9 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Wed, 6 Nov 2024 13:29:39 +0100 Subject: [PATCH 1/5] handle predictions and labels that are tuples (e.g. atom and bond predictions for gnn pretraining) in evaluation --- chebai/result/utils.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/chebai/result/utils.py b/chebai/result/utils.py index d015bd80..80bf56e2 100644 --- a/chebai/result/utils.py +++ b/chebai/result/utils.py @@ -66,6 +66,13 @@ def _run_batch(batch, model, collate): return preds, labels +def _concat_tuple(l): + if isinstance(l[0], tuple): + print(l[0]) + return tuple([torch.cat([t[i] for t in l]) for i in range(len(l[0]))]) + return torch.cat(l) + + def evaluate_model( model: ChebaiBaseNet, data_module: XYBaseDataModule, @@ -125,12 +132,12 @@ def evaluate_model( if buffer_dir is not None: if n_saved * batch_size >= save_batch_size: torch.save( - torch.cat(preds_list), + _concat_tuple(preds_list), os.path.join(buffer_dir, f"preds{save_ind:03d}.pt"), ) if labels_list[0] is not None: torch.save( - torch.cat(labels_list), + _concat_tuple(labels_list), os.path.join(buffer_dir, f"labels{save_ind:03d}.pt"), ) preds_list = [] @@ -141,20 +148,20 @@ def evaluate_model( n_saved += 1 if buffer_dir is None: - test_preds = torch.cat(preds_list) + test_preds = _concat_tuple(preds_list) if labels_list is not None: - test_labels = torch.cat(labels_list) + test_labels = _concat_tuple(labels_list) return test_preds, test_labels return test_preds, None else: torch.save( - torch.cat(preds_list), + _concat_tuple(preds_list), os.path.join(buffer_dir, f"preds{save_ind:03d}.pt"), ) if labels_list[0] is not None: torch.save( - torch.cat(labels_list), + _concat_tuple(labels_list), os.path.join(buffer_dir, f"labels{save_ind:03d}.pt"), ) From 16013af7f6f021f300c0c4ebdd9596c5fb36a5b6 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Wed, 6 Nov 2024 14:48:18 +0100 Subject: [PATCH 2/5] add processed_main file names --- chebai/preprocessing/datasets/base.py | 65 ++++++++++++++------------ chebai/preprocessing/datasets/chebi.py | 8 ++-- 2 files changed, 40 insertions(+), 33 deletions(-) diff --git a/chebai/preprocessing/datasets/base.py b/chebai/preprocessing/datasets/base.py index f163a9e6..bb0f50d2 100644 --- a/chebai/preprocessing/datasets/base.py +++ b/chebai/preprocessing/datasets/base.py @@ -394,45 +394,61 @@ def setup_processed(self): raise NotImplementedError @property - def processed_file_names(self) -> List[str]: + def processed_dir_main_file_names_dict(self) -> dict: """ - Returns the list of processed file names. - - This property should be implemented by subclasses to provide the list of processed file names. + Returns a dictionary mapping processed data file names. Returns: - List[str]: The list of processed file names. + dict: A dictionary mapping dataset key to their respective file names. + For example, {"data": "data.pkl"}. """ raise NotImplementedError @property - def raw_file_names(self) -> List[str]: + def processed_dir_main_file_names(self) -> List[str]: """ - Returns the list of raw file names. + Returns a list of file names for processed data (before tokenization). + + Returns: + List[str]: A list of file names corresponding to the processed data. + """ + return list(self.processed_dir_main_file_names_dict.values()) - This property should be implemented by subclasses to provide the list of raw file names. + @property + def processed_file_names_dict(self) -> dict: + """ + Returns a dictionary for the processed and tokenized data files. Returns: - List[str]: The list of raw file names. + dict: A dictionary mapping dataset keys to their respective file names. + For example, {"data": "data.pt"}. """ raise NotImplementedError @property - def processed_file_names_dict(self) -> dict: + def processed_file_names(self) -> List[str]: + """ + Returns a list of file names for processed data. + + Returns: + List[str]: A list of file names corresponding to the processed data. """ - Returns the dictionary of processed file names. + return list(self.processed_file_names_dict.values()) - This property should be implemented by subclasses to provide the dictionary of processed file names. + @property + def raw_file_names(self) -> List[str]: + """ + Returns the list of raw file names. Returns: - dict: The dictionary of processed file names. + List[str]: The list of raw file names. """ - raise NotImplementedError + return list(self.raw_file_names_dict.values()) @property def raw_file_names_dict(self) -> dict: """ - Returns the dictionary of raw file names. + Returns the dictionary of raw file names (i.e., files that are directly obtained from an external source). This property should be implemented by subclasses to provide the dictionary of raw file names. @@ -1133,10 +1149,10 @@ def processed_dir(self) -> str: @property def processed_dir_main_file_names_dict(self) -> dict: """ - Returns a dictionary mapping processed data file names, processed by `prepare_data` method. + Returns a dictionary mapping processed data file names. Returns: - dict: A dictionary mapping dataset types to their respective processed file names. + dict: A dictionary mapping dataset key to their respective file names. For example, {"data": "data.pkl"}. """ return {"data": "data.pkl"} @@ -1144,21 +1160,10 @@ def processed_dir_main_file_names_dict(self) -> dict: @property def processed_file_names_dict(self) -> dict: """ - Returns a dictionary mapping processed and transformed data file names to their final formats, which are - processed by `setup` method. + Returns a dictionary for the processed and tokenized data files. Returns: - dict: A dictionary mapping dataset types to their respective final file names. + dict: A dictionary mapping dataset keys to their respective file names. For example, {"data": "data.pt"}. """ return {"data": "data.pt"} - - @property - def processed_file_names(self) -> List[str]: - """ - Returns a list of file names for processed data. - - Returns: - List[str]: A list of file names corresponding to the processed data. - """ - return list(self.processed_file_names_dict.values()) diff --git a/chebai/preprocessing/datasets/chebi.py b/chebai/preprocessing/datasets/chebi.py index 9d80929a..1b49d0e2 100644 --- a/chebai/preprocessing/datasets/chebi.py +++ b/chebai/preprocessing/datasets/chebi.py @@ -216,9 +216,7 @@ def _load_chebi(self, version: int) -> str: Returns: str: The file path of the loaded ChEBI ontology. """ - chebi_name = ( - f"chebi.obo" if version == self.chebi_version else f"chebi_v{version}.obo" - ) + chebi_name = self.raw_file_names_dict["chebi"] chebi_path = os.path.join(self.raw_dir, chebi_name) if not os.path.isfile(chebi_path): print( @@ -540,6 +538,10 @@ def processed_dir(self) -> str: else: return os.path.join(res, f"single_{self.single_class}") + @property + def raw_file_names_dict(self) -> dict: + return {"chebi": "chebi.obo"} + class JCIExtendedBase(_ChEBIDataExtractor): From 9731f897466071f742ee06ad59444df5cb18b833 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Wed, 6 Nov 2024 15:00:43 +0100 Subject: [PATCH 3/5] shorten processed_dir_main_file_names to processed_main_file_names --- chebai/preprocessing/datasets/base.py | 12 ++++++------ chebai/preprocessing/datasets/chebi.py | 2 +- chebai/preprocessing/datasets/protein_pretraining.py | 2 +- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/chebai/preprocessing/datasets/base.py b/chebai/preprocessing/datasets/base.py index bb0f50d2..73c2b2cd 100644 --- a/chebai/preprocessing/datasets/base.py +++ b/chebai/preprocessing/datasets/base.py @@ -394,7 +394,7 @@ def setup_processed(self): raise NotImplementedError @property - def processed_dir_main_file_names_dict(self) -> dict: + def processed_main_file_names_dict(self) -> dict: """ Returns a dictionary mapping processed data file names. @@ -405,14 +405,14 @@ def processed_dir_main_file_names_dict(self) -> dict: raise NotImplementedError @property - def processed_dir_main_file_names(self) -> List[str]: + def processed_main_file_names(self) -> List[str]: """ Returns a list of file names for processed data (before tokenization). Returns: List[str]: A list of file names corresponding to the processed data. """ - return list(self.processed_dir_main_file_names_dict.values()) + return list(self.processed_main_file_names_dict.values()) @property def processed_file_names_dict(self) -> dict: @@ -721,7 +721,7 @@ def prepare_data(self, *args: Any, **kwargs: Any) -> None: """ print("Checking for processed data in", self.processed_dir_main) - processed_name = self.processed_dir_main_file_names_dict["data"] + processed_name = self.processed_main_file_names_dict["data"] if not os.path.isfile(os.path.join(self.processed_dir_main, processed_name)): print("Missing processed data file (`data.pkl` file)") os.makedirs(self.processed_dir_main, exist_ok=True) @@ -812,7 +812,7 @@ def setup_processed(self) -> None: self._load_data_from_file( os.path.join( self.processed_dir_main, - self.processed_dir_main_file_names_dict["data"], + self.processed_main_file_names_dict["data"], ) ), os.path.join(self.processed_dir, self.processed_file_names_dict["data"]), @@ -1147,7 +1147,7 @@ def processed_dir(self) -> str: ) @property - def processed_dir_main_file_names_dict(self) -> dict: + def processed_main_file_names_dict(self) -> dict: """ Returns a dictionary mapping processed data file names. diff --git a/chebai/preprocessing/datasets/chebi.py b/chebai/preprocessing/datasets/chebi.py index 1b49d0e2..d927a44c 100644 --- a/chebai/preprocessing/datasets/chebi.py +++ b/chebai/preprocessing/datasets/chebi.py @@ -185,7 +185,7 @@ def prepare_data(self, *args: Any, **kwargs: Any) -> None: if not os.path.isfile( os.path.join( self._chebi_version_train_obj.processed_dir_main, - self._chebi_version_train_obj.processed_dir_main_file_names_dict[ + self._chebi_version_train_obj.processed_main_file_names_dict[ "data" ], ) diff --git a/chebai/preprocessing/datasets/protein_pretraining.py b/chebai/preprocessing/datasets/protein_pretraining.py index 8550db2b..6b5d1df0 100644 --- a/chebai/preprocessing/datasets/protein_pretraining.py +++ b/chebai/preprocessing/datasets/protein_pretraining.py @@ -64,7 +64,7 @@ def prepare_data(self, *args: Any, **kwargs: Any) -> None: *args: Additional positional arguments. **kwargs: Additional keyword arguments. """ - processed_name = self.processed_dir_main_file_names_dict["data"] + processed_name = self.processed_main_file_names_dict["data"] if not os.path.isfile(os.path.join(self.processed_dir_main, processed_name)): print("Missing processed data file (`data.pkl` file)") os.makedirs(self.processed_dir_main, exist_ok=True) From ccc5aea6f89b64993e9b9a1eec4dc978668612a2 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Wed, 6 Nov 2024 17:44:56 +0100 Subject: [PATCH 4/5] move processed_dir_main to XYBaseDataModule --- chebai/preprocessing/datasets/base.py | 22 +++++++--------------- 1 file changed, 7 insertions(+), 15 deletions(-) diff --git a/chebai/preprocessing/datasets/base.py b/chebai/preprocessing/datasets/base.py index 73c2b2cd..dfa0f999 100644 --- a/chebai/preprocessing/datasets/base.py +++ b/chebai/preprocessing/datasets/base.py @@ -134,10 +134,15 @@ def base_dir(self) -> str: return self._base_dir return os.path.join("data", self._name) + @property + def processed_dir_main(self) -> str: + """Name of the directory where processed (but not tokenized) data is stored.""" + return os.path.join(self.base_dir, "processed") + @property def processed_dir(self) -> str: - """Name of the directory where the processed data is stored.""" - return os.path.join(self.base_dir, "processed", *self.identifier) + """Name of the directory where the processed and tokenized data is stored.""" + return os.path.join(self.processed_dir_main, *self.identifier) @property def raw_dir(self) -> str: @@ -1133,19 +1138,6 @@ def processed_dir_main(self) -> str: "processed", ) - @property - def processed_dir(self) -> str: - """ - Returns the specific directory path for processed data, including identifiers. - - Returns: - str: The path to the processed data directory, including additional identifiers. - """ - return os.path.join( - self.processed_dir_main, - *self.identifier, - ) - @property def processed_main_file_names_dict(self) -> dict: """ From dfea71ee12dcdb687e23560ef98178881e20be91 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Wed, 6 Nov 2024 17:45:30 +0100 Subject: [PATCH 5/5] use processed-main instead of raw file for BCE weights --- chebai/loss/bce_weighted.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/chebai/loss/bce_weighted.py b/chebai/loss/bce_weighted.py index b69fff43..c00756e6 100644 --- a/chebai/loss/bce_weighted.py +++ b/chebai/loss/bce_weighted.py @@ -43,8 +43,10 @@ def set_pos_weight(self, input: torch.Tensor) -> None: self.beta is not None and self.data_extractor is not None and all( - os.path.exists(os.path.join(self.data_extractor.raw_dir, raw_file)) - for raw_file in self.data_extractor.raw_file_names + os.path.exists( + os.path.join(self.data_extractor.processed_dir_main, file_name) + ) + for file_name in self.data_extractor.processed_main_file_names ) and self.pos_weight is None ): @@ -53,13 +55,13 @@ def set_pos_weight(self, input: torch.Tensor) -> None: pd.read_pickle( open( os.path.join( - self.data_extractor.raw_dir, - raw_file_name, + self.data_extractor.processed_dir_main, + file_name, ), "rb", ) ) - for raw_file_name in self.data_extractor.raw_file_names + for file_name in self.data_extractor.processed_main_file_names ] ) value_counts = []