In [1]:
import numpy as np
import gym
import jax
import jax.numpy as jnp
from jax import random


from utils import experience




In [2]:
key = random.PRNGKey(42)

## test `Accumulator`

In [3]:
def random_transition(rng_key):
  keys = random.split(rng_key, 5)
  action = random.normal(keys[0])
  step_type = random.choice(keys[1], np.array([0,1]), (), p=np.array([0.8,0.2]))
  obsv = random.normal(keys[2], (5,))
  reward = random.normal(keys[3])
  discount = random.normal(keys[4])
  timestep = experience.TimeStep(step_type, obsv, reward, discount)
  return action, timestep

In [4]:
num_trial = 20
acc = experience.Accumulator(100,10)
for _ in range(num_trial):
  key, rng_key = random.split(key)
  acc.push(*random_transition(rng_key))

In [11]:
if acc.len_ep():
  key, rng_key = random.split(key)
  ep = acc.sample_one_ep(rng_key=rng_key)
  a_tm1, timesteps = ep
  print(a_tm1)
  print(timesteps.obsv.shape)
  print(timesteps.reward)

[-0.5247943   1.2919614   0.17375942 -0.16435573 -0.4007549   0.8670506
  0.68050987 -0.55330265  0.8703184   1.5482894 ]
(10, 5)
[-1.0339216  -0.37174553 -0.6897025   1.2958331   0.12408566  0.94940984
 -0.36921528  0.1504747   1.0099529   0.1570648 ]


### test `Accumulator` with `gym` environment

#### Blackjack-v1

In [54]:
env = gym.make('Blackjack-v1')

In [56]:
acc = experience.Accumulator(100,10)
num_ep = 10

for _ in range(num_ep):
    discount = 1
    gamma = 0.9 # could be gamma*lambda*lho
    observation = env.reset()
    acc.push(None, experience.TimeStep(obsv = np.array(observation)))
    for _ in range(100):
        action = env.action_space.sample()
        # print(observation)
        observation, reward, done, info = env.step(action)
        acc.push(action, experience.TimeStep(
            int(done),
            np.array(observation),
            reward,
            discount
        ))
        discount*=gamma
        if done:
            break

In [64]:
if acc.len_ep():
  key, rng_key = random.split(key)
  ep = acc.sample_one_ep(rng_key=rng_key)
  a_tm1, timesteps = ep
  print("actions", a_tm1)
  print("observations'", timesteps.obsv)
  print("rewards", timesteps.reward)

actions [0 0]
observations' [[10  4  0]
 [10  4  0]]
rewards [ 0. -1.]


#### CartPole

In [65]:
env = gym.make('CartPole-v1')

In [68]:
acc = experience.Accumulator(max_t = 100,max_ep = 10)
num_ep = 10

for _ in range(num_ep):
    discount = 1
    gamma = 0.9 # could be gamma*lambda*lho
    observation = env.reset()
    acc.push(None, experience.TimeStep(obsv = np.array(observation)))
    for _ in range(100):
        action = env.action_space.sample()
        # print(observation)
        observation, reward, done, info = env.step(action)
        acc.push(action, experience.TimeStep(
            int(done),
            np.array(observation),
            reward,
            discount
        ))
        discount*=gamma
        if done:
            break

In [75]:
if acc.len_ep():
  key, rng_key = random.split(key)
  ep = acc.sample_one_ep(rng_key=rng_key)
  a_tm1, timesteps = ep
  print("actions", a_tm1)
  print("observations' shape", timesteps.obsv.shape)
  print("cartpole angles", timesteps.obsv[:, 2]) # terminate when out-of-screen or |angle|> .21 rad
  print("rewards", timesteps.reward)

actions [0 1 1 0 1 0 1 0 0 1 0 1 1 1 0 0 0 1 1 0 0 1 0 1 1]
observations' shape (25, 4)
cartpole angles [ 0.01291977  0.01277802  0.0068647  -0.00482113 -0.01061021 -0.02228329
 -0.02816993 -0.0400491  -0.04625471 -0.04686455 -0.0536126  -0.05480978
 -0.06218904 -0.07575729 -0.09555745 -0.10999842 -0.11921654 -0.12331293
 -0.13396502 -0.15119417 -0.16346852 -0.17091283 -0.18514588 -0.19469127
 -0.21113361]
rewards [0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1.]


In [80]:
len(acc._episodes[1,2])

TypeError: sequence index must be integer, not 'tuple'

# Random stuff

In [3]:
from value_prediction import approximator
from value_prediction import td

In [4]:
print(td.td_zero_error(0, 1, 1, 1))
# jitted_nstep_td_errors = jax.jit(td.nstep_td_errors, static_argnums=(0,))
# print(jitted_nstep_td_errors(3, jnp.zeros(4), jnp.array([1, 2, 3, 4]), jnp.ones(4), .5))

print(td.nstep_td_errors(3, jnp.zeros(4), jnp.array([1, 2, 3, 4]), jnp.ones(4), .5))

2
[2.875 4.625]


In [6]:
%timeit td.nstep_td_errors(3, jnp.zeros(4), jnp.array([1, 2, 3, 4]), jnp.ones(4), .5)
# %timeit jitted_nstep_td_errors(3, jnp.zeros(3), jnp.array([1,1,1]), jnp.zeros(3), 1)

844 µs ± 15.4 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [4]:
appr = approximator.tabularApproximator([jnp.arange(0,2), jnp.arange(-2,2)])

In [5]:
appr.v([0,0])

(2,)
(2,)


DeviceArray(0., dtype=float32)

In [20]:
x = jnp.array([[0,1,2],[3,4,5]])
x

DeviceArray([[0, 1, 2],
             [3, 4, 5]], dtype=int32)

In [34]:
x[tuple(x[0,:2])]

DeviceArray(1, dtype=int32)

In [33]:
tuple(x[0,:2])

(DeviceArray(0, dtype=int32), DeviceArray(1, dtype=int32))