In [6]:
# need a virtual display for rendering in docker
from pyvirtualdisplay import Display
display = Display(visible=0, size=(1400, 900))
display.start()


<pyvirtualdisplay.display.Display at 0x7f63dc66bf40>

In [54]:
import jax
import jax.numpy as jnp
import gymnasium as gym
import optax
from functools import partial 

env = gym.make('CartPole-v1', render_mode="rgb_array").env

# Define the dynamics model
def dynamics_model(state, action, params):
    A, B = params
    return jnp.dot(A, state) + jnp.dot(B, action)

# Define the policy


def policy(state, params):
    W, b = params
    logits = jnp.dot(W, state) + b
    probs = jax.nn.softmax(logits)
    action = jax.random.categorical(jax.random.PRNGKey(0), logits)
    return action, probs

# Define the loss function


def loss(params, rollout):
    A, B, W, b = params
    states, actions, rewards, next_states = rollout

    # Compute predicted next states using the dynamics model
    predicted_next_states = jax.vmap(dynamics_model, in_axes=(0, 0, None))(
        states, actions, (A, B))

    # Compute log likelihood of actions taken by the policy
    _, action_probs = jax.vmap(policy, in_axes=(0, None))(states, (W, b))
    log_probs = jnp.log(
        jnp.sum(action_probs * jax.nn.one_hot(actions, 2), axis=-1))

    # Compute the loss
    loss = -jnp.mean(log_probs * rewards)

    return loss


# Define the optimizer
optimizer = optax.adam(learning_rate=1e-3)

# Initialize parameters
key = jax.random.PRNGKey(0)
A = jax.random.normal(key, shape=(4, 4))
B = jax.random.normal(key, shape=(4,))
W = jax.random.normal(key, shape=(2, 4))
b = jax.random.normal(key, shape=(2,))

params = (A, B, W, b)

# Define the rollout function


def rollout(params, env):
    states = []
    actions = []
    rewards = []
    next_states = []

    A, B, W, b = params
    state, _ = env.reset(seed=0)
    done = False

    while not done:
        action, _ = policy(state, (W,b))
        next_state, reward, done, _, _ = env.step(int(action))
        states.append(state)
        actions.append(action)
        rewards.append(reward)
        next_states.append(next_state)
        state = next_state

    return (jnp.stack(states), jnp.array(actions), jnp.array(rewards), jnp.stack(next_states))



In [67]:
state, _ = env.reset()
print(len(state))
action, probs = policy(state, (W, b))
dynamics_model(state, action, (A, B))
rollout(params, env)
params

4


(Array([[ 0.08482574,  1.9097648 ,  0.29561743,  1.120948  ],
        [ 0.33432344, -0.82606775,  0.6481277 ,  1.0434873 ],
        [-0.7824839 , -0.4539462 ,  0.6297971 ,  0.81524646],
        [-0.32787678, -1.1234448 , -1.6607416 ,  0.27290547]],      dtype=float32),
 Array([ 1.8160863 , -0.75488514,  0.33988908, -0.53483534], dtype=float32),
 Array([[-1.1897566 , -1.3263226 ,  0.91276866,  2.7610164 ],
        [-0.00519618,  3.0592732 , -2.1466362 ,  0.03855126]],      dtype=float32),
 Array([-2.101969 ,  2.1736479], dtype=float32))

In [58]:
# add some test

rollout_fn = jax.tree_util.Partial(rollout, params)
rollout_fn(env)

(Array([[ 0.01369617, -0.02302133, -0.04590265, -0.04834723],
        [ 0.01323574,  0.17272775, -0.04686959, -0.3551522 ],
        [ 0.0166903 ,  0.3684837 , -0.05397264, -0.66223824],
        [ 0.02405997,  0.5643134 , -0.0672174 , -0.97141534],
        [ 0.03534624,  0.76026994, -0.08664571, -1.2844334 ],
        [ 0.05055164,  0.95638156, -0.11233438, -1.6029392 ],
        [ 0.06967927,  1.1526395 , -0.14439316, -1.9284277 ],
        [ 0.09273206,  1.3489841 , -0.18296172, -2.262184  ]],      dtype=float32),
 Array([1, 1, 1, 1, 1, 1, 1, 1], dtype=int32),
 Array([1., 1., 1., 1., 1., 1., 1., 1.], dtype=float32),
 Array([[ 0.01323574,  0.17272775, -0.04686959, -0.3551522 ],
        [ 0.0166903 ,  0.3684837 , -0.05397264, -0.66223824],
        [ 0.02405997,  0.5643134 , -0.0672174 , -0.97141534],
        [ 0.03534624,  0.76026994, -0.08664571, -1.2844334 ],
        [ 0.05055164,  0.95638156, -0.11233438, -1.6029392 ],
        [ 0.06967927,  1.1526395 , -0.14439316, -1.9284277 ],
      

In [78]:

# Train the model
states_opt = optimizer.init(params)
for i in range(100):
    key, subkey = jax.random.split(key)
    rollout_fn = partial(rollout, params)
    
    # this seems not working as [env] is not treated as jnp object
    #rollout_batch = jax.vmap(rollout_fn)(
    #    jax.random.split(subkey, 10), [env] * 10)
    rollout_batch = rollout_fn(env)

    rollout_batch = rollout_fn(env)
    grads = jax.grad(loss)(params, rollout_batch)
    updates, states_opt = optimizer.update(grads, states_opt, params)
    params = optax.apply_updates(params, updates)

    if i % 100 == 0:
        print(f'Iteration {i}, loss={loss(params, rollout_batch):.4f}')


Iteration 0, loss=0.0017


In [79]:
params
len(params)
params[2:]


(Array([[-1.2857947 , -1.4179808 ,  1.0088406 ,  2.8559885 ],
        [ 0.09084299,  3.1509333 , -2.2427087 , -0.05642055]],      dtype=float32),
 Array([-2.198084,  2.269763], dtype=float32))

In [80]:
# Test the model
state, _ = env.reset()
done = False
while not done:
    action, _ = policy(state, params[2:])
    state, _, done, _, _ = env.step(int(action))
    env.render()

env.close()
