In [1]:
import jax
import jax.numpy as jnp
import gym_2048
import gym
import numpy as onp
import haiku as hk

## Environment Set-Up

In [2]:
def example():
    env = gym.make('2048-v0')
    env.seed(42)

    env.reset()
    env.render()

    done = False
    moves = 0
    while not done:
        action = env.action_space.sample()
        next_state, reward, done, info = env.step(action)
        moves += 1
        
        print('Next Action: "{}"\n\nReward: {}'.format(
          gym_2048.Base2048Env.ACTION_STRING[action], reward))
        env.render()
        break

    print('\nTotal Moves: {}'.format(moves))

example()
# action is represented by number 0-4 (LEFT, UP, RIGHT, DOWN)
# env provides state is as 2D matrix

0 	0 	0 	0
0 	0 	0 	2
0 	0 	0 	0
0 	2 	0 	0
Next Action: "left"

Reward: 0
0 	0 	2 	0
2 	0 	0 	0
0 	0 	0 	0
2 	0 	0 	0

Total Moves: 1


  board[tile_locs] = tiles


In [3]:
def action_vector(action_index: int) -> jnp.array:
    base_vector = jnp.array([0, 0, 0, 0])
    return jax.nn.one_hot(base_vector, action_index)

def state_vector(observation: onp.array) -> jnp.array:
    # convert to log2
    state = onp.where(observation!=0, onp.log2(observation), 0)
    # (x, y, value)
    state = onp.expand_dims(state, axis=2)
    state = onp.pad(state, pad_width=[(0,0),(0,0),(0,12)], constant_values=0)
    for x in range(4):
        for y in range(4):
            value = int(state[x, y, 0])
            state[x, y, value], state[x, y, 0] = 1, 0
            
    # for the moment lets flatten into a single vector (208 Dimmension)
    return jnp.array(state.flatten())

## Deep Q-Learner

In [4]:
## Replay Buffer to store training episodes for TD error
## Linked list of size 2 over Queue
import collections
from typing import Tuple

EnvironmentOuput = collections.namedtuple("EnvironmentOuput", "state reward discount")

class ReplayBuffer:
    def __init__(self, length):
        self._prev = None
        self._action = None
        self._latest = None
        self.buffer = collections.deque(maxlen=length)
        
    def push(self, env_output: EnvironmentOuput, action: int):
        self._prev = self._latest
        self._latest = env_output
        self._action = action
        
        if action is not None:
            # append s_t-1, a_t, r_t, gamma_t, s_t
            self.buffer.append(
                (self._prev.state,
                 self._action,
                 self._latest.reward,
                 self._latest.discount,
                 self._latest.state)
            )
        
    def sample(self, batch_size: int) -> Tuple:
        samples = random.sample(self.buffer, batch_size)
        prev_state, action, reward, discount, state = zip(*samples)
        
        return (jnp.stack(state),
                jnp.asarray(action),
                jnp.asarray(reward),
                jnp.asarray(discount),
                jnp.stack(next_state))
         
    def is_ready(self, batch_size: int) -> bool:
        return len(self.buffer) > batch_size

In [5]:
## Q Network
class QNetwork(hk.Module):
    def __init__(self, num_actions: int):
        super(QNetwork, self).__init__()
        self._num_actions = num_actions

    def __call__(self, state) -> jnp.array:
        net = hk.Sequential([
            hk.Linear(128), jax.nn.relu,
            hk.Linear(64), jax.nn.relu,
            hk.Linear(self._num_actions)
        ])
        return net(state)

def build_network(num_actions: int) -> hk.Transformed:
    def q_network(x):
        model = QNetwork(num_actions)
        return model(x)

    return hk.transform(q_network)

In [6]:
# define TD error as pure JAX function so can be vmapped
def td_error(initial_q, action, reward, discount, final_q):
    target_q = reward + discount* jnp.amax(final_q)
    
    # want to update predicted to target
    error = jax.lax.stop_gradients(target_q) - predicted_q
    return jnp.squared(error)

In [7]:
from jax.experimental import optix

ActorState = collections.namedtuple("ActorState", "count")
LearnerState = collections.namedtuple("LearnerState", "count opt_state")

class DQN:
    def __init__(self, state_size, action_size):
        # parameters
        self._epsilon = 1.0 # exploration rate
        self._epsilon_decay = 0.995
        self._epsilon_min = 0.01
        
        # model and optimiser
        self._network = build_network(action_size)
        self._opt = optix.adam(1e-3)
        
        # jit methods
        self.learner_step = jax.jit(self.learner_step)
        self.action_step = jax.jit(self.action_step)

    def init_network(self, sample_state: jnp.array, key: jnp.array) -> jnp.ndarray:
        """return initialized parameters for network"""
        return self._network.init(key, sample_state)
    
    def init_learner(self, params: hk.Params) -> jnp.ndarray:
        """return optimiser parameters"""
        opt_state = self._opt.init(params)
        return opt_state
        
    def _loss(self,
              params: hk.Params,
              previous_state,
              action,
              reward,
              discount,
              state
             ) -> jnp.ndarray:
        """Batched loss"""
        initial_q = self._network.apply(params, previous_state)[action] 
        final_q = self._network.apply(params, state)

        batched_loss = jax.vmap(td_error)
        batched_td_error = td_error(initial_q, action, reward, discount, final_q)

        return jnp.mean(batched_td_error)
    
    def _epsilon_schedule(self, step_count: int):
        """Return the epsilon value for a given step"""
        return self._epsilon * (self._epsilon_decay ** step_count)
        
    def learner_step(self,
                     params: hk.Params,
                     opt_state: hk.Params,
                     previous_states: jnp.ndarray,
                     actions: jnp.array,
                     rewards: jnp.array,
                     discounts: jnp.array,
                     next_states: jnp.array
                    ) -> Tuple[jnp.array, jnp.array]:
        """Perform gradient update on model"""  
        # get gradients
        grads = jax.grad(self._loss)(params, previous_states, actions, rewards, discounts, nextstates)
        
        # update network
        updates, opt_state = self._opt.update(grads, opt_state)
        new_params = optix.apply_updates(params, updates)
        return new_params, opt_state

    def action_step(self,
                    params: hk.Params,
                    state: jnp.array,
                    actor_state: ActorState,
                    key: jnp.array,
                    evaluation: bool
                   )-> Tuple[jnp.array, ActorState]:
        """Compute a single action""" 
        # add (dummy) batch dimmension
        state = jnp.expand_dims(state, 0)
        q = self._network.apply(params, state)[0]
        
        # can't use if statements in Jax so calculate both random and greedy actions
        greedy_a = jnp.argmax(q) 
        random_a = jax.random.randint(key, (), 0, 4)
        epsilon = self._epsilon_schedule(actor_state.count)
        
        # training action should be chosen using epsilon distribution betweeen random and greedy action
        train_a = jax.lax.select(jax.random.uniform(key, minval=0, maxval=1) < epsilon, random_a, greedy_a)
        
        # training or evaluation
        return jax.lax.select(evaluation, greedy_a, train_a), ActorState(actor_state.count + 1)

In [8]:
state_size = 208
action_size = 4
key = hk.PRNGSequence(42)

agent = DQN(state_size, action_size)
actor_state = ActorState(0)
params = agent.init_network(sample_state=jnp.ones(208),key=next(key))
opt_state = agent.init_learner(params)

action, actor_state = agent.action_step(params, jnp.ones(208), actor_state, next(key), False)
action, actor_state = agent.action_step(params, jnp.ones(208), actor_state, next(key), False)



In [9]:
def collect_experience(num_episodes: int, env: gym.Env, agent: DQN, actor_state: ActorState, replay_buffer: ReplayBuffer):
    """Use the agent to play a number of games, store experiences in the reply buffer"""
    for episode in range(num_episodes):
        done = False
        moves = 0
        
        # add initial env state to buffer
        observation = env.reset()
        replay_buffer.push(EnvironmentOuput(state_vector(observation), reward=0, discount=0.95), None)
        
        while not done:
            # convert observation into jnp array
            state = state_vector(observation)
            
            # get next step from agent
            action, actor_state = agent.action_step(params, state, actor_state, next(key), False)

            # interact with the environment
            observation, reward, done, info = env.step(action)
            moves += 1
                        
            # store interaction in replay buffer
            replay_buffer.push(EnvironmentOuput(state, reward, discount=0.95), action)
            
            if moves > 1000:
                env.render()
                print(action)
                print(done)
                print(info)
            
        print(f'\n Episode: {episode}, Total Moves: {moves} Final Reward: {reward}')


In [10]:
# for environment
env = gym.make('2048-v0')
env.seed(42)
buffer = ReplayBuffer(5000)

# for agent
state_size = 208
action_size = 4
key = hk.PRNGSequence(42)

# init agent
agent = DQN(state_size, action_size)
actor_state = ActorState(0)
params = agent.init_network(sample_state=jnp.ones(208),key=next(key))
opt_state = agent.init_learner(params)

# collect experience
collect_experience(10, env, agent, ActorState(0), buffer)

  import sys



 Episode: 0, Total Moves: 74 Final Reward: 4

 Episode: 1, Total Moves: 109 Final Reward: 4

 Episode: 2, Total Moves: 281 Final Reward: 4

 Episode: 3, Total Moves: 267 Final Reward: 8
4 	32 	8 	2
64 	16 	4 	4
16 	4 	2 	16
4 	8 	4 	2
1
False
{}
4 	32 	8 	2
64 	16 	4 	4
16 	4 	2 	16
4 	8 	4 	2
1
False
{}
4 	32 	8 	2
64 	16 	4 	4
16 	4 	2 	16
4 	8 	4 	2
1
False
{}
4 	32 	8 	2
64 	16 	4 	4
16 	4 	2 	16
4 	8 	4 	2
1
False
{}
4 	32 	8 	2
64 	16 	4 	4
16 	4 	2 	16
4 	8 	4 	2
1
False
{}
4 	32 	8 	2
64 	16 	4 	4
16 	4 	2 	16
4 	8 	4 	2
1
False
{}
4 	32 	8 	2
64 	16 	4 	4
16 	4 	2 	16
4 	8 	4 	2
1
False
{}
4 	32 	8 	2
64 	16 	4 	4
16 	4 	2 	16
4 	8 	4 	2
1
False
{}
4 	32 	8 	2
64 	16 	4 	4
16 	4 	2 	16
4 	8 	4 	2
1
False
{}
4 	32 	8 	2
64 	16 	4 	4
16 	4 	2 	16
4 	8 	4 	2
1
False
{}
4 	32 	8 	2
64 	16 	4 	4
16 	4 	2 	16
4 	8 	4 	2
1
False
{}
4 	32 	8 	2
64 	16 	4 	4
16 	4 	2 	16
4 	8 	4 	2
1
False
{}
4 	32 	8 	2
64 	16 	4 	4
16 	4 	2 	16
4 	8 	4 	2
1
False
{}
4 	32 	8 	2
64 	16 	4 	4
16 	4 	2

KeyboardInterrupt: 