/
all_pairs_model.py
208 lines (188 loc) · 8.59 KB
/
all_pairs_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
from shared import *
from hrtf_analysis import *
from models import *
import gc
class AllPairsModel(object):
'''
Initialise this object with an hrtfset, a cochlear range (cfmin, cfmax, cfN),
a range of gains (gain_max in dB, gain_N) and a range of delays (delay_max,
delay_N),
and optionally:
a model for the coincidence detector neurons (cd_model),
a model for the filter neurons (filtergroup_model).
The __call__ method returns a count (see docstring of that method).
'''
def __init__(self, hrtfset, cfmin, cfmax, cfN,
gain_max, gain_N, delay_max, delay_N,
cd_model=standard_cd_model,
filtergroup_model=standard_filtergroup_model,
):
self.hrtfset = hrtfset
self.cfmin, self.cfmax, self.cfN = cfmin, cfmax, cfN
self.cd_model = cd_model
self.filtergroup_model = filtergroup_model
self.gain_max = gain_max
self.gain_N = gain_N
self.delay_max = delay_max
self.delay_N = delay_N
self.num_indices = num_indices = hrtfset.num_indices
cf = erbspace(cfmin, cfmax, cfN)
# dummy sound, when we run apply() we replace it
sound = Sound((silence(1*ms), silence(1*ms)))
soundinput = DoNothingFilterbank(sound)
# prepare gains filter
m = (gain_N+1)/2
gains_dB = linspace(0, gain_max, m)
gains = 10**(gains_dB/20)
gains = hstack((1/gains[::-1], gains[1:]))
allgains = reshape(gains, (1, 1, gains.size))
def apply_gains(y):
nsamples = y.shape[0]
cfN = y.shape[1]/2
y = reshape(y, (nsamples, 2*cfN, 1))
y1 = y[:, :cfN, :]*allgains
y2 = y[:, cfN:, :]*allgains[:, :, ::-1]
y = hstack((y1, y2))
y = reshape(y, (nsamples, y.size/nsamples))
return y
gfb = Gammatone(Repeat(soundinput, cfN), hstack((cf, cf)))
gains_fb = FunctionFilterbank(gfb, apply_gains)
gains_fb.nchannels = gfb.nchannels*gain_N
compress = filtergroup_model['compress']
cochlea = FunctionFilterbank(gains_fb, lambda x:compress(clip(x, 0, Inf)))
# Create the filterbank group
eqs = Equations(filtergroup_model['eqs'], **filtergroup_model['parameters'])
G = FilterbankGroup(cochlea, 'target_var', eqs,
threshold=filtergroup_model['threshold'],
reset=filtergroup_model['reset'],
refractory=filtergroup_model['refractory'])
# create the synchrony group
cd_eqs = Equations(cd_model['eqs'], **cd_model['parameters'])
cd = NeuronGroup(cfN*gain_N*(delay_N*2-1), cd_eqs,
threshold=cd_model['threshold'],
reset=cd_model['reset'],
refractory=cd_model['refractory'],
clock=G.clock)
# set up the synaptic connectivity
left_delays = hstack((zeros(delay_N-1), linspace(0, float(delay_max), delay_N)))
right_delays = left_delays[::-1]
cd_weight = cd_model['weight']
C = Connection(G, cd, 'target_var', delay=True, max_delay=delay_max)
for i, j, dl, dr in zip(repeat(arange(cfN*gain_N), 2*delay_N-1),
arange(cfN*gain_N*(delay_N*2-1)),
tile(left_delays, cfN*gain_N),
tile(right_delays, cfN*gain_N)):
C[i, j] = cd_weight
C[i+cfN*gain_N, j] = cd_weight
C.delay[i, j] = dl
C.delay[i+cfN*gain_N, j] = dr
self.soundinput = soundinput
self.filtergroup = G
self.synchronygroup = cd
self.synapses = C
self.counter = SpikeCounter(cd)
self.network = Network(G, cd, C, self.counter)
def __call__(self, sound, index=None, **indexkwds):
'''
Apply all pairs filtering group to given sound, which should be a
stereo sound unless you specify the HRTF index, or coordinates of
the HRTF index as keyword arguments, in which case it should be a mono
sound which will have the given HRTF applied to it. You can also
specify index=hrtf. Returns the count of the neurons in the synchrony
group with shape (cfN, gain_N, delay_N*2-1).
'''
hrtf = None
if index is not None:
hrtf = self.hrtfset[index]
elif isinstance(index, HRTF):
hrtf = index
elif len(indexkwds):
hrtf = self.hrtfset(**indexkwds)
if hrtf is not None:
sound = hrtf(sound)
self.soundinput.source = sound
self.network.reinit()
self.filtergroup_model['init'](self.filtergroup,
self.filtergroup_model['parameters'])
self.cd_model['init'](self.synchronygroup, self.cd_model['parameters'])
self.network.run(sound.duration, report='stderr')
count = reshape(self.counter.count,
(self.cfN, self.gain_N, self.delay_N*2-1))
return count
if __name__=='__main__':
from plot_count import ircam_plot_count
hrtfdb = get_ircam()
subject = 1002
hrtfset = hrtfdb.load_subject(subject)
index = randint(hrtfset.num_indices)
cfmin, cfmax, cfN = 150*Hz, 5*kHz, 80
gain_max, gain_N = 8.0, 61
delay_N = 35
delay_max = delay_N/samplerate
# Change this to 10*second for equivalent picture to the paper
sound = whitenoise(200*ms).atlevel(80*dB)
apmodel = AllPairsModel(hrtfset, cfmin, cfmax, cfN,
gain_max, gain_N, delay_max, delay_N)
count = apmodel(sound, index)
# Complicated code to plot the output nicely
freqlabels = array([150*Hz, 1*kHz, 2*kHz, 3*kHz, 4*kHz, 5*kHz])
fig_mew = 1 # marker edge width (in points)
num_indices = hrtfset.num_indices
from scipy.ndimage.filters import *
itd, ild = hrtfset_itd_ild(hrtfset, cfmin, cfmax, cfN)
delays = array([itd[index][i] for i in xrange(cfN)])
gains = array([ild[index][i] for i in xrange(cfN)])
gains = 20*log10(gains)
delays = -array(delays*samplerate, dtype=int)+delay_N-1
arrgains = linspace(-gain_max, gain_max, gain_N)
gains = digitize(gains, 0.5*(arrgains[1:]+arrgains[:-1]))
gains = gain_N-1-gains
def dofig(count, blur=0, blurmode='reflect', freqlabels=None):
count = array(count, dtype=float)
ocount = count
count = copy(ocount)
count.shape = (cfN, gain_N, delay_N*2-1)
count = amax(count, axis=1)
count.shape = (cfN, delay_N*2-1)
subplot(121)
count = gaussian_filter(count, blur, mode=blurmode)
imshow(count, origin='lower left', interpolation='nearest', aspect='auto',
extent=(-float(delay_N/samplerate/msecond), float(delay_N/samplerate/msecond), 0, cfN))
plot((delays-delay_N)/samplerate/msecond, arange(cfN), '+', color=(0,0,0), mew=fig_mew)
plot((argmax(count, axis=1)-delay_N)/samplerate/msecond, arange(cfN), 'x', color=(1,1,1), mew=fig_mew)
axis((float(-delay_N/samplerate/msecond), float(delay_N/samplerate/msecond), 0, cfN))
xlabel('Delay (ms)')
if freqlabels is None:
yticks([])
ylabel('Channel')
else:
cf = erbspace(cfmin, cfmax, cfN)
j = digitize(freqlabels, .5*(cf[1:]+cf[:-1]))
yticks(j, map(str, array(freqlabels, dtype=int)))
ylabel('Channel (Hz)')
subplot(122)
count = copy(ocount)
count.shape = (cfN, gain_N, delay_N*2-1)
count = amax(count, axis=2)
count.shape = (cfN, gain_N)
count = gaussian_filter(count, blur, mode=blurmode)
imshow(count, origin='lower left', interpolation='nearest', aspect='auto')
plot(gains, arange(cfN), '+', color=(0,0,0), mew=fig_mew)
plot(argmax(count, axis=1), arange(cfN), 'x', color=(1,1,1), mew=fig_mew)
axis('tight')
xlabel('Relative gain (dB)')
xticks([0, (gain_N-1)/2, gain_N-1], [str(min(arrgains)), '0', str(max(arrgains))])
if freqlabels is None:
yticks([])
ylabel('Channel')
else:
cf = erbspace(cfmin, cfmax, cfN)
j = digitize(freqlabels, .5*(cf[1:]+cf[:-1]))
yticks(j, map(str, array(freqlabels, dtype=int)))
ylabel('Channel (Hz)')
dofig(count, freqlabels=freqlabels)
figure()
dofig(count, blur=1)#, freqlabels=[500, 1000, 2000, 3000, 4000, 5000])
figure()
dofig(count, blur=2)
show()