In [1]:
from langchain_community.chat_models import ChatOpenAI
import dotenv
import os
import joblib
from langchain_core.language_models.chat_models import BaseChatModel
from typing import Dict, List

from discussion_agents.cog.train.expel import (
    # Experience Gathering.
    gather_experience,
    categorize_experiences,
    get_folds,
    # Insight Extraction.
    _build_compare_prompt,
    collapse_prompts,
    _prompt_compare_critique,
    parse_rules,
    retrieve_rule_index,
    is_existing_rule,
    remove_err_operations,
    update_rules,
)
from discussion_agents.utils.general import random_divide_list

dotenv.load_dotenv("../.env")
openai_api_key = os.getenv("OPENAI_API_KEY")

In [2]:
max_num_rules = 20
llm = ChatOpenAI(model_name="gpt-3.5-turbo-0125", openai_api_key=openai_api_key)
experiences = joblib.load("exp_15_compare_fake.joblib")
categories = categorize_experiences(experiences)
folds = get_folds(categories, len(experiences['idxs']))

In [3]:
from typing import Tuple

def create_rules(
    llm: BaseChatModel,
    experiences: Dict[str, List], 
    categories: Dict[str, int], 
    train_idxs: List[int], 
    rules: List[str], 
    rules_with_count: List[Tuple[str, int]],
    max_num_rules: int,
) -> Tuple[List[str], List[Tuple[str, int]]]:
    # Intersect between train_idxs and each category (compare, success, fail).
    train_category_idxs = {
        category: list(set(train_idxs).intersection(set(category_idxs))) \
            for category, category_idxs in categories.items()
    }

    # Compare.
    for train_idx in train_category_idxs["compare"]:
        question = experiences["questions"][train_idx]
        trajectory = experiences["trajectories"][train_idx]

        # Compare the successful trial with all previous failed trials.
        success_trial = trajectory[-1][-1]
        for failed_trial in trajectory[:-1]:
            out = _prompt_compare_critique(
                rules, 
                question, 
                success_trial, 
                failed_trial, 
                max_num_rules < len(rules_with_count),
                llm
            )
            operations = parse_rules(out)
            operations = remove_err_operations(rules_with_count, operations)

            # Update rules_with_count and rules with comparison insights.
            rules_with_count = update_rules(rules_with_count, operations, is_full=max_num_rules+5 <= len(rules_with_count))
            rules = [rule[0] for rule in rules_with_count]

    # Success.
    for train_idx in train_category_idxs["success"]:
        question = experiences["questions"][train_idx]
        trajectory = experiences["trajectories"][train_idx]

    return rules, rules_with_count

In [4]:
rules, rules_with_count = [], []
for fold, train_idxs in folds.items():
    print(fold, train_idxs)
    rules, rules_with_count = create_rules(llm, experiences, categories, train_idxs, rules, rules_with_count, max_num_rules)
    break

0 [0, 3, 5, 7, 11, 13]


In [5]:
train_category_idxs = {
    category: list(set(train_idxs).intersection(set(category_idxs))) \
        for category, category_idxs in categories.items()
}
train_category_idxs

{'compare': [11, 13], 'success': [3, 7], 'fail': [0, 5]}

In [6]:
from discussion_agents.utils.general import random_divide_list

success_critique_num = 8
batched_success_trajs_idxs = random_divide_list(train_category_idxs['success'], success_critique_num)
batched_success_trajs_idxs

[[3, 7]]

In [7]:
rules, rules_with_count

(['Always focus on extracting specific information relevant to the question at hand.',
  'Always consider narrowing down your search if the initial search results are too broad or not directly related to the question.',
  'When searching for specific information, utilize keywords or phrases related to the topic to refine search results.',
  'When researching historical events, consider the broader context in which the event took place to provide a more comprehensive answer.',
  'Ensure to provide a direct and concise answer to the question without unnecessary elaboration or tangents.'],
 [('Always focus on extracting specific information relevant to the question at hand.',
   2),
  ('Always consider narrowing down your search if the initial search results are too broad or not directly related to the question.',
   2),
  ('When searching for specific information, utilize keywords or phrases related to the topic to refine search results.',
   2),
  ('When researching historical events, c

In [26]:
for success_idxs in batched_success_trajs_idxs:
    print(success_idxs)

    concat_success_trajs = [
        f"{experiences['questions'][idx]}\n{experiences['trajectories'][idx][0][-1]}"  # Get this successful trajectory's zero-th trial output.
        for idx in success_idxs
    ]
    success_trajs_str = "\n\n".join(concat_success_trajs)
    self.rule_items, llm_output = extend_rules(self.rule_items, success_trials.strip(), None)



[1, 6]
it works


In [8]:
from discussion_agents.cog.prompts.expel import (
    SYSTEM_TEMPLATE,
    SYSTEM_CRITIQUE_ALL_SUCCESS_EXISTING_RULES_INSTRUCTION,
    NON_EXISTENT_RULES_AT_NAME,
    EXISTING_RULES_AI_NAME,
)
from langchain_core.prompts.chat import HumanMessagePromptTemplate

critique_history = []

# System prompt.
prefix = (
    HumanMessagePromptTemplate.from_template(SYSTEM_TEMPLATE)
    .format_messages(
        ai_name=NON_EXISTENT_RULES_AT_NAME if not rules else EXISTING_RULES_AI_NAME,
        instruction=SYSTEM_CRITIQUE_ALL_SUCCESS_EXISTING_RULES_INSTRUCTION
    )
)
critique_history.extend(prefix)

In [9]:
critique_history

[HumanMessage(content='You are an advanced reasoning agent that can add, edit or remove rules from your existing rule set, based on forming new critiques of past task trajectories. You will be given successful tasks trials in which you were given access to a Docstore API environment and a question to answer.')]

In [None]:
for train_idx in train_category_idxs["success"]:
    question = experiences["questions"][train_idx]
    trajectory = experiences["trajectories"][train_idx]

In [None]:
success_critique_num = 8


all_success = random_divide_list(all_success, self.success_critique_num)