diff --git a/.gitignore b/.gitignore index bafec1d9..be471dd0 100644 --- a/.gitignore +++ b/.gitignore @@ -176,3 +176,7 @@ lightning_logs logs .isort.cfg /.vscode + +*.out +*.err +*.sh diff --git a/README.md b/README.md index eeecd714..73179ac0 100644 --- a/README.md +++ b/README.md @@ -81,7 +81,7 @@ The `classes_path` is the path to the dataset's `raw/classes.txt` file that cont You can evaluate a model trained on the ontology extension task in one of two ways: ### 1. Using the Jupyter Notebook -An example notebook is provided at `tutorials/eval_model_basic.ipynb`. +An example notebook is provided at `tutorials/eval_model_basic.ipynb`. - Load your finetuned model and run the evaluation cells to compute metrics on the test set. ### 2. Using the Lightning CLI diff --git a/chebai/cli.py b/chebai/cli.py index 96262447..de9ef893 100644 --- a/chebai/cli.py +++ b/chebai/cli.py @@ -70,6 +70,10 @@ def call_data_methods(data: Type[XYBaseDataModule]): "data.num_of_labels", "trainer.callbacks.init_args.num_labels" ) + parser.link_arguments( + "data", "model.init_args.criterion.init_args.data_extractor" + ) + @staticmethod def subcommands() -> Dict[str, Set[str]]: """ diff --git a/chebai/loss/bce_weighted.py b/chebai/loss/bce_weighted.py index 993d535e..1f21b04b 100644 --- a/chebai/loss/bce_weighted.py +++ b/chebai/loss/bce_weighted.py @@ -10,7 +10,6 @@ class BCEWeighted(torch.nn.BCEWithLogitsLoss): """ BCEWithLogitsLoss with weights automatically computed according to the beta parameter. - If beta is None or data_extractor is None, the loss is unweighted. This class computes weights based on the formula from the paper by Cui et al. (2019): https://openaccess.thecvf.com/content_CVPR_2019/papers/Cui_Class-Balanced_Loss_Based_on_Effective_Number_of_Samples_CVPR_2019_paper.pdf @@ -22,7 +21,7 @@ class BCEWeighted(torch.nn.BCEWithLogitsLoss): def __init__( self, - beta: Optional[float] = None, + beta: float = 0.99, data_extractor: Optional[XYBaseDataModule] = None, **kwargs, ): @@ -32,11 +31,26 @@ def __init__( if isinstance(data_extractor, LabeledUnlabeledMixed): data_extractor = data_extractor.labeled self.data_extractor = data_extractor + + assert ( + isinstance(beta, float) and beta > 0.0 + ), f"Beta parameter must be a float with value greater than 0.0, for loss class {self.__class__.__name__}." + + assert ( + self.data_extractor is not None + ), f"Data extractor must be provided if this loss class ({self.__class__.__name__}) is used." + + assert all( + os.path.exists(os.path.join(self.data_extractor.processed_dir, file_name)) + for file_name in self.data_extractor.processed_file_names + ), "Dataset files not found. Make sure the dataset is processed before using this loss." + assert ( isinstance(self.data_extractor, _ChEBIDataExtractor) or self.data_extractor is None ) super().__init__(**kwargs) + self.pos_weight: Optional[torch.Tensor] = None def set_pos_weight(self, input: torch.Tensor) -> None: """ @@ -45,17 +59,7 @@ def set_pos_weight(self, input: torch.Tensor) -> None: Args: input (torch.Tensor): The input tensor for which to set the positive weights. """ - if ( - self.beta is not None - and self.data_extractor is not None - and all( - os.path.exists( - os.path.join(self.data_extractor.processed_dir, file_name) - ) - for file_name in self.data_extractor.processed_file_names - ) - and self.pos_weight is None - ): + if self.pos_weight is None: print( f"Computing loss-weights based on v{self.data_extractor.chebi_version} dataset (beta={self.beta})" ) @@ -96,3 +100,9 @@ def forward( """ self.set_pos_weight(input) return super().forward(input, target) + + +class UnWeightedBCEWithLogitsLoss(torch.nn.BCEWithLogitsLoss): + def forward(self, input, target, **kwargs): + # As the custom passed kwargs are not used in BCEWithLogitsLoss, we can ignore them + return super().forward(input, target) diff --git a/tests/unit/cli/bce_loss.yml b/tests/unit/cli/bce_loss.yml new file mode 100644 index 00000000..ed0a00b6 --- /dev/null +++ b/tests/unit/cli/bce_loss.yml @@ -0,0 +1 @@ +class_path: chebai.loss.bce_weighted.UnWeightedBCEWithLogitsLoss diff --git a/tests/unit/cli/testCLI.py b/tests/unit/cli/testCLI.py index d76b5a33..863a6df3 100644 --- a/tests/unit/cli/testCLI.py +++ b/tests/unit/cli/testCLI.py @@ -15,8 +15,7 @@ def setUp(self): "--model.pass_loss_kwargs=false", "--trainer.min_epochs=1", "--trainer.max_epochs=1", - "--model.criterion=configs/loss/bce.yml", - "--model.criterion.init_args.beta=0.99", + "--model.criterion=tests/unit/cli/bce_loss.yml", ] def test_mlp_on_chebai_cli(self):