Skip to content
Permalink
Browse files

ARROW-2920: [Python] Fix pytorch segfault

This fixes ARROW-2920 (see also ray-project/ray#2447) for me

Unfortunately we might not be able to have regression tests for this right now because we don't have CUDA in our test toolchain.

Author: Philipp Moritz <pcmoritz@gmail.com>

Closes #2329 from pcmoritz/fix-pytorch-segfault and squashes the following commits:

1d82825 <Philipp Moritz> fix
74bc93e <Philipp Moritz> add note
ff14c4d <Philipp Moritz> fix
b343ca6 <Philipp Moritz> add regression test
5f0cafa <Philipp Moritz> fix
2751679 <Philipp Moritz> fix
10c5a5c <Philipp Moritz> workaround for pyarrow segfault
  • Loading branch information...
pcmoritz authored and wesm committed Jul 27, 2018
1 parent 4ba2d19 commit 537e7f7fd503dd920c0b9f0cef8a2de86bc69e3b
Showing with 65 additions and 29 deletions.
  1. +2 −0 python/pyarrow/__init__.py
  2. +52 −29 python/pyarrow/compat.py
  3. +11 −0 python/pyarrow/tests/test_serialization.py
@@ -51,8 +51,10 @@ def parse_version(root):


# Workaround for https://issues.apache.org/jira/browse/ARROW-2657
# and https://issues.apache.org/jira/browse/ARROW-2920
if _sys.platform in ('linux', 'linux2'):
compat.import_tensorflow_extension()
compat.import_pytorch_extension()


from pyarrow.lib import cpu_count, set_cpu_count
@@ -160,31 +160,17 @@ def encode_file_path(path):
# will convert utf8 to utf16
return encoded_path

def import_tensorflow_extension():
def _iterate_python_module_paths(package_name):
"""
Load the TensorFlow extension if it exists.
Return an iterator to full paths of a python package.
This is used to load the TensorFlow extension before
pyarrow.lib. If we don't do this there are symbol clashes
between TensorFlow's use of threading and our global
thread pool, see also
https://issues.apache.org/jira/browse/ARROW-2657 and
https://github.com/apache/arrow/pull/2096.
This is a best effort and might fail (for example on Python 2).
It uses the official way of loading modules from
https://docs.python.org/3/library/importlib.html#approximating-importlib-import-module
"""
import os
tensorflow_loaded = False

# Try to load the tensorflow extension directly
# This is a performance optimization, tensorflow will always be
# loaded via the "import tensorflow" statement below if this
# doesn't succeed.
#
# This uses the official way of loading modules from
# https://docs.python.org/3/library/importlib.html#approximating-importlib-import-module

try:
import importlib
absolute_name = importlib.util.resolve_name("tensorflow", None)
absolute_name = importlib.util.resolve_name(package_name, None)
except (ImportError, AttributeError):
# Sometimes, importlib is not available (e.g. Python 2)
# or importlib.util is not available (e.g. Python 2.7)
@@ -205,16 +191,37 @@ def import_tensorflow_extension():
if spec:
module = importlib.util.module_from_spec(spec)
for path in module.__path__:
ext = os.path.join(path, "libtensorflow_framework.so")
if os.path.exists(ext):
import ctypes
try:
ctypes.CDLL(ext)
except OSError:
pass
tensorflow_loaded = True
break
yield path

def import_tensorflow_extension():
"""
Load the TensorFlow extension if it exists.
This is used to load the TensorFlow extension before
pyarrow.lib. If we don't do this there are symbol clashes
between TensorFlow's use of threading and our global
thread pool, see also
https://issues.apache.org/jira/browse/ARROW-2657 and
https://github.com/apache/arrow/pull/2096.
"""
import os
tensorflow_loaded = False

# Try to load the tensorflow extension directly
# This is a performance optimization, tensorflow will always be
# loaded via the "import tensorflow" statement below if this
# doesn't succeed.

for path in _iterate_python_module_paths("tensorflow"):
ext = os.path.join(path, "libtensorflow_framework.so")
if os.path.exists(ext):
import ctypes
try:
ctypes.CDLL(ext)
except OSError:
pass
tensorflow_loaded = True
break

# If the above failed, try to load tensorflow the normal way
# (this is more expensive)
@@ -225,6 +232,22 @@ def import_tensorflow_extension():
except ImportError:
pass

def import_pytorch_extension():
"""
Load the PyTorch extension if it exists.
This is used to load the PyTorch extension before
pyarrow.lib. If we don't do this there are symbol clashes
between PyTorch's use of threading and our global
thread pool, see also
https://issues.apache.org/jira/browse/ARROW-2920
"""
import ctypes
import os

for path in _iterate_python_module_paths("torch"):
ctypes.CDLL(os.path.join(path, "lib/libcaffe2.so"))


integer_types = six.integer_types + (np.integer,)

@@ -369,6 +369,17 @@ def test_torch_serialization(large_buffer):
context=serialization_context)


@pytest.mark.skipif(not torch or not torch.cuda.is_available(),
reason="requires pytorch with CUDA")
def test_torch_cuda():
# ARROW-2920: This used to segfault if torch is not imported
# before pyarrow
# Note that this test will only catch the issue if it is run
# with a pyarrow that has been built in the manylinux1 environment
torch.nn.Conv2d(64, 2, kernel_size=3, stride=1,
padding=1, bias=False).cuda()


def test_numpy_immutable(large_buffer):
obj = np.zeros([10])

0 comments on commit 537e7f7

Please sign in to comment.
You can’t perform that action at this time.