Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/2658 add argilla training module for openai with several bug fixes #2691

Conversation

davidberenstein1957
Copy link
Member

Description

Updated the argilla.training integration

Closes #2658
Closes #2665
Closes #2659

Type of change

(Please delete options that are not relevant. Remember to title the PR according to the type of change)

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Refactor (change restructuring the codebase without changing functionality)
  • Improvement (change adding some improvement to an existing functionality)
  • Documentation update

How Has This Been Tested

(Please describe the tests that you ran to verify your changes. And ideally, reference tests)

argilla/tests/training/*

Checklist

  • I have merged the original branch into my forked branch
  • I added relevant documentation
  • follows the style guidelines of this project
  • I did a self-review of my code
  • I made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • I have added relevant notes to the CHANGELOG.md file (See https://keepachangelog.com/)

@davidberenstein1957 davidberenstein1957 marked this pull request as draft April 13, 2023 11:35
Copy link
Contributor

@tomaarsen tomaarsen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lots of small improvements, I like it. I found a few nits that should be pretty easy to resolve.

src/argilla/client/datasets.py Outdated Show resolved Hide resolved
src/argilla/client/datasets.py Outdated Show resolved Hide resolved
src/argilla/client/datasets.py Outdated Show resolved Hide resolved
src/argilla/client/datasets.py Outdated Show resolved Hide resolved
src/argilla/client/datasets.py Outdated Show resolved Hide resolved
src/argilla/training/openai.py Outdated Show resolved Hide resolved
src/argilla/training/openai.py Outdated Show resolved Hide resolved
src/argilla/training/spacy.py Outdated Show resolved Hide resolved
src/argilla/training/spacy.py Show resolved Hide resolved
src/argilla/training/transformers.py Outdated Show resolved Hide resolved
@tomaarsen
Copy link
Contributor

Can we change this function name?

def cleanup(trainer: ArgillaTrainer, output_dir: str, train: bool = True):
try:
if train:
trainer.train(output_dir)
else:
trainer.save(output_dir)
assert Path(output_dir).exists()
finally:
shutil.rmtree(output_dir)

Perhaps into train_with_cleanup or something. It's not particularly intuitive that a "cleanup" method actually performs training.

Hello!

## Pull Request overview
* Add [SpanMarker](https://github.com/tomaarsen/SpanMarkerNER) Argilla
Trainer for Named Entity Recognition.

## Details
The SpanMarker Argilla trainer is based on the Transformers Trainer, as
SpanMarker is tightly implemented on top of transformers. However, we
don't need to do the tokenization, data collation or evaluation on the
Argilla side, unlike with Transformers. This makes the SpanMarker
Argilla trainer relatively small.

## Usage
First, we need an annotated dataset:
```python
import argilla as rg
from datasets import load_dataset

dataset = "conll2003"
dataset_ds = load_dataset(
    "conll2003",
    split="train[:1000]",
)
dataset_ds = dataset_ds.rename_column("ner_tags", "tags")
dataset_rb = rg.read_datasets(dataset_ds, task="TokenClassification")

rg.delete(dataset)
rg.log(name=dataset, records=dataset_rb)
```
And then we can use the new Trainer to train with this dataset:
```python
import argilla as rg
from argilla.training.base import ArgillaTrainer

dataset = "conll2003"

trainer = ArgillaTrainer(name=dataset, framework="span_marker", train_size=0.8)
trainer.update_config(
    num_train_epochs=10,
    bf16=True,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    marker_max_length=128,
    entity_max_length=8,
)
trainer.train(output_dir="tmp_span_marker_train")
```
(You can use lower batch sizes or `model_max_length=256` if you have
memory issues. You can also use `fp16` instead of `bf16` if you get an
error.)

This produces the following logs:

<details><summary>Click to see the logs</summary>

```
[04/13/23 16:25:25] WARNING  WARNING:argilla.client.datasets:No label schema provided. Using all_labels: TokenClassificationSettings(['LOC', 'MISC', 'ORG', 'PER']). We recommend      datasets.py:1222
                             providing a `TokenClassificationSettings()` or setting it `rg.configure_dataset_settings()`/`rg.load_dataset_settings()` to ensure reproducibility.       
[04/13/23 16:25:30] WARNING  WARNING:ArgillaTrainer:ArgillaBaseTrainer info:                                                                                                                base.py:175
                             _________________________________________________________________
                             These baseline params are fixed:
                                 dataset: conll2003
                                 task: DatasetForTokenClassification
                                 multi_label: False
                                 train_size: 0.8
                                 seed: None

                             <class 'argilla.training.span_marker.ArgillaSpanMarkerTrainer'> info:
                             _________________________________________________________________
                             The parameters are configurable via `trainer.update_config()`:
                                 'SpanMarkerModel'
                             pretrained_model_name_or_path: bert-base-cased
                             labels: ['O', 'B-LOC', 'I-LOC', 'B-MISC', 'I-MISC', 'B-ORG', 'I-ORG', 'B-PER', 'I-PER']
                             'Trainer'
                             overwrite_output_dir: False
                             do_train: False
                             do_eval: False
                             do_predict: False
                             evaluation_strategy: epoch
                             prediction_loss_only: False
                             per_device_train_batch_size: 8
                             per_device_eval_batch_size: 8
                             per_gpu_train_batch_size: None
                             per_gpu_eval_batch_size: None
                             gradient_accumulation_steps: 1
                             eval_accumulation_steps: None
                             eval_delay: 0
                             learning_rate: 5e-05
                             weight_decay: 0.01
                             adam_beta1: 0.9
                             adam_beta2: 0.999
                             adam_epsilon: 1e-08
                             max_grad_norm: 1.0
                             num_train_epochs: 3.0
                             max_steps: -1
                             lr_scheduler_type: linear
                             warmup_ratio: 0.0
                             warmup_steps: 0
                             log_level: passive
                             log_level_replica: warning
                             log_on_each_node: True
                             logging_dir: None
                             logging_strategy: steps
                             logging_first_step: False
                             logging_steps: 30
                             logging_nan_inf_filter: True
                             save_strategy: steps
                             save_steps: 500
                             save_total_limit: None
                             save_on_each_node: False
                             no_cuda: False
                             use_mps_device: False
                             seed: 42
                             data_seed: None
                             jit_mode_eval: False
                             use_ipex: False
                             bf16: False
                             fp16: False
                             fp16_opt_level: O1
                             half_precision_backend: auto
                             bf16_full_eval: False
                             fp16_full_eval: False
                             tf32: None
                             local_rank: -1
                             xpu_backend: None
                             tpu_num_cores: None
                             tpu_metrics_debug: False
                             debug:
                             dataloader_drop_last: False
                             eval_steps: None
                             dataloader_num_workers: 0
                             past_index: -1
                             run_name: None
                             disable_tqdm: None
                             remove_unused_columns: True
                             label_names: None
                             load_best_model_at_end: False
                             metric_for_best_model: None
                             greater_is_better: None
                             ignore_data_skip: False
                             sharded_ddp:
                             fsdp:
                             fsdp_min_num_params: 0
                             fsdp_config: None
                             fsdp_transformer_layer_cls_to_wrap: None
                             deepspeed: None
                             label_smoothing_factor: 0.0
                             optim: adamw_hf
                             optim_args: None
                             adafactor: False
                             group_by_length: False
                             length_column_name: length
                             report_to: None
                             ddp_find_unused_parameters: None
                             ddp_bucket_cap_mb: None
                             dataloader_pin_memory: True
                             skip_memory_metrics: True
                             use_legacy_prediction_loop: False
                             push_to_hub: False
                             resume_from_checkpoint: None
                             hub_model_id: None
                             hub_strategy: every_save
                             hub_token: None
                             hub_private_repo: False
                             gradient_checkpointing: False
                             include_inputs_for_metrics: False
                             fp16_backend: auto
                             push_to_hub_model_id: None
                             push_to_hub_organization: None
                             push_to_hub_token: None
                             mp_parameters:
                             auto_find_batch_size: False
                             full_determinism: False
                             torchdynamo: None
                             ray_scope: last
                             ddp_timeout: 1800
                             torch_compile: False
                             torch_compile_backend: None
                             torch_compile_mode: None
                             output_dir: None

                             Using the trainer:
                             _________________________________________________________________
                             `trainer.train(output_dir)` to train to start training. `output_dir` is the directory to save the model automatically.
                             `trainer.predict(text, as_argilla_records=True)` to make predictions.
                             `trainer.save(output_dir)` to save the model manually.
Some weights of the model checkpoint at bert-base-cased were not used when initializing BertModel: ['cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
C:\Users\tom\.conda\envs\argilla\lib\site-packages\transformers\optimization.py:395: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning
  FutureWarning,
{'loss': 0.2601, 'learning_rate': 4.754098360655738e-05, 'epoch': 0.49}
{'loss': 0.066, 'learning_rate': 4.508196721311476e-05, 'epoch': 0.98}                                                                                                                                 
{'eval_loss': 0.03822080418467522, 'eval_overall_precision': 0.8810289389067524, 'eval_overall_recall': 0.6372093023255814, 'eval_overall_f1': 0.7395411605937922, 'eval_overall_accuracy': 0.9204204204204204, 'eval_runtime': 4.5471, 'eval_samples_per_second': 57.619, 'eval_steps_per_second': 3.739, 'epoch': 1.0}                                                                                      
{'loss': 0.0378, 'learning_rate': 4.262295081967213e-05, 'epoch': 1.48}                                                                                                                                
{'loss': 0.029, 'learning_rate': 4.016393442622951e-05, 'epoch': 1.97}
{'eval_loss': 0.021275097504258156, 'eval_overall_precision': 0.8808933002481389, 'eval_overall_recall': 0.8255813953488372, 'eval_overall_f1': 0.8523409363745499, 'eval_overall_accuracy': 0.9602102102102102, 'eval_runtime': 4.5328, 'eval_samples_per_second': 57.8, 'eval_steps_per_second': 3.75, 'epoch': 2.0}                                                                                        
{'loss': 0.0169, 'learning_rate': 3.7704918032786885e-05, 'epoch': 2.46}
{'loss': 0.015, 'learning_rate': 3.524590163934427e-05, 'epoch': 2.95}
{'eval_loss': 0.013853945769369602, 'eval_overall_precision': 0.9527363184079602, 'eval_overall_recall': 0.8906976744186047, 'eval_overall_f1': 0.920673076923077, 'eval_overall_accuracy': 0.9774774774774775, 'eval_runtime': 4.4899, 'eval_samples_per_second': 58.353, 'eval_steps_per_second': 3.786, 'epoch': 3.0}                                                                                      
{'loss': 0.01, 'learning_rate': 3.2786885245901635e-05, 'epoch': 3.44}
{'loss': 0.0089, 'learning_rate': 3.0327868852459017e-05, 'epoch': 3.93}
{'eval_loss': 0.013151775114238262, 'eval_overall_precision': 0.948780487804878, 'eval_overall_recall': 0.9046511627906977, 'eval_overall_f1': 0.9261904761904762, 'eval_overall_accuracy': 0.9786036036036037, 'eval_runtime': 4.5472, 'eval_samples_per_second': 57.618, 'eval_steps_per_second': 3.739, 'epoch': 4.0}                                                                                      
{'loss': 0.0056, 'learning_rate': 2.7868852459016392e-05, 'epoch': 4.43}                                                                                                                               
{'loss': 0.0059, 'learning_rate': 2.540983606557377e-05, 'epoch': 4.92}
{'eval_loss': 0.012591547332704067, 'eval_overall_precision': 0.9587378640776699, 'eval_overall_recall': 0.9186046511627907, 'eval_overall_f1': 0.9382422802850356, 'eval_overall_accuracy': 0.9812312312312312, 'eval_runtime': 4.5184, 'eval_samples_per_second': 57.986, 'eval_steps_per_second': 3.762, 'epoch': 5.0}                                                                                     
{'loss': 0.0036, 'learning_rate': 2.295081967213115e-05, 'epoch': 5.41}
{'loss': 0.0044, 'learning_rate': 2.0491803278688525e-05, 'epoch': 5.9}
{'eval_loss': 0.012911035679280758, 'eval_overall_precision': 0.9539951573849879, 'eval_overall_recall': 0.9162790697674419, 'eval_overall_f1': 0.9347568208778173, 'eval_overall_accuracy': 0.9808558558558559, 'eval_runtime': 4.3026, 'eval_samples_per_second': 60.893, 'eval_steps_per_second': 3.951, 'epoch': 6.0}                                                                                     
{'loss': 0.0031, 'learning_rate': 1.8032786885245903e-05, 'epoch': 6.39}
{'loss': 0.0024, 'learning_rate': 1.557377049180328e-05, 'epoch': 6.89}
{'eval_loss': 0.010593990795314312, 'eval_overall_precision': 0.9567307692307693, 'eval_overall_recall': 0.9255813953488372, 'eval_overall_f1': 0.9408983451536642, 'eval_overall_accuracy': 0.9838588588588588, 'eval_runtime': 4.5172, 'eval_samples_per_second': 58.001, 'eval_steps_per_second': 3.763, 'epoch': 7.0}                                                                                     
{'loss': 0.0018, 'learning_rate': 1.3114754098360657e-05, 'epoch': 7.38}
{'loss': 0.0025, 'learning_rate': 1.0655737704918032e-05, 'epoch': 7.87}
{'eval_loss': 0.010297469794750214, 'eval_overall_precision': 0.9478672985781991, 'eval_overall_recall': 0.9302325581395349, 'eval_overall_f1': 0.9389671361502349, 'eval_overall_accuracy': 0.9846096096096096, 'eval_runtime': 4.4597, 'eval_samples_per_second': 58.748, 'eval_steps_per_second': 3.812, 'epoch': 8.0}                                                                                     
{'loss': 0.0011, 'learning_rate': 8.196721311475409e-06, 'epoch': 8.36}                                                                                                                                
{'loss': 0.0009, 'learning_rate': 5.737704918032787e-06, 'epoch': 8.85}
{'eval_loss': 0.009832620620727539, 'eval_overall_precision': 0.9478672985781991, 'eval_overall_recall': 0.9302325581395349, 'eval_overall_f1': 0.9389671361502349, 'eval_overall_accuracy': 0.9846096096096096, 'eval_runtime': 4.6756, 'eval_samples_per_second': 56.036, 'eval_steps_per_second': 3.636, 'epoch': 9.0}                                                                                     
{'loss': 0.0015, 'learning_rate': 3.278688524590164e-06, 'epoch': 9.34}                                                                                                                                
{'loss': 0.0013, 'learning_rate': 8.19672131147541e-07, 'epoch': 9.84}
{'eval_loss': 0.010347267612814903, 'eval_overall_precision': 0.9457547169811321, 'eval_overall_recall': 0.9325581395348838, 'eval_overall_f1': 0.9391100702576113, 'eval_overall_accuracy': 0.9846096096096096, 'eval_runtime': 4.7101, 'eval_samples_per_second': 55.625, 'eval_steps_per_second': 3.609, 'epoch': 10.0}                                                                                    
{'train_runtime': 297.0633, 'train_samples_per_second': 32.855, 'train_steps_per_second': 2.053, 'train_loss': 0.023517039891515597, 'epoch': 10.0}                                                   
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 610/610 [04:57<00:00,  2.05it/s] 
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 17/17 [00:04<00:00,  3.84it/s] 
[04/13/23 16:30:34] INFO     INFO:ArgillaSpanMarkerTrainer:{'eval_loss': 0.010347267612814903, 'eval_overall_precision': 0.9457547169811321, 'eval_overall_recall':                  span_marker.py:133
                             0.9325581395348838, 'eval_overall_f1': 0.9391100702576113, 'eval_overall_accuracy': 0.9846096096096096, 'eval_runtime': 4.6647,
                             'eval_samples_per_second': 56.166, 'eval_steps_per_second': 3.644, 'epoch': 10.0}
```
</details>

In short, I trained to 0.939 eval F1 on CoNLL03 in 5 minutes.

### Type of change

- [x] New feature (non-breaking change which adds functionality)

### How Has This Been Tested
Tests still need to be written. I'll be working on this - but I'll
publish this as a draft already so it's available for reviews already.

<!--
**Checklist**

- [ ] I have merged the original branch into my forked branch
- [ ] I added relevant documentation
- [ ] follows the style guidelines of this project
- [ ] I did a self-review of my code
- [ ] I made corresponding changes to the documentation
- [ ] My changes generate no new warnings
- [ ] I have added tests that prove my fix is effective or that my
feature works
- [ ] I have added relevant notes to the CHANGELOG.md file (See
https://keepachangelog.com/)
-->

- Tom Aarsen

---------

Co-authored-by: David Berenstein <david.m.berenstein@gmail.com>
src/argilla/training/openai.py Outdated Show resolved Hide resolved
src/argilla/datasets/__init__.py Show resolved Hide resolved
davidberenstein1957 and others added 2 commits April 26, 2023 09:45
Co-authored-by: Tom Aarsen <37621491+tomaarsen@users.noreply.github.com>
Copy link
Member

@alvarobartt alvarobartt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just some minor comments cc. @davidberenstein1957

src/argilla/training/openai.py Outdated Show resolved Hide resolved
tests/training/test_openai.py Outdated Show resolved Hide resolved
Copy link
Contributor

@tomaarsen tomaarsen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

openai isn't in the CI dependencies yet.

@davidberenstein1957 davidberenstein1957 marked this pull request as ready for review April 29, 2023 07:01
@tomaarsen tomaarsen mentioned this pull request May 2, 2023
pyproject.toml Outdated Show resolved Hide resolved
davidberenstein1957 and others added 2 commits May 2, 2023 11:07
Co-authored-by: Tom Aarsen <37621491+tomaarsen@users.noreply.github.com>
fix: replaced pass for docstrings abc
chore: added tests spacy wo training
chore: added tests transformers wo training
chore: added specific setfit versioning
@davidberenstein1957 davidberenstein1957 changed the base branch from develop to releases/1.7.0 May 3, 2023 09:44
@davidberenstein1957 davidberenstein1957 merged commit fbb3533 into releases/1.7.0 May 3, 2023
@davidberenstein1957 davidberenstein1957 deleted the feat/2658-add-argilla-training-module-for-openai branch May 3, 2023 10:40
@frascuchon frascuchon restored the feat/2658-add-argilla-training-module-for-openai branch May 8, 2023 12:03
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
3 participants