<a href="https://colab.research.google.com/github/Prathulyan/Federated-Learning/blob/main/FedSGD.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
pip install syft==0.2.9

Collecting syft==0.2.9
  Downloading syft-0.2.9-py3-none-any.whl (433 kB)
[?25l[K     |▊                               | 10 kB 25.0 MB/s eta 0:00:01[K     |█▌                              | 20 kB 23.8 MB/s eta 0:00:01[K     |██▎                             | 30 kB 12.7 MB/s eta 0:00:01[K     |███                             | 40 kB 9.6 MB/s eta 0:00:01[K     |███▊                            | 51 kB 5.2 MB/s eta 0:00:01[K     |████▌                           | 61 kB 5.7 MB/s eta 0:00:01[K     |█████▎                          | 71 kB 5.5 MB/s eta 0:00:01[K     |██████                          | 81 kB 6.2 MB/s eta 0:00:01[K     |██████▉                         | 92 kB 6.3 MB/s eta 0:00:01[K     |███████▌                        | 102 kB 5.0 MB/s eta 0:00:01[K     |████████▎                       | 112 kB 5.0 MB/s eta 0:00:01[K     |█████████                       | 122 kB 5.0 MB/s eta 0:00:01[K     |█████████▉                      | 133 kB 5.0 MB/s eta 0:00:01[

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Dataset
import syft as sy
import copy
import numpy as np

import importlib
importlib.import_module('FLDataset')
from FLDataset import load_dataset, getActualImgs
from utils import averageModels, averageGradients

In [15]:
class Arguments():
    def __init__(self):
        self.images = 60000
        self.clients = 3
        self.epochs = 3
        self.local_batches = self.images // self.clients
        self.lr = 0.01
        self.torch_seed = 0
        self.log_interval = 10
        self.iid = 'iid'
        self.split_size = int(self.images / self.clients)
        self.samples = self.split_size / self.images 
        self.use_cuda = False
        self.save_model = False

args = Arguments()

use_cuda = args.use_cuda and torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}

In [5]:
hook = sy.TorchHook(torch)
clients = []

for i in range(args.clients):
    clients.append({'hook': sy.VirtualWorker(hook, id="client{}".format(i+1))})

In [6]:
# Download MNIST manually using 'wget' then uncompress the file
!wget www.di.ens.fr/~lelarge/MNIST.tar.gz
!tar -zxvf MNIST.tar.gz

--2021-07-30 11:44:13--  http://www.di.ens.fr/~lelarge/MNIST.tar.gz
Resolving www.di.ens.fr (www.di.ens.fr)... 129.199.99.14
Connecting to www.di.ens.fr (www.di.ens.fr)|129.199.99.14|:80... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://www.di.ens.fr/~lelarge/MNIST.tar.gz [following]
--2021-07-30 11:44:14--  https://www.di.ens.fr/~lelarge/MNIST.tar.gz
Connecting to www.di.ens.fr (www.di.ens.fr)|129.199.99.14|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: unspecified [application/x-gzip]
Saving to: ‘MNIST.tar.gz’

MNIST.tar.gz            [          <=>       ]  33.20M  16.8MB/s    in 2.0s    

2021-07-30 11:44:16 (16.8 MB/s) - ‘MNIST.tar.gz’ saved [34813078]

MNIST/
MNIST/raw/
MNIST/raw/train-labels-idx1-ubyte
MNIST/raw/t10k-labels-idx1-ubyte.gz
MNIST/raw/t10k-labels-idx1-ubyte
MNIST/raw/t10k-images-idx3-ubyte.gz
MNIST/raw/train-images-idx3-ubyte
MNIST/raw/train-labels-idx1-ubyte.gz
MNIST/raw/t10k-images-idx3-ubyte
MNIST/raw/tra

In [7]:
global_train, global_test, train_group, test_group = load_dataset(args.clients, args.iid)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw
Processing...
Done!


In [8]:
for inx, client in enumerate(clients):
    trainset_ind_list = list(train_group[inx])
    client['trainset'] = getActualImgs(global_train, trainset_ind_list, args.local_batches)
    client['testset'] = getActualImgs(global_test, list(test_group[inx]), args.local_batches)
    client['samples'] = len(trainset_ind_list) / args.images

In [9]:
global_test_loader = DataLoader(global_test, batch_size=args.local_batches, shuffle=True)

In [10]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 20, 5, 1)
        self.conv2 = nn.Conv2d(20, 50, 5, 1)
        self.fc1 = nn.Linear(4*4*50, 500)
        self.fc2 = nn.Linear(500, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2, 2)
        x = x.view(-1, 4*4*50)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

In [11]:
def train(args, clientss, device, epoch):
    client['model'].train()
    for batch_idx, (data, target) in enumerate(client['trainset']):
        data = data.send(client['hook'])
        target = target.send(client['hook'])
        client['model'].send(data.location)

        data, target = data.to(device), target.to(device)
        client['optim'].zero_grad()
        output = client['model'](data)
        loss = F.nll_loss(output, target)
        loss.backward()
#         client['optim'].step()
        client['model'].get() 

        if batch_idx % args.log_interval == 0:
            loss = loss.get() 
            print('Model {} Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                client['hook'].id,
                epoch, batch_idx * args.local_batches, len(client['trainset']) * args.local_batches, 
                100. * batch_idx / len(client['trainset']), loss.item()))

In [12]:
def test(args, model, device, test_loader, name):
    model.eval()   
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss
            pred = output.argmax(1, keepdim=True) # get the index of the max log-probability 
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)

    print('\nTest set: Average loss for {} model: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        name, test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

In [13]:
class FedSGDOptim(optim.Optimizer):
    def __init__(self, params, lr=args.lr):
        defaults = dict(lr=lr)
        super(FedSGDOptim, self).__init__(params, defaults)

    def step(self, grad_model=None, closure = None):
        loss = None
        if closure is not None:
            loss = closure()
        for group in self.param_groups:
            lr = group['lr']
            for p in zip(group['params'], list(grad_model.parameters())): # (p[0], p[1])
                if p[0].grad is None:
                    continue
#                 d_p = p[0].grad.data # local model grads
                p[0].data.add_(-group['lr'], p[1].grad.data.clone())  
          
        return loss

In [16]:
torch.manual_seed(args.torch_seed)
global_model = Net().to(device)
optimizer = FedSGDOptim(global_model.parameters(), lr=args.lr)
grad_model = Net().to(device)

for client in clients:
    torch.manual_seed(args.torch_seed)
    client['model'] = Net().to(device)
    client['optim'] = optim.SGD(client['model'].parameters(), lr=args.lr)

for epoch in range(1, args.epochs + 1):
    
    for client in clients:
        train(args, client, device, epoch)
    
    grad_model = averageGradients(global_model, clients)
    
#     # Testing 
#     for client in clients:
#         test(args, client['model'], device, client['testset'], client['hook'].id)

    test(args, global_model, device, global_test_loader, 'Global')
    optimizer.step(grad_model)
    test(args, global_model, device, global_test_loader, 'Global')
    
    # Share global model
    for client in clients:
        client['model'].load_state_dict(global_model.state_dict())

if (args.save_model):
    torch.save(global_model.state_dict(), "FedSGD.pt")

  current_tensor = hook_self.torch.native_tensor(*args, **kwargs)



Test set: Average loss for Global model: 2.3129, Accuracy: 1004/10000 (10%)


Test set: Average loss for Global model: 2.3083, Accuracy: 1023/10000 (10%)


Test set: Average loss for Global model: 2.3083, Accuracy: 1023/10000 (10%)


Test set: Average loss for Global model: 2.3038, Accuracy: 1047/10000 (10%)


Test set: Average loss for Global model: 2.3038, Accuracy: 1047/10000 (10%)


Test set: Average loss for Global model: 2.2993, Accuracy: 1092/10000 (11%)

