diff --git a/adala/agents/base.py b/adala/agents/base.py index f4efff3f..f634d60d 100644 --- a/adala/agents/base.py +++ b/adala/agents/base.py @@ -259,10 +259,18 @@ def learn( runtime = self.get_runtime(runtime=runtime) teacher_runtime = self.get_teacher_runtime(runtime=teacher_runtime) - for iteration in range(learning_iterations): - print_text( - f"\n\n=> Iteration #{iteration}: Getting feedback, analyzing and improving ..." - ) + # We add 1 to the number of iterations to allow for evaluation of the accuracy of the skill we are training + # at the end of the last iteration. I.e. if we have 3 iterations, we will train 3 times and evaluate the + # accuracy of the skill we are training 4 times. + for iteration in range(learning_iterations+1): + if iteration == learning_iterations: + print_text( + f"\n\n=> Final evaluation of the improved skills ..." + ) + else: + print_text( + f"\n\n=> Iteration #{iteration}: Getting feedback, analyzing and improving ..." + ) inputs = self.environment.get_data_batch(batch_size=batch_size) predictions = self.skills.apply(inputs, runtime=runtime) @@ -285,12 +293,24 @@ def learn( first_skill_with_errors = skill_mismatch.any(axis=0).idxmax() accuracy = feedback.get_accuracy() + + # End the loop with evaluation of the accuracy of the skill we are training + if iteration == learning_iterations: + print_text("Reached maximum number of iterations, stopping ...") + break + # TODO: iterating over skill can be more complex, and we should take order into account for skill_output, skill_name in self.skills.get_skill_outputs().items(): skill = self.skills[skill_name] if skill.frozen: continue + if accuracy[skill_output] >= accuracy_threshold: + print_text( + f'Output {skill_output} of skill "{skill_name}" is already accurate enough ({accuracy[skill_output]}), skipping ...' + ) + continue + print_text( f'Skill output to improve: "{skill_output}" (Skill="{skill_name}")\n' f"Accuracy = {accuracy[skill_output] * 100:0.2f}%", @@ -310,4 +330,4 @@ def learn( if skill_name == first_skill_with_errors: break - print_text("Train is done!") + print_text("Training is done!") diff --git a/adala/environments/base.py b/adala/environments/base.py index fb9d3220..0d4c972b 100644 --- a/adala/environments/base.py +++ b/adala/environments/base.py @@ -207,9 +207,9 @@ def get_feedback( [gt_pred_match.rename("match"), gt], axis=1 ) pred_feedback[pred_column] = match_concat.apply( - lambda row: "Prediction is correct." + lambda row: f"Prediction for {gt_column} is correct." if row["match"] - else f'Prediction is incorrect. Correct answer: "{row[gt_column]}"' + else f'Prediction for {gt_column} is incorrect. Correct answer: "{row[gt_column]}"' if not pd.isna(row["match"]) else np.nan, axis=1, diff --git a/adala/skills/_base.py b/adala/skills/_base.py index e6e29466..8f034481 100644 --- a/adala/skills/_base.py +++ b/adala/skills/_base.py @@ -206,7 +206,7 @@ def improve( f"### Example #{i}\n\n" f"{self.input_template.format(**row)}\n\n" f"{self.output_template.format(**row)}\n\n" - f'User feedback: {row[f"{train_skill_output}__fb"]}\n\n' + f'User feedback for {train_skill_output}: {row[f"{train_skill_output}__fb"]}\n\n' ) examples = "\n".join(examples) @@ -298,10 +298,12 @@ def improve( 2. The new prompt should be similar to the current instruction, and only differ in the parts that address the issues you identified in Step 1. Example: - - Current prompt: "The model should generate a summary of the input text." - - New prompt: "The model should generate a summary of the input text. Pay attention to the original style." + - Current prompt: "Generate a summary of the input text." + - New prompt: "Generate a summary of the input text. Pay attention to the original style." -3. Reply only with the new prompt. Do not include input and output templates in the prompt.""", +3. Do not rephrase or change parts of the prompt that are not related to the issues identified in Step 1. + +4. Reply only with the new prompt. Do not include input and output templates in the prompt.""", }, ] diff --git a/adala/utils/logs.py b/adala/utils/logs.py index 33e3a2d3..8db54b6d 100644 --- a/adala/utils/logs.py +++ b/adala/utils/logs.py @@ -34,21 +34,23 @@ def print_error(text: str): error_console.print(text) -def print_dataframe(dataframe: InternalDataFrame): +def print_dataframe(dataframe: InternalDataFrame, num_rows: int = 5, print_index: bool = False): """ Print dataframe to console. """ - num_rows = 5 table = Table(show_header=True, header_style="bold magenta") - # index_name = dataframe.index.name or 'index' - # table.add_column(index_name) + if print_index: + index_name = dataframe.index.name or 'index' + table.add_column(index_name) for column in dataframe.columns: table.add_column(str(column)) for index, value_list in enumerate(dataframe.iloc[:num_rows].values.tolist()): - # row = [str(index)] - row = [] + if print_index: + row = [str(dataframe.index[index])] + else: + row = [] row += [str(x) for x in value_list] table.add_row(*row) @@ -100,5 +102,5 @@ def highlight_differences(text1, text2): if i[0] != "-" ] ) - highlighted = highlighted.replace(" \n ", "
") + highlighted = highlighted.replace("\n", "
") display(HTML(highlighted))