Skip to content

Commit

Permalink
Fixing index selection in spike_triggered_phase (see issue NeuralEnse…
Browse files Browse the repository at this point in the history
…mble#381) (NeuralEnsemble#382)

* adding exception for interpolate=True

* replace searchsorted and fix representation mismatch for interpolation=True
  • Loading branch information
rgutzen committed Nov 10, 2020
1 parent a6a0854 commit 6f81f6c
Showing 1 changed file with 9 additions and 17 deletions.
26 changes: 9 additions & 17 deletions elephant/phase_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,31 +125,24 @@ def spike_triggered_phase(hilbert_transform, spiketrains, interpolate):
sttimeind = np.where(np.logical_and(
spiketrain >= start[phase_i], spiketrain < stop[phase_i]))[0]

# Extract times for speed reasons
times = hilbert_transform[phase_i].times

# Find index into signal for each spike
ind_at_spike = np.round(
ind_at_spike = (
(spiketrain[sttimeind] - hilbert_transform[phase_i].t_start) /
hilbert_transform[phase_i].sampling_period). \
simplified.magnitude.astype(int)

# Extract times for speed reasons
times = hilbert_transform[phase_i].times

# Append new list to the results for this spiketrain
result_phases.append([])
result_amps.append([])
result_times.append([])

# Step through all spikes
for spike_i, ind_at_spike_j in enumerate(ind_at_spike):
# Difference vector between actual spike time and sample point,
# positive if spike time is later than sample point
dv = spiketrain[sttimeind[spike_i]] - times[ind_at_spike_j]

# Make sure ind_at_spike is to the left of the spike time
if dv < 0 and ind_at_spike_j > 0:
ind_at_spike_j = ind_at_spike_j - 1

if interpolate:
if interpolate and ind_at_spike_j+1 < len(times):
# Get relative spike occurrence between the two closest signal
# sample points
# if z->0 spike is more to the left sample
Expand All @@ -160,10 +153,10 @@ def spike_triggered_phase(hilbert_transform, spiketrains, interpolate):
# Save hilbert_transform (interpolate on circle)
p1 = np.angle(hilbert_transform[phase_i][ind_at_spike_j])
p2 = np.angle(hilbert_transform[phase_i][ind_at_spike_j + 1])
result_phases[spiketrain_i].append(
np.angle(
(1 - z) * np.exp(np.complex(0, p1)) +
z * np.exp(np.complex(0, p2))))
interpolation = (1 - z) * np.exp(np.complex(0, p1)) \
+ z * np.exp(np.complex(0, p2))
p12 = np.angle([interpolation])
result_phases[spiketrain_i].append(p12)

# Save amplitude
result_amps[spiketrain_i].append(
Expand All @@ -188,5 +181,4 @@ def spike_triggered_phase(hilbert_transform, spiketrains, interpolate):
result_amps[i] = pq.Quantity(entry, units=entry[0].units).flatten()
for i, entry in enumerate(result_times):
result_times[i] = pq.Quantity(entry, units=entry[0].units).flatten()

return result_phases, result_amps, result_times

0 comments on commit 6f81f6c

Please sign in to comment.