In [1]:
import rpyc
import numpy as np

In [2]:
c = rpyc.connect("localhost", 18861)
c.root

<__main__.ProcessingService object at 0x7f1da3b3ed00>

In [4]:
c.root.process([10, 10])

[100, 100]

In [5]:
c.close()

In [2]:
import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(784, 50)
        self.fc2 = nn.Linear(50, 10)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x)

In [6]:
import rpyc
import numpy as np
from computer_utils import SimpleSplitter, SimpleMerger, Overseer
from multiprocessing import Queue, Manager, Process
import torch

rpyc.core.protocol.DEFAULT_CONFIG['allow_pickle'] = True


class Computer:
    
    class Specification:
        
        def __init__(self, weight, splitter_mode, op_name, merge_ax):
            self.weight = weight
            self.splitter_mode = splitter_mode
            self.op_name = op_name
            self.merge_ax = merge_ax
    
    def __init__(self, data, path):
        self.data = data
        self.object_queue = Queue()
        self.model = torch.load(path)
        
        self.specs = [
            self.Specification(
                weight = self.model.fc1.weight.detach().numpy(),
                splitter_mode = 'weight',
                op_name = 'DotProduct',
                merge_ax = 0
            ), 
            self.Specification(
                weight = self.model.fc1.bias.detach().numpy(),
                splitter_mode = 'both',
                op_name = 'Addition',
                merge_ax = 0
            ),
            self.Specification(
                weight = None,
                splitter_mode = 'data',
                op_name = 'ReLu',
                merge_ax = 0
            ),
            self.Specification(
                weight = self.model.fc2.weight.detach().numpy(),
                splitter_mode = 'weight',
                op_name = 'DotProduct',
                merge_ax = 0
            ), 
            self.Specification(
                weight = self.model.fc2.bias.detach().numpy(),
                splitter_mode = 'both',
                op_name = 'Addition',
                merge_ax = 0
            ),
            self.Specification(
                weight = np.array([[0]]),
                splitter_mode = 'weight',
                op_name = 'SoftMax',
                merge_ax = 0
            ),
        ] 
        
        self.op_names = []
        self.model.eval()
        
        self.conns = {}
        
        self.splitter = SimpleSplitter()
        self.merger = SimpleMerger()
        
    def add_connection(self, ip = 'localhost', port = 18861):
        self.conns[f'{ip}:{port}'] = rpyc.connect(ip, port)
        
    def compute(self):
        for spec in self.specs:
            # creating a queue for data storage
            queue = Queue()
            sample_num = self.splitter.split(self.data, spec.weight, queue, mode = spec.splitter_mode)
            
            # creating Overseers for the Processors
            self.ovs = [Overseer(conn, spec.op_name) for conn in self.conns.values()]
            
            # Manager for the shared resources
            manager = Manager()
            res = manager.dict()
            # initializing task completeness list
            shared_status = manager.list([0] * sample_num)
            
            # created added task list to the Overseer
            processes = [Process(target=O.put, args = (queue, res, shared_status)) for O in self.ovs]
            [p.start() for p in processes]
            [p.join() for p in processes]
            
            # merging the resuts 
            self.data = self.merger.merge(res, axis = spec.merge_ax)
        
        # closing connections
        for c in self.conns.values():
            c.close()
        
        return self.data
    
if __name__ == "__main__":
    data = np.ones((1, 784))
    cmp = Computer(data, 'model.pth')
    cmp.add_connection()
    cmp.add_connection(port = 12258)
    q = cmp.compute()
    print('Resulting data:', q, sep = '\n')

Stack!
Stack!
Stack!
Stack!
Stack!
Stack!
Resulting data:
[[4.38180807e-04 2.16057778e-12 4.38012854e-02 6.86038996e-01
  1.79831414e-11 3.68350047e-02 2.05820696e-07 3.02527847e-10
  2.32886321e-01 5.75566192e-09]]


In [4]:
q

array([[4.38180807e-04, 2.16057778e-12, 4.38012854e-02, 6.86038996e-01,
        1.79831414e-11, 3.68350047e-02, 2.05820696e-07, 3.02527847e-10,
        2.32886321e-01, 5.75566192e-09]])

In [15]:
model = torch.load('model.pth')
model.eval()
model.fc1(torch.ones((1, 784), requires_grad = False))

tensor([[ 4.1781,  6.7648, -2.9683,  6.1687,  3.4663,  0.9767,  0.6334, -0.0778,
          0.0219,  6.3792,  7.0007, -2.3949,  1.8431, -1.2723,  4.8371,  3.7381,
          3.0966,  0.5337, -0.0786, -2.5660, -5.8470,  2.5215,  1.5484,  2.8317,
          2.2553, -2.8291,  0.8447,  2.5690,  3.1215, -1.9210, -0.3486,  1.2491,
          5.9591,  4.0934,  3.4655,  8.0855,  4.9449, -3.3025,  2.2341, -0.6910,
         -0.8201,  5.1374,  8.6047,  4.6303, -3.5164,  6.0302,  3.7816,  4.2053,
         -0.9560,  2.1111]], grad_fn=<AddmmBackward>)

In [12]:
weights = model.fc1.weight.detach()
data = torch.ones((1, 784))

In [17]:
model.fc1.bias

Parameter containing:
tensor([-0.1251,  0.1645, -0.0103,  0.0858,  0.2245, -0.0058,  0.1645,  0.0218,
         0.1902,  0.0668, -0.0794, -0.1141, -0.1522, -0.0211,  0.0174, -0.1217,
        -0.0869,  0.0286, -0.0041,  0.1410,  0.1522,  0.1439,  0.0779,  0.0437,
        -0.0465, -0.0324,  0.1304,  0.0986,  0.0846,  0.1418, -0.0851,  0.1867,
        -0.0724,  0.1731, -0.0317, -0.0270,  0.0339,  0.1084, -0.0116,  0.0004,
         0.0542, -0.0673, -0.0363,  0.0425,  0.0897,  0.0531,  0.1969,  0.0214,
         0.0448, -0.0021], requires_grad=True)

In [8]:
q.get()[2].shape

(784,)

In [2]:
data = np.ones((2, 1))

In [3]:
data

array([[1.],
       [1.]])

In [4]:
cmp = Computer(data, 10)

In [5]:
cmp.add_connection()

In [6]:
cmp.process()

TypeError: put() missing 1 required positional argument: 'samples'

In [10]:
len({1: 5, 2: 3}.items())

2