# Puddleworld Data Loader Prototyping

Spatial Navigation (Janner et. al) [https://github.com/JannerM/spatial-reasoning]

In [4]:
from preliminary.exploration_utils import *
from data.dataset_loading import *
%matplotlib inline 


## SR
## Spatial Reasoning (Janner)
def make_sr_dataset(raw_train, raw_test, verbose):
	sr_dataset = {'train': [], 'test': []}
	for i, dataset in enumerate((raw_train, raw_test)):
		layouts, objects, rewards, terminal, instructions, values, goals = dataset
		mode = 'train' if i == 0 else 'test'
		if verbose:
			print("Found %d %s instructions." % (len(instructions), mode))
		
		for instruction in instructions:
			sr_dataset[mode].append({
				'hints_aug': instruction.split()
			})
	return sr_dataset

def load_sr(verbose=False):
	"""
	Returns: local_sr, global_sr
	"""
	import data.spatialreasoning
	import data.spatialreasoning.environment
	import data.spatialreasoning.data as srdata
	
	annotations='human'
	n_local_train, n_local_val = 10000, 10000
	n_global_train, n_global_val = 10000, 10000
	data_path = os.path.join(TOP_LEVEL, "data", "spatialreasoning/data")
	local_train, local_val = srdata.load(data_path, 'local', annotations, n_local_train, n_local_val)
	global_train, global_val = srdata.load(data_path, 'global', annotations, n_global_train, n_global_val)
	return local_train, local_val, global_train, global_val		

local_train, local_val, global_train, global_val = load_sr()


<Data> Loading local train environments with human annotations
<Data> Found 1566 annotations

<Data> Loading local test environments with human annotations
<Data> Found 399 annotations

<Data> Loading global train environments with human annotations
<Data> Found 1071 annotations

<Data> Loading global test environments with human annotations
<Data> Found 272 annotations


In [22]:
# Create a new version of the datasets
def simple_dataset(dataset):
    layouts, objects, rewards, terminal, instructions, values, goals = dataset
    return zip(layouts, objects, instructions, goals)




# Get all the layouts
def show_layouts(dataset, inds):
#     for j in inds:
        layouts, objects, instructions, goals = simple_dataset(dataset)[1]
        print(layouts)
    
#         print(instructions)
#         print(goals)

#         print instructions[j]
#         from matplotlib import pyplot as plt
#         plt.figure(figsize=(6,6))
#         plt.imshow(objects[j].squeeze()+layouts[j].squeeze())
#         def obj2text(obj, end):
#             obj_strs = ['P', "St", "Ci", "Tri", "He", "Sp", "Di", "Ro", "Tr", "Ho", "Hr"]
#             for i in range(len(obj)):
#                 for j in range(len(obj)):
#                     if obj[i][j] != 0:
#                         plt.text(j-0.3, i, obj_strs[int(obj[i][j])], fontsize=20,color="red")
#             plt.text(end[1]-0.2, end[0]+0.7, "*", fontsize=50,color="red")
#         obj2text(objects[j].squeeze(), goals[j])
#         plt.show()

show_layouts(local_train, range(10))

[[[0. 0. 0. 1. 1. 1. 1. 1. 1. 1.]
  [0. 0. 0. 1. 1. 1. 1. 1. 1. 1.]
  [0. 0. 0. 0. 0. 1. 1. 0. 0. 1.]
  [0. 0. 0. 0. 0. 1. 1. 0. 0. 1.]
  [0. 0. 0. 0. 0. 1. 1. 0. 0. 1.]
  [0. 0. 0. 0. 0. 0. 1. 1. 0. 1.]
  [0. 0. 0. 0. 0. 0. 1. 1. 1. 1.]
  [0. 0. 0. 0. 0. 0. 1. 1. 1. 1.]
  [0. 0. 0. 0. 0. 0. 1. 1. 1. 1.]
  [0. 0. 0. 0. 0. 0. 1. 1. 1. 1.]]]


In [27]:
object_strings = ["NULL", "puddle", "star", "circle", "triangle", "heart", "spade", "diamond", "rock", "tree", "house", "horse"]
simple_local_train, \
simple_local_val, \
simple_global_train, \
simple_global_val = simple_dataset(local_train), simple_dataset(local_val), simple_dataset(global_train), simple_dataset(global_val)

In [29]:
puddleworld = {
    'object_keys' : object_strings,
    'local_train' : simple_local_train,
    'local_test' : simple_local_val,
    'global_train' : simple_global_train,
    'global_test' : simple_global_val
}
import pickle
with open('puddleworld.pickle', 'wb') as handle:
    pickle.dump(puddleworld, handle, protocol=pickle.HIGHEST_PROTOCOL)

with open('puddleworld.pickle', 'rb') as handle:
    check = pickle.load(handle)
    print check.keys()
    print len(check['local_train'])
    print check['local_train'][0]

['local_train', 'local_test', 'object_keys', 'global_train', 'global_test']
1566
(array([[[0., 0., 0., 0., 0., 0., 0., 0., 1., 1.],
        [0., 0., 0., 0., 1., 1., 1., 1., 1., 1.],
        [0., 0., 0., 0., 1., 1., 1., 1., 1., 1.],
        [0., 0., 0., 0., 1., 1., 1., 1., 1., 1.],
        [0., 0., 0., 0., 1., 1., 1., 1., 1., 1.],
        [0., 0., 0., 0., 1., 1., 1., 1., 1., 1.],
        [0., 0., 0., 0., 1., 1., 1., 1., 1., 1.],
        [0., 0., 0., 0., 1., 1., 0., 0., 1., 1.],
        [0., 0., 0., 0., 1., 1., 0., 0., 1., 1.],
        [0., 0., 0., 0., 1., 1., 0., 0., 1., 1.]]]), array([[[ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  8.,  2.],
        [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  7.,  1.],
        [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  6.],
        [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0., 10.],
        [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.,  0., 10., 10.,  5.,  8.,  

In [12]:
import pickle
with open('puddleworld.pickle', 'rb') as handle:
    check = pickle.load(handle)
    print check.keys()
    print len(check['local_train'])
    print check['local_train'][0]

['local_train', 'local_test', 'object_keys', 'global_train', 'global_test']
1566
(array([[[0., 0., 0., 0., 0., 0., 0., 0., 1., 1.],
        [0., 0., 0., 0., 1., 1., 1., 1., 1., 1.],
        [0., 0., 0., 0., 1., 1., 1., 1., 1., 1.],
        [0., 0., 0., 0., 1., 1., 1., 1., 1., 1.],
        [0., 0., 0., 0., 1., 1., 1., 1., 1., 1.],
        [0., 0., 0., 0., 1., 1., 1., 1., 1., 1.],
        [0., 0., 0., 0., 1., 1., 1., 1., 1., 1.],
        [0., 0., 0., 0., 1., 1., 0., 0., 1., 1.],
        [0., 0., 0., 0., 1., 1., 0., 0., 1., 1.],
        [0., 0., 0., 0., 1., 1., 0., 0., 1., 1.]]]), array([[[ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  8.,  2.],
        [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  7.,  1.],
        [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  6.],
        [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0., 10.],
        [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.,  0., 10., 10.,  5.,  8.,  

In [13]:
# Convert all the arrays to list:
for key in check.keys():
    if key != 'object_keys':
        for i in range(len(check[key])):
            layouts, objects, instructions, goals = check[key][i]
            check[key][i] = layouts.tolist(), objects.tolist(), instructions.encode('ascii', 'ignore'), goals


import json
with open('puddleworld.json', 'w') as fp:
    json.dump(check, fp)

In [2]:
import json
with open('puddleworld.json') as f:
    check = json.load(f)
    print check.keys()
    print len(check['local_train'])
    print check['local_train'][0]

[u'local_train', u'local_test', u'global_test', u'global_train', u'object_keys']
1566
[[[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0], [0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 1.0, 1.0], [0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 1.0, 1.0], [0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 1.0, 1.0]]], [[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 8.0, 2.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 7.0, 1.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 6.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 10.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 10.0, 10.0, 5.0, 8.0, 0.0], [0.0, 0.0, 0.0, 0.0,