# 9.

In [20]:
import tensorflow as tf
from tensorflow import keras
import numpy as np


keras.backend.clear_session()
np.random.seed(42)
tf.random.set_seed(42)


(X_train_full, y_train_full), (X_test, y_test) = keras.datasets.fashion_mnist.load_data()

X_train_full = tf.cast(X_train_full, tf.float32)
y_train_full = tf.cast(y_train_full, tf.float32)
X_test = tf.cast(X_test, tf.float32)
y_test = tf.cast(y_test, tf.float32)
X_valid, X_train = X_train_full[:10000], X_train_full[10000:30000]
y_valid, y_train = y_train_full[:10000], y_train_full[10000:30000]

train_set = tf.data.Dataset.from_tensor_slices((X_train_full, y_train_full))
train_set = train_set.shuffle(train_set.cardinality())
test_set = tf.data.Dataset.from_tensor_slices((X_test, y_test))
valid_set = tf.data.Dataset.from_tensor_slices((X_valid, y_valid))

In [30]:
from tensorflow.train import Example, Features, Feature, BytesList, FloatList

def create_example(image, label):
    serialized_image = tf.io.serialize_tensor(image)
    print(serialized_image)
    print(type(serialized_image.numpy()))
    return Example(
        features=Features(
            feature={
                'image': Feature(bytes_list=BytesList(value=[serialized_image.numpy()])),
                'label': Feature(float_list=FloatList(value=[label]))
            }))

In [None]:
from contextlib import ExitStack

def write_tfrecords(name, dataset, n_shards=10):
    paths = ["{}.tfrecord-{:05d}-of-{:05d}".format(name, index, n_shards)
             for index in range(n_shards)]
    with ExitStack() as stack:
        writers = [stack.enter_context(tf.io.TFRecordWriter(path))
                   for path in paths]
        for index, (image, label) in dataset.enumerate():
            shard = index % n_shards
            example = create_example(image, label)
            writers[shard].write(example.SerializeToString())
    return paths

In [None]:
train_filepaths = write_tfrecords("my_fashion_mnist.train", train_set)
valid_filepaths = write_tfrecords("my_fashion_mnist.valid", valid_set)
test_filepaths = write_tfrecords("my_fashion_mnist.test", test_set)

# 10.