# Lesson 8 Final Project: Training with Additive Secret Sharing
Aggregate the gradients using additive secret sharing & fixed precision encoding.

Use at least 3 data owners - this way we don't have to trust a secure aggregator with the gradients. No one ever needs to see the gradients that aren't their own.

##Import modules and create hook. 

In [1]:
!pip install syft

from torch import nn, optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset
import torch as th
import syft as sy 
import torchvision

import warnings
warnings.filterwarnings("ignore")

Collecting syft
[?25l  Downloading https://files.pythonhosted.org/packages/38/2e/16bdefc78eb089e1efa9704c33b8f76f035a30dc935bedd7cbb22f6dabaa/syft-0.1.21a1-py3-none-any.whl (219kB)
[K     |████████████████████████████████| 225kB 2.7MB/s 
[?25hCollecting zstd>=1.4.0.0 (from syft)
[?25l  Downloading https://files.pythonhosted.org/packages/22/37/6a7ba746ebddbd6cd06de84367515d6bc239acd94fb3e0b1c85788176ca2/zstd-1.4.1.0.tar.gz (454kB)
[K     |████████████████████████████████| 460kB 41.5MB/s 
Collecting tf-encrypted>=0.5.4 (from syft)
[?25l  Downloading https://files.pythonhosted.org/packages/55/ff/7dbd5fc77fcec0df1798268a6b72a2ab0150b854761bc39c77d566798f0b/tf_encrypted-0.5.7-py3-none-manylinux1_x86_64.whl (2.1MB)
[K     |████████████████████████████████| 2.1MB 39.5MB/s 
[?25hCollecting websocket-client>=0.56.0 (from syft)
[?25l  Downloading https://files.pythonhosted.org/packages/29/19/44753eab1fdb50770ac69605527e8859468f3c0fd7dc5a76dd9c4dbd7906/websocket_client-0.56.0-py2.py3-non

W0723 13:20:31.788064 140487084562304 secure_random.py:26] Falling back to insecure randomness since the required custom op could not be found for the installed version of TensorFlow. Fix this by compiling custom ops. Missing file was '/usr/local/lib/python3.6/dist-packages/tf_encrypted/operations/secure_random/secure_random_module_tf_1.14.0.so'
W0723 13:20:31.807106 140487084562304 deprecation_wrapper.py:119] From /usr/local/lib/python3.6/dist-packages/tf_encrypted/session.py:26: The name tf.Session is deprecated. Please use tf.compat.v1.Session instead.



In [0]:
hook = sy.TorchHook(th) #syft hooks onto pytorch - this adds pysyft functionality to our Torch module
bob = sy.VirtualWorker(hook, id="bob")
charlie = sy.VirtualWorker(hook, id="charlie")
joe = sy.VirtualWorker(hook, id="joe")

##Download datasets and create dataloaders.

In [3]:
trsfm = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,),(0.3081,))])
trainset = datasets.MNIST(root='./data', train=True, download=True, transform=trsfm)
federated_trainset = trainset.federate((bob, charlie, joe)) # directly federate imported datasets between x workers
print(federated_trainset)



0it [00:00, ?it/s]

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


9920512it [00:02, 3602605.18it/s]                             


Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz


0it [00:00, ?it/s]

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


32768it [00:00, 56919.28it/s]                           
0it [00:00, ?it/s]

Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


1654784it [00:01, 953475.44it/s]                             
0it [00:00, ?it/s]

Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


8192it [00:00, 21444.17it/s]            


Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz
Processing...
Done!
FederatedDataset
    Distributed accross: bob, charlie, joe
    Number of datapoints: 60000



In [0]:
def generate_loaders(workers_list, dataset, tp=0.2): 
  
    ''' Function generates 
            
            Trainloader 
            Testloader
            
        based on a given ratio for testset. Default is testset proportion == 0.2. '''
    
    NUM_W = len(workers_list) # number of workers
    TT = dataset.shape[0] if type(dataset) != torchvision.datasets.mnist.MNIST else dataset.data.shape[0] # total number of datapoints to be split 
    
    size = int(TT / NUM_W) # size of dataset that each federated worker will possess
    tr = int((1-0.2) * size)
    te = size - tr
    
    train_l, test_l = [], []
    for i, w in enumerate(workers_list): 
        idx = int(i*size)
        curr_train = sy.BaseDataset(dataset.train_data[idx:(idx+tr)],
                                   dataset.train_labels[idx:(idx+tr)]).send(w)
        curr_test = sy.BaseDataset(dataset.train_data[(idx+tr):(idx+tr+te)],
                                  dataset.train_labels[(idx+tr):(idx+tr+te)]).send(w)
        train_l.append(curr_train)
        test_l.append(curr_test)
    
    #FederatedDataset is created from a list of pointers to basedatasets
    trainloader = sy.FederatedDataLoader(sy.FederatedDataset(train_l), batch_size=64, shuffle=True)
    testloader = sy.FederatedDataLoader(sy.FederatedDataset(test_l), batch_size=64, shuffle=True)
    
    return trainloader, testloader

In [8]:
workers = [bob, charlie, joe]
fed_trainloader, fed_testloader = generate_loaders(workers, trainset)

<syft.frameworks.torch.federated.dataloader.FederatedDataLoader object at 0x7fc5195b4630>
((Wrapper)>[PointerTensor | me:22840010304 -> bob:18901421485], (Wrapper)>[PointerTensor | me:55742471293 -> bob:40390649273])


<VirtualWorker id:bob #objects:0>

# Security-level-n: Secret sharing)
### Create model on central server, send it out to train, and always aggregate each gradient using additive secret sharing, then send it to the central server to be decrypted and used as an update.

In [6]:
from collections import OrderedDict

model = nn.Sequential(OrderedDict([
     ('fc1', nn.Linear(784, 256)),
     ('relu1', nn.ReLU()),
     ('fc2', nn.Linear(256, 64)),
     ('relu2', nn.ReLU()),
     ('fc3', nn.Linear(64, 32)),
     ('relu3', nn.ReLU()), 
     ('fc4', nn.Linear(32, 10)),
     ('outputfx', nn.Softmax())
]))

criterion = nn.NLLLoss()

bob_opt = optim.SGD(params=model.parameters(), lr=0.1)
charlie_opt = optim.SGD(params=model.parameters(), lr=0.1)
joe_opt = optim.SGD(params=model.parameters(), lr=0.1)

optimisers = [bob_opt, charlie_opt, joe_opt]

# send model COPY 

for round_iter in range(10):
    for d, t in fed_trainloader: 

        model_ = model.copy().send(d.location)

        for i, w in enumerate(workers): 
            print(d.location.id)
            print(str(w))
            if d.location.id == str(w): 
                opt_ = optimisers[i]
                break

        opt_.zero_grad() 
        pred = model(d)
        output = criterion(pred, t)
        output.backward() 
        opt_.step() 

        agg_weight = model_.weight.data.fix_prec().share(workers)
        agg_bias = model_.weight.bias.fix_prec().share(workers)
        model.weight.data.set_(agg_weight.get())
        model.weight.bias.set_(agg_bias.get())
        
        print(output)

KeyError: ignored

## Security-level-1: Create model on central server, then send it out to train, and work with gradients directly on that remote body.

In [0]:
from collections import OrderedDict

model = nn.Sequential(OrderedDict([
     ('fc1', nn.Linear(784, 256)),
     ('relu1', nn.ReLU()),
     ('fc2', nn.Linear(256, 64)),
     ('relu2', nn.ReLU()),
     ('fc3', nn.Linear(64, 32)),
     ('relu3', nn.ReLU()), 
     ('fc4', nn.Linear(32, 10)),
     ('outputfx', nn.Softmax())
]))

opt = optim.SGD(params=model.parameters(), lr=0.1)
criterion = nn.NLLLoss()

for d, t in fed_trainloader: 
    
    #print(f'Shape: {d.shape}')
    d = d.view(60, -1)
    model.send(d.location)
    #print("model sent.")
    
    # train model
    opt.zero_grad()
   
    pred = model(d)
    output = criterion(pred, t)
    #print("output done.")
    output.backward()
    opt.step()
    
    # get smarter model back 
    model.get()
    
# Test the model on a single batch
correct_tally = 0
total_tally = 60

d_test, t_test = next(iter(fed_testloader))
d_test = d_test.view(60, -1)
model.send(d_test.location)

pred = model(d_test)
pred_ans = pred.argmax(dim=1) # or _, pred_ans = th.max(pred, 1) - adding dim=1 in torch.max() returns an additional list of the indices
correct_tally += (pred_ans.get().eq(t_test.get())).sum()
print('Correct tally: ' + str(correct_tally))



## Security-level-2: Local models & Trusted aggregator
Create model on central server > 

Send a model COPY to each remote device - each device should have its own optimiser to keep gradients accessible only after averaging >

Weight & bias are sent to a trusted aggregator to be averaged > 

Multiple rounds of training happen over time


In [0]:
# Create model on central server

from collections import OrderedDict

model = nn.Sequential(OrderedDict([
     ('fc1', nn.Linear(784, 256)),
     ('relu1', nn.ReLU()),
     ('fc2', nn.Linear(256, 64)),
     ('relu2', nn.ReLU()),
     ('fc3', nn.Linear(64, 32)),
     ('relu3', nn.ReLU()), 
     ('fc4', nn.Linear(32, 10)),
     ('outputfx', nn.Softmax())
]))

alice_opt = optim.SGD(params = model.parameters(), lr=0.01)
bob_opt = optim.SGD(params = model.parameters(), lr=0.01)
criterion = nn.NLLLoss()

# Send model COPY to each remote device 

for round_iter in range(10): # Code is not runnable - only skeletal idea here
    
    model.copy().send(d.location)
    
    alice_opt.zero_grad()
    pred = model(data)
    output = criterion(pred, target)
    output.backward()
    alice_opt.step()
    
    # Instead of doing model.get() wholesale, we only update our global model's parameters with already-averaged values, on the trusted aggregator 
    alices_model.move(secure_worker)
    bobs_model.move(secure_worker)
    
    model.weight.data.set_((alices_model.weight.data + bobs_model.weight.data) / 2).get() # Set global model weight, then retrieve value
    model.bias.data.set_((alices_model.bias.data + bobs_model.bias.data) / 2).get() # Set global model bias, then retrieve value
    
    
    
    