In [1]:
%load_ext autoreload
%autoreload 2

import jax.numpy as jnp
import numpy as np
import optax
import random
import gym
import jax.random as jrandom
from jax import vmap, jit
import matplotlib.pyplot as plt

from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.gaussian_process.kernels import RBF

from replay_buffer import ReplayBuffer
from constants import SEED
from policy import policy_grad, lr_gradients, trajectory_value
from pilco_utils import get_trajectories, rollout_episode
from rff import phi_X, phi_X_batch
from trans_model import prior, train_transition_models, predict, predict_batch



In [2]:
env = gym.make('InvertedPendulum-v4')
state_dim = env.observation_space.shape[0]

  "Agent's minimum observation space value is -infinity. This is probably too low."
  "Agent's maxmimum observation space value is infinity. This is probably too high"
  "We recommend you to use a symmetric and normalized Box action space (range=[-1, 1]) "


### Run random policy to collect training data

In [3]:
INIT_EPISODES = 1000
BUFFER_CAPACITY = 1000
HORIZON = 20

replay_buffer = ReplayBuffer(capacity=BUFFER_CAPACITY)

for ep in range(INIT_EPISODES):
    cur_state = env.reset()
    done = False
    
    for t in range(HORIZON):
        action = np.array([random.uniform(-1., 1.)])
        prev_state = cur_state

        cur_state, _, done, _ = env.step(action)
        replay_buffer.push(prev_state, action.squeeze(), cur_state)


print(len(replay_buffer))

1000


### Compute posterior distribution for transition models

In [None]:
theta = jnp.full((state_dim + 1,), .9)
BETA = .1
ALPHA = .3
N = 1000
NUM_FEATURES = 1000 
lengthscales = jnp.full((NUM_FEATURES,), .1)
coefs = jnp.full_like(lengthscales, 0.000001)

model_d1 = prior(NUM_FEATURES, alpha=ALPHA)
model_d2 = prior(NUM_FEATURES, alpha=ALPHA)
model_d3 = prior(NUM_FEATURES, alpha=ALPHA)
model_d4 = prior(NUM_FEATURES, alpha=ALPHA)

trans_models = [model_d1, model_d2, model_d3, model_d4]

trans_models = train_transition_models(
    replay_buffer, BETA, trans_models, NUM_FEATURES, lengthscales, coefs, SEED
)

### Test Next State Predictions

In [None]:
start_state = env.reset()
action = jnp.array([.1])
N = 1000

model_d1, model_d2, model_d3, model_d4 = trans_models

@jit
def pred(trans_eps, state_eps, omega, phi):
    model_input = jnp.stack([start_state, jnp.full((4,), action)]).T
    input = phi_X_batch(model_input, NUM_FEATURES, lengthscales, coefs, omega, phi)

    means = jnp.concatenate([model_d1[0], model_d2[0], model_d3[0], model_d4[0]])
    covs = jnp.diag(jnp.concatenate([
        jnp.diag(model_d1[1]), jnp.diag(model_d2[1]), jnp.diag(model_d3[1]), jnp.diag(model_d4[1])
    ]))
    d1, d2, d3, d4 = predict(means, covs, BETA, input.reshape(-1), trans_eps)

    next_mean = jnp.array([d1, d2, d3, d4]) + start_state
    next_state = next_mean + state_eps * 0.1
    return next_state

foo = vmap(vmap(pred, (None, 0, None, None)), (0, None, 0, 0))

trans = jrandom.normal(SEED, shape=(N, 4))
state = jrandom.normal(SEED, shape=(N, 4))
omega = jrandom.normal(SEED, shape=(N, NUM_FEATURES, 2))
phi = jrandom.uniform(SEED, minval=0, maxval=2 * jnp.pi, shape=(N, NUM_FEATURES,))

predictions = foo(state, trans, omega, phi).mean(axis=(1,0))
print(f"Pred: {predictions}")
print(f"True: {env.step(action)[0]}")

### Compute Policy Gradients

In [6]:
theta = jnp.full((state_dim + 1,), .9)
BETA = 1.
ALPHA = 0.3
N = 100
HORIZON = 10
NUM_FEATURES = 1000
NOISE = 0.1
lengthscales = jnp.full((NUM_FEATURES,), .1)
coefs = jnp.full_like(lengthscales, 0.001)

model_d1 = prior(NUM_FEATURES, alpha=ALPHA)
model_d2 = prior(NUM_FEATURES, alpha=ALPHA)
model_d3 = prior(NUM_FEATURES, alpha=ALPHA)
model_d4 = prior(NUM_FEATURES, alpha=ALPHA)

trans_models = [model_d1, model_d2, model_d3, model_d4]

trans_models = train_transition_models(replay_buffer, BETA, trans_models, NUM_FEATURES, lengthscales, coefs, SEED)

keys = jrandom.split(SEED, num=4)
state_epsilons = jrandom.normal(key=keys[0], shape=(N, HORIZON, 4))
trans_epsilons = jrandom.normal(key=keys[1], shape=(100, HORIZON, 4))
omegas = jrandom.normal(key=keys[2], shape=(N, HORIZON, NUM_FEATURES, 2))
phis = jrandom.uniform(key=keys[3], minval=0, maxval=2 * jnp.pi, shape=(N, HORIZON, NUM_FEATURES))

# LR gradients
trajectories = get_trajectories(
    theta, BETA, env, *trans_models, HORIZON, NUM_FEATURES, lengthscales, coefs, NOISE, state_epsilons, trans_epsilons, omegas, phis
    )
lr_grads = lr_gradients(
    theta, BETA, *trans_models, HORIZON, NUM_FEATURES, lengthscales, coefs, NOISE, trajectories, trans_epsilons, omegas, phis
    )
print(f"LR1: {lr_grads}")

# RP gradients
rp_grads = policy_grad(
    theta, BETA, *trans_models, env, HORIZON, NUM_FEATURES, lengthscales, coefs, NOISE, state_epsilons, trans_epsilons, omegas, phis
)
print(f"RP1: {rp_grads}")
# rp_grads = trajectory_value(
#     theta, BETA, jnp.zeros((4,)), *trans_models, jnp.arange(HORIZON - 1), NUM_FEATURES, lengthscales, coefs, NOISE, state_epsilons[0], trans_epsilons[0], omegas[0], phis[0]
# )


LR1: [ 2.3047597e-04  3.2374344e-05 -9.6576869e-06  4.8669985e-08
 -1.0052458e-05]
RP1: [ 4.7717482e-08 -1.2887130e-08 -9.7361273e-08 -5.6822959e-08
  2.2752496e-07]


### Particle-based PILCO

In [None]:
MAX_EPISODES = 50
NOISE = 0.1
HORIZON = 10
TEST_HORIZON = 50
N_STATES = 100
N_TRANS = 100
TRAIN_LOOPS = 5
key = SEED


theta = jnp.full((state_dim + 1,), 0.)
params = jnp.array([theta])

optimizer = optax.chain(
    optax.adam(learning_rate=0.005),
    # optax.scale(-1.0)
)

opt_state = optimizer.init(params)

for ep in range(MAX_EPISODES):
    subkeys = jrandom.split(key, num=6)
    key = subkeys[0]

    state_epsilons = jrandom.normal(key=subkeys[1], shape=(N_STATES, HORIZON, 4))
    trans_epsilons = jrandom.normal(key=subkeys[2], shape=(N_TRANS, HORIZON, 4))
    omegas = jrandom.normal(key=subkeys[3], shape=(N_STATES, HORIZON, NUM_FEATURES, 2))
    phis = jrandom.uniform(key=subkeys[4], minval=0, maxval=2 * jnp.pi, shape=(N_STATES, HORIZON, NUM_FEATURES))

    # Train model with all available data
    trans_models = train_transition_models(
        replay_buffer, BETA, trans_models, NUM_FEATURES, lengthscales, coefs, subkeys[5]
    )

    for _ in range(TRAIN_LOOPS):
        # Compute the policy gradient
        # trajectories = get_trajectories(theta, BETA, env, *trans_models, HORIZON, NOISE, state_epsilons, trans_epsilons)
        # grads = lr_gradients(theta, BETA, *trans_models, HORIZON, NOISE, trajectories, trans_epsilons)

        # RP gradients
        grads = policy_grad(
            theta,
            BETA,
            *trans_models,
            env,
            HORIZON,
            NUM_FEATURES,
            lengthscales,
            coefs,
            NOISE,
            state_epsilons,
            trans_epsilons,
            omegas,
            phis
        )

        params = jnp.array([theta])
        updates, opt_state = optimizer.update(grads, opt_state, params)
        params = optax.apply_updates(params, updates)
        theta = params[0]

    # Roll-out policy online for an episode and add to replay buffer
    cur_obj = rollout_episode(env, TEST_HORIZON, replay_buffer, theta)

    if ep % 5 == 0:
        print(f"Ep {ep}, objective: {cur_obj}, theta: {theta}")


final_obj = rollout_episode(env, TEST_HORIZON, replay_buffer, theta)
print(f"Final score: {final_obj}")


### Collect and organise training and testing data

In [None]:
train_X, train_y, test_X, test_y = replay_buffer.get_train_test_arrays()

train_X_d1 = train_X[:, :2]
train_X_d2 = train_X[:, 2:4]
train_X_d3 = train_X[:, 4:6]
train_X_d4 = train_X[:, 6:8]
test_X_d1 = test_X[:, :2]
test_X_d2 = test_X[:, 2:4]
test_X_d3 = test_X[:, 4:6]
test_X_d4 = test_X[:, 6:8]

train_y_d1 = train_y[:, 0]
train_y_d2 = train_y[:, 1]
train_y_d3 = train_y[:, 2]
train_y_d4 = train_y[:, 3]
test_y_d1 = test_y[:, 0]
test_y_d2 = test_y[:, 1]
test_y_d3 = test_y[:, 2]
test_y_d4 = test_y[:, 3]

### Apply Random Fourier Features

In [None]:
NUM_FEATURES = 2000

omega = jrandom.normal(key=SEED, shape=(NUM_FEATURES, 2))
phi = jrandom.uniform(key=SEED, minval=0, maxval=2 * jnp.pi, shape=(NUM_FEATURES, 1))
lengthscales = jnp.full((NUM_FEATURES, 1), 1.)
coefs = jnp.full_like(lengthscales, 1.)

phi_X_train_d1 = phi_X(train_X_d1, NUM_FEATURES, lengthscales, coefs, omega, phi)
phi_X_test_d1 = phi_X(test_X_d1, NUM_FEATURES, lengthscales, coefs, omega, phi)
phi_X_train_d2 = phi_X(train_X_d2, NUM_FEATURES, lengthscales, coefs, omega, phi)
phi_X_test_d2 = phi_X(test_X_d2, NUM_FEATURES, lengthscales, coefs, omega, phi)
phi_X_train_d3 = phi_X(train_X_d3, NUM_FEATURES, lengthscales, coefs, omega, phi)
phi_X_test_d3 = phi_X(test_X_d3, NUM_FEATURES, lengthscales, coefs, omega, phi)
phi_X_train_d4 = phi_X(train_X_d4, NUM_FEATURES, lengthscales, coefs, omega, phi)
phi_X_test_d4 = phi_X(test_X_d4, NUM_FEATURES, lengthscales, coefs, omega, phi)

### Train RFF transition model

In [None]:

model_d1 = BLR(NUM_FEATURES, alpha=0.3, beta=1.)
model_d2 = BLR(NUM_FEATURES, alpha=0.3, beta=1.)
model_d3 = BLR(NUM_FEATURES, alpha=0.3, beta=1.)
model_d4 = BLR(NUM_FEATURES, alpha=0.3, beta=1.)

model_d1.posterior(phi_X_train_d1, train_y_d1)
model_d2.posterior(phi_X_train_d2, train_y_d2)
model_d3.posterior(phi_X_train_d3, train_y_d3)
model_d4.posterior(phi_X_train_d4, train_y_d4)

d1_preds = np.zeros_like(test_y_d1)
d2_preds = np.zeros_like(test_y_d2)
d3_preds = np.zeros_like(test_y_d3)
d4_preds = np.zeros_like(test_y_d4)

for i in range(test_y_d1.shape[0]):
    d1_preds = model_d1.predict(phi_X_test_d1).mean()
    d2_preds = model_d2.predict(phi_X_test_d2).mean()
    d3_preds = model_d3.predict(phi_X_test_d3).mean()
    d4_preds = model_d4.predict(phi_X_test_d4).mean()



print(f"MSE d1: {metrics.mean_absolute_error(test_y_d1, d1_preds)}")
print(f"MSE d2: {metrics.mean_absolute_error(test_y_d2, d2_preds)}")
print(f"MSE d3: {metrics.mean_absolute_error(test_y_d3, d3_preds)}")
print(f"MSE d4: {metrics.mean_absolute_error(test_y_d4, d4_preds)}")

### Train Gaussian process transition model

In [None]:

gp_d1 = GaussianProcessRegressor(kernel=RBF(length_scale_bounds=(1e-4, 1e3)), n_restarts_optimizer=9)
gp_d2 = GaussianProcessRegressor(kernel=RBF(length_scale_bounds=(1e-4, 1e3)), n_restarts_optimizer=9)
gp_d3 = GaussianProcessRegressor(kernel=RBF(length_scale_bounds=(1e-4, 1e3)), n_restarts_optimizer=9)
gp_d4 = GaussianProcessRegressor(kernel=RBF(length_scale_bounds=(1e-4, 1e3)), n_restarts_optimizer=9)

gp_d1.fit(train_X_d1, train_y_d1)
gp_d2.fit(train_X_d2, train_y_d2)
gp_d3.fit(train_X_d3, train_y_d3)
gp_d4.fit(train_X_d4, train_y_d4)

d1_preds_gp = np.zeros_like(test_y_d1)
d2_preds_gp = np.zeros_like(test_y_d2)
d3_preds_gp = np.zeros_like(test_y_d3)
d4_preds_gp = np.zeros_like(test_y_d4)

d1_std_gp = np.zeros_like(test_y_d1)
d2_std_gp = np.zeros_like(test_y_d2)
d3_std_gp = np.zeros_like(test_y_d3)
d4_std_gp = np.zeros_like(test_y_d4)

for i in range(test_y_d1.shape[0]):
    d1_preds_gp, d1_std_gp = gp_d1.predict(test_X_d1, return_std=True)
    d2_preds_gp, d2_std_gp = gp_d2.predict(test_X_d2, return_std=True)
    d3_preds_gp, d3_std_gp = gp_d3.predict(test_X_d3, return_std=True)
    d4_preds_gp, d4_std_gp = gp_d4.predict(test_X_d4, return_std=True)

print(f"MSE d1: {metrics.mean_absolute_error(test_y_d1, d1_preds_gp)}, kernel: {gp_d1.kernel_}")
print(f"MSE d2: {metrics.mean_absolute_error(test_y_d2, d2_preds_gp)}, kernel: {gp_d2.kernel_}")
print(f"MSE d3: {metrics.mean_absolute_error(test_y_d3, d3_preds_gp)}, kernel: {gp_d3.kernel_}")
print(f"MSE d4: {metrics.mean_absolute_error(test_y_d4, d4_preds_gp)}, kernel: {gp_d4.kernel_}")