In [14]:
import logging
import tensorflow as tf
import keras

logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.INFO)
log = logging.getLogger()

%config Completer.use_jedi = False # make autocompletion works in jupyter

tf.__version__

'2.4.1'

In [85]:
import tensorflow_datasets as tfds
from tensorflow.train import Feature, Example, Features, BytesList, Int64List
from pathlib import Path
from tensorflow import keras

def ch13_ex_9():
    train_fraction = 0.8
    seed = 101
    train_dir = 'data-ignored/tf/train'
    val_dir = 'data-ignored/tf/val'
    batch_size = 32
    epochs = 3

    def serialize_dataset(dataset, data_dir):
        filepaths = []
        Path(data_dir).mkdir(parents=True, exist_ok=True)
        for i, d in enumerate(dataset):
            ser_image = tf.io.serialize_tensor(d['image']).numpy()
            image_example = Example(
                features=Features(
                    feature={
                        'image': Feature(bytes_list=BytesList(value=[ser_image])),
                        'label': Feature(int64_list=Int64List(value=[d['label']]))
                    }
                )
            )
            filepath = f'{data_dir}/image{i}.tfrecord'
            with tf.io.TFRecordWriter(filepath) as f:
                f.write(image_example.SerializeToString())
            filepaths.append(filepath)
        return filepaths
                
    def create_dataset(filepaths):
        image_descr = {
            'image': tf.io.FixedLenFeature([], tf.string, default_value=''),
            'label': tf.io.FixedLenFeature([], tf.int64, default_value=-1)
        }
        def parse_image_example(ser):
            parsed_image_example = tf.io.parse_single_example(ser, image_descr)
            image = tf.io.parse_tensor(parsed_image_example["image"], out_type=tf.uint8)
            label = parsed_image_example["label"]
            return image, label
       
        dataset = tf.data.TFRecordDataset(filepaths).map(parse_image_example).shuffle(buffer_size=len(filepaths), seed=seed).batch(batch_size).prefetch(1)
        return dataset
    
    def create_model():
        model = keras.models.Sequential()
        model.add(keras.layers.InputLayer(input_shape=(28, 28)))
        model.add(keras.layers.Flatten())
        model.add(keras.layers.Dense(300, activation='elu', kernel_initializer='he_normal'))
        model.add(keras.layers.Dense(10, activation='softmax'))
        model.compile(loss='sparse_categorical_crossentropy',
                     optimizer=keras.optimizers.Nadam(learning_rate=0.001),
                     metrics=[keras.metrics.sparse_categorical_accuracy])
        return model
    
    dataset = tfds.load("fashion_mnist")
    log.info(f'Loaded dataset: {dataset}')
    train_idx = int(len(dataset['train']) * train_fraction)
    train = dataset['train'].take(train_idx).shuffle(101, seed=seed, reshuffle_each_iteration=True)
    log.info(f'Train dataset len: {len(train)}')
    val = dataset['train'].skip(train_idx)
    log.info(f'Val dataset len: {len(val)}')
    
    train_filepaths = serialize_dataset(dataset=train, data_dir=train_dir)
    val_filepaths = serialize_dataset(dataset=val, data_dir=val_dir)
    
    train_dataset = create_dataset(train_filepaths)
    val_dataset = create_dataset(val_filepaths)
    
    model = create_model()
    model.fit(train_dataset, epochs=epochs, validation_data=val_dataset)

             
ch13_ex_9()

2021-04-19 14:24:24,368 : INFO : Load dataset info from /Users/mkhokhlush/tensorflow_datasets/fashion_mnist/3.0.1
2021-04-19 14:24:24,370 : INFO : Reusing dataset fashion_mnist (/Users/mkhokhlush/tensorflow_datasets/fashion_mnist/3.0.1)
2021-04-19 14:24:24,371 : INFO : Constructing tf.data.Dataset for split None, from /Users/mkhokhlush/tensorflow_datasets/fashion_mnist/3.0.1
2021-04-19 14:24:24,423 : INFO : Loaded dataset: {'train': <PrefetchDataset shapes: {image: (28, 28, 1), label: ()}, types: {image: tf.uint8, label: tf.int64}>, 'test': <PrefetchDataset shapes: {image: (28, 28, 1), label: ()}, types: {image: tf.uint8, label: tf.int64}>}
2021-04-19 14:24:24,425 : INFO : Train dataset len: 48000
2021-04-19 14:24:24,426 : INFO : Val dataset len: 12000


Epoch 1/3
Epoch 2/3
Epoch 3/3
