Встановлюємо необхідні для RL-моделі бібліотеку Stable Baseline 3 та бібліотеку для ігрових оточень Gymnasium.

In [2]:
!pip install stable-baselines3
!pip install gymnasium



Імпортуємо всі необхідні бібліотеки

In [3]:
import numpy as np
import random
import torch
import itertools
from torch import nn
from torch.nn import functional as func
import gymnasium as gym
from gymnasium import spaces
import stable_baselines3 as sb3
from stable_baselines3.common.env_checker import check_env
from stable_baselines3 import PPO
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor



Робимо базовий клас для гри в сапера

In [4]:
class MinesweeperGame:

    CLOSED_CELL = -1
    MINE = -2

    def __init__(self, field_size=8, num_mines=8, safe_start=False):
        self.__field = np.empty((field_size, field_size)).astype(np.int8)
        self.__closed_cells = np.empty((field_size, field_size)).astype(bool)
        self.num_mines = num_mines
        self.reset(num_mines, safe_start)

    def reset(self, num_mines=8, safe_start=False):
        self.__field[:, :] = 0
        self.__closed_cells[:, :] = True
        mines = np.random.randint(
            low=0, 
            high=self.__field.shape[0],
            size=(num_mines, 2),
        )
        self.__set_mines_into_field(mines)
        self.endgame = False
        self.score = 0
        if safe_start:
            safe_cell = random.choice(np.argwhere(self.__field == 0))
            self.__closed_cells[*safe_cell] = False
            self.score += 1

    
    def __set_mines_into_field(self, mines_coords):
        for mine in mines_coords:
            mine_y, mine_x = mine
            self.__field[
                max(mine_y - 1, 0) : min(mine_y + 2, self.__field.shape[0]),
                max(mine_x - 1, 0) : min(mine_x + 2, 8)
            ] += 1
        self.__field[*mines_coords.T] = self.MINE

    def get_mine_value(self):
        return self.MINE

    def get_closed_cell_value(self):
        return self.CLOSED_CELL

    def get_field(self):
        out_field = self.__field.copy()
        out_field[self.__closed_cells] = self.CLOSED_CELL
        return out_field

    def print_field(self):
        print_field = self.__field.astype(str)
        print_field[self.__field == self.MINE] = "*"
        print_field[self.__closed_cells] = "#"
        print(print_field)

    def get_moves(self):
        return np.argwhere(self.__closed_cells)

    def invalid_move(self, move):
        return self.__closed_cells[*move] == False
    
    def make_move(self, move):
        if self.endgame:
            return self.score
        if self.__closed_cells[*move] == False:
            return self.score
        self.__closed_cells[*move] = False
        if self.__field[*move] == self.MINE:
            self.endgame = True
            self.__closed_cells[:, :] = False
            self.score = self.MINE
        else:
            self.score += 1
        if self.score + self.num_mines == np.prod(self.__field.shape):
            self.score += self.num_mines
            self.__closed_cells[:, :] = False
            self.endgame = True
        return self.score

    def get_score(self):
        return self.score


Обгортка для `MinesweeperGame` середовищем Gymnasium.

In [24]:
class MinesweeperEnv(gym.Env):

    metadata = {"render_modes": ["console"]}

    def __init__(self, field_size=8, num_mines=8, safe_start=True, render_mode="console"):
        super(MinesweeperEnv, self).__init__()
        self.ms_game = MinesweeperGame(
            field_size=field_size,
            num_mines=num_mines,
            safe_start=safe_start
        )
        self.observation_space = spaces.Box(
            low=-2,
            high=4, 
            shape=(field_size, field_size,),
            dtype=np.int8
        )
        self.action_space = spaces.Discrete(field_size ** 2)
        assert render_mode is None or render_mode in self.metadata["render_modes"]
        self.render_mode = render_mode
        self.field_size = field_size

    def _get_info(self):
        return {"score": self.ms_game.get_score()}

    def _get_obs(self):
        return self.ms_game.get_field()

    def reset(self, seed=None, options={"safe_start" : True}):

        # We need the following line to seed self.np_random
        super().reset(seed=seed)
        self.ms_game.reset(self.ms_game.num_mines, options["safe_start"])
        return self._get_obs(), self._get_info()

    def step(self, action):
        action_transformed = np.unravel_index(action, (self.field_size, self.field_size))
        reward = 0

        if self.ms_game.invalid_move(action_transformed):
            reward = -1
        else:
            prev_score = self.ms_game.get_score()
            current_score = self.ms_game.make_move(action_transformed)
            if current_score > 0 and self.ms_game.endgame:
                reward = self.field_size ** 2
            elif current_score < 0 and self.ms_game.endgame:
                reward = -1 * (self.field_size ** 2)
            elif current_score - prev_score > 0:
                reward = current_score - prev_score
            else:
                reward = 0
        return self._get_obs(), reward, self.ms_game.endgame, False, self._get_info()

    def render(self):
        self.ms_game.print_field()

    def close(self):
        pass


Перевіряємо середовище на валідність

In [25]:
env = MinesweeperEnv()

# If the environment don't follow the interface, an error will be thrown
check_env(env, warn=True)



*Stable Baseline 3* рекомендує зробити кастомний Feature Extractor.

Спочатку функція трансформеру `MSFeatureExtractor.transform` формує з вхідного поля таке, що відображає
оцінку безпечності кроку. Після цього йдуть згорткові шари з послідуючим повнозʼєднаним шаром.

```
Transformer -> CNN -> NN
```

In [237]:
class MSFeatureExtractor(BaseFeaturesExtractor):
    """
    :param observation_space: (gym.Space)
    :param features_dim: (int) Number of features extracted.
        This corresponds to the number of unit for the last layer.
    """

    def __init__(
        self,
        observation_space: spaces.Box, 
        field_size: int = 8,
        features_dim: int = 256,
        closed_cell_indicator: int = -1,
        max_mines_indicator: int = 4,
    ):
        super().__init__(observation_space, features_dim)
        # We assume CxHxW images (channels first)
        # Re-ordering will be done by pre-preprocessing or wrapper
        self.closed_cell_indicator = closed_cell_indicator
        self.field_size = field_size
        self.max_danger_score = max_mines_indicator * 8
        self.cnn = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=3, padding="same"),
            nn.ReLU(),
            # nn.Conv2d(16, 8, kernel_size=3, padding="same"),
            # nn.ReLU(),
            nn.Conv2d(16, 1, kernel_size=5, padding="same"),
            nn.ReLU(),
            nn.Flatten(),
        )

        # Compute shape by doing one forward pass
        with torch.no_grad():
            n_flatten = self.cnn(
                torch.as_tensor(observation_space.sample()[None]).float()
            ).shape[1]

        self.linear = nn.Sequential(nn.Linear(n_flatten, features_dim), nn.ReLU())

    def _get_predictable_moves_mask(self, unobserved_mask):
        diff_x = torch.diff(unobserved_mask, dim=-1)
        diff_y = torch.diff(unobserved_mask, dim=-2)
        diff = torch.zeros(unobserved_mask.shape, dtype=torch.bool)
        diff[:, :, 1 : ] |= diff_x
        diff[:, :, : -1] |= diff_x
        diff[:, 1 :, : ] |= diff_y
        diff[:, : -1, :] |= diff_y
        return diff

    def _single_value_danger_score(self, field):
        mask = func.pad(field, (1, 1, 1, 1))
        danger_score = torch.zeros(field.shape)
        shifts_tuple = ((1, -1), (0, -2), (2, None))
        for x_slice, y_slice in itertools.product(shifts_tuple, shifts_tuple):
            x_low, x_high = x_slice
            y_low, y_high = y_slice
            danger_score += mask[:, y_low : y_high, x_low : x_high]
        return danger_score
    
    def _get_danger_score(self, field):
        field_mine_marks = field.clone()
        field_mine_marks[field_mine_marks == self.closed_cell_indicator] = 0.0
        danger_score = self._single_value_danger_score(field_mine_marks)
        zero_mask = self._single_value_danger_score(field == 0.0).type(torch.bool)
        danger_score[zero_mask] = 0.0
        return danger_score
    
    def transformer(self, mine_field):
        """
        Cells priority for the move
        Safe > unpredictable > dangerous > observed
        """
        unobserved_mask = mine_field == self.closed_cell_indicator
        safe_mask = self._get_predictable_moves_mask(unobserved_mask)
        danger_score = self._get_danger_score(mine_field)

        score_field = torch.zeros(mine_field.shape)
        score_field[safe_mask] += 1.0
        score_field -= danger_score
        score_field[~unobserved_mask] = -self.max_danger_score - 1.0

        score_field += self.max_danger_score + 1.0
        score_field /= self.max_danger_score + 2.0
        score_field = score_field * 2.0 - 1.0
        return score_field.unsqueeze(1)

    def forward(self, observations: torch.Tensor) -> torch.Tensor:
        return self.linear(self.cnn(self.transformer(observations)))

policy_kwargs = dict(
    features_extractor_class=MSFeatureExtractor,
    features_extractor_kwargs=dict(
        field_size=8,
        features_dim=256,
        closed_cell_indicator=-1,
        max_mines_indicator= 4,
    ),
)

Запуск навчання моделі

In [13]:
env = MinesweeperEnv()
env.reset()
model = PPO("MlpPolicy", env, policy_kwargs=policy_kwargs, verbose=1)
model.learn(total_timesteps=int(5e5))

NameError: name 'policy_kwargs' is not defined

Пробна гра RL-моделі

In [14]:
obs, info = env.reset()
env.render()
for i in range(100):
    action, _state = model.predict(obs)
    obs, reward, done, trunc, info = env.step(action)
    if done and reward < 0:
        print("AI lost")
        env.render()
        break
    elif done and reward > 0:
        print("AI won")
        env.render()
        break
    print(f"Reward = {reward}")
    env.render()

[['#' '#' '#' '#' '#' '#' '#' '#']
 ['#' '#' '#' '#' '#' '#' '#' '#']
 ['#' '#' '#' '#' '#' '#' '#' '#']
 ['#' '#' '#' '#' '#' '#' '#' '#']
 ['#' '#' '#' '#' '#' '0' '#' '#']
 ['#' '#' '#' '#' '#' '#' '#' '#']
 ['#' '#' '#' '#' '#' '#' '#' '#']
 ['#' '#' '#' '#' '#' '#' '#' '#']]
Reward = 1
[['#' '#' '#' '#' '#' '#' '#' '#']
 ['#' '#' '#' '#' '#' '#' '#' '#']
 ['#' '#' '#' '#' '#' '#' '#' '#']
 ['#' '#' '#' '#' '#' '#' '#' '#']
 ['#' '#' '#' '#' '#' '0' '0' '#']
 ['#' '#' '#' '#' '#' '#' '#' '#']
 ['#' '#' '#' '#' '#' '#' '#' '#']
 ['#' '#' '#' '#' '#' '#' '#' '#']]
Reward = 1
[['#' '#' '#' '#' '#' '#' '#' '#']
 ['#' '#' '#' '#' '#' '#' '#' '#']
 ['#' '#' '#' '#' '#' '#' '#' '#']
 ['#' '#' '#' '#' '#' '0' '#' '#']
 ['#' '#' '#' '#' '#' '0' '0' '#']
 ['#' '#' '#' '#' '#' '#' '#' '#']
 ['#' '#' '#' '#' '#' '#' '#' '#']
 ['#' '#' '#' '#' '#' '#' '#' '#']]
Reward = 1
[['#' '#' '#' '#' '#' '#' '#' '#']
 ['#' '#' '#' '#' '#' '#' '#' '#']
 ['#' '#' '#' '#' '#' '#' '#' '#']
 ['1' '#' '#' '#' '

Окремий клас для випробовування алгоритму трансофрмації поля у якості окремого класичного
ШІ для гри в сапера. Результатом виступає адреса елементу з максимальною метрикою.

In [27]:
class MineSweeperExtractor():

    def __init__(
        self,
        closed_cell_indicator=-1.0,
        max_mines_indicator=4,
    ):
        self.closed_cell_indicator = closed_cell_indicator
        self.max_danger_score = max_mines_indicator * 8

    def _get_predictable_moves_mask(self, unobserved_mask):
        diff_x = torch.diff(unobserved_mask, dim=-1)
        diff_y = torch.diff(unobserved_mask, dim=-2)
        diff = torch.zeros(unobserved_mask.shape, dtype=torch.bool)
        diff[ :, :, 1 : ] |= diff_x
        diff[ :, :, : -1] |= diff_x
        diff[ :, 1 :, : ] |= diff_y
        diff[ :, : -1, :] |= diff_y
        return diff

    def _single_value_danger_score(self, field):
        mask = func.pad(field, (1, 1, 1, 1))
        danger_score = torch.zeros(field.shape)
        shifts_tuple = ((1, -1), (0, -2), (2, None))
        for x_slice, y_slice in itertools.product(shifts_tuple, shifts_tuple):
            x_low, x_high = x_slice
            y_low, y_high = y_slice
            danger_score += mask[ :, y_low : y_high, x_low : x_high]
        return danger_score
    
    def _get_danger_score(self, field):
        field_mine_marks = field.clone()
        field_mine_marks[field_mine_marks == self.closed_cell_indicator] = 0.0
        danger_score = self._single_value_danger_score(field_mine_marks)
        zero_mask = self._single_value_danger_score(field == 0.0).type(torch.bool)
        danger_score[zero_mask] = 0
        return danger_score

    def forward(self, mine_field) -> torch.Tensor:
        """
        Cells priority for the move
        Safe > unpredictable > dangerous > observed
        """
        if not isinstance(mine_field, torch.Tensor):
            mine_field_ = torch.Tensor(mine_field)
        if mine_field_.ndim == 2:
            mine_field_ = torch.Tensor(mine_field).reshape(1, *mine_field.shape)

        unobserved_mask = mine_field_ == self.closed_cell_indicator
        safe_mask = self._get_predictable_moves_mask(unobserved_mask)
        danger_score = self._get_danger_score(mine_field_)

        score_field = torch.zeros(mine_field_.shape)
        score_field[safe_mask] += 1.0
        score_field -= danger_score
        score_field[~unobserved_mask] = -self.max_danger_score - 1.0

        score_field += self.max_danger_score + 1.0
        score_field /= self.max_danger_score + 2.0
        score_field = score_field * 2.0 - 1.0
        
        return torch.flatten(score_field)

    def predict(self, mine_field) -> int:
        return self.forward(mine_field).argmax().item()


Пробний запуск трансформатору

In [28]:
obs, info = env.reset(options={"safe_start" : True})
ext = MineSweeperExtractor()
env.render()

for i in range(1000):
    action = ext.predict(obs)
    obs, reward, done, trunc, info = env.step(action)
    if done and reward < 0:
        print("AI lost")
        env.render()
        break
    elif done and reward > 0:
        print("AI won")
        env.render()
        break
    print(f"Reward = {reward}")
    env.render()

[['#' '#' '#' '#' '#' '#' '#' '0']
 ['#' '#' '#' '#' '#' '#' '#' '#']
 ['#' '#' '#' '#' '#' '#' '#' '#']
 ['#' '#' '#' '#' '#' '#' '#' '#']
 ['#' '#' '#' '#' '#' '#' '#' '#']
 ['#' '#' '#' '#' '#' '#' '#' '#']
 ['#' '#' '#' '#' '#' '#' '#' '#']
 ['#' '#' '#' '#' '#' '#' '#' '#']]
Reward = 1
[['#' '#' '#' '#' '#' '#' '1' '0']
 ['#' '#' '#' '#' '#' '#' '#' '#']
 ['#' '#' '#' '#' '#' '#' '#' '#']
 ['#' '#' '#' '#' '#' '#' '#' '#']
 ['#' '#' '#' '#' '#' '#' '#' '#']
 ['#' '#' '#' '#' '#' '#' '#' '#']
 ['#' '#' '#' '#' '#' '#' '#' '#']
 ['#' '#' '#' '#' '#' '#' '#' '#']]
Reward = 1
[['#' '#' '#' '#' '#' '#' '1' '0']
 ['#' '#' '#' '#' '#' '#' '1' '#']
 ['#' '#' '#' '#' '#' '#' '#' '#']
 ['#' '#' '#' '#' '#' '#' '#' '#']
 ['#' '#' '#' '#' '#' '#' '#' '#']
 ['#' '#' '#' '#' '#' '#' '#' '#']
 ['#' '#' '#' '#' '#' '#' '#' '#']
 ['#' '#' '#' '#' '#' '#' '#' '#']]
Reward = 1
[['#' '#' '#' '#' '#' '#' '1' '0']
 ['#' '#' '#' '#' '#' '#' '1' '0']
 ['#' '#' '#' '#' '#' '#' '#' '#']
 ['#' '#' '#' '#' '

Хоча модель і заходить далеко, проте у кінцевій стадії такий ШІ дав збой. Тому є простір для покращення.