In [None]:
import argparse

import crafter
import tqdm
import numpy as np

class Object(object):
    pass

args = Object()
args.outdir='logs/test_env'
args.wandb_log_interval=10args.steps=1000000

In [None]:
from llm_api import get_query
query_model = get_query("gpt-4")
import wandb
wandb.init(project='SPRING')

In [None]:
import gym
import smartplay
env = gym.make("smartplay:Crafter-v0")
action_list = env.action_list

In [None]:
def match_act(string):
    for i, act in enumerate(action_list):
        if act.lower() in string.lower():
            return i
    print("LLM failed with output \"{}\", taking action Do...".format(string))
    return action_list.index("Do")

In [None]:
def convert_to_game(act, info):
    return act

In [None]:
question_dependencies = {
    "Did the last player action succeed? If not, why?": ["What was the last action taken by the player?",],
    "For each object in the list, are the requirements met for interaction?": ["List objects in the current observation. For each object, breifly answer what resource it provides and its requirement.", ],
    "List top 3 sub-tasks the player should follow. Indicate their priority out of 5.": ["List objects in the current observation. For each object, breifly answer what resource it provides and its requirement.", "For each object in the list, are the requirements met for interaction?", ],
    "What are the requirements for the top sub-task? What should the player do first?": ["List top 3 sub-tasks the player should follow. Indicate their priority out of 5.", ],
    "List top 5 actions the player should take and the requirement for each action. Choose ONLY from the list of all actions. Indicate their priority out of 5.": ["What are the requirements for the top sub-task? What should the player do first?"],
    "For each action in the list, are the requirements met?": ["List top 5 actions the player should take and the requirement for each action. Choose ONLY from the list of all actions. Indicate their priority out of 5."],
    "Choose the best executable action from above.": ["What was the last action taken by the player?", "Did the last player action succeed? If not, why?", "List top 5 actions the player should take and the requirement for each action. Choose ONLY from the list of all actions. Indicate their priority out of 5.", "For each action in the list, are the requirements met?"],
}
q_act = "Choose the best executable action from above."

In [None]:
def compose_prompt(CTXT, text_obs, Q_CTXT, question):

    messages = [
       {"role": "system", "content" : "You’re a player trying to play the game of crafter."}
    ]
    
    messages.append({"role": "system", "content": CTXT})

    messages.append({"role": "system", "content": "Most recent two steps of the player's in-game observation:\n{}".format(text_obs)})

    if len(Q_CTXT)>0:
        for q,a in Q_CTXT:
            messages.append({"role": "user", "content": q})
            messages.append({"role": "assistant", "content": a})

    messages.append({"role": "user", "content": question})

    return messages

In [None]:
def topological_sort(dependencies):
    def dfs(node):
        if visited[node] == 1:
            raise ValueError("There is a cycle in the dependency graph.")
        if visited[node] == 0:
            visited[node] = 1
            for neighbor in dependencies.get(node, []):
                dfs(neighbor)
            visited[node] = 2
            result.append(node)
    
    items = set(dependencies.keys()) | set(x for v in dependencies.values() for x in v)
    visited = {item: 0 for item in items}  # 0: unvisited, 1: visiting, 2: visited
    result = []
    
    for item in items:
        if visited[item] == 0:
            dfs(item)
    
    return result[::-1]


def topological_traverse(CTXT, text_obs, dependencies):
    
    def dfs(node):
        if visited[node] == 1:
            raise ValueError("There is a cycle in the dependency graph.")
        if visited[node] == 0:
            visited[node] = 1
            q_ctxt = []
            for neighbor in dependencies.get(node, []):
                if neighbor not in result.keys():
                    dfs(neighbor)
                q_ctxt.append((neighbor, result[neighbor]))
            visited[node] = 2
            prompt = compose_prompt(CTXT, text_obs, q_ctxt, node)
            answer = query_model(prompt).strip()
            print("Question: {}\nAnswer: {}".format(node, answer))
            result[node] = answer
    
    items = set(dependencies.keys()) | set(x for v in dependencies.values() for x in v)
    visited = {item: 0 for item in items}  # 0: unvisited, 1: visiting, 2: visited
    result = {}
    
    for item in items:
        if visited[item] == 0:
            dfs(item)
    
    return result

In [None]:
questions_lvls = topological_sort(question_dependencies)

In [None]:
achievements = ['collect_coal',
 'collect_diamond',
 'collect_drink',
 'collect_iron',
 'collect_sapling',
 'collect_stone',
 'collect_wood',
 'defeat_skeleton',
 'defeat_zombie',
 'eat_cow',
 'eat_plant',
 'make_iron_pickaxe',
 'make_iron_sword',
 'make_stone_pickaxe',
 'make_stone_sword',
 'make_wood_pickaxe',
 'make_wood_sword',
 'place_furnace',
 'place_plant',
 'place_stone',
 'place_table',
 'wake_up']

In [None]:
done = True
step = 0
_, info = env.reset()
CTXT = info['manual']
trajectories = []
R = 0
a = action_list.index("Noop")

columns=["Step", "OBS", "Score", "Reward"] + ["Lvl-{}: {}".format(i,q) for i, q in enumerate(questions_lvls)] + ["Action"]
wandb_table = wandb.Table(columns=columns)
achievement_table = wandb.Table(columns=achievements)

last_log = 0

while step < args.steps:

    if done:
        env.reset()
    done = False
    a_row = [0] * len(questions_lvls)

    _, reward, done, info = env.step(a)
    R += reward

    print("=="*15, "Step: {}, Reward: {}".format(step, R), "=="*15)
    desc = info['obs']
    print(desc)
    new_row = [step, desc, R, reward]
    wandb.log({"reward": reward, "total reward": R})

    trajectories.append((step, desc))
    text_obs = "\n\n".join(["Player Observation Step {}:\n{}".format(i+i,d) for i, d in trajectories[-2:]])
    
    print("--"*10 + " QA " + "--"*10)
    results = topological_traverse(CTXT, text_obs, question_dependencies)
    
    for (q,a) in results.items():
        a_row[questions_lvls.index(q)] = a
        if q == q_act:
            answer_act = a

    new_row = new_row + a_row

    a = match_act(answer_act)
    new_row.append(action_list[a])
    a = convert_to_game(a, info)
    step += 1
    wandb_table.add_data(*new_row)
    achievement_table.add_data(*[info['achievements'][k] for k in achievements])

    print()
    if step % args.wandb_log_interval == 0 or done:
        wandb.log({"rollout {}~{}".format(last_log, step-1): wandb_table, 
                   "achievements {}~{}".format(last_log, step-1): achievement_table, 
                  })
        wandb_table = wandb.Table(columns=columns)
        achievement_table = wandb.Table(columns=achievements)
        last_log = step
    if done:
        break

In [None]:
wandb.finish()