In [1]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import tensorflow as tf
from GomokuBoard import GomokuBoard
from Heuristics import Heuristics
from GomokuTools import GomokuTools as gt
from HeuristicPolicy import HeuristicGomokuPolicy
from google.cloud import bigquery
import google.datalab.bigquery as bq
import tensorflow_transform.tf_metadata as metadata
import datetime
import tempfile
import tensorflow_transform.beam.impl as beam_impl
import tensorflow_transform as tft
import apache_beam as beam

  'Running the Apache Beam SDK on Python 3 is not yet fully supported. '


In [2]:
input = {'game', "A"}

In [3]:
N_p=5
feature_spec = {
    'state': tf.FixedLenFeature([N_p * N_p * 2], tf.float32),
    'qvalue': tf.FixedLenFeature([N_p * N_p], tf.float32)
}

In [4]:
schema = metadata.dataset_schema.from_feature_spec(feature_spec)

In [5]:
def create_data(ignore_me):
    data = (
        np.random.randint(0,2,size=[3,5,5,2]),
        np.random.uniform(size=[3,5,5,1]))
    return data

data = create_data("whatever")
data[0].shape, data[1].shape

((3, 5, 5, 2), (3, 5, 5, 1))

In [6]:
state2 = np.rollaxis(data[0][2], 2, 0)

In [7]:
PROJECT='going-tfx'
BUCKET='going-tfx'
LOCAL_TMPDIR="/tmp"
OUTPUT_DIR="./out"
runner='DirectRunner'
job_name = 'tournament_data' + '-' + datetime.datetime.now().strftime('%y%m%d-%H%M%S')    

options = {
    'staging_location': os.path.join(OUTPUT_DIR, 'tmp', 'staging'),
    'temp_location': os.path.join(OUTPUT_DIR, 'tmp'),
    'job_name': job_name,
    'project': PROJECT,
    'max_num_workers': 24,
    'teardown_policy': 'TEARDOWN_ALWAYS',
    'no_save_main_session': True,
    'requirements_file': 'requirements.txt'
}
opts = beam.pipeline.PipelineOptions(flags=[], **options)

In [8]:
def _floats_feature(value):
    return tf.train.Feature(float_list=tf.train.FloatList(value=value))

In [9]:
data[0].shape, data[1].shape

((3, 5, 5, 2), (3, 5, 5, 1))

In [10]:
s_and_q = list(zip(data[0], data[1]))

In [11]:
s0 = s_and_q[0][0]
q0 = s_and_q[0][1]
s0.shape, q0.shape

((5, 5, 2), (5, 5, 1))

In [12]:
f = _floats_feature(q0.flatten())
f

float_list {
  value: 0.38569673895835876
  value: 0.9628795385360718
  value: 0.9119804501533508
  value: 0.8225455284118652
  value: 0.23530671000480652
  value: 0.48420655727386475
  value: 0.8617763519287109
  value: 0.926729142665863
  value: 0.7061881422996521
  value: 0.49137768149375916
  value: 0.9709430932998657
  value: 0.7995877265930176
  value: 0.8855231404304504
  value: 0.4085104465484619
  value: 0.7637038826942444
  value: 0.5410135984420776
  value: 0.12000403553247452
  value: 0.3285449147224426
  value: 0.39131662249565125
  value: 0.21633830666542053
  value: 0.8980754017829895
  value: 0.06974085420370102
  value: 0.5148058533668518
  value: 0.3054250180721283
  value: 0.19905908405780792
}

In [13]:
tfr_filename = "deleteme.tfr"
with tf.python_io.TFRecordWriter(tfr_filename) as writer:
    for vec in s_and_q:
        # Create an example protocol buffer
        example = tf.train.Example(features=tf.train.Features(feature={
            'state': _floats_feature(vec[0].flatten()),
            'qvalue' : _floats_feature(vec[1].flatten()),
            }))
        writer.write(example.SerializeToString())

### Read from File

In [14]:
def _parse_function(example):
    return tf.parse_single_example(example, feature_spec)

In [15]:
dataset = tf.data.TFRecordDataset("deleteme.tfr")

In [16]:
decoded = dataset.map(_parse_function).make_one_shot_iterator().get_next()

In [17]:
decoded

{'qvalue': <tf.Tensor 'IteratorGetNext:0' shape=(25,) dtype=float32>,
 'state': <tf.Tensor 'IteratorGetNext:1' shape=(50,) dtype=float32>}

In [18]:
with tf.Session() as sess:
    sess.run(decoded)
    sess.run(decoded)
    res2 = sess.run(decoded)

In [19]:
res2['state'].shape, res2['qvalue'].shape

((50,), (25,))

In [20]:
np.rollaxis(res2['state'].reshape(N_p,N_p,2), 2, 0).shape

(2, 5, 5)

In [21]:
state2_p = np.rollaxis(data[0][2], 2, 0)
np.equal(state2, state2_p).all()

True

### Pipelines


In [22]:
data = [1,2,3]

In [23]:
def create_games(ignore_me):
    data = (
        np.random.randint(0,2,size=[3,5,5,2]),
        np.random.uniform(size=[3,5,5,1]))
    return data

In [24]:
games = create_games("whatever")

In [25]:
games[0].shape, games[1].shape

((3, 5, 5, 2), (3, 5, 5, 1))

In [26]:
def recwise (games): 
    return [{'state': s.flatten(), 'qvalue': q.flatten()}  for s, q in zip(games[0], games[1])]

In [27]:
res = data | beam.Map(create_games) | beam.FlatMap(recwise)

In [28]:
res[0]['qvalue']

array([0.04658873, 0.90594238, 0.663904  , 0.18171573, 0.33790172,
       0.23647425, 0.76201465, 0.06078972, 0.70843801, 0.93941081,
       0.41140984, 0.15343284, 0.08211694, 0.43258062, 0.89847012,
       0.74435248, 0.16872548, 0.13533016, 0.26685144, 0.0087633 ,
       0.35557567, 0.92767592, 0.45090179, 0.55343526, 0.93775894])

In [29]:
res[0]['state']

array([1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 1, 0, 1, 0,
       1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0, 1, 1, 1, 0,
       0, 0, 0, 1, 1, 0])

In [30]:
tfr_encoder = tft.coders.ExampleProtoCoder(schema)

In [31]:
tfr_encoder.encode(res[0])

b'\n\xcf\x02\nr\n\x06qvalue\x12h\x12f\nd\xd2\xd3>=\xd7\xebg?\x9d\xf5)?\xb0\x13:>t\x01\xad>N&r>d\x13C?\xa3\xfex=2\\5?:}p?P\xa4\xd2>\x80\x1d\x1d>\xee,\xa8=5{\xdd>#\x02f?\xe2\x8d>?_\xc6,>\xfd\x93\n>\xc0\xa0\x88>\xf5\x93\x0f<\x04\x0e\xb6>+|m?\x99\xdc\xe6>\xef\xad\r?\xf8\x10p?\n\xd8\x01\n\x05state\x12\xce\x01\x12\xcb\x01\n\xc8\x01\x00\x00\x80?\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x80?\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x80?\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x80?\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x80?\x00\x00\x80?\x00\x00\x80?\x00\x00\x00\x00\x00\x00\x80?\x00\x00\x00\x00\x00\x00\x80?\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x80?\x00\x00\x80?\x00\x00\x80?\x00\x00\x80?\x00\x00\x80?\x00\x00\x80?\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x80?\x00\x00\x80?\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x80?\x00\x00\x80?\x00\x00\x80?\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\

### Pipe to TFRecord

In [55]:
query = "select distinct(game) from `going-tfx.gomoku.tournaments` limit 2"
out_name="games"
out_prefix = os.path.join(LOCAL_TMPDIR, out_name)
phase='train'
with beam.Pipeline(runner, options=opts) as p:
    with beam_impl.Context(temp_dir=tempfile.mkdtemp()):


        #   Read from Big Query
        #
        from_bq = p | "ReadFromBigQuery"  >> beam.io.Read(beam.io.BigQuerySource(
            query=query, use_standard_sql=True)) 

        # Encode back to file(s)
        #
        tfr_encoder = tft.coders.ExampleProtoCoder(schema)
        res = (from_bq
               | beam.Map(create_games)
               | beam.FlatMap(recwise)
               | ('EncodeTFRecord_' + phase) >> beam.Map(tfr_encoder.encode)
               | ('WriteTFRecord_' + phase) >> beam.io.WriteToTFRecord(out_prefix+'_tfr'))

out_prefix + '_tfr'



'/tmp/games_tfr'

### Read from File

In [33]:
def _parse_function(example):
    return tf.parse_single_example(example, feature_spec)

In [45]:
dataset = tf.data.TFRecordDataset("/tmp/games_tfr-00000-of-00001")

In [46]:
dataset

<TFRecordDatasetV1 shapes: (), types: tf.string>

In [43]:
record = dataset.take(1)

In [47]:
decoded = dataset.map(_parse_function).make_one_shot_iterator().get_next()

In [48]:
decoded

{'qvalue': <tf.Tensor 'IteratorGetNext_2:0' shape=(25,) dtype=float32>,
 'state': <tf.Tensor 'IteratorGetNext_2:1' shape=(50,) dtype=float32>}

In [49]:
with tf.Session() as sess:
    sess.run(decoded)
    sess.run(decoded)
    res2 = sess.run(decoded)

In [50]:
res2['state'].shape, res2['qvalue'].shape

((50,), (25,))

In [52]:
np.rollaxis(res2['state'].reshape(N_p,N_p,2), 2, 0)

array([[[1., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [1., 0., 1., 0., 1.],
        [1., 0., 0., 1., 0.],
        [0., 1., 1., 1., 1.]],

       [[0., 1., 1., 1., 1.],
        [1., 1., 0., 0., 1.],
        [0., 1., 0., 0., 0.],
        [1., 1., 1., 0., 0.],
        [1., 0., 0., 1., 0.]]], dtype=float32)