In [1]:
from discussion_agents.cog.agent.reflexion import ReflexionReActAgent
from langchain_community.chat_models import ChatOpenAI
from typing import List, Dict, Optional

import dotenv
import os

dotenv.load_dotenv("../.env")
openai_api_key = os.getenv("OPENAI_API_KEY")
llm = ChatOpenAI(model_name="gpt-3.5-turbo-0125", openai_api_key=openai_api_key)

In [None]:
import joblib
hotpot = joblib.load('../agent/hotpot-qa-distractor-sample.joblib').reset_index(drop=True)
hotpot.head()

In [4]:
def gather_experience(
    reflexion_react_agent: ReflexionReActAgent,
    questions: List[str],
    keys: List[str],
    strategy: Optional[str] = "reflexion",
) -> Dict[str, List]:
    experiences = {
        "idxs": [],
        "questions": [],
        "keys": [],
        "trajectories": [],
        "reflections": []
    }
    for idx, (question, key) in enumerate(zip(questions, keys)):
        trajectory = reflexion_react_agent.generate(
            question=question, key=key, strategy=strategy, reset=True
        )

        experiences["idxs"].append(idx)
        experiences["questions"].append(question)
        experiences["keys"].append(key)
        experiences["trajectories"].append(trajectory)
        experiences["reflections"].append(reflexion_react_agent.reflector.reflections)
        
    return experiences

In [None]:
k = 5

agent = ReflexionReActAgent(
    self_reflect_llm=llm,
    action_llm=llm,
    max_steps=7,
    max_trials=3,
)

# experiences_tmp = gather_experience(agent, questions=hotpot.question.values.tolist()[10:10+k], keys=hotpot.answer.values.tolist()[10:10+k])

In [None]:
# joblib.dump(experiences, "experiences_10.joblib")

In [None]:
import joblib
experiences = joblib.load("exp_15_compare_fake.joblib")

In [5]:

def categorize_experiences(experiences: Dict[str, List]) -> Dict[str, List]:
    count_dict = {
        "compare": [],
        "success": [],
        "fail": []
    }

    for idx in experiences["idxs"]:  # Index for a particular task.
        trajectory = experiences["trajectories"][idx]
        trials_are_correct = [trial[0] for trial in trajectory]  # (is_correct, answer, output)[0]

        # Success.
        if all(trials_are_correct) and len(trials_are_correct) == 1:  # If success @ first trial, then stop generation.
            count_dict["success"].append(idx)
        # Compare.
        elif trials_are_correct[-1]:  # If fail(s), then succeeds, then only last trial is True.
            count_dict["compare"].append(idx)
        # Fail.
        elif not all(trials_are_correct):  # All trials failed, then fail case.
            count_dict["fail"].append(idx)
        else:
            raise ValueError(f"Unhandled scenario for trajectory at index {idx}.")

    return count_dict

categories = categorize_experiences(experiences)
categories

{'compare': [10, 11, 12, 13, 14],
 'success': [1, 3, 6, 7, 8],
 'fail': [0, 2, 4, 5, 9]}

In [None]:
experiences.keys()

In [6]:
import random

def get_folds(categories: Dict[str, List], n_instances: int, n_folds: int = 2) -> Dict[str, List]:
    folds = {fold: [] for fold in range(n_folds)}

    # Assign labels for 'compare', 'success', and  'fail'.
    for _, indices in categories.items():
        random.shuffle(indices)
        for count, idx in enumerate(indices):
            folds[count % n_folds].append(idx)

    # Each fold is a validation set. Take the difference to get the training set of each fold.
    folds = {fold: list(set(list(range(n_instances))).difference(values)) for fold, values in folds.items()}

    return folds

folds = get_folds(categories, 15)
folds

{0: [1, 4, 5, 8, 11, 13], 1: [0, 2, 3, 6, 7, 9, 10, 12, 14]}

# Insight Extraction

In [None]:
train_idxs = folds[0]

train_category_idxs = {
    category: list(set(train_idxs).intersection(set(category_idxs))) \
        for category, category_idxs in categories.items()
}

In [None]:
train_idxs

In [None]:
train_category_idxs

In [7]:
from langchain_core.prompts.chat import HumanMessagePromptTemplate
from langchain_core.messages.human import HumanMessage

from discussion_agents.cog.prompts.expel import (
    SYSTEM_TEMPLATE, 
    SYSTEM_CRITIQUE_EXISTING_RULES_INSTRUCTION,
    EXISTING_RULES_AI_NAME,
    NON_EXISTENT_RULES_AT_NAME,
    HUMAN_CRITIQUE_EXISTING_RULES_TEMPLATE,
    CRITIQUE_SUMMARY_SUFFIX_FULL,
    CRITIQUE_SUMMARY_SUFFIX_NOT_FULL
)

def _build_compare_prompt(
    rules: List[str], 
    question: str,
    success_trial: str, 
    failed_trial: str, 
    is_full: bool,
) -> List[HumanMessage]:
    # is_full = self.max_num_rules <= len(self.rules_with_count)   ->    20 <= len(self.rules_with_count)

    critique_history = []

    if rules == []:
        rules = ['']

    # System prompt.
    prefix = (
        HumanMessagePromptTemplate.from_template(SYSTEM_TEMPLATE)
        .format_messages(
            ai_name=NON_EXISTENT_RULES_AT_NAME if not rules else EXISTING_RULES_AI_NAME,
            instruction=SYSTEM_CRITIQUE_EXISTING_RULES_INSTRUCTION
        )
    )
    critique_history.extend(prefix)

    # Task prompt.
    human_format_dict = {
        'question': question,
        'failed_traj': failed_trial,
        'success_traj': success_trial,
        'existing_rules': '\n'.join([f'{i}. {r}' for i, r in enumerate(rules, 1)])
    }

    human_critique_summary_message = HumanMessagePromptTemplate.from_template(HUMAN_CRITIQUE_EXISTING_RULES_TEMPLATE).format_messages(**human_format_dict)[0]
    critique_summary_suffix = CRITIQUE_SUMMARY_SUFFIX_FULL if is_full else CRITIQUE_SUMMARY_SUFFIX_NOT_FULL
    human_critique_summary_message.content = human_critique_summary_message.content + critique_summary_suffix
    critique_history.append(human_critique_summary_message)

    return critique_history

In [None]:
failed_traj = experiences['trajectories'][11][0][-1]
success_traj = experiences['trajectories'][11][-1][-1]
question = experiences['questions'][11]

In [None]:
max_num_rules = 20
rules = []
rules_with_count = []
is_full = max_num_rules < len(rules_with_count)
compare_prompt_msgs = _build_compare_prompt(rules, question, failed_traj, success_traj, is_full=is_full)
compare_prompt_msgs

In [8]:
from langchain_core.messages.chat import ChatMessage

def collapse_prompts(prompt_history: List[ChatMessage]) -> List[ChatMessage]:
    """Courtesy of GPT4"""
    if not prompt_history:
        return []

    new_prompt_history = []
    scratch_pad = prompt_history[0].content
    last_message_type = type(prompt_history[0])

    for message in prompt_history[1:]:
        current_message_type = type(message)
        if current_message_type == last_message_type:
            scratch_pad += '\n' + message.content
        else:
            new_prompt_history.append(last_message_type(content=scratch_pad))
            scratch_pad = message.content
            last_message_type = current_message_type

    # Handle the last accumulated message.
    new_prompt_history.append(last_message_type(content=scratch_pad))

    return new_prompt_history

In [None]:
compare_prompt_msgs = collapse_prompts(compare_prompt_msgs)
compare_prompt_msgs

In [9]:
from langchain_core.language_models.chat_models import BaseChatModel

def _prompt_compare_critique(compare_prompt_msgs: List[HumanMessage], llm: BaseChatModel, replace_newline: bool = False):
    out = llm(compare_prompt_msgs).content.strip('\n').strip()
    if replace_newline:
        out = out.replace('\n', '')
    return out

In [None]:
out = _prompt_compare_critique(compare_prompt_msgs, llm)

In [None]:
print(out)

In [10]:
import re

def parse_rules(llm_text):
    pattern = r'((?:REMOVE|EDIT|ADD|AGREE)(?: \d+|)): (?:[a-zA-Z\s\d]+: |)(.*)'
    matches = re.findall(pattern, llm_text)

    res = []
    banned_words = ['ADD', 'AGREE', 'EDIT']
    for operation, text in matches:
        text = text.strip()
        if text != '' and not any([w in text for w in banned_words]) and text.endswith('.'):
        # if text is not empty
        # if text doesn't contain banned words (avoid weird formatting cases from llm)
        # if text ends with a period (avoid cut off sentences from llm)
            if 'ADD' in operation:
                res.append(('ADD', text))
            else:
                res.append((operation.strip(), text))
    return res

In [None]:
operations = parse_rules(out)
operations

In [11]:
from typing import Tuple

def retrieve_rule_index(rules: List[Tuple[str, int]], operation_rule_text: str) -> int:
    for i in range(len(rules)):
        if rules[i][0] in operation_rule_text:
            return i
    return -1

def is_existing_rule(rules: List[Tuple[str, int]], operation_rule_text: str) -> bool:
    for i in range(len(rules)):
        if rules[i][0] in operation_rule_text:
            return True
    return False

def remove_err_operations(rules: List[Tuple[str, int]], operations: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
    cleaned_operations = operations.copy()
    
    delete_indices = []
    for i in range(len(cleaned_operations)):
        # Split the operation into action type and optional rule number.
        operation, operation_rule_text = cleaned_operations[i]
        operation_type = operation.split(' ')[0]
        rule_num = int(operation.split(' ')[1]) if ' ' in operation else None

        if operation_type == 'ADD':
            if is_existing_rule(rules, operation_rule_text): # If new rule_text is an existing rule ('in').
                delete_indices.append(i)
        else:
            if operation_type == 'EDIT':
                if is_existing_rule(rules, operation_rule_text): # If rule is matching ('in') existing rule, change it to AGREE.
                    rule_num = retrieve_rule_index(rules, operation_rule_text)
                    cleaned_operations[i] = (f'AGREE {rule_num+1}', rules[rule_num][0])
                elif (rule_num is None) or (rule_num > len(rules)):   # If rule doesn't exist, remove.
                    delete_indices.append(i)
                    
            elif operation_type == 'REMOVE' or operation_type == 'AGREE':
                if not is_existing_rule(rules, operation_rule_text): # If new operation_rule_text is not an existing rule.
                    delete_indices.append(i)

    # Remove problematic operations.
    cleaned_operations = [cleaned_operations[i] for i in range(len(cleaned_operations)) if i not in delete_indices]
    
    return cleaned_operations

def update_rules(rules: List[Tuple[str, int]], operations: List[Tuple[str, str]], is_full: bool = False) -> List[Tuple[str, int]]:
    updated_rules = rules.copy()
    
    for op in ['REMOVE', 'AGREE', 'EDIT', 'ADD']: # Order is important
        for i in range(len(operations)):
            operation, operation_rule_text = operations[i]
            operation_type = operation.split(' ')[0]
            if operation_type != op:
                continue

            if operation_type == 'REMOVE': # remove rule: -1
                rule_index = retrieve_rule_index(updated_rules, operation_rule_text) # if rule_num doesn't match but text does
                remove_strength = 3 if is_full else 1
                updated_rules[rule_index] = (updated_rules[rule_index][0], updated_rules[rule_index][1]-remove_strength) # -1 (-3 if list full) to the counter
            elif operation_type == 'AGREE': # agree with rule: +1
                rule_index = retrieve_rule_index(updated_rules, operation_rule_text) # if rule_num doesn't match but text does
                updated_rules[rule_index] = (updated_rules[rule_index][0], updated_rules[rule_index][1]+1) # +1 to the counter
            elif operation_type == 'EDIT': # edit the rule: +1 // NEED TO BE AFTER REMOVE AND AGREE
                rule_index = int(operation.split(' ')[1])-1
                updated_rules[rule_index] = (operation_rule_text, updated_rules[rule_index][1]+1) # +1 to the counter
            elif operation_type == 'ADD': # add new rule: +2
                updated_rules.append((operation_rule_text, 2))
    updated_rules = [updated_rules[i] for i in range(len(updated_rules)) if updated_rules[i][1] > 0] # remove rules when counter reach 0
    updated_rules.sort(key=lambda x: x[1], reverse=True)

    return updated_rules

In [None]:
is_full = max_num_rules+5 <= len(rules_with_count)

# Remove problematic operations.
operations = remove_err_operations(rules_with_count, operations)
rules_with_count = update_rules(rules_with_count, operations, is_full=is_full)
rules = [rule[0] for rule in rules_with_count]

In [12]:
from typing import Tuple

def create_rules(
    llm: BaseChatModel,
    experiences: Dict[str, List], 
    categories: Dict[str, int], 
    train_idxs: List[int], 
    rules: List[str], 
    rules_with_count: List[Tuple[str, int]],
    max_num_rules: int,
):
    # Intersect between train_idxs and each category (compare, success, fail).
    train_category_idxs = {
        category: list(set(train_idxs).intersection(set(category_idxs))) \
            for category, category_idxs in categories.items()
    }

    # Compare.
    for train_idx in train_category_idxs["compare"]:
        question = experiences["questions"][train_idx]
        trajectory = experiences["trajectories"][train_idx]
        success_trial = trajectory[-1][-1]
        for failed_trial in trajectory[:-1]:
            compare_prompt_msgs = _build_compare_prompt(rules, question, success_trial, failed_trial, is_full=max_num_rules < len(rules_with_count))
            compare_prompt_msgs = collapse_prompts(compare_prompt_msgs)
            out = _prompt_compare_critique(compare_prompt_msgs, llm)
            operations = parse_rules(out)
            operations = remove_err_operations(rules_with_count, operations)
            rules_with_count = update_rules(rules_with_count, operations, is_full=max_num_rules+5 <= len(rules_with_count))
            rules = [rule[0] for rule in rules_with_count]

    # # Success.
    # for train_idx in train_category_idxs["success"]:
    #     question = experiences["questions"][train_idx]
    #     key = experiences["keys"][train_idx]
    #     trajectory = experiences["trajectories"][train_idx]
    #     reflection = experiences["reflections"][train_idx]
    
    # # Fail.
    # for train_idx in train_category_idxs["fail"]:
    #     question = experiences["questions"][train_idx]
    #     key = experiences["keys"][train_idx]
    #     trajectory = experiences["trajectories"][train_idx]
    #     reflection = experiences["reflections"][train_idx]

    return rules, rules_with_count

In [14]:
import joblib
max_num_rules = 20
llm = ChatOpenAI(model_name="gpt-3.5-turbo-0125", openai_api_key=openai_api_key)
experiences = joblib.load("exp_15_compare_fake.joblib")
categories = categorize_experiences(experiences)
folds = get_folds(categories, len(experiences['idxs']))

rules, rules_with_count = [], []
for fold, train_idxs in folds.items():
    print(fold, train_idxs)
    rules, rules_with_count = create_rules(llm, experiences, categories, train_idxs, rules, rules_with_count, max_num_rules)
    break

0 [0, 2, 6, 8, 13, 14]


In [None]:
for k, eval_idxs in enumerate(eval_idx_list):
    if k < starting_fold:
        continue
    training_ids = set(range(num_training_tasks)) - set(eval_idxs)
    (SAVE_PATH / f"fold_{k}").mkdir(exist_ok=True)
    log += f'################## FOLD {k} ##################\n'
    log += react_agent.create_rules(
        list(training_ids),
        cache_fold=k,
        logging_dir=str(SAVE_PATH / f"fold_{k}"),
        run_name=cfg.run_name,
        loaded_dict=dicts[-1] if resume and resume_starting_fold == starting_fold else None,
        loaded_log=critique_summary_log if resume and resume_starting_fold == starting_fold else None,
        eval_idx_list=eval_idx_list,
        saving_dict=True,
    )
    starting_fold += 1