-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
147 lines (118 loc) · 4.03 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
"""
Tracks and logs all user actions, in addition to raw screen states.
Adapted from OpenAI Gym's example of human input: https://github.com/openai/gym
Includes more intuitive key bindings for the following games:
- SpaceInvaders
Usage:
main.py [options]
Options:
--env_id=<id> Environment ID [default: SpaceInvadersNoFrameskip-v4]
--skip-control=<n> Use previous control n times. [default: 0]
--rollout-time=<t> Max. Amount of time to play the game for [default: 1000000]
--logdir=<path> Path to root of logs directory [default: ./logs]
--random Use a random agent.
--n_episodes=<n> Number of episodes to play [default: 1]
"""
import docopt
import sys
import gym
import time
import bindings
import random
import numpy as np
import os
import time
from typing import Dict
from typing import List
def key_press(key: str, _, state: Dict):
"""Set state actions accordingly on key press.
:param key: key pressed
"""
if key == 0xff0d:
state['restart'] = True
if key == 32:
state['pause'] = not state['pause']
state['binding'].key_press(key, _)
def key_release(key: str, _, state: Dict):
"""Set release actions accordingly on key release.
:param key: key pressed
"""
state['binding'].key_release(key, _)
def rollout(env, state: Dict):
"""Advance game accordingly and update state actions.
:param env: gym environment
:param state: game state
"""
state['restart'] = False
obser = env.reset()
skip = 0
episode_reward = 0
sar = [] # state-action-reward tuples
for t in range(state['rollout_time']):
if not skip:
if state['random']:
action = random.randint(0, state['actions'] - 1)
else:
action = state['action']
skip = state['skip_control']
else:
skip -= 1
observation, reward, done, info = env.step(action)
sar.append(np.hstack((np.ravel(observation), action, reward)))
episode_reward += reward
if done:
print('Episode finished after %d timesteps with reward %d' % (
t, episode_reward))
write_sar_log(sar, state['logdir'], episode_reward)
sar = []
break
if state['restart']:
break
env.render()
while state['pause']:
env.render()
time.sleep(0.1)
def write_sar_log(sars: List, logdir: str, episode_reward: int):
"""Write state-action-rewards to a log file."""
np.savez_compressed(os.path.join(logdir,
'%s_%s' % (str(time.time())[-5:], episode_reward)), np.vstack(sars))
def main():
"""Main runnable"""
arguments = docopt.docopt(__doc__)
env_id = arguments['--env_id']
logdir = arguments['--logdir']
n_episodes = int(arguments['--n_episodes'])
env = gym.make(env_id)
random.seed(0)
os.makedirs(logdir, exist_ok=True)
state = {
'action': 0,
'restart': False,
'pause': False,
'skip_control': int(arguments['--skip-control']),
'rollout_time': int(arguments['--rollout-time']),
'random': arguments['--random'],
'env': env,
'logdir': logdir
}
if not hasattr(env.action_space, 'n'):
raise Exception('Keyboard agent only supports discrete action spaces')
state['actions'] = env.action_space.n
# Setup custom key configuration if available
if 'SpaceInvaders' in env_id:
state['binding'] = bindings.SpaceInvadersBinding(state)
else:
state['binding'] = bindings.DefaultBinding(state)
env.render()
if arguments['--random']:
print(' * Using random agent.')
else:
env.unwrapped.viewer.window.on_key_press = lambda key, mod: \
key_press(key, mod, state)
env.unwrapped.viewer.window.on_key_release = lambda key, mod: \
key_release(key, mod, state)
state['binding'].print_instructions()
for _ in range(n_episodes):
rollout(env, state)
if __name__ == '__main__':
main()