Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DALI with Keras in multi-GPU mode #1852

Closed
jpnavarro-nv opened this issue Apr 2, 2020 · 4 comments
Closed

DALI with Keras in multi-GPU mode #1852

jpnavarro-nv opened this issue Apr 2, 2020 · 4 comments
Labels
question Further information is requested

Comments

@jpnavarro-nv
Copy link

Hi.
Is there any way to use DALI with Keras and 'multi_gpu_model'?
Didn't found any example around.

@awolant
Copy link
Contributor

awolant commented Apr 2, 2020

Hello, thanks for the question.
We haven't tested it with multi_gpu_model, but as far as I know, this function is deprecated, if not even removed already, so I don't think we will ever try it.
If you want multi gpu with Keras and DALI, Horovod is known to work, at least it was some time ago.
We are working on making DALI compatible with MirroredStrategy, but it's not yet at the point, where you can use it on the GPU. If you want DALI on CPU only, look into our tests here.

@JanuszL JanuszL added the question Further information is requested label Apr 2, 2020
@jpnavarro-nv
Copy link
Author

OK. I'll test Keras + Horovod approach and post the solution here.

@awolant
Copy link
Contributor

awolant commented Apr 3, 2020

import tensorflow as tf
import horovod.tensorflow.keras as hvd


# Horovod: initialize Horovod.
hvd.init()

import nvidia.dali.plugin.tf as dali_tf
import nvidia.dali as dali
from nvidia.dali.pipeline import Pipeline
import nvidia.dali.ops as ops
import nvidia.dali.types as types

import os


# Path to MNIST dataset
data_path = os.path.join(os.environ['DALI_EXTRA_PATH'], 'db/MNIST/training/')



TARGET = 0.8
BATCH_SIZE = 50
DROPOUT = 0.2
IMAGE_SIZE = 28
NUM_CLASSES = 10
HIDDEN_SIZE = 128
EPOCHS = 3
NUM_GPUS = hvd.local_size()
GLOBAL_BATCH_SIZE = BATCH_SIZE * NUM_GPUS
DATASET_SIZE = 60000
ITERATIONS = DATASET_SIZE // GLOBAL_BATCH_SIZE


data_path = os.path.join(
    os.environ['DALI_EXTRA_PATH'], 'db/MNIST/training/')


# DALI pipeline definition
class MnistPipeline(Pipeline):
    def __init__(self, num_threads, path, device, device_id=0, shard_id=0, num_shards=1, seed=0):
        super(MnistPipeline, self).__init__(
            BATCH_SIZE, num_threads, device_id, seed)
        self.device = device
        self.reader = ops.Caffe2Reader(
            path=path, random_shuffle=True, shard_id=shard_id, num_shards=num_shards)
        self.decode = ops.ImageDecoder(
            device='mixed' if device is 'gpu' else 'cpu',
            output_type=types.GRAY)
        self.cmn = ops.CropMirrorNormalize(
            device=device,
            output_dtype=types.FLOAT,
            image_type=types.GRAY,
            mean=[0.],
            std=[255.],
            output_layout="CHW")

    def define_graph(self):
        inputs, labels = self.reader(name="Reader")
        images = self.decode(inputs)
        if self.device is 'gpu':
            labels = labels.gpu()
        images = self.cmn(images)

        return (
            images,
            labels
        )


# Parameters settings
device = 'gpu'

# Parameters for DALI TF DATASET
shapes = (
    (BATCH_SIZE, IMAGE_SIZE, IMAGE_SIZE),
    (BATCH_SIZE, 1)
    )
dtypes = (
    tf.float32,
    tf.int32
    )


def dataset_options():
    options = tf.data.Options()
    try:
        options.experimental_optimization.apply_default_optimizations = False
        options.experimental_optimization.autotune = False   
    except:
        print('Could not set TF Dataset Options')

    return options


# Horovod: pin GPU to be used to process local rank (one GPU per process)
gpus = tf.config.experimental.list_physical_devices('GPU')
for gpu in gpus:
    tf.config.experimental.set_memory_growth(gpu, True)
if gpus:
    tf.config.experimental.set_visible_devices(gpus[hvd.local_rank()], 'GPU')


with tf.device('/gpu:0'):
    mnist_pipeline = MnistPipeline(
        4, data_path, device, device_id=hvd.local_rank(), shard_id=hvd.local_rank(), num_shards=hvd.local_size())

    dataset = dali_tf.DALIDataset(
        pipeline=mnist_pipeline,
        batch_size=BATCH_SIZE,
        output_shapes=shapes,
        output_dtypes=dtypes,
        num_threads=4,
        device_id=0)

    dataset = dataset.with_options(dataset_options())

    mnist_model = tf.keras.Sequential([
        tf.keras.layers.Input(shape=(IMAGE_SIZE, IMAGE_SIZE), name='images'),
        tf.keras.layers.Flatten(input_shape=(IMAGE_SIZE, IMAGE_SIZE)),
        tf.keras.layers.Dense(HIDDEN_SIZE, activation='relu'),
        tf.keras.layers.Dropout(DROPOUT),
        tf.keras.layers.Dense(NUM_CLASSES, activation='softmax')
    ])

    # Horovod: adjust learning rate based on number of GPUs.
    opt = tf.optimizers.Adam(0.001 * hvd.size())

    # Horovod: add Horovod DistributedOptimizer.
    opt = hvd.DistributedOptimizer(opt)

    # Horovod: Specify `experimental_run_tf_function=False` to ensure TensorFlow
    # uses hvd.DistributedOptimizer() to compute gradients.
    mnist_model.compile(loss=tf.losses.SparseCategoricalCrossentropy(),
                        optimizer=opt,
                        metrics=['accuracy'],
                        experimental_run_tf_function=False)

    callbacks = [
        # Horovod: broadcast initial variable states from rank 0 to all other processes.
        # This is necessary to ensure consistent initialization of all workers when
        # training is started with random weights or restored from a checkpoint.
        hvd.callbacks.BroadcastGlobalVariablesCallback(0),

        # Horovod: average metrics among workers at the end of every epoch.
        #
        # Note: This callback must be in the list before the ReduceLROnPlateau,
        # TensorBoard or other metrics-based callbacks.
        hvd.callbacks.MetricAverageCallback(),

        # Horovod: using `lr = 1.0 * hvd.size()` from the very beginning leads to worse final
        # accuracy. Scale the learning rate `lr = 1.0` ---> `lr = 1.0 * hvd.size()` during
        # the first three epochs. See https://arxiv.org/abs/1706.02677 for details.
        hvd.callbacks.LearningRateWarmupCallback(warmup_epochs=3, verbose=1),
    ]

    # Horovod: save checkpoints only on worker 0 to prevent other workers from corrupting them.
    # if hvd.rank() == 0:
    #     callbacks.append(tf.keras.callbacks.ModelCheckpoint('/tmp/checkpoint-{epoch}.h5'))

    # Horovod: write logs on worker 0.
    verbose = 0 if hvd.local_rank() > 0 else 1

    # Train the model.
    # Horovod: adjust number of steps based on number of GPUs.
    mnist_model.fit(dataset, steps_per_epoch=ITERATIONS // hvd.size(), callbacks=callbacks, epochs=EPOCHS, verbose=verbose)

@jpnavarro-nv This script should get you started with Horovod+Keras+DALI Dataset. You can run it with:
horovodrun -np 2 -H localhost:4 python dev/horovod_test.py

@jpnavarro-nv
Copy link
Author

Wow! This is outstanding @awolant . Many thanks for sharing!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

3 participants