Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

A bunch of improvements for the classification skill #50

Draft
wants to merge 7 commits into
base: master
Choose a base branch
from
Draft
30 changes: 25 additions & 5 deletions adala/agents/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,10 +259,18 @@ def learn(
runtime = self.get_runtime(runtime=runtime)
teacher_runtime = self.get_teacher_runtime(runtime=teacher_runtime)

for iteration in range(learning_iterations):
print_text(
f"\n\n=> Iteration #{iteration}: Getting feedback, analyzing and improving ..."
)
# We add 1 to the number of iterations to allow for evaluation of the accuracy of the skill we are training
# at the end of the last iteration. I.e. if we have 3 iterations, we will train 3 times and evaluate the
# accuracy of the skill we are training 4 times.
for iteration in range(learning_iterations+1):
if iteration == learning_iterations:
print_text(
f"\n\n=> Final evaluation of the improved skills ..."
)
else:
print_text(
f"\n\n=> Iteration #{iteration}: Getting feedback, analyzing and improving ..."
)

inputs = self.environment.get_data_batch(batch_size=batch_size)
predictions = self.skills.apply(inputs, runtime=runtime)
Expand All @@ -285,12 +293,24 @@ def learn(
first_skill_with_errors = skill_mismatch.any(axis=0).idxmax()

accuracy = feedback.get_accuracy()

# End the loop with evaluation of the accuracy of the skill we are training
if iteration == learning_iterations:
print_text("Reached maximum number of iterations, stopping ...")
break

# TODO: iterating over skill can be more complex, and we should take order into account
for skill_output, skill_name in self.skills.get_skill_outputs().items():
skill = self.skills[skill_name]
if skill.frozen:
continue

if accuracy[skill_output] >= accuracy_threshold:
print_text(
f'Output {skill_output} of skill "{skill_name}" is already accurate enough ({accuracy[skill_output]}), skipping ...'
)
continue

print_text(
f'Skill output to improve: "{skill_output}" (Skill="{skill_name}")\n'
f"Accuracy = {accuracy[skill_output] * 100:0.2f}%",
Expand All @@ -310,4 +330,4 @@ def learn(
if skill_name == first_skill_with_errors:
break

print_text("Train is done!")
print_text("Training is done!")
4 changes: 2 additions & 2 deletions adala/environments/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,9 +207,9 @@ def get_feedback(
[gt_pred_match.rename("match"), gt], axis=1
)
pred_feedback[pred_column] = match_concat.apply(
lambda row: "Prediction is correct."
lambda row: f"Prediction for {gt_column} is correct."
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you explain the reason for this addition? the initial idea was to use a single column at a time, so pointing out a specific column name might be not necessary - but I'm probably missing your idea

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Classification (and Transform) skills support multiple outputs and I used this to classify each social media post into multiple categories (each one as a True/False field).

You're correct that on each step we evaluate only one output. But the base prompt doesn't mention which of the outputs we evaluate at the step. This patch is the easiest way I found to make sure the model understands that the feedback is related to a specific output.

If there is only one output, we can simply say "Prediction is correct." as it used to be.

Here is how the template from TransformSkill.improve(). Note that it doesn't mention anything about the output name:

"""
## Current prompt
{self.instructions}

## Examples
{examples}

Summarize your analysis about incorrect predictions and suggest changes to the prompt."""

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I got your idea about referencing specific columns to help LLM make the correct assessment. However, the column name defined there doesn't contain any signal for LLM, example from the tests:

Prediction for gt_0 is incorrect. Correct answer:     0 0 0   1 1 1   1 5 1   
  "1 1 1" 

"gt_0" keyword is not presented in input prompt which consists of the string "Input: ... Output: ...". In this case, I'd better create a string like
"Prediction for the field "Output" is incorrect"
assuming there can be multiple outputs.
Let me know if it makes sense.
Happy to merge your PR as soon as we have all tests passed. Thank you.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense.
I just need to figure out how to get the field name 🤔

if row["match"]
else f'Prediction is incorrect. Correct answer: "{row[gt_column]}"'
else f'Prediction for {gt_column} is incorrect. Correct answer: "{row[gt_column]}"'
if not pd.isna(row["match"])
else np.nan,
axis=1,
Expand Down
10 changes: 6 additions & 4 deletions adala/skills/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ def improve(
f"### Example #{i}\n\n"
f"{self.input_template.format(**row)}\n\n"
f"{self.output_template.format(**row)}\n\n"
f'User feedback: {row[f"{train_skill_output}__fb"]}\n\n'
f'User feedback for {train_skill_output}: {row[f"{train_skill_output}__fb"]}\n\n'
)

examples = "\n".join(examples)
Expand Down Expand Up @@ -298,10 +298,12 @@ def improve(

2. The new prompt should be similar to the current instruction, and only differ in the parts that address the issues you identified in Step 1.
Example:
- Current prompt: "The model should generate a summary of the input text."
- New prompt: "The model should generate a summary of the input text. Pay attention to the original style."
- Current prompt: "Generate a summary of the input text."
- New prompt: "Generate a summary of the input text. Pay attention to the original style."

3. Reply only with the new prompt. Do not include input and output templates in the prompt.""",
3. Do not rephrase or change parts of the prompt that are not related to the issues identified in Step 1.

4. Reply only with the new prompt. Do not include input and output templates in the prompt.""",
},
]

Expand Down
16 changes: 9 additions & 7 deletions adala/utils/logs.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,21 +34,23 @@ def print_error(text: str):
error_console.print(text)


def print_dataframe(dataframe: InternalDataFrame):
def print_dataframe(dataframe: InternalDataFrame, num_rows: int = 5, print_index: bool = False):
"""
Print dataframe to console.
"""
num_rows = 5
table = Table(show_header=True, header_style="bold magenta")
# index_name = dataframe.index.name or 'index'
# table.add_column(index_name)
if print_index:
index_name = dataframe.index.name or 'index'
table.add_column(index_name)

for column in dataframe.columns:
table.add_column(str(column))

for index, value_list in enumerate(dataframe.iloc[:num_rows].values.tolist()):
# row = [str(index)]
row = []
if print_index:
row = [str(dataframe.index[index])]
else:
row = []
row += [str(x) for x in value_list]
table.add_row(*row)

Expand Down Expand Up @@ -100,5 +102,5 @@ def highlight_differences(text1, text2):
if i[0] != "-"
]
)
highlighted = highlighted.replace(" \n ", "<br>")
highlighted = highlighted.replace("\n", "<br/>")
display(HTML(highlighted))
Loading