Skip to content

Commit

Permalink
require std dev input for whitegaussnoise (#276)
Browse files Browse the repository at this point in the history
* require std dev input for `whitegaussnoise`

* fix bugs in tests

* change stdev to std
  • Loading branch information
stestoll committed Mar 9, 2022
1 parent 3d0919c commit 4954618
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 11 deletions.
16 changes: 8 additions & 8 deletions deerlab/whitegaussnoise.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
import numpy as np

def whitegaussnoise(t,level=1,rescale=False,seed=None):
def whitegaussnoise(t, std, rescale=False, seed=None):
r"""
Generates a vector of white Gaussian (normal) noise
Parameters
----------
t : array_like
Time axis.
level : float scalar
std : float scalar
Noise level, i.e. standard deviation of underlying Gaussian distribution.
rescale : boolean, optional
If true, rescales the noise vector such that its standard deviation is exactly equal
to ``level``. If false (default), the standard deviation of the noise vector can deviate
slightly from ``level``, particularly for short vectors.
If ``True``, rescales the noise vector such that its standard deviation is exactly equal
to ``std``. If ``False`` (default), the standard deviation of the noise vector can deviate
slightly from ``std``, particularly for short vectors.
seed : integer scalar, optional
If ``None`` (default), do not seed the random number generator. If an integer scalar is
given (e.g. ``seed=137``), seed the random number generator with this number.
Expand All @@ -36,12 +36,12 @@ def whitegaussnoise(t,level=1,rescale=False,seed=None):
if seed is not None:
np.random.seed(seed)

# Draw from standard normal distribution
# Draw from normal distribution
N = len(np.atleast_1d(t))
noise = np.random.normal(0,level,N)
noise = np.random.normal(0, std, N)

# Rescale to sample std if wanted
if rescale:
noise = level/np.std(noise)*noise
noise *= std/np.std(noise)

return noise
6 changes: 3 additions & 3 deletions test/test_model_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,7 +515,7 @@ def test_bootCIs_parametric():
"Check the bootstrapped confidence intervals of the fitted parameters"
model = _getmodel('parametric')

noisydata = mock_data + whitegaussnoise(0.01,seed=1)
noisydata = mock_data + whitegaussnoise(x,0.01,seed=1)
fitResult = fit(model,noisydata,bootstrap=3)

assert_attributes_cis(fitResult,['mean1','mean2','width1','width2','amp1','amp2'])
Expand All @@ -526,7 +526,7 @@ def test_bootCIs_semiparametric():
"Check the bootstrapped confidence intervals of the fitted parameters"
model = _getmodel('semiparametric')

noisydata = mock_data + whitegaussnoise(0.01,seed=1)
noisydata = mock_data + whitegaussnoise(x,0.01,seed=1)
fitResult = fit(model,noisydata,bootstrap=3)

assert_attributes_cis(fitResult,['mean1','mean2','width1','width2','amp1','amp2'])
Expand All @@ -537,7 +537,7 @@ def test_bootCIs_nonparametric():
"Check the bootstrapped confidence intervals of the fitted parameters"
model = _getmodel('semiparametric')

noisydata = mock_data + whitegaussnoise(0.01,seed=1)
noisydata = mock_data + whitegaussnoise(x,0.01,seed=1)
fitResult = fit(model,noisydata,bootstrap=3)

assert_attributes_cis(fitResult,['amp1','amp2'])
Expand Down
17 changes: 17 additions & 0 deletions test/test_whitegaussnoise.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from deerlab.utils.utils import assert_docstring
import numpy as np
import pytest
from deerlab import whitegaussnoise
from deerlab.utils import assert_docstring

Expand Down Expand Up @@ -56,8 +57,24 @@ def test_noseed():
assert not np.array_equal(noise1,noise2)
# ======================================================================


def test_docstring():
# ======================================================================
"Check that the docstring includes all variables and keywords."
assert_docstring(whitegaussnoise)
# ======================================================================


def test_requiredstd():
# ======================================================================
"Check that the standard deviation is a required argument"

N = 10
t = np.linspace(0,3,N)

with pytest.raises(TypeError):
noise = whitegaussnoise(t)

# ======================================================================


0 comments on commit 4954618

Please sign in to comment.