diff --git a/adala/agents/base.py b/adala/agents/base.py
index f4efff3f..f634d60d 100644
--- a/adala/agents/base.py
+++ b/adala/agents/base.py
@@ -259,10 +259,18 @@ def learn(
runtime = self.get_runtime(runtime=runtime)
teacher_runtime = self.get_teacher_runtime(runtime=teacher_runtime)
- for iteration in range(learning_iterations):
- print_text(
- f"\n\n=> Iteration #{iteration}: Getting feedback, analyzing and improving ..."
- )
+ # We add 1 to the number of iterations to allow for evaluation of the accuracy of the skill we are training
+ # at the end of the last iteration. I.e. if we have 3 iterations, we will train 3 times and evaluate the
+ # accuracy of the skill we are training 4 times.
+ for iteration in range(learning_iterations+1):
+ if iteration == learning_iterations:
+ print_text(
+ f"\n\n=> Final evaluation of the improved skills ..."
+ )
+ else:
+ print_text(
+ f"\n\n=> Iteration #{iteration}: Getting feedback, analyzing and improving ..."
+ )
inputs = self.environment.get_data_batch(batch_size=batch_size)
predictions = self.skills.apply(inputs, runtime=runtime)
@@ -285,12 +293,24 @@ def learn(
first_skill_with_errors = skill_mismatch.any(axis=0).idxmax()
accuracy = feedback.get_accuracy()
+
+ # End the loop with evaluation of the accuracy of the skill we are training
+ if iteration == learning_iterations:
+ print_text("Reached maximum number of iterations, stopping ...")
+ break
+
# TODO: iterating over skill can be more complex, and we should take order into account
for skill_output, skill_name in self.skills.get_skill_outputs().items():
skill = self.skills[skill_name]
if skill.frozen:
continue
+ if accuracy[skill_output] >= accuracy_threshold:
+ print_text(
+ f'Output {skill_output} of skill "{skill_name}" is already accurate enough ({accuracy[skill_output]}), skipping ...'
+ )
+ continue
+
print_text(
f'Skill output to improve: "{skill_output}" (Skill="{skill_name}")\n'
f"Accuracy = {accuracy[skill_output] * 100:0.2f}%",
@@ -310,4 +330,4 @@ def learn(
if skill_name == first_skill_with_errors:
break
- print_text("Train is done!")
+ print_text("Training is done!")
diff --git a/adala/environments/base.py b/adala/environments/base.py
index fb9d3220..0d4c972b 100644
--- a/adala/environments/base.py
+++ b/adala/environments/base.py
@@ -207,9 +207,9 @@ def get_feedback(
[gt_pred_match.rename("match"), gt], axis=1
)
pred_feedback[pred_column] = match_concat.apply(
- lambda row: "Prediction is correct."
+ lambda row: f"Prediction for {gt_column} is correct."
if row["match"]
- else f'Prediction is incorrect. Correct answer: "{row[gt_column]}"'
+ else f'Prediction for {gt_column} is incorrect. Correct answer: "{row[gt_column]}"'
if not pd.isna(row["match"])
else np.nan,
axis=1,
diff --git a/adala/skills/_base.py b/adala/skills/_base.py
index e6e29466..8f034481 100644
--- a/adala/skills/_base.py
+++ b/adala/skills/_base.py
@@ -206,7 +206,7 @@ def improve(
f"### Example #{i}\n\n"
f"{self.input_template.format(**row)}\n\n"
f"{self.output_template.format(**row)}\n\n"
- f'User feedback: {row[f"{train_skill_output}__fb"]}\n\n'
+ f'User feedback for {train_skill_output}: {row[f"{train_skill_output}__fb"]}\n\n'
)
examples = "\n".join(examples)
@@ -298,10 +298,12 @@ def improve(
2. The new prompt should be similar to the current instruction, and only differ in the parts that address the issues you identified in Step 1.
Example:
- - Current prompt: "The model should generate a summary of the input text."
- - New prompt: "The model should generate a summary of the input text. Pay attention to the original style."
+ - Current prompt: "Generate a summary of the input text."
+ - New prompt: "Generate a summary of the input text. Pay attention to the original style."
-3. Reply only with the new prompt. Do not include input and output templates in the prompt.""",
+3. Do not rephrase or change parts of the prompt that are not related to the issues identified in Step 1.
+
+4. Reply only with the new prompt. Do not include input and output templates in the prompt.""",
},
]
diff --git a/adala/utils/logs.py b/adala/utils/logs.py
index 33e3a2d3..8db54b6d 100644
--- a/adala/utils/logs.py
+++ b/adala/utils/logs.py
@@ -34,21 +34,23 @@ def print_error(text: str):
error_console.print(text)
-def print_dataframe(dataframe: InternalDataFrame):
+def print_dataframe(dataframe: InternalDataFrame, num_rows: int = 5, print_index: bool = False):
"""
Print dataframe to console.
"""
- num_rows = 5
table = Table(show_header=True, header_style="bold magenta")
- # index_name = dataframe.index.name or 'index'
- # table.add_column(index_name)
+ if print_index:
+ index_name = dataframe.index.name or 'index'
+ table.add_column(index_name)
for column in dataframe.columns:
table.add_column(str(column))
for index, value_list in enumerate(dataframe.iloc[:num_rows].values.tolist()):
- # row = [str(index)]
- row = []
+ if print_index:
+ row = [str(dataframe.index[index])]
+ else:
+ row = []
row += [str(x) for x in value_list]
table.add_row(*row)
@@ -100,5 +102,5 @@ def highlight_differences(text1, text2):
if i[0] != "-"
]
)
- highlighted = highlighted.replace(" \n ", "
")
+ highlighted = highlighted.replace("\n", "
")
display(HTML(highlighted))