In [1]:
import numpy as np
import gym
from gym.spaces import Discrete,Box

# -------------------------------------------
# Policies
# -------------------------------------------

class DeterministicDiscreteActionLinearPolicy(object):

	def __init__(self,theta,ob_space,ac_space):
		"""
		dim_ob:dimension of observations
		n_actions: number of actions
		theta: flat vector of parameters
		"""
		dim_ob = ob_space.shape[0]
		n_actions = ac_space.n 
		assert len(theta) == (dim_ob + 1) * n_actions
		self.W = theta[0:dim_ob * n_actions].reshape(dim_ob,n_actions)
		self.b = theta[dim_ob * n_actions : None].reshape(1,n_actions)

	def act(self,ob):
		"""
		"""
		y = ob.dot(self.W) + self.b
		a = y.argmax()
		return a 

class DeterministicContinuousActionLinearPolicy(object):
	def __init__(self, theta,ob_space,ac_space):
		"""
		dim_ob: dimension of observations
		dim_ac: dimension of action vector
		theta: flat vector of parameters
		"""
		self.ac_space = ac_space
		dim_ob = ob_space.shape[0]
		dim_ac = ac_space.shape[0]
		assert len(theta) == (dim_ob + 1) * dim_ac
		self.W = theta[0:dim_ob*dim_ac].reshape(dim_ob,dim_ac)
		self.b = theta[dim_ob * dim_ac : None]

	def act(self,ob):
		a = np.clip(ob.dot(self.W) + self.b,self.ac_space.low,self.ac_space.high)
		return a 

def do_episode(policy,env,num_steps,render=False):
	total_rew = 0
	ob = env.reset()
	for t in range(num_steps):
		a = policy.act(ob)
		(ob,reward,done,_info) = env.step(a)
		total_rew += reward
		if render and t%3 == 0: env.render()
		if done: break
	return total_rew

env = None
def noisy_evaluation(theta):
	policy = make_policy(theta)
	rew = do_episode(policy,env,num_steps)
	return rew 

def make_policy(theta):
	if isinstance(env.action_space,Discrete):
		return DeterministicDiscreteActionLinearPolicy(theta,env.observation_space,env.action_space)
	elif isinstance(env.action_space,Box):
		return DeterministicContinuousActionLinearPolicy(theta,env.observation_space,env.action_space)
	else:
		raise NotImplementedError

# Task settings:
name = 'DistortionBoxCartPole-v0'
env = gym.make(name)
outdir = '/tmp/' + name + '-results'
env.monitor.start(outdir,force=True)

num_steps = 500 # maximum length of episode
# Alg settings:
n_iter = 100  # number of iterations of CEM
batch_size = 25	 # number of samples per batch
elite_frac = 0.2 # fraction of samples used as elite set

if isinstance(env.action_space,Discrete):
	dim_theta = (env.observation_space.shape[0] + 1) * env.action_space.n 
elif isinstance(env.action_space,Box):
	dim_theta = (env.observation_space.shape[0] + 1) * env.action_space.shape[0]
else:
	raise NotImplementedError

# Initialize mean and standard deviation
theta_mean = np.zeros(dim_theta)
theta_std = np.ones(dim_theta)
extra_std = 0.001 * np.ones(dim_theta)

std_decay_time = -1

# Now, for the algorithm
for iteration in xrange(n_iter):
	# Sample parameter vectors
	extra_var_multiplier = max((1.0-iteration/float(n_iter/2)),0)
	sample_std = np.sqrt(theta_std + np.square(extra_std) * extra_var_multiplier)
	thetas = [theta_mean + dth for dth in sample_std * np.random.randn(batch_size,dim_theta)]
	rewards = [noisy_evaluation(theta) for theta in thetas]
	# Get elite parameters
	n_elite = int(batch_size * elite_frac)
	elite_inds = np.argsort(rewards)[batch_size - n_elite:batch_size]
	elite_thetas = [thetas[i] for i in elite_inds]
	# Update theta_mean,theta_std
	theta_mean = np.mean(elite_thetas,axis = 0)
	theta_std = np.std(elite_thetas,axis = 0)
	print "iteration %i. mean f: %8.3g. max f: %8.3g" % (iteration,np.mean(rewards),np.max(rewards))
	do_episode(make_policy(theta_mean),env,num_steps,render=True)

env.monitor.close()

print "theta_mean:",theta_mean

[2016-06-05 21:14:02,011] Making new env: DistortionBoxCartPole-v0
[2016-06-05 21:14:02,022] Clearing 4 monitor files from previous run (because force=True was provided)
[2016-06-05 21:14:02,030] Starting new video recorder writing to /tmp/DistortionBoxCartPole-v0-results/openaigym.video.0.34091.video000000.mp4
[2016-06-05 21:14:03,296] Starting new video recorder writing to /tmp/DistortionBoxCartPole-v0-results/openaigym.video.0.34091.video000001.mp4
[2016-06-05 21:14:03,873] Starting new video recorder writing to /tmp/DistortionBoxCartPole-v0-results/openaigym.video.0.34091.video000008.mp4
[2016-06-05 21:14:05,157] Starting new video recorder writing to /tmp/DistortionBoxCartPole-v0-results/openaigym.video.0.34091.video000027.mp4


iteration 0. mean f:     14.6. max f:       43
iteration 1. mean f:     28.4. max f:      166


[2016-06-05 21:14:06,243] Starting new video recorder writing to /tmp/DistortionBoxCartPole-v0-results/openaigym.video.0.34091.video000064.mp4


iteration 2. mean f:     28.9. max f:       77
iteration 3. mean f:     34.8. max f:      110


[2016-06-05 21:14:07,871] Starting new video recorder writing to /tmp/DistortionBoxCartPole-v0-results/openaigym.video.0.34091.video000125.mp4


iteration 4. mean f:     27.4. max f:       74
iteration 5. mean f:     41.6. max f:      109


[2016-06-05 21:14:09,371] Observation '[-2.41680807 -1.27361197 -0.10843459 -0.34910011]' is not contained within observation space 'Box(4,)'.


iteration 6. mean f:       50. max f:      157
iteration 7. mean f:     57.2. max f:      214


[2016-06-05 21:14:12,242] Starting new video recorder writing to /tmp/DistortionBoxCartPole-v0-results/openaigym.video.0.34091.video000216.mp4


iteration 8. mean f:      102. max f:      500
iteration 9. mean f:      177. max f:      500
iteration 10. mean f:      152. max f:      500
iteration 11. mean f:      177. max f:      500
iteration 12. mean f:      242. max f:      500


[2016-06-05 21:14:42,037] Starting new video recorder writing to /tmp/DistortionBoxCartPole-v0-results/openaigym.video.0.34091.video000343.mp4


iteration 13. mean f:      275. max f:      500
iteration 14. mean f:      291. max f:      500
iteration 15. mean f:      290. max f:      500
iteration 16. mean f:      301. max f:      500
iteration 17. mean f:      297. max f:      500
iteration 18. mean f:      365. max f:      500


[2016-06-05 21:14:47,701] Starting new video recorder writing to /tmp/DistortionBoxCartPole-v0-results/openaigym.video.0.34091.video000512.mp4


iteration 19. mean f:      375. max f:      500
iteration 20. mean f:      415. max f:      500
iteration 21. mean f:      365. max f:      500
iteration 22. mean f:      377. max f:      500
iteration 23. mean f:      308. max f:      500
iteration 24. mean f:      243. max f:      500
iteration 25. mean f:      252. max f:      500
iteration 26. mean f:      330. max f:      500


[2016-06-05 21:14:51,759] Starting new video recorder writing to /tmp/DistortionBoxCartPole-v0-results/openaigym.video.0.34091.video000729.mp4


iteration 27. mean f:      313. max f:      500
iteration 28. mean f:      298. max f:      500
iteration 29. mean f:      367. max f:      500
iteration 30. mean f:      438. max f:      500
iteration 31. mean f:      336. max f:      500
iteration 32. mean f:      380. max f:      500
iteration 33. mean f:      347. max f:      500
iteration 34. mean f:      313. max f:      500
iteration 35. mean f:      289. max f:      500
iteration 36. mean f:      353. max f:      500
iteration 37. mean f:      377. max f:      500


[2016-06-05 21:15:00,494] Starting new video recorder writing to /tmp/DistortionBoxCartPole-v0-results/openaigym.video.0.34091.video001000.mp4


iteration 38. mean f:      262. max f:      500
iteration 39. mean f:      413. max f:      500
iteration 40. mean f:      195. max f:      500
iteration 41. mean f:      211. max f:      500
iteration 42. mean f:      427. max f:      500
iteration 43. mean f:      355. max f:      500
iteration 44. mean f:      402. max f:      500
iteration 45. mean f:      414. max f:      500
iteration 46. mean f:      430. max f:      500
iteration 47. mean f:      467. max f:      500
iteration 48. mean f:      396. max f:      500
iteration 49. mean f:      434. max f:      500
iteration 50. mean f:      454. max f:      500
iteration 51. mean f:      451. max f:      500
iteration 52. mean f:      480. max f:      500
iteration 53. mean f:      476. max f:      500
iteration 54. mean f:      436. max f:      500
iteration 55. mean f:      480. max f:      500
iteration 56. mean f:      471. max f:      500
iteration 57. mean f:      435. max f:      500
iteration 58. mean f:      447. max f:  

[2016-06-05 21:15:25,074] Starting new video recorder writing to /tmp/DistortionBoxCartPole-v0-results/openaigym.video.0.34091.video002000.mp4


iteration 76. mean f:      412. max f:      500
iteration 77. mean f:      429. max f:      500
iteration 78. mean f:      472. max f:      500
iteration 79. mean f:      469. max f:      500
iteration 80. mean f:      487. max f:      500
iteration 81. mean f:      468. max f:      500
iteration 82. mean f:      441. max f:      500
iteration 83. mean f:      444. max f:      500
iteration 84. mean f:      404. max f:      500
iteration 85. mean f:      500. max f:      500
iteration 86. mean f:      477. max f:      500
iteration 87. mean f:      473. max f:      500
iteration 88. mean f:      480. max f:      500
iteration 89. mean f:      433. max f:      500
iteration 90. mean f:      481. max f:      500
iteration 91. mean f:      480. max f:      500
iteration 92. mean f:      480. max f:      500
iteration 93. mean f:      480. max f:      500
iteration 94. mean f:      500. max f:      500
iteration 95. mean f:      480. max f:      500
iteration 96. mean f:      468. max f:  

[2016-06-05 21:15:49,654] Finished writing results. You can upload them to the scoreboard via gym.upload('/tmp/DistortionBoxCartPole-v0-results')


iteration 99. mean f:      443. max f:      500
theta_mean: [ -2.03004586   1.11180758  -3.8517514    3.65156726 -10.41248821
   2.77791355  -4.34535918   2.68476674  -0.99304803  -0.60556377]
