In [1]:
import os
import sys
import pickle
import numpy as np
import tensorflow as tf
sys.path.append('../../')

In [2]:
from modis_utils.misc import restore_data, cache_data

# Utils functions

In [3]:
def _float_feature(value):
    return tf.train.Feature(float_list=tf.train.FloatList(value=value))

In [4]:
def convert_to_tfrecord(input_files, output_file, resize=None):
    """Converts a file to TFRecords."""
    print('Generating %s' % output_file)
    with tf.python_io.TFRecordWriter(output_file) as record_writer:
        for input_file in input_files:
            inputs, labels = restore_data(input_file)
            example = tf.train.Example(features=tf.train.Features(
                feature={
                    'inputs': _float_feature(inputs.flatten().tolist()),
                    'labels': _float_feature(labels.flatten().tolist())
                }))
            record_writer.write(example.SerializeToString())

In [9]:
def create_dataset(data_dir, output_dir, f):
    for subset in ('val', 'test', 'train'):
        input_dir = os.path.join(data_dir, subset)
        input_files = [os.path.join(input_dir, filename) for filename in os.listdir(input_dir)]
        output_file = os.path.join(output_dir, subset + '.tfrecords')
        try:
            os.remove(output_file)
        except OSError:
            pass
        # Convert to tf.train.Example and write the to TFRecords.
        f(input_files, output_file)
        print('Done {}!'.format(subset))

In [10]:
data_dir = 'multiple_output/12/sequence_data'
output_dir = 'multiple_output/12/data'
if not os.path.exists(output_dir):
    os.makedirs(output_dir)
f = lambda x, y: convert_to_tfrecord(x, y, True)
create_dataset(data_dir, output_dir, f)

Generating multiple_output/12/data/val.tfrecords
Done val!
Generating multiple_output/12/data/test.tfrecords
Done test!
Generating multiple_output/12/data/train.tfrecords
Done train!
