forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_dist_autograd.py
116 lines (98 loc) · 4.49 KB
/
test_dist_autograd.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
from __future__ import absolute_import, division, print_function, unicode_literals
import sys
import torch.distributed as dist
import torch.distributed.autograd as dist_autograd
from common_distributed import MultiProcessTestCase
from functools import wraps
import six
import unittest
import torch
if not dist.is_available():
print("c10d not available, skipping tests")
sys.exit(0)
def dist_init(func):
"""
We use this decorator for setting up and tearing down state since
MultiProcessTestCase runs each `test*` method in a separate process and
each process just runs the `test*` method without actually calling
'setUp' and 'tearDown' methods of unittest.
"""
@wraps(func)
def wrapper(self):
self.worker_id = self.rank
store = dist.FileStore(self.file.name, self.world_size)
dist.init_process_group(backend='gloo', rank=self.rank,
world_size=self.world_size, store=store)
dist.init_model_parallel('worker%d' % self.rank)
func(self)
dist.join_rpc()
return wrapper
@unittest.skipIf(not six.PY3, "Pytorch distributed autograd package "
"does not support python2")
class TestDistAutograd(MultiProcessTestCase):
@property
def world_size(self):
return 4
@dist_init
def test_autograd_context(self):
context_ids = []
for i in range(1000):
with dist_autograd.context() as context_id:
self.assertEqual(context_id, dist_autograd._retrieve_context(context_id)._context_id())
# First 16 bits should be worker_id.
self.assertEqual(self.worker_id, context_id >> 48)
context_ids.append(context_id)
for context_id in context_ids:
with self.assertRaisesRegex(RuntimeError, 'Could not find autograd context with id: {}'.format(context_id)):
dist_autograd._retrieve_context(context_id)
@dist_init
def test_autograd_send_function(self):
dst_rank = (self.rank + 1) % self.world_size
with dist_autograd.context() as context_id:
t1 = torch.ones(3, 3, requires_grad=True)
t2 = torch.zeros(3, 3, requires_grad=True)
ret = dist.rpc('worker{}'.format(dst_rank), torch.add,
args=(t1, t2))
# Get send function.
ctx = dist_autograd._current_context()
self.assertEqual(context_id, ctx._context_id())
send_functions = ctx._send_functions()
self.assertEqual(1, len(send_functions))
# Retrieve the next functions in the graph.
next_funcs = send_functions[0].next_functions
self.assertEqual(2, len(next_funcs))
# We should now hit t1 and t2 in the autograd graph.
self.assertEqual('torch::autograd::AccumulateGrad', next_funcs[0][0].name())
self.assertEqual(t1, next_funcs[0][0].variable)
self.assertEqual(0, next_funcs[0][1])
self.assertEqual('torch::autograd::AccumulateGrad', next_funcs[1][0].name())
self.assertEqual(t2, next_funcs[1][0].variable)
self.assertEqual(0, next_funcs[1][1])
# autograd context should be cleaned up by now.
with self.assertRaises(RuntimeError):
ctx = dist_autograd._retrieve_context(context_id)
# No autograd context available.
with self.assertRaises(RuntimeError):
ctx = dist_autograd._current_context()
@dist_init
def test_rpc_complex_args(self):
dst_rank = (self.rank + 1) % self.world_size
with dist_autograd.context() as context_id:
num_tensors = 10
tensors = []
for i in range(num_tensors):
tensors.append(torch.ones(3, 3, requires_grad=(i % 2 == 0)))
ret = dist.rpc('worker{}'.format(dst_rank), torch.stack,
args=(tensors,))
self.assertEqual(torch.stack(tensors), ret)
# Verify appropriate tensors have been attached the autograd graph.
next_funcs = dist_autograd._current_context()._send_functions()[0].next_functions
idx = 0
for i in range(num_tensors):
if i % 2 == 0:
self.assertEqual('torch::autograd::AccumulateGrad', next_funcs[i][0].name())
self.assertEqual(tensors[i], next_funcs[i][0].variable)
else:
self.assertIsNone(next_funcs[i][0])
if __name__ == '__main__':
unittest.main()