-
Notifications
You must be signed in to change notification settings - Fork 621
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
DALI with Keras in multi-GPU mode #1852
Comments
Hello, thanks for the question. |
OK. I'll test Keras + Horovod approach and post the solution here. |
import tensorflow as tf
import horovod.tensorflow.keras as hvd
# Horovod: initialize Horovod.
hvd.init()
import nvidia.dali.plugin.tf as dali_tf
import nvidia.dali as dali
from nvidia.dali.pipeline import Pipeline
import nvidia.dali.ops as ops
import nvidia.dali.types as types
import os
# Path to MNIST dataset
data_path = os.path.join(os.environ['DALI_EXTRA_PATH'], 'db/MNIST/training/')
TARGET = 0.8
BATCH_SIZE = 50
DROPOUT = 0.2
IMAGE_SIZE = 28
NUM_CLASSES = 10
HIDDEN_SIZE = 128
EPOCHS = 3
NUM_GPUS = hvd.local_size()
GLOBAL_BATCH_SIZE = BATCH_SIZE * NUM_GPUS
DATASET_SIZE = 60000
ITERATIONS = DATASET_SIZE // GLOBAL_BATCH_SIZE
data_path = os.path.join(
os.environ['DALI_EXTRA_PATH'], 'db/MNIST/training/')
# DALI pipeline definition
class MnistPipeline(Pipeline):
def __init__(self, num_threads, path, device, device_id=0, shard_id=0, num_shards=1, seed=0):
super(MnistPipeline, self).__init__(
BATCH_SIZE, num_threads, device_id, seed)
self.device = device
self.reader = ops.Caffe2Reader(
path=path, random_shuffle=True, shard_id=shard_id, num_shards=num_shards)
self.decode = ops.ImageDecoder(
device='mixed' if device is 'gpu' else 'cpu',
output_type=types.GRAY)
self.cmn = ops.CropMirrorNormalize(
device=device,
output_dtype=types.FLOAT,
image_type=types.GRAY,
mean=[0.],
std=[255.],
output_layout="CHW")
def define_graph(self):
inputs, labels = self.reader(name="Reader")
images = self.decode(inputs)
if self.device is 'gpu':
labels = labels.gpu()
images = self.cmn(images)
return (
images,
labels
)
# Parameters settings
device = 'gpu'
# Parameters for DALI TF DATASET
shapes = (
(BATCH_SIZE, IMAGE_SIZE, IMAGE_SIZE),
(BATCH_SIZE, 1)
)
dtypes = (
tf.float32,
tf.int32
)
def dataset_options():
options = tf.data.Options()
try:
options.experimental_optimization.apply_default_optimizations = False
options.experimental_optimization.autotune = False
except:
print('Could not set TF Dataset Options')
return options
# Horovod: pin GPU to be used to process local rank (one GPU per process)
gpus = tf.config.experimental.list_physical_devices('GPU')
for gpu in gpus:
tf.config.experimental.set_memory_growth(gpu, True)
if gpus:
tf.config.experimental.set_visible_devices(gpus[hvd.local_rank()], 'GPU')
with tf.device('/gpu:0'):
mnist_pipeline = MnistPipeline(
4, data_path, device, device_id=hvd.local_rank(), shard_id=hvd.local_rank(), num_shards=hvd.local_size())
dataset = dali_tf.DALIDataset(
pipeline=mnist_pipeline,
batch_size=BATCH_SIZE,
output_shapes=shapes,
output_dtypes=dtypes,
num_threads=4,
device_id=0)
dataset = dataset.with_options(dataset_options())
mnist_model = tf.keras.Sequential([
tf.keras.layers.Input(shape=(IMAGE_SIZE, IMAGE_SIZE), name='images'),
tf.keras.layers.Flatten(input_shape=(IMAGE_SIZE, IMAGE_SIZE)),
tf.keras.layers.Dense(HIDDEN_SIZE, activation='relu'),
tf.keras.layers.Dropout(DROPOUT),
tf.keras.layers.Dense(NUM_CLASSES, activation='softmax')
])
# Horovod: adjust learning rate based on number of GPUs.
opt = tf.optimizers.Adam(0.001 * hvd.size())
# Horovod: add Horovod DistributedOptimizer.
opt = hvd.DistributedOptimizer(opt)
# Horovod: Specify `experimental_run_tf_function=False` to ensure TensorFlow
# uses hvd.DistributedOptimizer() to compute gradients.
mnist_model.compile(loss=tf.losses.SparseCategoricalCrossentropy(),
optimizer=opt,
metrics=['accuracy'],
experimental_run_tf_function=False)
callbacks = [
# Horovod: broadcast initial variable states from rank 0 to all other processes.
# This is necessary to ensure consistent initialization of all workers when
# training is started with random weights or restored from a checkpoint.
hvd.callbacks.BroadcastGlobalVariablesCallback(0),
# Horovod: average metrics among workers at the end of every epoch.
#
# Note: This callback must be in the list before the ReduceLROnPlateau,
# TensorBoard or other metrics-based callbacks.
hvd.callbacks.MetricAverageCallback(),
# Horovod: using `lr = 1.0 * hvd.size()` from the very beginning leads to worse final
# accuracy. Scale the learning rate `lr = 1.0` ---> `lr = 1.0 * hvd.size()` during
# the first three epochs. See https://arxiv.org/abs/1706.02677 for details.
hvd.callbacks.LearningRateWarmupCallback(warmup_epochs=3, verbose=1),
]
# Horovod: save checkpoints only on worker 0 to prevent other workers from corrupting them.
# if hvd.rank() == 0:
# callbacks.append(tf.keras.callbacks.ModelCheckpoint('/tmp/checkpoint-{epoch}.h5'))
# Horovod: write logs on worker 0.
verbose = 0 if hvd.local_rank() > 0 else 1
# Train the model.
# Horovod: adjust number of steps based on number of GPUs.
mnist_model.fit(dataset, steps_per_epoch=ITERATIONS // hvd.size(), callbacks=callbacks, epochs=EPOCHS, verbose=verbose) @jpnavarro-nv This script should get you started with Horovod+Keras+DALI Dataset. You can run it with: |
Wow! This is outstanding @awolant . Many thanks for sharing! |
Hi.
Is there any way to use DALI with Keras and 'multi_gpu_model'?
Didn't found any example around.
The text was updated successfully, but these errors were encountered: