## Model definition

In [1]:
## Model.py
## Exmaple of Model
import torch.nn as nn
import torch.nn.functional as F
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=5)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.pool = nn.MaxPool2d(kernel_size=2)
        self.fc1 = nn.Linear(1024, 2048)
        self.fc2 = nn.Linear(2048, 10)

    def forward(self, x):
        x = F.relu(self.pool(self.conv1(x)))
        x = F.relu(self.pool(self.conv2(x)))
        x = x.view(-1, 1024)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

def cnn():
    return CNN()


## DataLoader

In [2]:
## DataLoader.py
import torch
from torch import optim
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import transforms
import numpy as np
import random

class loader(object):
    def __init__(self, _data='mnist', batch_size=64, type="NON_IID"):
        self.batch_size = batch_size
        self.type = type
        self._data=_data
        
        self.__load_dataset()
        self.__get_index()
        

    def __load_dataset(self):
        # mnist
        self.train_mnist = datasets.MNIST('./dataset/',
                                          train=True, download=True,
                                          transform=transforms.Compose([
                                              transforms.ToTensor(),
                                              transforms.Normalize((0.1307,), (0.3081,))
                                          ]))

        self.test_mnist = datasets.MNIST('./dataset/',
                                         train=False, download=True,
                                         transform=transforms.Compose([
                                             transforms.ToTensor(),
                                             transforms.Normalize((0.1307,), (0.3081,))
                                         ]))

    def __get_index(self):
        self.train_dataset = self.train_mnist
        self.test_dataset = self.test_mnist

        #indices-> first item of big list mean the index of 0, first small list mean array of data indices which has 0 
        self.indices = [[], [], [], [], [], [], [], [], [], []]
        for index, data in enumerate(self.train_dataset):
            self.indices[data[1]].append(index)

    def get_loader(self, rank):
        if not rank:
            rank = np.random.randint(10)
        else:
            rank = int(rank[0])
            np.random.seed(rank)
        
        if self.type == "NON_IID":
            # 10 classes
            num_classes = 10
            # 2 shards per class
            shards_per_class = 2
            # 20 shards in total
            total_shards = num_classes * shards_per_class
            # shard_size: 3000 samples per shard-> but not exactly 3000(indices[0])
            shard_size = len(self.indices[0]) // shards_per_class

            # Sort indices for labels -> but why?, result is same, this process is unnecessary
            sorted_indices = [sorted(self.indices[i]) for i in range(num_classes)]  

            # shards: 20 shards, each shard has 3000 samples//0,10th->label_0//1,11th->label_1
            shards = []
            for shard_idx in range(total_shards):
                #class_idx: current working class
                class_idx = shard_idx % num_classes
                #start_idx: start index of current shard
                start_idx = (shard_idx // num_classes) * shard_size
                end_idx = start_idx + shard_size
                shards.append(sorted_indices[class_idx][start_idx:end_idx])

            # Randomly select 5 shards
            selected_classes = random.sample(range(num_classes), 5)  
            #pick 2shards from each selected class-> 2([0,10]) X 5 = 10 shards
            selected_shards = [shard for class_idx in selected_classes for shard in shards[class_idx::num_classes]]

            # but why?,labels1 is overwraped again, this process is unnecessary
            for rank, shard in enumerate(selected_shards):
                selected_clients = shard
                labels1 = [self.train_dataset.targets[idx].item() for idx in selected_clients]

            #shard: list of idx-> subset: list of data
            subsets = [torch.utils.data.Subset(self.train_dataset, shard) for shard in selected_shards]  
            #
            train_loader = DataLoader(torch.utils.data.ConcatDataset(subsets), batch_size=self.batch_size, shuffle=True)
            test_loader = DataLoader(self.test_dataset, batch_size=self.batch_size, shuffle=True)


        elif self.type == "IID":
            num_classes = 10
            shards_per_class = 5
            total_shards = num_classes * shards_per_class
            shard_size = len(self.indices[0]) // shards_per_class  

            sorted_indices = [sorted(self.indices[i]) for i in range(num_classes)]  

            shards = []
            for shard_idx in range(total_shards):
                class_idx = shard_idx % num_classes  # Shard index determines label
                start_idx = (shard_idx // num_classes) * shard_size
                end_idx = start_idx + shard_size
                shards.append(sorted_indices[class_idx][start_idx:end_idx])

            for rank, shard in enumerate(shards):
                selected_clients = shard
                labels1 = [self.train_dataset.targets[idx].item() for idx in selected_clients]

            subsets = [torch.utils.data.Subset(self.train_dataset, shard) for shard in shards]  
            train_loader = DataLoader(torch.utils.data.ConcatDataset(subsets), batch_size=self.batch_size, shuffle=True)
            test_loader = DataLoader(self.test_dataset, batch_size=self.batch_size, shuffle=True)

        return train_loader, test_loader
loader1 = loader()
loader1.get_loader([1])   


(<torch.utils.data.dataloader.DataLoader at 0x1f30ba13910>,
 <torch.utils.data.dataloader.DataLoader at 0x1f30ba13940>)

## Client

In [3]:
## Client.py
import torch

from torch import nn
from torch.autograd import Variable
class Client(object):
    def __init__(self, rank, data_loader, local_epoch, ES):
        seed = 19201077 + 19950920 + rank   # ??
        torch.manual_seed(seed)
        self.accuracy = []
        self.rank = rank
        self.local_epoch = local_epoch
        self.ES=ES
        self.device = ES.device
        self.test_loader = data_loader[1]
        self.train_loader = iter(data_loader[0])    #iter: make iterator??


    ## Option
    # def test(self, model):
    #     test_correct = 0
    #     
    #     # eval mode on
    #     model.eval()
    #     device = self.device
    
    #     with torch.no_grad():
    #         for data, target in self.test_loader:
    #             data, target = Variable(data).to(device),Variable(target).to(device)
    #             output = model(data)
    #             pred = output.argmax(dim=1, keepdim=True)
    #             test_correct += pred.eq(target.view_as(pred)).sum().item()
    
    #     return test_correct / len(self.test_loader.dataset)

    def train(self, model):
        #train mode on
        model.train()
        device = self.device
        
        optimizer = optim.SGD(model.parameters(), lr=1e-2)
        scheduler = optim.lr_scheduler.LambdaLR(optimizer=optimizer,
                                        lr_lambda=lambda epoch: 0.95 ** epoch,
                                        last_epoch=-1,
                                        verbose=False) 
        chk=0
        for data, target in self.train_loader:
            chk+=1
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output = model(data)
            self.loss = nn.CrossEntropyLoss()(output, target)
            self.loss.backward()
            optimizer.step()    
            if(chk ==1):
                break
            scheduler.step()
            
        weights=model.state_dict()
        
        return weights
    
    def load_global_model(self):
        '''
        ES's w-> model's w
        '''
        model = CNN().to(self.ES.device)
        model.load_state_dict(self.ES.model.state_dict()) 
        return model
    
    def run(self):
        model = self.load_global_model()
        for _ in range(self.local_epoch):
            weights = self.train(model=model)

        self.ES.clients[self.ES.count%self.ES.size]=weights # Client class -> ES class's clients attrubute
        self.ES.count+=1                                    # count: each ES's 1 client local training is done
        


## Edge Server

In [4]:
## Edge_Server.py
import torch
from torch import optim
from torch import nn
from torch.autograd import Variable
#import Client
class ES():
    def __init__(self, size, data_loader, device, CS=None):
        self.size = size
        self.test_loader = data_loader[1]
        self.accuracy = []
        self.clients = [None]*size
        self.count = 0
        self.model =  CNN().to(device)
        self.device = device
        self.CS = CS    #CS: Cloud_Server


    def average_weights(self,clients):
        for nth_Client in clients[1:]:
            for key in nth_Client:
                clients[0][key]+=nth_Client[key]
        for key in clients[0]:
            clients[0][key]=clients[0][key]/self.size  
        weights=clients[0]
        return weights

    def aggregate(self):
        weights_info = self.clients
        weights = self.average_weights(weights_info)
        self.model.load_state_dict(weights)
        test_accuracy = self.test()
        self.accuracy.append(test_accuracy)
        # ES
        if (self.CS):
            # up propagate ES's w-> CS's w
            self.upload_global_model()
            print('\n[ES Model]  Test Accuracy: {:.2f}%\n'.format(test_accuracy * 100.))
        # CS
        else:
            # can not down propagate CS's w-> ES's w
            print("Cloud Server process : ",end="")
            print('\n**** [CS Model]  Test Accuracy: {:.2f}% ****\n'.format(test_accuracy * 100.))
        
    def global_weight(self):
        weights = self.model.state_dict()
        return weights
    
    def upload_global_model(self):
        self.CS.clients[self.CS.count%self.CS.size]=self.model.state_dict()
        self.CS.count+=1
    
    def load_global_model(self):
        '''
        CS's w-> ES's w
        '''
        self.model.load_state_dict(self.CS.model.state_dict()) 
    
    def test(self):
        test_correct = 0
        self.model.eval()
        device = self.device
        with torch.no_grad():
            for data, target in self.test_loader:
                data, target = Variable(data).to(device), Variable(target).to(device)
                output = self.model(data)
                pred = output.argmax(dim=1, keepdim=True)
                test_correct += pred.eq(target.view_as(pred)).sum().item()
        return test_correct / len(self.test_loader.dataset)
    


## Cloud Server

In [5]:
##Cloud_server.py
# import Edge_Server

class CS(ES):
    def __init__(self, size, data_loader, device):
        super().__init__(size, data_loader, device)
        self.ESs=self.clients

## Main

In [6]:
## main.py
import random
def fed_AVG(n_client, n_ES, ES_epoch, epoch, batch_size, device = "cuda" ,type = "IID"):
    '''n_client: number of clients per 1 Edge Server'''
    
    print('Initialize Dataset...')
    data_loader = loader('mnist', batch_size=batch_size, type = type)    
    print('Initialize Edge Servers and Clients...')
    CS_1=CS(size = n_ES, data_loader = data_loader.get_loader([]), device = device)

    ESs = []
    clients = [[ None for i in range(n_client)] for j in range(n_ES) ]
    for i in range(n_ES):
        ESs.append(ES(size = n_client, data_loader = data_loader.get_loader([]), device = device, CS=CS_1))
        for j in range(n_client):
            clients[i][j]=Client(rank=j, data_loader=data_loader.get_loader(
            random.sample(range(0, 10), 4)  # distribute 4 classes
            ),local_epoch = epoch,
            ES = ESs[i])

    # federated learning
    for ESe in range(ES_epoch):
        print('\n================== Edge Server Epoch {:>3} =================='.format(ESe + 1))
        #each edge server's clients run
        for ESn in range(n_ES):
            print("Edge Server :",ESn,"process : ",end="")
            #each client's local training
            for c in clients[ESn]:
                c.run()
            ESs[ESn].aggregate()
        #receive and propagate global model to edge servers
        CS_1.aggregate()
        
        #down propagate CS's w-> ES's w
        for ESn in range(n_ES):
            ESs[ESn].load_global_model()


## Implementation Shell

In [8]:
fed_AVG(10, 5, 30, 20, 64, "cuda", "NON_IID")  #n_client, n_ES, ES_epoch, epoch, batch_size, device = "Cuda:0" ,type = "IID"

Initialize Dataset...
Initialize Edge Servers and Clients...

Edge Server : 0 process : 
[ES Model]  Test Accuracy: 10.11%

Edge Server : 1 process : 
[ES Model]  Test Accuracy: 15.11%

Edge Server : 2 process : 
[ES Model]  Test Accuracy: 23.22%

Edge Server : 3 process : 
[ES Model]  Test Accuracy: 29.42%

Edge Server : 4 process : 
[ES Model]  Test Accuracy: 19.40%

Cloud Server process : 
[CS Model]  Test Accuracy: 27.15%


Edge Server : 0 process : 
[ES Model]  Test Accuracy: 12.29%

Edge Server : 1 process : 
[ES Model]  Test Accuracy: 18.55%

Edge Server : 2 process : 
[ES Model]  Test Accuracy: 23.23%

Edge Server : 3 process : 
[ES Model]  Test Accuracy: 36.72%

Edge Server : 4 process : 
[ES Model]  Test Accuracy: 20.05%

Cloud Server process : 
[CS Model]  Test Accuracy: 33.39%


Edge Server : 0 process : 
[ES Model]  Test Accuracy: 21.88%

Edge Server : 1 process : 
[ES Model]  Test Accuracy: 31.07%

Edge Server : 2 process : 
[ES Model]  Test Accuracy: 42.11%

Edge Server 

KeyboardInterrupt: 