In [1]:
%load_ext autoreload
%autoreload 2
from models import NNDynamicModel, MPCcontroller
import tensorflow as tf
import numpy as np
import json
import pickle

  from ._conv import register_converters as _register_converters


In [2]:
DYNAMIC_MODEL_PARAMS = {
    'l2_regularizer_scale': 1e-7,
    'activation': 'tanh',
    'output_activation': None,
    'learning_rate': [0.001, 0.001, 0.001],
    'batch_size': 128,
    'n_layers': 3,
    'size': 64,
    'iterations': 100,
}

In [3]:
# load data
with open('./gao/data_buffer.pkl', 'rb') as f:
    data = pickle.load(f)

data = data.sample(data._tail)
test_data = data[-2000:]
data = data[:-2000]

In [4]:
# [obs, action, nobs, r]
obs = [d for d in data if d[0][1] < 0]
obs = np.array(obs)
print(obs[2])

[list([0.0, -9.0, 105.0, 17.0, 6.335000000000112])
 list([2.0, 0.7425015558113478])
 list([0.0, 152.0, 90.0, 25.0, 6.3520000000001176]) -1]


In [5]:
obs = [d[0] for d in data]
normalization = [np.min(obs, axis=0), np.max(obs, axis=0), np.mean(obs, axis=0)]
print(normalization)

session = tf.Session()

[array([  0.   , -50.   ,  90.   ,  17.   ,   6.153]), array([  0.   , 625.   , 105.   ,  51.   ,   6.752]), array([  0.        , 434.52996667,  97.109     ,  29.5238    ,
         6.2301027 ])]


In [6]:
model = NNDynamicModel(
    name='test',
    sess=session,
    normalization=normalization,
    **DYNAMIC_MODEL_PARAMS
)

In [7]:
init = tf.global_variables_initializer()
session.run(init)

In [8]:
model.fit(data)

on iter_step 0, loss = [0.5706578  0.6751954  0.53191787]
on iter_step 10, loss = [0.44149676 0.65795594 0.4981232 ]
on iter_step 20, loss = [0.4299685 0.6495223 0.474533 ]
on iter_step 30, loss = [0.4304059  0.64419097 0.465176  ]
on iter_step 40, loss = [0.4281967  0.64076144 0.45886502]
on iter_step 50, loss = [0.42332742 0.6364676  0.45450646]
on iter_step 60, loss = [0.4182011  0.6348425  0.45283675]
on iter_step 70, loss = [0.41330203 0.6314098  0.44820625]
on iter_step 80, loss = [0.4082443  0.6287709  0.44338024]
on iter_step 90, loss = [0.40234295 0.6273373  0.4414254 ]


In [9]:
mpc = MPCcontroller(dyn_model=model) 

In [10]:
test_obs, test_action = [d[0] for d in test_data], [d[1] for d in test_data]
test_pred = model.predict(test_obs, test_action)
train_pred = model.predict([d[0] for d in data], [d[1] for d in data])
test_reward = [d[-1] for d in test_data]
train_reward = [d[-1] for d in data]
gd = 0
test_pred = np.array([-1 if d[0] >= d[1] + gd else 1 for d in test_pred])
train_pred = np.array([-1 if d[0] >= d[1] + gd  else 1 for d in train_pred])
test_acc = np.sum(test_pred == test_reward) / len(test_reward)
train_acc = np.sum(train_pred == train_reward) / len(train_reward)
print('test_acc: ', test_acc, '\ntrain_acc: ', train_acc)

test_acc:  0.713 
train_acc:  0.7213


In [11]:
err_data = [test_data[i] for i in range(len(test_data)) if test_pred[i] != test_reward[i]]
print(len(err_data))

574


In [22]:
ed = err_data[0]
err_action = np.array([d[1] for d in err_data])
print(ed)
print(model.predict([ed[0]], [ed[1]]))

[list([0.0, 612.0, 105.0, 34.0, 6.1840000000000614])
 list([0.0, 0.9428808617415323])
 list([0.0, 280.0, 105.0, 34.0, 6.242000000000081]) 1]
[[0.93979515 0.06020485]]


In [23]:
action = mpc.get_action(ed[0])
print(action)
print(model.predict([ed[0]], [action]))

[0.         0.96218234]
[[0.94764915 0.05235085]]


In [27]:
from game_controller import GameEnv
import time

In [None]:
env = GameEnv()
env.start()
obs = None
while True:
    if env.status == 'CRASHED':
        env.restart()
        obs = None
    while obs is None:
        time.sleep(env.interval_time)
        obs = env.get_observation()
    action = mpc.get_action(obs)
    print(action)
    env.perform_action(action)

[2.         0.87819288]
[2.         0.78665201]
[0.         0.99321678]
[0.         0.99780051]
[0.         0.93470674]
[0.         0.87911879]
[0.         0.99242935]
[0.         0.88628478]
[0.         0.97900383]
[0.         0.97536597]
[0.         0.92732769]
[0.         0.94348627]
[0.         0.94566645]
[0.         0.91669065]
[0.         0.95671532]
[0.         0.96484812]
[0.         0.97171999]
[0.         0.95283831]
[0.         0.91805181]
