Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: Add SpanMarker Argilla Trainer for NER #2693

Conversation

tomaarsen
Copy link
Contributor

Hello!

Pull Request overview

  • Add SpanMarker Argilla Trainer for Named Entity Recognition.

Details

The SpanMarker Argilla trainer is based on the Transformers Trainer, as SpanMarker is tightly implemented on top of transformers. However, we don't need to do the tokenization, data collation or evaluation on the Argilla side, unlike with Transformers. This makes the SpanMarker Argilla trainer relatively small.

Usage

First, we need an annotated dataset:

import argilla as rg
from datasets import load_dataset

dataset = "conll2003"
dataset_ds = load_dataset(
    "conll2003",
    split="train[:1000]",
)
dataset_ds = dataset_ds.rename_column("ner_tags", "tags")
dataset_rb = rg.read_datasets(dataset_ds, task="TokenClassification")

rg.delete(dataset)
rg.log(name=dataset, records=dataset_rb)

And then we can use the new Trainer to train with this dataset:

import argilla as rg
from argilla.training.base import ArgillaTrainer

dataset = "conll2003"

trainer = ArgillaTrainer(name=dataset, framework="span_marker", train_size=0.8)
trainer.update_config(
    num_train_epochs=10,
    bf16=True,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    marker_max_length=128,
    entity_max_length=8,
)
trainer.train(output_dir="tmp_span_marker_train")

(You can use lower batch sizes or model_max_length=256 if you have memory issues. You can also use fp16 instead of bf16 if you get an error.)

This produces the following logs:

Click to see the logs
[04/13/23 16:25:25] WARNING  WARNING:argilla.client.datasets:No label schema provided. Using all_labels: TokenClassificationSettings(['LOC', 'MISC', 'ORG', 'PER']). We recommend      datasets.py:1222
                             providing a `TokenClassificationSettings()` or setting it `rg.configure_dataset_settings()`/`rg.load_dataset_settings()` to ensure reproducibility.       
[04/13/23 16:25:30] WARNING  WARNING:ArgillaTrainer:ArgillaBaseTrainer info:                                                                                                                base.py:175
                             _________________________________________________________________
                             These baseline params are fixed:
                                 dataset: conll2003
                                 task: DatasetForTokenClassification
                                 multi_label: False
                                 train_size: 0.8
                                 seed: None

                             <class 'argilla.training.span_marker.ArgillaSpanMarkerTrainer'> info:
                             _________________________________________________________________
                             The parameters are configurable via `trainer.update_config()`:
                                 'SpanMarkerModel'
                             pretrained_model_name_or_path: bert-base-cased
                             labels: ['O', 'B-LOC', 'I-LOC', 'B-MISC', 'I-MISC', 'B-ORG', 'I-ORG', 'B-PER', 'I-PER']
                             'Trainer'
                             overwrite_output_dir: False
                             do_train: False
                             do_eval: False
                             do_predict: False
                             evaluation_strategy: epoch
                             prediction_loss_only: False
                             per_device_train_batch_size: 8
                             per_device_eval_batch_size: 8
                             per_gpu_train_batch_size: None
                             per_gpu_eval_batch_size: None
                             gradient_accumulation_steps: 1
                             eval_accumulation_steps: None
                             eval_delay: 0
                             learning_rate: 5e-05
                             weight_decay: 0.01
                             adam_beta1: 0.9
                             adam_beta2: 0.999
                             adam_epsilon: 1e-08
                             max_grad_norm: 1.0
                             num_train_epochs: 3.0
                             max_steps: -1
                             lr_scheduler_type: linear
                             warmup_ratio: 0.0
                             warmup_steps: 0
                             log_level: passive
                             log_level_replica: warning
                             log_on_each_node: True
                             logging_dir: None
                             logging_strategy: steps
                             logging_first_step: False
                             logging_steps: 30
                             logging_nan_inf_filter: True
                             save_strategy: steps
                             save_steps: 500
                             save_total_limit: None
                             save_on_each_node: False
                             no_cuda: False
                             use_mps_device: False
                             seed: 42
                             data_seed: None
                             jit_mode_eval: False
                             use_ipex: False
                             bf16: False
                             fp16: False
                             fp16_opt_level: O1
                             half_precision_backend: auto
                             bf16_full_eval: False
                             fp16_full_eval: False
                             tf32: None
                             local_rank: -1
                             xpu_backend: None
                             tpu_num_cores: None
                             tpu_metrics_debug: False
                             debug:
                             dataloader_drop_last: False
                             eval_steps: None
                             dataloader_num_workers: 0
                             past_index: -1
                             run_name: None
                             disable_tqdm: None
                             remove_unused_columns: True
                             label_names: None
                             load_best_model_at_end: False
                             metric_for_best_model: None
                             greater_is_better: None
                             ignore_data_skip: False
                             sharded_ddp:
                             fsdp:
                             fsdp_min_num_params: 0
                             fsdp_config: None
                             fsdp_transformer_layer_cls_to_wrap: None
                             deepspeed: None
                             label_smoothing_factor: 0.0
                             optim: adamw_hf
                             optim_args: None
                             adafactor: False
                             group_by_length: False
                             length_column_name: length
                             report_to: None
                             ddp_find_unused_parameters: None
                             ddp_bucket_cap_mb: None
                             dataloader_pin_memory: True
                             skip_memory_metrics: True
                             use_legacy_prediction_loop: False
                             push_to_hub: False
                             resume_from_checkpoint: None
                             hub_model_id: None
                             hub_strategy: every_save
                             hub_token: None
                             hub_private_repo: False
                             gradient_checkpointing: False
                             include_inputs_for_metrics: False
                             fp16_backend: auto
                             push_to_hub_model_id: None
                             push_to_hub_organization: None
                             push_to_hub_token: None
                             mp_parameters:
                             auto_find_batch_size: False
                             full_determinism: False
                             torchdynamo: None
                             ray_scope: last
                             ddp_timeout: 1800
                             torch_compile: False
                             torch_compile_backend: None
                             torch_compile_mode: None
                             output_dir: None

                             Using the trainer:
                             _________________________________________________________________
                             `trainer.train(output_dir)` to train to start training. `output_dir` is the directory to save the model automatically.
                             `trainer.predict(text, as_argilla_records=True)` to make predictions.
                             `trainer.save(output_dir)` to save the model manually.
Some weights of the model checkpoint at bert-base-cased were not used when initializing BertModel: ['cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
C:\Users\tom\.conda\envs\argilla\lib\site-packages\transformers\optimization.py:395: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning
  FutureWarning,
{'loss': 0.2601, 'learning_rate': 4.754098360655738e-05, 'epoch': 0.49}
{'loss': 0.066, 'learning_rate': 4.508196721311476e-05, 'epoch': 0.98}                                                                                                                                 
{'eval_loss': 0.03822080418467522, 'eval_overall_precision': 0.8810289389067524, 'eval_overall_recall': 0.6372093023255814, 'eval_overall_f1': 0.7395411605937922, 'eval_overall_accuracy': 0.9204204204204204, 'eval_runtime': 4.5471, 'eval_samples_per_second': 57.619, 'eval_steps_per_second': 3.739, 'epoch': 1.0}                                                                                      
{'loss': 0.0378, 'learning_rate': 4.262295081967213e-05, 'epoch': 1.48}                                                                                                                                
{'loss': 0.029, 'learning_rate': 4.016393442622951e-05, 'epoch': 1.97}
{'eval_loss': 0.021275097504258156, 'eval_overall_precision': 0.8808933002481389, 'eval_overall_recall': 0.8255813953488372, 'eval_overall_f1': 0.8523409363745499, 'eval_overall_accuracy': 0.9602102102102102, 'eval_runtime': 4.5328, 'eval_samples_per_second': 57.8, 'eval_steps_per_second': 3.75, 'epoch': 2.0}                                                                                        
{'loss': 0.0169, 'learning_rate': 3.7704918032786885e-05, 'epoch': 2.46}
{'loss': 0.015, 'learning_rate': 3.524590163934427e-05, 'epoch': 2.95}
{'eval_loss': 0.013853945769369602, 'eval_overall_precision': 0.9527363184079602, 'eval_overall_recall': 0.8906976744186047, 'eval_overall_f1': 0.920673076923077, 'eval_overall_accuracy': 0.9774774774774775, 'eval_runtime': 4.4899, 'eval_samples_per_second': 58.353, 'eval_steps_per_second': 3.786, 'epoch': 3.0}                                                                                      
{'loss': 0.01, 'learning_rate': 3.2786885245901635e-05, 'epoch': 3.44}
{'loss': 0.0089, 'learning_rate': 3.0327868852459017e-05, 'epoch': 3.93}
{'eval_loss': 0.013151775114238262, 'eval_overall_precision': 0.948780487804878, 'eval_overall_recall': 0.9046511627906977, 'eval_overall_f1': 0.9261904761904762, 'eval_overall_accuracy': 0.9786036036036037, 'eval_runtime': 4.5472, 'eval_samples_per_second': 57.618, 'eval_steps_per_second': 3.739, 'epoch': 4.0}                                                                                      
{'loss': 0.0056, 'learning_rate': 2.7868852459016392e-05, 'epoch': 4.43}                                                                                                                               
{'loss': 0.0059, 'learning_rate': 2.540983606557377e-05, 'epoch': 4.92}
{'eval_loss': 0.012591547332704067, 'eval_overall_precision': 0.9587378640776699, 'eval_overall_recall': 0.9186046511627907, 'eval_overall_f1': 0.9382422802850356, 'eval_overall_accuracy': 0.9812312312312312, 'eval_runtime': 4.5184, 'eval_samples_per_second': 57.986, 'eval_steps_per_second': 3.762, 'epoch': 5.0}                                                                                     
{'loss': 0.0036, 'learning_rate': 2.295081967213115e-05, 'epoch': 5.41}
{'loss': 0.0044, 'learning_rate': 2.0491803278688525e-05, 'epoch': 5.9}
{'eval_loss': 0.012911035679280758, 'eval_overall_precision': 0.9539951573849879, 'eval_overall_recall': 0.9162790697674419, 'eval_overall_f1': 0.9347568208778173, 'eval_overall_accuracy': 0.9808558558558559, 'eval_runtime': 4.3026, 'eval_samples_per_second': 60.893, 'eval_steps_per_second': 3.951, 'epoch': 6.0}                                                                                     
{'loss': 0.0031, 'learning_rate': 1.8032786885245903e-05, 'epoch': 6.39}
{'loss': 0.0024, 'learning_rate': 1.557377049180328e-05, 'epoch': 6.89}
{'eval_loss': 0.010593990795314312, 'eval_overall_precision': 0.9567307692307693, 'eval_overall_recall': 0.9255813953488372, 'eval_overall_f1': 0.9408983451536642, 'eval_overall_accuracy': 0.9838588588588588, 'eval_runtime': 4.5172, 'eval_samples_per_second': 58.001, 'eval_steps_per_second': 3.763, 'epoch': 7.0}                                                                                     
{'loss': 0.0018, 'learning_rate': 1.3114754098360657e-05, 'epoch': 7.38}
{'loss': 0.0025, 'learning_rate': 1.0655737704918032e-05, 'epoch': 7.87}
{'eval_loss': 0.010297469794750214, 'eval_overall_precision': 0.9478672985781991, 'eval_overall_recall': 0.9302325581395349, 'eval_overall_f1': 0.9389671361502349, 'eval_overall_accuracy': 0.9846096096096096, 'eval_runtime': 4.4597, 'eval_samples_per_second': 58.748, 'eval_steps_per_second': 3.812, 'epoch': 8.0}                                                                                     
{'loss': 0.0011, 'learning_rate': 8.196721311475409e-06, 'epoch': 8.36}                                                                                                                                
{'loss': 0.0009, 'learning_rate': 5.737704918032787e-06, 'epoch': 8.85}
{'eval_loss': 0.009832620620727539, 'eval_overall_precision': 0.9478672985781991, 'eval_overall_recall': 0.9302325581395349, 'eval_overall_f1': 0.9389671361502349, 'eval_overall_accuracy': 0.9846096096096096, 'eval_runtime': 4.6756, 'eval_samples_per_second': 56.036, 'eval_steps_per_second': 3.636, 'epoch': 9.0}                                                                                     
{'loss': 0.0015, 'learning_rate': 3.278688524590164e-06, 'epoch': 9.34}                                                                                                                                
{'loss': 0.0013, 'learning_rate': 8.19672131147541e-07, 'epoch': 9.84}
{'eval_loss': 0.010347267612814903, 'eval_overall_precision': 0.9457547169811321, 'eval_overall_recall': 0.9325581395348838, 'eval_overall_f1': 0.9391100702576113, 'eval_overall_accuracy': 0.9846096096096096, 'eval_runtime': 4.7101, 'eval_samples_per_second': 55.625, 'eval_steps_per_second': 3.609, 'epoch': 10.0}                                                                                    
{'train_runtime': 297.0633, 'train_samples_per_second': 32.855, 'train_steps_per_second': 2.053, 'train_loss': 0.023517039891515597, 'epoch': 10.0}                                                   
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 610/610 [04:57<00:00,  2.05it/s] 
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 17/17 [00:04<00:00,  3.84it/s] 
[04/13/23 16:30:34] INFO     INFO:ArgillaSpanMarkerTrainer:{'eval_loss': 0.010347267612814903, 'eval_overall_precision': 0.9457547169811321, 'eval_overall_recall':                  span_marker.py:133
                             0.9325581395348838, 'eval_overall_f1': 0.9391100702576113, 'eval_overall_accuracy': 0.9846096096096096, 'eval_runtime': 4.6647,
                             'eval_samples_per_second': 56.166, 'eval_steps_per_second': 3.644, 'epoch': 10.0}

In short, I trained to 0.939 eval F1 on CoNLL03 in 5 minutes.

Type of change

  • New feature (non-breaking change which adds functionality)

How Has This Been Tested

Tests still need to be written. I'll be working on this - but I'll publish this as a draft already so it's available for reviews already.

  • Tom Aarsen

@tomaarsen tomaarsen marked this pull request as ready for review April 18, 2023 14:00
@codecov
Copy link

codecov bot commented Apr 19, 2023

Codecov Report

Patch coverage: 97.91% and project coverage change: +0.08 🎉

Comparison is base (857c569) 92.26% compared to head (127e311) 92.34%.

❗ Current head 127e311 differs from pull request most recent head 3715ad9. Consider uploading reports for the commit 3715ad9 to get more accurate results

Additional details and impacted files
@@                                 Coverage Diff                                  @@
##           feat/2658-add-argilla-training-module-for-openai    #2693      +/-   ##
====================================================================================
+ Coverage                                             92.26%   92.34%   +0.08%     
====================================================================================
  Files                                                   171      172       +1     
  Lines                                                  9029     9124      +95     
====================================================================================
+ Hits                                                   8331     8426      +95     
  Misses                                                  698      698              
Flag Coverage Δ
pytest 92.34% <97.91%> (+0.08%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

Impacted Files Coverage Δ
src/argilla/training/span_marker.py 97.77% <97.77%> (ø)
src/argilla/client/datasets.py 87.72% <100.00%> (+0.16%) ⬆️
src/argilla/client/models.py 94.52% <100.00%> (+0.02%) ⬆️
src/argilla/training/base.py 89.53% <100.00%> (+1.73%) ⬆️

... and 1 file with indirect coverage changes

Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here.

☔ View full report in Codecov by Sentry.
📢 Do you have feedback about the report comment? Let us know in this issue.

@tomaarsen
Copy link
Contributor Author

Failing test originates from the upstream: #2691

Co-authored-by: David Berenstein <david.m.berenstein@gmail.com>
@davidberenstein1957 davidberenstein1957 merged commit f891a47 into argilla-io:feat/2658-add-argilla-training-module-for-openai Apr 25, 2023
@tomaarsen tomaarsen deleted the feat/span_marker_trainer branch April 25, 2023 10:22
@frascuchon frascuchon mentioned this pull request May 9, 2023
@frascuchon frascuchon added this to the v1.7.0 milestone May 10, 2023
frascuchon added a commit that referenced this pull request May 10, 2023
##
[1.7.0](v1.6.0...v1.7.0)

### Added

- add `max_retries` and `num_threads` parameters to `rg.log` to run data
logging request concurrently with backoff retry policy. See
[#2458](#2458) and
[#2533](#2533)
- `rg.load` accepts `include_vectors` and `include_metrics` when loading
data. Closes [#2398](#2398)
- Added `settings` param to `prepare_for_training`
([#2689](#2689))
- Added `prepare_for_training` for `openai`
([#2658](#2658))
- Added `ArgillaOpenAITrainer`
([#2659](#2659))
- Added `ArgillaSpanMarkerTrainer` for Named Entity Recognition
([#2693](#2693))
- Added `ArgillaTrainer` CLI support. Closes
([#2809](#2809))

### Changed

- Argilla quickstart image dependencies are externalized into
`quickstart.requirements.txt`. See
[#2666](#2666)
- bulk endpoints will upsert data when record `id` is present. Closes
[#2535](#2535)
- moved from `click` to `typer` CLI support. Closes
([#2815](#2815))
- Argilla server docker image is built with PostgreSQL support. Closes
[#2686](#2686)
- The `rg.log` computes all batches and raise an error for all failed
batches.
- The default batch size for `rg.log` is now 100.

### Fixed

- `argilla.training` bugfixes and unification
([#2665](#2665))
- Resolved several small bugs in the `ArgillaTrainer`.

### Deprecated

- The `rg.log_async` function is deprecated and will be removed in next
minor release.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants