New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
How to integrate CompositionalMetric
into a LightningModule
?
#1108
Comments
Hi @celsofranssa, metric = MetricClass1() + MetricClass2() calling However, this only works for base metrics and not collection of metrics (because collections can have different sizes). What you could do would be something like this: class LTModel(pl.LightningModule):
def __init__(self, hparams):
super(LTModel, self).__init__()
self.entity_cls_head = # entity classifier
self.relation_cls_head = # relation classifier
# metrics
self.entity_metrics = self._get_metrics(prefix="entity_")
self.relation_metrics = self._get_metrics(prefix="relation_")
self.composition_metrics = MetricCollection({})
for k in self.entity_metrics:
self.composition_metrics.add_metric({"k": self.entity_metrics[k] + self.relation_metrics[k]}) |
Currently, I implemented a custom metric combining entity and relation metrics (as each of them has its own from torchmetrics import Metric, F1
class FERMetric(Metric):
def __init__(self, params):
super().__init__()
self.entity_metric = F1(num_classes=params.num_entities, average=params.average)
self.relation_metric = F1(num_classes=params.num_relations, average=params.average)
def update(self, entity_pred, entity_true, relation_pred, relation_true):
self.entity_metric.update(entity_pred, entity_true)
self.relation_metric.update(relation_pred, relation_true)
def compute(self):
return 0.5 * (
self.entity_metric.compute() + self.relation_metric.compute()
) However, it seems not right: Epoch 6: 36%|██ | ... entity_Mac-F1=0.588, relation_Mac-F1=0.477, val_FER=0.260] since |
@celsofranssa we are not explicit supporting nesting of metrics. I guess that this could be the problem here, because the underlying logic is not supported. Could you try implementing the def reset(self):
self.entity_metric.reset()
self.relation_metric.reset()
|
Currently, I am computing two metrics at the
validation_epoch_end
, as shown in the code snippet below.However, I would like to aggregate these two metrics to form a single one using the CompositionalMetric?
Then, is there an example I can use to achieve this goal?
The text was updated successfully, but these errors were encountered: