In [1]:
import torch
import MiniFL as mfl

In [2]:
def run_client(i:int, client: mfl.algorithms.gd.GDClient):
    client.prepare()
    for _ in range(100):
        _ = client.step()
        
def run_master(master: mfl.algorithms.gd.GDMaster):
    master.prepare()
    for i in range(100):
        print(f"Master: {master.step()}")

In [3]:
from sklearn.datasets import load_svmlight_file
from copy import deepcopy

NUM_CLIENTS = 20

data, labels = load_svmlight_file("phishing.txt")
enc_labels = labels.copy()
data_dense = data.todense()


eval_data = (torch.from_numpy(data_dense).to(torch.float32), torch.from_numpy(enc_labels).to(torch.float32)[:, None])
partition_size = (len(eval_data[0]) - 1) // NUM_CLIENTS + 1
clients_data = [(x, y) for x, y in zip(torch.split(eval_data[0], partition_size, dim=0), torch.split(eval_data[1], partition_size, dim=0))]

NETWORK= torch.nn.Linear(eval_data[0].shape[1],1, bias=False)
LOSS = torch.nn.BCEWithLogitsLoss()

master_fn = mfl.fn.NNDifferentiableFn(
    model=NETWORK,
    data=eval_data,
    loss_fn=LOSS,
    batch_size=-1,
    seed=0,
)

client_fns = [mfl.fn.NNDifferentiableFn(
    model=deepcopy(NETWORK),
    data=clients_data[i],
    loss_fn=LOSS,
    batch_size=-1,
    seed=i,
) for i in range(NUM_CLIENTS)]

master, clients = mfl.algorithms.get_permk_marina_master_and_clients(
    master_fn=master_fn,
    client_fns=client_fns,
    gamma=1,
    # rand_p=0.5,
    # top_p=0.5,
    p=0.01,
)

In [4]:
import threading

client_threads = []
for i, client in enumerate(clients):
    client_threads.append(threading.Thread(target=run_client, args=(i, client)))
    client_threads[-1].start()
    
master_thread = threading.Thread(target=run_master, args=(master,))
master_thread.start()

master_thread.join()
for t in client_threads:
    t.join()
    


Master: 0.6802643537521362
Master: 0.6745249629020691
Master: 0.669035792350769
Master: 0.6637405753135681
Master: 0.6586084365844727
Master: 0.6536110043525696
Master: 0.6487317681312561
Master: 0.6439497470855713
Master: 0.6392579674720764
Master: 0.634649932384491
Master: 0.6301199197769165
Master: 0.6256650686264038
Master: 0.6212836503982544
Master: 0.6169703602790833
Master: 0.6127251982688904
Master: 0.6085468530654907
Master: 0.6044349670410156
Master: 0.600385308265686
Master: 0.5963951945304871
Master: 0.5924663543701172
Master: 0.5885939598083496
Master: 0.5847803354263306
Master: 0.5810237526893616
Master: 0.5773242115974426
Master: 0.5736814141273499
Master: 0.5700929164886475
Master: 0.566557765007019
Master: 0.5630752444267273
Master: 0.559644341468811
Master: 0.5562624335289001
Master: 0.5529279112815857
Master: 0.5496391654014587
Master: 0.5463957786560059
Master: 0.5431985259056091
Master: 0.5400471687316895
Master: 0.5369425415992737
Master: 0.5338832139968872
Master

In [5]:
clients[0].data_sender.n_bits_passed

16960