In [1]:
import tensorflow as tf

In [2]:
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Conv2D,Dense,Flatten,Dropout,GlobalMaxPooling2D,MaxPooling2D,BatchNormalization

In [3]:
# Load in the data
cifar10 = tf.keras.datasets.cifar10

(x_train, y_train), (x_test, y_test) = cifar10.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
y_train, y_test = y_train.flatten(), y_test.flatten()
print("x_train.shape:", x_train.shape)
print("y_train.shape", y_train.shape)

Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
[1m170498071/170498071[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m13s[0m 0us/step
x_train.shape: (50000, 32, 32, 3)
y_train.shape (50000,)


In [4]:
# number of classes
K = len(set(y_train))
print("number of classes:", K)

number of classes: 10


In [5]:
# Build the model using the functional API
def create_model():
  i = Input(shape=x_train[0].shape)

  x = Conv2D(32, (3, 3), activation='relu', padding='same')(i)
  x = BatchNormalization()(x)
  x = Conv2D(32, (3, 3), activation='relu', padding='same')(x)
  x = BatchNormalization()(x)
  x = MaxPooling2D((2, 2))(x)
  x = Conv2D(64, (3, 3), activation='relu', padding='same')(x)
  x = BatchNormalization()(x)
  x = Conv2D(64, (3, 3), activation='relu', padding='same')(x)
  x = BatchNormalization()(x)
  x = MaxPooling2D((2, 2))(x)
  x = Conv2D(128, (3, 3), activation='relu', padding='same')(x)
  x = BatchNormalization()(x)
  x = Conv2D(128, (3, 3), activation='relu', padding='same')(x)
  x = BatchNormalization()(x)
  x = MaxPooling2D((2, 2))(x)

  x = Flatten()(x)
  x = Dropout(0.2)(x)
  x = Dense(1024, activation='relu')(x)
  x = Dropout(0.2)(x)
  x = Dense(512, activation='softmax')(x)

  model = Model(i, x)
  return model

In [6]:
strategy = tf.distribute.MirroredStrategy()
# strategy = tf.distribute.experimental.CentralStorageStrategy()

In [7]:
print(f'Number of devices: {strategy.num_replicas_in_sync}')

Number of devices: 1


In [8]:
# Critical part
with strategy.scope():
  model = create_model()
  model.compile(optimizer='adam',
                loss='sparse_categorical_crossentropy',
                metrics=['accuracy'])

In [9]:
# Fit
r = model.fit(x_train, y_train, validation_data=(x_test, y_test), batch_size=128, epochs=15)

Epoch 1/15
[1m400/400[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m18s[0m 30ms/step - accuracy: 0.4362 - loss: 1.9152 - val_accuracy: 0.3807 - val_loss: 1.8283
Epoch 2/15
[1m400/400[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m12s[0m 29ms/step - accuracy: 0.6945 - loss: 0.8599 - val_accuracy: 0.7000 - val_loss: 0.8594
Epoch 3/15
[1m400/400[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m12s[0m 29ms/step - accuracy: 0.7722 - loss: 0.6451 - val_accuracy: 0.7647 - val_loss: 0.6765
Epoch 4/15
[1m400/400[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m12s[0m 30ms/step - accuracy: 0.8175 - loss: 0.5207 - val_accuracy: 0.7438 - val_loss: 0.7636
Epoch 5/15
[1m400/400[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m12s[0m 30ms/step - accuracy: 0.8551 - loss: 0.4181 - val_accuracy: 0.7910 - val_loss: 0.6355
Epoch 6/15
[1m400/400[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m12s[0m 30ms/step - accuracy: 0.8804 - loss: 0.3384 - val_accuracy: 0.7881 - val_loss: 0.7024
Epoch 7/15
[1m4

In [10]:
50000/391

127.8772378516624

In [11]:
10000/79

126.58227848101266

In [12]:
#Compare this to non-distributed training
model2 = create_model()
model2.compile(optimizer='adam',
                loss='sparse_categorical_crossentropy',
                metrics=['accuracy'])
r = model2.fit(x_train, y_train, validation_data=(x_test, y_test), batch_size=128, epochs=15)

Epoch 1/15
[1m391/391[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m25s[0m 37ms/step - accuracy: 0.4411 - loss: 1.8524 - val_accuracy: 0.3910 - val_loss: 1.9148
Epoch 2/15
[1m391/391[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m26s[0m 16ms/step - accuracy: 0.7058 - loss: 0.8356 - val_accuracy: 0.6934 - val_loss: 0.8845
Epoch 3/15
[1m391/391[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m10s[0m 16ms/step - accuracy: 0.7795 - loss: 0.6248 - val_accuracy: 0.7286 - val_loss: 0.7990
Epoch 4/15
[1m391/391[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 16ms/step - accuracy: 0.8221 - loss: 0.5007 - val_accuracy: 0.7907 - val_loss: 0.6175
Epoch 5/15
[1m391/391[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m10s[0m 16ms/step - accuracy: 0.8601 - loss: 0.3992 - val_accuracy: 0.7446 - val_loss: 0.8419
Epoch 6/15
[1m391/391[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 15ms/step - accuracy: 0.8851 - loss: 0.3245 - val_accuracy: 0.7857 - val_loss: 0.6544
Epoch 7/15
[1m391