In [1]:
import numpy as np
import seaborn as sns
from time import time

import tensorflow as tf
from tensorflow import keras
from keras import layers, models
from keras.regularizers import L1L2
import keras.backend as K

import os
import gc
from pathlib import Path

tf.config.experimental.set_visible_devices([], 'GPU')

Using TensorFlow backend.


In [2]:
init_state = np.array([
    [0, 0, 0],
    [0, 0, 0],
    [0, 0, 0]
])

In [3]:
class Game:
    def __init__(self, state, FIRST=1):
        self.state = state
        self.empty = self.make_empty(state)
        self.first_player = FIRST
        
    def make_empty(self, state):
        emp = []
        for i in range(3):
            for j in range(3):
                if state[i][j] == 0:
                    emp.append(3*i + j)
        
        return emp
    
    def is_lose(self):
        a = self.next_opp()
        
        for i in range(3):
            if self.state[i][0] == self.state[i][1] == self.state[i][2] == a:
                return True
            elif self.state[0][i] == self.state[1][i] == self.state[2][i] == a:
                return True
        if self.state[0][0] == self.state[1][1] == self.state[2][2] == a:
            return True
        if self.state[0][2] == self.state[1][1] == self.state[2][0] == a:
            return True
        return 0
    
    def is_draw(self):
        a = self.next_opp()
        if self.is_lose():
            return 0
        if np.all(self.state):
            return 1
        else:
            return 0
        
    def is_done(self):
        if self.is_lose() or self.is_draw():
            return 1
        else:
            return 0
        
        
    def update(self, target):
        state = self.state.copy()
        x, y = target//3, target%3
        a = self.next_opp()
        state[x][y] = a
        return Game(state)
    
    
    def next_opp(self):
        a = b = 0
        for i in range(len(self.state)):
            for j in range(len(self.state)):
                if self.state[i][j] == self.first_player:
                    a += 1
                elif self.state[i][j] != 0:
                    b += 1
                    
        if a == b:
            return self.first_player
        else:
            return 2 + min(0, 1-self.first_player)

In [4]:
class Random:
    def action(self, game):
        return np.random.choice(game.empty)

In [5]:
def playout(game):
    if game.is_lose():
        return -1

    if game.is_draw():
        return 0

    return -playout(game.update(np.random.choice(game.empty)))


def action(game):
    values = [0] * 9
    
    for i in range(9):
        if i in game.empty:
            for _ in range(100):
                g = game.update(a)
                values[i] += playout(g)

    return np.argmax(values)

def value(game):
    n_steps = 50
    values = [0] * 9
    
    for i in range(9):
        if i in game.empty:
            for _ in range(n_steps):
                g = game.update(i)
                values[i] += playout(g)
    for j in range(9):
        values[j] /= n_steps

    return values

In [9]:
DN_FILTERS = 128  # 컨볼루션 레이어 커널 수(오리지널 256）
DN_RESIDUAL_NUM = 16  # 레지듀얼 블록 수(오리지널 19)
DN_INPUT_SHAPE = (1, 3, 3)  # 입력 셰이프
DN_OUTPUT_SIZE = 9  # 행동 수(배치 수(3*3))
    
def residual_block():
    def f(x):
        sc = x
        x = layers.Conv2D(DN_FILTERS, 3, padding='same', use_bias=False,
              kernel_initializer='he_normal', kernel_regularizer=L1L2(l2=0.0005))(x)
        x = layers.BatchNormalization()(x)
        x = layers.Activation('relu')(x)
        x = layers.Conv2D(DN_FILTERS, 3, padding='same', use_bias=False,
              kernel_initializer='he_normal', kernel_regularizer=L1L2(l2=0.0005))(x)
        x = layers.BatchNormalization()(x)
        x = layers.Add()([x, sc])
        x = layers.Activation('relu')(x)
        return x

    return f
    
def dual_network():
    # 입력 레이어
    input = layers.Input(shape=DN_INPUT_SHAPE)

    # 컨볼루션 레이어
    x = layers.Conv2D(DN_FILTERS, 3, padding='same', use_bias=False,
              kernel_initializer='he_normal', kernel_regularizer=L1L2(l2=0.0005))(input)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)

    # 레지듀얼 블록 x 16
    for i in range(DN_RESIDUAL_NUM):
        x = residual_block()(x)

    # 풀링 레이어
    x = layers.GlobalAveragePooling2D()(x)

    # policy 출력
    p = layers.Dense(DN_OUTPUT_SIZE, kernel_regularizer=L1L2(l2=0.0005),
                     activation='tanh', name='pi')(x)
    # 모델 생성
    model = models.Model(inputs=input, outputs=p)

    model.compile(optimizer = 'adam',
                 loss = 'logcosh')

    return model

GAMMA = 0.95
class CNN:
    def __init__(self):
        K.clear_session()         
#         self.model = models.load_model('./tanh2.h5')
        self.model = dual_network()
        
    def action(self, game):
        state = self.make_state(game)
        res = self.model.predict(state)[0]
        a = np.argmax(res)
        while a not in game.empty:
            res[a] = -float('inf')
            a = np.argmax(res)
        return a

    def value(self, game):
        values = [0] * 9
        n_steps = 20

        for i in range(9):
            if i in game.empty:
                for _ in range(n_steps):
                    g = game.update(i)
                    values[i] += playout(g)
            values[i] /= n_steps
            
        return values

    def train(self, warmup = False):
        if warmup:
            self.warmup()
            
#         for i in [1, 2]:
        game = Game(init_state, 1)
        X = []
        y = []
        while 1:
            state = self.make_state(game)

            a = self.action(game)
            temp_y = [0] * 9
            temp_target = self.model.predict(state)[0]
            game = game.update(a)
            state_next = self.make_state(game)

            if game.is_done():
                r = game.is_lose() * -1
            else:
                r = 0

            X.append(state)

            target = np.max(self.model.predict(state_next)[0])
            temp_y[a] += (r + temp_target[a] + -GAMMA*target)
            print(r)
            print(temp_y)

            y.append(temp_y)
            
            if game.is_done():
                break


        print('############')
        X = np.reshape(X, (len(X), 1, 3, 3))
        y = np.reshape(y, (len(y), 9))   
        self.model.fit(X, y, epochs=1, verbose=0)
                
                
    def warmup(self, n=50):
        for _ in range(n):
            game = Game(init_state, 1)
            X = []
            y = []
            while 1:
                state = self.make_state(game)

                a = np.random.choice(game.empty)
                temp_y = [0] * 9
                temp_target = value(game)
                game = game.update(a)
                state_next = self.make_state(game)

                if game.is_done():
                    r = game.is_lose() * -1
                else:
                    r = 0

                X.append(state)

                target = np.mean(value(game))
                temp_y[a] += (r + temp_target[a] + -GAMMA*target)
                
                y.append(temp_y)

                if game.is_done():
                    break

        X = np.reshape(X, (len(X), 1, 3, 3))
        y = np.reshape(y, (len(y), 9))   
        self.model.fit(X, y, epochs=1, verbose=0)
                

    def make_state(self, game):
        status = game.next_opp()
        state = np.reshape(game.state, (1, 1, 3, 3)).astype('float')
        state = np.where(state==status, 1., np.where(state==0, 0, -1.))
        return state
    
#     def opp(self, status):
#         return 2 + min(0, 1-status)


In [10]:
dd = CNN()

In [11]:
dd.warmup(500)

In [None]:
dd.train()

In [18]:
for _ in range(200):
    print(_)
    dd.train()
    
# dd.model.save('./tanh4.h5')

0
0
[0, 0, 0, 0, -0.38381389211863276, 0, 0, 0, 0]
0
[0, 0, 0, 0, 0, 0, -0.3028894454240799, 0, 0]
0
[0, 0, 0, 0, 0, 0, 0, -0.14911594986915588, 0]
0
[0, 0, 0, -0.5219645351171494, 0, 0, 0, 0, 0]
0
[0, 0, 0, 0, 0, 0, 0, 0, -0.6717596426606178]
0
[0, 0, 0, 0, 0, -0.786053417623043, 0, 0, 0]
0
[0, -0.9107467025518416, 0, 0, 0, 0, 0, 0, 0]
-1
[0, 0, -2.634351706504822, 0, 0, 0, 0, 0, 0]
############
1
0
[0, 0, 0, 0, -0.33709340319037434, 0, 0, 0, 0]
0
[0, 0, 0, -0.2659411370754242, 0, 0, 0, 0, 0]
0
[0, 0, 0, 0, 0, 0, -0.2910441905260086, 0, 0]
0
[0, 0, 0, 0, 0, 0, 0, -0.2272878319025039, 0]
0
[0, 0, 0, 0, 0, 0, 0, 0, -0.672182522714138]
0
[0, -0.702257177233696, 0, 0, 0, 0, 0, 0, 0]
0
[-1.2340794116258622, 0, 0, 0, 0, 0, 0, 0, 0]
-1
[0, 0, 0, 0, 0, -2.0645235151052477, 0, 0, 0]
############
2
0
[0, 0, 0, 0, -0.35761276111006735, 0, 0, 0, 0]
0
[0, 0, 0, -0.1987566769123077, 0, 0, 0, 0, 0]
0
[0, 0, 0, 0, 0, 0, -0.3041458457708358, 0, 0]
0
[0, 0, 0, 0, 0, 0, 0, -0.19780147373676293, 0]
0
[0,

0
[0, 0, -1.2057982712984083, 0, 0, 0, 0, 0, 0]
0
[0, 0, 0, 0, 0, -1.258391445875168, 0, 0, 0]
0
[0, -1.6529288709163665, 0, 0, 0, 0, 0, 0, 0]
############
22
0
[0, 0, 0, 0, 0, 0, 0, 0.20927056148648263, 0]
0
[0, 0, 0, -0.10123191401362419, 0, 0, 0, 0, 0]
0
[0, 0, 0, 0, -0.22864724397659297, 0, 0, 0, 0]
0
[0, 0, 0, 0, 0, 0, -0.13208682090044022, 0, 0]
0
[-0.651964557915926, 0, 0, 0, 0, 0, 0, 0, 0]
0
[0, 0, 0, 0, 0, 0, 0, 0, -0.7429749459028243]
0
[0, 0, -1.1741699516773223, 0, 0, 0, 0, 0, 0]
0
[0, 0, 0, 0, 0, -1.1823719084262847, 0, 0, 0]
0
[0, -1.64155233502388, 0, 0, 0, 0, 0, 0, 0]
############
23
0
[0, 0, 0, 0, 0, 0, 0, 0.21024608425796032, 0]
0
[0, 0, 0, -0.07835073918104171, 0, 0, 0, 0, 0]
0
[0, 0, 0, 0, -0.2603208854794502, 0, 0, 0, 0]
0
[0, 0, 0, 0, 0, 0, -0.11425849199295043, 0, 0]
0
[-0.6384246192872524, 0, 0, 0, 0, 0, 0, 0, 0]
0
[0, 0, 0, 0, 0, 0, 0, 0, -0.7702553421258926]
0
[0, 0, -1.1548882603645323, 0, 0, 0, 0, 0, 0]
0
[0, 0, 0, 0, 0, -1.1301133781671524, 0, 0, 0]
0
[0, -

0
[0, 0, 0, 0, 0, 0, 0, 0, -1.152254083752632]
0
[0, -1.4290190637111664, 0, 0, 0, 0, 0, 0, 0]
############
41
0
[0, 0, 0, 0, 0, 0, 0, 0.2198599047958851, 0]
0
[0, 0, 0, -0.033261735737323744, 0, 0, 0, 0, 0]
0
[0, 0, 0, 0, -0.06010322123765946, 0, 0, 0, 0]
0
[0, 0, 0, 0, 0, 0, -0.10696022734045982, 0, 0]
0
[-0.632149401307106, 0, 0, 0, 0, 0, 0, 0, 0]
0
[0, 0, 0, 0, 0, -0.7947017267346381, 0, 0, 0]
0
[0, 0, -0.7372744433581828, 0, 0, 0, 0, 0, 0]
0
[0, 0, 0, 0, 0, 0, 0, 0, -1.125201776623726]
0
[0, -1.418027889728546, 0, 0, 0, 0, 0, 0, 0]
############
42
0
[0, 0, 0, 0, 0, 0, 0, 0.22776921391487123, 0]
0
[0, 0, 0, -0.03615039587020874, 0, 0, 0, 0, 0]
0
[0, 0, 0, 0, -0.04501946792006492, 0, 0, 0, 0]
0
[0, 0, 0, 0, 0, 0, -0.11327432468533516, 0, 0]
0
[-0.6293142162263393, 0, 0, 0, 0, 0, 0, 0, 0]
0
[0, 0, 0, 0, 0, -0.7536573618650436, 0, 0, 0]
0
[0, 0, -0.7059630647301673, 0, 0, 0, 0, 0, 0]
0
[0, 0, 0, 0, 0, 0, 0, 0, -1.1008298426866532]
0
[0, -1.3919491440057754, 0, 0, 0, 0, 0, 0, 0]
######

0
[0, -0.24237281978130337, 0, 0, 0, 0, 0, 0, 0]
-1
[0, 0, 0, -1.3030302979052066, 0, 0, 0, 0, 0]
############
62
0
[0, 0, 0, 0, 0, 0, 0, 0.12766319513320923, 0]
0
[0, 0, 0.00037871152162552435, 0, 0, 0, 0, 0, 0]
0
[0, -0.25863667130470275, 0, 0, 0, 0, 0, 0, 0]
0
[0, 0, 0, 0, 0, 0, 0, 0, -0.1905746743083]
0
[0, 0, 0, 0, -0.37825683839619156, 0, 0, 0, 0]
-1
[0, 0, 0, -1.63840491771698, 0, 0, 0, 0, 0]
############
63
0
[0, 0, 0, 0, 0, 0, 0, 0.1411359027028084, 0]
0
[0, 0, 0.0008995853364467676, 0, 0, 0, 0, 0, 0]
0
[0, -0.2259813196957111, 0, 0, 0, 0, 0, 0, 0]
0
[0, 0, 0, 0, 0, 0, 0, 0, -0.18573930561542512]
0
[0, 0, 0, 0, -0.41465264558792114, 0, 0, 0, 0]
-1
[-1.5915640249848366, 0, 0, 0, 0, 0, 0, 0, 0]
############
64
0
[0, 0, 0, 0, 0, 0, 0, 0.16294617876410486, 0]
0
[0, 0, -0.002228158712387074, 0, 0, 0, 0, 0, 0]
0
[0, -0.19098378717899323, 0, 0, 0, 0, 0, 0, 0]
0
[0, 0, 0, 0, 0, 0, 0, 0, -0.18408440351486205]
0
[0, 0, 0, 0, -0.44673114418983456, 0, 0, 0, 0]
-1
[-1.5741581529378892, 0, 

0
[0, -0.23991141766309737, 0, 0, 0, 0, 0, 0, 0]
0
[0, 0, 0, 0, 0, 0, 0, 0, -0.173700144700706]
0
[-0.8474057376384735, 0, 0, 0, 0, 0, 0, 0, 0]
0
[0, 0, 0, 0, 0, -0.521926436573267, 0, 0, 0]
-1
[0, 0, 0, 0, 0, 0, -1.9748806297779082, 0, 0]
############
83
0
[0, 0, 0, 0, 0, 0, 0, 0.10419136062264445, 0]
0
[0, 0, 0.14796404652297496, 0, 0, 0, 0, 0, 0]
0
[0, -0.24267452210187912, 0, 0, 0, 0, 0, 0, 0]
0
[0, 0, 0, 0, 0, 0, 0, 0, -0.19336138125509023]
0
[-0.811134672164917, 0, 0, 0, 0, 0, 0, 0, 0]
0
[0, 0, 0, 0, 0, -0.5278810404241085, 0, 0, 0]
-1
[0, 0, 0, 0, 0, 0, -1.8759886801242829, 0, 0]
############
84
0
[0, 0, 0, 0, 0, 0, 0, 0.10113812386989596, 0]
0
[0, 0, 0.15182683244347572, 0, 0, 0, 0, 0, 0]
0
[0, -0.2639596283435821, 0, 0, 0, 0, 0, 0, 0]
0
[0, 0, 0, 0, 0, 0, 0, 0, -0.19343325961381197]
0
[-0.7956646114587783, 0, 0, 0, 0, 0, 0, 0, 0]
0
[0, 0, 0, 0, -0.48932993113994594, 0, 0, 0, 0]
0
[0, 0, 0, 0, 0, 0, -0.6705560356378555, 0, 0]
0
[0, 0, 0, 0, 0, -0.6533151000738144, 0, 0, 0]
-1
[

0
[0, 0, 0, 0, 0, 0, 0, 0, -0.44465309530496594]
0
[0, 0, 0, 0, 0, 0, -0.9466540455818176, 0, 0]
0
[0, 0, 0, 0, -0.9786636799573898, 0, 0, 0, 0]
############
103
0
[0, 0, 0, 0, 0, 0, 0, -0.01667216420173645, 0]
0
[0, 0, 0.136700489372015, 0, 0, 0, 0, 0, 0]
0
[-0.19180080480873585, 0, 0, 0, 0, 0, 0, 0, 0]
0
[0, 0, 0, -0.035293339192867285, 0, 0, 0, 0, 0]
0
[0, -0.13450322449207303, 0, 0, 0, 0, 0, 0, 0]
0
[0, 0, 0, 0, 0, -0.24844992496073243, 0, 0, 0]
0
[0, 0, 0, 0, 0, 0, 0, 0, -0.48010275363922117]
0
[0, 0, 0, 0, -1.0979998499155044, 0, 0, 0, 0]
-1
[0, 0, 0, 0, 0, 0, -1.9030258297920226, 0, 0]
############
104
0
[0, 0, 0, 0, 0, 0, 0, -0.008460700511932373, 0]
0
[0, 0, 0.13487302511930466, 0, 0, 0, 0, 0, 0]
0
[0, 0, 0, 0, 0, -0.427087889611721, 0, 0, 0]
0
[0, 0, 0, 0, 0, 0, 0, 0, -0.25624054223299025]
0
[0, -0.27890971004962917, 0, 0, 0, 0, 0, 0, 0]
0
[0, 0, 0, -0.21465554125607011, 0, 0, 0, 0, 0]
0
[0, 0, 0, 0, -0.5447147943079471, 0, 0, 0, 0]
-1
[-1.4156644284725188, 0, 0, 0, 0, 0, 0, 

0
[0, 0, 0, 0, -1.4317363321781158, 0, 0, 0, 0]
############
123
0
[0, 0, 0.059225909411907196, 0, 0, 0, 0, 0, 0]
0
[0, 0, 0, 0, 0, 0, -0.267326632142067, 0, 0]
0
[0, 0, 0, 0, 0, 0, 0, -0.37365991845726965, 0]
0
[0, 0, 0, 0, 0, -0.12528599351644515, 0, 0, 0]
0
[0, 0, 0, -0.39620348215103146, 0, 0, 0, 0, 0]
0
[0, 0, 0, 0, 0, 0, 0, 0, -0.21451046466827395]
0
[0, -0.3205125033855438, 0, 0, 0, 0, 0, 0, 0]
0
[-1.511491060256958, 0, 0, 0, 0, 0, 0, 0, 0]
0
[0, 0, 0, 0, -1.5157508373260498, 0, 0, 0, 0]
############
124
0
[0, 0, 0.054083384573459625, 0, 0, 0, 0, 0, 0]
0
[0, 0, 0, 0, 0, 0, -0.2117055382579565, 0, 0]
0
[0, 0, 0, 0, 0, 0, 0, -0.3786092888563871, 0]
0
[0, 0, 0, 0, 0, -0.08620980232954023, 0, 0, 0]
0
[0, 0, 0, -0.3468835290521383, 0, 0, 0, 0, 0]
0
[0, 0, 0, 0, 0, 0, 0, 0, -0.25198460221290586]
0
[0, -0.32785353660583494, 0, 0, 0, 0, 0, 0, 0]
0
[-1.4863127529621125, 0, 0, 0, 0, 0, 0, 0, 0]
0
[0, 0, 0, 0, -1.5682019144296646, 0, 0, 0, 0]
############
125
0
[0, 0, 0.04761806279420855, 

0
[0, 0, 0, 0, 0, 0, 0, -0.5511227786540985, 0]
0
[0, -1.1695489525794982, 0, 0, 0, 0, 0, 0, 0]
0
[0, 0, 0, 0, -0.9854485183954238, 0, 0, 0, 0]
############
144
0
[0, 0, 0.010985499620437628, 0, 0, 0, 0, 0, 0]
0
[0, 0, 0, 0, 0, 0, -0.02147368043661116, 0, 0]
0
[0, 0, 0, -0.3845736481249332, 0, 0, 0, 0, 0]
0
[0, 0, 0, 0, 0, 0, 0, 0, -0.32446761187165973]
0
[-0.2672025769948959, 0, 0, 0, 0, 0, 0, 0, 0]
0
[0, 0, 0, 0, 0, -0.6661317318677902, 0, 0, 0]
0
[0, 0, 0, 0, 0, 0, 0, -0.569666963815689, 0]
0
[0, 0, 0, 0, -1.2444365620613098, 0, 0, 0, 0]
0
[0, -1.3494553178548814, 0, 0, 0, 0, 0, 0, 0]
############
145
0
[0, 0, 0.01332723647356035, 0, 0, 0, 0, 0, 0]
0
[0, 0, 0, 0, 0, 0, -0.0435067169368267, 0, 0]
0
[0, 0, 0, -0.4214024983346462, 0, 0, 0, 0, 0]
0
[0, 0, 0, 0, 0, 0, 0, 0, -0.3977409705519676]
0
[0, 0, 0, 0, 0, -0.3696875788271427, 0, 0, 0]
0
[0, 0, 0, 0, 0, 0, 0, -0.3795418247580528, 0]
-1
[-0.6604339063167572, 0, 0, 0, 0, 0, 0, 0, 0]
############
146
0
[0, 0, 0.016170991957187686, 0, 

0
[0, 0, 0, 0, 0, -0.3400326490402221, 0, 0, 0]
0
[0, 0, 0, 0, 0, 0, 0, 0, -0.5638157725334167]
0
[0, 0, 0, 0, 0, 0, -0.07041229084134101, 0, 0]
0
[0, 0, 0, 0, 0, 0, 0, -0.03950206413865087, 0]
0
[0, 0, 0, -0.1536789409816265, 0, 0, 0, 0, 0]
0
[0, 0, 0, 0, -0.7628052577376365, 0, 0, 0, 0]
0
[0, -0.5341293916106225, 0, 0, 0, 0, 0, 0, 0]
0
[-1.046386080980301, 0, 0, 0, 0, 0, 0, 0, 0]
############
164
0
[0, 0, 0.03745550215244292, 0, 0, 0, 0, 0, 0]
0
[0, 0, 0, 0, 0, -0.34297346025705333, 0, 0, 0]
0
[0, 0, 0, 0, 0, 0, -0.5501728355884552, 0, 0]
0
[0, -0.11421209350228309, 0, 0, 0, 0, 0, 0, 0]
0
[0, 0, 0, -0.34712261855602267, 0, 0, 0, 0, 0]
0
[0, 0, 0, 0, 0, 0, 0, 0, -0.20416461378335954]
0
[0, 0, 0, 0, 0, 0, 0, -0.20862739831209182, 0]
0
[-0.6692537084221839, 0, 0, 0, 0, 0, 0, 0, 0]
0
[0, 0, 0, 0, -0.7067982591688633, 0, 0, 0, 0]
############
165
0
[0, 0, 0.03556139022111893, 0, 0, 0, 0, 0, 0]
0
[0, 0, 0, 0, 0, 0, -0.2497332006692886, 0, 0]
0
[0, 0, 0, 0, 0, -0.26378328949213026, 0, 0, 0]

0
[0, 0, 0, 0, 0, 0, 0, -0.36024763360619544, 0]
0
[0, 0, 0, 0, 0, -0.16902199685573577, 0, 0, 0]
0
[0, -0.7834323018789291, 0, 0, 0, 0, 0, 0, 0]
0
[0, 0, 0, -0.6909629970788955, 0, 0, 0, 0, 0]
0
[0, 0, 0, 0, -0.8463765777647495, 0, 0, 0, 0]
-1
[0, 0, 0, 0, 0, 0, 0, 0, -2.336162103712559]
############
185
0
[0, 0, 0.013114631175994873, 0, 0, 0, 0, 0, 0]
0
[0, 0, 0, 0, 0, 0, -0.20067742541432382, 0, 0]
0
[0, 0, 0, 0, -0.5336537881521508, 0, 0, 0, 0]
0
[0, 0, 0, 0, 0, -0.3275224611163139, 0, 0, 0]
0
[0, 0, 0, 0, 0, 0, 0, -0.6197223559021949, 0]
0
[0, 0, 0, -0.41510728001594543, 0, 0, 0, 0, 0]
0
[0, -0.871392998099327, 0, 0, 0, 0, 0, 0, 0]
-1
[0, 0, 0, 0, 0, 0, 0, 0, -2.3510365217924116]
############
186
0
[0, 0, 0.013392648845911043, 0, 0, 0, 0, 0, 0]
0
[0, 0, 0, 0, 0, 0, -0.1986985459923744, 0, 0]
0
[0, 0, 0, 0, -0.5235970301553606, 0, 0, 0, 0]
0
[0, 0, 0, 0, 0, -0.3044390872120857, 0, 0, 0]
0
[0, 0, 0, 0, 0, 0, 0, -0.5975855618715286, 0]
0
[0, 0, 0, -0.29792671389877795, 0, 0, 0, 0, 0]

In [19]:
a = np.array([
    [0, 0, 0],
    [0, 0, 0],
    [0, 0, 0]
])
a = np.reshape(a, (1, 1, 3, 3)).astype('float')
a = np.where(a==1, 1., np.where(a==0, 0, -1.))

dd.model.predict(a)


array([[-0.24207835, -0.16172211,  0.27029547, -0.01331075, -0.04390388,
        -0.04703951,  0.05251299,  0.05345737, -0.33706117]],
      dtype=float32)

In [20]:
gc.collect()

122390

In [15]:
def play(game, m1, m2):
    global score
    while 1:
        a1 = m1.action(game)
        game = game.update(a1)
#         print(game.state)
        if game.is_lose():
            score[0] += 1
#             print(game.state)
            return 
        elif game.is_draw():
            score[2] += 1
#             print(game.state)
            return 

        a2 = m2.action(game)
        game = game.update(a2)
#         print(game.state)
        if game.is_lose():
            score[1] += 1
#             print(game.state)
            return 
        elif game.is_draw():
            score[2] += 1
#             print(game.state)
            return 
        

In [16]:
game = Game(init_state)
m1 = Random()
# m2 = CNN()

In [21]:
%%time
score = [0, 0, 0]
for _ in range(100):
    play(game, m1, dd)

print(score)
# score = [0, 0, 0]
# for _ in range(100):
#     play(game, dd, m1)
# print(score)

[25, 42, 33]
Wall time: 920 ms


In [None]:
# sns.barplot(x = [1, 2], y = score[:2])

In [None]:
# tanh + mse
# v1: 732vs601 // 350vs323
# v2: 738vs399 // 411vs149
# v3: 1161vs773 // 791vs170
# v4: 1305vs567 // 832vs86