-
Notifications
You must be signed in to change notification settings - Fork 1
/
incnoise.py
450 lines (383 loc) · 15 KB
/
incnoise.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
from math import log, exp
import itertools
import random
import functools
from collections import Counter
import rfutils
from pmonad import *
from math import log2
# "Control characters"
HALT = "!H"
START = "!S"
ERASED = "!E"
DEFAULT_NUM_SAMPLES = 10000
def buildup(lst, start=0, end=None):
""" Given an iterable [a, b, ...] generate tuples [], [a], [a, b], ...
starting with the tuple of length `start` and ending with the list of length
`end`. """
for i in range(start, len(lst)+1 if end is None else end):
yield tuple(lst[:i])
# take_alternating :: Bool x [a] -> [a]
def take_alternating(start, xs):
include = start
for x in xs:
if include:
yield x
include = not include
# take_even :: [a] -> [a]
def take_even(xs):
return take_alternating(True, xs)
# take_odd :: [a] -> [a]
def take_odd(xs):
return take_alternating(False, xs)
# replicate :: [a] -> [a]
def replicate(xs, n):
for x in xs:
for _ in range(n):
yield x
####### Character-level noise functions ############
def switch_letter(letter):
return {'A': 'B', 'B': 'A', 'a': 'b', 'b': 'a'}[letter]
def switch_case(letter):
if letter.islower():
return letter.upper()
elif letter.isupper():
return letter.lower()
else:
raise ValueError
@enumerator
def switching_noise(letter, p):
yield switch_letter(letter), log(p)
yield switch_case(letter), log(p)
yield letter, log(1 - 2*p)
def sample_bit_noise(bit, p):
if sample_flip(p):
return str(1 - int(bit))
else:
return bit
# bit_noise :: Char x Float -> Enum Char
def enum_bit_noise(bit, p):
return Enumeration.flip(p).bind(
Enumeration.lift_ret(lambda b: str(1 - int(bit)) if b else bit)
)
def enum_bit_noise_by_word(bits, p):
bits_enum = Enumeration.mapM(lambda bit: enum_bit_noise(bit, p), bits)
return Enumeration.lift("".join)(bits_enum)
# maybe_erase :: a x Float -> Enum (Maybe a)
def maybe_erase(x, p):
return Enumeration.flip(p).bind(
Enumeration.lift_ret(lambda b: ERASED if b else x)
)
####### Sequence-level noise functions ############
# successive_noise :: [a] x ([a] -> Enum [a]) -> Enum [a]
@enumerator
def successive_noise(iterable, noise):
@enumerator
def apply_noise(acc, x):
for noisy_new_acc, p in noise(acc + (x,)):
yield noisy_new_acc, p
return Enumeration.reduceM(apply_noise, iterable, initial=())
# successive_noise :: ([a] -> Enum [a]) -> ([a] -> Enum [a])
def successive_noise_by_symbol(noise, *a, **k):
def noisy(xs):
return successive_noise(
xs,
lambda xs: Enumeration.mapM(lambda x: noise(x, *a, **k), xs)
)
return noisy
def test_successive_noise():
@enumerator
def noise(x):
yield "A(%s)" % x, log(3/4)
yield "B(%s)" % x, log(1/4)
noise_f = lambda acc: Enumeration.mapM(noise, acc)
sequence = tuple('xy')
assert dict(successive_noise(sequence, noise_f)) == {
('A(A(x))', 'A(y)'): -0.8630462173553427,
('A(A(x))', 'B(y)'): -1.9616585060234524,
('A(B(x))', 'A(y)'): -1.9616585060234524,
('A(B(x))', 'B(y)'): -3.0602707946915624,
('B(A(x))', 'A(y)'): -1.9616585060234524,
('B(A(x))', 'B(y)'): -3.060270794691562,
('B(B(x))', 'A(y)'): -3.0602707946915624,
('B(B(x))', 'B(y)'): -4.1588830833596715
}
@enumerator
def noise(x):
yield x+1, log(3/4)
yield x-1, log(1/4)
# now noise_f will have a reference to the new noise function
sequence = (1, 2)
assert dict(successive_noise(sequence, noise_f)) == {
(-1, 1): -4.1588830833596715,
(-1, 3): -3.0602707946915624,
(1, 1): -2.367123614131617,
(1, 3): -1.2685113254635072,
(3, 1): -1.9616585060234524,
(3, 3): -0.8630462173553427
}
# successive_erasure_noise :: [a] x Float -> Enum [Maybe a]
def successive_erasure_noise(xs, p):
return successive_noise(xs, lambda xs: sequence_erasure_noise(xs, p))
def sample_successive_erasure_noise(xs, p):
def gen():
for x in xs:
if sample_flip(p):
yield ERASED
else:
yield x
return tuple(gen())
def approx_successive_erasure_noise(xs, p, num_samples=DEFAULT_NUM_SAMPLES):
return enumeration_from_sampling_function(
lambda: sample_successive_erasure_noise(xs, p),
num_samples
)
# successive_deletion_noise :: [a] x Float -> Enum [a]
def successive_deletion_noise(xs, p):
return successive_noise(xs, lambda xs: sequence_deletion_noise(xs, p))
# successive_bit_noise :: [a] x Float -> Enum [a]
def successive_bit_noise(xs, p):
return successive_noise(xs, lambda xs: sequence_bit_noise(xs, p))
def sample_successive_bit_noise(xs, p):
so_far = ()
for x in xs:
so_far = sample_sequence_bit_noise(so_far + (x,), p)
return so_far
def approx_successive_bit_noise(xs, p, num_samples=DEFAULT_NUM_SAMPLES):
return enumeration_from_sampling_function(
lambda: sample_successive_bit_noise(xs, p),
num_samples
)
# sequence_erasure_noise :: [a] x Float -> Enum [Maybe a]
def sequence_erasure_noise(xs, p):
return Enumeration.mapM(lambda x: maybe_erase(x, p), xs)
# sequence_deletion_noise :: [a] x Float -> Enum [a]
def sequence_deletion_noise(xs, p):
@enumerator
def maybe_delete(acc, x):
return Enumeration.flip(p).bind(
Enumeration.lift_ret(lambda b: acc if b else acc + (x,))
)
return Enumeration.reduceM(maybe_delete, xs, initial=())
# sequence_bit_noise :: [a] x Float -> Enum [a]
def sequence_bit_noise(xs, p):
return Enumeration.mapM(lambda x: enum_bit_noise_by_word(x, p), xs)
def sample_sequence_bit_noise(xs, p):
return tuple("".join(sample_bit_noise(c, p) for c in x) for x in xs)
def approx_sequence_bit_noise(xs, p, num_samples=DEFAULT_NUM_SAMPLES):
return enumeration_from_sampling_function(
lambda: sample_sequence_bit_noise(xs, p),
num_samples
)
# noisy_channel_prefix_tree :: Enum [a] x ([a] -> Enum [b]) -> PT [a] a
def noisy_channel_prefix_tree(lang, noise):
assert isinstance(lang, Enumeration)
mul = lang.field.mul
npt = noisy_prefix_tree(lang, noise)
def traverse_from(real_prefix, perceived_prefix, p_so_far):
for value, p_value in npt[perceived_prefix]:
yield (real_prefix, value), mul(p_so_far, p_value)
if value != HALT:
new_perceived_prefix = perceived_prefix + (value,)
new_real_prefix = real_prefix + (value,)
for noisy_prefix, p_noise in noise(new_perceived_prefix):
new_p_so_far = mul(mul(p_so_far, p_value), p_noise)
yield from traverse_from(
new_real_prefix,
noisy_prefix,
new_p_so_far
)
return type(lang)(traverse_from((), (), lang.field.one)).conditional()
# lang_from_prefix_tree :: PT [a] a -> Enum [a]
def lang_from_prefix_tree(prefix_tree):
def traverse_from(prefix, logp_prefix):
for value, p_value in prefix_tree[prefix]:
p = ring.mul(p_prefix, p_value)
if value == HALT:
yield prefix, p
elif prefix + (value,) in prefix_tree:
yield from traverse_from(prefix + (value,), p)
ring = prefix_tree[rfutils.first(prefix_tree.keys())].field
return traverse_from((), ring.one)
def test_noisy_channel_prefix_tree():
lang = UniformEnumeration(['AA', 'Bb', 'aA', 'bb'])
p = .1
noise = lambda xs: successive_erasure_noise(xs, p)
ncf = noisy_channel_prefix_tree(lang, noise)
assert is_close(ncf[('A',)]['A'], log(19/20))
assert is_close(ncf[('A',)]['b'], log(1/20))
assert is_close(ncf[('B',)]['A'], log(1/20))
assert is_close(ncf[('B',)]['b'], log(19/20))
def monadic_noisy_prefix_tree(lang, noise):
# Amazingly, this does the same thing as noisy_prefix_tree!
joint = lang >> (lambda s:
certainly(tuple(s) + (HALT,)) >> (lambda s:
uniform(buildup(s, start=1)) >> (lambda prefix:
certainly((prefix[:-1], prefix[-1])) >> (lambda pair:
noise(pair[0]) >> (lambda noisy_prefix:
certainly((noisy_prefix, pair[-1])))))))
return joint.conditional()
# noisy_prefix_tree :: Enum [a] x ([a] -> Enum [b]) -> PT [b] a
def noisy_prefix_tree(lang, noise):
""" Convert a joint probability distribution into a noisy prefix tree
probability distribution, where contexts have had a noise e.f. applied
to them. The noise e.f. must operate over sequences. """
assert isinstance(lang, Enumeration)
d = {}
add = lang.field.add
mul = lang.field.mul
for string, p in lang:
string = list(string) + [HALT]
prefixes = buildup(string, start=1)
for prefix in prefixes:
*context, x = prefix
context = tuple(context)
for noisy_context, p_noise in noise(context):
if noisy_context in d:
if x in d[noisy_context]:
d[noisy_context][x] = add(
d[noisy_context][x],
mul(p, p_noise)
)
else:
d[noisy_context][x] = mul(p, p_noise)
else:
d[noisy_context] = {x: mul(p, p_noise)}
for prefix, prefix_distro in d.items():
d[prefix] = type(lang)(prefix_distro).normalize()
return d
def test_noisy_prefix_tree():
lang = UniformEnumeration(['AA', 'Bb', 'aA', 'bb'])
noise = lambda xs: successive_noise(
xs,
lambda xs: Enumeration.mapM(lambda x: switching_noise(x, .1), xs)
)
nf = noisy_prefix_tree(lang, noise)
# p_C(A|A) = \frac{\sum_w p_L(A|w) p_N(A|w) p_L(w)}
# {\sum_w p_N(A|w) p_L(w)}
# Say p_L(w) = 1/4,
# p_L(A|A) = 1, p_L(A|a) = 1, p_L(A|B) = 0, p_L(A|b) = 0
# p_N(w'|w) = 1/10 where w' != w
# p_N(w|w) = 8/10
# p_C(A|A) = (1 * 8/10 * 1/4 + 1 * 1/10 * 1/4) / (1/4)
# = 9/10
assert all(is_close(logp, log(1/4)) for logp in nf[()].dict.values())
assert is_close(nf[('A',)]['A'], log(9/10))
assert is_close(nf[('A',)]['b'], log(1/10))
assert is_close(nf[('A', 'A')][HALT], log(1))
p = .1
noise = lambda xs: sequence_erasure_noise(xs, p)
nf = noisy_prefix_tree(lang, noise)
assert all(logp == log(1/4) for logp in nf[()].dict.values())
assert is_close(nf[(ERASED,)]['A'], log(1/2))
assert is_close(nf[(ERASED,)]['b'], log(1/2))
def is_close(x, y, eps=10**-5):
return abs(x-y) < eps
# prefix_tree :: Enum [a] -> PT [a] a
def prefix_tree(lang):
""" Convert a joint probability distribution into a prefix tree probability
distribution. """
assert isinstance(lang, Enumeration)
add = lang.field.add
d = {}
for string, p in lang:
string = list(string) + [HALT]
prefixes = buildup(string, start=1)
for prefix in prefixes:
*context, x = prefix
context = tuple(context)
if context in d:
if x in context:
d[context][x] = add(d[context][x], p)
else:
d[context][x] = p
else:
d[context] = {x: p}
# Normalize the prefixes
for prefix, prefix_distro in d.items():
d[prefix] = type(lang)(prefix_distro).normalize()
return d
def test_prefix_tree():
d = prefix_tree(Enumeration([
('abc', log(.4)),
('acb', log(.4)),
('acd', log(.2))
]))
d_dict = {k:dict(v) for k, v in d.items()}
assert d_dict == {
(): {'a': 0.0},
('a',): {'b': -0.4054651081081645, 'c': -1.0986122886681098},
('a', 'b'): {'c': 0.0},
('a', 'b', 'c'): {HALT: 0.0},
('a', 'c'): {'b': -0.4054651081081645, 'd': -1.0986122886681098},
('a', 'c', 'b'): {HALT: 0.0},
('a', 'c', 'd'): {HALT: 0.0}
}
# cost :: Float -> Float
def cost(p):
return -log2(p)
# internal_string_cost :: Enum [a] x [a] x ([a] -> Enum [a]) -> Float
def internal_string_cost(lang, string, noise):
assert lang.field == log_space
noise = rfutils.memoize(noise)
pt = noisy_prefix_tree(lang, noise)
return sum(internal_symbol_costs(pt, string, noise))
# internal_symbol_costs :: PT [b] a x [a] x ([a] -> Enum [b]) -> [Float]
def internal_symbol_costs(pt, string, noise):
prefixes = buildup(list(string) + [HALT], start=1)
for prefix in prefixes:
*context, x = prefix
context = tuple(context)
# TODO this requires that the noise function enumerate exactly
# the same things as in previous calls --
# as a stopgap, try memoizing the random enumerator?
for noisy_context, logp_noise in noise(context):
yield exp(logp_noise) * cost(exp(pt[tuple(noisy_context)][x]))
# internal_lang_cost :: Enum [a] x ([a] -> Enum [b]) -> Float
def internal_lang_cost(lang, noise):
""" The expected average surprisal of each word given noisy representation
of the prefix. """
noise = rfutils.memoize(noise)
pt = noisy_prefix_tree(lang, noise)
return lang.expectation(lambda s: sum(internal_symbol_costs(pt, s, noise)))
def verbose_internal_lang_cost(lang, noise):
assert lang.field == log_space
pt = noisy_prefix_tree(lang, noise)
def string_costs():
for s, logp_s in lang:
print("p_L(%s) = %s" % (s, exp(logp_s)))
s = list(s)
the_cost = 0
prefixes = buildup(list(s) + [HALT], start=1)
for prefix in prefixes:
*context, x = prefix
for noisy_context, logp_noise in noise(context):
part_cost = cost(exp(pt[tuple(noisy_context)][x]))
print("p_N(%s | %s) = %s" % (noisy_context, context, exp(logp_noise)))
print("p(%s | %s) = %s" % (x, noisy_context, 2**(-part_cost)))
the_cost += exp(logp_s) * exp(logp_noise) * part_cost
yield the_cost
return sum(string_costs())
# external_symbol_costs :: PT [a] a x [a] x ([a] -> Enum [b]) -> [Float]
def external_symbol_costs(pt, string, noise):
prefixes = buildup(list(string) + [HALT], start=1)
for prefix in prefixes:
*context, x = prefix
yield cost(exp(pt[tuple(context)][x]))
# external_lang_cost :: Enum a x (a -> Enum b) -> Float
def external_lang_cost(lang, noise):
""" The expected average surprisal of each word as generated by a noisy
producer. """
assert lang.field == log_space
pt = noisy_channel_prefix_tree(lang, noise)
return lang.expectation(lambda s: sum(external_symbol_costs(pt, s, noise)))
def sample_flip(p):
return random.random() < p
p = .1
enoise = lambda xs: successive_erasure_noise(xs, p)
dnoise = lambda xs: successive_deletion_noise(xs, p)
snoise = successive_noise_by_symbol(switching_noise, p)
bnoise = lambda xs: successive_bit_noise(xs, p)
if __name__ == '__main__':
import nose
nose.runmodule()