In [1]:
import zmq
import torch
import pickle
import torch
from io import BytesIO
from torch.multiprocessing.reductions import ForkingPickler, init_reductions, reduce_tensor
import multiprocessing
from time import time
import numpy as np
import utilities
import cbor

import NetworkWorker
import NetworkManager
import NetworkClient

from NetworkManager import RequestManager, ResponseManager

torch.multiprocessing.set_sharing_strategy("file_system")

In [2]:
class TwoLayerNet(torch.nn.Module):
    def __init__(self, D_in, H, D_out):
        """
        In the constructor we instantiate two nn.Linear modules and assign them as
        member variables.
        """
        super(TwoLayerNet, self).__init__()
        self.linear1 = torch.nn.Linear(D_in, H)
        self.hidden = torch.nn.Linear(H, H)
        self.linear2 = torch.nn.Linear(H, D_out)

    def forward(self, x):
        """
        In the forward function we accept a Tensor of input data and we must return
        a Tensor of output data. We can use Modules defined in the constructor as
        well as arbitrary operators on Tensors.
        """
        h_relu = self.linear1(x).clamp(min=0)
        for _ in range(2):
            h_relu = self.hidden(h_relu).clamp(min=0)
        y_pred = self.linear2(h_relu)
        return y_pred

In [3]:
output_shape = (3, )
output_type = torch.float32

input_shape = (8, )
input_type = torch.float32

In [4]:
manager = NetworkManager.NetworkManager(TwoLayerNet, input_shape, input_type, output_shape, output_type, D_in=8, H=32, D_out=3, batch_size=64)
manager.start()

Starting Request Manager
Starting Response Manager
Starting Synchronization Manager
Starting Network b'N\x00'
Starting Local Network
Synchronizing initial weights


In [5]:
class TestWorker(torch.multiprocessing.Process):
    def __init__(self, config, count = 1_000):
        super().__init__()
        self.config = config
        self.count = count
        
        self.ready = torch.multiprocessing.Event()
        self.start_event = torch.multiprocessing.Event()
    
    def run(self):
        client = NetworkClient.NetworkClient(self.config, 1)
        client.register()
        self.ready.set()
        
        while True:
            self.start_event.wait()
            self.start_event.clear()
            
            for _ in range(self.count):
                client.predict_inplace()
                
            self.ready.set()

In [6]:
workers = [TestWorker(manager.client_config, 1_000) for _ in range(32)]

for worker in workers:
    worker.start()
    
for worker in workers:
    worker.ready.wait()

In [10]:
for worker in workers:
    worker.ready.clear()
    worker.start_event.set()
    
for worker in workers:
    worker.ready.wait()    

In [11]:
manager.shutdown()

In [8]:
x = torch.rand(1, 8)

In [9]:
for _ in range(32):
    for _ in range(1000):
        with torch.no_grad():
            manager._local_network(x)