Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Implementation of Weighted CRF Tagger (handling unbalanced datasets) #5676

Merged
merged 12 commits into from
Jul 14, 2022
3 changes: 3 additions & 0 deletions allennlp/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
from allennlp.modules.backbones import Backbone
from allennlp.modules.bimpm_matching import BiMpmMatching
from allennlp.modules.conditional_random_field import ConditionalRandomField
from allennlp.modules.conditional_random_field_wemission import ConditionalRandomFieldWeightEmission
from allennlp.modules.conditional_random_field_wtrans import ConditionalRandomFieldWeightTrans
from allennlp.modules.conditional_random_field_lannoy import ConditionalRandomFieldLannoy
from allennlp.modules.elmo import Elmo
from allennlp.modules.feedforward import FeedForward
from allennlp.modules.gated_sum import GatedSum
Expand Down
59 changes: 45 additions & 14 deletions allennlp/modules/conditional_random_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,10 +214,21 @@ def reset_parameters(self):
torch.nn.init.normal_(self.start_transitions)
torch.nn.init.normal_(self.end_transitions)

def _input_likelihood(self, logits: torch.Tensor, mask: torch.BoolTensor) -> torch.Tensor:
"""
Computes the (batch_size,) denominator term for the log-likelihood, which is the
sum of the likelihoods across all possible state sequences.
def _input_likelihood(
self, logits: torch.Tensor, transitions: torch.Tensor, mask: torch.BoolTensor
) -> torch.Tensor:
"""Computes the (batch_size,) denominator term $Z(x)$, per example, for the log-likelihood

This is the sum of the likelihoods across all possible state sequences.

Args:
logits (torch.Tensor): a (batch_size, sequence_length num_tags) tensor of
unnormalized log-probabilities
transitions (torch.Tensor): a (batch_size, num_tags, num_tags) tensor of transition scores
mask (torch.BoolTensor): a (batch_size, sequence_length) tensor of masking flags

Returns:
torch.Tensor: (batch_size,) denominator term $Z(x)$, per example, for the log-likelihood
"""
batch_size, sequence_length, num_tags = logits.size()

Expand All @@ -239,7 +250,7 @@ def _input_likelihood(self, logits: torch.Tensor, mask: torch.BoolTensor) -> tor
# The emit scores are for time i ("next_tag") so we broadcast along the current_tag axis.
emit_scores = logits[i].view(batch_size, 1, num_tags)
# Transition scores are (current_tag, next_tag) so we broadcast along the instance axis.
transition_scores = self.transitions.view(1, num_tags, num_tags)
transition_scores = transitions.view(1, num_tags, num_tags)
# Alpha is for the current_tag, so we broadcast along the next_tag axis.
broadcast_alpha = alpha.view(batch_size, num_tags, 1)

Expand All @@ -262,10 +273,23 @@ def _input_likelihood(self, logits: torch.Tensor, mask: torch.BoolTensor) -> tor
return util.logsumexp(stops)

def _joint_likelihood(
self, logits: torch.Tensor, tags: torch.Tensor, mask: torch.BoolTensor
self,
logits: torch.Tensor,
transitions: torch.Tensor,
tags: torch.Tensor,
mask: torch.BoolTensor,
) -> torch.Tensor:
"""
Computes the numerator term for the log-likelihood, which is just score(inputs, tags)
"""Computes the numerator term for the log-likelihood, which is just score(inputs, tags)

Args:
logits (torch.Tensor): a (batch_size, sequence_length num_tags) tensor of unnormalized
log-probabilities
transitions (torch.Tensor): a (batch_size, num_tags, num_tags) tensor of transition scores
tags (torch.Tensor): output tag sequences (batch_size, sequence_length) $y$ for each input sequence
mask (torch.BoolTensor): a (batch_size, sequence_length) tensor of masking flags

Returns:
torch.Tensor: numerator term for the log-likelihood, which is just score(inputs, tags)
"""
batch_size, sequence_length, _ = logits.data.shape

Expand All @@ -286,7 +310,7 @@ def _joint_likelihood(
current_tag, next_tag = tags[i], tags[i + 1]

# The scores for transitioning from current_tag to next_tag
transition_score = self.transitions[current_tag.view(-1), next_tag.view(-1)]
transition_score = transitions[current_tag.view(-1), next_tag.view(-1)]

# The score for using current_tag
emit_score = logits[i].gather(1, current_tag.view(batch_size, 1)).squeeze(1)
Expand Down Expand Up @@ -318,18 +342,25 @@ def _joint_likelihood(
def forward(
self, inputs: torch.Tensor, tags: torch.Tensor, mask: torch.BoolTensor = None
) -> torch.Tensor:
"""
Computes the log likelihood.
"""
"""Computes the log likelihood for the given batch of input sequences $(x,y)$

Args:
inputs (torch.Tensor): (batch_size, sequence_length, num_tags) tensor of logits for the inputs $x$
tags (torch.Tensor): (batch_size, sequence_length) tensor of tags $y$
mask (torch.BoolTensor, optional): (batch_size, sequence_length) tensor of masking flags.
Defaults to None.

Returns:
torch.Tensor: (batch_size,) log likelihoods $log P(y|x)$ for each input
"""
if mask is None:
mask = torch.ones(*tags.size(), dtype=torch.bool, device=inputs.device)
else:
# The code below fails in weird ways if this isn't a bool tensor, so we make sure.
mask = mask.to(torch.bool)

log_denominator = self._input_likelihood(inputs, mask)
log_numerator = self._joint_likelihood(inputs, tags, mask)
log_denominator = self._input_likelihood(inputs, self.transitions, mask)
log_numerator = self._joint_likelihood(inputs, self.transitions, tags, mask)

return torch.sum(log_numerator - log_denominator)

Expand Down