In [1]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function


import tensorflow as tf
import tflearn
import numpy as np
from sklearn.model_selection import train_test_split

import drqn

from experience_buffer import ExperienceBuffer
import dataset_utils as d_utils
import utils
import models_dict_utils

%load_ext autoreload
%autoreload 2
%reload_ext autoreload

### Preprocessing Data for DRQN
We take the data from data generator and save them into traces of (s,a,r,sp) tuples.

Each trajectory corresponds to a trace.

If trajectory has length n, then trace will have length n-1. (since we need the next state sp)

In [2]:
data = d_utils.load_data(filename="../synthetic_data/test-n10000-l3-random.pickle")

In [3]:
dqn_data = d_utils.preprocess_data_for_dqn(data, reward_model="dense")

In [17]:
# Single Trace
print (dqn_data[0])

[[array([ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  1.]), array([ 1.,  0.,  0.,  0.,  0.]), 0.20000000000000001, array([ 1.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.])], [array([ 1.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.]), array([ 0.,  0.,  0.,  0.,  1.]), 0.20000000000000001, array([ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  1.])]]


In [18]:
# First tuple in a trace
s,a,r,sp = dqn_data[0][0]
print (s)
print (a)
print (r)
print (sp)

[ 0.  0.  0.  0.  0.  0.  0.  0.  0.  1.]
[ 1.  0.  0.  0.  0.]
0.2
[ 1.  0.  0.  0.  0.  0.  0.  0.  0.  0.]


In [19]:
# Last tuple
s,a,r,sp = dqn_data[0][-1]
print (s)
print (a)
print (r)
print (sp)

[ 1.  0.  0.  0.  0.  0.  0.  0.  0.  0.]
[ 0.  0.  0.  0.  1.]
0.2
[ 0.  0.  0.  0.  0.  0.  0.  0.  0.  1.]


In [4]:
dqn_data_train, dqn_data_test = train_test_split(dqn_data, test_size=0.2)

### Creating a DRQN model and training it

In [5]:
model_id = "test_model_drqn"

# Directory for storing tensorboard summaries
tensorboard_dir = '../tensorboard_logs/' + model_id + '/'
summary_interval = 100
checkpoint_dir = '../checkpoints/' + model_id + '/'
checkpoint_path = checkpoint_dir + '_/'

utils.check_if_path_exists_or_create(tensorboard_dir)
utils.check_if_path_exists_or_create(checkpoint_dir)
    
checkpoint_interval = 200

In [None]:
drqn_model = drqn.DRQNModel(model_id=model_id, timesteps=2)
init = tf.global_variables_initializer()
saver = tf.train.Saver(max_to_keep=3)
# writer_summary = tf.summary.FileWriter
# histogram_summary = tf.summary.histogram

with tf.Session() as session:
    session.run(init)
    train_buffer =ExperienceBuffer()
    train_buffer.buffer = dqn_data_train
    train_buffer.buffer_sz = len(train_buffer.buffer)
    drqn.train(drqn_model, session, saver,train_buffer,load_checkpoint=False, ckpt_path=checkpoint_dir)

In [None]:
# Using tflearn Trainer

In [25]:
tf.reset_default_graph()
model_dict = models_dict_utils.load_model_dict(model_id)
n_inputdim =model_dict["n_inputdim"]
n_hidden = model_dict["n_hidden"]
n_outputdim = model_dict["n_outputdim"]
graph_ops = drqn.build_tf_graph_drqn_tflearn(n_timesteps=2, n_inputdim=n_inputdim, n_hidden=n_hidden, n_actions=n_outputdim)

init = tf.global_variables_initializer()
with tf.Session() as session:
    session.run(init)
    train_buffer = ExperienceBuffer()
    train_buffer.buffer = dqn_data_train
    train_buffer.buffer_sz = len(train_buffer.buffer)
    
    val_buffer = ExperienceBuffer()
    val_buffer.buffer = dqn_data_test
    val_buffer.buffer_sz = len(val_buffer.buffer)

    drqn.train_tflearn(graph_ops, train_buffer, val_buffer, n_epoch=256,
                  tensorboard_dir=tensorboard_dir, run_id="test_run", load_checkpoint=True,load_ckpt_path=checkpoint_dir, save_ckpt_path=checkpoint_path)

Training Step: 35999  | total loss: [1m[32m0.09022[0m[0m | time: 2.390s
| Optimizer | epoch: 256 | loss: 0.09022 -- iter: 7936/8000
Training Step: 36000  | total loss: [1m[32m0.08851[0m[0m | time: 3.414s
| Optimizer | epoch: 256 | loss: 0.08851 | val_loss: 0.09024 -- iter: 8000/8000
--
Type is unsupported, or the types of the items don't match field type in CollectionDef.
'list' object has no attribute 'name'
Type is unsupported, or the types of the items don't match field type in CollectionDef.
'list' object has no attribute 'name'
