Skip to content

refactor token classification#35

Merged
ArneBinder merged 66 commits into
mainfrom
refactor_token_classification
Jan 19, 2024
Merged

refactor token classification#35
ArneBinder merged 66 commits into
mainfrom
refactor_token_classification

Conversation

@ArneBinder
Copy link
Copy Markdown
Owner

@ArneBinder ArneBinder commented Jan 12, 2024

This PR introduces several changes to the token classification based labeled span extraction setups. Especially, metric setup happens in the task module now and span based metrics are logged during training.

Task Modules and Metrics

WrappedMetricWithPrepareFunction: newly added

  • wrapper around torchmetrics.Metric that pre-processes the predictions and targets with a prepare_function
  • prepare_function can unbatch the predictions / targets in which case the wrapped metric is updated multiple times

PrecisionRecallAndF1ForLabeledAnnotations: make it a real torchmetrics.Metric

  • save tp, fn, fp as tensors in the state instead of the correct / gold / predicted annotations

LabeledSpanExtractionByTokenClassificationTaskModule: formerly TokenClassificationTaskModule

  • rename label_token_pad_id to label_pad_id (default as before: -100), i.e. this is breaking
  • unbatch_output() expects a LongTensor representing the predicted label indices (and label_pad_id on pad positions)
  • include special_tokens_mask in input encodings
  • implement configure_model_metric() that creates the following metrics:
    • token metrics: micro and macro F1, without ignoring any class
    • span metrics: micro, macro, and per-class Precision, Recall, and F1
  • add parameter log_precision_recall_metrics (default: True) to disable logging of precision and recall metrics. Disabling this is useful to not get overwhelmed by to many logged metrics

Models

WithMetricsFromTaskModule: new mixin

  • handles metric setup, update, and logging in the models

Model: new abstract model

  • contains boiler plate code that is used in most of the models
  • uses WithMetricsFromTaskModule

SimpleTokenClassificationModel: newly added

  • taken from pytorch_ie.models.TransformerTokenClassificationModel, simple wrapper around transformers.AutoModelForTokenClassification

For both models, SimpleTokenClassificationModel and TokenClassificationModelWithSeq2SeqEncoderAndCrf:

  • rename label_pad_token_id to label_pad_id, i.e. this is breaking
  • remove any manual metric setup/updating/logging (now handled by taskmodule.configure_model_metric() in WithMetricsFromTaskModule)
  • outsourced much code to Model

Requires

Follow-up

TODOs:

  • metric:
    • do not prepare input in step()
    • add span F1 metric
    • enable macro averaging
    • fix metric log keys
  • adjust README.md
  • documentation
  • rename to LabeledSpanExtraction(TaskmoDule|Model) taskmodule to LabeledSpanExtractionByTokenClassificationTaskModule

@ArneBinder ArneBinder added refactoring Refactoring breaking Breaking Changes labels Jan 12, 2024
@codecov-commenter
Copy link
Copy Markdown

Codecov Report

Attention: 20 lines in your changes are missing coverage. Please review.

Comparison is base (5f96e5a) 95.62% compared to head (8a2b890) 95.20%.

Files Patch % Lines
.../pie_modules/models/simple_token_classification.py 86.48% 10 Missing ⚠️
...ken_classification_with_seq2seq_encoder_and_crf.py 81.48% 10 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main      #35      +/-   ##
==========================================
- Coverage   95.62%   95.20%   -0.43%     
==========================================
  Files          40       41       +1     
  Lines        3316     3417     +101     
==========================================
+ Hits         3171     3253      +82     
- Misses        145      164      +19     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@codecov-commenter
Copy link
Copy Markdown

codecov-commenter commented Jan 13, 2024

Codecov Report

Attention: 13 lines in your changes are missing coverage. Please review.

Comparison is base (5f96e5a) 95.62% compared to head (3fd3066) 95.79%.

Files Patch % Lines
...ules/models/mixins/with_metrics_from_taskmodule.py 87.71% 7 Missing ⚠️
...es/metrics/wrapped_metric_with_prepare_function.py 91.11% 4 Missing ⚠️
src/pie_modules/models/model.py 97.82% 1 Missing ⚠️
.../pie_modules/models/simple_token_classification.py 97.43% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main      #35      +/-   ##
==========================================
+ Coverage   95.62%   95.79%   +0.16%     
==========================================
  Files          40       45       +5     
  Lines        3316     3544     +228     
==========================================
+ Hits         3171     3395     +224     
- Misses        145      149       +4     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

…also for TokenClassificationModelWithSeq2SeqEncoderAndCrf
…ssificationTaskModule ad remove deprecated parameters
@ArneBinder ArneBinder merged commit 8e6d5b2 into main Jan 19, 2024
@ArneBinder ArneBinder deleted the refactor_token_classification branch January 19, 2024 15:00
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

breaking Breaking Changes refactoring Refactoring

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants