In [1]:
import numpy as np
import random
import gymnasium as gym
from gymnasium import spaces
from stable_baselines3 import PPO
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.callbacks import CheckpointCallback
from stable_baselines3.common.vec_env import DummyVecEnv
import os

  from pandas.core import (


## We will be taking X's as 1 and O's as -1 and by default empty is zero

In [2]:
class TicTacToeEnv:
    def __init__(self):
        self.board = np.zeros((3, 3), dtype=int)
        self.current_player = 1
        
    def reset(self):
        self.board[:] = 0
        self.current_player = 1
        return self.get_state()
    
    def get_state(self):
        return self.board.copy()
    
    def get_valid_actions(self):
        return [(i, j) for i in range(3) for j in range(3) if self.board[i, j] == 0]
    
    def step(self, action):
        i, j = action
        if self.board[i, j] != 0:
            raise ValueError("Invalid action: Cell is not empty")
            
        self.board[i, j] = self.current_player
        done, winner = self.is_done()
        reward = self.get_reward(done, winner)
        
        self.current_player *= -1
        return self.get_state(), reward, done, {"Winner": winner}
    
    def is_done(self):
        for i in range(3):
            if abs(sum(self.board[i, :])) == 3:
                return True, np.sign(sum(self.board[i, :]))
            if abs(sum(self.board[:, i])) == 3:
                return True, np.sign(sum(self.board[:, i]))
        if abs(sum([self.board[i, i] for i in range(3)])) == 3:
            return True, np.sign(sum([self.board[i, i] for i in range(3)]))
        if abs(sum([self.board[i, 2 - i] for i in range(3)])) == 3:
            return True, np.sign(sum([self.board[i, 2 - i] for i in range(3)]))
        
        if not self.get_valid_actions():
            return True, 0 # DRAW
        
        return False, None
    
    def get_reward(self, done, winner):
        if not done:
            return 0
        if winner == 1:
            return 1
        elif winner == -1:
            return -1
        else:
            return 0.5 # draw
        
    def render(self):
        symbols = {1: 'X', -1: 'O', 0: '.'}
        for row in self.board:
            print(" ".join(symbols[val] for val in row))
        print()

## Testing the class

In [3]:
env = TicTacToeEnv()

state = env.reset()
env.render()

. . .
. . .
. . .



In [4]:
for _ in range(5):
    valid_actions = env.get_valid_actions()
    action = random.choice(valid_actions)
    print(f"Player {env.current_player} plays {action}")
    
    state, reward, done, info = env.step(action)
    env.render()
    
    if done:
        print("Game Over!")
        if info["winner"] == 1:
            print("X wins!")
        elif info["winner"] == -1:
            print("O wins!")
        else:
            print("It's a draw!")
        break

Player 1 plays (0, 2)
. . X
. . .
. . .

Player -1 plays (2, 1)
. . X
. . .
. O .

Player 1 plays (1, 2)
. . X
. . X
. O .

Player -1 plays (0, 0)
O . X
. . X
. O .

Player 1 plays (1, 1)
O . X
. X X
. O .



In [5]:
env.reset()

env.render()

. . .
. . .
. . .



## Creating Gym Compatible Wrapper

In [3]:
class TicTacToeSelfPlayEnv(gym.Env):
    def __init__(self, opponent_model=None):
        super().__init__()
        self.env = TicTacToeEnv()
        self.opponent_model = opponent_model
        
        self.observation_space = spaces.Box(low=-1, high=1, shape=(3, 3), dtype=np.int8)
        self.action_space = spaces.Discrete(9)
        
    def reset(self, *, seed=None, options=None):
        super().reset(seed=seed)
        obs = self.env.reset()

        # If opponent goes first
        if self.env.current_player == -1 and self.opponent_model:
            obs = self._opponent_step(obs)

        return self._get_obs(), {}
    
    def step(self, action_index):
        valid_actions = self.env.get_valid_actions()
        row, col = divmod(action_index, 3)
        
        if (row, col) not in valid_actions:
            row, col = valid_actions[np.random.randint(len(valid_actions))]
            
        obs, reward, done, info = self.env.step((row, col))
        
        if done:
            return self._get_obs(), reward, done, False, info
        
        if self.env.current_player == -1 and self.opponent_model:
            obs = self._opponent_step(obs)

            # Check if game ended after opponent played
            done, winner = self.env.is_done()
            reward = self.env.get_reward(done, winner)

            info["winner"] = winner
            return self._get_obs(), reward, done, False, info

        return self._get_obs(), 0.0, False, False, {}
    
    def _opponent_step(self, obs):
        action, _ = self.opponent_model.predict(obs * -1)  # Flip for opponent's perspective
        row, col = divmod(action, 3)
        valid_actions = self.env.get_valid_actions()
        if (row, col) not in valid_actions:
            row, col = valid_actions[np.random.randint(len(valid_actions))]
        self.env.step((row, col))
        return self.env.get_state()

    def _get_obs(self):
        return self.env.get_state() * self.env.current_player  # Perspective-based obs

    def render(self):
        self.env.render()

## Logs saving

In [4]:
log_dir = "./ppo_logs"
model_dir = "./saved_models"
os.makedirs(log_dir, exist_ok=True)

In [8]:
opponent_model = PPO("MlpPolicy", DummyVecEnv([lambda: TicTacToeSelfPlayEnv()]), verbose=0)

env = TicTacToeSelfPlayEnv(opponent_model=opponent_model)
monitored_env = Monitor(env, log_dir)

checkpoint_callback = CheckpointCallback(save_freq=25000, save_path=model_dir, name_prefix="ppo_selfplay_model")

model = PPO("MlpPolicy", monitored_env, verbose=1, tensorboard_log=log_dir)

model.learn(total_timesteps=100_000, callback=checkpoint_callback)

Using cpu device
Wrapping the env in a DummyVecEnv.
Logging to ./ppo_logs\PPO_1
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 4.16     |
|    ep_rew_mean     | 0.31     |
| time/              |          |
|    fps             | 961      |
|    iterations      | 1        |
|    time_elapsed    | 2        |
|    total_timesteps | 2048     |
---------------------------------
------------------------------------------
| rollout/                |              |
|    ep_len_mean          | 4.22         |
|    ep_rew_mean          | 0.335        |
| time/                   |              |
|    fps                  | 798          |
|    iterations           | 2            |
|    time_elapsed         | 5            |
|    total_timesteps      | 4096         |
| train/                  |              |
|    approx_kl            | 0.0070934896 |
|    clip_fraction        | 0.0405       |
|    clip_range           | 0.2          |
|    entropy_loss    

------------------------------------------
| rollout/                |              |
|    ep_len_mean          | 4.06         |
|    ep_rew_mean          | 0.595        |
| time/                   |              |
|    fps                  | 688          |
|    iterations           | 11           |
|    time_elapsed         | 32           |
|    total_timesteps      | 22528        |
| train/                  |              |
|    approx_kl            | 0.0126979165 |
|    clip_fraction        | 0.112        |
|    clip_range           | 0.2          |
|    entropy_loss         | -1.87        |
|    explained_variance   | 0.0708       |
|    learning_rate        | 0.0003       |
|    loss                 | 0.269        |
|    n_updates            | 100          |
|    policy_gradient_loss | -0.0171      |
|    value_loss           | 0.493        |
------------------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_m

-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 3.93        |
|    ep_rew_mean          | 0.59        |
| time/                   |             |
|    fps                  | 683         |
|    iterations           | 21          |
|    time_elapsed         | 62          |
|    total_timesteps      | 43008       |
| train/                  |             |
|    approx_kl            | 0.009628574 |
|    clip_fraction        | 0.0963      |
|    clip_range           | 0.2         |
|    entropy_loss         | -1.53       |
|    explained_variance   | 0.0929      |
|    learning_rate        | 0.0003      |
|    loss                 | 0.185       |
|    n_updates            | 200         |
|    policy_gradient_loss | -0.0156     |
|    value_loss           | 0.357       |
-----------------------------------------
----------------------------------------
| rollout/                |            |
|    ep_len_mean          | 3.9     

-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 3.63        |
|    ep_rew_mean          | 0.83        |
| time/                   |             |
|    fps                  | 685         |
|    iterations           | 31          |
|    time_elapsed         | 92          |
|    total_timesteps      | 63488       |
| train/                  |             |
|    approx_kl            | 0.009111259 |
|    clip_fraction        | 0.104       |
|    clip_range           | 0.2         |
|    entropy_loss         | -1.05       |
|    explained_variance   | 0.0589      |
|    learning_rate        | 0.0003      |
|    loss                 | 0.0289      |
|    n_updates            | 300         |
|    policy_gradient_loss | -0.017      |
|    value_loss           | 0.212       |
-----------------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 3.72  

-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 3.47        |
|    ep_rew_mean          | 0.94        |
| time/                   |             |
|    fps                  | 689         |
|    iterations           | 41          |
|    time_elapsed         | 121         |
|    total_timesteps      | 83968       |
| train/                  |             |
|    approx_kl            | 0.010001609 |
|    clip_fraction        | 0.128       |
|    clip_range           | 0.2         |
|    entropy_loss         | -0.805      |
|    explained_variance   | 0.0947      |
|    learning_rate        | 0.0003      |
|    loss                 | 0.043       |
|    n_updates            | 400         |
|    policy_gradient_loss | -0.0182     |
|    value_loss           | 0.115       |
-----------------------------------------
------------------------------------------
| rollout/                |              |
|    ep_len_mean          | 3.45

<stable_baselines3.ppo.ppo.PPO at 0x1a3e32c62d0>

## Testing against Humans

In [5]:
def play_against_model(model, env):
    inner_env = env.envs[0].env  # unwrap the actual TicTacToeOpponentEnv
    obs = env.reset()[0]         # unwrap observation
    done = False

    print("You are 'O' (represented by -1), model is 'X' (represented by 1)")
    print("Cell numbers:\n1 | 2 | 3\n4 | 5 | 6\n7 | 8 | 9\n")

    while not done:
        if inner_env.current_player == -1:
            # Human's turn
            valid = inner_env.get_valid_actions()
            inner_env.render()
            print(f"Valid moves: {[r * 3 + c + 1 for r, c in valid]}")
            while True:
                try:
                    move = int(input("Enter your move (1-9): ")) - 1
                    action = (move // 3, move % 3)
                    if action in valid:
                        break
                    print("Invalid move. Try again.")
                except ValueError:
                    print("Please enter a number between 1 and 9.")
            obs, reward, done, _ = inner_env.step(action)
        else:
            action_flat, _ = model.predict(obs, deterministic=True)
            action_flat = int(action_flat)
            action = (action_flat // 3, action_flat % 3)
            valid_actions = inner_env.get_valid_actions()

            # Check validity
            if action not in valid_actions:
                # Fallback to a valid random action if invalid
                print("Model chose invalid move, selecting random valid move.")
                action = random.choice(valid_actions)

            obs, reward, done, _ = inner_env.step(action)


    inner_env.render()
    if reward == 1:
        print("Model won!")
    elif reward == -1:
        print("You won!")
    else:
        print("It's a draw!")

In [7]:
model = PPO.load("./saved_models/ppo_selfplay_model_100000_steps.zip")

opponent_env = TicTacToeSelfPlayEnv()
vec_env = DummyVecEnv([lambda: opponent_env])

play_against_model(model, vec_env)

You are 'O' (represented by -1), model is 'X' (represented by 1)
Cell numbers:
1 | 2 | 3
4 | 5 | 6
7 | 8 | 9

. . .
. X .
. . .

Valid moves: [1, 2, 3, 4, 6, 7, 8, 9]
Enter your move (1-9): 1
O . X
. X .
. . .

Valid moves: [2, 4, 6, 7, 8, 9]
Enter your move (1-9): 7
O . X
X X .
O . .

Valid moves: [2, 6, 8, 9]
Enter your move (1-9): 6
Model chose invalid move, selecting random valid move.
O X X
X X O
O . .

Valid moves: [8, 9]
Enter your move (1-9): 8
Model chose invalid move, selecting random valid move.
O X X
X X O
O O X

It's a draw!
