In [1]:
import tensorflow as tf
import numpy as np
import sys
sys.path.append('../general')
from pool import Pool
from collections import Counter
from metric import metric

  from ._conv import register_converters as _register_converters


In [2]:
pool = Pool('../data')
train_pool, test_pool = pool.train_test_split()
train_features = np.concatenate(
    (train_pool.features, np.reshape(train_pool.positions, (-1, 1))),
    axis=1
)
train_prediction = np.reshape(train_pool.targets, (-1, 1))
test_features = np.concatenate(
    (test_pool.features, np.reshape(test_pool.positions, (-1, 1))),
    axis=1
)
test_prediction = np.reshape(test_pool.targets, (-1, 1))

In [3]:
class Model:
    def __init__(self):
        self.POSITIONS = list(range(9)) + [100]
        num_features = np.shape(train_features)[1]
        self.matrix = tf.get_variable(
            "matrix", shape=(num_features, 1),
            initializer=tf.glorot_uniform_initializer()
        )
        self.bias = tf.get_variable(
            "bias", shape=(1,),
            initializer=tf.glorot_uniform_initializer()
        )
        self.input_features = tf.placeholder('float32', shape=(None, num_features))
        self.input_prediction = tf.placeholder('float32', shape=(None, 1))
        self.output_prediction = tf.matmul(self.input_features, self.matrix) + self.bias
        self.loss = (
            tf.reduce_mean((self.input_prediction - self.output_prediction) ** 2) +
            tf.reduce_mean(self.matrix ** 2)
        )
        self.optimizer = tf.train.AdamOptimizer().minimize(
            self.loss, var_list=[self.matrix, self.bias]
        )

        self.sess = tf.Session()
        self.sess.run(tf.global_variables_initializer())
        self.best_marix = self.sess.run(self.matrix)
        self.best_bias = self.sess.run(self.bias)

    def teach(self, train_features, train_prediction,
              test_features, test_prediction, verbose=True, iterations=1000):
        best_test_loss = 1e111
        for i in range(iterations):
            self.sess.run(
                self.optimizer, {
                    self.input_features: train_features,
                    self.input_prediction: train_prediction
                }
            )
            train_loss = self.sess.run(
                self.loss, {
                    self.input_features: train_features, 
                    self.input_prediction: train_prediction
                }
            ),
            test_loss = self.sess.run(
                self.loss, {
                    self.input_features: test_features,
                    self.input_prediction: test_prediction
                }
            ),
            test_loss = test_loss[0]
            train_loss = train_loss[0]
            if test_loss < best_test_loss:
                self.best_marix = self.sess.run(self.matrix)
                self.best_bias = self.sess.run(self.bias)
                best_test_loss = test_loss
            if verbose:
                print(
                    "train loss: {}, test loss: {}".format(
                        train_loss, test_loss
                    )
                )

    def predict_score(self, features):
        return self.sess.run(self.output_prediction, {self.input_features: features})

    def predict_positions(self, test_features):
        prediction = []
        for features in test_features:
            max_score = -100
            best_position = -100
            for position in self.POSITIONS:
                new_score = self.predict_score(
                    [list(features) + [position]]
                )
                new_score = new_score[0]
                if new_score > max_score:
                    max_score = new_score
                    best_position = position
            prediction.append(best_position)
        return prediction

In [4]:
model = Model()
model.teach(train_features, train_prediction, test_features, test_prediction)

train loss: 186.02474975585938, test loss: 53.172279357910156
train loss: 171.31396484375, test loss: 49.420372009277344
train loss: 157.787841796875, test loss: 46.3572883605957
train loss: 145.28469848632812, test loss: 43.8138313293457
train loss: 133.62879943847656, test loss: 41.60847473144531
train loss: 122.6788101196289, test loss: 39.59663772583008
train loss: 112.34425354003906, test loss: 37.68667984008789
train loss: 102.58025360107422, test loss: 35.83430099487305
train loss: 93.37439727783203, test loss: 34.02955627441406
train loss: 84.73379516601562, test loss: 32.28398132324219
train loss: 76.67411804199219, test loss: 30.61935806274414
train loss: 69.2108154296875, test loss: 29.05879783630371
train loss: 62.3524169921875, test loss: 27.62025260925293
train loss: 56.09633255004883, test loss: 26.312158584594727
train loss: 50.426822662353516, test loss: 25.131624221801758
train loss: 45.316165924072266, test loss: 24.065433502197266
train loss: 40.72813415527344, test

train loss: 0.8853310346603394, test loss: 1.1560897827148438
train loss: 0.8760091662406921, test loss: 1.1476653814315796
train loss: 0.8668987154960632, test loss: 1.1394611597061157
train loss: 0.8579916954040527, test loss: 1.1314677000045776
train loss: 0.8492807149887085, test loss: 1.1236766576766968
train loss: 0.8407590389251709, test loss: 1.1160808801651
train loss: 0.832420289516449, test loss: 1.108674168586731
train loss: 0.8242583274841309, test loss: 1.1014492511749268
train loss: 0.8162670135498047, test loss: 1.0944002866744995
train loss: 0.8084407448768616, test loss: 1.087520718574524
train loss: 0.800773561000824, test loss: 1.0808030366897583
train loss: 0.7932602763175964, test loss: 1.0742403268814087
train loss: 0.7858956456184387, test loss: 1.067825198173523
train loss: 0.7786746025085449, test loss: 1.0615497827529907
train loss: 0.7715923190116882, test loss: 1.0554070472717285
train loss: 0.7646443843841553, test loss: 1.0493910312652588
train loss: 0.75

train loss: 0.3754676580429077, test loss: 0.6733771562576294
train loss: 0.37395069003105164, test loss: 0.67128986120224
train loss: 0.37244853377342224, test loss: 0.669211745262146
train loss: 0.37096095085144043, test loss: 0.6671422123908997
train loss: 0.3694877326488495, test loss: 0.6650813817977905
train loss: 0.3680287003517151, test loss: 0.6630290150642395
train loss: 0.36658376455307007, test loss: 0.6609850525856018
train loss: 0.36515265703201294, test loss: 0.6589497923851013
train loss: 0.363735169172287, test loss: 0.6569229364395142
train loss: 0.36233118176460266, test loss: 0.6549041867256165
train loss: 0.360940545797348, test loss: 0.652894139289856
train loss: 0.3595629930496216, test loss: 0.6508920192718506
train loss: 0.35819846391677856, test loss: 0.648898184299469
train loss: 0.35684677958488464, test loss: 0.6469125747680664
train loss: 0.35550767183303833, test loss: 0.6449349522590637
train loss: 0.35418105125427246, test loss: 0.6429653167724609
train

train loss: 0.255715936422348, test loss: 0.458943635225296
train loss: 0.25519028306007385, test loss: 0.4577043056488037
train loss: 0.2546680271625519, test loss: 0.45646968483924866
train loss: 0.2541491985321045, test loss: 0.4552396237850189
train loss: 0.2536337673664093, test loss: 0.4540143311023712
train loss: 0.25312161445617676, test loss: 0.4527934193611145
train loss: 0.252612829208374, test loss: 0.4515772759914398
train loss: 0.25210732221603394, test loss: 0.45036572217941284
train loss: 0.2516050636768341, test loss: 0.4491584300994873
train loss: 0.25110602378845215, test loss: 0.4479561746120453
train loss: 0.25061023235321045, test loss: 0.44675812125205994
train loss: 0.25011754035949707, test loss: 0.4455646276473999
train loss: 0.24962802231311798, test loss: 0.44437548518180847
train loss: 0.2491416186094284, test loss: 0.4431908428668976
train loss: 0.24865829944610596, test loss: 0.4420106112957001
train loss: 0.24817800521850586, test loss: 0.440834879875183

train loss: 0.20768055319786072, test loss: 0.32950064539909363
train loss: 0.2074274718761444, test loss: 0.3287227749824524
train loss: 0.207175612449646, test loss: 0.3279477059841156
train loss: 0.2069249451160431, test loss: 0.327175110578537
train loss: 0.2066754698753357, test loss: 0.3264053463935852
train loss: 0.2064271718263626, test loss: 0.3256380259990692
train loss: 0.20618005096912384, test loss: 0.32487329840660095
train loss: 0.2059340924024582, test loss: 0.3241112530231476
train loss: 0.20568929612636566, test loss: 0.32335180044174194
train loss: 0.20544564723968506, test loss: 0.3225947916507721
train loss: 0.20520314574241638, test loss: 0.32184040546417236
train loss: 0.20496176183223724, test loss: 0.32108867168426514
train loss: 0.20472148060798645, test loss: 0.320339560508728
train loss: 0.20448236167430878, test loss: 0.31959283351898193
train loss: 0.20424431562423706, test loss: 0.3188488483428955
train loss: 0.20400740206241608, test loss: 0.318107068538

train loss: 0.18249841034412384, test loss: 0.24737140536308289
train loss: 0.18235434591770172, test loss: 0.24687977135181427
train loss: 0.18221086263656616, test loss: 0.24639010429382324
train loss: 0.18206791579723358, test loss: 0.24590209126472473
train loss: 0.18192556500434875, test loss: 0.24541598558425903
train loss: 0.1817837804555893, test loss: 0.24493153393268585
train loss: 0.1816425323486328, test loss: 0.24444903433322906
train loss: 0.1815018653869629, test loss: 0.2439681887626648
train loss: 0.18136171996593475, test loss: 0.24348925054073334
train loss: 0.18122217059135437, test loss: 0.24301210045814514
train loss: 0.18108315765857697, test loss: 0.24253669381141663
train loss: 0.18094466626644135, test loss: 0.24206317961215973
train loss: 0.18080675601959229, test loss: 0.24159127473831177
train loss: 0.180669367313385, test loss: 0.24112127721309662
train loss: 0.1805325299501419, test loss: 0.24065302312374115
train loss: 0.1803962141275406, test loss: 0.24

train loss: 0.16776396334171295, test loss: 0.19726823270320892
train loss: 0.16767659783363342, test loss: 0.19697992503643036
train loss: 0.16758956015110016, test loss: 0.19669285416603088
train loss: 0.16750288009643555, test loss: 0.19640712440013885
train loss: 0.167416512966156, test loss: 0.19612273573875427
train loss: 0.16733042895793915, test loss: 0.19583962857723236
train loss: 0.16724471747875214, test loss: 0.1955578774213791
train loss: 0.1671593189239502, test loss: 0.19527745246887207
train loss: 0.16707424819469452, test loss: 0.19499823451042175
train loss: 0.1669894903898239, test loss: 0.19472041726112366
train loss: 0.16690503060817719, test loss: 0.19444383680820465
train loss: 0.16682091355323792, test loss: 0.19416852295398712
train loss: 0.16673710942268372, test loss: 0.19389458000659943
train loss: 0.1666536033153534, test loss: 0.19362187385559082
train loss: 0.16657042503356934, test loss: 0.19335044920444489
train loss: 0.16648755967617035, test loss: 0.

train loss: 0.1587684452533722, test loss: 0.169694721698761
train loss: 0.1587149053812027, test loss: 0.16955094039440155
train loss: 0.15866157412528992, test loss: 0.16940808296203613
train loss: 0.15860843658447266, test loss: 0.16926607489585876
train loss: 0.15855549275875092, test loss: 0.169125035405159
train loss: 0.1585027575492859, test loss: 0.16898487508296967
train loss: 0.15845021605491638, test loss: 0.16884556412696838
train loss: 0.1583978533744812, test loss: 0.16870716214179993
train loss: 0.15834569931030273, test loss: 0.1685696691274643
train loss: 0.15829375386238098, test loss: 0.1684330701828003
train loss: 0.15824200212955475, test loss: 0.16829733550548553
train loss: 0.15819044411182404, test loss: 0.16816245019435883
train loss: 0.15813905000686646, test loss: 0.16802851855754852
train loss: 0.1580878645181656, test loss: 0.16789540648460388
train loss: 0.15803688764572144, test loss: 0.16776320338249207
train loss: 0.15798607468605042, test loss: 0.16763

In [5]:
position_predictions = model.predict_positions(test_pool.features)
metric(position_predictions, test_pool.positions, test_pool.targets, test_pool.probas)

-0.611597130734467