In [307]:
import json
from collections import defaultdict

In [308]:
states_path = '/root/data_alfred/json_feat_2.1.0/pick_and_place_simple-TennisRacket-None-Bed-303/trial_T20190906_193617_277654/pp/metadata_states.json'

with open(states_path, 'r') as f:
    saved_states = json.load(f)

## Setup

In [309]:
object_state_types = ['isToggled', 'isBroken', 'isFilledWithLiquid', 'isDirty',
                      'isUsedUp', 'isCooked', 'ObjectTemperature', 'isSliced',
                      'isOpen', 'isPickedUp', 'mass', 'receptacleObjectIds']

In [None]:
date = '20200511'
split_name = 'train'

In [None]:
import logging
logging.basicConfig(filename='/root/data_alfred/splits/collect_states_{}_{}.log'.format(date, split_name),level=logging.DEBUG)

## Try on one task

In [310]:
# shift all the states to the last time step
shifted_states = [{} for _ in range(len(saved_states))]

for t, state in enumerate(saved_states):
    new_state = {k:v for k,v in state.items() if k != 'objects_metadata'}
    if t < len(saved_states) - 1:
        new_state['objects_metadata'] = saved_states[t+1]['objects_metadata']
    elif t == len(saved_states) - 1:
        new_state['objects_metadata'] = saved_states[t]['objects_metadata']
    shifted_states[t] = new_state

In [311]:
def get_object_states(obj, required_state_types):
    o_states = {}
    for state_typ in required_state_types:
        if state_typ in obj.keys():
            o_states[state_typ] = obj[state_typ]
        else:
            o_states[state_typ] = None
    return o_states

def detect_state_change(last_obj_states, curr_obj_states, objectId_list):
    state_change = []
    for obj_Id in objectId_list:
        if (not obj_Id in last_obj_states) and (not obj_Id in curr_obj_states):
            state_change.append(False)
        elif not obj_Id in last_obj_states:
            state_change.append(True)
        elif not obj_Id in curr_obj_states:
            raise Exception('Objects should always appear in next time step.')
        else:
            if last_obj_states[obj_Id] == curr_obj_states[obj_Id]:
                state_change.append(False)
            else:
                state_change.append(True)
    return state_change

In [312]:
def detect_type_state_change(instance_state_change, object_instance_list_sorted, object_type_list_sorted):
    state_change = []
    
    curr_typ = object_instance_list_sorted[0].split('|')[0]
    changes = []
    for i, instance_name in enumerate(object_instance_list_sorted):
        instance_typ = instance_name.split('|')[0]
        if instance_typ != curr_typ:
            state_change.append(True in changes)
            changes = []
            curr_typ = instance_typ
        changes.append(instance_state_change[i])
            
    # don't forget about last typ!
    state_change.append(True in changes)
    assert len(state_change) == len(object_type_list_sorted)
    return state_change

def detect_type_visibility(instance_visible, object_instance_list_sorted, object_type_list_sorted):
    type_visible = []
    
    curr_typ = object_instance_list_sorted[0].split('|')[0]
    visible = []
    for i, instance_name in enumerate(object_instance_list_sorted):
        instance_typ = instance_name.split('|')[0]
        if instance_typ != curr_typ:
            type_visible.append(True in visible)
            visible = []
            curr_typ = instance_typ
        visible.append(instance_visible[i])
            
    # don't forget about last typ!
    type_visible.append(True in visible)
    assert len(type_visible) == len(object_type_list_sorted)
    return type_visible

In [313]:
objectId_list = sorted([o['objectId'] for o in shifted_states[-1]['objects_metadata']])
objectTyp_list = sorted(set([o.split('|')[0] for o in objectId_list]))
num_subgoals = len(set([s['subgoal_step'] for s in shifted_states]))

feat = {
    'subgoal': [[] for _ in range(num_subgoals)],
    'subgoal_i': [[] for _ in range(num_subgoals)],
    'subgoal_t': [[] for _ in range(num_subgoals)],
    'overall_t': [[] for _ in range(num_subgoals)],
    
    'objectInstanceList': objectId_list,
    'instance_visibile': [[] for _ in range(num_subgoals)], # [subgoal_i][subgoal timestep] = [T/F for ob in objectId_list]
    'instance_state_change': [[] for _ in range(num_subgoals)],
    
    'objectTypeList': objectTyp_list,
    'type_visibile': [[] for _ in range(num_subgoals)], # [subgoal_i][subgoal timestep] = [T/F for ob in objectTyp_list]
    'type_state_change': [[] for _ in range(num_subgoals)]
}

last_obj_states = {obj['objectId']:get_object_states(obj, object_state_types) for obj in saved_states[0]['objects_metadata']}
for t, state in enumerate(shifted_states):
    
    subgoal_i = state['subgoal_step']
    subgoal = state['subgoal']
    
    if state['new_subgoal']:
        subgoal_t = 0

    feat['subgoal_i'][subgoal_i].append(subgoal_i)
    feat['subgoal_t'][subgoal_i].append(subgoal_t)
    feat['subgoal'][subgoal_i].append(subgoal)
    feat['overall_t'][subgoal_i].append(t)
        
    # get state change
    curr_obj_states = {obj['objectId']:get_object_states(obj, object_state_types) for obj in state['objects_metadata']}
    # list same order as objectId_list
    state_change = detect_state_change(last_obj_states, curr_obj_states, objectId_list)
    last_obj_states = curr_obj_states
    
    feat['instance_state_change'][subgoal_i].append(state_change)
    
    # get visibility
    visible = [False for _ in objectId_list]
    for ob in state['objects_metadata']:
        pos = objectId_list.index(ob['objectId'])
        visible[pos] = ob['visible']
    feat['instance_visibile'][subgoal_i].append(visible)
    
    # get type state change
    type_state_change = detect_type_state_change(state_change, objectId_list, objectTyp_list)
    feat['type_state_change'][subgoal_i].append(type_state_change)
    
    # get type visibility
    type_visible = detect_type_visibility(visible, objectId_list, objectTyp_list)
    feat['type_visibile'][subgoal_i].append(type_visible)
    
    subgoal_t += 1

In [314]:
feat['subgoal']

[['GotoLocation',
  'GotoLocation',
  'GotoLocation',
  'GotoLocation',
  'GotoLocation',
  'GotoLocation',
  'GotoLocation',
  'GotoLocation',
  'GotoLocation',
  'GotoLocation',
  'GotoLocation',
  'GotoLocation'],
 ['PickupObject'],
 ['GotoLocation',
  'GotoLocation',
  'GotoLocation',
  'GotoLocation',
  'GotoLocation',
  'GotoLocation',
  'GotoLocation',
  'GotoLocation',
  'GotoLocation',
  'GotoLocation',
  'GotoLocation',
  'GotoLocation',
  'GotoLocation',
  'GotoLocation',
  'GotoLocation',
  'GotoLocation',
  'GotoLocation',
  'GotoLocation',
  'GotoLocation',
  'GotoLocation',
  'GotoLocation',
  'GotoLocation',
  'GotoLocation'],
 ['PutObject'],
 ['NoOp']]

In [315]:
feat.keys()

dict_keys(['subgoal', 'subgoal_i', 'subgoal_t', 'overall_t', 'objectInstanceList', 'instance_visibile', 'instance_state_change', 'objectTypeList', 'type_visibile', 'type_state_change'])

In [316]:
feat['objectInstanceList']

['AlarmClock|-01.85|+00.50|-01.17',
 'BaseballBat|-01.63|+00.65|+00.42',
 'Bed|-01.05|+00.00|-01.85',
 'Blinds|+00.45|+02.16|-02.85',
 'Book|+00.34|+00.72|-02.65',
 'Box|+02.00|+00.32|-01.04',
 'CD|+00.18|+00.41|-02.55',
 'CD|+00.74|+00.72|-02.55',
 'CellPhone|+00.63|+00.11|-02.55',
 'CellPhone|-01.58|+00.60|-00.39',
 'Chair|-01.56|+00.00|-00.31',
 'Cloth|+02.18|+00.00|-00.57',
 'CreditCard|+00.65|+00.73|-02.47',
 'CreditCard|-01.52|+00.60|+00.03',
 'DeskLamp|-01.85|+00.59|-00.79',
 'Desk|-01.71|+00.00|-00.37',
 'Drawer|+00.45|+00.20|-02.43',
 'Drawer|+00.45|+00.50|-02.43',
 'Drawer|-01.79|+00.10|-01.15',
 'Drawer|-01.79|+00.25|-01.15',
 'Drawer|-01.79|+00.39|-01.15',
 'Dresser|+00.44|+00.04|-02.65',
 'GarbageCan|+02.22|+00.00|+00.31',
 'KeyChain|+00.08|+00.73|-02.53',
 'KeyChain|+00.77|+00.73|-02.71',
 'KeyChain|-01.85|+00.60|-00.39',
 'Laptop|-01.76|+00.59|-00.11',
 'LightSwitch|-00.64|+01.15|+00.50',
 'Mirror|-01.07|+01.47|+00.50',
 'Mug|+00.08|+00.73|-02.65',
 'Pencil|+00.65|+00.73

In [322]:
feat['subgoal'][1]

['PickupObject']

In [321]:
list(zip(feat['objectInstanceList'], feat['instance_state_change'][1][0]))

[('AlarmClock|-01.85|+00.50|-01.17', False),
 ('BaseballBat|-01.63|+00.65|+00.42', False),
 ('Bed|-01.05|+00.00|-01.85', False),
 ('Blinds|+00.45|+02.16|-02.85', False),
 ('Book|+00.34|+00.72|-02.65', False),
 ('Box|+02.00|+00.32|-01.04', False),
 ('CD|+00.18|+00.41|-02.55', False),
 ('CD|+00.74|+00.72|-02.55', False),
 ('CellPhone|+00.63|+00.11|-02.55', False),
 ('CellPhone|-01.58|+00.60|-00.39', False),
 ('Chair|-01.56|+00.00|-00.31', False),
 ('Cloth|+02.18|+00.00|-00.57', False),
 ('CreditCard|+00.65|+00.73|-02.47', False),
 ('CreditCard|-01.52|+00.60|+00.03', False),
 ('DeskLamp|-01.85|+00.59|-00.79', False),
 ('Desk|-01.71|+00.00|-00.37', False),
 ('Drawer|+00.45|+00.20|-02.43', False),
 ('Drawer|+00.45|+00.50|-02.43', False),
 ('Drawer|-01.79|+00.10|-01.15', False),
 ('Drawer|-01.79|+00.25|-01.15', False),
 ('Drawer|-01.79|+00.39|-01.15', False),
 ('Dresser|+00.44|+00.04|-02.65', False),
 ('GarbageCan|+02.22|+00.00|+00.31', False),
 ('KeyChain|+00.08|+00.73|-02.53', False),
 ('K

In [319]:
feat['root']

KeyError: 'root'

In [132]:
objectTyp_list

['AlarmClock',
 'BaseballBat',
 'Bed',
 'Blinds',
 'Book',
 'Box',
 'CD',
 'CellPhone',
 'Chair',
 'Cloth',
 'CreditCard',
 'Desk',
 'DeskLamp',
 'Drawer',
 'Dresser',
 'GarbageCan',
 'KeyChain',
 'Laptop',
 'LightSwitch',
 'Mirror',
 'Mug',
 'Pen',
 'Pencil',
 'Pillow',
 'Poster',
 'Shelf',
 'SideTable',
 'TennisRacket',
 'Vase',
 'Window']

## Loop through some tasks in train

In [341]:
def shift_states(saved_states):
    # shift all the states to the last time step
    shifted_states = [{} for _ in range(len(saved_states))]

    for t, state in enumerate(saved_states):
        new_state = {k:v for k,v in state.items() if k != 'objects_metadata'}
        if t < len(saved_states) - 1:
            new_state['objects_metadata'] = saved_states[t+1]['objects_metadata']
        elif t == len(saved_states) - 1:
            new_state['objects_metadata'] = saved_states[t]['objects_metadata']
        shifted_states[t] = new_state
        
    return shifted_states

In [342]:
def extract_states_for_model(shifted_states, saved_states, root):

    objectId_list = sorted([o['objectId'] for o in shifted_states[-1]['objects_metadata']])
    objectTyp_list = sorted(set([o.split('|')[0] for o in objectId_list]))
    num_subgoals = len(set([s['subgoal_step'] for s in shifted_states]))

    feat = {
        'subgoal': [[] for _ in range(num_subgoals)],
        'subgoal_i': [[] for _ in range(num_subgoals)],
        'subgoal_t': [[] for _ in range(num_subgoals)],
        'overall_t': [[] for _ in range(num_subgoals)],

        'objectInstanceList': objectId_list,
        'instance_visibile': [[] for _ in range(num_subgoals)], # [subgoal_i][subgoal timestep] = [T/F for ob in objectId_list]
        'instance_state_change': [[] for _ in range(num_subgoals)],

        'objectTypeList': objectTyp_list,
        'type_visibile': [[] for _ in range(num_subgoals)], # [subgoal_i][subgoal timestep] = [T/F for ob in objectTyp_list]
        'type_state_change': [[] for _ in range(num_subgoals)]
    }

    last_obj_states = {obj['objectId']:get_object_states(obj, object_state_types) for obj in saved_states[0]['objects_metadata']}
    for t, state in enumerate(shifted_states):

        subgoal_i = state['subgoal_step']
        subgoal = state['subgoal']

        if state['new_subgoal']:
            subgoal_t = 0

        feat['subgoal_i'][subgoal_i].append(subgoal_i)
        feat['subgoal_t'][subgoal_i].append(subgoal_t)
        feat['subgoal'][subgoal_i].append(subgoal)
        feat['overall_t'][subgoal_i].append(t)

        # get state change
        curr_obj_states = {obj['objectId']:get_object_states(obj, object_state_types) for obj in state['objects_metadata']}
        
        # list same order as objectId_list
        state_change = detect_state_change(last_obj_states, curr_obj_states, objectId_list)
        last_obj_states = curr_obj_states

        feat['instance_state_change'][subgoal_i].append(state_change)
    
        # get visibility
        visible = [False for _ in objectId_list]
        for ob in state['objects_metadata']:
            pos = objectId_list.index(ob['objectId'])
            visible[pos] = ob['visible']
        feat['instance_visibile'][subgoal_i].append(visible)

        # get type state change
        type_state_change = detect_type_state_change(state_change, objectId_list, objectTyp_list)
        feat['type_state_change'][subgoal_i].append(type_state_change)

        # get type visibility
        type_visible = detect_type_visibility(visible, objectId_list, objectTyp_list)
        feat['type_visibile'][subgoal_i].append(type_visible)

        subgoal_t += 1
        
    return feat

## TRAIN

In [410]:
date = '20200511'
split_name = 'train'

with open('/root/data_alfred/splits/collect_states_{}_{}_notebook_success_paths.json'.format(date, split_name), 'r') as f:
    success_outpaths = json.load(f)
    
len(success_outpaths)

with open('/root/data_alfred/splits/collect_states_{}_{}_notebook_failed_roots.json'.format(date, split_name), 'r') as f:
    failed_roots = json.load(f)
    
print('success paths = ', len(success_outpaths))
print('failed paths = ', len(failed_roots))

success paths =  6505
failed paths =  69


In [411]:
extract_feat_dum = []
extract_feat_outpaths = []

for metadata_outpath in success_outpaths:
    
    with open(metadata_outpath, 'r') as f:
        saved_states = json.load(f)
        
    root = metadata_outpath[:metadata_outpath.index('pp/')]
    
    shifted_states = shift_states(saved_states)
    extract_feat = extract_states_for_model(shifted_states, saved_states, root)
    extract_feat['metadata_path'] = metadata_outpath
    extract_feat['root'] = root
    
    extract_feat_outpath = metadata_outpath.replace('/metadata_states.json', '/extracted_feature_states.json')
    with open(extract_feat_outpath, 'w') as f:
        json.dump(extract_feat ,f)
    
    extract_feat_outpaths.append(extract_feat_outpath)
    extract_feat_dum.append(extract_feat)
    
print(len(extract_feat_dum))    

6505


In [413]:
len(extract_feat_outpaths)

6505

In [412]:
date = '20200513'
with open('/root/data_alfred/splits/extract_feat_states_{}_{}_notebook_outpaths.json'.format(date, split_name), 'w') as f:
    json.dump(extract_feat_outpaths, f)

## VALID SEEN

In [397]:
date = '20200511'
split_name = 'valid_seen'

with open('/root/data_alfred/splits/collect_states_{}_{}_notebook_success_paths.json'.format(date, split_name), 'r') as f:
    success_outpaths = json.load(f)
    
len(success_outpaths)

with open('/root/data_alfred/splits/collect_states_{}_{}_notebook_failed_roots.json'.format(date, split_name), 'r') as f:
    failed_roots = json.load(f)
    
print('success paths = ', len(success_outpaths))
print('failed paths = ', len(failed_roots))

success paths =  249
failed paths =  2


In [398]:
extract_feat_dum = []
extract_feat_outpaths = []

for metadata_outpath in success_outpaths:
    
    with open(metadata_outpath, 'r') as f:
        saved_states = json.load(f)
        
    root = metadata_outpath[:metadata_outpath.index('pp/')]
    
    shifted_states = shift_states(saved_states)
    extract_feat = extract_states_for_model(shifted_states, saved_states, root)
    extract_feat['metadata_path'] = metadata_outpath
    extract_feat['root'] = root
    
    extract_feat_outpath = metadata_outpath.replace('/metadata_states.json', '/extracted_feature_states.json')
    with open(extract_feat_outpath, 'w') as f:
        json.dump(extract_feat ,f)
    
    extract_feat_outpaths.append(extract_feat_outpath)
    extract_feat_dum.append(extract_feat)
    
print(len(extract_feat_dum))    

249


In [402]:
date = '20200513'
with open('/root/data_alfred/splits/extract_feat_states_{}_{}_notebook_outpaths.json'.format(date, split_name), 'w') as f:
    json.dump(extract_feat_outpaths, f)

## VALID UNSEEN

In [403]:
date = '20200511'
split_name = 'valid_unseen'

with open('/root/data_alfred/splits/collect_states_{}_{}_notebook_success_paths.json'.format(date, split_name), 'r') as f:
    success_outpaths = json.load(f)
    
len(success_outpaths)

with open('/root/data_alfred/splits/collect_states_{}_{}_notebook_failed_roots.json'.format(date, split_name), 'r') as f:
    failed_roots = json.load(f)
    
print('success paths = ', len(success_outpaths))
print('failed paths = ', len(failed_roots))

success paths =  254
failed paths =  1


In [404]:
extract_feat_dum = []
extract_feat_outpaths = []

for metadata_outpath in success_outpaths:
    
    with open(metadata_outpath, 'r') as f:
        saved_states = json.load(f)
        
    root = metadata_outpath[:metadata_outpath.index('pp/')]
    
    shifted_states = shift_states(saved_states)
    extract_feat = extract_states_for_model(shifted_states, saved_states, root)
    extract_feat['metadata_path'] = metadata_outpath
    extract_feat['root'] = root
    
    extract_feat_outpath = metadata_outpath.replace('/metadata_states.json', '/extracted_feature_states.json')
    with open(extract_feat_outpath, 'w') as f:
        json.dump(extract_feat ,f)
    
    extract_feat_outpaths.append(extract_feat_outpath)
    extract_feat_dum.append(extract_feat)
    
print(len(extract_feat_dum))   

254


In [407]:
date = '20200513'
with open('/root/data_alfred/splits/extract_feat_states_{}_{}_notebook_outpaths.json'.format(date, split_name), 'w') as f:
    json.dump(extract_feat_outpaths, f)