In [1]:
import stable_baselines3
from stable_baselines3.common import env_checker

from gym import Env
from gym.spaces import Discrete, Box

import numpy as np
import random

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
import game as connect4_lib

In [91]:
class Connect4(Env):
    
    def __init__(self, rival_model):
        self.game = None # will be Initiallized with reset
        self.rival_model = rival_model
        self.action_space = Discrete(7)
        self.observation_space = Box(low=0, high=2, shape=(1,6,7), dtype=np.uint8)
        self.actions = ["a", "b", "c", "d", "e", "f", "g"]
        
    def step(self, action):
        self.game.do_turn(self.actions[action])

        # Player1 wins
        if self.game.is_winning():
            reward = 1000
            state = self.get_obs()
            done = True
            return state, reward, done, {}
        
        state_ = self.get_obs()
        rival_move,_ = self.rival_model.predict(state_)
        self.game.do_turn(self.actions[rival_move])
        
        # Rival Wins
        if self.game.is_winning():
            reward = -1000
            state = self.get_obs()
            done = True
            return state, reward, done, {}
           
        # Tie
        if self.game.is_full():
            reward = -10
            state = self.get_obs()
            done = True
            return state, reward, done, {}
        
        # Regular Turn
        reward = -1 # you pay for spending time and not winning
        state = self.get_obs()
        done = False
        return state, reward, done, {}
    
    def close(self):
        pass
    
    def render(self):
        self.game.plot()
    
    def reset(self):
        self.game = connect4_lib.GAME()
        return self.get_obs()
    
    def get_obs(self):
        return self.game.board[None].astype(np.uint8)

In [26]:
env = Connect4(rival)
obs = env.reset()

In [13]:
rival = lambda x: random.randint(0, 6)

In [31]:
env = Connect4(rival)
env_checker.check_env(env)



In [6]:
for game in range(1):
    env = Connect4(rival)
    obs = env.reset()
    done = False
    total_reward = 0
    while not done:
        obs, reward, done, _ = env.step(rival(obs))
        total_reward += reward
    print(total_reward)

-1012


In [100]:
from stable_baselines3.common.monitor import Monitor

In [105]:
# Import base callback 
from stable_baselines3.common.callbacks import BaseCallback
import os
class TrainAndLoggingCallback(BaseCallback):

    def __init__(self, check_freq, save_path, verbose=1):
        super(TrainAndLoggingCallback, self).__init__(verbose)
        self.check_freq = check_freq
        self.save_path = save_path

    def _init_callback(self):
        if self.save_path is not None:
            os.makedirs(self.save_path, exist_ok=True)

    def _on_step(self):
        if self.n_calls % self.check_freq == 0:
            model_path = os.path.join(self.save_path, 'best_model_{}'.format(self.n_calls))
            self.model.save(model_path)

        return True



In [107]:
LOG_DIR = './logs/'
callback = TrainAndLoggingCallback(check_freq=10000, save_path="./train")

In [113]:
for i in range(8):
    env = Connect4(last_model)
    env = Monitor(env, LOG_DIR)
    model = stable_baselines3.PPO('MlpPolicy', env, tensorboard_log=LOG_DIR, verbose=1)
    model.learn(100_000, callback=callback)
    last_model = model

Using cpu device
Wrapping the env in a DummyVecEnv.
Logging to ./logs/PPO_3
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 8.8      |
|    ep_rew_mean     | -308     |
| time/              |          |
|    fps             | 386      |
|    iterations      | 1        |
|    time_elapsed    | 5        |
|    total_timesteps | 2048     |
---------------------------------
------------------------------------------
| rollout/                |              |
|    ep_len_mean          | 9.57         |
|    ep_rew_mean          | -329         |
| time/                   |              |
|    fps                  | 372          |
|    iterations           | 2            |
|    time_elapsed         | 10           |
|    total_timesteps      | 4096         |
| train/                  |              |
|    approx_kl            | 0.0068449588 |
|    clip_fraction        | 0.0444       |
|    clip_range           | 0.2          |
|    entropy_loss        

In [114]:
for i in range(10):
    env = Connect4(last_model)
    env = Monitor(env, LOG_DIR)
    model = stable_baselines3.PPO('MlpPolicy', env, tensorboard_log=LOG_DIR, verbose=1)
    model.learn(300_000, callback=callback)
    last_model = model

Using cpu device
Wrapping the env in a DummyVecEnv.
Logging to ./logs/PPO_11
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 8.03     |
|    ep_rew_mean     | -407     |
| time/              |          |
|    fps             | 509      |
|    iterations      | 1        |
|    time_elapsed    | 4        |
|    total_timesteps | 2048     |
---------------------------------
------------------------------------------
| rollout/                |              |
|    ep_len_mean          | 8.36         |
|    ep_rew_mean          | -307         |
| time/                   |              |
|    fps                  | 450          |
|    iterations           | 2            |
|    time_elapsed         | 9            |
|    total_timesteps      | 4096         |
| train/                  |              |
|    approx_kl            | 0.0064760996 |
|    clip_fraction        | 0.0137       |
|    clip_range           | 0.2          |
|    entropy_loss       

In [115]:
for game in range(10):
    env = Connect4(model)
    obs = env.reset()
    done = False
    total_reward = 0
    while not done:
        action, _ = model.predict(obs)
        obs, reward, done, _ = env.step(action)
        total_reward += reward
        env.render()
    print(total_reward)

  0   0   0   0   0   0   0 
  0   0   0   0   0   0   0 
  0   0   0   0   0   0   0 
  0   0   0   0   0   0   0 
  0   0   0   2   0   0   0 
  0   0   0   1   0   0   0 
  A   B   C   D   E   F   G 
  0   0   0   0   0   0   0 
  0   0   0   0   0   0   0 
  0   0   0   2   0   0   0 
  0   0   0   1   0   0   0 
  0   0   0   2   0   0   0 
  0   0   0   1   0   0   0 
  A   B   C   D   E   F   G 
  0   0   0   0   0   0   0 
  0   0   0   0   0   0   0 
  0   0   0   2   0   0   0 
  0   0   0   1   0   0   0 
  0   0   0   2   0   0   0 
  0   1   2   1   0   0   0 
  A   B   C   D   E   F   G 
  0   0   0   0   0   0   0 
  0   0   0   1   0   0   0 
  0   0   0   2   0   0   0 
  0   0   0   1   0   0   0 
  0   2   0   2   0   0   0 
  0   1   2   1   0   0   0 
  A   B   C   D   E   F   G 
  0   0   0   2   0   0   0 
  0   0   0   1   0   0   0 
  0   0   0   2   0   0   0 
  0   0   0   1   0   0   0 
  0   2   1   2   0   0   0 
  0   1   2   1   0   0   0 
  A   B   C   

In [94]:
evaluate_policy(model, env, n_eval_episodes=100)

(713.87, 698.4004818870044)