In [None]:
import t3f
import tensorflow as tf
import numpy as np
import os
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
from environments import Snake
from openai_methods import ReplayBuffer, PrioritizedReplayBuffer, LinearSchedule

# Experiments

In [None]:
def tf_repeat(x, num):
    u = tf.reshape(x, (-1, 1))
    ones = tf.ones(1, dtype=tf.int32)
    u = tf.tile(u, tf.concat([ones, num], axis=0))
    u = tf.reshape(u, (-1, 1))
    return u

class QQTT:
    
    def __init__(self, num_actions, num_colors=4, state_shape=[8, 8, 1],
                 tt_rank=16, optimizer=tf.train.AdamOptimizer(2.5e-4), 
                 dtype=tf.float32, scope="qqtt_network", reuse=False):
        
        input_shape = np.prod(state_shape) * [num_colors,] + [num_actions,]
        
        with tf.variable_scope(scope, reuse=reuse):
            
        # random initialization of Q-tensor
            q0init = t3f.random_tensor(shape=input_shape, tt_rank=tt_rank, stddev=10.)
            q0init = t3f.cast(q0init, dtype=dtype)
            q0 = t3f.get_variable('Q', initializer=q0init)
        
            self.input_states = tf.placeholder(dtype=tf.int32, shape=[None]+state_shape)
            self.input_actions = tf.placeholder(dtype=tf.int32, shape=[None])
            self.input_targets = tf.placeholder(dtype=dtype, shape=[None])
            self.prio_weights = tf.placeholder(dtype=dtype, shape=[None])

            reshaped_s = tf.reshape(self.input_states, (-1, np.prod(state_shape)))
            reshaped_a = tf.reshape(self.input_actions, (-1, 1))
            input_s_and_a = tf.concat([reshaped_s, reshaped_a], axis=1) 
            self.q_selected = t3f.gather_nd(q0, input_s_and_a, dtype=dtype)

            reshaped_s_ = tf.reshape(self.input_states, [-1]+state_shape)
            
            # some shitty code
            s_a_idx = tf.concat(num_actions * [reshaped_s], axis=0) 
            actions_range = tf.range(start=0, limit=num_actions)
            a_idx = tf_repeat(actions_range, tf.shape(self.input_states)[0:1])
            s_a_idx = tf.concat([s_a_idx, a_idx], axis=1)
            vals = t3f.gather_nd(q0, s_a_idx, dtype=dtype)
            q_values = tf.transpose(tf.reshape(vals, shape=(num_actions, -1)))
            # shitty code ends here
            
            self.q_argmax = tf.argmax(q_values, axis=1)
            self.q_max = tf.reduce_max(q_values, axis=1)
            
            # self.loss = tf.losses.huber_loss(self.q_selected, self.input_targets)
            self.loss = tf.losses.huber_loss(self.q_selected, self.input_targets)
            self.td_error = self.q_selected - self.input_targets
            self.update_model = optimizer.minimize(self.loss)
        
    def update(self, sess, states, actions, targets, weights):
        feed_dict = {self.input_states:states,
                     self.input_actions:actions,
                     self.input_targets:targets,
                     self.prio_weights:weights}
        _, td = sess.run([self.update_model, self.td_error], feed_dict)
        return td
        
        
    def get_q_action_values(self, sess, states, actions):
        feed_dict = {self.input_states:states,
                     self.input_actions:actions}
        return sess.run(self.q_selected, feed_dict=feed_dict)
        
    def get_q_argmax(self, sess, states):
        feed_dict = {self.input_states:states}
        return sess.run(self.q_argmax, feed_dict=feed_dict)
    
    def get_q_max(self, sess, states):
        feed_dict = {self.input_states:states}
        return sess.run(self.q_max, feed_dict=feed_dict)

In [None]:
batch_size = 64
replay_memory_size=50000
replay_start_size=10000
init_eps=1
final_eps=0.1
annealing_steps=100000
gamma=0.999
max_episode_length=50
num_episodes = 1000000

prioritized_replay_alpha=0.6
prioritized_replay_beta0=0.4
prioritized_replay_beta_iters=None
prioritized_replay_eps=1e-6


if prioritized_replay_beta_iters is None:
    prioritized_replay_beta_iters = max_episode_length
beta_schedule = LinearSchedule(prioritized_replay_beta_iters,
                               initial_p=prioritized_replay_beta0,
                               final_p=1.0)


replay_buffer = ReplayBuffer(replay_memory_size)

In [5]:
qqtt_agent = QQTT(4, state_shape=[4, 4, 1], scope="qtt_snake")

In [6]:
train_env = Snake(grid_size=(4, 4))

frame_count = 0
while (frame_count < replay_start_size):
    s = train_env.reset()
    for time_step in range(max_episode_length):
        a = np.random.randint(4)
        s_, r, end = train_env.step(a)
        if r == 0: r = -0.01
        replay_buffer.add(s, a, r, s_, float(end))
        s = s_
        frame_count += 1
        if end:
            break

In [7]:
sess = tf.Session()
sess.run(tf.global_variables_initializer())

In [8]:
train_rewards = []
for i in range(num_episodes):
    s = train_env.reset()
    R = 0
    for time_step in range(max_episode_length):
        if np.random.rand() > 0.01:
            a = qqtt_agent.get_q_argmax(sess, [s])[0]
        else:
            a = np.random.randint(0, 4)
        
        s_, r, done = train_env.step(a)
        R += r
        if r == 0: r = -0.01
        replay_buffer.add(s, a, r, s_, float(done))
        
        s = s_

        if time_step % 1 == 0:
            
            experience = replay_buffer.sample(batch_size)
            obses_t, actions, rewards, obses_tp1, dones = experience
            weights, batch_idxes = np.ones_like(rewards), None
            q_ = qqtt_agent.get_q_max(sess, obses_tp1)
            targets = rewards + gamma * q_ * (1. - dones)
            
            td_errors = qqtt_agent.update(sess, obses_t, actions, targets, weights)
            new_priorities = np.abs(td_errors) + prioritized_replay_eps
            # replay_buffer.update_priorities(batch_idxes, new_priorities)
            
            
        if done: break
    train_rewards.append(R) 
    if i % 500 == 0:
        avg_reward = np.mean(train_rewards[-500:])
        max_reward = np.max(train_rewards[-500:])
        print("Train info:", avg_reward, max_reward) 

Train info: -1.0 -1
Train info: -0.816 2
Train info: -0.676 2
Train info: -0.294 3




Train info: -0.112 2
Train info: -0.902 1
Train info: -0.878 1
Train info: -0.88 1
Train info: -0.91 1
Train info: -0.888 1
Train info: -0.88 1


KeyboardInterrupt: 