# Make a random subset of the training data

- size 250

- check it matches the larger distribution by number of subgoals, high level tasks, and scenes

In [1]:
import os
os.environ['ALFRED_ROOT'] = '/home/hoyeung/alfred/'

import sys
sys.path.append(os.path.join(os.environ['ALFRED_ROOT']))
sys.path.append(os.path.join(os.environ['ALFRED_ROOT'], 'models'))

import torch
import pprint
import json
from data.preprocess import Dataset
from importlib import import_module
from collections import Counter

In [2]:
class args:
    pass

# settings
args.seed = 123
args.data = 'data/json_feat_2.1.0'
args.splits = 'data/splits/oct21.json'
args.preprocess = False #!
args.pp_folder = 'pp'
args.save_every_epoch = False #!
args.model = 'seq2seq_im'
args.gpu = True
args.dout = 'exp/model:seq2seq_im'
args.resume = False #!

# splits


- Load here
https://github.com/Chucooleg/alfred/blob/6d2a6d9b210ea2ab57a3d6c6b2810f796e9ad2d1/models/train/train_seq2seq.py#L80

In [4]:
# load train/valid/tests splits
with open(args.splits) as f:
    splits = json.load(f)
    pprint.pprint({k: len(v) for k, v in splits.items()})

{'tests_seen': 1533,
 'tests_unseen': 1529,
 'train': 21023,
 'valid_seen': 820,
 'valid_unseen': 821}


In [5]:
type(splits['train'])

list

In [6]:
splits['train'][0]

{'repeat_idx': 0,
 'task': 'pick_cool_then_place_in_recep-LettuceSliced-None-DiningTable-17/trial_T20190909_070538_437648'}

In [7]:
splits['valid_seen'][0]

{'repeat_idx': 0,
 'task': 'pick_heat_then_place_in_recep-PotatoSliced-None-SinkBasin-13/trial_T20190909_115736_122556'}

In [8]:
splits['valid_unseen'][0]

{'repeat_idx': 0,
 'task': 'look_at_obj_in_light-CellPhone-None-FloorLamp-219/trial_T20190908_044113_026049'}

# Train


- Load here
https://github.com/Chucooleg/alfred/blob/6d2a6d9b210ea2ab57a3d6c6b2810f796e9ad2d1/models/train/train_seq2seq.py#L80

In [9]:
print(len(splits['train']))
print(len([t for t in splits['train'] if t['repeat_idx'] == 0]))

21023
6574


In [10]:
21023/6574

3.197900821417706

In [11]:
train = splits['train']
train_ann_0 = [t for t in splits['train'] if t['repeat_idx'] == 0]

# Examine train distribution

In [12]:
def load_task_json(task):
    '''
    load preprocessed json from disk
    '''
    json_path = os.path.join(args.data, task['task'], 'pp', 'ann_%d.json' % task['repeat_idx'])
    with open(json_path) as f:
        data = json.load(f)
    return data

In [13]:
def load_task_stats(split):

    task_subgoals = []
    task_lens = []
    task_actions = []
    task_scenes = []

    for task in split:
        task_params = task['task'].split('/')[0]
        task_actions.append(task_params.split('-')[0])
        task_scenes.append(task_params.split('-')[-1])
        ex = load_task_json(task)
        task_subgoals.append(len(ex['num']['action_high']))
        task_lens.append(len(ex['num']['action_low']))
        
    return task_subgoals, task_lens, task_actions, task_scenes

In [19]:
train_subgoals, train_lens, train_actions, train_scenes = load_task_stats(train)

In [42]:
ctr = Counter(train_subgoals).most_common()
[(i[0], i[1]/len(train_subgoals)) for i in ctr]

[(7, 0.3194596394425153),
 (5, 0.2570993673595586),
 (9, 0.16753079960043762),
 (8, 0.108785615754174),
 (13, 0.08495457356228892),
 (11, 0.016886267421395613),
 (14, 0.016172763164153545),
 (12, 0.0156970936593255),
 (4, 0.004994529800694478),
 (6, 0.002854017028968273),
 (10, 0.0026637492270370545),
 (20, 0.0011416068115873092),
 (15, 0.0006183703562764591),
 (18, 0.0004281025543452409),
 (16, 0.0002854017028968273),
 (17, 0.0002854017028968273),
 (19, 0.00014270085144841365)]

In [44]:
ctr = Counter(train_actions).most_common()
[(i[0], i[1]/len(train_actions)) for i in ctr]

[('pick_two_obj_and_place', 0.16905294201588736),
 ('pick_and_place_simple', 0.15435475431670076),
 ('pick_and_place_with_movable_recep', 0.15430718736621796),
 ('pick_cool_then_place_in_recep', 0.14003710222137658),
 ('pick_heat_then_place_in_recep', 0.13998953527089378),
 ('pick_clean_then_place_in_recep', 0.13518527327213054),
 ('look_at_obj_in_light', 0.10707320553679303)]

In [45]:
ctr = Counter(train_scenes).most_common()
[(i[0], i[1]/len(train_scenes)) for i in ctr]

[('1', 0.02606668886457689),
 ('21', 0.025781287161680065),
 ('18', 0.022879703182228987),
 ('20', 0.022641868429814964),
 ('15', 0.021690529420158873),
 ('4', 0.02145269466774485),
 ('24', 0.020548922608571565),
 ('23', 0.020453788707605957),
 ('11', 0.020263520905674737),
 ('16', 0.020120820054226322),
 ('30', 0.020025686153260714),
 ('19', 0.020025686153260714),
 ('27', 0.0198829853018123),
 ('17', 0.01955001664843267),
 ('13', 0.01955001664843267),
 ('7', 0.019312181896018646),
 ('26', 0.018979213242639015),
 ('5', 0.01850354373781097),
 ('25', 0.01821814203491414),
 ('6', 0.0177900394805689),
 ('3', 0.01745707082718927),
 ('14', 0.017124102173809636),
 ('28', 0.017076535223326833),
 ('2', 0.016981401322361225),
 ('12', 0.016267897065119157),
 ('22', 0.01412738429339295),
 ('8', 0.013128478333254055),
 ('214', 0.00879988583931884),
 ('427', 0.00784854682966275),
 ('218', 0.007753412928697141),
 ('224', 0.007753412928697141),
 ('305', 0.007705845978214337),
 ('303', 0.00756314512676

In [28]:
train_ann0_subgoals, train_ann0_lens, train_ann0_actions, train_ann0_scenes = load_task_stats(train_ann_0)

In [46]:
ctr = Counter(train_ann0_subgoals).most_common()
[(i[0], i[1]/len(train_ann0_subgoals)) for i in ctr]

[(7, 0.32354730757529665),
 (5, 0.2613325220565866),
 (9, 0.16686948585336173),
 (8, 0.10587161545482203),
 (13, 0.08229388500152114),
 (11, 0.016580468512321266),
 (14, 0.015211439002129602),
 (12, 0.014755095832065714),
 (4, 0.005019774870702769),
 (6, 0.0030422878004259203),
 (10, 0.0027380590203833284),
 (20, 0.0010648007301490721),
 (18, 0.00045634317006388805),
 (15, 0.00045634317006388805),
 (16, 0.000304228780042592),
 (17, 0.000304228780042592),
 (19, 0.000152114390021296)]

In [47]:
ctr = Counter(train_ann0_actions).most_common()
[(i[0], i[1]/len(train_ann0_actions)) for i in ctr]

[('pick_two_obj_and_place', 0.168542744143596),
 ('pick_and_place_simple', 0.15728627928202008),
 ('pick_and_place_with_movable_recep', 0.15028901734104047),
 ('pick_cool_then_place_in_recep', 0.13994523881959234),
 ('pick_heat_then_place_in_recep', 0.13964101003954973),
 ('pick_clean_then_place_in_recep', 0.13614237906905993),
 ('look_at_obj_in_light', 0.10815333130514147)]

In [48]:
ctr = Counter(train_ann0_scenes).most_common()
[(i[0], i[1]/len(train_ann0_scenes)) for i in ctr]

[('1', 0.026163675083662914),
 ('21', 0.025403103133556433),
 ('20', 0.022817158503194403),
 ('18', 0.022360815333130515),
 ('15', 0.022056586553087922),
 ('4', 0.021600243383024034),
 ('23', 0.020383328262853665),
 ('16', 0.020383328262853665),
 ('19', 0.02023121387283237),
 ('11', 0.020079099482811075),
 ('24', 0.020079099482811075),
 ('27', 0.020079099482811075),
 ('13', 0.019774870702768482),
 ('7', 0.019318527532704594),
 ('17', 0.019014298752662),
 ('30', 0.019014298752662),
 ('25', 0.01871006997261941),
 ('26', 0.018253726802555523),
 ('6', 0.018253726802555523),
 ('5', 0.01794949802251293),
 ('14', 0.017341040462427744),
 ('2', 0.01718892607240645),
 ('3', 0.016884697292363856),
 ('28', 0.01673258290234256),
 ('12', 0.015363553392150897),
 ('22', 0.013842409491937937),
 ('8', 0.013386066321874049),
 ('214', 0.00836629145117128),
 ('305', 0.008062062671128689),
 ('427', 0.007909948281107393),
 ('218', 0.007757833891086097),
 ('224', 0.007757833891086097),
 ('212', 0.007453605111

In [14]:
# sample randomly
import numpy as np

In [15]:
np.random.seed(42)
train_subset = np.random.choice(train_ann_0, size=500, replace=False)

In [16]:
train_subset_subgoals, train_subset_lens, train_subset_actions, train_subset_scenes = load_task_stats(train_subset)

In [58]:
# 42 250
ctr = Counter(train_subset_subgoals).most_common(20)
[(i[0], i[1]/len(train_subset_subgoals)) for i in ctr]

[(7, 0.308),
 (5, 0.288),
 (9, 0.144),
 (8, 0.096),
 (13, 0.092),
 (14, 0.036),
 (11, 0.02),
 (6, 0.012),
 (12, 0.004)]

In [59]:
# 42 250
ctr = Counter(train_subset_actions).most_common()
[(i[0], i[1]/len(train_subset_actions)) for i in ctr]

[('pick_and_place_simple', 0.184),
 ('pick_and_place_with_movable_recep', 0.164),
 ('pick_cool_then_place_in_recep', 0.156),
 ('pick_two_obj_and_place', 0.152),
 ('pick_clean_then_place_in_recep', 0.136),
 ('look_at_obj_in_light', 0.112),
 ('pick_heat_then_place_in_recep', 0.096)]

In [60]:
# 42 250
ctr = Counter(train_subset_scenes).most_common()
[(i[0], i[1]/len(train_subset_scenes)) for i in ctr]

[('21', 0.036),
 ('15', 0.032),
 ('14', 0.028),
 ('16', 0.028),
 ('24', 0.024),
 ('5', 0.024),
 ('3', 0.024),
 ('27', 0.024),
 ('11', 0.024),
 ('30', 0.024),
 ('205', 0.024),
 ('18', 0.024),
 ('313', 0.02),
 ('22', 0.02),
 ('25', 0.02),
 ('23', 0.02),
 ('19', 0.02),
 ('203', 0.02),
 ('2', 0.02),
 ('17', 0.016),
 ('20', 0.016),
 ('4', 0.016),
 ('1', 0.016),
 ('310', 0.012),
 ('416', 0.012),
 ('202', 0.012),
 ('224', 0.012),
 ('420', 0.012),
 ('428', 0.012),
 ('222', 0.012),
 ('13', 0.012),
 ('314', 0.012),
 ('417', 0.012),
 ('218', 0.012),
 ('411', 0.012),
 ('12', 0.012),
 ('220', 0.008),
 ('426', 0.008),
 ('312', 0.008),
 ('427', 0.008),
 ('6', 0.008),
 ('403', 0.008),
 ('429', 0.008),
 ('223', 0.008),
 ('423', 0.008),
 ('307', 0.008),
 ('26', 0.008),
 ('318', 0.008),
 ('320', 0.008),
 ('402', 0.008),
 ('323', 0.008),
 ('326', 0.008),
 ('304', 0.008),
 ('303', 0.008),
 ('421', 0.008),
 ('208', 0.008),
 ('28', 0.008),
 ('422', 0.008),
 ('321', 0.008),
 ('409', 0.008),
 ('212', 0.008),
 

## save the splits

In [66]:
args.splits

'data/splits/oct21.json'

In [64]:
len(train_subset)

250

In [72]:
splits['train_sanity'] = list(train_subset)
splits.keys()

dict_keys(['tests_seen', 'tests_unseen', 'train', 'valid_seen', 'valid_unseen', 'train_sanity'])

In [None]:
with open(args.splits) as f:
    splits = json.load(f)
    pprint.pprint({k: len(v) for k, v in splits.items()})

In [73]:
with open('data/splits/apr13.json', 'w') as f:
    json.dump(splits, f)

In [74]:
with open('data/splits/apr13.json') as f:
    splits = json.load(f)
    pprint.pprint({k: len(v) for k, v in splits.items()})

{'tests_seen': 1533,
 'tests_unseen': 1529,
 'train': 21023,
 'train_sanity': 250,
 'valid_seen': 820,
 'valid_unseen': 821}
