In [43]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import pickle
import json
import ipdb
from overcooked_ai_py.mdp.overcooked_mdp import OvercookedState, PlayerState, OvercookedGridworld, ObjectState
%matplotlib inline
%config Completer.use_jedi = False

In [2]:
pkl_file = open("../human_aware_rl/data/human/anonymized/clean_train_trials.pkl", "rb")
df = pickle.load(pkl_file)

In [87]:
trial_id = df.workerid_num.unique()[1]
layout_id = df.layout_name.unique()[1]
trial_data = df[(df.workerid_num == trial_id) & (df.layout_name == layout_id)]
trial_data = trial_data.sort_values('cur_gameloop')

In [88]:
def get_player_states(state, grid):
    # NOTE: unused right now
    players = []
    for p in state["players"]:
        if "held_object" in p:
            players.append(PlayerState(p["position"], p["orientation"], p["held_object"]))
        else:
            players.append(PlayerState(p["position"], p["orientation"]))
    return players
    
def is_adjacent_to_serve(state, player, grid):
    serve_locs = grid.get_serving_locations()
    player_loc = state["players"][player]["position"]
    player_dir = state["players"][player]["orientation"]
    
    return any((np.add(player_loc, player_dir) == serve_locs).all(axis=1))

N_ACTIONS = 8
def get_hl_action(state, next_state, player, grid):
    # only getting the difference in held items
    # possible held items: onion, soup, dish
    # items = ["onion", "soup", "dish"]
    
    curr_item = state["players"][player]["held_object"]
    if curr_item is not None:
        curr_item = curr_item["name"]
    next_item = next_state["players"][player]["held_object"]
    if next_item is not None:
        next_item = next_item["name"]
    
    # NOTE: I think the only way to do this mapping is manually
    # possible high level actions: (encode with integer here, do one-hot encoding later)
    # 0: pick up onion
    # 1: pick pick up dish
    # 2: pick up soup (from counter)
    # 3: get cooked soup (from pot)
    # 4: put down onion
    # 5: put down dish
    # 6: put down soup (on counter)
    # 7: serve soup

    if (curr_item is None) and (next_item == "onion"):
        return 0
    elif curr_item is None and next_item == "dish":
        return 1
    elif curr_item is None and next_item == "soup":
        return 2
    elif curr_item == "dish" and next_item == "soup":
        return 3
    elif curr_item == "onion" and next_item is None:
        return 4
    elif curr_item == "dish" and next_item is None:
        return 5
    elif curr_item == "soup" and next_item is None:
        # split based on if soup was served or put on counter
        if is_adjacent_to_serve(state, player, grid):
            return 7
        return 6
    return None

def fix_held_obj(state):
    """
    Mutates state object to add key 'held_object' with value None if the key does not already exist
    """
    for player in state["players"]:
        if "held_object" not in player:
            player["held_object"] = None

def create_obj_states(objects):
    obj = {}
    for pos, o in objects.items():
        # NOTE: later code requires keys to be tuple not str, hence the eval(pos)
        obj[eval(pos)] = ObjectState.from_dict(o)
    return obj
            
def extract_hl_actions(data):
    data = data.sort_values('time_elapsed')
    times = data.cur_gameloop
    grid = None
    times = [[], []]
    hl_actions = [[], []]
    state_encodings = []
    t = 0
    for index, row in data.iterrows():
        if grid is None:
            grid = OvercookedGridworld.from_grid(eval(row["layout"]))
        # don't care about movement actions, we only need to split on "interact" actions
        state = eval(row["state"])
        next_state = eval(row["next_state"])
        
        # NOTE: loading PlayerState from dict errors if held_object key does not exist
        fix_held_obj(state)
        fix_held_obj(next_state)
        player_states = [PlayerState.from_dict(state["players"][player]) for player in range(2)] 
        objects = create_obj_states(state["objects"])
        # NOTE: using from_dict fails because state dict contains extra key 'pot_explosion'
        state_obj = OvercookedState(players=player_states, 
                                    objects=objects, 
                                    order_list=state["order_list"])
        state_encoding = grid.lossless_state_encoding(state_obj)
        state_encodings.append(state_encoding)
        for player in range(2):
            action = eval(row["joint_action"])[player]
            if action == "INTERACT":
                hl_action = get_hl_action(state, next_state, player, grid)
                if hl_action is not None:
                    times[player].append(t)
                    hl_actions[player].append(hl_action)
        t += 1
    return times, hl_actions, state_encodings

In [89]:
times, hl_actions, state_encodings = extract_hl_actions(trial_data)

In [90]:
print(hl_actions[0])

[0, 4, 0, 4, 0, 4, 0, 4, 0, 4, 1, 3, 7, 0, 4, 1, 3, 7, 0, 4, 0, 4, 0, 4, 0, 4, 1, 3, 7, 0, 4, 0, 4, 0, 4, 0, 4, 0, 4, 0, 4, 0, 4, 1, 3, 7, 0, 4, 0, 4, 0, 4, 1, 3, 7, 0, 4, 0, 4, 0, 4, 1, 3, 7, 0, 4, 0, 4, 0, 4, 1, 3, 7, 0, 4, 0, 4, 0, 4, 0, 4, 0]


In [83]:
len(state_encodings[0])

2

In [12]:
player = 1
for index, row in trial_data.iterrows():
    time = row["cur_gameloop"]
    action = eval(row["joint_action"])[player]
    state = eval(row["state"])
    next_state = eval(row["next_state"])
    
    curr_item = None
    if "held_object" in state["players"][player]:
        curr_item = state["players"][player]["held_object"]["name"]
    next_item = None
    if "held_object" in next_state["players"][player]:
        next_item = next_state["players"][player]["held_object"]["name"]
    print(time, curr_item, next_item, action)

0.0 None None [0, 0]
1.0 None None [0, 0]
2.0 None None [0, 0]
3.0 None None [0, 0]
4.0 None None [0, 0]
5.0 None None [0, 0]
6.0 None None [0, 0]
7.0 None None [0, -1]
8.0 None None [0, 0]
9.0 None None [0, -1]
10.0 None None [0, -1]
11.0 None None [0, 0]
12.0 None None [-1, 0]
13.0 None None [0, 0]
14.0 None None [0, 0]
15.0 None onion INTERACT
16.0 onion onion [0, 0]
17.0 onion onion [0, 0]
18.0 onion onion [0, 1]
19.0 onion onion [0, 0]
20.0 onion onion [1, 0]
21.0 onion onion [0, 0]
22.0 onion onion INTERACT
23.0 onion onion [0, 0]
24.0 onion onion [0, 0]
25.0 onion onion [1, 0]
26.0 onion None INTERACT
27.0 None None [0, 0]
28.0 None None INTERACT
29.0 None None [0, 0]
30.0 None None [0, 0]
31.0 None None [-1, 0]
32.0 None None [-1, 0]
33.0 None None [0, 0]
34.0 None None [-1, 0]
35.0 None None [0, 0]
36.0 None None [0, 0]
37.0 None None [0, -1]
38.0 None None [0, -1]
39.0 None None [-1, 0]
40.0 None None [0, 0]
41.0 None onion INTERACT
42.0 onion onion [0, 0]
43.0 onion onion [0

529.0 soup soup [0, -1]
530.0 soup soup [0, -1]
531.0 soup soup [0, -1]
532.0 soup None INTERACT
533.0 None None [0, 0]
534.0 None None [0, 0]
535.0 None None [0, 0]
536.0 None None [0, 0]
537.0 None None INTERACT
538.0 None None [0, 1]
539.0 None None [0, 0]
540.0 None None [0, 1]
541.0 None dish INTERACT
542.0 dish dish [0, 0]
543.0 dish dish [0, -1]
544.0 dish dish [0, -1]
545.0 dish dish [0, -1]
546.0 dish dish [1, 0]
547.0 dish dish [0, 0]
548.0 dish soup INTERACT
549.0 soup soup [0, 0]
550.0 soup soup [0, 0]
551.0 soup soup [0, -1]
552.0 soup soup [0, 0]
553.0 soup soup [0, 0]
554.0 soup soup [0, 0]
555.0 soup soup [0, 0]
556.0 soup soup [0, 0]
557.0 soup None INTERACT
558.0 None None [0, 0]
559.0 None None [0, 0]
560.0 None None [0, 1]
561.0 None None [0, 0]
562.0 None dish INTERACT
563.0 dish dish [0, 0]
564.0 dish dish [0, 0]
565.0 dish dish [1, 0]
566.0 dish dish [0, 0]
567.0 dish dish [0, 0]
568.0 dish dish [0, 0]
569.0 dish dish [0, 0]
570.0 dish dish [0, 0]
571.0 dish dish

1028.0 dish dish [0, 0]
1029.0 dish dish [0, 0]
1030.0 dish dish [0, 0]
1031.0 dish dish [0, 0]
1032.0 dish dish [0, 0]
1033.0 dish dish [0, 0]
1034.0 dish soup INTERACT
1035.0 soup soup [0, 0]
1036.0 soup soup [0, -1]
1037.0 soup soup [0, 0]
1038.0 soup None INTERACT
1039.0 None None [0, 0]
1040.0 None None [0, 1]
1041.0 None None [0, 0]
1042.0 None dish INTERACT
1043.0 dish dish INTERACT
1044.0 dish dish [1, 0]
1045.0 dish dish [0, 0]
1046.0 dish dish [0, 0]
1047.0 dish dish [0, 0]
1048.0 dish dish [0, 0]
1049.0 dish dish [0, 0]
1050.0 dish dish [0, 0]
1051.0 dish dish [0, 0]
1052.0 dish dish [0, 0]
1053.0 dish dish [0, 0]
1054.0 dish dish [0, 0]
1055.0 dish dish [0, 0]
1056.0 dish dish [0, 0]
1057.0 dish dish [0, 0]
1058.0 dish dish [0, 0]
1059.0 dish soup INTERACT
1060.0 soup soup [0, 0]
1061.0 soup soup [0, -1]
1062.0 soup soup [0, 0]
1063.0 soup None INTERACT
1064.0 None None [0, 0]
1065.0 None None [0, 0]
1066.0 None None [0, 1]
1067.0 None None [0, 0]
1068.0 None dish INTERACT
