Skip to content

Commit

Permalink
Merge pull request #9 from TeiaLabs/add-model-settings
Browse files Browse the repository at this point in the history
Add model settings
  • Loading branch information
martinduartemore committed Jan 19, 2023
2 parents 1bf6e9d + 48142a9 commit 6dfa879
Show file tree
Hide file tree
Showing 6 changed files with 60 additions and 14 deletions.
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -157,4 +157,6 @@ cython_debug/
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
#.idea/

htmlcov/
8 changes: 6 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,8 @@ pip install "prompts @ git+https://github.com/TeiaLabs/prompts.git

```python
template = 'a photo of a <img_label>'
expected_var = 'img_label'

prompt = DynamicPrompt(template, expected_var)
prompt = DynamicPrompt(template)
filled_prompt = prompt.build(img_label='dog')

print(filled_prompt)
Expand All @@ -30,6 +29,11 @@ str_prompt = prompt.build(
)
```

You can also access recommended model settings (engine, temperature) that can be fed to the model input (e.g., openai.Completion.create()):

```python
prompt.get_model_settings()
```

## Improve Autocomplete with custom prompts
Alternatively, to get more control and better autocomplete suggestions, you can inherit from the `BasePrompt` class and override the build method with explicit arguments:
Expand Down
24 changes: 20 additions & 4 deletions prompts/prompt_builder.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from abc import abstractmethod
from typing import Optional

from .exceptions import MissingArgumentError, VariableNotInPromptError, UndefinedVariableError
from .utils import load_yaml
Expand All @@ -12,12 +13,22 @@ class BasePrompt:
- build(self, var1, var2, ...)
'''

def __init__(self, prompt, template_vars=None):
def __init__(
self,
prompt: str,
template_vars: Optional[list[str]] = None,
settings: Optional[dict[str, str]] = None,
):
self.prompt = prompt
self.template_vars = template_vars
self.settings = settings

if template_vars is not None:
self._check_vars()

def get_model_settings(self) -> dict[str, str]:
return self.settings

def _check_vars(self, check_build=True):
for var in self.template_vars:
# check if var is an argument of self.build
Expand All @@ -32,16 +43,20 @@ def set_prompt_values(self, strict=True, **kwargs):
for var, value in kwargs.items():
pattern = f"<{var}>"
if pattern not in prompt and strict:
raise UndefinedVariableError(f"Variable {var} was not found in prompt (expected vars={self.template_vars}).")
raise UndefinedVariableError(
f"Variable {var} was not found in prompt (expected vars={self.template_vars})."
)
prompt = prompt.replace(pattern, value)
return prompt

@classmethod
def from_file(cls, prompt_file: str):
prompt = load_yaml(prompt_file)
settings = prompt.get('settings', None)
return cls(
prompt=prompt['prompt'],
template_vars=prompt['vars'],
template_vars=prompt['vars'],
settings=settings,
)

@abstractmethod
Expand All @@ -61,9 +76,10 @@ class DynamicPrompt(BasePrompt):
```
"""

def __init__(self, prompt, template_vars=None):
def __init__(self, prompt, template_vars=None, settings=None):
self.prompt = prompt
self.template_vars = template_vars
self.settings = settings
if template_vars is not None:
self._check_vars(check_build=False)

Expand Down
8 changes: 5 additions & 3 deletions samples/sample.prompt → samples/sample.prompt.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@ prompt: |
<input_sentence>
Fixed sentence:
top-1: 1
temperature: 0.15
vars:
- input_sentence
- input_sentence
settings:
top-k: 1
temperature: 0.15
engine: text-davinci-003
14 changes: 12 additions & 2 deletions tests/test_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from prompts import DynamicPrompt
from prompts import exceptions


def test_ensemble():
templates = ['<label>', 'a photo of <label>', 'picture of <label>']
template_vars = ['label']
Expand All @@ -19,6 +20,8 @@ def test_ensemble():
with pytest.raises(exceptions.UndefinedVariableError):
_ = prompt.build(img_class='cat')

assert len(prompt) == 3


def test_build_missing_args_valid():
templates = ['<label>/<superclass>', 'a photo of <label>']
Expand All @@ -37,6 +40,13 @@ def test_build_missing_args_valid():
assert prompted_list == expected


def test_build_missing_args():
templates = ['<label>/<superclass>', 'a photo of <label>']

with pytest.raises(exceptions.ExpectedVarsArgumentError):
PromptEnsemble(templates, expected_vars=None)


def test_build_missing_args_invalid():
templates = ['<label>', 'test']
template_vars = ['label']
Expand Down Expand Up @@ -135,7 +145,7 @@ def test_invalid_ensemble_template():
def test_prompt_ensemble_from_file():
prompts = []
for i in range(3):
prompt_file = 'samples/sample.prompt'
prompt_file = 'samples/sample.prompt.yaml'
prompt = DynamicPrompt.from_file(prompt_file)
prompts.append(prompt)

Expand All @@ -153,4 +163,4 @@ def test_prompt_ensemble_from_file():
input_sentence='lets go'
)

assert prompt_ff == prompt_str
assert prompt_ff == prompt_str
16 changes: 14 additions & 2 deletions tests/test_prompt.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import pytest
from prompts import DynamicPrompt
from prompts import DynamicPrompt, BasePrompt
from prompts import exceptions

class TestPrompt:
Expand All @@ -9,7 +9,7 @@ class TestPrompt:

@staticmethod
def test_prompt_from_file():
prompt_file = 'samples/sample.prompt'
prompt_file = 'samples/sample.prompt.yaml'
prompt = DynamicPrompt.from_file(prompt_file)
prompt_str = prompt.build(input_sentence='lets go')
assert 'lets go' in prompt_str
Expand All @@ -21,13 +21,25 @@ def test_prompt_from_file():
))
assert expected_prompt == prompt_str

settings = prompt.get_model_settings()
assert isinstance(settings, dict)
assert isinstance(settings['temperature'], float)
assert settings['temperature'] == 0.15
assert settings['engine'] == 'text-davinci-003'

@staticmethod
def test_str_prompt():

prompt = DynamicPrompt(TestPrompt.template, TestPrompt.template_vars)
filled_prompt = prompt.build(img_label='dog')
assert filled_prompt == 'a photo of a dog'

@staticmethod
def test_base_prompt():
# it has to throw exception because the build method is not implemented
with pytest.raises(exceptions.MissingArgumentError):
BasePrompt(TestPrompt.template, TestPrompt.template_vars)

@staticmethod
def test_str_prompt_without_vars():
prompt = DynamicPrompt(TestPrompt.template)
Expand Down

0 comments on commit 6dfa879

Please sign in to comment.