In [1]:
import torch
import MiniFL as mfl

In [2]:
def run_client(i:int, client: mfl.algorithms.gd.GDClient):
    client.prepare()
    for _ in range(100):
        _ = client.step()
        
def run_master(master: mfl.algorithms.gd.GDMaster):
    master.prepare()
    for i in range(100):
        print(f"Master: {master.step()}")

In [7]:
from sklearn.datasets import load_svmlight_file

NUM_CLIENTS = 20

data, labels = load_svmlight_file("phishing.txt")
enc_labels = labels.copy()
data_dense = data.todense()


eval_data = (torch.from_numpy(data_dense).to(torch.float32), torch.from_numpy(enc_labels).to(torch.float32)[:, None])
partition_size = (len(eval_data[0]) - 1) // NUM_CLIENTS + 1
clients_data = [(x, y) for x, y in zip(torch.split(eval_data[0], partition_size, dim=0), torch.split(eval_data[1], partition_size, dim=0))]

model = torch.nn.Linear(eval_data[0].shape[1], 1, bias=False)
loss_fn = torch.nn.BCEWithLogitsLoss()

master, clients = mfl.algorithms.get_marina_master_and_clients(
    gamma=100,
    clients_data=clients_data,
    eval_data=eval_data,
    model=model,
    loss_fn=loss_fn,
    compressors=[mfl.compressors.PermKUnbiasedCompressor(rank=rank, world_size=NUM_CLIENTS) for rank in range(NUM_CLIENTS)],
    p=0.1,
)

In [8]:
import threading

client_threads = []
for i, client in enumerate(clients):
    client_threads.append(threading.Thread(target=run_client, args=(i, client)))
    client_threads[-1].start()
    
master_thread = threading.Thread(target=run_master, args=(master,))
master_thread.start()

master_thread.join()
for t in client_threads:
    t.join()
    


Master: 0.6530994772911072
Master: 0.662650465965271
Master: 0.6606312394142151
Master: 0.6377838850021362
Master: 0.6397221088409424
Master: 0.6394876837730408
Master: 0.6383771300315857
Master: 0.636985182762146
Master: 0.6371131539344788
Master: 0.6359836459159851
Master: 0.6359114646911621
Master: 0.6359698176383972
Master: 0.63599693775177
Master: 0.6149914264678955
Master: 0.6144819855690002
Master: 0.6157960891723633
Master: 0.615833580493927
Master: 0.6157559752464294
Master: 0.61569744348526
Master: 0.6156364679336548
Master: 0.6156364679336548
Master: 0.6156404614448547
Master: 0.6156342029571533
Master: 0.6156347393989563
Master: 0.6156372427940369
Master: 0.6156644821166992
Master: 0.6156615614891052
Master: 0.6156803965568542
Master: 0.5965217351913452
Master: 0.5969430804252625
Master: 0.5968382954597473
Master: 0.5968732833862305
Master: 0.5970098972320557
Master: 0.596825361251831
Master: 0.5968292951583862
Master: 0.5968153476715088
Master: 0.5791747570037842
Master: 0

In [9]:
clients[0].data_sender.n_bits_passed

31424