In [1]:
import syft as sy
from syft.service.policy.policy import MixedInputPolicy

In [2]:
from datasites import DATASITE_URLS

In [3]:
for name, url in DATASITE_URLS.items():
    print("Trying", name, url)
    c = sy.login(url=url, email="researcher@openmined.org", password="****")

Trying MNIST Part 1 http://localhost:54879
Logged into <MNIST Part 1: High side Datasite> as <researcher@openmined.org>
Trying MNIST Part 2 http://localhost:54880
Logged into <MNIST Part 2: High side Datasite> as <researcher@openmined.org>


In [4]:
datasites = {}
for name, url in DATASITE_URLS.items():
    datasites[name] = sy.login(url=url, email="researcher@openmined.org", password="****")

Logged into <MNIST Part 1: High side Datasite> as <researcher@openmined.org>
Logged into <MNIST Part 2: High side Datasite> as <researcher@openmined.org>


In [5]:
mock_data = datasites["MNIST Part 1"].datasets["MNIST Dataset"]\
    .assets["MNIST Data"].mock


In [6]:
import numpy as np
import numpy.typing as npt
from typing import Union, TypeVar, Any

NDArray = npt.NDArray[Any]
NDArrayInt = npt.NDArray[np.int_]
NDArrayFloat = npt.NDArray[np.float_]

Dataset = TypeVar("torch.utils.data.Dataset")
Metric = TypeVar("Metric", bound=dict[str, Union[float, NDArrayInt]])
Metrics = TypeVar("Metrics", bound=tuple[Metric, Metric])  # train and test
ModelParams = TypeVar("ModelParams", bound=dict[str, NDArrayFloat])
Result = TypeVar("Result", bound=tuple[Metrics, ModelParams])


In [7]:
import numpy as np
from torch.utils.data import TensorDataset, DataLoader, random_split
import torch as th
import torch.nn as nn
import torch.optim as optim
from sklearn.metrics import matthews_corrcoef as mcc
from sklearn.metrics import confusion_matrix

def dl_experiment(
    data: tuple[th.Tensor, th.Tensor],   
    model_params: ModelParams = None,
    training_epochs: int = 1,
) -> Result:
    X, y = data                      
    full_ds = TensorDataset(X, y)

    n = len(full_ds)
    n_train = int(0.8 * n)
    n_test = n - n_train
    train_ds, test_ds = random_split(full_ds, [n_train, n_test])

    def make_loader(ds, shuffle: bool):
        return DataLoader(ds, batch_size=64, shuffle=shuffle)

    train_loader = make_loader(train_ds, shuffle=True)
    test_loader = make_loader(test_ds, shuffle=False)

    class MLP(nn.Module):
        def __init__(self):
            super().__init__()
            self.fc1 = nn.Linear(28 * 28, 128)
            self.fc2 = nn.Linear(128, 64)
            self.classifier = nn.Linear(64, 10)

        def forward(self, x):
            x = x.view(-1, 28 * 28)
            x = th.relu(self.fc1(x))
            x = th.relu(self.fc2(x))
            return self.classifier(x)

    clf = MLP()
    if model_params:
        clf.load_state_dict({k: th.from_numpy(v) for k, v in model_params.items()})

    device = th.device(
        "cuda"
        if th.cuda.is_available()
        else "mps"
        if th.backends.mps.is_available()
        else "cpu"
    )
    clf.to(device)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(clf.parameters(), lr=0.001)

    for epoch in range(training_epochs):
        clf.train()
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = clf(images)
            loss = criterion(outputs, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

    model_params = {k: t.detach().cpu().numpy() for k, t in clf.state_dict().items()}

    def evaluate(loader) -> Metric:
        clf.eval()
        all_true, all_pred = [], []
        with th.no_grad():
            for images, labels in loader:
                images, labels = images.to(device), labels.to(device)
                outputs = clf(images)
                preds = outputs.argmax(dim=1)
                all_true.append(labels.cpu().numpy())
                all_pred.append(preds.cpu().numpy())
        y_true = np.concatenate(all_true)
        y_pred = np.concatenate(all_pred)
        return {
            "mcc": mcc(y_true, y_pred),
            "cm": confusion_matrix(y_true, y_pred),
        }

    training_metrics = evaluate(train_loader)
    test_metrics = evaluate(test_loader)

    return (training_metrics, test_metrics), model_params


In [10]:
metrics, model_params = dl_experiment(data=mock_data)
metrics

({'mcc': 0.9414657470769033,
  'cm': array([[2286,    0,   11,    3,    2,   14,   18,    3,   14,    0],
         [   0, 2583,    7,    7,    2,    2,    3,    7,   27,    6],
         [   1,   36, 2159,   25,   13,   13,   13,   20,   48,    3],
         [   0,   16,   36, 2125,    2,   91,    2,   23,   56,   16],
         [   4,   12,    8,    0, 2236,    1,   21,    6,   12,   55],
         [   1,    3,    5,   26,    7, 2004,   21,    3,   26,    5],
         [   5,    9,    4,    0,    5,   30, 2368,    0,   16,    0],
         [   0,   20,   23,    2,   13,    0,    2, 2543,    7,   42],
         [   4,   29,   10,   14,    1,   32,    8,    4, 2234,    9],
         [   7,   13,    9,   37,   45,   18,    1,   65,   25, 2197]],
        dtype=int64)},
 {'mcc': 0.929080945032022,
  'cm': array([[548,   0,   2,   0,   3,   6,   4,   4,   7,   0],
         [  0, 681,   2,   1,   1,   1,   0,   2,   5,   1],
         [  1,  17, 574,   3,   6,   4,   5,   4,  16,   1],
         [  0,

In [11]:
model_params

{'fc1.weight': array([[-0.00664685, -0.00020845, -0.02289739, ..., -0.0319391 ,
         -0.01161636,  0.01764384],
        [-0.02924345,  0.00986123,  0.03299898, ...,  0.01930963,
          0.00201001,  0.0252439 ],
        [-0.01224802,  0.03098344,  0.00014249, ...,  0.02470691,
         -0.03940117, -0.03090782],
        ...,
        [-0.01866433,  0.02861011, -0.00971252, ..., -0.01151635,
         -0.00728869,  0.02389746],
        [-0.01712284,  0.02329939,  0.02219441, ..., -0.00499684,
          0.0258819 ,  0.00579977],
        [-0.00547862, -0.01399998, -0.02055736, ...,  0.00085069,
         -0.00746994,  0.02985933]], dtype=float32),
 'fc1.bias': array([-0.02502752, -0.03071491,  0.03865002,  0.01011007,  0.00134084,
         0.01703331, -0.01000741,  0.01534151,  0.00380063,  0.03236474,
         0.02467354,  0.03593792, -0.01751275, -0.0083262 , -0.01026052,
         0.00670924,  0.01682197, -0.03031656, -0.01419003,  0.00867055,
        -0.00658413,  0.02009567, -0.034

In [12]:
from syft.service.policy.policy import MixedInputPolicy

for name, datasite in datasites.items():
    data_asset = datasite.datasets["MNIST Dataset"].assets["MNIST Data"]

    syft_fl_experiment = sy.syft_function(
        input_policy=MixedInputPolicy(
            client=datasite,
            data=data_asset,      # on dit à la policy quel asset est autorisé
            model_params=dict,
            training_epochs=int,
        )
    )(dl_experiment)

    dl_training_project = sy.Project(
        name="DL Experiment for FL on MNIST",
        description="""DL experiment on MNIST for Federated Learning.
        The function will be invoked iteratively with updated model parameters
        averaged across datasites.""",
        members=[datasite],
    )

    dl_training_project.create_code_request(syft_fl_experiment, datasite)
    project = dl_training_project.send()


In [13]:
for name, datasite in datasites.items():
    datasite.refresh()  # important pour recharger l'API
    print(name, "API code methods:", dir(datasite.code))


MNIST Part 1 API code methods: ['call', 'dl_experiment', 'get_all', 'get_all_for_user', 'get_by_id', 'get_by_service_func_name', 'path', 'request_code_execution', 'store_execution_output', 'submit']
MNIST Part 2 API code methods: ['call', 'dl_experiment', 'get_all', 'get_all_for_user', 'get_by_id', 'get_by_service_func_name', 'path', 'request_code_execution', 'store_execution_output', 'submit']


In [14]:
import numpy as np
from collections import defaultdict
def avg(all_model_params: list[ModelParams]) -> ModelParams:
    return {
        layer: np.average([params[layer] for params in all_model_params], axis=0)
        for layer in all_model_params[0].keys()
    }
def fl_experiment(
    datasites,
    fl_epochs: int = 75,
    start_epoch: int = 0,
    training_epochs: int = 5,
    model_params: ModelParams = None,
):
    fl_model_params = dict() if not model_params else model_params
    fl_metrics = defaultdict(list)
    for epoch in range(start_epoch, (total_epochs := start_epoch + fl_epochs)):
        is_checkpoint = (
            (epoch == start_epoch)
            or (epoch % max(1, (total_epochs // 4)) == 0)
            or (epoch == total_epochs - 1)
        )
        for ds in datasites.values():
            data_asset = ds.datasets["MNIST Dataset"].assets["MNIST Data"]
            metrics, params = (
                ds.code.dl_experiment(
                    data=data_asset.data,             
                    model_params=fl_model_params,
                    training_epochs=training_epochs,
                )
                .get_from(ds)
            )
            fl_metrics[epoch].append((metrics, params))
        if is_checkpoint:
            print("Epoch:", epoch)
            for idx, name in enumerate(datasites):
                metrics = fl_metrics[epoch][idx][0]
                print(
                    f"\t {name}: "
                    f"Train:{metrics[0]['mcc']:.3f} | "
                    f"Test:{metrics[1]['mcc']:.3f}"
                )
        fl_model_params = avg([params for _, params in fl_metrics[epoch]])
    return fl_metrics, fl_model_params

In [15]:
fl_metrics, fl_model_params = fl_experiment(datasites)

In [102]:
data_asset = datasite.datasets["MNIST Dataset"].assets["MNIST Data"]
metrics, params = datasite.code.dl_experiment(
    model_params=None, # ou {}
training_epochs=1,
).get_from(datasite)


In [None]:
from matplotlib import pyplot as plt
from utils import plot_fl_metrics

plot_fl_metrics(datasites, fl_metrics, title="FL Experiment on MNIST Data w/ MLP")
plt.show()

from utils import plot_all_confusion_matrices

last_epoch = sorted(fl_metrics)[-1]
confusion_matrices = {
    name: fl_metrics[last_epoch][idx][0][1]["cm"]
    for idx, name in enumerate(datasites)
}
plot_all_confusion_matrices(
    confusion_matrices,
    title="Confusion Matrices of FL MLP on MNIST",
)
plt.show()