In [379]:
from __future__ import annotations

import os

from enum import Enum
import numpy as np
import gymnasium as gym
from gymnasium import spaces
from gymnasium.utils import EzPickle
import torch

from typing import Optional, Union

# Support functions

In [380]:
WINNING_COMBINATIONS = torch.tensor((
        (0, 1, 2),
        (3, 4, 5),
        (6, 7, 8),
        (0, 3, 6),
        (1, 4, 7),
        (2, 5, 8),
        (0, 4, 8),
        (2, 4, 6),
))

NULL_ACTION = -1

In [381]:
def n_empty_cells(cells):
    return (cells == 0).int().sum(dim=1)

def pos_to_coord(pos, size=9):
    """Shape (num_envs) -> (2, num_envs)"""
    return torch.stack((pos // size, pos % size))

def coord_to_pos(coord, size):
    """(2, num_envs) -> (num_envs)"""
    assert coord.shape[0] == 2
    return coord[0]*size + coord[1]

def coord_to_super_coord(coord):
    return coord // 3

def coord_to_sub_coord(coord):
    return coord - 3* coord_to_super_coord(coord)

def decompose(cells):
    """Returns subcells from cells. Takes (B, 81) or (B, 9, 9)"""
    return cells.view(-1, 3, 3, 3, 3).permute(0, 1, 3, 2, 4).reshape(-1, 9, 3, 3)

In [382]:
g.current_pos

tensor([40, 80, 21], dtype=torch.int32)

In [357]:
coord_to_super_coord(pos_to_coord(torch.tensor([2, 80, 40]), 9))

tensor([[0, 2, 1],
        [0, 2, 1]])

# Prove

In [373]:
temp = torch.randint(0, 3, (4, 81))
dec = decompose(temp).reshape(-1, 9)
get_game_status(dec).reshape(-1, 9)

tensor([[0, 0, 2, 0, 0, 0, 1, 1, 1],
        [2, 1, 1, 0, 1, 2, 2, 2, 0],
        [0, 0, 0, 0, 2, 0, 0, 1, 0],
        [2, 1, 1, 0, 2, 0, 1, 0, 0]])

In [342]:
temp = torch.randint(0, 3, (8,1))
temp = temp.reshape(-1, 9, 9)
print(temp)
temp_small = temp[:, 3:6, 0:3]

RuntimeError: shape '[-1, 9, 9]' is invalid for input of size 8

In [340]:
temp_small

tensor([[[2, 1, 2],
         [2, 2, 1],
         [0, 1, 2]],

        [[0, 1, 0],
         [2, 1, 2],
         [2, 2, 2]],

        [[0, 1, 1],
         [1, 1, 0],
         [2, 1, 0]],

        [[2, 1, 2],
         [1, 2, 2],
         [1, 1, 0]]])

In [341]:
get_game_status(temp_small.reshape(4, -1))

tensor([2, 2, 1, 0])

Variations:
- can/cannot play in won squares

In [318]:
temp = torch.randint(0, 2, (3, 9)).bool()
a = torch.randint(0, 9, (3,))
temp.view(-1, 3, 3), a

(tensor([[[False,  True, False],
          [ True,  True,  True],
          [ True, False, False]],
 
         [[False, False,  True],
          [False, False, False],
          [ True,  True, False]],
 
         [[ True, False, False],
          [ True, False,  True],
          [ True, False,  True]]]),
 tensor([7, 7, 2]))

In [319]:
temp[torch.arange(len(a)), a].all()

tensor(False)

# Class

In [497]:
t = torch.ones((4, 81))
s = t / t.sum(dim=1).unsqueeze(1)
s[torch.tensor([True, False, False, False])].shape

torch.Size([1, 81])

In [498]:
class Player:
    def get_actions(self, obs):
        raise NotImplementedError

class Bot(Player):

    def get_actions(self, obs: torch.Tensor) -> torch.Tensor:
        """Input a tensor (B, 4, 81). Output action tensor (num_envs,)"""
        mask = obs['mask']
        prob = mask / mask.sum(dim=1).unsqueeze(1)
        # deal with empty mask
        empty_mask = mask.sum(dim=1) == 0
        prob[empty_mask] = 1/81
        return torch.multinomial(prob, num_samples=1).squeeze(dim=1)

In [499]:
t = torch.stack([torch.arange(9) for _ in range(3)])
s = torch.tensor([0, 2, 3])
t[torch.arange(s.shape[0]), s] != 0

tensor([False,  True,  True])

In [512]:
class TorchEnv:
    metadata = {"render_modes": ["rgb_array"]}

    def __init__(
        self,
        num_envs: int = 1, # number of parallel envs
        max_episode_length: int  = 100, # max_episode_length
        max_steps: Optional[int] = None,
        num_rewards: int = 1,
        enable_open_sub_board: bool = False,
        bot: Optional[Player] = None,
        agent_role: Union[int, torch.Tensor] = 0,
        ) -> None:

        # ---ENVIRONMENT MANAGEMENT---
        self.num_envs = num_envs
        self.max_episode_length = max_episode_length
        self.max_steps = max_steps # is None, each episode is run only once
        self.enable_open_sub_board = enable_open_sub_board # allows to place symbols in completed subboards when choosing freely

        self.num_rewards = num_rewards # number of rewards (rewards classes)

        # roles: 0 -> X / 1 -> O
        if isinstance(agent_role, torch.Tensor):
            self.agent_role = agent_role
        elif isinstance(agent_role, int):
            self.agent_role = torch.ones(num_envs) * agent_role
        self.bot = bot if bot is not None else Bot()
        self.bot_role = (self.agent_role + 1) % 2

        self.expanded_cells = torch.zeros((self.num_envs, 82))
        self.cells = self.expanded_cells[:, :81] # cannot see dead cell of expanded cells
        self.super_cells = torch.zeros((self.num_envs, 9))
        self.augmented_super_cells = torch.zeros((self.num_envs, 81))
        # TODO: observation, action space

        # stats
        self._turn = 0 # global turn counter
        self.current_turn = torch.zeros(num_envs) # turn counters for each env

        self.game_status = torch.zeros(num_envs)
        self.terminated = torch.zeros(self.num_envs)
        self.truncated = torch.zeros(self.num_envs)

        self.last_action = torch.ones(self.num_envs) * -1
        self.current_pos = torch.ones(self.num_envs) * -1
        self.current_mask = torch.ones(self.num_envs, 81)
        self.mask1 = None
        self.mask2 = None

        # utils
        self._all_pos = torch.stack([torch.arange(81) for _ in range(self.num_envs)])
    
    def _get_obs(self):
        return {
            'cells': self.cells,
            'mask': self.current_mask,
            'last_action': self.last_action,
            'turn': self.current_turn,
        }

    def step(self, actions):
        # PLAYER TURN
        
        # TODO: validation must take into account finished games
        # assert self.valid_actions(actions), "Some actions are not valid"

        # Player turn
        self.turn(actions, self.agent_role)

        # BOT TURN
        bot_actions = self.bot.get_actions(self._get_obs())
        self.turn(bot_actions, self.bot_role)

        return self._get_obs(), self._get_rewards(), self._is_terminated(), self._is_truncated(), self._get_info()
    
    def turn(self, actions, player):
        """
        Manages player's turn (either agent or bot)
        """
        # nullify action if game ended
        # nullified actions do not propagate to actual cells
        # they are registered on a death cell
        old_actions = actions.clone()
        actions[self.game_status != 0] = -1

        # update cells
        self.expanded_cells[torch.arange(len(actions)), actions] = player + 1
        self.super_cells = self._get_game_status(decompose(self.cells).reshape(-1, 9)).reshape(-1, 9)
        self.augmented_super_cells = self.super_cells.reshape(-1, 3, 3).repeat_interleave(3, dim=1).repeat_interleave(3, dim=2).reshape(-1, 81)
        
        # check game status
        self._turn += 1
        self.current_turn[actions != -1] += 1
        self.game_status = self._get_game_status(self.super_cells)
        self.terminated = self._is_terminated()
        self.truncated = self._is_truncated()
        self.last_action = actions # (num_envs,)
        self.current_pos = actions # TODO set to -1 when the cell becomes a tris
        super_coord = coord_to_sub_coord(pos_to_coord(old_actions, 9))
        super_pos = coord_to_pos(super_coord, 3)

        occupied_mask = self.super_cells[torch.arange(super_pos.shape[0]), super_pos]!=0
        self.current_pos[occupied_mask] = -1

        # self.current_pos[self.super_cells[torch.arange(actions.shape[0]), actions]!=0] = -1
        self.current_mask = self._get_mask()

    def reset(self, agent_role=None):
        # TODO: initialize everything
        if agent_role is None:
            self.agent_role = torch.zeros(self.num_envs) #torch.randint(0, 2, (num_envs))
        else:
            self.agent_role = agent_role

        self.bot_role = (self.agent_role + 1) % 2
        self.expanded_cells = torch.zeros((self.num_envs, 82))
        self.cells = self.expanded_cells[:, :81] # cannot see dead cell of expanded cells
        self.super_cells = torch.zeros((self.num_envs, 9))
        self.augmented_super_cells = torch.zeros((self.num_envs, 81))

        self._turn = 0
        self.current_turn = torch.zeros(self.num_envs)
        self.game_status = torch.zeros(self.num_envs)
        self.last_action = torch.ones(self.num_envs) * -1
        self.current_pos = torch.ones(self.num_envs) * -1
        self.current_mask = torch.ones(self.num_envs, 81)

        # TODO: false start when agent is player 2
        if not ((self.agent_role==0).all() or (self.agent_role==1).all()):
            bot_actions = self.bot.get_actions(self._get_obs())
            bot_actions[self.agent_role==0] = -1 # keep the actions only for games where the bot starts
            self.turn(bot_actions, self.bot_role)
        return self._get_obs(), self._get_info()

    def _get_mask(self, verbose=False):
        # 1. mask based on last action
        super_coord = coord_to_sub_coord(pos_to_coord(self.current_pos, 9))
        # invalid actions
        super_coord[:, (self.current_pos==-1)] = -1
        super_coord = super_coord.unsqueeze(-1)
        all_super_coord = coord_to_super_coord(pos_to_coord(self._all_pos, 9))
        range_width = torch.where(super_coord!=-1, 1, 4)
        filter = (
            (super_coord <= all_super_coord) &
            (all_super_coord < (super_coord + range_width))
        )
        # mask1 should take into account also if the destination is occupied
        mask = filter[0] & filter[1]
        mask1 = mask.clone()

        # 2. filter out cells belonging to completed sub-boards
        if not self.enable_open_sub_board:
            mask &= (self.augmented_super_cells == 0)
        mask2 = mask.clone()

        # 3. filter out occupied cells
        mask &= (self.cells == 0)
        if not verbose:
            return mask.int()
        else:
            return torch.stack([mask1, mask2, mask])

    def valid_actions(self, actions):
        """Returns True if all actions are valid"""
        valid_positions_cond = (self.current_mask[torch.arange(actions.shape[0]), actions] == 1)
        game_ended_cond = (self.game_status != 0) & (actions == -1)
        return (valid_positions_cond | game_ended_cond).all()

    def _get_game_status(self, cells):
        """
        Takes in input a batch (B, 9), each representing a 3x3 game.
        Return status (winner, TTT_TIE if no winner, or TTT_GAME_NOT_OVER) for each game
        """
        results = torch.zeros(cells.shape[0], dtype=torch.int64)
        # check player winning
        for indices in WINNING_COMBINATIONS:
            triplets = cells[:, indices]
            pl1_triplets = (triplets == 1).all(dim=1)
            pl2_triplets = (triplets == 2).all(dim=1)

            # Note: this two conditions are never overlapping
            results[pl1_triplets] = 1 # TODO: use enum
            results[pl2_triplets] = 2

        # check ties
        cells_filled = n_empty_cells(cells) == 0
        games_not_won = results == 0
        results[games_not_won & cells_filled] = -1
        return results
    
    def _is_terminated(self):
        # all games reached the end
        return (self.game_status!=0).all()
    
    def _is_truncated(self):
        #TODO
        return torch.zeros(self.num_envs).bool()

    def _get_info(self):
        # TODO
        return None
    
    def _get_rewards(self):
        # TODO
        return torch.zeros((self.num_envs, self.num_rewards))

In [516]:
g = TorchEnv(5000)
obs, _ = g.reset()
# obs, rew, term, trunc, info = g.step(torch.tensor([0, 1, 9]).int()) #g.step(torch.tensor([36, 79, 2]).int())
pl = Bot()
while not (g.terminated.all() or g.truncated.all()):
    action = pl.get_actions(obs).int()
    obs, rew, term, trunc, info = g.step(action) #g.step(torch.tensor([36, 79, 2]).int())

In [518]:
g.current_turn.mean()

tensor(58.7838)