In [1]:
import pandas as pd
import numpy as np
import datasets as ds
import copy
import dotenv
import json
import os
import warnings
from langchain_openai import ChatOpenAI

from agents.core.reflection.agent import ReflexionCoTAgent, SelfReflectionCoTAgent

from agents.core.prompts.math_number_theory import (
    MATH_NUMBERTHEORY_FEWSHOT_INSTRUCTION_AGENT
)

from agents.core.reflection.prompts.general import (
    ZEROSHOT_REFLECT_INSTRUCTION_NEGATIVE
)

from agents.core.fewshots.math_number_theory import (
    MATH_NUMBERTHEORY_FEWSHOT_EXAMPLES_AGENT_COT
)

from agents.core.reflection.reflect import (
    ReflexionCoTReflector
)

from experiments.utils.logging import log_response

warnings.filterwarnings('ignore')
dotenv.load_dotenv()

total_rows = 200
prompt_examples = "negative_prompt_negative_examples"

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def load_algebra_dataset(file_path):
    try:
        with open(file_path, 'r') as f:
            data = json.load(f)
    except FileNotFoundError:
        print(f"Error: File not found at path: {file_path}")
        return None
    except json.JSONDecodeError:
        print(f"Error: Invalid JSON format in file: {file_path}")
        return None

    # Initialize lists to store each field
    problems = []
    levels = []
    types = []
    solutions = []

    # Iterate through each example in the JSON data
    for item in data:
        problems.append(item.get('problem', ''))
        levels.append(item.get('level', ''))
        types.append(item.get('type', ''))
        solutions.append(item.get('solution', ''))

    # Create a DataFrame from the loaded data
    df = pd.DataFrame({
        'problem': problems,
        'level': levels,
        'type': types,
        'solution': solutions
    })

    # Group the DataFrame by the 'level' field
    grouped_df = df.groupby('level')

    ds_dict = {}

    # Iterate over the groups and create a dataset for each group
    for group_name, group_data in grouped_df:
        dataset = ds.Dataset.from_pandas(group_data.reset_index(drop=True))
        ds_dict[group_name] = dataset

    return ds_dict

def downsample_datasets(ds_dict, target_total_rows=300):
  """
  Downsamples a dictionary of datasets while maintaining their original proportions.

  Args:
      ds_dict: A dictionary where keys are group names and values are datasets.Dataset objects.
      target_total_rows: The desired total number of rows across all datasets after downsampling.

  Returns:
      A new dictionary with the same keys as ds_dict, but with the datasets downsampled.
  """

  # Calculate the total number of rows across all datasets
  total_rows = sum(dataset.num_rows for dataset in ds_dict.values())

  # Calculate the scaling factor
  scaling_factor = target_total_rows / total_rows

  # Downsample each dataset
  new_ds_dict = {}  # Create a new dictionary to store the downsampled datasets
  for group_name, dataset in ds_dict.items():
    new_num_rows = round(dataset.num_rows * scaling_factor)
    new_ds_dict[group_name] = dataset.shuffle(seed=40).select(range(new_num_rows))

  return new_ds_dict

def create_stats(downsampled_ds_dict, prompt_examples):
    # Define the multi-index structure for detailed stats (previously total_stats structure)
    index = pd.MultiIndex.from_tuples([
        (pd.NA, 'cot_1', pd.NA, 'correct'),
        (pd.NA, 'cot_1', pd.NA, 'incorrect'),


        (prompt_examples, 'selfreflection_cot_2', pd.NA, 'correct_state_1'),
        (prompt_examples, 'selfreflection_cot_2', pd.NA, 'correct_state_2'),
        (prompt_examples, 'selfreflection_cot_2', pd.NA, 'incorrect_state_1'),
        (prompt_examples, 'selfreflection_cot_2', pd.NA, 'incorrect_state_2'),

        (prompt_examples, 'selfreflection_cot_2', 'start_to_final', 'correct_to_incorrect'),
        (prompt_examples, 'selfreflection_cot_2', 'start_to_final', 'incorrect_to_correct'),
        (prompt_examples, 'selfreflection_cot_2', 'start_to_final', 'incorrect_to_incorrect'),
        (prompt_examples, 'selfreflection_cot_2', 'start_to_final', 'correct_to_correct'),



        (prompt_examples, 'selfreflection_cot_3', pd.NA, 'correct_state_1'),
        (prompt_examples, 'selfreflection_cot_3', pd.NA, 'correct_state_2'),
        (prompt_examples, 'selfreflection_cot_3', pd.NA, 'correct_state_3'),
        (prompt_examples, 'selfreflection_cot_3', pd.NA, 'incorrect_state_1'),
        (prompt_examples, 'selfreflection_cot_3', pd.NA, 'incorrect_state_2'),
        (prompt_examples, 'selfreflection_cot_3', pd.NA, 'incorrect_state_3'),

        (prompt_examples, 'selfreflection_cot_3', 'start_to_final', 'correct_to_incorrect'),
        (prompt_examples, 'selfreflection_cot_3', 'start_to_final', 'incorrect_to_correct'),
        (prompt_examples, 'selfreflection_cot_3', 'start_to_final', 'incorrect_to_incorrect'),
        (prompt_examples, 'selfreflection_cot_3', 'start_to_final', 'correct_to_correct'),



        (prompt_examples, 'reflexion_cot_2', pd.NA, 'correct_state_1'),
        (prompt_examples, 'reflexion_cot_2', pd.NA, 'correct_state_2'),
        (prompt_examples, 'reflexion_cot_2', pd.NA, 'incorrect_state_1'),
        (prompt_examples, 'reflexion_cot_2', pd.NA, 'incorrect_state_2'),

        (prompt_examples, 'reflexion_cot_2', 'start_to_final', 'correct_to_incorrect'),
        (prompt_examples, 'reflexion_cot_2', 'start_to_final', 'incorrect_to_correct'),
        (prompt_examples, 'reflexion_cot_2', 'start_to_final', 'incorrect_to_incorrect'),
        (prompt_examples, 'reflexion_cot_2', 'start_to_final', 'correct_to_correct'),



        (prompt_examples, 'reflexion_cot_3', pd.NA, 'correct_state_1'),
        (prompt_examples, 'reflexion_cot_3', pd.NA, 'correct_state_2'),
        (prompt_examples, 'reflexion_cot_3', pd.NA, 'correct_state_3'),
        (prompt_examples, 'reflexion_cot_3', pd.NA, 'incorrect_state_1'),
        (prompt_examples, 'reflexion_cot_3', pd.NA, 'incorrect_state_2'),
        (prompt_examples, 'reflexion_cot_3', pd.NA, 'incorrect_state_3'),

        (prompt_examples, 'reflexion_cot_3', 'start_to_final', 'correct_to_incorrect'),
        (prompt_examples, 'reflexion_cot_3', 'start_to_final', 'incorrect_to_correct'),
        (prompt_examples, 'reflexion_cot_3', 'start_to_final', 'incorrect_to_incorrect'),
        (prompt_examples, 'reflexion_cot_3', 'start_to_final', 'correct_to_correct'),
    ], names=['prompt_examples', 'method', 'transition', 'result_type'])

    # Initialize the base stats DataFrame with the multi-index structure
    base_stats_df = pd.DataFrame(0, index=index, columns=['value'])

    # Initialize an empty list for detailed_stats
    stats_list = []

    # Loop through downsampled_ds_dict to generate detailed_stats
    for level in downsampled_ds_dict.keys():
        # Copy the base_stats_df and add group-specific columns
        df = base_stats_df.copy()
        df['level'] = level
        
        # Append the group-specific DataFrame to the list
        stats_list.append(df)

    # Concatenate all group-specific DataFrames into one detailed DataFrame
    stats_df = pd.concat(stats_list).reset_index()

    return stats_df

def save_stats(new_stats, stats_filename='stats.csv'):
    """
    Overwrites the stats file with new stats.

    Arguments:
    - new_stats (pd.DataFrame): The new stats to write to the file.
    - stats_filename (str): The filename of the stats file.
    """
    try:
        # Overwrite the existing file or create a new one
        new_stats.to_csv(stats_filename, index=False)
        print(f"Stats successfully saved to {stats_filename}.")

    except PermissionError:
        print(f"Permission denied: Unable to write to {stats_filename}.")
        print("Please ensure the file is not open in another application and you have write permissions.")

    except Exception as e:
        print(f"An error occurred while saving stats: {e}")

def get_general_stats(stats_df):
    general_stats_df = stats_df.groupby(['prompt_examples', 'method', 'transition', 'result_type']).agg({
        'value': 'sum'  # Sum up the values for each group
    }).reset_index()

    return general_stats_df

def update_stats(stats_df, level, prompt_examples, method, trial_num=None, state=None):
    """
    Helper function to update stats_df based on provided parameters.

    Arguments:
    - stats_df: The DataFrame storing all statistics.
    - level: Group-specific metadata.
    - prompt_examples: The specific prompt examples used in this context.
    - method: The method being used (e.g., 'cot_1', 'selfreflection_cot_3').
    - trial_num: The trial number (1, 2, or 3). Set to None for single-trial methods like 'cot_1'.
    - state: 'correct' or 'incorrect'.
    """
    if method == 'cot_1':
        # For single-trial methods, use 'correct' or 'incorrect' as result_type
        if state not in ['correct', 'incorrect']:
            raise ValueError("State must be 'correct' or 'incorrect' for 'cot_1'")
        result_type = state

        # Define the condition to locate the correct row in stats_df
        condition = (
            (stats_df['level'] == level) &
            (stats_df['method'] == method) &
            (stats_df['result_type'] == result_type)
        )
    else:
        # For multi-trial methods, use 'correct_state_x' or 'incorrect_state_x'
        if trial_num is None or state is None:
            raise ValueError("Both trial_num and state must be provided for multi-trial methods")
        result_type = f'{state}_state_{trial_num}'

        # Define the condition to locate the correct row in stats_df
        condition = (
            (stats_df['level'] == level) &
            (stats_df['prompt_examples'] == prompt_examples) &
            (stats_df['method'] == method) &
            (stats_df['transition'].isna()) &  # No transition for state rows
            (stats_df['result_type'] == result_type)
        )

    # Increment the 'value' by 1 for matching rows
    if stats_df.loc[condition, 'value'].empty:
        print(f"Warning: No matching row found for method '{method}', result_type '{result_type}'.")
    else:
        stats_df.loc[condition, 'value'] += 1

def determine_transition(prev_correct, current_correct):
    """
    Determines the transition type based on previous and current correctness.
    
    Arguments:
    - prev_correct: Boolean indicating if the previous trial was correct.
    - current_correct: Boolean indicating if the current trial was correct.
    
    Returns:
    - A string representing the transition type (e.g., 'correct_to_incorrect', 'incorrect_to_correct').
    """
    if prev_correct and not current_correct:
        return 'correct_to_incorrect'
    elif not prev_correct and current_correct:
        return 'incorrect_to_correct'
    elif not prev_correct and not current_correct:
        return 'incorrect_to_incorrect'
    else:
        return 'correct_to_correct'

def process_transitions(stats_df, output_list, level, prompt_examples, method, transition_labels, max_trials):
    """
    Processes transitions for a given output list and updates stats_df accordingly.

    Arguments:
    - stats_df: The DataFrame storing all statistics.
    - output_list: List of output objects with 'is_correct' attribute.
    - level: Group-specific metadata.
    - prompt_examples: The specific prompt examples used.
    - method: The method being used (e.g., 'selfreflection_cot_3').
    - transition_labels: The transitions to process (e.g., ['start_to_final']).
    - max_trials: Maximum number of trials for the method (2 or 3).
    """
    # Determine actual number of trials performed
    actual_trials = len(output_list)
    
    # Extract trial states
    trial_states = ['correct' if output.is_correct else 'incorrect' for output in output_list]
    
    # Pad the trial_states with the last state if method stopped early
    if actual_trials < max_trials:
        if actual_trials > 0:
            last_state = trial_states[-1]
        else:
            last_state = 'incorrect'  # Default to incorrect if no trials were performed
        trial_states.extend([last_state] * (max_trials - actual_trials))
    
    # Update the 'correct_state_x' and 'incorrect_state_x' counts
    for trial_num, state in enumerate(trial_states, start=1):
        update_stats(
            stats_df, 
            level, 
            prompt_examples, 
            method, 
            trial_num=trial_num, 
            state=state
        )
    
    # Handle transitions based on transition_labels
    for transition_label in transition_labels:
        if transition_label == 'start_to_final':
            if len(trial_states) >= 2:
                # Transition from first to final trial
                prev_correct = trial_states[0] == 'correct'
                final_correct = trial_states[-1] == 'correct'
                transition = determine_transition(prev_correct, final_correct)
                # Update the 'start_to_final' transition
                condition = (
                    (stats_df['level'] == level) &
                    (stats_df['prompt_examples'] == prompt_examples) &
                    (stats_df['method'] == method) &
                    (stats_df['transition'] == 'start_to_final') &
                    (stats_df['result_type'] == transition)
                )
                # Increment the 'value' by 1 for matching transition rows
                if stats_df.loc[condition, 'value'].empty:
                    print(f"Warning: No matching transition row found for method '{method}', transition '{transition}'.")
                stats_df.loc[condition, 'value'] += 1


In [3]:
llm_1 = ChatOpenAI(openai_api_key="", model="gpt-3.5-turbo", temperature=0.0)
llm_2 = ChatOpenAI(openai_api_key="", model="gpt-3.5-turbo", temperature=0.0)
llm_3 = ChatOpenAI(openai_api_key="", model="gpt-3.5-turbo", temperature=0.0)
llm_4 = ChatOpenAI(openai_api_key="", model="gpt-3.5-turbo", temperature=0.0)
llm_5 = ChatOpenAI(openai_api_key="", model="gpt-3.5-turbo", temperature=0.0)

In [4]:
ds_dict = load_algebra_dataset('../../number_theory_test.json')
downsampled_ds_dict = downsample_datasets(ds_dict, total_rows)

new_total_rows = sum(dataset.num_rows for dataset in downsampled_ds_dict.values())
print("New total number of rows:", new_total_rows)

stats_df = create_stats(downsampled_ds_dict, prompt_examples)

New total number of rows: 200


In [5]:
counter = 0
for group_name, downsampled_ds in downsampled_ds_dict.items():
    level = group_name
    
    for i, row in enumerate(downsampled_ds):
        counter += 1
        print(f"Processing row {i+1} of {len(downsampled_ds)} for level {level} ({counter} out of {total_rows} in total)...")

        question = "Question: "+ row['problem'] + "Your response should provide the final answer with  'Action: Finish['\\boxed{latex_answer}], and follow the same solution format as shown above."
        key = row['solution']
        
        examples = MATH_NUMBERTHEORY_FEWSHOT_EXAMPLES_AGENT_COT[level]

        #######################################
        ######### COT (NO REFLECTION) #########
        #######################################
        agent_1 = SelfReflectionCoTAgent(
            llm=llm_1, 
            benchmark="math_algebra",
            reflector=ReflexionCoTReflector(llm=llm_1),
            max_reflections=1,
            max_trials=1
        )
    
        out_cot_1 = agent_1.generate(
            question=question,
            key=key,
            prompt=MATH_NUMBERTHEORY_FEWSHOT_INSTRUCTION_AGENT,
            examples=examples,
            reflect_strategy=None,
            additional_keys={},
            reflect_additional_keys={},
            max_trials=1,
            patience=1,
            reset=True
        )
    
        # Save output to file
        with open(f"./logs/{prompt_examples}_cot_1.txt", "a") as f_1:
            f_1.write(f"""
---------------------------------
QUESTION: {question}
KEY: {key}
---------------------------------
SIMPLE COT OUTPUT
---------------------------------
""")
            log_response(f_1, out_cot_1)
    
        # Update stats_df for Simple CoT
        result_type_cot_1 = 'correct' if out_cot_1 and out_cot_1[0].is_correct else 'incorrect'
        update_stats(
            stats_df,
            level,
            prompt_examples,
            'cot_1',
            trial_num=None,  # Indicate that this is a single-trial method
            state=result_type_cot_1  # 'correct' or 'incorrect'
        )


        ###################################################
        ######### SELF-REFLECTION + COT (2 TRIALS) ########
        ###################################################
        
        agent_2 = SelfReflectionCoTAgent(
            llm=llm_2, 
            benchmark="math_algebra",
            reflector=ReflexionCoTReflector(llm=llm_2),
            max_reflections=2,
            max_trials=2
        )
    
        out_cot_selfreflect_2 = agent_2.generate(
            question=question,
            key=key,
            examples=examples,
            prompt=MATH_NUMBERTHEORY_FEWSHOT_INSTRUCTION_AGENT,
            reflect_examples=None,
            reflect_prompt=ZEROSHOT_REFLECT_INSTRUCTION_NEGATIVE,
            reflect_strategy="reflexion",
            additional_keys={},
            reflect_additional_keys={},
            max_trials=2,
            patience=2,
            reset=True
        )
    
        # Save output to file
        with open(f"./logs/{prompt_examples}_selfreflection_cot_2.txt", "a", encoding="utf-8") as f_2:
            f_2.write(f"""
---------------------------------
QUESTION: {question}
KEY: {key}
---------------------------------
REFLECTION COT OUTPUT
---------------------------------
""")
            log_response(f_2, out_cot_selfreflect_2)
    
        # Determine and update trial states with padding
        method = 'selfreflection_cot_2'
        max_trials = 2
        transition_labels = ['start_to_final']
        process_transitions(
            stats_df,
            out_cot_selfreflect_2,
            level,
            prompt_examples,
            method,
            transition_labels,
            max_trials
        )
    
        ###################################################
        ######### SELF-REFLECTION + COT (3 TRIALS) ########
        ###################################################
    
        agent_3 = SelfReflectionCoTAgent(
            llm=llm_3,
            benchmark="math_algebra",
            reflector=ReflexionCoTReflector(llm=llm_3),
            max_reflections=3,
            max_trials=3
        )
    
        out_cot_selfreflect_3 = agent_3.generate(
            question=question,
            key=key,
            examples=examples,
            prompt=MATH_NUMBERTHEORY_FEWSHOT_INSTRUCTION_AGENT,
            reflect_examples=None,
            reflect_prompt=ZEROSHOT_REFLECT_INSTRUCTION_NEGATIVE,
            reflect_strategy="reflexion",
            additional_keys={},
            reflect_additional_keys={},
            max_trials=3,
            patience=3,
            reset=True
        )
    
        # Save output to file
        with open(f"./logs/{prompt_examples}_selfreflection_cot_3.txt", "a", encoding="utf-8") as f_3:
            f_3.write(f"""
---------------------------------
QUESTION: {question}
KEY: {key}
---------------------------------
REFLECTION COT OUTPUT
---------------------------------
""")
            log_response(f_3, out_cot_selfreflect_3)
    
        # Determine and update trial states with padding
        method = 'selfreflection_cot_3'
        max_trials = 3
        transition_labels = ['start_to_final']
        process_transitions(
            stats_df,
            out_cot_selfreflect_3,
            level,
            prompt_examples,
            method,
            transition_labels,
            max_trials
        )
    
        ############################################
        ######## REFLEXION + COT (2 TRIALS) ########
        ############################################
    
        agent_4 = ReflexionCoTAgent(
            llm=llm_4,
            benchmark="math_algebra",
            reflector=ReflexionCoTReflector(llm=llm_4),
            max_reflections=2,
            max_trials=2
        )
    
        out_cot_reflect_2 = agent_4.generate(
            question=question,
            key=key,
            examples=examples,
            prompt=MATH_NUMBERTHEORY_FEWSHOT_INSTRUCTION_AGENT,
            reflect_examples=None,
            reflect_prompt=ZEROSHOT_REFLECT_INSTRUCTION_NEGATIVE,
            reflect_strategy="reflexion",
            additional_keys={},
            reflect_additional_keys={},
            max_trials=2,
            patience=2,
            reset=True
        )
    
        # Save output to file
        with open(f"./logs/{prompt_examples}_reflexion_cot_2.txt", "a", encoding="utf-8") as f_4:
            f_4.write(f"""
---------------------------------
QUESTION: {question}
KEY: {key}
---------------------------------
REFLECTION COT OUTPUT
---------------------------------
""")
            log_response(f_4, out_cot_reflect_2)
    
        # Determine and update trial states with padding
        method = 'reflexion_cot_2'
        max_trials = 2
        transition_labels = ['start_to_final']
        process_transitions(
            stats_df,
            out_cot_reflect_2,
            level,
            prompt_examples,
            method,
            transition_labels,
            max_trials
        )
    
        ############################################
        ######## REFLEXION + COT (3 TRIALS) ########
        ############################################
    
        agent_5 = ReflexionCoTAgent(
            llm=llm_5,
            benchmark="math_algebra",
            reflector=ReflexionCoTReflector(llm=llm_5),
            max_reflections=3,
            max_trials=3
        )
    
        out_cot_reflect_3 = agent_5.generate(
            question=question,
            key=key,
            examples=examples,
            prompt=MATH_NUMBERTHEORY_FEWSHOT_INSTRUCTION_AGENT,
            reflect_examples=None,
            reflect_prompt=ZEROSHOT_REFLECT_INSTRUCTION_NEGATIVE,
            reflect_strategy="reflexion",
            additional_keys={},
            reflect_additional_keys={},
            max_trials=3,
            patience=3,
            reset=True
        )
    
        # Save output to file
        with open(f"./logs/{prompt_examples}_reflexion_cot_3.txt", "a", encoding="utf-8") as f_5:
            f_5.write(f"""
---------------------------------
QUESTION: {question}
KEY: {key}
---------------------------------
REFLECTION COT OUTPUT
---------------------------------
""")
            log_response(f_5, out_cot_reflect_3)
    
        # Determine and update trial states with padding
        method = 'reflexion_cot_3'
        max_trials = 3
        transition_labels = ['start_to_final']
        process_transitions(
            stats_df,
            out_cot_reflect_3,
            level,
            prompt_examples,
            method,
            transition_labels,
            max_trials
        )
    
        print(f"Processing row {i+1} of {len(downsampled_ds)} for group {group_name} ({counter} out of {total_rows} in total)... DONE")
        save_stats(stats_df, stats_filename=f'./results/{prompt_examples}_stats.csv')


Processing row 1 of 11 for level Level 1 (1 out of 200 in total)...


  out = llm(


Processing row 1 of 11 for group Level 1 (1 out of 200 in total)... DONE
Stats successfully saved to ./results/negative_prompt_negative_examples_stats.csv.
Processing row 2 of 11 for level Level 1 (2 out of 200 in total)...
Processing row 2 of 11 for group Level 1 (2 out of 200 in total)... DONE
Stats successfully saved to ./results/negative_prompt_negative_examples_stats.csv.
Processing row 3 of 11 for level Level 1 (3 out of 200 in total)...
Processing row 3 of 11 for group Level 1 (3 out of 200 in total)... DONE
Stats successfully saved to ./results/negative_prompt_negative_examples_stats.csv.
Processing row 4 of 11 for level Level 1 (4 out of 200 in total)...
Processing row 4 of 11 for group Level 1 (4 out of 200 in total)... DONE
Stats successfully saved to ./results/negative_prompt_negative_examples_stats.csv.
Processing row 5 of 11 for level Level 1 (5 out of 200 in total)...
Processing row 5 of 11 for group Level 1 (5 out of 200 in total)... DONE
Stats successfully saved to ./r