### Imports

In [1]:
import sys
import torch
import torchvision
import numpy as np
import matplotlib.pyplot as plt
from collections import OrderedDict
from datetime import datetime
sys.path.append('..')
import phe as pallier

In [2]:
proc = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {proc} processor")

Using cpu processor


### Dataset

In [3]:
transform = torchvision.transforms.ToTensor()
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=128)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


  0%|          | 0/170498071 [00:00<?, ?it/s]

Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


### Models

In [4]:
class ConvNet(torch.nn.Module):
    def __init__(self):
        super(ConvNet, self).__init__()
        self.model = torch.nn.Sequential(
            self.conv_block(3, 32),
            self.conv_block(32, 32),
            self.conv_block(32, 64, stride=2),
            self.conv_block(64, 64),
            self.conv_block(64, 64),
            self.conv_block(64, 128, stride=2),
            self.conv_block(128, 128),
            self.conv_block(128, 256),
            self.conv_block(256, 256),
            torch.nn.AdaptiveAvgPool2d(1)
        )
        self.classifier = torch.nn.Linear(256, 10)
        
    def conv_block(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
        return torch.nn.Sequential(
            torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False),
            torch.nn.BatchNorm2d(out_channels),
            torch.nn.ReLU(inplace=True)
        )
        
    def forward(self, x):
        h = self.model(x)
        B, C, _, _ = h.shape
        h = h.view(B, C)
        return self.classifier(h)

### Encryption Functions
These functions return the encryption and decryption functions used by the local devices to encrypt their weight updates

In [5]:
def alg_base():
    return lambda x:x, lambda x:x

def alg_pallier():
    public_key, private_key = pallier.generate_paillier_keypair()
    return public_key.encrypt, private_key.decrypt

### Merging Functions
These functions are used by the central server to get the new global model by merging the local encrypted weight updates

In [6]:
def merge_avg(cipher_weight_dicts):
    result = OrderedDict()
    # sum the weight tensors
    for cipher_weight_dict in cipher_weight_dicts:
        for key, tensor in cipher_weight_dict.items():
            result[key] = result.get(key, 0) + tensor
    # divide the weight tensors to get the average
    for key in result.keys():
        result[key] = result[key] / len(cipher_weight_dicts)
    return result

def merge_sum(cipher_weight_dicts):
    result = OrderedDict()
    # sum the weight tensors
    for cipher_weight_dict in cipher_weight_dicts:
        for key, tensor in cipher_weight_dict.items():
            result[key] = result.get(key, 0) + tensor
    return result

### Classes

In [7]:
class Device():
    def __init__(self, device_id, trainset, net, epochs, func_alg, data_pct=0.1, bsz=128, lr=0.1):
        print(f'\tInitializing Device {device_id}')
        # initialize core device properties
        self.id = device_id
        
        # initialize device dataset
        data_idxs = np.random.choice(len(trainset), size=int(data_pct * len(trainset)), replace=False)
        self.trainloader = torch.utils.data.DataLoader(trainset, batch_size=bsz, sampler=data_idxs)
        
        # initialize device net
        self.net = net
        self.epochs = epochs
        self.criterion = torch.nn.CrossEntropyLoss()
        self.optimizer = torch.optim.SGD(self.net.parameters(), lr=lr, momentum=0.9)
        milestones = [int(0.25*epochs), int(0.50*epochs), int(0.75*epochs)]
        self.scheduler = torch.optim.lr_scheduler.MultiStepLR(self.optimizer, milestones=milestones, gamma=0.1)
        
        # initialize encryption
        print("Initializing Encryption //")
        start = datetime.now()
        self.e, self.d = func_alg()
        end = datetime.now()
        print("Have Keys")
        print((end - start).total_seconds() / 60.0)
        
        # initialize statistics
        
    def train(self):
        self.net.train()
        for epoch in range(self.epochs):
            total_loss, correct, total = 0, 0, 0
            for batch_idx, (inputs, targets) in enumerate(self.trainloader):
                inputs, targets = inputs.to(proc), targets.to(proc)
                self.optimizer.zero_grad()
                outputs = self.net(inputs)
                loss = self.criterion(outputs, targets)
                loss.backward()
                self.optimizer.step()
                
                total_loss += loss.item()
                predicted = outputs.max(1)[1]
                total += targets.size(0)
                correct += predicted.eq(targets).sum().item()
                #sys.stdout.write(f'\r(Device {self.id}/Epoch {epoch}) Train Loss: {total_loss/(batch_idx+1):.3f} | Train Acc: {100.*correct/total:.3f}')
                #sys.stdout.flush() 
        
    def test(self):
        self.net.eval()
        losses, correct, total = [], 0, 0
        with torch.no_grad():
            for batch_idx, (inputs, targets) in enumerate(testloader):
                inputs, targets = inputs.to(proc), targets.to(proc)
                outputs = self.net(inputs)
                loss = self.criterion(outputs, targets)
                
                losses.append(loss.item())
                predicted = outputs.max(1)[1]
                total += targets.size(0)
                correct += predicted.eq(targets).sum().item()
        #sys.stdout.write(f' | Test Loss: {sum(losses)/len(losses):.3f} | Test Acc: {100.*correct/total:.3f}\n')
        #sys.stdout.flush()         
                
    def transmit(self):
        cipher_weights = OrderedDict()
        
        print("Timing encryption of tensor objects")
        start = datetime.now()
        for key, tensor in self.net.state_dict().items():
            print("Tensor Object", key)
            cipher_weights[key] = self.e(tensor)
        
        print("Time to run")
        print((end - start).total_seconds() / 60.0)
        return cipher_weights
            
    def load(self, cipher_weights):
        plain_weights = OrderedDict()
        
        print("Timing decryption of tensor objects")
        for key, tensor in cipher_weights.items():
            print("Tensor Object", key)
            plain_weights[key] = self.d(tensor)
        
        print("Time to run")
        print((end - start).total_seconds() / 60.0)
        self.net.load_state_dict(plain_weights)


class Server():
    def __init__(self, func_merge=merge_avg):
        self.func_merge = func_merge
        self.weights = OrderedDict()
        
    def merge(self, cipher_weight_dicts):
        self.weights = self.func_merge(cipher_weight_dicts)

### Driver Code

In [None]:
def run(num_devices, local_epochs, rounds, round_device_pct, func_alg, func_merge):
    print("Initializing Server")
    server = Server(func_merge=func_merge)
    print("Initializing Devices")
    devices = [Device(d_id, trainset, ConvNet().to(proc), local_epochs, func_alg) for d_id in range(num_devices)]

    for round_num in range(rounds):
        print(f'Round {round_num}')
        round_devices = np.random.choice(devices, size=max(int(num_devices * round_device_pct), 1), replace=False).tolist()
        for device in round_devices:
            device.train()
        server.merge([device.transmit() for device in round_devices]) # the state_dict objects this function operates on are all encrypted
        for device in devices: # For performance I could instead only load the current global model into the next round's round_devices
            device.load(server.weights) # the device decrypts and loads the merged weights
            device.optimizer.zero_grad()
            device.optimizer.step()
            device.scheduler.step()
        
        devices[0].test()
        
    return devices[0]

start = datetime.now()
data_device = run(1, 1, 5, 1, alg_pallier, merge_avg)
end = datetime.now()
print("Time to run")
print((end - start).total_seconds() / 60.0)

Initializing Server
Initializing Devices
	Initializing Device 0
Initializing Encryption //
Have Keys
0.00452965
Round 0
Timing encryption of tensor objects
Tensor Object model.0.0.weight
Paillier --> Encrypting Tensor
Paillier --> Flattening Tensor
Paillier --> Encrypting Array
Tensor Object model.0.1.weight
Paillier --> Encrypting Tensor
Paillier --> Flattening Tensor
Paillier --> Encrypting Array
Tensor Object model.0.1.bias
Paillier --> Encrypting Tensor
Paillier --> Flattening Tensor
Paillier --> Encrypting Array
Tensor Object model.0.1.running_mean
Paillier --> Encrypting Tensor
Paillier --> Flattening Tensor
Paillier --> Encrypting Array
Tensor Object model.0.1.running_var
Paillier --> Encrypting Tensor
Paillier --> Flattening Tensor
Paillier --> Encrypting Array
Tensor Object model.0.1.num_batches_tracked
Paillier --> Encrypting Tensor
Paillier --> Flattening Tensor
Paillier --> Encrypting Array
Tensor Object model.1.0.weight
Paillier --> Encrypting Tensor
Paillier --> Flattenin