In [5]:
from langchain_community.chat_models import ChatOpenAI
import dotenv
import os
import joblib
from langchain_core.language_models.chat_models import BaseChatModel
from typing import Dict, List

from discussion_agents.cog.train.expel import (
    # Experience Gathering.
    gather_experience,
    categorize_experiences,
    get_folds,
    # Insight Extraction.
    _build_compare_prompt,
    collapse_prompts,
    _prompt_compare_critique,
    parse_rules,
    retrieve_rule_index,
    is_existing_rule,
    remove_err_operations,
    update_rules,
)

dotenv.load_dotenv("../.env")
openai_api_key = os.getenv("OPENAI_API_KEY")

In [6]:
max_num_rules = 20
llm = ChatOpenAI(model_name="gpt-3.5-turbo-0125", openai_api_key=openai_api_key)
experiences = joblib.load("exp_15_compare_fake.joblib")
categories = categorize_experiences(experiences)
folds = get_folds(categories, len(experiences['idxs']))

In [7]:
from typing import Tuple

def create_rules(
    llm: BaseChatModel,
    experiences: Dict[str, List], 
    categories: Dict[str, int], 
    train_idxs: List[int], 
    rules: List[str], 
    rules_with_count: List[Tuple[str, int]],
    max_num_rules: int,
) -> Tuple[List[str], List[Tuple[str, int]]]:
    # Intersect between train_idxs and each category (compare, success, fail).
    train_category_idxs = {
        category: list(set(train_idxs).intersection(set(category_idxs))) \
            for category, category_idxs in categories.items()
    }

    # Compare.
    for train_idx in train_category_idxs["compare"]:
        question = experiences["questions"][train_idx]
        trajectory = experiences["trajectories"][train_idx]

        # Compare the successful trial with all previous failed trials.
        success_trial = trajectory[-1][-1]
        for failed_trial in trajectory[:-1]:
            compare_prompt_msgs = _build_compare_prompt(rules, question, success_trial, failed_trial, is_full=max_num_rules < len(rules_with_count))
            compare_prompt_msgs = collapse_prompts(compare_prompt_msgs)
            out = _prompt_compare_critique(compare_prompt_msgs, llm)
            operations = parse_rules(out)
            operations = remove_err_operations(rules_with_count, operations)

            # Update rules_with_count and rules with comparison insights.
            rules_with_count = update_rules(rules_with_count, operations, is_full=max_num_rules+5 <= len(rules_with_count))
            rules = [rule[0] for rule in rules_with_count]

    # Success.
    for train_idx in train_category_idxs["success"]:
        question = experiences["questions"][train_idx]
        trajectory = experiences["trajectories"][train_idx]

    return rules, rules_with_count

In [8]:
rules, rules_with_count = [], []
for fold, train_idxs in folds.items():
    print(fold, train_idxs)
    rules, rules_with_count = create_rules(llm, experiences, categories, train_idxs, rules, rules_with_count, max_num_rules)
    break

0 [2, 4, 6, 7, 10, 12]
