In [8]:
import tensorflow as tf
import numpy as np
from tf_agents.agents.dqn import dqn_agent
from tf_agents.networks import q_network
from tf_agents.environments import tf_py_environment
from tf_agents.replay_buffers import tf_uniform_replay_buffer
from tf_agents.trajectories import trajectory
from tf_agents.policies import random_tf_policy
from tf_agents.utils import common
from tf_agents.specs import array_spec
from tf_agents.environments import py_environment
from tf_agents.trajectories import time_step as ts


In [9]:

class ConnectFourEnv(py_environment.PyEnvironment):
    def __init__(self):
        self._action_spec = array_spec.BoundedArraySpec(
            shape=(), dtype=np.int32, minimum=0, maximum=6, name='action')
        self._observation_spec = array_spec.BoundedArraySpec(
            shape=(6, 7), dtype=np.int32, minimum=0, maximum=2, name='observation')
        self._state = np.zeros((6, 7), dtype=np.int32)
        self._current_player = 1
        self._episode_ended = False

    def action_spec(self):
        return self._action_spec

    def observation_spec(self):
        return self._observation_spec

    def _reset(self):
        self._state = np.zeros((6, 7), dtype=np.int32)
        self._current_player = 1
        self._episode_ended = False
        return ts.restart(np.array(self._state, dtype=np.int32))

    def _step(self, action):
        if self._episode_ended:
            return self.reset()

        if np.all(self._state[:, action] != 0):
            self._episode_ended = True
            return ts.termination(np.array(self._state, dtype=np.int32), -10)

        row = np.max(np.where(self._state[:, action] == 0))
        self._state[row, action] = self._current_player

        if self._check_win(self._current_player):
            self._episode_ended = True
            return ts.termination(np.array(self._state, dtype=np.int32), 10)

        if np.all(self._state != 0):
            self._episode_ended = True
            return ts.termination(np.array(self._state, dtype=np.int32), 0)

        self._current_player = 1 if self._current_player == 2 else 2
        return ts.transition(np.array(self._state, dtype=np.int32), reward=0.0, discount=1.0)

    def _check_win(self, player):
      # Horizontal check
      for c in range(7-3):
          for r in range(6):
              if self._state[r][c] == player and self._state[r][c+1] == player and self._state[r][c+2] == player and self._state[r][c+3] == player:
                  return True

      # Vertical check
      for c in range(7):
          for r in range(6-3):
              if self._state[r][c] == player and self._state[r+1][c] == player and self._state[r+2][c] == player and self._state[r+3][c] == player:
                  return True

      # Positive diagonal check
      for c in range(7-3):
          for r in range(6-3):
              if self._state[r][c] == player and self._state[r+1][c+1] == player and self._state[r+2][c+2] == player and self._state[r+3][c+3] == player:
                  return True

      # Negative diagonal check
      for c in range(7-3):
          for r in range(3, 6):
              if self._state[r][c] == player and self._state[r-1][c+1] == player and self._state[r-2][c+2] == player and self._state[r-3][c+3] == player:
                  return True

      return False




In [10]:

# Convert the Python environment to a TensorFlow environment.
train_env = tf_py_environment.TFPyEnvironment(ConnectFourEnv())
eval_env = tf_py_environment.TFPyEnvironment(ConnectFourEnv())

# Initialize the QNetwork.
fc_layer_params = (100,)
q_net = q_network.QNetwork(
    train_env.observation_spec(),
    train_env.action_spec(),
    fc_layer_params=fc_layer_params)

# Initialize the agent.
optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)
train_step_counter = tf.Variable(0)
agent = dqn_agent.DqnAgent(
    train_env.time_step_spec(),
    train_env.action_spec(),
    q_network=q_net,
    optimizer=optimizer,
    td_errors_loss_fn=common.element_wise_squared_loss,
    train_step_counter=train_step_counter)
agent.initialize()

# Initialize the replay buffer.
replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
    data_spec=agent.collect_data_spec,
    batch_size=train_env.batch_size,
    max_length=100000)

# Initialize the data collection policy and collect some initial data.
collect_policy = agent.collect_policy
random_policy = random_tf_policy.RandomTFPolicy(train_env.time_step_spec(),
                                                train_env.action_spec())
initial_collect_steps = 1000
for _ in range(initial_collect_steps):
    time_step = train_env.current_time_step()
    action_step = random_policy.action(time_step)
    next_time_step = train_env.step(action_step.action)
    traj = trajectory.from_transition(time_step, action_step, next_time_step)
    replay_buffer.add_batch(traj)

# Set up the dataset.
dataset = replay_buffer.as_dataset(
    num_parallel_calls=3,
    sample_batch_size=64,
    num_steps=2).prefetch(3)

iterator = iter(dataset)


In [11]:

# Train the agent.
num_iterations = 20000
log_interval = 200

@tf.function
def train_step():
    time_step = train_env.current_time_step()
    action_step = agent.collect_policy.action(time_step)
    next_time_step = train_env.step(action_step.action)
    traj = trajectory.from_transition(time_step, action_step, next_time_step)
    replay_buffer.add_batch(traj)

    # Opponent's turn
    if not next_time_step.is_last():
        time_step = train_env.current_time_step()
        action_step = random_policy.action(time_step)
        next_time_step = train_env.step(action_step.action)
        traj = trajectory.from_transition(time_step, action_step, next_time_step)
        replay_buffer.add_batch(traj)

    experience, unused_info = next(iterator)
    return agent.train(experience).loss

for _ in range(num_iterations):
    train_loss = train_step()
    if train_step_counter.numpy() % log_interval == 0:
        print('step = {0}: loss = {1}'.format(train_step_counter.numpy(), train_loss))


step = 200: loss = 15.778207778930664
step = 400: loss = 44.207244873046875
step = 600: loss = 53.20875549316406
step = 800: loss = 31.317123413085938
step = 1000: loss = 8.032379150390625
step = 1200: loss = 7.57925271987915
step = 1400: loss = 1.5052998065948486
step = 1600: loss = 10.81866455078125
step = 1800: loss = 2.0842370986938477
step = 2000: loss = 17.604158401489258
step = 2200: loss = 2.1102828979492188
step = 2400: loss = 6.711139678955078
step = 2600: loss = 3.1410837173461914
step = 2800: loss = 0.8111436367034912
step = 3000: loss = 2.0690407752990723
step = 3200: loss = 2.0183353424072266
step = 3400: loss = 2.814595937728882
step = 3600: loss = 0.9827808141708374
step = 3800: loss = 5.350802421569824
step = 4000: loss = 3.635037899017334
step = 4200: loss = 1.4304511547088623
step = 4400: loss = 1.0700972080230713
step = 4600: loss = 3.3593015670776367
step = 4800: loss = 0.6524667739868164
step = 5000: loss = 1.956712007522583
step = 5200: loss = 2.027306318283081
s