# Federated Learning Training Plan: Host Plan & Model

Here we load Plan and Model params created earlier in "Create Plan" notebook
and host them on PyGrid.

After that it should be possible to run FL worker using
SwiftSyft, KotlinSyft, syft.js, or FL python worker
and train the hosted model using local worker's data.

In [1]:
%load_ext autoreload
%autoreload 2
import warnings
warnings.filterwarnings("ignore")

import websockets
import json
import requests
import torch

import syft as sy
from syft.grid.clients.static_fl_client import StaticFLClient
from syft.serde import protobuf
from syft_proto.execution.v1.plan_pb2 import Plan as PlanPB
from syft_proto.execution.v1.state_pb2 import State as StatePB

sy.make_hook(globals())
# force protobuf serialization for tensors
hook.local_worker.framework = None

Setting up Sandbox...
Done!


In [2]:
async def sendWsMessage(data):
    async with websockets.connect('ws://' + gatewayWsUrl) as websocket:
        await websocket.send(json.dumps(data))
        message = await websocket.recv()
        return json.loads(message)

def deserializeFromBin(worker, filename, pb):
    with open(filename, "rb") as f:
        bin = f.read()
    pb.ParseFromString(bin)
    return protobuf.serde._unbufferize(worker, pb)

## Step 4a: Host in PyGrid

Here we load "ops list" Plan.
PyGrid should translate it to other types (e.g. torchscript) automatically. 

In [3]:
# Load files with protobuf created in "Create Plan" notebook.
training_plan = deserializeFromBin(hook.local_worker, "tp_full.pb", PlanPB())
model_params_state = deserializeFromBin(hook.local_worker, "model_params.pb", StatePB())

Follow PyGrid README.md to build `openmined/grid-gateway` image from the latest `dev` branch 
and spin up PyGrid using `docker-compose up --build`.

In [4]:
# Default gateway address when running locally 
gatewayWsUrl = "127.0.0.1:5000"
grid = StaticFLClient(id="test", address=gatewayWsUrl, secure=False)
grid.connect()

Define name, version, configs.

In [5]:
# These name/version you use in worker
name = "mnist"
version = "1.0.0"

client_config = {
    "name": name,
    "version": version,
    "batch_size": 64,
    "lr": 0.005,
    "max_updates": 100  # custom syft.js option that limits number of training loops per worker
}

server_config = {
    "min_workers": 3,
    "max_workers": 3,
    "pool_selection": "random",
    "num_cycles": 5,
    "do_not_reuse_workers_until_cycle": 4,
    "cycle_length": 28800,
    "max_diffs": 3,  # number of diffs to collect before avg
    "minimum_upload_speed": 0,
    "minimum_download_speed": 0,
}

### Authentication (optional)
Let's additionally protect the model with simple authentication for workers.

PyGrid supports authentication via JWT token (HMAC, RSA) or opaque token
via remote API.

We'll try JWT/RSA. Suppose we generate RSA keys:
```
openssl genrsa -out private.pem
openssl rsa -in private.pem -pubout -out public.pem
```

In [6]:
private_key = """
-----BEGIN RSA PRIVATE KEY-----
MIIEowIBAAKCAQEAzQMcI09qonB9OZT20X3Z/oigSmybR2xfBQ1YJ1oSjQ3YgV+G
FUuhEsGDgqt0rok9BreT4toHqniFixddncTHg7EJzU79KZelk2m9I2sEsKUqEsEF
lMpkk9qkPHhJB5AQoClOijee7UNOF4yu3HYvGFphwwh4TNJXxkCg69/RsvPBIPi2
9vXFQzFE7cbN6jSxiCtVrpt/w06jJUsEYgNVQhUFABDyWN4h/67M1eArGA540vyd
kYdSIEQdknKHjPW62n4dvqDWxtnK0HyChsB+LzmjEnjTJqUzr7kM9Rzq3BY01DNi
TVcB2G8t/jICL+TegMGU08ANMKiDfSMGtpz3ZQIDAQABAoIBAD+xbKeHv+BxxGYE
Yt5ZFEYhGnOk5GU/RRIjwDSRplvOZmpjTBwHoCZcmsgZDqo/FwekNzzuch1DTnIV
M0+V2EqQ0TPJC5xFcfqnikybrhxXZAfpkhtU+gR5lDb5Q+8mkhPAYZdNioG6PGPS
oGz8BsuxINhgJEfxvbVpVNWTdun6hLOAMZaH3DHgi0uyTBg8ofARoZP5RIbHwW+D
p+5vd9x/x7tByu76nd2UbMp3yqomlB5jQktqyilexCIknEnfb3i/9jqFv8qVE5P6
e3jdYoJY+FoomWhqEvtfPpmUFTY5lx4EERCb1qhWG3a7sVBqTwO6jJJBsxy3RLIS
Ic0qZcECgYEA6GsBP11a2T4InZ7cixd5qwSeznOFCzfDVvVNI8KUw+n4DOPndpao
TUskWOpoV8MyiEGdQHgmTOgGaCXN7bC0ERembK0J64FI3TdKKg0v5nKa7xHb7Qcv
t9ccrDZVn4y/Yk5PCqjNWTR3/wDR88XouzIGaWkGlili5IJqdLEvPvUCgYEA4dA+
5MNEQmNFezyWs//FS6G3lTRWgjlWg2E6BXXvkEag6G5SBD31v3q9JIjs+sYdOmwj
kfkQrxEtbs173xgYWzcDG1FI796LTlJ/YzuoKZml8vEF3T8C4Bkbl6qj9DZljb2j
ehjTv5jA256sSUEqOa/mtNFUbFlBjgOZh3TCsLECgYAc701tdRLdXuK1tNRiIJ8O
Enou26Thm6SfC9T5sbzRkyxFdo4XbnQvgz5YL36kBnIhEoIgR5UFGBHMH4C+qbQR
OK+IchZ9ElBe8gYyrAedmgD96GxH2xAuxAIW0oDgZyZgd71RZ2iBRY322kRJJAdw
Xq77qo6eXTKpni7grjpijQKBgDHWRAs5DVeZkTwhoyEW0fRfPKUxZ+ZVwUI9sxCB
dt3guKKTtoY5JoOcEyJ9FdBC6TB7rV4KGiSJJf3OXAhgyP9YpNbimbZW52fhzTuZ
bwO/ZWC40RKDVZ8f63cNsiGz37XopKvNzu36SJYv7tY8C5WvvLsrd/ZxvIYbRUcf
/dgBAoGBAMdR5DXBcOWk3+KyEHXw2qwWcGXyzxtca5SRNLPR2uXvrBYXbhFB/PVj
h3rGBsiZbnIvSnSIE+8fFe6MshTl2Qxzw+F2WV3OhhZLLtBnN5qqeSe9PdHLHm49
XDce6NV2D1mQLBe8648OI5CScQENuRGxF2/h9igeR4oRRsM1gzJN
-----END RSA PRIVATE KEY-----
""".strip()

public_key = """
-----BEGIN PUBLIC KEY-----
MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAzQMcI09qonB9OZT20X3Z
/oigSmybR2xfBQ1YJ1oSjQ3YgV+GFUuhEsGDgqt0rok9BreT4toHqniFixddncTH
g7EJzU79KZelk2m9I2sEsKUqEsEFlMpkk9qkPHhJB5AQoClOijee7UNOF4yu3HYv
GFphwwh4TNJXxkCg69/RsvPBIPi29vXFQzFE7cbN6jSxiCtVrpt/w06jJUsEYgNV
QhUFABDyWN4h/67M1eArGA540vydkYdSIEQdknKHjPW62n4dvqDWxtnK0HyChsB+
LzmjEnjTJqUzr7kM9Rzq3BY01DNiTVcB2G8t/jICL+TegMGU08ANMKiDfSMGtpz3
ZQIDAQAB
-----END PUBLIC KEY-----
""".strip()

If we set __public key__ into model authentication config,
then PyGrid will validate that submitted JWT auth token is signed with private key.

In [7]:
server_config["authentication"] = {
    "type": "jwt",
    "pub_key": public_key,
}

Shoot!

If everything's good, success is returned.
If the name/version already exists in PyGrid, change them above or cleanup PyGrid db by re-creating docker containers (e.g. `docker-compose up --force-recreate`). 


In [8]:
response = grid.host_federated_training(
    model=model_params_state,
    client_plans={'training_plan': training_plan},
    client_protocols={},
    server_averaging_plan=None,
    client_config=client_config,
    server_config=server_config
)

print("Host response:", response)

Host response: {'type': 'federated/host-training', 'data': {'status': 'success'}}


Let's double-check that data is loaded by requesting a cycle.

First, create authentication token.

In [9]:
!pip install pyjwt[crypto]
import jwt
auth_token = jwt.encode({}, private_key, algorithm='RS256').decode('ascii')

print(auth_token)

eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiJ9.e30.Cn_0cSjCw1QKtcYDx_mYN_q9jO2KkpcUoiVbILmKVB4LUCQvZ7YeuyQ51r9h3562KQoSas_ehbjpz2dw1Dk24hQEoN6ObGxfJDOlemF5flvLO_sqAHJDGGE24JRE4lIAXRK6aGyy4f4kmlICL6wG8sGSpSrkZlrFLOVRJckTptgaiOTIm5Udfmi45NljPBQKVpqXFSmmb3dRy_e8g3l5eBVFLgrBhKPQ1VbNfRK712KlQWs7jJ31fGpW2NxMloO1qcd6rux48quivzQBCvyK8PV5Sqrfw_OMOoNLcSvzePDcZXa2nPHSu3qQIikUdZIeCnkJX-w0t8uEFG3DfH1fVA


Make authentication request:

In [10]:
auth_request = {
    "type": "federated/authenticate",
    "data": {
        "model_name": name,
        "model_version": version,
        "auth_token": auth_token,
    }
}
auth_response = await sendWsMessage(auth_request)
print('Auth response: ', json.dumps(auth_response, indent=2))

Auth response:  {
  "type": "federated/authenticate",
  "data": {
    "status": "success",
    "worker_id": "e0ce650c-1e28-469d-870f-1f28bf585753"
  }
}


In [11]:
cycle_request = {
    "type": "federated/cycle-request",
    "data": {
        "worker_id": auth_response['data']['worker_id'],
        "model": name,
        "version": version,
        "ping": 1,
        "download": 10000,
        "upload": 10000,
    }
}
cycle_response = await sendWsMessage(cycle_request)
print('Cycle response:', json.dumps(cycle_response, indent=2))

worker_id = auth_response['data']['worker_id']
request_key = cycle_response['data']['request_key']
model_id = cycle_response['data']['model_id'] 
training_plan_id = cycle_response['data']['plans']['training_plan']

Cycle response: {
  "type": "federated/cycle-request",
  "data": {
    "status": "accepted",
    "request_key": "647c7411419f21a480f3209a5f396eb22d622cd3e9f72341de3c10e746a310c5",
    "version": "1.0.0",
    "model": "mnist",
    "plans": {
      "training_plan": 2
    },
    "protocols": {},
    "client_config": {
      "name": "mnist",
      "version": "1.0.0",
      "batch_size": 64,
      "lr": 0.005,
      "max_updates": 100
    },
    "model_id": 1
  }
}


Let's download model and plan (both versions) and check they are actually workable.


In [12]:
# Model
req = requests.get(f"http://{gatewayWsUrl}/federated/get-model?worker_id={worker_id}&request_key={request_key}&model_id={model_id}")
model_data = req.content
pb = StatePB()
pb.ParseFromString(req.content)
model_params_downloaded = protobuf.serde._unbufferize(hook.local_worker, pb)
print("Params shapes:", [p.shape for p in model_params_downloaded.tensors()])

Params shapes: [torch.Size([392, 784]), torch.Size([392]), torch.Size([10, 392]), torch.Size([10])]


In [13]:
# Plan "list of ops"
req = requests.get(f"http://{gatewayWsUrl}/federated/get-plan?worker_id={worker_id}&request_key={request_key}&plan_id={training_plan_id}&receive_operations_as=list")
pb = PlanPB()
pb.ParseFromString(req.content)
plan_ops = protobuf.serde._unbufferize(hook.local_worker, pb)
print(plan_ops.code)
print(plan_ops.torchscript)

def training_plan(arg_1, arg_2, arg_3, arg_4, arg_5, arg_6, arg_7, arg_8):
    var_0 = arg_5.t()
    var_1 = arg_1.matmul(var_0)
    var_2 = arg_6.add(var_1)
    var_3 = var_2.relu()
    var_4 = arg_7.t()
    var_5 = var_3.matmul(var_4)
    var_6 = arg_8.add(var_5)
    var_7 = var_6.max()
    var_8 = var_6.sub(var_7)
    var_9 = var_8.exp()
    var_10 = var_9.sum(dim=1, keepdim=True)
    var_11 = var_10.log()
    var_12 = var_8.sub(var_11)
    var_13 = arg_2.mul(var_12)
    var_14 = var_13.sum()
    var_15 = var_14.neg()
    out_1 = var_15.div(arg_3)
    var_16 = out_1.mul(0)
    var_17 = var_16.add(1)
    var_18 = var_17.div(arg_3)
    var_19 = var_18.mul(-1)
    var_20 = var_19.reshape([-1, 1])
    var_21 = var_13.mul(0)
    var_22 = var_21.add(1)
    var_23 = var_22.mul(var_20)
    var_24 = var_23.mul(arg_2)
    var_25 = var_24.add(0)
    var_26 = var_24.mul(-1)
    var_27 = var_26.sum(dim=[1], keepdim=True)
    var_28 = var_25.add(0)
    var_29 = var_28.add(0)
    var_30 = var_28.a

In [14]:
# Plan "torchscript"
req = requests.get(f"http://{gatewayWsUrl}/federated/get-plan?worker_id={worker_id}&request_key={request_key}&plan_id={training_plan_id}&receive_operations_as=torchscript")
pb = PlanPB()
pb.ParseFromString(req.content)
plan_ts = protobuf.serde._unbufferize(hook.local_worker, pb)
print(plan_ts.code)
print(plan_ts.torchscript.code)

def training_plan(arg_1, arg_2, arg_3, arg_4, arg_5, arg_6, arg_7, arg_8):
    return out_1, out_2, out_3, out_4, out_5, out_6
def forward(self,
    argument_1: Tensor,
    argument_2: Tensor,
    argument_3: Tensor,
    argument_4: Tensor,
    argument_5: List[Tensor]) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]:
  _0, _1, _2, _3, = argument_5
  _4 = torch.add(_1, torch.matmul(argument_1, torch.t(_0)), alpha=1)
  _5 = torch.relu(_4)
  _6 = torch.t(_2)
  _7 = torch.add(_3, torch.matmul(_5, _6), alpha=1)
  _8 = torch.sub(_7, torch.max(_7), alpha=1)
  _9 = torch.exp(_8)
  _10 = torch.sum(_9, [1], True, dtype=None)
  _11 = torch.sub(_8, torch.log(_10), alpha=1)
  _12 = torch.mul(argument_2, _11)
  _13 = torch.div(torch.neg(torch.sum(_12, dtype=None)), argument_3)
  _14 = torch.add(torch.mul(_13, CONSTANTS.c0), CONSTANTS.c1, alpha=1)
  _15 = torch.mul(torch.div(_14, argument_3), CONSTANTS.c2)
  _16 = torch.reshape(_15, [-1, 1])
  _17 = torch.add(torch.mul(_12, CONSTANTS.c0), C

In [15]:
# Plan "tfjs"
req = requests.get(f"http://{gatewayWsUrl}/federated/get-plan?worker_id={worker_id}&request_key={request_key}&plan_id={training_plan_id}&receive_operations_as=tfjs")
pb = PlanPB()
pb.ParseFromString(req.content)
plan_tfjs = protobuf.serde._unbufferize(hook.local_worker, pb)
print(plan_tfjs.code)


def training_plan(arg_1, arg_2, arg_3, arg_4, arg_5, arg_6, arg_7, arg_8):
    var_0 = tf.transpose(arg_5)
    var_1 = tf.matMul(arg_1, var_0)
    var_2 = tf.add(arg_6, var_1)
    var_3 = tf.relu(var_2)
    var_4 = tf.transpose(arg_7)
    var_5 = tf.matMul(var_3, var_4)
    var_6 = tf.add(arg_8, var_5)
    var_7 = tf.max(var_6)
    var_8 = tf.sub(var_6, var_7)
    var_9 = tf.exp(var_8)
    var_10 = tf.sum(var_9, 1, keepdim=True)
    var_11 = tf.log(var_10)
    var_12 = tf.sub(var_8, var_11)
    var_13 = tf.mul(arg_2, var_12)
    var_14 = tf.sum(var_13)
    var_15 = tf.neg(var_14)
    out_1 = tf.div(var_15, arg_3)
    var_16 = tf.mul(out_1, 0)
    var_17 = tf.add(var_16, 1)
    var_18 = tf.div(var_17, arg_3)
    var_19 = tf.mul(var_18, -1)
    var_20 = reshape(var_19, [-1, 1])
    var_21 = tf.mul(var_13, 0)
    var_22 = tf.add(var_21, 1)
    var_23 = tf.mul(var_22, var_20)
    var_24 = tf.mul(var_23, arg_2)
    var_25 = tf.add(var_24, 0)
    var_26 = tf.mul(var_24, -1)
    var_27 = tf.s

## Step 5a: Train

To train hosted model, use one of the existing FL workers:
 * Python FL Client: see "[Execute Plan with Python FL Client](Execute%20Plan%20with%20Python%20FL%20Client.ipynb)" notebook that
has example of using python FL worker.
 * [SwiftSyft](https://github.com/OpenMined/SwiftSyft)
 * [KotlinSyft](https://github.com/OpenMined/KotlinSyft)
 * [syft.js](https://github.com/OpenMined/syft.js)


