From f3b28e49085a0e41842273d8624adf617e4e3857 Mon Sep 17 00:00:00 2001 From: Thomas Robitaille Date: Wed, 10 Jul 2019 09:22:44 -0500 Subject: [PATCH] Merge pull request #8845 from astrofrog/find-current-module-bundle Fix find_current_module so that it can work in application bundles --- CHANGES.rst | 6 +++ astropy/modeling/core.py | 10 +---- astropy/modeling/tests/test_core.py | 49 +++++++++++++++++++++ astropy/utils/introspection.py | 52 +++++++++++++++++++++-- astropy/utils/tests/test_introspection.py | 17 ++++++++ 5 files changed, 122 insertions(+), 12 deletions(-) diff --git a/CHANGES.rst b/CHANGES.rst index f9163d0f911..06c7e5d412a 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -2384,6 +2384,12 @@ astropy.units astropy.utils ^^^^^^^^^^^^^ +- Fix ``find_current_module`` so that it works properly if astropy is being used + inside a bundle such as that produced by PyInstaller. [#8845] + +- Fix path to renamed classes, which previously included duplicate path/module + information under certain circumstances. [#8845] + astropy.visualization ^^^^^^^^^^^^^^^^^^^^^ diff --git a/astropy/modeling/core.py b/astropy/modeling/core.py index 902194c35d0..3e5bfb4bfbb 100644 --- a/astropy/modeling/core.py +++ b/astropy/modeling/core.py @@ -207,7 +207,7 @@ def rename(cls, name): >>> from astropy.modeling.models import Rotation2D >>> SkyRotation = Rotation2D.rename('SkyRotation') >>> SkyRotation - + Name: SkyRotation (Rotation2D) Inputs: ('x', 'y') Outputs: ('x', 'y') @@ -227,13 +227,7 @@ def rename(cls, name): new_cls = type(name, (cls,), {}) new_cls.__module__ = modname - - if hasattr(cls, '__qualname__'): - if new_cls.__module__ == '__main__': - # __main__ is not added to a class's qualified name - new_cls.__qualname__ = name - else: - new_cls.__qualname__ = '{0}.{1}'.format(modname, name) + new_cls.__qualname__ = name return new_cls diff --git a/astropy/modeling/tests/test_core.py b/astropy/modeling/tests/test_core.py index e2f21e83ac7..00b74240e7a 100644 --- a/astropy/modeling/tests/test_core.py +++ b/astropy/modeling/tests/test_core.py @@ -1,10 +1,15 @@ # Licensed under a 3-clause BSD style license - see LICENSE.rst +import os +import sys +import subprocess + import pytest import numpy as np from inspect import signature from numpy.testing import assert_allclose +import astropy from astropy.modeling.core import Model, custom_model from astropy.modeling.parameters import Parameter from astropy.modeling import models @@ -380,3 +385,47 @@ def test_compound_deepcopy(): assert id(model._submodels[0]) != id(new_model._submodels[0]) assert id(model._submodels[1]) != id(new_model._submodels[1]) assert id(model._submodels[2]) != id(new_model._submodels[2]) + + +RENAMED_MODEL = models.Gaussian1D.rename('CustomGaussian') + +MODEL_RENAME_CODE = """ +from astropy.modeling.models import Gaussian1D +print(repr(Gaussian1D)) +print(repr(Gaussian1D.rename('CustomGaussian'))) +""".strip() + +MODEL_RENAME_EXPECTED = b""" + +Name: Gaussian1D +Inputs: ('x',) +Outputs: ('y',) +Fittable parameters: ('amplitude', 'mean', 'stddev') + +Name: CustomGaussian (Gaussian1D) +Inputs: ('x',) +Outputs: ('y',) +Fittable parameters: ('amplitude', 'mean', 'stddev') +""".strip() + + +def test_rename_path(tmpdir): + + # Regression test for a bug that caused the path to the class to be + # incorrect in a renamed model's __repr__. + + assert repr(RENAMED_MODEL).splitlines()[0] == "" + + # Make sure that when called from a user script, the class name includes + # __main__. + + env = os.environ.copy() + paths = [os.path.dirname(astropy.__path__[0])] + sys.path + env['PYTHONPATH'] = os.pathsep.join(paths) + + script = tmpdir.join('rename.py').strpath + with open(script, 'w') as f: + f.write(MODEL_RENAME_CODE) + + output = subprocess.check_output([sys.executable, script], env=env) + assert output.splitlines() == MODEL_RENAME_EXPECTED.splitlines() diff --git a/astropy/utils/introspection.py b/astropy/utils/introspection.py index e437b40c871..f6ac3e05b34 100644 --- a/astropy/utils/introspection.py +++ b/astropy/utils/introspection.py @@ -5,11 +5,12 @@ import inspect import re +import os +import sys import types import importlib from distutils.version import LooseVersion - __all__ = ['resolve_name', 'minversion', 'find_current_module', 'isinstancemethod'] @@ -239,7 +240,7 @@ def find(): return None if finddiff: - currmod = inspect.getmodule(frm) + currmod = _get_module_from_frame(frm) if finddiff is True: diffmods = [currmod] else: @@ -256,12 +257,55 @@ def find(): while frm: frmb = frm.f_back - modb = inspect.getmodule(frmb) + modb = _get_module_from_frame(frmb) if modb not in diffmods: return modb frm = frmb else: - return inspect.getmodule(frm) + return _get_module_from_frame(frm) + + +def _get_module_from_frame(frm): + """Uses inspect.getmodule() to get the module that the current frame's + code is running in. + + However, this does not work reliably for code imported from a zip file, + so this provides a fallback mechanism for that case which is less + reliable in general, but more reliable than inspect.getmodule() for this + particular case. + """ + + mod = inspect.getmodule(frm) + if mod is not None: + return mod + + # Check to see if we're importing from a bundle file. First ensure that + # __file__ is available in globals; this is cheap to check to bail out + # immediately if this fails + + if '__file__' in frm.f_globals and '__name__' in frm.f_globals: + + filename = frm.f_globals['__file__'] + + # Using __file__ from the frame's globals and getting it into the form + # of an absolute path name with .py at the end works pretty well for + # looking up the module using the same means as inspect.getmodule + + if filename[-4:].lower() in ('.pyc', '.pyo'): + filename = filename[:-4] + '.py' + filename = os.path.realpath(os.path.abspath(filename)) + if filename in inspect.modulesbyfile: + return sys.modules.get(inspect.modulesbyfile[filename]) + + # On Windows, inspect.modulesbyfile appears to have filenames stored + # in lowercase, so we check for this case too. + if filename.lower() in inspect.modulesbyfile: + return sys.modules.get(inspect.modulesbyfile[filename.lower()]) + + # Otherwise there are still some even trickier things that might be possible + # to track down the module, but we'll leave those out unless we find a case + # where it's really necessary. So return None if the module is not found. + return None def find_mod_objs(modname, onlylocals=False): diff --git a/astropy/utils/tests/test_introspection.py b/astropy/utils/tests/test_introspection.py index 2148bd355cf..bc8f509010d 100644 --- a/astropy/utils/tests/test_introspection.py +++ b/astropy/utils/tests/test_introspection.py @@ -2,6 +2,7 @@ # namedtuple is needed for find_mod_objs so it can have a non-local module from collections import namedtuple +from unittest import mock import pytest @@ -73,3 +74,19 @@ def test_minversion(): assert minversion(test_module, version) for version in bad_versions: assert not minversion(test_module, version) + + +def test_find_current_module_bundle(): + """ + Tests that the `find_current_module` function would work if used inside + an application bundle. Since we can't test this directly, we test what + would happen if inspect.getmodule returned `None`, which is what happens + inside PyInstaller and py2app bundles. + """ + with mock.patch('inspect.getmodule', return_value=None): + mod1 = 'astropy.utils.introspection' + mod2 = 'astropy.utils.tests.test_introspection' + mod3 = 'astropy.utils.tests.test_introspection' + assert find_current_module(0).__name__ == mod1 + assert find_current_module(1).__name__ == mod2 + assert find_current_module(0, True).__name__ == mod3