Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PipeModeDataset leads to infinite loop / memory exhaust when re-using dataset with tf.keras #46

Open
fmannhardt opened this issue Aug 12, 2019 · 10 comments

Comments

@fmannhardt
Copy link

I have tried out with several configurations to use PipeModeDataset together with tf.keras and I run into trouble re-using the same dataset (e.g. validation) for use in both fit and evaluate. It seems that on the second call the Sagemaker instance exhausts all available GPU memory and goes into some kind of loop.

This is my current training script (I will try to strip it down further, but this works perfectly when using File mode but fails on the evaluate call when executed in ``Pipe` mode:

import argparse, os

import logging
import math
import json

logging.getLogger().setLevel(logging.INFO)
logging.getLogger("tensorflow").setLevel(logging.INFO)
import tensorflow as tf

import glob

def load_data_as_dataset(channel_name, channel, data_config):

    def get_filenames(channel):
        return(glob.glob(channel + "/*.tfrecord"))

    mode = data_config[channel_name]['TrainingInputMode']    

    logging.info("Running {} in mode: {}".format(channel_name, mode))

    if mode == 'Pipe':
        # Construct a `PipeModeDataset` reading from a 'training' channel, using
        # the TF Record encoding.        
        from sagemaker_tensorflow import PipeModeDataset
        ds = PipeModeDataset(channel=channel_name, record_format='TFRecord')
    else:    
        filenames = get_filenames(channel)
        logging.info("Loading files {}".format(filenames))
        ds = tf.data.TFRecordDataset(filenames) 
    
    return ds

def extract_example(example_proto):

    image_feature_description = {
        'image/encoded': tf.io.FixedLenFeature([], tf.string),
        'image/class/obj1_center_x': tf.io.FixedLenFeature([], tf.float32),
        'image/class/obj1_center_y': tf.io.FixedLenFeature([], tf.float32),
        'image/class/obj2_center_x': tf.io.FixedLenFeature([], tf.float32),
        'image/class/obj2_center_y': tf.io.FixedLenFeature([], tf.float32)
    }

    feature = tf.io.parse_single_example(example_proto, image_feature_description)

    image = feature['image/encoded']
    image = tf.image.decode_jpeg(image, channels=3)
    image = tf.image.convert_image_dtype(image, tf.float32)

    return (image, tf.convert_to_tensor([feature['image/class/obj1_center_x'], 
                                         feature['image/class/obj1_center_y'],
                                         feature['image/class/obj2_center_x'],
                                         feature['image/class/obj2_center_y']]))

def train_preprocess(image, label):
    
    image = tf.image.random_brightness(image, max_delta=32.0 / 255.0)
    image = tf.image.random_saturation(image, lower=0.5, upper=1.5)

    #Make sure the image is still in [0, 1]
    image = tf.clip_by_value(image, 0.0, 1.0)

    return image, label


def build_model(input_shape):
    
    model = tf.keras.Sequential()

    model.add(tf.keras.layers.Conv2D(8, kernel_size=(3,3), padding='same', 
                                     input_shape=input_shape))
    model.add(tf.keras.layers.LeakyReLU())
    
    model.add(tf.keras.layers.Conv2D(8, kernel_size=(3,3), padding='same'))
    model.add(tf.keras.layers.LeakyReLU())    
    
    model.add(tf.keras.layers.MaxPooling2D(pool_size=(2,2), strides=2))

    model.add(tf.keras.layers.Conv2D(16, kernel_size=(3,3), padding='same'))
    model.add(tf.keras.layers.LeakyReLU())

    model.add(tf.keras.layers.Conv2D(16, kernel_size=(3,3), padding='same'))
    model.add(tf.keras.layers.LeakyReLU())
    
    model.add(tf.keras.layers.MaxPooling2D(pool_size=(2,2), strides=2))

    model.add(tf.keras.layers.Conv2D(32, kernel_size=(3,3), padding='same'))
    model.add(tf.keras.layers.LeakyReLU())

    model.add(tf.keras.layers.Conv2D(32, kernel_size=(3,3), padding='same'))
    model.add(tf.keras.layers.LeakyReLU())

    model.add(tf.keras.layers.GlobalAveragePooling2D())

    model.add(tf.keras.layers.Dense(32))
    model.add(tf.keras.layers.LeakyReLU())
    model.add(tf.keras.layers.Dropout(rate = 0.1))

    model.add(tf.keras.layers.Dense(4, activation='linear'))
    
    return (model)

if __name__ == '__main__':
        
    parser = argparse.ArgumentParser()
    
    #
    # Standard parameters required by Sagemaker
    #
    parser.add_argument('--gpu-count', type=int, default=os.environ['SM_NUM_GPUS'])
    parser.add_argument('--output-dir', type=str, default=os.environ.get('SM_OUTPUT_DIR'))        
    parser.add_argument('--model-dir', type=str, default=os.environ.get('SM_MODEL_DIR'))
    
    parser.add_argument('--data-config',type=json.loads,default=os.environ.get('SM_INPUT_DATA_CONFIG'))

    #
    # Input Channels
    #
    parser.add_argument('--training', type=str, required=False, default=os.environ.get('SM_CHANNEL_TRAINING'))
    parser.add_argument('--validation', type=str, required=False, default=os.environ.get('SM_CHANNEL_VALIDATION'))       

    #
    # Input Parameters
    #
    parser.add_argument('--num-channels', type=int, default=3)
    parser.add_argument('--img-height', type=int, default=416)
    parser.add_argument('--img-width', type=int, default=416)
    
    parser.add_argument('--num-samples', type=int, required=True)
    parser.add_argument('--num-validation', type=int, default=64)
    
    #
    # Training Parameters
    #
    parser.add_argument('--epochs', type=int, default=10)
    parser.add_argument('--learning-rate', type=float, default=0.01)
    parser.add_argument('--batch-size', type=int, default=16)
        
    args, _ = parser.parse_known_args()
    
    # NCWH format
    input_shape = (args.img_width,
                   args.img_height,
                   args.num_channels)     
    
    #
    # Build model
    #
    
    model = build_model(input_shape)
       
    model.compile(optimizer = tf.keras.optimizers.Adam(lr = args.learning_rate),
                  loss= "mse", 
                  metrics=['mse', 'mae'])
    
    
    len_train = args.num_samples
    logging.info("Training samples: {}".format(len_train))
    
    len_val = args.num_validation
    logging.info("Validation samples: {}".format(len_val))
    
    #
    # Prepare tf.dataset input
    #
    
    x_train = load_data_as_dataset("training", args.training, args.data_config)    
    x_train = x_train.shuffle(buffer_size = 8 * args.batch_size)    
    x_train = x_train.repeat()
    
    x_train = x_train.map(extract_example, num_parallel_calls=4)
    x_train = x_train.map(train_preprocess, num_parallel_calls=4)
    x_train = x_train.batch(args.batch_size)
    
    x_train = x_train.prefetch(1)
    
    if (args.validation != None and args.num_validation > 0):
        x_val = load_data_as_dataset("validation", args.validation, args.data_config)
        x_val = x_val.repeat()        
        x_val = x_val.map(extract_example, num_parallel_calls=4)
        x_val = x_val.batch(args.batch_size)
        
        x_val = x_val.prefetch(1)
    else:
        x_val = None
    
    #
    # Train model
    #
    
    logging.info("Batch size: {}".format(args.batch_size))
    
    steps_per_epoch = len_train // args.batch_size
    logging.info("Training steps per epoch {}".format(steps_per_epoch))
    
    validation_steps = len_val // args.batch_size
    logging.info("Validation steps per epoch {}".format(validation_steps))    
    
    if (x_val != None):
        history = model.fit(x = x_train,
                            validation_data = x_val, 
                            epochs = args.epochs,
                            steps_per_epoch = steps_per_epoch,
                            validation_steps = validation_steps,
                            verbose = 2)
    else:
        history = model.fit(x = x_train,
                            epochs = args.epochs,
                            steps_per_epoch = steps_per_epoch,
                            verbose = 2)        
    
    logging.info("Result: {}".format(history.history))
    
    #
    # Evaluate model
    #    
    
    if (x_val != None):
        score = model.evaluate(x_val,
                       steps=validation_steps,
                       verbose=2)        
        
        logging.info('Validation: {}'.format(list(zip(score, model.metrics_names))))
              
    #
    # Save/Export model
    #

    tf.contrib.saved_model.save_keras_model(model, os.path.join(args.model_dir, 'model/1'))
    model.save(os.path.join(args.output_dir, 'model.h5'))

I tried with framework version: v1.13, v1.14 both show the same behaviour and this seems to be related to re-using the dataset after model.fit is done. If I don't call model.evaluate, then everything is fine.

Unfortunately not much logging output except for this warning:

2019-08-12 07:44:06.993021: W tensorflow/core/framework/dataset.cc:393] Input of PipeModeDatasetOp::Dataset will not be optimized because the dataset does not implement the AsGraphDefInternal() method needed to apply optimizations.
2019-08-12 07:44:07.056622: W tensorflow/core/framework/dataset.cc:393] Input of PipeModeDatasetOp::Dataset will not be optimized because the dataset does not implement the AsGraphDefInternal() method needed to apply optimizations.

@ChoiByungWook
Copy link
Contributor

ChoiByungWook commented Aug 30, 2019

Hello @fmannhardt,

Apologies for the late response.

Let me look into this and I'll try to respond as soon as possible.

This potentially looks like it may require some dedicated investigation time to root cause. Let me speak with my team about this.

Thank you for your patience!

@fmannhardt
Copy link
Author

Thanks. If you need more information let me know.

@mvsusp
Copy link
Contributor

mvsusp commented Sep 3, 2019

Hi @fmannhardt,

I've noticed that you are passing validation_data = x_val when you call model.fit as well. Are you sure that the issue is happening during evaluate instead of happening under fit?

Would you mind sharing a complete example allowing us to reproduce the issue?

Thanks for using SageMaker

Márcio

@fmannhardt
Copy link
Author

I will try to set-up a complete example including data.

Based on the log messages, the issue appears on the second use. So it hangs on the call to evaluate as this is the second time the x_val dataset is used.

@mvsusp
Copy link
Contributor

mvsusp commented Sep 5, 2019

Thanks @fmannhardt, an minimal example will be very helpful to help us diagnose the issue.

@kafka399
Copy link

kafka399 commented Nov 5, 2019

Same issue, here is an example of the code:

 train_data_single = load_dataset('training', batch_size)
val_data_single = load_dataset('validation', batch_size)

print(train_data_single)

model = tf.keras.models.Sequential()

model.add(tf.keras.layers.LSTM(hidden_size, input_shape=(window_size, features), return_sequences=True))
model.add(tf.keras.layers.Dropout(rate = dropout))
model.add(tf.keras.layers.LSTM(hidden_size, return_sequences=False))
model.add(tf.keras.layers.Dropout(rate= dropout))
model.add(tf.keras.layers.Dense(1, kernel_initializer='normal'))
model.compile(loss='mse', optimizer='adam',metrics=['mse', 'mae', 'mape'])  # Using mse loss results in faster convergence

model.summary()

if gpu_count > 1:
    model = multi_gpu_model(model, gpus=gpu_count)
    print('going GPU')


single_step_history = model.fit(train_data_single, 
                                        
                                        epochs=epochs, 
                                        steps_per_epoch=10,
                                        validation_steps=2,
                                        validation_data=val_data_single
                                        )

I can remove validation_data and train job successfully completes the training, OR the number of epochs are removed or set to one.

@athewsey
Copy link

athewsey commented Apr 17, 2020

I think I'm seeing same behaviour as @kafka399, but not sure whether it's the same as this parent issue or should be tracked as separate: For me the hang is at 0% CPU utilization and static memory consumption - looks more like a deadlock than an infinite loop.

Setup

My script creates a tf.keras.Model object detection model and trains it in two stages: unfreezing some layers after the first round. Setup is something like:

ds_train = PipeModeDataset(channel="train") \
    .repeat(args.epochs) \
    .batch(2) \
    .map(data.get_tf_parse_mapper(args.data_shape, randomize=True)) \
    .batch(args.batch_size) \
    .map(data.get_tf_train_batch_mapper(args.batch_size, args.data_shape, args.num_classes))
ds_val = [Pretty much the same with different channel name]

# Do the pre-traininng:
train_model.fit(
    ds_train,
    epochs=args.epochs_stabilize,
    initial_epoch=0,
    shuffle=False,
    steps_per_epoch=args.num_samples_train // args.batch_size,
    validation_data=ds_val,
    validation_steps=args.num_samples_validation // args.batch_size,
    verbose=2,
)

# [Unfreeze some layers] then recompile the model:
train_model.compile(
    optimizer=Adam(lr=1e-4),
)

# Train for remaining epochs:
train_model.fit(
    ds_train,
    callbacks=train_callbacks,
    epochs=args.epochs,
    initial_epoch=args.epochs_stabilize,
    shuffle=False,
    steps_per_epoch=args.num_samples_train // args.batch_size,
    validation_data=ds_val,
    validation_steps=args.num_samples_validation // args.batch_size,
    verbose=2,
)

Findings

On SM TensorFlow container v1.15.2 (which I was originally targeting), the code ran (both training rounds) as long as I either removed the validation_* arguments or the epochs argument (although one epoch of training was not so useful). If I added them both in, it would freeze in Epoch 1 of the first training round.

I noticed in the README that PipeModeDataset only advertises support for TF v1.7-1.12 so tried dropping back to TF v1.12.

This fixed the issue, but only if I made sure my datasets were exact multiples of the batch size - otherwise it froze for me in the first epoch of the second model.fit() call.

So my asks would be:

  • How should we deal with partial final batches? Is there something I should do differently in my PipeModeDataset pipeline to handle it?
  • Is there anything we can do to get it working on newer versions of TensorFlow?

Edit: Some additional observations:

  • Adding drop_remainder=True in the Dataset batch() call doesn't seem to help
  • Having a .prefetch(4) at the end of my Dataset pipeline also caused the script to freeze at the start of the second round of training even on TFv1.12 (but complete the first fit OK). My data does not contain a round multiple of 4 batches, so I guess that's related to the batch size requirement.
  • On TFv1.13, the first round of training completes but the script freezes at the start of the second (before printing "Epoch 1/X")
  • TFv1.14 freezes on first epoch of first training round when a validation dataset is supplied, like v1.15.2. (after printing "Epoch 1/X")

@athewsey
Copy link

Here's my attempt at a full minimal-ish reproducible example on TF 1.15, with MNIST digits classification:

https://github.com/athewsey/sagemaker-workshop-101/tree/feat/pipemode/migration_challenge_keras_image

It's adapted from a workshop, so to run you need to:

  • First run (at least the first parts of) the "Local Notebook" (on SageMaker) to download MNIST and process (a fraction of) it into folders of JPEGs
  • Then create a bucket and run "Instructions" (the main notebook) which uploads the data to S3 and starts a training job. Don't click run all because the S3 upload happens in the background and you need it to complete before starting the training job.

On TFv1.15 (as per Git) the training freezes on the first epoch with 0 CPU/GPU utilization. On TFv1.12 (if taking care to make sure batch size is a factor of both num training samples and num test samples), the training completes successfully.

@oreade16
Copy link

I'm also having similar struggles here.
Anyone know how to eliminate this error:

Input of GeneratorDatasetOp::Dataset will not be optimized because the dataset does not implement the AsGraphDefInternal() method needed to apply optimizations.

@Dex247
Copy link

Dex247 commented Aug 4, 2022

Same issue here. In my case, I am trying to train my model. I am using a CPU machine, with 6GB Ram. Could it be the memory is not adequate for the model training?

Same error ---- GeneratorDatasetOp::Dataset will not be optimized because the dataset does not implement the AsGraphDefInternal() method needed to apply optimizations.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

7 participants