In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torchvision import datasets, transforms
import os
import random
from torch.autograd import Variable
import copy
from torch import nn, optim
import torch.optim as optim
import numpy as np
from torch.autograd import Variable
from collections import OrderedDict
import matplotlib.pyplot as plt
import pandas as pd
import pickle
import csv
import time


In [2]:
def fix_seed(seed):
    # random
    random.seed(seed)
    # Numpy
    np.random.seed(seed)
    # Pytorch
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True

SEED = 42
fix_seed(SEED)

In [3]:
class Argments():
  def __init__(self):
    self.batch_size = 20
    self.test_batch = 1000
    self.global_epochs = 300
    self.local_epochs = 2
    self.lr = None
    self.momentum = 0.9
    self.weight_decay = 10**-4.0
    self.clip = 20.0
    self.partience = 10
    self.worker_num = 20
    self.sample_num = 20
    self.device = device = torch.device('cuda:0'if torch.cuda.is_available() else'cpu')
    self.criterion = nn.CrossEntropyLoss()

args = Argments()

In [4]:
lr = 0

In [6]:
lr_list = []
lr_list.append(10**-3.0)
lr_list.append(10**-2.5)
lr_list.append(10**-2.0)
lr_list.append(10**-1.5)
lr_list.append(10**-1.0)
lr_list.append(10**-0.5)
lr_list.append(10**0.0)
lr_list.append(10**0.5)

args.lr = lr_list[lr]

In [9]:
class LocalDataset(torch.utils.data.Dataset):
  def __init__(self,dataset,worker_id):
    self.data = []
    self.target = []
    self.id = worker_id
    for data in dataset.take(len(dataset)):
      self.data.append(torch.tensor([data['pixels'].numpy()]))
      self.target.append(torch.tensor(data['label'].numpy().astype(np.int64)))

  def __getitem__(self, index):
    return self.data[index],self.target[index]

  def __len__(self):
    return len(self.data)

In [10]:
with open('../data/federated_trainset_femnist.pickle', 'rb') as f:
    all_federated_trainset = pickle.load(f)
with open('../data/federated_testset_femnist.pickle', 'rb') as f:
    all_federated_testset = pickle.load(f)
all_worker_num = len(all_federated_trainset)

In [11]:
worker_id_list = random.sample(range(all_worker_num),args.worker_num)
print(worker_id_list)
federated_trainset = []
federated_testset = []
for i in worker_id_list:
    federated_trainset.append(all_federated_trainset[i])
    federated_testset.append(all_federated_testset[i])

[2619, 456, 102, 3037, 1126, 1003, 914, 571, 3016, 419, 2771, 3033, 2233, 356, 2418, 1728, 130, 122, 383, 895]


In [12]:
federated_valset = [None]*args.worker_num
for i in range(args.worker_num):
  n_samples = len(federated_trainset[i])
  if n_samples==1:
    federated_valset[i] = copy.deepcopy(federated_trainset[i])
  else:
    train_size = int(len(federated_trainset[i]) * 0.7) 
    val_size = n_samples - train_size 
    federated_trainset[i],federated_valset[i] = torch.utils.data.random_split(federated_trainset[i], [train_size, val_size])

In [13]:
class GlobalDataset(torch.utils.data.Dataset):
  def __init__(self,federated_dataset):
    self.data = []
    self.target = []
    for dataset in federated_dataset:
      for (data,target) in dataset:
        self.data.append(data)
        self.target.append(target)

  def __getitem__(self, index):
    return self.data[index],self.target[index]

  def __len__(self):
    return len(self.data)

In [14]:
global_trainset = GlobalDataset(federated_trainset)
global_valset = GlobalDataset(federated_valset)
global_testset =  GlobalDataset(federated_testset)

In [15]:
global_trainloader = torch.utils.data.DataLoader(global_trainset,batch_size=args.batch_size,shuffle=True,num_workers=2)
global_valloader = torch.utils.data.DataLoader(global_valset,batch_size=args.test_batch,shuffle=False,num_workers=2)
global_testloader = torch.utils.data.DataLoader(global_testset,batch_size=args.test_batch,shuffle=False,num_workers=2)

In [17]:
class CNN1(torch.nn.Module):
    def __init__(self):
        super(CNN1, self).__init__()
        self.conv2d_1 = torch.nn.Conv2d(1, 32, kernel_size=3)
        self.max_pooling = nn.MaxPool2d(2, stride=2)
        self.dropout_1 = nn.Dropout(0.25)
        self.flatten = nn.Flatten()
        self.linear_1 = nn.Linear(5408, 128)
        self.dropout_2 = nn.Dropout(0.5)
        self.linear_2 = nn.Linear(128,62)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.conv2d_1(x)
        x = self.max_pooling(x)
        x = self.dropout_1(x)
        x = self.flatten(x)
        x = self.relu(self.linear_1(x))
        x = self.dropout_2(x)
        x = self.linear_2(x)
        return x

class CNN2(torch.nn.Module):
    def __init__(self):
        super(CNN2, self).__init__()
        self.conv2d_1 = torch.nn.Conv2d(1, 32, kernel_size=3)
        self.max_pooling = nn.MaxPool2d(2, stride=2)
        self.conv2d_2 = torch.nn.Conv2d(32, 64, kernel_size=3)
        self.dropout_1 = nn.Dropout(0.25)
        self.flatten = nn.Flatten()
        self.linear_1 = nn.Linear(9216, 128)
        self.dropout_2 = nn.Dropout(0.5)
        self.linear_2 = nn.Linear(128,62)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.conv2d_1(x)
        x = self.conv2d_2(x)
        x = self.max_pooling(x)
        x = self.dropout_1(x)
        x = self.flatten(x)
        x = self.relu(self.linear_1(x))
        x = self.dropout_2(x)
        x = self.linear_2(x)
        return x

class CNN3(torch.nn.Module):
    def __init__(self):
        super(CNN3, self).__init__()
        self.conv2d_1 = torch.nn.Conv2d(1, 32, kernel_size=3)
        self.max_pooling = nn.MaxPool2d(2, stride=2)
        self.conv2d_2 = torch.nn.Conv2d(32, 64, kernel_size=3)
        self.conv2d_3 = torch.nn.Conv2d(64, 128, kernel_size=3)
        self.dropout_1 = nn.Dropout(0.25)
        self.flatten = nn.Flatten()
        self.linear_1 = nn.Linear(15488, 128)
        self.dropout_2 = nn.Dropout(0.5)
        self.linear_2 = nn.Linear(128,62)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.conv2d_1(x)
        x = self.conv2d_2(x)
        x = self.conv2d_3(x)
        x = self.max_pooling(x)
        x = self.dropout_1(x)
        x = self.flatten(x)
        x = self.relu(self.linear_1(x))
        x = self.dropout_2(x)
        x = self.linear_2(x)
        return x
    
class CNN4(torch.nn.Module):
    def __init__(self):
        super(CNN4, self).__init__()
        self.conv2d_1 = torch.nn.Conv2d(1, 32, kernel_size=3)
        self.max_pooling = nn.MaxPool2d(2, stride=2)
        self.conv2d_2 = torch.nn.Conv2d(32, 64, kernel_size=3)
        self.conv2d_3 = torch.nn.Conv2d(64, 128, kernel_size=3)
        self.conv2d_4 = torch.nn.Conv2d(128, 256, kernel_size=3)
        self.dropout_1 = nn.Dropout(0.25)
        self.flatten = nn.Flatten()
        self.linear_1 = nn.Linear(25600, 128)
        self.dropout_2 = nn.Dropout(0.5)
        self.linear_2 = nn.Linear(128,62)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.conv2d_1(x)
        x = self.conv2d_2(x)
        x = self.conv2d_3(x)
        x = self.conv2d_4(x)
        x = self.max_pooling(x)
        x = self.dropout_1(x)
        x = self.flatten(x)
        x = self.relu(self.linear_1(x))
        x = self.dropout_2(x)
        x = self.linear_2(x)
        return x

In [18]:
class Server():
  def __init__(self):
    self.model = CNN2()

  def create_worker(self,federated_trainset,federated_valset,federated_testset):
    workers = []
    for i in range(args.worker_num):
      workers.append(Worker(federated_trainset[i],federated_valset[i],federated_testset[i]))
    return workers

  def sample_worker(self,workers):
    sample_worker = []
    sample_worker_num = random.sample(range(args.worker_num),args.sample_num)
    for i in sample_worker_num:
      sample_worker.append(workers[i])
    return sample_worker


  def send_model(self,workers):
    nums = 0
    for worker in workers:
      nums += worker.train_data_num

    for worker in workers:
      worker.aggregation_weight = 1.0*worker.train_data_num/nums
      worker.model = copy.deepcopy(self.model)
      worker.model = worker.model.to(args.device)

  def aggregate_model(self,workers):   
    new_params = OrderedDict()
    for i,worker in enumerate(workers):
      worker_state = worker.model.state_dict()
      for key in worker_state.keys():
        if i==0:
          new_params[key] = worker_state[key]*worker.aggregation_weight
        else:
          new_params[key] += worker_state[key]*worker.aggregation_weight
      worker.model = worker.model.to('cpu')
      del worker.model
    self.model.load_state_dict(new_params)

In [19]:
class Worker():
  def __init__(self,trainset,valset,testset):
    self.trainloader = torch.utils.data.DataLoader(trainset,batch_size=args.batch_size,shuffle=True,num_workers=2)
    self.valloader = torch.utils.data.DataLoader(valset,batch_size=args.test_batch,shuffle=False,num_workers=2)
    self.testloader = torch.utils.data.DataLoader(testset,batch_size=args.test_batch,shuffle=False,num_workers=2)
    self.model = None
    self.train_data_num = len(trainset)
    self.test_data_num = len(testset)
    self.aggregation_weight = None

  def local_train(self):
    acc_train,loss_train = local_train(self.model,args.criterion,self.trainloader,args.local_epochs)
    acc_valid,loss_valid = test(self.model,args.criterion,self.valloader)
    return acc_train,loss_train,acc_valid,loss_valid

    

In [20]:
def local_train(model,criterion,trainloader,epochs):
  optimizer = optim.SGD(model.parameters(),lr=args.lr,momentum=args.momentum,weight_decay=args.weight_decay)
  model.train()
  for epoch in range(epochs):
    running_loss = 0.0
    correct = 0
    count = 0
    for (data,labels) in trainloader:
      data,labels = Variable(data),Variable(labels)
      data,labels = data.to(args.device),labels.to(args.device)
      optimizer.zero_grad()
      outputs = model(data)
      loss = criterion(outputs,labels)
      running_loss += loss.item()
      predicted = torch.argmax(outputs,dim=1)
      correct += (predicted==labels).sum().item()
      count += len(labels)
      loss.backward()
      torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
      optimizer.step()

  return 100.0*correct/count,running_loss/len(trainloader)

In [21]:
def global_train(model,criterion,trainloader,valloader,epochs,partience=0,early_stop=False):
  if early_stop:
    early_stopping = Early_Stopping(partience)

  acc_train = []
  loss_train = []
  acc_valid = []
  loss_valid = []
  optimizer = optim.SGD(model.parameters(),lr=args.lr,momentum=args.momentum,weight_decay=args.weight_decay)
  for epoch in range(epochs):
    running_loss = 0.0
    correct = 0
    count = 0
    model.train()
    for (data,labels) in trainloader:
      count += len(labels)
      data,labels = Variable(data),Variable(labels)
      data,labels = data.to(args.device),labels.to(args.device)
      optimizer.zero_grad()
      outputs = model(data)
      loss = criterion(outputs,labels)
      running_loss += loss.item()
      predicted = torch.argmax(outputs,dim=1)
      correct += (predicted==labels).sum().item()
      loss.backward()
      torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
      optimizer.step()
    acc_train.append(100.0*correct/count)
    loss_train.append(running_loss/len(trainloader))
        
    running_loss = 0.0
    correct = 0
    count = 0
    model.eval()
    for (data,labels) in valloader:
      count += len(labels)
      data,labels = data.to(args.device),labels.to(args.device)
      outputs = model(data)
      loss = criterion(outputs,labels)
      running_loss += loss.item()
      predicted = torch.argmax(outputs,dim=1)
      correct += (predicted==labels).sum().item()
      
    print('Epoch:{}  accuracy:{}  loss:{}'.format(epoch+1,100.0*correct/count,running_loss/len(valloader)))
    acc_valid.append(100.0*correct/count)
    loss_valid.append(running_loss/len(valloader))
    if early_stop:
      if early_stopping.validate(running_loss):
        print('Early Stop')
        return acc_train,loss_train,acc_valid,loss_valid

  return acc_train,loss_train,acc_valid,loss_valid

In [22]:
def test(model,criterion,testloader):
  model.eval()
  running_loss = 0.0
  correct = 0
  count = 0
  for (data,labels) in testloader:
    data,labels = data.to(args.device),labels.to(args.device)
    outputs = model(data)
    running_loss += criterion(outputs,labels).item()
    predicted = torch.argmax(outputs,dim=1)
    correct += (predicted==labels).sum().item()
    count += len(labels)

  accuracy = 100.0*correct/count
  loss = running_loss/len(testloader)


  return accuracy,loss

In [23]:
class Early_Stopping():
  def __init__(self,partience):
    self.step = 0
    self.loss = float('inf')
    self.partience = partience

  def validate(self,loss):
    if self.loss<loss:
      self.step += 1
      if self.step>self.partience:
        return True
    else:
      self.step = 0
      self.loss = loss

    return False

In [24]:
model = CNN2()
model = model.to(args.device)

start = time.time()#開始時刻
acc_train,loss_train,acc_valid,loss_valid = global_train(model,args.criterion,global_trainloader,global_valloader,args.global_epochs,partience=args.partience,early_stop=True)
end = time.time()#終了時刻

Epoch:1  accuracy:2.8612303290414878  loss:3.7910149097442627
Epoch:2  accuracy:6.36623748211731  loss:3.7714728116989136
Epoch:3  accuracy:5.364806866952789  loss:3.7510420083999634
Epoch:4  accuracy:5.5078683834048645  loss:3.7319881916046143
Epoch:5  accuracy:11.587982832618026  loss:3.6789921522140503
Epoch:6  accuracy:10.44349070100143  loss:3.5750285387039185
Epoch:7  accuracy:20.314735336194563  loss:3.336100935935974
Epoch:8  accuracy:29.82832618025751  loss:3.0529645681381226
Epoch:9  accuracy:36.19456366237482  loss:2.7195905447006226
Epoch:10  accuracy:42.989985693848354  loss:2.412811756134033
Epoch:11  accuracy:47.85407725321888  loss:2.18495774269104
Epoch:12  accuracy:52.93276108726752  loss:2.0187483429908752
Epoch:13  accuracy:55.43633762517883  loss:1.8291654586791992
Epoch:14  accuracy:58.86981402002861  loss:1.7027303576469421
Epoch:15  accuracy:61.587982832618025  loss:1.5824919939041138
Epoch:16  accuracy:61.802575107296136  loss:1.4757599234580994
Epoch:17  accur

In [25]:
print('学習時間：{}秒'.format(end-start))#終了時刻-開始時刻でかかった時間

学習時間：157.90964365005493秒


In [26]:
server = Server()
workers = server.create_worker(federated_trainset,federated_valset,federated_testset)
server.model = model

In [27]:
acc_test = []
loss_test = []

server.model.to(args.device)

nums = 0
for worker in workers:
  nums += worker.test_data_num

start = time.time()#開始時刻

for i,worker in enumerate(workers):
  worker.aggregation_weight = 1.0*worker.test_data_num/nums
  acc_tmp,loss_tmp = test(server.model,args.criterion,worker.testloader)
  acc_test.append(acc_tmp)
  loss_test.append(loss_tmp)
  print('Worker{} accuracy:{}  loss:{}'.format(i+1,acc_tmp,loss_tmp))

end = time.time()#終了時刻

acc_test_avg = sum(acc_test)/len(acc_test)
loss_test_avg = sum(loss_test)/len(loss_test)
print('Test  loss:{}  accuracy:{}'.format(loss_test_avg,acc_test_avg))

Worker1 accuracy:70.58823529411765  loss:0.9494462609291077
Worker2 accuracy:94.28571428571429  loss:0.1838645190000534
Worker3 accuracy:77.5  loss:0.757978081703186
Worker4 accuracy:72.22222222222223  loss:1.5084679126739502
Worker5 accuracy:79.41176470588235  loss:0.6983450055122375
Worker6 accuracy:80.95238095238095  loss:0.642073929309845
Worker7 accuracy:73.6842105263158  loss:0.9906824231147766
Worker8 accuracy:84.375  loss:0.4335269033908844
Worker9 accuracy:47.36842105263158  loss:1.7057507038116455
Worker10 accuracy:75.60975609756098  loss:0.6409780383110046
Worker11 accuracy:77.77777777777777  loss:0.7760308980941772
Worker12 accuracy:78.94736842105263  loss:0.6195851564407349
Worker13 accuracy:78.94736842105263  loss:0.7724058032035828
Worker14 accuracy:76.66666666666667  loss:1.6078630685806274
Worker15 accuracy:84.21052631578948  loss:0.5556695461273193
Worker16 accuracy:82.3529411764706  loss:0.7818105220794678
Worker17 accuracy:72.72727272727273  loss:1.6295963525772095


In [28]:
print('推論時間：{}秒'.format(end-start))#終了時刻-開始時刻でかかった時間

推論時間：16.488682746887207秒


In [29]:
acc_tune_test = []
loss_tune_test = []
acc_tune_valid = []
loss_tune_valid = []

start = time.time()#開始時刻

for i,worker in enumerate(workers):
    worker.model = copy.deepcopy(server.model)
    worker.model = worker.model.to(args.device)
    _,_,acc_tmp,loss_tmp = worker.local_train()
    acc_tune_valid.append(acc_tmp)
    loss_tune_valid.append(loss_tmp)
    print('Worker{} Valid accuracy:{}  loss:{}'.format(i+1,acc_tmp,loss_tmp))
    
    acc_tmp,loss_tmp = test(worker.model,args.criterion,worker.testloader)
    acc_tune_test.append(acc_tmp)
    loss_tune_test.append(loss_tmp)
    print('Worker{} Test accuracy:{}  loss:{}'.format(i+1,acc_tmp,loss_tmp))
    worker.model = worker.model.to('cpu')
    del worker.model

end = time.time()#終了時刻

acc_valid_avg = sum(acc_tune_valid)/len(acc_tune_valid)
loss_valid_avg = sum(loss_tune_valid)/len(loss_tune_valid)
print('Validation(tune)  loss:{}  accuracy:{}'.format(loss_valid_avg,acc_valid_avg))
acc_test_avg = sum(acc_tune_test)/len(acc_tune_test)
loss_test_avg = sum(loss_tune_test)/len(loss_tune_test)
print('Test(tune)  loss:{}  accuracy:{}'.format(loss_test_avg,acc_test_avg))

Worker1 Valid accuracy:86.04651162790698  loss:0.31840115785598755
Worker1 Test accuracy:82.3529411764706  loss:0.8582797050476074
Worker2 Valid accuracy:87.23404255319149  loss:0.6121088266372681
Worker2 Test accuracy:94.28571428571429  loss:0.17410723865032196
Worker3 Valid accuracy:75.23809523809524  loss:0.9452593922615051
Worker3 Test accuracy:82.5  loss:0.5617423057556152
Worker4 Valid accuracy:84.78260869565217  loss:0.6278173327445984
Worker4 Test accuracy:72.22222222222223  loss:1.3909858465194702
Worker5 Valid accuracy:83.51648351648352  loss:0.4485121965408325
Worker5 Test accuracy:85.29411764705883  loss:0.48553428053855896
Worker6 Valid accuracy:72.72727272727273  loss:0.9799725413322449
Worker6 Test accuracy:80.95238095238095  loss:0.8291905522346497
Worker7 Valid accuracy:81.0  loss:0.6946359276771545
Worker7 Test accuracy:71.05263157894737  loss:0.9714431762695312
Worker8 Valid accuracy:76.47058823529412  loss:0.8884298801422119
Worker8 Test accuracy:84.375  loss:0.4168

In [30]:
print('学習・推論時間（fine-tune）：{}秒'.format(end-start))#終了時刻-開始時刻でかかった時間

学習・推論時間（fine-tune）：69.1464204788208秒
