<a href="https://colab.research.google.com/github/Berenice2018/DeepLearning/blob/master/Securing_Federated_Learning_Encrypted_Gradients.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
!pip install syft

In [0]:
import torch as th
import syft as sy

from torch import nn, optim


In [0]:
class virtual_worker:
    
    def __init__(self, hook, name, data, target, lr=0.1, epochs=3):
        self.worker = sy.VirtualWorker(hook, name) #create worker
        self.epochs = epochs
        self.name = name
        self.model = nn.Linear(data.shape[1], 1) # worker's own model
        self.optim = optim.SGD(self.model.parameters(), lr=lr) # worker's own optim
        self.data = data
        self.target = target
    
    def send_data(self): # send train, weights & target to virtual worker
        
        self.data_ptr = self.data.send(self.worker)
        self.target_ptr = self.target.send(self.worker) 
        self.model = self.model.send(self.worker) #send model to remote worker

In [0]:

class federated_model:

    def __init__(self, hook, weight_dims, remote_workers):
        
        self.remote_workers = remote_workers
        self.secure_worker = sy.VirtualWorker(hook, "secure_worker")
        self.model = nn.Linear(data.shape[1],1)
        #self.weights_dims = weight_dims
        self.remote_weights = []
        self.remote_biases = []


    def get_gradients(self): 
      
        weights_data = th.stack([ptr for ptr in self.remote_weights]).get().child.child.float().mean(dim=0)
        bias_data = th.stack([ptr for ptr in self.remote_biases]).get().child.child.float().mean()
       
        #copy weights_data and bias_data into secure model's weights and biases
        with th.no_grad():
            self.model.weight.set_(weights_data)
            self.model.bias.set_(bias_data)

        #for debugging: get loss of secure worker's model. Just a simple forward pass
        data = self.remote_workers[-1].data
        labels = self.remote_workers[-1].target
        
        predictions = self.model(data)
        loss = ((predictions - labels)**2).sum()
        print("\n\tloss of secure worker = {:.6f}".format(loss))

        
        
        
    def train(self): # training func for remote workers
        #train each remote worker with their models
        weights = []
        
        for worker_obj in self.remote_workers: 
            worker_name = worker_obj.name
            worker_epochs = worker_obj.epochs
           
            print("*"*30)
            print("\n\ttraining Model On Worker {}\n".format(worker_name.title()))
            
            for epoch in range(worker_epochs):

                print("\nepoch : {}/{}".format(epoch, worker_epochs))
                worker_obj.optim.zero_grad()
                
                preds_ptr = worker_obj.model(worker_obj.data_ptr) # forward pass
                loss_ptr = ((preds_ptr - worker_obj.target_ptr)**2).sum() #  loss value
                loss_ptr.backward()
                
                worker_obj.optim.step()
                
                loss_value = loss_ptr.get()
                print("raw loss : {:.6f}".format(loss_value))
        
        
        # encrypt weights and biases before sending to secure worker
        for worker_obj_index, worker_obj in enumerate(self.remote_workers):

            print("\n\tEncrypting weights and biases of worker : {}".format(worker_obj.name.title()))
          
            weights = worker_obj.model.weight.get()
            weights_enc = weights.fix_prec()
            bias = worker_obj.model.bias.get()
            bias_enc = bias.fix_prec()
            
            #share the weights and biases of this worker with all other remote workers and secure worker
            weights_ptr = weights_enc.share(self.remote_workers[0].worker,
                                  self.remote_workers[1].worker,
                                  self.remote_workers[2].worker,
                                  self.secure_worker)
            
            bias_ptr = bias_enc.share(self.remote_workers[0].worker,
                                  self.remote_workers[1].worker,
                                  self.remote_workers[2].worker,
                                  self.secure_worker)
                        
            self.remote_weights.append(weights_ptr)
            self.remote_biases.append(bias_ptr)
        
        self.get_gradients() # accumulate all gradients
        


### Mock dataset

In [0]:
samples = 1
features = 2

target=th.rand(samples,1)
data=th.rand(samples, features)

###  workers

In [11]:
hook=sy.TorchHook(th)

worker_names=["ada","bob","cyd"]

workers=[]

for name in worker_names:    
    worker_obj=virtual_worker(hook,name,data,target)
    worker_obj.send_data()
    workers.append(worker_obj)
    
print(len(workers))

W0728 21:28:06.355416 140295094265728 hook.py:98] Torch was already hooked... skipping hooking process


3


In [0]:
# instatiate and start training

federated_model(hook, features, workers).train()        