### Notes

Note, this requires using PyTorch v0.3.1.  Somewhere between 0.3.1 and 0.4.0 parts of the backend were significantly rewritten, preventing us from performing the following hacks. (Likely has to do with them fusing Variable and Tensor).  That may change once their new API stabilizes.

The nice thing about how this is working is that it should be general enough to work for compute, tree, and federated modes of Grid, depending on how the `receive` function works under the hood.

In [1]:
from grid.clients.torch import TorchClient

  from ._conv import register_converters as _register_converters
Using TensorFlow backend.


In [2]:
client = TorchClient(verbose = False)


[34mUPDATE: [0mConnecting to IPFS... this can take a few seconds...

[32mSUCCESS: [0mConnected!!! - My ID: QmXJMbiCqQdFCUjwy63GMUDDKCfEabJRYo2RHPjheCW8mc

[34mUPDATE: [0mQuerying known workers...
	WORKER: /p2p-circuit/ipfs/QmXkWUybbTnfvFH8SUcrug6RGTLYTB23gSockKLxueR1vQ...[32mSUCCESS!!![0m
	WORKER: /p2p-circuit/ipfs/Qmaosc64H6Y29VFCFYJzJXCX9AuRp7RCsekLmajHNVEARD...[32mSUCCESS!!![0m
	WORKER: /p2p-circuit/ipfs/QmQabt3SWuDvjse9z7GAcH2BGQv4wH8bumkd4x5oXN2obX...[32mSUCCESS!!![0m
	WORKER: /p2p-circuit/ipfs/Qme8SQLibzaAPQSS4GRFQCqAXqVPVknZeDLPqeePYYka8d...[32mSUCCESS!!![0m

[34mUPDATE: [0mSearching for IPFS nodes - 21 found overall - 3 are OpenMined workers          

[32mSUCCESS: [0mFound 3 OpenMined nodes!!!



In [3]:
service_self = client.services['torch_service']

In [4]:
# service_self = client.services['torch_service']
# def hook_float_tensor___init__(service_self):
#     torch.FloatTensor.old___init__ = torch.FloatTensor.__init__
#     def new___init__(self, *args, **kwargs):
#         self.old___init__(*args, **kwargs)
#         self = service_self.register_object(self,False)

#     torch.FloatTensor.__init__ = new___init__

In [5]:
import torch
import inspect
from torch.autograd import Variable
import random
import re
from functools import wraps, partial, partialmethod
from types import *
import imp
# from contextlib import contextmanager

In [6]:
tensor_types = [torch.FloatTensor,
                torch.DoubleTensor,
                torch.HalfTensor,
                torch.ByteTensor,
                torch.CharTensor,
                torch.ShortTensor,
                torch.IntTensor,
                torch.LongTensor]

In [7]:
def get_tensorvars(command):
    args = command['args']
    kwargs = command['kwargs']
    arg_types = command['arg_types']
    kwarg_types = command['kwarg_types']
    tensorvar_args = [args[i] for i in range(len(args)) if arg_types[i] in tensor_types]
    tensorvar_kwargs = [kwargs[i][1] for i in range(len(kwargs)) if kwarg_types[i] in tensor_types]
    return tensorvar_args + tensorvar_kwargs
    
def check_tensorvars(tensorvars):
    has_remote = any([tensorvar.is_pointer_to_remote for tensorvar in tensorvars])
    multiple_owners = len(set([tensorvar.owner for tensorvar in tensorvars])) != 1
    return has_remote, multiple_owners

In [8]:
def assign_workers_function(worker_ids):
    def decorate(func):
        @wraps(func)
        def send_to_workers(*args, **kwargs):
            part = func(*args, **kwargs)
            command = compile_command(part)
            tensorvars = get_tensorvars(command)
            has_remote, multiple_owners = check_tensorvars(tensorvars)
            if not has_remote:
                return part.func(*args, **kwargs)
            elif multiple_owners:
                raise NotImplementedError('MPC not yet implemented: Torch objects need to be on the same machine in order to compute with them.')
            else:
                for worker in worker_ids:
                    print("Placeholder print for sending command to worker {}".format(worker))
                    args, kwargs = send_command(command)
                receive_commands(worker_ids)  ## Probably needs to happen async
                return args, kwargs
        return send_to_workers
    return decorate

In [9]:
def assign_workers_method(worker_ids):
    def decorate(method):
        @wraps(method)
        def send_to_workers(self, *args, **kwargs):
            part = method(self, *args, **kwargs)
            if self.is_pointer_to_remote:
                command = compile_command(part)
                for worker in worker_ids:
                    print("Placeholder print for sending command to worker {}".format(worker))
                    args, kwargs = send_command(command)
                receive_commands(worker_ids)  ## Probably needs to happen async
                return args, kwargs
            else:
                result = part.func(self, *args, **kwargs)
                if type(result) in tensor_types:
                    my_service = self.worker.services['torch_service']
                    result = my_service.register_object(result, False)
                return result
        return send_to_workers
    return decorate

In [10]:
# # Slightly modified to remove parent class dependency
# torch.FloatTensor.old___init__ = torch.FloatTensor.__init__
# def hook_float_tensor___init__():
#     def new___init__(self, tensor, owner=client.services['torch_service'], *args, **kwargs):
#         super(torch.FloatTensor, self).__init__(*args, **kwargs)
#         self = owner.register_object(self, False)

#     torch.FloatTensor.__init__ = new___init__

In [11]:
# service_self = client.services['torch_service']
# def hook_float_tensor___init__(service_self):
#     def new___init__(self, *args, **kwargs):
#         super(torch.FloatTensor, self).__init__(*args, **kwargs)
#         self = service_self.register_object(self,False)

#     torch.FloatTensor.__init__ = new___init__

In [12]:
# def assign_workers_factory(worker_ids):
#     def decorate(method):
#         @wraps(method)
#         def send_to_workers(self, *args, **kwargs):
#             part = method(self, *args, **kwargs)
#             command = compile_command(part)
#             for worker in worker_ids:
#                 print("Placeholder print for sending command to worker {}".format(worker))
#                 args, kwargs = send_command(command)
#             receive_commands(worker_ids)  ## Probably needs to happen async
#             return old_init(*args, **kwargs)
#         return send_to_workers
#     return decorate

In [13]:
service_self = client.services['torch_service']
def hook_tensor___init__(service_self, tensor_type):
    def new___init__(self, tensor, *args, **kwargs):
        super(tensor_type, self).__init__(*args, **kwargs)
        self = service_self.register_object(self,False)

    tensor_type.__init__ = new___init__

In [14]:
def hook_tensor___repr__(service_self, tensor_type):
        def __repr__(self):
            if(service_self.worker.id == self.owner):
                return self.old__repr__()
            else:
                return "[ {} - Location:{} ]".format(tensor_type, self.owner)

        # if haven't reserved the actual __repr__ function - reserve it now
        try:
            tensor_type.old__repr__
        except:
            tensor_type.old__repr__ = tensor_type.__repr__
            

        tensor_type.__repr__ = __repr__

In [15]:
def pass_func_args(func):
    @wraps(func)
    def pass_args(*args, **kwargs):
        return partial(func, *args, **kwargs)
    return pass_args

def pass_method_args(method):
    @wraps(method)
    def pass_args(*args, **kwargs):
        return partialmethod(method, *args, **kwargs)
    return pass_args

In [16]:
def send_command(command):
    print(command['command'])
    print([type(arg) for arg in command['args']])
    print([type(pair) for pair in command['kwargs']])
    print('===========')
    print()
    return command['args'], command['kwargs']

def receive_commands(worker_ids):
    print('Placeholder print for receiving commands from workers in the following list')
    print(worker_ids)

In [17]:
def compile_command(partial_func):
    func = partial_func.func
    args = partial_func.args
    kwargs = partial_func.keywords
    command = {}
    command['command'] = func.__name__
    command['command_type'] = type(func)
    command['args'] = args
    command['kwargs'] = kwargs
    command['arg_types'] = [type(x) for x in args]
    command['kwarg_types'] = [type(kwargs[x]) for x in kwargs]
    return command

In [18]:
%%time
for x in range(100000):
    y = torch.FloatTensor([[2,2],[2,2]])
    z = torch.FloatTensor([[1,1],[1,1]])
    res = y.add(z)

CPU times: user 1.22 s, sys: 23.1 ms, total: 1.24 s
Wall time: 1.25 s


In [19]:
%%time

for attr in dir(torch):
    if attr == 'typename':
        continue
    if type(torch.__getattribute__(attr)) in [FunctionType, BuiltinFunctionType]:
        torch.__setattr__(attr, assign_workers_function(['A1','B1', 'B2'])(pass_func_args(torch.__getattribute__(attr))))

exclude = ['ndimension', 'nelement', 'size','numel', 'ser', 'de']
for tensor_type in tensor_types:
    print('Hooking {}'.format(tensor_type))
    print('==============')
    if tensor_type is not torch.FloatTensor:
        hook_tensor___init__(service_self, tensor_type)
        hook_tensor___repr__(service_self, tensor_type)
    for attr in dir(tensor_type):
        lit = getattr(tensor_type, attr)
        is_desc = inspect.ismethoddescriptor(lit)
        is_func = type(lit)==FunctionType
        is_mappingproxy = attr == '__dict__'
        try:
            is_service_func = 'TorchService' in lit.__qualname__
        except:
            is_service_func = False
        is_base = attr in dir(object)
        is_old = re.match('old*', attr) is not None
        if attr in exclude:
            print(attr,' skipped')
            continue
        if (is_desc or (is_func and not is_service_func)) and not is_base and not is_old:
            print(attr)
            setattr(tensor_type, 'old_{}'.format(attr), lit)
            setattr(tensor_type, attr, assign_workers_method(['A1','B1', 'B2'])(pass_method_args(lit)))
        else:
            print(attr, ' skipped')
    print()

Hooking <class 'torch.FloatTensor'>
__add__
__and__
__array__
__array_wrap__
__bool__
__class__  skipped
__deepcopy__
__delattr__  skipped
__delitem__
__dict__  skipped
__dir__  skipped
__div__
__doc__  skipped
__eq__  skipped
__float__
__format__  skipped
__ge__  skipped
__getattribute__  skipped
__getitem__
__getstate__
__gt__  skipped
__hash__  skipped
__iadd__
__iand__
__idiv__
__ilshift__
__imul__
__init__  skipped
__init_subclass__  skipped
__int__
__invert__
__ior__
__ipow__
__irshift__
__isub__
__iter__
__itruediv__
__ixor__
__le__  skipped
__len__
__long__
__lshift__
__lt__  skipped
__matmul__
__mod__
__module__  skipped
__mul__
__ne__  skipped
__neg__
__new__  skipped
__nonzero__
__or__
__pow__
__radd__
__rdiv__
__reduce__  skipped
__reduce_ex__  skipped
__repr__  skipped
__rmul__
__rpow__
__rshift__
__rsub__
__rtruediv__
__setattr__  skipped
__setitem__
__setstate__
__sizeof__  skipped
__str__  skipped
__sub__
__subclasshook__  skipped
__truediv__
__weakref__  skipped
__xor_

In [20]:
%%time
for x in range(100000):
    y = torch.FloatTensor([[2,2],[2,2]])
    z = torch.FloatTensor([[1,1],[1,1]])
    res = y.add(z)

CPU times: user 2.04 s, sys: 446 ms, total: 2.49 s
Wall time: 2.01 s


# FloatTensor

In [21]:
x = y.add(z)

In [22]:
print(x.is_pointer_to_remote)
print(x.id)

False
664525414


In [23]:
x


 3  3
 3  3
[torch.FloatTensor of size 2x2]

In [24]:
x.fill_(0)


 0  0
 0  0
[torch.FloatTensor of size 2x2]

In [25]:
print(x)


 0  0
 0  0
[torch.FloatTensor of size 2x2]



Case when tensor isn't local

In [26]:
x.is_pointer_to_remote = True
x.owner = 'other_guy'

In [27]:
x.normal_()

Placeholder print for sending command to worker A1
normal_
[<class 'torch.FloatTensor'>]
[]

Placeholder print for sending command to worker B1
normal_
[<class 'torch.FloatTensor'>]
[]

Placeholder print for sending command to worker B2
normal_
[<class 'torch.FloatTensor'>]
[]

Placeholder print for receiving commands from workers in the following list
['A1', 'B1', 'B2']


(([ torch.FloatTensor - Location:other_guy ],), {})

In [28]:
x.uniform_()

Placeholder print for sending command to worker A1
uniform_
[<class 'torch.FloatTensor'>]
[]

Placeholder print for sending command to worker B1
uniform_
[<class 'torch.FloatTensor'>]
[]

Placeholder print for sending command to worker B2
uniform_
[<class 'torch.FloatTensor'>]
[]

Placeholder print for receiving commands from workers in the following list
['A1', 'B1', 'B2']


(([ torch.FloatTensor - Location:other_guy ],), {})

In [29]:
torch.add(x, x)

Placeholder print for sending command to worker A1
add
[<class 'torch.FloatTensor'>, <class 'torch.FloatTensor'>]
[]

Placeholder print for sending command to worker B1
add
[<class 'torch.FloatTensor'>, <class 'torch.FloatTensor'>]
[]

Placeholder print for sending command to worker B2
add
[<class 'torch.FloatTensor'>, <class 'torch.FloatTensor'>]
[]

Placeholder print for receiving commands from workers in the following list
['A1', 'B1', 'B2']


(([ torch.FloatTensor - Location:other_guy ],
  [ torch.FloatTensor - Location:other_guy ]),
 {})

In [30]:
try:
    torch.add(x,y) # This should throw an error, since their attributes say they're not on the same machine.
except NotImplementedError:
    print('booped!')

booped!


# DoubleTensor

In [31]:
y = torch.DoubleTensor([[2,2],[2,2]])
z = torch.DoubleTensor(([[1,1],[1,1]]))

In [32]:
x = y.add(z)

In [33]:
print(x.is_pointer_to_remote)
print(x.id)

False
5174648574


In [34]:
x


 3  3
 3  3
[torch.DoubleTensor of size 2x2]

In [35]:
x.fill_(0)


 0  0
 0  0
[torch.DoubleTensor of size 2x2]

In [36]:
print(x)


 0  0
 0  0
[torch.DoubleTensor of size 2x2]



Case when tensor isn't local

In [37]:
x.is_pointer_to_remote = True
x.owner = 'other_guy'

In [38]:
x.normal_()

Placeholder print for sending command to worker A1
normal_
[<class 'torch.DoubleTensor'>]
[]

Placeholder print for sending command to worker B1
normal_
[<class 'torch.DoubleTensor'>]
[]

Placeholder print for sending command to worker B2
normal_
[<class 'torch.DoubleTensor'>]
[]

Placeholder print for receiving commands from workers in the following list
['A1', 'B1', 'B2']


(([ <class 'torch.DoubleTensor'> - Location:other_guy ],), {})

In [39]:
x.uniform_()

Placeholder print for sending command to worker A1
uniform_
[<class 'torch.DoubleTensor'>]
[]

Placeholder print for sending command to worker B1
uniform_
[<class 'torch.DoubleTensor'>]
[]

Placeholder print for sending command to worker B2
uniform_
[<class 'torch.DoubleTensor'>]
[]

Placeholder print for receiving commands from workers in the following list
['A1', 'B1', 'B2']


(([ <class 'torch.DoubleTensor'> - Location:other_guy ],), {})

In [40]:
torch.add(x, x)

Placeholder print for sending command to worker A1
add
[<class 'torch.DoubleTensor'>, <class 'torch.DoubleTensor'>]
[]

Placeholder print for sending command to worker B1
add
[<class 'torch.DoubleTensor'>, <class 'torch.DoubleTensor'>]
[]

Placeholder print for sending command to worker B2
add
[<class 'torch.DoubleTensor'>, <class 'torch.DoubleTensor'>]
[]

Placeholder print for receiving commands from workers in the following list
['A1', 'B1', 'B2']


(([ <class 'torch.DoubleTensor'> - Location:other_guy ],
  [ <class 'torch.DoubleTensor'> - Location:other_guy ]),
 {})

# HalfTensor

In [41]:
y = torch.HalfTensor([[2,2],[2,2]])
z = torch.HalfTensor(([[1,1],[1,1]]))

In [42]:
y.float()


 2  2
 2  2
[torch.FloatTensor of size 2x2]

Case when tensor isn't local

In [43]:
y.is_pointer_to_remote = True
y.owner = 'other_guy'

In [44]:
y.float()

Placeholder print for sending command to worker A1
float
[<class 'torch.HalfTensor'>]
[]

Placeholder print for sending command to worker B1
float
[<class 'torch.HalfTensor'>]
[]

Placeholder print for sending command to worker B2
float
[<class 'torch.HalfTensor'>]
[]

Placeholder print for receiving commands from workers in the following list
['A1', 'B1', 'B2']


(([ <class 'torch.HalfTensor'> - Location:other_guy ],), {})

In [45]:
torch.add(y, y)

Placeholder print for sending command to worker A1
add
[<class 'torch.HalfTensor'>, <class 'torch.HalfTensor'>]
[]

Placeholder print for sending command to worker B1
add
[<class 'torch.HalfTensor'>, <class 'torch.HalfTensor'>]
[]

Placeholder print for sending command to worker B2
add
[<class 'torch.HalfTensor'>, <class 'torch.HalfTensor'>]
[]

Placeholder print for receiving commands from workers in the following list
['A1', 'B1', 'B2']


(([ <class 'torch.HalfTensor'> - Location:other_guy ],
  [ <class 'torch.HalfTensor'> - Location:other_guy ]),
 {})

In [46]:
a = torch.HalfTensor([[2,2],[2,2]])
b = torch.HalfTensor(([[1,1],[1,1]]))

In [47]:
try:
    torch.add(a,b)
except:
    print('HalfTensor is weird')

HalfTensor is weird


# LongTensor

In [48]:
y = torch.DoubleTensor([[1,2],[3,4]])
z = torch.DoubleTensor(([[1,1],[1,1]]))

In [49]:
x = y.add(z)

In [50]:
print(x.is_pointer_to_remote)
print(x.id)

False
7411756871


In [51]:
x


 2  3
 4  5
[torch.DoubleTensor of size 2x2]

In [52]:
x.t()


 2  4
 3  5
[torch.DoubleTensor of size 2x2]

In [53]:
x.fill_(0)


 0  0
 0  0
[torch.DoubleTensor of size 2x2]

In [54]:
print(x)


 0  0
 0  0
[torch.DoubleTensor of size 2x2]



In [55]:
x.t()


 0  0
 0  0
[torch.DoubleTensor of size 2x2]

Case when tensor isn't local

In [56]:
x.is_pointer_to_remote = True
x.owner = 'other_guy'

In [57]:
#x.normal_()

In [58]:
#x.uniform_()

In [59]:
x.t()

Placeholder print for sending command to worker A1
t
[<class 'torch.DoubleTensor'>]
[]

Placeholder print for sending command to worker B1
t
[<class 'torch.DoubleTensor'>]
[]

Placeholder print for sending command to worker B2
t
[<class 'torch.DoubleTensor'>]
[]

Placeholder print for receiving commands from workers in the following list
['A1', 'B1', 'B2']


(([ <class 'torch.DoubleTensor'> - Location:other_guy ],), {})

In [60]:
torch.add(x, x)

Placeholder print for sending command to worker A1
add
[<class 'torch.DoubleTensor'>, <class 'torch.DoubleTensor'>]
[]

Placeholder print for sending command to worker B1
add
[<class 'torch.DoubleTensor'>, <class 'torch.DoubleTensor'>]
[]

Placeholder print for sending command to worker B2
add
[<class 'torch.DoubleTensor'>, <class 'torch.DoubleTensor'>]
[]

Placeholder print for receiving commands from workers in the following list
['A1', 'B1', 'B2']


(([ <class 'torch.DoubleTensor'> - Location:other_guy ],
  [ <class 'torch.DoubleTensor'> - Location:other_guy ]),
 {})