In [1]:
import gym
import random

import matplotlib as mpl
import matplotlib.pyplot as plt
import seaborn as sns
import json
import math
import random
import numpy as np
import scipy as sp
import scipy.stats as st
import scipy.integrate as integrate
from scipy.stats import multivariate_normal
from sklearn import linear_model
from sklearn.utils.testing import ignore_warnings
from sklearn.exceptions import ConvergenceWarning
import statsmodels.api as sm
from matplotlib.colors import LogNorm
import pickle

from joblib import Parallel, delayed
import multiprocessing
from collections import namedtuple
from itertools import count

import cProfile
from datetime import datetime

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as T
from torch.autograd import Variable

sns.set_style("whitegrid")
sns.set_palette("colorblind")
palette = sns.color_palette()
figsize = (15,8)
legend_fontsize = 16

from matplotlib import rc
rc('font',**{'family':'sans-serif'})
# rc('text', usetex=True)
# rc('text.latex',preamble=r'\usepackage[utf8]{inputenc}')
# rc('text.latex',preamble=r'\usepackage[russian]{babel}')
rc('figure', **{'dpi': 300})



In [2]:
from collections import defaultdict

In [3]:
N_ROWS, N_COLS, N_WIN = 3, 3, 3

In [4]:
class TicTacToe(gym.Env):
    def __init__(self, n_rows=N_ROWS, n_cols=N_COLS, n_win=N_WIN, clone=None):
        if clone is not None:
            self.n_rows, self.n_cols, self.n_win = clone.n_rows, clone.n_cols, clone.n_win
            self.board = copy.deepcopy(clone.board)
            self.curTurn = clone.curTurn
            self.emptySpaces = None
            self.boardHash = None
        else:
            self.n_rows = n_rows
            self.n_cols = n_cols
            self.n_win = n_win

            self.reset()

    def getEmptySpaces(self):
        if self.emptySpaces is None:
            res = np.where(self.board == 0)
            self.emptySpaces = np.array([ (i, j) for i,j in zip(res[0], res[1]) ])
        return self.emptySpaces

    def makeMove(self, player, i, j):
        self.board[i, j] = player
        self.emptySpaces = None
        self.boardHash = None

    def getHash(self):
        if self.boardHash is None:
            self.boardHash = ''.join(['%s' % (x+1) for x in self.board.reshape(self.n_rows * self.n_cols)])
        return self.boardHash

    def isTerminal(self):
        # проверим, не закончилась ли игра
        cur_marks, cur_p = np.where(self.board == self.curTurn), self.curTurn
        for i,j in zip(cur_marks[0], cur_marks[1]):
            win = False
            if i <= self.n_rows - self.n_win:
                if np.all(self.board[i:i+self.n_win, j] == cur_p):
                    win = True
            if not win:
                if j <= self.n_cols - self.n_win:
                    if np.all(self.board[i,j:j+self.n_win] == cur_p):
                        win = True
            if not win:
                if i <= self.n_rows - self.n_win and j <= self.n_cols - self.n_win:
                    if np.all(np.array([ self.board[i+k,j+k] == cur_p for k in range(self.n_win) ])):
                        win = True
            if not win:
                if i <= self.n_rows - self.n_win and j >= self.n_win-1:
                    if np.all(np.array([ self.board[i+k,j-k] == cur_p for k in range(self.n_win) ])):
                        win = True
            if win:
                self.gameOver = True
                return self.curTurn

        if len(self.getEmptySpaces()) == 0:
            self.gameOver = True
            return 0

        self.gameOver = False
        return None

    def printBoard(self):
        for i in range(0, self.n_rows):
            print('----'*(self.n_cols)+'-')
            out = '| '
            for j in range(0, self.n_cols):
                if self.board[i, j] == 1:
                    token = 'x'
                if self.board[i, j] == -1:
                    token = 'o'
                if self.board[i, j] == 0:
                    token = ' '
                out += token + ' | '
            print(out)
        print('----'*(self.n_cols)+'-')

    def getState(self):
        return (self.getHash(), self.getEmptySpaces(), self.curTurn)

    def action_from_int(self, action_int):
        return ( int(action_int / self.n_cols), int(action_int % self.n_cols))

    def int_from_action(self, action):
        return action[0] * self.n_cols + action[1]
    
    def step(self, action):
        if self.board[action[0], action[1]] != 0:
            return self.getState(), -10, True, {}
        self.makeMove(self.curTurn, action[0], action[1])
        reward = self.isTerminal()
        self.curTurn = -self.curTurn
        return self.getState(), 0 if reward is None else reward, reward is not None, {}

    def reset(self):
        self.board = np.zeros((self.n_rows, self.n_cols), dtype=int)
        self.boardHash = None
        self.gameOver = False
        self.emptySpaces = None
        self.curTurn = 1

In [5]:
def plot_board(env, pi, showtext=True, verbose=True, fontq=20, fontx=60):
    '''Рисуем доску с оценками из стратегии pi'''
    fig, ax = plt.subplots(1, 1, figsize=(10, 10))
    X, Y = np.meshgrid(np.arange(0, env.n_rows), np.arange(0, env.n_rows))
    Z = np.zeros((env.n_rows, env.n_cols)) + .01
    s, actions = env.getHash(), env.getEmptySpaces()
    if pi is not None and s in pi.Q:
        for i, a in enumerate(actions):
            Z[a[0], a[1]] = pi.Q[s][i]
    ax.set_xticks([])
    ax.set_yticks([])
    surf = ax.imshow(Z, cmap=plt.get_cmap('Accent', 10), vmin=-1, vmax=1)
    if showtext:
        for i,a in enumerate(actions):
            if pi is not None and s in pi.Q:
                ax.text( a[1] , a[0] , "%.3f" % pi.Q[s][i], fontsize=fontq, horizontalalignment='center', verticalalignment='center', color="w" )
    for i in range(env.n_rows):
        for j in range(env.n_cols):
            if env.board[i, j] == -1:
                ax.text(j, i, "O", fontsize=fontx, horizontalalignment='center', verticalalignment='center', color="w" )
            if env.board[i, j] == 1:
                ax.text(j, i, "X", fontsize=fontx, horizontalalignment='center', verticalalignment='center', color="w" )
#     cbar = plt.colorbar(surf, ticks=[0, 1])
    ax.grid(False)
    plt.show()

def get_and_print_move(env, pi, s, actions, random=False, verbose=True, fontq=20, fontx=60):
    '''Делаем ход, рисуем доску'''
    plot_board(env, pi, fontq=fontq, fontx=fontx)
    if verbose and (pi is not None):
        if s in pi.Q:
            for i,a in enumerate(actions):
                print(i, a, pi.Q[s][i])
        else:
            print("Стратегия не знает, что делать...")
    if random:
        return np.random.randint(len(actions))
    else:
        return pi.getActionGreedy(s, len(actions))

In [6]:
def plot_test_game(env, pi1, pi2, random_crosses=False, random_naughts=True, verbose=True, fontq=20, fontx=60):
    '''Играем тестовую партию между стратегиями или со случайными ходами, рисуем ход игры'''
    done = False
    env.reset()
    while not done:
        s, actions = env.getHash(), env.getEmptySpaces()
        if env.curTurn == 1:
            a = get_and_print_move(env, pi1, s, actions, random=random_crosses, verbose=verbose, fontq=fontq, fontx=fontx)
        else:
            a = get_and_print_move(env, pi2, s, actions, random=random_naughts, verbose=verbose, fontq=fontq, fontx=fontx)
        observation, reward, done, info = env.step(actions[a])
        if reward == 1:
            print("Крестики выиграли!")
            plot_board(env, None, showtext=False, fontq=fontq, fontx=fontx)
        if reward == -1:
            print("Нолики выиграли!")
            plot_board(env, None, showtext=False, fontq=fontq, fontx=fontx)

### Table Q-Learning

In [83]:
class QLearning:
    def __init__(self, env):
        self.env = env
        self.env.reset()
        self.actions = list(range(len(env.getState()[1]))) 
        self.q_table = defaultdict(lambda: np.zeros(9)) 
        
            
    def choose_action(self, state, epsilon):
        impossible_positions = np.where(np.array(list(state)) != "1")[0]
        possible_positions = np.where(np.array(list(state)) == "1")[0]
        for k in impossible_positions:
                self.q_table[state][k] = -np.inf
        action = np.argmax(self.q_table[state])       
                
        action = (np.random.choice(possible_positions) if epsilon > random.uniform(0, 1) else action)
        return action
        
           
    def training(self, n_iterations, epsilon=0.5, alpha=0.01, gamma=0.99):
        for i in range(n_iterations):
            self.env.reset()
            state_cross = self.env.getHash()
            action_cross = self.choose_action(state_cross, epsilon)
            (state_nought, _, turn), reward, done, info = self.env.step(self.env.action_from_int(action_cross))
            action_nought = self.choose_action(state_nought, epsilon)
            (state_cross, _, turn), reward, done, info = self.env.step(self.env.action_from_int(action_nought))
            done = False

            while not done:
                
                if self.env.curTurn == 1:            
                    action_cross = self.choose_action(state_cross, epsilon)
                    (state_nought_new, _, turn), reward, done, info = self.env.step(
                        self.env.action_from_int(action_cross))
                    if reward == 1:
                        self.q_table[state_cross][action_cross] == 1
                    
                    update_q_nought = alpha * (-1 * reward  + gamma * np.max(self.q_table[state_nought_new]) - self.q_table[state_nought][action_nought])
                    self.q_table[state_nought][action_nought] += update_q_nought
                    
                    state_nought = state_nought_new
                else:
                    action_nought = self.choose_action(state_nought, epsilon)
                    (state_cross_new, _, turn), reward, done, info = self.env.step(
                        self.env.action_from_int(action_nought))
        
                    if reward == -1:
                        self.q_table[state_nought][action_nought] == 1
                    
                    if reward == 10:
                        self.q_table[state_nought][action_nought] == -10
                    
                    update_q_cross = alpha * (1 * reward  + gamma * np.max(self.q_table[state_cross_new]) - self.q_table[state_cross][action_cross])
                    self.q_table[state_cross][action_cross] += update_q_cross
                    
                    state_cross = state_cross_new

                    
    def play(self, player):
        self.env.reset()
        done = False
        while not done:
            state = self.env.getHash()
            if self.env.curTurn == player:
                action = np.argmax(self.q_table[state])
            else:
                action = self.choose_action(state, epsilon=1)
            (state, _, _), reward, done, info = self.env.step(self.env.action_from_int(action))
        return reward * player
    
    
    def evaluate(self, player, n_iter):
        player_reward = []
        for i in range(n_iter):
            reward = self.play(player)
            if abs(reward) != 10:
                player_reward.append(reward)
        return(player_reward)

In [84]:
ql = QLearning(TicTacToe())

In [85]:
crosses = []
noughts = []
for i in list(range(1, 7)):
    ql.training(50000)
    print(f"After {50000 * i} iterations")
    n_iter = 1000
    crosses.append(np.mean(ql.evaluate(-1, n_iter)))
    noughts.append(np.mean(ql.evaluate(1, n_iter)))
    print(f"   Win rate for noughts: {np.mean(ql.evaluate(-1, n_iter))}")
    print(f"   Win rate for crosses: {np.mean(ql.evaluate(1, n_iter))}")

After 50000 iterations
   Win rate for noughts: 0.607
   Win rate for crosses: 0.939
After 100000 iterations
   Win rate for noughts: 0.69
   Win rate for crosses: 0.927
After 150000 iterations
   Win rate for noughts: 0.71
   Win rate for crosses: 0.935
After 200000 iterations
   Win rate for noughts: 0.723
   Win rate for crosses: 0.938
After 250000 iterations
   Win rate for noughts: 0.686
   Win rate for crosses: 0.921
After 300000 iterations
   Win rate for noughts: 0.702
   Win rate for crosses: 0.945
