Skip to content

Commit

Permalink
Merge pull request #491 from QData/CUDA_out_of_memory_fix
Browse files Browse the repository at this point in the history
document fix regarding the new training commands
  • Loading branch information
qiyanjun committed Jul 27, 2021
2 parents ac8872a + 17ab22c commit 8f36926
Show file tree
Hide file tree
Showing 10 changed files with 41 additions and 49 deletions.
9 changes: 2 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -380,18 +380,13 @@ automatically loaded using the `datasets` package.
#### Training Examples
*Train our default LSTM for 50 epochs on the Yelp Polarity dataset:*
```bash
textattack train --model lstm --dataset yelp_polarity --batch-size 64 --epochs 50 --learning-rate 1e-5
textattack train --model-name-or-path lstm --dataset yelp_polarity --epochs 50 --learning-rate 1e-5
```

The training process has data augmentation built-in:
```bash
textattack train --model lstm --dataset rotten_tomatoes --augment eda --pct-words-to-swap .1 --transformations-per-example 4
```
This uses the `EasyDataAugmenter` recipe to augment the `rotten_tomatoes` dataset before training.

*Fine-Tune `bert-base` on the `CoLA` dataset for 5 epochs**:
```bash
textattack train --model bert-base-uncased --dataset glue^cola --batch-size 32 --epochs 5
textattack train --model-name-or-path bert-base-uncased --dataset glue^cola --per-device-train-batch-size 8 --epochs 5
```


Expand Down
9 changes: 2 additions & 7 deletions README_ZH.md
Original file line number Diff line number Diff line change
Expand Up @@ -362,18 +362,13 @@ it's a enigma how the filmmaking wo be publicized in this condition .,0
#### 运行训练的例子
*在 Yelp 分类数据集上对 TextAttack 中默认的 LSTM 模型训练 50 个 epoch:*
```bash
textattack train --model lstm --dataset yelp_polarity --batch-size 64 --epochs 50 --learning-rate 1e-5
textattack train --model-name-or-path lstm --dataset yelp_polarity --epochs 50 --learning-rate 1e-5
```

训练接口中同样内置了数据增强功能:
```bash
textattack train --model lstm --dataset rotten_tomatoes --augment eda --pct-words-to-swap .1 --transformations-per-example 4
```
上面这个例子在训练之前使用 `EasyDataAugmenter` 策略对 `rotten_tomatoes` 数据集进行数据增强。

*`CoLA` 数据集上对 `bert-base` 模型精调 5 个 epoch:*
```bash
textattack train --model bert-base-uncased --dataset glue^cola --batch-size 32 --epochs 5
textattack train --model-name-or-path bert-base-uncased --dataset glue^cola --per-device-train-batch-size 8 --epochs 5
```


Expand Down
9 changes: 2 additions & 7 deletions docs/1start/FAQ.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,18 +54,13 @@ pip install .[dev]

For example, you can *Train our default LSTM for 50 epochs on the Yelp Polarity dataset:*
```bash
textattack train --model lstm --dataset yelp_polarity --batch-size 64 --epochs 50 --learning-rate 1e-5
textattack train --model-name-or-path lstm --dataset yelp_polarity --epochs 50 --learning-rate 1e-5
```

The training process has data augmentation built-in:
```bash
textattack train --model lstm --dataset rotten_tomatoes --augment eda --pct-words-to-swap .1 --transformations-per-example 4
```
This uses the `EasyDataAugmenter` recipe to augment the `rotten_tomatoes` dataset before training.

*Fine-Tune `bert-base` on the `CoLA` dataset for 5 epochs**:
```bash
textattack train --model bert-base-uncased --dataset glue^cola --batch-size 32 --epochs 5
textattack train --model-name-or-path bert-base-uncased --dataset glue^cola --per-device-train-batch-size 8 --epochs 5
```


Expand Down
29 changes: 28 additions & 1 deletion docs/3recipes/augmenter_recipes.rst
Original file line number Diff line number Diff line change
@@ -1,7 +1,34 @@
Augmenter Recipes API
=====================

Transformations and constraints can be used for simple NLP data augmentations. Here is a list of recipes for NLP data augmentations
Summary: Transformations and constraints can be used for simple NLP data augmentations.

In addition to the command-line interface, you can augment text dynamically by importing the
`Augmenter` in your own code. All `Augmenter` objects implement `augment` and `augment_many` to generate augmentations
of a string or a list of strings. Here's an example of how to use the `EmbeddingAugmenter` in a python script:

```python
>>> from textattack.augmentation import EmbeddingAugmenter
>>> augmenter = EmbeddingAugmenter()
>>> s = 'What I cannot create, I do not understand.'
>>> augmenter.augment(s)
['What I notable create, I do not understand.', 'What I significant create, I do not understand.', 'What I cannot engender, I do not understand.', 'What I cannot creating, I do not understand.', 'What I cannot creations, I do not understand.', 'What I cannot create, I do not comprehend.', 'What I cannot create, I do not fathom.', 'What I cannot create, I do not understanding.', 'What I cannot create, I do not understands.', 'What I cannot create, I do not understood.', 'What I cannot create, I do not realise.']
```
You can also create your own augmenter from scratch by importing transformations/constraints from `textattack.transformations` and `textattack.constraints`. Here's an example that generates augmentations of a string using `WordSwapRandomCharacterDeletion`:

```python
>>> from textattack.transformations import WordSwapRandomCharacterDeletion
>>> from textattack.transformations import CompositeTransformation
>>> from textattack.augmentation import Augmenter
>>> transformation = CompositeTransformation([WordSwapRandomCharacterDeletion()])
>>> augmenter = Augmenter(transformation=transformation, transformations_per_example=5)
>>> s = 'What I cannot create, I do not understand.'
>>> augmenter.augment(s)
['What I cannot creae, I do not understand.', 'What I cannot creat, I do not understand.', 'What I cannot create, I do not nderstand.', 'What I cannot create, I do nt understand.', 'Wht I cannot create, I do not understand.']
```


Here is a list of recipes for NLP data augmentations

.. automodule:: textattack.augmentation.recipes
:members:
Expand Down
24 changes: 0 additions & 24 deletions docs/3recipes/augmenter_recipes_cmd.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,27 +61,3 @@ it's a enigma how the filmmaking wo be publicized in this condition .,0

The 'embedding' augmentation recipe uses counterfitted embedding nearest-neighbors to augment data.

### Augmentation Python API Interface
In addition to the command-line interface, you can augment text dynamically by importing the
`Augmenter` in your own code. All `Augmenter` objects implement `augment` and `augment_many` to generate augmentations
of a string or a list of strings. Here's an example of how to use the `EmbeddingAugmenter` in a python script:

```python
>>> from textattack.augmentation import EmbeddingAugmenter
>>> augmenter = EmbeddingAugmenter()
>>> s = 'What I cannot create, I do not understand.'
>>> augmenter.augment(s)
['What I notable create, I do not understand.', 'What I significant create, I do not understand.', 'What I cannot engender, I do not understand.', 'What I cannot creating, I do not understand.', 'What I cannot creations, I do not understand.', 'What I cannot create, I do not comprehend.', 'What I cannot create, I do not fathom.', 'What I cannot create, I do not understanding.', 'What I cannot create, I do not understands.', 'What I cannot create, I do not understood.', 'What I cannot create, I do not realise.']
```
You can also create your own augmenter from scratch by importing transformations/constraints from `textattack.transformations` and `textattack.constraints`. Here's an example that generates augmentations of a string using `WordSwapRandomCharacterDeletion`:

```python
>>> from textattack.transformations import WordSwapRandomCharacterDeletion
>>> from textattack.transformations import CompositeTransformation
>>> from textattack.augmentation import Augmenter
>>> transformation = CompositeTransformation([WordSwapRandomCharacterDeletion()])
>>> augmenter = Augmenter(transformation=transformation, transformations_per_example=5)
>>> s = 'What I cannot create, I do not understand.'
>>> augmenter.augment(s)
['What I cannot creae, I do not understand.', 'What I cannot creat, I do not understand.', 'What I cannot create, I do not nderstand.', 'What I cannot create, I do nt understand.', 'Wht I cannot create, I do not understand.']
```
2 changes: 1 addition & 1 deletion examples/train/train_albert_snli_entailment.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
# Trains `bert-base-cased` on the STS-B task for 3 epochs. This is a
# demonstration of how our training script can handle different `transformers`
# models and customize for different datasets.
textattack train --model albert-base-v2 --dataset snli --batch-size 128 --epochs 5 --max-length 128 --learning-rate 1e-5 --allowed-labels 0 1 2
textattack train --model-name-or-path albert-base-v2 --dataset snli --per-device-train-batch-size 8 --epochs 5 --learning-rate 1e-5
2 changes: 1 addition & 1 deletion examples/train/train_bert_stsb_similarity.sh
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#!/bin/bash
# Trains `bert-base-cased` on the STS-B task for 3 epochs. This is a demonstration
# of how our training script handles regression.
textattack train --model bert-base-cased --dataset glue^stsb --batch-size 128 --epochs 3 --max-length 128 --learning-rate 1e-5
textattack train --model-name-or-path bert-base-cased --dataset glue^stsb --epochs 3 --learning-rate 1e-5
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#!/bin/bash
# Trains `bert-base-cased` on the STS-B task for 3 epochs. This is a basic
# demonstration of our training script and `datasets` integration.
textattack train --model lstm --dataset rotten_romatoes --batch-size 64 --epochs 50 --learning-rate 1e-5
textattack train --model-name-or-path lstm --dataset rotten_romatoes --epochs 50 --learning-rate 1e-5
2 changes: 2 additions & 0 deletions textattack/models/wrappers/huggingface_model_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@

from .pytorch_model_wrapper import PyTorchModelWrapper

torch.cuda.empty_cache()


class HuggingFaceModelWrapper(PyTorchModelWrapper):
"""Loads a HuggingFace ``transformers`` model and tokenizer."""
Expand Down
2 changes: 2 additions & 0 deletions textattack/models/wrappers/pytorch_model_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@

from .model_wrapper import ModelWrapper

torch.cuda.empty_cache()


class PyTorchModelWrapper(ModelWrapper):
"""Loads a PyTorch model (`nn.Module`) and tokenizer.
Expand Down

0 comments on commit 8f36926

Please sign in to comment.