### Imports & Settings

In [1]:
from collections import Counter
import time
import numpy as np
import nn_model as nn
from utils import properties as p
from utils import get_rnd_ethnicity
import hashlib

In [2]:
from noknow.core import ZK, ZKSignature, ZKParameters, ZKData, ZKProof
import copy

### Parameters

In [3]:
epochs_n = 1
learning_rate = 0.1
seed = 19101995

features = {
    'ETHNICITY': ['ASIAN', 'BLACK', 'INDIAN', 'WHITE'],
    'GENDER': ['FEMALE', 'MALE']
}

not_features = {
    'ETHNICITY': ['ROCK', 'PAPER', 'SCISSORS', 'UNCATEGORIZED'],
    'GENDER': ['ZERO', 'ONE']
}

### ZKP Server

In [4]:
class ZKServer:
    def __init__(self, password, features):
        self.password = password
        self.features = features
        self.client_representations = {}
        self.tokens = {}
        self.initialized_zk_features = False
        for el in features:
            self.client_representations[el] = []
            self.tokens[el] = []
        self.zk = ZK.new(curve_name="secp384r1", hash_alg="sha3_512")
        self.signature: ZKSignature = self.zk.create_signature(password)
            
    def load_client_signature(self, signature):
        #load client signature
        client_signature = ZKSignature.load(signature)
        return ZK(client_signature.params)
    
    def create_token(self, feature, signature):
        #load client signature
        client_signature = ZKSignature.load(signature)
        client_zk = ZK(client_signature.params)
        # Create a signed token and send to the client
        token = self.zk.sign(self.password, client_zk.token()).dump(separator=":")
        self.tokens[feature] += [token]
        self.client_representations[feature] += [{'signature' : client_signature, 'zk' : client_zk}]
        assert token in self.tokens[feature]
        return token
        
    def verify_clients(self, feature, signature):
        # Get the token from the client
        proof = ZKData.load(signature)
        token = ZKData.load(proof.data, ":")
        
        proofs = []
        # In this example, the server signs the token so it can be sure it has not been modified
        for label in self.client_representations[feature]:
            client_signature = label['signature']
            client_zk = label['zk']
            proofs += [self.verify_client(proof, token, client_signature, client_zk)]
            
        return sum(proofs) >= 1
    
    def verify_client(self, proof, token, client_signature, client_zk):
        return self.zk.verify(token, self.signature) and \
                client_zk.verify(proof, client_signature, data=token)
    
    def verify(self, feature, signatures):
        return sum([self.verify_clients(feature, s) for s in signatures]) >= 1
    
    def init_zk(self, clients):
        if self.initialized_zk_features:
            print('Server already initialized')
            return
        for c in clients:
            self.create_token(c.feature, c.signature)
        #print('Server initialized with ' + str(len(clients)) + ' features')
        self.initialized_zk_features = True

### ZKP Client

In [5]:
class ZKClient:
    def __init__(self, feature, label):
        self.zk = ZK.new(curve_name="secp256k1", hash_alg="sha3_256")
        self.label = label
        self.feature = feature
        # Create signature and send to server
        self.signature = self.zk.create_signature(label).dump()
        
    def set_label(self, label):
        self.label = label
        self.signature = self.zk.create_signature(label).dump()
            
    def create_proof(self, server_token):
        # Create a proof that signs the provided token and sends to server
        proof = self.zk.sign(self.label, server_token).dump()
        return proof
    
    def create_proofs(self, server_tokens):
        return [self.create_proof(t) for t in server_tokens]

### Server Initializer

In [6]:
class ServerInitializer:
    def __init__(self, features, client_prototype):
        self.client_prototype = client_prototype
        self.features = features
        
    def create_clients(self):
        clients = []
        for feature in self.features:
            for label in self.features[feature]:
                client = copy.copy(self.client_prototype)
                client.feature = feature
                client.set_label(label)
                clients += [client]
        return clients

### Client Initializer

In [7]:
class ClientInitializer:
    def __init__(self):
        self.generated = False
    def generate_client_prototype():
        if not self.generated:
            self.generated = True
            return ZKInitClient('', '')

### Model classes

In [8]:
#Server 
class FairServer(ZKServer):
    def __init__(self, model):
        self.model = LearningModel(model)
        self.workers = []
        self.groups = Counter()
        
    def register_worker(self, worker):
        self.workers += [worker]
        
    def register_feature_group(self, group):
        self.groups.update({group : 1})

    def send_model(self, worker):
        worker.load_model(self.model)
        
    def load_weights(self, W):
        self.model.set_weights(W)
        
    def request_training(self, w):
        w.train()
    
    def train(self):
        for w in self.workers:
            self.send_model(w)
            self.request_training(w)
            
    def fair_metrics(self):
        return 'Fair metrics: '
    
    def predict(self, x):
        return self.model.predict(x)

In [9]:
#Worker 
class Worker(ZKClient):
    def __init__(self, server, x, y, secret_feature):
        self.server = server
        self.model = {}
        self.x = x
        self.y = y
        self.secret_feature = secret_feature
        
    def send_registration(self):
        self.server.register_worker(self)
        
    def send_feature_group(self):
        secret = hashlib.sha256()
        secret.update(self.secret_feature.encode('utf-8'))
        self.server.register_feature_group(secret.hexdigest())
        
    def load_model(self, model):
        self.model = model
        
    def send_weights(self, W):
        self.server.load_weights(W)
        
    def train(self):
        self.model.train(self.x, self.y)
        self.send_weights(self.model.get_weights())
        self.send_feature_group()

In [10]:
class LearningModel:
    def __init__(self, model, epochs=epochs_n, learning_rate=learning_rate):
        self.model = model
        self.epochs = epochs
        self.learning_rate = learning_rate
    
    def set_weights(self, W):
        self.model.load_weights(W)
        
    def get_weights(self):
        return self.model.get_weights()
    
    def train(self, X, y):
        self.model.fit(X, y, self.epochs, self.learning_rate)
    
    def predict(self, x):
        return self.model.predict(x)

### ZKP-FED-Framework

In [11]:
class ZKPFEDFramework:
    def __init__(self, features, not_features, data=[]):
        self.client_prototype = ClientInitializer().generate_client_prototype()
        self.features = features
        self.not_features = not_features #TESTING
        self.server = ZKServer('password', features)
        self.server_initializer = ServerInitializer(features, self.client_prototype)
        self.server.init_zk(self.server_initializer.create_clients())
        self.test_server()
        
    def test_server(self):
        #Create a mock server initializer with actual features
        si_f = ServerInitializer(self.features, self.client_prototype)
        
        #Create a mock server initializer with unauthorized features (not_features)
        si_n = ServerInitializer(self.not_features, self.client_prototype)
        
        #Create a mock server
        server = ZKServer('password', features)
        
        #Register authrized clients in the mock server
        server.init_zk(si_f.create_clients())
        
        #Create proofs for authorized users
        mock_clients_f = si_f.create_clients()
        proofs_f = [server.verify(c.feature, c.create_proofs(server.tokens[c.feature])) 
                  for c 
                  in mock_clients_f]
        
        #Create proofs for unauthorized users
        mock_clients_n = si_n.create_clients()
        proofs_n = [server.verify_clients(c.feature, c.create_proof(server.tokens[c.feature][0])) 
                  for c 
                  in mock_clients_n]
        
        #Assert that all the registered clients get access to the server
        assert sum(proofs_f) == len(mock_clients_f)
        
        #Assert that none of the non-registered clients get access to the server
        assert sum(proofs_n) == 0

### Utils functions

In [12]:
#Create a server
def create_server(model):
    return FairServer(model)

#Create workers
def create_workers(server, X, y, secret_feature):
    return[Worker(server, X[i], y[i], secret_feature[i]) for i in range(len(X))]

#Register all workers
def register_workers(workers):
    for worker in workers:
        worker.send_registration()
        
def start_training(server):
    start_train = time.time()
    server.train()
    end_train = time.time()
    #print('Training time:' + str(end_train - start_train))
    
def znp_fed_training(model, X, y, secret):
    s = create_server(model)
    ws = create_workers(s, X, y, secret)
    register_workers(ws)
    start_training(s)
    #print(s.fair_metrics())
    return s, ws

### Model for reference

In [13]:
# network (not federated)
net = nn.get_3l_nn(2, 2, 1)

X = np.array([[[[0, 0]]], [[[0, 1]]], [[[1, 0]]], [[[1, 1]]]])
y = np.array([[[[0]]], [[[1]]], [[[1]]], [[[0]]]])

#train
for i in range(len(X)):
    net.fit(X[i], y[i], epochs=epochs_n, learning_rate=learning_rate)

# test
out = net.predict(X)

### Model for Fed-Learning

In [14]:
# network (for federated learning)
net2 = nn.get_3l_nn(2, 2, 1)

#train
s, ws = znp_fed_training(net2, X, y, [get_rnd_ethnicity() for i in range(len(X))])

#test
out2 = s.predict(X)

### Test if the 2 models behave in the same way

In [15]:
np.array_equal(out, out2)

True

### Test if workers retain feature [ethnicity]

In [16]:
def encode_sha256(s):
    secret = hashlib.sha256()
    secret.update(s.encode('utf-8'))
    return secret.hexdigest()

In [17]:
w_groups = [w.secret_feature for w in ws] 
h_groups = [encode_sha256(s) for s in w_groups]
cnt = Counter()
for h in h_groups:
    cnt.update({h : 1})
    
#check if the two counters are equal
#w_groups, s.groups, cnt
#Intersection
s.groups & cnt == s.groups == cnt

True