In [1]:
import torch
import MiniFL as mfl

In [2]:
def run_client(i:int, client: mfl.algorithms.gd.GDClient):
    client.prepare()
    for _ in range(100):
        _ = client.step()
        
def run_master(master: mfl.algorithms.gd.GDMaster):
    master.prepare()
    for i in range(100):
        print(f"Master: {master.step()}")

In [3]:
from sklearn.datasets import load_svmlight_file
from copy import deepcopy

NUM_CLIENTS = 20

data, labels = load_svmlight_file("phishing.txt")
enc_labels = labels.copy()
data_dense = data.todense()


eval_data = (torch.from_numpy(data_dense).to(torch.float32), torch.from_numpy(enc_labels).to(torch.float32)[:, None])
partition_size = (len(eval_data[0]) - 1) // NUM_CLIENTS + 1
clients_data = [(x, y) for x, y in zip(torch.split(eval_data[0], partition_size, dim=0), torch.split(eval_data[1], partition_size, dim=0))]

NETWORK= torch.nn.Linear(eval_data[0].shape[1],1, bias=False)
LOSS = torch.nn.BCEWithLogitsLoss()

master_fn = mfl.fn.NNDifferentiableFn(
    model=NETWORK,
    data=eval_data,
    loss_fn=LOSS,
    batch_size=-1,
    seed=0,
)

client_fns = [mfl.fn.NNDifferentiableFn(
    model=deepcopy(NETWORK),
    data=clients_data[i],
    loss_fn=LOSS,
    batch_size=-1,
    seed=i,
) for i in range(NUM_CLIENTS)]

master, clients = mfl.algorithms.get_cocktailgd_master_and_clients(
    master_fn=master_fn,
    client_fns=client_fns,
    gamma=1,
    rand_p=0.5,
    top_p=0.5,
    # p=0.01,
)

In [4]:
import threading

client_threads = []
for i, client in enumerate(clients):
    client_threads.append(threading.Thread(target=run_client, args=(i, client)))
    client_threads[-1].start()
    
master_thread = threading.Thread(target=run_master, args=(master,))
master_thread.start()

master_thread.join()
for t in client_threads:
    t.join()
    


Master: 0.7053979635238647
Master: 0.7036067247390747
Master: 0.7003681063652039
Master: 0.6975622177124023
Master: 0.6895534992218018
Master: 0.685336709022522
Master: 0.677385151386261
Master: 0.6746073365211487
Master: 0.6680946350097656
Master: 0.6602680683135986
Master: 0.6576413512229919
Master: 0.6508966088294983
Master: 0.6483407616615295
Master: 0.6419381499290466
Master: 0.6350042819976807
Master: 0.632785439491272
Master: 0.6268914937973022
Master: 0.6225146651268005
Master: 0.6177380084991455
Master: 0.6132492423057556
Master: 0.6088639497756958
Master: 0.6048448085784912
Master: 0.6000187397003174
Master: 0.5959621667861938
Master: 0.5930341482162476
Master: 0.5907672047615051
Master: 0.5867668986320496
Master: 0.5839431881904602
Master: 0.5778256058692932
Master: 0.5750718116760254
Master: 0.5718031525611877
Master: 0.5659629702568054
Master: 0.5631652474403381
Master: 0.5614697933197021
Master: 0.5578610897064209
Master: 0.5531708598136902
Master: 0.550542950630188
Maste

In [5]:
clients[0].data_sender.n_bits_passed

10200