In [22]:
import json
import os
import progressbar

In [2]:
data_p = '/root/data_alfred/unlabeled_12k_20201206/seen'
toy_splits_p = '/root/data_alfred/splits/debug_20201210.json'
splits_p = '/root/data_alfred/splits/unlabeled_12k_20201206.json'

In [23]:
def truncate_extra_subgoals(traj_data, num_subgoals, key):
    assert len(traj_data[key]['anns'][0]['high_descs']) >= num_subgoals
    if len(traj_data[key]['anns'][0]['high_descs']) > num_subgoals:
        traj_data[key]['anns'][0]['high_descs'] = traj_data[key]['anns'][0]['high_descs'][:num_subgoals]
        return True
    else:
        return False

def match_post_prediction_subgoal_lengths(split, data_p, overwrite_traj=False, debug=False):
    '''
    split: a list of tasks {'task':<task name>/<trial id>, 'repeat_idx':int, 'full_traj_success':boolean, 'collected_subgoals':int}
    '''
    adjusts_explainer_full = 0
    adjusts_explainer = 0
    adjusts_baseline = 0
    
    fail = 0
    for task in progressbar.progressbar(split):
        traj_data_p = os.path.join(data_p, task['task'], 'traj_data.json')
        with open(traj_data_p, 'r') as f:
            traj_data = json.load(f)

        if debug:
            print (task)
            print ('Explainer Full')
            print (traj_data['explainer_full_annotations'])
            print('pre truncation')
            print ('Explainer Aux Loss only')
            print (traj_data['explainer_annotations'])
            print ('Baseline')
            print (traj_data['baseline_annotations'])
            print ('----------------------------------------')
        
        if traj_data['plan']['high_pddl'][-1]['discrete_action']['action'] == 'NoOp':
            true_num_subgoals = len(traj_data['plan']['high_pddl']) - 1
        else:
            true_num_subgoals = len(traj_data['plan']['high_pddl'])
        
        # verify that the predicted instructions has # subgoals >= gold # subgoals
        # when predicting in a batch of different tasks, model can decode more than necessary
        assert len(traj_data['explainer_full_annotations']['anns'][0]['high_descs']) >= true_num_subgoals
        assert len(traj_data['explainer_annotations']['anns'][0]['high_descs']) >= true_num_subgoals
        assert len(traj_data['baseline_annotations']['anns'][0]['high_descs']) >= true_num_subgoals
        
        adjusts_explainer_full += int(truncate_extra_subgoals(traj_data, true_num_subgoals, key='explainer_full_annotations'))
        adjusts_explainer += int(truncate_extra_subgoals(traj_data, true_num_subgoals, key='explainer_annotations'))
        adjusts_baseline += int(truncate_extra_subgoals(traj_data, true_num_subgoals, key='baseline_annotations'))
        
        if debug:
            print (adjusts_explainer, adjusts_baseline)
            print (task)
            print('post truncation')
            print ('Explainer Full')
            print (traj_data['explainer_full_annotations'])
            print ('Explainer Aux Loss only')
            print (traj_data['explainer_annotations'])
            print ('Baseline')
            print (traj_data['baseline_annotations'])
            print ('\n\n\n\n\n\n\n\n')
        
        # be careful about overwriting!
        if overwrite_traj:
            with open(traj_data_p, 'w') as f:
                json.dump(traj_data, f)
    
    return adjusts_explainer_full, adjusts_explainer, adjusts_baseline
        

## Debug with Toy Split

In [4]:
with open(toy_splits_p, 'r') as f:
    toy_split = json.load(f)['augmentation']

In [17]:
adjusts_explainer_full, adjusts_explainer, adjusts_baseline = match_post_prediction_subgoal_lengths(toy_split, data_p, overwrite_traj=False, debug=False)

{'task': 'pick_cool_then_place_in_recep-Tomato-None-SinkBasin-21/trial_T20190910_135147_740373', 'repeat_idx': 0, 'full_traj_success': True, 'collected_subgoals': 6}
Explainer Full
{'anns': [{'task_desc': 'put a chilled tomato in the sink .', 'high_descs': ['turn around and walk to the sink .', 'pick up the tomato from the sink .', 'turn around and walk to the fridge .', 'open the fridge and put the tomato inside . shut the door and then open the door and take the tomato out and shut the door .', 'turn left and walk to the sink .', 'put the tomato in the sink .', 'pick up the yellow handled knife on the counter .', 'pick up the yellow handled knife on the counter .', 'pick up the yellow handled knife on the counter .', 'pick up the yellow handled knife on the counter .', 'pick up the yellow handled knife on the counter .', 'pick up the yellow handled knife on the counter .']}]}
pre truncation
Explainer Aux Loss only
{'anns': [{'task_desc': 'put a chilled tomato in the sink .', 'high_de

{'anns': [{'task_desc': 'put two keys on a dresser .', 'high_descs': ['turn around and walk to the dresser .', 'pick up the keys from the dresser .', 'turn right to face the dresser .', 'put the keys on the shelf', 'turn left and face the dresser .', 'pick up the keys from the dresser .', 'turn left and face the cabinet on the left .', 'put the keys on the shelf', 'pick up the yellow handled knife on the counter .', 'pick up the yellow handled knife on the counter .', 'pick up the yellow handled knife on the counter .', 'pick up the yellow handled knife on the counter .']}]}
pre truncation
Explainer Aux Loss only
{'anns': [{'task_desc': 'put two keys on a dresser .', 'high_descs': ['turn to the dresser', 'pick up the keys on the dresser .', 'turn right to face the chest of drawers .', 'put the keys on the shelf', 'turn to the left to face the dresser .', 'pick up the keys on the dresser .', 'turn around and face the cabinet underneath the chest .', 'put the keys on the shelf', 'turn ri

{'task': 'pick_cool_then_place_in_recep-Potato-None-Microwave-12/trial_T20190909_233235_652561', 'repeat_idx': 0, 'full_traj_success': True, 'collected_subgoals': 7}
Explainer Full
{'anns': [{'task_desc': 'put a cold potato in the microwave .', 'high_descs': ['turn around and walk to the sink .', 'take the potato out of the microwave .', 'turn around and walk to the fridge .', 'open the fridge and place the potato inside . shut the door and then open the door and take the potato out again .', 'turn right and walk to the sink .', 'put the potato in the microwave .', 'pick up the yellow handled knife on the counter .', 'pick up the yellow handled knife on the counter .']}]}
pre truncation
Explainer Aux Loss only
{'anns': [{'task_desc': 'put a cold potato in the microwave .', 'high_descs': ['turn left and walk to the microwave above the sink .', 'take the potato from the microwave .', 'turn around and walk to the fridge .', 'open the fridge and put the potato inside . shut the door and th

In [22]:
fail

34

## Run with Real Split

In [20]:
with open(splits_p, 'r') as f:
    splits = json.load(f)['augmentation']

In [26]:
adjusts_explainer_full, adjusts_explainer, adjusts_baseline = match_post_prediction_subgoal_lengths(splits, data_p, overwrite_traj=True, debug=False)

100% (11947 of 11947) |##################| Elapsed Time: 0:02:15 Time:  0:02:15


In [27]:
adjusts_explainer

9619

In [28]:
adjusts_baseline

9619

In [29]:
adjusts_explainer_full

9619