In [5]:
import torch
from torch import Tensor, nn
from typing import Collection, Tuple, Mapping
from queue import SimpleQueue
from copy import deepcopy

class BasicCompressor:
    def __init__(self, shapes: Mapping[str, torch.Size]):
        self.shapes = shapes
    
    def compress(self, grad_dict: Mapping[str, Tensor]) -> Collection[Tensor]:
        return (torch.cat(tuple(grad_dict[name].flatten() for name in sorted(grad_dict))),)
                    
    def decompress(self, data: Collection[Tensor]) -> Mapping[str, Tensor]:
        x = data[0]
        grad_dict = {}
        for name in sorted(self.shapes):
            shape = self.shapes[name]
            grad_dict[name] = x[:shape.numel()].view(*shape)
            x = x[shape.numel():]
        return grad_dict
    
    
def get_compressor(model: nn.Module) -> BasicCompressor:
    return BasicCompressor(shapes={k: v.shape for k, v in model.named_parameters()})
    
def get_num_bits(dtype: torch.dtype) -> int:
    if dtype.is_floating_point:
        return torch.finfo(dtype).bits
    else:
        return torch.iinfo(dtype).bits

class DataSender:
    def __init__(self, queue: SimpleQueue) -> None:
        self.queue = queue
        self.n_bits_passed = 0
        
    def send(self, data: Collection[Tensor]):
        for tensor in data:
            self.n_bits_passed += get_num_bits(tensor.dtype) * tensor.numel()
            
        self.queue.put(data)
        

class DataReceiver:
    def __init__(self, queue: SimpleQueue) -> None:
        self.queue = queue
        
    def recv(self) -> Collection[Tensor]:
        return self.queue.get()
    
def get_sender_receiver() -> Tuple[DataSender, DataReceiver]:
    queue = SimpleQueue()
    return DataSender(queue=queue), DataReceiver(queue=queue)
        
        

def get_grad_dict(module: nn.Module) -> Mapping[str, Tensor]:
    return {k:v.grad.detach() for k,v in module.named_parameters()}

def add_grad_dict(module: nn.Module, grad_dict: Mapping[str, Tensor]):
    for k,v in module.named_parameters():
        v.grad = grad_dict[k]

class Client:
    def __init__(self, lr: float, data: Tuple[Tensor, Tensor], model: nn.Module, loss_fn, data_sender: DataSender, data_receiver: DataReceiver, compressor: BasicCompressor):
        self.data = data
        self.model = model
        self.optimizer = torch.optim.SGD(self.model.parameters(), lr=lr)
        self.loss_fn = loss_fn
        self.data_sender = data_sender
        self.data_receiver = data_receiver
        self.compressor = compressor
        
    def send_grad_get_loss(self):
        self.optimizer.zero_grad()
        
        loss = self.loss_fn(self.model(self.data[0]), self.data[1])
        loss.backward()
            
        grad_dict = get_grad_dict(self.model)
        msg = self.compressor.compress(grad_dict=grad_dict)
        self.data_sender.send(msg)
        return float(loss)
        
        
    def apply_global_step(self):
        msg = self.data_receiver.recv()
        grad_dict = self.compressor.decompress(msg)
        
        self.optimizer.zero_grad()
        add_grad_dict(self.model, grad_dict=grad_dict)
        self.optimizer.step()
        
        
class Master:
    def __init__(self, lr: float, eval_data: Tuple[Tensor, Tensor],  model: nn.Module, data_senders: Collection[DataSender], data_receivers: Collection[DataReceiver], compressors: Collection[BasicCompressor], loss_fn):
        self.eval_data = eval_data
        self.model = model
        self.optimizer = torch.optim.SGD(self.model.parameters(), lr=lr)
        
        self.data_senders = data_senders
        self.data_receivers = data_receivers
        self.compressors = compressors
        self.loss_fn = loss_fn
        
    def scale_grads(self, scale: float):
        for v in self.model.parameters():
            v.grad *= scale
    
    def round(self) -> float:
        self.model.zero_grad()
        for receiver, compressor in zip(self.data_receivers, self.compressors):
            msg = receiver.recv()
            grad_dict = compressor.decompress(msg)
            add_grad_dict(self.model, grad_dict=grad_dict)    
        self.scale_grads(1 / len(self.data_senders))
        self.optimizer.step()
        
        for sender, compressor in zip(self.data_senders, self.compressors):
            grad_dict = get_grad_dict(self.model)
            msg = compressor.compress(grad_dict=grad_dict)
            sender.send(msg)
            
        with torch.no_grad():
            return float(self.loss_fn(self.model(self.eval_data[0]), self.eval_data[1]))
            
            
def get_master_and_clients(lr: float, clients_data: Collection[Collection[Tuple[Tensor, Tensor]]], eval_data: Collection[Tuple[Tensor, Tensor]], model: nn.Module, loss_fn) -> Tuple[Master, Collection[Client]]:
    num_clients = len(clients_data)
    
    uplink_comms = [get_sender_receiver() for _ in range(num_clients)]
    downlink_comms = [get_sender_receiver() for _ in range(num_clients)]
    compressors = [get_compressor(model=model) for _ in range(num_clients)]
    client_models = [deepcopy(model) for _ in range(num_clients)]
    
    master = Master(
        lr=lr,
        eval_data=eval_data,
        model=model,
        data_senders=[s for s, r in downlink_comms],
        data_receivers=[r for s, r in uplink_comms],
        compressors=compressors,
        loss_fn=loss_fn,
    )
    
    clients = []
    for i in range(num_clients):
        client = Client(
            lr=lr,
            data=clients_data[i],
            model=client_models[i],
            loss_fn=loss_fn,
            data_sender=uplink_comms[i][0],
            data_receiver=downlink_comms[i][1],
            compressor=compressors[i],
        )
        clients.append(client)
    
    return master, clients
        

In [6]:
def run_client(i:int, client: Client):
    for _ in range(10):
        loss = client.send_grad_get_loss()
        # print(f"Client {i}: {loss}")
        client.apply_global_step()
        
def run_master(master: Master):
    for i in range(10):
        print(f"Master: {master.round()}")

In [7]:
from sklearn.datasets import load_svmlight_file

NUM_CLIENTS = 20

data, labels = load_svmlight_file("phishing.txt")
enc_labels = labels.copy()
data_dense = data.todense()

eval_data = (torch.from_numpy(data_dense).to(torch.float32), torch.from_numpy(enc_labels).to(torch.float32)[:, None])
clients_data = [(x, y) for x, y in zip(torch.split(eval_data[0], len(eval_data[0]) // NUM_CLIENTS, dim=0), torch.split(eval_data[1], len(eval_data[1]) // NUM_CLIENTS, dim=0))]

model = torch.nn.Linear(eval_data[0].shape[1], 1, bias=False)
loss_fn = torch.nn.BCEWithLogitsLoss()

master, clients = get_master_and_clients(
    lr=200,
    clients_data=clients_data,
    eval_data=eval_data,
    model=model,
    loss_fn=loss_fn,
)

In [8]:
import threading

client_threads = []
for i, client in enumerate(clients):
    client_threads.append(threading.Thread(target=run_client, args=(i, client)))
    client_threads[-1].start()
    
master_thread = threading.Thread(target=run_master, args=(master,))
master_thread.start()

master_thread.join()
for t in client_threads:
    t.join()
    


Master: 0.6558125615119934
Master: 0.611517608165741
Master: 0.5869171619415283
Master: 0.5614525675773621
Master: 0.5421739816665649
Master: 0.5248126983642578
Master: 0.5100699663162231
Master: 0.49705731868743896
Master: 0.4855799674987793
Master: 0.4753459393978119
