From 3166bb36105fb73dbdc27c4968645732a2f06e64 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sat, 1 Nov 2025 11:44:07 +0100 Subject: [PATCH 1/7] add error when beta is set but not dataextractor --- chebai/loss/bce_weighted.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/chebai/loss/bce_weighted.py b/chebai/loss/bce_weighted.py index 993d535e..de202ff0 100644 --- a/chebai/loss/bce_weighted.py +++ b/chebai/loss/bce_weighted.py @@ -32,6 +32,11 @@ def __init__( if isinstance(data_extractor, LabeledUnlabeledMixed): data_extractor = data_extractor.labeled self.data_extractor = data_extractor + + # If beta is provided, require a data_extractor. + if self.beta is not None and self.data_extractor is None: + raise ValueError("When 'beta' is set, 'data_extractor' must also be set.") + assert ( isinstance(self.data_extractor, _ChEBIDataExtractor) or self.data_extractor is None From 6c99441031faf2f6e166432c55016e4883296f30 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sat, 1 Nov 2025 13:00:52 +0100 Subject: [PATCH 2/7] make beta mandatory for this loss func --- chebai/loss/bce_weighted.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/chebai/loss/bce_weighted.py b/chebai/loss/bce_weighted.py index de202ff0..f35d95fd 100644 --- a/chebai/loss/bce_weighted.py +++ b/chebai/loss/bce_weighted.py @@ -33,6 +33,10 @@ def __init__( data_extractor = data_extractor.labeled self.data_extractor = data_extractor + assert ( + beta is not None + ), f"Beta parameter must be provided if this loss ({self.__class__.__name__}) is used." + # If beta is provided, require a data_extractor. if self.beta is not None and self.data_extractor is None: raise ValueError("When 'beta' is set, 'data_extractor' must also be set.") From 65c448b6ee568a4a107c4d2666ccb1f45a099aa1 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sat, 1 Nov 2025 13:12:33 +0100 Subject: [PATCH 3/7] bceweighted class should only be used if weighting is the intention, else another class is provided --- chebai/loss/bce_weighted.py | 34 ++++++++++++++++------------------ 1 file changed, 16 insertions(+), 18 deletions(-) diff --git a/chebai/loss/bce_weighted.py b/chebai/loss/bce_weighted.py index f35d95fd..e31b8dbc 100644 --- a/chebai/loss/bce_weighted.py +++ b/chebai/loss/bce_weighted.py @@ -10,7 +10,6 @@ class BCEWeighted(torch.nn.BCEWithLogitsLoss): """ BCEWithLogitsLoss with weights automatically computed according to the beta parameter. - If beta is None or data_extractor is None, the loss is unweighted. This class computes weights based on the formula from the paper by Cui et al. (2019): https://openaccess.thecvf.com/content_CVPR_2019/papers/Cui_Class-Balanced_Loss_Based_on_Effective_Number_of_Samples_CVPR_2019_paper.pdf @@ -33,19 +32,22 @@ def __init__( data_extractor = data_extractor.labeled self.data_extractor = data_extractor - assert ( - beta is not None - ), f"Beta parameter must be provided if this loss ({self.__class__.__name__}) is used." + assert self.beta is not None and self.data_extractor is not None, ( + f"Beta parameter must be provided along with data_extractor, " + f"if this loss class ({self.__class__.__name__}) is used." + ) - # If beta is provided, require a data_extractor. - if self.beta is not None and self.data_extractor is None: - raise ValueError("When 'beta' is set, 'data_extractor' must also be set.") + assert all( + os.path.exists(os.path.join(self.data_extractor.processed_dir, file_name)) + for file_name in self.data_extractor.processed_file_names + ), "Dataset files not found. Make sure the dataset is processed before using this loss." assert ( isinstance(self.data_extractor, _ChEBIDataExtractor) or self.data_extractor is None ) super().__init__(**kwargs) + self.pos_weight: Optional[torch.Tensor] = None def set_pos_weight(self, input: torch.Tensor) -> None: """ @@ -54,17 +56,7 @@ def set_pos_weight(self, input: torch.Tensor) -> None: Args: input (torch.Tensor): The input tensor for which to set the positive weights. """ - if ( - self.beta is not None - and self.data_extractor is not None - and all( - os.path.exists( - os.path.join(self.data_extractor.processed_dir, file_name) - ) - for file_name in self.data_extractor.processed_file_names - ) - and self.pos_weight is None - ): + if self.pos_weight is None: print( f"Computing loss-weights based on v{self.data_extractor.chebi_version} dataset (beta={self.beta})" ) @@ -105,3 +97,9 @@ def forward( """ self.set_pos_weight(input) return super().forward(input, target) + + +class UnWeightedBCEWithLogitsLoss(torch.nn.BCEWithLogitsLoss): + def forward(self, input, target, **kwargs): + # As the custom passed kwargs are not used in BCEWithLogitsLoss, we can ignore them + return super().forward(input, target) From c8ddfb6c5a73d7d1c7b39173f3c321d7e087f1b6 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sun, 9 Nov 2025 12:47:52 +0100 Subject: [PATCH 4/7] mlp unit - use unweighted bce loss --- tests/unit/cli/bce_loss.yml | 1 + tests/unit/cli/testCLI.py | 3 +-- 2 files changed, 2 insertions(+), 2 deletions(-) create mode 100644 tests/unit/cli/bce_loss.yml diff --git a/tests/unit/cli/bce_loss.yml b/tests/unit/cli/bce_loss.yml new file mode 100644 index 00000000..ed0a00b6 --- /dev/null +++ b/tests/unit/cli/bce_loss.yml @@ -0,0 +1 @@ +class_path: chebai.loss.bce_weighted.UnWeightedBCEWithLogitsLoss diff --git a/tests/unit/cli/testCLI.py b/tests/unit/cli/testCLI.py index d76b5a33..863a6df3 100644 --- a/tests/unit/cli/testCLI.py +++ b/tests/unit/cli/testCLI.py @@ -15,8 +15,7 @@ def setUp(self): "--model.pass_loss_kwargs=false", "--trainer.min_epochs=1", "--trainer.max_epochs=1", - "--model.criterion=configs/loss/bce.yml", - "--model.criterion.init_args.beta=0.99", + "--model.criterion=tests/unit/cli/bce_loss.yml", ] def test_mlp_on_chebai_cli(self): From 9844e6808e21f4640ef45acee9f8efbaacaa63fd Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sun, 9 Nov 2025 13:05:44 +0100 Subject: [PATCH 5/7] link data extractor to weighted bce loss --- chebai/cli.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/chebai/cli.py b/chebai/cli.py index 96262447..de9ef893 100644 --- a/chebai/cli.py +++ b/chebai/cli.py @@ -70,6 +70,10 @@ def call_data_methods(data: Type[XYBaseDataModule]): "data.num_of_labels", "trainer.callbacks.init_args.num_labels" ) + parser.link_arguments( + "data", "model.init_args.criterion.init_args.data_extractor" + ) + @staticmethod def subcommands() -> Dict[str, Set[str]]: """ From e7ed0cdae08a2f7cf04d009d8b85cfefbbf2006d Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Fri, 14 Nov 2025 13:00:09 +0100 Subject: [PATCH 6/7] fix beta type checks --- .gitignore | 4 ++++ chebai/loss/bce_weighted.py | 15 ++++++++++----- 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/.gitignore b/.gitignore index bafec1d9..d75fc498 100644 --- a/.gitignore +++ b/.gitignore @@ -176,3 +176,7 @@ lightning_logs logs .isort.cfg /.vscode + +*.out +*.err +*.sh \ No newline at end of file diff --git a/chebai/loss/bce_weighted.py b/chebai/loss/bce_weighted.py index e31b8dbc..48fe5bf2 100644 --- a/chebai/loss/bce_weighted.py +++ b/chebai/loss/bce_weighted.py @@ -21,7 +21,7 @@ class BCEWeighted(torch.nn.BCEWithLogitsLoss): def __init__( self, - beta: Optional[float] = None, + beta: float = 0.99, data_extractor: Optional[XYBaseDataModule] = None, **kwargs, ): @@ -32,15 +32,20 @@ def __init__( data_extractor = data_extractor.labeled self.data_extractor = data_extractor - assert self.beta is not None and self.data_extractor is not None, ( - f"Beta parameter must be provided along with data_extractor, " - f"if this loss class ({self.__class__.__name__}) is used." + assert isinstance(beta, float) and beta > 0.0, ( + f"Beta parameter must be a float with value greater than 0.0, for loss class {self.__class__.__name__}." + ) + + assert self.data_extractor is not None, ( + f"Data extractor must be provided if this loss class ({self.__class__.__name__}) is used." ) assert all( os.path.exists(os.path.join(self.data_extractor.processed_dir, file_name)) for file_name in self.data_extractor.processed_file_names - ), "Dataset files not found. Make sure the dataset is processed before using this loss." + ), ( + "Dataset files not found. Make sure the dataset is processed before using this loss." + ) assert ( isinstance(self.data_extractor, _ChEBIDataExtractor) From a86e0f7d0a7adbf6b978a6e381ee1e4d0f1d2823 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Fri, 14 Nov 2025 13:03:57 +0100 Subject: [PATCH 7/7] pre-commit --- .gitignore | 4 ++-- README.md | 2 +- chebai/loss/bce_weighted.py | 16 +++++++--------- 3 files changed, 10 insertions(+), 12 deletions(-) diff --git a/.gitignore b/.gitignore index d75fc498..be471dd0 100644 --- a/.gitignore +++ b/.gitignore @@ -177,6 +177,6 @@ logs .isort.cfg /.vscode -*.out +*.out *.err -*.sh \ No newline at end of file +*.sh diff --git a/README.md b/README.md index eeecd714..73179ac0 100644 --- a/README.md +++ b/README.md @@ -81,7 +81,7 @@ The `classes_path` is the path to the dataset's `raw/classes.txt` file that cont You can evaluate a model trained on the ontology extension task in one of two ways: ### 1. Using the Jupyter Notebook -An example notebook is provided at `tutorials/eval_model_basic.ipynb`. +An example notebook is provided at `tutorials/eval_model_basic.ipynb`. - Load your finetuned model and run the evaluation cells to compute metrics on the test set. ### 2. Using the Lightning CLI diff --git a/chebai/loss/bce_weighted.py b/chebai/loss/bce_weighted.py index 48fe5bf2..1f21b04b 100644 --- a/chebai/loss/bce_weighted.py +++ b/chebai/loss/bce_weighted.py @@ -32,20 +32,18 @@ def __init__( data_extractor = data_extractor.labeled self.data_extractor = data_extractor - assert isinstance(beta, float) and beta > 0.0, ( - f"Beta parameter must be a float with value greater than 0.0, for loss class {self.__class__.__name__}." - ) + assert ( + isinstance(beta, float) and beta > 0.0 + ), f"Beta parameter must be a float with value greater than 0.0, for loss class {self.__class__.__name__}." - assert self.data_extractor is not None, ( - f"Data extractor must be provided if this loss class ({self.__class__.__name__}) is used." - ) + assert ( + self.data_extractor is not None + ), f"Data extractor must be provided if this loss class ({self.__class__.__name__}) is used." assert all( os.path.exists(os.path.join(self.data_extractor.processed_dir, file_name)) for file_name in self.data_extractor.processed_file_names - ), ( - "Dataset files not found. Make sure the dataset is processed before using this loss." - ) + ), "Dataset files not found. Make sure the dataset is processed before using this loss." assert ( isinstance(self.data_extractor, _ChEBIDataExtractor)