In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torchvision import datasets, transforms
import os
import random
from torch.autograd import Variable
import copy
from torch import nn, optim
import torch.optim as optim
import numpy as np
from torch.autograd import Variable
from collections import OrderedDict
import matplotlib.pyplot as plt
import pandas as pd
import pickle
import csv
import time


In [2]:
def fix_seed(seed):
    # random
    random.seed(seed)
    # Numpy
    np.random.seed(seed)
    # Pytorch
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True

SEED = 42
fix_seed(SEED)

In [3]:
class Argments():
  def __init__(self):
    self.batch_size = 10
    self.test_batch = 1000
    self.global_epochs = 300
    self.local_epochs = 2
    self.lr = None
    self.momentum = 0.9
    self.weight_decay = 10**-4.0
    self.clip = 20.0
    self.partience = 10
    self.worker_num = 20
    self.sample_num = 20
    self.device = device = torch.device('cuda:0'if torch.cuda.is_available() else'cpu')
    self.criterion = nn.CrossEntropyLoss()

args = Argments()

In [4]:
# Tuned value
lr = 4

In [6]:
lr_list = []
lr_list.append(10**-3.0)
lr_list.append(10**-2.5)
lr_list.append(10**-2.0)
lr_list.append(10**-1.5)
lr_list.append(10**-1.0)
lr_list.append(10**-0.5)
lr_list.append(10**0.0)
lr_list.append(10**0.5)

args.lr = lr_list[lr]

In [9]:
class LocalDataset(torch.utils.data.Dataset):
  def __init__(self,dataset,worker_id):
    self.data = []
    self.target = []
    self.id = worker_id
    for i in range(len(dataset)):
      self.data.append(dataset[i][0][0])
      self.target.append(dataset[i][1][0])

  def __getitem__(self, index):
    return self.data[index],self.target[index]

  def __len__(self):
    return len(self.data)

In [10]:
with open('../data/federated_trainset_shakespeare.pickle', 'rb') as f:
    all_federated_trainset = pickle.load(f)
with open('../data/federated_testset_shakespeare.pickle', 'rb') as f:
    all_federated_testset = pickle.load(f)
all_worker_num = len(all_federated_trainset)

In [11]:
worker_id_list = random.sample(range(all_worker_num),args.worker_num)
print(worker_id_list)
federated_trainset = []
federated_testset = []
for i in worker_id_list:
    federated_trainset.append(all_federated_trainset[i])
    federated_testset.append(all_federated_testset[i])

[28, 6, 70, 62, 57, 35, 26, 139, 22, 108, 8, 7, 23, 55, 59, 129, 50, 107, 56, 114]


In [12]:
federated_valset = [None]*args.worker_num
for i in range(args.worker_num):
  n_samples = len(federated_trainset[i])
  if n_samples==1:
    federated_valset[i] = copy.deepcopy(federated_trainset[i])
  else:
    train_size = int(len(federated_trainset[i]) * 0.7) 
    val_size = n_samples - train_size 
    federated_trainset[i],federated_valset[i] = torch.utils.data.random_split(federated_trainset[i], [train_size, val_size])

In [13]:
class GlobalDataset(torch.utils.data.Dataset):
  def __init__(self,federated_dataset):
    self.data = []
    self.target = []
    for dataset in federated_dataset:
      for (data,target) in dataset:
        self.data.append(data)
        self.target.append(target)

  def __getitem__(self, index):
    return self.data[index],self.target[index]

  def __len__(self):
    return len(self.data)

In [14]:
global_trainset = GlobalDataset(federated_trainset)
global_valset = GlobalDataset(federated_valset)
global_testset =  GlobalDataset(federated_testset)

In [15]:
global_trainloader = torch.utils.data.DataLoader(global_trainset,batch_size=args.batch_size,shuffle=True,num_workers=2)
global_valloader = torch.utils.data.DataLoader(global_valset,batch_size=args.test_batch,shuffle=False,num_workers=2)
global_testloader = torch.utils.data.DataLoader(global_testset,batch_size=args.test_batch,shuffle=False,num_workers=2)

In [17]:
class RNN1(nn.Module):
    def __init__(self, embedding_dim=8, vocab_size=90, hidden_size=256):
        super(RNN1, self).__init__()
        self.embeddings = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embedding_dim, padding_idx=0)
        self.lstm = nn.LSTM(input_size=embedding_dim, hidden_size=hidden_size, num_layers=1, batch_first=True)
        self.fc = nn.Linear(hidden_size, vocab_size)

    def forward(self, input_seq):
        embeds = self.embeddings(input_seq)
        # Note that the order of mini-batch is random so there is no hidden relationship among batches.
        # So we do not input the previous batch's hidden state,
        # leaving the first hidden state zero `self.lstm(embeds, None)`.
        lstm_out, _ = self.lstm(embeds)
        # use the final hidden state as the next character prediction
        final_hidden_state = lstm_out[:, -1]
        output = self.fc(final_hidden_state)
        return output

class RNN2(nn.Module):
    def __init__(self, embedding_dim=8, vocab_size=90, hidden_size=256):
        super(RNN2, self).__init__()
        self.embeddings = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embedding_dim, padding_idx=0)
        self.lstm = nn.LSTM(input_size=embedding_dim, hidden_size=hidden_size, num_layers=2, batch_first=True)
        self.fc = nn.Linear(hidden_size, vocab_size)

    def forward(self, input_seq):
        embeds = self.embeddings(input_seq)
        # Note that the order of mini-batch is random so there is no hidden relationship among batches.
        # So we do not input the previous batch's hidden state,
        # leaving the first hidden state zero `self.lstm(embeds, None)`.
        lstm_out, _ = self.lstm(embeds)
        # use the final hidden state as the next character prediction
        final_hidden_state = lstm_out[:, -1]
        output = self.fc(final_hidden_state)
        return output

class RNN3(nn.Module):
    def __init__(self, embedding_dim=8, vocab_size=90, hidden_size=256):
        super(RNN3, self).__init__()
        self.embeddings = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embedding_dim, padding_idx=0)
        self.lstm = nn.LSTM(input_size=embedding_dim, hidden_size=hidden_size, num_layers=3, batch_first=True)
        self.fc = nn.Linear(hidden_size, vocab_size)

    def forward(self, input_seq):
        embeds = self.embeddings(input_seq)
        # Note that the order of mini-batch is random so there is no hidden relationship among batches.
        # So we do not input the previous batch's hidden state,
        # leaving the first hidden state zero `self.lstm(embeds, None)`.
        lstm_out, _ = self.lstm(embeds)
        # use the final hidden state as the next character prediction
        final_hidden_state = lstm_out[:, -1]
        output = self.fc(final_hidden_state)
        return output
    
class RNN4(nn.Module):
    def __init__(self, embedding_dim=8, vocab_size=90, hidden_size=256):
        super(RNN4, self).__init__()
        self.embeddings = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embedding_dim, padding_idx=0)
        self.lstm = nn.LSTM(input_size=embedding_dim, hidden_size=hidden_size, num_layers=4, batch_first=True)
        self.fc = nn.Linear(hidden_size, vocab_size)

    def forward(self, input_seq):
        embeds = self.embeddings(input_seq)
        # Note that the order of mini-batch is random so there is no hidden relationship among batches.
        # So we do not input the previous batch's hidden state,
        # leaving the first hidden state zero `self.lstm(embeds, None)`.
        lstm_out, _ = self.lstm(embeds)
        # use the final hidden state as the next character prediction
        final_hidden_state = lstm_out[:, -1]
        output = self.fc(final_hidden_state)
        return output

In [18]:
class Server():
  def __init__(self):
    self.model = RNN2()

  def create_worker(self,federated_trainset,federated_valset,federated_testset):
    workers = []
    for i in range(args.worker_num):
      workers.append(Worker(federated_trainset[i],federated_valset[i],federated_testset[i]))
    return workers

  def sample_worker(self,workers):
    sample_worker = []
    sample_worker_num = random.sample(range(args.worker_num),args.sample_num)
    for i in sample_worker_num:
      sample_worker.append(workers[i])
    return sample_worker


  def send_model(self,workers):
    nums = 0
    for worker in workers:
      nums += worker.train_data_num

    for worker in workers:
      worker.aggregation_weight = 1.0*worker.train_data_num/nums
      worker.model = copy.deepcopy(self.model)
      worker.model = worker.model.to(args.device)

  def aggregate_model(self,workers):   
    new_params = OrderedDict()
    for i,worker in enumerate(workers):
      worker_state = worker.model.state_dict()
      for key in worker_state.keys():
        if i==0:
          new_params[key] = worker_state[key]*worker.aggregation_weight
        else:
          new_params[key] += worker_state[key]*worker.aggregation_weight
      worker.model = worker.model.to('cpu')
      del worker.model
    self.model.load_state_dict(new_params)

In [19]:
class Worker():
  def __init__(self,trainset,valset,testset):
    self.trainloader = torch.utils.data.DataLoader(trainset,batch_size=args.batch_size,shuffle=True,num_workers=2)
    self.valloader = torch.utils.data.DataLoader(valset,batch_size=args.test_batch,shuffle=False,num_workers=2)
    self.testloader = torch.utils.data.DataLoader(testset,batch_size=args.test_batch,shuffle=False,num_workers=2)
    self.model = None
    self.train_data_num = len(trainset)
    self.test_data_num = len(testset)
    self.aggregation_weight = None

  def local_train(self):
    acc_train,loss_train = local_train(self.model,args.criterion,self.trainloader,args.local_epochs)
    acc_valid,loss_valid = test(self.model,args.criterion,self.valloader)
    return acc_train,loss_train,acc_valid,loss_valid

    

In [20]:
def local_train(model,criterion,trainloader,epochs):
  optimizer = optim.SGD(model.parameters(),lr=args.lr,momentum=args.momentum,weight_decay=args.weight_decay)
  model.train()
  for epoch in range(epochs):
    running_loss = 0.0
    correct = 0
    count = 0
    for (data,labels) in trainloader:
      data,labels = Variable(data),Variable(labels)
      data,labels = data.to(args.device),labels.to(args.device)
      optimizer.zero_grad()
      outputs = model(data)
      loss = criterion(outputs,labels)
      running_loss += loss.item()
      predicted = torch.argmax(outputs,dim=1)
      correct += (predicted==labels).sum().item()
      count += len(labels)
      loss.backward()
      torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
      optimizer.step()

  return 100.0*correct/count,running_loss/len(trainloader)

In [21]:
def global_train(model,criterion,trainloader,valloader,epochs,partience=0,early_stop=False):
  if early_stop:
    early_stopping = Early_Stopping(partience)

  acc_train = []
  loss_train = []
  acc_valid = []
  loss_valid = []
  optimizer = optim.SGD(model.parameters(),lr=args.lr,momentum=args.momentum,weight_decay=args.weight_decay)
  for epoch in range(epochs):
    running_loss = 0.0
    correct = 0
    count = 0
    model.train()
    for (data,labels) in trainloader:
      count += len(labels)
      data,labels = Variable(data),Variable(labels)
      data,labels = data.to(args.device),labels.to(args.device)
      optimizer.zero_grad()
      outputs = model(data)
      loss = criterion(outputs,labels)
      running_loss += loss.item()
      predicted = torch.argmax(outputs,dim=1)
      correct += (predicted==labels).sum().item()
      loss.backward()
      torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
      optimizer.step()
    acc_train.append(100.0*correct/count)
    loss_train.append(running_loss/len(trainloader))
        
    running_loss = 0.0
    correct = 0
    count = 0
    model.eval()
    for (data,labels) in valloader:
      count += len(labels)
      data,labels = data.to(args.device),labels.to(args.device)
      outputs = model(data)
      loss = criterion(outputs,labels)
      running_loss += loss.item()
      predicted = torch.argmax(outputs,dim=1)
      correct += (predicted==labels).sum().item()
      
    print('Epoch:{}  accuracy:{}  loss:{}'.format(epoch+1,100.0*correct/count,running_loss/len(valloader)))
    acc_valid.append(100.0*correct/count)
    loss_valid.append(running_loss/len(valloader))
    if early_stop:
      if early_stopping.validate(running_loss):
        print('Early Stop')
        return acc_train,loss_train,acc_valid,loss_valid

  return acc_train,loss_train,acc_valid,loss_valid

In [22]:
def test(model,criterion,testloader):
  model.eval()
  running_loss = 0.0
  correct = 0
  count = 0
  for (data,labels) in testloader:
    data,labels = data.to(args.device),labels.to(args.device)
    outputs = model(data)
    running_loss += criterion(outputs,labels).item()
    predicted = torch.argmax(outputs,dim=1)
    correct += (predicted==labels).sum().item()
    count += len(labels)

  accuracy = 100.0*correct/count
  loss = running_loss/len(testloader)


  return accuracy,loss

In [23]:
class Early_Stopping():
  def __init__(self,partience):
    self.step = 0
    self.loss = float('inf')
    self.partience = partience

  def validate(self,loss):
    if self.loss<loss:
      self.step += 1
      if self.step>self.partience:
        return True
    else:
      self.step = 0
      self.loss = loss

    return False

In [24]:
model = RNN2()
model = model.to(args.device)

start = time.time()
acc_train,loss_train,acc_valid,loss_valid = global_train(model,args.criterion,global_trainloader,global_valloader,args.global_epochs,partience=args.partience,early_stop=True)
end = time.time()

Epoch:1  accuracy:18.73873873873874  loss:3.131506359577179
Epoch:2  accuracy:18.73873873873874  loss:3.1234628677368166
Epoch:3  accuracy:18.73873873873874  loss:3.1091001510620115
Epoch:4  accuracy:18.60860860860861  loss:3.068432021141052
Epoch:5  accuracy:20.31031031031031  loss:2.9974594116210938
Epoch:6  accuracy:22.06206206206206  loss:2.9100290060043337
Epoch:7  accuracy:23.943943943943943  loss:2.829341995716095
Epoch:8  accuracy:24.104104104104103  loss:2.789800024032593
Epoch:9  accuracy:23.953953953953953  loss:2.7204760909080505
Epoch:10  accuracy:26.426426426426428  loss:2.640570282936096
Epoch:11  accuracy:29.13913913913914  loss:2.580618166923523
Epoch:12  accuracy:29.954954954954953  loss:2.5127779603004456
Epoch:13  accuracy:31.066066066066067  loss:2.4634976029396056
Epoch:14  accuracy:32.56256256256256  loss:2.4282159209251404
Epoch:15  accuracy:33.548548548548546  loss:2.400114691257477
Epoch:16  accuracy:34.51951951951952  loss:2.357797622680664
Epoch:17  accuracy

In [26]:
server = Server()
workers = server.create_worker(federated_trainset,federated_valset,federated_testset)
server.model = model

In [27]:
acc_test = []
loss_test = []

server.model.to(args.device)

nums = 0
for worker in workers:
  nums += worker.test_data_num

start = time.time()

for i,worker in enumerate(workers):
  worker.aggregation_weight = 1.0*worker.test_data_num/nums
  acc_tmp,loss_tmp = test(server.model,args.criterion,worker.testloader)
  acc_test.append(acc_tmp)
  loss_test.append(loss_tmp)
  print('Worker{} accuracy:{}  loss:{}'.format(i+1,acc_tmp,loss_tmp))

end = time.time()

acc_test_avg = sum(acc_test)/len(acc_test)
loss_test_avg = sum(loss_test)/len(loss_test)
print('Test  loss:{}  accuracy:{}'.format(loss_test_avg,acc_test_avg))

Worker1 accuracy:46.80382072005878  loss:1.9149193934031896
Worker2 accuracy:45.36082474226804  loss:2.0038139820098877
Worker3 accuracy:31.818181818181817  loss:2.385687828063965
Worker4 accuracy:46.96969696969697  loss:1.9381753206253052
Worker5 accuracy:62.5  loss:1.4323594570159912
Worker6 accuracy:38.095238095238095  loss:3.1111762523651123
Worker7 accuracy:40.0  loss:1.5672590732574463
Worker8 accuracy:51.65289256198347  loss:1.7215585708618164
Worker9 accuracy:66.66666666666667  loss:1.166546106338501
Worker10 accuracy:42.54032258064516  loss:2.0178048610687256
Worker11 accuracy:46.728971962616825  loss:1.8134396076202393
Worker12 accuracy:42.5  loss:1.5625320672988892
Worker13 accuracy:46.87740940632228  loss:1.8970635732014973
Worker14 accuracy:38.622754491017965  loss:2.2313292026519775
Worker15 accuracy:50.0  loss:1.5804443359375
Worker16 accuracy:47.15140069247718  loss:1.8578182756900787
Worker17 accuracy:41.02564102564103  loss:2.7087905406951904
Worker18 accuracy:51.3513

In [29]:
acc_tune_test = []
loss_tune_test = []
acc_tune_valid = []
loss_tune_valid = []

start = time.time()

for i,worker in enumerate(workers):
    worker.model = copy.deepcopy(server.model)
    worker.model = worker.model.to(args.device)
    _,_,acc_tmp,loss_tmp = worker.local_train()
    acc_tune_valid.append(acc_tmp)
    loss_tune_valid.append(loss_tmp)
    print('Worker{} Valid accuracy:{}  loss:{}'.format(i+1,acc_tmp,loss_tmp))
    
    acc_tmp,loss_tmp = test(worker.model,args.criterion,worker.testloader)
    acc_tune_test.append(acc_tmp)
    loss_tune_test.append(loss_tmp)
    print('Worker{} Test accuracy:{}  loss:{}'.format(i+1,acc_tmp,loss_tmp))
    worker.model = worker.model.to('cpu')
    del worker.model

end = time.time()

acc_valid_avg = sum(acc_tune_valid)/len(acc_tune_valid)
loss_valid_avg = sum(loss_tune_valid)/len(loss_tune_valid)
print('Validation(tune)  loss:{}  accuracy:{}'.format(loss_valid_avg,acc_valid_avg))
acc_test_avg = sum(acc_tune_test)/len(acc_tune_test)
loss_test_avg = sum(loss_tune_test)/len(loss_tune_test)
print('Test(tune)  loss:{}  accuracy:{}'.format(loss_test_avg,acc_test_avg))

Worker1 Valid accuracy:46.95652173913044  loss:1.9003532462649875
Worker1 Test accuracy:45.9808963997061  loss:1.9440261806760515
Worker2 Valid accuracy:43.26647564469914  loss:1.9634466171264648
Worker2 Test accuracy:44.67353951890034  loss:1.9227676391601562
Worker3 Valid accuracy:37.735849056603776  loss:2.1340293884277344
Worker3 Test accuracy:31.818181818181817  loss:2.458251714706421
Worker4 Valid accuracy:53.164556962025316  loss:1.6840870380401611
Worker4 Test accuracy:51.515151515151516  loss:1.970503807067871
Worker5 Valid accuracy:44.73684210526316  loss:2.3607890605926514
Worker5 Test accuracy:65.625  loss:1.4286117553710938
Worker6 Valid accuracy:33.333333333333336  loss:2.0992465019226074
Worker6 Test accuracy:38.095238095238095  loss:3.0134758949279785
Worker7 Valid accuracy:0.0  loss:3.5820915699005127
Worker7 Test accuracy:40.0  loss:1.5882651805877686
Worker8 Valid accuracy:50.171821305841924  loss:1.8159635066986084
Worker8 Test accuracy:51.65289256198347  loss:1.697