In [None]:
%load_ext autoreload
%autoreload 2
import gym
from gym import spaces
import numpy as np
from gym_examples.envs.dubins_car import DubinsCarEnv


import jax
import jax.numpy as jnp
import haiku as hk
import optax

In [None]:
#generate data

env = DubinsCarEnv()
state = env.reset()
X = []
y = []
for i in range(200000):
    state = env.reset()
    for action in range(env.action_space.n):
        X.append(state)
        r = env.sample(state, action, 0)
        y.append(r)

X = np.array(X)
y = np.array(y)



In [None]:
X_train = X# (X - X.mean())/(X.std())
X_train

In [None]:
import matplotlib
%matplotlib inline
import matplotlib.pyplot as plt

# Plot...
plt.scatter(X_train[:,0], X_train[:,1], c=y, s=1) # s is a size of marker 
plt.jet()
plt.colorbar()
plt.show()

In [None]:
print(len(X))
print(len(y))

In [None]:
print(len(X_train))
print(y)
print(X_train)

In [None]:
X

In [None]:
#params are defined *implicitly* in haiku
def forward(X):
    l1 = hk.Linear(9)(X)
    l2 = jax.nn.relu(l1)
    l3 = hk.Linear(1)(l2)

    return l3.ravel()

# def forward(X):

#     mlp = hk.nets.MLP(output_sizes=[3,1])
#     return mlp(X)

# a transformed haiku function consists of an 'init' and an 'apply' function
forward = hk.without_apply_rng(hk.transform(forward))



# initialize parameters
rng = jax.random.PRNGKey(seed=14)
params = forward.init(rng, X_train)

# redefine 'forward' as the 'apply' function
forward = forward.apply


def loss_fn(params, X, y):
    err = forward(params, X) - y
    return jnp.mean(jnp.square(err))  # mse





optimizer = optax.adam(learning_rate=1e-2)
print(len(X_train))
opt_state = optimizer.init(params)
for epoch in range(300):
    loss, grads = jax.value_and_grad(loss_fn)(params,X=X_train,y=y)
    print("progress:", "epoch:", epoch, "loss",loss)
    updates, opt_state = optimizer.update(grads, opt_state, params)
    params = optax.apply_updates(params, updates)
    
# After training
print("estimation of the parameters:")
print(params)

estimate  = forward(X=env.reset(), params=params)
print("estimate", estimate)

In [None]:
test = np.array([-1,1.1,1])
estimate  = forward(X=test, params=params)
print("estimate", estimate)

In [None]:
y_preds = [forward(X=x, params=params) for x in X_train[:100]]
y_preds

In [None]:
env = DubinsCarEnv()
state = env.reset()
done = False
max_iter = 100
counter = 0
while (not done) and (counter < max_iter):
    counter+=1
    possible_actions = []
    for a in range(env.action_space.n):
        next_state, reward, done, _ = env.state_action_step(state, a)
        estimate = forward(X=next_state, params=params)
        print('estimate', estimate, 'actual', reward)
        possible_actions.append(estimate[0])
        #possible_actions.append(reward)
    action = np.argmax(np.array(possible_actions))
    print(action, possible_actions )

    state, reward, done, _ = env.step(action)
    env.render()
    print(counter)
    
env.make_gif()