From 83fe459e82324eced6aa74c7981bc6f7b7e4a45d Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Wed, 8 Oct 2025 18:04:10 +0200 Subject: [PATCH 1/2] persistent workers can be set through CLI for GNI --- chebai/preprocessing/datasets/base.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/chebai/preprocessing/datasets/base.py b/chebai/preprocessing/datasets/base.py index 12eb634c..4e1b9e2a 100644 --- a/chebai/preprocessing/datasets/base.py +++ b/chebai/preprocessing/datasets/base.py @@ -76,6 +76,7 @@ def __init__( label_filter: Optional[int] = None, balance_after_filter: Optional[float] = None, num_workers: int = 1, + persistent_workers=True, chebi_version: int = 200, inner_k_folds: int = -1, # use inner cross-validation if > 1 fold_index: Optional[int] = None, @@ -99,6 +100,7 @@ def __init__( ), "Filter balancing requires a filter" self.balance_after_filter = balance_after_filter self.num_workers = num_workers + self.persistent_workers: bool = bool(persistent_workers) self.chebi_version = chebi_version assert type(inner_k_folds) is int self.inner_k_folds = inner_k_folds @@ -360,7 +362,7 @@ def train_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader "train", shuffle=True, num_workers=self.num_workers, - persistent_workers=True, + persistent_workers=self.persistent_workers, **kwargs, ) @@ -379,7 +381,7 @@ def val_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]] "validation", shuffle=False, num_workers=self.num_workers, - persistent_workers=True, + persistent_workers=self.persistent_workers, **kwargs, ) From c72a2d9217da66497e772cfed3c812ae2d042a34 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Fri, 10 Oct 2025 12:58:52 +0200 Subject: [PATCH 2/2] typehint for argparse --- chebai/preprocessing/datasets/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/chebai/preprocessing/datasets/base.py b/chebai/preprocessing/datasets/base.py index 4e1b9e2a..3ac0a803 100644 --- a/chebai/preprocessing/datasets/base.py +++ b/chebai/preprocessing/datasets/base.py @@ -76,7 +76,7 @@ def __init__( label_filter: Optional[int] = None, balance_after_filter: Optional[float] = None, num_workers: int = 1, - persistent_workers=True, + persistent_workers: bool = True, chebi_version: int = 200, inner_k_folds: int = -1, # use inner cross-validation if > 1 fold_index: Optional[int] = None,