In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys

import syft as sy
from syft.workers import WebsocketClientWorker
from syft.frameworks.torch.federated import utils

import torch as th
import torchvision
import syft as sy

hook = sy.TorchHook(th)

## Preparation: start the websocket server workers

Each worker is represented by two parts, a local handle (websocket client worker) and the remote instance that holds the data and performs the computations. The remote part is called a websocket server worker.

So first, we need to create the remote workers. For this, you need to run in a terminal (not possible from the notebook):

```bash
python examples/experimental/Plans\ Jit\ Experimental/run_websocket_server.py --port 8777 --id alice
```

In [3]:
kwargs_websocket = {"host": "localhost", "hook": hook, "verbose": False}
alice = WebsocketClientWorker(id="alice", port=8777, **kwargs_websocket)

### Create model jit trace

In [4]:
# An instance of your model.
model = th.nn.Linear(10, 1)

# An example input you would normally provide to your model's forward() method.
example = th.rand(1, 10)

# Use torch.jit.trace to generate a torch.jit.ScriptModule via tracing.
traced_script_module = th.jit.trace(model, example)

### Serialization

In [5]:
import io
buffer = io.BytesIO()
th.jit.save(traced_script_module, buffer)

In [6]:
buffer_output = buffer.getvalue()

In [7]:
buffer_output[:10]

b'PK\x03\x04\x00\x00\x08\x08\x00\x00'

### Run training on websocket worker!!!

In [8]:
alice.run_script_remote(serialized_model=buffer_output)

() {'serialized_model': b'PK\x03\x04\x00\x00\x08\x08\x00\x00% \xb5N\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x0f\x00\x13\x00archive/versionFB\x0f\x00ZZZZZZZZZZZZZZZ1\nPK\x07\x08S\xfcQg\x02\x00\x00\x00\x02\x00\x00\x00PK\x03\x04\x00\x00\x08\x08\x00\x00% \xb5N\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x17\x009\x00archive/code/archive.pyFB5\x00ZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZZop_version_set = 0\ndef forward(self,\n    argument_1: Tensor) -> Tensor:\n  _0 = self.weight\n  _1 = torch.addmm(self.bias, argument_1, torch.t(_0), beta=1, alpha=1)\n  return _1\nPK\x07\x08\xa8\x12\xd5\xef\xaf\x00\x00\x00\xaf\x00\x00\x00PK\x03\x04\x00\x00\x08\x08\x00\x00% \xb5N\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x16\x00\r\x00archive/attributes.pklFB\t\x00ZZZZZZZZZ\x80\x02(t.PK\x07\x08a)\x16|\x05\x00\x00\x00\x05\x00\x00\x00PK\x03\x04\x00\x00\x08\x08\x00\x00% \xb5N\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x11\x00<\x00archive/tensors/0FB8\x00ZZZZZZZZZZZZZZZZZZZZZZZZZ

### Console output

```bash
(syft) marianne@marianne-GS63-Stealth-8RE:~/PySyft$ python examples/experimental/Plans\ Jit\ Experimental/run_websocket_server.py --port 8777 --id alice
Train Epoch: 1  Loss: 0.488759
Train Epoch: 2  Loss: 0.297361
Train Epoch: 3  Loss: 0.180914
Train Epoch: 4  Loss: 0.110068
Train Epoch: 5  Loss: 0.066966
Train Epoch: 6  Loss: 0.040742
Train Epoch: 7  Loss: 0.024787
Train Epoch: 8  Loss: 0.015081
Train Epoch: 9  Loss: 0.009175
Train Epoch: 10 Loss: 0.005582
Total loss 0.005582111421972513
```