In [44]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import tensorflow as tf
from GomokuBoard import GomokuBoard
from Heuristics import Heuristics
from GomokuTools import GomokuTools as gt
from HeuristicPolicy import HeuristicGomokuPolicy
from google.cloud import bigquery
import google.datalab.bigquery as bq
import tensorflow_transform.tf_metadata as metadata
import datetime
import tempfile
import tensorflow_transform.beam.impl as beam_impl
import tensorflow_transform as tft
import apache_beam as beam

In [4]:
input = {'game', "A"}

In [5]:
N_p=5
feature_spec = {
    'state': tf.FixedLenFeature([N_p * N_p * 2], tf.float32),
    'qvalue': tf.FixedLenFeature([N_p * N_p], tf.float32)
}

In [6]:
schema = metadata.dataset_schema.from_feature_spec(feature_spec)

In [7]:
def create_data(ignore_me):
    data = (
        np.random.randint(0,2,size=[3,5,5,2]),
        np.random.uniform(size=[3,5,5,1]))
    return data

data = create_data("whatever")
data[0].shape, data[1].shape

((3, 5, 5, 2), (3, 5, 5, 1))

In [8]:
state2 = np.rollaxis(data[0][2], 2, 0)

In [9]:
PROJECT='going-tfx'
BUCKET='going-tfx'
LOCAL_TMPDIR="/tmp"
OUTPUT_DIR="./out"
runner='DirectRunner'
job_name = 'tournament_data' + '-' + datetime.datetime.now().strftime('%y%m%d-%H%M%S')    

options = {
    'staging_location': os.path.join(OUTPUT_DIR, 'tmp', 'staging'),
    'temp_location': os.path.join(OUTPUT_DIR, 'tmp'),
    'job_name': job_name,
    'project': PROJECT,
    'max_num_workers': 24,
    'teardown_policy': 'TEARDOWN_ALWAYS',
    'no_save_main_session': True,
    'requirements_file': 'requirements.txt'
}
opts = beam.pipeline.PipelineOptions(flags=[], **options)

In [10]:
def _floats_feature(value):
    return tf.train.Feature(float_list=tf.train.FloatList(value=value))

In [11]:
data[0].shape, data[1].shape

((3, 5, 5, 2), (3, 5, 5, 1))

In [12]:
s_and_q = list(zip(data[0], data[1]))

In [13]:
s0 = s_and_q[0][0]
q0 = s_and_q[0][1]
s0.shape, q0.shape

((5, 5, 2), (5, 5, 1))

In [14]:
f = _floats_feature(q0.flatten())
f

float_list {
  value: 0.0209263414144516
  value: 0.3686140775680542
  value: 0.9081296920776367
  value: 0.8073073029518127
  value: 0.8000392913818359
  value: 0.8261414766311646
  value: 0.4063945710659027
  value: 0.2874433994293213
  value: 0.10769154131412506
  value: 0.4927733838558197
  value: 0.4382688105106354
  value: 0.8112466335296631
  value: 0.30836591124534607
  value: 0.6808125972747803
  value: 0.8147794604301453
  value: 0.35377880930900574
  value: 0.40647491812705994
  value: 0.28317582607269287
  value: 0.29440751671791077
  value: 0.2760884463787079
  value: 0.23128415644168854
  value: 0.3272022008895874
  value: 0.9716960787773132
  value: 0.9115557074546814
  value: 0.2744579315185547
}

In [15]:
tfr_filename = "deleteme.tfr"
with tf.python_io.TFRecordWriter(tfr_filename) as writer:
    for vec in s_and_q:
        # Create an example protocol buffer
        example = tf.train.Example(features=tf.train.Features(feature={
            'state': _floats_feature(vec[0].flatten()),
            'qvalue' : _floats_feature(vec[1].flatten()),
            }))
        writer.write(example.SerializeToString())

### Read from File

In [16]:
def _parse_function(example):
    return tf.parse_single_example(example, feature_spec)

In [17]:
dataset = tf.data.TFRecordDataset("deleteme.tfr")

In [18]:
decoded = dataset.map(_parse_function).make_one_shot_iterator().get_next()

In [19]:
decoded

{'qvalue': <tf.Tensor 'IteratorGetNext:0' shape=(25,) dtype=float32>,
 'state': <tf.Tensor 'IteratorGetNext:1' shape=(50,) dtype=float32>}

In [20]:
with tf.Session() as sess:
    sess.run(decoded)
    sess.run(decoded)
    res2 = sess.run(decoded)

In [21]:
res2['state'].shape, res2['qvalue'].shape

((50,), (25,))

In [22]:
np.rollaxis(res2['state'].reshape(N_p,N_p,2), 2, 0).shape

(2, 5, 5)

In [23]:
state2_p = np.rollaxis(data[0][2], 2, 0)
np.equal(state2, state2_p).all()

True

### Pipelines


In [24]:
data = [1,2,3]

In [25]:
def create_games(ignore_me):
    data = (
        np.random.randint(0,2,size=[3,5,5,2]),
        np.random.uniform(size=[3,5,5,1]))
    return data

In [26]:
games = create_games("whatever")

In [27]:
games[0].shape, games[1].shape

((3, 5, 5, 2), (3, 5, 5, 1))

In [28]:
def recwise (games): 
    return [{'state': s.flatten(), 'qvalue': q.flatten()}  for s, q in zip(games[0], games[1])]

In [29]:
res = data | beam.Map(create_games) | beam.FlatMap(recwise)

In [30]:
res[0]['qvalue']

array([4.14327187e-01, 1.77229180e-01, 3.84385350e-04, 2.01961726e-01,
       3.90489468e-01, 2.12715784e-01, 3.74710327e-01, 1.89744475e-01,
       7.64024794e-01, 2.54353764e-01, 2.58991329e-01, 7.24080874e-01,
       8.51405265e-01, 5.33671783e-01, 9.80782910e-01, 3.10965705e-01,
       4.52187551e-01, 1.68917436e-01, 5.68908672e-01, 6.10763193e-01,
       4.49291984e-01, 9.34408653e-01, 8.57797200e-01, 1.14181199e-01,
       5.95960327e-01])

In [31]:
res[0]['state']

array([1, 1, 0, 0, 1, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1,
       0, 1, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 1, 0, 0, 1, 0,
       1, 1, 0, 1, 0, 1])

In [32]:
tfr_encoder = tft.coders.ExampleProtoCoder(schema)

In [33]:
tfr_encoder.encode(res[0])

b'\n\xcf\x02\n\xd8\x01\n\x05state\x12\xce\x01\x12\xcb\x01\n\xc8\x01\x00\x00\x80?\x00\x00\x80?\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x80?\x00\x00\x00\x00\x00\x00\x80?\x00\x00\x80?\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x80?\x00\x00\x80?\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x80?\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x80?\x00\x00\x00\x00\x00\x00\x80?\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x80?\x00\x00\x80?\x00\x00\x80?\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x80?\x00\x00\x80?\x00\x00\x80?\x00\x00\x80?\x00\x00\x80?\x00\x00\x80?\x00\x00\x00\x00\x00\x00\x80?\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x80?\x00\x00\x00\x00\x00\x00\x80?\x00\x00\x80?\x00\x00\x00\x00\x00\x00\x80?\x00\x00\x00\x00\x00\x00\x80?\nr\n\x06qvalue\x12h\x12f\nd\xb1"\xd4>\x91{5>T\x87\xc99\x0e\xcfN><\xee\xc7>+\xd2Y>\x08\xda\xbf>`LB>!\x97C?\xa8:\x82>\x83\x9a\x84>]]9?\xb2\xf5Y?\xb7\x9e\x08?\x97\x14{?\xe66\x9f> \x85\xe7>\xb1\xf8,>\x00\xa4\x11

### Pipe to TFRecord

In [34]:
query = "select distinct(game) from `going-tfx.gomoku.tournaments` limit 2"
out_name="games"
out_prefix = os.path.join(LOCAL_TMPDIR, out_name)
phase='train'
with beam.Pipeline(runner, options=opts) as p:
    with beam_impl.Context(temp_dir=tempfile.mkdtemp()):


        #   Read from Big Query
        #
        from_bq = p | "ReadFromBigQuery"  >> beam.io.Read(beam.io.BigQuerySource(
            query=query, use_standard_sql=True)) 

        # Encode back to file(s)
        #
        tfr_encoder = tft.coders.ExampleProtoCoder(schema)
        res = (from_bq
               | beam.Map(create_games)
               | beam.FlatMap(recwise)
               | ('EncodeTFRecord_' + phase) >> beam.Map(tfr_encoder.encode)
               | ('WriteTFRecord_' + phase) >> beam.io.WriteToTFRecord(out_prefix+'_tfr'))

out_prefix + '_tfr'



'/tmp/games_tfr'

### Read from File

In [35]:
def _parse_function(example):
    return tf.parse_single_example(example, feature_spec)

In [36]:
dataset = tf.data.TFRecordDataset("/tmp/games_tfr-00000-of-00001")

In [37]:
dataset

<TFRecordDatasetV1 shapes: (), types: tf.string>

In [38]:
record = dataset.take(1)

In [39]:
decoded = dataset.map(_parse_function).make_one_shot_iterator().get_next()

In [40]:
decoded

{'qvalue': <tf.Tensor 'IteratorGetNext_1:0' shape=(25,) dtype=float32>,
 'state': <tf.Tensor 'IteratorGetNext_1:1' shape=(50,) dtype=float32>}

In [41]:
with tf.Session() as sess:
    sess.run(decoded)
    sess.run(decoded)
    res2 = sess.run(decoded)

In [42]:
res2['state'].shape, res2['qvalue'].shape

((50,), (25,))

In [43]:
np.rollaxis(res2['state'].reshape(N_p,N_p,2), 2, 0)

array([[[0., 1., 1., 0., 1.],
        [0., 0., 0., 1., 1.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 1.],
        [1., 1., 0., 0., 1.]],

       [[0., 0., 0., 1., 0.],
        [1., 0., 1., 1., 0.],
        [0., 0., 1., 0., 0.],
        [0., 0., 1., 0., 0.],
        [0., 1., 0., 0., 0.]]], dtype=float32)