In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Dataset
import syft as sy
import copy
import numpy as np

from ipynb.fs.full.FLDataset import load_dataset, getActualImages
from ipynb.fs.full.utils import averageModels, averageGradients

In [None]:
class Arguments():
    def __init__(self):
        self.images = 60000
        self.clients = 10
        self.rounds = 2
        self.epochs = 2
        self.local_batches = 1
        self.lr = 0.01
        self.C = 0.9
        self.mu = 0.1
        self.torch_seed = 0
        self.log_interval = 10
        self.iid = 'iid'
        self.split_size = int(self.images / self.clients)
        self.samples = self.split_size / self.images 
        self.use_cuda = False
        self.save_model = False

args = Arguments()

use_cuda = args.use_cuda and torch.cuda.is_available()
torch.manual_seed(1)
device = torch.device("cuda" if use_cuda else "cpu")
kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}

In [None]:
hook = sy.TorchHook(torch)
clients = []

for i in range(args.clients):
    clients.append({'hook': sy.VirtualWorker(hook, id="client{}".format(i+1))})

In [None]:
global_train, global_test, train_group, test_group = load_dataset(args.clients, args.iid)

In [None]:
for inx, client in enumerate(clients):
    trainset_ind_list = list(train_group[inx])
    client['trainset'] = getActualImages(global_train, list(train_group[inx])[:10], args.local_batches)
    client['testset'] = getActualImages(global_test, list(test_group[inx]), args.local_batches)
    client['samples'] = len(trainset_ind_list) / args.images

In [None]:
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,),(0.3081,))])
global_test_dataset = datasets.MNIST('./data', train=False, download=True, transform=transform)
global_test_loader = DataLoader(global_test_dataset, batch_size=args.local_batches, shuffle=True)

In [None]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 20, 5, 1)
        self.conv2 = nn.Conv2d(20, 50, 5, 1)
        self.fc1 = nn.Linear(4*4*50, 500)
        self.fc2 = nn.Linear(500, 10)
        
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2, 2)
        x = x.view(-1, 4*4*50)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

In [None]:
def ClientUpdate(args, device, client, global_model):        
    client['model'].train()
    client['model'].send(client['hook'])
    global_model.send(client['hook'])
    
    for epoch in range(1, args.epochs + 1):
        for batch_idx, (data, target) in enumerate(client['trainset']):
            data = data.send(client['hook'])
            target = target.send(client['hook'])
            
            data, target = data.to(device), target.to(device)
            client['optim'].zero_grad()
            output = client['model'](data)
            loss = F.nll_loss(output, target)
            loss.backward()
            
            client['optim'].step(global_model)
            
            if batch_idx % args.log_interval == 0:
                loss = loss.get() 
                print('Model {} Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    client['hook'].id,
                    epoch, batch_idx * args.local_batches, len(client['trainset']) * args.local_batches, 
                    100. * batch_idx / len(client['trainset']), loss.item()))
    global_model.get()
    client['model'].get()

In [None]:
def test(args, model, device, test_loader, name):
    model.eval()   
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss
            pred = output.argmax(1, keepdim=True) # get the index of the max log-probability 
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    print('\nTest set: Average loss for {} model: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        name, test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

In [None]:
class FedDANEOptim(optim.Optimizer):
    def __init__(self, params, lr=args.lr, mu=args.mu):
        defaults = dict(lr=lr, mu=mu)
        super(FedDANEOptim, self).__init__(params, defaults)
    
    def step(self, global_model=None, closure = None):
        loss = None
        if closure is not None:
            loss = closure()
        
        for group in self.param_groups:
            lr, mu = group['lr'], group['mu']
            for p in zip(group['params'], list(global_model.parameters())):

                if p[0].grad is None:
                    continue
                d_p = p[0].grad.data # local model grads
                                
                gold = p[1].grad.data.clone() - p[0].grad.data.clone()

                p[0].data.sub_(group['lr'], (d_p + gold + mu * (p[0].data.clone() - p[1].data.clone())))

        return loss

In [None]:
# global model
torch.manual_seed(args.torch_seed)
global_model = Net().to(device)

# creating client model
for client in clients:
    torch.manual_seed(args.torch_seed)
    client['model'] = Net().to(device)
    client['optim'] = FedDANEOptim(client['model'].parameters(), lr=args.lr, mu=args.mu)

for fed_round in range(args.rounds):
    
    m = int(max(args.C * args.clients, 1)) # m = number of clients we will select
    
    # selected devices
    np.random.seed(fed_round)
    selected_clients_inds = np.random.choice(range(len(clients)), int(m), replace=False)
    selected_clients = [clients[i] for i in selected_clients_inds]
    
    # average of gradients
    if fed_round > 0:
        global_model = averageGradients(global_model, selected_clients)
        
    # another of set of devices for updating
    np.random.seed(fed_round+1)
    sprime_clients_inds = np.random.choice(range(len(clients)), m, replace=False)
    sprime_clients = [clients[i] for i in sprime_clients_inds]
    
    # training
    for client in sprime_clients:
        ClientUpdate(args, device, client, global_model)
    
    # averaging
    global_model = averageModels(global_model, selected_clients)
    
    # testing
    test(args, global_model, device, global_test_loader, 'Global')
    
    for client in clients:
        client['model'].load_state_dict(global_model.state_dict())
        
if (args.save_model):
    torch.save(global_model.state_dict(), "FedDANE.pt")