Skip to content

Commit

Permalink
Merge pull request #223 from EricFillion/ef/gen-cuda
Browse files Browse the repository at this point in the history
Implemented Pipeline For Text Generation
  • Loading branch information
EricFillion committed May 8, 2021
2 parents fbb5eb5 + 5d9ae19 commit 7706600
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 30 deletions.
13 changes: 0 additions & 13 deletions docs/pages/1-text-generation/4-finetuning.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,13 +58,6 @@ mlm_probability: The probability of masking a token.
args = GENTrainArgs(num_train_epochs=1)
happy_gen.train("../../data/gen/train-eval.txt", args=args)
```
Note: if you wish to use HappyGeneration.generate_text() after training then first run the following command:

```python
happy_gen.model.to("cpu")

```


### eval()
Input:
Expand Down Expand Up @@ -100,9 +93,3 @@ Output: An object with the field "loss"
print(result.loss) # 3.3437771797180176

```
Note: if you wish to use HappyGeneration.generate_text() after evaluating then first run the following command:

```python
happy_gen.model.to("cpu")

```
30 changes: 15 additions & 15 deletions happytransformer/happy_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@
Contains the HappyGeneration class
"""
from dataclasses import dataclass
from transformers import AutoModelForCausalLM
from transformers import AutoModelForCausalLM, TextGenerationPipeline
from happytransformer.happy_transformer import HappyTransformer
from happytransformer.gen.trainer import GENTrainer, GENTrainArgs, GENEvalArgs
from happytransformer.adaptors import get_adaptor
from happytransformer.gen import ARGS_GEN_TRAIN, ARGS_GEN_EVAl, ARGS_GEN_TEST
from happytransformer.happy_trainer import EvalResult
from happytransformer.fine_tuning_util import create_args_dataclass
from happytransformer.cuda_detect import detect_cuda_device_number

"""
The main settings that users will adjust when performing experiments
Expand Down Expand Up @@ -50,6 +51,9 @@ def __init__(self, model_type: str = "GPT2", model_name: str = "gpt2", load_path
model = AutoModelForCausalLM.from_pretrained(model_name)

super().__init__(model_type, model_name, model)
device_number = detect_cuda_device_number()

self._pipeline = TextGenerationPipeline(model=self.model, tokenizer=self.tokenizer, device=device_number)

self._trainer = GENTrainer(self.model, model_type, self.tokenizer, self._device, self.logger)

Expand Down Expand Up @@ -78,20 +82,16 @@ def generate_text(self, text: str, args: GENSettings=GENSettings()) -> Generatio
adjusted_min_length = args.min_length + len(input_ids[0])
adjusted_max_length = args.max_length + len(input_ids[0])

output = self.model.generate(input_ids,
min_length=adjusted_min_length,
max_length=adjusted_max_length,
do_sample=args.do_sample,
early_stopping=args.early_stopping,
num_beams=args.num_beams,
temperature=args.temperature,
top_k=args.top_k,
no_repeat_ngram_size=args.no_repeat_ngram_size
)
result = self.tokenizer.decode(output[0], skip_special_tokens=True)
final_result = self.__post_process_generated_text(result, text)

return GenerationResult(text=final_result)
output = self._pipeline(text, min_length=adjusted_min_length,
return_full_text=False,
max_length=adjusted_max_length,
do_sample=args.do_sample,
early_stopping=args.early_stopping,
num_beams=args.num_beams,
temperature=args.temperature,
top_k=args.top_k,
no_repeat_ngram_size=args.no_repeat_ngram_size)
return GenerationResult(text=output[0]['generated_text'])


def __post_process_generated_text(self, result, text):
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
setup(
name = 'happytransformer',
packages = find_packages(),
version = '2.2.1',
version = '2.2.2',
license='Apache 2.0',
description = "Happy Transformer is an API built on top of Hugging Face's Transformer library that makes it easy to utilize state-of-the-art NLP models.",
long_description= readme,
Expand Down
8 changes: 7 additions & 1 deletion tests/test_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ def test_default_simple():
args = GENSettings(min_length=5, max_length=5)
output = happy_gen.generate_text("Artificial intelligence is ", args=args)
assert type(output.text) == str
print("default simple: ", output.text)


def test_default_min_max_length():
Expand Down Expand Up @@ -129,3 +128,10 @@ def test_gen_train_eval_with_dataclass():

assert type(after_result.loss) == float

def test_generate_after_train_eval():
happy_gen = HappyGeneration()
happy_gen.train("../data/gen/train-eval.txt")
eval_result = happy_gen.eval("../data/gen/train-eval.txt")
output = happy_gen.generate_text("Artificial intelligence is ")
assert type(output.text) == str

0 comments on commit 7706600

Please sign in to comment.