In [1]:
%load_ext autoreload
%autoreload 2
from models import NNDynamicModel, MPCcontroller
import tensorflow as tf
import numpy as np
import json
import pickle

  from ._conv import register_converters as _register_converters


In [2]:
DYNAMIC_MODEL_PARAMS = {
    'l2_regularizer_scale': 1e-5,
    'activation': 'tanh',
    'output_activation': None,
    'learning_rate': [0.001, 0.001, 0.001],
    'batch_size': 128,
    'n_layers': 3,
    'size': 128,
    'iterations': 100,
}

In [3]:
# load data
with open('./sample.json', 'r') as f:
    data = json.load(f)

test_data = data[-2000:]
data = data[:-2000]

In [4]:
# [obs, action, nobs, r]
obs = [d for d in data if d[0][1] < 0]
obs = np.array(obs)
print(obs[2])

[list([0.0, -50.0, 105.0, 51.0, 6.2920000000000975])
 list([2.0, 0.37846549794879236])
 list([0.0, 439.0, 105.0, 51.0, 6.320000000000107]) 1]


In [5]:
obs = [d[0] for d in data]
normalization = [np.min(obs, axis=0), np.max(obs, axis=0), np.mean(obs, axis=0)]
print(normalization)

session = tf.Session()

[array([  0.   , -50.   ,  90.   ,  17.   ,   6.168]), array([  0.   , 625.   , 105.   ,  51.   ,   6.657]), array([  0.        , 424.433     ,  97.6475    ,  29.597     ,
         6.21942167])]


In [6]:
model = NNDynamicModel(
    name='test',
    sess=session,
    normalization=normalization,
    **DYNAMIC_MODEL_PARAMS
)

In [7]:
init = tf.global_variables_initializer()
session.run(init)

In [8]:
model.fit(data)

on iter_step 0, loss = [0.4479493  0.54096717 0.4239607 ]
on iter_step 10, loss = [0.37100336 0.5289081  0.31792584]
on iter_step 20, loss = [0.34034365 0.52791566 0.2917938 ]
on iter_step 30, loss = [0.30821112 0.5215189  0.27817002]
on iter_step 40, loss = [0.30320576 0.5268646  0.2736329 ]
on iter_step 50, loss = [0.30070356 0.52706665 0.26914805]
on iter_step 60, loss = [0.28914273 0.5223497  0.2621015 ]
on iter_step 70, loss = [0.2817178  0.51736903 0.2617269 ]
on iter_step 80, loss = [0.2860543  0.5117415  0.25506967]
on iter_step 90, loss = [0.27748728 0.51080805 0.25007936]


In [9]:
mpc = MPCcontroller(dyn_model=model) 

In [10]:
test_obs, test_action = [d[0] for d in test_data], [d[1] for d in test_data]
test_pred = model.predict(test_obs, test_action)
train_pred = model.predict([d[0] for d in data], [d[1] for d in data])
test_reward = [d[-1] for d in test_data]
train_reward = [d[-1] for d in data]
gd = 0
test_pred = np.array([-1 if d[0] >= d[1] + gd else 1 for d in test_pred])
train_pred = np.array([-1 if d[0] >= d[1] + gd  else 1 for d in train_pred])
test_acc = np.sum(test_pred == test_reward) / len(test_reward)
train_acc = np.sum(train_pred == train_reward) / len(train_reward)
print('test_acc: ', test_acc, '\ntrain_acc: ', train_acc)

test_acc:  0.864 
train_acc:  0.8566666666666667


In [11]:
err_data = [test_data[i] for i in range(len(test_data)) if test_pred[i] != test_reward[i]]
print(len(err_data))

272


In [12]:
ed = err_data[0]
err_action = np.array([d[1] for d in err_data])
print(ed)
print(model.predict([ed[0]], [ed[1]]))

[[0.0, 463.0, 90.0, 25.0, 6.533000000000178], [1.0, 0.8963443224702344], [0.0, 71.0, 90.0, 25.0, 6.65900000000022], -1]
[[0.23102708 0.76897292]]


In [13]:
action = mpc.get_action(ed[0])
print(action)
print(model.predict([ed[0]], [action]))

[2.         0.92919007]
[[0.97122065 0.02877935]]


In [14]:
from game_controller import GameEnv
import time

In [15]:
env = GameEnv()
env.start()
obs = None
while True:
    if env.status == 'CRASHED':
        env.restart()
        obs = None
    while obs is None:
        time.sleep(env.interval_time)
        obs = env.get_observation()
    action = mpc.get_action(obs)
    print(action)
    env.perform_action(action)

[1.         0.26947554]
[1.         0.24914105]
[1.         0.27154803]
[1.         0.21984896]
[1.         0.25748501]
[1.         0.25708158]
[1.         0.25277342]
[1.         0.26917848]
[1.         0.24369939]
[1.         0.27588385]
[1.        0.2365862]
[1.         0.30997926]
[1.        0.3206901]
[1.         0.27091303]
[2.         0.67103372]
[2.         0.70485548]
[1.         0.58769428]
[1.         0.60637972]
[1.         0.59905223]
[1.         0.59119009]
[1.         0.62228676]
[1.         0.60113987]
[1.         0.54763343]
[1.         0.60101109]
[1.         0.58954427]
[1.         0.60859799]
[1.         0.60604238]
[1.        0.6543313]
[1.         0.60885497]
[1.         0.61039764]
[1.         0.61914529]
[1.         0.59759672]
[1.         0.57030778]
[1.         0.57203205]
[1.         0.59435181]


KeyboardInterrupt: 