### Imports & Settings

In [1]:
from collections import Counter
import time
import numpy as np
import nn_model as nn
from utils import properties as p
from utils import get_rnd_ethnicity
import hashlib

### Parameters

In [2]:
epochs_n = 1
learning_rate = 0.1
seed = 19101995

### Model classes

In [3]:
#Server 
class FairServer:
    def __init__(self, model):
        self.model = LearningModel(model)
        self.workers = []
        self.groups = Counter()
        
    def register_worker(self, worker):
        self.workers += [worker]
        
    def register_feature_group(self, group):
        self.groups.update({group : 1})

    def send_model(self, worker):
        worker.load_model(self.model)
        
    def load_weights(self, W):
        self.model.set_weights(W)
        
    def request_training(self, w):
        w.train()
    
    def train(self):
        for w in self.workers:
            self.send_model(w)
            self.request_training(w)
            
    def fair_metrics(self):
        return 'Fair metrics: '
    
    def predict(self, x):
        return self.model.predict(x)

In [4]:
#Worker 
class Worker():
    def __init__(self, server, x, y, secret_feature):
        self.server = server
        self.model = {}
        self.x = x
        self.y = y
        self.secret_feature = secret_feature
        
    def send_registration(self):
        self.server.register_worker(self)
        
    def send_feature_group(self):
        secret = hashlib.sha256()
        secret.update(self.secret_feature.encode('utf-8'))
        self.server.register_feature_group(secret.hexdigest())
        
    def load_model(self, model):
        self.model = model
        
    def send_weights(self, W):
        self.server.load_weights(W)
        
    def train(self):
        self.model.train(self.x, self.y)
        self.send_weights(self.model.get_weights())
        self.send_feature_group()

In [5]:
class LearningModel:
    def __init__(self, model, epochs=epochs_n, learning_rate=learning_rate):
        self.model = model
        self.epochs = epochs
        self.learning_rate = learning_rate
    
    def set_weights(self, W):
        self.model.load_weights(W)
        
    def get_weights(self):
        return self.model.get_weights()
    
    def train(self, X, y):
        self.model.fit(X, y, self.epochs, self.learning_rate)
    
    def predict(self, x):
        return self.model.predict(x)

### Utils functions

In [6]:
#Create a server
def create_server(model):
    return FairServer(model)

#Create workers
def create_workers(server, X, y, secret_feature):
    return[Worker(server, X[i], y[i], secret_feature[i]) for i in range(len(X))]

#Register all workers
def register_workers(workers):
    for worker in workers:
        worker.send_registration()
        
def start_training(server):
    start_train = time.time()
    server.train()
    end_train = time.time()
    #print('Training time:' + str(end_train - start_train))
    
def znp_fed_training(model, X, y, secret):
    s = create_server(model)
    ws = create_workers(s, X, y, secret)
    register_workers(ws)
    start_training(s)
    #print(s.fair_metrics())
    return s, ws

### Model for reference

In [7]:
# network (not federated)
net = nn.get_3l_nn(2, 2, 1)

X = np.array([[[[0, 0]]], [[[0, 1]]], [[[1, 0]]], [[[1, 1]]]])
y = np.array([[[[0]]], [[[1]]], [[[1]]], [[[0]]]])

#train
for i in range(len(X)):
    net.fit(X[i], y[i], epochs=epochs_n, learning_rate=learning_rate)

# test
out = net.predict(X)

### Model for Fed-Learning

In [8]:
# network (for federated learning)
net2 = nn.get_3l_nn(2, 2, 1)

#train
s, ws = znp_fed_training(net2, X, y, [get_rnd_ethnicity() for i in range(len(X))])

#test
out2 = s.predict(X)

### Test if the 2 models behave in the same way

In [9]:
np.array_equal(out, out2)

True

### Test if workers retain feature [ethnicity]

In [13]:
def encode_sha256(s):
    secret = hashlib.sha256()
    secret.update(s.encode('utf-8'))
    return secret.hexdigest()

In [15]:
w_groups = [w.secret_feature for w in ws] 
h_groups = [encode_sha256(s) for s in w_groups]
cnt = Counter()
for h in h_groups:
    cnt.update({h : 1})
    
#check if the two counters are equal
#w_groups, s.groups, cnt
s.groups & cnt == s.groups

(['Indian', 'Other', 'White', 'Indian'],
 Counter({'67081fb2a08a41f6e8e8b79e6319f9bb5de5d30a84a681ca22fbea019b75114d': 2,
          'f97e9da0e3b879f0a9df979ae260a5f7e1371edb127c1862d4f861981166cdc1': 1,
          '3495e757855a5c678addcf32516274e2962d0572f065378dba689e22168f28dd': 1}),
 Counter({'67081fb2a08a41f6e8e8b79e6319f9bb5de5d30a84a681ca22fbea019b75114d': 2,
          'f97e9da0e3b879f0a9df979ae260a5f7e1371edb127c1862d4f861981166cdc1': 1,
          '3495e757855a5c678addcf32516274e2962d0572f065378dba689e22168f28dd': 1}))