<a href="https://colab.research.google.com/github/aimerou/deep-learning/blob/main/notebooks/keras_fashion_mnist.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

**Train on TPU**  
1. On the main menu, click Runtime and select Change runtime type. Set "TPU" as the hardware accelerator.
2. Click Runtime again and select Runtime > Run All.

In [1]:
import distutils
import os
import tensorflow as tf
from keras.models import Sequential
from keras.layers import Dense, BatchNormalization, Dropout, Conv2D, MaxPooling2D, Flatten
import numpy as np
from matplotlib import pyplot

**Load Fashion MNIST dataset**

In [None]:
(X_train, y_train), (X_test, y_test) = tf.keras.datasets.fashion_mnist.load_data()
X_train = X_train.reshape(X_train.shape[0], X_train.shape[1], X_train.shape[2], 1)
X_test = X_test.reshape(X_test.shape[0], X_test.shape[1], X_test.shape[2], 1)

In [3]:
X_train.shape

(60000, 28, 28, 1)

**Keras Model**

In [4]:
def create_model():
  model = Sequential()

  model.add(BatchNormalization(input_shape=X_train.shape[1:]))
  model.add(Conv2D(64, kernel_size=4, padding='same', activation='relu'))
  model.add(MaxPooling2D())
  model.add(Dropout(0.25))

  model.add(BatchNormalization(input_shape=X_train.shape[1:]))
  model.add(Conv2D(128, kernel_size=4, padding='same', activation='relu'))
  model.add(MaxPooling2D())
  model.add(Dropout(0.25))

  model.add(BatchNormalization(input_shape=X_train.shape[1:]))
  model.add(Conv2D(256, kernel_size=4, padding='same', activation='relu'))
  model.add(MaxPooling2D())
  model.add(Dropout(0.25))

  model.add(Flatten())
  model.add(Dense(256, activation='relu'))
  model.add(Dense(10, activation='softmax'))
  return model

**Construct the model on TPU and compile it**

In [None]:
tf.keras.backend.clear_session()
resolver = tf.distribute.cluster_resolver.TPUClusterResolver('grpc://' + os.environ['COLAB_TPU_ADDR'])
tf.config.experimental_connect_to_cluster(resolver)

# This is the TPU initialization code that has to be at the beginning.
tf.tpu.experimental.initialize_tpu_system(resolver)
strategy = tf.distribute.experimental.TPUStrategy(resolver)

with strategy.scope():
  model = create_model()
  model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3),
      loss='sparse_categorical_crossentropy',
      metrics=['accuracy'])

In [6]:
model.summary()

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
batch_normalization (BatchNo (None, 28, 28, 1)         4         
_________________________________________________________________
conv2d (Conv2D)              (None, 28, 28, 64)        1088      
_________________________________________________________________
max_pooling2d (MaxPooling2D) (None, 14, 14, 64)        0         
_________________________________________________________________
dropout (Dropout)            (None, 14, 14, 64)        0         
_________________________________________________________________
batch_normalization_1 (Batch (None, 14, 14, 64)        256       
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 14, 14, 128)       131200    
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 7, 7, 128)         0

**Training and Validation**

In [7]:
model.fit(
    X_train.astype(np.float32), y_train.astype(np.float32),
    epochs=10,
    validation_data=(X_test.astype(np.float32), y_test.astype(np.float32))
)

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


<tensorflow.python.keras.callbacks.History at 0x7fe790f2ab90>

In [None]:
LABEL_NAMES = ['t_shirt', 'trouser', 'pullover', 'dress', 'coat', 'sandal', 'shirt', 'sneaker', 'bag', 'ankle_boots']

%matplotlib inline

def plot_predictions(images, predictions):
  n = images.shape[0]
  nc = int(np.ceil(n / 4))
  f, axes = pyplot.subplots(nc, 4)
  for i in range(nc * 4):
    y = i // 4
    x = i % 4
    axes[x, y].axis('off')
    label = LABEL_NAMES[np.argmax(predictions[i])]
    confidence = np.max(predictions[i])
    if i > n:
      continue
    axes[x, y].imshow(images[i])
    axes[x, y].text(0.5, 0.5, label + '\n%.3f' % confidence, fontsize=14)
  pyplot.gcf().set_size_inches(8, 10)
    
plot_predictions(np.squeeze(X_test[:16]), model.predict(X_test[:16]))