In [54]:
import tensorflow as tf
import numpy as np

import matplotlib.pyplot as plt

(X_train, y_train), (X_test, y_test) = tf.keras.datasets.cifar10.load_data()
X_train = np.asarray(X_train, dtype=np.float32)
y_train = np.asarray(y_train, dtype=np.int32).flatten()
X_test = np.asarray(X_test, dtype=np.float32)
y_test = np.asarray(y_test, dtype=np.int32).flatten()

num_training=49000
num_validation=1000
num_test=10000

BUFFER_SIZE = len(X_train)

mask = range(num_training, num_training + num_validation)
X_val = X_train[mask]
y_val = y_train[mask]
mask = range(num_training)
X_train = X_train[mask]
y_train = y_train[mask]
mask = range(num_test)
X_test = X_test[mask]
y_test = y_test[mask]

#print(X_train[0].shape)

mean_pixel = X_train.mean(axis=(0, 1, 2), keepdims=True)
std_pixel = X_train.std(axis=(0, 1, 2), keepdims=True)
X_train = (X_train - mean_pixel) / std_pixel
X_val = (X_val - mean_pixel) / std_pixel
X_test = (X_test - mean_pixel) / std_pixel

In [55]:
train_ds = tf.data.Dataset.from_tensor_slices((X_train, y_train)).shuffle(BUFFER_SIZE).batch(64)
val_ds = tf.data.Dataset.from_tensor_slices((X_val, y_val)).batch(64)

In [56]:
type(train_ds)

tensorflow.python.data.ops.dataset_ops.BatchDataset

In [44]:
class MyModel(tf.keras.Model):
    def __init__(self):
        super(MyModel, self).__init__()
        self.conv1 = tf.keras.layers.Conv2D(32, 3, activation = None)
        self.relu1 = tf.keras.layers.ReLU()
        self.drop1 = tf.keras.layers.Dropout(rate = 0.5)
        self.conv2 = tf.keras.layers.Conv2D(64, 3, activation = None)
        self.relu2 = tf.keras.layers.ReLU()
        self.drop2 = tf.keras.layers.Dropout(rate = 0.5)
        self.conv3 = tf.keras.layers.Conv2D(128, 3, activation = None)
        self.relu3 = tf.keras.layers.ReLU()
        self.drop3 = tf.keras.layers.Dropout(rate = 0.5)

        self.flatten = tf.keras.layers.Flatten()
        self.fc1 = tf.keras.layers.Dense(64, activation = 'relu')
        self.fc2 = tf.keras.layers.Dense(64, activation = 'relu')
        self.fc3 = tf.keras.layers.Dense(10)
        
    def call(self, x):
        x = self.conv1(x)
        #x = self.bn1(x)
        x = self.relu1(x)
        x = self.drop1(x)
        x = self.conv2(x)
        #x = self.bn2(x)
        x = self.relu2(x)
        x = self.drop2(x)
        x = self.conv3(x)
        x = self.relu3(x)
        x = self.drop3(x)

        x = self.flatten(x)
        x = self.fc1(x)
        x = self.fc2(x)
        return self.fc3(x)

In [57]:
import os

checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")

In [58]:
strategy = tf.distribute.MirroredStrategy()


INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0',)


In [59]:
BATCH_SIZE_PER_REPLICA = 64
GLOBAL_BATCH_SIZE = BATCH_SIZE_PER_REPLICA * strategy.num_replicas_in_sync

In [60]:
train_dist_ds = strategy.experimental_distribute_dataset(train_ds)
val_dist_ds = strategy.experimental_distribute_dataset(val_ds)

In [61]:
type(train_dist_ds)

tensorflow.python.distribute.input_lib.DistributedDataset

In [62]:
imgbatch, labels = next(iter(train_dist_ds))
(imgbatch[:].shape)

TensorShape([64, 32, 32, 3])

In [81]:
with strategy.scope():
    loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True,
                                                               reduction=tf.keras.losses.Reduction.NONE)
    learning_rate = 1e-3
    optimizer = tf.keras.optimizers.Adam(learning_rate)
    checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)

    train_loss = tf.keras.metrics.Mean(name='train_loss')
    train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')

    val_loss = tf.keras.metrics.Mean(name='val_loss')
    val_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='val_accuracy')
    model = MyModel()

In [137]:
with strategy.scope():
    def train_step(inputs):
        images, labels = inputs
        with tf.GradientTape() as tape:
            predictions = model(images, training = True)
            loss_per_example = loss_object(labels, predictions)
            #tf.print(tf.math.reduce_sum(loss_per_example))
            loss = tf.nn.compute_average_loss(loss_per_example, global_batch_size=GLOBAL_BATCH_SIZE)
            #tf.print(loss)
        gradients = tape.gradient(loss, model.trainable_variables)
        optimizer.apply_gradients(zip(gradients, model.trainable_variables))
        #print("Step: {},         Loss: {}".format(optimizer.iterations.numpy(), loss.numpy()))
        train_accuracy(labels, predictions)
        return loss
        
    def val_step (inputs):
        images, labels = inputs
        predictions = model(images, training=False)
        v_loss = loss_object(labels, predictions)

        val_loss(v_loss)
        val_accuracy(labels, predictions)

In [None]:
with strategy.scope():
    @tf.function
    def dist_train_step(inputs):
        per_replica_losses = strategy.run(train_step, args=(inputs,))
        return strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses, axis = None)
    
    @tf.function
    def dist_val_step(inputs):
        return strategy.run(val_step, args=(inputs,))
    
    EPOCHS = 10
    for epoch in range(EPOCHS):
        total_loss = 0.0
        num_batches = 0
        for x in train_dist_ds:
            total_loss += dist_train_step(x)
            num_batches += 1
        train_loss = total_loss/num_batches
        
        for x in val_dist_ds:
            dist_val_step(x)
            
        if epoch %2 == 0:
            checkpoint.save(checkpoint_prefix)

        template = ("Epoch {}, Loss: {}, Accuracy: {}, Val Loss: {}, "
                    "Val Accuracy: {}")
        print(template.format(epoch+1, train_loss,
                               train_accuracy.result()*100, val_loss.result(),
                               val_accuracy.result()*100))
        val_loss.reset_states()
        train_accuracy.reset_states()
        val_accuracy.reset_states()

Epoch 1, Loss: 0.8660712838172913, Accuracy: 60.37810516357422, Test Loss: 0.9396776556968689, Test Accuracy: 67.0999984741211
