In [1]:
import tensorflow as tf
import numpy as np
from sklearn.model_selection import RepeatedKFold

In [2]:
# Batch of inputs (game states, one-hot encoded)
X = tf.placeholder(tf.float32, [None, 13, 26, 9], name="X") # type tf.float32 is needed for the rest of operations

# Batch of outputs (correct predictions of number of actions)
Y_corr = tf.placeholder(tf.float32, [None, 1], name="Y")

In [3]:
# Neural Network Architecture

# Flatten input
X_flat = tf.layers.flatten(X)

# First Layer - Receives inputs
layer_1 = tf.layers.dense(inputs = X_flat,
                                  units = 512,
                                  activation = tf.nn.relu,
                                  kernel_initializer=tf.contrib.layers.xavier_initializer(),
                                  use_bias = True,
                                  name="layer_1")
    
# Second Layer
"""layer_2 = tf.layers.dense(inputs = layer_1,
                                  units = 32,
                                  activation = tf.nn.relu,
                                  kernel_initializer=tf.contrib.layers.xavier_initializer(),
                                  use_bias = True,
                                  name="layer_2")"""

# Output Layer
output = tf.layers.dense(inputs = layer_1,
                                  units = 1,
                                  activation = None,
                                  kernel_initializer=tf.contrib.layers.xavier_initializer(),
                                  use_bias = True,
                                  name="output")

In [4]:
# Training
loss = tf.reduce_mean(tf.square(output - Y_corr), name="loss")

optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.00001, name="optimizer")
train_op = optimizer.minimize(loss, name="train_op")

In [5]:
# TensorBoard

writer = tf.summary.FileWriter("ModeloPlanificacion_log")
writer.add_graph(tf.get_default_graph())

# Summaries
tf.summary.scalar('loss', loss)

merged_summary = tf.summary.merge_all() # Operación para obtener todos los valores de los summaries

In [6]:
# Session

# Load Dataset
dataset = np.load('datasets/dataset_train_3.npz')
dataset_x = dataset['X']
dataset_y = dataset['Y']

# Number of epochs
num_epochs = 10
# Minibatch size
batch_size = 32

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer()) # Initialize all variables
    
    rkf = RepeatedKFold(n_splits=int(len(dataset_y) // batch_size),
                  n_repeats=num_epochs) # Get randomized indexes for minibatches
    
    it = 0
    
    for _, batch_index in rkf.split(dataset_y):
        batch_x = np.take(dataset_x, batch_index, axis=0) # Obtain current batch
        batch_y = np.take(dataset_y, batch_index)
        batch_y = np.reshape(batch_y, (-1, 1))
        
        data_dict = {X:batch_x, Y_corr:batch_y}
        
        summary = sess.run(merged_summary, feed_dict=data_dict) # Get Summaries
        writer.add_summary(summary, it) # Write it to log in disk
        
        sess.run(train_op, feed_dict=data_dict) # Execute one training step
        #perd = sess.run([train_op, loss], feed_dict=data_dict)[1] # Execute one training step
        #print("Loss:", perd)
        
        it += 1
    

In [7]:
# Now predict some values (OF TRAINING SET -> CAN BE OVERFITTED!)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer()) # Initialize all variables
    
    rkf = RepeatedKFold(n_splits=int(len(dataset_y) // batch_size),
                  n_repeats=1) # Get randomized indexes for minibatches
    
    for _, batch_index in rkf.split(dataset_y):
        batch_x = np.take(dataset_x, batch_index, axis=0) # Obtain current batch
        batch_y = np.take(dataset_y, batch_index)
        batch_y = np.reshape(batch_y, (-1, 1))
        
        data_dict = {X:batch_x, Y_corr:batch_y}
        
        pred = sess.run(output, feed_dict=data_dict) # Get Predictions for current batch
        
        print("Predictions:", pred)
        print("Real values:", batch_y)
        print("\n\n")


Predictions: [[ 0.4381218 ]
 [ 0.39164707]
 [ 0.02017336]
 [ 0.21652949]
 [ 0.40511325]
 [ 0.47342798]
 [ 0.0276497 ]
 [-0.05009077]
 [ 0.1173408 ]
 [ 0.46506637]
 [ 0.07443713]
 [-0.15246132]
 [ 0.02948943]
 [ 0.14227141]
 [ 0.07452695]
 [-0.01828644]
 [ 0.22906253]
 [ 0.34116513]
 [ 0.2599635 ]
 [ 0.3142672 ]
 [ 0.34094456]
 [ 0.4162283 ]
 [-0.02124607]
 [ 0.01543297]
 [ 0.4759834 ]
 [ 0.17589867]
 [-0.19191295]
 [ 0.39877772]
 [ 0.4381218 ]
 [ 0.3030164 ]
 [ 0.32711288]
 [ 0.2357158 ]
 [ 0.11492565]]
Real values: [[16]
 [ 8]
 [23]
 [31]
 [14]
 [14]
 [12]
 [27]
 [13]
 [12]
 [14]
 [ 8]
 [14]
 [21]
 [ 4]
 [25]
 [31]
 [30]
 [22]
 [ 2]
 [25]
 [29]
 [ 6]
 [10]
 [ 8]
 [20]
 [14]
 [27]
 [16]
 [11]
 [21]
 [19]
 [15]]



Predictions: [[ 0.03572428]
 [ 0.4216823 ]
 [ 0.40158   ]
 [ 0.41742095]
 [ 0.25467366]
 [ 0.37244818]
 [ 0.44801623]
 [-0.03058693]
 [ 0.4208577 ]
 [ 0.44303402]
 [ 0.4184775 ]
 [-0.15250118]
 [ 0.3317619 ]
 [-0.13647276]
 [ 0.100977  ]
 [-0.0669228 ]
 [ 0.09000449]
 [ 0.303

Predictions: [[ 0.3473613 ]
 [ 0.22791867]
 [ 0.11070132]
 [ 0.44004712]
 [ 0.398691  ]
 [ 0.3659808 ]
 [ 0.3744596 ]
 [ 0.2982948 ]
 [-0.0588201 ]
 [ 0.11296961]
 [ 0.2382497 ]
 [ 0.3950296 ]
 [ 0.44041693]
 [ 0.07553408]
 [-0.08050306]
 [-0.04436024]
 [-0.0051235 ]
 [ 0.3663777 ]
 [ 0.24194464]
 [ 0.02932371]
 [ 0.02057351]
 [ 0.0427302 ]
 [ 0.05715673]
 [ 0.20686495]
 [ 0.31890357]
 [ 0.03689145]
 [ 0.138528  ]
 [ 0.38692886]
 [ 0.19271769]
 [ 0.10985729]
 [ 0.05883783]
 [ 0.22710118]]
Real values: [[ 1]
 [13]
 [ 6]
 [10]
 [11]
 [10]
 [11]
 [25]
 [11]
 [10]
 [18]
 [ 1]
 [12]
 [12]
 [11]
 [12]
 [12]
 [24]
 [27]
 [ 2]
 [ 8]
 [24]
 [ 7]
 [14]
 [ 8]
 [11]
 [11]
 [ 9]
 [27]
 [18]
 [ 9]
 [ 9]]



Predictions: [[ 0.4535846 ]
 [ 0.25435635]
 [ 0.4474471 ]
 [ 0.21991193]
 [ 0.37667167]
 [ 0.34787732]
 [ 0.04607011]
 [ 0.01132798]
 [ 0.43582648]
 [ 0.40715563]
 [ 0.44967157]
 [ 0.12454945]
 [ 0.19371715]
 [ 0.35713595]
 [ 0.36479658]
 [ 0.42333287]
 [ 0.37860817]
 [ 0.36753735]
 [ 0.4558961 ]

Predictions: [[ 0.4209933 ]
 [ 0.06256793]
 [ 0.4207185 ]
 [ 0.21803069]
 [ 0.38704172]
 [ 0.1845631 ]
 [ 0.15989766]
 [ 0.44041693]
 [ 0.4805415 ]
 [ 0.45931584]
 [ 0.2382497 ]
 [ 0.33334395]
 [ 0.09510283]
 [ 0.1609147 ]
 [ 0.44768265]
 [ 0.2296932 ]
 [ 0.28211266]
 [ 0.4759834 ]
 [ 0.25467366]
 [-0.10651636]
 [ 0.1789353 ]
 [ 0.1173408 ]
 [ 0.21195668]
 [ 0.20686305]
 [ 0.00734398]
 [ 0.29584152]
 [ 0.09510283]
 [-0.07408163]
 [ 0.40861982]
 [ 0.45272124]
 [ 0.34967405]
 [ 0.2976526 ]]
Real values: [[24]
 [20]
 [12]
 [13]
 [25]
 [30]
 [20]
 [12]
 [ 4]
 [23]
 [18]
 [20]
 [ 6]
 [ 7]
 [ 6]
 [19]
 [16]
 [ 8]
 [20]
 [10]
 [14]
 [13]
 [ 8]
 [ 1]
 [24]
 [19]
 [ 6]
 [23]
 [38]
 [10]
 [17]
 [ 6]]



Predictions: [[ 0.44967157]
 [ 0.34940434]
 [-0.07449675]
 [ 0.29835308]
 [ 0.13104501]
 [ 0.25467366]
 [ 0.46051484]
 [ 0.29502434]
 [-0.11511649]
 [ 0.01857042]
 [ 0.17570974]
 [ 0.21984015]
 [-0.02053729]
 [ 0.0137378 ]
 [ 0.07452695]
 [ 0.25467366]
 [ 0.32479185]
 [ 0.36715633]
 [ 0.26005098]

Predictions: [[ 0.01819298]
 [ 0.1845631 ]
 [ 0.3287138 ]
 [ 0.0552963 ]
 [-0.0303833 ]
 [ 0.42400262]
 [ 0.44869784]
 [-0.16267471]
 [ 0.07580794]
 [ 0.18028165]
 [ 0.25079143]
 [ 0.33629525]
 [ 0.20885098]
 [ 0.18129818]
 [ 0.22325468]
 [ 0.06448215]
 [-0.07251304]
 [ 0.01142387]
 [ 0.3782857 ]
 [ 0.2326531 ]
 [ 0.3525005 ]
 [ 0.3586399 ]
 [ 0.40364474]
 [ 0.45619196]
 [ 0.19157486]
 [ 0.11296961]
 [ 0.02561396]
 [-0.04878169]
 [ 0.29360545]
 [ 0.3938701 ]
 [ 0.41823527]
 [ 0.10363485]]
Real values: [[15]
 [30]
 [17]
 [11]
 [27]
 [36]
 [16]
 [30]
 [30]
 [29]
 [19]
 [24]
 [15]
 [ 7]
 [ 3]
 [28]
 [12]
 [14]
 [24]
 [13]
 [32]
 [14]
 [15]
 [29]
 [18]
 [10]
 [ 1]
 [17]
 [20]
 [ 8]
 [18]
 [15]]



Predictions: [[ 0.38095474]
 [ 0.1173408 ]
 [ 0.25658014]
 [ 0.3319552 ]
 [-0.07289623]
 [ 0.06717393]
 [ 0.4005879 ]
 [ 0.07443713]
 [-0.05143492]
 [ 0.38224155]
 [ 0.43332037]
 [-0.03130434]
 [ 0.18028165]
 [ 0.2505104 ]
 [ 0.06882551]
 [-0.04258077]
 [ 0.11168224]
 [ 0.36861223]
 [ 0.01056813]

Predictions: [[ 0.44484985]
 [ 0.448372  ]
 [ 0.12819543]
 [ 0.4303377 ]
 [-0.05863182]
 [-0.13354631]
 [ 0.3932209 ]
 [ 0.01527053]
 [ 0.22710118]
 [-0.00355166]
 [ 0.17804262]
 [-0.00285526]
 [-0.01164167]
 [ 0.19280365]
 [ 0.34165812]
 [ 0.3172443 ]
 [ 0.07443713]
 [-0.11035027]
 [-0.01863791]
 [ 0.100977  ]
 [ 0.15050566]
 [ 0.29161918]
 [ 0.23242866]
 [ 0.05449456]
 [ 0.00310414]
 [ 0.12703092]
 [ 0.4601273 ]
 [ 0.31543022]
 [ 0.31343928]
 [ 0.31827307]
 [ 0.22850902]
 [ 0.04417236]]
Real values: [[ 6]
 [ 7]
 [19]
 [ 1]
 [15]
 [15]
 [17]
 [11]
 [ 9]
 [10]
 [20]
 [13]
 [26]
 [12]
 [20]
 [14]
 [14]
 [15]
 [10]
 [ 5]
 [26]
 [ 3]
 [20]
 [19]
 [27]
 [31]
 [10]
 [20]
 [24]
 [20]
 [21]
 [11]]



Predictions: [[ 0.45067486]
 [ 0.11841348]
 [-0.15522374]
 [ 0.33373272]
 [ 0.4381218 ]
 [ 0.3833352 ]
 [ 0.09623924]
 [ 0.13409846]
 [ 0.07443713]
 [ 0.3606908 ]
 [ 0.37704554]
 [ 0.19180912]
 [ 0.31625497]
 [ 0.19190425]
 [-0.02896817]
 [-0.03770249]
 [ 0.02244284]
 [ 0.42032138]
 [ 0.3816473 ]