forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_futures.py
163 lines (120 loc) · 4.45 KB
/
test_futures.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
import threading
import time
import torch
import unittest
from torch.futures import Future
from torch.testing._internal.common_utils import IS_WINDOWS, TestCase, TemporaryFileName, run_tests
def add_one(fut):
return fut.wait() + 1
class TestFuture(TestCase):
def test_done(self) -> None:
f = Future[torch.Tensor]()
self.assertFalse(f.done())
f.set_result(torch.ones(2, 2))
self.assertTrue(f.done())
def test_done_exception(self) -> None:
err_msg = "Intentional Value Error"
def raise_exception(unused_future):
raise RuntimeError(err_msg)
f1 = Future[torch.Tensor]()
self.assertFalse(f1.done())
f1.set_result(torch.ones(2, 2))
self.assertTrue(f1.done())
f2 = f1.then(raise_exception)
self.assertTrue(f2.done())
with self.assertRaisesRegex(RuntimeError, err_msg):
f2.wait()
def test_wait(self) -> None:
f = Future[torch.Tensor]()
f.set_result(torch.ones(2, 2))
self.assertEqual(f.wait(), torch.ones(2, 2))
def test_wait_multi_thread(self) -> None:
def slow_set_future(fut, value):
time.sleep(0.5)
fut.set_result(value)
f = Future[torch.Tensor]()
t = threading.Thread(target=slow_set_future, args=(f, torch.ones(2, 2)))
t.start()
self.assertEqual(f.wait(), torch.ones(2, 2))
t.join()
def test_mark_future_twice(self) -> None:
fut = Future[int]()
fut.set_result(1)
with self.assertRaisesRegex(
RuntimeError,
"Future can only be marked completed once"
):
fut.set_result(1)
def test_pickle_future(self):
fut = Future[int]()
errMsg = "Can not pickle torch.futures.Future"
with TemporaryFileName() as fname:
with self.assertRaisesRegex(RuntimeError, errMsg):
torch.save(fut, fname)
def test_then(self):
fut = Future[torch.Tensor]()
then_fut = fut.then(lambda x: x.wait() + 1)
fut.set_result(torch.ones(2, 2))
self.assertEqual(fut.wait(), torch.ones(2, 2))
self.assertEqual(then_fut.wait(), torch.ones(2, 2) + 1)
def test_chained_then(self):
fut = Future[torch.Tensor]()
futs = []
last_fut = fut
for _ in range(20):
last_fut = last_fut.then(add_one)
futs.append(last_fut)
fut.set_result(torch.ones(2, 2))
for i in range(len(futs)):
self.assertEqual(futs[i].wait(), torch.ones(2, 2) + i + 1)
def _test_error(self, cb, errMsg):
fut = Future[int]()
then_fut = fut.then(cb)
fut.set_result(5)
self.assertEqual(5, fut.wait())
with self.assertRaisesRegex(RuntimeError, errMsg):
then_fut.wait()
def test_then_wrong_arg(self):
def wrong_arg(tensor):
return tensor + 1
self._test_error(wrong_arg, "unsupported operand type.*Future.*int")
def test_then_no_arg(self):
def no_arg():
return True
self._test_error(no_arg, "takes 0 positional arguments but 1 was given")
def test_then_raise(self):
def raise_value_error(fut):
raise ValueError("Expected error")
self._test_error(raise_value_error, "Expected error")
def test_collect_all(self):
fut1 = Future[int]()
fut2 = Future[int]()
fut_all = torch.futures.collect_all([fut1, fut2])
def slow_in_thread(fut, value):
time.sleep(0.1)
fut.set_result(value)
t = threading.Thread(target=slow_in_thread, args=(fut1, 1))
fut2.set_result(2)
t.start()
res = fut_all.wait()
self.assertEqual(res[0].wait(), 1)
self.assertEqual(res[1].wait(), 2)
t.join()
@unittest.skipIf(IS_WINDOWS, "TODO: need to fix this testcase for Windows")
def test_wait_all(self):
fut1 = Future[int]()
fut2 = Future[int]()
# No error version
fut1.set_result(1)
fut2.set_result(2)
res = torch.futures.wait_all([fut1, fut2])
print(res)
self.assertEqual(res, [1, 2])
# Version with an exception
def raise_in_fut(fut):
raise ValueError("Expected error")
fut3 = fut1.then(raise_in_fut)
with self.assertRaisesRegex(RuntimeError, "Expected error"):
torch.futures.wait_all([fut3, fut2])
if __name__ == '__main__':
run_tests()