# Env

In [None]:
cd ..

In [None]:
import json
from tot.prompts.crosswords import propose_prompt, value_prompt
from tot.models import gpt
from tot.tasks.crosswords import MiniCrosswordsEnv

env = MiniCrosswordsEnv()

# Prompt

In [None]:
def prompt_wrap(obs):
    return propose_prompt.format(input=obs)

print(prompt_wrap(env.reset(0)))
# print('---------')
# print(prompt_wrap(env.step('h2. value')[0]))

In [None]:
import re
import copy
from tot.models import gpt

def parse_line(input_str):
    # regular expression pattern to match the input string format
    pattern = r'^([hv][1-5])\. ([a-zA-Z]{5,5}) \((certain|high|medium|low)\).*$'

    # use regex to extract the parts of the input string
    match = re.match(pattern, input_str)

    if match:
        # extract the matched groups
        parts = [match.group(1), match.group(2), match.group(3)]
        return parts
    else:
        return None

confidence_to_value = {'certain': 1, 'high': 0.5, 'medium': 0.2, 'low': 0.1}  # TODO: ad hoc

def parse_response(response):
    # split the response into lines
    lines = response.split('\n')

    # parse each line
    parsed_lines = [parse_line(line) for line in lines]

    # filter out the lines that didn't match the format
    parsed_lines = [(line[0].lower() + '. ' + line[1].lower(), confidence_to_value.get(line[2], 0)) for line in parsed_lines if line is not None]

    return parsed_lines if len(parsed_lines) >= 1 else None


def get_candidates_to_scores(env):
    obs = env.render()
    if obs in env.cache: 
        print('cache hit')
        return env.cache[obs]
    print('call gpt')
    responses = gpt(prompt_wrap(obs), model='gpt-4', n=8)
    candidates_to_scores = {}
    for response in responses:
        parsed_response = parse_response(response)
        if parsed_response:
            for candidate, score in parsed_response:
                candidates_to_scores[candidate] = candidates_to_scores.get(candidate, 0) + score
        # choose candiate with highest score
    # print(sorted(candidates_to_scores.items(), key=lambda x: x[1], reverse=True))
    env.cache[obs] = candidates_to_scores
    return candidates_to_scores

def propose_score(env, idx):
    obs = env.reset(idx)
    done = False
    infos = []
    while not done:
        responses = gpt(prompt_wrap(obs), model='gpt-4', n=5)
        candidates_to_scores = {}
        for response in responses:
            parsed_response = parse_response(response)
            if parsed_response:
                for candidate, score in parsed_response:
                    candidates_to_scores[candidate] = candidates_to_scores.get(candidate, 0) + score
        # choose candiate with highest score
        print(sorted(candidates_to_scores.items(), key=lambda x: x[1], reverse=True))
        if len(candidates_to_scores) == 0:
            break
        candidates =  sorted(candidates_to_scores, key=candidates_to_scores.get, reverse=True)
        for candidate in candidates:
            env_ = copy.deepcopy(env)
            env_.step(candidate)
            if not any(_ == 2 for _ in env_.status):
                break
        print(candidate)
        # candidate = input()
        obs, r, done, info = env.step(candidate)
        print(obs)
        print(env.steps, info)
        print('-------------------\n\n\n')
        infos.append(info)
    return infos

# DFS

In [None]:
def dfs(env, actions, infos, time_limit, prune, max_per_state):
    # get candidate thoughts
    candidates_to_scores = get_candidates_to_scores(env)
    if len(candidates_to_scores) == 0: return 0, [], []
    print(sorted(candidates_to_scores.items(), key=lambda x: x[1], reverse=True))

    # back up current state
    board, status, steps = env.board.copy(), env.status.copy(), env.steps

    # try each candidate
    cnt_per_state = 0
    for action in sorted(candidates_to_scores, key=candidates_to_scores.get, reverse=True):
        obs, r, done, info = env.step(action)
        r = info['r_word']
        if len(infos) < time_limit and env.steps < 10 and not any(_ == 2 for _ in env.status):  # not violating any existing constraints
            cnt_per_state += 1
            if cnt_per_state > max_per_state: break
            count = env.prompt_status()       
            actions.append(action)  

            print(len(infos))
            print(actions)
            print(env.render_board())
            print(info)
            print(count)
            if infos:
                best = max(infos, key=lambda x: x['info']['r_word'])
                print('best', best)
            print('--------------')
            print()

            info = {'total_step': len(infos), 'env_step': env.steps, 'actions': actions.copy(), 'info': info, 'count': count}
            infos.append(info)
            if not prune or count['impossible'] < 1:  # only continue if the current status is possible
                dfs(env, actions, infos, time_limit, prune, max_per_state)
            actions.pop()
        env.reset(env.idx, board=board.copy(), status=status.copy(), steps=steps)

In [None]:
def crossword_execute(env, action_info, other_params):
    action = action_info['action']
    if action_info['env_info'] is not None:
        board, status, steps = action_info['env_info']
        env.reset(env.idx, board=board.copy(), status=status.copy(), steps=steps)
    obs, r, done, info = env.step(action)
    r = info['r_word']
    if len(other_params['infos']) < other_params['time_limit'] and env.steps < 10 and not any(_ == 2 for _ in env.status):  # not violating any existing constraints
        
        cnt_per_state += 1
        if cnt_per_state > other_params['max_per_state']: return 'break'
        count = env.prompt_status()      
        actions = action_info['parent_actions'].copy() 
        actions.append(action)  

        print(len(other_params['infos']))
        print(actions)
        print(env.render_board())
        print(info)
        print(count)
        if other_params['infos']:
            best = max(other_params['infos'], key=lambda x: x['info']['r_word'])
            print('best', best)
        print('--------------')
        print()

        info = {'total_step': len(other_params['infos']), 'env_step': env.steps, 'actions': actions.copy(), 'info': info, 'count': count}
        other_params['infos'].append(info)
        if other_params['prune'] and count['impossible'] >= 1:  # only continue if the current status is possible
            return 'non-generate'
    return 'generate'

def crossword_generate(env, action_info, other_params):
    candidates_to_scores = get_candidates_to_scores(env)
    if len(candidates_to_scores) == 0: return []
    sorted_candidates = sorted(candidates_to_scores, key=candidates_to_scores.get, reverse=False)
    sorted_scores = [candidates_to_scores[candidate] for candidate in sorted_candidates]
    return { "result_list": sorted_candidates, "priority_list": sorted_scores }

def auto_search(env, other_params, execute_func, generate_func, epsilon = 0.3, decay_rate = 0.9, sliding_window_size = None, heapify_queue_stack = False, queue_stack_valuate_func = None):
    queue_stack = [{'action': None, 'parent_actions': [], 'env_info': None, 'level': 0}]
    while queue_stack:
        if queue_stack_valuate_func is not None:
            queue_stack_valuate_func(env, queue_stack, other_params)
        random_number = random.random()
        if random_number < epsilon:
            sliding_window_indexes = (0, len(queue_stack) - 1) if sliding_window_size is None else sliding_window_size
            if heapify_queue_stack:
                queue_stack_copy = queue_stack.copy()
                heapq.heapify(queue_stack_copy)
                action_info = queue_stack_copy.pop(random.randint(sliding_window_indexes[0], sliding_window_indexes[1]))
                queue_stack.pop(queue_stack.index(action_info))
            else:
                action_info = queue_stack.pop(random.randint(sliding_window_indexes[0], sliding_window_indexes[1]))
        else:
            action_info = queue_stack.pop()
        actions = action_info['parent_actions']
        level = action_info['level']

        to_generate_thoughts = 'generate'

        if action_info['action'] is not None:
            to_generate_thoughts = execute_func(env, action_info, other_params)

        if to_generate_thoughts == 'generate':
            new_thoughts = generate_func(env, action_info, other_params)
            support_heapified = False
            if (type(new_thoughts) == dict) and 'priority_list' in new_thoughts and new_thoughts['priority_list'] is not None:
                support_heapified = True
                if len(new_thoughts['result_list']) == 0: continue
                for priority, action in zip(new_thoughts['priority_list'], new_thoughts['result_list']):
                    queue_stack.append((
                        priority, {
                            'action': action, 
                            'parent_actions': actions.copy(), 
                            'env_info': (env.board.copy(), env.status.copy(), env.steps),
                            'level': level + 1
                        }
                    ))
            else:
                if len(new_thoughts) == 0: continue
                for action in new_thoughts:
                    queue_stack.append({
                        'action': action, 'parent_actions': actions.copy(), 
                        'env_info': (env.board.copy(), env.status.copy(), env.steps),
                        'level': level + 1
                        })
        elif to_generate_thoughts == 'non-generate':
            continue
        elif to_generate_thoughts == 'break':
            break
        epsilon = epsilon * decay_rate

In [None]:
# dfs with pruning
infoss = []
for i in range(0, 100, 5):
    env.reset(i)
    infos = []
    actions = []
    dfs(env, actions, infos, 100, prune=True, max_per_state=3)
    infoss.append(infos)
    with open('logs/crosswords/infoss_dfs_prune.json', 'w') as fout:
        json.dump(infoss, fout)

In [None]:
# dfs without pruning
infoss = []
for i in range(0, 100, 5):
    env.reset(i)
    infos = []
    actions = []
    dfs(env, actions, infos, 100, prune=False, max_per_state=3)
    infoss.append(infos)
    with open('logs/crosswords/infoss_dfs_no_prune.json', 'w') as fout:
        json.dump(infoss, fout)

In [3]:
things = [1,2,3]
while things:
    print('HAHAHA', things.pop())


HAHAHA 3
HAHAHA 2
HAHAHA 1
