### Imports & Settings

In [1]:
from collections import Counter
import time

### Model classes

In [2]:
#Server 
class FairServer():
    def __init__(self):
        self.model = LearningModel()
        self.workers = set()
        self.groups = Counter()
        
    def register_worker(self, worker):
        self.workers.add(worker)
        
    def register_feature_group(self, group):
        self.groups.update(group)

    def send_model(self, worker):
        worker.load_model(self.model)
        
    def load_weights(self, W):
        self.model.set_weights(W)
        
    def request_training(self, w):
        w.train()
    
    def train(self):
        for w in self.workers:
            self.send_model(w)
            self.request_training(w)
            print(self.model.W)
            
    def fair_metrics(self):
        return 'Fair metrics: '

In [3]:
#Worker 
class Worker():
    def __init__(self, server, data):
        self.server = server
        self.model = {}
        self.data = data
        
    def send_registration(self):
        self.server.register_worker(self)
        
    def send_feature_group(self):
        self.server.register_feature_group('test')
        
    def load_model(self, model):
        self.model = model
        
    def send_weights(self, W):
        self.server.load_weights(W)
        
    def train(self):
        self.model.train(self.data)
        self.send_weights(self.model.W)
        self.send_feature_group()

In [4]:
class LearningModel:
    def __init__(self):
        self.W = [0, 0]
        pass
    
    def set_weights(self, W):
        self.W = W
    
    def train(self, X=None, y=None):
        self.W += [1, 1]
    
    def predict(self, x):
        pass 

### Utils functions

In [5]:
#Create a server
def create_server():
    return FairServer()

#Create workers
def create_workers(server, data=None):
    return[Worker(server, {}) for i in range(5)]

#Register all workers
def register_workers(workers):
    for worker in workers:
        worker.send_registration()
        

def start_training(server):
    start_train = time.time()
    server.train()
    end_train = time.time()
    print('Training time:' + str(end_train - start_train))
    
def znp_fed_training(model=None, data=None):
    s = create_server()
    ws = create_workers(s)
    register_workers(ws)
    start_training(s)
    print(s.fair_metrics())

In [6]:
znp_fed_training()

[0, 0, 1, 1]
[0, 0, 1, 1, 1, 1]
[0, 0, 1, 1, 1, 1, 1, 1]
[0, 0, 1, 1, 1, 1, 1, 1, 1, 1]
[0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
Training time:0.003000020980834961
Fair metrics: 
