In [None]:
import numpy as np
import seaborn as sns
from time import time

import tensorflow as tf
from tensorflow import keras
from keras import layers, models
from keras.regularizers import L1L2
import keras.backend as K

import os
import gc
from pathlib import Path

tf.config.experimental.set_visible_devices([], 'GPU')

In [None]:
init_state = np.array([
    [0, 0, 0],
    [0, 0, 0],
    [0, 0, 0]
])

In [None]:
class Game:
    def __init__(self, state, FIRST=1):
        self.state = state
        self.empty = self.make_empty(state)
        self.first_player = FIRST
        
    def make_empty(self, state):
        emp = []
        for i in range(3):
            for j in range(3):
                if state[i][j] == 0:
                    emp.append(3*i + j)
        
        return emp
    
    def is_lose(self):
        a = self.next_opp()
        
        for i in range(3):
            if self.state[i][0] == self.state[i][1] == self.state[i][2] == a:
                return True
            elif self.state[0][i] == self.state[1][i] == self.state[2][i] == a:
                return True
        if self.state[0][0] == self.state[1][1] == self.state[2][2] == a:
            return True
        if self.state[0][2] == self.state[1][1] == self.state[2][0] == a:
            return True
        return 0
    
#     def is_win(self):       
#         a = self.next_opp()
#         for i in range(3):
#             if self.state[i][0] == self.state[i][1] == self.state[i][2] == a:
#                 return True
#             elif self.state[0][i] == self.state[1][i] == self.state[2][i] == a:
#                 return True
#         if self.state[0][0] == self.state[1][1] == self.state[2][2] == a:
#             return True
#         if self.state[0][2] == self.state[1][1] == self.state[2][0] == a:
#             return True
        
#         return False
    
    def is_draw(self):
        a = self.next_opp()
        if self.is_lose():
            return 0
        if np.all(self.state):
            return 1
        else:
            return 0
        
    def is_done(self):
        if self.is_lose() or self.is_draw():
            return 1
        else:
            return 0
        
        
    def update(self, target):
        state = self.state.copy()
        x, y = target//3, target%3
        a = self.next_opp()
        state[x][y] = a
        return Game(state)
    
    
    def next_opp(self):
        a = b = 0
        for i in range(len(self.state)):
            for j in range(len(self.state)):
                if self.state[i][j] == self.first_player:
                    a += 1
                elif self.state[i][j] != 0:
                    b += 1
                    
        if a == b:
            return self.first_player
        else:
            return 2 + min(0, 1-self.first_player)

In [None]:
class Random:
    def action(self, game):
        return np.random.choice(game.empty)

In [None]:
def playout(game):
    if game.is_lose():
        return -1

    if game.is_draw():
        return 0

    return -playout(game.update(np.random.choice(game.empty)))


def action(game):
    values = [0] * len(game.empty)

    for i, a in enumerate(game.empty):
        for _ in range(100):
            g = game.update(a)
            values[i] += playout(g)

    return game.empty[np.argmax(values)]

In [None]:
DN_FILTERS = 128  # 컨볼루션 레이어 커널 수(오리지널 256）
DN_RESIDUAL_NUM = 16  # 레지듀얼 블록 수(오리지널 19)
DN_INPUT_SHAPE = (1, 3, 3)  # 입력 셰이프
DN_OUTPUT_SIZE = 9  # 행동 수(배치 수(3*3))
    
def residual_block():
    def f(x):
        sc = x
        x = layers.Conv2D(DN_FILTERS, 3, padding='same', use_bias=False,
              kernel_initializer='he_normal', kernel_regularizer=L1L2(l2=0.0005))(x)
        x = layers.BatchNormalization()(x)
        x = layers.Activation('relu')(x)
        x = layers.Conv2D(DN_FILTERS, 3, padding='same', use_bias=False,
              kernel_initializer='he_normal', kernel_regularizer=L1L2(l2=0.0005))(x)
        x = layers.BatchNormalization()(x)
        x = layers.Add()([x, sc])
        x = layers.Activation('relu')(x)
        return x

    return f
    
def dual_network():
    # 입력 레이어
    input = layers.Input(shape=DN_INPUT_SHAPE)

    # 컨볼루션 레이어
    x = layers.Conv2D(DN_FILTERS, 3, padding='same', use_bias=False,
              kernel_initializer='he_normal', kernel_regularizer=L1L2(l2=0.0005))(input)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)

    # 레지듀얼 블록 x 16
    for i in range(DN_RESIDUAL_NUM):
        x = residual_block()(x)

    # 풀링 레이어
    x = layers.GlobalAveragePooling2D()(x)

    # policy 출력
    p = layers.Dense(DN_OUTPUT_SIZE, kernel_regularizer=L1L2(l2=0.0005),
              activation='tanh', name='pi')(x)

    # value 출력
#     v = layers.Dense(1, kernel_regularizer=L1L2(l2=0.0005))(x)
#     v = layers.Activation('tanh', name='v')(v)

    # 모델 생성
    model = models.Model(inputs=input, outputs=p)

    model.compile(optimizer = 'adam',
                 loss = 'logcosh')

    return model

class CNN:
    def __init__(self):
        K.clear_session()         
#         self.model = models.load_model('./tanh2.h5')
        self.model = dual_network()
        
    def action(self, game):
        status = game.next_opp()
        target = np.reshape(game.state, (1, 1, 3, 3)).astype('float')
        target = np.where(target==status, 1., np.where(target==0, 0, -1.))
        res = self.model.predict(target)[0]
        a = np.argmax(res)
        while a not in game.empty:
            res[a] = -float('inf')
            a = np.argmax(res)
        return a

    def value(self, game):
        values = [0] * 9
        n_steps = 20

        for i in range(9):
            if i in game.empty:
                for _ in range(n_steps):
                    g = game.update(i)
                    values[i] += playout(g)
            values[i] /= n_steps
            
        return values

    def train(self):
        for i in [1, 2]:
            game = Game(init_state, i)
            X = []
            y = []
            while 1:
                status = game.next_opp()
                state = np.reshape(game.state, (1, 1, 3, 3)).astype('float')
                state = np.where(state==status, 1., np.where(state==0, 0, -1.))

                X.append(state)
                y.append(self.value(game))

                a = action(game)
                game = game.update(a)

                if game.is_done():
                    break

            X = np.reshape(X, (len(X), 1, 3, 3))
            y = np.reshape(y, (len(y), 9))   
            self.model.fit(X, y, epochs=1, verbose=0)
                
    
    def opp(self, status):
        return 2 + min(0, 1-status)


In [None]:
dd = CNN()

In [None]:
dd.train()

In [None]:
for _ in range(200):
    print(_)
    dd.train()
    
# dd.model.save('./tanh4.h5')

In [None]:
a = init_state.copy()
a = np.reshape(a, (1, 1, 3, 3)).astype('float')
a = np.where(a==1, 1., np.where(a==0, 0, -1.))

dd.model.predict(a)


In [None]:
gc.collect()

In [None]:
def play(game, m1, m2):
    global score
    while 1:
        a1 = m1.action(game)
        game = game.update(a1)
#         print(game.state)
        if game.is_lose():
            score[0] += 1
#             print(game.state)
            return 
        elif game.is_draw():
            score[2] += 1
#             print(game.state)
            return 

        a2 = m2.action(game)
        game = game.update(a2)
#         print(game.state)
        if game.is_lose():
            score[1] += 1
#             print(game.state)
            return 
        elif game.is_draw():
            score[2] += 1
#             print(game.state)
            return 
        

In [None]:
game = Game(init_state)
m1 = Random()
# m2 = CNN()

In [None]:
%%time
score = [0, 0, 0]
for _ in range(100):
    play(game, m1, dd)

print(score)
score = [0, 0, 0]
for _ in range(100):
    play(game, dd, m1)
print(score)

In [None]:
# sns.barplot(x = [1, 2], y = score[:2])

In [None]:
# tanh + mse
# v1: 732vs601 // 350vs323
# v2: 738vs399 // 411vs149
# v3: 1161vs773 // 791vs170
# v4: 1305vs567 // 832vs86