-
Notifications
You must be signed in to change notification settings - Fork 0
/
atari_util.py
58 lines (51 loc) · 2.22 KB
/
atari_util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
"""Auxilary files for those who wanted to solve breakout with CEM or policy gradient"""
import numpy as np
import gym
from scipy.misc import imresize
from gym.core import Wrapper
from gym.spaces.box import Box
class PreprocessAtari(Wrapper):
def __init__(self, env, height=42, width=42, color=False,
crop=lambda img: img, n_frames=4, dim_order='theano'):
"""A gym wrapper that reshapes, crops and scales image into the desired shapes"""
super(PreprocessAtari, self).__init__(env)
assert dim_order in ('theano', 'tensorflow')
self.img_size = (height,width)
self.crop=crop
self.color=color
self.dim_order = dim_order
n_channels = (3 * n_frames) if color else n_frames
obs_shape = [n_channels,height,width] if dim_order == 'theano' else [height,width,n_channels]
self.observation_space = Box(0.0, 1.0, obs_shape)
self.framebuffer = np.zeros(obs_shape, 'float32')
def reset(self):
"""resets breakout, returns initial frames"""
self.framebuffer = np.zeros_like(self.framebuffer)
self.update_buffer(self.env.reset())
return self.framebuffer
def step(self,action):
"""plays breakout for 1 step, returns frame buffer"""
new_img,r,done,info = self.env.step(action)
self.update_buffer(new_img)
return self.framebuffer,r,done,info
### image processing ###
def update_buffer(self,img):
img = self.preproc_image(img)
offset = 3 if self.color else 1
if self.dim_order == 'theano':
axis = 0
cropped_framebuffer = self.framebuffer[:-offset]
else:
axis = -1
cropped_framebuffer = self.framebuffer[:,:,:-offset]
self.framebuffer = np.concatenate([img, cropped_framebuffer], axis = axis)
def preproc_image(self, img):
"""what happens to the observation"""
img = self.crop(img)
img = imresize(img, self.img_size)
if not self.color:
img = img.mean(-1, keepdims=True)
if self.dim_order == 'theano':
img = img.transpose([2,0,1]) # [h, w, c] to [c, h, w]
img = img.astype('float32')/255.
return img