Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to integrate CompositionalMetric into a LightningModule? #1108

Closed
celsofranssa opened this issue Jun 22, 2022 · 3 comments
Closed

How to integrate CompositionalMetric into a LightningModule? #1108

celsofranssa opened this issue Jun 22, 2022 · 3 comments
Assignees
Labels
documentation Improvements or additions to documentation

Comments

@celsofranssa
Copy link

Currently, I am computing two metrics at the validation_epoch_end, as shown in the code snippet below.

class LTModel(pl.LightningModule):

    def __init__(self, hparams):
        super(LTModel, self).__init__()
        self.entity_cls_head = # entity classifier 
        self.relation_cls_head = # relation classifier

        # metrics
        self.entity_metrics = self._get_metrics(prefix="entity_")
        self.relation_metrics = self._get_metrics(prefix="relation_")

        # loss
        self.loss = ...

    def _get_metrics(self, prefix):
        pass
        return MetricCollection(
             metrics={
                 "Wei-F1": F1(num_classes=self.hparams.num_entities, average="weighted"),
                 ...
             },
             prefix=prefix)

However, I would like to aggregate these two metrics to form a single one using the CompositionalMetric?
Then, is there an example I can use to achieve this goal?

@celsofranssa celsofranssa added the documentation Improvements or additions to documentation label Jun 22, 2022
@SkafteNicki
Copy link
Member

Hi @celsofranssa,
CompositionalMetric is not meant to be directly used. Instead all metrics support composition through simple arithmetics operations (https://torchmetrics.readthedocs.io/en/stable/pages/overview.html#metric-arithmetics). So if you want the sum of two metrics you just do:

metric = MetricClass1() + MetricClass2()

calling metric.compute() will essentially correspond to calling MetricClass1().compute() + MetricClass2().compute().

However, this only works for base metrics and not collection of metrics (because collections can have different sizes). What you could do would be something like this:

class LTModel(pl.LightningModule):

    def __init__(self, hparams):
        super(LTModel, self).__init__()
        self.entity_cls_head = # entity classifier 
        self.relation_cls_head = # relation classifier

        # metrics
        self.entity_metrics = self._get_metrics(prefix="entity_")
        self.relation_metrics = self._get_metrics(prefix="relation_")
        self.composition_metrics = MetricCollection({})
        for k in self.entity_metrics:
            self.composition_metrics.add_metric({"k": self.entity_metrics[k] + self.relation_metrics[k]})

@celsofranssa
Copy link
Author

Currently, I implemented a custom metric combining entity and relation metrics (as each of them has its own preds and targets)

from torchmetrics import Metric, F1

class FERMetric(Metric):
    def __init__(self, params):
        super().__init__()
        self.entity_metric = F1(num_classes=params.num_entities, average=params.average)
        self.relation_metric = F1(num_classes=params.num_relations, average=params.average)

    def update(self, entity_pred, entity_true, relation_pred, relation_true):
        self.entity_metric.update(entity_pred, entity_true)
        self.relation_metric.update(relation_pred, relation_true)

    def compute(self):
        return 0.5 * (
                self.entity_metric.compute() + self.relation_metric.compute()
        )

However, it seems not right:

Epoch 6:  36%|██               | ... entity_Mac-F1=0.588, relation_Mac-F1=0.477, val_FER=0.260]

since 0.5 * (entity_Mac-F1 + relation_Mac-F1) should be equal to val_FER .
What am I missing?

@SkafteNicki
Copy link
Member

@celsofranssa we are not explicit supporting nesting of metrics. I guess that this could be the problem here, because the underlying logic is not supported. Could you try implementing the reset method also:

def reset(self):
    self.entity_metric.reset()
    self.relation_metric.reset()

reset normally only resets whatever was added with self.add_state and I will therefore assume this is what is wrong (the two sub-metrics are never reset).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
documentation Improvements or additions to documentation
Projects
None yet
Development

No branches or pull requests

2 participants