In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torchvision import datasets, transforms
import os
import random
from torch.autograd import Variable
import copy
from torch import nn, optim
import torch.optim as optim
import numpy as np
from torch.autograd import Variable
from collections import OrderedDict
import matplotlib.pyplot as plt
import pandas as pd
import pickle
import csv
import time
import math


In [2]:
def fix_seed(seed):
    # random
    random.seed(seed)
    # Numpy
    np.random.seed(seed)
    # Pytorch
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True

SEED = 42
fix_seed(SEED)

In [3]:
class Argments():
  def __init__(self):
    self.batch_size = 10
    self.test_batch = 1000
    self.global_epochs = 100
    self.local_epochs = 2
    self.lr = None
    self.momentum = 0.9
    self.weight_decay = 10**-4.0
    self.clip = 20.0
    self.partience = 100
    self.worker_num = 20
    self.sample_num = 20
    self.cluster_list = [1,2,3,4]
    self.cluster_num = None
    self.turn_of_cluster_num = [0,50,75,90]
    self.turn_of_replacement_model = list(range(self.global_epochs))
    self.unlabeleddata_size = 1000
    self.device = device = torch.device('cuda:0'if torch.cuda.is_available() else'cpu')
    self.criterion_ce = nn.CrossEntropyLoss()
    self.criterion_kl = nn.KLDivLoss(reduction='batchmean')

args = Argments()

In [4]:
lr = 0

In [6]:
lr_list = []
lr_list.append(10**-3.0)
lr_list.append(10**-2.5)
lr_list.append(10**-2.0)
lr_list.append(10**-1.5)
lr_list.append(10**-1.0)
lr_list.append(10**-0.5)
lr_list.append(10**0.0)
lr_list.append(10**0.5)

args.lr = lr_list[lr]

In [9]:
class LocalDataset(torch.utils.data.Dataset):
  def __init__(self,dataset,worker_id):
    self.data = []
    self.target = []
    self.id = worker_id
    for i in range(len(dataset)):
      self.data.append(dataset[i][0][0])
      self.target.append(dataset[i][1][0])

  def __getitem__(self, index):
    return self.data[index],self.target[index]

  def __len__(self):
    return len(self.data)

In [10]:
with open('../data/federated_trainset_shakespeare.pickle', 'rb') as f:
    all_federated_trainset = pickle.load(f)
with open('../data/federated_testset_shakespeare.pickle', 'rb') as f:
    all_federated_testset = pickle.load(f)
all_worker_num = len(all_federated_trainset)

In [11]:
worker_id_list = random.sample(range(all_worker_num),args.worker_num)
print(worker_id_list)
federated_trainset = []
federated_testset = []
for i in worker_id_list:
    federated_trainset.append(all_federated_trainset[i])
    federated_testset.append(all_federated_testset[i])

[28, 6, 70, 62, 57, 35, 26, 139, 22, 108, 8, 7, 23, 55, 59, 129, 50, 107, 56, 114]


In [12]:
federated_valset = [None]*args.worker_num
for i in range(args.worker_num):
  n_samples = len(federated_trainset[i])
  if n_samples==1:
    federated_valset[i] = copy.deepcopy(federated_trainset[i])
  else:
    train_size = int(len(federated_trainset[i]) * 0.7) 
    val_size = n_samples - train_size 
    federated_trainset[i],federated_valset[i] = torch.utils.data.random_split(federated_trainset[i], [train_size, val_size])

In [13]:
class UnlabeledDataset(torch.utils.data.Dataset):
  def __init__(self):
    self.data = []
    self.target = None

  def __getitem__(self, index):
    return self.data[index],'unlabeled'

  def __len__(self):
    return len(self.data)

In [14]:
unlabeled_dataset = UnlabeledDataset()

for i in range(all_worker_num):
    if i not in worker_id_list:
        unlabeled_dataset.data = unlabeled_dataset.data + all_federated_trainset[i].data
        
unlabeled_dataset,_ = torch.utils.data.random_split(unlabeled_dataset, [args.unlabeleddata_size, len(unlabeled_dataset)-args.unlabeleddata_size])

In [16]:
class RNN1(nn.Module):
    def __init__(self, embedding_dim=8, vocab_size=90, hidden_size=256):
        super(RNN1, self).__init__()
        self.embeddings = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embedding_dim, padding_idx=0)
        self.lstm = nn.LSTM(input_size=embedding_dim, hidden_size=hidden_size, num_layers=1, batch_first=True)
        self.fc = nn.Linear(hidden_size, vocab_size)

    def forward(self, input_seq):
        embeds = self.embeddings(input_seq)
        # Note that the order of mini-batch is random so there is no hidden relationship among batches.
        # So we do not input the previous batch's hidden state,
        # leaving the first hidden state zero `self.lstm(embeds, None)`.
        lstm_out, _ = self.lstm(embeds)
        # use the final hidden state as the next character prediction
        final_hidden_state = lstm_out[:, -1]
        output = self.fc(final_hidden_state)
        return output

class RNN2(nn.Module):
    def __init__(self, embedding_dim=8, vocab_size=90, hidden_size=256):
        super(RNN2, self).__init__()
        self.embeddings = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embedding_dim, padding_idx=0)
        self.lstm = nn.LSTM(input_size=embedding_dim, hidden_size=hidden_size, num_layers=2, batch_first=True)
        self.fc = nn.Linear(hidden_size, vocab_size)

    def forward(self, input_seq):
        embeds = self.embeddings(input_seq)
        # Note that the order of mini-batch is random so there is no hidden relationship among batches.
        # So we do not input the previous batch's hidden state,
        # leaving the first hidden state zero `self.lstm(embeds, None)`.
        lstm_out, _ = self.lstm(embeds)
        # use the final hidden state as the next character prediction
        final_hidden_state = lstm_out[:, -1]
        output = self.fc(final_hidden_state)
        return output

class RNN3(nn.Module):
    def __init__(self, embedding_dim=8, vocab_size=90, hidden_size=256):
        super(RNN3, self).__init__()
        self.embeddings = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embedding_dim, padding_idx=0)
        self.lstm = nn.LSTM(input_size=embedding_dim, hidden_size=hidden_size, num_layers=3, batch_first=True)
        self.fc = nn.Linear(hidden_size, vocab_size)

    def forward(self, input_seq):
        embeds = self.embeddings(input_seq)
        # Note that the order of mini-batch is random so there is no hidden relationship among batches.
        # So we do not input the previous batch's hidden state,
        # leaving the first hidden state zero `self.lstm(embeds, None)`.
        lstm_out, _ = self.lstm(embeds)
        # use the final hidden state as the next character prediction
        final_hidden_state = lstm_out[:, -1]
        output = self.fc(final_hidden_state)
        return output
    
class RNN4(nn.Module):
    def __init__(self, embedding_dim=8, vocab_size=90, hidden_size=256):
        super(RNN4, self).__init__()
        self.embeddings = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embedding_dim, padding_idx=0)
        self.lstm = nn.LSTM(input_size=embedding_dim, hidden_size=hidden_size, num_layers=4, batch_first=True)
        self.fc = nn.Linear(hidden_size, vocab_size)

    def forward(self, input_seq):
        embeds = self.embeddings(input_seq)
        # Note that the order of mini-batch is random so there is no hidden relationship among batches.
        # So we do not input the previous batch's hidden state,
        # leaving the first hidden state zero `self.lstm(embeds, None)`.
        lstm_out, _ = self.lstm(embeds)
        # use the final hidden state as the next character prediction
        final_hidden_state = lstm_out[:, -1]
        output = self.fc(final_hidden_state)
        return output

In [17]:
class KMeans(object):
    """KMeans 法でクラスタリングするクラス"""

    def __init__(self, n_clusters=2, max_iter=300):
        """コンストラクタ

        Args:
            n_clusters (int): クラスタ数
            max_iter (int): 最大イテレーション数
        """
        self.n_clusters = n_clusters
        self.max_iter = max_iter

        self.cluster_centers_ = None

    def fit_predict(self, features):
        """クラスタリングを実施する

        Args:
            features (numpy.ndarray): ラベル付けするデータ

        Returns:
            numpy.ndarray: ラベルデータ
        """
            
        # 要素の中からセントロイド (重心) の初期値となる候補をクラスタ数だけ選び出す
        feature_indexes = np.arange(len(features))
        np.random.shuffle(feature_indexes)
        initial_centroid_indexes = feature_indexes[:self.n_clusters]
        self.cluster_centers_ = features[initial_centroid_indexes]

        # ラベル付けした結果となる配列はゼロで初期化しておく
        pred = np.zeros(features.shape)
        

        # クラスタリングをアップデートする
        for _ in range(self.max_iter):

            # 各特徴ベクトルから最短距離となるセントロイドを基準に新しいラベルをつける
            new_pred = np.array([
                np.array([
                    self.Euclidean_distance(p, centroid)
                    for centroid in self.cluster_centers_
                ]).argmin()
                for p in features
            ])

            if np.all(new_pred == pred):
                # 更新前と内容を比較して、もし同じなら終了
                break

            pred = new_pred
            
            # 各クラスタごとにセントロイド (重心) を再計算する
            self.cluster_centers_ = np.array([features[pred == i].mean(axis=0)
                                              for i in range(self.n_clusters)])

        return pred

    def KLD(self, p0, p1):
        P = torch.from_numpy(p0.astype(np.float32)).clone()
        Q = torch.from_numpy(p1.astype(np.float32)).clone()
        P = F.softmax(Variable(P), dim=1)
        Q = F.softmax(Variable(Q), dim=1)
        kld = ((P/(P+Q))*(P * (P / ((P/(P+Q))*P + (Q/(P+Q))*Q)).log())).sum() + ((Q/(P+Q))*(Q * (Q / ((P/(P+Q))*P + (Q/(P+Q))*Q)).log())).sum()
        return kld
    
    def Euclidean_distance(self, p0, p1):
        return np.sum((p0 - p1) ** 2)

In [18]:
class Server():
  def __init__(self,unlabeled_dataset):
    self.cluster = None
    self.models = None
    self.unlabeled_dataloader = torch.utils.data.DataLoader(unlabeled_dataset,batch_size=args.batch_size,shuffle=False,num_workers=2)

  def create_worker(self,federated_trainset,federated_valset,federated_testset,client_best_model):
    workers = []
    for i in range(args.worker_num):
      workers.append(Worker(i,federated_trainset[i],federated_valset[i],federated_testset[i],client_best_model[i]))
    return workers

  def sample_worker(self,workers):
    sample_worker = []
    sample_worker_num = random.sample(range(args.worker_num),args.sample_num)
    for i in sample_worker_num:
      sample_worker.append(workers[i])
    return sample_worker

  def collect_model(self,workers):
    self.models = [None]*args.worker_num
    for worker in workers:
      self.models[worker.id] = copy.deepcopy(worker.local_model)

  def send_model(self,workers):
    for worker in workers:
      worker.local_model = copy.deepcopy(self.models[worker.id])
      worker.other_model = copy.deepcopy(self.models[worker.other_model_id])
        
  def return_model(self,workers):
    for worker in workers:
      worker.local_model = copy.deepcopy(self.models[worker.local_model_id])
      worker.local_model_id = worker.id
    del self.models
    
  def aggregate_model(self,workers):   
    new_params = []
    train_model_id = []
    train_model_id_count = []
    for worker in workers:
      worker_state = worker.local_model.state_dict()
      if worker.id in train_model_id:
        i = train_model_id.index(worker.id)
        for key in worker_state.keys():
          new_params[i][key] += worker_state[key]
        train_model_id_count[i] += 1
      else:
        new_params.append(OrderedDict())
        train_model_id.append(worker.id)
        train_model_id_count.append(1)
        i = train_model_id.index(worker.id)
        for key in worker_state.keys():
          new_params[i][key] = worker_state[key]
        
      worker_state = worker.other_model.state_dict()
      if worker.other_model_id in train_model_id:
        i = train_model_id.index(worker.other_model_id)
        for key in worker_state.keys():
          new_params[i][key] += worker_state[key]
        train_model_id_count[i] += 1
      else:
        new_params.append(OrderedDict())
        train_model_id.append(worker.other_model_id)
        train_model_id_count.append(1)
        i = train_model_id.index(worker.other_model_id)
        for key in worker_state.keys():
          new_params[i][key] = worker_state[key]
        
      worker.local_model = worker.local_model.to('cpu')
      worker.other_model = worker.other_model.to('cpu')
      del worker.local_model,worker.other_model
    
    for i,model_id in enumerate(train_model_id):
      for key in new_params[i].keys():
        new_params[i][key] = new_params[i][key]/train_model_id_count[i]
      self.models[model_id].load_state_dict(new_params[i])
      
  '''clustering by kmeans'''  
  def clustering(self,workers):
    if args.cluster_num==1:
        pred = [0]*len(workers)
        worker_id_list = []
        for worker in workers:
            worker_id_list.append(worker.id)
    else:
        with torch.no_grad():
            worker_softmax_targets = [[] for i in range(len(workers))]
            worker_id_list = []
            count = 0
            for i,model in enumerate(self.models):
              if model==None:
                pass
              else:
                model = model.to(args.device)
                model.eval()
                for data,_ in self.unlabeled_dataloader:
                  data = data.to(args.device)
                  worker_softmax_targets[count].append(model(data).to('cpu').detach().numpy())
                worker_softmax_targets[count] = np.array(worker_softmax_targets[count])
                model = model.to('cpu')
                worker_id_list.append(i)
                count += 1
            worker_softmax_targets = np.array(worker_softmax_targets)
            kmeans = KMeans(n_clusters=args.cluster_num)
            pred = kmeans.fit_predict(worker_softmax_targets)
    self.cluster = []
    for i in range(args.cluster_num):
      self.cluster.append([])
    for i,cls in enumerate(pred):
      self.cluster[cls].append(worker_id_list[i])
    for worker in workers:
      idx = worker_id_list.index(worker.id)
      worker.cluster_num = pred[idx]
        
  def decide_other_model(self,workers):
    for worker in workers:
      cls = worker.cluster_num
      '''if number of worker in cluster is one, other model is decided by random in all workers. '''
      if len(self.cluster[cls])==1:
        while True:
          other_worker = random.choice(workers)
          other_model_id = other_worker.id
          if worker.id!=other_model_id:
            break
      else:
        while True:
          other_model_id = random.choice(self.cluster[cls])
          if worker.id!=other_model_id:
            break
      worker.other_model_id = other_model_id

In [19]:
class Worker():
  def __init__(self,i,trainset,valset,testset,best_model):
    self.id = i
    self.cluster_num = None
    self.trainloader = torch.utils.data.DataLoader(trainset,batch_size=args.batch_size,shuffle=True,num_workers=2)
    self.valloader = torch.utils.data.DataLoader(valset,batch_size=args.test_batch,shuffle=False,num_workers=2)
    self.testloader = torch.utils.data.DataLoader(testset,batch_size=args.test_batch,shuffle=False,num_workers=2)
    if best_model==1:
      self.local_model = RNN1()
    elif best_model==2:
      self.local_model = RNN2()
    elif best_model==3:
      self.local_model = RNN3()
    elif best_model==4:
      self.local_model = RNN4()
    self.local_model_id = i
    self.other_model = None
    self.other_model_id = None
    self.train_data_num = len(trainset)
    self.test_data_num = len(testset)

  def local_train(self):
    self.local_model = self.local_model.to(args.device)
    self.other_model = self.other_model.to(args.device)
    local_optimizer = optim.SGD(self.local_model.parameters(),lr=args.lr,momentum=args.momentum,weight_decay=args.weight_decay)
    other_optimizer = optim.SGD(self.other_model.parameters(),lr=args.lr,momentum=args.momentum,weight_decay=args.weight_decay)
    self.local_model.train()
    self.other_model.train()
    for epoch in range(args.local_epochs):
      running_loss = 0.0
      correct = 0
      count = 0
      for (data,labels) in self.trainloader:
        data,labels = Variable(data),Variable(labels)
        data,labels = data.to(args.device),labels.to(args.device)
        local_optimizer.zero_grad()
        other_optimizer.zero_grad()
        local_outputs = self.local_model(data)
        other_outputs = self.other_model(data)
        #train local_model
        ce_loss = args.criterion_ce(local_outputs,labels)
        kl_loss = args.criterion_kl(F.log_softmax(local_outputs, dim = 1),F.softmax(Variable(other_outputs), dim=1))
        loss = ce_loss + kl_loss
        running_loss += loss.item()
        predicted = torch.argmax(local_outputs,dim=1)
        correct += (predicted==labels).sum().item()
        count += len(labels)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.local_model.parameters(), args.clip)
        local_optimizer.step()

        #train other_model
        ce_loss = args.criterion_ce(other_outputs,labels)
        kl_loss = args.criterion_kl(F.log_softmax(other_outputs, dim = 1),F.softmax(Variable(local_outputs), dim=1))
        loss = ce_loss + kl_loss
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.other_model.parameters(), args.clip)
        other_optimizer.step()
        
    return 100.0*correct/count,running_loss/len(self.trainloader)

        
  def validate(self):
    acc,loss = test(self.local_model,args.criterion_ce,self.valloader)
    return acc,loss


  def model_replacement(self):
    _,loss_local = test(self.local_model,args.criterion_ce,self.valloader)
    _,loss_other = test(self.other_model,args.criterion_ce,self.valloader)
    if loss_other<loss_local:
        self.local_model_id = self.other_model_id

In [20]:
def train(model,criterion,trainloader,epochs):
  optimizer = optim.SGD(model.parameters(),lr=args.lr,momentum=args.momentum,weight_decay=args.weight_decay)
  model.train()
  for epoch in range(epochs):
    running_loss = 0.0
    correct = 0
    count = 0
    for (data,labels) in trainloader:
      data,labels = Variable(data),Variable(labels)
      data,labels = data.to(args.device),labels.to(args.device)
      optimizer.zero_grad()
      outputs = model(data)
      loss = criterion(outputs,labels)
      running_loss += loss.item()
      predicted = torch.argmax(outputs,dim=1)
      correct += (predicted==labels).sum().item()
      count += len(labels)
      loss.backward()
      torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
      optimizer.step()

  return 100.0*correct/count,running_loss/len(trainloader)



In [21]:
def test(model,criterion,testloader):
  model.eval()
  running_loss = 0.0
  correct = 0
  count = 0
  for (data,labels) in testloader:
    data,labels = data.to(args.device),labels.to(args.device)
    outputs = model(data)
    running_loss += criterion(outputs,labels).item()
    predicted = torch.argmax(outputs,dim=1)
    correct += (predicted==labels).sum().item()
    count += len(labels)

  accuracy = 100.0*correct/count
  loss = running_loss/len(testloader)


  return accuracy,loss

In [22]:
class Early_Stopping():
  def __init__(self,partience):
    self.step = 0
    self.loss = float('inf')
    self.partience = partience

  def validate(self,loss):
    if self.loss<loss:
      self.step += 1
      if self.step>self.partience:
        return True
    else:
      self.step = 0
      self.loss = loss

    return False

In [23]:
with open('../client_best_model/client_best_model_shakespeare_42.csv') as fp:
    csvList = list(csv.reader(fp))
client_best_model = [int(item) for subList in csvList for item in subList]

In [24]:
server = Server(unlabeled_dataset)
workers = server.create_worker(federated_trainset,federated_valset,federated_testset,client_best_model)
acc_train = []
loss_train = []
acc_valid = []
loss_valid = []

early_stopping = Early_Stopping(args.partience)

start = time.time()#開始時刻

for epoch in range(args.global_epochs):
  if epoch in args.turn_of_cluster_num:
    idx = args.turn_of_cluster_num.index(epoch)
    args.cluster_num = args.cluster_list[idx]
  sample_worker = server.sample_worker(workers)
  server.collect_model(sample_worker)
  server.clustering(sample_worker)
  server.decide_other_model(sample_worker)
  server.send_model(sample_worker)

  acc_train_avg = 0.0
  loss_train_avg = 0.0
  acc_valid_avg = 0.0
  loss_valid_avg = 0.0
  for worker in sample_worker:
    acc_train_tmp,loss_train_tmp = worker.local_train()
    acc_valid_tmp,loss_valid_tmp = worker.validate()
    acc_train_avg += acc_train_tmp/len(sample_worker)
    loss_train_avg += loss_train_tmp/len(sample_worker)
    acc_valid_avg += acc_valid_tmp/len(sample_worker)
    loss_valid_avg += loss_valid_tmp/len(sample_worker)
  if epoch in args.turn_of_replacement_model:
    for worker in sample_worker:
      worker.model_replacement()
  server.aggregate_model(sample_worker)
  server.return_model(sample_worker)
  '''
  server.model.to(args.device)
  for worker in workers:
    acc_valid_tmp,loss_valid_tmp = test(server.model,args.criterion,worker.valloader)
    acc_valid_avg += acc_valid_tmp/len(workers)
    loss_valid_avg += loss_valid_tmp/len(workers)
  server.model.to('cpu')
  '''
  print('Epoch{}  loss:{}  accuracy:{}'.format(epoch+1,loss_valid_avg,acc_valid_avg))
  acc_train.append(acc_train_avg)
  loss_train.append(loss_train_avg)
  acc_valid.append(acc_valid_avg)
  loss_valid.append(loss_valid_avg)

  if early_stopping.validate(loss_valid_avg):
    print('Early Stop')
    break
    
end = time.time()#終了時刻

Epoch1  loss:4.198258861899376  accuracy:10.651870824872297
Epoch2  loss:3.983953463037808  accuracy:14.60677448747968
Epoch3  loss:3.6042719238334238  accuracy:15.56038421250699
Epoch4  loss:3.3906404561466648  accuracy:15.83442143055214
Epoch5  loss:3.2383825984266066  accuracy:16.300496756722662
Epoch6  loss:3.142006513145234  accuracy:17.441770613295063
Epoch7  loss:3.08594811823633  accuracy:18.91312855055702
Epoch8  loss:3.0313406692610845  accuracy:20.111726739634488
Epoch9  loss:2.9843761026859283  accuracy:21.261069394706865
Epoch10  loss:2.955580001738337  accuracy:22.16170537770375
Epoch11  loss:2.9240746759706067  accuracy:21.35003284756231
Epoch12  loss:2.8586263298988337  accuracy:23.329559869580272
Epoch13  loss:2.8257503413491776  accuracy:24.177537814651917
Epoch14  loss:2.8007778465747837  accuracy:25.115764823963303
Epoch15  loss:2.7458615336153245  accuracy:27.32232453615374
Epoch16  loss:2.7278159297174884  accuracy:25.581527947044666
Epoch17  loss:2.70722918079959



Epoch51  loss:2.2434962915049663  accuracy:34.457767671141546
Epoch52  loss:2.245163856281174  accuracy:33.72202237568606
Epoch53  loss:2.2416701871487827  accuracy:33.193129908532114
Epoch54  loss:2.243698896798823  accuracy:32.76735209685523
Epoch55  loss:2.2499532027377023  accuracy:33.13213596553535
Epoch56  loss:2.2323292300105093  accuracy:33.62153202095396
Epoch57  loss:2.2358024014367  accuracy:33.31124371994409
Epoch58  loss:2.2343278924624124  accuracy:34.00073498238374
Epoch59  loss:2.220729173388746  accuracy:33.844746650154654
Epoch60  loss:2.2065521775020494  accuracy:33.33467048542896
Epoch61  loss:2.1964914313620993  accuracy:33.68698167350901
Epoch62  loss:2.1947868580619496  accuracy:33.315349843704766
Epoch63  loss:2.1934416815638547  accuracy:34.25414464444166
Epoch64  loss:2.193582003149721  accuracy:34.78852015926791
Epoch65  loss:2.191911637451914  accuracy:34.06386726516609
Epoch66  loss:2.1908940802017844  accuracy:34.240483985142774
Epoch67  loss:2.18565761066

  ret, rcount, out=ret, casting='unsafe', subok=False)


Epoch69  loss:2.1823716019590695  accuracy:34.9162722787044
Epoch70  loss:2.1806348691384  accuracy:35.244986729257306
Epoch71  loss:2.1775920857985813  accuracy:34.667028242400995
Epoch72  loss:2.1819047002328764  accuracy:35.34386350932925
Epoch73  loss:2.175524635116259  accuracy:35.34365956611227
Epoch74  loss:2.172938975857364  accuracy:34.49066964509605
Epoch75  loss:2.1717261882291896  accuracy:34.32120260185827
Epoch76  loss:2.1865069578091307  accuracy:34.10645893685997
Epoch77  loss:2.1807282881604304  accuracy:34.89359211120683
Epoch78  loss:2.185183694958687  accuracy:34.02240935520897
Epoch79  loss:2.185211420390342  accuracy:33.84983132342728
Epoch80  loss:2.1769041723675198  accuracy:33.81795612601263
Epoch81  loss:2.1744099178247978  accuracy:35.0802141123586
Epoch82  loss:2.1723133357034787  accuracy:34.71955526644721
Epoch83  loss:2.1767233891619573  accuracy:34.30402885843345
Epoch84  loss:2.1664575151271293  accuracy:34.49093806612127
Epoch85  loss:2.168232717282241

In [25]:
print('学習時間：{}秒'.format(end-start))#終了時刻-開始時刻でかかった時間

学習時間：14694.664063692093秒


In [26]:
acc_test = []
loss_test = []

start = time.time()#開始時刻

for i,worker in enumerate(workers):
  worker.local_model = worker.local_model.to(args.device)
  acc_tmp,loss_tmp = test(worker.local_model,args.criterion_ce,worker.testloader)
  acc_test.append(acc_tmp)
  loss_test.append(loss_tmp)
  print('Worker{} accuracy:{}  loss:{}'.format(i+1,acc_tmp,loss_tmp))
  worker.local_model = worker.local_model.to('cpu')

end = time.time()#終了時刻

acc_test_avg = sum(acc_test)/len(acc_test)
loss_test_avg = sum(loss_test)/len(loss_test)
print('Test  loss:{}  accuracy:{}'.format(loss_test_avg,acc_test_avg))

Worker1 accuracy:44.511388684790596  loss:1.9563690764563424
Worker2 accuracy:42.61168384879725  loss:2.114173173904419
Worker3 accuracy:43.18181818181818  loss:2.130749464035034
Worker4 accuracy:50.0  loss:2.2862415313720703
Worker5 accuracy:53.125  loss:1.8954030275344849
Worker6 accuracy:52.38095238095238  loss:2.5655641555786133
Worker7 accuracy:40.0  loss:1.4525091648101807
Worker8 accuracy:54.13223140495868  loss:1.798909068107605
Worker9 accuracy:66.66666666666667  loss:1.634102463722229
Worker10 accuracy:38.104838709677416  loss:2.1451563835144043
Worker11 accuracy:40.18691588785047  loss:2.0083608627319336
Worker12 accuracy:45.0  loss:1.7421486377716064
Worker13 accuracy:43.25366229760987  loss:1.9714645544687908
Worker14 accuracy:37.125748502994014  loss:2.209486961364746
Worker15 accuracy:50.0  loss:1.705851674079895
Worker16 accuracy:43.90934844192635  loss:1.9525646269321442
Worker17 accuracy:39.743589743589745  loss:2.620882749557495
Worker18 accuracy:43.24324324324324  l

In [27]:
print('推論時間：{}秒'.format(end-start))#終了時刻-開始時刻でかかった時間

推論時間：4.832446813583374秒


In [31]:
acc_tune_test = []
loss_tune_test = []
acc_tune_valid = []
loss_tune_valid = []

start = time.time()#開始時刻

for i,worker in enumerate(workers):
    worker.local_model = worker.local_model.to(args.device)
    _,_ = train(worker.local_model,args.criterion_ce,worker.trainloader,args.local_epochs)
    acc_tmp,loss_tmp = test(worker.local_model,args.criterion_ce,worker.valloader)
    acc_tune_valid.append(acc_tmp)
    loss_tune_valid.append(loss_tmp)
    print('Worker{} Valid accuracy:{}  loss:{}'.format(i+1,acc_tmp,loss_tmp))
    
    acc_tmp,loss_tmp = test(worker.local_model,args.criterion_ce,worker.testloader)
    acc_tune_test.append(acc_tmp)
    loss_tune_test.append(loss_tmp)
    print('Worker{} Test accuracy:{}  loss:{}'.format(i+1,acc_tmp,loss_tmp))
    worker.local_model = worker.local_model.to('cpu')

end = time.time()#終了時刻

acc_valid_avg = sum(acc_tune_valid)/len(acc_tune_valid)
loss_valid_avg = sum(loss_tune_valid)/len(loss_tune_valid)
print('Validation(tune)  loss:{}  accuracy:{}'.format(loss_valid_avg,acc_valid_avg))
acc_test_avg = sum(acc_tune_test)/len(acc_tune_test)
loss_test_avg = sum(loss_tune_test)/len(loss_tune_test)
print('Test(tune)  loss:{}  accuracy:{}'.format(loss_test_avg,acc_test_avg))

Worker1 Valid accuracy:44.237599510104104  loss:1.978326506084866
Worker1 Test accuracy:43.27700220426157  loss:1.9968429122652327
Worker2 Valid accuracy:40.11461318051576  loss:2.127464532852173
Worker2 Test accuracy:41.92439862542955  loss:2.1207661628723145
Worker3 Valid accuracy:33.9622641509434  loss:2.1820123195648193
Worker3 Test accuracy:43.18181818181818  loss:2.1478054523468018
Worker4 Valid accuracy:36.70886075949367  loss:2.243539810180664
Worker4 Test accuracy:45.45454545454545  loss:2.2879741191864014
Worker5 Valid accuracy:39.473684210526315  loss:2.4351000785827637
Worker5 Test accuracy:50.0  loss:1.878347396850586
Worker6 Valid accuracy:41.666666666666664  loss:1.9529699087142944
Worker6 Test accuracy:52.38095238095238  loss:2.563145399093628
Worker7 Valid accuracy:0.0  loss:2.6686911582946777
Worker7 Test accuracy:40.0  loss:1.4974935054779053
Worker8 Valid accuracy:48.45360824742268  loss:1.8956695795059204
Worker8 Test accuracy:52.06611570247934  loss:1.816692352294

In [32]:
print('学習＋推論時間：{}秒'.format(end-start))#終了時刻-開始時刻でかかった時間

学習＋推論時間：75.32590579986572秒
