In [1]:
import numpy as np
import random
import gymnasium as gym
from gymnasium import spaces
from stable_baselines3 import PPO
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.callbacks import CheckpointCallback
from stable_baselines3.common.vec_env import DummyVecEnv
import os
import copy

  from pandas.core import (


## We will be taking X's as 1 and O's as -1 and by default empty is zero

In [2]:
class TicTacToeEnv:
    def __init__(self):
        self.board = np.zeros((3, 3), dtype=int)
        self.current_player = 1
        
    def reset(self):
        self.board[:] = 0
        self.current_player = 1
        return self.get_state()
    
    def get_state(self):
        return self.board.copy()
    
    def get_valid_actions(self):
        return [(i, j) for i in range(3) for j in range(3) if self.board[i, j] == 0]
    
    def step(self, action):
        i, j = action
        if self.board[i, j] != 0:
            raise ValueError("Invalid action: Cell is not empty")
            
        self.board[i, j] = self.current_player
        done, winner = self.is_done()
        reward = self.get_reward(done, winner)
        
        self.current_player *= -1
        return self.get_state(), reward, done, {"Winner": winner}
    
    def is_done(self):
        for i in range(3):
            if abs(sum(self.board[i, :])) == 3:
                return True, np.sign(sum(self.board[i, :]))
            if abs(sum(self.board[:, i])) == 3:
                return True, np.sign(sum(self.board[:, i]))
        if abs(sum([self.board[i, i] for i in range(3)])) == 3:
            return True, np.sign(sum([self.board[i, i] for i in range(3)]))
        if abs(sum([self.board[i, 2 - i] for i in range(3)])) == 3:
            return True, np.sign(sum([self.board[i, 2 - i] for i in range(3)]))
        
        if not self.get_valid_actions():
            return True, 0 # DRAW
        
        return False, None
    
    def get_reward(self, done, winner):
        if not done:
            return 0
        if winner == 1:
            return 1
        elif winner == -1:
            return -1
        else:
            return 0.5 # draw
        
    def render(self):
        symbols = {1: 'X', -1: 'O', 0: '.'}
        for row in self.board:
            print(" ".join(symbols[val] for val in row))
        print()

## Testing the class

In [3]:
env = TicTacToeEnv()

state = env.reset()
env.render()

. . .
. . .
. . .



In [4]:
for _ in range(5):
    valid_actions = env.get_valid_actions()
    action = random.choice(valid_actions)
    print(f"Player {env.current_player} plays {action}")
    
    state, reward, done, info = env.step(action)
    env.render()
    
    if done:
        print("Game Over!")
        if info["winner"] == 1:
            print("X wins!")
        elif info["winner"] == -1:
            print("O wins!")
        else:
            print("It's a draw!")
        break

Player 1 plays (0, 1)
. X .
. . .
. . .

Player -1 plays (2, 1)
. X .
. . .
. O .

Player 1 plays (0, 0)
X X .
. . .
. O .

Player -1 plays (1, 1)
X X .
. O .
. O .

Player 1 plays (2, 2)
X X .
. O .
. O X



In [5]:
env.reset()

env.render()

. . .
. . .
. . .



## Creating Gym Compatible Wrapper

In [6]:
class TicTacToeOpponentEnv(gym.Env):
    def __init__(self, opponent):
        super(TicTacToeOpponentEnv, self).__init__()
        self.env = TicTacToeEnv()
        self.opponent = opponent
        
        self.observation_space = spaces.Box(low=-1, high=1, shape=(3, 3), dtype=np.int8)
        self.action_space = spaces.Discrete(9)
        
    def reset(self, *, seed=None, options=None):
        super().reset(seed=seed)  # sets the RNG if needed

        state = self.env.reset()

        if self.env.current_player == -1:
            opp_action = self.opponent.select_action(state, self.env.get_valid_actions())
            self.env.step(opp_action)
            state = self.env.get_state()

        return state, {} 
    
    def step(self, action_index):
        valid_actions = self.env.get_valid_actions()

        row, col = divmod(action_index, 3)

        if (row, col) not in valid_actions:
            reward = -2.5
            row, col = valid_actions[np.random.randint(len(valid_actions))]
        else:
            reward = 0 
        
        state, step_reward, done, info = self.env.step((row, col))
        
        reward += step_reward

        if not done and self.env.current_player == -1:
            opp_action = self.opponent.select_action(state, self.env.get_valid_actions())
            state, reward, done, info = self.env.step(opp_action)

            if done:
                reward *= -1

        return state, reward, done, False, info
    
    def render(self):
        self.env.render()

## Logs saving

In [7]:
log_dir = "./ppo_logs"
model_dir = "./saved_models"
os.makedirs(log_dir, exist_ok=True)

## Testing against Human

In [8]:
def play_against_model(model, env):
    inner_env = env.envs[0].env  # unwrap the actual TicTacToeOpponentEnv
    obs = env.reset()[0]         # unwrap observation
    done = False

    print("You are 'O' (represented by -1), model is 'X' (represented by 1)")
    print("Cell numbers:\n1 | 2 | 3\n4 | 5 | 6\n7 | 8 | 9\n")

    while not done:
        if inner_env.current_player == -1:
            # Human's turn
            valid = inner_env.get_valid_actions()
            inner_env.render()
            print(f"Valid moves: {[r * 3 + c + 1 for r, c in valid]}")
            while True:
                try:
                    move = int(input("Enter your move (1-9): ")) - 1
                    action = (move // 3, move % 3)
                    if action in valid:
                        break
                    print("Invalid move. Try again.")
                except ValueError:
                    print("Please enter a number between 1 and 9.")
            obs, reward, done, _ = inner_env.step(action)
        else:
            # Model's turn
            action_flat, _ = model.predict(obs)
            action_flat = int(action_flat)
            action = (action_flat // 3, action_flat % 3)
            obs, reward, done, _ = inner_env.step(action)

    inner_env.render()
    if reward == 1:
        print("Model won!")
    elif reward == -1:
        print("You won!")
    else:
        print("It's a draw!")

## 1. Random Agent vs RL

In [9]:
class MinMaxAgent:
    def __init__(self):
        self.memo = {}

    def select_action(self, state, valid_actions):
        best_score = float('-inf')
        best_action = None
        
        for action in valid_actions:
            next_state = self.apply_action(state, action, player=1)  # MinMaxAgent is 1
            score = self.minimax(next_state, maximizing=False)
            if score > best_score:
                best_score = score
                best_action = action
        
        return best_action

    def minimax(self, state, maximizing):
        # Check if the state has already been evaluated
        state_tuple = tuple(map(tuple, state))  # Convert to tuple for immutability and hashing
        if state_tuple in self.memo:
            return self.memo[state_tuple]  # Return cached result

        # Evaluate the current state
        winner = self.check_winner(state)
        if winner == 1:
            self.memo[state_tuple] = 1.0
            return 1.0
        elif winner == -1:
            self.memo[state_tuple] = -1.0
            return -1.0
        elif self.is_full(state):
            self.memo[state_tuple] = 0.5
            return 0.5

        valid_actions = self.get_valid_actions(state)

        if maximizing:
            max_eval = float('-inf')
            for action in valid_actions:
                next_state = self.apply_action(state, action, player=1)
                eval = self.minimax(next_state, maximizing=False)
                max_eval = max(max_eval, eval)
            self.memo[state_tuple] = max_eval
            return max_eval
        else:
            min_eval = float('inf')
            for action in valid_actions:
                next_state = self.apply_action(state, action, player=-1)
                eval = self.minimax(next_state, maximizing=True)
                min_eval = min(min_eval, eval)
            self.memo[state_tuple] = min_eval
            return min_eval

    def apply_action(self, state, action, player):
        next_state = copy.deepcopy(state)
        i, j = action
        next_state[i][j] = player
        return next_state

    def get_valid_actions(self, state):
        actions = []
        for i in range(len(state)):
            for j in range(len(state[0])):
                if state[i][j] == 0:  # empty cells are 0, not '-'
                    actions.append((i, j))
        return actions

    def is_full(self, state):
        for row in state:
            for cell in row:
                if cell == 0:  # empty cells are 0
                    return False
        return True

    def check_winner(self, state):
        N = len(state)
        for i in range(N):
            if state[i, 0] != 0 and all(state[i, j] == state[i, 0] for j in range(N)):
                return state[i, 0]
            if state[0, i] != 0 and all(state[j, i] == state[0, i] for j in range(N)):
                return state[0, i]
        if state[0, 0] != 0 and all(state[i, i] == state[0, 0] for i in range(N)):
            return state[0, 0]
        if state[0, N-1] != 0 and all(state[i, N-1-i] == state[0, N-1] for i in range(N)):
            return state[0, N-1]
        return None


In [10]:
random_opponent = MinMaxAgent()

monitored_env = Monitor(TicTacToeOpponentEnv(random_opponent), log_dir)
checkpoint_callback = CheckpointCallback(save_freq=25000, save_path=model_dir, name_prefix="ppo_model")

model = PPO("MlpPolicy", monitored_env, verbose=1, tensorboard_log=log_dir)

model.learn(total_timesteps=100_000, callback=checkpoint_callback)

Using cpu device
Wrapping the env in a DummyVecEnv.
Logging to ./ppo_logs\PPO_5
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 3.73     |
|    ep_rew_mean     | 0        |
| time/              |          |
|    fps             | 1202     |
|    iterations      | 1        |
|    time_elapsed    | 1        |
|    total_timesteps | 2048     |
---------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 3.79        |
|    ep_rew_mean          | -0.005      |
| time/                   |             |
|    fps                  | 960         |
|    iterations           | 2           |
|    time_elapsed         | 4           |
|    total_timesteps      | 4096        |
| train/                  |             |
|    approx_kl            | 0.013399223 |
|    clip_fraction        | 0.171       |
|    clip_range           | 0.2         |
|    entropy_loss         | -2.19 

-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 3.2         |
|    ep_rew_mean          | 0.975       |
| time/                   |             |
|    fps                  | 859         |
|    iterations           | 11          |
|    time_elapsed         | 26          |
|    total_timesteps      | 22528       |
| train/                  |             |
|    approx_kl            | 0.015729595 |
|    clip_fraction        | 0.231       |
|    clip_range           | 0.2         |
|    entropy_loss         | -1.49       |
|    explained_variance   | 0.179       |
|    learning_rate        | 0.0003      |
|    loss                 | -0.00306    |
|    n_updates            | 100         |
|    policy_gradient_loss | -0.0209     |
|    value_loss           | 0.103       |
-----------------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 3.23  

-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 3.01        |
|    ep_rew_mean          | 1           |
| time/                   |             |
|    fps                  | 857         |
|    iterations           | 21          |
|    time_elapsed         | 50          |
|    total_timesteps      | 43008       |
| train/                  |             |
|    approx_kl            | 0.014105683 |
|    clip_fraction        | 0.169       |
|    clip_range           | 0.2         |
|    entropy_loss         | -0.891      |
|    explained_variance   | 0.17        |
|    learning_rate        | 0.0003      |
|    loss                 | -0.0489     |
|    n_updates            | 200         |
|    policy_gradient_loss | -0.0122     |
|    value_loss           | 0.0211      |
-----------------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 3.03  

-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 3.01        |
|    ep_rew_mean          | 1           |
| time/                   |             |
|    fps                  | 854         |
|    iterations           | 31          |
|    time_elapsed         | 74          |
|    total_timesteps      | 63488       |
| train/                  |             |
|    approx_kl            | 0.018637965 |
|    clip_fraction        | 0.122       |
|    clip_range           | 0.2         |
|    entropy_loss         | -0.724      |
|    explained_variance   | -0.0673     |
|    learning_rate        | 0.0003      |
|    loss                 | -0.0658     |
|    n_updates            | 300         |
|    policy_gradient_loss | -0.00246    |
|    value_loss           | 0.00837     |
-----------------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 3     

---------------------------------------
| rollout/                |           |
|    ep_len_mean          | 3         |
|    ep_rew_mean          | 1         |
| time/                   |           |
|    fps                  | 865       |
|    iterations           | 41        |
|    time_elapsed         | 96        |
|    total_timesteps      | 83968     |
| train/                  |           |
|    approx_kl            | 0.0059626 |
|    clip_fraction        | 0.068     |
|    clip_range           | 0.2       |
|    entropy_loss         | -0.579    |
|    explained_variance   | 0.908     |
|    learning_rate        | 0.0003    |
|    loss                 | 0.00881   |
|    n_updates            | 400       |
|    policy_gradient_loss | -0.00255  |
|    value_loss           | 1.92e-06  |
---------------------------------------
-----------------------------------------
| rollout/                |             |
|    ep_len_mean          | 3           |
|    ep_rew_mean          | 1     

<stable_baselines3.ppo.ppo.PPO at 0x18ace8b1f50>

In [None]:
model = PPO.load("./saved_models/ppo_model_100000_steps.zip")

opponent = MinMaxAgent()
opponent_env = TicTacToeOpponentEnv(opponent)
vec_env = DummyVecEnv([lambda: opponent_env])

play_against_model(model, vec_env)

You are 'O' (represented by -1), model is 'X' (represented by 1)
Cell numbers:
1 | 2 | 3
4 | 5 | 6
7 | 8 | 9

. . .
. X .
. . .

Valid moves: [1, 2, 3, 4, 6, 7, 8, 9]
Enter your move (1-9): 3
. . O
. X .
. X .

Valid moves: [1, 2, 4, 6, 7, 9]
Enter your move (1-9): 2
. O O
. X .
. X X

Valid moves: [1, 4, 6, 7]
