In [1]:
import torch
from torch import nn
from torch.nn import functional as F

from TorchSpread import NetworkClient, NetworkManager
from TestMulti import ConvNet
from TestCorrectness import TwoLayerNet

from tqdm import tqdm
import numpy as np

from TorchSpread.RemoteManager import RemoteManager, RemoteHandler, RemoteCommands
from TorchSpread.utilities import serialize_int, serialize_buffer, deserialize_buffer
from TorchSpread.NetworkClient import RemoteClient

from torchvision import datasets
from lz4 import frame
import msgpack

In [2]:
output_shape = (3, )
output_type = torch.float32

input_shape = {'x': (8, )}
input_type = {'x': torch.float32}

manager = NetworkManager(input_shape, input_type, output_shape, output_type, 1024, TwoLayerNet,
                            network_args=[8, 32, 3], placement={'cpu': 1}, remote_manager=1234)
manager.start()

Starting Request Manager
Starting Response Manager
Starting Synchronization Manager
Starting Network b'N\x00' on cpu
Starting Local Network
Synchronizing initial weights
Starting remote manager on *:1234


In [7]:
data = {'x': torch.rand(1024, 8)}
# data = [torch.rand(1024, 8)]

In [9]:
with NetworkClient(manager.client_config, 1024) as client:
    print(client.predict(data))

tensor([[ 0.4074,  0.3170, -0.2234],
        [ 0.3462,  0.2686, -0.2014],
        [ 0.2283,  0.2391, -0.1812],
        ...,
        [ 0.2854,  0.1276, -0.1483],
        [ 0.3009,  0.1787, -0.1129],
        [ 0.3561,  0.2527, -0.2110]])


In [8]:
with RemoteClient(manager.client_config, 1024, 'localhost', 1234) as client:
    print(client.predict(data))

tensor([[ 0.4074,  0.3170, -0.2234],
        [ 0.3462,  0.2686, -0.2014],
        [ 0.2283,  0.2391, -0.1812],
        ...,
        [ 0.2854,  0.1276, -0.1483],
        [ 0.3009,  0.1787, -0.1129],
        [ 0.3561,  0.2527, -0.2110]])


In [10]:
with RemoteClient(manager.client_config, 1024, 'localhost', 1234) as client:
    client.predict_async(data)
    print(client.receive_async())

tensor([[ 0.4074,  0.3170, -0.2234],
        [ 0.3462,  0.2686, -0.2014],
        [ 0.2283,  0.2391, -0.1812],
        ...,
        [ 0.2854,  0.1276, -0.1483],
        [ 0.3009,  0.1787, -0.1129],
        [ 0.3561,  0.2527, -0.2110]])


In [11]:
manager.shutdown()

In [6]:
x = torch.rand((100, 64))
x

tensor([[0.4221, 0.2679, 0.7554,  ..., 0.1050, 0.0284, 0.5971],
        [0.6633, 0.6882, 0.4199,  ..., 0.0523, 0.6893, 0.4503],
        [0.9526, 0.6089, 0.2221,  ..., 0.1973, 0.5023, 0.0361],
        ...,
        [0.6948, 0.1353, 0.7705,  ..., 0.7079, 0.1976, 0.5450],
        [0.5291, 0.2652, 0.1406,  ..., 0.2775, 0.9925, 0.0059],
        [0.6024, 0.5355, 0.5793,  ..., 0.0253, 0.1379, 0.0032]])

In [29]:
from lz4 import frame
import _pickle as cPickle

In [97]:
def _serialize_buffer(buffer, compress):
    if isinstance(buffer, dict):
        return {key: _serialize_buffer(val, compress) for key, val in buffer.items()}
    elif isinstance(buffer, (list, tuple)):
        return [_serialize_buffer(val, compress) for val in buffer]
    elif isinstance(buffer, torch.Tensor):
        buffer = buffer.numpy()
        
    pickled = buffer.tobytes()
    pickled = frame.compress(pickled, compression_level=compress)
    
    shape = buffer.shape
    dtype = buffer.dtype.str
    
    return (shape, dtype, pickled)

def serialize_buffer(buffer, compress=3):
    serialized = _serialize_buffer(buffer, compress)
    serialized = cPickle.dumps(serialized)
    return serialized

def _deserialize_buffer(serialized):
    if isinstance(serialized, dict):
        return {key: _deserialize_buffer(val) for key, val in serialized.items()}
    elif isinstance(serialized, list):
        return [_deserialize_buffer(val) for val in serialized]    
    
    shape, dtype, pickled = serialized
    array = np.frombuffer(frame.decompress(pickled), np.dtype(dtype)).reshape(shape)
    return torch.from_numpy(array)
    
def deserialize_buffer(serialized):
    serialized = cPickle.loads(serialized)
    return _deserialize_buffer(serialized)

def _deserialize_buffer_into(to_buffer, serialized, size: int, start_index: int = 0):
    if isinstance(to_buffer, dict):
        for key, to_tensor in to_buffer.items():
            _deserialize_buffer_into(to_tensor, serialized[key], size, start_index)

    elif isinstance(to_buffer, (list, tuple)):
        for to_tensor, from_tensor in zip(to_buffer, serialized):
            _deserialize_buffer_into(to_tensor, from_tensor, size, start_index)

    else:
        shape, dtype, pickled = serialize
        from_buffer = np.frombuffer(frame.decompress(pickled), np.dtype(dtype)).reshape(shape)
        to_buffer[start_index:start_index + size].copy_(from_buffer[:size])
        
def deserialize_buffer_into(to_buffer, serialized, size: int, start_index: int = 0):
    serialized = cPickle.loads(serialized)
    return _deserialize_buffer_into(to_buffer, serialized, size, start_index)

In [83]:
np.ones(128).dtype.

SyntaxError: invalid syntax (<ipython-input-83-1dbc477277ca>, line 1)

In [91]:
buffer = {'x': torch.ones(1024, 64)}

In [98]:
deserialize_buffer(serialize_buffer(buffer, 0))

{'x': tensor([[1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         ...,
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.]])}

In [23]:
import numpy as np
import dill



In [24]:
%timeit cPickle.dumps((shape, pickled))

1.02 µs ± 5.7 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)


In [21]:
send_pickle = dill.dumps((shape, pickled))

In [14]:
np.frombuffer(pickled).shape

(3200,)

In [3]:
client = NetworkClient(manager.client_config, 32)
client.register()

In [4]:
client.predict({'x': np.random.rand(12, 8).astype(np.float32)})

tensor([[ 0.0258, -0.3191,  0.2729],
        [-0.0946, -0.2570,  0.1323],
        [ 0.0021, -0.2726,  0.2965],
        [ 0.0216, -0.1899,  0.1824],
        [ 0.0584, -0.1536,  0.2238],
        [ 0.0502, -0.2478,  0.2079],
        [-0.0964, -0.2143,  0.2852],
        [-0.0412, -0.3186,  0.2473],
        [-0.1288, -0.2982,  0.2281],
        [-0.0270, -0.2914,  0.2370],
        [-0.0404, -0.2660,  0.2165],
        [-0.1373, -0.2541,  0.1506]])

In [5]:
client.predict_async({'x': torch.rand(12, 8)})

In [6]:
client.receive_async()

tensor([[ 0.0022, -0.3193,  0.2609],
        [-0.0137, -0.2544,  0.1116],
        [-0.1923, -0.2023,  0.0530],
        [ 0.0327, -0.1611,  0.2772],
        [ 0.0698, -0.1876,  0.1695],
        [-0.0821, -0.2483,  0.0891],
        [-0.0438, -0.2481,  0.2114],
        [-0.1529, -0.2210,  0.1987],
        [ 0.0363, -0.1481,  0.2720],
        [-0.0607, -0.3359,  0.2068],
        [-0.2062, -0.2370,  0.1008],
        [-0.0058, -0.2721,  0.2263]])

In [5]:
manager._local_network({'x': torch.zeros(12, 8)})

tensor([[ 0.1722,  0.2221, -0.1633],
        [ 0.1722,  0.2221, -0.1633],
        [ 0.1722,  0.2221, -0.1633],
        [ 0.1722,  0.2221, -0.1633],
        [ 0.1722,  0.2221, -0.1633],
        [ 0.1722,  0.2221, -0.1633],
        [ 0.1722,  0.2221, -0.1633],
        [ 0.1722,  0.2221, -0.1633],
        [ 0.1722,  0.2221, -0.1633],
        [ 0.1722,  0.2221, -0.1633],
        [ 0.1722,  0.2221, -0.1633],
        [ 0.1722,  0.2221, -0.1633]], grad_fn=<AddmmBackward>)

In [1]:
manager.shutdown()

NameError: name 'manager' is not defined

In [None]:
for _ in tqdm(range(10000)):
    with NetworkClient(manager.client_config, 32) as client:
        pass

In [7]:
del clients

In [8]:
clients[0]

NameError: name 'clients' is not defined

In [6]:
client.predict_inplace()

tensor([[-0.0816, -0.0583, -0.2316],
        [-0.0816, -0.0583, -0.2316],
        [-0.0816, -0.0583, -0.2316],
        [-0.0816, -0.0583, -0.2316],
        [-0.0816, -0.0583, -0.2316],
        [-0.0816, -0.0583, -0.2316],
        [-0.0816, -0.0583, -0.2316],
        [-0.0816, -0.0583, -0.2316],
        [-0.0816, -0.0583, -0.2316],
        [-0.0816, -0.0583, -0.2316],
        [-0.0816, -0.0583, -0.2316],
        [-0.0816, -0.0583, -0.2316],
        [-0.0816, -0.0583, -0.2316],
        [-0.0816, -0.0583, -0.2316],
        [-0.0816, -0.0583, -0.2316],
        [-0.0816, -0.0583, -0.2316],
        [-0.0816, -0.0583, -0.2316],
        [-0.0816, -0.0583, -0.2316],
        [-0.0816, -0.0583, -0.2316],
        [-0.0816, -0.0583, -0.2316],
        [-0.0816, -0.0583, -0.2316],
        [-0.0816, -0.0583, -0.2316],
        [-0.0816, -0.0583, -0.2316],
        [-0.0816, -0.0583, -0.2316],
        [-0.0816, -0.0583, -0.2316],
        [-0.0816, -0.0583, -0.2316],
        [-0.0816, -0.0583, -0.2316],
 

In [7]:
client.deregister()

In [4]:
manager.shutdown()

In [None]:
            
man = Manager()
d = man.dict()
workers = [TestWorker(manager.client_config, i, d) for i in range(128)]

for worker in workers:
    worker.start()
    
for worker in workers:
    worker.ready.wait()

for worker in workers:
    worker.ready.clear()
    worker.start_event.set()
    
for worker in workers:
    worker.ready.wait()    

x = torch.zeros(1, *input_shape)

for i in range(len(workers)):
    with torch.no_grad():
        print("{}:{}".format(i, manager._local_network(x + i).numpy()  - d[i]))

In [23]:
with manager.training_network() as network:
    optimizer = torch.optim.SGD(network.parameters(), 0.0001)
    x = torch.zeros(128, *input_shape) + torch.unsqueeze(torch.arange(0, 128, dtype=torch.float), 1)
    y = torch.zeros(128, *output_shape) + torch.unsqueeze(torch.arange(0, 128, dtype=torch.float), 1)
    
    loss = torch.nn.MSELoss()(network(x), y)
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

In [25]:
multiprocessing.get_start_method()

'fork'

In [17]:
for worker in workers:
    worker.ready.clear()
    worker.start_event.set()
    
for worker in workers:
    worker.ready.wait()    

x = torch.zeros(1, *input_shape)

for i in range(len(workers)):
    with torch.no_grad():
        print("{}:{}".format(i, manager._local_network(x + i).numpy()  - d[i]))

0:[[0. 0. 0.]]
1:[[0. 0. 0.]]
2:[[0. 0. 0.]]
3:[[0. 0. 0.]]
4:[[0. 0. 0.]]
5:[[0. 0. 0.]]
6:[[0. 0. 0.]]
7:[[0. 0. 0.]]
8:[[0. 0. 0.]]
9:[[0. 0. 0.]]
10:[[0. 0. 0.]]
11:[[0. 0. 0.]]
12:[[0. 0. 0.]]
13:[[0. 0. 0.]]
14:[[0. 0. 0.]]
15:[[0. 0. 0.]]
16:[[0. 0. 0.]]
17:[[0. 0. 0.]]
18:[[0. 0. 0.]]
19:[[0. 0. 0.]]
20:[[0. 0. 0.]]
21:[[0. 0. 0.]]
22:[[0. 0. 0.]]
23:[[0. 0. 0.]]
24:[[0. 0. 0.]]
25:[[0. 0. 0.]]
26:[[0. 0. 0.]]
27:[[0. 0. 0.]]
28:[[0. 0. 0.]]
29:[[0. 0. 0.]]
30:[[0. 0. 0.]]
31:[[0. 0. 0.]]
32:[[0. 0. 0.]]
33:[[0. 0. 0.]]
34:[[0. 0. 0.]]
35:[[0. 0. 0.]]
36:[[0. 0. 0.]]
37:[[0. 0. 0.]]
38:[[0. 0. 0.]]
39:[[0. 0. 0.]]
40:[[0. 0. 0.]]
41:[[0. 0. 0.]]
42:[[0. 0. 0.]]
43:[[0. 0. 0.]]
44:[[0. 0. 0.]]
45:[[0. 0. 0.]]
46:[[0. 0. 0.]]
47:[[0. 0. 0.]]
48:[[0. 0. 0.]]
49:[[0. 0. 0.]]
50:[[0. 0. 0.]]
51:[[0. 0. 0.]]
52:[[0. 0. 0.]]
53:[[0. 0. 0.]]
54:[[0. 0. 0.]]
55:[[0. 0. 0.]]
56:[[0. 0. 0.]]
57:[[0. 0. 0.]]
58:[[0. 0. 0.]]
59:[[0. 0. 0.]]
60:[[0. 0. 0.]]
61:[[0. 0. 0.]]
62:[[0. 0. 0.]]
63

In [11]:
manager.shutdown()

In [8]:
x = torch.rand(1, 8)

In [9]:
for _ in range(32):
    for _ in range(1000):
        with torch.no_grad():
            manager._local_network(x)