In [15]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Dataset
import copy
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import pickle
from PIL import Image
from sklearn.utils import shuffle


In [16]:
n_e = 2
batch_size = 32
T = 20
learning_rate = 0.001

In [17]:
def load_dataset():
    transforms_mnist = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,),(0.3081,))])
    mnist_data_train = datasets.MNIST('./data/mnist', train=True, download=True, transform=transforms_mnist)
    mnist_data_test = datasets.MNIST('./data/mnist', train=False, download=True, transform=transforms_mnist)

    return mnist_data_train, mnist_data_test

In [18]:
mnist_data_train, mnist_data_test = load_dataset()

In [19]:
train_set_array = mnist_data_train.data.numpy()
train_labels = mnist_data_train.targets.numpy()
test_set_array = mnist_data_test.data.numpy()

In [28]:

def shuffle_data(train_set_array, train_labels):
    
    local_data = np.zeros((2, train_set_array.shape[0], 28, 14))
    train_set_array, train_labels = shuffle(train_set_array, train_labels, random_state=0)
    
    for idx in range(train_set_array.shape[0]):
        img = train_set_array[idx]
        
        local_data[0,idx,:,:] = img[:, 0:14]
        local_data[1,idx,:,:] = img[:, 14:]
    
    return local_data, train_labels
    

In [21]:
class Local_Net(nn.Module):
    def __init__(self):
        super(Local_Net, self).__init__()
        self.fc1 = nn.Linear(28*10, 200)
#         self.fc2 = nn.Linear(200, 200)
        

    def forward(self, x):
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
#         x = F.relu(self.fc2(x))
        
        return x

In [22]:
class Global_Net(nn.Module):
    def __init__(self):
        super(Global_Net, self).__init__()
       
        self.fc1 = nn.Linear(32*320, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        output = self.fc2(x)
        return output

In [23]:
class Local_CNN_Net(nn.Module):
    def __init__(self):
        super(Local_CNN_Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3)
        self.pool = nn.MaxPool2d(2,2)
        self.dropout = nn.Dropout(p=0.2)
#         self.fc1 = nn.Linear(1024, 512)
#         self.fc2 = nn.Linear(512, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
#         x = self.dropout(x)
        x = torch.flatten(x, 1)
#         print('x shape ' + str(x.shape))
#         x = F.relu(self.fc1(x))
        return x

In [30]:
class Global_CNN_Net(nn.Module):
    def __init__(self):
        super(Global_CNN_Net, self).__init__()
        
        self.fc1 = nn.Linear(1280, 256)
        self.fc2 = nn.Linear(256, 10)

    def forward(self, x):
        
        x = torch.flatten(x, 1)
#         print('shape after flatten ' + str(x.shape))
        x = F.relu(self.fc1(x))
        output = self.fc2(x)
        return x

In [25]:
class ClientUpdate(object):
    def __init__(self,learning_rate):
        self.learning_rate = learning_rate
       
    def calculate_output(self, model, data):
        
        output = model(data)
        
        return output
    
    def get_loss(self, y, target):
        loss_object = nn.CrossEntropyLoss()
        loss = loss_object(y_, target)
        
        return loss

In [31]:
model_list = []
optimizer_list = []
for i in range(n_e):
    local_net = Local_CNN_Net()
    model_list.append(local_net)    

In [32]:
num_examples = train_set_array.shape[0]
global_model = Global_CNN_Net()

for curr_round in tqdm(range(1, T+1)):
    
    local_data, train_labels = shuffle_data(train_set_array,train_labels )
    for offset in range(0,num_examples, batch_size):
        global_input = None
        end = offset + batch_size
        if end > num_examples:
            end = num_examples
            
        for i in range(n_e):
            
            input_ = torch.from_numpy(local_data[i,offset:end, :,:]).to(torch.float32)
            input_ = input_.unsqueeze(1)
            input_.requires_grad = True
#             print('input shape ' + str(input_.shape))
            out_ = model_list[i](input_)
            
            if global_input is None:
                global_input = out_
            else:
                global_input = torch.cat((global_input, out_), 1)

#             print(global_input.shape)
        y_pred = global_model(global_input)
        y_target = torch.from_numpy(train_labels[offset:end]).type(torch.LongTensor) 
        
        loss_object = nn.CrossEntropyLoss()
        loss = loss_object(y_pred, torch.squeeze(y_target))
        
        g_optimizer = torch.optim.Adam(global_model.parameters(), lr=learning_rate)
        
        loss.backward()
        
        g_optimizer.step()
        
        for i in range(n_e):
            
            l_optimizer = torch.optim.Adam(model_list[i].parameters(), lr=learning_rate)
            l_optimizer.step()
        
        if end == num_examples:
            
            predictions = y_pred.argmax(dim=1)  
            accuracy = predictions.eq(y_target.data.view_as(predictions)).sum().item()/ len(y_target)
            print('accuracy at round ' + str(curr_round) +' ' + str(accuracy))
        
        

  5%|████▏                                                                              | 1/20 [00:48<15:20, 48.45s/it]

accuracy at round 1 0.0625


 10%|████████▎                                                                          | 2/20 [01:34<14:07, 47.06s/it]

accuracy at round 2 0.125


 15%|████████████▍                                                                      | 3/20 [02:20<13:14, 46.74s/it]

accuracy at round 3 0.15625


 20%|████████████████▌                                                                  | 4/20 [03:08<12:30, 46.91s/it]

accuracy at round 4 0.03125


 25%|████████████████████▊                                                              | 5/20 [03:53<11:36, 46.40s/it]

accuracy at round 5 0.125


 30%|████████████████████████▉                                                          | 6/20 [04:40<10:51, 46.56s/it]

accuracy at round 6 0.125


 35%|█████████████████████████████                                                      | 7/20 [05:29<10:15, 47.38s/it]

accuracy at round 7 0.0625


 40%|█████████████████████████████████▏                                                 | 8/20 [06:19<09:40, 48.34s/it]

accuracy at round 8 0.0625


 45%|█████████████████████████████████████▎                                             | 9/20 [07:10<08:59, 49.08s/it]

accuracy at round 9 0.15625


 50%|█████████████████████████████████████████                                         | 10/20 [07:56<08:01, 48.17s/it]

accuracy at round 10 0.125


 55%|█████████████████████████████████████████████                                     | 11/20 [08:44<07:12, 48.06s/it]

accuracy at round 11 0.0625


 60%|█████████████████████████████████████████████████▏                                | 12/20 [09:33<06:25, 48.21s/it]

accuracy at round 12 0.15625


 65%|█████████████████████████████████████████████████████▎                            | 13/20 [10:23<05:43, 49.02s/it]

accuracy at round 13 0.125


 70%|█████████████████████████████████████████████████████████▍                        | 14/20 [11:17<05:01, 50.27s/it]

accuracy at round 14 0.15625


 75%|█████████████████████████████████████████████████████████████▌                    | 15/20 [12:05<04:07, 49.54s/it]

accuracy at round 15 0.125


 80%|█████████████████████████████████████████████████████████████████▌                | 16/20 [12:50<03:12, 48.25s/it]

accuracy at round 16 0.0625


 85%|█████████████████████████████████████████████████████████████████████▋            | 17/20 [13:35<02:21, 47.25s/it]

accuracy at round 17 0.0


 90%|█████████████████████████████████████████████████████████████████████████▊        | 18/20 [14:21<01:33, 46.85s/it]

accuracy at round 18 0.125


 95%|█████████████████████████████████████████████████████████████████████████████▉    | 19/20 [15:06<00:46, 46.50s/it]

accuracy at round 19 0.125


100%|██████████████████████████████████████████████████████████████████████████████████| 20/20 [15:51<00:00, 47.56s/it]

accuracy at round 20 0.0625



