Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix of gauss_2d_large(seed=63) -> NaN #76

Merged
merged 5 commits into from
Jan 23, 2019
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
55 changes: 54 additions & 1 deletion kcsd/validation/csd_profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,50 @@
Nencki Institute of Exprimental Biology, Warsaw.
'''
import numpy as np
from numpy import exp
from numpy import exp, isfinite
from functools import wraps


def repeatUntilValid(f):
"""
A decorator (wrapper).

If output of `f(..., seed)` contains either NaN or infinite, repeats
calculations for other `seed` (randomly generated for the current `seed`)
until the result is valid.

:param f: function of two arguments (the latter is `seed`)
:return: wrapped function f
"""
@wraps(f)
def wrapper(arg, seed=0):
for seed in seedSequence(seed):
result = f(arg, seed)
if isfinite(result).all():
return result

# Python 2.7 walkarround necessary for test purposes
if not hasattr(wrapper, '__wrapped__'):
setattr(wrapper, '__wrapped__', f)

return wrapper


def seedSequence(seed):
"""
Yields a sequence of unique, pseudorandom, deterministic seeds.

:param seed: beginning of the sequence
:return: seed generator
"""
previous = set()
rstate = np.random.RandomState(seed)
while True:
yield seed

previous.add(seed)
while seed in previous:
seed = rstate.randint(2 ** 32)


def get_states_1D(seed, n=1):
Expand Down Expand Up @@ -66,6 +109,7 @@ def get_states_2D(seed):
return states


@repeatUntilValid
def gauss_2d_large(csd_at, seed=0):
'''random quadpolar'large source' profile in 2012 paper in 2D'''
x, y = csd_at
Expand Down Expand Up @@ -339,3 +383,12 @@ def gauss_3d_mono3_f(csd_at):
# neat_4d_plot(chrg_x, chrg_y, chrg_z, f)

# plt.show()

# test of gauss_2d_large(seed=63) -> NaN fix
csd_at = np.mgrid[0.:1.:100j, 0.:1.:100j]
for seed in range(63):
assert (gauss_2d_large.__wrapped__(csd_at, seed) == gauss_2d_large(csd_at, seed)).all(),\
"decorated gauss_2d_large output differs for seed={}".format(seed)

assert isfinite(gauss_2d_large(csd_at, 63)).all(),\
"invalid output of gauss_2d_large(seed=63)"