In [1]:
from langchain_community.chat_models import ChatOpenAI
import dotenv
import os
import joblib
from langchain_core.language_models.chat_models import BaseChatModel
from typing import Dict, List

from discussion_agents.cog.train.expel import (
    # Experience Gathering.
    gather_experience,
    categorize_experiences,
    get_folds,
    # Insight Extraction.
    _build_compare_prompt,
    collapse_prompts,
    _prompt_compare_critique,
    parse_rules,
    retrieve_rule_index,
    is_existing_rule,
    remove_err_operations,
    update_rules,
)
from discussion_agents.utils.general import random_divide_list

dotenv.load_dotenv("../.env")
openai_api_key = os.getenv("OPENAI_API_KEY")

In [2]:
max_num_rules = 20
llm = ChatOpenAI(model_name="gpt-3.5-turbo-0125", openai_api_key=openai_api_key)
experiences = joblib.load("exp_15_compare_fake.joblib")
categories = categorize_experiences(experiences)
folds = get_folds(categories, len(experiences['idxs']))

In [3]:
from typing import Tuple

def create_rules(
    llm: BaseChatModel,
    experiences: Dict[str, List], 
    categories: Dict[str, int], 
    train_idxs: List[int], 
    rules: List[str], 
    rules_with_count: List[Tuple[str, int]],
    max_num_rules: int,
) -> Tuple[List[str], List[Tuple[str, int]]]:
    # Intersect between train_idxs and each category (compare, success, fail).
    train_category_idxs = {
        category: list(set(train_idxs).intersection(set(category_idxs))) \
            for category, category_idxs in categories.items()
    }

    # Compare.
    for train_idx in train_category_idxs["compare"]:
        question = experiences["questions"][train_idx]
        trajectory = experiences["trajectories"][train_idx]

        # Compare the successful trial with all previous failed trials.
        success_trial = trajectory[-1][-1]
        for failed_trial in trajectory[:-1]:
            out = _prompt_compare_critique(
                rules, 
                question, 
                success_trial, 
                failed_trial, 
                max_num_rules < len(rules_with_count),
                llm
            )
            operations = parse_rules(out)
            operations = remove_err_operations(rules_with_count, operations)

            # Update rules_with_count and rules with comparison insights.
            rules_with_count = update_rules(rules_with_count, operations, is_full=max_num_rules+5 <= len(rules_with_count))
            rules = [rule[0] for rule in rules_with_count]

    # Success.
    for train_idx in train_category_idxs["success"]:
        question = experiences["questions"][train_idx]
        trajectory = experiences["trajectories"][train_idx]

    return rules, rules_with_count

In [4]:
rules, rules_with_count = [], []
for fold, train_idxs in folds.items():
    print(fold, train_idxs)
    rules, rules_with_count = create_rules(llm, experiences, categories, train_idxs, rules, rules_with_count, max_num_rules)
    break

0 [1, 2, 6, 9, 12, 14]


In [5]:
train_category_idxs = {
    category: list(set(train_idxs).intersection(set(category_idxs))) \
        for category, category_idxs in categories.items()
}
train_category_idxs

{'compare': [12, 14], 'success': [1, 6], 'fail': [9, 2]}

In [21]:
from discussion_agents.utils.general import random_divide_list

success_critique_num = 8
batched_success_trajs_idxs = random_divide_list(train_category_idxs['success'], success_critique_num)
batched_success_trajs_idxs

[[1, 6]]

In [26]:
for success_idxs in batched_success_trajs_idxs:
    print(success_idxs)

    concat_success_trajs = [
        f"{experiences['questions'][idx]}\n{experiences['trajectories'][idx][0][-1]}"
        for idx in success_idxs
    ]
    success_trajs_str = "\n\n".join(concat_success_trajs)
    print("it works")

    concat_success_trajs = []
    for idx in success_idxs:
        question = experiences["questions"][idx]
        trajectory = experiences['trajectories'][idx]
        out = question + "\n" + trajectory[0][-1]  # Get this successful trajectory's zero-th trial output.
        concat_success_trajs.append(out)
    _success_trajs_str = "\n\n".join(concat_success_trajs)

    # success_trials = '\n\n'.join([self.remove_task_suffix(task) + '\n' + trajectory for task, trajectory in success_chunk])


[1, 6]
it works


In [None]:
for train_idx in train_category_idxs["success"]:
    question = experiences["questions"][train_idx]
    trajectory = experiences["trajectories"][train_idx]

In [None]:
success_critique_num = 8


all_success = random_divide_list(all_success, self.success_critique_num)