-
Notifications
You must be signed in to change notification settings - Fork 3
/
client.py
76 lines (58 loc) · 2.63 KB
/
client.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import warnings
import flwr as fl
import numpy as np
import os
import argparse
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import log_loss
import utils
class FlowerClient(fl.client.NumPyClient):
"""Client interface"""
def get_parameters(self, config):
return utils.get_model_parameters(model)
def fit(self, parameters, config):
utils.set_model_params(model, parameters)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
model.fit(X_train, y_train)
# print(f"Training finished for round {config['server_round']}")
return utils.get_model_parameters(model), len(X_train), {}
def evaluate(self, parameters, config):
utils.set_model_params(model, parameters)
loss = log_loss(y_test, model.predict_proba(X_test))
accuracy = model.score(X_test, y_test)
return loss, len(X_test), {"accuracy": accuracy}
if __name__ == "__main__":
print("extracting arguments")
parser = argparse.ArgumentParser()
"""Hyperparameters sent by the client"""
# Algorithm related hyperparameters
parser.add_argument("--penalty", type=str, default="l2")
parser.add_argument("--max-iter", type=int, default=3)
# Server ip address that client talks to
parser.add_argument("--server-address", type=str, default="0.0.0.0:8080") # "0.0.0.0:8080" for running on same machine
# Data, model, and output directories
parser.add_argument("--model-dir", type=str, default=os.environ.get("SM_MODEL_DIR"))
parser.add_argument("--train", type=str, default=os.environ.get("SM_CHANNEL_TRAIN"))
parser.add_argument("--test", type=str, default=os.environ.get("SM_CHANNEL_TEST"))
parser.add_argument("--train-file", type=str, default="cms_payment_train.csv")
parser.add_argument("--test-file", type=str, default="cms_payment_validation.csv")
args, _ = parser.parse_known_args()
# Load data
(X_train, y_train), (X_test, y_test) = utils.load_data(args.train, args.train_file, args.train, args.test_file)
print("Data Loaded", X_train.shape)
"""Initialize and train model on the client"""
model = LogisticRegression(
penalty = args.penalty,
max_iter = args.max_iter,
warm_start = True, # prevent refreshing weights when fitting
)
utils.set_initial_params(model)
print("parameters : ", utils.set_initial_params(model))
fl_client = FlowerClient()
print(fl_client)
fl.client.start_numpy_client(
server_address = args.server_address,
client = fl_client
)
utils.save_model(args.model_dir, model)