In [1]:
import os
import copy
import random
import numpy as np
import pandas as pd
import prompt_utils as utils
from optimization_utils import fill_missing

In [2]:
def run(_DATA, _SEED, _RULE_DIR, _NUM_COL, _NUM_QUERY, _API_KEY):
    utils.set_seed(_SEED)
    df, X_train, X_test, _, _, target_attr, label_list, is_cat = utils.get_dataset(_DATA, _SEED)
    X_train_org, _ = fill_missing(X_train.copy(), X_test.copy())

    # Feature bagging
    if len(X_train.columns) >= 20:
        total_column_list = []
        for i in range(len(X_train.columns) // 10):
            column_list = X_train.columns.tolist()
            random.shuffle(column_list)
            total_column_list.append(column_list[i*10:(i+1)*10])
    else:
        total_column_list = [X_train.columns.tolist()]

    meta_data_name = f"./data/{_DATA}-metadata.json"
    function_file_name = './template/ask_for_function.txt'

    rule_file_name = f'./LLM_results/{_RULE_DIR}/columns-{_DATA}-{_SEED}.out'
    saved_file_name = f'./LLM_results/{_RULE_DIR}/function-{_DATA}-{_SEED}.out'

    results = []
    fct_strs = []
    prev_modules_list = []
    current_query_num = 0
    if os.path.isfile(saved_file_name) == False:
        while current_query_num < _NUM_QUERY:
            print(f"Extracting columns {_DATA}/{_SEED} - {current_query_num}/{_NUM_QUERY}")
            # Ask llm to extract features
            if len(prev_modules_list) == 0:
                ask_file_name = './template/ask_columns.txt'
                templates, feature_desc = utils.get_prompt_for_asking(
                    _DATA, X_train, label_list, ask_file_name, meta_data_name, is_cat, num_col=_NUM_COL, num_query=1
                )
                template = templates[0]
            else:
                ask_file_name = './template/ask_columns_diversity.txt'
                template, feature_desc = utils.get_prompt_for_asking_with_diversity(_DATA, X_train, prev_modules_list, label_list,
                                                                                    ask_file_name, meta_data_name, is_cat, total_column_list, 
                                                                                    current_query_num, num_col=_NUM_COL)

            result = utils.query_gpt([template], api_key=_API_KEY, max_tokens=1500, temperature=0.5, verbose=False)[0]
            results.append(result)

            # Parse text to feature generation function
            try_num = 0
            while try_num < 10:
                fct_template = utils.get_prompt_for_generating_function(
                    [result], feature_desc, function_file_name
                )
                fct_result = utils.query_gpt(fct_template, api_key=_API_KEY, max_tokens=1500, temperature=0.5, verbose=False)
                fct_str = fct_result[0].split('<start>')[1].split('<end>')[0].strip()

                fct_str = 'def' + 'def'.join(fct_str.split('def')[1:])
                except_handled = "\n".join(["    try: " + fct_piece.strip() + "\n    except: pass" for fct_piece in fct_str.split('\n')[1:-1] if len(fct_piece.strip()) > 2])
                fct_str_handled = "\n".join([fct_str.split('\n')[0], except_handled, fct_str.split('\n')[-1]])

                try:
                    exec(fct_str_handled)
                    X_train_new_col = locals()['column_appender'](X_train_org)
                    new_columns = list(set(X_train_new_col.columns) - set(X_train_org.columns))
                    X_train_new_col = X_train_new_col[new_columns]
                except:
                    try_num += 1
                    continue
                break

            if try_num >= 10: # Skip for failed cases
                continue

            fct_strs.append(fct_str_handled)  
                
            try:
                discovered_column_name_list = []
                discovered_column_name_desc = []
                for result_str in result.split('\n'):
                    if '|' not in result_str:
                        continue
                    result_str_list = result_str.split('|')
                    discovered_column_name_list.append(result_str_list[1].strip())
                    discovered_column_name_desc.append(result_str_list[2].strip())
                
                saved_modules_list = copy.deepcopy(prev_modules_list)
                prev_modules_list = []
                for new_column in X_train_new_col.columns:
                    if new_column in discovered_column_name_list:
                        found_idx = discovered_column_name_list.index(new_column)
                        new_column_desc = discovered_column_name_desc[found_idx]
                    else:
                        new_column_desc = new_column

                    prev_modules_list.append([new_column, new_column_desc])
            except:
                prev_modules_list = saved_modules_list
                continue

            current_query_num += 1

        if not os.path.exists(f'./LLM_results/{_RULE_DIR}'): 
            os.makedirs(f'./LLM_results/{_RULE_DIR}') 
            
        with open(rule_file_name, 'w') as f:
            total_rules = "\n\n---DIVIDER---\n\n".join(results)
            f.write(total_rules)

        with open(saved_file_name, 'w') as f:
            total_str = "\n\n---DIVIDER---\n\n".join(fct_strs)
            f.write(total_str)

In [None]:
_NUM_COL = 5
_NUM_QUERY = 40
_RULE_DIR = f'diversity_{_NUM_COL}rules_{_NUM_QUERY}trials'
_API_KEY = '<Get your own API key>'

for _DATA in [
    'adult', 'blood', 'adult', 'tic-tac-toe', 'sequence-type','insurance',  'heart', 
    'car', 'communities', 'credit-g', 'diabetes', 'bank', 'myocardial', 
    'junglechess', 'housing', 'solution-mix', 'forest-fires', 'eucalyptus', 'balance-scale', 'vehicle'
]:
    for _SEED in [0, 1, 2]:
        run(_DATA, _SEED, _RULE_DIR, _NUM_COL, _NUM_QUERY, _API_KEY)