In [1]:
import torch
from torch import optim, nn
from torch.autograd import Variable
import syft as sy
hook = sy.TorchHook(torch)
# import pixiedust


In [2]:
@property
def location(self):
    m = self.__getitem__(0)
    w = m.weight[0]
    return w.location

nn.Sequential.location = location

In [3]:
# A Toy Dataset
x = torch.tensor([[0,0,0,0],[1,0,0,0],[0,1,0,0],[0,0,1,0],[1,1,0,0],[1,0,1,0],[0,1,1,0],[1,1,1,0],[0,0,0,1],[1,0,0,1],[0,1,0,1],[0,0,1,1],[1,1,0,1],[1,0,1,1],[0,1,1,1],[1,1,1,1.]])
x.requires_grad_()
target = torch.tensor([[0],[0],[0],[0],[0],[0],[0],[0],[1],[1],[1],[1],[1],[1],[1],[1.]])


#   Variables for performance metrics
epochs = 20
lr = 0.2
counter = 0

# Define 2 chained models
models = [
    nn.Sequential(
        nn.Linear(4, 3),
        nn.Tanh()
    ),
    nn.Sequential(
        nn.Linear(3, 1),
        nn.Sigmoid()
    )
]

# Create optimisers for each segment and link to their segment
optimizers = [
    optim.SGD(params=model.parameters(),lr=lr)
    for model in models
]

# create some workers
alice = sy.VirtualWorker(hook, id="alice")
bob = sy.VirtualWorker(hook, id="bob")
workers = alice, bob

# Send Model Segments and Data to starting locations
model_locations = [alice, bob]

for model, location in zip(models, model_locations):
    model.send(location)
    
x = x.send(models[0].location)
target = target.send(models[1].location)

In [4]:
# %%pixie_debugger

def train():
    # Training Logic
    for iter in range(epochs):

        # 1) erase previous gradients (if they exist)
        for opt in optimizers:
            opt.zero_grad()

        # 2) make a prediction
        a  = models[0](x)
        
        print('*')
        print(a.location)
        ones = Variable(torch.ones(a.shape))
        ones = ones.send(a.location)
        print(ones)
        a.add(ones)
        print('$')

        # 3) send the activation signal to the next model
        a_to_send = a.detach()
        
        remote_a = a_to_send.move(models[1].location)
        # re-enable autograd here
        remote_a.requires_grad_()

        pred =  models[1](remote_a)

        # 3) calculate how much we missed
        loss = ((pred - target)**2).sum()

        # 4) figure out which weights caused us to miss
        loss.backward()
        
        # 5) Backprop gradient to model behind
        grad_a = remote_a.grad.clone()
        grad_a.move(models[0].location)
        
        # 5) This is where it breaks, these are both in the same location
        # and of the same width and length. I think the issue could be that
        # it is looking at the pointer and not the tensor behind the pointer
        print(a)
        print(grad_a)
        print("a: ",a.shape, "location: ",a.location)
        print("grad_a: ",grad_a.shape,"location: ",grad_a.location)
        print('====')
        print(a.location._objects[a.id_at_location])
        print(grad_a.location._objects[grad_a.id_at_location])
        print('====')
        
        ones = Variable(torch.ones(a.shape))
        ones = ones.send(a.location)
        a.backward(ones)


        # 5) change the weights
        for opt in optimizers:
            opt.step()

        # 6) print our progress
        # Do not use .data
        print(loss.detach())
        
train()

*
<VirtualWorker id:alice #objects:4>
(Wrapper)>[PointerTensor | me:76880242729 -> alice:77691046710]
backward
54016016075 77691046710 tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]])
$
(Wrapper)>[PointerTensor | me:74453929069 -> alice:54016016075]
(Wrapper)>[PointerTensor | me:21263053600 -> alice:83110777809]
a:  torch.Size([16, 3]) location:  <VirtualWorker id:alice #objects:6>
grad_a:  torch.Size([16, 3]) location:  <VirtualWorker id:alice #objects:6>
====
tensor([[ 0.2782,  0.2705,  0.2178],
        [ 0.5266,  0.4460,  0.2571],
        [ 0.1375, -0.0825, -0.1112],
        [ 0.1314,  0.5827,  0.4651],
        [ 0.4121,  0.1191, -0.0698],
        [ 0.4069,  0.7008,  0.4971],
        

54189897086 575159298 tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]])
$
(Wrapper)>[PointerTensor | me:90396777249 -> alice:54189897086]
(Wrapper)>[PointerTensor | me:66500106077 -> alice:18885709279]
a:  torch.Size([16, 3]) location:  <VirtualWorker id:alice #objects:18>
grad_a:  torch.Size([16, 3]) location:  <VirtualWorker id:alice #objects:18>
====
tensor([[-0.9909, -0.9888, -0.9915],
        [-0.9991, -0.9989, -0.9995],
        [-0.9996, -0.9997, -0.9998],
        [-0.9996, -0.9983, -0.9990],
        [-1.0000, -1.0000, -1.0000],
        [-1.0000, -0.9998, -0.9999],
        [-1.0000, -1.0000, -1.0000],
        [-1.0000, -1.0000, -1.0000],
        [-0.9998, -0.9998, -0.9995],
        

In [5]:
sorted(["__getitem__",
    "__setitem__",
    "_getitem_public",
    "view",
    "permute",
    "add_",
    "sub_",
    "new",
    "chunk",
    "reshape","backward"])

['__getitem__',
 '__setitem__',
 '_getitem_public',
 'add_',
 'backward',
 'chunk',
 'new',
 'permute',
 'reshape',
 'sub_',
 'view']