Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 74 additions & 40 deletions langtest/augmentation/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,7 @@ def __init__(
task: TaskManager,
generate_templates=False,
show_templates=False,
num_extra_templates=10,
) -> None:
"""This constructor for the TemplaticAugment class.

Expand All @@ -336,53 +337,25 @@ def __init__(
self.__task = task

if generate_templates:
if try_import_lib("openai"):
import openai
from pydantic import BaseModel

client = openai.OpenAI()

class Templates(BaseModel):
templates: List[str]

try:
given_template = self.__templates[:]
for template in given_template:
prompt = f"""Based on the template provided, create 10 new and unique templates that are variations on this theme. Present these as a Python list, with each template as a quoted string. The list should contain only the templates without any additional text or explanation.

Template:
"{template}"

"""

response = client.beta.chat.completions.parse(
model="gpt-4o-mini",
messages=[
{"role": "system", "content": "Action: Generate templates"},
{"role": "user", "content": prompt},
],
max_tokens=500,
temperature=0,
response_format=Templates,
generated_templates: List[str] = self.__generate_templates(
template, num_extra_templates
)

generated_response = response.choices[0].message.parsed
# Process the generated response
if generated_response:
# # Assuming the response format is a Python-like list in a string
# templates_list = generated_response.strip("[]").split('",')
# templates_list = [
# template.strip().strip('"')
# for template in templates_list
# if template.strip()
# ]
while len(generated_templates) < num_extra_templates:
temp_templates = self.__generate_templates(
template, num_extra_templates
)
generated_templates.extend(temp_templates)

if generated_templates:
# Extend the existing templates list
self.__templates.extend(generated_response.templates)
else:
print("No response or unexpected format.")

else:
raise RuntimeError(Errors.E084)
self.__templates.extend(generated_templates[:num_extra_templates])
except Exception as e:
raise Errors.E095(e)

if show_templates:
[print(template) for template in self.__templates]
Expand Down Expand Up @@ -622,3 +595,64 @@ def add_spaces_around_punctuation(text: str):
text = re.sub(r"\s+", " ", text).strip()

return text

def __generate_templates(self, template, num_extra_templates) -> List[str]:
if try_import_lib("openai"):
import openai
from pydantic import BaseModel, validator

client = openai.OpenAI()

class Templates(BaseModel):
templates: List[str]

def __post_init__(self):
self.templates = [i.strip('"') for i in self.templates]

@validator("templates", each_item=True)
def check_templates(cls, v: str):
if not v:
raise ValueError("No templates generated.")
return v.strip('"')

def remove_invalid_templates(self, original_template):
# extract variable names using regex
regexs = r"{([^{}]*)}"
original_vars = re.findall(regexs, original_template)
original_vars = set([var.strip() for var in original_vars])

# remove invalid templates
valid_templates = []
for template in self.templates:
template_vars: List[str] = re.findall(regexs, template)
template_vars = set([var.strip() for var in template_vars])
if template_vars == original_vars:
valid_templates.append(template)
self.templates = valid_templates

prompt = (
f"Based on the provided template, create {num_extra_templates} new and unique templates that are "
"variations on this theme. Present these as a list, with each template as a quoted string. The list should "
"contain only the templates, without any additional text or explanation. Ensure that the structure of "
"these variables remains consistent in each generated template. Note: don't add any extra variables and ignore typo errors.\n\n"
"Template:\n"
f"{template}\n"
)
response = client.beta.chat.completions.parse(
model="gpt-4o-mini",
messages=[
{
"role": "system",
"content": f"Action: Generate up to {num_extra_templates} templates and ensure that the structure of the variables within the templates remains unchanged and don't add any extra variables.",
},
{"role": "user", "content": prompt},
],
max_tokens=500,
temperature=0,
response_format=Templates,
)

generated_response = response.choices[0].message.parsed
generated_response.remove_invalid_templates(template)

return generated_response.templates[:num_extra_templates]
1 change: 1 addition & 0 deletions langtest/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,7 @@ class Errors(metaclass=ErrorsWithCodes):
E093 = ("Category cannot be None. Please provide a valid category.")
E094 = ("Unsupported category: '{category}'. Supported categories: {supported_category}")
E095 = ("Failed to make API request: {e}")
E096 = ("Failed to generate the templates in Augmentation: {e}")


class ColumnNameError(Exception):
Expand Down