### Dataset creation

In [None]:
'''
Install requirements
'''

!pip3 install -r requirements.txt
!pip3 install prettytable

In [None]:
'''
Imports external and own libraries
'''

import pickle

import torch.nn as nn
import torch.optim as optim
from torchsummary import summary
from torch.utils.data import DataLoader

from prettytable import PrettyTable

# own
import collector
import action
import world
import plot
import preprocess
import nets
import train

In [None]:
'''
Create and visualize the world
'''
env = world.init_env('MiniWorld-Maze-v0')
plot.plot_obs_top_dep(env)
world.print_env_parameters(env)

In [None]:
'''
Create Oracle dataset
'''
oracle_actions = action.oracle_actions()
env = world.init_env('MiniWorld-Maze-v0')
oracle_data = collector.collect(oracle_actions, env, img_size=32, show=False)

with open('datasets/oracle_data.pickle', 'wb') as handle:
    pickle.dump(oracle_data, handle, protocol=pickle.HIGHEST_PROTOCOL)

'''
Create Oracle + reversed dataset
'''
oracle_actions = action.oracle_actions()
env = world.init_env('MiniWorld-Maze-v0')
reversed_actions = action.reverse_action_seq(oracle_actions)
oracle_reversed_data = collector.collect(oracle_actions + reversed_actions, env, img_size=32, show=False)

with open('datasets/oracle_reversed_data.pickle', 'wb') as handle:
    pickle.dump(oracle_reversed_data, handle, protocol=pickle.HIGHEST_PROTOCOL)

'''
Create Oracle dataset with random actions inbetween
'''
env = world.init_env('MiniWorld-Maze-v0')
oracle_random_actions = action.add_randomness(oracle_actions, env)
oracle_random_data = collector.collect(oracle_random_actions, env, img_size=32, show=False)

with open('datasets/oracle_random_data.pickle', 'wb') as handle:
    pickle.dump(oracle_random_data, handle, protocol=pickle.HIGHEST_PROTOCOL)
    
'''
Create Oracle dataset with random actions inbetween + reversed
'''
oracle_actions = action.oracle_actions()
env = world.init_env('MiniWorld-Maze-v0')
oracle_random_actions = action.add_randomness(oracle_actions, env)
reversed_actions = action.reverse_action_seq(oracle_random_actions)
oracle_reversed_random_data = collector.collect(oracle_random_actions + reversed_actions, env, img_size=32, show=False)

with open('datasets/oracle_reversed_random_data.pickle', 'wb') as handle:
    pickle.dump(oracle_reversed_random_data, handle, protocol=pickle.HIGHEST_PROTOCOL)
    
'''
Random dataset
'''
env = world.init_env('MiniWorld-Maze-v0')
random_act = action.random_actions(5000, env)
random_data = collector.collect(random_act, env, img_size=32, show=False)

with open('datasets/random_data.pickle', 'wb') as handle:
    pickle.dump(random_data, handle, protocol=pickle.HIGHEST_PROTOCOL)


In [None]:
'''
Load data with pickle (deserialize)
'''
with open('datasets/oracle_data.pickle', 'rb') as handle:
    oracle_data = pickle.load(handle)

with open('datasets/oracle_reversed_data.pickle', 'rb') as handle:
    oracle_reversed_data = pickle.load(handle)

with open('datasets/oracle_random_data.pickle', 'rb') as handle:
    oracle_random_data = pickle.load(handle)

with open('datasets/oracle_reversed_random_data.pickle', 'rb') as handle:
    oracle_reversed_random_data = pickle.load(handle)
    
with open('datasets/random_data.pickle', 'rb') as handle:
    random_data = pickle.load(handle)

In [None]:
print(random_data.keys())
plot.plot_3x3_examples(random_data)

In [None]:
'''
Use Turtle to plot the agents trajectory. 
A gif making function
'''
oracle_actions = action.oracle_actions()
plot.turtle_tracing(oracle_actions)
plot.save_gif_of_sequence(oracle_actions)

In [None]:
'''
Compare different datasets
'''
def count_actions(data):
    count = []
    for i in set(data['actions']):
        count.append(data['actions'].count(i))
    return count

myTable = PrettyTable(["Dataset Name", "Length", "Distribution", "Environment"])
myTable.align["Dataset Name"] = "l"

# Add rows
myTable.add_row(["Oracle data", len(oracle_data['actions']), str(count_actions(oracle_data)), 'MiniWorld-Maze-v0'])
myTable.add_row(["Oracle data + reversed", len(oracle_reversed_data['actions']), str(count_actions(oracle_reversed_data)), 'MiniWorld-Maze-v0'])
myTable.add_row(["Oracle data + random", len(oracle_random_data['actions']), str(count_actions(oracle_random_data)), 'MiniWorld-Maze-v0'])
myTable.add_row(["Oracle data + random + reversed", len(oracle_reversed_random_data['actions']), str(count_actions(oracle_reversed_random_data)), 'MiniWorld-Maze-v0'])
myTable.add_row(["Random", len(random_data['actions']), str(count_actions(random_data)), 'MiniWorld-Maze-v0'])

print(myTable)