# Generative Adversarial Imitation Learning

In [None]:
import numpy as np
import tensorflow as tf

In [None]:
class Discriminator(object):
    def __init__(self, s_size, a_size, h_size, lr):
        self.state_in_expert = tf.placeholder(shape=[None, s_size], dtype=tf.float32)
        self.action_in_expert = tf.placeholder(shape=[None, a_size], dtype=tf.float32)
        self.state_in_policy = tf.placeholder(shape=[None, s_size], dtype=tf.float32)
        self.action_in_policy = tf.placeholder(shape=[None, a_size], dtype=tf.float32)
        self.s_size = s_size
        self.h_size = h_size
        self.lr = lr
        self.update()
        
    def get_d(self, state_in, action_in, reuse):
        with tf.variable_scope("discriminator"):
            concat_input = tf.concat([state_in, action_in], axis=1)
            hidden_1 = tf.layers.dense(concat_input, self.h_size, activation=tf.nn.tanh, use_bias=False, name="d_hidden_1", reuse=reuse)
            hidden_2 = tf.layers.dense(hidden_1, self.h_size, activation=tf.nn.tanh, use_bias=False, name="d_hidden_2", reuse=reuse)
            d = tf.layers.dense(hidden_2, 1, activation=tf.nn.sigmoid, use_bias=False, name="d_out", reuse=reuse)
            return d
        
    def update(self):
        self.d_expert = self.get_d(self.state_in_expert, self.action_in_expert, False)
        self.d_policy = self.get_d(self.state_in_policy, self.action_in_policy, True)
        self.de = tf.reduce_mean(self.d_expert)
        self.dp = tf.reduce_mean(self.d_policy)
        self.d_loss = tf.reduce_mean(tf.log(self.d_policy + 1e-10) + tf.log(1 - self.d_expert + 1e-10))
        optimizer = tf.train.AdamOptimizer(learning_rate=self.lr)
        self.update_batch = optimizer.minimize(self.d_loss)

In [None]:
S_exp = np.reshape(np.random.randn(100) / 100, [-1, 1])
A_exp = (S_exp > 0) * 1.0

In [None]:
tf.reset_default_graph()

disc = Discriminator(1, 1, 32, 1e-3)

init = tf.global_variables_initializer()
sess = tf.InteractiveSession()

In [None]:
sess.run(init)
for i in range(50000):
    S_pol = np.reshape(np.random.randn(100) / 100, [-1, 1])
    A_pol = (S_pol < 0) * 1.0
    fd = {disc.state_in_expert: S_exp, disc.state_in_policy: S_pol,
          disc.action_in_expert: A_exp, disc.action_in_policy: A_pol}
    d_e, d_p, loss, _ = sess.run([disc.de, disc.dp, disc.d_loss, disc.update_batch], feed_dict=fd)
    if i % 1000 == 0:
        print(d_e, d_p, loss)