@@ -66,15 +66,22 @@ class Spike2IO(BaseIO):
66
66
67
67
mode = 'file'
68
68
69
- def __init__ (self , filename = None ):
69
+ ced_units = False
70
+
71
+ def __init__ (self , filename = None , ced_units = False ):
70
72
"""
71
- This class read a smr file.
73
+ This class reads an smr file.
72
74
73
75
Arguments:
74
76
filename : the filename
77
+ ced_units: whether a spike trains should be added for each unit
78
+ as determined by Spike2's spike sorting (True), or if a spike
79
+ channel should be considered a single unit and will ignore
80
+ Spike2's spike sorting (False). Defaults to False.
75
81
"""
76
82
BaseIO .__init__ (self )
77
83
self .filename = filename
84
+ self .ced_units = ced_units
78
85
79
86
def read_segment (self , take_ideal_sampling_rate = False ,
80
87
lazy = False , cascade = True ):
@@ -132,11 +139,12 @@ def addannotations(ob, channelHeader):
132
139
seg .events .append (ea )
133
140
134
141
elif channelHeader .kind in [6 , 7 ]:
135
- sptr = self .read_one_channel_event_or_spike (
142
+ sptrs = self .read_one_channel_event_or_spike (
136
143
fid , i , header , lazy = lazy )
137
- if sptr is not None :
138
- addannotations (sptr , channelHeader )
139
- seg .spiketrains .append (sptr )
144
+ if sptrs is not None :
145
+ for sptr in sptrs :
146
+ addannotations (sptr , channelHeader )
147
+ seg .spiketrains .append (sptr )
140
148
141
149
fid .close ()
142
150
@@ -165,23 +173,21 @@ def read_header(self, filename=''):
165
173
('offset' , 'f4' ),
166
174
('unit' , 'S6' ), ]
167
175
channelHeader += HeaderReader (fid , np .dtype (dt ))
176
+
168
177
if header .system_id < 6 :
169
- channelHeader += HeaderReader (fid , np .dtype (
170
- [('divide' , 'i4' )])) # i8
171
- else :
172
- channelHeader += HeaderReader (fid , np .dtype (
173
- [('interleave' , 'i4' )])) # i8
178
+ channelHeader += HeaderReader (fid , np .dtype ([ ('divide' , 'i2' )]) )
179
+ else :
180
+ channelHeader += HeaderReader (fid , np .dtype ([ ('interleave' , 'i2' )]) )
181
+
174
182
if channelHeader .kind in [7 , 9 ]:
175
183
dt = [('min' , 'f4' ),
176
184
('max' , 'f4' ),
177
185
('unit' , 'S6' ), ]
178
186
channelHeader += HeaderReader (fid , np .dtype (dt ))
179
187
if header .system_id < 6 :
180
- channelHeader += HeaderReader (fid , np .dtype (
181
- [('divide' , 'i4' )])) # i8
182
- else :
183
- channelHeader += HeaderReader (fid , np .dtype (
184
- [('interleave' , 'i4' )])) # i8
188
+ channelHeader += HeaderReader (fid , np .dtype ([ ('divide' , 'i2' )]))
189
+ else :
190
+ channelHeader += HeaderReader (fid , np .dtype ([ ('interleave' , 'i2' )]) )
185
191
if channelHeader .kind in [4 ]:
186
192
dt = [('init_low' , 'u1' ),
187
193
('next_low' , 'u1' ), ]
@@ -352,7 +358,7 @@ def read_one_channel_event_or_spike(self, fid, channel_num, header,
352
358
elif channelHeader .kind in [6 , 7 ]:
353
359
# correct value for t_stop to be put in later
354
360
sptr = SpikeTrain ([] * pq .s , t_stop = 1e99 )
355
- sptr .annotate (channel_index = channel_num )
361
+ sptr .annotate (channel_index = channel_num , ced_unit = 0 )
356
362
sptr .lazy_shape = totalitems
357
363
return sptr
358
364
else :
@@ -431,13 +437,27 @@ def read_one_channel_event_or_spike(self, fid, channel_num, header,
431
437
t_stop = alltimes .max ()
432
438
else :
433
439
t_stop = 0.0
434
- sptr = SpikeTrain (alltimes ,
435
- waveforms = waveforms * unit ,
436
- sampling_rate = (1. / sample_interval ) * pq .Hz ,
437
- t_stop = t_stop )
438
- sptr .annotate (channel_index = channel_num )
439
440
440
- return sptr
441
+ if not self .ced_units :
442
+ sptr = SpikeTrain (alltimes ,
443
+ waveforms = waveforms * unit ,
444
+ sampling_rate = (1. / sample_interval )* pq .Hz ,
445
+ t_stop = t_stop
446
+ )
447
+ sptr .annotate (channel_index = channel_num , ced_unit = 0 )
448
+ return [sptr ]
449
+
450
+ sptrs = []
451
+ for i in set (alltrigs ['marker' ] & 255 ):
452
+ sptr = SpikeTrain (alltimes [alltrigs ['marker' ] == i ],
453
+ waveforms = waveforms [alltrigs ['marker' ] == i ]* unit ,
454
+ sampling_rate = (1. / sample_interval )* pq .Hz ,
455
+ t_stop = t_stop
456
+ )
457
+ sptr .annotate (channel_index = channel_num , ced_unit = i )
458
+ sptrs .append (sptr )
459
+
460
+ return sptrs
441
461
442
462
443
463
class HeaderReader (object ):
0 commit comments