In [1]:
import torch
from torch.autograd import Variable
import inspect
import random
import copy
import json

In [2]:
class Worker():
    
    objects = {}
    
    def __init__(self,addr=0):
        self.addr = addr
    
    def register_object(self,obj,is_pointer_to_remote):
        obj.id = random.randint(0, 1e10)
        obj.owner = self
        obj.is_pointer_to_remote = False
        self.objects[obj.id] = obj
        return obj
    
    def send_obj(self,obj,to):
        to.receive_obj(obj.ser())
        obj.is_pointer_to_remote = True
        obj.owner = to
        return obj
    
    def request_obj(self,obj):
        response = obj.owner.receive_obj_request(obj.id)
        return self.receive_obj(response)
    
    def receive_obj_request(self,obj_id):
        return self.objects[obj_id].ser()
    
    def receive_obj(self,msg):
        dic = json.loads(msg)
        if(dic['type'] == 'torch.FloatTensor'):
            obj = torch.FloatTensor.de(dic)
            obj.is_pointer_to_remote = False
            obj.owner = self
            self.objects[obj.id] = obj
            return obj
    def send_command(self,command,to):
        return to.receive_command(command)

    def receive_command(self,command):
        if(command['base_type'] == 'torch.FloatTensor'):
            raw_response = torch.FloatTensor.process_command(self,command)
        
        return json.dumps(raw_response)
            
    
    def process_response(self,response):
        response = json.loads(response)
        tensor_ids = response
        out_tensors = list()
        for raw_msg in tensor_ids:
            msg = json.loads(raw_msg)
            if(msg["type"] == "torch.FloatTensor"):
                obj = torch.FloatTensor.de(msg)
            out_tensors.append(obj)
            
        if(len(out_tensors) > 1):
            return out_tensors
        elif(len(out_tensors) == 1):
            return out_tensors[0]
        else:
            return None
    
    def function2json(self, obj, name, frame, ix):
        
        args, varargs, keywords, values = inspect.getargvalues(frame)
        
        command = {}
        command['id'] = ix # This id is assigned as a placeholder for the data that the worker has
        command['command'] = name
        command['base_type'] = obj.type()
        command['args'] = args
        command['varargs'] =  varargs
        command['keywords'] = keywords
        command['values'] = [values[arg].id for arg in args]
        command['types'] = [type(val) for val in command['values']]
        
        return command

    def json2function(command):
        ""

    def object2json():
        ""

    def json2object():
        ""
        
me = Worker(0)

workers = list()
for i in range(10):
    workers.append(Worker(i+1))

In [3]:
# GENERIC

def assign_workers():
    def decorate(func):
        def send_to_workers(*args, **kwargs):
            if(args[0].is_pointer_to_remote):
                command = func(*args, **kwargs)
                response = me.send_command(command,args[0].owner)
                return me.process_response(response)
                
            else:
                return func(*args, **kwargs)
        return send_to_workers
    return decorate


# FLOAT TENSOR FUNCTIONS
def hook_float_tensor___init__():
    def new___init__(self,tensor,owner=me, *args, **kwargs):
        super(torch.FloatTensor, self).__init__(*args, **kwargs)
        self = owner.register_object(self,False)
     
    torch.FloatTensor.__init__ = new___init__


def hook_float_tensor_add():
    @assign_workers()
    def new_add(self, other):
        if(self.is_pointer_to_remote):
            frame = inspect.currentframe()
            command = self.owner.function2json(self,'add', frame, self.id)
            return command
        else:
            result = self.old_add(other)
            return me.register_object(result,True)

    try:
        torch.FloatTensor.old_add
    except:
        torch.FloatTensor.old_add = torch.FloatTensor.add
        
    torch.FloatTensor.add = new_add
    
def hook_float_tensor_serde():
    def ser(self, include_data=True):

        msg = {}
        msg['type'] = 'torch.FloatTensor'
        if(include_data):
            msg['data'] = self.tolist()
        msg['id'] = self.id
        msg['owner'] = self.owner.addr
        
        return json.dumps(msg)

    def de(msg):
        if(type(msg) == str):
            msg = json.loads(msg)
        if('data' in msg.keys()):
            v = torch.FloatTensor(msg['data'])
            v.owner = workers[msg['owner']]
        else:
            v = torch.zeros(0)
            v.owner = workers[msg['owner']]
            
        v.id = msg['id']
        return v

    torch.FloatTensor.ser = ser
    torch.FloatTensor.de = de 
    
def hook_float_tensor_send():
    def send(self,new_owner):
        self.owner.send_obj(self,new_owner)
        return self

    torch.FloatTensor.send = send
    
def hook_float_tensor_get():
    def get(self):
        self = me.request_obj(self)
        return self
    torch.FloatTensor.get = get
    
def hook_float_tensor_process_command():
    def process_command(worker,command):
        if(command['command'] == 'add'):
            a = worker.objects[int(command['values'][0])]
            b = worker.objects[int(command['values'][1])]
            c = a.add(b)
            return [c.ser(False)]
        else:
            return "command not found"
        ""
        
    torch.FloatTensor.process_command = process_command
    
hook_float_tensor_add()
hook_float_tensor___init__()
hook_float_tensor_serde()
hook_float_tensor_send()
hook_float_tensor_process_command()
hook_float_tensor_get()

In [4]:
a = torch.FloatTensor([1,2,3,4]).send(workers[5])
b = torch.FloatTensor([1,1,1,1]).send(workers[5])

c = a.add(b).get()

In [8]:
a.owner

<__main__.Worker at 0x113da92b0>

In [10]:
d = c.add(b)

In [11]:
d.owner

<__main__.Worker at 0x113da9278>

In [7]:
d


 3
 4
 5
 6
[torch.FloatTensor of size 4]