In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Dataset
import syft as sy
import copy
import numpy as np
import time

import importlib
importlib.import_module('FLDataset')
from FLDataset import load_dataset, getActualImgs
from utils import averageModels, normalModels

In [2]:
class Arguments():
    def __init__(self):
        self.images = 50000
        self.clients = 20
        self.rounds = 10
        self.epochs = 5
        self.local_batches = 64
        self.lr = 0.01
        self.C = 1.0
        self.drop_rate = 0.0
        self.torch_seed = 0
        self.log_interval = 100
        self.iid = 'iid'
        self.split_size = int(self.images / self.clients)
        self.samples = self.split_size / self.images 
        self.use_cuda = False
        self.save_model = False

args = Arguments()

use_cuda = args.use_cuda and torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}

In [3]:
writer = open('20_experiments_cifar.csv', 'a')
writer.write('\n\n\n')
writer.write('Clients {}, Rounds {}, Epochs {}, IID {}'.format(args.clients, args.rounds, args.epochs, args.iid))
writer.write('\n')

1

In [4]:
hook = sy.TorchHook(torch)
clients = []

for i in range(args.clients):
    clients.append({'hook': sy.VirtualWorker(hook, id="client{}".format(i+1))})

In [5]:
# # Download MNIST manually using 'wget' then uncompress the file
# !wget www.di.ens.fr/~lelarge/MNIST.tar.gz
# !tar -zxvf MNIST.tar.gz

In [6]:
global_train, global_test, train_group, test_group = load_dataset(args.clients, args.iid)

Files already downloaded and verified
Files already downloaded and verified


In [7]:
for inx, client in enumerate(clients):
    trainset_ind_list = list(train_group[inx])
    client['trainset'] = getActualImgs(global_train, trainset_ind_list, args.local_batches)
    client['testset'] = getActualImgs(global_test, list(test_group[inx]), args.local_batches)
#     client['samples'] = len(trainset_ind_list) / args.images
    client['samples'] = len(trainset_ind_list) / (len(trainset_ind_list) * args.clients)

In [8]:
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
global_test_dataset = datasets.CIFAR10('./', train=False, download=True, transform=transform)
global_test_loader = DataLoader(global_test_dataset, batch_size=args.local_batches, shuffle=True)

Files already downloaded and verified


In [9]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return F.log_softmax(x, dim=1)

In [10]:
def ClientUpdate(args, device, client):
    client['model'].train()
#     client['model'].send(client['hook'])
    
    for epoch in range(1, args.epochs + 1):
        for batch_idx, (data, target) in enumerate(client['trainset']):
#             data = data.send(client['hook'])
#             target = target.send(client['hook'])
            
            data, target = data.to(device), target.to(device)
            client['optim'].zero_grad()
            output = client['model'](data)
            loss = F.nll_loss(output, target)
            loss.backward()
            client['optim'].step()
            
            if batch_idx % args.log_interval == 0:
#                 loss = loss.get() 
                print('Model {} Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    client['hook'].id,
                    epoch, batch_idx * args.local_batches, len(client['trainset']) * args.local_batches, 
                    100. * batch_idx / len(client['trainset']), loss))
                
#     client['model'].get() 

In [11]:
def test(args, model, device, test_loader, name):
    model.eval()   
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss
            pred = output.argmax(1, keepdim=True) # get the index of the max log-probability 
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)

    print('\nTest set: Average loss for {} model: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        name, test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))
    writer.write('{},'.format(100. * correct / len(test_loader.dataset)))

In [12]:
def Aggregate(all_models, sample_client):
    all_dicts = {}
    for inds, client in enumerate(all_models):
        all_dicts[client['hook'].id] = all_models[inds]['model'].state_dict()
    output_dict = copy.deepcopy(all_dicts[sample_client])
    for key in all_dicts[sample_client]:
        for val in all_dicts:
            if val == sample_client:
                continue
            output_dict[key] += all_dicts[val][key]
#         output_dict[key] /= float(len(all_dicts))
        output_dict[key] /= float(len(all_models))
    output_model = Net().to(device)
    output_model.load_state_dict(output_dict)
    return output_model

In [13]:
torch.manual_seed(args.torch_seed)
global_model = Net()
# torch.manual_seed(args.torch_seed)
# normal_avg = Net()
# torch.manual_seed(args.torch_seed)
# fake_model = Net()

for client in clients:
    torch.manual_seed(args.torch_seed)
    client['model'] = Net().to(device)
    client['optim'] = optim.SGD(client['model'].parameters(), lr=args.lr)

start_time = time.time()

for fed_round in range(args.rounds):
    
#     uncomment if you want a randome fraction for C every round
#     args.C = float(format(np.random.random(), '.1f'))
    
    # number of selected clients
    m = int(max(args.C * args.clients, 1))

    # Selected devices
    np.random.seed(fed_round)
    selected_clients_inds = np.random.choice(range(len(clients)), m, replace=False)
    selected_clients = [clients[i] for i in selected_clients_inds]
    
    # Active devices
    np.random.seed(fed_round)
    active_clients_inds = np.random.choice(selected_clients_inds, int((1-args.drop_rate) * m), replace=False)
    active_clients = [clients[i] for i in active_clients_inds]
    
    # Training 
    for client in active_clients:
        ClientUpdate(args, device, client)
    
#     # Testing 
#     for client in active_clients:
#         test(args, client['model'], device, client['testset'], client['hook'].id)
    
    # Averaging 
    global_model = averageModels(global_model, active_clients)
#     normal_avg = normalModels(normal_avg, active_clients)
#     fake_model = Aggregate(active_clients, 'client1')
    
    # Testing the average model
    test(args, global_model, device, global_test_loader, 'Global')
#     test(args, normal_avg, device, global_test_loader, 'Global')
#     test(args, fake_model, device, global_test_loader, 'Global')
    
#     break
            
    # Share the global model with the clients
    for client in clients:
        client['model'].load_state_dict(global_model.state_dict())

writer.write('\n')
end_time = time.time() 
hours, rem = divmod(end_time-start_time, 3600)
minutes, seconds = divmod(rem, 60)
writer.write('{:0>2}:{:0>2}:{:05.2f}'.format(int(hours),int(minutes),seconds))
writer.write('\n')
writer.close()

if (args.save_model):
    torch.save(global_model.state_dict(), "FedAvg.pt")

  current_tensor = hook_self.torch.native_tensor(*args, **kwargs)



Test set: Average loss for Global model: 2.3028, Accuracy: 1026/10000 (10%)




Test set: Average loss for Global model: 2.3005, Accuracy: 1083/10000 (11%)




Test set: Average loss for Global model: 2.2971, Accuracy: 1133/10000 (11%)


Test set: Average loss for Global model: 2.2894, Accuracy: 1407/10000 (14%)




Test set: Average loss for Global model: 2.2670, Accuracy: 1920/10000 (19%)




Test set: Average loss for Global model: 2.1725, Accuracy: 2212/10000 (22%)




Test set: Average loss for Global model: 2.0276, Accuracy: 2829/10000 (28%)


Test set: Average loss for Global model: 1.9661, Accuracy: 2907/10000 (29%)




Test set: Average loss for Global model: 1.9192, Accuracy: 3029/10000 (30%)




Test set: Average loss for Global model: 1.8790, Accuracy: 3142/10000 (31%)

