Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AttributeError: 'NoneType' object has no attribute 'taskgraph' #27

Open
Waterpine opened this issue Apr 10, 2023 · 3 comments
Open

AttributeError: 'NoneType' object has no attribute 'taskgraph' #27

Waterpine opened this issue Apr 10, 2023 · 3 comments

Comments

@Waterpine
Copy link

Waterpine commented Apr 10, 2023

Hi EPL team,

When I use epl library to train the following code:

import os
import numpy as np
from concurrent.futures import ThreadPoolExecutor
from PIL import Image
import tensorflow as tf
import epl

def preprocess_image(image):
    # Resize and crop
    width, height = image.size
    if width > height:
        new_width = int(224 * width / height)
        image = image.resize((new_width, 224))
        left = (new_width - 224) / 2
        image = image.crop((left, 0, left + 224, 224))
    else:
        new_height = int(224 * height / width)
        image = image.resize((224, new_height))
        top = (new_height - 224) / 2
        image = image.crop((0, top, 224, top + 224))

    # Normalize pixel values
    image = np.array(image, dtype=np.float32) / 255.0
    mean = np.array([0.485, 0.456, 0.406])[None, None, :]
    std = np.array([0.229, 0.224, 0.225])[None, None, :]
    image = (image - mean) / std

    return image


def load_and_preprocess_image(path):
    image = Image.open(path).convert('RGB')
    return preprocess_image(image)


train_image_dir = '/users/Master/imagenet/train'
val_image_dir = '/users/Master/imagenet/val'
class_names = sorted(os.listdir(train_image_dir))
num_classes = len(class_names)

train_image_paths = []
train_labels = []
val_image_paths = []
val_labels = []

for label, class_name in enumerate(class_names):
    train_class_dir = os.path.join(train_image_dir, class_name)
    val_class_dir = os.path.join(val_image_dir, class_name)

    for img_name in os.listdir(train_class_dir):
        img_path = os.path.join(train_class_dir, img_name)
        train_image_paths.append(img_path)
        train_labels.append(label)

    for img_name in os.listdir(val_class_dir):
        img_path = os.path.join(val_class_dir, img_name)
        val_image_paths.append(img_path)
        val_labels.append(label)


def load_images_parallel(image_paths, num_workers=16):
    with ThreadPoolExecutor(max_workers=num_workers) as executor:
        images = list(executor.map(load_and_preprocess_image, image_paths))
    return np.array(images)


def load_images_chunk(image_paths, labels, batch_size):
    num_batches = int(np.ceil(len(image_paths) / batch_size))
    for i in range(num_batches):
        batch_image_paths = image_paths[i * batch_size:(i + 1) * batch_size]
        batch_labels = labels[i * batch_size:(i + 1) * batch_size]
        batch_images = load_images_parallel(batch_image_paths)
        batch_labels_one_hot = tf.keras.utils.to_categorical(batch_labels, num_classes=num_classes)
        yield batch_images, batch_labels_one_hot


def conv2d_bn(x, filters, kernel_size, strides=1, padding='same', activation=tf.nn.relu, name=None):
    x = tf.layers.conv2d(x, filters, kernel_size, strides=strides, padding=padding, use_bias=False, name=name)
    x = tf.layers.batch_normalization(x, training=True)
    if activation is not None:
        x = activation(x)
    return x


def identity_block(input_tensor, filters, stage, block):
    filters1, filters2, filters3 = filters
    conv_name_base = 'res' + str(stage) + block + '_branch'
    bn_name_base = 'bn' + str(stage) + block + '_branch'

    x = conv2d_bn(input_tensor, filters1, 1, name=conv_name_base + '2a')
    x = conv2d_bn(x, filters2, 3, name=conv_name_base + '2b')
    x = conv2d_bn(x, filters3, 1, activation=None, name=conv_name_base + '2c')

    x = tf.add(x, input_tensor)
    x = tf.nn.relu(x)
    return x


def conv_block(input_tensor, filters, stage, block, strides=2):
    filters1, filters2, filters3 = filters
    conv_name_base = 'res' + str(stage) + block + '_branch'
    bn_name_base = 'bn' + str(stage) + block + '_branch'

    x = conv2d_bn(input_tensor, filters1, 1, strides=strides, name=conv_name_base + '2a')
    x = conv2d_bn(x, filters2, 3, name=conv_name_base + '2b')
    x = conv2d_bn(x, filters3, 1, activation=None, name=conv_name_base + '2c')

    shortcut = conv2d_bn(input_tensor, filters3, 1, strides=strides, activation=None, name=conv_name_base + '1')

    x = tf.add(x, shortcut)
    x = tf.nn.relu(x)
    return x


def resnet50(input_tensor, classes):
    x = conv2d_bn(input_tensor, 64, 7, strides=2, name='conv1')
    x = tf.layers.max_pooling2d(x, 3, strides=2, padding='same', name='pool1')

    x = conv_block(x, [64, 64, 256], stage=2, block='a', strides=1)
    x = identity_block(x, [64, 64, 256], stage=2, block='b')
    x = identity_block(x, [64, 64, 256], stage=2, block='c')

    x = conv_block(x, [128, 128, 512], stage=3, block='a')
    x = identity_block(x, [128, 128, 512], stage=3, block='b')
    x = identity_block(x, [128, 128, 512], stage=3, block='c')
    x = identity_block(x, [128, 128, 512], stage=3, block='d')

    x = conv_block(x, [256, 256, 1024], stage=4, block='a')
    x = identity_block(x, [256, 256, 1024], stage=4, block='b')
    x = identity_block(x, [256, 256, 1024], stage=4, block='c')
    x = identity_block(x, [256, 256, 1024], stage=4, block='d')
    x = identity_block(x, [256, 256, 1024], stage=4, block='e')
    x = identity_block(x, [256, 256, 1024], stage=4, block='f')

    x = conv_block(x, [512, 512, 2048], stage=5, block='a')
    x = identity_block(x, [512, 512, 2048], stage=5, block='b')
    x = identity_block(x, [512, 512, 2048], stage=5, block='c')

    x = tf.layers.average_pooling2d(x, 7, strides=1, padding='valid', name='pool5')
    x = tf.layers.flatten(x)
    x = tf.layers.dense(x, classes, activation=None, name='fc1000')

    return x


def run_model():
    with tf.Session() as sess:
        input_tensor = tf.placeholder(tf.float32, shape=[None, 224, 224, 3], name="input_image")
        labels_tensor = tf.placeholder(tf.float32, shape=[None, num_classes], name="labels")
        learning_rate = 0.001

        logits = resnet50(input_tensor, num_classes)

        loss_op = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=logits, labels=labels_tensor))
        optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
        train_op = optimizer.minimize(loss_op)

        correct_pred = tf.equal(tf.argmax(logits, 1), tf.argmax(labels_tensor, 1))
        accuracy_op = tf.reduce_mean(tf.cast(correct_pred, tf.float32))

        sess.run(tf.global_variables_initializer())

        epochs = 10
        batch_size = 64

        for epoch in range(epochs):
            step = 0

            for batch_images, batch_labels_one_hot in load_images_chunk(train_image_paths, train_labels, batch_size):
                _, loss, accuracy = sess.run(
                    [train_op, loss_op, accuracy_op],
                    feed_dict={input_tensor: batch_images, labels_tensor: batch_labels_one_hot}
                )
                print(f"Epoch {epoch + 1}/{epochs}, Step: {step}, Loss: {loss:.4f}, Accuracy: {accuracy:.4f}")
                step = step + 1

            # Validate the model
            val_accuracy_list = []
            for batch_images, batch_labels_one_hot in load_images_chunk(val_image_paths, val_labels, batch_size):
                accuracy = sess.run(accuracy_op,
                                    feed_dict={input_tensor: batch_images, labels_tensor: batch_labels_one_hot})
                val_accuracy_list.append(accuracy)
            val_accuracy = np.mean(val_accuracy_list)
            print(f"Validation Accuracy: {val_accuracy:.4f}")


if __name__ == '__main__':
    tf.logging.set_verbosity(tf.logging.INFO)
    config_json = {}
    epl.init(epl.Config(config_json))
    print(epl.Env.get().cluster.gpu_num_per_worker)
    if epl.Env.get().cluster.gpu_num_per_worker > 1:
        # Avoid NCCL hang.
        os.environ["NCCL_LAUNCH_MODE"] = "GROUP"
    epl.set_default_strategy(epl.replicate(device_count=1))
    run_model()

I am confronted with the following issue:
Traceback (most recent call last):
File "resnet50_split3.py", line 203, in
run_model()
File "resnet50_split3.py", line 164, in run_model
sess.run(tf.global_variables_initializer())
File "/users/Master/anaconda3/envs/py37/lib/python3.7/site-packages/epl/parallel/hooks.py", line 453, in run
assign_ops = _init_local_resources(self, fn)
File "/users/Master/anaconda3/envs/py37/lib/python3.7/site-packages/epl/parallel/hooks.py", line 416, in _init_local_resources
assign_ops = broadcast_variables()
File "/users/Master/anaconda3/envs/py37/lib/python3.7/site-packages/epl/parallel/hooks.py", line 339, in broadcast_variables
bcast_variables = taskgraph.get_variables(replica_idx)
File "/users/Master/anaconda3/envs/py37/lib/python3.7/site-packages/epl/ir/taskgraph.py", line 409, in get_variables
if id(var_tensor.taskgraph) != id(self):
AttributeError: 'NoneType' object has no attribute 'taskgraph'

Could you give me a hand when you are free? Thank you very much!

@SeaOfOcean
Copy link
Collaborator

you can use tf.train.MonitoredTrainingSession instead of tf.Session , and global_variables_initializer is not necessary when using MonitoredTrainingSession, you can refer https://github.com/alibaba/FastNN/blob/73b70c633117ccff4f1a270f461bacb96e0fc4ee/resnet/resnet_dp.py#L67

@Waterpine
Copy link
Author

Thanks for your reply! I modify the code as follows:

import numpy as np
import tensorflow as tf
from tensorflow.keras.datasets import cifar10
import epl
import os


def conv_bn_relu(inputs, filters, kernel_size, stride, training):
    conv = tf.layers.conv2d(inputs, filters, kernel_size, strides=stride, padding='SAME', use_bias=False)
    bn = tf.layers.batch_normalization(conv, training=training)
    relu = tf.nn.relu(bn)
    return relu


def bottleneck_block(inputs, filters, stride, training):
    shortcut = inputs
    out = conv_bn_relu(inputs, filters, 1, 1, training)
    out = conv_bn_relu(out, filters, 3, stride, training)
    out = conv_bn_relu(out, 4 * filters, 1, 1, training)
    if stride != 1 or inputs.get_shape().as_list()[-1] != 4 * filters:
        shortcut = tf.layers.conv2d(inputs, 4 * filters, 1, strides=stride, padding='SAME', use_bias=False)
        shortcut = tf.layers.batch_normalization(shortcut, training=training)
    out = tf.add(out, shortcut)
    return out


def resnet50(inputs, training):
    out = conv_bn_relu(inputs, 64, 3, 1, training)
    out = bottleneck_block(out, 64, 1, training)
    out = bottleneck_block(out, 128, 2, training)
    out = bottleneck_block(out, 256, 2, training)
    out = bottleneck_block(out, 512, 2, training)
    out = tf.layers.average_pooling2d(out, 4, 1)
    out = tf.layers.flatten(out)
    out = tf.layers.dense(out, 10)
    return out


def run_model():
    (X_train, y_train), (X_test, y_test) = cifar10.load_data()

    X_train, X_test = X_train.astype(np.float32) / 255.0, X_test.astype(np.float32) / 255.0
    y_train, y_test = y_train.astype(np.int32), y_test.astype(np.int32)

    images = tf.placeholder(tf.float32, shape=(None, 32, 32, 3), name='images')
    labels = tf.placeholder(tf.int32, shape=(None), name='labels')
    is_training = tf.placeholder(tf.bool, name='is_training')

    logits = resnet50(images, is_training)
    loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)

    global_step = tf.train.get_or_create_global_step()
    optimizer = tf.train.AdamOptimizer(0.001)
    train_op = optimizer.minimize(loss, global_step=global_step)

    batch_size = 128
    n_epochs = 100
    hooks = [tf.train.StopAtStepHook(last_step=n_epochs * len(X_train) // batch_size)]

    def get_batch(data, labels, batch_size):
        idx = np.random.choice(np.arange(len(data)), batch_size, replace=False)
        return data[idx], labels[idx].flatten()

    with tf.train.MonitoredTrainingSession(hooks=hooks) as sess:
        while not sess.should_stop():
            batch_images, batch_labels = get_batch(X_train, y_train, batch_size)
            _, train_loss, step = sess.run(
                [train_op, loss, global_step],
                feed_dict={images: batch_images, labels: batch_labels, is_training: True}
            )
            if step % 100 == 0:
                print(f"Step {step}, Loss: {train_loss:.4f}")


if __name__ == '__main__':
    tf.logging.set_verbosity(tf.logging.INFO)
    config_json = {}
    epl.init(epl.Config(config_json))
    print(epl.Env.get().cluster.gpu_num_per_worker)
    if epl.Env.get().cluster.gpu_num_per_worker > 1:
        # Avoid NCCL hang.
        os.environ["NCCL_LAUNCH_MODE"] = "GROUP"
    epl.set_default_strategy(epl.replicate(device_count=1))
    run_model()

However, I am confronted with the following issue:

Traceback (most recent call last):
File "/users/Master/anaconda3/envs/py37/lib/python3.7/site-packages/tensorflow_core/python/client/session.py", line 1365, in _do_call
return fn(*args)
File "/users/Master/anaconda3/envs/py37/lib/python3.7/site-packages/tensorflow_core/python/client/session.py", line 1350, in _run_fn
target_list, run_metadata)
File "/users/Master/anaconda3/envs/py37/lib/python3.7/site-packages/tensorflow_core/python/client/session.py", line 1443, in _call_tf_sessionrun
run_metadata)
tensorflow.python.framework.errors_impl.InvalidArgumentError: From /job:worker/replica:0/task:0:
You must feed a value for placeholder tensor 'EPL_REPLICA_1/labels' with dtype int32
[[{{node EPL_REPLICA_1/labels}}]]

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
File "resnet50_split4.py", line 89, in
run_model()
File "resnet50_split4.py", line 74, in run_model
feed_dict={images: batch_images, labels: batch_labels, is_training: True}
File "/users/Master/anaconda3/envs/py37/lib/python3.7/site-packages/tensorflow_core/python/training/monitored_session.py", line 754, in run
run_metadata=run_metadata)
File "/users/Master/anaconda3/envs/py37/lib/python3.7/site-packages/tensorflow_core/python/training/monitored_session.py", line 1259, in run
run_metadata=run_metadata)
File "/users/Master/anaconda3/envs/py37/lib/python3.7/site-packages/tensorflow_core/python/training/monitored_session.py", line 1360, in run
raise six.reraise(*original_exc_info)
File "/users/Master/anaconda3/envs/py37/lib/python3.7/site-packages/six.py", line 719, in reraise
raise value
File "/users/Master/anaconda3/envs/py37/lib/python3.7/site-packages/tensorflow_core/python/training/monitored_session.py", line 1345, in run
return self._sess.run(*args, **kwargs)
File "/users/Master/anaconda3/envs/py37/lib/python3.7/site-packages/tensorflow_core/python/training/monitored_session.py", line 1418, in run
run_metadata=run_metadata)
File "/users/Master/anaconda3/envs/py37/lib/python3.7/site-packages/tensorflow_core/python/training/monitored_session.py", line 1176, in run
return self._sess.run(*args, **kwargs)
File "/users/Master/anaconda3/envs/py37/lib/python3.7/site-packages/epl/parallel/hooks.py", line 464, in run
outputs = fn(self, actual_fetches, feed_dict, options, run_metadata)
File "/users/Master/anaconda3/envs/py37/lib/python3.7/site-packages/tensorflow_core/python/client/session.py", line 956, in run
run_metadata_ptr)
File "/users/Master/anaconda3/envs/py37/lib/python3.7/site-packages/tensorflow_core/python/client/session.py", line 1180, in _run
feed_dict_tensor, options, run_metadata)
File "/users/Master/anaconda3/envs/py37/lib/python3.7/site-packages/tensorflow_core/python/client/session.py", line 1359, in _do_run
run_metadata)
File "/users/Master/anaconda3/envs/py37/lib/python3.7/site-packages/tensorflow_core/python/client/session.py", line 1384, in _do_call
raise type(e)(node_def, op, message)
tensorflow.python.framework.errors_impl.InvalidArgumentError: From /job:worker/replica:0/task:0:
You must feed a value for placeholder tensor 'EPL_REPLICA_1/labels' with dtype int32
[[node EPL_REPLICA_1/labels (defined at /users/Master/anaconda3/envs/py37/lib/python3.7/site-packages/tensorflow_core/python/framework/ops.py:1748) ]]

Original stack trace for 'EPL_REPLICA_1/labels':
File "resnet50_split4.py", line 89, in
run_model()
File "resnet50_split4.py", line 69, in run_model
with tf.train.MonitoredTrainingSession(hooks=hooks) as sess:
File "/users/Master/anaconda3/envs/py37/lib/python3.7/site-packages/tensorflow_core/python/training/monitored_session.py", line 584, in MonitoredTrainingSession
stop_grace_period_secs=stop_grace_period_secs)
File "/users/Master/anaconda3/envs/py37/lib/python3.7/site-packages/tensorflow_core/python/training/monitored_session.py", line 1014, in init
stop_grace_period_secs=stop_grace_period_secs)
File "/users/Master/anaconda3/envs/py37/lib/python3.7/site-packages/epl/parallel/hooks.py", line 319, in init
res = fn(self, *args, **kwargs)
File "/users/Master/anaconda3/envs/py37/lib/python3.7/site-packages/tensorflow_core/python/training/monitored_session.py", line 725, in init
self._sess = _RecoverableSession(self._coordinated_creator)
File "/users/Master/anaconda3/envs/py37/lib/python3.7/site-packages/tensorflow_core/python/training/monitored_session.py", line 1207, in init
_WrappedSession.init(self, self._create_session())
File "/users/Master/anaconda3/envs/py37/lib/python3.7/site-packages/tensorflow_core/python/training/monitored_session.py", line 1212, in _create_session
return self._sess_creator.create_session()
File "/users/Master/anaconda3/envs/py37/lib/python3.7/site-packages/tensorflow_core/python/training/monitored_session.py", line 878, in create_session
self.tf_sess = self._session_creator.create_session()
File "/users/Master/anaconda3/envs/py37/lib/python3.7/site-packages/tensorflow_core/python/training/monitored_session.py", line 638, in create_session
self._scaffold.finalize()
File "/users/Master/anaconda3/envs/py37/lib/python3.7/site-packages/epl/parallel/hooks.py", line 273, in finalize
fn(self)
File "/users/Master/anaconda3/envs/py37/lib/python3.7/site-packages/tensorflow_core/python/training/monitored_session.py", line 239, in finalize
ops.get_default_graph().finalize()
File "/users/Master/anaconda3/envs/py37/lib/python3.7/site-packages/epl/parallel/hooks.py", line 261, in finalize
Parallel.get().do_parallelism()
File "/users/Master/anaconda3/envs/py37/lib/python3.7/site-packages/epl/parallel/parallel.py", line 223, in do_parallelism
self.transformer.replicas_clone()
File "/users/Master/anaconda3/envs/py37/lib/python3.7/site-packages/epl/parallel/graph_editor.py", line 427, in replicas_clone
self._forward_clone()
File "/users/Master/anaconda3/envs/py37/lib/python3.7/site-packages/epl/parallel/graph_editor.py", line 343, in _forward_clone
target_device)
File "/users/Master/anaconda3/envs/py37/lib/python3.7/site-packages/epl/parallel/ops.py", line 237, in node_clone_for_replicas
op_def=op_def)
File "/users/Master/anaconda3/envs/py37/lib/python3.7/site-packages/tensorflow_core/python/framework/ops.py", line 1748, in init
self._traceback = tf_stack.extract_stack()

Could you give me a hand? Thank you very much!

@SeaOfOcean
Copy link
Collaborator

you should replace get_batch with tf.data.Dataset

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants