In [1]:
!wget https://www.lamsade.dauphine.fr/~cazenave/project2022.zip
!unzip project2022.zip

--2023-04-28 14:32:26--  https://www.lamsade.dauphine.fr/~cazenave/project2022.zip
Resolving www.lamsade.dauphine.fr (www.lamsade.dauphine.fr)... 193.48.71.250
Connecting to www.lamsade.dauphine.fr (www.lamsade.dauphine.fr)|193.48.71.250|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 138884468 (132M) [application/zip]
Saving to: ‘project2022.zip’


2023-04-28 14:32:28 (98.3 MB/s) - ‘project2022.zip’ saved [138884468/138884468]

Archive:  project2022.zip
  inflating: Board.h                 
  inflating: Game.h                  
  inflating: Rzone.h                 
  inflating: compile.sh              
  inflating: compileMAC.sh           
  inflating: games.data              
  inflating: golois.cpp              
  inflating: golois.cpython-310-x86_64-linux-gnu.so  
  inflating: golois.cpython-37m-x86_64-linux-gnu.so  
  inflating: golois.cpython-38-x86_64-linux-gnu.so  
  inflating: golois.cpython-39-x86_64-linux-gnu.so  
  inflating: golois.py             


The data used for training comes from the Katago Go program self played games. There are 1 000 000 different games in total in the training set. The input data is composed of 31 19x19 planes (color to play, ladders, current state on two planes, two previous states on four planes). The output targets are the policy (a vector of size 361 with 1.0 for the move played, 0.0 for the other moves), and the value (close to 1.0 if White wins, close to 0.0 if Black wins). The policy network samples actions and the value network evaluates positions, these are used to reduce the depth of the search tree and the breadth of the search.
The number of parameters for the networks must be lower than 100 000. The metrics used to evaluate the model are the loss, policy loss, value loss, policy accuracy, and value MSE.

In [None]:
import tensorflow as tf
import tensorflow.keras as keras
import numpy as np
from tensorflow.keras import layers 
from tensorflow.keras import regularizers
import matplotlib.pyplot as plt
import json
import gc
from tensorflow.keras.callbacks import ModelCheckpoint
from tensorflow.keras.models import load_model
import golois

planes = 31
moves = 361
N = 20000
epochs = 5
batch = 128
filters = 80

input_data = np.random.randint(2, size=(N, 19, 19, planes))
input_data = input_data.astype ('float32')

policy = np.random.randint(moves, size=(N,))
policy = keras.utils.to_categorical (policy)

value = np.random.randint(2, size=(N,))
value = value.astype ('float32')

end = np.random.randint(2, size=(N, 19, 19, 2))
end = end.astype ('float32')

groups = np.zeros((N, 19, 19, 1))
groups = groups.astype ('float32')

print ("getValidation", flush = True)
golois.getValidation (input_data, policy, value, end)

def SE_block(x, filters, ratio = 16):
    se_shape = (1,1, filters)
    se = layers.GlobalAveragePooling2D()(x)
    se = layers.Reshape(se_shape)(se)
    se = layers.Dense( filters // ratio, activation='swish',use_bias=False)(se)
    se = layers.Dense( filters, activation ='sigmoid', use_bias = False)(se)
    x = layers.multiply([x,se])
    return x 


input = keras.Input(shape=(19, 19, planes), name='board')
x = layers.Conv2D(filters, 1,padding='same')(input)
x = layers.BatchNormalization()(x)
x = layers.Activation('swish')(x)
loop_layer = x
for i in range (11):
    x = layers.Conv2D(filters, 1 , padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('swish')(x)
    x = layers.DepthwiseConv2D(3, padding='same', kernel_regularizer=regularizers.l2(0.0001))(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('swish')(x)
    x = SE_block(x,filters)
    x = layers.add([x,loop_layer])
    loop_layer = x
    

policy_head = layers.Conv2D(1, 1, activation='relu', padding='same', use_bias = False, kernel_regularizer=regularizers.l2(0.0001))(x)
policy_head = layers.Flatten()(policy_head)
policy_head = layers.Activation('softmax', name='policy')(policy_head)

value_head = layers.GlobalAveragePooling2D()(x)
value_head = layers.Flatten()(value_head)
value_head = layers.Dense(50, activation='relu', kernel_regularizer=regularizers.l2(0.0001))(value_head)
value_head = layers.Dense(1, activation='sigmoid', name='value', kernel_regularizer=regularizers.l2(0.0001))(value_head)

model = keras.Model(inputs=input, outputs=[policy_head, value_head])

model.summary ()



model.compile(optimizer=keras.optimizers.Adam(learning_rate=0.0005), 
              loss={'policy': 'categorical_crossentropy', 'value': 'binary_crossentropy'},
              loss_weights={'policy' : 1.0, 'value' : 1.0},
              metrics={'policy': 'categorical_accuracy', 'value': 'mse'})

loss = {}
policy_loss = {}
value_loss = {}
policy_accuracy = {}
value_mse = {}


for i in range (1, epochs + 1):
    print ('epoch ')
    golois.getBatch (input_data, policy, value, end, groups, i * N)
    history = model.fit(input_data,
                        {'policy': policy, 'value': value}, 
                        epochs=1, batch_size=batch)
    loss[i] = history.history['loss']
    policy_loss[i] = history.history['policy_loss']
    value_loss[i] = history.history['value_loss']
    policy_accuracy[i] = history.history['policy_categorical_accuracy']
    value_mse[i] = history.history['value_mse']
    
    
    if (i % 5 == 0):
        gc.collect ()
    if (i == 275):
        golois.getValidation (input_data, policy, value, end)
        val = model.evaluate (input_data,
                              [policy, value], verbose = 0, batch_size=batch)
        print ("val =", val)









In [None]:
fig, axs = plt.subplots(3, 2, figsize=(10,10))

axs[0, 0].plot(loss.keys(), loss.values())
axs[0, 0].set_title('Loss')
axs[0, 0].set_xlabel('Epochs')


axs[0, 1].plot(policy_loss.keys(), policy_loss.values())
axs[0, 1].set_title('Policy Loss')
axs[0, 1].set_xlabel('Epochs')


axs[1, 0].plot(value_loss.keys(), value_loss.values())
axs[1, 0].set_title('Value Loss')
axs[1, 0].set_xlabel('Epochs')


axs[1, 1].plot(policy_accuracy.keys(), policy_accuracy.values())
axs[1, 1].set_title('Policy Accuracy')
axs[1, 1].set_xlabel('Epochs')


axs[2, 0].plot(value_mse.keys(), value_mse.values())
axs[2, 0].set_title('Value MSE')
axs[2, 0].set_xlabel('Epochs')

fig.tight_layout(pad=2.0)
fig.delaxes(axs[2][1])
plt.show()