Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Implementation of Weighted CRF Tagger (handling unbalanced datasets) #5676

Merged
merged 12 commits into from Jul 14, 2022

Conversation

eraldoluis
Copy link
Contributor

@eraldoluis eraldoluis commented Jun 22, 2022

Closes #4619 .

Dependency of allennlp-models PR #341

Changes proposed in this pull request:

  • I implemented and experimentally compared three sample weighting strategies for CrfTagger.
  • The three strategies are implemented in the files allennlp/modules/conditional_random_field_**<strategy>**.py, where can be: wemission, wtrans, or lannoy.
  • I generalized the methods _input_likelihood(...) and _joint_likelihood(...) of the ConditionalRandomField class so that they now receive an argument with the transition weights. In that way, I could implement the two basic sample weighting strategies (wemission and wtrans) by just subclassing this class and weighting the corresponding weights (logits and transitions) in the forward(...) method before calling _input_likelihood(...) and _joint_likelihood(...). No modification was necessary to the basic algorithms in these two methods.
  • On the other hand, the strategy proposed by Lannoy et al. (suggestion by @dirkgr) needed a quite different implementation.

Before submitting

  • I've read and followed all steps in the Making a pull request
    section of the CONTRIBUTING docs.
  • I've updated or added any relevant docstrings following the syntax described in the
    Writing docstrings section of the CONTRIBUTING docs.
  • If this PR adds a new feature, I've added tests that sufficiently cover my new functionality.

After submitting

  • All GitHub Actions jobs for my pull request have passed.
  • codecov/patch reports high test coverage (at least 90%).
    You can find this under the "Actions" tab of the pull request once the other checks have finished.

@epwalsh
Copy link
Member

epwalsh commented Jun 30, 2022

Hi @eraldoluis, thanks for this! I may not have time for a thorough review this week but this will be a priority next week.

@epwalsh epwalsh self-assigned this Jun 30, 2022
@eraldoluis
Copy link
Contributor Author

Hi @eraldoluis, thanks for this! I may not have time for a thorough review this week but this will be a priority next week.

Thank you, @epwalsh !

Copy link
Member

@epwalsh epwalsh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @eraldoluis, this looks really great! This is my first pass through (I still need to go through some of the code in more detail to get my head around it), but I just have some minor comments. In addition to the comments below, I will also say I think we should organize these modules into a common parent module. That is, create a new folder allennlp/modules/conditional_random_field/ and move the 3 implementations into there.

if label_weights is None:
raise ConfigurationError("label_weights must be given")

self.label_weights = torch.nn.Parameter(torch.Tensor(label_weights), requires_grad=False)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it might be better to use self.register_buffer() here instead of defining the weights as a parameter. That way we can be sure the label weights aren't passed to the optimizer.

https://discuss.pytorch.org/t/what-is-the-difference-between-register-buffer-and-register-parameter-of-nn-module/32723/11

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

Comment on lines 15 to 18
This module uses the "forward-backward" algorithm to compute
the log-likelihood of its inputs assuming a conditional random field model.

See, e.g. http://www.cs.columbia.edu/~mcollins/fb.pdf
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add something about the weighting strategy here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

Comment on lines 15 to 18
This module uses the "forward-backward" algorithm to compute
the log-likelihood of its inputs assuming a conditional random field model.

See, e.g. http://www.cs.columbia.edu/~mcollins/fb.pdf
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should also have a note about the weighting strategy here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

Comment on lines 161 to 165
This module uses the "forward-backward" algorithm to compute
the log-likelihood of its inputs assuming a conditional random field model.

See, e.g. http://www.cs.columbia.edu/~mcollins/fb.pdf

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also would be good to have a note about the weight strategy here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

VITERBI_DECODING = Tuple[List[int], float] # a list of tags, and a viterbi score


def allowed_transitions(constraint_type: str, labels: Dict[int, str]) -> List[Tuple[int, int]]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this any different from the same function in conditional_random_field.py?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I simply forgot to adapt this class. I think now it is much better.

return allowed


def is_transition_allowed(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same question here: is any different from the original?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Idem.

for i, j in constraints:
constraint_mask[i, j] = 1.0

self._constraint_mask = torch.nn.Parameter(constraint_mask, requires_grad=False)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should probably be a buffer as well, but I just realized this is how it's done in the original CRF module, so I guess I'm okay with this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok. I didn't touch it. Although I agree with you that this should be a buffer as well.

@eraldoluis
Copy link
Contributor Author

eraldoluis commented Jul 8, 2022

Thank you a lot @epwalsh for the effort you put on this.

I tried to address your first concerns. Let me know what you think about my changes.

I am looking forward to your feedback regarding the whole thing. Let me know if you have any questions. I will be happy to discuss this further if necessary.

@epwalsh
Copy link
Member

epwalsh commented Jul 12, 2022

Thanks for the quick responses/fixes! Changes look good. I should clarify what I meant by:

create a new folder allennlp/modules/conditional_random_field/ and move the 3 implementations into there.

Looks like you left allennlp/modules/conditional_random_field.py where it is, and then moved the weighted CRFs into allennlp/modules/conditional_random_field_weighted/. I'd rather have a single submodule (folder) called allennlp/modules/conditional_random_field/ with all of the CRFs (included the non-weighted base class).

@epwalsh
Copy link
Member

epwalsh commented Jul 13, 2022

I liked your blog post a lot by the way!

Renamed module allennlp.modules.conditional_random_field_weight
to ...conditional_random_files
@eraldoluis
Copy link
Contributor Author

Looks like you left allennlp/modules/conditional_random_field.py where it is, and then moved the weighted CRFs into allennlp/modules/conditional_random_field_weighted/. I'd rather have a single submodule (folder) called allennlp/modules/conditional_random_field/ with all of the CRFs (included the non-weighted base class).

Yes. I was unsure at first. But now I renamed the module to conditional_random_field and moved the original class to it. I also updated the changelog, which I had forgotten.

I also updated allennlp-models to reflect the new module organization. Unfortunately, I pushed first to the allennlp repository and the Model Tests failed (because allennlp-models was outdated). But these tests should pass now.

Let me know what do you think.

Copy link
Member

@epwalsh epwalsh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! I'll follow up with the PR in allennlp-models next

@epwalsh epwalsh merged commit 5a3acba into allenai:main Jul 14, 2022
@eraldoluis
Copy link
Contributor Author

Thank you very much, @epwalsh and @dirkgr ! This was my first contribution for an open source project and it was quite fun. I will definitely try it again soon. :)

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Handling unbalanced datasets in the CRF tagger
3 participants