In [1]:
import tensorflow as tf

In [9]:
def vgg_block(num_convs, num_channels):
    blk = tf.keras.models.Sequential()
    for _ in range(num_convs):
        blk.add(tf.keras.layers.Conv2D(num_channels, kernel_size=3, padding='same', activation='relu'))
    blk.add(tf.keras.layers.MaxPool2D(pool_size=2, strides=2))
    return blk

In [10]:
conv_arch = ((1, 64), (1, 128), (2, 256), (2, 512), (2, 512))

In [11]:
def vgg(conv_arch):
    net = tf.keras.models.Sequential()
    for (num_convs, num_channels) in conv_arch:
        net.add(vgg_block(num_convs, num_channels))
    net.add(tf.keras.models.Sequential([
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(4096, activation='relu'),
        tf.keras.layers.Dropout(0.5),
        tf.keras.layers.Dense(4096, activation='relu'),
        tf.keras.layers.Dropout(0.5),
        tf.keras.layers.Dense(10)]))
    return net
net = vgg(conv_arch)

In [12]:
X = tf.random.normal(shape = (1, 224, 224, 1))
for layer in net.layers:
    X = layer(X)
    print(layer.__class__.__name__, 'output_shape:\t', X.shape)
# Prints output at sequential blocks

Sequential output_shape:	 (1, 112, 112, 64)
Sequential output_shape:	 (1, 56, 56, 128)
Sequential output_shape:	 (1, 28, 28, 256)
Sequential output_shape:	 (1, 14, 14, 512)
Sequential output_shape:	 (1, 7, 7, 512)
Sequential output_shape:	 (1, 10)


In [24]:
# Ensure it is not time-consuming
ratio = 4
small_conv_arch = [(pair[0], pair[1]//ratio) for pair in conv_arch]
net = lambda : vgg(small_conv_arch)

In [18]:
X = tf.random.normal(shape = (1, 224, 224, 1))
for layer in net.layers:
    X = layer(X)
    print(layer.__class__.__name__, 'output_shape:\t', X.shape)
# Prints output at sequential blocks

Sequential output_shape:	 (1, 112, 112, 16)
Sequential output_shape:	 (1, 56, 56, 32)
Sequential output_shape:	 (1, 28, 28, 64)
Sequential output_shape:	 (1, 14, 14, 128)
Sequential output_shape:	 (1, 7, 7, 128)
Sequential output_shape:	 (1, 10)


In [25]:
class TrainCallback(tf.keras.callbacks.Callback):
    def __init__(self, net, train_iter, test_iter, num_epochs):
        self.net = net
        self.train_iter = train_iter
        self.test_iter = test_iter
        self.num_epochs = num_epochs

    def on_epoch_end(self, epoch, logs = None):
        test_acc = self.net.evaluate(self.test_iter, verbose = 0)
        metrics = (logs["loss"], logs["accuracy"], test_acc[1])
        print(f'epoch {epoch}, loss {metrics[0]:.3f}, train acc {metrics[1]:.3f}, 'f'test acc {metrics[2]:.3f}')

        if epoch == self.num_epochs - 1:
            print(f'loss {metrics[0]:.3f}, train acc {metrics[1]:.3f}, 'f'test acc {metrics[2]:.3f}')

In [26]:
def load_data_fashion_mnist(batch_size, resize=None):   #@save
    """Download the Fashion-MNIST dataset and then load it into memory."""
    mnist_train, mnist_test = tf.keras.datasets.fashion_mnist.load_data()
    # Divide all numbers by 255 so that all pixel values are between
    # 0 and 1, add a batch dimension at the last. And cast label to int32
    process = lambda X, y: (tf.expand_dims(X, axis=3) / 255,
                            tf.cast(y, dtype='int32'))
    resize_fn = lambda X, y: (
        tf.image.resize_with_pad(X, resize, resize) if resize else X, y)
    return (
        tf.data.Dataset.from_tensor_slices(process(*mnist_train)).batch(
            batch_size).shuffle(len(mnist_train[0])).map(resize_fn),
        tf.data.Dataset.from_tensor_slices(process(*mnist_test)).batch(
            batch_size).map(resize_fn))

In [27]:
batch_size = 128
train_iter, test_iter = load_data_fashion_mnist(batch_size, resize = 224)

In [29]:
lr, num_epochs = 0.05, 10
strategy = tf.distribute.OneDeviceStrategy(device="/cpu:0")
with strategy.scope():
    optimizer = tf.keras.optimizers.SGD(learning_rate=lr)
    loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
    net_model = net()
    net_model.compile(optimizer=optimizer, loss=loss, metrics=['accuracy'])
callback = TrainCallback(net_model, train_iter, test_iter, num_epochs)
net_model.fit(train_iter, epochs=num_epochs, verbose=0, callbacks=[callback])

epoch 0, loss 0.922, train acc 0.668, test acc 0.838


KeyboardInterrupt: 