In [1]:
import torch
from torch import nn
from torch.nn import functional as F

from NetworkManager import NetworkManager
from NetworkClient import NetworkClient
from TestMulti import TwoLayerNet

from tqdm import tqdm

ModuleNotFoundError: No module named 'NetworkManager'

In [22]:
from typing import Iterable, Iterator, Generator
def iterate_window(iterator: Iterable, n: int = 2) -> Generator:
    """ Iterate over a sliding window of an iterator with no overlap.

    Parameters
    ----------
    iterator: Iterable
        The target to iterate over. Total length should be a multiple of n.
    n: int
        Window size

    Returns
    -------
    Iterable
        Generator for veiweing every n elements of the iterator.
    """
    size = len(iterator)
    iterator = iter(iterator)
    for _ in range(0, size, n):
        yield (next(iterator) for _ in range(n))

In [16]:
a = [1, 2, 3, 4]

In [20]:
list(iterate_window(a))

[<generator object iterate_window.<locals>.<genexpr> at 0x7f5ed714ce58>,
 <generator object iterate_window.<locals>.<genexpr> at 0x7f5ed714c8b8>]

In [23]:
iterate_window(a)

<generator object iterate_window at 0x7f5ed714ced0>

In [7]:
b = iter(a)

In [12]:
next(b)

StopIteration: 

In [None]:
class TestWorker(torch.multiprocessing.Process):
    def __init__(self, config, idx, d):
        super().__init__()
        self.config = config
        self.idx = idx
        self.d = d
        
        self.ready = torch.multiprocessing.Event()
        self.start_event = torch.multiprocessing.Event()
    
    def run(self):
        client = NetworkClient.NetworkClient(self.config, 1)
        client.register()
        self.ready.set()
        
        while True:
            self.start_event.wait()
            self.start_event.clear()
            
            client.input_buffer[:] = torch.zeros_like(client.input_buffer) + self.idx
            out = client.predict_inplace()
            self.d[self.idx] = out.numpy()
            print("Worker {}: {}".format(self.idx, out))
                
            self.ready.set()

In [3]:
output_shape = (3, )
output_type = torch.float32

input_shape = {'x': (8, )}
input_type = {'x': torch.float32}

manager = NetworkManager(input_shape, input_type, output_shape, output_type, 1024, TwoLayerNet,
                            network_args=[8, 32, 3], placement={'cpu': 4})
manager.start()

Starting Request Manager
Starting Response Manager
Starting Synchronization Manager
Starting Network b'N\x00' on cpu
Starting Network b'N\x01' on cpu
Starting Network b'N\x02' on cpu
Starting Network b'N\x03' on cpu
Starting Local Network
Synchronizing initial weights


In [4]:
for _ in tqdm(range(10000)):
    with NetworkClient(manager.client_config, 32) as client:
        pass

100%|██████████| 10000/10000 [01:15<00:00, 132.35it/s]


In [7]:
del clients

In [8]:
clients[0]

NameError: name 'clients' is not defined

In [6]:
client.predict_inplace()

tensor([[-0.0816, -0.0583, -0.2316],
        [-0.0816, -0.0583, -0.2316],
        [-0.0816, -0.0583, -0.2316],
        [-0.0816, -0.0583, -0.2316],
        [-0.0816, -0.0583, -0.2316],
        [-0.0816, -0.0583, -0.2316],
        [-0.0816, -0.0583, -0.2316],
        [-0.0816, -0.0583, -0.2316],
        [-0.0816, -0.0583, -0.2316],
        [-0.0816, -0.0583, -0.2316],
        [-0.0816, -0.0583, -0.2316],
        [-0.0816, -0.0583, -0.2316],
        [-0.0816, -0.0583, -0.2316],
        [-0.0816, -0.0583, -0.2316],
        [-0.0816, -0.0583, -0.2316],
        [-0.0816, -0.0583, -0.2316],
        [-0.0816, -0.0583, -0.2316],
        [-0.0816, -0.0583, -0.2316],
        [-0.0816, -0.0583, -0.2316],
        [-0.0816, -0.0583, -0.2316],
        [-0.0816, -0.0583, -0.2316],
        [-0.0816, -0.0583, -0.2316],
        [-0.0816, -0.0583, -0.2316],
        [-0.0816, -0.0583, -0.2316],
        [-0.0816, -0.0583, -0.2316],
        [-0.0816, -0.0583, -0.2316],
        [-0.0816, -0.0583, -0.2316],
 

In [7]:
client.deregister()

In [4]:
manager.shutdown()

In [None]:
            
man = Manager()
d = man.dict()
workers = [TestWorker(manager.client_config, i, d) for i in range(128)]

for worker in workers:
    worker.start()
    
for worker in workers:
    worker.ready.wait()

for worker in workers:
    worker.ready.clear()
    worker.start_event.set()
    
for worker in workers:
    worker.ready.wait()    

x = torch.zeros(1, *input_shape)

for i in range(len(workers)):
    with torch.no_grad():
        print("{}:{}".format(i, manager._local_network(x + i).numpy()  - d[i]))

In [23]:
with manager.training_network() as network:
    optimizer = torch.optim.SGD(network.parameters(), 0.0001)
    x = torch.zeros(128, *input_shape) + torch.unsqueeze(torch.arange(0, 128, dtype=torch.float), 1)
    y = torch.zeros(128, *output_shape) + torch.unsqueeze(torch.arange(0, 128, dtype=torch.float), 1)
    
    loss = torch.nn.MSELoss()(network(x), y)
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

In [25]:
multiprocessing.get_start_method()

'fork'

In [17]:
for worker in workers:
    worker.ready.clear()
    worker.start_event.set()
    
for worker in workers:
    worker.ready.wait()    

x = torch.zeros(1, *input_shape)

for i in range(len(workers)):
    with torch.no_grad():
        print("{}:{}".format(i, manager._local_network(x + i).numpy()  - d[i]))

0:[[0. 0. 0.]]
1:[[0. 0. 0.]]
2:[[0. 0. 0.]]
3:[[0. 0. 0.]]
4:[[0. 0. 0.]]
5:[[0. 0. 0.]]
6:[[0. 0. 0.]]
7:[[0. 0. 0.]]
8:[[0. 0. 0.]]
9:[[0. 0. 0.]]
10:[[0. 0. 0.]]
11:[[0. 0. 0.]]
12:[[0. 0. 0.]]
13:[[0. 0. 0.]]
14:[[0. 0. 0.]]
15:[[0. 0. 0.]]
16:[[0. 0. 0.]]
17:[[0. 0. 0.]]
18:[[0. 0. 0.]]
19:[[0. 0. 0.]]
20:[[0. 0. 0.]]
21:[[0. 0. 0.]]
22:[[0. 0. 0.]]
23:[[0. 0. 0.]]
24:[[0. 0. 0.]]
25:[[0. 0. 0.]]
26:[[0. 0. 0.]]
27:[[0. 0. 0.]]
28:[[0. 0. 0.]]
29:[[0. 0. 0.]]
30:[[0. 0. 0.]]
31:[[0. 0. 0.]]
32:[[0. 0. 0.]]
33:[[0. 0. 0.]]
34:[[0. 0. 0.]]
35:[[0. 0. 0.]]
36:[[0. 0. 0.]]
37:[[0. 0. 0.]]
38:[[0. 0. 0.]]
39:[[0. 0. 0.]]
40:[[0. 0. 0.]]
41:[[0. 0. 0.]]
42:[[0. 0. 0.]]
43:[[0. 0. 0.]]
44:[[0. 0. 0.]]
45:[[0. 0. 0.]]
46:[[0. 0. 0.]]
47:[[0. 0. 0.]]
48:[[0. 0. 0.]]
49:[[0. 0. 0.]]
50:[[0. 0. 0.]]
51:[[0. 0. 0.]]
52:[[0. 0. 0.]]
53:[[0. 0. 0.]]
54:[[0. 0. 0.]]
55:[[0. 0. 0.]]
56:[[0. 0. 0.]]
57:[[0. 0. 0.]]
58:[[0. 0. 0.]]
59:[[0. 0. 0.]]
60:[[0. 0. 0.]]
61:[[0. 0. 0.]]
62:[[0. 0. 0.]]
63

In [11]:
manager.shutdown()

In [8]:
x = torch.rand(1, 8)

In [9]:
for _ in range(32):
    for _ in range(1000):
        with torch.no_grad():
            manager._local_network(x)