In [1]:
from langchain_community.chat_models import ChatOpenAI
import dotenv
import os
import joblib
from langchain_core.language_models.chat_models import BaseChatModel
from typing import Dict, List

from discussion_agents.cog.functional.expel import (
    # Experience Gathering.
    gather_experience,
    categorize_experiences,
    get_folds,
    # Insight Extraction.
    _build_compare_prompt,
    _prompt_compare_critique,
    parse_rules,
    retrieve_rule_index,
    is_existing_rule,
    remove_err_operations,
    update_rules,
)
from discussion_agents.utils.general import random_divide_list

dotenv.load_dotenv("../.env")
openai_api_key = os.getenv("OPENAI_API_KEY")

In [2]:
max_num_rules = 20
llm = ChatOpenAI(model_name="gpt-3.5-turbo-0125", openai_api_key=openai_api_key)
experiences = joblib.load("exp_15_compare_fake.joblib")
categories = categorize_experiences(experiences)
folds = get_folds(categories, len(experiences['idxs']))

In [3]:
from typing import Tuple
from discussion_agents.cog.functional.expel import _prompt_all_success_critique

def create_rules(
    llm: BaseChatModel,
    experiences: Dict[str, List], 
    categories: Dict[str, int], 
    train_idxs: List[int], 
    rules: List[str], 
    rules_with_count: List[Tuple[str, int]],
    max_num_rules: int,
    success_critique_num: int = 8
) -> Tuple[List[str], List[Tuple[str, int]]]:
    # Intersect between train_idxs and each category (compare, success, fail).
    train_category_idxs = {
        category: list(set(train_idxs).intersection(set(category_idxs))) \
            for category, category_idxs in categories.items()
    }

    # Compare.
    for train_idx in train_category_idxs["compare"]:
        question = experiences["questions"][train_idx]
        trajectory = experiences["trajectories"][train_idx]

        # Compare the successful trial with all previous failed trials.
        success_trial = trajectory[-1][-1]
        for failed_trial in trajectory[:-1]:
            # Prompt.
            out = _prompt_compare_critique(
                rules, 
                question, 
                success_trial, 
                failed_trial, 
                max_num_rules < len(rules_with_count),
                llm
            )

            # Parse.
            operations = parse_rules(out)
            
            # Remove no-ops.
            operations = remove_err_operations(rules_with_count, operations)

            # Update rules_with_count and rules with comparison insights.
            rules_with_count = update_rules(rules_with_count, operations, is_full=max_num_rules+5 <= len(rules_with_count))
            rules = [rule[0] for rule in rules_with_count]

    # Success.
    batched_success_trajs_idxs = random_divide_list(train_category_idxs['success'], success_critique_num)
    for success_idxs in batched_success_trajs_idxs:
        # Concatenate batched successful trajectories.
        concat_success_trajs = [
            f"{experiences['questions'][idx]}\n{experiences['trajectories'][idx][0][-1]}"  # Get this successful trajectory's zero-th trial output.
            for idx in success_idxs
        ]
        success_trajs_str = "\n\n".join(concat_success_trajs)

        # Prompt.
        out = _prompt_all_success_critique(
            rules, 
            success_trajs_str, 
            max_num_rules < len(rules_with_count), 
            llm
        )

        # Parse.
        operations = parse_rules(out)

        # Remove no-ops.
        operations = remove_err_operations(rules_with_count, operations)

        # Update rules_with_count and rules with success insights.
        rules_with_count = update_rules(rules_with_count, operations, is_full=max_num_rules+5 <= len(rules_with_count))
        rules = [rule[0] for rule in rules_with_count]

    return rules, rules_with_count

In [4]:
rules, rules_with_count = [], []
for fold, train_idxs in folds.items():
    print(fold, train_idxs)
    rules, rules_with_count = create_rules(llm, experiences, categories, train_idxs, rules, rules_with_count, max_num_rules)
    break

0 [0, 1, 2, 6, 11, 14]


In [None]:
train_category_idxs = {
    category: list(set(train_idxs).intersection(set(category_idxs))) \
        for category, category_idxs in categories.items()
}
train_category_idxs

In [None]:
from discussion_agents.utils.general import random_divide_list

success_critique_num = 8
batched_success_trajs_idxs = random_divide_list(train_category_idxs['success'], success_critique_num)
batched_success_trajs_idxs

In [None]:
rules, rules_with_count

In [None]:
for success_idxs in batched_success_trajs_idxs:
    print(success_idxs)

    concat_success_trajs = [
        f"{experiences['questions'][idx]}\n{experiences['trajectories'][idx][0][-1]}"  # Get this successful trajectory's zero-th trial output.
        for idx in success_idxs
    ]
    success_trajs_str = "\n\n".join(concat_success_trajs)
    # self.rule_items, llm_output = extend_rules(self.rule_items, success_trials.strip(), None)



In [None]:
from discussion_agents.cog.prompts.expel import (
    SYSTEM_TEMPLATE,
    SYSTEM_CRITIQUE_ALL_SUCCESS_EXISTING_RULES_INSTRUCTION,
    NON_EXISTENT_RULES_AT_NAME,
    EXISTING_RULES_AI_NAME,
)
from langchain_core.prompts.chat import HumanMessagePromptTemplate

critique_history = []

# System prompt.
prefix = (
    HumanMessagePromptTemplate.from_template(SYSTEM_TEMPLATE)
    .format_messages(
        ai_name=NON_EXISTENT_RULES_AT_NAME if not rules else EXISTING_RULES_AI_NAME,
        instruction=SYSTEM_CRITIQUE_ALL_SUCCESS_EXISTING_RULES_INSTRUCTION
    )
)
critique_history.extend(prefix)

In [None]:
critique_history

In [None]:
human_format_dict = {
    "success_trajs": success_trajs_str,
    "existing_rules": '\n'.join([f'{i}. {r}' for i, r in enumerate(rules, 1)])
}
human_format_dict

In [None]:
from discussion_agents.cog.prompts.expel import FORMAT_RULES_OPERATION_TEMPLATE
HUMAN_CRITIQUE_EXISTING_RULES_ALL_SUCCESS_TEMPLATE = """
Here are the trials:
{success_trajs}

Here are the EXISTING RULES:
{existing_rules}

By examining the successful trials, and the list of existing rules, you can perform the following operations: add, edit, remove, or agree so that the new list of rules are general and high level insights of the successful trials or proposed way of Thought so they can be used as helpful tips to different tasks in the future. Have an emphasis on tips that help the agent perform better Thought and Action. Follow the below format:

""" + FORMAT_RULES_OPERATION_TEMPLATE

human_critique_summary_message = HumanMessagePromptTemplate.from_template(HUMAN_CRITIQUE_EXISTING_RULES_ALL_SUCCESS_TEMPLATE).format_messages(**human_format_dict)[0]
human_critique_summary_message

In [None]:
for train_idx in train_category_idxs["success"]:
    question = experiences["questions"][train_idx]
    trajectory = experiences["trajectories"][train_idx]

In [None]:
from langchain_core.prompts.chat import HumanMessage
from discussion_agents.cog.prompts.expel import (
    CRITIQUE_SUMMARY_SUFFIX_FULL,
    CRITIQUE_SUMMARY_SUFFIX_NOT_FULL
)

def _build_all_success_prompt(
    rules: List[str], 
    success_trajs_str: str,
    is_full: bool,
) -> List[HumanMessage]:
    # is_full = self.max_num_rules <= len(self.rules_with_count)   ->    20 <= len(self.rules_with_count)

    critique_history = []

    if rules == []:
        rules = ['']

    # System prompt.
    prefix = (
        HumanMessagePromptTemplate.from_template(SYSTEM_TEMPLATE)
        .format_messages(
            ai_name=NON_EXISTENT_RULES_AT_NAME if not rules else EXISTING_RULES_AI_NAME,
            instruction=SYSTEM_CRITIQUE_ALL_SUCCESS_EXISTING_RULES_INSTRUCTION
        )
    )
    critique_history.extend(prefix)

    # Task prompt.
    human_format_dict = {
        "success_trajs": success_trajs_str,
        "existing_rules": '\n'.join([f'{i}. {r}' for i, r in enumerate(rules, 1)])
    }

    human_critique_summary_message = HumanMessagePromptTemplate.from_template(HUMAN_CRITIQUE_EXISTING_RULES_ALL_SUCCESS_TEMPLATE).format_messages(**human_format_dict)[0]
    critique_summary_suffix = CRITIQUE_SUMMARY_SUFFIX_FULL if is_full else CRITIQUE_SUMMARY_SUFFIX_NOT_FULL
    human_critique_summary_message.content = human_critique_summary_message.content + critique_summary_suffix
    critique_history.append(human_critique_summary_message)

    return critique_history

def _prompt_all_success_critique(
    rules: List[str], 
    success_trajs_str: str, 
    is_full: bool,
    llm: BaseChatModel, 
    replace_newline: bool = False
) -> str:
    compare_prompt_msgs = _build_all_success_prompt(
        rules=rules,
        success_trajs_str=success_trajs_str,
        is_full=is_full
    )
    compare_prompt_msgs = collapse_prompts(compare_prompt_msgs)
    out = llm(compare_prompt_msgs).content.strip('\n').strip()
    if replace_newline:
        out = out.replace('\n', '')
    return out

In [None]:
oo = _prompt_all_success_critique(rules, success_trajs_str, max_num_rules < len(rules_with_count), llm)

In [None]:
oo

In [None]:
success_critique_num = 8


all_success = random_divide_list(all_success, self.success_critique_num)