Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Handling unbalanced datasets in the CRF tagger #4619

Closed
calusbr opened this issue Sep 1, 2020 · 8 comments · Fixed by #5676 or allenai/allennlp-models#341
Closed

Handling unbalanced datasets in the CRF tagger #4619

calusbr opened this issue Sep 1, 2020 · 8 comments · Fixed by #5676 or allenai/allennlp-models#341

Comments

@calusbr
Copy link

calusbr commented Sep 1, 2020

I'm using an unbalanced Corpus NER and I would like to add weights to the entities in the training step via nn.CrossEntropyLoss, I would like to know which .py file can I call my own lib and pass the weights to the model ???

Some tutorials point to:

def sequence_cross_entropy_with_logits(logits: torch.FloatTensor,

Version Allennlp: 0.9.0
Script:

{
  "dataset_reader": {
    "type": "conll2003",
    "tag_label": "ner",
    "coding_scheme": "BIOUL",
    "token_indexers": {
      "tokens": {
        "type": "single_id",
        "lowercase_tokens": true
      },
      "token_characters": {
        "type": "characters",
        "min_padding_length": 3
      },
      "elmo": {
        "type": "elmo_characters"
     }
    }
  },
  "train_data_path": "train.txt",
  "validation_data_path": "dev.txt",
  "test_data_path": "test.txt",
  "model": {
    "type": "crf_tagger",
    "label_encoding": "BIOUL",
    "calculate_span_f1": true,
    "dropout": 0.5,
    "include_start_end_transitions": false,
    "text_field_embedder": {
      "token_embedders": {
        "tokens": {
            "type": "embedding",
            "embedding_dim": 300,
            "pretrained_file": "/model/glove/glove_s300.zip",
            "trainable": true
        },
        "elmo":{
          "type": "elmo_token_embedder",
          "options_file": "/model/elmo/options.json",
          "weight_file": "/model/elmo/elmo_pt_weights_dgx1.hdf5",
          "do_layer_norm": false,
          "dropout": 0.0
        },
        "token_characters": {
            "type": "character_encoding",
            "embedding": {
            "embedding_dim": 16
            },
            "encoder": {
            "type": "cnn",
            "embedding_dim": 16,
            "num_filters": 128,
            "ngram_filter_sizes": [3],
            "conv_layer_activation": "relu"
            }
        }
      }
    },
    "encoder": {
      "type": "lstm",
      "input_size": 1452,
      "hidden_size": 200,
      "num_layers": 2,
      "dropout": 0.5,
      "bidirectional": true
    },
    "verbose_metrics": true,
    "regularizer": [
      [
        "scalar_parameters",
        {
          "type": "l2",
          "alpha": 0.1
        }
      ]
    ]
  },
  "iterator": {
    "type": "basic",
    "batch_size":32
  },
  "trainer": {
    "optimizer": {
        "type": "adam",
        "lr": 0.001
    },
    "validation_metric": "+f1-measure-overall",
    "num_serialized_models_to_keep": 3,
    "num_epochs": 10,
    "grad_norm": 5.0,
    "patience": 25,
    "cuda_device":[0] 
  },
}
@matt-gardner
Copy link
Contributor

Hi @calusbr, the CrfTagger model doesn't currently have a way to specify these weights. I would recommend copying the code to your own repo and modifying it to add in the weighting that you want.

@matt-gardner matt-gardner removed their assignment Sep 4, 2020
@dirkgr
Copy link
Member

dirkgr commented Sep 11, 2020

Let me add to this: In case you develop a general way of giving class weights to CrfTagger, I'd love to review the pull request for it :-)

@dirkgr dirkgr changed the title Loss Imbalance NER Handling unbalanced datasets in the CRF tagger Sep 11, 2020
@calusbr
Copy link
Author

calusbr commented Sep 22, 2020

@matt-gardner Thanks for the feedback! @dirkgr I believe it is possible to create a task for this problem. Do you have any idea how we can add weights?

@matt-gardner In case there is any tip on how we can add this weight without changing the official code, I await feedback!

@dirkgr
Copy link
Member

dirkgr commented Sep 23, 2020

I did a quick Google of this problem, and unless I missed something, no major library has weighted CRFs. That tells me that a) it would be very cool if AllenNLP did, and b) it's not easy.

I did find this paper referenced a few times: https://perso.uclouvain.be/michel.verleysen/papers/ieeetbe12gdl.pdf

They give the math, but not code, on how to do it. I'd recommend copying the existing CrfTagger code, adding the math from the paper there, and submitting a PR to us. Or maybe do some more research first to find a more accessible paper, or maybe even an existing implementation somewhere that could be adapted.

@epwalsh
Copy link
Member

epwalsh commented Sep 23, 2020

I glanced over the paper and it seems relatively straight-forward, but it would definitely take some refactoring of the CrfTagger, and I'm not sure what the performance implications are.

But this definitely interests me, so if no one else decides to work on it, I might eventually do it.

@eraldoluis
Copy link
Contributor

Hi @dirkgr and @epwalsh ,

I implemented and experimentally compared three weighting strategies for CRF in AllenNLP. One of the strategies was the one proposed by Lannoy et al. (this paper was mentioned above by @dirkgr). The other strategies are two straightforward approaches based on weighting emission and/or transition scores.

I performed these experiments because I wasn't convinced by the theoretical basis of Lannoy et al.'s method. And the experimental results (although the setup was quite limited) corroborate my concerns with this method. The results show that Lannoy et al.'s approach presents a kind of inconsistent behaviour, which makes the method hard to be used in practice in my opinion.

If someone is willing to review a PR with a weighted CRF method, I can submit one soon. I just need some feedback about which strategy is better or whether it would be interesting to have the three strategies available. As mentioned in the end of the report, I think the best approach for AllenNLP would be the emission-based one. It's the simplest method among the three and presents a nice behaviour.

In my current implementation (allennlp and allennlp-model), I included a weight_strategy parameter to CrfTagger in order to be able to select one of the three strategies and, thus, make the comparison experiments simpler. But, as mentioned, I don't think it is worth to include the three methods in AllenNLP.

I also included a new version of the FBetaMeasure metric, which I called FBetaMeasure2 and includes a few more options than the original one (these options are more suitable to my experiments). On top of that, this implementation also uses the more "modern" metric API that each metric value should be stored under an exclusive key (instead of having a key with a list of values as in FBetaMeasure). So, in fact, I think this new implementation could replace the previous one. But this is another question that I have: how to include this in a PR. It seems to me that it would be better to have it in a separate PR, but the weighted CRF implementation depends on it. Any suggestion?

@dirkgr
Copy link
Member

dirkgr commented May 18, 2022

Sounds like we should have two PRs: One with the improved FBetaMeasure, and then another with the improved CrfTagger. If it's not a huge amount of code, I think it would be good to have all three methods, with the default that you suggested. That way others can reproduce your results.

@eraldoluis
Copy link
Contributor

Thank you for your comment, @dirkgr ! It is definitely not a huge amount of code. I will polish the code (there are some unnecessary duplicated code chunks) and submit the two PRs.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.