Skip to content

Commit

Permalink
FIX: improper normalization in time_shift function.
Browse files Browse the repository at this point in the history
Cross-correlation measurements of time-shifts was not properly normalized. Took the opportunity to deprecate time_shift and time_shifts in favour of two new functions with correct normalization, register_time_shift(s).
  • Loading branch information
LaurentRDC committed Jun 12, 2018
1 parent aececc3 commit 62beb54
Show file tree
Hide file tree
Showing 6 changed files with 75 additions and 34 deletions.
4 changes: 2 additions & 2 deletions docs/source/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,8 @@ Measurement of time-shifts between physically-equivalent time traces:
:toctree: functions/
:nosignatures:

time_shift
time_shifts
register_time_shift
register_time_shifts

Robust statistics
-----------------
Expand Down
2 changes: 1 addition & 1 deletion skued/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,5 +27,5 @@
from .structure import (Atom, AtomicStructure, Crystal, Lattice,
symmetry_expansion, lattice_system, LatticeSystem)
from .thin_films import film_optical_coefficients
from .time_series import time_shift, time_shifts, nfftfreq, nfft, mad
from .time_series import time_shift, time_shifts, register_time_shift, register_time_shifts, nfftfreq, nfft, mad
from .voigt import gaussian, lorentzian, pseudo_voigt
2 changes: 1 addition & 1 deletion skued/time_series/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@

from .robust import mad
from .nfft_routines import nfft, nfftfreq
from .time_zero import time_shift, time_shifts
from .time_zero import time_shift, time_shifts, register_time_shift, register_time_shifts
33 changes: 19 additions & 14 deletions skued/time_series/tests/test_time_zero.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@

from .. import time_shift, time_shifts
from .. import register_time_shift, register_time_shifts
import unittest
import numpy as np

from scipy.ndimage.interpolation import shift as scipy_shift

np.random.seed(23)

class TestTimeShift(unittest.TestCase):
Expand All @@ -12,66 +14,69 @@ def test_trivial(self):
with self.subTest('Even length'):
trace1 = np.sin(2*np.pi*np.linspace(0, 10, 64))
trace2 = np.array(trace1, copy = True)
shift = time_shift(trace1, trace2)
shift = register_time_shift(trace1, trace2)
self.assertEqual(shift, 0)

with self.subTest('Odd length'):
trace1 = np.sin(2*np.pi*np.linspace(0, 10, 65))
trace2 = np.array(trace1, copy = True)
shift = time_shift(trace1, trace2)
shift = register_time_shift(trace1, trace2)
self.assertEqual(shift, 0)

def test_shift_no_noise(self):
""" Test measuring the time-shift between traces shifted from one another, without added noise """
trace1 = np.sin(2*np.pi*np.linspace(0, 10, 64))
trace2 = np.roll(trace1, 5)
shift = time_shift(trace1, trace2)
shift = register_time_shift(trace1, trace2)
self.assertEqual(shift, -5)

def test_shift_with_noise(self):
""" Test measuring the time-shift between traces shifted from one another, with added 10% gaussian noise """
trace1 = np.sin(2*np.pi*np.linspace(0, 10, 64))
trace2 = np.roll(trace1, 5)
trace2 = scipy_shift(trace1, 5)

trace1 = trace1[6:-6]
trace2 = trace2[6:-6]

trace1 += 0.1*np.random.random(size = trace1.shape)
trace2 += 0.1*np.random.random(size = trace2.shape)
shift = time_shift(trace1, trace2)
trace1 += 0.05*np.random.random(size = trace1.shape)
trace2 += 0.05*np.random.random(size = trace2.shape)
shift = register_time_shift(trace1, trace2)
self.assertEqual(shift, -5)

def test_shift_different_lengths(self):
""" Test that time_shift() raises an exception if the reference and trace do not have the same shape """
""" Test that register_time_shift() raises an exception if the reference and trace do not have the same shape """
with self.assertRaises(ValueError):
trace1 = np.empty((16,))
trace2 = np.empty((8,))
time_shift(trace1, trace2)
register_time_shift(trace1, trace2)

class TestTimeShifts(unittest.TestCase):

def test_trivial(self):
""" Test that the time-shifts between identical time traces """
with self.subTest('Even lengths'):
traces = [np.sin(2*np.pi*np.linspace(0, 10, 64)) for _ in range(10)]
shifts = time_shifts(traces)
shifts = register_time_shifts(traces)
self.assertTrue(np.allclose(shifts, np.zeros_like(shifts)))

with self.subTest('Odd lengths'):
traces = [np.sin(2*np.pi*np.linspace(0, 10, 31)) for _ in range(10)]
shifts = time_shifts(traces)
shifts = register_time_shifts(traces)
self.assertTrue(np.allclose(shifts, np.zeros_like(shifts)))

def test_output_shape(self):
""" Test the output shape """
with self.subTest('reference = None'):
traces = [np.sin(2*np.pi*np.linspace(0, 10, 64) + i) for i in range(10)]
shifts = time_shifts(traces)
shifts = register_time_shifts(traces)
self.assertTupleEqual(shifts.shape, (len(traces), ))
# The first shift should then be zero
# because it is the shift between the reference and itself
self.assertEqual(shifts[0], 0)

with self.subTest('reference is not None'):
traces = [np.sin(2*np.pi*np.linspace(0, 10, 64) + i) for i in range(10)]
shifts = time_shifts(traces, reference = np.array(traces[0], copy = True))
shifts = register_time_shifts(traces, reference = np.array(traces[0], copy = True))
self.assertTupleEqual(shifts.shape, (len(traces), ))


Expand Down
62 changes: 49 additions & 13 deletions skued/time_series/time_zero.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,33 @@
# -*- coding: utf-8 -*-

from collections.abc import Sized
from functools import partial
from functools import lru_cache, partial

import numpy as np
from scipy.signal import correlate

from npstreams import peek, array_stream
from npstreams import array_stream, peek

from ..utils import deprecated

def time_shift(trace, reference, method = 'auto'):

# Save the normalization of correlations so that identical
# autocorrelations are saved.
@lru_cache(maxsize = 128)
def __xcorr_normalization(size, dtype):
arr = np.ones(shape = (size,), dtype = dtype)
return correlate(arr, arr, mode = 'full')

def register_time_shift(trace, reference, method = 'auto'):
"""
Measure the time shift between a time trace and a reference trace
by cross correlation.
by normalized cross correlation.
.. versionadded:: 1.0.1.1
Parameters
----------
trace : array-like
Time trace. Must be the same lengths as ``reference``.
trace : array-like, shape (N,)
Time trace. Must be the same lengths as ``reference``. Only 1D traces are supported.
reference : array-like
Reference trace.
method : str {'auto', 'fft', 'direct'}, optional
Expand All @@ -32,28 +42,40 @@ def time_shift(trace, reference, method = 'auto'):
Raises
------
ValueError : if ``trace`` and ``reference`` do not have the same shape.
ValueError : if ``trace`` is not a 1D array
See Also
--------
time_shifts : measure time-shift between multiple traces and a reference.
register_time_shifts : measure time-shift between multiple traces and a reference.
scipy.signal.choose_conv_method : contains more documentation on ``method``
"""
trace, reference = np.atleast_1d(trace, reference)
if trace.shape != reference.shape:
raise ValueError('Time trace and reference trace are expected to have the same shape, be received \
a time-trace of shape {} and a reference trace of shape {}'.format(trace.shape, reference.shape))

if trace.ndim > 1:
raise ValueError('Expected 1D time traces, but received traces of shape {}'.format(trace.shape))

trace = trace - trace.mean()
reference = reference - reference.mean()

xcorr = correlate(trace, reference, mode = 'same', method = method)
# Normalized cross-correlation
# Note : we use an external function to calculate normalization
# so that it can be efficiently cached
xcorr = correlate(trace, reference, mode = 'full', method = 'auto') / __xcorr_normalization(trace.size, trace.dtype)

# Generalize to the average of multiple maxima
maxima = np.transpose(np.nonzero(xcorr == xcorr.max()))
return np.mean(maxima) - int(xcorr.shape[0]/2)

@array_stream
def time_shifts(traces, reference = None, method = 'auto'):
def register_time_shifts(traces, reference = None, method = 'auto'):
"""
Measure the time shifts between time traces and a reference by cross-correlation.
.. versionadded:: 1.0.1.1
Parameters
----------
traces : iterable of ndarrays
Expand All @@ -76,10 +98,11 @@ def time_shifts(traces, reference = None, method = 'auto'):
Raises
------
ValueError : if not all traces have the same shape.
ValueError : if traces are not 1D arrays
See Also
--------
time_shift : measure time-shift between a single trace and a reference.
register_time_shift : measure time-shift between a single trace and a reference.
"""
# fromiter can preallocate the full array if the number of traces
# is known in advance
Expand All @@ -96,5 +119,18 @@ def time_shifts(traces, reference = None, method = 'auto'):

kwargs = {'reference': reference, 'method': method}

shifts = map(partial(time_shift, **kwargs), traces)
return np.fromiter(shifts, dtype = np.float, count = count)
shifts = map(partial(register_time_shift, **kwargs), traces)
return np.fromiter(shifts, dtype = np.float, count = count)

# TODO: Remove in an upcoming release
@deprecated('time_shift function is deprecated in favor of `register_time_shift`')
def time_shift(*args, **kwargs):
return register_time_shift(*args, **kwargs)

time_shift.__doc__ = register_time_shift.__doc__

@deprecated('time_shifts function is deprecated in favor or `register_time_shifts`')
def time_shifts(*args, **kwargs):
return register_time_shifts(*args, **kwargs)

time_shifts.__doc__ = register_time_shifts.__doc__
6 changes: 3 additions & 3 deletions skued/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from functools import wraps
from warnings import resetwarnings, simplefilter, warn


@contextmanager
def suppress_warnings():
"""
Expand Down Expand Up @@ -36,10 +37,9 @@ def decorator(func):
def newfunc(*args, **kwargs):
full_message = """Calls to {name} deprecated: {message}.
{name} will be removed in a future release.""".format(name = func.__name__, message = message)
with contextwarnings('always', DeprecationWarning):
warn(full_message, category = DeprecationWarning, stacklevel = 2)
warn(full_message, category = DeprecationWarning, stacklevel = 2)
return func(*args, **kwargs)

return newfunc

return decorator
return decorator

0 comments on commit 62beb54

Please sign in to comment.