In [None]:
import torch
import torchvision
import torchvision.transforms as T
import time
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
from torch.utils.data import sampler
import torchvision.datasets as dset
from torch.utils.data import DataLoader
from model import ResNet
from torch.autograd import Variable
import numpy as np
import copy
import networkx as nx 
import torch.optim as optim
import random

import os
os.environ["CUDA_VISIBLE_DEVICES"]="1"

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

manualSeed = 1

np.random.seed(manualSeed)
random.seed(manualSeed)
torch.manual_seed(manualSeed)
torch.cuda.manual_seed(manualSeed)
torch.cuda.manual_seed_all(manualSeed)
torch.backends.cudnn.enabled = False 
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

In [None]:
transform_augment = T.Compose([
        T.RandomHorizontalFlip(),
        T.RandomCrop(32, padding=4)])
transform_normalize = T.Compose([
        T.ToTensor(),
        T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

In [None]:
DEVICE_NUM = 100
DATASIZE_LOCAL = int(50000/DEVICE_NUM)
SERVER_NUM = 10
DEVICE_PER_SERVER = int(DEVICE_NUM/SERVER_NUM)
BATCH_SIZE = 32
STEP_NUM = 5
LABEL_DIVERSITY = 10
ACTIVE_PER_SERVER = 3
CLOUD_STEP_NUM = 1

In [None]:
np.random.seed(seed=1)

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True,
                                 transform=T.Compose([transform_augment, transform_normalize]))

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform_normalize)


trainloader = []
testloader_sub = []
for device_ID in range(DEVICE_NUM):
    label_set = random.sample(range(0, 10), LABEL_DIVERSITY)
    idx = torch.FloatTensor(trainset.targets) == label_set[0]
    for label_val in label_set[1:]:
        idx += torch.FloatTensor(trainset.targets) == label_val
    indx = np.random.permutation(np.where(idx==1)[0])[0:DATASIZE_LOCAL]
    trainset_indx = torch.utils.data.Subset(trainset, indx)
    trainloader.append(torch.utils.data.DataLoader(trainset_indx, batch_size=BATCH_SIZE,
                                          shuffle=True, num_workers=2))
    idx = torch.FloatTensor(testset.targets) == label_set[0]
    for label_val in label_set[1:]:
        idx += torch.FloatTensor(testset.targets) == label_val
    test_indx = torch.utils.data.Subset(testset, np.where(idx==1)[0])
    testloader_sub.append(torch.utils.data.DataLoader(test_indx, batch_size=BATCH_SIZE,
                                          shuffle=True, num_workers=2))
    

testloader = torch.utils.data.DataLoader(testset, batch_size=10000,
                                         shuffle=False, num_workers=8)

In [None]:
def avg_dict(para_set):
    para_copy = copy.deepcopy(para_set)
    N = float(len(para_copy))
    return { k : sum(t[k] for t in para_copy)/N for k in para_copy[0] }

def weighted_dict(para_set, weight):
    para_copy = copy.deepcopy(para_set)
    for k in range(len(para_copy)):
        para_copy[k].update((x, y*weight[k]) for x, y in para_copy[k].items())
    return { k : sum(t[k] for t in para_copy) for k in para_copy[0] }

def sub_dict(primal_set1, primal_set2):
    primal_set1_copy = copy.deepcopy(primal_set1)
    primal_set2_copy = copy.deepcopy(primal_set2)
    return { k: primal_set1_copy[k] - primal_set2_copy.get(k, 0) for k in primal_set1_copy }

def mul_dict(primal_set, ratio):
    primal_set_copy = copy.deepcopy(primal_set)
    return { k: primal_set_copy[k] * ratio for k in primal_set_copy }

def l2_reg_para(primal_set):
    primal_set_copy = copy.deepcopy(primal_set)
    return torch.sum(torch.stack([torch.norm(x)**2 for x in primal_set_copy]))

def sub_para(primal_set1, primal_set2):
    primal_set1_copy = copy.deepcopy(primal_set1)
    primal_set2_copy = copy.deepcopy(primal_set2)
    return [i - j for i, j in zip(primal_set1_copy, primal_set2_copy)]

In [None]:
net = ResNet(3)
net_const = ResNet(3)
for p in net_const.parameters():
    p.requires_grad = False
if torch.cuda.device_count() != 0:
    global gpu_dtype
    gpu_dtype = torch.cuda.FloatTensor
    
    net.cuda()
    net = net.type(gpu_dtype)
    
    net_const.cuda()
    net_const = net_const.type(gpu_dtype)

In [None]:
para = net.state_dict()
para_device = []
para_server = []
mean_device2server = []

for i in range(DEVICE_NUM):
    para_device.append(copy.deepcopy(para))
    
for i in range(SERVER_NUM):
    para_server.append(copy.deepcopy(para))
    mean_device2server.append(copy.deepcopy(para))

In [None]:
np.random.seed(seed=1)
learning_rate = 0.005
num_epochs = 1000
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=learning_rate, momentum=0.9, weight_decay=0.0001)
RHO = 1
lrz = 0.1
print('(synchronization) Device number: %d, server number: %d, training for %d epochs with learning rate %f, RHO %f, lrz %f, cloud_step %d' % 
      (DEVICE_NUM, SERVER_NUM, num_epochs, learning_rate, RHO, lrz, CLOUD_STEP_NUM))

runtime_record = 0
for epoch in range(num_epochs):
    print('Starting epoch %d / %d' % (epoch+1, num_epochs))
    start_time = time.time()
    runtime_record += np.max(np.abs(np.random.normal(1, 1, 10)))
    ACTIVE_DEVICE = []
    for server_ID in range(SERVER_NUM):
        ACTIVE_DEVICE.append(random.sample(range(server_ID*DEVICE_PER_SERVER, (server_ID+1)*DEVICE_PER_SERVER), ACTIVE_PER_SERVER))
    ACTIVE_DEVICE = [item for sublist in ACTIVE_DEVICE for item in sublist]
    
    for server_ID in range(SERVER_NUM):
        net_const.load_state_dict(copy.deepcopy(para_server[server_ID]))
        z_n = list(net_const.parameters())
        para_tmp = []
        for device_ind in range(DEVICE_PER_SERVER):
            device_ID = device_ind + server_ID * DEVICE_PER_SERVER
            if device_ID in ACTIVE_DEVICE:
                stopping_iter = random.randint(1, STEP_NUM)    
                net.load_state_dict(copy.deepcopy(para_device[device_ID]))
                iter_count = 0
                for i, data in enumerate(trainloader[device_ID], 0):
                    if torch.cuda.device_count() != 0:
                        inputs, labels = data[0].cuda(), data[1].cuda()
                    else:
                        inputs, labels = data[0], data[1]
                    optimizer.zero_grad()
                    outputs = net(inputs)
                    fitting_loss = criterion(outputs, labels)
                    penalty = None
                    for (Ww, Zz) in zip(net.parameters(), z_n):
                        if penalty is None:
                            penalty = torch.norm(Ww-Zz)**2
                        else:
                            penalty = penalty + torch.norm(Ww-Zz) ** 2
                    loss = fitting_loss + RHO * penalty
                    loss.backward()
                    optimizer.step()
                    iter_count += 1
                    if iter_count == stopping_iter:
                        break
                para_device[device_ID] = copy.deepcopy(net.state_dict())
                
    

    for server_epoch in range(CLOUD_STEP_NUM):
        for server_ID in range(SERVER_NUM):
            para_server[server_ID] = sub_dict(para_server[server_ID], mul_dict(sub_dict(para_server[server_ID] , avg_dict(para_device)), lrz ))

            
    correct = 0
    total = 0
    net_const.load_state_dict(avg_dict(para_server))
    with torch.no_grad():
        for data in testloader:
            if torch.cuda.device_count() != 0:
                images, labels = data[0].cuda(), data[1].cuda()
            else:
                images, labels = data[0], data[1]
            outputs = net_const(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    print('[%d, %d] average device test accuracy: %.2f %%, runtime: %.2f' % (num_epochs + 1, epoch + 1, 100 * float(correct) / total, runtime_record))
    
    print("--- %s seconds ---" % (time.time() - start_time))