/
test_utils.py
63 lines (48 loc) · 1.44 KB
/
test_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import unittest
from collections import OrderedDict
import utils
def mk_table(a, b):
assert len(a) == len(b)
return OrderedDict(zip(a, b))
def tbij(tester, f, t, test_bf, test_bt):
a, b = utils.construct_bijection(mk_table(f, t))
tester.assertEqual(set(a.items()), set(zip(test_bf, test_bt)))
tester.assertEqual(set(b.items()), set(zip(test_bt, test_bf)))
class TestBijection(unittest.TestCase):
def test_bij(self):
tbij(self,
[0, 2, 4, 6],
[3, 2, 1, 4],
[0, 1, 3, 4, 6],
[3, 6, 0, 1, 4])
def test_bij2(self):
tbij(self,
[0, 1, 3, 4, 5],
[0, 3, 4, 2, 1],
[1, 2, 3, 4, 5],
[3, 5, 4, 2, 1])
class TestSampleByScores(unittest.TestCase):
def test_rnd(self):
choices = list(range(10))
scores = [1] * 10
picks = [0] * 10
for i in range(100000):
take = utils.sample_by_scores(choices, scores)
picks[take] += 1
# XXX
self.assertTrue(all(p >= 9700 for p in picks))
def test_small(self):
self.assertEqual(
utils.sample_by_scores((666,), (0,)),
666
)
self.assertEqual(
utils.sample_by_scores((666,), (1,)),
666
)
self.assertEqual(
utils.sample_by_scores((666,), (100,)),
666
)
if __name__ == "__main__":
unittest.main()