In [5]:
# stdlib
import base64
import json

# third party
import jwt
import requests
import torch as th
from websocket import create_connection

# syft absolute
import syft as sy
from syft import deserialize
from syft import serialize
from syft.core.plan.plan_builder import ROOT_CLIENT
from syft.core.plan.plan_builder import make_plan
from syft.federated.model_centric_fl_client import ModelCentricFLClient
from syft.lib.python.int import Int
from syft.lib.python.list import List
from syft.proto.core.plan.plan_pb2 import Plan as PlanPB
from syft.proto.lib.python.list_pb2 import List as ListPB

In [6]:
th.random.manual_seed(42)

<torch._C.Generator at 0x114c07570>

# Federated Learning - Model Centric MNIST Example: Create Plan

## Step 1: Define the model

This model will train on MNIST data, it's very simple yet can demonstrate learning process.
There're 2 linear layers: 

* Linear 784x100
* ReLU
* Linear 100x10 

In [7]:
class MLP(sy.Module):
    def __init__(self, torch_ref):
        super().__init__(torch_ref=torch_ref)
        self.l1 = self.torch_ref.nn.Linear(784, 100)
        self.a1 = self.torch_ref.nn.ReLU()
        self.l2 = self.torch_ref.nn.Linear(100, 10)

    def forward(self, x):
        x_reshaped = x.view(-1, 28 * 28)
        l1_out = self.a1(self.l1(x_reshaped))
        l2_out = self.l2(l1_out)
        return l2_out

## Step 2: Define Training Plan

In [8]:
def set_params(model, params):
    for p, p_new in zip(model.parameters(), params):
        p.data = p_new.data


def cross_entropy_loss(logits, targets, batch_size):
    norm_logits = logits - logits.max()
    log_probs = norm_logits - norm_logits.exp().sum(dim=1, keepdim=True).log()
    return -(targets * log_probs).sum() / batch_size


def sgd_step(model, lr=0.1):
    with ROOT_CLIENT.torch.no_grad():
        for p in model.parameters():
            p.data = p.data - lr * p.grad
            p.grad = th.zeros_like(p.grad.get())

In [9]:
local_model = MLP(th)

In [10]:
@make_plan
def train(
    xs=th.rand([64 * 3, 1, 28, 28]),
    ys=th.randint(0, 10, [64 * 3, 10]),
    params=List(local_model.parameters()),
):

    model = local_model.send(ROOT_CLIENT)
    set_params(model, params)
    for i in range(1):
        indices = th.tensor(range(64 * i, 64 * (i + 1)))
        x, y = xs.index_select(0, indices), ys.index_select(0, indices)
        out = model(x)
        loss = cross_entropy_loss(out, y, 64)
        loss.backward()
        sgd_step(model)

    return model.parameters()

## Step 3: Define Averaging Plan

Averaging Plan is executed by PyGrid at the end of the cycle,
to average _diffs_ submitted by workers and update the model
and create new checkpoint for the next cycle.

_Diff_ is the difference between client-trained
model params and original model params,
so it has same number of tensors and tensor's shapes
as the model parameters.

We define Plan that processes one diff at a time.
Such Plans require `iterative_plan` flag set to `True`
in `server_config` when hosting FL model to PyGrid.

Plan below will calculate simple mean of each parameter.

In [11]:
@make_plan
def avg_plan(
    avg=List(local_model.parameters()), item=List(local_model.parameters()), num=Int(0)
):
    new_avg = []
    for i, param in enumerate(avg):
        new_avg.append((avg[i] * num + item[i]) / (num + 1))
    return new_avg

# Config & keys

In [12]:
name = "mnist"
version = "1.0"

client_config = {
    "name": name,
    "version": version,
    "batch_size": 64,
    "lr": 0.1,
    "max_updates": 1,  # custom syft.js option that limits number of training loops per worker
}

server_config = {
    "min_workers": 2,
    "max_workers": 2,
    "pool_selection": "random",
    "do_not_reuse_workers_until_cycle": 6,
    "cycle_length": 28800,  # max cycle length in seconds
    "num_cycles": 30,  # max number of cycles
    "max_diffs": 1,  # number of diffs to collect before avg
    "minimum_upload_speed": 0,
    "minimum_download_speed": 0,
    "iterative_plan": True,  # tells PyGrid that avg plan is executed per diff
}

In [13]:
def read_file(fname):
    with open(fname, "r") as f:
        return f.read()

In [14]:
private_key = read_file("example_rsa").strip()
public_key = read_file("example_rsa.pub").strip()

server_config["authentication"] = {
    "type": "jwt",
    "pub_key": public_key,
}

## Step 4: Host in PyGrid

Let's now host everything in PyGrid so that it can be accessed by worker libraries (syft.js, KotlinSyft, SwiftSyft, or even PySyft itself).

# Auth

In [15]:
grid_address = "localhost:7000"

In [16]:
grid = ModelCentricFLClient(address=grid_address, secure=False)
grid.connect()

# Host

If the process already exists, might you need to clear the db. To do that, set path below correctly and run:

In [17]:
# !rm PyGrid/apps/domain/src/nodedatabase.db

In [18]:
response = grid.host_federated_training(
    model=local_model,
    client_plans={"training_plan": train},
    client_protocols={},
    server_averaging_plan=avg_plan,
    client_config=client_config,
    server_config=server_config,
)

In [19]:
response

{'type': 'model-centric/host-training', 'data': {'status': 'success'}}

# Authenticate for cycle

In [20]:
# Helper function to make WS requests
def sendWsMessage(data):
    ws = create_connection("ws://" + grid_address)
    ws.send(json.dumps(data))
    message = ws.recv()
    return json.loads(message)

In [21]:
auth_token = jwt.encode({}, private_key, algorithm="RS256").decode("ascii")

In [22]:
auth_request = {
    "type": "model-centric/authenticate",
    "data": {
        "model_name": name,
        "model_version": version,
        "auth_token": auth_token,
    },
}
auth_response = sendWsMessage(auth_request)
auth_response

{'type': 'model-centric/authenticate',
 'data': {'status': 'success',
  'worker_id': '99663011-6eea-4135-a998-a3e4ac5a4553',
  'requires_speed_test': True}}

# Do cycle request

In [23]:
cycle_request = {
    "type": "model-centric/cycle-request",
    "data": {
        "worker_id": auth_response["data"]["worker_id"],
        "model": name,
        "version": version,
        "ping": 1,
        "download": 10000,
        "upload": 10000,
    },
}
cycle_response = sendWsMessage(cycle_request)
print("Cycle response:", json.dumps(cycle_response, indent=2).replace("\\n", "\n"))

Cycle response: {
  "type": "model-centric/cycle-request",
  "data": {
    "status": "accepted",
    "request_key": "c5d6a1183aec817bc29de87d745e6760ad2748377c331c8e5d3adbddbdebf600",
    "version": "1.0",
    "model": "mnist",
    "plans": {
      "training_plan": 2
    },
    "protocols": {},
    "client_config": {
      "name": "mnist",
      "version": "1.0",
      "batch_size": 64,
      "lr": 0.1,
      "max_updates": 1
    },
    "model_id": 1
  }
}


# Download model

In [24]:
worker_id = auth_response["data"]["worker_id"]
request_key = cycle_response["data"]["request_key"]
model_id = cycle_response["data"]["model_id"]
training_plan_id = cycle_response["data"]["plans"]["training_plan"]

In [25]:
def get_model(grid_address, worker_id, request_key, model_id):
    req = requests.get(
        f"http://{grid_address}/model-centric/get-model?worker_id={worker_id}&request_key={request_key}&model_id={model_id}"
    )
    model_data = req.content
    pb = ListPB()
    pb.ParseFromString(req.content)
    return deserialize(pb)

In [26]:
# Model
model_params_downloaded = get_model(grid_address, worker_id, request_key, model_id)
print("Params shapes:", [p.shape for p in model_params_downloaded])

Params shapes: [torch.Size([100, 784]), torch.Size([100]), torch.Size([10, 100]), torch.Size([10])]


In [27]:
model_params_downloaded[0]

Parameter containing:
tensor([[ 0.0273,  0.0296, -0.0084,  ..., -0.0142,  0.0093,  0.0135],
        [-0.0188, -0.0354,  0.0187,  ..., -0.0106, -0.0001,  0.0115],
        [-0.0008,  0.0017,  0.0045,  ..., -0.0127, -0.0188,  0.0059],
        ...,
        [-0.0255,  0.0213,  0.0111,  ...,  0.0060, -0.0308,  0.0306],
        [-0.0323, -0.0083, -0.0017,  ...,  0.0317, -0.0348,  0.0304],
        [-0.0058,  0.0239, -0.0202,  ..., -0.0106,  0.0301, -0.0222]],
       requires_grad=True)

# Download & Execute Plan

In [28]:
req = requests.get(
    f"http://{grid_address}/model-centric/get-plan?worker_id={worker_id}&request_key={request_key}&plan_id={training_plan_id}&receive_operations_as=list"
)
pb = PlanPB()
pb.ParseFromString(req.content)
plan = deserialize(pb)

In [29]:
xs = th.rand([64 * 3, 1, 28, 28])
ys = th.randint(0, 10, [64 * 3, 10])

In [30]:
(res,) = plan(xs=xs, ys=ys, params=model_params_downloaded)

# Report Model diff

In [31]:
diff = [orig - new for orig, new in zip(res, local_model.parameters())]
diff_serialized = serialize((List(diff))).SerializeToString()

  grad = getattr(obj, "grad", None)


In [32]:
params = {
    "type": "model-centric/report",
    "data": {
        "worker_id": worker_id,
        "request_key": request_key,
        "diff": base64.b64encode(diff_serialized).decode("ascii"),
    },
}

In [33]:
sendWsMessage(params)

{'type': 'model-centric/report', 'data': {'status': 'success'}}

# Check new model

In [34]:
req_params = {
    "name": name,
    "version": version,
    "checkpoint": "latest",
}

In [35]:
res = requests.get(f"http://{grid_address}/model-centric/retrieve-model", req_params)

In [36]:
params_pb = ListPB()
params_pb.ParseFromString(res.content)
new_model_params = deserialize(params_pb)

In [37]:
new_model_params[0]

tensor([[ 0.0273,  0.0296, -0.0084,  ..., -0.0142,  0.0093,  0.0135],
        [-0.0205, -0.0405,  0.0165,  ..., -0.0154, -0.0009,  0.0094],
        [ 0.0012,  0.0043,  0.0069,  ..., -0.0116, -0.0177,  0.0073],
        ...,
        [-0.0228,  0.0258,  0.0132,  ...,  0.0085, -0.0291,  0.0339],
        [-0.0366, -0.0134, -0.0075,  ...,  0.0274, -0.0388,  0.0244],
        [-0.0018,  0.0306, -0.0153,  ..., -0.0078,  0.0365, -0.0224]],
       requires_grad=True)

In [38]:
# !rm PyGrid/apps/domain/src/nodedatabase.db

## Step 5: Train

To train hosted model, use one of the existing FL workers:
 * PySyft - see the "[MCFL - Execute Plan](mcfl_execute_plan.ipynb)" notebook that
has example of using Python FL worker.

Support for our Edge Clients is coming soon:
 * [SwiftSyft](https://github.com/OpenMined/SwiftSyft)
 * [KotlinSyft](https://github.com/OpenMined/KotlinSyft)
 * [syft.js](https://github.com/OpenMined/syft.js)