In [None]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import tensorflow as tf
from wgomoku import GomokuBoard
from wgomoku import Heuristics
from wgomoku import GomokuTools as gt
from wgomoku import HeuristicGomokuPolicy
from google.cloud import bigquery
import google.datalab.bigquery as bq
import tensorflow_transform.tf_metadata as metadata
import datetime
import tempfile
import tensorflow_transform.beam.impl as beam_impl
import tensorflow_transform as tft
import apache_beam as beam

In [None]:
input = {'game', "A"}

In [None]:
N_p=5
feature_spec = {
    'state': tf.FixedLenFeature([N_p * N_p * 2], tf.float32),
    'qvalue': tf.FixedLenFeature([N_p * N_p], tf.float32)
}

In [None]:
schema = metadata.dataset_schema.from_feature_spec(feature_spec)

In [None]:
def create_data(ignore_me):
    data = (
        np.random.randint(0,2,size=[3,5,5,2]),
        np.random.uniform(size=[3,5,5,1]))
    return data

data = create_data("whatever")
data[0].shape, data[1].shape

In [None]:
state2 = np.rollaxis(data[0][2], 2, 0)

In [None]:
PROJECT='going-tfx'
BUCKET='going-tfx'
LOCAL_TMPDIR="/tmp"
OUTPUT_DIR="./out"
runner='DirectRunner'
job_name = 'tournament_data' + '-' + datetime.datetime.now().strftime('%y%m%d-%H%M%S')    

options = {
    'staging_location': os.path.join(OUTPUT_DIR, 'tmp', 'staging'),
    'temp_location': os.path.join(OUTPUT_DIR, 'tmp'),
    'job_name': job_name,
    'project': PROJECT,
    'max_num_workers': 24,
    'teardown_policy': 'TEARDOWN_ALWAYS',
    'no_save_main_session': True,
    'requirements_file': 'requirements.txt'
}
opts = beam.pipeline.PipelineOptions(flags=[], **options)

In [None]:
def _floats_feature(value):
    return tf.train.Feature(float_list=tf.train.FloatList(value=value))

In [None]:
data[0].shape, data[1].shape

In [None]:
s_and_q = list(zip(data[0], data[1]))

In [None]:
s0 = s_and_q[0][0]
q0 = s_and_q[0][1]
s0.shape, q0.shape

In [None]:
f = _floats_feature(q0.flatten())
f

In [None]:
tfr_filename = "deleteme.tfr"
with tf.python_io.TFRecordWriter(tfr_filename) as writer:
    for vec in s_and_q:
        # Create an example protocol buffer
        example = tf.train.Example(features=tf.train.Features(feature={
            'state': _floats_feature(vec[0].flatten()),
            'qvalue' : _floats_feature(vec[1].flatten()),
            }))
        writer.write(example.SerializeToString())

### Read from File

In [None]:
def _parse_function(example):
    return tf.parse_single_example(example, feature_spec)

In [None]:
dataset = tf.data.TFRecordDataset("deleteme.tfr")

In [None]:
decoded = dataset.map(_parse_function).make_one_shot_iterator().get_next()

In [None]:
decoded

In [None]:
with tf.Session() as sess:
    sess.run(decoded)
    sess.run(decoded)
    res2 = sess.run(decoded)

In [None]:
res2['state'].shape, res2['qvalue'].shape

In [None]:
np.rollaxis(res2['state'].reshape(N_p,N_p,2), 2, 0).shape

In [None]:
state2_p = np.rollaxis(data[0][2], 2, 0)
np.equal(state2, state2_p).all()

### Pipelines


In [None]:
data = [1,2,3]

In [None]:
def create_games(ignore_me):
    data = (
        np.random.randint(0,2,size=[3,5,5,2]),
        np.random.uniform(size=[3,5,5,1]))
    return data

In [None]:
games = create_games("whatever")

In [None]:
games[0].shape, games[1].shape

In [None]:
def recwise (games): 
    return [{'state': s.flatten(), 'qvalue': q.flatten()}  for s, q in zip(games[0], games[1])]

In [None]:
res = data | beam.Map(create_games) | beam.FlatMap(recwise)

In [None]:
res[0]['qvalue']

In [None]:
res[0]['state']

In [None]:
tfr_encoder = tft.coders.ExampleProtoCoder(schema)

In [None]:
tfr_encoder.encode(res[0])

### Pipe to TFRecord

In [None]:
query = "select distinct(game) from `going-tfx.gomoku.tournaments` limit 2"
out_name="games"
out_prefix = os.path.join(LOCAL_TMPDIR, out_name)
phase='train'
with beam.Pipeline(runner, options=opts) as p:
    with beam_impl.Context(temp_dir=tempfile.mkdtemp()):


        #   Read from Big Query
        #
        from_bq = p | "ReadFromBigQuery"  >> beam.io.Read(beam.io.BigQuerySource(
            query=query, use_standard_sql=True)) 

        # Encode back to file(s)
        #
        tfr_encoder = tft.coders.ExampleProtoCoder(schema)
        res = (from_bq
               | beam.Map(create_games)
               | beam.FlatMap(recwise)
               | ('EncodeTFRecord_' + phase) >> beam.Map(tfr_encoder.encode)
               | ('WriteTFRecord_' + phase) >> beam.io.WriteToTFRecord(out_prefix+'_tfr'))

out_prefix + '_tfr'

### Read from File

In [None]:
def _parse_function(example):
    return tf.parse_single_example(example, feature_spec)

In [None]:
dataset = tf.data.TFRecordDataset("/tmp/games_tfr-00000-of-00001")

In [None]:
dataset

In [None]:
record = dataset.take(1)

In [None]:
decoded = dataset.map(_parse_function).make_one_shot_iterator().get_next()

In [None]:
decoded

In [None]:
with tf.Session() as sess:
    sess.run(decoded)
    sess.run(decoded)
    res2 = sess.run(decoded)

In [None]:
res2['state'].shape, res2['qvalue'].shape

In [None]:
np.rollaxis(res2['state'].reshape(N_p,N_p,2), 2, 0)