In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torchvision import datasets, transforms
import os
import random
from torch.autograd import Variable
import copy
from torch import nn, optim
import torch.optim as optim
import numpy as np
from torch.autograd import Variable
from collections import OrderedDict
import matplotlib.pyplot as plt
import pandas as pd
import pickle
import csv
import time
import math


In [2]:
def fix_seed(seed):
    # random
    random.seed(seed)
    # Numpy
    np.random.seed(seed)
    # Pytorch
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True

SEED = 42
fix_seed(SEED)

In [3]:
class Argments():
  def __init__(self):
    self.batch_size = 20
    self.test_batch = 1000
    self.global_epochs = 300
    self.local_epochs = 2
    self.lr = None
    self.momentum = 0.9
    self.weight_decay = 10**-4.0
    self.clip = 20.0
    self.partience = 300
    self.worker_num = 20
    self.sample_num = 20
    self.cluster_list = [1,2,3,4]
    self.cluster_num = None
    self.turn_of_cluster_num = [0,150,225,275]
    self.turn_of_replacement_model = list(range(self.global_epochs))
    self.unlabeleddata_size = 1000
    self.device = device = torch.device('cuda:0'if torch.cuda.is_available() else'cpu')
    self.criterion_ce = nn.CrossEntropyLoss()
    self.criterion_kl = nn.KLDivLoss(reduction='batchmean')

args = Argments()

In [4]:
lr = 0

In [6]:
lr_list = []
lr_list.append(10**-3.0)
lr_list.append(10**-2.5)
lr_list.append(10**-2.0)
lr_list.append(10**-1.5)
lr_list.append(10**-1.0)
lr_list.append(10**-0.5)
lr_list.append(10**0.0)
lr_list.append(10**0.5)

args.lr = lr_list[lr]

In [9]:
class LocalDataset(torch.utils.data.Dataset):
  def __init__(self,dataset,worker_id):
    self.data = []
    self.target = []
    self.id = worker_id
    for data in dataset.take(len(dataset)):
      self.data.append(torch.tensor([data['pixels'].numpy()]))
      self.target.append(torch.tensor(data['label'].numpy().astype(np.int64)))

  def __getitem__(self, index):
    return self.data[index],self.target[index]

  def __len__(self):
    return len(self.data)

In [10]:
with open('../data/federated_trainset_femnist.pickle', 'rb') as f:
    all_federated_trainset = pickle.load(f)
with open('../data/federated_testset_femnist.pickle', 'rb') as f:
    all_federated_testset = pickle.load(f)
all_worker_num = len(all_federated_trainset)

In [11]:
worker_id_list = random.sample(range(all_worker_num),args.worker_num)
print(worker_id_list)
federated_trainset = []
federated_testset = []
for i in worker_id_list:
    federated_trainset.append(all_federated_trainset[i])
    federated_testset.append(all_federated_testset[i])

[2619, 456, 102, 3037, 1126, 1003, 914, 571, 3016, 419, 2771, 3033, 2233, 356, 2418, 1728, 130, 122, 383, 895]


In [12]:
federated_valset = [None]*args.worker_num
for i in range(args.worker_num):
  n_samples = len(federated_trainset[i])
  if n_samples==1:
    federated_valset[i] = copy.deepcopy(federated_trainset[i])
  else:
    train_size = int(len(federated_trainset[i]) * 0.7) 
    val_size = n_samples - train_size 
    federated_trainset[i],federated_valset[i] = torch.utils.data.random_split(federated_trainset[i], [train_size, val_size])

In [13]:
class UnlabeledDataset(torch.utils.data.Dataset):
  def __init__(self):
    self.data = []
    self.target = None

  def __getitem__(self, index):
    return self.data[index],'unlabeled'

  def __len__(self):
    return len(self.data)

In [14]:
unlabeled_dataset = UnlabeledDataset()

for i in range(all_worker_num):
    if i not in worker_id_list:
        unlabeled_dataset.data = unlabeled_dataset.data + all_federated_trainset[i].data
        
unlabeled_dataset,_ = torch.utils.data.random_split(unlabeled_dataset, [args.unlabeleddata_size, len(unlabeled_dataset)-args.unlabeleddata_size])

In [16]:
class CNN1(torch.nn.Module):
    def __init__(self):
        super(CNN1, self).__init__()
        self.conv2d_1 = torch.nn.Conv2d(1, 32, kernel_size=3)
        self.max_pooling = nn.MaxPool2d(2, stride=2)
        self.dropout_1 = nn.Dropout(0.25)
        self.flatten = nn.Flatten()
        self.linear_1 = nn.Linear(5408, 128)
        self.dropout_2 = nn.Dropout(0.5)
        self.linear_2 = nn.Linear(128,62)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.conv2d_1(x)
        x = self.max_pooling(x)
        x = self.dropout_1(x)
        x = self.flatten(x)
        x = self.relu(self.linear_1(x))
        x = self.dropout_2(x)
        x = self.linear_2(x)
        return x

class CNN2(torch.nn.Module):
    def __init__(self):
        super(CNN2, self).__init__()
        self.conv2d_1 = torch.nn.Conv2d(1, 32, kernel_size=3)
        self.max_pooling = nn.MaxPool2d(2, stride=2)
        self.conv2d_2 = torch.nn.Conv2d(32, 64, kernel_size=3)
        self.dropout_1 = nn.Dropout(0.25)
        self.flatten = nn.Flatten()
        self.linear_1 = nn.Linear(9216, 128)
        self.dropout_2 = nn.Dropout(0.5)
        self.linear_2 = nn.Linear(128,62)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.conv2d_1(x)
        x = self.conv2d_2(x)
        x = self.max_pooling(x)
        x = self.dropout_1(x)
        x = self.flatten(x)
        x = self.relu(self.linear_1(x))
        x = self.dropout_2(x)
        x = self.linear_2(x)
        return x

class CNN3(torch.nn.Module):
    def __init__(self):
        super(CNN3, self).__init__()
        self.conv2d_1 = torch.nn.Conv2d(1, 32, kernel_size=3)
        self.max_pooling = nn.MaxPool2d(2, stride=2)
        self.conv2d_2 = torch.nn.Conv2d(32, 64, kernel_size=3)
        self.conv2d_3 = torch.nn.Conv2d(64, 128, kernel_size=3)
        self.dropout_1 = nn.Dropout(0.25)
        self.flatten = nn.Flatten()
        self.linear_1 = nn.Linear(15488, 128)
        self.dropout_2 = nn.Dropout(0.5)
        self.linear_2 = nn.Linear(128,62)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.conv2d_1(x)
        x = self.conv2d_2(x)
        x = self.conv2d_3(x)
        x = self.max_pooling(x)
        x = self.dropout_1(x)
        x = self.flatten(x)
        x = self.relu(self.linear_1(x))
        x = self.dropout_2(x)
        x = self.linear_2(x)
        return x
    
class CNN4(torch.nn.Module):
    def __init__(self):
        super(CNN4, self).__init__()
        self.conv2d_1 = torch.nn.Conv2d(1, 32, kernel_size=3)
        self.max_pooling = nn.MaxPool2d(2, stride=2)
        self.conv2d_2 = torch.nn.Conv2d(32, 64, kernel_size=3)
        self.conv2d_3 = torch.nn.Conv2d(64, 128, kernel_size=3)
        self.conv2d_4 = torch.nn.Conv2d(128, 256, kernel_size=3)
        self.dropout_1 = nn.Dropout(0.25)
        self.flatten = nn.Flatten()
        self.linear_1 = nn.Linear(25600, 128)
        self.dropout_2 = nn.Dropout(0.5)
        self.linear_2 = nn.Linear(128,62)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.conv2d_1(x)
        x = self.conv2d_2(x)
        x = self.conv2d_3(x)
        x = self.conv2d_4(x)
        x = self.max_pooling(x)
        x = self.dropout_1(x)
        x = self.flatten(x)
        x = self.relu(self.linear_1(x))
        x = self.dropout_2(x)
        x = self.linear_2(x)
        return x

In [17]:
class KMeans(object):
    """KMeans 法でクラスタリングするクラス"""

    def __init__(self, n_clusters=2, max_iter=300):
        """コンストラクタ

        Args:
            n_clusters (int): クラスタ数
            max_iter (int): 最大イテレーション数
        """
        self.n_clusters = n_clusters
        self.max_iter = max_iter

        self.cluster_centers_ = None

    def fit_predict(self, features):
        """クラスタリングを実施する

        Args:
            features (numpy.ndarray): ラベル付けするデータ

        Returns:
            numpy.ndarray: ラベルデータ
        """
            
        # 要素の中からセントロイド (重心) の初期値となる候補をクラスタ数だけ選び出す
        feature_indexes = np.arange(len(features))
        np.random.shuffle(feature_indexes)
        initial_centroid_indexes = feature_indexes[:self.n_clusters]
        self.cluster_centers_ = features[initial_centroid_indexes]

        # ラベル付けした結果となる配列はゼロで初期化しておく
        pred = np.zeros(features.shape)
        

        # クラスタリングをアップデートする
        for _ in range(self.max_iter):

            # 各特徴ベクトルから最短距離となるセントロイドを基準に新しいラベルをつける
            new_pred = np.array([
                np.array([
                    self.Euclidean_distance(p, centroid)
                    for centroid in self.cluster_centers_
                ]).argmin()
                for p in features
            ])

            if np.all(new_pred == pred):
                # 更新前と内容を比較して、もし同じなら終了
                break

            pred = new_pred
            
            # 各クラスタごとにセントロイド (重心) を再計算する
            self.cluster_centers_ = np.array([features[pred == i].mean(axis=0)
                                              for i in range(self.n_clusters)])

        return pred

    def KLD(self, p0, p1):
        P = torch.from_numpy(p0.astype(np.float32)).clone()
        Q = torch.from_numpy(p1.astype(np.float32)).clone()
        P = F.softmax(Variable(P), dim=1)
        Q = F.softmax(Variable(Q), dim=1)
        kld = ((P/(P+Q))*(P * (P / ((P/(P+Q))*P + (Q/(P+Q))*Q)).log())).sum() + ((Q/(P+Q))*(Q * (Q / ((P/(P+Q))*P + (Q/(P+Q))*Q)).log())).sum()
        return kld
    
    def Euclidean_distance(self, p0, p1):
        return np.sum((p0 - p1) ** 2)

In [18]:
class Server():
  def __init__(self,unlabeled_dataset):
    self.cluster = None
    self.models = None
    self.unlabeled_dataloader = torch.utils.data.DataLoader(unlabeled_dataset,batch_size=args.batch_size,shuffle=False,num_workers=2)

  def create_worker(self,federated_trainset,federated_valset,federated_testset,client_best_model):
    workers = []
    for i in range(args.worker_num):
      workers.append(Worker(i,federated_trainset[i],federated_valset[i],federated_testset[i],client_best_model[i]))
    return workers

  def sample_worker(self,workers):
    sample_worker = []
    sample_worker_num = random.sample(range(args.worker_num),args.sample_num)
    for i in sample_worker_num:
      sample_worker.append(workers[i])
    return sample_worker

  def collect_model(self,workers):
    self.models = [None]*args.worker_num
    for worker in workers:
      self.models[worker.id] = copy.deepcopy(worker.local_model)

  def send_model(self,workers):
    for worker in workers:
      worker.local_model = copy.deepcopy(self.models[worker.id])
      worker.other_model = copy.deepcopy(self.models[worker.other_model_id])
        
  def return_model(self,workers):
    for worker in workers:
      worker.local_model = copy.deepcopy(self.models[worker.local_model_id])
      worker.local_model_id = worker.id
    del self.models
    
  def aggregate_model(self,workers):   
    new_params = []
    train_model_id = []
    train_model_id_count = []
    for worker in workers:
      worker_state = worker.local_model.state_dict()
      if worker.id in train_model_id:
        i = train_model_id.index(worker.id)
        for key in worker_state.keys():
          new_params[i][key] += worker_state[key]
        train_model_id_count[i] += 1
      else:
        new_params.append(OrderedDict())
        train_model_id.append(worker.id)
        train_model_id_count.append(1)
        i = train_model_id.index(worker.id)
        for key in worker_state.keys():
          new_params[i][key] = worker_state[key]
        
      worker_state = worker.other_model.state_dict()
      if worker.other_model_id in train_model_id:
        i = train_model_id.index(worker.other_model_id)
        for key in worker_state.keys():
          new_params[i][key] += worker_state[key]
        train_model_id_count[i] += 1
      else:
        new_params.append(OrderedDict())
        train_model_id.append(worker.other_model_id)
        train_model_id_count.append(1)
        i = train_model_id.index(worker.other_model_id)
        for key in worker_state.keys():
          new_params[i][key] = worker_state[key]
        
      worker.local_model = worker.local_model.to('cpu')
      worker.other_model = worker.other_model.to('cpu')
      del worker.local_model,worker.other_model
    
    for i,model_id in enumerate(train_model_id):
      for key in new_params[i].keys():
        new_params[i][key] = new_params[i][key]/train_model_id_count[i]
      self.models[model_id].load_state_dict(new_params[i])
      
  '''clustering by kmeans'''  
  def clustering(self,workers):
    if args.cluster_num==1:
        pred = [0]*len(workers)
        worker_id_list = []
        for worker in workers:
            worker_id_list.append(worker.id)
    else:
        with torch.no_grad():
            worker_softmax_targets = [[] for i in range(len(workers))]
            worker_id_list = []
            count = 0
            for i,model in enumerate(self.models):
              if model==None:
                pass
              else:
                model = model.to(args.device)
                model.eval()
                for data,_ in self.unlabeled_dataloader:
                  data = data.to(args.device)
                  worker_softmax_targets[count].append(model(data).to('cpu').detach().numpy())
                worker_softmax_targets[count] = np.array(worker_softmax_targets[count])
                model = model.to('cpu')
                worker_id_list.append(i)
                count += 1
            worker_softmax_targets = np.array(worker_softmax_targets)
            kmeans = KMeans(n_clusters=args.cluster_num)
            pred = kmeans.fit_predict(worker_softmax_targets)
    self.cluster = []
    for i in range(args.cluster_num):
      self.cluster.append([])
    for i,cls in enumerate(pred):
      self.cluster[cls].append(worker_id_list[i])
    for worker in workers:
      idx = worker_id_list.index(worker.id)
      worker.cluster_num = pred[idx]
        
  def decide_other_model(self,workers):
    for worker in workers:
      cls = worker.cluster_num
      '''if number of worker in cluster is one, other model is decided by random in all workers. '''
      if len(self.cluster[cls])==1:
        while True:
          other_worker = random.choice(workers)
          other_model_id = other_worker.id
          if worker.id!=other_model_id:
            break
      else:
        while True:
          other_model_id = random.choice(self.cluster[cls])
          if worker.id!=other_model_id:
            break
      worker.other_model_id = other_model_id

In [19]:
class Worker():
  def __init__(self,i,trainset,valset,testset,best_model):
    self.id = i
    self.cluster_num = None
    self.trainloader = torch.utils.data.DataLoader(trainset,batch_size=args.batch_size,shuffle=True,num_workers=2)
    self.valloader = torch.utils.data.DataLoader(valset,batch_size=args.test_batch,shuffle=False,num_workers=2)
    self.testloader = torch.utils.data.DataLoader(testset,batch_size=args.test_batch,shuffle=False,num_workers=2)
    if best_model==1:
      self.local_model = CNN1()
    elif best_model==2:
      self.local_model = CNN2()
    elif best_model==3:
      self.local_model = CNN3()
    elif best_model==4:
      self.local_model = CNN4()
    self.local_model_id = i
    self.other_model = None
    self.other_model_id = None
    self.train_data_num = len(trainset)
    self.test_data_num = len(testset)

  def local_train(self):
    self.local_model = self.local_model.to(args.device)
    self.other_model = self.other_model.to(args.device)
    local_optimizer = optim.SGD(self.local_model.parameters(),lr=args.lr,momentum=args.momentum,weight_decay=args.weight_decay)
    other_optimizer = optim.SGD(self.other_model.parameters(),lr=args.lr,momentum=args.momentum,weight_decay=args.weight_decay)
    self.local_model.train()
    self.other_model.train()
    for epoch in range(args.local_epochs):
      running_loss = 0.0
      correct = 0
      count = 0
      for (data,labels) in self.trainloader:
        data,labels = Variable(data),Variable(labels)
        data,labels = data.to(args.device),labels.to(args.device)
        local_optimizer.zero_grad()
        other_optimizer.zero_grad()
        local_outputs = self.local_model(data)
        other_outputs = self.other_model(data)
        #train local_model
        ce_loss = args.criterion_ce(local_outputs,labels)
        kl_loss = args.criterion_kl(F.log_softmax(local_outputs, dim = 1),F.softmax(Variable(other_outputs), dim=1))
        loss = ce_loss + kl_loss
        running_loss += loss.item()
        predicted = torch.argmax(local_outputs,dim=1)
        correct += (predicted==labels).sum().item()
        count += len(labels)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.local_model.parameters(), args.clip)
        local_optimizer.step()

        #train other_model
        ce_loss = args.criterion_ce(other_outputs,labels)
        kl_loss = args.criterion_kl(F.log_softmax(other_outputs, dim = 1),F.softmax(Variable(local_outputs), dim=1))
        loss = ce_loss + kl_loss
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.other_model.parameters(), args.clip)
        other_optimizer.step()
        
    return 100.0*correct/count,running_loss/len(self.trainloader)

        
  def validate(self):
    acc,loss = test(self.local_model,args.criterion_ce,self.valloader)
    return acc,loss


  def model_replacement(self):
    _,loss_local = test(self.local_model,args.criterion_ce,self.valloader)
    _,loss_other = test(self.other_model,args.criterion_ce,self.valloader)
    if loss_other<loss_local:
        self.local_model_id = self.other_model_id

In [20]:
def train(model,criterion,trainloader,epochs):
  optimizer = optim.SGD(model.parameters(),lr=args.lr,momentum=args.momentum,weight_decay=args.weight_decay)
  model.train()
  for epoch in range(epochs):
    running_loss = 0.0
    correct = 0
    count = 0
    for (data,labels) in trainloader:
      data,labels = Variable(data),Variable(labels)
      data,labels = data.to(args.device),labels.to(args.device)
      optimizer.zero_grad()
      outputs = model(data)
      loss = criterion(outputs,labels)
      running_loss += loss.item()
      predicted = torch.argmax(outputs,dim=1)
      correct += (predicted==labels).sum().item()
      count += len(labels)
      loss.backward()
      torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
      optimizer.step()

  return 100.0*correct/count,running_loss/len(trainloader)



In [21]:
def test(model,criterion,testloader):
  model.eval()
  running_loss = 0.0
  correct = 0
  count = 0
  for (data,labels) in testloader:
    data,labels = data.to(args.device),labels.to(args.device)
    outputs = model(data)
    running_loss += criterion(outputs,labels).item()
    predicted = torch.argmax(outputs,dim=1)
    correct += (predicted==labels).sum().item()
    count += len(labels)

  accuracy = 100.0*correct/count
  loss = running_loss/len(testloader)


  return accuracy,loss

In [22]:
class Early_Stopping():
  def __init__(self,partience):
    self.step = 0
    self.loss = float('inf')
    self.partience = partience

  def validate(self,loss):
    if self.loss<loss:
      self.step += 1
      if self.step>self.partience:
        return True
    else:
      self.step = 0
      self.loss = loss

    return False

In [23]:
with open('../client_best_model/client_best_model_femnist_42.csv') as fp:
    csvList = list(csv.reader(fp))
client_best_model = [int(item) for subList in csvList for item in subList]

In [24]:
server = Server(unlabeled_dataset)
workers = server.create_worker(federated_trainset,federated_valset,federated_testset,client_best_model)
acc_train = []
loss_train = []
acc_valid = []
loss_valid = []

early_stopping = Early_Stopping(args.partience)

start = time.time()#開始時刻

for epoch in range(args.global_epochs):
  if epoch in args.turn_of_cluster_num:
    idx = args.turn_of_cluster_num.index(epoch)
    args.cluster_num = args.cluster_list[idx]
  sample_worker = server.sample_worker(workers)
  server.collect_model(sample_worker)
  server.clustering(sample_worker)
  server.decide_other_model(sample_worker)
  server.send_model(sample_worker)

  acc_train_avg = 0.0
  loss_train_avg = 0.0
  acc_valid_avg = 0.0
  loss_valid_avg = 0.0
  for worker in sample_worker:
    acc_train_tmp,loss_train_tmp = worker.local_train()
    acc_valid_tmp,loss_valid_tmp = worker.validate()
    acc_train_avg += acc_train_tmp/len(sample_worker)
    loss_train_avg += loss_train_tmp/len(sample_worker)
    acc_valid_avg += acc_valid_tmp/len(sample_worker)
    loss_valid_avg += loss_valid_tmp/len(sample_worker)
  if epoch in args.turn_of_replacement_model:
    for worker in sample_worker:
      worker.model_replacement()
  server.aggregate_model(sample_worker)
  server.return_model(sample_worker)
  '''
  server.model.to(args.device)
  for worker in workers:
    acc_valid_tmp,loss_valid_tmp = test(server.model,args.criterion,worker.valloader)
    acc_valid_avg += acc_valid_tmp/len(workers)
    loss_valid_avg += loss_valid_tmp/len(workers)
  server.model.to('cpu')
  '''
  print('Epoch{}  loss:{}  accuracy:{}'.format(epoch+1,loss_valid_avg,acc_valid_avg))
  acc_train.append(acc_train_avg)
  loss_train.append(loss_train_avg)
  acc_valid.append(acc_valid_avg)
  loss_valid.append(loss_valid_avg)

  if early_stopping.validate(loss_valid_avg):
    print('Early Stop')
    break
    
end = time.time()#終了時刻

Epoch1  loss:4.0051945567131035  accuracy:5.876749915470376
Epoch2  loss:3.8954055190086363  accuracy:5.754684343684445
Epoch3  loss:3.831409931182861  accuracy:6.335507727843882
Epoch4  loss:3.811757743358612  accuracy:6.642949078763494
Epoch5  loss:3.7782390713691707  accuracy:6.371958505787191
Epoch6  loss:3.7626258015632628  accuracy:6.664542030745779
Epoch7  loss:3.749310159683227  accuracy:6.234297552852456
Epoch8  loss:3.7319972634315497  accuracy:6.430966082884119
Epoch9  loss:3.6980793476104736  accuracy:5.567734076794969
Epoch10  loss:3.700629305839538  accuracy:6.281527238045111
Epoch11  loss:3.692843234539032  accuracy:7.336936758396965
Epoch12  loss:3.694174611568451  accuracy:7.339690759085998
Epoch13  loss:3.6989389419555665  accuracy:6.313135039803905
Epoch14  loss:3.6791916728019713  accuracy:7.4666842638657025
Epoch15  loss:3.6734932661056523  accuracy:7.7073199254710865
Epoch16  loss:3.6720670461654663  accuracy:7.190765647572414
Epoch17  loss:3.678414165973663  accu

Epoch136  loss:1.6468112230300904  accuracy:66.75875976705751
Epoch137  loss:1.6574361920356748  accuracy:66.66744854460647
Epoch138  loss:1.6314379990100862  accuracy:67.86413389884243
Epoch139  loss:1.630495315790176  accuracy:67.55215390321688
Epoch140  loss:1.6273400843143462  accuracy:67.8074623757512
Epoch141  loss:1.6179823219776153  accuracy:68.27628058892114
Epoch142  loss:1.6313174188137054  accuracy:68.63266441662824
Epoch143  loss:1.6101988613605498  accuracy:67.17164937549018
Epoch144  loss:1.5908301532268523  accuracy:67.71177731057105
Epoch145  loss:1.5991706907749172  accuracy:67.80763178664337
Epoch146  loss:1.5768540918827056  accuracy:68.16042790870408
Epoch147  loss:1.579812031984329  accuracy:68.55974612667504
Epoch148  loss:1.5727613031864167  accuracy:68.23834789915517
Epoch149  loss:1.5734138667583466  accuracy:67.38561748864939
Epoch150  loss:1.549539363384247  accuracy:67.37775458048634




Epoch151  loss:1.5401683032512663  accuracy:67.98176505454995
Epoch152  loss:1.542090892791748  accuracy:68.77130429312113
Epoch153  loss:1.5329599261283875  accuracy:68.74119593864323
Epoch154  loss:1.5306429862976076  accuracy:68.59215552443888
Epoch155  loss:1.5319107949733735  accuracy:68.77619911254864
Epoch156  loss:1.5277054607868197  accuracy:68.51912080998534
Epoch157  loss:1.5245430767536163  accuracy:68.14074871076599
Epoch158  loss:1.5201799690723419  accuracy:68.5165507875572
Epoch159  loss:1.500540918111801  accuracy:68.72167263540133
Epoch160  loss:1.499469131231308  accuracy:68.63326062103448
Epoch161  loss:1.4896495640277863  accuracy:68.78823531435685
Epoch162  loss:1.4806180655956267  accuracy:68.55418428054952
Epoch163  loss:1.4799287140369417  accuracy:69.29489783676563
Epoch164  loss:1.4635453224182127  accuracy:69.22185562497988
Epoch165  loss:1.4811067938804625  accuracy:68.19939202492588
Epoch166  loss:1.4755980908870698  accuracy:69.88853762601815
Epoch167  lo

  ret, rcount, out=ret, casting='unsafe', subok=False)


Epoch208  loss:1.3204728782176969  accuracy:71.24629383640226
Epoch209  loss:1.2796892732381822  accuracy:71.06687847313026
Epoch210  loss:1.2754454374313358  accuracy:71.6452924500822
Epoch211  loss:1.2718292385339738  accuracy:71.39920990246668
Epoch212  loss:1.2741100817918776  accuracy:72.0822761027109
Epoch213  loss:1.2618150472640992  accuracy:72.29423436513522
Epoch214  loss:1.268881341814995  accuracy:71.22930011605894
Epoch215  loss:1.2539768069982526  accuracy:72.17589764217679
Epoch216  loss:1.2668285816907885  accuracy:71.48729732970729
Epoch217  loss:1.2398518145084383  accuracy:71.20321652687217
Epoch218  loss:1.2382011979818344  accuracy:72.37186483296816
Epoch219  loss:1.2452979803085324  accuracy:72.10648980481376
Epoch220  loss:1.2373915046453479  accuracy:71.44619028088037
Epoch221  loss:1.2321339040994643  accuracy:71.51457019283932
Epoch222  loss:1.2364871233701709  accuracy:71.1450801280876
Epoch223  loss:1.2345329761505126  accuracy:72.16998626390807
Epoch224  lo

In [25]:
print('学習時間：{}秒'.format(end-start))#終了時刻-開始時刻でかかった時間

学習時間：12320.034575939178秒


In [26]:
acc_test = []
loss_test = []

start = time.time()#開始時刻

for i,worker in enumerate(workers):
  worker.local_model = worker.local_model.to(args.device)
  acc_tmp,loss_tmp = test(worker.local_model,args.criterion_ce,worker.testloader)
  acc_test.append(acc_tmp)
  loss_test.append(loss_tmp)
  print('Worker{} accuracy:{}  loss:{}'.format(i+1,acc_tmp,loss_tmp))
  worker.local_model = worker.local_model.to('cpu')

end = time.time()#終了時刻

acc_test_avg = sum(acc_test)/len(acc_test)
loss_test_avg = sum(loss_test)/len(loss_test)
print('Test  loss:{}  accuracy:{}'.format(loss_test_avg,acc_test_avg))

Worker1 accuracy:76.47058823529412  loss:1.3427379131317139
Worker2 accuracy:85.71428571428571  loss:0.5027276277542114
Worker3 accuracy:82.5  loss:0.9414807558059692
Worker4 accuracy:66.66666666666667  loss:1.77033269405365
Worker5 accuracy:91.17647058823529  loss:0.5614816546440125
Worker6 accuracy:80.95238095238095  loss:1.0939338207244873
Worker7 accuracy:68.42105263157895  loss:1.0617470741271973
Worker8 accuracy:78.125  loss:0.8840042352676392
Worker9 accuracy:52.63157894736842  loss:1.7317153215408325
Worker10 accuracy:78.04878048780488  loss:0.9648281931877136
Worker11 accuracy:66.66666666666667  loss:1.4878454208374023
Worker12 accuracy:73.6842105263158  loss:0.9826846718788147
Worker13 accuracy:84.21052631578948  loss:0.963140606880188
Worker14 accuracy:56.666666666666664  loss:1.873410701751709
Worker15 accuracy:78.94736842105263  loss:0.7993667125701904
Worker16 accuracy:70.58823529411765  loss:1.132392168045044
Worker17 accuracy:63.63636363636363  loss:1.3797324895858765
W

In [27]:
print('推論時間：{}秒'.format(end-start))#終了時刻-開始時刻でかかった時間

推論時間：6.750147581100464秒


In [31]:
acc_tune_test = []
loss_tune_test = []
acc_tune_valid = []
loss_tune_valid = []

start = time.time()#開始時刻

for i,worker in enumerate(workers):
    worker.local_model = worker.local_model.to(args.device)
    _,_ = train(worker.local_model,args.criterion_ce,worker.trainloader,args.local_epochs)
    acc_tmp,loss_tmp = test(worker.local_model,args.criterion_ce,worker.valloader)
    acc_tune_valid.append(acc_tmp)
    loss_tune_valid.append(loss_tmp)
    print('Worker{} Valid accuracy:{}  loss:{}'.format(i+1,acc_tmp,loss_tmp))
    
    acc_tmp,loss_tmp = test(worker.local_model,args.criterion_ce,worker.testloader)
    acc_tune_test.append(acc_tmp)
    loss_tune_test.append(loss_tmp)
    print('Worker{} Test accuracy:{}  loss:{}'.format(i+1,acc_tmp,loss_tmp))
    worker.local_model = worker.local_model.to('cpu')

end = time.time()#終了時刻

acc_valid_avg = sum(acc_tune_valid)/len(acc_tune_valid)
loss_valid_avg = sum(loss_tune_valid)/len(loss_tune_valid)
print('Validation(tune)  loss:{}  accuracy:{}'.format(loss_valid_avg,acc_valid_avg))
acc_test_avg = sum(acc_tune_test)/len(acc_tune_test)
loss_test_avg = sum(loss_tune_test)/len(loss_tune_test)
print('Test(tune)  loss:{}  accuracy:{}'.format(loss_test_avg,acc_test_avg))

Worker1 Valid accuracy:79.06976744186046  loss:0.8725739121437073
Worker1 Test accuracy:76.47058823529412  loss:1.3105943202972412
Worker2 Valid accuracy:84.04255319148936  loss:0.7584352493286133
Worker2 Test accuracy:88.57142857142857  loss:0.33251848816871643
Worker3 Valid accuracy:71.42857142857143  loss:1.0110368728637695
Worker3 Test accuracy:80.0  loss:0.7700222134590149
Worker4 Valid accuracy:73.91304347826087  loss:1.0164458751678467
Worker4 Test accuracy:66.66666666666667  loss:1.738160252571106
Worker5 Valid accuracy:78.02197802197803  loss:0.8186147212982178
Worker5 Test accuracy:91.17647058823529  loss:0.5031121373176575
Worker6 Valid accuracy:70.9090909090909  loss:1.311125636100769
Worker6 Test accuracy:76.19047619047619  loss:0.9257825016975403
Worker7 Valid accuracy:81.0  loss:0.881182849407196
Worker7 Test accuracy:68.42105263157895  loss:1.061694622039795
Worker8 Valid accuracy:77.6470588235294  loss:0.8096266984939575
Worker8 Test accuracy:81.25  loss:0.713312864303

In [32]:
print('学習＋推論時間：{}秒'.format(end-start))#終了時刻-開始時刻でかかった時間

学習＋推論時間：28.86422610282898秒
