In [2]:
%load_ext autoreload
%autoreload 2

import jax.numpy as jnp
import numpy as np
import optax
import random
import gym
import jax.random as jrandom
from jax import vmap, jit
import matplotlib.pyplot as plt

from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.gaussian_process.kernels import RBF

from replay_buffer import ReplayBuffer
from constants import SEED
from policy import policy_grad, lr_gradients, get_sequence_rewards
from pilco_utils import get_trajectories, rollout_episode
from rff import phi_X, phi_X_batch
from trans_model import prior, train_transition_models, predict, lklhood_grad, marg_lklhood



In [3]:
env = gym.make('InvertedPendulum-v4')
state_dim = env.observation_space.shape[0]

  "Agent's minimum observation space value is -infinity. This is probably too low."
  "Agent's maxmimum observation space value is infinity. This is probably too high"
  "We recommend you to use a symmetric and normalized Box action space (range=[-1, 1]) "


### Run random policy to collect training data

In [4]:
INIT_EPISODES = 500
BUFFER_CAPACITY = 1000
HORIZON = 15

replay_buffer = ReplayBuffer(capacity=BUFFER_CAPACITY)

for ep in range(INIT_EPISODES):
    cur_state = env.reset()
    done = False
    
    for t in range(HORIZON):
        action = np.array([random.uniform(-1., 1.)])
        prev_state = cur_state

        cur_state, _, done, _ = env.step(action)
        replay_buffer.push(prev_state, action.squeeze(), cur_state)


print(len(replay_buffer))

7500


### Compute posterior distribution for transition models

In [41]:
theta = jnp.full((state_dim + 1,), .9)
# BETAS = jnp.full((4,), 20.)
BETAS = jnp.load('optimal_beta.npy')
ALPHA = .3
N = 1000
NUM_FEATURES = 1000 
# lengthscales = jnp.asarray(np.load('optimal_lengthscales.npy'))
lengthscales = jnp.full((4, NUM_FEATURES, 1), .6) 
# coefs = jnp.full_like(lengthscales, 0.3)
coefs = jnp.load('optimal_coefs.npy')

model_d1 = prior(NUM_FEATURES, alpha=ALPHA)
model_d2 = prior(NUM_FEATURES, alpha=ALPHA)
model_d3 = prior(NUM_FEATURES, alpha=ALPHA)
model_d4 = prior(NUM_FEATURES, alpha=ALPHA)

trans_models_pre = [model_d1, model_d2, model_d3, model_d4]

trans_models_post = train_transition_models(
    replay_buffer, BETAS, trans_models_pre, NUM_FEATURES, lengthscales, coefs, SEED
)

### Optimise Transition Model Parameters

In [42]:
key = SEED
cur_keys = jrandom.split(key, num=3)
N = 1000
SEQ_LEN = 20
MODEL_NOISE = 0.1

transitions = replay_buffer.memory[:SEQ_LEN]
states = jnp.array(list(map(lambda t: t.state, transitions)))
actions = jnp.array(list(map(lambda t: t.action, transitions)))
next_states = jnp.array(list(map(lambda t: t.next_state, transitions)))

params = jnp.array([BETAS])
optimizer = optax.chain(
    optax.adam(learning_rate=0.05),
    # optax.scale(-1.0)
)
opt_state = optimizer.init(params)

for i in range(100):
    trans = jrandom.normal(cur_keys[0], shape=(N, SEQ_LEN, 4))
    omega = jrandom.normal(cur_keys[1], shape=(N, NUM_FEATURES, 2))
    phi = jrandom.uniform(cur_keys[2], minval=0, maxval=2 * jnp.pi, shape=(N, NUM_FEATURES, 1))

    grads = lklhood_grad(states, NUM_FEATURES, lengthscales, coefs, params[0], MODEL_NOISE, 4, actions, next_states, trans, omega, phi, *trans_models_post)

    updates, opt_state = optimizer.update(grads, opt_state, params)
    params = optax.apply_updates(params, updates)

    if i % 5 == 0:
        objective = marg_lklhood(states, NUM_FEATURES, lengthscales, coefs, params[0], MODEL_NOISE, actions, next_states, trans, omega, phi, *trans_models_post)
        print(f"Step {i}, objective: {objective}")

np.save('optimal_beta.npy', np.asarray(params[0]))

Step 0, objective: 371.58062744140625
Step 5, objective: 369.9600830078125
Step 10, objective: 368.3751525878906
Step 15, objective: 366.8271484375
Step 20, objective: 365.3169250488281
Step 25, objective: 363.8446960449219
Step 30, objective: 362.410400390625
Step 35, objective: 361.01336669921875
Step 40, objective: 359.6528625488281


### Test Next State Predictions

In [None]:
lengthscales = jnp.asarray(np.load('optimal_lengthscales.npy'))

start_state = env.reset()
action = jnp.array([.5])
N = 1000

@jit
def pred(start_st, trans_eps, state_eps, omega, phi, m_d1, m_d2, m_d3, m_d4):
    model_input = jnp.stack([start_st, jnp.full((4,), action)]).T
    input = phi_X_batch(model_input, NUM_FEATURES, lengthscales, coefs, omega, phi)

    means = jnp.concatenate([m_d1[0], m_d2[0], m_d3[0], m_d4[0]])
    covs = jnp.zeros((means.shape[0], means.shape[0]))
    covs = covs.at[:NUM_FEATURES, :NUM_FEATURES].set(m_d1[1])
    covs = covs.at[NUM_FEATURES:NUM_FEATURES * 2, NUM_FEATURES:NUM_FEATURES * 2].set(m_d2[1])
    covs = covs.at[NUM_FEATURES * 2:NUM_FEATURES * 3, NUM_FEATURES * 2:NUM_FEATURES * 3].set(m_d3[1])
    covs = covs.at[NUM_FEATURES * 3:NUM_FEATURES * 4, NUM_FEATURES * 3:NUM_FEATURES * 4].set(m_d4[1])
    d1, d2, d3, d4 = predict(means, covs, BETA, input.reshape(-1), trans_eps)

    next_mean = jnp.array([d1, d2, d3, d4]) + start_st
    next_state = next_mean + state_eps * 0.3
    return next_state

foo = vmap(pred, (None, 0, 0, 0, 0, None, None, None, None))

keys = jrandom.split(SEED, num=4)
state = jrandom.normal(keys[0], shape=(N, 4))
trans = jrandom.normal(keys[1], shape=(N, 4))
omega = jrandom.normal(keys[2], shape=(N, NUM_FEATURES, 2))
phi = jrandom.uniform(keys[3], minval=0, maxval=2 * jnp.pi, shape=(N, NUM_FEATURES,))


predictions_post = foo(jnp.zeros((4,)), state, trans, omega, phi, *trans_models_post).mean(axis=0)
print(f"Pred: {predictions_post}")
print(f"True: {env.step(action)[0]}")

In [None]:
transitions = replay_buffer.memory[:100]
states = jnp.array(list(map(lambda t: t.state, transitions)))
actions = jnp.array(list(map(lambda t: t.action, transitions)))
next_states = jnp.array(list(map(lambda t: t.next_state, transitions)))

key = SEED
num_states = 100
pred_next_states = np.zeros((num_states,))
xs = np.linspace(-.3, .3, num_states)

for i, start_state in enumerate(states):
    key, subkey = jrandom.split(key)
    keys = jrandom.split(subkey, num=4)

    state_eps = jrandom.normal(keys[0], shape=(N, 4))
    trans_eps = jrandom.normal(keys[1], shape=(N, 4))
    omega = jrandom.normal(keys[2], shape=(N, NUM_FEATURES, 2))
    phi = jrandom.uniform(keys[3], minval=0, maxval=2 * jnp.pi, shape=(N, NUM_FEATURES,))

    pred_st = foo(start_state, state_eps, trans_eps, omega, phi, *trans_models_post).mean(axis=0)[0]
    pred_next_states[i] = pred_st

    if i % 20 == 0:
        print(f"Step {i}")


In [None]:
plt.figure(figsize=(20, 6))
# plt.vlines(states[:, 0], -0.32, next_states[:, 0], color='orange', linestyle='dashed')
plt.scatter(states[:, 0], next_states[:, 0], color='orange', label="True transitions")
# plt.vlines(states[:, 0], -0.32, pred_next_states, color='cornflowerblue', linestyle='dashed')
plt.scatter(states[:, 0], pred_next_states, color='cornflowerblue', label="Predicted transitions")
plt.show()

### Compute Policy Gradients

In [None]:
NOISE = 0.1
N = 10000
eps = 0.0001

costs = np.zeros((2,))
key = SEED
for j, i in enumerate([-1, 1]):
    keys = jrandom.split(key, num=4)
    state_epsilons = jrandom.normal(key=keys[0], shape=(N, HORIZON, 4))
    trans_epsilons = jrandom.normal(key=keys[1], shape=(N, HORIZON, 4))
    omegas = jrandom.normal(key=keys[2], shape=(N, HORIZON, NUM_FEATURES, 2))
    phis = jrandom.uniform(key=keys[3], minval=0, maxval=2 * jnp.pi, shape=(N, HORIZON, NUM_FEATURES))

    theta = jnp.array([1., 1., 1., 1., 1 + i * eps])
    trajectories = get_trajectories(
        theta, BETA, env, *trans_models_post, HORIZON, NUM_FEATURES, lengthscales, coefs, NOISE, state_epsilons, trans_epsilons, omegas, phis
        )
    avg_cost = vmap(get_sequence_rewards, (0,))(trajectories).mean()
    costs[j] = avg_cost
    print(avg_cost)

diff_grad = (costs[1] - costs[0]) / (2 * eps)
print(diff_grad)

In [8]:
theta = jnp.full((state_dim + 1,), 1.)
BETA = 1.
ALPHA = 0.3
N = 1000
HORIZON = 10
NUM_FEATURES = 1000
NOISE = 0.2
lengthscales = jnp.full((NUM_FEATURES,), .1)
coefs = jnp.full_like(lengthscales, 0.06)

model_d1 = prior(NUM_FEATURES, alpha=ALPHA)
model_d2 = prior(NUM_FEATURES, alpha=ALPHA)
model_d3 = prior(NUM_FEATURES, alpha=ALPHA)
model_d4 = prior(NUM_FEATURES, alpha=ALPHA)

trans_models = [model_d1, model_d2, model_d3, model_d4]

trans_models_post = train_transition_models(replay_buffer, BETA, trans_models, NUM_FEATURES, lengthscales, coefs, SEED)

model_d1, model_d2, model_d3, model_d4 = trans_models_post

keys = jrandom.split(SEED, num=7)
state_epsilons = jrandom.normal(key=keys[0], shape=(N, HORIZON, 4))
trans_epsilons = jrandom.normal(key=keys[1], shape=(N, HORIZON, 4))
omegas = jrandom.normal(key=keys[2], shape=(N, NUM_FEATURES, 2))
phis = jrandom.uniform(key=keys[3], minval=0, maxval=2 * jnp.pi, shape=(N, NUM_FEATURES))
extra_epsilons = jrandom.normal(key=keys[4], shape=(N, HORIZON, 4))
extra_omegas = jrandom.normal(key=keys[5], shape=(N, NUM_FEATURES, 2))
extra_phis = jrandom.uniform(key=keys[6], minval=0, maxval=2 * jnp.pi, shape=(N, NUM_FEATURES,))

# LR gradients
trajectories = get_trajectories(
    theta, BETA, env, *trans_models_post, HORIZON, NUM_FEATURES, lengthscales, coefs, NOISE, state_epsilons, trans_epsilons, omegas, phis
    )
lr_grads = lr_gradients(
    theta, BETA, *trans_models_post, HORIZON, NUM_FEATURES, lengthscales, coefs, NOISE, trajectories, trans_epsilons, omegas, phis
    )
print(f"LR: {lr_grads}")

# RP gradients
rp_grads = policy_grad(
    theta, BETA, *trans_models_post, env, HORIZON, NUM_FEATURES, lengthscales, coefs, NOISE, state_epsilons, trans_epsilons, omegas, phis
)
print(f"RP: {rp_grads}")


LR: [-0.33891183 -0.05050422 -0.64287704  0.84783596  0.2908255 ]
RP: [ 0.00662371 -0.00318408 -0.00130731 -0.02781972  0.01692043]


### Particle-based PILCO

In [None]:
MAX_EPISODES = 50
NOISE = 0.1
HORIZON = 10
TEST_HORIZON = 50
N = 1000
TRAIN_LOOPS = 5
key = SEED

trans_models = trans_models_post

theta = jnp.full((state_dim + 1,), 0.)
params = jnp.array([theta])

optimizer = optax.chain(
    optax.adam(learning_rate=0.008),
    # optax.scale(-1.0)
)

opt_state = optimizer.init(params)

for ep in range(MAX_EPISODES):
    subkeys = jrandom.split(key, num=6)
    key = subkeys[0]

    state_epsilons = jrandom.normal(key=subkeys[1], shape=(N, HORIZON, 4))
    trans_epsilons = jrandom.normal(key=subkeys[2], shape=(N, HORIZON, 4))
    omegas = jrandom.normal(key=subkeys[3], shape=(N, HORIZON, NUM_FEATURES, 2))
    phis = jrandom.uniform(key=subkeys[4], minval=0, maxval=2 * jnp.pi, shape=(N, HORIZON, NUM_FEATURES))

    # Train model with all available data
    trans_models = train_transition_models(
        replay_buffer, BETA, trans_models, NUM_FEATURES, lengthscales, coefs, subkeys[5]
    )

    for _ in range(TRAIN_LOOPS):
        # Compute the policy gradient
        # trajectories = get_trajectories(theta, BETA, env, *trans_models, HORIZON, NOISE, state_epsilons, trans_epsilons)
        # grads = lr_gradients(theta, BETA, *trans_models, HORIZON, NOISE, trajectories, trans_epsilons)

        # RP gradients
        grads = policy_grad(
            theta,
            BETA,
            *trans_models,
            env,
            HORIZON,
            NUM_FEATURES,
            lengthscales,
            coefs,
            NOISE,
            state_epsilons,
            trans_epsilons,
            omegas,
            phis
        )

        params = jnp.array([theta])
        updates, opt_state = optimizer.update(grads, opt_state, params)
        params = optax.apply_updates(params, updates)
        theta = params[0]

    # Roll-out policy online for an episode and add to replay buffer
    cur_obj = rollout_episode(env, TEST_HORIZON, replay_buffer, theta)

    if ep % 1 == 0:
        print(f"Ep {ep}, objective: {cur_obj}, theta: {theta}")


final_obj = rollout_episode(env, TEST_HORIZON, replay_buffer, theta)
print(f"Final score: {final_obj}")


### Collect and organise training and testing data

In [None]:
train_X, train_y, test_X, test_y = replay_buffer.get_train_test_arrays()

train_X_d1 = train_X[:, :2]
train_X_d2 = train_X[:, 2:4]
train_X_d3 = train_X[:, 4:6]
train_X_d4 = train_X[:, 6:8]
test_X_d1 = test_X[:, :2]
test_X_d2 = test_X[:, 2:4]
test_X_d3 = test_X[:, 4:6]
test_X_d4 = test_X[:, 6:8]

train_y_d1 = train_y[:, 0]
train_y_d2 = train_y[:, 1]
train_y_d3 = train_y[:, 2]
train_y_d4 = train_y[:, 3]
test_y_d1 = test_y[:, 0]
test_y_d2 = test_y[:, 1]
test_y_d3 = test_y[:, 2]
test_y_d4 = test_y[:, 3]

### Apply Random Fourier Features

In [None]:
NUM_FEATURES = 2000

omega = jrandom.normal(key=SEED, shape=(NUM_FEATURES, 2))
phi = jrandom.uniform(key=SEED, minval=0, maxval=2 * jnp.pi, shape=(NUM_FEATURES, 1))
lengthscales = jnp.full((NUM_FEATURES, 1), 1.)
coefs = jnp.full_like(lengthscales, 1.)

phi_X_train_d1 = phi_X(train_X_d1, NUM_FEATURES, lengthscales, coefs, omega, phi)
phi_X_test_d1 = phi_X(test_X_d1, NUM_FEATURES, lengthscales, coefs, omega, phi)
phi_X_train_d2 = phi_X(train_X_d2, NUM_FEATURES, lengthscales, coefs, omega, phi)
phi_X_test_d2 = phi_X(test_X_d2, NUM_FEATURES, lengthscales, coefs, omega, phi)
phi_X_train_d3 = phi_X(train_X_d3, NUM_FEATURES, lengthscales, coefs, omega, phi)
phi_X_test_d3 = phi_X(test_X_d3, NUM_FEATURES, lengthscales, coefs, omega, phi)
phi_X_train_d4 = phi_X(train_X_d4, NUM_FEATURES, lengthscales, coefs, omega, phi)
phi_X_test_d4 = phi_X(test_X_d4, NUM_FEATURES, lengthscales, coefs, omega, phi)

### Train RFF transition model

In [None]:

model_d1 = BLR(NUM_FEATURES, alpha=0.3, beta=1.)
model_d2 = BLR(NUM_FEATURES, alpha=0.3, beta=1.)
model_d3 = BLR(NUM_FEATURES, alpha=0.3, beta=1.)
model_d4 = BLR(NUM_FEATURES, alpha=0.3, beta=1.)

model_d1.posterior(phi_X_train_d1, train_y_d1)
model_d2.posterior(phi_X_train_d2, train_y_d2)
model_d3.posterior(phi_X_train_d3, train_y_d3)
model_d4.posterior(phi_X_train_d4, train_y_d4)

d1_preds = np.zeros_like(test_y_d1)
d2_preds = np.zeros_like(test_y_d2)
d3_preds = np.zeros_like(test_y_d3)
d4_preds = np.zeros_like(test_y_d4)

for i in range(test_y_d1.shape[0]):
    d1_preds = model_d1.predict(phi_X_test_d1).mean()
    d2_preds = model_d2.predict(phi_X_test_d2).mean()
    d3_preds = model_d3.predict(phi_X_test_d3).mean()
    d4_preds = model_d4.predict(phi_X_test_d4).mean()



print(f"MSE d1: {metrics.mean_absolute_error(test_y_d1, d1_preds)}")
print(f"MSE d2: {metrics.mean_absolute_error(test_y_d2, d2_preds)}")
print(f"MSE d3: {metrics.mean_absolute_error(test_y_d3, d3_preds)}")
print(f"MSE d4: {metrics.mean_absolute_error(test_y_d4, d4_preds)}")

### Train Gaussian process transition model

In [None]:

gp_d1 = GaussianProcessRegressor(kernel=RBF(length_scale_bounds=(1e-4, 1e3)), n_restarts_optimizer=9)
gp_d2 = GaussianProcessRegressor(kernel=RBF(length_scale_bounds=(1e-4, 1e3)), n_restarts_optimizer=9)
gp_d3 = GaussianProcessRegressor(kernel=RBF(length_scale_bounds=(1e-4, 1e3)), n_restarts_optimizer=9)
gp_d4 = GaussianProcessRegressor(kernel=RBF(length_scale_bounds=(1e-4, 1e3)), n_restarts_optimizer=9)

gp_d1.fit(train_X_d1, train_y_d1)
gp_d2.fit(train_X_d2, train_y_d2)
gp_d3.fit(train_X_d3, train_y_d3)
gp_d4.fit(train_X_d4, train_y_d4)

d1_preds_gp = np.zeros_like(test_y_d1)
d2_preds_gp = np.zeros_like(test_y_d2)
d3_preds_gp = np.zeros_like(test_y_d3)
d4_preds_gp = np.zeros_like(test_y_d4)

d1_std_gp = np.zeros_like(test_y_d1)
d2_std_gp = np.zeros_like(test_y_d2)
d3_std_gp = np.zeros_like(test_y_d3)
d4_std_gp = np.zeros_like(test_y_d4)

for i in range(test_y_d1.shape[0]):
    d1_preds_gp, d1_std_gp = gp_d1.predict(test_X_d1, return_std=True)
    d2_preds_gp, d2_std_gp = gp_d2.predict(test_X_d2, return_std=True)
    d3_preds_gp, d3_std_gp = gp_d3.predict(test_X_d3, return_std=True)
    d4_preds_gp, d4_std_gp = gp_d4.predict(test_X_d4, return_std=True)

print(f"MSE d1: {metrics.mean_absolute_error(test_y_d1, d1_preds_gp)}, kernel: {gp_d1.kernel_}")
print(f"MSE d2: {metrics.mean_absolute_error(test_y_d2, d2_preds_gp)}, kernel: {gp_d2.kernel_}")
print(f"MSE d3: {metrics.mean_absolute_error(test_y_d3, d3_preds_gp)}, kernel: {gp_d3.kernel_}")
print(f"MSE d4: {metrics.mean_absolute_error(test_y_d4, d4_preds_gp)}, kernel: {gp_d4.kernel_}")