diff --git a/chebai/loss/bce_weighted.py b/chebai/loss/bce_weighted.py index b69fff43..c00756e6 100644 --- a/chebai/loss/bce_weighted.py +++ b/chebai/loss/bce_weighted.py @@ -43,8 +43,10 @@ def set_pos_weight(self, input: torch.Tensor) -> None: self.beta is not None and self.data_extractor is not None and all( - os.path.exists(os.path.join(self.data_extractor.raw_dir, raw_file)) - for raw_file in self.data_extractor.raw_file_names + os.path.exists( + os.path.join(self.data_extractor.processed_dir_main, file_name) + ) + for file_name in self.data_extractor.processed_main_file_names ) and self.pos_weight is None ): @@ -53,13 +55,13 @@ def set_pos_weight(self, input: torch.Tensor) -> None: pd.read_pickle( open( os.path.join( - self.data_extractor.raw_dir, - raw_file_name, + self.data_extractor.processed_dir_main, + file_name, ), "rb", ) ) - for raw_file_name in self.data_extractor.raw_file_names + for file_name in self.data_extractor.processed_main_file_names ] ) value_counts = [] diff --git a/chebai/preprocessing/datasets/base.py b/chebai/preprocessing/datasets/base.py index f163a9e6..dfa0f999 100644 --- a/chebai/preprocessing/datasets/base.py +++ b/chebai/preprocessing/datasets/base.py @@ -134,10 +134,15 @@ def base_dir(self) -> str: return self._base_dir return os.path.join("data", self._name) + @property + def processed_dir_main(self) -> str: + """Name of the directory where processed (but not tokenized) data is stored.""" + return os.path.join(self.base_dir, "processed") + @property def processed_dir(self) -> str: - """Name of the directory where the processed data is stored.""" - return os.path.join(self.base_dir, "processed", *self.identifier) + """Name of the directory where the processed and tokenized data is stored.""" + return os.path.join(self.processed_dir_main, *self.identifier) @property def raw_dir(self) -> str: @@ -394,45 +399,61 @@ def setup_processed(self): raise NotImplementedError @property - def processed_file_names(self) -> List[str]: + def processed_main_file_names_dict(self) -> dict: """ - Returns the list of processed file names. - - This property should be implemented by subclasses to provide the list of processed file names. + Returns a dictionary mapping processed data file names. Returns: - List[str]: The list of processed file names. + dict: A dictionary mapping dataset key to their respective file names. + For example, {"data": "data.pkl"}. """ raise NotImplementedError @property - def raw_file_names(self) -> List[str]: + def processed_main_file_names(self) -> List[str]: """ - Returns the list of raw file names. + Returns a list of file names for processed data (before tokenization). + + Returns: + List[str]: A list of file names corresponding to the processed data. + """ + return list(self.processed_main_file_names_dict.values()) - This property should be implemented by subclasses to provide the list of raw file names. + @property + def processed_file_names_dict(self) -> dict: + """ + Returns a dictionary for the processed and tokenized data files. Returns: - List[str]: The list of raw file names. + dict: A dictionary mapping dataset keys to their respective file names. + For example, {"data": "data.pt"}. """ raise NotImplementedError @property - def processed_file_names_dict(self) -> dict: + def processed_file_names(self) -> List[str]: """ - Returns the dictionary of processed file names. + Returns a list of file names for processed data. - This property should be implemented by subclasses to provide the dictionary of processed file names. + Returns: + List[str]: A list of file names corresponding to the processed data. + """ + return list(self.processed_file_names_dict.values()) + + @property + def raw_file_names(self) -> List[str]: + """ + Returns the list of raw file names. Returns: - dict: The dictionary of processed file names. + List[str]: The list of raw file names. """ - raise NotImplementedError + return list(self.raw_file_names_dict.values()) @property def raw_file_names_dict(self) -> dict: """ - Returns the dictionary of raw file names. + Returns the dictionary of raw file names (i.e., files that are directly obtained from an external source). This property should be implemented by subclasses to provide the dictionary of raw file names. @@ -705,7 +726,7 @@ def prepare_data(self, *args: Any, **kwargs: Any) -> None: """ print("Checking for processed data in", self.processed_dir_main) - processed_name = self.processed_dir_main_file_names_dict["data"] + processed_name = self.processed_main_file_names_dict["data"] if not os.path.isfile(os.path.join(self.processed_dir_main, processed_name)): print("Missing processed data file (`data.pkl` file)") os.makedirs(self.processed_dir_main, exist_ok=True) @@ -796,7 +817,7 @@ def setup_processed(self) -> None: self._load_data_from_file( os.path.join( self.processed_dir_main, - self.processed_dir_main_file_names_dict["data"], + self.processed_main_file_names_dict["data"], ) ), os.path.join(self.processed_dir, self.processed_file_names_dict["data"]), @@ -1118,25 +1139,12 @@ def processed_dir_main(self) -> str: ) @property - def processed_dir(self) -> str: + def processed_main_file_names_dict(self) -> dict: """ - Returns the specific directory path for processed data, including identifiers. + Returns a dictionary mapping processed data file names. Returns: - str: The path to the processed data directory, including additional identifiers. - """ - return os.path.join( - self.processed_dir_main, - *self.identifier, - ) - - @property - def processed_dir_main_file_names_dict(self) -> dict: - """ - Returns a dictionary mapping processed data file names, processed by `prepare_data` method. - - Returns: - dict: A dictionary mapping dataset types to their respective processed file names. + dict: A dictionary mapping dataset key to their respective file names. For example, {"data": "data.pkl"}. """ return {"data": "data.pkl"} @@ -1144,21 +1152,10 @@ def processed_dir_main_file_names_dict(self) -> dict: @property def processed_file_names_dict(self) -> dict: """ - Returns a dictionary mapping processed and transformed data file names to their final formats, which are - processed by `setup` method. + Returns a dictionary for the processed and tokenized data files. Returns: - dict: A dictionary mapping dataset types to their respective final file names. + dict: A dictionary mapping dataset keys to their respective file names. For example, {"data": "data.pt"}. """ return {"data": "data.pt"} - - @property - def processed_file_names(self) -> List[str]: - """ - Returns a list of file names for processed data. - - Returns: - List[str]: A list of file names corresponding to the processed data. - """ - return list(self.processed_file_names_dict.values()) diff --git a/chebai/preprocessing/datasets/chebi.py b/chebai/preprocessing/datasets/chebi.py index 9d80929a..d927a44c 100644 --- a/chebai/preprocessing/datasets/chebi.py +++ b/chebai/preprocessing/datasets/chebi.py @@ -185,7 +185,7 @@ def prepare_data(self, *args: Any, **kwargs: Any) -> None: if not os.path.isfile( os.path.join( self._chebi_version_train_obj.processed_dir_main, - self._chebi_version_train_obj.processed_dir_main_file_names_dict[ + self._chebi_version_train_obj.processed_main_file_names_dict[ "data" ], ) @@ -216,9 +216,7 @@ def _load_chebi(self, version: int) -> str: Returns: str: The file path of the loaded ChEBI ontology. """ - chebi_name = ( - f"chebi.obo" if version == self.chebi_version else f"chebi_v{version}.obo" - ) + chebi_name = self.raw_file_names_dict["chebi"] chebi_path = os.path.join(self.raw_dir, chebi_name) if not os.path.isfile(chebi_path): print( @@ -540,6 +538,10 @@ def processed_dir(self) -> str: else: return os.path.join(res, f"single_{self.single_class}") + @property + def raw_file_names_dict(self) -> dict: + return {"chebi": "chebi.obo"} + class JCIExtendedBase(_ChEBIDataExtractor): diff --git a/chebai/preprocessing/datasets/protein_pretraining.py b/chebai/preprocessing/datasets/protein_pretraining.py index 8550db2b..6b5d1df0 100644 --- a/chebai/preprocessing/datasets/protein_pretraining.py +++ b/chebai/preprocessing/datasets/protein_pretraining.py @@ -64,7 +64,7 @@ def prepare_data(self, *args: Any, **kwargs: Any) -> None: *args: Additional positional arguments. **kwargs: Additional keyword arguments. """ - processed_name = self.processed_dir_main_file_names_dict["data"] + processed_name = self.processed_main_file_names_dict["data"] if not os.path.isfile(os.path.join(self.processed_dir_main, processed_name)): print("Missing processed data file (`data.pkl` file)") os.makedirs(self.processed_dir_main, exist_ok=True) diff --git a/chebai/result/utils.py b/chebai/result/utils.py index d015bd80..80bf56e2 100644 --- a/chebai/result/utils.py +++ b/chebai/result/utils.py @@ -66,6 +66,13 @@ def _run_batch(batch, model, collate): return preds, labels +def _concat_tuple(l): + if isinstance(l[0], tuple): + print(l[0]) + return tuple([torch.cat([t[i] for t in l]) for i in range(len(l[0]))]) + return torch.cat(l) + + def evaluate_model( model: ChebaiBaseNet, data_module: XYBaseDataModule, @@ -125,12 +132,12 @@ def evaluate_model( if buffer_dir is not None: if n_saved * batch_size >= save_batch_size: torch.save( - torch.cat(preds_list), + _concat_tuple(preds_list), os.path.join(buffer_dir, f"preds{save_ind:03d}.pt"), ) if labels_list[0] is not None: torch.save( - torch.cat(labels_list), + _concat_tuple(labels_list), os.path.join(buffer_dir, f"labels{save_ind:03d}.pt"), ) preds_list = [] @@ -141,20 +148,20 @@ def evaluate_model( n_saved += 1 if buffer_dir is None: - test_preds = torch.cat(preds_list) + test_preds = _concat_tuple(preds_list) if labels_list is not None: - test_labels = torch.cat(labels_list) + test_labels = _concat_tuple(labels_list) return test_preds, test_labels return test_preds, None else: torch.save( - torch.cat(preds_list), + _concat_tuple(preds_list), os.path.join(buffer_dir, f"preds{save_ind:03d}.pt"), ) if labels_list[0] is not None: torch.save( - torch.cat(labels_list), + _concat_tuple(labels_list), os.path.join(buffer_dir, f"labels{save_ind:03d}.pt"), )