In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
# !pip install pyjwt

In [3]:
import syft as sy
import torch as th
import os, websockets, json, requests, jwt
from torch import nn
from websocket import create_connection
from syft.lib.python.list import List
from syft import serialize, deserialize
from syft.lib.python.collections import OrderedDict
from syft.proto.lib.python.collections.ordered_dict_pb2 import OrderedDict as OrderedDictPB
from syft.proto.lib.python.list_pb2 import List as ListPB
from syft.lib.python.int import Int
from syft.proto.core.plan.plan_pb2 import Plan as PlanPB
from syft.core.plan.plan_builder import PLAN_BUILDER_VM, make_plan, build_plan_inputs, ROOT_CLIENT
from syft.federated.model_centric_fl_client import ModelCentricFLClient, GridError
import base64
th.random.manual_seed(1);

# Federated Learning - Model Centric MNIST Example: Create Plan

In [4]:
class MLP(sy.Module):
    def __init__(self, torch_ref):
        super().__init__(torch_ref=torch_ref)
        self.l1 = self.torch_ref.nn.Linear(784, 100)
        self.a1 = self.torch_ref.nn.ReLU()
        self.l2 = self.torch_ref.nn.Linear(100, 10)
        
    def forward(self, x):
        x_reshaped = x.view(-1, 28 * 28)
        l1_out = self.a1(self.l1(x_reshaped))
        l2_out = self.l2(l1_out)
        return l2_out

In [5]:
def set_params(model, params):
    for p, p_new in zip(model.parameters(), params): p.data = p_new.data
        
def cross_entropy_loss(logits, targets, batch_size):
    norm_logits = logits - logits.max()
    log_probs = norm_logits - norm_logits.exp().sum(dim=1, keepdim=True).log()
    return -(targets * log_probs).sum() / batch_size

def sgd_step(model, lr=0.1):
    with ROOT_CLIENT.torch.no_grad():
        for p in model.parameters():
            p.data = p.data - lr * p.grad            
            p.grad = th.zeros_like(p.grad.get())

In [6]:
local_model = MLP(th)

In [7]:
@make_plan
def train(xs = th.rand([64*3, 1, 28, 28]), ys = th.randint(0, 10, [64*3, 10]),
          params = List(local_model.parameters()) ):
    
    model = local_model.send(ROOT_CLIENT)
    set_params(model, params)
    for i in range(1):
        indices = th.tensor(range(64*i, 64*(i+1)))
        x, y = xs.index_select(0, indices), ys.index_select(0, indices)
        out = model(x)
        loss = cross_entropy_loss(out, y, 64)
        loss.backward()
        sgd_step(model)
        
    return model.parameters()

Cant make real_module pointable. You do not have permission to update Object with ID: <UID: 89828e2a40994ef49204001b9b952bc6>Please submit a request.


[2021-03-22T13:36:57.396790+0100][CRITICAL][logger]][17733] <class 'syft.core.store.store_memory.MemoryStore'> __delitem__ error <UID: fe23c23750274f8cb0762276413c9391>.
[2021-03-22T13:36:57.439207+0100][CRITICAL][logger]][17733] <class 'syft.core.store.store_memory.MemoryStore'> __delitem__ error <UID: 190ace41cfab47b2842ca48d9810a5c0>.
[2021-03-22T13:36:57.468792+0100][CRITICAL][logger]][17733] <class 'syft.core.store.store_memory.MemoryStore'> __delitem__ error <UID: 466e15b91986420785368828aee91b73>.
[2021-03-22T13:36:57.488401+0100][CRITICAL][logger]][17733] <class 'syft.core.store.store_memory.MemoryStore'> __delitem__ error <UID: 747b2c37de6747389a8da125c257487a>.


In [8]:
@make_plan
def avg_plan(avg = List(local_model.parameters()), item = List(local_model.parameters()),
             num = Int(0)):
    new_avg = []
    for i, param in enumerate(avg):
        new_avg.append((avg[i] * num + item[i]) / (num + 1))
    return new_avg

# Config & keys

In [9]:
name = "mnist" 
version = "1.0"

client_config = {
    "name": name,
    "version": version,
    "batch_size": 64,
    "lr": 0.1,
    "max_updates": 1  # custom syft.js option that limits number of training loops per worker
}

server_config = {
    "min_workers": 2,
    "max_workers": 2,
    "pool_selection": "random",
    "do_not_reuse_workers_until_cycle": 6,
    "cycle_length": 28800,  # max cycle length in seconds
    "num_cycles": 30,  # max number of cycles
    "max_diffs": 1,  # number of diffs to collect before avg
    "minimum_upload_speed": 0,
    "minimum_download_speed": 0,
    "iterative_plan": True  # tells PyGrid that avg plan is executed per diff
}

In [10]:
def read_file(fname):
    with open(fname, "r") as f:
        return f.read()

In [11]:
private_key = read_file("example_rsa").strip()
public_key  = read_file("example_rsa.pub").strip()

server_config["authentication"] = {
    "type": "jwt",
    "pub_key": public_key,
}

# Auth

In [12]:
grid_address = "bob:7000"

In [13]:
grid = ModelCentricFLClient(id="test", address=grid_address, secure=False)
grid.connect()

# Host

If the process already exists, might you need to clear the db. To do that, set path below correctly and run:

In [14]:
# !rm /Users/koen/workspace/PyGrid/apps/domain/src/datadomain.db

In [15]:
response = grid.host_federated_training(model=local_model, 
                                        client_plans={'training_plan': train},
                                        client_protocols={},
                                        server_averaging_plan=avg_plan,
                                        client_config=client_config,
                                        server_config=server_config)

In [16]:
response

{'type': 'model-centric/host-training', 'data': {'status': 'success'}}

# Authenticate for cycle

In [17]:
# Helper function to make WS requests
def sendWsMessage(data):
    ws = create_connection('ws://' + grid_address)
    ws.send(json.dumps(data))
    message = ws.recv()
    return json.loads(message)

In [18]:
auth_token = jwt.encode({}, private_key, algorithm='RS256').decode('ascii')

In [19]:
auth_request = {
    "type": "model-centric/authenticate",
    "data": {
        "model_name": name,
        "model_version": version,
        "auth_token": auth_token,
    }
}
auth_response = sendWsMessage(auth_request)
auth_response

{'type': 'model-centric/authenticate',
 'data': {'status': 'success',
  'worker_id': 'adee3e3f-c159-4a80-b388-19d435af5c6c',
  'requires_speed_test': True}}

# Do cycle request

In [20]:
cycle_request = {
    "type": "model-centric/cycle-request",
    "data": {
        "worker_id": auth_response['data']['worker_id'],
        "model": name,  
        "version": version,
        "ping": 1,
        "download": 10000,
        "upload": 10000,
    }
}
cycle_response = sendWsMessage(cycle_request)
print('Cycle response:', json.dumps(cycle_response, indent=2).replace("\\n", "\n"))


Cycle response: {
  "type": "model-centric/cycle-request",
  "data": {
    "status": "accepted",
    "request_key": "a496c206e87e2b0268ed76e9a35dec3111c32091033fb708de372ed8a2e88e59",
    "version": "1.0",
    "model": "mnist",
    "plans": {
      "training_plan": 2
    },
    "protocols": {},
    "client_config": {
      "name": "mnist",
      "version": "1.0",
      "batch_size": 64,
      "lr": 0.1,
      "max_updates": 1
    },
    "model_id": 1
  }
}


# Download model

In [21]:
worker_id = auth_response['data']['worker_id']
request_key = cycle_response['data']['request_key']
model_id = cycle_response['data']['model_id'] 
training_plan_id = cycle_response['data']['plans']['training_plan']

In [22]:
def get_model(grid_address, worker_id, request_key, model_id):
    req = requests.get(f"http://{grid_address}/model-centric/get-model?worker_id={worker_id}&request_key={request_key}&model_id={model_id}")
    model_data = req.content
    pb = ListPB()
    pb.ParseFromString(req.content)
    return deserialize(pb)

In [23]:
# Model
model_params_downloaded = get_model(grid_address, worker_id, request_key, model_id)
print("Params shapes:", [p.shape for p in model_params_downloaded])

Params shapes: [torch.Size([100, 784]), torch.Size([100]), torch.Size([10, 100]), torch.Size([10])]


In [24]:
model_params_downloaded[0]

Parameter containing:
tensor([[ 0.0184, -0.0158, -0.0069,  ...,  0.0068, -0.0041,  0.0025],
        [-0.0274, -0.0224, -0.0309,  ..., -0.0029,  0.0013, -0.0167],
        [ 0.0282, -0.0095, -0.0340,  ..., -0.0141,  0.0056, -0.0335],
        ...,
        [-0.0170, -0.0294, -0.0351,  ..., -0.0320, -0.0291, -0.0083],
        [ 0.0207, -0.0126,  0.0167,  ..., -0.0350, -0.0347, -0.0292],
        [ 0.0182,  0.0104,  0.0114,  ..., -0.0278, -0.0205,  0.0123]],
       requires_grad=True)

# Download & Execute Plan

In [25]:
req = requests.get(f"http://{grid_address}/model-centric/get-plan?worker_id={worker_id}&request_key={request_key}&plan_id={training_plan_id}&receive_operations_as=list")
pb = PlanPB()
pb.ParseFromString(req.content)
plan = deserialize(pb)

In [26]:
xs = th.rand([64*3, 1, 28, 28])
ys = th.randint(0, 10, [64*3, 10])

In [27]:
res, = plan(xs=xs, ys=ys, params=model_params_downloaded)

[2021-03-22T13:37:00.103602+0100][CRITICAL][logger]][17733] <class 'syft.core.store.store_memory.MemoryStore'> __delitem__ error <UID: 40ff4c125fdc49e28665e9f3789b8eac>.


# Report Model diff

In [28]:
diff = [orig-new for orig, new in zip(res, local_model.parameters())]
diff_serialized = serialize((List(diff))).SerializeToString()

  grad = getattr(obj, "grad", None)


In [29]:
params = {
    "type": "model-centric/report",
    "data": {"worker_id": worker_id, "request_key": request_key,
             "diff": base64.b64encode(diff_serialized).decode("ascii")},
}

In [30]:
sendWsMessage(params)

{'type': 'model-centric/report', 'data': {'status': 'success'}}

# Check new model

In [31]:
req_params = {
    "name": name,
    "version": version,
    "checkpoint": "latest",
}

In [32]:
res = requests.get(f"http://{grid_address}/model-centric/retrieve-model", req_params)

In [33]:
params_pb = ListPB()
params_pb.ParseFromString(res.content)
new_model_params = deserialize(params_pb)

In [34]:
new_model_params[0]

tensor([[ 0.0202, -0.0075, -0.0027,  ...,  0.0129,  0.0020,  0.0086],
        [-0.0262, -0.0207, -0.0279,  ..., -0.0014,  0.0025, -0.0133],
        [ 0.0337, -0.0049, -0.0296,  ..., -0.0078,  0.0094, -0.0271],
        ...,
        [-0.0182, -0.0309, -0.0409,  ..., -0.0353, -0.0303, -0.0142],
        [ 0.0195, -0.0134,  0.0152,  ..., -0.0370, -0.0363, -0.0299],
        [ 0.0187,  0.0109,  0.0119,  ..., -0.0276, -0.0200,  0.0124]],
       requires_grad=True)

In [36]:
# !rm /Users/koen/workspace/PyGrid/apps/domain/src/datadomain.db