In [1]:
import torch
import MiniFL as mfl

In [2]:
def run_client(i:int, client: mfl.algorithms.gd.GDClient):
    client.prepare()
    for _ in range(100):
        _ = client.step()
        
def run_master(master: mfl.algorithms.gd.GDMaster):
    master.prepare()
    for i in range(100):
        print(f"Master: {master.step()}")

In [3]:
from sklearn.datasets import load_svmlight_file
from copy import deepcopy

NUM_CLIENTS = 20

data, labels = load_svmlight_file("phishing.txt")
enc_labels = labels.copy()
data_dense = data.todense()


eval_data = (torch.from_numpy(data_dense).to(torch.float32), torch.from_numpy(enc_labels).to(torch.float32)[:, None])
partition_size = (len(eval_data[0]) - 1) // NUM_CLIENTS + 1
clients_data = [(x, y) for x, y in zip(torch.split(eval_data[0], partition_size, dim=0), torch.split(eval_data[1], partition_size, dim=0))]

NETWORK= torch.nn.Linear(eval_data[0].shape[1],1, bias=False)
LOSS = torch.nn.BCEWithLogitsLoss()

master_fn = mfl.fn.NNDifferentiableFn(
    model=NETWORK,
    data=eval_data,
    loss_fn=LOSS,
    batch_size=-1,
    seed=0,
)

client_fns = [mfl.fn.NNDifferentiableFn(
    model=deepcopy(NETWORK),
    data=clients_data[i],
    loss_fn=LOSS,
    batch_size=-1,
    seed=i,
) for i in range(NUM_CLIENTS)]

master, clients = mfl.algorithms.get_gd_master_and_clients(
    master_fn=master_fn,
    client_fns=client_fns,
    gamma=1,
    # rand_p=0.5,
    # top_p=0.5,
    # p=0.1,
)

In [4]:
import threading

client_threads = []
for i, client in enumerate(clients):
    client_threads.append(threading.Thread(target=run_client, args=(i, client)))
    client_threads[-1].start()
    
master_thread = threading.Thread(target=run_master, args=(master,))
master_thread.start()

master_thread.join()
for t in client_threads:
    t.join()
    


Master: 0.6919958591461182
Master: 0.6855571866035461
Master: 0.679527997970581
Master: 0.6738094091415405
Master: 0.6683315634727478
Master: 0.663045346736908
Master: 0.6579157710075378
Master: 0.6529179811477661
Master: 0.6480341553688049
Master: 0.6432514786720276
Master: 0.6385605931282043
Master: 0.6339545845985413
Master: 0.6294281482696533
Master: 0.6249774098396301
Master: 0.620599091053009
Master: 0.616290807723999
Master: 0.6120502352714539
Master: 0.607875645160675
Master: 0.6037653684616089
Master: 0.5997180938720703
Master: 0.5957323908805847
Master: 0.5918070077896118
Master: 0.5879408717155457
Master: 0.5841327905654907
Master: 0.5803816318511963
Master: 0.5766863226890564
Master: 0.5730459690093994
Master: 0.5694595575332642
Master: 0.5659260153770447
Master: 0.5624443888664246
Master: 0.5590137839317322
Master: 0.5556332468986511
Master: 0.5523019433021545
Master: 0.5490188598632812
Master: 0.545783281326294
Master: 0.5425941348075867
Master: 0.5394508242607117
Master:

In [5]:
clients[0].data_sender.n_bits_passed

217600