Skip to content

Commit

Permalink
Improve memory efficiency of _create_sparse_matrix in BinnedSpikeTrai…
Browse files Browse the repository at this point in the history
…n class (NeuralEnsemble#395)

Co-authored-by: kleinjohann <a.kleinjohann@fz-juelich.de>
  • Loading branch information
morales-gregorio and Kleinjohann committed Jan 12, 2021
1 parent 6193e66 commit 0be27f9
Showing 1 changed file with 25 additions and 5 deletions.
30 changes: 25 additions & 5 deletions elephant/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -1065,11 +1065,23 @@ def _create_sparse_matrix(self, spiketrains):
Spike trains to bin.
"""

# The data type for numeric values
data_dtype = np.int32

if not _check_neo_spiketrain(spiketrains):
# a binned numpy array
sparse_matrix = sps.csr_matrix(spiketrains, dtype=np.int32)
sparse_matrix = sps.csr_matrix(spiketrains, dtype=data_dtype)
return sparse_matrix

# Get index dtype that can accomodate the largest index
# (this is the same dtype that will be used for the index arrays of the
# sparse matrix, so already using it here avoids array duplication)
shape = (len(spiketrains), self.n_bins)
numtype = np.int32
if max(shape) > np.iinfo(numtype).max:
numtype = np.int64

row_ids, column_ids = [], []
# data
counts = []
Expand All @@ -1089,21 +1101,29 @@ def _create_sparse_matrix(self, spiketrains):
valid_bins = bins[bins < self.n_bins]
n_discarded += len(bins) - len(valid_bins)
f, c = np.unique(valid_bins, return_counts=True)
# f inherits the dtype np.int32 from bins, but c is created in
# np.unique with the default int dtype (usually np.int64)
c = c.astype(data_dtype)
column_ids.append(f)
counts.append(c)
row_ids.append(np.repeat(idx, repeats=len(f)))
row_ids.append(np.repeat(idx, repeats=len(f)).astype(numtype))

if n_discarded > 0:
warnings.warn("Binning discarded {} last spike(s) of the "
"input spiketrain".format(n_discarded))

# Stacking preserves the data type. In any case, while creating
# the sparse matrix, a copy is performed even if we set 'copy' to False
# explicitly (however, this might change in future scipy versions -
# this depends on scipy csr matrix initialization implementation).
counts = np.hstack(counts)
row_ids = np.hstack(row_ids)
column_ids = np.hstack(column_ids)
row_ids = np.hstack(row_ids)

sparse_matrix = sps.csr_matrix((counts, (row_ids, column_ids)),
shape=(len(spiketrains), self.n_bins),
dtype=np.int32, copy=False)
shape=shape, dtype=data_dtype,
copy=False)

return sparse_matrix


Expand Down

0 comments on commit 0be27f9

Please sign in to comment.