Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions chebai/loss/bce_weighted.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,10 @@ def set_pos_weight(self, input: torch.Tensor) -> None:
self.beta is not None
and self.data_extractor is not None
and all(
os.path.exists(os.path.join(self.data_extractor.raw_dir, raw_file))
for raw_file in self.data_extractor.raw_file_names
os.path.exists(
os.path.join(self.data_extractor.processed_dir_main, file_name)
)
for file_name in self.data_extractor.processed_main_file_names
)
and self.pos_weight is None
):
Expand All @@ -53,13 +55,13 @@ def set_pos_weight(self, input: torch.Tensor) -> None:
pd.read_pickle(
open(
os.path.join(
self.data_extractor.raw_dir,
raw_file_name,
self.data_extractor.processed_dir_main,
file_name,
),
"rb",
)
)
for raw_file_name in self.data_extractor.raw_file_names
for file_name in self.data_extractor.processed_main_file_names
]
)
value_counts = []
Expand Down
93 changes: 45 additions & 48 deletions chebai/preprocessing/datasets/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,10 +134,15 @@ def base_dir(self) -> str:
return self._base_dir
return os.path.join("data", self._name)

@property
def processed_dir_main(self) -> str:
"""Name of the directory where processed (but not tokenized) data is stored."""
return os.path.join(self.base_dir, "processed")

@property
def processed_dir(self) -> str:
"""Name of the directory where the processed data is stored."""
return os.path.join(self.base_dir, "processed", *self.identifier)
"""Name of the directory where the processed and tokenized data is stored."""
return os.path.join(self.processed_dir_main, *self.identifier)

@property
def raw_dir(self) -> str:
Expand Down Expand Up @@ -394,45 +399,61 @@ def setup_processed(self):
raise NotImplementedError

@property
def processed_file_names(self) -> List[str]:
def processed_main_file_names_dict(self) -> dict:
"""
Returns the list of processed file names.

This property should be implemented by subclasses to provide the list of processed file names.
Returns a dictionary mapping processed data file names.

Returns:
List[str]: The list of processed file names.
dict: A dictionary mapping dataset key to their respective file names.
For example, {"data": "data.pkl"}.
"""
raise NotImplementedError

@property
def raw_file_names(self) -> List[str]:
def processed_main_file_names(self) -> List[str]:
"""
Returns the list of raw file names.
Returns a list of file names for processed data (before tokenization).

Returns:
List[str]: A list of file names corresponding to the processed data.
"""
return list(self.processed_main_file_names_dict.values())

This property should be implemented by subclasses to provide the list of raw file names.
@property
def processed_file_names_dict(self) -> dict:
"""
Returns a dictionary for the processed and tokenized data files.

Returns:
List[str]: The list of raw file names.
dict: A dictionary mapping dataset keys to their respective file names.
For example, {"data": "data.pt"}.
"""
raise NotImplementedError

@property
def processed_file_names_dict(self) -> dict:
def processed_file_names(self) -> List[str]:
"""
Returns the dictionary of processed file names.
Returns a list of file names for processed data.

This property should be implemented by subclasses to provide the dictionary of processed file names.
Returns:
List[str]: A list of file names corresponding to the processed data.
"""
return list(self.processed_file_names_dict.values())

@property
def raw_file_names(self) -> List[str]:
"""
Returns the list of raw file names.

Returns:
dict: The dictionary of processed file names.
List[str]: The list of raw file names.
"""
raise NotImplementedError
return list(self.raw_file_names_dict.values())

@property
def raw_file_names_dict(self) -> dict:
"""
Returns the dictionary of raw file names.
Returns the dictionary of raw file names (i.e., files that are directly obtained from an external source).

This property should be implemented by subclasses to provide the dictionary of raw file names.

Expand Down Expand Up @@ -705,7 +726,7 @@ def prepare_data(self, *args: Any, **kwargs: Any) -> None:
"""
print("Checking for processed data in", self.processed_dir_main)

processed_name = self.processed_dir_main_file_names_dict["data"]
processed_name = self.processed_main_file_names_dict["data"]
if not os.path.isfile(os.path.join(self.processed_dir_main, processed_name)):
print("Missing processed data file (`data.pkl` file)")
os.makedirs(self.processed_dir_main, exist_ok=True)
Expand Down Expand Up @@ -796,7 +817,7 @@ def setup_processed(self) -> None:
self._load_data_from_file(
os.path.join(
self.processed_dir_main,
self.processed_dir_main_file_names_dict["data"],
self.processed_main_file_names_dict["data"],
)
),
os.path.join(self.processed_dir, self.processed_file_names_dict["data"]),
Expand Down Expand Up @@ -1118,47 +1139,23 @@ def processed_dir_main(self) -> str:
)

@property
def processed_dir(self) -> str:
def processed_main_file_names_dict(self) -> dict:
"""
Returns the specific directory path for processed data, including identifiers.
Returns a dictionary mapping processed data file names.

Returns:
str: The path to the processed data directory, including additional identifiers.
"""
return os.path.join(
self.processed_dir_main,
*self.identifier,
)

@property
def processed_dir_main_file_names_dict(self) -> dict:
"""
Returns a dictionary mapping processed data file names, processed by `prepare_data` method.

Returns:
dict: A dictionary mapping dataset types to their respective processed file names.
dict: A dictionary mapping dataset key to their respective file names.
For example, {"data": "data.pkl"}.
"""
return {"data": "data.pkl"}

@property
def processed_file_names_dict(self) -> dict:
"""
Returns a dictionary mapping processed and transformed data file names to their final formats, which are
processed by `setup` method.
Returns a dictionary for the processed and tokenized data files.

Returns:
dict: A dictionary mapping dataset types to their respective final file names.
dict: A dictionary mapping dataset keys to their respective file names.
For example, {"data": "data.pt"}.
"""
return {"data": "data.pt"}

@property
def processed_file_names(self) -> List[str]:
"""
Returns a list of file names for processed data.

Returns:
List[str]: A list of file names corresponding to the processed data.
"""
return list(self.processed_file_names_dict.values())
10 changes: 6 additions & 4 deletions chebai/preprocessing/datasets/chebi.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def prepare_data(self, *args: Any, **kwargs: Any) -> None:
if not os.path.isfile(
os.path.join(
self._chebi_version_train_obj.processed_dir_main,
self._chebi_version_train_obj.processed_dir_main_file_names_dict[
self._chebi_version_train_obj.processed_main_file_names_dict[
"data"
],
)
Expand Down Expand Up @@ -216,9 +216,7 @@ def _load_chebi(self, version: int) -> str:
Returns:
str: The file path of the loaded ChEBI ontology.
"""
chebi_name = (
f"chebi.obo" if version == self.chebi_version else f"chebi_v{version}.obo"
)
chebi_name = self.raw_file_names_dict["chebi"]
chebi_path = os.path.join(self.raw_dir, chebi_name)
if not os.path.isfile(chebi_path):
print(
Expand Down Expand Up @@ -540,6 +538,10 @@ def processed_dir(self) -> str:
else:
return os.path.join(res, f"single_{self.single_class}")

@property
def raw_file_names_dict(self) -> dict:
return {"chebi": "chebi.obo"}


class JCIExtendedBase(_ChEBIDataExtractor):

Expand Down
2 changes: 1 addition & 1 deletion chebai/preprocessing/datasets/protein_pretraining.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def prepare_data(self, *args: Any, **kwargs: Any) -> None:
*args: Additional positional arguments.
**kwargs: Additional keyword arguments.
"""
processed_name = self.processed_dir_main_file_names_dict["data"]
processed_name = self.processed_main_file_names_dict["data"]
if not os.path.isfile(os.path.join(self.processed_dir_main, processed_name)):
print("Missing processed data file (`data.pkl` file)")
os.makedirs(self.processed_dir_main, exist_ok=True)
Expand Down
19 changes: 13 additions & 6 deletions chebai/result/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,13 @@ def _run_batch(batch, model, collate):
return preds, labels


def _concat_tuple(l):
if isinstance(l[0], tuple):
print(l[0])
return tuple([torch.cat([t[i] for t in l]) for i in range(len(l[0]))])
return torch.cat(l)


def evaluate_model(
model: ChebaiBaseNet,
data_module: XYBaseDataModule,
Expand Down Expand Up @@ -125,12 +132,12 @@ def evaluate_model(
if buffer_dir is not None:
if n_saved * batch_size >= save_batch_size:
torch.save(
torch.cat(preds_list),
_concat_tuple(preds_list),
os.path.join(buffer_dir, f"preds{save_ind:03d}.pt"),
)
if labels_list[0] is not None:
torch.save(
torch.cat(labels_list),
_concat_tuple(labels_list),
os.path.join(buffer_dir, f"labels{save_ind:03d}.pt"),
)
preds_list = []
Expand All @@ -141,20 +148,20 @@ def evaluate_model(
n_saved += 1

if buffer_dir is None:
test_preds = torch.cat(preds_list)
test_preds = _concat_tuple(preds_list)
if labels_list is not None:
test_labels = torch.cat(labels_list)
test_labels = _concat_tuple(labels_list)

return test_preds, test_labels
return test_preds, None
else:
torch.save(
torch.cat(preds_list),
_concat_tuple(preds_list),
os.path.join(buffer_dir, f"preds{save_ind:03d}.pt"),
)
if labels_list[0] is not None:
torch.save(
torch.cat(labels_list),
_concat_tuple(labels_list),
os.path.join(buffer_dir, f"labels{save_ind:03d}.pt"),
)

Expand Down
Loading