Skip to content

Commit

Permalink
Use GenerationConfig
Browse files Browse the repository at this point in the history
  • Loading branch information
amaiya committed Apr 1, 2023
1 parent 349e6ae commit 5cef6ce
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 14 deletions.
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@

### News and Announcements
- **2023-03-31**
- **ktrain 0.35.x** is released and supports Generative AI using an instruction-fine-tuned version of GPT-J that can run on your own machine. See the [example notebook](https://nbviewer.jupyter.org/github/amaiya/ktrain/blob/develop/examples/text/generative_ai_example.ipynb).
- **ktrain 0.35.x** is released and supports Generative AI using an instruction-fine-tuned version of GPT-J that can run on your own machine. See the [example notebook](https://nbviewer.jupyter.org/github/amaiya/ktrain/blob/develop/examples/text/generative_ai_example.ipynb) for more information.
```python
# Example: Generative AI in ktrain
from ktrain.text.generative_ai import GenerativeAI
model = GenerativeAI(device='cpu') # use device='cuda' if you have a good GPU!
prompt = """Extract the names of people in the supplied sentences. Here is an example:
Expand Down Expand Up @@ -58,6 +59,7 @@ print(model.execute(prompt))
- **Speech Transcription**: Extract text from audio files <sub><sup>[[example notebook](https://nbviewer.jupyter.org/github/amaiya/ktrain/blob/develop/examples/text/speech_transcription_example.ipynb)]</sup></sub>
- **Universal Information Extraction**: extract any kind of information from documents by simply phrasing it in the form of a question <sub><sup>[[example notebook](https://nbviewer.jupyter.org/github/amaiya/ktrain/blob/master/examples/text/qa_information_extraction.ipynb)]</sup></sub>
- **Keyphrase Extraction**: extract keywords from documents <sub><sup>[[example notebook](https://nbviewer.jupyter.org/github/amaiya/ktrain/blob/develop/examples/text/keyword_extraction_example.ipynb)]</sup></sub>
- **Generative AI with GPT**: Provide instructions to a lightweight ChatGPT-like model running on your own own machine to solve various tasks. <sub><sup>[[example notebook](https://nbviewer.jupyter.org/github/amaiya/ktrain/blob/develop/examples/text/generative_ai_example.ipynb)]</sup
- `vision` data:
- **image classification** (e.g., [ResNet](https://arxiv.org/abs/1512.03385), [Wide ResNet](https://arxiv.org/abs/1605.07146), [Inception](https://www.cs.unc.edu/~wliu/papers/GoogLeNet.pdf)) <sub><sup>[[example notebook](https://colab.research.google.com/drive/1WipQJUPL7zqyvLT10yekxf_HNMXDDtyR)]</sup></sub>
- **image regression** for predicting numerical targets from photos (e.g., age prediction) <sub><sup>[[example notebook](https://nbviewer.jupyter.org/github/amaiya/ktrain/blob/master/examples/vision/utk_faces_age_prediction-resnet50.ipynb)]</sup></sub>
Expand Down
8 changes: 4 additions & 4 deletions examples/text/generative_ai_example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@
"source": [
"## Generative AI with *ktrain*\n",
"\n",
"*ktrain* supports a Generative AI that is currently based on an instruction-fine-tuned version of GPT-J. Think of it as a lightweight version of ChatGPT that can be run locally on your own machine. As a smaller model, it will not perform as well as GPT-4, ChatGPT, etc. However, since it does not communicate with external APIs like OpenAI, it can be used with non-public data.\n",
"*ktrain* supports a Generative AI module that is currently based on an instruction-fine-tuned version of GPT-J. Think of it as a lightweight version of ChatGPT that can be run locally on your own machine. As a smaller model, it will not perform as well as GPT-4, ChatGPT, etc. However, since it does not communicate with external APIs like OpenAI, it can be used with non-public data.\n",
"\n",
"The model requires a GPU with at least 16GB of GPU memory or VRAM. If you have less than this, you can use a CPU (provided it has at least 16GB of RAM), but output will be generated **very** slowly. We will use a CPU in this example, but you should supply `device=cuda` if you have a GPU with at least 16GB of GPU memory."
"The model requires a GPU with at least 16GB of GPU memory or VRAM. If you have less than this, you can use a CPU (provided it has at least 16GB of RAM), but output will be generated very slowly (depending on the number of CPU cores). We will use a CPU in this example, but you should supply `device=cuda` if you have a GPU with at least 16GB of GPU memory."
]
},
{
Expand Down Expand Up @@ -415,7 +415,7 @@
"[Tweet]: \n",
"Startups should not worry about how to put out fires, they should worry about how to start them.\n",
"###\n",
"[Keyword]: \n",
"[Keyword]: http://localhost:7999/notebooks/examples/text/generative_ai_example.ipynb#\n",
"climate change\n",
"[Tweet]:\"\"\"\n",
"print(model.execute(prompt))"
Expand All @@ -426,7 +426,7 @@
"metadata": {},
"source": [
"## Final Comments\n",
"The `execute` method accepts parameters that are fed directly to the generative model. You can change them as necessary. The default value for `max_new_tokens` (the upper limt on generated answers) has been set to 512."
"The constructor for `GenerativeAI` accepts parameters that are fed directly to the `generate` method of the underlying model. You can change them as necessary. The default value for `max_new_tokens` (the upper limt on generated answers) has been set to 512."
]
},
{
Expand Down
29 changes: 20 additions & 9 deletions ktrain/text/generative_ai/core.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from transformers import pipeline
from transformers import pipeline, GenerationConfig
import torch
from ...torch_base import TorchBase
from typing import Optional
Expand All @@ -14,33 +14,45 @@ def __init__(
self,
model_name: str = "nlpcloud/instruct-gpt-j-fp16",
device: Optional[str] = None,
max_new_tokens: int = 512,
**kwargs
):
"""
```
interface to GenerativeAI models using the transformers library
Interface to GenerativeAI models using the transformers library.
Extra kwargs are supplied directly to the generate method of the model.
Args:
model_name(str): name of the model. Currently, only the nlpcloud/instruct-gpt-j-fp16
device(str): device to use ("cpu" for CPU, "cuda" for GPU, "cuda:0" for first GPU, "cuda:1" for second GPU ,etc.):
max_new_tokens(int): The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt.
```
"""

super().__init__(device=device)
self.device_id = self.device_to_id()
self.config = GenerationConfig(max_new_tokens=max_new_tokens, **kwargs)
if self.device_id < 0:
self.generator = pipeline(model=model_name, device=self.device_id)
self.generator = pipeline(
model=model_name, device=self.device_id, generation_config=self.config
)
else:
self.generator = pipeline(
model=model_name, torch_dtype=torch.float16, device=self.device_id
model=model_name,
torch_dtype=torch.float16,
device=self.device_id,
generation_config=self.config,
)
self.generator.model.generation_config.pad_token_id = (
self.generator.model.generation_config.eos_token_id
)

def execute(self, prompt: str, max_new_tokens: int = 512, **kwargs):
def execute(self, prompt: str):
"""
```
Issue a prompt to the model. The default model is an instruction-fine-tuned model based on GPT-J.
This means that you should always construct your prompt in the form of an instruction.
In addition to max_new_tokens, additonal parmeters can be supplied that will be fed directly to the model.
Examples include min_new_tokens and max_time.
Example:
Expand All @@ -51,13 +63,12 @@ def execute(self, prompt: str, max_new_tokens: int = 512, **kwargs):
Args:
prompt(str): prompt to supply to model
max_new_tokens(int): The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt.
Returns:
str: generated text
```
"""
prompt = prompt.strip() + "\n"
result = self.generator(prompt, max_new_tokens=512, **kwargs)
result = self.generator(prompt)
result = result[0]["generated_text"]
if result.startswith(prompt):
result = result.replace(prompt, "")
Expand Down

0 comments on commit 5cef6ce

Please sign in to comment.