### MNIST multi GPUs implementation
inspiration from https://keras.io/guides/distributed_training_with_tensorflow/

In [18]:
import tensorflow as tf
from tensorflow.keras.datasets import mnist
from tensorflow import keras
from tensorflow.keras import layers

In [19]:
def get_dataset():
    # batch_size = 32
    
    (x_train, y_train), (x_test, y_test) = mnist.load_data()

    #print(x_test.shape)

    x_train = x_train.reshape((60000, 28 * 28))
    x_train = x_train.astype("float32") / 255
    x_test = x_test.reshape((10000, 28 * 28))
    x_test = x_test.astype("float32") / 255
    y_train = y_train.astype("float32")
    y_test = y_test.astype("float32")

    return (
        #tf.data.Dataset.from_tensor_slices((x_train, y_train)).batch(batch_size),
        #tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(batch_size),
        (x_train, y_train),
        (x_test, y_test)
    )

In [20]:
def get_compiled_model():
    # Make a simple 2-layer densely-connected neural network.
    model = keras.Sequential([
        layers.Dense(512, activation="relu"),
        layers.Dense(10, activation="softmax")
    ])

    model.compile(optimizer="rmsprop",
              loss="sparse_categorical_crossentropy",
              metrics=["accuracy"])
    return model

In [21]:
# Create a MirroredStrategy.
strategy = tf.distribute.MirroredStrategy()
print("Number of devices: {}".format(strategy.num_replicas_in_sync))

# Open a strategy scope.
with strategy.scope():
    # Everything that creates variables should be under the strategy scope.
    # In general this is only model construction & `compile()`.
    model = get_compiled_model()

    # Train the model on all available devices.
    (x_train, y_train), (x_test, y_test) = get_dataset()
    model.fit(x_train, y_train, epochs=5, batch_size=64)

    # Test the model on all available devices.
    test_loss, test_acc = model.evaluate(x_test, y_test)

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1')
Number of devices: 2


2024-02-01 17:19:24.511535: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:656] In AUTO-mode, and switching to DATA-based sharding, instead of FILE-based sharding as we cannot find appropriate reader dataset op(s) to shard. Error: Did not find a shardable source, walked to a node which is not a dataset: name: "FlatMapDataset/_9"
op: "FlatMapDataset"
input: "PrefetchDataset/_8"
attr {
  key: "Targuments"
  value {
    list {
    }
  }
}
attr {
  key: "f"
  value {
    func {
      name: "__inference_Dataset_flat_map_slice_batch_indices_90743"
    }
  }
}
attr {
  key: "output_shapes"
  value {
    list {
      shape {
        dim {
          size: -1
        }
      }
    }
  }
}
attr {
  key: "output_types"
  value {
    list {
      type: DT_INT64
    }
  }
}
. Consider either turning off auto-sharding or switching the auto_shard_policy to DATA to shard this dataset. You can do this by creating a new `tf.data.Options()` object then setting `options.experimental_distribute.au

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


2024-02-01 17:19:40.123936: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:656] In AUTO-mode, and switching to DATA-based sharding, instead of FILE-based sharding as we cannot find appropriate reader dataset op(s) to shard. Error: Did not find a shardable source, walked to a node which is not a dataset: name: "FlatMapDataset/_9"
op: "FlatMapDataset"
input: "PrefetchDataset/_8"
attr {
  key: "Targuments"
  value {
    list {
    }
  }
}
attr {
  key: "f"
  value {
    func {
      name: "__inference_Dataset_flat_map_slice_batch_indices_101716"
    }
  }
}
attr {
  key: "output_shapes"
  value {
    list {
      shape {
        dim {
          size: -1
        }
      }
    }
  }
}
attr {
  key: "output_types"
  value {
    list {
      type: DT_INT64
    }
  }
}
. Consider either turning off auto-sharding or switching the auto_shard_policy to DATA to shard this dataset. You can do this by creating a new `tf.data.Options()` object then setting `options.experimental_distribute.a



In [23]:
print(f"test_acc: {test_acc}")

test_acc: 0.98089998960495
