In [154]:
#Impoting
import pandas as pd
import numpy as np
import tensorflow as tf
import random
from kaggle_environments import make
from collections import deque 

env = make("connectx", debug=True)

In [172]:
#Constants
ROWS = 6
COLS = 7
BATCH_SIZE = 64
REPLAY_MEMORY_SIZE = 10000
MIN_REPLAY_MEMORY_SIZE = 500
BATCH_SIZE = 32
DISCOUNT = 0.99
UPDATE_PREDICITON_EVERY = 5
EPSILON_DECAY = 0.99975
MIN_EPSILON = 0.001

In [156]:
def create_model():
    model = tf.keras.Sequential([tf.keras.layers.Conv2D(256,(4,4),input_shape=(ROWS,COLS,1)),
                                 
                                 tf.keras.layers.Flatten(),
                                 tf.keras.layers.Dense(128, activation = 'relu'),
                                 tf.keras.layers.Dropout(0.2),
                                 tf.keras.layers.Dense(64, activation = 'relu'),
                                 tf.keras.layers.Dropout(0.2),
                                 
                                 tf.keras.layers.Dense(COLS, activation = 'linear')])
    model.compile(tf.keras.optimizers.Adam(0.001), loss=tf.keras.losses.MeanSquaredError(), metrics=['accuracy'])
    return model              
                               

In [157]:
#Create new models
updated_model = create_model()

prediction_model = create_model()
prediction_model.set_weights(updated_model.get_weights())



In [158]:
def unflatten_board(board_batch):
    batch = np.empty((len(board_batch), ROWS, COLS, 1))
    for i in range(len(board_batch)):
        batch[i]=(np.array(board_batch[i]).reshape(ROWS, COLS, 1)) 
    return batch

In [159]:
def relative_board(board, player):
        board = [-1 if x != player and x != 0 else x for x in board]
        board = [1 if x == player else x for x in board]
        return board

In [174]:
def train_model():
    if len(replay_memory) >= MIN_REPLAY_MEMORY_SIZE:
        minibatch = random.sample(replay_memory, BATCH_SIZE)
        X = unflatten_board([x[0] for x in minibatch])
        
        current_qs_list = updated_model.predict(X)
        
        future_qs_list = prediction_model.predict(unflatten_board([x[3] for x in minibatch]))
        
        y = []
            
        for index, (current_board, action, reward, new_board, done) in enumerate(minibatch):
            reward = reward*20 #Scaling the rewards to help see what the netwrok is predicting
            if not done:
                for i in range(COLS):
                #If the move is unplayable it gets the reward of a loss (-20)
                    if current_board[i] != 0:
                        future_qs_list[index][i] = -20
                        
                max_future_q = np.max(future_qs_list[index])
                new_q = reward + DISCOUNT * max(-20, min(20, max_future_q)) #max and min to make sure it stays within the expected range
            else:
                new_q = reward

            #Update Q value for given state
            current_qs = current_qs_list[index]
            current_qs[action] = new_q

            #And append to our training data
            y.append(current_qs)
        updated_model.fit(X, np.array(y), batch_size=BATCH_SIZE, shuffle=False)    

In [177]:
def get_action(board):
    #if random.uniform(0, 1) <= epsilon:
        #options = [c for c in range(COLS) if board[c] == 0]
        #return options[random.randint(0, len(options)-1)]
    #else:
    q_vals = updated_model.predict(unflatten_board(np.array(board).reshape(-1, len(board))))  
    print(q_vals)
    best_playable = None                             
    for i in range(COLS):
        if board[i] == 0 and (best_playable == None  or q_vals[0][i] > q_vals[0][best_playable]):
            best_playable = i                      
    return best_playable                           

In [162]:
def agent(observation, configuration):
    board = relative_board(observation.board, observation.mark)
    return get_action(board)

In [179]:
#Trainer
#trainer = env.train([None, agent]);
#replay_memory = deque(maxlen = REPLAY_MEMORY_SIZE)
obs = trainer.reset()
prediction_update_couter = 0
print(len(replay_memory))
for _ in range(100):
    r_board = relative_board(obs.board, obs.mark)
    action = get_action(r_board)
    new_obs, reward, done, info = trainer.step(action);
    replay_memory.append([r_board, action, reward, relative_board(new_obs.board, new_obs.mark), done]);
    train_model()
    if done:
        obs = trainer.reset()
        #if epsilon > MIN_EPSILON:
            #epsilon *= EPSILON_DECAY
            #epsilon = max(MIN_EPSILON, epsilon)
        prediction_update_couter += 1
    else:
        obs = new_obs
    if(prediction_update_couter >= UPDATE_PREDICITON_EVERY):
        updated_model.save_weights('model')
        prediction_model.set_weights(updated_model.get_weights())
        prediction_update_couter = 0


765
[[-0.00727437 -0.48201263 -0.15225103  0.06162041  0.02023095 -0.00833592
  -0.08453599]]
[[ 0.01281223  0.12800333 -0.03656507 -0.05678732 -0.10724328  0.06372692
  -0.01434617]]
[[-0.07936531 -0.5622362  -0.05391463 -0.04527815  0.20693529 -0.12818953
  -0.11947542]]
[[ 0.02301363  0.3181936  -0.15290639  0.0028825   0.13898428  0.26826158
   0.02740951]]
[[-0.24508147 -1.1977715   0.05752116 -0.24924609  0.11170579 -0.47154152
  -0.23161264]]
[[ 0.36070386  3.5870583   0.02897456 -0.70664984 -0.24277203 -0.36931816
   0.1457765 ]]
[[-0.31901646 -4.6542215   0.00749896 -0.5170498   0.1914821  -0.6800768
  -1.23806   ]]
[[ 1.5045334  14.50755     0.7706396  -2.9996128  -0.9540988  -1.9729282
   0.88064337]]
[[-0.03111057 -0.50278443 -0.13848072  0.03812135  0.01118253 -0.02594661
  -0.08737137]]
[[ 0.01350557  0.27580073 -0.03639572 -0.08593491 -0.11878096  0.0495555
  -0.02150137]]
[[-0.09843865 -0.62644935 -0.03645425 -0.0756545   0.21599598 -0.17318423
  -0.13442071]]
[[ 0.0245

[[-0.06287293 -0.40888566 -0.12917185  0.04055132  0.15367138  0.04641285
  -0.02364662]]
[[ 0.10539938 -0.3759154   0.12185976  0.25070477  0.10895625  1.7118874
   0.15892185]]
[[-0.28107277 -1.7548479   0.00405543 -0.10022503 -0.11610004 -0.10777391
  -0.37444916]]
[[ 0.09921676 -0.50230193  0.313752    0.29340002  0.03019978  2.0540879
   0.14942524]]
[[-0.28978676 -3.5738072   0.1453501  -0.2502635  -0.448791    0.08145429
  -0.70125383]]
[[ 0.3035897  -1.0522726   0.6028333   0.35378277  0.19086878  4.058516
   0.20728616]]
[[-0.8146054  -9.075861    0.86174613 -0.891787   -1.962795    0.14756098
  -2.3116791 ]]
[[ 1.0588272  -3.5151792   1.896225    0.8520732   0.1141654  14.756985
   0.69560283]]
[[-0.10788359 -0.45222852 -0.11239737  0.04800691  0.2125887   0.14359857
  -0.02480749]]
[[ 0.10680063 -0.575028    0.2173291   0.29732987  0.07028732  2.0690205
   0.15290713]]
[[-0.40838015 -1.8989149   0.02237567 -0.13244197 -0.08367048 -0.13278586
  -0.4566754 ]]
[[ 0.12941876 -0.

[[-0.45877934 -1.4450388   0.06526955 -0.05083241 -0.40742368  0.05397166
  -0.30528697]]
[[ 0.34184915 -1.4850432   0.76624316  0.5587177   0.39434302  6.378837
   0.4316875 ]]
[[-2.480453   -4.8084908   1.9361696  -1.949626   -3.6452308  -0.98313504
  -2.0952086 ]]
[[ 1.0021662 -4.0527024  2.1285827  1.0145309  0.5568293 19.07208
   1.3106622]]
[[ -3.3483975 -10.14785     2.153599   -2.0198548  -5.4424047  -0.5755816
   -3.107847 ]]
[[ 1.071499   -4.097536    1.9811443   1.248653    0.60427904 19.703035
   1.4707665 ]]
[[-0.05535449 -0.210028    0.10332806  0.18512169  0.25573912  1.0082828
   0.20246772]]
[[-0.0323078  -0.38367265  0.13247682  0.21579503  0.2752887   1.7120233
   0.24630217]]
[[-0.53132653 -0.84154534  0.10754596 -0.10722691 -0.36273357  0.0419449
  -0.20455495]]
[[ 0.2372964  -1.3217034   0.6545534   0.6058321   0.52941024  5.958633
   0.49081105]]
[[-2.772817  -3.8261058  2.2679977 -2.226882  -3.9738376 -1.1891044
  -1.9330697]]
[[ 0.8136639  -3.6653905   1.777477

[[ 0.4508101  -1.3504826   0.52294564  0.63118464  0.7451407   8.915557
   1.1881931 ]]
[[-3.3797026  -9.082977    1.061216   -1.7145065  -4.527555   -0.58582765
  -2.426482  ]]
[[ 0.84622884 -1.1072433   0.60594416  0.81072706  1.1330092  11.463399
   1.6982273 ]]
[[-3.0231328  -8.16384     0.82312465 -1.3646171  -4.0031676  -0.519515
  -2.1601398 ]]
[[0.7461229  2.1582778  0.46198833 0.15127233 0.53770363 7.404787
  1.9387486 ]]
[[-0.05609238 -0.08415391  0.05993932  0.24066214  0.44308493  1.3277026
   0.3535712 ]]
[[-0.00373115  0.13664275  0.15404415  0.2106209   0.5955651   1.710513
   0.51610404]]
[[-0.6336739  -0.6603484   0.03883438 -0.0313156  -0.21167243  0.11022421
  -0.11527383]]
[[-1.5528563  -4.5811753   0.33789742 -0.76220405 -1.7669986  -0.27301204
  -1.068647  ]]
[[ 0.5428956  -1.0942968   0.48427382  0.66403013  1.0323765   9.585146
   1.4514487 ]]
[[-3.4019547  -9.151397    0.70063055 -1.5189296  -4.292101   -0.5009026
  -2.251199  ]]
[[ 0.9753135  -0.51217026  0.55

[[0.106311   0.3500917  0.08883402 0.19759798 0.6058255  1.4506079
  0.58722174]]
[[ 0.31603897  3.0342867   0.93216896 -0.22021672  0.48652792  0.9032728
   0.9484511 ]]
[[ 0.23529339 -0.3188854   0.09556132  0.31262088  0.69252676  3.0593286
   0.6813934 ]]
[[ 0.42142525  8.341456    2.5932977  -1.1808828   0.16744114 -0.6675238
   1.1010691 ]]
[[ 0.57518613 -1.4565135   0.48058122  0.78804666  2.2351496   9.255578
   1.8052522 ]]
[[-2.3089192   0.7780484   1.9208813  -2.109421   -2.7552893  -1.190889
  -0.78949326]]
[[ 1.8231981  -3.3602707   0.83614945  1.9274799   3.8233786  22.951395
   4.053745  ]]
[[0.18844147 0.45251435 0.11591824 0.25880575 0.74139845 1.6954355
  0.76428   ]]
[[ 0.42247754  3.5207918   1.087777   -0.20852283  0.573853    0.90142655
   1.1556945 ]]
[[ 0.2957139  -0.36557147  0.16893473  0.42186385  0.85626054  3.6897619
   0.9512513 ]]
[[ 0.5453545   9.00373     2.825841   -1.1848223   0.18385765 -0.81868285
   1.3512623 ]]
[[ 0.73947066 -1.5474187   0.6431877

[[ 3.2000272 -1.9795976  1.2776415  3.673264   5.0126476 23.27052
   7.7432284]]
[[0.815532   1.1994638  0.37447646 0.8775545  0.98754317 3.4004807
  1.9170468 ]]
[[1.1485462  5.464891   1.8067067  0.37120107 0.8706579  2.0960798
  2.4113953 ]]
[[ 1.0344113  -0.34432703  0.33145702  1.3269933   1.2110265   6.9533563
   2.6386545 ]]
[[ 1.4343147  11.469023    3.8544583  -0.6272539   0.35134572 -0.11257909
   2.5294633 ]]
[[ 1.9649507  -0.8949801   0.85576135  2.4220245   2.5356336  13.551835
   5.0396595 ]]
[[-0.87353474  2.437674    1.5393093  -0.97407705 -2.337235   -0.46227846
   0.52205616]]
[[ 3.6278841 -1.0834675  1.1298025  3.764141   4.841713  22.26793
   8.328431 ]]
[[1.047542   1.3367028  0.35651296 1.0685787  1.118408   3.7592704
  2.280101  ]]
[[1.355326  5.461384  1.7494394 0.6555384 0.9958276 2.343846  2.684794 ]]
[[ 1.3712219  -0.17543553  0.33330503  1.5455725   1.3131126   7.514057
   3.19245   ]]
[[ 1.4791181  10.642084    3.5219216  -0.2395768   0.4510623  -0.10072018

In [176]:
#Display a game of the model playing itself with the best moves 
env.run([agent, agent])
env.render(mode="ipython", width=600, height=500, header=False)




In [28]:
#Load saved models
updated_model.load_weights('model')
prediction_model.load_weights('model')

<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x279d773d370>

In [171]:
updated_model.summary()

Model: "sequential_8"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 conv2d_14 (Conv2D)          (None, 3, 4, 256)         4352      
                                                                 
 flatten_8 (Flatten)         (None, 3072)              0         
                                                                 
 dense_24 (Dense)            (None, 128)               393344    
                                                                 
 dropout_16 (Dropout)        (None, 128)               0         
                                                                 
 dense_25 (Dense)            (None, 64)                8256      
                                                                 
 dropout_17 (Dropout)        (None, 64)                0         
                                                                 
 dense_26 (Dense)            (None, 7)                