In [1]:
import itertools
import pickle
import random
import os

import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import matplotlib.pyplot as plt

%matplotlib inline

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cuda:0


In [1]:
PATH = ""

In [0]:
# Игровое поле
class GameBoard:
    def __init__(self):        
        self.reset()
        
    def reset(self):
        self.board = np.zeros((3, 3)).astype(int)
        self.done = False
        self.X_reward = 0.
        self.O_reward = 0.
    
    def set_sign(self, coords, sign):
        # Выполняет ход, если возможно
        assert(type(coords) is tuple and len(coords)==2)
        if self.board[coords] != 0:
            return False
        if sign==1 or sign==-1:
            self.board[coords] = sign
            return True
        return False
    
    def get_state(self):
        return self.board
    
    def get_actions(self):
        # Возвращает допустимые действия
        return np.argwhere(self.board==0)
    
    def check_win(self, sign, state=None):
        # Проверяет выигрыш игрока
        if state is None:
            state = self.board
        win = False
        for i in range(3):
            if (state[i, :] == sign).sum() == 3:                
                win = True
            if (state[:, i] == sign).sum() == 3:
                win = True
        diag = np.array([state[0,0], state[1,1], state[2,2]])
        if (diag == sign).sum() == 3:
            win = True        
        diag = np.array([state[2,0], state[1,1], state[0,2]])
        if (diag == sign).sum() == 3:
            win = True
        return win
    
    def check_draw(self):
        # Проверяет ничью
        if 0 not in self.board:
            return True
    
    def check_done(self):
        # Проверяет завершение партии
        if self.check_win(1):
            #print("X won")
            self.X_reward = 1.
            self.O_reward = 0.
            self.done = True
        if self.check_win(-1):
            #print("O won")
            self.O_reward = 1.
            self.X_reward = 0.
            self.done = True
        if self.check_draw():
            #print("Draw")
            self.X_reward = 0.1
            self.O_reward = 0.1
            self.done = True        
        return self.done
    
    def get_reward(self, player):
        # Возвращает награду
        if player.sign == 1:
            return self.X_reward
        if player.sign == -1:
            return self.O_reward
    
    def print(self):
        # Печатает поле
        sign = lambda x: "X" if x == 1 else "O" if x == -1 else " "
        print("  0 1 2 ")
        print("  - - - ")
        print("0|{}|{}|{}|".format(sign(self.board[0,0]), sign(self.board[0,1]), sign(self.board[0,2])))
        print("  - - - ")
        print("1|{}|{}|{}|".format(sign(self.board[1,0]), sign(self.board[1,1]), sign(self.board[1,2])))
        print("  - - - ")
        print("2|{}|{}|{}|".format(sign(self.board[2,0]), sign(self.board[2,1]), sign(self.board[2,2])))
        print("  - - - ")
        

In [0]:
"""
Класс игрока.
Включает бот для игры, стремящийся поставить три в ряд и
блокирующий такие попытки оппонента. Пытается вилки делать.
Простенький, но играть может.
"""
class Player:
    def __init__(self, board, side="X"):
        self.board = board
        self.sign = 1 if side=="X" else -1 if side=="O" else 0
            
    def __check_danger(self, coords, sign, state):  
        # Проверяет опасную ситуацию      
        c0 = coords[0]
        c1 = coords[1]        
        if (state[c0, :]==-sign).sum()==2:
            return True
        if (state[:, c1]==-sign).sum()==2:
            return True
        if c0==c1:
            diag = np.array([state[0,0], state[1,1], state[2,2]])
            if (diag==-sign).sum()==2:
                return True
        if np.abs(c0-c1)==2:
            diag = np.array([state[2,0], state[1,1], state[0,2]])
            if (diag==-sign).sum()==2:
                return True
        if c0==1 and c1==1:
            diag = np.array([state[0,0], state[1,1], state[2,2]])
            if (diag==-sign).sum()==2:
                return True        
            diag = np.array([state[2,0], state[1,1], state[0,2]])
            if (diag==-sign).sum()==2:
                return True
        return False
    
    def __find_danger(self, sign, state):         
        return np.argwhere(state==0)[np.array([self.__check_danger(c, sign, state) for c in self.get_actions(state)])]
    
    def get_actions(self, state):
        return np.argwhere(state==0)
    
    def make_move(self, coords):
        # Для совершения хода вручную
        if self.board.check_done():
            return False
        if self.board.set_sign(coords, self.sign):
            self.board.check_done()
            return True
        print("Invalid move")
        return False
    
    def act(self):
        # Совершение хода ботом
        if self.board.check_done():
            return False
        wins = self.__find_danger(-self.sign, self.board.get_state())
        if len(wins) > 0:
            self.make_move(tuple(wins[0]))
            return True
        dangers = self.__find_danger(self.sign, self.board.get_state())
        if len(dangers) > 0:
            self.make_move(tuple(dangers[0]))
            return True
        actions = self.get_actions(self.board.get_state())
        if [1, 1] in actions.tolist():            
            self.make_move((1, 1))
            return True
        else:
            potential = 0
            action = actions[np.random.randint(len(actions))]
            for a in actions:                
                test_state = self.board.get_state().copy()
                test_state[tuple(a)] = self.sign
                if len(self.__find_danger(sign=-self.sign, state=test_state)) > potential:
                    potential = len(self.__find_danger(sign=-self.sign, state=test_state))
                    action = a                    
            self.make_move(tuple(action))        
        return True

In [0]:
gb = GameBoard()
pla = Player(gb)
bot = Player(gb, "O")
gb.print()

  0 1 2 
  - - - 
0| | | |
  - - - 
1| | | |
  - - - 
2| | | |
  - - - 


In [0]:
# Можно поиграть вручную, задавая коордтнаты в make_move()
if(pla.make_move((1,2))):
    gb.print()
    print(gb.get_reward(pla))
    if(bot.act()):
        gb.print()
        print(gb.get_reward(bot))

  0 1 2 
  - - - 
0|O|O|X|
  - - - 
1| |X|X|
  - - - 
2|O| |X|
  - - - 
1.0


In [0]:
# Класс таблицы Q(s,a)
class QTable:
    def __init__(self):
        self.__generate_positions()
        self.all_actions = np.array([p for p in itertools.product(range(3), repeat=2)])
        self.__generate_table()
    
    def __check_position(self, pos):
        # Проверяет валидность сгенерированных позиций
        valid = True
        count = 0           
        for i in range(3):
            if (pos[i, :] == 1).sum() == 3:
                count += 1
            if (pos[i, :] == -1).sum() == 3:
                count += 1
            if (pos[:, i] == 1).sum() == 3:
                count += 1
            if (pos[:, i] == -1).sum() == 3:
                count += 1
        diag = np.array([pos[0,0], pos[1,1], pos[2,2]])
        if (diag == 1).sum() == 3:
            count += 1
        if (diag == -1).sum() == 3:
            count += 1
        diag = np.array([pos[2,0], pos[1,1], pos[0,2]])
        if (diag == 1).sum() == 3:
            count += 1
        if (diag == -1).sum() == 3:
            count += 1
        if count >= 1:
            valid = False
        return valid
    
    def __generate_positions(self):
        # Создает список всех возможных позиций.
        positions = []
        for i, c in enumerate(itertools.product(range(3), repeat=9)):
            c = np.array(c)-1
            if (c==1).sum() >= (c==-1).sum() and (c==1).sum() <= (c==-1).sum()+1:
                c = c.reshape([3,3])
                if self.__check_position(c):
                    positions.append(c)
        self.positions = np.array(positions).astype(int)
    
    def __generate_table(self):
        # Создает q-таблицу с нулями и наградами на выигрышных позициях (для крестиков)
        self.table = np.zeros((9, len(self.positions)))
        for i, pos in enumerate(self.positions):
            if GameBoard().check_win(1, pos):
                self.table[:, i] = 1
    
    def patois(self, plaa):
        # Возвращает индексы доступных игроку действий
        return np.array([np.where(np.all(self.all_actions==plaa[i], axis=1))[0][0] for i in range(len(plaa))])
    
    def atoi(self, action):
        # Возвращает индекс действия
        return np.where(np.all(self.all_actions==action, axis=1))[0][0]
    
    def stoi(self, state):
        # Возвращает индекс состояния
        return np.where(np.all(self.positions == state, axis=(1,2)))[0][0]
    
    def get_states(self):
        # Возвращает все позиции
        return self.positions
    
    def get_table(self):
        # Возвращает таблицу
        return self.table
    
    def get_all_actions(self):
        # Возвращает все действия
        return self.all_actions
    
    def get_Q(self, state, action):
        # Возвращает Q(s,a)
        si = self.stoi(state)
        ai = self.atoi(action)
        return self.table[ai, si]
    
    def set_Q(self, state, action, value):
        # Устанавливает Q(s,a)
        si = self.stoi(state)
        ai = self.atoi(action)        
        self.table[ai, si] = value
        
    def get_Qs(self, state):
        # Возвращает значения Q для всех действий при состоянии
        si = self.stoi(state)        
        return self.table[:, si]
    
    def save(self):
        with open("qtable.pickle", "wb") as f:
            pickle.dump(self.table, f)
            
    def load(self):
        with open("qtable.pickle", "rb") as f:
            self.table = pickle.load(f)

In [0]:
# Табличная модель reinforcement learning
class RLTDmodel:
    def __init__(self, player, gameboard, qtable):
        self.alpha = 0.1
        self.gamma = 0.9
        self.epsilon = 0.9
        self.max_acts = 5
        self.act_n = 1
        self.r = 0.
        self.player = player
        self.gameboard = gameboard
        self.Qtable = qtable
        self.prev_a = None
        self.prev_s = None    
    
    def select_action(self):
        # Выбор действия через epsilon-greedy стратегию
        actions = self.player.get_actions(self.gameboard.get_state())
        if np.random.rand() < self.epsilon/self.act_n:            
            return actions[np.random.randint(len(actions))]
        else:            
            state = self.gameboard.get_state()            
            Qs = np.array([self.Qtable.get_Q(state, action) for action in actions])
            return actions[np.argmax(Qs)]
        
    def act(self):
        a = self.select_action()
        self.prev_a = a
        self.prev_s = self.gameboard.get_state().copy()        
        self.player.make_move(tuple(a))
        self.r = gb.get_reward(self.player)        
        self.act_n += 1
        
    def update_Qs(self):
        # Обновление Q по уравнению Беллмана
        actions = self.player.get_actions(self.gameboard.get_state())
        state = self.gameboard.get_state()        
        Q_old = self.Qtable.get_Q(self.prev_s, self.prev_a)
        if not self.gameboard.done:
            Q_max = np.max(np.array([self.Qtable.get_Q(state, action) for action in actions]))            
        else:
            Q_max = 0.
            self.act_n = 1        
        Q_new = Q_old + self.alpha*(self.r + self.gamma*Q_max - Q_old)        
        self.Qtable.set_Q(self.prev_s, self.prev_a, Q_new)

In [0]:
# Создание и загрузка обученной q таблицы
qt = QTable()
qt.load()

In [0]:
gb = GameBoard()
pla = Player(gb)
bot = Player(gb, "O")
# Создание модели
rtm = RLTDmodel(pla, gb, qt)
gb.print()

  0 1 2 
  - - - 
0| | | |
  - - - 
1| | | |
  - - - 
2| | | |
  - - - 


In [0]:
# Обучение модели. Постепенно начинает играть со 100% эффективностью, найдя все 
# уязвимости бота (если сделать epsilon маленьким)
Xr = 0
Or = 0
rtm.epsilon = 0.3
for i in range(1000):
    if i>1 and i % 1000 == 0:
        qt.save()
        print("Results {} :".format(i), Xr, Or, str(100*Xr/(Xr+Or))+"%")
    gb.reset()
    while not gb.done:
        rtm.act()
        #gb.print()
        if gb.get_reward(pla) == 1.:
            Xr += gb.get_reward(pla)
        if gb.done:
            rtm.update_Qs()
        if(bot.act()):
            #gb.print()
            if gb.get_reward(bot) == 1.:
                Or += gb.get_reward(bot)
            rtm.update_Qs()
print("Results", Xr, Or, str(100*Xr/(Xr+Or))+"%")

Results 655.0 126.0 83.86683738796415%


In [0]:
# Обертка для игры для обучения нейросети. По ходу игрока возвращает следующее
# состояние (после хода оппонента), награду, конец игры.
class Game:
    def __init__(self):
        self.gb = GameBoard()
        self.pla = Player(self.gb)
        self.bot = Player(self.gb, "O")
        self.actions_all = self.gb.get_actions()
        self.actions_n = len(self.actions_all)
    
    def atoi(self, action):
        return np.where(np.all(self.actions_all==action, axis=1))[0][0]
    
    def actions_valid_mask(self, state):
        mask = np.zeros((1, 9))
        v = self.pla.get_actions(state)
        for v in v:
            mask += np.all(self.actions_all==v, axis=1).astype(int)
        return mask.astype(bool)         
        
    def step(self, action):
        action = self.actions_all[np.array(action).astype(bool)][0]        
        self.pla.make_move(tuple(action))
        reward = self.gb.get_reward(self.pla)
        terminal = self.gb.done
        if not terminal:
            self.bot.act()
            reward = self.gb.get_reward(self.pla)
            terminal = self.gb.done
        state = self.gb.get_state().copy()
        return state, reward, terminal

In [0]:
param_input_shape = 9
param_n_actions = 9
param_loss_fn = F.mse_loss
param_gamma = 0.9
param_epsilon = 0.01
param_iterations = 10001
param_replay_memory_size = 4096
param_batch_size = 1024

In [0]:
# FullyConnected с маской
class FullyConnected(nn.Module):
    def __init__(self, sizes, dropout=False, activation_fn=nn.Tanh(), flatten=False, 
                 last_fn=None, first_fn=None, device='cpu'):
        super(FullyConnected, self).__init__()
        layers = []
        self.flatten = flatten
        if first_fn is not None:
            layers.append(first_fn)
        for i in range(len(sizes) - 2):
            layers.append(nn.Linear(sizes[i], sizes[i+1]))
            if dropout:
                layers.append(nn.Dropout(dropout))
            layers.append(activation_fn) # нам не нужен дропаут и фнкция активации в последнем слое
        else: 
            layers.append(nn.Linear(sizes[-2], sizes[-1]))
        if last_fn is not None:
            layers.append(last_fn)
        self.model = nn.Sequential(*layers)
        self.to(device)
        
    def forward(self, x, mask=None):        
        if self.flatten:
            x = x.view(x.shape[0], -1)           
        x = self.model(x)
        # Маскирование невозможных действий нулями
        if mask is not None:
            x = x*mask.float()
        return self.model(x)

In [0]:
# Создание "policy net" и "target net"
dqn = FullyConnected([param_input_shape, 128, 256, 128, param_n_actions],
                     flatten=True, dropout=0.3, activation_fn=nn.LeakyReLU(0.2), last_fn=nn.Softmax(dim=-1), device=device)

dqn_opt = optim.Adam(dqn.parameters(), lr=1e-5)

dqn_target = FullyConnected([param_input_shape, 128, 256, 128, param_n_actions],
                     flatten=True, dropout=0.3, activation_fn=nn.LeakyReLU(0.2), last_fn=nn.Softmax(dim=-1), device=device)
dqn_target.load_state_dict(dqn.state_dict())
dqn_target.eval()

dqn_train_log = {'DQN': []}
dqn_test_log = {'DQN': []}

In [0]:
load_filename = "model_128_256_128_softmax_start_6_30000.pth"
filename = "model_128_256_128_softmax_start_7_{}.pth"

In [0]:
# Горячий старт
dqn = FullyConnected([param_input_shape, 128, 256, 128, param_n_actions],
                     flatten=True, dropout=0.3, activation_fn=nn.LeakyReLU(0.2), last_fn=nn.Softmax(dim=-1), device=device)

dqn.load_state_dict(torch.load(os.path.join(PATH, load_filename)))
dqn_opt = optim.Adam(dqn.parameters(), lr=1e-5)

dqn_target = FullyConnected([param_input_shape, 128, 256, 128, param_n_actions],
                     flatten=True, dropout=0.3, activation_fn=nn.LeakyReLU(0.2), last_fn=nn.Softmax(dim=-1), device=device)
dqn_target.load_state_dict(dqn.state_dict())
dqn_target.eval()

dqn_train_log = {'DQN': []}
dqn_test_log = {'DQN': []}

In [0]:
dqn_opt = optim.Adam(dqn.parameters(), lr=1e-6)

In [0]:
replay_memory = []

In [21]:
# Обучение
for epoch in range(param_iterations):
    g = Game()
    # Прохождение игры против бота для заполнения replay memory
    while not g.gb.done:
        dqn.eval()
        state_0 = torch.Tensor(g.gb.get_state().copy()).unsqueeze(0)
        action = torch.zeros((1, param_n_actions))
        if np.random.rand() < param_epsilon:
            output = torch.rand((1,9))    
        else:
            output = dqn(state_0.to(device))
        mask = torch.Tensor(g.actions_valid_mask(state_0.squeeze().numpy())).type(torch.uint8).to(device)
        output[~mask] = 0
        action[:, torch.argmax(output)] = 1
        state_1, reward, terminal = g.step(action[0].numpy())
        state_1 = torch.Tensor(state_1).unsqueeze(0).to(device)
        reward = torch.Tensor([reward]).unsqueeze(0).to(device)
        terminal = torch.Tensor([terminal]).type(torch.uint8).unsqueeze(0).to(device)

        replay_memory.append((state_0, action, reward, state_1, terminal))
        if len(replay_memory) > param_replay_memory_size:
            replay_memory.pop(0)
    
    # Обучение сэмплированным батчем из replay memory
    if len(replay_memory) > param_batch_size:
        dqn.train()
        dqn.zero_grad()
        batch = random.sample(replay_memory, param_batch_size)
    
        # Разделение батча на признаки
        batch_state_0 = torch.cat(tuple(d[0] for d in batch)).to(device)
        batch_action = torch.cat(tuple(d[1] for d in batch)).to(device)
        batch_reward = torch.cat(tuple(d[2] for d in batch)).to(device)
        batch_state_1 = torch.cat(tuple(d[3] for d in batch)).to(device)
        batch_terminal = torch.cat(tuple(d[4] for d in batch)).to(device)
    
        # Предсказание target net Q значений следующего состояниz
        batch_mask_1 = (batch_state_1 == 0)
        batch_mask_1 = batch_mask_1.view((param_batch_size, -1)).to(device)
        output_1 = dqn_target(batch_state_1, batch_mask_1)        
    
        # Создание "groun truth" меток (r(s,a) + gamma*max(Q(s',a)))
        batch_y = torch.cat(tuple(batch_reward[i] if batch_terminal[i] else batch_reward[i] + \
                                  param_gamma*torch.max(output_1[i]) for i in range(len(batch))))        

        # Получение Q текущего состояния для функции потерь
        batch_mask_0 = (batch_state_0 == 0)
        batch_mask_0 = batch_mask_0.view((param_batch_size, -1)).to(device)
        output_0 = dqn(batch_state_0, batch_mask_0) 

        # Значение Q, по которому было совершено действие
        batch_q = torch.sum(output_0*batch_action, dim=1)        
        
        batch_y = batch_y.detach()
        loss = param_loss_fn(batch_q, batch_y)        

        loss.backward()
        
        for param in dqn.parameters():
            param.grad.data.clamp_(-1, 1)
        dqn_opt.step()

        if epoch % 1000 == 0:
            # Обновление весов target net
            dqn_target.load_state_dict(dqn.state_dict())
            # Вывод результатов
            res = torch.cat(tuple([d[2] for d in replay_memory[-1000:] if d[4]]))
            res = torch.sum(res)/len(res)
            wins = len([d[2] for d in replay_memory[-1000:] if d[2]==1])
            draws = len([d[2] for d in replay_memory[-1000:] if d[2]==0.1])
            losses = len([d[2] for d in replay_memory[-1000:] if (d[2]==0 and d[4])])
            print("Epoch {} loss: {}, reward ratio: {:.3f}, wins: {}, draws: {}, losses {}".format(epoch, loss.item(), res, wins, draws, losses))
            if epoch % 10000 == 0:
                torch.save(dqn.state_dict(), os.path.join(PATH, filename.format(epoch)))

Epoch 1000 loss: 0.0017406800761818886, reward ratio: 0.961, wins: 241, draws: 2, losses 8
Epoch 2000 loss: 0.002646678127348423, reward ratio: 0.974, wins: 243, draws: 4, losses 3
Epoch 3000 loss: 0.0032253579702228308, reward ratio: 0.969, wins: 243, draws: 3, losses 5
Epoch 4000 loss: 0.0014677841681987047, reward ratio: 0.965, wins: 241, draws: 2, losses 7
Epoch 5000 loss: 0.003262968733906746, reward ratio: 0.992, wins: 248, draws: 1, losses 1
Epoch 6000 loss: 0.0026034191250801086, reward ratio: 0.992, wins: 249, draws: 0, losses 2
Epoch 7000 loss: 0.004114625044167042, reward ratio: 0.954, wins: 239, draws: 5, losses 7
Epoch 8000 loss: 0.0029856807086616755, reward ratio: 0.961, wins: 242, draws: 1, losses 9
Epoch 9000 loss: 0.0031430646777153015, reward ratio: 0.973, wins: 244, draws: 1, losses 6
Epoch 10000 loss: 0.002826766576617956, reward ratio: 0.985, wins: 246, draws: 3, losses 1


Сеть тоже научилась играть против бота, хотя здесь результаты выглядят менее стабильными. Потребовалось около 300000 батчей (по инерции назвал переменную эпохи) по 1024 наблюдения из replay memory. Эпсилон снижался от 0.3 до 0.01.