In [1]:
from centralized_package.client_centralized import load_data, load_model, train, test

In [2]:
import torch
from collections import OrderedDict

In [3]:
import flwr as fl

In [4]:
def set_parameters(model, parameters):
    params_dict = zip(model.state_dict().keys(), parameters)
    state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
    model.load_state_dict(state_dict, strict=True)

In [5]:
net = load_model()
trainloader, testloader = load_data()

Files already downloaded and verified
Files already downloaded and verified


In [6]:
class FlowerClient(fl.client.NumPyClient):
    
    def get_parameters(self, config):
        return [val.cpu().numpy() for _, val in net.state_dict().items()]
    
    def fit(self, parameters, config):
        set_parameters(net, parameters)
        train(net, trainloader, epochs=1)
        return self.get_parameters(config={}), len(trainloader.dataset), {}
    
    def evaluate(self, parameters, config):
        set_parameters(net, parameters)
        loss, accuracy = test(net, testloader)
        return float(loss), len(trainloader.dataset), {"accuracy": float(accuracy)}

In [7]:
fl.client.start_numpy_client(
    server_address="127.0.0.1:8080", 
    client=FlowerClient()
)

INFO flwr 2023-11-30 14:13:50,239 | grpc.py:52 | Opened insecure gRPC connection (no certificates were passed)
DEBUG flwr 2023-11-30 14:13:50,254 | connection.py:42 | ChannelConnectivity.IDLE
DEBUG flwr 2023-11-30 14:13:50,257 | connection.py:42 | ChannelConnectivity.READY
DEBUG flwr 2023-11-30 14:16:27,614 | connection.py:141 | gRPC channel closed
INFO flwr 2023-11-30 14:16:27,624 | app.py:304 | Disconnect and shut down
