In [1]:
import gym

import matplotlib as mpl
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from collections import deque
from tqdm import tqdm
import copy

In [2]:
N_ROWS, N_COLS, N_WIN = 3, 3, 3

In [3]:
class TicTacToe(gym.Env):
    def __init__(self, n_rows=N_ROWS, n_cols=N_COLS, n_win=N_WIN, clone=None):
        if clone is not None:
            self.n_rows, self.n_cols, self.n_win = clone.n_rows, clone.n_cols, clone.n_win
            self.board = copy.deepcopy(clone.board)
            self.curTurn = clone.curTurn
            self.gameOver = clone.gameOver
            self.emptySpaces = None
            self.boardHash = None
        else:
            self.n_rows = n_rows
            self.n_cols = n_cols
            self.n_win = n_win
            self.gameOver = False

            self.reset()

    def getEmptySpaces(self):
        if self.emptySpaces is None:
            res = np.where(self.board == 0)
            self.emptySpaces = np.array([ (i, j) for i,j in zip(res[0], res[1]) ])
        return self.emptySpaces

    def makeMove(self, player, i, j):
        self.board[i, j] = player
        self.emptySpaces = None
        self.boardHash = None

    def getHash(self):
        if self.boardHash is None:
            self.boardHash = ''.join(['%s' % (x+1) for x in self.board.reshape(self.n_rows * self.n_cols)])
        return self.boardHash
    
    def decodeHash(self, boardHash):
        return tuple(list(map(int,list(boardHash))))    
    
    def giveObsSpace(self):
        return tuple(3 for _ in range(self.n_cols*self.n_rows))
    
    def giveActSpace(self):
        return self.n_cols*self.n_rows

    def isTerminal(self):
        # проверим, не закончилась ли игра
        cur_marks, cur_p = np.where(self.board == self.curTurn), self.curTurn
        for i,j in zip(cur_marks[0], cur_marks[1]):
            win = False
            if i <= self.n_rows - self.n_win:
                if np.all(self.board[i:i+self.n_win, j] == cur_p):
                    win = True
            if not win:
                if j <= self.n_cols - self.n_win:
                    if np.all(self.board[i,j:j+self.n_win] == cur_p):
                        win = True
            if not win:
                if i <= self.n_rows - self.n_win and j <= self.n_cols - self.n_win:
                    if np.all(np.array([ self.board[i+k,j+k] == cur_p for k in range(self.n_win) ])):
                        win = True
            if not win:
                if i <= self.n_rows - self.n_win and j >= self.n_win-1:
                    if np.all(np.array([ self.board[i+k,j-k] == cur_p for k in range(self.n_win) ])):
                        win = True
            if win:
                self.gameOver = True
                return self.curTurn

        if len(self.getEmptySpaces()) == 0:
            self.gameOver = True
            return 0
        #self.gameOver = False
        return None

    def printBoard(self):
        for i in range(0, self.n_rows):
            print('----'*(self.n_cols)+'-')
            out = '| '
            for j in range(0, self.n_cols):
                if self.board[i, j] == 1:
                    token = 'x'
                if self.board[i, j] == -1:
                    token = 'o'
                if self.board[i, j] == 0:
                    token = ' '
                out += token + ' | '
            print(out)
        print('----'*(self.n_cols)+'-')

    def getState(self):
        return (self.getHash(), self.getEmptySpaces(), self.curTurn)

    def action_from_int(self, action_int):
        return ( int(action_int / self.n_cols), int(action_int % self.n_cols))

    def int_from_action(self, action):
        return action[0] * self.n_cols + action[1]
    
    def step(self, action):
        if self.board[action[0], action[1]] != 0:
            return self.getState(), -10 * self.curTurn, True, {}
        self.makeMove(self.curTurn, action[0], action[1])
        reward = self.isTerminal()
        #if reward is None:
        self.curTurn = -self.curTurn
        return self.getState(), 0 if reward is None else reward, reward is not None, {}

    def reset(self):
        self.board = np.zeros((self.n_rows, self.n_cols), dtype=int)
        self.boardHash = None
        self.gameOver = False
        self.emptySpaces = None
        self.curTurn = 1

In [4]:
def plot_board(env, pi, showtext=True, verbose=True, fontq=20, fontx=60):
    '''Рисуем доску с оценками из стратегии pi'''
    fig, ax = plt.subplots(1, 1, figsize=(10, 10))
    X, Y = np.meshgrid(np.arange(0, env.n_rows), np.arange(0, env.n_rows))
    Z = np.zeros((env.n_rows, env.n_cols)) + .01
    s, actions = env.getHash(), env.getEmptySpaces()
    if pi is not None and s in pi.Q:
        for i, a in enumerate(actions):
            Z[a[0], a[1]] = pi.Q[s][i]
    ax.set_xticks([])
    ax.set_yticks([])
    surf = ax.imshow(Z, cmap=plt.get_cmap('Accent', 10), vmin=-1, vmax=1)
    if showtext:
        for i,a in enumerate(actions):
            if pi is not None and s in pi.Q:
                ax.text( a[1] , a[0] , "%.3f" % pi.Q[s][i], fontsize=fontq, horizontalalignment='center', verticalalignment='center', color="w" )
    for i in range(env.n_rows):
        for j in range(env.n_cols):
            if env.board[i, j] == -1:
                ax.text(j, i, "O", fontsize=fontx, horizontalalignment='center', verticalalignment='center', color="w" )
            if env.board[i, j] == 1:
                ax.text(j, i, "X", fontsize=fontx, horizontalalignment='center', verticalalignment='center', color="w" )
#     cbar = plt.colorbar(surf, ticks=[0, 1])
    ax.grid(False)
    plt.show()

def get_and_print_move(env, pi, s, actions, random=False, verbose=True, fontq=20, fontx=60):
    '''Делаем ход, рисуем доску'''
    plot_board(env, pi, fontq=fontq, fontx=fontx)
    if verbose and (pi is not None):
        if s in pi.Q:
            for i,a in enumerate(actions):
                print(i, a, pi.Q[s][i])
        else:
            print("Стратегия не знает, что делать...")
    if random:
        return np.random.randint(len(actions))
    else:
        return pi.getActionGreedy(s, len(actions))

In [5]:
def plot_test_game(env, pi1, pi2, random_crosses=False, random_naughts=True, verbose=True, fontq=20, fontx=60):
    '''Играем тестовую партию между стратегиями или со случайными ходами, рисуем ход игры'''
    done = False
    env.reset()
    while not done:
        s, actions = env.getHash(), env.getEmptySpaces()
        if env.curTurn == 1:
            a = get_and_print_move(env, pi1, s, actions, random=random_crosses, verbose=verbose, fontq=fontq, fontx=fontx)
        else:
            a = get_and_print_move(env, pi2, s, actions, random=random_naughts, verbose=verbose, fontq=fontq, fontx=fontx)
        observation, reward, done, info = env.step(actions[a])
        if reward == 1:
            print("Крестики выиграли!")
            plot_board(env, None, showtext=False, fontq=fontq, fontx=fontx)
        if reward == -1:
            print("Нолики выиграли!")
            plot_board(env, None, showtext=False, fontq=fontq, fontx=fontx)

In [None]:
env = TicTacToe(n_rows=7, n_cols=7, n_win=4)
plot_test_game(env, None, None, random_crosses=True, random_naughts=True, verbose=True, fontx=60)

# Часть 1

In [14]:
class Qlearning:
    def __init__(self, state_space, action_space, player=1, eps=0.1, alpha=0.1, gamma=1):
        self.eps = eps
        self.alpha = alpha
        self.gamma = gamma
        self.s_space = state_space
        self.a_space = action_space
        self.actions = np.arange(action_space)
        self.Q = np.zeros((*state_space, action_space))
        self.player = player
    
    def step(self, state, next_state, action, reward):
        self.Q[state][action] += self.alpha * (reward * self.player + self.gamma * self.Q[next_state].max() - self.Q[state][action])
    
    def act(self, state):
        if np.random.rand() <= self.eps:
            return np.random.choice(self.actions)
        return self.Q[state].argmax()

In [15]:
def train(agent1, agent2, env, steps):
    x_d_o = [0, 0, 0]
    actions = deque(maxlen=2)
    states = deque(maxlen=3)
    counter = 0
    for i in tqdm(range(steps)):
        env.reset()
        sHash, _, _ = env.getState()
        states.append(env.decodeHash(sHash))
        actions.append(agent1.act(states[-1]))
        observation, reward, done, inf = env.step(env.action_from_int(actions[-1]))
        states.append(env.decodeHash(observation[0]))
        while done is not True:
            turn = observation[2]
            if turn == 1:
                actions.append(agent1.act(states[-1]))
            else:
                actions.append(agent2.act(states[-1]))
            
            observation, reward, done, inf = env.step(env.action_from_int(actions[-1]))
            states.append(env.decodeHash(observation[0]))
            if abs(reward) > 1:
                if turn == 1:
                    agent1.step(states[1], states[2], actions[1], reward)
                else:
                    agent2.step(states[1], states[2], actions[1], reward)
            else:
            
                if turn == 1:
                    agent2.step(states[0], states[2], actions[0], reward)
                else:
                    agent1.step(states[0], states[2], actions[0], reward)
        if reward > 0:
            x_d_o[0] += 1
        elif reward == 0:
            x_d_o[1] += 1
        else:
            x_d_o[2] += 1
        
        if abs(reward) <= 1:
            counter += 1
        
        if i % 100000==0:
            print(counter)
            print(x_d_o)
            
    
    

In [16]:
env = TicTacToe(n_rows=3, n_cols=3, n_win=3)
obs_space = env.giveObsSpace()
act_space = env.giveActSpace()

In [17]:
agent_x = Qlearning(obs_space, act_space)
agent_o = Qlearning(obs_space, act_space, player=-1)

In [18]:
train(agent_x, agent_o, env, int(5e5))

  0%|                                    | 216/500000 [00:00<08:12, 1015.80it/s]

0
[1, 0, 0]


 20%|██████▊                           | 100173/500000 [02:01<08:05, 823.13it/s]

64431
[22573, 55045, 22383]


 40%|█████████████▌                    | 200252/500000 [04:01<05:48, 859.34it/s]

131972
[45554, 110845, 43602]


 60%|████████████████████▍             | 300184/500000 [06:03<04:06, 809.04it/s]

199909
[68565, 166798, 64638]


 80%|███████████████████████████▏      | 400212/500000 [08:05<02:10, 763.43it/s]

267965
[91368, 222941, 85692]


100%|██████████████████████████████████| 500000/500000 [10:09<00:00, 820.97it/s]


In [25]:
def test(agent1, agent2, env, n=100):
    agent1.eps = 0
    agent2.eps = 0
    total_reward = 0
    for game in range(n):
        env.reset()
        done = False
        sHash, _, _ = env.getState()
        state = env.decodeHash(sHash)
        while not done:
            act = agent1.act(state)
            obs, rew, done, _ = env.step(env.action_from_int(act))
            if done:
                break
            state = env.decodeHash(obs[0])
            act = agent2.act(state)
            obs, rew, done, _ = env.step(env.action_from_int(act))
            if done:
                break
            state = env.decodeHash(obs[0])
        total_reward += rew
        #env.printBoard()
    return total_reward / n
        

In [26]:
test(agent_x, agent_o, env)

0.0

агенты обучились играть друг с другом и выходят на ничью

Более сложная среда. При увеличении n_rows и n_cols увеличивается размер таблицы Q, при чём экспоненциально (3^(n_cols*n_rows))

In [29]:
env = TicTacToe(n_rows=4, n_cols=4, n_win=3)
obs_space = env.giveObsSpace()
act_space = env.giveActSpace()
agent_x = Qlearning(obs_space, act_space)
agent_o = Qlearning(obs_space, act_space, player=-1)

In [30]:
train(agent_x, agent_o, env, int(1e6))

  0%|                                     | 65/1000000 [00:00<25:42, 648.13it/s]

0
[1, 0, 0]


 10%|███▎                             | 100088/1000000 [03:07<29:45, 504.11it/s]

15813
[51562, 268, 48171]


 20%|██████▌                          | 200123/1000000 [06:35<32:33, 409.54it/s]

57537
[117386, 619, 81996]


 30%|█████████▉                       | 300123/1000000 [09:50<21:21, 546.30it/s]

118115
[194546, 894, 104561]


 40%|█████████████▏                   | 400064/1000000 [12:47<19:37, 509.69it/s]

179735
[271348, 976, 127677]


 50%|████████████████▌                | 500197/1000000 [15:50<12:10, 684.43it/s]

241555
[347696, 1094, 151211]


 60%|███████████████████▊             | 600140/1000000 [18:55<15:09, 439.61it/s]

303297
[423938, 1243, 174820]


 70%|███████████████████████          | 700185/1000000 [21:58<08:25, 593.13it/s]

365381
[500028, 1598, 198375]


 80%|██████████████████████████▍      | 800096/1000000 [25:04<07:27, 446.46it/s]

427993
[576217, 1827, 221957]


 90%|█████████████████████████████▋   | 900149/1000000 [28:01<03:13, 515.78it/s]

490642
[651850, 2214, 245937]


100%|████████████████████████████████| 1000000/1000000 [30:57<00:00, 538.29it/s]


Крестики побеждают в большинстве случаев, так и должно быть. Алгоритм обучается.

In [31]:
test(agent_x, agent_o, env)

1.0

# Часть 2

In [34]:
import torch
from torch import nn
from torch.optim import Adam

from copy import deepcopy
from collections import deque
import random

In [35]:
LR = 1e-3
BATCH_SIZE = 1
GAMMA = 0.99
TAU = 0.1
EPS = 0.15
PER_UPDATE = 1

In [40]:
class DQN:
    def __init__(self, player=1):
        self.steps = 0
        self.player = player
        self.model = nn.Sequential(
        nn.Conv2d(1,9,3),
        nn.ELU(),
        nn.Flatten(),
        nn.Linear(9, 9),
        nn.ELU(),
        nn.Linear(9, 9)            
        )
        self.target_model = deepcopy(self.model)
        self.buffer = deque(maxlen=9)
        self.optimizer = Adam(self.model.parameters(), lr=LR)
        self.criterion = nn.SmoothL1Loss()
        self.actions = np.arange(9)
        self.eps = 0.1
    
    def sample_batch(self):
        batch = random.sample(self.buffer, BATCH_SIZE)
        batch = list(zip(*batch))
        return batch
    
    def soft_update(self):
        for tp, sp in zip(self.target_model.parameters(), self.model.parameters()):
            tp.data.copy_((1 - TAU) * tp.data + TAU * sp.data)
    
    def step(self, batch):
        
        state, action, next_state, reward, done = batch
        
        state = torch.tensor(np.array(state, dtype=np.float32))
        action = np.array(action, dtype=np.int)
        next_state = torch.tensor(np.array(next_state, dtype=np.float32))
        reward = torch.tensor(np.array(reward, dtype=np.float32)*self.player)
        done = torch.tensor(np.array(done, dtype=np.float32))
        
        Q = self.model(state)[np.arange(BATCH_SIZE), action]
        QT = self.target_model(next_state)
        QT[done==True] = 0
        
        self.optimizer.zero_grad()
        target = reward + GAMMA * QT.max(dim=1)[0]
        loss = self.criterion(Q, target)
        loss.backward()
        self.optimizer.step()

        if self.steps % 32 == 0:
            self.target_model = deepcopy(self.model)
        
    
    def update(self, transition):
        self.steps += 1
        self.buffer.append(transition)
        if len(self.buffer)==9 and self.steps % PER_UPDATE==0:
            batch = self.sample_batch()
            self.step(batch)
    
    def act(self, state, possible, env):
        if np.random.rand() <= self.eps:
            idx = np.random.randint(len(possible))
            return env.int_from_action(possible[idx])
        state = torch.tensor(np.array(state, dtype=np.float32)).unsqueeze(0)
        act = self.model(state).argmax()
        return act.item()
            
            

In [41]:
def train(agent1, agent2, env, steps):
    x_d_o = [0, 0, 0]
    actions = deque(maxlen=2)
    states = deque(maxlen=3)
    counter = 0
    k = 0
    xo = [0,0]
    for i in tqdm(range(steps)):
        env.reset()
        sHash, possible, _ = env.getState()
        states.append([deepcopy(env.board)])
        actions.append(agent1.act(states[-1], possible, env))
        observation, reward, done, inf = env.step(env.action_from_int(actions[-1]))
        states.append([deepcopy(env.board)])
        while done is not True:
            turn = observation[2]
            possible = observation[1]
            if turn == 1:
                actions.append(agent1.act(states[-1], possible, env))
            else:
                actions.append(agent2.act(states[-1], possible, env))
            
            observation, reward, done, inf = env.step(env.action_from_int(actions[-1]))
            states.append([deepcopy(env.board)])
            if done:
                if abs(reward)==10:
                    k += 1
                    if reward == -10:
                        xo[0] += 1
                    else:
                        xo[1] += 1
                if turn == 1:
                    agent1.update((states[1], actions[1], states[2], reward, done))
                    agent2.update((states[0], actions[0], states[2], reward, done))
                else:
                    agent2.update((states[1], actions[1], states[2], reward, done))
                    agent1.update((states[0], actions[0], states[2], reward, done))
            else:            
                if turn == 1:
                    agent2.update((states[0], actions[0], states[2], reward, done))
                else:
                    agent1.update((states[0], actions[0], states[2], reward, done))

                    
                    
                    
                    
        if i % 10000==0:
            print(k)
            print(xo)
            k = 0
            xo = [0, 0]
            
    

In [42]:
agent1 = DQN()
agent2 = DQN(player=-1)
env = TicTacToe(n_rows=3, n_cols=3, n_win=3)

In [43]:
train(agent1, agent2, env, int(1e6))

Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  action = np.array(action, dtype=np.int)
  0%|                                     | 8/1000000 [00:00<3:42:28, 74.91it/s]

1
[0, 1]


  1%|▎                                | 10011/1000000 [03:09<4:03:51, 67.66it/s]

8695
[4329, 4366]


  2%|▋                                | 20007/1000000 [06:07<5:04:33, 53.63it/s]

8986
[4406, 4580]


  3%|▉                                | 30009/1000000 [08:54<5:27:57, 49.29it/s]

9040
[4525, 4515]


  4%|█▎                               | 40007/1000000 [12:02<5:56:46, 44.85it/s]

6845
[4135, 2710]


  5%|█▋                               | 50011/1000000 [15:21<4:39:43, 56.60it/s]

5197
[2515, 2682]


  6%|█▉                               | 60008/1000000 [18:34<5:22:56, 48.51it/s]

7000
[2828, 4172]


  7%|██▎                              | 70007/1000000 [21:49<6:37:23, 39.00it/s]

9293
[2980, 6313]


  8%|██▋                              | 80005/1000000 [25:16<5:32:20, 46.14it/s]

8127
[3814, 4313]


  9%|██▉                              | 90007/1000000 [28:37<4:34:21, 55.28it/s]

5008
[2395, 2613]


 10%|███▏                            | 100006/1000000 [31:56<5:47:51, 43.12it/s]

5239
[2931, 2308]


 11%|███▌                            | 110008/1000000 [35:08<4:57:10, 49.91it/s]

4995
[2508, 2487]


 12%|███▊                            | 120012/1000000 [38:21<4:14:44, 57.58it/s]

3536
[1979, 1557]


 13%|████▏                           | 130009/1000000 [41:30<4:24:04, 54.91it/s]

4376
[2332, 2044]


 14%|████▍                           | 140007/1000000 [45:11<5:47:03, 41.30it/s]

4984
[2203, 2781]


 15%|████▊                           | 150007/1000000 [48:34<5:21:55, 44.01it/s]

4108
[2222, 1886]


 16%|█████                           | 160007/1000000 [51:43<4:21:05, 53.62it/s]

3538
[1942, 1596]


 17%|█████▍                          | 170008/1000000 [54:55<4:19:23, 53.33it/s]

3430
[1523, 1907]


 18%|█████▊                          | 180010/1000000 [58:23<4:28:19, 50.93it/s]

5147
[2503, 2644]


 19%|█████▋                        | 190006/1000000 [1:01:43<4:35:35, 48.99it/s]

4518
[2102, 2416]


 20%|██████                        | 200012/1000000 [1:05:08<4:04:44, 54.48it/s]

6407
[2579, 3828]


 21%|██████▎                       | 210010/1000000 [1:08:25<4:22:03, 50.24it/s]

3610
[1609, 2001]


 22%|██████▌                       | 220005/1000000 [1:11:54<5:10:30, 41.87it/s]

3611
[1903, 1708]


 23%|██████▉                       | 230009/1000000 [1:15:35<4:23:16, 48.74it/s]

4078
[2289, 1789]


 24%|███████▏                      | 240007/1000000 [1:19:48<5:09:31, 40.92it/s]

8050
[5553, 2497]


 25%|███████▌                      | 250006/1000000 [1:24:07<6:20:34, 32.84it/s]

5005
[3229, 1776]


 26%|███████▊                      | 260008/1000000 [1:27:57<4:17:57, 47.81it/s]

6023
[3820, 2203]


 27%|████████                      | 270010/1000000 [1:31:29<4:26:33, 45.64it/s]

4068
[1840, 2228]


 28%|████████▍                     | 280005/1000000 [1:35:21<4:32:45, 43.99it/s]

3488
[1992, 1496]


 29%|████████▋                     | 290009/1000000 [1:39:32<4:03:05, 48.68it/s]

3248
[1781, 1467]


 30%|█████████                     | 300007/1000000 [1:42:59<3:51:03, 50.49it/s]

3455
[1976, 1479]


 31%|█████████▎                    | 310010/1000000 [1:46:02<3:31:26, 54.39it/s]

5117
[2162, 2955]


 32%|█████████▌                    | 320011/1000000 [1:49:13<3:21:13, 56.32it/s]

7022
[2796, 4226]


 33%|█████████▉                    | 330008/1000000 [1:52:22<3:43:28, 49.97it/s]

4155
[2302, 1853]


 34%|██████████▏                   | 340012/1000000 [1:55:22<3:07:36, 58.63it/s]

4784
[2197, 2587]


 35%|██████████▌                   | 350010/1000000 [1:58:30<3:25:22, 52.75it/s]

4374
[2660, 1714]


 36%|██████████▊                   | 360006/1000000 [2:01:40<3:27:02, 51.52it/s]

7952
[3426, 4526]


 37%|███████████                   | 370006/1000000 [2:04:52<3:09:08, 55.51it/s]

5609
[3274, 2335]


 38%|███████████▍                  | 380006/1000000 [2:08:13<3:32:38, 48.59it/s]

5929
[3490, 2439]


 39%|███████████▋                  | 390011/1000000 [2:11:25<2:59:53, 56.52it/s]

4078
[2878, 1200]


 40%|████████████                  | 400005/1000000 [2:14:28<3:19:39, 50.09it/s]

3418
[2148, 1270]


 41%|████████████▎                 | 410009/1000000 [2:17:28<3:00:12, 54.57it/s]

5817
[3802, 2015]


 42%|████████████▌                 | 420012/1000000 [2:20:34<2:49:59, 56.86it/s]

4733
[3108, 1625]


 43%|████████████▉                 | 430007/1000000 [2:23:43<2:48:17, 56.45it/s]

4407
[2423, 1984]


 44%|█████████████▏                | 440011/1000000 [2:26:48<2:40:19, 58.21it/s]

2931
[1541, 1390]


 45%|█████████████▌                | 450009/1000000 [2:30:01<3:23:55, 44.95it/s]

3409
[1875, 1534]


 46%|█████████████▊                | 460011/1000000 [2:33:12<2:43:46, 54.95it/s]

3557
[1947, 1610]


 47%|██████████████                | 470007/1000000 [2:36:25<2:51:11, 51.60it/s]

4855
[2752, 2103]


 48%|██████████████▍               | 480009/1000000 [2:39:41<2:50:22, 50.87it/s]

6320
[3417, 2903]


 49%|██████████████▋               | 490011/1000000 [2:42:56<2:34:09, 55.13it/s]

4268
[2064, 2204]


 50%|███████████████               | 500013/1000000 [2:46:23<2:22:35, 58.44it/s]

4860
[2075, 2785]


 51%|███████████████▎              | 510008/1000000 [2:49:40<2:26:43, 55.66it/s]

5452
[3098, 2354]


 52%|███████████████▌              | 520009/1000000 [2:52:56<2:56:56, 45.21it/s]

5865
[3979, 1886]


 53%|███████████████▉              | 530007/1000000 [2:56:11<3:20:39, 39.04it/s]

5904
[3793, 2111]


 54%|████████████████▏             | 540007/1000000 [2:59:45<2:30:28, 50.95it/s]

4259
[2808, 1451]


 55%|████████████████▌             | 550007/1000000 [3:03:12<2:46:49, 44.95it/s]

5757
[3589, 2168]


 56%|████████████████▊             | 560008/1000000 [3:06:11<2:20:31, 52.18it/s]

6473
[4103, 2370]


 57%|█████████████████             | 570009/1000000 [3:09:26<2:11:35, 54.46it/s]

6519
[4204, 2315]


 58%|█████████████████▍            | 580009/1000000 [3:12:37<1:57:50, 59.40it/s]

3656
[2267, 1389]


 59%|█████████████████▋            | 590012/1000000 [3:16:11<2:08:56, 53.00it/s]

4557
[2997, 1560]


 60%|██████████████████            | 600007/1000000 [3:19:30<2:14:40, 49.50it/s]

3447
[2182, 1265]


 61%|██████████████████▎           | 610011/1000000 [3:22:24<1:55:48, 56.13it/s]

6268
[4100, 2168]


 62%|██████████████████▌           | 620005/1000000 [3:25:37<1:59:18, 53.08it/s]

4333
[2706, 1627]


 63%|██████████████████▉           | 630008/1000000 [3:28:41<1:55:16, 53.49it/s]

4579
[3219, 1360]


 64%|███████████████████▏          | 640009/1000000 [3:31:46<1:58:51, 50.48it/s]

4202
[3291, 911]


 65%|███████████████████▌          | 650007/1000000 [3:34:59<1:44:58, 55.57it/s]

3852
[2732, 1120]


 66%|███████████████████▊          | 660008/1000000 [3:38:14<1:36:26, 58.75it/s]

6893
[4945, 1948]


 67%|████████████████████          | 670009/1000000 [3:41:33<1:42:42, 53.55it/s]

6566
[5063, 1503]


 68%|████████████████████▍         | 680009/1000000 [3:44:56<1:58:53, 44.86it/s]

6930
[5629, 1301]


 69%|████████████████████▋         | 690007/1000000 [3:48:16<1:56:09, 44.48it/s]

6807
[4926, 1881]


 70%|█████████████████████         | 700005/1000000 [3:51:29<1:58:33, 42.17it/s]

4656
[3153, 1503]


 71%|█████████████████████▎        | 710007/1000000 [3:54:58<1:36:19, 50.18it/s]

6157
[4945, 1212]


 72%|█████████████████████▌        | 720007/1000000 [3:58:09<1:57:32, 39.70it/s]

4274
[2921, 1353]


 73%|█████████████████████▊        | 725480/1000000 [4:00:29<1:30:59, 50.28it/s]


KeyboardInterrupt: 

Очень долго я боролся с этой задачей, но завести её так и не получилось. Пробовал разный размер буфера опыта, различные константы, soft- и hard- обновление таргет-сети. Однако, сеть всё равно не обучается не ставить х и o на запрещенные позиции (там, где уже стоят значки). В процессе обучения, выводимое число - это кол-во игр, которые закончились, потому что агент поставил значок на занятую позицию. Видно, что со временем это число падает, так что скорее всего, если усложнить сеть и и добавить больше эпох, то она обучится. Также я пришёл к выводу, что размер батча лучше оставить равным единице. При обучениие dqn на различных компьютерных играх последовательные стейты не сильно отличаются друг от друга, а здесь каждый стейт особенный.

# Часть 3

In [6]:
from collections import defaultdict
from math import log
from IPython.display import clear_output

In [7]:
class Rollouts:
    def __init__(self, player=1):
        self.player = player
    
    def rand_policy(self, possib):
        idx = np.random.randint(len(possib))
        return possib[idx]
    
    def smart_policy(self, env, possib): #если можем закончить игру - заканчиваем
        for ac in possib:
            cop = TicTacToe(clone=env)
            if cop.step(ac)[2]:
                return ac
        return self.rand_policy(possib)
        
    
    def roll(self, env):
        temp_env = TicTacToe(clone=env)
        _, empty, turn = temp_env.getState()
        reward = temp_env.isTerminal()
        done = reward is not None
        while not done:
            #action = self.rand_policy(empty)
            action = self.smart_policy(env, empty)
            observation, reward, done, _ = temp_env.step(action)
            empty = observation[1]
        reward *= self.player
        return reward
    
    def n_roll_outs(self, env, n=100):
        results = 0
        for _ in range(n):
            res = self.roll(env)
            results += res
        return results / n
        
        

In [8]:
class Node:
    def __init__(self, state):
        self.state = state
        self.parent = None
        self.childs = []
        self.nxt_states_set = set()
    
    def __hash__(self):
        return self.state.env.getHash().__hash__()
    
#     def __repr__(self):
#         self.state.env.printBoard()
#         return ''
        
    
    def __str__(self):
        self.state.env.printBoard()
        print('visits', self.state.visits)
        print('score', self.state.score)
        return ''
        
        

In [9]:
class State:
    def __init__(self, env):
        self.env = env
        self.visits = 0
        self.score = 0
    
    def next_states(self):
        if self.env.isTerminal():
            return None
        _, empty, _ = self.env.getState()
        states = []
        for act in empty:
            temp = TicTacToe(clone=self.env)
            temp.step(act)
            states.append(State(temp))
        return states

In [11]:
class MCTS:
    def __init__(self, n_cols, n_rows, n_wins, player=1):
        env = TicTacToe(n_rows, n_cols, n_wins)
        state = State(env)
        self.root = Node(state)
        self.current_node = self.root
        self.ro = Rollouts(player)
        self.player = player
        if player < 0:
            self.run(n_cols*n_rows)
    
    def utb(self, node):
        if node.state.visits == 0:
            return 1e9
        utb_score = -1 if self.player==node.state.env.curTurn else 1
        utb_score *= node.state.score / node.state.visits
        utb_score += 1.41 * (log(node.parent.state.visits) / node.state.visits)**0.5
        return utb_score
    
    def choose_next(self, nodes):
        scores = []
        for nxt in nodes:
            scores.append(self.utb(nxt))
        return nodes[np.array(scores).argmax()]
    
    def selection(self, node, path=None):
        if not path:
            path = []
        path.append(node)
        if node.state.env.gameOver:
            return path

        if len(node.childs) != len(node.state.next_states()):
            return path
        nxt = self.choose_next(node.childs)
        return self.selection(nxt, path)
    
    def expansion(self, node):
        if node.state.env.gameOver:
            return node
        states = node.state.next_states()
        for state in states:
            if state.env.getHash() not in node.nxt_states_set:
                new_node = Node(state)
                node.childs.append(new_node)
                node.nxt_states_set.add(new_node.state.env.getHash())
                new_node.parent = node
                break
        return new_node
    
    def simulation(self, node):
        if node.state.env.gameOver:
            rew = node.state.env.curTurn * -1 * self.player * len(node.state.env.getEmptySpaces())
            node.state.visits += 1
            node.state.score += rew
            #print(node.state.env.isTerminal())
            return rew
        mean_reward = self.ro.n_roll_outs(node.state.env)
        node.state.visits += 1
        node.state.score += mean_reward
        return mean_reward
    
    def backup(self, path, reward):
        for node in path[::-1]:
            node.state.visits += 1
            node.state.score += reward * 0.9
    
    def run(self, n=1):
        for _ in range(n):
            path = self.selection(self.current_node)
            new_node = self.expansion(path[-1])
            rew = self.simulation(new_node)
            self.backup(path, rew)
    
    def step(self, n=100):
        self.run(n)
        max_rew = -10
        action = None
        for node in self.current_node.childs:
            score = node.state.score / node.state.visits
            if score > max_rew:
                max_rew = score
                nxt_node = node
        
        action = np.argmax(np.abs(self.current_node.state.env.board - nxt_node.state.env.board))
        self.current_node = nxt_node
        return action, nxt_node.state
        
        
    def opponent_step(self, state):
        st_hash = state.env.getHash()
        if st_hash in self.current_node.nxt_states_set:
            for node in self.current_node.childs:
                if st_hash == node.state.env.getHash():
                    self.current_node = node
                    return
        copied_env = TicTacToe(clone=state.env)
        new_state = State(copied_env)
        new_node = Node(new_state)
        new_node.parent = self.current_node
        self.current_node.childs.append(new_node)
        self.current_node.nxt_states_set.add(new_node.state.env.getHash())
        self.current_node = new_node
        return
    
    def reset(self):
        self.current_node = self.root

In [12]:
from time import sleep

In [13]:
def play(agent1, agent2, env):
    agent1.reset()
    agent2.reset()
    env.reset()
    done = False
    while not done:
        #clear_output()
        env.printBoard()
        sleep(1)
        ac, state = agent1.step(100)
        _, _, done, _ = env.step(env.action_from_int(ac))
        if done:
            break
        #clear_output()
        env.printBoard()
        sleep(1)
        agent2.opponent_step(state)
        ac, state = agent2.step(100)
        _, _, done, _ = env.step(env.action_from_int(ac))
        if done:
            break
        agent1.opponent_step(state)
    #clear_output()
    env.printBoard()
        
        

In [314]:
ag1 = MCTS(3, 3, 3)
ag2 = MCTS(3,3,3,-1)
env = TicTacToe(3,3,3)


в качестве доказательства работоспособности кода я решил смодулировать игру в крестики-нолики 3х3. При больших размерах нужно проводить больше запусков цикла (selection-expansion-simulating-backup), что занимает намного больше времени. В 3х3 MSTS всегда выходит на ничью, как и должно быть.

In [315]:
play(ag1, ag2, env)

-------------
|   |   |   | 
-------------
|   |   |   | 
-------------
|   |   |   | 
-------------
-------------
|   |   |   | 
-------------
|   | x |   | 
-------------
|   |   |   | 
-------------
-------------
|   |   | o | 
-------------
|   | x |   | 
-------------
|   |   |   | 
-------------
-------------
|   |   | o | 
-------------
|   | x | x | 
-------------
|   |   |   | 
-------------
-------------
|   |   | o | 
-------------
| o | x | x | 
-------------
|   |   |   | 
-------------
-------------
| x |   | o | 
-------------
| o | x | x | 
-------------
|   |   |   | 
-------------
-------------
| x |   | o | 
-------------
| o | x | x | 
-------------
|   |   | o | 
-------------
-------------
| x |   | o | 
-------------
| o | x | x | 
-------------
|   | x | o | 
-------------
-------------
| x | o | o | 
-------------
| o | x | x | 
-------------
|   | x | o | 
-------------
-------------
| x | o | o | 
-------------
| o | x | x | 
-------------
| x | x | o | 
----