Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -176,3 +176,7 @@ lightning_logs
logs
.isort.cfg
/.vscode

*.out
*.err
*.sh
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ The `classes_path` is the path to the dataset's `raw/classes.txt` file that cont
You can evaluate a model trained on the ontology extension task in one of two ways:

### 1. Using the Jupyter Notebook
An example notebook is provided at `tutorials/eval_model_basic.ipynb`.
An example notebook is provided at `tutorials/eval_model_basic.ipynb`.
- Load your finetuned model and run the evaluation cells to compute metrics on the test set.

### 2. Using the Lightning CLI
Expand Down
4 changes: 4 additions & 0 deletions chebai/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,10 @@ def call_data_methods(data: Type[XYBaseDataModule]):
"data.num_of_labels", "trainer.callbacks.init_args.num_labels"
)

parser.link_arguments(
"data", "model.init_args.criterion.init_args.data_extractor"
)

@staticmethod
def subcommands() -> Dict[str, Set[str]]:
"""
Expand Down
36 changes: 23 additions & 13 deletions chebai/loss/bce_weighted.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
class BCEWeighted(torch.nn.BCEWithLogitsLoss):
"""
BCEWithLogitsLoss with weights automatically computed according to the beta parameter.
If beta is None or data_extractor is None, the loss is unweighted.

This class computes weights based on the formula from the paper by Cui et al. (2019):
https://openaccess.thecvf.com/content_CVPR_2019/papers/Cui_Class-Balanced_Loss_Based_on_Effective_Number_of_Samples_CVPR_2019_paper.pdf
Expand All @@ -22,7 +21,7 @@ class BCEWeighted(torch.nn.BCEWithLogitsLoss):

def __init__(
self,
beta: Optional[float] = None,
beta: float = 0.99,
data_extractor: Optional[XYBaseDataModule] = None,
**kwargs,
):
Expand All @@ -32,11 +31,26 @@ def __init__(
if isinstance(data_extractor, LabeledUnlabeledMixed):
data_extractor = data_extractor.labeled
self.data_extractor = data_extractor

assert (
isinstance(beta, float) and beta > 0.0
), f"Beta parameter must be a float with value greater than 0.0, for loss class {self.__class__.__name__}."

assert (
self.data_extractor is not None
), f"Data extractor must be provided if this loss class ({self.__class__.__name__}) is used."

assert all(
os.path.exists(os.path.join(self.data_extractor.processed_dir, file_name))
for file_name in self.data_extractor.processed_file_names
), "Dataset files not found. Make sure the dataset is processed before using this loss."

assert (
isinstance(self.data_extractor, _ChEBIDataExtractor)
or self.data_extractor is None
)
super().__init__(**kwargs)
self.pos_weight: Optional[torch.Tensor] = None

def set_pos_weight(self, input: torch.Tensor) -> None:
"""
Expand All @@ -45,17 +59,7 @@ def set_pos_weight(self, input: torch.Tensor) -> None:
Args:
input (torch.Tensor): The input tensor for which to set the positive weights.
"""
if (
self.beta is not None
and self.data_extractor is not None
and all(
os.path.exists(
os.path.join(self.data_extractor.processed_dir, file_name)
)
for file_name in self.data_extractor.processed_file_names
)
and self.pos_weight is None
):
if self.pos_weight is None:
print(
f"Computing loss-weights based on v{self.data_extractor.chebi_version} dataset (beta={self.beta})"
)
Expand Down Expand Up @@ -96,3 +100,9 @@ def forward(
"""
self.set_pos_weight(input)
return super().forward(input, target)


class UnWeightedBCEWithLogitsLoss(torch.nn.BCEWithLogitsLoss):
def forward(self, input, target, **kwargs):
# As the custom passed kwargs are not used in BCEWithLogitsLoss, we can ignore them
return super().forward(input, target)
1 change: 1 addition & 0 deletions tests/unit/cli/bce_loss.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
class_path: chebai.loss.bce_weighted.UnWeightedBCEWithLogitsLoss
3 changes: 1 addition & 2 deletions tests/unit/cli/testCLI.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@ def setUp(self):
"--model.pass_loss_kwargs=false",
"--trainer.min_epochs=1",
"--trainer.max_epochs=1",
"--model.criterion=configs/loss/bce.yml",
"--model.criterion.init_args.beta=0.99",
"--model.criterion=tests/unit/cli/bce_loss.yml",
]

def test_mlp_on_chebai_cli(self):
Expand Down