Skip to content

Commit

Permalink
Merge pull request #8845 from astrofrog/find-current-module-bundle
Browse files Browse the repository at this point in the history
Fix find_current_module so that it can work in application bundles
  • Loading branch information
astrofrog authored and bsipocz committed Jul 15, 2019
1 parent d29151f commit f3b28e4
Show file tree
Hide file tree
Showing 5 changed files with 122 additions and 12 deletions.
6 changes: 6 additions & 0 deletions CHANGES.rst
Expand Up @@ -2384,6 +2384,12 @@ astropy.units
astropy.utils
^^^^^^^^^^^^^

- Fix ``find_current_module`` so that it works properly if astropy is being used
inside a bundle such as that produced by PyInstaller. [#8845]

- Fix path to renamed classes, which previously included duplicate path/module
information under certain circumstances. [#8845]

astropy.visualization
^^^^^^^^^^^^^^^^^^^^^

Expand Down
10 changes: 2 additions & 8 deletions astropy/modeling/core.py
Expand Up @@ -207,7 +207,7 @@ def rename(cls, name):
>>> from astropy.modeling.models import Rotation2D
>>> SkyRotation = Rotation2D.rename('SkyRotation')
>>> SkyRotation
<class '__main__.SkyRotation'>
<class 'astropy.modeling.core.SkyRotation'>
Name: SkyRotation (Rotation2D)
Inputs: ('x', 'y')
Outputs: ('x', 'y')
Expand All @@ -227,13 +227,7 @@ def rename(cls, name):

new_cls = type(name, (cls,), {})
new_cls.__module__ = modname

if hasattr(cls, '__qualname__'):
if new_cls.__module__ == '__main__':
# __main__ is not added to a class's qualified name
new_cls.__qualname__ = name
else:
new_cls.__qualname__ = '{0}.{1}'.format(modname, name)
new_cls.__qualname__ = name

return new_cls

Expand Down
49 changes: 49 additions & 0 deletions astropy/modeling/tests/test_core.py
@@ -1,10 +1,15 @@
# Licensed under a 3-clause BSD style license - see LICENSE.rst

import os
import sys
import subprocess

import pytest
import numpy as np
from inspect import signature
from numpy.testing import assert_allclose

import astropy
from astropy.modeling.core import Model, custom_model
from astropy.modeling.parameters import Parameter
from astropy.modeling import models
Expand Down Expand Up @@ -380,3 +385,47 @@ def test_compound_deepcopy():
assert id(model._submodels[0]) != id(new_model._submodels[0])
assert id(model._submodels[1]) != id(new_model._submodels[1])
assert id(model._submodels[2]) != id(new_model._submodels[2])


RENAMED_MODEL = models.Gaussian1D.rename('CustomGaussian')

MODEL_RENAME_CODE = """
from astropy.modeling.models import Gaussian1D
print(repr(Gaussian1D))
print(repr(Gaussian1D.rename('CustomGaussian')))
""".strip()

MODEL_RENAME_EXPECTED = b"""
<class 'astropy.modeling.functional_models.Gaussian1D'>
Name: Gaussian1D
Inputs: ('x',)
Outputs: ('y',)
Fittable parameters: ('amplitude', 'mean', 'stddev')
<class '__main__.CustomGaussian'>
Name: CustomGaussian (Gaussian1D)
Inputs: ('x',)
Outputs: ('y',)
Fittable parameters: ('amplitude', 'mean', 'stddev')
""".strip()


def test_rename_path(tmpdir):

# Regression test for a bug that caused the path to the class to be
# incorrect in a renamed model's __repr__.

assert repr(RENAMED_MODEL).splitlines()[0] == "<class 'astropy.modeling.tests.test_core.CustomGaussian'>"

# Make sure that when called from a user script, the class name includes
# __main__.

env = os.environ.copy()
paths = [os.path.dirname(astropy.__path__[0])] + sys.path
env['PYTHONPATH'] = os.pathsep.join(paths)

script = tmpdir.join('rename.py').strpath
with open(script, 'w') as f:
f.write(MODEL_RENAME_CODE)

output = subprocess.check_output([sys.executable, script], env=env)
assert output.splitlines() == MODEL_RENAME_EXPECTED.splitlines()
52 changes: 48 additions & 4 deletions astropy/utils/introspection.py
Expand Up @@ -5,11 +5,12 @@

import inspect
import re
import os
import sys
import types
import importlib
from distutils.version import LooseVersion


__all__ = ['resolve_name', 'minversion', 'find_current_module',
'isinstancemethod']

Expand Down Expand Up @@ -239,7 +240,7 @@ def find():
return None

if finddiff:
currmod = inspect.getmodule(frm)
currmod = _get_module_from_frame(frm)
if finddiff is True:
diffmods = [currmod]
else:
Expand All @@ -256,12 +257,55 @@ def find():

while frm:
frmb = frm.f_back
modb = inspect.getmodule(frmb)
modb = _get_module_from_frame(frmb)
if modb not in diffmods:
return modb
frm = frmb
else:
return inspect.getmodule(frm)
return _get_module_from_frame(frm)


def _get_module_from_frame(frm):
"""Uses inspect.getmodule() to get the module that the current frame's
code is running in.
However, this does not work reliably for code imported from a zip file,
so this provides a fallback mechanism for that case which is less
reliable in general, but more reliable than inspect.getmodule() for this
particular case.
"""

mod = inspect.getmodule(frm)
if mod is not None:
return mod

# Check to see if we're importing from a bundle file. First ensure that
# __file__ is available in globals; this is cheap to check to bail out
# immediately if this fails

if '__file__' in frm.f_globals and '__name__' in frm.f_globals:

filename = frm.f_globals['__file__']

# Using __file__ from the frame's globals and getting it into the form
# of an absolute path name with .py at the end works pretty well for
# looking up the module using the same means as inspect.getmodule

if filename[-4:].lower() in ('.pyc', '.pyo'):
filename = filename[:-4] + '.py'
filename = os.path.realpath(os.path.abspath(filename))
if filename in inspect.modulesbyfile:
return sys.modules.get(inspect.modulesbyfile[filename])

# On Windows, inspect.modulesbyfile appears to have filenames stored
# in lowercase, so we check for this case too.
if filename.lower() in inspect.modulesbyfile:
return sys.modules.get(inspect.modulesbyfile[filename.lower()])

# Otherwise there are still some even trickier things that might be possible
# to track down the module, but we'll leave those out unless we find a case
# where it's really necessary. So return None if the module is not found.
return None


def find_mod_objs(modname, onlylocals=False):
Expand Down
17 changes: 17 additions & 0 deletions astropy/utils/tests/test_introspection.py
Expand Up @@ -2,6 +2,7 @@

# namedtuple is needed for find_mod_objs so it can have a non-local module
from collections import namedtuple
from unittest import mock

import pytest

Expand Down Expand Up @@ -73,3 +74,19 @@ def test_minversion():
assert minversion(test_module, version)
for version in bad_versions:
assert not minversion(test_module, version)


def test_find_current_module_bundle():
"""
Tests that the `find_current_module` function would work if used inside
an application bundle. Since we can't test this directly, we test what
would happen if inspect.getmodule returned `None`, which is what happens
inside PyInstaller and py2app bundles.
"""
with mock.patch('inspect.getmodule', return_value=None):
mod1 = 'astropy.utils.introspection'
mod2 = 'astropy.utils.tests.test_introspection'
mod3 = 'astropy.utils.tests.test_introspection'
assert find_current_module(0).__name__ == mod1
assert find_current_module(1).__name__ == mod2
assert find_current_module(0, True).__name__ == mod3

0 comments on commit f3b28e4

Please sign in to comment.