In [1]:
import torch as th
import torchvision
import syft as sy

hook = sy.TorchHook(th)

## Pytorch hooked local execution

In [2]:
# An instance of your model.
model = th.nn.Linear(10, 1)

# An example input you would normally provide to your model's forward() method.
example = th.rand(1, 10)

# Use torch.jit.trace to generate a torch.jit.ScriptModule via tracing.
traced_script_module = th.jit.trace(model, example)

### Serialization

In [3]:
import io
buffer = io.BytesIO()
th.jit.save(traced_script_module, buffer)

In [4]:
buffer_output = buffer.getvalue()

In [5]:
buffer_output[:10]

b'PK\x03\x04\x00\x00\x08\x08\x00\x00'

### Deserialization

In [6]:
other_buffer = io.BytesIO(buffer_output)

In [7]:
serialized_model = th.jit.load(other_buffer)

In [8]:
serialized_model(th.ones(2, 10))

tensor([[-0.0952],
        [-0.0952]], grad_fn=<DifferentiableGraphBackward>)

### Training serialized model

In [9]:
optimizer = th.optim.SGD(params=serialized_model.parameters(),lr=0.01)

In [12]:
epochs = 10
for epoch in range(1, epochs + 1):
    loss_accum = 0
    optimizer.zero_grad()
    pred = serialized_model(th.ones(1, 10))
    loss = ((pred - th.ones(1, 1))**2).sum()
    loss.backward()
    optimizer.step()

    loss_accum += float(loss)

    print('Train Epoch: {}\tLoss: {:.6f}'.format(
        epoch, loss.item()))

print('Total loss', loss_accum)

Train Epoch: 1	Loss: 0.008335
Train Epoch: 2	Loss: 0.005071
Train Epoch: 3	Loss: 0.003085
Train Epoch: 4	Loss: 0.001877
Train Epoch: 5	Loss: 0.001142
Train Epoch: 6	Loss: 0.000695
Train Epoch: 7	Loss: 0.000423
Train Epoch: 8	Loss: 0.000257
Train Epoch: 9	Loss: 0.000156
Train Epoch: 10	Loss: 0.000095
Total loss 9.519056038698182e-05


### Remote execution with virtual workers: not working

In [11]:
# A virtual worker
alice = sy.VirtualWorker(hook, id="alice")

# An instance of your model.
model = th.nn.Linear(10, 1)

# Send model to Alice
model_ptr = model.send(alice)

# An example of pointer tensor input you could potentially provide to your model's forward() method.
example = th.rand(1, 10).send(alice)

# Use torch.jit.trace to generate a torch.jit.ScriptModule via tracing.
# TODO: hook jit trace method
traced_script_module = th.jit.trace(model_ptr, example)

TypeError: linear() argument after ** must be a mapping, not list