In [14]:
import os
import sys
import pickle
import numpy as np
import tensorflow as tf

In [3]:
from modis_utils.misc import restore_data, cache_data

In [4]:
data_dir = '../sequence_data/12'

In [6]:
def _get_file_names():
    """Returns the file names expected to exist in the input_dir."""
    file_names = {}
    file_names['train'] = ['data_batch_%d' % i for i in range(1, 5)]
    file_names['validation'] = ['data_batch_5']
    file_names['eval'] = ['test_batch']
    return file_names

In [44]:
def _float_feature(value):
    return tf.train.Feature(float_list=tf.train.FloatList(value=value))

In [60]:
def convert_to_tfrecord(input_files, output_file):
    """Converts a file to TFRecords."""
    print('Generating %s' % output_file)
    with tf.python_io.TFRecordWriter(output_file) as record_writer:
        for input_file in input_files:
            inputs, labels, inputs_pw, labels_pw = restore_data(input_file)
            num_entries_in_batch = len(inputs)
            for i in range(num_entries_in_batch):
                example = tf.train.Example(features=tf.train.Features(
                    feature={
                        'inputs': _float_feature(inputs[i].ravel()),
                        'labels': _float_feature(labels[i].ravel())
                    }))
                record_writer.write(example.SerializeToString())

In [61]:
for subset in ('val', 'test', 'train'):
    input_dir = os.path.join(data_dir, subset)
    input_files = [os.path.join(input_dir, f) for f in os.listdir(input_dir)]
    output_file = os.path.join(data_dir, subset + '.tfrecords')
    try:
        os.remove(output_file)
    except OSError:
        pass
    # Convert to tf.train.Example and write the to TFRecords.
    convert_to_tfrecord(input_files, output_file)
    print('Done {}!'.format(subset))

Generating ../sequence_data/12/val.tfrecords
Done val!
Generating ../sequence_data/12/test.tfrecords
Done test!
Generating ../sequence_data/12/train.tfrecords
Done train!


In [30]:
n_examples = {'train': 0, 'val': 0, 'test': 0}
for subset in ('train', 'val', 'test'):
    n = 0
    subset_data_dir = os.path.join(data_dir, subset)
    for filename in os.listdir(subset_data_dir):
        data = restore_data(os.path.join(subset_data_dir, filename))
        n += len(data[0])
    n_examples[subset] = n
print(n_examples)

{'train': 15534, 'val': 1380, 'test': 2454}
