# Federated Learning Training Plan: Host Plan & Model

Here we load Plan and Model params created earlier in "Create Plan" notebook
and host them on PyGrid.

After that it should be possible to run FL worker using
SwiftSyft, KotlinSyft, syft.js, or FL python worker
and train the hosted model using local worker's data.

In [13]:
%load_ext autoreload
%autoreload 2
import warnings
warnings.filterwarnings("ignore")

import websockets
import json
import requests
import torch

import syft as sy
from syft.grid.grid_client import GridClient
from syft.serde import protobuf
from syft_proto.execution.v1.plan_pb2 import Plan as PlanPB
from syft_proto.execution.v1.state_pb2 import State as StatePB

sy.make_hook(globals())
# force protobuf serialization for tensors
hook.local_worker.framework = None



The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
Setting up Sandbox...
Done!


In [14]:
async def sendWsMessage(data):
    async with websockets.connect('ws://' + gatewayWsUrl) as websocket:
        await websocket.send(json.dumps(data))
        message = await websocket.recv()
        return json.loads(message)

def deserializeFromBin(worker, filename, pb):
    with open(filename, "rb") as f:
        bin = f.read()
    pb.ParseFromString(bin)
    return protobuf.serde._unbufferize(worker, pb)

## Step 4a: Host in PyGrid

Here we load "ops list" Plan.
PyGrid should translate it to other types (e.g. torchscript) automatically. 

In [15]:
# Load files with protobuf created in "Create Plan" notebook.
training_plan = deserializeFromBin(hook.local_worker, "tp_full.pb", PlanPB())
model_params_state = deserializeFromBin(hook.local_worker, "model_params.pb", StatePB())

Follow PyGrid README.md to build `openmined/grid-gateway` image from the latest `dev` branch 
and spin up PyGrid using `docker-compose up --build`.

In [16]:
# Default gateway address when running locally 
gatewayWsUrl = "127.0.0.1:5000"
grid = GridClient(id="test", address=gatewayWsUrl, secure=False)
grid.connect()

Define name, version, configs.

In [17]:
# These name/version you use in worker
name = "mnist"
version = "1.0.0"
client_config = {
            "name": name,  
            "version": version,
            "batch_size": 64,
            "lr": 0.005,
            "max_updates": 100  # custom syft.js option that limits number of training loops per worker
        }

server_config = {
            "min_workers": 3,  # temporarily this plays role "min # of worker's diffs" for triggering cycle end event
            "max_workers": 3,
            "pool_selection": "random",
            "num_cycles": 5,
            "do_not_reuse_workers_until_cycle": 4,
            "cycle_length": 28800,
            "minimum_upload_speed": 0,
            "minimum_download_speed": 0
        }

Shoot!

If everything's good, success is returned.
If the name/version already exists in PyGrid, change them above or cleanup PyGrid db by re-creating docker containers (e.g. `docker-compose up --force-recreate`). 


In [18]:
response = grid.host_federated_training(
    model=model_params_state,
    client_plans={'training_plan': training_plan},
    client_protocols={},
    server_averaging_plan=None,
    client_config=client_config,
    server_config=server_config
)

print("Host response:", response)

GridError: ('FL Process already exists.Traceback (most recent call last):\n  File "/app/grid/app/main/events/fl_events.py", line 56, in host_federated_training\n    server_config=server_config,\n  File "/app/grid/app/main/controller/fl_controller.py", line 55, in create_process\n    server_averaging_plan,\n  File "/app/grid/app/main/processes/process_manager.py", line 44, in create\n    raise FLProcessConflict\ngrid.app.main.exceptions.FLProcessConflict: FL Process already exists.\n', None)

Let's double-check that data is loaded by requesting a cycle.

(Request is made directly, will be methods on grid client in the future)

In [None]:
auth_request = {
    "type": "federated/authenticate",
    "data": {}
}
auth_response = await sendWsMessage(auth_request)
print('Auth response: ', json.dumps(auth_response, indent=2))

cycle_request = {
    "type": "federated/cycle-request",
    "data": {
        "worker_id": auth_response['data']['worker_id'],
        "model": name,
        "version": version,
        "ping": 1,
        "download": 10000,
        "upload": 10000,
    }
}
cycle_response = await sendWsMessage(cycle_request)
print('Cycle response:', json.dumps(cycle_response, indent=2))

worker_id = auth_response['data']['worker_id']
request_key = cycle_response['data']['request_key']
model_id = cycle_response['data']['model_id'] 
training_plan_id = cycle_response['data']['plans']['training_plan']

Let's download model and plan (both versions) and check they are actually workable.

In [None]:
# Model
req = requests.get(f"http://{gatewayWsUrl}/federated/get-model?worker_id={worker_id}&request_key={request_key}&model_id={model_id}")
model_data = req.content
pb = StatePB()
pb.ParseFromString(req.content)
model_params_downloaded = protobuf.serde._unbufferize(hook.local_worker, pb)
print("Params shapes:", [p.shape for p in model_params_downloaded.tensors()])

In [None]:
# Plan "list of ops"
req = requests.get(f"http://{gatewayWsUrl}/federated/get-plan?worker_id={worker_id}&request_key={request_key}&plan_id={training_plan_id}&receive_operations_as=list")
pb = PlanPB()
pb.ParseFromString(req.content)
plan_ops = protobuf.serde._unbufferize(hook.local_worker, pb)
print(plan_ops.code)
print(plan_ops.torchscript)

In [None]:
# Plan "torchscript"
req = requests.get(f"http://{gatewayWsUrl}/federated/get-plan?worker_id={worker_id}&request_key={request_key}&plan_id={training_plan_id}&receive_operations_as=torchscript")
pb = PlanPB()
pb.ParseFromString(req.content)
plan_ts = protobuf.serde._unbufferize(hook.local_worker, pb)
print(plan_ts.code)
print(plan_ts.torchscript.code)

## Step 5a: Train

To train hosted model, use one of the existing FL workers:
 * Python FL Client: see "[Execute Plan with Python FL Client](Execute%20Plan%20with%20Python%20FL%20Client.ipynb)" notebook that
has example of using python FL worker.
 * [SwiftSyft](https://github.com/OpenMined/SwiftSyft)
 * [KotlinSyft](https://github.com/OpenMined/KotlinSyft)
 * [syft.js](https://github.com/OpenMined/syft.js)


