In [2]:
import torch
import torchvision
import torchvision.transforms as T
import time
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
from torch.utils.data import sampler
import torchvision.datasets as dset
from torch.utils.data import DataLoader
import numpy as np
import copy
import random
import torch.optim as optim

import os
os.environ["CUDA_VISIBLE_DEVICES"]="0"

In [3]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

manualSeed = 1

np.random.seed(manualSeed)
random.seed(manualSeed)
torch.manual_seed(manualSeed)
torch.cuda.manual_seed(manualSeed)
torch.cuda.manual_seed_all(manualSeed)
torch.backends.cudnn.enabled = False 
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

cpu


In [4]:
DEVICE_NUM = 100
DATASIZE_LOCAL = int(60000/DEVICE_NUM)
SERVER_NUM = 10
DEVICE_PER_SERVER = int(DEVICE_NUM/SERVER_NUM)
BATCH_SIZE = 32
STEP_NUM = 5
LABEL_DIVERSITY = 6
ACTIVE_PER_SERVER = 3

In [5]:
np.random.seed(seed=0)

# generate non-IID datasets stored on edge devices
trainset = torchvision.datasets.MNIST('.data/', train=True, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ]))

testset = torchvision.datasets.MNIST('.data/', train=False, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ]))

trainloader = []
testloader_sub = []
for device_ID in range(DEVICE_NUM):
    label_set = random.sample(range(0, 10), LABEL_DIVERSITY)
    idx = trainset.targets.clone().detach() == label_set[0]
    for label_val in label_set[1:]:
        idx += trainset.targets.clone().detach() == label_val
    indx = np.random.permutation(np.where(idx==1)[0])[0:DATASIZE_LOCAL]
    trainset_indx = torch.utils.data.Subset(trainset, indx)
    trainloader.append(torch.utils.data.DataLoader(trainset_indx, batch_size=BATCH_SIZE,
                                          shuffle=True, num_workers=2))
    idx = testset.targets.clone().detach() == label_set[0]
    for label_val in label_set[1:]:
        idx += testset.targets.clone().detach() == label_val
    test_indx = torch.utils.data.Subset(testset, np.where(idx==1)[0])
    testloader_sub.append(torch.utils.data.DataLoader(test_indx, batch_size=BATCH_SIZE,
                                          shuffle=True, num_workers=2))
    

testloader = torch.utils.data.DataLoader(testset, batch_size=10000,
                                         shuffle=False, num_workers=8)

In [6]:
def average_state_dic(para_set):
    para_copy = copy.deepcopy(para_set)
    N = float(len(para_copy))
    return { k : sum(t[k] for t in para_copy)/N for k in para_copy[0] }

In [7]:
input_size = 784
hidden_sizes = [128, 64]
output_size = 10

net = nn.Sequential(  nn.Linear(input_size, hidden_sizes[0]),
                      nn.ReLU(),
                      nn.Linear(hidden_sizes[0], hidden_sizes[1]),
                      nn.ReLU(),
                      nn.Linear(hidden_sizes[1], output_size),
                      nn.LogSoftmax(dim=1))
net_const = nn.Sequential(  nn.Linear(input_size, hidden_sizes[0]),
                      nn.ReLU(),
                      nn.Linear(hidden_sizes[0], hidden_sizes[1]),
                      nn.ReLU(),
                      nn.Linear(hidden_sizes[1], output_size),
                      nn.LogSoftmax(dim=1))

for p in net_const.parameters():
    p.requires_grad = False

if torch.cuda.device_count() != 0:
    global gpu_dtype
    gpu_dtype = torch.cuda.FloatTensor
    
    net.cuda()
    net = net.type(gpu_dtype)
    
    net_const.cuda()
    net_const = net_const.type(gpu_dtype)

In [8]:
para = net.state_dict()
para_set = []

# for each edge device, save a copy of the model
for i in range(DEVICE_NUM):
    para_set.append(copy.deepcopy(para))

In [9]:
np.random.seed(seed=1)
learning_rate = 0.005
num_epochs = 500
RHO = 5
runtime_record = 0
wake_up_time_server = np.zeros(SERVER_NUM)
print('FedProx - RHO: %.1f, training for %d epochs with learning rate %f' % (RHO, num_epochs, learning_rate))
criterion = nn.NLLLoss()
optimizer = optim.SGD(net.parameters(), lr=learning_rate,
                              momentum=0.9, weight_decay=0.0001)

for epoch in range(num_epochs):
    start_time = time.time()
    print('Starting epoch %d / %d' % (epoch+1, num_epochs))
    
    # generate the index set of selected edge devices
    ACTIVE_DEVICE = []
    for server_ID in range(SERVER_NUM):
        ACTIVE_DEVICE.append(random.sample(range(server_ID*DEVICE_PER_SERVER, (server_ID+1)*DEVICE_PER_SERVER), ACTIVE_PER_SERVER))
    ACTIVE_DEVICE = [item for sublist in ACTIVE_DEVICE for item in sublist]
    
    # the selected edge devices update their models (edge-device side)
    for device_ID in range(DEVICE_NUM):
        if device_ID in ACTIVE_DEVICE:
            stopping_iter = random.randint(1, STEP_NUM)
            net.load_state_dict(para_set[device_ID])
            net_const.load_state_dict(copy.deepcopy(para_set[device_ID]))
            z_n = list(net_const.parameters())
            iter_count = 0
            for i, data in enumerate(trainloader[device_ID], 0):
                if torch.cuda.device_count() != 0:
                    inputs, labels = data[0].cuda(), data[1].cuda()
                else:
                    inputs, labels = data[0], data[1]
                inputs = inputs.view(inputs.shape[0], -1)
                optimizer.zero_grad()
                outputs = net(inputs)
                datafitting = criterion(outputs, labels)
                penalty = None
                for (Ww, Zz) in zip(net.parameters(), z_n):
                    if penalty is None:
                        penalty = torch.norm(Ww-Zz)**2
                    else:
                        penalty = penalty + torch.norm(Ww-Zz) ** 2
                loss = datafitting + RHO * penalty
                loss.backward()
                optimizer.step()
                iter_count += 1
                if iter_count == stopping_iter:
                    break
            para_set[device_ID] = copy.deepcopy(net.state_dict())

    # aggregate the updated models from the selected edge devices (cloud-server side)
    para_update = average_state_dic([para_set[i] for i in ACTIVE_DEVICE])
    for i in range(DEVICE_NUM):
        para_set[i] = copy.deepcopy(para_update)
    
    # check the performance on the test dataset
    net.load_state_dict(para_update)
    correct = 0
    total = 0
    with torch.no_grad():
        for data in testloader:
            if torch.cuda.device_count() != 0:
                images, labels = data[0].cuda(), data[1].cuda()
            else:
                images, labels = data[0], data[1]
            images = images.view(images.shape[0], -1)
            outputs = net(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        print('[%d, %d] test accuracy: %.2f %%' %  (num_epochs + 1, epoch + 1, 100 * float(correct) / total))
        
#     correct = 0
#     total = 0
#     for decive_ID in range(DEVICE_NUM):
#         with torch.no_grad():
#             for i, data in enumerate(testloader_sub[decive_ID], 0):
#                 if torch.cuda.device_count() != 0:
#                     images, labels = data[0].cuda(), data[1].cuda()
#                 else:
#                     images, labels = data[0], data[1]
#                 images = images.view(images.shape[0], -1)
#                 outputs = net(images)
#                 _, predicted = torch.max(outputs.data, 1)
#                 total += labels.size(0)
#                 correct += (predicted == labels).sum().item()
#     print('[%d, %d] average device test accuracy: %.2f %%' % (num_epochs + 1, epoch + 1, 100 * float(correct) / total))
    
#     correct = 0
#     total = 0
#     for decive_ID in ACTIVE_DEVICE:
#         with torch.no_grad():
#             for i, data in enumerate(testloader_sub[decive_ID], 0):
#                 if torch.cuda.device_count() != 0:
#                     images, labels = data[0].cuda(), data[1].cuda()
#                 else:
#                     images, labels = data[0], data[1]
#                 images = images.view(images.shape[0], -1)
#                 outputs = net(images)
#                 _, predicted = torch.max(outputs.data, 1)
#                 total += labels.size(0)
#                 correct += (predicted == labels).sum().item()
#     print('[%d, %d] average selected device test accuracy: %.2f %%' % (num_epochs + 1, epoch + 1, 100 * float(correct) / total))
            
    print("--- %s seconds ---" % (time.time() - start_time))


FedProx - RHO: 0, training for 250 epochs with learning rate 0.005000
Starting epoch 1 / 250


ERROR:root:Internal Python error in the inspect module.
Below is the traceback from this internal error.



Traceback (most recent call last):
  File "/Users/Ryan/miniconda3/envs/ryanenv/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 3331, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-9-2b60bba48160>", line 31, in <module>
    for i, data in enumerate(trainloader[device_ID], 0):
  File "/Users/Ryan/miniconda3/envs/ryanenv/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 345, in __next__
    data = self._next_data()
  File "/Users/Ryan/miniconda3/envs/ryanenv/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 841, in _next_data
    idx, data = self._get_data()
  File "/Users/Ryan/miniconda3/envs/ryanenv/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 808, in _get_data
    success, data = self._try_get_data()
  File "/Users/Ryan/miniconda3/envs/ryanenv/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 761, in _try_get_data
    data = self._data_queue.get(timeout=tim

KeyboardInterrupt: 