Skip to content

Commit

Permalink
Fix error message for incorrect calls with unlock_instance (#1044)
Browse files Browse the repository at this point in the history
* first commit

* added a test
  • Loading branch information
nikfilippas committed Mar 23, 2023
1 parent d514c13 commit bccfc88
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 15 deletions.
26 changes: 16 additions & 10 deletions pyccl/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,14 +393,14 @@ def __exit__(self, type, value, traceback):
self.object_lock.lock()

@classmethod
def unlock_instance(cls, func=None, *, argv=0, mutate=True):
def unlock_instance(cls, func=None, *, name=None, mutate=True):
"""Decorator that temporarily unlocks an instance of CCLObject.
Arguments:
func (``function``):
Function which changes one of its ``CCLObject`` arguments.
argv (``int``):
Which argument should be unlocked. Defaults to the first one.
name (``str``):
Name of the parameter to unlock. Defaults to the first one.
If not a ``CCLObject`` the decorator will do nothing.
mutate (``bool``):
If after the function ``instance_old != instance_new``, the
Expand All @@ -409,17 +409,23 @@ def unlock_instance(cls, func=None, *, argv=0, mutate=True):
"""
if func is None:
# called with parentheses
return functools.partial(cls.unlock_instance, argv=argv,
return functools.partial(cls.unlock_instance, name=name,
mutate=mutate)

if not hasattr(func, "__signature__"):
# store the function signature
func.__signature__ = signature(func)
names = list(func.__signature__.parameters.keys())
name = names[0] if name is None else name # default name
if name not in names:
# ensure the name makes sense
raise NameError(f"{name} does not exist in {func.__name__}.")

@functools.wraps(func)
def wrapper(*args, **kwargs):
# Pick argument from list of `args` or `kwargs` as needed.
size = len(args)
arg = args[argv] if size > argv else list(kwargs.values())[argv-size] # noqa
with UnlockInstance(arg, mutate=mutate):
out = func(*args, **kwargs)
return out
bound = func.__signature__.bind(*args, **kwargs)
with UnlockInstance(bound.arguments[name], mutate=mutate):
return func(*args, **kwargs)
return wrapper

@classmethod
Expand Down
2 changes: 1 addition & 1 deletion pyccl/bcm.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def bcm_model_fka(cosmo, k, a):
return fka


@unlock_instance(mutate=True, argv=1)
@unlock_instance(mutate=True, name="pk2d")
def bcm_correct_pk2d(cosmo, pk2d):
"""Apply the BCM model correction factor to a given power spectrum.
Expand Down
26 changes: 22 additions & 4 deletions pyccl/tests/test_cclobject.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,12 @@
import functools


def all_subclasses(cls):
"""Get all subclasses of ``cls``. NOTE: Used in ``conftest.py``."""
return set(cls.__subclasses__()).union([s for c in cls.__subclasses__()
for s in all_subclasses(c)])


def test_fancy_repr():
# Test fancy-repr controls.
cosmo1 = ccl.CosmologyVanillaLCDM()
Expand Down Expand Up @@ -211,7 +217,19 @@ def wrapper(self, *args, **kwargs):
return wrapper


def all_subclasses(cls):
"""Get all subclasses of ``cls``. NOTE: Used in ``conftest.py``."""
return set(cls.__subclasses__()).union([s for c in cls.__subclasses__()
for s in all_subclasses(c)])
def test_unlock_instance_errors():
# Test that unlock_instance gives the correct errors.

# 1. Developer error
with pytest.raises(NameError):
@ccl.unlock_instance(name="hello")
def func1(item, pk, a0=0, *, a1=None, a2):
return

# 2. User error
@ccl.unlock_instance(name="pk")
def func2(item, pk, a0=0, *, a1=None, a2):
return

with pytest.raises(TypeError):
func2()

0 comments on commit bccfc88

Please sign in to comment.