# Setup

In [None]:
!pip install --upgrade keras
!pip install --upgrade jax jaxlib
!pip install --upgrade numpy



In [None]:
import os

os.environ["KERAS_BACKEND"] = "jax"

import keras
from keras import layers, Input

import numpy as np
from connect4 import Connect4

gamma = 0.99  # Discount factor for past rewards
epsilon = 1.0  # Epsilon greedy parameter
epsilon_min = 0.1  # Minimum epsilon greedy parameter
epsilon_max = 1.0  # Maximum epsilon greedy parameter
epsilon_interval = (
    epsilon_max - epsilon_min
)  # Rate at which to reduce chance of random action being taken

reward_punishment_const = 10000

# Big training
max_games = 10000
trainging_game_batch_size = 100

# Number of frames to take random action and observe output
epsilon_random_games = 1000
# Number of frames for exploration
epsilon_greedy_games = 2000

update_network = 100

# Small training
# max_games = 320
# trainging_game_batch_size = 32
# epsilon_random_games = 32
# epsilon_greedy_games = 64
# update_network = 64

connect_4 = Connect4()

# Deep Q-Network

In [None]:
num_actions = 7

def create_q_model():
    # Network defined by the Deepmind paper
    model = keras.Sequential(
        [
            layers.Conv2D(16, (4, 4), input_shape=(6, 7, 1), activation='relu'),
            layers.Flatten(),
            layers.Dense(512, activation="relu"),
            layers.Dense(num_actions, activation="linear"),
        ]
    )
    model.compile(
        loss='mse',
        optimizer=keras.optimizers.Adam(learning_rate=0.25, clipnorm=1.0)
    )
    return model

# The first model makes the predictions for Q-values which are used to
# make a action.
model = create_q_model()
# Build a target model for the prediction of future rewards.
# The weights of a target model get updated every 10000 steps thus when the
# loss between the Q-values is calculated the target Q-value is stable.
model_target = create_q_model()


  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


# Train

In [None]:
num_games = 1

# History
memory = []
games_states = []
games_next_states = []
games_actions = []
games_rewards = []
games_done = []

def transform_state(state):
  tensor = keras.ops.convert_to_tensor(
      keras.ops.array(state)
  ).reshape(1, 6, 7, 1)
  return tensor

def flatten_list(lst):
  return [
      lst[i][j]
      for i in range(len(lst))
      for j in range(len(lst[i]))
  ]

while num_games < max_games:
  print("num_games", num_games)
  state = connect_4.reset()
  state = transform_state(state)

  done = False
  game_states = []
  game_next_states = []
  game_turns = []
  game_actions = []
  game_rewards = []
  game_done = []
  if num_games < epsilon_random_games:
      print("Random game")
  else:
    print("Normal game")
  while not done:
    if num_games < epsilon_random_games or epsilon > np.random.rand(1)[0]:
      action = np.random.choice(num_actions)
    else:
      action_probs = model(state, training=False)
      action = keras.ops.argmax(action_probs[0])

    epsilon -= epsilon_interval / epsilon_greedy_games
    epsilon = max(epsilon, epsilon_min)
    next_state, reward, turn, done = connect_4.move(action)
    next_state = transform_state(next_state)
    game_states.append(state)
    game_next_states.append(next_state)
    game_turns.append(turn)
    game_actions.append(action)
    game_rewards.append(reward)
    game_done.append(abs(done))
    state = next_state
  print("Winner", done)
  for i in range(len(game_rewards)):
    turn = game_turns[i]
    if turn == done:
      game_rewards[i] += reward_punishment_const
    else:
      game_rewards[i] -= reward_punishment_const
  print(model(game_states[0], training=False))
  print("Number of moves", len(game_states))
  games_states.append(game_states)
  games_next_states.append(game_next_states)
  games_actions.append(game_actions)
  games_rewards.append(game_rewards)
  games_done.append(game_done)

  if len(games_states) == trainging_game_batch_size:
    state_sample = flatten_list(games_states)
    next_state_sample = flatten_list(games_next_states)
    action_sample = flatten_list(games_actions)
    rewards_sample = flatten_list(games_rewards)
    done_sample = flatten_list(games_done)
    c_target_sample = []
    print(rewards_sample[0:10])

    for i in range(len(state_sample)):
      state = state_sample[i]
      action = action_sample[i]
      reward = rewards_sample[i]
      done = done_sample[i]
      next_state = next_state_sample[i]

      c_target = [[0 for _ in range(num_actions)]]
      c_target[0][action] = reward
      c_target = keras.ops.array(c_target).reshape(1, 7)
      c_target_sample.append(c_target)

    model.fit(
      state_sample,
      c_target_sample,
      batch_size=trainging_game_batch_size,
      epochs=1,
      verbose=0
    )
    games_states.clear()
    games_next_states.clear()
    games_actions.clear()
    games_rewards.clear()
    games_done.clear()


  if num_games % update_network == 0:
    model_target.set_weights(model.get_weights())
    print("Updated Network")
  num_games += 1

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
ILLEGAL
Winner -1
[[-2.2894437 -2.8306122 -3.4553542  3.1141994  3.3824804 15.5853
   3.5228004]]
Number of moves 9
num_games 9288
Normal game
ILLEGAL
Winner -1
[[-2.2894437 -2.8306122 -3.4553542  3.1141994  3.3824804 15.5853
   3.5228004]]
Number of moves 7
num_games 9289
Normal game
ILLEGAL
Winner 1
[[-2.2894437 -2.8306122 -3.4553542  3.1141994  3.3824804 15.5853
   3.5228004]]
Number of moves 8
num_games 9290
Normal game
ILLEGAL
Winner -1
[[-2.2894437 -2.8306122 -3.4553542  3.1141994  3.3824804 15.5853
   3.5228004]]
Number of moves 7
num_games 9291
Normal game
ILLEGAL
Winner -1
[[-2.2894437 -2.8306122 -3.4553542  3.1141994  3.3824804 15.5853
   3.5228004]]
Number of moves 7
num_games 9292
Normal game
ILLEGAL
Winner 1
[[-2.2894437 -2.8306122 -3.4553542  3.1141994  3.3824804 15.5853
   3.5228004]]
Number of moves 8
num_games 9293
Normal game
ILLEGAL
Winner 1
[[-2.2894437 -2.8306122 -3.4553542  3.1141994  3.3824804 15.58

In [None]:
model_target.set_weights(model.get_weights())
model.save('connect4AI.keras')