In [2]:
import json
from collections import defaultdict
import torch
from vocab import Vocab

In [3]:
# to run just a single example
states_path = '/root/data_alfred/json_feat_2.1.0/pick_and_place_simple-TennisRacket-None-Bed-303/trial_T20190906_193617_277654/pp/metadata_states.json'

with open(states_path, 'r') as f:
    saved_states = json.load(f)

## Setup

In [4]:
object_state_types = ['isToggled', 'isBroken', 'isFilledWithLiquid', 'isDirty',
                      'isUsedUp', 'isCooked', 'ObjectTemperature', 'isSliced',
                      'isOpen', 'isPickedUp', 'mass', 'receptacleObjectIds']

In [5]:
object_vocab = dict()
object_vocab['object_type'] = Vocab(['<<pad>>'])

## Transformations

In [6]:
def get_object_states(obj, required_state_types):
    o_states = {}
    for state_typ in required_state_types:
        if state_typ in obj.keys():
            o_states[state_typ] = obj[state_typ]
        else:
            o_states[state_typ] = None
    return o_states

def detect_state_change(last_obj_states, curr_obj_states, objectId_list):
    state_change = []
    for obj_Id in objectId_list:
        if (not obj_Id in last_obj_states) and (not obj_Id in curr_obj_states):
            state_change.append(False)
        elif not obj_Id in last_obj_states:
            state_change.append(True)
        elif not obj_Id in curr_obj_states:
            raise Exception('Objects should always appear in next time step.')
        else:
            if last_obj_states[obj_Id] == curr_obj_states[obj_Id]:
                state_change.append(False)
            else:
                state_change.append(True)
    return state_change

In [7]:
def detect_type_state_change(instance_state_change, object_instance_list_sorted, object_type_list_sorted):
    state_change = []
    
    curr_typ = object_instance_list_sorted[0].split('|')[0]
    changes = []
    for i, instance_name in enumerate(object_instance_list_sorted):
        instance_typ = instance_name.split('|')[0]
        if instance_typ != curr_typ:
            state_change.append(True in changes)
            changes = []
            curr_typ = instance_typ
        changes.append(instance_state_change[i])
            
    # don't forget about last typ!
    state_change.append(int(True in changes))
    assert len(state_change) == len(object_type_list_sorted)
    return state_change

def detect_type_visibility(instance_visible, object_instance_list_sorted, object_type_list_sorted):
    type_visible = []
    
    curr_typ = object_instance_list_sorted[0].split('|')[0]
    visible = []
    for i, instance_name in enumerate(object_instance_list_sorted):
        instance_typ = instance_name.split('|')[0]
        if instance_typ != curr_typ:
            type_visible.append(True in visible)
            visible = []
            curr_typ = instance_typ
        visible.append(instance_visible[i])
            
    # don't forget about last typ!
    type_visible.append(int(True in visible))
    assert len(type_visible) == len(object_type_list_sorted)
    return type_visible

In [8]:
def shift_states(saved_states):
    # shift all the states to the last time step
    shifted_states = [{} for _ in range(len(saved_states))]

    for t, state in enumerate(saved_states):
        new_state = {k:v for k,v in state.items() if k != 'objects_metadata'}
        if t < len(saved_states) - 1:
            new_state['objects_metadata'] = saved_states[t+1]['objects_metadata']
        elif t == len(saved_states) - 1:
            new_state['objects_metadata'] = saved_states[t]['objects_metadata']
        shifted_states[t] = new_state
        
    return shifted_states

In [9]:
def extract_states_for_model(shifted_states, saved_states, root):

    objectId_list = sorted([o['objectId'].lower() for o in shifted_states[-1]['objects_metadata']])
    objectTyp_list = sorted(set([o.split('|')[0].lower() for o in objectId_list]))
    num_subgoals = len(set([s['subgoal_step'] for s in shifted_states]))

    try:
        feat = {
            'subgoal': [[] for _ in range(num_subgoals)],
            'subgoal_i': [[] for _ in range(num_subgoals)],
            'subgoal_t': [[] for _ in range(num_subgoals)],
            'overall_t': [[] for _ in range(num_subgoals)],

            'objectInstanceList': objectId_list,
            'objectInstanceList_TypeNum': object_vocab['object_type'].word2index([o.split('|')[0] for o in objectId_list], train=True),
            'instance_visibile': [[] for _ in range(num_subgoals)], # [subgoal_i][subgoal timestep] = [T/F for ob in objectId_list]
            'instance_state_change': [[] for _ in range(num_subgoals)],

            'objectTypeList': objectTyp_list,
            'objectTypeList_TypeNum': object_vocab['object_type'].word2index(objectTyp_list, train=True),
            'type_visibile': [[] for _ in range(num_subgoals)], # [subgoal_i][subgoal timestep] = [T/F for ob in objectTyp_list]
            'type_state_change': [[] for _ in range(num_subgoals)]
        }
    except:
        import pdb; pdb.set_trace()

    last_obj_states = {obj['objectId'].lower():get_object_states(obj, object_state_types) for obj in saved_states[0]['objects_metadata']}
    for t, state in enumerate(shifted_states):

        subgoal_i = state['subgoal_step']
        subgoal = state['subgoal']

        if state['new_subgoal']:
            subgoal_t = 0

        feat['subgoal_i'][subgoal_i].append(subgoal_i)
        feat['subgoal_t'][subgoal_i].append(subgoal_t)
        feat['subgoal'][subgoal_i].append(subgoal)
        feat['overall_t'][subgoal_i].append(t)

        # get state change
        curr_obj_states = {obj['objectId'].lower():get_object_states(obj, object_state_types) for obj in state['objects_metadata']}
        
        # list same order as objectId_list
        state_change = detect_state_change(last_obj_states, curr_obj_states, objectId_list)
        last_obj_states = curr_obj_states

        feat['instance_state_change'][subgoal_i].append(state_change)
    
        # get visibility
        visible = [False for _ in objectId_list]
        for ob in state['objects_metadata']:
            pos = objectId_list.index(ob['objectId'].lower())
            visible[pos] = ob['visible']
        feat['instance_visibile'][subgoal_i].append(visible)

        # get type state change
        type_state_change = detect_type_state_change(state_change, objectId_list, objectTyp_list)
        feat['type_state_change'][subgoal_i].append(type_state_change)

        # get type visibility
        type_visible = detect_type_visibility(visible, objectId_list, objectTyp_list)
        feat['type_visibile'][subgoal_i].append(type_visible)

        subgoal_t += 1
        
    return feat

## TRAIN

In [10]:
date = '20200511'
split_name = 'train'

with open('/root/data_alfred/splits/collect_states_{}_{}_notebook_success_paths.json'.format(date, split_name), 'r') as f:
    success_outpaths = json.load(f)
    
len(success_outpaths)

with open('/root/data_alfred/splits/collect_states_{}_{}_notebook_failed_roots.json'.format(date, split_name), 'r') as f:
    failed_roots = json.load(f)
    
print('success paths = ', len(success_outpaths))
print('failed paths = ', len(failed_roots))

success paths =  6505
failed paths =  69


In [11]:
extract_feat_dum = []
extract_feat_outpaths = []

for metadata_outpath in success_outpaths:
    
    with open(metadata_outpath, 'r') as f:
        saved_states = json.load(f)
        
    root = metadata_outpath[:metadata_outpath.index('pp/')]
    
    shifted_states = shift_states(saved_states)
    extract_feat = extract_states_for_model(shifted_states, saved_states, root)
    extract_feat['metadata_path'] = metadata_outpath
    extract_feat['root'] = root
    
    extract_feat_outpath = metadata_outpath.replace('/metadata_states.json', '/extracted_feature_states.json')
    with open(extract_feat_outpath, 'w') as f:
        json.dump(extract_feat ,f)
    
    extract_feat_outpaths.append(extract_feat_outpath)
    extract_feat_dum.append(extract_feat)
    
print(len(extract_feat_dum))    

6505


In [12]:
date = '20200513'
with open('/root/data_alfred/splits/extract_feat_states_{}_{}_notebook_outpaths.json'.format(date, split_name), 'w') as f:
    json.dump(extract_feat_outpaths, f)

## VALID SEEN

In [15]:
date = '20200511'
split_name = 'valid_seen'

with open('/root/data_alfred/splits/collect_states_{}_{}_notebook_success_paths.json'.format(date, split_name), 'r') as f:
    success_outpaths = json.load(f)
    
len(success_outpaths)

with open('/root/data_alfred/splits/collect_states_{}_{}_notebook_failed_roots.json'.format(date, split_name), 'r') as f:
    failed_roots = json.load(f)
    
print('success paths = ', len(success_outpaths))
print('failed paths = ', len(failed_roots))

success paths =  249
failed paths =  2


In [16]:
extract_feat_dum = []
extract_feat_outpaths = []

for metadata_outpath in success_outpaths:
    
    with open(metadata_outpath, 'r') as f:
        saved_states = json.load(f)
        
    root = metadata_outpath[:metadata_outpath.index('pp/')]
    
    shifted_states = shift_states(saved_states)
    extract_feat = extract_states_for_model(shifted_states, saved_states, root)
    extract_feat['metadata_path'] = metadata_outpath
    extract_feat['root'] = root
    
    extract_feat_outpath = metadata_outpath.replace('/metadata_states.json', '/extracted_feature_states.json')
    with open(extract_feat_outpath, 'w') as f:
        json.dump(extract_feat ,f)
    
    extract_feat_outpaths.append(extract_feat_outpath)
    extract_feat_dum.append(extract_feat)
    
print(len(extract_feat_dum))    

249


In [17]:
date = '20200513'
with open('/root/data_alfred/splits/extract_feat_states_{}_{}_notebook_outpaths.json'.format(date, split_name), 'w') as f:
    json.dump(extract_feat_outpaths, f)

## VALID UNSEEN

In [19]:
date = '20200511'
split_name = 'valid_unseen'

with open('/root/data_alfred/splits/collect_states_{}_{}_notebook_success_paths.json'.format(date, split_name), 'r') as f:
    success_outpaths = json.load(f)
    
len(success_outpaths)

with open('/root/data_alfred/splits/collect_states_{}_{}_notebook_failed_roots.json'.format(date, split_name), 'r') as f:
    failed_roots = json.load(f)
    
print('success paths = ', len(success_outpaths))
print('failed paths = ', len(failed_roots))

success paths =  254
failed paths =  1


In [20]:
extract_feat_dum = []
extract_feat_outpaths = []

for metadata_outpath in success_outpaths:
    
    with open(metadata_outpath, 'r') as f:
        saved_states = json.load(f)
        
    root = metadata_outpath[:metadata_outpath.index('pp/')]
    
    shifted_states = shift_states(saved_states)
    extract_feat = extract_states_for_model(shifted_states, saved_states, root)
    extract_feat['metadata_path'] = metadata_outpath
    extract_feat['root'] = root
    
    extract_feat_outpath = metadata_outpath.replace('/metadata_states.json', '/extracted_feature_states.json')
    with open(extract_feat_outpath, 'w') as f:
        json.dump(extract_feat ,f)
    
    extract_feat_outpaths.append(extract_feat_outpath)
    extract_feat_dum.append(extract_feat)
    
print(len(extract_feat_dum))   

254


In [21]:
date = '20200513'
with open('/root/data_alfred/splits/extract_feat_states_{}_{}_notebook_outpaths.json'.format(date, split_name), 'w') as f:
    json.dump(extract_feat_outpaths, f)

## SAVE OBJECT VOCAB

In [25]:
# save vocab in dout path
import os
vocab_dout_path = os.path.join('/root/data_alfred/json_feat_2.1.0', '%s.vocab' % 'objects')
torch.save(object_vocab, vocab_dout_path)

In [26]:
object_vocab['object_type']

Vocab(109)