In [None]:
import torch
import torchvision
import torchvision.transforms as T
import time
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
from torch.utils.data import sampler
import torchvision.datasets as dset
from torch.utils.data import DataLoader
from model import ResNet
from torch.autograd import Variable
import numpy as np
import copy
import networkx as nx 
import torch.optim as optim
import random

import os
os.environ["CUDA_VISIBLE_DEVICES"]="1"

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

manualSeed = 1

np.random.seed(manualSeed)
random.seed(manualSeed)
torch.manual_seed(manualSeed)
torch.cuda.manual_seed(manualSeed)
torch.cuda.manual_seed_all(manualSeed)
torch.backends.cudnn.enabled = False 
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

In [None]:
transform_augment = T.Compose([
        T.RandomHorizontalFlip(),
        T.RandomCrop(32, padding=4)])
transform_normalize = T.Compose([
        T.ToTensor(),
        T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

In [None]:
DEVICE_NUM = 100
DATASIZE_LOCAL = int(50000/DEVICE_NUM)
SERVER_NUM = 10
DEVICE_PER_SERVER = int(DEVICE_NUM/SERVER_NUM)
BATCH_SIZE = 32
STEP_NUM = 5
LABEL_DIVERSITY = 5
ACTIVE_PER_SERVER = 3
FAULT_RATE = 0.05
SYNCHRONIZATION_FLAG = 1
CLOUD_STEP_NUM = 10

In [None]:
np.random.seed(seed=1)

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True,
                                 transform=T.Compose([transform_augment, transform_normalize]))

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform_normalize)


trainloader = []
testloader_sub = []
for device_ID in range(DEVICE_NUM):
    label_set = random.sample(range(0, 10), LABEL_DIVERSITY)
    idx = torch.FloatTensor(trainset.targets) == label_set[0]
    for label_val in label_set[1:]:
        idx += torch.FloatTensor(trainset.targets) == label_val
    indx = np.random.permutation(np.where(idx==1)[0])[0:DATASIZE_LOCAL]
    trainset_indx = torch.utils.data.Subset(trainset, indx)
    trainloader.append(torch.utils.data.DataLoader(trainset_indx, batch_size=BATCH_SIZE,
                                          shuffle=True, num_workers=2))
    idx = torch.FloatTensor(testset.targets) == label_set[0]
    for label_val in label_set[1:]:
        idx += torch.FloatTensor(testset.targets) == label_val
    test_indx = torch.utils.data.Subset(testset, np.where(idx==1)[0])
    testloader_sub.append(torch.utils.data.DataLoader(test_indx, batch_size=BATCH_SIZE,
                                          shuffle=True, num_workers=2))
    

testloader = torch.utils.data.DataLoader(testset, batch_size=10000,
                                         shuffle=False, num_workers=8)

In [None]:
def average_state_dic(para_set):
    para_copy = copy.deepcopy(para_set)
    N = float(len(para_copy))
    return { k : sum(t[k] for t in para_copy)/N for k in para_copy[0] }

In [None]:
net = ResNet(3)
net_const = ResNet(3)
for p in net_const.parameters():
    p.requires_grad = False
if torch.cuda.device_count() != 0:
    global gpu_dtype
    gpu_dtype = torch.cuda.FloatTensor
    
    net.cuda()
    net = net.type(gpu_dtype)
    
    net_const.cuda()
    net_const = net_const.type(gpu_dtype)

In [None]:
para = net.state_dict()
para_set = []

# for each edge device, save a copy of the model
for i in range(DEVICE_NUM):
    para_set.append(copy.deepcopy(para))

In [None]:
np.random.seed(seed=1)
learning_rate = 0.005
num_epochs = 250
RHO = 5
print('FedProx - RHO: %.1f, training for %d epochs with learning rate %f' % (RHO, num_epochs, learning_rate))
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=learning_rate,
                              momentum=0.9, weight_decay=0.0001)

for epoch in range(num_epochs):
    start_time = time.time()
    print('Starting epoch %d / %d' % (epoch+1, num_epochs))
    
    # generate the index set of selected edge devices
    ACTIVE_DEVICE = []
    for server_ID in range(SERVER_NUM):
        ACTIVE_DEVICE.append(random.sample(range(server_ID*DEVICE_PER_SERVER, (server_ID+1)*DEVICE_PER_SERVER), ACTIVE_PER_SERVER))
    ACTIVE_DEVICE = [item for sublist in ACTIVE_DEVICE for item in sublist]
    
    # the selected edge devices update their models (edge-device side)
    for device_ID in range(DEVICE_NUM):
        if device_ID in ACTIVE_DEVICE:
            stopping_iter = random.randint(1, STEP_NUM)
            net.load_state_dict(para_set[device_ID])
            net_const.load_state_dict(copy.deepcopy(para_set[device_ID]))
            z_n = list(net_const.parameters())
            iter_count = 0
            for i, data in enumerate(trainloader[device_ID], 0):
                if torch.cuda.device_count() != 0:
                    inputs, labels = data[0].cuda(), data[1].cuda()
                else:
                    inputs, labels = data[0], data[1]
                optimizer.zero_grad()
                outputs = net(inputs)
                datafitting = criterion(outputs, labels)
                penalty = None
                for (Ww, Zz) in zip(net.parameters(), z_n):
                    if penalty is None:
                        penalty = torch.norm(Ww-Zz)**2
                    else:
                        penalty = penalty + torch.norm(Ww-Zz) ** 2
                loss = datafitting + RHO * penalty
                loss.backward()
                optimizer.step()
                iter_count += 1
                if iter_count == stopping_iter:
                    break
            if np.random.binomial(1, FAULT_RATE) == 1:
                print('transmission fault!')
                para_set[device_ID] = copy.deepcopy(para)
            else:
                para_set[device_ID] = copy.deepcopy(net.state_dict())

    # aggregate the updated models from the selected edge devices (cloud-server side)
    para_update = average_state_dic([para_set[i] for i in ACTIVE_DEVICE])
    for i in range(DEVICE_NUM):
        para_set[i] = copy.deepcopy(para_update)
    
    # check the performance on the test dataset
    net.load_state_dict(para_update)
    correct = 0
    total = 0
    with torch.no_grad():
        for data in testloader:
            if torch.cuda.device_count() != 0:
                images, labels = data[0].cuda(), data[1].cuda()
            else:
                images, labels = data[0], data[1]
            outputs = net(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        print('[%d, %d] test accuracy: %.2f %%' %  (num_epochs + 1, epoch + 1, 100 * float(correct) / total))
        
    tested_device = random.randint(0, DEVICE_NUM-1)
    correct = 0
    total = 0
    for decive_ID in range(DEVICE_NUM):
        with torch.no_grad():
            for i, data in enumerate(testloader_sub[decive_ID], 0):
                if torch.cuda.device_count() != 0:
                    images, labels = data[0].cuda(), data[1].cuda()
                else:
                    images, labels = data[0], data[1]
                outputs = net(images)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
    print('[%d, %d] average device test accuracy: %.2f %%' % (num_epochs + 1, epoch + 1, 100 * float(correct) / total))
    
#     correct = 0
#     total = 0
#     for decive_ID in ACTIVE_DEVICE:
#         with torch.no_grad():
#             for i, data in enumerate(testloader_sub[decive_ID], 0):
#                 if torch.cuda.device_count() != 0:
#                     images, labels = data[0].cuda(), data[1].cuda()
#                 else:
#                     images, labels = data[0], data[1]
#                 images = images.view(images.shape[0], -1)
#                 outputs = net(images)
#                 _, predicted = torch.max(outputs.data, 1)
#                 total += labels.size(0)
#                 correct += (predicted == labels).sum().item()
#     print('[%d, %d] average selected device test accuracy: %.2f %%' % (num_epochs + 1, epoch + 1, 100 * float(correct) / total))
            
    print("--- %s seconds ---" % (time.time() - start_time))
