Skip to content

Commit

Permalink
Removed option with window list of floating to avoid rounding errors (N…
Browse files Browse the repository at this point in the history
…euralEnsemble#172)

* Removed option with window list of floating to avoid rounding errors (now only list of int accepted) and adapted relative tests
* pep8
  • Loading branch information
pietroquaglio authored and alperyeg committed Sep 13, 2018
1 parent 4874ed0 commit 9fead94
Show file tree
Hide file tree
Showing 2 changed files with 134 additions and 221 deletions.
228 changes: 91 additions & 137 deletions elephant/spike_train_correlation.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,8 +249,8 @@ def __calculate_correlation_or_covariance(binned_sts, binary, corrcoef_norm):


def cross_correlation_histogram(
binned_st1, binned_st2, window='full', border_correction=False, binary=False,
kernel=None, method='speed', cross_corr_coef=False):
binned_st1, binned_st2, window='full', border_correction=False,
binary=False, kernel=None, method='speed', cross_corr_coef=False):
"""
Computes the cross-correlation histogram (CCH) between two binned spike
trains binned_st1 and binned_st2.
Expand All @@ -260,7 +260,7 @@ def cross_correlation_histogram(
binned_st1, binned_st2 : BinnedSpikeTrain
Binned spike trains to cross-correlate. The two spike trains must have
same t_start and t_stop
window : string or list (optional)
window : string or list of integer (optional)
‘full’: This returns the crosscorrelation at each point of overlap,
with an output shape of (N+M-1,). At the end-points of the
cross-correlogram, the signals do not overlap completely, and
Expand All @@ -269,12 +269,11 @@ def cross_correlation_histogram(
The cross-correlation product is only given for points where the
signals overlap completely.
Values outside the signal boundary have no effect.
list of integer (window[0]=minimum lag, window[1]=maximum lag): The
entries of window are two integers representing the left and
right extremes (expressed as number of bins) where the
crosscorrelation is computed
Default: 'full'
list of integer of of quantities (window[0]=minimum, window[1]=maximum
lag): The entries of window can be integer (number of bins) or
quantities (time units of the lag), in the second case they have to be
a multiple of the binsize
Default: 'Full'
border_correction : bool (optional)
whether to correct for the border effect. If True, the value of the
CCH at bin b (for b=-H,-H+1, ...,H, where H is the CCH half-length)
Expand Down Expand Up @@ -308,21 +307,23 @@ def cross_correlation_histogram(
which is more memory efficient but slower than the "speed" option.
Default: "speed"
cross_corr_coef : bool (optional)
Normalizes the CCH to obtain the cross-correlation coefficient
function ranging from -1 to 1 according to Equation (5.10) in
Normalizes the CCH to obtain the cross-correlation coefficient
function ranging from -1 to 1 according to Equation (5.10) in
"Analysis of parallel spike trains", 2010, Gruen & Rotter, Vol 7
Returns
-------
cch : AnalogSignal
Containing the cross-correlation histogram between binned_st1 and binned_st2.
Containing the cross-correlation histogram between binned_st1 and
binned_st2.
The central bin of the histogram represents correlation at zero
delay. Offset bins correspond to correlations at a delay equivalent
to the difference between the spike times of binned_st1 and those of binned_st2: an
entry at positive lags corresponds to a spike in binned_st2 following a
spike in binned_st1 bins to the right, and an entry at negative lags
corresponds to a spike in binned_st1 following a spike in binned_st2.
to the difference between the spike times of binned_st1 and those of
binned_st2: an entry at positive lags corresponds to a spike in
binned_st2 following a spike in binned_st1 bins to the right, and an
entry at negative lags corresponds to a spike in binned_st1 following
a spike in binned_st2.
To illustrate this definition, consider the two spike trains:
binned_st1: 0 0 0 0 1 0 0 0 0 0 0
Expand Down Expand Up @@ -355,7 +356,8 @@ def cross_correlation_histogram(
10. * pq.Hz, t_start=0 * pq.ms, t_stop=5000 * pq.ms),
binsize=5. * pq.ms)
>>> cc_hist = elephant.spike_train_correlation.cross_correlation_histogram(
>>> cc_hist = \
elephant.spike_train_correlation.cross_correlation_histogram(
binned_st1, binned_st2, window=[-30,30],
border_correction=False,
binary=False, kernel=None, method='memory')
Expand All @@ -373,11 +375,11 @@ def cross_correlation_histogram(
-----
cch
"""

def _cross_corr_coef(cch_result, binned_st1, binned_st2):
# Normalizes the CCH to obtain the cross-correlation
# Normalizes the CCH to obtain the cross-correlation
# coefficient function ranging from -1 to 1
N = max(binned_st1.num_bins, binned_st2.num_bins)
N = max(binned_st1.num_bins, binned_st2.num_bins)
Nx = len(binned_st1.spike_indices[0])
Ny = len(binned_st2.spike_indices[0])
spmat = [binned_st1.to_sparse_array(), binned_st2.to_sparse_array()]
Expand All @@ -386,10 +388,10 @@ def _cross_corr_coef(cch_result, binned_st1, binned_st2):
bin_counts_unique.append(s.data)
ii = np.dot(bin_counts_unique[0], bin_counts_unique[0])
jj = np.dot(bin_counts_unique[1], bin_counts_unique[1])
rho_xy = (cch_result - Nx*Ny/N) / np.sqrt( (ii-Nx**2./N)*(jj-Ny**2./N) )
rho_xy = (cch_result - Nx * Ny / N) / \
np.sqrt((ii - Nx**2. / N) * (jj - Ny**2. / N))
return rho_xy



def _border_correction(counts, max_num_bins, l, r):
# Correct the values taking into account lacking contributes
# at the edges
Expand All @@ -415,55 +417,12 @@ def _kernel_smoothing(counts, kern, l, r):
# Smooth the cross-correlation histogram with the kern
return np.convolve(counts, kern, mode='same')

def _cch_memory(binned_st1, binned_st2, win, border_corr, binary, kern):
def _cch_memory(binned_st1, binned_st2, left_edge, right_edge,
border_corr, binary, kern):

# Retrieve unclipped matrix
st1_spmat = binned_st1.to_sparse_array()
st2_spmat = binned_st2.to_sparse_array()
binsize = binned_st1.binsize
max_num_bins = max(binned_st1.num_bins, binned_st2.num_bins)

# Set the time window in which is computed the cch
if not isinstance(win, str):
# Window parameter given in number of bins (integer)
if isinstance(win[0], int) and isinstance(win[1], int):
# Check the window parameter values
if win[0] >= win[1] or win[0] <= -max_num_bins \
or win[1] >= max_num_bins:
raise ValueError(
"The window exceeds the length of the spike trains")
# Assign left and right edges of the cch
l, r = win[0], win[1]
# Window parameter given in time units
else:
# Check the window parameter values
if win[0].rescale(binsize.units).magnitude % \
binsize.magnitude != 0 or win[1].rescale(
binsize.units).magnitude % binsize.magnitude != 0:
raise ValueError(
"The window has to be a multiple of the binsize")
if win[0] >= win[1] or win[0] <= -max_num_bins * binsize \
or win[1] >= max_num_bins * binsize:
raise ValueError("The window exceeds the length of the"
" spike trains")
# Assign left and right edges of the cch
l, r = int(win[0].rescale(binsize.units) / binsize), int(
win[1].rescale(binsize.units) / binsize)
# Case without explicit window parameter
elif window == 'full':
# cch computed for all the possible entries
# Assign left and right edges of the cch
r = binned_st2.num_bins - 1
l = - binned_st1.num_bins + 1
# cch compute only for the entries that completely overlap
elif window == 'valid':
# cch computed only for valid entries
# Assign left and right edges of the cch
r = max(binned_st2.num_bins - binned_st1.num_bins, 0)
l = min(binned_st2.num_bins - binned_st1.num_bins, 0)
# Check the mode parameter
else:
raise KeyError("Invalid window parameter")

# For each row, extract the nonzero column indices
# and the corresponding # data in the matrix (for performance reasons)
Expand All @@ -482,25 +441,28 @@ def _cch_memory(binned_st1, binned_st2, win, border_corr, binary, kern):
# Initialize the counts to an array of zeroes,
# and the bin IDs to integers
# spanning the time axis
counts = np.zeros(np.abs(l) + np.abs(r) + 1)
bin_ids = np.arange(l, r + 1)
# Compute the CCH at lags in l,...,r only
counts = np.zeros(np.abs(left_edge) + np.abs(right_edge) + 1)
bin_ids = np.arange(left_edge, right_edge + 1)
# Compute the CCH at lags in left_edge,...,right_edge only
for idx, i in enumerate(st1_bin_idx_unique):
il = np.searchsorted(st2_bin_idx_unique, l + i)
ir = np.searchsorted(st2_bin_idx_unique, r + i, side='right')
il = np.searchsorted(st2_bin_idx_unique, left_edge + i)
ir = np.searchsorted(st2_bin_idx_unique,
right_edge + i, side='right')
timediff = st2_bin_idx_unique[il:ir] - i
assert ((timediff >= l) & (timediff <= r)).all(), 'Not all the '
assert ((timediff >= left_edge) & (
timediff <= right_edge)).all(), 'Not all the '
'entries of cch lie in the window'
counts[timediff + np.abs(l)] += (st1_bin_counts_unique[idx] *
st2_bin_counts_unique[il:ir])
counts[timediff + np.abs(left_edge)] += (
st1_bin_counts_unique[idx] * st2_bin_counts_unique[il:ir])
st2_bin_idx_unique = st2_bin_idx_unique[il:]
st2_bin_counts_unique = st2_bin_counts_unique[il:]
# Border correction
if border_corr is True:
counts = _border_correction(counts, max_num_bins, l, r)
counts = _border_correction(
counts, max_num_bins, left_edge, right_edge)
if kern is not None:
# Smoothing
counts = _kernel_smoothing(counts, kern, l, r)
counts = _kernel_smoothing(counts, kern, left_edge, right_edge)
# Transform the array count into an AnalogSignal
cch_result = neo.AnalogSignal(
signal=counts.reshape(counts.size, 1),
Expand All @@ -511,75 +473,34 @@ def _cch_memory(binned_st1, binned_st2, win, border_corr, binary, kern):
# central one
return cch_result, bin_ids

def _cch_speed(binned_st1, binned_st2, win, border_corr, binary, kern):
def _cch_speed(binned_st1, binned_st2, left_edge, right_edge, cch_mode,
border_corr, binary, kern):

# Retrieve the array of the binne spik train
# Retrieve the array of the binne spike train
st1_arr = binned_st1.to_array()[0, :]
st2_arr = binned_st2.to_array()[0, :]
binsize = binned_st1.binsize

# Convert the to binary version
if binary:
st1_arr = np.array(st1_arr > 0, dtype=int)
st2_arr = np.array(st2_arr > 0, dtype=int)
max_num_bins = max(len(st1_arr), len(st2_arr))

# Cross correlate the spiketrains

# Case explicit temporal window
if not isinstance(win, str):
# Window parameter given in number of bins (integer)
if isinstance(win[0], int) and isinstance(win[1], int):
# Check the window parameter values
if win[0] >= win[1] or win[0] <= -max_num_bins \
or win[1] >= max_num_bins:
raise ValueError(
"The window exceed the length of the spike trains")
# Assign left and right edges of the cch
l, r = win
# Window parameter given in time units
else:
# Check the window parameter values
if win[0].rescale(binsize.units).magnitude % \
binsize.magnitude != 0 or win[1].rescale(
binsize.units).magnitude % binsize.magnitude != 0:
raise ValueError(
"The window has to be a multiple of the binsize")
if win[0] >= win[1] or win[0] <= -max_num_bins * binsize \
or win[1] >= max_num_bins * binsize:
raise ValueError("The window exceed the length of the"
" spike trains")
# Assign left and right edges of the cch
l, r = int(win[0].rescale(binsize.units) / binsize), int(
win[1].rescale(binsize.units) / binsize)

# Zero padding
st1_arr = np.pad(
st1_arr, (int(np.abs(np.min([l, 0]))), np.max([r, 0])),
mode='constant')
if cch_mode == 'pad':
# Zero padding to stay between left_edge and right_edge
st1_arr = np.pad(st1_arr,
(int(np.abs(np.min([left_edge, 0]))), np.max(
[right_edge, 0])),
mode='constant')
cch_mode = 'valid'
else:
# Assign the edges of the cch for the different mode parameters
if win == 'full':
# Assign left and right edges of the cch
r = binned_st2.num_bins - 1
l = - binned_st1.num_bins + 1
# cch compute only for the entries that completely overlap
elif win == 'valid':
# Assign left and right edges of the cch
r = max(binned_st2.num_bins - binned_st1.num_bins, 0)
l = min(binned_st2.num_bins - binned_st1.num_bins, 0)
cch_mode = win

# Cross correlate the spike trains
counts = np.correlate(st2_arr, st1_arr, mode=cch_mode)
bin_ids = np.r_[l:r + 1]
bin_ids = np.r_[left_edge:right_edge + 1]
# Border correction
if border_corr is True:
counts = _border_correction(counts, max_num_bins, l, r)
counts = _border_correction(
counts, max_num_bins, left_edge, right_edge)
if kern is not None:
# Smoothing
counts = _kernel_smoothing(counts, kern, l, r)
counts = _kernel_smoothing(counts, kern, left_edge, right_edge)
# Transform the array count into an AnalogSignal
cch_result = neo.AnalogSignal(
signal=counts.reshape(counts.size, 1),
Expand All @@ -606,21 +527,54 @@ def _cch_speed(binned_st1, binned_st2, win, border_corr, binary, kern):
if not binned_st1.t_stop == binned_st2.t_stop:
raise AssertionError("Spike train must have same t stop")

# The maximum number of of bins
max_num_bins = max(binned_st1.num_bins, binned_st2.num_bins)

# Set the time window in which is computed the cch
# Window parameter given in number of bins (integer)
if isinstance(window[0], int) and isinstance(window[1], int):
# Check the window parameter values
if window[0] >= window[1] or window[0] <= -max_num_bins \
or window[1] >= max_num_bins:
raise ValueError(
"The window exceeds the length of the spike trains")
# Assign left and right edges of the cch
left_edge, right_edge = window[0], window[1]
# The mode in which to compute the cch for the speed implementation
cch_mode = 'pad'
# Case without explicit window parameter
elif window == 'full':
# cch computed for all the possible entries
# Assign left and right edges of the cch
right_edge = binned_st2.num_bins - 1
left_edge = - binned_st1.num_bins + 1
cch_mode = window
# cch compute only for the entries that completely overlap
elif window == 'valid':
# cch computed only for valid entries
# Assign left and right edges of the cch
right_edge = max(binned_st2.num_bins - binned_st1.num_bins, 0)
left_edge = min(binned_st2.num_bins - binned_st1.num_bins, 0)
cch_mode = window
# Check the mode parameter
else:
raise KeyError("Invalid window parameter")

if method == "memory":
cch_result, bin_ids = _cch_memory(
binned_st1, binned_st2, window, border_correction, binary,
kernel)
binned_st1, binned_st2, left_edge, right_edge, border_correction,
binary, kernel)
elif method == "speed":

cch_result, bin_ids = _cch_speed(
binned_st1, binned_st2, window, border_correction, binary,
kernel)
binned_st1, binned_st2, left_edge, right_edge, cch_mode,
border_correction, binary, kernel)

if cross_corr_coef:
cch_result = _cross_corr_coef(cch_result, binned_st1, binned_st2)

return cch_result, bin_ids


# Alias for common abbreviation
cch = cross_correlation_histogram

Expand Down Expand Up @@ -737,7 +691,7 @@ def run_T(spiketrain, N, dt):
PB = run_P(spiketrain_2, spiketrain_1, N2, N1, dt)
PB = PB / N2
index = 0.5 * (PA - TB) / (1 - PA * TB) + 0.5 * (PB - TA) / (
1 - PB * TA)
1 - PB * TA)
return index


Expand Down

0 comments on commit 9fead94

Please sign in to comment.