Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Commit

Permalink
Fix use of scalar tensors in ConllCorefScores (#1604)
Browse files Browse the repository at this point in the history
* Added test for #1545, fixed test, and improved documentation

* pylint

* mypy
  • Loading branch information
matt-gardner authored and DeNeutoy committed Aug 14, 2018
1 parent 982cedd commit 4eaeff7
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 12 deletions.
19 changes: 19 additions & 0 deletions allennlp/tests/training/metrics/conll_coref_scores_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# pylint: disable=no-self-use,invalid-name,protected-access
import torch

from allennlp.common.testing import AllenNlpTestCase
from allennlp.training.metrics import ConllCorefScores

class ConllCorefScoresTest(AllenNlpTestCase):
def test_get_predicted_clusters(self):
top_spans = torch.Tensor([[0, 1], [4, 6], [8, 9]]).long()
antecedent_indices = torch.Tensor([[-1, -1, -1],
[0, -1, -1],
[0, 1, -1]]).long()
predicted_antecedents = torch.Tensor([-1, -1, 1]).long()
clusters, mention_to_cluster = ConllCorefScores.get_predicted_clusters(top_spans,
antecedent_indices,
predicted_antecedents)
assert len(clusters) == 1
assert set(clusters[0]) == {(4, 6), (8, 9)}
assert mention_to_cluster == {(4, 6): clusters[0], (8, 9): clusters[0]}
58 changes: 46 additions & 12 deletions allennlp/training/metrics/conll_coref_scores.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from typing import Dict, List, Tuple
from typing import Any, Dict, List, Tuple
from collections import Counter
import numpy as np
from sklearn.utils.linear_assignment_ import linear_assignment

from overrides import overrides
from sklearn.utils.linear_assignment_ import linear_assignment
import numpy as np
import torch

from allennlp.training.metrics.metric import Metric

Expand All @@ -13,7 +14,29 @@ def __init__(self) -> None:
self.scorers = [Scorer(m) for m in (Scorer.muc, Scorer.b_cubed, Scorer.ceafe)]

@overrides
def __call__(self, top_spans, antecedent_indices, predicted_antecedents, metadata_list):
def __call__(self, # type: ignore
top_spans: torch.Tensor,
antecedent_indices: torch.Tensor,
predicted_antecedents: torch.Tensor,
metadata_list: List[Dict[str, Any]]):
"""
Parameters
----------
top_spans : ``torch.Tensor``
(start, end) indices for all spans kept after span pruning in the model.
Expected shape: (batch_size, num_spans, 2)
antecedent_indices : ``torch.Tensor``
For each span, the indices of all allowed antecedents for that span. This is
independent of the batch dimension, as it's just based on order in the document.
Expected shape: (num_spans, num_antecedents)
predicted_antecedents: ``torch.Tensor``
For each span, this contains the index (into antecedent_indices) of the most likely
antecedent for that span.
Expected shape: (batch_size, num_spans)
metadata_list : ``List[Dict[str, Any]]``
A metadata dictionary for each instance in the batch. We use the "clusters" key from
this dictionary, which has the annotated gold coreference clusters for that instance.
"""
top_spans, antecedent_indices, predicted_antecedents = self.unwrap_to_tensors(top_spans,
antecedent_indices,
predicted_antecedents)
Expand Down Expand Up @@ -48,7 +71,17 @@ def get_gold_clusters(gold_clusters):
return gold_clusters, mention_to_gold

@staticmethod
def get_predicted_clusters(top_spans, antecedent_indices, predicted_antecedents):
def get_predicted_clusters(top_spans: torch.Tensor,
antecedent_indices: torch.Tensor,
predicted_antecedents: torch.Tensor) -> Tuple[List[Tuple[Tuple[int, int], ...]],
Dict[Tuple[int, int],
Tuple[Tuple[int, int], ...]]]:
# Pytorch 0.4 introduced scalar tensors, so our calls to tuple() and such below don't
# actually give ints unless we convert to numpy first. So we do that here.
top_spans = top_spans.numpy() # (num_spans, 2)
antecedent_indices = antecedent_indices.numpy() # (num_spans, num_antecedents)
predicted_antecedents = predicted_antecedents.numpy() # (num_spans,)

predicted_clusters_to_ids: Dict[Tuple[int, int], int] = {}
clusters: List[List[Tuple[int, int]]] = []
for i, predicted_antecedent in enumerate(predicted_antecedents):
Expand All @@ -59,7 +92,7 @@ def get_predicted_clusters(top_spans, antecedent_indices, predicted_antecedents)
predicted_index = antecedent_indices[i, predicted_antecedent]
# Must be a previous span.
assert i > predicted_index
antecedent_span = tuple(top_spans[predicted_index])
antecedent_span: Tuple[int, int] = tuple(top_spans[predicted_index]) # type: ignore

# Check if we've seen the span before.
if antecedent_span in predicted_clusters_to_ids.keys():
Expand All @@ -70,18 +103,19 @@ def get_predicted_clusters(top_spans, antecedent_indices, predicted_antecedents)
clusters.append([antecedent_span])
predicted_clusters_to_ids[antecedent_span] = predicted_cluster_id

mention = tuple(top_spans[i])
mention: Tuple[int, int] = tuple(top_spans[i]) # type: ignore
clusters[predicted_cluster_id].append(mention)
predicted_clusters_to_ids[mention] = predicted_cluster_id

# finalise the spans and clusters.
clusters = [tuple(cluster) for cluster in clusters]
final_clusters = [tuple(cluster) for cluster in clusters]
# Return a mapping of each mention to the cluster containing it.
predicted_clusters_to_ids: Dict[Tuple[int, int], List[Tuple[int, int]]] = \
{mention: clusters[cluster_id] for mention, cluster_id
in predicted_clusters_to_ids.items()}
mention_to_cluster: Dict[Tuple[int, int], Tuple[Tuple[int, int], ...]] = {
mention: final_clusters[cluster_id]
for mention, cluster_id in predicted_clusters_to_ids.items()
}

return clusters, predicted_clusters_to_ids
return final_clusters, mention_to_cluster


class Scorer:
Expand Down

0 comments on commit 4eaeff7

Please sign in to comment.