Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 67 additions & 7 deletions chebai/preprocessing/datasets/chebi.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,9 +301,7 @@ def _graph_to_raw_dataset(self, g: nx.DiGraph) -> pd.DataFrame:
data = pd.DataFrame(data)
data = data[~data["SMILES"].isnull()]
data = data[[name not in CHEBI_BLACKLIST for name, _ in data.iterrows()]]
# This filters the DataFrame to include only the rows where at least one value in the row from 4th column
# onwards is True/non-zero.
data = data[data.iloc[:, self._LABELS_START_IDX :].any(axis=1)]

return data

# ------------------------------ Phase: Setup data -----------------------------------
Expand Down Expand Up @@ -702,18 +700,24 @@ class ChEBIOverXPartial(ChEBIOverX):
top_class_id (int): The ID of the top class from which to extract subclasses.
"""

def __init__(self, top_class_id: int, **kwargs):
def __init__(self, top_class_id: int, external_data_ratio: float, **kwargs):
"""
Initializes the ChEBIOverXPartial dataset.

Args:
top_class_id (int): The ID of the top class from which to extract subclasses.
**kwargs: Additional keyword arguments passed to the superclass initializer.
external_data_ratio (float): How much external data (i.e., samples where top_class_id
is no positive label) to include in the dataset. 0 means no external data, 1 means
the maximum amount (i.e., the complete ChEBI dataset).
"""
if "top_class_id" not in kwargs:
kwargs["top_class_id"] = top_class_id
if "external_data_ratio" not in kwargs:
kwargs["external_data_ratio"] = external_data_ratio

self.top_class_id: int = top_class_id
self.external_data_ratio: float = external_data_ratio
super().__init__(**kwargs)

@property
Expand All @@ -727,7 +731,7 @@ def processed_dir_main(self) -> str:
return os.path.join(
self.base_dir,
self._name,
f"partial_{self.top_class_id}",
f"partial_{self.top_class_id}_ext_ratio_{self.external_data_ratio:.2f}",
"processed",
)

Expand All @@ -746,9 +750,53 @@ def _extract_class_hierarchy(self, chebi_path: str) -> nx.DiGraph:
descendants of the top class ID.
"""
g = super()._extract_class_hierarchy(chebi_path)
g = g.subgraph(list(g.successors(self.top_class_id)) + [self.top_class_id])
top_class_successors = list(g.successors(self.top_class_id)) + [
self.top_class_id
]
external_nodes = list(set(n for n in g.nodes if n not in top_class_successors))
if 0 < self.external_data_ratio < 1:
n_external_nodes = int(
len(top_class_successors)
* self.external_data_ratio
/ (1 - self.external_data_ratio)
)
print(
f"Extracting {n_external_nodes} external nodes from the ChEBI dataset (ratio: {self.external_data_ratio:.2f})"
)
external_nodes = external_nodes[: int(n_external_nodes)]
elif self.external_data_ratio == 0:
external_nodes = []

g = g.subgraph(top_class_successors + external_nodes)
print(
f"Subgraph contains {len(g.nodes)} nodes, of which {len(top_class_successors)} are subclasses of the top class ID {self.top_class_id}."
)
return g

def select_classes(self, g: nx.DiGraph, *args, **kwargs) -> List:
"""Only selects classes that meet the threshold AND are subclasses of the top class ID (including itself)."""
smiles = nx.get_node_attributes(g, "smiles")
nodes = list(
sorted(
{
node
for node in g.nodes
if sum(
1 if smiles[s] is not None else 0 for s in g.successors(node)
)
>= self.THRESHOLD
and (
self.top_class_id in g.predecessors(node)
or node == self.top_class_id
)
}
)
)
filename = "classes.txt"
with open(os.path.join(self.processed_dir_main, filename), "wt") as fout:
fout.writelines(str(node) + "\n" for node in nodes)
return nodes


class ChEBIOver50Partial(ChEBIOverXPartial, ChEBIOver50):
"""
Expand Down Expand Up @@ -842,7 +890,7 @@ def term_callback(doc: fastobo.term.TermFrame) -> Union[Dict, bool]:


atom_index = (
"\*",
r"\*",
"H",
"He",
"Li",
Expand Down Expand Up @@ -1473,3 +1521,15 @@ def term_callback(doc: fastobo.term.TermFrame) -> Union[Dict, bool]:
]

JCI_500_COLUMNS_INT = [int(n.split(":")[-1]) for n in JCI_500_COLUMNS]

if __name__ == "__main__":
data_module_05 = ChEBIOver50Partial(
chebi_version=241,
splits_file_path=os.path.join(
"data", "chebi_v241", "ChEBI50", "splits_80_10_10.csv"
),
top_class_id=22712,
external_data_ratio=0.5,
)
data_module_05.prepare_data()
data_module_05.setup()
4 changes: 3 additions & 1 deletion tests/unit/dataset_classes/testChebiOverXPartial.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@ def setUpClass(cls, mock_makedirs) -> None:
"""
Set up the ChEBIOverXPartial instance with a mock processed directory path and a test graph.
"""
cls.chebi_extractor = ChEBIOverXPartial(top_class_id=11111, chebi_version=231)
cls.chebi_extractor = ChEBIOverXPartial(
top_class_id=11111, external_data_ratio=0, chebi_version=231
)
cls.test_graph = ChebiMockOntology.get_transitively_closed_graph()

@patch(
Expand Down