In [1]:
import os
import numpy as np
import pandas as pd
import tensorflow as tf
import matplotlib.pyplot as plt
from contextlib import ExitStack
from tensorflow import keras
from tensorflow.train import BytesList, FloatList, Int64List
from tensorflow.train import Features, Feature, Example
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from keras.datasets import fashion_mnist

In [2]:
(X_train_full, y_train_full), (X_test, y_test) = fashion_mnist.load_data()
X_train, y_train = X_train_full[5000:], y_train_full[5000:]
X_val, y_val = X_train_full[:5000], y_train_full[:5000]

In [3]:
keras.backend.clear_session()

In [4]:
train_dataset = tf.data.Dataset.from_tensor_slices((X_train, y_train)).shuffle(buffer_size=len(X_train))
val_dataset = tf.data.Dataset.from_tensor_slices((X_val, y_val))
test_dataset = tf.data.Dataset.from_tensor_slices((X_test, y_test))

In [5]:
def create_dir(dir_prefix):
    """
    """
    dir_path = os.path.join("datasets", dir_prefix)
    os.makedirs(dir_path, exist_ok=True)
    return dir_path

In [6]:
def create_proto_example(image, label):
    """
    """
    image = tf.io.serialize_tensor(image)
    return Example(features=Features(feature={
        "image": Feature(bytes_list=BytesList(value=[image.numpy()])),
        "label": Feature(int64_list=Int64List(value=[label]))
    }))

In [7]:
def write_tfrecords(name, dataset, n_shards=10):
    """
    """
    paths = ["{}-{:05d}-of-{:05d}.tfrecord".format(name, index, n_shards)
             for index in range(1, n_shards + 1)]
    with ExitStack() as stack:
        writers = [stack.enter_context(tf.io.TFRecordWriter(path))
                   for path in paths]
        for index, (image, label) in dataset.enumerate():
            shard = index % n_shards
            example = create_proto_example(image, label)
            writers[shard].write(example.SerializeToString())
    return paths

In [9]:
dir_path = create_dir("fashion_mnist")
train_filepaths = write_tfrecords(os.path.join(dir_path,"train_dataset"), train_dataset)
valid_filepaths = write_tfrecords(os.path.join(dir_path,"val_dataset"), val_dataset)
test_filepaths = write_tfrecords(os.path.join(dir_path,"test_dataset"), test_dataset)

In [17]:
def preprocess(tfrecord):
    features_description = {
        "image": tf.io.FixedLenFeature([], tf.string, default_value=""),
        "label": tf.io.FixedLenFeature([], tf.int64, default_value=-1)
    }
    
    example = tf.io.parse_single_example(tfrecord, features_description)
    image = tf.io.parse_tensor(example["image"], out_type=tf.uint8)
    image = tf.reshape(image, [28,28])
    return image, example["label"]

def mnist_dataset(filepaths, buffer_size=None, n_threads=5, batch_size=32):
    dataset = tf.data.TFRecordDataset(filepaths)
    dataset = dataset.map(preprocess, num_parallel_calls=n_threads)
    if buffer_size:
        dataset = dataset.shuffle(buffer_size=buffer_size)
    dataset = dataset.batch(batch_size=batch_size)
    return dataset.prefetch(1)

In [28]:
train_dataset = mnist_dataset(train_filepaths, buffer_size=len(X_train))
test_dataset = mnist_dataset(test_filepaths)
val_dataset = mnist_dataset(valid_filepaths)