In [1]:
import torch
import MiniFL as mfl

In [2]:
def run_client(i:int, client: mfl.algorithms.gd.Client):
    for _ in range(100):
        loss = client.send_grad_get_loss()
        # print(f"Client {i}: {loss}")
        client.apply_global_step()
        
def run_master(master: mfl.algorithms.gd.Master):
    for i in range(100):
        print(f"Master: {master.round()}")

In [3]:
from sklearn.datasets import load_svmlight_file

NUM_CLIENTS = 20

data, labels = load_svmlight_file("phishing.txt")
enc_labels = labels.copy()
data_dense = data.todense()

eval_data = (torch.from_numpy(data_dense).to(torch.float32), torch.from_numpy(enc_labels).to(torch.float32)[:, None])
clients_data = [(x, y) for x, y in zip(torch.split(eval_data[0], len(eval_data[0]) // NUM_CLIENTS, dim=0), torch.split(eval_data[1], len(eval_data[1]) // NUM_CLIENTS, dim=0))]

model = torch.nn.Linear(eval_data[0].shape[1], 1, bias=False)
loss_fn = torch.nn.BCEWithLogitsLoss()

master, clients = mfl.algorithms.gd.get_master_and_clients(
    lr=200,
    clients_data=clients_data,
    eval_data=eval_data,
    model=model,
    loss_fn=loss_fn,
)

In [4]:
import threading

client_threads = []
for i, client in enumerate(clients):
    client_threads.append(threading.Thread(target=run_client, args=(i, client)))
    client_threads[-1].start()
    
master_thread = threading.Thread(target=run_master, args=(master,))
master_thread.start()

master_thread.join()
for t in client_threads:
    t.join()
    


Master: 0.682962954044342
Master: 0.6829753518104553
Master: 0.6820846796035767
Master: 0.680315375328064
Master: 0.6803724765777588
Master: 0.6803614497184753
Master: 0.6839784383773804
Master: 0.685096025466919
Master: 0.685072660446167
Master: 0.6837006211280823
Master: 0.6840313673019409
Master: 0.6838663816452026
Master: 0.6838663816452026
Master: 0.6839240193367004
Master: 0.6839357018470764
Master: 0.6845172643661499
Master: 0.684398353099823
Master: 0.6845191717147827
Master: 0.6790302991867065
Master: 0.6791834235191345
Master: 0.6787835955619812
Master: 0.6778480410575867
Master: 0.6787332892417908
Master: 0.678732693195343
Master: 0.6787346005439758
Master: 0.6786686778068542
Master: 0.6676605343818665
Master: 0.6680176258087158
Master: 0.6661996245384216
Master: 0.6661432981491089
Master: 0.6662565469741821
Master: 0.665495753288269
Master: 0.6656352281570435
Master: 0.664916455745697
Master: 0.6657066941261292
Master: 0.6654624342918396
Master: 0.6646496653556824
Master: 0

In [5]:
clients[0].data_sender.n_bits_passed

400.0