#### Source for the architecture

https://arxiv.org/pdf/1712.01815.pdf

https://www.chessprogramming.org/AlphaZero#:~:text=AlphaZero%20evaluates%20positions%20using%20non,policy)%20and%20a%20position%20evaluation.

http://www.talkchess.com/forum3/viewtopic.php?f=2&t=69175&start=93

residual block : https://www.chessprogramming.org/Neural_Networks#Residual

# Input size : 5 x 5 x 5
- board size : 5 x 5 
- P1 unique pieces : 2 
- P2 unique pieces : 2
- Colour

# Output size 5 x 5 x 13
- board size : 5 x 5 
- 13 possibles moves, 1 for each directions in ["N", "NE", "E", "SE", "S", "SW", "W", "NW"], plus the dragon, crab and tiger differents moves  

So a total of 5x5x13 = 325 moves

In [1]:
import numpy as np

### Layer codes

In [2]:
codes, i = {}, 0 
for nSquares in range(1,2):
    for direction in ["N", "NE", "E", "SE", "S", "SW", "W", "NW"]:
        codes[(nSquares, direction)] = i
        i += 1
        
for direction in ["E", "W", "S"]:
    codes[(2, direction)] = i
    i += 1

codes[("2E", "1S")] = 11
codes[("2W", "1S")] = 12

In [3]:
codes

{(1, 'N'): 0,
 (1, 'NE'): 1,
 (1, 'E'): 2,
 (1, 'SE'): 3,
 (1, 'S'): 4,
 (1, 'SW'): 5,
 (1, 'W'): 6,
 (1, 'NW'): 7,
 (2, 'E'): 8,
 (2, 'W'): 9,
 (2, 'S'): 10,
 ('2E', '1S'): 11,
 ('2W', '1S'): 12}

In [4]:
policy = np.zeros((5, 5, 13))

In [5]:
columns = {k:v for v,k in enumerate("abcde")}

In [6]:
d2d3policy = np.zeros((5, 5, 13))
d2d3policy[columns["d"], 3-1, codes[(1, 'N')]] = 1

In [7]:
c2d3policy = np.zeros((5, 5, 13))
c2d3policy[columns["d"], 3-1, codes[(1, 'NE')]] = 1

In [8]:
# 50/50 chance for each move
openingPolicy = (d2d3policy + c2d3policy) / 2

openingPolicy[2][3][8] mean that on the column C, on the line 3, the move was (2, 'E'). so the piece to move is in the column C and in the line 4

# Functionnal NN

In [9]:
# needed to do a class in order to have multiple instances of the nnet ? 

In [10]:
import tensorflow as tf
from tensorflow.keras import layers
from tensorflow.keras.models import Model

In [22]:
# input_block 
input_block = layers.Input(shape=(5, 5, 5))

In [23]:
# convolutionnal_layer 
x = layers.Conv2D(filters=256, kernel_size=(3,3), padding="same", activation="linear")(input_block)
x = layers.BatchNormalization()(x)
x = layers.LeakyReLU()(x)

In [24]:
# 19 Residual blocks with a skip connection
for _ in range(19):
    y = layers.Conv2D(filters=256, kernel_size=(3,3), padding="same", strides=1, activation="linear")(x)
    y = layers.BatchNormalization()(y)
    y = layers.LeakyReLU()(y)
    
    y = layers.Conv2D(filters=256, kernel_size=(3,3), padding="same", strides=1, activation="linear")(y)
    y = layers.BatchNormalization()(y)
    y = layers.LeakyReLU()(y)
    
    x = layers.Add()([x, y])

In [25]:
# policy_head with a final convolution of 13 filters
policy_head = layers.Conv2D(filters=256, kernel_size=(1, 1), padding="same", activation="linear")(x)
policy_head = layers.BatchNormalization()(policy_head)
policy_head = layers.LeakyReLU()(policy_head)
policy_head = layers.Conv2D(filters=13, kernel_size=(1, 1), padding="same", activation="linear")(x)

In [None]:
The value head applies an additional rectified, batch-normalized convolution of 1 filter of kernel size 1 ⇥ 1 with stride 1,
followed by a rectified linear layer of size 256 and a tanh-linear layer of size 1.

In [26]:
# value_head
value_head = layers.Conv2D(filters=1, kernel_size=(1,1), padding="same", strides=1, activation="linear")(x)
value_head = layers.BatchNormalization()(value_head)
value_head = layers.LeakyReLU()(value_head)
value_head = layers.Flatten()(value_head)
value_head = layers.Dense(256, activation="linear")(value_head)
value_head = layers.LeakyReLU()(value_head)
value_head = layers.BatchNormalization()(value_head)
value_head = layers.Dense(1, activation="tanh", name="value_head")(value_head)

In [27]:
model = Model(inputs=[input_block], outputs=[policy_head, value_head])
model.compile(loss=['categorical_crossentropy','mean_squared_error'], optimizer="Adam")

In [28]:
# input size (1,5,5,5)

In [29]:
model.predict(np.ones((1,5,5,5)))

[array([[7.4443847e-08, 2.8526601e-09, 1.5705516e-10, 9.3892404e-06,
         9.6059126e-01, 2.1774908e-13, 4.3645874e-04, 3.8961716e-02,
         1.1754929e-13, 4.9246911e-08, 3.0024462e-14, 1.4869140e-14,
         1.0087599e-06]], dtype=float32),
 array([[0.53387797]], dtype=float32)]