In [1]:
import os, sys
from easydict import EasyDict

import torch
import flwr as fl

# Agent
sys.path.append("..")
from src.agents.breast_cancer import BreastCancerLRAgent

In [2]:
CONFIG = EasyDict({
    "data": {
        "path_dir": "../data/classification/breast_cancer"
        , "path_file": "data.csv"
        , "pct_train": 0.7
        , "batch_size": 4
    },
    "model": {
        "n_factor": 30
    },
    "train": {
        "max_epoch": 5
        , "lr": 0.01
        , "log_interval": 20
    }
})

In [3]:
bc_agent = BreastCancerLRAgent(CONFIG)
# bc_agent.run()

In [4]:
class BreastCancerLRClient(fl.client.NumPyClient):

    def get_parameters(self):
        return [val.cpu().numpy() for _, val in bc_agent.model.state_dict().items()]

    def set_parameters(self, parameters):
        params_dict = zip(bc_agent.model.state_dict().keys(), parameters)
        state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
        bc_agent.model.load_state_dict(state_dict, strict=True)

    def fit(self, parameters, config):
        self.set_parameters(parameters)
        bc_agent.model.train_one_epoch()
        return self.get_parameters(), len(bc_agent.loader.loader_train), {}

    def evaluate(self, parameters, config):
        self.set_parameters(parameters)
        loss = bc_agent.validate()
        return float(loss), len(bc_agent.loader.loader_valid), {"loss": float(loss)}

In [5]:
fl.client.start_numpy_client("[::]:8080", client=BreastCancerLRClient())

DEBUG flower 2021-03-26 14:13:25,579 | connection.py:36 | ChannelConnectivity.IDLE
DEBUG flower 2021-03-26 14:13:25,583 | connection.py:36 | ChannelConnectivity.TRANSIENT_FAILURE
INFO flower 2021-03-26 14:13:25,585 | app.py:61 | Opened (insecure) gRPC connection
DEBUG flower 2021-03-26 14:13:25,800 | connection.py:68 | Insecure gRPC channel closed


_MultiThreadedRendezvous: <_MultiThreadedRendezvous of RPC that terminated with:
	status = StatusCode.UNAVAILABLE
	details = "failed to connect to all addresses"
	debug_error_string = "{"created":"@1616739205.581000000","description":"Failed to pick subchannel","file":"src/core/ext/filters/client_channel/client_channel.cc","file_line":4143,"referenced_errors":[{"created":"@1616739205.581000000","description":"failed to connect to all addresses","file":"src/core/ext/filters/client_channel/lb_policy/pick_first/pick_first.cc","file_line":398,"grpc_status":14}]}"
>