Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[installation] set_seed not an attribute #625

Closed
JiahaoYao opened this issue Jul 22, 2022 · 2 comments
Closed

[installation] set_seed not an attribute #625

JiahaoYao opened this issue Jul 22, 2022 · 2 comments

Comments

@JiahaoYao
Copy link
Contributor

get the error of the following

Traceback (most recent call last):
  File "main.py", line 65, in <module>
    app.run(main)
  File "/home/ubuntu/anaconda3/envs/tensorflow2_p38/lib/python3.8/site-packages/absl/app.py", line 312, in run
    _run_main(main, args)
  File "/home/ubuntu/anaconda3/envs/tensorflow2_p38/lib/python3.8/site-packages/absl/app.py", line 258, in _run_main
    sys.exit(main(argv))
  File "main.py", line 60, in main
    train.train_and_evaluate(FLAGS.config, FLAGS.workdir)
  File "/home/ubuntu/alpa/examples/mnist/train.py", line 152, in train_and_evaluate
    state, train_loss, train_accuracy = train_epoch(state, train_ds,
  File "/home/ubuntu/alpa/examples/mnist/train.py", line 99, in train_epoch
    state, loss, accuracy = train_step(state, batch_images, batch_labels)
  File "/home/ubuntu/anaconda3/envs/tensorflow2_p38/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/ubuntu/anaconda3/envs/tensorflow2_p38/lib/python3.8/site-packages/alpa/api.py", line 104, in __call__
    self._decode_args_and_get_executable(*args))
  File "/home/ubuntu/anaconda3/envs/tensorflow2_p38/lib/python3.8/site-packages/alpa/api.py", line 174, in _decode_args_and_get_executable
    executable = _compile_parallel_executable(f, in_tree, out_tree_hashable,
  File "/home/ubuntu/anaconda3/envs/tensorflow2_p38/lib/python3.8/site-packages/jax/linear_util.py", line 272, in memoized_fun
    ans = call(fun, *args)
  File "/home/ubuntu/anaconda3/envs/tensorflow2_p38/lib/python3.8/site-packages/alpa/api.py", line 201, in _compile_parallel_executable
    return method.compile_executable(fun, in_tree, out_tree_thunk,
  File "/home/ubuntu/anaconda3/envs/tensorflow2_p38/lib/python3.8/site-packages/alpa/parallel_method.py", line 89, in compile_executable
    mesh = get_global_physical_mesh(create_if_not_exist=True)
  File "/home/ubuntu/anaconda3/envs/tensorflow2_p38/lib/python3.8/site-packages/alpa/device_mesh.py", line 2127, in get_global_physical_mesh
    mesh = LocalPhysicalDeviceMesh()
  File "/home/ubuntu/anaconda3/envs/tensorflow2_p38/lib/python3.8/site-packages/alpa/device_mesh.py", line 816, in __init__
    self.set_runtime_random_seed(global_config.runtime_random_seed)
  File "/home/ubuntu/anaconda3/envs/tensorflow2_p38/lib/python3.8/site-packages/alpa/device_mesh.py", line 873, in set_runtime_random_seed
    d.set_seed(seed)
jax._src.traceback_util.UnfilteredStackTrace: AttributeError: 'jaxlib.xla_extension.GpuDevice' object has no attribute 'set_seed'
@JiahaoYao
Copy link
Contributor Author

The reason of this error is because the old version of jaxlib

Python 3.7.3 (default, Mar 27 2019, 22:11:17)
[GCC 7.3.0] :: Anaconda, Inc. on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import jax
jax.devi>>> jax.devices()
[GpuDevice(id=0, process_index=0)]
>>> jax.devices()[0]
GpuDevice(id=0, process_index=0)
>>> dir(jax.devices()[0])
['__class__', '__delattr__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattribute__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__le__', '__lt__', '__module__', '__ne__', '__new__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__sizeof__', '__str__', '__subclasshook__', 'client', 'device_kind', 'device_vendor', 'host_id', 'id', 'live_buffers', 'platform', 'process_index', 'task_id', 'transfer_from_outfeed', 'transfer_to_infeed']

The solution is to update the jaxlib

Python 3.8.12 | packaged by conda-forge | (default, Oct 12 2021, 21:59:51)
[GCC 9.4.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import jax
<frozen importlib._bootstrap>:219: RuntimeWarning: scipy._lib.messagestream.MessageStream size changed, may indicate binary incompatibility. Expected 56 from C header, got 64 from PyObject
>>> jax.devices()
[GpuDevice(id=0, process_index=0), GpuDevice(id=1, process_index=0), GpuDevice(id=2, process_index=0), GpuDevice(id=3, process_index=0)]
>>> jax.devices()[0]
GpuDevice(id=0, process_index=0)
>>> dir(jax.devices()[0])
['__class__', '__delattr__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattribute__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__le__', '__lt__', '__module__', '__ne__', '__new__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__sizeof__', '__str__', '__subclasshook__', 'available_memory', 'clear_memory_stats', 'client', 'device_kind', 'device_vendor', 'host_id', 'id', 'live_buffers', 'max_memory_allocated', 'memory_allocated', 'platform', 'process_index', 'set_seed', 'synchronize_all_activity', 'task_id', 'transfer_from_outfeed', 'transfer_to_infeed']

@JiahaoYao
Copy link
Contributor Author

more trouble shooting: see #496

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant